├── .dockerignore
├── .gitignore
├── Dockerfile.gpu
├── LICENSE
├── README.md
├── core.py
├── dataset.py
├── libraries
├── log.py
└── strategies.py
├── main.py
├── model.py
└── static
├── cptr_architecture.jpg
├── mlm_000.png
├── mlm_001.png
├── mlm_002.png
├── mlm_003.png
├── mlm_004.png
└── mlm_005.png
/.dockerignore:
--------------------------------------------------------------------------------
1 | arxiv/
2 | models/
3 | source/
4 | target/
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # projects static ressources
132 | images/
133 | models/
134 | source/
135 | target/
--------------------------------------------------------------------------------
/Dockerfile.gpu:
--------------------------------------------------------------------------------
1 | # base image derivation
2 | FROM nvcr.io/nvidia/pytorch:21.08-py3
3 |
4 | # timezone handler
5 | ARG DEBIAN_FRONTEND=noninteractive
6 | ENV TZ=Europe/Paris
7 |
8 | # initial system requirements
9 | RUN apt-get update --fix-missing && \
10 | apt-get install --yes --no-install-recommends \
11 | tzdata apt-utils dialog gcc git curl pkg-config build-essential ffmpeg
12 |
13 | # user creation
14 | RUN useradd --gid root --create-home solver
15 | WORKDIR /home/solver
16 |
17 | # virtualenv
18 | ENV VIRTUAL_ENV=/opt/venv
19 | RUN chmod -R g+rwx /home/solver && python -m venv $VIRTUAL_ENV --system-site-packages
20 | ENV PATH="$VIRTUAL_ENV/bin:$PATH"
21 |
22 | # python requirements
23 | RUN pip install --upgrade pip && \
24 | pip install torchtext torchvision spacy pyzmq click loguru sentence_transformers rich pandas && \
25 | pip install ftfy regex git+https://github.com/openai/CLIP.git && \
26 | python -m spacy download en_core_web_sm
27 |
28 | # pull source code
29 | COPY . ./
30 |
31 | # env variables
32 | ENV IMAGES='images/'
33 | ENV SOURCE='source/'
34 | ENV TARGET='target/'
35 | ENV MODELS='models/'
36 |
37 | # entrypoint
38 | ENTRYPOINT ["python", "main.py"]
39 | CMD ["--debug"]
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Ibraheem Khalil Ba
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # transformer-image-captioning
2 | Implementation of the paper CPTR : FULL TRANSFORMER NETWORK FOR IMAGE CAPTIONING
3 |
4 |
5 |
6 |
7 | architecture of the CPTR model for image captioning
8 |
9 |
10 |
11 | ---
12 | ---
13 |
14 | # predictions
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 | # prerequisites
32 | * git
33 | * python3
34 | * python3-venv
35 | * docker
36 |
37 | # clone the repo and prepare data
38 | ```bash
39 | # clone
40 | git clone https://github.com/Milkymap/transformer-image-captioning
41 | cd transformer-image-captioning
42 | # prepare data
43 | # models is the space where resnet152 and clip will be saved
44 | # models also contains the checkpoints during training
45 | # images contains a set of image files for inference time(see docker describe step)
46 | # source contains the data used for training
47 | # the data is in the next format
48 | # images directory : contains all images for training
49 | # captions.json : is a hashmap(image_file_id=>[text, text, text])
50 | # target contains extracted features such as vectors, tokenizer, vocabulary
51 | mkdir models images source target
52 | ```
53 |
54 | # docker build and run
55 | ```bash
56 | docker build -t capformer:0.0 -f Dockerfile.gpu
57 | ```
58 |
59 | # docker run processing step
60 | ```bash
61 | docker run
62 | --rm
63 | --tty
64 | --name capformer
65 | --gpus all
66 | -v $(pwd)/source:/home/solver/source
67 | -v $(pwd)/models:/home/solver/models
68 | -v $(pwd)/target:/home/solver/target
69 | -v $(pwd)/images:/home/solver/images
70 | -e TERM=xterm-256color
71 | capformer:0.0 processing
72 | --path2images /home/solver/source/images
73 | --path2captions /home/solver/source/captions.json
74 | --path2vectorizer /home/solver/models/resnet152.th
75 | --extension jpg
76 | --path2features /home/solver/target/map_img2features.pkl
77 | --path2tokenids /home/solver/target/zip_img2tokenids.pkl
78 | --path2vocabulary /home/solver/target/vocabulary.pkl
79 | ```
80 |
81 |
82 | # docker run learning step
83 | ```bash
84 | docker run
85 | --rm
86 | --tty
87 | --name capformer
88 | --gpus all
89 | -v $(pwd)/source:/home/solver/source
90 | -v $(pwd)/models:/home/solver/models
91 | -v $(pwd)/target:/home/solver/target
92 | -v $(pwd)/images:/home/solver/images
93 | -e TERM=xterm-256color
94 | capformer:0.0
95 | learning
96 | --path2features /home/solver/target/map_img2features.pkl
97 | --path2tokenids /home/solver/target/zip_img2tokenids.pkl
98 | --path2vocabulary /home/solver/target/vocabulary.pkl
99 | --nb_epochs 92
100 | --bt_size 128
101 | --path2checkpoint /home/solver/models/checkpoint_128.th
102 | --checkpoint 16
103 | --start 0
104 | ```
105 |
106 | # docker run describe step
107 | ```bash
108 | docker run
109 | --rm
110 | --tty
111 | --name capformer
112 | --gpus all
113 | -v $(pwd)/source:/home/solver/source
114 | -v $(pwd)/models:/home/solver/models
115 | -v $(pwd)/target:/home/solver/target
116 | -v $(pwd)/images:/home/solver/images
117 | -e TERM=xterm-256color
118 | capformer:0.0
119 | describe
120 | --path2vectorizer /home/solver/models/resnet152.th
121 | --path2ranker /home/solver/models/ranker.pkl
122 | --path2vocabulary /home/solver/target/vocabulary.pkl
123 | --path2checkpoint /home/solver/models/checkpoint_128.th
124 | --beam_width 17
125 | --path2image /home/solver/images/bob.jpg
126 | ```
127 |
128 | # structure of the project
129 |
130 | this project is based on opensource libraries such as **[pytorch, clip(openai), opencv, PIL]**
131 | It contains :
132 | * **core.py**
133 | * this is the main file of the project
134 | * it contains the definition of the transformer
135 | * it is based on the paper Attention Is All You Need
136 | * i added some modifications for handling multiple output of the decoder
137 | * **dataset.py**
138 | * this file contains two classes :
139 | * DatasetForFeaturesExtraction
140 | * DatasetForTraining
141 | * **model.py**
142 | * this file contains the definition of the CPTR model
143 | * it uses the transformer defined on the core module
144 | * it has some additional moduless like : token_embedding, prediction_head
145 | * **libraries**
146 | * contains usefull functions such as :
147 | * log handler
148 | * tokenization
149 | * features extraction
150 | * model loading
151 | * **beam and greedy search** for caption generation
152 | * **static**
153 | * contains images and fonts for the readme
154 | * **main.py**
155 | * this is the entrypoint of the program
156 | * it defines three subcommands
157 | * processing : for features extraction and tokenization
158 | * learning : training loop of the CPTR
159 | * describe : generate caption by taking an image path
160 | * **.gitignore**
161 | * **.dockerignore**
162 | * **Dockerfile.gpu**
163 | * **LICENCE**
164 | * **README.md**
165 |
166 | # Citations
167 |
168 | ```bibtex
169 | @misc{Liu2021cptr,
170 | title = {CPTR: FULL TRANSFORMER NETWORK FOR IMAGE CAPTIONING},
171 | author = {Wei Liu, Sihan Chen, Longteng Guo, Xinxin Zhu1, Jing Liu1},
172 | year = {2021},
173 | eprint = {2101.10804},
174 | archivePrefix = {arXiv},
175 | primaryClass = {cs.CV}
176 | }
177 | ```
178 |
--------------------------------------------------------------------------------
/core.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import operator as op
3 | import itertools as it, functools as ft
4 |
5 | import torch as th
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | class PositionalEncoding(nn.Module):
10 | def __init__(self, seq_length, in_dim, drop_val=0.1):
11 | super(PositionalEncoding, self).__init__()
12 | pos = np.arange(0, seq_length)[:, None]
13 | idx = np.fromfunction(lambda _,j: j - j % 2, shape=(1, in_dim))
14 | mask = np.fromfunction(lambda _,j: j % 2 == 0, shape=(1, in_dim))
15 |
16 | pnt = pos / (10000 ** (idx / in_dim))
17 | val = np.sin(pnt) * mask + np.cos(pnt) * (1 - mask)
18 |
19 | self.drop_layer = nn.Dropout(drop_val)
20 | self.register_buffer('psne_layer', th.tensor(val).float())
21 |
22 | def forward(self, src):
23 | _, seq_length, _ = src.shape
24 | pos = self.psne_layer[:seq_length, :][None, ...]
25 | return self.drop_layer(src + pos)
26 |
27 | class FeedForwardNetwork(nn.Module):
28 | __THETA = { # map id to non_linear
29 | 0: nn.Identity(),
30 | 1: nn.ReLU(),
31 | 2: nn.GELU(),
32 | 3: nn.Sigmoid(),
33 | 4: nn.Tanh(),
34 | 5: nn.Softmax(dim=-1)
35 | }
36 | def __init__(self, layer_cfg, activations, drop_vals):
37 | super(FeedForwardNetwork, self).__init__()
38 | self.shapes = list(zip(layer_cfg[:-1], layer_cfg[1:]))
39 | self.linears = nn.ModuleList([])
40 | for idx, (in_dim, out_dim) in enumerate(self.shapes):
41 | fn_id = activations[idx]
42 | proba = drop_vals[idx]
43 | block = nn.Sequential(
44 | nn.Linear(in_dim, out_dim),
45 | nn.Dropout(proba) if proba > 0.0 else nn.Identity(),
46 | FeedForwardNetwork.__THETA.get(fn_id, nn.Identity())
47 | )
48 | self.linears.append(block)
49 |
50 | def forward(self, input_batch):
51 | output_batch = ft.reduce( # functools
52 | lambda acc, crr: crr(acc),
53 | self.linears,
54 | input_batch
55 | )
56 | return output_batch
57 |
58 | class MultiHeadCrossAttention(nn.Module):
59 | def __init__(self, in_dim, nb_heads):
60 | super(MultiHeadCrossAttention, self).__init__()
61 | self.nbr_heads = nb_heads
62 | self.heads_dim = in_dim // nb_heads
63 |
64 | self.to_qry = nn.Linear(in_dim, in_dim)
65 | self.to_key = nn.Linear(in_dim, in_dim)
66 | self.to_val = nn.Linear(in_dim, in_dim)
67 | self.to_out = nn.Linear(in_dim, in_dim)
68 |
69 | def __rearrange(self, seq):
70 | bt_size, seq_length, _ = seq.shape # unpack shape
71 | seq = seq.reshape(bt_size, seq_length, self.nbr_heads, self.heads_dim).permute(0, 2, 1, 3)
72 | return seq
73 |
74 | def forward(self, qry, key, val, mask=None, key_padding_mask=None):
75 |
76 | qry = self.to_qry(qry)
77 | key = self.to_key(key)
78 | val = self.to_val(val)
79 |
80 | qry = self.__rearrange(qry)
81 | key = self.__rearrange(key)
82 | val = self.__rearrange(val)
83 |
84 | dim = qry.shape[-1]
85 |
86 | wgt = qry @ key.transpose(-2, -1)
87 | wgt = wgt / np.sqrt(dim)
88 | if mask is not None:
89 | wgt = wgt.masked_fill(mask, float('-inf'))
90 | if key_padding_mask is not None:
91 | cnd = key_padding_mask[:, None, None, :]
92 | wgt = wgt.masked_fill(cnd, float('-inf'))
93 | wgt = th.softmax(wgt, dim=-1)
94 |
95 | res = wgt @ val
96 | res = res.permute(0, 2, 1, 3) # permute head and sequence
97 | res = th.flatten(res, start_dim=2) # concat over heads
98 | res = self.to_out(res)
99 |
100 | return res
101 |
102 | class MultiHeadSelfAttention(nn.Module):
103 | def __init__(self, in_dim, nb_heads):
104 | super(MultiHeadSelfAttention, self).__init__()
105 |
106 | self.nbr_heads = nb_heads
107 | self.heads_dim = in_dim // nb_heads
108 | self.qkv_layer = nn.Linear(in_dim, 3 * in_dim)
109 | self.out_layer = nn.Linear(in_dim, in_dim)
110 |
111 | def forward(self, src, mask=None, key_padding_mask=None):
112 | bt_size, seq_length, _ = src.shape # unpack shape
113 |
114 | qkv = self.qkv_layer(src) # extract query, key and value
115 | qkv = qkv.reshape(bt_size, seq_length, self.nbr_heads, 3 * self.heads_dim)
116 | qkv = qkv.permute(0, 2, 1, 3) # permute head and sequence
117 | qry, key, val = th.chunk(qkv, 3, dim=-1)
118 |
119 | dim = qry.shape[-1]
120 | wgt = qry @ key.transpose(-2, -1) # hidden_dim and sequence_dim
121 | wgt = wgt / np.sqrt(dim) # normalize
122 | if mask is not None:
123 | wgt = wgt.masked_fill(mask, float('-inf'))
124 | if key_padding_mask is not None:
125 | cnd = key_padding_mask[:, None, None, :]
126 | wgt = wgt.masked_fill(cnd, float('-inf'))
127 | wgt = th.softmax(wgt, dim=-1)
128 |
129 | res = wgt @ val
130 | res = res.permute(0, 2, 1, 3) # permute head and sequence
131 | res = th.flatten(res, start_dim=2) # concat over heads
132 | res = self.out_layer(res)
133 |
134 | return res
135 |
136 | class EncoderBlock(nn.Module):
137 | def __init__(self, in_dim, ff_dim, nb_heads, drop_val=0.1, pre_norm=False):
138 | super(EncoderBlock, self).__init__()
139 | assert in_dim % nb_heads == 0
140 |
141 | self.nbr_heads = nb_heads
142 | self.heads_dim = in_dim // nb_heads
143 |
144 | self.mha_layer = MultiHeadSelfAttention(in_dim, nb_heads)
145 | self.ffn_layer = FeedForwardNetwork([in_dim, ff_dim, in_dim], [1, 0], [drop_val, 0.0])
146 |
147 | self.dropout_layer = nn.ModuleDict({
148 | 'mha': nn.Dropout(drop_val),
149 | 'ffn': nn.Dropout(drop_val)
150 | })
151 | self.layer_normalz = nn.ModuleDict({
152 | 'mha': nn.ModuleList([
153 | nn.LayerNorm(in_dim) if pre_norm else nn.Identity(),
154 | nn.LayerNorm(in_dim) if not pre_norm else nn.Identity()
155 | ]),
156 | 'ffn': nn.ModuleList([
157 | nn.LayerNorm(in_dim) if pre_norm else nn.Identity(),
158 | nn.LayerNorm(in_dim) if not pre_norm else nn.Identity()
159 | ])
160 | })
161 |
162 |
163 | def forward(self, src, src_mask=None, src_key_padding_mask=None):
164 | # multi head self attention
165 | tmp = self.layer_normalz['mha'][0](src)
166 | out = self.mha_layer(tmp, src_mask, src_key_padding_mask)
167 | out = self.dropout_layer['mha'](out)
168 | agg = tmp + out
169 | agg = self.layer_normalz['mha'][1](agg)
170 |
171 | # feed forward network
172 | tmp = self.layer_normalz['ffn'][0](agg)
173 | out = self.ffn_layer(tmp)
174 | out = self.dropout_layer['ffn'](out)
175 | agg = tmp + out
176 | agg = self.layer_normalz['ffn'][1](agg)
177 |
178 | return agg
179 |
180 | class DecoderBlock(nn.Module):
181 | def __init__(self, in_dim, ff_dim, nb_heads, drop_val=0.1, pre_norm=False):
182 | super(DecoderBlock, self).__init__()
183 | assert in_dim % nb_heads == 0
184 |
185 | self.nbr_heads = nb_heads
186 | self.heads_dim = in_dim // nb_heads
187 |
188 | self.mha_layer = MultiHeadSelfAttention(in_dim, nb_heads)
189 | self.crx_layer = MultiHeadCrossAttention(in_dim, nb_heads)
190 | self.ffn_layer = FeedForwardNetwork([in_dim, ff_dim, in_dim], [1, 0], [drop_val, 0.0])
191 |
192 | self.dropout_layer = nn.ModuleDict({
193 | 'mha': nn.Dropout(drop_val),
194 | 'crx': nn.Dropout(drop_val),
195 | 'ffn': nn.Dropout(drop_val)
196 | })
197 | self.layer_normalz = nn.ModuleDict({
198 | 'mha': nn.ModuleList([
199 | nn.LayerNorm(in_dim) if pre_norm else nn.Identity(),
200 | nn.LayerNorm(in_dim) if not pre_norm else nn.Identity()
201 | ]),
202 | 'crx': nn.ModuleList([
203 | nn.LayerNorm(in_dim) if pre_norm else nn.Identity(),
204 | nn.LayerNorm(in_dim) if not pre_norm else nn.Identity()
205 | ]),
206 | 'ffn': nn.ModuleList([
207 | nn.LayerNorm(in_dim) if pre_norm else nn.Identity(),
208 | nn.LayerNorm(in_dim) if not pre_norm else nn.Identity()
209 | ])
210 | })
211 |
212 | def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
213 | # masked multi head attention
214 | tmp = self.layer_normalz['mha'][0](tgt)
215 | out = self.mha_layer(tmp, tgt_mask, tgt_key_padding_mask)
216 | out = self.dropout_layer['mha'](out)
217 | agg = tmp + out # residual
218 | agg = self.layer_normalz['mha'][1](agg)
219 |
220 | # cross multi head attention
221 | tmp = self.layer_normalz['crx'][0](agg)
222 | out = self.crx_layer(tmp, memory, memory, memory_mask, memory_key_padding_mask)
223 | out = self.dropout_layer['crx'](out)
224 | agg = tmp + out # residual
225 | agg = self.layer_normalz['crx'][1](agg)
226 |
227 | # feed forward network
228 | tmp = self.layer_normalz['ffn'][0](agg)
229 | out = self.ffn_layer(agg)
230 | out = self.dropout_layer['ffn'](out)
231 | agg = tmp + out # residual
232 | agg = self.layer_normalz['ffn'][1](agg)
233 |
234 | return agg
235 |
236 | class TransformerEncoder(nn.Module):
237 | def __init__(self, nb_layers, in_dim, ff_dim, nb_heads, drop_val=0.1, pre_norm=False):
238 | super(TransformerEncoder, self).__init__()
239 | self.encoders = nn.ModuleList([])
240 | for _ in range(nb_layers):
241 | blk = EncoderBlock(in_dim=in_dim, ff_dim=ff_dim, nb_heads=nb_heads, drop_val=drop_val, pre_norm=pre_norm)
242 | self.encoders.append(blk)
243 |
244 | def forward(self, src, mask=None, key_padding_mask=None):
245 | fnl = ft.reduce(
246 | lambda acc, crr: acc + [crr(acc[-1], mask, key_padding_mask)],
247 | self.encoders,
248 | [src]
249 | )
250 | return fnl[1:] # ignore src
251 |
252 | class TransformerDecoder(nn.Module):
253 | def __init__(self, nb_layers, in_dim, ff_dim, nb_heads, drop_val=0.1, pre_norm=False):
254 | super(TransformerDecoder, self).__init__()
255 | self.decoders = nn.ModuleList([])
256 | for _ in range(nb_layers):
257 | blk = DecoderBlock(in_dim=in_dim, ff_dim=ff_dim, nb_heads=nb_heads, drop_val=drop_val, pre_norm=pre_norm)
258 | self.decoders.append(blk)
259 |
260 | def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
261 | lng = len(memory) - 1
262 | fnl = ft.reduce(
263 | lambda acc,crr: acc + [crr[1](acc[-1], memory[-1], tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask)],
264 | enumerate(self.decoders),
265 | [tgt]
266 | )
267 | return fnl[1:] # ignore tgt
268 |
269 | class Transformer(nn.Module):
270 | def __init__(self, in_dim, ff_dim, nb_heads, encoder_depth, decoder_depth, drop_val=0.1, pre_norm=False):
271 | super(Transformer, self).__init__()
272 | self.encoder = TransformerEncoder(
273 | nb_layers=encoder_depth,
274 | in_dim=in_dim,
275 | ff_dim=ff_dim,
276 | nb_heads=nb_heads,
277 | drop_val=drop_val,
278 | pre_norm=pre_norm
279 | )
280 | self.decoder = TransformerDecoder(
281 | nb_layers=decoder_depth,
282 | in_dim=in_dim,
283 | ff_dim=ff_dim,
284 | nb_heads=nb_heads,
285 | drop_val=drop_val,
286 | pre_norm=pre_norm
287 | )
288 |
289 | def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
290 | memory = self.encoder(src, src_mask, src_key_padding_mask)
291 | output = self.decoder(tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask)
292 | return output
293 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 |
3 | from os import path
4 | from torch.utils.data import Dataset
5 | from libraries.strategies import pull_images, read_image, prepare_image, deserialize, cv2th
6 |
7 | class DatasetForFeaturesExtraction(Dataset):
8 | def __init__(self, path2images, file_extension='*.jpg'):
9 | self.image_paths = pull_images(path2images, exts=file_extension)
10 | self.image_names = [ path.split(path_)[1] for path_ in self.image_paths ]
11 |
12 | def __len__(self):
13 | return len(self.image_paths)
14 |
15 | def __getitem__(self, index):
16 | path2img = self.image_paths[index]
17 | cv_image = read_image(path2img)
18 | th_image = cv2th(cv_image)
19 | return prepare_image(th_image)
20 |
21 | class DatasetForTraining(Dataset):
22 | def __init__(self, path2tokenids, path2features):
23 | self.tokenids = deserialize(path2tokenids)
24 | self.features = deserialize(path2features)
25 | def __len__(self):
26 | return len(self.tokenids)
27 |
28 | def __getitem__(self, idx):
29 | file_name, ids = self.tokenids[idx]
30 | vec = self.features[file_name]
31 | vec = th.tensor(vec).float()
32 | ids = th.tensor(ids).long()
33 | return vec, ids
--------------------------------------------------------------------------------
/libraries/log.py:
--------------------------------------------------------------------------------
1 | from sys import stdout
2 | from loguru import logger
3 |
4 |
5 | log_format = [
6 | '{time:YYYY-MM-DD hh:mm:ss}',
7 | '{file:^15}',
8 | '{line:03d}',
9 | '{level:^10}',
10 | '{message:<50}'
11 | ]
12 |
13 | log_separator = ' | '
14 |
15 | logger.remove()
16 | logger.add(
17 | sink=stdout,
18 | level='TRACE',
19 | format=log_separator.join(log_format)
20 | )
21 |
22 | if __name__ == '__main__':
23 | logger.debug('check if log is available...!')
--------------------------------------------------------------------------------
/libraries/strategies.py:
--------------------------------------------------------------------------------
1 | import io
2 | import json
3 | from multiprocessing.sharedctypes import Value
4 | import pickle as pk
5 |
6 | import clip
7 |
8 | import cv2
9 | import numpy as np
10 | import torch as th
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | import operator as op
15 | import itertools as it, functools as ft
16 |
17 | from os import path
18 | from glob import glob
19 |
20 | from PIL import Image
21 | from rich.progress import track
22 | from torch.nn.utils.rnn import pad_sequence
23 | from torchvision import models
24 | from torchvision import transforms as T
25 | from torchtext.data.utils import get_tokenizer
26 | from torchtext.vocab import build_vocab_from_iterator
27 |
28 | from libraries.log import logger
29 |
30 | SPECIALS2IDX = {"": 0, "": 1, "": 2, "": 3}
31 |
32 | def pull_files(endpoint, extension):
33 | file_paths = sorted(glob(path.join(endpoint, extension)))
34 | return file_paths
35 |
36 | def build_tokenizer(tok_name='spacy', lang='en_core_web_sm'):
37 | tokenizer = get_tokenizer(tokenizer=tok_name, language=lang)
38 | return tokenizer
39 |
40 | def yield_tokens(data_iter, tokenizer):
41 | for sample in track(data_iter, description=f'tokenization process'):
42 | sample = sample.strip().lower() # remove trailing keys and lowercase
43 | yield tokenizer(sample)
44 |
45 | def make_vocab(data_iter, tokenizer, map_specials2idx):
46 | vocab = build_vocab_from_iterator(
47 | iterator=yield_tokens(data_iter, tokenizer),
48 | specials=list(map_specials2idx.keys()),
49 | min_freq=1,
50 | special_first=True
51 | )
52 | vocab.set_default_index(map_specials2idx['']) # index of the token
53 | return vocab
54 |
55 | def serialize(path2dump, data):
56 | with open(path2dump, mode='wb') as fp:
57 | pk.dump(data, fp)
58 |
59 | def deserialize(path2dump):
60 | with open(path2dump, mode='rb') as fp:
61 | return pk.load(fp)
62 |
63 | def serialize(path2dump, data):
64 | with open(path2dump, mode='wb') as fp:
65 | pk.dump(data, fp)
66 |
67 | def deserialize(path2dump):
68 | with open(path2dump, mode='rb') as fp:
69 | return pk.load(fp)
70 |
71 | def load_vectorizer(path2models):
72 | if path.isfile(path2models):
73 | features_extractor = th.load(path2models)
74 | else:
75 | model_name = path.split(path2models)[1]
76 | real_name, _ = model_name.split('.')
77 | endpoint = op.attrgetter(real_name)(models)
78 | if endpoint is not None:
79 | features_extractor = endpoint(pretrained=True, progress=True)
80 | features_extractor = nn.Sequential(*list(features_extractor.children())[:-2])
81 | for prm in features_extractor.parameters():
82 | prm.requires_grad = False
83 | th.save(features_extractor, path2models)
84 | else:
85 | raise Value(f'{real_name} is not a valid option for torchvision.models')
86 | return features_extractor
87 |
88 | def load_ranker(path2models, device):
89 | if path.isfile(path2models):
90 | with open(path2models, 'rb') as fp:
91 | model, processor = pk.load(fp) # (model, processor)
92 | else:
93 | model, processor = clip.load("ViT-B/32")
94 | with open(path2models, 'wb') as fp:
95 | pk.dump((model, processor), fp)
96 | return model.to(device), processor
97 |
98 | def pull_images(path2images, exts='*.jpg'):
99 | return sorted( glob(path.join(path2images, '**' ,exts), recursive=True) )
100 |
101 | def th2cv(th_image):
102 | red, green, blue = th_image.numpy()
103 | return cv2.merge((blue, green, red))
104 |
105 | def cv2th(cv_image):
106 | blue, green, red = cv2.split(cv_image)
107 | return th.as_tensor(np.stack([red, green, blue]))
108 |
109 | def cv2pil(cv_image):
110 | return Image.fromarray(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB))
111 |
112 | def pil2cv(pil_image):
113 | return cv2.cvtColor(np.asarray(pil_image), cv2.COLOR_RGB2BGR)
114 |
115 | def read_image(path2image, size=None):
116 | pl_image = Image.open(path2image).convert('RGB')
117 | cv_image = cv2.cvtColor(np.array(pl_image), cv2.COLOR_RGB2BGR)
118 | if size is not None:
119 | return cv2.resize(cv_image, size, interpolation=cv2.INTER_CUBIC)
120 | return cv_image
121 |
122 | def save_image(cv_image, path2location):
123 | cv2.imwrite(path2location, cv_image)
124 |
125 | def prepare_image(th_image):
126 | normalied_th_image = th_image / th.max(th_image)
127 | return T.Compose([
128 | T.RandomHorizontalFlip(p=0.5),
129 | T.Resize((256, 256)),
130 | T.CenterCrop((224, 224)),
131 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
132 | ])(normalied_th_image)
133 |
134 | def extract_features(extractor, batch_of_images):
135 | with th.no_grad():
136 | features = extractor(batch_of_images)
137 | return features
138 |
139 | def rank_solutions(pil_image, sentences, ranker, processor, device):
140 | image = processor(pil_image).unsqueeze(0).to(device)
141 | tokens = clip.tokenize(sentences).to(device)
142 | with th.no_grad():
143 | logits, _ = ranker(image, tokens)
144 | probabilities = th.softmax(logits, dim=1).cpu().squeeze(0)
145 | lowest = th.min(probabilities)
146 | largest = th.max(probabilities)
147 | normalized_scores = (probabilities - lowest) / (largest - lowest)
148 | return normalized_scores.tolist()
149 |
150 | def custom_fn(batch):
151 | features, token_ids = list(zip(*batch))
152 | features = th.stack(features) # 3d
153 | token_ids = list(token_ids)
154 | token_ids = pad_sequence(token_ids, batch_first=True, padding_value=SPECIALS2IDX[''])
155 | return features, token_ids
156 |
157 | def build_mask(seq):
158 | seq_length = seq.shape[1]
159 | mask = np.fromfunction(lambda i,j: j > i, shape=(seq_length, seq_length))
160 | return th.as_tensor(mask)
161 |
162 | def build_key_padding_mask(seq, pad_idx):
163 | seq_key_padding_mask = (seq == pad_idx)
164 | return seq_key_padding_mask
165 |
166 | def greedy_search(model, source, BOS, EOS, max_len, device):
167 | memory = model.encode(source.to(device))
168 | target = th.tensor([[BOS]])
169 | keep_generating = True
170 | while keep_generating:
171 | output = model.decode(target.to(device), memory).squeeze(0)
172 | logits = model.generator(output[-1, :])
173 | scaled_logits = th.log_softmax(logits, dim=-1).squeeze(0)
174 | candidate = th.argmax(scaled_logits)
175 | target = th.cat([target, th.tensor([[candidate]])], dim=1)
176 | keep_generating = (candidate != EOS) and (target.shape[1] < max_len)
177 | return th.flatten(target)
178 |
179 | def beam_search(model, source, BOS, EOS, max_len, device, beam_width, alpha=0.7):
180 | memory = model.encode(source.to(device))
181 | target = th.tensor([[BOS]])
182 | with th.no_grad():
183 | output = model.decode(target.to(device), memory)
184 | output = th.stack([ model.generator(out[0, -1, :]) for out in output ])
185 | logits = th.mean(output, dim=0)
186 |
187 | scaled_logits = th.log_softmax(logits[None, :], dim=1).cpu().squeeze(0) # over vocab size
188 | weights, candidates = th.topk(input=scaled_logits, k=beam_width, largest=True)
189 |
190 | response_tracker = [] # for valid final sequence
191 | sequence_tracker = [] # for current active sequence
192 | for idx in candidates:
193 | option = th.tensor([[idx]]) # a new option into the search tree
194 | sequence = th.cat([target, option], dim=1)
195 | sequence_tracker.append(sequence)
196 |
197 | keep_generating = True
198 | while keep_generating:
199 | input_batch = th.vstack(sequence_tracker)
200 | with th.no_grad():
201 | input_memory = [m.repeat(input_batch.shape[0], 1, 1) for m in memory ]
202 | output = model.decode(input_batch.to(device), input_memory)
203 | logits = th.mean(th.stack([ model.generator(out[:, -1, :]) for out in output ]), dim=0)
204 |
205 | scaled_logits = th.log_softmax(logits, dim=1).cpu()
206 |
207 | length = input_batch.shape[1]
208 | vocab_size = scaled_logits.shape[1]
209 | weighted_logits = (scaled_logits + weights[:, None]) / length ** alpha
210 | weights, candidates = th.topk(th.flatten(weighted_logits), k=beam_width, largest=True)
211 | weights = weights * length ** alpha # denormalize
212 |
213 | weights_tmp = []
214 | sequence_tmp = []
215 | for idx, pos in enumerate(candidates):
216 | row = th.div(pos, vocab_size, rounding_mode='floor') # get relative position over nb_sequences
217 | col = pos % vocab_size # get relative position over vocab_size
218 | sequence = th.cat([sequence_tracker[row], th.tensor([[col]])], dim=1)
219 | if col == EOS:
220 | logger.success('a sentence was generated :)')
221 | flattened_sequence = th.flatten(sequence).tolist()
222 | sequence_score = weights[idx] / len(flattened_sequence) ** alpha
223 | response_tracker.append((flattened_sequence, sequence_score)) # a sentence was built
224 | if len(response_tracker) == beam_width:
225 | keep_generating = False
226 | break # end the for loop over candidates
227 | elif sequence.shape[1] < max_len - 1:
228 | weights_tmp.append(weights[row])
229 | sequence_tmp.append(sequence)
230 | # end for loop over candidates ...!
231 |
232 | if len(sequence_tmp) == 0:
233 | keep_generating = False
234 | else:
235 | weights = th.tensor(weights_tmp)
236 | sequence_tracker = sequence_tmp
237 | # end while search loop ...!
238 | return response_tracker
239 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from pickletools import optimize
2 | import click
3 |
4 | from torch.utils.data import DataLoader
5 |
6 | from os import getenv, path
7 | from time import sleep
8 | from rich.progress import track
9 | from dataset import DatasetForFeaturesExtraction, DatasetForTraining
10 | from libraries.log import logger
11 | from libraries.strategies import *
12 |
13 | from model import CaptionTransformer
14 |
15 |
16 | @click.group(chain=False, invoke_without_command=True)
17 | @click.option('--debug/--no-debug', help='debug mode flag', default=True)
18 | @click.pass_context
19 | def router_command(ctx, debug):
20 | ctx.ensure_object(dict)
21 |
22 | models = getenv('MODELS')
23 | source = getenv('SOURCE')
24 | target = getenv('TARGET')
25 | images = getenv('IMAGES')
26 |
27 | assert models is not None and path.isdir(models)
28 | assert source is not None and path.isdir(source)
29 | assert target is not None and path.isdir(target)
30 | assert images is not None and path.isdir(images)
31 |
32 | ctx.obj['debug'] = debug
33 | command = ctx.invoked_subcommand
34 | if command is None:
35 | logger.debug('no command was called, add --help option to see the avaiables command')
36 | else:
37 | logger.debug(f'{command} was called')
38 |
39 | @router_command.command()
40 | @click.option('--path2vectorizer', help='path to models for features extraction', type=click.Path(False))
41 | @click.option('--path2images', help='path to images directory', type=click.Path(True))
42 | @click.option('--path2captions', help='path to captions json file', type=click.Path(True))
43 | @click.option('--extension', help='image file extension', type=click.Choice(['jpg', 'jpeg']))
44 | @click.option('--path2features', help='path to features dump location', type=click.Path(False))
45 | @click.option('--path2tokenids', help='path to tokenids dump lication', type=click.Path(False))
46 | @click.option('--path2vocabulary', help='path to vacabulary dump location', type=click.Path(False))
47 | def processing(path2vectorizer, path2images, path2captions, extension, path2features, path2tokenids, path2vocabulary):
48 | device = th.device('cuda:0' if th.cuda.is_available() else 'cpu')
49 |
50 | with open(file=path2captions, mode='r') as fp:
51 | img2captions = json.load(fp)
52 |
53 | captions = list(img2captions.values())
54 | captions = list(it.chain(*captions))
55 |
56 | tokenizer = build_tokenizer(tok_name='spacy', lang='en_core_web_sm')
57 | vocabulary = make_vocab(captions, tokenizer, SPECIALS2IDX)
58 | logger.success('vocaulary was built')
59 |
60 | serialize(path2vocabulary, vocabulary)
61 |
62 | bos = th.tensor([SPECIALS2IDX['']])
63 | eos = th.tensor([SPECIALS2IDX['']])
64 |
65 | zip_img2tokenids = []
66 | logger.debug('caption tokenization')
67 | for key, val in track(img2captions.items(), 'build map_img2tokenids'):
68 | for cap in val:
69 | tok = tokenizer(cap.strip().lower())
70 | idx = th.tensor(vocabulary(tok))
71 | seq = th.cat([bos, idx, eos]).numpy() # more effective for storage
72 | zip_img2tokenids.append((key, seq))
73 |
74 | serialize(path2tokenids, zip_img2tokenids)
75 |
76 | logger.debug('features extraction loading')
77 | vectorizer = load_vectorizer(path2vectorizer)
78 | vectorizer.eval()
79 | vectorizer.to(device)
80 |
81 | dataset = DatasetForFeaturesExtraction(path2images, f'*.{extension}')
82 |
83 | logger.debug('extraction will start')
84 | accumulator = []
85 | for sections in track(dataset, 'features extraction'):
86 | embedding = extract_features(vectorizer, sections[None, ...].to(device)).squeeze(0) # (2048, 7, 7)
87 | embedding = th.flatten(embedding, start_dim=1).T.cpu().numpy() # 49, 2048
88 | accumulator.append(embedding)
89 |
90 | image_names = dataset.image_names
91 | accumulator = np.stack(accumulator) # stack over batch axis ==> (nb_images, 49, 512)
92 | logger.debug(f'accumulated features shape : {accumulator.shape}')
93 | assert len(image_names) == len(accumulator)
94 | map_img2features = dict(zip(image_names, accumulator))
95 |
96 | serialize(path2features, map_img2features)
97 |
98 | logger.success('features, tokenids and vocabulary were saved')
99 |
100 | @router_command.command()
101 | @click.option('--path2vocabulary', help='path to vacabulary dump location', type=click.Path(True))
102 | @click.option('--path2features', help='path to features dump location', type=click.Path(True))
103 | @click.option('--path2tokenids', help='path to tokenids dump lication', type=click.Path(True))
104 | @click.option('--nb_epochs', help='number of epochs', type=int, default=128)
105 | @click.option('--bt_size', help='batch size', type=int, default=32)
106 | @click.option('--path2checkpoint', help='path to checkpoint model', type=click.Path(False))
107 | @click.option('--checkpoint', help='checkpoint period(save model)', type=int, default=16)
108 | @click.option('--start', help='start epoch index', type=int, default=0)
109 | def learning(path2vocabulary, path2features, path2tokenids, nb_epochs, bt_size, path2checkpoint, checkpoint, start):
110 | basepath2models = getenv('MODELS')
111 |
112 | device = th.device('cuda:0' if th.cuda.is_available() else 'cpu')
113 |
114 | logger.debug('load vocabulary')
115 | vocabulary = deserialize(path2vocabulary)
116 | nb_tokens = len(vocabulary)
117 |
118 | logger.debug('build dataset')
119 | dataset = DatasetForTraining(path2tokenids, path2features)
120 | logger.debug(f'size of the dataset : {len(dataset):05d}')
121 | dataloader = DataLoader(dataset, batch_size=bt_size, shuffle=True, collate_fn=custom_fn)
122 | nb_data = len(dataset)
123 |
124 | logger.debug('define network')
125 | if path.isfile(path2checkpoint):
126 | net = th.load(path2checkpoint)
127 | else:
128 | net = CaptionTransformer(
129 | in_dim=2048,
130 | hd_dim=256,
131 | ff_dim=512,
132 | nb_heads=8,
133 | num_encoders=5,
134 | num_decoders=5,
135 | pre_norm=False,
136 | seq_length=64,
137 | nb_tokens=nb_tokens,
138 | padding_idx=SPECIALS2IDX['']
139 | )
140 |
141 | net.to(device)
142 | net.train()
143 |
144 | print(net)
145 |
146 | optimizer = th.optim.Adam(net.parameters(), lr=1e-4, betas=(0.9, 0.99), eps=1e-9)
147 | criterion = nn.CrossEntropyLoss(ignore_index=SPECIALS2IDX[''])
148 | logger.debug('training will begin ...!')
149 | sleep(1)
150 |
151 | nb_epochs += start
152 | for epoch in range(start, nb_epochs):
153 | counter = 0
154 | for src, tgt in dataloader:
155 | counter += len(tgt)
156 | tgt_input = tgt[:, :-1]
157 | tgt_output = tgt[:, 1:]
158 |
159 | tgt_mask = build_mask(tgt_input).to(device)
160 | tgt_key_padding_mask = build_key_padding_mask(tgt_input, SPECIALS2IDX['']).to(device)
161 |
162 | memory = net.encode(src=src.to(device))
163 | output = net.decode(
164 | tgt=tgt_input.to(device),
165 | memory=memory,
166 | tgt_mask=tgt_mask,
167 | tgt_key_padding_mask=tgt_key_padding_mask
168 | )
169 |
170 | logits = [net.generator(out) for out in output ]
171 | logits = [ th.flatten(prb, start_dim=0, end_dim=1) for prb in logits ]
172 | tgt_output = th.flatten(tgt_output)
173 |
174 | optimizer.zero_grad()
175 | errors = [ criterion(prb, tgt_output.to(device)) for prb in logits ]
176 | error = sum(errors)
177 | error.backward()
178 | optimizer.step()
179 |
180 | message = []
181 | for err in errors:
182 | msg = f'{err.cpu().item():07.3f}'
183 | message.append(msg)
184 | message = ' | '.join(message)
185 | logger.debug(f'[{epoch:03d}/{nb_epochs:03d}] [{counter:05d}/{nb_data:05d}] | Loss : {error.cpu().item():07.3f} >> {message}')
186 | # end for loop over batchs
187 |
188 | if epoch % checkpoint == 0:
189 | path2network = path.join(basepath2models, f'checkpoint_{epoch:03d}.th')
190 | th.save(net.cpu(), path2network)
191 | net.to(device)
192 | logger.success(f'a snapshot was saved {path2network}')
193 |
194 | # end for loop over epochs
195 |
196 | path2network = path.join(basepath2models, f'checkpoint_###.th')
197 | th.save(net.cpu(), path2network)
198 | logger.success(f'a snapshot was saved {path2network}')
199 | logger.success('end of training')
200 |
201 |
202 |
203 | @router_command.command()
204 | @click.option('--path2vectorizer', help='name of the stored model(features extractor)', type=str)
205 | @click.option('--path2checkpoint', help='model snapshot filename', type=str)
206 | @click.option('--path2image', help='image to describe', type=str)
207 | @click.option('--path2vocabulary', help='vocabulary object', type=str)
208 | @click.option('--beam_width', help='size of beam', type=int, default=7)
209 | @click.option('--path2ranker', help='name of the ranker model', type=str)
210 | def describe(path2vectorizer, path2checkpoint, path2image, path2vocabulary, beam_width, path2ranker):
211 | device = th.device('cuda:0' if th.cuda.is_available() else 'cpu')
212 |
213 | logger.debug('env variables loading')
214 | logger.debug('features, vocab and token_ids loading')
215 |
216 | if path.isfile(path2checkpoint):
217 | logger.debug('model(snapshot) will be loaded')
218 | net = th.load(path2checkpoint)
219 | net.to(device)
220 | net.eval()
221 |
222 | vocab = deserialize(path2vocabulary)
223 | logger.debug(f'vocab was loaded | len => {len(vocab)}')
224 |
225 | logger.debug(f'load features extractor')
226 |
227 | vectorizer = load_vectorizer(path2vectorizer)
228 | vectorizer.eval()
229 | vectorizer.to(device)
230 |
231 | logger.debug('load ranker clip VIT model')
232 | ranker, processor = load_ranker(path2ranker, device)
233 |
234 | logger.debug('features extraction by resnet152')
235 |
236 | cv_image = read_image(path2image)
237 | th_image = cv2th(cv_image)
238 | th_image = prepare_image(th_image)
239 |
240 | embedding = extract_features(vectorizer, th_image[None, ...].to(device)).squeeze(0)
241 | output_batch = th.flatten(embedding, start_dim=1).T # 49, 2048
242 |
243 | response = beam_search(
244 | model=net,
245 | source=output_batch[None, ...],
246 | BOS=SPECIALS2IDX[''],
247 | EOS=SPECIALS2IDX[''],
248 | max_len=64,
249 | beam_width=beam_width,
250 | device=device,
251 | alpha=0.7
252 | )
253 |
254 | logger.debug(f'nb generated : {len(response)}')
255 | sentences = []
256 | for sequence, _ in response:
257 | caption = vocab.lookup_tokens(sequence[1:-1]) # ignore and
258 | joined_caption = ' '.join(caption)
259 | sentences.append(joined_caption)
260 |
261 | logger.debug('ranking will begin...!')
262 | pil_image = cv2pil(cv_image)
263 | ranked_scores = rank_solutions(pil_image, sentences, ranker, processor, device)
264 | ranked_response = list(zip(sentences, ranked_scores))
265 | ranked_response = sorted(ranked_response, key=op.itemgetter(1), reverse=True)
266 |
267 | for caption, score in ranked_response:
268 | score = int(score * 100)
269 | logger.debug(f'caption : {caption} | score : {score:03d}')
270 |
271 | if __name__ == '__main__':
272 | router_command(obj={})
273 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 | import torch.nn as nn
4 |
5 | from core import Transformer, PositionalEncoding
6 |
7 | class CaptionTransformer(nn.Module):
8 | def __init__(self, in_dim, hd_dim, ff_dim, nb_heads, num_encoders, num_decoders, pre_norm, seq_length, nb_tokens, padding_idx):
9 | super(CaptionTransformer, self).__init__()
10 | self.embedding_scale = np.sqrt(hd_dim)
11 | self.position_encoder = PositionalEncoding(seq_length, hd_dim)
12 | self.adaptaror = nn.Linear(in_dim, hd_dim)
13 | self.token_embedder = nn.Embedding(nb_tokens, hd_dim, padding_idx)
14 | self.transformer = Transformer(
15 | in_dim=hd_dim,
16 | ff_dim=ff_dim,
17 | nb_heads=nb_heads,
18 | encoder_depth=num_encoders,
19 | decoder_depth=num_decoders,
20 | pre_norm=pre_norm
21 | )
22 | self.generator = nn.Linear(hd_dim, nb_tokens)
23 |
24 |
25 | def encode(self, src, src_mask=None, src_key_padding_mask=None):
26 | src = self.adaptaror(src) # reduce the dimension for in_dim to hd_dim
27 | src = self.position_encoder(src)
28 | memory = self.transformer.encoder(src, src_mask, src_key_padding_mask)
29 | return memory
30 |
31 | def decode(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
32 | tgt = self.token_embedder(tgt) * self.embedding_scale
33 | tgt = self.position_encoder(tgt)
34 | output = self.transformer.decoder(tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask)
35 | return output
36 |
37 | def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
38 | src = self.position_encoder(self.adaptaror(src))
39 | embedded_tgt = self.token_embedder(tgt)
40 | embedded_tgt = self.position_encoder(embedded_tgt)
41 |
42 | output = self.transformer(
43 | src=src,
44 | tgt=embedded_tgt,
45 | src_mask=src_mask,
46 | tgt_mask=tgt_mask,
47 | memory_mask=memory_mask,
48 | src_key_padding_mask=src_key_padding_mask,
49 | tgt_key_padding_mask=tgt_key_padding_mask,
50 | memory_key_padding_mask=memory_key_padding_mask
51 | )
52 |
53 | return self.generator(output[-1])
54 |
55 |
56 |
--------------------------------------------------------------------------------
/static/cptr_architecture.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/milkymap/transformer-image-captioning/6441d684e7b181aefb7c76d3ff548f0995b06d71/static/cptr_architecture.jpg
--------------------------------------------------------------------------------
/static/mlm_000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/milkymap/transformer-image-captioning/6441d684e7b181aefb7c76d3ff548f0995b06d71/static/mlm_000.png
--------------------------------------------------------------------------------
/static/mlm_001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/milkymap/transformer-image-captioning/6441d684e7b181aefb7c76d3ff548f0995b06d71/static/mlm_001.png
--------------------------------------------------------------------------------
/static/mlm_002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/milkymap/transformer-image-captioning/6441d684e7b181aefb7c76d3ff548f0995b06d71/static/mlm_002.png
--------------------------------------------------------------------------------
/static/mlm_003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/milkymap/transformer-image-captioning/6441d684e7b181aefb7c76d3ff548f0995b06d71/static/mlm_003.png
--------------------------------------------------------------------------------
/static/mlm_004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/milkymap/transformer-image-captioning/6441d684e7b181aefb7c76d3ff548f0995b06d71/static/mlm_004.png
--------------------------------------------------------------------------------
/static/mlm_005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/milkymap/transformer-image-captioning/6441d684e7b181aefb7c76d3ff548f0995b06d71/static/mlm_005.png
--------------------------------------------------------------------------------