├── .gitignore ├── LICENSE ├── NOTICE.txt ├── README.md ├── argument-classification ├── README.md ├── inference.py ├── train.py ├── train_ibm.sh ├── train_ukp.sh ├── train_ukp_all_data.sh └── ukp_evaluation.py └── argument-similarity ├── README.md ├── SigmoidBERT.py ├── datasets └── ukp_aspect │ └── make_splits.py ├── evaluation_with_clustering.py ├── evaluation_without_clustering.py ├── inference.py ├── train.py ├── train_misra.sh ├── train_misra_all.sh ├── train_ukp.sh └── train_ukp_all.sh /.gitignore: -------------------------------------------------------------------------------- 1 | *.tsv 2 | *.bin 3 | .idea 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /NOTICE.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UKPLab/acl2019-BERT-argument-classification-and-clustering/72f643b06a06b9ba82a25df2c134664fc26f84f3/NOTICE.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Argument Classification and Clustering using BERT 2 | In our publication [Classification and Clustering of Arguments with Contextualized Word Embeddings](https://arxiv.org/abs/1906.09821) (ACL 2019) we fine-tuned the BERT network to: 3 | - Perform sentential argument classification (i.e., given a sentence with an argument for a controversial topic, classify this sentence as pro, con, or no argument). Details can be found in [argument-classification/README.md](argument-classification/README.md) 4 | - Estimate the argument similarity (0...1) given two sentences. This argument similarity score can be used in conjuction with hierarchical agglomerative clustering to perform aspect-based argument clustering. Details can be found in [argument-similarity/README.md](argument-similarity/README.md) 5 | 6 | 7 | # Citation 8 | If you find the implementation useful, please cite the following paper: [Classification and Clustering of Arguments with Contextualized Word Embeddings](https://arxiv.org/abs/1906.09821) 9 | 10 | ``` 11 | @InProceedings{Reimers:2019:ACL, 12 | author = {Reimers, Nils, and Schiller, Benjamin and Beck, Tilman and Daxenberger, Johannes and Stab, Christian and Gurevych, Iryna}, 13 | title = {{Classification and Clustering of Arguments with Contextualized Word Embeddings}}, 14 | booktitle = {Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)}, 15 | month = {07}, 16 | year = {2019}, 17 | address = {Florence, Italy}, 18 | pages = {567--578}, 19 | url = {https://arxiv.org/abs/1906.09821} 20 | } 21 | ``` 22 | 23 | 24 | 25 | Contact person: Nils Reimers, Rnils@web.de 26 | 27 | https://www.ukp.tu-darmstadt.de/ 28 | 29 | 30 | Don't hesitate to send us an e-mail or report an issue, if something is broken (and it shouldn't be) or if you have further questions. 31 | 32 | > This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication. 33 | 34 | # Setup 35 | 36 | This repository requires Python 3.5+ and PyTorch 0.4.1/1.0.0. It uses [pytorch-pretrained-BERT](https://github.com/huggingface/pytorch-pretrained-BERT/) version 0.6.2. See the [pytorch-pretrained-BERT](https://github.com/huggingface/pytorch-pretrained-BERT/) readme for further details on the installation. Usually, you can install as follows: 37 | ``` 38 | pip install pytorch-pretrained-bert==0.6.2 sklearn scipy 39 | ``` 40 | 41 | # Argument Classification 42 | Please see [argument-classification/README.md](argument-classification/README.md) for full details. 43 | 44 | Given a sentence and a topic, classify if the sentence is a pro, con, or no argument. For example: 45 | ``` 46 | Topic: zoo 47 | Sentence: Zoo confinement is psychologically damaging to animals. 48 | Output Label: Argument_against 49 | ``` 50 | 51 | You can download pre-trained models from here, which were trained on all eight topics of the [UKP Sentential Argument Mining Corpus](https://www.informatik.tu-darmstadt.de/ukp/research_6/data/argumentation_mining_1/ukp_sentential_argument_mining_corpus/index.en.jsp): 52 | - [argument_classification_ukp_all_data.zip](https://public.ukp.informatik.tu-darmstadt.de/reimers/2019_acl-BERT-argument-classification-and-clustering/models/argument_classification_ukp_all_data.zip) 53 | - [argument_classification_ukp_all_data_large_model.zip](https://public.ukp.informatik.tu-darmstadt.de/reimers/2019_acl-BERT-argument-classification-and-clustering/models/argument_classification_ukp_all_data_large_model.zip) 54 | 55 | 56 | See [argument-classification/inference.py](argument-classification/inference.py) how to use these models for classifying new sentences. 57 | 58 | In a leave-one-topic out evaluation, the BERT model achieves the following performance. 59 | 60 | ![Classification Performance](https://public.ukp.informatik.tu-darmstadt.de/reimers/2019_acl-BERT-argument-classification-and-clustering/images/table_classification_results.png) 61 | 62 | 63 | # Argument Similarity & Clustering 64 | See [argument-similarity/README.md](argument-similarity/README.md) for full details. 65 | 66 | Given two sentences, the code in [argument-similarity](argument-similarity/) returns a value between 0 and 1 indicating the similarity between the arguments. This can be used for agglomorative clustering to find & cluster similar arguments. 67 | 68 | You can download two pre-trained models: 69 | - [argument_similarity_ukp_aspects_all.zip](https://public.ukp.informatik.tu-darmstadt.de/reimers/2019_acl-BERT-argument-classification-and-clustering/models/argument_similarity_ukp_aspects_all.zip) - trained on the complete [UKP Argument Aspect Similarity Corpus](https://www.informatik.tu-darmstadt.de/ukp/research_6/data/argumentation_mining_1/ukp_argument_aspect_similarity_corpus/ukp_argument_aspect_similarity_corpus.en.jsp) 70 | - [argument_similarity_misra_all.zip](https://public.ukp.informatik.tu-darmstadt.de/reimers/2019_acl-BERT-argument-classification-and-clustering/models/argument_similarity_misra_all.zip) - trained on the complete [Argument Facet Similarity (AFS) Corpus](https://nlds.soe.ucsc.edu/node/44) from Misra et al. 71 | 72 | 73 | See [argument-similarity/inference.py](argument-similarity/inference.py) for an example. This example computes the pairwise similarity between arguments on different topics. 74 | The output should be something like this for the model trained on the UKP corpus: 75 | ``` 76 | Predicted similarities (sorted by similarity): 77 | Sentence A: Eating meat is not cruel or unethical; it is a natural part of the cycle of life. 78 | Sentence B: It is cruel and unethical to kill animals for food when vegetarian options are available 79 | Similarity: 0.99436545 80 | 81 | Sentence A: Zoos are detrimental to animals' physical health. 82 | Sentence B: Zoo confinement is psychologically damaging to animals. 83 | Similarity: 0.99386144 84 | 85 | [...] 86 | 87 | Sentence A: It is cruel and unethical to kill animals for food when vegetarian options are available 88 | Sentence B: Rising levels of human-produced gases released into the atmosphere create a greenhouse effect that traps heat and causes global warming. 89 | Similarity: 0.0057242378 90 | ``` 91 | 92 | With the Misra AFS model, the output should be something like this: 93 | ``` 94 | Predicted similarities (sorted by similarity): 95 | Sentence A: Zoos are detrimental to animals' physical health. 96 | Sentence B: Zoo confinement is psychologically damaging to animals. 97 | Similarity: 0.8723387 98 | 99 | Sentence A: Eating meat is not cruel or unethical; it is a natural part of the cycle of life. 100 | Sentence B: It is cruel and unethical to kill animals for food when vegetarian options are available 101 | Similarity: 0.77635074 102 | 103 | [...] 104 | 105 | Sentence A: Zoos produce helpful scientific research. 106 | Sentence B: Eating meat is not cruel or unethical; it is a natural part of the cycle of life. 107 | Similarity: 0.20616204 108 | ``` 109 | 110 | 111 | ## Argument Similarty Performance 112 | 113 | ![UKP Aspects Performance](https://public.ukp.informatik.tu-darmstadt.de/reimers/2019_acl-BERT-argument-classification-and-clustering/images/table_UKP_Aspects_results.png) 114 | 115 | ![AFS Performance](https://public.ukp.informatik.tu-darmstadt.de/reimers/2019_acl-BERT-argument-classification-and-clustering/images/table_AFS_results.png) 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /argument-classification/README.md: -------------------------------------------------------------------------------- 1 | # Argument Classification 2 | This folder contains code to fine-tune BERT for argument classification: Is a given sentence a pro, con, or no argument for a given topic? 3 | 4 | **Example:** 5 | ``` 6 | Topic: zoo 7 | Sentence: Zoo confinement is psychologically damaging to animals. 8 | Output Label: Argument_against 9 | ``` 10 | 11 | 12 | ## Setup 13 | For the setup, see the [README.md](https://github.com/UKPLab/acl2019-BERT-argument-classification-and-clustering/) in the main folder. 14 | 15 | 16 | ## Training 17 | We trained (and evaluated) our models on the [UKP Sentential Argument Mining Corpus](https://www.informatik.tu-darmstadt.de/ukp/research_6/data/argumentation_mining_1/ukp_sentential_argument_mining_corpus/index.en.jsp), which annotated 25,492 sentences over eight controversial topics. 18 | 19 | Due to copyright issues, we cannot distribute the corpus directly. You need to download it from the website and run a Java program, to re-construct the corpus from the sources. More information can be found in the README.txt of the UKP Sentential Argument Mining Corpus. 20 | 21 | Once you have re-created the UKP Sentential Argument Mining Corpus, you can fine-tune BERT by running the `train_ukp.sh` script: 22 | ``` 23 | ./train_ukp.sh 24 | ``` 25 | 26 | This fine-tunes BERT on seven topics and evaluates the performance on the eigth topic. 27 | 28 | 29 | If you want to train BERT on all 25,492 sentences from the UKP Argument Corpus, run `train_ukp_all_data.sh`. 30 | 31 | We also provide a data reader and script for the [IBM Debater dataset](http://www.research.ibm.com/haifa/dept/vst/debating_data.shtml). See `train_ibm.sh` how to train BERT for this dataset. As before, you first need to download the corpus and unzip it to `datasets/ibm/. 32 | 33 | **Note:** Training on GPU leads to non-determinisitc results. For scientific experiments, we recommend to train with multiple random seeds and to average results. 34 | 35 | ## Inference 36 | You can use `inference.py` to classify new arguments on new topics: 37 | ``` 38 | python inference.py 39 | ``` 40 | 41 | You must specify the model path: 42 | ``` 43 | model_path = 'bert_output/ukp/bert-base-topic-sentence/all_topics/' 44 | ``` 45 | 46 | Download and unzip a pre-trained model from here: 47 | - [argument_classification_ukp_all_data.zip](https://public.ukp.informatik.tu-darmstadt.de/reimers/2019_acl-BERT-argument-classification-and-clustering/models/argument_classification_ukp_all_data.zip) 48 | - [argument_classification_ukp_all_data_large_model.zip](https://public.ukp.informatik.tu-darmstadt.de/reimers/2019_acl-BERT-argument-classification-and-clustering/models/argument_classification_ukp_all_data_large_model.zip) 49 | 50 | This model was trained on all eight topics of the [UKP Sentential Argument Mining Corpus](https://www.informatik.tu-darmstadt.de/ukp/research_6/data/argumentation_mining_1/ukp_sentential_argument_mining_corpus/index.en.jsp). The topics are: bortion, cloning, death penalty, gun control, marijuana legalization, minimum wage, nuclear energy, school uniforms. 51 | 52 | This model can be applied for arguments from different topics, for example, for keeping animals in zoos (this topic was not in the training data): 53 | ``` 54 | Predicted labels: 55 | Topic: zoo 56 | Sentence: A zoo is a facility in which all animals are housed within enclosures, displayed to the public, and in which they may also breed. 57 | Gold label: NoArgument 58 | Predicted label: NoArgument 59 | 60 | Topic: zoo 61 | Sentence: Zoos produce helpful scientific research. 62 | Gold label: Argument_for 63 | Predicted label: Argument_for 64 | 65 | Topic: zoo 66 | Sentence: Zoos save species from extinction and other dangers. 67 | Gold label: Argument_for 68 | Predicted label: Argument_for 69 | 70 | Topic: zoo 71 | Sentence: Zoo confinement is psychologically damaging to animals. 72 | Gold label: Argument_against 73 | Predicted label: Argument_against 74 | 75 | Topic: zoo 76 | Sentence: Zoos are detrimental to animals' physical health. 77 | Gold label: Argument_against 78 | Predicted label: Argument_against 79 | 80 | Topic: autonomous cars 81 | Sentence: Zoos are detrimental to animals' physical health. 82 | Gold label: NoArgument 83 | Predicted label: NoArgument 84 | ``` 85 | 86 | Note, when you change the topic for an argument, as in the last example, the model corretly identifies that this sentence is not an argument for / against 'autonomous cars'. 87 | 88 | 89 | 90 | ## Performance 91 | 92 | In a cross-topic evaluation, the BERT model achieves the following performance. 93 | 94 | ![Classification Performance](https://public.ukp.informatik.tu-darmstadt.de/reimers/2019_acl-BERT-argument-classification-and-clustering/images/table_classification_results.png) 95 | 96 | 97 | See our paper ([Classification and Clustering of Arguments with Contextualized Word Embeddings](https://arxiv.org/abs/1906.09821)) for further details. 98 | 99 | For the computation of the macro F1-score for the UKP corpus, see `ukp_evaluation.py`. 100 | -------------------------------------------------------------------------------- /argument-classification/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Runs a pre-trained BERT model for argument classification. 3 | 4 | You can download pre-trained models here: https://public.ukp.informatik.tu-darmstadt.de/reimers/2019_acl-BERT-argument-classification-and-clustering/models/argument_classification_ukp_all_data.zip 5 | 6 | The model 'bert_output/ukp/bert-base-topic-sentence/all_ukp_data/' was trained on all eight topics (abortion, cloning, death penalty, gun control, marijuana legalization, minimum wage, nuclear energy, school uniforms) from the Stab et al. corpus (UKP Sentential Argument 7 | Mining Corpus) 8 | 9 | Usage: python inference.py 10 | 11 | """ 12 | 13 | from pytorch_pretrained_bert.modeling import BertForSequenceClassification 14 | from pytorch_pretrained_bert.tokenization import BertTokenizer 15 | import torch 16 | from torch.utils.data import TensorDataset, DataLoader, SequentialSampler 17 | import numpy as np 18 | 19 | from train import InputExample, convert_examples_to_features 20 | 21 | 22 | num_labels = 3 23 | model_path = 'bert_output/argument_classification_ukp_all_data/' 24 | label_list = ["NoArgument", "Argument_against", "Argument_for"] 25 | max_seq_length = 64 26 | eval_batch_size = 8 27 | 28 | #Input examples. The model 'bert_output/ukp/bert-base-topic-sentence/all_topics/' expects text_a to be the topic 29 | #and text_b to be the sentence. label is an optional value, only used when we print the output in this script. 30 | 31 | input_examples = [ 32 | InputExample(text_a='zoo', text_b='A zoo is a facility in which all animals are housed within enclosures, displayed to the public, and in which they may also breed. ', label='NoArgument'), 33 | InputExample(text_a='zoo', text_b='Zoos produce helpful scientific research. ', label='Argument_for'), 34 | InputExample(text_a='zoo', text_b='Zoos save species from extinction and other dangers.', label='Argument_for'), 35 | InputExample(text_a='zoo', text_b='Zoo confinement is psychologically damaging to animals.', label='Argument_against'), 36 | InputExample(text_a='zoo', text_b='Zoos are detrimental to animals\' physical health.', label='Argument_against'), 37 | InputExample(text_a='autonomous cars', text_b='Zoos are detrimental to animals\' physical health.', label='NoArgument'), 38 | ] 39 | 40 | 41 | 42 | 43 | 44 | tokenizer = BertTokenizer.from_pretrained(model_path, do_lower_case=True) 45 | eval_features = convert_examples_to_features(input_examples, label_list, max_seq_length, tokenizer) 46 | 47 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 48 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 49 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 50 | 51 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids) 52 | eval_sampler = SequentialSampler(eval_data) 53 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=eval_batch_size) 54 | 55 | 56 | 57 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 58 | model = BertForSequenceClassification.from_pretrained(model_path, num_labels=num_labels) 59 | model.to(device) 60 | model.eval() 61 | 62 | predicted_labels = [] 63 | with torch.no_grad(): 64 | for input_ids, input_mask, segment_ids in eval_dataloader: 65 | input_ids = input_ids.to(device) 66 | input_mask = input_mask.to(device) 67 | segment_ids = segment_ids.to(device) 68 | 69 | 70 | logits = model(input_ids, segment_ids, input_mask) 71 | logits = logits.detach().cpu().numpy() 72 | 73 | for prediction in np.argmax(logits, axis=1): 74 | predicted_labels.append(label_list[prediction]) 75 | 76 | print("Predicted labels:") 77 | for idx in range(len(input_examples)): 78 | example = input_examples[idx] 79 | print("Topic:", example.text_a) 80 | print("Sentence:", example.text_b) 81 | print("Gold label:", example.label) 82 | print("Predicted label:", predicted_labels[idx]) 83 | print("") 84 | 85 | -------------------------------------------------------------------------------- /argument-classification/train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # This code trains a classifier that classifies a topic+sentence into the categories: NoArgument, Argument_for, Argument_against 3 | # The code is based on HuggingFace classification example: https://github.com/huggingface/pytorch-pretrained-BERT/blob/v0.6.2/examples/run_classifier.py 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import glob 10 | import os 11 | import csv 12 | import logging 13 | import argparse 14 | import random 15 | from tqdm import tqdm, trange 16 | 17 | import numpy as np 18 | import torch 19 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 20 | from torch.utils.data.distributed import DistributedSampler 21 | 22 | from pytorch_pretrained_bert.tokenization import BertTokenizer 23 | from pytorch_pretrained_bert.modeling import BertForSequenceClassification 24 | from pytorch_pretrained_bert.optimization import BertAdam 25 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME 26 | 27 | logging.basicConfig(format='%(message)s', #"format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 28 | datefmt='%m/%d/%Y %H:%M:%S', 29 | level=logging.INFO) 30 | logger = logging.getLogger(__name__) 31 | 32 | 33 | class InputExample(object): 34 | """A single training/test example for simple sequence classification.""" 35 | 36 | def __init__(self, text_a, text_b=None, label=None, guid=None): 37 | """Constructs a InputExample. 38 | 39 | Args: 40 | text_a: string. The untokenized text of the first sequence. For single 41 | sequence tasks, only this sequence must be specified. 42 | text_b: (Optional) string. The untokenized text of the second sequence. 43 | Only must be specified for sequence pair tasks. 44 | label: (Optional) string. The label of the example. This should be 45 | specified for train and dev examples, but not for test examples. 46 | guid: (Optional) Unique id for the example. 47 | """ 48 | self.text_a = text_a 49 | self.text_b = text_b 50 | self.label = label 51 | self.guid = guid 52 | 53 | 54 | class InputFeatures(object): 55 | """A single set of features of data.""" 56 | 57 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 58 | self.input_ids = input_ids 59 | self.input_mask = input_mask 60 | self.segment_ids = segment_ids 61 | self.label_id = label_id 62 | 63 | 64 | 65 | 66 | class UKPProcessor(object): 67 | def __init__(self, binarize_labels=False, use_all_data=False): 68 | self.binarize_labels = binarize_labels 69 | self.use_all_data=use_all_data 70 | 71 | def _read_dataset(self, data_dir): 72 | sentences = [] 73 | for filepath in glob.glob(os.path.join(data_dir,'*.tsv')): 74 | 75 | lines = open(filepath).readlines() 76 | 77 | for idx in range(1, len(lines)): 78 | line = lines[idx] 79 | splits = line.strip().split('\t') 80 | topic = splits[0] 81 | text = splits[4] 82 | label = self._convert_label(splits[5].strip()) 83 | data_split_set = splits[6] #train/dev/test value 84 | 85 | sentences.append([text, label, topic, data_split_set]) 86 | 87 | return sentences 88 | 89 | def get_train_examples(self, data_dir, test_topic): 90 | logger.info("Get train examples, Dev set: " + test_topic) 91 | return self._create_examples(self._read_dataset(data_dir), "train", test_topic) 92 | 93 | def get_test_examples(self, data_dir, test_topic): 94 | return self._create_examples(self._read_dataset(data_dir), "test", test_topic) 95 | 96 | def get_labels(self): 97 | if self.binarize_labels: 98 | return ["NoArgument", "Argument"] 99 | else: 100 | return ["NoArgument", "Argument_against", "Argument_for"] 101 | 102 | def _convert_label(self, label): 103 | if self.binarize_labels and label != 'NoArgument': 104 | return "Argument" 105 | 106 | return label 107 | 108 | def _create_examples(self, data_tuples, set_type, test_topic): 109 | """Creates examples for the training and dev sets.""" 110 | examples = [] 111 | for (i, data_tuple) in enumerate(data_tuples): 112 | guid = "%s-%s" % (set_type, i) 113 | sentence, label, topic, data_split_set = data_tuple 114 | 115 | if self.use_all_data: 116 | examples.append(self._get_input_example(guid, topic, sentence, label)) 117 | else: 118 | if data_split_set != set_type: #Train only on train_split, evaluate only on test set 119 | continue 120 | 121 | if set_type == 'test' and topic == test_topic: 122 | examples.append(self._get_input_example(guid, topic, sentence, label)) 123 | 124 | if set_type == 'train' and topic != test_topic: 125 | examples.append(self._get_input_example(guid, topic, sentence, label)) 126 | 127 | return examples 128 | 129 | def _get_input_example(self, guid, topic, sentence, label): 130 | return InputExample(guid=guid, text_a=sentence, text_b=None, label=label) 131 | 132 | 133 | class UKPProcessorTopicSentence(UKPProcessor): 134 | def _get_input_example(self, guid, topic, sentence, label): 135 | return InputExample(guid=guid, text_a=topic, text_b=sentence, label=label) 136 | 137 | 138 | class UKPProcessorSentenceTopic(UKPProcessor): 139 | def _get_input_example(self, guid, topic, sentence, label): 140 | return InputExample(guid=guid, text_a=sentence, text_b=topic, label=label) 141 | 142 | 143 | class IBMProcessor(object): 144 | def __init__(self, binarize_labels=False, use_all_data=False): 145 | pass 146 | 147 | def _read_dataset(self, filepath): 148 | sentences = [] 149 | 150 | with open(filepath) as fIn: 151 | csvreader = csv.DictReader(fIn) 152 | for row in csvreader: 153 | sentences.append(row) 154 | 155 | return sentences 156 | 157 | def get_train_examples(self, data_dir, test_topic): 158 | return self._create_examples(self._read_dataset(os.path.join(data_dir, 'train.csv')), 'train') 159 | 160 | def get_test_examples(self, data_dir, test_topic): 161 | return self._create_examples(self._read_dataset(os.path.join(data_dir, 'test.csv')), 'test') 162 | 163 | def get_labels(self): 164 | return ["0", "1"] 165 | 166 | 167 | def _create_examples(self, data_tuples, set_type): 168 | """Creates examples for the training and dev sets.""" 169 | examples = [] 170 | for (i, row) in enumerate(data_tuples): 171 | guid = "%s-%s" % (set_type, i) 172 | examples.append(self._get_input_example(guid, row)) 173 | 174 | return examples 175 | 176 | def _get_input_example(self, guid, row): 177 | topic = row['topic'] 178 | topic_concept = row['the concept of the topic'] 179 | sentence = row['candidate'] 180 | label = row['label'] 181 | return InputExample(guid=guid, text_a=sentence, text_b=None, label=label) 182 | 183 | 184 | class IBMProcessorTopicSentence(IBMProcessor): 185 | def _get_input_example(self, guid, row): 186 | topic = row['topic'] 187 | topic_concept = row['the concept of the topic'] 188 | sentence = row['candidate'] 189 | label = row['label'] 190 | return InputExample(guid=guid, text_a=topic, text_b=sentence, label=label) 191 | 192 | class IBMProcessorConceptSentence(IBMProcessor): 193 | def _get_input_example(self, guid, row): 194 | topic = row['topic'] 195 | topic_concept = row['the concept of the topic'] 196 | sentence = row['candidate'] 197 | label = row['label'] 198 | return InputExample(guid=guid, text_a=topic_concept, text_b=sentence, label=label) 199 | 200 | 201 | 202 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): 203 | """Loads a data file into a list of `InputBatch`s.""" 204 | 205 | label_map = {label: i for i, label in enumerate(label_list)} 206 | tokens_a_longer_max_seq_length = 0 207 | features = [] 208 | for (ex_index, example) in enumerate(examples): 209 | tokens_a = tokenizer.tokenize(example.text_a) 210 | tokens_b = None 211 | 212 | len_tokens_a = len(tokens_a) 213 | len_tokens_b = 0 214 | 215 | 216 | 217 | if example.text_b: 218 | tokens_b = tokenizer.tokenize(example.text_b) 219 | len_tokens_b = len(tokens_b) 220 | # Modifies `tokens_a` and `tokens_b` in place so that the total 221 | # length is less than the specified length. 222 | # Account for [CLS], [SEP], [SEP] with "- 3" 223 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 224 | else: 225 | # Account for [CLS] and [SEP] with "- 2" 226 | if len(tokens_a) > max_seq_length - 2: 227 | tokens_a = tokens_a[:(max_seq_length - 2)] 228 | 229 | if (len_tokens_a + len_tokens_b) > (max_seq_length - 2): 230 | tokens_a_longer_max_seq_length += 1 231 | 232 | # The convention in BERT is: 233 | # (a) For sequence pairs: 234 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 235 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 236 | # (b) For single sequences: 237 | # tokens: [CLS] the dog is hairy . [SEP] 238 | # type_ids: 0 0 0 0 0 0 0 239 | # 240 | # Where "type_ids" are used to indicate whether this is the first 241 | # sequence or the second sequence. The embedding vectors for `type=0` and 242 | # `type=1` were learned during pre-training and are added to the wordpiece 243 | # embedding vector (and position vector). This is not *strictly* necessary 244 | # since the [SEP] token unambigiously separates the sequences, but it makes 245 | # it easier for the model to learn the concept of sequences. 246 | # 247 | # For classification tasks, the first vector (corresponding to [CLS]) is 248 | # used as as the "sentence vector". Note that this only makes sense because 249 | # the entire model is fine-tuned. 250 | tokens = ["[CLS]"] + tokens_a + ["[SEP]"] 251 | segment_ids = [0] * len(tokens) 252 | 253 | if tokens_b: 254 | tokens += tokens_b + ["[SEP]"] 255 | segment_ids += [1] * (len(tokens_b) + 1) 256 | 257 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 258 | 259 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 260 | # tokens are attended to. 261 | input_mask = [1] * len(input_ids) 262 | 263 | # Zero-pad up to the sequence length. 264 | padding = [0] * (max_seq_length - len(input_ids)) 265 | input_ids += padding 266 | input_mask += padding 267 | segment_ids += padding 268 | 269 | assert len(input_ids)==max_seq_length 270 | assert len(input_mask)==max_seq_length 271 | assert len(segment_ids)==max_seq_length 272 | 273 | label_id = label_map[example.label] 274 | if ex_index < 1 and example.guid is not None and example.guid.startswith('train'): 275 | logger.info("\n\n*** Example ***") 276 | logger.info("guid: %s" % (example.guid)) 277 | logger.info("tokens: %s" % " ".join( 278 | [str(x) for x in tokens])) 279 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 280 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 281 | logger.info( 282 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 283 | logger.info("label: %s (id = %d)" % (example.label, label_id)) 284 | logger.info("\n\n") 285 | 286 | features.append( 287 | InputFeatures(input_ids=input_ids, 288 | input_mask=input_mask, 289 | segment_ids=segment_ids, 290 | label_id=label_id)) 291 | 292 | logger.info(":: Sentences longer than max_sequence_length: %d" % (tokens_a_longer_max_seq_length)) 293 | logger.info(":: Num sentences: %d" % (len(examples))) 294 | return features 295 | 296 | 297 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 298 | """Truncates a sequence pair in place to the maximum length.""" 299 | 300 | # This is a simple heuristic which will always truncate the longer sequence 301 | # one token at a time. This makes more sense than truncating an equal percent 302 | # of tokens from each, since if one sequence is very short then each token 303 | # that's truncated likely contains more information than a longer sequence. 304 | while True: 305 | total_length = len(tokens_a) + len(tokens_b) 306 | if total_length <= max_length: 307 | break 308 | if len(tokens_a) > len(tokens_b): 309 | tokens_a.pop() 310 | else: 311 | tokens_b.pop() 312 | 313 | 314 | def accuracy(out, labels): 315 | outputs = np.argmax(out, axis=1) 316 | return np.sum(outputs==labels) 317 | 318 | 319 | def warmup_linear(x, warmup=0.002): 320 | if x < warmup: 321 | return x / warmup 322 | return 1.0 - x 323 | 324 | 325 | def main(): 326 | parser = argparse.ArgumentParser() 327 | 328 | ## Required parameters 329 | parser.add_argument("--data_dir", 330 | default=None, 331 | type=str, 332 | required=True, 333 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 334 | parser.add_argument("--bert_model", default=None, type=str, required=True, 335 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 336 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " 337 | "bert-base-multilingual-cased, bert-base-chinese.") 338 | parser.add_argument("--task_name", 339 | default=None, 340 | type=str, 341 | required=True, 342 | help="The name of the task to train.") 343 | parser.add_argument("--output_dir", 344 | default=None, 345 | type=str, 346 | required=True, 347 | help="The output directory where the model predictions and checkpoints will be written.") 348 | 349 | ## Other parameters 350 | parser.add_argument("--max_seq_length", 351 | default=128, 352 | type=int, 353 | help="The maximum total input sequence length after WordPiece tokenization. \n" 354 | "Sequences longer than this will be truncated, and sequences shorter \n" 355 | "than this will be padded.") 356 | parser.add_argument("--do_train", 357 | action='store_true', 358 | help="Whether to run training.") 359 | parser.add_argument("--do_eval", 360 | action='store_true', 361 | help="Whether to run eval on the dev set.") 362 | parser.add_argument("--do_lower_case", 363 | action='store_true', 364 | help="Set this flag if you are using an uncased model.") 365 | parser.add_argument("--train_batch_size", 366 | default=32, 367 | type=int, 368 | help="Total batch size for training.") 369 | parser.add_argument("--eval_batch_size", 370 | default=8, 371 | type=int, 372 | help="Total batch size for eval.") 373 | parser.add_argument("--learning_rate", 374 | default=5e-5, 375 | type=float, 376 | help="The initial learning rate for Adam.") 377 | parser.add_argument("--num_train_epochs", 378 | default=3.0, 379 | type=float, 380 | help="Total number of training epochs to perform.") 381 | parser.add_argument("--warmup_proportion", 382 | default=0.1, 383 | type=float, 384 | help="Proportion of training to perform linear learning rate warmup for. " 385 | "E.g., 0.1 = 10%% of training.") 386 | parser.add_argument("--no_cuda", 387 | action='store_true', 388 | help="Whether not to use CUDA when available") 389 | parser.add_argument("--local_rank", 390 | type=int, 391 | default=-1, 392 | help="local_rank for distributed training on gpus") 393 | parser.add_argument('--seed', 394 | type=int, 395 | default=42, 396 | help="random seed for initialization") 397 | parser.add_argument('--gradient_accumulation_steps', 398 | type=int, 399 | default=1, 400 | help="Number of updates steps to accumulate before performing a backward/update pass.") 401 | parser.add_argument('--fp16', 402 | action='store_true', 403 | help="Whether to use 16-bit float precision instead of 32-bit") 404 | parser.add_argument('--loss_scale', 405 | type=float, default=0, 406 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 407 | "0 (default value): dynamic loss scaling.\n" 408 | "Positive power of 2: static loss scaling value.\n") 409 | 410 | parser.add_argument("--test_set", 411 | default=None, 412 | type=str, 413 | help="Name of test set.") 414 | 415 | parser.add_argument("--binarize_labels", 416 | default=0, 417 | type=int, 418 | help="binarize_labels.") 419 | 420 | 421 | parser.add_argument("--use_all_data", 422 | default=0, 423 | type=int, 424 | help="binarize_labels.") 425 | 426 | args = parser.parse_args() 427 | 428 | 429 | 430 | processors = { 431 | "ukp-sentence": UKPProcessor, 432 | "ukp-topic-sentence": UKPProcessorTopicSentence, 433 | "ukp-sentence-topic": UKPProcessorSentenceTopic, 434 | "ibm-sentence": IBMProcessor, 435 | "ibm-topic-sentence": IBMProcessorTopicSentence, 436 | "ibm-concept-sentence": IBMProcessorConceptSentence, 437 | } 438 | 439 | binarize_labels = False if args.binarize_labels == 0 else True 440 | use_all_data = False if args.use_all_data==0 else True 441 | 442 | if args.test_set is not None: 443 | logger.info("Test set: "+args.test_set) 444 | 445 | if args.local_rank==-1 or args.no_cuda: 446 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 447 | n_gpu = torch.cuda.device_count() 448 | else: 449 | torch.cuda.set_device(args.local_rank) 450 | device = torch.device("cuda", args.local_rank) 451 | n_gpu = 1 452 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 453 | torch.distributed.init_process_group(backend='nccl') 454 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( 455 | device, n_gpu, bool(args.local_rank!=-1), args.fp16)) 456 | 457 | if args.gradient_accumulation_steps < 1: 458 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 459 | args.gradient_accumulation_steps)) 460 | 461 | args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) 462 | 463 | random.seed(args.seed) 464 | np.random.seed(args.seed) 465 | torch.manual_seed(args.seed) 466 | if n_gpu > 0: 467 | torch.cuda.manual_seed_all(args.seed) 468 | 469 | if not args.do_train and not args.do_eval: 470 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 471 | 472 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: 473 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 474 | os.makedirs(args.output_dir, exist_ok=True) 475 | 476 | with open(os.path.join(args.output_dir, 'parameters.txt'), 'w') as fOut: 477 | fOut.write(str(args)) 478 | 479 | task_name = args.task_name.lower() 480 | 481 | if task_name not in processors: 482 | raise ValueError("Task not found: %s" % (task_name)) 483 | 484 | processor = processors[task_name](binarize_labels, use_all_data) 485 | label_list = processor.get_labels() 486 | num_labels = len(label_list) 487 | 488 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 489 | 490 | train_examples = None 491 | num_train_steps = None 492 | if args.do_train: 493 | train_examples = processor.get_train_examples(args.data_dir, args.test_set) 494 | num_train_steps = int( 495 | len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) 496 | 497 | # Prepare model 498 | model = BertForSequenceClassification.from_pretrained(args.bert_model, 499 | cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format( 500 | args.local_rank), 501 | num_labels=num_labels) 502 | if args.fp16: 503 | model.half() 504 | model.to(device) 505 | if args.local_rank!=-1: 506 | try: 507 | from apex.parallel import DistributedDataParallel as DDP 508 | except ImportError: 509 | raise ImportError( 510 | "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 511 | 512 | model = DDP(model) 513 | elif n_gpu > 1: 514 | model = torch.nn.DataParallel(model) 515 | 516 | # Prepare optimizer 517 | param_optimizer = list(model.named_parameters()) 518 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 519 | optimizer_grouped_parameters = [ 520 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 521 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 522 | ] 523 | t_total = num_train_steps 524 | if args.local_rank!=-1: 525 | t_total = t_total // torch.distributed.get_world_size() 526 | if args.fp16: 527 | try: 528 | from apex.optimizers import FP16_Optimizer 529 | from apex.optimizers import FusedAdam 530 | except ImportError: 531 | raise ImportError( 532 | "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 533 | 534 | optimizer = FusedAdam(optimizer_grouped_parameters, 535 | lr=args.learning_rate, 536 | bias_correction=False, 537 | max_grad_norm=1.0) 538 | if args.loss_scale==0: 539 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) 540 | else: 541 | optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) 542 | 543 | else: 544 | optimizer = BertAdam(optimizer_grouped_parameters, 545 | lr=args.learning_rate, 546 | warmup=args.warmup_proportion, 547 | t_total=t_total) 548 | 549 | global_step = 0 550 | nb_tr_steps = 0 551 | tr_loss = 0 552 | if args.do_train: 553 | with open(os.path.join(args.output_dir, "train_sentences.csv"), "w") as writer: 554 | for idx, example in enumerate(train_examples): 555 | writer.write("%s\t%s\t%s\n" % (example.label, example.text_a, example.text_b)) 556 | 557 | 558 | train_features = convert_examples_to_features( 559 | train_examples, label_list, args.max_seq_length, tokenizer) 560 | logger.info("***** Running training *****") 561 | logger.info(" Labels = %s", ", ".join(label_list)) 562 | logger.info(" Num examples = %d", len(train_examples)) 563 | logger.info(" Batch size = %d", args.train_batch_size) 564 | logger.info(" Num steps = %d", num_train_steps) 565 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 566 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 567 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 568 | all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) 569 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 570 | if args.local_rank==-1: 571 | train_sampler = RandomSampler(train_data) 572 | else: 573 | train_sampler = DistributedSampler(train_data) 574 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) 575 | 576 | model.train() 577 | for epoch in trange(int(args.num_train_epochs), desc="Epoch"): 578 | tr_loss = 0 579 | nb_tr_examples, nb_tr_steps = 0, 0 580 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 581 | batch = tuple(t.to(device) for t in batch) 582 | input_ids, input_mask, segment_ids, label_ids = batch 583 | loss = model(input_ids, segment_ids, input_mask, label_ids) 584 | if n_gpu > 1: 585 | loss = loss.mean() # mean() to average on multi-gpu. 586 | if args.gradient_accumulation_steps > 1: 587 | loss = loss / args.gradient_accumulation_steps 588 | 589 | if args.fp16: 590 | optimizer.backward(loss) 591 | else: 592 | loss.backward() 593 | 594 | tr_loss += loss.item() 595 | nb_tr_examples += input_ids.size(0) 596 | nb_tr_steps += 1 597 | if (step + 1) % args.gradient_accumulation_steps==0: 598 | # modify learning rate with special warm up BERT uses 599 | lr_this_step = args.learning_rate * warmup_linear(global_step / t_total, args.warmup_proportion) 600 | for param_group in optimizer.param_groups: 601 | param_group['lr'] = lr_this_step 602 | optimizer.step() 603 | optimizer.zero_grad() 604 | global_step += 1 605 | 606 | if args.do_eval: 607 | eval_results_filename = "test_results_epoch_%d.txt" % (epoch) 608 | eval_prediction_filename = "test_predictions_epoch_%d.txt" % (epoch) 609 | do_evaluation(processor, args, label_list, tokenizer, model, device, tr_loss, nb_tr_steps, global_step, 610 | task_name, eval_results_filename, eval_prediction_filename) 611 | 612 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 613 | # Save a trained model 614 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 615 | 616 | # If we save using the predefined names, we can load using `from_pretrained` 617 | output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) 618 | output_config_file = os.path.join(args.output_dir, CONFIG_NAME) 619 | 620 | torch.save(model_to_save.state_dict(), output_model_file) 621 | model_to_save.config.to_json_file(output_config_file) 622 | tokenizer.save_vocabulary(args.output_dir) 623 | 624 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 625 | output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") 626 | if args.do_train: 627 | torch.save(model_to_save.state_dict(), output_model_file) 628 | else: 629 | model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels) 630 | model.to(device) 631 | 632 | 633 | if args.do_eval and (args.local_rank==-1 or torch.distributed.get_rank()==0): 634 | eval_results_filename = "test_results.txt" 635 | eval_prediction_filename = "test_predictions.txt" 636 | do_evaluation(processor, args, label_list, tokenizer, model, device, tr_loss, nb_tr_steps, global_step, task_name, eval_results_filename, eval_prediction_filename) 637 | 638 | 639 | def do_evaluation(processor, args, label_list, tokenizer, model, device, tr_loss, nb_tr_steps, global_step, task_name, eval_results_filename, eval_prediction_filename): 640 | eval_examples = processor.get_test_examples(args.data_dir, args.test_set) 641 | eval_features = convert_examples_to_features( 642 | eval_examples, label_list, args.max_seq_length, tokenizer) 643 | logger.info("***** Running evaluation *****") 644 | logger.info(" Num examples = %d", len(eval_examples)) 645 | logger.info(" Batch size = %d", args.eval_batch_size) 646 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 647 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 648 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 649 | all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) 650 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 651 | # Run prediction for full data 652 | eval_sampler = SequentialSampler(eval_data) 653 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 654 | 655 | model.eval() 656 | eval_accuracy = 0, 0 657 | nb_eval_steps, nb_eval_examples = 0, 0 658 | 659 | predicted_labels = [] 660 | for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"): 661 | input_ids = input_ids.to(device) 662 | input_mask = input_mask.to(device) 663 | segment_ids = segment_ids.to(device) 664 | label_ids = label_ids.to(device) 665 | 666 | with torch.no_grad(): 667 | logits = model(input_ids, segment_ids, input_mask) 668 | 669 | logits = logits.detach().cpu().numpy() 670 | label_ids = label_ids.to('cpu').numpy() 671 | tmp_eval_accuracy = accuracy(logits, label_ids) 672 | 673 | for prediction in np.argmax(logits, axis=1): 674 | predicted_labels.append(label_list[prediction]) 675 | 676 | eval_accuracy += tmp_eval_accuracy 677 | 678 | nb_eval_examples += input_ids.size(0) 679 | nb_eval_steps += 1 680 | 681 | 682 | eval_accuracy = eval_accuracy / nb_eval_examples 683 | loss = tr_loss / nb_tr_steps if args.do_train else None 684 | result = { 685 | 'eval_accuracy': eval_accuracy, 686 | 'global_step': global_step, 687 | 'train_loss': loss} 688 | 689 | output_eval_file = os.path.join(args.output_dir, eval_results_filename) 690 | with open(output_eval_file, "w") as writer: 691 | logger.info("\n\n\n***** Eval results *****") 692 | for key in sorted(result.keys()): 693 | logger.info(" %s = %s", key, str(result[key])) 694 | writer.write("%s = %s\n" % (key, str(result[key]))) 695 | logger.info("\n\n\n") 696 | 697 | 698 | output_pred_file = os.path.join(args.output_dir, eval_prediction_filename) 699 | with open(output_pred_file, "w") as writer: 700 | for idx, example in enumerate(eval_examples): 701 | gold_label = example.label 702 | pred_label = predicted_labels[idx] 703 | 704 | text_a = example.text_a.replace("\n", " ") 705 | text_b = example.text_b.replace("\n", " ") if example.text_b is not None else "None" 706 | 707 | writer.write("%s\t%s\t%s\t%s\n" % (gold_label, pred_label, text_a, text_b)) 708 | 709 | 710 | if __name__=="__main__": 711 | main() 712 | 713 | 714 | 715 | 716 | 717 | -------------------------------------------------------------------------------- /argument-classification/train_ibm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | 4 | python train.py --task_name ibm-topic-sentence --do_train --do_eval --seed 1 --do_lower_case --binarize_labels 0 --data_dir ./datasets/ibm/ --bert_model bert-base-uncased --max_seq_length 64 --train_batch_size 32 --learning_rate 2e-5 --num_train_epochs 3.0 --output_dir "bert_output/ibm/" 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /argument-classification/train_ukp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | for testset in "abortion" "cloning" "death penalty" "gun control" "marijuana legalization" "minimum wage" "nuclear energy" "school uniforms" 4 | do 5 | python train.py --task_name ukp-topic-sentence --do_train --do_eval --seed 1 --do_lower_case --binarize_labels 0 --data_dir ./datasets/ukp/data/complete/ --bert_model bert-base-uncased --max_seq_length 64 --train_batch_size 16 --test_set "$testset" --learning_rate 2e-5 --num_train_epochs 2.0 --output_dir "bert_output/ukp/bert-base-topic-sentence/${testset}_test_topic/" 6 | 7 | done 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /argument-classification/train_ukp_all_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python train.py --task_name ukp-topic-sentence --do_train --use_all_data=1 --seed 1 --do_lower_case --binarize_labels 0 --data_dir ./datasets/ukp/data/complete/ --bert_model bert-base-uncased --max_seq_length 64 --train_batch_size 16 --test_set "" --learning_rate 2e-5 --num_train_epochs 2.0 --output_dir "bert_output/ukp/bert-base-topic-sentence/all_data/" 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /argument-classification/ukp_evaluation.py: -------------------------------------------------------------------------------- 1 | # This code runs the macro-weighted F1-computation for the UKP Sentential Argument Mining Corpus (Table 1 in the Paper Classification and Clustering of Arguments with Contextualized Word Embeddings) 2 | # Usage: python ukp_evaluation.py bert_output/ukp/bert-base-topic-sentence/*/test_predictions.txt 3 | # Expected output: 4 | # ... 5 | # ===================== Overall results ====================" 6 | # Average F1 score over all topics 7 | # test_predictions.txt (8) 61.69% 8 | # 9 | # Note: output values can change due to the non-determinism of GPU computations 10 | 11 | 12 | from __future__ import print_function 13 | from sklearn.metrics import recall_score, precision_score, f1_score 14 | from sklearn.utils.multiclass import unique_labels 15 | import numpy as np 16 | import sys 17 | import os 18 | 19 | def analyze_predictions(filepath): 20 | total_sent = 0 21 | correct_sent = 0 22 | count = {} 23 | 24 | y_true = [] 25 | y_pred = [] 26 | 27 | for line in open(filepath, encoding='utf8'): 28 | splits = line.strip().split("\t") 29 | gold = splits[0] 30 | pred = splits[1] 31 | 32 | total_sent += 1 33 | if gold == pred: 34 | correct_sent += 1 35 | 36 | if gold not in count: 37 | count[gold] = {} 38 | 39 | if pred not in count[gold]: 40 | count[gold][pred] = 0 41 | 42 | count[gold][pred] += 1 43 | 44 | y_true.append(gold) 45 | y_pred.append(pred) 46 | 47 | print("gold - pred - Confusion Matrix") 48 | for gold_label in sorted(count.keys()): 49 | for pred_label in sorted(count[gold_label].keys()): 50 | print("%s - %s: %d" % (gold_label, pred_label, count[gold_label][pred_label])) 51 | 52 | 53 | print(":: BERT ::") 54 | print("Acc: %.2f%%" % (correct_sent/total_sent*100) ) 55 | labels = unique_labels(y_true, y_pred) 56 | prec = precision_score(y_true, y_pred, average=None) 57 | rec = recall_score(y_true, y_pred, average=None) 58 | f1 = f1_score(y_true, y_pred, average=None) 59 | 60 | arg_f1 = [] 61 | for idx, label in enumerate(labels): 62 | print("\n:: F1 for "+label+" ::") 63 | print("Prec: %.2f%%" % (prec[idx]*100)) 64 | print("Recall: %.2f%%" % (rec[idx]*100)) 65 | print("F1: %.2f%%" % (f1[idx]*100)) 66 | 67 | if label in labels: 68 | if label != 'NoArgument': 69 | arg_f1.append(f1[idx]) 70 | 71 | 72 | print("\n:: Macro Weighted for all ::") 73 | print("F1: %.2f%%" % (np.mean(f1)*100)) 74 | 75 | prec_mapping = {key:value for key, value in zip(labels, prec)} 76 | rec_mapping = {key:value for key, value in zip(labels, rec)} 77 | return np.mean(f1), prec_mapping, rec_mapping 78 | 79 | results = {} 80 | prec_results = {} 81 | rec_results = {} 82 | for filepath in sys.argv[1:]: 83 | print("\n\n===================== "+filepath+" ====================") 84 | f1, prec, rec = analyze_predictions(filepath) 85 | 86 | folder = filepath.split('/')[-2] 87 | topic = folder.split('_')[0].split('(')[0] 88 | filename = os.path.basename(filepath) 89 | 90 | if topic not in results: 91 | results[topic] = {} 92 | prec_results[topic] = {} 93 | rec_results[topic] = {} 94 | 95 | if filename not in results[topic]: 96 | results[topic][filename] = [] 97 | prec_results[topic][filename] = [] 98 | rec_results[topic][filename] = [] 99 | 100 | 101 | results[topic][filename].append(f1) 102 | prec_results[topic][filename].append(prec) 103 | rec_results[topic][filename].append(rec) 104 | 105 | 106 | 107 | print("\n\n===================== Overall results ====================") 108 | model_f1 = {} 109 | model_prec = {} 110 | model_rec = {} 111 | 112 | for topic in sorted(results.keys()): 113 | print(topic) 114 | for filename in sorted(results[topic].keys()): 115 | topic_f1_mean = np.mean(results[topic][filename]) 116 | print("%s (%d): %.4f" % (filename, len(results[topic][filename]), topic_f1_mean)) 117 | 118 | if filename not in model_f1: 119 | model_f1[filename] = [] 120 | model_prec[filename] = [] 121 | model_rec[filename] = [] 122 | 123 | model_f1[filename].append(results[topic][filename]) 124 | for prec in prec_results[topic][filename]: 125 | model_prec[filename].append(prec) 126 | 127 | for rec in rec_results[topic][filename]: 128 | model_rec[filename].append(rec) 129 | print("") 130 | 131 | 132 | print("\n\n==========================================") 133 | print("Average F1 score over all topics") 134 | for filename in model_f1: 135 | print("%s (%d) %.2f%%" % (filename, len(model_f1[filename]), np.mean(model_f1[filename])*100)) 136 | 137 | 138 | print("\n\n==========================================") 139 | print("P_arg score over all topics") 140 | for filename in model_prec: 141 | prec_pos = [prec_result['Argument_for'] for prec_result in model_prec[filename]] 142 | print("P_arg+ %s (%d): %.4f" % (filename, len(prec_pos), np.mean(prec_pos))) 143 | 144 | prec_neg = [prec_result['Argument_against'] for prec_result in model_prec[filename]] 145 | print("P_arg- %s (%d): %.4f" % (filename, len(prec_neg), np.mean(prec_neg))) 146 | 147 | print("\n\n==========================================") 148 | print("R_arg score over all topics") 149 | for filename in model_rec: 150 | rec_pos = [rec_result['Argument_for'] for rec_result in model_rec[filename]] 151 | print("R_arg+ %s (%d): %.4f" % (filename, len(rec_pos), np.mean(rec_pos))) 152 | 153 | rec_neg = [rec_result['Argument_against'] for rec_result in model_rec[filename]] 154 | print("R_arg- %s (%d): %.4f" % (filename, len(rec_neg), np.mean(rec_neg))) -------------------------------------------------------------------------------- /argument-similarity/README.md: -------------------------------------------------------------------------------- 1 | # Argument Similarity 2 | This folder contains code to fine-tune BERT to estimate the similarity between two arguments. We fine-tune BERT either on the [UKP Argument Aspect Similarity Corpus](https://www.informatik.tu-darmstadt.de/ukp/research_6/data/argumentation_mining_1/ukp_argument_aspect_similarity_corpus/ukp_argument_aspect_similarity_corpus.en.jsp) or on the [Argument Facet Similarity (AFS) Corpus](https://nlds.soe.ucsc.edu/node/44) from Misra et al., 2016. 3 | 4 | ## Setup 5 | For the setup, see the [README.md](https://github.com/UKPLab/acl2019-BERT-argument-classification-and-clustering/) in the main folder. 6 | 7 | ## Example 8 | You can download two pre-trained models: 9 | - [argument_similarity_ukp_aspects_all.zip](https://public.ukp.informatik.tu-darmstadt.de/reimers/2019_acl-BERT-argument-classification-and-clustering/models/argument_similarity_ukp_aspects_all.zip) - trained on the complete UKP Aspects Corpus 10 | - [argument_similarity_misra_all.zip](https://public.ukp.informatik.tu-darmstadt.de/reimers/2019_acl-BERT-argument-classification-and-clustering/models/argument_similarity_misra_all.zip) - trained on the complete AFS corpus from Misra et al. 11 | 12 | Download and unzip these models. In `inference.py, update the model_path variable to match the path with your unzipped models: 13 | ``` 14 | model_path = 'bert_output/ukp_aspects_all' 15 | ``` 16 | 17 | And then run it: 18 | ``` 19 | python inference.py 20 | ``` 21 | 22 | The output should be something like this for the model trained on the UKP corpus: 23 | ``` 24 | Predicted similarities (sorted by similarity): 25 | Sentence A: Eating meat is not cruel or unethical; it is a natural part of the cycle of life. 26 | Sentence B: It is cruel and unethical to kill animals for food when vegetarian options are available 27 | Similarity: 0.99436545 28 | 29 | Sentence A: Zoos are detrimental to animals' physical health. 30 | Sentence B: Zoo confinement is psychologically damaging to animals. 31 | Similarity: 0.99386144 32 | 33 | [...] 34 | 35 | Sentence A: It is cruel and unethical to kill animals for food when vegetarian options are available 36 | Sentence B: Rising levels of human-produced gases released into the atmosphere create a greenhouse effect that traps heat and causes global warming. 37 | Similarity: 0.0057242378 38 | ``` 39 | 40 | With the Misra AFS model, the output should be something like this: 41 | ``` 42 | Predicted similarities (sorted by similarity): 43 | Sentence A: Zoos are detrimental to animals' physical health. 44 | Sentence B: Zoo confinement is psychologically damaging to animals. 45 | Similarity: 0.8723387 46 | 47 | Sentence A: Eating meat is not cruel or unethical; it is a natural part of the cycle of life. 48 | Sentence B: It is cruel and unethical to kill animals for food when vegetarian options are available 49 | Similarity: 0.77635074 50 | 51 | [...] 52 | 53 | Sentence A: Zoos produce helpful scientific research. 54 | Sentence B: Eating meat is not cruel or unethical; it is a natural part of the cycle of life. 55 | Similarity: 0.20616204 56 | ``` 57 | 58 | ## Training UKP Aspects Corpus 59 | Download [UKP Argument Aspect Similarity Corpus](https://www.informatik.tu-darmstadt.de/ukp/research_6/data/argumentation_mining_1/ukp_argument_aspect_similarity_corpus/ukp_argument_aspect_similarity_corpus.en.jsp) and unzip it into the `datasets` folder, so that the file `datasets/ukp_aspect/UKP_ASPECT.tsv` exists. 60 | 61 | In our experiments, we used 4 fold cross-topic validation. To generate the 4 splits, run `datasets/ukp_aspect/make_splits.py`. This generates 4 folders with respecitive train/dev/test.tsv files, that can be used for training, tuning and testing the performance on the respective fold. 62 | 63 | Run `train_ukp.sh` to train on the UKP Aspects Corpus using this 4 fold cross-topic validation. `train_ukp_all.sh` fine-tunes BERT on all 28 topics of the UKP Aspects corpus (again, without any dev/test set). 64 | 65 | ## Training Argument Facet Similarity (AFS) Corpus 66 | 67 | The [Argument Facet Similarity (AFS) Corpus](https://nlds.soe.ucsc.edu/node/44) must be download from that website and unzipped into the `datasets/misra/` folder, i.e., the file `datasets/misra/ArgParis_DP.csv` should exists after unzipping the AFS corpus. 68 | 69 | Run `train_misra.sh` to train on the Misra AFS Corpus. The `train_misra_all.sh` fine-tunes BERT on all 3 topics of the AFS data, without any development or test set. 70 | 71 | ## Performance 72 | See our paper ([Classification and Clustering of Arguments with Contextualized Word Embeddings](https://arxiv.org/abs/1906.09821)) for further details. 73 | 74 | ### UKP Aspect Corpus 75 | The performance on the UKP Aspects Corpus is evaluated in 4-fold cross-topic setup. See `evaluation_with_clustering.py` and `evaluation_without_clustering.py` to compute the performance scores. In the paper, we achieved the following performances: 76 | 77 | ![UKP Aspects Performance](https://public.ukp.informatik.tu-darmstadt.de/reimers/2019_acl-BERT-argument-classification-and-clustering/images/table_UKP_Aspects_results.png) 78 | 79 | ### AFS Corpus 80 | Misra et al., 2016, used 10-fold cross-validation. However, their setup has the drawback that the test data contains sentences that were already seen in the training set (only a specific combination of two sentences were not seen at test time). 81 | 82 | Instead of 10-fold cross-validation, we propose cross-topic evaluation. This also allows to estimate how well the model generalizes to new, unseen topics. 83 | 84 | In the paper, we achieve the following correlation scores: 85 | 86 | ![AFS Performance](https://public.ukp.informatik.tu-darmstadt.de/reimers/2019_acl-BERT-argument-classification-and-clustering/images/table_AFS_results.png) 87 | 88 | Note: The code published here is a cleaner and nicer version from the code we used for the paper. Results you get from this published code is slightly different to what is reported in the paper, partly due to randomness, partly (maybe) to the slight adaptation we published. Results achieved with this published implementation are usually slightly higher than what was published. 89 | -------------------------------------------------------------------------------- /argument-similarity/SigmoidBERT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import pytorch_pretrained_bert.modeling 4 | 5 | 6 | class SigmoidBERT(pytorch_pretrained_bert.modeling.BertPreTrainedModel): 7 | def __init__(self, config, num_labels=1): 8 | super(SigmoidBERT, self).__init__(config) 9 | assert num_labels==1; 10 | self.num_labels = num_labels 11 | self.bert = pytorch_pretrained_bert.modeling.BertModel(config) 12 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 13 | self.lin_layer = nn.Linear(config.hidden_size, num_labels) 14 | self.apply(self.init_bert_weights) 15 | 16 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 17 | encoded_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 18 | 19 | #sent_encoding = pooled_output 20 | sent_encoding = encoded_layers[:, 0, :] 21 | 22 | sent_encoding = self.lin_layer(sent_encoding) 23 | logits = torch.sigmoid(sent_encoding) 24 | 25 | if labels is not None: 26 | loss_fct = nn.BCELoss() 27 | loss = loss_fct(logits[:, 0], labels.view(-1)) 28 | return loss 29 | else: 30 | return logits -------------------------------------------------------------------------------- /argument-similarity/datasets/ukp_aspect/make_splits.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script takes the UKP Argument Aspect Similarity Corpus (https://www.informatik.tu-darmstadt.de/ukp/research_6/data/argumentation_mining_1/ukp_argument_aspect_similarity_corpus/ukp_argument_aspect_similarity_corpus.en.jsp) and creates a 4-fold cross-topic split 3 | 4 | To run this script, download the corpus, unzip it so that this folder contains the UKP_ASPECT.tsv file 5 | """ 6 | 7 | 8 | import os 9 | 10 | topic_splits = [ 11 | { 12 | 'train': ['Wind power', 'Nanotechnology', '3d printing', 'Cryptocurrency', 'Virtual reality', 'Gene editing', 'Public surveillance', 'Genetic diagnosis', 'Geoengineering', 'Gmo', 'Organ donation', 'Recycling', 'Offshore drilling', 'Robotic surgery', 'Cloud storing', 'Electric cars', 'Stem cell research'], 13 | 'dev': ['Hydrogen fuel cells', 'Electronic voting', 'Drones', 'Solar energy'], 14 | 'test': ['Tissue engineering', 'Big data', 'Fracking', 'Social networks', 'Net neutrality', 'Hydroelectric dams', 'Internet of things'] 15 | }, 16 | { 17 | 'train': ['Wind power', '3d printing', 'Cryptocurrency', 'Tissue engineering', 'Gene editing', 'Virtual reality', 'Big data', 'Fracking', 'Public surveillance', 'Genetic diagnosis', 'Hydroelectric dams', 'Drones', 'Gmo', 'Organ donation', 'Solar energy', 'Electronic voting', 'Electric cars'], 18 | 'dev': ['Net neutrality', 'Internet of things', 'Hydrogen fuel cells', 'Social networks'], 19 | 'test': ['Nanotechnology', 'Geoengineering', 'Recycling', 'Offshore drilling', 'Robotic surgery', 'Cloud storing', 'Stem cell research'] 20 | }, 21 | { 22 | 'train': ['Nanotechnology', 'Tissue engineering', 'Hydrogen fuel cells', 'Big data', 'Cloud storing', 'Fracking', 'Net neutrality', 'Hydroelectric dams', 'Geoengineering', 'Gmo', 'Recycling', 'Offshore drilling', 'Robotic surgery', 'Internet of things', 'Electronic voting', 'Genetic diagnosis', 'Stem cell research'], 23 | 'dev': ['Gene editing', 'Solar energy', 'Social networks', '3d printing'], 24 | 'test': ['Wind power', 'Cryptocurrency', 'Virtual reality', 'Public surveillance', 'Drones', 'Organ donation', 'Electric cars'] 25 | }, 26 | { 27 | 'train': ['Wind power', 'Nanotechnology', 'Tissue engineering', 'Virtual reality', 'Big data', 'Fracking', 'Public surveillance', 'Social networks', 'Net neutrality', 'Drones', 'Recycling', 'Organ donation', 'Offshore drilling', 'Robotic surgery', 'Cloud storing', 'Electric cars', 'Stem cell research'], 28 | 'dev': ['Cryptocurrency', 'Hydroelectric dams', 'Geoengineering', 'Internet of things'], 29 | 'test': ['3d printing', 'Gene editing', 'Hydrogen fuel cells', 'Gmo', 'Solar energy', 'Electronic voting', 'Genetic diagnosis'] 30 | } 31 | ] 32 | 33 | sentences = {} 34 | 35 | with open('UKP_ASPECT.tsv') as fIn: 36 | next(fIn) #Skip header line 37 | for line in fIn: 38 | line = line.strip() 39 | topic = line.split('\t')[0] 40 | if topic not in sentences: 41 | sentences[topic] = [] 42 | 43 | sentences[topic].append(line) 44 | 45 | 46 | for split_idx in range(len(topic_splits)): 47 | for dataset_split in ['train', 'dev', 'test']: 48 | folder = os.path.join("splits", str(split_idx)) 49 | os.makedirs(folder, exist_ok=True) 50 | 51 | with open(os.path.join(folder, dataset_split+'.tsv'), 'w') as fOut: 52 | for topic in topic_splits[split_idx][dataset_split]: 53 | for sentence in sentences[topic]: 54 | fOut.write(sentence+"\n") 55 | fOut.flush() 56 | 57 | print("Splits created") 58 | 59 | ##Create all_data.tsv.gz 60 | with open(os.path.join('splits', 'all_data.tsv'), 'w') as fOut: 61 | for topic in sentences: 62 | for sentence in sentences[topic]: 63 | fOut.write(sentence + "\n") 64 | 65 | print("all_data.tsv.gz created") 66 | print("topics:", sentences.keys()) -------------------------------------------------------------------------------- /argument-similarity/evaluation_with_clustering.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluates the performance on the UKP ASPECT Corpus with hierachical clustering (Table 2 in our paper). 3 | 4 | Greedy hierachical clustering. 5 | Merges two clusters if the pairwise mean cluster similarity is larger than a threshold. 6 | Merges clusters with highest similarity first 7 | Uses dev set to determine the threshold for supervised systems 8 | """ 9 | import numpy as np 10 | import scipy 11 | import scipy.spatial.distance 12 | import csv 13 | import os 14 | from sklearn.metrics import f1_score 15 | from collections import defaultdict 16 | 17 | class VectorSimilarityScorer: 18 | def __init__(self, sentence_vectors): 19 | self.vector_lookup = {} 20 | 21 | for line in open(sentence_vectors): 22 | sentence, vector_str = line.strip().split('\t') 23 | vector = np.asarray(list(map(float, vector_str.split(" ")))) 24 | self.vector_lookup[sentence] = vector 25 | 26 | self.cache = defaultdict(dict) 27 | 28 | def get_similarity(self, sentence_a, sentence_b): 29 | if sentence_a not in self.cache or sentence_b not in self.cache[sentence_a]: 30 | vector_a = self.vector_lookup[sentence_a] 31 | vector_b = self.vector_lookup[sentence_b] 32 | cosine_sim = 1 - scipy.spatial.distance.cosine(vector_a, vector_b) 33 | self.cache[sentence_a][sentence_b] = cosine_sim 34 | self.cache[sentence_b][sentence_a] = cosine_sim 35 | return self.cache[sentence_b][sentence_a] 36 | 37 | 38 | class PairwisePredictionSimilarityScorer: 39 | def __init__(self, predictions_file): 40 | self.score_lookup = defaultdict(dict) 41 | for line in open(predictions_file): 42 | splits = line.strip().split('\t') 43 | score = float(splits[-1]) 44 | sentence_a = splits[0].strip() 45 | sentence_b = splits[1].strip() 46 | self.score_lookup[sentence_a][sentence_b] = score 47 | self.score_lookup[sentence_b][sentence_a] = score 48 | 49 | 50 | def get_similarity(self, sentence_a, sentence_b): 51 | return self.score_lookup[sentence_a][sentence_b] 52 | 53 | class PriorityQueue(object): 54 | def __init__(self): 55 | self.queue = [] 56 | 57 | def __str__(self): 58 | return ' '.join([str(i) for i in self.queue]) 59 | 60 | # for checking if the queue is empty 61 | def isEmpty(self): 62 | return len(self.queue) == 0 63 | 64 | # for inserting an element in the queue 65 | def insert(self, data): 66 | self.queue.append(data) 67 | 68 | # Removes all element addressing a cluster key 69 | def remove_clusters(self, cluster_key): 70 | i = 0 71 | while i < len(self.queue): 72 | ele = self.queue[i] 73 | if ele['cluster_a'] == cluster_key or ele['cluster_b'] == cluster_key: 74 | del self.queue[i] 75 | else: 76 | i += 1 77 | 78 | # for popping an element based on Priority 79 | def pop(self): 80 | max = 0 81 | for i in range(len(self.queue)): 82 | if self.queue[i]['cluster_sim'] > self.queue[max]['cluster_sim']: 83 | max = i 84 | item = self.queue[max] 85 | del self.queue[max] 86 | return item 87 | 88 | 89 | class HierachicalClustering: 90 | """ 91 | Simple clustering algorithm. Merges two clusters, if the cluster similarity is larger than the threshold. 92 | Highest similarities first. 93 | """ 94 | def __init__(self, similarity_score_function, testfile, np_mode=np.mean): 95 | self.compute_similarity_score = similarity_score_function 96 | self.test_data, self.clusters = self.read_gold_data(testfile) 97 | self.np_mode = np_mode 98 | 99 | def read_gold_data(self, testfile): 100 | test_data = {} 101 | unique_sentences = {} 102 | 103 | with open(testfile, 'r') as csvfile: 104 | csvreader = csv.reader(csvfile, delimiter='\t', quotechar=None) 105 | for splits in csvreader: 106 | splits = map(str.strip, splits) 107 | topic, sentence_a, sentence_b, label = splits 108 | label_bin = '1' if label in ['SS', 'HS'] else '0' 109 | 110 | if topic not in test_data: 111 | test_data[topic] = [] 112 | 113 | test_data[topic].append({'topic': topic, 'sentence_a': sentence_a, 'sentence_b': sentence_b, 'label': label, 114 | 'label_bin': label_bin}) 115 | 116 | if topic not in unique_sentences: 117 | unique_sentences[topic] = set() 118 | 119 | unique_sentences[topic].add(sentence_a) 120 | unique_sentences[topic].add(sentence_b) 121 | 122 | cluster_info = {} 123 | for topic in unique_sentences: 124 | topic_sentences = unique_sentences[topic] 125 | cluster_info[topic] = {} 126 | for idx, sentence in enumerate(topic_sentences): 127 | cluster_info[topic][idx] = [sentence] 128 | 129 | return test_data, cluster_info 130 | 131 | 132 | 133 | def compute_cluster_sim(self, cluster_a, cluster_b): 134 | scores = [] 135 | for sentence_a in cluster_a: 136 | for sentence_b in cluster_b: 137 | scores.append(self.compute_similarity_score(sentence_a, sentence_b)) 138 | 139 | return self.np_mode(scores) 140 | 141 | def cluster_topics(self, threshold): 142 | for topic in self.clusters: 143 | #print("\nRun clustering for:", topic) 144 | topic_cluster = self.clusters[topic] 145 | self.run_clustering(topic_cluster, threshold) 146 | return self.clusters 147 | 148 | 149 | 150 | def run_clustering(self, clusters, threshold): 151 | queue = PriorityQueue() 152 | 153 | #Initial cluster sim computation 154 | cluster_ids = list(clusters.keys()) 155 | for i in range(0, len(cluster_ids)-1): 156 | for j in range(i+1, len(cluster_ids)): 157 | cluster_a = cluster_ids[i] 158 | cluster_b = cluster_ids[j] 159 | 160 | cluster_sim = self.compute_cluster_sim(clusters[cluster_a], clusters[cluster_b]) 161 | element = {'cluster_sim': cluster_sim, 'cluster_a': cluster_a, 'cluster_b': cluster_b} 162 | queue.insert(element) 163 | 164 | while not queue.isEmpty(): 165 | element = queue.pop() 166 | if element['cluster_sim'] <= threshold: 167 | break 168 | 169 | #print("Merge", element, "size_a:", len(clusters[element['cluster_a']]), "size_b:", len(clusters[element['cluster_b']])) 170 | #Merge cluster with highest sim 171 | self.merge_clusters(clusters, element['cluster_a'], element['cluster_b']) 172 | 173 | #Remove all element involving cluster_a or cluster_b 174 | queue.remove_clusters(element['cluster_a']) 175 | queue.remove_clusters(element['cluster_b']) 176 | 177 | #Recompute cluster sim for all clusters with cluster_a and cluster_b 178 | cluster_a = element['cluster_a'] 179 | for cluster_b in clusters.keys(): 180 | if cluster_a != cluster_b: 181 | cluster_sim = self.compute_cluster_sim(clusters[cluster_a], clusters[cluster_b]) 182 | element = {'cluster_sim': cluster_sim, 'cluster_a': cluster_a, 'cluster_b': cluster_b} 183 | queue.insert(element) 184 | 185 | def merge_clusters(self, clusters, key_a, key_b): 186 | clusters[key_a] += clusters[key_b] 187 | del clusters[key_b] 188 | 189 | 190 | ###################################### 191 | # 192 | # Some help functions 193 | # 194 | ###################################### 195 | 196 | def get_clustering(similarity_function, testfile, threshold): 197 | cluster_alg = HierachicalClustering(similarity_function, testfile) 198 | clusters = cluster_alg.cluster_topics(threshold) 199 | return clusters 200 | 201 | def write_output_file(clusters, output_file): 202 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 203 | fOut = open(output_file, 'w') 204 | 205 | for topic in clusters: 206 | topic_cluster = clusters[topic] 207 | for cluster_id in topic_cluster: 208 | for sentence in topic_cluster[cluster_id]: 209 | fOut.write("\t".join([str(cluster_id), topic, sentence.replace("\n", " ").replace("\t", " ")])) 210 | fOut.write("\n") 211 | 212 | 213 | def evaluate(clusters, labels_file, print_scores=False): 214 | all_f1_means = [] 215 | all_f1_sim = [] 216 | all_f1_dissim = [] 217 | 218 | test_data = defaultdict(list) 219 | with open(labels_file, 'r') as csvfile: 220 | csvreader = csv.reader(csvfile, delimiter='\t', quotechar=None) 221 | for splits in csvreader: 222 | splits = map(str.strip, splits) 223 | label_topic, sentence_a, sentence_b, label = splits 224 | label_bin = '1' if label in ['SS', 'HS'] else '0' 225 | 226 | 227 | test_data[label_topic].append( 228 | {'topic': label_topic, 'sentence_a': sentence_a, 'sentence_b': sentence_b, 'label': label, 229 | 'label_bin': label_bin}) 230 | 231 | for topic in clusters: 232 | topic_cluster = clusters[topic] 233 | sentences_cluster_id = {} 234 | for cluster_id in topic_cluster: 235 | for sentence in topic_cluster[cluster_id]: 236 | sentences_cluster_id[sentence] = cluster_id 237 | 238 | topic_test_data = test_data[topic] 239 | 240 | 241 | y_true = np.zeros(len(topic_test_data)) 242 | y_pred = np.zeros(len(topic_test_data)) 243 | 244 | for idx, test_annotation in enumerate(topic_test_data): 245 | sentence_a = test_annotation['sentence_a'] 246 | sentence_b = test_annotation['sentence_b'] 247 | label = test_annotation['label_bin'] 248 | 249 | if label=='1': 250 | y_true[idx] = 1 251 | 252 | if sentences_cluster_id[sentence_a] == sentences_cluster_id[sentence_b]: 253 | y_pred[idx] = 1 254 | 255 | f_sim = f1_score(y_true, y_pred, pos_label=1) 256 | f_dissim = f1_score(y_true, y_pred, pos_label=0) 257 | f_mean = np.mean([f_sim, f_dissim]) 258 | all_f1_sim.append(f_sim) 259 | all_f1_dissim.append(f_dissim) 260 | all_f1_means.append(f_mean) 261 | 262 | if print_scores: 263 | print("F-Sim: %.2f%%" % (f_sim * 100)) 264 | print("F-Dissim: %.2f%%" % (f_dissim * 100)) 265 | print("F-Mean: %.2f%%" % (f_mean * 100)) 266 | acc = np.sum(y_true==y_pred) / len(y_true) 267 | print("Acc: %.2f%%" % (acc * 100)) 268 | 269 | return np.mean(all_f1_sim), np.mean(all_f1_dissim), np.mean(all_f1_means) 270 | 271 | 272 | 273 | 274 | 275 | ###################################### 276 | # 277 | # Functions for pairwise classification approaches 278 | # 279 | ###################################### 280 | def trained_pairwise_prediction_clustering(bert_experiment, epoch): 281 | 282 | print("Epoch:", epoch) 283 | 284 | all_f1_sim = [] 285 | all_f1_dissim = [] 286 | all_f1 = [] 287 | for split in [0, 1, 2, 3]: 288 | print("\n==================") 289 | print("Split:", split) 290 | dev_file = './datasets/ukp_aspect/splits/%d/dev.tsv' % (split) 291 | test_file = './datasets/ukp_aspect/splits/%d/test.tsv' % (split) 292 | output_file = None #'output/bert-base-uncased/%s/seed-%d/%d/test_clusters.tsv' % (transitive, seed, split) 293 | 294 | dev_sim_scorer = PairwisePredictionSimilarityScorer("%s/%d/dev_predictions_epoch_%d.tsv" % (bert_experiment, split, epoch)) 295 | test_sim_scorer = PairwisePredictionSimilarityScorer("%s/%d/test_predictions_epoch_%d.tsv" % (bert_experiment, split, epoch)) 296 | 297 | best_f1 = 0 298 | best_threshold = 0 299 | 300 | for threshold_int in range(0, 20): 301 | threshold = threshold_int / 20 302 | clusters = get_clustering(dev_sim_scorer.get_similarity, dev_file, threshold) 303 | f1_sim, f1_dissim, f1 = evaluate(clusters, dev_file) 304 | 305 | if f1 > best_f1: 306 | best_f1 = f1 307 | best_threshold = threshold 308 | 309 | print("Best threshold on dev:", best_threshold) 310 | 311 | # Evaluate on test 312 | clusters = get_clustering(test_sim_scorer.get_similarity, test_file, best_threshold) 313 | if output_file != None: 314 | write_output_file(clusters, output_file) 315 | f1_sim, f1_dissim, f1 = evaluate(clusters, test_file) 316 | 317 | all_f1_sim.append(f1_sim) 318 | all_f1_dissim.append(f1_dissim) 319 | all_f1.append(f1) 320 | 321 | print("Test-Performance on this split:") 322 | print("F-Mean: %.4f" % (f1)) 323 | print("F-sim: %.4f" % (f1_sim)) 324 | print("F-dissim: %.4f" % (f1_dissim)) 325 | 326 | print("\n\n=========== Averaged performance over all splits ==========") 327 | print("F-Mean: %.4f" % (np.mean(all_f1))) 328 | print("F-sim: %.4f" % (np.mean(all_f1_sim))) 329 | print("F-dissim: %.4f" % (np.mean(all_f1_dissim))) 330 | return np.mean(all_f1) 331 | 332 | 333 | def main(): 334 | bert_experiment = 'bert_output/ukp/seed-1/splits' 335 | trained_pairwise_prediction_clustering(bert_experiment, epoch=3) 336 | 337 | 338 | if __name__ == '__main__': 339 | main() -------------------------------------------------------------------------------- /argument-similarity/evaluation_without_clustering.py: -------------------------------------------------------------------------------- 1 | """ 2 | Computes the F1-scores without clustering (Table 2 in the paper). 3 | """ 4 | import numpy as np 5 | import csv 6 | from sklearn.metrics import f1_score 7 | from collections import defaultdict 8 | 9 | 10 | class PairwisePredictionSimilarityScorer: 11 | def __init__(self, predictions_file): 12 | self.score_lookup = defaultdict(dict) 13 | for line in open(predictions_file): 14 | splits = line.strip().split('\t') 15 | score = float(splits[-1]) 16 | sentence_a = splits[0].strip() 17 | sentence_b = splits[1].strip() 18 | self.score_lookup[sentence_a][sentence_b] = score 19 | self.score_lookup[sentence_b][sentence_a] = score 20 | 21 | 22 | def get_similarity(self, sentence_a, sentence_b): 23 | return self.score_lookup[sentence_a][sentence_b] 24 | 25 | 26 | 27 | 28 | ###################################### 29 | # 30 | # Some help functions 31 | # 32 | ###################################### 33 | def evaluate(similarity_score_function, labels_file, threshold, print_scores=False): 34 | all_f1_means = [] 35 | all_f1_sim = [] 36 | all_f1_dissim = [] 37 | 38 | test_data = defaultdict(list) 39 | with open(labels_file, 'r') as csvfile: 40 | csvreader = csv.reader(csvfile, delimiter='\t', quotechar=None) 41 | for splits in csvreader: 42 | splits = map(str.strip, splits) 43 | label_topic, sentence_a, sentence_b, label = splits 44 | label_bin = '1' if label in ['SS', 'HS'] else '0' 45 | 46 | test_data[label_topic].append({'topic': label_topic, 'sentence_a': sentence_a, 'sentence_b': sentence_b, 'label': label, 47 | 'label_bin': label_bin}) 48 | 49 | for topic in test_data: 50 | topic_test_data = test_data[topic] 51 | y_true = np.zeros(len(topic_test_data)) 52 | y_pred = np.zeros(len(topic_test_data)) 53 | 54 | for idx, test_annotation in enumerate(topic_test_data): 55 | sentence_a = test_annotation['sentence_a'] 56 | sentence_b = test_annotation['sentence_b'] 57 | label = test_annotation['label_bin'] 58 | 59 | if label == '1': 60 | y_true[idx] = 1 61 | 62 | if similarity_score_function(sentence_a, sentence_b) > threshold: 63 | y_pred[idx] = 1 64 | 65 | 66 | 67 | 68 | f_sim = f1_score(y_true, y_pred, pos_label=1) 69 | f_dissim = f1_score(y_true, y_pred, pos_label=0) 70 | f_mean = np.mean([f_sim, f_dissim]) 71 | all_f1_sim.append(f_sim) 72 | all_f1_dissim.append(f_dissim) 73 | all_f1_means.append(f_mean) 74 | 75 | if print_scores: 76 | print("F-Sim: %.2f%%" % (f_sim * 100)) 77 | print("F-Dissim: %.2f%%" % (f_dissim * 100)) 78 | print("F-Mean: %.2f%%" % (f_mean * 100)) 79 | acc = np.sum(y_true==y_pred) / len(y_true) 80 | print("Acc: %.2f%%" % (acc * 100)) 81 | 82 | return np.mean(all_f1_sim), np.mean(all_f1_dissim), np.mean(all_f1_means) 83 | 84 | 85 | 86 | ###################################### 87 | # 88 | # Functions for pairwise classification approaches 89 | # 90 | ###################################### 91 | def trained_pairwise_prediction_clustering(bert_experiment, epoch): 92 | 93 | print("Epoch:", epoch) 94 | 95 | all_f1_sim = [] 96 | all_f1_dissim = [] 97 | all_f1 = [] 98 | for split in [0, 1, 2, 3]: 99 | print("\n==================") 100 | print("Split:", split) 101 | dev_file = './datasets/ukp_aspect/splits/%d/dev.tsv' % (split) 102 | test_file = './datasets/ukp_aspect/splits/%d/test.tsv' % (split) 103 | 104 | dev_sim_scorer = PairwisePredictionSimilarityScorer("%s/%d/dev_predictions_epoch_%d.tsv" % (bert_experiment, split, epoch)) 105 | test_sim_scorer = PairwisePredictionSimilarityScorer("%s/%d/test_predictions_epoch_%d.tsv" % (bert_experiment, split, epoch)) 106 | 107 | best_f1 = 0 108 | best_threshold = 0 109 | 110 | for threshold_int in range(0, 20): 111 | threshold = threshold_int / 20 112 | f1_sim, f1_dissim, f1 = evaluate(dev_sim_scorer.get_similarity, dev_file, threshold) 113 | 114 | if f1 > best_f1: 115 | best_f1 = f1 116 | best_threshold = threshold 117 | 118 | print("Best threshold on dev:", best_threshold, "F1:", best_f1) 119 | 120 | # Evaluate on test 121 | f1_sim, f1_dissim, f1 = evaluate(test_sim_scorer.get_similarity, test_file, best_threshold) 122 | 123 | all_f1_sim.append(f1_sim) 124 | all_f1_dissim.append(f1_dissim) 125 | all_f1.append(f1) 126 | 127 | print("Test-Performance on this split:") 128 | print("F-Mean: %.4f" % (f1)) 129 | print("F-sim: %.4f" % (f1_sim)) 130 | print("F-dissim: %.4f" % (f1_dissim)) 131 | 132 | 133 | 134 | print("\n\n=========== Averaged performance over all splits ==========") 135 | print("F-Mean: %.4f" % (np.mean(all_f1))) 136 | print("F-sim: %.4f" % (np.mean(all_f1_sim))) 137 | print("F-dissim: %.4f" % (np.mean(all_f1_dissim))) 138 | return np.mean(all_f1) 139 | 140 | 141 | def main(): 142 | bert_experiment = 'bert_output/ukp/seed-1/splits/' 143 | trained_pairwise_prediction_clustering(bert_experiment, epoch=2) 144 | 145 | 146 | if __name__ == '__main__': 147 | main() -------------------------------------------------------------------------------- /argument-similarity/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code runs a fine-tuned BERT model to estimate the similarity between two arguments. 3 | 4 | In this example, we include arguments from three topics (courtesy to www.procon.org): Zoos, Vegetarianism, Climate Change 5 | 6 | Each argument is compared against each other argument. Arguments from the same topic, e.g. on zoos, should be ranked higher 7 | than arguments from different topics. 8 | 9 | Usage: python inference.py 10 | """ 11 | 12 | from pytorch_pretrained_bert.tokenization import BertTokenizer 13 | import torch 14 | from torch.utils.data import TensorDataset, DataLoader, SequentialSampler 15 | from train import InputExample, convert_examples_to_features 16 | from SigmoidBERT import SigmoidBERT 17 | 18 | # See the README.md where to download pre-trained models 19 | #model_path = 'bert_output/ukp_aspects_all' #ukp_aspects_all model: trained 20 | model_path = 'bert_output/misra_all' #misra_all model: Trained on all 3 topics from Misra et al., 2016 21 | 22 | 23 | max_seq_length = 64 24 | eval_batch_size = 8 25 | 26 | arguments = ['Zoos save species from extinction and other dangers.', 27 | 'Zoos produce helpful scientific research.', 28 | 'Zoos are detrimental to animals\' physical health.', 29 | 'Zoo confinement is psychologically damaging to animals.', 30 | 'Eating meat is not cruel or unethical; it is a natural part of the cycle of life. ', 31 | 'It is cruel and unethical to kill animals for food when vegetarian options are available', 32 | 'Overwhelming scientific consensus says human activity is primarily responsible for global climate change.', 33 | 'Rising levels of human-produced gases released into the atmosphere create a greenhouse effect that traps heat and causes global warming.' 34 | ] 35 | 36 | #Compare every argument with each other 37 | input_examples = [] 38 | output_examples = [] 39 | 40 | for i in range(0, len(arguments)-1): 41 | for j in range(i+1, len(arguments)): 42 | input_examples.append(InputExample(text_a=arguments[i], text_b=arguments[j], label=-1)) 43 | output_examples.append([arguments[i], arguments[j]]) 44 | 45 | 46 | tokenizer = BertTokenizer.from_pretrained(model_path, do_lower_case=True) 47 | eval_features = convert_examples_to_features(input_examples, max_seq_length, tokenizer) 48 | 49 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 50 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 51 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 52 | 53 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids) 54 | eval_sampler = SequentialSampler(eval_data) 55 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=eval_batch_size) 56 | 57 | 58 | 59 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 60 | model = SigmoidBERT.from_pretrained(model_path,) 61 | model.to(device) 62 | model.eval() 63 | 64 | predicted_logits = [] 65 | with torch.no_grad(): 66 | for input_ids, input_mask, segment_ids in eval_dataloader: 67 | input_ids = input_ids.to(device) 68 | input_mask = input_mask.to(device) 69 | segment_ids = segment_ids.to(device) 70 | 71 | logits = model(input_ids, segment_ids, input_mask) 72 | logits = logits.detach().cpu().numpy() 73 | predicted_logits.extend(logits[:, 0]) 74 | 75 | 76 | 77 | 78 | for idx in range(len(predicted_logits)): 79 | output_examples[idx].append(predicted_logits[idx]) 80 | 81 | #Sort by similarity 82 | output_examples = sorted(output_examples, key=lambda x: x[2], reverse=True) 83 | 84 | print("Predicted similarities (sorted by similarity):") 85 | for idx in range(len(output_examples)): 86 | example = output_examples[idx] 87 | print("Sentence A:", example[0]) 88 | print("Sentence B:", example[1]) 89 | print("Similarity:", example[2]) 90 | print("") 91 | 92 | -------------------------------------------------------------------------------- /argument-similarity/train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """BERT finetuning runner.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | import gzip 24 | import csv 25 | import os 26 | import logging 27 | import argparse 28 | import random 29 | from tqdm import tqdm, trange 30 | 31 | import scipy.stats 32 | import numpy as np 33 | import torch 34 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 35 | from torch.utils.data.distributed import DistributedSampler 36 | import torch.nn.functional as F 37 | from pytorch_pretrained_bert.tokenization import BertTokenizer 38 | from pytorch_pretrained_bert.optimization import BertAdam 39 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME 40 | from collections import defaultdict 41 | 42 | from SigmoidBERT import SigmoidBERT 43 | 44 | logging.basicConfig(format='%(message)s', 45 | datefmt='%m/%d/%Y %H:%M:%S', 46 | level=logging.INFO) 47 | logger = logging.getLogger(__name__) 48 | 49 | 50 | class InputExample(object): 51 | """A single training/test example for simple sequence classification.""" 52 | 53 | def __init__(self, text_a, text_b=None, label=None, guid=None): 54 | """Constructs a InputExample. 55 | 56 | Args: 57 | guid: Unique id for the example. 58 | text_a: string. The untokenized text of the first sequence. For single 59 | sequence tasks, only this sequence must be specified. 60 | text_b: (Optional) string. The untokenized text of the second sequence. 61 | Only must be specified for sequence pair tasks. 62 | label: (Optional) string. The label of the example. This should be 63 | specified for train and dev examples, but not for test examples. 64 | """ 65 | self.guid = guid 66 | self.text_a = text_a 67 | self.text_b = text_b 68 | self.label = label 69 | 70 | 71 | class InputFeatures(object): 72 | """A single set of features of data.""" 73 | 74 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 75 | self.input_ids = input_ids 76 | self.input_mask = input_mask 77 | self.segment_ids = segment_ids 78 | self.label_id = label_id 79 | 80 | 81 | 82 | 83 | class UKPAspectsProcessor(object): 84 | def _read_dataset(self, filepath): 85 | sentences = defaultdict(lambda: defaultdict(dict)) 86 | 87 | with open(filepath, 'r') as fIn: 88 | for line in fIn: 89 | splits = line.strip().split('\t') 90 | assert len(splits)==4 91 | 92 | topic = splits[0].strip() 93 | sentence_a = splits[1].strip() 94 | sentence_b = splits[2].strip() 95 | label = splits[-1].strip() 96 | 97 | #Binarize the label 98 | bin_label = 1 if label in ['SS', 'HS'] else 0 99 | 100 | sentences[topic][sentence_a][sentence_b] = bin_label 101 | sentences[topic][sentence_b][sentence_a] = bin_label 102 | 103 | return sentences 104 | 105 | def get_examples(self, data_dir, train_file, dev_file, test_file, data_set): 106 | """See base class.""" 107 | logging.info("Get "+ data_set+ " examples") 108 | 109 | if data_set == 'train': 110 | data_file = os.path.join(data_dir, train_file) 111 | return self._get_train_examples(data_file, data_set) 112 | elif data_set == 'dev': 113 | data_file = os.path.join(data_dir, dev_file) 114 | return self._get_test_examples(data_file, data_set) 115 | else: 116 | data_file = os.path.join(data_dir, test_file) 117 | return self._get_test_examples(data_file, data_set) 118 | 119 | def _get_train_examples(self, data_file, data_set): 120 | sentences = self._read_dataset(data_file) 121 | topics = list(sentences.keys()) 122 | random.shuffle(topics) 123 | 124 | examples = [] 125 | logging.info("Topics: " + str(topics)) 126 | 127 | for topic in topics: 128 | for sentence_a in sentences[topic].keys(): 129 | for sentence_b in sentences[topic][sentence_a].keys(): 130 | guid = "%s-%d" % (data_set, len(examples)) 131 | label = sentences[topic][sentence_a][sentence_b] 132 | examples.append(InputExample(guid=guid, text_a=sentence_a, text_b=sentence_b, label=label)) 133 | 134 | return examples 135 | 136 | 137 | def _get_test_examples(self, data_file, data_set): 138 | sentences = self._read_dataset(data_file) 139 | topics = list(sentences.keys()) 140 | logging.info("Topics: "+str(topics)) 141 | 142 | examples = [] 143 | for topic in topics: 144 | unique_sentences = list(sentences[topic].keys()) 145 | for i in range(len(unique_sentences)-1): 146 | for j in range(i+1, len(unique_sentences)): 147 | guid = "%s-%d" % (data_set, len(examples)) 148 | sentence_a = unique_sentences[i] 149 | sentence_b = unique_sentences[j] 150 | label = -1 151 | 152 | if sentence_b in sentences[topic][sentence_a]: 153 | label = sentences[topic][sentence_a][sentence_b] 154 | 155 | examples.append(InputExample(guid=guid, text_a=sentence_a, text_b=sentence_b, label=label)) 156 | 157 | return examples 158 | 159 | 160 | 161 | class MisraProcessor(object): 162 | def get_examples(self, data_dir, train_topic, dev_topic, test_topic, data_set): 163 | topics = set(['DP', 'GC', 'GM']) 164 | 165 | 166 | if data_set == 'test': 167 | filepath = os.path.join(data_dir, 'ArgPairs_' + test_topic + '.csv') 168 | sentences = self._read_dataset(filepath, test_topic) 169 | return self._get_examples(sentences, data_set) 170 | elif data_set=='dev': 171 | filepath = os.path.join(data_dir, 'ArgPairs_' + dev_topic + '.csv') 172 | sentences = self._read_dataset(filepath, dev_topic) 173 | return self._get_examples(sentences, data_set) 174 | else: 175 | all_train_examples = [] 176 | 177 | for topic in topics: 178 | if topic == dev_topic or topic == test_topic: 179 | continue 180 | 181 | filepath = os.path.join(data_dir, 'ArgPairs_' + topic + '.csv') 182 | sentences = self._read_dataset(filepath, topic, add_symmetry = True) 183 | all_train_examples.extend(self._get_examples(sentences, data_set)) 184 | 185 | return all_train_examples 186 | 187 | 188 | def _read_dataset(self, filepath, topic, add_symmetry = False): 189 | logging.info("Read file: "+filepath) 190 | sentences = defaultdict(lambda: defaultdict(dict)) 191 | 192 | with open(filepath, 'r', encoding='iso-8859-1') as csvfile: 193 | csvreader = csv.reader(csvfile, delimiter=',', quotechar='"') 194 | headers = next(csvreader) 195 | 196 | for splits in csvreader: 197 | assert(len(splits)==11) 198 | label = float(splits[0].strip())/5 199 | sentence_a = splits[-1].strip() 200 | sentence_b = splits[-2].strip() 201 | 202 | sentences[topic][sentence_a][sentence_b] = label 203 | 204 | if add_symmetry: 205 | sentences[topic][sentence_b][sentence_a] = label 206 | 207 | return sentences 208 | 209 | 210 | 211 | def _get_examples(self, sentences, data_set): 212 | topics = list(sentences.keys()) 213 | examples = [] 214 | for topic in topics: 215 | for sentence_a in sentences[topic].keys(): 216 | for sentence_b in sentences[topic][sentence_a].keys(): 217 | guid = "%s-%d" % (data_set, len(examples)) 218 | label = sentences[topic][sentence_a][sentence_b] 219 | examples.append(InputExample(guid=guid, text_a=sentence_a, text_b=sentence_b, label=label)) 220 | 221 | return examples 222 | 223 | 224 | def get_test_examples(self, data_file, data_set): 225 | return self.get_train_examples(data_file, data_set) 226 | 227 | 228 | 229 | 230 | 231 | 232 | def convert_examples_to_features(examples, max_seq_length, tokenizer): 233 | """Loads a data file into a list of `InputBatch`s.""" 234 | tokens_a_longer_max_seq_length = 0 235 | features = [] 236 | for (ex_index, example) in enumerate(examples): 237 | tokens_a = tokenizer.tokenize(example.text_a) 238 | tokens_b = None 239 | 240 | len_tokens_a = len(tokens_a) 241 | len_tokens_b = 0 242 | 243 | 244 | 245 | if example.text_b: 246 | tokens_b = tokenizer.tokenize(example.text_b) 247 | len_tokens_b = len(tokens_b) 248 | # Modifies `tokens_a` and `tokens_b` in place so that the total 249 | # length is less than the specified length. 250 | # Account for [CLS], [SEP], [SEP] with "- 3" 251 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 252 | else: 253 | # Account for [CLS] and [SEP] with "- 2" 254 | if len(tokens_a) > max_seq_length - 2: 255 | tokens_a = tokens_a[:(max_seq_length - 2)] 256 | 257 | if (len_tokens_a + len_tokens_b) > (max_seq_length - 2): 258 | tokens_a_longer_max_seq_length += 1 259 | 260 | # The convention in BERT is: 261 | # (a) For sequence pairs: 262 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 263 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 264 | # (b) For single sequences: 265 | # tokens: [CLS] the dog is hairy . [SEP] 266 | # type_ids: 0 0 0 0 0 0 0 267 | # 268 | # Where "type_ids" are used to indicate whether this is the first 269 | # sequence or the second sequence. The embedding vectors for `type=0` and 270 | # `type=1` were learned during pre-training and are added to the wordpiece 271 | # embedding vector (and position vector). This is not *strictly* necessary 272 | # since the [SEP] token unambigiously separates the sequences, but it makes 273 | # it easier for the model to learn the concept of sequences. 274 | # 275 | # For classification tasks, the first vector (corresponding to [CLS]) is 276 | # used as as the "sentence vector". Note that this only makes sense because 277 | # the entire model is fine-tuned. 278 | tokens = ["[CLS]"] + tokens_a + ["[SEP]"] 279 | segment_ids = [0] * len(tokens) 280 | 281 | if tokens_b: 282 | tokens += tokens_b + ["[SEP]"] 283 | segment_ids += [1] * (len(tokens_b) + 1) 284 | 285 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 286 | 287 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 288 | # tokens are attended to. 289 | input_mask = [1] * len(input_ids) 290 | 291 | # Zero-pad up to the sequence length. 292 | padding = [0] * (max_seq_length - len(input_ids)) 293 | input_ids += padding 294 | input_mask += padding 295 | segment_ids += padding 296 | 297 | assert len(input_ids)==max_seq_length 298 | assert len(input_mask)==max_seq_length 299 | assert len(segment_ids)==max_seq_length 300 | 301 | label_id = float(example.label) 302 | 303 | 304 | if ex_index < 1 and example.guid is not None and example.guid.startswith('train-'): 305 | logger.info("\n\n*** Example ***") 306 | logger.info("guid: %s" % (example.guid)) 307 | logger.info("tokens: %s" % " ".join( 308 | [str(x) for x in tokens])) 309 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 310 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 311 | logger.info( 312 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 313 | logger.info("label: %s (id = %d)" % (example.label, label_id)) 314 | 315 | features.append( 316 | InputFeatures(input_ids=input_ids, 317 | input_mask=input_mask, 318 | segment_ids=segment_ids, 319 | label_id=label_id)) 320 | 321 | logger.info(":: Sentences longer than max_sequence_length: %d" % (tokens_a_longer_max_seq_length)) 322 | logger.info(":: Num sentences: %d" % (len(examples))) 323 | return features 324 | 325 | 326 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 327 | """Truncates a sequence pair in place to the maximum length.""" 328 | 329 | # This is a simple heuristic which will always truncate the longer sequence 330 | # one token at a time. This makes more sense than truncating an equal percent 331 | # of tokens from each, since if one sequence is very short then each token 332 | # that's truncated likely contains more information than a longer sequence. 333 | while True: 334 | total_length = len(tokens_a) + len(tokens_b) 335 | if total_length <= max_length: 336 | break 337 | if len(tokens_a) > len(tokens_b): 338 | tokens_a.pop() 339 | else: 340 | tokens_b.pop() 341 | 342 | 343 | def accuracy(predicted_logits, gold_labels): 344 | assert len(predicted_logits) == len(gold_labels) 345 | 346 | num_labels = 0 347 | num_correct = 0 348 | 349 | for predicted_logit, gold_label in zip(predicted_logits, gold_labels): 350 | if gold_label < 0: #Labels < 0 indicate non-existent labels 351 | continue 352 | 353 | num_labels += 1 354 | 355 | #Binarize gold and predicted label 356 | if (gold_label < 0.5 and predicted_logit < 0.5) or (gold_label >= 0.5 and predicted_logit >= 0.5): 357 | num_correct += 1 358 | 359 | 360 | return num_correct / num_labels 361 | 362 | 363 | def warmup_linear(x, warmup=0.002): 364 | if x < warmup: 365 | return x / warmup 366 | return 1.0 - x 367 | 368 | 369 | def main(): 370 | parser = argparse.ArgumentParser() 371 | 372 | ## Required parameters 373 | parser.add_argument("--data_dir", 374 | default=None, 375 | type=str, 376 | required=True, 377 | help="The input data dir.") 378 | parser.add_argument("--train_file", 379 | default=None, 380 | type=str) 381 | parser.add_argument("--dev_file", 382 | default=None, 383 | type=str) 384 | parser.add_argument("--test_file", 385 | default=None, 386 | type=str) 387 | 388 | 389 | parser.add_argument("--bert_model", default=None, type=str, required=True, 390 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 391 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " 392 | "bert-base-multilingual-cased, bert-base-chinese.") 393 | parser.add_argument("--task_name", 394 | default=None, 395 | type=str, 396 | required=True, 397 | help="The name of the task to train.") 398 | parser.add_argument("--output_dir", 399 | default=None, 400 | type=str, 401 | required=True, 402 | help="The output directory where the model predictions and checkpoints will be written.") 403 | 404 | ## Other parameters 405 | parser.add_argument("--max_seq_length", 406 | default=128, 407 | type=int, 408 | help="The maximum total input sequence length after WordPiece tokenization. \n" 409 | "Sequences longer than this will be truncated, and sequences shorter \n" 410 | "than this will be padded.") 411 | parser.add_argument("--do_train", 412 | action='store_true', 413 | help="Whether to run training.") 414 | parser.add_argument("--do_eval", 415 | action='store_true', 416 | help="Whether to run eval on the dev set.") 417 | parser.add_argument("--do_lower_case", 418 | action='store_true', 419 | help="Set this flag if you are using an uncased model.") 420 | parser.add_argument("--train_batch_size", 421 | default=32, 422 | type=int, 423 | help="Total batch size for training.") 424 | parser.add_argument("--eval_batch_size", 425 | default=8, 426 | type=int, 427 | help="Total batch size for eval.") 428 | parser.add_argument("--learning_rate", 429 | default=5e-5, 430 | type=float, 431 | help="The initial learning rate for Adam.") 432 | parser.add_argument("--num_train_epochs", 433 | default=3.0, 434 | type=float, 435 | help="Total number of training epochs to perform.") 436 | parser.add_argument("--warmup_proportion", 437 | default=0.1, 438 | type=float, 439 | help="Proportion of training to perform linear learning rate warmup for. " 440 | "E.g., 0.1 = 10%% of training.") 441 | parser.add_argument("--no_cuda", 442 | action='store_true', 443 | help="Whether not to use CUDA when available") 444 | parser.add_argument("--local_rank", 445 | type=int, 446 | default=-1, 447 | help="local_rank for distributed training on gpus") 448 | parser.add_argument('--seed', 449 | type=int, 450 | default=42, 451 | help="random seed for initialization") 452 | parser.add_argument('--gradient_accumulation_steps', 453 | type=int, 454 | default=1, 455 | help="Number of updates steps to accumulate before performing a backward/update pass.") 456 | parser.add_argument('--fp16', 457 | action='store_true', 458 | help="Whether to use 16-bit float precision instead of 32-bit") 459 | parser.add_argument('--loss_scale', 460 | type=float, default=0, 461 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 462 | "0 (default value): dynamic loss scaling.\n" 463 | "Positive power of 2: static loss scaling value.\n") 464 | 465 | 466 | 467 | 468 | args = parser.parse_args() 469 | 470 | 471 | 472 | processors = { 473 | "ukp_aspects": UKPAspectsProcessor, 474 | "misra": MisraProcessor, 475 | } 476 | 477 | 478 | 479 | 480 | 481 | if args.local_rank==-1 or args.no_cuda: 482 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 483 | n_gpu = torch.cuda.device_count() 484 | else: 485 | torch.cuda.set_device(args.local_rank) 486 | device = torch.device("cuda", args.local_rank) 487 | n_gpu = 1 488 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 489 | torch.distributed.init_process_group(backend='nccl') 490 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( 491 | device, n_gpu, bool(args.local_rank!=-1), args.fp16)) 492 | 493 | if args.gradient_accumulation_steps < 1: 494 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 495 | args.gradient_accumulation_steps)) 496 | 497 | args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) 498 | 499 | random.seed(args.seed) 500 | np.random.seed(args.seed) 501 | torch.manual_seed(args.seed) 502 | if n_gpu > 0: 503 | torch.cuda.manual_seed_all(args.seed) 504 | 505 | if not args.do_train and not args.do_eval: 506 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 507 | 508 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: 509 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 510 | os.makedirs(args.output_dir, exist_ok=True) 511 | 512 | with open(os.path.join(args.output_dir, 'config.txt'), 'w') as fOut: 513 | fOut.write(str(args)) 514 | 515 | task_name = args.task_name.lower() 516 | 517 | if task_name not in processors: 518 | raise ValueError("Task not found: %s" % (task_name)) 519 | 520 | processor = processors[task_name]() 521 | num_labels = 1 522 | 523 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 524 | 525 | train_examples = None 526 | num_train_steps = None 527 | if args.do_train: 528 | train_examples = processor.get_examples(args.data_dir, args.train_file, args.dev_file, args.test_file, 'train') 529 | num_train_steps = int( 530 | len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) 531 | 532 | # Prepare model 533 | model = SigmoidBERT.from_pretrained(args.bert_model, cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank), num_labels=num_labels) 534 | if args.fp16: 535 | model.half() 536 | model.to(device) 537 | if args.local_rank!=-1: 538 | try: 539 | from apex.parallel import DistributedDataParallel as DDP 540 | except ImportError: 541 | raise ImportError( 542 | "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 543 | 544 | model = DDP(model) 545 | elif n_gpu > 1: 546 | model = torch.nn.DataParallel(model) 547 | 548 | # Prepare optimizer 549 | param_optimizer = list(model.named_parameters()) 550 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 551 | optimizer_grouped_parameters = [ 552 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 553 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 554 | ] 555 | t_total = num_train_steps 556 | if args.local_rank!=-1: 557 | t_total = t_total // torch.distributed.get_world_size() 558 | if args.fp16: 559 | try: 560 | from apex.optimizers import FP16_Optimizer 561 | from apex.optimizers import FusedAdam 562 | except ImportError: 563 | raise ImportError( 564 | "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 565 | 566 | optimizer = FusedAdam(optimizer_grouped_parameters, 567 | lr=args.learning_rate, 568 | bias_correction=False, 569 | max_grad_norm=1.0) 570 | if args.loss_scale==0: 571 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) 572 | else: 573 | optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) 574 | 575 | else: 576 | optimizer = BertAdam(optimizer_grouped_parameters, 577 | lr=args.learning_rate, 578 | warmup=args.warmup_proportion, 579 | t_total=t_total) 580 | 581 | global_step = 0 582 | 583 | 584 | if args.do_train: 585 | 586 | with open(os.path.join(args.output_dir, "train_sentences.csv"), "w") as writer: 587 | for idx, example in enumerate(train_examples): 588 | writer.write("%s\t%s\t%s\n" % (example.label, example.text_a, example.text_b)) 589 | 590 | 591 | train_features = convert_examples_to_features(train_examples, args.max_seq_length, tokenizer) 592 | logger.info("\n\n***** Running training *****") 593 | logger.info(" Num examples = %d", len(train_examples)) 594 | logger.info(" Batch size = %d", args.train_batch_size) 595 | logger.info(" Num steps = %d", num_train_steps) 596 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 597 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 598 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 599 | all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.float) 600 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 601 | if args.local_rank==-1: 602 | train_sampler = RandomSampler(train_data) 603 | else: 604 | train_sampler = DistributedSampler(train_data) 605 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) 606 | 607 | 608 | model.train() 609 | for epoch in trange(int(args.num_train_epochs), desc="Epoch"): 610 | tr_loss = 0 611 | nb_tr_examples, nb_tr_steps = 0, 0 612 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 613 | batch = tuple(t.to(device) for t in batch) 614 | input_ids, input_mask, segment_ids, label_ids = batch 615 | loss = model(input_ids, segment_ids, input_mask, label_ids) 616 | if n_gpu > 1: 617 | loss = loss.mean() # mean() to average on multi-gpu. 618 | if args.gradient_accumulation_steps > 1: 619 | loss = loss / args.gradient_accumulation_steps 620 | 621 | if args.fp16: 622 | optimizer.backward(loss) 623 | else: 624 | loss.backward() 625 | 626 | tr_loss += loss.item() 627 | nb_tr_examples += input_ids.size(0) 628 | nb_tr_steps += 1 629 | if (step + 1) % args.gradient_accumulation_steps==0: 630 | # modify learning rate with special warm up BERT uses 631 | lr_this_step = args.learning_rate * warmup_linear(global_step / t_total, args.warmup_proportion) 632 | for param_group in optimizer.param_groups: 633 | param_group['lr'] = lr_this_step 634 | optimizer.step() 635 | optimizer.zero_grad() 636 | global_step += 1 637 | 638 | #Dev set 639 | if args.dev_file is not None: 640 | eval_set = 'dev' 641 | eval_results_filename = "%s_results_epoch_%d.txt" % (eval_set, epoch+1) 642 | eval_prediction_filename = "%s_predictions_epoch_%d.tsv" % (eval_set, epoch+1) 643 | do_evaluation(processor, args, tokenizer, model, device, global_step, 644 | eval_set, eval_results_filename, eval_prediction_filename) 645 | 646 | # Test set 647 | if args.test_file is not None: 648 | eval_set = 'test' 649 | eval_results_filename = "%s_results_epoch_%d.txt" % (eval_set, epoch+1) 650 | eval_prediction_filename = "%s_predictions_epoch_%d.tsv" % (eval_set, epoch+1) 651 | 652 | do_evaluation(processor, args, tokenizer, model, device, global_step, 653 | eval_set, eval_results_filename, eval_prediction_filename) 654 | 655 | # Save a trained model 656 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 657 | 658 | # If we save using the predefined names, we can load using `from_pretrained` 659 | output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) 660 | output_config_file = os.path.join(args.output_dir, CONFIG_NAME) 661 | 662 | torch.save(model_to_save.state_dict(), output_model_file) 663 | model_to_save.config.to_json_file(output_config_file) 664 | tokenizer.save_vocabulary(args.output_dir) 665 | 666 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 667 | output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") 668 | if args.do_train: 669 | torch.save(model_to_save.state_dict(), output_model_file) 670 | 671 | def do_evaluation(processor, args, tokenizer, model, device, global_step, eval_set, eval_results_filename, eval_prediction_filename): 672 | eval_examples = processor.get_examples(args.data_dir, args.train_file, args.dev_file, args.test_file, eval_set) 673 | eval_features = convert_examples_to_features(eval_examples, args.max_seq_length, tokenizer) 674 | logger.info("\n\n\n***** Running evaluation *****") 675 | logger.info(" Num examples = %d", len(eval_examples)) 676 | logger.info(" Batch size = %d", args.eval_batch_size) 677 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 678 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 679 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 680 | all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.float) 681 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 682 | # Run prediction for full data 683 | eval_sampler = SequentialSampler(eval_data) 684 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 685 | 686 | model.eval() 687 | 688 | nb_eval_steps, nb_eval_examples = 0, 0 689 | gold_labels = [f.label_id for f in eval_features] 690 | 691 | predicted_logits = [] 692 | for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"): 693 | input_ids = input_ids.to(device) 694 | input_mask = input_mask.to(device) 695 | segment_ids = segment_ids.to(device) 696 | 697 | 698 | with torch.no_grad(): 699 | logits = model(input_ids, segment_ids, input_mask) 700 | 701 | logits = logits.detach().cpu().numpy() 702 | predicted_logits.extend(logits[:,0]) 703 | 704 | nb_eval_examples += input_ids.size(0) 705 | nb_eval_steps += 1 706 | 707 | eval_accuracy = accuracy(predicted_logits, gold_labels) 708 | 709 | eval_spearman, eval_pearson = -999, -999 710 | 711 | try: 712 | eval_spearman, _ = scipy.stats.spearmanr(gold_labels, predicted_logits) 713 | except: 714 | pass 715 | 716 | try: 717 | eval_pearson, _ = scipy.stats.pearsonr(gold_labels, predicted_logits) 718 | except: 719 | pass 720 | 721 | result = { 722 | 'eval_accuracy': eval_accuracy, 723 | 'eval_spearman': eval_spearman, 724 | 'eval_pearson': eval_pearson, 725 | 'global_step': global_step, 726 | } 727 | 728 | output_eval_file = os.path.join(args.output_dir, eval_results_filename) 729 | with open(output_eval_file, "w") as writer: 730 | logger.info("\n\n\n***** Eval results *****") 731 | for key in sorted(result.keys()): 732 | logger.info(" %s = %s", key, str(result[key])) 733 | writer.write("%s = %s\n" % (key, str(result[key]))) 734 | 735 | 736 | output_pred_file = os.path.join(args.output_dir, eval_prediction_filename) 737 | with open(output_pred_file, "w") as writer: 738 | for idx, example in enumerate(eval_examples): 739 | gold_label = example.label 740 | pred_logits = predicted_logits[idx] 741 | writer.write("\t".join([example.text_a.replace("\n", " ").replace("\t", " "), example.text_b.replace("\n", " ").replace("\t", " "), str(gold_label), str(pred_logits)])) 742 | writer.write("\n") 743 | 744 | 745 | if __name__=="__main__": 746 | main() 747 | 748 | 749 | 750 | 751 | 752 | -------------------------------------------------------------------------------- /argument-similarity/train_misra.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | for test_topic in "DP" "GC" "GM" 3 | do 4 | python train.py --task_name misra --do_train --seed 2 --do_eval --do_lower_case --data_dir ./datasets/misra/ --test_file "${test_topic}" --bert_model bert-base-uncased --max_seq_length 64 --train_batch_size 32 --learning_rate 2e-5 --num_train_epochs 3.0 --output_dir "./bert_output/misra/${test_topic}" 5 | done 6 | 7 | -------------------------------------------------------------------------------- /argument-similarity/train_misra_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | python train.py --task_name misra --do_train --seed 2 --do_eval --do_lower_case --data_dir ./datasets/misra/ --bert_model bert-base-uncased --max_seq_length 64 --train_batch_size 32 --learning_rate 2e-5 --num_train_epochs 3.0 --output_dir "./bert_output/misra_all/" 3 | 4 | 5 | -------------------------------------------------------------------------------- /argument-similarity/train_ukp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | for split in 0 1 2 3 3 | do 4 | DATA_DIR=./datasets/ukp_aspect/splits/${split} 5 | python train.py --task_name ukp_aspects --do_train --seed 1 --do_eval --do_lower_case --data_dir $DATA_DIR/ --train_file "train.tsv" --dev_file "dev.tsv" --test_file "test.tsv" --bert_model bert-base-uncased --max_seq_length 64 --train_batch_size 32 --learning_rate 2e-5 --num_train_epochs 3.0 --output_dir "./bert_output/ukp/seed-1/splits/${split}" 6 | done 7 | 8 | -------------------------------------------------------------------------------- /argument-similarity/train_ukp_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | python train.py --task_name ukp_aspects --do_train --seed 1 --do_eval --do_lower_case --data_dir ./datasets/ukp_aspect/splits/ --train_file "all_data.tsv" --bert_model bert-base-uncased --max_seq_length 64 --train_batch_size 32 --learning_rate 2e-5 --num_train_epochs 3.0 --output_dir "./bert_output/ukp_aspects_all" 3 | 4 | 5 | --------------------------------------------------------------------------------