├── .gitignore ├── LICENSE ├── README.md ├── data ├── .gitignore └── label.txt ├── data_loader.py ├── main.py ├── model.py ├── requirements.txt ├── server ├── .gitignore ├── Dockerfile ├── app.py ├── download_models.py ├── label.txt ├── model.py ├── request.sh ├── requirements.txt ├── run_server.sh └── templates │ ├── post.html │ └── result.html ├── swagger.yaml ├── trainer.py ├── utils.py └── vgg_feature.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | ################## 132 | .vscode 133 | .idea 134 | 135 | legacy 136 | 137 | model/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hashtag Prediction with Pytorch 2 | 3 | Multimodal hashtag prediction from instagram 4 | 5 | ## Overview 6 | 7 |
8 |
9 |
36 |
37 |
38 | - Copy the id. It will be used when you give request.
39 |
40 |
41 |
42 |
43 | ### 3. Request
44 |
45 | ```bash
46 | # URL
47 | localhost:80/predict?image_id=1DGu9R5a9jpkY-fy79VrGFmCdJigzTMC-&text=20%20days%20till%20Christmas%20%F0%9F%98%8D%F0%9F%8E%85&max_seq_len=20&n_label=10
48 | ```
49 |
50 | ## Run on Ainize
51 |
52 | [](https://ainize.web.app/redirect?git_repo=github.com/monologg/hashtag-prediction-pytorch)
53 |
54 |
55 |
56 |
57 | 1. `image_id` : the share id you can get from google drive above
58 | 2. `text` : like caption in instagram
59 | 3. `max_seq_len`: maximum sequence length
60 | 4. `n_label`: num of labels you want to predict
61 |
62 | ```bash
63 | https://endpoint.ainize.ai/monologg/hashtag/predict?image_id={image_id}&text={text}&max_seq_len={max_seq_len}&n_label={n_label}
64 | ```
65 |
66 | ```bash
67 | # URL
68 | https://endpoint.ainize.ai/monologg/hashtag/predict?image_id=1DGu9R5a9jpkY-fy79VrGFmCdJigzTMC-&text=20%20days%20till%20Christmas%20%F0%9F%98%8D%F0%9F%8E%85&max_seq_len=20&n_label=10
69 | ```
70 |
71 | ### Result on html
72 |
73 |
74 |
75 |
76 | ## Reference
77 |
78 | - [Huggingface Transformers](https://github.com/huggingface/transformers)
79 | - [ALBERT Paper](https://arxiv.org/abs/1909.11942)
80 |
--------------------------------------------------------------------------------
/data/.gitignore:
--------------------------------------------------------------------------------
1 | *.tsv
2 | *.h5
3 | cached*
--------------------------------------------------------------------------------
/data/label.txt:
--------------------------------------------------------------------------------
1 | [UNK]
2 | adventure
3 | amazing
4 | art
5 | beach
6 | beautiful
7 | beauty
8 | bestoftheday
9 | blue
10 | car
11 | cat
12 | clouds
13 | coffee
14 | cool
15 | cute
16 | dog
17 | explore
18 | f4f
19 | family
20 | fashion
21 | fit
22 | fitness
23 | flowers
24 | follow
25 | follow4follow
26 | followforfollow
27 | followme
28 | food
29 | foodie
30 | foodporn
31 | friends
32 | fun
33 | funny
34 | girl
35 | girls
36 | green
37 | gym
38 | hair
39 | happy
40 | holiday
41 | hot
42 | igers
43 | instadaily
44 | instagood
45 | instagram
46 | instagramers
47 | instalike
48 | instamood
49 | instapic
50 | l4l
51 | landscape
52 | life
53 | lifestyle
54 | like
55 | like4like
56 | likeforlike
57 | lol
58 | love
59 | makeup
60 | me
61 | model
62 | motivation
63 | mountain
64 | mountains
65 | music
66 | nature
67 | naturephotography
68 | night
69 | nofilter
70 | ootd
71 | outfit
72 | party
73 | photo
74 | photographer
75 | photography
76 | photooftheday
77 | picoftheday
78 | pink
79 | pretty
80 | rain
81 | red
82 | sea
83 | selfie
84 | sky
85 | smile
86 | snow
87 | style
88 | summer
89 | sun
90 | sunset
91 | swag
92 | tbt
93 | throwbackthursday
94 | travel
95 | travelphotography
96 | tree
97 | vacation
98 | vsco
99 | vscocam
100 | wanderlust
101 | winter
102 |
--------------------------------------------------------------------------------
/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import json
4 | import logging
5 | import h5py
6 |
7 | import torch
8 | from torch.utils.data import TensorDataset
9 |
10 | from utils import get_label
11 | from vgg_feature import load_vgg_features
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | class InputExample(object):
17 | """
18 | A single training/test example for simple sequence classification.
19 |
20 | Args:
21 | guid: Unique id for the example.
22 | text_a: string. The untokenized text of the first sequence. For single
23 | sequence tasks, only this sequence must be specified.
24 | label: (Optional) string. The label of the example. This should be
25 | specified for train and dev examples, but not for test examples.
26 | """
27 |
28 | def __init__(self, guid, image_num, text_a, label_lst):
29 | self.guid = guid
30 | self.image_num = image_num
31 | self.text_a = text_a
32 | self.label_lst = label_lst
33 |
34 | def __repr__(self):
35 | return str(self.to_json_string())
36 |
37 | def to_dict(self):
38 | """Serializes this instance to a Python dictionary."""
39 | output = copy.deepcopy(self.__dict__)
40 | return output
41 |
42 | def to_json_string(self):
43 | """Serializes this instance to a JSON string."""
44 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
45 |
46 |
47 | class InputFeatures(object):
48 | """A single set of features of data."""
49 |
50 | def __init__(self, input_ids, attention_mask, token_type_ids, label_id):
51 | self.input_ids = input_ids
52 | self.attention_mask = attention_mask
53 | self.token_type_ids = token_type_ids
54 | self.label_id = label_id
55 |
56 | def __repr__(self):
57 | return str(self.to_json_string())
58 |
59 | def to_dict(self):
60 | """Serializes this instance to a Python dictionary."""
61 | output = copy.deepcopy(self.__dict__)
62 | return output
63 |
64 | def to_json_string(self):
65 | """Serializes this instance to a JSON string."""
66 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
67 |
68 |
69 | class HashtagProcessor(object):
70 | """Processor for the Hashtag data set """
71 |
72 | def __init__(self, args):
73 | self.args = args
74 | self.labels = get_label(args)
75 |
76 | @classmethod
77 | def _read_file(cls, input_file, quotechar=None):
78 | """Reads a tab separated value file."""
79 | with open(input_file, "r", encoding="utf-8") as f:
80 | lines = []
81 | for line in f:
82 | lines.append(line.strip())
83 | return lines
84 |
85 | def _create_examples(self, lines, set_type):
86 | """Creates examples for the training and dev sets."""
87 | examples = []
88 | for (i, line) in enumerate(lines):
89 | line = line.split('\t')
90 | guid = "%s-%s" % (set_type, i)
91 | image_num = line[0]
92 | text_a = line[1]
93 | if self.args.do_lower_case:
94 | text_a = text_a.lower()
95 | tag_lst = line[2].split()
96 | label_num_lst = [self.labels.index(tag) if tag in self.labels else self.labels.index("[UNK]") for tag in tag_lst]
97 | if i % 1000 == 0:
98 | logger.info(line)
99 | examples.append(InputExample(guid=guid, image_num=image_num, text_a=text_a, label_lst=label_num_lst))
100 | return examples
101 |
102 | def get_examples(self, mode):
103 | """
104 | Args:
105 | mode: train, dev, test
106 | """
107 | file_to_read = None
108 | if mode == 'train':
109 | file_to_read = self.args.train_file
110 | elif mode == 'dev':
111 | file_to_read = self.args.dev_file
112 | elif mode == 'test':
113 | file_to_read = self.args.test_file
114 |
115 | logger.info("LOOKING AT {}".format(os.path.join(self.args.data_dir, file_to_read)))
116 | return self._create_examples(self._read_file(os.path.join(self.args.data_dir, file_to_read)), mode)
117 |
118 |
119 | processors = {
120 | "hashtag": HashtagProcessor,
121 | }
122 |
123 |
124 | def convert_examples_to_features(examples, max_seq_len, tokenizer,
125 | label_len,
126 | cls_token_segment_id=0,
127 | pad_token_segment_id=0,
128 | sequence_a_segment_id=0,
129 | mask_padding_with_zero=True):
130 | # Setting based on the current model type
131 | cls_token = tokenizer.cls_token
132 | sep_token = tokenizer.sep_token
133 | pad_token_id = tokenizer.pad_token_id
134 |
135 | features = []
136 | for (ex_index, example) in enumerate(examples):
137 | if ex_index % 5000 == 0:
138 | logger.info("Writing example %d of %d" % (ex_index, len(examples)))
139 |
140 | tokens = tokenizer.tokenize(example.text_a)
141 |
142 | # Account for [CLS] and [SEP]
143 | special_tokens_count = 2
144 | if len(tokens) > max_seq_len - special_tokens_count:
145 | tokens = tokens[:(max_seq_len - special_tokens_count)]
146 |
147 | # Add [SEP] token
148 | tokens += [sep_token]
149 | token_type_ids = [sequence_a_segment_id] * len(tokens)
150 |
151 | # Add [CLS] token
152 | tokens = [cls_token] + tokens
153 | token_type_ids = [cls_token_segment_id] + token_type_ids
154 |
155 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
156 |
157 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
158 | # tokens are attended to.
159 | attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
160 |
161 | # Zero-pad up to the sequence length.
162 | padding_length = max_seq_len - len(input_ids)
163 | input_ids = input_ids + ([pad_token_id] * padding_length)
164 | attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
165 | token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
166 |
167 | assert len(input_ids) == max_seq_len, "Error with input length {} vs {}".format(len(input_ids), max_seq_len)
168 | assert len(attention_mask) == max_seq_len, "Error with attention mask length {} vs {}".format(len(attention_mask), max_seq_len)
169 | assert len(token_type_ids) == max_seq_len, "Error with token type length {} vs {}".format(len(token_type_ids), max_seq_len)
170 |
171 | label_id = [0] * label_len
172 | for label_idx in example.label_lst:
173 | label_id[label_idx] = 1
174 |
175 | if ex_index < 5:
176 | logger.info("*** Example ***")
177 | logger.info("guid: %s" % example.guid)
178 | logger.info("tokens: %s" % " ".join([str(x) for x in tokens]))
179 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
180 | logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
181 | logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids]))
182 | logger.info("label_id: %s " % " ".join([str(x) for x in label_id]))
183 |
184 | features.append(
185 | InputFeatures(input_ids=input_ids,
186 | attention_mask=attention_mask,
187 | token_type_ids=token_type_ids,
188 | label_id=label_id,
189 | ))
190 |
191 | return features
192 |
193 |
194 | def get_image_nums(args, filename):
195 | img_ids = []
196 | with open(os.path.join(args.data_dir, filename), 'r', encoding='utf-8') as f:
197 | for line in f:
198 | img_num, _, _ = line.split('\t')
199 | img_ids.append(img_num)
200 | return img_ids
201 |
202 |
203 | def load_examples(args, tokenizer, mode):
204 | processor = processors[args.task](args)
205 |
206 | # Load data features from dataset file
207 | # NOTE: Get image features
208 | # Load data features from cache or dataset file
209 | cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}'.format(args.task, mode))
210 | cached_img_features_file = os.path.join(args.data_dir, 'cached_img_{}_{}'.format(args.task, mode))
211 |
212 | if os.path.exists(cached_features_file) and os.path.exists(cached_img_features_file):
213 | logger.info("Loading features from cached file %s", cached_features_file)
214 | features = torch.load(cached_features_file)
215 | logger.info("Loading img features from cached file %s", cached_img_features_file)
216 | all_img_features = torch.load(cached_img_features_file)
217 | else:
218 | logger.info("Creating features from dataset file at %s", args.data_dir)
219 | img_feature_file = h5py.File(os.path.join(args.data_dir, args.h5_filename), 'r')
220 | if mode == "train":
221 | examples = processor.get_examples("train")
222 | img_ids = get_image_nums(args, args.train_file)
223 | all_img_features = load_vgg_features(img_feature_file, img_ids)
224 | elif mode == "dev":
225 | examples = processor.get_examples("dev")
226 | img_ids = get_image_nums(args, args.dev_file)
227 | all_img_features = load_vgg_features(img_feature_file, img_ids)
228 | elif mode == "test":
229 | examples = processor.get_examples("test")
230 | img_ids = get_image_nums(args, args.test_file)
231 | all_img_features = load_vgg_features(img_feature_file, img_ids)
232 | else:
233 | raise Exception("For mode, Only train, dev, test is available")
234 |
235 | label_len = len(get_label(args))
236 |
237 | features = convert_examples_to_features(examples, args.max_seq_len, tokenizer, label_len)
238 | logger.info("Saving features into cached file %s", cached_features_file)
239 | torch.save(features, cached_features_file)
240 | torch.save(all_img_features, cached_img_features_file)
241 |
242 | # Convert to Tensors and build dataset
243 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
244 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
245 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
246 | all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
247 |
248 | print(all_input_ids.size())
249 | print(all_attention_mask.size())
250 | print(all_token_type_ids.size())
251 | print(all_input_ids.size())
252 | print(all_img_features.size())
253 |
254 | dataset = TensorDataset(all_input_ids, all_attention_mask,
255 | all_token_type_ids, all_label_ids, all_img_features)
256 | return dataset
257 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from trainer import Trainer
4 | from utils import init_logger, load_tokenizer, MODEL_CLASSES, MODEL_PATH_MAP
5 | from data_loader import load_examples
6 |
7 |
8 | def main(args):
9 | init_logger()
10 | tokenizer = load_tokenizer(args)
11 | train_dataset = load_examples(args, tokenizer, mode="train")
12 | dev_dataset = load_examples(args, tokenizer, mode="dev")
13 | test_dataset = load_examples(args, tokenizer, mode="test")
14 | trainer = Trainer(args, train_dataset, dev_dataset, test_dataset)
15 |
16 | if args.do_train:
17 | trainer.train()
18 |
19 | if args.do_eval:
20 | trainer.load_model()
21 | trainer.evaluate("test")
22 |
23 |
24 | if __name__ == '__main__':
25 | parser = argparse.ArgumentParser()
26 |
27 | parser.add_argument("--task", default="hashtag", type=str, help="The name of the task to train")
28 | parser.add_argument("--model_dir", default="./model", type=str, help="Path to save, load model")
29 | parser.add_argument("--data_dir", default="./data", type=str, help="The input data dir")
30 | parser.add_argument("--train_file", default="train.tsv", type=str, help="Train file")
31 | parser.add_argument("--dev_file", default="dev.tsv", type=str, help="Dev file")
32 | parser.add_argument("--test_file", default="test.tsv", type=str, help="Test file")
33 | parser.add_argument("--h5_filename", default="img_vgg_feature_224.h5", type=str, help="Image file")
34 | parser.add_argument("--label_file", default="label.txt", type=str, help="Label file")
35 |
36 | parser.add_argument("--model_type", default="albert", type=str, help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
37 |
38 | parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
39 | parser.add_argument("--batch_size", default=16, type=int, help="Batch size for training and evaluation.")
40 | parser.add_argument("--max_seq_len", default=50, type=int, help="The maximum total input sequence length after tokenization.")
41 | parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
42 | parser.add_argument("--num_train_epochs", default=5.0, type=float, help="Total number of training epochs to perform.")
43 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
44 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
45 | help="Number of updates steps to accumulate before performing a backward/update pass.")
46 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
47 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
48 | parser.add_argument("--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
49 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
50 | parser.add_argument("--dropout_rate", default=0.1, type=float, help="Dropout for fully-connected layers")
51 |
52 | parser.add_argument('--logging_steps', type=int, default=500, help="Log every X updates steps.")
53 | parser.add_argument('--save_steps', type=int, default=500, help="Save checkpoint every X updates steps.")
54 |
55 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
56 | parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the test set.")
57 | parser.add_argument("--do_lower_case", action="store_true", help="Whether to lowercase the text (For uncased model)")
58 | parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
59 |
60 | args = parser.parse_args()
61 |
62 | args.model_name_or_path = MODEL_PATH_MAP[args.model_type]
63 | main(args)
64 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import AlbertModel
4 | from transformers.modeling_albert import AlbertPreTrainedModel
5 |
6 |
7 | class FCLayer(nn.Module):
8 | def __init__(self, input_dim, output_dim, dropout_rate=0., use_activation=True):
9 | super(FCLayer, self).__init__()
10 | self.use_activation = use_activation
11 | self.dropout = nn.Dropout(dropout_rate)
12 | self.linear = nn.Linear(input_dim, output_dim)
13 | self.relu = nn.ReLU()
14 |
15 | def forward(self, x):
16 | x = self.dropout(x)
17 | if self.use_activation:
18 | x = self.relu(x)
19 | return self.linear(x)
20 |
21 |
22 | class HashtagClassifier(AlbertPreTrainedModel):
23 | def __init__(self, bert_config, args):
24 | super(HashtagClassifier, self).__init__(bert_config)
25 | self.albert = AlbertModel.from_pretrained(args.model_name_or_path, config=bert_config) # Load pretrained bert
26 |
27 | self.num_labels = bert_config.num_labels
28 |
29 | self.text_classifier = FCLayer(bert_config.hidden_size, 100, args.dropout_rate, use_activation=False)
30 | self.img_classifier = FCLayer(512*7*7, 100, args.dropout_rate, use_activation=False)
31 |
32 | self.label_classifier = FCLayer(200, self.num_labels, args.dropout_rate, use_activation=True)
33 |
34 | def forward(self, input_ids, attention_mask, token_type_ids, labels, img_features):
35 | outputs = self.albert(input_ids, attention_mask=attention_mask,
36 | token_type_ids=token_type_ids) # sequence_output, pooled_output, (hidden_states), (attentions)
37 | pooled_output = outputs[1] # [CLS]
38 | text_tensors = self.text_classifier(pooled_output)
39 |
40 | # NOTE Concat text feature and img features [512, 7, 7]
41 | img_flatten = torch.flatten(img_features, start_dim=1)
42 | img_tensors = self.img_classifier(img_flatten)
43 |
44 | # Concat -> fc_layer
45 | logits = self.label_classifier(torch.cat((text_tensors, img_tensors), -1))
46 | logits = torch.sigmoid(logits)
47 |
48 | outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
49 |
50 | # Softmax
51 | if labels is not None:
52 | loss_fct = nn.BCELoss()
53 | loss = loss_fct(logits, labels.float())
54 |
55 | outputs = (loss,) + outputs
56 |
57 | return outputs # (loss), logits, (hidden_states), (attentions)
58 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | emoji
2 | transformers==2.2.1
3 | h5py
4 | torchvision
5 | torch==1.1.0
6 | sklearn
--------------------------------------------------------------------------------
/server/.gitignore:
--------------------------------------------------------------------------------
1 | static/
--------------------------------------------------------------------------------
/server/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM ubuntu:18.04
2 |
3 | # Install some basic utilities
4 | RUN apt-get update && apt-get install -y \
5 | curl \
6 | ca-certificates \
7 | sudo \
8 | git \
9 | bzip2 \
10 | libx11-6 \
11 | vim \
12 | unzip \
13 | python3-pip \
14 | python3-dev\
15 | && rm -rf /var/lib/apt/lists/*
16 |
17 | # Create a working directory
18 | RUN mkdir /app
19 | WORKDIR /app
20 |
21 | ### Without this Python thinks we're ASCII and unicode chars fail
22 | ENV LANG C.UTF-8
23 |
24 | # Working directory
25 | COPY . /app
26 | WORKDIR /app
27 |
28 | # Install requirements
29 | RUN pip3 install -r requirements.txt --no-cache-dir
30 |
31 | # Save pretrained model in advance
32 | RUN python3 download_models.py
33 |
34 | # Run the server
35 | EXPOSE 80
36 | CMD [ "python3", "app.py"]
--------------------------------------------------------------------------------
/server/app.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import argparse
4 | import gdown
5 |
6 | import torch
7 | import emoji
8 | import numpy as np
9 | import tensorflow as tf
10 | from tensorflow.keras.preprocessing import image
11 | from torchvision.models import vgg16
12 |
13 | from tensorflow.keras.applications.vgg16 import VGG16
14 | from tensorflow.keras.applications.vgg16 import preprocess_input
15 |
16 | from flask import Flask, jsonify, request, render_template, url_for
17 |
18 | from transformers import AlbertTokenizer, AlbertConfig
19 | from model import HashtagClassifier
20 |
21 |
22 | from werkzeug import secure_filename
23 | from flask_uploads import UploadSet, configure_uploads, IMAGES
24 |
25 |
26 | def get_label():
27 | return [label.strip() for label in open('label.txt', 'r', encoding='utf-8')]
28 |
29 |
30 | app = Flask(__name__)
31 |
32 | PHOTO_FOLDER = 'static'
33 | app.config['UPLOAD_FOLDER'] = PHOTO_FOLDER
34 |
35 |
36 | photos = UploadSet('photos', IMAGES)
37 | app.config['UPLOADED_PHOTOS_DEST'] = PHOTO_FOLDER
38 | configure_uploads(app, photos)
39 |
40 |
41 | tokenizer = None
42 | model = None
43 | args = None
44 | label_lst = get_label()
45 |
46 | vgg_model = VGG16(weights='imagenet', include_top=False)
47 |
48 |
49 | DOWNLOAD_URL_MAP = {
50 | 'hashtag': {
51 | 'pytorch_model': ('https://drive.google.com/uc?id=1zs5xGh43KUDnzbw-ntTb4kBU5w8bslI8', 'pytorch_model.bin'),
52 | 'config': ('https://drive.google.com/uc?id=1LVb7BlC3_0jVLei7a8llDH7qIv49wQcs', 'config.json'),
53 | 'training_config': ('https://drive.google.com/uc?id=1uBP_64wdHPb-N6x89LRXLfXdQqoRg75B', 'training_config.bin')
54 | }
55 | }
56 |
57 |
58 | def download(url, filename, cachedir='~/hashtag/'):
59 | f_cachedir = os.path.expanduser(cachedir)
60 | os.makedirs(f_cachedir, exist_ok=True)
61 | file_path = os.path.join(f_cachedir, filename)
62 | if os.path.isfile(file_path):
63 | print('Using cached model')
64 | return file_path
65 | gdown.download(url, file_path, quiet=False)
66 | return file_path
67 |
68 |
69 | def download_model(cachedir='~/hashtag/'):
70 | download(DOWNLOAD_URL_MAP['hashtag']['pytorch_model'][0], DOWNLOAD_URL_MAP['hashtag']['pytorch_model'][1], cachedir)
71 | download(DOWNLOAD_URL_MAP['hashtag']['config'][0], DOWNLOAD_URL_MAP['hashtag']['config'][1], cachedir)
72 | download(DOWNLOAD_URL_MAP['hashtag']['training_config'][0], DOWNLOAD_URL_MAP['hashtag']['training_config'][1], cachedir)
73 |
74 |
75 | def init_model(cachedir='~/hashtag/', no_cuda=True):
76 | global tokenizer, model
77 |
78 | f_cachedir = os.path.expanduser(cachedir)
79 | bert_config = AlbertConfig.from_pretrained(f_cachedir)
80 | model = HashtagClassifier.from_pretrained(f_cachedir, config=bert_config)
81 | device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
82 | model.to(device)
83 | model.eval()
84 |
85 | tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
86 |
87 |
88 | def convert_texts_to_tensors(texts, max_seq_len, no_cuda=True):
89 | input_ids = []
90 | attention_mask = []
91 | token_type_ids = []
92 | for text in texts:
93 | input_id = tokenizer.encode(text, add_special_tokens=True)
94 | input_id = input_id[:max_seq_len]
95 |
96 | attention_id = [1] * len(input_id)
97 |
98 | # Zero padding
99 | padding_length = max_seq_len - len(input_id)
100 | input_id = input_id + ([tokenizer.pad_token_id] * padding_length)
101 | attention_id = attention_id + ([0] * padding_length)
102 |
103 | input_ids.append(input_id)
104 | attention_mask.append(attention_id)
105 | token_type_ids.append([0]*max_seq_len)
106 |
107 | # Change list to torch tensor
108 | device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
109 |
110 | input_ids = torch.tensor(input_ids, dtype=torch.long).to(device)
111 | attention_mask = torch.tensor(attention_mask, dtype=torch.long).to(device)
112 | token_type_ids = torch.tensor(token_type_ids, dtype=torch.long).to(device)
113 | return input_ids, attention_mask, token_type_ids
114 |
115 |
116 | def img_to_tensor(img_path, no_cuda):
117 | img = image.load_img(img_path, target_size=(224, 224))
118 | img_data = image.img_to_array(img)
119 | img_data = np.expand_dims(img_data, axis=0)
120 | img_data = preprocess_input(img_data)
121 |
122 | vgg16_feature = vgg_model.predict(img_data)
123 |
124 | feat = np.transpose(vgg16_feature, (0, 3, 1, 2))
125 | # Change list to torch tensor
126 | device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
127 | return torch.tensor(feat, dtype=torch.float).to(device)
128 |
129 |
130 | @app.route("/predict", methods=["POST", "GET"])
131 | def predict():
132 | img_id = request.args.get('image_id')
133 | text = request.args.get('text')
134 | max_seq_len = int(request.args.get('max_seq_len'))
135 | n_label = int(request.args.get('n_label'))
136 |
137 | # Prediction
138 | img_link = "https://drive.google.com/uc?id={}".format(img_id)
139 | download(img_link, "{}.jpg".format(img_id), cachedir=app.config['UPLOAD_FOLDER'])
140 | img_tensor = img_to_tensor(os.path.join(app.config['UPLOAD_FOLDER'], "{}.jpg".format(img_id)), args.no_cuda)
141 |
142 | texts = [emoji.demojize(text.lower())]
143 |
144 | input_ids, attention_mask, token_type_ids = convert_texts_to_tensors(texts, max_seq_len, args.no_cuda)
145 | with torch.no_grad():
146 | outputs = model(input_ids, attention_mask, token_type_ids, None, img_tensor)
147 | logits = outputs[0]
148 |
149 | _, top_idx = logits.topk(n_label)
150 |
151 | preds = []
152 | print(top_idx)
153 | for idx in top_idx[0]:
154 | preds.append("#{}".format(label_lst[idx]))
155 |
156 | return render_template("result.html", user_image="./{}/{}".format(app.config['UPLOAD_FOLDER'], "{}.jpg".format(img_id)), text=text, tag=" ".join(preds))
157 |
158 |
159 | if __name__ == "__main__":
160 | parser = argparse.ArgumentParser()
161 |
162 | parser.add_argument("-p", "--port_num", type=int, default=80, help="Port Number")
163 | parser.add_argument("-n", "--no_cuda", action="store_true", help="Avoid using CUDA when available")
164 | args = parser.parse_args()
165 |
166 | download_model()
167 | print("Initializing the model...")
168 | init_model(no_cuda=args.no_cuda)
169 |
170 | app.run(host="0.0.0.0", debug=False, port=args.port_num)
171 |
--------------------------------------------------------------------------------
/server/download_models.py:
--------------------------------------------------------------------------------
1 | from tensorflow.keras.applications.vgg16 import VGG16
2 | from app import download_model
3 |
4 | vgg_model = VGG16(weights='imagenet', include_top=False)
5 | download_model()
6 |
--------------------------------------------------------------------------------
/server/label.txt:
--------------------------------------------------------------------------------
1 | [UNK]
2 | adventure
3 | amazing
4 | art
5 | beach
6 | beautiful
7 | beauty
8 | bestoftheday
9 | blue
10 | car
11 | cat
12 | clouds
13 | coffee
14 | cool
15 | cute
16 | dog
17 | explore
18 | f4f
19 | family
20 | fashion
21 | fit
22 | fitness
23 | flowers
24 | follow
25 | follow4follow
26 | followforfollow
27 | followme
28 | food
29 | foodie
30 | foodporn
31 | friends
32 | fun
33 | funny
34 | girl
35 | girls
36 | green
37 | gym
38 | hair
39 | happy
40 | holiday
41 | hot
42 | igers
43 | instadaily
44 | instagood
45 | instagram
46 | instagramers
47 | instalike
48 | instamood
49 | instapic
50 | l4l
51 | landscape
52 | life
53 | lifestyle
54 | like
55 | like4like
56 | likeforlike
57 | lol
58 | love
59 | makeup
60 | me
61 | model
62 | motivation
63 | mountain
64 | mountains
65 | music
66 | nature
67 | naturephotography
68 | night
69 | nofilter
70 | ootd
71 | outfit
72 | party
73 | photo
74 | photographer
75 | photography
76 | photooftheday
77 | picoftheday
78 | pink
79 | pretty
80 | rain
81 | red
82 | sea
83 | selfie
84 | sky
85 | smile
86 | snow
87 | style
88 | summer
89 | sun
90 | sunset
91 | swag
92 | tbt
93 | throwbackthursday
94 | travel
95 | travelphotography
96 | tree
97 | vacation
98 | vsco
99 | vscocam
100 | wanderlust
101 | winter
102 |
--------------------------------------------------------------------------------
/server/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import AlbertModel
4 | from transformers.modeling_albert import AlbertPreTrainedModel
5 |
6 |
7 | class FCLayer(nn.Module):
8 | def __init__(self, input_dim, output_dim, dropout_rate=0., use_activation=True):
9 | super(FCLayer, self).__init__()
10 | self.use_activation = use_activation
11 | self.dropout = nn.Dropout(dropout_rate)
12 | self.linear = nn.Linear(input_dim, output_dim)
13 | self.relu = nn.ReLU()
14 |
15 | def forward(self, x):
16 | x = self.dropout(x)
17 | if self.use_activation:
18 | x = self.relu(x)
19 | return self.linear(x)
20 |
21 |
22 | class HashtagClassifier(AlbertPreTrainedModel):
23 | def __init__(self, bert_config):
24 | super(HashtagClassifier, self).__init__(bert_config)
25 | self.albert = AlbertModel(bert_config) # Load pretrained bert
26 |
27 | self.num_labels = bert_config.num_labels
28 |
29 | self.text_classifier = FCLayer(bert_config.hidden_size, 100, 0, use_activation=False)
30 | self.img_classifier = FCLayer(512*7*7, 100, 0, use_activation=False)
31 |
32 | self.label_classifier = FCLayer(200, self.num_labels, 0, use_activation=True)
33 |
34 | def forward(self, input_ids, attention_mask, token_type_ids, labels, img_features):
35 | outputs = self.albert(input_ids, attention_mask=attention_mask,
36 | token_type_ids=token_type_ids) # sequence_output, pooled_output, (hidden_states), (attentions)
37 | pooled_output = outputs[1] # [CLS]
38 | text_tensors = self.text_classifier(pooled_output)
39 |
40 | # NOTE Concat text feature and img features [512, 7, 7]
41 | img_flatten = torch.flatten(img_features, start_dim=1)
42 | img_tensors = self.img_classifier(img_flatten)
43 |
44 | # Concat -> fc_layer
45 | logits = self.label_classifier(torch.cat((text_tensors, img_tensors), -1))
46 | logits = torch.sigmoid(logits)
47 |
48 | outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
49 |
50 | # Softmax
51 | if labels is not None:
52 | loss_fct = nn.BCELoss()
53 | loss = loss_fct(logits, labels.float())
54 |
55 | outputs = (loss,) + outputs
56 |
57 | return outputs # (loss), logits, (hidden_states), (attentions)
58 |
--------------------------------------------------------------------------------
/server/request.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | curl -X POST -H "Content-Type: application/json" -d '{"image_id":"1okfka3J9d2KBwVqNNSoqeWzQ2YfONijy","text":"I am very cool. Come to my place~~~", "max_seq_len":20,"n_label":10}' http://0.0.0.0:80/predict
3 | curl -X POST -H "Content-Type: application/json" -d '{"image_id":"1DGu9R5a9jpkY-fy79VrGFmCdJigzTMC-","text":"20 days till Christmas 😍🎅", "max_seq_len":20,"n_label":10}' http://0.0.0.0:80/predict
--------------------------------------------------------------------------------
/server/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==2.2.1
2 | torchvision
3 | torch==1.1.0
4 | sklearn
5 | emoji
6 | gdown
7 | tensorflow==2.0.0a0
8 | flask
9 | flask-uploads
--------------------------------------------------------------------------------
/server/run_server.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | docker run -d -p 80:80 --memory="2g" adieujw/hashtag:latest
--------------------------------------------------------------------------------
/server/templates/post.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
Caption: {{ text }}
9 |Tag: {{ tag }}
10 | 11 | 12 | -------------------------------------------------------------------------------- /swagger.yaml: -------------------------------------------------------------------------------- 1 | openapi: 3.0.0 # Open api version 2 | 3 | info: # Set basic infomation 4 | title: hashtag 5 | version: 0.1.0 6 | 7 | servers: # Set your server endpoint 8 | - url: https://endpoint.ainize.ai/monologg/hashtag 9 | 10 | paths: 11 | /predict: # GET method path 12 | get: 13 | parameters: # Set parameter values here 14 | - name: image_id # Set parameter name 15 | in: query # Select amongst query, header, path, and cookie 16 | default: 1DGu9R5a9jpkY-fy79VrGFmCdJigzTMC- 17 | required: true 18 | allowReserved: true # Option for percent-encoding, default; false 19 | - name: text 20 | in: query 21 | required: true 22 | default: 20 days till Christmas 😍🎅 23 | allowReserved: true # Option for percent-encoding, default; false 24 | - name: max_seq_len 25 | in: query 26 | required: true 27 | default: 30 28 | allowReserved: true # Option for percent-encoding, default; false 29 | - name: n_label 30 | in: query 31 | required: true 32 | default: 10 33 | allowReserved: true # Option for percent-encoding, default; false 34 | responses: # Set response 35 | "200": 36 | description: OK 37 | content: 38 | text/html: 39 | schema: 40 | type: string 41 | "400": 42 | description: Bad Request Error 43 | default: 44 | description: Unexpected Error 45 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from tqdm import tqdm, trange 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 8 | from transformers import AdamW, get_linear_schedule_with_warmup 9 | 10 | from utils import set_seed, compute_metrics, get_label, MODEL_CLASSES 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class Trainer(object): 16 | def __init__(self, args, train_dataset=None, dev_dataset=None, test_dataset=None): 17 | self.args = args 18 | self.train_dataset = train_dataset 19 | self.dev_dataset = dev_dataset 20 | self.test_dataset = test_dataset 21 | 22 | self.label_lst = get_label(args) 23 | self.num_labels = len(self.label_lst) 24 | 25 | self.config_class, self.model_class, _ = MODEL_CLASSES[args.model_type] 26 | 27 | self.bert_config = self.config_class.from_pretrained(args.model_name_or_path, num_labels=self.num_labels, finetuning_task=args.task) 28 | self.model = self.model_class(self.bert_config, args) 29 | 30 | # GPU or CPU 31 | self.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" 32 | self.model.to(self.device) 33 | 34 | def train(self): 35 | train_sampler = RandomSampler(self.train_dataset) 36 | train_dataloader = DataLoader(self.train_dataset, sampler=train_sampler, batch_size=self.args.batch_size) 37 | 38 | if self.args.max_steps > 0: 39 | t_total = self.args.max_steps 40 | self.args.num_train_epochs = self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1 41 | else: 42 | t_total = len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs 43 | 44 | # Prepare optimizer and schedule (linear warmup and decay) 45 | no_decay = ['bias', 'LayerNorm.weight'] 46 | optimizer_grouped_parameters = [ 47 | {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 48 | 'weight_decay': self.args.weight_decay}, 49 | {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 50 | ] 51 | optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon) 52 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=t_total) 53 | 54 | # Train! 55 | logger.info("***** Running training *****") 56 | logger.info(" Num examples = %d", len(self.train_dataset)) 57 | logger.info(" Num Epochs = %d", self.args.num_train_epochs) 58 | logger.info(" Total train batch size = %d", self.args.batch_size) 59 | logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) 60 | logger.info(" Total optimization steps = %d", t_total) 61 | 62 | global_step = 0 63 | tr_loss = 0.0 64 | self.model.zero_grad() 65 | 66 | train_iterator = trange(int(self.args.num_train_epochs), desc="Epoch") 67 | set_seed(self.args) 68 | 69 | for _ in train_iterator: 70 | epoch_iterator = tqdm(train_dataloader, desc="Iteration") 71 | for step, batch in enumerate(epoch_iterator): 72 | self.model.train() 73 | batch = tuple(t.to(self.device) for t in batch) # GPU or CPU 74 | inputs = {'input_ids': batch[0], 75 | 'attention_mask': batch[1], 76 | 'token_type_ids': batch[2], 77 | 'labels': batch[3], 78 | 'img_features': batch[4]} 79 | outputs = self.model(**inputs) 80 | loss = outputs[0] 81 | 82 | if self.args.gradient_accumulation_steps > 1: 83 | loss = loss / self.args.gradient_accumulation_steps 84 | 85 | loss.backward() 86 | 87 | tr_loss += loss.item() 88 | if (step + 1) % self.args.gradient_accumulation_steps == 0: 89 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) 90 | 91 | optimizer.step() 92 | scheduler.step() # Update learning rate schedule 93 | self.model.zero_grad() 94 | global_step += 1 95 | 96 | if self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0: 97 | self.evaluate("dev") # Only test set available for NSMC 98 | 99 | if self.args.save_steps > 0 and global_step % self.args.save_steps == 0: 100 | self.save_model() 101 | 102 | if 0 < self.args.max_steps < global_step: 103 | epoch_iterator.close() 104 | break 105 | 106 | if 0 < self.args.max_steps < global_step: 107 | train_iterator.close() 108 | break 109 | 110 | return global_step, tr_loss / global_step 111 | 112 | def evaluate(self, mode): 113 | # We use test dataset because semeval doesn't have dev dataset 114 | if mode == 'test': 115 | dataset = self.test_dataset 116 | elif mode == 'dev': 117 | dataset = self.dev_dataset 118 | else: 119 | raise Exception("Only dev and test dataset available") 120 | 121 | eval_sampler = SequentialSampler(dataset) 122 | eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=self.args.batch_size) 123 | 124 | # Eval! 125 | logger.info("***** Running evaluation on %s dataset *****", mode) 126 | logger.info(" Num examples = %d", len(dataset)) 127 | logger.info(" Batch size = %d", self.args.batch_size) 128 | eval_loss = 0.0 129 | nb_eval_steps = 0 130 | preds = None 131 | out_label_ids = None 132 | 133 | self.model.eval() 134 | 135 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 136 | batch = tuple(t.to(self.device) for t in batch) 137 | with torch.no_grad(): 138 | inputs = {'input_ids': batch[0], 139 | 'attention_mask': batch[1], 140 | 'token_type_ids': batch[2], 141 | 'labels': batch[3], 142 | 'img_features': batch[4]} 143 | outputs = self.model(**inputs) 144 | tmp_eval_loss, logits = outputs[:2] 145 | 146 | eval_loss += tmp_eval_loss.mean().item() 147 | nb_eval_steps += 1 148 | 149 | if preds is None: 150 | preds = logits.detach().cpu().numpy() 151 | out_label_ids = inputs['labels'].detach().cpu().numpy() 152 | else: 153 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 154 | out_label_ids = np.append( 155 | out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0) 156 | 157 | eval_loss = eval_loss / nb_eval_steps 158 | results = { 159 | "loss": eval_loss 160 | } 161 | 162 | # preds = np.argmax(preds, axis=1) 163 | preds[preds >= 0.5] = 1 164 | preds[preds < 0.5] = 0 165 | result = compute_metrics(preds, out_label_ids) 166 | results.update(result) 167 | 168 | logger.info("***** Eval results *****") 169 | for key in sorted(results.keys()): 170 | logger.info(" %s = %s", key, str(results[key])) 171 | 172 | return results 173 | 174 | def save_model(self): 175 | # Save model checkpoint (Overwrite) 176 | output_dir = os.path.join(self.args.model_dir) 177 | 178 | if not os.path.exists(output_dir): 179 | os.makedirs(output_dir) 180 | model_to_save = self.model.module if hasattr(self.model, 'module') else self.model 181 | model_to_save.save_pretrained(output_dir) 182 | torch.save(self.args, os.path.join(output_dir, 'training_config.bin')) 183 | logger.info("Saving model checkpoint to %s", output_dir) 184 | 185 | def load_model(self): 186 | # Check whether model exists 187 | if not os.path.exists(self.args.model_dir): 188 | raise Exception("Model doesn't exists! Train first!") 189 | 190 | try: 191 | self.bert_config = self.config_class.from_pretrained(self.args.model_dir) 192 | logger.info("***** Config loaded *****") 193 | self.model = self.model_class.from_pretrained(self.args.model_dir, config=self.bert_config, args=self.args) 194 | self.model.to(self.device) 195 | logger.info("***** Model Loaded *****") 196 | except: 197 | raise Exception("Some model files might be missing...") 198 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import logging 4 | 5 | import torch 6 | import numpy as np 7 | from sklearn.metrics import f1_score 8 | 9 | from transformers import AlbertConfig, AlbertTokenizer 10 | 11 | from model import HashtagClassifier 12 | 13 | MODEL_CLASSES = { 14 | 'albert': (AlbertConfig, HashtagClassifier, AlbertTokenizer), 15 | } 16 | 17 | MODEL_PATH_MAP = { 18 | 'albert': 'albert-base-v2' 19 | } 20 | 21 | 22 | def get_label(args): 23 | return [label.strip() for label in open(os.path.join(args.data_dir, args.label_file), 'r', encoding='utf-8')] 24 | 25 | 26 | def load_tokenizer(args): 27 | return MODEL_CLASSES[args.model_type][2].from_pretrained(args.model_name_or_path) 28 | 29 | 30 | def init_logger(): 31 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 32 | datefmt='%m/%d/%Y %H:%M:%S', 33 | level=logging.INFO) 34 | 35 | 36 | def set_seed(args): 37 | random.seed(args.seed) 38 | np.random.seed(args.seed) 39 | torch.manual_seed(args.seed) 40 | if not args.no_cuda and torch.cuda.is_available(): 41 | torch.cuda.manual_seed_all(args.seed) 42 | 43 | 44 | def compute_metrics(preds, labels): 45 | assert len(preds) == len(labels) 46 | return acc_and_f1(preds, labels) 47 | 48 | 49 | def simple_accuracy(preds, labels): 50 | return (preds == labels).mean() 51 | 52 | 53 | def acc_and_f1(preds, labels, average='macro'): 54 | acc = simple_accuracy(preds, labels) 55 | f1 = f1_score(y_true=labels, y_pred=preds, average=average) 56 | return { 57 | "acc": acc, 58 | "f1": f1, 59 | } 60 | -------------------------------------------------------------------------------- /vgg_feature.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import numpy as np 6 | 7 | 8 | def load_vgg_features(img_feature_file, img_ids): 9 | img_features = [] 10 | for image_num in img_ids: 11 | feature = img_feature_file.get(image_num) 12 | if not feature: 13 | print("Errrrr") 14 | img_features.append(np.array(feature)) 15 | img_tensor_features = torch.tensor(img_features, dtype=torch.float).squeeze(1) 16 | print(img_tensor_features.size()) 17 | return img_tensor_features 18 | --------------------------------------------------------------------------------