├── .gitignore ├── LICENSE ├── README.md ├── config ├── ekman.json ├── group.json └── original.json ├── data ├── ekman │ ├── dev.tsv │ ├── labels.txt │ ├── test.tsv │ └── train.tsv ├── group │ ├── dev.tsv │ ├── labels.txt │ ├── test.tsv │ └── train.tsv └── original │ ├── dev.tsv │ ├── labels.txt │ ├── test.tsv │ └── train.tsv ├── data_loader.py ├── model.py ├── multilabel_pipeline.py ├── requirements.txt ├── run_goemotions.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | ################## 132 | .vscode/ 133 | .idea/ 134 | 135 | cached* 136 | ckpt* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GoEmotions Pytorch 2 | 3 | Pytorch Implementation of [GoEmotions](https://github.com/google-research/google-research/tree/master/goemotions) with [Huggingface Transformers](https://github.com/huggingface/transformers) 4 | 5 | ## What is GoEmotions 6 | 7 | Dataset labeled **58000 Reddit comments** with **28 emotions** 8 | 9 | - admiration, amusement, anger, annoyance, approval, caring, confusion, curiosity, desire, disappointment, disapproval, disgust, embarrassment, excitement, fear, gratitude, grief, joy, love, nervousness, optimism, pride, realization, relief, remorse, sadness, surprise + neutral 10 | 11 | ## Training Details 12 | 13 | - Use `bert-base-cased` (Same as the paper's code) 14 | - In paper, **3 Taxonomies** were used. I've also made the data with new taxonomy labels for `hierarchical grouping` and `ekman`. 15 | 16 | 1. **Original GoEmotions** (27 emotions + neutral) 17 | 2. **Hierarchical Grouping** (positive, negative, ambiguous + neutral) 18 | 3. **Ekman** (anger, disgust, fear, joy, sadness, surprise + neutral) 19 | 20 | ### Vocabulary 21 | 22 | - I've replace `[unused1]`, `[unused2]` to `[NAME]`, `[RELIGION]` in the vocab, respectively. 23 | 24 | ```text 25 | [PAD] 26 | [NAME] 27 | [RELIGION] 28 | [unused3] 29 | [unused4] 30 | ... 31 | ``` 32 | 33 | - I've also set `special_tokens_map.json` as below, so the tokenizer won't split the `[NAME]` or `[RELIGION]` into its word pieces. 34 | 35 | ```json 36 | { 37 | "unk_token": "[UNK]", 38 | "sep_token": "[SEP]", 39 | "pad_token": "[PAD]", 40 | "cls_token": "[CLS]", 41 | "mask_token": "[MASK]", 42 | "additional_special_tokens": ["[NAME]", "[RELIGION]"] 43 | } 44 | ``` 45 | 46 | ### Requirements 47 | 48 | - torch==1.4.0 49 | - transformers==2.11.0 50 | - attrdict==2.0.1 51 | 52 | ### Hyperparameters 53 | 54 | You can change the parameters from the json files in `config` directory. 55 | 56 | | Parameter | | 57 | | ----------------- | ---: | 58 | | Learning rate | 5e-5 | 59 | | Warmup proportion | 0.1 | 60 | | Epochs | 10 | 61 | | Max Seq Length | 50 | 62 | | Batch size | 16 | 63 | 64 | ## How to Run 65 | 66 | For taxonomy, choose `original`, `group` or `ekman` 67 | 68 | ```bash 69 | $ python3 run_goemotions.py --taxonomy {$TAXONOMY} 70 | 71 | $ python3 run_goemotions.py --taxonomy original 72 | $ python3 run_goemotions.py --taxonomy group 73 | $ python3 run_goemotions.py --taxonomy ekman 74 | ``` 75 | 76 | ## Results 77 | 78 | Best Result of `Macro F1` 79 | 80 | | Macro F1 (%) | Dev | Test | 81 | | ------------ | :---: | :---: | 82 | | original | 50.16 | 50.30 | 83 | | group | 69.41 | 70.06 | 84 | | ekman | 62.59 | 62.38 | 85 | 86 | ## Pipeline 87 | 88 | - Inference for multi-label classification was made possible by creating a new `MultiLabelPipeline` class. 89 | - Already uploaded `finetuned model` on Huggingface S3. 90 | - Original GoEmotions Taxonomy: `monologg/bert-base-cased-goemotions-original` 91 | - Hierarchical Group Taxonomy: `monologg/bert-base-cased-goemotions-group` 92 | - Ekman Taxonomy: `monologg/bert-base-cased-goemotions-ekman` 93 | 94 | ### 1. Original GoEmotions Taxonomy 95 | 96 | ```python 97 | from transformers import BertTokenizer 98 | from model import BertForMultiLabelClassification 99 | from multilabel_pipeline import MultiLabelPipeline 100 | from pprint import pprint 101 | 102 | tokenizer = BertTokenizer.from_pretrained("monologg/bert-base-cased-goemotions-original") 103 | model = BertForMultiLabelClassification.from_pretrained("monologg/bert-base-cased-goemotions-original") 104 | 105 | goemotions = MultiLabelPipeline( 106 | model=model, 107 | tokenizer=tokenizer, 108 | threshold=0.3 109 | ) 110 | 111 | texts = [ 112 | "Hey that's a thought! Maybe we need [NAME] to be the celebrity vaccine endorsement!", 113 | "it’s happened before?! love my hometown of beautiful new ken 😂😂", 114 | "I love you, brother.", 115 | "Troll, bro. They know they're saying stupid shit. The motherfucker does nothing but stink up libertarian subs talking shit", 116 | ] 117 | 118 | pprint(goemotions(texts)) 119 | 120 | # Output 121 | [{'labels': ['neutral'], 'scores': [0.9750906]}, 122 | {'labels': ['curiosity', 'love'], 'scores': [0.9694574, 0.9227462]}, 123 | {'labels': ['love'], 'scores': [0.993483]}, 124 | {'labels': ['anger'], 'scores': [0.99225825]}] 125 | ``` 126 | 127 | 128 | ### 2. Group Taxonomy 129 | 130 | ```python 131 | from transformers import BertTokenizer 132 | from model import BertForMultiLabelClassification 133 | from multilabel_pipeline import MultiLabelPipeline 134 | from pprint import pprint 135 | 136 | tokenizer = BertTokenizer.from_pretrained("monologg/bert-base-cased-goemotions-group") 137 | model = BertForMultiLabelClassification.from_pretrained("monologg/bert-base-cased-goemotions-group") 138 | 139 | goemotions = MultiLabelPipeline( 140 | model=model, 141 | tokenizer=tokenizer, 142 | threshold=0.3 143 | ) 144 | 145 | texts = [ 146 | "Hey that's a thought! Maybe we need [NAME] to be the celebrity vaccine endorsement!", 147 | "it’s happened before?! love my hometown of beautiful new ken 😂😂", 148 | "I love you, brother.", 149 | "Troll, bro. They know they're saying stupid shit. The motherfucker does nothing but stink up libertarian subs talking shit", 150 | ] 151 | 152 | pprint(goemotions(texts)) 153 | 154 | # Output 155 | [{'labels': ['positive'], 'scores': [0.9989434]}, 156 | {'labels': ['ambiguous', 'positive'], 'scores': [0.99801123, 0.99845874]}, 157 | {'labels': ['positive'], 'scores': [0.99930394]}, 158 | {'labels': ['negative'], 'scores': [0.9984231]}] 159 | ``` 160 | 161 | ### 3. Ekman Taxonomy 162 | 163 | ```python 164 | from transformers import BertTokenizer 165 | from model import BertForMultiLabelClassification 166 | from multilabel_pipeline import MultiLabelPipeline 167 | from pprint import pprint 168 | 169 | tokenizer = BertTokenizer.from_pretrained("monologg/bert-base-cased-goemotions-ekman") 170 | model = BertForMultiLabelClassification.from_pretrained("monologg/bert-base-cased-goemotions-ekman") 171 | 172 | goemotions = MultiLabelPipeline( 173 | model=model, 174 | tokenizer=tokenizer, 175 | threshold=0.3 176 | ) 177 | 178 | texts = [ 179 | "Hey that's a thought! Maybe we need [NAME] to be the celebrity vaccine endorsement!", 180 | "it’s happened before?! love my hometown of beautiful new ken 😂😂", 181 | "I love you, brother.", 182 | "Troll, bro. They know they're saying stupid shit. The motherfucker does nothing but stink up libertarian subs talking shit", 183 | ] 184 | 185 | pprint(goemotions(texts)) 186 | 187 | # Output 188 | [{'labels': ['joy', 'neutral'], 'scores': [0.30459446, 0.9217335]}, 189 | {'labels': ['joy', 'surprise'], 'scores': [0.9981395, 0.99863845]}, 190 | {'labels': ['joy'], 'scores': [0.99910116]}, 191 | {'labels': ['anger'], 'scores': [0.9984291]}] 192 | ``` 193 | 194 | ## Reference 195 | 196 | - [GoEmotions](https://github.com/google-research/google-research/tree/master/goemotions) 197 | - [GoEmotions Github](https://github.com/google-research/google-research/tree/master/goemotions) 198 | - [Huggingface Transformers](https://github.com/huggingface/transformers) 199 | -------------------------------------------------------------------------------- /config/ekman.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "goemotions", 3 | "data_dir": "data/ekman", 4 | "ckpt_dir": "ckpt/ekman", 5 | "output_dir": "bert-base-cased-goemotions-ekman", 6 | "train_file": "train.tsv", 7 | "dev_file": "dev.tsv", 8 | "test_file": "test.tsv", 9 | "label_file": "labels.txt", 10 | "evaluate_test_during_training": false, 11 | "eval_all_checkpoints": true, 12 | "save_optimizer": false, 13 | "do_lower_case": false, 14 | "do_train": true, 15 | "do_eval": true, 16 | "max_seq_len": 50, 17 | "num_train_epochs": 10, 18 | "weight_decay": 0.0, 19 | "gradient_accumulation_steps": 1, 20 | "adam_epsilon": 1e-8, 21 | "warmup_proportion": 0.1, 22 | "max_steps": -1, 23 | "max_grad_norm": 1.0, 24 | "no_cuda": false, 25 | "model_type": "bert", 26 | "model_name_or_path": "bert-base-cased", 27 | "tokenizer_name_or_path": "monologg/bert-base-cased-goemotions-ekman", 28 | "seed": 42, 29 | "train_batch_size": 16, 30 | "eval_batch_size": 32, 31 | "logging_steps": 1000, 32 | "save_steps": 1000, 33 | "learning_rate": 5e-5, 34 | "threshold": 0.3 35 | } 36 | -------------------------------------------------------------------------------- /config/group.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "goemotions", 3 | "data_dir": "data/group", 4 | "ckpt_dir": "ckpt/group", 5 | "output_dir": "bert-base-cased-goemotions-group", 6 | "train_file": "train.tsv", 7 | "dev_file": "dev.tsv", 8 | "test_file": "test.tsv", 9 | "label_file": "labels.txt", 10 | "evaluate_test_during_training": false, 11 | "eval_all_checkpoints": true, 12 | "save_optimizer": false, 13 | "do_lower_case": false, 14 | "do_train": true, 15 | "do_eval": true, 16 | "max_seq_len": 50, 17 | "num_train_epochs": 10, 18 | "weight_decay": 0.0, 19 | "gradient_accumulation_steps": 1, 20 | "adam_epsilon": 1e-8, 21 | "warmup_proportion": 0.1, 22 | "max_steps": -1, 23 | "max_grad_norm": 1.0, 24 | "no_cuda": false, 25 | "model_type": "bert", 26 | "model_name_or_path": "bert-base-cased", 27 | "tokenizer_name_or_path": "monologg/bert-base-cased-goemotions-group", 28 | "seed": 777, 29 | "train_batch_size": 16, 30 | "eval_batch_size": 32, 31 | "logging_steps": 1000, 32 | "save_steps": 1000, 33 | "learning_rate": 5e-5, 34 | "threshold": 0.3 35 | } 36 | -------------------------------------------------------------------------------- /config/original.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "goemotions", 3 | "data_dir": "data/original", 4 | "ckpt_dir": "ckpt/original", 5 | "output_dir": "bert-base-cased-goemotions-original", 6 | "train_file": "train.tsv", 7 | "dev_file": "dev.tsv", 8 | "test_file": "test.tsv", 9 | "label_file": "labels.txt", 10 | "evaluate_test_during_training": false, 11 | "eval_all_checkpoints": true, 12 | "save_optimizer": false, 13 | "do_lower_case": false, 14 | "do_train": true, 15 | "do_eval": true, 16 | "max_seq_len": 50, 17 | "num_train_epochs": 10, 18 | "weight_decay": 0.0, 19 | "gradient_accumulation_steps": 1, 20 | "adam_epsilon": 1e-8, 21 | "warmup_proportion": 0.1, 22 | "max_steps": -1, 23 | "max_grad_norm": 1.0, 24 | "no_cuda": false, 25 | "model_type": "bert", 26 | "model_name_or_path": "bert-base-cased", 27 | "tokenizer_name_or_path": "monologg/bert-base-cased-goemotions-original", 28 | "seed": 42, 29 | "train_batch_size": 16, 30 | "eval_batch_size": 32, 31 | "logging_steps": 1000, 32 | "save_steps": 1000, 33 | "learning_rate": 5e-5, 34 | "threshold": 0.3 35 | } 36 | -------------------------------------------------------------------------------- /data/ekman/labels.txt: -------------------------------------------------------------------------------- 1 | anger 2 | disgust 3 | fear 4 | joy 5 | neutral 6 | sadness 7 | surprise -------------------------------------------------------------------------------- /data/group/labels.txt: -------------------------------------------------------------------------------- 1 | ambiguous 2 | negative 3 | neutral 4 | positive -------------------------------------------------------------------------------- /data/original/labels.txt: -------------------------------------------------------------------------------- 1 | admiration 2 | amusement 3 | anger 4 | annoyance 5 | approval 6 | caring 7 | confusion 8 | curiosity 9 | desire 10 | disappointment 11 | disapproval 12 | disgust 13 | embarrassment 14 | excitement 15 | fear 16 | gratitude 17 | grief 18 | joy 19 | love 20 | nervousness 21 | optimism 22 | pride 23 | realization 24 | relief 25 | remorse 26 | sadness 27 | surprise 28 | neutral -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import json 4 | import logging 5 | 6 | import torch 7 | from torch.utils.data import TensorDataset 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class InputExample(object): 13 | """ A single training/test example for simple sequence classification. """ 14 | 15 | def __init__(self, guid, text_a, text_b, label): 16 | self.guid = guid 17 | self.text_a = text_a 18 | self.text_b = text_b 19 | self.label = label 20 | 21 | def __repr__(self): 22 | return str(self.to_json_string()) 23 | 24 | def to_dict(self): 25 | """Serializes this instance to a Python dictionary.""" 26 | output = copy.deepcopy(self.__dict__) 27 | return output 28 | 29 | def to_json_string(self): 30 | """Serializes this instance to a JSON string.""" 31 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 32 | 33 | 34 | class InputFeatures(object): 35 | """A single set of features of data.""" 36 | 37 | def __init__(self, input_ids, attention_mask, token_type_ids, label): 38 | self.input_ids = input_ids 39 | self.attention_mask = attention_mask 40 | self.token_type_ids = token_type_ids 41 | self.label = label 42 | 43 | def __repr__(self): 44 | return str(self.to_json_string()) 45 | 46 | def to_dict(self): 47 | """Serializes this instance to a Python dictionary.""" 48 | output = copy.deepcopy(self.__dict__) 49 | return output 50 | 51 | def to_json_string(self): 52 | """Serializes this instance to a JSON string.""" 53 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 54 | 55 | 56 | def convert_examples_to_features( 57 | args, 58 | examples, 59 | tokenizer, 60 | max_length, 61 | ): 62 | processor = GoEmotionsProcessor(args) 63 | label_list_len = len(processor.get_labels()) 64 | 65 | def convert_to_one_hot_label(label): 66 | one_hot_label = [0] * label_list_len 67 | for l in label: 68 | one_hot_label[l] = 1 69 | return one_hot_label 70 | 71 | labels = [convert_to_one_hot_label(example.label) for example in examples] 72 | 73 | batch_encoding = tokenizer.batch_encode_plus( 74 | [(example.text_a, example.text_b) for example in examples], max_length=max_length, pad_to_max_length=True 75 | ) 76 | 77 | features = [] 78 | for i in range(len(examples)): 79 | inputs = {k: batch_encoding[k][i] for k in batch_encoding} 80 | 81 | feature = InputFeatures(**inputs, label=labels[i]) 82 | features.append(feature) 83 | 84 | for i, example in enumerate(examples[:10]): 85 | logger.info("*** Example ***") 86 | logger.info("guid: {}".format(example.guid)) 87 | logger.info("sentence: {}".format(example.text_a)) 88 | logger.info("tokens: {}".format(" ".join([str(x) for x in tokenizer.tokenize(example.text_a)]))) 89 | logger.info("input_ids: {}".format(" ".join([str(x) for x in features[i].input_ids]))) 90 | logger.info("attention_mask: {}".format(" ".join([str(x) for x in features[i].attention_mask]))) 91 | logger.info("token_type_ids: {}".format(" ".join([str(x) for x in features[i].token_type_ids]))) 92 | logger.info("label: {}".format(" ".join([str(x) for x in features[i].label]))) 93 | 94 | return features 95 | 96 | 97 | class GoEmotionsProcessor(object): 98 | """Processor for the GoEmotions data set """ 99 | 100 | def __init__(self, args): 101 | self.args = args 102 | 103 | def get_labels(self): 104 | labels = [] 105 | with open(os.path.join(self.args.data_dir, self.args.label_file), "r", encoding="utf-8") as f: 106 | for line in f: 107 | labels.append(line.rstrip()) 108 | return labels 109 | 110 | @classmethod 111 | def _read_file(cls, input_file): 112 | """Reads a tab separated value file.""" 113 | with open(input_file, "r", encoding="utf-8") as f: 114 | return f.readlines() 115 | 116 | def _create_examples(self, lines, set_type): 117 | """ Creates examples for the train, dev and test sets.""" 118 | examples = [] 119 | for (i, line) in enumerate(lines): 120 | guid = "%s-%s" % (set_type, i) 121 | line = line.strip() 122 | items = line.split("\t") 123 | text_a = items[0] 124 | label = list(map(int, items[1].split(","))) 125 | if i % 5000 == 0: 126 | logger.info(line) 127 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 128 | return examples 129 | 130 | def get_examples(self, mode): 131 | """ 132 | Args: 133 | mode: train, dev, test 134 | """ 135 | file_to_read = None 136 | if mode == 'train': 137 | file_to_read = self.args.train_file 138 | elif mode == 'dev': 139 | file_to_read = self.args.dev_file 140 | elif mode == 'test': 141 | file_to_read = self.args.test_file 142 | 143 | logger.info("LOOKING AT {}".format(os.path.join(self.args.data_dir, file_to_read))) 144 | return self._create_examples(self._read_file(os.path.join(self.args.data_dir, 145 | file_to_read)), mode) 146 | 147 | 148 | def load_and_cache_examples(args, tokenizer, mode): 149 | processor = GoEmotionsProcessor(args) 150 | # Load data features from cache or dataset file 151 | cached_features_file = os.path.join( 152 | args.data_dir, 153 | "cached_{}_{}_{}_{}".format( 154 | str(args.task), 155 | list(filter(None, args.model_name_or_path.split("/"))).pop(), 156 | str(args.max_seq_len), 157 | mode 158 | ) 159 | ) 160 | if os.path.exists(cached_features_file): 161 | logger.info("Loading features from cached file %s", cached_features_file) 162 | features = torch.load(cached_features_file) 163 | else: 164 | logger.info("Creating features from dataset file at %s", args.data_dir) 165 | if mode == "train": 166 | examples = processor.get_examples("train") 167 | elif mode == "dev": 168 | examples = processor.get_examples("dev") 169 | elif mode == "test": 170 | examples = processor.get_examples("test") 171 | else: 172 | raise ValueError("For mode, only train, dev, test is available") 173 | features = convert_examples_to_features( 174 | args, examples, tokenizer, max_length=args.max_seq_len 175 | ) 176 | logger.info("Saving features into cached file %s", cached_features_file) 177 | torch.save(features, cached_features_file) 178 | 179 | # Convert to Tensors and build dataset 180 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 181 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) 182 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) 183 | all_labels = torch.tensor([f.label for f in features], dtype=torch.float) 184 | 185 | dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels) 186 | return dataset 187 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import BertPreTrainedModel, BertModel 3 | 4 | 5 | class BertForMultiLabelClassification(BertPreTrainedModel): 6 | def __init__(self, config): 7 | super().__init__(config) 8 | self.num_labels = config.num_labels 9 | 10 | self.bert = BertModel(config) 11 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 12 | self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) 13 | self.loss_fct = nn.BCEWithLogitsLoss() 14 | 15 | self.init_weights() 16 | 17 | def forward( 18 | self, 19 | input_ids=None, 20 | attention_mask=None, 21 | token_type_ids=None, 22 | position_ids=None, 23 | head_mask=None, 24 | inputs_embeds=None, 25 | labels=None, 26 | ): 27 | outputs = self.bert( 28 | input_ids, 29 | attention_mask=attention_mask, 30 | token_type_ids=token_type_ids, 31 | position_ids=position_ids, 32 | head_mask=head_mask, 33 | inputs_embeds=inputs_embeds, 34 | ) 35 | pooled_output = outputs[1] 36 | 37 | pooled_output = self.dropout(pooled_output) 38 | logits = self.classifier(pooled_output) 39 | 40 | outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here 41 | 42 | if labels is not None: 43 | loss = self.loss_fct(logits, labels) 44 | outputs = (loss,) + outputs 45 | 46 | return outputs # (loss), logits, (hidden_states), (attentions) 47 | -------------------------------------------------------------------------------- /multilabel_pipeline.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional 2 | 3 | import numpy as np 4 | from transformers.pipelines import ArgumentHandler 5 | from transformers import ( 6 | Pipeline, 7 | PreTrainedTokenizer, 8 | ModelCard 9 | ) 10 | 11 | 12 | class MultiLabelPipeline(Pipeline): 13 | def __init__( 14 | self, 15 | model: Union["PreTrainedModel", "TFPreTrainedModel"], 16 | tokenizer: PreTrainedTokenizer, 17 | modelcard: Optional[ModelCard] = None, 18 | framework: Optional[str] = None, 19 | task: str = "", 20 | args_parser: ArgumentHandler = None, 21 | device: int = -1, 22 | binary_output: bool = False, 23 | threshold: float = 0.3 24 | ): 25 | super().__init__( 26 | model=model, 27 | tokenizer=tokenizer, 28 | modelcard=modelcard, 29 | framework=framework, 30 | args_parser=args_parser, 31 | device=device, 32 | binary_output=binary_output, 33 | task=task 34 | ) 35 | 36 | self.threshold = threshold 37 | 38 | def __call__(self, *args, **kwargs): 39 | outputs = super().__call__(*args, **kwargs) 40 | scores = 1 / (1 + np.exp(-outputs)) # Sigmoid 41 | results = [] 42 | for item in scores: 43 | labels = [] 44 | scores = [] 45 | for idx, s in enumerate(item): 46 | if s > self.threshold: 47 | labels.append(self.model.config.id2label[idx]) 48 | scores.append(s) 49 | results.append({"labels": labels, "scores": scores}) 50 | return results 51 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.4.0 2 | transformers==2.11.0 3 | attrdict==2.0.1 -------------------------------------------------------------------------------- /run_goemotions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import glob 6 | 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 10 | from tqdm import tqdm, trange 11 | from attrdict import AttrDict 12 | 13 | from transformers import ( 14 | BertConfig, 15 | BertTokenizer, 16 | AdamW, 17 | get_linear_schedule_with_warmup 18 | ) 19 | 20 | from model import BertForMultiLabelClassification 21 | from utils import ( 22 | init_logger, 23 | set_seed, 24 | compute_metrics 25 | ) 26 | from data_loader import ( 27 | load_and_cache_examples, 28 | GoEmotionsProcessor 29 | ) 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | def train(args, 35 | model, 36 | tokenizer, 37 | train_dataset, 38 | dev_dataset=None, 39 | test_dataset=None): 40 | train_sampler = RandomSampler(train_dataset) 41 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 42 | if args.max_steps > 0: 43 | t_total = args.max_steps 44 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 45 | else: 46 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 47 | 48 | # Prepare optimizer and schedule (linear warmup and decay) 49 | no_decay = ['bias', 'LayerNorm.weight'] 50 | optimizer_grouped_parameters = [ 51 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 52 | 'weight_decay': args.weight_decay}, 53 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 54 | ] 55 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 56 | scheduler = get_linear_schedule_with_warmup( 57 | optimizer, 58 | num_warmup_steps=int(t_total * args.warmup_proportion), 59 | num_training_steps=t_total 60 | ) 61 | 62 | if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile( 63 | os.path.join(args.model_name_or_path, "scheduler.pt") 64 | ): 65 | # Load optimizer and scheduler states 66 | optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) 67 | scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) 68 | 69 | # Train! 70 | logger.info("***** Running training *****") 71 | logger.info(" Num examples = %d", len(train_dataset)) 72 | logger.info(" Num Epochs = %d", args.num_train_epochs) 73 | logger.info(" Total train batch size = %d", args.train_batch_size) 74 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 75 | logger.info(" Total optimization steps = %d", t_total) 76 | logger.info(" Logging steps = %d", args.logging_steps) 77 | logger.info(" Save steps = %d", args.save_steps) 78 | 79 | global_step = 0 80 | tr_loss = 0.0 81 | 82 | model.zero_grad() 83 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch") 84 | for _ in train_iterator: 85 | epoch_iterator = tqdm(train_dataloader, desc="Iteration") 86 | for step, batch in enumerate(epoch_iterator): 87 | model.train() 88 | batch = tuple(t.to(args.device) for t in batch) 89 | inputs = { 90 | "input_ids": batch[0], 91 | "attention_mask": batch[1], 92 | "token_type_ids": batch[2], 93 | "labels": batch[3] 94 | } 95 | outputs = model(**inputs) 96 | 97 | loss = outputs[0] 98 | 99 | if args.gradient_accumulation_steps > 1: 100 | loss = loss / args.gradient_accumulation_steps 101 | 102 | loss.backward() 103 | tr_loss += loss.item() 104 | if (step + 1) % args.gradient_accumulation_steps == 0 or ( 105 | len(train_dataloader) <= args.gradient_accumulation_steps 106 | and (step + 1) == len(train_dataloader) 107 | ): 108 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 109 | 110 | optimizer.step() 111 | scheduler.step() 112 | model.zero_grad() 113 | global_step += 1 114 | 115 | if args.logging_steps > 0 and global_step % args.logging_steps == 0: 116 | if args.evaluate_test_during_training: 117 | evaluate(args, model, test_dataset, "test", global_step) 118 | else: 119 | evaluate(args, model, dev_dataset, "dev", global_step) 120 | 121 | if args.save_steps > 0 and global_step % args.save_steps == 0: 122 | # Save model checkpoint 123 | output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) 124 | if not os.path.exists(output_dir): 125 | os.makedirs(output_dir) 126 | model_to_save = ( 127 | model.module if hasattr(model, "module") else model 128 | ) 129 | model_to_save.save_pretrained(output_dir) 130 | tokenizer.save_pretrained(output_dir) 131 | 132 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 133 | logger.info("Saving model checkpoint to {}".format(output_dir)) 134 | 135 | if args.save_optimizer: 136 | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 137 | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 138 | logger.info("Saving optimizer and scheduler states to {}".format(output_dir)) 139 | 140 | if args.max_steps > 0 and global_step > args.max_steps: 141 | break 142 | 143 | if args.max_steps > 0 and global_step > args.max_steps: 144 | break 145 | 146 | return global_step, tr_loss / global_step 147 | 148 | 149 | def evaluate(args, model, eval_dataset, mode, global_step=None): 150 | results = {} 151 | eval_sampler = SequentialSampler(eval_dataset) 152 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 153 | 154 | # Eval! 155 | if global_step != None: 156 | logger.info("***** Running evaluation on {} dataset ({} step) *****".format(mode, global_step)) 157 | else: 158 | logger.info("***** Running evaluation on {} dataset *****".format(mode)) 159 | logger.info(" Num examples = {}".format(len(eval_dataset))) 160 | logger.info(" Eval Batch size = {}".format(args.eval_batch_size)) 161 | eval_loss = 0.0 162 | nb_eval_steps = 0 163 | preds = None 164 | out_label_ids = None 165 | 166 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 167 | model.eval() 168 | batch = tuple(t.to(args.device) for t in batch) 169 | 170 | with torch.no_grad(): 171 | inputs = { 172 | "input_ids": batch[0], 173 | "attention_mask": batch[1], 174 | "token_type_ids": batch[2], 175 | "labels": batch[3] 176 | } 177 | outputs = model(**inputs) 178 | tmp_eval_loss, logits = outputs[:2] 179 | 180 | eval_loss += tmp_eval_loss.mean().item() 181 | nb_eval_steps += 1 182 | if preds is None: 183 | preds = 1 / (1 + np.exp(-logits.detach().cpu().numpy())) # Sigmoid 184 | out_label_ids = inputs["labels"].detach().cpu().numpy() 185 | else: 186 | preds = np.append(preds, 1 / (1 + np.exp(-logits.detach().cpu().numpy())), axis=0) # Sigmoid 187 | out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) 188 | 189 | eval_loss = eval_loss / nb_eval_steps 190 | results = { 191 | "loss": eval_loss 192 | } 193 | preds[preds > args.threshold] = 1 194 | preds[preds <= args.threshold] = 0 195 | result = compute_metrics(out_label_ids, preds) 196 | results.update(result) 197 | 198 | output_dir = os.path.join(args.output_dir, mode) 199 | if not os.path.exists(output_dir): 200 | os.makedirs(output_dir) 201 | 202 | output_eval_file = os.path.join(output_dir, "{}-{}.txt".format(mode, global_step) if global_step else "{}.txt".format(mode)) 203 | with open(output_eval_file, "w") as f_w: 204 | logger.info("***** Eval results on {} dataset *****".format(mode)) 205 | for key in sorted(results.keys()): 206 | logger.info(" {} = {}".format(key, str(results[key]))) 207 | f_w.write(" {} = {}\n".format(key, str(results[key]))) 208 | 209 | return results 210 | 211 | 212 | def main(cli_args): 213 | # Read from config file and make args 214 | config_filename = "{}.json".format(cli_args.taxonomy) 215 | with open(os.path.join("config", config_filename)) as f: 216 | args = AttrDict(json.load(f)) 217 | logger.info("Training/evaluation parameters {}".format(args)) 218 | 219 | args.output_dir = os.path.join(args.ckpt_dir, args.output_dir) 220 | 221 | init_logger() 222 | set_seed(args) 223 | 224 | processor = GoEmotionsProcessor(args) 225 | label_list = processor.get_labels() 226 | 227 | config = BertConfig.from_pretrained( 228 | args.model_name_or_path, 229 | num_labels=len(label_list), 230 | finetuning_task=args.task, 231 | id2label={str(i): label for i, label in enumerate(label_list)}, 232 | label2id={label: i for i, label in enumerate(label_list)} 233 | ) 234 | tokenizer = BertTokenizer.from_pretrained( 235 | args.tokenizer_name_or_path, 236 | ) 237 | model = BertForMultiLabelClassification.from_pretrained( 238 | args.model_name_or_path, 239 | config=config 240 | ) 241 | 242 | # GPU or CPU 243 | args.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" 244 | model.to(args.device) 245 | 246 | # Load dataset 247 | train_dataset = load_and_cache_examples(args, tokenizer, mode="train") if args.train_file else None 248 | dev_dataset = load_and_cache_examples(args, tokenizer, mode="dev") if args.dev_file else None 249 | test_dataset = load_and_cache_examples(args, tokenizer, mode="test") if args.test_file else None 250 | 251 | if dev_dataset is None: 252 | args.evaluate_test_during_training = True # If there is no dev dataset, only use test dataset 253 | 254 | if args.do_train: 255 | global_step, tr_loss = train(args, model, tokenizer, train_dataset, dev_dataset, test_dataset) 256 | logger.info(" global_step = {}, average loss = {}".format(global_step, tr_loss)) 257 | 258 | results = {} 259 | if args.do_eval: 260 | checkpoints = list( 261 | os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + "pytorch_model.bin", recursive=True)) 262 | ) 263 | if not args.eval_all_checkpoints: 264 | checkpoints = checkpoints[-1:] 265 | else: 266 | logging.getLogger("transformers.configuration_utils").setLevel(logging.WARN) # Reduce logging 267 | logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging 268 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 269 | for checkpoint in checkpoints: 270 | global_step = checkpoint.split("-")[-1] 271 | model = BertForMultiLabelClassification.from_pretrained(checkpoint) 272 | model.to(args.device) 273 | result = evaluate(args, model, test_dataset, mode="test", global_step=global_step) 274 | result = dict((k + "_{}".format(global_step), v) for k, v in result.items()) 275 | results.update(result) 276 | 277 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt") 278 | with open(output_eval_file, "w") as f_w: 279 | for key in sorted(results.keys()): 280 | f_w.write("{} = {}\n".format(key, str(results[key]))) 281 | 282 | 283 | if __name__ == '__main__': 284 | cli_parser = argparse.ArgumentParser() 285 | 286 | cli_parser.add_argument("--taxonomy", type=str, required=True, help="Taxonomy (original, ekman, group)") 287 | 288 | cli_args = cli_parser.parse_args() 289 | 290 | main(cli_args) 291 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import logging 4 | 5 | import torch 6 | import numpy as np 7 | 8 | from sklearn.metrics import precision_recall_fscore_support, accuracy_score 9 | 10 | 11 | def init_logger(): 12 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 13 | datefmt='%m/%d/%Y %H:%M:%S', 14 | level=logging.INFO) 15 | 16 | 17 | def set_seed(args): 18 | random.seed(args.seed) 19 | np.random.seed(args.seed) 20 | torch.manual_seed(args.seed) 21 | if not args.no_cuda and torch.cuda.is_available(): 22 | torch.cuda.manual_seed_all(args.seed) 23 | 24 | 25 | def compute_metrics(labels, preds): 26 | assert len(preds) == len(labels) 27 | results = dict() 28 | 29 | results["accuracy"] = accuracy_score(labels, preds) 30 | results["macro_precision"], results["macro_recall"], results[ 31 | "macro_f1"], _ = precision_recall_fscore_support( 32 | labels, preds, average="macro") 33 | results["micro_precision"], results["micro_recall"], results[ 34 | "micro_f1"], _ = precision_recall_fscore_support( 35 | labels, preds, average="micro") 36 | results["weighted_precision"], results["weighted_recall"], results[ 37 | "weighted_f1"], _ = precision_recall_fscore_support( 38 | labels, preds, average="weighted") 39 | 40 | return results 41 | --------------------------------------------------------------------------------