├── LICENSE ├── README.md └── exercises ├── week00_PyTorch_Intro ├── Intro+_entry_level_exercise_1_Learn_Pytorch_with_cifar10_tutorial.ipynb ├── PyTorch_Introduction.ipynb ├── Tutorial_+_exersice_1.ipynb └── [solution]_Learn_Pytorch_with_cifar10_tutorial.ipynb ├── week01_AE ├── A01_autoencoder.ipynb └── viz_reconstructions_solution.png ├── week02_RotPred ├── A02_rotpred.ipynb ├── tsne_plot_embeddings_solution.png ├── utils.py └── viz_prediction_solution.png ├── week03_BERT ├── A03_Bert.ipynb ├── curves.png └── utils.py ├── week04_simclr ├── A04_SimCLR_Resnet18.ipynb ├── figs │ ├── fine_tuning_results_stl10.png │ ├── simclr-illustration-loss.png │ └── t-sne-simclr_feats__.png └── utils.py ├── week05_scan ├── A05-image-clustering.ipynb ├── figs │ ├── knn_viz.png │ └── tsne_plot_embeddings_solution__.png └── utils.py ├── week06_distillation ├── A06_Distillation_CIFAR100.ipynb ├── figs │ ├── ViT-tiny_CIFAR100.png │ └── distilled_ViT-tiny_.png └── utils.py ├── week07-08_DINO └── A07_DINO.ipynb ├── week09_10_CLIP ├── A09_CLIP_OOD.ipynb └── utils.py └── week12-13_Proteins ├── A12_proteins.ipynb └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 MMBS 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Representation Learning MSc course Summer Semester 2023 2 | 3 | ## Summary 4 | This course is tailored for MSc students of the AI and Data Science Master of the Heinrich Heine University of Dusseldorf. 5 | 6 | We provide all the course materials, including lectures, slides, and exercise classes. 7 | 8 | [YouTube Playlist of videos](https://www.youtube.com/playlist?list=PL3mKiGE4zNJJ83K4c3IBka6eYfe6v71dS) 9 | 10 | 11 | ## Week 1 - Introduction to Representation Learning 12 | - [Lecture](https://www.youtube.com/watch?v=i1-OtPa9doY&list=PL3mKiGE4zNJJ83K4c3IBka6eYfe6v71dS&index=2) || [Slides](https://uni-duesseldorf.sciebo.de/s/2h3pY73kHHIWtUW) 13 | - Introduction to autoencoders for representation learning, early/old traditional approaches. Based on [Bengio et al. 2012 paper](https://arxiv.org/pdf/1206.5538.pdf) 14 | 15 | #### Exercise 16 | Image autoencoders. Learning to use and evaluate the intermediate learned representation. 17 | 18 | 19 | ## Week 2 - Overview of visual self-supervised learning methods 20 | - [Lecture](https://youtu.be/3Zvo1BihTRE) || [Slides](https://uni-duesseldorf.sciebo.de/s/J5f839uJRKQhW8y) 21 | - Self-supervised learning VS Transfer Learning. Pretext VS Downstream Task 22 | - Pretext tasks covered: Colorization, Jigsaw puzzles, Image inpainting, Shuffle and Learn (Videos), - Classify corrupted images, Rotation Prediction 23 | - Semi-supervised learning: Consistency loss 24 | - A small intro to Contrastive loss (infoNCE) 25 | 26 | #### Exercise 27 | In this exercise, we will train a ResNet18 on the task of rotation prediction. Rotation prediction provides a simple, yet effective way to learn rich representations from unlabeled image data. The basic idea behind rotation prediction is that the network is trained to predict the orientation of a given image after it has been rotated by a certain angle (e.g., 0°, 90°, 180°, or 270°). 28 | 29 | 30 | ## Week 3 - BERT:Learning Natural Language Representations 31 | - [Lecture](https://youtu.be/qCZiR5I47Bo) || [Slides](https://uni-duesseldorf.sciebo.de/s/a3XfMv2HgcxMR8a) 32 | - Natural Language Processing (NLP) basics 33 | - RNN, self-attention, and Transformer recap 34 | - Language pretext tasks 35 | - Pretext tasks for representation learning in NLP. An in-depth look into [BERT](https://arxiv.org/abs/1810.04805). 36 | 37 | 38 | #### Exercise 39 | In this exercise, you will train a small [BERT](https://arxiv.org/abs/1810.04805) model on the IMDB dataset (https://huggingface.co/datasets/imdb). You will then use the model to classify the sentiment of movie reviews and the sentiment of sentences from the Stanford Sentiment Treebank (SST2, https://huggingface.co/datasets/sst2). 40 | 41 | ## Week 4 - Contrastive Learning, SimCLR and mutual information-based proof 42 | - [Lecture](https://youtu.be/RlCqUawKcwA) || [Slides](https://uni-duesseldorf.sciebo.de/s/g5bD2N5QMpOM3lf) || [Notes](https://uni-duesseldorf.sciebo.de/s/zwbaz4mENfnuy28) 43 | - A deep look into contrastive learning, theory, and proof of MI bound. 44 | - [SimCLR Paper](https://arxiv.org/abs/2002.05709) 45 | 46 | #### Exercise 47 | Build and train SimCLR resnet18 on CIFAR10. 48 | 49 | 50 | ## Week 5 - Understanding Contrastive learning & MoCO and image clustering 51 | - [Lecture](https://youtu.be/PE1MT_S9m1k) || [Slides](https://uni-duesseldorf.sciebo.de/s/jZtCrfKRIRA2UmI) || [MoCO implementation](https://uni-duesseldorf.sciebo.de/s/NTnqx68EE630X4a) 52 | - Contrastive Learning, L2 normalization, Properties of contrastive loss 53 | - Momentum encoder (MoCO). Issues and concerns regarding batch normalization 54 | - Multi-view contrastive learning 55 | - Deep Image Clustering: task definition and challenges, K-means and [SCAN](https://arxiv.org/abs/2005.12320), [PMI and TEMI](https://arxiv.org/abs/2303.17896) 56 | 57 | #### Exercise 58 | Use pretrained MoCO ResNet50 for image clustering. 59 | 60 | 61 | ## Week 6 - Vision Transformers and Knowledge Distillation 62 | - [Lecture](https://youtu.be/J_q-PEYikEo) || [Slides](https://uni-duesseldorf.sciebo.de/s/Jbx5bw87vlZrueB) 63 | - Transformer encoder and Vision transformer 64 | - ViTs VS CNNs: receptive field and inductive biases 65 | - Knowledge distillation and the mysteries of model ensembles 66 | - Knowledge distillation in ViTs and masked image modeling 67 | 68 | #### Exercise 69 | Knowledge distillation on CIFAR100 with Vision Transformers. 70 | 71 | ## Week 7 - Self-supervised learning without negative samples (BYOL, DINO) 72 | - [Lecture](https://youtu.be/-VqXScgDZnM) || [Slides](https://uni-duesseldorf.sciebo.de/s/iU8owOBDx7PZdMs) 73 | - A small review of self-supervised methods 74 | - A small review of knowledge distillation 75 | - Self-Supervised Learning & knowledge distillation 76 | - An in-depth look into DINO 77 | 78 | #### Exercise (2-week assignment) 79 | In this exercise you will implement and train a DINO model on a medical dataset, the PathMNIST dataset from [medmnist](https://medmnist.com/) consisting of low-resolution images of various colon pathologies. 80 | 81 | 82 | ## Week 8 - Masked-based visual representation learning: MAE, BEiT, iBOT, DINOv2 83 | - [Lecture](https://youtu.be/8KP2SCm1YVo) || [Slides](https://uni-duesseldorf.sciebo.de/s/ifWDLlUdGRYBvD9) 84 | - MAE: https://arxiv.org/abs/2111.06377 85 | - BEiT: BERT-style pre-training in vision: https://arxiv.org/abs/2106.08254 86 | - iBOT: Combining MIM with DINO https://arxiv.org/abs/2111.07832 87 | - DINOv2: https://arxiv.org/abs/2304.07193 88 | 89 | 90 | ## Week 9 - Multimodal representation learning, robustness, and visual anomaly detection 91 | - [Lecture](https://youtu.be/eAf9UjPXmVg) || [Slides](https://uni-duesseldorf.sciebo.de/s/N8mAjFoMLaDJbHZ) 92 | - Defining Robustness and Types of Robustness 93 | - Zero-shot learning 94 | - Contrastive Language Image Pretraining (CLIP) 95 | - Image captioning 96 | - Few-shot learning 97 | - Visual anomaly detection: task definition 98 | - Anomaly detection scores 99 | - Anomaly detection metrics: AUROC 100 | 101 | 102 | #### Exercise (2-week assignment) 103 | Use a CLIP-pre-trained model for out-of-distribution detection. 104 | 105 | 106 | ## Week 10 - Emerging properties of the learned representations and scaling laws 107 | - [Lecture](https://youtu.be/CiXyNHVTxLs) || [Slides](https://uni-duesseldorf.sciebo.de/s/FMc1pDa9js5OTcR) 108 | - Investigating CLIP models and scaling laws 109 | - Determine factor of success of CLIP? 110 | - How does CLIP scale to larger datasets and models? 111 | - OpenCLIP: Scaling laws of CLIP models and connection to NLP scaling laws 112 | - Robustness of CLIP models against image manipulations 113 | - Learned representations of supervised models:CNNs VS Vision Transformers (ViTs), the texture-shape bias 114 | - Robustness and generalization of supervised-pretrained CNNs VS ViTs 115 | - Scaling (Supervised) Vision Transformers 116 | - [Properties of ViT pretrained models](https://theaisummer.com/vit-properties/) 117 | 118 | 119 | ## Week 11 - Investigating the self-supervised learned representations 120 | - [Lecture](https://youtu.be/uUR0yEZ55Vg) || [Slides](https://uni-duesseldorf.sciebo.de/s/NALUEG5AlUzhbI3) 121 | - Limitations of existing vision language models 122 | - Self-supervised VS supervised learned feature representations 123 | - What do vision transformers (ViTs) learn “on their own”? 124 | - MoCOv3 and DINO: https://arxiv.org/abs/2104.14294 125 | - Self-supervised learning in medical imaging 126 | - Investigating the pre-training self-supervised objectives 127 | 128 | #### Exercise 129 | No exercise takes place this week. 130 | 131 | ## Week 12 - Representation Learning in Proteins 132 | - [Lecture](https://youtu.be/ZFazdK7dA7Q) || [Slides](https://uni-duesseldorf.sciebo.de/s/hFCXvnJCpAiPzlR) 133 | - A closer look at the attention mechanism. The attention mechanism in Natural Language Translation 134 | - A tiny intro to proteins 135 | - Representing protein sequences with Transformers: BERT masked language modeling VS GPT? 136 | - [ESM](https://www.pnas.org/doi/full/10.1073/pnas.2016239118), [ESMv2])(https://pubmed.ncbi.nlm.nih.gov/36927031/) 137 | - [Looking & combining at the attention maps of a pre-trained Transformer](https://www.biorxiv.org/content/10.1101/2020.12.15.422761v1) 138 | - [Protein Language models generalize beyond natural proteins](https://www.biorxiv.org/content/10.1101/2022.12.21.521521v1) 139 | 140 | #### Exercise 141 | Use a pretrained Protein Language Model 142 | 143 | 144 | ## Week 13 AlphaFold2 145 | - [Lecture]() || [Slides]() 146 | 147 | #### Exercise 148 | [Just play around with an Alphafold notebook](https://colab.research.google.com/github/deepmind/alphafold/blob/main/notebooks/AlphaFold.ipynb) 149 | 150 | 151 | # Additional info 152 | Feel free to open issues regarding errors on the exercises or missing information and we will try to get back to you. 153 | > Important: Solutions to the exercises are not provided, but you can cross-check your results with the Expected results in the notebook. 154 | -------------------------------------------------------------------------------- /exercises/week01_AE/viz_reconstructions_solution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHU-MMBS/RepresentationLearning_SS2023/9238c4c7648f8d64dd0461e828652cd3606f36a1/exercises/week01_AE/viz_reconstructions_solution.png -------------------------------------------------------------------------------- /exercises/week02_RotPred/tsne_plot_embeddings_solution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHU-MMBS/RepresentationLearning_SS2023/9238c4c7648f8d64dd0461e828652cd3606f36a1/exercises/week02_RotPred/tsne_plot_embeddings_solution.png -------------------------------------------------------------------------------- /exercises/week02_RotPred/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader, Dataset 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | import torch.optim as optim 11 | import torch.utils.data as data 12 | import random 13 | import matplotlib.pyplot as plt 14 | from torchvision import transforms as T 15 | from tqdm import tqdm 16 | 17 | def unnormalize_image(image, mean=0, std=1.0, permute=True): 18 | """ 19 | Transform the given image by un-normalizing it with the given mean and standard deviation. 20 | :param image: image to unnormalize as numpy array. 21 | :param mean: mean to use for unnormalization. 22 | :param std: standard deviation to use for unnormalization. 23 | :param permute: flag to indicate if image should be flipped and its dimensions rearranged. 24 | :return: unnormalized image as numpy array. 25 | """ 26 | if type(mean) != torch.Tensor: 27 | mean = torch.tensor(mean, dtype=torch.float32) 28 | 29 | if type(std) != torch.Tensor: 30 | std = torch.tensor(std, dtype=torch.float32) 31 | 32 | unnormalize = T.Normalize( 33 | (-mean / std).tolist(), 34 | (1.0 / std).tolist() 35 | ) 36 | 37 | unnormalized_image = unnormalize(image).numpy() 38 | if permute: 39 | return np.transpose(unnormalized_image, (1, 2, 0)) 40 | return unnormalized_image 41 | 42 | def imshow(img, mean=torch.tensor([0.0], dtype=torch.float32), std=torch.tensor([1], dtype=torch.float32)): 43 | """ 44 | shows an image on the screen. mean of 0 and variance of 1 will show the images unchanged in the screen 45 | """ 46 | plt.imshow(unnormalize_image(img, mean=mean, std=std, permute=True)) 47 | 48 | 49 | def prevalidate(model, val_loader,criterion, device): 50 | model.eval() 51 | correct, total = 0, 0 52 | loss_step = [] 53 | with torch.no_grad(): 54 | for data in val_loader: 55 | inp_data,labels = data 56 | inp_data, labels = inp_data.to(device), labels.to(device) 57 | outputs = model(inp_data) 58 | val_loss = criterion(outputs, labels) 59 | loss_step.append(val_loss.item()) 60 | # dont forget to take the means here 61 | val_loss_epoch = np.mean(loss_step) 62 | return val_loss_epoch 63 | 64 | def pretrain_one_epoch(model, optimizer, train_loader, criterion, device): 65 | model.train() 66 | loss_step = [] 67 | for data in train_loader: 68 | # Move the data to the GPU 69 | inp_data, labels = data 70 | inp_data, labels = inp_data.to(device), labels.to(device) 71 | outputs = model(inp_data) 72 | loss = criterion(outputs, labels) 73 | optimizer.zero_grad() 74 | loss.backward() 75 | optimizer.step() 76 | loss_step.append(loss.item()) 77 | # dont forget the means here 78 | loss_curr_epoch = np.mean(loss_step) 79 | return loss_curr_epoch 80 | 81 | def save_model(model, path, epoch, optimizer, val_loss): 82 | torch.save({ 83 | 'epoch': epoch, 84 | 'model_state_dict': model.state_dict(), 85 | 'optimizer_state_dict': optimizer.state_dict(), 86 | 'loss': val_loss, 87 | }, path) 88 | 89 | def pretrain(model, optimizer, num_epochs, train_loader, val_loader, criterion, device): 90 | dict_log = {"train_loss":[], "val_loss":[]} 91 | best_val_loss = 1e8 92 | model = model.to(device) 93 | pbar = tqdm(range(num_epochs)) 94 | for epoch in pbar: 95 | loss_curr_epoch = pretrain_one_epoch(model, optimizer, train_loader, criterion, device) 96 | val_loss = prevalidate(model, val_loader, criterion, device) 97 | 98 | # Print epoch results to screen 99 | msg = (f'Ep {epoch+1}/{num_epochs}: || Loss: Train {loss_curr_epoch:.3f} \t Val {val_loss:.3f}') 100 | pbar.set_description(msg) 101 | 102 | dict_log["train_loss"].append(loss_curr_epoch) 103 | dict_log["val_loss"].append(val_loss) 104 | 105 | # Use this code to save the model with the best validation loss 106 | if val_loss < best_val_loss: 107 | best_val_loss = val_loss 108 | save_model(model, f'best_model_min_val_loss.pth', epoch, optimizer, val_loss) 109 | return dict_log 110 | 111 | 112 | 113 | def validate(model, val_loader, device): 114 | model.eval() 115 | criterion = nn.CrossEntropyLoss() 116 | correct, total = 0, 0 117 | loss_step = [] 118 | with torch.no_grad(): 119 | for data in val_loader: 120 | inp_data,labels = data 121 | inp_data = inp_data.to(device) 122 | labels = labels.to(device) 123 | outputs = model(inp_data) 124 | val_loss = criterion(outputs, labels) 125 | predicted = torch.max(outputs, 1)[1] 126 | total += labels.size(0) 127 | correct += (predicted == labels).sum() 128 | loss_step.append(val_loss.item()) 129 | # dont forget to take the means here 130 | val_acc = (100 * correct / total).cpu().numpy() 131 | val_loss_epoch = torch.tensor(loss_step).mean().numpy() 132 | return val_acc , val_loss_epoch 133 | 134 | # Provided 135 | def get_features(model, dataloader, device): 136 | model = model.to(device) 137 | feats, labs = [], [] 138 | for i in dataloader: 139 | inp_data,labels = i 140 | inp_data = inp_data.to(device) 141 | features = model(inp_data) 142 | features = features.cpu().detach().flatten(start_dim=1) 143 | labels = labels.cpu().detach() 144 | feats.append(features) 145 | labs.append(labels) 146 | f = torch.cat(feats, dim=0) 147 | l = torch.cat(labs, dim=0) 148 | return f,l 149 | 150 | 151 | def train_one_epoch(model, optimizer, train_loader, device): 152 | model.train() 153 | criterion = nn.CrossEntropyLoss() 154 | loss_step = [] 155 | correct, total = 0, 0 156 | for data in train_loader: 157 | # Move the data to the GPU 158 | inp_data,labels = data 159 | inp_data = inp_data.to(device) 160 | labels = labels.to(device) 161 | outputs = model(inp_data) 162 | loss = criterion(outputs, labels) 163 | optimizer.zero_grad() 164 | loss.backward() 165 | optimizer.step() 166 | with torch.no_grad(): 167 | _, predicted = torch.max(outputs, 1) 168 | total += labels.size(0) 169 | correct += (predicted == labels).sum() 170 | loss_step.append(loss.item()) 171 | # dont forget the means here 172 | loss_curr_epoch = np.mean(loss_step) 173 | train_acc = (100 * correct / total).cpu() 174 | return loss_curr_epoch, train_acc 175 | 176 | 177 | def linear_eval(model, optimizer, num_epochs, train_loader, val_loader, device): 178 | best_val_loss = 1e8 179 | best_val_acc = 0 180 | model = model.to(device) 181 | dict_log = {"train_acc_epoch":[], "val_acc_epoch":[], "loss_epoch":[], "val_loss":[]} 182 | train_acc, _ = validate(model, train_loader, device) 183 | val_acc, _ = validate(model, val_loader, device) 184 | print(f'Init Accuracy of the model: Train:{train_acc:.3f} \t Val:{val_acc:3f}') 185 | pbar = tqdm(range(num_epochs)) 186 | for epoch in pbar: 187 | loss_curr_epoch, train_acc = train_one_epoch(model, optimizer, train_loader, device) 188 | val_acc, val_loss = validate(model, val_loader, device) 189 | 190 | # Print epoch results to screen 191 | msg = (f'Ep {epoch+1}/{num_epochs}: Accuracy : Train:{train_acc:.2f} \t Val:{val_acc:.2f} || Loss: Train {loss_curr_epoch:.3f} \t Val {val_loss:.3f}') 192 | pbar.set_description(msg) 193 | # Track stats 194 | dict_log["train_acc_epoch"].append(train_acc) 195 | dict_log["val_acc_epoch"].append(val_acc) 196 | dict_log["loss_epoch"].append(loss_curr_epoch) 197 | dict_log["val_loss"].append(val_loss) 198 | 199 | if val_loss < best_val_loss: 200 | best_val_loss = val_loss 201 | torch.save({ 202 | 'epoch': epoch, 203 | 'model_state_dict': model.state_dict(), 204 | 'optimizer_state_dict': optimizer.state_dict(), 205 | 'loss': val_loss, 206 | }, f'best_model_min_val_loss.pth') 207 | 208 | if val_acc > best_val_acc: 209 | best_val_acc = val_acc 210 | torch.save({ 211 | 'epoch': epoch, 212 | 'model_state_dict': model.state_dict(), 213 | 'optimizer_state_dict': optimizer.state_dict(), 214 | 'loss': val_loss, 215 | }, f'best_model_max_val_acc.pth') 216 | return dict_log 217 | 218 | 219 | def load_model(model, path): 220 | checkpoint = torch.load(path) 221 | model.load_state_dict(checkpoint['model_state_dict']) 222 | print(f"Model {path} is loaded from epoch {checkpoint['epoch']} , loss {checkpoint['loss']}") 223 | return model 224 | -------------------------------------------------------------------------------- /exercises/week02_RotPred/viz_prediction_solution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHU-MMBS/RepresentationLearning_SS2023/9238c4c7648f8d64dd0461e828652cd3606f36a1/exercises/week02_RotPred/viz_prediction_solution.png -------------------------------------------------------------------------------- /exercises/week03_BERT/curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHU-MMBS/RepresentationLearning_SS2023/9238c4c7648f8d64dd0461e828652cd3606f36a1/exercises/week03_BERT/curves.png -------------------------------------------------------------------------------- /exercises/week03_BERT/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from pathlib import Path 4 | 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from torch.nn import functional as F 8 | from datasets import load_dataset 9 | from tqdm import tqdm 10 | from transformers import AutoTokenizer 11 | 12 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 13 | 14 | TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased") 15 | MAX_SEQ_LEN = 256 16 | 17 | 18 | class MovingAverage: 19 | def __init__(self, beta=0.99): 20 | self.beta = beta 21 | self.avg = None 22 | 23 | def update(self, value): 24 | if isinstance(value, torch.Tensor): 25 | value = value.item() 26 | if self.avg is None: 27 | self.avg = value 28 | else: 29 | self.avg = self.beta * self.avg + (1 - self.beta) * value 30 | 31 | def get(self): 32 | return self.avg 33 | 34 | 35 | class TextProcessor: 36 | 37 | def __init__(self): 38 | import nltk 39 | self.sent_tokenize = nltk.sent_tokenize 40 | nltk.download('punkt') 41 | 42 | self.CLEAN_HTML = re.compile('<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});') 43 | self.CLEAN_PUNKT = re.compile('[' + re.escape('!"#$%&()*+,.:;<=>?@[\\]^_`{|}~') + ']') 44 | self.CLEAN_WHITE = re.compile(r'\s+') 45 | 46 | def clean_text(self, text): 47 | text = re.sub(self.CLEAN_HTML, ' ', text) 48 | text = re.sub(self.CLEAN_PUNKT, ' ', text.lower()) 49 | text = re.sub(self.CLEAN_WHITE, ' ', text) 50 | return text.strip() 51 | 52 | def __call__(self, text): 53 | return [self.clean_text(sent) for sent in self.sent_tokenize(text)] 54 | 55 | 56 | class IMDBDataset(torch.utils.data.Dataset): 57 | 58 | def __init__(self, train=True): 59 | super().__init__() 60 | split = "unsupervised" if train else "test" 61 | raw_dataset = load_dataset("imdb", split=split) 62 | self.tokens = tokenize_dataset(raw_dataset, f"imdb_{split}", "text") 63 | 64 | def __len__(self): 65 | return len(self.tokens) 66 | 67 | def __getitem__(self, index): 68 | return self.tokens[index] 69 | 70 | 71 | def load_imdb_dataset(train=True): 72 | return IMDBDataset(train) 73 | 74 | 75 | def add_special_tokens(tokens_a, token_b=None): 76 | tokens = torch.cat([ 77 | torch.tensor([TOKENIZER.cls_token_id]), 78 | tokens_a, 79 | torch.tensor([TOKENIZER.sep_token_id]) 80 | ]) 81 | normal_mask = torch.tensor([False] + [True] * len(tokens_a) + [False]) 82 | segment_id = torch.zeros_like(tokens) 83 | if token_b is not None: 84 | tokens = torch.cat([ 85 | tokens, 86 | token_b, 87 | torch.tensor([TOKENIZER.sep_token_id]) 88 | ]) 89 | normal_mask = torch.cat([ 90 | normal_mask, 91 | torch.tensor([True] * len(token_b) + [False]) 92 | ]) 93 | segment_id = torch.cat([ 94 | segment_id, 95 | torch.ones(len(token_b) + 1, dtype=torch.int16) 96 | ]) 97 | return dict( 98 | input_ids=tokens.long(), 99 | normal_mask=normal_mask, 100 | segment_ids=segment_id.long()) 101 | 102 | 103 | class SST2Dataset(torch.utils.data.Dataset): 104 | 105 | def __init__(self, train=True): 106 | super().__init__() 107 | split = "train" if train else "validation" 108 | self.raw_dataset = load_dataset("glue", "sst2", split=split) 109 | self.tokens = tokenize_dataset(self.raw_dataset, f"sst2_{split}", "sentence") 110 | 111 | def __len__(self): 112 | return len(self.raw_dataset) 113 | 114 | def __getitem__(self, index): 115 | label = self.raw_dataset[index]['label'] 116 | tokens = self.tokens[index][0] 117 | out = add_special_tokens(tokens) 118 | out.update(labels=label) 119 | return out 120 | 121 | 122 | def load_sst2_dataset(train=True): 123 | return SST2Dataset(train) 124 | 125 | 126 | def tokenize_dataset(dataset, name, key): 127 | path = Path('.cache') 128 | path.mkdir(exist_ok=True) 129 | path = path / f"{name}_tokens.pt" 130 | if path.exists(): 131 | return torch.load(path) 132 | print(f"Tokenizing {name} dataset...") 133 | text_processor = TextProcessor() 134 | out = [] 135 | for i in tqdm(range(len(dataset))): 136 | text = dataset[i][key] 137 | # If the item is a tuple, it is a labeled dataset 138 | if isinstance(text, tuple): 139 | text = text 140 | text = text_processor(text) 141 | ids = TOKENIZER(text, 142 | max_length=MAX_SEQ_LEN, 143 | add_special_tokens=False, 144 | truncation=True)['input_ids'] 145 | ids = [torch.tensor(ids, dtype=torch.int16, device='cpu') for ids in ids] 146 | out.append(ids) 147 | torch.save(out, path) 148 | return out 149 | 150 | ############################################################################################################ 151 | # This is copied from the PyTorch source code and only modified to allow padding 152 | 153 | import collections 154 | import contextlib 155 | import re 156 | import torch 157 | 158 | from typing import Callable, Dict, Optional, Tuple, Type, Union 159 | 160 | np_str_obj_array_pattern = re.compile(r'[SaUO]') 161 | 162 | 163 | def default_convert(data): 164 | r""" 165 | Function that converts each NumPy array element into a :class:`torch.Tensor`. If the input is a `Sequence`, 166 | `Collection`, or `Mapping`, it tries to convert each element inside to a :class:`torch.Tensor`. 167 | If the input is not an NumPy array, it is left unchanged. 168 | This is used as the default function for collation when both `batch_sampler` and 169 | `batch_size` are NOT defined in :class:`~torch.utils.data.DataLoader`. 170 | The general input type to output type mapping is similar to that 171 | of :func:`~torch.utils.data.default_collate`. See the description there for more details. 172 | Args: 173 | data: a single data point to be converted 174 | Examples: 175 | >>> # xdoctest: +SKIP 176 | >>> # Example with `int` 177 | >>> default_convert(0) 178 | 0 179 | >>> # Example with NumPy array 180 | >>> default_convert(np.array([0, 1])) 181 | tensor([0, 1]) 182 | >>> # Example with NamedTuple 183 | >>> Point = namedtuple('Point', ['x', 'y']) 184 | >>> default_convert(Point(0, 0)) 185 | Point(x=0, y=0) 186 | >>> default_convert(Point(np.array(0), np.array(0))) 187 | Point(x=tensor(0), y=tensor(0)) 188 | >>> # Example with List 189 | >>> default_convert([np.array([0, 1]), np.array([2, 3])]) 190 | [tensor([0, 1]), tensor([2, 3])] 191 | """ 192 | elem_type = type(data) 193 | if isinstance(data, torch.Tensor): 194 | return data 195 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 196 | and elem_type.__name__ != 'string_': 197 | # array of string classes and object 198 | if elem_type.__name__ == 'ndarray' \ 199 | and np_str_obj_array_pattern.search(data.dtype.str) is not None: 200 | return data 201 | return torch.as_tensor(data) 202 | elif isinstance(data, collections.abc.Mapping): 203 | try: 204 | return elem_type({key: default_convert(data[key]) for key in data}) 205 | except TypeError: 206 | # The mapping type may not support `__init__(iterable)`. 207 | return {key: default_convert(data[key]) for key in data} 208 | elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple 209 | return elem_type(*(default_convert(d) for d in data)) 210 | elif isinstance(data, tuple): 211 | return [default_convert(d) for d in data] # Backwards compatibility. 212 | elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str): 213 | try: 214 | return elem_type([default_convert(d) for d in data]) 215 | except TypeError: 216 | # The sequence type may not support `__init__(iterable)` (e.g., `range`). 217 | return [default_convert(d) for d in data] 218 | else: 219 | return data 220 | 221 | 222 | default_collate_err_msg_format = ( 223 | "default_collate: batch must contain tensors, numpy arrays, numbers, " 224 | "dicts or lists; found {}") 225 | 226 | 227 | def collate(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): 228 | r""" 229 | General collate function that handles collection type of element within each batch 230 | and opens function registry to deal with specific element types. `default_collate_fn_map` 231 | provides default collate functions for tensors, numpy arrays, numbers and strings. 232 | Args: 233 | batch: a single batch to be collated 234 | collate_fn_map: Optional dictionary mapping from element type to the corresponding collate function. 235 | If the element type isn't present in this dictionary, 236 | this function will go through each key of the dictionary in the insertion order to 237 | invoke the corresponding collate function if the element type is a subclass of the key. 238 | Examples: 239 | >>> # Extend this function to handle batch of tensors 240 | >>> def collate_tensor_fn(batch, *, collate_fn_map): 241 | ... return torch.stack(batch, 0) 242 | >>> def custom_collate(batch): 243 | ... collate_map = {torch.Tensor: collate_tensor_fn} 244 | ... return collate(batch, collate_fn_map=collate_map) 245 | >>> # Extend `default_collate` by in-place modifying `default_collate_fn_map` 246 | >>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn}) 247 | Note: 248 | Each collate function requires a positional argument for batch and a keyword argument 249 | for the dictionary of collate functions as `collate_fn_map`. 250 | """ 251 | elem = batch[0] 252 | elem_type = type(elem) 253 | 254 | if collate_fn_map is not None: 255 | if elem_type in collate_fn_map: 256 | return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map) 257 | 258 | for collate_type in collate_fn_map: 259 | if isinstance(elem, collate_type): 260 | return collate_fn_map[collate_type](batch, collate_fn_map=collate_fn_map) 261 | 262 | if isinstance(elem, collections.abc.Mapping): 263 | try: 264 | return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}) 265 | except TypeError: 266 | # The mapping type may not support `__init__(iterable)`. 267 | return {key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem} 268 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple 269 | return elem_type(*(collate(samples, collate_fn_map=collate_fn_map) for samples in zip(*batch))) 270 | elif isinstance(elem, collections.abc.Sequence): 271 | # check to make sure that the elements in batch have consistent size 272 | it = iter(batch) 273 | elem_size = len(next(it)) 274 | if not all(len(elem) == elem_size for elem in it): 275 | raise RuntimeError('each element in list of batch should be of equal size') 276 | transposed = list(zip(*batch)) # It may be accessed twice, so we use a list. 277 | 278 | if isinstance(elem, tuple): 279 | return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed] # Backwards compatibility. 280 | else: 281 | try: 282 | return elem_type([collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]) 283 | except TypeError: 284 | # The sequence type may not support `__init__(iterable)` (e.g., `range`). 285 | return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed] 286 | 287 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 288 | 289 | 290 | # This function is new 291 | def padded_stack(tensors, pad_length, dim=0, *, out=None): 292 | padded_tensors = [] 293 | for tensor in tensors: 294 | padding = torch.zeros(pad_length - tensor.size(0), *tensor.shape[1:], dtype=tensor.dtype, device=tensor.device) 295 | padded_tensor = torch.cat([tensor, padding], dim=0) 296 | padded_tensors.append(padded_tensor) 297 | return torch.stack(padded_tensors, dim=dim, out=out) 298 | 299 | 300 | def collate_tensor_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): 301 | elem = batch[0] 302 | max_length = max(t.size(0) for t in batch) 303 | out = None 304 | if torch.utils.data.get_worker_info() is not None: 305 | # If we're in a background process, concatenate directly into a 306 | # shared memory tensor to avoid an extra copy 307 | numel = elem[0].numel() * max_length * len(batch) 308 | storage = elem.storage()._new_shared(numel, device=elem.device) 309 | shape = [len(batch), max_length] + list(elem.shape[1:]) 310 | out = elem.new(storage).resize_(shape) 311 | return padded_stack(batch, pad_length=max_length, dim=0, out=out) 312 | 313 | 314 | def collate_numpy_array_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): 315 | elem = batch[0] 316 | # array of string classes and object 317 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 318 | raise TypeError(default_collate_err_msg_format.format(elem.dtype)) 319 | 320 | return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map) 321 | 322 | 323 | def collate_numpy_scalar_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): 324 | return torch.as_tensor(batch) 325 | 326 | 327 | def collate_float_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): 328 | return torch.tensor(batch, dtype=torch.float64) 329 | 330 | 331 | def collate_int_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): 332 | return torch.tensor(batch) 333 | 334 | 335 | def collate_str_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): 336 | return batch 337 | 338 | 339 | default_collate_fn_map: Dict[Union[Type, Tuple[Type, ...]], Callable] = {torch.Tensor: collate_tensor_fn} 340 | with contextlib.suppress(ImportError): 341 | import numpy as np 342 | # For both ndarray and memmap (subclass of ndarray) 343 | default_collate_fn_map[np.ndarray] = collate_numpy_array_fn 344 | # See scalars hierarchy: https://numpy.org/doc/stable/reference/arrays.scalars.html 345 | # Skip string scalars 346 | default_collate_fn_map[(np.bool_, np.number, np.object_)] = collate_numpy_scalar_fn 347 | default_collate_fn_map[float] = collate_float_fn 348 | default_collate_fn_map[int] = collate_int_fn 349 | default_collate_fn_map[str] = collate_str_fn 350 | 351 | 352 | def padded_collate(batch): 353 | """Like torch.utils.data.dataloader.default_collate, but pads the data to the maximum length. 354 | """ 355 | return collate(batch, collate_fn_map=default_collate_fn_map) 356 | 357 | class SST2Model(torch.nn.Module): 358 | 359 | def __init__(self, bert_encoder, train_encoder=True): 360 | """ 361 | Args: 362 | bert_encoder: An instance of a BERTEncoder 363 | train_encoder: wheter the encoder should be trained or not. 364 | """ 365 | super().__init__() 366 | 367 | self.bert_encoder = bert_encoder 368 | for param in self.bert_encoder.parameters(): 369 | param.requires_grad = train_encoder 370 | self.classifier = torch.nn.Linear(bert_encoder.d_model, 1, bias=False) 371 | 372 | def forward(self, input_ids): 373 | """ 374 | Predicts the sentiment of a sentence (positive or negative) 375 | Args: 376 | input_ids: tensor of shape (batch_size, seq_len) containing the token ids of the sentences 377 | Returns: 378 | tensor of shape (batch_size) containing the predicted sentiment 379 | """ 380 | h = self.bert_encoder(input_ids) 381 | return self.classifier(h[:, 0]).view(-1) 382 | 383 | def train_sst2(bert_encoder, train_encoder=False, epochs=3, batch_size=256, lr=1e-3, device='cuda'): 384 | sst2_dataset = load_sst2_dataset(train=True) 385 | loader = DataLoader(sst2_dataset, batch_size=batch_size, shuffle=True, collate_fn=padded_collate, num_workers=4) 386 | sst2_model = SST2Model(bert_encoder, train_encoder=train_encoder).to(device).train() 387 | opt = torch.optim.AdamW(sst2_model.classifier.parameters(), lr=lr) 388 | loss_avg = MovingAverage() 389 | acc_avg = MovingAverage() 390 | for ep in range(epochs): 391 | with tqdm(loader, desc=f'Epoch {ep}') as pbar: 392 | for batch in pbar: 393 | opt.zero_grad(set_to_none=True) 394 | input_ids = batch['input_ids'].to(device) 395 | labels = batch['labels'].float().to(device) 396 | logits = sst2_model(input_ids) 397 | loss = F.binary_cross_entropy_with_logits(logits, labels) 398 | loss.backward() 399 | opt.step() 400 | loss_avg.update(loss) 401 | acc_avg.update(((logits > 0) == labels).float().mean()) 402 | pbar.set_postfix( 403 | loss=loss_avg.get(), 404 | acc=acc_avg.get() 405 | ) 406 | return sst2_model 407 | 408 | @torch.no_grad() 409 | def validate_sst2(model, device): 410 | model.eval() 411 | dataset = load_sst2_dataset(train=False) 412 | loader = DataLoader(dataset, batch_size=64, shuffle=False, collate_fn=padded_collate) 413 | accs = [] 414 | for batch in tqdm(loader): 415 | input_ids = batch['input_ids'].to(device) 416 | labels = batch['labels'].to(device) 417 | logits = model(input_ids) 418 | pred = logits > 0 419 | accs.append((pred == labels).float()) 420 | return torch.cat(accs).mean().item() * 100 -------------------------------------------------------------------------------- /exercises/week04_simclr/A04_SimCLR_Resnet18.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": { 7 | "id": "e-uirVvNW-yY" 8 | }, 9 | "source": [ 10 | "HHU Deep Learning, SS2022/23, 05.05.2023, Prof. Dr. Markus Kollmann\n", 11 | "\n", 12 | "Lecturers and Tutoring is done by Tim Kaiser, Nikolas Adaloglou and Felix Michels.\n", 13 | "\n", 14 | "# Assignment 05 - Contrastive self-supervised learning: SimCLR in STL10 with Resnet18 \n", 15 | "\n", 16 | "\n", 17 | "## Contents\n", 18 | "\n", 19 | "1. Preparation and imports\n", 20 | "2. Implement the augmentation pipeline used in SimCLR\n", 21 | "3. Implement the SimCLR Contrastive loss (NT-Xent)\n", 22 | "4. Load and modify resnet18\n", 23 | "5. Gradient Accumulation: Implement the `training_step` and `pretrain_one_epoch_grad_acc`\n", 24 | "6. Putting everything together and train the model\n", 25 | "7. Linear probing + T-SNE visualization of features\n", 26 | "8. Compare SimCLR versus supervised Imagenet-pretrained weights and random init on STL10 train/val split\n", 27 | "9. Plot the val accuracies for the 3 different initializations" 28 | ] 29 | }, 30 | { 31 | "attachments": {}, 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "# Introduction \n", 36 | "\n", 37 | "Contrastive loss is a way of training a machine learning model in a self-supervised manner, where the goal is to learn meaningful representations of the input data without any explicit labels or annotations.\n", 38 | "\n", 39 | "The basic idea is to take a pair of input samples (such as two augmented views from the same image), and compare them to see if they are similar or dissimilar. The model is then trained to push similar pairs closer together in the representation space, while pushing dissimilar pairs farther apart.\n", 40 | "\n", 41 | "To do this, the contrastive loss function measures the similarity between the representations of the two input samples (nominator), and encourages the model to maximize this similarity if the samples are similar, and minimize it if they are dissimilar.\n", 42 | "\n", 43 | "\n", 44 | "You can also advice the [SimCLR Paper](https://arxiv.org/abs/2002.05709)" 45 | ] 46 | }, 47 | { 48 | "attachments": {}, 49 | "cell_type": "markdown", 50 | "metadata": { 51 | "id": "lw5K7r5SQDca" 52 | }, 53 | "source": [ 54 | "# Part I. Preparation and imports" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": { 61 | "id": "1SIad0uuHBlv" 62 | }, 63 | "outputs": [], 64 | "source": [ 65 | "import os\n", 66 | "import torch\n", 67 | "import torchvision.models as models\n", 68 | "import numpy as np\n", 69 | "\n", 70 | "import torch\n", 71 | "import torchvision\n", 72 | "import torchvision.transforms as T\n", 73 | "import torch.nn as nn\n", 74 | "import torch.nn.functional as F\n", 75 | "from torchvision.datasets import STL10\n", 76 | "from torch.utils.data import DataLoader\n", 77 | "from torch.optim import Adam\n", 78 | "import tqdm\n", 79 | "\n", 80 | "# Local imports\n", 81 | "from utils import *" 82 | ] 83 | }, 84 | { 85 | "attachments": {}, 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "# Part II. Implement the augmentation pipeline used in SimCLR\n", 90 | "\n", 91 | "In contrastive self-supervised learning, there are several image augmentations that are commonly used to create pairs of images that are transformed versions of each other. These augmentations are designed to ensure that the resulting views have enough differences between them so that the model can learn to distinguish between them, while also preserving the label-related information.\n", 92 | "\n", 93 | "Implement the following transformations **presented in random order**:\n", 94 | "\n", 95 | "\n", 96 | "- Random flipping: This involves randomly flipping the image horizontally or vertically. Choose the one that best fits with a probability of 50%.\n", 97 | "- Normalize the images with an appropriate mean std.\n", 98 | "- Color jitter: This involves randomly changing the brightness, contrast, saturation and hue (20%) of the image. This augmentation helps the model learn to recognize objects or scenes under different lighting conditions. Apply this augmentation with a probability of 80%. Distort the brightness, contrast, saturation in the range `[0.2, 1.8]`.\n", 99 | "- Random cropping: This involves randomly cropping a portion of the image to create a new image. We will then resize the images to 64x64 instead of 96x96 to reduce the computational time complexity to train the model. Use a scale of 10-100% of the initial image size. \n", 100 | "- Gaussian blur: This augmentation helps the model learn to recognize objects or scenes that are slightly out of focus. Use a `kernel_size` of 3 and Standard deviation of 0.1 to 2.0.\n", 101 | "\n", 102 | "\n", 103 | "The above augmentations are typically applied randomly to each image in a pair, resulting in two slightly different versions of the same image that can be used for contrastive learning.\n", 104 | "\n", 105 | "Your task is to define the augmentation and decide in which order they should be applied. " 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "class Augment:\n", 115 | " \"\"\"\n", 116 | " A stochastic data augmentation module\n", 117 | " Transforms any given data example randomly\n", 118 | " resulting in two correlated views of the same example,\n", 119 | " denoted x ̃i and x ̃j, which we consider as a positive pair.\n", 120 | " \"\"\"\n", 121 | " def __init__(self, img_size):\n", 122 | " ### START CODE HERE ### (≈ 5 lines of code)\n", 123 | " \n", 124 | " def __call__(self, x):\n", 125 | " # This function applied the same transformation to an image twice.\n", 126 | " \n", 127 | " ### END CODE HERE ###\n", 128 | "\n", 129 | "def load_data( batch_size=128, train_split=\"unlabeled\", test_split=\"test\", transf = T.ToTensor()):\n", 130 | " # Returns a train and validation dataloader for STL10 dataset\n", 131 | " ### START CODE HERE ### (≈ 6 lines of code)\n", 132 | "\n", 133 | " ### END CODE HERE ###\n", 134 | " return train_dl, val_dl" 135 | ] 136 | }, 137 | { 138 | "attachments": {}, 139 | "cell_type": "markdown", 140 | "metadata": { 141 | "id": "4v-Qg5Xpk2Bv" 142 | }, 143 | "source": [ 144 | "# Part III. Implement the SimCLR Contrastive loss (NT-Xent)\n", 145 | "\n", 146 | "Let $sim(u,v)$ note the dot product between 2 normalized $u$ and $v$ (i.e. cosine similarity). Then the loss function for a **positive pair**\n", 147 | "of examples (i,j) is defined as:\n", 148 | "$$\n", 149 | "\\ell_{i, j}=-\\log \\frac{\\exp \\left(\\operatorname{sim}\\left(\\boldsymbol{z}_{i}, \\boldsymbol{z}_{j}\\right) / \\tau\\right)}{\\sum_{k=1}^{2 N} \\mathbb{1}_{[k \\neq i]} \\exp \\left(\\operatorname{sim}\\left(\\boldsymbol{z}_{i}, \\boldsymbol{z}_{k}\\right) / \\tau\\right)}\n", 150 | "$$\n", 151 | "\n", 152 | "where $\\mathbb{1}_{[k \\neq i]} $ ∈{0,1} is an indicator function evaluating to 1 iff $k != i$ and τ denotes a temperature parameter. The final loss is computed by summing all positive pairs and divide by $2\\times N = views \\times batch_{size} $\n", 153 | "\n", 154 | "There are different ways to develop contrastive loss. \n", 155 | "\n", 156 | "\n", 157 | "#### Hints\n", 158 | "Here we provide you with some hints about the main algorithm:\n", 159 | "\n", 160 | "- apply l2 normalization to the features and concatenate them in the batch dimension\n", 161 | "\n", 162 | "- Calculate the similarity/logits of all pairs. Output shape:[batch_size $\\times$ views,batch_size $\\times$ views]\n", 163 | "\n", 164 | "- Make Identity matrix as mask with size=(batch_size $\\times$ views, batch_size $\\times$ views)\n", 165 | "\n", 166 | "- Repeat the mask in both direction to the number of views (in simclr number of views = 2)\n", 167 | "for batch_size=5 and 2 views: \n", 168 | "```\n", 169 | "[1., 0., 0., 0., 0., 1., 0., 0., 0., 0.]\n", 170 | "[0., 1., 0., 0., 0., 0., 1., 0., 0., 0.],\n", 171 | "[0., 0., 1., 0., 0., 0., 0., 1., 0., 0.],\n", 172 | "[0., 0., 0., 1., 0., 0., 0., 0., 1., 0.],\n", 173 | "[0., 0., 0., 0., 1., 0., 0., 0., 0., 1.],\n", 174 | "[1., 0., 0., 0., 0., 1., 0., 0., 0., 0.],\n", 175 | "[0., 1., 0., 0., 0., 0., 1., 0., 0., 0.],\n", 176 | "[0., 0., 1., 0., 0., 0., 0., 1., 0., 0.],\n", 177 | "[0., 0., 0., 1., 0., 0., 0., 0., 1., 0.],\n", 178 | "[0., 0., 0., 0., 1., 0., 0., 0., 0., 1.]\n", 179 | "```\n", 180 | "\n", 181 | "4. Make a mask to index the positive pairs. mask-out the self-contrast as follows.\n", 182 | "make a mask with the shape of the logits = [batch_size $\\times$ views,batch_size $\\times$ views] that has ones in the diagonals that are +- batch_size from the main diagonal. this will be used to index the positive pairs.\n", 183 | "Example for [6,6] matrix (batch_size=3,views=2):\n", 184 | "```\n", 185 | "[0., 0., 0., 1., 0., 0.],\n", 186 | "[0., 0., 0., 0., 1., 0.],\n", 187 | "[0., 0., 0., 0., 0., 1.],\n", 188 | "[1., 0., 0., 0., 0., 0.],\n", 189 | "[0., 1., 0., 0., 0., 0.],\n", 190 | "[0., 0., 1., 0., 0., 0.]\n", 191 | "``` \n", 192 | "Ones here will be the positive elements for the nominator.\n", 193 | "Alternativly you can use torch.diag() to take the positives from the [6,6] similarity matrix (aka logits)\n", 194 | "\n", 195 | "- Use the positives to form the nominator.Scale down result with the temperature. There are batch_size $\\times$ views positive pairs.\n", 196 | "\n", 197 | "- Calculate the denominator by summing the masked logits in the correct dimension.\n", 198 | "\n", 199 | "- dont forget to apply `-log(result)`\n", 200 | "\n", 201 | "- Calculate the final loss as in the above equation.\n", 202 | "\n", 203 | "\n", 204 | "#### A note on L2 normalization\n", 205 | "\n", 206 | "L2 normalization is a common technique used in contrastive learning to normalize the embedding vectors before computing the contrastive loss. \n", 207 | "\n", 208 | "This is because L2 normalization scales the vectors to have unit length. Without L2 normalization, the magnitude of the embedding vectors can have a large influence on the contrastive loss. \n", 209 | "\n", 210 | "This can result in the optimization process focusing more on adjusting the magnitude of the vectors rather than their direction, leading to suboptimal solutions. \n", 211 | "\n", 212 | "By normalizing the embeddings, the contrastive loss only considers the angular difference between embedding vectors.\n", 213 | "\n", 214 | "\n" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "metadata": { 221 | "id": "FeP4ZuZpsyOp" 222 | }, 223 | "outputs": [], 224 | "source": [ 225 | "import torch\n", 226 | "import torch.nn as nn\n", 227 | "\n", 228 | "class ContrastiveLoss(nn.Module):\n", 229 | " \"\"\"\n", 230 | " Vanilla Contrastive loss, also called InfoNceLoss as in SimCLR paper\n", 231 | " There are different ways to develop contrastive loss. Here we provide you with some hints about the main algorithm:\n", 232 | " 1- create an Identity matrix as a mask (bsz, bsz)\n", 233 | " 2- repeat the mask in both direction to the number of views (in simclr number of views = 2) in the above code we called it anchor_count\n", 234 | " 3- modify the mask to remove the self contrast cases\n", 235 | " 4- calculate the similarity of two features. *Note: final size should be [bsz, bsz]\n", 236 | " 5- apply the mask on similairty matrix \n", 237 | " 6- calculate the final loss \n", 238 | " \"\"\"\n", 239 | " ### START CODE HERE ### (≈ 19 lines of code)\n", 240 | " \n", 241 | " def forward(self, proj_1, proj_2):\n", 242 | " \"\"\"\n", 243 | " proj_1 and proj_2 are batched embeddings [batch, embedding_dim]\n", 244 | " where corresponding indices are pairs\n", 245 | " z_i, z_j in the SimCLR paper\n", 246 | " \"\"\"\n", 247 | "\n", 248 | " return loss # scalar!\n", 249 | " ### END CODE HERE ###\n", 250 | "\n", 251 | "def test_ContrastiveLoss():\n", 252 | " batch_size = 8\n", 253 | " temperature = 0.1\n", 254 | " criterion = ContrastiveLoss(batch_size, temperature)\n", 255 | " proj_1 = torch.rand(batch_size, 128)\n", 256 | " proj_2 = torch.rand(batch_size, 128)\n", 257 | " loss = criterion(proj_1, proj_2)\n", 258 | " assert loss.shape == torch.Size([]), \"ContrastiveLoss output shape is wrong\"\n", 259 | " assert loss.item() >= 0, \"ContrastiveLoss output is negative\"\n", 260 | " print(\"ContrastiveLoss test passed!\")\n", 261 | "\n", 262 | "test_ContrastiveLoss()" 263 | ] 264 | }, 265 | { 266 | "attachments": {}, 267 | "cell_type": "markdown", 268 | "metadata": { 269 | "id": "F8iM6b8CQjSy" 270 | }, 271 | "source": [ 272 | "# Part IV. Load and modify resnet18\n", 273 | "\n", 274 | "- Load and modify the resnet18.\n", 275 | "- Add an MLP with batch normalization after the resnet18 backbone as illustrate below:\n", 276 | "```python\n", 277 | "Sequential(\n", 278 | " (0): Linear(in_features=in_features, out_features=in_features, bias=False)\n", 279 | " (1): BatchNorm(in_features)\n", 280 | " (2): ReLU()\n", 281 | " (3): Linear(in_features=in_features, out_features=embedding_size, bias=False)\n", 282 | " (4): BatchNorm(embedding_size))\n", 283 | "```" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": null, 289 | "metadata": { 290 | "id": "WpEEBp7EH7-x" 291 | }, 292 | "outputs": [], 293 | "source": [ 294 | "class ResNetSimCLR(nn.Module):\n", 295 | " def __init__(self, embedding_size=128):\n", 296 | " super(ResNetSimCLR, self).__init__()\n", 297 | " ### START CODE HERE ### (≈ 10 lines of code)\n", 298 | " # load resnet18 pretrained on imagenet\n", 299 | " # self.backbone = ...\n", 300 | " # add mlp projection head\n", 301 | " # self.projection = ....\n", 302 | "\n", 303 | " def forward(self, x, return_embedding=False):\n", 304 | "\n", 305 | " ### END CODE HERE ###" 306 | ] 307 | }, 308 | { 309 | "attachments": {}, 310 | "cell_type": "markdown", 311 | "metadata": { 312 | "id": "ppxywhSH_Xjc" 313 | }, 314 | "source": [ 315 | "# Part V. Implement the `training_step` and `pretrain_one_epoch_grad_acc`\n", 316 | "\n", 317 | "### Gradient accumulation and mixed precision\n", 318 | "\n", 319 | "- `training_step` should load a batch of 2 image views and feed them to the model. The loss function will calculate the implemented SimCLR loss.\n", 320 | "- Gradient accumulation saves the gradient values for $N$ steps. It calculates the gradients and proceeds to the next batch. Remember that when you call `loss.backward()` the newly computed gradients are added to the old ones. After N steps, the parameter update is done and the loss shall be scaled down (averaged) by the number of N iterations.\n", 321 | "\n", 322 | "Note: SimCLR training requires a large batch size. You should be to train SimCLR with a batch size of at least 256 on Google Colab.\n", 323 | "\n", 324 | "#### Explanation of accumulated gradients\n", 325 | "\n", 326 | "When training large neural networks, the computational cost of computing the gradient for all of the training examples in the dataset can be prohibitive. Gradient accumulation is a technique used to increase the size of the batch of training samples used to update the weights of the network. \n", 327 | "\n", 328 | "Instead of applying the gradients to the model's parameters after each batch, the gradients are accumulated over a batch of training examples. The accumulated gradients are then used to update the model's parameters. In this way, one reduces the noise in the gradients by averaging them over a batch of training examples, which can lead to more stable updates to the model's parameters. It also allows the model to make larger updates to its parameters, which may speed up the training process.\n", 329 | "\n", 330 | "For example, if we set the batch size to 32, the network would process 32 examples at a time, compute the gradients for each example, and then accumulate the gradients over the 32 examples. After accumulating the gradients for the entire batch, the weights of the network are updated using the average of the accumulated gradients. Thus, for a batch size of 32 you can accumulate gradients every N steps so that you have an effective batch size of 32 $\\times$ N!\n", 331 | "\n", 332 | "> Importantly, gradient accumulation slows down training since gradient updates happen every N steps, but it is expected to see the loss dropping steadily and probably faster, depending on the method.\n", 333 | "\n", 334 | "### Mixed Precision\n", 335 | "\n", 336 | "At this point, we are introducing another technique to optimize GPU memory usage to use larger batch sizes, mixed precision. The idea is to perform as many operations as possible in fp16, instead of the standard fp32, during training. This is not as simple as casting everything to fp16 however, because some operations are sensitive to underflow (being rounded to 0), especially the gradient itself. \n", 337 | "\n", 338 | "Luckily, there is a torch package for this, `torch.cuda.amp`. Feel free to check out the docs [here](https://pytorch.org/docs/stable/amp.html#) and some examples [here](https://pytorch.org/docs/stable/notes/amp_examples.html#amp-examples). This package takes care of the intricate things and you can go ahead and train. \n", 339 | "\n", 340 | "We are using two functions from the package here, `autocast` and `GradScaler`. Autocast is taking care of casting the correct tensors to fp16 and leaving the others unchanged. The GradScaler then makes sure that the gradients in the backward pass avoid numerical instabilities. \n", 341 | "\n", 342 | "Feel free to use this technique in future exercises to save some memory and speed up your training. " 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": null, 348 | "metadata": { 349 | "id": "a5ukADmI_d_H" 350 | }, 351 | "outputs": [], 352 | "source": [ 353 | "from torch.cuda.amp import autocast, GradScaler\n", 354 | "\n", 355 | "def training_step(model, loss_function, data):\n", 356 | " ### START CODE HERE ### (≈ 5 lines of code)\n", 357 | " \n", 358 | " ### END CODE HERE ###\n", 359 | " return loss\n", 360 | "\n", 361 | "def pretrain_one_epoch_grad_acc(model, loss_function, train_dataloader, \n", 362 | " optimizer, device, accum_iter=1, amp=False):\n", 363 | " model.train()\n", 364 | " total_loss = 0\n", 365 | " num_batches = len(train_dataloader)\n", 366 | " optimizer.zero_grad()\n", 367 | " scaler = GradScaler() if amp else None\n", 368 | " for batch_idx,data in enumerate(train_dataloader):\n", 369 | " ### START CODE HERE ### ( > 6 lines of code)\n", 370 | " if amp:\n", 371 | " # ....\n", 372 | " else:\n", 373 | " #.......\n", 374 | " \n", 375 | " # weights update\n", 376 | "\n", 377 | " # scale back the loss\n", 378 | " # total_loss = ....\n", 379 | "\n", 380 | " ### END CODE HERE ###\n", 381 | " return total_loss/num_batches\n", 382 | " \n", 383 | "\n", 384 | "\n", 385 | "def pretrain(model, optimizer, num_epochs, train_loader, criterion, device, accum_iter=1, amp=False):\n", 386 | " dict_log = {\"train_loss\":[]}\n", 387 | " best_loss = 1e8\n", 388 | " model = model.to(device)\n", 389 | " pbar = tqdm(range(num_epochs))\n", 390 | " for epoch in pbar:\n", 391 | " train_loss = pretrain_one_epoch_grad_acc(model, criterion, train_loader, optimizer,\n", 392 | " device, accum_iter, amp=amp)\n", 393 | " msg = (f'Ep {epoch}/{num_epochs}: || Loss: Train {train_loss:.3f}')\n", 394 | " pbar.set_description(msg)\n", 395 | " dict_log[\"train_loss\"].append(train_loss)\n", 396 | " \n", 397 | " # Use this code to save the model with the lowest loss\n", 398 | " if train_loss < best_loss:\n", 399 | " best_val_loss = train_loss\n", 400 | " save_model(model, f'best_model_min_train_loss.pth', epoch, optimizer, train_loss) \n", 401 | " if epoch == num_epochs - 1:\n", 402 | " save_model(model, f'last_model_ep{epoch}.pth', epoch, optimizer, train_loss)\n", 403 | " return dict_log" 404 | ] 405 | }, 406 | { 407 | "attachments": {}, 408 | "cell_type": "markdown", 409 | "metadata": { 410 | "id": "LqNuy5R2AThH" 411 | }, 412 | "source": [ 413 | "# Part VI. Putting everything together and train the model\n", 414 | "\n", 415 | "Hint: ~50 epochs should be sufficient to see the learned features.\n", 416 | "\n", 417 | "A small training trick here. We will exclude batch normalization parameters from weight decay in `define_param_groups`\n", 418 | "\n", 419 | "Note on complexity: 10.7 VRAM used and ~156mins needed. Effective batch size>1024, images of 64x64, 60 epochs.\n", 420 | "\n", 421 | "In case you face problem with Google colab, download the model every 5 epochs or better mount you google drive and save the model there in case you disconnect.\n", 422 | "\n", 423 | "Here\n", 424 | "```python\n", 425 | "PATH = './best_model.ckpt'\n", 426 | "torch.save(model_simclr.state_dict(), PATH)\n", 427 | "files.download(PATH)\n", 428 | "```" 429 | ] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "execution_count": null, 434 | "metadata": {}, 435 | "outputs": [], 436 | "source": [ 437 | "class Hparams:\n", 438 | " def __init__(self):\n", 439 | " # This is what we used, feel free to change those parameters.\n", 440 | " # You only need to specify the temperature in the config object\n", 441 | " self.seed = 77777 # randomness seed\n", 442 | " self.device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", 443 | " self.img_size = 64 #image shape\n", 444 | " self.load = False # load pretrained checkpoint\n", 445 | " self.batch_size = 512\n", 446 | " self.lr = 3e-4 # for ADAm only\n", 447 | " self.weight_decay = 1e-6\n", 448 | " self.embedding_size = 128 # papers value is 128\n", 449 | " \n", 450 | " self.epochs = 100\n", 451 | " self.accum_iter = 1 # gradient accumulation\n", 452 | " self.amp = True # automatic mixed precision\n", 453 | " ############################################\n", 454 | " # START CODE HERE ### (≈ 1 line of code)\n", 455 | " self.temperature = ........\n", 456 | " ### END CODE HERE ###\n", 457 | "\n", 458 | "### START CODE HERE ### (>10 lines of code)\n", 459 | "\n", 460 | "\n", 461 | "# Launch training i.e :\n", 462 | "# dict_log = pretrain(model, optimizer, config.epochs,\n", 463 | "# train_dl, criterion, \n", 464 | "# config.device, accum_iter=config.accum_iter,\n", 465 | "# amp=config.amp)\n", 466 | "\n", 467 | "### END CODE HERE ###" 468 | ] 469 | }, 470 | { 471 | "attachments": {}, 472 | "cell_type": "markdown", 473 | "metadata": {}, 474 | "source": [ 475 | "# Part VII. Linear probing + T-SNE visualization of features\n", 476 | "\n", 477 | "As in the previous exercise, check the results of linear probing on the supervised training split and the T-SNE visualization.\n", 478 | "\n", 479 | "Code for the T-SNE visualization exists in `utils.py`." 480 | ] 481 | }, 482 | { 483 | "cell_type": "code", 484 | "execution_count": null, 485 | "metadata": {}, 486 | "outputs": [], 487 | "source": [ 488 | "### START CODE HERE ### (> 10 lines of code)\n", 489 | "# model = ResNetSimCLR(embedding_size=config.embedding_size)\n", 490 | "# model = load_model(model, \"simclr.pth\")\n", 491 | "\n", 492 | "\n", 493 | "# Linear evaluation\n", 494 | "\n", 495 | "\n", 496 | "# TSNE plot\n", 497 | "\n", 498 | "\n", 499 | "### END CODE HERE ###" 500 | ] 501 | }, 502 | { 503 | "attachments": {}, 504 | "cell_type": "markdown", 505 | "metadata": {}, 506 | "source": [ 507 | "### Expected results\n", 508 | "```\n", 509 | "Model simclr.pth is loaded from epoch 99 , loss 5.342101926069994\n", 510 | "Ep 199/200: Accuracy : Train:87.80 \t Val:78.41 || Loss: Train 0.360 \t Val 0.612\n", 511 | "```" 512 | ] 513 | }, 514 | { 515 | "attachments": {}, 516 | "cell_type": "markdown", 517 | "metadata": { 518 | "id": "4FrwRzDnAst5" 519 | }, 520 | "source": [ 521 | "# Part VIII. Compare SimCLR versus supervised Imagenet-pretrained weights and random init on STL10 train/val split\n", 522 | "\n", 523 | "- Don't forget to use the train split of STL10 for supervised training.\n", 524 | "- For simplicity, don't use augmentations here, although it's possible and it would lead to better results.\n", 525 | "- Since we are not using any augmentations at this step, simclr will have the same results as before.\n", 526 | "\n", 527 | "\n", 528 | "Variants to be tested: \n", 529 | "- SimCLR weights trained for at least 50 epochs\n", 530 | "- Imagenet initialization\n", 531 | "- random initialization\n", 532 | "Afterward, print the best val. accuracy for all 3 models!" 533 | ] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "execution_count": null, 538 | "metadata": {}, 539 | "outputs": [], 540 | "source": [ 541 | "def main(mode='simclr'):\n", 542 | " device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 543 | " ### START CODE HERE ### (≈ 15 lines of code)\n", 544 | " \n", 545 | " if mode == 'random':\n", 546 | "\n", 547 | " elif mode == 'imagenet':\n", 548 | "\n", 549 | " elif mode == 'simclr':\n", 550 | " \n", 551 | " ### END CODE HERE ###\n", 552 | " optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", 553 | " dict_log = linear_eval(model, optimizer, 20, train_dl, val_dl, device)\n", 554 | " return dict_log\n", 555 | " \n", 556 | "\n", 557 | "dict_log_simclr = main('simclr')\n", 558 | "acc1 = np.max(dict_log_simclr[\"val_acc_epoch\"])\n", 559 | "dict_log_in = main('imagenet')\n", 560 | "acc2 = np.max(dict_log_in[\"val_acc_epoch\"])\n", 561 | "dict_log_ran = main('random')\n", 562 | "acc3 = np.max(dict_log_ran[\"val_acc_epoch\"])\n", 563 | "print(f\"Fine-tuning best results: SimCLR: {acc1:.2f}%, ImageNet: {acc2:.2f} %, Random: {acc3:.2f} %\")" 564 | ] 565 | }, 566 | { 567 | "attachments": {}, 568 | "cell_type": "markdown", 569 | "metadata": { 570 | "id": "HPqA2qOp9vl6" 571 | }, 572 | "source": [ 573 | "### Expected results\n", 574 | "\n", 575 | "By fine-tuning all variants for 20 epochs this is what we got: \n", 576 | "\n", 577 | "```\n", 578 | "Fine-tuning best results: SimCLR: 77.26%, ImageNet: 76.25 %, Random: 53.83 %\n", 579 | "\n", 580 | "```" 581 | ] 582 | }, 583 | { 584 | "attachments": {}, 585 | "cell_type": "markdown", 586 | "metadata": {}, 587 | "source": [ 588 | "# Part IX. Plot the val accuracies for the 3 different initializations" 589 | ] 590 | }, 591 | { 592 | "cell_type": "code", 593 | "execution_count": null, 594 | "metadata": {}, 595 | "outputs": [], 596 | "source": [ 597 | "# Provided\n", 598 | "plt.figure(figsize=(10, 5))\n", 599 | "plt.plot(dict_log_simclr[\"val_acc_epoch\"], label=\"SimCLR\")\n", 600 | "plt.plot(dict_log_in[\"val_acc_epoch\"], label=\"ImageNet\")\n", 601 | "plt.plot(dict_log_ran[\"val_acc_epoch\"], label=\"Random\")\n", 602 | "plt.legend()\n", 603 | "plt.xlabel(\"Epochs\")\n", 604 | "plt.ylabel(\"Accuracy\")\n", 605 | "plt.title(\"Fine tuning results on STL-10\")\n", 606 | "plt.savefig(\"fine_tuning_results_stl10.png\")\n", 607 | "plt.show()" 608 | ] 609 | }, 610 | { 611 | "attachments": {}, 612 | "cell_type": "markdown", 613 | "metadata": {}, 614 | "source": [ 615 | "# Conclusion and Bonus reads\n", 616 | "\n", 617 | "That's the end of this exercise. If you reached this point, congratulations!\n", 618 | "\n", 619 | "\n", 620 | "### Optional stuff\n", 621 | "\n", 622 | "- Improve SimCLR. Add the [LARS optimizer](https://gist.github.com/black0017/3766fc7c62bdd274df664f8ec03715a2) with linear warm + [cosine scheduler](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html?highlight=cosine%20scheduler#torch.optim.lr_scheduler.CosineAnnealingLR) + train for 200 epochs. Then make a new comparison!\n", 623 | "- Train on CIFAR100 and compare rotation prediction VS SimCLR pretraining on both datasets. Which pretext task is likely to work better there?" 624 | ] 625 | } 626 | ], 627 | "metadata": { 628 | "accelerator": "GPU", 629 | "colab": { 630 | "collapsed_sections": [], 631 | "machine_shape": "hm", 632 | "name": "[Exercise 4] - SimCLR Resnet18 Solution.ipynb", 633 | "provenance": [] 634 | }, 635 | "kernelspec": { 636 | "display_name": "Python 3", 637 | "language": "python", 638 | "name": "python3" 639 | }, 640 | "language_info": { 641 | "codemirror_mode": { 642 | "name": "ipython", 643 | "version": 3 644 | }, 645 | "file_extension": ".py", 646 | "mimetype": "text/x-python", 647 | "name": "python", 648 | "nbconvert_exporter": "python", 649 | "pygments_lexer": "ipython3", 650 | "version": "3.8.5" 651 | }, 652 | "vscode": { 653 | "interpreter": { 654 | "hash": "dc5fcf396fe0abd4fa852aee332a0572494dcaf5776820055c87d9b84157f362" 655 | } 656 | } 657 | }, 658 | "nbformat": 4, 659 | "nbformat_minor": 1 660 | } 661 | -------------------------------------------------------------------------------- /exercises/week04_simclr/figs/fine_tuning_results_stl10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHU-MMBS/RepresentationLearning_SS2023/9238c4c7648f8d64dd0461e828652cd3606f36a1/exercises/week04_simclr/figs/fine_tuning_results_stl10.png -------------------------------------------------------------------------------- /exercises/week04_simclr/figs/simclr-illustration-loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHU-MMBS/RepresentationLearning_SS2023/9238c4c7648f8d64dd0461e828652cd3606f36a1/exercises/week04_simclr/figs/simclr-illustration-loss.png -------------------------------------------------------------------------------- /exercises/week04_simclr/figs/t-sne-simclr_feats__.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHU-MMBS/RepresentationLearning_SS2023/9238c4c7648f8d64dd0461e828652cd3606f36a1/exercises/week04_simclr/figs/t-sne-simclr_feats__.png -------------------------------------------------------------------------------- /exercises/week04_simclr/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader, Dataset 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | import torch.optim as optim 11 | import torch.utils.data as data 12 | import random 13 | import matplotlib.pyplot as plt 14 | from torchvision import transforms as T 15 | from tqdm import tqdm 16 | from sklearn.manifold import TSNE 17 | 18 | 19 | def imshow(img, i=0, mean=torch.tensor([0.0], dtype=torch.float32), std=torch.tensor([1], dtype=torch.float32)): 20 | """ 21 | shows an image on the screen. mean of 0 and variance of 1 will show the images unchanged in the screen 22 | """ 23 | # undoes the normalization 24 | unnormalize = T.Normalize((-mean / std).tolist(), (1.0 / std).tolist()) 25 | npimg = unnormalize(img).numpy() 26 | plt.imshow(np.transpose(npimg, (1, 2, 0))) 27 | 28 | 29 | def prevalidate(model, val_loader,criterion, device): 30 | ### START CODE HERE ### (≈ 12 lines of code) 31 | model.eval() 32 | correct, total = 0, 0 33 | loss_step = [] 34 | with torch.no_grad(): 35 | for data in val_loader: 36 | inp_data,labels = data 37 | inp_data, labels = inp_data.to(device), labels.to(device) 38 | outputs = model(inp_data) 39 | val_loss = criterion(outputs, labels) 40 | loss_step.append(val_loss.item()) 41 | # dont forget to take the means here 42 | val_loss_epoch = np.mean(loss_step) 43 | ### END CODE HERE ### 44 | return val_loss_epoch 45 | 46 | def pretrain_one_epoch(model, optimizer, train_loader, criterion, device): 47 | ### START CODE HERE ### (≈ 12 lines of code) 48 | model.train() 49 | loss_step = [] 50 | for data in train_loader: 51 | # Move the data to the GPU 52 | inp_data, labels = data 53 | inp_data, labels = inp_data.to(device), labels.to(device) 54 | outputs = model(inp_data) 55 | loss = criterion(outputs, labels) 56 | optimizer.zero_grad() 57 | loss.backward() 58 | optimizer.step() 59 | loss_step.append(loss.item()) 60 | # dont forget the means here 61 | loss_curr_epoch = np.mean(loss_step) 62 | ### END CODE HERE ### 63 | return loss_curr_epoch 64 | 65 | def save_model(model, path, epoch, optimizer, val_loss): 66 | torch.save({ 67 | 'epoch': epoch, 68 | 'model_state_dict': model.state_dict(), 69 | 'optimizer_state_dict': optimizer.state_dict(), 70 | 'loss': val_loss, 71 | }, path) 72 | 73 | def pretrain(model, optimizer, num_epochs, train_loader, val_loader, criterion, device): 74 | dict_log = {"train_loss":[], "val_loss":[]} 75 | ### START CODE HERE ### (≈ 12 lines of code) 76 | best_val_loss = 1e8 77 | model = model.to(device) 78 | pbar = tqdm(range(num_epochs)) 79 | for epoch in pbar: 80 | loss_curr_epoch = pretrain_one_epoch(model, optimizer, train_loader, criterion, device) 81 | val_loss = prevalidate(model, val_loader, criterion, device) 82 | 83 | # Print epoch results to screen 84 | msg = (f'Ep {epoch}/{num_epochs}: || Loss: Train {loss_curr_epoch:.3f} \t Val {val_loss:.3f}') 85 | pbar.set_description(msg) 86 | 87 | dict_log["train_loss"].append(loss_curr_epoch) 88 | dict_log["val_loss"].append(val_loss) 89 | 90 | # Use this code to save the model with the best validation loss 91 | if val_loss < best_val_loss: 92 | best_val_loss = val_loss 93 | save_model(model, f'best_model_min_val_loss.pth', epoch, optimizer, val_loss) 94 | ### END CODE HERE ### 95 | return dict_log 96 | 97 | 98 | 99 | def validate(model, val_loader, device): 100 | model.eval() 101 | criterion = nn.CrossEntropyLoss() 102 | correct, total = 0, 0 103 | loss_step = [] 104 | with torch.no_grad(): 105 | for data in val_loader: 106 | inp_data,labels = data 107 | inp_data = inp_data.to(device) 108 | labels = labels.to(device) 109 | outputs = model(inp_data) 110 | val_loss = criterion(outputs, labels) 111 | predicted = torch.max(outputs, 1)[1] 112 | total += labels.size(0) 113 | correct += (predicted == labels).sum() 114 | loss_step.append(val_loss.item()) 115 | # dont forget to take the means here 116 | val_acc = (100 * correct / total).cpu().numpy() 117 | val_loss_epoch = torch.tensor(loss_step).mean().numpy() 118 | return val_acc , val_loss_epoch 119 | 120 | # Provided 121 | def get_features(model, dataloader, device): 122 | model = model.to(device) 123 | feats, labs = [], [] 124 | for i in dataloader: 125 | inp_data,labels = i 126 | inp_data = inp_data.to(device) 127 | features = model(inp_data) 128 | features = features.cpu().detach().flatten(start_dim=1) 129 | labels = labels.cpu().detach() 130 | feats.append(features) 131 | labs.append(labels) 132 | f = torch.cat(feats, dim=0) 133 | l = torch.cat(labs, dim=0) 134 | return f,l 135 | 136 | 137 | def train_one_epoch(model, optimizer, train_loader, device): 138 | model.train() 139 | criterion = nn.CrossEntropyLoss() 140 | loss_step = [] 141 | correct, total = 0, 0 142 | for data in train_loader: 143 | # Move the data to the GPU 144 | inp_data,labels = data 145 | inp_data = inp_data.to(device) 146 | labels = labels.to(device) 147 | outputs = model(inp_data) 148 | loss = criterion(outputs, labels) 149 | optimizer.zero_grad() 150 | loss.backward() 151 | optimizer.step() 152 | with torch.no_grad(): 153 | _, predicted = torch.max(outputs, 1) 154 | total += labels.size(0) 155 | correct += (predicted == labels).sum() 156 | loss_step.append(loss.item()) 157 | # dont forget the means here 158 | loss_curr_epoch = np.mean(loss_step) 159 | train_acc = (100 * correct / total).cpu() 160 | return loss_curr_epoch, train_acc 161 | 162 | 163 | def linear_eval(model, optimizer, num_epochs, train_loader, val_loader, device): 164 | best_val_loss = 1e8 165 | best_val_acc = 0 166 | model = model.to(device) 167 | dict_log = {"train_acc_epoch":[], "val_acc_epoch":[], "loss_epoch":[], "val_loss":[]} 168 | train_acc, _ = validate(model, train_loader, device) 169 | val_acc, _ = validate(model, val_loader, device) 170 | print(f'Init Accuracy of the model: Train:{train_acc:.3f} \t Val:{val_acc:3f}') 171 | pbar = tqdm(range(num_epochs)) 172 | for epoch in pbar: 173 | loss_curr_epoch, train_acc = train_one_epoch(model, optimizer, train_loader, device) 174 | val_acc, val_loss = validate(model, val_loader, device) 175 | 176 | # Print epoch results to screen 177 | msg = (f'Ep {epoch}/{num_epochs}: Accuracy : Train:{train_acc:.2f} \t Val:{val_acc:.2f} || Loss: Train {loss_curr_epoch:.3f} \t Val {val_loss:.3f}') 178 | pbar.set_description(msg) 179 | # Track stats 180 | dict_log["train_acc_epoch"].append(train_acc) 181 | dict_log["val_acc_epoch"].append(val_acc) 182 | dict_log["loss_epoch"].append(loss_curr_epoch) 183 | dict_log["val_loss"].append(val_loss) 184 | 185 | if val_loss < best_val_loss: 186 | best_val_loss = val_loss 187 | torch.save({ 188 | 'epoch': epoch, 189 | 'model_state_dict': model.state_dict(), 190 | 'optimizer_state_dict': optimizer.state_dict(), 191 | 'loss': val_loss, 192 | }, f'best_model_min_val_loss.pth') 193 | 194 | if val_acc > best_val_acc: 195 | best_val_acc = val_acc 196 | torch.save({ 197 | 'epoch': epoch, 198 | 'model_state_dict': model.state_dict(), 199 | 'optimizer_state_dict': optimizer.state_dict(), 200 | 'loss': val_loss, 201 | }, f'best_model_max_val_acc.pth') 202 | return dict_log 203 | 204 | 205 | def load_model(model, path): 206 | checkpoint = torch.load(path) 207 | model.load_state_dict(checkpoint['model_state_dict']) 208 | print(f"Model {path} is loaded from epoch {checkpoint['epoch']} , loss {checkpoint['loss']}") 209 | return model 210 | 211 | 212 | def default(val, def_val): 213 | return def_val if val is None else val 214 | 215 | def reproducibility(SEED): 216 | torch.manual_seed(SEED) 217 | torch.backends.cudnn.deterministic = True 218 | torch.backends.cudnn.benchmark = False 219 | np.random.seed(SEED) 220 | if torch.cuda.is_available(): 221 | torch.cuda.manual_seed(SEED) 222 | 223 | def define_param_groups(model, weight_decay, optimizer_name): 224 | def exclude_from_wd_and_adaptation(name): 225 | if 'bn' in name: 226 | return True 227 | if optimizer_name == 'lars' and 'bias' in name: 228 | return True 229 | 230 | param_groups = [ 231 | { 232 | 'params': [p for name, p in model.named_parameters() if not exclude_from_wd_and_adaptation(name)], 233 | 'weight_decay': weight_decay, 234 | 'layer_adaptation': True, 235 | }, 236 | { 237 | 'params': [p for name, p in model.named_parameters() if exclude_from_wd_and_adaptation(name)], 238 | 'weight_decay': 0., 239 | 'layer_adaptation': False, 240 | }, 241 | ] 242 | return param_groups 243 | 244 | 245 | def tsne_plot_embeddings(features, labels, class_names, title="T-SNE plot"): 246 | plt.figure(figsize=(12, 12)) 247 | latent_space_tsne = TSNE(2, verbose = True, n_iter = 2000, metric="cosine", perplexity=50, learning_rate=500) 248 | xa_tsne = latent_space_tsne.fit_transform(features.cpu().numpy()[:, :]) 249 | colors = plt.rcParams["axes.prop_cycle"]() 250 | for class_idx in range(len(class_names)): 251 | c = next(colors)["color"] 252 | plt.scatter(xa_tsne[:,0][labels==class_idx], xa_tsne[:,1][labels==class_idx], color=c, label=class_names[class_idx]) 253 | 254 | plt.legend(class_names, fontsize=18, loc='center left', bbox_to_anchor=(1.05, 0.5)) 255 | if title is not None: 256 | plt.title(title, fontsize=18) 257 | 258 | plt.gca().axes.get_yaxis().set_visible(False) 259 | plt.gca().axes.get_xaxis().set_visible(False) 260 | plt.savefig("tsne_plot_embeddings_solution.png") 261 | plt.show() -------------------------------------------------------------------------------- /exercises/week05_scan/A05-image-clustering.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "e-uirVvNW-yY" 7 | }, 8 | "source": [ 9 | "HHU Deep Learning, SS2022/23, 12.05.2023, Prof. Dr. Markus Kollmann\n", 10 | "\n", 11 | "Lecturers and Tutoring is done by Tim Kaiser, Nikolas Adaloglou and Felix Michels.\n", 12 | "\n", 13 | "# Assignment 05 - Image Clustering\n", 14 | "\n", 15 | "\n", 16 | "## Contents\n", 17 | "\n", 18 | "1. Imports, basic utils, augmentations\n", 19 | "2. Load the pretrained MoCO model ResNet50 pretrained on ImageNet\n", 20 | "3. Compute the k-means clustering accuracy using the learned representations\n", 21 | "4. T-SNE visualization of features\n", 22 | "5. Compute the 50-NN\n", 23 | "6. Write a new dataset class to load image pairs\n", 24 | "7. Implement the SCAN loss\n", 25 | "8. Implement the PMI loss. Train the clustering head and compute the validation accuracy\n", 26 | "9. Pretraining code. (Provided, no need to change something here!)\n", 27 | "10. Train with SCAN and PMI using the KNN pairs\n", 28 | "11. Get cluster assignments and evaluate cluster accuracy" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "# Introduction \n", 36 | "\n", 37 | "Image clustering in deep learning can be mathematically described as a process of partitioning a set of images, X, into K clusters, where K is a user-defined parameter representing the number of desired clusters.\n", 38 | "\n", 39 | "Let V(X) be the visual feature representation of the images in X, obtained using a deep learning algorithm such as a convolutional neural network (CNN). Each image in X is transformed into a feature vector in V(X), where the dimensions correspond to the learned features of the CNN.\n", 40 | "\n", 41 | "Image clustering is a task in deep learning where an algorithm is used to group similar images together based on their visual characteristics. Ideally, images with similar ground truth labels will belong in the same cluster.\n", 42 | "\n", 43 | "The goal of image clustering is to automatically categorize large sets of images into smaller subsets based on their similarities, which can help in organizing and managing large image datasets.\n", 44 | "\n", 45 | "To accomplish this task, deep learning algorithms use complex mathematical models to analyze and identify patterns within the images, and then group the images that share these patterns into clusters. This process can be useful in a variety of applications, such as image recognition, image search, and content-based image retrieval.\n", 46 | "\n", 47 | "\n", 48 | "[SimCLR Paper](https://arxiv.org/abs/2002.05709)\n", 49 | "\n", 50 | "[MoCo Paper](https://arxiv.org/abs/1911.05722)\n", 51 | "\n", 52 | "[SCAN Paper](https://arxiv.org/abs/2005.12320v2)\n", 53 | "\n", 54 | "[TEMI](https://arxiv.org/abs/2303.17896)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": { 60 | "id": "lw5K7r5SQDca" 61 | }, 62 | "source": [ 63 | "# Part I. Imports, basic utils, augmentations" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": { 70 | "id": "1SIad0uuHBlv" 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "import os\n", 75 | "import torch\n", 76 | "import torchvision.models as models\n", 77 | "import numpy as np\n", 78 | "\n", 79 | "import torch\n", 80 | "import torchvision\n", 81 | "import torchvision.transforms as T\n", 82 | "import torch.nn as nn\n", 83 | "import torch.nn.functional as F\n", 84 | "from torchvision.datasets import STL10\n", 85 | "from torch.utils.data import DataLoader\n", 86 | "from torch.optim import Adam\n", 87 | "import tqdm\n", 88 | "\n", 89 | "import numpy as np\n", 90 | "import matplotlib.pyplot as plt\n", 91 | "from sklearn.cluster import KMeans\n", 92 | "# Local imports\n", 93 | "from utils import *\n", 94 | "\n", 95 | "os.makedirs(\"./figs\", exist_ok=True)\n", 96 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "# Part II: Load the pretrained MoCO model ResNet50 pretrained on ImageNet\n", 104 | "\n", 105 | "[Weights are available in this link](https://dl.fbaipublicfiles.com/moco/moco_checkpoints/moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar)\n", 106 | "\n", 107 | "You can download the weight by running the terminal command:\n", 108 | "\n", 109 | "`$ wget link_to_model_weights`" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "def load_moco_model(pretrained_path = \"./moco_v2_800ep_pretrain.pth.tar\"):\n", 119 | " ### START CODE HERE ### (≈ 11 lines of code)\n", 120 | " ckpt = torch.load(pretrained_path, map_location='cpu')\n", 121 | " print(ckpt.keys(), ckpt[\"arch\"], ckpt[\"epoch\"])\n", 122 | " state_dict = ckpt[\"state_dict\"]\n", 123 | " state_dict_new = dict()\n", 124 | " for key in state_dict.keys():\n", 125 | " new_key = key.replace(\"module.encoder_q.\",\"\")\n", 126 | " state_dict_new[new_key] = state_dict[key]\n", 127 | " model = getattr(models, ckpt[\"arch\"])(pretrained=False)\n", 128 | " model.fc = nn.Identity()\n", 129 | " msg = model.load_state_dict(state_dict_new, strict=False)\n", 130 | " print(\"Loaded model with message:\", msg)\n", 131 | " ### END CODE HERE ###\n", 132 | " model.eval()\n", 133 | " return model\n", 134 | "\n", 135 | "encoder = load_moco_model() " 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "### Expected results\n", 143 | "\n", 144 | "There should be no missing keys, while loading the model. There may be some unexpected keys based on your implementation.\n", 145 | "\n", 146 | "```python\n", 147 | "Loaded model with message: _IncompatibleKeys(missing_keys=[], unexpected_keys=['fc.0.weight', 'fc.0.bias', 'fc.2.weight', 'fc.2.bias'])\n", 148 | "```" 149 | ] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "metadata": { 154 | "id": "4v-Qg5Xpk2Bv" 155 | }, 156 | "source": [ 157 | "# Part III: Compute the k-means clustering accuracy using the learned representations\n", 158 | "\n", 159 | "\n", 160 | "- Compute the frozen features representations of the backbone model.\n", 161 | "- Compute the accuracy both for the `train` and `test` split using Kmeans.\n", 162 | "\n", 163 | "Hint: you may use the function 'compute_clustering_metrics' defined in utils.py\n" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": { 170 | "id": "FeP4ZuZpsyOp" 171 | }, 172 | "outputs": [], 173 | "source": [ 174 | "transf = T.Compose([\n", 175 | " T.ToTensor(),\n", 176 | " T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])\n", 177 | "### START CODE HERE ### (≈>10 lines of code)\n", 178 | "\n", 179 | "# compute features for train and val\n", 180 | "\n", 181 | "#Fitt k-means.....\n", 182 | "\n", 183 | "\n", 184 | "# compute clustering metrics for train and val\n", 185 | "train_acc = compute_clustering_metrics(train_labels.cpu().numpy(), train_preds,min_samples_per_class=10)[0]\n", 186 | "val_acc = compute_clustering_metrics(train_labels.cpu().numpy(), train_preds, min_samples_per_class=10)[0]\n", 187 | "\n", 188 | "### END CODE HERE ###\n", 189 | "print(f\"Train acc: {train_acc:.2f}, Val acc: {val_acc:.2f}\")" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "metadata": {}, 195 | "source": [ 196 | "### Expected results\n", 197 | "\n", 198 | "`Train acc: 53.64, Val acc: 53.64`" 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "metadata": {}, 204 | "source": [ 205 | "# Part IV. T-SNE visualization of features\n", 206 | "\n", 207 | "As in the previous exercise, check the results of linear probing on the supervised training split and the T-SNE visualization.\n", 208 | "\n", 209 | "Code for the T-SNE visualization exists in `utils.py`." 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "### START CODE HERE ### (≈ 3 line of code)\n", 219 | "# TSNE plot\n", 220 | "\n", 221 | "### END CODE HERE ###" 222 | ] 223 | }, 224 | { 225 | "cell_type": "markdown", 226 | "metadata": {}, 227 | "source": [ 228 | "# Part V. Compute the 50-NN\n", 229 | "\n", 230 | "- Load the train features\n", 231 | "- Use the cosine similarity\n", 232 | "- Compute the k=50 nearset neiboughrs(NN) on the feature space of the pretrained ResNet50\n", 233 | "- save the indices of the k-NN.\n", 234 | "- Visualize the top 5 NN for a couple of images (~10)" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [ 243 | "# Provided but optional to use!\n", 244 | "class_names = torchvision.datasets.STL10(root='../data').classes\n", 245 | "def vizualize_pairs(indices, true_labels, train_ds):\n", 246 | " # Visualize the reference image and its 7 nearest neighbors\n", 247 | " ref_ids = [0, 100, 200, 300, 400, 500, 600, 700, 800, 900]\n", 248 | " nn_viz = 6 \n", 249 | " plt.subplots_adjust(wspace=0.4, hspace=0.4)\n", 250 | " plt.figure(figsize = (22,22))\n", 251 | " ax = plt.gca()\n", 252 | " ax.get_xaxis().set_visible(False)\n", 253 | " ax.get_yaxis().set_visible(False)\n", 254 | " for c, ref in enumerate(ref_ids):\n", 255 | " knns = indices[ref, :nn_viz]\n", 256 | " imgs_to_viz = [train_ds[ref][0]]\n", 257 | " true_labels = [train_ds[ref][1]]\n", 258 | " for i in knns:\n", 259 | " imgs_to_viz.append(train_ds[i][0])\n", 260 | " true_labels.append(train_ds[i][1])\n", 261 | " # show the images\n", 262 | " for j in range(nn_viz):\n", 263 | " label = int(true_labels[j])\n", 264 | " plt.subplot(len(ref_ids), nn_viz, (c*nn_viz)+(j+1))\n", 265 | " imshow(imgs_to_viz[j])\n", 266 | " plt.title(f\"{class_names[label]}, Label {label}\", fontsize = 10)\n", 267 | " ax = plt.gca()\n", 268 | " ax.get_xaxis().set_visible(False)\n", 269 | " ax.get_yaxis().set_visible(False)\n", 270 | " plt.savefig(f'./figs/knn_viz', bbox_inches = \"tight\", dpi = 500) \n", 271 | "\n", 272 | "### START CODE HERE ### (≈ 10 line of code)\n", 273 | "\n", 274 | "# compute the similarity matrix\n", 275 | "\n", 276 | "# take top k similar images\n", 277 | "\n", 278 | "\n", 279 | "# save the indices\n", 280 | "\n", 281 | "### END CODE HERE ###" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": {}, 287 | "source": [ 288 | "# Part VI. Write a new dataset class to load image pairs\n", 289 | "\n", 290 | "- The new dataset class will inherit from `torch.utils.data.Dataset`\n", 291 | "- It will return the representations of 2 images that are in the 50-NN (randomly sampled)." 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": null, 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "### START CODE HERE (≈ 12 lines of code)\n", 301 | "class PairSTL10(torch.utils.data.Dataset):\n", 302 | " def __init__(self, indices_path=\"./knn_indices.pth\", embeds_path=\"./train_feats.pth\", l2_normalize=True):\n", 303 | "\n", 304 | " def __len__(self):\n", 305 | "\n", 306 | " def __getitem__(self, index):\n", 307 | "\n", 308 | "### END CODE HERE\n", 309 | " \n", 310 | "def test_get_pair():\n", 311 | " dataset = PairSTL10()\n", 312 | " emb1, emb2 = dataset[16]\n", 313 | " print(emb1.shape, emb2.shape)\n", 314 | " assert emb1.shape==emb2.shape \n", 315 | "\n", 316 | "test_get_pair()\n", 317 | "train_loader = torch.utils.data.DataLoader(PairSTL10(), batch_size=128, shuffle=True, num_workers=4)\n", 318 | "data_batch = next(iter(train_loader))" 319 | ] 320 | }, 321 | { 322 | "cell_type": "markdown", 323 | "metadata": {}, 324 | "source": [ 325 | "# Part VII. Implement the SCAN loss\n", 326 | "\n", 327 | "Check the SCAN paper, specifically Eq.2 for details." 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": null, 333 | "metadata": {}, 334 | "outputs": [], 335 | "source": [ 336 | "class SCAN(torch.nn.Module):\n", 337 | " def __init__(self, alpha=1):\n", 338 | " super().__init__()\n", 339 | " self.alpha = alpha\n", 340 | "\n", 341 | " def forward(self, proj_1, proj_2):\n", 342 | " # START CODE HERE (≈ 6 line of code)\n", 343 | " \n", 344 | " # dot product\n", 345 | " \n", 346 | "\n", 347 | " # self-entropy regularization\n", 348 | "\n", 349 | " ### END CODE HERE\n", 350 | "\n", 351 | "def test_scan():\n", 352 | " torch.manual_seed(99)\n", 353 | " scan = SCAN(alpha=1)\n", 354 | " proj_1 = torch.randn(100, 128)\n", 355 | " proj_2 = torch.randn(100, 128)\n", 356 | " loss = scan(proj_1, proj_2)\n", 357 | " print(loss)\n", 358 | " assert loss.shape==torch.Size([])\n", 359 | "test_scan()" 360 | ] 361 | }, 362 | { 363 | "cell_type": "markdown", 364 | "metadata": {}, 365 | "source": [ 366 | "### Expected results\n", 367 | "\n", 368 | "For alpha=1, output = `tensor(0.0275)`" 369 | ] 370 | }, 371 | { 372 | "cell_type": "markdown", 373 | "metadata": {}, 374 | "source": [ 375 | "# Part VIII. Implement the PMI loss. Train the clustering head and compute the validation accuracy\n", 376 | "\n", 377 | "Implement the PMI loss based on eq 6,7,8 from the paper https://arxiv.org/pdf/2303.17896.pdf\n", 378 | "\n", 379 | "As a side note we didnt use the symmetrized version of the loss in the exercise: Loss = -PMI, don't forget the sign." 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": null, 385 | "metadata": {}, 386 | "outputs": [], 387 | "source": [ 388 | "class PMI(torch.nn.Module):\n", 389 | " def __init__(self, gamma=1, momentum=0.99, temp=0.1):\n", 390 | " super().__init__()\n", 391 | " self.gamma = gamma\n", 392 | " self.temp = temp\n", 393 | " self.center = None\n", 394 | " self.momentum = momentum\n", 395 | " \n", 396 | " # START CODE HERE (≈ 6 line of code)\n", 397 | " @torch.no_grad()\n", 398 | " def update_ema(self, output):\n", 399 | " \"\"\"\n", 400 | " Update exponential moving average of the center (denominator)\n", 401 | " \"\"\"\n", 402 | " \n", 403 | " def forward(self, proj_1, proj_2):\n", 404 | " \n", 405 | " ### END CODE HERE\n", 406 | "\n", 407 | "def test_pmi():\n", 408 | " torch.manual_seed(99)\n", 409 | " criterion = PMI(gamma=1)\n", 410 | " proj_1 = torch.rand(100, 128)\n", 411 | " proj_2 = torch.rand(100, 128)\n", 412 | " loss = criterion(proj_1, proj_2)\n", 413 | " print(loss)\n", 414 | " assert loss.shape==torch.Size([])\n", 415 | " \n", 416 | "test_pmi()" 417 | ] 418 | }, 419 | { 420 | "cell_type": "markdown", 421 | "metadata": {}, 422 | "source": [ 423 | "### Expected results \n", 424 | "\n", 425 | "`tensor(0.0738)`" 426 | ] 427 | }, 428 | { 429 | "cell_type": "markdown", 430 | "metadata": {}, 431 | "source": [ 432 | "# Part IX. PROVIDED: Pretraining code\n", 433 | "\n", 434 | "This part is provided, but please take a look and identify what is changing compared to the standard train loop.\n", 435 | "\n", 436 | "You don't need to code something here, unless there is some inconsitency with the previous parts of the code.\n", 437 | "\n", 438 | "Still, this code works in our proposed solution and it's your job to modify it if it doesnt work well with the previous code based on your implementations." 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": null, 444 | "metadata": {}, 445 | "outputs": [], 446 | "source": [ 447 | "import copy \n", 448 | "\n", 449 | "\n", 450 | "def pretrain(model, optimizer, num_epochs, train_loader, criterion, device, prefix=\"scan\", model_ema=False):\n", 451 | " dict_log = {\"train_loss\":[]}\n", 452 | " best_loss = 1e8\n", 453 | " model = model.to(device)\n", 454 | " pbar = tqdm(range(num_epochs))\n", 455 | " for epoch in pbar:\n", 456 | " loss_curr_epoch = pretrain_one_epoch(model, optimizer, train_loader, criterion, device, model_ema=model_ema)\n", 457 | " msg = (f'Ep {epoch}/{num_epochs}: || Loss: Train {loss_curr_epoch:.3f}')\n", 458 | " pbar.set_description(msg)\n", 459 | " dict_log[\"train_loss\"].append(loss_curr_epoch)\n", 460 | " if loss_curr_epoch < best_loss:\n", 461 | " best_loss = loss_curr_epoch\n", 462 | " save_model(model, f'{prefix}_best_model_min_val_loss.pth', epoch, optimizer, best_loss) \n", 463 | " return dict_log\n", 464 | "\n", 465 | "class EMA():\n", 466 | " def __init__(self, alpha, student):\n", 467 | " super().__init__()\n", 468 | " self.alpha = alpha\n", 469 | " self.teacher = copy.deepcopy(student)\n", 470 | " for p in self.teacher.parameters():\n", 471 | " p.requires_grad = False\n", 472 | " \n", 473 | " def update_average(self, old, new):\n", 474 | " if old is None:\n", 475 | " return new\n", 476 | " return old * self.alpha + (1 - self.alpha) * new\n", 477 | " \n", 478 | " def update_teacher(self, student):\n", 479 | " for ema_params, student_params in zip(self.teacher.parameters(), student.parameters()):\n", 480 | " old_weight, student_weight = ema_params.data, student_params.data\n", 481 | " ema_params.data = self.update_average(old_weight, student_weight)\n", 482 | "\n", 483 | "\n", 484 | "def pretrain_one_epoch(model, optimizer, train_loader, criterion, device, model_ema=False):\n", 485 | " \"\"\"\n", 486 | " model: the model to train\n", 487 | " optimizer: the optimizer to use\n", 488 | " train_loader: the train loader\n", 489 | " criterion: the loss function, PMI or SCAN\n", 490 | " device: the device to use\n", 491 | " model_ema: whether to use EMA or not\n", 492 | " \"\"\"\n", 493 | " model.train()\n", 494 | " loss_step = []\n", 495 | " if model_ema:\n", 496 | " ema = EMA(0.99, model)\n", 497 | " for data in train_loader:\n", 498 | " # Move the data to the GPU\n", 499 | " img1, img2 = data\n", 500 | " img1, img2 = img1.to(device), img2.to(device)\n", 501 | " p1 = model(img1)\n", 502 | " p2 = ema.teacher(img2) if model_ema else model(img2)\n", 503 | " loss = criterion(p1, p2)\n", 504 | " optimizer.zero_grad()\n", 505 | " loss.backward()\n", 506 | " optimizer.step()\n", 507 | " loss_step.append(loss.item())\n", 508 | " if model_ema:\n", 509 | " ema.update_teacher(model)\n", 510 | " loss_curr_epoch = np.mean(loss_step)\n", 511 | " return loss_curr_epoch" 512 | ] 513 | }, 514 | { 515 | "cell_type": "markdown", 516 | "metadata": {}, 517 | "source": [ 518 | "# Part X. Train with SCAN and PMI using the KNN pairs\n", 519 | "\n", 520 | "- Load the data using the implemented dataloader\n", 521 | "- Create a clustering head\n", 522 | "- Train head using Adam: optimizer, lr=1e-4, weight_decay=1e-6 for 150 epochs.\n", 523 | "- Train with SCAN and PMI and compare them with k-means.\n", 524 | "\n", 525 | "You can use the pretrain function:\n", 526 | "```python\n", 527 | "dict = pretrain(head, optimizer, num_epochs, train_loader, criterion, ......)\n", 528 | "```\n", 529 | "\n", 530 | "Training should **not** take more than 5 minutes for both models.\n", 531 | "\n", 532 | "We used: `PMI(gamma=0.65, momentum=0.9, temp=0.1)` for PMI" 533 | ] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "execution_count": null, 538 | "metadata": {}, 539 | "outputs": [], 540 | "source": [ 541 | "### START CODE HERE ### (>15 line of code)\n", 542 | "# SCAN\n", 543 | "criterion = SCAN(alpha=......)\n", 544 | "\n", 545 | "optimizer = torch.optim.Adam(scan_head.parameters(), lr=1e-4, weight_decay=1e-6)\n", 546 | "dict_log_scan = pretrain(scan_head, optimizer, num_epochs, train_loader, criterion, device, prefix=\"scan\")\n", 547 | "\n", 548 | "# PMI\n", 549 | "criterion = PMI(.....)\n", 550 | "\n", 551 | "optimizer = torch.optim.Adam(pmi_head.parameters(), lr=1e-4, weight_decay=1e-6)\n", 552 | "dict_log_pmi = pretrain(pmi_head, optimizer, num_epochs, train_loader, criterion, device, prefix=\"pmi\", model_ema=True)\n", 553 | "### END CODE HERE ###" 554 | ] 555 | }, 556 | { 557 | "cell_type": "markdown", 558 | "metadata": {}, 559 | "source": [ 560 | "# Part XI. Get cluster assignments and evaluate cluster accuracy\n", 561 | "\n", 562 | "- Load the model trained with both objectives.\n", 563 | "- Predict cluster assignments.\n", 564 | "- Compute the clustering accuracy using `compute_clustering_metrics`" 565 | ] 566 | }, 567 | { 568 | "cell_type": "code", 569 | "execution_count": null, 570 | "metadata": {}, 571 | "outputs": [], 572 | "source": [ 573 | "@torch.no_grad()\n", 574 | "def evaluate_clustering(model):\n", 575 | " model.eval()\n", 576 | " val_feats, val_labels = torch.load(\"val_feats.pth\"), torch.load(\"val_labels.pth\")\n", 577 | " train_feats, train_labels = torch.load(\"train_feats.pth\"), torch.load(\"train_labels.pth\")\n", 578 | " ### START CODE HERE ### (≈ 10 lines of code)\n", 579 | " # normalize feats\n", 580 | " \n", 581 | " # load features and compute logits\n", 582 | "\n", 583 | "\n", 584 | " # compute metrics\n", 585 | " print(\"Unique preds\", np.unique(train_preds), np.unique(val_preds))\n", 586 | " metrics_train = compute_clustering_metrics(train_labels.cpu().numpy(), train_preds, min_samples_per_class=10)\n", 587 | " metrics_val = compute_clustering_metrics(val_labels.cpu().numpy(), val_preds,min_samples_per_class=10)\n", 588 | " return metrics_train[0], metrics_val[0]\n", 589 | " ### END CODE HERE ###\n", 590 | " \n", 591 | "\n", 592 | "# Given but you may need to MODIFY the paths!!!!\n", 593 | "n_clusters = 10\n", 594 | "### START CODE HERE ### (4 lines of code)\n", 595 | "model = ....\n", 596 | "model_scan = load_model(model, \"./scan_best_model_min_val_loss.pth\")\n", 597 | "model = ....\n", 598 | "model_pmi = load_model(model, \"./pmi_best_model_min_val_loss.pth\")\n", 599 | "### END CODE HERE ###\n", 600 | "train_acc, val_acc = evaluate_clustering(model_scan)\n", 601 | "print(f\"SCAN: Train acc: {train_acc:.3f}, Val acc: {val_acc:.3f}\")\n", 602 | "train_acc, val_acc = evaluate_clustering(model_pmi)\n", 603 | "print(f\"PMI: Train acc: {train_acc:.3f}, Val acc: {val_acc:.3f}\")" 604 | ] 605 | }, 606 | { 607 | "cell_type": "markdown", 608 | "metadata": {}, 609 | "source": [ 610 | "### Expected results:\n", 611 | "Current best scores! Results may slightly vary between runs.\n", 612 | "```\n", 613 | "Model ./scan_best_model_min_val_loss.pth is loaded from epoch 148 , loss -22.383880043029784\n", 614 | "Model ./pmi_best_model_min_val_loss.pth is loaded from epoch 129 , loss -2.0719790697097777\n", 615 | "Unique preds [0 1 2 3 4 5 6 7 8 9] [0 1 2 3 4 5 6 7 8 9]\n", 616 | "SCAN: Train acc: 74.380, Val acc: 74.450\n", 617 | "Unique preds [0 1 2 3 4 5 6 7 8 9] [0 1 2 3 4 5 6 7 8 9]\n", 618 | "PMI: Train acc: 77.280, Val acc: 78.238\n", 619 | "```" 620 | ] 621 | }, 622 | { 623 | "cell_type": "markdown", 624 | "metadata": {}, 625 | "source": [ 626 | "# Conclusion and Bonus reads\n", 627 | "\n", 628 | "That's the end of this exercise. If you reached this point, congratulations!\n", 629 | "\n", 630 | "Additional things to to (Optional):\n", 631 | "\n", 632 | "- Plot the histogram of class assignments for SCAN and PMI\n", 633 | "- Compute the mean and median max softmax probability for SCAN and PMI\n" 634 | ] 635 | } 636 | ], 637 | "metadata": { 638 | "accelerator": "GPU", 639 | "colab": { 640 | "collapsed_sections": [], 641 | "machine_shape": "hm", 642 | "name": "[Exercise 4] - SimCLR Resnet18 Solution.ipynb", 643 | "provenance": [] 644 | }, 645 | "kernelspec": { 646 | "display_name": "Python 3 (ipykernel)", 647 | "language": "python", 648 | "name": "python3" 649 | }, 650 | "language_info": { 651 | "codemirror_mode": { 652 | "name": "ipython", 653 | "version": 3 654 | }, 655 | "file_extension": ".py", 656 | "mimetype": "text/x-python", 657 | "name": "python", 658 | "nbconvert_exporter": "python", 659 | "pygments_lexer": "ipython3", 660 | "version": "3.7.15" 661 | }, 662 | "vscode": { 663 | "interpreter": { 664 | "hash": "dc5fcf396fe0abd4fa852aee332a0572494dcaf5776820055c87d9b84157f362" 665 | } 666 | } 667 | }, 668 | "nbformat": 4, 669 | "nbformat_minor": 1 670 | } 671 | -------------------------------------------------------------------------------- /exercises/week05_scan/figs/knn_viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHU-MMBS/RepresentationLearning_SS2023/9238c4c7648f8d64dd0461e828652cd3606f36a1/exercises/week05_scan/figs/knn_viz.png -------------------------------------------------------------------------------- /exercises/week05_scan/figs/tsne_plot_embeddings_solution__.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHU-MMBS/RepresentationLearning_SS2023/9238c4c7648f8d64dd0461e828652cd3606f36a1/exercises/week05_scan/figs/tsne_plot_embeddings_solution__.png -------------------------------------------------------------------------------- /exercises/week05_scan/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader, Dataset 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | import torch.optim as optim 11 | import torch.utils.data as data 12 | import random 13 | import matplotlib.pyplot as plt 14 | from torchvision import transforms as T 15 | from tqdm import tqdm 16 | from sklearn.manifold import TSNE 17 | 18 | def load_data( batch_size=128, train_split="unlabeled", test_split="test", transf = T.ToTensor(), num_workers=2, shuffle=False): 19 | train_ds = torchvision.datasets.STL10(root='../data', split=train_split, transform=transf, download=True) 20 | val_ds = torchvision.datasets.STL10(root='../data', split=test_split, transform=transf, download=True) 21 | train_dl = DataLoader(dataset=train_ds, batch_size=batch_size, num_workers=num_workers, drop_last=False, shuffle=shuffle) 22 | val_dl = DataLoader(dataset=val_ds, batch_size=batch_size, num_workers=num_workers, drop_last=False, shuffle=shuffle) 23 | return train_dl, val_dl 24 | 25 | def imshow(img, mean=torch.tensor([0.0], dtype=torch.float32), std=torch.tensor([1], dtype=torch.float32)): 26 | """ 27 | shows an image on the screen. mean of 0 and variance of 1 will show the images unchanged in the screen 28 | """ 29 | # undoes the normalization 30 | unnormalize = T.Normalize((-mean / std).tolist(), (1.0 / std).tolist()) 31 | npimg = unnormalize(img).numpy() 32 | plt.imshow(np.transpose(npimg, (1, 2, 0))) 33 | 34 | 35 | def prevalidate(model, val_loader,criterion, device): 36 | ### START CODE HERE ### (≈ 12 lines of code) 37 | model.eval() 38 | correct, total = 0, 0 39 | loss_step = [] 40 | with torch.no_grad(): 41 | for data in val_loader: 42 | inp_data,labels = data 43 | inp_data, labels = inp_data.to(device), labels.to(device) 44 | outputs = model(inp_data) 45 | val_loss = criterion(outputs, labels) 46 | loss_step.append(val_loss.item()) 47 | # dont forget to take the means here 48 | val_loss_epoch = np.mean(loss_step) 49 | ### END CODE HERE ### 50 | return val_loss_epoch 51 | 52 | 53 | def save_model(model, path, epoch, optimizer, val_loss): 54 | torch.save({ 55 | 'epoch': epoch, 56 | 'model_state_dict': model.state_dict(), 57 | 'optimizer_state_dict': optimizer.state_dict(), 58 | 'loss': val_loss, 59 | }, path) 60 | 61 | 62 | def validate(model, val_loader, device): 63 | model.eval() 64 | criterion = nn.CrossEntropyLoss() 65 | correct, total = 0, 0 66 | loss_step = [] 67 | with torch.no_grad(): 68 | for data in val_loader: 69 | inp_data,labels = data 70 | inp_data = inp_data.to(device) 71 | labels = labels.to(device) 72 | outputs = model(inp_data) 73 | val_loss = criterion(outputs, labels) 74 | predicted = torch.max(outputs, 1)[1] 75 | total += labels.size(0) 76 | correct += (predicted == labels).sum() 77 | loss_step.append(val_loss.item()) 78 | # dont forget to take the means here 79 | val_acc = (100 * correct / total).cpu().numpy() 80 | val_loss_epoch = torch.tensor(loss_step).mean().numpy() 81 | return val_acc , val_loss_epoch 82 | 83 | # Provided 84 | @torch.no_grad() 85 | def get_features(model, dataloader, device): 86 | model = model.to(device) 87 | feats, labs = [], [] 88 | for i in dataloader: 89 | inp_data,labels = i 90 | inp_data = inp_data.to(device) 91 | features = model(inp_data) 92 | if features.ndim > 2: 93 | features = features.flatten(start_dim=1) 94 | features = features.cpu().detach() 95 | labels = labels.cpu().detach() 96 | feats.append(features) 97 | labs.append(labels) 98 | f = torch.cat(feats, dim=0) 99 | l = torch.cat(labs, dim=0) 100 | return f,l 101 | 102 | 103 | def train_one_epoch(model, optimizer, train_loader, device): 104 | model.train() 105 | criterion = nn.CrossEntropyLoss() 106 | loss_step = [] 107 | correct, total = 0, 0 108 | for data in train_loader: 109 | # Move the data to the GPU 110 | inp_data,labels = data 111 | inp_data = inp_data.to(device) 112 | labels = labels.to(device) 113 | outputs = model(inp_data) 114 | loss = criterion(outputs, labels) 115 | optimizer.zero_grad() 116 | loss.backward() 117 | optimizer.step() 118 | with torch.no_grad(): 119 | _, predicted = torch.max(outputs, 1) 120 | total += labels.size(0) 121 | correct += (predicted == labels).sum() 122 | loss_step.append(loss.item()) 123 | # dont forget the means here 124 | loss_curr_epoch = np.mean(loss_step) 125 | train_acc = (100 * correct / total).cpu() 126 | return loss_curr_epoch, train_acc 127 | 128 | 129 | def linear_eval(model, optimizer, num_epochs, train_loader, val_loader, device): 130 | best_val_loss = 1e8 131 | best_val_acc = 0 132 | model = model.to(device) 133 | dict_log = {"train_acc_epoch":[], "val_acc_epoch":[], "loss_epoch":[], "val_loss":[]} 134 | train_acc, _ = validate(model, train_loader, device) 135 | val_acc, _ = validate(model, val_loader, device) 136 | print(f'Init Accuracy of the model: Train:{train_acc:.3f} \t Val:{val_acc:3f}') 137 | pbar = tqdm(range(num_epochs)) 138 | for epoch in pbar: 139 | loss_curr_epoch, train_acc = train_one_epoch(model, optimizer, train_loader, device) 140 | val_acc, val_loss = validate(model, val_loader, device) 141 | 142 | # Print epoch results to screen 143 | msg = (f'Ep {epoch}/{num_epochs}: Accuracy : Train:{train_acc:.2f} \t Val:{val_acc:.2f} || Loss: Train {loss_curr_epoch:.3f} \t Val {val_loss:.3f}') 144 | pbar.set_description(msg) 145 | # Track stats 146 | dict_log["train_acc_epoch"].append(train_acc) 147 | dict_log["val_acc_epoch"].append(val_acc) 148 | dict_log["loss_epoch"].append(loss_curr_epoch) 149 | dict_log["val_loss"].append(val_loss) 150 | 151 | if val_loss < best_val_loss: 152 | best_val_loss = val_loss 153 | torch.save({ 154 | 'epoch': epoch, 155 | 'model_state_dict': model.state_dict(), 156 | 'optimizer_state_dict': optimizer.state_dict(), 157 | 'loss': val_loss, 158 | }, f'best_model_min_val_loss.pth') 159 | 160 | if val_acc > best_val_acc: 161 | best_val_acc = val_acc 162 | torch.save({ 163 | 'epoch': epoch, 164 | 'model_state_dict': model.state_dict(), 165 | 'optimizer_state_dict': optimizer.state_dict(), 166 | 'loss': val_loss, 167 | }, f'best_model_max_val_acc.pth') 168 | return dict_log 169 | 170 | 171 | def load_model(model, path): 172 | checkpoint = torch.load(path) 173 | model.load_state_dict(checkpoint['model_state_dict']) 174 | print(f"Model {path} is loaded from epoch {checkpoint['epoch']} , loss {checkpoint['loss']}") 175 | return model 176 | 177 | 178 | def default(val, def_val): 179 | return def_val if val is None else val 180 | 181 | def reproducibility(SEED): 182 | torch.manual_seed(SEED) 183 | torch.backends.cudnn.deterministic = True 184 | torch.backends.cudnn.benchmark = False 185 | np.random.seed(SEED) 186 | if torch.cuda.is_available(): 187 | torch.cuda.manual_seed(SEED) 188 | 189 | def define_param_groups(model, weight_decay, optimizer_name): 190 | def exclude_from_wd_and_adaptation(name): 191 | if 'bn' in name: 192 | return True 193 | if optimizer_name == 'lars' and 'bias' in name: 194 | return True 195 | 196 | param_groups = [ 197 | { 198 | 'params': [p for name, p in model.named_parameters() if not exclude_from_wd_and_adaptation(name)], 199 | 'weight_decay': weight_decay, 200 | 'layer_adaptation': True, 201 | }, 202 | { 203 | 'params': [p for name, p in model.named_parameters() if exclude_from_wd_and_adaptation(name)], 204 | 'weight_decay': 0., 205 | 'layer_adaptation': False, 206 | }, 207 | ] 208 | return param_groups 209 | 210 | 211 | def tsne_plot_embeddings(features, labels, class_names, title="T-SNE plot"): 212 | plt.figure(figsize=(12, 12)) 213 | latent_space_tsne = TSNE(2, verbose = True, n_iter = 2000, metric="cosine", perplexity=50, learning_rate=500) 214 | xa_tsne = latent_space_tsne.fit_transform(features.cpu().numpy()[:, :]) 215 | colors = plt.rcParams["axes.prop_cycle"]() 216 | for class_idx in range(len(class_names)): 217 | c = next(colors)["color"] 218 | plt.scatter(xa_tsne[:,0][labels==class_idx], xa_tsne[:,1][labels==class_idx], color=c, label=class_names[class_idx]) 219 | 220 | plt.legend(class_names, fontsize=18, loc='center left', bbox_to_anchor=(1.05, 0.5)) 221 | if title is not None: 222 | plt.title(title, fontsize=18) 223 | 224 | plt.gca().axes.get_yaxis().set_visible(False) 225 | plt.gca().axes.get_xaxis().set_visible(False) 226 | plt.savefig("tsne_plot_embeddings_solution.png") 227 | plt.show() 228 | 229 | # based on https://github.com/elad-amrani/self-classifier/blob/e5e3fb98d71bd6961031bbd308826017fd9753ec/src/cls_eval.py 230 | def compute_clustering_metrics(targets, preds, min_samples_per_class, verbose=False): 231 | from sklearn.metrics import normalized_mutual_info_score as nmi 232 | from sklearn.metrics import adjusted_mutual_info_score as adjusted_nmi 233 | from sklearn.metrics import adjusted_rand_score as adjusted_rand_index 234 | from scipy.optimize import linear_sum_assignment 235 | val_nmi = nmi(targets, preds) 236 | val_adjusted_nmi = adjusted_nmi(targets, preds) 237 | val_adjusted_rand_index = adjusted_rand_index(targets, preds) 238 | 239 | # compute accuracy 240 | num_classes = max(targets.max(), preds.max()) + 1 241 | count_matrix = np.zeros((num_classes, num_classes), dtype=np.int32) 242 | for ii in range(preds.shape[0]): 243 | count_matrix[preds[ii], targets[ii]] += 1 244 | reassignment = np.dstack(linear_sum_assignment(count_matrix.max() - count_matrix))[0] 245 | 246 | if len(np.unique(preds)) > len(np.unique(targets)): # if using over-clustering, append remaining clusters to best option 247 | for cls_idx in np.unique(preds): 248 | if reassignment[cls_idx, 1] not in targets: 249 | reassignment[cls_idx, 1] = count_matrix[cls_idx].argmax() 250 | 251 | acc = count_matrix[reassignment[:, 0], reassignment[:, 1]].sum().astype(np.float32) / preds.shape[0] 252 | 253 | # extract max accuracy classes 254 | num_samples_per_class = count_matrix[reassignment[:, 0], :].sum(axis=1) 255 | acc_per_class = np.where(num_samples_per_class >= min_samples_per_class, 256 | count_matrix[reassignment[:, 0], reassignment[:, 1]] / num_samples_per_class, 0) 257 | max_acc_classes = np.argsort(acc_per_class)[::-1] 258 | acc_per_class = acc_per_class[max_acc_classes] 259 | num_samples_per_class = num_samples_per_class[max_acc_classes] 260 | if verbose: 261 | print('=> number of samples: {}'.format(len(targets))) 262 | print('=> number of unique assignments: {}'.format(len(set(preds)))) 263 | print('=> NMI: {:.3f}%'.format(val_nmi * 100.0)) 264 | print('=> Adjusted NMI: {:.3f}%'.format(val_adjusted_nmi * 100.0)) 265 | print('=> Adjusted Rand-Index: {:.3f}%'.format(val_adjusted_rand_index * 100.0)) 266 | print('=> Accuracy: {:.3f}%'.format(acc * 100.0)) 267 | 268 | return acc * 100.0, val_nmi * 100.0, val_adjusted_nmi * 100.0, val_adjusted_rand_index * 100.0, acc_per_class 269 | 270 | 271 | class SST2Model(nn.Module): 272 | def __init__(self, bert_encoder, train_encoder=True): 273 | """ 274 | Args: 275 | bert_encoder: An instance of a BERTEncoder 276 | train_encoder: wheter the encoder should be trained or not. 277 | """ 278 | super().__init__() 279 | self.bert_encoder = bert_encoder 280 | for param in self.bert_encoder.parameters(): 281 | param.requires_grad = train_encoder 282 | self.classifier = nn.Linear(bert_encoder.d_model, 1, bias=False) 283 | 284 | def forward(self, input_ids): 285 | """ 286 | Predicts the sentiment of a sentence (positive or negative) 287 | Args: 288 | input_ids: tensor of shape (batch_size, seq_len) containing the token ids of the sentences 289 | Returns: 290 | tensor of shape (batch_size) containing the predicted sentiment 291 | """ 292 | h = self.bert_encoder(input_ids) 293 | return self.classifier(h[:, 0]).view(-1) 294 | -------------------------------------------------------------------------------- /exercises/week06_distillation/figs/ViT-tiny_CIFAR100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHU-MMBS/RepresentationLearning_SS2023/9238c4c7648f8d64dd0461e828652cd3606f36a1/exercises/week06_distillation/figs/ViT-tiny_CIFAR100.png -------------------------------------------------------------------------------- /exercises/week06_distillation/figs/distilled_ViT-tiny_.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHU-MMBS/RepresentationLearning_SS2023/9238c4c7648f8d64dd0461e828652cd3606f36a1/exercises/week06_distillation/figs/distilled_ViT-tiny_.png -------------------------------------------------------------------------------- /exercises/week06_distillation/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader, Dataset 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | import torch.optim as optim 11 | import torch.utils.data as data 12 | import random 13 | import matplotlib.pyplot as plt 14 | from torchvision import transforms as T 15 | from tqdm import tqdm 16 | from sklearn.manifold import TSNE 17 | 18 | def load_data( batch_size=128, train_split="unlabeled", test_split="test", transf = T.ToTensor()): 19 | train_ds = torchvision.datasets.STL10(root='../data', split=train_split, transform=transf, download=True) 20 | val_ds = torchvision.datasets.STL10(root='../data', split=test_split, transform=transf, download=True) 21 | train_dl = DataLoader(dataset=train_ds, batch_size=batch_size, num_workers=2, drop_last=True) 22 | val_dl = DataLoader(dataset=val_ds, batch_size=batch_size, num_workers=2, drop_last=True) 23 | return train_dl, val_dl 24 | 25 | def imshow(img, mean=torch.tensor([0.0], dtype=torch.float32), std=torch.tensor([1], dtype=torch.float32)): 26 | """ 27 | shows an image on the screen. mean of 0 and variance of 1 will show the images unchanged in the screen 28 | """ 29 | # undoes the normalization 30 | unnormalize = T.Normalize((-mean / std).tolist(), (1.0 / std).tolist()) 31 | npimg = unnormalize(img).numpy() 32 | plt.imshow(np.transpose(npimg, (1, 2, 0))) 33 | 34 | 35 | def prevalidate(model, val_loader,criterion, device): 36 | ### START CODE HERE ### (≈ 12 lines of code) 37 | model.eval() 38 | correct, total = 0, 0 39 | loss_step = [] 40 | with torch.no_grad(): 41 | for data in val_loader: 42 | inp_data,labels = data 43 | inp_data, labels = inp_data.to(device), labels.to(device) 44 | outputs = model(inp_data) 45 | val_loss = criterion(outputs, labels) 46 | loss_step.append(val_loss.item()) 47 | # dont forget to take the means here 48 | val_loss_epoch = np.mean(loss_step) 49 | ### END CODE HERE ### 50 | return val_loss_epoch 51 | 52 | 53 | def save_model(model, path, epoch, optimizer, val_loss): 54 | torch.save({ 55 | 'epoch': epoch, 56 | 'model_state_dict': model.state_dict(), 57 | 'optimizer_state_dict': optimizer.state_dict(), 58 | 'loss': val_loss, 59 | }, path) 60 | 61 | 62 | def validate(model, val_loader, device): 63 | model.eval() 64 | criterion = nn.CrossEntropyLoss() 65 | correct, total = 0, 0 66 | loss_step = [] 67 | with torch.no_grad(): 68 | for data in val_loader: 69 | inp_data,labels = data 70 | inp_data = inp_data.to(device) 71 | labels = labels.to(device) 72 | outputs = model(inp_data) 73 | val_loss = criterion(outputs, labels) 74 | predicted = torch.max(outputs, 1)[1] 75 | total += labels.size(0) 76 | correct += (predicted == labels).sum() 77 | loss_step.append(val_loss.item()) 78 | # dont forget to take the means here 79 | val_acc = (100 * correct / total).cpu().numpy() 80 | val_loss_epoch = torch.tensor(loss_step).mean().numpy() 81 | return val_acc , val_loss_epoch 82 | 83 | @torch.no_grad() 84 | def get_features(model, dataloader, device): 85 | model = model.to(device) 86 | feats, labs = [], [] 87 | for i in dataloader: 88 | inp_data,labels = i 89 | inp_data = inp_data.to(device) 90 | features = model(inp_data) 91 | features = features.cpu().detach().flatten(start_dim=1) 92 | labels = labels.cpu().detach() 93 | feats.append(features) 94 | labs.append(labels) 95 | f = torch.cat(feats, dim=0) 96 | l = torch.cat(labs, dim=0) 97 | return f,l 98 | 99 | 100 | def train_one_epoch(model, optimizer, train_loader, device): 101 | model.train() 102 | criterion = nn.CrossEntropyLoss() 103 | loss_step = [] 104 | correct, total = 0, 0 105 | for data in train_loader: 106 | # Move the data to the GPU 107 | inp_data,labels = data 108 | inp_data = inp_data.to(device) 109 | labels = labels.to(device) 110 | outputs = model(inp_data) 111 | loss = criterion(outputs, labels) 112 | optimizer.zero_grad() 113 | loss.backward() 114 | optimizer.step() 115 | with torch.no_grad(): 116 | _, predicted = torch.max(outputs, 1) 117 | total += labels.size(0) 118 | correct += (predicted == labels).sum() 119 | loss_step.append(loss.item()) 120 | # dont forget the means here 121 | loss_curr_epoch = np.mean(loss_step) 122 | train_acc = (100 * correct / total).cpu() 123 | return loss_curr_epoch, train_acc 124 | 125 | 126 | def finetune(model, optimizer, num_epochs, train_loader, val_loader, device, prefix='model'): 127 | best_val_loss = 1e8 128 | best_val_acc = 0 129 | model = model.to(device) 130 | dict_log = {"train_acc_epoch":[], "val_acc_epoch":[], "loss_epoch":[], "val_loss":[]} 131 | pbar = tqdm(range(num_epochs)) 132 | for epoch in pbar: 133 | loss_curr_epoch, train_acc = train_one_epoch(model, optimizer, train_loader, device) 134 | val_acc, val_loss = validate(model, val_loader, device) 135 | 136 | # Print epoch results to screen 137 | msg = (f'Ep {epoch}/{num_epochs}: Accuracy : Train:{train_acc:.2f} \t Val:{val_acc:.2f} || Loss: Train {loss_curr_epoch:.3f} \t Val {val_loss:.3f}') 138 | pbar.set_description(msg) 139 | # Track stats 140 | dict_log["train_acc_epoch"].append(train_acc) 141 | dict_log["val_acc_epoch"].append(val_acc) 142 | dict_log["loss_epoch"].append(loss_curr_epoch) 143 | dict_log["val_loss"].append(val_loss) 144 | 145 | if val_loss < best_val_loss: 146 | best_val_loss = val_loss 147 | torch.save({ 148 | 'epoch': epoch, 149 | 'model_state_dict': model.state_dict(), 150 | 'optimizer_state_dict': optimizer.state_dict(), 151 | 'loss': val_loss, 152 | }, f'{prefix}_best_model_min_val_loss.pth') 153 | 154 | if val_acc > best_val_acc: 155 | best_val_acc = val_acc 156 | torch.save({ 157 | 'epoch': epoch, 158 | 'model_state_dict': model.state_dict(), 159 | 'optimizer_state_dict': optimizer.state_dict(), 160 | 'loss': val_loss, 161 | }, f'{prefix}_best_model_max_val_acc.pth') 162 | return dict_log 163 | 164 | 165 | def load_model(model, path): 166 | checkpoint = torch.load(path) 167 | model.load_state_dict(checkpoint['model_state_dict']) 168 | print(f"Model {path} is loaded from epoch {checkpoint['epoch']} , loss {checkpoint['loss']}") 169 | return model 170 | 171 | 172 | def default(val, def_val): 173 | return def_val if val is None else val 174 | 175 | def reproducibility(SEED): 176 | torch.manual_seed(SEED) 177 | torch.backends.cudnn.deterministic = True 178 | torch.backends.cudnn.benchmark = False 179 | np.random.seed(SEED) 180 | if torch.cuda.is_available(): 181 | torch.cuda.manual_seed(SEED) 182 | 183 | def define_param_groups(model, weight_decay, optimizer_name): 184 | def exclude_from_wd_and_adaptation(name): 185 | if 'bn' in name: 186 | return True 187 | if optimizer_name == 'lars' and 'bias' in name: 188 | return True 189 | 190 | param_groups = [ 191 | { 192 | 'params': [p for name, p in model.named_parameters() if not exclude_from_wd_and_adaptation(name)], 193 | 'weight_decay': weight_decay, 194 | 'layer_adaptation': True, 195 | }, 196 | { 197 | 'params': [p for name, p in model.named_parameters() if exclude_from_wd_and_adaptation(name)], 198 | 'weight_decay': 0., 199 | 'layer_adaptation': False, 200 | }, 201 | ] 202 | return param_groups 203 | 204 | 205 | def tsne_plot_embeddings(features, labels, class_names, title="T-SNE plot"): 206 | plt.figure(figsize=(12, 12)) 207 | latent_space_tsne = TSNE(2, verbose = True, n_iter = 2000, metric="cosine", perplexity=50, learning_rate=500) 208 | xa_tsne = latent_space_tsne.fit_transform(features.cpu().numpy()[:, :]) 209 | colors = plt.rcParams["axes.prop_cycle"]() 210 | for class_idx in range(len(class_names)): 211 | c = next(colors)["color"] 212 | plt.scatter(xa_tsne[:,0][labels==class_idx], xa_tsne[:,1][labels==class_idx], color=c, label=class_names[class_idx]) 213 | 214 | plt.legend(class_names, fontsize=18, loc='center left', bbox_to_anchor=(1.05, 0.5)) 215 | if title is not None: 216 | plt.title(title, fontsize=18) 217 | 218 | plt.gca().axes.get_yaxis().set_visible(False) 219 | plt.gca().axes.get_xaxis().set_visible(False) 220 | plt.savefig("tsne_plot_embeddings_solution.png") 221 | plt.show() 222 | 223 | # based on https://github.com/elad-amrani/self-classifier/blob/e5e3fb98d71bd6961031bbd308826017fd9753ec/src/cls_eval.py 224 | def compute_clustering_metrics(targets, preds, min_samples_per_class, verbose=False): 225 | from sklearn.metrics import normalized_mutual_info_score as nmi 226 | from sklearn.metrics import adjusted_mutual_info_score as adjusted_nmi 227 | from sklearn.metrics import adjusted_rand_score as adjusted_rand_index 228 | from scipy.optimize import linear_sum_assignment 229 | val_nmi = nmi(targets, preds) 230 | val_adjusted_nmi = adjusted_nmi(targets, preds) 231 | val_adjusted_rand_index = adjusted_rand_index(targets, preds) 232 | 233 | # compute accuracy 234 | num_classes = max(targets.max(), preds.max()) + 1 235 | count_matrix = np.zeros((num_classes, num_classes), dtype=np.int32) 236 | for ii in range(preds.shape[0]): 237 | count_matrix[preds[ii], targets[ii]] += 1 238 | reassignment = np.dstack(linear_sum_assignment(count_matrix.max() - count_matrix))[0] 239 | 240 | if len(np.unique(preds)) > len(np.unique(targets)): # if using over-clustering, append remaining clusters to best option 241 | for cls_idx in np.unique(preds): 242 | if reassignment[cls_idx, 1] not in targets: 243 | reassignment[cls_idx, 1] = count_matrix[cls_idx].argmax() 244 | 245 | acc = count_matrix[reassignment[:, 0], reassignment[:, 1]].sum().astype(np.float32) / preds.shape[0] 246 | 247 | 248 | # extract max accuracy classes 249 | num_samples_per_class = count_matrix[reassignment[:, 0], :].sum(axis=1) 250 | acc_per_class = np.where(num_samples_per_class >= min_samples_per_class, 251 | count_matrix[reassignment[:, 0], reassignment[:, 1]] / num_samples_per_class, 0) 252 | max_acc_classes = np.argsort(acc_per_class)[::-1] 253 | acc_per_class = acc_per_class[max_acc_classes] 254 | num_samples_per_class = num_samples_per_class[max_acc_classes] 255 | if verbose: 256 | print('=> number of samples: {}'.format(len(targets))) 257 | print('=> number of unique assignments: {}'.format(len(set(preds)))) 258 | print('=> NMI: {:.3f}%'.format(val_nmi * 100.0)) 259 | print('=> Adjusted NMI: {:.3f}%'.format(val_adjusted_nmi * 100.0)) 260 | print('=> Adjusted Rand-Index: {:.3f}%'.format(val_adjusted_rand_index * 100.0)) 261 | print('=> Accuracy: {:.3f}%'.format(acc * 100.0)) 262 | 263 | return acc * 100.0, val_nmi * 100.0, val_adjusted_nmi * 100.0, val_adjusted_rand_index * 100.0, acc_per_class 264 | 265 | 266 | def plot_stats(dict_log, modelname="", baseline=None, title=None): 267 | fontsize = 14 268 | plt.subplots_adjust(hspace=0.3) 269 | plt.subplot(2,1,1) 270 | x_axis = list(range(len(dict_log["val_acc_epoch"]))) 271 | plt.plot(dict_log["train_acc_epoch"], label=f'{modelname} Train accuracy') 272 | plt.scatter(x_axis, dict_log["train_acc_epoch"]) 273 | plt.plot(dict_log["val_acc_epoch"], label=f'{modelname} Validation accuracy') 274 | plt.scatter(x_axis, dict_log["val_acc_epoch"]) 275 | plt.ylabel('Accuracy in %') 276 | plt.xlabel('Number of Epochs') 277 | plt.title("Accuracy over epochs", fontsize=fontsize) 278 | if baseline is not None: 279 | plt.axhline(y=baseline, color='red', label="Acceptable accuracy") 280 | plt.legend(fontsize=fontsize) 281 | plt.subplot(2,1,2) 282 | plt.plot(dict_log["loss_epoch"] , label="Training") 283 | plt.scatter(x_axis, dict_log["loss_epoch"], ) 284 | plt.plot(dict_log["val_loss"] , label='Validation') 285 | plt.scatter(x_axis, dict_log["val_loss"]) 286 | plt.ylabel('Loss value') 287 | plt.xlabel('Number of Epochs') 288 | plt.title("Loss over epochs", fontsize=fontsize) 289 | plt.legend(fontsize=fontsize) 290 | if title is not None: 291 | plt.savefig(title) -------------------------------------------------------------------------------- /exercises/week09_10_CLIP/A09_CLIP_OOD.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "CCfZVwhwMIYN" 7 | }, 8 | "source": [ 9 | "HHU Deep Learning, SS2022/23, 09.06.2023, Prof. Dr. Markus Kollmann\n", 10 | "\n", 11 | "Lecturers and Tutoring is done by Tim Kaiser, Nikolas Adaloglou and Felix Michels.\n", 12 | "\n", 13 | "# Assignment 10 - Contrastive Language-Image Pre-training for unsupervised out-of-distribution detection\n", 14 | "\n", 15 | "Copyright © 2023 Nikolas Adaloglou, Tim Kaiser and Felix Michels\n", 16 | "\n", 17 | "\n", 18 | "## Contents\n", 19 | "\n", 20 | "1. Basic imports\n", 21 | "2. Get the visual features of the CLIP model\n", 22 | "3. Compute the k-NN similarity as the OOD score\n", 23 | "4. Compute MSP using the text encoder and the label names\n", 24 | "5. Linear probing on the pseudolabels\n", 25 | "6. Mahalanobis distance as OOD score\n", 26 | "7. Mahalanobis distance using the real labels without linear probing\n", 27 | "8. K-means clusters combined with Mahalanobis distance" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "---\n", 35 | "\n", 36 | "## Overview\n", 37 | "We will apply the learned representations from Contrastive Language-Image Pretrained (CLIP) on the downstream task of out-of-distribution detection.\n", 38 | "\n", 39 | "`Note`: I used the pretrained models from open_clip_torch, you can install it with `!pip install open_clip_torch`\n", 40 | "\n", 41 | "We will be using the model 'convnext_base_w' pretrained on 'laion2b_s13b_b82k' throughout this tutorial.\n", 42 | "\n", 43 | "Info and examples on how to use CLIP models for inference is provided in [openclip](https://github.com/mlfoundations/open_clip#usage)\n", 44 | "\n", 45 | "- [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020)\n", 46 | "- [Contrastive Language-Image Pretrained (CLIP) Models are Powerful Out-of-Distribution Detectors](https://arxiv.org/abs/2303.05828)\n", 47 | "\n", 48 | "\n", 49 | "\n", 50 | "# Part I. Basic imports" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "import os\n", 60 | "from pathlib import Path\n", 61 | "from tqdm import tqdm\n", 62 | "import numpy as np\n", 63 | "from sklearn.cluster import KMeans\n", 64 | "import matplotlib.pyplot as plt\n", 65 | "from sklearn.metrics import roc_auc_score\n", 66 | "\n", 67 | "import torch\n", 68 | "import open_clip\n", 69 | "from torch import nn\n", 70 | "from torch.nn import functional as F\n", 71 | "import torchvision\n", 72 | "import torchvision.transforms as T\n", 73 | "from torch.utils.data import Subset, DataLoader, Dataset\n", 74 | "\n", 75 | "out_dir = Path('./features/').resolve()\n", 76 | "out_dir.mkdir(parents=True, exist_ok=True)\n", 77 | "# Local import\n", 78 | "from utils import *\n", 79 | "\n", 80 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 81 | "\n", 82 | "# Helper function\n", 83 | "def auroc_score(score_in, score_out):\n", 84 | " if type(score_in) == torch.Tensor:\n", 85 | " score_in = score_in.cpu().numpy()\n", 86 | " score_out = score_out.cpu().numpy()\n", 87 | " labels = np.concatenate((np.ones_like(score_in), np.zeros_like(score_out)))\n", 88 | " return roc_auc_score(labels, np.concatenate((score_in, score_out))) * 100" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "# Part II. Get the visual features of the CLIP model\n", 96 | "\n", 97 | "- We will use `CIFAR100` as the in-distribution, and `CIFAR10` as the out-distribution.\n", 98 | "- When you are only loading the visual CLIP backbone, you must remove the final linear layer that projects the features to the shared feature space of the image-text encoder.\n", 99 | "- Load the data, compute the visual features and save them in the `features` folder.\n", 100 | "- For the in-distribution you need both the train and test split, while for the out-distribution, we will only use the validation split.\n", 101 | "\n", 102 | "\n", 103 | "### Optional structure\n", 104 | "\n", 105 | "```python\n", 106 | "def load_datasets(indist=\"CIFAR100\", ood=\"CIFAR10\", batch_size=256, tranform=None):\n", 107 | " # ....\n", 108 | " return indist_train_loader, indist_test_loader, ood_loader\n", 109 | "\n", 110 | "# visual is a boolean that controls whether the visual backbone is only returned or the whole CLIP model\n", 111 | "def get_model(visual, name, pretrained)\n", 112 | " # .....\n", 113 | " if visual:\n", 114 | " return backbone, preprocess\n", 115 | " return model, preprocess, tokenizer\n", 116 | " \n", 117 | "# Load everything .......\n", 118 | "\n", 119 | "feats, labels = get_features(backbone, dl, device)\n", 120 | "# Save features \n", 121 | "# ....\n", 122 | "```" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "### START CODE HERE ### (≈ 31 lines of code)\n", 132 | "\n", 133 | "### END CODE HERE ###\n", 134 | "\n", 135 | "# feature test\n", 136 | "for name, N in [('cifar100_train', 50000), ('cifar100_test', 10000), ('cifar10_test', 10000)]:\n", 137 | " feats = torch.load(f'features/{name}_feats.pt')\n", 138 | " labels = torch.load(f'features/{name}_labels.pt')\n", 139 | " assert feats.shape == (N, 1024)\n", 140 | " assert labels.shape == (N,)\n", 141 | "print('Success!')" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "# Part III. Compute the k-NN similarity as the OOD score\n", 149 | "\n", 150 | "- For each test image of in and out distribution compute the top-1 cosine similarity and use it as OOD score.\n", 151 | "- Report the resulting AUROC score.\n", 152 | "- Note: Use the image features and not the images!" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "@torch.no_grad()\n", 162 | "def OOD_classifier_knn(train_features, test_features, k=1):\n", 163 | " ### START CODE HERE ### (≈ 13 lines of code)\n", 164 | " \n", 165 | " ### END CODE HERE ###\n", 166 | " return cos_sim \n", 167 | "\n", 168 | "# load the computed features and compute scores\n", 169 | "### START CODE HERE ### (≈ 5 lines of code)\n", 170 | "indist_train = ...\n", 171 | "indist_test = ...\n", 172 | "ood_test = ...\n", 173 | "score_in = ...\n", 174 | "score_out = ...\n", 175 | "### END CODE HERE ###\n", 176 | "print(f'CIFAR100-->CIFAR10 AUROC: {auroc_score(score_in, score_out):.2f}')" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": {}, 182 | "source": [ 183 | "### Expected result\n", 184 | "\n", 185 | "```\n", 186 | "CIFAR100-->CIFAR10 AUROC: 83.55\n", 187 | "```\n" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "# Part IV. Compute MSP using the text encoder and the label names\n", 195 | "\n", 196 | "We will now consider the case where the in-distribution label names are available.\n", 197 | "\n", 198 | "Your task is to apply zero-shot classification and get the maximum softmax probability (MSP) as the OOD score.\n", 199 | "\n", 200 | "In short:\n", 201 | "- compute image and text embeddings\n", 202 | "- compute the image-test similarity matrix (logits)\n", 203 | "- apply softmax to the logits for each image to get a probability distribution of the classes.\n", 204 | "- compute maximum softmax probability (MSP)\n", 205 | "\n", 206 | "- `Note`: After loading the saved image features you need to apply the linear projection layer from the visual backbone of CLIP" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "### optional to use\n", 216 | "def compute_logits(model, text_embs, img_embs, device):\n", 217 | " ### START CODE HERE ### (≈ 5 lines of code)\n", 218 | " \n", 219 | " ### END CODE HERE ###\n", 220 | " return logits\n", 221 | "\n", 222 | "### optional to use\n", 223 | "def compute_text_embeds(model, class_tokens):\n", 224 | " ### START CODE HERE ### (≈ 3 lines of code)\n", 225 | " \n", 226 | " ### END CODE HERE ###\n", 227 | " return text_embs\n", 228 | "\n", 229 | "def compute_msp(label_names, model, class_tokens, indist_test, ood_test, device):\n", 230 | " ### START CODE HERE ### (≈ 7 lines of code)\n", 231 | " \n", 232 | " ### END CODE HERE ###\n", 233 | " return score_in, score_out\n", 234 | " \n", 235 | "\n", 236 | "# Load model and features\n", 237 | "### START CODE HERE ### (≈ 4 lines of code)\n", 238 | "indist_test = ...\n", 239 | "ood_test = ...\n", 240 | "model, preprocess, tokenizer = ...\n", 241 | "model = ...\n", 242 | "### END CODE HERE ###\n", 243 | "\n", 244 | "### Provided \n", 245 | "label_names = torchvision.datasets.CIFAR100(root='../data', train=True, download=True).classes\n", 246 | "prompts = ['an image of a ' + lab.replace('_', ' ') for lab in label_names]\n", 247 | "class_tokens = tokenizer(label_names).to(device)\n", 248 | "score_in, score_out = compute_msp(label_names, model, class_tokens, indist_test, ood_test, device)\n", 249 | "print(f'CIFAR100-->CIFAR10 AUROC: {auroc_score(score_in, score_out):.2f}')" 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "metadata": {}, 255 | "source": [ 256 | "### Expected result\n", 257 | "\n", 258 | "```\n", 259 | "CIFAR100-->CIFAR10 AUROC: 76.38\n", 260 | "```\n" 261 | ] 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "metadata": {}, 266 | "source": [ 267 | "# Part V. Linear probing on the pseudolabels\n", 268 | "\n", 269 | "- Your task is to train a linear layer using the CLIP pseudolabels as targets.\n", 270 | "- The pseudolabels are the argmax of the logits computed above, i.e., take the class with the maximum probability as the class label" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [ 279 | "def compute_score_probe_msp(lin_layer, indist_loader, ood_loader, device):\n", 280 | " \"\"\"\n", 281 | " Computes the MSP scores for a linear layer for both in- and out- distribution.\n", 282 | " \"\"\"\n", 283 | " ### START CODE HERE ### (≈ 4 lines of code)\n", 284 | " \n", 285 | " ### END CODE HERE ###\n", 286 | " return score_in, score_out\n", 287 | "\n", 288 | "### START CODE HERE ### (≈ 17 lines of code) \n", 289 | "# get CLIP model\n", 290 | "# get text embeds from label names\n", 291 | "\n", 292 | "# load features\n", 293 | "\n", 294 | "# compute CLIP logits of image features based on text encoder\n", 295 | "\n", 296 | "# get target pseudo labels from CLIP logits\n", 297 | "\n", 298 | "# create dataset and dataloaders for linear probing\n", 299 | "\n", 300 | "### END CODE HERE ###\n", 301 | "\n", 302 | "# The code below is provided based on our implementation. Optional to use!\n", 303 | "# Run linear probing\n", 304 | "embed_dim = train_dataset[0][0].shape[0]\n", 305 | "lin_layer = nn.Linear(embed_dim, 100).to(device)\n", 306 | "optimizer = torch.optim.Adam(lin_layer.parameters(), lr=1e-3)\n", 307 | "num_epochs = 20\n", 308 | "dict_log = linear_eval(lin_layer, optimizer, num_epochs, train_loader, val_loader, device)\n", 309 | "# compute MSP scores\n", 310 | "lin_layer = load_model(lin_layer, \"CLIP_best_max_train_acc.pth\")\n", 311 | "ood_dataset = torch.utils.data.TensorDataset(ood_test, torch.zeros(ood_test.shape[0], dtype=torch.long))\n", 312 | "ood_loader = torch.utils.data.DataLoader(ood_dataset, batch_size=128, shuffle=False, drop_last=False)\n", 313 | "score_in, score_out = compute_score_probe_msp(lin_layer, val_loader, ood_loader, device)\n", 314 | "print(f'CIFAR100-->CIFAR10 AUROC: {auroc_score(score_in, score_out):.2f}')" 315 | ] 316 | }, 317 | { 318 | "cell_type": "markdown", 319 | "metadata": {}, 320 | "source": [ 321 | "### Expected results\n", 322 | "\n", 323 | "AUROC may slightly vary due to random initialization of linear probing.\n", 324 | "```\n", 325 | "CIFAR100-->CIFAR10 AUROC: 74.81\n", 326 | "```" 327 | ] 328 | }, 329 | { 330 | "cell_type": "markdown", 331 | "metadata": {}, 332 | "source": [ 333 | "# Part VI. Mahalanobis distance as OOD score\n", 334 | "- Use the output of the linear layer from task 4 as features to compute the Mahalanobis distance and the relative Mahalanobis distance.\n", 335 | "- To compute the Mahalanobis distance group the features by their pseudolabels and compute the mean and covariance matrix for each class." 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": null, 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "### optional to use\n", 345 | "@torch.no_grad()\n", 346 | "def calc_maha_distance(embeds, means_c, inv_cov_c):\n", 347 | " ### START CODE HERE ### (≈ 3 lines of code)\n", 348 | " \n", 349 | " ### END CODE HERE ###\n", 350 | " return dist\n", 351 | "\n", 352 | "def OOD_classifier_maha(train_embeds_in, train_labels_in, test_embeds_in, test_embeds_outs, num_classes,\n", 353 | " relative=False):\n", 354 | " # optional to use our code!\n", 355 | " class_covs = []\n", 356 | " class_means = []\n", 357 | " used_classes = 0\n", 358 | " if type(train_labels_in) == torch.Tensor:\n", 359 | " train_labels_in = train_labels_in.cpu().numpy()\n", 360 | " if type(train_embeds_in) == torch.Tensor:\n", 361 | " train_embeds_in = train_embeds_in.cpu().numpy()\n", 362 | " test_embeds_in = test_embeds_in.cpu().numpy()\n", 363 | " test_embeds_outs = test_embeds_outs.cpu().numpy()\n", 364 | " ### START CODE HERE ### (≈ 23 lines of code)\n", 365 | " # calculate class-wise means and covariances\n", 366 | " \n", 367 | " # estimating the global std from train data\n", 368 | " \n", 369 | " # RMD: subtracting the average train score if relative is True\n", 370 | " \n", 371 | " # Get OOD score for each datapoint\n", 372 | " \n", 373 | " ### END CODE HERE ###\n", 374 | " return scores_in, scores_out\n", 375 | "\n", 376 | "# The code below is provided based on our implementation. Optional to use!\n", 377 | "num_classes = 100\n", 378 | "lin_layer = load_model(lin_layer, \"CLIP_best_max_train_acc.pth\")\n", 379 | "logits_indist_train, indist_pseudolabels_train = get_features(lin_layer, train_loader, device)\n", 380 | "logits_indist_test, indist_pseudolabels_test = get_features(lin_layer, val_loader, device)\n", 381 | "logits_ood, _ = get_features(lin_layer, ood_loader, device)\n", 382 | "# convert to numpy\n", 383 | "indist_pseudolabels_train = indist_pseudolabels_train.cpu().numpy()\n", 384 | "indist_pseudolabels_test = indist_pseudolabels_test.cpu().numpy()\n", 385 | "logits_indist_train = logits_indist_train.cpu().numpy()\n", 386 | "logits_indist_test = logits_indist_test.cpu().numpy()\n", 387 | "logits_ood = logits_ood.cpu().numpy()\n", 388 | "\n", 389 | "# run OOD classifier based on mahalanobis distance\n", 390 | "scores_in, scores_out = OOD_classifier_maha(logits_indist_train, indist_pseudolabels_train, \n", 391 | " logits_indist_test, logits_ood, num_classes, relative=False)\n", 392 | "print(f'Maha: CIFAR100-->CIFAR10 AUROC: {auroc_score(scores_in, scores_out):.2f}')\n", 393 | "scores_in, scores_out = OOD_classifier_maha(logits_indist_train, indist_pseudolabels_train, \n", 394 | " logits_indist_test, logits_ood, num_classes, relative=True)\n", 395 | "print(f'Relative Maha: CIFAR100-->CIFAR10 AUROC: {auroc_score(scores_in, scores_out):.2f}') " 396 | ] 397 | }, 398 | { 399 | "cell_type": "markdown", 400 | "metadata": {}, 401 | "source": [ 402 | "### Expected results\n", 403 | "(can differ based on linear probing performance)\n", 404 | "\n", 405 | "```\n", 406 | "Maha: CIFAR100-->CIFAR10 AUROC: 83.31\n", 407 | "Relative Maha: CIFAR100-->CIFAR10 AUROC: 80.88\n", 408 | "```" 409 | ] 410 | }, 411 | { 412 | "cell_type": "markdown", 413 | "metadata": {}, 414 | "source": [ 415 | "# Part VII. Mahalanobis distance using the real labels without linear probing\n", 416 | "- Again, compute the (relative) Mahalanobis distance as OOD score\n", 417 | "- This time, instead of using the pseudolabels and output of the linear probing layer, use the real labels of the training data and the features computed in task 1" 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": null, 423 | "metadata": {}, 424 | "outputs": [], 425 | "source": [ 426 | "### START CODE HERE ### (≈ 7 lines of code)\n", 427 | "# load features\n", 428 | "\n", 429 | "# load labels\n", 430 | "\n", 431 | "# run OOD classifier based on mahalanobis distance\n", 432 | "\n", 433 | "### END CODE HERE ###\n", 434 | "print(f'Maha: CIFAR100-->CIFAR10 AUROC: {auroc_score(scores_md_in, scores_md_out):.2f}')\n", 435 | "print(f'Relative Maha: CIFAR100-->CIFAR10 AUROC: {auroc_score(scores_rmd_in, scores_rmd_out):.2f}')" 436 | ] 437 | }, 438 | { 439 | "cell_type": "markdown", 440 | "metadata": {}, 441 | "source": [ 442 | "### Expected results\n", 443 | "```\n", 444 | "Maha: CIFAR100-->CIFAR10 AUROC: 71.71\n", 445 | "Relative Maha: CIFAR100-->CIFAR10 AUROC: 84.93\n", 446 | "```" 447 | ] 448 | }, 449 | { 450 | "cell_type": "markdown", 451 | "metadata": {}, 452 | "source": [ 453 | "# Part VIII. K-means clusters combined with Mahalanobis distance\n", 454 | "\n", 455 | "The paper [SSD: A Unified Framework for Self-Supervised Outlier Detection](https://arxiv.org/abs/2103.12051) has proposed another unsupervised method for OOD detection. Instead of using the (real or pseudo) labels as class-wise means, we will now use the obtained clusters as found be kmeans. In more detail:\n", 456 | "\n", 457 | "- Find k=10,50,100 clusters using Kmeans on the in-distribution training data (you can use the sklearn KMeans implementation).\n", 458 | "- Get the cluster centers.\n", 459 | "- Use them as class-wise means for the mahalanobis distance classifier." 460 | ] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "execution_count": null, 465 | "metadata": {}, 466 | "outputs": [], 467 | "source": [ 468 | "# The code below is provided based on our implementation. Optional to use!\n", 469 | "# load features - modify names if you use different names\n", 470 | "indist_train = torch.load('features/cifar100_train_feats.pt').cpu().numpy()\n", 471 | "indist_test = torch.load('features/cifar100_test_feats.pt').cpu().numpy()\n", 472 | "ood_test = torch.load('features/cifar10_test_feats.pt').cpu().numpy()\n", 473 | "results_md = []\n", 474 | "results_rmd = []\n", 475 | "for N in [10,50,100]:\n", 476 | " ### START CODE HERE ### (≈ 7 lines of code)\n", 477 | " \n", 478 | " ### END CODE HERE ###\n", 479 | " print(f'Kmeans (k={N}) + RMD: CIFAR100-->CIFAR10 AUROC: {auroc_rmd:.2f}')\n", 480 | " print(\"-\"*100)" 481 | ] 482 | }, 483 | { 484 | "cell_type": "markdown", 485 | "metadata": {}, 486 | "source": [ 487 | "### Expected results\n", 488 | "Can differ based on KMenas performance.\n", 489 | "```\n", 490 | "Kmeans (k=10) + MD: CIFAR100-->CIFAR10 AUROC: 67.87\n", 491 | "Kmeans (k=10) + RMD: CIFAR100-->CIFAR10 AUROC: 42.38\n", 492 | "----------------------------------------------------------------------------------------------------\n", 493 | "Kmeans (k=50) + MD: CIFAR100-->CIFAR10 AUROC: 72.18\n", 494 | "Kmeans (k=50) + RMD: CIFAR100-->CIFAR10 AUROC: 58.73\n", 495 | "----------------------------------------------------------------------------------------------------\n", 496 | "Kmeans (k=100) + MD: CIFAR100-->CIFAR10 AUROC: 72.84\n", 497 | "Kmeans (k=100) + RMD: CIFAR100-->CIFAR10 AUROC: 68.59\n", 498 | "----------------------------------------------------------------------------------------------------\n", 499 | "```" 500 | ] 501 | }, 502 | { 503 | "cell_type": "markdown", 504 | "metadata": { 505 | "id": "RK4j971u_4yu" 506 | }, 507 | "source": [ 508 | "That's the end of this exercise. If you reached this point, **congratulations**!" 509 | ] 510 | }, 511 | { 512 | "cell_type": "markdown", 513 | "metadata": {}, 514 | "source": [] 515 | } 516 | ], 517 | "metadata": { 518 | "accelerator": "GPU", 519 | "colab": { 520 | "machine_shape": "hm", 521 | "name": "[Exercise 3 solution] - Self-distillation on CIFAR100 .ipynb", 522 | "provenance": [] 523 | }, 524 | "kernelspec": { 525 | "display_name": "Python 3", 526 | "language": "python", 527 | "name": "python3" 528 | }, 529 | "language_info": { 530 | "codemirror_mode": { 531 | "name": "ipython", 532 | "version": 3 533 | }, 534 | "file_extension": ".py", 535 | "mimetype": "text/x-python", 536 | "name": "python", 537 | "nbconvert_exporter": "python", 538 | "pygments_lexer": "ipython3", 539 | "version": "3.8.5" 540 | }, 541 | "vscode": { 542 | "interpreter": { 543 | "hash": "dc5fcf396fe0abd4fa852aee332a0572494dcaf5776820055c87d9b84157f362" 544 | } 545 | }, 546 | "widgets": { 547 | "application/vnd.jupyter.widget-state+json": { 548 | "15451b89dea54127867d368514cfea78": { 549 | "model_module": "@jupyter-widgets/base", 550 | "model_module_version": "1.2.0", 551 | "model_name": "LayoutModel", 552 | "state": { 553 | "_model_module": "@jupyter-widgets/base", 554 | "_model_module_version": "1.2.0", 555 | "_model_name": "LayoutModel", 556 | "_view_count": null, 557 | "_view_module": "@jupyter-widgets/base", 558 | "_view_module_version": "1.2.0", 559 | "_view_name": "LayoutView", 560 | "align_content": null, 561 | "align_items": null, 562 | "align_self": null, 563 | "border": null, 564 | "bottom": null, 565 | "display": null, 566 | "flex": null, 567 | "flex_flow": null, 568 | "grid_area": null, 569 | "grid_auto_columns": null, 570 | "grid_auto_flow": null, 571 | "grid_auto_rows": null, 572 | "grid_column": null, 573 | "grid_gap": null, 574 | "grid_row": null, 575 | "grid_template_areas": null, 576 | "grid_template_columns": null, 577 | "grid_template_rows": null, 578 | "height": null, 579 | "justify_content": null, 580 | "justify_items": null, 581 | "left": null, 582 | "margin": null, 583 | "max_height": null, 584 | "max_width": null, 585 | "min_height": null, 586 | "min_width": null, 587 | "object_fit": null, 588 | "object_position": null, 589 | "order": null, 590 | "overflow": null, 591 | "overflow_x": null, 592 | "overflow_y": null, 593 | "padding": null, 594 | "right": null, 595 | "top": null, 596 | "visibility": null, 597 | "width": null 598 | } 599 | }, 600 | "4d335008de124ef1890de5087a8254f8": { 601 | "model_module": "@jupyter-widgets/controls", 602 | "model_module_version": "1.5.0", 603 | "model_name": "HTMLModel", 604 | "state": { 605 | "_dom_classes": [], 606 | "_model_module": "@jupyter-widgets/controls", 607 | "_model_module_version": "1.5.0", 608 | "_model_name": "HTMLModel", 609 | "_view_count": null, 610 | "_view_module": "@jupyter-widgets/controls", 611 | "_view_module_version": "1.5.0", 612 | "_view_name": "HTMLView", 613 | "description": "", 614 | "description_tooltip": null, 615 | "layout": "IPY_MODEL_15451b89dea54127867d368514cfea78", 616 | "placeholder": "​", 617 | "style": "IPY_MODEL_a1a228b8e820488aa9d7a39326832b43", 618 | "value": "" 619 | } 620 | }, 621 | "5c84f8f441eb4b65a2947a8b0076b78c": { 622 | "model_module": "@jupyter-widgets/base", 623 | "model_module_version": "1.2.0", 624 | "model_name": "LayoutModel", 625 | "state": { 626 | "_model_module": "@jupyter-widgets/base", 627 | "_model_module_version": "1.2.0", 628 | "_model_name": "LayoutModel", 629 | "_view_count": null, 630 | "_view_module": "@jupyter-widgets/base", 631 | "_view_module_version": "1.2.0", 632 | "_view_name": "LayoutView", 633 | "align_content": null, 634 | "align_items": null, 635 | "align_self": null, 636 | "border": null, 637 | "bottom": null, 638 | "display": null, 639 | "flex": null, 640 | "flex_flow": null, 641 | "grid_area": null, 642 | "grid_auto_columns": null, 643 | "grid_auto_flow": null, 644 | "grid_auto_rows": null, 645 | "grid_column": null, 646 | "grid_gap": null, 647 | "grid_row": null, 648 | "grid_template_areas": null, 649 | "grid_template_columns": null, 650 | "grid_template_rows": null, 651 | "height": null, 652 | "justify_content": null, 653 | "justify_items": null, 654 | "left": null, 655 | "margin": null, 656 | "max_height": null, 657 | "max_width": null, 658 | "min_height": null, 659 | "min_width": null, 660 | "object_fit": null, 661 | "object_position": null, 662 | "order": null, 663 | "overflow": null, 664 | "overflow_x": null, 665 | "overflow_y": null, 666 | "padding": null, 667 | "right": null, 668 | "top": null, 669 | "visibility": null, 670 | "width": null 671 | } 672 | }, 673 | "947edc38a98549e79c0906847f20560b": { 674 | "model_module": "@jupyter-widgets/controls", 675 | "model_module_version": "1.5.0", 676 | "model_name": "DescriptionStyleModel", 677 | "state": { 678 | "_model_module": "@jupyter-widgets/controls", 679 | "_model_module_version": "1.5.0", 680 | "_model_name": "DescriptionStyleModel", 681 | "_view_count": null, 682 | "_view_module": "@jupyter-widgets/base", 683 | "_view_module_version": "1.2.0", 684 | "_view_name": "StyleView", 685 | "description_width": "" 686 | } 687 | }, 688 | "9bb9f2adc1f4480ba6ae6b390eda4521": { 689 | "model_module": "@jupyter-widgets/base", 690 | "model_module_version": "1.2.0", 691 | "model_name": "LayoutModel", 692 | "state": { 693 | "_model_module": "@jupyter-widgets/base", 694 | "_model_module_version": "1.2.0", 695 | "_model_name": "LayoutModel", 696 | "_view_count": null, 697 | "_view_module": "@jupyter-widgets/base", 698 | "_view_module_version": "1.2.0", 699 | "_view_name": "LayoutView", 700 | "align_content": null, 701 | "align_items": null, 702 | "align_self": null, 703 | "border": null, 704 | "bottom": null, 705 | "display": null, 706 | "flex": null, 707 | "flex_flow": null, 708 | "grid_area": null, 709 | "grid_auto_columns": null, 710 | "grid_auto_flow": null, 711 | "grid_auto_rows": null, 712 | "grid_column": null, 713 | "grid_gap": null, 714 | "grid_row": null, 715 | "grid_template_areas": null, 716 | "grid_template_columns": null, 717 | "grid_template_rows": null, 718 | "height": null, 719 | "justify_content": null, 720 | "justify_items": null, 721 | "left": null, 722 | "margin": null, 723 | "max_height": null, 724 | "max_width": null, 725 | "min_height": null, 726 | "min_width": null, 727 | "object_fit": null, 728 | "object_position": null, 729 | "order": null, 730 | "overflow": null, 731 | "overflow_x": null, 732 | "overflow_y": null, 733 | "padding": null, 734 | "right": null, 735 | "top": null, 736 | "visibility": null, 737 | "width": null 738 | } 739 | }, 740 | "a0db04ae470a41c098cb9b59a67e899b": { 741 | "model_module": "@jupyter-widgets/base", 742 | "model_module_version": "1.2.0", 743 | "model_name": "LayoutModel", 744 | "state": { 745 | "_model_module": "@jupyter-widgets/base", 746 | "_model_module_version": "1.2.0", 747 | "_model_name": "LayoutModel", 748 | "_view_count": null, 749 | "_view_module": "@jupyter-widgets/base", 750 | "_view_module_version": "1.2.0", 751 | "_view_name": "LayoutView", 752 | "align_content": null, 753 | "align_items": null, 754 | "align_self": null, 755 | "border": null, 756 | "bottom": null, 757 | "display": null, 758 | "flex": null, 759 | "flex_flow": null, 760 | "grid_area": null, 761 | "grid_auto_columns": null, 762 | "grid_auto_flow": null, 763 | "grid_auto_rows": null, 764 | "grid_column": null, 765 | "grid_gap": null, 766 | "grid_row": null, 767 | "grid_template_areas": null, 768 | "grid_template_columns": null, 769 | "grid_template_rows": null, 770 | "height": null, 771 | "justify_content": null, 772 | "justify_items": null, 773 | "left": null, 774 | "margin": null, 775 | "max_height": null, 776 | "max_width": null, 777 | "min_height": null, 778 | "min_width": null, 779 | "object_fit": null, 780 | "object_position": null, 781 | "order": null, 782 | "overflow": null, 783 | "overflow_x": null, 784 | "overflow_y": null, 785 | "padding": null, 786 | "right": null, 787 | "top": null, 788 | "visibility": null, 789 | "width": null 790 | } 791 | }, 792 | "a1a228b8e820488aa9d7a39326832b43": { 793 | "model_module": "@jupyter-widgets/controls", 794 | "model_module_version": "1.5.0", 795 | "model_name": "DescriptionStyleModel", 796 | "state": { 797 | "_model_module": "@jupyter-widgets/controls", 798 | "_model_module_version": "1.5.0", 799 | "_model_name": "DescriptionStyleModel", 800 | "_view_count": null, 801 | "_view_module": "@jupyter-widgets/base", 802 | "_view_module_version": "1.2.0", 803 | "_view_name": "StyleView", 804 | "description_width": "" 805 | } 806 | }, 807 | "a83800177edb46bdb5789bd0507c55b1": { 808 | "model_module": "@jupyter-widgets/controls", 809 | "model_module_version": "1.5.0", 810 | "model_name": "FloatProgressModel", 811 | "state": { 812 | "_dom_classes": [], 813 | "_model_module": "@jupyter-widgets/controls", 814 | "_model_module_version": "1.5.0", 815 | "_model_name": "FloatProgressModel", 816 | "_view_count": null, 817 | "_view_module": "@jupyter-widgets/controls", 818 | "_view_module_version": "1.5.0", 819 | "_view_name": "ProgressView", 820 | "bar_style": "success", 821 | "description": "", 822 | "description_tooltip": null, 823 | "layout": "IPY_MODEL_9bb9f2adc1f4480ba6ae6b390eda4521", 824 | "max": 169001437, 825 | "min": 0, 826 | "orientation": "horizontal", 827 | "style": "IPY_MODEL_d6ea4f8b6c304a279b56769a81897f5f", 828 | "value": 169001437 829 | } 830 | }, 831 | "aca8ccf946814899be23fad49d4c4ecb": { 832 | "model_module": "@jupyter-widgets/controls", 833 | "model_module_version": "1.5.0", 834 | "model_name": "HTMLModel", 835 | "state": { 836 | "_dom_classes": [], 837 | "_model_module": "@jupyter-widgets/controls", 838 | "_model_module_version": "1.5.0", 839 | "_model_name": "HTMLModel", 840 | "_view_count": null, 841 | "_view_module": "@jupyter-widgets/controls", 842 | "_view_module_version": "1.5.0", 843 | "_view_name": "HTMLView", 844 | "description": "", 845 | "description_tooltip": null, 846 | "layout": "IPY_MODEL_a0db04ae470a41c098cb9b59a67e899b", 847 | "placeholder": "​", 848 | "style": "IPY_MODEL_947edc38a98549e79c0906847f20560b", 849 | "value": " 169001984/? [00:10<00:00, 16869945.76it/s]" 850 | } 851 | }, 852 | "ae807dfd8be647d29b923f286caa47c4": { 853 | "model_module": "@jupyter-widgets/controls", 854 | "model_module_version": "1.5.0", 855 | "model_name": "HBoxModel", 856 | "state": { 857 | "_dom_classes": [], 858 | "_model_module": "@jupyter-widgets/controls", 859 | "_model_module_version": "1.5.0", 860 | "_model_name": "HBoxModel", 861 | "_view_count": null, 862 | "_view_module": "@jupyter-widgets/controls", 863 | "_view_module_version": "1.5.0", 864 | "_view_name": "HBoxView", 865 | "box_style": "", 866 | "children": [ 867 | "IPY_MODEL_4d335008de124ef1890de5087a8254f8", 868 | "IPY_MODEL_a83800177edb46bdb5789bd0507c55b1", 869 | "IPY_MODEL_aca8ccf946814899be23fad49d4c4ecb" 870 | ], 871 | "layout": "IPY_MODEL_5c84f8f441eb4b65a2947a8b0076b78c" 872 | } 873 | }, 874 | "d6ea4f8b6c304a279b56769a81897f5f": { 875 | "model_module": "@jupyter-widgets/controls", 876 | "model_module_version": "1.5.0", 877 | "model_name": "ProgressStyleModel", 878 | "state": { 879 | "_model_module": "@jupyter-widgets/controls", 880 | "_model_module_version": "1.5.0", 881 | "_model_name": "ProgressStyleModel", 882 | "_view_count": null, 883 | "_view_module": "@jupyter-widgets/base", 884 | "_view_module_version": "1.2.0", 885 | "_view_name": "StyleView", 886 | "bar_color": null, 887 | "description_width": "" 888 | } 889 | } 890 | } 891 | } 892 | }, 893 | "nbformat": 4, 894 | "nbformat_minor": 1 895 | } 896 | -------------------------------------------------------------------------------- /exercises/week09_10_CLIP/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.metrics import roc_auc_score 4 | from torch import nn 5 | from tqdm import tqdm 6 | 7 | @torch.no_grad() 8 | def get_features(model, dataloader, device): 9 | model.eval() 10 | model = model.to(device) 11 | feats, labs = [], [] 12 | for i in dataloader: 13 | inp_data,labels = i 14 | inp_data = inp_data.to(device) 15 | features = model(inp_data) 16 | features = features.cpu().detach().flatten(start_dim=1) 17 | labels = labels.cpu().detach() 18 | feats.append(features) 19 | labs.append(labels) 20 | return torch.cat(feats, dim=0), torch.cat(labs, dim=0) 21 | 22 | def auroc_score(score_in, score_out): 23 | score_in = score_in.cpu() 24 | score_out = score_out.cpu() 25 | labels = torch.cat((torch.ones_like(score_in), torch.zeros_like(score_out))) 26 | return roc_auc_score(labels.numpy(), torch.cat((score_in, score_out)).numpy()) * 100 27 | 28 | @torch.no_grad() 29 | def OOD_classifier_knn(train_features, test_features, k=1): 30 | # k = -1 for whole trainset 31 | if k < 0: 32 | k = len(train_features) 33 | 34 | num_chunks = 128 # num of test images in loop 35 | num_test_images = test_features.shape[0] 36 | imgs_per_chunk = num_test_images // num_chunks 37 | cos_sim = torch.zeros(num_test_images).cuda() 38 | 39 | train_features = nn.functional.normalize(train_features, dim=-1, p=2) 40 | test_features = nn.functional.normalize(test_features, dim=-1, p=2) 41 | 42 | for idx in range(0, num_test_images, imgs_per_chunk): 43 | # get the features for test images 44 | idx_next_chunk = min((idx + imgs_per_chunk), num_test_images) 45 | features = test_features[idx : idx_next_chunk, :] 46 | # calculate the metric and compute ood scores 47 | similarity = features @ train_features.T 48 | top_sim, _ = similarity.topk(k, largest=True, sorted=True, dim=-1) 49 | cos_sim[idx: idx_next_chunk] = top_sim.mean(dim=1) 50 | return cos_sim 51 | 52 | def validate(model, val_loader, device): 53 | model.eval() 54 | criterion = nn.CrossEntropyLoss() 55 | correct, total = 0, 0 56 | loss_step = [] 57 | with torch.no_grad(): 58 | for data in val_loader: 59 | inp_data,labels = data 60 | inp_data = inp_data.to(device) 61 | labels = labels.to(device) 62 | outputs = model(inp_data) 63 | val_loss = criterion(outputs, labels) 64 | predicted = torch.max(outputs, 1)[1] 65 | total += labels.size(0) 66 | correct += (predicted == labels).sum() 67 | loss_step.append(val_loss.item()) 68 | # dont forget to take the means here 69 | val_acc = (100 * correct / total).cpu().numpy() 70 | val_loss_epoch = torch.tensor(loss_step).mean().numpy() 71 | return val_acc , val_loss_epoch 72 | 73 | 74 | def train_one_epoch(model, optimizer, train_loader, device): 75 | model.train() 76 | criterion = nn.CrossEntropyLoss() 77 | loss_step = [] 78 | correct, total = 0, 0 79 | for data in train_loader: 80 | # Move the data to the GPU 81 | inp_data,labels = data 82 | inp_data = inp_data.to(device) 83 | labels = labels.to(device) 84 | outputs = model(inp_data) 85 | loss = criterion(outputs, labels) 86 | optimizer.zero_grad() 87 | loss.backward() 88 | optimizer.step() 89 | with torch.no_grad(): 90 | _, predicted = torch.max(outputs, 1) 91 | total += labels.size(0) 92 | correct += (predicted == labels).sum() 93 | loss_step.append(loss.item()) 94 | # dont forget the means here 95 | loss_curr_epoch = np.mean(loss_step) 96 | train_acc = (100 * correct / total).cpu() 97 | return loss_curr_epoch, train_acc 98 | 99 | 100 | def linear_eval(model, optimizer, num_epochs, train_loader, val_loader, device, prefix="CLIP"): 101 | best_acc = 0 102 | model = model.to(device) 103 | dict_log = {"train_acc_epoch":[], "val_acc_epoch":[], "loss_epoch":[], "val_loss":[]} 104 | pbar = tqdm(range(num_epochs)) 105 | for epoch in pbar: 106 | loss_curr_epoch, train_acc = train_one_epoch(model, optimizer, train_loader, device) 107 | val_acc, val_loss = validate(model, val_loader, device) 108 | 109 | # Print epoch results to screen 110 | msg = (f'Ep {epoch}/{num_epochs}: Accuracy : Train:{train_acc:.2f} \t Val:{val_acc:.2f} || Loss: Train {loss_curr_epoch:.3f} \t Val {val_loss:.3f}') 111 | pbar.set_description(msg) 112 | # Track stats 113 | dict_log["train_acc_epoch"].append(train_acc) 114 | dict_log["val_acc_epoch"].append(val_acc) 115 | dict_log["loss_epoch"].append(loss_curr_epoch) 116 | dict_log["val_loss"].append(val_loss) 117 | 118 | if train_acc > best_acc: 119 | best_acc = train_acc 120 | torch.save({ 121 | 'epoch': epoch, 122 | 'model_state_dict': model.state_dict(), 123 | 'optimizer_state_dict': optimizer.state_dict(), 124 | 'best_acc': best_acc, 125 | }, f'{prefix}_best_max_train_acc.pth') 126 | return dict_log 127 | 128 | def load_model(model, path): 129 | checkpoint = torch.load(path) 130 | model.load_state_dict(checkpoint['model_state_dict']) 131 | print(f"Model {path} is loaded from epoch {checkpoint['epoch']}") 132 | return model 133 | -------------------------------------------------------------------------------- /exercises/week12-13_Proteins/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | from typing import List, Tuple, Optional, Dict, NamedTuple, Union, Callable, Sequence 4 | import pickle 5 | import re 6 | import shutil 7 | import torch 8 | from pathlib import Path 9 | import numpy as np 10 | import torch 11 | from scipy.spatial.distance import squareform, pdist, cdist 12 | import matplotlib.pyplot as plt 13 | import matplotlib as mpl 14 | import string 15 | 16 | import torch.functional as F 17 | from Bio import SeqIO 18 | import biotite.structure as bs 19 | 20 | TEST_VALS = torch.load('test_vals.pt') 21 | 22 | 23 | def test_func(func): 24 | name = func.__name__ 25 | for inp, outp in TEST_VALS[name]: 26 | assert torch.allclose(func(**inp), outp), f'Function {name} failed on input {inp}' 27 | print(f'Function {name} passed all tests!') 28 | 29 | 30 | 31 | proteinseq_toks = { 32 | 'toks': ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-'] 33 | } 34 | 35 | # This is an efficient way to delete lowercase characters and insertion characters from a string 36 | deletekeys = dict.fromkeys(string.ascii_lowercase) 37 | deletekeys["."] = None 38 | deletekeys["*"] = None 39 | translation = str.maketrans(deletekeys) 40 | RawMSA = Sequence[Tuple[str, str]] 41 | 42 | class FastaBatchedDataset(object): 43 | def __init__(self, sequence_labels, sequence_strs): 44 | self.sequence_labels = list(sequence_labels) 45 | self.sequence_strs = list(sequence_strs) 46 | 47 | @classmethod 48 | def from_file(cls, fasta_file): 49 | sequence_labels, sequence_strs = [], [] 50 | cur_seq_label = None 51 | buf = [] 52 | 53 | def _flush_current_seq(): 54 | nonlocal cur_seq_label, buf 55 | if cur_seq_label is None: 56 | return 57 | sequence_labels.append(cur_seq_label) 58 | sequence_strs.append("".join(buf)) 59 | cur_seq_label = None 60 | buf = [] 61 | 62 | with open(fasta_file, "r") as infile: 63 | for line_idx, line in enumerate(infile): 64 | if line.startswith(">"): # label line 65 | _flush_current_seq() 66 | line = line[1:].strip() 67 | if len(line) > 0: 68 | cur_seq_label = line 69 | else: 70 | cur_seq_label = f"seqnum{line_idx:09d}" 71 | else: # sequence line 72 | buf.append(line.strip()) 73 | 74 | _flush_current_seq() 75 | 76 | assert len(set(sequence_labels)) == len( 77 | sequence_labels 78 | ), "Found duplicate sequence labels" 79 | 80 | return cls(sequence_labels, sequence_strs) 81 | 82 | def __len__(self): 83 | return len(self.sequence_labels) 84 | 85 | def __getitem__(self, idx): 86 | return self.sequence_labels[idx], self.sequence_strs[idx] 87 | 88 | def get_batch_indices(self, toks_per_batch, extra_toks_per_seq=0): 89 | sizes = [(len(s), i) for i, s in enumerate(self.sequence_strs)] 90 | sizes.sort() 91 | batches = [] 92 | buf = [] 93 | max_len = 0 94 | 95 | def _flush_current_buf(): 96 | nonlocal max_len, buf 97 | if len(buf) == 0: 98 | return 99 | batches.append(buf) 100 | buf = [] 101 | max_len = 0 102 | 103 | for sz, i in sizes: 104 | sz += extra_toks_per_seq 105 | if max(sz, max_len) * (len(buf) + 1) > toks_per_batch: 106 | _flush_current_buf() 107 | max_len = max(max_len, sz) 108 | buf.append(i) 109 | 110 | _flush_current_buf() 111 | return batches 112 | 113 | 114 | class Alphabet(object): 115 | def __init__( 116 | self, 117 | standard_toks: Sequence[str], 118 | prepend_toks: Sequence[str] = ("", "", "", ""), 119 | append_toks: Sequence[str] = ("", "", ""), 120 | prepend_bos: bool = True, 121 | append_eos: bool = False, 122 | use_msa: bool = False, 123 | ): 124 | self.standard_toks = list(standard_toks) 125 | self.prepend_toks = list(prepend_toks) 126 | self.append_toks = list(append_toks) 127 | self.prepend_bos = prepend_bos 128 | self.append_eos = append_eos 129 | self.use_msa = use_msa 130 | 131 | self.all_toks = list(self.prepend_toks) 132 | self.all_toks.extend(self.standard_toks) 133 | for i in range((8 - (len(self.all_toks) % 8)) % 8): 134 | self.all_toks.append(f"") 135 | self.all_toks.extend(self.append_toks) 136 | 137 | self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)} 138 | 139 | self.unk_idx = self.tok_to_idx[""] 140 | self.padding_idx = self.get_idx("") 141 | self.cls_idx = self.get_idx("") 142 | self.mask_idx = self.get_idx("") 143 | self.eos_idx = self.get_idx("") 144 | self.all_special_tokens = ['', '', '', '', ''] 145 | self.unique_no_split_tokens = self.all_toks 146 | 147 | def __len__(self): 148 | return len(self.all_toks) 149 | 150 | def get_idx(self, tok): 151 | return self.tok_to_idx.get(tok, self.unk_idx) 152 | 153 | def get_tok(self, ind): 154 | return self.all_toks[ind] 155 | 156 | def to_dict(self): 157 | return self.tok_to_idx.copy() 158 | 159 | def get_batch_converter(self, truncation_seq_length: int = None): 160 | if self.use_msa: 161 | return MSABatchConverter(self, truncation_seq_length) 162 | else: 163 | return BatchConverter(self, truncation_seq_length) 164 | 165 | @classmethod 166 | def from_architecture(cls, name: str) -> "Alphabet": 167 | if name in ("ESM-1", "protein_bert_base"): 168 | standard_toks = proteinseq_toks["toks"] 169 | prepend_toks: Tuple[str, ...] = ("", "", "", "") 170 | append_toks: Tuple[str, ...] = ("", "", "") 171 | prepend_bos = True 172 | append_eos = False 173 | use_msa = False 174 | elif name in ("ESM-1b", "roberta_large"): 175 | standard_toks = proteinseq_toks["toks"] 176 | prepend_toks = ("", "", "", "") 177 | append_toks = ("",) 178 | prepend_bos = True 179 | append_eos = True 180 | use_msa = False 181 | elif name in ("MSA Transformer", "msa_transformer"): 182 | standard_toks = proteinseq_toks["toks"] 183 | prepend_toks = ("", "", "", "") 184 | append_toks = ("",) 185 | prepend_bos = True 186 | append_eos = False 187 | use_msa = True 188 | elif "invariant_gvp" in name.lower(): 189 | standard_toks = proteinseq_toks["toks"] 190 | prepend_toks = ("", "", "", "") 191 | append_toks = ("", "", "") 192 | prepend_bos = True 193 | append_eos = False 194 | use_msa = False 195 | else: 196 | raise ValueError("Unknown architecture selected") 197 | return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos, use_msa) 198 | 199 | def _tokenize(self, text) -> str: 200 | return text.split() 201 | 202 | def tokenize(self, text, **kwargs) -> List[str]: 203 | """ 204 | Converts a string in a sequence of tokens, using the tokenizer. 205 | 206 | Args: 207 | text (:obj:`str`): 208 | The sequence to be encoded. 209 | 210 | Returns: 211 | :obj:`List[str]`: The list of tokens. 212 | """ 213 | 214 | def split_on_token(tok, text): 215 | result = [] 216 | split_text = text.split(tok) 217 | for i, sub_text in enumerate(split_text): 218 | if i < len(split_text) - 1: 219 | sub_text = sub_text.rstrip() 220 | if i > 0: 221 | sub_text = sub_text.lstrip() 222 | 223 | if i == 0 and not sub_text: 224 | result.append(tok) 225 | elif i == len(split_text) - 1: 226 | if sub_text: 227 | result.append(sub_text) 228 | else: 229 | pass 230 | else: 231 | if sub_text: 232 | result.append(sub_text) 233 | result.append(tok) 234 | return result 235 | 236 | def split_on_tokens(tok_list, text): 237 | if not text.strip(): 238 | return [] 239 | 240 | tokenized_text = [] 241 | text_list = [text] 242 | for tok in tok_list: 243 | tokenized_text = [] 244 | for sub_text in text_list: 245 | if sub_text not in self.unique_no_split_tokens: 246 | tokenized_text.extend(split_on_token(tok, sub_text)) 247 | else: 248 | tokenized_text.append(sub_text) 249 | text_list = tokenized_text 250 | 251 | return list( 252 | itertools.chain.from_iterable( 253 | ( 254 | self._tokenize(token) 255 | if token not in self.unique_no_split_tokens 256 | else [token] 257 | for token in tokenized_text 258 | ) 259 | ) 260 | ) 261 | no_split_token = self.unique_no_split_tokens 262 | tokenized_text = split_on_tokens(no_split_token, text) 263 | return tokenized_text 264 | 265 | def encode(self, text): 266 | return [self.tok_to_idx[tok] for tok in self.tokenize(text)] 267 | 268 | 269 | class BatchConverter(object): 270 | """Callable to convert an unprocessed (labels + strings) batch to a 271 | processed (labels + tensor) batch. 272 | """ 273 | 274 | def __init__(self, alphabet, truncation_seq_length: int = None): 275 | self.alphabet = alphabet 276 | self.truncation_seq_length = truncation_seq_length 277 | 278 | def __call__(self, raw_batch: Sequence[Tuple[str, str]]): 279 | # RoBERTa uses an eos token, while ESM-1 does not. 280 | batch_size = len(raw_batch) 281 | batch_labels, seq_str_list = zip(*raw_batch) 282 | seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list] 283 | if self.truncation_seq_length: 284 | seq_encoded_list = [seq_str[:self.truncation_seq_length] for seq_str in seq_encoded_list] 285 | max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list) 286 | tokens = torch.empty( 287 | ( 288 | batch_size, 289 | max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos), 290 | ), 291 | dtype=torch.int64, 292 | ) 293 | tokens.fill_(self.alphabet.padding_idx) 294 | labels = [] 295 | strs = [] 296 | 297 | for i, (label, seq_str, seq_encoded) in enumerate( 298 | zip(batch_labels, seq_str_list, seq_encoded_list) 299 | ): 300 | labels.append(label) 301 | strs.append(seq_str) 302 | if self.alphabet.prepend_bos: 303 | tokens[i, 0] = self.alphabet.cls_idx 304 | seq = torch.tensor(seq_encoded, dtype=torch.int64) 305 | tokens[ 306 | i, 307 | int(self.alphabet.prepend_bos) : len(seq_encoded) 308 | + int(self.alphabet.prepend_bos), 309 | ] = seq 310 | if self.alphabet.append_eos: 311 | tokens[i, len(seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx 312 | 313 | return labels, strs, tokens 314 | 315 | 316 | 317 | def read_fasta( 318 | path, 319 | keep_gaps=True, 320 | keep_insertions=True, 321 | to_upper=False, 322 | ): 323 | with open(path, "r") as f: 324 | for result in read_alignment_lines( 325 | f, keep_gaps=keep_gaps, keep_insertions=keep_insertions, to_upper=to_upper 326 | ): 327 | yield result 328 | 329 | 330 | def read_alignment_lines( 331 | lines, 332 | keep_gaps=True, 333 | keep_insertions=True, 334 | to_upper=False, 335 | ): 336 | seq = desc = None 337 | 338 | def parse(s): 339 | if not keep_gaps: 340 | s = re.sub("-", "", s) 341 | if not keep_insertions: 342 | s = re.sub("[a-z]", "", s) 343 | return s.upper() if to_upper else s 344 | 345 | for line in lines: 346 | # Line may be empty if seq % file_line_width == 0 347 | if len(line) > 0 and line[0] == ">": 348 | if seq is not None: 349 | yield desc, parse(seq) 350 | desc = line.strip().lstrip(">") 351 | seq = "" 352 | else: 353 | assert isinstance(seq, str) 354 | seq += line.strip() 355 | assert isinstance(seq, str) and isinstance(desc, str) 356 | yield desc, parse(seq) 357 | 358 | 359 | def compute_precisions( 360 | predictions: torch.Tensor, 361 | targets: torch.Tensor, 362 | src_lengths: Optional[torch.Tensor] = None, 363 | minsep: int = 6, 364 | maxsep: Optional[int] = None, 365 | override_length: Optional[int] = None, # for casp 366 | ): 367 | if isinstance(predictions, np.ndarray): 368 | predictions = torch.from_numpy(predictions) 369 | if isinstance(targets, np.ndarray): 370 | targets = torch.from_numpy(targets) 371 | if predictions.dim() == 2: 372 | predictions = predictions.unsqueeze(0) 373 | if targets.dim() == 2: 374 | targets = targets.unsqueeze(0) 375 | override_length = (targets[0, 0] >= 0).sum() 376 | 377 | # Check sizes 378 | if predictions.size() != targets.size(): 379 | raise ValueError( 380 | f"Size mismatch. Received predictions of size {predictions.size()}, " 381 | f"targets of size {targets.size()}" 382 | ) 383 | device = predictions.device 384 | 385 | batch_size, seqlen, _ = predictions.size() 386 | seqlen_range = torch.arange(seqlen, device=device) 387 | 388 | sep = seqlen_range.unsqueeze(0) - seqlen_range.unsqueeze(1) 389 | sep = sep.unsqueeze(0) 390 | valid_mask = sep >= minsep 391 | valid_mask = valid_mask & (targets >= 0) # negative targets are invalid 392 | 393 | if maxsep is not None: 394 | valid_mask &= sep < maxsep 395 | 396 | if src_lengths is not None: 397 | valid = seqlen_range.unsqueeze(0) < src_lengths.unsqueeze(1) 398 | valid_mask &= valid.unsqueeze(1) & valid.unsqueeze(2) 399 | else: 400 | src_lengths = torch.full([batch_size], seqlen, device=device, dtype=torch.long) 401 | 402 | predictions = predictions.masked_fill(~valid_mask, float("-inf")) 403 | 404 | x_ind, y_ind = np.triu_indices(seqlen, minsep) 405 | predictions_upper = predictions[:, x_ind, y_ind] 406 | targets_upper = targets[:, x_ind, y_ind] 407 | 408 | topk = seqlen if override_length is None else max(seqlen, override_length) 409 | indices = predictions_upper.argsort(dim=-1, descending=True)[:, :topk] 410 | topk_targets = targets_upper[torch.arange(batch_size).unsqueeze(1), indices] 411 | if topk_targets.size(1) < topk: 412 | topk_targets = F.pad(topk_targets, [0, topk - topk_targets.size(1)]) 413 | 414 | cumulative_dist = topk_targets.type_as(predictions).cumsum(-1) 415 | 416 | gather_lengths = src_lengths.unsqueeze(1) 417 | if override_length is not None: 418 | gather_lengths = override_length * torch.ones_like( 419 | gather_lengths, device=device 420 | ) 421 | 422 | gather_indices = ( 423 | torch.arange(0.1, 1.1, 0.1, device=device).unsqueeze(0) * gather_lengths 424 | ).type(torch.long) - 1 425 | 426 | binned_cumulative_dist = cumulative_dist.gather(1, gather_indices) 427 | binned_precisions = binned_cumulative_dist / (gather_indices + 1).type_as( 428 | binned_cumulative_dist 429 | ) 430 | 431 | pl5 = binned_precisions[:, 1] 432 | pl2 = binned_precisions[:, 4] 433 | pl = binned_precisions[:, 9] 434 | auc = binned_precisions.mean(-1) 435 | 436 | return {"AUC": auc, "P@L": pl, "P@L2": pl2, "P@L5": pl5} 437 | 438 | 439 | def evaluate_prediction( 440 | predictions: torch.Tensor, 441 | targets: torch.Tensor, 442 | ) -> Dict[str, float]: 443 | if isinstance(targets, np.ndarray): 444 | targets = torch.from_numpy(targets) 445 | contact_ranges = [ 446 | ("local", 3, 6), 447 | ("short", 6, 12), 448 | ("medium", 12, 24), 449 | ("long", 24, None), 450 | ] 451 | metrics = {} 452 | targets = targets.to(predictions.device) 453 | for name, minsep, maxsep in contact_ranges: 454 | rangemetrics = compute_precisions( 455 | predictions, 456 | targets, 457 | minsep=minsep, 458 | maxsep=maxsep, 459 | ) 460 | for key, val in rangemetrics.items(): 461 | metrics[f"{name}_{key}"] = val.item() 462 | return metrics 463 | 464 | class ESMStructuralSplitDataset(torch.utils.data.Dataset): 465 | """ 466 | Structural Split Dataset as described in section A.10 of the supplement of our paper. 467 | https://doi.org/10.1101/622803 468 | 469 | We use the full version of SCOPe 2.07, clustered at 90% sequence identity, 470 | generated on January 23, 2020. 471 | 472 | For each SCOPe domain: 473 | - We extract the sequence from the corresponding PDB file 474 | - We extract the 3D coordinates of the Carbon beta atoms, aligning them 475 | to the sequence. We put NaN where Cb atoms are missing. 476 | - From the 3D coordinates, we calculate a pairwise distance map, based 477 | on L2 distance 478 | - We use DSSP to generate secondary structure labels for the corresponding 479 | PDB file. This is also aligned to the sequence. We put - where SSP 480 | labels are missing. 481 | 482 | For each SCOPe classification level of family/superfamily/fold (in order of difficulty), 483 | we have split the data into 5 partitions for cross validation. These are provided 484 | in a downloaded splits folder, in the format: 485 | splits/{split_level}/{cv_partition}/{train|valid}.txt 486 | where train is the partition and valid is the concatentation of the remaining 4. 487 | 488 | For each SCOPe domain, we provide a pkl dump that contains: 489 | - seq : The domain sequence, stored as an L-length string 490 | - ssp : The secondary structure labels, stored as an L-length string 491 | - dist : The distance map, stored as an LxL numpy array 492 | - coords : The 3D coordinates, stored as an Lx3 numpy array 493 | 494 | """ 495 | 496 | base_folder = "structural-data" 497 | file_list = [ 498 | # url tar filename filename MD5 Hash 499 | ( 500 | "https://dl.fbaipublicfiles.com/fair-esm/structural-data/splits.tar.gz", 501 | "splits.tar.gz", 502 | "splits", 503 | "456fe1c7f22c9d3d8dfe9735da52411d", 504 | ), 505 | ( 506 | "https://dl.fbaipublicfiles.com/fair-esm/structural-data/pkl.tar.gz", 507 | "pkl.tar.gz", 508 | "pkl", 509 | "644ea91e56066c750cd50101d390f5db", 510 | ), 511 | ] 512 | 513 | def __init__( 514 | self, 515 | split_level, 516 | cv_partition, 517 | split, 518 | root_path=os.path.expanduser("~/.cache/torch/data/esm"), 519 | download=False, 520 | ): 521 | super().__init__() 522 | assert split in [ 523 | "train", 524 | "valid", 525 | ], "train_valid must be 'train' or 'valid'" 526 | self.root_path = root_path 527 | self.base_path = os.path.join(self.root_path, self.base_folder) 528 | 529 | # check if root path has what you need or else download it 530 | if download: 531 | self.download() 532 | 533 | self.split_file = os.path.join( 534 | self.base_path, "splits", split_level, cv_partition, f"{split}.txt" 535 | ) 536 | self.pkl_dir = os.path.join(self.base_path, "pkl") 537 | self.names = [] 538 | with open(self.split_file) as f: 539 | self.names = f.read().splitlines() 540 | 541 | def __len__(self): 542 | return len(self.names) 543 | 544 | def _check_exists(self) -> bool: 545 | for (_, _, filename, _) in self.file_list: 546 | fpath = os.path.join(self.base_path, filename) 547 | if not os.path.exists(fpath) or not os.path.isdir(fpath): 548 | return False 549 | return True 550 | 551 | def download(self): 552 | 553 | if self._check_exists(): 554 | print("Files already downloaded and verified") 555 | return 556 | 557 | from torchvision.datasets.utils import download_url 558 | 559 | for url, tar_filename, filename, md5_hash in self.file_list: 560 | download_path = os.path.join(self.base_path, tar_filename) 561 | download_url(url=url, root=self.base_path, filename=tar_filename, md5=md5_hash) 562 | shutil.unpack_archive(download_path, self.base_path) 563 | 564 | def __getitem__(self, idx): 565 | """ 566 | Returns a dict with the following entires 567 | - seq : Str (domain sequence) 568 | - ssp : Str (SSP labels) 569 | - dist : np.array (distance map) 570 | - coords : np.array (3D coordinates) 571 | """ 572 | name = self.names[idx] 573 | pkl_fname = os.path.join(self.pkl_dir, name[1:3], f"{name}.pkl") 574 | with open(pkl_fname, "rb") as f: 575 | obj = pickle.load(f) 576 | return obj 577 | 578 | 579 | def plot_contacts_and_predictions( 580 | predictions: Union[torch.Tensor, np.ndarray], 581 | contacts: Union[torch.Tensor, np.ndarray], 582 | ax: Optional[mpl.axes.Axes] = None, 583 | # artists: Optional[ContactAndPredictionArtists] = None, 584 | cmap: str = "Blues", 585 | ms: float = 1, 586 | title: Union[bool, str, Callable[[float], str]] = True, 587 | animated: bool = False, 588 | ) -> None: 589 | 590 | if isinstance(predictions, torch.Tensor): 591 | predictions = predictions.detach().cpu().numpy() 592 | if isinstance(contacts, torch.Tensor): 593 | contacts = contacts.detach().cpu().numpy() 594 | if ax is None: 595 | ax = plt.gca() 596 | 597 | seqlen = contacts.shape[0] 598 | relative_distance = np.add.outer(-np.arange(seqlen), np.arange(seqlen)) 599 | bottom_mask = relative_distance < 0 600 | masked_image = np.ma.masked_where(bottom_mask, predictions) 601 | invalid_mask = np.abs(np.add.outer(np.arange(seqlen), -np.arange(seqlen))) < 6 602 | predictions = predictions.copy() 603 | predictions[invalid_mask] = float("-inf") 604 | 605 | topl_val = np.sort(predictions.reshape(-1))[-seqlen] 606 | pred_contacts = predictions >= topl_val 607 | true_positives = contacts & pred_contacts & ~bottom_mask 608 | false_positives = ~contacts & pred_contacts & ~bottom_mask 609 | other_contacts = contacts & ~pred_contacts & ~bottom_mask 610 | 611 | if isinstance(title, str): 612 | title_text: Optional[str] = title 613 | elif title: 614 | long_range_pl = compute_precisions(predictions, contacts, minsep=24)[ 615 | "P@L" 616 | ].item() 617 | if callable(title): 618 | title_text = title(long_range_pl) 619 | else: 620 | title_text = f"Long Range P@L: {100 * long_range_pl:0.1f}" 621 | else: 622 | title_text = None 623 | 624 | img = ax.imshow(masked_image, cmap=cmap, animated=animated) 625 | oc = ax.plot(*np.where(other_contacts), "o", c="grey", ms=ms)[0] 626 | fn = ax.plot(*np.where(false_positives), "o", c="r", ms=ms)[0] 627 | tp = ax.plot(*np.where(true_positives), "o", c="b", ms=ms)[0] 628 | ti = ax.set_title(title_text) if title_text is not None else None 629 | # artists = ContactAndPredictionArtists(img, oc, fn, tp, ti) 630 | 631 | ax.axis("square") 632 | ax.set_xlim([0, seqlen]) 633 | ax.set_ylim([0, seqlen]) 634 | 635 | 636 | def read_sequence(filename: str) -> Tuple[str, str]: 637 | """ Reads the first (reference) sequences from a fasta or MSA file.""" 638 | record = next(SeqIO.parse(filename, "fasta")) 639 | return record.description, str(record.seq) 640 | 641 | def remove_insertions(sequence: str) -> str: 642 | """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """ 643 | return sequence.translate(translation) 644 | 645 | def read_msa(filename: str) -> List[Tuple[str, str]]: 646 | """ Reads the sequences from an MSA file, automatically removes insertions.""" 647 | return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")] 648 | 649 | def extend(a, b, c, L, A, D): 650 | """ 651 | input: 3 coords (a,b,c), (L)ength, (A)ngle, and (D)ihedral 652 | output: 4th coord 653 | """ 654 | 655 | def normalize(x): 656 | return x / np.linalg.norm(x, ord=2, axis=-1, keepdims=True) 657 | 658 | bc = normalize(b - c) 659 | n = normalize(np.cross(b - a, bc)) 660 | m = [bc, np.cross(n, bc), n] 661 | d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)] 662 | return c + sum([m * d for m, d in zip(m, d)]) 663 | 664 | 665 | def contacts_from_pdb( 666 | structure: bs.AtomArray, 667 | distance_threshold: float = 8.0, 668 | chain: Optional[str] = None, 669 | ) -> np.ndarray: 670 | mask = ~structure.hetero 671 | if chain is not None: 672 | mask &= structure.chain_id == chain 673 | 674 | N = structure.coord[mask & (structure.atom_name == "N")] 675 | CA = structure.coord[mask & (structure.atom_name == "CA")] 676 | C = structure.coord[mask & (structure.atom_name == "C")] 677 | 678 | Cbeta = extend(C, N, CA, 1.522, 1.927, -2.143) 679 | dist = squareform(pdist(Cbeta)) 680 | 681 | contacts = dist < distance_threshold 682 | contacts = contacts.astype(np.int64) 683 | contacts[np.isnan(dist)] = -1 684 | return contacts 685 | 686 | 687 | class BatchConverterContact(object): 688 | def __init__(self, alphabet, truncation_seq_length: int = None): 689 | self.alphabet = alphabet 690 | self.truncation_seq_length = truncation_seq_length 691 | 692 | def __call__(self, raw_batch): 693 | batch_size = len(raw_batch) 694 | batch_labels, seq_str_list, contact_list = zip(*raw_batch) 695 | seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list] 696 | if self.truncation_seq_length: 697 | seq_encoded_list = [seq_str[:self.truncation_seq_length] for seq_str in seq_encoded_list] 698 | max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list) 699 | tokens = torch.empty( 700 | (batch_size, max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos)), 701 | dtype=torch.int64) 702 | tokens.fill_(self.alphabet.padding_idx) 703 | labels = [] 704 | strs = [] 705 | 706 | contacts = torch.zeros((batch_size, max_len, max_len), dtype=torch.float32) 707 | 708 | for i, (label, seq_str, seq_encoded, contact_map) in enumerate( 709 | zip(batch_labels, seq_str_list, seq_encoded_list, contact_list) 710 | ): 711 | labels.append(label) 712 | strs.append(seq_str) 713 | if self.alphabet.prepend_bos: 714 | tokens[i, 0] = self.alphabet.cls_idx 715 | seq = torch.tensor(seq_encoded, dtype=torch.int64) 716 | tokens[ 717 | i, 718 | int(self.alphabet.prepend_bos) : len(seq_encoded) 719 | + int(self.alphabet.prepend_bos), 720 | ] = seq 721 | if self.alphabet.append_eos: 722 | tokens[i, len(seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx 723 | 724 | contacts[i, :len(contact_map),:len(contact_map)] = contact_map 725 | return tokens, contacts 726 | 727 | def plot_stats(dict_log, modelname="", baseline=None, title=None): 728 | fontsize = 14 729 | plt.subplots_adjust(hspace=0.3) 730 | plt.subplot(2,1,1) 731 | x_axis = list(range(len(dict_log["val_f1_epoch"]))) 732 | plt.plot(dict_log["train_f1_epoch"], label=f'{modelname} Train F1') 733 | plt.scatter(x_axis, dict_log["train_f1_epoch"]) 734 | plt.plot(dict_log["val_f1_epoch"], label=f'{modelname} Validation F1') 735 | plt.scatter(x_axis, dict_log["val_f1_epoch"]) 736 | plt.ylabel('F1') 737 | plt.xlabel('Number of Epochs') 738 | plt.title("F1 over epochs", fontsize=fontsize) 739 | if baseline is not None: 740 | plt.axhline(y=baseline, color='red', label="Acceptable F1") 741 | plt.legend(fontsize=fontsize) 742 | plt.subplot(2,1,2) 743 | plt.plot(dict_log["loss_epoch"] , label="Training") 744 | plt.scatter(x_axis, dict_log["loss_epoch"], ) 745 | plt.plot(dict_log["val_loss"] , label='Validation') 746 | plt.scatter(x_axis, dict_log["val_loss"]) 747 | plt.ylabel('Loss value') 748 | plt.xlabel('Number of Epochs') 749 | plt.title("Loss over epochs", fontsize=fontsize) 750 | plt.legend(fontsize=fontsize) 751 | if title is not None: 752 | plt.savefig(title) 753 | --------------------------------------------------------------------------------