├── assets ├── CLIP.png ├── cat.jpg ├── dog.jpg └── map.jpg ├── .pre-commit-config.yaml ├── requirements.txt ├── src ├── config.py ├── download_coco_data.py ├── model_loss.py ├── custom_model.py └── clip_dl.py ├── README.md ├── LICENSE ├── clip_training.py └── .gitignore /assets/CLIP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RustamyF/clip-multimodal-ml/HEAD/assets/CLIP.png -------------------------------------------------------------------------------- /assets/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RustamyF/clip-multimodal-ml/HEAD/assets/cat.jpg -------------------------------------------------------------------------------- /assets/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RustamyF/clip-multimodal-ml/HEAD/assets/dog.jpg -------------------------------------------------------------------------------- /assets/map.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RustamyF/clip-multimodal-ml/HEAD/assets/map.jpg -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: https://github.com/psf/black 9 | rev: 22.3.0 10 | hooks: 11 | - id: black 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow==9.4.0 2 | torch @ https://download.pytorch.org/whl/cu118/torch-2.1.0%2Bcu118-cp310-cp310-linux_x86_64.whl#sha256=a81b554184492005543ddc32e96469f9369d778dedd195d73bda9bed407d6589 3 | torchvision @ https://download.pytorch.org/whl/cu118/torchvision-0.16.0%2Bcu118-cp310-cp310-linux_x86_64.whl#sha256=033712f65d45afe806676c4129dfe601ad1321d9e092df62b15847c02d4061dc 4 | transformers==4.35.2 5 | scikit-learn==1.2.2 6 | datasets==2.15.0 7 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class Config: 6 | """ 7 | Configuration class for the CLIP training script. 8 | """ 9 | 10 | embed_dim: int = 512 # Embedding dimension 11 | transformer_embed_dim: int = 768 # Transformer embedding dimension 12 | max_len: int = 32 # Maximum text length 13 | text_model: str = "distilbert-base-multilingual-cased" # Text model name 14 | epochs: int = 5 # Number of training epochs 15 | batch_size: int = 128 # Batch size 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CLIP 2 | This repository presents CLIP model training and serving. CLIP, wich stands for Contrastive Language-Image Pretraining, is a deep learning model developed by OpenAI in 2021. It's a powerful vision-and-language model that bridges the gap between images and their textual descriptions. 3 | 4 | ![App Screenshot](assets/CLIP.png) 5 | 6 | ## CLIP Training 7 | To run the training pipeline for CLIP, run the following command: 8 | 9 | ```python 10 | python clip_training.py 11 | ``` 12 | by default, the training will use fliker30k dataset. For training on COCO dataset, change the `coco_dataset = False` to `coco_dataset = True` in the `clip_training.py` file. 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Fahim Rustamy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/download_coco_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from zipfile import ZipFile 4 | import urllib.request 5 | import json 6 | import collections 7 | 8 | root_dir = "datasets" 9 | annotations_dir = os.path.join(root_dir, "annotations") 10 | images_dir = os.path.join(root_dir, "train2017") 11 | annotation_file = os.path.join(annotations_dir, "captions_train2017.json") 12 | 13 | # Download caption annotation files 14 | if not os.path.exists(annotations_dir): 15 | annotation_zip_url = ( 16 | "http://images.cocodataset.org/annotations/annotations_trainval2017.zip" 17 | ) 18 | annotation_zip_path = os.path.join(os.path.abspath("."), "captions.zip") 19 | urllib.request.urlretrieve(annotation_zip_url, annotation_zip_path) 20 | with ZipFile(annotation_zip_path, "r") as zip_ref: 21 | zip_ref.extractall(annotations_dir) 22 | os.remove(annotation_zip_path) 23 | 24 | # Download image files 25 | if not os.path.exists(images_dir): 26 | image_zip_url = "http://images.cocodataset.org/zips/train2017.zip" 27 | image_zip_path = os.path.join(os.path.abspath("."), "train2017.zip") 28 | urllib.request.urlretrieve(image_zip_url, image_zip_path) 29 | with ZipFile(image_zip_path, "r") as zip_ref: 30 | zip_ref.extractall(images_dir) 31 | os.remove(image_zip_path) 32 | 33 | print("Dataset is downloaded and extracted successfully.") 34 | -------------------------------------------------------------------------------- /src/model_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | 4 | 5 | def CLIP_loss(logits: torch.Tensor) -> torch.Tensor: 6 | """ 7 | Calculate a custom cross-entropy loss. 8 | 9 | Args: 10 | - logits (torch.Tensor): The input tensor containing unnormalized logits. 11 | 12 | Returns: 13 | - torch.Tensor: The computed custom cross-entropy loss. 14 | 15 | Example: 16 | >>> logits = torch.rand((batch_size, num_classes)) 17 | >>> loss = CLIP_loss(logits) 18 | """ 19 | 20 | n = logits.shape[1] 21 | 22 | # Create labels tensor 23 | labels = torch.arange(n) 24 | 25 | # bring logits to cpu 26 | logits = logits.to("cpu") 27 | 28 | # Calculate cross entropy losses along axis 0 and 1 29 | loss_i = F.cross_entropy(logits.transpose(0, 1), labels, reduction="mean") 30 | loss_t = F.cross_entropy(logits, labels, reduction="mean") 31 | 32 | # Calculate the final loss 33 | loss = (loss_i + loss_t) / 2 34 | 35 | return loss 36 | 37 | 38 | def metrics(similarity: torch.Tensor): 39 | y = torch.arange(len(similarity)).to(similarity.device) 40 | img2cap_match_idx = similarity.argmax(dim=1) 41 | cap2img_match_idx = similarity.argmax(dim=0) 42 | 43 | img_acc = (img2cap_match_idx == y).float().mean() 44 | cap_acc = (cap2img_match_idx == y).float().mean() 45 | 46 | return img_acc, cap_acc 47 | 48 | 49 | # Example usage 50 | # logits = torch.rand(8, 8) # Simulated logits 51 | # loss = CLIP_loss(logits) 52 | # print(f"Contrastive loss: {loss}") 53 | -------------------------------------------------------------------------------- /clip_training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import torch 4 | import subprocess 5 | from torch.utils.data import DataLoader 6 | from src.custom_model import CustomModel 7 | from src.clip_dl import CocoDataset, Flickr30kDataset 8 | from src.config import Config 9 | 10 | coco_dataset = False 11 | # Create the CLIP dataset 12 | if coco_dataset: 13 | if not "datasets" in os.listdir(): 14 | print("coco dataset is not downloaded! running the downloading script ....") 15 | subprocess.run(["python", "src/download_coco_data.py"]) 16 | 17 | clip_dataset = CocoDataset(root_dir="datasets") 18 | else: 19 | clip_dataset = Flickr30kDataset() 20 | 21 | 22 | # Create the DataLoader 23 | clip_dataloader = DataLoader( 24 | clip_dataset, batch_size=Config.batch_size, shuffle=True, num_workers=4 25 | ) 26 | 27 | device = "cuda" if torch.cuda.is_available() else "cpu" 28 | # Create an instance of your model 29 | model = CustomModel().to(device) 30 | 31 | # Define optimizer 32 | optimizer = torch.optim.Adam( 33 | [ 34 | {"params": model.vision_encoder.parameters()}, 35 | {"params": model.caption_encoder.parameters()}, 36 | ], 37 | lr=model.lr, 38 | ) 39 | 40 | 41 | # Dummy training and validation loops 42 | num_epochs = 5 43 | batch_zero = True 44 | for epoch in range(num_epochs): 45 | model.train() 46 | for batch in clip_dataloader: 47 | image = batch["image"].to(device) 48 | text = batch["caption"] 49 | # images, text = batch 50 | loss, img_acc, cap_acc = model(image, text) 51 | 52 | # Backward pass and optimization 53 | optimizer.zero_grad() 54 | loss.backward() 55 | optimizer.step() 56 | 57 | if batch_zero: 58 | print(f"Epoch [{0}/{num_epochs}], Batch Loss: {loss.item()}") 59 | batch_zero = False 60 | 61 | # Print training statistics 62 | print(f"Epoch [{epoch+1}/{num_epochs}], Batch Loss: {loss.item()}") 63 | 64 | print("Training complete.") 65 | -------------------------------------------------------------------------------- /src/custom_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torchvision import models 4 | import torch 5 | from transformers import AutoModel, AutoTokenizer, BertTokenizer 6 | from .config import Config 7 | from .model_loss import CLIP_loss, metrics 8 | 9 | 10 | class Projection(nn.Module): 11 | def __init__(self, d_in: int, d_out: int, p: float = 0.5) -> None: 12 | super().__init__() 13 | self.linear1 = nn.Linear(d_in, d_out, bias=False) 14 | self.linear2 = nn.Linear(d_out, d_out, bias=False) 15 | self.layer_norm = nn.LayerNorm(d_out) 16 | self.drop = nn.Dropout(p) 17 | 18 | def forward(self, x: torch.Tensor) -> torch.Tensor: 19 | embed1 = self.linear1(x) 20 | embed2 = self.drop(self.linear2(F.gelu(embed1))) 21 | embeds = self.layer_norm(embed1 + embed2) 22 | return embeds 23 | 24 | 25 | class VisionEncoder(nn.Module): 26 | def __init__(self, d_out: int) -> None: 27 | super().__init__() 28 | base = models.resnet34(pretrained=True) 29 | d_in = base.fc.in_features 30 | base.fc = nn.Identity() 31 | self.base = base 32 | self.projection = Projection(d_in, d_out) 33 | for p in self.base.parameters(): 34 | p.requires_grad = False 35 | 36 | def forward(self, x): 37 | projected_vec = self.projection(self.base(x)) 38 | projection_len = torch.norm(projected_vec, dim=-1, keepdim=True) 39 | return projected_vec / projection_len 40 | 41 | 42 | class TextEncoder(nn.Module): 43 | def __init__(self, d_out: int) -> None: 44 | super().__init__() 45 | self.base = AutoModel.from_pretrained(Config.text_model) 46 | self.projection = Projection(Config.transformer_embed_dim, d_out) 47 | for p in self.base.parameters(): 48 | p.requires_grad = False 49 | 50 | def forward(self, x): 51 | out = self.base(x)[0] 52 | out = out[:, 0, :] # get CLS token output 53 | projected_vec = self.projection(out) 54 | projection_len = torch.norm(projected_vec, dim=-1, keepdim=True) 55 | return projected_vec / projection_len 56 | 57 | 58 | class CustomModel(nn.Module): 59 | def __init__(self, lr: float = 1e-3) -> None: 60 | super().__init__() 61 | self.vision_encoder = VisionEncoder(Config.embed_dim) 62 | self.caption_encoder = TextEncoder(Config.embed_dim) 63 | self.tokenizer = Tokenizer( 64 | AutoTokenizer.from_pretrained(Config.text_model), use_fast=False 65 | ) 66 | self.lr = lr 67 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 68 | 69 | def forward(self, images, text): 70 | text = self.tokenizer(text).to(self.device) 71 | 72 | image_embed = self.vision_encoder(images) 73 | caption_embed = self.caption_encoder(text["input_ids"]) 74 | similarity = caption_embed @ image_embed.T 75 | 76 | loss = CLIP_loss(similarity) 77 | img_acc, cap_acc = metrics(similarity) 78 | return loss, img_acc, cap_acc 79 | 80 | 81 | class Tokenizer: 82 | def __init__(self, tokenizer: BertTokenizer) -> None: 83 | self.tokenizer = tokenizer 84 | 85 | def __call__(self, x: str) -> AutoTokenizer: 86 | return self.tokenizer( 87 | x, 88 | max_length=Config.max_len, 89 | truncation=True, 90 | padding=True, 91 | return_tensors="pt", 92 | ) 93 | -------------------------------------------------------------------------------- /src/clip_dl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | import collections 6 | from torchvision import transforms 7 | from datasets import load_dataset 8 | import torch 9 | 10 | 11 | class CocoDataset(Dataset): 12 | def __init__(self, root_dir): 13 | self.root_dir = root_dir 14 | self.transform = transforms.Compose( 15 | [ 16 | transforms.Resize((224, 224)), 17 | transforms.ToTensor(), 18 | ] 19 | ) 20 | annotations_dir = os.path.join(root_dir, "annotations") 21 | annotation_file = os.path.join( 22 | annotations_dir, "annotations", "captions_train2017.json" 23 | ) 24 | 25 | self.caption_list, self.image_path_list = self.load_annotations(annotation_file) 26 | 27 | def load_annotations(self, annotation_file): 28 | with open(annotation_file, "r") as f: 29 | annotations = json.load(f)["annotations"] 30 | 31 | image_path_to_caption = collections.defaultdict(list) 32 | for element in annotations: 33 | caption = f"{element['caption'].lower().rstrip('.')}" 34 | image_path = os.path.join( 35 | self.root_dir, 36 | "train2017", 37 | "train2017", 38 | "%012d.jpg" % (element["image_id"]), 39 | ) 40 | image_path_to_caption[image_path].append(caption) 41 | image_paths = list(image_path_to_caption.keys()) 42 | caption_list, image_path_list = self.training_list( 43 | image_paths, image_path_to_caption 44 | ) 45 | 46 | return caption_list, image_path_list 47 | 48 | def training_list(self, image_paths, image_path_to_caption): 49 | captions_per_image = 2 50 | caption_list = [] 51 | image_path_list = [] 52 | for image_path in image_paths: 53 | captions = image_path_to_caption[image_path][:captions_per_image] 54 | caption_list.extend(captions) 55 | image_path_list.extend([image_path] * len(captions)) 56 | 57 | return caption_list, image_path_list 58 | 59 | def __len__(self): 60 | return len(self.caption_list) 61 | 62 | def __getitem__(self, idx): 63 | image_path = self.image_path_list[idx] 64 | caption = self.caption_list[idx] 65 | 66 | image = Image.open(image_path).convert("RGB") 67 | 68 | if self.transform: 69 | image = self.transform(image) 70 | 71 | return {"image": image, "caption": caption} 72 | 73 | 74 | class Flickr30kDataset(Dataset): 75 | def __init__(self): 76 | self.dataset = load_dataset("nlphuji/flickr30k", cache_dir="./huggingface_data") 77 | self.transform = transforms.Compose( 78 | [ 79 | transforms.Resize((224, 224)), 80 | transforms.ToTensor(), 81 | ] 82 | ) 83 | self.cap_per_image = 2 84 | 85 | def __len__(self): 86 | return self.dataset.num_rows["test"] * self.cap_per_image 87 | 88 | def __getitem__(self, idx): 89 | original_idx = idx // self.cap_per_image 90 | # image_path = self.dataset[idx]["image_path"] 91 | image = self.dataset["test"][original_idx]["image"].convert("RGB") 92 | image = self.transform(image) 93 | 94 | # You might need to adjust the labels based on your task 95 | caption = self.dataset["test"][original_idx]["caption"][ 96 | idx % self.cap_per_image 97 | ] 98 | 99 | return {"image": image, "caption": caption} 100 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | --------------------------------------------------------------------------------