├── .gitignore ├── Figure_1_Modality_Gap ├── features_clasp.npy ├── features_clasp_random.npy ├── features_clip.npy ├── features_clip_random.npy ├── features_convirt.npy ├── features_convirt_random.npy ├── features_videoclip.npy ├── features_videoclip_random.npy ├── repr_clasp.ipynb ├── repr_clip.ipynb ├── repr_convirt.ipynb ├── repr_videoclip.ipynb └── visualize.ipynb ├── Figure_2_Cone_Effect ├── Figure_2a_random_init_random_data │ ├── coco-extract.ipynb │ └── visualize.ipynb ├── Figure_2a_random_init_real_data │ ├── coco-extract.ipynb │ └── visualize.ipynb ├── Figure_2a_real_features │ └── real_features.ipynb ├── Figure_2b_random_MLP_layerwise │ ├── bias_linear_relu.ipynb │ └── no_bias │ │ └── linear_relu.ipynb └── Figure_2c_scatter_cones_random_init │ ├── MLP │ └── scatter_cones_linear_relu.ipynb │ ├── random_data │ ├── coco-extract.ipynb │ ├── visualizePCA.ipynb │ └── visualizeUMAP.ipynb │ ├── real_data │ ├── coco-extract.ipynb │ ├── visualizePCA.ipynb │ └── visualizeUMAP.ipynb │ └── real_data_ImageNet_pretrained │ ├── ImageNet-Pretrained-Cones.png │ ├── README.md │ ├── coco-extract.ipynb │ └── visualizeUMAP.ipynb ├── Figure_3_Contrastive_Learning ├── 3d_sphere.ipynb ├── Appendix_3d_sphere.ipynb ├── get_gap_stats.ipynb ├── mismatched_simulation.ipynb └── plot_optimization_exp.ipynb ├── LICENSE ├── README.md ├── Table_1_Implications_CLIP_Zero_Shot ├── shifting │ └── shift_features.ipynb ├── simulation │ └── simulation.ipynb └── training │ ├── datasets.py │ ├── train_clip.py │ └── utils.py ├── Table_2_Implications_CLIP_Fairness ├── coco-extract.ipynb └── shift_CLIP_FairFace_Bias.ipynb ├── docs └── figures │ ├── Figure1.png │ ├── Figure2.jpg │ ├── Figure2ab.png │ ├── Figure2c.png │ ├── Figure3.jpg │ ├── Tables.png │ ├── Theorem1.png │ ├── Theorem2.png │ └── Theorem_variance.png ├── environment.yml └── util ├── gap_amend_std.ipynb └── get_arch.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.pkl 6 | 7 | */dummy_val/* 8 | */val/* 9 | */features/* 10 | */fairface-img-margin025-trainval.zip.zip 11 | */fairface_label_val.csv 12 | 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | -------------------------------------------------------------------------------- /Figure_1_Modality_Gap/features_clasp.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Weixin-Liang/Modality-Gap/8e20cb24efa4c5f89aad694f2f65eb43ffc46d10/Figure_1_Modality_Gap/features_clasp.npy -------------------------------------------------------------------------------- /Figure_1_Modality_Gap/features_clasp_random.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Weixin-Liang/Modality-Gap/8e20cb24efa4c5f89aad694f2f65eb43ffc46d10/Figure_1_Modality_Gap/features_clasp_random.npy -------------------------------------------------------------------------------- /Figure_1_Modality_Gap/features_clip.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Weixin-Liang/Modality-Gap/8e20cb24efa4c5f89aad694f2f65eb43ffc46d10/Figure_1_Modality_Gap/features_clip.npy -------------------------------------------------------------------------------- /Figure_1_Modality_Gap/features_clip_random.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Weixin-Liang/Modality-Gap/8e20cb24efa4c5f89aad694f2f65eb43ffc46d10/Figure_1_Modality_Gap/features_clip_random.npy -------------------------------------------------------------------------------- /Figure_1_Modality_Gap/features_convirt.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Weixin-Liang/Modality-Gap/8e20cb24efa4c5f89aad694f2f65eb43ffc46d10/Figure_1_Modality_Gap/features_convirt.npy -------------------------------------------------------------------------------- /Figure_1_Modality_Gap/features_convirt_random.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Weixin-Liang/Modality-Gap/8e20cb24efa4c5f89aad694f2f65eb43ffc46d10/Figure_1_Modality_Gap/features_convirt_random.npy -------------------------------------------------------------------------------- /Figure_1_Modality_Gap/features_videoclip.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Weixin-Liang/Modality-Gap/8e20cb24efa4c5f89aad694f2f65eb43ffc46d10/Figure_1_Modality_Gap/features_videoclip.npy -------------------------------------------------------------------------------- /Figure_1_Modality_Gap/features_videoclip_random.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Weixin-Liang/Modality-Gap/8e20cb24efa4c5f89aad694f2f65eb43ffc46d10/Figure_1_Modality_Gap/features_videoclip_random.npy -------------------------------------------------------------------------------- /Figure_1_Modality_Gap/visualize.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "from matplotlib import pyplot as plt\n", 11 | "plt.rcParams['figure.dpi'] = 300\n", 12 | "plt.rcParams['savefig.dpi'] = 300\n", 13 | "import seaborn as sns\n", 14 | "sns.set_theme()\n", 15 | "sns.set_context(\"talk\")\n", 16 | "\n", 17 | "import sys\n", 18 | "import os\n", 19 | "sys.path.append('ANONYMOUS_ROOTDIR/develop/open-world/')\n", 20 | "from utils import reduce_and_visualize" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "filenames = sorted([filename for filename in os.listdir() if filename.endswith('.npy')])" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "for filename in filenames:\n", 39 | " print(filename)\n", 40 | " image_features, text_features = np.load(filename)\n", 41 | " reduce_and_visualize(image_features, text_features, connection=True)\n", 42 | " input()" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [] 51 | } 52 | ], 53 | "metadata": { 54 | "interpreter": { 55 | "hash": "bf49421d02fb18daac2fe024769d7389ca36bccb970e26253e571efb021ca22f" 56 | }, 57 | "kernelspec": { 58 | "display_name": "Python 3.8.12 ('dalle')", 59 | "language": "python", 60 | "name": "python3" 61 | }, 62 | "language_info": { 63 | "codemirror_mode": { 64 | "name": "ipython", 65 | "version": 3 66 | }, 67 | "file_extension": ".py", 68 | "mimetype": "text/x-python", 69 | "name": "python", 70 | "nbconvert_exporter": "python", 71 | "pygments_lexer": "ipython3", 72 | "version": "3.8.12" 73 | }, 74 | "orig_nbformat": 4 75 | }, 76 | "nbformat": 4, 77 | "nbformat_minor": 2 78 | } 79 | -------------------------------------------------------------------------------- /Figure_2_Cone_Effect/Figure_2a_random_init_random_data/coco-extract.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "YPHN7PJgKOzb" 7 | }, 8 | "source": [ 9 | "# Image Feature Pair Extract - CLIP, ResNet18. \n", 10 | "conda activate clip\n", 11 | "\n", 12 | "\n", 13 | "clip_image_features_list (118287, 512)\n", 14 | "target_image_features_list (118287, 512)\n", 15 | "clip_image_features_list (5000, 512)\n", 16 | "target_image_features_list (5000, 512)\n", 17 | "\n", 18 | "Feature extraction complete in 6m 16s" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "metadata": { 25 | "colab": { 26 | "base_uri": "https://localhost:8080/" 27 | }, 28 | "id": "C1hkDT38hSaP", 29 | "outputId": "70a44964-883d-4fd0-b95a-2c7f2b19aca9" 30 | }, 31 | "outputs": [ 32 | { 33 | "name": "stdout", 34 | "output_type": "stream", 35 | "text": [ 36 | "Torch version: 1.7.1\n" 37 | ] 38 | } 39 | ], 40 | "source": [ 41 | "import numpy as np\n", 42 | "import torch\n", 43 | "import pickle\n", 44 | "import time\n", 45 | "print(\"Torch version:\", torch.__version__)\n", 46 | "\n", 47 | "assert torch.__version__.split(\".\") >= [\"1\", \"7\", \"1\"], \"PyTorch 1.7.1 or later is required\"\n", 48 | "\n", 49 | "import os\n", 50 | "import matplotlib.pyplot as plt\n", 51 | "from collections import OrderedDict\n", 52 | "import torch\n", 53 | "\n", 54 | "%matplotlib inline\n", 55 | "%config InlineBackend.figure_format = 'retina'" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "# Load CLIP" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 2, 68 | "metadata": { 69 | "colab": { 70 | "base_uri": "https://localhost:8080/" 71 | }, 72 | "id": "uLFS29hnhlY4", 73 | "outputId": "11779e1e-8bdd-4167-c18e-d26bdd6b67db" 74 | }, 75 | "outputs": [ 76 | { 77 | "data": { 78 | "text/plain": [ 79 | "['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']" 80 | ] 81 | }, 82 | "execution_count": 2, 83 | "metadata": {}, 84 | "output_type": "execute_result" 85 | } 86 | ], 87 | "source": [ 88 | "import clip\n", 89 | "\n", 90 | "clip.available_models()" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "# ViT-B-32.json\n", 100 | "# copied from https://github.com/mlfoundations/open_clip/blob/91f6cce16b7bee90b3b5d38ca305b5b3b67cc200/src/training/model_configs/ViT-B-32.json\n", 101 | "model_info = {\n", 102 | " \"embed_dim\": 512,\n", 103 | " \"image_resolution\": 224,\n", 104 | " \"vision_layers\": 12,\n", 105 | " \"vision_width\": 768,\n", 106 | " \"vision_patch_size\": 32,\n", 107 | " \"context_length\": 77,\n", 108 | " \"vocab_size\": 49408,\n", 109 | " \"transformer_width\": 512,\n", 110 | " \"transformer_heads\": 8,\n", 111 | " \"transformer_layers\": 12\n", 112 | "} \n", 113 | "from clip.model import CLIP\n", 114 | "model = CLIP(**model_info)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 4, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "from torchvision import transforms\n", 124 | "input_size = model_info['image_resolution']\n", 125 | "preprocess = transforms.Compose([\n", 126 | " transforms.Resize(input_size),\n", 127 | " transforms.CenterCrop(input_size),\n", 128 | " transforms.ToTensor(),\n", 129 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", 130 | " ])" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 5, 136 | "metadata": { 137 | "colab": { 138 | "base_uri": "https://localhost:8080/" 139 | }, 140 | "id": "IBRVTY9lbGm8", 141 | "outputId": "f06fd2fd-6126-475b-87d0-b10aa3b7da49" 142 | }, 143 | "outputs": [ 144 | { 145 | "name": "stdout", 146 | "output_type": "stream", 147 | "text": [ 148 | "Model parameters: 151,277,313\n", 149 | "Input resolution: 224\n", 150 | "Context length: 77\n", 151 | "Vocab size: 49408\n" 152 | ] 153 | } 154 | ], 155 | "source": [ 156 | "\n", 157 | "model.cuda().eval()\n", 158 | "input_resolution = model.visual.input_resolution\n", 159 | "context_length = model.context_length\n", 160 | "vocab_size = model.vocab_size\n", 161 | "\n", 162 | "print(\"Model parameters:\", f\"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}\")\n", 163 | "print(\"Input resolution:\", input_resolution)\n", 164 | "print(\"Context length:\", context_length)\n", 165 | "print(\"Vocab size:\", vocab_size)\n", 166 | "\n", 167 | "clip_model = model" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 6, 173 | "metadata": {}, 174 | "outputs": [ 175 | { 176 | "data": { 177 | "text/plain": [ 178 | "torchvision.transforms.transforms.Compose" 179 | ] 180 | }, 181 | "execution_count": 6, 182 | "metadata": {}, 183 | "output_type": "execute_result" 184 | } 185 | ], 186 | "source": [ 187 | "type(preprocess)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "# Load Data" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 7, 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "name": "stdout", 204 | "output_type": "stream", 205 | "text": [ 206 | "loading annotations into memory...\n", 207 | "Done (t=0.10s)\n", 208 | "creating index...\n", 209 | "index created!\n" 210 | ] 211 | } 212 | ], 213 | "source": [ 214 | "import torchvision\n", 215 | "from torch.utils.data import DataLoader\n", 216 | "\n", 217 | "def target_transform(caption_list):\n", 218 | " caption = caption_list[0] # only the first caption\n", 219 | " return clip.tokenize(caption)[0]\n", 220 | "\n", 221 | "# coco_train_dataset = torchvision.datasets.CocoCaptions(\n", 222 | "# root = '/home/ubuntu/data/coco/train2017',\n", 223 | "# annFile = '/home/ubuntu/data/coco/annotations/captions_train2017.json',\n", 224 | "# transform=preprocess,\n", 225 | "# target_transform=target_transform,\n", 226 | "# )\n", 227 | "\n", 228 | "coco_val_dataset = torchvision.datasets.CocoCaptions(\n", 229 | " root = '/home/ubuntu/data/coco/val2017',\n", 230 | " annFile = '/home/ubuntu/data/coco/annotations/captions_val2017.json',\n", 231 | " transform=preprocess,\n", 232 | " target_transform=target_transform,\n", 233 | " )" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": 8, 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "# coco_train_dataloader = DataLoader(coco_train_dataset, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)\n", 243 | "coco_val_dataloader = DataLoader(coco_val_dataset, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)" 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "metadata": {}, 249 | "source": [ 250 | "# ResNet" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 9, 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "import torch\n", 260 | "import torch.nn as nn\n", 261 | "import torchvision.models as models\n", 262 | "from torch.autograd import Variable\n", 263 | "\n", 264 | "resnet18 = models.resnet18(pretrained=False) # resnet18 = models.resnet18(pretrained=True)\n", 265 | "modules=list(resnet18.children())[:-1]\n", 266 | "resnet18=nn.Sequential(*modules)\n", 267 | "for p in resnet18.parameters():\n", 268 | " p.requires_grad = False\n", 269 | "\n", 270 | "resnet18.cuda().eval()\n", 271 | "target_model = resnet18\n" 272 | ] 273 | }, 274 | { 275 | "cell_type": "markdown", 276 | "metadata": {}, 277 | "source": [ 278 | "# Extractor loop\n" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 10, 284 | "metadata": {}, 285 | "outputs": [ 286 | { 287 | "name": "stdout", 288 | "output_type": "stream", 289 | "text": [ 290 | "clip_image_features_list (5000, 512)\n", 291 | "target_image_features_list (5000, 512)\n", 292 | "\n", 293 | "Feature Extraction completed in 0m 45s\n" 294 | ] 295 | } 296 | ], 297 | "source": [ 298 | "since = time.time()\n", 299 | "dataloaders = {\n", 300 | " # 'train': coco_train_dataloader, \n", 301 | " 'val': coco_val_dataloader,\n", 302 | "}\n", 303 | "# Each epoch has a training and validation phase\n", 304 | "for phase in ['val']: # ['train', 'val',]:\n", 305 | "\n", 306 | " clip_model.eval() # Set model to evaluate mode, for extraction\n", 307 | " ##################################\n", 308 | " # Fields to be stored for postprocessing \n", 309 | " ##################################\n", 310 | " clip_image_features_list = []\n", 311 | " clip_text_features_list = []\n", 312 | " target_image_features_list = []\n", 313 | "\n", 314 | " # Iterate over data.\n", 315 | " for inputs, captions in dataloaders[phase]:\n", 316 | " # image_input = inputs.cuda(non_blocking=True)\n", 317 | " # text_input = captions.cuda(non_blocking=True)\n", 318 | "\n", 319 | " batch_size = len(captions)\n", 320 | " image_input = torch.randn((batch_size, 3, 224, 224)).cuda(non_blocking=True)\n", 321 | " text_input = torch.randint(0, 49408, (batch_size, 77)).cuda(non_blocking=True)\n", 322 | "\n", 323 | " \n", 324 | " with torch.set_grad_enabled(False):\n", 325 | " clip_image_features = clip_model.encode_image(image_input).float()\n", 326 | " clip_text_features = clip_model.encode_text(text_input).float()\n", 327 | " target_image_features = target_model(image_input).squeeze() \n", 328 | " ##################################\n", 329 | " # Evaluation book-keeping Field \n", 330 | " ##################################\n", 331 | " clip_image_features_list.append( clip_image_features.cpu().numpy() )\n", 332 | " clip_text_features_list.append( clip_text_features.cpu().numpy() )\n", 333 | " target_image_features_list.append( target_image_features.cpu().numpy() )\n", 334 | "\n", 335 | " ##################################\n", 336 | " # Evaluation book-keeping Field \n", 337 | " ##################################\n", 338 | " clip_image_features_list = np.concatenate( clip_image_features_list, axis=0)\n", 339 | " clip_text_features_list = np.concatenate( clip_text_features_list, axis=0)\n", 340 | " target_image_features_list = np.concatenate( target_image_features_list, axis=0)\n", 341 | " print('clip_image_features_list', clip_image_features_list.shape)\n", 342 | " print('target_image_features_list', target_image_features_list.shape)\n", 343 | "\n", 344 | " dump_result_dict = {\n", 345 | " \"clip_image_features_list\": clip_image_features_list, \n", 346 | " \"clip_text_features_list\" : clip_text_features_list,\n", 347 | " \"target_image_features_list\": target_image_features_list, \n", 348 | " }\n", 349 | " with open(os.path.join('features', 'feature_dump_{}.pkl'.format(phase) ), \"wb\") as pkl_file:\n", 350 | " pickle.dump(\n", 351 | " dump_result_dict, \n", 352 | " pkl_file, \n", 353 | " )\n", 354 | "\n", 355 | "print()\n", 356 | "\n", 357 | "time_elapsed = time.time() - since\n", 358 | "print('Feature Extraction completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": null, 364 | "metadata": {}, 365 | "outputs": [], 366 | "source": [] 367 | } 368 | ], 369 | "metadata": { 370 | "accelerator": "GPU", 371 | "colab": { 372 | "collapsed_sections": [], 373 | "name": "Interacting with CLIP.ipynb", 374 | "provenance": [] 375 | }, 376 | "kernelspec": { 377 | "display_name": "Python 3", 378 | "name": "python3" 379 | }, 380 | "language_info": { 381 | "codemirror_mode": { 382 | "name": "ipython", 383 | "version": 3 384 | }, 385 | "file_extension": ".py", 386 | "mimetype": "text/x-python", 387 | "name": "python", 388 | "nbconvert_exporter": "python", 389 | "pygments_lexer": "ipython3", 390 | "version": "3.9.7" 391 | }, 392 | "widgets": { 393 | "application/vnd.jupyter.widget-state+json": { 394 | "12e23e2819094ee0a079d4eb77cfc4f9": { 395 | "model_module": "@jupyter-widgets/base", 396 | "model_module_version": "1.2.0", 397 | "model_name": "LayoutModel", 398 | "state": { 399 | "_model_module": "@jupyter-widgets/base", 400 | "_model_module_version": "1.2.0", 401 | "_model_name": "LayoutModel", 402 | "_view_count": null, 403 | "_view_module": "@jupyter-widgets/base", 404 | "_view_module_version": "1.2.0", 405 | "_view_name": "LayoutView", 406 | "align_content": null, 407 | "align_items": null, 408 | "align_self": null, 409 | "border": null, 410 | "bottom": null, 411 | "display": null, 412 | "flex": null, 413 | "flex_flow": null, 414 | "grid_area": null, 415 | "grid_auto_columns": null, 416 | "grid_auto_flow": null, 417 | "grid_auto_rows": null, 418 | "grid_column": null, 419 | "grid_gap": null, 420 | "grid_row": null, 421 | "grid_template_areas": null, 422 | "grid_template_columns": null, 423 | "grid_template_rows": null, 424 | "height": null, 425 | "justify_content": null, 426 | "justify_items": null, 427 | "left": null, 428 | "margin": null, 429 | "max_height": null, 430 | "max_width": null, 431 | "min_height": null, 432 | "min_width": null, 433 | "object_fit": null, 434 | "object_position": null, 435 | "order": null, 436 | "overflow": null, 437 | "overflow_x": null, 438 | "overflow_y": null, 439 | "padding": null, 440 | "right": null, 441 | "top": null, 442 | "visibility": null, 443 | "width": null 444 | } 445 | }, 446 | "1369964d45004b5e95a058910b2a33e6": { 447 | "model_module": "@jupyter-widgets/controls", 448 | "model_module_version": "1.5.0", 449 | "model_name": "HBoxModel", 450 | "state": { 451 | "_dom_classes": [], 452 | "_model_module": "@jupyter-widgets/controls", 453 | "_model_module_version": "1.5.0", 454 | "_model_name": "HBoxModel", 455 | "_view_count": null, 456 | "_view_module": "@jupyter-widgets/controls", 457 | "_view_module_version": "1.5.0", 458 | "_view_name": "HBoxView", 459 | "box_style": "", 460 | "children": [ 461 | "IPY_MODEL_7a5f52e56ede4ac3abe37a3ece007dc9", 462 | "IPY_MODEL_ce8b0faa1a1340b5a504d7b3546b3ccb" 463 | ], 464 | "layout": "IPY_MODEL_12e23e2819094ee0a079d4eb77cfc4f9" 465 | } 466 | }, 467 | "161969cae25a49f38aacd1568d3cac6c": { 468 | "model_module": "@jupyter-widgets/base", 469 | "model_module_version": "1.2.0", 470 | "model_name": "LayoutModel", 471 | "state": { 472 | "_model_module": "@jupyter-widgets/base", 473 | "_model_module_version": "1.2.0", 474 | "_model_name": "LayoutModel", 475 | "_view_count": null, 476 | "_view_module": "@jupyter-widgets/base", 477 | "_view_module_version": "1.2.0", 478 | "_view_name": "LayoutView", 479 | "align_content": null, 480 | "align_items": null, 481 | "align_self": null, 482 | "border": null, 483 | "bottom": null, 484 | "display": null, 485 | "flex": null, 486 | "flex_flow": null, 487 | "grid_area": null, 488 | "grid_auto_columns": null, 489 | "grid_auto_flow": null, 490 | "grid_auto_rows": null, 491 | "grid_column": null, 492 | "grid_gap": null, 493 | "grid_row": null, 494 | "grid_template_areas": null, 495 | "grid_template_columns": null, 496 | "grid_template_rows": null, 497 | "height": null, 498 | "justify_content": null, 499 | "justify_items": null, 500 | "left": null, 501 | "margin": null, 502 | "max_height": null, 503 | "max_width": null, 504 | "min_height": null, 505 | "min_width": null, 506 | "object_fit": null, 507 | "object_position": null, 508 | "order": null, 509 | "overflow": null, 510 | "overflow_x": null, 511 | "overflow_y": null, 512 | "padding": null, 513 | "right": null, 514 | "top": null, 515 | "visibility": null, 516 | "width": null 517 | } 518 | }, 519 | "4a61c10fc00c4f04bb00b82e942da210": { 520 | "model_module": "@jupyter-widgets/base", 521 | "model_module_version": "1.2.0", 522 | "model_name": "LayoutModel", 523 | "state": { 524 | "_model_module": "@jupyter-widgets/base", 525 | "_model_module_version": "1.2.0", 526 | "_model_name": "LayoutModel", 527 | "_view_count": null, 528 | "_view_module": "@jupyter-widgets/base", 529 | "_view_module_version": "1.2.0", 530 | "_view_name": "LayoutView", 531 | "align_content": null, 532 | "align_items": null, 533 | "align_self": null, 534 | "border": null, 535 | "bottom": null, 536 | "display": null, 537 | "flex": null, 538 | "flex_flow": null, 539 | "grid_area": null, 540 | "grid_auto_columns": null, 541 | "grid_auto_flow": null, 542 | "grid_auto_rows": null, 543 | "grid_column": null, 544 | "grid_gap": null, 545 | "grid_row": null, 546 | "grid_template_areas": null, 547 | "grid_template_columns": null, 548 | "grid_template_rows": null, 549 | "height": null, 550 | "justify_content": null, 551 | "justify_items": null, 552 | "left": null, 553 | "margin": null, 554 | "max_height": null, 555 | "max_width": null, 556 | "min_height": null, 557 | "min_width": null, 558 | "object_fit": null, 559 | "object_position": null, 560 | "order": null, 561 | "overflow": null, 562 | "overflow_x": null, 563 | "overflow_y": null, 564 | "padding": null, 565 | "right": null, 566 | "top": null, 567 | "visibility": null, 568 | "width": null 569 | } 570 | }, 571 | "5e6adc4592124a4581b85f4c1f3bab4d": { 572 | "model_module": "@jupyter-widgets/controls", 573 | "model_module_version": "1.5.0", 574 | "model_name": "ProgressStyleModel", 575 | "state": { 576 | "_model_module": "@jupyter-widgets/controls", 577 | "_model_module_version": "1.5.0", 578 | "_model_name": "ProgressStyleModel", 579 | "_view_count": null, 580 | "_view_module": "@jupyter-widgets/base", 581 | "_view_module_version": "1.2.0", 582 | "_view_name": "StyleView", 583 | "bar_color": null, 584 | "description_width": "initial" 585 | } 586 | }, 587 | "7a5f52e56ede4ac3abe37a3ece007dc9": { 588 | "model_module": "@jupyter-widgets/controls", 589 | "model_module_version": "1.5.0", 590 | "model_name": "FloatProgressModel", 591 | "state": { 592 | "_dom_classes": [], 593 | "_model_module": "@jupyter-widgets/controls", 594 | "_model_module_version": "1.5.0", 595 | "_model_name": "FloatProgressModel", 596 | "_view_count": null, 597 | "_view_module": "@jupyter-widgets/controls", 598 | "_view_module_version": "1.5.0", 599 | "_view_name": "ProgressView", 600 | "bar_style": "success", 601 | "description": "", 602 | "description_tooltip": null, 603 | "layout": "IPY_MODEL_4a61c10fc00c4f04bb00b82e942da210", 604 | "max": 169001437, 605 | "min": 0, 606 | "orientation": "horizontal", 607 | "style": "IPY_MODEL_5e6adc4592124a4581b85f4c1f3bab4d", 608 | "value": 169001437 609 | } 610 | }, 611 | "b597cd6f6cd443aba4bf4491ac7f957e": { 612 | "model_module": "@jupyter-widgets/controls", 613 | "model_module_version": "1.5.0", 614 | "model_name": "DescriptionStyleModel", 615 | "state": { 616 | "_model_module": "@jupyter-widgets/controls", 617 | "_model_module_version": "1.5.0", 618 | "_model_name": "DescriptionStyleModel", 619 | "_view_count": null, 620 | "_view_module": "@jupyter-widgets/base", 621 | "_view_module_version": "1.2.0", 622 | "_view_name": "StyleView", 623 | "description_width": "" 624 | } 625 | }, 626 | "ce8b0faa1a1340b5a504d7b3546b3ccb": { 627 | "model_module": "@jupyter-widgets/controls", 628 | "model_module_version": "1.5.0", 629 | "model_name": "HTMLModel", 630 | "state": { 631 | "_dom_classes": [], 632 | "_model_module": "@jupyter-widgets/controls", 633 | "_model_module_version": "1.5.0", 634 | "_model_name": "HTMLModel", 635 | "_view_count": null, 636 | "_view_module": "@jupyter-widgets/controls", 637 | "_view_module_version": "1.5.0", 638 | "_view_name": "HTMLView", 639 | "description": "", 640 | "description_tooltip": null, 641 | "layout": "IPY_MODEL_161969cae25a49f38aacd1568d3cac6c", 642 | "placeholder": "​", 643 | "style": "IPY_MODEL_b597cd6f6cd443aba4bf4491ac7f957e", 644 | "value": " 169001984/? [00:06<00:00, 25734958.25it/s]" 645 | } 646 | } 647 | } 648 | } 649 | }, 650 | "nbformat": 4, 651 | "nbformat_minor": 0 652 | } 653 | -------------------------------------------------------------------------------- /Figure_2_Cone_Effect/Figure_2a_random_init_real_data/coco-extract.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "YPHN7PJgKOzb" 7 | }, 8 | "source": [ 9 | "# Image Feature Pair Extract - CLIP, ResNet18. \n", 10 | "conda activate clip\n", 11 | "\n", 12 | "\n", 13 | "clip_image_features_list (118287, 512)\n", 14 | "target_image_features_list (118287, 512)\n", 15 | "clip_image_features_list (5000, 512)\n", 16 | "target_image_features_list (5000, 512)\n", 17 | "\n", 18 | "Feature extraction complete in 6m 16s" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "metadata": { 25 | "colab": { 26 | "base_uri": "https://localhost:8080/" 27 | }, 28 | "id": "C1hkDT38hSaP", 29 | "outputId": "70a44964-883d-4fd0-b95a-2c7f2b19aca9" 30 | }, 31 | "outputs": [ 32 | { 33 | "name": "stdout", 34 | "output_type": "stream", 35 | "text": [ 36 | "Torch version: 1.7.1\n" 37 | ] 38 | } 39 | ], 40 | "source": [ 41 | "import numpy as np\n", 42 | "import torch\n", 43 | "import pickle\n", 44 | "import time\n", 45 | "print(\"Torch version:\", torch.__version__)\n", 46 | "\n", 47 | "assert torch.__version__.split(\".\") >= [\"1\", \"7\", \"1\"], \"PyTorch 1.7.1 or later is required\"\n", 48 | "\n", 49 | "import os\n", 50 | "import matplotlib.pyplot as plt\n", 51 | "from collections import OrderedDict\n", 52 | "import torch\n", 53 | "\n", 54 | "%matplotlib inline\n", 55 | "%config InlineBackend.figure_format = 'retina'" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "# Load CLIP" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 2, 68 | "metadata": { 69 | "colab": { 70 | "base_uri": "https://localhost:8080/" 71 | }, 72 | "id": "uLFS29hnhlY4", 73 | "outputId": "11779e1e-8bdd-4167-c18e-d26bdd6b67db" 74 | }, 75 | "outputs": [ 76 | { 77 | "data": { 78 | "text/plain": [ 79 | "['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']" 80 | ] 81 | }, 82 | "execution_count": 2, 83 | "metadata": {}, 84 | "output_type": "execute_result" 85 | } 86 | ], 87 | "source": [ 88 | "import clip\n", 89 | "\n", 90 | "clip.available_models()" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "# ViT-B-32.json\n", 100 | "# copied from https://github.com/mlfoundations/open_clip/blob/91f6cce16b7bee90b3b5d38ca305b5b3b67cc200/src/training/model_configs/ViT-B-32.json\n", 101 | "model_info = {\n", 102 | " \"embed_dim\": 512,\n", 103 | " \"image_resolution\": 224,\n", 104 | " \"vision_layers\": 12,\n", 105 | " \"vision_width\": 768,\n", 106 | " \"vision_patch_size\": 32,\n", 107 | " \"context_length\": 77,\n", 108 | " \"vocab_size\": 49408,\n", 109 | " \"transformer_width\": 512,\n", 110 | " \"transformer_heads\": 8,\n", 111 | " \"transformer_layers\": 12\n", 112 | "} \n", 113 | "from clip.model import CLIP\n", 114 | "model = CLIP(**model_info)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 4, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "from torchvision import transforms\n", 124 | "input_size = model_info['image_resolution']\n", 125 | "preprocess = transforms.Compose([\n", 126 | " transforms.Resize(input_size),\n", 127 | " transforms.CenterCrop(input_size),\n", 128 | " transforms.ToTensor(),\n", 129 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", 130 | " ])" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 5, 136 | "metadata": { 137 | "colab": { 138 | "base_uri": "https://localhost:8080/" 139 | }, 140 | "id": "IBRVTY9lbGm8", 141 | "outputId": "f06fd2fd-6126-475b-87d0-b10aa3b7da49" 142 | }, 143 | "outputs": [ 144 | { 145 | "name": "stdout", 146 | "output_type": "stream", 147 | "text": [ 148 | "Model parameters: 151,277,313\n", 149 | "Input resolution: 224\n", 150 | "Context length: 77\n", 151 | "Vocab size: 49408\n" 152 | ] 153 | } 154 | ], 155 | "source": [ 156 | "\n", 157 | "model.cuda().eval()\n", 158 | "input_resolution = model.visual.input_resolution\n", 159 | "context_length = model.context_length\n", 160 | "vocab_size = model.vocab_size\n", 161 | "\n", 162 | "print(\"Model parameters:\", f\"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}\")\n", 163 | "print(\"Input resolution:\", input_resolution)\n", 164 | "print(\"Context length:\", context_length)\n", 165 | "print(\"Vocab size:\", vocab_size)\n", 166 | "\n", 167 | "clip_model = model" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 6, 173 | "metadata": {}, 174 | "outputs": [ 175 | { 176 | "data": { 177 | "text/plain": [ 178 | "torchvision.transforms.transforms.Compose" 179 | ] 180 | }, 181 | "execution_count": 6, 182 | "metadata": {}, 183 | "output_type": "execute_result" 184 | } 185 | ], 186 | "source": [ 187 | "type(preprocess)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "# Load Data" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 7, 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "name": "stdout", 204 | "output_type": "stream", 205 | "text": [ 206 | "loading annotations into memory...\n", 207 | "Done (t=0.14s)\n", 208 | "creating index...\n", 209 | "index created!\n" 210 | ] 211 | } 212 | ], 213 | "source": [ 214 | "import torchvision\n", 215 | "from torch.utils.data import DataLoader\n", 216 | "\n", 217 | "def target_transform(caption_list):\n", 218 | " caption = caption_list[0] # only the first caption\n", 219 | " return clip.tokenize(caption)[0]\n", 220 | "\n", 221 | "# coco_train_dataset = torchvision.datasets.CocoCaptions(\n", 222 | "# root = '/home/ubuntu/data/coco/train2017',\n", 223 | "# annFile = '/home/ubuntu/data/coco/annotations/captions_train2017.json',\n", 224 | "# transform=preprocess,\n", 225 | "# target_transform=target_transform,\n", 226 | "# )\n", 227 | "\n", 228 | "coco_val_dataset = torchvision.datasets.CocoCaptions(\n", 229 | " root = '/home/ubuntu/data/coco/val2017',\n", 230 | " annFile = '/home/ubuntu/data/coco/annotations/captions_val2017.json',\n", 231 | " transform=preprocess,\n", 232 | " target_transform=target_transform,\n", 233 | " )" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": 8, 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "# coco_train_dataloader = DataLoader(coco_train_dataset, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)\n", 243 | "coco_val_dataloader = DataLoader(coco_val_dataset, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)" 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "metadata": {}, 249 | "source": [ 250 | "# ResNet" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 9, 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "import torch\n", 260 | "import torch.nn as nn\n", 261 | "import torchvision.models as models\n", 262 | "from torch.autograd import Variable\n", 263 | "\n", 264 | "resnet18 = models.resnet18(pretrained=False) # resnet18 = models.resnet18(pretrained=True)\n", 265 | "modules=list(resnet18.children())[:-1]\n", 266 | "resnet18=nn.Sequential(*modules)\n", 267 | "for p in resnet18.parameters():\n", 268 | " p.requires_grad = False\n", 269 | "\n", 270 | "resnet18.cuda().eval()\n", 271 | "target_model = resnet18\n" 272 | ] 273 | }, 274 | { 275 | "cell_type": "markdown", 276 | "metadata": {}, 277 | "source": [ 278 | "# Extractor loop\n" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 11, 284 | "metadata": {}, 285 | "outputs": [ 286 | { 287 | "name": "stdout", 288 | "output_type": "stream", 289 | "text": [ 290 | "clip_image_features_list (5000, 512)\n", 291 | "target_image_features_list (5000, 512)\n", 292 | "\n", 293 | "Feature Extraction completed in 0m 33s\n" 294 | ] 295 | } 296 | ], 297 | "source": [ 298 | "since = time.time()\n", 299 | "dataloaders = {\n", 300 | " # 'train': coco_train_dataloader, \n", 301 | " 'val': coco_val_dataloader,\n", 302 | "}\n", 303 | "# Each epoch has a training and validation phase\n", 304 | "for phase in ['val']: # ['train', 'val',]:\n", 305 | "\n", 306 | " clip_model.eval() # Set model to evaluate mode, for extraction\n", 307 | " ##################################\n", 308 | " # Fields to be stored for postprocessing \n", 309 | " ##################################\n", 310 | " clip_image_features_list = []\n", 311 | " clip_text_features_list = []\n", 312 | " target_image_features_list = []\n", 313 | "\n", 314 | " # Iterate over data.\n", 315 | " for inputs, captions in dataloaders[phase]:\n", 316 | " image_input = inputs.cuda(non_blocking=True)\n", 317 | " text_input = captions.cuda(non_blocking=True)\n", 318 | " # TODO: add text here\n", 319 | " \n", 320 | " with torch.set_grad_enabled(False):\n", 321 | " clip_image_features = clip_model.encode_image(image_input).float()\n", 322 | " clip_text_features = clip_model.encode_text(text_input).float()\n", 323 | " target_image_features = target_model(image_input).squeeze() \n", 324 | " ##################################\n", 325 | " # Evaluation book-keeping Field \n", 326 | " ##################################\n", 327 | " clip_image_features_list.append( clip_image_features.cpu().numpy() )\n", 328 | " clip_text_features_list.append( clip_text_features.cpu().numpy() )\n", 329 | " target_image_features_list.append( target_image_features.cpu().numpy() )\n", 330 | "\n", 331 | " ##################################\n", 332 | " # Evaluation book-keeping Field \n", 333 | " ##################################\n", 334 | " clip_image_features_list = np.concatenate( clip_image_features_list, axis=0)\n", 335 | " clip_text_features_list = np.concatenate( clip_text_features_list, axis=0)\n", 336 | " target_image_features_list = np.concatenate( target_image_features_list, axis=0)\n", 337 | " print('clip_image_features_list', clip_image_features_list.shape)\n", 338 | " print('target_image_features_list', target_image_features_list.shape)\n", 339 | "\n", 340 | " dump_result_dict = {\n", 341 | " \"clip_image_features_list\": clip_image_features_list, \n", 342 | " \"clip_text_features_list\" : clip_text_features_list,\n", 343 | " \"target_image_features_list\": target_image_features_list, \n", 344 | " }\n", 345 | " with open(os.path.join('features', 'feature_dump_{}.pkl'.format(phase) ), \"wb\") as pkl_file:\n", 346 | " pickle.dump(\n", 347 | " dump_result_dict, \n", 348 | " pkl_file, \n", 349 | " )\n", 350 | "\n", 351 | "print()\n", 352 | "\n", 353 | "time_elapsed = time.time() - since\n", 354 | "print('Feature Extraction completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": null, 360 | "metadata": {}, 361 | "outputs": [], 362 | "source": [] 363 | } 364 | ], 365 | "metadata": { 366 | "accelerator": "GPU", 367 | "colab": { 368 | "collapsed_sections": [], 369 | "name": "Interacting with CLIP.ipynb", 370 | "provenance": [] 371 | }, 372 | "kernelspec": { 373 | "display_name": "Python 3", 374 | "name": "python3" 375 | }, 376 | "language_info": { 377 | "codemirror_mode": { 378 | "name": "ipython", 379 | "version": 3 380 | }, 381 | "file_extension": ".py", 382 | "mimetype": "text/x-python", 383 | "name": "python", 384 | "nbconvert_exporter": "python", 385 | "pygments_lexer": "ipython3", 386 | "version": "3.9.7" 387 | }, 388 | "widgets": { 389 | "application/vnd.jupyter.widget-state+json": { 390 | "12e23e2819094ee0a079d4eb77cfc4f9": { 391 | "model_module": "@jupyter-widgets/base", 392 | "model_module_version": "1.2.0", 393 | "model_name": "LayoutModel", 394 | "state": { 395 | "_model_module": "@jupyter-widgets/base", 396 | "_model_module_version": "1.2.0", 397 | "_model_name": "LayoutModel", 398 | "_view_count": null, 399 | "_view_module": "@jupyter-widgets/base", 400 | "_view_module_version": "1.2.0", 401 | "_view_name": "LayoutView", 402 | "align_content": null, 403 | "align_items": null, 404 | "align_self": null, 405 | "border": null, 406 | "bottom": null, 407 | "display": null, 408 | "flex": null, 409 | "flex_flow": null, 410 | "grid_area": null, 411 | "grid_auto_columns": null, 412 | "grid_auto_flow": null, 413 | "grid_auto_rows": null, 414 | "grid_column": null, 415 | "grid_gap": null, 416 | "grid_row": null, 417 | "grid_template_areas": null, 418 | "grid_template_columns": null, 419 | "grid_template_rows": null, 420 | "height": null, 421 | "justify_content": null, 422 | "justify_items": null, 423 | "left": null, 424 | "margin": null, 425 | "max_height": null, 426 | "max_width": null, 427 | "min_height": null, 428 | "min_width": null, 429 | "object_fit": null, 430 | "object_position": null, 431 | "order": null, 432 | "overflow": null, 433 | "overflow_x": null, 434 | "overflow_y": null, 435 | "padding": null, 436 | "right": null, 437 | "top": null, 438 | "visibility": null, 439 | "width": null 440 | } 441 | }, 442 | "1369964d45004b5e95a058910b2a33e6": { 443 | "model_module": "@jupyter-widgets/controls", 444 | "model_module_version": "1.5.0", 445 | "model_name": "HBoxModel", 446 | "state": { 447 | "_dom_classes": [], 448 | "_model_module": "@jupyter-widgets/controls", 449 | "_model_module_version": "1.5.0", 450 | "_model_name": "HBoxModel", 451 | "_view_count": null, 452 | "_view_module": "@jupyter-widgets/controls", 453 | "_view_module_version": "1.5.0", 454 | "_view_name": "HBoxView", 455 | "box_style": "", 456 | "children": [ 457 | "IPY_MODEL_7a5f52e56ede4ac3abe37a3ece007dc9", 458 | "IPY_MODEL_ce8b0faa1a1340b5a504d7b3546b3ccb" 459 | ], 460 | "layout": "IPY_MODEL_12e23e2819094ee0a079d4eb77cfc4f9" 461 | } 462 | }, 463 | "161969cae25a49f38aacd1568d3cac6c": { 464 | "model_module": "@jupyter-widgets/base", 465 | "model_module_version": "1.2.0", 466 | "model_name": "LayoutModel", 467 | "state": { 468 | "_model_module": "@jupyter-widgets/base", 469 | "_model_module_version": "1.2.0", 470 | "_model_name": "LayoutModel", 471 | "_view_count": null, 472 | "_view_module": "@jupyter-widgets/base", 473 | "_view_module_version": "1.2.0", 474 | "_view_name": "LayoutView", 475 | "align_content": null, 476 | "align_items": null, 477 | "align_self": null, 478 | "border": null, 479 | "bottom": null, 480 | "display": null, 481 | "flex": null, 482 | "flex_flow": null, 483 | "grid_area": null, 484 | "grid_auto_columns": null, 485 | "grid_auto_flow": null, 486 | "grid_auto_rows": null, 487 | "grid_column": null, 488 | "grid_gap": null, 489 | "grid_row": null, 490 | "grid_template_areas": null, 491 | "grid_template_columns": null, 492 | "grid_template_rows": null, 493 | "height": null, 494 | "justify_content": null, 495 | "justify_items": null, 496 | "left": null, 497 | "margin": null, 498 | "max_height": null, 499 | "max_width": null, 500 | "min_height": null, 501 | "min_width": null, 502 | "object_fit": null, 503 | "object_position": null, 504 | "order": null, 505 | "overflow": null, 506 | "overflow_x": null, 507 | "overflow_y": null, 508 | "padding": null, 509 | "right": null, 510 | "top": null, 511 | "visibility": null, 512 | "width": null 513 | } 514 | }, 515 | "4a61c10fc00c4f04bb00b82e942da210": { 516 | "model_module": "@jupyter-widgets/base", 517 | "model_module_version": "1.2.0", 518 | "model_name": "LayoutModel", 519 | "state": { 520 | "_model_module": "@jupyter-widgets/base", 521 | "_model_module_version": "1.2.0", 522 | "_model_name": "LayoutModel", 523 | "_view_count": null, 524 | "_view_module": "@jupyter-widgets/base", 525 | "_view_module_version": "1.2.0", 526 | "_view_name": "LayoutView", 527 | "align_content": null, 528 | "align_items": null, 529 | "align_self": null, 530 | "border": null, 531 | "bottom": null, 532 | "display": null, 533 | "flex": null, 534 | "flex_flow": null, 535 | "grid_area": null, 536 | "grid_auto_columns": null, 537 | "grid_auto_flow": null, 538 | "grid_auto_rows": null, 539 | "grid_column": null, 540 | "grid_gap": null, 541 | "grid_row": null, 542 | "grid_template_areas": null, 543 | "grid_template_columns": null, 544 | "grid_template_rows": null, 545 | "height": null, 546 | "justify_content": null, 547 | "justify_items": null, 548 | "left": null, 549 | "margin": null, 550 | "max_height": null, 551 | "max_width": null, 552 | "min_height": null, 553 | "min_width": null, 554 | "object_fit": null, 555 | "object_position": null, 556 | "order": null, 557 | "overflow": null, 558 | "overflow_x": null, 559 | "overflow_y": null, 560 | "padding": null, 561 | "right": null, 562 | "top": null, 563 | "visibility": null, 564 | "width": null 565 | } 566 | }, 567 | "5e6adc4592124a4581b85f4c1f3bab4d": { 568 | "model_module": "@jupyter-widgets/controls", 569 | "model_module_version": "1.5.0", 570 | "model_name": "ProgressStyleModel", 571 | "state": { 572 | "_model_module": "@jupyter-widgets/controls", 573 | "_model_module_version": "1.5.0", 574 | "_model_name": "ProgressStyleModel", 575 | "_view_count": null, 576 | "_view_module": "@jupyter-widgets/base", 577 | "_view_module_version": "1.2.0", 578 | "_view_name": "StyleView", 579 | "bar_color": null, 580 | "description_width": "initial" 581 | } 582 | }, 583 | "7a5f52e56ede4ac3abe37a3ece007dc9": { 584 | "model_module": "@jupyter-widgets/controls", 585 | "model_module_version": "1.5.0", 586 | "model_name": "FloatProgressModel", 587 | "state": { 588 | "_dom_classes": [], 589 | "_model_module": "@jupyter-widgets/controls", 590 | "_model_module_version": "1.5.0", 591 | "_model_name": "FloatProgressModel", 592 | "_view_count": null, 593 | "_view_module": "@jupyter-widgets/controls", 594 | "_view_module_version": "1.5.0", 595 | "_view_name": "ProgressView", 596 | "bar_style": "success", 597 | "description": "", 598 | "description_tooltip": null, 599 | "layout": "IPY_MODEL_4a61c10fc00c4f04bb00b82e942da210", 600 | "max": 169001437, 601 | "min": 0, 602 | "orientation": "horizontal", 603 | "style": "IPY_MODEL_5e6adc4592124a4581b85f4c1f3bab4d", 604 | "value": 169001437 605 | } 606 | }, 607 | "b597cd6f6cd443aba4bf4491ac7f957e": { 608 | "model_module": "@jupyter-widgets/controls", 609 | "model_module_version": "1.5.0", 610 | "model_name": "DescriptionStyleModel", 611 | "state": { 612 | "_model_module": "@jupyter-widgets/controls", 613 | "_model_module_version": "1.5.0", 614 | "_model_name": "DescriptionStyleModel", 615 | "_view_count": null, 616 | "_view_module": "@jupyter-widgets/base", 617 | "_view_module_version": "1.2.0", 618 | "_view_name": "StyleView", 619 | "description_width": "" 620 | } 621 | }, 622 | "ce8b0faa1a1340b5a504d7b3546b3ccb": { 623 | "model_module": "@jupyter-widgets/controls", 624 | "model_module_version": "1.5.0", 625 | "model_name": "HTMLModel", 626 | "state": { 627 | "_dom_classes": [], 628 | "_model_module": "@jupyter-widgets/controls", 629 | "_model_module_version": "1.5.0", 630 | "_model_name": "HTMLModel", 631 | "_view_count": null, 632 | "_view_module": "@jupyter-widgets/controls", 633 | "_view_module_version": "1.5.0", 634 | "_view_name": "HTMLView", 635 | "description": "", 636 | "description_tooltip": null, 637 | "layout": "IPY_MODEL_161969cae25a49f38aacd1568d3cac6c", 638 | "placeholder": "​", 639 | "style": "IPY_MODEL_b597cd6f6cd443aba4bf4491ac7f957e", 640 | "value": " 169001984/? [00:06<00:00, 25734958.25it/s]" 641 | } 642 | } 643 | } 644 | } 645 | }, 646 | "nbformat": 4, 647 | "nbformat_minor": 0 648 | } 649 | -------------------------------------------------------------------------------- /Figure_2_Cone_Effect/Figure_2c_scatter_cones_random_init/random_data/visualizePCA.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Visualize COCO features\n", 8 | "\n", 9 | "1. visualize coco features\n", 10 | "2. identify pca-one; what is its cosine similarity with the residual (should be very high)\n", 11 | "3. move along the direction, plot 1-dim loss landscape. [-2,-1,-0.5,0,0.5,1,2]\n", 12 | " - need to have a fn(scalar,), output loss. \n" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import argparse\n", 22 | "import os\n", 23 | "import random\n", 24 | "import shutil\n", 25 | "import time\n", 26 | "import warnings\n", 27 | "from enum import Enum\n", 28 | "import pickle\n", 29 | "import numpy as np\n", 30 | "from collections import defaultdict\n", 31 | "\n", 32 | "import torch\n", 33 | "import torch.nn as nn\n", 34 | "import torch.optim\n", 35 | "from torch.utils.data import Dataset, DataLoader\n", 36 | "import torch.backends.cudnn as cudnn\n", 37 | "\n", 38 | "import glob \n", 39 | "def my_norm(x):\n", 40 | " return x/np.linalg.norm(x, axis=-1, keepdims=True)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "data_dict_list = list()\n", 50 | "\n", 51 | "for pickle_path in glob.glob('./features*/feature_dump_*.pkl'):\n", 52 | " with open(pickle_path, 'rb') as pkl_file:\n", 53 | " data_dict = pickle.load(pkl_file)\n", 54 | " assert len(data_dict['clip_image_features_list']) == len(data_dict['clip_text_features_list'])\n", 55 | " # assert len(data_dict['clip_image_features_list']) == len(data_dict['target_image_features_list'])\n", 56 | " # print('Number of image-text pairs', len(data_dict['clip_image_features_list']))\n", 57 | " data_dict_list.append(data_dict)\n", 58 | "\n", 59 | "print('Number of experiment files loaded', len(data_dict_list))" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "# visualize.\n", 69 | "\n", 70 | "from sklearn.decomposition import PCA\n", 71 | "# from sklearn.decomposition import TruncatedSVD as PCA # showns as multiple lines. \n", 72 | "# from sklearn.manifold import TSNE as PCA # \n", 73 | "# import umap\n", 74 | "# from umap import UMAP as PCA\n", 75 | "import pandas as pd\n", 76 | "import matplotlib.pyplot as plt\n", 77 | "%matplotlib inline\n", 78 | "import seaborn as sns\n", 79 | "# sns.set(font_scale=2) # crazy big\n", 80 | "plt.rcParams['figure.dpi'] = 300\n", 81 | "plt.rcParams['savefig.dpi'] = 300\n", 82 | "sns.set_theme()\n" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "# Functionality: given a list of exp, plot one modality. \n", 92 | "sns.set_context(\"talk\", font_scale=1.5) # paper, notebook, talk, and poster; font_scale=1.5,\n", 93 | "\n", 94 | "def plot_scattered_cones(data_dict_list, modality_str, draw=True):\n", 95 | " assert modality_str in ['clip_image_features_list', 'clip_text_features_list', 'target_image_features_list']\n", 96 | " print('modality_str: ', modality_str)\n", 97 | " # dataset_size = len(data_dict_list[0][modality_str])\n", 98 | " dataset_size = 5000\n", 99 | "\n", 100 | " total_feature_list = list()\n", 101 | " label_list = list()\n", 102 | " for expriment_idx in range(len(data_dict_list)):\n", 103 | " total_feature_list.append(data_dict_list[expriment_idx][modality_str][:dataset_size])\n", 104 | " label_list.extend(['Random-{}'.format(expriment_idx+1)] * dataset_size)\n", 105 | " total_feature_np = np.concatenate(total_feature_list, axis=0) \n", 106 | " total_feature_np = my_norm(total_feature_np) # L2-normalize\n", 107 | " assert len(total_feature_np) == len(data_dict_list) * dataset_size\n", 108 | "\n", 109 | " pca = PCA(n_components=2)\n", 110 | " pca_result = pca.fit_transform(total_feature_np)\n", 111 | "\n", 112 | " df = pd.DataFrame()\n", 113 | " df['pca_one'] = pca_result[:,0]\n", 114 | " df['pca_two'] = pca_result[:,1] \n", 115 | " df['Random Seed'] = label_list\n", 116 | "\n", 117 | " if draw:\n", 118 | " plt.figure(figsize=(20.0,6.18 * 2))\n", 119 | " p1 = sns.scatterplot(\n", 120 | " x=\"pca_one\", y=\"pca_two\",\n", 121 | " hue=\"Random Seed\",\n", 122 | " data=df,\n", 123 | " legend=True,\n", 124 | " )\n", 125 | " plt.xlabel(\"\")\n", 126 | " plt.ylabel(\"\")\n", 127 | " plt.legend(title='Random Seed', loc='upper left', bbox_to_anchor=(1.00, 1.0, ), prop={'size': 18})\n", 128 | " plt.show()\n", 129 | "\n", 130 | " return df\n" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "df_clip_img = plot_scattered_cones(data_dict_list[:25], 'clip_image_features_list', draw=True)\n", 140 | "df_clip_txt = plot_scattered_cones(data_dict_list[:25], 'clip_text_features_list', draw=True)\n", 141 | "df_resnet = plot_scattered_cones(data_dict_list[:25], 'target_image_features_list', draw=True)\n" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "def draw_df(df):\n", 151 | " plt.figure(figsize=(20.0,6.18 * 2))\n", 152 | " df['Seed'] = df['Random Seed'].str.replace('Random-', '', regex=False)\n", 153 | " p1 = sns.scatterplot(\n", 154 | " x=\"pca_one\", y=\"pca_two\",\n", 155 | " hue=\"Seed\",\n", 156 | " data=df,\n", 157 | " legend=True,\n", 158 | " )\n", 159 | " plt.xlabel(\"\")\n", 160 | " plt.ylabel(\"\")\n", 161 | " plt.legend(title='Random Seed', loc='upper left', bbox_to_anchor=(1.00, 1.0, ), ncol=2) # prop={'size': 50}, \n", 162 | " plt.show()\n", 163 | " return\n", 164 | "\n", 165 | "draw_df(df_clip_img)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [] 174 | } 175 | ], 176 | "metadata": { 177 | "interpreter": { 178 | "hash": "09c077faaa20da841f22e0f4d12b4addb73e00d9291bc78d00732f9f39794f23" 179 | }, 180 | "kernelspec": { 181 | "display_name": "Python 3.9.7 ('clip')", 182 | "language": "python", 183 | "name": "python3" 184 | }, 185 | "language_info": { 186 | "codemirror_mode": { 187 | "name": "ipython", 188 | "version": 3 189 | }, 190 | "file_extension": ".py", 191 | "mimetype": "text/x-python", 192 | "name": "python", 193 | "nbconvert_exporter": "python", 194 | "pygments_lexer": "ipython3", 195 | "version": "3.9.7" 196 | }, 197 | "orig_nbformat": 4 198 | }, 199 | "nbformat": 4, 200 | "nbformat_minor": 2 201 | } 202 | -------------------------------------------------------------------------------- /Figure_2_Cone_Effect/Figure_2c_scatter_cones_random_init/real_data/visualizePCA.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Visualize COCO features\n", 8 | "\n", 9 | "1. visualize coco features\n", 10 | "2. identify pca-one; what is its cosine similarity with the residual (should be very high)\n", 11 | "3. move along the direction, plot 1-dim loss landscape. [-2,-1,-0.5,0,0.5,1,2]\n", 12 | " - need to have a fn(scalar,), output loss. \n" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import argparse\n", 22 | "import os\n", 23 | "import random\n", 24 | "import shutil\n", 25 | "import time\n", 26 | "import warnings\n", 27 | "from enum import Enum\n", 28 | "import pickle\n", 29 | "import numpy as np\n", 30 | "from collections import defaultdict\n", 31 | "\n", 32 | "import torch\n", 33 | "import torch.nn as nn\n", 34 | "import torch.optim\n", 35 | "from torch.utils.data import Dataset, DataLoader\n", 36 | "import torch.backends.cudnn as cudnn\n", 37 | "\n", 38 | "import glob \n", 39 | "def my_norm(x):\n", 40 | " return x/np.linalg.norm(x, axis=-1, keepdims=True)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "data_dict_list = list()\n", 50 | "\n", 51 | "for pickle_path in glob.glob('./features*/feature_dump_*.pkl'):\n", 52 | " with open(pickle_path, 'rb') as pkl_file:\n", 53 | " data_dict = pickle.load(pkl_file)\n", 54 | " assert len(data_dict['clip_image_features_list']) == len(data_dict['clip_text_features_list'])\n", 55 | " # assert len(data_dict['clip_image_features_list']) == len(data_dict['target_image_features_list'])\n", 56 | " # print('Number of image-text pairs', len(data_dict['clip_image_features_list']))\n", 57 | " data_dict_list.append(data_dict)\n", 58 | "\n", 59 | "print('Number of experiment files loaded', len(data_dict_list))" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "# visualize.\n", 69 | "\n", 70 | "from sklearn.decomposition import PCA\n", 71 | "# from sklearn.decomposition import TruncatedSVD as PCA # showns as multiple lines. \n", 72 | "# from sklearn.manifold import TSNE as PCA # \n", 73 | "# import umap\n", 74 | "# from umap import UMAP as PCA\n", 75 | "import pandas as pd\n", 76 | "import matplotlib.pyplot as plt\n", 77 | "%matplotlib inline\n", 78 | "import seaborn as sns\n", 79 | "# sns.set(font_scale=2) # crazy big\n", 80 | "plt.rcParams['figure.dpi'] = 300\n", 81 | "plt.rcParams['savefig.dpi'] = 300\n", 82 | "sns.set_theme()\n" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "# Functionality: given a list of exp, plot one modality. \n", 92 | "sns.set_context(\"talk\", font_scale=1.5) # paper, notebook, talk, and poster; font_scale=1.5,\n", 93 | "\n", 94 | "def plot_scattered_cones(data_dict_list, modality_str, draw=True):\n", 95 | " assert modality_str in ['clip_image_features_list', 'clip_text_features_list', 'target_image_features_list']\n", 96 | " print('modality_str: ', modality_str)\n", 97 | " # dataset_size = len(data_dict_list[0][modality_str])\n", 98 | " dataset_size = 5000\n", 99 | "\n", 100 | " total_feature_list = list()\n", 101 | " label_list = list()\n", 102 | " for expriment_idx in range(len(data_dict_list)):\n", 103 | " total_feature_list.append(data_dict_list[expriment_idx][modality_str][:dataset_size])\n", 104 | " label_list.extend(['Random-{}'.format(expriment_idx+1)] * dataset_size)\n", 105 | " total_feature_np = np.concatenate(total_feature_list, axis=0) \n", 106 | " total_feature_np = my_norm(total_feature_np) # L2-normalize\n", 107 | " assert len(total_feature_np) == len(data_dict_list) * dataset_size\n", 108 | "\n", 109 | " pca = PCA(n_components=6)\n", 110 | " pca_result = pca.fit_transform(total_feature_np)\n", 111 | " print('pca.explained_variance_ratio_', pca.explained_variance_ratio_)\n", 112 | " print('pca.singular_values_', pca.singular_values_)\n", 113 | "\n", 114 | " df = pd.DataFrame()\n", 115 | " df['pca_one'] = pca_result[:,0]\n", 116 | " df['pca_two'] = pca_result[:,1] \n", 117 | " df['Random Seed'] = label_list\n", 118 | "\n", 119 | " if draw:\n", 120 | " plt.figure(figsize=(20.0,6.18 * 2))\n", 121 | " p1 = sns.scatterplot(\n", 122 | " x=\"pca_one\", y=\"pca_two\",\n", 123 | " hue=\"Random Seed\",\n", 124 | " data=df,\n", 125 | " legend=True,\n", 126 | " )\n", 127 | " plt.xlabel(\"\")\n", 128 | " plt.ylabel(\"\")\n", 129 | " plt.legend(title='Random Seed', loc='upper left', bbox_to_anchor=(1.00, 1.0, ), prop={'size': 18})\n", 130 | " plt.show()\n", 131 | "\n", 132 | " return df\n" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "df_clip_img = plot_scattered_cones(data_dict_list[:25], 'clip_image_features_list', draw=True)\n", 142 | "df_clip_txt = plot_scattered_cones(data_dict_list[:25], 'clip_text_features_list', draw=True)\n", 143 | "df_resnet = plot_scattered_cones(data_dict_list[:25], 'target_image_features_list', draw=True)\n" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "def draw_df(df):\n", 153 | " plt.figure(figsize=(20.0,6.18 * 2))\n", 154 | " df['Seed'] = df['Random Seed'].str.replace('Random-', '', regex=False)\n", 155 | " p1 = sns.scatterplot(\n", 156 | " x=\"pca_one\", y=\"pca_two\",\n", 157 | " hue=\"Seed\",\n", 158 | " data=df,\n", 159 | " legend=True,\n", 160 | " )\n", 161 | " plt.xlabel(\"\")\n", 162 | " plt.ylabel(\"\")\n", 163 | " plt.legend(title='Random Seed', loc='upper left', bbox_to_anchor=(1.00, 1.0, ), ncol=2) # prop={'size': 50}, \n", 164 | " plt.show()\n", 165 | " return\n", 166 | "\n", 167 | "draw_df(df_clip_img)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": {}, 173 | "source": [ 174 | "# Plot PCA Singular Values, Explained Variance Ratios. \n", 175 | "Kind of anwering Mert's question" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 26, 181 | "metadata": {}, 182 | "outputs": [ 183 | { 184 | "name": "stdout", 185 | "output_type": "stream", 186 | "text": [ 187 | "modality_str: clip_image_features_list\n", 188 | "pca.explained_variance_ratio_\n", 189 | "0.043, 0.041, 0.039, 0.038, 0.036, 0.035, 0.035, 0.034, 0.033, 0.032, \n", 190 | "pca.singular_values_ [72.44832 70.31703 68.78217 68.24517 66.22955 65.66144 65.02128\n", 191 | " 64.06602 63.149437 62.50923 61.43108 60.71535 60.435135 59.02705\n", 192 | " 58.74808 57.4058 56.325825 56.2117 55.202732 54.309063 53.766792\n", 193 | " 52.040756 51.68926 49.76612 34.14688 33.398888 32.901985 31.960554\n", 194 | " 31.528515 31.300081 30.672626 30.518982 30.29744 29.762638 29.396282\n", 195 | " 28.373528 28.064127 27.74946 27.346584 27.130186 26.959745 26.397924\n", 196 | " 25.524904 25.109116 24.717733 24.531994 24.060846 23.81253 22.803596\n", 197 | " 20.144312]\n", 198 | "modality_str: clip_text_features_list\n", 199 | "pca.explained_variance_ratio_\n", 200 | "0.043, 0.041, 0.039, 0.037, 0.037, 0.035, 0.034, 0.033, 0.033, 0.031, \n", 201 | "pca.singular_values_ [71.93895 70.64999 68.51955 67.25281 66.71326 65.2795 64.50423\n", 202 | " 63.39669 62.925117 61.176167 59.73097 58.7134 58.423645 57.11752\n", 203 | " 56.474472 55.85696 54.98844 54.659405 54.08874 53.35901 51.593594\n", 204 | " 50.34826 49.43106 48.493847 16.067904 15.492056 15.30791 14.992251\n", 205 | " 14.946433 14.73657 14.656306 14.519942 14.41191 14.366245 14.130468\n", 206 | " 14.007584 13.708626 13.655253 13.45591 13.389069 13.198088 13.179104\n", 207 | " 13.093057 12.848161 12.838188 12.79897 12.603904 12.445068 12.337545\n", 208 | " 12.306129]\n", 209 | "modality_str: target_image_features_list\n", 210 | "pca.explained_variance_ratio_\n", 211 | "0.056, 0.055, 0.054, 0.051, 0.050, 0.050, 0.049, 0.046, 0.044, 0.043, \n", 212 | "pca.singular_values_ [57.44344 56.822586 56.4279 54.55056 54.171036 53.912224\n", 213 | " 53.301693 51.85659 50.885063 50.07982 49.386353 49.12857\n", 214 | " 48.405567 47.63106 47.15982 45.581974 45.29316 45.029636\n", 215 | " 44.288643 43.610165 42.718163 41.86789 40.769337 39.61369\n", 216 | " 4.8666005 4.7441974 4.5143256 4.4266877 4.175692 4.155532\n", 217 | " 4.1449823 4.055484 3.8198297 3.783392 3.687432 3.661967\n", 218 | " 3.6238446 3.5420978 3.483381 3.4556499 3.2627327 3.2502015\n", 219 | " 3.1480756 3.124066 3.0445938 2.9486566 2.828199 2.759845\n", 220 | " 2.7152538 2.6587367]\n" 221 | ] 222 | } 223 | ], 224 | "source": [ 225 | "# Functionality: given a list of exp, plot one modality. \n", 226 | "sns.set_context(\"talk\", font_scale=1.5) # paper, notebook, talk, and poster; font_scale=1.5,\n", 227 | "\n", 228 | "def plot_pca_stats(data_dict_list, modality_str, draw=True):\n", 229 | " assert modality_str in ['clip_image_features_list', 'clip_text_features_list', 'target_image_features_list']\n", 230 | " print('modality_str: ', modality_str)\n", 231 | " # dataset_size = len(data_dict_list[0][modality_str])\n", 232 | " dataset_size = 5000\n", 233 | "\n", 234 | " total_feature_list = list()\n", 235 | " label_list = list()\n", 236 | " for expriment_idx in range(len(data_dict_list)):\n", 237 | " total_feature_list.append(data_dict_list[expriment_idx][modality_str][:dataset_size])\n", 238 | " label_list.extend(['Random-{}'.format(expriment_idx+1)] * dataset_size)\n", 239 | " total_feature_np = np.concatenate(total_feature_list, axis=0) \n", 240 | " total_feature_np = my_norm(total_feature_np) # L2-normalize\n", 241 | " assert len(total_feature_np) == len(data_dict_list) * dataset_size\n", 242 | "\n", 243 | " pca = PCA(n_components=50)\n", 244 | " pca_result = pca.fit_transform(total_feature_np)\n", 245 | " print('pca.explained_variance_ratio_')\n", 246 | " for ratio in pca.explained_variance_ratio_[:10]:\n", 247 | " print('{:.3f},'.format(ratio), end=' ')\n", 248 | " print()\n", 249 | "\n", 250 | "\n", 251 | " print('pca.singular_values_', pca.singular_values_)\n", 252 | " return\n", 253 | "\n", 254 | "\n", 255 | "df_clip_img = plot_pca_stats(data_dict_list[:25], 'clip_image_features_list', draw=True)\n", 256 | "df_clip_txt = plot_pca_stats(data_dict_list[:25], 'clip_text_features_list', draw=True)\n", 257 | "df_resnet = plot_pca_stats(data_dict_list[:25], 'target_image_features_list', draw=True)\n" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [] 266 | } 267 | ], 268 | "metadata": { 269 | "interpreter": { 270 | "hash": "09c077faaa20da841f22e0f4d12b4addb73e00d9291bc78d00732f9f39794f23" 271 | }, 272 | "kernelspec": { 273 | "display_name": "Python 3.9.7 ('clip')", 274 | "language": "python", 275 | "name": "python3" 276 | }, 277 | "language_info": { 278 | "codemirror_mode": { 279 | "name": "ipython", 280 | "version": 3 281 | }, 282 | "file_extension": ".py", 283 | "mimetype": "text/x-python", 284 | "name": "python", 285 | "nbconvert_exporter": "python", 286 | "pygments_lexer": "ipython3", 287 | "version": "3.9.7" 288 | }, 289 | "orig_nbformat": 4 290 | }, 291 | "nbformat": 4, 292 | "nbformat_minor": 2 293 | } 294 | -------------------------------------------------------------------------------- /Figure_2_Cone_Effect/Figure_2c_scatter_cones_random_init/real_data_ImageNet_pretrained/ImageNet-Pretrained-Cones.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Weixin-Liang/Modality-Gap/8e20cb24efa4c5f89aad694f2f65eb43ffc46d10/Figure_2_Cone_Effect/Figure_2c_scatter_cones_random_init/real_data_ImageNet_pretrained/ImageNet-Pretrained-Cones.png -------------------------------------------------------------------------------- /Figure_2_Cone_Effect/Figure_2c_scatter_cones_random_init/real_data_ImageNet_pretrained/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ![](./ImageNet-Pretrained-Cones.png) 5 | -------------------------------------------------------------------------------- /Figure_2_Cone_Effect/Figure_2c_scatter_cones_random_init/real_data_ImageNet_pretrained/coco-extract.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "YPHN7PJgKOzb" 7 | }, 8 | "source": [ 9 | "# If so, will such distinctively different cones remain if randomly initialized models are fully trained?\n", 10 | "\n", 11 | "\n", 12 | "env: conda activate clip\n", 13 | "\n", 14 | "https://github.com/SamsungLabs/pytorch-ensembles" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 24, 20 | "metadata": { 21 | "colab": { 22 | "base_uri": "https://localhost:8080/" 23 | }, 24 | "id": "C1hkDT38hSaP", 25 | "outputId": "70a44964-883d-4fd0-b95a-2c7f2b19aca9" 26 | }, 27 | "outputs": [ 28 | { 29 | "name": "stdout", 30 | "output_type": "stream", 31 | "text": [ 32 | "Torch version: 1.7.1\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "import numpy as np\n", 38 | "import torch\n", 39 | "import pickle\n", 40 | "import time\n", 41 | "print(\"Torch version:\", torch.__version__)\n", 42 | "\n", 43 | "assert torch.__version__.split(\".\") >= [\"1\", \"7\", \"1\"], \"PyTorch 1.7.1 or later is required\"\n", 44 | "\n", 45 | "import os\n", 46 | "import matplotlib.pyplot as plt\n", 47 | "from collections import OrderedDict\n", 48 | "import torch\n", 49 | "\n", 50 | "%matplotlib inline\n", 51 | "%config InlineBackend.figure_format = 'retina'" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "# Load CLIP" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 25, 64 | "metadata": { 65 | "colab": { 66 | "base_uri": "https://localhost:8080/" 67 | }, 68 | "id": "uLFS29hnhlY4", 69 | "outputId": "11779e1e-8bdd-4167-c18e-d26bdd6b67db" 70 | }, 71 | "outputs": [ 72 | { 73 | "data": { 74 | "text/plain": [ 75 | "['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']" 76 | ] 77 | }, 78 | "execution_count": 25, 79 | "metadata": {}, 80 | "output_type": "execute_result" 81 | } 82 | ], 83 | "source": [ 84 | "import clip\n", 85 | "\n", 86 | "clip.available_models()" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 26, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "# ViT-B-32.json\n", 96 | "# copied from https://github.com/mlfoundations/open_clip/blob/91f6cce16b7bee90b3b5d38ca305b5b3b67cc200/src/training/model_configs/ViT-B-32.json\n", 97 | "model_info = {\n", 98 | " \"embed_dim\": 512,\n", 99 | " \"image_resolution\": 224,\n", 100 | " \"vision_layers\": 12,\n", 101 | " \"vision_width\": 768,\n", 102 | " \"vision_patch_size\": 32,\n", 103 | " \"context_length\": 77,\n", 104 | " \"vocab_size\": 49408,\n", 105 | " \"transformer_width\": 512,\n", 106 | " \"transformer_heads\": 8,\n", 107 | " \"transformer_layers\": 12\n", 108 | "} " 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 27, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "from torchvision import transforms\n", 118 | "input_size = model_info['image_resolution']\n", 119 | "preprocess = transforms.Compose([\n", 120 | " transforms.Resize(input_size),\n", 121 | " transforms.CenterCrop(input_size),\n", 122 | " transforms.ToTensor(),\n", 123 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", 124 | " ])" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 28, 130 | "metadata": {}, 131 | "outputs": [ 132 | { 133 | "data": { 134 | "text/plain": [ 135 | "torchvision.transforms.transforms.Compose" 136 | ] 137 | }, 138 | "execution_count": 28, 139 | "metadata": {}, 140 | "output_type": "execute_result" 141 | } 142 | ], 143 | "source": [ 144 | "type(preprocess)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "# Load Data" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 29, 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "name": "stdout", 161 | "output_type": "stream", 162 | "text": [ 163 | "loading annotations into memory...\n", 164 | "Done (t=0.04s)\n", 165 | "creating index...\n", 166 | "index created!\n" 167 | ] 168 | } 169 | ], 170 | "source": [ 171 | "import torchvision\n", 172 | "from torch.utils.data import DataLoader\n", 173 | "\n", 174 | "def target_transform(caption_list):\n", 175 | " caption = caption_list[0] # only the first caption\n", 176 | " return clip.tokenize(caption)[0]\n", 177 | "\n", 178 | "# coco_train_dataset = torchvision.datasets.CocoCaptions(\n", 179 | "# root = '/home/ubuntu/data/coco/train2017',\n", 180 | "# annFile = '/home/ubuntu/data/coco/annotations/captions_train2017.json',\n", 181 | "# transform=preprocess,\n", 182 | "# target_transform=target_transform,\n", 183 | "# )\n", 184 | "\n", 185 | "coco_val_dataset = torchvision.datasets.CocoCaptions(\n", 186 | " root = '/home/ubuntu/data/coco/val2017',\n", 187 | " annFile = '/home/ubuntu/data/coco/annotations/captions_val2017.json',\n", 188 | " transform=preprocess,\n", 189 | " target_transform=target_transform,\n", 190 | " )" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 30, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "# coco_train_dataloader = DataLoader(coco_train_dataset, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)\n", 200 | "coco_val_dataloader = DataLoader(coco_val_dataset, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": {}, 206 | "source": [ 207 | "# ResNet" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 31, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "import torch\n", 217 | "import torch.nn as nn\n", 218 | "import torchvision.models as models\n", 219 | "from torch.autograd import Variable" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 32, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "\n", 229 | "deepens_imagenet = [\n", 230 | " 'ImageNet-ResNet50-052e7f78e4db--1564492444-1.pth.tar', \n", 231 | " 'ImageNet-ResNet50-1132c260ef75--1564493784-1.pth.tar',\n", 232 | " 'ImageNet-ResNet50-2f817072e8da--1564493734-1.pth.tar',\n", 233 | " 'ImageNet-ResNet50-3177c697fbf4--1564495013-1.pth.tar',\n", 234 | " 'ImageNet-ResNet50-628e11f9fd67--1564481099-1.pth.tar',\n", 235 | " 'ImageNet-ResNet50-743e10f26a38--1564493675-1.pth.tar',\n", 236 | " 'ImageNet-ResNet50-7ded66ec9900--1564481097-1.pth.tar',\n", 237 | " 'ImageNet-ResNet50-8fc5076a66c9--1564481079-1.pth.tar',\n", 238 | " 'ImageNet-ResNet50-a58ab8dd26fc--1564492521-1.pth.tar',\n", 239 | " 'ImageNet-ResNet50-a80e40d84db2--1564492573-1.pth.tar',\n", 240 | " 'ImageNet-ResNet50-be11903315ee--1564481101-1.pth.tar',\n", 241 | "]\n", 242 | "\n", 243 | "def load_model_states(model, filename):\n", 244 | " \"\"\"\n", 245 | " Load a previously saved model states.\n", 246 | " https://github.com/SamsungLabs/pytorch-ensembles\n", 247 | " \"\"\"\n", 248 | " with open(filename, 'rb') as f:\n", 249 | " # original saved file with DataParallel\n", 250 | " state_dict = torch.load(f)['state_dict']\n", 251 | " # create new OrderedDict that does not contain `module.`\n", 252 | " from collections import OrderedDict\n", 253 | " new_state_dict = OrderedDict()\n", 254 | " for k, v in state_dict.items():\n", 255 | " name = k[7:] # remove `module.`\n", 256 | " new_state_dict[name] = v\n", 257 | " # load params\n", 258 | " model.load_state_dict(new_state_dict)\n", 259 | "\n" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 33, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "from clip.model import CLIP\n", 269 | "def get_random_init_models(checkpoint_tar_name):\n", 270 | "\n", 271 | " resnet18 = models.resnet50(pretrained=False) # actually resnet 50\n", 272 | " load_model_states(resnet18, '../deepens_imagenet/' + checkpoint_tar_name)\n", 273 | "\n", 274 | " modules=list(resnet18.children())[:-1]\n", 275 | " resnet18=nn.Sequential(*modules)\n", 276 | " for p in resnet18.parameters():\n", 277 | " p.requires_grad = False\n", 278 | "\n", 279 | " resnet18.cuda().eval()\n", 280 | " target_model = resnet18\n", 281 | " return target_model\n" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": {}, 287 | "source": [ 288 | "# Extractor loop\n" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 34, 294 | "metadata": {}, 295 | "outputs": [ 296 | { 297 | "name": "stdout", 298 | "output_type": "stream", 299 | "text": [ 300 | "target_image_features_list (5000, 2048)\n", 301 | "expriment_idx 0\n", 302 | "Feature Extraction completed in 0m 15s\n", 303 | "target_image_features_list (5000, 2048)\n", 304 | "expriment_idx 1\n", 305 | "Feature Extraction completed in 0m 31s\n", 306 | "target_image_features_list (5000, 2048)\n", 307 | "expriment_idx 2\n", 308 | "Feature Extraction completed in 0m 48s\n", 309 | "target_image_features_list (5000, 2048)\n", 310 | "expriment_idx 3\n", 311 | "Feature Extraction completed in 1m 2s\n", 312 | "target_image_features_list (5000, 2048)\n", 313 | "expriment_idx 4\n", 314 | "Feature Extraction completed in 1m 17s\n", 315 | "target_image_features_list (5000, 2048)\n", 316 | "expriment_idx 5\n", 317 | "Feature Extraction completed in 1m 32s\n", 318 | "target_image_features_list (5000, 2048)\n", 319 | "expriment_idx 6\n", 320 | "Feature Extraction completed in 1m 48s\n", 321 | "target_image_features_list (5000, 2048)\n", 322 | "expriment_idx 7\n", 323 | "Feature Extraction completed in 2m 6s\n", 324 | "target_image_features_list (5000, 2048)\n", 325 | "expriment_idx 8\n", 326 | "Feature Extraction completed in 2m 21s\n", 327 | "target_image_features_list (5000, 2048)\n", 328 | "expriment_idx 9\n", 329 | "Feature Extraction completed in 2m 40s\n", 330 | "target_image_features_list (5000, 2048)\n", 331 | "expriment_idx 10\n", 332 | "Feature Extraction completed in 3m 3s\n" 333 | ] 334 | } 335 | ], 336 | "source": [ 337 | "since = time.time()\n", 338 | "dataloaders = {\n", 339 | " # 'train': coco_train_dataloader, \n", 340 | " 'val': coco_val_dataloader,\n", 341 | "}\n", 342 | "\n", 343 | "\n", 344 | "# Each epoch has a training and validation phase\n", 345 | "for expriment_idx in range(len(deepens_imagenet)):\n", 346 | " phase = 'val'\n", 347 | " target_model = get_random_init_models(checkpoint_tar_name=deepens_imagenet[expriment_idx])\n", 348 | "\n", 349 | " ##################################\n", 350 | " # Fields to be stored for postprocessing \n", 351 | " ##################################\n", 352 | "\n", 353 | " target_image_features_list = []\n", 354 | "\n", 355 | " # Iterate over data.\n", 356 | " for inputs, captions in dataloaders[phase]:\n", 357 | " image_input = inputs.cuda(non_blocking=True)\n", 358 | " text_input = captions.cuda(non_blocking=True)\n", 359 | " \n", 360 | " with torch.set_grad_enabled(False):\n", 361 | " target_image_features = target_model(image_input).squeeze() \n", 362 | " ##################################\n", 363 | " # Evaluation book-keeping Field \n", 364 | " ##################################\n", 365 | " target_image_features_list.append( target_image_features.cpu().numpy() )\n", 366 | "\n", 367 | " ##################################\n", 368 | " # Evaluation book-keeping Field \n", 369 | " ##################################\n", 370 | " target_image_features_list = np.concatenate( target_image_features_list, axis=0)\n", 371 | " print('target_image_features_list', target_image_features_list.shape)\n", 372 | "\n", 373 | " dump_result_dict = {\n", 374 | " \"target_image_features_list\": target_image_features_list, \n", 375 | " }\n", 376 | " \n", 377 | " feature_dir = 'features200'\n", 378 | " os.makedirs(feature_dir, exist_ok = True) \n", 379 | " with open(os.path.join(feature_dir, 'feature_dump_{}.pkl'.format(expriment_idx) ), \"wb\") as pkl_file:\n", 380 | " pickle.dump(\n", 381 | " dump_result_dict, \n", 382 | " pkl_file, \n", 383 | " )\n", 384 | "\n", 385 | " time_elapsed = time.time() - since\n", 386 | " print('expriment_idx', expriment_idx)\n", 387 | " print('Feature Extraction completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": null, 393 | "metadata": {}, 394 | "outputs": [], 395 | "source": [] 396 | } 397 | ], 398 | "metadata": { 399 | "accelerator": "GPU", 400 | "colab": { 401 | "collapsed_sections": [], 402 | "name": "Interacting with CLIP.ipynb", 403 | "provenance": [] 404 | }, 405 | "kernelspec": { 406 | "display_name": "Python 3", 407 | "name": "python3" 408 | }, 409 | "language_info": { 410 | "codemirror_mode": { 411 | "name": "ipython", 412 | "version": 3 413 | }, 414 | "file_extension": ".py", 415 | "mimetype": "text/x-python", 416 | "name": "python", 417 | "nbconvert_exporter": "python", 418 | "pygments_lexer": "ipython3", 419 | "version": "3.9.7" 420 | }, 421 | "widgets": { 422 | "application/vnd.jupyter.widget-state+json": { 423 | "12e23e2819094ee0a079d4eb77cfc4f9": { 424 | "model_module": "@jupyter-widgets/base", 425 | "model_module_version": "1.2.0", 426 | "model_name": "LayoutModel", 427 | "state": { 428 | "_model_module": "@jupyter-widgets/base", 429 | "_model_module_version": "1.2.0", 430 | "_model_name": "LayoutModel", 431 | "_view_count": null, 432 | "_view_module": "@jupyter-widgets/base", 433 | "_view_module_version": "1.2.0", 434 | "_view_name": "LayoutView", 435 | "align_content": null, 436 | "align_items": null, 437 | "align_self": null, 438 | "border": null, 439 | "bottom": null, 440 | "display": null, 441 | "flex": null, 442 | "flex_flow": null, 443 | "grid_area": null, 444 | "grid_auto_columns": null, 445 | "grid_auto_flow": null, 446 | "grid_auto_rows": null, 447 | "grid_column": null, 448 | "grid_gap": null, 449 | "grid_row": null, 450 | "grid_template_areas": null, 451 | "grid_template_columns": null, 452 | "grid_template_rows": null, 453 | "height": null, 454 | "justify_content": null, 455 | "justify_items": null, 456 | "left": null, 457 | "margin": null, 458 | "max_height": null, 459 | "max_width": null, 460 | "min_height": null, 461 | "min_width": null, 462 | "object_fit": null, 463 | "object_position": null, 464 | "order": null, 465 | "overflow": null, 466 | "overflow_x": null, 467 | "overflow_y": null, 468 | "padding": null, 469 | "right": null, 470 | "top": null, 471 | "visibility": null, 472 | "width": null 473 | } 474 | }, 475 | "1369964d45004b5e95a058910b2a33e6": { 476 | "model_module": "@jupyter-widgets/controls", 477 | "model_module_version": "1.5.0", 478 | "model_name": "HBoxModel", 479 | "state": { 480 | "_dom_classes": [], 481 | "_model_module": "@jupyter-widgets/controls", 482 | "_model_module_version": "1.5.0", 483 | "_model_name": "HBoxModel", 484 | "_view_count": null, 485 | "_view_module": "@jupyter-widgets/controls", 486 | "_view_module_version": "1.5.0", 487 | "_view_name": "HBoxView", 488 | "box_style": "", 489 | "children": [ 490 | "IPY_MODEL_7a5f52e56ede4ac3abe37a3ece007dc9", 491 | "IPY_MODEL_ce8b0faa1a1340b5a504d7b3546b3ccb" 492 | ], 493 | "layout": "IPY_MODEL_12e23e2819094ee0a079d4eb77cfc4f9" 494 | } 495 | }, 496 | "161969cae25a49f38aacd1568d3cac6c": { 497 | "model_module": "@jupyter-widgets/base", 498 | "model_module_version": "1.2.0", 499 | "model_name": "LayoutModel", 500 | "state": { 501 | "_model_module": "@jupyter-widgets/base", 502 | "_model_module_version": "1.2.0", 503 | "_model_name": "LayoutModel", 504 | "_view_count": null, 505 | "_view_module": "@jupyter-widgets/base", 506 | "_view_module_version": "1.2.0", 507 | "_view_name": "LayoutView", 508 | "align_content": null, 509 | "align_items": null, 510 | "align_self": null, 511 | "border": null, 512 | "bottom": null, 513 | "display": null, 514 | "flex": null, 515 | "flex_flow": null, 516 | "grid_area": null, 517 | "grid_auto_columns": null, 518 | "grid_auto_flow": null, 519 | "grid_auto_rows": null, 520 | "grid_column": null, 521 | "grid_gap": null, 522 | "grid_row": null, 523 | "grid_template_areas": null, 524 | "grid_template_columns": null, 525 | "grid_template_rows": null, 526 | "height": null, 527 | "justify_content": null, 528 | "justify_items": null, 529 | "left": null, 530 | "margin": null, 531 | "max_height": null, 532 | "max_width": null, 533 | "min_height": null, 534 | "min_width": null, 535 | "object_fit": null, 536 | "object_position": null, 537 | "order": null, 538 | "overflow": null, 539 | "overflow_x": null, 540 | "overflow_y": null, 541 | "padding": null, 542 | "right": null, 543 | "top": null, 544 | "visibility": null, 545 | "width": null 546 | } 547 | }, 548 | "4a61c10fc00c4f04bb00b82e942da210": { 549 | "model_module": "@jupyter-widgets/base", 550 | "model_module_version": "1.2.0", 551 | "model_name": "LayoutModel", 552 | "state": { 553 | "_model_module": "@jupyter-widgets/base", 554 | "_model_module_version": "1.2.0", 555 | "_model_name": "LayoutModel", 556 | "_view_count": null, 557 | "_view_module": "@jupyter-widgets/base", 558 | "_view_module_version": "1.2.0", 559 | "_view_name": "LayoutView", 560 | "align_content": null, 561 | "align_items": null, 562 | "align_self": null, 563 | "border": null, 564 | "bottom": null, 565 | "display": null, 566 | "flex": null, 567 | "flex_flow": null, 568 | "grid_area": null, 569 | "grid_auto_columns": null, 570 | "grid_auto_flow": null, 571 | "grid_auto_rows": null, 572 | "grid_column": null, 573 | "grid_gap": null, 574 | "grid_row": null, 575 | "grid_template_areas": null, 576 | "grid_template_columns": null, 577 | "grid_template_rows": null, 578 | "height": null, 579 | "justify_content": null, 580 | "justify_items": null, 581 | "left": null, 582 | "margin": null, 583 | "max_height": null, 584 | "max_width": null, 585 | "min_height": null, 586 | "min_width": null, 587 | "object_fit": null, 588 | "object_position": null, 589 | "order": null, 590 | "overflow": null, 591 | "overflow_x": null, 592 | "overflow_y": null, 593 | "padding": null, 594 | "right": null, 595 | "top": null, 596 | "visibility": null, 597 | "width": null 598 | } 599 | }, 600 | "5e6adc4592124a4581b85f4c1f3bab4d": { 601 | "model_module": "@jupyter-widgets/controls", 602 | "model_module_version": "1.5.0", 603 | "model_name": "ProgressStyleModel", 604 | "state": { 605 | "_model_module": "@jupyter-widgets/controls", 606 | "_model_module_version": "1.5.0", 607 | "_model_name": "ProgressStyleModel", 608 | "_view_count": null, 609 | "_view_module": "@jupyter-widgets/base", 610 | "_view_module_version": "1.2.0", 611 | "_view_name": "StyleView", 612 | "bar_color": null, 613 | "description_width": "initial" 614 | } 615 | }, 616 | "7a5f52e56ede4ac3abe37a3ece007dc9": { 617 | "model_module": "@jupyter-widgets/controls", 618 | "model_module_version": "1.5.0", 619 | "model_name": "FloatProgressModel", 620 | "state": { 621 | "_dom_classes": [], 622 | "_model_module": "@jupyter-widgets/controls", 623 | "_model_module_version": "1.5.0", 624 | "_model_name": "FloatProgressModel", 625 | "_view_count": null, 626 | "_view_module": "@jupyter-widgets/controls", 627 | "_view_module_version": "1.5.0", 628 | "_view_name": "ProgressView", 629 | "bar_style": "success", 630 | "description": "", 631 | "description_tooltip": null, 632 | "layout": "IPY_MODEL_4a61c10fc00c4f04bb00b82e942da210", 633 | "max": 169001437, 634 | "min": 0, 635 | "orientation": "horizontal", 636 | "style": "IPY_MODEL_5e6adc4592124a4581b85f4c1f3bab4d", 637 | "value": 169001437 638 | } 639 | }, 640 | "b597cd6f6cd443aba4bf4491ac7f957e": { 641 | "model_module": "@jupyter-widgets/controls", 642 | "model_module_version": "1.5.0", 643 | "model_name": "DescriptionStyleModel", 644 | "state": { 645 | "_model_module": "@jupyter-widgets/controls", 646 | "_model_module_version": "1.5.0", 647 | "_model_name": "DescriptionStyleModel", 648 | "_view_count": null, 649 | "_view_module": "@jupyter-widgets/base", 650 | "_view_module_version": "1.2.0", 651 | "_view_name": "StyleView", 652 | "description_width": "" 653 | } 654 | }, 655 | "ce8b0faa1a1340b5a504d7b3546b3ccb": { 656 | "model_module": "@jupyter-widgets/controls", 657 | "model_module_version": "1.5.0", 658 | "model_name": "HTMLModel", 659 | "state": { 660 | "_dom_classes": [], 661 | "_model_module": "@jupyter-widgets/controls", 662 | "_model_module_version": "1.5.0", 663 | "_model_name": "HTMLModel", 664 | "_view_count": null, 665 | "_view_module": "@jupyter-widgets/controls", 666 | "_view_module_version": "1.5.0", 667 | "_view_name": "HTMLView", 668 | "description": "", 669 | "description_tooltip": null, 670 | "layout": "IPY_MODEL_161969cae25a49f38aacd1568d3cac6c", 671 | "placeholder": "​", 672 | "style": "IPY_MODEL_b597cd6f6cd443aba4bf4491ac7f957e", 673 | "value": " 169001984/? [00:06<00:00, 25734958.25it/s]" 674 | } 675 | } 676 | } 677 | } 678 | }, 679 | "nbformat": 4, 680 | "nbformat_minor": 0 681 | } 682 | -------------------------------------------------------------------------------- /Figure_3_Contrastive_Learning/get_gap_stats.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import argparse\n", 10 | "import os\n", 11 | "import random\n", 12 | "import shutil\n", 13 | "import time\n", 14 | "import warnings\n", 15 | "from enum import Enum\n", 16 | "import pickle\n", 17 | "import numpy as np\n", 18 | "from collections import defaultdict\n", 19 | "\n", 20 | "import torch\n", 21 | "import torch.nn as nn\n", 22 | "import torch.optim\n", 23 | "from torch.utils.data import Dataset, DataLoader\n", 24 | "import torch.backends.cudnn as cudnn\n", 25 | "\n", 26 | "def my_norm(x):\n", 27 | " return x/np.linalg.norm(x, axis=-1, keepdims=True)" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "pickle_path = './features/feature_dump_val.pkl'\n", 37 | "with open(pickle_path, 'rb') as pkl_file:\n", 38 | " data_dict = pickle.load(pkl_file)\n", 39 | " assert len(data_dict['clip_image_features_list']) == len(data_dict['clip_text_features_list'])\n", 40 | " # assert len(data_dict['clip_image_features_list']) == len(data_dict['target_image_features_list'])\n", 41 | " print('Number of image-text pairs', len(data_dict['clip_image_features_list']))" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "# Get the gap\n", 51 | "modality_gap = my_norm(my_norm(data_dict['clip_image_features_list']).mean(axis=0) - my_norm(data_dict['clip_text_features_list']).mean(axis=0))\n", 52 | "# # save as a gap vector\n", 53 | "# with open(os.path.join('modality_gap_vector.pkl' ), \"wb\") as pkl_file:\n", 54 | "# pickle.dump(\n", 55 | "# modality_gap, \n", 56 | "# pkl_file, \n", 57 | "# )" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 20, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "name": "stdout", 67 | "output_type": "stream", 68 | "text": [ 69 | "modifying_results\n" 70 | ] 71 | }, 72 | { 73 | "data": { 74 | "text/html": [ 75 | "
\n", 76 | "\n", 89 | "\n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | "
distancedelta
01.412827-1.00
11.305014-0.75
21.173114-0.50
31.013107-0.25
40.8221030.00
50.5998460.25
60.3503830.50
70.0838290.75
\n", 140 | "
" 141 | ], 142 | "text/plain": [ 143 | " distance delta\n", 144 | "0 1.412827 -1.00\n", 145 | "1 1.305014 -0.75\n", 146 | "2 1.173114 -0.50\n", 147 | "3 1.013107 -0.25\n", 148 | "4 0.822103 0.00\n", 149 | "5 0.599846 0.25\n", 150 | "6 0.350383 0.50\n", 151 | "7 0.083829 0.75" 152 | ] 153 | }, 154 | "execution_count": 20, 155 | "metadata": {}, 156 | "output_type": "execute_result" 157 | } 158 | ], 159 | "source": [ 160 | "# modifying_results = defaultdict(list)\n", 161 | "\n", 162 | "# for delta in np.arange(-1.0, 1.0, 0.25): \n", 163 | "# modified_text_features = my_norm(data_dict['clip_text_features_list']) + 0.5 * delta * modality_gap\n", 164 | "# modified_text_features = my_norm(modified_text_features)\n", 165 | "\n", 166 | "# modified_image_features = my_norm(data_dict['clip_image_features_list']) - 0.5 * delta * modality_gap\n", 167 | "# modified_image_features = my_norm(modified_image_features)\n", 168 | "\n", 169 | "# distance_sign = np.dot(modality_gap, modified_image_features.mean(axis=0)-modified_text_features.mean(axis=0))\n", 170 | " \n", 171 | "# # Euclidean distance between mass centers\n", 172 | "# modifying_results['distance'].append(\n", 173 | "# np.linalg.norm(\n", 174 | "# modified_image_features.mean(axis=0) - modified_text_features.mean(axis=0)\n", 175 | "# ) * np.sign(distance_sign)\n", 176 | "# )\n", 177 | "# modifying_results['delta'].append(delta)\n", 178 | "\n", 179 | "# import pandas as pd\n", 180 | "# print('modifying_results')\n", 181 | "\n", 182 | "# pd.DataFrame(modifying_results)" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 23, 188 | "metadata": {}, 189 | "outputs": [ 190 | { 191 | "name": "stdout", 192 | "output_type": "stream", 193 | "text": [ 194 | "modifying_results\n" 195 | ] 196 | }, 197 | { 198 | "data": { 199 | "text/html": [ 200 | "
\n", 201 | "\n", 214 | "\n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | "
distancedelta
01.287905-1.00
11.224927-0.75
21.134930-0.50
31.005094-0.25
40.8221030.00
50.5842310.25
60.3178480.50
70.0702930.75
\n", 265 | "
" 266 | ], 267 | "text/plain": [ 268 | " distance delta\n", 269 | "0 1.287905 -1.00\n", 270 | "1 1.224927 -0.75\n", 271 | "2 1.134930 -0.50\n", 272 | "3 1.005094 -0.25\n", 273 | "4 0.822103 0.00\n", 274 | "5 0.584231 0.25\n", 275 | "6 0.317848 0.50\n", 276 | "7 0.070293 0.75" 277 | ] 278 | }, 279 | "execution_count": 23, 280 | "metadata": {}, 281 | "output_type": "execute_result" 282 | } 283 | ], 284 | "source": [ 285 | "modifying_results = defaultdict(list)\n", 286 | "\n", 287 | "for delta in np.arange(-1.0, 1.0, 0.05): \n", 288 | " modified_text_features = my_norm(data_dict['clip_text_features_list']) \n", 289 | " modified_text_features = my_norm(modified_text_features)\n", 290 | "\n", 291 | " modified_image_features = my_norm(data_dict['clip_image_features_list']) - 1.0 * delta * modality_gap\n", 292 | " modified_image_features = my_norm(modified_image_features)\n", 293 | "\n", 294 | " distance_sign = np.dot(modality_gap, modified_image_features.mean(axis=0)-modified_text_features.mean(axis=0))\n", 295 | " \n", 296 | " # Euclidean distance between mass centers\n", 297 | " modifying_results['distance'].append(\n", 298 | " np.linalg.norm(\n", 299 | " modified_image_features.mean(axis=0) - modified_text_features.mean(axis=0)\n", 300 | " ) * np.sign(distance_sign)\n", 301 | " )\n", 302 | " modifying_results['delta'].append(delta)\n", 303 | "\n", 304 | "import pandas as pd\n", 305 | "print('modifying_results')\n", 306 | "\n", 307 | "pd.DataFrame(modifying_results)\n" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": null, 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [ 316 | "# save as a gap vector\n", 317 | "dump_dict = {\n", 318 | " 'modality_gap': modality_gap, \n", 319 | " 'Eucliean distance': modifying_results['distance'], \n", 320 | " 'delta': modifying_results['delta'],\n", 321 | "}\n", 322 | "with open(os.path.join('modality_gap_vector.pkl' ), \"wb\") as pkl_file:\n", 323 | " pickle.dump(\n", 324 | " dump_dict, \n", 325 | " pkl_file, \n", 326 | " )" 327 | ] 328 | } 329 | ], 330 | "metadata": { 331 | "interpreter": { 332 | "hash": "09c077faaa20da841f22e0f4d12b4addb73e00d9291bc78d00732f9f39794f23" 333 | }, 334 | "kernelspec": { 335 | "display_name": "Python 3.9.7 ('clip')", 336 | "language": "python", 337 | "name": "python3" 338 | }, 339 | "language_info": { 340 | "codemirror_mode": { 341 | "name": "ipython", 342 | "version": 3 343 | }, 344 | "file_extension": ".py", 345 | "mimetype": "text/x-python", 346 | "name": "python", 347 | "nbconvert_exporter": "python", 348 | "pygments_lexer": "ipython3", 349 | "version": "3.9.7" 350 | }, 351 | "orig_nbformat": 4 352 | }, 353 | "nbformat": 4, 354 | "nbformat_minor": 2 355 | } 356 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mind the Gap: Understanding the Modality Gap in Multi-modal Contrastive Representation Learning 2 | 3 | [![Website shields.io](https://img.shields.io/website-up-down-green-red/http/shields.io.svg)](https://modalitygap.readthedocs.io) 4 | [![Documentation Status](https://readthedocs.org/projects/modalitygap/badge/?version=latest)](http://modalitygap.readthedocs.io/?badge=latest) 5 | [![MIT license](https://img.shields.io/badge/License-MIT-blue.svg)](https://lbesson.mit-license.org/) 6 | [![OpenReview](https://img.shields.io/badge/OpenReview-S7Evzt9uit3.svg)](https://openreview.net/forum?id=S7Evzt9uit3) 7 | [![Python 3.6](https://img.shields.io/badge/python-3.6-blue.svg)](https://www.python.org/downloads/release/python-360/) 8 | [![Pytorch](https://img.shields.io/badge/Pytorch-1.8-red.svg)](https://shields.io/) 9 | [![Made withJupyter](https://img.shields.io/badge/Made%20with-Jupyter-orange?style=for-the-badge&logo=Jupyter)](https://jupyter.org/try) 10 | 11 | 12 | 13 | This repo provides the PyTorch source code of our paper: 14 | > [Mind the Gap: Understanding the Modality Gap in Multi-modal Contrastive Representation Learning](https://openreview.net/forum?id=S7Evzt9uit3)
15 | > Weixin Liang*, Yuhui Zhang*, Yongchan Kwon*, Serena Yeung, James Zou
16 | > NeurIPS (2022)
17 | > [[PDF]](https://openreview.net/pdf?id=S7Evzt9uit3) 18 | [[Website]](https://modalitygap.readthedocs.io/) 19 | [[Twitter]](https://twitter.com/james_y_zou/status/1503370841957933056) 20 | 21 | 22 | 23 | 24 | 25 | 26 | ## Repo Structure Overview 27 | ```plain 28 | . 29 | ├── README.md 30 | ├── Figure_1_Modality_Gap/ 31 | ├── Figure_2_Cone_Effect/ 32 | ├── Figure_3_Contrastive_Learning/ 33 | ├── Table_1_Implications_CLIP_Zero_Shot/ 34 | ├── Table_2_Implications_CLIP_Fairness/ 35 | ├── util/ 36 | ``` 37 | We organize the code in the orders of the figures as presented in the paper. As the folder name indicated, 38 | the `Figure_1_Modality_Gap` folder provides the code for reproducing Figure 1. 39 | 40 | 41 | ## Abstract 42 | *We present modality gap, an intriguing geometric phenomenon of the representation space of multi-modal models. Specifically, we show that different data modalities (e.g. images and text) are embedded at arm's length in their shared representation in multi-modal models such as CLIP. Our systematic analysis demonstrates that this gap is caused by a combination of model initialization and contrastive learning optimization. In model initialization, we show empirically and theoretically that the representation of a common deep neural network is restricted to a narrow cone. As a consequence, in a multi-modal model with two encoders, the representations of the two modalities are clearly apart when the model is initialized. During optimization, contrastive learning keeps the different modalities separate by a certain distance, which is influenced by the temperature parameter in the loss function. Our experiments further demonstrate that varying the modality gap distance has a significant impact in improving the model's downstream zero-shot classification performance and fairness.* 43 | 44 | **TL;DR:** We present modality gap, an intriguing geometric phenomenon of the representation space of multi-modal models. 45 | 46 | ## What is `Modality Gap`? 47 | 48 | 49 | As shown in Figure 1 (b), 50 | CLIP's image embeddings and text embeddings are located in two *completely separate* regions of the embedding space. 51 | We find this phenomenon consistently across various multi-modal models, covering texts, natural images, videos, medical images, and amino-acid sequences. 52 | Interestingly, this phenomenon still holds even when we embed using multi-modal models with random weights (Figure 1 (c)). 53 | 54 | ![](./docs/figures/Figure1.png) 55 | **Figure 1: The pervasive modality gap in multi-modal contrastive representation learning** 56 | 57 | 58 | 59 | 60 | ## How do we explain `Modality Gap`? A three-part explanationn. 61 | 62 | While it might seem reasonable to attribute the gap to differences in data distributions or to the different encoder architectures, we showed that these factors are *not* the fundamental cause of the modality gap phenomenon. This paper provides a *three-part explanation* for the modality gap phenomenon. 63 | 64 | 1. The general inductive bias of deep neural architecture creates a cone effect: The effective embedding space is restricted to a narrow cone for pre-trained models or models with random weights. 65 | 66 | 2. Different random initializations create different embedding cones. Since a multi-modal model consists of two encoders, which create different cones at random initialization, this explains how the modality gap is present at initialization. 67 | 68 | 3. The contrastive learning objective commonly used by multi-modal models preserves the gap. 69 | 70 | 71 | 72 | 73 | ### Part 1: The `Cone Effect` Induces A Modality Gap 74 | 75 | ![](./docs/figures/Figure2ab.png) 76 | **Figure 2 (a,b): The Cone Effect Induces A Modality Gap.** 77 | 78 | #### The cosine similarity between all pairs of embeddings (last-layer feature) 79 | We extract 5,000 embeddings from the final layer of ResNet, Vision Transformer, and Text Transformer respectively on MSCOCO Caption. 80 | We then compute the cosine similarity be- tween all possible pairs of the 5,000 embeddings within each model. 81 | 82 | The average cosine similarity is substantially larger than 0, indicating that the embedding space is a narrow cone. 83 | The cone effect also holds on randomly initialized models, and on random noise inputs. 84 | 85 | 86 | 87 | #### Effects of nonlinear activation and depth. 88 | As shown in Figure 2 (b), MLPs without non- 89 | linear activation shows little cone effect. However, with non-linearity, the average cosine similarity increases rapidly as the number of layers increases. These results indicate that the non-linear activation functions play a crucial role in the cone effect. 90 | 91 | 92 | ### Part 2: Different random initializations create different cones 93 | 94 | ![](./docs/figures/Figure2c.png) 95 | **Figure 2 (c): Different random initializations create different cones.** 96 | 97 | 98 | We randomly initialized a model 25 times, and plotted its extracted embeddings on the same real data via UMAP visualization. We found that each random initialization forms a distinctively different cone. 99 | Since a multi-modal model consists of two encoders, which creates different cones at random ini- tialization, this explains how the modality gap is present at initialization. 100 | 101 | 102 | ### Theoretical Analysis: 103 | #### Part 1: The `Cone Effect` Induces A Modality Gap 104 | 105 |

106 | 107 |

108 | 109 | **Theorem 1:** 110 | Our theoretical analysis shows that under mild assumptions, each neural network layer shrinks the angle between any pair of embedding vectors with high probability, thereby creating more narrow cones in deeper architectures. Here $\phi$ is the ReLU activation function. 111 | 112 | #### Part 2: Different random initializations create different cones 113 | 114 | 115 |

116 | 117 | 118 |

119 | 120 | **Theorem 2:** 121 | We further prove that different random initializations of model weights result in different cones. 122 | More specifically, the variance of an intermediate output mostly come from the model’s random initialization. 123 | 124 | 125 | 126 | ### Part 3: `Contrastive learning` preserves modality gap 127 | 128 | ![](./docs/figures/Figure3.jpg) 129 | **Figure 3: Contrastive learning preserves modality gap.** 130 | 131 | We hypothesize that the contrastive learning objective encourages the existence of the modality gap. To testify this hypothesis, we manually shift CLIP's image embeddings and text embeddings towards closing the gap. 132 | 133 | We found that under CLIP's default temperature $\tau=\frac{1}{100}$, the default gap distance $\| \vec{\Delta}_\text{gap} \|=0.82$ actually achieves the global minimum, and shifting toward closing the gap *increases* the contrastive loss. 134 | 135 | However, when the temperature increases (Figure 3 (c,d)), the repulsive structure and the local minimum gradually disappear, and closing the gap becomes more optimal. 136 | 137 | Together, these results show that contrastive learning keeps the different modalities separate by a certain distance, which is influenced by the temperature parameter in the loss function. 138 | 139 | 140 | ### Modality Gap Implications 141 | 142 | Interestingly, by simply modifying the modality gap's distance, we can improve CLIP's zero-shot performance (Table 1) and fairness (Table 2). 143 | 144 | 145 | ![](./docs/figures/Tables.png) 146 | **Modality Gap Implications: Experiment Results** 147 | 148 | #### Table 1: Zero-shot Performance 149 | One of the most interesting capabilities for CLIP is its strong zero-shot transferability to a variety of downstream tasks without any supervision. 150 | We found that modifying the modality gap can improve zeroshot performances on multiple downstream tasks. 151 | 152 | #### Table 2: Fairness 153 | We found that increasing the gap from $0.82$ to $0.97$ *reduces* denigration harms consistently for *all* races. 154 | Meanwhile, we only observe a minor $0.0008$ top-1 accuracy drop. 155 | It is encouraging that a simple gap offsetting approach can lead to a consistent bias reduction across all races on such a complex model (i.e., CLIP). 156 | 157 | 158 | ## Citation 159 | ``` 160 | @inproceedings{ 161 | ModalityGap, 162 | title={Mind the Gap: Understanding the Modality Gap in Multi-modal Contrastive Representation Learning}, 163 | author = {Weixin Liang and 164 | Yuhui Zhang and 165 | Yongchan Kwon and 166 | Serena Yeung and 167 | James Zou}, 168 | booktitle={NeurIPS}, 169 | year={2022}, 170 | url={https://openreview.net/forum?id=S7Evzt9uit3} 171 | } 172 | ``` 173 | 174 | ![image](https://user-images.githubusercontent.com/32794044/192175707-93b48fd3-ef06-465e-a43b-ff6a82503a66.png) 175 | -------------------------------------------------------------------------------- /Table_1_Implications_CLIP_Zero_Shot/shifting/shift_features.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Most commonly used\n", 10 | "import sys\n", 11 | "import os\n", 12 | "import json\n", 13 | "import pickle\n", 14 | "import math\n", 15 | "from collections import Counter, defaultdict\n", 16 | "from functools import partial\n", 17 | "from tqdm import tqdm, trange\n", 18 | "from colors import blue, red, green, cyan\n", 19 | "\n", 20 | "# Numerical computation\n", 21 | "import numpy as np\n", 22 | "import torch\n", 23 | "import torch.nn.functional as F\n", 24 | "\n", 25 | "# Visualization\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "import seaborn as sns\n", 28 | "sns.set_theme()\n", 29 | "sns.set_context(\"talk\")\n", 30 | "\n", 31 | "sys.path.append('ANONYMOUS_ROOTDIR/develop/open-world/')\n", 32 | "from utils import svd, reduce_and_visualize, load_clip, encode_clip, encode_clip_classification, train_clip_toy, ce_loss, uniform_loss, dual_ce_loss, simple_ce_loss\n", 33 | "from datasets import ImageCaptionDataset, ClassificationDataset\n", 34 | "\n", 35 | "\n", 36 | "def evaluate_retrieval(image_features, text_features):\n", 37 | " metrics = {}\n", 38 | " sim = image_features @ text_features.T\n", 39 | " for K in [1, 5, 10]:\n", 40 | " pred = sim.argsort(dim=-1)\n", 41 | " text_r = np.mean([i in pred[i, -K:] for i in range(len(pred))])\n", 42 | "\n", 43 | " pred = sim.argsort(dim=0)\n", 44 | " image_r = np.mean([i in pred[-K:, i] for i in range(len(pred))])\n", 45 | "\n", 46 | " metrics[f'Text R@{K}'] = text_r\n", 47 | " metrics[f'Image R@{K}'] = image_r\n", 48 | " return metrics\n", 49 | "\n", 50 | "\n", 51 | "def evaluate_classification(image_features, text_features, labels):\n", 52 | " metrics = {}\n", 53 | " sim = image_features @ text_features.T\n", 54 | " for K in [1, 5, 10]:\n", 55 | " pred = sim.argsort(dim=-1)\n", 56 | " text_r = np.mean([labels[i] in pred[i, -K:] for i in range(len(pred))])\n", 57 | " metrics[f'Hit@{K}'] = text_r\n", 58 | " return metrics\n", 59 | "\n", 60 | "\n", 61 | "def evaluate_binary_classification(image_features, text_features, labels):\n", 62 | " from sklearn.metrics import roc_auc_score\n", 63 | " metrics = {}\n", 64 | " sim = image_features @ text_features.T * 100\n", 65 | " probs = F.softmax(sim, dim=-1)[:, 1]\n", 66 | " roc_auc = roc_auc_score(labels, probs)\n", 67 | " metrics[f'ROC-AUC'] = roc_auc\n", 68 | " return metrics\n", 69 | "\n", 70 | "\n", 71 | "def move_features(image_features, text_features, evaluate_func, direction_vec=None):\n", 72 | " all_metrics = {}\n", 73 | " if direction_vec is None:\n", 74 | " modality_gap = image_features.mean(axis=0) - text_features.mean(axis=0)\n", 75 | " modality_gap = modality_gap / modality_gap.norm()\n", 76 | " direction_vec = modality_gap\n", 77 | " \n", 78 | " for delta in np.arange(-5, 5, 0.25):\n", 79 | " modified_text_features = text_features + 0.5 * delta * direction_vec\n", 80 | " modified_text_features /= modified_text_features.norm(dim=-1, keepdim=True)\n", 81 | "\n", 82 | " modified_image_features = image_features - 0.5 * delta * direction_vec\n", 83 | " modified_image_features /= modified_image_features.norm(dim=-1, keepdim=True)\n", 84 | "\n", 85 | " # reduce_and_visualize(modified_image_features.numpy(), modified_text_features.numpy(), methods=['svd', 'pca'], n_dim=2)\n", 86 | "\n", 87 | " preds = (modified_image_features @ modified_text_features.T).argmax(dim=-1)\n", 88 | "\n", 89 | " gap_distance = (modified_text_features.mean(axis=0) - modified_image_features.mean(axis=0)).norm().item()\n", 90 | "\n", 91 | " metrics = evaluate_func(modified_image_features, modified_text_features)\n", 92 | " all_metrics[delta] = (metrics, gap_distance, preds)\n", 93 | "\n", 94 | " print(delta, metrics, gap_distance)\n", 95 | " return all_metrics\n", 96 | "\n", 97 | "\n", 98 | "def move_features_along_hypersphere(image_features, text_features, evaluate_func):\n", 99 | " return \"Impossible\"\n", 100 | "\n", 101 | "\n", 102 | "def plot_metrics(all_metrics, metric_name='Hit@1'):\n", 103 | " xs, ys = [], []\n", 104 | " for delta in sorted(all_metrics.keys()):\n", 105 | " metrics, gap_distance, preds = all_metrics[delta]\n", 106 | " xs.append(gap_distance)\n", 107 | " ys.append(metrics[metric_name])\n", 108 | " print(f'Optimal {metric_name}: {max(ys)}')\n", 109 | "\n", 110 | " minidx = xs.index(min(xs))\n", 111 | " for i in range(minidx + 1, len(xs)): xs[i] = -xs[i]\n", 112 | " plt.plot(xs, ys, 'o-')\n", 113 | " plt.xlabel('Gap Distance')\n", 114 | " plt.ylabel(metric_name)\n", 115 | "\n", 116 | " initial_gap = all_metrics[0][1]\n", 117 | " plt.axvline(initial_gap, color='k', linestyle='--')\n", 118 | "\n", 119 | " plt.show()" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "# Move features along direction computed on downstream tasks\n", 129 | "\n", 130 | "model = load_clip()\n", 131 | "dataset = ClassificationDataset(name='EuroSAT')\n", 132 | "image_features, text_features = encode_clip_classification(model, dataset, prompt='a centered satellite photo of {}.')\n", 133 | "labels = [item[1] for item in dataset]\n", 134 | "metrics = evaluate_classification(image_features, text_features, labels)\n", 135 | "print(metrics)\n", 136 | "\n", 137 | "reduce_and_visualize(image_features.numpy(), text_features.numpy(), methods=['svd', 'pca', 'tsne', 'umap'], n_dim=2)\n", 138 | "\n", 139 | "all_metrics = move_features(image_features, text_features, partial(evaluate_classification, labels=labels))\n", 140 | "plot_metrics(all_metrics, metric_name='Hit@1')" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "# Move features along direction computed on MSCOCO\n", 150 | "\n", 151 | "model = load_clip()\n", 152 | "\n", 153 | "dataset = ImageCaptionDataset()\n", 154 | "image_features, text_features = encode_clip(model, dataset)\n", 155 | "direction_vec = image_features.mean(axis=0) - text_features.mean(axis=0)\n", 156 | "direction_vec = direction_vec / direction_vec.norm()\n", 157 | "\n", 158 | "dataset = ClassificationDataset(name='SVHN')\n", 159 | "image_features, text_features = encode_clip_classification(model, dataset, prompt='a street sign of the number: \"{}\".')\n", 160 | "labels = [item[1] for item in dataset]\n", 161 | "metrics = evaluate_classification(image_features, text_features, labels)\n", 162 | "print(metrics)\n", 163 | "\n", 164 | "reduce_and_visualize(image_features.numpy(), text_features.numpy(), methods=['svd', 'pca', 'tsne', 'umap'], n_dim=2)\n", 165 | "\n", 166 | "all_metrics = move_features(image_features, text_features, partial(evaluate_classification, labels=labels), direction_vec)\n", 167 | "plot_metrics(all_metrics, metric_name='Hit@1')" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "# Retrieval\n", 177 | "\n", 178 | "model = load_clip()\n", 179 | "dataset = ImageCaptionDataset()\n", 180 | "image_features, text_features = encode_clip(model, dataset)\n", 181 | "metrics = evaluate_retrieval(image_features, text_features)\n", 182 | "print(metrics)\n", 183 | "reduce_and_visualize(image_features.numpy(), text_features.numpy(), methods=['svd', 'pca', 'tsne', 'umap'], n_dim=2)\n", 184 | "\n", 185 | "all_metrics = move_features(image_features, text_features, evaluate_retrieval)\n", 186 | "plot_metrics(all_metrics, metric_name='Image R@1')\n", 187 | "plot_metrics(all_metrics, metric_name='Text R@1')" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "# Fine-tuning CLIP" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "dataset = ImageCaptionDataset(split='train', max_data_size=50000)\n", 204 | "model = load_clip()\n", 205 | "model.logit_scale.data = torch.log(torch.tensor(100))\n", 206 | "logs, model = train_clip_toy(model, dataset, f'ANONYMOUS_ROOTDIR/develop/open-world/exps/pretrained_512d_refactor_t100/', batch_size=64, end_epoch=5)" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "dataset = ImageCaptionDataset()\n", 216 | "model = load_clip()\n", 217 | "logs, model = train_clip_toy(model, dataset, f'ANONYMOUS_ROOTDIR/develop/open-world/exps/pretrained_512d_uniform_refactor/', loss_funcs=[ce_loss, uniform_loss])" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "dataset = ImageCaptionDataset()\n", 227 | "model = load_clip()\n", 228 | "logs, model = train_clip_toy(model, dataset, f'ANONYMOUS_ROOTDIR/develop/open-world/exps/pretrained_512d_dualloss_refactor/', loss_funcs=[dual_ce_loss])" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "dataset = ImageCaptionDataset()\n", 238 | "model = load_clip()\n", 239 | "logs, model = train_clip_toy(model, dataset, f'ANONYMOUS_ROOTDIR/develop/open-world/exps/pretrained_512d_removehard_refactor/', loss_funcs=[simple_ce_loss])" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": {}, 245 | "source": [ 246 | "# Downstream Task using Fine-tuned Models" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "model = load_clip('ANONYMOUS_ROOTDIR/develop/open-world/exps/pretrained_512d_refactor_t30/model_epoch_1.pt')\n", 256 | "# dataset = ImageCaptionDataset(split='train', max_data_size=50000)\n", 257 | "# dataset.data = dataset.data[:500]\n", 258 | "dataset = ImageCaptionDataset(split='val')\n", 259 | "image_features, text_features = encode_clip(model, dataset)\n", 260 | "feature_dist = (image_features.mean(axis=0) - text_features.mean(axis=0)).norm().item()\n", 261 | "print(feature_dist)\n", 262 | "metrics = evaluate_retrieval(image_features, text_features)\n", 263 | "print(metrics)\n", 264 | "reduce_and_visualize(image_features.numpy(), text_features.numpy(), methods=['svd', 'pca'], n_dim=2)\n" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": null, 270 | "metadata": {}, 271 | "outputs": [], 272 | "source": [ 273 | "model = load_clip('ANONYMOUS_ROOTDIR/develop/open-world/exps/pretrained_512d_refactor_t30/model_epoch_1.pt')\n", 274 | "dataset = ClassificationDataset(name='CIFAR10')\n", 275 | "image_features, text_features = encode_clip_classification(model, dataset)\n", 276 | "feature_dist = (image_features.mean(axis=0) - text_features.mean(axis=0)).norm().item()\n", 277 | "print(feature_dist)\n", 278 | "labels = [item[1] for item in dataset]\n", 279 | "metrics = evaluate_classification(image_features, text_features, labels)\n", 280 | "print(metrics)\n", 281 | "reduce_and_visualize(image_features.numpy(), text_features.numpy(), methods=['svd', 'pca'], n_dim=2)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "metadata": {}, 288 | "outputs": [], 289 | "source": [ 290 | "plt.figure()\n", 291 | "gaps = [0.2384, 0.3028, 0.5524, 0.6352, 0.7961, 1.0006]\n", 292 | "acc_i = [0.0214, 0.1896, 0.1772, 0.2048, 0.2090, 0.1836]\n", 293 | "acc_t = [0.0170, 0.1660, 0.1740, 0.2098, 0.2036, 0.1894]\n", 294 | "xs = [1, 1/10, 1/20, 1/30, 1/50, 1/100]\n", 295 | "plt.plot(xs, gaps, 'o-', label='Gap')\n", 296 | "plt.plot(xs, acc_i, 'o-', label='Image R@1')\n", 297 | "plt.plot(xs, acc_t, 'o-', label='Text R@1')\n", 298 | "# plt a line at y=0.8262\n", 299 | "plt.axhline(y=0.8262, color='k', linestyle='--')\n", 300 | "plt.legend()\n", 301 | "plt.xlabel('Temperature')\n", 302 | "\n", 303 | "plt.figure()\n", 304 | "gaps = [0.9407, 0.6450, 0.8455, 0.9346, 1.0092, 1.1241]\n", 305 | "acc = [0.1918, 0.5036, 0.4525, 0.4544, 0.5065, 0.3348]\n", 306 | "plt.plot(xs, gaps, 'o-', label='Gap')\n", 307 | "plt.plot(xs, acc, 'o-', label='Acc')\n", 308 | "plt.axhline(y=1.1136, color='k', linestyle='--')\n", 309 | "plt.legend()\n", 310 | "plt.xlabel('Temperature')" 311 | ] 312 | }, 313 | { 314 | "cell_type": "markdown", 315 | "metadata": {}, 316 | "source": [ 317 | "# Gap vs Prediction Overlap" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": null, 323 | "metadata": {}, 324 | "outputs": [], 325 | "source": [ 326 | "print((preds1 == preds2).float().mean())\n", 327 | "sim1 = image_features1 @ text_features1.t()\n", 328 | "sim2 = image_features2 @ text_features2.t()\n", 329 | "\n", 330 | "overlaps = []\n", 331 | "for idx in range(len(sim1)):\n", 332 | " top_preds1 = sim1[idx].argsort().tolist()[::-1][:5]\n", 333 | " for pred in top_preds1: print(dataset.data[pred])\n", 334 | " print()\n", 335 | " top_preds2 = sim2[idx].argsort().tolist()[::-1][:5]\n", 336 | " for pred in top_preds2: print(dataset.data[pred])\n", 337 | " overlap = len(set(top_preds1) & set(top_preds2)) / len(set(top_preds1) | set(top_preds2))\n", 338 | " overlaps.append(overlap)\n", 339 | " break\n", 340 | "\n", 341 | "# print(np.mean(overlaps))" 342 | ] 343 | }, 344 | { 345 | "cell_type": "markdown", 346 | "metadata": {}, 347 | "source": [ 348 | "# Fix initialization" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": null, 354 | "metadata": {}, 355 | "outputs": [], 356 | "source": [ 357 | "model = load_clip('ANONYMOUS_ROOTDIR/develop/open-world/exps/random_t100/model_epoch_1.pt')\n", 358 | "dataset = ImageCaptionDataset(split='train', max_data_size=50000)\n", 359 | "dataset.data = dataset.data[:500]\n", 360 | "# dataset = ImageCaptionDataset(split='val')\n", 361 | "image_features, text_features = encode_clip(model, dataset)\n", 362 | "feature_dist = (image_features.mean(axis=0) - text_features.mean(axis=0)).norm().item()\n", 363 | "print(feature_dist)\n", 364 | "metrics = evaluate_retrieval(image_features, text_features)\n", 365 | "print(metrics)\n", 366 | "reduce_and_visualize(image_features.numpy(), text_features.numpy(), methods=['svd', 'pca'], n_dim=2)\n" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": null, 372 | "metadata": {}, 373 | "outputs": [], 374 | "source": [ 375 | "model = load_clip('ANONYMOUS_ROOTDIR/develop/open-world/exps/random_t100_fix_init/model_epoch_1.pt')\n", 376 | "dataset = ImageCaptionDataset(split='train', max_data_size=50000)\n", 377 | "dataset.data = dataset.data[:500]\n", 378 | "w, _, _ = torch.load('ANONYMOUS_ROOTDIR/develop/open-world/exps/random_t100_fix_init/w.pt')\n", 379 | "# dataset = ImageCaptionDataset(split='val')\n", 380 | "image_features, text_features = encode_clip(model, dataset)\n", 381 | "text_features = text_features @ w.T\n", 382 | "feature_dist = (image_features.mean(axis=0) - text_features.mean(axis=0)).norm().item()\n", 383 | "print(feature_dist)\n", 384 | "metrics = evaluate_retrieval(image_features, text_features)\n", 385 | "print(metrics)\n", 386 | "reduce_and_visualize(image_features.numpy(), text_features.numpy(), methods=['svd', 'pca'], n_dim=2)\n" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": null, 392 | "metadata": {}, 393 | "outputs": [], 394 | "source": [] 395 | } 396 | ], 397 | "metadata": { 398 | "kernelspec": { 399 | "display_name": "Python 3 (ipykernel)", 400 | "language": "python", 401 | "name": "python3" 402 | }, 403 | "language_info": { 404 | "codemirror_mode": { 405 | "name": "ipython", 406 | "version": 3 407 | }, 408 | "file_extension": ".py", 409 | "mimetype": "text/x-python", 410 | "name": "python", 411 | "nbconvert_exporter": "python", 412 | "pygments_lexer": "ipython3", 413 | "version": "3.8.12" 414 | } 415 | }, 416 | "nbformat": 4, 417 | "nbformat_minor": 4 418 | } 419 | -------------------------------------------------------------------------------- /Table_1_Implications_CLIP_Zero_Shot/training/datasets.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import json 4 | import pickle 5 | import numpy as np 6 | from PIL import Image 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.utils.data import Dataset, DataLoader 11 | 12 | sys.path.append("ANONYMOUS_ROOTDIR/develop/open-world/vision") 13 | from torchvision import transforms 14 | from torchvision.datasets import ( 15 | CIFAR10, 16 | CIFAR100, 17 | MNIST, 18 | ImageNet, 19 | SVHN, 20 | Flowers102, 21 | EuroSAT, 22 | ImageFolder, 23 | ) 24 | 25 | sys.path.append("ANONYMOUS_ROOTDIR/develop/open-world/CLIP") 26 | import clip 27 | 28 | 29 | def get_default_transform(): 30 | return clip.load("ViT-B/16", device="cpu")[1] 31 | 32 | 33 | def get_default_tokenizer(): 34 | return clip.tokenize 35 | 36 | 37 | class ImageCaptionDataset(Dataset): 38 | def __init__( 39 | self, 40 | base_dir="ANONYMOUS_ROOTDIR/data/COCO/", 41 | split="val", 42 | transform=None, 43 | tokenizer=None, 44 | max_data_size=None, 45 | ): 46 | self.base_dir = base_dir 47 | self.split = split 48 | self.transform = transform if transform is not None else get_default_transform() 49 | self.tokenizer = tokenizer if tokenizer is not None else get_default_tokenizer() 50 | self.max_data_size = max_data_size 51 | 52 | data = json.load(open(f"{self.base_dir}/annotations/captions_{split}2017.json")) 53 | id2file = {item["id"]: item["coco_url"] for item in data["images"]} 54 | id2caption = {item["image_id"]: item["caption"] for item in data["annotations"]} 55 | self.data = [ 56 | (id2file[id].replace("http://images.cocodataset.org/", ""), id2caption[id]) 57 | for id in id2caption 58 | ] 59 | 60 | if self.max_data_size is not None: 61 | np.random.seed(1234) 62 | indices = np.random.choice( 63 | len(self.data), size=max_data_size, replace=False 64 | ) 65 | self.data = [self.data[i] for i in indices] 66 | 67 | def __len__(self): 68 | return len(self.data) 69 | 70 | def __getitem__(self, idx): 71 | filename, caption = self.data[idx] 72 | im = Image.open(f"{self.base_dir}/{filename}") 73 | image_input = self.transform(im) 74 | if len(image_input.shape) == 4: 75 | image_input = image_input[0] 76 | text_input = self.tokenizer(caption) 77 | if len(text_input.shape) == 2: 78 | text_input = text_input[0] 79 | return image_input, text_input 80 | 81 | @staticmethod 82 | def collate_fn(batch): 83 | images, texts = zip(*batch) 84 | images = torch.stack(images, dim=0) 85 | texts = torch.stack(texts, dim=0) 86 | return images, texts 87 | 88 | 89 | class ClassificationDataset(Dataset): 90 | def __init__(self, name="CIFAR100", transform=None, max_data_size=None): 91 | self.name = name 92 | self.transform = transform if transform is not None else get_default_transform() 93 | self.max_data_size = max_data_size 94 | 95 | if self.name in ["CIFAR100", "CIFAR10", "MNIST"]: 96 | self.data = eval(name)( 97 | root=os.path.expanduser("~/.cache"), download=True, train=False 98 | ) 99 | elif self.name in ["ImageNet"]: 100 | self.data = eval(name)(root=os.path.expanduser("~/.cache"), split="val") 101 | # if self.name == "ImageNet": 102 | # self.data.classes = json.load(open(os.path.expanduser('~/.cache/imagenet_classes.json'))) 103 | elif self.name in ["ImageNetSketch", "HateSpeechMeme"]: 104 | self.data = ImageFolder( 105 | root=f"ANONYMOUS_ROOTDIR/data/{self.name}/imagefolder" 106 | ) 107 | if self.name == "ImageNetSketch": 108 | lines = [ 109 | line.strip().split() 110 | for line in open( 111 | f"ANONYMOUS_ROOTDIR/data/{self.name}/map_clsloc.txt" 112 | ) 113 | ] 114 | for line in lines: 115 | assert len(line) == 3 116 | mapping = {line[0]: line[2].replace("_", " ") for line in lines} 117 | self.data.classes = [mapping[id] for id in self.data.classes] 118 | elif self.name in ["SVHN", "Flowers102"]: 119 | self.data = eval(name)( 120 | root=os.path.expanduser("~/.cache"), download=True, split="test" 121 | ) 122 | if self.name == "SVHN": 123 | self.data.classes = [i for i in range(10)] 124 | elif self.name == "Flowers102": 125 | self.data.classes = json.load( 126 | open(os.path.expanduser("~/.cache/flowers-102/mapping.json")) 127 | ) 128 | self.data._labels = [i - 1 for i in self.data._labels] 129 | elif self.name in ["EuroSAT"]: 130 | self.data = eval(name)(root=os.path.expanduser("~/.cache"), download=True) 131 | if self.name == "EuroSAT": 132 | self.data.classes = json.load( 133 | open(os.path.expanduser("~/.cache/eurosat/mapping.json")) 134 | ) 135 | else: 136 | raise ValueError(f"Unknown dataset: {self.name}") 137 | 138 | if self.max_data_size is not None: 139 | raise NotImplementedError 140 | 141 | def __len__(self): 142 | return len(self.data) 143 | 144 | def __getitem__(self, idx): 145 | im, label = self.data[idx] 146 | image_input = self.transform(im) 147 | if len(image_input.shape) == 4: 148 | image_input = image_input[0] 149 | return image_input, label 150 | 151 | @staticmethod 152 | def collate_fn(batch): 153 | images, labels = zip(*batch) 154 | images = torch.stack(images, dim=0) 155 | labels = torch.tensor(labels) 156 | return images, labels 157 | 158 | 159 | if __name__ == "__main__": 160 | dataset = ImageCaptionDataset() 161 | print(dataset[0]) 162 | dataset = ClassificationDataset() 163 | print(dataset[0]) 164 | -------------------------------------------------------------------------------- /Table_1_Implications_CLIP_Zero_Shot/training/train_clip.py: -------------------------------------------------------------------------------- 1 | # Most commonly used 2 | import sys 3 | import os 4 | import json 5 | import pickle 6 | import math 7 | from collections import Counter, defaultdict 8 | from functools import partial 9 | from tqdm import tqdm, trange 10 | from colors import blue, red, green, cyan 11 | 12 | # Numerical computation 13 | import numpy as np 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | # Visualization 18 | import matplotlib.pyplot as plt 19 | import seaborn as sns 20 | 21 | sns.set_theme() 22 | sns.set_context("talk") 23 | 24 | sys.path.append("ANONYMOUS_ROOTDIR/develop/open-world/") 25 | from utils import ( 26 | svd, 27 | reduce_and_visualize, 28 | load_clip, 29 | encode_clip, 30 | encode_clip_classification, 31 | train_clip_toy, 32 | train_clip_toy_fix_init, 33 | ce_loss, 34 | uniform_loss, 35 | dual_ce_loss, 36 | simple_ce_loss, 37 | ) 38 | from datasets import ImageCaptionDataset, ClassificationDataset 39 | 40 | 41 | def evaluate_retrieval(image_features, text_features): 42 | metrics = {} 43 | sim = image_features @ text_features.T 44 | for K in [1, 5, 10]: 45 | pred = sim.argsort(dim=-1) 46 | text_r = np.mean([i in pred[i, -K:] for i in range(len(pred))]) 47 | 48 | pred = sim.argsort(dim=0) 49 | image_r = np.mean([i in pred[-K:, i] for i in range(len(pred))]) 50 | 51 | metrics[f"Text R@{K}"] = text_r 52 | metrics[f"Image R@{K}"] = image_r 53 | return metrics 54 | 55 | 56 | def evaluate_classification(image_features, text_features, labels): 57 | metrics = {} 58 | sim = image_features @ text_features.T 59 | for K in [1, 5, 10]: 60 | pred = sim.argsort(dim=-1) 61 | text_r = np.mean([labels[i] in pred[i, -K:] for i in range(len(pred))]) 62 | metrics[f"Hit@{K}"] = text_r 63 | return metrics 64 | 65 | 66 | def evaluate_binary_classification(image_features, text_features, labels): 67 | from sklearn.metrics import roc_auc_score 68 | 69 | metrics = {} 70 | sim = image_features @ text_features.T * 100 71 | probs = F.softmax(sim, dim=-1)[:, 1] 72 | roc_auc = roc_auc_score(labels, probs) 73 | metrics[f"ROC-AUC"] = roc_auc 74 | return metrics 75 | 76 | 77 | def move_features(image_features, text_features, evaluate_func): 78 | all_metrics = {} 79 | 80 | modality_gap = image_features.mean(axis=0) - text_features.mean(axis=0) 81 | modality_gap = modality_gap / modality_gap.norm() 82 | modality_gap.unsqueeze(0) 83 | 84 | for delta in np.arange(-5, 5, 0.25): 85 | modified_text_features = text_features + 0.5 * delta * modality_gap 86 | modified_text_features /= modified_text_features.norm(dim=-1, keepdim=True) 87 | 88 | modified_image_features = image_features - 0.5 * delta * modality_gap 89 | modified_image_features /= modified_image_features.norm(dim=-1, keepdim=True) 90 | 91 | # reduce_and_visualize(modified_image_features.numpy(), modified_text_features.numpy(), methods=['svd', 'pca'], n_dim=2) 92 | 93 | preds = (modified_image_features @ modified_text_features.T).argmax(dim=-1) 94 | 95 | gap_distance = ( 96 | (modified_text_features.mean(axis=0) - modified_image_features.mean(axis=0)) 97 | .norm() 98 | .item() 99 | ) 100 | 101 | metrics = evaluate_func(modified_image_features, modified_text_features) 102 | all_metrics[delta] = (metrics, gap_distance, preds) 103 | 104 | print(delta, metrics, gap_distance) 105 | return all_metrics 106 | 107 | 108 | def move_features_along_hypersphere(image_features, text_features, evaluate_func): 109 | return "Impossible" 110 | 111 | 112 | def plot_metrics(all_metrics, metric_name="Hit@1"): 113 | xs, ys = [], [] 114 | for delta in sorted(all_metrics.keys()): 115 | metrics, gap_distance, preds = all_metrics[delta] 116 | xs.append(gap_distance) 117 | ys.append(metrics[metric_name]) 118 | print(f"Optimal {metric_name}: {max(ys)}") 119 | 120 | minidx = xs.index(min(xs)) 121 | for i in range(minidx + 1, len(xs)): 122 | xs[i] = -xs[i] 123 | plt.plot(xs, ys, "o-") 124 | plt.xlabel("Gap Distance") 125 | plt.ylabel(metric_name) 126 | 127 | initial_gap = all_metrics[0][1] 128 | plt.axvline(initial_gap, color="k", linestyle="--") 129 | 130 | plt.show() 131 | 132 | 133 | if __name__ == "__main__": 134 | temperature = int(sys.argv[1]) 135 | print(f"Temperature: {temperature}") 136 | dataset = ImageCaptionDataset(split="train", max_data_size=50000) 137 | model = load_clip("random") 138 | model.logit_scale.data = torch.log(torch.tensor(temperature)) 139 | logs, model = train_clip_toy( 140 | model, 141 | dataset, 142 | f"ANONYMOUS_ROOTDIR/develop/open-world/exps/random_t{temperature}_2/", 143 | batch_size=64, 144 | end_epoch=5, 145 | ) 146 | logs, model = train_clip_toy_fix_init( 147 | model, 148 | dataset, 149 | f"ANONYMOUS_ROOTDIR/develop/open-world/exps/random_t{temperature}_fix_init/", 150 | batch_size=64, 151 | end_epoch=5, 152 | ) 153 | -------------------------------------------------------------------------------- /Table_1_Implications_CLIP_Zero_Shot/training/utils.py: -------------------------------------------------------------------------------- 1 | # Most commonly used 2 | import sys 3 | import os 4 | import json 5 | import pickle 6 | import math 7 | from collections import Counter, defaultdict 8 | from functools import partial 9 | from tqdm import tqdm, trange 10 | from colors import blue, red, green, cyan 11 | 12 | # Numerical computation 13 | import numpy as np 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | # Visualization 18 | from matplotlib import pyplot as plt 19 | from mpl_toolkits.mplot3d import Axes3D 20 | from sklearn.decomposition import PCA 21 | from sklearn.manifold import TSNE 22 | from umap.umap_ import UMAP 23 | from sklearn.cluster import KMeans 24 | 25 | # Density estimation 26 | sys.path.append("ANONYMOUS_ROOTDIR/develop/open-world/vonmiseskde") 27 | from vonmiseskde import VonMisesKDE 28 | from sklearn.neighbors import KernelDensity 29 | 30 | # Image processing 31 | from PIL import Image 32 | from torchvision import transforms 33 | from torch.utils.data import Dataset, DataLoader 34 | 35 | # Multimodal model 36 | sys.path.append("ANONYMOUS_ROOTDIR/develop/open-world/CLIP") 37 | import clip 38 | from clip.model import CLIP 39 | 40 | 41 | def get_device(): 42 | return "cuda" if torch.cuda.is_available() else "cpu" 43 | 44 | 45 | def load_clip(model_path=None): 46 | device = get_device() 47 | if model_path is None: 48 | print("Loading original model...") 49 | model, _ = clip.load("ViT-B/16", device=device) 50 | model.float() 51 | else: 52 | print(f"Loading model from {model_path}...") 53 | model = CLIP( 54 | embed_dim=512, 55 | image_resolution=224, 56 | vision_layers=12, 57 | vision_width=768, 58 | vision_patch_size=16, 59 | context_length=77, 60 | vocab_size=49408, 61 | transformer_width=512, 62 | transformer_heads=8, 63 | transformer_layers=12, 64 | ).to(device) 65 | if model_path != "random": 66 | model.load_state_dict(torch.load(model_path)) 67 | model.eval() 68 | print(f"Temperature: {model.logit_scale.exp()}") 69 | return model 70 | 71 | 72 | def encode_clip(model, dataset, batch_size=32): 73 | device = get_device() 74 | 75 | dataloader = DataLoader( 76 | dataset, 77 | batch_size=batch_size, 78 | shuffle=False, 79 | num_workers=batch_size // 4, 80 | collate_fn=dataset.collate_fn, 81 | ) 82 | 83 | all_image_features, all_text_features = [], [] 84 | with torch.no_grad(): 85 | for batch in tqdm(dataloader): 86 | image_inputs, text_inputs = batch 87 | image_inputs, text_inputs = image_inputs.to(device), text_inputs.to(device) 88 | 89 | image_features = model.encode_image(image_inputs).cpu() 90 | image_features /= image_features.norm(dim=-1, keepdim=True) 91 | all_image_features.append(image_features) 92 | 93 | text_features = model.encode_text(text_inputs).cpu() 94 | text_features /= text_features.norm(dim=-1, keepdim=True) 95 | all_text_features.append(text_features) 96 | 97 | all_image_features = torch.cat(all_image_features, dim=0) 98 | all_text_features = torch.cat(all_text_features, dim=0) 99 | 100 | return all_image_features, all_text_features 101 | 102 | 103 | def align_loss_(x, y, alpha=2): 104 | return (x - y).norm(p=2, dim=1).pow(alpha).mean() 105 | 106 | 107 | def uniform_loss_(x, t=2): 108 | return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log() 109 | 110 | 111 | def ce_loss(model, image_features, text_features): 112 | loss_func = torch.nn.CrossEntropyLoss() 113 | 114 | logit_scale = model.logit_scale.exp() 115 | logits_per_image = logit_scale * image_features @ text_features.t() 116 | logits_per_text = logits_per_image.t() 117 | 118 | batch_size = image_features.size(0) 119 | device = get_device() 120 | ground_truth = torch.arange(batch_size, dtype=torch.long, device=device) 121 | 122 | loss = ( 123 | loss_func(logits_per_image, ground_truth) 124 | + loss_func(logits_per_text, ground_truth) 125 | ) / 2 126 | return loss 127 | 128 | 129 | def uniform_loss(model, image_features, text_features): 130 | loss = (uniform_loss_(image_features) + uniform_loss_(text_features)) / 2 131 | return loss 132 | 133 | 134 | def dual_ce_loss(model, image_features, text_features): 135 | loss_func = torch.nn.CrossEntropyLoss() 136 | 137 | features = torch.cat([image_features, text_features], 0) 138 | sims = features @ features.t() 139 | 140 | logit_scale = model.logit_scale.exp() 141 | logits = sims * logit_scale 142 | 143 | batch_size = image_features.size(0) 144 | logits_per_image = logits[:batch_size, :].contiguous() 145 | logits_per_image[torch.arange(batch_size), torch.arange(batch_size)] -= 10000 146 | logits_per_text = logits[batch_size:, :].contiguous() 147 | logits_per_text[ 148 | torch.arange(batch_size), torch.arange(batch_size) + batch_size 149 | ] -= 10000 150 | 151 | device = get_device() 152 | image_ground_truth = ( 153 | torch.arange(batch_size, dtype=torch.long, device=device) + batch_size 154 | ) 155 | text_ground_truth = torch.arange(batch_size, dtype=torch.long, device=device) 156 | 157 | loss = ( 158 | loss_func(logits_per_image, image_ground_truth) 159 | + loss_func(logits_per_text, text_ground_truth) 160 | ) / 2 161 | return loss 162 | 163 | 164 | def simple_ce_loss(model, image_features, text_features): 165 | loss_func = torch.nn.CrossEntropyLoss(reduction="none") 166 | 167 | logit_scale = model.logit_scale.exp() 168 | logits_per_image = logit_scale * image_features @ text_features.t() 169 | logits_per_text = logits_per_image.t() 170 | 171 | preds_per_image = torch.argmax(logits_per_image, dim=1) 172 | preds_per_text = torch.argmax(logits_per_text, dim=1) 173 | 174 | batch_size = image_features.size(0) 175 | device = get_device() 176 | ground_truth = torch.arange(batch_size, dtype=torch.long, device=device) 177 | 178 | correct_per_image = (preds_per_image == ground_truth).float() 179 | correct_per_text = (preds_per_text == ground_truth).float() 180 | 181 | loss_img = (loss_func(logits_per_image, ground_truth) * correct_per_image).sum() / ( 182 | correct_per_image.sum() + 1e-6 183 | ) 184 | loss_text = (loss_func(logits_per_text, ground_truth) * correct_per_text).sum() / ( 185 | correct_per_text.sum() + 1e-6 186 | ) 187 | 188 | loss = (loss_img + loss_text) / 2 189 | return loss 190 | 191 | 192 | def train_clip_toy_fix_init( 193 | model, 194 | dataset, 195 | model_path, 196 | batch_size=32, 197 | start_epoch=0, 198 | end_epoch=10, 199 | loss_funcs=[ce_loss], 200 | ): 201 | if not os.path.exists(model_path): 202 | os.makedirs(model_path) 203 | device = get_device() 204 | 205 | if start_epoch == 0: 206 | print("Training original model...") 207 | torch.save(model.state_dict(), f"{model_path}/model_epoch_{start_epoch}.pt") 208 | else: 209 | print(f"Loading model from {model_path} and continue training...") 210 | assert os.path.exists(f"{model_path}/model_epoch_{start_epoch}.pt") 211 | model.load_state_dict(torch.load(f"{model_path}/model_epoch_{start_epoch}.pt")) 212 | 213 | dataloader = DataLoader( 214 | dataset, 215 | batch_size=batch_size, 216 | shuffle=True, 217 | num_workers=batch_size // 4, 218 | collate_fn=dataset.collate_fn, 219 | drop_last=True, 220 | ) 221 | 222 | all_image_features, all_text_features = encode_clip(model, dataset) 223 | yx = all_image_features.t() @ all_text_features 224 | u, s, v = torch.svd(yx) 225 | w = u @ v.T 226 | torch.save([w, all_image_features, all_text_features], f"{model_path}/w.pt") 227 | all_text_features_transform = all_text_features @ w.T 228 | w = w.to(device) 229 | 230 | optimizer = torch.optim.Adam( 231 | model.parameters(), lr=1e-5, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.2 232 | ) 233 | 234 | logs = {} 235 | for epoch in range(start_epoch + 1, end_epoch + 1): 236 | logs[epoch] = [] 237 | bar = tqdm(dataloader) 238 | for i, batch in enumerate(bar): 239 | image_inputs, text_inputs = batch 240 | image_inputs, text_inputs = image_inputs.to(device), text_inputs.to(device) 241 | 242 | image_features = model.encode_image(image_inputs) 243 | text_features = model.encode_text(text_inputs) 244 | 245 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 246 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 247 | text_features = text_features @ w.T 248 | 249 | losses = [ 250 | loss_func(model, image_features, text_features) 251 | for loss_func in loss_funcs 252 | ] 253 | loss = sum(losses) 254 | 255 | optimizer.zero_grad() 256 | loss.backward() 257 | optimizer.step() 258 | 259 | logs[epoch].append( 260 | {"loss": loss.item(), "losses": [loss.item() for loss in losses]} 261 | ) 262 | bar.set_description(f"Epoch {epoch}/{end_epoch}, Loss: {logs[epoch][i]}") 263 | 264 | torch.save(model.state_dict(), f"{model_path}/model_epoch_{epoch}.pt") 265 | 266 | epoch_loss = np.mean([item["loss"] for item in logs[epoch]]) 267 | epoch_losses = [ 268 | np.mean([item["losses"][i] for item in logs[epoch]]) 269 | for i in range(len(loss_funcs)) 270 | ] 271 | print(f"Epoch {epoch}: loss = {epoch_loss:.4f}, losses = {epoch_losses}") 272 | return model, logs 273 | 274 | 275 | def train_clip_toy( 276 | model, 277 | dataset, 278 | model_path, 279 | batch_size=32, 280 | start_epoch=0, 281 | end_epoch=10, 282 | loss_funcs=[ce_loss], 283 | ): 284 | if not os.path.exists(model_path): 285 | os.makedirs(model_path) 286 | device = get_device() 287 | 288 | if start_epoch == 0: 289 | print("Training original model...") 290 | torch.save(model.state_dict(), f"{model_path}/model_epoch_{start_epoch}.pt") 291 | else: 292 | print(f"Loading model from {model_path} and continue training...") 293 | assert os.path.exists(f"{model_path}/model_epoch_{start_epoch}.pt") 294 | model.load_state_dict(torch.load(f"{model_path}/model_epoch_{start_epoch}.pt")) 295 | 296 | dataloader = DataLoader( 297 | dataset, 298 | batch_size=batch_size, 299 | shuffle=True, 300 | num_workers=batch_size // 4, 301 | collate_fn=dataset.collate_fn, 302 | drop_last=True, 303 | ) 304 | 305 | optimizer = torch.optim.Adam( 306 | model.parameters(), lr=1e-5, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.2 307 | ) 308 | 309 | logs = {} 310 | for epoch in range(start_epoch + 1, end_epoch + 1): 311 | logs[epoch] = [] 312 | bar = tqdm(dataloader) 313 | for i, batch in enumerate(bar): 314 | image_inputs, text_inputs = batch 315 | image_inputs, text_inputs = image_inputs.to(device), text_inputs.to(device) 316 | 317 | image_features = model.encode_image(image_inputs) 318 | text_features = model.encode_text(text_inputs) 319 | 320 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 321 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 322 | 323 | losses = [ 324 | loss_func(model, image_features, text_features) 325 | for loss_func in loss_funcs 326 | ] 327 | loss = sum(losses) 328 | 329 | optimizer.zero_grad() 330 | loss.backward() 331 | optimizer.step() 332 | 333 | logs[epoch].append( 334 | {"loss": loss.item(), "losses": [loss.item() for loss in losses]} 335 | ) 336 | bar.set_description(f"Epoch {epoch}/{end_epoch}, Loss: {logs[epoch][i]}") 337 | 338 | torch.save(model.state_dict(), f"{model_path}/model_epoch_{epoch}.pt") 339 | 340 | epoch_loss = np.mean([item["loss"] for item in logs[epoch]]) 341 | epoch_losses = [ 342 | np.mean([item["losses"][i] for item in logs[epoch]]) 343 | for i in range(len(loss_funcs)) 344 | ] 345 | print(f"Epoch {epoch}: loss = {epoch_loss:.4f}, losses = {epoch_losses}") 346 | return model, logs 347 | 348 | 349 | def encode_clip_classification( 350 | model, dataset, prompt="a photo of a {}.", batch_size=32 351 | ): 352 | device = get_device() 353 | 354 | text_inputs = torch.cat( 355 | [clip.tokenize(prompt.format(c)) for c in dataset.data.classes] 356 | ).to(device) 357 | with torch.no_grad(): 358 | all_text_features = model.encode_text(text_inputs).cpu() 359 | all_text_features /= all_text_features.norm(dim=-1, keepdim=True) 360 | 361 | dataloader = DataLoader( 362 | dataset, 363 | batch_size=batch_size, 364 | shuffle=False, 365 | num_workers=batch_size // 4, 366 | collate_fn=dataset.collate_fn, 367 | ) 368 | 369 | all_image_features = [] 370 | with torch.no_grad(): 371 | for batch in tqdm(dataloader): 372 | image_inputs, labels = batch 373 | image_inputs = image_inputs.to(device) 374 | 375 | image_features = model.encode_image(image_inputs).cpu() 376 | image_features /= image_features.norm(dim=-1, keepdim=True) 377 | all_image_features.append(image_features) 378 | 379 | all_image_features = torch.cat(all_image_features, dim=0) 380 | 381 | return all_image_features, all_text_features 382 | 383 | 384 | def svd(X, n_components=2, return_singular_values=False): 385 | U, S, Vt = np.linalg.svd(X) 386 | X_reduce = U[:, :n_components] * S[:n_components] 387 | if return_singular_values: 388 | return X_reduce, S 389 | return X_reduce 390 | 391 | 392 | def visualize_2d(clusters, colors=None, labels=None, connection=False): 393 | assert isinstance(clusters, list) 394 | for cluster in clusters: 395 | assert isinstance(cluster, np.ndarray) 396 | assert cluster.shape[1] == 2 397 | 398 | fig = plt.figure(figsize=(5, 5)) 399 | if colors is None: 400 | colors = ["r" for i in range(len(clusters))] 401 | if labels is None: 402 | labels = [f"cluster_{i}" for i in range(len(clusters))] 403 | for cluster, color, label in zip(clusters, colors, labels): 404 | plt.scatter(cluster[:, 0], cluster[:, 1], c=color, label=label, alpha=0.2) 405 | 406 | if connection: 407 | assert len(clusters) == 2 and len(clusters[0]) == len(clusters[1]) 408 | for i in range(len(clusters[0])): 409 | plt.plot( 410 | [clusters[0][i, 0], clusters[1][i, 0]], 411 | [clusters[0][i, 1], clusters[1][i, 1]], 412 | c="k", 413 | alpha=0.05, 414 | ) 415 | plt.show() 416 | 417 | 418 | def visualize_3d(clusters, colors=None, labels=None, connection=False): 419 | assert isinstance(clusters, list) 420 | assert connection == False 421 | for cluster in clusters: 422 | assert isinstance(cluster, np.ndarray) 423 | assert cluster.shape[1] == 3 424 | 425 | fig = plt.figure() 426 | ax = Axes3D(fig) 427 | if colors is None: 428 | colors = ["r" for i in range(len(clusters))] 429 | if labels is None: 430 | labels = [f"cluster_{i}" for i in range(len(clusters))] 431 | for cluster, color, label in zip(clusters, colors, labels): 432 | ax.scatter( 433 | cluster[:, 0], cluster[:, 1], cluster[:, 2], c=color, label=label, alpha=0.2 434 | ) 435 | ax.set_xlabel("X") 436 | ax.set_ylabel("Y") 437 | ax.set_zlabel("Z") 438 | fig.add_axes(ax) 439 | plt.show() 440 | 441 | 442 | def dim_reduce(features, n_dim=2, methods=["svd", "pca", "tsne", "umap"]): 443 | assert isinstance(features, np.ndarray) 444 | 445 | features_reduce = {} 446 | for method in methods: 447 | if method == "svd": 448 | features_reduce[method] = svd(features, n_components=n_dim) 449 | else: 450 | projector = eval(method.upper())(n_components=n_dim) 451 | features_reduce[method] = projector.fit_transform(features) 452 | return features_reduce 453 | 454 | 455 | def reduce_and_visualize( 456 | image_features, 457 | text_features, 458 | n_dim=2, 459 | methods=["svd", "pca", "tsne", "umap"], 460 | connection=False, 461 | ): 462 | assert isinstance(image_features, np.ndarray) and isinstance( 463 | text_features, np.ndarray 464 | ) 465 | assert n_dim in [2, 3] 466 | 467 | features = np.concatenate([image_features, text_features], axis=0) 468 | features_reduce = dim_reduce(features, n_dim=n_dim, methods=methods) 469 | 470 | for i, method in enumerate(methods): 471 | image_features_reduce = features_reduce[method][: len(image_features)] 472 | text_features_reduce = features_reduce[method][len(image_features) :] 473 | eval(f"visualize_{n_dim}d")( 474 | [image_features_reduce, text_features_reduce], 475 | colors=["r", "b"], 476 | connection=connection, 477 | ) 478 | 479 | 480 | def convert_image_to_rgb(image): 481 | return image.convert("RGB") 482 | 483 | 484 | def estimate_density(image_features, text_features): 485 | x_plot = np.linspace(-1.2, 1.2, 100) 486 | y_plot = np.linspace(-1.2, 1.2, 100) 487 | xy_plot = np.array(np.meshgrid(x_plot, y_plot)).reshape(2, -1).T 488 | 489 | kde_image = KernelDensity(kernel="gaussian", bandwidth=0.1).fit(image_features) 490 | image_density = np.exp(kde_image.score_samples(xy_plot)) 491 | 492 | kde_text = KernelDensity(kernel="gaussian", bandwidth=0.1).fit(text_features) 493 | text_density = np.exp(kde_text.score_samples(xy_plot)) 494 | 495 | plt.figure(figsize=(10, 5)) 496 | 497 | plt.subplot(1, 2, 1) 498 | plt.imshow( 499 | image_density.reshape(100, 100), 500 | extent=(-1.2, 1.2, -1.2, 1.2), 501 | origin="lower", 502 | cmap="Reds", 503 | alpha=0.5, 504 | vmin=min([image_density.min(), text_density.min()]), 505 | vmax=max([image_density.max(), text_density.max()]), 506 | ) 507 | plt.scatter(image_features[:, 0], image_features[:, 1], c="red", alpha=0.05) 508 | 509 | plt.subplot(1, 2, 2) 510 | plt.imshow( 511 | text_density.reshape(100, 100), 512 | extent=(-1.2, 1.2, -1.2, 1.2), 513 | origin="lower", 514 | cmap="Blues", 515 | alpha=0.5, 516 | vmin=min([image_density.min(), text_density.min()]), 517 | vmax=max([image_density.max(), text_density.max()]), 518 | ) 519 | plt.scatter(text_features[:, 0], text_features[:, 1], c="blue", alpha=0.05) 520 | 521 | print( 522 | text_density.min(), 523 | text_density.max(), 524 | text_density.mean(), 525 | image_density.min(), 526 | image_density.max(), 527 | image_density.mean(), 528 | ) 529 | 530 | 531 | def estimate_angle_density(image_features, text_features): 532 | image_features_angle = [ 533 | np.arctan2(image_features[i, 1], image_features[i, 0]).item() 534 | for i in range(len(image_features)) 535 | ] 536 | text_features_angle = [ 537 | np.arctan2(text_features[i, 1], text_features[i, 0]).item() 538 | for i in range(len(text_features)) 539 | ] 540 | 541 | kappa = 25 542 | kde_image = VonMisesKDE(image_features_angle, weights=[], kappa=kappa) 543 | kde_text = VonMisesKDE(text_features_angle, weights=[], kappa=kappa) 544 | 545 | test_x = np.linspace(-math.pi, math.pi, 100) 546 | 547 | # # Display individual distributions 548 | # for i in np.arange(0, len(text_features_angle)): 549 | # sample = text_features_angle[i] 550 | # test_y = kde_text.vonMisesPDF(test_x, sample) 551 | # test_y = test_y / test_y.sum() 552 | # plt.plot(test_x, test_y, color='gray', alpha=0.5) 553 | 554 | # Display posterior estimate 555 | plt.figure(figsize=(10, 1)) 556 | 557 | plt.subplot(1, 2, 1) 558 | plt.plot(test_x, kde_image.evaluate(test_x), zorder=20, color="red", alpha=0.5) 559 | plt.fill_between( 560 | test_x, kde_image.evaluate(test_x), step="pre", alpha=0.2, color="red" 561 | ) 562 | plt.xlim(-math.pi, math.pi) 563 | plt.ylim(0, 1) 564 | 565 | plt.subplot(1, 2, 2) 566 | plt.plot(test_x, kde_text.evaluate(test_x), zorder=20, color="blue", alpha=0.5) 567 | plt.fill_between( 568 | test_x, kde_text.evaluate(test_x), step="pre", alpha=0.2, color="blue" 569 | ) 570 | plt.xlim(-math.pi, math.pi) 571 | plt.ylim(0, 1) 572 | 573 | 574 | if __name__ == "__main__": 575 | ##### Test svd() ##### 576 | X = np.arange(100).reshape(10, 10) 577 | X_2d = svd(X) 578 | assert X_2d.shape == (10, 2) 579 | -------------------------------------------------------------------------------- /Table_2_Implications_CLIP_Fairness/coco-extract.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "YPHN7PJgKOzb" 7 | }, 8 | "source": [ 9 | "# Image Feature Pair Extract - CLIP, ResNet18. \n", 10 | "conda activate clip\n", 11 | "\n", 12 | "\n", 13 | "clip_image_features_list (118287, 512)\n", 14 | "target_image_features_list (118287, 512)\n", 15 | "clip_image_features_list (5000, 512)\n", 16 | "target_image_features_list (5000, 512)\n", 17 | "\n", 18 | "Feature extraction complete in 6m 16s" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "metadata": { 25 | "colab": { 26 | "base_uri": "https://localhost:8080/" 27 | }, 28 | "id": "C1hkDT38hSaP", 29 | "outputId": "70a44964-883d-4fd0-b95a-2c7f2b19aca9" 30 | }, 31 | "outputs": [ 32 | { 33 | "name": "stdout", 34 | "output_type": "stream", 35 | "text": [ 36 | "Torch version: 1.7.1\n" 37 | ] 38 | } 39 | ], 40 | "source": [ 41 | "import numpy as np\n", 42 | "import torch\n", 43 | "import pickle\n", 44 | "import time\n", 45 | "print(\"Torch version:\", torch.__version__)\n", 46 | "\n", 47 | "assert torch.__version__.split(\".\") >= [\"1\", \"7\", \"1\"], \"PyTorch 1.7.1 or later is required\"\n", 48 | "\n", 49 | "import os\n", 50 | "import matplotlib.pyplot as plt\n", 51 | "from collections import OrderedDict\n", 52 | "import torch\n", 53 | "\n", 54 | "%matplotlib inline\n", 55 | "%config InlineBackend.figure_format = 'retina'" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "# Load CLIP" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 2, 68 | "metadata": { 69 | "colab": { 70 | "base_uri": "https://localhost:8080/" 71 | }, 72 | "id": "uLFS29hnhlY4", 73 | "outputId": "11779e1e-8bdd-4167-c18e-d26bdd6b67db" 74 | }, 75 | "outputs": [ 76 | { 77 | "data": { 78 | "text/plain": [ 79 | "['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']" 80 | ] 81 | }, 82 | "execution_count": 2, 83 | "metadata": {}, 84 | "output_type": "execute_result" 85 | } 86 | ], 87 | "source": [ 88 | "import clip\n", 89 | "\n", 90 | "clip.available_models()" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 3, 96 | "metadata": { 97 | "colab": { 98 | "base_uri": "https://localhost:8080/" 99 | }, 100 | "id": "IBRVTY9lbGm8", 101 | "outputId": "f06fd2fd-6126-475b-87d0-b10aa3b7da49" 102 | }, 103 | "outputs": [ 104 | { 105 | "name": "stdout", 106 | "output_type": "stream", 107 | "text": [ 108 | "Model parameters: 151,277,313\n", 109 | "Input resolution: 224\n", 110 | "Context length: 77\n", 111 | "Vocab size: 49408\n" 112 | ] 113 | } 114 | ], 115 | "source": [ 116 | "model, preprocess = clip.load(\"ViT-B/32\")\n", 117 | "model.cuda().eval()\n", 118 | "input_resolution = model.visual.input_resolution\n", 119 | "context_length = model.context_length\n", 120 | "vocab_size = model.vocab_size\n", 121 | "\n", 122 | "print(\"Model parameters:\", f\"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}\")\n", 123 | "print(\"Input resolution:\", input_resolution)\n", 124 | "print(\"Context length:\", context_length)\n", 125 | "print(\"Vocab size:\", vocab_size)\n", 126 | "\n", 127 | "clip_model = model" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 4, 133 | "metadata": {}, 134 | "outputs": [ 135 | { 136 | "name": "stdout", 137 | "output_type": "stream", 138 | "text": [ 139 | "logit_scale 4.605170249938965\n", 140 | "temperature 0.009999999360491285\n", 141 | "1/temperature 100.00000639508755\n" 142 | ] 143 | } 144 | ], 145 | "source": [ 146 | "# what is the final learned temperature? \n", 147 | "# np.log(1 / tau) = logit_scale\n", 148 | "# 1 / tau = np.exp(logit_scale)\n", 149 | "# tau = 1/np.exp(logit_scale)\n", 150 | "logit_scale = clip_model.logit_scale.detach().cpu().item()\n", 151 | "print('logit_scale', logit_scale)\n", 152 | "print('temperature', 1/np.exp(logit_scale))\n", 153 | "print('1/temperature', np.exp(logit_scale))" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 5, 159 | "metadata": {}, 160 | "outputs": [ 161 | { 162 | "data": { 163 | "text/plain": [ 164 | "torchvision.transforms.transforms.Compose" 165 | ] 166 | }, 167 | "execution_count": 5, 168 | "metadata": {}, 169 | "output_type": "execute_result" 170 | } 171 | ], 172 | "source": [ 173 | "type(preprocess)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "# Load Data" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 6, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "import torchvision\n", 190 | "from torch.utils.data import DataLoader\n", 191 | "\n", 192 | "\n", 193 | "coco_val_dataset = torchvision.datasets.ImageFolder(\n", 194 | " root = './dummy_val',\n", 195 | " transform=preprocess,\n", 196 | " )" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 7, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "coco_val_dataloader = DataLoader(coco_val_dataset, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)" 206 | ] 207 | }, 208 | { 209 | "cell_type": "markdown", 210 | "metadata": {}, 211 | "source": [ 212 | "# Extractor loop\n" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 8, 218 | "metadata": {}, 219 | "outputs": [ 220 | { 221 | "name": "stdout", 222 | "output_type": "stream", 223 | "text": [ 224 | "clip_image_features_list (10954, 512)\n", 225 | "\n", 226 | "Feature Extraction completed in 0m 12s\n" 227 | ] 228 | } 229 | ], 230 | "source": [ 231 | "since = time.time()\n", 232 | "dataloaders = {\n", 233 | " 'val': coco_val_dataloader,\n", 234 | "}\n", 235 | "# Each epoch has a training and validation phase\n", 236 | "for phase in ['val',]:\n", 237 | "\n", 238 | " clip_model.eval() # Set model to evaluate mode, for extraction\n", 239 | " ##################################\n", 240 | " # Fields to be stored for postprocessing \n", 241 | " ##################################\n", 242 | " clip_image_features_list = []\n", 243 | "\n", 244 | " # Iterate over data.\n", 245 | " for inputs, captions in dataloaders[phase]:\n", 246 | " image_input = inputs.cuda(non_blocking=True)\n", 247 | " text_input = captions.cuda(non_blocking=True)\n", 248 | " # TODO: add text here\n", 249 | " \n", 250 | " with torch.set_grad_enabled(False):\n", 251 | " clip_image_features = clip_model.encode_image(image_input).float()\n", 252 | "\n", 253 | " ##################################\n", 254 | " # Evaluation book-keeping Field \n", 255 | " ##################################\n", 256 | " clip_image_features_list.append( clip_image_features.cpu().numpy() )\n", 257 | "\n", 258 | " ##################################\n", 259 | " # Evaluation book-keeping Field \n", 260 | " ##################################\n", 261 | " clip_image_features_list = np.concatenate( clip_image_features_list, axis=0)\n", 262 | " print('clip_image_features_list', clip_image_features_list.shape)\n", 263 | "\n", 264 | " dump_result_dict = {\n", 265 | " \"clip_image_features_list\": clip_image_features_list, \n", 266 | " }\n", 267 | " with open(os.path.join('features', 'feature_dump_{}.pkl'.format(phase) ), \"wb\") as pkl_file:\n", 268 | " pickle.dump(\n", 269 | " dump_result_dict, \n", 270 | " pkl_file, \n", 271 | " )\n", 272 | "\n", 273 | "print()\n", 274 | "\n", 275 | "time_elapsed = time.time() - since\n", 276 | "print('Feature Extraction completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [] 285 | } 286 | ], 287 | "metadata": { 288 | "accelerator": "GPU", 289 | "colab": { 290 | "collapsed_sections": [], 291 | "name": "Interacting with CLIP.ipynb", 292 | "provenance": [] 293 | }, 294 | "kernelspec": { 295 | "display_name": "Python 3", 296 | "name": "python3" 297 | }, 298 | "language_info": { 299 | "codemirror_mode": { 300 | "name": "ipython", 301 | "version": 3 302 | }, 303 | "file_extension": ".py", 304 | "mimetype": "text/x-python", 305 | "name": "python", 306 | "nbconvert_exporter": "python", 307 | "pygments_lexer": "ipython3", 308 | "version": "3.9.7" 309 | }, 310 | "widgets": { 311 | "application/vnd.jupyter.widget-state+json": { 312 | "12e23e2819094ee0a079d4eb77cfc4f9": { 313 | "model_module": "@jupyter-widgets/base", 314 | "model_module_version": "1.2.0", 315 | "model_name": "LayoutModel", 316 | "state": { 317 | "_model_module": "@jupyter-widgets/base", 318 | "_model_module_version": "1.2.0", 319 | "_model_name": "LayoutModel", 320 | "_view_count": null, 321 | "_view_module": "@jupyter-widgets/base", 322 | "_view_module_version": "1.2.0", 323 | "_view_name": "LayoutView", 324 | "align_content": null, 325 | "align_items": null, 326 | "align_self": null, 327 | "border": null, 328 | "bottom": null, 329 | "display": null, 330 | "flex": null, 331 | "flex_flow": null, 332 | "grid_area": null, 333 | "grid_auto_columns": null, 334 | "grid_auto_flow": null, 335 | "grid_auto_rows": null, 336 | "grid_column": null, 337 | "grid_gap": null, 338 | "grid_row": null, 339 | "grid_template_areas": null, 340 | "grid_template_columns": null, 341 | "grid_template_rows": null, 342 | "height": null, 343 | "justify_content": null, 344 | "justify_items": null, 345 | "left": null, 346 | "margin": null, 347 | "max_height": null, 348 | "max_width": null, 349 | "min_height": null, 350 | "min_width": null, 351 | "object_fit": null, 352 | "object_position": null, 353 | "order": null, 354 | "overflow": null, 355 | "overflow_x": null, 356 | "overflow_y": null, 357 | "padding": null, 358 | "right": null, 359 | "top": null, 360 | "visibility": null, 361 | "width": null 362 | } 363 | }, 364 | "1369964d45004b5e95a058910b2a33e6": { 365 | "model_module": "@jupyter-widgets/controls", 366 | "model_module_version": "1.5.0", 367 | "model_name": "HBoxModel", 368 | "state": { 369 | "_dom_classes": [], 370 | "_model_module": "@jupyter-widgets/controls", 371 | "_model_module_version": "1.5.0", 372 | "_model_name": "HBoxModel", 373 | "_view_count": null, 374 | "_view_module": "@jupyter-widgets/controls", 375 | "_view_module_version": "1.5.0", 376 | "_view_name": "HBoxView", 377 | "box_style": "", 378 | "children": [ 379 | "IPY_MODEL_7a5f52e56ede4ac3abe37a3ece007dc9", 380 | "IPY_MODEL_ce8b0faa1a1340b5a504d7b3546b3ccb" 381 | ], 382 | "layout": "IPY_MODEL_12e23e2819094ee0a079d4eb77cfc4f9" 383 | } 384 | }, 385 | "161969cae25a49f38aacd1568d3cac6c": { 386 | "model_module": "@jupyter-widgets/base", 387 | "model_module_version": "1.2.0", 388 | "model_name": "LayoutModel", 389 | "state": { 390 | "_model_module": "@jupyter-widgets/base", 391 | "_model_module_version": "1.2.0", 392 | "_model_name": "LayoutModel", 393 | "_view_count": null, 394 | "_view_module": "@jupyter-widgets/base", 395 | "_view_module_version": "1.2.0", 396 | "_view_name": "LayoutView", 397 | "align_content": null, 398 | "align_items": null, 399 | "align_self": null, 400 | "border": null, 401 | "bottom": null, 402 | "display": null, 403 | "flex": null, 404 | "flex_flow": null, 405 | "grid_area": null, 406 | "grid_auto_columns": null, 407 | "grid_auto_flow": null, 408 | "grid_auto_rows": null, 409 | "grid_column": null, 410 | "grid_gap": null, 411 | "grid_row": null, 412 | "grid_template_areas": null, 413 | "grid_template_columns": null, 414 | "grid_template_rows": null, 415 | "height": null, 416 | "justify_content": null, 417 | "justify_items": null, 418 | "left": null, 419 | "margin": null, 420 | "max_height": null, 421 | "max_width": null, 422 | "min_height": null, 423 | "min_width": null, 424 | "object_fit": null, 425 | "object_position": null, 426 | "order": null, 427 | "overflow": null, 428 | "overflow_x": null, 429 | "overflow_y": null, 430 | "padding": null, 431 | "right": null, 432 | "top": null, 433 | "visibility": null, 434 | "width": null 435 | } 436 | }, 437 | "4a61c10fc00c4f04bb00b82e942da210": { 438 | "model_module": "@jupyter-widgets/base", 439 | "model_module_version": "1.2.0", 440 | "model_name": "LayoutModel", 441 | "state": { 442 | "_model_module": "@jupyter-widgets/base", 443 | "_model_module_version": "1.2.0", 444 | "_model_name": "LayoutModel", 445 | "_view_count": null, 446 | "_view_module": "@jupyter-widgets/base", 447 | "_view_module_version": "1.2.0", 448 | "_view_name": "LayoutView", 449 | "align_content": null, 450 | "align_items": null, 451 | "align_self": null, 452 | "border": null, 453 | "bottom": null, 454 | "display": null, 455 | "flex": null, 456 | "flex_flow": null, 457 | "grid_area": null, 458 | "grid_auto_columns": null, 459 | "grid_auto_flow": null, 460 | "grid_auto_rows": null, 461 | "grid_column": null, 462 | "grid_gap": null, 463 | "grid_row": null, 464 | "grid_template_areas": null, 465 | "grid_template_columns": null, 466 | "grid_template_rows": null, 467 | "height": null, 468 | "justify_content": null, 469 | "justify_items": null, 470 | "left": null, 471 | "margin": null, 472 | "max_height": null, 473 | "max_width": null, 474 | "min_height": null, 475 | "min_width": null, 476 | "object_fit": null, 477 | "object_position": null, 478 | "order": null, 479 | "overflow": null, 480 | "overflow_x": null, 481 | "overflow_y": null, 482 | "padding": null, 483 | "right": null, 484 | "top": null, 485 | "visibility": null, 486 | "width": null 487 | } 488 | }, 489 | "5e6adc4592124a4581b85f4c1f3bab4d": { 490 | "model_module": "@jupyter-widgets/controls", 491 | "model_module_version": "1.5.0", 492 | "model_name": "ProgressStyleModel", 493 | "state": { 494 | "_model_module": "@jupyter-widgets/controls", 495 | "_model_module_version": "1.5.0", 496 | "_model_name": "ProgressStyleModel", 497 | "_view_count": null, 498 | "_view_module": "@jupyter-widgets/base", 499 | "_view_module_version": "1.2.0", 500 | "_view_name": "StyleView", 501 | "bar_color": null, 502 | "description_width": "initial" 503 | } 504 | }, 505 | "7a5f52e56ede4ac3abe37a3ece007dc9": { 506 | "model_module": "@jupyter-widgets/controls", 507 | "model_module_version": "1.5.0", 508 | "model_name": "FloatProgressModel", 509 | "state": { 510 | "_dom_classes": [], 511 | "_model_module": "@jupyter-widgets/controls", 512 | "_model_module_version": "1.5.0", 513 | "_model_name": "FloatProgressModel", 514 | "_view_count": null, 515 | "_view_module": "@jupyter-widgets/controls", 516 | "_view_module_version": "1.5.0", 517 | "_view_name": "ProgressView", 518 | "bar_style": "success", 519 | "description": "", 520 | "description_tooltip": null, 521 | "layout": "IPY_MODEL_4a61c10fc00c4f04bb00b82e942da210", 522 | "max": 169001437, 523 | "min": 0, 524 | "orientation": "horizontal", 525 | "style": "IPY_MODEL_5e6adc4592124a4581b85f4c1f3bab4d", 526 | "value": 169001437 527 | } 528 | }, 529 | "b597cd6f6cd443aba4bf4491ac7f957e": { 530 | "model_module": "@jupyter-widgets/controls", 531 | "model_module_version": "1.5.0", 532 | "model_name": "DescriptionStyleModel", 533 | "state": { 534 | "_model_module": "@jupyter-widgets/controls", 535 | "_model_module_version": "1.5.0", 536 | "_model_name": "DescriptionStyleModel", 537 | "_view_count": null, 538 | "_view_module": "@jupyter-widgets/base", 539 | "_view_module_version": "1.2.0", 540 | "_view_name": "StyleView", 541 | "description_width": "" 542 | } 543 | }, 544 | "ce8b0faa1a1340b5a504d7b3546b3ccb": { 545 | "model_module": "@jupyter-widgets/controls", 546 | "model_module_version": "1.5.0", 547 | "model_name": "HTMLModel", 548 | "state": { 549 | "_dom_classes": [], 550 | "_model_module": "@jupyter-widgets/controls", 551 | "_model_module_version": "1.5.0", 552 | "_model_name": "HTMLModel", 553 | "_view_count": null, 554 | "_view_module": "@jupyter-widgets/controls", 555 | "_view_module_version": "1.5.0", 556 | "_view_name": "HTMLView", 557 | "description": "", 558 | "description_tooltip": null, 559 | "layout": "IPY_MODEL_161969cae25a49f38aacd1568d3cac6c", 560 | "placeholder": "​", 561 | "style": "IPY_MODEL_b597cd6f6cd443aba4bf4491ac7f957e", 562 | "value": " 169001984/? [00:06<00:00, 25734958.25it/s]" 563 | } 564 | } 565 | } 566 | } 567 | }, 568 | "nbformat": 4, 569 | "nbformat_minor": 0 570 | } 571 | -------------------------------------------------------------------------------- /docs/figures/Figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Weixin-Liang/Modality-Gap/8e20cb24efa4c5f89aad694f2f65eb43ffc46d10/docs/figures/Figure1.png -------------------------------------------------------------------------------- /docs/figures/Figure2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Weixin-Liang/Modality-Gap/8e20cb24efa4c5f89aad694f2f65eb43ffc46d10/docs/figures/Figure2.jpg -------------------------------------------------------------------------------- /docs/figures/Figure2ab.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Weixin-Liang/Modality-Gap/8e20cb24efa4c5f89aad694f2f65eb43ffc46d10/docs/figures/Figure2ab.png -------------------------------------------------------------------------------- /docs/figures/Figure2c.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Weixin-Liang/Modality-Gap/8e20cb24efa4c5f89aad694f2f65eb43ffc46d10/docs/figures/Figure2c.png -------------------------------------------------------------------------------- /docs/figures/Figure3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Weixin-Liang/Modality-Gap/8e20cb24efa4c5f89aad694f2f65eb43ffc46d10/docs/figures/Figure3.jpg -------------------------------------------------------------------------------- /docs/figures/Tables.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Weixin-Liang/Modality-Gap/8e20cb24efa4c5f89aad694f2f65eb43ffc46d10/docs/figures/Tables.png -------------------------------------------------------------------------------- /docs/figures/Theorem1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Weixin-Liang/Modality-Gap/8e20cb24efa4c5f89aad694f2f65eb43ffc46d10/docs/figures/Theorem1.png -------------------------------------------------------------------------------- /docs/figures/Theorem2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Weixin-Liang/Modality-Gap/8e20cb24efa4c5f89aad694f2f65eb43ffc46d10/docs/figures/Theorem2.png -------------------------------------------------------------------------------- /docs/figures/Theorem_variance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Weixin-Liang/Modality-Gap/8e20cb24efa4c5f89aad694f2f65eb43ffc46d10/docs/figures/Theorem_variance.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: dalle 2 | channels: 3 | - rmg 4 | - conda-forge 5 | - rdkit 6 | - pytorch 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=conda_forge 10 | - _openmp_mutex=4.5=1_gnu 11 | - blas=1.0=mkl 12 | - bzip2=1.0.8=h7f98852_4 13 | - ca-certificates=2020.6.20=hecda079_0 14 | - cffi=1.15.0=py38h3931269_0 15 | - cpuonly=2.0=0 16 | - cudatoolkit=10.1.243=h036e899_10 17 | - ffmpeg=4.3=hf484d3e_0 18 | - freetype=2.10.4=h0708190_1 19 | - future=0.18.2=py38h578d9bd_4 20 | - gmp=6.2.1=h58526e2_0 21 | - gnutls=3.6.13=h85f3911_1 22 | - intel-openmp=2021.4.0=h06a4308_3561 23 | - jpeg=9b=h024ee3a_2 24 | - lame=3.100=h7f98852_1001 25 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 26 | - libblas=3.9.0=12_linux64_mkl 27 | - libffi=3.4.2=h7f98852_5 28 | - libgcc-ng=11.2.0=h1d223b6_11 29 | - libgomp=11.2.0=h1d223b6_11 30 | - libiconv=1.16=h516909a_0 31 | - libnsl=2.0.0=h7f98852_0 32 | - libpng=1.6.37=h21135ba_2 33 | - libprotobuf=3.19.4=h780b84a_0 34 | - libstdcxx-ng=11.2.0=he4da1e4_11 35 | - libtiff=4.2.0=h85742a9_0 36 | - libuv=1.42.0=h7f98852_0 37 | - libwebp-base=1.2.1=h7f98852_0 38 | - libzlib=1.2.11=h36c2ea0_1013 39 | - lz4-c=1.9.3=h9c3ff4c_1 40 | - mkl=2021.4.0=h06a4308_640 41 | - mkl-service=2.4.0=py38h95df7f1_0 42 | - mkl_fft=1.3.1=py38h8666266_1 43 | - mkl_random=1.2.2=py38h1abd341_0 44 | - ncurses=6.2=h58526e2_4 45 | - nettle=3.6=he412f7d_0 46 | - ninja=1.10.2=h4bd325d_1 47 | - numpy-base=1.21.2=py38h79a1101_0 48 | - olefile=0.46=pyh9f0ad1d_1 49 | - openh264=2.1.1=h780b84a_0 50 | - openssl=3.0.0=h7f98852_2 51 | - pip=21.3.1=pyhd8ed1ab_0 52 | - pycparser=2.21=pyhd8ed1ab_0 53 | - python=3.8.12=hf930737_2_cpython 54 | - python_abi=3.8=2_cp38 55 | - pytorch-mutex=1.0=cpu 56 | - readline=8.1=h46c0cb4_0 57 | - ruamel_yaml=0.15.80=py38h497a2fe_1006 58 | - setuptools=60.2.0=py38h578d9bd_0 59 | - six=1.16.0=pyh6c4a22f_0 60 | - sleef=3.5.1=h9b69904_2 61 | - sqlite=3.37.0=h9cd32fc_0 62 | - tk=8.6.11=h27826a3_1 63 | - torchaudio=0.10.0=py38_cpu 64 | - typing_extensions=4.0.1=pyha770c72_0 65 | - wheel=0.37.1=pyhd8ed1ab_0 66 | - xz=5.2.5=h516909a_1 67 | - yaml=0.2.5=h7f98852_2 68 | - zlib=1.2.11=h36c2ea0_1013 69 | - zstd=1.4.9=ha95c52a_0 70 | - pip: 71 | - absl-py==1.0.0 72 | - aiohttp==3.8.1 73 | - aiosignal==1.2.0 74 | - ansicolors==1.1.8 75 | - antlr4-python3-runtime==4.8 76 | - anyio==3.4.0 77 | - apptools==5.1.0 78 | - argon2-cffi==21.3.0 79 | - argon2-cffi-bindings==21.2.0 80 | - astroid==2.9.3 81 | - async-timeout==4.0.2 82 | - attrs==21.4.0 83 | - babel==2.9.1 84 | - backcall==0.2.0 85 | - bleach==4.1.0 86 | - blessings==1.7 87 | - cachetools==4.2.4 88 | - certifi==2021.10.8 89 | - charset-normalizer==2.0.9 90 | - click==8.0.3 91 | - cloudpickle==2.0.0 92 | - configobj==5.0.6 93 | - cycler==0.11.0 94 | - cython==0.29.26 95 | - debugpy==1.5.1 96 | - decorator==5.1.0 97 | - defusedxml==0.7.1 98 | - einops==0.3.2 99 | - entrypoints==0.3 100 | - envisage==6.0.1 101 | - filelock==3.4.2 102 | - fonttools==4.28.5 103 | - frozenlist==1.2.0 104 | - fsspec==2021.11.1 105 | - ftfy==6.0.3 106 | - fvcore==0.1.5.post20211023 107 | - google-auth==2.3.3 108 | - google-auth-oauthlib==0.4.6 109 | - gpustat==0.6.0 110 | - grpcio==1.43.0 111 | - huggingface-hub==0.0.12 112 | - idna==3.3 113 | - imageio==2.13.5 114 | - importlib-metadata==4.10.0 115 | - importlib-resources==5.4.0 116 | - iopath==0.1.9 117 | - ipykernel==6.6.0 118 | - ipython==7.30.1 119 | - ipython-genutils==0.2.0 120 | - ipywidgets==7.6.5 121 | - isort==5.10.1 122 | - jedi==0.18.1 123 | - jinja2==3.0.3 124 | - joblib==1.1.0 125 | - json5==0.9.6 126 | - jsonschema==4.3.2 127 | - jupyter-client==7.1.0 128 | - jupyter-core==4.9.1 129 | - jupyter-server==1.13.1 130 | - jupyterlab==3.2.5 131 | - jupyterlab-pygments==0.1.2 132 | - jupyterlab-server==2.10.2 133 | - jupyterlab-widgets==1.0.2 134 | - kiwisolver==1.3.2 135 | - lazy-object-proxy==1.7.1 136 | - llvmlite==0.38.0 137 | - markdown==3.3.6 138 | - markupsafe==2.0.1 139 | - matplotlib==3.5.1 140 | - matplotlib-inline==0.1.3 141 | - mccabe==0.6.1 142 | - mistune==0.8.4 143 | - mock==4.0.3 144 | - multidict==5.2.0 145 | - nbclassic==0.3.4 146 | - nbclient==0.5.9 147 | - nbconvert==6.3.0 148 | - nbformat==5.1.3 149 | - nest-asyncio==1.5.4 150 | - networkx==2.6.3 151 | - nibabel==3.2.2 152 | - notebook==6.4.6 153 | - numba==0.55.0 154 | - numpy==1.21.5 155 | - nvidia-ml-py3==7.352.0 156 | - oauthlib==3.1.1 157 | - omegaconf==2.1.1 158 | - opencv-python==4.5.5.62 159 | - packaging==21.3 160 | - pandas==1.4.0 161 | - pandocfilters==1.5.0 162 | - parso==0.8.3 163 | - pexpect==4.8.0 164 | - pickleshare==0.7.5 165 | - pillow==8.4.0 166 | - platformdirs==2.4.1 167 | - portalocker==2.3.2 168 | - prometheus-client==0.12.0 169 | - prompt-toolkit==3.0.24 170 | - protobuf==3.19.1 171 | - psutil==5.9.0 172 | - ptyprocess==0.7.0 173 | - pyasn1==0.4.8 174 | - pyasn1-modules==0.2.8 175 | - pycocotools==2.0 176 | - pydeprecate==0.3.1 177 | - pydicom==2.2.2 178 | - pydot==1.4.2 179 | - pyface==7.4.0 180 | - pyflakes==2.4.0 181 | - pygments==2.11.0 182 | - pylint==2.12.2 183 | - pynndescent==0.5.6 184 | - pyparsing==3.0.6 185 | - pyqt5==5.15.6 186 | - pyqt5-qt5==5.15.2 187 | - pyqt5-sip==12.9.1 188 | - pyrsistent==0.18.0 189 | - python-dateutil==2.8.2 190 | - python-magic==0.4.25 191 | - pytorch-lightning==1.5.7 192 | - pytorch-pretrained-vit==0.0.7 193 | - pytz==2021.3 194 | - pywavelets==1.2.0 195 | - pyyaml==6.0 196 | - pyzmq==22.3.0 197 | - regex==2021.11.10 198 | - requests==2.26.0 199 | - requests-oauthlib==1.3.0 200 | - rsa==4.8 201 | - ruamel-yaml-clib==0.2.6 202 | - sacremoses==0.0.47 203 | - scikit-image==0.19.1 204 | - scikit-learn==1.0.2 205 | - scipy==1.7.3 206 | - seaborn==0.11.2 207 | - send2trash==1.8.0 208 | - sniffio==1.2.0 209 | - tabulate==0.8.9 210 | - tensorboard==2.7.0 211 | - tensorboard-data-server==0.6.1 212 | - tensorboard-plugin-wit==1.8.0 213 | - termcolor==1.1.0 214 | - terminado==0.12.1 215 | - testpath==0.5.0 216 | - threadpoolctl==3.0.0 217 | - tifffile==2021.11.2 218 | - timm==0.4.9 219 | - tokenizers==0.10.3 220 | - toml==0.10.2 221 | - torch==1.10.0 222 | - torchmetrics==0.6.2 223 | - torchvision==0.13.0a0+7bb5e41 224 | - tornado==6.1 225 | - tqdm==4.62.3 226 | - traitlets==5.1.1 227 | - traits==6.3.2 228 | - traitsui==7.2.1 229 | - transformers==4.8.1 230 | - triton==1.1.1 231 | - umap-learn==0.5.2 232 | - urllib3==1.26.7 233 | - vtk==9.1.0 234 | - wcwidth==0.2.5 235 | - webencodings==0.5.1 236 | - websocket-client==1.2.3 237 | - werkzeug==2.0.2 238 | - widgetsnbextension==3.5.2 239 | - wrapt==1.13.3 240 | - wslink==1.4.1 241 | - yacs==0.1.8 242 | - yarl==1.7.2 243 | - zipp==3.7.0 244 | prefix: ANONYMOUS_ROOTDIR/develop/miniconda3/envs/dalle 245 | -------------------------------------------------------------------------------- /util/gap_amend_std.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 3, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "# % Original model (three seeds): 1.1898 1.1884 1.1892 -> 1.3149 1.2988 1.2836\n", 19 | "# % Amended model (three seeds): 0.0314 0.0551 0.0299 -> 0.7472 0.7195 0.7704\n", 20 | "\n", 21 | "original_init = [1.1898, 1.1884, 1.1892]\n", 22 | "original_trained = [1.3149, 1.2988, 1.2836]\n", 23 | "\n", 24 | "amended_init = [0.0314, 0.0551, 0.0299]\n", 25 | "amended_trained = [0.7472, 0.7195, 0.7704]\n" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 4, 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "name": "stdout", 35 | "output_type": "stream", 36 | "text": [ 37 | "1.1891 0.0006\n", 38 | "1.2991 0.0128\n", 39 | "0.0388 0.0115\n", 40 | "0.7457 0.0208\n" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "for var_name in ['original_init', 'original_trained', 'amended_init', 'amended_trained']:\n", 46 | " array = eval(var_name)\n", 47 | " print('{:.4f} {:.4f}'.format(np.mean(array), np.std(array)))\n", 48 | " " 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [] 57 | } 58 | ], 59 | "metadata": { 60 | "interpreter": { 61 | "hash": "09c077faaa20da841f22e0f4d12b4addb73e00d9291bc78d00732f9f39794f23" 62 | }, 63 | "kernelspec": { 64 | "display_name": "Python 3.9.7 ('clip')", 65 | "language": "python", 66 | "name": "python3" 67 | }, 68 | "language_info": { 69 | "codemirror_mode": { 70 | "name": "ipython", 71 | "version": 3 72 | }, 73 | "file_extension": ".py", 74 | "mimetype": "text/x-python", 75 | "name": "python", 76 | "nbconvert_exporter": "python", 77 | "pygments_lexer": "ipython3", 78 | "version": "3.9.7" 79 | }, 80 | "orig_nbformat": 4 81 | }, 82 | "nbformat": 4, 83 | "nbformat_minor": 2 84 | } 85 | --------------------------------------------------------------------------------