├── .gitignore
├── LICENSE
├── README.md
├── config
├── atomic-2020
│ ├── bart-base_baseline
│ │ └── v01.yml
│ ├── bart-base_experiments
│ │ └── v_all_joint_v01.yml
│ ├── bart-large_baseline
│ │ └── v01.yml
│ ├── bart-large_experiments
│ │ ├── v01.yml
│ │ └── v_all_joint_v01.yml
│ └── datasets.yml
├── atomic
│ ├── bart-base_baseline
│ │ └── v01.yml
│ ├── bart-base_experiments
│ │ └── v_all_joint_v01.yml
│ ├── bart-large_baseline
│ │ └── v01.yml
│ ├── bart-large_experiments
│ │ ├── v01.yml
│ │ └── v_all_joint_v01.yml
│ └── datasets.yml
├── conceptnet
│ ├── bart-base_baseline
│ │ └── v01.yml
│ ├── bart-base_experiments
│ │ └── v_all_joint_v01.yml
│ ├── bart-large_baseline
│ │ └── v01.yml
│ ├── bart-large_experiments
│ │ ├── v01.yml
│ │ └── v_all_joint_v01.yml
│ └── datasets.yml
└── tokenizer_config.yml
├── models
├── bart.py
├── distance_func.py
├── head_proj_layer.py
├── loss_func.py
└── model_utils.py
├── requirements.txt
├── scripts
├── feature_learn.py
├── finetune.py
└── inference.py
├── src
├── data_utils.py
├── feed_model.py
├── finetune
│ ├── finetune_trainer.py
│ └── finetune_utils.py
├── lr_schedule.py
├── rec_adam.py
├── rec_adam_wrapper.py
├── sampler.py
├── sampler_utils.py
├── tokenizer.py
├── train_utils.py
├── trainer.py
└── utils.py
└── system_eval
├── automatic_eval.py
├── evaluation
├── LICENSE
├── README.md
├── __init__.py
├── bert_score
│ ├── __init__.py
│ ├── bert_score.py
│ ├── score.py
│ └── utils.py
├── bleu
│ ├── .gitignore
│ ├── LICENSE
│ ├── __init__.py
│ ├── bleu.py
│ └── bleu_scorer.py
├── cider
│ ├── __init__.py
│ ├── cider.py
│ └── cider_scorer.py
├── eval.py
├── meteor
│ ├── __init__.py
│ ├── meteor-1.5.jar
│ ├── meteor.py
│ └── meteor_nltk.py
└── rouge
│ ├── __init__.py
│ └── rouge.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | shell/*
2 | *__pycache__
3 |
4 | human_eval
5 | *.ipynb
6 | *gpt*
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # solar-framework_commonsense-inference
2 | Code release for "Learning from Missing Relations: Contrastive Learning with Commonsense Knowledge Graphs for Commonsense Inference"
3 |
4 |
5 | Download Similarity Matrix :
6 | ```bash
7 | gdown https://drive.google.com/uc?id=1EHMIZXP_T1UfSzCWv9Is16n8DRW3dGJx
8 | ```
9 |
10 | ### Preprocessing
11 |
12 | #### Download knowledge graphs
13 | Commonsense Knowledge Graph Sources :
14 |
15 | * [ConceptNet](https://home.ttic.edu/~kgimpel/commonsense.html)
16 | * [ATOMIC](https://allenai.org/data/atomic)
17 | * [ATOMIC-2020](https://allenai.org/data/atomic-2020)
18 |
19 | #### Simple preprocess
20 | * Download the above files, and preprocess each element in tab-separated tsv format.
21 | ```
22 | # examples
23 | subject1 \t relation1 \t object1
24 | subject2 \t relation2 \t object2
25 | ...
26 | ```
27 |
28 | * Modify the preprocessed data path in config/{dataset}/dataset.yml
29 | ```
30 | name: 'atomic'
31 | truncate:
32 | subj_len: 25
33 | obj_len: 25
34 | dir:
35 | train: {your path}
36 | dev: {your path}
37 | test: {your path}
38 | sim: ## <- This is similarity matrix. you can download it from the above url.
39 | train: {your path}
40 | dev: {your path}
41 |
42 | ```
43 |
44 | ### Fine-tuning
45 |
46 | ```
47 | python scripts/finetune.py --dataset_type {dataset} --model_name {model_name} --model_size {model_size}
48 | ```
49 |
50 | ## Pre-training
51 |
52 | ```
53 | python scripts/feature_learn.py --dataset_type {dataset} --model_name {model_name} --model_size {model_size}
54 | ```
55 |
56 |
--------------------------------------------------------------------------------
/config/atomic-2020/bart-base_baseline/v01.yml:
--------------------------------------------------------------------------------
1 | # Fine-Tuning
2 |
3 | model:
4 | name: 'bart-base'
5 | pretrained_model: 'facebook/bart-base'
6 | tokenize_model: 'facebook/bart-base'
7 |
8 | opt:
9 | lr_scheduler: "linear"
10 | warmup_steps: 200
11 | clip_grad_norm: 1.0
12 | weight_decay: 0
13 | output_dropout_p: 0.1
14 | optimizer: 'adam'
15 | adam_beta_1: 0.9
16 | adam_beta_2: 0.999
17 | adam_eps: 1E-08
18 |
19 | log:
20 | tb_period: 10
21 | val_period: 1000
22 | save_period: 5000
23 |
--------------------------------------------------------------------------------
/config/atomic-2020/bart-base_experiments/v_all_joint_v01.yml:
--------------------------------------------------------------------------------
1 | # Test Version
2 | task:
3 | max_enc: 55
4 | max_dec: 55
5 | init_with_shuffle: true # Only For Training dataset
6 | gen_task:
7 | weight: 0.0
8 | loss_func: "cross-entropy"
9 | con_task:
10 | weight: 0.8
11 | method: "cluster" # cluster, naive
12 | cluster_contrast:
13 | group_num: 32
14 | pos_subj_min: 0.85
15 | sampling_method: "adv" # random, adv
16 | adv_sampling:
17 | min_sim: 0.2
18 | max_sim: 0.7
19 | loss_func: "NT-Logistic"
20 | rec_inf_shu_task:
21 | weight: 0.2
22 | method: "naive"
23 | loss_func: "cross-entropy"
24 | no_crpt_prob: 0.25
25 | subj_crpt_prob: 0.25
26 | rel_crpt_prob: 0.25
27 | obj_crpt_prob: 0.25
28 | shuffle_prob: 0.5
29 | den_task:
30 | weight: 0.0
31 | method: "naive"
32 | subj_mask_prob: 0.33
33 | rel_mask_prob: 0.33
34 | obj_mask_prob: 0.34
35 | hint_prob: 0.3
36 | hint_from_the_front: true
37 | loss_func: "cross-entropy"
38 |
39 | model:
40 | name: 'bart-base'
41 | pretrained_model: 'facebook/bart-base'
42 | tokenize_model: 'facebook/bart-base'
43 | task_adaptor_options:
44 | common:
45 | use_task_prefix: true
46 | con_task:
47 | format: 'enc-dec'
48 | dec_input_ids_s: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id |
49 | dec_input_ids_ro: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | con_task_token_ro_id
50 | rec_inf_shu_task:
51 | format: 'enc-dec'
52 | dec_input_ids: 'bos_token_id' # bos_token_id | den_task_token_id | den_task_token_for_dec_id
53 |
54 | contrastive_head:
55 | proj_layer_type: 'multi-head' # multi-head-deeper | multi-head | non-linear
56 | pool_method: 'all_joint' # dec_bos, mean_pool_enc, all_joint,
57 | multi-head:
58 | head_num: 8
59 | head_dim: 512
60 | pool_method: "maxpool"
61 | multi-head-deeper:
62 | head_num: 12
63 | inner_hidden_mul: 3
64 | head_dim: 512
65 | pool_method: "maxpool"
66 |
67 | opt:
68 | lr_scheduler: "linear"
69 | warmup_steps: 200
70 | clip_grad_norm: 1.0
71 | weight_decay: 0
72 | output_dropout_p: 0.1
73 | optimizer: 'adam'
74 | adam_beta_1: 0.9
75 | adam_beta_2: 0.999
76 | adam_eps: 1E-08
77 | temperature: 0.1 #0.1
78 | use_l2: true
79 | log:
80 | tb_period: 10
81 | val_period: 1000
82 | save_period: 10000
83 |
--------------------------------------------------------------------------------
/config/atomic-2020/bart-large_baseline/v01.yml:
--------------------------------------------------------------------------------
1 | # Fine-Tuning
2 |
3 | model:
4 | name: 'bart-large'
5 | pretrained_model: 'facebook/bart-large'
6 | tokenize_model: 'facebook/bart-large'
7 |
8 | opt:
9 | lr_scheduler: "linear"
10 | warmup_steps: 200
11 | clip_grad_norm: 1.0
12 | weight_decay: 0
13 | output_dropout_p: 0.1
14 | optimizer: 'adam'
15 | adam_beta_1: 0.9
16 | adam_beta_2: 0.999
17 | adam_eps: 1E-08
18 |
19 | log:
20 | tb_period: 10
21 | val_period: 1000
22 | save_period: 5000
23 |
--------------------------------------------------------------------------------
/config/atomic-2020/bart-large_experiments/v01.yml:
--------------------------------------------------------------------------------
1 | # Test Version
2 |
3 | task:
4 | max_enc: 55
5 | max_dec: 55
6 | init_with_shuffle: true # Only For Training dataset
7 | gen_task:
8 | weight: 0.0
9 | loss_func: "cross-entropy"
10 | con_task:
11 | weight: 0.8
12 | method: "cluster" # cluster, naive
13 | cluster_contrast:
14 | group_num: 16
15 | pos_subj_min: 0.75
16 | sampling_method: "adv" # random, adv
17 | adv_sampling:
18 | min_sim: 0.4
19 | max_sim: 0.6
20 | loss_func: "NT-Logistic"
21 | rec_inf_shu_task:
22 | weight: 0.2
23 | method: "naive"
24 | loss_func: "cross-entropy"
25 | no_crpt_prob: 0.25
26 | subj_crpt_prob: 0.25
27 | rel_crpt_prob: 0.25
28 | obj_crpt_prob: 0.25
29 | shuffle_prob: 0.5
30 | den_task:
31 | weight: 0.0
32 | method: "naive"
33 | subj_mask_prob: 0.33
34 | rel_mask_prob: 0.33
35 | obj_mask_prob: 0.34
36 | hint_prob: 0.3
37 | hint_from_the_front: true
38 | loss_func: "cross-entropy"
39 |
40 | model:
41 | name: 'bart-base'
42 | pretrained_model: 'facebook/bart-base'
43 | tokenize_model: 'facebook/bart-base'
44 | task_adaptor_options:
45 | common:
46 | use_task_prefix: true
47 | con_task:
48 | format: 'enc-dec'
49 | dec_input_ids_s: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id |
50 | dec_input_ids_ro: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | con_task_token_ro_id
51 | rec_inf_shu_task:
52 | format: 'enc-dec'
53 | dec_input_ids: 'bos_token_id' # bos_token_id | den_task_token_id | den_task_token_for_dec_id
54 |
55 | contrastive_head:
56 | name: 'multi-head' # multi-head-deeper | multi-head | non-linear
57 | multi-head:
58 | head_num: 4
59 | head_dim: 512
60 | pool_method: "maxpool"
61 | multi-head-deeper:
62 | head_num: 12
63 | inner_hidden_mul: 3
64 | head_dim: 512
65 | pool_method: "maxpool"
66 |
67 | opt:
68 | lr_scheduler: "linear"
69 | warmup_steps: 200
70 | clip_grad_norm: 1.0
71 | weight_decay: 0
72 | output_dropout_p: 0.1
73 | optimizer: 'adam'
74 | adam_beta_1: 0.9
75 | adam_beta_2: 0.999
76 | adam_eps: 1E-08
77 | temperature: 0.1 #0.1
78 | use_l2: true
79 | log:
80 | tb_period: 10
81 | val_period: 1000
82 | save_period: 10000
83 |
--------------------------------------------------------------------------------
/config/atomic-2020/bart-large_experiments/v_all_joint_v01.yml:
--------------------------------------------------------------------------------
1 | # Test Version
2 | task:
3 | max_enc: 25
4 | max_dec: 25
5 | init_with_shuffle: true # Only For Training dataset
6 | gen_task:
7 | weight: 0.0
8 | loss_func: "cross-entropy"
9 | con_task:
10 | weight: 0.8
11 | method: "cluster" # cluster, naive
12 | cluster_contrast:
13 | group_num: 4
14 | pos_subj_min: 0.85
15 | sampling_method: "adv" # random, adv
16 | adv_sampling:
17 | min_sim: 0.2
18 | max_sim: 0.7
19 | loss_func: "NT-Logistic"
20 | rec_inf_shu_task:
21 | weight: 0.2
22 | method: "naive"
23 | loss_func: "cross-entropy"
24 | no_crpt_prob: 0.25
25 | subj_crpt_prob: 0.25
26 | rel_crpt_prob: 0.25
27 | obj_crpt_prob: 0.25
28 | shuffle_prob: 0.5
29 | den_task:
30 | weight: 0.0
31 | method: "naive"
32 | subj_mask_prob: 0.33
33 | rel_mask_prob: 0.33
34 | obj_mask_prob: 0.34
35 | hint_prob: 0.3
36 | hint_from_the_front: true
37 | loss_func: "cross-entropy"
38 |
39 | model:
40 | name: 'bart-large'
41 | pretrained_model: 'facebook/bart-large'
42 | tokenize_model: 'facebook/bart-large'
43 | task_adaptor_options:
44 | common:
45 | use_task_prefix: true
46 | con_task:
47 | format: 'enc-dec'
48 | dec_input_ids_s: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id |
49 | dec_input_ids_ro: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | con_task_token_ro_id
50 | rec_inf_shu_task:
51 | format: 'enc-dec'
52 | dec_input_ids: 'bos_token_id' # bos_token_id | den_task_token_id | den_task_token_for_dec_id
53 |
54 | contrastive_head:
55 | proj_layer_type: 'multi-head' # multi-head-deeper | multi-head | non-linear
56 | pool_method: 'all_joint' # dec_bos, mean_pool_enc, all_joint,
57 | multi-head:
58 | head_num: 8
59 | head_dim: 512
60 | pool_method: "maxpool"
61 | multi-head-deeper:
62 | head_num: 12
63 | inner_hidden_mul: 3
64 | head_dim: 512
65 | pool_method: "maxpool"
66 |
67 | opt:
68 | lr_scheduler: "linear"
69 | warmup_steps: 200
70 | clip_grad_norm: 1.0
71 | weight_decay: 0
72 | output_dropout_p: 0.1
73 | optimizer: 'adam'
74 | adam_beta_1: 0.9
75 | adam_beta_2: 0.999
76 | adam_eps: 1E-08
77 | temperature: 0.1 #0.1
78 | use_l2: true
79 | log:
80 | tb_period: 10
81 | val_period: 1000
82 | save_period: 10000
83 |
--------------------------------------------------------------------------------
/config/atomic-2020/datasets.yml:
--------------------------------------------------------------------------------
1 | # Test Version
2 |
3 | name: 'atomic-2020'
4 | truncate:
5 | subj_len: 25
6 | obj_len: 25
7 | dir:
8 | train: '/mnt/data/user8/solar-commonsense_inference/data/tuple_data/atomic-2020/train.tsv'
9 | dev: '/mnt/data/user8/solar-commonsense_inference/data/tuple_data/atomic-2020/dev.tsv'
10 | test : '/mnt/data/user8/solar-commonsense_inference/data/tuple_data/atomic-2020/test.tsv'
11 | sim:
12 | train: '/mnt/data/user8/solar-commonsense_inference/data/sim_mat/atomic-2020/train_subj_dist__a_concept_related_to.pkl'
13 | dev: '/mnt/data/user8/solar-commonsense_inference/data/sim_mat/atomic-2020/dev_subj_dist__a_concept_related_to.pkl'
14 |
--------------------------------------------------------------------------------
/config/atomic/bart-base_baseline/v01.yml:
--------------------------------------------------------------------------------
1 | # Fine-Tuning
2 |
3 | model:
4 | name: 'bart-base'
5 | pretrained_model: 'facebook/bart-base'
6 | tokenize_model: 'facebook/bart-base'
7 |
8 | opt:
9 | lr_scheduler: "linear"
10 | warmup_steps: 200
11 | clip_grad_norm: 1.0
12 | weight_decay: 0
13 | output_dropout_p: 0.1
14 | optimizer: 'adam'
15 | adam_beta_1: 0.9
16 | adam_beta_2: 0.999
17 | adam_eps: 1E-08
18 |
19 | log:
20 | tb_period: 10
21 | val_period: 1000
22 | save_period: 5000
23 |
--------------------------------------------------------------------------------
/config/atomic/bart-base_experiments/v_all_joint_v01.yml:
--------------------------------------------------------------------------------
1 | # Test Version
2 |
3 | task:
4 | max_enc: 25
5 | max_dec: 25
6 | init_with_shuffle: true # Only For Training dataset
7 | gen_task:
8 | weight: 0.0
9 | loss_func: "cross-entropy"
10 | con_task:
11 | weight: 0.8
12 | method: "cluster" # cluster, naive
13 | cluster_contrast:
14 | group_num: 32
15 | pos_subj_min: 0.8
16 | sampling_method: "adv" # random, adv
17 | adv_sampling:
18 | min_sim: 0.2
19 | max_sim: 0.7
20 | loss_func: "NT-Logistic"
21 | rec_inf_shu_task:
22 | weight: 0.2
23 | method: "naive"
24 | loss_func: "cross-entropy"
25 | no_crpt_prob: 0.25
26 | subj_crpt_prob: 0.25
27 | rel_crpt_prob: 0.25
28 | obj_crpt_prob: 0.25
29 | shuffle_prob: 0.5
30 | den_task:
31 | weight: 0.0
32 | method: "naive"
33 | subj_mask_prob: 0.33
34 | rel_mask_prob: 0.33
35 | obj_mask_prob: 0.34
36 | hint_prob: 0.3
37 | hint_from_the_front: true
38 | loss_func: "cross-entropy"
39 |
40 | model:
41 | name: 'bart-base'
42 | pretrained_model: 'facebook/bart-base'
43 | tokenize_model: 'facebook/bart-base'
44 | task_adaptor_options:
45 | common:
46 | use_task_prefix: true
47 | con_task:
48 | format: 'enc-dec'
49 | dec_input_ids_s: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id |
50 | dec_input_ids_ro: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | con_task_token_ro_id
51 | rec_inf_shu_task:
52 | format: 'enc-dec'
53 | dec_input_ids: 'bos_token_id' # bos_token_id | den_task_token_id | den_task_token_for_dec_id
54 |
55 |
56 | contrastive_head:
57 | proj_layer_type: 'multi-head' # multi-head-deeper | multi-head | non-linear
58 | pool_method: 'all_joint' # dec_bos, mean_pool_enc, all_joint,
59 | multi-head:
60 | head_num: 8
61 | head_dim: 512
62 | pool_method: "maxpool"
63 | multi-head-deeper:
64 | head_num: 12
65 | inner_hidden_mul: 3
66 | head_dim: 512
67 | pool_method: "maxpool"
68 |
69 | opt:
70 | lr_scheduler: "linear"
71 | warmup_steps: 200
72 | clip_grad_norm: 1.0
73 | weight_decay: 0
74 | output_dropout_p: 0.1
75 | optimizer: 'adam'
76 | adam_beta_1: 0.9
77 | adam_beta_2: 0.999
78 | adam_eps: 1E-08
79 | temperature: 0.1 #0.1
80 | use_l2: true
81 | log:
82 | tb_period: 10
83 | val_period: 1000
84 | save_period: 10000
85 |
--------------------------------------------------------------------------------
/config/atomic/bart-large_baseline/v01.yml:
--------------------------------------------------------------------------------
1 | # Fine-Tuning
2 |
3 | model:
4 | name: 'bart-large'
5 | pretrained_model: 'facebook/bart-large'
6 | tokenize_model: 'facebook/bart-large'
7 |
8 | opt:
9 | lr_scheduler: "linear"
10 | warmup_steps: 200
11 | clip_grad_norm: 1.0
12 | weight_decay: 0
13 | output_dropout_p: 0.1
14 | optimizer: 'adam'
15 | adam_beta_1: 0.9
16 | adam_beta_2: 0.999
17 | adam_eps: 1E-08
18 |
19 | log:
20 | tb_period: 10
21 | val_period: 1000
22 | save_period: 5000
23 |
--------------------------------------------------------------------------------
/config/atomic/bart-large_experiments/v01.yml:
--------------------------------------------------------------------------------
1 | # Test Version
2 |
3 | task:
4 | max_enc: 55
5 | max_dec: 55
6 | init_with_shuffle: true # Only For Training dataset
7 | gen_task:
8 | weight: 0.0
9 | loss_func: "cross-entropy"
10 | con_task:
11 | weight: 0.8
12 | method: "cluster" # cluster, naive
13 | cluster_contrast:
14 | group_num: 32
15 | pos_subj_min: 0.85
16 | sampling_method: "adv" # random, adv
17 | adv_sampling:
18 | min_sim: 0.2
19 | max_sim: 0.7
20 | loss_func: "NT-Logistic"
21 | rec_inf_shu_task:
22 | weight: 0.2
23 | method: "naive"
24 | loss_func: "cross-entropy"
25 | no_crpt_prob: 0.25
26 | subj_crpt_prob: 0.25
27 | rel_crpt_prob: 0.25
28 | obj_crpt_prob: 0.25
29 | shuffle_prob: 0.5
30 | den_task:
31 | weight: 0.0
32 | method: "naive"
33 | subj_mask_prob: 0.33
34 | rel_mask_prob: 0.33
35 | obj_mask_prob: 0.34
36 | hint_prob: 0.3
37 | hint_from_the_front: true
38 | loss_func: "cross-entropy"
39 |
40 | model:
41 | name: 'bart-large'
42 | pretrained_model: 'facebook/bart-large'
43 | tokenize_model: 'facebook/bart-large'
44 | task_adaptor_options:
45 | common:
46 | use_task_prefix: true
47 | con_task:
48 | format: 'enc-dec'
49 | dec_input_ids_s: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id |
50 | dec_input_ids_ro: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | con_task_token_ro_id
51 | den_task:
52 | format: 'naive'
53 | dec_input_ids: 'bos_token_id' # bos_token_id | den_task_token_id | den_task_token_for_dec_id
54 | rec_inf_shu_task:
55 | format: 'naive'
56 | dec_input_ids: 'bos_token_id' # bos_token_id | den_task_token_id | den_task_token_for_dec_id
57 |
58 | contrastive_head:
59 | name: 'multi-head' # multi-head-deeper | multi-head | non-linear
60 | multi-head:
61 | head_num: 8
62 | head_dim: 512
63 | pool_method: "maxpool"
64 | multi-head-deeper:
65 | head_num: 12
66 | inner_hidden_mul: 3
67 | head_dim: 512
68 | pool_method: "maxpool"
69 |
70 | opt:
71 | lr_scheduler: "linear"
72 | warmup_steps: 200
73 | clip_grad_norm: 1.0
74 | weight_decay: 0
75 | output_dropout_p: 0.1
76 | optimizer: 'adam'
77 | adam_beta_1: 0.9
78 | adam_beta_2: 0.999
79 | adam_eps: 1E-08
80 | temperature: 0.1 #0.1
81 | use_l2: true
82 | log:
83 | tb_period: 10
84 | val_period: 1000
85 | save_period: 10000
86 |
--------------------------------------------------------------------------------
/config/atomic/bart-large_experiments/v_all_joint_v01.yml:
--------------------------------------------------------------------------------
1 | # Test Version
2 |
3 | task:
4 | max_enc: 25
5 | max_dec: 25
6 | init_with_shuffle: true # Only For Training dataset
7 | gen_task:
8 | weight: 0.0
9 | loss_func: "cross-entropy"
10 | con_task:
11 | weight: 0.8
12 | method: "cluster" # cluster, naive
13 | cluster_contrast:
14 | group_num: 4
15 | pos_subj_min: 0.85
16 | sampling_method: "adv" # random, adv
17 | adv_sampling:
18 | min_sim: 0.2
19 | max_sim: 0.7
20 | loss_func: "NT-Logistic"
21 | rec_inf_shu_task:
22 | weight: 0.2
23 | method: "naive"
24 | loss_func: "cross-entropy"
25 | no_crpt_prob: 0.25
26 | subj_crpt_prob: 0.25
27 | rel_crpt_prob: 0.25
28 | obj_crpt_prob: 0.25
29 | shuffle_prob: 0.5
30 | den_task:
31 | weight: 0.0
32 | method: "naive"
33 | subj_mask_prob: 0.33
34 | rel_mask_prob: 0.33
35 | obj_mask_prob: 0.34
36 | hint_prob: 0.3
37 | hint_from_the_front: true
38 | loss_func: "cross-entropy"
39 |
40 | model:
41 | name: 'bart-large'
42 | pretrained_model: 'facebook/bart-large'
43 | tokenize_model: 'facebook/bart-large'
44 | task_adaptor_options:
45 | common:
46 | use_task_prefix: true
47 | con_task:
48 | format: 'enc-dec'
49 | dec_input_ids_s: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id |
50 | dec_input_ids_ro: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | con_task_token_ro_id
51 | rec_inf_shu_task:
52 | format: 'enc-dec'
53 | dec_input_ids: 'bos_token_id' # bos_token_id | den_task_token_id | den_task_token_for_dec_id
54 |
55 |
56 | contrastive_head:
57 | proj_layer_type: 'multi-head' # multi-head-deeper | multi-head | non-linear
58 | pool_method: 'all_joint' # dec_bos, mean_pool_enc, all_joint,
59 | multi-head:
60 | head_num: 8
61 | head_dim: 512
62 | pool_method: "maxpool"
63 | multi-head-deeper:
64 | head_num: 12
65 | inner_hidden_mul: 3
66 | head_dim: 512
67 | pool_method: "maxpool"
68 |
69 | opt:
70 | lr_scheduler: "linear"
71 | warmup_steps: 200
72 | clip_grad_norm: 1.0
73 | weight_decay: 0
74 | output_dropout_p: 0.1
75 | optimizer: 'adam'
76 | adam_beta_1: 0.9
77 | adam_beta_2: 0.999
78 | adam_eps: 1E-08
79 | temperature: 0.1 #0.1
80 | use_l2: true
81 | log:
82 | tb_period: 10
83 | val_period: 1000
84 | save_period: 10000
85 |
--------------------------------------------------------------------------------
/config/atomic/datasets.yml:
--------------------------------------------------------------------------------
1 | name: 'atomic'
2 | truncate:
3 | subj_len: 25
4 | obj_len: 25
5 | dir:
6 | train: '/mnt/data/user8/solar-commonsense_inference/data/tuple_data/atomic/train.tsv'
7 | dev: '/mnt/data/user8/solar-commonsense_inference/data/tuple_data/atomic/dev.tsv'
8 | test: '/mnt/data/user8/solar-commonsense_inference/data/tuple_data/atomic/test.tsv'
9 | sim:
10 | train: '/mnt/data/user8/solar-commonsense_inference/data/sim_mat/atomic/train_subj_dist__a_concept_related_to.pkl'
11 | dev: '/mnt/data/user8/solar-commonsense_inference/data/sim_mat/atomic/dev_subj_dist__a_concept_related_to.pkl'
12 |
--------------------------------------------------------------------------------
/config/conceptnet/bart-base_baseline/v01.yml:
--------------------------------------------------------------------------------
1 | # Fine-Tuning
2 |
3 | model:
4 | name: 'bart-base'
5 | pretrained_model: 'facebook/bart-base'
6 | tokenize_model: 'facebook/bart-base'
7 |
8 | opt:
9 | lr_scheduler: "linear"
10 | warmup_steps: 200
11 | clip_grad_norm: 1.0
12 | weight_decay: 0
13 | output_dropout_p: 0.1
14 | optimizer: 'adam'
15 | adam_beta_1: 0.9
16 | adam_beta_2: 0.999
17 | adam_eps: 1E-08
18 |
19 | log:
20 | tb_period: 10
21 | val_period: 1000
22 | save_period: 5000
23 |
--------------------------------------------------------------------------------
/config/conceptnet/bart-base_experiments/v_all_joint_v01.yml:
--------------------------------------------------------------------------------
1 | # Test Version
2 |
3 | task:
4 | max_enc: 25
5 | max_dec: 25
6 | init_with_shuffle: true # Only For Training dataset
7 | gen_task:
8 | weight: 0.0
9 | loss_func: "cross-entropy"
10 | con_task:
11 | weight: 0.8
12 | method: "cluster" # cluster, naive
13 | cluster_contrast:
14 | group_num: 32
15 | pos_subj_min: 0.8
16 | sampling_method: "adv" # random, adv
17 | adv_sampling:
18 | min_sim: 0.2
19 | max_sim: 0.7
20 | loss_func: "NT-Logistic"
21 | rec_inf_shu_task:
22 | weight: 0.2
23 | method: "naive"
24 | loss_func: "cross-entropy"
25 | no_crpt_prob: 0.25
26 | subj_crpt_prob: 0.25
27 | rel_crpt_prob: 0.25
28 | obj_crpt_prob: 0.25
29 | shuffle_prob: 0.5
30 | den_task:
31 | weight: 0.0
32 | method: "naive"
33 | subj_mask_prob: 0.33
34 | rel_mask_prob: 0.33
35 | obj_mask_prob: 0.34
36 | hint_prob: 0.3
37 | hint_from_the_front: true
38 | loss_func: "cross-entropy"
39 |
40 | model:
41 | name: 'bart-base'
42 | pretrained_model: 'facebook/bart-base'
43 | tokenize_model: 'facebook/bart-base'
44 | task_adaptor_options:
45 | common:
46 | use_task_prefix: true
47 | con_task:
48 | format: 'enc-dec'
49 | dec_input_ids_s: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id |
50 | dec_input_ids_ro: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | con_task_token_ro_id
51 | rec_inf_shu_task:
52 | format: 'enc-dec'
53 | dec_input_ids: 'bos_token_id' # bos_token_id | den_task_token_id | den_task_token_for_dec_id
54 |
55 |
56 | contrastive_head:
57 | proj_layer_type: 'multi-head' # multi-head-deeper | multi-head | non-linear
58 | pool_method: 'all_joint' # dec_bos, mean_pool_enc, all_joint,
59 | multi-head:
60 | head_num: 8
61 | head_dim: 512
62 | pool_method: "maxpool"
63 | multi-head-deeper:
64 | head_num: 12
65 | inner_hidden_mul: 3
66 | head_dim: 512
67 | pool_method: "maxpool"
68 |
69 | opt:
70 | lr_scheduler: "linear"
71 | warmup_steps: 200
72 | clip_grad_norm: 1.0
73 | weight_decay: 0
74 | output_dropout_p: 0.1
75 | optimizer: 'adam'
76 | adam_beta_1: 0.9
77 | adam_beta_2: 0.999
78 | adam_eps: 1E-08
79 | temperature: 0.1 #0.1
80 | use_l2: true
81 | log:
82 | tb_period: 10
83 | val_period: 1000
84 | save_period: 10000
85 |
--------------------------------------------------------------------------------
/config/conceptnet/bart-large_baseline/v01.yml:
--------------------------------------------------------------------------------
1 | # Fine-Tuning
2 |
3 | model:
4 | name: 'bart-large'
5 | pretrained_model: 'facebook/bart-large'
6 | tokenize_model: 'facebook/bart-large'
7 |
8 | opt:
9 | lr_scheduler: "linear"
10 | warmup_steps: 200
11 | clip_grad_norm: 1.0
12 | weight_decay: 0
13 | output_dropout_p: 0.1
14 | optimizer: 'adam'
15 | adam_beta_1: 0.9
16 | adam_beta_2: 0.999
17 | adam_eps: 1E-08
18 |
19 | log:
20 | tb_period: 10
21 | val_period: 1000
22 | save_period: 5000
--------------------------------------------------------------------------------
/config/conceptnet/bart-large_experiments/v01.yml:
--------------------------------------------------------------------------------
1 | # Test Version
2 |
3 | task:
4 | max_enc: 55
5 | max_dec: 55
6 | init_with_shuffle: true # Only For Training dataset
7 | gen_task:
8 | weight: 0.0
9 | loss_func: "cross-entropy"
10 | con_task:
11 | weight: 0.8
12 | method: "cluster" # cluster, naive
13 | cluster_contrast:
14 | group_num: 32
15 | pos_subj_min: 0.85
16 | sampling_method: "adv" # random, adv
17 | adv_sampling:
18 | min_sim: 0.2
19 | max_sim: 0.7
20 | loss_func: "NT-Logistic"
21 | rec_inf_shu_task:
22 | weight: 0.2
23 | method: "naive"
24 | loss_func: "cross-entropy"
25 | no_crpt_prob: 0.25
26 | subj_crpt_prob: 0.25
27 | rel_crpt_prob: 0.25
28 | obj_crpt_prob: 0.25
29 | shuffle_prob: 0.5
30 | den_task:
31 | weight: 0.0
32 | method: "naive"
33 | subj_mask_prob: 0.33
34 | rel_mask_prob: 0.33
35 | obj_mask_prob: 0.34
36 | hint_prob: 0.3
37 | hint_from_the_front: true
38 | loss_func: "cross-entropy"
39 |
40 | model:
41 | name: 'bart-large'
42 | pretrained_model: 'facebook/bart-large'
43 | tokenize_model: 'facebook/bart-large'
44 | task_adaptor_options:
45 | common:
46 | use_task_prefix: true
47 | con_task:
48 | format: 'enc-dec'
49 | dec_input_ids_s: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id |
50 | dec_input_ids_ro: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | con_task_token_ro_id
51 | den_task:
52 | format: 'naive'
53 | dec_input_ids: 'bos_token_id' # bos_token_id | den_task_token_id | den_task_token_for_dec_id
54 | rec_inf_shu_task:
55 | format: 'naive'
56 | dec_input_ids: 'bos_token_id' # bos_token_id | den_task_token_id | den_task_token_for_dec_id
57 |
58 | contrastive_head:
59 | name: 'multi-head' # multi-head-deeper | multi-head | non-linear
60 | multi-head:
61 | head_num: 8
62 | head_dim: 512
63 | pool_method: "maxpool"
64 | multi-head-deeper:
65 | head_num: 12
66 | inner_hidden_mul: 3
67 | head_dim: 512
68 | pool_method: "maxpool"
69 |
70 | opt:
71 | lr_scheduler: "linear"
72 | warmup_steps: 200
73 | clip_grad_norm: 1.0
74 | weight_decay: 0
75 | output_dropout_p: 0.1
76 | optimizer: 'adam'
77 | adam_beta_1: 0.9
78 | adam_beta_2: 0.999
79 | adam_eps: 1E-08
80 | temperature: 0.1 #0.1
81 | use_l2: true
82 | log:
83 | tb_period: 10
84 | val_period: 1000
85 | save_period: 10000
86 |
--------------------------------------------------------------------------------
/config/conceptnet/bart-large_experiments/v_all_joint_v01.yml:
--------------------------------------------------------------------------------
1 | # Test Version
2 |
3 | task:
4 | max_enc: 25
5 | max_dec: 25
6 | init_with_shuffle: true # Only For Training dataset
7 | gen_task:
8 | weight: 0.0
9 | loss_func: "cross-entropy"
10 | con_task:
11 | weight: 0.8
12 | method: "cluster" # cluster, naive
13 | cluster_contrast:
14 | group_num: 4
15 | pos_subj_min: 0.85
16 | sampling_method: "adv" # random, adv
17 | adv_sampling:
18 | min_sim: 0.2
19 | max_sim: 0.7
20 | loss_func: "NT-Logistic"
21 | rec_inf_shu_task:
22 | weight: 0.2
23 | method: "naive"
24 | loss_func: "cross-entropy"
25 | no_crpt_prob: 0.25
26 | subj_crpt_prob: 0.25
27 | rel_crpt_prob: 0.25
28 | obj_crpt_prob: 0.25
29 | shuffle_prob: 0.5
30 | den_task:
31 | weight: 0.0
32 | method: "naive"
33 | subj_mask_prob: 0.33
34 | rel_mask_prob: 0.33
35 | obj_mask_prob: 0.34
36 | hint_prob: 0.3
37 | hint_from_the_front: true
38 | loss_func: "cross-entropy"
39 |
40 | model:
41 | name: 'bart-large'
42 | pretrained_model: 'facebook/bart-large'
43 | tokenize_model: 'facebook/bart-large'
44 | task_adaptor_options:
45 | common:
46 | use_task_prefix: true
47 | con_task:
48 | format: 'enc-dec'
49 | dec_input_ids_s: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id |
50 | dec_input_ids_ro: 'con_task_token_id' # bos_token_id | con_task_token_id | con_task_token_s_id | con_task_token_ro_id
51 | rec_inf_shu_task:
52 | format: 'enc-dec'
53 | dec_input_ids: 'bos_token_id' # bos_token_id | den_task_token_id | den_task_token_for_dec_id
54 |
55 |
56 | contrastive_head:
57 | proj_layer_type: 'multi-head' # multi-head-deeper | multi-head | non-linear
58 | pool_method: 'all_joint' # dec_bos, mean_pool_enc, all_joint,
59 | multi-head:
60 | head_num: 8
61 | head_dim: 512
62 | pool_method: "maxpool"
63 | multi-head-deeper:
64 | head_num: 12
65 | inner_hidden_mul: 3
66 | head_dim: 512
67 | pool_method: "maxpool"
68 |
69 | opt:
70 | lr_scheduler: "linear"
71 | warmup_steps: 200
72 | clip_grad_norm: 1.0
73 | weight_decay: 0
74 | output_dropout_p: 0.1
75 | optimizer: 'adam'
76 | adam_beta_1: 0.9
77 | adam_beta_2: 0.999
78 | adam_eps: 1E-08
79 | temperature: 0.1 #0.1
80 | use_l2: true
81 | log:
82 | tb_period: 10
83 | val_period: 1000
84 | save_period: 10000
85 |
--------------------------------------------------------------------------------
/config/conceptnet/datasets.yml:
--------------------------------------------------------------------------------
1 | # Test Version
2 |
3 | name: 'conceptnet'
4 | truncate:
5 | subj_len: 25
6 | obj_len: 25
7 | dir:
8 | train: '/mnt/data/user8/solar-commonsense_inference/data/tuple_data/conceptnet/train.tsv'
9 | dev: '/mnt/data/user8/solar-commonsense_inference/data/tuple_data/conceptnet/dev.tsv'
10 | test: '/mnt/data/user8/solar-commonsense_inference/data/tuple_data/conceptnet/test.tsv'
11 | sim:
12 | train: '/mnt/data/user8/solar-commonsense_inference/data/sim_mat/conceptnet/train_subj_dist__a_concept_related_to.pkl'
13 | dev: '/mnt/data/user8/solar-commonsense_inference/data/sim_mat/conceptnet/dev_subj_dist__a_concept_related_to.pkl'
14 |
--------------------------------------------------------------------------------
/config/tokenizer_config.yml:
--------------------------------------------------------------------------------
1 | # Test Version
2 |
3 | special_tokens:
4 | bos_token: ''
5 | eos_token: ''
6 | pad_token: ''
7 | mask_token: ''
8 | sep_token: ''
9 |
10 | additional_tokens:
11 | gen_token: ''
12 | blk_token: ''
13 | none_token: ''
14 | mask_gen_task_token: ''
15 | mask_gen_task_token_for_dec: ''
16 | gen_task_token: ''
17 | gen_task_token_for_dec: ''
18 | con_task_token: ''
19 | con_task_token_for_s: ''
20 | con_task_token_for_ro: ''
21 | den_task_token: ''
22 | den_task_token_for_dec: ''
23 | rec_task_token: ''
24 | rec_task_token_for_dec: ''
25 | cla_task_token: ''
26 |
27 | relation_tokens:
28 | - ''
29 | - ''
30 | - ''
31 | - ''
32 | - ''
33 | - ''
34 | - ''
35 | - ''
36 | - ''
37 | - ''
38 | - ''
39 | - ''
40 | - ''
41 | - ''
42 | - ''
43 | - ''
44 | - ''
45 | - ''
46 | - ''
47 | - ''
48 | - ''
49 | - ''
50 | - ''
51 | - ''
52 | - ''
53 | - ''
54 | - ''
55 | - ''
56 | - ''
57 | - ''
58 | - ''
59 | - ''
60 | - ''
61 | - ''
62 | - ''
63 | - ''
64 | - ''
65 | - ''
66 | - ''
67 | - ''
68 | - ''
69 | - ''
70 | - ''
71 | - ''
72 | - ''
73 | - ''
74 | - ''
75 | - ''
76 | - ''
77 | - ''
78 | - ''
79 |
--------------------------------------------------------------------------------
/models/bart.py:
--------------------------------------------------------------------------------
1 | from transformers import BartModel
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from models.head_proj_layer import *
6 | from transformers import BartTokenizer, BartForConditionalGeneration, BartModel
7 | from models.model_utils import convert_model
8 | from torch.nn import Linear
9 |
10 |
11 | class CometBART(BartModel):
12 | def modify_lm_heads(self, vocab_len, lm_head_dropout_p):
13 | self.resize_token_embeddings(vocab_len)
14 | self.lm_head = nn.Linear(self.config.hidden_size, vocab_len)
15 | self.lm_head_dropout = nn.Dropout(p=lm_head_dropout_p)
16 |
17 | def add_proj_layer(self, proj_options, proj_head_dropout_p, shared_proj_layer=True):
18 | self.proj_head_dropout = nn.Dropout(p=proj_head_dropout_p)
19 | self.proj_options = proj_options
20 | self.seq_method = proj_options['pool_method']
21 | if shared_proj_layer is True:
22 | self.proj_head = self._get_proj_head()
23 | else:
24 | self.proj_head_s = self._get_proj_head()
25 | self.proj_head_ro = self._get_proj_head()
26 |
27 | def _get_proj_head(self):
28 | proj_layer_type = self.proj_options['proj_layer_type']
29 | hidden_size = self.config.hidden_size
30 |
31 | if proj_layer_type == 'non-linear':
32 | head_layer = NonLinearHeadProjLayer(input_hidden_size)
33 | elif proj_layer_type == 'multi-head':
34 | sub_opt = self.proj_options[proj_layer_type]
35 | head_num = sub_opt['head_num']
36 | head_dim = int(hidden_size / head_num) if sub_opt['head_dim'] == -1 else sub_opt['head_dim']
37 | head_layer = MultiHeadProjLayer(hidden_size, head_num, head_dim)
38 | else:
39 | raise NotImplementedError
40 |
41 | return head_layer
42 |
43 | def forward_conditional_gen(self, enc_input_ids, enc_att_mask, dec_input_ids, dec_att_mask):
44 | outputs = super().forward(input_ids=enc_input_ids, attention_mask=enc_att_mask,
45 | decoder_input_ids=dec_input_ids, decoder_attention_mask=dec_att_mask)
46 |
47 | last_hidden_state = outputs.last_hidden_state
48 | last_hidden_state = self.lm_head_dropout(last_hidden_state)
49 | lm_logits = self.lm_head(last_hidden_state)
50 |
51 | return lm_logits
52 |
53 | def forward_latent_feature(self, enc_input_ids, enc_att_mask, dec_input_ids, dec_att_mask, for_s=None):
54 | outputs = super().forward(input_ids=enc_input_ids, attention_mask=enc_att_mask,
55 | decoder_input_ids=dec_input_ids, decoder_attention_mask=dec_att_mask)
56 |
57 | seq_feature = self.get_sequence_feature(outputs, enc_att_mask)
58 | seq_feature = self.proj_head_dropout(seq_feature)
59 | latent_vec = self._forward_projection(seq_feature, for_s)
60 |
61 | return latent_vec
62 |
63 | def get_sequence_feature(self, outputs, enc_att_mask):
64 | if self.seq_method == 'dec_bos':
65 | dec_last_hidden_state = outputs.last_hidden_state
66 | seq_feature = self._get_seq_feature_from_dec_bos(dec_last_hidden_state)
67 | elif self.seq_method == 'mean_pool_enc':
68 | enc_last_hidden_state = outputs.encoder_last_hidden_state
69 | seq_feature = self._get_seq_feature_from_mean_pool_enc(enc_last_hidden_state, enc_att_mask)
70 | elif self.seq_method == 'all_joint':
71 | enc_last_hidden_state = outputs.encoder_last_hidden_state
72 | dec_last_hidden_state = outputs.last_hidden_state
73 | seq_feature = self._get_seq_feature_from_mean_pool_enc_dec(
74 | enc_last_hidden_state, enc_att_mask, dec_last_hidden_state)
75 | else:
76 | raise NotImplementedError
77 |
78 | return seq_feature
79 |
80 | def _get_seq_feature_from_dec_bos(self, dec_last_hidden_state):
81 | seq_feature = dec_last_hidden_state[:, 0, :]
82 | return seq_feature
83 |
84 | def _get_seq_feature_from_mean_pool_enc(self, enc_last_hidden_state, att_mask):
85 | seq_feature = enc_last_hidden_state * att_mask.unsqueeze(-1) # (B, S, H)
86 | seq_feature = torch.sum(seq_feature, dim=1) # (B, H)
87 | seq_feature = seq_feature / torch.sum(att_mask, -1, keepdim=True)
88 | return seq_feature
89 |
90 | def _get_seq_feature_from_mean_pool_enc_dec(self, enc_last_hidden_state, enc_att_mask, dec_last_hidden_state):
91 | seq_feature = enc_last_hidden_state * enc_att_mask.unsqueeze(-1) # (B, S, H)
92 | seq_feature_with_dec_bos = torch.cat((seq_feature, dec_last_hidden_state[:, :1, :]),dim=1)
93 | seq_feature = torch.sum(seq_feature_with_dec_bos, dim=1) # (B, H)
94 | seq_feature = seq_feature / (torch.sum(enc_att_mask, -1, keepdim=True)+1)
95 | return seq_feature
96 |
97 | def _forward_projection(self, sequence_feature, for_s):
98 | if for_s is None:
99 | latent_vec = self.proj_head(sequence_feature)
100 | elif for_s is True:
101 | latent_vec = self.proj_head_s(sequence_feature)
102 | elif for_s is False:
103 | latent_vec = self.proj_head_ro(sequence_feature)
104 | else:
105 | raise NotImplementedError
106 |
107 | return latent_vec
108 |
109 |
110 | def convert_BARTModel_to_BartForConditionalGeneration(bart_model, params):
111 | device = bart_model.device
112 | # model_name = f'facebook/bart-{size}'
113 | gen_model = BartForConditionalGeneration.from_pretrained(params)
114 | vocab_len, hidden = bart_model.lm_head.weight.shape
115 | use_bias = False if bart_model.lm_head.bias is None else True
116 | gen_model.resize_token_embeddings(vocab_len)
117 | gen_model.lm_head = Linear(hidden, vocab_len, bias=use_bias)
118 |
119 | gen_model = convert_model(bart_model, gen_model).to(device)
120 |
121 | return gen_model
122 |
--------------------------------------------------------------------------------
/models/distance_func.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class NonLinearHeadDistance(torch.nn.Module):
5 | def __init__(self):
6 | super(NonLinearHeadDistance, self).__init__()
7 |
8 | def forward(self, batch_s, batch_ro):
9 | l2_s = batch_s / torch.linalg.norm(batch_s, dim=-1, ord=2, keepdim=True)
10 | l2_ro = batch_ro / torch.linalg.norm(batch_ro, dim=-1, ord=2, keepdim=True)
11 |
12 | l2_ro = l2_ro.transpose(0, 1)
13 |
14 | final_dist = torch.matmul(l2_s, l2_ro)
15 | return final_dist
16 |
17 |
18 | class MultiHeadDistance(torch.nn.Module):
19 | def __init__(self, head_option):
20 | super(MultiHeadDistance, self).__init__()
21 | self.head_num = head_option['head_num']
22 | self.head_dim = head_option['head_dim']
23 | self.pool = None
24 | if head_option['pool_method'] == 'maxpool':
25 | self.max_pool = torch.nn.MaxPool1d(self.head_num)
26 | else:
27 | raise NotImplementedError
28 |
29 | def forward(self, batch_s, batch_ro):
30 | batch_s = batch_s.view(list(batch_s.shape[:-1]) + [self.head_num, self.head_dim])
31 | batch_ro = batch_ro.view(list(batch_ro.shape[:-1]) + [self.head_num, self.head_dim])
32 |
33 | l2_s = batch_s / torch.linalg.norm(batch_s, dim=-1, ord=2, keepdim=True)
34 | l2_ro = batch_ro / torch.linalg.norm(batch_ro, dim=-1, ord=2, keepdim=True)
35 |
36 | l2_s = l2_s.permute([1, 0, 2])
37 | l2_ro = l2_ro.permute([1, 0, 2])
38 | l2_ro = l2_ro.transpose(1, 2)
39 |
40 | dist = torch.matmul(l2_s, l2_ro)
41 | dist = dist.permute([1, 2, 0])
42 |
43 | final_dist = self.max_pool(dist).squeeze(-1)
44 | return final_dist
45 |
46 |
--------------------------------------------------------------------------------
/models/head_proj_layer.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class NonLinearHeadProjLayer(nn.Module):
5 | def __init__(self, input_hidden_size):
6 | super(NonLinearHeadProjLayer, self).__init__()
7 | self.linear_1 = nn.Linear(input_hidden_size, input_hidden_size)
8 | self.relu = nn.ReLU()
9 | self.linear_2 = nn.Linear(input_hidden_size, input_hidden_size)
10 |
11 | def forward(self, x):
12 | x = self.linear_1(x)
13 | x = self.relu(x)
14 | x = self.linear_2(x)
15 | return x
16 |
17 |
18 | class MultiHeadProjLayer(nn.Module):
19 | def __init__(self, input_hidden_size, head_num, head_size):
20 | super(MultiHeadProjLayer, self).__init__()
21 | self.linear_1 = nn.Linear(input_hidden_size, input_hidden_size)
22 | self.relu = nn.ReLU()
23 | self.head_size = head_size if head_size != -1 else int(input_hidden_size / head_num)
24 | self.multi_head_proj = nn.Linear(input_hidden_size, head_num * head_size)
25 |
26 | def forward(self, x):
27 | x = self.linear_1(x)
28 | x = self.relu(x)
29 | x = self.multi_head_proj(x)
30 | return x
31 |
32 |
33 | class MultiHeadProjLayerDeeper(nn.Module):
34 | def __init__(self, input_hidden_size, head_num, head_size, inner_hidden_mul):
35 | super(MultiHeadProjLayerDeeper, self).__init__()
36 | self.linear_1 = nn.Linear(input_hidden_size, input_hidden_size * inner_hidden_mul)
37 | self.relu = nn.ReLU()
38 | self.head_size = head_size if head_size != -1 else int(input_hidden_size / head_num)
39 | self.linear_2 = nn.Linear(input_hidden_size * inner_hidden_mul, input_hidden_size)
40 | self.multi_head_proj = nn.Linear(input_hidden_size, head_num * head_size)
41 |
42 | def forward(self, x):
43 | x = self.linear_1(x)
44 | x = self.relu(x)
45 | x = self.linear_2(x)
46 | x = self.relu(x)
47 | x = self.multi_head_proj(x)
48 | return x
--------------------------------------------------------------------------------
/models/loss_func.py:
--------------------------------------------------------------------------------
1 | import torch
2 | EPS = 0.0001
3 |
4 |
5 | class NT_Logistic:
6 | def __init__(self, temp=0.1):
7 | self.temp = temp
8 | self.sigmoid = torch.nn.Sigmoid()
9 |
10 | def get_loss(self, dist, pos_sample):
11 | neg_sample = torch.ones_like(pos_sample) - pos_sample
12 | pos_dist = torch.mul(dist, pos_sample)
13 | neg_dist = torch.mul(dist, neg_sample)
14 | avg = (pos_sample / torch.sum(pos_sample, -1)) + (neg_sample / torch.sum(neg_sample, -1))
15 | logged_dist = torch.log(self.sigmoid((pos_dist - neg_dist) / self.temp) + 0.0001)
16 |
17 | pos_dists = torch.sum(pos_dist) / torch.sum(pos_sample)
18 | neg_dists = torch.sum(neg_dist) / torch.sum(neg_sample)
19 | avg_dist = torch.sum(torch.mul(logged_dist, avg), -1)
20 | avg_dist = -torch.mean(avg_dist)
21 | return avg_dist, pos_dists, neg_dists
22 |
23 |
24 | '''
25 | class NT_Logistic:
26 | def __init__(self, temp=0.1):
27 | self.temp = temp
28 | self.sigmoid = torch.nn.Sigmoid()
29 |
30 | def get_loss(self, dist, pos_sample):
31 | pos_sample = pos_sample
32 | neg_sample = torch.ones_like(pos_sample) - pos_sample
33 |
34 | pos_dist, pos_num = self.get_each_loss(dist, pos_sample)
35 | neg_dist, neg_num = self.get_each_loss(dist, neg_sample)
36 | logged_dist = torch.log(self.sigmoid((pos_dist - neg_dist) / self.temp) + 0.0001)
37 |
38 | avg = torch.zeros_like(dist)
39 | if pos_num.item() != 0:
40 | avg += pos_sample / pos_num
41 | if neg_num.item() != 0:
42 | avg += neg_sample / neg_num
43 |
44 | loss = -torch.sum(logged_dist * avg)
45 | print(loss)
46 | print(pos_num, neg_num)
47 | print(logged_dist[0][:50])
48 | #print(avg[0])
49 | pos_dist_total = torch.sum(pos_dist)
50 | neg_dist_total = torch.sum(neg_dist)
51 |
52 | pos_dists = (pos_dist_total.item(), pos_num.item())
53 | neg_dists = (neg_dist_total.item(), neg_num.item())
54 |
55 | return loss, pos_dists, neg_dists
56 |
57 |
58 |
59 |
60 | pos_dist = torch.mul(dist, pos_sample) + EPS
61 | neg_dist = torch.mul(dist, neg_sample) + EPS
62 | avg = (pos_sample / (torch.sum(pos_sample, -1) + EPS)) + (neg_sample / (torch.sum(neg_sample, -1) + EPS))
63 | logged_dist = torch.log(self.sigmoid((pos_dist - neg_dist) / self.temp) + 0.0001)
64 |
65 | pos_dists = torch.sum(pos_dist) / torch.sum(pos_sample) + EPS
66 | neg_dists = torch.sum(neg_dist) / torch.sum(neg_sample) + EPS
67 | avg_dist = torch.sum(torch.mul(logged_dist, avg), -1)
68 | avg_dist = -torch.mean(avg_dist)
69 | return avg_dist, pos_dists, neg_dists
70 |
71 | def get_each_loss(self, dist, sample_mask):
72 | sample_num = torch.sum(sample_mask)
73 | sample_dist = torch.mul(dist, sample_mask)
74 | return sample_dist, sample_num
75 | '''
--------------------------------------------------------------------------------
/models/model_utils.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Linear
2 | from transformers import BartForConditionalGeneration
3 |
4 | def convert_model(src_model, dst_model):
5 | params_src = src_model.named_parameters()
6 | params_dst = dst_model.named_parameters()
7 | dict_src = dict(params_src)
8 | dict_dst = dict(params_dst)
9 | match_fail = 0
10 | for name in dict_src:
11 | dst_name = f'model.{name}'
12 | if dst_name in dict_dst:
13 | dict_dst[dst_name].data.copy_(dict_src[name].data)
14 | elif name in dict_dst:
15 | dict_dst[name].data.copy_(dict_src[name].data)
16 | else:
17 | match_fail += 1
18 | print(f'Unmatched layer of dst_model to src_model : layer name : {dst_name}')
19 |
20 | if match_fail == 0:
21 | print('All layered are matched')
22 |
23 | return dst_model
24 |
25 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.0.0
2 | beautifulsoup4==4.11.1
3 | bert-score==0.3.11
4 | cachetools==4.2.4
5 | certifi==2022.5.18.1
6 | charset-normalizer==2.0.12
7 | click==8.1.3
8 | cycler==0.11.0
9 | filelock==3.7.1
10 | fonttools==4.33.3
11 | gdown==4.4.0
12 | google-auth==1.35.0
13 | google-auth-oauthlib==0.4.6
14 | grpcio==1.46.3
15 | huggingface-hub==0.7.0
16 | idna==3.3
17 | importlib-metadata==4.11.4
18 | joblib==1.1.0
19 | kiwisolver==1.4.2
20 | Markdown==3.3.7
21 | matplotlib==3.5.2
22 | nltk==3.5
23 | numpy==1.22.4
24 | oauthlib==3.2.0
25 | packaging==21.3
26 | pandas==1.4.2
27 | Pillow==9.1.1
28 | protobuf==3.14.0
29 | pyasn1==0.4.8
30 | pyasn1-modules==0.2.8
31 | pyparsing==3.0.9
32 | PySocks==1.7.1
33 | python-dateutil==2.8.2
34 | pytz==2022.1
35 | PyYAML==6.0
36 | regex==2022.4.24
37 | requests==2.27.1
38 | requests-oauthlib==1.3.1
39 | rsa==4.8
40 | six==1.16.0
41 | soupsieve==2.3.2.post1
42 | tabulate==0.8.9
43 | tensorboard==2.5.0
44 | tensorboard-data-server==0.6.1
45 | tensorboard-plugin-wit==1.8.1
46 | tokenizers==0.12.1
47 | torch==1.10.1+cu111
48 | torchaudio==0.10.1+rocm4.1
49 | torchvision==0.11.2+cu111
50 | tqdm==4.64.0
51 | transformers==4.19.2
52 | typing-extensions==4.2.0
53 | urllib3==1.26.9
54 | Werkzeug==2.1.2
55 | zipp==3.8.0
56 |
--------------------------------------------------------------------------------
/scripts/feature_learn.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | sys.path.append(os.getcwd())
5 |
6 | import json
7 | import yaml
8 | from tqdm import tqdm
9 | import argparse
10 | import numpy as np
11 | import torch
12 |
13 | from torch.utils.data import DataLoader
14 | from src.utils import load_logger, load_yaml
15 | from src.sampler import get_data_sampler, load_datasets
16 | from src.lr_schedule import WarmupLinearScheduler
17 | from torch.utils.tensorboard import SummaryWriter
18 | from torch.cuda.amp import autocast
19 | from torch.cuda.amp import GradScaler
20 | from torch.nn.utils import clip_grad_norm_
21 | from distutils.util import strtobool as _bool
22 | from src.train_utils import get_data_feeder_from_sampler
23 | from models.distance_func import *
24 | from models.loss_func import *
25 | from src.trainer import *
26 | from copy import deepcopy
27 |
28 |
29 | parser = argparse.ArgumentParser()
30 |
31 | parser.add_argument('--mode', type=str, default=None)
32 | parser.add_argument('--file_dir', type=str, default='/mnt/data/user8/solar-commonsense_inference')
33 | parser.add_argument('--dataset_type', type=str, default='atomic-2020')#, required=True)
34 | parser.add_argument('--model_name', type=str, default='gpt2', help="bart | gpt2")#, required=True)
35 | parser.add_argument('--model_size', type=str, default='base', help="base | large")#, required=True)
36 | parser.add_argument('--main_yml', type=str, default='v01.yml')#, required=True)
37 | parser.add_argument('--tknz_yml', type=str, default='tokenizer_config.yml')
38 | parser.add_argument('--log_level', type=str, default='INFO')
39 | parser.add_argument('--log', type=str, default='NEWvTEST')#, required=True)
40 | parser.add_argument('--random_seed', type=int, default=42)
41 | parser.add_argument('--load_model', type=str, default=None)
42 |
43 | # Optimize
44 | parser.add_argument('--batch_size', type=int, default=128)
45 | parser.add_argument('--update_batch', type=int, default=128)
46 | parser.add_argument('--iter_per_epoch', type=int, default=40000)
47 | parser.add_argument('--dev_iter_per_epoch', type=int, default=1)
48 | parser.add_argument('--epoch_num', type=int, default=1)
49 |
50 | parser.add_argument("--learning_rate", default=0.00001, type=float, help="The initial learning rate for Adam.")
51 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
52 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
53 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
54 | parser.add_argument('--lm_head_dropout_p', default=0.1, type=float)
55 | parser.add_argument('--proj_head_dropout_p', default=0.1, type=float)
56 |
57 | parser.add_argument("--recadam_anneal_fun", type=str, default='sigmoid', choices=["sigmoid", "linear", 'constant'],
58 | help="the type of annealing function in RecAdam. Default sigmoid")
59 | parser.add_argument("--recadam_anneal_k", type=float, default=0.0001, help="k for the annealing function in RecAdam.")
60 | parser.add_argument("--recadam_anneal_t0", type=int, default=5000, help="t0 for the annealing function in RecAdam.")
61 | parser.add_argument("--recadam_anneal_w", type=float, default=1.0,
62 | help="Weight for the annealing function in RecAdam. Default 1.0.")
63 | parser.add_argument("--recadam_pretrain_cof", type=float, default=5000.0,
64 | help="Coefficient of the quadratic penalty in RecAdam. Default 5000.0.")
65 | parser.add_argument('--warmup_steps', type=int, default=100)
66 |
67 | # ETC
68 | parser.add_argument('--share_proj_layer', type=_bool, default=True, help=
69 | 'if true, share projection layer between s_representation and ro_representation'
70 | 'if false, do not share projection layer between them. each have each own projection layer')
71 |
72 | args = parser.parse_args()
73 |
74 | if args.batch_size > args.update_batch:
75 | args.batch_size = args.update_batch
76 |
77 | # Path Setting
78 | np.random.seed(args.random_seed)
79 |
80 | # - log, [tb, ckpt, gen, eval]
81 | main_config_path = f'config/{args.dataset_type}/{args.model_name}-{args.model_size}_experiments/{args.main_yml}'
82 | tknz_config_path = f'config/{args.tknz_yml}'
83 | dataset_config_path = f'config/{args.dataset_type}/datasets.yml'
84 |
85 | log_dir = f'{args.file_dir}/log/{args.dataset_type}/{args.model_name}-{args.model_size}_experiments/{args.log}'
86 | logging_dir = log_dir + '/logging.log'
87 | tb_dir = log_dir + '/tb'
88 | gen_dir = log_dir + '/gen'
89 | eval_dir = log_dir + '/eval'
90 | ckpt_dir = log_dir + '/ckpt'
91 |
92 | if not os.path.exists(log_dir):
93 | os.makedirs(log_dir)
94 | os.mkdir(tb_dir)
95 | os.mkdir(gen_dir)
96 | os.mkdir(eval_dir)
97 | os.mkdir(ckpt_dir)
98 |
99 | # Initialize Log
100 | logger = load_logger(logging_dir, args.log_level)
101 | logger.info('Logger is Successfully initialized !')
102 | tb_writer = SummaryWriter(tb_dir)
103 |
104 | # Loading YML file
105 | main_config = load_yaml(main_config_path)
106 | tknz_config = load_yaml(tknz_config_path)
107 | dataset_cfg = load_yaml(dataset_config_path)
108 |
109 | with open(os.path.join(log_dir, 'main_config.json'), 'w') as f:
110 | json.dump(main_config, f)
111 | with open(os.path.join(log_dir, 'tknz_config.json'), 'w') as f:
112 | json.dump(tknz_config, f)
113 | with open(os.path.join(log_dir, 'argparse.json'), 'w') as f:
114 | json.dump(args.__dict__, f)
115 |
116 | task_cfg = main_config['task']
117 | model_cfg = main_config['model']
118 | opt_cfg = main_config['opt']
119 | log_cfg = main_config['log']
120 |
121 | # Tokenizer Setting with args.tknz_yml
122 | if args.model_name == 'bart':
123 | from transformers import BartTokenizer as Tokenizer
124 | from models.bart import CometBART as CometModel
125 | from src.tokenizer import BartCSKGTokenizer as CSKGTokenizer
126 |
127 | elif args.model_name == 'gpt2':
128 | from transformers import GPT2Tokenizer as Tokenizer
129 | from models.gpt2 import CometGPT2 as CometModel
130 | from src.tokenizer import Gpt2CSKGTokenizer as CSKGTokenizer
131 |
132 | elif 't5' in model_cfg['name']:
133 | raise NotImplementedError
134 | else:
135 | raise NotImplementedError
136 |
137 | anneal_targets = load_yaml(f'models/pretrained_params/pretrained_{args.model_name}-{args.model_size}.yml')['pretrained_weights']
138 |
139 | _tokenizer = Tokenizer.from_pretrained(model_cfg['pretrained_model'])
140 | tokenizer = CSKGTokenizer(_tokenizer, tknz_config)
141 | vocab_len = len(tokenizer) + 1
142 |
143 | model = CometModel.from_pretrained(model_cfg['pretrained_model'])
144 | model.modify_lm_heads(vocab_len, args.lm_head_dropout_p)
145 | model.add_proj_layer(model_cfg['contrastive_head'], args.proj_head_dropout_p)
146 |
147 | pretrained_model = deepcopy(model)
148 |
149 | # Load Dataset
150 | dataset, sim_mat = load_datasets(dataset_cfg, tokenizer, logger)
151 |
152 | # Initialize Sampler
153 | sampler = get_data_sampler(task_cfg, dataset, sim_mat, tokenizer, logger, args)
154 | train_sampler = sampler['train']
155 | dev_sampler = sampler['dev']
156 |
157 | # Connect sampler to data_feeding adaptor
158 | options = main_config['model']['task_adaptor_options']
159 |
160 | iter_num = 0
161 | assert args.update_batch >= args.batch_size
162 | stack_num = int(args.update_batch / args.batch_size)
163 | assert int(stack_num * args.batch_size) == int(args.update_batch)
164 |
165 | train_data_feeder = get_data_feeder_from_sampler(train_sampler, options, tokenizer, task_cfg, args.batch_size, args.iter_per_epoch * stack_num)
166 | dev_data_feeder = get_data_feeder_from_sampler(dev_sampler, options, tokenizer, task_cfg, args.batch_size, args.dev_iter_per_epoch)
167 |
168 | if model_cfg['contrastive_head']['proj_layer_type'] in ['multi-head', 'multi-head-deeper']:
169 | distance_model = MultiHeadDistance(model_cfg['contrastive_head']['multi-head'])
170 | elif model_cfg['contrastive_head']['proj_layer_type'] == 'non-linear':
171 | distance_model = NonLinearHeadDistance()
172 | else:
173 | raise NotImplementedError
174 |
175 | usable_cuda = torch.cuda.is_available()
176 | device = torch.device("cuda:0" if usable_cuda else "cpu")
177 |
178 | model.to(device)
179 | pretrained_model.to(device)
180 |
181 | from src.rec_adam_wrapper import get_adam_optimizer
182 | optim, scheduler = get_adam_optimizer(model, pretrained_model, anneal_targets, args, args.iter_per_epoch * args.epoch_num)
183 |
184 | global_steps = 0
185 | model.train()
186 | optim.zero_grad()
187 |
188 | if args.share_proj_layer:
189 | share_s = share_ro = None
190 | else:
191 | share_s = True
192 | share_ro = False
193 |
194 | share_proj_layer = (share_s, share_ro)
195 |
196 | if task_cfg['con_task']['loss_func'] == 'NT-Logistic':
197 | con_loss_func = NT_Logistic(opt_cfg['temperature'])
198 | else:
199 | raise NotImplementedError
200 |
201 | auxil_loss_func = torch.nn.CrossEntropyLoss(reduction='none')
202 |
203 |
204 | con_task_loader = DataLoader(train_data_feeder['con_task'], batch_size=1,
205 | drop_last=False, shuffle=False, num_workers=4)
206 |
207 | if options['con_task']['format'] == 'enc-dec':
208 | con_task_train = ConTaskTrainerForEncDec(model, device, distance_model, con_loss_func,
209 | log_cfg['tb_period'], tb_writer, share_proj_layer, args)
210 |
211 | elif options['con_task']['format'] == 'dec':
212 | con_task_train = ConTaskTrainerForDec(model, device, distance_model, con_loss_func,
213 | log_cfg['tb_period'], tb_writer, share_proj_layer, args)
214 |
215 | if 'rec_inf_shu_task' in train_data_feeder:
216 | auxil_task_loader = DataLoader(train_data_feeder['rec_inf_shu_task'], batch_size=args.batch_size,
217 | drop_last=True, shuffle=False, num_workers=10)
218 | auxil_task_name = 'rec_inf_shu_task'
219 | auxil_task_train = RecTaskTrainerForEncDec(model, device, auxil_loss_func, log_cfg['tb_period'], tb_writer)
220 |
221 | elif 'mask_gen_task' in train_data_feeder:
222 | auxil_task_loader = DataLoader(train_data_feeder['mask_gen_task'], batch_size=args.batch_size,
223 | drop_last=True, shuffle=False, num_workers=10)
224 | auxil_task_name = 'mask_gen_task'
225 | auxil_task_train = MaskGenTaskTrainerForDec(model, device, auxil_loss_func, log_cfg['tb_period'], tb_writer)
226 |
227 | else:
228 | raise NotImplementedError
229 |
230 | torch.save(model, os.path.join(ckpt_dir, 'model-{}-steps.ckpt'.format(global_steps)))
231 |
232 | #########
233 | # for task_name, task_feeder in train_data_feeder.items():
234 | # task_feeder.sampler._init_cursor(args.random_seed)
235 | #
236 | # for con_task_data, auxil_task_data in tqdm(zip(con_task_loader, auxil_task_loader)):
237 | # multi_task_loss = []
238 | ########
239 | for e in range(args.epoch_num):
240 | model.train()
241 | for task_name, task_feeder in train_data_feeder.items():
242 | task_feeder.sampler._init_cursor(args.random_seed)
243 |
244 | for con_task_data, auxil_task_data in tqdm(zip(con_task_loader, auxil_task_loader)):
245 | multi_task_loss = []
246 |
247 | with autocast():
248 | con_task_data = {i: con_task_data[i][0] for i in con_task_data}
249 | con_loss = con_task_train.train(con_task_data, global_steps, iter_num == 0)
250 | multi_task_loss.append((con_loss.item(), task_cfg['con_task']['weight']))
251 | con_loss *= task_cfg['con_task']['weight']
252 | con_loss /= stack_num
253 | con_loss.backward()
254 |
255 | with autocast():
256 | auxil_loss = auxil_task_train.train(auxil_task_data, global_steps, iter_num == 0)
257 | multi_task_loss.append((auxil_loss.item(), task_cfg[auxil_task_name]['weight']))
258 | auxil_loss *= task_cfg[auxil_task_name]['weight']
259 | auxil_loss /= stack_num
260 | auxil_loss.backward()
261 |
262 | iter_num += 1
263 | if iter_num != stack_num:
264 | continue
265 | iter_num = 0
266 |
267 | clip_grad_norm_(model.parameters(), opt_cfg['clip_grad_norm'])
268 | _, anneal_lambda = optim.step()
269 | scheduler.step()
270 | global_steps += 1
271 |
272 | if global_steps % log_cfg['tb_period'] == 0:
273 | avg_multi_task_loss = sum([i[0] for i in multi_task_loss]) / len(multi_task_loss)
274 | weighted_multi_task_loss = sum([i[0] * i[1] for i in multi_task_loss])
275 | tb_writer.add_scalar('train/total_lr', optim.param_groups[0]['lr'], global_steps)
276 | tb_writer.add_scalar('train/multi-task_loss', avg_multi_task_loss, global_steps)
277 | tb_writer.add_scalar('train/multi-task_loss(weighted)', weighted_multi_task_loss, global_steps)
278 | tb_writer.add_scalar('train/anneal_lambda', anneal_lambda, global_steps)
279 | tb_writer.flush()
280 |
281 | if global_steps % log_cfg['save_period'] == 0:
282 | torch.save(model, os.path.join(ckpt_dir, 'model-{}-steps.ckpt'.format(global_steps)))
283 |
284 | torch.save(model, os.path.join(ckpt_dir, 'model-{}-steps.ckpt'.format(global_steps)))
285 |
286 |
--------------------------------------------------------------------------------
/scripts/finetune.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 |
3 | sys.path.append(os.getcwd())
4 |
5 | import json
6 | import yaml
7 | from tqdm import tqdm
8 | import argparse
9 | import numpy as np
10 | import torch
11 | from copy import deepcopy
12 |
13 | from torch.utils.tensorboard import SummaryWriter
14 | from torch.cuda.amp import autocast
15 | from torch.cuda.amp import GradScaler
16 | from torch.nn.utils import clip_grad_norm_
17 | from distutils.util import strtobool as _bool
18 | from torch.utils.data import DataLoader
19 |
20 | from src.utils import load_logger, load_yaml
21 |
22 | from src.finetune.finetune_utils import *
23 | from src.finetune.finetune_trainer import get_finetune_trainer
24 | from models.model_utils import convert_model
25 | from src.lr_schedule import WarmupLinearScheduler
26 |
27 | os.environ['CUDA_VISIBLE_DEVICES'] = '3'
28 | parser = argparse.ArgumentParser()
29 |
30 | parser.add_argument('--mode', type=str, default=None)
31 | parser.add_argument('--file_dir', type=str, default='/mnt/data/user8/solar-commonsense_inference')
32 | parser.add_argument('--dataset_type', type=str, default='conceptnet')
33 | parser.add_argument('--model_name', type=str, default='bart', help="bart | gpt2")#, required=True)
34 | parser.add_argument('--model_size', type=str, default='large', help="base | large")#, required=True)
35 | parser.add_argument('--exp_type', type=str, default='baseline', help='baseline | experiments')#, required=True)
36 | parser.add_argument('--main_yml', type=str, default='v01.yml')
37 | parser.add_argument('--tknz_yml', type=str, default='tokenizer_config.yml')
38 | parser.add_argument('--log_level', type=str, default='INFO')
39 | parser.add_argument('--log', type=str, default='vTEST1')#, required=True)
40 | parser.add_argument('--random_seed', type=int, default=42)
41 | parser.add_argument('--load_model', type=str, default=None)
42 |
43 | parser.add_argument('--same_hwang', type=_bool, default=False)
44 |
45 | parser.add_argument('--lr', type=float, default=0.00001)
46 | parser.add_argument('--batch_size', type=int, default=64)
47 | parser.add_argument('--update_batch', type=int, default=64)
48 | parser.add_argument('--epoch_num', type=int, default=2)
49 | parser.add_argument('--output_drop', type=float, default=0.1)
50 | parser.add_argument('--remove_except_best', type=_bool, default=True)
51 | parser.add_argument('--patience', type=int, default=2)
52 |
53 |
54 | args = parser.parse_args()
55 |
56 | if args.load_model is not None:
57 | assert args.exp_type == 'experiments'
58 |
59 | args.batch_size = args.batch_size if args.update_batch >= args.batch_size else args.update_batch
60 |
61 | np.random.seed(args.random_seed)
62 |
63 | main_config_path = f'config/{args.dataset_type}/{args.model_name}-{args.model_size}_baseline/{args.main_yml}'
64 | tknz_config_path = f'config/{args.tknz_yml}'
65 |
66 | dataset_config_path = f'config/{args.dataset_type}/datasets.yml'
67 |
68 | log_dir = f'{args.file_dir}/log_fntn/{args.dataset_type}/{args.model_name}-{args.model_size}_{args.exp_type}/{args.log}'
69 |
70 | logging_dir = log_dir + '/logging.log'
71 | tb_dir = log_dir + '/tb'
72 | gen_dir = log_dir + '/gen'
73 | eval_dir = log_dir + '/eval'
74 | ckpt_dir = log_dir + '/ckpt'
75 |
76 | if not os.path.exists(log_dir):
77 | os.makedirs(log_dir)
78 | os.mkdir(tb_dir)
79 | os.mkdir(gen_dir)
80 | os.mkdir(eval_dir)
81 | os.mkdir(ckpt_dir)
82 |
83 | logger = load_logger(logging_dir, args.log_level)
84 | logger.info('Logger is Successfully initialized !')
85 | tb_writer = SummaryWriter(tb_dir)
86 |
87 | main_config = load_yaml(main_config_path)
88 | tknz_config = load_yaml(tknz_config_path)
89 | dataset_cfg = load_yaml(dataset_config_path)
90 |
91 | with open(os.path.join(log_dir, 'main_config.json'), 'w') as f:
92 | json.dump(main_config, f)
93 | with open(os.path.join(log_dir, 'tknz_config.json'), 'w') as f:
94 | json.dump(tknz_config, f)
95 | with open(os.path.join(log_dir, 'argparse.json'), 'w') as f:
96 | json.dump(args.__dict__, f)
97 |
98 | model_cfg = main_config['model']
99 | opt_cfg = main_config['opt']
100 | log_cfg = main_config['log']
101 |
102 | # Model Selection
103 | MODEL_TYPE = None
104 | if 'bart' in model_cfg['name']:
105 | from transformers import BartTokenizer as Tokenizer
106 | from models.bart import CometBART as CometModel
107 | from src.tokenizer import BartCSKGTokenizer as CSKGTokenizer
108 | from models.bart import convert_BARTModel_to_BartForConditionalGeneration as model2gen
109 | MODEL_TYPE = 'enc-dec'
110 |
111 | elif 't5' in model_cfg['name']:
112 | raise NotImplementedError
113 |
114 | elif 'gpt2' in model_cfg['name']:
115 | from transformers import GPT2Tokenizer as Tokenizer
116 | from models.gpt2 import CometGPT2 as CometModel
117 | from src.tokenizer import Gpt2CSKGTokenizer as CSKGTokenizer
118 | from models.gpt2 import convert_GPT2Model_to_GPT2LMHeadModel as model2gen
119 | MODEL_TYPE = 'dec'
120 |
121 | else:
122 | raise NotImplementedError
123 |
124 | _tokenizer = Tokenizer.from_pretrained(model_cfg['pretrained_model'])
125 | tokenizer = CSKGTokenizer(_tokenizer, tknz_config)
126 |
127 | vocab_len = len(tokenizer) + 1
128 |
129 | if args.load_model is None:
130 | model = CometModel.from_pretrained(model_cfg['pretrained_model'])
131 | model.modify_lm_heads(vocab_len, args.output_drop)
132 | logger.info("Model is loaded from : {}".format(model_cfg['pretrained_model']))
133 | else:
134 | args.load_model = os.path.join(args.load_model)
135 | src_model = torch.load(args.load_model)
136 | logger.info("Model is loaded from : {}".format(args.load_model))
137 | dst_model = CometModel.from_pretrained(model_cfg['pretrained_model'])
138 | dst_model.modify_lm_heads(vocab_len, args.output_drop)
139 | model = convert_model(src_model, dst_model)
140 |
141 | dataset = load_fntn_datasets(dataset_cfg, tokenizer, logger)
142 |
143 | fntn_train_dataset = get_finetune_dataset(dataset['train'], dataset_cfg, tokenizer, logger, 'train', MODEL_TYPE)
144 | fntn_dev_dataset = get_finetune_dataset(dataset['dev'], dataset_cfg, tokenizer, logger, 'dev', MODEL_TYPE)
145 |
146 | eval_dev_dataset = get_eval_dataset(dataset['dev'], tokenizer, logger, 'dev', MODEL_TYPE)
147 | eval_test_dataset = get_eval_dataset(dataset['test'], tokenizer, logger, 'test', MODEL_TYPE)
148 |
149 | fntn_train_loader = DataLoader(fntn_train_dataset,
150 | batch_size=args.batch_size, drop_last=True, shuffle=True, num_workers=20)
151 | fntn_dev_loader = DataLoader(fntn_dev_dataset,
152 | batch_size=args.batch_size, drop_last=True, shuffle=True, num_workers=20)
153 |
154 | usable_cuda = torch.cuda.is_available()
155 | device = torch.device("cuda:0" if usable_cuda else "cpu")
156 | model.to(device)
157 |
158 | # Optimizer Settings
159 | if args.same_hwang is not None:
160 | optim = torch.optim.AdamW(model.parameters(), lr=args.lr, eps=1e-8)
161 | else:
162 | optim = torch.optim.Adam(model.parameters(),
163 | lr=args.lr, betas=(opt_cfg['adam_beta_1'], opt_cfg['adam_beta_2']), weight_decay=opt_cfg['weight_decay'])
164 |
165 | total_steps = int((len(fntn_train_dataset) / args.update_batch) * args.epoch_num) + 100
166 | lr_schedule = WarmupLinearScheduler(optim, args.lr, opt_cfg['warmup_steps'], total_steps)
167 | loss_func = torch.nn.CrossEntropyLoss(reduction='none')
168 | global_step = 0
169 | scaler = GradScaler()
170 | model.train()
171 | optim.zero_grad()
172 |
173 | iter_num = 0
174 |
175 | assert args.update_batch >= args.batch_size
176 | stack_num = int(args.update_batch / args.batch_size)
177 | assert int(stack_num * args.batch_size) == int(args.update_batch)
178 |
179 | fntn_trainer = get_finetune_trainer(model, device, loss_func, log_cfg['tb_period'], tb_writer, MODEL_TYPE)
180 |
181 | ckpt_loss = {}
182 |
183 | for e in range(1, args.epoch_num+1):
184 | model.train()
185 | fntn_train_loader.dataset.shuffle()
186 | patience = args.patience
187 | for sample in tqdm(fntn_train_loader, desc='[Train] Epoch {}/{}'.format(e, args.epoch_num), ncols=130):
188 | with autocast():
189 | loss = fntn_trainer.train(sample, global_step, 'train', iter_num == 0)
190 | loss /= stack_num
191 | loss.backward()
192 | iter_num += 1
193 | if iter_num != stack_num:
194 | continue
195 | iter_num = 0
196 |
197 | lr_schedule(global_step)
198 | clip_grad_norm_(model.parameters(), opt_cfg['clip_grad_norm'])
199 | optim.step()
200 | optim.zero_grad()
201 | global_step += 1
202 |
203 | if global_step % log_cfg['tb_period'] == 0:
204 | tb_writer.add_scalar('train/lr', optim.param_groups[0]['lr'], global_step)
205 | tb_writer.flush()
206 |
207 | model.eval()
208 | loss_list = list()
209 | for sample in tqdm(fntn_dev_loader, desc='[Validate]', ncols=130):
210 | with autocast():
211 | loss = fntn_trainer.train(sample, global_step, 'dev')
212 | loss_list.append(loss.item())
213 | loss = sum(loss_list) / len(loss_list)
214 | tb_writer.add_scalar('dev/loss', loss, global_step)
215 |
216 | save_name = os.path.join(ckpt_dir, 'model-{}-epoch.ckpt'.format(e))
217 |
218 | torch.save(model, save_name)
219 | ckpt_loss[save_name] = loss
220 |
221 | for _ckpt, _loss in ckpt_loss.items():
222 | if loss > _loss:
223 | patience -= 1
224 | if patience <= 0:
225 | logger.info('Patience is over. End')
226 | break
227 |
228 | tokenizer_save_name = os.path.join(ckpt_dir, 'tokenizer.torch-pkl')
229 | torch.save(tokenizer, tokenizer_save_name)
230 |
231 | with open(os.path.join(ckpt_dir, 'ckpt_loss.json'), 'w') as f:
232 | json.dump(ckpt_loss, f)
233 |
234 | with open(os.path.join(ckpt_dir, 'ckpt_loss.json'), 'r') as f:
235 | ckpt_loss = json.load(f)
236 |
237 | best_ckpt = None
238 | best_loss = 10000
239 |
240 | for _ckpt, _loss in ckpt_loss.items():
241 | if best_loss > _loss:
242 | best_loss = _loss
243 | best_ckpt = _ckpt
244 |
245 | if args.remove_except_best:
246 | for _ckpt in ckpt_loss:
247 | if _ckpt == best_ckpt:
248 | continue
249 | os.remove(os.path.join(_ckpt))
250 |
251 | decode = tokenizer.tokenizer.decode
252 |
253 | model = torch.load(best_ckpt).to('cpu')
254 | if MODEL_TYPE == 'enc-dec':
255 | gen_model = model2gen(model, model_cfg['pretrained_model']).to(device)
256 | else:
257 | gen_model = model.to(device)
258 |
259 | test_decode_results = {
260 | 'info': {'log': log_dir, 'ckpt': best_ckpt},
261 | 'content': list()}
262 |
263 | greedy_test_decode_results = deepcopy(test_decode_results)
264 | greedy_test_decode_results['info']['decode_method'] = 'greedy'
265 | beam5_test_decode_results = deepcopy(test_decode_results)
266 | beam5_test_decode_results['info']['decode_method'] = 'beam5'
267 | nucl_test_decode_results = deepcopy(test_decode_results)
268 | nucl_test_decode_results['info']['decode_method'] = 'nucl'
269 |
270 | if MODEL_TYPE == 'enc-dec':
271 | for sample in tqdm(eval_test_dataset, ncols=130):
272 | src = sample['src']
273 | refs = sample['ref']
274 | enc_input_ids = torch.tensor(src).to(device).view(1, -1)
275 | enc_att_masks = torch.ones_like(enc_input_ids).to(device)
276 |
277 | inputs = {'input_ids': enc_input_ids, 'attention_mask': enc_att_masks}
278 | greedy_output = gen_model.generate(**inputs, early_stopping=True,
279 | bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id)
280 | # try:
281 | # beam5_output = gen_model.generate(**inputs, num_beams=5, early_stopping=True) #max_length=60,
282 | # except:
283 | # beam5_output = gen_model.generate(**inputs, num_beams=5, early_stopping=True, max_length=60)
284 | greedy_str = decode(greedy_output.tolist()[0])
285 | # beam5_str = decode(beam5_output.tolist()[0])
286 |
287 | if '' in greedy_str:
288 | greedy_str = greedy_str[greedy_str.find('') + 1 + len(''):].strip()
289 | # if '' in beam5_str:
290 | # beam5_str = beam5_str[beam5_str.find('')+1 + len(''):].strip()
291 |
292 | _input = decode(src)
293 | _refs = list()
294 | for ref in refs:
295 | _ref = decode(ref)
296 | _refs.append(_ref)
297 |
298 | greedy_test_decode_results['content'].append({'input': _input, 'output': greedy_str, 'refs': _refs})
299 | # beam5_test_decode_results['content'].append({'input': _input, 'output': beam5_str, 'refs': _refs})
300 |
301 | with open(os.path.join(log_dir, 'greedy_gen_examples.json'), 'w') as f:
302 | json.dump(greedy_test_decode_results, f)
303 | #
304 | # with open(os.path.join(log_dir, 'beam5_gen_examples.json'), 'w') as f:
305 | # json.dump(beam5_test_decode_results, f)
306 |
307 | else:
308 |
309 | for sample in tqdm(eval_test_dataset, ncols=130):
310 | src = sample['src']
311 | refs = sample['ref']
312 | enc_input_ids = torch.tensor(src).to(device).view(1, -1)
313 | enc_att_masks = torch.ones_like(enc_input_ids).to(device)
314 |
315 | inputs = {'input_ids': enc_input_ids, 'att_masks': enc_att_masks}
316 |
317 | for i in range(30):
318 | output = model.forward_conditional_gen(**inputs) # ['input_ids'], att_masks=inputs['attention_mask'])
319 | gen_token_id = int(torch.argmax(output[:, -1, :], -1))
320 | old_inputs = {key: val.tolist() for key, val in inputs.items()}
321 | old_inputs['input_ids'][0].append(gen_token_id)
322 | old_inputs['att_masks'][0].append(1)
323 | if gen_token_id == tokenizer.eos_token_id:
324 | break
325 | inputs = {key: torch.tensor(val).to(device) for key, val in old_inputs.items()}
326 |
327 | greedy_output = inputs['input_ids']
328 |
329 | greedy_str = decode(greedy_output.tolist()[0])
330 |
331 | if '' in greedy_str:
332 | greedy_str = greedy_str[greedy_str.find('') + 1 + len(''):].strip()
333 |
334 | _input = decode(src)
335 | print(_input)
336 | print(greedy_str)
337 | print('----------')
338 | _refs = list()
339 | for ref in refs:
340 | _ref = decode(ref)
341 | _refs.append(_ref)
342 |
343 | greedy_test_decode_results['content'].append({'input': _input, 'output': greedy_str, 'refs': _refs})
344 |
345 | with open(os.path.join(log_dir, 'greedy_gen_examples_FIXED.json'), 'w') as f:
346 | json.dump(greedy_test_decode_results, f)
347 |
--------------------------------------------------------------------------------
/scripts/inference.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 |
3 | sys.path.append(os.getcwd())
4 |
5 | import json
6 | import yaml
7 | from tqdm import tqdm
8 | import argparse
9 | import numpy as np
10 | import torch
11 |
12 | from torch.utils.tensorboard import SummaryWriter
13 | from torch.cuda.amp import autocast
14 | from torch.cuda.amp import GradScaler
15 | from torch.nn.utils import clip_grad_norm_
16 | from distutils.util import strtobool as _bool
17 | from torch.utils.data import DataLoader
18 |
19 | from src.utils import load_logger, load_yaml
20 |
21 | from src.finetune.finetune_utils import *
22 | from src.finetune.finetune_trainer import FineTuneTrainer
23 | from models import bart
24 | from src.lr_schedule import WarmupLinearScheduler
25 |
26 | os.environ['CUDA_VISIBLE_DEVICES'] = '2'
27 |
28 |
29 | parser = argparse.ArgumentParser()
30 |
31 | parser.add_argument('--mode', type=str, default=None)
32 | parser.add_argument('--file_dir', type=str, default='/mnt/data/user8/solar-commonsense_inference')
33 | parser.add_argument('--dataset_type', type=str, default='conceptnet')
34 | parser.add_argument('--model_name', type=str, default='bart', help="bart | gpt2")
35 | parser.add_argument('--model_size', type=str, default='large', help="base | large")
36 | parser.add_argument('--exp_type', type=str, default='baseline', help='baseline | experiments')
37 | parser.add_argument('--log_level', type=str, default='INFO')
38 | parser.add_argument('--random_seed', type=int, default=42)
39 | parser.add_argument('--load_model', type=str, default='v07')
40 | parser.add_argument('--use_greedy', type=_bool, default=False)
41 | parser.add_argument('--use_beam', type=int, default=10 )
42 |
43 | args = parser.parse_args()
44 |
45 | np.random.seed(args.random_seed)
46 |
47 | use_greedy = args.use_greedy
48 | use_beam = args.use_beam > 1
49 | beam_num = args.use_beam
50 |
51 | log_dir = f'{args.file_dir}/log_fntn/{args.dataset_type}/{args.model_name}-{args.model_size}_{args.exp_type}/{args.load_model}'
52 |
53 | main_config_path = f'{log_dir}/main_config.json'
54 | tknz_config_path = f'{log_dir}/tknz_config.json'
55 |
56 | dataset_config_path = f'config/{args.dataset_type}/datasets.yml'
57 |
58 | print(f'Log Location : {log_dir}')
59 | logging_dir = log_dir + '/inf_logging.log'
60 | tb_dir = log_dir + '/tb'
61 | gen_dir = log_dir + '/gen'
62 | eval_dir = log_dir + '/eval'
63 | ckpt_dir = log_dir + '/ckpt'
64 |
65 | if not os.path.exists(log_dir):
66 | raise Exception
67 |
68 | logger = load_logger(logging_dir, args.log_level)
69 | logger.info('Logger is Successfully initialized !')
70 |
71 | main_config = load_yaml(main_config_path)
72 | tknz_config = load_yaml(tknz_config_path)
73 | dataset_cfg = load_yaml(dataset_config_path)
74 |
75 | model_cfg = main_config['model']
76 | opt_cfg = main_config['opt']
77 | log_cfg = main_config['log']
78 |
79 | usable_cuda = torch.cuda.is_available()
80 | device = torch.device("cuda:0" if usable_cuda else "cpu")
81 |
82 |
83 | # Model Selection
84 | MODEL_TYPE = None
85 | if 'bart' in model_cfg['name']:
86 | from transformers import BartTokenizer as Tokenizer
87 | from models.bart import CometBART as CometModel
88 | from src.tokenizer import BartCSKGTokenizer as CSKGTokenizer
89 | from models.bart import convert_BARTModel_to_BartForConditionalGeneration as model2gen
90 | MODEL_TYPE = 'enc-dec'
91 |
92 | elif 't5' in model_cfg['name']:
93 | raise NotImplementedError
94 |
95 | elif 'gpt2' in model_cfg['name']:
96 | raise NotImplementedError
97 |
98 | else:
99 | raise NotImplementedError
100 |
101 | _tokenizer = Tokenizer.from_pretrained(model_cfg['pretrained_model'])
102 | tokenizer = CSKGTokenizer(_tokenizer, tknz_config)
103 |
104 | vocab_len = len(tokenizer) + 1
105 |
106 | with open(os.path.join(ckpt_dir, 'ckpt_loss.json'), 'r') as f:
107 | ckpt_loss = json.load(f)
108 |
109 | best_ckpt = None
110 | best_loss = 10000
111 |
112 | for _ckpt, _loss in ckpt_loss.items():
113 | if best_loss > _loss:
114 | best_loss = _loss
115 | best_ckpt = _ckpt
116 |
117 | logger.info("Model will be loaded from : {}".format(best_ckpt))
118 |
119 | try:
120 | if str(device) == 'cpu':
121 | model = torch.load(best_ckpt, map_location=torch.device('cpu'))
122 | else:
123 | model = torch.load(best_ckpt)
124 |
125 | except:
126 | best_ckpt = load_model.replace('log_fntn', 'log2_fntn')
127 | model = torch.load(best_ckpt)
128 |
129 | gen_model = model2gen(model, model_cfg['pretrained_model']).to(device)
130 |
131 | gen_model.to(device)
132 | gen_model.eval()
133 |
134 | decode = tokenizer.tokenizer.decode
135 |
136 |
137 | dataset = load_fntn_datasets(dataset_cfg, tokenizer, logger)
138 | eval_test_dataset = get_eval_dataset(dataset['test'], tokenizer, logger, 'test', 'dec')
139 |
140 | test_decode_results = {
141 | 'info': {'log': log_dir, 'ckpt': best_ckpt},
142 | 'content': list()}
143 |
144 | from copy import deepcopy
145 |
146 | if use_greedy:
147 | greedy_results = deepcopy(test_decode_results)
148 | greedy_results['info']['decode_method'] = 'greedy'
149 |
150 | if use_beam:
151 | beam_results = deepcopy(test_decode_results)
152 | beam_results['info']['decode_method'] = f'beam{beam_num}'
153 |
154 |
155 | for sample in tqdm(eval_test_dataset, ncols=130):
156 | src = sample['src']
157 | refs = sample['ref']
158 |
159 | _input = decode(src)
160 | _refs = list()
161 | for ref in refs:
162 | _ref = decode(ref)
163 | _refs.append(_ref)
164 |
165 | enc_input_ids = torch.tensor(src).to(device).view(1, -1)
166 | enc_att_masks = torch.ones_like(enc_input_ids).to(device)
167 |
168 | inputs = {'input_ids': enc_input_ids, 'attention_mask': enc_att_masks}
169 | if use_greedy:
170 | greedy_output = gen_model.generate(**inputs, early_stopping=True,
171 | bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id)
172 | greedy_str = decode(greedy_output.tolist()[0])
173 | greedy_results['content'].append({'input': _input, 'output': greedy_str, 'refs': _refs})
174 |
175 | if use_beam:
176 | beam_output = gen_model.generate(**inputs, num_beams=beam_num, num_return_sequences=beam_num, early_stopping=True,
177 | bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id) #max_length=60,
178 | beam_str = [decode(beam.tolist()).replace('','') for beam in beam_output]
179 | beam_results['content'].append({'input': _input, 'output': beam_str, 'refs': _refs})
180 |
181 | if use_greedy:
182 | with open(os.path.join(log_dir, 'greedy_gen_examples.json'), 'w') as f:
183 | json.dump(greedy_results, f)
184 |
185 | if use_beam:
186 | with open(os.path.join(log_dir, f'beam-{beam_num}_gen_examples.json'), 'w') as f:
187 | json.dump(beam_results, f)
188 | # with open(os.path.join(log_dir, 'beam5_gen_examples.json'), 'w') as f:
189 | # json.dump(beam5_test_decode_results, f)
190 | #
191 | # for sample in tqdm(eval_test_dataset, ncols=130):
192 | # src = sample['src']
193 | # refs = sample['ref']
194 | # enc_input_ids = torch.tensor(src).to(device).view(1, -1)
195 | # enc_att_masks = torch.ones_like(enc_input_ids).to(device)
196 | #
197 | # inputs = {'input_ids': enc_input_ids, 'attention_mask': enc_att_masks}
198 | #
199 | # print('\n\nstart')
200 | # greedy_output = gen_model.generate(**inputs)#, early_stopping=True)
201 | # try:
202 | # beam5_output = gen_model.generate(**inputs, num_beams=5, early_stopping=True) #max_length=60,
203 | # except:
204 | # print('beam err occur')
205 | # beam5_output = gen_model.generate(**inputs, num_beams=5, early_stopping=True, max_length=60)
206 | # print('end')
207 | # greedy_str = decode(greedy_output.tolist()[0])
208 | # beam5_str = decode(beam5_output.tolist()[0])
209 | #
210 | # if '' in greedy_str:
211 | # greedy_str = greedy_str[greedy_str.find('') + 1 + len(''):].strip()
212 | # if '' in beam5_str:
213 | # beam5_str = beam5_str[beam5_str.find('')+1 + len(''):].strip()
214 | #
215 | # _input = decode(src)
216 | # _refs = list()
217 | # for ref in refs:
218 | # _ref = decode(ref)
219 | # _refs.append(_ref)
220 | #
221 | # greedy_test_decode_results['content'].append({'input': _input, 'output': greedy_str, 'refs': _refs})
222 | # beam5_test_decode_results['content'].append({'input': _input, 'output': beam5_str, 'refs': _refs})
223 | #
224 | #
225 | # with open(os.path.join(log_dir, 'greedy_gen_examples.json'), 'w') as f:
226 | # json.dump(greedy_test_decode_results, f)
227 | #
228 | # with open(os.path.join(log_dir, 'beam5_gen_examples.json'), 'w') as f:
229 | # json.dump(beam5_test_decode_results, f)
230 | #
231 | # # ----- TEST ----- #
232 | #
233 | # for i in range(10):
234 | # output = gen_model(**inputs).logits
235 | # last_toks = torch.argmax(output[:, -1, :], -1)
236 | # old_inputs = {key: val.tolist() for key, val in inputs.items()}
237 | # old_inputs['input_ids'][0].append(int(last_toks))
238 | # old_inputs['attention_mask'][0].append(1)
239 | # inputs = {key : torch.tensor(val).to(device) for key, val in old_inputs.items()}
240 | #
241 | #
242 | # # -- Non LM model -- #
243 | #
244 | # for i in range(10):
245 | # output = model.forward_conditional_gen(input_ids=inputs['input_ids'], att_masks=inputs['attention_mask'])
246 | # last_toks = torch.argmax(output[:, -1, :], -1)
247 | # old_inputs = {key: val.tolist() for key, val in inputs.items()}
248 | # old_inputs['input_ids'][0].append(int(last_toks))
249 | # old_inputs['attention_mask'][0].append(1)
250 | # inputs = {key : torch.tensor(val).to(device) for key, val in old_inputs.items()}
251 |
--------------------------------------------------------------------------------
/src/data_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 | import csv
4 | import operator
5 | import random
6 |
7 |
8 | def read_csv(input_file, quotechar='"', delimiter=",", skip_header=False):
9 | """Reads a tab separated value file."""
10 | with open(input_file, "r") as f:
11 | reader = csv.reader(f, delimiter=delimiter, quotechar=quotechar, quoting=csv.QUOTE_ALL, skipinitialspace=True)
12 | lines = []
13 | for line in reader:
14 | if sys.version_info[0] == 2:
15 | line = list(unicode(cell, 'utf-8') for cell in line)
16 | lines.append(line)
17 | if skip_header:
18 | lines = lines[1:]
19 | return lines
20 |
21 |
22 | def write_tsv(output_file, data, header=False):
23 | keys = list(data[0].keys())
24 | with open(output_file, 'w') as f:
25 | w = csv.DictWriter(f, keys, delimiter='\t', lineterminator='\n')
26 | if header:
27 | w.writeheader()
28 | for r in data:
29 | entry = {k: r[k] for k in keys}
30 | w.writerow(entry)
31 |
32 |
33 | def write_array2tsv(output_file, data, header=False):
34 | keys = range(len(data[0]))
35 | with open(output_file, 'w') as f:
36 | w = csv.DictWriter(f, keys, delimiter='\t', lineterminator='\n')
37 | if header:
38 | w.writeheader()
39 | for r in data:
40 | entry = {k: r[k] for k in keys}
41 | w.writerow(entry)
42 |
43 |
44 | def write_csv(filename, data, fieldnames):
45 | with open(filename, 'w', newline='') as csvfile:
46 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
47 |
48 | writer.writeheader()
49 | for d in data:
50 | formatted_d = {}
51 | for key, val in d.items():
52 | formatted_d[key] = json.dumps(val)
53 | writer.writerow(formatted_d)
54 |
55 |
56 | def read_jsonl(filename):
57 | data = []
58 | with open(filename, "r") as f:
59 | for line in f:
60 | data.append(json.loads(line))
61 | return data
62 |
63 |
64 | def write_items(output_file, items):
65 | with open(output_file, 'w') as f:
66 | for concept in items:
67 | f.write(concept + "\n")
68 | f.close()
69 |
70 |
71 | def write_jsonl(f, d):
72 | write_items(f, [json.dumps(r) for r in d])
73 |
74 |
75 | def count_relation(d):
76 | relation_count = {}
77 | prefix_count = {}
78 | head_count = {}
79 | for l in d:
80 | r = l[1]
81 | if r not in relation_count.keys():
82 | relation_count[r] = 0
83 | relation_count[r] += 1
84 |
85 | prefix = l[0]+l[1]
86 | if prefix not in prefix_count.keys():
87 | prefix_count[prefix] = 0
88 | prefix_count[prefix] += 1
89 |
90 | head = l[0]
91 | if head not in head_count.keys():
92 | head_count[head] = 0
93 | head_count[head] += 1
94 |
95 | sorted_relation_count = dict(sorted(relation_count.items(), key=operator.itemgetter(1), reverse=True))
96 | sorted_prefix_count = dict(sorted(prefix_count.items(), key=operator.itemgetter(1), reverse=True))
97 | sorted_head_count = dict(sorted(head_count.items(), key=operator.itemgetter(1), reverse=True))
98 |
99 | print("Relations:")
100 | for r in sorted_relation_count.keys():
101 | print(r, sorted_relation_count[r])
102 |
103 | print("\nPrefixes:")
104 | print("uniq prefixes: ", len(sorted_prefix_count.keys()))
105 | i = 0
106 | for r in sorted_prefix_count.keys():
107 | print(r, sorted_prefix_count[r])
108 | i += 1
109 | if i > 20:
110 | break
111 |
112 | print("\nHeads:")
113 | i = 0
114 | for r in sorted_head_count.keys():
115 | print(r, sorted_head_count[r])
116 | i += 1
117 | if i > 20:
118 | break
119 |
120 |
121 | def get_head_set(d):
122 | return set([l[0] for l in d])
123 |
124 |
125 | def head_based_split(data, dev_size, test_size, head_size_threshold=500, dev_heads=[], test_heads=[]):
126 | """
127 | :param data: the tuples to split according to the heads, where the head is the first element of each tuple
128 | :param dev_size: target size of the dev set
129 | :param test_size: target size of the test set
130 | :param head_size_threshold: Maximum number of tuples a head can be involved in,
131 | in order to be considered for the dev/test set'
132 | :param dev_heads: heads that are forced to belong to the dev set
133 | :param test_heads: heads that are forced to belong to the test set
134 | :return:
135 | """
136 | head_count = {}
137 | for l in data:
138 | head = l[0]
139 | if head not in head_count.keys():
140 | head_count[head] = 0
141 | head_count[head] += 1
142 |
143 | remaining_heads = dict(head_count)
144 |
145 | test_selected_heads = {}
146 | test_head_total_count = 0
147 |
148 | for h in test_heads:
149 | if h in remaining_heads:
150 | c = remaining_heads[h]
151 | test_selected_heads[h] = c
152 | test_head_total_count += c
153 | remaining_heads.pop(h)
154 |
155 | while test_head_total_count < test_size:
156 | h = random.sample(remaining_heads.keys(), 1)[0]
157 | c = remaining_heads[h]
158 | if c < head_size_threshold:
159 | test_selected_heads[h] = c
160 | test_head_total_count += c
161 | remaining_heads.pop(h)
162 |
163 | test = [l for l in data if l[0] in test_selected_heads.keys()]
164 |
165 | dev_selected_heads = {}
166 | dev_head_total_count = 0
167 |
168 | for h in dev_heads:
169 | if h in remaining_heads:
170 | c = remaining_heads[h]
171 | dev_selected_heads[h] = c
172 | dev_head_total_count += c
173 | remaining_heads.pop(h)
174 |
175 | while dev_head_total_count < dev_size:
176 | h = random.sample(remaining_heads.keys(), 1)[0]
177 | c = remaining_heads[h]
178 | if c < head_size_threshold:
179 | dev_selected_heads[h] = c
180 | dev_head_total_count += c
181 | remaining_heads.pop(h)
182 |
183 | dev = [l for l in data if l[0] in dev_selected_heads.keys()]
184 |
185 | dev_test_heads = set(list(dev_selected_heads.keys()) + list(test_selected_heads.keys()))
186 | train = [l for l in data if l[0] not in dev_test_heads]
187 |
188 | return train, dev, test
189 |
190 |
191 | def remove_prefix(text, prefix):
192 | return text[text.startswith(prefix) and len(prefix):]
193 |
--------------------------------------------------------------------------------
/src/finetune/finetune_trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def get_finetune_trainer(model, device, loss_func, tb_period, writer, model_type):
5 | if model_type == 'enc-dec':
6 | return FineTuneTrainerForEncDec(model, device, loss_func, tb_period, writer)
7 | elif model_type == 'dec':
8 | return FineTuneTrainerForDec(model, device, loss_func, tb_period, writer)
9 |
10 |
11 | class FineTuneTrainer:
12 | def __init__(self, model, device, loss_func, tb_period, writer):
13 | self.model = model
14 | self.device = device
15 | self.loss_func = loss_func
16 | self.tb_period = tb_period
17 | self.writer = writer
18 |
19 | def train(self, sample, global_step, _type='train', save_tb=True):
20 | raise NotImplementedError
21 |
22 |
23 | class FineTuneTrainerForEncDec(FineTuneTrainer):
24 | def __init__(self, model, device, loss_func, tb_period, writer):
25 | super(FineTuneTrainerForEncDec, self).__init__(model, device, loss_func, tb_period, writer)
26 |
27 | def train(self, sample, global_step, _type='train', save_tb=True):
28 | enc_input_ids = sample['enc_input_ids'].to(self.device)
29 | enc_att_masks = sample['enc_att_masks'].to(self.device)
30 | dec_input_ids = sample['dec_input_ids'].to(self.device)
31 | dec_att_masks = sample['dec_att_masks'].to(self.device)
32 | dec_label_ids = sample['dec_label_ids'].to(self.device)
33 |
34 | lm_logits = self.model.forward_conditional_gen(enc_input_ids, enc_att_masks, dec_input_ids, dec_att_masks)
35 |
36 | B, S, V = lm_logits.shape
37 | loss = self.loss_func(lm_logits.view(B * S, V), dec_label_ids.view(-1)).view(B, S)
38 | loss = torch.sum(loss * dec_att_masks, dim=-1) / torch.sum(dec_att_masks, dim=-1)
39 | loss = torch.mean(loss)
40 | if global_step % self.tb_period == 0 and _type == 'train' and save_tb:
41 | self.writer.add_scalar('train/loss'.format(_type), loss.item(), global_step)
42 | self.writer.flush()
43 |
44 | return loss
45 |
46 |
47 | class FineTuneTrainerForDec(FineTuneTrainer):
48 | def __init__(self, model, device, loss_func, tb_period, writer):
49 | super(FineTuneTrainerForDec, self).__init__(model, device, loss_func, tb_period, writer)
50 |
51 | def train(self, sample, global_step, _type='train', save_tb=True):
52 | input_ids = sample['input_ids'].to(self.device)
53 | att_masks = sample['att_masks'].to(self.device)
54 | label_ids = sample['label_ids'].to(self.device)
55 |
56 | lm_logits = self.model.forward_conditional_gen(input_ids, att_masks)
57 |
58 | B, S, V = lm_logits.shape
59 | loss = self.loss_func(lm_logits.view(B * S, V), label_ids.view(-1)).view(B, S)
60 | loss = torch.sum(loss * att_masks, dim=-1) / torch.sum(att_masks, dim=-1)
61 | loss = torch.mean(loss)
62 | if global_step % self.tb_period == 0 and _type == 'train' and save_tb:
63 | self.writer.add_scalar('train/loss'.format(_type), loss.item(), global_step)
64 | self.writer.flush()
65 |
66 | return loss
67 |
68 |
--------------------------------------------------------------------------------
/src/finetune/finetune_utils.py:
--------------------------------------------------------------------------------
1 | from src.sampler_utils import load_raw_tsv
2 | from torch.utils.data import Dataset
3 | from src.sampler import BaseDataSampler
4 | import random
5 | import torch
6 |
7 | random.seed(42)
8 |
9 |
10 | def load_fntn_datasets(dataset_cfg, tokenizer, logger):
11 | logger.info('Target Dataset : {}'.format(dataset_cfg['name']))
12 | dataset = {key: load_raw_tsv(dataset_cfg['dir'][key], tokenizer, logger, dataset_cfg['truncate'])
13 | for key in dataset_cfg['dir']}
14 |
15 | return dataset
16 |
17 |
18 | def get_finetune_dataset(dataset, dataset_cfg, tokenizer, logger, _type, model_type='enc-dec'):
19 | if model_type == 'enc-dec':
20 | fntn_dataset = FineTuneDatasetForEncDec(dataset, tokenizer, dataset_cfg['truncate'])
21 | elif model_type == 'dec':
22 | fntn_dataset = FineTuneDatasetForDec(dataset, tokenizer, dataset_cfg['truncate'])
23 | else:
24 | raise NotImplementedError
25 | logger.info("Load Fine-tuning datasets [{}] : {}".format(_type, len(fntn_dataset)))
26 | return fntn_dataset
27 |
28 |
29 | def get_eval_dataset(dataset, tokenizer, logger, _type, model_type='enc-dec'):
30 | triple_dict = dict()
31 | for row in dataset:
32 | s, r, o = row
33 | key = s + '||' + r
34 | if triple_dict.get(key) is None:
35 | triple_dict[key] = list()
36 | triple_dict[key].append(o)
37 |
38 | outputs = list()
39 | for key in triple_dict:
40 | src = key.split('||')
41 | ref = list(set(triple_dict[key]))
42 | outputs.append({'src': src, 'ref': ref})
43 |
44 | logger.info("Load Evaluation datasets [{}] : {}".format(_type, len(outputs)))
45 |
46 | if model_type == 'enc-dec':
47 | return EvalDatasetForEncDec(outputs, tokenizer, logger)
48 | elif model_type == 'dec':
49 | return EvalDatasetForDec(outputs, tokenizer, logger)
50 | else:
51 | raise NotImplementedError
52 |
53 |
54 | class EvalDataset(Dataset):
55 | def __init__(self, dataset, tokenizer, logger):
56 | self.dataset = dataset
57 | self.tokenizer = tokenizer
58 | self.logger = logger
59 | self.bos = self.tokenizer.bos_token_id
60 | self.eos = self.tokenizer.eos_token_id
61 | self.sep = self.tokenizer.sep_token_id
62 | self.gen = self.tokenizer.gen_token_id
63 |
64 | def formatting(self, sample):
65 | raise NotImplementedError
66 |
67 | def __len__(self):
68 | return len(self.dataset)
69 |
70 | def __getitem__(self, idx):
71 | sample = self.dataset[idx]
72 | return self.formatting(sample)
73 |
74 |
75 | class EvalDatasetForEncDec(EvalDataset):
76 | def __init__(self, dataset, tokenizer, logger):
77 | super(EvalDatasetForEncDec, self).__init__(dataset, tokenizer, logger)
78 |
79 | def formatting(self, sample):
80 | dp = {'src': None, 'ref': list()}
81 |
82 | src, ref_list = sample['src'], sample['ref']
83 | s, r = [self.tokenizer(i) for i in src]
84 | _input = [self.bos] + s + [self.sep] + r + [self.sep] + [self.gen] + [self.eos]
85 | dp['src'] = _input
86 | for ref in ref_list:
87 | ref_token = self.tokenizer(ref)
88 | _output = [self.bos] + ref_token + [self.eos]
89 | dp['ref'].append(_output)
90 |
91 | if len(dp['ref']) == 0:
92 | raise Exception
93 |
94 | return dp
95 |
96 |
97 | class EvalDatasetForDec(EvalDataset):
98 | def __init__(self, dataset, tokenizer, logger):
99 | super(EvalDatasetForDec, self).__init__(dataset, tokenizer, logger)
100 |
101 | def __len__(self):
102 | return len(self.dataset)
103 |
104 | def formatting(self, sample):
105 | dp = {'src': None, 'ref': list()}
106 |
107 | src, ref_list = sample['src'], sample['ref']
108 | s, r = [self.tokenizer(i) for i in src]
109 | _input = [self.bos] + s + [self.sep] + r + [self.sep] + [self.gen]
110 | dp['src'] = _input
111 | for ref in ref_list:
112 | ref_token = self.tokenizer(ref)
113 | _output = ref_token + [self.eos]
114 | dp['ref'].append(_output)
115 |
116 | if len(dp['ref']) == 0:
117 | raise Exception
118 |
119 | return dp
120 |
121 |
122 | class FineTuneDataset(Dataset):
123 | def __init__(self, dataset, tokenizer):
124 | super(FineTuneDataset, self).__init__()
125 | self.dataset = dataset
126 | self.tokenizer = tokenizer
127 | self.gen = self.tokenizer.gen_token_id
128 | self.bos = self.tokenizer.bos_token_id
129 | self.eos = self.tokenizer.eos_token_id
130 | self.sep = self.tokenizer.sep_token_id
131 | self.pad = self.tokenizer.pad_token_id
132 | self.shuffle()
133 |
134 | def __len__(self):
135 | return len(self.dataset)
136 |
137 | def formatting(self, sample):
138 | raise NotImplementedError
139 |
140 | def shuffle(self):
141 | random.shuffle(self.dataset)
142 |
143 | def __getitem__(self, idx):
144 | sample = self.dataset[idx]
145 | sample = self.formatting(sample)
146 | return sample
147 |
148 |
149 | class FineTuneDatasetForEncDec(FineTuneDataset):
150 | def __init__(self, dataset, tokenizer, truncate):
151 | super(FineTuneDatasetForEncDec, self).__init__(dataset, tokenizer)
152 | self.enc_len, self.dec_len = truncate['subj_len'] + 6, truncate['obj_len'] + 5
153 |
154 | def formatting(self, sample):
155 | s, r, o = sample
156 | s_tokens = self.tokenizer(s)
157 | r_tokens = self.tokenizer(r)
158 | o_tokens = self.tokenizer(o)
159 |
160 | _input = [self.bos]
161 | _input.extend(s_tokens + [self.sep])
162 | _input.extend(r_tokens + [self.sep])
163 | _input.extend([self.gen])
164 | _input.extend([self.pad] * (self.enc_len - len(_input)))
165 |
166 | _output = [self.bos] + o_tokens + [self.eos]
167 | _output.extend([self.pad] * (self.dec_len - len(_output)))
168 |
169 | _enc_input_ids = torch.tensor(_input)
170 | _enc_att_masks = torch.ones_like(_enc_input_ids) * (_enc_input_ids != self.pad)
171 | _dec_origin = torch.tensor(_output)
172 | _dec_input_ids = _dec_origin[:-1].clone().detach()
173 | _dec_att_masks = torch.ones_like(_dec_input_ids) * (_dec_input_ids != self.pad)
174 | _dec_output_ids = _dec_origin[1:].clone().detach()
175 |
176 | output = {'enc_input_ids': _enc_input_ids,
177 | 'enc_att_masks': _enc_att_masks,
178 | 'dec_input_ids': _dec_input_ids,
179 | 'dec_att_masks': _dec_att_masks,
180 | 'dec_label_ids': _dec_output_ids}
181 |
182 | return output
183 |
184 |
185 | class FineTuneDatasetForDec(FineTuneDataset):
186 | def __init__(self, dataset, tokenizer, truncate):
187 | super(FineTuneDatasetForDec, self).__init__(dataset, tokenizer)
188 | self.dec_len = truncate['subj_len'] + 6 + truncate['obj_len'] + 5
189 |
190 | def __len__(self):
191 | return len(self.dataset)
192 |
193 | def formatting(self, sample):
194 | s, r, o = sample
195 | s_tokens = self.tokenizer(s)
196 | r_tokens = self.tokenizer(r)
197 | o_tokens = self.tokenizer(o)
198 |
199 | _input = [self.bos]
200 | _input.extend(s_tokens + [self.sep])
201 | _input.extend(r_tokens + [self.sep, self.gen])
202 | _input.extend(o_tokens + [self.eos])
203 |
204 | _input.extend([self.pad] * (self.dec_len - len(_input)))
205 | _input_ids = _input[:-1]
206 | _label_ids = _input[1:]
207 |
208 | _input_ids = torch.tensor(_input_ids)
209 | _label_ids = torch.tensor(_label_ids)
210 |
211 | _att_masks = torch.ones_like(_input_ids) * (_input_ids != self.pad)
212 | output = {'input_ids': _input_ids,
213 | 'att_masks': _att_masks,
214 | 'label_ids': _label_ids}
215 |
216 | return output
217 |
218 |
--------------------------------------------------------------------------------
/src/lr_schedule.py:
--------------------------------------------------------------------------------
1 |
2 | class WarmupLinearScheduler(object):
3 | def __init__(self, optimizer, max_lr, warmup_steps, total_steps):
4 | self.optimizer = optimizer
5 | self.lr = 0.0
6 | self.total_steps = total_steps
7 | self.warmup_steps = warmup_steps
8 | self.max_lr = max_lr
9 | self.warmup_increase = None
10 | self.lr_decay = None
11 | self._calculate()
12 | self._adapt_lr()
13 |
14 | def __call__(self, step):
15 | if step <= self.warmup_steps:
16 | self.lr += self.warmup_increase
17 | else:
18 | self.lr -= self.lr_decay
19 | assert not self.lr < 0
20 | self._adapt_lr()
21 |
22 | def _calculate(self):
23 | self.warmup_increase = self.max_lr / self.warmup_steps
24 | self.lr_decay = self.max_lr / (self.total_steps - self.warmup_steps)
25 |
26 | def _adapt_lr(self):
27 | for g in self.optimizer.param_groups:
28 | g['lr'] = self.lr
--------------------------------------------------------------------------------
/src/rec_adam.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | import numpy as np
4 |
5 | import torch
6 | from torch.optim import Optimizer
7 |
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | def anneal_function(function, step, k, t0, weight):
13 | if function == 'sigmoid':
14 | return float(1 / (1 + np.exp(-k * (step - t0)))) * weight
15 | elif function == 'linear':
16 | return min(1, step / t0) * weight
17 | elif function == 'constant':
18 | return weight
19 | else:
20 | ValueError
21 |
22 |
23 | class RecAdam(Optimizer):
24 | """ Implementation of RecAdam optimizer, a variant of Adam optimizer.
25 | Parameters:
26 | lr (float): learning rate. Default 1e-3.
27 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999)
28 | eps (float): Adams epsilon. Default: 1e-6
29 | weight_decay (float): Weight decay. Default: 0.0
30 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True.
31 | anneal_fun (str): a hyperparam for the anneal function, decide the function of the curve. Default 'sigmoid'.
32 | anneal_k (float): a hyperparam for the anneal function, decide the slop of the curve. Choice: [0.05, 0.1, 0.2, 0.5, 1]
33 | anneal_t0 (float): a hyperparam for the anneal function, decide the middle point of the curve. Choice: [100, 250, 500, 1000]
34 | anneal_w (float): a hyperparam for the anneal function, decide the scale of the curve. Default 1.0.
35 | pretrain_cof (float): the coefficient of the quadratic penalty. Default 5000.0.
36 | pretrain_params (list of tensors): the corresponding group of params in the pretrained model.
37 | """
38 |
39 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True,
40 | anneal_fun='sigmoid', anneal_k=0, anneal_t0=0, anneal_w=1.0, pretrain_cof=5000.0, pretrain_params=None):
41 | if lr < 0.0:
42 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
43 | if not 0.0 <= betas[0] < 1.0:
44 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
45 | if not 0.0 <= betas[1] < 1.0:
46 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
47 | if not 0.0 <= eps:
48 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
49 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias,
50 | anneal_fun=anneal_fun, anneal_k=anneal_k, anneal_t0=anneal_t0, anneal_w=anneal_w,
51 | pretrain_cof=pretrain_cof, pretrain_params=pretrain_params)
52 | super().__init__(params, defaults)
53 |
54 | def step(self, closure=None):
55 | """Performs a single optimization step.
56 | Arguments:
57 | closure (callable, optional): A closure that reevaluates the model
58 | and returns the loss.
59 | """
60 | anneal_lambda = -1
61 | loss = None
62 | if closure is not None:
63 | loss = closure()
64 | for group in self.param_groups:
65 | for p, pp in zip(group["params"], group["pretrain_params"]):
66 | if p.grad is None:
67 | continue
68 | grad = p.grad.data
69 | if grad.is_sparse:
70 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
71 |
72 | state = self.state[p]
73 |
74 | # State initialization
75 | if len(state) == 0:
76 | state["step"] = 0
77 | # Exponential moving average of gradient values
78 | state["exp_avg"] = torch.zeros_like(p.data)
79 | # Exponential moving average of squared gradient values
80 | state["exp_avg_sq"] = torch.zeros_like(p.data)
81 |
82 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
83 | beta1, beta2 = group["betas"]
84 |
85 | state["step"] += 1
86 |
87 | # Decay the first and second moment running average coefficient
88 | # In-place operations to update the averages at the same time
89 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
90 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
91 | denom = exp_avg_sq.sqrt().add_(group["eps"])
92 |
93 | step_size = group["lr"]
94 | if group["correct_bias"]:
95 | bias_correction1 = 1.0 - beta1 ** state["step"]
96 | bias_correction2 = 1.0 - beta2 ** state["step"]
97 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
98 |
99 | # With RecAdam method, the optimization objective is
100 | # Loss = lambda(t)*Loss_T + (1-lambda(t))*Loss_S
101 | # Loss = lambda(t)*Loss_T + (1-lambda(t))*\gamma/2*\sum((\theta_i-\theta_i^*)^2)
102 | if group['anneal_w'] > 0.0:
103 | # We calculate the lambda as the annealing function
104 | anneal_lambda = anneal_function(group['anneal_fun'], state["step"], group['anneal_k'],
105 | group['anneal_t0'], group['anneal_w'])
106 | assert anneal_lambda <= group['anneal_w']
107 | # The loss of the target task is multiplied by lambda(t)
108 | p.data.addcdiv_(-step_size * anneal_lambda, exp_avg, denom)
109 | # Add the quadratic penalty to simulate the pretraining tasks
110 | p.data.add_(-group["lr"] * (group['anneal_w'] - anneal_lambda) * group["pretrain_cof"], p.data - pp.data)
111 | else:
112 | p.data.addcdiv_(-step_size, exp_avg, denom)
113 |
114 | # Just adding the square of the weights to the loss function is *not*
115 | # the correct way of using L2 regularization/weight decay with Adam,
116 | # since that will interact with the m and v parameters in strange ways.
117 | #
118 | # Instead we want to decay the weights in a manner that doesn't interact
119 | # with the m/v parameters. This is equivalent to adding the square
120 | # of the weights to the loss with plain (non-momentum) SGD.
121 | # Add weight decay at the end (fixed version)
122 | if group["weight_decay"] > 0.0:
123 | p.data.add_(-group["lr"] * group["weight_decay"], p.data)
124 |
125 | return loss, anneal_lambda
--------------------------------------------------------------------------------
/src/rec_adam_wrapper.py:
--------------------------------------------------------------------------------
1 | from src.rec_adam import RecAdam, anneal_function
2 | from transformers import get_linear_schedule_with_warmup
3 |
4 | ##########
5 | '''
6 | #params = [p for n, p in model.named_parameters()]
7 | pretrained_weights = anneal_targets
8 | new_model = model
9 | no_decay = ["bias", "layer_norm.weight"]
10 | for n, p in model.named_parameters():
11 | for nd in no_decay:
12 | if nd in n:
13 | print(n)
14 |
15 | for n, p in new_model.named_parameters():
16 | if not any(nd in n for nd in no_decay):
17 | print(n)
18 | if n in pretrained_weights:
19 | print(n)
20 | pretrained_weights.keys()
21 |
22 | any(nd in n for nd in no_decay)
23 | [p for n, p in new_model.named_parameters() if
24 | not any(nd in n for nd in no_decay) and n in pretrained_weights]
25 | '''
26 | ##########
27 |
28 | def get_adam_optimizer(new_model, pretrained_model, pretrained_weights, args, t_total):
29 | no_decay = ["bias", "LayerNorm.weight"]
30 |
31 | optimizer_grouped_parameters = [
32 | {
33 | "params": [p for n, p in new_model.named_parameters() if
34 | not any(nd in n for nd in no_decay) and n in pretrained_weights],
35 | "weight_decay": args.weight_decay,
36 | "anneal_w": args.recadam_anneal_w,
37 | "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if
38 | not any(nd in p_n for nd in no_decay) and p_n in pretrained_weights]
39 | },
40 | {
41 | "params": [p for n, p in new_model.named_parameters() if
42 | not any(nd in n for nd in no_decay) and n not in pretrained_weights],
43 | "weight_decay": args.weight_decay,
44 | "anneal_w": 0.0,
45 | "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if
46 | not any(nd in p_n for nd in no_decay) and p_n not in pretrained_weights]
47 | },
48 | {
49 | "params": [p for n, p in new_model.named_parameters() if
50 | any(nd in n for nd in no_decay) and n in pretrained_weights],
51 | "weight_decay": 0.0,
52 | "anneal_w": args.recadam_anneal_w,
53 | "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if
54 | any(nd in p_n for nd in no_decay) and p_n in pretrained_weights == p_n]
55 | },
56 | {
57 | "params": [p for n, p in new_model.named_parameters() if
58 | any(nd in n for nd in no_decay) and n not in pretrained_weights],
59 | "weight_decay": 0.0,
60 | "anneal_w": 0.0,
61 | "pretrain_params": [p_p for p_n, p_p in pretrained_model.named_parameters() if
62 | any(nd in p_n for nd in no_decay) and p_n not in pretrained_weights]
63 | }
64 | ]
65 |
66 | optimizer = RecAdam(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon,
67 | anneal_fun=args.recadam_anneal_fun,
68 | anneal_k=args.recadam_anneal_k, anneal_t0=args.recadam_anneal_t0,
69 | pretrain_cof=args.recadam_pretrain_cof)
70 |
71 | scheduler = get_linear_schedule_with_warmup(
72 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
73 | )
74 |
75 | return optimizer, scheduler
76 |
--------------------------------------------------------------------------------
/src/sampler_utils.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import pickle as pkl
3 | from tqdm import tqdm
4 | from copy import deepcopy
5 | import torch
6 | import os
7 |
8 | def load_raw_tsv(raw_path, tokenizer, logger, truncate):
9 | data = pd.read_csv(raw_path, delimiter='\t', header=None)
10 | BLK_TOKEN = tokenizer.blk_token
11 | triples = list()
12 | subj_len = truncate['subj_len']
13 | obj_len = truncate['obj_len']
14 |
15 | cached_path = raw_path[:-4] + '_cached_{}_{}.torch-pkl'.format(subj_len, obj_len)
16 | if os.path.exists(cached_path):
17 | logger.info('Load from {}'.format(cached_path))
18 | return torch.load(cached_path)
19 |
20 | for row in tqdm(data.iloc, desc='loading_raw_file...', ncols=70):
21 |
22 | s, r, o = row
23 |
24 | if pd.isna(s) or pd.isna(r) or pd.isna(o):
25 | continue
26 | s = s.replace('___', BLK_TOKEN)
27 | r = '<{}>'.format(r)
28 | if o != o: # pass nan
29 | continue
30 | o = o.replace('___', BLK_TOKEN)
31 |
32 | if len(tokenizer(s)) > subj_len or len(tokenizer(o)) > obj_len:
33 | continue
34 | triples.append([s, r, o])
35 |
36 | logger.info('Total loaded raw samples : {}'.format(len(triples)))
37 |
38 | torch.save(triples, cached_path)
39 |
40 | return triples
41 |
42 |
43 | def load_atomic_sim_pkl(raw_path, tokenizer, logger):
44 | logger.info('Load Similar matrix from {}'.format(raw_path))
45 | with open(raw_path, 'rb') as f:
46 | data = pkl.load(f)
47 | blk_token = tokenizer.blk_token
48 | output = deepcopy(data)
49 |
50 | for s, i in data['s2i'].items():
51 | s = s.replace('___', blk_token)
52 | output['s2i'][s] = i
53 | output['i2s'][i] = s
54 | return output
55 |
--------------------------------------------------------------------------------
/src/tokenizer.py:
--------------------------------------------------------------------------------
1 | def adapt_commonsense_tokenizer(tokenizer, config):
2 | tokenizer.add_tokens([config['additional_tokens'][key] for key in config['additional_tokens']])
3 | tokenizer.add_tokens([tokens for tokens in config['relation_tokens']])
4 | tokenizer.add_special_tokens(config['special_tokens'])
5 | return tokenizer
6 |
7 | class BaseCSKGTokenizer:
8 | def __init__(self, tokenizer, config):
9 | self.tokenizer = tokenizer
10 | self.config = config
11 |
12 | self.tokenizer.add_tokens([self.config['additional_tokens'][key] for key in self.config['additional_tokens']])
13 | self.tokenizer.add_tokens([tokens for tokens in self.config['relation_tokens']])
14 | self.tokenizer.add_special_tokens(self.config['special_tokens'])
15 |
16 | # Additional Token Ids
17 | self.gen_token = self.config['additional_tokens']['gen_token']
18 | self.gen_token_id = self.tokenize(self.gen_token)[0]
19 | self.blk_token = self.config['additional_tokens']['blk_token']
20 | self.blk_token_id = self.tokenize(self.blk_token)[0]
21 | self.none_token = self.config['additional_tokens']['none_token']
22 | self.none_token_id = self.tokenize(self.none_token)[0]
23 | self.gen_task_token = self.config['additional_tokens']['gen_task_token']
24 | self.gen_task_token_id = self.tokenize(self.gen_task_token)[0]
25 |
26 | # Contrastive Task Tokens
27 | self.con_task_token = self.config['additional_tokens']['con_task_token']
28 | self.con_task_token_id = self.tokenize(self.con_task_token)[0]
29 | self.con_task_token_s = self.config['additional_tokens']['con_task_token_for_s']
30 | self.con_task_token_s_id = self.tokenize(self.con_task_token_s)[0]
31 | self.con_task_token_ro = self.config['additional_tokens']['con_task_token_for_ro']
32 | self.con_task_token_ro_id = self.tokenize(self.con_task_token_ro)[0]
33 |
34 | # Reconstruct Task Tokens
35 | self.rec_task_token = self.config['additional_tokens']['rec_task_token']
36 | self.rec_task_token_id = self.tokenize(self.rec_task_token)[0]
37 |
38 | # Mask Generate Task Tokens
39 | self.mask_gen_task_token = self.config['additional_tokens']['rec_task_token']
40 | self.mask_gen_task_token_id = self.tokenize(self.rec_task_token)[0]
41 |
42 | # Special Token Ids
43 | self.bos_token = self.tokenizer.bos_token
44 | self.bos_token_id = self.tokenizer.bos_token_id
45 | self.eos_token = self.tokenizer.eos_token
46 | self.eos_token_id = self.tokenizer.eos_token_id
47 | self.pad_token = self.tokenizer.pad_token
48 | self.pad_token_id = self.tokenizer.pad_token_id
49 | self.mask_token = self.tokenizer.mask_token
50 | self.mask_token_id = self.tokenizer.mask_token_id
51 | self.sep_token = self.tokenizer.sep_token
52 | self.sep_token_id = self.tokenizer.sep_token_id
53 |
54 | def tokenize(self, sequence):
55 | raise NotImplementedError
56 |
57 | def __call__(self, sequence):
58 | return self.tokenize(sequence)
59 |
60 | def __len__(self):
61 | return len(self.tokenizer)
62 |
63 |
64 | class BartCSKGTokenizer(BaseCSKGTokenizer):
65 | def __init__(self, tokenizer, config):
66 | super(BartCSKGTokenizer, self).__init__(tokenizer, config)
67 |
68 | def tokenize(self, seq):
69 | assert type(seq) is str
70 | return self.tokenizer(seq)['input_ids'][1:-1]
71 |
72 |
73 | class Gpt2CSKGTokenizer(BaseCSKGTokenizer):
74 | def __init__(self, tokenizer, config):
75 | super(Gpt2CSKGTokenizer, self).__init__(tokenizer, config)
76 |
77 | def tokenize(self, seq):
78 | assert type(seq) is str
79 | return self.tokenizer(seq)['input_ids']
80 |
81 |
--------------------------------------------------------------------------------
/src/train_utils.py:
--------------------------------------------------------------------------------
1 | from src.feed_model import *
2 | from torch.utils.data import DataLoader
3 |
4 |
5 | def get_data_feeder_from_sampler(sampler, options, tokenizer, task_cfg, batch_size, iter_per_epoch):
6 | data_feeder = dict()
7 | enc_len = task_cfg['max_enc'] + 5
8 | dec_len = task_cfg['max_dec'] + 5
9 | for task_key, task_sampler in sampler.items():
10 | for option_key, val in options['common'].items():
11 | options[task_key][option_key] = val
12 |
13 |
14 | if task_key == 'con_task':
15 | if task_cfg['con_task']['method'] == 'cluster':
16 | group_num = task_cfg['con_task']['cluster_contrast']['group_num']
17 | elif task_cfg['con_task']['method'] == 'naive':
18 | group_num = batch_size
19 | else:
20 | raise NotImplementedError
21 |
22 | task_adaptor = ConTaskAdaptor(
23 | options[task_key], task_sampler, tokenizer, enc_len, dec_len, batch_size, iter_per_epoch, group_num)
24 | data_feeder[task_key] = task_adaptor
25 |
26 | elif task_key == 'rec_inf_shu_task':
27 | task_adaptor = RecInfShuTaskAdaptor(
28 | options[task_key], task_sampler, tokenizer, enc_len, dec_len, batch_size, iter_per_epoch * batch_size)
29 | data_feeder[task_key] = task_adaptor
30 |
31 | elif task_key == 'mask_gen_task':
32 | task_adaptor = MaskGenTaskAdaptor(
33 | options[task_key], task_sampler, tokenizer, enc_len, dec_len, batch_size, iter_per_epoch * batch_size)
34 | data_feeder[task_key] = task_adaptor
35 |
36 | return data_feeder
37 |
38 |
39 | class WrappedLoader:
40 | def __init__(self, dataset, batch_size):
41 | self.loader = DataLoader(dataset,batch_size=batch_size, drop_last=False, shuffle=True, num_workers=10)
42 | pass
43 |
44 | def __getitem__(self, idx):
45 | for data in self.loader:
46 | yield data
47 |
--------------------------------------------------------------------------------
/src/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class RecTaskTrainerForEncDec:
5 | def __init__(self, model, device, loss_func, tb_period, writer, _type='train'):
6 | self.model = model
7 | self.device = device
8 | self.loss_func = loss_func
9 | self.tb_period = tb_period
10 | self.writer = writer
11 | self._type = _type
12 |
13 | def train(self, task_data, global_step, save=False):
14 | enc_input_ids = task_data['enc_input_ids'].to(self.device)
15 | enc_att_mask = task_data['enc_att_mask'].to(self.device)
16 | dec_input_ids = task_data['dec_input_ids'].to(self.device)
17 | dec_att_mask = task_data['dec_att_mask'].to(self.device)
18 | dec_label_ids = task_data['dec_label_ids'].to(self.device)
19 |
20 | lm_logits = self.model.forward_conditional_gen(enc_input_ids, enc_att_mask, dec_input_ids, dec_att_mask)
21 |
22 | B, S, V = lm_logits.shape
23 |
24 | loss = self.loss_func(lm_logits.view(B * S, V), dec_label_ids.view(-1)).view(B, S)
25 | loss = torch.sum(torch.mul(loss, dec_att_mask), -1) / torch.sum(dec_att_mask, -1)
26 | total_loss = torch.mean(loss)
27 |
28 | if (global_step % self.tb_period == 0) and save is True:
29 | self.writer.add_scalar('{}_rec_task/total_loss'.format(self._type), total_loss.item(), global_step)
30 | self.writer.flush()
31 |
32 | return total_loss
33 |
34 |
35 | class MaskGenTaskTrainerForDec:
36 | def __init__(self, model, device, loss_func, tb_period, writer, _type='train'):
37 | self.model = model
38 | self.device = device
39 | self.loss_func = loss_func
40 | self.tb_period = tb_period
41 | self.writer = writer
42 | self._type = _type
43 |
44 | def train(self, task_data, global_step, save=False):
45 | input_ids = task_data['input_ids'].to(self.device)
46 | att_mask = task_data['att_mask'].to(self.device)
47 | label_ids = task_data['label_ids'].to(self.device)
48 |
49 | lm_logits = self.model.forward_conditional_gen(input_ids, att_mask)
50 |
51 | B, S, V = lm_logits.shape
52 |
53 | loss = self.loss_func(lm_logits.view(B * S, V), label_ids.view(-1)).view(B, S)
54 | loss = torch.sum(torch.mul(loss, att_mask), -1) / torch.sum(att_mask, -1)
55 | total_loss = torch.mean(loss)
56 |
57 | if (global_step % self.tb_period == 0) and save is True:
58 | self.writer.add_scalar('{}_mask_gen_task/total_loss'.format(self._type), total_loss.item(), global_step)
59 | self.writer.flush()
60 |
61 | return total_loss
62 |
63 |
64 |
65 | class ConTaskTrainerBase:
66 | def __init__(self, model, device, distance_model, loss_func, tb_period, writer, share_proj_layer, args, _type='train'):
67 | self.model = model
68 | self.device = device
69 | self.distance_model = distance_model
70 | self.loss_func = loss_func
71 | self.tb_period = tb_period
72 | self.writer = writer
73 | self.share_s, self.share_ro = share_proj_layer
74 | self._type = _type
75 |
76 | class ConTaskTrainerForEncDec(ConTaskTrainerBase):
77 | def train(self, task_data, global_step, save=False):
78 | enc_input_ids_s = task_data['enc_input_ids_s'].to(self.device)
79 | enc_input_ids_ro = task_data['enc_input_ids_ro'].to(self.device)
80 | enc_att_mask_s = task_data['enc_att_mask_s'].to(self.device)
81 | enc_att_mask_ro = task_data['enc_att_mask_ro'].to(self.device)
82 | dec_input_ids_s = task_data['dec_input_ids_s'].to(self.device)
83 | dec_input_ids_ro = task_data['dec_input_ids_ro'].to(self.device)
84 | dec_att_mask_s = task_data['dec_att_mask_s'].to(self.device)
85 | dec_att_mask_ro = task_data['dec_att_mask_ro'].to(self.device)
86 | pos_sample = task_data['pos_samples'].to(self.device)
87 |
88 | s_repre = self.model.forward_latent_feature(
89 | enc_input_ids_s, enc_att_mask_s, dec_input_ids_s, dec_att_mask_s)#, None, self.share_s)
90 | ro_repre = self.model.forward_latent_feature(
91 | enc_input_ids_ro, enc_att_mask_ro, dec_input_ids_ro, dec_att_mask_ro)#, None, self.share_ro)
92 |
93 | dist_s_ro = self.distance_model(s_repre, ro_repre)
94 | dist_s_s = self.distance_model(s_repre, s_repre)
95 | dist_ro_ro = self.distance_model(ro_repre, ro_repre)
96 |
97 | sro_loss, sro_pos_dists, sro_neg_dists = self.loss_func.get_loss(dist_s_ro, pos_sample)
98 | ss_loss, ss_pos_dists, ss_neg_dists = self.loss_func.get_loss(dist_s_s, pos_sample)
99 | roro_loss, roro_pos_dists, roro_neg_dists = self.loss_func.get_loss(dist_ro_ro, pos_sample)
100 |
101 | loss = sro_loss + ss_loss + roro_loss
102 | pos_dists = sro_pos_dists + ss_pos_dists + roro_pos_dists
103 | neg_dists = sro_neg_dists + ss_neg_dists + roro_neg_dists
104 |
105 | if (global_step % self.tb_period == 0) and save is True:
106 | self.writer.add_scalar('{}_con_task/total_loss'.format(self._type), float(loss), global_step)
107 | self.writer.add_scalar('{}_con_task/s-ro_loss'.format(self._type), float(sro_loss), global_step)
108 | self.writer.add_scalar('{}_con_task/s-s_loss'.format(self._type), float(ss_loss), global_step)
109 | self.writer.add_scalar('{}_con_task/ro-ro_loss'.format(self._type), float(roro_loss), global_step)
110 | self.writer.add_scalar('{}_con_task/total_pos_dists'.format(self._type), float(pos_dists), global_step)
111 | self.writer.add_scalar('{}_con_task/total_neg_dists'.format(self._type), float(neg_dists), global_step)
112 | self.writer.add_scalar('{}_con_task/total_pos-neg'.format(self._type), float(pos_dists - neg_dists), global_step)
113 | self.writer.add_scalar('{}_con_task/s-ro_pos-neg'.format(self._type), float(sro_pos_dists - sro_neg_dists), global_step)
114 | self.writer.add_scalar('{}_con_task/s-s_pos-neg'.format(self._type), float(ss_pos_dists - ss_neg_dists), global_step)
115 | self.writer.add_scalar('{}_con_task/ro-ro_pos-neg'.format(self._type), float(roro_pos_dists - roro_neg_dists), global_step)
116 | self.writer.flush()
117 |
118 | return loss
119 |
120 |
121 | class ConTaskTrainerForDec(ConTaskTrainerBase):
122 | def train(self, task_data, global_step, save=False):
123 | input_ids_s = task_data['input_ids_s'].to(self.device)
124 | input_ids_ro = task_data['input_ids_ro'].to(self.device)
125 | att_mask_s = task_data['att_mask_s'].to(self.device)
126 | att_mask_ro = task_data['att_mask_ro'].to(self.device)
127 | pos_sample = task_data['pos_samples'].to(self.device)
128 |
129 | s_repre = self.model.forward_latent_feature(
130 | input_ids_s, att_mask_s)
131 | ro_repre = self.model.forward_latent_feature(
132 | input_ids_ro, att_mask_ro)
133 |
134 | dist_s_ro = self.distance_model(s_repre, ro_repre)
135 | dist_s_s = self.distance_model(s_repre, s_repre)
136 | dist_ro_ro = self.distance_model(ro_repre, ro_repre)
137 |
138 | sro_loss, sro_pos_dists, sro_neg_dists = self.loss_func.get_loss(dist_s_ro, pos_sample)
139 | ss_loss, ss_pos_dists, ss_neg_dists = self.loss_func.get_loss(dist_s_s, pos_sample)
140 | roro_loss, roro_pos_dists, roro_neg_dists = self.loss_func.get_loss(dist_ro_ro, pos_sample)
141 |
142 | loss = sro_loss + ss_loss + roro_loss
143 | pos_dists = sro_pos_dists + ss_pos_dists + roro_pos_dists
144 | neg_dists = sro_neg_dists + ss_neg_dists + roro_neg_dists
145 |
146 | if (global_step % self.tb_period == 0) and save is True:
147 | self.writer.add_scalar('{}_con_task/total_loss'.format(self._type), float(loss), global_step)
148 | self.writer.add_scalar('{}_con_task/s-ro_loss'.format(self._type), float(sro_loss), global_step)
149 | self.writer.add_scalar('{}_con_task/s-s_loss'.format(self._type), float(ss_loss), global_step)
150 | self.writer.add_scalar('{}_con_task/ro-ro_loss'.format(self._type), float(roro_loss), global_step)
151 | self.writer.add_scalar('{}_con_task/total_pos_dists'.format(self._type), float(pos_dists), global_step)
152 | self.writer.add_scalar('{}_con_task/total_neg_dists'.format(self._type), float(neg_dists), global_step)
153 | self.writer.add_scalar('{}_con_task/total_pos-neg'.format(self._type), float(pos_dists - neg_dists), global_step)
154 | self.writer.add_scalar('{}_con_task/s-ro_pos-neg'.format(self._type), float(sro_pos_dists - sro_neg_dists), global_step)
155 | self.writer.add_scalar('{}_con_task/s-s_pos-neg'.format(self._type), float(ss_pos_dists - ss_neg_dists), global_step)
156 | self.writer.add_scalar('{}_con_task/ro-ro_pos-neg'.format(self._type), float(roro_pos_dists - roro_neg_dists), global_step)
157 | self.writer.flush()
158 |
159 | return loss
160 |
161 | def mean_list(items):
162 | if type(items[0]) not in (list, tuple):
163 | return sum(items) / len(items)
164 |
165 | vals = 0
166 | nums = 0
167 | for val, num in items:
168 | vals += val
169 | nums += num
170 |
171 | return vals/nums
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import yaml
3 |
4 |
5 | def load_logger(log_dir, log_level):
6 | logger = logging.getLogger('CSKG')
7 | if log_level == 'INFO':
8 | lv = logging.INFO
9 | elif log_level == 'ERROR':
10 | lv = logging.ERROR
11 | elif log_level == 'DEBUG':
12 | lv = logging.DEBUG
13 | else:
14 | raise NotImplementedError
15 | logger.setLevel(lv)
16 |
17 | formatter = logging.Formatter('%(asctime)s [%(name)s] [%(levelname)s] :: %(message)s')
18 | stream_handler = logging.StreamHandler()
19 | stream_handler.setFormatter(formatter)
20 | file_handler = logging.FileHandler(log_dir)
21 | file_handler.setFormatter(formatter)
22 |
23 | logger.addHandler(stream_handler)
24 | logger.addHandler(file_handler)
25 |
26 | return logger
27 |
28 | def load_yaml(f):
29 | if type(f) is str:
30 | with open(f, 'r') as fp:
31 | config = yaml.load(fp, Loader=yaml.FullLoader)
32 | else:
33 | raise NotImplementedError
34 |
35 | return config
--------------------------------------------------------------------------------
/system_eval/automatic_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import json
5 | sys.path.append(os.path.join(os.getcwd(), 'system_eval'))
6 |
7 | import numpy as np
8 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
9 | from utils import read_jsonl, remove_prefix
10 | from evaluation.eval import QGEvalCap
11 | from tabulate import tabulate
12 | from tqdm import tqdm
13 |
14 | os.environ['CUDA_VISIBLE_DEVICES'] = '3'
15 |
16 | def get_refs_preds(l, type=1):
17 | if type==1:
18 | tails = l["fact"]["tails"]
19 | head = l["fact"]["head"]
20 | prompt = l["prompt"]
21 | generations = l["generations"]
22 | gens = [remove_prefix(g, prompt).strip() for g in generations]
23 | if type==2:
24 | tails = l["refs"]
25 | head = l["input"]
26 | gens = [l["output"]]
27 | if type==3:
28 | tails = l["fact"]["tails"]
29 | head = l["fact"]["head"]
30 | gens = l["generations"]
31 | return gens, tails, head
32 |
33 | def get2(l):
34 | return list(zip(*l))[1]
35 |
36 |
37 | def topk_eval(model_name, data, data_type, k):
38 | topk_gts = {}
39 | topk_res = {}
40 | topk_exact_match = []
41 | topk_exact_match_not_none = []
42 | topk_bleu_score = []
43 |
44 | topk_is_head = []
45 | func = SmoothingFunction()
46 | # for i, l in enumerate(data):
47 | for i, l in tqdm(enumerate(data)):
48 | (gens, tails, head) = get_refs_preds(l, type=data_type)
49 | gens = [g.replace('', '').replace('', '').strip() for g in gens]
50 | tails = [t.replace('', '').replace('', '').strip() for t in tails]
51 | head = head.replace('', '').replace('', '').strip()
52 | sentence_tails = [t.lower() for t in tails]
53 | split_tails = [t.lower().replace('', '').replace('', '').split() for t in tails]
54 |
55 | for (j, g) in enumerate(gens[:k]):
56 | key = str(i) + "_" + str(j)
57 | topk_gts[key] = sentence_tails
58 | topk_res[key] = [g.lower()]
59 |
60 | b = sentence_bleu(split_tails, g.lower().split(), weights=(0.5, 0.5), smoothing_function=func.method1)
61 | topk_bleu_score.append((l, b))
62 | if g in sentence_tails:
63 | topk_exact_match.append((l, 1))
64 | if g != "none":
65 | topk_exact_match_not_none.append((l, 1))
66 | else:
67 | topk_exact_match.append((l, 0))
68 | if g != "none":
69 | topk_exact_match_not_none.append((l, 0))
70 | if g == head:
71 | topk_is_head.append((l, 1))
72 | else:
73 | topk_is_head.append((l, 0))
74 |
75 | print("---------------TOP K={}---------------".format(k))
76 | #print(np.mean(get2(topk_exact_match)))
77 | #print(np.mean(get2(topk_exact_match_not_none)))
78 | print(np.mean(get2(topk_bleu_score)))
79 | QGEval = QGEvalCap(model_name, topk_gts, topk_res)
80 | score, scores = QGEval.evaluate()
81 | scores["Exact_match"] = np.mean(get2(topk_exact_match))
82 | #scores["TailIsHead"] = np.mean(get2(topk_is_head))
83 | return score, topk_bleu_score, scores
84 |
85 |
86 | def eval(data_file, data_type, model_name):
87 | data = read_jsonl(data_file)
88 | return topk_eval(model_name, data, data_type, k=1)
89 |
90 | def toRow(name, results, columns):
91 | return [name] + [format(float(results[c]), '#.3f') for c in columns]
92 |
93 | parser = argparse.ArgumentParser()
94 |
95 | parser.add_argument('--mode', type=str, default='pycharm')
96 | parser.add_argument('--file_dir', type=str, default='/mnt/data/user8/solar-commonsense_inference/log_fntn')
97 | parser.add_argument('--dataset_type', type=str, default='atomic')
98 | parser.add_argument('--model_name', type=str, default='bart')
99 | parser.add_argument('--model_size', type=str, default='large')
100 | parser.add_argument('--exp_type', type=str, default='baseline')
101 | parser.add_argument('--target', type=str, default=None)
102 |
103 | args = parser.parse_args()
104 |
105 | targets_list = list()
106 |
107 | target_path = f'{args.file_dir}/{args.dataset_type}/{args.model_name}-{args.model_size}_{args.exp_type}'
108 | if args.target is None:
109 | target_list = os.listdir(target_path)
110 | else:
111 | target_list = [args.target]
112 |
113 | target_list.sort()
114 |
115 | decode_type = ['greedy']
116 | print(target_list)
117 |
118 | #target_list = target_list[:7]
119 | for target_file in target_list:
120 | for decode in decode_type:
121 | input_file = f'{target_path}/{target_file}/{decode}_gen_examples.json'
122 | output_file = f'{target_path}/{target_file}/eval/{decode}_results.txt'
123 | results_per_sample_file = f'{target_path}/{target_file}/eval/{decode}_results_per_sample.pkl'
124 | try:
125 | with open(input_file, 'r') as f:
126 | input_file = json.load(f)
127 | except:
128 | continue
129 | # Eval
130 | print('TEST target : {}'.format(input_file['info']['ckpt']))
131 | print(f'Decoded Type {decode}')
132 | gen_data = input_file['content']
133 |
134 | scores, topk_bleu_score, score_list = topk_eval(model_name='BART-ATOMIC2020', data=gen_data, data_type=2, k=1)
135 |
136 | results_per_sample = list()
137 |
138 | for idx, sample in enumerate(gen_data):
139 | sample_result = dict()
140 | sample_result.update(sample)
141 |
142 | for key in score_list:
143 | if type(score_list[key]) is not list:
144 | continue
145 | val = score_list[key][idx]
146 | sample_result[key] = val
147 |
148 | results_per_sample.append(sample_result)
149 |
150 | print(scores)
151 | for key in scores:
152 | print(round(float(scores[key]) * 100, 4), end='\t')
153 |
154 | with open(output_file, 'w') as f:
155 | for key, item in scores.items():
156 | f.write('{} : {}\n'.format(key, item))
157 |
158 | print('\n\n')
159 |
160 | import pickle as pkl
161 | with open(results_per_sample_file, 'wb') as f:
162 | pkl.dump(results_per_sample, f)
--------------------------------------------------------------------------------
/system_eval/evaluation/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Manish Joshi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/system_eval/evaluation/README.md:
--------------------------------------------------------------------------------
1 | # qgeval
2 | Calculate Bleu, METEOR, ROUGE and CIDEr score
3 |
4 | Usage:
5 |
6 | ```python anli_evaluation/eval.py --gen_file GENERATIONS_FILE --keys MODEL_KEYS[comma-separated list] --results_file RESULTS_FILE```
7 |
--------------------------------------------------------------------------------
/system_eval/evaluation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yongho94/solar-framework_commonsense-inference/50e99ba0b0f5ae2315c72f5f35f6d385499283e1/system_eval/evaluation/__init__.py
--------------------------------------------------------------------------------
/system_eval/evaluation/bert_score/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yongho94/solar-framework_commonsense-inference/50e99ba0b0f5ae2315c72f5f35f6d385499283e1/system_eval/evaluation/bert_score/__init__.py
--------------------------------------------------------------------------------
/system_eval/evaluation/bert_score/bert_score.py:
--------------------------------------------------------------------------------
1 | from bert_score import score
2 | # Code for BertScore reused from original implementation: https://github.com/Tiiiger/bert_score
3 |
4 | class BertScore:
5 | def __init__(self):
6 | self._hypo_for_image = {}
7 | self.ref_for_image = {}
8 |
9 | def compute_score(self, gts, res):
10 |
11 | assert(gts.keys() == res.keys())
12 | imgIds = gts.keys()
13 |
14 | hyp_input = []
15 | ref_input = []
16 | same_indices = []
17 | for id in imgIds:
18 | hypo = res[id]
19 | ref = gts[id]
20 |
21 | # Sanity check.
22 | assert(type(hypo) is list)
23 | assert(len(hypo) == 1)
24 | assert(type(ref) is list)
25 | assert(len(ref) >= 1)
26 |
27 | hyp_input += [hypo[0]] * len(ref)
28 | ref_input += ref
29 | same_indices.append(len(ref_input))
30 |
31 | p, r, f_scores = score(hyp_input, ref_input, model_type="bert-base-uncased")
32 |
33 | prev_idx = 0
34 | aggreg_f1_scores = []
35 | for idx in same_indices:
36 | aggreg_f1_scores.append(f_scores[prev_idx: idx].mean().cpu().item())
37 | prev_idx = idx
38 |
39 | return sum(aggreg_f1_scores)/len(aggreg_f1_scores), aggreg_f1_scores
40 |
41 | def method(self):
42 | return "Bert Score"
43 |
--------------------------------------------------------------------------------
/system_eval/evaluation/bert_score/score.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import argparse
4 | import torch
5 | from collections import defaultdict
6 | from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
7 | import matplotlib
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 |
11 | from .utils import get_idf_dict, bert_cos_score_idf,\
12 | get_bert_embedding, bert_types
13 |
14 | __all__ = ['score', 'plot_example']
15 |
16 | def score(cands, refs, bert="bert-base-multilingual-cased",
17 | num_layers=8, verbose=False, no_idf=False, batch_size=64):
18 | """
19 | BERTScore metric.
20 | Args:
21 | - :param: `cands` (list of str): candidate sentences
22 | - :param: `refs` (list of str): reference sentences
23 | - :param: `bert` (str): bert specification
24 | - :param: `num_layers` (int): the layer of representation to use
25 | - :param: `verbose` (bool): turn on intermediate status update
26 | - :param: `no_idf` (bool): do not use idf weighting
27 | - :param: `batch_size` (int): bert score processing batch size
28 | """
29 | assert len(cands) == len(refs)
30 | assert bert in bert_types
31 |
32 | tokenizer = BertTokenizer.from_pretrained(bert)
33 | model = BertModel.from_pretrained(bert)
34 | model.eval()
35 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
36 | model.to(device)
37 |
38 | # drop unused layers
39 | model.encoder.layer = torch.nn.ModuleList([layer for layer in model.encoder.layer[:num_layers]])
40 |
41 | if no_idf:
42 | idf_dict = defaultdict(lambda: 1.)
43 | # set idf for [SEP] and [CLS] to 0
44 | idf_dict[101] = 0
45 | idf_dict[102] = 0
46 | else:
47 | if verbose:
48 | print('preparing IDF dict...')
49 | start = time.perf_counter()
50 | idf_dict = get_idf_dict(refs, tokenizer)
51 | if verbose:
52 | print('done in {:.2f} seconds'.format(time.perf_counter() - start))
53 |
54 | if verbose:
55 | print('calculating scores...')
56 | start = time.perf_counter()
57 | all_preds = bert_cos_score_idf(model, refs, cands, tokenizer, idf_dict,
58 | verbose=verbose, device=device, batch_size=batch_size)
59 |
60 | P = all_preds[:, 0].cpu()
61 | R = all_preds[:, 1].cpu()
62 | F1 = all_preds[:, 2].cpu()
63 | if verbose:
64 | print('done in {:.2f} seconds'.format(time.perf_counter() - start))
65 |
66 | return P, R, F1
67 |
68 | def plot_example(h, r, verbose=False, bert="bert-base-multilingual-cased",
69 | num_layers=8, fname=''):
70 | """
71 | BERTScore metric.
72 | Args:
73 | - :param: `h` (str): a candidate sentence
74 | - :param: `r` (str): a reference sentence
75 | - :param: `verbose` (bool): turn on intermediate status update
76 | - :param: `bert` (str): bert specification
77 | - :param: `num_layers` (int): the layer of representation to use
78 | """
79 | assert bert in bert_types
80 |
81 | if verbose:
82 | print('loading BERT model...')
83 | tokenizer = BertTokenizer.from_pretrained(bert)
84 | model = BertModel.from_pretrained(bert)
85 | model.eval()
86 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
87 | model.to(device)
88 |
89 | h_tokens = ['[CLS]'] + tokenizer.tokenize(h) + ['[SEP]']
90 | r_tokens = ['[CLS]'] + tokenizer.tokenize(r) + ['[SEP]']
91 |
92 | model.encoder.layer = torch.nn.ModuleList([layer for layer in model.encoder.layer[:num_layers]])
93 | idf_dict = defaultdict(lambda: 1.)
94 |
95 | ref_embedding, ref_lens, ref_masks, padded_idf = get_bert_embedding([r], model, tokenizer, idf_dict,
96 | device=device)
97 | hyp_embedding, ref_lens, ref_masks, padded_idf = get_bert_embedding([h], model, tokenizer, idf_dict,
98 | device=device)
99 |
100 | ref_embedding.div_(torch.norm(ref_embedding, dim=-1).unsqueeze(-1))
101 | hyp_embedding.div_(torch.norm(hyp_embedding, dim=-1).unsqueeze(-1))
102 |
103 | batch_size = ref_embedding.size(1)
104 |
105 | sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2)).cpu()
106 | sim = sim.squeeze(0).numpy()
107 |
108 | # remove [CLS] and [SEP] tokens
109 | r_tokens = r_tokens[1:-1]
110 | h_tokens = h_tokens[1:-1]
111 | sim = sim[1:-1,1:-1]
112 |
113 | fig, ax = plt.subplots(figsize=(len(r_tokens)*0.8, len(h_tokens)*0.8))
114 | im = ax.imshow(sim, cmap='Blues')
115 |
116 | # We want to show all ticks...
117 | ax.set_xticks(np.arange(len(r_tokens)))
118 | ax.set_yticks(np.arange(len(h_tokens)))
119 | # ... and label them with the respective list entries
120 | ax.set_xticklabels(r_tokens, fontsize=10)
121 | ax.set_yticklabels(h_tokens, fontsize=10)
122 | plt.xlabel("Refernce", fontsize=10)
123 | plt.ylabel("Candidate", fontsize=10)
124 |
125 | # Rotate the tick labels and set their alignment.
126 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
127 | rotation_mode="anchor")
128 |
129 | # Loop over data dimensions and create text annotations.
130 | for i in range(len(h_tokens)):
131 | for j in range(len(r_tokens)):
132 | text = ax.text(j, i, '{:.3f}'.format(sim[i, j]),
133 | ha="center", va="center", color="k" if sim[i, j] < 0.6 else "w")
134 |
135 | # P = sim.max(1).mean()
136 | # R = sim.max(0).mean()
137 | # F1 = 2 * P * R / (P + R)
138 |
139 | fig.tight_layout()
140 | # plt.title("BERT-F1: {:.3f}".format(F1), fontsize=10)
141 | if fname != "":
142 | print("Saved figure to file: ", fname+".png")
143 | plt.savefig(fname+'.png', dpi=100)
144 | plt.show()
145 |
--------------------------------------------------------------------------------
/system_eval/evaluation/bert_score/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from math import log
3 | from itertools import chain
4 | from collections import defaultdict, Counter
5 | from multiprocessing import Pool
6 | from functools import partial
7 | from tqdm.auto import tqdm
8 |
9 | __all__ = ['bert_types']
10 |
11 | bert_types = [
12 | 'bert-base-uncased',
13 | 'bert-large-uncased',
14 | 'bert-base-cased',
15 | 'bert-large-cased',
16 | 'bert-base-multilingual-uncased',
17 | 'bert-base-multilingual-cased',
18 | 'bert-base-chinese',
19 | ]
20 |
21 | def padding(arr, pad_token, dtype=torch.long):
22 | lens = torch.LongTensor([len(a) for a in arr])
23 | max_len = lens.max().item()
24 | padded = torch.ones(len(arr), max_len, dtype=dtype) * pad_token
25 | mask = torch.zeros(len(arr), max_len, dtype=torch.long)
26 | for i, a in enumerate(arr):
27 | padded[i, :lens[i]] = torch.tensor(a, dtype=dtype)
28 | mask[i, :lens[i]] = 1
29 | return padded, lens, mask
30 |
31 |
32 | def bert_encode(model, x, attention_mask):
33 | model.eval()
34 | x_seg = torch.zeros_like(x, dtype=torch.long)
35 | with torch.no_grad():
36 | x_encoded_layers, pooled_output = model(x, x_seg, attention_mask=attention_mask, output_all_encoded_layers=False)
37 | return x_encoded_layers
38 |
39 |
40 | def process(a, tokenizer=None):
41 | if not tokenizer is None:
42 | a = ["[CLS]"]+tokenizer.tokenize(a)+["[SEP]"]
43 | a = tokenizer.convert_tokens_to_ids(a)
44 | return set(a)
45 |
46 |
47 | def get_idf_dict(arr, tokenizer, nthreads=4):
48 | """
49 | Returns mapping from word piece index to its inverse document frequency.
50 | Args:
51 | - :param: `arr` (list of str) : sentences to process.
52 | - :param: `tokenizer` : a BERT tokenizer corresponds to `model`.
53 | - :param: `nthreads` (int) : number of CPU threads to use
54 | """
55 | idf_count = Counter()
56 | num_docs = len(arr)
57 |
58 | process_partial = partial(process, tokenizer=tokenizer)
59 |
60 | with Pool(nthreads) as p:
61 | idf_count.update(chain.from_iterable(p.map(process_partial, arr)))
62 |
63 | idf_dict = defaultdict(lambda : log((num_docs+1)/(1)))
64 | idf_dict.update({idx:log((num_docs+1)/(c+1)) for (idx, c) in idf_count.items()})
65 | return idf_dict
66 |
67 |
68 | def collate_idf(arr, tokenize, numericalize, idf_dict,
69 | pad="[PAD]", device='cuda:0'):
70 | """
71 | Helper function that pads a list of sentences to hvae the same length and
72 | loads idf score for words in the sentences.
73 | Args:
74 | - :param: `arr` (list of str): sentences to process.
75 | - :param: `tokenize` : a function that takes a string and return list
76 | of tokens.
77 | - :param: `numericalize` : a function that takes a list of tokens and
78 | return list of token indexes.
79 | - :param: `idf_dict` (dict): mapping a word piece index to its
80 | inverse document frequency
81 | - :param: `pad` (str): the padding token.
82 | - :param: `device` (str): device to use, e.g. 'cpu' or 'cuda'
83 | """
84 | arr = [["[CLS]"]+tokenize(a)+["[SEP]"] for a in arr]
85 | arr = [numericalize(a) for a in arr]
86 |
87 | idf_weights = [[idf_dict[i] for i in a] for a in arr]
88 |
89 | pad_token = numericalize([pad])[0]
90 |
91 | padded, lens, mask = padding(arr, pad_token, dtype=torch.long)
92 | padded_idf, _, _ = padding(idf_weights, pad_token, dtype=torch.float)
93 |
94 | padded = padded.to(device=device)
95 | mask = mask.to(device=device)
96 | lens = lens.to(device=device)
97 | return padded, padded_idf, lens, mask
98 |
99 |
100 | def get_bert_embedding(all_sens, model, tokenizer, idf_dict,
101 | batch_size=-1, device='cuda:0'):
102 | """
103 | Compute BERT embedding in batches.
104 | Args:
105 | - :param: `all_sens` (list of str) : sentences to encode.
106 | - :param: `model` : a BERT model from `pytorch_pretrained_bert`.
107 | - :param: `tokenizer` : a BERT tokenizer corresponds to `model`.
108 | - :param: `idf_dict` (dict) : mapping a word piece index to its
109 | inverse document frequency
110 | - :param: `device` (str): device to use, e.g. 'cpu' or 'cuda'
111 | """
112 |
113 | padded_sens, padded_idf, lens, mask = collate_idf(all_sens,
114 | tokenizer.tokenize, tokenizer.convert_tokens_to_ids,
115 | idf_dict,
116 | device=device)
117 |
118 | if batch_size == -1: batch_size = len(all_sens)
119 |
120 | embeddings = []
121 | with torch.no_grad():
122 | for i in range(0, len(all_sens), batch_size):
123 | batch_embedding = bert_encode(model, padded_sens[i:i+batch_size],
124 | attention_mask=mask[i:i+batch_size])
125 | # batch_embedding = torch.stack(batch_embedding)
126 | embeddings.append(batch_embedding)
127 | del batch_embedding
128 |
129 | total_embedding = torch.cat(embeddings, dim=0)
130 |
131 | return total_embedding, lens, mask, padded_idf
132 |
133 |
134 | def greedy_cos_idf(ref_embedding, ref_lens, ref_masks, ref_idf,
135 | hyp_embedding, hyp_lens, hyp_masks, hyp_idf):
136 | """
137 | Compute greedy matching based on cosine similarity.
138 | Args:
139 | - :param: `ref_embedding` (torch.Tensor):
140 | embeddings of reference sentences, BxKxd,
141 | B: batch size, K: longest length, d: bert dimenison
142 | - :param: `ref_lens` (list of int): list of reference sentence length.
143 | - :param: `ref_masks` (torch.LongTensor): BxKxK, BERT attention mask for
144 | reference sentences.
145 | - :param: `ref_idf` (torch.Tensor): BxK, idf score of each word
146 | piece in the reference setence
147 | - :param: `hyp_embedding` (torch.Tensor):
148 | embeddings of candidate sentences, BxKxd,
149 | B: batch size, K: longest length, d: bert dimenison
150 | - :param: `hyp_lens` (list of int): list of candidate sentence length.
151 | - :param: `hyp_masks` (torch.LongTensor): BxKxK, BERT attention mask for
152 | candidate sentences.
153 | - :param: `hyp_idf` (torch.Tensor): BxK, idf score of each word
154 | piece in the candidate setence
155 | """
156 |
157 | ref_embedding.div_(torch.norm(ref_embedding, dim=-1).unsqueeze(-1))
158 | hyp_embedding.div_(torch.norm(hyp_embedding, dim=-1).unsqueeze(-1))
159 |
160 | batch_size = ref_embedding.size(0)
161 |
162 | sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2))
163 | masks = torch.bmm(hyp_masks.unsqueeze(2).float(), ref_masks.unsqueeze(1).float())
164 | masks = masks.expand(batch_size, masks.size(1), masks.size(2))\
165 | .contiguous().view_as(sim)
166 |
167 | masks = masks.float().to(sim.device)
168 | sim = sim * masks
169 |
170 | word_precision = sim.max(dim=2)[0]
171 | word_recall = sim.max(dim=1)[0]
172 |
173 | hyp_idf.div_(hyp_idf.sum(dim=1, keepdim=True))
174 | ref_idf.div_(ref_idf.sum(dim=1, keepdim=True))
175 | precision_scale = hyp_idf.to(word_precision.device)
176 | recall_scale = ref_idf.to(word_recall.device)
177 | P = (word_precision * precision_scale).sum(dim=1)
178 | R = (word_recall * recall_scale).sum(dim=1)
179 |
180 | F = 2 * P * R / (P + R)
181 | return P, R, F
182 |
183 | def bert_cos_score_idf(model, refs, hyps, tokenizer, idf_dict,
184 | verbose=False, batch_size=64, device='cuda:0'):
185 | """
186 | Compute BERTScore.
187 | Args:
188 | - :param: `model` : a BERT model in `pytorch_pretrained_bert`
189 | - :param: `refs` (list of str): reference sentences
190 | - :param: `hyps` (list of str): candidate sentences
191 | - :param: `tokenzier` : a BERT tokenizer corresponds to `model`
192 | - :param: `idf_dict` : a dictionary mapping a word piece index to its
193 | inverse document frequency
194 | - :param: `verbose` (bool): turn on intermediate status update
195 | - :param: `batch_size` (int): bert score processing batch size
196 | - :param: `device` (str): device to use, e.g. 'cpu' or 'cuda'
197 | """
198 | preds = []
199 | iter_range = range(0, len(refs), batch_size)
200 | if verbose: iter_range = tqdm(iter_range)
201 | for batch_start in iter_range:
202 | batch_refs = refs[batch_start:batch_start+batch_size]
203 | batch_hyps = hyps[batch_start:batch_start+batch_size]
204 | ref_stats = get_bert_embedding(batch_refs, model, tokenizer, idf_dict,
205 | device=device)
206 | hyp_stats = get_bert_embedding(batch_hyps, model, tokenizer, idf_dict,
207 | device=device)
208 |
209 | P, R, F1 = greedy_cos_idf(*ref_stats, *hyp_stats)
210 | preds.append(torch.stack((P, R, F1), dim=1).cpu())
211 | preds = torch.cat(preds, dim=0)
212 | return preds
213 |
--------------------------------------------------------------------------------
/system_eval/evaluation/bleu/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 |
--------------------------------------------------------------------------------
/system_eval/evaluation/bleu/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/system_eval/evaluation/bleu/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/system_eval/evaluation/bleu/bleu.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : bleu.py
4 | #
5 | # Description : Wrapper for BLEU scorer.
6 | #
7 | # Creation Date : 06-01-2015
8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
9 | # Authors : Hao Fang and Tsung-Yi Lin
10 |
11 | from evaluation.bleu.bleu_scorer import BleuScorer
12 |
13 |
14 | class Bleu:
15 | def __init__(self, n=4):
16 | # default compute Blue score up to 4
17 | self._n = n
18 | self._hypo_for_image = {}
19 | self.ref_for_image = {}
20 |
21 | def compute_score(self, gts, res):
22 | assert(gts.keys() == res.keys())
23 | imgIds = gts.keys()
24 | bleu_scorer = BleuScorer(n=self._n)
25 | for id in imgIds:
26 | hypo = res[id]
27 | # print("hypo:", hypo)
28 | ref = gts[id]
29 | # print("ref:", ref)
30 |
31 | # Sanity check.
32 | assert(type(hypo) is list)
33 | assert(len(hypo) == 1)
34 | assert(type(ref) is list)
35 | assert(len(ref) >= 1)
36 |
37 | bleu_scorer += (hypo[0], ref)
38 |
39 | #score, scores = bleu_scorer.compute_score(option='shortest')
40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=1)
41 | #score, scores = bleu_scorer.compute_score(option='average', verbose=0)
42 |
43 | # return (bleu, bleu_info)
44 | return score, scores
45 |
46 | def method(self):
47 | return "Bleu"
48 |
--------------------------------------------------------------------------------
/system_eval/evaluation/bleu/bleu_scorer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # bleu_scorer.py
4 | # David Chiang
5 |
6 | # Copyright (c) 2004-2006 University of Maryland. All rights
7 | # reserved. Do not redistribute without permission from the
8 | # author. Not for commercial use.
9 |
10 | # Modified by:
11 | # Hao Fang
12 | # Tsung-Yi Lin
13 |
14 | '''Provides:
15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
17 | '''
18 |
19 | import copy
20 | import sys, math, re
21 | from collections import defaultdict
22 |
23 | def precook(s, n=4, out=False):
24 | """Takes a string as input and returns an object that can be given to
25 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
26 | can take string arguments as well."""
27 | words = s.split()
28 | counts = defaultdict(int)
29 | for k in range(1,n+1):
30 | for i in range(len(words)-k+1):
31 | ngram = tuple(words[i:i+k])
32 | counts[ngram] += 1
33 | return (len(words), counts)
34 |
35 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
36 | '''Takes a list of reference sentences for a single segment
37 | and returns an object that encapsulates everything that BLEU
38 | needs to know about them.'''
39 |
40 | reflen = []
41 | maxcounts = {}
42 | for ref in refs:
43 | rl, counts = precook(ref, n)
44 | reflen.append(rl)
45 | for (ngram,count) in counts.items():
46 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
47 |
48 | # Calculate effective reference sentence length.
49 | if eff == "shortest":
50 | reflen = min(reflen)
51 | elif eff == "average":
52 | reflen = float(sum(reflen))/len(reflen)
53 |
54 | ## lhuang: N.B.: leave reflen computaiton to the very end!!
55 |
56 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)
57 |
58 | return (reflen, maxcounts)
59 |
60 | def cook_test(test, tup, eff=None, n=4):
61 | '''Takes a test sentence and returns an object that
62 | encapsulates everything that BLEU needs to know about it.'''
63 |
64 | (reflen, refmaxcounts) = tup
65 | testlen, counts = precook(test, n, True)
66 |
67 | result = {}
68 |
69 | # Calculate effective reference sentence length.
70 |
71 | if eff == "closest":
72 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
73 | else: ## i.e., "average" or "shortest" or None
74 | result["reflen"] = reflen
75 |
76 | result["testlen"] = testlen
77 |
78 | result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)]
79 |
80 | result['correct'] = [0]*n
81 | for (ngram, count) in counts.items():
82 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)
83 |
84 | return result
85 |
86 | class BleuScorer(object):
87 | """Bleu scorer.
88 | """
89 |
90 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
91 | # special_reflen is used in oracle (proportional effective ref len for a node).
92 |
93 | def copy(self):
94 | ''' copy the refs.'''
95 | new = BleuScorer(n=self.n)
96 | new.ctest = copy.copy(self.ctest)
97 | new.crefs = copy.copy(self.crefs)
98 | new._score = None
99 | return new
100 |
101 | def __init__(self, test=None, refs=None, n=4, special_reflen=None):
102 | ''' singular instance '''
103 |
104 | self.n = n
105 | self.crefs = []
106 | self.ctest = []
107 | self.cook_append(test, refs)
108 | self.special_reflen = special_reflen
109 |
110 | def cook_append(self, test, refs):
111 | '''called by constructor and __iadd__ to avoid creating new instances.'''
112 |
113 | if refs is not None:
114 | self.crefs.append(cook_refs(refs))
115 | if test is not None:
116 | cooked_test = cook_test(test, self.crefs[-1])
117 | self.ctest.append(cooked_test) ## N.B.: -1
118 | else:
119 | self.ctest.append(None) # lens of crefs and ctest have to match
120 |
121 | self._score = None ## need to recompute
122 |
123 | def ratio(self, option=None):
124 | self.compute_score(option=option)
125 | return self._ratio
126 |
127 | def score_ratio(self, option=None):
128 | '''return (bleu, len_ratio) pair'''
129 | return (self.fscore(option=option), self.ratio(option=option))
130 |
131 | def score_ratio_str(self, option=None):
132 | return "%.4f (%.2f)" % self.score_ratio(option)
133 |
134 | def reflen(self, option=None):
135 | self.compute_score(option=option)
136 | return self._reflen
137 |
138 | def testlen(self, option=None):
139 | self.compute_score(option=option)
140 | return self._testlen
141 |
142 | def retest(self, new_test):
143 | if type(new_test) is str:
144 | new_test = [new_test]
145 | assert len(new_test) == len(self.crefs), new_test
146 | self.ctest = []
147 | for t, rs in zip(new_test, self.crefs):
148 | self.ctest.append(cook_test(t, rs))
149 | self._score = None
150 |
151 | return self
152 |
153 | def rescore(self, new_test):
154 | ''' replace test(s) with new test(s), and returns the new score.'''
155 |
156 | return self.retest(new_test).compute_score()
157 |
158 | def size(self):
159 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
160 | return len(self.crefs)
161 |
162 | def __iadd__(self, other):
163 | '''add an instance (e.g., from another sentence).'''
164 |
165 | if type(other) is tuple:
166 | ## avoid creating new BleuScorer instances
167 | self.cook_append(other[0], other[1])
168 | else:
169 | assert self.compatible(other), "incompatible BLEUs."
170 | self.ctest.extend(other.ctest)
171 | self.crefs.extend(other.crefs)
172 | self._score = None ## need to recompute
173 |
174 | return self
175 |
176 | def compatible(self, other):
177 | return isinstance(other, BleuScorer) and self.n == other.n
178 |
179 | def single_reflen(self, option="average"):
180 | return self._single_reflen(self.crefs[0][0], option)
181 |
182 | def _single_reflen(self, reflens, option=None, testlen=None):
183 |
184 | if option == "shortest":
185 | reflen = min(reflens)
186 | elif option == "average":
187 | reflen = float(sum(reflens))/len(reflens)
188 | elif option == "closest":
189 | reflen = min((abs(l-testlen), l) for l in reflens)[1]
190 | else:
191 | assert False, "unsupported reflen option %s" % option
192 |
193 | return reflen
194 |
195 | def recompute_score(self, option=None, verbose=0):
196 | self._score = None
197 | return self.compute_score(option, verbose)
198 |
199 | def compute_score(self, option=None, verbose=0):
200 | n = self.n
201 | small = 1e-9
202 | tiny = 1e-15 ## so that if guess is 0 still return 0
203 | bleu_list = [[] for _ in range(n)]
204 |
205 | if self._score is not None:
206 | return self._score
207 |
208 | if option is None:
209 | option = "average" if len(self.crefs) == 1 else "closest"
210 |
211 | self._testlen = 0
212 | self._reflen = 0
213 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}
214 |
215 | # for each sentence
216 | for comps in self.ctest:
217 | testlen = comps['testlen']
218 | self._testlen += testlen
219 |
220 | if self.special_reflen is None: ## need computation
221 | reflen = self._single_reflen(comps['reflen'], option, testlen)
222 | else:
223 | reflen = self.special_reflen
224 |
225 | self._reflen += reflen
226 |
227 | for key in ['guess','correct']:
228 | for k in range(n):
229 | totalcomps[key][k] += comps[key][k]
230 |
231 | # append per image bleu score
232 | bleu = 1.
233 | for k in range(n):
234 | bleu *= (float(comps['correct'][k]) + tiny) \
235 | /(float(comps['guess'][k]) + small)
236 | bleu_list[k].append(bleu ** (1./(k+1)))
237 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
238 | if ratio < 1:
239 | for k in range(n):
240 | bleu_list[k][-1] *= math.exp(1 - 1/ratio)
241 |
242 | if verbose > 1:
243 | print(comps, reflen)
244 |
245 | totalcomps['reflen'] = self._reflen
246 | totalcomps['testlen'] = self._testlen
247 |
248 | bleus = []
249 | bleu = 1.
250 | for k in range(n):
251 | bleu *= float(totalcomps['correct'][k] + tiny) \
252 | / (totalcomps['guess'][k] + small)
253 | bleus.append(bleu ** (1./(k+1)))
254 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
255 | if ratio < 1:
256 | for k in range(n):
257 | bleus[k] *= math.exp(1 - 1/ratio)
258 |
259 | if verbose > 0:
260 | print(totalcomps)
261 | print("ratio:", ratio)
262 |
263 | self._score = bleus
264 | return self._score, bleu_list
265 |
--------------------------------------------------------------------------------
/system_eval/evaluation/cider/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/system_eval/evaluation/cider/cider.py:
--------------------------------------------------------------------------------
1 | # Filename: cider.py
2 | #
3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric
4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
5 | #
6 | # Creation Date: Sun Feb 8 14:16:54 2015
7 | #
8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin
9 |
10 | from evaluation.cider.cider_scorer import CiderScorer
11 | import pdb
12 |
13 | class Cider:
14 | """
15 | Main Class to compute the CIDEr metric
16 |
17 | """
18 | def __init__(self, test=None, refs=None, n=4, sigma=6.0):
19 | # set cider to sum over 1 to 4-grams
20 | self._n = n
21 | # set the standard deviation parameter for gaussian penalty
22 | self._sigma = sigma
23 |
24 | def compute_score(self, gts, res):
25 | """
26 | Main function to compute CIDEr score
27 | :param hypo_for_image (dict) : dictionary with key and value
28 | ref_for_image (dict) : dictionary with key and value
29 | :return: cider (float) : computed CIDEr score for the corpus
30 | """
31 |
32 | assert(gts.keys() == res.keys())
33 | imgIds = gts.keys()
34 |
35 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma)
36 |
37 | for id in imgIds:
38 | hypo = res[id]
39 | ref = gts[id]
40 |
41 | # Sanity check.
42 | assert(type(hypo) is list)
43 | assert(len(hypo) == 1)
44 | assert(type(ref) is list)
45 | assert(len(ref) > 0)
46 |
47 | cider_scorer += (hypo[0], ref)
48 |
49 | (score, scores) = cider_scorer.compute_score()
50 |
51 | return score, scores
52 |
53 | def method(self):
54 | return "CIDEr"
55 |
--------------------------------------------------------------------------------
/system_eval/evaluation/cider/cider_scorer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Tsung-Yi Lin
3 | # Ramakrishna Vedantam
4 |
5 | import copy
6 | from collections import defaultdict
7 | import numpy as np
8 | import pdb
9 | import math
10 |
11 | def precook(s, n=4, out=False):
12 | """
13 | Takes a string as input and returns an object that can be given to
14 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
15 | can take string arguments as well.
16 | :param s: string : sentence to be converted into ngrams
17 | :param n: int : number of ngrams for which representation is calculated
18 | :return: term frequency vector for occuring ngrams
19 | """
20 | words = s.split()
21 | counts = defaultdict(int)
22 | for k in range(1,n+1):
23 | for i in range(len(words)-k+1):
24 | ngram = tuple(words[i:i+k])
25 | counts[ngram] += 1
26 | return counts
27 |
28 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
29 | '''Takes a list of reference sentences for a single segment
30 | and returns an object that encapsulates everything that BLEU
31 | needs to know about them.
32 | :param refs: list of string : reference sentences for some image
33 | :param n: int : number of ngrams for which (ngram) representation is calculated
34 | :return: result (list of dict)
35 | '''
36 | return [precook(ref, n) for ref in refs]
37 |
38 | def cook_test(test, n=4):
39 | '''Takes a test sentence and returns an object that
40 | encapsulates everything that BLEU needs to know about it.
41 | :param test: list of string : hypothesis sentence for some image
42 | :param n: int : number of ngrams for which (ngram) representation is calculated
43 | :return: result (dict)
44 | '''
45 | return precook(test, n, True)
46 |
47 | class CiderScorer(object):
48 | """CIDEr scorer.
49 | """
50 |
51 | def copy(self):
52 | ''' copy the refs.'''
53 | new = CiderScorer(n=self.n)
54 | new.ctest = copy.copy(self.ctest)
55 | new.crefs = copy.copy(self.crefs)
56 | return new
57 |
58 | def __init__(self, test=None, refs=None, n=4, sigma=6.0):
59 | ''' singular instance '''
60 | self.n = n
61 | self.sigma = sigma
62 | self.crefs = []
63 | self.ctest = []
64 | self.document_frequency = defaultdict(float)
65 | self.cook_append(test, refs)
66 | self.ref_len = None
67 |
68 | def cook_append(self, test, refs):
69 | '''called by constructor and __iadd__ to avoid creating new instances.'''
70 |
71 | if refs is not None:
72 | self.crefs.append(cook_refs(refs))
73 | if test is not None:
74 | self.ctest.append(cook_test(test)) ## N.B.: -1
75 | else:
76 | self.ctest.append(None) # lens of crefs and ctest have to match
77 |
78 | def size(self):
79 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
80 | return len(self.crefs)
81 |
82 | def __iadd__(self, other):
83 | '''add an instance (e.g., from another sentence).'''
84 |
85 | if type(other) is tuple:
86 | ## avoid creating new CiderScorer instances
87 | self.cook_append(other[0], other[1])
88 | else:
89 | self.ctest.extend(other.ctest)
90 | self.crefs.extend(other.crefs)
91 |
92 | return self
93 | def compute_doc_freq(self):
94 | '''
95 | Compute term frequency for reference data.
96 | This will be used to compute idf (inverse document frequency later)
97 | The term frequency is stored in the object
98 | :return: None
99 | '''
100 | for refs in self.crefs:
101 | # refs, k ref captions of one image
102 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
103 | self.document_frequency[ngram] += 1
104 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
105 |
106 | def compute_cider(self):
107 | def counts2vec(cnts):
108 | """
109 | Function maps counts of ngram to vector of tfidf weights.
110 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
111 | The n-th entry of array denotes length of n-grams.
112 | :param cnts:
113 | :return: vec (array of dict), norm (array of float), length (int)
114 | """
115 | vec = [defaultdict(float) for _ in range(self.n)]
116 | length = 0
117 | norm = [0.0 for _ in range(self.n)]
118 | for (ngram,term_freq) in cnts.items():
119 | # give word count 1 if it doesn't appear in reference corpus
120 | df = np.log(max(1.0, self.document_frequency[ngram]))
121 | # ngram index
122 | n = len(ngram)-1
123 | # tf (term_freq) * idf (precomputed idf) for n-grams
124 | vec[n][ngram] = float(term_freq)*(self.ref_len - df)
125 | # compute norm for the vector. the norm will be used for computing similarity
126 | norm[n] += pow(vec[n][ngram], 2)
127 |
128 | if n == 1:
129 | length += term_freq
130 | norm = [np.sqrt(n) for n in norm]
131 | return vec, norm, length
132 |
133 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
134 | '''
135 | Compute the cosine similarity of two vectors.
136 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis
137 | :param vec_ref: array of dictionary for vector corresponding to reference
138 | :param norm_hyp: array of float for vector corresponding to hypothesis
139 | :param norm_ref: array of float for vector corresponding to reference
140 | :param length_hyp: int containing length of hypothesis
141 | :param length_ref: int containing length of reference
142 | :return: array of score for each n-grams cosine similarity
143 | '''
144 | delta = float(length_hyp - length_ref)
145 | # measure consine similarity
146 | val = np.array([0.0 for _ in range(self.n)])
147 | for n in range(self.n):
148 | # ngram
149 | for (ngram,count) in vec_hyp[n].items():
150 | # vrama91 : added clipping
151 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
152 |
153 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
154 | val[n] /= (norm_hyp[n]*norm_ref[n])
155 |
156 | assert(not math.isnan(val[n]))
157 | # vrama91: added a length based gaussian penalty
158 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
159 | return val
160 |
161 | # compute log reference length
162 | self.ref_len = np.log(float(len(self.crefs)))
163 |
164 | scores = []
165 | for test, refs in zip(self.ctest, self.crefs):
166 | # compute vector for test captions
167 | vec, norm, length = counts2vec(test)
168 | # compute vector for ref captions
169 | score = np.array([0.0 for _ in range(self.n)])
170 | for ref in refs:
171 | vec_ref, norm_ref, length_ref = counts2vec(ref)
172 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
173 | # change by vrama91 - mean of ngram scores, instead of sum
174 | score_avg = np.mean(score)
175 | # divide by number of references
176 | score_avg /= len(refs)
177 | # multiply score by 10
178 | score_avg *= 10.0
179 | # append score of an image to the score list
180 | scores.append(score_avg)
181 | return scores
182 |
183 | def compute_score(self, option=None, verbose=0):
184 | # compute idf
185 | self.compute_doc_freq()
186 | # assert to check document frequency
187 | assert(len(self.ctest) >= max(self.document_frequency.values()))
188 | # compute cider score
189 | score = self.compute_cider()
190 | # debug
191 | # print score
192 | return np.mean(np.array(score)), np.array(score)
193 |
--------------------------------------------------------------------------------
/system_eval/evaluation/eval.py:
--------------------------------------------------------------------------------
1 | from evaluation.bleu.bleu import Bleu
2 | from evaluation.meteor.meteor_nltk import Meteor
3 | from evaluation.rouge.rouge import Rouge
4 | from evaluation.cider.cider import Cider
5 | from evaluation.bert_score.bert_score import BertScore
6 | from collections import defaultdict
7 | from argparse import ArgumentParser
8 |
9 | import sys
10 | import json
11 | #reload(sys)
12 | #sys.setdefaultencoding('utf-8')
13 |
14 | class QGEvalCap:
15 | def __init__(self, model_key, gts, res, results_file=None):
16 | self.gts = gts
17 | self.res = res
18 | self.results_file = results_file
19 | self.model_key = model_key
20 |
21 | def evaluate(self):
22 | output = []
23 | scorers = [
24 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
25 | (Meteor(),"METEOR"),
26 | (Rouge(), "ROUGE_L"),
27 | (Cider(), "CIDEr"),
28 | (BertScore(), "Bert Score")
29 | ]
30 |
31 | # =================================================
32 | # Compute scores
33 | # =================================================
34 | score_dict = {}
35 | scores_dict = {}
36 | #scores_dict["model_key"] = self.model_key
37 | for scorer, method in scorers:
38 | # print 'computing %s score...'%(scorer.method())
39 | score, scores = scorer.compute_score(self.gts, self.res)
40 | if type(method) == list:
41 | for sc, scs, m in zip(score, scores, method):
42 | #print("%s: %0.5f"%(m, sc))
43 | output.append(sc)
44 | score_dict[m] = str(sc)
45 | scores_dict[m] = list(scs)
46 | else:
47 | #print("%s: %0.5f"%(method, score))
48 | output.append(score)
49 | score_dict[method] = score
50 | scores_dict[method] = list(scores)
51 |
52 | if self.results_file != None:
53 | with open(self.results_file, "a") as f:
54 | f.write(json.dumps(score_dict)+"\n")
55 |
56 | return score_dict, scores_dict
57 |
58 | def eval(model_key, sources, references, predictions, results_file=None):
59 | """
60 | Given a filename, calculate the metric scores for that prediction file
61 | isDin: boolean value to check whether input file is DirectIn.txt
62 | """
63 |
64 | pairs = []
65 |
66 | for tup in sources:
67 | pair = {}
68 | pair['tokenized_sentence'] = tup
69 | pairs.append(pair)
70 |
71 | cnt = 0
72 | for line in references:
73 | pairs[cnt]['tokenized_question'] = line
74 | cnt += 1
75 |
76 | output = predictions
77 |
78 | for idx, pair in enumerate(pairs):
79 | pair['prediction'] = output[idx]
80 |
81 | ## eval
82 | from evaluation.eval import QGEvalCap
83 | import json
84 | from json import encoder
85 | encoder.FLOAT_REPR = lambda o: format(o, '.4f')
86 |
87 | res = defaultdict(lambda: [])
88 | gts = defaultdict(lambda: [])
89 | for pair in pairs[:]:
90 | key = pair['tokenized_sentence']
91 | #res[key] = [pair['prediction']]
92 | res[key] = pair['prediction']
93 |
94 | ## gts
95 | gts[key].append(pair['tokenized_question'])
96 |
97 | QGEval = QGEvalCap(model_key, gts, res, results_file)
98 | return QGEval.evaluate()
99 |
100 |
101 | def preprocess(file_name, keys):
102 | with open(file_name) as f:
103 | data = f.readlines()
104 | generations = [json.loads(elem) for elem in data]
105 |
106 | predictions = {}
107 | references = {}
108 | sources = {}
109 | keys_list = keys if keys!=None else generations[0]["generations"].keys()
110 | for key in keys_list:
111 | references[key] = []
112 | predictions[key] = []
113 | sources[key] = []
114 |
115 | for elem in generations:
116 | label = elem["label"]
117 | hyp = elem["hyp"+label]
118 | for key in keys_list:
119 | if key in elem["generations"]:
120 | references[key].append(hyp)
121 | predictions[key].append(elem["generations"][key])
122 | sources[key].append((elem["obs1"], elem["obs2"]))
123 |
124 | return sources, references, predictions
125 |
126 |
127 | if __name__ == "__main__":
128 | parser = ArgumentParser()
129 | parser.add_argument("-gen_file", "--gen_file", dest="gen_file", help="generations file with gold/references")
130 | parser.add_argument("--keys", type=str, default=None, help="comma-separated list of model keys")
131 | parser.add_argument("--results_file", default="eval_results.jsonl")
132 | args = parser.parse_args()
133 |
134 | print("scores: \n")
135 | keys=None
136 | if args.keys:
137 | keys = args.keys.split(",")
138 |
139 | sources, references, predictions = preprocess(args.gen_file, keys)
140 | for key in references.keys():
141 | print("\nEvaluating %s" %key)
142 | eval(key, sources[key], references[key], predictions[key], args.results_file)
143 |
144 |
--------------------------------------------------------------------------------
/system_eval/evaluation/meteor/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/system_eval/evaluation/meteor/meteor-1.5.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yongho94/solar-framework_commonsense-inference/50e99ba0b0f5ae2315c72f5f35f6d385499283e1/system_eval/evaluation/meteor/meteor-1.5.jar
--------------------------------------------------------------------------------
/system_eval/evaluation/meteor/meteor.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Python wrapper for METEOR implementation, by Xinlei Chen
4 | # Acknowledge Michael Denkowski for the generous discussion and help
5 |
6 | import os
7 | import sys
8 | import subprocess
9 | import threading
10 |
11 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed.
12 | METEOR_JAR = 'meteor-1.5.jar'
13 | # print METEOR_JAR
14 |
15 | class Meteor:
16 |
17 | def __init__(self):
18 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \
19 | '-', '-', '-stdio', '-l', 'en',
20 | '-norm',
21 | # '-t', 'adq'
22 | # '-p', '0.85 0.2 0.6 0.75' # alpha beta gamma delta'',
23 | # '-a', 'data/paraphrase-en.gz', '-m', 'exact stem paraphrase']
24 | ]
25 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \
26 | cwd=os.path.dirname(os.path.abspath(__file__)), \
27 | stdin=subprocess.PIPE, \
28 | stdout=subprocess.PIPE, \
29 | stderr=subprocess.PIPE)
30 | # Used to guarantee thread safety
31 | self.lock = threading.Lock()
32 |
33 | def compute_score(self, gts, res):
34 | assert(gts.keys() == res.keys())
35 | imgIds = gts.keys()
36 | scores = []
37 |
38 | eval_line = 'EVAL'
39 | self.lock.acquire()
40 | for i in imgIds:
41 | assert(len(res[i]) == 1)
42 | stat = self._stat(res[i][0], gts[i])
43 | eval_line += ' ||| {}'.format(stat)
44 |
45 | print('{}\n'.format(eval_line))
46 | self.meteor_p.stdin.write('{}\n'.format(eval_line))
47 | print(self.meteor_p.stdout.readline().strip())
48 |
49 | for i in range(0,len(imgIds)):
50 | scores.append(float(self.meteor_p.stdout.readline().strip()))
51 | score = float(self.meteor_p.stdout.readline().strip())
52 | self.lock.release()
53 |
54 | return score, scores
55 |
56 | def method(self):
57 | return "METEOR"
58 |
59 | def _stat(self, hypothesis_str, reference_list):
60 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
61 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ')
62 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
63 | # print score_line
64 | str_in = '{}\n'.format(score_line)
65 | #self.meteor_p.communicate(str_in.encode('utf=8'))
66 | self.meteor_p.stdin.write(str_in.encode('utf=8'))
67 | return self.meteor_p.stdout.readline().strip()
68 |
69 | def _score(self, hypothesis_str, reference_list):
70 | self.lock.acquire()
71 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
72 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ')
73 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
74 | self.meteor_p.stdin.write('{}\n'.format(score_line))
75 | stats = self.meteor_p.stdout.readline().strip()
76 | eval_line = 'EVAL ||| {}'.format(stats)
77 | # EVAL ||| stats
78 | self.meteor_p.stdin.write('{}\n'.format(eval_line))
79 | score = float(self.meteor_p.stdout.readline().strip())
80 | # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice
81 | # thanks for Andrej for pointing this out
82 | score = float(self.meteor_p.stdout.readline().strip())
83 | self.lock.release()
84 | return score
85 |
86 | def __del__(self):
87 | self.lock.acquire()
88 | self.meteor_p.stdin.close()
89 | self.meteor_p.kill()
90 | self.meteor_p.wait()
91 | self.lock.release()
92 |
--------------------------------------------------------------------------------
/system_eval/evaluation/meteor/meteor_nltk.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Python wrapper for METEOR implementation, by Xinlei Chen
4 | # Acknowledge Michael Denkowski for the generous discussion and help
5 |
6 | import os
7 | import sys
8 | import nltk
9 | from nltk.translate.meteor_score import meteor_score
10 |
11 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed.
12 | #METEOR_JAR = 'meteor-1.5.jar'
13 | # print METEOR_JAR
14 |
15 | class Meteor:
16 |
17 | def __init__(self):
18 | pass
19 |
20 | def compute_score(self, gts, res):
21 | assert(gts.keys() == res.keys())
22 | imgIds = gts.keys()
23 | scores = []
24 |
25 | for i in imgIds:
26 | assert(len(res[i]) == 1)
27 | score = round(meteor_score(gts[i], res[i][0]), 4)
28 | scores.append(score)
29 | #print('{}\n'.format(eval_line))
30 | #self.meteor_p.stdin.write('{}\n'.format(eval_line))
31 | #print(self.meteor_p.stdout.readline().strip())
32 |
33 | #for i in range(0,len(imgIds)):
34 | # scores.append(float(self.meteor_p.stdout.readline().strip()))
35 | #score = float(self.meteor_p.stdout.readline().strip())
36 | #self.lock.release()
37 |
38 | return sum(scores)/len(scores), scores
39 |
40 | def method(self):
41 | return "METEOR"
42 |
43 |
--------------------------------------------------------------------------------
/system_eval/evaluation/rouge/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'vrama91'
2 |
--------------------------------------------------------------------------------
/system_eval/evaluation/rouge/rouge.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : rouge.py
4 | #
5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
6 | #
7 | # Creation Date : 2015-01-07 06:03
8 | # Author : Ramakrishna Vedantam
9 |
10 | import numpy as np
11 | import pdb
12 |
13 | def my_lcs(string, sub):
14 | """
15 | Calculates longest common subsequence for a pair of tokenized strings
16 | :param string : list of str : tokens from a string split using whitespace
17 | :param sub : list of str : shorter string, also split using whitespace
18 | :returns: length (list of int): length of the longest common subsequence between the two strings
19 |
20 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
21 | """
22 | if(len(string)< len(sub)):
23 | sub, string = string, sub
24 |
25 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)]
26 |
27 | for j in range(1,len(sub)+1):
28 | for i in range(1,len(string)+1):
29 | if(string[i-1] == sub[j-1]):
30 | lengths[i][j] = lengths[i-1][j-1] + 1
31 | else:
32 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1])
33 |
34 | return lengths[len(string)][len(sub)]
35 |
36 | class Rouge():
37 | '''
38 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set
39 |
40 | '''
41 | def __init__(self):
42 | # vrama91: updated the value below based on discussion with Hovey
43 | self.beta = 1.2
44 |
45 | def calc_score(self, candidate, refs):
46 | """
47 | Compute ROUGE-L score given one candidate and references for an image
48 | :param candidate: str : candidate sentence to be evaluated
49 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated
50 | :returns score: int (ROUGE-L score for the candidate evaluated against references)
51 | """
52 | assert(len(candidate)==1)
53 | assert(len(refs)>0)
54 | prec = []
55 | rec = []
56 |
57 | # split into tokens
58 | token_c = candidate[0].split(" ")
59 |
60 | for reference in refs:
61 | # split into tokens
62 | token_r = reference.split(" ")
63 | # compute the longest common subsequence
64 | lcs = my_lcs(token_r, token_c)
65 | prec.append(lcs/float(len(token_c)))
66 | rec.append(lcs/float(len(token_r)))
67 |
68 | prec_max = max(prec)
69 | rec_max = max(rec)
70 |
71 | if(prec_max!=0 and rec_max !=0):
72 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max)
73 | else:
74 | score = 0.0
75 | return score
76 |
77 | def compute_score(self, gts, res):
78 | """
79 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset
80 | Invoked by evaluate_captions.py
81 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values
82 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values
83 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
84 | """
85 | assert(gts.keys() == res.keys())
86 | imgIds = gts.keys()
87 |
88 | score = []
89 | for id in imgIds:
90 | hypo = res[id]
91 | ref = gts[id]
92 |
93 | score.append(self.calc_score(hypo, ref))
94 |
95 | # Sanity check.
96 | assert(type(hypo) is list)
97 | assert(len(hypo) == 1)
98 | assert(type(ref) is list)
99 | assert(len(ref) > 0)
100 |
101 | average_score = np.mean(np.array(score))
102 | print("len score:", len(score))
103 | return average_score, np.array(score)
104 |
105 | def method(self):
106 | return "Rouge"
107 |
--------------------------------------------------------------------------------
/system_eval/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 | import csv
4 | import operator
5 | import random
6 |
7 |
8 | def read_csv(input_file, quotechar='"', delimiter=",", skip_header=False):
9 | """Reads a tab separated value file."""
10 | with open(input_file, "r") as f:
11 | reader = csv.reader(f, delimiter=delimiter, quotechar=quotechar, quoting=csv.QUOTE_ALL, skipinitialspace=True)
12 | lines = []
13 | for line in reader:
14 | if sys.version_info[0] == 2:
15 | line = list(unicode(cell, 'utf-8') for cell in line)
16 | lines.append(line)
17 | if skip_header:
18 | lines = lines[1:]
19 | return lines
20 |
21 |
22 | def write_tsv(output_file, data, header=False):
23 | keys = list(data[0].keys())
24 | with open(output_file, 'w') as f:
25 | w = csv.DictWriter(f, keys, delimiter='\t', lineterminator='\n')
26 | if header:
27 | w.writeheader()
28 | for r in data:
29 | entry = {k: r[k] for k in keys}
30 | w.writerow(entry)
31 |
32 |
33 | def write_array2tsv(output_file, data, header=False):
34 | keys = range(len(data[0]))
35 | with open(output_file, 'w') as f:
36 | w = csv.DictWriter(f, keys, delimiter='\t', lineterminator='\n')
37 | if header:
38 | w.writeheader()
39 | for r in data:
40 | entry = {k: r[k] for k in keys}
41 | w.writerow(entry)
42 |
43 |
44 | def write_csv(filename, data, fieldnames):
45 | with open(filename, 'w', newline='') as csvfile:
46 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
47 |
48 | writer.writeheader()
49 | for d in data:
50 | formatted_d = {}
51 | for key, val in d.items():
52 | formatted_d[key] = json.dumps(val)
53 | writer.writerow(formatted_d)
54 |
55 |
56 | def read_jsonl(filename):
57 | data = []
58 | with open(filename, "r") as f:
59 | for line in f:
60 | data.append(json.loads(line))
61 | return data
62 |
63 |
64 | def write_items(output_file, items):
65 | with open(output_file, 'w') as f:
66 | for concept in items:
67 | f.write(concept + "\n")
68 | f.close()
69 |
70 |
71 | def write_jsonl(f, d):
72 | write_items(f, [json.dumps(r) for r in d])
73 |
74 |
75 | def count_relation(d):
76 | relation_count = {}
77 | prefix_count = {}
78 | head_count = {}
79 | for l in d:
80 | r = l[1]
81 | if r not in relation_count.keys():
82 | relation_count[r] = 0
83 | relation_count[r] += 1
84 |
85 | prefix = l[0]+l[1]
86 | if prefix not in prefix_count.keys():
87 | prefix_count[prefix] = 0
88 | prefix_count[prefix] += 1
89 |
90 | head = l[0]
91 | if head not in head_count.keys():
92 | head_count[head] = 0
93 | head_count[head] += 1
94 |
95 | sorted_relation_count = dict(sorted(relation_count.items(), key=operator.itemgetter(1), reverse=True))
96 | sorted_prefix_count = dict(sorted(prefix_count.items(), key=operator.itemgetter(1), reverse=True))
97 | sorted_head_count = dict(sorted(head_count.items(), key=operator.itemgetter(1), reverse=True))
98 |
99 | print("Relations:")
100 | for r in sorted_relation_count.keys():
101 | print(r, sorted_relation_count[r])
102 |
103 | print("\nPrefixes:")
104 | print("uniq prefixes: ", len(sorted_prefix_count.keys()))
105 | i = 0
106 | for r in sorted_prefix_count.keys():
107 | print(r, sorted_prefix_count[r])
108 | i += 1
109 | if i > 20:
110 | break
111 |
112 | print("\nHeads:")
113 | i = 0
114 | for r in sorted_head_count.keys():
115 | print(r, sorted_head_count[r])
116 | i += 1
117 | if i > 20:
118 | break
119 |
120 |
121 | def get_head_set(d):
122 | return set([l[0] for l in d])
123 |
124 |
125 | def head_based_split(data, dev_size, test_size, head_size_threshold=500, dev_heads=[], test_heads=[]):
126 | """
127 | :param data: the tuples to split according to the heads, where the head is the first element of each tuple
128 | :param dev_size: target size of the dev set
129 | :param test_size: target size of the test set
130 | :param head_size_threshold: Maximum number of tuples a head can be involved in,
131 | in order to be considered for the dev/test set'
132 | :param dev_heads: heads that are forced to belong to the dev set
133 | :param test_heads: heads that are forced to belong to the test set
134 | :return:
135 | """
136 | head_count = {}
137 | for l in data:
138 | head = l[0]
139 | if head not in head_count.keys():
140 | head_count[head] = 0
141 | head_count[head] += 1
142 |
143 | remaining_heads = dict(head_count)
144 |
145 | test_selected_heads = {}
146 | test_head_total_count = 0
147 |
148 | for h in test_heads:
149 | if h in remaining_heads:
150 | c = remaining_heads[h]
151 | test_selected_heads[h] = c
152 | test_head_total_count += c
153 | remaining_heads.pop(h)
154 |
155 | while test_head_total_count < test_size:
156 | h = random.sample(remaining_heads.keys(), 1)[0]
157 | c = remaining_heads[h]
158 | if c < head_size_threshold:
159 | test_selected_heads[h] = c
160 | test_head_total_count += c
161 | remaining_heads.pop(h)
162 |
163 | test = [l for l in data if l[0] in test_selected_heads.keys()]
164 |
165 | dev_selected_heads = {}
166 | dev_head_total_count = 0
167 |
168 | for h in dev_heads:
169 | if h in remaining_heads:
170 | c = remaining_heads[h]
171 | dev_selected_heads[h] = c
172 | dev_head_total_count += c
173 | remaining_heads.pop(h)
174 |
175 | while dev_head_total_count < dev_size:
176 | h = random.sample(remaining_heads.keys(), 1)[0]
177 | c = remaining_heads[h]
178 | if c < head_size_threshold:
179 | dev_selected_heads[h] = c
180 | dev_head_total_count += c
181 | remaining_heads.pop(h)
182 |
183 | dev = [l for l in data if l[0] in dev_selected_heads.keys()]
184 |
185 | dev_test_heads = set(list(dev_selected_heads.keys()) + list(test_selected_heads.keys()))
186 | train = [l for l in data if l[0] not in dev_test_heads]
187 |
188 | return train, dev, test
189 |
190 |
191 | def remove_prefix(text, prefix):
192 | return text[text.startswith(prefix) and len(prefix):]
193 |
--------------------------------------------------------------------------------