├── .gitignore ├── 0_Dataset_Exploration.ipynb ├── 1_Architecture.ipynb ├── 2_Training.ipynb ├── 3_Inference.ipynb ├── README.md ├── coco_dataset.py ├── data_loader.py ├── data_loader_val.py ├── gradio_main.py ├── images ├── cnn_rnn_model.png ├── coco-examples.jpg ├── decoder.png ├── encoder-decoder.png ├── encoder.png ├── gradio_demo.png ├── readme.png ├── result.png ├── sample_002.png ├── sample_008.png ├── sample_029.png ├── sample_034.png ├── sample_107.png ├── sample_171.png ├── sample_193.png ├── sample_202.png ├── sample_296.png ├── sample_326.png ├── sample_366.png ├── sample_440.png ├── sample_457.png └── sample_498.png ├── model.py ├── nlp_utils.py ├── requirements.txt ├── vocab.pkl └── vocabulary.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | samples/ 131 | models/ 132 | images.pt 133 | captions.pt 134 | .idea/ 135 | debug_script.py 136 | gradio_logs/ -------------------------------------------------------------------------------- /1_Architecture.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Data Preparation & Architecture Design\n", 8 | "\n", 9 | "In this notebook, we will explain how to load and pre-process data from the [COCO dataset](http://cocodataset.org/#home). We will also design a CNN-RNN model for automatically generating image captions.\n", 10 | "\n", 11 | "The implementation of CNN encoder and RNN decoder are in the **model.py** file. \n", 12 | "\n", 13 | "Outline of this notebook:\n", 14 | "- [Step 1](#step1): Explore the Data Loader\n", 15 | "- [Step 2](#step2): Use the Data Loader to Obtain Batches\n", 16 | "- [Step 3](#step3): Experiment with the CNN Encoder\n", 17 | "- [Step 4](#step4): Implement the RNN Decoder" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "\n", 25 | "## Step 1: Write the Data Loader\n", 26 | "\n", 27 | "We wrote a [data loader](http://pytorch.org/docs/master/data.html#torch.utils.data.DataLoader) to load the COCO dataset in batches. In the code cell below, We will initialize the data loader by using the `get_loader` function in **data_loader.py**. \n", 28 | "\n", 29 | "\n", 30 | "The `get_loader` function takes as input a number of arguments that can be explored in **data_loader.py**. Take the time to explore these arguments now by opening **data_loader.py**. Most of the arguments must be left at their default values Here some important parameters:\n", 31 | "1. **`transform`** - an [image transform](http://pytorch.org/docs/master/torchvision/transforms.html) specifying how to pre-process the images and convert them to PyTorch tensors before using them as input to the CNN encoder. We defined transforms in `transform_train` variable.\n", 32 | "2. **`mode`** - one of `'train'` (loads the training data in batches) or `'test'` (for the test data). We will say that the data loader is in training or test mode, respectively. We keep the data loader in training mode in this notebook by setting `mode='train'`.\n", 33 | "3. **`batch_size`** - determines the batch size. When training the model, this is number of image-caption pairs used to amend the model weights in each training step.\n", 34 | "4. **`vocab_threshold`** - **the total number of times that a word must appear in the in the training captions before it is used as part of the vocabulary**. Words that have fewer than `vocab_threshold` occurrences in the training captions are considered unknown words. \n", 35 | "5. **`vocab_from_file`** - a Boolean that decides whether to load the vocabulary from file. \n", 36 | "\n", 37 | "We will describe the `vocab_threshold` and `vocab_from_file` arguments in more detail soon. For now, run the code cell below. It may take a couple of minutes to run!" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 1, 43 | "metadata": { 44 | "pycharm": { 45 | "is_executing": false 46 | }, 47 | "scrolled": false 48 | }, 49 | "outputs": [ 50 | { 51 | "name": "stderr", 52 | "output_type": "stream", 53 | "text": [ 54 | "[nltk_data] Downloading package punkt to /home/masoud/nltk_data...\n", 55 | "[nltk_data] Package punkt is already up-to-date!\n", 56 | "[nltk_data] Downloading package punkt to /home/masoud/nltk_data...\n", 57 | "[nltk_data] Package punkt is already up-to-date!\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "from pycocotools.coco import COCO\n", 63 | "import nltk\n", 64 | "from data_loader import get_loader\n", 65 | "import torch\n", 66 | "import numpy as np\n", 67 | "import torch.utils.data as data\n", 68 | "from torchvision import transforms\n", 69 | "\n", 70 | "nltk.download(\"punkt\")\n", 71 | "\n", 72 | "\n", 73 | "%load_ext autoreload\n", 74 | "%autoreload 2" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 2, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "name": "stdout", 84 | "output_type": "stream", 85 | "text": [ 86 | "loading annotations into memory...\n", 87 | "Done (t=0.61s)\n", 88 | "creating index...\n", 89 | "index created!\n", 90 | "[0/414113] Tokenizing captions...\n", 91 | "[100000/414113] Tokenizing captions...\n", 92 | "[200000/414113] Tokenizing captions...\n", 93 | "[300000/414113] Tokenizing captions...\n", 94 | "[400000/414113] Tokenizing captions...\n", 95 | "loading annotations into memory...\n", 96 | "Done (t=0.64s)\n", 97 | "creating index...\n", 98 | "index created!\n", 99 | "Obtaining caption lengths...\n" 100 | ] 101 | }, 102 | { 103 | "name": "stderr", 104 | "output_type": "stream", 105 | "text": [ 106 | "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 414113/414113 [00:45<00:00, 9162.46it/s]\n" 107 | ] 108 | } 109 | ], 110 | "source": [ 111 | "# Define a transform to pre-process the training images.\n", 112 | "transform_train = transforms.Compose(\n", 113 | " [\n", 114 | " transforms.Resize(256), # smaller edge of image resized to 256\n", 115 | " transforms.RandomCrop(224), # get 224x224 crop from random location\n", 116 | " transforms.RandomHorizontalFlip(), # horizontally flip image with probability=0.5\n", 117 | " transforms.ToTensor(), # convert the PIL Image to a tensor\n", 118 | " transforms.Normalize(\n", 119 | " (0.485, 0.456, 0.406), # normalize image for pre-trained model\n", 120 | " (0.229, 0.224, 0.225),\n", 121 | " ),\n", 122 | " ]\n", 123 | ")\n", 124 | "\n", 125 | "# Set the minimum word count threshold.\n", 126 | "vocab_threshold = 5\n", 127 | "\n", 128 | "# Specify the batch size.\n", 129 | "batch_size = 10\n", 130 | "\n", 131 | "# Path to cocoapi dir\n", 132 | "cocoapi_dir = r\"path/to/cocoapi/dir\"\n", 133 | "\n", 134 | "# Obtain the data loader.\n", 135 | "data_loader = get_loader(\n", 136 | " transform=transform_train,\n", 137 | " mode=\"train\",\n", 138 | " batch_size=batch_size,\n", 139 | " vocab_threshold=vocab_threshold,\n", 140 | " vocab_from_file=False,\n", 141 | " cocoapi_loc=cocoapi_dir,\n", 142 | ")" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "When we run the code cell above, the data loader is stored in the variable `data_loader`. \n", 150 | "\n", 151 | "We can now access the corresponding dataset as `data_loader.dataset`. This dataset is an instance of the `CoCoDataset` class in **coco_dataset.py**. If you are unfamiliar with data loaders and datasets, you are encouraged to review [my post](http://www.sefidian.com/2022/03/09/writing-custom-datasets-and-dataloader-in-pytorch/) or [this PyTorch tutorial](http://pytorch.org/tutorials/beginner/data_loading_tutorial.html).\n", 152 | "\n", 153 | "### Exploring the `__getitem__` Method\n", 154 | "\n", 155 | "The `__getitem__` method in the `CoCoDataset` class determines how an image-caption pair is pre-processed before being incorporated into a batch. This is true for all `Dataset` classes in PyTorch; if this is unfamiliar to you, please review [the tutorial linked above](http://pytorch.org/tutorials/beginner/data_loading_tutorial.html). \n", 156 | "\n", 157 | "When the data loader is in training mode, this method begins by first obtaining the filename (`path`) of a training image and its corresponding caption (`caption`).\n", 158 | "\n", 159 | "#### Image Pre-Processing \n", 160 | "\n", 161 | "Image pre-processing is relatively straightforward (from the `__getitem__` method in the `CoCoDataset` class):\n", 162 | "```python\n", 163 | "# Convert image to tensor and pre-process using transform\n", 164 | "image = Image.open(os.path.join(self.img_folder, path)).convert('RGB')\n", 165 | "image = self.transform(image)\n", 166 | "```\n", 167 | "After loading the image in the training folder with name `path`, the image is pre-processed using the same transform (`transform_train`) that was supplied when instantiating the data loader. \n", 168 | "\n", 169 | "#### Caption Pre-Processing \n", 170 | "\n", 171 | "The captions also need to be pre-processed and prepped for training. In this example, for generating captions, we are aiming to create a model that predicts the next token of a sentence from previous tokens, so we turn the caption associated with any image into a list of tokenized words, before casting it to a PyTorch tensor that we can use to train the network.\n", 172 | "\n", 173 | "To understand in more detail how COCO captions are pre-processed, please take a look at the `vocab` instance variable of the `CoCoDataset` class. The code snippet below is pulled from the `__init__` method of the `CoCoDataset` class:\n", 174 | "```python\n", 175 | "def __init__(self, transform, mode, batch_size, vocab_threshold, vocab_file, start_word, \n", 176 | " end_word, unk_word, annotations_file, vocab_from_file, img_folder):\n", 177 | " ...\n", 178 | " self.vocab = Vocabulary(vocab_threshold, vocab_file, start_word,\n", 179 | " end_word, unk_word, annotations_file, vocab_from_file)\n", 180 | " ...\n", 181 | "```\n", 182 | "From the code snippet above, we can see that `data_loader.dataset.vocab` is an instance of the `Vocabulary` class from **vocabulary.py**. Take the time now to verify this by looking at the full code in **data_loader.py**. \n", 183 | "\n", 184 | "We use this instance to pre-process the COCO captions (from the `__getitem__` method in the `CoCoDataset` class):\n", 185 | "\n", 186 | "```python\n", 187 | "# Convert caption to tensor of word ids.\n", 188 | "tokens = nltk.tokenize.word_tokenize(str(caption).lower()) # line 1\n", 189 | "caption = [] # line 2\n", 190 | "caption.append(self.vocab(self.vocab.start_word)) # line 3\n", 191 | "caption.extend([self.vocab(token) for token in tokens]) # line 4\n", 192 | "caption.append(self.vocab(self.vocab.end_word)) # line 5\n", 193 | "caption = torch.Tensor(caption).long() # line 6\n", 194 | "```\n", 195 | "\n", 196 | "This code converts any string-valued caption to a list of integers, before casting it to a PyTorch tensor. To see how this code works, I'll apply it to the sample caption in the next code cell." 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 3, 202 | "metadata": { 203 | "pycharm": { 204 | "is_executing": false 205 | } 206 | }, 207 | "outputs": [], 208 | "source": [ 209 | "sample_caption = \"A person doing a trick on a rail while riding a skateboard.\"" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "metadata": {}, 215 | "source": [ 216 | "In **`line 1`** of the code snippet, every letter in the caption is converted to lowercase, and the [`nltk.tokenize.word_tokenize`](http://www.nltk.org/) function is used to obtain a list of string-valued tokens. Run the next code cell to visualize the effect on `sample_caption`." 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 4, 222 | "metadata": { 223 | "pycharm": { 224 | "is_executing": false 225 | } 226 | }, 227 | "outputs": [ 228 | { 229 | "name": "stdout", 230 | "output_type": "stream", 231 | "text": [ 232 | "['a', 'person', 'doing', 'a', 'trick', 'on', 'a', 'rail', 'while', 'riding', 'a', 'skateboard', '.']\n" 233 | ] 234 | } 235 | ], 236 | "source": [ 237 | "sample_tokens = nltk.tokenize.word_tokenize(str(sample_caption).lower())\n", 238 | "print(sample_tokens)" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "metadata": {}, 244 | "source": [ 245 | "**`line 2`** and **`line 3`** initializes an empty list and appends an integer to mark the start of a caption. The [paper](https://arxiv.org/pdf/1411.4555.pdf) uses a special start word (and a special end word, which I'll examine below) to mark the beginning (and end) of a caption.\n", 246 | "\n", 247 | "This special start word (`\"\"`) is decided when instantiating the data loader and is passed as a parameter (`start_word`). It is **required** to keep this parameter at its default value (`start_word=\"\"`).\n", 248 | "\n", 249 | "As shown below, the integer `0` is always used to mark the start of a caption." 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": 5, 255 | "metadata": { 256 | "pycharm": { 257 | "is_executing": false 258 | } 259 | }, 260 | "outputs": [ 261 | { 262 | "name": "stdout", 263 | "output_type": "stream", 264 | "text": [ 265 | "Special start word: \n", 266 | "[0]\n" 267 | ] 268 | } 269 | ], 270 | "source": [ 271 | "sample_caption = []\n", 272 | "\n", 273 | "start_word = data_loader.dataset.vocab.start_word\n", 274 | "print(\"Special start word:\", start_word)\n", 275 | "sample_caption.append(data_loader.dataset.vocab(start_word))\n", 276 | "print(sample_caption)" 277 | ] 278 | }, 279 | { 280 | "cell_type": "markdown", 281 | "metadata": {}, 282 | "source": [ 283 | "In **`line 4`**, we continue the list by adding integers that correspond to each of the tokens in the caption." 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 6, 289 | "metadata": { 290 | "pycharm": { 291 | "is_executing": false 292 | } 293 | }, 294 | "outputs": [ 295 | { 296 | "name": "stdout", 297 | "output_type": "stream", 298 | "text": [ 299 | "[0, 3, 98, 754, 3, 396, 39, 3, 1010, 207, 139, 3, 753, 18]\n" 300 | ] 301 | } 302 | ], 303 | "source": [ 304 | "sample_caption.extend([data_loader.dataset.vocab(token) for token in sample_tokens])\n", 305 | "print(sample_caption)" 306 | ] 307 | }, 308 | { 309 | "cell_type": "markdown", 310 | "metadata": {}, 311 | "source": [ 312 | "**`line 5`** appends a final integer to mark the end of the caption. \n", 313 | "\n", 314 | "Identical to the case of the special start word (above), the special end word (`\"\"`) is decided when instantiating the data loader and is passed as a parameter (`end_word`). It is **required** to keep this parameter at its default value (`end_word=\"\"`).\n", 315 | "\n", 316 | "As shown below, the integer `1` is always used to mark the end of a caption." 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 7, 322 | "metadata": { 323 | "pycharm": { 324 | "is_executing": false 325 | } 326 | }, 327 | "outputs": [ 328 | { 329 | "name": "stdout", 330 | "output_type": "stream", 331 | "text": [ 332 | "Special end word: \n", 333 | "[0, 3, 98, 754, 3, 396, 39, 3, 1010, 207, 139, 3, 753, 18, 1]\n" 334 | ] 335 | } 336 | ], 337 | "source": [ 338 | "end_word = data_loader.dataset.vocab.end_word\n", 339 | "print(\"Special end word:\", end_word)\n", 340 | "\n", 341 | "sample_caption.append(data_loader.dataset.vocab(end_word))\n", 342 | "print(sample_caption)" 343 | ] 344 | }, 345 | { 346 | "cell_type": "markdown", 347 | "metadata": {}, 348 | "source": [ 349 | "Finally, in **`line 6`**, we convert the list of integers to a PyTorch tensor and cast it to [long type](http://pytorch.org/docs/master/tensors.html#torch.Tensor.long). Read more about the different types of PyTorch tensors on the [website](http://pytorch.org/docs/master/tensors.html)." 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": 8, 355 | "metadata": { 356 | "pycharm": { 357 | "is_executing": false 358 | } 359 | }, 360 | "outputs": [ 361 | { 362 | "name": "stdout", 363 | "output_type": "stream", 364 | "text": [ 365 | "tensor([ 0, 3, 98, 754, 3, 396, 39, 3, 1010, 207, 139, 3,\n", 366 | " 753, 18, 1])\n" 367 | ] 368 | } 369 | ], 370 | "source": [ 371 | "sample_caption = torch.Tensor(sample_caption).long()\n", 372 | "print(sample_caption)" 373 | ] 374 | }, 375 | { 376 | "cell_type": "markdown", 377 | "metadata": {}, 378 | "source": [ 379 | "## IDs\n", 380 | "start : 0\n", 381 | "\n", 382 | "end : 1\n", 383 | "\n", 384 | "unk : 2 " 385 | ] 386 | }, 387 | { 388 | "cell_type": "markdown", 389 | "metadata": {}, 390 | "source": [ 391 | "And that's it! In summary, any caption is converted to a list of tokens, with _special_ start and end tokens marking the beginning and end of the sentence:\n", 392 | "```\n", 393 | "[, 'a', 'person', 'doing', 'a', 'trick', 'while', 'riding', 'a', 'skateboard', '.', ]\n", 394 | "```\n", 395 | "This list of tokens is then turned into a list of integers, where every distinct word in the vocabulary has an associated integer value:\n", 396 | "```\n", 397 | "[0, 3, 98, 754, 3, 396, 207, 139, 3, 753, 18, 1]\n", 398 | "```\n", 399 | "Finally, this list is converted to a PyTorch tensor. All of the captions in the COCO dataset are pre-processed using this same procedure from **`lines 1-6`** described above. \n", 400 | "\n", 401 | "As shown, in order to convert a token to its corresponding integer, we call `data_loader.dataset.vocab` as a function. The details of how this call works can be explored in the `__call__` method in the `Vocabulary` class in **vocabulary.py**. \n", 402 | "\n", 403 | "```python\n", 404 | "def __call__(self, word):\n", 405 | " if not word in self.word2idx:\n", 406 | " return self.word2idx[self.unk_word]\n", 407 | " return self.word2idx[word]\n", 408 | "```\n", 409 | "\n", 410 | "The `word2idx` instance variable is a Python [dictionary](https://docs.python.org/3/tutorial/datastructures.html#dictionaries) that is indexed by string-valued keys (mostly tokens obtained from training captions). For each key, the corresponding value is the integer that the token is mapped to in the pre-processing step.\n", 411 | "\n", 412 | "Run cell below to view a subset of this dictionary." 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": 9, 418 | "metadata": { 419 | "pycharm": { 420 | "is_executing": false 421 | } 422 | }, 423 | "outputs": [ 424 | { 425 | "data": { 426 | "text/plain": [ 427 | "{'': 0,\n", 428 | " '': 1,\n", 429 | " '': 2,\n", 430 | " 'a': 3,\n", 431 | " 'very': 4,\n", 432 | " 'clean': 5,\n", 433 | " 'and': 6,\n", 434 | " 'well': 7,\n", 435 | " 'decorated': 8,\n", 436 | " 'empty': 9}" 437 | ] 438 | }, 439 | "execution_count": 9, 440 | "metadata": {}, 441 | "output_type": "execute_result" 442 | } 443 | ], 444 | "source": [ 445 | "# Preview the word2idx dictionary.\n", 446 | "dict(list(data_loader.dataset.vocab.word2idx.items())[:10])" 447 | ] 448 | }, 449 | { 450 | "cell_type": "markdown", 451 | "metadata": {}, 452 | "source": [ 453 | "We also print the total number of keys." 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": 10, 459 | "metadata": { 460 | "pycharm": { 461 | "is_executing": false 462 | } 463 | }, 464 | "outputs": [ 465 | { 466 | "name": "stdout", 467 | "output_type": "stream", 468 | "text": [ 469 | "Total number of tokens in vocabulary: 8852\n" 470 | ] 471 | } 472 | ], 473 | "source": [ 474 | "# Print the total number of keys in the word2idx dictionary.\n", 475 | "print(\"Total number of tokens in vocabulary:\", len(data_loader.dataset.vocab))" 476 | ] 477 | }, 478 | { 479 | "cell_type": "markdown", 480 | "metadata": {}, 481 | "source": [ 482 | "In **vocabulary.py** codes, the `word2idx` dictionary is created by looping over the captions in the training dataset. If a token appears no less than `vocab_threshold` times in the training set, then it is added as a key to the dictionary and assigned a corresponding unique integer. We can amend the `vocab_threshold` argument when instantiating the data loader. Note that in general, **smaller** values for `vocab_threshold` yield a **larger** number of tokens in the vocabulary. We can check this in the next code cell by decreasing the value of `vocab_threshold` before creating a new data loader. " 483 | ] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "execution_count": null, 488 | "metadata": { 489 | "pycharm": { 490 | "is_executing": false 491 | } 492 | }, 493 | "outputs": [], 494 | "source": [ 495 | "# Modify the minimum word count threshold.\n", 496 | "vocab_threshold = 4\n", 497 | "\n", 498 | "# Obtain the data loader.\n", 499 | "data_loader = get_loader(\n", 500 | " transform=transform_train,\n", 501 | " mode=\"train\",\n", 502 | " batch_size=batch_size,\n", 503 | " vocab_threshold=vocab_threshold,\n", 504 | " vocab_from_file=False,\n", 505 | " cocoapi_loc=cocoapi_dir,\n", 506 | ")" 507 | ] 508 | }, 509 | { 510 | "cell_type": "code", 511 | "execution_count": null, 512 | "metadata": { 513 | "pycharm": { 514 | "is_executing": false 515 | } 516 | }, 517 | "outputs": [], 518 | "source": [ 519 | "# Print the total number of keys in the word2idx dictionary.\n", 520 | "print(f\"Total number of tokens in vocabulary: {len(data_loader.dataset.vocab)}\")" 521 | ] 522 | }, 523 | { 524 | "cell_type": "markdown", 525 | "metadata": {}, 526 | "source": [ 527 | "There are also a few special keys in the `word2idx` dictionary. For example, the special start word (`\"\"`) and special end word (`\"\"`). There is one more special token, corresponding to unknown words (`\"\"`). All tokens that don't appear anywhere in the `word2idx` dictionary are considered unknown words. In the pre-processing step, any unknown tokens are mapped to the integer `2`." 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": 10, 533 | "metadata": { 534 | "pycharm": { 535 | "is_executing": false 536 | } 537 | }, 538 | "outputs": [ 539 | { 540 | "name": "stdout", 541 | "output_type": "stream", 542 | "text": [ 543 | "Special unknown word: \n", 544 | "All unknown words are mapped to this integer: 2\n" 545 | ] 546 | } 547 | ], 548 | "source": [ 549 | "unk_word = data_loader.dataset.vocab.unk_word\n", 550 | "print(f\"Special unknown word: {unk_word}\")\n", 551 | "\n", 552 | "print(\n", 553 | " f\"All unknown words are mapped to this integer: {data_loader.dataset.vocab(unk_word)}\"\n", 554 | ")" 555 | ] 556 | }, 557 | { 558 | "cell_type": "markdown", 559 | "metadata": {}, 560 | "source": [ 561 | "We can check this by pre-processing the provided nonsense words that never appear in the training captions. " 562 | ] 563 | }, 564 | { 565 | "cell_type": "code", 566 | "execution_count": 11, 567 | "metadata": { 568 | "pycharm": { 569 | "is_executing": false 570 | } 571 | }, 572 | "outputs": [ 573 | { 574 | "name": "stdout", 575 | "output_type": "stream", 576 | "text": [ 577 | "2\n", 578 | "2\n" 579 | ] 580 | } 581 | ], 582 | "source": [ 583 | "print(data_loader.dataset.vocab(\"jfkafejw\"))\n", 584 | "print(data_loader.dataset.vocab(\"ieowoqjf\"))" 585 | ] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": 12, 590 | "metadata": {}, 591 | "outputs": [ 592 | { 593 | "name": "stdout", 594 | "output_type": "stream", 595 | "text": [ 596 | "18\n" 597 | ] 598 | } 599 | ], 600 | "source": [ 601 | "print(data_loader.dataset.vocab(\".\"))" 602 | ] 603 | }, 604 | { 605 | "cell_type": "markdown", 606 | "metadata": {}, 607 | "source": [ 608 | "The final thing to mention is the `vocab_from_file` argument that is supplied when creating a data loader. To understand this argument, note that when a new data loader is created, the vocabulary (`data_loader.dataset.vocab`) is saved as a [pickle](https://docs.python.org/3/library/pickle.html) file in the project folder, with filename `vocab.pkl`.\n", 609 | "\n", 610 | "If you want to tweak the value of the `vocab_threshold` argument, you **must** set `vocab_from_file=False` to have the changes take effect. \n", 611 | "\n", 612 | "Once the `vocab_threshold` argument has been chosen, run the data loader *one more time* with the chosen `vocab_threshold` to save the new vocabulary to file. Then, we can henceforth set `vocab_from_file=True` to load the vocabulary from file and speed the instantiation of the data loader. Note that building the vocabulary from scratch is the most time-consuming part of instantiating the data loader, and we set `vocab_from_file=True` as soon as we fix the configuration.\n", 613 | "\n", 614 | "Note that if `vocab_from_file=True`, then any supplied argument for `vocab_threshold` when instantiating the data loader is completely ignored." 615 | ] 616 | }, 617 | { 618 | "cell_type": "code", 619 | "execution_count": 13, 620 | "metadata": { 621 | "pycharm": { 622 | "is_executing": false 623 | } 624 | }, 625 | "outputs": [ 626 | { 627 | "name": "stdout", 628 | "output_type": "stream", 629 | "text": [ 630 | "Vocabulary successfully loaded from vocab.pkl file!\n", 631 | "loading annotations into memory...\n", 632 | "Done (t=0.56s)\n", 633 | "creating index...\n", 634 | "index created!\n", 635 | "Obtaining caption lengths...\n" 636 | ] 637 | }, 638 | { 639 | "name": "stderr", 640 | "output_type": "stream", 641 | "text": [ 642 | "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 414113/414113 [00:43<00:00, 9587.01it/s]\n" 643 | ] 644 | } 645 | ], 646 | "source": [ 647 | "# Obtain the data loader (from file). Note that it runs much faster than before!\n", 648 | "data_loader = get_loader(\n", 649 | " transform=transform_train,\n", 650 | " mode=\"train\",\n", 651 | " batch_size=batch_size,\n", 652 | " vocab_from_file=True,\n", 653 | " cocoapi_loc=cocoapi_dir,\n", 654 | ")" 655 | ] 656 | }, 657 | { 658 | "cell_type": "markdown", 659 | "metadata": {}, 660 | "source": [ 661 | "The way to use the data loader to obtain batches of training data is explained in the next section." 662 | ] 663 | }, 664 | { 665 | "cell_type": "markdown", 666 | "metadata": {}, 667 | "source": [ 668 | "\n", 669 | "## Step 2: Using the Data Loader to Obtain Batches\n", 670 | "\n", 671 | "The captions in the dataset vary greatly in length. We can see this by examining `data_loader.dataset.caption_lengths`, a Python list with one entry for each training caption (where the value stores the length of the corresponding caption). \n", 672 | "\n", 673 | "In the code cell below, we use this list to print the total number of captions in the training data with each length. As seen below, the majority of captions have length 10. Likewise, very short and very long captions are quite rare. " 674 | ] 675 | }, 676 | { 677 | "cell_type": "code", 678 | "execution_count": 14, 679 | "metadata": {}, 680 | "outputs": [ 681 | { 682 | "data": { 683 | "text/plain": [ 684 | "(list, 414113)" 685 | ] 686 | }, 687 | "execution_count": 14, 688 | "metadata": {}, 689 | "output_type": "execute_result" 690 | } 691 | ], 692 | "source": [ 693 | "type(data_loader.dataset.caption_lengths), len(data_loader.dataset.caption_lengths)" 694 | ] 695 | }, 696 | { 697 | "cell_type": "code", 698 | "execution_count": 15, 699 | "metadata": { 700 | "pycharm": { 701 | "is_executing": false 702 | } 703 | }, 704 | "outputs": [ 705 | { 706 | "name": "stdout", 707 | "output_type": "stream", 708 | "text": [ 709 | "value: 10 --- count: 86302\n", 710 | "value: 11 --- count: 79971\n", 711 | "value: 9 --- count: 71920\n", 712 | "value: 12 --- count: 57653\n", 713 | "value: 13 --- count: 37668\n", 714 | "value: 14 --- count: 22342\n", 715 | "value: 8 --- count: 20742\n", 716 | "value: 15 --- count: 12839\n", 717 | "value: 16 --- count: 7736\n", 718 | "value: 17 --- count: 4845\n", 719 | "value: 18 --- count: 3101\n", 720 | "value: 19 --- count: 2017\n", 721 | "value: 7 --- count: 1594\n", 722 | "value: 20 --- count: 1453\n", 723 | "value: 21 --- count: 997\n", 724 | "value: 22 --- count: 684\n", 725 | "value: 23 --- count: 533\n", 726 | "value: 24 --- count: 384\n", 727 | "value: 25 --- count: 277\n", 728 | "value: 26 --- count: 214\n", 729 | "value: 27 --- count: 160\n", 730 | "value: 28 --- count: 114\n", 731 | "value: 29 --- count: 87\n", 732 | "value: 30 --- count: 58\n", 733 | "value: 31 --- count: 49\n", 734 | "value: 32 --- count: 44\n", 735 | "value: 34 --- count: 40\n", 736 | "value: 37 --- count: 32\n", 737 | "value: 35 --- count: 31\n", 738 | "value: 33 --- count: 30\n", 739 | "value: 36 --- count: 26\n", 740 | "value: 38 --- count: 18\n", 741 | "value: 39 --- count: 18\n", 742 | "value: 43 --- count: 16\n", 743 | "value: 44 --- count: 16\n", 744 | "value: 48 --- count: 12\n", 745 | "value: 45 --- count: 11\n", 746 | "value: 42 --- count: 10\n", 747 | "value: 40 --- count: 9\n", 748 | "value: 49 --- count: 9\n", 749 | "value: 46 --- count: 9\n", 750 | "value: 47 --- count: 7\n", 751 | "value: 50 --- count: 6\n", 752 | "value: 51 --- count: 6\n", 753 | "value: 41 --- count: 6\n", 754 | "value: 52 --- count: 5\n", 755 | "value: 54 --- count: 3\n", 756 | "value: 56 --- count: 2\n", 757 | "value: 6 --- count: 2\n", 758 | "value: 53 --- count: 2\n", 759 | "value: 55 --- count: 2\n", 760 | "value: 57 --- count: 1\n" 761 | ] 762 | } 763 | ], 764 | "source": [ 765 | "from collections import Counter\n", 766 | "\n", 767 | "# Tally the total number of training captions with each length.\n", 768 | "counter = Counter(data_loader.dataset.caption_lengths)\n", 769 | "lengths = sorted(counter.items(), key=lambda pair: pair[1], reverse=True)\n", 770 | "for value, count in lengths:\n", 771 | " print(\"value: %2d --- count: %5d\" % (value, count))" 772 | ] 773 | }, 774 | { 775 | "cell_type": "markdown", 776 | "metadata": {}, 777 | "source": [ 778 | "To generate batches of training data, we begin by first sampling a caption length (where the probability that any length is drawn is proportional to the number of captions with that length in the dataset). Then, we retrieve a batch of size `batch_size` of image-caption pairs, where all captions have the sampled length. This approach for assembling batches matches the procedure in [this paper](https://arxiv.org/pdf/1502.03044.pdf) and has been shown to be computationally efficient without degrading performance.\n", 779 | "\n", 780 | "The code cell below generates a batch. The `get_train_indices` method in the `CoCoDataset` class first samples a caption length, and then samples `batch_size` indices corresponding to training data points with captions of that length. These indices are stored below in `indices`.\n", 781 | "\n", 782 | "These indices are supplied to the data loader, which then is used to retrieve the corresponding data points. The pre-processed images and captions in the batch are stored in `images` and `captions`." 783 | ] 784 | }, 785 | { 786 | "cell_type": "code", 787 | "execution_count": 16, 788 | "metadata": { 789 | "pycharm": { 790 | "is_executing": false 791 | }, 792 | "scrolled": false 793 | }, 794 | "outputs": [ 795 | { 796 | "name": "stdout", 797 | "output_type": "stream", 798 | "text": [ 799 | "10\n", 800 | "sampled indices: [388202, 286065, 397083, 408969, 92246, 354550, 13088, 43820, 124136, 374756]\n", 801 | "images.shape: torch.Size([10, 3, 224, 224])\n", 802 | "captions.shape: torch.Size([10, 13])\n" 803 | ] 804 | } 805 | ], 806 | "source": [ 807 | "print(batch_size)\n", 808 | "# Randomly sample a caption length, and sample indices with that length.\n", 809 | "indices = data_loader.dataset.get_train_indices()\n", 810 | "print(\"sampled indices:\", indices)\n", 811 | "\n", 812 | "# Create and assign a batch sampler to retrieve a batch with the sampled indices.\n", 813 | "new_sampler = data.sampler.SubsetRandomSampler(indices=indices)\n", 814 | "data_loader.batch_sampler.sampler = new_sampler\n", 815 | "\n", 816 | "# Obtain the batch.\n", 817 | "images, captions = next(iter(data_loader))\n", 818 | "\n", 819 | "print(\"images.shape:\", images.shape)\n", 820 | "print(\"captions.shape:\", captions.shape)\n", 821 | "\n", 822 | "# Uncomment the lines of code below to print the pre-processed images and captions.\n", 823 | "# print('images:', images)\n", 824 | "# print('captions:', captions)" 825 | ] 826 | }, 827 | { 828 | "cell_type": "markdown", 829 | "metadata": {}, 830 | "source": [ 831 | "**Each time we run the code cell above, a different caption length is sampled**, and a different batch of training data is returned.\n", 832 | "\n", 833 | "We will train our model in the next notebook in this sequence (**2_Training.ipynb**).\n", 834 | "\n", 835 | "> Before moving to the next notebook in the sequence (**2_Training.ipynb**), take the time to become very familiar with the code in **coco_dataset.py**, **data_loader.py**, and **vocabulary.py**. **Step 1** and **Step 2** of this notebook are designed to help facilitate a basic introduction and guide the understanding. However, our description is not exhaustive.\n", 836 | "\n", 837 | "In the next steps, we focus on learning how to specify a CNN-RNN architecture in PyTorch, towards the goal of image captioning." 838 | ] 839 | }, 840 | { 841 | "cell_type": "markdown", 842 | "metadata": {}, 843 | "source": [ 844 | "# Architecture Details\n", 845 | "\n", 846 | "![Image Captioning CNN-RNN model](images/encoder-decoder.png)\n", 847 | "\n", 848 | "The architecture consists of a CNN encoder and RNN decoder. The CNN encoder is a pre-trained ResNet on ImageNet, which is a VGG convolutional neural network with skip connections. It has been proven to work really well on tasks like image recognition because the residual connections help model the residual differences before and after the convolution with the help of the identity block. A good pre-trained network on ImageNet is already good at extracting both useful low-level and high-level features for image tasks, so it naturally serves as a feature encoder for the image we want to caption. Since we are not doing the traditional image classification task, we drop the last fully connected layer and replace it without a new trainable fully connected layer to help transform the final feature map to an encoding that is more useful for the RNN decoder.\n", 849 | "\n", 850 | "RNNs have long been shown useful in language tasks due to their ability to model data with sequential nature, such as language. Specifically, LSTMs even incorporate both long-term and short-term information as memories in the network. Thus, we pick an RNN decoder for the captioning task. Specifically, following the spirit of sequence to sequence (seq2seq) models used in translation, I leveraged the architecture choices in [this paper](https://arxiv.org/pdf/1411.4555.pdf) to use an LSTM to generate captions based on the encoded information from the CNN encoder. Specifically, **I first use the CNN encoder output concatenated with the \"START\" token as the initial input for the RNN decoder.** I apply a fully connected layer on the hidden states at that timestamp to output a softmax probability over the words in our entire vocabulary, where we choose the word with the highest probability as the word generated at that timestamp. Then, we feed this predicted word back again as the input for the next step. We continue so until we generated a caption of max length, or the network generated the \"STOP\" token, which indicates the end of the sentence." 851 | ] 852 | }, 853 | { 854 | "cell_type": "markdown", 855 | "metadata": {}, 856 | "source": [ 857 | "\n", 858 | "## Step 3: Experimenting with the CNN Encoder\n", 859 | "\n", 860 | "Run the code cell below to import `EncoderCNN` and `DecoderRNN` from **model.py**. " 861 | ] 862 | }, 863 | { 864 | "cell_type": "code", 865 | "execution_count": 17, 866 | "metadata": { 867 | "pycharm": { 868 | "is_executing": false 869 | } 870 | }, 871 | "outputs": [], 872 | "source": [ 873 | "# Import EncoderCNN and DecoderRNN.\n", 874 | "# Watch for any changes in model.py, and re-load it automatically.\n", 875 | "from model import EncoderCNN, DecoderRNN" 876 | ] 877 | }, 878 | { 879 | "cell_type": "markdown", 880 | "metadata": {}, 881 | "source": [ 882 | "In the next code cell we define a `device` that we will use move PyTorch tensors to GPU (if CUDA is available). Run this code cell before continuing." 883 | ] 884 | }, 885 | { 886 | "cell_type": "code", 887 | "execution_count": 18, 888 | "metadata": { 889 | "pycharm": { 890 | "is_executing": false 891 | } 892 | }, 893 | "outputs": [ 894 | { 895 | "name": "stdout", 896 | "output_type": "stream", 897 | "text": [ 898 | "cuda\n" 899 | ] 900 | } 901 | ], 902 | "source": [ 903 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 904 | "print(device)" 905 | ] 906 | }, 907 | { 908 | "cell_type": "markdown", 909 | "metadata": {}, 910 | "source": [ 911 | "Run the code cell below to instantiate the CNN encoder in `encoder`. \n", 912 | "\n", 913 | "The pre-processed images from the batch in **Step 2** of this notebook are then passed through the encoder, and the output is stored in `features`." 914 | ] 915 | }, 916 | { 917 | "cell_type": "code", 918 | "execution_count": 19, 919 | "metadata": { 920 | "pycharm": { 921 | "is_executing": false 922 | } 923 | }, 924 | "outputs": [ 925 | { 926 | "name": "stdout", 927 | "output_type": "stream", 928 | "text": [ 929 | "type(features): \n", 930 | "features.shape: torch.Size([10, 256])\n", 931 | "captions.shape: torch.Size([10, 13])\n" 932 | ] 933 | } 934 | ], 935 | "source": [ 936 | "# Specify the dimensionality of the image embedding.\n", 937 | "image_embed_size = 256\n", 938 | "\n", 939 | "# Initialize the encoder.\n", 940 | "encoder = EncoderCNN(image_embed_size)\n", 941 | "\n", 942 | "# Move the encoder to appropriate device.\n", 943 | "encoder.to(device)\n", 944 | "\n", 945 | "# Move last batch of images (from Step 2) to GPU if CUDA is available.\n", 946 | "images = images.to(device)\n", 947 | "\n", 948 | "# Pass the images through the encoder.\n", 949 | "features = encoder(images)\n", 950 | "\n", 951 | "print(\"type(features):\", type(features))\n", 952 | "print(\"features.shape:\", features.shape)\n", 953 | "print(\"captions.shape:\", captions.shape)\n", 954 | "\n", 955 | "# Check that the encoder satisfies the requirements!\n", 956 | "assert type(features) == torch.Tensor, \"Encoder output needs to be a PyTorch Tensor.\"\n", 957 | "\n", 958 | "assert (features.shape[0] == batch_size) and (\n", 959 | " features.shape[1] == image_embed_size\n", 960 | "), \"The shape of the encoder output is incorrect.\"" 961 | ] 962 | }, 963 | { 964 | "cell_type": "markdown", 965 | "metadata": {}, 966 | "source": [ 967 | "The encoder uses the pre-trained ResNet-50 architecture (with the final fully-connected layer removed) to extract features from a batch of pre-processed images. The output is then flattened to a vector, before being passed through a `Linear` layer to transform the feature vector to have the same size as the word embedding.\n", 968 | "\n", 969 | "![Encoder](images/encoder.png)\n", 970 | "\n", 971 | "We can amend the encoder in **model.py**, to experiment with other architectures. In particular, using a [different pre-trained model architecture](http://pytorch.org/docs/master/torchvision/models.html) and adding [add batch normalization](http://pytorch.org/docs/master/nn.html#normalization-layers) could be good options. \n", 972 | "\n", 973 | "\n", 974 | "For this project, I incorporated a pre-trained CNN into the encoder. The `EncoderCNN` class takes `image_embed_size` as an input argument, which will also correspond to the dimensionality of the input to the RNN decoder that we will implement in Step 4. When we train the model in the next notebook in this sequence (**2_Training.ipynb**), we can tweak the value of `image_embed_size`.\n", 975 | "\n", 976 | "If you decide to modify the `EncoderCNN` class, save **model.py** and re-execute the code cell above. If the code cell returns an assertion error, then please follow the instructions to modify the code before proceeding. The assert statements ensure that `features` is a PyTorch tensor with shape `[batch_size, image_embed_size]`." 977 | ] 978 | }, 979 | { 980 | "cell_type": "markdown", 981 | "metadata": {}, 982 | "source": [ 983 | "\n", 984 | "## Step 4: Implementing the RNN Decoder\n", 985 | "\n", 986 | "Before executing the next code cell, please read the `__init__` and `forward` methods in the `DecoderRNN` class in **model.py**. We will work on `sample` method when we reach **3_Inference.ipynb**.\n", 987 | "\n", 988 | "The decoder is an instance of the `DecoderRNN` class and accepts the followings as input:\n", 989 | "- the PyTorch tensor `features` containing the embedded image features (outputted in Step 3, when the last batch of images from Step 2 was passed through `encoder`)\n", 990 | "- a PyTorch tensor corresponding to the last batch of captions (`captions`) from Step 2.\n", 991 | "\n", 992 | "Note that the way I have written the data loader simplifies the code a bit. In particular, every training batch will contain pre-processed **captions where all have the same length** (`captions.shape[1]`), so **we do not need to worry about padding**. \n", 993 | "> I have implemented the decoder described in [this paper](https://arxiv.org/pdf/1411.4555.pdf), you can implement any architecture of your choosing. \n", 994 | "\n", 995 | "Although we will test the decoder using the last batch that is currently stored in the notebook, the decoder can accept an arbitrary batch (of embedded image features and pre-processed captions [where all captions have the same length]) as input. \n", 996 | "\n", 997 | "![Decoder](images/decoder.png)\n", 998 | "\n", 999 | "In the code cell below, `outputs` is a PyTorch tensor with size `[batch_size, captions.shape[1], vocab_size]`. The output designed such that `outputs[i,j,k]` contains the model's predicted score, indicating how likely the `j`-th token in the `i`-th caption in the batch is the `k`-th token in the vocabulary. In the next notebook of the sequence (**2_Training.ipynb**), we provide code to supply these scores to the [`torch.nn.CrossEntropyLoss`](http://pytorch.org/docs/master/nn.html#torch.nn.CrossEntropyLoss) optimizer in PyTorch." 1000 | ] 1001 | }, 1002 | { 1003 | "cell_type": "code", 1004 | "execution_count": 20, 1005 | "metadata": {}, 1006 | "outputs": [ 1007 | { 1008 | "name": "stdout", 1009 | "output_type": "stream", 1010 | "text": [ 1011 | "256\n" 1012 | ] 1013 | } 1014 | ], 1015 | "source": [ 1016 | "print(image_embed_size)" 1017 | ] 1018 | }, 1019 | { 1020 | "cell_type": "code", 1021 | "execution_count": 21, 1022 | "metadata": { 1023 | "pycharm": { 1024 | "is_executing": false 1025 | } 1026 | }, 1027 | "outputs": [ 1028 | { 1029 | "name": "stdout", 1030 | "output_type": "stream", 1031 | "text": [ 1032 | "type(outputs): \n", 1033 | "outputs.shape: torch.Size([10, 13, 8852])\n" 1034 | ] 1035 | } 1036 | ], 1037 | "source": [ 1038 | "# Specify the number of features in the hidden state of the RNN decoder.\n", 1039 | "hidden_size = 512\n", 1040 | "\n", 1041 | "word_embed_size = image_embed_size\n", 1042 | "\n", 1043 | "# Store the size of the vocabulary.\n", 1044 | "vocab_size = len(data_loader.dataset.vocab)\n", 1045 | "\n", 1046 | "# Initialize the decoder.\n", 1047 | "decoder = DecoderRNN(word_embed_size, hidden_size, vocab_size)\n", 1048 | "\n", 1049 | "# Move the decoder to proper device.\n", 1050 | "decoder.to(device)\n", 1051 | "\n", 1052 | "\n", 1053 | "# Move last batch of captions (from Step 1) to GPU if CUDA is available\n", 1054 | "captions = captions.to(device)\n", 1055 | "\n", 1056 | "# Pass the encoder output and captions through the decoder.\n", 1057 | "# outputs[i,j,k] contains the model's predicted score:\n", 1058 | "# how likely the j-th token in the i-th caption in the batch is the k-th token in the vocabulary.\n", 1059 | "\n", 1060 | "outputs = decoder(features, captions) # (bs, cap_length, vocab_size)\n", 1061 | "\n", 1062 | "\n", 1063 | "print(\"type(outputs):\", type(outputs))\n", 1064 | "print(\"outputs.shape:\", outputs.shape)\n", 1065 | "\n", 1066 | "# Check that the decoder satisfies the requirements!\n", 1067 | "assert type(outputs) == torch.Tensor, \"Decoder output needs to be a PyTorch Tensor.\"\n", 1068 | "assert (\n", 1069 | " (outputs.shape[0] == batch_size)\n", 1070 | " and (outputs.shape[1] == captions.shape[1])\n", 1071 | " and (outputs.shape[2] == vocab_size)\n", 1072 | "), \"The shape of the decoder output is incorrect.\"" 1073 | ] 1074 | }, 1075 | { 1076 | "cell_type": "markdown", 1077 | "metadata": {}, 1078 | "source": [ 1079 | "When we are training the model in the next notebook in this sequence (**2_Training.ipynb**), we can tweak the value of `hidden_size`." 1080 | ] 1081 | } 1082 | ], 1083 | "metadata": { 1084 | "anaconda-cloud": {}, 1085 | "kernelspec": { 1086 | "display_name": "Python 3 (ipykernel)", 1087 | "language": "python", 1088 | "name": "python3" 1089 | }, 1090 | "language_info": { 1091 | "codemirror_mode": { 1092 | "name": "ipython", 1093 | "version": 3 1094 | }, 1095 | "file_extension": ".py", 1096 | "mimetype": "text/x-python", 1097 | "name": "python", 1098 | "nbconvert_exporter": "python", 1099 | "pygments_lexer": "ipython3", 1100 | "version": "3.7.10" 1101 | }, 1102 | "varInspector": { 1103 | "cols": { 1104 | "lenName": 16, 1105 | "lenType": 16, 1106 | "lenVar": 40 1107 | }, 1108 | "kernels_config": { 1109 | "python": { 1110 | "delete_cmd_postfix": "", 1111 | "delete_cmd_prefix": "del ", 1112 | "library": "var_list.py", 1113 | "varRefreshCmd": "print(var_dic_list())" 1114 | }, 1115 | "r": { 1116 | "delete_cmd_postfix": ") ", 1117 | "delete_cmd_prefix": "rm(", 1118 | "library": "var_list.r", 1119 | "varRefreshCmd": "cat(var_dic_list()) " 1120 | } 1121 | }, 1122 | "types_to_exclude": [ 1123 | "module", 1124 | "function", 1125 | "builtin_function_or_method", 1126 | "instance", 1127 | "_Feature" 1128 | ], 1129 | "window_display": false 1130 | } 1131 | }, 1132 | "nbformat": 4, 1133 | "nbformat_minor": 2 1134 | } 1135 | -------------------------------------------------------------------------------- /2_Training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Training Network\n", 8 | "\n", 9 | "In this notebook, we will train the CNN-RNN model. We can try out many different architectures and hyperparameters when searching for a good model.\n", 10 | "\n", 11 | "Outline of this notebook:\n", 12 | "- [Step 1](#step1): Training Setup\n", 13 | "- [Step 2](#step2): Training the Model\n", 14 | "- [Step 3](#step3): Validating the Model using Bleu Score" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "\n", 22 | "## Step 1: Training Setup\n", 23 | "\n", 24 | "In this step of the notebook, we will customize the training of the CNN-RNN model by specifying hyperparameters and setting other options that are important to the training procedure. The values we set now will be used when training we model in **Step 2** below.\n", 25 | "\n", 26 | "### Parameters\n", 27 | "\n", 28 | "We begin by setting the following variables:\n", 29 | "- `batch_size` - the batch size of each training batch. It is the number of image-caption pairs used to amend the model weights in each training step. \n", 30 | "- `vocab_threshold` - the minimum word count threshold. Note that a larger threshold will result in a smaller vocabulary, whereas a smaller threshold will include rarer words and result in a larger vocabulary. \n", 31 | "- `vocab_from_file` - a Boolean that decides whether to load the vocabulary from file. \n", 32 | "- `embed_size` - the dimensionality of the image and word embeddings.\n", 33 | "- `hidden_size` - the number of features in the hidden state of the RNN decoder.\n", 34 | "- `num_epochs` - the number of epochs to train the model. We set `num_epochs=3`, but feel free to increase or decrease this number. [This paper](https://arxiv.org/pdf/1502.03044.pdf) trained a captioning model on a single state-of-the-art GPU for 3 days, but we'll soon see that we can get reasonable results in a matter of a few hours! (_But of course, if we want to compete with current research, we will have to train for much longer._)\n", 35 | "- `save_every` - determines how often to save the model weights. We set `save_every=1`, to save the model weights after each epoch. This way, after the `i`th epoch, the encoder and decoder weights will be saved in the `models/` folder as `encoder-i.pkl` and `decoder-i.pkl`, respectively.\n", 36 | "- `print_every` - determines how often to print the batch loss to the Jupyter notebook while training. Note that we probably **will not** observe a monotonic decrease in the loss function while training - this is perfectly fine and completely expected! We keep this at its default value of `20` to avoid clogging the notebook.\n", 37 | "- `log_file` - the name of the text file containing, for every step, how the loss and perplexity evolved during training.\n", 38 | "\n", 39 | "\n", 40 | "### Image Transformations\n", 41 | "\n", 42 | "I use the `transform_train` as described in the previous notebook. In the original [ResNet](https://arxiv.org/pdf/1512.03385.pdf) paper, which is the ResNet architecture that our CNN encoder uses, it scales the shorter edge of images to 256, randomly crops it at 224, randomly samples, and horizontally flips the images, and performs batch normalization. Thus, to keep the best performance of the original ResNet model, it makes the most sense to keep the image preprocessing and transforms the same as the original model. Thus, I use the default `transform_train` as follows:\n", 43 | "\n", 44 | "```\n", 45 | "transform_train = transforms.Compose([ \n", 46 | " transforms.Resize(256), # smaller edge of image resized to 256\n", 47 | " transforms.RandomCrop(224), # get 224x224 crop from random location\n", 48 | " transforms.RandomHorizontalFlip(), # horizontally flip image with probability=0.5\n", 49 | " transforms.ToTensor(), # convert the PIL Image to a tensor\n", 50 | " transforms.Normalize((0.485, 0.456, 0.406), # normalize image for pre-trained model\n", 51 | " (0.229, 0.224, 0.225))])\n", 52 | "```\n", 53 | "If you are gonna modifying this transform, keep in mind that:\n", 54 | "- The images in the dataset have varying heights and widths, and \n", 55 | "- When using a pre-trained model, it must perform the corresponding appropriate normalization.\n", 56 | "\n", 57 | "\n", 58 | "### Hyperparameters\n", 59 | "\n", 60 | "To obtain a strong initial guess for which hyperparameters are likely to work best, I initially consulted [this paper](https://arxiv.org/pdf/1502.03044.pdf) and [this paper](https://arxiv.org/pdf/1411.4555.pdf). I used a minimum word count threshold of **5**, an embedding size of **512**, and a hidden size of **512** as well. I trained the network for 3 epochs. When initially inspecting the loss decrease, it is decreasing well as expected, but after training for 20 hours, when I did the inference on test images, the network appears to have overfitted on the training data, because generated captions are not related to the test images at all. I repeated the inference with the model trained after every epoch, and it still performs unsatisfactorily. Thus, I decreased the embedding size to **256** and trained again, this time for only 1 epoch. The network performs great this time! If you are unhappy with the performance, you can return to this notebook to tweak the hyperparameters (and/or the architecture in **model.py**) and re-train the model.\n", 61 | "\n", 62 | "\n", 63 | "### Trainable Parameters\n", 64 | "\n", 65 | "We can specify a Python list containing the learnable parameters of the model. For instance, if we decide to make all weights in the decoder trainable, but only want to train the weights in the embedding layer of the encoder, then we should set `params` to something like:\n", 66 | "\n", 67 | "```\n", 68 | "params = list(decoder.parameters()) + list(encoder.embed.parameters()) \n", 69 | "```\n", 70 | "\n", 71 | "I decided to freeze all but the last layer of ResNet, because it's already pre-trained on ResNet and performs well. We can still fine tune the entire ResNet for better performance, but since ResNet is a kind of big and deep architecture with a lot of parameters, freezing them makes the training faster, as the RNN decoder is already slow to train. Empirical results suggest that the pre-trained ResNet indeed does a good job. Since the last layer of the CNN encoder is used to transform the CNN feature map to something that RNN needs, it makes sense to train the last new fully connected layer from scratch. \n", 72 | "\n", 73 | "The RNN decoder is completely new, and not a part of the pre-trained ResNet, so we also train all the parameters inside the RNN decoder.\n", 74 | "\n", 75 | "### Optimizer\n", 76 | "\n", 77 | "Finally, we need to select an [optimizer](http://pytorch.org/docs/master/optim.html#torch.optim.Optimizer). I chose the Adam optimizer to optimize the [CrossEntropyLoss](https://medium.com/swlh/cross-entropy-loss-in-pytorch-c010faf97bab) because it is one of the most popular and effective optimizers. It combines the benefits of weight decay, momentum, and many other optimization tricks altogether." 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 83, 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "name": "stdout", 87 | "output_type": "stream", 88 | "text": [ 89 | "The autoreload extension is already loaded. To reload it, use:\n", 90 | " %reload_ext autoreload\n" 91 | ] 92 | } 93 | ], 94 | "source": [ 95 | "# Watch for any changes in model.py, and re-load it automatically.\n", 96 | "import math\n", 97 | "from model import EncoderCNN, DecoderRNN\n", 98 | "from data_loader import get_loader\n", 99 | "from data_loader_val import get_loader as val_get_loader\n", 100 | "from pycocotools.coco import COCO\n", 101 | "from torchvision import transforms\n", 102 | "from tqdm.notebook import tqdm\n", 103 | "import torch.nn as nn\n", 104 | "import torch\n", 105 | "import torch.utils.data as data\n", 106 | "from collections import defaultdict\n", 107 | "import json\n", 108 | "import os\n", 109 | "import sys\n", 110 | "import numpy as np\n", 111 | "from nlp_utils import clean_sentence, bleu_score\n", 112 | "\n", 113 | "%load_ext autoreload\n", 114 | "%autoreload 2" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 2, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "# Setting hyperparameters\n", 124 | "batch_size = 128 # batch size\n", 125 | "vocab_threshold = 5 # minimum word count threshold\n", 126 | "vocab_from_file = True # if True, load existing vocab file\n", 127 | "embed_size = 256 # dimensionality of image and word embeddings\n", 128 | "hidden_size = 512 # number of features in hidden state of the RNN decoder\n", 129 | "num_epochs = 3 # number of training epochs\n", 130 | "save_every = 1 # determines frequency of saving model weights\n", 131 | "print_every = 20 # determines window for printing average loss\n", 132 | "log_file = \"training_log.txt\" # name of file with saved training loss and perplexity\n", 133 | "# Path to cocoapi dir\n", 134 | "cocoapi_dir = r\"path/to/cocoapi/dir\"\n", 135 | "\n", 136 | "\n", 137 | "# Amend the image transform below.\n", 138 | "transform_train = transforms.Compose(\n", 139 | " [\n", 140 | " # smaller edge of image resized to 256\n", 141 | " transforms.Resize(256),\n", 142 | " # get 224x224 crop from random location\n", 143 | " transforms.RandomCrop(224),\n", 144 | " # horizontally flip image with probability=0.5\n", 145 | " transforms.RandomHorizontalFlip(),\n", 146 | " # convert the PIL Image to a tensor\n", 147 | " transforms.ToTensor(),\n", 148 | " transforms.Normalize(\n", 149 | " (0.485, 0.456, 0.406), # normalize image for pre-trained model\n", 150 | " (0.229, 0.224, 0.225),\n", 151 | " ),\n", 152 | " ]\n", 153 | ")" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 3, 159 | "metadata": {}, 160 | "outputs": [ 161 | { 162 | "name": "stdout", 163 | "output_type": "stream", 164 | "text": [ 165 | "Vocabulary successfully loaded from vocab.pkl file!\n", 166 | "loading annotations into memory...\n", 167 | "Done (t=1.08s)\n", 168 | "creating index...\n", 169 | "index created!\n", 170 | "Obtaining caption lengths...\n" 171 | ] 172 | }, 173 | { 174 | "name": "stderr", 175 | "output_type": "stream", 176 | "text": [ 177 | "100%|█████████████████████████████████████████████████████████████████████████| 414113/414113 [01:21<00:00, 5075.39it/s]\n" 178 | ] 179 | } 180 | ], 181 | "source": [ 182 | "# Build data loader.\n", 183 | "data_loader = get_loader(\n", 184 | " transform=transform_train,\n", 185 | " mode=\"train\",\n", 186 | " batch_size=batch_size,\n", 187 | " vocab_threshold=vocab_threshold,\n", 188 | " vocab_from_file=vocab_from_file,\n", 189 | " cocoapi_loc=cocoapi_dir,\n", 190 | ")" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 4, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "# The size of the vocabulary.\n", 200 | "vocab_size = len(data_loader.dataset.vocab)\n", 201 | "\n", 202 | "# Initializing the encoder and decoder\n", 203 | "encoder = EncoderCNN(embed_size)\n", 204 | "decoder = DecoderRNN(embed_size, hidden_size, vocab_size)\n", 205 | "\n", 206 | "# Move models to device\n", 207 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 208 | "encoder.to(device)\n", 209 | "decoder.to(device)\n", 210 | "\n", 211 | "# Defining the loss function\n", 212 | "criterion = (\n", 213 | " nn.CrossEntropyLoss().cuda() if torch.cuda.is_available() else nn.CrossEntropyLoss()\n", 214 | ")\n", 215 | "\n", 216 | "# Specifying the learnable parameters of the mode\n", 217 | "params = list(decoder.parameters()) + list(encoder.embed.parameters())\n", 218 | "\n", 219 | "# Defining the optimize\n", 220 | "optimizer = torch.optim.Adam(params, lr=0.001)\n", 221 | "\n", 222 | "# Set the total number of training steps per epoc\n", 223 | "total_step = math.ceil(len(data_loader.dataset) / data_loader.batch_sampler.batch_size)" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 5, 229 | "metadata": {}, 230 | "outputs": [ 231 | { 232 | "name": "stdout", 233 | "output_type": "stream", 234 | "text": [ 235 | "3236\n" 236 | ] 237 | } 238 | ], 239 | "source": [ 240 | "print(total_step)" 241 | ] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "metadata": {}, 246 | "source": [ 247 | "\n", 248 | "## Step 2: Training the Model\n", 249 | "\n", 250 | "It is useful to load saved weights to resume training. In that case, note the names of the files containing the encoder and decoder weights that we'd like to load (`encoder_file` and `decoder_file`). Then we can load the weights by using the lines below:\n", 251 | "\n", 252 | "```python\n", 253 | "# Load pre-trained weights before resuming training.\n", 254 | "encoder.load_state_dict(torch.load(os.path.join('./models', encoder_file)))\n", 255 | "decoder.load_state_dict(torch.load(os.path.join('./models', decoder_file)))\n", 256 | "```\n", 257 | "\n", 258 | "It is a good practice to make sure to take extensive notes and record the settings that we used in various training runs while we trying out parameters.\n", 259 | "\n", 260 | "### A Note on Tuning Hyperparameters\n", 261 | "\n", 262 | "To figure out how well the model is doing, we can look at how the training loss and [perplexity](http://www.sefidian.com/2022/05/11/understanding-perplexity-for-language-models/) evolve during training. However, this will not tell us if our model is overfitting to the training data, and, unfortunately, **overfitting is a problem that is commonly encountered when training image captioning models**. \n", 263 | "\n", 264 | "In this project I mainly do not have strict requirements regarding the performance of the model. We want to demonstrate that the model has learned **_something_** when we generate captions on the test data. For now, I train the model for 3 epochs without worrying about performance. Then, we will go to the next notebook in the sequence (**3_Inference.ipynb**) to see how the model performs on the test data. We can come back to this notebook and amend hyperparameters (if necessary), and re-train the model.\n", 265 | "\n", 266 | "You can read about some approaches to minimizing overfitting in section 4.3.1 of [this paper](http://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=7505636). In the next step of this notebook, I provide some guidance for assessing the performance on the validation dataset." 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": null, 272 | "metadata": { 273 | "scrolled": true 274 | }, 275 | "outputs": [], 276 | "source": [ 277 | "# Open the training log file.\n", 278 | "f = open(log_file, \"w\")\n", 279 | "\n", 280 | "for epoch in range(1, num_epochs + 1):\n", 281 | " for i_step in range(1, total_step + 1):\n", 282 | "\n", 283 | " # Randomly sample a caption length, and sample indices with that length.\n", 284 | " indices = data_loader.dataset.get_train_indices()\n", 285 | " # Create and assign a batch sampler to retrieve a batch with the sampled indices.\n", 286 | " new_saosmpler = data.sampler.SubsetRandomSampler(indices=indices)\n", 287 | " data_loader.batch_sampler.sampler = new_sampler\n", 288 | "\n", 289 | " # Obtain the batch.\n", 290 | " images, captions = next(iter(data_loader))\n", 291 | "\n", 292 | " # Move batch of images and captions to GPU if CUDA is available.\n", 293 | " images = images.to(device)\n", 294 | " captions = captions.to(device)\n", 295 | "\n", 296 | " # Zero the gradients.\n", 297 | " decoder.zero_grad()\n", 298 | " encoder.zero_grad()\n", 299 | "\n", 300 | " # Passing the inputs through the CNN-RNN model\n", 301 | " features = encoder(images)\n", 302 | " outputs = decoder(features, captions)\n", 303 | "\n", 304 | " # Calculating the batch loss.\n", 305 | " loss = criterion(outputs.view(-1, vocab_size), captions.view(-1))\n", 306 | "\n", 307 | " # # Uncomment to debug\n", 308 | " # print(outputs.shape, captions.shape)\n", 309 | " # # torch.Size([bs, cap_len, vocab_size]) torch.Size([bs, cap_len])\n", 310 | "\n", 311 | " # print(outputs.view(-1, vocab_size).shape, captions.view(-1).shape)\n", 312 | " # # torch.Size([bs*cap_len, vocab_size]) torch.Size([bs*cap_len])\n", 313 | "\n", 314 | " # Backwarding pass\n", 315 | " loss.backward()\n", 316 | "\n", 317 | " # Updating the parameters in the optimizer\n", 318 | " optimizer.step()\n", 319 | "\n", 320 | " # Getting training statistics\n", 321 | " stats = (\n", 322 | " f\"Epoch [{epoch}/{num_epochs}], Step [{i_step}/{total_step}], \"\n", 323 | " f\"Loss: {loss.item():.4f}, Perplexity: {np.exp(loss.item()):.4f}\"\n", 324 | " )\n", 325 | "\n", 326 | " # Print training statistics to file.\n", 327 | " f.write(stats + \"\\n\")\n", 328 | " f.flush()\n", 329 | "\n", 330 | " # Print training statistics (on different line).\n", 331 | " if i_step % print_every == 0:\n", 332 | " print(\"\\r\" + stats)\n", 333 | "\n", 334 | " # Save the weights.\n", 335 | " if epoch % save_every == 0:\n", 336 | " torch.save(\n", 337 | " decoder.state_dict(), os.path.join(\"./models\", \"decoder-%d.pkl\" % epoch)\n", 338 | " )\n", 339 | " torch.save(\n", 340 | " encoder.state_dict(), os.path.join(\"./models\", \"encoder-%d.pkl\" % epoch)\n", 341 | " )\n", 342 | "\n", 343 | "# Close the training log file.\n", 344 | "f.close()" 345 | ] 346 | }, 347 | { 348 | "cell_type": "markdown", 349 | "metadata": {}, 350 | "source": [ 351 | "### Uncomment below codes to save the models" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 16, 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "# torch.save(decoder.state_dict(), os.path.join('./models', 'decoder-final.pkl'))\n", 361 | "# torch.save(encoder.state_dict(), os.path.join('./models', 'encoder-final.pkl'))" 362 | ] 363 | }, 364 | { 365 | "cell_type": "markdown", 366 | "metadata": {}, 367 | "source": [ 368 | "\n", 369 | "## Step 3: Validating the Model using Bleu Score\n", 370 | "\n", 371 | "To assess potential overfitting, one approach is to assess performance on a validation set. To do this task, we need to first complete all of the steps in the next notebook in the sequence (**3_Inference.ipynb**); as part of that notebook, please see the `sample` method in the `DecoderRNN` class that uses the RNN decoder to generate captions. \n", 372 | "\n", 373 | "To validate our model, I created a new file named **data_loader_val.py** containing the code for obtaining the data loader for the validation data. We can access:\n", 374 | "- the validation images at filepath `'/opt/cocoapi/images/train2014/'`, and\n", 375 | "- the validation image caption annotation file at filepath `'/opt/cocoapi/annotations/captions_val2014.json'`.\n", 376 | "\n", 377 | "The suggested approach to validating the model involves creating a .json file such as [this one](https://github.com/cocodataset/cocoapi/blob/master/results/captions_val2014_fakecap_results.json) containing the model's predicted captions for the validation images. Then, we can write our own script or use one that we can [find online](https://github.com/tylin/coco-caption) to calculate the BLEU score of our model. Read more about the BLEU score, along with other evaluation metrics (such as TEOR and Cider) in section 4.1 of [this paper](https://arxiv.org/pdf/1411.4555.pdf). For more information about how to use the annotation file, check out the [website](http://cocodataset.org/#download) for the COCO dataset." 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": 84, 383 | "metadata": {}, 384 | "outputs": [ 385 | { 386 | "name": "stdout", 387 | "output_type": "stream", 388 | "text": [ 389 | "Vocabulary successfully loaded from vocab.pkl file!\n" 390 | ] 391 | }, 392 | { 393 | "data": { 394 | "text/plain": [ 395 | "DecoderRNN(\n", 396 | " (embed): Embedding(9955, 256)\n", 397 | " (lstm): LSTM(256, 512, batch_first=True)\n", 398 | " (linear): Linear(in_features=512, out_features=9955, bias=True)\n", 399 | ")" 400 | ] 401 | }, 402 | "execution_count": 84, 403 | "metadata": {}, 404 | "output_type": "execute_result" 405 | } 406 | ], 407 | "source": [ 408 | "transform_test = transforms.Compose(\n", 409 | " [\n", 410 | " transforms.Resize(224),\n", 411 | " transforms.ToTensor(),\n", 412 | " transforms.Normalize(\n", 413 | " (0.485, 0.456, 0.406), # normalize image for pre-trained model\n", 414 | " (0.229, 0.224, 0.225),\n", 415 | " ),\n", 416 | " ]\n", 417 | ")\n", 418 | "\n", 419 | "\n", 420 | "# Create the data loader.\n", 421 | "val_data_loader = val_get_loader(\n", 422 | " transform=transform_test, mode=\"valid\", cocoapi_loc=cocoapi_dir\n", 423 | ")\n", 424 | "\n", 425 | "encoder_file = \"encoder-3.pkl\"\n", 426 | "decoder_file = \"decoder-3.pkl\"\n", 427 | "\n", 428 | "# Initialize the encoder and decoder.\n", 429 | "encoder = EncoderCNN(embed_size)\n", 430 | "decoder = DecoderRNN(embed_size, hidden_size, vocab_size)\n", 431 | "\n", 432 | "# Moving models to GPU if CUDA is available.\n", 433 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 434 | "encoder.to(device)\n", 435 | "decoder.to(device)\n", 436 | "\n", 437 | "# Loading the trained weights\n", 438 | "encoder.load_state_dict(torch.load(os.path.join(\"./models\", encoder_file)))\n", 439 | "decoder.load_state_dict(torch.load(os.path.join(\"./models\", decoder_file)))\n", 440 | "\n", 441 | "encoder.eval()\n", 442 | "decoder.eval()" 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": 85, 448 | "metadata": {}, 449 | "outputs": [ 450 | { 451 | "data": { 452 | "application/vnd.jupyter.widget-view+json": { 453 | "model_id": "cda5deb725e5442d8f896cf20ff1189e", 454 | "version_major": 2, 455 | "version_minor": 0 456 | }, 457 | "text/plain": [ 458 | " 0%| | 0/40504 [00:00:image_captioning $`. The `(captioning_env)` indicates that your environment has been activated, and you can proceed with further package installations. 29 | 30 | 6. Before you can experiment with the code, you'll have to make sure that you have all the libraries and dependencies required to support this project. You will mainly need Python3.7+, PyTorch and its torchvision, OpenCV, and Matplotlib. You can install dependencies using: 31 | ``` 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | 7. Navigate back to the repo. (Also, your source environment should still be activated at this point.) 36 | ```shell 37 | cd image_captioning 38 | ``` 39 | 40 | 8. Open the directory of notebooks, using the below command. You'll see all of the project files appear in your local environment; open the first notebook and follow the instructions. 41 | ```shell 42 | jupyter notebook 43 | ``` 44 | 45 | 9. Once you open any of the project notebooks, make sure you are in the correct `captioning_env` environment by clicking `Kernel > Change Kernel > captioning_env`. 46 | 47 | 48 | ## Dataset 49 | ### About MS COCO dataset 50 | The Microsoft **C**ommon **O**bjects in **CO**ntext (MS COCO) dataset is a large-scale dataset for scene understanding. The dataset is commonly used to train and benchmark object detection, segmentation, and captioning algorithms. 51 | 52 | ![Sample Coco Example](images/coco-examples.jpg) 53 | 54 | You can read more about the dataset on the [website](http://cocodataset.org/#home), [research paper](https://arxiv.org/pdf/1405.0312.pdf), or Appendix section at the end of this page. 55 | 56 | ### Install COCO API 57 | 58 | 1. Clone this repo: https://github.com/cocodataset/cocoapi 59 | ``` 60 | git clone https://github.com/cocodataset/cocoapi.git 61 | ``` 62 | 63 | 2. Setup the coco API (also described in the readme [here](https://github.com/cocodataset/cocoapi)) 64 | ``` 65 | cd cocoapi/PythonAPI 66 | make 67 | cd .. 68 | ``` 69 | 70 | 3. Download some specific data from here: http://cocodataset.org/#download (described below) 71 | 72 | * Under **Annotations**, download: 73 | * **2014 Train/Val annotations [241MB]** (extract captions_train2014.json and captions_val2014.json, and place at locations cocoapi/annotations/captions_train2014.json and cocoapi/annotations/captions_val2014.json, respectively) 74 | * **2014 Testing Image info [1MB]** (extract image_info_test2014.json and place at location cocoapi/annotations/image_info_test2014.json) 75 | 76 | * Under **Images**, download: 77 | * **2014 Train images [83K/13GB]** (extract the train2014 folder and place at location cocoapi/images/train2014/) 78 | * **2014 Val images [41K/6GB]** (extract the val2014 folder and place at location cocoapi/images/val2014/) 79 | * **2014 Test images [41K/6GB]** (extract the test2014 folder and place at location cocoapi/images/test2014/) 80 | 81 | ## Jupyter Notebooks 82 | The project is structured as a series of Jupyter notebooks that should be run in sequential order: 83 | 84 | ### [0. Dataset Exploration notebook](0_Dataset_Exploration.ipynb) 85 | 86 | This notebook initializes the [COCO API](https://github.com/cocodataset/cocoapi) (the "pycocotools" library) used to access data from the MS COCO (Common Objects in Context) dataset, which is "commonly used to train and benchmark object detection, segmentation, and captioning algorithms." 87 | 88 | ### [1. Architecture notebook](1_Architecture.ipynb) 89 | 90 | This notebook uses the pycocotools, torchvision transforms, and NLTK to preprocess the images and the captions for network training. It also explores details of EncoderCNN, which is taken pretrained from [torchvision.models, the ResNet50 architecture](https://pytorch.org/docs/master/torchvision/models.html#id3). The implementations of the EncoderCNN and DecoderRNN are found in the [model.py](model.py) file. 91 | 92 | The core architecture used to achieve this task follows an encoder-decoder architecture, where the encoder is a pretrained ResNet CNN on ImageNet, and the decoder is a basic one-layer LSTM. 93 | 94 | #### Architecture Details 95 | ![encoder-decoder-architecture](images/encoder-decoder.png) 96 | 97 | The left half of the diagram depicts the "EncoderCNN", which encodes the critical information contained in a regular picture file into a "feature vector" of a specific size. That feature vector is fed into the "DecoderRNN" on the right half of the diagram (which is "unfolded" in time - each box labeled "LSTM" represents the same cell at a different time step). Each word appearing as output at the top is fed back to the network as input (at the bottom) in a subsequent time step until the entire caption is generated. The arrow pointing right that connects the LSTM boxes together represents hidden state information, which represents the network's "memory", also fed back to the LSTM at each time step. 98 | 99 | The architecture consists of a CNN encoder and RNN decoder. The CNN encoder is a pre-trained ResNet on ImageNet, which is a VGG convolutional neural network with skip connections. It has been proven to work really well on tasks like image recognition because the residual connections help model the residual differences before and after the convolution with the help of the identity block. A good pre-trained network on ImageNet is already good at extracting both useful low-level and high-level features for image tasks, so it naturally serves as a feature encoder for the image we want to caption. Since we are not doing the traditional image classification task, we drop the last fully connected layer and replace it without a new trainable fully connected layer to help transform the final feature map to an encoding that is more useful for the RNN decoder. 100 | 101 | RNNs have long been shown useful in language tasks due to their ability to model data with sequential nature, such as language. Specifically, LSTMs even incorporate both long-term and short-term information as memories in the network. Thus, we pick an RNN decoder for the captioning task. Specifically, following the spirit of sequence to sequence (seq2seq) models used in translation, I leveraged the architecture choices in [this paper](https://arxiv.org/pdf/1411.4555.pdf) to use an LSTM to generate captions based on the encoded information from the CNN encoder. Specifically, **I first use the CNN encoder output concatenated with the "START" token as the initial input for the RNN decoder.** I apply a fully connected layer on the hidden states at that timestamp to output a softmax probability over the words in our entire vocabulary, where we choose the word with the highest probability as the word generated at that timestamp. Then, we feed this predicted word back again as the input for the next step. We continue so until we generated a caption of max length, or the network generated the "STOP" token, which indicates the end of the sentence. 102 | 103 | #### LSTM Decoder 104 | In the project, we pass all our inputs as a sequence to an LSTM. A sequence looks like this: first a feature vector that is extracted from an input image, then a start word, then the next word, the next word, and so on. 105 | 106 | #### Embedding Dimension 107 | The LSTM is defined such that, as it sequentially looks at inputs, it expects that each individual input in a sequence is of a consistent size and so we embed the feature vector and each word so that they are `embed_size`. 108 | 109 | #### Using my trained model 110 | You can [download](https://drive.google.com/file/d/1s3aRAdt8ZMqn53UUSSFLCDBJ_F-KNbpE/view?usp=sharing) my trained models by unzipping the `captioning_models.zip` file in the `models` directory of project for your own experimentation. 111 | 112 | Feel free to experiment with alternative architectures, such as bidirectional LSTM with attention mechanisms! 113 | 114 | ### [2. Training notebook](2_Training.ipynb) 115 | 116 | This notebook provides the selection of hyperparameter values and EncoderRNN training. The hyperparameter selection is also explained. 117 | 118 | 119 | #### Parameters 120 | 121 | - `batch_size` - the batch size of each training batch. It is the number of image-caption pairs used to amend the model weights in each training step. 122 | - `vocab_threshold` - the minimum word count threshold. Note that a larger threshold will result in a smaller vocabulary, whereas a smaller threshold will include rarer words and result in a larger vocabulary. 123 | - `vocab_from_file` - a Boolean that decides whether to load the vocabulary from file. 124 | - `embed_size` - the dimensionality of the image and word embeddings. 125 | - `hidden_size` - the number of features in the hidden state of the RNN decoder. 126 | - `num_epochs` - the number of epochs to train the model. We set `num_epochs=3`, but feel free to increase or decrease this number. [This paper](https://arxiv.org/pdf/1502.03044.pdf) trained a captioning model on a single state-of-the-art GPU for 3 days, but we'll soon see that we can get reasonable results in a matter of a few hours! (_But of course, if we want to compete with current research, we will have to train for much longer._) 127 | - `save_every` - determines how often to save the model weights. We set `save_every=1`, to save the model weights after each epoch. This way, after the `i`th epoch, the encoder and decoder weights will be saved in the `models/` folder as `encoder-i.pkl` and `decoder-i.pkl`, respectively. 128 | - `print_every` - determines how often to print the batch loss to the Jupyter notebook while training. Note that we probably **will not** observe a monotonic decrease in the loss function while training - this is perfectly fine and completely expected! We keep this at its default value of `20` to avoid clogging the notebook. 129 | - `log_file` - the name of the text file containing, for every step, how the loss and perplexity evolved during training. 130 | 131 | 132 | #### Image Transformations 133 | 134 | In the original [ResNet](https://arxiv.org/pdf/1512.03385.pdf) paper, which is the ResNet architecture that our CNN encoder uses, it scales the shorter edge of images to 256, randomly crops it at 224, randomly samples, and horizontally flips the images, and performs batch normalization. Thus, to keep the best performance of the original ResNet model, it makes the most sense to keep the image preprocessing and transforms the same as the original model. Thus, I use the default `transform_train` as follows: 135 | 136 | ``` 137 | transform_train = transforms.Compose([ 138 | transforms.Resize(256), # smaller edge of image resized to 256 139 | transforms.RandomCrop(224), # get 224x224 crop from random location 140 | transforms.RandomHorizontalFlip(), # horizontally flip image with probability=0.5 141 | transforms.ToTensor(), # convert the PIL Image to a tensor 142 | transforms.Normalize((0.485, 0.456, 0.406), # normalize image for pre-trained model 143 | (0.229, 0.224, 0.225))]) 144 | ``` 145 | If you are gonna modifying this transform, keep in mind that: 146 | - The images in the dataset have varying heights and widths, and 147 | - When using a pre-trained model, it must perform the corresponding appropriate normalization. 148 | 149 | 150 | #### Hyperparameters 151 | 152 | To obtain a strong initial guess for which hyperparameters are likely to work best, I initially consulted [this paper](https://arxiv.org/pdf/1502.03044.pdf) and [this paper](https://arxiv.org/pdf/1411.4555.pdf). I used a minimum word count threshold of **5**, an embedding size of **512**, and a hidden size of **512** as well. I trained the network for 3 epochs. When initially inspecting the loss decrease, it is decreasing well as expected, but after training for 20 hours, when I did the inference on test images, the network appears to have overfitted on the training data, because generated captions are not related to the test images at all. I repeated the inference with the model trained after every epoch, and it still performs unsatisfactorily. Thus, I decreased the embedding size to **256** and trained again, this time for only 1 epoch. The network performs great this time! If you are unhappy with the performance, you can return to this notebook to tweak the hyperparameters (and/or the architecture in **model.py**) and re-train the model. 153 | 154 | 155 | #### Trainable Parameters 156 | 157 | We can specify a Python list containing the learnable parameters of the model. For instance, if we decide to make all weights in the decoder trainable, but only want to train the weights in the embedding layer of the encoder, then we should set `params` to something like: 158 | 159 | ``` 160 | params = list(decoder.parameters()) + list(encoder.embed.parameters()) 161 | ``` 162 | 163 | I decided to freeze all but the last layer of ResNet, because it's already pre-trained on ResNet and performs well. We can still fine tune the entire ResNet for better performance, but since ResNet is a kind of big and deep architecture with a lot of parameters, freezing them makes the training faster, as the RNN decoder is already slow to train. Empirical results suggest that the pre-trained ResNet indeed does a good job. Since the last layer of the CNN encoder is used to transform the CNN feature map to something that RNN needs, it makes sense to train the last new fully connected layer from scratch. 164 | 165 | The RNN decoder is completely new, and not a part of the pre-trained ResNet, so we also train all the parameters inside the RNN decoder. 166 | 167 | #### Optimizer 168 | 169 | We need to select an [optimizer](http://pytorch.org/docs/master/optim.html#torch.optim.Optimizer). I chose the Adam optimizer to optimize the [CrossEntropyLoss](https://medium.com/swlh/cross-entropy-loss-in-pytorch-c010faf97bab) because it is one of the most popular and effective optimizers. It combines the benefits of weight decay, momentum, and many other optimization tricks altogether. 170 | 171 | ### [3. Inference notebook](3_Inference.ipynb) 172 | 173 | This notebook contains the testing of the trained networks to generate captions for additional images. No rigorous validation or accuracy measurement was performed, only sample images were generated. 174 | 175 | 176 | ## Results 177 | Here are some predictions from the model. 178 | 179 | ### Some good results 180 | ![sample_171](images/sample_171.png?raw=true)
181 | ![sample_440](images/sample_440.png?raw=true)
182 | ![sample_457](images/sample_457.png?raw=true)
183 | ![sample_002](images/sample_002.png?raw=true)
184 | ![sample_029](images/sample_029.png?raw=true)
185 | ![sample_107](images/sample_107.png?raw=true)
186 | ![sample_202](images/sample_202.png?raw=true) 187 | 188 | 189 | ### Some not so good results 190 | 191 | ![sample_296](images/sample_296.png?raw=true)
192 | ![sample_008](images/sample_008.png?raw=true)
193 | ![sample_193](images/sample_193.png?raw=true)
194 | ![sample_034](images/sample_034.png?raw=true)
195 | ![sample_326](images/sample_326.png?raw=true)
196 | ![sample_366](images/sample_366.png?raw=true)
197 | ![sample_498](images/sample_498.png?raw=true) 198 | 199 | ## Deploy and share image captioning service using Gradio 200 | 201 | [Gradio](http://pytorch.org/docs/master/optim.html#torch.optim.Optimizer) is a package that allows users to create simple web apps with just a few lines of code. It is essentially used for the same purpose as Streamlight and Flask but is much simpler to utilize. Many types of web interface tools can be selected including sketchpad, text boxes, file upload buttons, webcam, etc. Using these tools to receive various types of data as input, machine learning tasks such as classification and regression can easily be demoed. 202 | 203 | You can deploy an interactive version of the image captioning service on your browser by running the following command. Please don't forget to set the `cocoapi_dir` and encoder/decoder model paths to the correct values. 204 | 205 | ```shell 206 | python gradio_main.py 207 | ``` 208 | 209 | Access the service on local URL: http://127.0.0.1:7860/ 210 | 211 | ![sample_498](images/gradio_demo.png) 212 | 213 | 214 | 215 | ## Future work 216 | Steps for additional improvement would be exploring the hyperparameter and other architectures and also training with more epochs. 217 | 218 | ## Appendix: More about COCO dataset API 219 | COCO is a large image dataset designed for object detection, segmentation, person keypoints detection, stuff segmentation, and caption generation. This package provides Matlab, Python, and Lua APIs that assists in loading, parsing, and visualizing the annotations in COCO. Please visit http://cocodataset.org/ for more information on COCO, including the data, paper, and tutorials. The exact format of the annotations is also described on the COCO website. The Matlab and Python APIs are complete, the Lua API provides only basic functionality. 220 | 221 | In addition to this API, please download both the COCO images and annotations in order to run the demos and use the API. Both are available on the project website. 222 | - Please download, unzip, and place the images in: coco/images/ 223 | - Please download and place the annotations in: coco/annotations/ 224 | 225 | For substantially more details on the API please see [COCO Home Page](http://cocodataset.org/#home). 226 | 227 | After downloading the images and annotations, run the Matlab, Python, or Lua demos for example usage. 228 | 229 | To install: 230 | - For Matlab, add coco/MatlabApi to the Matlab path (OSX/Linux binaries provided) 231 | - For Python, run "make" under coco/PythonAPI 232 | - For Lua, run “luarocks make LuaAPI/rocks/coco-scm-1.rockspec” under coco/ 233 | 234 | 235 | Note: This project is a part of [Udacity Computer Vision Nanodegree Program](https://www.udacity.com/course/computer-vision-nanodegree--nd891). -------------------------------------------------------------------------------- /coco_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import nltk 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | from pycocotools.coco import COCO 9 | from torch.utils import data as data 10 | from tqdm import tqdm 11 | 12 | from vocabulary import Vocabulary 13 | 14 | 15 | class CoCoDataset(data.Dataset): 16 | def __init__( 17 | self, 18 | transform, 19 | mode, 20 | batch_size, 21 | vocab_threshold, 22 | vocab_file, 23 | start_word, 24 | end_word, 25 | unk_word, 26 | annotations_file, 27 | vocab_from_file, 28 | img_folder, 29 | ): 30 | self.transform = transform 31 | self.mode = mode 32 | self.batch_size = batch_size 33 | self.img_folder = img_folder 34 | # create vocabulary from the captions 35 | self.vocab = Vocabulary( 36 | vocab_threshold, 37 | vocab_file, 38 | start_word, 39 | end_word, 40 | unk_word, 41 | annotations_file, 42 | vocab_from_file, 43 | ) 44 | if self.mode == "train": 45 | self.coco = COCO(annotations_file) 46 | self.ids = list(self.coco.anns.keys()) 47 | print("Obtaining caption lengths...") 48 | 49 | # get list of tokens for each caption 50 | tokenized_captions = [ 51 | nltk.tokenize.word_tokenize( 52 | str(self.coco.anns[self.ids[index]]["caption"]).lower() 53 | ) 54 | for index in tqdm(np.arange(len(self.ids))) 55 | ] 56 | 57 | # get len of each caption 58 | self.caption_lengths = [len(token) for token in tokenized_captions] 59 | else: 60 | test_info = json.loads(open(annotations_file).read()) 61 | self.paths = [item["file_name"] for item in test_info["images"]] 62 | 63 | def __getitem__(self, index): 64 | # obtain image and caption if in training mode 65 | if self.mode == "train": 66 | ann_id = self.ids[index] 67 | caption = self.coco.anns[ann_id]["caption"] 68 | img_id = self.coco.anns[ann_id]["image_id"] 69 | path = self.coco.loadImgs(img_id)[0]["file_name"] 70 | 71 | # Convert image to tensor and pre-process using transform 72 | image = Image.open(os.path.join(self.img_folder, path)).convert("RGB") 73 | image = self.transform(image) 74 | 75 | # Convert caption to tensor of word ids. 76 | tokens = nltk.tokenize.word_tokenize(str(caption).lower()) 77 | caption = [self.vocab(self.vocab.start_word)] 78 | caption.extend([self.vocab(token) for token in tokens]) 79 | caption.append(self.vocab(self.vocab.end_word)) 80 | caption = torch.Tensor(caption).long() 81 | 82 | # return pre-processed image and caption tensors 83 | return image, caption 84 | 85 | # obtain image if in test mode 86 | else: 87 | path = self.paths[index] 88 | 89 | # Convert image to tensor and pre-process using transform 90 | pil_image = Image.open(os.path.join(self.img_folder, path)).convert("RGB") 91 | orig_image = np.array(pil_image) 92 | image = self.transform(pil_image) 93 | 94 | # return original image and pre-processed image tensor 95 | return orig_image, image 96 | 97 | def get_train_indices(self): 98 | # select random len 99 | sel_length = np.random.choice(self.caption_lengths) 100 | # find indices of captions having specific length 101 | all_indices = np.where( 102 | [ 103 | self.caption_lengths[i] == sel_length 104 | for i in np.arange(len(self.caption_lengths)) 105 | ] 106 | )[0] 107 | # select only limited (batch size) number of them 108 | indices = list(np.random.choice(all_indices, size=self.batch_size)) 109 | return indices 110 | 111 | def __len__(self): 112 | if self.mode == "train": 113 | return len(self.ids) 114 | else: 115 | return len(self.paths) 116 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import nltk 4 | import torch.utils.data as data 5 | 6 | from coco_dataset import CoCoDataset 7 | 8 | nltk.download("punkt") 9 | 10 | 11 | def get_loader( 12 | transform, 13 | mode="train", 14 | batch_size=1, 15 | vocab_threshold=None, 16 | vocab_file="./vocab.pkl", 17 | start_word="", 18 | end_word="", 19 | unk_word="", 20 | vocab_from_file=True, 21 | num_workers=0, 22 | cocoapi_loc="/opt", 23 | ): 24 | """Returns the data loader. 25 | Args: 26 | transform: Image transform. 27 | mode: One of 'train' or 'test'. 28 | batch_size: Batch size (if in testing mode, must have batch_size=1). 29 | vocab_threshold: Minimum word count threshold. 30 | vocab_file: File containing the vocabulary. 31 | start_word: Special word denoting sentence start. 32 | end_word: Special word denoting sentence end. 33 | unk_word: Special word denoting unknown words. 34 | vocab_from_file: If False, create vocab from scratch and override any existing vocab_file. 35 | If True, load vocab from existing vocab_file, if it exists. 36 | num_workers: Number of subprocesses to use for data loading. 37 | cocoapi_loc: The location of the folder containing the COCO API: https://github.com/cocodataset/cocoapi 38 | """ 39 | 40 | assert mode in ["train", "test"], "mode must be one of 'train' or 'test'." 41 | 42 | if not vocab_from_file: 43 | assert ( 44 | mode == "train" 45 | ), "To generate vocab from captions file, must be in training mode (mode='train')." 46 | 47 | # Based on mode (train, val, test), obtain img_folder and annotations_file. 48 | if mode == "train": 49 | if vocab_from_file: 50 | assert os.path.exists( 51 | vocab_file 52 | ), "vocab_file does not exist. Change vocab_from_file to False to create vocab_file." 53 | img_folder = os.path.join(cocoapi_loc, "cocoapi/images/train2014/") 54 | annotations_file = os.path.join( 55 | cocoapi_loc, "cocoapi/annotations/captions_train2014.json" 56 | ) 57 | 58 | elif mode == "test": 59 | assert batch_size == 1, "Please change batch_size to 1 if testing the model." 60 | assert os.path.exists( 61 | vocab_file 62 | ), "Must first generate vocab.pkl from training data." 63 | assert vocab_from_file, "Change vocab_from_file to True." 64 | img_folder = os.path.join(cocoapi_loc, "cocoapi/images/test2014/") 65 | annotations_file = os.path.join( 66 | cocoapi_loc, "cocoapi/annotations/image_info_test2014.json" 67 | ) 68 | else: 69 | raise ValueError(f"Invalid mode: {mode}") 70 | 71 | # COCO caption dataset. 72 | dataset = CoCoDataset( 73 | transform=transform, 74 | mode=mode, 75 | batch_size=batch_size, 76 | vocab_threshold=vocab_threshold, 77 | vocab_file=vocab_file, 78 | start_word=start_word, 79 | end_word=end_word, 80 | unk_word=unk_word, 81 | annotations_file=annotations_file, 82 | vocab_from_file=vocab_from_file, 83 | img_folder=img_folder, 84 | ) 85 | 86 | if mode == "train": 87 | # Randomly sample a caption length, and sample indices with that length. 88 | indices = dataset.get_train_indices() 89 | # Create and assign a batch sampler to retrieve a batch with the sampled indices. 90 | initial_sampler = data.sampler.SubsetRandomSampler(indices=indices) 91 | # data loader for COCO dataset. 92 | data_loader = data.DataLoader( 93 | dataset=dataset, 94 | num_workers=num_workers, 95 | batch_sampler=data.sampler.BatchSampler( 96 | sampler=initial_sampler, batch_size=dataset.batch_size, drop_last=False 97 | ), 98 | ) 99 | else: 100 | data_loader = data.DataLoader( 101 | dataset=dataset, 102 | batch_size=dataset.batch_size, 103 | shuffle=True, 104 | num_workers=num_workers, 105 | ) 106 | 107 | return data_loader 108 | -------------------------------------------------------------------------------- /data_loader_val.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import nltk 5 | import numpy as np 6 | import torch 7 | import torch.utils.data as data 8 | from PIL import Image 9 | from pycocotools.coco import COCO 10 | from tqdm import tqdm 11 | 12 | from vocabulary import Vocabulary 13 | 14 | 15 | def get_loader( 16 | transform, 17 | mode="valid", 18 | batch_size=1, 19 | vocab_threshold=None, 20 | vocab_file="./vocab.pkl", 21 | start_word="", 22 | end_word="", 23 | unk_word="", 24 | vocab_from_file=True, 25 | num_workers=0, 26 | cocoapi_loc="/opt", 27 | ): 28 | """Returns the data loader. 29 | Args: 30 | transform: Image transform. 31 | mode: One of 'train' or 'test'. 32 | batch_size: Batch size (if in testing mode, must have batch_size=1). 33 | vocab_threshold: Minimum word count threshold. 34 | vocab_file: File containing the vocabulary. 35 | start_word: Special word denoting sentence start. 36 | end_word: Special word denoting sentence end. 37 | unk_word: Special word denoting unknown words. 38 | vocab_from_file: If False, create vocab from scratch and override any existing vocab_file. 39 | If True, load vocab from existing vocab_file, if it exists. 40 | num_workers: Number of subprocesses to use for data loading 41 | cocoapi_loc: The location of the folder containing the COCO API: https://github.com/cocodataset/cocoapi 42 | """ 43 | 44 | assert mode in ["train", "valid", "test"], "mode must be one of 'train' or 'test'." 45 | if not vocab_from_file: 46 | assert ( 47 | mode == "train" 48 | ), "To generate vocab from captions file, must be in training mode (mode='train')." 49 | 50 | # Based on mode (train, val, test), obtain img_folder and annotations_file. 51 | if mode == "train": 52 | if vocab_from_file: 53 | assert os.path.exists( 54 | vocab_file 55 | ), "vocab_file does not exist. Change vocab_from_file to False to create vocab_file." 56 | img_folder = os.path.join(cocoapi_loc, "cocoapi/images/train2014/") 57 | annotations_file = os.path.join( 58 | cocoapi_loc, "cocoapi/annotations/captions_train2014.json" 59 | ) 60 | elif mode == "test": 61 | assert batch_size == 1, "Please change batch_size to 1 if testing the model." 62 | assert os.path.exists( 63 | vocab_file 64 | ), "Must first generate vocab.pkl from training data." 65 | assert vocab_from_file, "Change vocab_from_file to True." 66 | img_folder = os.path.join(cocoapi_loc, "cocoapi/images/test2014/") 67 | annotations_file = os.path.join( 68 | cocoapi_loc, "cocoapi/annotations/image_info_test2014.json" 69 | ) 70 | elif mode == "valid": 71 | assert batch_size == 1, "Please change batch_size to 1 if testing the model." 72 | assert os.path.exists( 73 | vocab_file 74 | ), "Must first generate vocab.pkl from training data." 75 | assert vocab_from_file, "Change vocab_from_file to True." 76 | img_folder = os.path.join(cocoapi_loc, "cocoapi/images/val2014/") 77 | annotations_file = os.path.join( 78 | cocoapi_loc, "cocoapi/annotations/captions_val2014.json" 79 | ) 80 | else: 81 | raise ValueError(f"Invalid mode: {mode}") 82 | # COCO caption dataset. 83 | dataset = CoCoDataset( 84 | transform=transform, 85 | mode=mode, 86 | batch_size=batch_size, 87 | vocab_threshold=vocab_threshold, 88 | vocab_file=vocab_file, 89 | start_word=start_word, 90 | end_word=end_word, 91 | unk_word=unk_word, 92 | annotations_file=annotations_file, 93 | vocab_from_file=vocab_from_file, 94 | img_folder=img_folder, 95 | ) 96 | 97 | if mode == "train": 98 | # Randomly sample a caption length, and sample indices with that length. 99 | indices = dataset.get_train_indices() 100 | # Create and assign a batch sampler to retrieve a batch with the sampled indices. 101 | initial_sampler = data.sampler.SubsetRandomSampler(indices=indices) 102 | # data loader for COCO dataset. 103 | data_loader = data.DataLoader( 104 | dataset=dataset, 105 | num_workers=num_workers, 106 | batch_sampler=data.sampler.BatchSampler( 107 | sampler=initial_sampler, batch_size=dataset.batch_size, drop_last=False 108 | ), 109 | ) 110 | else: 111 | data_loader = data.DataLoader( 112 | dataset=dataset, 113 | batch_size=dataset.batch_size, 114 | shuffle=True, 115 | num_workers=num_workers, 116 | ) 117 | 118 | return data_loader 119 | 120 | 121 | class CoCoDataset(data.Dataset): 122 | def __init__( 123 | self, 124 | transform, 125 | mode, 126 | batch_size, 127 | vocab_threshold, 128 | vocab_file, 129 | start_word, 130 | end_word, 131 | unk_word, 132 | annotations_file, 133 | vocab_from_file, 134 | img_folder, 135 | ): 136 | self.transform = transform 137 | self.mode = mode 138 | self.batch_size = batch_size 139 | self.vocab = Vocabulary( 140 | vocab_threshold, 141 | vocab_file, 142 | start_word, 143 | end_word, 144 | unk_word, 145 | annotations_file, 146 | vocab_from_file, 147 | ) 148 | self.img_folder = img_folder 149 | if self.mode == "train": 150 | self.coco = COCO(annotations_file) 151 | self.ids = list(self.coco.anns.keys()) 152 | print("Obtaining caption lengths...") 153 | all_tokens = [ 154 | nltk.tokenize.word_tokenize( 155 | str(self.coco.anns[self.ids[index]]["caption"]).lower() 156 | ) 157 | for index in tqdm(np.arange(len(self.ids))) 158 | ] 159 | self.caption_lengths = [len(token) for token in all_tokens] 160 | else: 161 | test_info = json.loads(open(annotations_file).read()) 162 | self.paths = [item["file_name"] for item in test_info["images"]] 163 | 164 | def __getitem__(self, index): 165 | # obtain image and caption if in training mode 166 | if self.mode == "train": 167 | ann_id = self.ids[index] 168 | caption = self.coco.anns[ann_id]["caption"] 169 | img_id = self.coco.anns[ann_id]["image_id"] 170 | path = self.coco.loadImgs(img_id)[0]["file_name"] 171 | 172 | # Convert image to tensor and pre-process using transform 173 | image = Image.open(os.path.join(self.img_folder, path)).convert("RGB") 174 | image = self.transform(image) 175 | 176 | # Convert caption to tensor of word ids. 177 | tokens = nltk.tokenize.word_tokenize(str(caption).lower()) 178 | caption = [] 179 | caption.append(self.vocab(self.vocab.start_word)) 180 | caption.extend([self.vocab(token) for token in tokens]) 181 | caption.append(self.vocab(self.vocab.end_word)) 182 | caption = torch.Tensor(caption).long() 183 | 184 | # return pre-processed image and caption tensors 185 | return image, caption 186 | 187 | elif self.mode == "valid": 188 | path = self.paths[index] 189 | image_id = int(path.split("/")[0].split(".")[0].split("_")[-1]) 190 | pil_image = Image.open(os.path.join(self.img_folder, path)).convert("RGB") 191 | image = self.transform(pil_image) 192 | 193 | # return original image and pre-processed image tensor 194 | return image_id, image 195 | # obtain image if in test mode 196 | else: 197 | path = self.paths[index] 198 | 199 | # Convert image to tensor and pre-process using transform 200 | pil_image = Image.open(os.path.join(self.img_folder, path)).convert("RGB") 201 | orig_image = np.array(pil_image) 202 | image = self.transform(pil_image) 203 | 204 | # return original image and pre-processed image tensor 205 | return orig_image, image 206 | 207 | def get_train_indices(self): 208 | sel_length = np.random.choice(self.caption_lengths) 209 | all_indices = np.where( 210 | [ 211 | self.caption_lengths[i] == sel_length 212 | for i in np.arange(len(self.caption_lengths)) 213 | ] 214 | )[0] 215 | indices = list(np.random.choice(all_indices, size=self.batch_size)) 216 | return indices 217 | 218 | def __len__(self): 219 | if self.mode == "train": 220 | return len(self.ids) 221 | else: 222 | return len(self.paths) 223 | -------------------------------------------------------------------------------- /gradio_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import gradio as gr 4 | import torch 5 | from gradio import SimpleCSVLogger 6 | from torchvision import transforms 7 | 8 | from data_loader import get_loader 9 | from model import DecoderRNN, EncoderCNN 10 | from nlp_utils import clean_sentence 11 | 12 | cocoapi_dir = r"path/to/cocoapi/dir" 13 | 14 | # # Defining a transform to pre-process the testing images. 15 | transform_test = transforms.Compose( 16 | [ 17 | transforms.Resize(256), # smaller edge of image resized to 256 18 | transforms.RandomCrop(224), # get 224x224 crop from random location 19 | transforms.RandomHorizontalFlip(), # horizontally flip image with probability=0.5 20 | transforms.ToTensor(), # convert the PIL Image to a tensor 21 | transforms.Normalize( 22 | (0.485, 0.456, 0.406), # normalize image for pre-trained model 23 | (0.229, 0.224, 0.225), 24 | ), 25 | ] 26 | ) 27 | 28 | # Creating the data loader. 29 | data_loader = get_loader(transform=transform_test, mode="test", cocoapi_loc=cocoapi_dir) 30 | 31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 32 | 33 | # Specify the saved models to load. 34 | encoder_file = "encoder-3.pkl" 35 | decoder_file = "decoder-3.pkl" 36 | 37 | # Select appropriate values for the Python variables below. 38 | embed_size = 256 39 | hidden_size = 512 40 | 41 | # The size of the vocabulary. 42 | vocab_size = len(data_loader.dataset.vocab) 43 | 44 | # Initialize the encoder and decoder, and set each to inference mode. 45 | encoder = EncoderCNN(embed_size) 46 | decoder = DecoderRNN(embed_size, hidden_size, vocab_size) 47 | encoder.eval() 48 | decoder.eval() 49 | 50 | # Load the trained weights. 51 | encoder.load_state_dict(torch.load(os.path.join("./models", encoder_file))) 52 | decoder.load_state_dict(torch.load(os.path.join("./models", decoder_file))) 53 | 54 | # Move models to GPU if CUDA is available. 55 | encoder.to(device) 56 | decoder.to(device) 57 | 58 | 59 | def predict_caption(image): 60 | if image is None: 61 | return "Please select an image" 62 | 63 | image = transform_test(image).unsqueeze(0) 64 | with torch.no_grad(): 65 | # Moving image Pytorch Tensor to GPU if CUDA is available. 66 | image = image.to(device) 67 | 68 | # Obtaining the embedded image features. 69 | features = encoder(image).unsqueeze(1) 70 | 71 | # Passing the embedded image features through the model to get a predicted caption. 72 | output = decoder.sample(features) 73 | 74 | sentence = clean_sentence(output, data_loader.dataset.vocab.idx2word) 75 | 76 | return sentence 77 | 78 | 79 | gr.Interface( 80 | fn=predict_caption, 81 | inputs=gr.Image(type="pil", image_mode="RGB"), 82 | outputs=gr.Textbox(label="Predicted caption"), 83 | flagging_dir="./gradio_logs", 84 | flagging_callback=SimpleCSVLogger(), 85 | ).launch(share=True, server_port=7860) 86 | -------------------------------------------------------------------------------- /images/cnn_rnn_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/images/cnn_rnn_model.png -------------------------------------------------------------------------------- /images/coco-examples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/images/coco-examples.jpg -------------------------------------------------------------------------------- /images/decoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/images/decoder.png -------------------------------------------------------------------------------- /images/encoder-decoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/images/encoder-decoder.png -------------------------------------------------------------------------------- /images/encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/images/encoder.png -------------------------------------------------------------------------------- /images/gradio_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/images/gradio_demo.png -------------------------------------------------------------------------------- /images/readme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/images/readme.png -------------------------------------------------------------------------------- /images/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/images/result.png -------------------------------------------------------------------------------- /images/sample_002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/images/sample_002.png -------------------------------------------------------------------------------- /images/sample_008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/images/sample_008.png -------------------------------------------------------------------------------- /images/sample_029.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/images/sample_029.png -------------------------------------------------------------------------------- /images/sample_034.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/images/sample_034.png -------------------------------------------------------------------------------- /images/sample_107.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/images/sample_107.png -------------------------------------------------------------------------------- /images/sample_171.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/images/sample_171.png -------------------------------------------------------------------------------- /images/sample_193.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/images/sample_193.png -------------------------------------------------------------------------------- /images/sample_202.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/images/sample_202.png -------------------------------------------------------------------------------- /images/sample_296.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/images/sample_296.png -------------------------------------------------------------------------------- /images/sample_326.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/images/sample_326.png -------------------------------------------------------------------------------- /images/sample_366.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/images/sample_366.png -------------------------------------------------------------------------------- /images/sample_440.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/images/sample_440.png -------------------------------------------------------------------------------- /images/sample_457.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/images/sample_457.png -------------------------------------------------------------------------------- /images/sample_498.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/images/sample_498.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | 5 | 6 | # ----------- Encoder ------------ 7 | class EncoderCNN(nn.Module): 8 | def __init__(self, embed_size): 9 | super(EncoderCNN, self).__init__() 10 | resnet = models.resnet50(pretrained=True) 11 | # disable learning for parameters 12 | for param in resnet.parameters(): 13 | param.requires_grad_(False) 14 | 15 | modules = list(resnet.children())[:-1] 16 | self.resnet = nn.Sequential(*modules) 17 | self.embed = nn.Linear(resnet.fc.in_features, embed_size) 18 | 19 | def forward(self, images): 20 | features = self.resnet(images) 21 | features = features.view(features.size(0), -1) 22 | features = self.embed(features) 23 | return features 24 | 25 | 26 | # --------- Decoder ---------- 27 | class DecoderRNN(nn.Module): 28 | def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1): 29 | """ 30 | Args: 31 | embed_size: final embedding size of the CNN encoder 32 | hidden_size: hidden size of the LSTM 33 | vocab_size: size of the vocabulary 34 | num_layers: number of layers of the LSTM 35 | """ 36 | super(DecoderRNN, self).__init__() 37 | 38 | # Assigning hidden dimension 39 | self.hidden_dim = hidden_size 40 | # Map each word index to a dense word embedding tensor of embed_size 41 | self.embed = nn.Embedding(vocab_size, embed_size) 42 | # Creating LSTM layer 43 | self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) 44 | # Initializing linear to apply at last of RNN layer for further prediction 45 | self.linear = nn.Linear(hidden_size, vocab_size) 46 | # Initializing values for hidden and cell state 47 | self.hidden = (torch.zeros(1, 1, hidden_size), torch.zeros(1, 1, hidden_size)) 48 | 49 | def forward(self, features, captions): 50 | """ 51 | Args: 52 | features: features tensor. shape is (bs, embed_size) 53 | captions: captions tensor. shape is (bs, cap_length) 54 | Returns: 55 | outputs: scores of the linear layer 56 | 57 | """ 58 | # remove token from captions and embed captions 59 | cap_embedding = self.embed( 60 | captions[:, :-1] 61 | ) # (bs, cap_length) -> (bs, cap_length-1, embed_size) 62 | 63 | # concatenate the images features to the first of caption embeddings. 64 | # [bs, embed_size] => [bs, 1, embed_size] concat [bs, cap_length-1, embed_size] 65 | # => [bs, cap_length, embed_size] add encoded image (features) as t=0 66 | embeddings = torch.cat((features.unsqueeze(dim=1), cap_embedding), dim=1) 67 | 68 | # getting output i.e. score and hidden layer. 69 | # first value: all the hidden states throughout the sequence. second value: the most recent hidden state 70 | lstm_out, self.hidden = self.lstm( 71 | embeddings 72 | ) # (bs, cap_length, hidden_size), (1, bs, hidden_size) 73 | outputs = self.linear(lstm_out) # (bs, cap_length, vocab_size) 74 | 75 | return outputs 76 | 77 | def sample(self, inputs, states=None, max_len=20): 78 | """ 79 | accepts pre-processed image tensor (inputs) and returns predicted 80 | sentence (list of tensor ids of length max_len) 81 | Args: 82 | inputs: shape is (1, 1, embed_size) 83 | states: initial hidden state of the LSTM 84 | max_len: maximum length of the predicted sentence 85 | 86 | Returns: 87 | res: list of predicted words indices 88 | """ 89 | res = [] 90 | 91 | # Now we feed the LSTM output and hidden states back into itself to get the caption 92 | for i in range(max_len): 93 | lstm_out, states = self.lstm( 94 | inputs, states 95 | ) # lstm_out: (1, 1, hidden_size) 96 | outputs = self.linear(lstm_out.squeeze(dim=1)) # outputs: (1, vocab_size) 97 | _, predicted_idx = outputs.max(dim=1) # predicted: (1, 1) 98 | res.append(predicted_idx.item()) 99 | # if the predicted idx is the stop index, the loop stops 100 | if predicted_idx == 1: 101 | break 102 | inputs = self.embed(predicted_idx) # inputs: (1, embed_size) 103 | # prepare input for next iteration 104 | inputs = inputs.unsqueeze(1) # inputs: (1, 1, embed_size) 105 | 106 | return res 107 | -------------------------------------------------------------------------------- /nlp_utils.py: -------------------------------------------------------------------------------- 1 | from nltk.translate.bleu_score import corpus_bleu 2 | 3 | 4 | def clean_sentence(output, idx2word): 5 | sentence = "" 6 | for i in output: 7 | word = idx2word[i] 8 | if i == 0: 9 | continue 10 | if i == 1: 11 | break 12 | if i == 18: 13 | sentence = sentence + word 14 | else: 15 | sentence = sentence + " " + word 16 | return sentence 17 | 18 | 19 | def bleu_score(true_sentences, predicted_sentences): 20 | hypotheses = [] 21 | references = [] 22 | for img_id in set(true_sentences.keys()).intersection( 23 | set(predicted_sentences.keys()) 24 | ): 25 | img_refs = [cap.split() for cap in true_sentences[img_id]] 26 | references.append(img_refs) 27 | hypotheses.append(predicted_sentences[img_id][0].strip().split()) 28 | 29 | return corpus_bleu(references, hypotheses) 30 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pycocotools~=2.0.4 2 | matplotlib~=3.5.2 3 | scikit-image 4 | torch~=1.8.1 5 | nltk~=3.7 6 | torchvision~=0.9.1 7 | numpy~=1.21.6 8 | pillow~=9.1.1 9 | tqdm~=4.64.0 10 | jupyter==1.0.0 11 | opencv-python==4.6.0.66 12 | gradio==3.1.1 -------------------------------------------------------------------------------- /vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamirmasoud/image_captioning/a95af17e303767fc59f64b2ca274b94d31a6a2a2/vocab.pkl -------------------------------------------------------------------------------- /vocabulary.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import pickle 3 | from collections import Counter 4 | 5 | import nltk 6 | from pycocotools.coco import COCO 7 | 8 | 9 | class Vocabulary(object): 10 | def __init__( 11 | self, 12 | vocab_threshold, 13 | vocab_file="./vocab.pkl", 14 | start_word="", 15 | end_word="", 16 | unk_word="", 17 | annotations_file="../cocoapi/annotations/captions_train2014.json", 18 | vocab_from_file=False, 19 | ): 20 | """Initialize the vocabulary. 21 | Args: 22 | vocab_threshold: Minimum word count threshold. 23 | vocab_file: File containing the vocabulary. 24 | start_word: Special word denoting sentence start. 25 | end_word: Special word denoting sentence end. 26 | unk_word: Special word denoting unknown words. 27 | annotations_file: Path for train annotation file. 28 | vocab_from_file: If False, create vocab from scratch and override any existing vocab_file 29 | If True, load vocab from existing vocab_file, if it exists 30 | """ 31 | self.vocab_threshold = vocab_threshold 32 | self.vocab_file = vocab_file 33 | self.start_word = start_word 34 | self.end_word = end_word 35 | self.unk_word = unk_word 36 | self.annotations_file = annotations_file 37 | self.vocab_from_file = vocab_from_file 38 | self.get_vocab() 39 | 40 | def get_vocab(self): 41 | """Load the vocabulary from file OR build the vocabulary from scratch.""" 42 | if os.path.exists(self.vocab_file) and self.vocab_from_file: 43 | with open(self.vocab_file, "rb") as f: 44 | vocab = pickle.load(f) 45 | self.word2idx = vocab.word2idx 46 | self.idx2word = vocab.idx2word 47 | print("Vocabulary successfully loaded from vocab.pkl file!") 48 | 49 | # create a new vocab file 50 | else: 51 | self.build_vocab() 52 | with open(self.vocab_file, "wb") as f: 53 | pickle.dump(self, f) 54 | 55 | def build_vocab(self): 56 | """Populate the dictionaries for converting tokens to integers (and vice-versa).""" 57 | self.init_vocab() 58 | self.add_word(self.start_word) 59 | self.add_word(self.end_word) 60 | self.add_word(self.unk_word) 61 | self.add_captions() 62 | 63 | def init_vocab(self): 64 | """Initialize the dictionaries for converting tokens to integers (and vice-versa).""" 65 | self.word2idx = {} 66 | self.idx2word = {} 67 | self.idx = 0 68 | 69 | def add_word(self, word): 70 | """Add a token to the vocabulary.""" 71 | if word not in self.word2idx: 72 | self.word2idx[word] = self.idx 73 | self.idx2word[self.idx] = word 74 | self.idx += 1 75 | 76 | def add_captions(self): 77 | """Loop over training captions and add all tokens to the vocabulary that meet or exceed the threshold.""" 78 | coco = COCO(self.annotations_file) 79 | counter = Counter() 80 | ids = coco.anns.keys() 81 | for i, idx in enumerate(ids): 82 | caption = str(coco.anns[idx]["caption"]) 83 | tokens = nltk.tokenize.word_tokenize(caption.lower()) 84 | counter.update(tokens) 85 | 86 | if i % 100000 == 0: 87 | print("[%d/%d] Tokenizing captions..." % (i, len(ids))) 88 | 89 | # keep only words that repeated more than threshold times in the final vocabulary 90 | words = [word for word, cnt in counter.items() if cnt >= self.vocab_threshold] 91 | 92 | for i, word in enumerate(words): 93 | self.add_word(word) 94 | 95 | def __call__(self, word): 96 | if word not in self.word2idx: 97 | return self.word2idx[self.unk_word] 98 | return self.word2idx[word] 99 | 100 | def __len__(self): 101 | return len(self.word2idx) 102 | --------------------------------------------------------------------------------