├── .gitignore ├── LICENSE ├── README.md ├── config ├── atomic-2020 │ ├── bart-base_baseline │ │ └── v01.yml │ ├── bart-base_experiments │ │ └── v_all_joint_v01.yml │ ├── bart-large_baseline │ │ └── v01.yml │ ├── bart-large_experiments │ │ ├── v01.yml │ │ └── v_all_joint_v01.yml │ └── datasets.yml ├── atomic │ ├── bart-base_baseline │ │ └── v01.yml │ ├── bart-base_experiments │ │ └── v_all_joint_v01.yml │ ├── bart-large_baseline │ │ └── v01.yml │ ├── bart-large_experiments │ │ ├── v01.yml │ │ └── v_all_joint_v01.yml │ └── datasets.yml ├── conceptnet │ ├── bart-base_baseline │ │ └── v01.yml │ ├── bart-base_experiments │ │ └── v_all_joint_v01.yml │ ├── bart-large_baseline │ │ └── v01.yml │ ├── bart-large_experiments │ │ ├── v01.yml │ │ └── v_all_joint_v01.yml │ └── datasets.yml └── tokenizer_config.yml ├── models ├── bart.py ├── distance_func.py ├── head_proj_layer.py ├── loss_func.py └── model_utils.py ├── requirements.txt ├── scripts ├── feature_learn.py ├── finetune.py └── inference.py ├── src ├── data_utils.py ├── feed_model.py ├── finetune │ ├── finetune_trainer.py │ └── finetune_utils.py ├── lr_schedule.py ├── rec_adam.py ├── rec_adam_wrapper.py ├── sampler.py ├── sampler_utils.py ├── tokenizer.py ├── train_utils.py ├── trainer.py └── utils.py └── system_eval ├── automatic_eval.py ├── evaluation ├── LICENSE ├── README.md ├── __init__.py ├── bert_score │ ├── __init__.py │ ├── bert_score.py │ ├── score.py │ └── utils.py ├── bleu │ ├── .gitignore │ ├── LICENSE │ ├── __init__.py │ ├── bleu.py │ └── bleu_scorer.py ├── cider │ ├── __init__.py │ ├── cider.py │ └── cider_scorer.py ├── eval.py ├── meteor │ ├── __init__.py │ ├── meteor-1.5.jar │ ├── meteor.py │ └── meteor_nltk.py └── rouge │ ├── __init__.py │ └── rouge.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | shell/* 2 | *__pycache__ 3 | 4 | human_eval 5 | *.ipynb 6 | *gpt* 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # solar-framework_commonsense-inference 2 | Code release for "Learning from Missing Relations: Contrastive Learning with Commonsense Knowledge Graphs for Commonsense Inference" 3 | 4 | 5 | Download Similarity Matrix : 6 | ```bash 7 | gdown https://drive.google.com/uc?id=1EHMIZXP_T1UfSzCWv9Is16n8DRW3dGJx 8 | ``` 9 | 10 | ### Preprocessing 11 | 12 | #### Download knowledge graphs 13 | Commonsense Knowledge Graph Sources : 14 | 15 | * [ConceptNet](https://home.ttic.edu/~kgimpel/commonsense.html) 16 | * [ATOMIC](https://allenai.org/data/atomic) 17 | * [ATOMIC-2020](https://allenai.org/data/atomic-2020) 18 | 19 | #### Simple preprocess 20 | * Download the above files, and preprocess each element in tab-separated tsv format. 21 | ``` 22 | # examples 23 | subject1 \t relation1 \t object1 24 | subject2 \t relation2 \t object2 25 | ... 26 | ``` 27 | 28 | * Modify the preprocessed data path in config/{dataset}/dataset.yml 29 | ``` 30 | name: 'atomic' 31 | truncate: 32 | subj_len: 25 33 | obj_len: 25 34 | dir: 35 | train: {your path} 36 | dev: {your path} 37 | test: {your path} 38 | sim: ## <- This is similarity matrix. you can download it from the above url. 39 | train: {your path} 40 | dev: {your path} 41 | 42 | ``` 43 | 44 | ### Fine-tuning 45 | 46 | ``` 47 | python scripts/finetune.py --dataset_type {dataset} --model_name {model_name} --model_size {model_size} 48 | ``` 49 | 50 | ## Pre-training 51 | 52 | ``` 53 | python scripts/feature_learn.py --dataset_type {dataset} --model_name {model_name} --model_size {model_size} 54 | ``` 55 | 56 | -------------------------------------------------------------------------------- /config/atomic-2020/bart-base_baseline/v01.yml: -------------------------------------------------------------------------------- 1 | # Fine-Tuning 2 | 3 | model: 4 | name: 'bart-base' 5 | pretrained_model: 'facebook/bart-base' 6 | tokenize_model: 'facebook/bart-base' 7 | 8 | opt: 9 | lr_scheduler: "linear" 10 | warmup_steps: 200 11 | clip_grad_norm: 1.0 12 | weight_decay: 0 13 | output_dropout_p: 0.1 14 | optimizer: 'adam' 15 | adam_beta_1: 0.9 16 | adam_beta_2: 0.999 17 | adam_eps: 1E-08 18 | 19 | log: 20 | tb_period: 10 21 | val_period: 1000 22 | save_period: 5000 23 | -------------------------------------------------------------------------------- /config/atomic-2020/bart-base_experiments/v_all_joint_v01.yml: -------------------------------------------------------------------------------- 1 | # Test Version 2 | task: 3 | max_enc: 55 4 | max_dec: 55 5 | init_with_shuffle: true # Only For Training dataset 6 | gen_task: 7 | weight: 0.0 8 | loss_func: "cross-entropy" 9 | con_task: 10 | weight: 0.8 11 | method: "cluster" # cluster, naive 12 | cluster_contrast: 13 | group_num: 32 14 | pos_subj_min: 0.85 15 | sampling_method: "adv" # random, adv 16 | adv_sampling: 17 | min_sim: 0.2 18 | max_sim: 0.7 19 | loss_func: "NT-Logistic" 20 | rec_inf_shu_task: 21 | weight: 0.2 22 | method: "naive" 23 | loss_func: "cross-entropy" 24 | no_crpt_prob: 0.25 25 | subj_crpt_prob: 0.25 26 | rel_crpt_prob: 0.25 27 | obj_crpt_prob: 0.25 28 | shuffle_prob: 0.5 29 | den_task: 30 | weight: 0.0 31 | method: "naive" 32 | subj_mask_prob: 0.33 33 | rel_mask_prob: 0.33 34 | obj_mask_prob: 0.34 35 | hint_prob: 0.3 36 | hint_from_the_front: true 37 | loss_func: "cross-entropy" 38 | 39 | model: 40 | name: 'bart-base' 41 | pretrained_model: 'facebook/bart-base' 42 | tokenize_model: 'facebook/bart-base' 43 | task_adaptor_options: 44 | common: 45 | use_task_prefix: true 46 | con_task: 47 | format: 'enc-dec' 48 | dec_input_ids_s: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | 49 | dec_input_ids_ro: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | con_task_token_ro_id 50 | rec_inf_shu_task: 51 | format: 'enc-dec' 52 | dec_input_ids: 'bos_token_id' # bos_token_id | den_task_token_id | den_task_token_for_dec_id 53 | 54 | contrastive_head: 55 | proj_layer_type: 'multi-head' # multi-head-deeper | multi-head | non-linear 56 | pool_method: 'all_joint' # dec_bos, mean_pool_enc, all_joint, 57 | multi-head: 58 | head_num: 8 59 | head_dim: 512 60 | pool_method: "maxpool" 61 | multi-head-deeper: 62 | head_num: 12 63 | inner_hidden_mul: 3 64 | head_dim: 512 65 | pool_method: "maxpool" 66 | 67 | opt: 68 | lr_scheduler: "linear" 69 | warmup_steps: 200 70 | clip_grad_norm: 1.0 71 | weight_decay: 0 72 | output_dropout_p: 0.1 73 | optimizer: 'adam' 74 | adam_beta_1: 0.9 75 | adam_beta_2: 0.999 76 | adam_eps: 1E-08 77 | temperature: 0.1 #0.1 78 | use_l2: true 79 | log: 80 | tb_period: 10 81 | val_period: 1000 82 | save_period: 10000 83 | -------------------------------------------------------------------------------- /config/atomic-2020/bart-large_baseline/v01.yml: -------------------------------------------------------------------------------- 1 | # Fine-Tuning 2 | 3 | model: 4 | name: 'bart-large' 5 | pretrained_model: 'facebook/bart-large' 6 | tokenize_model: 'facebook/bart-large' 7 | 8 | opt: 9 | lr_scheduler: "linear" 10 | warmup_steps: 200 11 | clip_grad_norm: 1.0 12 | weight_decay: 0 13 | output_dropout_p: 0.1 14 | optimizer: 'adam' 15 | adam_beta_1: 0.9 16 | adam_beta_2: 0.999 17 | adam_eps: 1E-08 18 | 19 | log: 20 | tb_period: 10 21 | val_period: 1000 22 | save_period: 5000 23 | -------------------------------------------------------------------------------- /config/atomic-2020/bart-large_experiments/v01.yml: -------------------------------------------------------------------------------- 1 | # Test Version 2 | 3 | task: 4 | max_enc: 55 5 | max_dec: 55 6 | init_with_shuffle: true # Only For Training dataset 7 | gen_task: 8 | weight: 0.0 9 | loss_func: "cross-entropy" 10 | con_task: 11 | weight: 0.8 12 | method: "cluster" # cluster, naive 13 | cluster_contrast: 14 | group_num: 16 15 | pos_subj_min: 0.75 16 | sampling_method: "adv" # random, adv 17 | adv_sampling: 18 | min_sim: 0.4 19 | max_sim: 0.6 20 | loss_func: "NT-Logistic" 21 | rec_inf_shu_task: 22 | weight: 0.2 23 | method: "naive" 24 | loss_func: "cross-entropy" 25 | no_crpt_prob: 0.25 26 | subj_crpt_prob: 0.25 27 | rel_crpt_prob: 0.25 28 | obj_crpt_prob: 0.25 29 | shuffle_prob: 0.5 30 | den_task: 31 | weight: 0.0 32 | method: "naive" 33 | subj_mask_prob: 0.33 34 | rel_mask_prob: 0.33 35 | obj_mask_prob: 0.34 36 | hint_prob: 0.3 37 | hint_from_the_front: true 38 | loss_func: "cross-entropy" 39 | 40 | model: 41 | name: 'bart-base' 42 | pretrained_model: 'facebook/bart-base' 43 | tokenize_model: 'facebook/bart-base' 44 | task_adaptor_options: 45 | common: 46 | use_task_prefix: true 47 | con_task: 48 | format: 'enc-dec' 49 | dec_input_ids_s: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | 50 | dec_input_ids_ro: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | con_task_token_ro_id 51 | rec_inf_shu_task: 52 | format: 'enc-dec' 53 | dec_input_ids: 'bos_token_id' # bos_token_id | den_task_token_id | den_task_token_for_dec_id 54 | 55 | contrastive_head: 56 | name: 'multi-head' # multi-head-deeper | multi-head | non-linear 57 | multi-head: 58 | head_num: 4 59 | head_dim: 512 60 | pool_method: "maxpool" 61 | multi-head-deeper: 62 | head_num: 12 63 | inner_hidden_mul: 3 64 | head_dim: 512 65 | pool_method: "maxpool" 66 | 67 | opt: 68 | lr_scheduler: "linear" 69 | warmup_steps: 200 70 | clip_grad_norm: 1.0 71 | weight_decay: 0 72 | output_dropout_p: 0.1 73 | optimizer: 'adam' 74 | adam_beta_1: 0.9 75 | adam_beta_2: 0.999 76 | adam_eps: 1E-08 77 | temperature: 0.1 #0.1 78 | use_l2: true 79 | log: 80 | tb_period: 10 81 | val_period: 1000 82 | save_period: 10000 83 | -------------------------------------------------------------------------------- /config/atomic-2020/bart-large_experiments/v_all_joint_v01.yml: -------------------------------------------------------------------------------- 1 | # Test Version 2 | task: 3 | max_enc: 25 4 | max_dec: 25 5 | init_with_shuffle: true # Only For Training dataset 6 | gen_task: 7 | weight: 0.0 8 | loss_func: "cross-entropy" 9 | con_task: 10 | weight: 0.8 11 | method: "cluster" # cluster, naive 12 | cluster_contrast: 13 | group_num: 4 14 | pos_subj_min: 0.85 15 | sampling_method: "adv" # random, adv 16 | adv_sampling: 17 | min_sim: 0.2 18 | max_sim: 0.7 19 | loss_func: "NT-Logistic" 20 | rec_inf_shu_task: 21 | weight: 0.2 22 | method: "naive" 23 | loss_func: "cross-entropy" 24 | no_crpt_prob: 0.25 25 | subj_crpt_prob: 0.25 26 | rel_crpt_prob: 0.25 27 | obj_crpt_prob: 0.25 28 | shuffle_prob: 0.5 29 | den_task: 30 | weight: 0.0 31 | method: "naive" 32 | subj_mask_prob: 0.33 33 | rel_mask_prob: 0.33 34 | obj_mask_prob: 0.34 35 | hint_prob: 0.3 36 | hint_from_the_front: true 37 | loss_func: "cross-entropy" 38 | 39 | model: 40 | name: 'bart-large' 41 | pretrained_model: 'facebook/bart-large' 42 | tokenize_model: 'facebook/bart-large' 43 | task_adaptor_options: 44 | common: 45 | use_task_prefix: true 46 | con_task: 47 | format: 'enc-dec' 48 | dec_input_ids_s: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | 49 | dec_input_ids_ro: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | con_task_token_ro_id 50 | rec_inf_shu_task: 51 | format: 'enc-dec' 52 | dec_input_ids: 'bos_token_id' # bos_token_id | den_task_token_id | den_task_token_for_dec_id 53 | 54 | contrastive_head: 55 | proj_layer_type: 'multi-head' # multi-head-deeper | multi-head | non-linear 56 | pool_method: 'all_joint' # dec_bos, mean_pool_enc, all_joint, 57 | multi-head: 58 | head_num: 8 59 | head_dim: 512 60 | pool_method: "maxpool" 61 | multi-head-deeper: 62 | head_num: 12 63 | inner_hidden_mul: 3 64 | head_dim: 512 65 | pool_method: "maxpool" 66 | 67 | opt: 68 | lr_scheduler: "linear" 69 | warmup_steps: 200 70 | clip_grad_norm: 1.0 71 | weight_decay: 0 72 | output_dropout_p: 0.1 73 | optimizer: 'adam' 74 | adam_beta_1: 0.9 75 | adam_beta_2: 0.999 76 | adam_eps: 1E-08 77 | temperature: 0.1 #0.1 78 | use_l2: true 79 | log: 80 | tb_period: 10 81 | val_period: 1000 82 | save_period: 10000 83 | -------------------------------------------------------------------------------- /config/atomic-2020/datasets.yml: -------------------------------------------------------------------------------- 1 | # Test Version 2 | 3 | name: 'atomic-2020' 4 | truncate: 5 | subj_len: 25 6 | obj_len: 25 7 | dir: 8 | train: '/mnt/data/user8/solar-commonsense_inference/data/tuple_data/atomic-2020/train.tsv' 9 | dev: '/mnt/data/user8/solar-commonsense_inference/data/tuple_data/atomic-2020/dev.tsv' 10 | test : '/mnt/data/user8/solar-commonsense_inference/data/tuple_data/atomic-2020/test.tsv' 11 | sim: 12 | train: '/mnt/data/user8/solar-commonsense_inference/data/sim_mat/atomic-2020/train_subj_dist__a_concept_related_to.pkl' 13 | dev: '/mnt/data/user8/solar-commonsense_inference/data/sim_mat/atomic-2020/dev_subj_dist__a_concept_related_to.pkl' 14 | -------------------------------------------------------------------------------- /config/atomic/bart-base_baseline/v01.yml: -------------------------------------------------------------------------------- 1 | # Fine-Tuning 2 | 3 | model: 4 | name: 'bart-base' 5 | pretrained_model: 'facebook/bart-base' 6 | tokenize_model: 'facebook/bart-base' 7 | 8 | opt: 9 | lr_scheduler: "linear" 10 | warmup_steps: 200 11 | clip_grad_norm: 1.0 12 | weight_decay: 0 13 | output_dropout_p: 0.1 14 | optimizer: 'adam' 15 | adam_beta_1: 0.9 16 | adam_beta_2: 0.999 17 | adam_eps: 1E-08 18 | 19 | log: 20 | tb_period: 10 21 | val_period: 1000 22 | save_period: 5000 23 | -------------------------------------------------------------------------------- /config/atomic/bart-base_experiments/v_all_joint_v01.yml: -------------------------------------------------------------------------------- 1 | # Test Version 2 | 3 | task: 4 | max_enc: 25 5 | max_dec: 25 6 | init_with_shuffle: true # Only For Training dataset 7 | gen_task: 8 | weight: 0.0 9 | loss_func: "cross-entropy" 10 | con_task: 11 | weight: 0.8 12 | method: "cluster" # cluster, naive 13 | cluster_contrast: 14 | group_num: 32 15 | pos_subj_min: 0.8 16 | sampling_method: "adv" # random, adv 17 | adv_sampling: 18 | min_sim: 0.2 19 | max_sim: 0.7 20 | loss_func: "NT-Logistic" 21 | rec_inf_shu_task: 22 | weight: 0.2 23 | method: "naive" 24 | loss_func: "cross-entropy" 25 | no_crpt_prob: 0.25 26 | subj_crpt_prob: 0.25 27 | rel_crpt_prob: 0.25 28 | obj_crpt_prob: 0.25 29 | shuffle_prob: 0.5 30 | den_task: 31 | weight: 0.0 32 | method: "naive" 33 | subj_mask_prob: 0.33 34 | rel_mask_prob: 0.33 35 | obj_mask_prob: 0.34 36 | hint_prob: 0.3 37 | hint_from_the_front: true 38 | loss_func: "cross-entropy" 39 | 40 | model: 41 | name: 'bart-base' 42 | pretrained_model: 'facebook/bart-base' 43 | tokenize_model: 'facebook/bart-base' 44 | task_adaptor_options: 45 | common: 46 | use_task_prefix: true 47 | con_task: 48 | format: 'enc-dec' 49 | dec_input_ids_s: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | 50 | dec_input_ids_ro: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | con_task_token_ro_id 51 | rec_inf_shu_task: 52 | format: 'enc-dec' 53 | dec_input_ids: 'bos_token_id' # bos_token_id | den_task_token_id | den_task_token_for_dec_id 54 | 55 | 56 | contrastive_head: 57 | proj_layer_type: 'multi-head' # multi-head-deeper | multi-head | non-linear 58 | pool_method: 'all_joint' # dec_bos, mean_pool_enc, all_joint, 59 | multi-head: 60 | head_num: 8 61 | head_dim: 512 62 | pool_method: "maxpool" 63 | multi-head-deeper: 64 | head_num: 12 65 | inner_hidden_mul: 3 66 | head_dim: 512 67 | pool_method: "maxpool" 68 | 69 | opt: 70 | lr_scheduler: "linear" 71 | warmup_steps: 200 72 | clip_grad_norm: 1.0 73 | weight_decay: 0 74 | output_dropout_p: 0.1 75 | optimizer: 'adam' 76 | adam_beta_1: 0.9 77 | adam_beta_2: 0.999 78 | adam_eps: 1E-08 79 | temperature: 0.1 #0.1 80 | use_l2: true 81 | log: 82 | tb_period: 10 83 | val_period: 1000 84 | save_period: 10000 85 | -------------------------------------------------------------------------------- /config/atomic/bart-large_baseline/v01.yml: -------------------------------------------------------------------------------- 1 | # Fine-Tuning 2 | 3 | model: 4 | name: 'bart-large' 5 | pretrained_model: 'facebook/bart-large' 6 | tokenize_model: 'facebook/bart-large' 7 | 8 | opt: 9 | lr_scheduler: "linear" 10 | warmup_steps: 200 11 | clip_grad_norm: 1.0 12 | weight_decay: 0 13 | output_dropout_p: 0.1 14 | optimizer: 'adam' 15 | adam_beta_1: 0.9 16 | adam_beta_2: 0.999 17 | adam_eps: 1E-08 18 | 19 | log: 20 | tb_period: 10 21 | val_period: 1000 22 | save_period: 5000 23 | -------------------------------------------------------------------------------- /config/atomic/bart-large_experiments/v01.yml: -------------------------------------------------------------------------------- 1 | # Test Version 2 | 3 | task: 4 | max_enc: 55 5 | max_dec: 55 6 | init_with_shuffle: true # Only For Training dataset 7 | gen_task: 8 | weight: 0.0 9 | loss_func: "cross-entropy" 10 | con_task: 11 | weight: 0.8 12 | method: "cluster" # cluster, naive 13 | cluster_contrast: 14 | group_num: 32 15 | pos_subj_min: 0.85 16 | sampling_method: "adv" # random, adv 17 | adv_sampling: 18 | min_sim: 0.2 19 | max_sim: 0.7 20 | loss_func: "NT-Logistic" 21 | rec_inf_shu_task: 22 | weight: 0.2 23 | method: "naive" 24 | loss_func: "cross-entropy" 25 | no_crpt_prob: 0.25 26 | subj_crpt_prob: 0.25 27 | rel_crpt_prob: 0.25 28 | obj_crpt_prob: 0.25 29 | shuffle_prob: 0.5 30 | den_task: 31 | weight: 0.0 32 | method: "naive" 33 | subj_mask_prob: 0.33 34 | rel_mask_prob: 0.33 35 | obj_mask_prob: 0.34 36 | hint_prob: 0.3 37 | hint_from_the_front: true 38 | loss_func: "cross-entropy" 39 | 40 | model: 41 | name: 'bart-large' 42 | pretrained_model: 'facebook/bart-large' 43 | tokenize_model: 'facebook/bart-large' 44 | task_adaptor_options: 45 | common: 46 | use_task_prefix: true 47 | con_task: 48 | format: 'enc-dec' 49 | dec_input_ids_s: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | 50 | dec_input_ids_ro: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | con_task_token_ro_id 51 | den_task: 52 | format: 'naive' 53 | dec_input_ids: 'bos_token_id' # bos_token_id | den_task_token_id | den_task_token_for_dec_id 54 | rec_inf_shu_task: 55 | format: 'naive' 56 | dec_input_ids: 'bos_token_id' # bos_token_id | den_task_token_id | den_task_token_for_dec_id 57 | 58 | contrastive_head: 59 | name: 'multi-head' # multi-head-deeper | multi-head | non-linear 60 | multi-head: 61 | head_num: 8 62 | head_dim: 512 63 | pool_method: "maxpool" 64 | multi-head-deeper: 65 | head_num: 12 66 | inner_hidden_mul: 3 67 | head_dim: 512 68 | pool_method: "maxpool" 69 | 70 | opt: 71 | lr_scheduler: "linear" 72 | warmup_steps: 200 73 | clip_grad_norm: 1.0 74 | weight_decay: 0 75 | output_dropout_p: 0.1 76 | optimizer: 'adam' 77 | adam_beta_1: 0.9 78 | adam_beta_2: 0.999 79 | adam_eps: 1E-08 80 | temperature: 0.1 #0.1 81 | use_l2: true 82 | log: 83 | tb_period: 10 84 | val_period: 1000 85 | save_period: 10000 86 | -------------------------------------------------------------------------------- /config/atomic/bart-large_experiments/v_all_joint_v01.yml: -------------------------------------------------------------------------------- 1 | # Test Version 2 | 3 | task: 4 | max_enc: 25 5 | max_dec: 25 6 | init_with_shuffle: true # Only For Training dataset 7 | gen_task: 8 | weight: 0.0 9 | loss_func: "cross-entropy" 10 | con_task: 11 | weight: 0.8 12 | method: "cluster" # cluster, naive 13 | cluster_contrast: 14 | group_num: 4 15 | pos_subj_min: 0.85 16 | sampling_method: "adv" # random, adv 17 | adv_sampling: 18 | min_sim: 0.2 19 | max_sim: 0.7 20 | loss_func: "NT-Logistic" 21 | rec_inf_shu_task: 22 | weight: 0.2 23 | method: "naive" 24 | loss_func: "cross-entropy" 25 | no_crpt_prob: 0.25 26 | subj_crpt_prob: 0.25 27 | rel_crpt_prob: 0.25 28 | obj_crpt_prob: 0.25 29 | shuffle_prob: 0.5 30 | den_task: 31 | weight: 0.0 32 | method: "naive" 33 | subj_mask_prob: 0.33 34 | rel_mask_prob: 0.33 35 | obj_mask_prob: 0.34 36 | hint_prob: 0.3 37 | hint_from_the_front: true 38 | loss_func: "cross-entropy" 39 | 40 | model: 41 | name: 'bart-large' 42 | pretrained_model: 'facebook/bart-large' 43 | tokenize_model: 'facebook/bart-large' 44 | task_adaptor_options: 45 | common: 46 | use_task_prefix: true 47 | con_task: 48 | format: 'enc-dec' 49 | dec_input_ids_s: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | 50 | dec_input_ids_ro: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | con_task_token_ro_id 51 | rec_inf_shu_task: 52 | format: 'enc-dec' 53 | dec_input_ids: 'bos_token_id' # bos_token_id | den_task_token_id | den_task_token_for_dec_id 54 | 55 | 56 | contrastive_head: 57 | proj_layer_type: 'multi-head' # multi-head-deeper | multi-head | non-linear 58 | pool_method: 'all_joint' # dec_bos, mean_pool_enc, all_joint, 59 | multi-head: 60 | head_num: 8 61 | head_dim: 512 62 | pool_method: "maxpool" 63 | multi-head-deeper: 64 | head_num: 12 65 | inner_hidden_mul: 3 66 | head_dim: 512 67 | pool_method: "maxpool" 68 | 69 | opt: 70 | lr_scheduler: "linear" 71 | warmup_steps: 200 72 | clip_grad_norm: 1.0 73 | weight_decay: 0 74 | output_dropout_p: 0.1 75 | optimizer: 'adam' 76 | adam_beta_1: 0.9 77 | adam_beta_2: 0.999 78 | adam_eps: 1E-08 79 | temperature: 0.1 #0.1 80 | use_l2: true 81 | log: 82 | tb_period: 10 83 | val_period: 1000 84 | save_period: 10000 85 | -------------------------------------------------------------------------------- /config/atomic/datasets.yml: -------------------------------------------------------------------------------- 1 | name: 'atomic' 2 | truncate: 3 | subj_len: 25 4 | obj_len: 25 5 | dir: 6 | train: '/mnt/data/user8/solar-commonsense_inference/data/tuple_data/atomic/train.tsv' 7 | dev: '/mnt/data/user8/solar-commonsense_inference/data/tuple_data/atomic/dev.tsv' 8 | test: '/mnt/data/user8/solar-commonsense_inference/data/tuple_data/atomic/test.tsv' 9 | sim: 10 | train: '/mnt/data/user8/solar-commonsense_inference/data/sim_mat/atomic/train_subj_dist__a_concept_related_to.pkl' 11 | dev: '/mnt/data/user8/solar-commonsense_inference/data/sim_mat/atomic/dev_subj_dist__a_concept_related_to.pkl' 12 | -------------------------------------------------------------------------------- /config/conceptnet/bart-base_baseline/v01.yml: -------------------------------------------------------------------------------- 1 | # Fine-Tuning 2 | 3 | model: 4 | name: 'bart-base' 5 | pretrained_model: 'facebook/bart-base' 6 | tokenize_model: 'facebook/bart-base' 7 | 8 | opt: 9 | lr_scheduler: "linear" 10 | warmup_steps: 200 11 | clip_grad_norm: 1.0 12 | weight_decay: 0 13 | output_dropout_p: 0.1 14 | optimizer: 'adam' 15 | adam_beta_1: 0.9 16 | adam_beta_2: 0.999 17 | adam_eps: 1E-08 18 | 19 | log: 20 | tb_period: 10 21 | val_period: 1000 22 | save_period: 5000 23 | -------------------------------------------------------------------------------- /config/conceptnet/bart-base_experiments/v_all_joint_v01.yml: -------------------------------------------------------------------------------- 1 | # Test Version 2 | 3 | task: 4 | max_enc: 25 5 | max_dec: 25 6 | init_with_shuffle: true # Only For Training dataset 7 | gen_task: 8 | weight: 0.0 9 | loss_func: "cross-entropy" 10 | con_task: 11 | weight: 0.8 12 | method: "cluster" # cluster, naive 13 | cluster_contrast: 14 | group_num: 32 15 | pos_subj_min: 0.8 16 | sampling_method: "adv" # random, adv 17 | adv_sampling: 18 | min_sim: 0.2 19 | max_sim: 0.7 20 | loss_func: "NT-Logistic" 21 | rec_inf_shu_task: 22 | weight: 0.2 23 | method: "naive" 24 | loss_func: "cross-entropy" 25 | no_crpt_prob: 0.25 26 | subj_crpt_prob: 0.25 27 | rel_crpt_prob: 0.25 28 | obj_crpt_prob: 0.25 29 | shuffle_prob: 0.5 30 | den_task: 31 | weight: 0.0 32 | method: "naive" 33 | subj_mask_prob: 0.33 34 | rel_mask_prob: 0.33 35 | obj_mask_prob: 0.34 36 | hint_prob: 0.3 37 | hint_from_the_front: true 38 | loss_func: "cross-entropy" 39 | 40 | model: 41 | name: 'bart-base' 42 | pretrained_model: 'facebook/bart-base' 43 | tokenize_model: 'facebook/bart-base' 44 | task_adaptor_options: 45 | common: 46 | use_task_prefix: true 47 | con_task: 48 | format: 'enc-dec' 49 | dec_input_ids_s: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | 50 | dec_input_ids_ro: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | con_task_token_ro_id 51 | rec_inf_shu_task: 52 | format: 'enc-dec' 53 | dec_input_ids: 'bos_token_id' # bos_token_id | den_task_token_id | den_task_token_for_dec_id 54 | 55 | 56 | contrastive_head: 57 | proj_layer_type: 'multi-head' # multi-head-deeper | multi-head | non-linear 58 | pool_method: 'all_joint' # dec_bos, mean_pool_enc, all_joint, 59 | multi-head: 60 | head_num: 8 61 | head_dim: 512 62 | pool_method: "maxpool" 63 | multi-head-deeper: 64 | head_num: 12 65 | inner_hidden_mul: 3 66 | head_dim: 512 67 | pool_method: "maxpool" 68 | 69 | opt: 70 | lr_scheduler: "linear" 71 | warmup_steps: 200 72 | clip_grad_norm: 1.0 73 | weight_decay: 0 74 | output_dropout_p: 0.1 75 | optimizer: 'adam' 76 | adam_beta_1: 0.9 77 | adam_beta_2: 0.999 78 | adam_eps: 1E-08 79 | temperature: 0.1 #0.1 80 | use_l2: true 81 | log: 82 | tb_period: 10 83 | val_period: 1000 84 | save_period: 10000 85 | -------------------------------------------------------------------------------- /config/conceptnet/bart-large_baseline/v01.yml: -------------------------------------------------------------------------------- 1 | # Fine-Tuning 2 | 3 | model: 4 | name: 'bart-large' 5 | pretrained_model: 'facebook/bart-large' 6 | tokenize_model: 'facebook/bart-large' 7 | 8 | opt: 9 | lr_scheduler: "linear" 10 | warmup_steps: 200 11 | clip_grad_norm: 1.0 12 | weight_decay: 0 13 | output_dropout_p: 0.1 14 | optimizer: 'adam' 15 | adam_beta_1: 0.9 16 | adam_beta_2: 0.999 17 | adam_eps: 1E-08 18 | 19 | log: 20 | tb_period: 10 21 | val_period: 1000 22 | save_period: 5000 -------------------------------------------------------------------------------- /config/conceptnet/bart-large_experiments/v01.yml: -------------------------------------------------------------------------------- 1 | # Test Version 2 | 3 | task: 4 | max_enc: 55 5 | max_dec: 55 6 | init_with_shuffle: true # Only For Training dataset 7 | gen_task: 8 | weight: 0.0 9 | loss_func: "cross-entropy" 10 | con_task: 11 | weight: 0.8 12 | method: "cluster" # cluster, naive 13 | cluster_contrast: 14 | group_num: 32 15 | pos_subj_min: 0.85 16 | sampling_method: "adv" # random, adv 17 | adv_sampling: 18 | min_sim: 0.2 19 | max_sim: 0.7 20 | loss_func: "NT-Logistic" 21 | rec_inf_shu_task: 22 | weight: 0.2 23 | method: "naive" 24 | loss_func: "cross-entropy" 25 | no_crpt_prob: 0.25 26 | subj_crpt_prob: 0.25 27 | rel_crpt_prob: 0.25 28 | obj_crpt_prob: 0.25 29 | shuffle_prob: 0.5 30 | den_task: 31 | weight: 0.0 32 | method: "naive" 33 | subj_mask_prob: 0.33 34 | rel_mask_prob: 0.33 35 | obj_mask_prob: 0.34 36 | hint_prob: 0.3 37 | hint_from_the_front: true 38 | loss_func: "cross-entropy" 39 | 40 | model: 41 | name: 'bart-large' 42 | pretrained_model: 'facebook/bart-large' 43 | tokenize_model: 'facebook/bart-large' 44 | task_adaptor_options: 45 | common: 46 | use_task_prefix: true 47 | con_task: 48 | format: 'enc-dec' 49 | dec_input_ids_s: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | 50 | dec_input_ids_ro: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | con_task_token_ro_id 51 | den_task: 52 | format: 'naive' 53 | dec_input_ids: 'bos_token_id' # bos_token_id | den_task_token_id | den_task_token_for_dec_id 54 | rec_inf_shu_task: 55 | format: 'naive' 56 | dec_input_ids: 'bos_token_id' # bos_token_id | den_task_token_id | den_task_token_for_dec_id 57 | 58 | contrastive_head: 59 | name: 'multi-head' # multi-head-deeper | multi-head | non-linear 60 | multi-head: 61 | head_num: 8 62 | head_dim: 512 63 | pool_method: "maxpool" 64 | multi-head-deeper: 65 | head_num: 12 66 | inner_hidden_mul: 3 67 | head_dim: 512 68 | pool_method: "maxpool" 69 | 70 | opt: 71 | lr_scheduler: "linear" 72 | warmup_steps: 200 73 | clip_grad_norm: 1.0 74 | weight_decay: 0 75 | output_dropout_p: 0.1 76 | optimizer: 'adam' 77 | adam_beta_1: 0.9 78 | adam_beta_2: 0.999 79 | adam_eps: 1E-08 80 | temperature: 0.1 #0.1 81 | use_l2: true 82 | log: 83 | tb_period: 10 84 | val_period: 1000 85 | save_period: 10000 86 | -------------------------------------------------------------------------------- /config/conceptnet/bart-large_experiments/v_all_joint_v01.yml: -------------------------------------------------------------------------------- 1 | # Test Version 2 | 3 | task: 4 | max_enc: 25 5 | max_dec: 25 6 | init_with_shuffle: true # Only For Training dataset 7 | gen_task: 8 | weight: 0.0 9 | loss_func: "cross-entropy" 10 | con_task: 11 | weight: 0.8 12 | method: "cluster" # cluster, naive 13 | cluster_contrast: 14 | group_num: 4 15 | pos_subj_min: 0.85 16 | sampling_method: "adv" # random, adv 17 | adv_sampling: 18 | min_sim: 0.2 19 | max_sim: 0.7 20 | loss_func: "NT-Logistic" 21 | rec_inf_shu_task: 22 | weight: 0.2 23 | method: "naive" 24 | loss_func: "cross-entropy" 25 | no_crpt_prob: 0.25 26 | subj_crpt_prob: 0.25 27 | rel_crpt_prob: 0.25 28 | obj_crpt_prob: 0.25 29 | shuffle_prob: 0.5 30 | den_task: 31 | weight: 0.0 32 | method: "naive" 33 | subj_mask_prob: 0.33 34 | rel_mask_prob: 0.33 35 | obj_mask_prob: 0.34 36 | hint_prob: 0.3 37 | hint_from_the_front: true 38 | loss_func: "cross-entropy" 39 | 40 | model: 41 | name: 'bart-large' 42 | pretrained_model: 'facebook/bart-large' 43 | tokenize_model: 'facebook/bart-large' 44 | task_adaptor_options: 45 | common: 46 | use_task_prefix: true 47 | con_task: 48 | format: 'enc-dec' 49 | dec_input_ids_s: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | 50 | dec_input_ids_ro: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | con_task_token_ro_id 51 | rec_inf_shu_task: 52 | format: 'enc-dec' 53 | dec_input_ids: 'bos_token_id' # bos_token_id | den_task_token_id | den_task_token_for_dec_id 54 | 55 | 56 | contrastive_head: 57 | proj_layer_type: 'multi-head' # multi-head-deeper | multi-head | non-linear 58 | pool_method: 'all_joint' # dec_bos, mean_pool_enc, all_joint, 59 | multi-head: 60 | head_num: 8 61 | head_dim: 512 62 | pool_method: "maxpool" 63 | multi-head-deeper: 64 | head_num: 12 65 | inner_hidden_mul: 3 66 | head_dim: 512 67 | pool_method: "maxpool" 68 | 69 | opt: 70 | lr_scheduler: "linear" 71 | warmup_steps: 200 72 | clip_grad_norm: 1.0 73 | weight_decay: 0 74 | output_dropout_p: 0.1 75 | optimizer: 'adam' 76 | adam_beta_1: 0.9 77 | adam_beta_2: 0.999 78 | adam_eps: 1E-08 79 | temperature: 0.1 #0.1 80 | use_l2: true 81 | log: 82 | tb_period: 10 83 | val_period: 1000 84 | save_period: 10000 85 | -------------------------------------------------------------------------------- /config/conceptnet/datasets.yml: -------------------------------------------------------------------------------- 1 | # Test Version 2 | 3 | name: 'conceptnet' 4 | truncate: 5 | subj_len: 25 6 | obj_len: 25 7 | dir: 8 | train: '/mnt/data/user8/solar-commonsense_inference/data/tuple_data/conceptnet/train.tsv' 9 | dev: '/mnt/data/user8/solar-commonsense_inference/data/tuple_data/conceptnet/dev.tsv' 10 | test: '/mnt/data/user8/solar-commonsense_inference/data/tuple_data/conceptnet/test.tsv' 11 | sim: 12 | train: '/mnt/data/user8/solar-commonsense_inference/data/sim_mat/conceptnet/train_subj_dist__a_concept_related_to.pkl' 13 | dev: '/mnt/data/user8/solar-commonsense_inference/data/sim_mat/conceptnet/dev_subj_dist__a_concept_related_to.pkl' 14 | -------------------------------------------------------------------------------- /config/tokenizer_config.yml: -------------------------------------------------------------------------------- 1 | # Test Version 2 | 3 | special_tokens: 4 | bos_token: '' 5 | eos_token: '' 6 | pad_token: '' 7 | mask_token: '' 8 | sep_token: '' 9 | 10 | additional_tokens: 11 | gen_token: '' 12 | blk_token: '' 13 | none_token: '' 14 | mask_gen_task_token: '' 15 | mask_gen_task_token_for_dec: '' 16 | gen_task_token: '' 17 | gen_task_token_for_dec: '' 18 | con_task_token: '' 19 | con_task_token_for_s: '' 20 | con_task_token_for_ro: '' 21 | den_task_token: '' 22 | den_task_token_for_dec: '' 23 | rec_task_token: '' 24 | rec_task_token_for_dec: '' 25 | cla_task_token: '' 26 | 27 | relation_tokens: 28 | - '' 29 | - '' 30 | - '' 31 | - '' 32 | - '' 33 | - '' 34 | - '' 35 | - '' 36 | - '' 37 | - '' 38 | - '' 39 | - '' 40 | - '' 41 | - '' 42 | - '' 43 | - '' 44 | - '' 45 | - '' 46 | - '' 47 | - '' 48 | - '' 49 | - '' 50 | - '' 51 | - '' 52 | - '' 53 | - '' 54 | - '' 55 | - '' 56 | - '' 57 | - '' 58 | - '' 59 | - '' 60 | - '' 61 | - '' 62 | - '' 63 | - '' 64 | - '' 65 | - '' 66 | - '' 67 | - '' 68 | - '' 69 | - '' 70 | - '' 71 | - '' 72 | - '' 73 | - '' 74 | - '' 75 | - '' 76 | - '' 77 | - '' 78 | - '' 79 | -------------------------------------------------------------------------------- /models/bart.py: -------------------------------------------------------------------------------- 1 | from transformers import BartModel 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from models.head_proj_layer import * 6 | from transformers import BartTokenizer, BartForConditionalGeneration, BartModel 7 | from models.model_utils import convert_model 8 | from torch.nn import Linear 9 | 10 | 11 | class CometBART(BartModel): 12 | def modify_lm_heads(self, vocab_len, lm_head_dropout_p): 13 | self.resize_token_embeddings(vocab_len) 14 | self.lm_head = nn.Linear(self.config.hidden_size, vocab_len) 15 | self.lm_head_dropout = nn.Dropout(p=lm_head_dropout_p) 16 | 17 | def add_proj_layer(self, proj_options, proj_head_dropout_p, shared_proj_layer=True): 18 | self.proj_head_dropout = nn.Dropout(p=proj_head_dropout_p) 19 | self.proj_options = proj_options 20 | self.seq_method = proj_options['pool_method'] 21 | if shared_proj_layer is True: 22 | self.proj_head = self._get_proj_head() 23 | else: 24 | self.proj_head_s = self._get_proj_head() 25 | self.proj_head_ro = self._get_proj_head() 26 | 27 | def _get_proj_head(self): 28 | proj_layer_type = self.proj_options['proj_layer_type'] 29 | hidden_size = self.config.hidden_size 30 | 31 | if proj_layer_type == 'non-linear': 32 | head_layer = NonLinearHeadProjLayer(input_hidden_size) 33 | elif proj_layer_type == 'multi-head': 34 | sub_opt = self.proj_options[proj_layer_type] 35 | head_num = sub_opt['head_num'] 36 | head_dim = int(hidden_size / head_num) if sub_opt['head_dim'] == -1 else sub_opt['head_dim'] 37 | head_layer = MultiHeadProjLayer(hidden_size, head_num, head_dim) 38 | else: 39 | raise NotImplementedError 40 | 41 | return head_layer 42 | 43 | def forward_conditional_gen(self, enc_input_ids, enc_att_mask, dec_input_ids, dec_att_mask): 44 | outputs = super().forward(input_ids=enc_input_ids, attention_mask=enc_att_mask, 45 | decoder_input_ids=dec_input_ids, decoder_attention_mask=dec_att_mask) 46 | 47 | last_hidden_state = outputs.last_hidden_state 48 | last_hidden_state = self.lm_head_dropout(last_hidden_state) 49 | lm_logits = self.lm_head(last_hidden_state) 50 | 51 | return lm_logits 52 | 53 | def forward_latent_feature(self, enc_input_ids, enc_att_mask, dec_input_ids, dec_att_mask, for_s=None): 54 | outputs = super().forward(input_ids=enc_input_ids, attention_mask=enc_att_mask, 55 | decoder_input_ids=dec_input_ids, decoder_attention_mask=dec_att_mask) 56 | 57 | seq_feature = self.get_sequence_feature(outputs, enc_att_mask) 58 | seq_feature = self.proj_head_dropout(seq_feature) 59 | latent_vec = self._forward_projection(seq_feature, for_s) 60 | 61 | return latent_vec 62 | 63 | def get_sequence_feature(self, outputs, enc_att_mask): 64 | if self.seq_method == 'dec_bos': 65 | dec_last_hidden_state = outputs.last_hidden_state 66 | seq_feature = self._get_seq_feature_from_dec_bos(dec_last_hidden_state) 67 | elif self.seq_method == 'mean_pool_enc': 68 | enc_last_hidden_state = outputs.encoder_last_hidden_state 69 | seq_feature = self._get_seq_feature_from_mean_pool_enc(enc_last_hidden_state, enc_att_mask) 70 | elif self.seq_method == 'all_joint': 71 | enc_last_hidden_state = outputs.encoder_last_hidden_state 72 | dec_last_hidden_state = outputs.last_hidden_state 73 | seq_feature = self._get_seq_feature_from_mean_pool_enc_dec( 74 | enc_last_hidden_state, enc_att_mask, dec_last_hidden_state) 75 | else: 76 | raise NotImplementedError 77 | 78 | return seq_feature 79 | 80 | def _get_seq_feature_from_dec_bos(self, dec_last_hidden_state): 81 | seq_feature = dec_last_hidden_state[:, 0, :] 82 | return seq_feature 83 | 84 | def _get_seq_feature_from_mean_pool_enc(self, enc_last_hidden_state, att_mask): 85 | seq_feature = enc_last_hidden_state * att_mask.unsqueeze(-1) # (B, S, H) 86 | seq_feature = torch.sum(seq_feature, dim=1) # (B, H) 87 | seq_feature = seq_feature / torch.sum(att_mask, -1, keepdim=True) 88 | return seq_feature 89 | 90 | def _get_seq_feature_from_mean_pool_enc_dec(self, enc_last_hidden_state, enc_att_mask, dec_last_hidden_state): 91 | seq_feature = enc_last_hidden_state * enc_att_mask.unsqueeze(-1) # (B, S, H) 92 | seq_feature_with_dec_bos = torch.cat((seq_feature, dec_last_hidden_state[:, :1, :]),dim=1) 93 | seq_feature = torch.sum(seq_feature_with_dec_bos, dim=1) # (B, H) 94 | seq_feature = seq_feature / (torch.sum(enc_att_mask, -1, keepdim=True)+1) 95 | return seq_feature 96 | 97 | def _forward_projection(self, sequence_feature, for_s): 98 | if for_s is None: 99 | latent_vec = self.proj_head(sequence_feature) 100 | elif for_s is True: 101 | latent_vec = self.proj_head_s(sequence_feature) 102 | elif for_s is False: 103 | latent_vec = self.proj_head_ro(sequence_feature) 104 | else: 105 | raise NotImplementedError 106 | 107 | return latent_vec 108 | 109 | 110 | def convert_BARTModel_to_BartForConditionalGeneration(bart_model, params): 111 | device = bart_model.device 112 | # model_name = f'facebook/bart-{size}' 113 | gen_model = BartForConditionalGeneration.from_pretrained(params) 114 | vocab_len, hidden = bart_model.lm_head.weight.shape 115 | use_bias = False if bart_model.lm_head.bias is None else True 116 | gen_model.resize_token_embeddings(vocab_len) 117 | gen_model.lm_head = Linear(hidden, vocab_len, bias=use_bias) 118 | 119 | gen_model = convert_model(bart_model, gen_model).to(device) 120 | 121 | return gen_model 122 | -------------------------------------------------------------------------------- /models/distance_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class NonLinearHeadDistance(torch.nn.Module): 5 | def __init__(self): 6 | super(NonLinearHeadDistance, self).__init__() 7 | 8 | def forward(self, batch_s, batch_ro): 9 | l2_s = batch_s / torch.linalg.norm(batch_s, dim=-1, ord=2, keepdim=True) 10 | l2_ro = batch_ro / torch.linalg.norm(batch_ro, dim=-1, ord=2, keepdim=True) 11 | 12 | l2_ro = l2_ro.transpose(0, 1) 13 | 14 | final_dist = torch.matmul(l2_s, l2_ro) 15 | return final_dist 16 | 17 | 18 | class MultiHeadDistance(torch.nn.Module): 19 | def __init__(self, head_option): 20 | super(MultiHeadDistance, self).__init__() 21 | self.head_num = head_option['head_num'] 22 | self.head_dim = head_option['head_dim'] 23 | self.pool = None 24 | if head_option['pool_method'] == 'maxpool': 25 | self.max_pool = torch.nn.MaxPool1d(self.head_num) 26 | else: 27 | raise NotImplementedError 28 | 29 | def forward(self, batch_s, batch_ro): 30 | batch_s = batch_s.view(list(batch_s.shape[:-1]) + [self.head_num, self.head_dim]) 31 | batch_ro = batch_ro.view(list(batch_ro.shape[:-1]) + [self.head_num, self.head_dim]) 32 | 33 | l2_s = batch_s / torch.linalg.norm(batch_s, dim=-1, ord=2, keepdim=True) 34 | l2_ro = batch_ro / torch.linalg.norm(batch_ro, dim=-1, ord=2, keepdim=True) 35 | 36 | l2_s = l2_s.permute([1, 0, 2]) 37 | l2_ro = l2_ro.permute([1, 0, 2]) 38 | l2_ro = l2_ro.transpose(1, 2) 39 | 40 | dist = torch.matmul(l2_s, l2_ro) 41 | dist = dist.permute([1, 2, 0]) 42 | 43 | final_dist = self.max_pool(dist).squeeze(-1) 44 | return final_dist 45 | 46 | -------------------------------------------------------------------------------- /models/head_proj_layer.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class NonLinearHeadProjLayer(nn.Module): 5 | def __init__(self, input_hidden_size): 6 | super(NonLinearHeadProjLayer, self).__init__() 7 | self.linear_1 = nn.Linear(input_hidden_size, input_hidden_size) 8 | self.relu = nn.ReLU() 9 | self.linear_2 = nn.Linear(input_hidden_size, input_hidden_size) 10 | 11 | def forward(self, x): 12 | x = self.linear_1(x) 13 | x = self.relu(x) 14 | x = self.linear_2(x) 15 | return x 16 | 17 | 18 | class MultiHeadProjLayer(nn.Module): 19 | def __init__(self, input_hidden_size, head_num, head_size): 20 | super(MultiHeadProjLayer, self).__init__() 21 | self.linear_1 = nn.Linear(input_hidden_size, input_hidden_size) 22 | self.relu = nn.ReLU() 23 | self.head_size = head_size if head_size != -1 else int(input_hidden_size / head_num) 24 | self.multi_head_proj = nn.Linear(input_hidden_size, head_num * head_size) 25 | 26 | def forward(self, x): 27 | x = self.linear_1(x) 28 | x = self.relu(x) 29 | x = self.multi_head_proj(x) 30 | return x 31 | 32 | 33 | class MultiHeadProjLayerDeeper(nn.Module): 34 | def __init__(self, input_hidden_size, head_num, head_size, inner_hidden_mul): 35 | super(MultiHeadProjLayerDeeper, self).__init__() 36 | self.linear_1 = nn.Linear(input_hidden_size, input_hidden_size * inner_hidden_mul) 37 | self.relu = nn.ReLU() 38 | self.head_size = head_size if head_size != -1 else int(input_hidden_size / head_num) 39 | self.linear_2 = nn.Linear(input_hidden_size * inner_hidden_mul, input_hidden_size) 40 | self.multi_head_proj = nn.Linear(input_hidden_size, head_num * head_size) 41 | 42 | def forward(self, x): 43 | x = self.linear_1(x) 44 | x = self.relu(x) 45 | x = self.linear_2(x) 46 | x = self.relu(x) 47 | x = self.multi_head_proj(x) 48 | return x -------------------------------------------------------------------------------- /models/loss_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | EPS = 0.0001 3 | 4 | 5 | class NT_Logistic: 6 | def __init__(self, temp=0.1): 7 | self.temp = temp 8 | self.sigmoid = torch.nn.Sigmoid() 9 | 10 | def get_loss(self, dist, pos_sample): 11 | neg_sample = torch.ones_like(pos_sample) - pos_sample 12 | pos_dist = torch.mul(dist, pos_sample) 13 | neg_dist = torch.mul(dist, neg_sample) 14 | avg = (pos_sample / torch.sum(pos_sample, -1)) + (neg_sample / torch.sum(neg_sample, -1)) 15 | logged_dist = torch.log(self.sigmoid((pos_dist - neg_dist) / self.temp) + 0.0001) 16 | 17 | pos_dists = torch.sum(pos_dist) / torch.sum(pos_sample) 18 | neg_dists = torch.sum(neg_dist) / torch.sum(neg_sample) 19 | avg_dist = torch.sum(torch.mul(logged_dist, avg), -1) 20 | avg_dist = -torch.mean(avg_dist) 21 | return avg_dist, pos_dists, neg_dists 22 | 23 | 24 | ''' 25 | class NT_Logistic: 26 | def __init__(self, temp=0.1): 27 | self.temp = temp 28 | self.sigmoid = torch.nn.Sigmoid() 29 | 30 | def get_loss(self, dist, pos_sample): 31 | pos_sample = pos_sample 32 | neg_sample = torch.ones_like(pos_sample) - pos_sample 33 | 34 | pos_dist, pos_num = self.get_each_loss(dist, pos_sample) 35 | neg_dist, neg_num = self.get_each_loss(dist, neg_sample) 36 | logged_dist = torch.log(self.sigmoid((pos_dist - neg_dist) / self.temp) + 0.0001) 37 | 38 | avg = torch.zeros_like(dist) 39 | if pos_num.item() != 0: 40 | avg += pos_sample / pos_num 41 | if neg_num.item() != 0: 42 | avg += neg_sample / neg_num 43 | 44 | loss = -torch.sum(logged_dist * avg) 45 | print(loss) 46 | print(pos_num, neg_num) 47 | print(logged_dist[0][:50]) 48 | #print(avg[0]) 49 | pos_dist_total = torch.sum(pos_dist) 50 | neg_dist_total = torch.sum(neg_dist) 51 | 52 | pos_dists = (pos_dist_total.item(), pos_num.item()) 53 | neg_dists = (neg_dist_total.item(), neg_num.item()) 54 | 55 | return loss, pos_dists, neg_dists 56 | 57 | 58 | 59 | 60 | pos_dist = torch.mul(dist, pos_sample) + EPS 61 | neg_dist = torch.mul(dist, neg_sample) + EPS 62 | avg = (pos_sample / (torch.sum(pos_sample, -1) + EPS)) + (neg_sample / (torch.sum(neg_sample, -1) + EPS)) 63 | logged_dist = torch.log(self.sigmoid((pos_dist - neg_dist) / self.temp) + 0.0001) 64 | 65 | pos_dists = torch.sum(pos_dist) / torch.sum(pos_sample) + EPS 66 | neg_dists = torch.sum(neg_dist) / torch.sum(neg_sample) + EPS 67 | avg_dist = torch.sum(torch.mul(logged_dist, avg), -1) 68 | avg_dist = -torch.mean(avg_dist) 69 | return avg_dist, pos_dists, neg_dists 70 | 71 | def get_each_loss(self, dist, sample_mask): 72 | sample_num = torch.sum(sample_mask) 73 | sample_dist = torch.mul(dist, sample_mask) 74 | return sample_dist, sample_num 75 | ''' -------------------------------------------------------------------------------- /models/model_utils.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear 2 | from transformers import BartForConditionalGeneration 3 | 4 | def convert_model(src_model, dst_model): 5 | params_src = src_model.named_parameters() 6 | params_dst = dst_model.named_parameters() 7 | dict_src = dict(params_src) 8 | dict_dst = dict(params_dst) 9 | match_fail = 0 10 | for name in dict_src: 11 | dst_name = f'model.{name}' 12 | if dst_name in dict_dst: 13 | dict_dst[dst_name].data.copy_(dict_src[name].data) 14 | elif name in dict_dst: 15 | dict_dst[name].data.copy_(dict_src[name].data) 16 | else: 17 | match_fail += 1 18 | print(f'Unmatched layer of dst_model to src_model : layer name : {dst_name}') 19 | 20 | if match_fail == 0: 21 | print('All layered are matched') 22 | 23 | return dst_model 24 | 25 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | beautifulsoup4==4.11.1 3 | bert-score==0.3.11 4 | cachetools==4.2.4 5 | certifi==2022.5.18.1 6 | charset-normalizer==2.0.12 7 | click==8.1.3 8 | cycler==0.11.0 9 | filelock==3.7.1 10 | fonttools==4.33.3 11 | gdown==4.4.0 12 | google-auth==1.35.0 13 | google-auth-oauthlib==0.4.6 14 | grpcio==1.46.3 15 | huggingface-hub==0.7.0 16 | idna==3.3 17 | importlib-metadata==4.11.4 18 | joblib==1.1.0 19 | kiwisolver==1.4.2 20 | Markdown==3.3.7 21 | matplotlib==3.5.2 22 | nltk==3.5 23 | numpy==1.22.4 24 | oauthlib==3.2.0 25 | packaging==21.3 26 | pandas==1.4.2 27 | Pillow==9.1.1 28 | protobuf==3.14.0 29 | pyasn1==0.4.8 30 | pyasn1-modules==0.2.8 31 | pyparsing==3.0.9 32 | PySocks==1.7.1 33 | python-dateutil==2.8.2 34 | pytz==2022.1 35 | PyYAML==6.0 36 | regex==2022.4.24 37 | requests==2.27.1 38 | requests-oauthlib==1.3.1 39 | rsa==4.8 40 | six==1.16.0 41 | soupsieve==2.3.2.post1 42 | tabulate==0.8.9 43 | tensorboard==2.5.0 44 | tensorboard-data-server==0.6.1 45 | tensorboard-plugin-wit==1.8.1 46 | tokenizers==0.12.1 47 | torch==1.10.1+cu111 48 | torchaudio==0.10.1+rocm4.1 49 | torchvision==0.11.2+cu111 50 | tqdm==4.64.0 51 | transformers==4.19.2 52 | typing-extensions==4.2.0 53 | urllib3==1.26.9 54 | Werkzeug==2.1.2 55 | zipp==3.8.0 56 | -------------------------------------------------------------------------------- /scripts/feature_learn.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.getcwd()) 5 | 6 | import json 7 | import yaml 8 | from tqdm import tqdm 9 | import argparse 10 | import numpy as np 11 | import torch 12 | 13 | from torch.utils.data import DataLoader 14 | from src.utils import load_logger, load_yaml 15 | from src.sampler import get_data_sampler, load_datasets 16 | from src.lr_schedule import WarmupLinearScheduler 17 | from torch.utils.tensorboard import SummaryWriter 18 | from torch.cuda.amp import autocast 19 | from torch.cuda.amp import GradScaler 20 | from torch.nn.utils import clip_grad_norm_ 21 | from distutils.util import strtobool as _bool 22 | from src.train_utils import get_data_feeder_from_sampler 23 | from models.distance_func import * 24 | from models.loss_func import * 25 | from src.trainer import * 26 | from copy import deepcopy 27 | 28 | 29 | parser = argparse.ArgumentParser() 30 | 31 | parser.add_argument('--mode', type=str, default=None) 32 | parser.add_argument('--file_dir', type=str, default='/mnt/data/user8/solar-commonsense_inference') 33 | parser.add_argument('--dataset_type', type=str, default='atomic-2020')#, required=True) 34 | parser.add_argument('--model_name', type=str, default='gpt2', help="bart | gpt2")#, required=True) 35 | parser.add_argument('--model_size', type=str, default='base', help="base | large")#, required=True) 36 | parser.add_argument('--main_yml', type=str, default='v01.yml')#, required=True) 37 | parser.add_argument('--tknz_yml', type=str, default='tokenizer_config.yml') 38 | parser.add_argument('--log_level', type=str, default='INFO') 39 | parser.add_argument('--log', type=str, default='NEWvTEST')#, required=True) 40 | parser.add_argument('--random_seed', type=int, default=42) 41 | parser.add_argument('--load_model', type=str, default=None) 42 | 43 | # Optimize 44 | parser.add_argument('--batch_size', type=int, default=128) 45 | parser.add_argument('--update_batch', type=int, default=128) 46 | parser.add_argument('--iter_per_epoch', type=int, default=40000) 47 | parser.add_argument('--dev_iter_per_epoch', type=int, default=1) 48 | parser.add_argument('--epoch_num', type=int, default=1) 49 | 50 | parser.add_argument("--learning_rate", default=0.00001, type=float, help="The initial learning rate for Adam.") 51 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 52 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 53 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 54 | parser.add_argument('--lm_head_dropout_p', default=0.1, type=float) 55 | parser.add_argument('--proj_head_dropout_p', default=0.1, type=float) 56 | 57 | parser.add_argument("--recadam_anneal_fun", type=str, default='sigmoid', choices=["sigmoid", "linear", 'constant'], 58 | help="the type of annealing function in RecAdam. Default sigmoid") 59 | parser.add_argument("--recadam_anneal_k", type=float, default=0.0001, help="k for the annealing function in RecAdam.") 60 | parser.add_argument("--recadam_anneal_t0", type=int, default=5000, help="t0 for the annealing function in RecAdam.") 61 | parser.add_argument("--recadam_anneal_w", type=float, default=1.0, 62 | help="Weight for the annealing function in RecAdam. Default 1.0.") 63 | parser.add_argument("--recadam_pretrain_cof", type=float, default=5000.0, 64 | help="Coefficient of the quadratic penalty in RecAdam. Default 5000.0.") 65 | parser.add_argument('--warmup_steps', type=int, default=100) 66 | 67 | # ETC 68 | parser.add_argument('--share_proj_layer', type=_bool, default=True, help= 69 | 'if true, share projection layer between s_representation and ro_representation' 70 | 'if false, do not share projection layer between them. each have each own projection layer') 71 | 72 | args = parser.parse_args() 73 | 74 | if args.batch_size > args.update_batch: 75 | args.batch_size = args.update_batch 76 | 77 | # Path Setting 78 | np.random.seed(args.random_seed) 79 | 80 | # - log, [tb, ckpt, gen, eval] 81 | main_config_path = f'config/{args.dataset_type}/{args.model_name}-{args.model_size}_experiments/{args.main_yml}' 82 | tknz_config_path = f'config/{args.tknz_yml}' 83 | dataset_config_path = f'config/{args.dataset_type}/datasets.yml' 84 | 85 | log_dir = f'{args.file_dir}/log/{args.dataset_type}/{args.model_name}-{args.model_size}_experiments/{args.log}' 86 | logging_dir = log_dir + '/logging.log' 87 | tb_dir = log_dir + '/tb' 88 | gen_dir = log_dir + '/gen' 89 | eval_dir = log_dir + '/eval' 90 | ckpt_dir = log_dir + '/ckpt' 91 | 92 | if not os.path.exists(log_dir): 93 | os.makedirs(log_dir) 94 | os.mkdir(tb_dir) 95 | os.mkdir(gen_dir) 96 | os.mkdir(eval_dir) 97 | os.mkdir(ckpt_dir) 98 | 99 | # Initialize Log 100 | logger = load_logger(logging_dir, args.log_level) 101 | logger.info('Logger is Successfully initialized !') 102 | tb_writer = SummaryWriter(tb_dir) 103 | 104 | # Loading YML file 105 | main_config = load_yaml(main_config_path) 106 | tknz_config = load_yaml(tknz_config_path) 107 | dataset_cfg = load_yaml(dataset_config_path) 108 | 109 | with open(os.path.join(log_dir, 'main_config.json'), 'w') as f: 110 | json.dump(main_config, f) 111 | with open(os.path.join(log_dir, 'tknz_config.json'), 'w') as f: 112 | json.dump(tknz_config, f) 113 | with open(os.path.join(log_dir, 'argparse.json'), 'w') as f: 114 | json.dump(args.__dict__, f) 115 | 116 | task_cfg = main_config['task'] 117 | model_cfg = main_config['model'] 118 | opt_cfg = main_config['opt'] 119 | log_cfg = main_config['log'] 120 | 121 | # Tokenizer Setting with args.tknz_yml 122 | if args.model_name == 'bart': 123 | from transformers import BartTokenizer as Tokenizer 124 | from models.bart import CometBART as CometModel 125 | from src.tokenizer import BartCSKGTokenizer as CSKGTokenizer 126 | 127 | elif args.model_name == 'gpt2': 128 | from transformers import GPT2Tokenizer as Tokenizer 129 | from models.gpt2 import CometGPT2 as CometModel 130 | from src.tokenizer import Gpt2CSKGTokenizer as CSKGTokenizer 131 | 132 | elif 't5' in model_cfg['name']: 133 | raise NotImplementedError 134 | else: 135 | raise NotImplementedError 136 | 137 | anneal_targets = load_yaml(f'models/pretrained_params/pretrained_{args.model_name}-{args.model_size}.yml')['pretrained_weights'] 138 | 139 | _tokenizer = Tokenizer.from_pretrained(model_cfg['pretrained_model']) 140 | tokenizer = CSKGTokenizer(_tokenizer, tknz_config) 141 | vocab_len = len(tokenizer) + 1 142 | 143 | model = CometModel.from_pretrained(model_cfg['pretrained_model']) 144 | model.modify_lm_heads(vocab_len, args.lm_head_dropout_p) 145 | model.add_proj_layer(model_cfg['contrastive_head'], args.proj_head_dropout_p) 146 | 147 | pretrained_model = deepcopy(model) 148 | 149 | # Load Dataset 150 | dataset, sim_mat = load_datasets(dataset_cfg, tokenizer, logger) 151 | 152 | # Initialize Sampler 153 | sampler = get_data_sampler(task_cfg, dataset, sim_mat, tokenizer, logger, args) 154 | train_sampler = sampler['train'] 155 | dev_sampler = sampler['dev'] 156 | 157 | # Connect sampler to data_feeding adaptor 158 | options = main_config['model']['task_adaptor_options'] 159 | 160 | iter_num = 0 161 | assert args.update_batch >= args.batch_size 162 | stack_num = int(args.update_batch / args.batch_size) 163 | assert int(stack_num * args.batch_size) == int(args.update_batch) 164 | 165 | train_data_feeder = get_data_feeder_from_sampler(train_sampler, options, tokenizer, task_cfg, args.batch_size, args.iter_per_epoch * stack_num) 166 | dev_data_feeder = get_data_feeder_from_sampler(dev_sampler, options, tokenizer, task_cfg, args.batch_size, args.dev_iter_per_epoch) 167 | 168 | if model_cfg['contrastive_head']['proj_layer_type'] in ['multi-head', 'multi-head-deeper']: 169 | distance_model = MultiHeadDistance(model_cfg['contrastive_head']['multi-head']) 170 | elif model_cfg['contrastive_head']['proj_layer_type'] == 'non-linear': 171 | distance_model = NonLinearHeadDistance() 172 | else: 173 | raise NotImplementedError 174 | 175 | usable_cuda = torch.cuda.is_available() 176 | device = torch.device("cuda:0" if usable_cuda else "cpu") 177 | 178 | model.to(device) 179 | pretrained_model.to(device) 180 | 181 | from src.rec_adam_wrapper import get_adam_optimizer 182 | optim, scheduler = get_adam_optimizer(model, pretrained_model, anneal_targets, args, args.iter_per_epoch * args.epoch_num) 183 | 184 | global_steps = 0 185 | model.train() 186 | optim.zero_grad() 187 | 188 | if args.share_proj_layer: 189 | share_s = share_ro = None 190 | else: 191 | share_s = True 192 | share_ro = False 193 | 194 | share_proj_layer = (share_s, share_ro) 195 | 196 | if task_cfg['con_task']['loss_func'] == 'NT-Logistic': 197 | con_loss_func = NT_Logistic(opt_cfg['temperature']) 198 | else: 199 | raise NotImplementedError 200 | 201 | auxil_loss_func = torch.nn.CrossEntropyLoss(reduction='none') 202 | 203 | 204 | con_task_loader = DataLoader(train_data_feeder['con_task'], batch_size=1, 205 | drop_last=False, shuffle=False, num_workers=4) 206 | 207 | if options['con_task']['format'] == 'enc-dec': 208 | con_task_train = ConTaskTrainerForEncDec(model, device, distance_model, con_loss_func, 209 | log_cfg['tb_period'], tb_writer, share_proj_layer, args) 210 | 211 | elif options['con_task']['format'] == 'dec': 212 | con_task_train = ConTaskTrainerForDec(model, device, distance_model, con_loss_func, 213 | log_cfg['tb_period'], tb_writer, share_proj_layer, args) 214 | 215 | if 'rec_inf_shu_task' in train_data_feeder: 216 | auxil_task_loader = DataLoader(train_data_feeder['rec_inf_shu_task'], batch_size=args.batch_size, 217 | drop_last=True, shuffle=False, num_workers=10) 218 | auxil_task_name = 'rec_inf_shu_task' 219 | auxil_task_train = RecTaskTrainerForEncDec(model, device, auxil_loss_func, log_cfg['tb_period'], tb_writer) 220 | 221 | elif 'mask_gen_task' in train_data_feeder: 222 | auxil_task_loader = DataLoader(train_data_feeder['mask_gen_task'], batch_size=args.batch_size, 223 | drop_last=True, shuffle=False, num_workers=10) 224 | auxil_task_name = 'mask_gen_task' 225 | auxil_task_train = MaskGenTaskTrainerForDec(model, device, auxil_loss_func, log_cfg['tb_period'], tb_writer) 226 | 227 | else: 228 | raise NotImplementedError 229 | 230 | torch.save(model, os.path.join(ckpt_dir, 'model-{}-steps.ckpt'.format(global_steps))) 231 | 232 | ######### 233 | # for task_name, task_feeder in train_data_feeder.items(): 234 | # task_feeder.sampler._init_cursor(args.random_seed) 235 | # 236 | # for con_task_data, auxil_task_data in tqdm(zip(con_task_loader, auxil_task_loader)): 237 | # multi_task_loss = [] 238 | ######## 239 | for e in range(args.epoch_num): 240 | model.train() 241 | for task_name, task_feeder in train_data_feeder.items(): 242 | task_feeder.sampler._init_cursor(args.random_seed) 243 | 244 | for con_task_data, auxil_task_data in tqdm(zip(con_task_loader, auxil_task_loader)): 245 | multi_task_loss = [] 246 | 247 | with autocast(): 248 | con_task_data = {i: con_task_data[i][0] for i in con_task_data} 249 | con_loss = con_task_train.train(con_task_data, global_steps, iter_num == 0) 250 | multi_task_loss.append((con_loss.item(), task_cfg['con_task']['weight'])) 251 | con_loss *= task_cfg['con_task']['weight'] 252 | con_loss /= stack_num 253 | con_loss.backward() 254 | 255 | with autocast(): 256 | auxil_loss = auxil_task_train.train(auxil_task_data, global_steps, iter_num == 0) 257 | multi_task_loss.append((auxil_loss.item(), task_cfg[auxil_task_name]['weight'])) 258 | auxil_loss *= task_cfg[auxil_task_name]['weight'] 259 | auxil_loss /= stack_num 260 | auxil_loss.backward() 261 | 262 | iter_num += 1 263 | if iter_num != stack_num: 264 | continue 265 | iter_num = 0 266 | 267 | clip_grad_norm_(model.parameters(), opt_cfg['clip_grad_norm']) 268 | _, anneal_lambda = optim.step() 269 | scheduler.step() 270 | global_steps += 1 271 | 272 | if global_steps % log_cfg['tb_period'] == 0: 273 | avg_multi_task_loss = sum([i[0] for i in multi_task_loss]) / len(multi_task_loss) 274 | weighted_multi_task_loss = sum([i[0] * i[1] for i in multi_task_loss]) 275 | tb_writer.add_scalar('train/total_lr', optim.param_groups[0]['lr'], global_steps) 276 | tb_writer.add_scalar('train/multi-task_loss', avg_multi_task_loss, global_steps) 277 | tb_writer.add_scalar('train/multi-task_loss(weighted)', weighted_multi_task_loss, global_steps) 278 | tb_writer.add_scalar('train/anneal_lambda', anneal_lambda, global_steps) 279 | tb_writer.flush() 280 | 281 | if global_steps % log_cfg['save_period'] == 0: 282 | torch.save(model, os.path.join(ckpt_dir, 'model-{}-steps.ckpt'.format(global_steps))) 283 | 284 | torch.save(model, os.path.join(ckpt_dir, 'model-{}-steps.ckpt'.format(global_steps))) 285 | 286 | -------------------------------------------------------------------------------- /scripts/finetune.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | sys.path.append(os.getcwd()) 4 | 5 | import json 6 | import yaml 7 | from tqdm import tqdm 8 | import argparse 9 | import numpy as np 10 | import torch 11 | from copy import deepcopy 12 | 13 | from torch.utils.tensorboard import SummaryWriter 14 | from torch.cuda.amp import autocast 15 | from torch.cuda.amp import GradScaler 16 | from torch.nn.utils import clip_grad_norm_ 17 | from distutils.util import strtobool as _bool 18 | from torch.utils.data import DataLoader 19 | 20 | from src.utils import load_logger, load_yaml 21 | 22 | from src.finetune.finetune_utils import * 23 | from src.finetune.finetune_trainer import get_finetune_trainer 24 | from models.model_utils import convert_model 25 | from src.lr_schedule import WarmupLinearScheduler 26 | 27 | os.environ['CUDA_VISIBLE_DEVICES'] = '3' 28 | parser = argparse.ArgumentParser() 29 | 30 | parser.add_argument('--mode', type=str, default=None) 31 | parser.add_argument('--file_dir', type=str, default='/mnt/data/user8/solar-commonsense_inference') 32 | parser.add_argument('--dataset_type', type=str, default='conceptnet') 33 | parser.add_argument('--model_name', type=str, default='bart', help="bart | gpt2")#, required=True) 34 | parser.add_argument('--model_size', type=str, default='large', help="base | large")#, required=True) 35 | parser.add_argument('--exp_type', type=str, default='baseline', help='baseline | experiments')#, required=True) 36 | parser.add_argument('--main_yml', type=str, default='v01.yml') 37 | parser.add_argument('--tknz_yml', type=str, default='tokenizer_config.yml') 38 | parser.add_argument('--log_level', type=str, default='INFO') 39 | parser.add_argument('--log', type=str, default='vTEST1')#, required=True) 40 | parser.add_argument('--random_seed', type=int, default=42) 41 | parser.add_argument('--load_model', type=str, default=None) 42 | 43 | parser.add_argument('--same_hwang', type=_bool, default=False) 44 | 45 | parser.add_argument('--lr', type=float, default=0.00001) 46 | parser.add_argument('--batch_size', type=int, default=64) 47 | parser.add_argument('--update_batch', type=int, default=64) 48 | parser.add_argument('--epoch_num', type=int, default=2) 49 | parser.add_argument('--output_drop', type=float, default=0.1) 50 | parser.add_argument('--remove_except_best', type=_bool, default=True) 51 | parser.add_argument('--patience', type=int, default=2) 52 | 53 | 54 | args = parser.parse_args() 55 | 56 | if args.load_model is not None: 57 | assert args.exp_type == 'experiments' 58 | 59 | args.batch_size = args.batch_size if args.update_batch >= args.batch_size else args.update_batch 60 | 61 | np.random.seed(args.random_seed) 62 | 63 | main_config_path = f'config/{args.dataset_type}/{args.model_name}-{args.model_size}_baseline/{args.main_yml}' 64 | tknz_config_path = f'config/{args.tknz_yml}' 65 | 66 | dataset_config_path = f'config/{args.dataset_type}/datasets.yml' 67 | 68 | log_dir = f'{args.file_dir}/log_fntn/{args.dataset_type}/{args.model_name}-{args.model_size}_{args.exp_type}/{args.log}' 69 | 70 | logging_dir = log_dir + '/logging.log' 71 | tb_dir = log_dir + '/tb' 72 | gen_dir = log_dir + '/gen' 73 | eval_dir = log_dir + '/eval' 74 | ckpt_dir = log_dir + '/ckpt' 75 | 76 | if not os.path.exists(log_dir): 77 | os.makedirs(log_dir) 78 | os.mkdir(tb_dir) 79 | os.mkdir(gen_dir) 80 | os.mkdir(eval_dir) 81 | os.mkdir(ckpt_dir) 82 | 83 | logger = load_logger(logging_dir, args.log_level) 84 | logger.info('Logger is Successfully initialized !') 85 | tb_writer = SummaryWriter(tb_dir) 86 | 87 | main_config = load_yaml(main_config_path) 88 | tknz_config = load_yaml(tknz_config_path) 89 | dataset_cfg = load_yaml(dataset_config_path) 90 | 91 | with open(os.path.join(log_dir, 'main_config.json'), 'w') as f: 92 | json.dump(main_config, f) 93 | with open(os.path.join(log_dir, 'tknz_config.json'), 'w') as f: 94 | json.dump(tknz_config, f) 95 | with open(os.path.join(log_dir, 'argparse.json'), 'w') as f: 96 | json.dump(args.__dict__, f) 97 | 98 | model_cfg = main_config['model'] 99 | opt_cfg = main_config['opt'] 100 | log_cfg = main_config['log'] 101 | 102 | # Model Selection 103 | MODEL_TYPE = None 104 | if 'bart' in model_cfg['name']: 105 | from transformers import BartTokenizer as Tokenizer 106 | from models.bart import CometBART as CometModel 107 | from src.tokenizer import BartCSKGTokenizer as CSKGTokenizer 108 | from models.bart import convert_BARTModel_to_BartForConditionalGeneration as model2gen 109 | MODEL_TYPE = 'enc-dec' 110 | 111 | elif 't5' in model_cfg['name']: 112 | raise NotImplementedError 113 | 114 | elif 'gpt2' in model_cfg['name']: 115 | from transformers import GPT2Tokenizer as Tokenizer 116 | from models.gpt2 import CometGPT2 as CometModel 117 | from src.tokenizer import Gpt2CSKGTokenizer as CSKGTokenizer 118 | from models.gpt2 import convert_GPT2Model_to_GPT2LMHeadModel as model2gen 119 | MODEL_TYPE = 'dec' 120 | 121 | else: 122 | raise NotImplementedError 123 | 124 | _tokenizer = Tokenizer.from_pretrained(model_cfg['pretrained_model']) 125 | tokenizer = CSKGTokenizer(_tokenizer, tknz_config) 126 | 127 | vocab_len = len(tokenizer) + 1 128 | 129 | if args.load_model is None: 130 | model = CometModel.from_pretrained(model_cfg['pretrained_model']) 131 | model.modify_lm_heads(vocab_len, args.output_drop) 132 | logger.info("Model is loaded from : {}".format(model_cfg['pretrained_model'])) 133 | else: 134 | args.load_model = os.path.join(args.load_model) 135 | src_model = torch.load(args.load_model) 136 | logger.info("Model is loaded from : {}".format(args.load_model)) 137 | dst_model = CometModel.from_pretrained(model_cfg['pretrained_model']) 138 | dst_model.modify_lm_heads(vocab_len, args.output_drop) 139 | model = convert_model(src_model, dst_model) 140 | 141 | dataset = load_fntn_datasets(dataset_cfg, tokenizer, logger) 142 | 143 | fntn_train_dataset = get_finetune_dataset(dataset['train'], dataset_cfg, tokenizer, logger, 'train', MODEL_TYPE) 144 | fntn_dev_dataset = get_finetune_dataset(dataset['dev'], dataset_cfg, tokenizer, logger, 'dev', MODEL_TYPE) 145 | 146 | eval_dev_dataset = get_eval_dataset(dataset['dev'], tokenizer, logger, 'dev', MODEL_TYPE) 147 | eval_test_dataset = get_eval_dataset(dataset['test'], tokenizer, logger, 'test', MODEL_TYPE) 148 | 149 | fntn_train_loader = DataLoader(fntn_train_dataset, 150 | batch_size=args.batch_size, drop_last=True, shuffle=True, num_workers=20) 151 | fntn_dev_loader = DataLoader(fntn_dev_dataset, 152 | batch_size=args.batch_size, drop_last=True, shuffle=True, num_workers=20) 153 | 154 | usable_cuda = torch.cuda.is_available() 155 | device = torch.device("cuda:0" if usable_cuda else "cpu") 156 | model.to(device) 157 | 158 | # Optimizer Settings 159 | if args.same_hwang is not None: 160 | optim = torch.optim.AdamW(model.parameters(), lr=args.lr, eps=1e-8) 161 | else: 162 | optim = torch.optim.Adam(model.parameters(), 163 | lr=args.lr, betas=(opt_cfg['adam_beta_1'], opt_cfg['adam_beta_2']), weight_decay=opt_cfg['weight_decay']) 164 | 165 | total_steps = int((len(fntn_train_dataset) / args.update_batch) * args.epoch_num) + 100 166 | lr_schedule = WarmupLinearScheduler(optim, args.lr, opt_cfg['warmup_steps'], total_steps) 167 | loss_func = torch.nn.CrossEntropyLoss(reduction='none') 168 | global_step = 0 169 | scaler = GradScaler() 170 | model.train() 171 | optim.zero_grad() 172 | 173 | iter_num = 0 174 | 175 | assert args.update_batch >= args.batch_size 176 | stack_num = int(args.update_batch / args.batch_size) 177 | assert int(stack_num * args.batch_size) == int(args.update_batch) 178 | 179 | fntn_trainer = get_finetune_trainer(model, device, loss_func, log_cfg['tb_period'], tb_writer, MODEL_TYPE) 180 | 181 | ckpt_loss = {} 182 | 183 | for e in range(1, args.epoch_num+1): 184 | model.train() 185 | fntn_train_loader.dataset.shuffle() 186 | patience = args.patience 187 | for sample in tqdm(fntn_train_loader, desc='[Train] Epoch {}/{}'.format(e, args.epoch_num), ncols=130): 188 | with autocast(): 189 | loss = fntn_trainer.train(sample, global_step, 'train', iter_num == 0) 190 | loss /= stack_num 191 | loss.backward() 192 | iter_num += 1 193 | if iter_num != stack_num: 194 | continue 195 | iter_num = 0 196 | 197 | lr_schedule(global_step) 198 | clip_grad_norm_(model.parameters(), opt_cfg['clip_grad_norm']) 199 | optim.step() 200 | optim.zero_grad() 201 | global_step += 1 202 | 203 | if global_step % log_cfg['tb_period'] == 0: 204 | tb_writer.add_scalar('train/lr', optim.param_groups[0]['lr'], global_step) 205 | tb_writer.flush() 206 | 207 | model.eval() 208 | loss_list = list() 209 | for sample in tqdm(fntn_dev_loader, desc='[Validate]', ncols=130): 210 | with autocast(): 211 | loss = fntn_trainer.train(sample, global_step, 'dev') 212 | loss_list.append(loss.item()) 213 | loss = sum(loss_list) / len(loss_list) 214 | tb_writer.add_scalar('dev/loss', loss, global_step) 215 | 216 | save_name = os.path.join(ckpt_dir, 'model-{}-epoch.ckpt'.format(e)) 217 | 218 | torch.save(model, save_name) 219 | ckpt_loss[save_name] = loss 220 | 221 | for _ckpt, _loss in ckpt_loss.items(): 222 | if loss > _loss: 223 | patience -= 1 224 | if patience <= 0: 225 | logger.info('Patience is over. End') 226 | break 227 | 228 | tokenizer_save_name = os.path.join(ckpt_dir, 'tokenizer.torch-pkl') 229 | torch.save(tokenizer, tokenizer_save_name) 230 | 231 | with open(os.path.join(ckpt_dir, 'ckpt_loss.json'), 'w') as f: 232 | json.dump(ckpt_loss, f) 233 | 234 | with open(os.path.join(ckpt_dir, 'ckpt_loss.json'), 'r') as f: 235 | ckpt_loss = json.load(f) 236 | 237 | best_ckpt = None 238 | best_loss = 10000 239 | 240 | for _ckpt, _loss in ckpt_loss.items(): 241 | if best_loss > _loss: 242 | best_loss = _loss 243 | best_ckpt = _ckpt 244 | 245 | if args.remove_except_best: 246 | for _ckpt in ckpt_loss: 247 | if _ckpt == best_ckpt: 248 | continue 249 | os.remove(os.path.join(_ckpt)) 250 | 251 | decode = tokenizer.tokenizer.decode 252 | 253 | model = torch.load(best_ckpt).to('cpu') 254 | if MODEL_TYPE == 'enc-dec': 255 | gen_model = model2gen(model, model_cfg['pretrained_model']).to(device) 256 | else: 257 | gen_model = model.to(device) 258 | 259 | test_decode_results = { 260 | 'info': {'log': log_dir, 'ckpt': best_ckpt}, 261 | 'content': list()} 262 | 263 | greedy_test_decode_results = deepcopy(test_decode_results) 264 | greedy_test_decode_results['info']['decode_method'] = 'greedy' 265 | beam5_test_decode_results = deepcopy(test_decode_results) 266 | beam5_test_decode_results['info']['decode_method'] = 'beam5' 267 | nucl_test_decode_results = deepcopy(test_decode_results) 268 | nucl_test_decode_results['info']['decode_method'] = 'nucl' 269 | 270 | if MODEL_TYPE == 'enc-dec': 271 | for sample in tqdm(eval_test_dataset, ncols=130): 272 | src = sample['src'] 273 | refs = sample['ref'] 274 | enc_input_ids = torch.tensor(src).to(device).view(1, -1) 275 | enc_att_masks = torch.ones_like(enc_input_ids).to(device) 276 | 277 | inputs = {'input_ids': enc_input_ids, 'attention_mask': enc_att_masks} 278 | greedy_output = gen_model.generate(**inputs, early_stopping=True, 279 | bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id) 280 | # try: 281 | # beam5_output = gen_model.generate(**inputs, num_beams=5, early_stopping=True) #max_length=60, 282 | # except: 283 | # beam5_output = gen_model.generate(**inputs, num_beams=5, early_stopping=True, max_length=60) 284 | greedy_str = decode(greedy_output.tolist()[0]) 285 | # beam5_str = decode(beam5_output.tolist()[0]) 286 | 287 | if '' in greedy_str: 288 | greedy_str = greedy_str[greedy_str.find('') + 1 + len(''):].strip() 289 | # if '' in beam5_str: 290 | # beam5_str = beam5_str[beam5_str.find('')+1 + len(''):].strip() 291 | 292 | _input = decode(src) 293 | _refs = list() 294 | for ref in refs: 295 | _ref = decode(ref) 296 | _refs.append(_ref) 297 | 298 | greedy_test_decode_results['content'].append({'input': _input, 'output': greedy_str, 'refs': _refs}) 299 | # beam5_test_decode_results['content'].append({'input': _input, 'output': beam5_str, 'refs': _refs}) 300 | 301 | with open(os.path.join(log_dir, 'greedy_gen_examples.json'), 'w') as f: 302 | json.dump(greedy_test_decode_results, f) 303 | # 304 | # with open(os.path.join(log_dir, 'beam5_gen_examples.json'), 'w') as f: 305 | # json.dump(beam5_test_decode_results, f) 306 | 307 | else: 308 | 309 | for sample in tqdm(eval_test_dataset, ncols=130): 310 | src = sample['src'] 311 | refs = sample['ref'] 312 | enc_input_ids = torch.tensor(src).to(device).view(1, -1) 313 | enc_att_masks = torch.ones_like(enc_input_ids).to(device) 314 | 315 | inputs = {'input_ids': enc_input_ids, 'att_masks': enc_att_masks} 316 | 317 | for i in range(30): 318 | output = model.forward_conditional_gen(**inputs) # ['input_ids'], att_masks=inputs['attention_mask']) 319 | gen_token_id = int(torch.argmax(output[:, -1, :], -1)) 320 | old_inputs = {key: val.tolist() for key, val in inputs.items()} 321 | old_inputs['input_ids'][0].append(gen_token_id) 322 | old_inputs['att_masks'][0].append(1) 323 | if gen_token_id == tokenizer.eos_token_id: 324 | break 325 | inputs = {key: torch.tensor(val).to(device) for key, val in old_inputs.items()} 326 | 327 | greedy_output = inputs['input_ids'] 328 | 329 | greedy_str = decode(greedy_output.tolist()[0]) 330 | 331 | if '' in greedy_str: 332 | greedy_str = greedy_str[greedy_str.find('') + 1 + len(''):].strip() 333 | 334 | _input = decode(src) 335 | print(_input) 336 | print(greedy_str) 337 | print('----------') 338 | _refs = list() 339 | for ref in refs: 340 | _ref = decode(ref) 341 | _refs.append(_ref) 342 | 343 | greedy_test_decode_results['content'].append({'input': _input, 'output': greedy_str, 'refs': _refs}) 344 | 345 | with open(os.path.join(log_dir, 'greedy_gen_examples_FIXED.json'), 'w') as f: 346 | json.dump(greedy_test_decode_results, f) 347 | -------------------------------------------------------------------------------- /scripts/inference.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | sys.path.append(os.getcwd()) 4 | 5 | import json 6 | import yaml 7 | from tqdm import tqdm 8 | import argparse 9 | import numpy as np 10 | import torch 11 | 12 | from torch.utils.tensorboard import SummaryWriter 13 | from torch.cuda.amp import autocast 14 | from torch.cuda.amp import GradScaler 15 | from torch.nn.utils import clip_grad_norm_ 16 | from distutils.util import strtobool as _bool 17 | from torch.utils.data import DataLoader 18 | 19 | from src.utils import load_logger, load_yaml 20 | 21 | from src.finetune.finetune_utils import * 22 | from src.finetune.finetune_trainer import FineTuneTrainer 23 | from models import bart 24 | from src.lr_schedule import WarmupLinearScheduler 25 | 26 | os.environ['CUDA_VISIBLE_DEVICES'] = '2' 27 | 28 | 29 | parser = argparse.ArgumentParser() 30 | 31 | parser.add_argument('--mode', type=str, default=None) 32 | parser.add_argument('--file_dir', type=str, default='/mnt/data/user8/solar-commonsense_inference') 33 | parser.add_argument('--dataset_type', type=str, default='conceptnet') 34 | parser.add_argument('--model_name', type=str, default='bart', help="bart | gpt2") 35 | parser.add_argument('--model_size', type=str, default='large', help="base | large") 36 | parser.add_argument('--exp_type', type=str, default='baseline', help='baseline | experiments') 37 | parser.add_argument('--log_level', type=str, default='INFO') 38 | parser.add_argument('--random_seed', type=int, default=42) 39 | parser.add_argument('--load_model', type=str, default='v07') 40 | parser.add_argument('--use_greedy', type=_bool, default=False) 41 | parser.add_argument('--use_beam', type=int, default=10 ) 42 | 43 | args = parser.parse_args() 44 | 45 | np.random.seed(args.random_seed) 46 | 47 | use_greedy = args.use_greedy 48 | use_beam = args.use_beam > 1 49 | beam_num = args.use_beam 50 | 51 | log_dir = f'{args.file_dir}/log_fntn/{args.dataset_type}/{args.model_name}-{args.model_size}_{args.exp_type}/{args.load_model}' 52 | 53 | main_config_path = f'{log_dir}/main_config.json' 54 | tknz_config_path = f'{log_dir}/tknz_config.json' 55 | 56 | dataset_config_path = f'config/{args.dataset_type}/datasets.yml' 57 | 58 | print(f'Log Location : {log_dir}') 59 | logging_dir = log_dir + '/inf_logging.log' 60 | tb_dir = log_dir + '/tb' 61 | gen_dir = log_dir + '/gen' 62 | eval_dir = log_dir + '/eval' 63 | ckpt_dir = log_dir + '/ckpt' 64 | 65 | if not os.path.exists(log_dir): 66 | raise Exception 67 | 68 | logger = load_logger(logging_dir, args.log_level) 69 | logger.info('Logger is Successfully initialized !') 70 | 71 | main_config = load_yaml(main_config_path) 72 | tknz_config = load_yaml(tknz_config_path) 73 | dataset_cfg = load_yaml(dataset_config_path) 74 | 75 | model_cfg = main_config['model'] 76 | opt_cfg = main_config['opt'] 77 | log_cfg = main_config['log'] 78 | 79 | usable_cuda = torch.cuda.is_available() 80 | device = torch.device("cuda:0" if usable_cuda else "cpu") 81 | 82 | 83 | # Model Selection 84 | MODEL_TYPE = None 85 | if 'bart' in model_cfg['name']: 86 | from transformers import BartTokenizer as Tokenizer 87 | from models.bart import CometBART as CometModel 88 | from src.tokenizer import BartCSKGTokenizer as CSKGTokenizer 89 | from models.bart import convert_BARTModel_to_BartForConditionalGeneration as model2gen 90 | MODEL_TYPE = 'enc-dec' 91 | 92 | elif 't5' in model_cfg['name']: 93 | raise NotImplementedError 94 | 95 | elif 'gpt2' in model_cfg['name']: 96 | raise NotImplementedError 97 | 98 | else: 99 | raise NotImplementedError 100 | 101 | _tokenizer = Tokenizer.from_pretrained(model_cfg['pretrained_model']) 102 | tokenizer = CSKGTokenizer(_tokenizer, tknz_config) 103 | 104 | vocab_len = len(tokenizer) + 1 105 | 106 | with open(os.path.join(ckpt_dir, 'ckpt_loss.json'), 'r') as f: 107 | ckpt_loss = json.load(f) 108 | 109 | best_ckpt = None 110 | best_loss = 10000 111 | 112 | for _ckpt, _loss in ckpt_loss.items(): 113 | if best_loss > _loss: 114 | best_loss = _loss 115 | best_ckpt = _ckpt 116 | 117 | logger.info("Model will be loaded from : {}".format(best_ckpt)) 118 | 119 | try: 120 | if str(device) == 'cpu': 121 | model = torch.load(best_ckpt, map_location=torch.device('cpu')) 122 | else: 123 | model = torch.load(best_ckpt) 124 | 125 | except: 126 | best_ckpt = load_model.replace('log_fntn', 'log2_fntn') 127 | model = torch.load(best_ckpt) 128 | 129 | gen_model = model2gen(model, model_cfg['pretrained_model']).to(device) 130 | 131 | gen_model.to(device) 132 | gen_model.eval() 133 | 134 | decode = tokenizer.tokenizer.decode 135 | 136 | 137 | dataset = load_fntn_datasets(dataset_cfg, tokenizer, logger) 138 | eval_test_dataset = get_eval_dataset(dataset['test'], tokenizer, logger, 'test', 'dec') 139 | 140 | test_decode_results = { 141 | 'info': {'log': log_dir, 'ckpt': best_ckpt}, 142 | 'content': list()} 143 | 144 | from copy import deepcopy 145 | 146 | if use_greedy: 147 | greedy_results = deepcopy(test_decode_results) 148 | greedy_results['info']['decode_method'] = 'greedy' 149 | 150 | if use_beam: 151 | beam_results = deepcopy(test_decode_results) 152 | beam_results['info']['decode_method'] = f'beam{beam_num}' 153 | 154 | 155 | for sample in tqdm(eval_test_dataset, ncols=130): 156 | src = sample['src'] 157 | refs = sample['ref'] 158 | 159 | _input = decode(src) 160 | _refs = list() 161 | for ref in refs: 162 | _ref = decode(ref) 163 | _refs.append(_ref) 164 | 165 | enc_input_ids = torch.tensor(src).to(device).view(1, -1) 166 | enc_att_masks = torch.ones_like(enc_input_ids).to(device) 167 | 168 | inputs = {'input_ids': enc_input_ids, 'attention_mask': enc_att_masks} 169 | if use_greedy: 170 | greedy_output = gen_model.generate(**inputs, early_stopping=True, 171 | bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id) 172 | greedy_str = decode(greedy_output.tolist()[0]) 173 | greedy_results['content'].append({'input': _input, 'output': greedy_str, 'refs': _refs}) 174 | 175 | if use_beam: 176 | beam_output = gen_model.generate(**inputs, num_beams=beam_num, num_return_sequences=beam_num, early_stopping=True, 177 | bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id) #max_length=60, 178 | beam_str = [decode(beam.tolist()).replace('','') for beam in beam_output] 179 | beam_results['content'].append({'input': _input, 'output': beam_str, 'refs': _refs}) 180 | 181 | if use_greedy: 182 | with open(os.path.join(log_dir, 'greedy_gen_examples.json'), 'w') as f: 183 | json.dump(greedy_results, f) 184 | 185 | if use_beam: 186 | with open(os.path.join(log_dir, f'beam-{beam_num}_gen_examples.json'), 'w') as f: 187 | json.dump(beam_results, f) 188 | # with open(os.path.join(log_dir, 'beam5_gen_examples.json'), 'w') as f: 189 | # json.dump(beam5_test_decode_results, f) 190 | # 191 | # for sample in tqdm(eval_test_dataset, ncols=130): 192 | # src = sample['src'] 193 | # refs = sample['ref'] 194 | # enc_input_ids = torch.tensor(src).to(device).view(1, -1) 195 | # enc_att_masks = torch.ones_like(enc_input_ids).to(device) 196 | # 197 | # inputs = {'input_ids': enc_input_ids, 'attention_mask': enc_att_masks} 198 | # 199 | # print('\n\nstart') 200 | # greedy_output = gen_model.generate(**inputs)#, early_stopping=True) 201 | # try: 202 | # beam5_output = gen_model.generate(**inputs, num_beams=5, early_stopping=True) #max_length=60, 203 | # except: 204 | # print('beam err occur') 205 | # beam5_output = gen_model.generate(**inputs, num_beams=5, early_stopping=True, max_length=60) 206 | # print('end') 207 | # greedy_str = decode(greedy_output.tolist()[0]) 208 | # beam5_str = decode(beam5_output.tolist()[0]) 209 | # 210 | # if '' in greedy_str: 211 | # greedy_str = greedy_str[greedy_str.find('') + 1 + len(''):].strip() 212 | # if '' in beam5_str: 213 | # beam5_str = beam5_str[beam5_str.find('')+1 + len(''):].strip() 214 | # 215 | # _input = decode(src) 216 | # _refs = list() 217 | # for ref in refs: 218 | # _ref = decode(ref) 219 | # _refs.append(_ref) 220 | # 221 | # greedy_test_decode_results['content'].append({'input': _input, 'output': greedy_str, 'refs': _refs}) 222 | # beam5_test_decode_results['content'].append({'input': _input, 'output': beam5_str, 'refs': _refs}) 223 | # 224 | # 225 | # with open(os.path.join(log_dir, 'greedy_gen_examples.json'), 'w') as f: 226 | # json.dump(greedy_test_decode_results, f) 227 | # 228 | # with open(os.path.join(log_dir, 'beam5_gen_examples.json'), 'w') as f: 229 | # json.dump(beam5_test_decode_results, f) 230 | # 231 | # # ----- TEST ----- # 232 | # 233 | # for i in range(10): 234 | # output = gen_model(**inputs).logits 235 | # last_toks = torch.argmax(output[:, -1, :], -1) 236 | # old_inputs = {key: val.tolist() for key, val in inputs.items()} 237 | # old_inputs['input_ids'][0].append(int(last_toks)) 238 | # old_inputs['attention_mask'][0].append(1) 239 | # inputs = {key : torch.tensor(val).to(device) for key, val in old_inputs.items()} 240 | # 241 | # 242 | # # -- Non LM model -- # 243 | # 244 | # for i in range(10): 245 | # output = model.forward_conditional_gen(input_ids=inputs['input_ids'], att_masks=inputs['attention_mask']) 246 | # last_toks = torch.argmax(output[:, -1, :], -1) 247 | # old_inputs = {key: val.tolist() for key, val in inputs.items()} 248 | # old_inputs['input_ids'][0].append(int(last_toks)) 249 | # old_inputs['attention_mask'][0].append(1) 250 | # inputs = {key : torch.tensor(val).to(device) for key, val in old_inputs.items()} 251 | -------------------------------------------------------------------------------- /src/data_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import csv 4 | import operator 5 | import random 6 | 7 | 8 | def read_csv(input_file, quotechar='"', delimiter=",", skip_header=False): 9 | """Reads a tab separated value file.""" 10 | with open(input_file, "r") as f: 11 | reader = csv.reader(f, delimiter=delimiter, quotechar=quotechar, quoting=csv.QUOTE_ALL, skipinitialspace=True) 12 | lines = [] 13 | for line in reader: 14 | if sys.version_info[0] == 2: 15 | line = list(unicode(cell, 'utf-8') for cell in line) 16 | lines.append(line) 17 | if skip_header: 18 | lines = lines[1:] 19 | return lines 20 | 21 | 22 | def write_tsv(output_file, data, header=False): 23 | keys = list(data[0].keys()) 24 | with open(output_file, 'w') as f: 25 | w = csv.DictWriter(f, keys, delimiter='\t', lineterminator='\n') 26 | if header: 27 | w.writeheader() 28 | for r in data: 29 | entry = {k: r[k] for k in keys} 30 | w.writerow(entry) 31 | 32 | 33 | def write_array2tsv(output_file, data, header=False): 34 | keys = range(len(data[0])) 35 | with open(output_file, 'w') as f: 36 | w = csv.DictWriter(f, keys, delimiter='\t', lineterminator='\n') 37 | if header: 38 | w.writeheader() 39 | for r in data: 40 | entry = {k: r[k] for k in keys} 41 | w.writerow(entry) 42 | 43 | 44 | def write_csv(filename, data, fieldnames): 45 | with open(filename, 'w', newline='') as csvfile: 46 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 47 | 48 | writer.writeheader() 49 | for d in data: 50 | formatted_d = {} 51 | for key, val in d.items(): 52 | formatted_d[key] = json.dumps(val) 53 | writer.writerow(formatted_d) 54 | 55 | 56 | def read_jsonl(filename): 57 | data = [] 58 | with open(filename, "r") as f: 59 | for line in f: 60 | data.append(json.loads(line)) 61 | return data 62 | 63 | 64 | def write_items(output_file, items): 65 | with open(output_file, 'w') as f: 66 | for concept in items: 67 | f.write(concept + "\n") 68 | f.close() 69 | 70 | 71 | def write_jsonl(f, d): 72 | write_items(f, [json.dumps(r) for r in d]) 73 | 74 | 75 | def count_relation(d): 76 | relation_count = {} 77 | prefix_count = {} 78 | head_count = {} 79 | for l in d: 80 | r = l[1] 81 | if r not in relation_count.keys(): 82 | relation_count[r] = 0 83 | relation_count[r] += 1 84 | 85 | prefix = l[0]+l[1] 86 | if prefix not in prefix_count.keys(): 87 | prefix_count[prefix] = 0 88 | prefix_count[prefix] += 1 89 | 90 | head = l[0] 91 | if head not in head_count.keys(): 92 | head_count[head] = 0 93 | head_count[head] += 1 94 | 95 | sorted_relation_count = dict(sorted(relation_count.items(), key=operator.itemgetter(1), reverse=True)) 96 | sorted_prefix_count = dict(sorted(prefix_count.items(), key=operator.itemgetter(1), reverse=True)) 97 | sorted_head_count = dict(sorted(head_count.items(), key=operator.itemgetter(1), reverse=True)) 98 | 99 | print("Relations:") 100 | for r in sorted_relation_count.keys(): 101 | print(r, sorted_relation_count[r]) 102 | 103 | print("\nPrefixes:") 104 | print("uniq prefixes: ", len(sorted_prefix_count.keys())) 105 | i = 0 106 | for r in sorted_prefix_count.keys(): 107 | print(r, sorted_prefix_count[r]) 108 | i += 1 109 | if i > 20: 110 | break 111 | 112 | print("\nHeads:") 113 | i = 0 114 | for r in sorted_head_count.keys(): 115 | print(r, sorted_head_count[r]) 116 | i += 1 117 | if i > 20: 118 | break 119 | 120 | 121 | def get_head_set(d): 122 | return set([l[0] for l in d]) 123 | 124 | 125 | def head_based_split(data, dev_size, test_size, head_size_threshold=500, dev_heads=[], test_heads=[]): 126 | """ 127 | :param data: the tuples to split according to the heads, where the head is the first element of each tuple 128 | :param dev_size: target size of the dev set 129 | :param test_size: target size of the test set 130 | :param head_size_threshold: Maximum number of tuples a head can be involved in, 131 | in order to be considered for the dev/test set' 132 | :param dev_heads: heads that are forced to belong to the dev set 133 | :param test_heads: heads that are forced to belong to the test set 134 | :return: 135 | """ 136 | head_count = {} 137 | for l in data: 138 | head = l[0] 139 | if head not in head_count.keys(): 140 | head_count[head] = 0 141 | head_count[head] += 1 142 | 143 | remaining_heads = dict(head_count) 144 | 145 | test_selected_heads = {} 146 | test_head_total_count = 0 147 | 148 | for h in test_heads: 149 | if h in remaining_heads: 150 | c = remaining_heads[h] 151 | test_selected_heads[h] = c 152 | test_head_total_count += c 153 | remaining_heads.pop(h) 154 | 155 | while test_head_total_count < test_size: 156 | h = random.sample(remaining_heads.keys(), 1)[0] 157 | c = remaining_heads[h] 158 | if c < head_size_threshold: 159 | test_selected_heads[h] = c 160 | test_head_total_count += c 161 | remaining_heads.pop(h) 162 | 163 | test = [l for l in data if l[0] in test_selected_heads.keys()] 164 | 165 | dev_selected_heads = {} 166 | dev_head_total_count = 0 167 | 168 | for h in dev_heads: 169 | if h in remaining_heads: 170 | c = remaining_heads[h] 171 | dev_selected_heads[h] = c 172 | dev_head_total_count += c 173 | remaining_heads.pop(h) 174 | 175 | while dev_head_total_count < dev_size: 176 | h = random.sample(remaining_heads.keys(), 1)[0] 177 | c = remaining_heads[h] 178 | if c < head_size_threshold: 179 | dev_selected_heads[h] = c 180 | dev_head_total_count += c 181 | remaining_heads.pop(h) 182 | 183 | dev = [l for l in data if l[0] in dev_selected_heads.keys()] 184 | 185 | dev_test_heads = set(list(dev_selected_heads.keys()) + list(test_selected_heads.keys())) 186 | train = [l for l in data if l[0] not in dev_test_heads] 187 | 188 | return train, dev, test 189 | 190 | 191 | def remove_prefix(text, prefix): 192 | return text[text.startswith(prefix) and len(prefix):] 193 | -------------------------------------------------------------------------------- /src/finetune/finetune_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_finetune_trainer(model, device, loss_func, tb_period, writer, model_type): 5 | if model_type == 'enc-dec': 6 | return FineTuneTrainerForEncDec(model, device, loss_func, tb_period, writer) 7 | elif model_type == 'dec': 8 | return FineTuneTrainerForDec(model, device, loss_func, tb_period, writer) 9 | 10 | 11 | class FineTuneTrainer: 12 | def __init__(self, model, device, loss_func, tb_period, writer): 13 | self.model = model 14 | self.device = device 15 | self.loss_func = loss_func 16 | self.tb_period = tb_period 17 | self.writer = writer 18 | 19 | def train(self, sample, global_step, _type='train', save_tb=True): 20 | raise NotImplementedError 21 | 22 | 23 | class FineTuneTrainerForEncDec(FineTuneTrainer): 24 | def __init__(self, model, device, loss_func, tb_period, writer): 25 | super(FineTuneTrainerForEncDec, self).__init__(model, device, loss_func, tb_period, writer) 26 | 27 | def train(self, sample, global_step, _type='train', save_tb=True): 28 | enc_input_ids = sample['enc_input_ids'].to(self.device) 29 | enc_att_masks = sample['enc_att_masks'].to(self.device) 30 | dec_input_ids = sample['dec_input_ids'].to(self.device) 31 | dec_att_masks = sample['dec_att_masks'].to(self.device) 32 | dec_label_ids = sample['dec_label_ids'].to(self.device) 33 | 34 | lm_logits = self.model.forward_conditional_gen(enc_input_ids, enc_att_masks, dec_input_ids, dec_att_masks) 35 | 36 | B, S, V = lm_logits.shape 37 | loss = self.loss_func(lm_logits.view(B * S, V), dec_label_ids.view(-1)).view(B, S) 38 | loss = torch.sum(loss * dec_att_masks, dim=-1) / torch.sum(dec_att_masks, dim=-1) 39 | loss = torch.mean(loss) 40 | if global_step % self.tb_period == 0 and _type == 'train' and save_tb: 41 | self.writer.add_scalar('train/loss'.format(_type), loss.item(), global_step) 42 | self.writer.flush() 43 | 44 | return loss 45 | 46 | 47 | class FineTuneTrainerForDec(FineTuneTrainer): 48 | def __init__(self, model, device, loss_func, tb_period, writer): 49 | super(FineTuneTrainerForDec, self).__init__(model, device, loss_func, tb_period, writer) 50 | 51 | def train(self, sample, global_step, _type='train', save_tb=True): 52 | input_ids = sample['input_ids'].to(self.device) 53 | att_masks = sample['att_masks'].to(self.device) 54 | label_ids = sample['label_ids'].to(self.device) 55 | 56 | lm_logits = self.model.forward_conditional_gen(input_ids, att_masks) 57 | 58 | B, S, V = lm_logits.shape 59 | loss = self.loss_func(lm_logits.view(B * S, V), label_ids.view(-1)).view(B, S) 60 | loss = torch.sum(loss * att_masks, dim=-1) / torch.sum(att_masks, dim=-1) 61 | loss = torch.mean(loss) 62 | if global_step % self.tb_period == 0 and _type == 'train' and save_tb: 63 | self.writer.add_scalar('train/loss'.format(_type), loss.item(), global_step) 64 | self.writer.flush() 65 | 66 | return loss 67 | 68 | -------------------------------------------------------------------------------- /src/finetune/finetune_utils.py: -------------------------------------------------------------------------------- 1 | from src.sampler_utils import load_raw_tsv 2 | from torch.utils.data import Dataset 3 | from src.sampler import BaseDataSampler 4 | import random 5 | import torch 6 | 7 | random.seed(42) 8 | 9 | 10 | def load_fntn_datasets(dataset_cfg, tokenizer, logger): 11 | logger.info('Target Dataset : {}'.format(dataset_cfg['name'])) 12 | dataset = {key: load_raw_tsv(dataset_cfg['dir'][key], tokenizer, logger, dataset_cfg['truncate']) 13 | for key in dataset_cfg['dir']} 14 | 15 | return dataset 16 | 17 | 18 | def get_finetune_dataset(dataset, dataset_cfg, tokenizer, logger, _type, model_type='enc-dec'): 19 | if model_type == 'enc-dec': 20 | fntn_dataset = FineTuneDatasetForEncDec(dataset, tokenizer, dataset_cfg['truncate']) 21 | elif model_type == 'dec': 22 | fntn_dataset = FineTuneDatasetForDec(dataset, tokenizer, dataset_cfg['truncate']) 23 | else: 24 | raise NotImplementedError 25 | logger.info("Load Fine-tuning datasets [{}] : {}".format(_type, len(fntn_dataset))) 26 | return fntn_dataset 27 | 28 | 29 | def get_eval_dataset(dataset, tokenizer, logger, _type, model_type='enc-dec'): 30 | triple_dict = dict() 31 | for row in dataset: 32 | s, r, o = row 33 | key = s + '||' + r 34 | if triple_dict.get(key) is None: 35 | triple_dict[key] = list() 36 | triple_dict[key].append(o) 37 | 38 | outputs = list() 39 | for key in triple_dict: 40 | src = key.split('||') 41 | ref = list(set(triple_dict[key])) 42 | outputs.append({'src': src, 'ref': ref}) 43 | 44 | logger.info("Load Evaluation datasets [{}] : {}".format(_type, len(outputs))) 45 | 46 | if model_type == 'enc-dec': 47 | return EvalDatasetForEncDec(outputs, tokenizer, logger) 48 | elif model_type == 'dec': 49 | return EvalDatasetForDec(outputs, tokenizer, logger) 50 | else: 51 | raise NotImplementedError 52 | 53 | 54 | class EvalDataset(Dataset): 55 | def __init__(self, dataset, tokenizer, logger): 56 | self.dataset = dataset 57 | self.tokenizer = tokenizer 58 | self.logger = logger 59 | self.bos = self.tokenizer.bos_token_id 60 | self.eos = self.tokenizer.eos_token_id 61 | self.sep = self.tokenizer.sep_token_id 62 | self.gen = self.tokenizer.gen_token_id 63 | 64 | def formatting(self, sample): 65 | raise NotImplementedError 66 | 67 | def __len__(self): 68 | return len(self.dataset) 69 | 70 | def __getitem__(self, idx): 71 | sample = self.dataset[idx] 72 | return self.formatting(sample) 73 | 74 | 75 | class EvalDatasetForEncDec(EvalDataset): 76 | def __init__(self, dataset, tokenizer, logger): 77 | super(EvalDatasetForEncDec, self).__init__(dataset, tokenizer, logger) 78 | 79 | def formatting(self, sample): 80 | dp = {'src': None, 'ref': list()} 81 | 82 | src, ref_list = sample['src'], sample['ref'] 83 | s, r = [self.tokenizer(i) for i in src] 84 | _input = [self.bos] + s + [self.sep] + r + [self.sep] + [self.gen] + [self.eos] 85 | dp['src'] = _input 86 | for ref in ref_list: 87 | ref_token = self.tokenizer(ref) 88 | _output = [self.bos] + ref_token + [self.eos] 89 | dp['ref'].append(_output) 90 | 91 | if len(dp['ref']) == 0: 92 | raise Exception 93 | 94 | return dp 95 | 96 | 97 | class EvalDatasetForDec(EvalDataset): 98 | def __init__(self, dataset, tokenizer, logger): 99 | super(EvalDatasetForDec, self).__init__(dataset, tokenizer, logger) 100 | 101 | def __len__(self): 102 | return len(self.dataset) 103 | 104 | def formatting(self, sample): 105 | dp = {'src': None, 'ref': list()} 106 | 107 | src, ref_list = sample['src'], sample['ref'] 108 | s, r = [self.tokenizer(i) for i in src] 109 | _input = [self.bos] + s + [self.sep] + r + [self.sep] + [self.gen] 110 | dp['src'] = _input 111 | for ref in ref_list: 112 | ref_token = self.tokenizer(ref) 113 | _output = ref_token + [self.eos] 114 | dp['ref'].append(_output) 115 | 116 | if len(dp['ref']) == 0: 117 | raise Exception 118 | 119 | return dp 120 | 121 | 122 | class FineTuneDataset(Dataset): 123 | def __init__(self, dataset, tokenizer): 124 | super(FineTuneDataset, self).__init__() 125 | self.dataset = dataset 126 | self.tokenizer = tokenizer 127 | self.gen = self.tokenizer.gen_token_id 128 | self.bos = self.tokenizer.bos_token_id 129 | self.eos = self.tokenizer.eos_token_id 130 | self.sep = self.tokenizer.sep_token_id 131 | self.pad = self.tokenizer.pad_token_id 132 | self.shuffle() 133 | 134 | def __len__(self): 135 | return len(self.dataset) 136 | 137 | def formatting(self, sample): 138 | raise NotImplementedError 139 | 140 | def shuffle(self): 141 | random.shuffle(self.dataset) 142 | 143 | def __getitem__(self, idx): 144 | sample = self.dataset[idx] 145 | sample = self.formatting(sample) 146 | return sample 147 | 148 | 149 | class FineTuneDatasetForEncDec(FineTuneDataset): 150 | def __init__(self, dataset, tokenizer, truncate): 151 | super(FineTuneDatasetForEncDec, self).__init__(dataset, tokenizer) 152 | self.enc_len, self.dec_len = truncate['subj_len'] + 6, truncate['obj_len'] + 5 153 | 154 | def formatting(self, sample): 155 | s, r, o = sample 156 | s_tokens = self.tokenizer(s) 157 | r_tokens = self.tokenizer(r) 158 | o_tokens = self.tokenizer(o) 159 | 160 | _input = [self.bos] 161 | _input.extend(s_tokens + [self.sep]) 162 | _input.extend(r_tokens + [self.sep]) 163 | _input.extend([self.gen]) 164 | _input.extend([self.pad] * (self.enc_len - len(_input))) 165 | 166 | _output = [self.bos] + o_tokens + [self.eos] 167 | _output.extend([self.pad] * (self.dec_len - len(_output))) 168 | 169 | _enc_input_ids = torch.tensor(_input) 170 | _enc_att_masks = torch.ones_like(_enc_input_ids) * (_enc_input_ids != self.pad) 171 | _dec_origin = torch.tensor(_output) 172 | _dec_input_ids = _dec_origin[:-1].clone().detach() 173 | _dec_att_masks = torch.ones_like(_dec_input_ids) * (_dec_input_ids != self.pad) 174 | _dec_output_ids = _dec_origin[1:].clone().detach() 175 | 176 | output = {'enc_input_ids': _enc_input_ids, 177 | 'enc_att_masks': _enc_att_masks, 178 | 'dec_input_ids': _dec_input_ids, 179 | 'dec_att_masks': _dec_att_masks, 180 | 'dec_label_ids': _dec_output_ids} 181 | 182 | return output 183 | 184 | 185 | class FineTuneDatasetForDec(FineTuneDataset): 186 | def __init__(self, dataset, tokenizer, truncate): 187 | super(FineTuneDatasetForDec, self).__init__(dataset, tokenizer) 188 | self.dec_len = truncate['subj_len'] + 6 + truncate['obj_len'] + 5 189 | 190 | def __len__(self): 191 | return len(self.dataset) 192 | 193 | def formatting(self, sample): 194 | s, r, o = sample 195 | s_tokens = self.tokenizer(s) 196 | r_tokens = self.tokenizer(r) 197 | o_tokens = self.tokenizer(o) 198 | 199 | _input = [self.bos] 200 | _input.extend(s_tokens + [self.sep]) 201 | _input.extend(r_tokens + [self.sep, self.gen]) 202 | _input.extend(o_tokens + [self.eos]) 203 | 204 | _input.extend([self.pad] * (self.dec_len - len(_input))) 205 | _input_ids = _input[:-1] 206 | _label_ids = _input[1:] 207 | 208 | _input_ids = torch.tensor(_input_ids) 209 | _label_ids = torch.tensor(_label_ids) 210 | 211 | _att_masks = torch.ones_like(_input_ids) * (_input_ids != self.pad) 212 | output = {'input_ids': _input_ids, 213 | 'att_masks': _att_masks, 214 | 'label_ids': _label_ids} 215 | 216 | return output 217 | 218 | -------------------------------------------------------------------------------- /src/lr_schedule.py: -------------------------------------------------------------------------------- 1 | 2 | class WarmupLinearScheduler(object): 3 | def __init__(self, optimizer, max_lr, warmup_steps, total_steps): 4 | self.optimizer = optimizer 5 | self.lr = 0.0 6 | self.total_steps = total_steps 7 | self.warmup_steps = warmup_steps 8 | self.max_lr = max_lr 9 | self.warmup_increase = None 10 | self.lr_decay = None 11 | self._calculate() 12 | self._adapt_lr() 13 | 14 | def __call__(self, step): 15 | if step <= self.warmup_steps: 16 | self.lr += self.warmup_increase 17 | else: 18 | self.lr -= self.lr_decay 19 | assert not self.lr < 0 20 | self._adapt_lr() 21 | 22 | def _calculate(self): 23 | self.warmup_increase = self.max_lr / self.warmup_steps 24 | self.lr_decay = self.max_lr / (self.total_steps - self.warmup_steps) 25 | 26 | def _adapt_lr(self): 27 | for g in self.optimizer.param_groups: 28 | g['lr'] = self.lr -------------------------------------------------------------------------------- /src/rec_adam.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import numpy as np 4 | 5 | import torch 6 | from torch.optim import Optimizer 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def anneal_function(function, step, k, t0, weight): 13 | if function == 'sigmoid': 14 | return float(1 / (1 + np.exp(-k * (step - t0)))) * weight 15 | elif function == 'linear': 16 | return min(1, step / t0) * weight 17 | elif function == 'constant': 18 | return weight 19 | else: 20 | ValueError 21 | 22 | 23 | class RecAdam(Optimizer): 24 | """ Implementation of RecAdam optimizer, a variant of Adam optimizer. 25 | Parameters: 26 | lr (float): learning rate. Default 1e-3. 27 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) 28 | eps (float): Adams epsilon. Default: 1e-6 29 | weight_decay (float): Weight decay. Default: 0.0 30 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True. 31 | anneal_fun (str): a hyperparam for the anneal function, decide the function of the curve. Default 'sigmoid'. 32 | anneal_k (float): a hyperparam for the anneal function, decide the slop of the curve. Choice: [0.05, 0.1, 0.2, 0.5, 1] 33 | anneal_t0 (float): a hyperparam for the anneal function, decide the middle point of the curve. Choice: [100, 250, 500, 1000] 34 | anneal_w (float): a hyperparam for the anneal function, decide the scale of the curve. Default 1.0. 35 | pretrain_cof (float): the coefficient of the quadratic penalty. Default 5000.0. 36 | pretrain_params (list of tensors): the corresponding group of params in the pretrained model. 37 | """ 38 | 39 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True, 40 | anneal_fun='sigmoid', anneal_k=0, anneal_t0=0, anneal_w=1.0, pretrain_cof=5000.0, pretrain_params=None): 41 | if lr < 0.0: 42 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 43 | if not 0.0 <= betas[0] < 1.0: 44 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])) 45 | if not 0.0 <= betas[1] < 1.0: 46 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])) 47 | if not 0.0 <= eps: 48 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) 49 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias, 50 | anneal_fun=anneal_fun, anneal_k=anneal_k, anneal_t0=anneal_t0, anneal_w=anneal_w, 51 | pretrain_cof=pretrain_cof, pretrain_params=pretrain_params) 52 | super().__init__(params, defaults) 53 | 54 | def step(self, closure=None): 55 | """Performs a single optimization step. 56 | Arguments: 57 | closure (callable, optional): A closure that reevaluates the model 58 | and returns the loss. 59 | """ 60 | anneal_lambda = -1 61 | loss = None 62 | if closure is not None: 63 | loss = closure() 64 | for group in self.param_groups: 65 | for p, pp in zip(group["params"], group["pretrain_params"]): 66 | if p.grad is None: 67 | continue 68 | grad = p.grad.data 69 | if grad.is_sparse: 70 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") 71 | 72 | state = self.state[p] 73 | 74 | # State initialization 75 | if len(state) == 0: 76 | state["step"] = 0 77 | # Exponential moving average of gradient values 78 | state["exp_avg"] = torch.zeros_like(p.data) 79 | # Exponential moving average of squared gradient values 80 | state["exp_avg_sq"] = torch.zeros_like(p.data) 81 | 82 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 83 | beta1, beta2 = group["betas"] 84 | 85 | state["step"] += 1 86 | 87 | # Decay the first and second moment running average coefficient 88 | # In-place operations to update the averages at the same time 89 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) 90 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) 91 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 92 | 93 | step_size = group["lr"] 94 | if group["correct_bias"]: 95 | bias_correction1 = 1.0 - beta1 ** state["step"] 96 | bias_correction2 = 1.0 - beta2 ** state["step"] 97 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 98 | 99 | # With RecAdam method, the optimization objective is 100 | # Loss = lambda(t)*Loss_T + (1-lambda(t))*Loss_S 101 | # Loss = lambda(t)*Loss_T + (1-lambda(t))*\gamma/2*\sum((\theta_i-\theta_i^*)^2) 102 | if group['anneal_w'] > 0.0: 103 | # We calculate the lambda as the annealing function 104 | anneal_lambda = anneal_function(group['anneal_fun'], state["step"], group['anneal_k'], 105 | group['anneal_t0'], group['anneal_w']) 106 | assert anneal_lambda <= group['anneal_w'] 107 | # The loss of the target task is multiplied by lambda(t) 108 | p.data.addcdiv_(-step_size * anneal_lambda, exp_avg, denom) 109 | # Add the quadratic penalty to simulate the pretraining tasks 110 | p.data.add_(-group["lr"] * (group['anneal_w'] - anneal_lambda) * group["pretrain_cof"], p.data - pp.data) 111 | else: 112 | p.data.addcdiv_(-step_size, exp_avg, denom) 113 | 114 | # Just adding the square of the weights to the loss function is *not* 115 | # the correct way of using L2 regularization/weight decay with Adam, 116 | # since that will interact with the m and v parameters in strange ways. 117 | # 118 | # Instead we want to decay the weights in a manner that doesn't interact 119 | # with the m/v parameters. This is equivalent to adding the square 120 | # of the weights to the loss with plain (non-momentum) SGD. 121 | # Add weight decay at the end (fixed version) 122 | if group["weight_decay"] > 0.0: 123 | p.data.add_(-group["lr"] * group["weight_decay"], p.data) 124 | 125 | return loss, anneal_lambda -------------------------------------------------------------------------------- /src/rec_adam_wrapper.py: -------------------------------------------------------------------------------- 1 | from src.rec_adam import RecAdam, anneal_function 2 | from transformers import get_linear_schedule_with_warmup 3 | 4 | ########## 5 | ''' 6 | #params = [p for n, p in model.named_parameters()] 7 | pretrained_weights = anneal_targets 8 | new_model = model 9 | no_decay = ["bias", "layer_norm.weight"] 10 | for n, p in model.named_parameters(): 11 | for nd in no_decay: 12 | if nd in n: 13 | print(n) 14 | 15 | for n, p in new_model.named_parameters(): 16 | if not any(nd in n for nd in no_decay): 17 | print(n) 18 | if n in pretrained_weights: 19 | print(n) 20 | pretrained_weights.keys() 21 | 22 | any(nd in n for nd in no_decay) 23 | [p for n, p in new_model.named_parameters() if 24 | not any(nd in n for nd in no_decay) and n in pretrained_weights] 25 | ''' 26 | ########## 27 | 28 | def get_adam_optimizer(new_model, pretrained_model, pretrained_weights, args, t_total): 29 | no_decay = ["bias", "LayerNorm.weight"] 30 | 31 | optimizer_grouped_parameters = [ 32 | { 33 | "params": [p for n, p in new_model.named_parameters() if 34 | not any(nd in n for nd in no_decay) and n in pretrained_weights], 35 | "weight_decay": args.weight_decay, 36 | "anneal_w": args.recadam_anneal_w, 37 | "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if 38 | not any(nd in p_n for nd in no_decay) and p_n in pretrained_weights] 39 | }, 40 | { 41 | "params": [p for n, p in new_model.named_parameters() if 42 | not any(nd in n for nd in no_decay) and n not in pretrained_weights], 43 | "weight_decay": args.weight_decay, 44 | "anneal_w": 0.0, 45 | "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if 46 | not any(nd in p_n for nd in no_decay) and p_n not in pretrained_weights] 47 | }, 48 | { 49 | "params": [p for n, p in new_model.named_parameters() if 50 | any(nd in n for nd in no_decay) and n in pretrained_weights], 51 | "weight_decay": 0.0, 52 | "anneal_w": args.recadam_anneal_w, 53 | "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if 54 | any(nd in p_n for nd in no_decay) and p_n in pretrained_weights == p_n] 55 | }, 56 | { 57 | "params": [p for n, p in new_model.named_parameters() if 58 | any(nd in n for nd in no_decay) and n not in pretrained_weights], 59 | "weight_decay": 0.0, 60 | "anneal_w": 0.0, 61 | "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if 62 | any(nd in p_n for nd in no_decay) and p_n not in pretrained_weights] 63 | } 64 | ] 65 | 66 | optimizer = RecAdam(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon, 67 | anneal_fun=args.recadam_anneal_fun, 68 | anneal_k=args.recadam_anneal_k, anneal_t0=args.recadam_anneal_t0, 69 | pretrain_cof=args.recadam_pretrain_cof) 70 | 71 | scheduler = get_linear_schedule_with_warmup( 72 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 73 | ) 74 | 75 | return optimizer, scheduler 76 | -------------------------------------------------------------------------------- /src/sampler_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pickle as pkl 3 | from tqdm import tqdm 4 | from copy import deepcopy 5 | import torch 6 | import os 7 | 8 | def load_raw_tsv(raw_path, tokenizer, logger, truncate): 9 | data = pd.read_csv(raw_path, delimiter='\t', header=None) 10 | BLK_TOKEN = tokenizer.blk_token 11 | triples = list() 12 | subj_len = truncate['subj_len'] 13 | obj_len = truncate['obj_len'] 14 | 15 | cached_path = raw_path[:-4] + '_cached_{}_{}.torch-pkl'.format(subj_len, obj_len) 16 | if os.path.exists(cached_path): 17 | logger.info('Load from {}'.format(cached_path)) 18 | return torch.load(cached_path) 19 | 20 | for row in tqdm(data.iloc, desc='loading_raw_file...', ncols=70): 21 | 22 | s, r, o = row 23 | 24 | if pd.isna(s) or pd.isna(r) or pd.isna(o): 25 | continue 26 | s = s.replace('___', BLK_TOKEN) 27 | r = '<{}>'.format(r) 28 | if o != o: # pass nan 29 | continue 30 | o = o.replace('___', BLK_TOKEN) 31 | 32 | if len(tokenizer(s)) > subj_len or len(tokenizer(o)) > obj_len: 33 | continue 34 | triples.append([s, r, o]) 35 | 36 | logger.info('Total loaded raw samples : {}'.format(len(triples))) 37 | 38 | torch.save(triples, cached_path) 39 | 40 | return triples 41 | 42 | 43 | def load_atomic_sim_pkl(raw_path, tokenizer, logger): 44 | logger.info('Load Similar matrix from {}'.format(raw_path)) 45 | with open(raw_path, 'rb') as f: 46 | data = pkl.load(f) 47 | blk_token = tokenizer.blk_token 48 | output = deepcopy(data) 49 | 50 | for s, i in data['s2i'].items(): 51 | s = s.replace('___', blk_token) 52 | output['s2i'][s] = i 53 | output['i2s'][i] = s 54 | return output 55 | -------------------------------------------------------------------------------- /src/tokenizer.py: -------------------------------------------------------------------------------- 1 | def adapt_commonsense_tokenizer(tokenizer, config): 2 | tokenizer.add_tokens([config['additional_tokens'][key] for key in config['additional_tokens']]) 3 | tokenizer.add_tokens([tokens for tokens in config['relation_tokens']]) 4 | tokenizer.add_special_tokens(config['special_tokens']) 5 | return tokenizer 6 | 7 | class BaseCSKGTokenizer: 8 | def __init__(self, tokenizer, config): 9 | self.tokenizer = tokenizer 10 | self.config = config 11 | 12 | self.tokenizer.add_tokens([self.config['additional_tokens'][key] for key in self.config['additional_tokens']]) 13 | self.tokenizer.add_tokens([tokens for tokens in self.config['relation_tokens']]) 14 | self.tokenizer.add_special_tokens(self.config['special_tokens']) 15 | 16 | # Additional Token Ids 17 | self.gen_token = self.config['additional_tokens']['gen_token'] 18 | self.gen_token_id = self.tokenize(self.gen_token)[0] 19 | self.blk_token = self.config['additional_tokens']['blk_token'] 20 | self.blk_token_id = self.tokenize(self.blk_token)[0] 21 | self.none_token = self.config['additional_tokens']['none_token'] 22 | self.none_token_id = self.tokenize(self.none_token)[0] 23 | self.gen_task_token = self.config['additional_tokens']['gen_task_token'] 24 | self.gen_task_token_id = self.tokenize(self.gen_task_token)[0] 25 | 26 | # Contrastive Task Tokens 27 | self.con_task_token = self.config['additional_tokens']['con_task_token'] 28 | self.con_task_token_id = self.tokenize(self.con_task_token)[0] 29 | self.con_task_token_s = self.config['additional_tokens']['con_task_token_for_s'] 30 | self.con_task_token_s_id = self.tokenize(self.con_task_token_s)[0] 31 | self.con_task_token_ro = self.config['additional_tokens']['con_task_token_for_ro'] 32 | self.con_task_token_ro_id = self.tokenize(self.con_task_token_ro)[0] 33 | 34 | # Reconstruct Task Tokens 35 | self.rec_task_token = self.config['additional_tokens']['rec_task_token'] 36 | self.rec_task_token_id = self.tokenize(self.rec_task_token)[0] 37 | 38 | # Mask Generate Task Tokens 39 | self.mask_gen_task_token = self.config['additional_tokens']['rec_task_token'] 40 | self.mask_gen_task_token_id = self.tokenize(self.rec_task_token)[0] 41 | 42 | # Special Token Ids 43 | self.bos_token = self.tokenizer.bos_token 44 | self.bos_token_id = self.tokenizer.bos_token_id 45 | self.eos_token = self.tokenizer.eos_token 46 | self.eos_token_id = self.tokenizer.eos_token_id 47 | self.pad_token = self.tokenizer.pad_token 48 | self.pad_token_id = self.tokenizer.pad_token_id 49 | self.mask_token = self.tokenizer.mask_token 50 | self.mask_token_id = self.tokenizer.mask_token_id 51 | self.sep_token = self.tokenizer.sep_token 52 | self.sep_token_id = self.tokenizer.sep_token_id 53 | 54 | def tokenize(self, sequence): 55 | raise NotImplementedError 56 | 57 | def __call__(self, sequence): 58 | return self.tokenize(sequence) 59 | 60 | def __len__(self): 61 | return len(self.tokenizer) 62 | 63 | 64 | class BartCSKGTokenizer(BaseCSKGTokenizer): 65 | def __init__(self, tokenizer, config): 66 | super(BartCSKGTokenizer, self).__init__(tokenizer, config) 67 | 68 | def tokenize(self, seq): 69 | assert type(seq) is str 70 | return self.tokenizer(seq)['input_ids'][1:-1] 71 | 72 | 73 | class Gpt2CSKGTokenizer(BaseCSKGTokenizer): 74 | def __init__(self, tokenizer, config): 75 | super(Gpt2CSKGTokenizer, self).__init__(tokenizer, config) 76 | 77 | def tokenize(self, seq): 78 | assert type(seq) is str 79 | return self.tokenizer(seq)['input_ids'] 80 | 81 | -------------------------------------------------------------------------------- /src/train_utils.py: -------------------------------------------------------------------------------- 1 | from src.feed_model import * 2 | from torch.utils.data import DataLoader 3 | 4 | 5 | def get_data_feeder_from_sampler(sampler, options, tokenizer, task_cfg, batch_size, iter_per_epoch): 6 | data_feeder = dict() 7 | enc_len = task_cfg['max_enc'] + 5 8 | dec_len = task_cfg['max_dec'] + 5 9 | for task_key, task_sampler in sampler.items(): 10 | for option_key, val in options['common'].items(): 11 | options[task_key][option_key] = val 12 | 13 | 14 | if task_key == 'con_task': 15 | if task_cfg['con_task']['method'] == 'cluster': 16 | group_num = task_cfg['con_task']['cluster_contrast']['group_num'] 17 | elif task_cfg['con_task']['method'] == 'naive': 18 | group_num = batch_size 19 | else: 20 | raise NotImplementedError 21 | 22 | task_adaptor = ConTaskAdaptor( 23 | options[task_key], task_sampler, tokenizer, enc_len, dec_len, batch_size, iter_per_epoch, group_num) 24 | data_feeder[task_key] = task_adaptor 25 | 26 | elif task_key == 'rec_inf_shu_task': 27 | task_adaptor = RecInfShuTaskAdaptor( 28 | options[task_key], task_sampler, tokenizer, enc_len, dec_len, batch_size, iter_per_epoch * batch_size) 29 | data_feeder[task_key] = task_adaptor 30 | 31 | elif task_key == 'mask_gen_task': 32 | task_adaptor = MaskGenTaskAdaptor( 33 | options[task_key], task_sampler, tokenizer, enc_len, dec_len, batch_size, iter_per_epoch * batch_size) 34 | data_feeder[task_key] = task_adaptor 35 | 36 | return data_feeder 37 | 38 | 39 | class WrappedLoader: 40 | def __init__(self, dataset, batch_size): 41 | self.loader = DataLoader(dataset,batch_size=batch_size, drop_last=False, shuffle=True, num_workers=10) 42 | pass 43 | 44 | def __getitem__(self, idx): 45 | for data in self.loader: 46 | yield data 47 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class RecTaskTrainerForEncDec: 5 | def __init__(self, model, device, loss_func, tb_period, writer, _type='train'): 6 | self.model = model 7 | self.device = device 8 | self.loss_func = loss_func 9 | self.tb_period = tb_period 10 | self.writer = writer 11 | self._type = _type 12 | 13 | def train(self, task_data, global_step, save=False): 14 | enc_input_ids = task_data['enc_input_ids'].to(self.device) 15 | enc_att_mask = task_data['enc_att_mask'].to(self.device) 16 | dec_input_ids = task_data['dec_input_ids'].to(self.device) 17 | dec_att_mask = task_data['dec_att_mask'].to(self.device) 18 | dec_label_ids = task_data['dec_label_ids'].to(self.device) 19 | 20 | lm_logits = self.model.forward_conditional_gen(enc_input_ids, enc_att_mask, dec_input_ids, dec_att_mask) 21 | 22 | B, S, V = lm_logits.shape 23 | 24 | loss = self.loss_func(lm_logits.view(B * S, V), dec_label_ids.view(-1)).view(B, S) 25 | loss = torch.sum(torch.mul(loss, dec_att_mask), -1) / torch.sum(dec_att_mask, -1) 26 | total_loss = torch.mean(loss) 27 | 28 | if (global_step % self.tb_period == 0) and save is True: 29 | self.writer.add_scalar('{}_rec_task/total_loss'.format(self._type), total_loss.item(), global_step) 30 | self.writer.flush() 31 | 32 | return total_loss 33 | 34 | 35 | class MaskGenTaskTrainerForDec: 36 | def __init__(self, model, device, loss_func, tb_period, writer, _type='train'): 37 | self.model = model 38 | self.device = device 39 | self.loss_func = loss_func 40 | self.tb_period = tb_period 41 | self.writer = writer 42 | self._type = _type 43 | 44 | def train(self, task_data, global_step, save=False): 45 | input_ids = task_data['input_ids'].to(self.device) 46 | att_mask = task_data['att_mask'].to(self.device) 47 | label_ids = task_data['label_ids'].to(self.device) 48 | 49 | lm_logits = self.model.forward_conditional_gen(input_ids, att_mask) 50 | 51 | B, S, V = lm_logits.shape 52 | 53 | loss = self.loss_func(lm_logits.view(B * S, V), label_ids.view(-1)).view(B, S) 54 | loss = torch.sum(torch.mul(loss, att_mask), -1) / torch.sum(att_mask, -1) 55 | total_loss = torch.mean(loss) 56 | 57 | if (global_step % self.tb_period == 0) and save is True: 58 | self.writer.add_scalar('{}_mask_gen_task/total_loss'.format(self._type), total_loss.item(), global_step) 59 | self.writer.flush() 60 | 61 | return total_loss 62 | 63 | 64 | 65 | class ConTaskTrainerBase: 66 | def __init__(self, model, device, distance_model, loss_func, tb_period, writer, share_proj_layer, args, _type='train'): 67 | self.model = model 68 | self.device = device 69 | self.distance_model = distance_model 70 | self.loss_func = loss_func 71 | self.tb_period = tb_period 72 | self.writer = writer 73 | self.share_s, self.share_ro = share_proj_layer 74 | self._type = _type 75 | 76 | class ConTaskTrainerForEncDec(ConTaskTrainerBase): 77 | def train(self, task_data, global_step, save=False): 78 | enc_input_ids_s = task_data['enc_input_ids_s'].to(self.device) 79 | enc_input_ids_ro = task_data['enc_input_ids_ro'].to(self.device) 80 | enc_att_mask_s = task_data['enc_att_mask_s'].to(self.device) 81 | enc_att_mask_ro = task_data['enc_att_mask_ro'].to(self.device) 82 | dec_input_ids_s = task_data['dec_input_ids_s'].to(self.device) 83 | dec_input_ids_ro = task_data['dec_input_ids_ro'].to(self.device) 84 | dec_att_mask_s = task_data['dec_att_mask_s'].to(self.device) 85 | dec_att_mask_ro = task_data['dec_att_mask_ro'].to(self.device) 86 | pos_sample = task_data['pos_samples'].to(self.device) 87 | 88 | s_repre = self.model.forward_latent_feature( 89 | enc_input_ids_s, enc_att_mask_s, dec_input_ids_s, dec_att_mask_s)#, None, self.share_s) 90 | ro_repre = self.model.forward_latent_feature( 91 | enc_input_ids_ro, enc_att_mask_ro, dec_input_ids_ro, dec_att_mask_ro)#, None, self.share_ro) 92 | 93 | dist_s_ro = self.distance_model(s_repre, ro_repre) 94 | dist_s_s = self.distance_model(s_repre, s_repre) 95 | dist_ro_ro = self.distance_model(ro_repre, ro_repre) 96 | 97 | sro_loss, sro_pos_dists, sro_neg_dists = self.loss_func.get_loss(dist_s_ro, pos_sample) 98 | ss_loss, ss_pos_dists, ss_neg_dists = self.loss_func.get_loss(dist_s_s, pos_sample) 99 | roro_loss, roro_pos_dists, roro_neg_dists = self.loss_func.get_loss(dist_ro_ro, pos_sample) 100 | 101 | loss = sro_loss + ss_loss + roro_loss 102 | pos_dists = sro_pos_dists + ss_pos_dists + roro_pos_dists 103 | neg_dists = sro_neg_dists + ss_neg_dists + roro_neg_dists 104 | 105 | if (global_step % self.tb_period == 0) and save is True: 106 | self.writer.add_scalar('{}_con_task/total_loss'.format(self._type), float(loss), global_step) 107 | self.writer.add_scalar('{}_con_task/s-ro_loss'.format(self._type), float(sro_loss), global_step) 108 | self.writer.add_scalar('{}_con_task/s-s_loss'.format(self._type), float(ss_loss), global_step) 109 | self.writer.add_scalar('{}_con_task/ro-ro_loss'.format(self._type), float(roro_loss), global_step) 110 | self.writer.add_scalar('{}_con_task/total_pos_dists'.format(self._type), float(pos_dists), global_step) 111 | self.writer.add_scalar('{}_con_task/total_neg_dists'.format(self._type), float(neg_dists), global_step) 112 | self.writer.add_scalar('{}_con_task/total_pos-neg'.format(self._type), float(pos_dists - neg_dists), global_step) 113 | self.writer.add_scalar('{}_con_task/s-ro_pos-neg'.format(self._type), float(sro_pos_dists - sro_neg_dists), global_step) 114 | self.writer.add_scalar('{}_con_task/s-s_pos-neg'.format(self._type), float(ss_pos_dists - ss_neg_dists), global_step) 115 | self.writer.add_scalar('{}_con_task/ro-ro_pos-neg'.format(self._type), float(roro_pos_dists - roro_neg_dists), global_step) 116 | self.writer.flush() 117 | 118 | return loss 119 | 120 | 121 | class ConTaskTrainerForDec(ConTaskTrainerBase): 122 | def train(self, task_data, global_step, save=False): 123 | input_ids_s = task_data['input_ids_s'].to(self.device) 124 | input_ids_ro = task_data['input_ids_ro'].to(self.device) 125 | att_mask_s = task_data['att_mask_s'].to(self.device) 126 | att_mask_ro = task_data['att_mask_ro'].to(self.device) 127 | pos_sample = task_data['pos_samples'].to(self.device) 128 | 129 | s_repre = self.model.forward_latent_feature( 130 | input_ids_s, att_mask_s) 131 | ro_repre = self.model.forward_latent_feature( 132 | input_ids_ro, att_mask_ro) 133 | 134 | dist_s_ro = self.distance_model(s_repre, ro_repre) 135 | dist_s_s = self.distance_model(s_repre, s_repre) 136 | dist_ro_ro = self.distance_model(ro_repre, ro_repre) 137 | 138 | sro_loss, sro_pos_dists, sro_neg_dists = self.loss_func.get_loss(dist_s_ro, pos_sample) 139 | ss_loss, ss_pos_dists, ss_neg_dists = self.loss_func.get_loss(dist_s_s, pos_sample) 140 | roro_loss, roro_pos_dists, roro_neg_dists = self.loss_func.get_loss(dist_ro_ro, pos_sample) 141 | 142 | loss = sro_loss + ss_loss + roro_loss 143 | pos_dists = sro_pos_dists + ss_pos_dists + roro_pos_dists 144 | neg_dists = sro_neg_dists + ss_neg_dists + roro_neg_dists 145 | 146 | if (global_step % self.tb_period == 0) and save is True: 147 | self.writer.add_scalar('{}_con_task/total_loss'.format(self._type), float(loss), global_step) 148 | self.writer.add_scalar('{}_con_task/s-ro_loss'.format(self._type), float(sro_loss), global_step) 149 | self.writer.add_scalar('{}_con_task/s-s_loss'.format(self._type), float(ss_loss), global_step) 150 | self.writer.add_scalar('{}_con_task/ro-ro_loss'.format(self._type), float(roro_loss), global_step) 151 | self.writer.add_scalar('{}_con_task/total_pos_dists'.format(self._type), float(pos_dists), global_step) 152 | self.writer.add_scalar('{}_con_task/total_neg_dists'.format(self._type), float(neg_dists), global_step) 153 | self.writer.add_scalar('{}_con_task/total_pos-neg'.format(self._type), float(pos_dists - neg_dists), global_step) 154 | self.writer.add_scalar('{}_con_task/s-ro_pos-neg'.format(self._type), float(sro_pos_dists - sro_neg_dists), global_step) 155 | self.writer.add_scalar('{}_con_task/s-s_pos-neg'.format(self._type), float(ss_pos_dists - ss_neg_dists), global_step) 156 | self.writer.add_scalar('{}_con_task/ro-ro_pos-neg'.format(self._type), float(roro_pos_dists - roro_neg_dists), global_step) 157 | self.writer.flush() 158 | 159 | return loss 160 | 161 | def mean_list(items): 162 | if type(items[0]) not in (list, tuple): 163 | return sum(items) / len(items) 164 | 165 | vals = 0 166 | nums = 0 167 | for val, num in items: 168 | vals += val 169 | nums += num 170 | 171 | return vals/nums -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import yaml 3 | 4 | 5 | def load_logger(log_dir, log_level): 6 | logger = logging.getLogger('CSKG') 7 | if log_level == 'INFO': 8 | lv = logging.INFO 9 | elif log_level == 'ERROR': 10 | lv = logging.ERROR 11 | elif log_level == 'DEBUG': 12 | lv = logging.DEBUG 13 | else: 14 | raise NotImplementedError 15 | logger.setLevel(lv) 16 | 17 | formatter = logging.Formatter('%(asctime)s [%(name)s] [%(levelname)s] :: %(message)s') 18 | stream_handler = logging.StreamHandler() 19 | stream_handler.setFormatter(formatter) 20 | file_handler = logging.FileHandler(log_dir) 21 | file_handler.setFormatter(formatter) 22 | 23 | logger.addHandler(stream_handler) 24 | logger.addHandler(file_handler) 25 | 26 | return logger 27 | 28 | def load_yaml(f): 29 | if type(f) is str: 30 | with open(f, 'r') as fp: 31 | config = yaml.load(fp, Loader=yaml.FullLoader) 32 | else: 33 | raise NotImplementedError 34 | 35 | return config -------------------------------------------------------------------------------- /system_eval/automatic_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import json 5 | sys.path.append(os.path.join(os.getcwd(), 'system_eval')) 6 | 7 | import numpy as np 8 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 9 | from utils import read_jsonl, remove_prefix 10 | from evaluation.eval import QGEvalCap 11 | from tabulate import tabulate 12 | from tqdm import tqdm 13 | 14 | os.environ['CUDA_VISIBLE_DEVICES'] = '3' 15 | 16 | def get_refs_preds(l, type=1): 17 | if type==1: 18 | tails = l["fact"]["tails"] 19 | head = l["fact"]["head"] 20 | prompt = l["prompt"] 21 | generations = l["generations"] 22 | gens = [remove_prefix(g, prompt).strip() for g in generations] 23 | if type==2: 24 | tails = l["refs"] 25 | head = l["input"] 26 | gens = [l["output"]] 27 | if type==3: 28 | tails = l["fact"]["tails"] 29 | head = l["fact"]["head"] 30 | gens = l["generations"] 31 | return gens, tails, head 32 | 33 | def get2(l): 34 | return list(zip(*l))[1] 35 | 36 | 37 | def topk_eval(model_name, data, data_type, k): 38 | topk_gts = {} 39 | topk_res = {} 40 | topk_exact_match = [] 41 | topk_exact_match_not_none = [] 42 | topk_bleu_score = [] 43 | 44 | topk_is_head = [] 45 | func = SmoothingFunction() 46 | # for i, l in enumerate(data): 47 | for i, l in tqdm(enumerate(data)): 48 | (gens, tails, head) = get_refs_preds(l, type=data_type) 49 | gens = [g.replace('', '').replace('', '').strip() for g in gens] 50 | tails = [t.replace('', '').replace('', '').strip() for t in tails] 51 | head = head.replace('', '').replace('', '').strip() 52 | sentence_tails = [t.lower() for t in tails] 53 | split_tails = [t.lower().replace('', '').replace('', '').split() for t in tails] 54 | 55 | for (j, g) in enumerate(gens[:k]): 56 | key = str(i) + "_" + str(j) 57 | topk_gts[key] = sentence_tails 58 | topk_res[key] = [g.lower()] 59 | 60 | b = sentence_bleu(split_tails, g.lower().split(), weights=(0.5, 0.5), smoothing_function=func.method1) 61 | topk_bleu_score.append((l, b)) 62 | if g in sentence_tails: 63 | topk_exact_match.append((l, 1)) 64 | if g != "none": 65 | topk_exact_match_not_none.append((l, 1)) 66 | else: 67 | topk_exact_match.append((l, 0)) 68 | if g != "none": 69 | topk_exact_match_not_none.append((l, 0)) 70 | if g == head: 71 | topk_is_head.append((l, 1)) 72 | else: 73 | topk_is_head.append((l, 0)) 74 | 75 | print("---------------TOP K={}---------------".format(k)) 76 | #print(np.mean(get2(topk_exact_match))) 77 | #print(np.mean(get2(topk_exact_match_not_none))) 78 | print(np.mean(get2(topk_bleu_score))) 79 | QGEval = QGEvalCap(model_name, topk_gts, topk_res) 80 | score, scores = QGEval.evaluate() 81 | scores["Exact_match"] = np.mean(get2(topk_exact_match)) 82 | #scores["TailIsHead"] = np.mean(get2(topk_is_head)) 83 | return score, topk_bleu_score, scores 84 | 85 | 86 | def eval(data_file, data_type, model_name): 87 | data = read_jsonl(data_file) 88 | return topk_eval(model_name, data, data_type, k=1) 89 | 90 | def toRow(name, results, columns): 91 | return [name] + [format(float(results[c]), '#.3f') for c in columns] 92 | 93 | parser = argparse.ArgumentParser() 94 | 95 | parser.add_argument('--mode', type=str, default='pycharm') 96 | parser.add_argument('--file_dir', type=str, default='/mnt/data/user8/solar-commonsense_inference/log_fntn') 97 | parser.add_argument('--dataset_type', type=str, default='atomic') 98 | parser.add_argument('--model_name', type=str, default='bart') 99 | parser.add_argument('--model_size', type=str, default='large') 100 | parser.add_argument('--exp_type', type=str, default='baseline') 101 | parser.add_argument('--target', type=str, default=None) 102 | 103 | args = parser.parse_args() 104 | 105 | targets_list = list() 106 | 107 | target_path = f'{args.file_dir}/{args.dataset_type}/{args.model_name}-{args.model_size}_{args.exp_type}' 108 | if args.target is None: 109 | target_list = os.listdir(target_path) 110 | else: 111 | target_list = [args.target] 112 | 113 | target_list.sort() 114 | 115 | decode_type = ['greedy'] 116 | print(target_list) 117 | 118 | #target_list = target_list[:7] 119 | for target_file in target_list: 120 | for decode in decode_type: 121 | input_file = f'{target_path}/{target_file}/{decode}_gen_examples.json' 122 | output_file = f'{target_path}/{target_file}/eval/{decode}_results.txt' 123 | results_per_sample_file = f'{target_path}/{target_file}/eval/{decode}_results_per_sample.pkl' 124 | try: 125 | with open(input_file, 'r') as f: 126 | input_file = json.load(f) 127 | except: 128 | continue 129 | # Eval 130 | print('TEST target : {}'.format(input_file['info']['ckpt'])) 131 | print(f'Decoded Type {decode}') 132 | gen_data = input_file['content'] 133 | 134 | scores, topk_bleu_score, score_list = topk_eval(model_name='BART-ATOMIC2020', data=gen_data, data_type=2, k=1) 135 | 136 | results_per_sample = list() 137 | 138 | for idx, sample in enumerate(gen_data): 139 | sample_result = dict() 140 | sample_result.update(sample) 141 | 142 | for key in score_list: 143 | if type(score_list[key]) is not list: 144 | continue 145 | val = score_list[key][idx] 146 | sample_result[key] = val 147 | 148 | results_per_sample.append(sample_result) 149 | 150 | print(scores) 151 | for key in scores: 152 | print(round(float(scores[key]) * 100, 4), end='\t') 153 | 154 | with open(output_file, 'w') as f: 155 | for key, item in scores.items(): 156 | f.write('{} : {}\n'.format(key, item)) 157 | 158 | print('\n\n') 159 | 160 | import pickle as pkl 161 | with open(results_per_sample_file, 'wb') as f: 162 | pkl.dump(results_per_sample, f) -------------------------------------------------------------------------------- /system_eval/evaluation/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Manish Joshi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /system_eval/evaluation/README.md: -------------------------------------------------------------------------------- 1 | # qgeval 2 | Calculate Bleu, METEOR, ROUGE and CIDEr score 3 | 4 | Usage: 5 | 6 | ```python anli_evaluation/eval.py --gen_file GENERATIONS_FILE --keys MODEL_KEYS[comma-separated list] --results_file RESULTS_FILE``` 7 | -------------------------------------------------------------------------------- /system_eval/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongho94/solar-framework_commonsense-inference/50e99ba0b0f5ae2315c72f5f35f6d385499283e1/system_eval/evaluation/__init__.py -------------------------------------------------------------------------------- /system_eval/evaluation/bert_score/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongho94/solar-framework_commonsense-inference/50e99ba0b0f5ae2315c72f5f35f6d385499283e1/system_eval/evaluation/bert_score/__init__.py -------------------------------------------------------------------------------- /system_eval/evaluation/bert_score/bert_score.py: -------------------------------------------------------------------------------- 1 | from bert_score import score 2 | # Code for BertScore reused from original implementation: https://github.com/Tiiiger/bert_score 3 | 4 | class BertScore: 5 | def __init__(self): 6 | self._hypo_for_image = {} 7 | self.ref_for_image = {} 8 | 9 | def compute_score(self, gts, res): 10 | 11 | assert(gts.keys() == res.keys()) 12 | imgIds = gts.keys() 13 | 14 | hyp_input = [] 15 | ref_input = [] 16 | same_indices = [] 17 | for id in imgIds: 18 | hypo = res[id] 19 | ref = gts[id] 20 | 21 | # Sanity check. 22 | assert(type(hypo) is list) 23 | assert(len(hypo) == 1) 24 | assert(type(ref) is list) 25 | assert(len(ref) >= 1) 26 | 27 | hyp_input += [hypo[0]] * len(ref) 28 | ref_input += ref 29 | same_indices.append(len(ref_input)) 30 | 31 | p, r, f_scores = score(hyp_input, ref_input, model_type="bert-base-uncased") 32 | 33 | prev_idx = 0 34 | aggreg_f1_scores = [] 35 | for idx in same_indices: 36 | aggreg_f1_scores.append(f_scores[prev_idx: idx].mean().cpu().item()) 37 | prev_idx = idx 38 | 39 | return sum(aggreg_f1_scores)/len(aggreg_f1_scores), aggreg_f1_scores 40 | 41 | def method(self): 42 | return "Bert Score" 43 | -------------------------------------------------------------------------------- /system_eval/evaluation/bert_score/score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import torch 5 | from collections import defaultdict 6 | from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM 7 | import matplotlib 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | 11 | from .utils import get_idf_dict, bert_cos_score_idf,\ 12 | get_bert_embedding, bert_types 13 | 14 | __all__ = ['score', 'plot_example'] 15 | 16 | def score(cands, refs, bert="bert-base-multilingual-cased", 17 | num_layers=8, verbose=False, no_idf=False, batch_size=64): 18 | """ 19 | BERTScore metric. 20 | Args: 21 | - :param: `cands` (list of str): candidate sentences 22 | - :param: `refs` (list of str): reference sentences 23 | - :param: `bert` (str): bert specification 24 | - :param: `num_layers` (int): the layer of representation to use 25 | - :param: `verbose` (bool): turn on intermediate status update 26 | - :param: `no_idf` (bool): do not use idf weighting 27 | - :param: `batch_size` (int): bert score processing batch size 28 | """ 29 | assert len(cands) == len(refs) 30 | assert bert in bert_types 31 | 32 | tokenizer = BertTokenizer.from_pretrained(bert) 33 | model = BertModel.from_pretrained(bert) 34 | model.eval() 35 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 36 | model.to(device) 37 | 38 | # drop unused layers 39 | model.encoder.layer = torch.nn.ModuleList([layer for layer in model.encoder.layer[:num_layers]]) 40 | 41 | if no_idf: 42 | idf_dict = defaultdict(lambda: 1.) 43 | # set idf for [SEP] and [CLS] to 0 44 | idf_dict[101] = 0 45 | idf_dict[102] = 0 46 | else: 47 | if verbose: 48 | print('preparing IDF dict...') 49 | start = time.perf_counter() 50 | idf_dict = get_idf_dict(refs, tokenizer) 51 | if verbose: 52 | print('done in {:.2f} seconds'.format(time.perf_counter() - start)) 53 | 54 | if verbose: 55 | print('calculating scores...') 56 | start = time.perf_counter() 57 | all_preds = bert_cos_score_idf(model, refs, cands, tokenizer, idf_dict, 58 | verbose=verbose, device=device, batch_size=batch_size) 59 | 60 | P = all_preds[:, 0].cpu() 61 | R = all_preds[:, 1].cpu() 62 | F1 = all_preds[:, 2].cpu() 63 | if verbose: 64 | print('done in {:.2f} seconds'.format(time.perf_counter() - start)) 65 | 66 | return P, R, F1 67 | 68 | def plot_example(h, r, verbose=False, bert="bert-base-multilingual-cased", 69 | num_layers=8, fname=''): 70 | """ 71 | BERTScore metric. 72 | Args: 73 | - :param: `h` (str): a candidate sentence 74 | - :param: `r` (str): a reference sentence 75 | - :param: `verbose` (bool): turn on intermediate status update 76 | - :param: `bert` (str): bert specification 77 | - :param: `num_layers` (int): the layer of representation to use 78 | """ 79 | assert bert in bert_types 80 | 81 | if verbose: 82 | print('loading BERT model...') 83 | tokenizer = BertTokenizer.from_pretrained(bert) 84 | model = BertModel.from_pretrained(bert) 85 | model.eval() 86 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 87 | model.to(device) 88 | 89 | h_tokens = ['[CLS]'] + tokenizer.tokenize(h) + ['[SEP]'] 90 | r_tokens = ['[CLS]'] + tokenizer.tokenize(r) + ['[SEP]'] 91 | 92 | model.encoder.layer = torch.nn.ModuleList([layer for layer in model.encoder.layer[:num_layers]]) 93 | idf_dict = defaultdict(lambda: 1.) 94 | 95 | ref_embedding, ref_lens, ref_masks, padded_idf = get_bert_embedding([r], model, tokenizer, idf_dict, 96 | device=device) 97 | hyp_embedding, ref_lens, ref_masks, padded_idf = get_bert_embedding([h], model, tokenizer, idf_dict, 98 | device=device) 99 | 100 | ref_embedding.div_(torch.norm(ref_embedding, dim=-1).unsqueeze(-1)) 101 | hyp_embedding.div_(torch.norm(hyp_embedding, dim=-1).unsqueeze(-1)) 102 | 103 | batch_size = ref_embedding.size(1) 104 | 105 | sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2)).cpu() 106 | sim = sim.squeeze(0).numpy() 107 | 108 | # remove [CLS] and [SEP] tokens 109 | r_tokens = r_tokens[1:-1] 110 | h_tokens = h_tokens[1:-1] 111 | sim = sim[1:-1,1:-1] 112 | 113 | fig, ax = plt.subplots(figsize=(len(r_tokens)*0.8, len(h_tokens)*0.8)) 114 | im = ax.imshow(sim, cmap='Blues') 115 | 116 | # We want to show all ticks... 117 | ax.set_xticks(np.arange(len(r_tokens))) 118 | ax.set_yticks(np.arange(len(h_tokens))) 119 | # ... and label them with the respective list entries 120 | ax.set_xticklabels(r_tokens, fontsize=10) 121 | ax.set_yticklabels(h_tokens, fontsize=10) 122 | plt.xlabel("Refernce", fontsize=10) 123 | plt.ylabel("Candidate", fontsize=10) 124 | 125 | # Rotate the tick labels and set their alignment. 126 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", 127 | rotation_mode="anchor") 128 | 129 | # Loop over data dimensions and create text annotations. 130 | for i in range(len(h_tokens)): 131 | for j in range(len(r_tokens)): 132 | text = ax.text(j, i, '{:.3f}'.format(sim[i, j]), 133 | ha="center", va="center", color="k" if sim[i, j] < 0.6 else "w") 134 | 135 | # P = sim.max(1).mean() 136 | # R = sim.max(0).mean() 137 | # F1 = 2 * P * R / (P + R) 138 | 139 | fig.tight_layout() 140 | # plt.title("BERT-F1: {:.3f}".format(F1), fontsize=10) 141 | if fname != "": 142 | print("Saved figure to file: ", fname+".png") 143 | plt.savefig(fname+'.png', dpi=100) 144 | plt.show() 145 | -------------------------------------------------------------------------------- /system_eval/evaluation/bert_score/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from math import log 3 | from itertools import chain 4 | from collections import defaultdict, Counter 5 | from multiprocessing import Pool 6 | from functools import partial 7 | from tqdm.auto import tqdm 8 | 9 | __all__ = ['bert_types'] 10 | 11 | bert_types = [ 12 | 'bert-base-uncased', 13 | 'bert-large-uncased', 14 | 'bert-base-cased', 15 | 'bert-large-cased', 16 | 'bert-base-multilingual-uncased', 17 | 'bert-base-multilingual-cased', 18 | 'bert-base-chinese', 19 | ] 20 | 21 | def padding(arr, pad_token, dtype=torch.long): 22 | lens = torch.LongTensor([len(a) for a in arr]) 23 | max_len = lens.max().item() 24 | padded = torch.ones(len(arr), max_len, dtype=dtype) * pad_token 25 | mask = torch.zeros(len(arr), max_len, dtype=torch.long) 26 | for i, a in enumerate(arr): 27 | padded[i, :lens[i]] = torch.tensor(a, dtype=dtype) 28 | mask[i, :lens[i]] = 1 29 | return padded, lens, mask 30 | 31 | 32 | def bert_encode(model, x, attention_mask): 33 | model.eval() 34 | x_seg = torch.zeros_like(x, dtype=torch.long) 35 | with torch.no_grad(): 36 | x_encoded_layers, pooled_output = model(x, x_seg, attention_mask=attention_mask, output_all_encoded_layers=False) 37 | return x_encoded_layers 38 | 39 | 40 | def process(a, tokenizer=None): 41 | if not tokenizer is None: 42 | a = ["[CLS]"]+tokenizer.tokenize(a)+["[SEP]"] 43 | a = tokenizer.convert_tokens_to_ids(a) 44 | return set(a) 45 | 46 | 47 | def get_idf_dict(arr, tokenizer, nthreads=4): 48 | """ 49 | Returns mapping from word piece index to its inverse document frequency. 50 | Args: 51 | - :param: `arr` (list of str) : sentences to process. 52 | - :param: `tokenizer` : a BERT tokenizer corresponds to `model`. 53 | - :param: `nthreads` (int) : number of CPU threads to use 54 | """ 55 | idf_count = Counter() 56 | num_docs = len(arr) 57 | 58 | process_partial = partial(process, tokenizer=tokenizer) 59 | 60 | with Pool(nthreads) as p: 61 | idf_count.update(chain.from_iterable(p.map(process_partial, arr))) 62 | 63 | idf_dict = defaultdict(lambda : log((num_docs+1)/(1))) 64 | idf_dict.update({idx:log((num_docs+1)/(c+1)) for (idx, c) in idf_count.items()}) 65 | return idf_dict 66 | 67 | 68 | def collate_idf(arr, tokenize, numericalize, idf_dict, 69 | pad="[PAD]", device='cuda:0'): 70 | """ 71 | Helper function that pads a list of sentences to hvae the same length and 72 | loads idf score for words in the sentences. 73 | Args: 74 | - :param: `arr` (list of str): sentences to process. 75 | - :param: `tokenize` : a function that takes a string and return list 76 | of tokens. 77 | - :param: `numericalize` : a function that takes a list of tokens and 78 | return list of token indexes. 79 | - :param: `idf_dict` (dict): mapping a word piece index to its 80 | inverse document frequency 81 | - :param: `pad` (str): the padding token. 82 | - :param: `device` (str): device to use, e.g. 'cpu' or 'cuda' 83 | """ 84 | arr = [["[CLS]"]+tokenize(a)+["[SEP]"] for a in arr] 85 | arr = [numericalize(a) for a in arr] 86 | 87 | idf_weights = [[idf_dict[i] for i in a] for a in arr] 88 | 89 | pad_token = numericalize([pad])[0] 90 | 91 | padded, lens, mask = padding(arr, pad_token, dtype=torch.long) 92 | padded_idf, _, _ = padding(idf_weights, pad_token, dtype=torch.float) 93 | 94 | padded = padded.to(device=device) 95 | mask = mask.to(device=device) 96 | lens = lens.to(device=device) 97 | return padded, padded_idf, lens, mask 98 | 99 | 100 | def get_bert_embedding(all_sens, model, tokenizer, idf_dict, 101 | batch_size=-1, device='cuda:0'): 102 | """ 103 | Compute BERT embedding in batches. 104 | Args: 105 | - :param: `all_sens` (list of str) : sentences to encode. 106 | - :param: `model` : a BERT model from `pytorch_pretrained_bert`. 107 | - :param: `tokenizer` : a BERT tokenizer corresponds to `model`. 108 | - :param: `idf_dict` (dict) : mapping a word piece index to its 109 | inverse document frequency 110 | - :param: `device` (str): device to use, e.g. 'cpu' or 'cuda' 111 | """ 112 | 113 | padded_sens, padded_idf, lens, mask = collate_idf(all_sens, 114 | tokenizer.tokenize, tokenizer.convert_tokens_to_ids, 115 | idf_dict, 116 | device=device) 117 | 118 | if batch_size == -1: batch_size = len(all_sens) 119 | 120 | embeddings = [] 121 | with torch.no_grad(): 122 | for i in range(0, len(all_sens), batch_size): 123 | batch_embedding = bert_encode(model, padded_sens[i:i+batch_size], 124 | attention_mask=mask[i:i+batch_size]) 125 | # batch_embedding = torch.stack(batch_embedding) 126 | embeddings.append(batch_embedding) 127 | del batch_embedding 128 | 129 | total_embedding = torch.cat(embeddings, dim=0) 130 | 131 | return total_embedding, lens, mask, padded_idf 132 | 133 | 134 | def greedy_cos_idf(ref_embedding, ref_lens, ref_masks, ref_idf, 135 | hyp_embedding, hyp_lens, hyp_masks, hyp_idf): 136 | """ 137 | Compute greedy matching based on cosine similarity. 138 | Args: 139 | - :param: `ref_embedding` (torch.Tensor): 140 | embeddings of reference sentences, BxKxd, 141 | B: batch size, K: longest length, d: bert dimenison 142 | - :param: `ref_lens` (list of int): list of reference sentence length. 143 | - :param: `ref_masks` (torch.LongTensor): BxKxK, BERT attention mask for 144 | reference sentences. 145 | - :param: `ref_idf` (torch.Tensor): BxK, idf score of each word 146 | piece in the reference setence 147 | - :param: `hyp_embedding` (torch.Tensor): 148 | embeddings of candidate sentences, BxKxd, 149 | B: batch size, K: longest length, d: bert dimenison 150 | - :param: `hyp_lens` (list of int): list of candidate sentence length. 151 | - :param: `hyp_masks` (torch.LongTensor): BxKxK, BERT attention mask for 152 | candidate sentences. 153 | - :param: `hyp_idf` (torch.Tensor): BxK, idf score of each word 154 | piece in the candidate setence 155 | """ 156 | 157 | ref_embedding.div_(torch.norm(ref_embedding, dim=-1).unsqueeze(-1)) 158 | hyp_embedding.div_(torch.norm(hyp_embedding, dim=-1).unsqueeze(-1)) 159 | 160 | batch_size = ref_embedding.size(0) 161 | 162 | sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2)) 163 | masks = torch.bmm(hyp_masks.unsqueeze(2).float(), ref_masks.unsqueeze(1).float()) 164 | masks = masks.expand(batch_size, masks.size(1), masks.size(2))\ 165 | .contiguous().view_as(sim) 166 | 167 | masks = masks.float().to(sim.device) 168 | sim = sim * masks 169 | 170 | word_precision = sim.max(dim=2)[0] 171 | word_recall = sim.max(dim=1)[0] 172 | 173 | hyp_idf.div_(hyp_idf.sum(dim=1, keepdim=True)) 174 | ref_idf.div_(ref_idf.sum(dim=1, keepdim=True)) 175 | precision_scale = hyp_idf.to(word_precision.device) 176 | recall_scale = ref_idf.to(word_recall.device) 177 | P = (word_precision * precision_scale).sum(dim=1) 178 | R = (word_recall * recall_scale).sum(dim=1) 179 | 180 | F = 2 * P * R / (P + R) 181 | return P, R, F 182 | 183 | def bert_cos_score_idf(model, refs, hyps, tokenizer, idf_dict, 184 | verbose=False, batch_size=64, device='cuda:0'): 185 | """ 186 | Compute BERTScore. 187 | Args: 188 | - :param: `model` : a BERT model in `pytorch_pretrained_bert` 189 | - :param: `refs` (list of str): reference sentences 190 | - :param: `hyps` (list of str): candidate sentences 191 | - :param: `tokenzier` : a BERT tokenizer corresponds to `model` 192 | - :param: `idf_dict` : a dictionary mapping a word piece index to its 193 | inverse document frequency 194 | - :param: `verbose` (bool): turn on intermediate status update 195 | - :param: `batch_size` (int): bert score processing batch size 196 | - :param: `device` (str): device to use, e.g. 'cpu' or 'cuda' 197 | """ 198 | preds = [] 199 | iter_range = range(0, len(refs), batch_size) 200 | if verbose: iter_range = tqdm(iter_range) 201 | for batch_start in iter_range: 202 | batch_refs = refs[batch_start:batch_start+batch_size] 203 | batch_hyps = hyps[batch_start:batch_start+batch_size] 204 | ref_stats = get_bert_embedding(batch_refs, model, tokenizer, idf_dict, 205 | device=device) 206 | hyp_stats = get_bert_embedding(batch_hyps, model, tokenizer, idf_dict, 207 | device=device) 208 | 209 | P, R, F1 = greedy_cos_idf(*ref_stats, *hyp_stats) 210 | preds.append(torch.stack((P, R, F1), dim=1).cpu()) 211 | preds = torch.cat(preds, dim=0) 212 | return preds 213 | -------------------------------------------------------------------------------- /system_eval/evaluation/bleu/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /system_eval/evaluation/bleu/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /system_eval/evaluation/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /system_eval/evaluation/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from evaluation.bleu.bleu_scorer import BleuScorer 12 | 13 | 14 | class Bleu: 15 | def __init__(self, n=4): 16 | # default compute Blue score up to 4 17 | self._n = n 18 | self._hypo_for_image = {} 19 | self.ref_for_image = {} 20 | 21 | def compute_score(self, gts, res): 22 | assert(gts.keys() == res.keys()) 23 | imgIds = gts.keys() 24 | bleu_scorer = BleuScorer(n=self._n) 25 | for id in imgIds: 26 | hypo = res[id] 27 | # print("hypo:", hypo) 28 | ref = gts[id] 29 | # print("ref:", ref) 30 | 31 | # Sanity check. 32 | assert(type(hypo) is list) 33 | assert(len(hypo) == 1) 34 | assert(type(ref) is list) 35 | assert(len(ref) >= 1) 36 | 37 | bleu_scorer += (hypo[0], ref) 38 | 39 | #score, scores = bleu_scorer.compute_score(option='shortest') 40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=1) 41 | #score, scores = bleu_scorer.compute_score(option='average', verbose=0) 42 | 43 | # return (bleu, bleu_info) 44 | return score, scores 45 | 46 | def method(self): 47 | return "Bleu" 48 | -------------------------------------------------------------------------------- /system_eval/evaluation/bleu/bleu_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # bleu_scorer.py 4 | # David Chiang 5 | 6 | # Copyright (c) 2004-2006 University of Maryland. All rights 7 | # reserved. Do not redistribute without permission from the 8 | # author. Not for commercial use. 9 | 10 | # Modified by: 11 | # Hao Fang 12 | # Tsung-Yi Lin 13 | 14 | '''Provides: 15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 17 | ''' 18 | 19 | import copy 20 | import sys, math, re 21 | from collections import defaultdict 22 | 23 | def precook(s, n=4, out=False): 24 | """Takes a string as input and returns an object that can be given to 25 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 26 | can take string arguments as well.""" 27 | words = s.split() 28 | counts = defaultdict(int) 29 | for k in range(1,n+1): 30 | for i in range(len(words)-k+1): 31 | ngram = tuple(words[i:i+k]) 32 | counts[ngram] += 1 33 | return (len(words), counts) 34 | 35 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 36 | '''Takes a list of reference sentences for a single segment 37 | and returns an object that encapsulates everything that BLEU 38 | needs to know about them.''' 39 | 40 | reflen = [] 41 | maxcounts = {} 42 | for ref in refs: 43 | rl, counts = precook(ref, n) 44 | reflen.append(rl) 45 | for (ngram,count) in counts.items(): 46 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 47 | 48 | # Calculate effective reference sentence length. 49 | if eff == "shortest": 50 | reflen = min(reflen) 51 | elif eff == "average": 52 | reflen = float(sum(reflen))/len(reflen) 53 | 54 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 55 | 56 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 57 | 58 | return (reflen, maxcounts) 59 | 60 | def cook_test(test, tup, eff=None, n=4): 61 | '''Takes a test sentence and returns an object that 62 | encapsulates everything that BLEU needs to know about it.''' 63 | 64 | (reflen, refmaxcounts) = tup 65 | testlen, counts = precook(test, n, True) 66 | 67 | result = {} 68 | 69 | # Calculate effective reference sentence length. 70 | 71 | if eff == "closest": 72 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1] 73 | else: ## i.e., "average" or "shortest" or None 74 | result["reflen"] = reflen 75 | 76 | result["testlen"] = testlen 77 | 78 | result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)] 79 | 80 | result['correct'] = [0]*n 81 | for (ngram, count) in counts.items(): 82 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) 83 | 84 | return result 85 | 86 | class BleuScorer(object): 87 | """Bleu scorer. 88 | """ 89 | 90 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 91 | # special_reflen is used in oracle (proportional effective ref len for a node). 92 | 93 | def copy(self): 94 | ''' copy the refs.''' 95 | new = BleuScorer(n=self.n) 96 | new.ctest = copy.copy(self.ctest) 97 | new.crefs = copy.copy(self.crefs) 98 | new._score = None 99 | return new 100 | 101 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 102 | ''' singular instance ''' 103 | 104 | self.n = n 105 | self.crefs = [] 106 | self.ctest = [] 107 | self.cook_append(test, refs) 108 | self.special_reflen = special_reflen 109 | 110 | def cook_append(self, test, refs): 111 | '''called by constructor and __iadd__ to avoid creating new instances.''' 112 | 113 | if refs is not None: 114 | self.crefs.append(cook_refs(refs)) 115 | if test is not None: 116 | cooked_test = cook_test(test, self.crefs[-1]) 117 | self.ctest.append(cooked_test) ## N.B.: -1 118 | else: 119 | self.ctest.append(None) # lens of crefs and ctest have to match 120 | 121 | self._score = None ## need to recompute 122 | 123 | def ratio(self, option=None): 124 | self.compute_score(option=option) 125 | return self._ratio 126 | 127 | def score_ratio(self, option=None): 128 | '''return (bleu, len_ratio) pair''' 129 | return (self.fscore(option=option), self.ratio(option=option)) 130 | 131 | def score_ratio_str(self, option=None): 132 | return "%.4f (%.2f)" % self.score_ratio(option) 133 | 134 | def reflen(self, option=None): 135 | self.compute_score(option=option) 136 | return self._reflen 137 | 138 | def testlen(self, option=None): 139 | self.compute_score(option=option) 140 | return self._testlen 141 | 142 | def retest(self, new_test): 143 | if type(new_test) is str: 144 | new_test = [new_test] 145 | assert len(new_test) == len(self.crefs), new_test 146 | self.ctest = [] 147 | for t, rs in zip(new_test, self.crefs): 148 | self.ctest.append(cook_test(t, rs)) 149 | self._score = None 150 | 151 | return self 152 | 153 | def rescore(self, new_test): 154 | ''' replace test(s) with new test(s), and returns the new score.''' 155 | 156 | return self.retest(new_test).compute_score() 157 | 158 | def size(self): 159 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 160 | return len(self.crefs) 161 | 162 | def __iadd__(self, other): 163 | '''add an instance (e.g., from another sentence).''' 164 | 165 | if type(other) is tuple: 166 | ## avoid creating new BleuScorer instances 167 | self.cook_append(other[0], other[1]) 168 | else: 169 | assert self.compatible(other), "incompatible BLEUs." 170 | self.ctest.extend(other.ctest) 171 | self.crefs.extend(other.crefs) 172 | self._score = None ## need to recompute 173 | 174 | return self 175 | 176 | def compatible(self, other): 177 | return isinstance(other, BleuScorer) and self.n == other.n 178 | 179 | def single_reflen(self, option="average"): 180 | return self._single_reflen(self.crefs[0][0], option) 181 | 182 | def _single_reflen(self, reflens, option=None, testlen=None): 183 | 184 | if option == "shortest": 185 | reflen = min(reflens) 186 | elif option == "average": 187 | reflen = float(sum(reflens))/len(reflens) 188 | elif option == "closest": 189 | reflen = min((abs(l-testlen), l) for l in reflens)[1] 190 | else: 191 | assert False, "unsupported reflen option %s" % option 192 | 193 | return reflen 194 | 195 | def recompute_score(self, option=None, verbose=0): 196 | self._score = None 197 | return self.compute_score(option, verbose) 198 | 199 | def compute_score(self, option=None, verbose=0): 200 | n = self.n 201 | small = 1e-9 202 | tiny = 1e-15 ## so that if guess is 0 still return 0 203 | bleu_list = [[] for _ in range(n)] 204 | 205 | if self._score is not None: 206 | return self._score 207 | 208 | if option is None: 209 | option = "average" if len(self.crefs) == 1 else "closest" 210 | 211 | self._testlen = 0 212 | self._reflen = 0 213 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} 214 | 215 | # for each sentence 216 | for comps in self.ctest: 217 | testlen = comps['testlen'] 218 | self._testlen += testlen 219 | 220 | if self.special_reflen is None: ## need computation 221 | reflen = self._single_reflen(comps['reflen'], option, testlen) 222 | else: 223 | reflen = self.special_reflen 224 | 225 | self._reflen += reflen 226 | 227 | for key in ['guess','correct']: 228 | for k in range(n): 229 | totalcomps[key][k] += comps[key][k] 230 | 231 | # append per image bleu score 232 | bleu = 1. 233 | for k in range(n): 234 | bleu *= (float(comps['correct'][k]) + tiny) \ 235 | /(float(comps['guess'][k]) + small) 236 | bleu_list[k].append(bleu ** (1./(k+1))) 237 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division 238 | if ratio < 1: 239 | for k in range(n): 240 | bleu_list[k][-1] *= math.exp(1 - 1/ratio) 241 | 242 | if verbose > 1: 243 | print(comps, reflen) 244 | 245 | totalcomps['reflen'] = self._reflen 246 | totalcomps['testlen'] = self._testlen 247 | 248 | bleus = [] 249 | bleu = 1. 250 | for k in range(n): 251 | bleu *= float(totalcomps['correct'][k] + tiny) \ 252 | / (totalcomps['guess'][k] + small) 253 | bleus.append(bleu ** (1./(k+1))) 254 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division 255 | if ratio < 1: 256 | for k in range(n): 257 | bleus[k] *= math.exp(1 - 1/ratio) 258 | 259 | if verbose > 0: 260 | print(totalcomps) 261 | print("ratio:", ratio) 262 | 263 | self._score = bleus 264 | return self._score, bleu_list 265 | -------------------------------------------------------------------------------- /system_eval/evaluation/cider/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /system_eval/evaluation/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | 10 | from evaluation.cider.cider_scorer import CiderScorer 11 | import pdb 12 | 13 | class Cider: 14 | """ 15 | Main Class to compute the CIDEr metric 16 | 17 | """ 18 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 19 | # set cider to sum over 1 to 4-grams 20 | self._n = n 21 | # set the standard deviation parameter for gaussian penalty 22 | self._sigma = sigma 23 | 24 | def compute_score(self, gts, res): 25 | """ 26 | Main function to compute CIDEr score 27 | :param hypo_for_image (dict) : dictionary with key and value 28 | ref_for_image (dict) : dictionary with key and value 29 | :return: cider (float) : computed CIDEr score for the corpus 30 | """ 31 | 32 | assert(gts.keys() == res.keys()) 33 | imgIds = gts.keys() 34 | 35 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) 36 | 37 | for id in imgIds: 38 | hypo = res[id] 39 | ref = gts[id] 40 | 41 | # Sanity check. 42 | assert(type(hypo) is list) 43 | assert(len(hypo) == 1) 44 | assert(type(ref) is list) 45 | assert(len(ref) > 0) 46 | 47 | cider_scorer += (hypo[0], ref) 48 | 49 | (score, scores) = cider_scorer.compute_score() 50 | 51 | return score, scores 52 | 53 | def method(self): 54 | return "CIDEr" 55 | -------------------------------------------------------------------------------- /system_eval/evaluation/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | 5 | import copy 6 | from collections import defaultdict 7 | import numpy as np 8 | import pdb 9 | import math 10 | 11 | def precook(s, n=4, out=False): 12 | """ 13 | Takes a string as input and returns an object that can be given to 14 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 15 | can take string arguments as well. 16 | :param s: string : sentence to be converted into ngrams 17 | :param n: int : number of ngrams for which representation is calculated 18 | :return: term frequency vector for occuring ngrams 19 | """ 20 | words = s.split() 21 | counts = defaultdict(int) 22 | for k in range(1,n+1): 23 | for i in range(len(words)-k+1): 24 | ngram = tuple(words[i:i+k]) 25 | counts[ngram] += 1 26 | return counts 27 | 28 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 29 | '''Takes a list of reference sentences for a single segment 30 | and returns an object that encapsulates everything that BLEU 31 | needs to know about them. 32 | :param refs: list of string : reference sentences for some image 33 | :param n: int : number of ngrams for which (ngram) representation is calculated 34 | :return: result (list of dict) 35 | ''' 36 | return [precook(ref, n) for ref in refs] 37 | 38 | def cook_test(test, n=4): 39 | '''Takes a test sentence and returns an object that 40 | encapsulates everything that BLEU needs to know about it. 41 | :param test: list of string : hypothesis sentence for some image 42 | :param n: int : number of ngrams for which (ngram) representation is calculated 43 | :return: result (dict) 44 | ''' 45 | return precook(test, n, True) 46 | 47 | class CiderScorer(object): 48 | """CIDEr scorer. 49 | """ 50 | 51 | def copy(self): 52 | ''' copy the refs.''' 53 | new = CiderScorer(n=self.n) 54 | new.ctest = copy.copy(self.ctest) 55 | new.crefs = copy.copy(self.crefs) 56 | return new 57 | 58 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 59 | ''' singular instance ''' 60 | self.n = n 61 | self.sigma = sigma 62 | self.crefs = [] 63 | self.ctest = [] 64 | self.document_frequency = defaultdict(float) 65 | self.cook_append(test, refs) 66 | self.ref_len = None 67 | 68 | def cook_append(self, test, refs): 69 | '''called by constructor and __iadd__ to avoid creating new instances.''' 70 | 71 | if refs is not None: 72 | self.crefs.append(cook_refs(refs)) 73 | if test is not None: 74 | self.ctest.append(cook_test(test)) ## N.B.: -1 75 | else: 76 | self.ctest.append(None) # lens of crefs and ctest have to match 77 | 78 | def size(self): 79 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 80 | return len(self.crefs) 81 | 82 | def __iadd__(self, other): 83 | '''add an instance (e.g., from another sentence).''' 84 | 85 | if type(other) is tuple: 86 | ## avoid creating new CiderScorer instances 87 | self.cook_append(other[0], other[1]) 88 | else: 89 | self.ctest.extend(other.ctest) 90 | self.crefs.extend(other.crefs) 91 | 92 | return self 93 | def compute_doc_freq(self): 94 | ''' 95 | Compute term frequency for reference data. 96 | This will be used to compute idf (inverse document frequency later) 97 | The term frequency is stored in the object 98 | :return: None 99 | ''' 100 | for refs in self.crefs: 101 | # refs, k ref captions of one image 102 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 103 | self.document_frequency[ngram] += 1 104 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 105 | 106 | def compute_cider(self): 107 | def counts2vec(cnts): 108 | """ 109 | Function maps counts of ngram to vector of tfidf weights. 110 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 111 | The n-th entry of array denotes length of n-grams. 112 | :param cnts: 113 | :return: vec (array of dict), norm (array of float), length (int) 114 | """ 115 | vec = [defaultdict(float) for _ in range(self.n)] 116 | length = 0 117 | norm = [0.0 for _ in range(self.n)] 118 | for (ngram,term_freq) in cnts.items(): 119 | # give word count 1 if it doesn't appear in reference corpus 120 | df = np.log(max(1.0, self.document_frequency[ngram])) 121 | # ngram index 122 | n = len(ngram)-1 123 | # tf (term_freq) * idf (precomputed idf) for n-grams 124 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 125 | # compute norm for the vector. the norm will be used for computing similarity 126 | norm[n] += pow(vec[n][ngram], 2) 127 | 128 | if n == 1: 129 | length += term_freq 130 | norm = [np.sqrt(n) for n in norm] 131 | return vec, norm, length 132 | 133 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 134 | ''' 135 | Compute the cosine similarity of two vectors. 136 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 137 | :param vec_ref: array of dictionary for vector corresponding to reference 138 | :param norm_hyp: array of float for vector corresponding to hypothesis 139 | :param norm_ref: array of float for vector corresponding to reference 140 | :param length_hyp: int containing length of hypothesis 141 | :param length_ref: int containing length of reference 142 | :return: array of score for each n-grams cosine similarity 143 | ''' 144 | delta = float(length_hyp - length_ref) 145 | # measure consine similarity 146 | val = np.array([0.0 for _ in range(self.n)]) 147 | for n in range(self.n): 148 | # ngram 149 | for (ngram,count) in vec_hyp[n].items(): 150 | # vrama91 : added clipping 151 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 152 | 153 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 154 | val[n] /= (norm_hyp[n]*norm_ref[n]) 155 | 156 | assert(not math.isnan(val[n])) 157 | # vrama91: added a length based gaussian penalty 158 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 159 | return val 160 | 161 | # compute log reference length 162 | self.ref_len = np.log(float(len(self.crefs))) 163 | 164 | scores = [] 165 | for test, refs in zip(self.ctest, self.crefs): 166 | # compute vector for test captions 167 | vec, norm, length = counts2vec(test) 168 | # compute vector for ref captions 169 | score = np.array([0.0 for _ in range(self.n)]) 170 | for ref in refs: 171 | vec_ref, norm_ref, length_ref = counts2vec(ref) 172 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 173 | # change by vrama91 - mean of ngram scores, instead of sum 174 | score_avg = np.mean(score) 175 | # divide by number of references 176 | score_avg /= len(refs) 177 | # multiply score by 10 178 | score_avg *= 10.0 179 | # append score of an image to the score list 180 | scores.append(score_avg) 181 | return scores 182 | 183 | def compute_score(self, option=None, verbose=0): 184 | # compute idf 185 | self.compute_doc_freq() 186 | # assert to check document frequency 187 | assert(len(self.ctest) >= max(self.document_frequency.values())) 188 | # compute cider score 189 | score = self.compute_cider() 190 | # debug 191 | # print score 192 | return np.mean(np.array(score)), np.array(score) 193 | -------------------------------------------------------------------------------- /system_eval/evaluation/eval.py: -------------------------------------------------------------------------------- 1 | from evaluation.bleu.bleu import Bleu 2 | from evaluation.meteor.meteor_nltk import Meteor 3 | from evaluation.rouge.rouge import Rouge 4 | from evaluation.cider.cider import Cider 5 | from evaluation.bert_score.bert_score import BertScore 6 | from collections import defaultdict 7 | from argparse import ArgumentParser 8 | 9 | import sys 10 | import json 11 | #reload(sys) 12 | #sys.setdefaultencoding('utf-8') 13 | 14 | class QGEvalCap: 15 | def __init__(self, model_key, gts, res, results_file=None): 16 | self.gts = gts 17 | self.res = res 18 | self.results_file = results_file 19 | self.model_key = model_key 20 | 21 | def evaluate(self): 22 | output = [] 23 | scorers = [ 24 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 25 | (Meteor(),"METEOR"), 26 | (Rouge(), "ROUGE_L"), 27 | (Cider(), "CIDEr"), 28 | (BertScore(), "Bert Score") 29 | ] 30 | 31 | # ================================================= 32 | # Compute scores 33 | # ================================================= 34 | score_dict = {} 35 | scores_dict = {} 36 | #scores_dict["model_key"] = self.model_key 37 | for scorer, method in scorers: 38 | # print 'computing %s score...'%(scorer.method()) 39 | score, scores = scorer.compute_score(self.gts, self.res) 40 | if type(method) == list: 41 | for sc, scs, m in zip(score, scores, method): 42 | #print("%s: %0.5f"%(m, sc)) 43 | output.append(sc) 44 | score_dict[m] = str(sc) 45 | scores_dict[m] = list(scs) 46 | else: 47 | #print("%s: %0.5f"%(method, score)) 48 | output.append(score) 49 | score_dict[method] = score 50 | scores_dict[method] = list(scores) 51 | 52 | if self.results_file != None: 53 | with open(self.results_file, "a") as f: 54 | f.write(json.dumps(score_dict)+"\n") 55 | 56 | return score_dict, scores_dict 57 | 58 | def eval(model_key, sources, references, predictions, results_file=None): 59 | """ 60 | Given a filename, calculate the metric scores for that prediction file 61 | isDin: boolean value to check whether input file is DirectIn.txt 62 | """ 63 | 64 | pairs = [] 65 | 66 | for tup in sources: 67 | pair = {} 68 | pair['tokenized_sentence'] = tup 69 | pairs.append(pair) 70 | 71 | cnt = 0 72 | for line in references: 73 | pairs[cnt]['tokenized_question'] = line 74 | cnt += 1 75 | 76 | output = predictions 77 | 78 | for idx, pair in enumerate(pairs): 79 | pair['prediction'] = output[idx] 80 | 81 | ## eval 82 | from evaluation.eval import QGEvalCap 83 | import json 84 | from json import encoder 85 | encoder.FLOAT_REPR = lambda o: format(o, '.4f') 86 | 87 | res = defaultdict(lambda: []) 88 | gts = defaultdict(lambda: []) 89 | for pair in pairs[:]: 90 | key = pair['tokenized_sentence'] 91 | #res[key] = [pair['prediction']] 92 | res[key] = pair['prediction'] 93 | 94 | ## gts 95 | gts[key].append(pair['tokenized_question']) 96 | 97 | QGEval = QGEvalCap(model_key, gts, res, results_file) 98 | return QGEval.evaluate() 99 | 100 | 101 | def preprocess(file_name, keys): 102 | with open(file_name) as f: 103 | data = f.readlines() 104 | generations = [json.loads(elem) for elem in data] 105 | 106 | predictions = {} 107 | references = {} 108 | sources = {} 109 | keys_list = keys if keys!=None else generations[0]["generations"].keys() 110 | for key in keys_list: 111 | references[key] = [] 112 | predictions[key] = [] 113 | sources[key] = [] 114 | 115 | for elem in generations: 116 | label = elem["label"] 117 | hyp = elem["hyp"+label] 118 | for key in keys_list: 119 | if key in elem["generations"]: 120 | references[key].append(hyp) 121 | predictions[key].append(elem["generations"][key]) 122 | sources[key].append((elem["obs1"], elem["obs2"])) 123 | 124 | return sources, references, predictions 125 | 126 | 127 | if __name__ == "__main__": 128 | parser = ArgumentParser() 129 | parser.add_argument("-gen_file", "--gen_file", dest="gen_file", help="generations file with gold/references") 130 | parser.add_argument("--keys", type=str, default=None, help="comma-separated list of model keys") 131 | parser.add_argument("--results_file", default="eval_results.jsonl") 132 | args = parser.parse_args() 133 | 134 | print("scores: \n") 135 | keys=None 136 | if args.keys: 137 | keys = args.keys.split(",") 138 | 139 | sources, references, predictions = preprocess(args.gen_file, keys) 140 | for key in references.keys(): 141 | print("\nEvaluating %s" %key) 142 | eval(key, sources[key], references[key], predictions[key], args.results_file) 143 | 144 | -------------------------------------------------------------------------------- /system_eval/evaluation/meteor/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /system_eval/evaluation/meteor/meteor-1.5.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongho94/solar-framework_commonsense-inference/50e99ba0b0f5ae2315c72f5f35f6d385499283e1/system_eval/evaluation/meteor/meteor-1.5.jar -------------------------------------------------------------------------------- /system_eval/evaluation/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python wrapper for METEOR implementation, by Xinlei Chen 4 | # Acknowledge Michael Denkowski for the generous discussion and help 5 | 6 | import os 7 | import sys 8 | import subprocess 9 | import threading 10 | 11 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed. 12 | METEOR_JAR = 'meteor-1.5.jar' 13 | # print METEOR_JAR 14 | 15 | class Meteor: 16 | 17 | def __init__(self): 18 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ 19 | '-', '-', '-stdio', '-l', 'en', 20 | '-norm', 21 | # '-t', 'adq' 22 | # '-p', '0.85 0.2 0.6 0.75' # alpha beta gamma delta'', 23 | # '-a', 'data/paraphrase-en.gz', '-m', 'exact stem paraphrase'] 24 | ] 25 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \ 26 | cwd=os.path.dirname(os.path.abspath(__file__)), \ 27 | stdin=subprocess.PIPE, \ 28 | stdout=subprocess.PIPE, \ 29 | stderr=subprocess.PIPE) 30 | # Used to guarantee thread safety 31 | self.lock = threading.Lock() 32 | 33 | def compute_score(self, gts, res): 34 | assert(gts.keys() == res.keys()) 35 | imgIds = gts.keys() 36 | scores = [] 37 | 38 | eval_line = 'EVAL' 39 | self.lock.acquire() 40 | for i in imgIds: 41 | assert(len(res[i]) == 1) 42 | stat = self._stat(res[i][0], gts[i]) 43 | eval_line += ' ||| {}'.format(stat) 44 | 45 | print('{}\n'.format(eval_line)) 46 | self.meteor_p.stdin.write('{}\n'.format(eval_line)) 47 | print(self.meteor_p.stdout.readline().strip()) 48 | 49 | for i in range(0,len(imgIds)): 50 | scores.append(float(self.meteor_p.stdout.readline().strip())) 51 | score = float(self.meteor_p.stdout.readline().strip()) 52 | self.lock.release() 53 | 54 | return score, scores 55 | 56 | def method(self): 57 | return "METEOR" 58 | 59 | def _stat(self, hypothesis_str, reference_list): 60 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 61 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 62 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 63 | # print score_line 64 | str_in = '{}\n'.format(score_line) 65 | #self.meteor_p.communicate(str_in.encode('utf=8')) 66 | self.meteor_p.stdin.write(str_in.encode('utf=8')) 67 | return self.meteor_p.stdout.readline().strip() 68 | 69 | def _score(self, hypothesis_str, reference_list): 70 | self.lock.acquire() 71 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 72 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 73 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 74 | self.meteor_p.stdin.write('{}\n'.format(score_line)) 75 | stats = self.meteor_p.stdout.readline().strip() 76 | eval_line = 'EVAL ||| {}'.format(stats) 77 | # EVAL ||| stats 78 | self.meteor_p.stdin.write('{}\n'.format(eval_line)) 79 | score = float(self.meteor_p.stdout.readline().strip()) 80 | # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice 81 | # thanks for Andrej for pointing this out 82 | score = float(self.meteor_p.stdout.readline().strip()) 83 | self.lock.release() 84 | return score 85 | 86 | def __del__(self): 87 | self.lock.acquire() 88 | self.meteor_p.stdin.close() 89 | self.meteor_p.kill() 90 | self.meteor_p.wait() 91 | self.lock.release() 92 | -------------------------------------------------------------------------------- /system_eval/evaluation/meteor/meteor_nltk.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python wrapper for METEOR implementation, by Xinlei Chen 4 | # Acknowledge Michael Denkowski for the generous discussion and help 5 | 6 | import os 7 | import sys 8 | import nltk 9 | from nltk.translate.meteor_score import meteor_score 10 | 11 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed. 12 | #METEOR_JAR = 'meteor-1.5.jar' 13 | # print METEOR_JAR 14 | 15 | class Meteor: 16 | 17 | def __init__(self): 18 | pass 19 | 20 | def compute_score(self, gts, res): 21 | assert(gts.keys() == res.keys()) 22 | imgIds = gts.keys() 23 | scores = [] 24 | 25 | for i in imgIds: 26 | assert(len(res[i]) == 1) 27 | score = round(meteor_score(gts[i], res[i][0]), 4) 28 | scores.append(score) 29 | #print('{}\n'.format(eval_line)) 30 | #self.meteor_p.stdin.write('{}\n'.format(eval_line)) 31 | #print(self.meteor_p.stdout.readline().strip()) 32 | 33 | #for i in range(0,len(imgIds)): 34 | # scores.append(float(self.meteor_p.stdout.readline().strip())) 35 | #score = float(self.meteor_p.stdout.readline().strip()) 36 | #self.lock.release() 37 | 38 | return sum(scores)/len(scores), scores 39 | 40 | def method(self): 41 | return "METEOR" 42 | 43 | -------------------------------------------------------------------------------- /system_eval/evaluation/rouge/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'vrama91' 2 | -------------------------------------------------------------------------------- /system_eval/evaluation/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | def my_lcs(string, sub): 14 | """ 15 | Calculates longest common subsequence for a pair of tokenized strings 16 | :param string : list of str : tokens from a string split using whitespace 17 | :param sub : list of str : shorter string, also split using whitespace 18 | :returns: length (list of int): length of the longest common subsequence between the two strings 19 | 20 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 21 | """ 22 | if(len(string)< len(sub)): 23 | sub, string = string, sub 24 | 25 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)] 26 | 27 | for j in range(1,len(sub)+1): 28 | for i in range(1,len(string)+1): 29 | if(string[i-1] == sub[j-1]): 30 | lengths[i][j] = lengths[i-1][j-1] + 1 31 | else: 32 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1]) 33 | 34 | return lengths[len(string)][len(sub)] 35 | 36 | class Rouge(): 37 | ''' 38 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 39 | 40 | ''' 41 | def __init__(self): 42 | # vrama91: updated the value below based on discussion with Hovey 43 | self.beta = 1.2 44 | 45 | def calc_score(self, candidate, refs): 46 | """ 47 | Compute ROUGE-L score given one candidate and references for an image 48 | :param candidate: str : candidate sentence to be evaluated 49 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 50 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 51 | """ 52 | assert(len(candidate)==1) 53 | assert(len(refs)>0) 54 | prec = [] 55 | rec = [] 56 | 57 | # split into tokens 58 | token_c = candidate[0].split(" ") 59 | 60 | for reference in refs: 61 | # split into tokens 62 | token_r = reference.split(" ") 63 | # compute the longest common subsequence 64 | lcs = my_lcs(token_r, token_c) 65 | prec.append(lcs/float(len(token_c))) 66 | rec.append(lcs/float(len(token_r))) 67 | 68 | prec_max = max(prec) 69 | rec_max = max(rec) 70 | 71 | if(prec_max!=0 and rec_max !=0): 72 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max) 73 | else: 74 | score = 0.0 75 | return score 76 | 77 | def compute_score(self, gts, res): 78 | """ 79 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 80 | Invoked by evaluate_captions.py 81 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 82 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 83 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 84 | """ 85 | assert(gts.keys() == res.keys()) 86 | imgIds = gts.keys() 87 | 88 | score = [] 89 | for id in imgIds: 90 | hypo = res[id] 91 | ref = gts[id] 92 | 93 | score.append(self.calc_score(hypo, ref)) 94 | 95 | # Sanity check. 96 | assert(type(hypo) is list) 97 | assert(len(hypo) == 1) 98 | assert(type(ref) is list) 99 | assert(len(ref) > 0) 100 | 101 | average_score = np.mean(np.array(score)) 102 | print("len score:", len(score)) 103 | return average_score, np.array(score) 104 | 105 | def method(self): 106 | return "Rouge" 107 | -------------------------------------------------------------------------------- /system_eval/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import csv 4 | import operator 5 | import random 6 | 7 | 8 | def read_csv(input_file, quotechar='"', delimiter=",", skip_header=False): 9 | """Reads a tab separated value file.""" 10 | with open(input_file, "r") as f: 11 | reader = csv.reader(f, delimiter=delimiter, quotechar=quotechar, quoting=csv.QUOTE_ALL, skipinitialspace=True) 12 | lines = [] 13 | for line in reader: 14 | if sys.version_info[0] == 2: 15 | line = list(unicode(cell, 'utf-8') for cell in line) 16 | lines.append(line) 17 | if skip_header: 18 | lines = lines[1:] 19 | return lines 20 | 21 | 22 | def write_tsv(output_file, data, header=False): 23 | keys = list(data[0].keys()) 24 | with open(output_file, 'w') as f: 25 | w = csv.DictWriter(f, keys, delimiter='\t', lineterminator='\n') 26 | if header: 27 | w.writeheader() 28 | for r in data: 29 | entry = {k: r[k] for k in keys} 30 | w.writerow(entry) 31 | 32 | 33 | def write_array2tsv(output_file, data, header=False): 34 | keys = range(len(data[0])) 35 | with open(output_file, 'w') as f: 36 | w = csv.DictWriter(f, keys, delimiter='\t', lineterminator='\n') 37 | if header: 38 | w.writeheader() 39 | for r in data: 40 | entry = {k: r[k] for k in keys} 41 | w.writerow(entry) 42 | 43 | 44 | def write_csv(filename, data, fieldnames): 45 | with open(filename, 'w', newline='') as csvfile: 46 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 47 | 48 | writer.writeheader() 49 | for d in data: 50 | formatted_d = {} 51 | for key, val in d.items(): 52 | formatted_d[key] = json.dumps(val) 53 | writer.writerow(formatted_d) 54 | 55 | 56 | def read_jsonl(filename): 57 | data = [] 58 | with open(filename, "r") as f: 59 | for line in f: 60 | data.append(json.loads(line)) 61 | return data 62 | 63 | 64 | def write_items(output_file, items): 65 | with open(output_file, 'w') as f: 66 | for concept in items: 67 | f.write(concept + "\n") 68 | f.close() 69 | 70 | 71 | def write_jsonl(f, d): 72 | write_items(f, [json.dumps(r) for r in d]) 73 | 74 | 75 | def count_relation(d): 76 | relation_count = {} 77 | prefix_count = {} 78 | head_count = {} 79 | for l in d: 80 | r = l[1] 81 | if r not in relation_count.keys(): 82 | relation_count[r] = 0 83 | relation_count[r] += 1 84 | 85 | prefix = l[0]+l[1] 86 | if prefix not in prefix_count.keys(): 87 | prefix_count[prefix] = 0 88 | prefix_count[prefix] += 1 89 | 90 | head = l[0] 91 | if head not in head_count.keys(): 92 | head_count[head] = 0 93 | head_count[head] += 1 94 | 95 | sorted_relation_count = dict(sorted(relation_count.items(), key=operator.itemgetter(1), reverse=True)) 96 | sorted_prefix_count = dict(sorted(prefix_count.items(), key=operator.itemgetter(1), reverse=True)) 97 | sorted_head_count = dict(sorted(head_count.items(), key=operator.itemgetter(1), reverse=True)) 98 | 99 | print("Relations:") 100 | for r in sorted_relation_count.keys(): 101 | print(r, sorted_relation_count[r]) 102 | 103 | print("\nPrefixes:") 104 | print("uniq prefixes: ", len(sorted_prefix_count.keys())) 105 | i = 0 106 | for r in sorted_prefix_count.keys(): 107 | print(r, sorted_prefix_count[r]) 108 | i += 1 109 | if i > 20: 110 | break 111 | 112 | print("\nHeads:") 113 | i = 0 114 | for r in sorted_head_count.keys(): 115 | print(r, sorted_head_count[r]) 116 | i += 1 117 | if i > 20: 118 | break 119 | 120 | 121 | def get_head_set(d): 122 | return set([l[0] for l in d]) 123 | 124 | 125 | def head_based_split(data, dev_size, test_size, head_size_threshold=500, dev_heads=[], test_heads=[]): 126 | """ 127 | :param data: the tuples to split according to the heads, where the head is the first element of each tuple 128 | :param dev_size: target size of the dev set 129 | :param test_size: target size of the test set 130 | :param head_size_threshold: Maximum number of tuples a head can be involved in, 131 | in order to be considered for the dev/test set' 132 | :param dev_heads: heads that are forced to belong to the dev set 133 | :param test_heads: heads that are forced to belong to the test set 134 | :return: 135 | """ 136 | head_count = {} 137 | for l in data: 138 | head = l[0] 139 | if head not in head_count.keys(): 140 | head_count[head] = 0 141 | head_count[head] += 1 142 | 143 | remaining_heads = dict(head_count) 144 | 145 | test_selected_heads = {} 146 | test_head_total_count = 0 147 | 148 | for h in test_heads: 149 | if h in remaining_heads: 150 | c = remaining_heads[h] 151 | test_selected_heads[h] = c 152 | test_head_total_count += c 153 | remaining_heads.pop(h) 154 | 155 | while test_head_total_count < test_size: 156 | h = random.sample(remaining_heads.keys(), 1)[0] 157 | c = remaining_heads[h] 158 | if c < head_size_threshold: 159 | test_selected_heads[h] = c 160 | test_head_total_count += c 161 | remaining_heads.pop(h) 162 | 163 | test = [l for l in data if l[0] in test_selected_heads.keys()] 164 | 165 | dev_selected_heads = {} 166 | dev_head_total_count = 0 167 | 168 | for h in dev_heads: 169 | if h in remaining_heads: 170 | c = remaining_heads[h] 171 | dev_selected_heads[h] = c 172 | dev_head_total_count += c 173 | remaining_heads.pop(h) 174 | 175 | while dev_head_total_count < dev_size: 176 | h = random.sample(remaining_heads.keys(), 1)[0] 177 | c = remaining_heads[h] 178 | if c < head_size_threshold: 179 | dev_selected_heads[h] = c 180 | dev_head_total_count += c 181 | remaining_heads.pop(h) 182 | 183 | dev = [l for l in data if l[0] in dev_selected_heads.keys()] 184 | 185 | dev_test_heads = set(list(dev_selected_heads.keys()) + list(test_selected_heads.keys())) 186 | train = [l for l in data if l[0] not in dev_test_heads] 187 | 188 | return train, dev, test 189 | 190 | 191 | def remove_prefix(text, prefix): 192 | return text[text.startswith(prefix) and len(prefix):] 193 | --------------------------------------------------------------------------------