├── .gitignore ├── README.md ├── cifar100_example_subset_creation.ipynb ├── configs.py ├── data_proc ├── __init__.py ├── augmentation.py └── dataset.py ├── evaluate ├── __init__.py ├── checkpoint.py └── lbfgs.py ├── examples ├── cifar10 │ ├── cifar10-0.05-sas-subset-indices.pkl │ ├── cifar10-0.10-sas-subset-indices.pkl │ ├── cifar10-0.15-sas-subset-indices.pkl │ └── cifar10-0.20-sas-subset-indices.pkl ├── cifar100 │ ├── cifar100-0.2-sas-subset-indices.pkl │ ├── cifar100-0.4-sas-subset-indices.pkl │ ├── cifar100-0.6-sas-subset-indices.pkl │ └── cifar100-0.8-sas-subset-indices.pkl ├── imagenet │ ├── imagenet-0.2-rand-balanced-indices.pkl │ ├── imagenet-0.2-sas-subset-indices.pkl │ ├── imagenet-0.6-random-balanced-idx.pkl │ └── imagenet-0.6-sas-subset-indices.pkl └── stl10 │ ├── stl10-0.2-sas-subset-indices.pkl │ ├── stl10-0.4-sas-subset-indices.pkl │ ├── stl10-0.6-sas-subset-indices.pkl │ └── stl10-0.8-sas-subset-indices.pkl ├── final_subsets ├── cifar10-cl-core-idx.pkl ├── cifar10-rand-balanced-idx.pkl ├── cifar100-cl-core-idx.pkl ├── cifar100-rand-balanced-idx.pkl ├── stl10-cl-core-idx.pkl └── stl10-rand-balanced-idx.pkl ├── linear_probe.py ├── projection_heads ├── __init__.py ├── critic.py ├── gradient_linear_clf.py └── lbfgs_linear_clf.py ├── proxy-cifar100-resnet10-399-critic.pt ├── proxy-cifar100-resnet10-399-net.pt ├── requirements.txt ├── resnet.py ├── sas-pip ├── sas │ ├── __init__.py │ ├── approx_latent_classes.py │ ├── submodular_maximization.py │ └── subset_dataset.py └── setup.py ├── simclr.py ├── trainer.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | sas.egg-info 3 | build 4 | wandb 5 | *.pt 6 | imagenet-subset 7 | *.pkl -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Data Efficient Contrastive Learing (ICML 2023) 2 | 3 | ## Abstract 4 | 5 | Self-supervised learning (SSL) learns high-quality representations from large pools of unlabeled training data. As datasets grow larger, it becomes crucial to identify the examples that contribute the most to learning such representations. This enables efficient SSL by reducing the volume of data required for learning high-quality representations. Nevertheless, quantifying the value of examples for SSL has remained an open question. In this work, we address this for the first time, by proving that examples that contribute the most to contrastive SSL are those that have the most similar augmentations to other examples, in expectation. We provide rigorous guarantees for the generalization performance of SSL on such subsets. Empirically, we discover, perhaps surprisingly, the subsets that contribute the most to SSL are those that contribute the least to supervised learning. Through extensive experiments, we show we can safely exclude 20% of examples from CIFAR100 and 40% from STL10 and TinyImageNet, without affecting downstream task performance. We also show that our subsets outperform random subsets by more than 2% on CIFAR10. We also demonstrate that these subsets are effective across contrastive SSL methods (evaluated on SimCLR, MoCo, SimSiam, BYOL). 6 | 7 | [Project Page](https://sjoshi804.github.io/data-efficient-contrastive-learning/) 8 | 9 | [Paper](https://proceedings.mlr.press/v202/joshi23b.html) 10 | 11 | ## BibTex Citation 12 | 13 | Please cite this if you use this code / paper in your work. 14 | 15 | ```bibtex 16 | @InProceedings{pmlr-v202-joshi23b, 17 | title = {Data-Efficient Contrastive Self-supervised Learning: Most Beneficial Examples for Supervised Learning Contribute the Least}, 18 | author = {Joshi, Siddharth and Mirzasoleiman, Baharan}, 19 | booktitle = {Proceedings of the 40th International Conference on Machine Learning}, 20 | pages = {15356--15370}, 21 | year = {2023}, 22 | editor = {Krause, Andreas and Brunskill, Emma and Cho, Kyunghyun and Engelhardt, Barbara and Sabato, Sivan and Scarlett, Jonathan}, 23 | volume = {202}, 24 | series = {Proceedings of Machine Learning Research}, 25 | month = {23--29 Jul}, 26 | publisher = {PMLR}, 27 | pdf = {https://proceedings.mlr.press/v202/joshi23b/joshi23b.pdf}, 28 | url = {https://proceedings.mlr.press/v202/joshi23b.html}, 29 | } 30 | ``` 31 | 32 | ## Examples 33 | 34 | Example subsets can be found in `/examples` 35 | 36 | Subsets Provided: 37 | 38 | - CIFAR100 - 20%, 40%, 60%, 80% subsets 39 | - STL10 - 20%, 40%, 60%, 80% subsets 40 | - CIFAR10 - 5%, 10%, 15%, 20% subsets 41 | - TinyImageNet (coming soon) 42 | - ImageNet (coming soon) 43 | 44 | To get the subset indices: 45 | 46 | ```python 47 | import pickle 48 | 49 | with open(f"", "rb") as f: 50 | subset_indices = pickle.load(f) 51 | ``` 52 | 53 | And then pass this to CustomSubsetDataset as the subset indices argument to get the corresponding subset dataset object. (More details below) 54 | 55 | ## Sample Usage 56 | 57 | See `cifar100_example_subset_creation.ipynb` for a complete example of how to create a subset with proxy models provided. 58 | 59 | ```bash 60 | pip install sas-pip/ 61 | pip install -r requirements.txt 62 | ``` 63 | 64 | Samples shown for choosing subsets of CIFAR100 65 | 66 | ### SAS (default) 67 | 68 | ```python 69 | # Approximate Latent Classes 70 | from sas.approx_latent_classes import clip_approx 71 | from sas.subset_dataset import SASSubsetDataset 72 | import random 73 | 74 | cifar100 = torchvision.datasets.CIFAR100("/data/cifar100/", transform=transforms.ToTensor()) 75 | device = "cuda:0" 76 | 77 | rand_labeled_examples_indices = random.sample(len(cifar100), 500) 78 | rand_labeled_examples_labels = [cifar100[i][1] for i in rand_labeled_examples_indices] 79 | 80 | partition = clip_approx( 81 | img_trainset=cifar100, 82 | labeled_example_indices=rand_labeled_examples_indices, 83 | labeled_examples_labels=rand_labeled_examples_labels, 84 | num_classes=100, 85 | device=device 86 | ) 87 | 88 | # Get Subset 89 | proxy_model = torch.load(f"cifar100-proxy-encoder.pt").module.to(device) 90 | subset_dataset = SASSubsetDataset( 91 | dataset=cifar100, 92 | subset_fraction=0.2, 93 | num_downstream_classes=100, 94 | device=device, 95 | proxy_model=proxy_model, 96 | approx_latent_class_partition=partition, 97 | verbose=True 98 | ) 99 | ``` 100 | 101 | ### SAS (CLIP 0-shot Latent Classes) 102 | 103 | ```python 104 | # Approximate Latent Classes 105 | from sas.approx_latent_classes import clip_0shot_approx 106 | 107 | partition = clip_0shot_approx( 108 | img_trainset=cifar100, 109 | class_names=cifar100_classes, 110 | device=device 111 | ) 112 | ``` 113 | 114 | ### SAS (k-Means Latent Classes) 115 | 116 | ```python 117 | # Approximate Latent Classes 118 | from sas.approx_latent_classes import kmeans_approx 119 | 120 | partition = kmeans_approx( 121 | trainset=cifar100, 122 | proxy_model=net, 123 | num_classes=100, 124 | device=device 125 | ) 126 | ``` 127 | 128 | ### Random Subset 129 | 130 | ```python 131 | from sas.subset_dataset import RandomSubsetDataset 132 | 133 | cifar100 = torchvision.datasets.CIFAR100("/data/cifar100/", transform=transforms.ToTensor()) 134 | subset_dataset = RandomSubsetDataset(cifar100, subset_fraction=0.2) 135 | 136 | ``` 137 | 138 | ### Custom Subset 139 | 140 | ```python 141 | from sas.subset_dataset import CustomSubset 142 | 143 | cifar100 = torchvision.datasets.CIFAR100("/data/cifar100/", transform=transforms.ToTensor()) 144 | subset_dataset = CustomSubsetDataset(cifar100, subset_indices=range(10000)) 145 | ``` 146 | 147 | ## Sample Implementation of Compatible Augmented Dataset (Required for Contrastive Learning) 148 | 149 | ```python 150 | class CIFAR100Augment(torchvision.datasets.CIFAR100): 151 | def __init__(self, root: str, transform=Callable, n_augmentations: int = 2, train: bool = True, download: bool = False): 152 | super().__init__( 153 | root=root, 154 | train=train, 155 | transform=transform, 156 | download=download 157 | ) 158 | self.n_augmentations = n_augmentations 159 | 160 | def __getitem__(self, index): 161 | """ 162 | Args: 163 | index (int): Index 164 | 165 | Returns: 166 | tuple: (image, target) where target is index of the target class. 167 | """ 168 | img, _ = self.data[index], self.targets[index] 169 | pil_img = Image.fromarray(img) 170 | imgs = [] 171 | for _ in range(self.n_augmentations): 172 | imgs.append(self.transform(pil_img)) 173 | return imgs 174 | ``` 175 | 176 | # Downloading TinyImageNet 177 | 178 | Follow the steps here to correctly download and format TinyImageNet: https://github.com/tjmoon0104/pytorch-tiny-imagenet/ 179 | -------------------------------------------------------------------------------- /cifar100_example_subset_creation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Ensure latest version of package is installed" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "Processing ./sas-pip\n", 20 | " Preparing metadata (setup.py) ... \u001b[?25ldone\n", 21 | "\u001b[?25hRequirement already satisfied: torch in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from sas==1.0) (2.0.1)\n", 22 | "Requirement already satisfied: torchvision in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from sas==1.0) (0.15.2)\n", 23 | "Requirement already satisfied: numpy in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from sas==1.0) (1.24.3)\n", 24 | "Requirement already satisfied: fast-pytorch-kmeans in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from sas==1.0) (0.1.6)\n", 25 | "Requirement already satisfied: pynvml in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from fast-pytorch-kmeans->sas==1.0) (11.4.1)\n", 26 | "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from torch->sas==1.0) (11.10.3.66)\n", 27 | "Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from torch->sas==1.0) (11.4.0.1)\n", 28 | "Requirement already satisfied: sympy in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from torch->sas==1.0) (1.12)\n", 29 | "Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from torch->sas==1.0) (11.7.101)\n", 30 | "Requirement already satisfied: networkx in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from torch->sas==1.0) (3.1)\n", 31 | "Requirement already satisfied: typing-extensions in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from torch->sas==1.0) (4.5.0)\n", 32 | "Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from torch->sas==1.0) (10.9.0.58)\n", 33 | "Requirement already satisfied: filelock in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from torch->sas==1.0) (3.12.0)\n", 34 | "Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from torch->sas==1.0) (11.7.91)\n", 35 | "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from torch->sas==1.0) (11.7.99)\n", 36 | "Requirement already satisfied: jinja2 in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from torch->sas==1.0) (3.1.2)\n", 37 | "Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from torch->sas==1.0) (11.7.4.91)\n", 38 | "Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from torch->sas==1.0) (2.14.3)\n", 39 | "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from torch->sas==1.0) (11.7.99)\n", 40 | "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from torch->sas==1.0) (8.5.0.96)\n", 41 | "Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from torch->sas==1.0) (10.2.10.91)\n", 42 | "Requirement already satisfied: triton==2.0.0 in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from torch->sas==1.0) (2.0.0)\n", 43 | "Requirement already satisfied: setuptools in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch->sas==1.0) (67.8.0)\n", 44 | "Requirement already satisfied: wheel in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch->sas==1.0) (0.40.0)\n", 45 | "Requirement already satisfied: lit in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from triton==2.0.0->torch->sas==1.0) (16.0.5)\n", 46 | "Requirement already satisfied: cmake in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from triton==2.0.0->torch->sas==1.0) (3.26.3)\n", 47 | "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from torchvision->sas==1.0) (9.5.0)\n", 48 | "Requirement already satisfied: requests in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from torchvision->sas==1.0) (2.31.0)\n", 49 | "Requirement already satisfied: MarkupSafe>=2.0 in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from jinja2->torch->sas==1.0) (2.1.2)\n", 50 | "Requirement already satisfied: charset-normalizer<4,>=2 in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from requests->torchvision->sas==1.0) (3.1.0)\n", 51 | "Requirement already satisfied: idna<4,>=2.5 in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from requests->torchvision->sas==1.0) (3.4)\n", 52 | "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from requests->torchvision->sas==1.0) (2.0.2)\n", 53 | "Requirement already satisfied: certifi>=2017.4.17 in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from requests->torchvision->sas==1.0) (2023.5.7)\n", 54 | "Requirement already satisfied: mpmath>=0.19 in /home/sjoshi/anaconda3/envs/clip/lib/python3.10/site-packages (from sympy->torch->sas==1.0) (1.3.0)\n", 55 | "Building wheels for collected packages: sas\n", 56 | " Building wheel for sas (setup.py) ... \u001b[?25ldone\n", 57 | "\u001b[?25h Created wheel for sas: filename=sas-1.0-py3-none-any.whl size=6289 sha256=6e8f8d3141702ae426b4a9635e99beaa3da3ecf8c32cedb7dcc76cad8522aca4\n", 58 | " Stored in directory: /home/sjoshi/.cache/pip/wheels/4e/07/53/a089817b38c15451794418a74eb8812ee557a2982d04e9d60a\n", 59 | "Successfully built sas\n", 60 | "Installing collected packages: sas\n", 61 | " Attempting uninstall: sas\n", 62 | " Found existing installation: sas 1.0\n", 63 | " Uninstalling sas-1.0:\n", 64 | " Successfully uninstalled sas-1.0\n", 65 | "Successfully installed sas-1.0\n", 66 | "Note: you may need to restart the kernel to use updated packages.\n" 67 | ] 68 | } 69 | ], 70 | "source": [ 71 | "%pip install sas-pip/" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "Load Data" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 2, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "import torchvision\n", 88 | "from torchvision import transforms\n", 89 | "\n", 90 | "cifar100 = torchvision.datasets.CIFAR100(\"/data/cifar100/\", transform=transforms.ToTensor())\n", 91 | "device = \"cuda:6\"" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "Partition into approximate latent classes" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 3, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "from sas.approx_latent_classes import clip_approx\n", 108 | "from sas.subset_dataset import SASSubsetDataset\n", 109 | "import random \n", 110 | "\n", 111 | "rand_labeled_examples_indices = random.sample(range(len(cifar100)), 500)\n", 112 | "rand_labeled_examples_labels = [cifar100[i][1] for i in rand_labeled_examples_indices]\n", 113 | "\n", 114 | "partition = clip_approx(\n", 115 | " img_trainset=cifar100,\n", 116 | " labeled_example_indices=rand_labeled_examples_indices, \n", 117 | " labeled_examples_labels=rand_labeled_examples_labels,\n", 118 | " num_classes=100,\n", 119 | " device=device\n", 120 | ")" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "Load proxy model" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 4, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "from torch import nn \n", 137 | "\n", 138 | "class ProxyModel(nn.Module):\n", 139 | " def __init__(self, net, critic):\n", 140 | " super().__init__()\n", 141 | " self.net = net\n", 142 | " self.critic = critic\n", 143 | " def forward(self, x):\n", 144 | " return self.critic.project(self.net(x))" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "Determine subset" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 5, 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "name": "stderr", 161 | "output_type": "stream", 162 | "text": [ 163 | "Subset Selection:: 100%|██████████| 100/100 [00:03<00:00, 29.65it/s]" 164 | ] 165 | }, 166 | { 167 | "name": "stdout", 168 | "output_type": "stream", 169 | "text": [ 170 | "Subset Size: 10000\n", 171 | "Discarded 40000 examples\n" 172 | ] 173 | }, 174 | { 175 | "name": "stderr", 176 | "output_type": "stream", 177 | "text": [ 178 | "\n" 179 | ] 180 | } 181 | ], 182 | "source": [ 183 | "import torch \n", 184 | "\n", 185 | "net = torch.load(\"proxy-cifar100-resnet10-399-net.pt\")\n", 186 | "critic = torch.load(\"proxy-cifar100-resnet10-399-critic.pt\")\n", 187 | "proxy_model = ProxyModel(net, critic)\n", 188 | " \n", 189 | "subset_dataset = SASSubsetDataset(\n", 190 | " dataset=cifar100,\n", 191 | " subset_fraction=0.2,\n", 192 | " num_downstream_classes=100,\n", 193 | " device=device,\n", 194 | " proxy_model=proxy_model,\n", 195 | " approx_latent_class_partition=partition,\n", 196 | " verbose=True\n", 197 | ")" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": {}, 203 | "source": [ 204 | "Save subset to file" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 6, 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "subset_dataset.save_to_file(\"cifar100-0.2-sas-indices.pkl\")" 214 | ] 215 | } 216 | ], 217 | "metadata": { 218 | "kernelspec": { 219 | "display_name": "clip", 220 | "language": "python", 221 | "name": "python3" 222 | }, 223 | "language_info": { 224 | "codemirror_mode": { 225 | "name": "ipython", 226 | "version": 3 227 | }, 228 | "file_extension": ".py", 229 | "mimetype": "text/x-python", 230 | "name": "python", 231 | "nbconvert_exporter": "python", 232 | "pygments_lexer": "ipython3", 233 | "version": "3.10.4" 234 | }, 235 | "orig_nbformat": 4 236 | }, 237 | "nbformat": 4, 238 | "nbformat_minor": 2 239 | } 240 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import json 3 | 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | 7 | from collections import namedtuple 8 | from data_proc.augmentation import ColourDistortion 9 | from data_proc.dataset import * 10 | from resnet import * 11 | 12 | class SupportedDatasets(Enum): 13 | CIFAR10 = "cifar10" 14 | CIFAR100 = "cifar100" 15 | TINY_IMAGENET = "tiny_imagenet" 16 | IMAGENET = "imagenet" 17 | STL10 = "stl10" 18 | 19 | Datasets = namedtuple('Datasets', 'trainset testset clftrainset num_classes stem') 20 | 21 | def get_datasets(dataset: str, augment_clf_train=False, add_indices_to_data=False, num_positive=2): 22 | 23 | CACHED_MEAN_STD = { 24 | SupportedDatasets.CIFAR10.value: ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 25 | SupportedDatasets.CIFAR100.value: ((0.5071, 0.4865, 0.4409), (0.2009, 0.1984, 0.2023)), 26 | SupportedDatasets.STL10.value: ((0.4409, 0.4279, 0.3868), (0.2309, 0.2262, 0.2237)), 27 | SupportedDatasets.TINY_IMAGENET.value: ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 28 | SupportedDatasets.IMAGENET.value: ((0.485, 0.456, 0.3868), (0.2309, 0.2262, 0.2237)) 29 | } 30 | 31 | PATHS = { 32 | SupportedDatasets.CIFAR10.value: '/data/cifar10/', 33 | SupportedDatasets.CIFAR100.value: '/data/cifar100/', 34 | SupportedDatasets.STL10.value: '/data/stl10/', 35 | SupportedDatasets.TINY_IMAGENET.value: '/data/tiny_imagenet/', 36 | SupportedDatasets.IMAGENET.value: '/data/ILSVRC/' 37 | } 38 | 39 | try: 40 | with open('dataset-paths.json', 'r') as f: 41 | local_paths = json.load(f) 42 | PATHS.update(local_paths) 43 | except FileNotFoundError: 44 | pass 45 | root = PATHS[dataset] 46 | 47 | # Data 48 | if dataset == SupportedDatasets.STL10.value: 49 | img_size = 96 50 | elif dataset == SupportedDatasets.IMAGENET.value: 51 | img_size = 224 52 | elif dataset == SupportedDatasets.TINY_IMAGENET.value: 53 | img_size = 64 54 | else: 55 | img_size = 32 56 | 57 | transform_train = transforms.Compose([ 58 | transforms.RandomResizedCrop(img_size, interpolation=Image.BICUBIC), 59 | transforms.RandomHorizontalFlip(), 60 | ColourDistortion(s=0.5), 61 | transforms.ToTensor(), 62 | transforms.Normalize(*CACHED_MEAN_STD[dataset]), 63 | ]) 64 | 65 | if dataset == SupportedDatasets.IMAGENET.value: 66 | transform_test = transforms.Compose([ 67 | transforms.Resize(256), 68 | transforms.CenterCrop(224), 69 | transforms.ToTensor(), 70 | transforms.Normalize(*CACHED_MEAN_STD[dataset]), 71 | ]) 72 | else: 73 | transform_test = transforms.Compose([ 74 | transforms.ToTensor(), 75 | transforms.Normalize(*CACHED_MEAN_STD[dataset]), 76 | ]) 77 | 78 | if augment_clf_train: 79 | transform_clftrain = transforms.Compose([ 80 | transforms.RandomResizedCrop(img_size, interpolation=Image.BICUBIC), 81 | transforms.RandomHorizontalFlip(), 82 | transforms.ToTensor(), 83 | transforms.Normalize(*CACHED_MEAN_STD[dataset]), 84 | ]) 85 | else: 86 | transform_clftrain = transform_test 87 | if augment_clf_train: 88 | transform_clftrain = transforms.Compose([ 89 | transforms.RandomResizedCrop(img_size, interpolation=Image.BICUBIC), 90 | transforms.RandomHorizontalFlip(), 91 | transforms.ToTensor(), 92 | transforms.Normalize(*CACHED_MEAN_STD[dataset]), 93 | ]) 94 | else: 95 | transform_clftrain = transform_test 96 | 97 | trainset = testset = clftrainset = num_classes = stem = None 98 | 99 | if dataset == SupportedDatasets.CIFAR100.value: 100 | if add_indices_to_data: 101 | dset = add_indices(torchvision.datasets.CIFAR100) 102 | else: 103 | dset = torchvision.datasets.CIFAR100 104 | trainset = CIFAR100Augment(root=root, train=True, download=True, transform=transform_train, n_augmentations=num_positive) 105 | clftrainset = dset(root=root, train=True, download=True, transform=transform_clftrain) 106 | testset = dset(root=root, train=False, download=True, transform=transform_test) 107 | num_classes = 100 108 | stem = StemCIFAR 109 | 110 | elif dataset == SupportedDatasets.CIFAR10.value: 111 | if add_indices_to_data: 112 | dset = add_indices(torchvision.datasets.CIFAR10) 113 | else: 114 | dset = torchvision.datasets.CIFAR10 115 | trainset = CIFAR10Augment(root=root, train=True, download=True, transform=transform_train, n_augmentations=num_positive) 116 | clftrainset = dset(root=root, train=True, download=True, transform=transform_clftrain) 117 | testset = dset(root=root, train=False, download=True, transform=transform_test) 118 | num_classes = 10 119 | stem = StemCIFAR 120 | elif dataset == SupportedDatasets.STL10.value: 121 | if add_indices_to_data: 122 | dset = add_indices(torchvision.datasets.STL10) 123 | else: 124 | dset = torchvision.datasets.STL10 125 | trainset = STL10Augment(root=root, split='train+unlabeled', download=True, transform=transform_train) 126 | clftrainset = dset(root=root, split='train', download=True, transform=transform_clftrain) 127 | testset = dset(root=root, split='test', download=True, transform=transform_test) 128 | num_classes = 10 129 | stem = StemSTL 130 | 131 | elif dataset == SupportedDatasets.TINY_IMAGENET.value: 132 | if add_indices_to_data: 133 | raise NotImplementedError("Not implemented for TinyImageNet") 134 | trainset = ImageFolderAugment(root=f"{root}train/", transform=transform_train, n_augmentations=num_positive) 135 | clftrainset = ImageFolder(root=f"{root}train/", transform=transform_clftrain) 136 | testset = ImageFolder(root=f"{root}test/", transform=transform_train) 137 | num_classes = 200 138 | stem = StemCIFAR 139 | 140 | elif dataset == SupportedDatasets.IMAGENET.value: 141 | if add_indices_to_data: 142 | raise NotImplementedError("Not implemented for ImageNet") 143 | trainset = ImageNetAugment(root=f"{root}train_full/", transform=transform_train, n_augmentations=num_positive) 144 | clftrainset = ImageNet(root=f"{root}train_full/", transform=transform_clftrain) 145 | testset = ImageNet(root=f"{root}test/", transform=transform_clftrain) 146 | num_classes = 1000 147 | stem = StemImageNet 148 | 149 | return Datasets(trainset=trainset, testset=testset, clftrainset=clftrainset, num_classes=num_classes, stem=stem) 150 | -------------------------------------------------------------------------------- /data_proc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/data_proc/__init__.py -------------------------------------------------------------------------------- /data_proc/augmentation.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageFilter 2 | from torchvision import transforms 3 | 4 | 5 | def ColourDistortion(s=1.0): 6 | # s is the strength of color distortion. 7 | color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s) 8 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) 9 | rnd_gray = transforms.RandomGrayscale(p=0.2) 10 | color_distort = transforms.Compose([rnd_color_jitter, rnd_gray]) 11 | return color_distort 12 | 13 | 14 | def BlurOrSharpen(radius=2.): 15 | blur = GaussianBlur(radius=radius) 16 | full_transform = transforms.RandomApply([blur], p=.5) 17 | return full_transform 18 | 19 | 20 | class ImageFilterTransform(object): 21 | 22 | def __init__(self): 23 | raise NotImplementedError 24 | 25 | def __call__(self, img): 26 | return img.filter(self.filter) 27 | 28 | 29 | class GaussianBlur(ImageFilterTransform): 30 | 31 | def __init__(self, radius=2.): 32 | self.filter = ImageFilter.GaussianBlur(radius=radius) 33 | 34 | 35 | class Sharpen(ImageFilterTransform): 36 | 37 | def __init__(self): 38 | self.filter = ImageFilter.SHARPEN 39 | -------------------------------------------------------------------------------- /data_proc/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Callable, Optional 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | import torchvision 8 | from PIL import Image 9 | from torchvision.datasets import ImageFolder 10 | 11 | class CIFAR10Augment(torchvision.datasets.CIFAR10): 12 | def __init__(self, root: str, transform=Callable, n_augmentations: int = 2, train: bool = True, download: bool = False): 13 | super().__init__( 14 | root=root, 15 | train=train, 16 | transform=transform, 17 | download=download 18 | ) 19 | self.n_augmentations = n_augmentations 20 | 21 | def __getitem__(self, index): 22 | """ 23 | Args: 24 | index (int): Index 25 | 26 | Returns: 27 | List of augmented views of element at index 28 | """ 29 | img = self.data[index] 30 | pil_img = Image.fromarray(img) 31 | imgs = [] 32 | for _ in range(self.n_augmentations): 33 | imgs.append(self.transform(pil_img)) 34 | return imgs 35 | 36 | class STL10Augment(torchvision.datasets.STL10): 37 | def __init__( 38 | self, 39 | root: str, 40 | split: str, 41 | transform: Callable, 42 | n_augmentations: int = 2, 43 | download: bool = False, 44 | ) -> None: 45 | super().__init__( 46 | root=root, 47 | split=split, 48 | transform=transform, 49 | download=download) 50 | self.n_augmentations = n_augmentations 51 | 52 | def __getitem__(self, index): 53 | """ 54 | Args: 55 | index (int): Index 56 | 57 | Returns: 58 | tuple: (image, target) where target is index of the target class. 59 | """ 60 | img = self.data[index] 61 | 62 | # doing this so that it is consistent with all other datasets 63 | # to return a PIL Image 64 | pil_img = Image.fromarray(np.transpose(img, (1, 2, 0))) 65 | imgs = [] 66 | for _ in range(self.n_augmentations): 67 | imgs.append(self.transform(pil_img)) 68 | return imgs 69 | 70 | 71 | class CIFAR100Augment(CIFAR10Augment): 72 | """`CIFAR100 `_ Dataset. 73 | 74 | This is a subclass of the `CIFAR10Biaugment` Dataset. 75 | """ 76 | base_folder = 'cifar-100-python' 77 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 78 | filename = "cifar-100-python.tar.gz" 79 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 80 | train_list = [ 81 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 82 | ] 83 | 84 | test_list = [ 85 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 86 | ] 87 | meta = { 88 | 'filename': 'meta', 89 | 'key': 'fine_label_names', 90 | 'md5': '7973b15100ade9c7d40fb424638fde48', 91 | } 92 | 93 | 94 | class ImageFolderAugment(ImageFolder): 95 | def __init__(self, root: str, transform=Callable, n_augmentations: int = 2): 96 | super().__init__( 97 | root=root, 98 | transform=transform, 99 | ) 100 | self.n_augmentations = n_augmentations 101 | 102 | def __getitem__(self, index): 103 | """ 104 | Args: 105 | index (int): Index 106 | Returns: 107 | tuple: (sample, target) where target is class_index of the target class. 108 | """ 109 | path, _ = self.samples[index] 110 | pil_img = self.loader(path) 111 | imgs = [] 112 | for _ in range(self.n_augmentations): 113 | imgs.append(self.transform(pil_img)) 114 | return imgs 115 | 116 | def add_indices(dataset_cls): 117 | class NewClass(dataset_cls): 118 | def __getitem__(self, item): 119 | output = super(NewClass, self).__getitem__(item) 120 | return (*output, item) 121 | 122 | return NewClass 123 | 124 | class ImageNet(torch.utils.data.Dataset): 125 | def __init__(self, root, transform=None): 126 | self.root = root 127 | df = pd.read_csv(os.path.join(root, "labels.csv"), on_bad_lines='skip') 128 | self.images = df["image"] 129 | self.labels = df["label"] 130 | self.transform = transform 131 | 132 | def __len__(self): 133 | return len(self.labels) 134 | 135 | def __getitem__(self, idx): 136 | image = Image.open(os.path.join(self.root, self.images[idx])).convert('RGB') 137 | if self.transform is not None: 138 | image = self.transform(image) 139 | label = self.labels[idx] 140 | return image, label 141 | 142 | 143 | class ImageNetAugment(torch.utils.data.Dataset): 144 | def __init__(self, root, transform, n_augmentations=2): 145 | self.root = root 146 | self.transform = transform 147 | self.n_augmentations = n_augmentations 148 | df = pd.read_csv(os.path.join(root, "labels.csv"), on_bad_lines='skip') 149 | self.images = df["image"] 150 | self.labels = df["label"] 151 | 152 | def __len__(self): 153 | return len(self.images) 154 | 155 | def __getitem__(self, idx): 156 | pil_img = Image.open(os.path.join(self.root, self.images[idx])).convert('RGB') 157 | imgs = [] 158 | for _ in range(self.n_augmentations): 159 | imgs.append(self.transform(pil_img)) 160 | return imgs -------------------------------------------------------------------------------- /evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | from .checkpoint import * 2 | from .lbfgs import * 3 | -------------------------------------------------------------------------------- /evaluate/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | 6 | def save_checkpoint(net, clf, critic, epoch, args, file_name): 7 | # Save checkpoint. 8 | print('Saving..') 9 | state = { 10 | 'net': net.state_dict(), 11 | 'clf': clf.state_dict(), 12 | 'critic': critic.state_dict(), 13 | 'epoch': epoch, 14 | 'args': vars(args) 15 | } 16 | if not os.path.isdir('~/efficient-contrastive-learning/checkpoint'): 17 | os.mkdir('~/efficient-contrastive-learning/checkpoint') 18 | destination = os.path.join('~/efficient-contrastive-learning/checkpoint', f"{file_name}.pth") 19 | torch.save(state, destination) -------------------------------------------------------------------------------- /evaluate/lbfgs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from tqdm import tqdm 7 | 8 | 9 | def encode_train_set(clftrainloader, device, net): 10 | net.eval() 11 | 12 | store = [] 13 | with torch.no_grad(): 14 | t = tqdm(enumerate(clftrainloader), desc='Encoded: **/** ', total=len(clftrainloader), 15 | bar_format='{desc}{bar}{r_bar}') 16 | for batch_idx, (inputs, targets) in t: 17 | inputs, targets = inputs.to(device), targets.to(device) 18 | representation = net(inputs) 19 | store.append((representation, targets)) 20 | 21 | t.set_description('Encoded %d/%d' % (batch_idx, len(clftrainloader))) 22 | 23 | X, y = zip(*store) 24 | X, y = torch.cat(X, dim=0), torch.cat(y, dim=0) 25 | return X, y 26 | 27 | def encode_train_set_w_augmentations(trainloader, device, net, critic=None, num_pos = 1): 28 | net.eval() 29 | 30 | X = [] 31 | for _ in range(num_pos): 32 | X.append([]) 33 | y = [] 34 | 35 | with torch.no_grad(): 36 | t = tqdm(enumerate(trainloader), desc='Encoded: **/** ', total=len(trainloader), 37 | bar_format='{desc}{bar}{r_bar}') 38 | for batch_idx, (input, targets, _) in t: 39 | targets = targets.to(device) 40 | for i in range(num_pos): 41 | x = None 42 | if num_pos > 2: 43 | x = input[:, i, :, :, :].to(device) 44 | else: 45 | x = input[i].to(device) 46 | if critic is not None: 47 | X[i].append(critic.project(net(x))) 48 | else: 49 | X[i].append(net(x)) 50 | y.append(targets) 51 | t.set_description('Encoded %d/%d' % (batch_idx, len(trainloader))) 52 | 53 | y = torch.cat(y, dim=0) 54 | for i, X_i in enumerate(X): 55 | X[i] = torch.cat(X_i, dim=0) 56 | return X, y 57 | 58 | def encode_train_set_projection(trainloader, device, net, critic): 59 | net.eval() 60 | critic.eval() 61 | store = [] 62 | with torch.no_grad(): 63 | t = tqdm(enumerate(trainloader), desc='Encoded: **/** ', total=len(trainloader), 64 | bar_format='{desc}{bar}{r_bar}') 65 | for batch_idx, train_tuple in t: 66 | x1 = train_tuple[0] 67 | x1 = x1.to(device) 68 | representation = critic.project(net(x1)) 69 | store.append(representation) 70 | t.set_description('Encoded Projections %d/%d' % (batch_idx, len(trainloader))) 71 | X = torch.cat(store, dim=0) 72 | return X 73 | 74 | def train_clf(X, y, representation_dim, num_classes, device, reg_weight=1e-3, iter=500): 75 | print('\nL2 Regularization weight: %g' % reg_weight) 76 | 77 | criterion = nn.CrossEntropyLoss() 78 | n_lbfgs_steps = iter 79 | 80 | # Should be reset after each epoch for a completely independent evaluation 81 | clf = nn.Linear(representation_dim, num_classes).to(device) 82 | clf_optimizer = optim.LBFGS(clf.parameters()) 83 | clf.train() 84 | 85 | t = tqdm(range(n_lbfgs_steps), desc='Loss: **** | Train Acc: ****% ', bar_format='{desc}{bar}{r_bar}') 86 | for _ in t: 87 | def closure(): 88 | clf_optimizer.zero_grad() 89 | raw_scores = clf(X) 90 | loss = criterion(raw_scores, y) 91 | loss += reg_weight * clf.weight.pow(2).sum() 92 | loss.backward() 93 | 94 | _, predicted = raw_scores.max(1) 95 | correct = predicted.eq(y).sum().item() 96 | 97 | t.set_description('Loss: %.3f | Train Acc: %.3f%% ' % (loss, 100. * correct / y.shape[0])) 98 | 99 | return loss 100 | 101 | clf_optimizer.step(closure) 102 | 103 | return clf 104 | 105 | 106 | def test_clf(testloader, device, net, clf): 107 | criterion = nn.CrossEntropyLoss() 108 | net.eval() 109 | clf.eval() 110 | test_clf_loss = 0 111 | correct = 0 112 | total = 0 113 | acc_per_point = [] 114 | with torch.no_grad(): 115 | t = tqdm(enumerate(testloader), total=len(testloader), desc='Loss: **** | Test Acc: ****% ', 116 | bar_format='{desc}{bar}{r_bar}') 117 | for batch_idx, (inputs, targets) in t: 118 | inputs, targets = inputs.to(device), targets.to(device) 119 | representation = net(inputs) 120 | # test_repr_loss = criterion(representation, targets) 121 | raw_scores = clf(representation) 122 | clf_loss = criterion(raw_scores, targets) 123 | test_clf_loss += clf_loss.item() 124 | _, predicted = raw_scores.max(1) 125 | total += targets.size(0) 126 | acc_per_point.append(predicted.eq(targets)) 127 | correct += acc_per_point[-1].sum().item() 128 | t.set_description('Loss: %.3f | Test Acc: %.3f%% ' % (test_clf_loss / (batch_idx + 1), 100. * correct / total)) 129 | 130 | acc = 100. * correct / total 131 | return acc 132 | 133 | def top5accuracy(output, target, topk=(5,)): 134 | """Computes the accuracy over the k top predictions for the specified values of k""" 135 | with torch.no_grad(): 136 | maxk = max(topk) 137 | batch_size = target.size(0) 138 | 139 | _, pred = output.topk(maxk, 1, True, True) 140 | pred = pred.t() 141 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 142 | 143 | res = [] 144 | print(correct) 145 | for k in topk: 146 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 147 | res.append(correct_k.mul_(100.0 / batch_size).item()) 148 | return res -------------------------------------------------------------------------------- /examples/cifar10/cifar10-0.05-sas-subset-indices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/examples/cifar10/cifar10-0.05-sas-subset-indices.pkl -------------------------------------------------------------------------------- /examples/cifar10/cifar10-0.10-sas-subset-indices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/examples/cifar10/cifar10-0.10-sas-subset-indices.pkl -------------------------------------------------------------------------------- /examples/cifar10/cifar10-0.15-sas-subset-indices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/examples/cifar10/cifar10-0.15-sas-subset-indices.pkl -------------------------------------------------------------------------------- /examples/cifar10/cifar10-0.20-sas-subset-indices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/examples/cifar10/cifar10-0.20-sas-subset-indices.pkl -------------------------------------------------------------------------------- /examples/cifar100/cifar100-0.2-sas-subset-indices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/examples/cifar100/cifar100-0.2-sas-subset-indices.pkl -------------------------------------------------------------------------------- /examples/cifar100/cifar100-0.4-sas-subset-indices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/examples/cifar100/cifar100-0.4-sas-subset-indices.pkl -------------------------------------------------------------------------------- /examples/cifar100/cifar100-0.6-sas-subset-indices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/examples/cifar100/cifar100-0.6-sas-subset-indices.pkl -------------------------------------------------------------------------------- /examples/cifar100/cifar100-0.8-sas-subset-indices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/examples/cifar100/cifar100-0.8-sas-subset-indices.pkl -------------------------------------------------------------------------------- /examples/imagenet/imagenet-0.2-rand-balanced-indices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/examples/imagenet/imagenet-0.2-rand-balanced-indices.pkl -------------------------------------------------------------------------------- /examples/imagenet/imagenet-0.2-sas-subset-indices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/examples/imagenet/imagenet-0.2-sas-subset-indices.pkl -------------------------------------------------------------------------------- /examples/imagenet/imagenet-0.6-random-balanced-idx.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/examples/imagenet/imagenet-0.6-random-balanced-idx.pkl -------------------------------------------------------------------------------- /examples/imagenet/imagenet-0.6-sas-subset-indices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/examples/imagenet/imagenet-0.6-sas-subset-indices.pkl -------------------------------------------------------------------------------- /examples/stl10/stl10-0.2-sas-subset-indices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/examples/stl10/stl10-0.2-sas-subset-indices.pkl -------------------------------------------------------------------------------- /examples/stl10/stl10-0.4-sas-subset-indices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/examples/stl10/stl10-0.4-sas-subset-indices.pkl -------------------------------------------------------------------------------- /examples/stl10/stl10-0.6-sas-subset-indices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/examples/stl10/stl10-0.6-sas-subset-indices.pkl -------------------------------------------------------------------------------- /examples/stl10/stl10-0.8-sas-subset-indices.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/examples/stl10/stl10-0.8-sas-subset-indices.pkl -------------------------------------------------------------------------------- /final_subsets/cifar10-cl-core-idx.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/final_subsets/cifar10-cl-core-idx.pkl -------------------------------------------------------------------------------- /final_subsets/cifar10-rand-balanced-idx.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/final_subsets/cifar10-rand-balanced-idx.pkl -------------------------------------------------------------------------------- /final_subsets/cifar100-cl-core-idx.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/final_subsets/cifar100-cl-core-idx.pkl -------------------------------------------------------------------------------- /final_subsets/cifar100-rand-balanced-idx.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/final_subsets/cifar100-rand-balanced-idx.pkl -------------------------------------------------------------------------------- /final_subsets/stl10-cl-core-idx.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/final_subsets/stl10-cl-core-idx.pkl -------------------------------------------------------------------------------- /final_subsets/stl10-rand-balanced-idx.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/final_subsets/stl10-rand-balanced-idx.pkl -------------------------------------------------------------------------------- /linear_probe.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | from sas.subset_dataset import CustomSubsetDataset 7 | import torch 8 | import torch.multiprocessing as mp 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from torch.distributed import destroy_process_group, init_process_group 12 | from torch.nn.parallel import DistributedDataParallel as DDP 13 | from torch.utils.data.distributed import DistributedSampler 14 | from tqdm import tqdm 15 | 16 | import wandb 17 | from configs import SupportedDatasets, get_datasets 18 | from evaluate.lbfgs import test_clf 19 | from resnet import * 20 | from util import Random 21 | 22 | 23 | def main(rank: int, world_size: int, args: int): 24 | # Determine Device 25 | device = rank 26 | if args.distributed: 27 | device = args.device_ids[rank] 28 | torch.cuda.set_device(args.device_ids[rank]) 29 | args.lr *= world_size 30 | 31 | # WandB Logging 32 | if not args.distributed or rank == 0: 33 | wandb.init( 34 | project="data-efficient-contrastive-learning-linear-probe", 35 | config=args 36 | ) 37 | 38 | if args.distributed: 39 | ddp_setup(rank, world_size, str(args.port)) 40 | 41 | # Set all seeds 42 | torch.manual_seed(args.seed) 43 | np.random.seed(args.seed) 44 | Random(args.seed) 45 | 46 | print('==> Preparing data..') 47 | datasets = get_datasets(args.dataset) 48 | 49 | testloader = torch.utils.data.DataLoader( 50 | dataset=CustomSubsetDataset(datasets.testset, subset_indices=range(1000)), 51 | batch_size=args.batch_size, 52 | shuffle=False, 53 | num_workers=4, 54 | pin_memory=True 55 | ) 56 | clftrainloader = torch.utils.data.DataLoader( 57 | dataset=datasets.clftrainset, 58 | batch_size=args.batch_size, 59 | shuffle=not args.distributed, 60 | sampler=DistributedSampler(CustomSubsetDataset(datasets.clftrainset, subset_indices=range(1000)), shuffle=True) if args.distributed else None, 61 | num_workers=4, 62 | pin_memory=True 63 | ) 64 | 65 | ############################################################## 66 | # Model and Optimizer 67 | ############################################################## 68 | 69 | net = torch.load(args.encoder).to(device) 70 | 71 | clf = nn.Linear(net.representation_dim, datasets.num_classes).to(device) 72 | if args.distributed: 73 | clf = DDP(clf, device_ids=[device]) 74 | 75 | criterion = nn.CrossEntropyLoss() 76 | clf_optimizer = optim.SGD(clf.parameters(), lr=args.lr, momentum=args.momentum, nesterov=args.nesterov, 77 | weight_decay=args.weight_decay) 78 | 79 | ############################################################## 80 | # Train Function 81 | ############################################################## 82 | 83 | def train_clf(epoch): 84 | print('\nEpoch %d' % epoch) 85 | net.eval() 86 | clf.train() 87 | train_loss = 0 88 | t = tqdm(enumerate(clftrainloader), desc='Loss: **** ', total=len(clftrainloader), bar_format='{desc}{bar}{r_bar}') 89 | for batch_idx, (inputs, targets) in t: 90 | clf_optimizer.zero_grad() 91 | inputs, targets = inputs.to(device), targets.to(device) 92 | representation = net(inputs).detach() 93 | predictions = clf(representation) 94 | loss = criterion(predictions, targets) 95 | loss.backward() 96 | clf_optimizer.step() 97 | 98 | train_loss += loss.item() 99 | 100 | t.set_description('Loss: %.3f ' % (train_loss / (batch_idx + 1))) 101 | return train_loss 102 | 103 | ############################################################## 104 | # Main Loop 105 | ############################################################## 106 | best_acc = 0 107 | for epoch in range(args.num_epochs): 108 | train_loss = train_clf(epoch) 109 | if not args.distributed or rank == 0: 110 | acc, top5acc = test_clf(testloader, device, net, clf) 111 | wandb.log( 112 | { 113 | "test": 114 | { 115 | "acc": acc, 116 | "top5acc": top5acc 117 | }, 118 | "train": 119 | { 120 | "loss": train_loss 121 | } 122 | }, 123 | step=epoch 124 | ) 125 | if acc > best_acc: 126 | best_acc = acc 127 | 128 | if not args.distributed or rank == 0: 129 | print("Best test accuracy", best_acc, "%") 130 | wandb.log( 131 | { 132 | "test": 133 | { 134 | "best_acc": best_acc 135 | } 136 | } 137 | ) 138 | 139 | if args.distributed: 140 | destroy_process_group() 141 | 142 | ############################################################## 143 | # Distributed Training Setup 144 | ############################################################## 145 | def ddp_setup(rank: int, world_size: int, port: str): 146 | os.environ["MASTER_ADDR"] = "localhost" 147 | os.environ["MASTER_PORT"] = port 148 | init_process_group(backend="nccl", rank=rank, world_size=world_size) 149 | 150 | ############################################################## 151 | # Script Entry Point 152 | ############################################################## 153 | if __name__ == "__main__": 154 | parser = argparse.ArgumentParser(description='PyTorch Linear Probe') 155 | parser = argparse.ArgumentParser(description='Train downstream classifier with gradients.') 156 | parser.add_argument('--lr', default=0.01, type=float, help='learning rate') 157 | parser.add_argument("--momentum", default=0.9, type=float, help='SGD momentum') 158 | parser.add_argument("--batch-size", type=int, default=512, help='Training batch size') 159 | parser.add_argument("--num-epochs", type=int, default=2, help='Number of training epochs') 160 | parser.add_argument("--weight-decay", type=float, default=1e-6, help='Weight decay on the linear classifier') 161 | parser.add_argument("--nesterov", action="store_true", help="Turn on Nesterov style momentum") 162 | parser.add_argument("--encoder", type=str, default='ckpt.pth', help='Pretrained encoder') 163 | parser.add_argument('--temperature', type=float, default=0.5, help='InfoNCE temperature') 164 | parser.add_argument('--dataset', type=str, default=str(SupportedDatasets.CIFAR100.value), help='dataset', 165 | choices=[x.value for x in SupportedDatasets]) 166 | parser.add_argument('--device', type=int, default=-1, help="GPU number to use") 167 | parser.add_argument("--device-ids", nargs = "+", default = None, help = "Specify device ids if using multiple gpus") 168 | parser.add_argument('--port', type=int, default=random.randint(49152, 65535), help="free port to use") 169 | parser.add_argument('--seed', type=int, default=0, help="Seed for randomness") 170 | 171 | # Parse arguments 172 | args = parser.parse_args() 173 | 174 | # Arguments check and initialize global variables 175 | device = "cpu" 176 | device_ids = None 177 | distributed = False 178 | if torch.cuda.is_available(): 179 | if args.device_ids is None: 180 | if args.device >= 0: 181 | device = args.device 182 | else: 183 | device = 0 184 | else: 185 | distributed = True 186 | device_ids = [int(id) for id in args.device_ids] 187 | args.device = device 188 | args.device_ids = device_ids 189 | args.distributed = distributed 190 | if distributed: 191 | mp.spawn( 192 | fn=main, 193 | args=(len(device_ids), args), 194 | nprocs=len(device_ids) 195 | ) 196 | else: 197 | main(device, 1, args) -------------------------------------------------------------------------------- /projection_heads/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/projection_heads/__init__.py -------------------------------------------------------------------------------- /projection_heads/critic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from util import Random 4 | 5 | class LinearCritic(nn.Module): 6 | 7 | def __init__(self, latent_dim, temperature=1., num_negatives = -1): 8 | super(LinearCritic, self).__init__() 9 | self.temperature = temperature 10 | self.projection_dim = 128 11 | self.w1 = nn.Linear(latent_dim, latent_dim, bias=False) 12 | self.bn1 = nn.BatchNorm1d(latent_dim) 13 | self.relu = nn.ReLU() 14 | self.w2 = nn.Linear(latent_dim, self.projection_dim, bias=False) 15 | self.bn2 = nn.BatchNorm1d(self.projection_dim, affine=False) 16 | self.cossim = nn.CosineSimilarity(dim=-1) 17 | self.num_negatives = num_negatives 18 | 19 | def project(self, h): 20 | return self.bn2(self.w2(self.relu(self.bn1(self.w1(h))))) 21 | 22 | def compute_sim(self, z): 23 | p = [] 24 | for i in range(len(z)): 25 | p.append(self.project(z[i])) 26 | 27 | sim = {} 28 | for i in range(len(p)): 29 | for j in range(i, len(p)): 30 | sim[(i, j)] = self.cossim(p[i].unsqueeze(-2), p[j].unsqueeze(-3)) / self.temperature 31 | 32 | d = sim[(0,1)].shape[-1] 33 | for i in range(len(p)): 34 | sim[(i,i)][..., range(d), range(d)] = float('-inf') 35 | 36 | for i in range(len(p)): 37 | sim[i] = torch.cat([sim[(j, i)].transpose(-1, -2) for j in range(0, i)] + [sim[(i, j)] for j in range(i, len(p))], dim=-1) 38 | sim = torch.cat([sim[i] for i in range(len(p))], dim=-2) 39 | 40 | return sim 41 | 42 | def forward(self, z): 43 | return self.compute_sim(z) -------------------------------------------------------------------------------- /projection_heads/gradient_linear_clf.py: -------------------------------------------------------------------------------- 1 | '''This script trains the downstream classifier using gradients (for large datasets).''' 2 | import argparse 3 | import os 4 | 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from tqdm import tqdm 10 | 11 | from configs import get_datasets 12 | from evaluate.lbfgs import test_clf 13 | from models import * 14 | 15 | parser = argparse.ArgumentParser(description='Train downstream classifier with gradients.') 16 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 17 | parser.add_argument("--momentum", default=0.9, type=float, help='SGD momentum') 18 | parser.add_argument("--batch-size", type=int, default=512, help='Training batch size') 19 | parser.add_argument("--num-epochs", type=int, default=90, help='Number of training epochs') 20 | parser.add_argument("--num-workers", type=int, default=2, help='Number of threads for data loaders') 21 | parser.add_argument("--weight-decay", type=float, default=1e-6, help='Weight decay on the linear classifier') 22 | parser.add_argument("--nesterov", action="store_true", help="Turn on Nesterov style momentum") 23 | parser.add_argument("--load-from", type=str, default='ckpt.pth', help='File to load from') 24 | args = parser.parse_args() 25 | 26 | # Load checkpoint. 27 | print('==> Loading settings from checkpoint..') 28 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 29 | resume_from = os.path.join('./checkpoint', args.load_from) 30 | checkpoint = torch.load(resume_from) 31 | args.dataset = checkpoint['args']['dataset'] 32 | args.arch = checkpoint['args']['arch'] 33 | 34 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 35 | best_acc = 0 36 | 37 | # Data 38 | print('==> Preparing data..') 39 | _, testset, clftrainset, num_classes, stem = get_datasets(args.dataset, augment_clf_train=True) 40 | 41 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, 42 | num_workers=args.num_workers, pin_memory=True) 43 | clftrainloader = torch.utils.data.DataLoader(clftrainset, batch_size=args.batch_size, shuffle=True, 44 | num_workers=args.num_workers, pin_memory=True) 45 | 46 | # Model 47 | print('==> Building model..') 48 | ############################################################## 49 | # Encoder 50 | ############################################################## 51 | if args.arch == 'resnet18': 52 | net = ResNet18(stem=stem) 53 | elif args.arch == 'resnet34': 54 | net = ResNet34(stem=stem) 55 | elif args.arch == 'resnet50': 56 | net = ResNet50(stem=stem) 57 | else: 58 | raise ValueError("Bad architecture specification") 59 | net = net.to(device) 60 | 61 | ############################################################## 62 | # Classifier 63 | ############################################################## 64 | clf = nn.Linear(net.representation_dim, num_classes).to(device) 65 | 66 | if device == 'cuda': 67 | repr_dim = net.representation_dim 68 | net = torch.nn.DataParallel(net) 69 | net.representation_dim = repr_dim 70 | cudnn.benchmark = True 71 | 72 | print('==> Loading encoder from checkpoint..') 73 | net.load_state_dict(checkpoint['net']) 74 | 75 | criterion = nn.CrossEntropyLoss() 76 | clf_optimizer = optim.SGD(clf.parameters(), lr=args.lr, momentum=args.momentum, nesterov=args.nesterov, 77 | weight_decay=args.weight_decay) 78 | 79 | 80 | def train_clf(epoch): 81 | print('\nEpoch %d' % epoch) 82 | net.eval() 83 | clf.train() 84 | train_loss = 0 85 | t = tqdm(enumerate(clftrainloader), desc='Loss: **** ', total=len(clftrainloader), bar_format='{desc}{bar}{r_bar}') 86 | for batch_idx, (inputs, targets) in t: 87 | clf_optimizer.zero_grad() 88 | inputs, targets = inputs.to(device), targets.to(device) 89 | representation = net(inputs).detach() 90 | predictions = clf(representation) 91 | loss = criterion(predictions, targets) 92 | loss.backward() 93 | clf_optimizer.step() 94 | 95 | train_loss += loss.item() 96 | 97 | t.set_description('Loss: %.3f ' % (train_loss / (batch_idx + 1))) 98 | 99 | 100 | for epoch in range(args.num_epochs): 101 | train_clf(epoch) 102 | acc, _ = test_clf(testloader, device, net, clf) 103 | if acc > best_acc: 104 | best_acc = acc 105 | print("Best test accuracy", best_acc, "%") 106 | -------------------------------------------------------------------------------- /projection_heads/lbfgs_linear_clf.py: -------------------------------------------------------------------------------- 1 | '''This script tunes the L2 reg weight of the final classifier.''' 2 | import argparse 3 | import os 4 | import math 5 | 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | 9 | from configs import get_datasets 10 | from evaluate import encode_train_set, train_clf, test_clf 11 | from models import * 12 | 13 | parser = argparse.ArgumentParser(description='Tune regularization coefficient of downstream classifier.') 14 | parser.add_argument("--num-workers", type=int, default=2, help='Number of threads for data loaders') 15 | parser.add_argument("--load-from", type=str, default='ckpt.pth', help='File to load from') 16 | args = parser.parse_args() 17 | 18 | # Load checkpoint. 19 | print('==> Loading settings from checkpoint..') 20 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 21 | resume_from = os.path.join('./checkpoint', args.load_from) 22 | checkpoint = torch.load(resume_from) 23 | args.dataset = checkpoint['args']['dataset'] 24 | args.arch = checkpoint['args']['arch'] 25 | 26 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 27 | 28 | # Data 29 | print('==> Preparing data..') 30 | _, testset, clftrainset, num_classes, stem = get_datasets(args.dataset) 31 | 32 | testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False, num_workers=args.num_workers, 33 | pin_memory=True) 34 | clftrainloader = torch.utils.data.DataLoader(clftrainset, batch_size=1000, shuffle=False, num_workers=args.num_workers, 35 | pin_memory=True) 36 | 37 | # Model 38 | print('==> Building model..') 39 | ############################################################## 40 | # Encoder 41 | ############################################################## 42 | if args.arch == 'resnet18': 43 | net = ResNet18(stem=stem) 44 | elif args.arch == 'resnet34': 45 | net = ResNet34(stem=stem) 46 | elif args.arch == 'resnet50': 47 | net = ResNet50(stem=stem) 48 | else: 49 | raise ValueError("Bad architecture specification") 50 | net = net.to(device) 51 | 52 | if device == 'cuda': 53 | repr_dim = net.representation_dim 54 | net = torch.nn.DataParallel(net) 55 | net.representation_dim = repr_dim 56 | cudnn.benchmark = True 57 | 58 | print('==> Loading encoder from checkpoint..') 59 | net.load_state_dict(checkpoint['net']) 60 | 61 | 62 | best_acc = 0 63 | X, y = encode_train_set(clftrainloader, device, net) 64 | for reg_weight in torch.exp(math.log(10) * torch.linspace(-7, -3, 16, dtype=torch.float, device=device)): 65 | clf = train_clf(X, y, net.representation_dim, num_classes, device, reg_weight=reg_weight) 66 | acc, _ = test_clf(testloader, device, net, clf) 67 | if acc > best_acc: 68 | best_acc = acc 69 | print("Best test accuracy", best_acc, "%") 70 | -------------------------------------------------------------------------------- /proxy-cifar100-resnet10-399-critic.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/proxy-cifar100-resnet10-399-critic.pt -------------------------------------------------------------------------------- /proxy-cifar100-resnet10-399-net.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/proxy-cifar100-resnet10-399-net.pt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | asposestorage==1.0.2 2 | clip==1.0 3 | fast_pytorch_kmeans==0.1.6 4 | numpy==1.23.1 5 | pandas==1.5.1 6 | Pillow==9.2.0 7 | Pillow==9.5.0 8 | setuptools==63.4.1 9 | torch==1.12.1 10 | torchvision==0.13.1 11 | tqdm==4.64.1 12 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1 13 | 14 | 15 | class StemCIFAR(nn.Module): 16 | def __init__(self): 17 | super(StemCIFAR, self).__init__() 18 | self.in_planes = 64 19 | self.conv1 = nn.Conv2d(3, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(self.in_planes) 21 | 22 | def forward(self, inputs): 23 | return F.relu(self.bn1(self.conv1(inputs))) 24 | 25 | class StemSTL(StemCIFAR): 26 | def __init__(self): 27 | super(StemSTL, self).__init__() 28 | self.maxpool = nn.MaxPool2d(kernel_size=3) 29 | 30 | def forward(self, inputs): 31 | out = F.relu(self.bn1(self.conv1(inputs))) 32 | out = self.maxpool(out) 33 | return out 34 | 35 | class StemTinyImageNet(StemCIFAR): 36 | def __init__(self): 37 | super(StemTinyImageNet, self).__init__() 38 | self.maxpool = nn.MaxPool2d(kernel_size=2) 39 | 40 | def forward(self, inputs): 41 | out = F.relu(self.bn1(self.conv1(inputs))) 42 | out = self.maxpool(out) 43 | return out 44 | 45 | class StemImageNet(nn.Module): 46 | def __init__(self): 47 | super(StemImageNet, self).__init__() 48 | self.inplanes = 64 49 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 50 | self.bn1 = nn.BatchNorm2d(self.inplanes) 51 | self.relu = nn.ReLU(inplace=True) 52 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 53 | 54 | def forward(self, x): 55 | x = self.conv1(x) 56 | x = self.bn1(x) 57 | x = self.relu(x) 58 | x = self.maxpool(x) 59 | return x 60 | 61 | class ResNet(nn.Module): 62 | 63 | def __init__(self, block, layers, num_classes=None, zero_init_residual=False, 64 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 65 | norm_layer=None, stem=StemCIFAR): 66 | super(ResNet, self).__init__() 67 | if norm_layer is None: 68 | norm_layer = nn.BatchNorm2d 69 | self._norm_layer = norm_layer 70 | 71 | self.inplanes = 64 72 | self.dilation = 1 73 | if replace_stride_with_dilation is None: 74 | # each element in the tuple indicates if we should replace 75 | # the 2x2 stride with a dilated convolution instead 76 | replace_stride_with_dilation = [False, False, False] 77 | if len(replace_stride_with_dilation) != 3: 78 | raise ValueError("replace_stride_with_dilation should be None " 79 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 80 | self.groups = groups 81 | self.base_width = width_per_group 82 | self.stem = stem() 83 | self.layer1 = self._make_layer(block, 64, layers[0]) 84 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 85 | dilate=replace_stride_with_dilation[0]) 86 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 87 | dilate=replace_stride_with_dilation[1]) 88 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 89 | dilate=replace_stride_with_dilation[2]) 90 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 91 | if num_classes is not None: 92 | self.fc = nn.Linear(512 * block.expansion, num_classes) 93 | else: 94 | self.fc = None 95 | 96 | for m in self.modules(): 97 | if isinstance(m, nn.Conv2d): 98 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 99 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 100 | nn.init.constant_(m.weight, 1) 101 | nn.init.constant_(m.bias, 0) 102 | 103 | # Zero-initialize the last BN in each residual branch, 104 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 105 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 106 | if zero_init_residual: 107 | for m in self.modules(): 108 | if isinstance(m, Bottleneck): 109 | nn.init.constant_(m.bn3.weight, 0) 110 | elif isinstance(m, BasicBlock): 111 | nn.init.constant_(m.bn2.weight, 0) 112 | 113 | self.representation_dim = 512 * block.expansion 114 | 115 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 116 | norm_layer = self._norm_layer 117 | downsample = None 118 | previous_dilation = self.dilation 119 | if dilate: 120 | self.dilation *= stride 121 | stride = 1 122 | if stride != 1 or self.inplanes != planes * block.expansion: 123 | downsample = nn.Sequential( 124 | conv1x1(self.inplanes, planes * block.expansion, stride), 125 | norm_layer(planes * block.expansion), 126 | ) 127 | 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 130 | self.base_width, previous_dilation, norm_layer)) 131 | self.inplanes = planes * block.expansion 132 | for _ in range(1, blocks): 133 | layers.append(block(self.inplanes, planes, groups=self.groups, 134 | base_width=self.base_width, dilation=self.dilation, 135 | norm_layer=norm_layer)) 136 | 137 | return nn.Sequential(*layers) 138 | 139 | def _forward_impl(self, x): 140 | # See note [TorchScript super()] 141 | x = self.stem(x) 142 | 143 | x = self.layer1(x) 144 | x = self.layer2(x) 145 | x = self.layer3(x) 146 | x = self.layer4(x) 147 | 148 | x = self.avgpool(x) 149 | x = torch.flatten(x, 1) 150 | if self.fc: 151 | x = self.fc(x) 152 | 153 | return x 154 | 155 | def forward(self, x): 156 | return self._forward_impl(x) 157 | 158 | def ResNet10(**kwargs): 159 | return ResNet(BasicBlock, [1,1,1,1], **kwargs) 160 | 161 | def ResNet18(**kwargs): 162 | return ResNet(BasicBlock, [2,2,2,2], **kwargs) 163 | 164 | def ResNet34(**kwargs): 165 | return ResNet(BasicBlock, [3,4,6,3], **kwargs) 166 | 167 | def ResNet50(**kwargs): 168 | return ResNet(Bottleneck, [3,4,6,3], **kwargs) 169 | 170 | def ResNet101(**kwargs): 171 | return ResNet(Bottleneck, [3,4,23,3], **kwargs) 172 | 173 | def ResNet152(**kwargs): 174 | return ResNet(Bottleneck, [3,8,36,3], **kwargs) 175 | -------------------------------------------------------------------------------- /sas-pip/sas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjoshi804/sas-data-efficient-contrastive-learning/5e5741884c2e5dcdea65467989b1ab0720430eca/sas-pip/sas/__init__.py -------------------------------------------------------------------------------- /sas-pip/sas/approx_latent_classes.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import List 3 | 4 | import clip 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from fast_pytorch_kmeans import KMeans 10 | from tqdm import tqdm 11 | 12 | def clip_approx( 13 | img_trainset: torch.utils.data.Dataset, 14 | labeled_example_indices: List[int], 15 | labeled_examples_labels: np.array, 16 | num_classes: int, 17 | device: torch.device, 18 | batch_size: int = 512, 19 | verbose: bool = False, 20 | ): 21 | Z = encode_using_clip( 22 | img_trainset=img_trainset, 23 | device=device, 24 | batch_size=batch_size, 25 | verbose=verbose 26 | ) 27 | clf = train_linear_classifier( 28 | X=Z[labeled_example_indices], 29 | y=torch.tensor(labeled_examples_labels), 30 | representation_dim=len(Z[0]), 31 | num_classes=num_classes, 32 | device=device, 33 | verbose=verbose 34 | ) 35 | preds = [] 36 | for start_idx in range(0, len(Z), batch_size): 37 | preds.append(torch.argmax(clf(Z[start_idx:start_idx + batch_size]).detach(), dim=1).cpu()) 38 | preds = torch.cat(preds).numpy() 39 | 40 | return partition_from_preds(preds) 41 | 42 | def clip_0shot_approx( 43 | img_trainset: torch.utils.data.Dataset, 44 | class_names: List[str], 45 | device: torch.device, 46 | verbose: bool = False, 47 | ): 48 | model, preprocess = clip.load("ViT-B/32", device=device) 49 | img_trainset = deepcopy(img_trainset) 50 | img_trainset.transform = preprocess 51 | 52 | zeroshot_weights = zeroshot_classifier( 53 | class_names=class_names, 54 | device=device, 55 | verbose=verbose 56 | ) 57 | logits = [] 58 | loader = torch.utils.data.DataLoader(img_trainset, batch_size=32, num_workers=2) 59 | with torch.no_grad(): 60 | for input in tqdm(loader, "0-shot classification using provided text names for classes", disable=not verbose): 61 | # predict 62 | image_features = model.encode_image(input[0].to(device=device)) 63 | image_features /= image_features.norm(dim=-1, keepdim=True) 64 | logits.append(100. * image_features @ zeroshot_weights) 65 | 66 | preds = [] 67 | for logit in logits: 68 | preds.append(logit.topk(1, 1, True, True)[1].t()[0]) 69 | 70 | return partition_from_preds(preds) 71 | 72 | def kmeans_approx( 73 | trainset: torch.utils.data.Dataset, 74 | proxy_model: nn.Module, 75 | num_classes: int, 76 | device: torch.device, 77 | verbose: bool = False 78 | ): 79 | proxy_model.eval() 80 | Z = [] 81 | with torch.no_grad(): 82 | loader = torch.utils.data.DataLoader(trainset, batch_size=32, num_workers=2) 83 | for input in tqdm(loader, "Encoding data using proxy model provided", disable=not verbose): 84 | Z.append(proxy_model(input[0].to(device))) 85 | Z = torch.cat(Z, dim=0).to("cpu") 86 | 87 | if verbose: 88 | print(f"KMeans: clustering into {num_classes} clusters.") 89 | 90 | kmeans = KMeans(n_clusters=num_classes, mode='euclidean', verbose=int(verbose), max_iter=1000) 91 | preds = kmeans.fit_predict(Z).cpu().numpy() 92 | return partition_from_preds(preds) 93 | 94 | def encode_using_clip( 95 | img_trainset: torch.utils.data.Dataset, 96 | device: torch.device, 97 | batch_size=512, 98 | verbose: bool = False, 99 | ): 100 | model, preprocess = clip.load("ViT-B/32", device=device) 101 | img_trainset = deepcopy(img_trainset) 102 | img_trainset.transform = preprocess 103 | 104 | loader = torch.utils.data.DataLoader(img_trainset, batch_size=batch_size, num_workers=8) 105 | Z = [] 106 | with torch.no_grad(): 107 | for input in tqdm(loader, desc="Encoding images using CLIP", disable=not verbose): 108 | Z.append(model.encode_image(input[0].to(device))) 109 | Z = torch.cat(Z, dim=0).to(torch.float32) 110 | return Z 111 | 112 | def partition_from_preds(preds): 113 | partition = {} 114 | for i, pred in enumerate(preds): 115 | if pred not in partition: 116 | partition[pred] = [] 117 | partition[pred].append(i) 118 | return partition 119 | 120 | def train_linear_classifier( 121 | X: torch.tensor, 122 | y: torch.tensor, 123 | representation_dim: int, 124 | num_classes: int, 125 | device: torch.device, 126 | reg_weight: float = 1e-3, 127 | n_lbfgs_steps: int = 500, 128 | verbose=False, 129 | ): 130 | if verbose: 131 | print('\nL2 Regularization weight: %g' % reg_weight) 132 | 133 | criterion = nn.CrossEntropyLoss() 134 | X_gpu = X.to(device) 135 | y_gpu = y.to(device) 136 | 137 | # Should be reset after each epoch for a completely independent evaluation 138 | clf = nn.Linear(representation_dim, num_classes).to(device) 139 | clf_optimizer = optim.LBFGS(clf.parameters()) 140 | clf.train() 141 | 142 | for _ in tqdm(range(n_lbfgs_steps), desc="Training linear classifier using fraction of labels", disable=not verbose): 143 | def closure(): 144 | clf_optimizer.zero_grad() 145 | raw_scores = clf(X_gpu) 146 | loss = criterion(raw_scores, y_gpu) 147 | loss += reg_weight * clf.weight.pow(2).sum() 148 | loss.backward() 149 | return loss 150 | clf_optimizer.step(closure) 151 | return clf 152 | 153 | def zeroshot_classifier( 154 | class_names: List[str], 155 | device: torch.device, 156 | verbose: bool = False 157 | ): 158 | templates = [ 159 | 'itap of the {}.', 160 | 'a bad photo of the {}', 161 | 'a origami {}.', 162 | 'a photo of the large {}.', 163 | 'a {} in a video game.', 164 | 'art of the {}.', 165 | 'a photo of the small {}.', 166 | ] 167 | 168 | model, _ = clip.load("ViT-B/32") 169 | model = model.to(device) 170 | 171 | with torch.no_grad(): 172 | zeroshot_weights = [] 173 | for classname in tqdm(class_names, desc="Creating zero shot classifier", disable=not verbose): 174 | texts = [template.format(classname) for template in templates] #format with class 175 | texts = clip.tokenize(texts).to(device) #tokenize 176 | class_embeddings = model.encode_text(texts) #embed with text encoder 177 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 178 | class_embedding = class_embeddings.mean(dim=0) 179 | class_embedding /= class_embedding.norm() 180 | zeroshot_weights.append(class_embedding) 181 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1) 182 | return zeroshot_weights -------------------------------------------------------------------------------- /sas-pip/sas/submodular_maximization.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | 3 | from tqdm import tqdm 4 | 5 | def _heappush_max(heap, item): 6 | heap.append(item) 7 | heapq._siftdown_max(heap, 0, len(heap)-1) 8 | 9 | def _heappop_max(heap): 10 | """Maxheap version of a heappop.""" 11 | lastelt = heap.pop() # raises appropriate IndexError if heap is empty 12 | if heap: 13 | returnitem = heap[0] 14 | heap[0] = lastelt 15 | heapq._siftup_max(heap, 0) 16 | return returnitem 17 | return lastelt 18 | 19 | def lazy_greedy(F, V, B, verbose=False): 20 | """ 21 | Args 22 | - F: Submodular Objective 23 | - V: list of indices of columns of Similarity Matrix 24 | - B: Budget of subset (int) 25 | """ 26 | sset = [] 27 | 28 | order = [] 29 | heapq._heapify_max(order) 30 | [_heappush_max(order, (F.inc(sset, index), index)) for index in V] 31 | 32 | if verbose: 33 | print("Starting lazy greedy selection") 34 | 35 | with tqdm(total=B, disable=not verbose) as pbar: 36 | while order and len(sset) < B: 37 | el = _heappop_max(order) 38 | improv = F.inc(sset, el[1]) 39 | 40 | #if improv >= 0: 41 | if not order: 42 | sset.append(el[1]) 43 | F.add(el[1]) 44 | pbar.update(1) 45 | else: 46 | top = _heappop_max(order) 47 | if improv >= top[0]: 48 | sset.append(el[1]) 49 | F.add(el[1]) 50 | pbar.update(1) 51 | else: 52 | _heappush_max(order, (improv, el[1])) 53 | _heappush_max(order, top) 54 | return sset -------------------------------------------------------------------------------- /sas-pip/sas/subset_dataset.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Dict, List, Optional 3 | import math 4 | import pickle 5 | import random 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import Dataset 11 | 12 | from sas.submodular_maximization import lazy_greedy 13 | from tqdm import tqdm 14 | 15 | # Efficient alternate implementation of np.block 16 | def efficient_block(mat_list:List[List[np.ndarray]]): 17 | if type(mat_list[0]) is not list: 18 | mat_list = [mat_list] 19 | 20 | x_size = 0 21 | y_size = 0 22 | 23 | for i in mat_list[0]: 24 | x_size += i.shape[1] 25 | 26 | for j in range(len(mat_list)): 27 | y_size += mat_list[j][0].shape[0] 28 | 29 | output_data = np.zeros((y_size, x_size)) 30 | 31 | x_cursor = 0 32 | y_cursor = 0 33 | 34 | for mat_row in mat_list: 35 | y_offset = 0 36 | 37 | for matrix_ in mat_row: 38 | shape_ = matrix_.shape 39 | output_data[y_cursor: y_cursor + shape_[0], x_cursor: x_cursor + shape_[1]] = matrix_ 40 | x_cursor += shape_[1] 41 | y_offset = shape_[0] 42 | 43 | y_cursor += y_offset 44 | x_cursor = 0 45 | 46 | return output_data 47 | 48 | """ 49 | Base Subset Dataset (Abstract Base Class) 50 | """ 51 | class BaseSubsetDataset(ABC, Dataset): 52 | def __init__( 53 | self, 54 | dataset: Dataset, 55 | subset_fraction: float, 56 | verbose: bool = False 57 | ): 58 | """ 59 | :param dataset: Original Dataset 60 | :type dataset: Dataset 61 | :param subset_fraction: Fractional size of subset 62 | :type subset_fraction: float 63 | :param verbose: verbose 64 | :type verbose: boolean 65 | """ 66 | self.dataset = dataset 67 | self.subset_fraction = subset_fraction 68 | self.len_dataset = len(self.dataset) 69 | self.subset_size = int(self.len_dataset * self.subset_fraction) 70 | self.subset_indices = None 71 | self.verbose = verbose 72 | 73 | def initialization_complete(self): 74 | if self.verbose: 75 | print(f"Subset Size: {self.subset_size}") 76 | print(f"Discarded {self.len_dataset - self.subset_size} examples") 77 | 78 | def __len__(self): 79 | return self.subset_size 80 | 81 | def __getitem__(self, index): 82 | # Get the index for the corresponding item in the original dataset 83 | original_index = self.subset_indices[index] 84 | 85 | # Get the item from the original dataset at the corresponding index 86 | original_item = self.dataset[original_index] 87 | 88 | return original_item 89 | 90 | def save_to_file(self, filename): 91 | with open(filename, "wb") as f: 92 | pickle.dump(self.subset_indices, f) 93 | 94 | """ 95 | Random Subset 96 | """ 97 | class RandomSubsetDataset(BaseSubsetDataset): 98 | def __init__( 99 | self, 100 | dataset: Dataset, 101 | subset_fraction: float, 102 | partition: Optional[Dict[int, List[int]]] = None, 103 | verbose: bool = False 104 | ): 105 | """ 106 | :param dataset: Original Dataset 107 | :type dataset: Dataset 108 | :param subset_fraction: Fractional size of subset 109 | :type subset_fraction: float 110 | :param verbose: verbose 111 | :type verbose: boolean 112 | """ 113 | super().__init__( 114 | dataset=dataset, 115 | subset_fraction=subset_fraction, 116 | verbose=verbose 117 | ) 118 | 119 | self.subset_indices = [] 120 | if partition is not None: 121 | if self.verbose: 122 | print("Partition provided => returning balanced random subset from all latent classes") 123 | self.subset_indices = RandomSubsetDataset.get_random_balanced_indices(partition, subset_fraction) 124 | else: 125 | if self.verbose: 126 | print("No partition => random subset from full data") 127 | self.subset_indices = random.sample(range(self.len_dataset), self.subset_size) 128 | self.initialization_complete() 129 | 130 | def __len__(self): 131 | return self.subset_size 132 | 133 | def __getitem__(self, index): 134 | # Get the index for the corresponding item in the original dataset 135 | original_index = self.subset_indices[index] 136 | 137 | # Get the item from the original dataset at the corresponding index 138 | original_item = self.dataset[original_index] 139 | 140 | return original_item 141 | 142 | @staticmethod 143 | def get_random_balanced_indices(partition: Dict[int, List[int]], subset_fraction: float): 144 | """ 145 | Randomly selects a subset of fractional size = 'subset_fraction' from each latent class in partition. 146 | The subset selected from each list is the same fraction of the whole. 147 | 148 | Parameters: 149 | - partition: Dict[int, List[int]] 150 | - subset_fraction: float 151 | 152 | Returns: 153 | - selected_subset: List containing the selected subset. 154 | """ 155 | 156 | def random_subset_with_fixed_size(original_list, subset_size): 157 | subset_size = min(subset_size, len(original_list)) 158 | return random.sample(original_list, subset_size) 159 | 160 | selected_subset = [] 161 | 162 | for key in partition.keys(): 163 | subset_size = int(len(partition[key]) * subset_fraction) 164 | subset = random_subset_with_fixed_size(partition[key], subset_size) 165 | selected_subset.extend(subset) 166 | 167 | return selected_subset 168 | 169 | """ 170 | Custom Subset 171 | """ 172 | class CustomSubsetDataset(BaseSubsetDataset): 173 | def __init__( 174 | self, 175 | dataset: Dataset, 176 | subset_indices: List[int], 177 | verbose: bool = False, 178 | ): 179 | """ 180 | :param dataset: Original Dataset 181 | :type dataset: Dataset 182 | :param subset_fraction: Fractional size of subset 183 | :type subset_fraction: float 184 | :param subset_indices: Indices of custom subset 185 | :type subset_indices: List[int] 186 | :param verbose: verbose 187 | :type verbose: boolean 188 | """ 189 | super().__init__( 190 | dataset=dataset, 191 | subset_fraction=1.0, 192 | verbose=verbose 193 | ) 194 | self.subset_size = len(subset_indices) 195 | self.subset_fraction = self.subset_size / len(dataset) 196 | self.subset_indices = subset_indices 197 | self.initialization_complete() 198 | 199 | """ 200 | Subsets that maximize Augmentation Similarity Subset Dataset 201 | """ 202 | class SubsetSelectionObjective: 203 | def __init__(self, distance, threshold=0, verbose=False): 204 | ''' 205 | :param distance: (n, n) matrix specifying pairwise augmentation distance 206 | :type distance: np.array 207 | :param threshold: minimum cosine similarity to consider to be significant (default=0) 208 | :type threshold: float 209 | ''' 210 | self.distance = distance 211 | self.threshold = threshold 212 | self.verbose = verbose 213 | if self.verbose: 214 | print("Masking pairwise distance matrix") 215 | for i in range(len(self.distance)): 216 | self.distance[i] *= (self.distance[i] >= self.threshold) 217 | 218 | def inc(self, sset, i): 219 | return np.sum(self.distance[i]) - np.sum(self.distance[np.ix_(sset, [i])]) 220 | 221 | def add(self, i): 222 | self.distance[:][i] = 0 223 | return 224 | 225 | class SASSubsetDataset(BaseSubsetDataset): 226 | def __init__( 227 | self, 228 | dataset: Dataset, 229 | subset_fraction: float, 230 | num_downstream_classes: int, 231 | device: torch.device, 232 | approx_latent_class_partition: Dict[int, int], 233 | proxy_model: Optional[nn.Module] = None, 234 | augmentation_distance: Optional[Dict[int, np.array]] = None, 235 | num_runs=1, 236 | pairwise_distance_block_size: int = 1024, 237 | threshold: float = 0.0, 238 | verbose: bool = False 239 | ): 240 | """ 241 | dataset: Dataset 242 | Original dataset 243 | 244 | subset_fraction: float 245 | Fractional size of subset. 246 | 247 | num_downstream_classes: int 248 | Number of downstream classes (can be an estimate). 249 | 250 | proxy_model: nn.Module 251 | Proxy model to calculate the augmentation distance (and kmeans clustering if the avoid clip option is chosen). 252 | 253 | augmentation_distance: Dict[int, np.array] 254 | Pass a precomputed dictionary containing augmentation distance for each latent class. 255 | 256 | num_augmentations: int 257 | Number of augmentations to consider while approximating the augmentation distance. 258 | 259 | pairwise_distance_block_size: int 260 | Block size for calculating pairwise distance. This is just to optimize GPU usage while calculating pairwise distance and will not affect the subset created in any way. 261 | 262 | verbose: boolean 263 | Verbosity of the output. 264 | """ 265 | super().__init__( 266 | dataset=dataset, 267 | subset_fraction=subset_fraction, 268 | verbose=verbose 269 | ) 270 | self.device = device 271 | self.num_downstream_classes = num_downstream_classes 272 | self.proxy_model = proxy_model 273 | self.partition = approx_latent_class_partition 274 | self.augmentation_distance = augmentation_distance 275 | self.num_runs = num_runs 276 | self.pairwise_distance_block_size = pairwise_distance_block_size 277 | print("Here1") 278 | if self.augmentation_distance == None: 279 | self.augmentation_distance = self.approximate_augmentation_distance() 280 | 281 | print("Here2") 282 | class_wise_idx = {} 283 | for latent_class in tqdm(self.partition.keys(), desc="Subset Selection", disable=not verbose): 284 | F = SubsetSelectionObjective(self.augmentation_distance[latent_class].copy(), threshold=threshold, verbose=self.verbose) 285 | class_wise_idx[latent_class] = lazy_greedy(F, range(len(self.augmentation_distance[latent_class])), len(self.augmentation_distance[latent_class]), verbose=self.verbose) 286 | class_wise_idx[latent_class] = [self.partition[latent_class][i] for i in class_wise_idx[latent_class]] 287 | 288 | print("Here3") 289 | self.subset_indices = [] 290 | for latent_class in class_wise_idx.keys(): 291 | l = len(class_wise_idx[latent_class]) 292 | self.subset_indices.extend(class_wise_idx[latent_class][:int(self.subset_fraction * l)]) 293 | 294 | self.initialization_complete() 295 | 296 | 297 | def approximate_augmentation_distance(self): 298 | self.proxy_model = self.proxy_model.to(self.device) 299 | 300 | # Initialize augmentation distance with all 0s 301 | augmentation_distance = {} 302 | Z = self.encode_trainset() 303 | for latent_class in tqdm(list(self.partition.keys()), desc="Computing augmentation distance", disable=not(self.verbose)): 304 | Z_partition = Z[self.partition[latent_class]] 305 | pairwise_distance = SASSubsetDataset.pairwise_distance(Z_partition, Z_partition, verbose=self.verbose) 306 | augmentation_distance[latent_class] = pairwise_distance.copy() 307 | return augmentation_distance 308 | 309 | def encode_trainset(self): 310 | trainloader = torch.utils.data.DataLoader(self.dataset, batch_size=self.pairwise_distance_block_size, shuffle=False, num_workers=2, pin_memory=True) 311 | with torch.no_grad(): 312 | Z = [] 313 | for input in tqdm(trainloader, desc="Encoding trainset", disable=not(self.verbose)): 314 | Z.append(self.proxy_model(input[0].to(self.device))) 315 | return torch.cat(Z, dim=0) 316 | 317 | def encode_augmented_trainset(self, num_positives=1): 318 | trainloader = torch.utils.data.DataLoader(self.dataset, batch_size=self.pairwise_distance_block_size, shuffle=False, num_workers=2, pin_memory=True) 319 | with torch.no_grad(): 320 | Z = [] 321 | for _ in range(num_positives): 322 | Z.append([]) 323 | for X in tqdm(trainloader, desc="Encoding augmented trainset", disable=not(self.verbose)): 324 | for j in range(num_positives): 325 | Z[j].append(self.proxy_model(X[j].to(self.device))) 326 | for i in range(num_positives): 327 | Z[i] = torch.cat(Z[i], dim=0) 328 | Z = torch.cat(Z, dim=0) 329 | return Z 330 | 331 | @staticmethod 332 | def pairwise_distance(Z1: torch.tensor, Z2: torch.tensor, block_size: int = 1024, verbose=False): 333 | similarity_matrices = [] 334 | for i in tqdm(range(Z1.shape[0] // block_size + 1), desc="Computing pairwise distances", disable=not(verbose)): 335 | similarity_matrices_i = [] 336 | e = Z1[i*block_size:(i+1)*block_size] 337 | for j in range(Z2.shape[0] // block_size + 1): 338 | e_t = Z2[j*block_size:(j+1)*block_size].t() 339 | similarity_matrices_i.append( 340 | np.array( 341 | torch.cosine_similarity(e[:, :, None], e_t[None, :, :]).detach().cpu() 342 | ) 343 | ) 344 | similarity_matrices.append(similarity_matrices_i) 345 | similarity_matrix = efficient_block(similarity_matrices) 346 | 347 | return similarity_matrix 348 | -------------------------------------------------------------------------------- /sas-pip/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | VERSION = '1.0' 4 | DESCRIPTION = 'A python package implementing the data-efficient contrastive learning proposed in https://arxiv.org/abs/2302.09195 by S. Joshi and B. Mirzasoleiman' 5 | 6 | setup( 7 | name="sas", 8 | version=VERSION, 9 | author="Siddharth Joshi", 10 | author_email="sjoshi804@cs.ucla.edu", 11 | description=DESCRIPTION, 12 | long_description_content_type="text/markdown", 13 | long_description=DESCRIPTION, 14 | packages=find_packages(), 15 | install_requires=['torch', 'torchvision', 'numpy', 'fast-pytorch-kmeans'], 16 | ) -------------------------------------------------------------------------------- /simclr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | from datetime import datetime 5 | import random 6 | 7 | import numpy as np 8 | import sas.subset_dataset 9 | import torch 10 | import torch.multiprocessing as mp 11 | import torch.optim as optim 12 | from torch.distributed import destroy_process_group, init_process_group 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | from torch.utils.data.distributed import DistributedSampler 15 | import wandb 16 | 17 | from configs import SupportedDatasets, get_datasets 18 | from projection_heads.critic import LinearCritic 19 | from resnet import * 20 | from trainer import Trainer 21 | from util import Random 22 | 23 | def main(rank: int, world_size: int, args): 24 | 25 | # Determine Device 26 | device = rank 27 | if args.distributed: 28 | device = args.device_ids[rank] 29 | torch.cuda.set_device(args.device_ids[rank]) 30 | args.lr *= world_size 31 | 32 | # WandB Logging 33 | if not args.distributed or rank == 0: 34 | wandb.init( 35 | project="data-efficient-contrastive-learning", 36 | config=args 37 | ) 38 | 39 | if args.distributed: 40 | args.batch_size = int(args.batch_size / world_size) 41 | 42 | # Set all seeds 43 | torch.manual_seed(args.seed) 44 | np.random.seed(args.seed) 45 | Random(args.seed) 46 | 47 | print('==> Preparing data..') 48 | datasets = get_datasets(args.dataset) 49 | 50 | ############################################################## 51 | # Load Subset Indices 52 | ############################################################## 53 | 54 | if args.random_subset: 55 | trainset = sas.subset_dataset.RandomSubsetDataset( 56 | dataset=datasets.trainset, 57 | subset_fraction=args.subset_fraction 58 | ) 59 | elif args.subset_indices != "": 60 | with open(args.subset_indices, "rb") as f: 61 | subset_indices = pickle.load(f) 62 | trainset = sas.subset_dataset.CustomSubsetDataset( 63 | dataset=datasets.trainset, 64 | subset_indices=subset_indices 65 | ) 66 | else: 67 | trainset = datasets.trainset 68 | print("subset_size:", len(trainset)) 69 | 70 | # Model 71 | print('==> Building model..') 72 | 73 | ############################################################## 74 | # Encoder 75 | ############################################################## 76 | 77 | if args.arch == 'resnet10': 78 | net = ResNet10(stem=datasets.stem) 79 | elif args.arch == 'resnet18': 80 | net = ResNet18(stem=datasets.stem) 81 | elif args.arch == 'resnet34': 82 | net = ResNet34(stem=datasets.stem) 83 | elif args.arch == 'resnet50': 84 | net = ResNet50(stem=datasets.stem) 85 | else: 86 | raise ValueError("Bad architecture specification") 87 | 88 | ############################################################## 89 | # Critic 90 | ############################################################## 91 | 92 | critic = LinearCritic(net.representation_dim, temperature=args.temperature).to(device) 93 | 94 | # DCL Setup 95 | optimizer = optim.Adam(list(net.parameters()) + list(critic.parameters()), lr=args.lr, weight_decay=1e-6) 96 | if args.dataset == SupportedDatasets.TINY_IMAGENET.value: 97 | optimizer = optim.Adam(list(net.parameters()) + list(critic.parameters()), lr=2 * args.lr, weight_decay=1e-6) 98 | 99 | 100 | ############################################################## 101 | # Data Loaders 102 | ############################################################## 103 | 104 | trainloader = torch.utils.data.DataLoader( 105 | dataset=trainset, 106 | batch_size=args.batch_size, 107 | shuffle=(not args.distributed), 108 | sampler=DistributedSampler(trainset, shuffle=True, num_replicas=world_size, rank=rank, drop_last=True) if args.distributed else None, 109 | num_workers=4, 110 | pin_memory=True, 111 | ) 112 | 113 | clftrainloader = torch.utils.data.DataLoader( 114 | dataset=datasets.clftrainset, 115 | batch_size=args.batch_size, 116 | shuffle=False, 117 | num_workers=4, 118 | pin_memory=True 119 | ) 120 | 121 | testloader = torch.utils.data.DataLoader( 122 | dataset=datasets.testset, 123 | batch_size=args.batch_size, 124 | shuffle=False, 125 | num_workers=4, 126 | pin_memory=True, 127 | ) 128 | 129 | ############################################################## 130 | # Main Loop (Train, Test) 131 | ############################################################## 132 | 133 | # Date Time String 134 | DT_STRING = "".join(str(datetime.now()).split()) 135 | 136 | if args.distributed: 137 | ddp_setup(rank, world_size, str(args.port)) 138 | 139 | net = net.to(device) 140 | critic = critic.to(device) 141 | if args.distributed: 142 | net = DDP(net, device_ids=[device]) 143 | 144 | trainer = Trainer( 145 | device=device, 146 | distributed=args.distributed, 147 | rank=rank if args.distributed else 0, 148 | world_size=world_size, 149 | net=net, 150 | critic=critic, 151 | trainloader=trainloader, 152 | clftrainloader=clftrainloader, 153 | testloader=testloader, 154 | num_classes=datasets.num_classes, 155 | optimizer=optimizer, 156 | ) 157 | 158 | for epoch in range(0, args.num_epochs): 159 | print(f"step: {epoch}") 160 | 161 | train_loss = trainer.train() 162 | print(f"train_loss: {train_loss}") 163 | if not args.distributed or rank == 0: 164 | wandb.log( 165 | data={"train": { 166 | "loss": train_loss, 167 | }}, 168 | step=epoch 169 | ) 170 | 171 | if (args.test_freq > 0) and (not args.distributed or rank == 0) and ((epoch + 1) % args.test_freq == 0): 172 | test_acc = trainer.test() 173 | print(f"test_acc: {test_acc}") 174 | wandb.log( 175 | data={"test": { 176 | "acc": test_acc, 177 | }}, 178 | step=epoch 179 | ) 180 | 181 | # Checkpoint Model 182 | if (args.checkpoint_freq > 0) and ((not args.distributed or rank == 0) and (epoch + 1) % args.checkpoint_freq == 0): 183 | trainer.save_checkpoint(prefix=f"{DT_STRING}-{args.dataset}-{args.arch}-{epoch}") 184 | 185 | if not args.distributed or rank == 0: 186 | print(f"best_test_acc: {trainer.best_acc}") 187 | wandb.log( 188 | data={"test": { 189 | "best_acc": trainer.best_acc, 190 | }} 191 | ) 192 | wandb.finish(quiet=True) 193 | 194 | if args.distributed: 195 | destroy_process_group() 196 | 197 | 198 | ############################################################## 199 | # Distributed Training Setup 200 | ############################################################## 201 | def ddp_setup(rank: int, world_size: int, port: str): 202 | os.environ["MASTER_ADDR"] = "localhost" 203 | os.environ["MASTER_PORT"] = port 204 | init_process_group(backend="nccl", rank=rank, world_size=world_size) 205 | 206 | if __name__ == "__main__": 207 | parser = argparse.ArgumentParser(description='PyTorch Contrastive Learning.') 208 | parser.add_argument('--temperature', type=float, default=0.5, help='InfoNCE temperature') 209 | parser.add_argument("--batch-size", type=int, default=512, help='Training batch size') 210 | parser.add_argument("--lr", type=float, default=1e-3, help='learning rate') 211 | parser.add_argument("--num-epochs", type=int, default=400, help='Number of training epochs') 212 | parser.add_argument("--arch", type=str, default='resnet18', help='Encoder architecture', 213 | choices=['resnet10', 'resnet18', 'resnet34', 'resnet50']) 214 | parser.add_argument("--test-freq", type=int, default=10, help='Frequency to fit a linear clf with L-BFGS for testing' 215 | 'Not appropriate for large datasets. Set 0 to avoid ' 216 | 'classifier only training here.') 217 | parser.add_argument("--checkpoint-freq", type=int, default=400, help="How often to checkpoint model") 218 | parser.add_argument('--dataset', type=str, default=str(SupportedDatasets.CIFAR100.value), help='dataset', 219 | choices=[x.value for x in SupportedDatasets]) 220 | parser.add_argument('--subset-indices', type=str, default="", help="Path to subset indices") 221 | parser.add_argument('--random-subset', action="store_true", help="Random subset") 222 | parser.add_argument('--subset-fraction', type=float, help="Size of Subset as fraction (only needed for random subset)") 223 | parser.add_argument('--device', type=int, default=-1, help="GPU number to use") 224 | parser.add_argument("--device-ids", nargs = "+", default = None, help = "Specify device ids if using multiple gpus") 225 | parser.add_argument('--port', type=int, default=random.randint(49152, 65535), help="free port to use") 226 | parser.add_argument('--seed', type=int, default=0, help="Seed for randomness") 227 | 228 | # Parse arguments 229 | args = parser.parse_args() 230 | 231 | # Arguments check and initialize global variables 232 | device = "cpu" 233 | device_ids = None 234 | distributed = False 235 | if torch.cuda.is_available(): 236 | if args.device_ids is None: 237 | if args.device >= 0: 238 | device = args.device 239 | else: 240 | device = 0 241 | else: 242 | distributed = True 243 | device_ids = [int(id) for id in args.device_ids] 244 | args.device = device 245 | args.device_ids = device_ids 246 | args.distributed = distributed 247 | if distributed: 248 | mp.spawn( 249 | fn=main, 250 | args=(len(device_ids), args), 251 | nprocs=len(device_ids) 252 | ) 253 | else: 254 | main(device, 1, args) -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from torch import Tensor, nn 4 | import torch 5 | import torch.distributed as dist 6 | from torch.optim import Optimizer 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | 10 | from evaluate.lbfgs import encode_train_set, train_clf, test_clf 11 | from projection_heads.critic import LinearCritic 12 | 13 | class Trainer(): 14 | def __init__( 15 | self, 16 | net: nn.Module, 17 | critic: LinearCritic, 18 | trainloader: DataLoader, 19 | clftrainloader: DataLoader, 20 | testloader: DataLoader, 21 | num_classes: int, 22 | optimizer: Optimizer, 23 | device: torch.device, 24 | distributed: bool, 25 | rank: int = 0, 26 | world_size: int = 1, 27 | lr_scheduler = None, 28 | ): 29 | """ 30 | :param device: Device to run on (GPU) 31 | :param net: encoder network 32 | :param critic: projection head 33 | :param trainset: training data 34 | :param clftrainloader: dataloader for train data (for linear probe) 35 | :param optimizer: Optimizer for the encoder network (net) 36 | :param lr_scheduler: learning rate scheduler 37 | """ 38 | self.device = device 39 | self.rank = rank 40 | self.net = net 41 | self.critic = critic 42 | self.trainloader = trainloader 43 | self.clftrainloader = clftrainloader 44 | self.testloader = testloader 45 | self.num_classes = num_classes 46 | self.encoder_optimizer = optimizer 47 | self.lr_scheduler = lr_scheduler 48 | self.distributed = distributed 49 | self.world_size = world_size 50 | 51 | self.criterion = nn.CrossEntropyLoss() 52 | self.best_acc = 0 53 | self.best_rare_acc = 0 54 | 55 | ######################################### 56 | # Loss Functions # 57 | ######################################### 58 | def un_supcon_loss(self, z: Tensor, num_positive: int): 59 | batch_size = int(len(z) / num_positive) 60 | 61 | if self.distributed: 62 | all_z = [torch.zeros_like(z) for _ in range(self.world_size)] 63 | dist.all_gather(all_z, z) 64 | # Move all tensors to the same device 65 | aug_z = [] 66 | for i in range(num_positive): 67 | aug_z.append([]) 68 | for rank in range(self.world_size): 69 | if rank == self.rank: 70 | aug_z[-1].append(z[i * batch_size: (i+1) * batch_size]) 71 | else: 72 | aug_z[-1].append(all_z[rank][i * batch_size: (i+1) * batch_size]) 73 | z = [torch.cat(aug_z_i, dim=0) for aug_z_i in aug_z] 74 | else: 75 | aug_z = [] 76 | for i in range(num_positive): 77 | aug_z.append(z[i * batch_size : (i + 1) * batch_size]) 78 | z = aug_z 79 | 80 | sim = self.critic(z) 81 | #print(sim) 82 | log_sum_exp_sim = torch.log(torch.sum(torch.exp(sim), dim=1)) 83 | # Positive Pairs Mask 84 | p_targets = torch.cat([torch.tensor(range(int(len(sim) / num_positive)))] * num_positive) 85 | #len(p_targets) 86 | pos_pairs = (p_targets.unsqueeze(1) == p_targets.unsqueeze(0)).to(self.device) 87 | #print(pos_pairs) 88 | inf_mask = (sim != float('-inf')).to(self.device) 89 | pos_pairs = torch.logical_and(pos_pairs, inf_mask) 90 | pos_count = torch.sum(pos_pairs, dim=1) 91 | pos_sims = torch.nansum(sim * pos_pairs, dim=-1) 92 | return torch.mean(-pos_sims / pos_count + log_sum_exp_sim) 93 | 94 | ######################################### 95 | # Train & Test Modules # 96 | ######################################### 97 | def train(self): 98 | self.net.train() 99 | self.critic.train() 100 | 101 | # Training Loop (over batches in epoch) 102 | train_loss = 0 103 | t = tqdm(enumerate(self.trainloader), desc='Loss: **** ', total=len(self.trainloader), bar_format='{desc}{bar}{r_bar}') 104 | for batch_idx, inputs in t: 105 | num_positive = len(inputs) 106 | x = torch.cat(inputs, dim=0).to(self.device) 107 | self.encoder_optimizer.zero_grad() 108 | z = self.net(x) 109 | loss = self.un_supcon_loss(z, num_positive) 110 | loss.backward() 111 | 112 | self.encoder_optimizer.step() 113 | train_loss += loss.item() 114 | t.set_description('Loss: %.3f ' % (train_loss / (batch_idx + 1))) 115 | 116 | if self.lr_scheduler is not None: 117 | self.lr_scheduler.step() 118 | print("lr:", self.scale_lr * self.lr_scheduler.get_last_lr()[0]) 119 | 120 | return train_loss / len(self.trainloader) 121 | 122 | def test(self): 123 | X, y = encode_train_set(self.clftrainloader, self.device, self.net) 124 | representation_dim = self.net.module.representation_dim if self.distributed else self.net.representation_dim 125 | clf = train_clf(X, y, representation_dim, self.num_classes, self.device, reg_weight=1e-5, iter=100) 126 | acc = test_clf(self.testloader, self.device, self.net, clf) 127 | 128 | if acc > self.best_acc: 129 | self.best_acc = acc 130 | 131 | return acc 132 | 133 | def save_checkpoint(self, prefix): 134 | if self.world_size > 1: 135 | torch.save(self.net.module, f"{prefix}-net.pt") 136 | else: 137 | torch.save(self.net, f"{prefix}-net.pt") 138 | torch.save(self.critic, f"{prefix}-critic.pt") -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | def save_np_array(template_file_name: str, array_label: str, np_array: np.array): 5 | with open(f"{template_file_name}-{array_label}.npy", 'wb+') as csvfile: 6 | np.save(csvfile, np_array, allow_pickle=False, fix_imports=False) 7 | 8 | class Singleton (type): 9 | _instances = {} 10 | def __call__(cls, *args, **kwargs): 11 | if cls not in cls._instances: 12 | cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) 13 | return cls._instances[cls] 14 | 15 | class Random(metaclass=Singleton): 16 | def __init__(self, seed=0): 17 | self.random = random 18 | self.random.seed(seed) --------------------------------------------------------------------------------