├── .dockerignore ├── .gitignore ├── Dockerfile.gpu ├── LICENSE ├── README.md ├── core.py ├── dataset.py ├── libraries ├── log.py └── strategies.py ├── main.py ├── model.py └── static ├── cptr_architecture.jpg ├── mlm_000.png ├── mlm_001.png ├── mlm_002.png ├── mlm_003.png ├── mlm_004.png └── mlm_005.png /.dockerignore: -------------------------------------------------------------------------------- 1 | arxiv/ 2 | models/ 3 | source/ 4 | target/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # projects static ressources 132 | images/ 133 | models/ 134 | source/ 135 | target/ -------------------------------------------------------------------------------- /Dockerfile.gpu: -------------------------------------------------------------------------------- 1 | # base image derivation 2 | FROM nvcr.io/nvidia/pytorch:21.08-py3 3 | 4 | # timezone handler 5 | ARG DEBIAN_FRONTEND=noninteractive 6 | ENV TZ=Europe/Paris 7 | 8 | # initial system requirements 9 | RUN apt-get update --fix-missing && \ 10 | apt-get install --yes --no-install-recommends \ 11 | tzdata apt-utils dialog gcc git curl pkg-config build-essential ffmpeg 12 | 13 | # user creation 14 | RUN useradd --gid root --create-home solver 15 | WORKDIR /home/solver 16 | 17 | # virtualenv 18 | ENV VIRTUAL_ENV=/opt/venv 19 | RUN chmod -R g+rwx /home/solver && python -m venv $VIRTUAL_ENV --system-site-packages 20 | ENV PATH="$VIRTUAL_ENV/bin:$PATH" 21 | 22 | # python requirements 23 | RUN pip install --upgrade pip && \ 24 | pip install torchtext torchvision spacy pyzmq click loguru sentence_transformers rich pandas && \ 25 | pip install ftfy regex git+https://github.com/openai/CLIP.git && \ 26 | python -m spacy download en_core_web_sm 27 | 28 | # pull source code 29 | COPY . ./ 30 | 31 | # env variables 32 | ENV IMAGES='images/' 33 | ENV SOURCE='source/' 34 | ENV TARGET='target/' 35 | ENV MODELS='models/' 36 | 37 | # entrypoint 38 | ENTRYPOINT ["python", "main.py"] 39 | CMD ["--debug"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Ibraheem Khalil Ba 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # transformer-image-captioning 2 | Implementation of the paper CPTR : FULL TRANSFORMER NETWORK FOR IMAGE CAPTIONING 3 | 4 |

5 | 6 |

7 | architecture of the CPTR model for image captioning 8 |

9 |

10 | 11 | --- 12 | --- 13 | 14 | # predictions 15 | 16 |

17 | 18 | 19 |

20 | 21 |

22 | 23 | 24 |

25 | 26 |

27 | 28 | 29 |

