├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md └── src ├── __init__.py ├── dataloader.py ├── datasets.py ├── models.py ├── py.typed ├── requirements.txt └── train.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Source codes for ``Efficient Classification of Long Documents Using Transformers'' 2 | 3 | Please refer to our paper for more details and cite our paper if you find this repo useful: 4 | 5 | ``` 6 | @inproceedings{park-etal-2022-efficient, 7 | title = "Efficient Classification of Long Documents Using Transformers", 8 | author = "Park, Hyunji and 9 | Vyas, Yogarshi and 10 | Shah, Kashif", 11 | booktitle = "Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)", 12 | month = may, 13 | year = "2022", 14 | address = "Dublin, Ireland", 15 | publisher = "Association for Computational Linguistics", 16 | url = "https://aclanthology.org/2022.acl-short.79", 17 | doi = "10.18653/v1/2022.acl-short.79", 18 | pages = "702--709", 19 | } 20 | ``` 21 | 22 | ## Instructions 23 | 24 | ### 1. Install required libraries 25 | 26 | ``` 27 | pip install -r requirements.txt 28 | python -m spacy download en_core_web_sm 29 | ``` 30 | 31 | ### 2. Prepare the datasets 32 | 33 | #### Hyperpartisan News Detection 34 | 35 | * Available at 36 | * Download the datasets 37 | 38 | ``` 39 | mkdir data/hyperpartisan 40 | wget -P data/hyperpartisan/ https://zenodo.org/record/1489920/files/articles-training-byarticle-20181122.zip 41 | wget -P data/hyperpartisan/ https://zenodo.org/record/1489920/files/ground-truth-training-byarticle-20181122.zip 42 | unzip data/hyperpartisan/articles-training-byarticle-20181122.zip -d data/hyperpartisan 43 | unzip data/hyperpartisan/ground-truth-training-byarticle-20181122.zip -d data/hyperpartisan 44 | rm data/hyperpartisan/*zip 45 | ``` 46 | 47 | * Prepare the datasets with the resulting xml files and this preprocessing script (following [Longformer](https://arxiv.org/abs/2004.05150)): 48 | 49 | #### 20NewsGroups 50 | 51 | * Originally available at 52 | * Running `train.py` with the `--data 20news` flag will download and prepare the data available via `sklearn.datasets` (following [CogLTX](https://proceedings.neurips.cc/paper/2020/file/96671501524948bc3937b4b30d0e57b9-Paper.pdf)). 53 | We adopt the train/dev/test split from [this ToBERT paper](https://ieeexplore.ieee.org/document/9003958). 54 | 55 | #### EURLEX-57K 56 | 57 | * Available at 58 | * Download the datasets 59 | 60 | ``` 61 | mkdir data/EURLEX57K 62 | wget -O data/EURLEX57K/datasets.zip http://nlp.cs.aueb.gr/software_and_datasets/EURLEX57K/datasets.zip 63 | unzip data/EURLEX57K/datasets.zip -d data/EURLEX57K 64 | rm data/EURLEX57K/datasets.zip 65 | rm -rf data/EURLEX57K/__MACOSX 66 | mv data/EURLEX57K/dataset/* data/EURLEX57K 67 | rm -rf data/EURLEX57K/dataset 68 | wget -O data/EURLEX57K/EURLEX57K.json http://nlp.cs.aueb.gr/software_and_datasets/EURLEX57K/eurovoc_en.json 69 | ``` 70 | 71 | * Running `train.py` with the `--data eurlex` flag reads and prepares the data from `data/EURLEX57K/{train, dev, test}/*.json` files 72 | * Running `train.py` with the `--data eurlex --inverted` flag creates Inverted EURLEX data by inverting the order of the sections 73 | * `data/EURLEX57K/EURLEX57K.json` contains label information. 74 | 75 | #### CMU Book Summary Dataset 76 | 77 | * Available at 78 | 79 | ``` 80 | wget -P data/ http://www.cs.cmu.edu/~dbamman/data/booksummaries.tar.gz 81 | tar -xf data/booksummaries.tar.gz -C data 82 | ``` 83 | 84 | * Running `train.py` with the `--data books` flag reads and prepares the data from `data/booksummaries/booksummaries.txt` 85 | * Running `train.py` with the `--data books --pairs` flag creates Paired Book Summary by combining pairs of summaries and their labels 86 | 87 | 88 | ### 3. Run the models 89 | 90 | ``` 91 | e.g. python train.py --model_name bertplusrandom --data books --pairs --batch_size 8 --epochs 20 --lr 3e-05 92 | ``` 93 | 94 | cf. Note that we use the source code for the CogLTX model: 95 | 96 | ### Hyperparameters used 97 | 98 | #### Hyperpartisan 99 | 100 | | Parameter | BERT | BERT+TextRank | BERT+Random | Longformer | ToBERT | 101 | |------------|-------|---------------|-------------|---------------------------------------------------|--------| 102 | | Batch size | 8 | 8 | 8 | 16 | 8 | 103 | | Epochs | 20 | 20 | 20 | 20 | 20 | 104 | | LR | 3e-05 | 3e-05 | 5e-05 | 5e-05 | 5e-05 | 105 | | Scheduler | NA | NA | NA | [warmup](https://arxiv.org/abs/2004.05150) | NA | 106 | 107 | #### 20NewsGroups, Book Summary, Paired Book Summary 108 | 109 | | Parameter | BERT | BERT+TextRank | BERT+Random | Longformer | ToBERT | 110 | |------------|-------|---------------|-------------|---------------------------------------------------|--------| 111 | | Batch size | 8 | 8 | 8 | 16 | 8 | 112 | | Epochs | 20 | 20 | 20 | 20 | 20 | 113 | | LR | 3e-05 | 3e-05 | 3e-05 | 0.005 | 3e-05 | 114 | | Scheduler | NA | NA | NA | [warmup](https://arxiv.org/abs/2004.05150) | NA | 115 | 116 | #### EURLEX, Inverted EURLEX 117 | 118 | | Parameter | BERT | BERT+TextRank | BERT+Random | Longformer | ToBERT | 119 | |------------|-------|---------------|-------------|---------------------------------------------------|--------| 120 | | Batch size | 8 | 8 | 8 | 16 | 8 | 121 | | Epochs | 20 | 20 | 20 | 20 | 20 | 122 | | LR | 5e-05 | 5e-05 | 5e-05 | 0.005 | 5e-05 | 123 | | Scheduler | NA | NA | NA | [warmup](https://arxiv.org/abs/2004.05150) | NA | 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # Implement your code here. 2 | -------------------------------------------------------------------------------- /src/dataloader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import glob 4 | import tqdm 5 | import jsonlines 6 | import pandas as pd 7 | import re 8 | from sklearn import preprocessing 9 | from sklearn.preprocessing import MultiLabelBinarizer 10 | from sklearn.datasets import fetch_20newsgroups 11 | from sklearn.model_selection import train_test_split 12 | from pytorch_lightning import seed_everything 13 | 14 | def read_hyperpartisan_data(hyper_file_path): 15 | """ 16 | Read a jsonl file for Hyperpartisan News Detection data and return lists of documents and labels 17 | :param hyper_file_path: path to a jsonl file 18 | :return: lists of documents and labels 19 | """ 20 | documents = [] 21 | labels = [] 22 | with jsonlines.open(hyper_file_path) as reader: 23 | for doc in tqdm.tqdm(reader): 24 | documents.append(doc['text']) 25 | labels.append(doc['label']) 26 | 27 | return documents, labels 28 | 29 | def prepare_hyperpartisan_data(hyper_path='./data/hyperpartisan'): 30 | """ 31 | Load the Hyperpartisan News Detection data and prepare the datasets 32 | :param hyper_path: path to the dataset files, {train, dev, test}.jsonl 33 | :return: dicts of lists of documents and labels and number of labels 34 | """ 35 | if not os.path.exists(hyper_path): 36 | raise Exception("Data path not found: {}".format(hyper_path)) 37 | 38 | text_set = {} 39 | label_set = {} 40 | for split in ['train', 'dev', 'test']: 41 | file_path = os.path.join(hyper_path, split + '.jsonl') 42 | text_set[split], label_set[split] = read_hyperpartisan_data(file_path) 43 | 44 | enc = preprocessing.LabelBinarizer() 45 | enc.fit(label_set['train']) 46 | num_labels = 1 # binary classification 47 | # vectorize labels as zeros and ones 48 | vectorized_labels = {} 49 | for split in ['train', 'dev', 'test']: 50 | vectorized_labels[split] = enc.transform(label_set[split]) 51 | return text_set, vectorized_labels, num_labels 52 | 53 | def clean_20news_data(text_str): 54 | """ 55 | Clean up 20NewsGroups text data, from CogLTX: https://github.com/Sleepychord/CogLTX/blob/main/20news/process_20news.py 56 | // SPDX-License-Identifier: MIT 57 | :param text_str: text string to clean up 58 | :return: clean text string 59 | """ 60 | tmp_doc = [] 61 | for words in text_str.split(): 62 | if ':' in words or '@' in words or len(words) > 60: 63 | pass 64 | else: 65 | c = re.sub(r'[>|-]', '', words) 66 | # c = words.replace('>', '').replace('-', '') 67 | if len(c) > 0: 68 | tmp_doc.append(c) 69 | tmp_doc = ' '.join(tmp_doc) 70 | tmp_doc = re.sub(r'\([A-Za-z \.]*[A-Z][A-Za-z \.]*\) ', '', tmp_doc) 71 | return tmp_doc 72 | 73 | def prepare_20news_data(): 74 | """ 75 | Load the 20NewsGroups datasets and split the original train set into train/dev sets 76 | :return: dicts of lists of documents and labels and number of labels 77 | """ 78 | text_set = {} 79 | label_set = {} 80 | test_set = fetch_20newsgroups(subset='test', random_state=21) 81 | text_set['test'] = [clean_20news_data(text) for text in test_set.data] 82 | label_set['test'] = test_set.target 83 | 84 | train_set = fetch_20newsgroups(subset='train', random_state=21) 85 | train_text = [clean_20news_data(text) for text in train_set.data] 86 | train_label = train_set.target 87 | 88 | # take 10% of the train set as the dev set 89 | text_set['train'], text_set['dev'], label_set['train'], label_set['dev'] = train_test_split(train_text, 90 | train_label, 91 | test_size=0.10, 92 | random_state=21) 93 | enc = preprocessing.LabelEncoder() 94 | enc.fit(label_set['train']) 95 | num_labels = len(enc.classes_) 96 | 97 | # vectorize labels as zeros and ones 98 | vectorized_labels = {} 99 | for split in ['train', 'dev', 'test']: 100 | vectorized_labels[split] = enc.transform(label_set[split]) 101 | 102 | return text_set, vectorized_labels, num_labels 103 | 104 | def prepare_eurlex_data(inverted, eur_path='./data/EURLEX57K'): 105 | """ 106 | Load EURLEX-57K dataset and prepare the datasets 107 | :param inverted: whether to invert the section order or not 108 | :param eur_path: path to the EURLEX files 109 | :return: dicts of lists of documents and labels and number of labels 110 | """ 111 | if not os.path.exists(eur_path): 112 | raise Exception("Data path not found: {}".format(eur_path)) 113 | 114 | text_set = {'train': [], 'dev': [], 'test': []} 115 | label_set = {'train': [], 'dev': [], 'test': []} 116 | 117 | for split in ['train', 'dev', 'test']: 118 | file_paths = glob.glob(os.path.join(eur_path, split, '*.json')) 119 | for file_path in tqdm.tqdm(sorted(file_paths)): 120 | text, tags = read_eurlex_file(file_path, inverted) 121 | text_set[split].append(text) 122 | label_set[split].append(tags) 123 | 124 | vectorized_labels, num_labels = vectorize_labels(label_set) 125 | 126 | return text_set, vectorized_labels, num_labels 127 | 128 | def read_eurlex_file(eur_file_path, inverted): 129 | """ 130 | Read each json file and return lists of documents and labels 131 | :param eur_file_path: path to a json file 132 | :param inverted: whether to invert the section order or not 133 | :return: list of documents and labels 134 | """ 135 | tags = [] 136 | with open(eur_file_path) as file: 137 | data = json.load(file) 138 | sections = [] 139 | text = '' 140 | if inverted: 141 | sections.extend(data['main_body']) 142 | sections.append(data['recitals']) 143 | sections.append(data['header']) 144 | 145 | else: 146 | sections.append(data['header']) 147 | sections.append(data['recitals']) 148 | sections.extend(data['main_body']) 149 | 150 | text = '\n'.join(sections) 151 | 152 | for concept in data['concepts']: 153 | tags.append(concept) 154 | 155 | return text, tags 156 | 157 | def parse_json_column(genre_data): 158 | """ 159 | Read genre information as a json string and convert it to a dict 160 | :param genre_data: genre data to be converted 161 | :return: dict of genre names 162 | """ 163 | try: 164 | return json.loads(genre_data) 165 | except Exception as e: 166 | return None # when genre information is missing 167 | 168 | def load_booksummaries_data(book_path): 169 | """ 170 | Load the Book Summary data and split it into train/dev/test sets 171 | :param book_path: path to the booksummaries.txt file 172 | :return: train, dev, test as pandas data frames 173 | """ 174 | book_df = pd.read_csv(book_path, sep='\t', names=["Wikipedia article ID", 175 | "Freebase ID", 176 | "Book title", 177 | "Author", 178 | "Publication date", 179 | "genres", 180 | "summary"], 181 | converters={'genres': parse_json_column}) 182 | book_df = book_df.dropna(subset=['genres', 'summary']) # remove rows missing any genres or summaries 183 | book_df['word_count'] = book_df['summary'].str.split().str.len() 184 | book_df = book_df[book_df['word_count'] >= 10] 185 | train = book_df.sample(frac=0.8, random_state=22) 186 | rest = book_df.drop(train.index) 187 | dev = rest.sample(frac=0.5, random_state=22) 188 | test = rest.drop(dev.index) 189 | return train, dev, test 190 | 191 | def prepare_book_summaries(pairs, book_path='data/booksummaries/booksummaries.txt'): 192 | """ 193 | Load the Book Summary data and prepare the datasets 194 | :param pairs: whether to combine pairs of documents or not 195 | :param book_path: path to the booksummaries.txt file 196 | :return: dicts of lists of documents and labels and number of labels 197 | """ 198 | if not os.path.exists(book_path): 199 | raise Exception("Data not found: {}".format(book_path)) 200 | 201 | text_set = {'train': [], 'dev': [], 'test': []} 202 | label_set = {'train': [], 'dev': [], 'test': []} 203 | train, dev, test = load_booksummaries_data(book_path) 204 | 205 | if not pairs: 206 | text_set['train'] = train['summary'].tolist() 207 | text_set['dev'] = dev['summary'].tolist() 208 | text_set['test'] = test['summary'].tolist() 209 | 210 | train_genres = train['genres'].tolist() 211 | label_set['train'] = [list(genre.values()) for genre in train_genres] 212 | dev_genres = dev['genres'].tolist() 213 | label_set['dev'] = [list(genre.values()) for genre in dev_genres] 214 | test_genres = test['genres'].tolist() 215 | label_set['test'] = [list(genre.values()) for genre in test_genres] 216 | else: 217 | train_temp = train['summary'].tolist() 218 | dev_temp = dev['summary'].tolist() 219 | test_temp = test['summary'].tolist() 220 | 221 | train_genres = train['genres'].tolist() 222 | train_genres_temp = [list(genre.values()) for genre in train_genres] 223 | dev_genres = dev['genres'].tolist() 224 | dev_genres_temp = [list(genre.values()) for genre in dev_genres] 225 | test_genres = test['genres'].tolist() 226 | test_genres_temp = [list(genre.values()) for genre in test_genres] 227 | 228 | for i in range(0, len(train_temp) - 1, 2): 229 | text_set['train'].append(train_temp[i] + train_temp[i+1]) 230 | label_set['train'].append(list(set(train_genres_temp[i] + train_genres_temp[i+1]))) 231 | 232 | for i in range(0, len(dev_temp) - 1, 2): 233 | text_set['dev'].append(dev_temp[i] + dev_temp[i+1]) 234 | label_set['dev'].append(list(set(dev_genres_temp[i] + dev_genres_temp[i+1]))) 235 | 236 | for i in range(0, len(test_temp) - 1, 2): 237 | text_set['test'].append(test_temp[i] + test_temp[i+1]) 238 | label_set['test'].append(list(set(test_genres_temp[i] + test_genres_temp[i+1]))) 239 | 240 | vectorized_labels, num_labels = vectorize_labels(label_set) 241 | return text_set, vectorized_labels, num_labels 242 | 243 | def vectorize_labels(all_labels): 244 | """ 245 | Combine labels across all data and reformat the labels e.g. [[1, 2], ..., [123, 343, 4] ] --> [[0, 1, 1, ... 0], ...] 246 | Only used for multi-label classification 247 | :param all_labels: dict with labels with keys 'train', 'dev', 'test' 248 | :return: dict of vectorized labels per split and total number of labels 249 | """ 250 | all_set = [] 251 | for split in all_labels: 252 | for labels in all_labels[split]: 253 | all_set.extend(labels) 254 | all_set = list(set(all_set)) 255 | 256 | mlb = MultiLabelBinarizer() 257 | mlb.fit([all_set]) 258 | num_labels = len(mlb.classes_) 259 | 260 | print(f'Total number of labels: {num_labels}') 261 | 262 | result = {} 263 | for split in all_labels: 264 | result[split] = mlb.transform(all_labels[split]) 265 | 266 | return result, num_labels 267 | 268 | if __name__ == "__main__": 269 | seed_everything(3456) 270 | hyper_text_set, hyper_label_set, hyper_num_labels = prepare_hyperpartisan_data() 271 | assert hyper_num_labels == 1 272 | assert len(hyper_text_set['train']) == len(hyper_label_set['train']) == 516 273 | assert len(hyper_text_set['dev']) == len(hyper_label_set['dev']) == 64 274 | assert len(hyper_text_set['test']) == len(hyper_label_set['test']) == 65 275 | news_text_set, news_label_set, news_num_labels = prepare_20news_data() 276 | assert news_num_labels == 20 277 | assert len(news_text_set['train']) == len(news_label_set['train']) == 10182 278 | assert len(news_text_set['dev']) == len(news_label_set['dev']) == 1132 279 | assert len(news_text_set['test']) == len(news_label_set['test']) == 7532 280 | eur_text_set, eur_label_set, eur_num_labels = prepare_eurlex_data(False) 281 | assert eur_num_labels == 4271 282 | assert len(eur_text_set['train']) == len(eur_label_set['train']) == 45000 283 | assert len(eur_text_set['dev']) == len(eur_label_set['dev']) == 6000 284 | assert len(eur_text_set['test']) == len(eur_label_set['test']) == 6000 285 | inverted_text_set, inverted_label_set, inverted_num_labels = prepare_eurlex_data(True) 286 | assert inverted_num_labels == 4271 287 | assert len(inverted_text_set['train']) == len(inverted_label_set['train']) == 45000 288 | assert len(inverted_text_set['dev']) == len(inverted_label_set['dev']) == 6000 289 | assert len(inverted_text_set['test']) == len(inverted_label_set['test']) == 6000 290 | book_text_set, book_label_set, book_num_labels = prepare_book_summaries(False) 291 | assert book_num_labels == 227 292 | assert len(book_text_set['train']) == len(book_label_set['train']) == 10230 293 | assert len(book_text_set['dev']) == len(book_label_set['dev']) == 1279 294 | assert len(book_text_set['test']) == len(book_label_set['test']) == 1279 295 | pair_text_set, pair_label_set, pair_num_labels = prepare_book_summaries(True) 296 | assert pair_num_labels == 227 297 | assert len(pair_text_set['train']) == len(pair_label_set['train']) == 5115 298 | assert len(pair_text_set['dev']) == len(pair_label_set['dev']) == 639 299 | assert len(pair_text_set['test']) == len(pair_label_set['test']) == 639 300 | 301 | -------------------------------------------------------------------------------- /src/datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import spacy 3 | import pytextrank 4 | import random 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | class TruncatedDataset(Dataset): 9 | def __init__(self, text, labels, tokenizer, max_len): 10 | self.tokenizer = tokenizer 11 | self.text = text 12 | self.labels = labels 13 | self.max_len = max_len 14 | 15 | def __len__(self): 16 | return len(self.text) 17 | 18 | def __getitem__(self, index): 19 | text = str(self.text[index]) 20 | text = " ".join(text.split()) 21 | 22 | inputs = self.tokenizer.encode_plus( 23 | text=text, 24 | text_pair=None, 25 | add_special_tokens=True, 26 | max_length=self.max_len, 27 | truncation=True, 28 | padding='max_length', 29 | return_token_type_ids=True 30 | ) 31 | ids = inputs['input_ids'] 32 | mask = inputs['attention_mask'] 33 | token_type_ids = inputs["token_type_ids"] 34 | 35 | return { 36 | 'ids': torch.tensor(ids), 37 | 'mask': torch.tensor(mask), 38 | 'token_type_ids': torch.tensor(token_type_ids), 39 | 'labels': torch.tensor(self.labels[index]) 40 | } 41 | 42 | class TruncatedPlusTextRankDataset(Dataset): 43 | def __init__(self, text, labels, tokenizer, max_len): 44 | self.tokenizer = tokenizer 45 | self.text = text 46 | self.labels = labels 47 | self.max_len = max_len 48 | 49 | def __len__(self): 50 | return len(self.text) 51 | 52 | def apply_textrank(self, text): 53 | nlp = spacy.load("en_core_web_sm") 54 | nlp.add_pipe("textrank") 55 | doc = nlp(text) 56 | num_phrases = len(list(doc._.phrases)) 57 | num_sents = len(list(doc.sents)) 58 | tr = doc._.textrank 59 | running_length = 0 60 | key_sents_idx = [] 61 | key_sents = [] 62 | for sentence in tr.summary(limit_phrases=num_phrases, limit_sentences=num_sents, preserve_order=False): 63 | if running_length <= (self.max_len - 2): 64 | sentence_str = str(sentence) 65 | sentence_tokens = self.tokenizer.tokenize(sentence_str) 66 | running_length += len(sentence_tokens) 67 | key_sents.append(sentence_str) 68 | key_sents_idx.append(sentence.sent.start) 69 | 70 | reorder_idx = list(np.argsort(key_sents_idx)) 71 | selected_text = '' 72 | for idx in reorder_idx: 73 | selected_text += key_sents[idx] + ' ' 74 | return selected_text 75 | 76 | def __getitem__(self, index): 77 | text = str(self.text[index]) 78 | text = " ".join(text.split()) 79 | 80 | inputs = self.tokenizer.encode_plus( 81 | text=text, 82 | text_pair=None, 83 | add_special_tokens=True, 84 | max_length=self.max_len, 85 | truncation=True, 86 | padding='max_length', 87 | return_attention_mask=True, 88 | return_token_type_ids=True, 89 | return_overflowing_tokens=True 90 | ) 91 | 92 | if inputs.get("overflowing_tokens"): 93 | # select key sentences if text is longer than max length 94 | selected_text = self.apply_textrank(text) 95 | 96 | second_inputs = self.tokenizer.encode_plus( 97 | text=selected_text, 98 | text_pair=None, 99 | add_special_tokens=True, 100 | max_length=self.max_len, 101 | truncation=True, 102 | padding='max_length', 103 | return_attention_mask=True, 104 | return_token_type_ids=True, 105 | return_overflowing_tokens=True 106 | ) 107 | else: 108 | second_inputs = inputs 109 | 110 | ids = (inputs['input_ids'], second_inputs['input_ids']) 111 | mask = (inputs['attention_mask'], second_inputs['attention_mask']) 112 | token_type_ids = (inputs["token_type_ids"], second_inputs["token_type_ids"]) 113 | 114 | return { 115 | 'ids': torch.tensor(ids), 116 | 'mask': torch.tensor(mask), 117 | 'token_type_ids': torch.tensor(token_type_ids), 118 | 'labels': torch.tensor(self.labels[index]) 119 | } 120 | 121 | class TruncatedPlusRandomDataset(Dataset): 122 | def __init__(self, text, labels, tokenizer, max_len): 123 | self.tokenizer = tokenizer 124 | self.text = text 125 | self.labels = labels 126 | self.max_len = max_len 127 | 128 | def __len__(self): 129 | return len(self.text) 130 | 131 | def select_random_sents(self, text): 132 | nlp = spacy.load("en_core_web_sm") 133 | doc = nlp(text) 134 | sents = list(doc.sents) 135 | running_length = 0 136 | sent_idxs = list(range(len(sents))) 137 | selected_idx = [] 138 | while running_length <= (self.max_len - 2) and sent_idxs: 139 | idx = random.choice(sent_idxs) 140 | sent_idxs.remove(idx) 141 | sentence = str(sents[idx]) 142 | sentence_tokens = self.tokenizer.tokenize(sentence) 143 | running_length += len(sentence_tokens) 144 | selected_idx.append(idx) 145 | 146 | reorder_idx = sorted(selected_idx) 147 | selected_text = '' 148 | for idx in reorder_idx: 149 | selected_text += str(sents[idx]) + ' ' 150 | return selected_text 151 | 152 | def __getitem__(self, index): 153 | text = str(self.text[index]) 154 | text = " ".join(text.split()) 155 | 156 | inputs = self.tokenizer.encode_plus( 157 | text=text, 158 | text_pair=None, 159 | add_special_tokens=True, 160 | max_length=self.max_len, 161 | truncation=True, 162 | padding='max_length', 163 | return_attention_mask=True, 164 | return_token_type_ids=True, 165 | return_overflowing_tokens=True 166 | ) 167 | 168 | if inputs.get("overflowing_tokens"): 169 | # select random sentences if text is longer than max length 170 | selected_text = self.select_random_sents(text) 171 | second_inputs = self.tokenizer.encode_plus( 172 | text=selected_text, 173 | text_pair=None, 174 | add_special_tokens=True, 175 | max_length=self.max_len, 176 | truncation=True, 177 | padding='max_length', 178 | return_attention_mask=True, 179 | return_token_type_ids=True, 180 | return_overflowing_tokens=True 181 | ) 182 | else: 183 | second_inputs = inputs 184 | 185 | ids = (inputs['input_ids'], second_inputs['input_ids']) 186 | mask = (inputs['attention_mask'], second_inputs['attention_mask']) 187 | token_type_ids = (inputs["token_type_ids"], second_inputs["token_type_ids"]) 188 | 189 | return { 190 | 'ids': torch.tensor(ids), 191 | 'mask': torch.tensor(mask), 192 | 'token_type_ids': torch.tensor(token_type_ids), 193 | 'labels': torch.tensor(self.labels[index]) 194 | } 195 | 196 | class ChunkDataset(Dataset): 197 | def __init__(self, text, labels, tokenizer, chunk_len=200, overlap_len=50): 198 | self.tokenizer = tokenizer 199 | self.text = text 200 | self.labels = labels 201 | self.overlap_len = overlap_len 202 | self.chunk_len = chunk_len 203 | 204 | def __len__(self): 205 | return len(self.labels) 206 | 207 | def chunk_tokenizer(self, tokenized_data, targets): 208 | input_ids_list = [] 209 | attention_mask_list = [] 210 | token_type_ids_list = [] 211 | targets_list = [] 212 | 213 | previous_input_ids = tokenized_data["input_ids"] 214 | previous_attention_mask = tokenized_data["attention_mask"] 215 | previous_token_type_ids = tokenized_data["token_type_ids"] 216 | remain = tokenized_data.get("overflowing_tokens") 217 | 218 | input_ids_list.append(torch.tensor(previous_input_ids, dtype=torch.long)) 219 | attention_mask_list.append(torch.tensor(previous_attention_mask, dtype=torch.long)) 220 | token_type_ids_list.append(torch.tensor(previous_token_type_ids, dtype=torch.long)) 221 | targets_list.append(torch.tensor(targets, dtype=torch.long)) 222 | 223 | if remain: # if there is any overflowing tokens 224 | # remain = torch.tensor(remain, dtype=torch.long) 225 | idxs = range(len(remain) + self.chunk_len) 226 | idxs = idxs[(self.chunk_len - self.overlap_len - 2) 227 | ::(self.chunk_len - self.overlap_len - 2)] 228 | input_ids_first_overlap = previous_input_ids[-( 229 | self.overlap_len + 1):-1] 230 | start_token = [101] 231 | end_token = [102] 232 | 233 | for i, idx in enumerate(idxs): 234 | if i == 0: 235 | input_ids = input_ids_first_overlap + remain[:idx] 236 | elif i == len(idxs): 237 | input_ids = remain[idx:] 238 | elif previous_idx >= len(remain): 239 | break 240 | else: 241 | input_ids = remain[(previous_idx - self.overlap_len):idx] 242 | 243 | previous_idx = idx 244 | 245 | nb_token = len(input_ids) + 2 246 | attention_mask = np.ones(self.chunk_len) 247 | attention_mask[nb_token:self.chunk_len] = 0 248 | token_type_ids = np.zeros(self.chunk_len) 249 | input_ids = start_token + input_ids + end_token 250 | if self.chunk_len - nb_token > 0: 251 | padding = np.zeros(self.chunk_len - nb_token) 252 | input_ids = np.concatenate([input_ids, padding]) 253 | 254 | input_ids_list.append(torch.tensor(input_ids, dtype=torch.long)) 255 | attention_mask_list.append(torch.tensor(attention_mask, dtype=torch.long)) 256 | token_type_ids_list.append(torch.tensor(token_type_ids, dtype=torch.long)) 257 | targets_list.append(torch.tensor(targets, dtype=torch.long)) 258 | 259 | return ({ 260 | 'ids': input_ids_list, 261 | 'mask': attention_mask_list, 262 | 'token_type_ids': token_type_ids_list, 263 | 'targets': targets_list, 264 | 'len': [torch.tensor(len(targets_list), dtype=torch.long)] 265 | }) 266 | 267 | def __getitem__(self, index): 268 | text = " ".join(str(self.text[index]).split()) 269 | targets = self.labels[index] 270 | 271 | data = self.tokenizer.encode_plus( 272 | text=text, 273 | text_pair=None, 274 | add_special_tokens=True, 275 | max_length=self.chunk_len, 276 | truncation=True, 277 | pad_to_max_length=True, 278 | return_token_type_ids=True, 279 | return_overflowing_tokens=True 280 | ) 281 | 282 | chunk_token = self.chunk_tokenizer(data, targets) 283 | return chunk_token -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import LongformerModel, BertModel 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class BERTPlus(torch.nn.Module): 7 | def __init__(self, dropout_rate, num_labels): 8 | super(BERTPlus, self).__init__() 9 | self.bert = BertModel.from_pretrained('bert-base-uncased') 10 | self.dropout = torch.nn.Dropout(dropout_rate) 11 | self.classifier = torch.nn.Linear(768*2, num_labels) 12 | 13 | def forward(self, ids, mask, token_type_ids): 14 | _, truncated_output = self.bert(ids[:, 0,:], attention_mask=mask[:, 0,:], token_type_ids=token_type_ids[:, 0,:], return_dict=False) 15 | _, additional_text_output = self.bert(ids[:, 1,:], attention_mask=mask[:, 1,:], token_type_ids=token_type_ids[:, 1,:], return_dict=False) 16 | concat_output = torch.cat((truncated_output, additional_text_output), dim=1) # batch_size, 768*2 17 | drop_output = self.dropout(concat_output) # batch_size, 768*2 18 | logits = self.classifier(drop_output) # batch_size, num_labels 19 | return logits 20 | 21 | class BERTClass(torch.nn.Module): 22 | def __init__(self, dropout_rate, num_labels): 23 | super(BERTClass, self).__init__() 24 | self.bert = BertModel.from_pretrained('bert-base-uncased') 25 | self.dropout = torch.nn.Dropout(dropout_rate) 26 | self.classifier = torch.nn.Linear(768, num_labels) 27 | 28 | def forward(self, ids, mask, token_type_ids): 29 | _, bert_output = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids, return_dict=False) 30 | drop_output = self.dropout(bert_output) 31 | logits = self.classifier(drop_output) 32 | return logits 33 | 34 | class LongformerClass(torch.nn.Module): 35 | def __init__(self, num_labels): 36 | super(LongformerClass, self).__init__() 37 | self.longformer = LongformerModel.from_pretrained('allenai/longformer-base-4096', add_pooling_layer=False, 38 | gradient_checkpointing=True) 39 | self.classifier = LongformerClassificationHead(hidden_size=768, hidden_dropout_prob=0.1, num_labels=num_labels) 40 | 41 | def forward(self, ids, mask, token_type_ids): 42 | # Initialize global attention on CLS token 43 | global_attention_mask = torch.zeros_like(ids) 44 | global_attention_mask[:, 0] = 1 45 | sequence_output, _ = self.longformer(ids, attention_mask=mask, global_attention_mask=global_attention_mask, 46 | token_type_ids=token_type_ids, return_dict=False) 47 | logits = self.classifier(sequence_output) 48 | return logits 49 | 50 | class LongformerClassificationHead(torch.nn.Module): 51 | # This class is from https://huggingface.co/transformers/_modules/transformers/models/longformer 52 | # /modeling_longformer.html#LongformerForSequenceClassification 53 | """Head for sentence-level classification tasks.""" 54 | 55 | def __init__(self, hidden_size, hidden_dropout_prob, num_labels): # config from transformers.LongformerConfig.from_pretrained('allenai/longformer-base-4096') 56 | super().__init__() 57 | self.dense = torch.nn.Linear(hidden_size, hidden_size) 58 | self.dropout = torch.nn.Dropout(hidden_dropout_prob) 59 | self.out_proj = torch.nn.Linear(hidden_size, num_labels) 60 | 61 | def forward(self, hidden_states, **kwargs): 62 | hidden_states = hidden_states[:, 0, :] # take token (equiv. to [CLS]) 63 | hidden_states = self.dropout(hidden_states) 64 | hidden_states = self.dense(hidden_states) 65 | hidden_states = torch.tanh(hidden_states) 66 | hidden_states = self.dropout(hidden_states) 67 | output = self.out_proj(hidden_states) 68 | return output 69 | 70 | class ToBERTModel(torch.nn.Module): 71 | def __init__(self, num_labels, device): 72 | super(ToBERTModel, self).__init__() 73 | self.bert = BertModel.from_pretrained('bert-base-uncased') 74 | self.trans = torch.nn.TransformerEncoderLayer(d_model=768, nhead=2) 75 | self.fc = torch.nn.Linear(768, 30) 76 | self.classifier = torch.nn.Linear(30, num_labels) 77 | self.device = device 78 | 79 | def forward(self, ids, mask, token_type_ids, length): 80 | _, pooled_out = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids, return_dict=False) 81 | 82 | chunks_emb = pooled_out.split_with_sizes(length) 83 | batch_emb_pad = torch.nn.utils.rnn.pad_sequence( 84 | chunks_emb, padding_value=0, batch_first=True) 85 | batch_emb = batch_emb_pad.transpose(0, 1) # (B,L,D) -> (L,B,D) 86 | padding_mask = np.zeros([batch_emb.shape[1], batch_emb.shape[0]]) # Batch size, Sequence length 87 | for idx in range(len(padding_mask)): 88 | padding_mask[idx][length[idx]:] = 1 # padding key = 1 ignored 89 | 90 | padding_mask = torch.tensor(padding_mask).to(self.device, dtype=torch.bool) 91 | trans_output = self.trans(batch_emb, src_key_padding_mask=padding_mask) 92 | mean_pool = torch.mean(trans_output, dim=0) # Batch size, 768 93 | fc_output = self.fc(mean_pool) 94 | relu_output = F.relu(fc_output) 95 | logits = self.classifier(relu_output) 96 | 97 | return logits -------------------------------------------------------------------------------- /src/py.typed: -------------------------------------------------------------------------------- 1 | # Marker file that indicates this package supports typing 2 | -------------------------------------------------------------------------------- /src/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7.0 2 | pytorch-lightning>=1.3.7 3 | transformers>=4.6.1 4 | spacy>=3.0.6 5 | sklearn>=0.22.1 6 | pytextrank==3.1.1 7 | torchmetrics==0.3.2 8 | jsonlines -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import time 3 | import os 4 | import argparse 5 | import numpy as np 6 | import logging 7 | import glob 8 | import torchmetrics 9 | import pytorch_lightning as pl 10 | from pytorch_lightning import seed_everything 11 | from pytorch_lightning.callbacks import ModelCheckpoint 12 | from pytorch_lightning.loggers import TensorBoardLogger 13 | import torch 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | from torch import cuda 17 | from torch.utils.data import DataLoader 18 | from transformers import LongformerTokenizer, BertTokenizer 19 | from transformers.optimization import get_linear_schedule_with_warmup 20 | from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score 21 | 22 | import dataloader 23 | import models 24 | import datasets 25 | 26 | def chunk_collate_fn(batches): 27 | """ 28 | Create batches for ChunkDataset 29 | """ 30 | return [{key: torch.stack(value) for key, value in batch.items()} for batch in batches] 31 | 32 | def create_dataloader(dataset_class, text_set, label_set, tokenizer, max_length, batch_size, num_workers): 33 | """ 34 | Create appropriate dataloaders for the given data 35 | :param dataset_class: Dataset to use as defined in datasets.py 36 | :param text_set: dict of lists of texts for train/dev/test splits, keys=['train', 'dev', 'test'] 37 | :param label_set: dict of lists of labels for train/dev/test splits, keys=['train', 'dev', 'test'] 38 | :param tokenizer: tokenizer of choice e.g. LongformerTokenizer, BertTokenizer 39 | :param max_length: maximum length of sequence e.g. 512 40 | :param batch_size: batch size for dataloaders 41 | :param num_workers: number of workers for dataloaders 42 | :return: set of dataloaders for train/dev/test splits, keys=['train', 'dev', 'test'] 43 | """ 44 | dataloaders = {} 45 | 46 | if 'train' in text_set.keys(): 47 | split = 'train' 48 | dataset = dataset_class(text_set[split], label_set[split], tokenizer, max_length) 49 | if isinstance(dataset, datasets.ChunkDataset): 50 | dataloaders[split] = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, 51 | pin_memory=True, collate_fn=chunk_collate_fn) 52 | else: 53 | dataloaders[split] = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, 54 | pin_memory=True) 55 | 56 | for split in ['dev', 'test']: 57 | dataset = dataset_class(text_set[split], label_set[split], tokenizer, max_length) 58 | if isinstance(dataset, datasets.ChunkDataset): 59 | dataloaders[split] = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, 60 | pin_memory=True, collate_fn=chunk_collate_fn) 61 | else: 62 | dataloaders[split] = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, 63 | pin_memory=True) 64 | 65 | return dataloaders 66 | 67 | def get_long_texts_and_labels(text_dict, label_dict, tokenizer, max_length=512): 68 | """ 69 | Find texts that have more than a given max token length and their labels 70 | :param text_dict: dict of lists of texts for train/dev/test splits, keys=['train', 'dev', 'test'] 71 | :param label_dict: dict of lists of labels for train/dev/test splits, keys=['train', 'dev', 'test'] 72 | :param tokenizer: tokenizer of choice e.g. LongformerTokenizer, BertTokenizer 73 | :param max_length: maximum length of sequence e.g. 512 74 | :return: dicts of lists of texts with more than the max token length and their labels 75 | """ 76 | long_text_set = {'dev': [], 'test': []} 77 | long_label_set = {'dev': [], 'test': []} 78 | for split in ['dev', 'test']: 79 | long_text_idx = [] 80 | for idx, text in enumerate(text_dict[split]): 81 | if len(tokenizer.tokenize(text)) > (max_length - 2): 82 | long_text_idx.append(idx) 83 | long_text_set[split] = [text_dict[split][i] for i in long_text_idx] 84 | long_label_set[split] = [label_dict[split][i] for i in long_text_idx] 85 | return long_text_set, long_label_set 86 | 87 | class Classification(pl.LightningModule): 88 | """ 89 | Pytorch Lightning module to train all models 90 | """ 91 | def __init__(self, model, lr, scheduler, label_type, chunk, num_labels, dataset_size, epochs, batch_size): 92 | super().__init__() 93 | self.model = model 94 | self.lr = lr 95 | self.scheduler = scheduler 96 | self.label_type = label_type 97 | self.chunk = chunk 98 | self.num_labels = num_labels 99 | self.dataset_size = dataset_size 100 | self.epochs = epochs 101 | self.batch_size = batch_size 102 | if self.label_type == 'binary_class': 103 | self.eval_metric = torchmetrics.Accuracy(num_classes=self.num_labels) 104 | elif self.label_type == 'multi_label': 105 | self.eval_metric = torchmetrics.F1(num_classes=self.num_labels, average='micro') 106 | elif self.label_type == 'multi_class': 107 | self.eval_metric = torchmetrics.Accuracy(num_classes=self.num_labels, multiclass=True) 108 | 109 | def training_step(self, batch, batch_idx): 110 | start = time.time() 111 | metrics = {} 112 | if self.chunk: 113 | ids = [data['ids'] for data in batch] 114 | mask = [data['mask'] for data in batch] 115 | token_type_ids = [data['token_type_ids'] for data in batch] 116 | targets = [data['targets'][0] for data in batch] 117 | length = [data['len'] for data in batch] 118 | 119 | ids = torch.cat(ids) 120 | mask = torch.cat(mask) 121 | token_type_ids = torch.cat(token_type_ids) 122 | targets = torch.stack(targets) 123 | length = torch.cat(length) 124 | length = [x.item() for x in length] 125 | 126 | ids = ids.to(self.device) 127 | mask = mask.to(self.device) 128 | token_type_ids = token_type_ids.to(self.device) 129 | y = targets.to(self.device) 130 | 131 | y_hat = self.model(ids, mask, token_type_ids, length) 132 | else: 133 | ids = batch['ids'].to(self.device) 134 | mask = batch['mask'].to(self.device) 135 | token_type_ids = batch['token_type_ids'].to(self.device) 136 | y = batch['labels'].to(self.device) 137 | 138 | y_hat = self.model(ids, mask, token_type_ids) 139 | 140 | if self.label_type == 'multi_label' or self.label_type == 'binary_class': 141 | loss = F.binary_cross_entropy_with_logits(y_hat, y.float()) # sigmoid + binary cross entropy loss 142 | preds = torch.sigmoid(y_hat) 143 | 144 | elif self.label_type == 'multi_class': 145 | loss = F.cross_entropy(y_hat, y) # softmax + cross entropy loss 146 | preds = torch.softmax(y_hat, dim=-1) 147 | 148 | metrics['loss'] = loss 149 | 150 | self.log('train_eval_metric', self.eval_metric(preds, y), on_step=True, on_epoch=True, prog_bar=True, logger=True) 151 | self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) 152 | self.log('losses', {'train_loss': loss}, on_step=False, on_epoch=True, prog_bar=False, logger=True) 153 | self.log('train_time', time.time() - start, on_step=False, on_epoch=True, prog_bar=True, logger=True) 154 | 155 | return metrics 156 | 157 | def validation_step(self, batch, batch_idx, prefix='val_'): 158 | start = time.time() 159 | metrics = {} 160 | if self.chunk: 161 | ids = [data['ids'] for data in batch] 162 | mask = [data['mask'] for data in batch] 163 | token_type_ids = [data['token_type_ids'] for data in batch] 164 | targets = [data['targets'][0] for data in batch] 165 | length = [data['len'] for data in batch] 166 | 167 | ids = torch.cat(ids) 168 | mask = torch.cat(mask) 169 | token_type_ids = torch.cat(token_type_ids) 170 | targets = torch.stack(targets) 171 | length = torch.cat(length) 172 | length = [x.item() for x in length] 173 | 174 | ids = ids.to(self.device) 175 | mask = mask.to(self.device) 176 | token_type_ids = token_type_ids.to(self.device) 177 | y = targets.to(self.device) 178 | 179 | y_hat = self.model(ids, mask, token_type_ids, length) 180 | 181 | else: 182 | ids = batch['ids'].to(self.device) 183 | mask = batch['mask'].to(self.device) 184 | token_type_ids = batch['token_type_ids'].to(self.device) 185 | y = batch['labels'].to(self.device) 186 | 187 | y_hat = self.model(ids, mask, token_type_ids) 188 | 189 | if self.label_type == 'multi_label' or self.label_type == 'binary_class': 190 | loss = F.binary_cross_entropy_with_logits(y_hat, y.float()) # sigmoid + loss 191 | preds = torch.sigmoid(y_hat) 192 | 193 | elif self.label_type == 'multi_class': 194 | loss = F.cross_entropy(y_hat, y) # softmax + loss 195 | preds = torch.softmax(y_hat, dim=-1) 196 | 197 | metrics[prefix + 'loss'] = loss 198 | metrics['preds'] = preds 199 | metrics['y'] = y 200 | 201 | self.log(prefix + 'eval_metric', self.eval_metric(preds, y), on_step=True, on_epoch=True, prog_bar=True, logger=True) 202 | self.log(prefix + 'loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) 203 | self.log('losses', {prefix + 'loss': loss}, on_step=False, on_epoch=True, prog_bar=False, logger=True) 204 | self.log(prefix + 'time', time.time() - start, on_step=False, on_epoch=True, prog_bar=True, logger=True) 205 | 206 | return metrics 207 | 208 | def validation_epoch_end(self, outputs, prefix='val_'): 209 | labels = [] 210 | predictions = [] 211 | for output in outputs: 212 | for out_labels in output['y'].detach().cpu(): 213 | labels.append(out_labels) 214 | for out_predictions in output['preds'].detach().cpu(): 215 | predictions.append(out_predictions) 216 | 217 | labels = torch.stack(labels).int() 218 | predictions = torch.stack(predictions) 219 | 220 | y_pred = predictions.numpy() 221 | y_true = labels.numpy() 222 | 223 | if self.label_type == 'multi_label' or self.label_type == 'binary_class': 224 | y_pred_labels = np.where(y_pred > 0.5, 1, 0) 225 | 226 | elif self.label_type == 'multi_class': 227 | y_pred_labels = np.argmax(y_pred, axis=1) 228 | 229 | logging.info("Epoch: {}".format(self.current_epoch)) 230 | 231 | logging.info( 232 | prefix + 'accuracy: {}'.format(accuracy_score(y_true, y_pred_labels))) 233 | 234 | if self.label_type == 'binary_class': 235 | average_type = 'macro' 236 | logging.info(prefix + average_type + '_precision: {}'.format(precision_score(y_true, y_pred_labels, average=average_type))) 237 | logging.info( 238 | prefix + average_type + '_recall: {}'.format(recall_score(y_true, y_pred_labels, average=average_type))) 239 | logging.info( 240 | prefix + average_type + '_f1: {}'.format(f1_score(y_true, y_pred_labels, average=average_type))) 241 | 242 | else: 243 | for average_type in ['micro', 'macro', 'weighted']: 244 | logging.info(prefix + average_type + '_precision: {}'.format(precision_score(y_true, y_pred_labels, average=average_type))) 245 | logging.info( 246 | prefix + average_type + '_recall: {}'.format(recall_score(y_true, y_pred_labels, average=average_type))) 247 | logging.info( 248 | prefix + average_type + '_f1: {}'.format(f1_score(y_true, y_pred_labels, average=average_type))) 249 | 250 | 251 | def test_step(self, batch, batch_idx): 252 | metrics = self.validation_step(batch, batch_idx, 'test_') 253 | return metrics 254 | 255 | def test_epoch_end(self, outputs): 256 | self.validation_epoch_end(outputs, prefix="test_") 257 | 258 | def configure_optimizers(self): 259 | opt = {} 260 | optimizer = optim.Adam(self.model.parameters(), lr=self.lr) 261 | opt['optimizer'] = optimizer 262 | if not self.scheduler: 263 | return opt 264 | else: 265 | num_steps = self.dataset_size * self.epochs / self.batch_size 266 | scheduler = get_linear_schedule_with_warmup( 267 | optimizer, num_warmup_steps=num_steps * 0.1, num_training_steps=num_steps 268 | ) 269 | opt['lr_scheduler'] = scheduler 270 | return opt 271 | 272 | if __name__ == "__main__": 273 | warnings.simplefilter(action='ignore', category=FutureWarning) # ignore future warnings 274 | parser = argparse.ArgumentParser() 275 | parser.add_argument('--model_name', type=str, required=True, 276 | help="Model name: {bert, bertplustextrank, bertplusrandom, longformer, tobert}") 277 | parser.add_argument('--data', type=str, required=True, 278 | help="Dataset name: {eurlex, hyperpartisan, books, 20news}") 279 | parser.add_argument('--batch_size', type=int, required=True, help="Batch size") 280 | parser.add_argument('--lr', type=float, required=False, help="Learning rate e.g. 0.005, 5e-05") 281 | parser.add_argument('--epochs', type=int, required=True, help="Number of epochs") 282 | parser.add_argument("--scheduler", action='store_true', help="Use a warmup scheduler with warmup steps of 0.1 of " 283 | "the total training steps") 284 | parser.add_argument("--num_workers", type=int, default=0, help="Number of data loader workers") 285 | parser.add_argument('--model_dir', type=str, default='./ckpts/', help="Path to save the best model") 286 | parser.add_argument("--seed", type=int, default=3456, help="Random seed") 287 | parser.add_argument("--inverted", action='store_true', help="Use the Inverted EURLEX dataset") 288 | parser.add_argument("--pairs", action='store_true', help="Use the Paired Book Summary dataset") 289 | parser.add_argument("--eval", action='store_true', help="Evaluate only, do not train") 290 | parser.add_argument("--ckpt", type=str, help="Path to a saved ckpt for continued training or evaluation" 291 | "e.g. bert_hyperpartisan_b8_e20_s3456_lr3e-05--epoch=17.ckpt") 292 | 293 | args = parser.parse_args() 294 | 295 | device = 'cuda' if cuda.is_available() else 'cpu' 296 | # sets seeds for numpy, torch, python.random and PYTHONHASHSEED 297 | seed_everything(args.seed, workers=True) 298 | dropout_rate = 0.1 299 | chunk = True if args.model_name.lower() == 'tobert' else False 300 | 301 | if args.data.lower() == 'eurlex': 302 | label_type = 'multi_label' 303 | text_set, label_set, num_labels = dataloader.prepare_eurlex_data(inverted=args.inverted) 304 | 305 | elif args.data.lower() == 'books': 306 | label_type = 'multi_label' 307 | text_set, label_set, num_labels = dataloader.prepare_book_summaries(pairs=args.pairs) 308 | 309 | elif args.data.lower() == '20news': 310 | label_type = 'multi_class' 311 | text_set, label_set, num_labels = dataloader.prepare_20news_data() 312 | 313 | elif args.data.lower() == 'hyperpartisan': 314 | label_type = 'binary_class' 315 | text_set, label_set, num_labels = dataloader.prepare_hyperpartisan_data() 316 | 317 | else: 318 | raise Exception("Data not found: {}".format(args.data)) 319 | 320 | dataset_size = len(label_set['train']) # to calculate the num of steps for warm up scheduler 321 | 322 | if args.model_name.lower() == 'longformer': 323 | tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096', do_lower_case=True) 324 | max_length = 4096 325 | else: 326 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) 327 | max_length = 512 328 | 329 | if args.model_name.lower() == 'bert': 330 | model = models.BERTClass(dropout_rate, num_labels) 331 | dataset_class = datasets.TruncatedDataset 332 | 333 | elif args.model_name.lower() == 'bertplustextrank': 334 | model = models.BERTPlus(dropout_rate, num_labels) 335 | dataset_class = datasets.TruncatedPlusTextRankDataset 336 | 337 | elif args.model_name.lower() == 'bertplusrandom': 338 | model = models.BERTPlus(dropout_rate, num_labels) 339 | dataset_class = datasets.TruncatedPlusRandomDataset 340 | 341 | elif args.model_name.lower() == 'longformer': 342 | model = models.LongformerClass(num_labels) 343 | dataset_class = datasets.TruncatedDataset 344 | 345 | elif args.model_name.lower() == 'tobert': 346 | max_length = 200 # divide documents into chunks up to 200 tokens 347 | model = models.ToBERTModel(num_labels, device) 348 | dataset_class = datasets.ChunkDataset 349 | 350 | else: 351 | raise Exception("Model not found: {}".format(args.model_name)) 352 | 353 | dataloaders = create_dataloader(dataset_class, text_set, label_set, tokenizer, max_length, args.batch_size, 354 | args.num_workers) 355 | 356 | long_text_set, long_label_set = get_long_texts_and_labels(text_set, label_set, tokenizer) 357 | long_dataloaders = create_dataloader(dataset_class, long_text_set, long_label_set, tokenizer, max_length, 358 | args.batch_size, args.num_workers) 359 | 360 | model.to(device) 361 | 362 | if not os.path.exists(args.model_dir): 363 | os.makedirs(args.model_dir) 364 | print(f'Model directory created: {args.model_dir}') 365 | 366 | task = Classification(model, args.lr, args.scheduler, label_type, chunk, num_labels, dataset_size, args.epochs, 367 | args.batch_size) 368 | 369 | inverted_str = '_inverted' if args.inverted else '' 370 | pairs_str = '_pairs' if args.pairs else '' 371 | scheduler_str = '_warmup' if args.scheduler else '' 372 | 373 | output_model_name = args.model_name + '_' + args.data + inverted_str + pairs_str + '_b' + str(args.batch_size) + \ 374 | '_e' + str(args.epochs) + '_s' + str(args.seed) + '_lr' + str(args.lr) + scheduler_str 375 | 376 | logging.basicConfig(filename=output_model_name + '.log', level=logging.DEBUG) 377 | logger = TensorBoardLogger('tb_logs', name=output_model_name) 378 | 379 | for arg in vars(args): 380 | logging.info("{}: {}".format(arg, getattr(args, arg))) 381 | 382 | if not args.eval: # train mode 383 | ckpt_config = ModelCheckpoint( 384 | monitor="val_eval_metric_epoch", 385 | verbose=False, 386 | save_top_k=1, 387 | save_weights_only=False, 388 | mode='max', 389 | every_n_val_epochs=1, 390 | dirpath=args.model_dir, 391 | filename=output_model_name + "--{epoch}" 392 | ) 393 | if args.ckpt: 394 | trainer = pl.Trainer(logger=logger, 395 | callbacks=ckpt_config, 396 | gpus=1, 397 | deterministic=True, 398 | log_gpu_memory='min_max', 399 | num_sanity_val_steps=0, 400 | max_epochs=args.epochs, 401 | resume_from_checkpoint=args.model_dir + args.ckpt) 402 | 403 | else: 404 | trainer = pl.Trainer(logger=logger, 405 | callbacks=ckpt_config, 406 | gpus=1, 407 | deterministic=True, 408 | log_gpu_memory='min_max', 409 | num_sanity_val_steps=0, 410 | max_epochs=args.epochs) 411 | 412 | print("Training: {}".format(output_model_name)) 413 | trainer.fit(model=task, train_dataloader=dataloaders['train'], val_dataloaders=dataloaders['dev']) 414 | 415 | for _ckpt in range(len(trainer.checkpoint_callbacks)): 416 | logging.info("Testing") 417 | paths = trainer.checkpoint_callbacks[_ckpt] 418 | ckpt_path = trainer.checkpoint_callbacks[_ckpt].best_model_path 419 | logging.info("Checkpoint path: {}".format(ckpt_path)) 420 | metrics = trainer.test(test_dataloaders=dataloaders['test'], ckpt_path=ckpt_path) 421 | for metric in metrics: 422 | for key in metric: 423 | logging.info("{}: {}".format(key, metric[key])) 424 | 425 | for split in ['dev', 'test']: 426 | logging.info("Evaluating on long documents in the {} set only".format(split)) 427 | metrics = trainer.test(test_dataloaders=long_dataloaders[split], ckpt_path=ckpt_path) 428 | for metric in metrics: 429 | for key in metric: 430 | logging.info("long_{}_{}: {}".format(split, key, metric[key])) 431 | 432 | else: # eval mode 433 | if args.ckpt: 434 | ckpt_paths = glob.glob(args.model_dir + args.ckpt) 435 | else: 436 | ckpt_paths = glob.glob(args.model_dir + output_model_name + '*.ckpt') 437 | 438 | logging.info("Evaluating: {}".format(output_model_name)) 439 | 440 | for ckpt_path in ckpt_paths: 441 | logging.info("Checkpoint path: {}".format(ckpt_path)) 442 | task.load_from_checkpoint(ckpt_path, model=model, lr=args.lr, scheduler=args.scheduler, label_type=label_type, 443 | num_labels=num_labels) 444 | 445 | trainer = pl.Trainer(gpus=1) 446 | 447 | for split in ['dev', 'test']: 448 | logging.info("Evaluating on all documents in the {} set".format(split)) 449 | metrics = trainer.test(model= task, test_dataloaders=dataloaders[split]) 450 | for metric in metrics: 451 | for key in metric: 452 | logging.info("all_{}_{}: {}".format(split, key, metric[key])) 453 | 454 | for split in ['dev', 'test']: 455 | logging.info("Evaluating on long documents in the {} set only".format(split)) 456 | metrics = trainer.test(model= task, test_dataloaders=long_dataloaders[split]) 457 | for metric in metrics: 458 | for key in metric: 459 | logging.info("long_{}_{}: {}".format(split, key, metric[key])) --------------------------------------------------------------------------------