├── ruclipsb ├── __init__.py ├── utils.py ├── dataset.py ├── ruclipsb.py └── trainer.py ├── requirements.txt ├── pictures └── Similarity.png ├── setup.py ├── LICENSE ├── README.md └── notebooks ├── ruCLIP_SB_onnx.ipynb └── finetune_ruCLIP_SB.ipynb /ruclipsb/__init__.py: -------------------------------------------------------------------------------- 1 | from .ruclipsb import * 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | timm 3 | transformers==4.12.3 4 | tqdm 5 | -------------------------------------------------------------------------------- /pictures/Similarity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cene555/ruCLIP-SB/HEAD/pictures/Similarity.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pkg_resources 4 | from setuptools import setup, find_packages 5 | 6 | setup( 7 | name="ruclipsb", 8 | py_modules=["ruclipsb"], 9 | version="1.0", 10 | description="", 11 | author="Shahmatov Arseniy", 12 | packages=find_packages(exclude=["tests*"]), 13 | install_requires=[ 14 | str(r) 15 | for r in pkg_resources.parse_requirements( 16 | open(os.path.join(os.path.dirname(__file__), "requirements.txt")) 17 | ) 18 | ], 19 | include_package_data=True, 20 | extras_require={'dev': ['pytest']}, 21 | ) 22 | -------------------------------------------------------------------------------- /ruclipsb/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def tokenize(tokenizer, texts, max_len=77): 5 | input_ids = [] 6 | attention_masks = [] 7 | for sent in texts: 8 | encoded_dict = tokenizer.encode_plus( 9 | sent, 10 | truncation=True, 11 | add_special_tokens = True, 12 | max_length = max_len, 13 | padding='max_length', 14 | return_attention_mask = True, 15 | return_tensors = 'pt', 16 | ) 17 | input_ids.append(encoded_dict['input_ids']) 18 | attention_masks.append(encoded_dict['attention_mask']) 19 | input_ids = torch.cat(input_ids, dim=0) 20 | attention_masks = torch.cat(attention_masks, dim=0) 21 | return input_ids, attention_masks 22 | 23 | 24 | def _convert_image_to_rgb(image): 25 | return image.convert("RGB") 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Shahmatov Arseniy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ruCLIP-SB 2 | RuCLIP-SB (Russian Contrastive Language–Image Pretraining SWIN-BERT) is a multimodal model for obtaining images and text similarities and rearranging captions and pictures. Unlike other versions of the model we use BERT for text encoder and SWIN transformer for image encoder. 3 | 4 | ## Our model achieved 37.02% zero-shot accuracy on CIFAR100 and has 39543907 parameters. 5 | ### Download URL: [ruCLIP-SB](https://drive.google.com/file/d/1-CghuC9TCIDyn5H3zQS6ho_TNiudzJCX/view?usp=sharing) 6 | 7 | ### Example usage: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cene555/ruCLIP-SB/blob/main/notebooks/evaluate_ruCLIP_SB_latest.ipynb) 8 | 9 | ### Finetuning: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1iGIfr9XD7wQi3rGZjmx1bmm2_qDg9qYy?usp=sharing) 10 | 11 | ### ONNX example: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cene555/ruCLIP-SB/blob/main/notebooks/ruCLIP_SB_onnx.ipynb) 12 | 13 | We trained model on 2 millions images. 14 | 15 | ![image](https://github.com/cene555/ruCLIP-SB/blob/main/pictures/Similarity.png) 16 | 17 | 18 | ### Thanks to Sber AI for help. 19 | -------------------------------------------------------------------------------- /ruclipsb/dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from torch.utils.data import Dataset 3 | from torchvision import transforms 4 | import transformers 5 | from .utils import _convert_image_to_rgb, tokenize 6 | import pandas as pd 7 | import os 8 | from PIL import Image 9 | try: 10 | from torchvision.transforms import InterpolationMode 11 | BICUBIC = InterpolationMode.BICUBIC 12 | except ImportError: 13 | BICUBIC = Image.BICUBIC 14 | 15 | class RuCLIPSBDataset(Dataset): 16 | def __init__(self, dir, df_path, max_text_len=77): 17 | self.df = pd.read_csv(df_path) 18 | self.dir = dir 19 | self.max_text_len = max_text_len 20 | self.tokenizer = transformers.BertTokenizer.from_pretrained("cointegrated/rubert-tiny") 21 | self.transform = transforms.Compose([ 22 | transforms.Resize(224, interpolation=BICUBIC), 23 | transforms.CenterCrop(224), 24 | _convert_image_to_rgb, 25 | transforms.ToTensor(), 26 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),]) 27 | def __getitem__(self, idx): 28 | # достаем имя изображения и ее лейбл 29 | image_name = self.df['image_name'].iloc[idx] 30 | text = self.df['text'].iloc[idx] 31 | input_ids, attention_mask = tokenize(self.tokenizer, [text], max_len=self.max_text_len) 32 | input_ids, attention_mask = input_ids[0], attention_mask[0] 33 | image = cv2.imread(os.path.join(self.dir, image_name)) 34 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 35 | image = Image.fromarray(image) 36 | image = self.transform(image) 37 | return image, input_ids, attention_mask 38 | def __len__(self): 39 | return len(self.df) 40 | -------------------------------------------------------------------------------- /ruclipsb/ruclipsb.py: -------------------------------------------------------------------------------- 1 | from timm import create_model as create_swin 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | from transformers import BertModel 7 | 8 | class ruCLIPSB(nn.Module): 9 | def __init__(self,): 10 | super().__init__() 11 | self.visual = swin = create_swin( 12 | 'swin_tiny_patch4_window7_224', pretrained=True, num_classes=0, in_chans=3) #out 768 13 | self.transformer = BertModel.from_pretrained("cointegrated/rubert-tiny") 14 | self.final_ln = nn.Linear(312, 768) 15 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 16 | 17 | @property 18 | def dtype(self): 19 | return self.visual.patch_embed.proj.weight.dtype 20 | 21 | def encode_image(self, image): 22 | return self.visual(image.type(self.dtype)) 23 | 24 | def encode_text(self, input_ids, attention_mask): 25 | x = self.transformer(input_ids=input_ids, attention_mask=attention_mask) 26 | x = x.last_hidden_state[:, 0, :] 27 | x = self.final_ln(x) 28 | return x 29 | 30 | def forward(self, image, input_ids, attention_mask): 31 | image_features = self.encode_image(image) 32 | text_features = self.encode_text(input_ids, attention_mask) 33 | 34 | # normalized features 35 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 36 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 37 | 38 | # cosine similarity as logits 39 | logit_scale = self.logit_scale.exp() 40 | logits_per_image = logit_scale * image_features @ text_features.t() 41 | logits_per_text = logits_per_image.t() 42 | 43 | return logits_per_image, logits_per_text 44 | -------------------------------------------------------------------------------- /ruclipsb/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import get_linear_schedule_with_warmup 3 | from tqdm.auto import tqdm 4 | from .dataset import RuCLIPSBDataset 5 | 6 | 7 | class Trainer: 8 | def __init__(self, train_dataframe, train_dir, 9 | val_dataframe=None, val_dir=None, learning_rate=1e-4, 10 | freeze_image_encoder=True, freeze_text_encoder=False, max_text_len=77, 11 | train_batch_size=64, val_batch_size=64, num_workers=2, 12 | weight_decay=1e-4): 13 | self.train_dataframe = train_dataframe 14 | self.train_dir = train_dir 15 | self.val_dataframe = val_dataframe 16 | self.val_dir = val_dir 17 | self.learning_rate = learning_rate 18 | self.freeze_image_encoder = freeze_image_encoder 19 | self.freeze_text_encoder = freeze_text_encoder 20 | self.max_text_len = max_text_len 21 | self.train_batch_size = train_batch_size 22 | self.val_batch_size = val_batch_size 23 | self.num_workers = num_workers 24 | self.weight_decay = weight_decay 25 | 26 | def train_model(self, model, epochs_num=1, device='cuda', verbose=10): 27 | 28 | is_val = self.val_dataframe is not None and self.val_dir is not None 29 | 30 | model.to(device) 31 | 32 | train_dataset = RuCLIPSBDataset(self.train_dir, self.train_dataframe, self.max_text_len) 33 | 34 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 35 | batch_size=self.train_batch_size, 36 | shuffle=True, 37 | pin_memory=True, 38 | num_workers=self.num_workers) 39 | 40 | if is_val: 41 | val_dataset = RuCLIPSBDataset(self.val_dir, self.val_dataframe, self.max_text_len) 42 | val_loader = torch.utils.data.DataLoader(dataset=val_dataset, 43 | batch_size=self.val_batch_size, 44 | shuffle=False, 45 | pin_memory=True, 46 | num_workers=self.num_workers) 47 | 48 | for i, child in enumerate(model.children()): 49 | if (i == 0 and self.freeze_image_encoder) or (i == 1 and self.freeze_text_encoder): 50 | for param in child.parameters(): 51 | param.requires_grad = False 52 | 53 | loss_img = torch.nn.CrossEntropyLoss() 54 | loss_txt = torch.nn.CrossEntropyLoss() 55 | 56 | optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate, betas=(0.9,0.98), eps=1e-8, weight_decay=self.weight_decay) 57 | total_steps = len(train_loader) * epochs_num 58 | scheduler = get_linear_schedule_with_warmup(optimizer, 59 | num_warmup_steps=0, 60 | num_training_steps=total_steps) 61 | model.train() 62 | for epoch in range(epochs_num): 63 | print(f'start training epoch {epoch}') 64 | 65 | for i, batch in enumerate(tqdm(train_loader)): 66 | optimizer.zero_grad() 67 | images = batch[0].to(device) 68 | input_ids = batch[1].to(device) 69 | attention_mask = batch[2].to(device) 70 | 71 | logits_per_image, logits_per_text = model(images, input_ids, attention_mask) 72 | ground_truth = torch.arange(batch[1].shape[0], dtype=torch.long).to(device) 73 | img_l = loss_img(logits_per_image, ground_truth) 74 | text_l = loss_txt(logits_per_text, ground_truth) 75 | total_loss = (img_l + text_l)/2 76 | if i % verbose == 0: 77 | print(f'{i}/{len(train_loader)} total_loss {total_loss}') 78 | total_loss.backward() 79 | torch.nn.utils.clip_grad_norm_(model.parameters(), 100) 80 | optimizer.step() 81 | scheduler.step() 82 | if is_val: 83 | print(f'start val epoch {epoch}') 84 | total_loss = 0 85 | model.eval() 86 | with torch.no_grad(): 87 | for i, batch in enumerate(tqdm(val_loader)): 88 | images = batch[0].to(device) 89 | input_ids = batch[1].to(device) 90 | attention_mask = batch[2].to(device) 91 | 92 | logits_per_image, logits_per_text = model(images, input_ids, attention_mask) 93 | ground_truth = torch.arange(batch[1].shape[0], dtype=torch.long).to(device) 94 | img_l = loss_img(logits_per_image, ground_truth).item() 95 | text_l = loss_txt(logits_per_text, ground_truth).item() 96 | total_loss += (img_l + text_l)/2 97 | print(f'val loss = {total_loss / len(val_loader)}') 98 | return model 99 | -------------------------------------------------------------------------------- /notebooks/ruCLIP_SB_onnx.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "ruCLIP_SB_onnx.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [ 9 | "WWXCt_2NLhN_", 10 | "PHb4CAoRL3qC", 11 | "re2sSYAYO3D-", 12 | "ithu4-z0PIm5", 13 | "FWm0GAhWPzSW" 14 | ], 15 | "machine_shape": "hm" 16 | }, 17 | "kernelspec": { 18 | "name": "python3", 19 | "display_name": "Python 3" 20 | }, 21 | "language_info": { 22 | "name": "python" 23 | }, 24 | "accelerator": "GPU" 25 | }, 26 | "cells": [ 27 | { 28 | "cell_type": "markdown", 29 | "source": [ 30 | "# [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cene555/ruCLIP-SB/blob/main/notebooks/ruCLIP_SB_onnx.ipynb)" 31 | ], 32 | "metadata": { 33 | "id": "JsWuTduwaagq" 34 | } 35 | }, 36 | { 37 | "cell_type": "code", 38 | "source": [ 39 | "#@title Allowed Resources\n", 40 | "import multiprocessing\n", 41 | "import torch\n", 42 | "from psutil import virtual_memory\n", 43 | "\n", 44 | "ram_gb = round(virtual_memory().total / 1024**3, 1)\n", 45 | "\n", 46 | "print('CPU:', multiprocessing.cpu_count())\n", 47 | "print('RAM GB:', ram_gb)\n", 48 | "print(\"PyTorch version:\", torch.__version__)\n", 49 | "print(\"CUDA version:\", torch.version.cuda)\n", 50 | "print(\"cuDNN version:\", torch.backends.cudnn.version())\n", 51 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 52 | "print(\"device:\", device.type)\n", 53 | "\n", 54 | "!nvidia-smi" 55 | ], 56 | "metadata": { 57 | "cellView": "form", 58 | "colab": { 59 | "base_uri": "https://localhost:8080/" 60 | }, 61 | "id": "6xdy_cPJEYXV", 62 | "outputId": "18168553-b4ec-41e9-f966-9efa9db2fd33" 63 | }, 64 | "execution_count": 3, 65 | "outputs": [ 66 | { 67 | "output_type": "stream", 68 | "name": "stdout", 69 | "text": [ 70 | "CPU: 2\n", 71 | "RAM GB: 12.7\n", 72 | "PyTorch version: 1.10.0+cu111\n", 73 | "CUDA version: 11.1\n", 74 | "cuDNN version: 8005\n", 75 | "device: cuda\n", 76 | "Tue Jan 25 17:45:47 2022 \n", 77 | "+-----------------------------------------------------------------------------+\n", 78 | "| NVIDIA-SMI 495.46 Driver Version: 460.32.03 CUDA Version: 11.2 |\n", 79 | "|-------------------------------+----------------------+----------------------+\n", 80 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 81 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", 82 | "| | | MIG M. |\n", 83 | "|===============================+======================+======================|\n", 84 | "| 0 Tesla K80 Off | 00000000:00:04.0 Off | 0 |\n", 85 | "| N/A 34C P8 28W / 149W | 3MiB / 11441MiB | 0% Default |\n", 86 | "| | | N/A |\n", 87 | "+-------------------------------+----------------------+----------------------+\n", 88 | " \n", 89 | "+-----------------------------------------------------------------------------+\n", 90 | "| Processes: |\n", 91 | "| GPU GI CI PID Type Process name GPU Memory |\n", 92 | "| ID ID Usage |\n", 93 | "|=============================================================================|\n", 94 | "| No running processes found |\n", 95 | "+-----------------------------------------------------------------------------+\n" 96 | ] 97 | } 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "source": [ 103 | "## Install requirements" 104 | ], 105 | "metadata": { 106 | "id": "WWXCt_2NLhN_" 107 | } 108 | }, 109 | { 110 | "cell_type": "code", 111 | "source": [ 112 | "%%capture\n", 113 | "!pip install git+https://github.com/cene555/ruCLIP-SB.git\n", 114 | "!pip install pymorphy2\n", 115 | "!gdown -O ruCLIP-SB.pkl https://drive.google.com/uc?id=1-CghuC9TCIDyn5H3zQS6ho_TNiudzJCX\n", 116 | "\n", 117 | "!pip install git+https://github.com/Lednik7/CLIP-ONNX.git\n", 118 | "!pip install git+https://github.com/openai/CLIP.git\n", 119 | "!pip install onnxruntime-gpu\n", 120 | "\n", 121 | "!wget -c -O CLIP.png https://github.com/openai/CLIP/blob/main/CLIP.png?raw=true" 122 | ], 123 | "metadata": { 124 | "id": "FWEEtd7Vryaf" 125 | }, 126 | "execution_count": 2, 127 | "outputs": [] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "source": [ 132 | "## Import libraries" 133 | ], 134 | "metadata": { 135 | "id": "PHb4CAoRL3qC" 136 | } 137 | }, 138 | { 139 | "cell_type": "code", 140 | "source": [ 141 | "import torch\n", 142 | "from torchvision import transforms\n", 143 | "import transformers\n", 144 | "from transformers import BertTokenizer\n", 145 | "from ruclipsb import ruCLIPSB\n", 146 | "from ruclipsb.utils import tokenize, _convert_image_to_rgb\n", 147 | "from PIL import ImageCms, Image\n", 148 | "import cv2\n", 149 | "import numpy as np\n", 150 | "try:\n", 151 | " from torchvision.transforms import InterpolationMode\n", 152 | " BICUBIC = InterpolationMode.BICUBIC\n", 153 | "except ImportError:\n", 154 | " BICUBIC = Image.BICUBIC" 155 | ], 156 | "metadata": { 157 | "id": "cznZ7ozDL5-M" 158 | }, 159 | "execution_count": 1, 160 | "outputs": [] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "source": [ 165 | "import warnings\n", 166 | "\n", 167 | "warnings.filterwarnings(\"ignore\", category=UserWarning)" 168 | ], 169 | "metadata": { 170 | "id": "Q1JZTGGvWVNC" 171 | }, 172 | "execution_count": 3, 173 | "outputs": [] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "source": [ 178 | "torch.manual_seed(1)\n", 179 | "device = torch.device('cpu')" 180 | ], 181 | "metadata": { 182 | "id": "QXNtl3gNRiRr" 183 | }, 184 | "execution_count": 6, 185 | "outputs": [] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "source": [ 190 | "## Load RuCLIP-SB model" 191 | ], 192 | "metadata": { 193 | "id": "ithu4-z0PIm5" 194 | } 195 | }, 196 | { 197 | "cell_type": "code", 198 | "source": [ 199 | "model = ruCLIPSB()\n", 200 | "model.load_state_dict(torch.load('ruCLIP-SB.pkl', map_location=device))\n", 201 | "model = model.half().to(device)\n", 202 | "\n", 203 | "model = model.eval()\n", 204 | "for x in model.parameters(): x.requires_grad = False\n", 205 | "torch.cuda.empty_cache()" 206 | ], 207 | "metadata": { 208 | "colab": { 209 | "base_uri": "https://localhost:8080/" 210 | }, 211 | "id": "RWrR6BzhPKji", 212 | "outputId": "1f6e05e8-2f5e-401d-a7be-3e9f2be0a71c" 213 | }, 214 | "execution_count": 7, 215 | "outputs": [ 216 | { 217 | "output_type": "stream", 218 | "name": "stderr", 219 | "text": [ 220 | "Some weights of the model checkpoint at cointegrated/rubert-tiny were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']\n", 221 | "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 222 | "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" 223 | ] 224 | } 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "source": [ 230 | "tokenizer = BertTokenizer.from_pretrained(\"cointegrated/rubert-tiny\")" 231 | ], 232 | "metadata": { 233 | "id": "0eniQ2HTQggY" 234 | }, 235 | "execution_count": null, 236 | "outputs": [] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "source": [ 241 | "transform = transforms.Compose([\n", 242 | " transforms.Resize(224),\n", 243 | " transforms.CenterCrop(224),\n", 244 | " _convert_image_to_rgb,\n", 245 | " transforms.ToTensor(),\n", 246 | " transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", 247 | " std=[0.229, 0.224, 0.225]),])" 248 | ], 249 | "metadata": { 250 | "id": "RaFNlbHpQj7i" 251 | }, 252 | "execution_count": 9, 253 | "outputs": [] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "source": [ 258 | "## Prepare functions" 259 | ], 260 | "metadata": { 261 | "id": "3PlskeIMYmxk" 262 | } 263 | }, 264 | { 265 | "cell_type": "code", 266 | "source": [ 267 | "# batch first\n", 268 | "image = transform(Image.open(\"CLIP.png\")).unsqueeze(0).cpu() # [1, 3, 224, 224]\n", 269 | "image_onnx = image.detach().cpu().numpy().astype(np.float32)\n", 270 | "\n", 271 | "# batch first\n", 272 | "texts = ['диаграмма', 'собака', 'кошка']\n", 273 | "text_tokens, attention_mask = tokenize(tokenizer, texts, 77)\n", 274 | "text_tokens, attention_mask = text_tokens.cpu(), attention_mask.cpu() # [3, 77]\n", 275 | "text_onnx = torch.stack([text_tokens, attention_mask]).detach().cpu().numpy().astype(np.int64)" 276 | ], 277 | "metadata": { 278 | "id": "H3mN8xVnWj9M" 279 | }, 280 | "execution_count": 12, 281 | "outputs": [] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "source": [ 286 | "class Textual(torch.nn.Module):\n", 287 | " def __init__(self, model):\n", 288 | " super().__init__()\n", 289 | " self.model = model\n", 290 | "\n", 291 | " def forward(self, input_data):\n", 292 | " input_ids, attention_mask = input_data\n", 293 | " x = self.model.transformer(input_ids=input_ids, attention_mask=attention_mask)\n", 294 | " x = x.last_hidden_state[:, 0, :]\n", 295 | " x = self.model.final_ln(x)\n", 296 | " return x" 297 | ], 298 | "metadata": { 299 | "id": "DcnycYuYF6w1" 300 | }, 301 | "execution_count": 14, 302 | "outputs": [] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "source": [ 307 | "## Convert RuCLIP-SB model to ONNX" 308 | ], 309 | "metadata": { 310 | "id": "WmITDkxDYsv7" 311 | } 312 | }, 313 | { 314 | "cell_type": "code", 315 | "source": [ 316 | "from clip_onnx import clip_onnx\n", 317 | "\n", 318 | "def convert_textual(self, dummy_input):\n", 319 | " textual = Textual(self.model)\n", 320 | " torch.onnx.export(textual, dummy_input, self.textual_path,\n", 321 | " input_names=['input'], output_names=['output'],\n", 322 | " export_params=True, verbose=False, opset_version=14,\n", 323 | " do_constant_folding=True,\n", 324 | " dynamic_axes={'input': {1: 'batch_size'}, 'output': {0: 'batch_size'}})\n", 325 | " self.onnx_checker(self.textual_path)\n", 326 | "\n", 327 | "clip_onnx.convert_textual = convert_textual\n", 328 | "\n", 329 | "visual_path = \"clip_visual.onnx\"\n", 330 | "textual_path = \"clip_textual.onnx\"\n", 331 | "\n", 332 | "dummy_input_text = torch.stack([text_tokens, attention_mask]).detach().cpu()" 333 | ], 334 | "metadata": { 335 | "id": "zUUu9wCZCFEg" 336 | }, 337 | "execution_count": null, 338 | "outputs": [] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "source": [ 343 | "onnx_model = clip_onnx(model.float().cpu(), visual_path=visual_path, textual_path=textual_path)\n", 344 | "onnx_model.convert2onnx(image, dummy_input_text, verbose=True)" 345 | ], 346 | "metadata": { 347 | "id": "-TBIiGzwYKMn" 348 | }, 349 | "execution_count": null, 350 | "outputs": [] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "source": [ 355 | "## [ONNX] CUDA inference mode" 356 | ], 357 | "metadata": { 358 | "id": "_KWnKarOY6t-" 359 | } 360 | }, 361 | { 362 | "cell_type": "code", 363 | "source": [ 364 | "# ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']\n", 365 | "onnx_model.start_sessions(providers=[\"CUDAExecutionProvider\"]) # cuda mode" 366 | ], 367 | "metadata": { 368 | "id": "06Y5KogAY6Dj" 369 | }, 370 | "execution_count": null, 371 | "outputs": [] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "source": [ 376 | "image_features = onnx_model.encode_image(image_onnx)\n", 377 | "text_features = onnx_model.encode_text(text_onnx)\n", 378 | "\n", 379 | "logits_per_image, logits_per_text = onnx_model(image_onnx, text_onnx)\n", 380 | "probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()\n", 381 | "\n", 382 | "print(\"Label probs:\", probs) # [[0.9844646 0.01167088 0.00386453]]" 383 | ], 384 | "metadata": { 385 | "colab": { 386 | "base_uri": "https://localhost:8080/" 387 | }, 388 | "id": "IM_vMne7MGEu", 389 | "outputId": "a6ecd7d4-ed50-48e4-b098-82aedf321fc7" 390 | }, 391 | "execution_count": 16, 392 | "outputs": [ 393 | { 394 | "output_type": "stream", 395 | "name": "stdout", 396 | "text": [ 397 | "Label probs: [[0.9844646 0.01167088 0.00386453]]\n" 398 | ] 399 | } 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "source": [ 405 | "%timeit onnx_model.encode_image(image_onnx)" 406 | ], 407 | "metadata": { 408 | "colab": { 409 | "base_uri": "https://localhost:8080/" 410 | }, 411 | "id": "jCzqP0wDYXdt", 412 | "outputId": "9a1ea435-f4d8-4f16-b0f0-712bdd457947" 413 | }, 414 | "execution_count": 17, 415 | "outputs": [ 416 | { 417 | "output_type": "stream", 418 | "name": "stdout", 419 | "text": [ 420 | "10 loops, best of 5: 18 ms per loop\n" 421 | ] 422 | } 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "source": [ 428 | "%timeit onnx_model.encode_text(text_onnx)" 429 | ], 430 | "metadata": { 431 | "colab": { 432 | "base_uri": "https://localhost:8080/" 433 | }, 434 | "id": "4zYofeqqYco8", 435 | "outputId": "ed1f7195-eb0a-4982-eab1-689e30e2bec6" 436 | }, 437 | "execution_count": 18, 438 | "outputs": [ 439 | { 440 | "output_type": "stream", 441 | "name": "stdout", 442 | "text": [ 443 | "100 loops, best of 5: 2.76 ms per loop\n" 444 | ] 445 | } 446 | ] 447 | } 448 | ] 449 | } -------------------------------------------------------------------------------- /notebooks/finetune_ruCLIP_SB.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "finetune_ruCLIP-SB.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [ 9 | "VebmNl7x48k9", 10 | "edtendrV4-c2", 11 | "-7FO-aQTJT3X", 12 | "1fbZkwoZP7Lr", 13 | "O1DCPGa_P-Ut" 14 | ], 15 | "machine_shape": "hm" 16 | }, 17 | "kernelspec": { 18 | "name": "python3", 19 | "display_name": "Python 3" 20 | }, 21 | "language_info": { 22 | "name": "python" 23 | }, 24 | "accelerator": "GPU", 25 | "widgets": { 26 | "application/vnd.jupyter.widget-state+json": { 27 | "7091636af08b41c88eb8a9f9179551ee": { 28 | "model_module": "@jupyter-widgets/controls", 29 | "model_name": "HBoxModel", 30 | "model_module_version": "1.5.0", 31 | "state": { 32 | "_view_name": "HBoxView", 33 | "_dom_classes": [], 34 | "_model_name": "HBoxModel", 35 | "_view_module": "@jupyter-widgets/controls", 36 | "_model_module_version": "1.5.0", 37 | "_view_count": null, 38 | "_view_module_version": "1.5.0", 39 | "box_style": "", 40 | "layout": "IPY_MODEL_d8c24ad5ba9f431da269aa99537d3bd5", 41 | "_model_module": "@jupyter-widgets/controls", 42 | "children": [ 43 | "IPY_MODEL_c62ecf993b0e4ddb810d4d4b82e50ef5", 44 | "IPY_MODEL_11b7b2ebb4294b02b15506446507a637", 45 | "IPY_MODEL_b7e3962196584ac4a1a325ad6af7239a" 46 | ] 47 | } 48 | }, 49 | "d8c24ad5ba9f431da269aa99537d3bd5": { 50 | "model_module": "@jupyter-widgets/base", 51 | "model_name": "LayoutModel", 52 | "model_module_version": "1.2.0", 53 | "state": { 54 | "_view_name": "LayoutView", 55 | "grid_template_rows": null, 56 | "right": null, 57 | "justify_content": null, 58 | "_view_module": "@jupyter-widgets/base", 59 | "overflow": null, 60 | "_model_module_version": "1.2.0", 61 | "_view_count": null, 62 | "flex_flow": null, 63 | "width": null, 64 | "min_width": null, 65 | "border": null, 66 | "align_items": null, 67 | "bottom": null, 68 | "_model_module": "@jupyter-widgets/base", 69 | "top": null, 70 | "grid_column": null, 71 | "overflow_y": null, 72 | "overflow_x": null, 73 | "grid_auto_flow": null, 74 | "grid_area": null, 75 | "grid_template_columns": null, 76 | "flex": null, 77 | "_model_name": "LayoutModel", 78 | "justify_items": null, 79 | "grid_row": null, 80 | "max_height": null, 81 | "align_content": null, 82 | "visibility": null, 83 | "align_self": null, 84 | "height": null, 85 | "min_height": null, 86 | "padding": null, 87 | "grid_auto_rows": null, 88 | "grid_gap": null, 89 | "max_width": null, 90 | "order": null, 91 | "_view_module_version": "1.2.0", 92 | "grid_template_areas": null, 93 | "object_position": null, 94 | "object_fit": null, 95 | "grid_auto_columns": null, 96 | "margin": null, 97 | "display": null, 98 | "left": null 99 | } 100 | }, 101 | "c62ecf993b0e4ddb810d4d4b82e50ef5": { 102 | "model_module": "@jupyter-widgets/controls", 103 | "model_name": "HTMLModel", 104 | "model_module_version": "1.5.0", 105 | "state": { 106 | "_view_name": "HTMLView", 107 | "style": "IPY_MODEL_fe468022a67f48cb8d8f1c05621a6cb7", 108 | "_dom_classes": [], 109 | "description": "", 110 | "_model_name": "HTMLModel", 111 | "placeholder": "​", 112 | "_view_module": "@jupyter-widgets/controls", 113 | "_model_module_version": "1.5.0", 114 | "value": "Downloading: 100%", 115 | "_view_count": null, 116 | "_view_module_version": "1.5.0", 117 | "description_tooltip": null, 118 | "_model_module": "@jupyter-widgets/controls", 119 | "layout": "IPY_MODEL_c0ec008b526a4c3486b4638ec18e66ea" 120 | } 121 | }, 122 | "11b7b2ebb4294b02b15506446507a637": { 123 | "model_module": "@jupyter-widgets/controls", 124 | "model_name": "FloatProgressModel", 125 | "model_module_version": "1.5.0", 126 | "state": { 127 | "_view_name": "ProgressView", 128 | "style": "IPY_MODEL_81e0631c175c444e9603918fad9b504d", 129 | "_dom_classes": [], 130 | "description": "", 131 | "_model_name": "FloatProgressModel", 132 | "bar_style": "success", 133 | "max": 632, 134 | "_view_module": "@jupyter-widgets/controls", 135 | "_model_module_version": "1.5.0", 136 | "value": 632, 137 | "_view_count": null, 138 | "_view_module_version": "1.5.0", 139 | "orientation": "horizontal", 140 | "min": 0, 141 | "description_tooltip": null, 142 | "_model_module": "@jupyter-widgets/controls", 143 | "layout": "IPY_MODEL_4d9d5b61bb2847848699313b8540aa9c" 144 | } 145 | }, 146 | "b7e3962196584ac4a1a325ad6af7239a": { 147 | "model_module": "@jupyter-widgets/controls", 148 | "model_name": "HTMLModel", 149 | "model_module_version": "1.5.0", 150 | "state": { 151 | "_view_name": "HTMLView", 152 | "style": "IPY_MODEL_84abaf38fd6d422588f9bf7226e9378b", 153 | "_dom_classes": [], 154 | "description": "", 155 | "_model_name": "HTMLModel", 156 | "placeholder": "​", 157 | "_view_module": "@jupyter-widgets/controls", 158 | "_model_module_version": "1.5.0", 159 | "value": " 632/632 [00:00<00:00, 17.2kB/s]", 160 | "_view_count": null, 161 | "_view_module_version": "1.5.0", 162 | "description_tooltip": null, 163 | "_model_module": "@jupyter-widgets/controls", 164 | "layout": "IPY_MODEL_f24f6be70e314f2689017533011611eb" 165 | } 166 | }, 167 | "fe468022a67f48cb8d8f1c05621a6cb7": { 168 | "model_module": "@jupyter-widgets/controls", 169 | "model_name": "DescriptionStyleModel", 170 | "model_module_version": "1.5.0", 171 | "state": { 172 | "_view_name": "StyleView", 173 | "_model_name": "DescriptionStyleModel", 174 | "description_width": "", 175 | "_view_module": "@jupyter-widgets/base", 176 | "_model_module_version": "1.5.0", 177 | "_view_count": null, 178 | "_view_module_version": "1.2.0", 179 | "_model_module": "@jupyter-widgets/controls" 180 | } 181 | }, 182 | "c0ec008b526a4c3486b4638ec18e66ea": { 183 | "model_module": "@jupyter-widgets/base", 184 | "model_name": "LayoutModel", 185 | "model_module_version": "1.2.0", 186 | "state": { 187 | "_view_name": "LayoutView", 188 | "grid_template_rows": null, 189 | "right": null, 190 | "justify_content": null, 191 | "_view_module": "@jupyter-widgets/base", 192 | "overflow": null, 193 | "_model_module_version": "1.2.0", 194 | "_view_count": null, 195 | "flex_flow": null, 196 | "width": null, 197 | "min_width": null, 198 | "border": null, 199 | "align_items": null, 200 | "bottom": null, 201 | "_model_module": "@jupyter-widgets/base", 202 | "top": null, 203 | "grid_column": null, 204 | "overflow_y": null, 205 | "overflow_x": null, 206 | "grid_auto_flow": null, 207 | "grid_area": null, 208 | "grid_template_columns": null, 209 | "flex": null, 210 | "_model_name": "LayoutModel", 211 | "justify_items": null, 212 | "grid_row": null, 213 | "max_height": null, 214 | "align_content": null, 215 | "visibility": null, 216 | "align_self": null, 217 | "height": null, 218 | "min_height": null, 219 | "padding": null, 220 | "grid_auto_rows": null, 221 | "grid_gap": null, 222 | "max_width": null, 223 | "order": null, 224 | "_view_module_version": "1.2.0", 225 | "grid_template_areas": null, 226 | "object_position": null, 227 | "object_fit": null, 228 | "grid_auto_columns": null, 229 | "margin": null, 230 | "display": null, 231 | "left": null 232 | } 233 | }, 234 | "81e0631c175c444e9603918fad9b504d": { 235 | "model_module": "@jupyter-widgets/controls", 236 | "model_name": "ProgressStyleModel", 237 | "model_module_version": "1.5.0", 238 | "state": { 239 | "_view_name": "StyleView", 240 | "_model_name": "ProgressStyleModel", 241 | "description_width": "", 242 | "_view_module": "@jupyter-widgets/base", 243 | "_model_module_version": "1.5.0", 244 | "_view_count": null, 245 | "_view_module_version": "1.2.0", 246 | "bar_color": null, 247 | "_model_module": "@jupyter-widgets/controls" 248 | } 249 | }, 250 | "4d9d5b61bb2847848699313b8540aa9c": { 251 | "model_module": "@jupyter-widgets/base", 252 | "model_name": "LayoutModel", 253 | "model_module_version": "1.2.0", 254 | "state": { 255 | "_view_name": "LayoutView", 256 | "grid_template_rows": null, 257 | "right": null, 258 | "justify_content": null, 259 | "_view_module": "@jupyter-widgets/base", 260 | "overflow": null, 261 | "_model_module_version": "1.2.0", 262 | "_view_count": null, 263 | "flex_flow": null, 264 | "width": null, 265 | "min_width": null, 266 | "border": null, 267 | "align_items": null, 268 | "bottom": null, 269 | "_model_module": "@jupyter-widgets/base", 270 | "top": null, 271 | "grid_column": null, 272 | "overflow_y": null, 273 | "overflow_x": null, 274 | "grid_auto_flow": null, 275 | "grid_area": null, 276 | "grid_template_columns": null, 277 | "flex": null, 278 | "_model_name": "LayoutModel", 279 | "justify_items": null, 280 | "grid_row": null, 281 | "max_height": null, 282 | "align_content": null, 283 | "visibility": null, 284 | "align_self": null, 285 | "height": null, 286 | "min_height": null, 287 | "padding": null, 288 | "grid_auto_rows": null, 289 | "grid_gap": null, 290 | "max_width": null, 291 | "order": null, 292 | "_view_module_version": "1.2.0", 293 | "grid_template_areas": null, 294 | "object_position": null, 295 | "object_fit": null, 296 | "grid_auto_columns": null, 297 | "margin": null, 298 | "display": null, 299 | "left": null 300 | } 301 | }, 302 | "84abaf38fd6d422588f9bf7226e9378b": { 303 | "model_module": "@jupyter-widgets/controls", 304 | "model_name": "DescriptionStyleModel", 305 | "model_module_version": "1.5.0", 306 | "state": { 307 | "_view_name": "StyleView", 308 | "_model_name": "DescriptionStyleModel", 309 | "description_width": "", 310 | "_view_module": "@jupyter-widgets/base", 311 | "_model_module_version": "1.5.0", 312 | "_view_count": null, 313 | "_view_module_version": "1.2.0", 314 | "_model_module": "@jupyter-widgets/controls" 315 | } 316 | }, 317 | "f24f6be70e314f2689017533011611eb": { 318 | "model_module": "@jupyter-widgets/base", 319 | "model_name": "LayoutModel", 320 | "model_module_version": "1.2.0", 321 | "state": { 322 | "_view_name": "LayoutView", 323 | "grid_template_rows": null, 324 | "right": null, 325 | "justify_content": null, 326 | "_view_module": "@jupyter-widgets/base", 327 | "overflow": null, 328 | "_model_module_version": "1.2.0", 329 | "_view_count": null, 330 | "flex_flow": null, 331 | "width": null, 332 | "min_width": null, 333 | "border": null, 334 | "align_items": null, 335 | "bottom": null, 336 | "_model_module": "@jupyter-widgets/base", 337 | "top": null, 338 | "grid_column": null, 339 | "overflow_y": null, 340 | "overflow_x": null, 341 | "grid_auto_flow": null, 342 | "grid_area": null, 343 | "grid_template_columns": null, 344 | "flex": null, 345 | "_model_name": "LayoutModel", 346 | "justify_items": null, 347 | "grid_row": null, 348 | "max_height": null, 349 | "align_content": null, 350 | "visibility": null, 351 | "align_self": null, 352 | "height": null, 353 | "min_height": null, 354 | "padding": null, 355 | "grid_auto_rows": null, 356 | "grid_gap": null, 357 | "max_width": null, 358 | "order": null, 359 | "_view_module_version": "1.2.0", 360 | "grid_template_areas": null, 361 | "object_position": null, 362 | "object_fit": null, 363 | "grid_auto_columns": null, 364 | "margin": null, 365 | "display": null, 366 | "left": null 367 | } 368 | }, 369 | "d60078ecf0a743e7a28f523c1ed105b6": { 370 | "model_module": "@jupyter-widgets/controls", 371 | "model_name": "HBoxModel", 372 | "model_module_version": "1.5.0", 373 | "state": { 374 | "_view_name": "HBoxView", 375 | "_dom_classes": [], 376 | "_model_name": "HBoxModel", 377 | "_view_module": "@jupyter-widgets/controls", 378 | "_model_module_version": "1.5.0", 379 | "_view_count": null, 380 | "_view_module_version": "1.5.0", 381 | "box_style": "", 382 | "layout": "IPY_MODEL_107f5d59a1504e76a205fc54ef302376", 383 | "_model_module": "@jupyter-widgets/controls", 384 | "children": [ 385 | "IPY_MODEL_9e59a739116d4ff6a2314cfffbf4594e", 386 | "IPY_MODEL_f358ac594bf34885985324fb1f10f371", 387 | "IPY_MODEL_a95b5e967c574b15a40da9497df022b6" 388 | ] 389 | } 390 | }, 391 | "107f5d59a1504e76a205fc54ef302376": { 392 | "model_module": "@jupyter-widgets/base", 393 | "model_name": "LayoutModel", 394 | "model_module_version": "1.2.0", 395 | "state": { 396 | "_view_name": "LayoutView", 397 | "grid_template_rows": null, 398 | "right": null, 399 | "justify_content": null, 400 | "_view_module": "@jupyter-widgets/base", 401 | "overflow": null, 402 | "_model_module_version": "1.2.0", 403 | "_view_count": null, 404 | "flex_flow": null, 405 | "width": null, 406 | "min_width": null, 407 | "border": null, 408 | "align_items": null, 409 | "bottom": null, 410 | "_model_module": "@jupyter-widgets/base", 411 | "top": null, 412 | "grid_column": null, 413 | "overflow_y": null, 414 | "overflow_x": null, 415 | "grid_auto_flow": null, 416 | "grid_area": null, 417 | "grid_template_columns": null, 418 | "flex": null, 419 | "_model_name": "LayoutModel", 420 | "justify_items": null, 421 | "grid_row": null, 422 | "max_height": null, 423 | "align_content": null, 424 | "visibility": null, 425 | "align_self": null, 426 | "height": null, 427 | "min_height": null, 428 | "padding": null, 429 | "grid_auto_rows": null, 430 | "grid_gap": null, 431 | "max_width": null, 432 | "order": null, 433 | "_view_module_version": "1.2.0", 434 | "grid_template_areas": null, 435 | "object_position": null, 436 | "object_fit": null, 437 | "grid_auto_columns": null, 438 | "margin": null, 439 | "display": null, 440 | "left": null 441 | } 442 | }, 443 | "9e59a739116d4ff6a2314cfffbf4594e": { 444 | "model_module": "@jupyter-widgets/controls", 445 | "model_name": "HTMLModel", 446 | "model_module_version": "1.5.0", 447 | "state": { 448 | "_view_name": "HTMLView", 449 | "style": "IPY_MODEL_18aed82927c945eb926e48bd17bbc823", 450 | "_dom_classes": [], 451 | "description": "", 452 | "_model_name": "HTMLModel", 453 | "placeholder": "​", 454 | "_view_module": "@jupyter-widgets/controls", 455 | "_model_module_version": "1.5.0", 456 | "value": "Downloading: 100%", 457 | "_view_count": null, 458 | "_view_module_version": "1.5.0", 459 | "description_tooltip": null, 460 | "_model_module": "@jupyter-widgets/controls", 461 | "layout": "IPY_MODEL_e1f478d1bb864f8899d53ccee9694e47" 462 | } 463 | }, 464 | "f358ac594bf34885985324fb1f10f371": { 465 | "model_module": "@jupyter-widgets/controls", 466 | "model_name": "FloatProgressModel", 467 | "model_module_version": "1.5.0", 468 | "state": { 469 | "_view_name": "ProgressView", 470 | "style": "IPY_MODEL_17a109cf5a374c479a4af09cea3484e2", 471 | "_dom_classes": [], 472 | "description": "", 473 | "_model_name": "FloatProgressModel", 474 | "bar_style": "success", 475 | "max": 47679974, 476 | "_view_module": "@jupyter-widgets/controls", 477 | "_model_module_version": "1.5.0", 478 | "value": 47679974, 479 | "_view_count": null, 480 | "_view_module_version": "1.5.0", 481 | "orientation": "horizontal", 482 | "min": 0, 483 | "description_tooltip": null, 484 | "_model_module": "@jupyter-widgets/controls", 485 | "layout": "IPY_MODEL_35c41aa3b1064c91a287ab4e6b5b7de8" 486 | } 487 | }, 488 | "a95b5e967c574b15a40da9497df022b6": { 489 | "model_module": "@jupyter-widgets/controls", 490 | "model_name": "HTMLModel", 491 | "model_module_version": "1.5.0", 492 | "state": { 493 | "_view_name": "HTMLView", 494 | "style": "IPY_MODEL_e5072a91b6914f10af2d63045c27a307", 495 | "_dom_classes": [], 496 | "description": "", 497 | "_model_name": "HTMLModel", 498 | "placeholder": "​", 499 | "_view_module": "@jupyter-widgets/controls", 500 | "_model_module_version": "1.5.0", 501 | "value": " 45.5M/45.5M [00:00<00:00, 62.9MB/s]", 502 | "_view_count": null, 503 | "_view_module_version": "1.5.0", 504 | "description_tooltip": null, 505 | "_model_module": "@jupyter-widgets/controls", 506 | "layout": "IPY_MODEL_c43f948fd4b54f4fbe65bf93095f675c" 507 | } 508 | }, 509 | "18aed82927c945eb926e48bd17bbc823": { 510 | "model_module": "@jupyter-widgets/controls", 511 | "model_name": "DescriptionStyleModel", 512 | "model_module_version": "1.5.0", 513 | "state": { 514 | "_view_name": "StyleView", 515 | "_model_name": "DescriptionStyleModel", 516 | "description_width": "", 517 | "_view_module": "@jupyter-widgets/base", 518 | "_model_module_version": "1.5.0", 519 | "_view_count": null, 520 | "_view_module_version": "1.2.0", 521 | "_model_module": "@jupyter-widgets/controls" 522 | } 523 | }, 524 | "e1f478d1bb864f8899d53ccee9694e47": { 525 | "model_module": "@jupyter-widgets/base", 526 | "model_name": "LayoutModel", 527 | "model_module_version": "1.2.0", 528 | "state": { 529 | "_view_name": "LayoutView", 530 | "grid_template_rows": null, 531 | "right": null, 532 | "justify_content": null, 533 | "_view_module": "@jupyter-widgets/base", 534 | "overflow": null, 535 | "_model_module_version": "1.2.0", 536 | "_view_count": null, 537 | "flex_flow": null, 538 | "width": null, 539 | "min_width": null, 540 | "border": null, 541 | "align_items": null, 542 | "bottom": null, 543 | "_model_module": "@jupyter-widgets/base", 544 | "top": null, 545 | "grid_column": null, 546 | "overflow_y": null, 547 | "overflow_x": null, 548 | "grid_auto_flow": null, 549 | "grid_area": null, 550 | "grid_template_columns": null, 551 | "flex": null, 552 | "_model_name": "LayoutModel", 553 | "justify_items": null, 554 | "grid_row": null, 555 | "max_height": null, 556 | "align_content": null, 557 | "visibility": null, 558 | "align_self": null, 559 | "height": null, 560 | "min_height": null, 561 | "padding": null, 562 | "grid_auto_rows": null, 563 | "grid_gap": null, 564 | "max_width": null, 565 | "order": null, 566 | "_view_module_version": "1.2.0", 567 | "grid_template_areas": null, 568 | "object_position": null, 569 | "object_fit": null, 570 | "grid_auto_columns": null, 571 | "margin": null, 572 | "display": null, 573 | "left": null 574 | } 575 | }, 576 | "17a109cf5a374c479a4af09cea3484e2": { 577 | "model_module": "@jupyter-widgets/controls", 578 | "model_name": "ProgressStyleModel", 579 | "model_module_version": "1.5.0", 580 | "state": { 581 | "_view_name": "StyleView", 582 | "_model_name": "ProgressStyleModel", 583 | "description_width": "", 584 | "_view_module": "@jupyter-widgets/base", 585 | "_model_module_version": "1.5.0", 586 | "_view_count": null, 587 | "_view_module_version": "1.2.0", 588 | "bar_color": null, 589 | "_model_module": "@jupyter-widgets/controls" 590 | } 591 | }, 592 | "35c41aa3b1064c91a287ab4e6b5b7de8": { 593 | "model_module": "@jupyter-widgets/base", 594 | "model_name": "LayoutModel", 595 | "model_module_version": "1.2.0", 596 | "state": { 597 | "_view_name": "LayoutView", 598 | "grid_template_rows": null, 599 | "right": null, 600 | "justify_content": null, 601 | "_view_module": "@jupyter-widgets/base", 602 | "overflow": null, 603 | "_model_module_version": "1.2.0", 604 | "_view_count": null, 605 | "flex_flow": null, 606 | "width": null, 607 | "min_width": null, 608 | "border": null, 609 | "align_items": null, 610 | "bottom": null, 611 | "_model_module": "@jupyter-widgets/base", 612 | "top": null, 613 | "grid_column": null, 614 | "overflow_y": null, 615 | "overflow_x": null, 616 | "grid_auto_flow": null, 617 | "grid_area": null, 618 | "grid_template_columns": null, 619 | "flex": null, 620 | "_model_name": "LayoutModel", 621 | "justify_items": null, 622 | "grid_row": null, 623 | "max_height": null, 624 | "align_content": null, 625 | "visibility": null, 626 | "align_self": null, 627 | "height": null, 628 | "min_height": null, 629 | "padding": null, 630 | "grid_auto_rows": null, 631 | "grid_gap": null, 632 | "max_width": null, 633 | "order": null, 634 | "_view_module_version": "1.2.0", 635 | "grid_template_areas": null, 636 | "object_position": null, 637 | "object_fit": null, 638 | "grid_auto_columns": null, 639 | "margin": null, 640 | "display": null, 641 | "left": null 642 | } 643 | }, 644 | "e5072a91b6914f10af2d63045c27a307": { 645 | "model_module": "@jupyter-widgets/controls", 646 | "model_name": "DescriptionStyleModel", 647 | "model_module_version": "1.5.0", 648 | "state": { 649 | "_view_name": "StyleView", 650 | "_model_name": "DescriptionStyleModel", 651 | "description_width": "", 652 | "_view_module": "@jupyter-widgets/base", 653 | "_model_module_version": "1.5.0", 654 | "_view_count": null, 655 | "_view_module_version": "1.2.0", 656 | "_model_module": "@jupyter-widgets/controls" 657 | } 658 | }, 659 | "c43f948fd4b54f4fbe65bf93095f675c": { 660 | "model_module": "@jupyter-widgets/base", 661 | "model_name": "LayoutModel", 662 | "model_module_version": "1.2.0", 663 | "state": { 664 | "_view_name": "LayoutView", 665 | "grid_template_rows": null, 666 | "right": null, 667 | "justify_content": null, 668 | "_view_module": "@jupyter-widgets/base", 669 | "overflow": null, 670 | "_model_module_version": "1.2.0", 671 | "_view_count": null, 672 | "flex_flow": null, 673 | "width": null, 674 | "min_width": null, 675 | "border": null, 676 | "align_items": null, 677 | "bottom": null, 678 | "_model_module": "@jupyter-widgets/base", 679 | "top": null, 680 | "grid_column": null, 681 | "overflow_y": null, 682 | "overflow_x": null, 683 | "grid_auto_flow": null, 684 | "grid_area": null, 685 | "grid_template_columns": null, 686 | "flex": null, 687 | "_model_name": "LayoutModel", 688 | "justify_items": null, 689 | "grid_row": null, 690 | "max_height": null, 691 | "align_content": null, 692 | "visibility": null, 693 | "align_self": null, 694 | "height": null, 695 | "min_height": null, 696 | "padding": null, 697 | "grid_auto_rows": null, 698 | "grid_gap": null, 699 | "max_width": null, 700 | "order": null, 701 | "_view_module_version": "1.2.0", 702 | "grid_template_areas": null, 703 | "object_position": null, 704 | "object_fit": null, 705 | "grid_auto_columns": null, 706 | "margin": null, 707 | "display": null, 708 | "left": null 709 | } 710 | } 711 | } 712 | } 713 | }, 714 | "cells": [ 715 | { 716 | "cell_type": "markdown", 717 | "source": [ 718 | "# Install" 719 | ], 720 | "metadata": { 721 | "id": "VebmNl7x48k9" 722 | } 723 | }, 724 | { 725 | "cell_type": "code", 726 | "execution_count": 1, 727 | "metadata": { 728 | "colab": { 729 | "base_uri": "https://localhost:8080/" 730 | }, 731 | "id": "-ctvtX2t4zpC", 732 | "outputId": "97d68d28-f269-4455-8f95-35e884aeae96" 733 | }, 734 | "outputs": [ 735 | { 736 | "output_type": "stream", 737 | "name": "stdout", 738 | "text": [ 739 | "Collecting git+https://github.com/cene555/ruCLIP-SB.git\n", 740 | " Cloning https://github.com/cene555/ruCLIP-SB.git to /tmp/pip-req-build-td8cko6v\n", 741 | " Running command git clone -q https://github.com/cene555/ruCLIP-SB.git /tmp/pip-req-build-td8cko6v\n", 742 | "Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from ruclipsb==1.0) (1.10.0+cu111)\n", 743 | "Collecting timm\n", 744 | " Downloading timm-0.5.4-py3-none-any.whl (431 kB)\n", 745 | "\u001b[K |████████████████████████████████| 431 kB 10.1 MB/s \n", 746 | "\u001b[?25hCollecting transformers==4.12.3\n", 747 | " Downloading transformers-4.12.3-py3-none-any.whl (3.1 MB)\n", 748 | "\u001b[K |████████████████████████████████| 3.1 MB 87.9 MB/s \n", 749 | "\u001b[?25hRequirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from ruclipsb==1.0) (4.62.3)\n", 750 | "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers==4.12.3->ruclipsb==1.0) (21.3)\n", 751 | "Collecting tokenizers<0.11,>=0.10.1\n", 752 | " Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)\n", 753 | "\u001b[K |████████████████████████████████| 3.3 MB 81.3 MB/s \n", 754 | "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers==4.12.3->ruclipsb==1.0) (3.4.2)\n", 755 | "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers==4.12.3->ruclipsb==1.0) (4.10.0)\n", 756 | "Collecting pyyaml>=5.1\n", 757 | " Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)\n", 758 | "\u001b[K |████████████████████████████████| 596 kB 59.6 MB/s \n", 759 | "\u001b[?25hRequirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.12.3->ruclipsb==1.0) (2019.12.20)\n", 760 | "Collecting sacremoses\n", 761 | " Downloading sacremoses-0.0.47-py2.py3-none-any.whl (895 kB)\n", 762 | "\u001b[K |████████████████████████████████| 895 kB 55.4 MB/s \n", 763 | "\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers==4.12.3->ruclipsb==1.0) (2.23.0)\n", 764 | "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.12.3->ruclipsb==1.0) (1.19.5)\n", 765 | "Collecting huggingface-hub<1.0,>=0.1.0\n", 766 | " Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)\n", 767 | "\u001b[K |████████████████████████████████| 67 kB 6.2 MB/s \n", 768 | "\u001b[?25hRequirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers==4.12.3->ruclipsb==1.0) (3.10.0.2)\n", 769 | "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers==4.12.3->ruclipsb==1.0) (3.0.6)\n", 770 | "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers==4.12.3->ruclipsb==1.0) (3.7.0)\n", 771 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.12.3->ruclipsb==1.0) (2.10)\n", 772 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.12.3->ruclipsb==1.0) (2021.10.8)\n", 773 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.12.3->ruclipsb==1.0) (3.0.4)\n", 774 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.12.3->ruclipsb==1.0) (1.24.3)\n", 775 | "Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.12.3->ruclipsb==1.0) (1.1.0)\n", 776 | "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.12.3->ruclipsb==1.0) (1.15.0)\n", 777 | "Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.12.3->ruclipsb==1.0) (7.1.2)\n", 778 | "Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (from timm->ruclipsb==1.0) (0.11.1+cu111)\n", 779 | "Requirement already satisfied: pillow!=8.3.0,>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision->timm->ruclipsb==1.0) (7.1.2)\n", 780 | "Building wheels for collected packages: ruclipsb\n", 781 | " Building wheel for ruclipsb (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 782 | " Created wheel for ruclipsb: filename=ruclipsb-1.0-py3-none-any.whl size=5562 sha256=f497e108f0173806c658f6a965ac70ebaf70c702041ded112e66b5958a4b9a19\n", 783 | " Stored in directory: /tmp/pip-ephem-wheel-cache-rmcf_uy5/wheels/6d/97/0f/f70741c9e95bd88c83a32b20d029078aa97fc3e781b4282e9f\n", 784 | "Successfully built ruclipsb\n", 785 | "Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers, timm, ruclipsb\n", 786 | " Attempting uninstall: pyyaml\n", 787 | " Found existing installation: PyYAML 3.13\n", 788 | " Uninstalling PyYAML-3.13:\n", 789 | " Successfully uninstalled PyYAML-3.13\n", 790 | "Successfully installed huggingface-hub-0.4.0 pyyaml-6.0 ruclipsb-1.0 sacremoses-0.0.47 timm-0.5.4 tokenizers-0.10.3 transformers-4.12.3\n" 791 | ] 792 | } 793 | ], 794 | "source": [ 795 | "!pip install 'git+https://github.com/cene555/ruCLIP-SB.git'" 796 | ] 797 | }, 798 | { 799 | "cell_type": "markdown", 800 | "source": [ 801 | "Download model" 802 | ], 803 | "metadata": { 804 | "id": "mHC0pjfUJiN2" 805 | } 806 | }, 807 | { 808 | "cell_type": "code", 809 | "source": [ 810 | "!gdown https://drive.google.com/uc?id=1-CghuC9TCIDyn5H3zQS6ho_TNiudzJCX " 811 | ], 812 | "metadata": { 813 | "colab": { 814 | "base_uri": "https://localhost:8080/" 815 | }, 816 | "id": "X31Z3oOPJPCF", 817 | "outputId": "9b36cf3f-44f0-46af-91cc-29553e067593" 818 | }, 819 | "execution_count": 2, 820 | "outputs": [ 821 | { 822 | "output_type": "stream", 823 | "name": "stdout", 824 | "text": [ 825 | "Downloading...\n", 826 | "From: https://drive.google.com/uc?id=1-CghuC9TCIDyn5H3zQS6ho_TNiudzJCX\n", 827 | "To: /content/ruCLIP-SB.pkl\n", 828 | "100% 159M/159M [00:00<00:00, 282MB/s]\n" 829 | ] 830 | } 831 | ] 832 | }, 833 | { 834 | "cell_type": "markdown", 835 | "source": [ 836 | "Download dataset" 837 | ], 838 | "metadata": { 839 | "id": "b17X7R1XJlGw" 840 | } 841 | }, 842 | { 843 | "cell_type": "code", 844 | "source": [ 845 | "#download this dataset https://drive.google.com/file/d/1XQYR66ndrMPTy_Z4eD5jMgl_BNISZGfq/view?usp=sharing" 846 | ], 847 | "metadata": { 848 | "id": "jYLkE8WbnWkV" 849 | }, 850 | "execution_count": null, 851 | "outputs": [] 852 | }, 853 | { 854 | "cell_type": "code", 855 | "source": [ 856 | "!gdown https://drive.google.com/uc?id=1VRDYOvzvASd8PHSBTbp6En47-I5X88YU" 857 | ], 858 | "metadata": { 859 | "colab": { 860 | "base_uri": "https://localhost:8080/" 861 | }, 862 | "id": "8jHlSkgRJGma", 863 | "outputId": "9bcdb43a-626d-4f59-a28b-72f44b1e0f81" 864 | }, 865 | "execution_count": 14, 866 | "outputs": [ 867 | { 868 | "output_type": "stream", 869 | "name": "stdout", 870 | "text": [ 871 | "Downloading...\n", 872 | "From: https://drive.google.com/uc?id=1VRDYOvzvASd8PHSBTbp6En47-I5X88YU\n", 873 | "To: /content/ru_flickr.csv\n", 874 | "\r 0% 0.00/5.73M [00:00\n", 980 | "
\n", 981 | "
\n", 982 | "\n", 995 | "\n", 996 | " \n", 997 | " \n", 998 | " \n", 999 | " \n", 1000 | " \n", 1001 | " \n", 1002 | " \n", 1003 | " \n", 1004 | " \n", 1005 | " \n", 1006 | " \n", 1007 | " \n", 1008 | " \n", 1009 | " \n", 1010 | " \n", 1011 | " \n", 1012 | " \n", 1013 | " \n", 1014 | " \n", 1015 | " \n", 1016 | " \n", 1017 | " \n", 1018 | " \n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | " \n", 1024 | " \n", 1025 | " \n", 1026 | " \n", 1027 | " \n", 1028 | " \n", 1029 | " \n", 1030 | "
image_nametext
01000092795.jpgДва молодых парня с шаткими волосами смотрят н...
110002456.jpgНесколько человек в жестких люках управляют ги...
21000268201.jpgРебенок в розовом платье поднимается по лестни...
31000344755.jpgКто-то в синей рубашке и шляпе стоит на полу и...
41000366164.jpgДвое мужчин, один в темной рубашке, другой в ч...
\n", 1031 | "
\n", 1032 | " \n", 1042 | " \n", 1043 | " \n", 1080 | "\n", 1081 | " \n", 1105 | "
\n", 1106 | " \n", 1107 | " " 1108 | ], 1109 | "text/plain": [ 1110 | " image_name text\n", 1111 | "0 1000092795.jpg Два молодых парня с шаткими волосами смотрят н...\n", 1112 | "1 10002456.jpg Несколько человек в жестких люках управляют ги...\n", 1113 | "2 1000268201.jpg Ребенок в розовом платье поднимается по лестни...\n", 1114 | "3 1000344755.jpg Кто-то в синей рубашке и шляпе стоит на полу и...\n", 1115 | "4 1000366164.jpg Двое мужчин, один в темной рубашке, другой в ч..." 1116 | ] 1117 | }, 1118 | "metadata": {}, 1119 | "execution_count": 7 1120 | } 1121 | ] 1122 | }, 1123 | { 1124 | "cell_type": "code", 1125 | "source": [ 1126 | "train_df, val_df = train_test_split(df, test_size=0.1, random_state=655)" 1127 | ], 1128 | "metadata": { 1129 | "id": "diEEyVEsPyYS" 1130 | }, 1131 | "execution_count": 8, 1132 | "outputs": [] 1133 | }, 1134 | { 1135 | "cell_type": "code", 1136 | "source": [ 1137 | "train_df.to_csv('train.csv', index=False)\n", 1138 | "val_df.to_csv('val.csv', index=False)\n" 1139 | ], 1140 | "metadata": { 1141 | "id": "oRfqx0MkQxwT" 1142 | }, 1143 | "execution_count": 9, 1144 | "outputs": [] 1145 | }, 1146 | { 1147 | "cell_type": "markdown", 1148 | "source": [ 1149 | "# Load model" 1150 | ], 1151 | "metadata": { 1152 | "id": "1fbZkwoZP7Lr" 1153 | } 1154 | }, 1155 | { 1156 | "cell_type": "code", 1157 | "source": [ 1158 | "device = 'cuda'" 1159 | ], 1160 | "metadata": { 1161 | "id": "dxlyKsY0MDlF" 1162 | }, 1163 | "execution_count": 10, 1164 | "outputs": [] 1165 | }, 1166 | { 1167 | "cell_type": "code", 1168 | "source": [ 1169 | "model = ruCLIPSB().to(device)\n", 1170 | "model.load_state_dict(torch.load('/content/ruCLIP-SB.pkl'))" 1171 | ], 1172 | "metadata": { 1173 | "colab": { 1174 | "base_uri": "https://localhost:8080/", 1175 | "height": 0, 1176 | "referenced_widgets": [ 1177 | "7091636af08b41c88eb8a9f9179551ee", 1178 | "d8c24ad5ba9f431da269aa99537d3bd5", 1179 | "c62ecf993b0e4ddb810d4d4b82e50ef5", 1180 | "11b7b2ebb4294b02b15506446507a637", 1181 | "b7e3962196584ac4a1a325ad6af7239a", 1182 | "fe468022a67f48cb8d8f1c05621a6cb7", 1183 | "c0ec008b526a4c3486b4638ec18e66ea", 1184 | "81e0631c175c444e9603918fad9b504d", 1185 | "4d9d5b61bb2847848699313b8540aa9c", 1186 | "84abaf38fd6d422588f9bf7226e9378b", 1187 | "f24f6be70e314f2689017533011611eb", 1188 | "d60078ecf0a743e7a28f523c1ed105b6", 1189 | "107f5d59a1504e76a205fc54ef302376", 1190 | "9e59a739116d4ff6a2314cfffbf4594e", 1191 | "f358ac594bf34885985324fb1f10f371", 1192 | "a95b5e967c574b15a40da9497df022b6", 1193 | "18aed82927c945eb926e48bd17bbc823", 1194 | "e1f478d1bb864f8899d53ccee9694e47", 1195 | "17a109cf5a374c479a4af09cea3484e2", 1196 | "35c41aa3b1064c91a287ab4e6b5b7de8", 1197 | "e5072a91b6914f10af2d63045c27a307", 1198 | "c43f948fd4b54f4fbe65bf93095f675c" 1199 | ] 1200 | }, 1201 | "id": "s3LKBi08JWdh", 1202 | "outputId": "cba94b69-4f5c-42c9-9bdd-76a2eee2e35c" 1203 | }, 1204 | "execution_count": 11, 1205 | "outputs": [ 1206 | { 1207 | "output_type": "stream", 1208 | "name": "stderr", 1209 | "text": [ 1210 | "/usr/local/lib/python3.7/dist-packages/torch/functional.py:445: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:2157.)\n", 1211 | " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n", 1212 | "Downloading: \"https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth\" to /root/.cache/torch/hub/checkpoints/swin_tiny_patch4_window7_224.pth\n" 1213 | ] 1214 | }, 1215 | { 1216 | "output_type": "display_data", 1217 | "data": { 1218 | "application/vnd.jupyter.widget-view+json": { 1219 | "model_id": "7091636af08b41c88eb8a9f9179551ee", 1220 | "version_minor": 0, 1221 | "version_major": 2 1222 | }, 1223 | "text/plain": [ 1224 | "Downloading: 0%| | 0.00/632 [00:00" 1257 | ] 1258 | }, 1259 | "metadata": {}, 1260 | "execution_count": 11 1261 | } 1262 | ] 1263 | }, 1264 | { 1265 | "cell_type": "markdown", 1266 | "source": [ 1267 | "# Train model" 1268 | ], 1269 | "metadata": { 1270 | "id": "O1DCPGa_P-Ut" 1271 | } 1272 | }, 1273 | { 1274 | "cell_type": "code", 1275 | "source": [ 1276 | "trainer = Trainer(train_dataframe='/content/train.csv', train_dir='/content/flickr30k_images/flickr30k_images',\n", 1277 | " val_dataframe='/content/val.csv', val_dir='/content/flickr30k_images/flickr30k_images',)" 1278 | ], 1279 | "metadata": { 1280 | "id": "zW3uWDRsP_9v" 1281 | }, 1282 | "execution_count": 12, 1283 | "outputs": [] 1284 | }, 1285 | { 1286 | "cell_type": "code", 1287 | "source": [ 1288 | "model = trainer.train_model(model, epochs_num=2, device=device, verbose=20)" 1289 | ], 1290 | "metadata": { 1291 | "id": "TuHzZCXOQSzj" 1292 | }, 1293 | "execution_count": null, 1294 | "outputs": [] 1295 | } 1296 | ] 1297 | } --------------------------------------------------------------------------------