30 | 31 | # prerequisites 32 | * git 33 | * python3 34 | * python3-venv 35 | * docker 36 | 37 | # clone the repo and prepare data 38 | ```bash 39 | # clone 40 | git clone https://github.com/Milkymap/transformer-image-captioning 41 | cd transformer-image-captioning 42 | # prepare data 43 | # models is the space where resnet152 and clip will be saved 44 | # models also contains the checkpoints during training 45 | # images contains a set of image files for inference time(see docker describe step) 46 | # source contains the data used for training 47 | # the data is in the next format 48 | # images directory : contains all images for training 49 | # captions.json : is a hashmap(image_file_id=>[text, text, text]) 50 | # target contains extracted features such as vectors, tokenizer, vocabulary 51 | mkdir models images source target 52 | ``` 53 | 54 | # docker build and run 55 | ```bash 56 | docker build -t capformer:0.0 -f Dockerfile.gpu 57 | ``` 58 | 59 | # docker run processing step 60 | ```bash 61 | docker run 62 | --rm 63 | --tty 64 | --name capformer 65 | --gpus all 66 | -v $(pwd)/source:/home/solver/source 67 | -v $(pwd)/models:/home/solver/models 68 | -v $(pwd)/target:/home/solver/target 69 | -v $(pwd)/images:/home/solver/images 70 | -e TERM=xterm-256color 71 | capformer:0.0 processing 72 | --path2images /home/solver/source/images 73 | --path2captions /home/solver/source/captions.json 74 | --path2vectorizer /home/solver/models/resnet152.th 75 | --extension jpg 76 | --path2features /home/solver/target/map_img2features.pkl 77 | --path2tokenids /home/solver/target/zip_img2tokenids.pkl 78 | --path2vocabulary /home/solver/target/vocabulary.pkl 79 | ``` 80 | 81 | 82 | # docker run learning step 83 | ```bash 84 | docker run 85 | --rm 86 | --tty 87 | --name capformer 88 | --gpus all 89 | -v $(pwd)/source:/home/solver/source 90 | -v $(pwd)/models:/home/solver/models 91 | -v $(pwd)/target:/home/solver/target 92 | -v $(pwd)/images:/home/solver/images 93 | -e TERM=xterm-256color 94 | capformer:0.0 95 | learning 96 | --path2features /home/solver/target/map_img2features.pkl 97 | --path2tokenids /home/solver/target/zip_img2tokenids.pkl 98 | --path2vocabulary /home/solver/target/vocabulary.pkl 99 | --nb_epochs 92 100 | --bt_size 128 101 | --path2checkpoint /home/solver/models/checkpoint_128.th 102 | --checkpoint 16 103 | --start 0 104 | ``` 105 | 106 | # docker run describe step 107 | ```bash 108 | docker run 109 | --rm 110 | --tty 111 | --name capformer 112 | --gpus all 113 | -v $(pwd)/source:/home/solver/source 114 | -v $(pwd)/models:/home/solver/models 115 | -v $(pwd)/target:/home/solver/target 116 | -v $(pwd)/images:/home/solver/images 117 | -e TERM=xterm-256color 118 | capformer:0.0 119 | describe 120 | --path2vectorizer /home/solver/models/resnet152.th 121 | --path2ranker /home/solver/models/ranker.pkl 122 | --path2vocabulary /home/solver/target/vocabulary.pkl 123 | --path2checkpoint /home/solver/models/checkpoint_128.th 124 | --beam_width 17 125 | --path2image /home/solver/images/bob.jpg 126 | ``` 127 | 128 | # structure of the project 129 | 130 | this project is based on opensource libraries such as **[pytorch, clip(openai), opencv, PIL]** 131 | It contains : 132 | * **core.py** 133 | * this is the main file of the project 134 | * it contains the definition of the transformer 135 | * it is based on the paper Attention Is All You Need 136 | * i added some modifications for handling multiple output of the decoder 137 | * **dataset.py** 138 | * this file contains two classes : 139 | * DatasetForFeaturesExtraction 140 | * DatasetForTraining 141 | * **model.py** 142 | * this file contains the definition of the CPTR model 143 | * it uses the transformer defined on the core module 144 | * it has some additional moduless like : token_embedding, prediction_head 145 | * **libraries** 146 | * contains usefull functions such as : 147 | * log handler 148 | * tokenization 149 | * features extraction 150 | * model loading 151 | * **beam and greedy search** for caption generation 152 | * **static** 153 | * contains images and fonts for the readme 154 | * **main.py** 155 | * this is the entrypoint of the program 156 | * it defines three subcommands 157 | * processing : for features extraction and tokenization 158 | * learning : training loop of the CPTR 159 | * describe : generate caption by taking an image path 160 | * **.gitignore** 161 | * **.dockerignore** 162 | * **Dockerfile.gpu** 163 | * **LICENCE** 164 | * **README.md** 165 | 166 | # Citations 167 | 168 | ```bibtex 169 | @misc{Liu2021cptr, 170 | title = {CPTR: FULL TRANSFORMER NETWORK FOR IMAGE CAPTIONING}, 171 | author = {Wei Liu, Sihan Chen, Longteng Guo, Xinxin Zhu1, Jing Liu1}, 172 | year = {2021}, 173 | eprint = {2101.10804}, 174 | archivePrefix = {arXiv}, 175 | primaryClass = {cs.CV} 176 | } 177 | ``` 178 | -------------------------------------------------------------------------------- /core.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import operator as op 3 | import itertools as it, functools as ft 4 | 5 | import torch as th 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | class PositionalEncoding(nn.Module): 10 | def __init__(self, seq_length, in_dim, drop_val=0.1): 11 | super(PositionalEncoding, self).__init__() 12 | pos = np.arange(0, seq_length)[:, None] 13 | idx = np.fromfunction(lambda _,j: j - j % 2, shape=(1, in_dim)) 14 | mask = np.fromfunction(lambda _,j: j % 2 == 0, shape=(1, in_dim)) 15 | 16 | pnt = pos / (10000 ** (idx / in_dim)) 17 | val = np.sin(pnt) * mask + np.cos(pnt) * (1 - mask) 18 | 19 | self.drop_layer = nn.Dropout(drop_val) 20 | self.register_buffer('psne_layer', th.tensor(val).float()) 21 | 22 | def forward(self, src): 23 | _, seq_length, _ = src.shape 24 | pos = self.psne_layer[:seq_length, :][None, ...] 25 | return self.drop_layer(src + pos) 26 | 27 | class FeedForwardNetwork(nn.Module): 28 | __THETA = { # map id to non_linear 29 | 0: nn.Identity(), 30 | 1: nn.ReLU(), 31 | 2: nn.GELU(), 32 | 3: nn.Sigmoid(), 33 | 4: nn.Tanh(), 34 | 5: nn.Softmax(dim=-1) 35 | } 36 | def __init__(self, layer_cfg, activations, drop_vals): 37 | super(FeedForwardNetwork, self).__init__() 38 | self.shapes = list(zip(layer_cfg[:-1], layer_cfg[1:])) 39 | self.linears = nn.ModuleList([]) 40 | for idx, (in_dim, out_dim) in enumerate(self.shapes): 41 | fn_id = activations[idx] 42 | proba = drop_vals[idx] 43 | block = nn.Sequential( 44 | nn.Linear(in_dim, out_dim), 45 | nn.Dropout(proba) if proba > 0.0 else nn.Identity(), 46 | FeedForwardNetwork.__THETA.get(fn_id, nn.Identity()) 47 | ) 48 | self.linears.append(block) 49 | 50 | def forward(self, input_batch): 51 | output_batch = ft.reduce( # functools 52 | lambda acc, crr: crr(acc), 53 | self.linears, 54 | input_batch 55 | ) 56 | return output_batch 57 | 58 | class MultiHeadCrossAttention(nn.Module): 59 | def __init__(self, in_dim, nb_heads): 60 | super(MultiHeadCrossAttention, self).__init__() 61 | self.nbr_heads = nb_heads 62 | self.heads_dim = in_dim // nb_heads 63 | 64 | self.to_qry = nn.Linear(in_dim, in_dim) 65 | self.to_key = nn.Linear(in_dim, in_dim) 66 | self.to_val = nn.Linear(in_dim, in_dim) 67 | self.to_out = nn.Linear(in_dim, in_dim) 68 | 69 | def __rearrange(self, seq): 70 | bt_size, seq_length, _ = seq.shape # unpack shape 71 | seq = seq.reshape(bt_size, seq_length, self.nbr_heads, self.heads_dim).permute(0, 2, 1, 3) 72 | return seq 73 | 74 | def forward(self, qry, key, val, mask=None, key_padding_mask=None): 75 | 76 | qry = self.to_qry(qry) 77 | key = self.to_key(key) 78 | val = self.to_val(val) 79 | 80 | qry = self.__rearrange(qry) 81 | key = self.__rearrange(key) 82 | val = self.__rearrange(val) 83 | 84 | dim = qry.shape[-1] 85 | 86 | wgt = qry @ key.transpose(-2, -1) 87 | wgt = wgt / np.sqrt(dim) 88 | if mask is not None: 89 | wgt = wgt.masked_fill(mask, float('-inf')) 90 | if key_padding_mask is not None: 91 | cnd = key_padding_mask[:, None, None, :] 92 | wgt = wgt.masked_fill(cnd, float('-inf')) 93 | wgt = th.softmax(wgt, dim=-1) 94 | 95 | res = wgt @ val 96 | res = res.permute(0, 2, 1, 3) # permute head and sequence 97 | res = th.flatten(res, start_dim=2) # concat over heads 98 | res = self.to_out(res) 99 | 100 | return res 101 | 102 | class MultiHeadSelfAttention(nn.Module): 103 | def __init__(self, in_dim, nb_heads): 104 | super(MultiHeadSelfAttention, self).__init__() 105 | 106 | self.nbr_heads = nb_heads 107 | self.heads_dim = in_dim // nb_heads 108 | self.qkv_layer = nn.Linear(in_dim, 3 * in_dim) 109 | self.out_layer = nn.Linear(in_dim, in_dim) 110 | 111 | def forward(self, src, mask=None, key_padding_mask=None): 112 | bt_size, seq_length, _ = src.shape # unpack shape 113 | 114 | qkv = self.qkv_layer(src) # extract query, key and value 115 | qkv = qkv.reshape(bt_size, seq_length, self.nbr_heads, 3 * self.heads_dim) 116 | qkv = qkv.permute(0, 2, 1, 3) # permute head and sequence 117 | qry, key, val = th.chunk(qkv, 3, dim=-1) 118 | 119 | dim = qry.shape[-1] 120 | wgt = qry @ key.transpose(-2, -1) # hidden_dim and sequence_dim 121 | wgt = wgt / np.sqrt(dim) # normalize 122 | if mask is not None: 123 | wgt = wgt.masked_fill(mask, float('-inf')) 124 | if key_padding_mask is not None: 125 | cnd = key_padding_mask[:, None, None, :] 126 | wgt = wgt.masked_fill(cnd, float('-inf')) 127 | wgt = th.softmax(wgt, dim=-1) 128 | 129 | res = wgt @ val 130 | res = res.permute(0, 2, 1, 3) # permute head and sequence 131 | res = th.flatten(res, start_dim=2) # concat over heads 132 | res = self.out_layer(res) 133 | 134 | return res 135 | 136 | class EncoderBlock(nn.Module): 137 | def __init__(self, in_dim, ff_dim, nb_heads, drop_val=0.1, pre_norm=False): 138 | super(EncoderBlock, self).__init__() 139 | assert in_dim % nb_heads == 0 140 | 141 | self.nbr_heads = nb_heads 142 | self.heads_dim = in_dim // nb_heads 143 | 144 | self.mha_layer = MultiHeadSelfAttention(in_dim, nb_heads) 145 | self.ffn_layer = FeedForwardNetwork([in_dim, ff_dim, in_dim], [1, 0], [drop_val, 0.0]) 146 | 147 | self.dropout_layer = nn.ModuleDict({ 148 | 'mha': nn.Dropout(drop_val), 149 | 'ffn': nn.Dropout(drop_val) 150 | }) 151 | self.layer_normalz = nn.ModuleDict({ 152 | 'mha': nn.ModuleList([ 153 | nn.LayerNorm(in_dim) if pre_norm else nn.Identity(), 154 | nn.LayerNorm(in_dim) if not pre_norm else nn.Identity() 155 | ]), 156 | 'ffn': nn.ModuleList([ 157 | nn.LayerNorm(in_dim) if pre_norm else nn.Identity(), 158 | nn.LayerNorm(in_dim) if not pre_norm else nn.Identity() 159 | ]) 160 | }) 161 | 162 | 163 | def forward(self, src, src_mask=None, src_key_padding_mask=None): 164 | # multi head self attention 165 | tmp = self.layer_normalz['mha'][0](src) 166 | out = self.mha_layer(tmp, src_mask, src_key_padding_mask) 167 | out = self.dropout_layer['mha'](out) 168 | agg = tmp + out 169 | agg = self.layer_normalz['mha'][1](agg) 170 | 171 | # feed forward network 172 | tmp = self.layer_normalz['ffn'][0](agg) 173 | out = self.ffn_layer(tmp) 174 | out = self.dropout_layer['ffn'](out) 175 | agg = tmp + out 176 | agg = self.layer_normalz['ffn'][1](agg) 177 | 178 | return agg 179 | 180 | class DecoderBlock(nn.Module): 181 | def __init__(self, in_dim, ff_dim, nb_heads, drop_val=0.1, pre_norm=False): 182 | super(DecoderBlock, self).__init__() 183 | assert in_dim % nb_heads == 0 184 | 185 | self.nbr_heads = nb_heads 186 | self.heads_dim = in_dim // nb_heads 187 | 188 | self.mha_layer = MultiHeadSelfAttention(in_dim, nb_heads) 189 | self.crx_layer = MultiHeadCrossAttention(in_dim, nb_heads) 190 | self.ffn_layer = FeedForwardNetwork([in_dim, ff_dim, in_dim], [1, 0], [drop_val, 0.0]) 191 | 192 | self.dropout_layer = nn.ModuleDict({ 193 | 'mha': nn.Dropout(drop_val), 194 | 'crx': nn.Dropout(drop_val), 195 | 'ffn': nn.Dropout(drop_val) 196 | }) 197 | self.layer_normalz = nn.ModuleDict({ 198 | 'mha': nn.ModuleList([ 199 | nn.LayerNorm(in_dim) if pre_norm else nn.Identity(), 200 | nn.LayerNorm(in_dim) if not pre_norm else nn.Identity() 201 | ]), 202 | 'crx': nn.ModuleList([ 203 | nn.LayerNorm(in_dim) if pre_norm else nn.Identity(), 204 | nn.LayerNorm(in_dim) if not pre_norm else nn.Identity() 205 | ]), 206 | 'ffn': nn.ModuleList([ 207 | nn.LayerNorm(in_dim) if pre_norm else nn.Identity(), 208 | nn.LayerNorm(in_dim) if not pre_norm else nn.Identity() 209 | ]) 210 | }) 211 | 212 | def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None): 213 | # masked multi head attention 214 | tmp = self.layer_normalz['mha'][0](tgt) 215 | out = self.mha_layer(tmp, tgt_mask, tgt_key_padding_mask) 216 | out = self.dropout_layer['mha'](out) 217 | agg = tmp + out # residual 218 | agg = self.layer_normalz['mha'][1](agg) 219 | 220 | # cross multi head attention 221 | tmp = self.layer_normalz['crx'][0](agg) 222 | out = self.crx_layer(tmp, memory, memory, memory_mask, memory_key_padding_mask) 223 | out = self.dropout_layer['crx'](out) 224 | agg = tmp + out # residual 225 | agg = self.layer_normalz['crx'][1](agg) 226 | 227 | # feed forward network 228 | tmp = self.layer_normalz['ffn'][0](agg) 229 | out = self.ffn_layer(agg) 230 | out = self.dropout_layer['ffn'](out) 231 | agg = tmp + out # residual 232 | agg = self.layer_normalz['ffn'][1](agg) 233 | 234 | return agg 235 | 236 | class TransformerEncoder(nn.Module): 237 | def __init__(self, nb_layers, in_dim, ff_dim, nb_heads, drop_val=0.1, pre_norm=False): 238 | super(TransformerEncoder, self).__init__() 239 | self.encoders = nn.ModuleList([]) 240 | for _ in range(nb_layers): 241 | blk = EncoderBlock(in_dim=in_dim, ff_dim=ff_dim, nb_heads=nb_heads, drop_val=drop_val, pre_norm=pre_norm) 242 | self.encoders.append(blk) 243 | 244 | def forward(self, src, mask=None, key_padding_mask=None): 245 | fnl = ft.reduce( 246 | lambda acc, crr: acc + [crr(acc[-1], mask, key_padding_mask)], 247 | self.encoders, 248 | [src] 249 | ) 250 | return fnl[1:] # ignore src 251 | 252 | class TransformerDecoder(nn.Module): 253 | def __init__(self, nb_layers, in_dim, ff_dim, nb_heads, drop_val=0.1, pre_norm=False): 254 | super(TransformerDecoder, self).__init__() 255 | self.decoders = nn.ModuleList([]) 256 | for _ in range(nb_layers): 257 | blk = DecoderBlock(in_dim=in_dim, ff_dim=ff_dim, nb_heads=nb_heads, drop_val=drop_val, pre_norm=pre_norm) 258 | self.decoders.append(blk) 259 | 260 | def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None): 261 | lng = len(memory) - 1 262 | fnl = ft.reduce( 263 | lambda acc,crr: acc + [crr[1](acc[-1], memory[-1], tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask)], 264 | enumerate(self.decoders), 265 | [tgt] 266 | ) 267 | return fnl[1:] # ignore tgt 268 | 269 | class Transformer(nn.Module): 270 | def __init__(self, in_dim, ff_dim, nb_heads, encoder_depth, decoder_depth, drop_val=0.1, pre_norm=False): 271 | super(Transformer, self).__init__() 272 | self.encoder = TransformerEncoder( 273 | nb_layers=encoder_depth, 274 | in_dim=in_dim, 275 | ff_dim=ff_dim, 276 | nb_heads=nb_heads, 277 | drop_val=drop_val, 278 | pre_norm=pre_norm 279 | ) 280 | self.decoder = TransformerDecoder( 281 | nb_layers=decoder_depth, 282 | in_dim=in_dim, 283 | ff_dim=ff_dim, 284 | nb_heads=nb_heads, 285 | drop_val=drop_val, 286 | pre_norm=pre_norm 287 | ) 288 | 289 | def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None): 290 | memory = self.encoder(src, src_mask, src_key_padding_mask) 291 | output = self.decoder(tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask) 292 | return output 293 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | from os import path 4 | from torch.utils.data import Dataset 5 | from libraries.strategies import pull_images, read_image, prepare_image, deserialize, cv2th 6 | 7 | class DatasetForFeaturesExtraction(Dataset): 8 | def __init__(self, path2images, file_extension='*.jpg'): 9 | self.image_paths = pull_images(path2images, exts=file_extension) 10 | self.image_names = [ path.split(path_)[1] for path_ in self.image_paths ] 11 | 12 | def __len__(self): 13 | return len(self.image_paths) 14 | 15 | def __getitem__(self, index): 16 | path2img = self.image_paths[index] 17 | cv_image = read_image(path2img) 18 | th_image = cv2th(cv_image) 19 | return prepare_image(th_image) 20 | 21 | class DatasetForTraining(Dataset): 22 | def __init__(self, path2tokenids, path2features): 23 | self.tokenids = deserialize(path2tokenids) 24 | self.features = deserialize(path2features) 25 | def __len__(self): 26 | return len(self.tokenids) 27 | 28 | def __getitem__(self, idx): 29 | file_name, ids = self.tokenids[idx] 30 | vec = self.features[file_name] 31 | vec = th.tensor(vec).float() 32 | ids = th.tensor(ids).long() 33 | return vec, ids -------------------------------------------------------------------------------- /libraries/log.py: -------------------------------------------------------------------------------- 1 | from sys import stdout 2 | from loguru import logger 3 | 4 | 5 | log_format = [ 6 | '{time:YYYY-MM-DD hh:mm:ss}', 7 | '{file:^15}', 8 | '{line:03d}', 9 | '{level:^10}', 10 | '{message:<50}' 11 | ] 12 | 13 | log_separator = ' | ' 14 | 15 | logger.remove() 16 | logger.add( 17 | sink=stdout, 18 | level='TRACE', 19 | format=log_separator.join(log_format) 20 | ) 21 | 22 | if __name__ == '__main__': 23 | logger.debug('check if log is available...!') -------------------------------------------------------------------------------- /libraries/strategies.py: -------------------------------------------------------------------------------- 1 | import io 2 | import json 3 | from multiprocessing.sharedctypes import Value 4 | import pickle as pk 5 | 6 | import clip 7 | 8 | import cv2 9 | import numpy as np 10 | import torch as th 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | import operator as op 15 | import itertools as it, functools as ft 16 | 17 | from os import path 18 | from glob import glob 19 | 20 | from PIL import Image 21 | from rich.progress import track 22 | from torch.nn.utils.rnn import pad_sequence 23 | from torchvision import models 24 | from torchvision import transforms as T 25 | from torchtext.data.utils import get_tokenizer 26 | from torchtext.vocab import build_vocab_from_iterator 27 | 28 | from libraries.log import logger 29 | 30 | SPECIALS2IDX = {"": 0, "": 1, "": 2, "": 3} 31 | 32 | def pull_files(endpoint, extension): 33 | file_paths = sorted(glob(path.join(endpoint, extension))) 34 | return file_paths 35 | 36 | def build_tokenizer(tok_name='spacy', lang='en_core_web_sm'): 37 | tokenizer = get_tokenizer(tokenizer=tok_name, language=lang) 38 | return tokenizer 39 | 40 | def yield_tokens(data_iter, tokenizer): 41 | for sample in track(data_iter, description=f'tokenization process'): 42 | sample = sample.strip().lower() # remove trailing keys and lowercase 43 | yield tokenizer(sample) 44 | 45 | def make_vocab(data_iter, tokenizer, map_specials2idx): 46 | vocab = build_vocab_from_iterator( 47 | iterator=yield_tokens(data_iter, tokenizer), 48 | specials=list(map_specials2idx.keys()), 49 | min_freq=1, 50 | special_first=True 51 | ) 52 | vocab.set_default_index(map_specials2idx['']) # index of the token 53 | return vocab 54 | 55 | def serialize(path2dump, data): 56 | with open(path2dump, mode='wb') as fp: 57 | pk.dump(data, fp) 58 | 59 | def deserialize(path2dump): 60 | with open(path2dump, mode='rb') as fp: 61 | return pk.load(fp) 62 | 63 | def serialize(path2dump, data): 64 | with open(path2dump, mode='wb') as fp: 65 | pk.dump(data, fp) 66 | 67 | def deserialize(path2dump): 68 | with open(path2dump, mode='rb') as fp: 69 | return pk.load(fp) 70 | 71 | def load_vectorizer(path2models): 72 | if path.isfile(path2models): 73 | features_extractor = th.load(path2models) 74 | else: 75 | model_name = path.split(path2models)[1] 76 | real_name, _ = model_name.split('.') 77 | endpoint = op.attrgetter(real_name)(models) 78 | if endpoint is not None: 79 | features_extractor = endpoint(pretrained=True, progress=True) 80 | features_extractor = nn.Sequential(*list(features_extractor.children())[:-2]) 81 | for prm in features_extractor.parameters(): 82 | prm.requires_grad = False 83 | th.save(features_extractor, path2models) 84 | else: 85 | raise Value(f'{real_name} is not a valid option for torchvision.models') 86 | return features_extractor 87 | 88 | def load_ranker(path2models, device): 89 | if path.isfile(path2models): 90 | with open(path2models, 'rb') as fp: 91 | model, processor = pk.load(fp) # (model, processor) 92 | else: 93 | model, processor = clip.load("ViT-B/32") 94 | with open(path2models, 'wb') as fp: 95 | pk.dump((model, processor), fp) 96 | return model.to(device), processor 97 | 98 | def pull_images(path2images, exts='*.jpg'): 99 | return sorted( glob(path.join(path2images, '**' ,exts), recursive=True) ) 100 | 101 | def th2cv(th_image): 102 | red, green, blue = th_image.numpy() 103 | return cv2.merge((blue, green, red)) 104 | 105 | def cv2th(cv_image): 106 | blue, green, red = cv2.split(cv_image) 107 | return th.as_tensor(np.stack([red, green, blue])) 108 | 109 | def cv2pil(cv_image): 110 | return Image.fromarray(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)) 111 | 112 | def pil2cv(pil_image): 113 | return cv2.cvtColor(np.asarray(pil_image), cv2.COLOR_RGB2BGR) 114 | 115 | def read_image(path2image, size=None): 116 | pl_image = Image.open(path2image).convert('RGB') 117 | cv_image = cv2.cvtColor(np.array(pl_image), cv2.COLOR_RGB2BGR) 118 | if size is not None: 119 | return cv2.resize(cv_image, size, interpolation=cv2.INTER_CUBIC) 120 | return cv_image 121 | 122 | def save_image(cv_image, path2location): 123 | cv2.imwrite(path2location, cv_image) 124 | 125 | def prepare_image(th_image): 126 | normalied_th_image = th_image / th.max(th_image) 127 | return T.Compose([ 128 | T.RandomHorizontalFlip(p=0.5), 129 | T.Resize((256, 256)), 130 | T.CenterCrop((224, 224)), 131 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 132 | ])(normalied_th_image) 133 | 134 | def extract_features(extractor, batch_of_images): 135 | with th.no_grad(): 136 | features = extractor(batch_of_images) 137 | return features 138 | 139 | def rank_solutions(pil_image, sentences, ranker, processor, device): 140 | image = processor(pil_image).unsqueeze(0).to(device) 141 | tokens = clip.tokenize(sentences).to(device) 142 | with th.no_grad(): 143 | logits, _ = ranker(image, tokens) 144 | probabilities = th.softmax(logits, dim=1).cpu().squeeze(0) 145 | lowest = th.min(probabilities) 146 | largest = th.max(probabilities) 147 | normalized_scores = (probabilities - lowest) / (largest - lowest) 148 | return normalized_scores.tolist() 149 | 150 | def custom_fn(batch): 151 | features, token_ids = list(zip(*batch)) 152 | features = th.stack(features) # 3d 153 | token_ids = list(token_ids) 154 | token_ids = pad_sequence(token_ids, batch_first=True, padding_value=SPECIALS2IDX['']) 155 | return features, token_ids 156 | 157 | def build_mask(seq): 158 | seq_length = seq.shape[1] 159 | mask = np.fromfunction(lambda i,j: j > i, shape=(seq_length, seq_length)) 160 | return th.as_tensor(mask) 161 | 162 | def build_key_padding_mask(seq, pad_idx): 163 | seq_key_padding_mask = (seq == pad_idx) 164 | return seq_key_padding_mask 165 | 166 | def greedy_search(model, source, BOS, EOS, max_len, device): 167 | memory = model.encode(source.to(device)) 168 | target = th.tensor([[BOS]]) 169 | keep_generating = True 170 | while keep_generating: 171 | output = model.decode(target.to(device), memory).squeeze(0) 172 | logits = model.generator(output[-1, :]) 173 | scaled_logits = th.log_softmax(logits, dim=-1).squeeze(0) 174 | candidate = th.argmax(scaled_logits) 175 | target = th.cat([target, th.tensor([[candidate]])], dim=1) 176 | keep_generating = (candidate != EOS) and (target.shape[1] < max_len) 177 | return th.flatten(target) 178 | 179 | def beam_search(model, source, BOS, EOS, max_len, device, beam_width, alpha=0.7): 180 | memory = model.encode(source.to(device)) 181 | target = th.tensor([[BOS]]) 182 | with th.no_grad(): 183 | output = model.decode(target.to(device), memory) 184 | output = th.stack([ model.generator(out[0, -1, :]) for out in output ]) 185 | logits = th.mean(output, dim=0) 186 | 187 | scaled_logits = th.log_softmax(logits[None, :], dim=1).cpu().squeeze(0) # over vocab size 188 | weights, candidates = th.topk(input=scaled_logits, k=beam_width, largest=True) 189 | 190 | response_tracker = [] # for valid final sequence 191 | sequence_tracker = [] # for current active sequence 192 | for idx in candidates: 193 | option = th.tensor([[idx]]) # a new option into the search tree 194 | sequence = th.cat([target, option], dim=1) 195 | sequence_tracker.append(sequence) 196 | 197 | keep_generating = True 198 | while keep_generating: 199 | input_batch = th.vstack(sequence_tracker) 200 | with th.no_grad(): 201 | input_memory = [m.repeat(input_batch.shape[0], 1, 1) for m in memory ] 202 | output = model.decode(input_batch.to(device), input_memory) 203 | logits = th.mean(th.stack([ model.generator(out[:, -1, :]) for out in output ]), dim=0) 204 | 205 | scaled_logits = th.log_softmax(logits, dim=1).cpu() 206 | 207 | length = input_batch.shape[1] 208 | vocab_size = scaled_logits.shape[1] 209 | weighted_logits = (scaled_logits + weights[:, None]) / length ** alpha 210 | weights, candidates = th.topk(th.flatten(weighted_logits), k=beam_width, largest=True) 211 | weights = weights * length ** alpha # denormalize 212 | 213 | weights_tmp = [] 214 | sequence_tmp = [] 215 | for idx, pos in enumerate(candidates): 216 | row = th.div(pos, vocab_size, rounding_mode='floor') # get relative position over nb_sequences 217 | col = pos % vocab_size # get relative position over vocab_size 218 | sequence = th.cat([sequence_tracker[row], th.tensor([[col]])], dim=1) 219 | if col == EOS: 220 | logger.success('a sentence was generated :)') 221 | flattened_sequence = th.flatten(sequence).tolist() 222 | sequence_score = weights[idx] / len(flattened_sequence) ** alpha 223 | response_tracker.append((flattened_sequence, sequence_score)) # a sentence was built 224 | if len(response_tracker) == beam_width: 225 | keep_generating = False 226 | break # end the for loop over candidates 227 | elif sequence.shape[1] < max_len - 1: 228 | weights_tmp.append(weights[row]) 229 | sequence_tmp.append(sequence) 230 | # end for loop over candidates ...! 231 | 232 | if len(sequence_tmp) == 0: 233 | keep_generating = False 234 | else: 235 | weights = th.tensor(weights_tmp) 236 | sequence_tracker = sequence_tmp 237 | # end while search loop ...! 238 | return response_tracker 239 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from pickletools import optimize 2 | import click 3 | 4 | from torch.utils.data import DataLoader 5 | 6 | from os import getenv, path 7 | from time import sleep 8 | from rich.progress import track 9 | from dataset import DatasetForFeaturesExtraction, DatasetForTraining 10 | from libraries.log import logger 11 | from libraries.strategies import * 12 | 13 | from model import CaptionTransformer 14 | 15 | 16 | @click.group(chain=False, invoke_without_command=True) 17 | @click.option('--debug/--no-debug', help='debug mode flag', default=True) 18 | @click.pass_context 19 | def router_command(ctx, debug): 20 | ctx.ensure_object(dict) 21 | 22 | models = getenv('MODELS') 23 | source = getenv('SOURCE') 24 | target = getenv('TARGET') 25 | images = getenv('IMAGES') 26 | 27 | assert models is not None and path.isdir(models) 28 | assert source is not None and path.isdir(source) 29 | assert target is not None and path.isdir(target) 30 | assert images is not None and path.isdir(images) 31 | 32 | ctx.obj['debug'] = debug 33 | command = ctx.invoked_subcommand 34 | if command is None: 35 | logger.debug('no command was called, add --help option to see the avaiables command') 36 | else: 37 | logger.debug(f'{command} was called') 38 | 39 | @router_command.command() 40 | @click.option('--path2vectorizer', help='path to models for features extraction', type=click.Path(False)) 41 | @click.option('--path2images', help='path to images directory', type=click.Path(True)) 42 | @click.option('--path2captions', help='path to captions json file', type=click.Path(True)) 43 | @click.option('--extension', help='image file extension', type=click.Choice(['jpg', 'jpeg'])) 44 | @click.option('--path2features', help='path to features dump location', type=click.Path(False)) 45 | @click.option('--path2tokenids', help='path to tokenids dump lication', type=click.Path(False)) 46 | @click.option('--path2vocabulary', help='path to vacabulary dump location', type=click.Path(False)) 47 | def processing(path2vectorizer, path2images, path2captions, extension, path2features, path2tokenids, path2vocabulary): 48 | device = th.device('cuda:0' if th.cuda.is_available() else 'cpu') 49 | 50 | with open(file=path2captions, mode='r') as fp: 51 | img2captions = json.load(fp) 52 | 53 | captions = list(img2captions.values()) 54 | captions = list(it.chain(*captions)) 55 | 56 | tokenizer = build_tokenizer(tok_name='spacy', lang='en_core_web_sm') 57 | vocabulary = make_vocab(captions, tokenizer, SPECIALS2IDX) 58 | logger.success('vocaulary was built') 59 | 60 | serialize(path2vocabulary, vocabulary) 61 | 62 | bos = th.tensor([SPECIALS2IDX['']]) 63 | eos = th.tensor([SPECIALS2IDX['']]) 64 | 65 | zip_img2tokenids = [] 66 | logger.debug('caption tokenization') 67 | for key, val in track(img2captions.items(), 'build map_img2tokenids'): 68 | for cap in val: 69 | tok = tokenizer(cap.strip().lower()) 70 | idx = th.tensor(vocabulary(tok)) 71 | seq = th.cat([bos, idx, eos]).numpy() # more effective for storage 72 | zip_img2tokenids.append((key, seq)) 73 | 74 | serialize(path2tokenids, zip_img2tokenids) 75 | 76 | logger.debug('features extraction loading') 77 | vectorizer = load_vectorizer(path2vectorizer) 78 | vectorizer.eval() 79 | vectorizer.to(device) 80 | 81 | dataset = DatasetForFeaturesExtraction(path2images, f'*.{extension}') 82 | 83 | logger.debug('extraction will start') 84 | accumulator = [] 85 | for sections in track(dataset, 'features extraction'): 86 | embedding = extract_features(vectorizer, sections[None, ...].to(device)).squeeze(0) # (2048, 7, 7) 87 | embedding = th.flatten(embedding, start_dim=1).T.cpu().numpy() # 49, 2048 88 | accumulator.append(embedding) 89 | 90 | image_names = dataset.image_names 91 | accumulator = np.stack(accumulator) # stack over batch axis ==> (nb_images, 49, 512) 92 | logger.debug(f'accumulated features shape : {accumulator.shape}') 93 | assert len(image_names) == len(accumulator) 94 | map_img2features = dict(zip(image_names, accumulator)) 95 | 96 | serialize(path2features, map_img2features) 97 | 98 | logger.success('features, tokenids and vocabulary were saved') 99 | 100 | @router_command.command() 101 | @click.option('--path2vocabulary', help='path to vacabulary dump location', type=click.Path(True)) 102 | @click.option('--path2features', help='path to features dump location', type=click.Path(True)) 103 | @click.option('--path2tokenids', help='path to tokenids dump lication', type=click.Path(True)) 104 | @click.option('--nb_epochs', help='number of epochs', type=int, default=128) 105 | @click.option('--bt_size', help='batch size', type=int, default=32) 106 | @click.option('--path2checkpoint', help='path to checkpoint model', type=click.Path(False)) 107 | @click.option('--checkpoint', help='checkpoint period(save model)', type=int, default=16) 108 | @click.option('--start', help='start epoch index', type=int, default=0) 109 | def learning(path2vocabulary, path2features, path2tokenids, nb_epochs, bt_size, path2checkpoint, checkpoint, start): 110 | basepath2models = getenv('MODELS') 111 | 112 | device = th.device('cuda:0' if th.cuda.is_available() else 'cpu') 113 | 114 | logger.debug('load vocabulary') 115 | vocabulary = deserialize(path2vocabulary) 116 | nb_tokens = len(vocabulary) 117 | 118 | logger.debug('build dataset') 119 | dataset = DatasetForTraining(path2tokenids, path2features) 120 | logger.debug(f'size of the dataset : {len(dataset):05d}') 121 | dataloader = DataLoader(dataset, batch_size=bt_size, shuffle=True, collate_fn=custom_fn) 122 | nb_data = len(dataset) 123 | 124 | logger.debug('define network') 125 | if path.isfile(path2checkpoint): 126 | net = th.load(path2checkpoint) 127 | else: 128 | net = CaptionTransformer( 129 | in_dim=2048, 130 | hd_dim=256, 131 | ff_dim=512, 132 | nb_heads=8, 133 | num_encoders=5, 134 | num_decoders=5, 135 | pre_norm=False, 136 | seq_length=64, 137 | nb_tokens=nb_tokens, 138 | padding_idx=SPECIALS2IDX[''] 139 | ) 140 | 141 | net.to(device) 142 | net.train() 143 | 144 | print(net) 145 | 146 | optimizer = th.optim.Adam(net.parameters(), lr=1e-4, betas=(0.9, 0.99), eps=1e-9) 147 | criterion = nn.CrossEntropyLoss(ignore_index=SPECIALS2IDX['']) 148 | logger.debug('training will begin ...!') 149 | sleep(1) 150 | 151 | nb_epochs += start 152 | for epoch in range(start, nb_epochs): 153 | counter = 0 154 | for src, tgt in dataloader: 155 | counter += len(tgt) 156 | tgt_input = tgt[:, :-1] 157 | tgt_output = tgt[:, 1:] 158 | 159 | tgt_mask = build_mask(tgt_input).to(device) 160 | tgt_key_padding_mask = build_key_padding_mask(tgt_input, SPECIALS2IDX['']).to(device) 161 | 162 | memory = net.encode(src=src.to(device)) 163 | output = net.decode( 164 | tgt=tgt_input.to(device), 165 | memory=memory, 166 | tgt_mask=tgt_mask, 167 | tgt_key_padding_mask=tgt_key_padding_mask 168 | ) 169 | 170 | logits = [net.generator(out) for out in output ] 171 | logits = [ th.flatten(prb, start_dim=0, end_dim=1) for prb in logits ] 172 | tgt_output = th.flatten(tgt_output) 173 | 174 | optimizer.zero_grad() 175 | errors = [ criterion(prb, tgt_output.to(device)) for prb in logits ] 176 | error = sum(errors) 177 | error.backward() 178 | optimizer.step() 179 | 180 | message = [] 181 | for err in errors: 182 | msg = f'{err.cpu().item():07.3f}' 183 | message.append(msg) 184 | message = ' | '.join(message) 185 | logger.debug(f'[{epoch:03d}/{nb_epochs:03d}] [{counter:05d}/{nb_data:05d}] | Loss : {error.cpu().item():07.3f} >> {message}') 186 | # end for loop over batchs 187 | 188 | if epoch % checkpoint == 0: 189 | path2network = path.join(basepath2models, f'checkpoint_{epoch:03d}.th') 190 | th.save(net.cpu(), path2network) 191 | net.to(device) 192 | logger.success(f'a snapshot was saved {path2network}') 193 | 194 | # end for loop over epochs 195 | 196 | path2network = path.join(basepath2models, f'checkpoint_###.th') 197 | th.save(net.cpu(), path2network) 198 | logger.success(f'a snapshot was saved {path2network}') 199 | logger.success('end of training') 200 | 201 | 202 | 203 | @router_command.command() 204 | @click.option('--path2vectorizer', help='name of the stored model(features extractor)', type=str) 205 | @click.option('--path2checkpoint', help='model snapshot filename', type=str) 206 | @click.option('--path2image', help='image to describe', type=str) 207 | @click.option('--path2vocabulary', help='vocabulary object', type=str) 208 | @click.option('--beam_width', help='size of beam', type=int, default=7) 209 | @click.option('--path2ranker', help='name of the ranker model', type=str) 210 | def describe(path2vectorizer, path2checkpoint, path2image, path2vocabulary, beam_width, path2ranker): 211 | device = th.device('cuda:0' if th.cuda.is_available() else 'cpu') 212 | 213 | logger.debug('env variables loading') 214 | logger.debug('features, vocab and token_ids loading') 215 | 216 | if path.isfile(path2checkpoint): 217 | logger.debug('model(snapshot) will be loaded') 218 | net = th.load(path2checkpoint) 219 | net.to(device) 220 | net.eval() 221 | 222 | vocab = deserialize(path2vocabulary) 223 | logger.debug(f'vocab was loaded | len => {len(vocab)}') 224 | 225 | logger.debug(f'load features extractor') 226 | 227 | vectorizer = load_vectorizer(path2vectorizer) 228 | vectorizer.eval() 229 | vectorizer.to(device) 230 | 231 | logger.debug('load ranker clip VIT model') 232 | ranker, processor = load_ranker(path2ranker, device) 233 | 234 | logger.debug('features extraction by resnet152') 235 | 236 | cv_image = read_image(path2image) 237 | th_image = cv2th(cv_image) 238 | th_image = prepare_image(th_image) 239 | 240 | embedding = extract_features(vectorizer, th_image[None, ...].to(device)).squeeze(0) 241 | output_batch = th.flatten(embedding, start_dim=1).T # 49, 2048 242 | 243 | response = beam_search( 244 | model=net, 245 | source=output_batch[None, ...], 246 | BOS=SPECIALS2IDX[''], 247 | EOS=SPECIALS2IDX[''], 248 | max_len=64, 249 | beam_width=beam_width, 250 | device=device, 251 | alpha=0.7 252 | ) 253 | 254 | logger.debug(f'nb generated : {len(response)}') 255 | sentences = [] 256 | for sequence, _ in response: 257 | caption = vocab.lookup_tokens(sequence[1:-1]) # ignore and 258 | joined_caption = ' '.join(caption) 259 | sentences.append(joined_caption) 260 | 261 | logger.debug('ranking will begin...!') 262 | pil_image = cv2pil(cv_image) 263 | ranked_scores = rank_solutions(pil_image, sentences, ranker, processor, device) 264 | ranked_response = list(zip(sentences, ranked_scores)) 265 | ranked_response = sorted(ranked_response, key=op.itemgetter(1), reverse=True) 266 | 267 | for caption, score in ranked_response: 268 | score = int(score * 100) 269 | logger.debug(f'caption : {caption} | score : {score:03d}') 270 | 271 | if __name__ == '__main__': 272 | router_command(obj={}) 273 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | import torch.nn as nn 4 | 5 | from core import Transformer, PositionalEncoding 6 | 7 | class CaptionTransformer(nn.Module): 8 | def __init__(self, in_dim, hd_dim, ff_dim, nb_heads, num_encoders, num_decoders, pre_norm, seq_length, nb_tokens, padding_idx): 9 | super(CaptionTransformer, self).__init__() 10 | self.embedding_scale = np.sqrt(hd_dim) 11 | self.position_encoder = PositionalEncoding(seq_length, hd_dim) 12 | self.adaptaror = nn.Linear(in_dim, hd_dim) 13 | self.token_embedder = nn.Embedding(nb_tokens, hd_dim, padding_idx) 14 | self.transformer = Transformer( 15 | in_dim=hd_dim, 16 | ff_dim=ff_dim, 17 | nb_heads=nb_heads, 18 | encoder_depth=num_encoders, 19 | decoder_depth=num_decoders, 20 | pre_norm=pre_norm 21 | ) 22 | self.generator = nn.Linear(hd_dim, nb_tokens) 23 | 24 | 25 | def encode(self, src, src_mask=None, src_key_padding_mask=None): 26 | src = self.adaptaror(src) # reduce the dimension for in_dim to hd_dim 27 | src = self.position_encoder(src) 28 | memory = self.transformer.encoder(src, src_mask, src_key_padding_mask) 29 | return memory 30 | 31 | def decode(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None): 32 | tgt = self.token_embedder(tgt) * self.embedding_scale 33 | tgt = self.position_encoder(tgt) 34 | output = self.transformer.decoder(tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask) 35 | return output 36 | 37 | def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None): 38 | src = self.position_encoder(self.adaptaror(src)) 39 | embedded_tgt = self.token_embedder(tgt) 40 | embedded_tgt = self.position_encoder(embedded_tgt) 41 | 42 | output = self.transformer( 43 | src=src, 44 | tgt=embedded_tgt, 45 | src_mask=src_mask, 46 | tgt_mask=tgt_mask, 47 | memory_mask=memory_mask, 48 | src_key_padding_mask=src_key_padding_mask, 49 | tgt_key_padding_mask=tgt_key_padding_mask, 50 | memory_key_padding_mask=memory_key_padding_mask 51 | ) 52 | 53 | return self.generator(output[-1]) 54 | 55 | 56 | -------------------------------------------------------------------------------- /static/cptr_architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/milkymap/transformer-image-captioning/6441d684e7b181aefb7c76d3ff548f0995b06d71/static/cptr_architecture.jpg -------------------------------------------------------------------------------- /static/mlm_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/milkymap/transformer-image-captioning/6441d684e7b181aefb7c76d3ff548f0995b06d71/static/mlm_000.png -------------------------------------------------------------------------------- /static/mlm_001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/milkymap/transformer-image-captioning/6441d684e7b181aefb7c76d3ff548f0995b06d71/static/mlm_001.png -------------------------------------------------------------------------------- /static/mlm_002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/milkymap/transformer-image-captioning/6441d684e7b181aefb7c76d3ff548f0995b06d71/static/mlm_002.png -------------------------------------------------------------------------------- /static/mlm_003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/milkymap/transformer-image-captioning/6441d684e7b181aefb7c76d3ff548f0995b06d71/static/mlm_003.png -------------------------------------------------------------------------------- /static/mlm_004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/milkymap/transformer-image-captioning/6441d684e7b181aefb7c76d3ff548f0995b06d71/static/mlm_004.png -------------------------------------------------------------------------------- /static/mlm_005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/milkymap/transformer-image-captioning/6441d684e7b181aefb7c76d3ff548f0995b06d71/static/mlm_005.png --------------------------------------------------------------------------------