├── LICENSE.md ├── NOTICES.txt ├── README.md ├── classitransformers ├── __init__.py ├── configs.py ├── downloader.py ├── inference.py ├── pytransformers.py ├── report.py ├── tfbert.py └── tfelectra.py ├── data_preparation_format.txt ├── datasets ├── bbcfull │ ├── dev.csv │ ├── infer.csv │ ├── test.csv │ ├── test_with_label.csv │ └── train.csv ├── bbcsports │ ├── dev.csv │ ├── infer.csv │ ├── test.csv │ ├── test_with_label.csv │ └── train.csv ├── financial_phrasebank │ ├── dev.csv │ ├── test.csv │ ├── test_with_label.csv │ ├── testing.csv │ ├── train.csv │ ├── training.csv │ └── validation.csv └── semeval_relatioship_classification │ ├── TEST_FILE_FULL.TXT │ ├── TRAIN_FILE.TXT │ ├── test.csv │ └── train.csv ├── electra ├── __init__.py ├── modeling.py ├── optimization.py └── tokenization.py ├── environment.yml ├── models ├── Albert │ ├── config.json │ └── spiece.model ├── Bert_base │ ├── bert_config.json │ └── vocab.txt ├── DistilBert │ ├── config.json │ └── vocab.txt ├── Electra_base │ └── vocab.txt └── Roberta │ ├── config.json │ ├── merges.txt │ └── vocab.json ├── requirements.txt ├── sample_notebooks ├── PYTORCH_classi_transformers_BBCnews_example.ipynb ├── PYTORCH_classi_transformers_BBCsports_example.ipynb ├── PYTORCH_classi_transformers_entity_relationship_example.ipynb ├── PYTORCH_classi_transformers_sentiment_example.ipynb ├── TF_classi_transformers_BBCnews_example.ipynb ├── TF_classi_transformers_BBCsports_example.ipynb ├── TF_classi_transformers_entity_relationship_example.ipynb └── TF_classi_transformers_sentiment_example.ipynb ├── setup.py └── wip /LICENSE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /NOTICES.txt: -------------------------------------------------------------------------------- 1 | ========================================================================= 2 | == NOTICE file for use with the Apache License, Version 2.0 == 3 | ========================================================================= 4 | 5 | classitransformers 6 | Copyright 2020 FMR LLC. 7 | 8 | This library includes software developed at FMR LLC (https://www.fidelity.com/). 9 | 10 | The library utilizes or derived from the following works and hence the dependency installed separately. 11 | - tensorflow | https://github.com/tensorflow/tensorflow | Apache 2.0 License 12 | - bert-tensorflow | https://github.com/google-research/bert | Apache 2.0 License 13 | - electra | https://github.com/google-research/electra | Apache 2.0 License 14 | - torch | https://github.com/pytorch/pytorch | 3-Clause BSD License 15 | - transformers | https://github.com/huggingface/transformers | Apache 2.0 License 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ClassiTransformers 2 | 3 | [ClassiTransformers](https://pypi.org/project/classitransformers/) is an abstract library based on Tensorflow implementation of BERT and ELECTRA, and transformers library of HuggingFace Inc. 4 | 5 | Currently implemented models 6 | - BERT (Tensorflow) 7 | - ELECTRA (Tensorflow) 8 | - RoBERTa (PyTorch - transformers library) 9 | - ALBERT (PyTorch - transformers library) 10 | - DistilBERT (PyTorch - transformers library) 11 | 12 | Task 13 | - Sequence Classification 14 | 15 | ## Features 16 | 17 | - Works for N-class classification problem where N is any number of classes. 18 | - Easy to use. Takes away all the complexity of writing tensorflow or pytorch codes for training and testing classification models. 19 | - It provides an methods to easily train, test and create deployable models in .pb and .bin format in just 5 steps. 20 | - Hyperparameters can be easily modified without having to change the source code. 21 | - Deployment ready module available. 22 | 23 | # Table of contents 24 | 25 | 26 | - [ClassiTransformers](#classi-transformers) 27 | - [Features](#features) 28 | - [Table of contents](#table-of-contents) 29 | - [Installation](#installation) 30 | - [Usage](#usage) 31 | - [Structure](#structure) 32 | - [Quick Start](#quick-start) 33 | - [Inferencing post Finetuning](#inferencing-post-fine-tuning) 34 | - [Data Preparation](#data-preparation) 35 | - [Setting Configuration](#setting-configuration) 36 | - [Class Methods](#class-methods) 37 | - [Getting Language Models](#getting-language-models) 38 | - [Real Dataset Examples](#real-dataset-examples) 39 | - [Support and Contributions](#support-and-contributions) 40 | - [Acknowledgement](#acknowledgement) 41 | - [License](#license) 42 | 43 | # Installation 44 | Assuming that anaconda environment is already installed, 45 | 46 | - ClassiTransformer can be installed from PyPI using 47 | 48 | ``` 49 | pip install classitransformers 50 | ``` 51 | 52 | - with requirements.txt 53 | 54 | ``` 55 | pip install -r requirements.txt 56 | ``` 57 | 58 | - with yml file, create conda environment 59 | 60 | ``` 61 | conda env create -f environment.yml 62 | source activate env 63 | ``` 64 | 65 | # Usage 66 | 67 | Example notebooks can be found in the `sample_notebooks` directory. 68 | 69 | ## Structure 70 | 71 | * `classitransformers.pytransformers` - Includes all pytorch-based text classification models from transformers library. 72 | * `classitransformers.tfelectra` - Includes tensorflow-based Electra model for text classification 73 | * `classitransformers.tfbert` - Includes tensorflow-based BERT model for text classification 74 | * `classitransformers.report` - Used for reporting performance metrics. (precision, recall, F1, confusion matric) 75 | * `classitransformers.configs` - Used for initializing the hyperparameters of the language models. Also checkas and creates the necessary directories. 76 | * `classitransformers.downloader` - Used for downloading any of the 5 language models. 77 | * `classitransformers.inference` - Used for inferencing in production, Only requires fine-tuned model directory and few hyperparameters. 78 | 79 | 80 | ## Quick Start 81 | Supports text classification with any number of labels. 82 | 83 | ```python 84 | 85 | from classitransformers.report import metrics 86 | from classitransformers.configs import Configs 87 | from classitransformers.pytransformers import TransformersClassification 88 | 89 | config = Configs(pretrained_model_dir = './models/DistilBert/', 90 | model_name='distilbert', 91 | num_train_epochs=3, 92 | train_batch_size = 16, 93 | eval_batch_size = 8, 94 | do_train = True, 95 | do_eval = True, 96 | label_list = ["0", "1", "2"], 97 | max_seq_length = 64, 98 | data_dir = "./datasets/financial_phrasebank/", 99 | output_dir = "./transformers_output_sentiment/") 100 | 101 | model = TransformersClassification(config) 102 | 103 | model.train() 104 | prediction = model.test() 105 | 106 | y_pred = [np.argmax(tup) for tup in prediction] 107 | print(y_pred) 108 | ``` 109 | 110 | ## Inferencing post Fine-tuning 111 | Supports pytorch models to be directly used in deployment/production. 112 | 113 | ```python 114 | from classitransformers.inference import InferenceModel 115 | 116 | model = InferenceModel(num_classes = 3, 117 | max_seq_length = 64, 118 | batch_size = 8, 119 | model_name='distilbert', 120 | classification_model_dir = './transformers_output_sentiment/') 121 | 122 | model.text_inference(['Market gives the best returns among all the options of investemnts']) 123 | ``` 124 | 125 | ## Data Preparation 126 | 127 | The **directory** for input files needs to be **specified** while creating the **config** object. The files should be named **train.csv**, **dev.csv**, **test.csv**. **test.csv** may or may not have the labels. Labels would be useful for generating the report after testing. 128 | 129 | Please check `data_preparation_format.txt` for details. 130 | 131 | ## Setting Configuration 132 | 133 | All the Hyperparameters are listed in the Configs class, and have standard default values. The values can be changed by modifying the parameters passed in the Configs constructor for config object. 134 | 135 | ```python 136 | class Configs(object): 137 | """Fine-tuning hyperparameters.""" 138 | 139 | def __init__(self, pretrained_model_dir = './Albert', 140 | data_dir = "./data/", output_dir = "./albert_output/", 141 | export_dir = None, model_name="albert", 142 | model_size ="base", learning_rate = 5e-5, 143 | num_train_epochs=3.0, train_batch_size = 16, 144 | eval_batch_size = 8, predict_batch_size = 8, do_train = True, 145 | do_eval = True, do_predict = False, label_list = ["0", "1"], 146 | do_lower_case = True, max_seq_length = 256, use_tpu = False, 147 | iterations_per_loop = 1000, save_checkpoint_steps = 1000000, 148 | warmup_proportion = 0.1, export_path ='./exported_bert_model') 149 | ``` 150 | 151 | These are the parameters to be specified for creating the config object of Configs class. 152 | 153 | Args: 154 | * `pretrained_model_dir` : The path for pretrained directory. 155 | * `data_dir` : The path of the directory for the train,dev and test files. 156 | * `output_dir` (optional): The directory where the fine-tuned model will be saved. If not given, model will be saved in the current directory.(checkpoint for TF, .bin for pytorch) 157 | * `export_dir` (optional): The directory where the model to be deployed will be saved.(Currently only for BERT) 158 | * `model_name` : The name of the model. Either of these: 'albert', 'bert', 'electra', 'roberta', 'distilbert' 159 | * `learning_rate`: The learning rate required while training the model. Default is 5e-5. 160 | * `num_training_epochs`: The number of iterations for finetuning the pretrained model for classification task. 161 | * `label_list`: The list of the labels for text classification task. 162 | * `max_seq_length`: Max Sequence Length (multiples of 2) should be ideally just greater than the length of the longest text sentence, to prevent loss of information. 163 | * `export_path`: The export path directory where chkpt format is converted to .pb format. Only set for bert. 164 | 165 | 166 | ## Class Methods 167 | 168 | The class methods do not take any parameters. All the parameters are predefined to improve the clarity of the code. 169 | 170 | **`train()`** 171 | Fine-Tunes(trains) the model and saves the model and config file in the `output_dir` directory. Validation is done after each epoch. 172 | 173 | **`test()`** 174 | Tests the model for test dataset. Returns the prediction labels. 175 | 176 | **`export_model()`** 177 | Exports checkpoint model to .pb fotmat. Used for tensorflow-serviing while inferencing.(Currently only for BERT) 178 | 179 | **`inference()`** 180 | Inference on any input csv in batches using tensorflow serving for .pb model. (Currently only for BERT) 181 | 182 | **`text_inference()`** 183 | Inference on list of sentences as input. 184 | 185 | **`metrics()`** 186 | Prints and returns the accuracy and other metrics. Also prints Confusion Matrix (decorated matrix using matplotlib) 187 | 188 | **`downloader()`** 189 | Downloads pre-trained models just by passing model names in lowercase. 190 | 191 | ## Getting Language Models 192 | 193 | ```python 194 | from classitransformers.downloader import downloader 195 | 196 | # pass name of the model ('albert', 'bert', 'electra', 'roberta', 'distilbert') 197 | downloader('roberta') # Downloads to default dir '../models' 198 | ``` 199 | 200 | ## Real Dataset Examples 201 | 202 | * [BBC News and BBC Sports](http://mlg.ucd.ie/datasets/bbc.html) 203 | * [Financial Phrasebank - 3 class Classification of Financial Statements](https://www.kaggle.com/ankurzing/sentiment-analysis-for-financial-news/kernels) 204 | * [Semeval 2010 Task 8 - Entity Relationship Classification](https://www.aclweb.org/anthology/S10-1006.pdf) 205 | * [Yelp 2013 Dataset - User Rating Classification](https://www.kaggle.com/c/yelp-recsys-2013/data) 206 | 207 | 208 | ## Support and Contributions 209 | 210 | Please submit bug reports and feature requests as Issues. 211 | Contributions are very welcome. 212 | 213 | For additional questions and feedback, please contact us at classitransformers@fmr.com 214 | 215 | ## Acknowledgement 216 | 217 | ClassiTransformers is developed by Emerging Tech Team at Fidelity Investments. 218 | The part of the package was developed as part of an internship program at Fidelity. We thank [Hrushikesh](https://github.com/Hrushikesh-Vazurkar) and [Mayank](https://github.com/mgroovyank) for their contribution to the package. 219 | 220 | ## License 221 | 222 | ClassiTransformers is licensed under the [Apache License 2.0.](LICENSE.md) 223 | -------------------------------------------------------------------------------- /classitransformers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fidelity/classitransformers/ec195296787c2dd4cea445a5916dd0ce66252655/classitransformers/__init__.py -------------------------------------------------------------------------------- /classitransformers/configs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 FMR LLC 2 | # SPDX-License-Identifer: Apache-2.0 3 | 4 | """Set configurations for language models on a downstream classification task.""" 5 | 6 | import os 7 | 8 | class Configs(object): 9 | 10 | """Fine-tuning hyperparameters.""" 11 | 12 | 13 | """ 14 | Parameters 15 | ---------- 16 | 17 | pretrained_model_dir: Directory path of pretrained models. (check downloader.py) 18 | 19 | model_name: Name of the language model in small case (as string) 20 | 21 | learning_rate: The initial learning rate for Adam. 22 | 23 | num_train_epochs: Total number of training epochs to perform. 24 | 25 | train_batch_size: Total batch size for training. 26 | 27 | eval_batch_size: Total batch size for eval. 28 | 29 | predict_batch_size: Total batch size for predict. 30 | 31 | do_train: Whether to run training. 32 | 33 | do_eval: Whether to run eval on the dev set. 34 | 35 | do_predict: Whether to run training. 36 | 37 | label_list: Number of classes in data 38 | 39 | do_lower_case: Whether to lower case the input text. 40 | Should be True for uncased models and False for cased models 41 | 42 | max_seq_length: The maximum total input sequence length after WordPiece tokenization. 43 | Sequences longer than this will be truncated, and shorter than this will be padded. 44 | 45 | data_dir: The input data dir. Should contain the .csv files 46 | (train.csv, dev.csv, test.csv) for the task. 47 | 48 | output_dir: The output directory where the model checkpoints will be written. 49 | 50 | num_tpu_cores: Only used if `use_tpu` is True. Total number of TPU cores to use. 51 | 52 | gcp_project: [Optional] Project name for the Cloud TPU-enabled project. 53 | If not specified, we will attempt to automatically detect the GCE project from metadata. 54 | 55 | master: [Optional] TensorFlow master URL. 56 | 57 | tpu_zone: [Optional] GCE zone where the Cloud TPU is located in. 58 | If not specified, we will attempt to automatically detect the GCE project from metadata. 59 | 60 | tpu_name: [Optional] GCE zone where the Cloud TPU is located in. 61 | If not specified, we will attempt to automatically detect the GCE project from metadata. 62 | 63 | use_tpu: Whether to use TPU or GPU/CPU. 64 | 65 | iterations_per_loop: How many steps to make in each estimator call. 66 | 67 | save_checkpoints_steps: How often to save the model checkpoint. 68 | 69 | warmup_proportion: Proportion of training to perform linear learning rate warmup for. 70 | E.g., 0.1 = 10% of training. 71 | 72 | export_dir: The output model (.pb format) dir, where the freezed graph and weights get stored (only for BERT). 73 | 74 | tokenizer: Wordpiece tokenizer object. 75 | 76 | export_path: Path of saved .pb model for inferencing purpose (Only for BERT). 77 | 78 | """ 79 | 80 | 81 | def __init__(self, pretrained_model_dir = './Albert', 82 | data_dir = "./data/", 83 | output_dir = "./albert_output/", 84 | export_dir = None, 85 | model_name="", model_size ="base", 86 | learning_rate = 5e-5, num_train_epochs=3.0, train_batch_size = 16, 87 | eval_batch_size = 8, predict_batch_size = 8, do_train = True, 88 | do_eval = True, do_predict = False, label_list = ["0", "1"], 89 | do_lower_case = True, max_seq_length = 256, use_tpu = False, 90 | iterations_per_loop = 1000, save_checkpoint_steps = 1000000, 91 | warmup_proportion = 0.1, export_path ='./exported_bert_model'): 92 | 93 | 94 | # default locations of required files 95 | self.data_dir = data_dir 96 | self.output_dir = output_dir 97 | self.model_name = model_name 98 | self.pretrained_model_dir = pretrained_model_dir 99 | 100 | # general 101 | if self.model_name == "": 102 | raise ValueError('Kindly specify the model name (bert, electra, albert, roberta, distilbert)') 103 | 104 | if not os.path.exists(self.data_dir): 105 | raise ValueError(self.data_dir, ' not found !!') 106 | 107 | if not os.path.exists(self.pretrained_model_dir): 108 | raise ValueError(pretrained_model_dir, ' pretrained model not found !!') 109 | 110 | if not os.path.exists(output_dir): 111 | os.makedirs(output_dir) 112 | 113 | # Common paramteres for all models 114 | 115 | 116 | self.do_train = do_train # train a model ? 117 | self.do_eval = do_eval # evaluate the model ? 118 | self.do_predict = do_predict 119 | self.label_list = label_list 120 | 121 | self.num_train_epochs = num_train_epochs # passes over the dataset during training 122 | self.warmup_proportion = warmup_proportion # how much of training to warm up the LR for 123 | self.learning_rate = learning_rate 124 | self.do_lower_case = do_lower_case 125 | 126 | # Params related to sizes, common to all 127 | 128 | self.max_seq_length = max_seq_length 129 | self.train_batch_size = train_batch_size 130 | self.eval_batch_size = eval_batch_size 131 | self.predict_batch_size = predict_batch_size 132 | 133 | # model 134 | self.embedding_size = None # bert hidden size by default 135 | self.vocab_size = 30522 # number of tokens in the vocabulary 136 | 137 | # training 138 | self.weight_decay_rate = 0.01 139 | self.layerwise_lr_decay = 0.8 140 | self.save_checkpoints_steps = save_checkpoint_steps 141 | self.iterations_per_loop = iterations_per_loop 142 | self.use_tfrecords_if_existing = False 143 | 144 | # TPU settings: Not required, just default init (required for execution) 145 | self.use_tpu = False 146 | self.num_tpu_cores = 1 147 | self.tpu_job_name = None 148 | self.tpu_name = None # cloud TPU to use for training 149 | self.tpu_zone = None # GCE zone where the Cloud TPU is located in 150 | self.gcp_project = None # project name for the Cloud TPU-enabled project 151 | self.master = None 152 | 153 | self.vocab_file = os.path.join(pretrained_model_dir, "vocab.txt") 154 | 155 | # Specific to BERT tensorflow model 156 | if self.model_name == 'bert': 157 | self.init_checkpoint = os.path.join(pretrained_model_dir,'bert_model.ckpt') 158 | self.bert_config_file = os.path.join(pretrained_model_dir,'bert_config.json') 159 | if not os.path.exists(self.bert_config_file): 160 | raise ValueError('bert_config.json not found in pretrained_model_dir!!') 161 | if not os.path.exists(self.vocab_file): 162 | raise ValueError('vocab.txt not found in pretrained_model_dir!!') 163 | 164 | self.export_path = export_path 165 | self.export_dir = export_dir 166 | 167 | if not os.path.exists(self.export_path): 168 | os.makedirs(self.export_path) 169 | 170 | # Specific to Electra Tensorflow model 171 | elif self.model_name == 'electra': 172 | self.init_checkpoint = pretrained_model_dir 173 | self.model_size = model_size # one of "small", "base", or "large" 174 | self.task_names = ["textclassification"] # which tasks to learn 175 | self.model_hparam_overrides = {} 176 | self.num_trials = 1 # how many train+eval runs to perform 177 | self.keep_all_models = True 178 | self.log_examples = False 179 | 180 | if not os.path.exists(self.vocab_file): 181 | raise ValueError('vocab.txt not found in pretrained_model_dir!!') 182 | 183 | # default hyperparameters for different model sizes 184 | if self.model_size == "large": 185 | self.learning_rate = 5e-5 186 | self.layerwise_lr_decay = 0.9 187 | elif self.model_size == "small": 188 | self.embedding_size = 128 189 | 190 | # update defaults with passed-in hyperparameters 191 | self.tasks = { 192 | "textclassification":{ 193 | "type":"classification", 194 | "labels":label_list, 195 | "header":True, 196 | "text_column":1, 197 | "label_column":2 198 | } 199 | } 200 | -------------------------------------------------------------------------------- /classitransformers/downloader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 FMR LLC 2 | # SPDX-License-Identifer: Apache-2.0 3 | 4 | import os 5 | import wget 6 | import shutil 7 | from zipfile import ZipFile 8 | 9 | 10 | def downloader(model, destination_path='../models'): 11 | 12 | """ 13 | Function to download pre-trained models from hugginf face aws repository OR Google's storgage. 14 | 15 | Parameters 16 | ---------- 17 | model : Pick one of these ['bert','electra','roberta','distilbert','albert'] 18 | download_path : Directory whether the model has to be downloaded 19 | 20 | Returns 21 | ------- 22 | 23 | """ 24 | 25 | model_download = { 26 | "bert": 'Bert_base', 27 | "electra": 'Electra_base', 28 | "roberta": 'Roberta', 29 | "distilbert" : 'DistilBert', 30 | "albert": 'Albert' 31 | } 32 | 33 | if model not in model_download: 34 | print("Please pick model name from ['bert','electra','roberta','distilbert','albert']") 35 | return None 36 | 37 | output_dir = os.path.join(destination_path, model_download[model]) 38 | print("Model gets downloaded here: ", output_dir) 39 | 40 | # Create output directory if needed 41 | if not os.path.exists(output_dir): 42 | os.makedirs(output_dir) 43 | 44 | if model == 'roberta': 45 | 46 | 47 | config_file = wget.download('https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json', 48 | os.path.join(output_dir,'config.json')) 49 | vocab_file = wget.download('https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json', 50 | os.path.join(output_dir,'vocab.json')) 51 | merges_file = wget.download('https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt', 52 | os.path.join(output_dir,'merges.txt')) 53 | model_file = wget.download('https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-pytorch_model.bin', 54 | os.path.join(output_dir,'pytorch_model.bin')) 55 | 56 | if model == 'albert': 57 | config_file = wget.download('https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-config.json', 58 | os.path.join(output_dir,'config.json')) 59 | spiece_file = wget.download('https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-spiece.model', 60 | os.path.join(output_dir,'spiece.model')) 61 | model_file = wget.download('https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-pytorch_model.bin', 62 | os.path.join(output_dir,'pytorch_model.bin')) 63 | 64 | if model == 'distilbert': 65 | config_file = wget.download('https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json', 66 | os.path.join(output_dir,'config.json')) 67 | vocab_file = wget.download('https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt', 68 | os.path.join(output_dir,'vocab.json')) 69 | model_file = wget.download('https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-pytorch_model.bin', 70 | os.path.join(output_dir,'pytorch_model.bin')) 71 | 72 | 73 | if model == 'bert': 74 | zip_file = wget.download('https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-12_H-768_A-12.zip', 75 | os.path.join(output_dir,'uncased_L-12_H-768_A-12.zip')) 76 | 77 | with ZipFile(os.path.join(output_dir,'uncased_L-12_H-768_A-12.zip'), "r") as zip_ref: 78 | zip_ref.extractall(output_dir) 79 | 80 | os.remove(os.path.join(output_dir,'uncased_L-12_H-768_A-12.zip')) 81 | 82 | if model == 'electra': 83 | zip_file = wget.download('https://storage.googleapis.com/electra-data/electra_base.zip', 84 | os.path.join(output_dir,'electra_base.zip')) 85 | 86 | with ZipFile(os.path.join(output_dir,'electra_base.zip'), "r") as zip_ref: 87 | zip_ref.extractall(output_dir) 88 | 89 | files = os.listdir(os.path.join(output_dir,'electra_base')) 90 | for f in files: 91 | shutil.move(os.path.join(output_dir,'electra_base',f), output_dir) 92 | 93 | os.remove(os.path.join(output_dir,'electra_base.zip')) 94 | os.rmdir(os.path.join(output_dir,'electra_base')) -------------------------------------------------------------------------------- /classitransformers/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 FMR LLC 2 | # SPDX-License-Identifer: Apache-2.0 3 | 4 | import os 5 | import sys 6 | import time 7 | import torch 8 | import scipy 9 | import random 10 | import logging 11 | import datetime 12 | import numpy as np 13 | import pandas as pd 14 | from torch.utils.data import TensorDataset, random_split 15 | from transformers import AdamW, get_linear_schedule_with_warmup 16 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 17 | from transformers import DistilBertTokenizer, AlbertTokenizer, RobertaTokenizer, BertTokenizer 18 | from transformers import DistilBertForSequenceClassification, AlbertForSequenceClassification, RobertaForSequenceClassification, BertForSequenceClassification 19 | 20 | logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) 21 | 22 | 23 | class InferenceModel: 24 | 25 | def __init__(self, num_classes, max_seq_length, batch_size, model_name, model_path): 26 | 27 | self.num_classes = num_classes 28 | self.classification_model_dir = model_path 29 | self.max_seq_length = max_seq_length 30 | self.predict_batch_size = batch_size 31 | self.model_name = model_name 32 | 33 | if torch.cuda.is_available(): 34 | self.device = torch.device("cuda") 35 | else: 36 | print('No GPU available, using the CPU instead.') 37 | self.device = torch.device("cpu") 38 | 39 | if self.model_name == 'bert': 40 | self.model = BertForSequenceClassification.from_pretrained(self.classification_model_dir, 41 | num_labels=self.num_classes) 42 | self.tokenizer = BertTokenizer.from_pretrained(self.classification_model_dir) 43 | 44 | if self.model_name == 'albert': 45 | self.model = AlbertForSequenceClassification.from_pretrained(self.classification_model_dir, 46 | num_labels=self.num_classes) 47 | self.tokenizer = AlbertTokenizer.from_pretrained(self.classification_model_dir) 48 | 49 | if self.model_name == 'distilbert': 50 | self.model = DistilBertForSequenceClassification.from_pretrained(self.classification_model_dir, 51 | num_labels=self.num_classes) 52 | self.tokenizer = DistilBertTokenizer.from_pretrained(self.classification_model_dir) 53 | 54 | if self.model_name == 'roberta': 55 | self.model = RobertaForSequenceClassification.from_pretrained(self.classification_model_dir, 56 | num_labels=self.num_classes) 57 | self.tokenizer = RobertaTokenizer.from_pretrained(self.classification_model_dir) 58 | 59 | if torch.cuda.is_available(): 60 | self.model.cuda() 61 | 62 | 63 | 64 | 65 | def text_inference(self, texts): 66 | 67 | inf_labels = [] 68 | for i in range(len(texts)): 69 | inf_labels.append(0) 70 | 71 | test_params = tokenize_sentences(texts, self.max_seq_length, self.tokenizer, inf_labels) 72 | prediction_data = TensorDataset(test_params['input_ids'], test_params['attention_masks'], 73 | test_params['labels']) 74 | prediction_sampler = SequentialSampler(prediction_data) 75 | prediction_dataloader = DataLoader(prediction_data, sampler=prediction_sampler, 76 | batch_size=self.predict_batch_size) 77 | 78 | self.model.eval() 79 | predictions = [] 80 | 81 | for batch in prediction_dataloader: 82 | b_input_ids, b_input_mask, b_labels = tuple(t.to(self.device) for t in batch) 83 | with torch.no_grad(): 84 | loss, logits = self.model(b_input_ids, 85 | attention_mask=b_input_mask, 86 | labels=b_labels) 87 | 88 | logits = logits.detach().cpu().numpy() 89 | predictions.extend(logits) 90 | 91 | torch.cuda.empty_cache() 92 | return [scipy.special.expit(pred) for pred in predictions] 93 | 94 | 95 | def tokenize_sentences(sentences,max_seq_len,tokenizer,labels = []): 96 | # Tokenize all of the sentences and map the tokens to thier word IDs. 97 | input_ids = [] 98 | attention_masks = [] 99 | 100 | for sent in sentences: 101 | 102 | encoded_dict = tokenizer.encode_plus( 103 | sent, # Sentence to encode. 104 | add_special_tokens=True, # Add '[CLS]' and '[SEP]' 105 | max_length=max_seq_len, # Pad & truncate all sentences. 106 | pad_to_max_length=True, 107 | return_attention_mask=True, # Construct attn. masks. 108 | return_tensors='pt', # Return pytorch tensors. 109 | ) 110 | 111 | # Add the encoded sentence to the list. 112 | input_ids.append(encoded_dict['input_ids']) 113 | 114 | # And its attention mask (simply differentiates padding from non-padding). 115 | attention_masks.append(encoded_dict['attention_mask']) 116 | 117 | # Convert the lists into tensors. 118 | input_ids = torch.cat(input_ids, dim=0) 119 | attention_masks = torch.cat(attention_masks, dim=0) 120 | try: 121 | labels = torch.tensor(labels) 122 | except: 123 | labels = [] 124 | for i in range(len(sentences)): 125 | labels.append[0] 126 | 127 | return { 128 | 'input_ids': input_ids, 129 | 'attention_masks': attention_masks, 130 | 'labels': labels 131 | } 132 | -------------------------------------------------------------------------------- /classitransformers/pytransformers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 FMR LLC 2 | # SPDX-License-Identifer: Apache-2.0 3 | 4 | import os 5 | import sys 6 | import time 7 | import torch 8 | import scipy 9 | import random 10 | import logging 11 | import datetime 12 | import numpy as np 13 | import pandas as pd 14 | from torch.utils.data import TensorDataset, random_split 15 | from transformers import AdamW, get_linear_schedule_with_warmup 16 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 17 | from transformers import DistilBertTokenizer, AlbertTokenizer, RobertaTokenizer, BertTokenizer 18 | from transformers import DistilBertForSequenceClassification, AlbertForSequenceClassification, RobertaForSequenceClassification, BertForSequenceClassification 19 | 20 | logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) 21 | 22 | 23 | class TransformersClassification: 24 | 25 | def __init__(self, config): 26 | # self.name, self.num_classes, epochs, batchs 27 | 28 | self.Configs = config 29 | self.num_classes = len(config.label_list) 30 | 31 | self.train_logits = [] 32 | self.validation_logits = [] 33 | self.test_logits = [] 34 | 35 | self.train_texts = [] 36 | self.train_labels = [] 37 | self.validation_texts = [] 38 | self.validation_labels = [] 39 | self.test_texts = [] 40 | self.test_labels = [] 41 | 42 | train = pd.read_csv(os.path.join(self.Configs.data_dir, 'train.csv')) 43 | 44 | try: 45 | dev = pd.read_csv(os.path.join(self.Configs.data_dir, 'dev.csv')) 46 | 47 | except: 48 | print('Validation disabled.') 49 | test = pd.read_csv(os.path.join(self.Configs.data_dir, 'test.csv')) 50 | 51 | self.train_texts = train['text'].tolist() 52 | 53 | self.train_labels = train['label'].tolist() 54 | 55 | try: 56 | self.validation_texts = dev['text'].tolist() 57 | self.validation_labels = dev['label'].tolist() 58 | 59 | except: 60 | pass 61 | self.test_texts = test['text'].tolist() 62 | 63 | for i in range(len(self.test_texts)): 64 | self.test_labels.append(0) 65 | 66 | 67 | if torch.cuda.is_available(): 68 | self.device = torch.device("cuda") 69 | else: 70 | print('No GPU available, using the CPU instead.') 71 | self.device = torch.device("cpu") 72 | 73 | if self.Configs.model_name == 'bert': 74 | self.model = BertForSequenceClassification.from_pretrained(self.Configs.pretrained_model_dir, 75 | num_labels=self.num_classes) 76 | self.tokenizer = BertTokenizer.from_pretrained(self.Configs.pretrained_model_dir) 77 | 78 | if self.Configs.model_name == 'albert': 79 | self.model = AlbertForSequenceClassification.from_pretrained(self.Configs.pretrained_model_dir, 80 | num_labels=self.num_classes) 81 | self.tokenizer = AlbertTokenizer.from_pretrained(self.Configs.pretrained_model_dir) 82 | 83 | if self.Configs.model_name == 'distilbert': 84 | self.model = DistilBertForSequenceClassification.from_pretrained(self.Configs.pretrained_model_dir, 85 | num_labels=self.num_classes) 86 | self.tokenizer = DistilBertTokenizer.from_pretrained(self.Configs.pretrained_model_dir) 87 | 88 | if self.Configs.model_name == 'roberta': 89 | self.model = RobertaForSequenceClassification.from_pretrained(self.Configs.pretrained_model_dir, 90 | num_labels=self.num_classes) 91 | self.tokenizer = RobertaTokenizer.from_pretrained(self.Configs.pretrained_model_dir) 92 | 93 | if torch.cuda.is_available(): 94 | self.model.cuda() 95 | 96 | 97 | def train(self): 98 | 99 | # Combine the training inputs into a TensorDataset. 100 | train_params = tokenize_sentences(self.train_texts,self.Configs.max_seq_length,self.tokenizer,self.train_labels) 101 | train_dataset = TensorDataset(train_params['input_ids'], train_params['attention_masks'], 102 | train_params['labels']) 103 | 104 | if self.Configs.do_eval == True: 105 | 106 | validation_params = tokenize_sentences(self.validation_texts, 107 | self.Configs.max_seq_length, self.tokenizer, self.validation_labels) 108 | validation_dataset = TensorDataset(validation_params['input_ids'], validation_params['attention_masks'], 109 | validation_params['labels']) 110 | 111 | 112 | # Create the DataLoaders for our training and validation sets. 113 | # We'll take training samples in random order. 114 | train_dataloader = DataLoader( 115 | train_dataset, # The training samples. 116 | sampler=RandomSampler(train_dataset), # Select batches randomly 117 | batch_size=self.Configs.train_batch_size # Trains with this batch size. 118 | ) 119 | 120 | if self.Configs.do_eval == True: 121 | validation_dataloader = DataLoader( 122 | validation_dataset, # The training samples. 123 | sampler=RandomSampler(validation_dataset), # Select batches randomly 124 | batch_size=self.Configs.eval_batch_size # Trains with this batch size. 125 | ) 126 | 127 | # Note: AdamW is a class from the huggingface library (as opposed to pytorch) 128 | # I believe the 'W' stands for 'Weight Decay fix" 129 | optimizer = AdamW(self.model.parameters(), 130 | lr=self.Configs.learning_rate, # default is 5e-5 131 | eps=1e-8 # args.adam_epsilon - default is 1e-8. 132 | ) 133 | 134 | 135 | epochs = self.Configs.num_train_epochs 136 | 137 | # Total number of training steps is [number of batches] x [number of epochs]. 138 | # (Note that this is not the same as the number of training samples). 139 | total_steps = len(train_dataloader) * epochs 140 | 141 | # Create the learning rate scheduler. 142 | scheduler = get_linear_schedule_with_warmup(optimizer, 143 | num_warmup_steps=total_steps / 10, # Default value in run_glue.py 144 | num_training_steps=total_steps) 145 | 146 | 147 | # Set the seed value all over the place to make this reproducible. 148 | seed_val = 42 149 | 150 | random.seed(seed_val) 151 | np.random.seed(seed_val) 152 | torch.manual_seed(seed_val) 153 | torch.cuda.manual_seed_all(seed_val) 154 | 155 | # We'll store a number of quantities such as training and validation loss, 156 | # validation accuracy, and timings. 157 | training_stats = [] 158 | 159 | # Measure the total training time for the whole run. 160 | total_t0 = time.time() 161 | 162 | # For each epoch... 163 | for epoch_i in range(0, epochs): 164 | 165 | # Perform one full pass over the training set. 166 | 167 | print("") 168 | print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs)) 169 | print('Training...') 170 | 171 | # Measure how long the training epoch takes. 172 | t0 = time.time() 173 | 174 | # Reset the total loss for this epoch. 175 | total_train_loss = 0 176 | train_logits = [] 177 | validation_logits = [] 178 | 179 | self.model.train() 180 | 181 | # For each batch of training data... 182 | for step, batch in enumerate(train_dataloader): 183 | 184 | # Progress update every 40 batches. 185 | if step % 40 == 0 and not step == 0: 186 | # Calculate elapsed time in minutes. 187 | elapsed = format_time(time.time() - t0) 188 | 189 | # Report progress. 190 | print( 191 | ' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format(step, len(train_dataloader), elapsed)) 192 | 193 | b_input_ids = batch[0].to(self.device) 194 | b_input_mask = batch[1].to(self.device) 195 | b_labels = batch[2].to(self.device) 196 | 197 | self.model.zero_grad() 198 | 199 | loss, logits = self.model(b_input_ids, 200 | 201 | attention_mask=b_input_mask, 202 | labels=b_labels) 203 | 204 | logits = logits.detach().cpu().numpy() 205 | label_ids = b_labels.to('cpu').numpy() 206 | train_logits.append(logits) 207 | 208 | # Accumulate the training loss over all of the batches so that we can 209 | # calculate the average loss at the end. 210 | total_train_loss += loss.item() 211 | 212 | # Perform a backward pass to calculate the gradients. 213 | loss.backward() 214 | 215 | # Clip the norm of the gradients to 1.0. 216 | # This is to help prevent the "exploding gradients" problem. 217 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) 218 | 219 | # Update parameters and take a step using the computed gradient. 220 | optimizer.step() 221 | 222 | # Update the learning rate. 223 | scheduler.step() 224 | 225 | # Calculate the average loss over all of the batches. 226 | avg_train_loss = total_train_loss / len(train_dataloader) 227 | 228 | # Measure how long this epoch took. 229 | training_time = format_time(time.time() - t0) 230 | 231 | print("") 232 | print(" Average training loss: {0:.2f}".format(avg_train_loss)) 233 | print(" Training epcoh took: {:}".format(training_time)) 234 | 235 | torch.cuda.empty_cache() 236 | 237 | if self.Configs.do_eval == True: 238 | 239 | # After the completion of each training epoch, measure our performance on 240 | # our validation set. 241 | 242 | print("") 243 | print("Running Validation...") 244 | 245 | t0 = time.time() 246 | 247 | # Put the model in evaluation mode--the dropout layers behave differently 248 | # during evaluation. 249 | self.model.eval() 250 | 251 | # Tracking variables 252 | total_eval_accuracy = 0 253 | total_eval_loss = 0 254 | nb_eval_steps = 0 255 | 256 | # Evaluate data for one epoch 257 | for batch in validation_dataloader: 258 | 259 | b_input_ids = batch[0].to(self.device) 260 | b_input_mask = batch[1].to(self.device) 261 | b_labels = batch[2].to(self.device) 262 | 263 | # Tell pytorch not to bother with constructing the compute graph during 264 | # the forward pass, since this is only needed for backprop (training). 265 | with torch.no_grad(): 266 | 267 | loss, logits = self.model(b_input_ids, 268 | attention_mask=b_input_mask, 269 | labels=b_labels) 270 | 271 | # Move logits and labels to CPU 272 | logits = logits.detach().cpu().numpy() 273 | label_ids = b_labels.to('cpu').numpy() 274 | validation_logits.append(logits) 275 | # Accumulate the validation loss. 276 | # total_eval_loss += loss.item() 277 | 278 | # Calculate the accuracy for this batch of test sentences, and 279 | # accumulate it over all batches. 280 | total_eval_accuracy += flat_accuracy(logits, label_ids) 281 | 282 | # Report the final accuracy for this validation run. 283 | avg_val_accuracy = total_eval_accuracy / len(validation_dataloader) 284 | print(" Accuracy: {0:.2f}".format(avg_val_accuracy)) 285 | 286 | # Calculate the average loss over all of the batches. 287 | avg_val_loss = total_eval_loss / len(validation_dataloader) 288 | 289 | # Measure how long the validation run took. 290 | validation_time = format_time(time.time() - t0) 291 | 292 | print(" Validation Loss: {0:.2f}".format(avg_val_loss)) 293 | print(" Validation took: {:}".format(validation_time)) 294 | 295 | 296 | # torch.cuda.empty_cache() 297 | 298 | # Record all statistics from this epoch. 299 | training_stats.append( 300 | { 301 | 'epoch': epoch_i + 1, 302 | 'Training Loss': avg_train_loss, 303 | 'Valid. Loss': avg_val_loss if self.Configs.do_eval == True else 'NA', 304 | 'Valid. Accur.': avg_val_accuracy if self.Configs.do_eval == True else 'NA', 305 | 'Training Time': training_time, 306 | 'Validation Time': validation_time if self.Configs.do_eval == True else 'NA' 307 | } 308 | ) 309 | 310 | if self.Configs.do_eval == True: 311 | validation_logits = np.vstack(validation_logits) 312 | 313 | train_logits = np.vstack(train_logits) 314 | 315 | self.train_logits = train_logits 316 | self.validation_logits = validation_logits 317 | 318 | # torch.state_dict(self.model,'/content/checkpoint.pth') 319 | output_dir = self.Configs.output_dir 320 | 321 | # Create output directory if needed 322 | if not os.path.exists(output_dir): 323 | os.makedirs(output_dir) 324 | 325 | print("Saving model to %s" % output_dir) 326 | 327 | # Save a trained model, configuration and tokenizer using `save_pretrained()`. 328 | # They can then be reloaded using `from_pretrained()` 329 | model_to_save = self.model.module if hasattr(self.model, 'module') else self.model # Take care of distributed/parallel training 330 | model_to_save.save_pretrained(output_dir) 331 | self.tokenizer.save_pretrained(output_dir) 332 | 333 | print("") 334 | print("Training complete!") 335 | 336 | print("Total training took {:} (h:mm:ss)".format(format_time(time.time() - total_t0))) 337 | 338 | 339 | def test(self): 340 | 341 | test_params = tokenize_sentences(self.test_texts,self.Configs.max_seq_length,self.tokenizer,self.test_labels) 342 | 343 | prediction_data = TensorDataset(test_params['input_ids'], test_params['attention_masks'], test_params['labels']) 344 | prediction_sampler = SequentialSampler(prediction_data) 345 | prediction_dataloader = DataLoader(prediction_data, sampler=prediction_sampler, batch_size=self.Configs.predict_batch_size) 346 | 347 | 348 | # Put model in evaluation mode 349 | self.model.eval() 350 | 351 | # Tracking variables 352 | predictions = [] 353 | 354 | # Predict 355 | for batch in prediction_dataloader: 356 | # Add batch to GPU 357 | # Unpack the inputs from our dataloader 358 | b_input_ids, b_input_mask, b_labels = tuple(t.to(self.device) for t in batch) 359 | 360 | # Telling the model not to compute or store gradients, saving memory and 361 | # speeding up prediction 362 | with torch.no_grad(): 363 | # Forward pass, calculate logit predictions 364 | loss, logits = self.model(b_input_ids, 365 | attention_mask=b_input_mask, 366 | labels=b_labels) 367 | 368 | # Move logits and labels to CPU 369 | logits = logits.detach().cpu().numpy() 370 | 371 | # Store predictions and true labels 372 | predictions.extend(logits) 373 | 374 | self.test_logits = predictions 375 | torch.cuda.empty_cache() 376 | 377 | return [scipy.special.expit(pred) for pred in predictions] 378 | 379 | 380 | def text_inference(self, texts): 381 | 382 | inf_labels = [] 383 | for i in range(len(texts)): 384 | inf_labels.append(0) 385 | 386 | test_params = tokenize_sentences(texts, self.Configs.max_seq_length, self.tokenizer, inf_labels) 387 | prediction_data = TensorDataset(test_params['input_ids'], test_params['attention_masks'], 388 | test_params['labels']) 389 | prediction_sampler = SequentialSampler(prediction_data) 390 | prediction_dataloader = DataLoader(prediction_data, sampler=prediction_sampler, 391 | batch_size=self.Configs.predict_batch_size) 392 | 393 | self.model.eval() 394 | predictions = [] 395 | 396 | for batch in prediction_dataloader: 397 | b_input_ids, b_input_mask, b_labels = tuple(t.to(self.device) for t in batch) 398 | with torch.no_grad(): 399 | loss, logits = self.model(b_input_ids, 400 | attention_mask=b_input_mask, 401 | labels=b_labels) 402 | 403 | logits = logits.detach().cpu().numpy() 404 | predictions.extend(logits) 405 | 406 | self.test_logits = predictions 407 | torch.cuda.empty_cache() 408 | return [scipy.special.expit(pred) for pred in predictions] 409 | 410 | 411 | def add_CLS_and_SEP(sentences,tokenizer): 412 | max_len = 0 413 | 414 | for sent in sentences: 415 | # Tokenize the text and add `[CLS]` and `[SEP]` tokens. 416 | input_ids = tokenizer.encode(sent, add_special_tokens=True) 417 | 418 | # Update the maximum sentence length. 419 | max_len = max(max_len, len(input_ids)) 420 | 421 | return sentences 422 | 423 | 424 | def tokenize_sentences(sentences,max_seq_len,tokenizer,labels = []): 425 | # Tokenize all of the sentences and map the tokens to thier word IDs. 426 | input_ids = [] 427 | attention_masks = [] 428 | 429 | sentences = add_CLS_and_SEP(sentences,tokenizer) 430 | 431 | for sent in sentences: 432 | 433 | encoded_dict = tokenizer.encode_plus( 434 | sent, # Sentence to encode. 435 | add_special_tokens=True, # Add '[CLS]' and '[SEP]' 436 | max_length=max_seq_len, # Pad & truncate all sentences. 437 | pad_to_max_length=True, 438 | return_attention_mask=True, # Construct attn. masks. 439 | return_tensors='pt', # Return pytorch tensors. 440 | ) 441 | 442 | # Add the encoded sentence to the list. 443 | input_ids.append(encoded_dict['input_ids']) 444 | 445 | # And its attention mask (simply differentiates padding from non-padding). 446 | attention_masks.append(encoded_dict['attention_mask']) 447 | 448 | # Convert the lists into tensors. 449 | input_ids = torch.cat(input_ids, dim=0) 450 | attention_masks = torch.cat(attention_masks, dim=0) 451 | try: 452 | labels = torch.tensor(labels) 453 | except: 454 | labels = [] 455 | for i in range(len(sentences)): 456 | labels.append[0] 457 | 458 | return { 459 | 'input_ids': input_ids, 460 | 'attention_masks': attention_masks, 461 | 'labels': labels 462 | } 463 | 464 | 465 | 466 | # Function to calculate the accuracy of our predictions vs labels 467 | def flat_accuracy(preds, labels): 468 | pred_flat = np.argmax(preds, axis=1).flatten() 469 | labels_flat = labels.flatten() 470 | return np.sum(pred_flat == labels_flat) / len(labels_flat) 471 | 472 | 473 | 474 | 475 | def format_time(elapsed): 476 | ''' 477 | Takes a time in seconds and returns a string hh:mm:ss 478 | ''' 479 | # Round to the nearest second. 480 | elapsed_rounded = int(round((elapsed))) 481 | 482 | # Format as hh:mm:ss 483 | return str(datetime.timedelta(seconds=elapsed_rounded)) -------------------------------------------------------------------------------- /classitransformers/report.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 FMR LLC 2 | # SPDX-License-Identifer: Apache-2.0 3 | 4 | import numpy as np 5 | from itertools import product 6 | import matplotlib.pyplot as plt 7 | from sklearn.metrics import accuracy_score, confusion_matrix 8 | from sklearn.metrics import precision_score, recall_score, f1_score 9 | 10 | 11 | def plot_confusion_matrix(cm, classes = [0, 1], cmap=plt.cm.Accent): 12 | 13 | """ function to plot confusion matrix """ 14 | 15 | title = 'Confusion matrix' 16 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 17 | plt.title(title) 18 | plt.colorbar() 19 | tick_marks = np.arange(len(classes)) 20 | plt.xticks(tick_marks, classes, rotation=45) 21 | plt.yticks(tick_marks, classes) 22 | 23 | fmt = 'd' 24 | thresh = cm.max() / 2. 25 | for i, j in product(range(cm.shape[0]), range(cm.shape[1])): 26 | plt.text(j, i, format(cm[i, j], fmt), 27 | horizontalalignment="center", 28 | color="black") 29 | 30 | plt.tight_layout() 31 | plt.ylabel('True label') 32 | plt.xlabel('Predicted label') 33 | 34 | 35 | 36 | def metrics(y, y_pred, average='macro'): 37 | 38 | """ 39 | uses sklearn metrics to calculate accuracy, precision, recall, f1_score. 40 | 41 | Parameters 42 | ---------- 43 | y : actual labels. array-like of shape (n_samples,) 44 | y_pred : predicted labels. array-like of shape (n_samples,) 45 | average: determines the type of averaging performed on the data. default: macro 46 | """ 47 | 48 | cm = confusion_matrix(y, y_pred) 49 | plot_confusion_matrix(cm, np.arange(cm.shape[0])) 50 | print("\n") 51 | print("Accuracy: {0:.3f}".format(accuracy_score(y, y_pred))) 52 | print("Precision: {0:.3f}".format(precision_score(y, y_pred, average=average))) 53 | print("Recall: {0:.3f}".format(recall_score(y, y_pred, average=average))) 54 | print("F1-Score: {0:.3f}".format(f1_score(y, y_pred, average=average))) -------------------------------------------------------------------------------- /data_preparation_format.txt: -------------------------------------------------------------------------------- 1 | For every task, there is data directory parameter as part of Configs class e.g. data_dir='./datasets/name_of_dataset_folder/' 2 | 3 | The data directory should have the following comma separated (csv format) files. 4 | 5 | train.csv (train set) 6 | test.csv (test set) 7 | dev.csv (validation set) 8 | 9 | 10 | train.csv - Used for training the model. 11 | dev.csv - Used as validation set during the model training. 12 | test.csv - Used for scoring purpose after model training finished. 13 | 14 | 15 | Both the training and validation set should contain the following columns in sequence. 16 | 17 | id 18 | text 19 | label (numeric, not text categories) 20 | 21 | 22 | For example. 23 | 24 | id, text, label 25 | 0, Raisio is the site of the main production plan, 2 26 | 1, The production capacity can be tripled without, 0 27 | 28 | 29 | Test set can have the following columns as labels may not be there. 30 | 31 | id 32 | text 33 | 34 | For example. 35 | 36 | id, text 37 | 0, Operating profit was EUR 11.4 mn , up from EURO 38 | 1, The expansion will be delivered in the fourth 39 | -------------------------------------------------------------------------------- /datasets/bbcfull/infer.csv: -------------------------------------------------------------------------------- 1 | id,text 2 | 0,He is very good in writing a book on financial investment money making strategy. There are some transactional charges but the profit we make is awesome. The business people are exchanging bulk money over trades happening every seconds. Even retailers are happy doing trade and adding sum to their investments. 3 | -------------------------------------------------------------------------------- /datasets/bbcsports/infer.csv: -------------------------------------------------------------------------------- 1 | id,text 2 | 0,Rodger Fedrer is one of the great tennis player. Rafael Nadal is very good player at clay court. They have combined won more than 30 grand slams till now. 3 | -------------------------------------------------------------------------------- /electra/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. -------------------------------------------------------------------------------- /electra/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Functions and classes related to optimization (weight updates). 17 | Modified from the original BERT code to allow for having separate learning 18 | rates for different layers of the network. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import collections 26 | import re 27 | import tensorflow.compat.v1 as tf 28 | 29 | 30 | def create_optimizer( 31 | loss, learning_rate, num_train_steps, weight_decay_rate=0.0, use_tpu=False, 32 | warmup_steps=0, warmup_proportion=0, lr_decay_power=1.0, 33 | layerwise_lr_decay_power=-1, n_transformer_layers=None): 34 | """Creates an optimizer and training op.""" 35 | global_step = tf.train.get_or_create_global_step() 36 | learning_rate = tf.train.polynomial_decay( 37 | learning_rate, 38 | global_step, 39 | num_train_steps, 40 | end_learning_rate=0.0, 41 | power=lr_decay_power, 42 | cycle=False) 43 | warmup_steps = max(num_train_steps * warmup_proportion, warmup_steps) 44 | learning_rate *= tf.minimum( 45 | 1.0, tf.cast(global_step, tf.float32) / tf.cast(warmup_steps, tf.float32)) 46 | 47 | if layerwise_lr_decay_power > 0: 48 | learning_rate = _get_layer_lrs(learning_rate, layerwise_lr_decay_power, 49 | n_transformer_layers) 50 | optimizer = AdamWeightDecayOptimizer( 51 | learning_rate=learning_rate, 52 | weight_decay_rate=weight_decay_rate, 53 | beta_1=0.9, 54 | beta_2=0.999, 55 | epsilon=1e-6, 56 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 57 | if use_tpu: 58 | optimizer = tf.tpu.CrossShardOptimizer(optimizer) 59 | 60 | tvars = tf.trainable_variables() 61 | grads = tf.gradients(loss, tvars) 62 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 63 | train_op = optimizer.apply_gradients( 64 | zip(grads, tvars), global_step=global_step) 65 | new_global_step = global_step + 1 66 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 67 | return train_op 68 | 69 | 70 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 71 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 72 | 73 | def __init__(self, 74 | learning_rate, 75 | weight_decay_rate=0.0, 76 | beta_1=0.9, 77 | beta_2=0.999, 78 | epsilon=1e-6, 79 | exclude_from_weight_decay=None, 80 | name="AdamWeightDecayOptimizer"): 81 | """Constructs a AdamWeightDecayOptimizer.""" 82 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 83 | 84 | self.learning_rate = learning_rate 85 | self.weight_decay_rate = weight_decay_rate 86 | self.beta_1 = beta_1 87 | self.beta_2 = beta_2 88 | self.epsilon = epsilon 89 | self.exclude_from_weight_decay = exclude_from_weight_decay 90 | 91 | def _apply_gradients(self, grads_and_vars, learning_rate): 92 | """See base class.""" 93 | assignments = [] 94 | for (grad, param) in grads_and_vars: 95 | if grad is None or param is None: 96 | continue 97 | 98 | param_name = self._get_variable_name(param.name) 99 | 100 | m = tf.get_variable( 101 | name=param_name + "/adam_m", 102 | shape=param.shape.as_list(), 103 | dtype=tf.float32, 104 | trainable=False, 105 | initializer=tf.zeros_initializer()) 106 | v = tf.get_variable( 107 | name=param_name + "/adam_v", 108 | shape=param.shape.as_list(), 109 | dtype=tf.float32, 110 | trainable=False, 111 | initializer=tf.zeros_initializer()) 112 | 113 | # Standard Adam update. 114 | next_m = ( 115 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 116 | next_v = ( 117 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 118 | tf.square(grad))) 119 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 120 | 121 | # Just adding the square of the weights to the loss function is *not* 122 | # the correct way of using L2 regularization/weight decay with Adam, 123 | # since that will interact with the m and v parameters in strange ways. 124 | # 125 | # Instead we want ot decay the weights in a manner that doesn't interact 126 | # with the m/v parameters. This is equivalent to adding the square 127 | # of the weights to the loss with plain (non-momentum) SGD. 128 | if self.weight_decay_rate > 0: 129 | if self._do_use_weight_decay(param_name): 130 | update += self.weight_decay_rate * param 131 | 132 | update_with_lr = learning_rate * update 133 | next_param = param - update_with_lr 134 | 135 | assignments.extend( 136 | [param.assign(next_param), 137 | m.assign(next_m), 138 | v.assign(next_v)]) 139 | 140 | return assignments 141 | 142 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 143 | if isinstance(self.learning_rate, dict): 144 | key_to_grads_and_vars = {} 145 | for grad, var in grads_and_vars: 146 | update_for_var = False 147 | for key in self.learning_rate: 148 | if key in var.name: 149 | update_for_var = True 150 | if key not in key_to_grads_and_vars: 151 | key_to_grads_and_vars[key] = [] 152 | key_to_grads_and_vars[key].append((grad, var)) 153 | if not update_for_var: 154 | raise ValueError("No learning rate specified for variable", var) 155 | assignments = [] 156 | for key, key_grads_and_vars in key_to_grads_and_vars.items(): 157 | assignments += self._apply_gradients(key_grads_and_vars, 158 | self.learning_rate[key]) 159 | else: 160 | assignments = self._apply_gradients(grads_and_vars, self.learning_rate) 161 | return tf.group(*assignments, name=name) 162 | 163 | def _do_use_weight_decay(self, param_name): 164 | """Whether to use L2 weight decay for `param_name`.""" 165 | if not self.weight_decay_rate: 166 | return False 167 | if self.exclude_from_weight_decay: 168 | for r in self.exclude_from_weight_decay: 169 | if re.search(r, param_name) is not None: 170 | return False 171 | return True 172 | 173 | def _get_variable_name(self, param_name): 174 | """Get the variable name from the tensor name.""" 175 | m = re.match("^(.*):\\d+$", param_name) 176 | if m is not None: 177 | param_name = m.group(1) 178 | return param_name 179 | 180 | 181 | def _get_layer_lrs(learning_rate, layer_decay, n_layers): 182 | """Have lower learning rates for layers closer to the input.""" 183 | key_to_depths = collections.OrderedDict({ 184 | "/embeddings/": 0, 185 | "/embeddings_project/": 0, 186 | "task_specific/": n_layers + 2, 187 | }) 188 | for layer in range(n_layers): 189 | key_to_depths["encoder/layer_" + str(layer) + "/"] = layer + 1 190 | return { 191 | key: learning_rate * (layer_decay ** (n_layers + 2 - depth)) 192 | for key, depth in key_to_depths.items() 193 | } 194 | -------------------------------------------------------------------------------- /electra/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tokenization classes, the same as used for BERT.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | import unicodedata 24 | import six 25 | import tensorflow.compat.v1 as tf 26 | 27 | 28 | 29 | def convert_to_unicode(text): 30 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 31 | if six.PY3: 32 | if isinstance(text, str): 33 | return text 34 | elif isinstance(text, bytes): 35 | return text.decode("utf-8", "ignore") 36 | else: 37 | raise ValueError("Unsupported string type: %s" % (type(text))) 38 | elif six.PY2: 39 | if isinstance(text, str): 40 | return text.decode("utf-8", "ignore") 41 | elif isinstance(text, unicode): 42 | return text 43 | else: 44 | raise ValueError("Unsupported string type: %s" % (type(text))) 45 | else: 46 | raise ValueError("Not running on Python2 or Python 3?") 47 | 48 | 49 | def printable_text(text): 50 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 51 | 52 | # These functions want `str` for both Python2 and Python3, but in one case 53 | # it's a Unicode string and in the other it's a byte string. 54 | if six.PY3: 55 | if isinstance(text, str): 56 | return text 57 | elif isinstance(text, bytes): 58 | return text.decode("utf-8", "ignore") 59 | else: 60 | raise ValueError("Unsupported string type: %s" % (type(text))) 61 | elif six.PY2: 62 | if isinstance(text, str): 63 | return text 64 | elif isinstance(text, unicode): 65 | return text.encode("utf-8") 66 | else: 67 | raise ValueError("Unsupported string type: %s" % (type(text))) 68 | else: 69 | raise ValueError("Not running on Python2 or Python 3?") 70 | 71 | 72 | def load_vocab(vocab_file): 73 | """Loads a vocabulary file into a dictionary.""" 74 | vocab = collections.OrderedDict() 75 | index = 0 76 | with tf.io.gfile.GFile(vocab_file, "r") as reader: 77 | while True: 78 | token = convert_to_unicode(reader.readline()) 79 | if not token: 80 | break 81 | token = token.strip() 82 | vocab[token] = index 83 | index += 1 84 | return vocab 85 | 86 | 87 | def convert_by_vocab(vocab, items): 88 | """Converts a sequence of [tokens|ids] using the vocab.""" 89 | output = [] 90 | for item in items: 91 | output.append(vocab[item]) 92 | return output 93 | 94 | 95 | def convert_tokens_to_ids(vocab, tokens): 96 | return convert_by_vocab(vocab, tokens) 97 | 98 | 99 | def convert_ids_to_tokens(inv_vocab, ids): 100 | return convert_by_vocab(inv_vocab, ids) 101 | 102 | 103 | def whitespace_tokenize(text): 104 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 105 | text = text.strip() 106 | if not text: 107 | return [] 108 | tokens = text.split() 109 | return tokens 110 | 111 | 112 | class FullTokenizer(object): 113 | """Runs end-to-end tokenziation.""" 114 | 115 | def __init__(self, vocab_file, do_lower_case=True): 116 | self.vocab = load_vocab(vocab_file) 117 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 118 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 119 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 120 | 121 | def tokenize(self, text): 122 | split_tokens = [] 123 | for token in self.basic_tokenizer.tokenize(text): 124 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 125 | split_tokens.append(sub_token) 126 | 127 | return split_tokens 128 | 129 | def convert_tokens_to_ids(self, tokens): 130 | return convert_by_vocab(self.vocab, tokens) 131 | 132 | def convert_ids_to_tokens(self, ids): 133 | return convert_by_vocab(self.inv_vocab, ids) 134 | 135 | 136 | class BasicTokenizer(object): 137 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 138 | 139 | def __init__(self, do_lower_case=True): 140 | """Constructs a BasicTokenizer. 141 | 142 | Args: 143 | do_lower_case: Whether to lower case the input. 144 | """ 145 | self.do_lower_case = do_lower_case 146 | 147 | def tokenize(self, text): 148 | """Tokenizes a piece of text.""" 149 | text = convert_to_unicode(text) 150 | text = self._clean_text(text) 151 | 152 | # This was added on November 1st, 2018 for the multilingual and Chinese 153 | # models. This is also applied to the English models now, but it doesn't 154 | # matter since the English models were not trained on any Chinese data 155 | # and generally don't have any Chinese data in them (there are Chinese 156 | # characters in the vocabulary because Wikipedia does have some Chinese 157 | # words in the English Wikipedia.). 158 | text = self._tokenize_chinese_chars(text) 159 | 160 | orig_tokens = whitespace_tokenize(text) 161 | split_tokens = [] 162 | for token in orig_tokens: 163 | if self.do_lower_case: 164 | token = token.lower() 165 | token = self._run_strip_accents(token) 166 | split_tokens.extend(self._run_split_on_punc(token)) 167 | 168 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 169 | return output_tokens 170 | 171 | def _run_strip_accents(self, text): 172 | """Strips accents from a piece of text.""" 173 | text = unicodedata.normalize("NFD", text) 174 | output = [] 175 | for char in text: 176 | cat = unicodedata.category(char) 177 | if cat == "Mn": 178 | continue 179 | output.append(char) 180 | return "".join(output) 181 | 182 | def _run_split_on_punc(self, text): 183 | """Splits punctuation on a piece of text.""" 184 | chars = list(text) 185 | i = 0 186 | start_new_word = True 187 | output = [] 188 | while i < len(chars): 189 | char = chars[i] 190 | if _is_punctuation(char): 191 | output.append([char]) 192 | start_new_word = True 193 | else: 194 | if start_new_word: 195 | output.append([]) 196 | start_new_word = False 197 | output[-1].append(char) 198 | i += 1 199 | 200 | return ["".join(x) for x in output] 201 | 202 | def _tokenize_chinese_chars(self, text): 203 | """Adds whitespace around any CJK character.""" 204 | output = [] 205 | for char in text: 206 | cp = ord(char) 207 | if self._is_chinese_char(cp): 208 | output.append(" ") 209 | output.append(char) 210 | output.append(" ") 211 | else: 212 | output.append(char) 213 | return "".join(output) 214 | 215 | def _is_chinese_char(self, cp): 216 | """Checks whether CP is the codepoint of a CJK character.""" 217 | # This defines a "chinese character" as anything in the CJK Unicode block: 218 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 219 | # 220 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 221 | # despite its name. The modern Korean Hangul alphabet is a different block, 222 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 223 | # space-separated words, so they are not treated specially and handled 224 | # like the all of the other languages. 225 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 226 | (cp >= 0x3400 and cp <= 0x4DBF) or # 227 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 228 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 229 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 230 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 231 | (cp >= 0xF900 and cp <= 0xFAFF) or # 232 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 233 | return True 234 | 235 | return False 236 | 237 | def _clean_text(self, text): 238 | """Performs invalid character removal and whitespace cleanup on text.""" 239 | output = [] 240 | for char in text: 241 | cp = ord(char) 242 | if cp == 0 or cp == 0xfffd or _is_control(char): 243 | continue 244 | if _is_whitespace(char): 245 | output.append(" ") 246 | else: 247 | output.append(char) 248 | return "".join(output) 249 | 250 | 251 | class WordpieceTokenizer(object): 252 | """Runs WordPiece tokenziation.""" 253 | 254 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 255 | self.vocab = vocab 256 | self.unk_token = unk_token 257 | self.max_input_chars_per_word = max_input_chars_per_word 258 | 259 | def tokenize(self, text): 260 | """Tokenizes a piece of text into its word pieces. 261 | 262 | This uses a greedy longest-match-first algorithm to perform tokenization 263 | using the given vocabulary. 264 | 265 | For example: 266 | input = "unaffable" 267 | output = ["un", "##aff", "##able"] 268 | 269 | Args: 270 | text: A single token or whitespace separated tokens. This should have 271 | already been passed through `BasicTokenizer. 272 | 273 | Returns: 274 | A list of wordpiece tokens. 275 | """ 276 | 277 | text = convert_to_unicode(text) 278 | 279 | output_tokens = [] 280 | for token in whitespace_tokenize(text): 281 | chars = list(token) 282 | if len(chars) > self.max_input_chars_per_word: 283 | output_tokens.append(self.unk_token) 284 | continue 285 | 286 | is_bad = False 287 | start = 0 288 | sub_tokens = [] 289 | while start < len(chars): 290 | end = len(chars) 291 | cur_substr = None 292 | while start < end: 293 | substr = "".join(chars[start:end]) 294 | if start > 0: 295 | substr = "##" + substr 296 | if substr in self.vocab: 297 | cur_substr = substr 298 | break 299 | end -= 1 300 | if cur_substr is None: 301 | is_bad = True 302 | break 303 | sub_tokens.append(cur_substr) 304 | start = end 305 | 306 | if is_bad: 307 | output_tokens.append(self.unk_token) 308 | else: 309 | output_tokens.extend(sub_tokens) 310 | return output_tokens 311 | 312 | 313 | def _is_whitespace(char): 314 | """Checks whether `chars` is a whitespace character.""" 315 | # \t, \n, and \r are technically contorl characters but we treat them 316 | # as whitespace since they are generally considered as such. 317 | if char == " " or char == "\t" or char == "\n" or char == "\r": 318 | return True 319 | cat = unicodedata.category(char) 320 | if cat == "Zs": 321 | return True 322 | return False 323 | 324 | 325 | def _is_control(char): 326 | """Checks whether `chars` is a control character.""" 327 | # These are technically control characters but we count them as whitespace 328 | # characters. 329 | if char == "\t" or char == "\n" or char == "\r": 330 | return False 331 | cat = unicodedata.category(char) 332 | if cat.startswith("C"): 333 | return True 334 | return False 335 | 336 | 337 | def _is_punctuation(char): 338 | """Checks whether `chars` is a punctuation character.""" 339 | cp = ord(char) 340 | # We treat all non-letter/number ASCII as punctuation. 341 | # Characters such as "^", "$", and "`" are not in the Unicode 342 | # Punctuation class but we treat them as punctuation anyways, for 343 | # consistency. 344 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 345 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 346 | return True 347 | cat = unicodedata.category(char) 348 | if cat.startswith("P"): 349 | return True 350 | return False 351 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: env 2 | dependencies: 3 | - python>=3.6 4 | - anaconda 5 | - pip 6 | - pip: 7 | - tensorflow-gpu==1.15.2 8 | - bert-tensorflow 9 | - torch==1.5.1 10 | - torchvision==0.6.1 11 | - transformers==2.11.0 12 | - wget 13 | -------------------------------------------------------------------------------- /models/Albert/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "AlbertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0, 6 | "bos_token_id": 2, 7 | "classifier_dropout_prob": 0.1, 8 | "down_scale_factor": 1, 9 | "embedding_size": 128, 10 | "eos_token_id": 3, 11 | "gap_size": 0, 12 | "hidden_act": "gelu_new", 13 | "hidden_dropout_prob": 0, 14 | "hidden_size": 768, 15 | "initializer_range": 0.02, 16 | "inner_group_num": 1, 17 | "intermediate_size": 3072, 18 | "layer_norm_eps": 1e-12, 19 | "max_position_embeddings": 512, 20 | "model_type": "albert", 21 | "net_structure_type": 0, 22 | "num_attention_heads": 12, 23 | "num_hidden_groups": 1, 24 | "num_hidden_layers": 12, 25 | "num_memory_blocks": 0, 26 | "pad_token_id": 0, 27 | "type_vocab_size": 2, 28 | "vocab_size": 30000 29 | } 30 | -------------------------------------------------------------------------------- /models/Albert/spiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fidelity/classitransformers/ec195296787c2dd4cea445a5916dd0ce66252655/models/Albert/spiece.model -------------------------------------------------------------------------------- /models/Bert_base/bert_config.json: -------------------------------------------------------------------------------- 1 | {"hidden_size": 768, "hidden_act": "gelu", "initializer_range": 0.02, "vocab_size": 30522, "hidden_dropout_prob": 0.1, "num_attention_heads": 12, "type_vocab_size": 2, "max_position_embeddings": 512, "num_hidden_layers": 12, "intermediate_size": 3072, "attention_probs_dropout_prob": 0.1} -------------------------------------------------------------------------------- /models/DistilBert/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "activation": "gelu", 3 | "architectures": [ 4 | "DistilBertForMaskedLM" 5 | ], 6 | "attention_dropout": 0.1, 7 | "dim": 768, 8 | "dropout": 0.1, 9 | "hidden_dim": 3072, 10 | "initializer_range": 0.02, 11 | "max_position_embeddings": 512, 12 | "model_type": "distilbert", 13 | "n_heads": 12, 14 | "n_layers": 6, 15 | "pad_token_id": 0, 16 | "qa_dropout": 0.1, 17 | "seq_classif_dropout": 0.2, 18 | "sinusoidal_pos_embds": false, 19 | "tie_weights_": true, 20 | "vocab_size": 30522 21 | } 22 | -------------------------------------------------------------------------------- /models/Roberta/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "RobertaForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "bos_token_id": 0, 7 | "eos_token_id": 2, 8 | "hidden_act": "gelu", 9 | "hidden_dropout_prob": 0.1, 10 | "hidden_size": 768, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 3072, 13 | "layer_norm_eps": 1e-05, 14 | "max_position_embeddings": 514, 15 | "model_type": "roberta", 16 | "num_attention_heads": 12, 17 | "num_hidden_layers": 12, 18 | "pad_token_id": 1, 19 | "type_vocab_size": 1, 20 | "vocab_size": 50265 21 | } 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu==1.15.2 2 | bert-tensorflow 3 | torch==1.5.1 4 | torchvision==0.6.1 5 | transformers==2.11.0 6 | pandas 7 | matplotlib 8 | scikit-learn 9 | wget -------------------------------------------------------------------------------- /sample_notebooks/PYTORCH_classi_transformers_BBCnews_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Pytorch Framework for Multi-class Text Classification" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## BBC News Classification - 5 Class Problem\n", 15 | "\n", 16 | "* Consists of 2225 documents from the BBC news website corresponding to stories in five topical areas from 2004-2005.\n", 17 | "* Class Labels: 5
\n", 18 | "\n", 19 | " 0 - Business
\n", 20 | " 1 - Sport
\n", 21 | " 2 - Politics
\n", 22 | " 3 - Tech
\n", 23 | " 4 - Entertainment " 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "## Reading Data" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 1, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "import sys\n", 40 | "import numpy as np\n", 41 | "import pandas as pd\n", 42 | "\n", 43 | "sys.path.append(\"..\")\n", 44 | "pd.set_option('display.max_colwidth', -1)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "data": { 54 | "text/plain": [ 55 | "((1424, 3), (445, 2), (356, 3))" 56 | ] 57 | }, 58 | "execution_count": 2, 59 | "metadata": {}, 60 | "output_type": "execute_result" 61 | } 62 | ], 63 | "source": [ 64 | "df_train = pd.read_csv('../datasets/bbcfull/train.csv')\n", 65 | "df_test = pd.read_csv('../datasets/bbcfull/test.csv')\n", 66 | "df_score = pd.read_csv('../datasets/bbcfull/test_with_label.csv')\n", 67 | "df_val = pd.read_csv('../datasets/bbcfull/dev.csv')\n", 68 | "df_train.shape, df_test.shape, df_val.shape" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 3, 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "data": { 78 | "text/plain": [ 79 | "1 327\n", 80 | "0 326\n", 81 | "2 267\n", 82 | "3 257\n", 83 | "4 247\n", 84 | "Name: label, dtype: int64" 85 | ] 86 | }, 87 | "execution_count": 3, 88 | "metadata": {}, 89 | "output_type": "execute_result" 90 | } 91 | ], 92 | "source": [ 93 | "df_train.label.value_counts()" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 4, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "data": { 103 | "text/plain": [ 104 | "1 82\n", 105 | "0 82\n", 106 | "2 66\n", 107 | "3 64\n", 108 | "4 62\n", 109 | "Name: label, dtype: int64" 110 | ] 111 | }, 112 | "execution_count": 4, 113 | "metadata": {}, 114 | "output_type": "execute_result" 115 | } 116 | ], 117 | "source": [ 118 | "df_val.label.value_counts()" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 5, 124 | "metadata": {}, 125 | "outputs": [ 126 | { 127 | "data": { 128 | "text/plain": [ 129 | "1 102\n", 130 | "0 102\n", 131 | "2 84 \n", 132 | "3 80 \n", 133 | "4 77 \n", 134 | "Name: label, dtype: int64" 135 | ] 136 | }, 137 | "execution_count": 5, 138 | "metadata": {}, 139 | "output_type": "execute_result" 140 | } 141 | ], 142 | "source": [ 143 | "df_score.label.value_counts()" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": {}, 149 | "source": [ 150 | "## Training Transformers (Albert | Roberta | DistilBert) classifier with FrameWork" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 6, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "from classitransformers.report import metrics\n", 160 | "from classitransformers.configs import Configs\n", 161 | "from classitransformers.pytransformers import TransformersClassification" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "## Set Configurations\n", 169 | "\n", 170 | "* Pre-trained folderpath\n", 171 | "* labels list\n", 172 | "* train, test and eval batch size\n", 173 | "* max_seq_len - how long is the each text input sentence !!\n", 174 | "* data_dir - data path folder (optional)\n", 175 | "* output_dir - output path folder (optional)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 7, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "config = Configs(\n", 185 | " pretrained_model_dir = '../models/Roberta',\n", 186 | " model_name = 'roberta',\n", 187 | " num_train_epochs=3,\n", 188 | " train_batch_size = 16,\n", 189 | " eval_batch_size = 8, \n", 190 | " do_train = True, \n", 191 | " do_eval = True, \n", 192 | " max_seq_length = 256,\n", 193 | " label_list = [\"0\", \"1\", \"2\", \"3\", \"4\"],\n", 194 | " data_dir = '../datasets/bbcfull',\n", 195 | " output_dir = '../transformers_output_bbcnews/')" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 8, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "model = TransformersClassification(config)" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "metadata": {}, 210 | "source": [ 211 | "## Train Method" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 9, 217 | "metadata": { 218 | "scrolled": true 219 | }, 220 | "outputs": [ 221 | { 222 | "name": "stdout", 223 | "output_type": "stream", 224 | "text": [ 225 | "\n", 226 | "======== Epoch 1 / 3 ========\n", 227 | "Training...\n", 228 | " Batch 40 of 89. Elapsed: 0:00:52.\n", 229 | " Batch 80 of 89. Elapsed: 0:01:44.\n", 230 | "\n", 231 | " Average training loss: 0.59\n", 232 | " Training epcoh took: 0:01:56\n", 233 | "\n", 234 | "Running Validation...\n", 235 | " Accuracy: 0.99\n", 236 | " Validation Loss: 0.00\n", 237 | " Validation took: 0:00:10\n", 238 | "\n", 239 | "======== Epoch 2 / 3 ========\n", 240 | "Training...\n", 241 | " Batch 40 of 89. Elapsed: 0:00:52.\n", 242 | " Batch 80 of 89. Elapsed: 0:01:44.\n", 243 | "\n", 244 | " Average training loss: 0.07\n", 245 | " Training epcoh took: 0:01:56\n", 246 | "\n", 247 | "Running Validation...\n", 248 | " Accuracy: 0.98\n", 249 | " Validation Loss: 0.00\n", 250 | " Validation took: 0:00:10\n", 251 | "\n", 252 | "======== Epoch 3 / 3 ========\n", 253 | "Training...\n", 254 | " Batch 40 of 89. Elapsed: 0:00:52.\n", 255 | " Batch 80 of 89. Elapsed: 0:01:44.\n", 256 | "\n", 257 | " Average training loss: 0.03\n", 258 | " Training epcoh took: 0:01:56\n", 259 | "\n", 260 | "Running Validation...\n", 261 | " Accuracy: 0.98\n", 262 | " Validation Loss: 0.00\n", 263 | " Validation took: 0:00:10\n", 264 | "Saving model to ../transformers_output_bbcnews/\n", 265 | "\n", 266 | "Training complete!\n", 267 | "Total training took 0:06:25 (h:mm:ss)\n" 268 | ] 269 | } 270 | ], 271 | "source": [ 272 | "model.train()" 273 | ] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "metadata": {}, 278 | "source": [ 279 | "## Test Method" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 10, 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [ 288 | "prediction = model.test()" 289 | ] 290 | }, 291 | { 292 | "cell_type": "markdown", 293 | "metadata": {}, 294 | "source": [ 295 | "**Testing a sentence, takes a list of sentences.**" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 11, 301 | "metadata": {}, 302 | "outputs": [ 303 | { 304 | "data": { 305 | "text/plain": [ 306 | "[array([0.99918574, 0.19028524, 0.2688493 , 0.23254088, 0.17762943],\n", 307 | " dtype=float32)]" 308 | ] 309 | }, 310 | "execution_count": 11, 311 | "metadata": {}, 312 | "output_type": "execute_result" 313 | } 314 | ], 315 | "source": [ 316 | "model.text_inference(['He is very good in writing a book on financial investment money making strategy. There are some transactional charges but the profit we make is awesome. The business people are exchanging bulk money over trades happening every seconds. Even retailers are happy doing trade and adding sum to their investments.'])" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 12, 322 | "metadata": {}, 323 | "outputs": [], 324 | "source": [ 325 | "y_pred = [np.argmax(tup) for tup in prediction]\n", 326 | "y = df_score.label" 327 | ] 328 | }, 329 | { 330 | "cell_type": "markdown", 331 | "metadata": {}, 332 | "source": [ 333 | "## Accuracy and Report" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 13, 339 | "metadata": {}, 340 | "outputs": [ 341 | { 342 | "name": "stdout", 343 | "output_type": "stream", 344 | "text": [ 345 | "\n", 346 | "\n", 347 | "Accuracy: 0.984\n", 348 | "Precision: 0.984\n", 349 | "Recall: 0.985\n", 350 | "F1-Score: 0.985\n" 351 | ] 352 | }, 353 | { 354 | "data": { 355 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAEmCAYAAADr3bIaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3XucVXW9//HXewYcroMgjOAgIIngpcDJzAuaefspKprGyVIPdixKLT2ZJ7t4jtnl1DmeB1mni2FWlual1PB4zSwyTSkhM1BEvKGAchMZhOEy8/n9sdboSDCzZjt79l6b9/Px2A/2Wnvt9X3PZs9nvuv2XYoIzMysfVWlDmBmlgculmZmGbhYmpll4GJpZpaBi6WZWQYulmZmGbhY7kAk9Zb0f5Jek/TLt7GeMyT9piuzlYqkwyQ9VeocVv7k8yzLj6SPABcB44BG4DHg6xHx4Ntc71nAp4FDImLL2w5a5iQFMCYiFpU6i+Wfe5ZlRtJFwJXAfwK7AiOA7wMnd8HqRwILd4RCmYWkHqXOYDkSEX6UyQMYAKwDprSzTA1JMV2aPq4EatLXjgBeAj4LLAeWAR9NX7sc2ARsTts4B/gycF2bdY8CAuiRTp8NPEvSu30OOKPN/AfbvO8Q4C/Aa+m/h7R5bRbwVeChdD2/AQZv52drzf+5NvlPASYBC4HVwBfbLH8g8DCwJl32u8BO6WsPpD/L6+nP+6E2678EeBn4eeu89D3vSNtoSKd3A1YCR5T6u+FH6R/uWZaXg4FewG3tLPMl4CBgAjCepGBc2ub1oSRFt56kIH5P0sCIuIykt3pTRPSLiGvaCyKpL/Ad4PiI6E9SEB/bxnKDgDvTZXcBpgN3StqlzWIfAT4K1AE7ARe30/RQks+gHvgP4GrgTODdwGHAf0ganS7bDHwGGEzy2R0FnAcQEYeny4xPf96b2qx/EEkve1rbhiPiGZJCer2kPsBPgJ9GxKx28toOwsWyvOwCrIz2N5PPAL4SEcsjYgVJj/GsNq9vTl/fHBF3kfSqxhaYpwXYT1LviFgWEfO3scwJwNMR8fOI2BIRNwALgJPaLPOTiFgYERuAm0kK/fZsJtk/uxm4kaQQfjsiGtP25wPvAoiIORHxSNru88APgfdl+Jkui4iNaZ63iIirgaeB2cAwkj9OZi6WZWYVMLiDfWm7AS+0mX4hnffGOrYqtuuBfp0NEhGvk2y6fhJYJulOSeMy5GnNVN9m+uVO5FkVEc3p89Zi9kqb1ze0vl/SXpLukPSypLUkPefB7awbYEVENHWwzNXAfsD/RsTGDpa1HYSLZXl5GGgi2U+3PUtJNiFbjUjnFeJ1oE+b6aFtX4yIeyPiGJIe1gKSItJRntZMSwrM1Bk/IMk1JiJqgS8C6uA97Z7+IakfyX7ga4Avp7sZzFwsy0lEvEayn+57kk6R1EdST0nHS/rvdLEbgEslDZE0OF3+ugKbfAw4XNIISQOAL7S+IGlXSZPTfZcbSTbnm7exjruAvSR9RFIPSR8C9gHuKDBTZ/QH1gLr0l7vuVu9/gow+h/e1b5vA3Mi4mMk+2KvetsprSK4WJaZiJhOco7lpcAK4EXgU8Cv00W+BjwKPA78HZibziukrfuAm9J1zeGtBa6K5Kj6UpIjxO8jPXiy1TpWASemy64iOZJ9YkSsLCRTJ11McvCokaTXe9NWr38ZuFbSGkn/1NHKJJ0MHEey6wGS/4cGSWd0WWLLLZ+UbmaWgXuWZmYZuFiamWXgYmlmloGLpZlVBEk/lrRc0rw28wZJuk/S0+m/A9P5kvQdSYskPS6pocP1l9MBnl69ekX//v1LHaPTeg7uWeoIZkW1Ztka1q9Z39E5rJ2y++67R1NTR9cHvGnlypX3RsRx23td0uEkp7j9LCL2S+f9N7A6Ir4p6fPAwIi4RNIkkhG4JgHvJblK7L3ttV9Wo67079+fU089tdQxOm3YOcNKHcGsqGZMndHl62xqaurU7/uMGTPavTorIh6QNGqr2SeTDJYCcC3JwC6XpPN/Fklv8RFJO0saFhHLtrd+b4abWSXbtbUApv/WpfPrSc5hbvUSb71E9x+UVc/SzKwdgyU92mZ6RkQU2uXd1i6FdvdJuliaWV6sjIgDOvmeV1o3ryUNIxknFZKe5O5tlhtOB2MseDPczCrZ7cDU9PlUYGab+f+cHhU/CHitvf2V4J6lmVUISTeQHMwZLOkl4DLgm8DNks4BFgNT0sXvIjkSvohk2MCPdrR+F0szqwgR8eHtvHTUNpYN4PzOrN+b4WZmGbhYmpll4GJpZpaBi6WZWQYulmZmGbhYmpll4GJpZpaBi6WZWQYVeVL63//+dxYsWADAuHHjeOc738nKlSt58MEHaW5uRhITJ06krq6ugzWV1qKHF3HP9HtoaWmhYXIDE6dOLHWkDr32ymv8+su/Zt3qdUii4ZQGDjr9oFLH6lBec7fK43clbyquWK5evZoFCxbwgQ98gKqqKu6++25GjBjB7NmzaWhoYMSIESxevJjZs2dz0kknlTrudrU0t3DXFXdx1v+eRW1dLVeffTVjDxvLkNFDSh2tXVXVVRx74bEMGzeMja9vZMbUGbzjwHc4dxHl9buSNxW3Gb5mzRrq6uro0aMHVVVVDBs2jOeeew5JbN68GYBNmzbRp0+fEidt35InljBo+CAG1g+kumc1+x6zLwseWFDqWB3qP7g/w8YlgyHX9K1hyKghrF2xtsSpOpbX3JDf70reVFyxHDhwIC+//DJNTU1s2bKFxYsX8/rrr3PwwQfzyCOPcP311/PII49w4IEHljpquxqXN1K7a+0b07V1tTSuaCxhos5bs3QNyxYuY/i+w0sdpVPylrsSvit5UNTNcEnHAd8GqoEfRcQ3i9keJMVy/Pjx3HnnnfTs2ZNddtkFSTzxxBMcfPDBjB49mmeeeYYHHniAE044odhxChbbGoe0S++AUlyb1m/i5s/fzHGfOY6afjWljpNZHnPn/buSF0XrWUqqBr4HHA/sA3xY0j7Faq+tcePGcdpppzF58mRqamoYMGAACxcuZI899gBg9OjRLF++vIO1lFZtXS1rX3lzM3Dt8rX0H5yPm7k1b2nm5s/fzDuPeyd7v3/vUsfJLK+58/xdyZNiboYfCCyKiGcjYhNwI8lNgopuw4YNAKxbt47nnnuOPffck759+7JsWTK259KlSxkwYEB3RClY/d71rHpxFa8ufZXmzc3Mv28+Yw8fW+pYHYoIbv/a7QweNZiDP3JwqeNkltfckN/vSt4UczN8WzcEavdWk13lvvvuo6mpiaqqKiZOnEhNTQ2HH344f/rTn2hpaaG6uprDDjusO6IUrKpHFZMunsR1F1xHtAQTTppA3ejyPtUJ4MW/vcjjdz9O3Z51XHXmVQAcde5RjDl0TImTtS+vuSG/35W8KWaxzHRDIEnTgGkA/fr165KGJ0+e/A/zhg4dmrvb7I45dEwuflnbGjFhBJfNvqzUMTotr7lb5fG7kjfF3AzPdEOgiJgREQdExAG9evUqYhwzs8IVs1j+BRgjaQ9JOwGnk9wkyMwsd4q2GR4RWyR9CriX5NShH0fE/GK1Z2ZWTEU9zzIi7iK5i5qZWa5V3BU8ZmbFUHEDaZhZPgxqWcOHX5+ZefkZRcyShXuWZmYZuFiamWXgYmlmloGLpZlZBi6WZmYZuFiamWXgYmlmloGLpZlZBi6WZmYZuFiamWXgYmlmloGLpZlZBi6WZmYZuFiamWXgYmlmFUHSZyTNlzRP0g2SeqW3tZkt6WlJN6W3uClIWY1n2XNwT4adM6zUMTpt2TXLSh2hYHn8vM22JqkeuADYJyI2SLqZ5L5fk4BvRcSNkq4CzgF+UEgb7lmaWaXoAfSW1APoAywDjgR+lb5+LXBKoSt3sTSz3IuIJcD/AItJiuRrwBxgTURsSRd7CagvtA0XSzPLi8GSHm3zmNb6gqSBwMnAHsBuQF/g+G2sIwptvKz2WZqZtWNlRBywndeOBp6LiBUAkm4FDgF2ltQj7V0OB5YW2rh7lmZWCRYDB0nqI0nAUcATwO+BD6bLTAWy3yFtKy6WZpZ7ETGb5EDOXODvJLVtBnAJcJGkRcAuwDWFtuHNcDOrCBFxGXDZVrOfBQ7sivW7Z2lmloGLpZlZBi6WZmYZuFiamWXgYmlmloGPhptZSfRZPYT9b5jW8YJvuLxoWbJwz9LMLAMXSzOzDFwszcwyqPh9loseXsQ90++hpaWFhskNTJw6sdSR3mLWrFksXryY3r17M2XKFACampq4//77aWxspH///hx99NHU1NSwZs0aZs2axcqVK3nPe97D+PHjS5z+H8386kwWPrSQvgP7ct4N55U6TqeU+3elPXnOnhcV3bNsaW7hrivu4owrz+D8G89n3m/mseLZFaWO9RZjx45l0qRJb5n32GOPUV9fz+mnn059fT2PPfYYADU1NRxyyCG8613vKkXUTCacOIEzrzyz1DE6LQ/fle3Jc/Y8qehiueSJJQwaPoiB9QOp7lnNvsfsy4IHFpQ61lsMGzaMmpqat8x74YUX2GuvvQDYa6+9eP755wHo3bs3dXV1VFWV73/byP1H0ru2d6ljdFoevivbk+fseVK+v3VdoHF5I7W71r4xXVtXS+OKxhImymbDhg306dMHgD59+rBhw4YSJ6p8ef2uQL6z50nRiqWkH0taLmlesdroSGxrUGR1fw4rf3n+ruQ5e54Us2f5U+C4Iq6/Q7V1tax9Ze0b02uXr6X/4P4lTJRN7969Wb9+PQDr16+nd+/8bdbmTV6/K5Dv7HlStGIZEQ8Aq4u1/izq965n1YureHXpqzRvbmb+ffMZe/jYUkbKZOTIkSxcuBCAhQsXMnLkyBInqnx5/a5AvrPnSclPHUpvOjQNYMDQAV267qoeVUy6eBLXXXAd0RJMOGkCdaPrurSNt+v+++9n6dKlNDU1cf311/Pud7+bCRMm8Nvf/pYFCxbQr18/jj76aCDpZd52221s2rQJScybN48pU6aw004F3ze+y91y6S08P/d51q9Zz/QTp3PEtCNomNxQ6lgdysN3ZXvynD1PFFHwzc46Xrk0CrgjIvbLsvxue+8W067tzLWi5WHZNctKHaFgw84ZVuoIlgMzps5g6ZNLu3RP6P7aLWZVZf9937nl8jnt3LCs6Cr6aLiZWVdxsTQzy6CYpw7dADwMjJX0kqRzitWWmVmxFe0AT0R8uFjrNjPrbt4MNzPLwMXSzCwDF0szswxcLM3MMnCxNDPLwMXSzCwDF0szswxcLM3MMnCxNDPLwMXSzCwDF0szswxcLM3MMij5SOlmtmN6ekgfjvvQ/tnf8N3iZcnCPUszswxcLM3MMnCxNLOKIGlnSb+StEDSk5IOljRI0n2Snk7/HVjo+l0szaxSfBu4JyLGAeOBJ4HPA/dHxBjg/nS6INs9wCOptr03RsTa9l7fkeT5DolHfOcHpY5QkFkXnFvqCFZG0np1OHA2QERsAjZJOhk4Il3sWmAWcEkhbbR3NHw+EEDb21+2TgcwopAGzcyKYDSwAviJpPHAHOBCYNeIWAYQEcskFXxD9e0Wy4jYvdCVmpkVwWBJj7aZnhERM9LnPYAG4NMRMVvSt3kbm9zbkuk8S0mnA6Mj4j8lDSep1nO6MoiZWQdWRsQB23ntJeCliJidTv+KpFi+ImlY2qscBiwvtPEOD/BI+i7wfuCsdNZ64KpCGzQz62oR8TLwoqSx6ayjgCeA24Gp6bypwMxC28jSszwkIhok/TUNtVrSToU2aGZWJJ8Grk/r07PAR0k6hDdLOgdYDEwpdOVZiuVmSVUkB3WQtAvQUmiDZmbFEBGPAdvaTD+qK9af5TzL7wG3AEMkXQ48CPxXVzRuZpYXHfYsI+JnkuYAR6ezpkTEvOLGMjMrL1lHHaoGNpNsivuqHzPb4WQ5Gv4l4AZgN2A48AtJXyh2MDOzcpKlZ3km8O6IWA8g6eskZ8d/o5jBzMzKSZZN6hd4a1HtQXJY3sxsh9HeQBrfItlHuR6YL+nedPpYkiPiZmY7jPY2w1uPeM8H7mwz/5HixTEzK0/tDaRxTXcGMTMrZx0e4JH0DuDrwD5Ar9b5EbFXEXN1mUUPL+Ke6ffQ0tJCw+QGJk6dWOpImeUp+y+fXM+dz2xAwOide3DJwbXsVJ2M7vedvzRy97NN3P2hIaUN2YE8fd5by3P2vMhygOenwE9IxrE8HrgZuLGImbpMS3MLd11xF2dceQbn33g+834zjxXPrih1rEzylH3F+mZufWo9PzxuED85cReaA373fBMAT63azLpN5X91bJ4+763lOXueZCmWfSLiXoCIeCYiLiUZhajsLXliCYOGD2Jg/UCqe1az7zH7suCBBaWOlUnesjcHbGwOmluCjVuCXfpU0dwSXPXXdXyioV+p43Uob593W3nOnidZiuVGSQKekfRJSScBBY823J0alzdSu+ubd8eoraulcUVjCRNll6fsQ/pU80979+FDv17FabeupO9O4j3Darht4QYOqa9hl97VpY7YoTx93lvLc/Y8yVIsPwP0Ay4ADgU+DvxLR2+StLuk36d3WZsv6cK3F7XzIhkoaatg3Z2iMHnK3rixhT+9tJEbTt6FX506mKYtwb3PbuAPizdy6tjepY6XSZ4+763lOXueZBlIo3Xk4UbeHAA4iy3AZyNirqT+wBxJ90XEEwXkLEhtXS1rX3nzvmprl6+l/+D+3dX825Kn7HNe3sTQftXs3Cv523vY7jX89PHX2dgMZ9y+CoCNW4IzZq7i+pN3KWXU7crT5721PGfPk+32LCXdJunW7T06WnFELIuIuenzRpLbUtZ3XfSO1e9dz6oXV/Hq0ldp3tzM/PvmM/bwsR2/sQzkKXtd32qeWLmZpi1BRDD35c1M2bsPt542mBtPSR41PVS2hRLy9XlvLc/Z86S9nuV3u6oRSaOA/YHZ23htGjANYMDQAV3VJABVPaqYdPEkrrvgOqIlmHDSBOpG52J3a66y7zO4J+8bUcO0u1dTLRgzsAcn7pmPze9Wefq8t5bn7HmiiG3s7+jKBqR+wB+Ar0dEuz3S3fbeLaZdO62oeeytfN9wy2LG1BksfXJpl+4JPeCAkfHoo1/KvLz0iTnt3LCs6Io6NqWkniSjrF/fUaE0MytnRSuW6elG1wBPRsT0YrVjZtYdMhdLSTWdXPehJEfPj5T0WPqY1Ml1mJmVhSzXhh9I0kMcAIyQNB74WER8ur33RcSD+GwvM6sQWXqW3wFOBFYBRMTfyMnljmZmXSVLsayKiBe2mtdcjDBmZuUqyz14Xkw3xUNSNfBpYGFxY5mZlZcsPctzgYuAEcArwEHpPDOzHUaWa8OXA6d3QxYzs7KV5Wj41fCPw5pEhC+1MbMdRpZ9lr9t87wX8AHgxeLEMTMrT1k2w29qOy3p58B9RUtkZlaGCrnccQ9gZFcHMTMrZ1n2Wb7Km/ssq4DVwOeLGcrMrNy0WyzTwTDGA0vSWS1R7DHdzMzKULub4WlhvC0imtOHC6WZ7ZCy7LP8s6SGoicxMytj290Ml9QjIrYAE4GPS3oGeJ1kJKGICBdQM9thtLfP8s9AA3BKN2UxMytb7RVLAUTEM92UxczsbUkH+3kUWBIRJ0raA7gRGATMBc6KiE2FrLu9YjlE0kXbe9G3ijCzMnQhyW23a9Pp/wK+FRE3SroKOAco6C597RXLaqAfHu28ouX1Lom+K6VtTdJw4ATg68BF6amPRwIfSRe5FvgyRSiWyyLiK4Ws1MysCAZLerTN9IyImNFm+krgc0D/dHoXYE16oBrgJaC+0MY73GdpZlYmVm7vvuGSTgSWR8QcSUe0zt7GogWfK95esTyq0JWamXWzQ4HJ6R1ke5Hss7wS2LnNaZDDgaWFNrDdk9IjYnWhKzUz604R8YWIGB4Ro0gGK/9dRJwB/B74YLrYVGBmoW0UMuqQmVleXEJysGcRyT7MawpdUZbBf83MciMiZgGz0ufPAgd2xXpdLM2sNBqbafnDqlKnyMyb4WZmGbhYmpll4GJpZpaBi6WZWQYulmZmGbhYmpll4GJpZpaBi6WZWQYVf1L6zK/OZOFDC+k7sC/n3XBeqeN0yqKHF3HP9HtoaWmhYXIDE6dOLHWkTPKSe/HaLXzlwbVvTC9rbOaj4/syYdeefOvPjWzYHAztV82XDq2lb8/y7lfk5TPPs/L+BnSBCSdO4Mwrzyx1jE5raW7hrivu4owrz+D8G89n3m/mseLZFaWO1aE85R5R24MfTRrEjyYN4ofHDaSmh5g4vIb/eaSRj0/ox49P3IWJu9dw0xPrSx21XXn6zPOs4ovlyP1H0ru2d6ljdNqSJ5YwaPggBtYPpLpnNfsesy8LHlhQ6lgdymvuua9sYrd+1QztV82La5sZX9cTgAOG7sQDizeWOF378vqZ503FF8u8alzeSO2utW9M19bV0riisYSJsslr7t89v5GjRtUAsMfO1Tz0UnJPq1mLN7J8fUspo3Uor5953hStWErqJenPkv4mab6ky4vVViWKbQ3onIOx6/OYe3Nz8KclG3nfiF4AfO6gWmYuXM+0u1ezYXNQ5rsrc/mZ51ExD/BsBI6MiHWSegIPSro7Ih4pYpsVo7aulrWvvHnwYe3ytfQf3L+dd5SHPOaevXQTew3swaDeSVUcMaAHVxw1EIAX127hkaXlvRmex888j4r2NzMS69LJnumj4Ptf7Gjq965n1YureHXpqzRvbmb+ffMZe/jYUsfqUB5z/+6FJo4c1euN6Vebks3ulgh+Pm89J40p733eefzM86iopw6lNzyfA+wJfC8iZhezvW255dJbeH7u86xfs57pJ07niGlH0DC5obtjdFpVjyomXTyJ6y64jmgJJpw0gbrRdaWO1aG85W7aEsxZtomLDnyzJ3b/803MXLgBgMN2r+H40b229/aykLfPPK+KWiwjohmYIGln4DZJ+0XEvLbLSJoGTAMYMHRAl2c47Wundfk6u8uYQ8cw5tAxpY7RaXnK3auHmDllyFvmfXBcHz44rk+JEhUmT595XnXLruuIWEMyzPtx23htRkQcEBEH9Nk5X19QM9txFPNo+JC0R4mk3sDRgE/+MrNcKuZm+DDg2nS/ZRVwc0TcUcT2zMyKpmjFMiIeB/Yv1vrNzLpTmZ9ua2ZWHlwszcwycLE0M8vAxdLMLAMXSzOzDFwszcwycLE0M8ug4u/BY2blqbF5AH9YPbkT7/hi0bJk4Z6lmVkGLpZmZhm4WJqZZeBiaWaWgYulmVkGLpZmZhm4WJqZZeBiaWa5J2l3Sb+X9KSk+ZIuTOcPknSfpKfTfwcW2oaLpZlVgi3AZyNib+Ag4HxJ+wCfB+6PiDHA/el0QVwszSz3ImJZRMxNnzcCTwL1wMnAteli1wKnFNqGL3e0XJp1wbmljlCQI77zg1JHKMgNy1eXOgLAYEmPtpmeEREztl5I0iiSW9rMBnaNiGWQFFRJBd9Q3cXSzPJiZUQc0N4CkvoBtwD/GhFrJXVZ494MN7OKIKknSaG8PiJuTWe/ImlY+vowYHmh63exNLPcU9KFvAZ4MiKmt3npdmBq+nwqMLPQNrwZbmaV4FDgLODvkh5L530R+CZws6RzgMXAlEIbcLE0s9yLiAeB7e2gPKor2vBmuJlZBi6WZmYZuFiamWXgYmlmloGLpZlZBi6WZmYZuFiamWXgYmlmloGLpZlZBi6WZmYZVPzljjO/OpOFDy2k78C+nHfDeaWO0ymLHl7EPdPvoaWlhYbJDUycOrHUkTJx7uJbvHYLX3lw7RvTyxqb+ej4vnxwXB9ufWo9v35qA1VVcNBuNXyyoV8Jk1aOii+WE06cwIFTDuS2y28rdZROaWlu4a4r7uKs/z2L2rparj77asYeNpYho4eUOlq7nLt7jKjtwY8mDQKguSWYctsqJg6v4a8vb+KhlzbyoxMGsVO1eLWppcRJK0fFF8uR+49kzdI1pY7RaUueWMKg4YMYWJ/cX2nfY/ZlwQMLyvaXt5Vzd7+5r2xit37VDO1XzVV/XcdH9unLTtXJmBIDe5XvnrbGnq/yh/pfljpGZuX7Se7gGpc3Urtr7RvTtXW1NK5oLGGibJy7+/3u+Y0cNaoGgJcam3l8xSbOvWc1F973KgtWbS5xuspR9GIpqVrSXyXdUey2KkkQ/ziz60bILxrn7l6bm4M/LdnI+0b0ApJN8sZNwff/30A+uX8/Lv/ja0Rs42ezTuuOnuWFJHdas06oratl7Stv7sBfu3wt/Qf3L2GibJy7e81euom9BvZgUO/kV3lIn2oO370GSew9uCdVgtc2ulh2haIWS0nDgROAHxWznUpUv3c9q15cxatLX6V5czPz75vP2MPHljpWh5y7e/3uhSaOHNXrjemJw2uY+/ImAF5cu4XNLTCgJgdd5Bwo9gGeK4HPAdv9Ey1pGjANYMDQAV0e4JZLb+H5uc+zfs16pp84nSOmHUHD5IYub6erVfWoYtLFk7juguuIlmDCSROoG13wXTy7jXN3n6YtwZxlm7jowDd/vY5/Ry/++5G1fPSOVfSsEp8/uJauvMPhjqxoxVLSicDyiJgj6YjtLZfe93cGwG5779bl2wunfe20rl5ltxlz6BjGHDqm1DE6zbm7R68eYuaUtx6t71ktvnRo13c6rLib4YcCkyU9D9wIHCnpuiK2Z2ZWNEUrlhHxhYgYHhGjgNOB30XEmcVqz8ysmHyepZlZBt1yBU9EzAJmdUdbZmbF4J6lmVkGLpZmZhm4WJqZZeBiaWaWgYulmVkGLpZmZhm4WJqZZeBiaWaWgYulmVkGLpZmZhm4WJqZZeBiaWaWgYulmVkGLpZmZhm4WJpZRZB0nKSnJC2S9PmuXr+LpZnlnqRq4HvA8cA+wIcl7dOVbbhYmlklOBBYFBHPRsQmkvt+ndyVDSiifG7ALmkF8EKRVj8YWFmkdReTc3e/vGYvZu6RETGk48Wyk3QPSeasegFNbaZnpHeHRdIHgeMi4mPp9FnAeyPiU12Vt1tuK5FVV/9ntCXp0Yg4oFjrLxbn7n55zZ633BFxXBeubls3R+/SnqA3w82sErwE7N5mejiwtCsbcLE0s0rwF2BW92guAAAGkElEQVSMpD0k7URy++3bu7KBstoML7IZpQ5QIOfufnnNntfcb1tEbJH0KeBeoBr4cUTM78o2yuoAj5lZufJmuJlZBi6WZmYZuFiaVQhJ2zp9xrpIxRZLSWMlHSypZ3opVK7kNPOekg6QVFPqLJ0haV9J75O0S6mzdJakiekJ2EREuGAWT0UeDZd0KvCfwJL08aikn0bE2tIm65ikvSJiYUQ0S6qOiOZSZ8pC0okkn/kq4GVJl0XEwhLH6pCk44H/Ap4Feko6JyJeLnGsDkmqAvoAP0wm1TcirkoLZlVEtJQ4YsWpuJ6lpJ7Ah4BzIuIoYCbJyaqfk1Rb0nAdSAvOY5J+AdBaMEscq0OSDgH+B5gaEe8HXgW6fNSXribpCODbwMci4hRgE7BfSUNlFBEtEbEOuBa4BjhE0mdaXytpuApVccUyVQuMSZ/fBtwB7AR8pFw3UyT1BT4F/CuwSdJ1kJ+CCXwzIv6aPr8MGJSDzfFXgE9ExJ8lDQXeC3xK0g8lfbBcvytb2ULSGbgWOFDSdEnfUKJSf79LouI+zIjYDEwHTpV0WPpX9kHgMWBiScO1IyJeB/4F+AVwMdCrbcEsZbYMZgO3whv7WmuAkSR/tCjXfYER8WRE/D6dPAf4ftrDfASYQucGeSiVmcDLEXE/8CjwSaA2Eu5hdqGKK5apPwK/Ac6SdHhENEfEL4DdgPGljbZ9EbE0ItZFxErgE0Dv1oIpqUHSuNIm3Lb0823dHyxgDbA6IlZIOgP4mqTepUvYsYj4ekR8LX3+E6A/b73WuFxtAMZK+jhJofwmMELSJ0obq/JU5AGeiGiSdD3JqCNfSIvMRmBXYFlJw2UUEavSL/wVkhaQXML1/hLH6lBEbAHWSXpR0jeAY4GzI2JDiaNtlyRFm0vZJJ1G8l3p0oEYiiEilkp6Efh34PyI+D9J7wcWlThaxanoyx3TC+oPJemlNQHfbrNfLRfSnfaXAMdExN9Lnacj6X6+nsCT6b9HRcTTpU2VTbqP9UzgIuBDETGvxJEykbQ7UBcRc9JpHw0vgooulq3S/Wi524cjaSBwM/DZiHi81Hk6Q9LZwF+6ejCDYkrPpDgGeCYinip1ns7auodsXWuHKJZ5JqlXRDR1vGR58S+uVRoXSzOzDCr1aLiZWZdysTQzy8DF0swsAxdLM7MMXCwrhKRmSY9Jmifpl5L6vI11HSHpjvT5ZEnbHRRD0s6SziugjS9Lujjr/K2W+Wl6n+isbY2SlItzJq18uVhWjg0RMSEi9iMZPeeTbV8sdGCFiLg9Ir7ZziI7A50ulmZ542JZmf4I7Jn2qJ6U9H1gLrC7pGMlPSxpbtoD7Qcg6ThJCyQ9CJzauiJJZ0v6bvp8V0m3Sfpb+jiE5Frkd6S92ivS5f5N0l8kPS7p8jbr+pKkpyT9Fhjb0Q8h6ePpev4m6ZatestHS/qjpIXp0HZIqpZ0RZu2fX20dRkXywojqQdwPNB6aeRY4GcRsT/wOnApcHRENJCMUnORpF7A1cBJwGHA0O2s/jvAHyJiPNAAzCcZt/KZtFf7b5KOJRke70BgAvBuSYdLejfJvZz3JynG78nw49waEe9J23uSZGSgVqOA9wEnAFelP8M5wGsR8Z50/R+XtEeGdsw6VJEDaeygekt6LH3+R5IBYXcDXoiIR9L5BwH7AA+lQzXuBDwMjAOea72GOx3paNo22jgS+Gd4Y9i419JLMts6Nn20XoPfj6R49gdui4j1aRu3Z/iZ9pP0NZJN/X4k94RudXN6+erTkp5Nf4ZjgXe12Z85IG277Edst/LnYlk5NkTEhLYz0oL4ettZwH0R8eGtlptAMkJTVxDwjYj44VZt/GsBbfwUOCUi/pZea35Em9e2XlekbX86ItoWVSSN6mS7Zv/Am+E7lkeAQyXtCSCpj6S9gAXAHpLekS734e28/37g3PS91Upu09FI0mtsdS/wL232hdZLqgMeAD4gqbek/iSb/B3pDyxLB7g4Y6vXpkiqSjOPBp5K2z43XR5JeykZgd7sbXPPcgeSDsZ7NnCD3rzlw6URsVDSNOBOSStJRpbf1r1oLgRmSDoHaAbOjYiHJT2Unppzd7rfcm/g4bRnuw44MyLmSrqJZMT6F0h2FXTk30lGYX+BZB9s26L8FPAHknEnP5mOYfojkn2Zc9Oh4lYAp2T7dMza54E0zMwy8Ga4mVkGLpZmZhm4WJqZZeBiaWaWgYulmVkGLpZmZhm4WJqZZfD/AWLrsi4+BNDGAAAAAElFTkSuQmCC\n", 356 | "text/plain": [ 357 | "
" 358 | ] 359 | }, 360 | "metadata": {}, 361 | "output_type": "display_data" 362 | } 363 | ], 364 | "source": [ 365 | "metrics(y, y_pred)" 366 | ] 367 | } 368 | ], 369 | "metadata": { 370 | "kernelspec": { 371 | "display_name": "conda_pytorch_p36", 372 | "language": "python", 373 | "name": "conda_pytorch_p36" 374 | }, 375 | "language_info": { 376 | "codemirror_mode": { 377 | "name": "ipython", 378 | "version": 3 379 | }, 380 | "file_extension": ".py", 381 | "mimetype": "text/x-python", 382 | "name": "python", 383 | "nbconvert_exporter": "python", 384 | "pygments_lexer": "ipython3", 385 | "version": "3.6.5" 386 | } 387 | }, 388 | "nbformat": 4, 389 | "nbformat_minor": 2 390 | } 391 | -------------------------------------------------------------------------------- /sample_notebooks/PYTORCH_classi_transformers_BBCsports_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# PyTorch Framework for Multi-class Text Classification" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## BBC Sports Classification - 5 Class Problem\n", 15 | "\n", 16 | "* Consists of 737 documents from the BBC Sport website corresponding to sports news articles in five topical areas from 2004-2005.\n", 17 | "* Class Labels: 5
\n", 18 | "\n", 19 | " 0 - Rugby
\n", 20 | " 1 - Tennis
\n", 21 | " 2 - Football
\n", 22 | " 3 - Cricket
\n", 23 | " 4 - Atheletics
" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "\n", 31 | "## Reading Data" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 1, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "import sys\n", 41 | "import numpy as np\n", 42 | "import pandas as pd\n", 43 | "\n", 44 | "sys.path.append(\"..\")\n", 45 | "pd.set_option('display.max_colwidth', -1)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 2, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "df_train = pd.read_csv('../datasets/bbcsports/train.csv')\n", 55 | "df_test = pd.read_csv('../datasets/bbcsports/test.csv')\n", 56 | "df_score = pd.read_csv('../datasets/bbcsports/test_with_label.csv')\n", 57 | "df_val = pd.read_csv('../datasets/bbcsports/dev.csv')" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 3, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "data": { 67 | "text/plain": [ 68 | "((471, 3), (148, 2), (118, 3))" 69 | ] 70 | }, 71 | "execution_count": 3, 72 | "metadata": {}, 73 | "output_type": "execute_result" 74 | } 75 | ], 76 | "source": [ 77 | "df_train.shape, df_test.shape, df_val.shape" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 4, 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "data": { 87 | "text/plain": [ 88 | "2 169\n", 89 | "0 94 \n", 90 | "3 79 \n", 91 | "4 65 \n", 92 | "1 64 \n", 93 | "Name: label, dtype: int64" 94 | ] 95 | }, 96 | "execution_count": 4, 97 | "metadata": {}, 98 | "output_type": "execute_result" 99 | } 100 | ], 101 | "source": [ 102 | "df_train.label.value_counts()" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 5, 108 | "metadata": {}, 109 | "outputs": [ 110 | { 111 | "data": { 112 | "text/plain": [ 113 | "2 43\n", 114 | "0 23\n", 115 | "3 20\n", 116 | "4 16\n", 117 | "1 16\n", 118 | "Name: label, dtype: int64" 119 | ] 120 | }, 121 | "execution_count": 5, 122 | "metadata": {}, 123 | "output_type": "execute_result" 124 | } 125 | ], 126 | "source": [ 127 | "df_val.label.value_counts()" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "## Training Transformers (Albert | Roberta | DistilBert) classifier with FrameWork" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 6, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "from classitransformers.report import metrics\n", 144 | "from classitransformers.configs import Configs\n", 145 | "from classitransformers.pytransformers import TransformersClassification" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "## Set Configurations\n", 153 | "\n", 154 | "* Pre-trained config filepath, vocab filepath and model checkpoint filepath\n", 155 | "* labels list\n", 156 | "* train, test and eval batch size\n", 157 | "* max_seq_len - how long is the each text input sentence !!\n", 158 | "* data_dir - data path folder (optional)\n", 159 | "* output_dir - output path folder (optional)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 7, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "config = Configs(\n", 169 | " pretrained_model_dir = '../models/Albert',\n", 170 | " model_name = 'albert',\n", 171 | " num_train_epochs=3,\n", 172 | " train_batch_size = 16,\n", 173 | " eval_batch_size = 8, \n", 174 | " do_train = True, \n", 175 | " do_eval = True, \n", 176 | " max_seq_length = 256,\n", 177 | " label_list = [\"0\", \"1\", \"2\", \"3\", \"4\"],\n", 178 | " data_dir = '../datasets/bbcsports',\n", 179 | " output_dir = '../transformers_output_bbcsports/')" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 8, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "model = TransformersClassification(config)" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "## Train Method" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 9, 201 | "metadata": { 202 | "scrolled": true 203 | }, 204 | "outputs": [ 205 | { 206 | "name": "stdout", 207 | "output_type": "stream", 208 | "text": [ 209 | "\n", 210 | "======== Epoch 1 / 3 ========\n", 211 | "Training...\n", 212 | "\n", 213 | " Average training loss: 1.39\n", 214 | " Training epcoh took: 0:00:39\n", 215 | "\n", 216 | "Running Validation...\n", 217 | " Accuracy: 0.92\n", 218 | " Validation Loss: 0.00\n", 219 | " Validation took: 0:00:04\n", 220 | "\n", 221 | "======== Epoch 2 / 3 ========\n", 222 | "Training...\n", 223 | "\n", 224 | " Average training loss: 0.30\n", 225 | " Training epcoh took: 0:00:40\n", 226 | "\n", 227 | "Running Validation...\n", 228 | " Accuracy: 0.98\n", 229 | " Validation Loss: 0.00\n", 230 | " Validation took: 0:00:04\n", 231 | "\n", 232 | "======== Epoch 3 / 3 ========\n", 233 | "Training...\n", 234 | "\n", 235 | " Average training loss: 0.04\n", 236 | " Training epcoh took: 0:00:40\n", 237 | "\n", 238 | "Running Validation...\n", 239 | " Accuracy: 0.98\n", 240 | " Validation Loss: 0.00\n", 241 | " Validation took: 0:00:04\n", 242 | "Saving model to ../transformers_output_bbcsports/\n", 243 | "\n", 244 | "Training complete!\n", 245 | "Total training took 0:02:12 (h:mm:ss)\n" 246 | ] 247 | } 248 | ], 249 | "source": [ 250 | "model.train()" 251 | ] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "metadata": {}, 256 | "source": [ 257 | "## Test Method" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 10, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "prediction = model.test()" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "metadata": {}, 272 | "source": [ 273 | "**Testing a sentence, takes a list of sentences.**" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 11, 279 | "metadata": {}, 280 | "outputs": [ 281 | { 282 | "data": { 283 | "text/plain": [ 284 | "[array([0.25924018, 0.98505855, 0.27304068, 0.23541152, 0.383295 ],\n", 285 | " dtype=float32)]" 286 | ] 287 | }, 288 | "execution_count": 11, 289 | "metadata": {}, 290 | "output_type": "execute_result" 291 | } 292 | ], 293 | "source": [ 294 | "model.text_inference(['Rodger Fedrer is one of the great tennis player. Rafael Nadal is very good player at clay court. They have combined won more than 30 grand slams till now.'])" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 12, 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [ 303 | "y_pred = [np.argmax(tup) for tup in prediction]\n", 304 | "y = df_score.label" 305 | ] 306 | }, 307 | { 308 | "cell_type": "markdown", 309 | "metadata": {}, 310 | "source": [ 311 | "## Accuracy and Report" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 13, 317 | "metadata": {}, 318 | "outputs": [ 319 | { 320 | "name": "stdout", 321 | "output_type": "stream", 322 | "text": [ 323 | "\n", 324 | "\n", 325 | "Accuracy: 0.980\n", 326 | "Precision: 0.983\n", 327 | "Recall: 0.982\n", 328 | "F1-Score: 0.982\n" 329 | ] 330 | }, 331 | { 332 | "data": { 333 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUUAAAEmCAYAAAD1FIKpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3XucVXW9//HXZwaUiwyCMDAMIIqI4yXGidAElSSMCNEDqRUqFYWllkVWVp6THi9pJgc7en6JxxIlr5nhUUI4Fpn8wAuIxk0EHEVuw9UZboozn/PHWoMrnJm9Z5g9a6/N+/l47Mfstdfa3/WeJX7mu27fZe6OiIgE8uIOICKSTVQURUQiVBRFRCJUFEVEIlQURUQiVBRFRCJUFA8hZtbWzP7HzN4zs8cPop1xZja7ObPFxczONLM34s4h2cN0nWL2MbOvAJOAE4AqYDFws7u/cJDtXgp8BzjD3T886KBZzswc6Ofuq+LOIsmhnmKWMbNJwBTgFqAb0Bv4L+D8Zmj+aGDloVAQ02FmreLOIFnI3fXKkhfQEdgJXNjAMocTFM314WsKcHg4byjwLvADoALYAHwtnHcD8AGwL1zHBOB6YHqk7T6AA63C6a8Cawh6q28B4yKfvxD53hnAy8B74c8zIvPmAjcC88J2ZgNd6vndavP/KJL/AmAksBLYBvw0svwgYD6wI1z2LuCwcN7z4e+yK/x9L460/2NgI/Bg7Wfhd/qG6ygLp3sAW4Chcf/b0KvlXuopZpdPA22AJxtY5mfA6UApMICgMFwXmd+doLgWExS+u82sk7v/nKD3+ai7H+Hu9zUUxMzaA78GPu/uHQgK3+I6lusMPBMuexQwGXjGzI6KLPYV4GtAIXAYcE0Dq+5OsA2KgX8D7gUuAT4JnAn8m5kdGy5bDXwf6EKw7YYBVwC4+1nhMgPC3/fRSPudCXrNE6MrdvfVBAXz92bWDvgdcL+7z20gr+QYFcXschSwxRvevR0H/Lu7V7j7ZoIe4KWR+fvC+fvcfSZBL6l/E/PUACebWVt33+DuS+tY5gvAm+7+oLt/6O4PAyuA8yLL/M7dV7r7HuAxgoJen30Ex0/3AY8QFLw73b0qXP9S4BMA7r7Q3ReE6y0H7gHOTuN3+rm7vx/m+Sfufi/wJvAiUETwR0gOISqK2WUr0CXFsa4ewNuR6bfDz/a3cUBR3Q0c0dgg7r6LYJfzW8AGM3vGzE5II09tpuLI9MZG5Nnq7tXh+9qitSkyf0/t983seDN72sw2mlklQU+4SwNtA2x2970plrkXOBn4T3d/P8WykmNUFLPLfGAvwXG0+qwn2PWr1Tv8rCl2Ae0i092jM939WXcfTtBjWkFQLFLlqc20romZGuP/EeTq5+4FwE8BS/GdBi+3MLMjCI7T3gdcHx4ekEOIimIWcff3CI6j3W1mF5hZOzNrbWafN7Nfhos9DFxnZl3NrEu4/PQmrnIxcJaZ9TazjsBPameYWTczGx0eW3yfYDe8uo42ZgLHm9lXzKyVmV0MnAg83cRMjdEBqAR2hr3Ybx8wfxNw7Me+1bA7gYXu/g2CY6W/OeiUkigqilnG3ScTXKN4HbAZWAtcBfwpXOQm4BXgdeAfwKLws6asaw7waNjWQv65kOURnMVeT3BG9mzCkxgHtLEVGBUuu5XgzPEod9/SlEyNdA3BSZwqgl7sowfMvx6YZmY7zOyiVI2Z2fnACIJDBhD8dygzs3HNlliyni7eFhGJUE9RRCRCRVFEJEJFUUQkQkVRRCQiq26Ib922wA/vUBh3jEYrOGp33BFEMmrHhh3s3rE71TWgjdKrVy/fuzfVdfQf2bJly7PuPqI5M9Qlq4ri4R0KOeXiO+KO0Wifu/TVuCOIZNTU8VObvc29e/cyZsyY9DNMnZrqbqVmod1nEZEIFUURkQgVRRGRCBVFEZGIrDrRIiLSVGZWTnAffDXwobsPDEc5epRgVPly4CJ3395QO+opikgu+Yy7l7r7wHD6WuA5d+8HPBdON0hFUURy2fnAtPD9NBoeqxRQURSR5OhiZq9EXhMPmO/AbDNbGJnXzd03AIQ/U94domOKIpIUWyK7xXUZ7O7rzawQmGNmK5qyEvUURSQnuPv68GcFwRMxBwGbzKwIIPxZkaodFUURSTwza29mHWrfA+cCS4CngPHhYuOBGana0u6ziOSCbsCTZgZBXXvI3WeZ2cvAY2Y2AXgHuDBVQyqKIpJ47r4GGFDH51uBYY1pS7vPIiIROddTfL9qM6v/904+2L0DM6PwpHMpGnAeu7a8xVt//Q3V+/ZweEEhx507iVaHtUvdYIxWzV/FrMmzqKmpoWx0GUPGD4k7UkozbpzBynkrad+pPVc8/LGH/2W1JG7vWknOnm1yrqdoefkcPfhrlI67i5O/+Es2vf5ndm9by5q/3E3vMy5lwFd+TedjT2fDoifjjtqgmuoaZt4+k3FTxnHlI1eyZPYSNq/ZHHeslEpHlXLJlEvijtFoSd3ekOzs2SjniuJh7TvTvrAvAPmHtaVt5558sHMre7evo0OPkwDo2GsA21bPjzNmSuuWraNzz850Ku5Efut8Thp+Eiueb9JlVy3q6FOPpm1B27hjNFpStzckO3s2yrmiGLW3chO7Nq/hiO7H0/ao3mx/6yUAtq36/7y/syWe1d50VRVVFHQr2D9dUFhA1eaqGBPltiRv7yRnz0YZLYpmNsLM3jCzVWaW8kbs5lT9wR7e/PNt9DlzAq0Oa0ffYd9h0z9m8o9HJ1G9bw95ea1bMk6jOf7xD5v1CRkSleTtneTs2ShjJ1rMLB+4GxgOvAu8bGZPufuyTK2zVk31h6z88210Of5sOvf9NABtO/Wk5PwbANizfR3byxdmOsZBKSgsoHJT5f7pyopKOnTpEGOi3Jbk7Z3k7Nkokz3FQcAqd1/j7h8AjxCMWJFR7s6av9xF2849KTr1o9Xt270jnF/Dulcep9vJn8t0lINSXFLM1rVb2b5+O9X7qlk6Zyn9z+ofd6ycleTtneTs2SiTl+QUA2sj0+8Cp2VwfQBUbVjOljfm0u6oo3n9ke8B0Ov0S9j73gY2vf5nADr3PZ2uJY26nrPF5bXKY+Q1I5n+3el4jVN6XimFx2b/41+fuO4JyheVs3vHbiaPmszQiUMpG10Wd6yUkrq9IdnZs1Emi2JdRzU+dvAjHOJnIsBhR3Q96JUW9DiR06/6U53zigacd9Dtt6R+g/vRb3C/uGM0ytibxsYdocmSuL1rJTl7tsnk7vO7QK/IdE9g/YELuftUdx/o7gNbty04cLaISIvKZFF8GehnZseY2WHAlwhGrBARyVoZ23129w/N7CrgWSAf+K27L83U+kREmkNG731295nAzEyuQ0SkOeX0HS0iIo2Vc6PkiEgydK7ZwZd3pRwIe7+pGcwSpZ6iiEiEiqKISISKoohIhIqiiEiEiqKISISKoohIhIqiiEiEiqKISISKoohIhIqiiEiEiqKISISKoohIhIqiiEiEiqKISISKoohIRFaNp1hw1G4+d+mrccdotOsH9Yg7QpNd/9LHniUmckhTT1FEJEJFUUQkQkVRRCRCRVFEJEJFUURyhpnlm9mrZvZ0OH2Mmb1oZm+a2aNmdliqNlQURSSXXA0sj0zfBvyHu/cDtgMTUjWgoigiOcHMegJfAP47nDbgHOAP4SLTgAtStaOiKCJJ0cXMXom8Jh4wfwrwI6AmnD4K2OHuH4bT7wLFqVaSVRdvi4g0YIu7D6xrhpmNAircfaGZDa39uI5FPdVKVBRFJBcMBkab2UigDVBA0HM80sxahb3FnkDKW7i0+ywiiefuP3H3nu7eB/gS8Bd3Hwf8FfhiuNh4YEaqtlQURSSX/RiYZGarCI4x3pfqC9p9FpFYtNvWlVMfPvBcSUNuSGspd58LzA3frwEGNSaXeooiIhEqiiIiESqKIiIROX9McdX8VcyaPIuamhrKRpcxZPyQuCPVa+3abVx22e/YuLGSvDxj4sQzufrqYWzbtouLL76X8vKt9OlzFI899k06dWofd9x6JWmbRyU1NyQ7e7bJ6Z5iTXUNM2+fybgp47jykStZMnsJm9dsjjtWvVq1yueOOy5k+fIbWLDgWu6+ey7Llq3n1ltnMWzYCbz55o0MG3YCt946K+6o9UraNq+V1NyQ7OzZKKeL4rpl6+jcszOdijuR3zqfk4afxIrnV8Qdq15FRR0pK+sNQIcObSgpKWLduh3MmPEa48d/GoDx4z/Nn/70WpwxG5S0bV4rqbkh2dmzUU4XxaqKKgq6FeyfLigsoGpzVYyJ0ldevoVXX32H0047hk2bKikq6ggEhbOiInt/h6Ru86TmhmRnz0YZK4pm9lszqzCzJZlaRype122Odd0NmWV27tzL2LH3MGXKRRQUtI07TqMkdZsnNTckO3s2ymRP8X5gRAbbT6mgsIDKTZX7pysrKunQpUOMiVLbt6+asWPvYdy4QYwZUwZAt24FbNjwHgAbNrxHYWH2/g5J3OaQ3NyQ7OzZKGNF0d2fB7Zlqv10FJcUs3XtVrav3071vmqWzllK/7P6xxmpQe7OhAkPUFLSnUmThu//fPToTzBt2nwApk2bz/nnD4grYkpJ2+a1kpobkp09G8V+SU44JtpEgI7dOzZr23mt8hh5zUimf3c6XuOUnldK4bGFzbqO5jRv3moefHABp5xSTGnpjQDccssFXHvtCC66aCr33TeP3r078fjjl8ectH5J2+a1kpobkp09G5l7yuHFmt64WR/gaXc/OZ3le5T08InTGnMvZHa4flCPuCM02fUvpRxJSYSp46eyfvn6Zj1Sear18Ll56f//fmTNDQvrG0+xOeX02WcRkcZSURQRicjkJTkPA/OB/mb2rpmlfIqWiEjcMnaixd2/nKm2RUQyRbvPIiIRKooiIhEqiiIiESqKIiIRKooiIhEqiiIiESqKIiIRKooiIhEqiiIiESqKIiIRKooiIhEqiiIiEbGPvC0ih6Y3u7ZjxMWnpv+FuzKXJUo9RRGRCBVFEZEIFUURkQgVRRGRiHpPtJhZQUNfdPfKhuYfSpL8RLwN922IO0KTFE0oijuC5KiGzj4vBRyIPtawdtqB3hnMJSISi3qLorv3askgIiJNZWZtgOeBwwnq2h/c/edmdgzwCNAZWARc6u4fNNRWWscUzexLZvbT8H1PM/vkwfwCIiLN7H3gHHcfAJQCI8zsdOA24D/cvR+wHUj5VNGURdHM7gI+A1wafrQb+E0Tg4uINDsP7AwnW4cvB84B/hB+Pg24IFVb6fQUz3D3y4G94cq3AYc1NrSIyEHqYmavRF4TozPNLN/MFgMVwBxgNbDD3T8MF3kXKE61knRu89tnZnkEVRczOwqoacQvIiLSHLa4+8D6Zrp7NVBqZkcCTwIldS2WaiXp9BTvBp4AuprZDcALBPvpIiJZx913AHOB04Ejzay289cTSHn9XMqeors/YGYLgc+GH13o7kuaFldEpPmZWVdgn7vvMLO2BPXqNuCvwBcJzkCPB2akaivdUXLygX0EXU/dBSMi2aYImGZm+QQ16jF3f9rMlgGPmNlNwKvAfakaSlkUzexnwFcI9tENeMjMfu/uvziY30BEpLm4++vAx8Yhc/c1wKDGtJVOT/ES4JPuvhvAzG4GFgIqiiKSc9LZFX6bfy6erYA1mYkjIhKvhgaE+A+CY4i7gaVm9mw4fS7BGWgRkZzT0O5z7RnmpcAzkc8XZC6OiEi8GhoQIuVZGhGRXJPO2ee+wM3AiUCb2s/d/fgM5mo2M26cwcp5K2nfqT1XPHxF3HEaZdX8VcyaPIuamhrKRpcxZPyQuCPV66GHHqJ169bk5eVhZowZM4YFCxbw9ttvk5+fT0FBAWeffTaHH3543FHrlaTtfaAkZ8826Zx9vh+4CfgV8HngayToNr/SUaUMunAQT97wZNxRGqWmuoaZt8/k0v+8lILCAu796r30P7M/XY/tGne0ep133nm0abP/7yY9e/Zk0KBB5OXl8eKLL7J48WJOO+20GBPWL4nbu1aSs2ejdM4+t3P3ZwHcfbW7X0cwak4iHH3q0bQtaBt3jEZbt2wdnXt2plNxJ/Jb53PS8JNY8fyKuGM1Ss+ePcnLC/6JFRYWsmvXrpgT1S/J2zvJ2bNROj3F983MgNVm9i1gHVCY2VhSVVFFQbePnghRUFjAuqXrYkzUMDPjmWeewcwoKSmhpOSf78V/44036Nu3b0zpUkva9o5KcvZslE5R/D5wBPBdgmOLHYGvp/qSmfUCHgC6E+xuT3X3O5se9dDidQ3mYR//KFuMHj2a9u3bs2fPHp555hmOPPJIioqC56gsWrSIvLw8jjvuuJhT1i9p2zsqydmzUToDQrwYvq3io4Fm0/Eh8AN3X2RmHYCFZjbH3Zc1Iechp6CwgMpNHz0brLKikg5dOsSYqGHt27cHoG3btvTp04eKigqKiopYuXIl77zzDqNGjSLY4chOSdveUUnOno3qPaZoZk+a2R/re6Vq2N03uPui8H0VsJw0BniUQHFJMVvXbmX7+u1U76tm6Zyl9D+rf9yx6rRv3z4++OCD/e/XrVtH586dWbt2LYsXL+Zzn/scrVqlO/ZIPJK0vQ+U5OzZqKF/qXc110rMrA/Bzdov1jFvIjARoGP3js21yv2euO4JyheVs3vHbiaPmszQiUMpG13W7Otpbnmt8hh5zUimf3c6XuOUnldK4bHZeSh3z549zJ49GwB3p2/fvvTq1YtHHnmE6upqZs6cCQQnW84888w4o9YrSdv7QEnOno3MPeVAtAe3ArMjgL8BN7t7gz3MHiU9fOK0iQ0tIs1Mz32WdEwdP5X1y9c36/GPgQOP9lde+Vnay5tdvrChkbebS0bHRjSz1gSjdv8+VUEUEckGGSuK4WU89wHL3X1yptYjItKc0i6KZtbY+7MGE5ytPsfMFoevkY1sQ0SkRaVz7/Mggh5fR6C3mQ0AvuHu32noe+7+ArpaSkQSJp2e4q+BUcBWAHd/jQTd5ici0hjpFMU8d3/7gM+qMxFGRCRu6VxRuzbchfbwSVnfAVZmNpaISDzS6Sl+G5gE9AY2ETxg+tuZDCUiEpd07n2uAL7UAllERGKXztnne+Hjw3C4u249EZGck84xxf+NvG8D/AuwNjNxRETilc7u86PRaTN7EJiTsUQiIjFqym1+xwBHN3cQEZFskM4xxe18dEwxD9gGXJvJUCIicWmwKIaDOgwgeC4LQI1neqwxEZEYNbj7HBbAJ929OnypIIpITkvnmOJLZpb9Q1WLiDSDenefzayVu38IDAG+aWargV0EI9+4u6tQikjOaeiY4ktAGXBBC2UREYldQ0XRANx9dQtlERGJXUNFsauZTapvph4xICK5qKGimA8cgUbPzmlJfSre9YN6xB2hSa5/aX3cESSFhoriBnf/9xZLIiLSRGbWC3gA6A7UAFPd/U4z6ww8CvQByoGL3H17Q201dEmOeogikhQfAj9w9xKCMV+vNLMTCe6+e87d+wHPkcbdeA0VxWHNkVREJNPcfYO7LwrfVwHLgWLgfGBauNg00riapt7dZ3ffdvBRRUSaTRczeyUyPdXdpx64kJn1AU4FXgS6ufsGCAqnmRWmWkk64ymKiGSDLe4+sKEFzOwI4Ange+5eGQzf0DhNGTpMRCTrmFlrgoL4e3f/Y/jxJjMrCucXARWp2lFPUUTiUVVNzd+2NktT4Yhe9wHLD7iG+ilgPHBr+HNGqrZUFEUkFwwGLgX+YWaLw89+SlAMHzOzCcA7wIWpGlJRFJHEc/cXqP8ywkZdSaNjiiIiESqKIiIRKooiIhEqiiIiESqKIiIRKooiIhE5f0nOjBtnsHLeStp3as8VD18Rd5xGWTV/FbMmz6Kmpoay0WUMGT8k7khpSUrutWu3cdllv2Pjxkry8oyJE8/k6qs/unrjV7+azQ9/+ASbN99Bly5HxJg0taRs8yTI+Z5i6ahSLplySdwxGq2muoaZt89k3JRxXPnIlSyZvYTNazbHHSulJOVu1SqfO+64kOXLb2DBgmu5++65LFsWDAK7du025sxZTu/enWNOmVqStnkS5HxRPPrUo2lb0DbuGI22btk6OvfsTKfiTuS3zuek4Sex4vkVccdKKUm5i4o6UlbWG4AOHdpQUlLEunU7APj+9x/nl78cQ1MGFGhpSdrmSZDzRTGpqiqqKOhWsH+6oLCAqs1VMSZKT1Jzl5dv4dVX3+G0047hqadeo7j4SAYM6BV3rLQkdZtnq4wdUzSzNsDzwOHhev7g7j/P1PpyjeMf/zD7Oy2JzL1z517Gjr2HKVMuolWrfG6+eSazZ38v7lhpS+I2z2aZ7Cm+D5zj7gOAUmCEmZ2ewfXllILCAio3Ve6frqyopEOXDjEmSk/Scu/bV83YsfcwbtwgxowpY/Xqzbz11lYGDLiRPn1+yrvvbqes7CY2bnwv7qj1Sto2z3YZK4oe2BlOtg5fdfxJk7oUlxSzde1Wtq/fTvW+apbOWUr/s/rHHSulJOV2dyZMeICSku5MmjQcgFNOKaai4leUl99Cefkt9OzZiUWLrqN7944xp61fkrZ5EmT0khwzywcWAscBd7v7i5lcX12euO4JyheVs3vHbiaPmszQiUMpG13W0jEaLa9VHiOvGcn0707Ha5zS80opPDblSOqxS1LuefNW8+CDCzjllGJKS28E4JZbLmDkyFNiTtY4SdrmSZDRouju1UCpmR0JPGlmJ7v7kugyZjYRmAjQMQN/jcfeNLbZ22wp/Qb3o9/gfnHHaLSk5B4y5Djc72lwmfLyW1oozcFJyjZPghY5++zuO4C5wIg65k1194HuPrDdke1aIo6ISL0yVhTNrGvYQ8TM2gKfBXTxlIhktUzuPhcB08LjinnAY+7+dAbXJyJy0DJWFN39dYJnr4qIJIbuaBERiVBRFBGJUFEUEYlQURQRiVBRFBGJUFEUEYlQURQRicj5Z7SISHaqqu7I37aNbsQ3fpqxLFHqKYqIRKgoiohEqCiKiESoKIqIRKgoiohEqCiKiESoKIqIRKgoiohEqCiKSE4ws9+aWYWZLYl81tnM5pjZm+HPTqnaUVEUkVxxPx9/ON61wHPu3g94LpxukG7zk0S6/qX1cUdokusH9Yg7QpM83b513BFScvfnzazPAR+fDwwN308jeKrojxtqRz1FEcll3dx9A0D4szDVF9RTFJGk6GJmr0Smp7r71OZeiYqiiCTFFncf2MjvbDKzInffYGZFQEWqL2j3WURy2VPA+PD9eGBGqi+oKIpITjCzh4H5QH8ze9fMJgC3AsPN7E1geDjdIO0+i0hOcPcv1zNrWGPaUU9RRCRCRVFEJEJFUUQkQkVRRCRCRVFEJEJFUUQkQkVRRCRCRVFEJEJFUUQkIufvaFk1fxWzJs+ipqaGstFlDBk/JO5IaUtqduXOvLVrt3HZZb9j48ZK8vKMiRPP5Oqrh7Ft2y4uvvheysu30qfPUTz22Dfp1Kl93HETJad7ijXVNcy8fSbjpozjykeuZMnsJWxesznuWGlJanblbhmtWuVzxx0Xsnz5DSxYcC133z2XZcvWc+utsxg27ATefPNGhg07gVtvnRV31MTJ6Z7iumXr6NyzM52Kg8cynDT8JFY8v4Kux3aNOVlqSc2u3C2jqKgjRUUdAejQoQ0lJUWsW7eDGTNeY+7cHwAwfvynGTr0Dm67bWycUetV1Xo7fyt+PO4YH5PTPcWqiioKuhXsny4oLKBqc1WMidKX1OzK3fLKy7fw6qvvcNppx7BpU+X+YllU1JGKimT8Dtkk40XRzPLN7FUzezrT6zqQ43UEaukUTZPU7Mrdsnbu3MvYsfcwZcpFFBS0jTtOTmiJnuLVwPIWWM/HFBQWULmpcv90ZUUlHbp0iCNKoyU1u3K3nH37qhk79h7GjRvEmDFlAHTrVsCGDe8BsGHDexQWZvfvkI0yWhTNrCfwBeC/M7me+hSXFLN17Va2r99O9b5qls5ZSv+z+scRpdGSml25W4a7M2HCA5SUdGfSpOH7Px89+hNMmzYfgGnT5nP++QPiiphYmT7RMgX4EVDvnyszmwhMBOjYvWOzrjyvVR4jrxnJ9O9Ox2uc0vNKKTw25cO8skJSsyt3y5g3bzUPPriAU04pprT0RgBuueUCrr12BBddNJX77ptH796dePzxy2NOmjwZK4pmNgqocPeFZja0vuXCp3FNBehR0qOOAzsHp9/gfvQb3K+5m20RSc2u3Jk3ZMhxuN9T57znnpvUwmlySyZ3nwcDo82sHHgEOMfMpmdwfSIiBy1jRdHdf+LuPd29D/Al4C/ufkmm1ici0hxy+jpFEZHGapE7Wtx9LjC3JdYlInIw1FMUEYlQURQRiVBRFBGJUFEUEYlQURQRiVBRFBGJUFEUEYlQURQRiVBRFBGJUFEUEYlQURQRiVBRFBGJUFEUEYlQURQRiVBRFJGcYGYjzOwNM1tlZtc2tR0VRRFJPDPLB+4GPg+cCHzZzE5sSlsqiiKSCwYBq9x9jbt/QPBcqPOb0pC5N/sD9JrMzDYDb2eo+S7Algy1nUnK3fKSmj2TuY92967N2aCZzSLInK42wN7I9NTwaaCY2ReBEe7+jXD6UuA0d7+qsbla5HEE6WrujR5lZq+4+8BMtZ8pyt3ykpo9abndfUQzNmd1raIpDWn3WURywbtAr8h0T2B9UxpSURSRXPAy0M/MjjGzwwgeq/xUUxrKqt3nDJsad4AmUu6Wl9TsSc190Nz9QzO7CngWyAd+6+5Lm9JWVp1oERGJm3afRUQiVBRFRCJUFEVyhJnVdVmKNFLOFkUz629mnzaz1uEtQImS0MzHmdlAMzs87iyNYWYnmdnZZnZU3Fkay8yGhBcq4+6uwnjwcvLss5mNAW4B1oWvV8zsfnevjDdZamZ2vLuvdPdqM8t39+q4M6XDzEYRbPOtwEYz+7m7r4w5Vkpm9nngNmAN0NrMJrj7xphjpWRmeUA74J5g0tq7+2/Cwpjn7jUxR0ysnOspmllr4GJggrsPA2YQXNT5IzMriDVcCmFhWWxmDwHUFsaYY6VkZmcAvwLGu/tngO1Ak0cpaSlmNhS4E/iGu19jY0xlAAAFfklEQVQAfACcHGuoNLl7jbvvBKYB9wFnmNn3a+fFGi7hcq4ohgqAfuH7J4GngcOAr2Tr7oWZtQeuAr4HfGBm0yE5hRG41d1fDd//HOicgN3oTcDl7v6SmXUHTgOuMrN7zOyL2fpv5QAfEvzRnwYMMrPJZvYLC+Tq/98ZlXMbzd33AZOBMWZ2ZvhX8wVgMTAk1nANcPddwNeBh4BrgDbRwhhntjS8CPwR9h8LPRw4muCPE9l6rM7dl7v7X8PJCcB/hT3GBcCFNG6wgrjMADa6+3PAK8C3gAIPqMfYBDlXFEN/B2YDl5rZWe5e7e4PAT2AAfFGq5+7r3f3ne6+BbgcaFtbGM2szMxOiDdh3cLtW3u81oAdwDZ332xm44CbzKxtfAlTc/eb3f2m8P3vgA7887202WoP0N/MvklQEG8FepvZ5fHGSq6cPNHi7nvN7PcEo2T8JCwm7wPdgA2xhkuTu28N/2HfbmYrCG5d+kzMsVJy9w+BnWa21sx+AZwLfNXd98QcrV5mZh65tcvMxhL8W2nSgAItyd3Xm9la4F+BK939f8zsM8CqmKMlVk7f5hfeGD6YoNe1F7gzctwrEcKD5z8Ghrv7P+LOk0p4HK41sDz8Oczd34w3VXrCY6CXAJOAi919ScyR0mJmvYBCd18YTuvs80HI6aJYKzzOlbhjLGbWCXgM+IG7vx53nsYws68CLzf1pvw4hFcuDAdWu/sbcedprAN7vNI0h0RRTDIza+Pue1MvmV30P6gklYqiiEhErp59FhFpEhVFEZEIFUURkQgVRRGRCBXFHGFm1Wa22MyWmNnjZtbuINoaamZPh+9Hm1m9gzuY2ZFmdkUT1nG9mV2T7ucHLHN/+JzfdNfVx8wScc2hxE9FMXfscfdSdz+ZYLSXb0VnNnWAAHd/yt1vbWCRI4FGF0WRbKWimJv+DhwX9pCWm9l/AYuAXmZ2rpnNN7NFYY/yCAAzG2FmK8zsBWBMbUNm9lUzuyt8383MnjSz18LXGQT32vYNe6m3h8v90MxeNrPXzeyGSFs/M7M3zOx/gf6pfgkz+2bYzmtm9sQBvd/PmtnfzWxlOOQaZpZvZrdH1q37f6XRVBRzjJm1Aj4P1N4S2B94wN1PBXYB1wGfdfcyglFVJplZG+Be4DzgTKB7Pc3/Gvibuw8AyoClBOMmrg57qT80s3MJhm0bBJQCnzSzs8zskwTP4j2VoOh+Ko1f54/u/qlwfcsJRrKp1Qc4G/gC8Jvwd5gAvOfunwrb/6aZHZPGekT2y8kBIQ5Rbc1scfj+7wQDj/YA3nb3BeHnpwMnAvPCoQIPA+YDJwBv1d6jHI7MM7GOdZwDXAb7hzN7L7wVMerc8FV7j/kRBEWyA/Cku+8O15HOg8pPNrObCHbRjyB4pm+tx8LbNt80szXh73Au8InI8caO4bqzfgRwyR4qirljj7uXRj8IC9+u6EfAHHf/8gHLlRKMKNQcDPiFu99zwDq+14R13A9c4O6vhfdSD43MO7AtD9f9HXePFk/MrE8j1yuHMO0+H1oWAIPN7DgAM2tnZscDK4BjzKxvuNyX6/n+c8C3w+/mW/B4hyqCXmCtZ4GvR45VFptZIfA88C9m1tbMOhDsqqfSAdgQDtQw7oB5F5pZXpj5WOCNcN3fDpfHzI63YERzkbSpp3gICQd9/SrwsH30qIDr3H2lmU0EnjGzLQQjldf1rJKrgalmNgGoBr7t7vPNbF54ycufw+OKJcD8sKe6E7jE3ReZ2aMEI6C/TbCLn8q/Eozq/TbBMdJo8X0D+BvBuIffCsfQ/G+CY42LwiHMNgMXpLd1RAIaEEJEJEK7zyIiESqKIiIRKooiIhEqiiIiESqKIiIRKooiIhEqiiIiEf8HppyUvAd1WKQAAAAASUVORK5CYII=\n", 334 | "text/plain": [ 335 | "
" 336 | ] 337 | }, 338 | "metadata": {}, 339 | "output_type": "display_data" 340 | } 341 | ], 342 | "source": [ 343 | "metrics(y,y_pred)" 344 | ] 345 | } 346 | ], 347 | "metadata": { 348 | "kernelspec": { 349 | "display_name": "conda_pytorch_p36", 350 | "language": "python", 351 | "name": "conda_pytorch_p36" 352 | }, 353 | "language_info": { 354 | "codemirror_mode": { 355 | "name": "ipython", 356 | "version": 3 357 | }, 358 | "file_extension": ".py", 359 | "mimetype": "text/x-python", 360 | "name": "python", 361 | "nbconvert_exporter": "python", 362 | "pygments_lexer": "ipython3", 363 | "version": "3.6.5" 364 | } 365 | }, 366 | "nbformat": 4, 367 | "nbformat_minor": 2 368 | } 369 | -------------------------------------------------------------------------------- /sample_notebooks/PYTORCH_classi_transformers_sentiment_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Pytorch Framework for Multi-class Text Classification" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Sentiment Classification - 3 Class Problem\n", 15 | "\n", 16 | "* Financial Phrasebank Dataset
\n", 17 | "\n", 18 | " * 0 - Positive
\n", 19 | " * 1 - Negative
\n", 20 | " * 2 - Neutral
" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "## Reading Data" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 1, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "import sys\n", 37 | "import numpy as np\n", 38 | "import pandas as pd\n", 39 | "\n", 40 | "sys.path.append(\"..\")\n", 41 | "pd.set_option('display.max_colwidth', -1)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "data": { 51 | "text/plain": [ 52 | "((3100, 3), (970, 3), (776, 3))" 53 | ] 54 | }, 55 | "execution_count": 2, 56 | "metadata": {}, 57 | "output_type": "execute_result" 58 | } 59 | ], 60 | "source": [ 61 | "df_train = pd.read_csv('../datasets/financial_phrasebank/training.csv', sep='\\t')\n", 62 | "df_test = pd.read_csv('../datasets/financial_phrasebank/testing.csv', sep='\\t')\n", 63 | "df_val = pd.read_csv('../datasets/financial_phrasebank/validation.csv', sep='\\t')\n", 64 | "\n", 65 | "df_train.shape, df_test.shape, df_val.shape" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 3, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "data": { 75 | "text/plain": [ 76 | "neutral 1821\n", 77 | "positive 892 \n", 78 | "negative 387 \n", 79 | "Name: label, dtype: int64" 80 | ] 81 | }, 82 | "execution_count": 3, 83 | "metadata": {}, 84 | "output_type": "execute_result" 85 | } 86 | ], 87 | "source": [ 88 | "df_train.label.value_counts()" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 4, 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "data": { 98 | "text/plain": [ 99 | "neutral 475\n", 100 | "positive 208\n", 101 | "negative 93 \n", 102 | "Name: label, dtype: int64" 103 | ] 104 | }, 105 | "execution_count": 4, 106 | "metadata": {}, 107 | "output_type": "execute_result" 108 | } 109 | ], 110 | "source": [ 111 | "df_val.label.value_counts()" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "## Processing Data in Input format" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 5, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "data_classes = ['positive', 'negative', 'neutral']\n", 128 | "\n", 129 | "df_train['label'] = df_train['label'].apply(data_classes.index)\n", 130 | "df_test['label'] = df_test['label'].apply(data_classes.index)\n", 131 | "df_val['label'] = df_val['label'].apply(data_classes.index)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 6, 137 | "metadata": {}, 138 | "outputs": [ 139 | { 140 | "data": { 141 | "text/html": [ 142 | "
\n", 143 | "\n", 156 | "\n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | "
Unnamed: 0textlabel
0178comparable operating profit totaled EUR 854mn , up from EUR 730mn in 2004 .0
14157Finnish M-real 's Kyro mill is starting negotiations to cut staff .1
23616In 2007 , the Group 's net sales stood at EUR 42 million and it had about 1,445 employees .2
3194Finnish lifting equipment maker Konecranes Oyj said on July 30 , 2008 that its net profit rose to 71.2 mln euro ( $ 111.1 mln ) for the first half of 2008 from 57.1 mln euro ( $ 89.1 mln ) for the same period of 2007 .0
41584Pulkovo park will be ready in 2016 , its first stage of 23,000 sq. m. will be finished in the first quarter of 2010 .2
\n", 198 | "
" 199 | ], 200 | "text/plain": [ 201 | " Unnamed: 0 \\\n", 202 | "0 178 \n", 203 | "1 4157 \n", 204 | "2 3616 \n", 205 | "3 194 \n", 206 | "4 1584 \n", 207 | "\n", 208 | " text \\\n", 209 | "0 comparable operating profit totaled EUR 854mn , up from EUR 730mn in 2004 . \n", 210 | "1 Finnish M-real 's Kyro mill is starting negotiations to cut staff . \n", 211 | "2 In 2007 , the Group 's net sales stood at EUR 42 million and it had about 1,445 employees . \n", 212 | "3 Finnish lifting equipment maker Konecranes Oyj said on July 30 , 2008 that its net profit rose to 71.2 mln euro ( $ 111.1 mln ) for the first half of 2008 from 57.1 mln euro ( $ 89.1 mln ) for the same period of 2007 . \n", 213 | "4 Pulkovo park will be ready in 2016 , its first stage of 23,000 sq. m. will be finished in the first quarter of 2010 . \n", 214 | "\n", 215 | " label \n", 216 | "0 0 \n", 217 | "1 1 \n", 218 | "2 2 \n", 219 | "3 0 \n", 220 | "4 2 " 221 | ] 222 | }, 223 | "execution_count": 6, 224 | "metadata": {}, 225 | "output_type": "execute_result" 226 | } 227 | ], 228 | "source": [ 229 | "df_test.head()" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 7, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "df_bert_train = pd.DataFrame({'id':df_train['Unnamed: 0'],\n", 239 | " 'text':df_train['text'],\n", 240 | " 'label':df_train['label']})\n", 241 | "\n", 242 | "df_bert_test = pd.DataFrame({'id':df_test['Unnamed: 0'],\n", 243 | " 'text':df_test['text']})\n", 244 | "\n", 245 | "df_bert_val = pd.DataFrame({'id':df_val['Unnamed: 0'],\n", 246 | " 'text':df_val['text'],\n", 247 | " 'label':df_val['label']})" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": 8, 253 | "metadata": {}, 254 | "outputs": [], 255 | "source": [ 256 | "df_bert_train.to_csv('../datasets/financial_phrasebank/train.csv',index=False)\n", 257 | "df_bert_test.to_csv('../datasets/financial_phrasebank/test.csv',index=False)\n", 258 | "df_bert_val.to_csv('../datasets/financial_phrasebank/dev.csv',index=False)" 259 | ] 260 | }, 261 | { 262 | "cell_type": "markdown", 263 | "metadata": {}, 264 | "source": [ 265 | "## Training Transformers (Albert | Roberta | DistilBert) classifier with FrameWork" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 9, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "from classitransformers.report import metrics\n", 275 | "from classitransformers.configs import Configs\n", 276 | "from classitransformers.pytransformers import TransformersClassification" 277 | ] 278 | }, 279 | { 280 | "cell_type": "markdown", 281 | "metadata": {}, 282 | "source": [ 283 | "## Set Configurations\n", 284 | "\n", 285 | "* Pre-trained config filepath, vocab filepath and model checkpoint filepath\n", 286 | "* labels list\n", 287 | "* train, test and eval batch size\n", 288 | "* max_seq_len - how long is the each text input sentence !!\n", 289 | "* data_dir - data path folder (optional)\n", 290 | "* output_dir - output path folder (optional)" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 10, 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "config = Configs(\n", 300 | " pretrained_model_dir = '../models/DistilBert',\n", 301 | " model_name = 'distilbert',\n", 302 | " num_train_epochs=3,\n", 303 | " train_batch_size = 16,\n", 304 | " eval_batch_size = 8, \n", 305 | " do_train = True, \n", 306 | " do_eval = True, \n", 307 | " max_seq_length = 64,\n", 308 | " label_list = [\"0\", \"1\", \"2\"],\n", 309 | " data_dir = '../datasets/financial_phrasebank/',\n", 310 | " output_dir = '../transformers_output_sentiment/')" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 11, 316 | "metadata": {}, 317 | "outputs": [], 318 | "source": [ 319 | "model = TransformersClassification(config)" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "metadata": {}, 325 | "source": [ 326 | "## Train Method" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": 12, 332 | "metadata": { 333 | "scrolled": true 334 | }, 335 | "outputs": [ 336 | { 337 | "name": "stdout", 338 | "output_type": "stream", 339 | "text": [ 340 | "\n", 341 | "======== Epoch 1 / 3 ========\n", 342 | "Training...\n", 343 | " Batch 40 of 194. Elapsed: 0:00:08.\n", 344 | " Batch 80 of 194. Elapsed: 0:00:15.\n", 345 | " Batch 120 of 194. Elapsed: 0:00:23.\n", 346 | " Batch 160 of 194. Elapsed: 0:00:31.\n", 347 | "\n", 348 | " Average training loss: 0.64\n", 349 | " Training epcoh took: 0:00:37\n", 350 | "\n", 351 | "Running Validation...\n", 352 | " Accuracy: 0.83\n", 353 | " Validation Loss: 0.00\n", 354 | " Validation took: 0:00:03\n", 355 | "\n", 356 | "======== Epoch 2 / 3 ========\n", 357 | "Training...\n", 358 | " Batch 40 of 194. Elapsed: 0:00:08.\n", 359 | " Batch 80 of 194. Elapsed: 0:00:15.\n", 360 | " Batch 120 of 194. Elapsed: 0:00:23.\n", 361 | " Batch 160 of 194. Elapsed: 0:00:31.\n", 362 | "\n", 363 | " Average training loss: 0.28\n", 364 | " Training epcoh took: 0:00:37\n", 365 | "\n", 366 | "Running Validation...\n", 367 | " Accuracy: 0.85\n", 368 | " Validation Loss: 0.00\n", 369 | " Validation took: 0:00:03\n", 370 | "\n", 371 | "======== Epoch 3 / 3 ========\n", 372 | "Training...\n", 373 | " Batch 40 of 194. Elapsed: 0:00:08.\n", 374 | " Batch 80 of 194. Elapsed: 0:00:15.\n", 375 | " Batch 120 of 194. Elapsed: 0:00:23.\n", 376 | " Batch 160 of 194. Elapsed: 0:00:31.\n", 377 | "\n", 378 | " Average training loss: 0.11\n", 379 | " Training epcoh took: 0:00:38\n", 380 | "\n", 381 | "Running Validation...\n", 382 | " Accuracy: 0.84\n", 383 | " Validation Loss: 0.00\n", 384 | " Validation took: 0:00:03\n", 385 | "Saving model to ../transformers_output_sentiment/\n", 386 | "\n", 387 | "Training complete!\n", 388 | "Total training took 0:02:03 (h:mm:ss)\n" 389 | ] 390 | } 391 | ], 392 | "source": [ 393 | "model.train()" 394 | ] 395 | }, 396 | { 397 | "cell_type": "markdown", 398 | "metadata": {}, 399 | "source": [ 400 | "## Model Test" 401 | ] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": 13, 406 | "metadata": {}, 407 | "outputs": [], 408 | "source": [ 409 | "prediction = model.test()" 410 | ] 411 | }, 412 | { 413 | "cell_type": "markdown", 414 | "metadata": {}, 415 | "source": [ 416 | "**Testing a sentence, takes a list of sentences.**" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 14, 422 | "metadata": {}, 423 | "outputs": [ 424 | { 425 | "data": { 426 | "text/plain": [ 427 | "[array([0.9197695 , 0.07576627, 0.48963723], dtype=float32)]" 428 | ] 429 | }, 430 | "execution_count": 14, 431 | "metadata": {}, 432 | "output_type": "execute_result" 433 | } 434 | ], 435 | "source": [ 436 | "model.text_inference(['Trading gives the best returns among all the options of investemnts'])" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": 15, 442 | "metadata": {}, 443 | "outputs": [], 444 | "source": [ 445 | "y_pred = [np.argmax(tup) for tup in prediction]\n", 446 | "y = df_test.label" 447 | ] 448 | }, 449 | { 450 | "cell_type": "markdown", 451 | "metadata": {}, 452 | "source": [ 453 | "## Accuracy and Report" 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": 16, 459 | "metadata": {}, 460 | "outputs": [ 461 | { 462 | "name": "stdout", 463 | "output_type": "stream", 464 | "text": [ 465 | "\n", 466 | "\n", 467 | "Accuracy: 0.866\n", 468 | "Precision: 0.845\n", 469 | "Recall: 0.838\n", 470 | "F1-Score: 0.842\n" 471 | ] 472 | }, 473 | { 474 | "data": { 475 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAEmCAYAAADr3bIaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3XmcFPWd//HXe2aQ+xBGATlEBQlKFJGg0RgvNB4oiJLE28iKxuOna2Ji1I0Yc7jrRl3XHOK68b6F4KoRWddbAUERQeSQQy6FERDkcgY+vz+6Bluc6alBhu6B9/Px6Ed3VX276tMt8/ZbVV3fUkRgZma5FeW7ADOz+sBhaWaWgsPSzCwFh6WZWQoOSzOzFByWZmYpOCx3IJIaS/ofSZ9JevwbrOdMSc9vzdryRdJhkqbnuw4rfPLvLAuPpDOAK4FvAauAScDvIuK1b7jes4HLgEMiouIbF1rgJAXQLSJm5bsWq//csywwkq4EbgN+D7QFOgN/BgZshdXvDszYEYIyDUkl+a7B6pGI8KNAHkBL4HNgcI42DcmE6aLkcRvQMFl2BLAA+BmwBFgM/CRZdgPwBVCebGMIMAx4IGvdXYAASpLp84DZZHq3c4Azs+a/lvW+Q4C3gM+S50Oylr0E3Ai8nqzneaC0ms9WWf8vsuofCJwAzACWAddkte8LvAmsSNreAeyULHsl+Syrk8/7o6z1/xL4GLi/cl7ynr2SbfROpncDyoAj8v1vw4/8P9yzLCzfBRoBI3O0uRY4GOgF7E8mMK7LWt6OTOh2IBOIf5K0c0RcT6a3+mhENIuIu3MVIqkpcDtwfEQ0JxOIk6po1xp4JmnbBrgFeEZSm6xmZwA/AXYFdgJ+nmPT7ch8Bx2AXwN3AWcBBwKHAb+WtGfSdgPwz0Apme/uaOBigIj4ftJm/+TzPpq1/tZketlDszccER+SCdIHJTUB/gbcExEv5ajXdhAOy8LSBiiL3LvJZwK/iYglEbGUTI/x7Kzl5cny8oh4lkyvqvsW1rMR6CmpcUQsjoipVbQ5EZgZEfdHREVEPAx8AJyU1eZvETEjItYCj5EJ+uqUkzk+Ww48QiYI/yMiViXbnwrsBxAREyNibLLducCdwOEpPtP1EbE+qecrIuIuYCYwDmhP5n9OZg7LAvMpUFrDsbTdgHlZ0/OSeZvWsVnYrgGa1baQiFhNZtf1ImCxpGckfStFPZU1dcia/rgW9XwaERuS15Vh9knW8rWV75e0t6SnJX0saSWZnnNpjnUDLI2IdTW0uQvoCfxnRKyvoa3tIByWheVNYB2Z43TVWURmF7JS52TellgNNMmabpe9MCJGR8QxZHpYH5AJkZrqqaxp4RbWVBt/IVNXt4hoAVwDqIb35Pz5h6RmZI4D3w0MSw4zmDksC0lEfEbmON2fJA2U1ERSA0nHS/q3pNnDwHWSdpFUmrR/YAs3OQn4vqTOkloCv6pcIKmtpJOTY5fryezOb6hiHc8Ce0s6Q1KJpB8B+wBPb2FNtdEcWAl8nvR6f7rZ8k+APb/2rtz+A5gYEf9E5ljsX79xlbZdcFgWmIi4hcxvLK8DlgLzgUuBvydNfgtMACYD7wFvJ/O2ZFtjgEeTdU3kqwFXROas+iIyZ4gPJzl5stk6PgX6J20/JXMmu39ElG1JTbX0czInj1aR6fU+utnyYcC9klZI+mFNK5M0ADiOzKEHyPx36C3pzK1WsdVb/lG6mVkK7lmamaXgsDQzS8FhaWaWgsPSzCyFghpIoLS0WXTp0qbmhrbJotXl+S7BdgArFq9gzYo1Nf2GtVY6deoU69bVdH3Al8rKykZHxHFbs4baKKiw7NKlDRMm+Oqy2hg2fkt/j26W3vBzh2/1da5bt45Bgwalr2H48JquzqpT3g03M0vBYWlmloLD0swsBYelmVkKDkszsxQclmZmKTgszcxScFiamaXgsDSz7YKkuZLekzRJ0oRkXmtJYyTNTJ53TuZL0u2SZkmaLKl3Tet3WJrZ9uTIiOgVEX2S6auBFyKiG/BCMg1wPNAteQwlc4uSnByWZrY9GwDcm7y+ly/vbzUAuC8yxgKtJLXPtSKHpZnVF6WSJmQ9hm62PIDnJU3MWtY2IhYDJM+7JvM7kLllS6UFfPWOpF9TUANpmJnlUJa1e12VQyNikaRdgTGSPsjRtqoRlHLeY8c9SzPbLkTEouR5CTAS6At8Url7nTwvSZovADplvb0jNdxS2mFpZvWepKaSmle+Bo4FpgBPAecmzc4FRiWvnwLOSc6KHwx8Vrm7Xh3vhpvZ9qAtMFISZHLtoYh4TtJbwGOShgAfAYOT9s8CJwCzgDXAT2ragMPSzOq9iJgN7F/F/E+Bo6uYH8AltdmGd8PNzFJwWJqZpeCwNDNLwWFpZpaCw9LMLAWHpZlZCg5LM7MUHJZmZik4LM3MUnBYmpml4LA0M0vB14abWV603riC01ePqrlhYngd1pKGe5ZmZins8D3L+fOXcc45f+Pjj1dSVCSGDj2Myy8/mscfn8iwYf/DtGkfM3781fTp0wWA8ePnMHToAwBEwLBh/TnllAPy+AkKS9m8Mp649olN08sXLufIoUdy8OkH57GqwrVxw0buOu8umu/SnDNuOYNRvx3F4mmLCYI2ndow8NcD2anJTvku03BYUlJSzB//OJjevTuzatU6DjzwdxxzTA969tyNESMu4sILH/xK+549OzBhwjWUlBSzePFn7L//jZx00n6UlBTn6RMUltLdS7nogYuATBDc0v8WvnXEt/JcVeEa9+g4SruUsn71egCOu+I4GjZrCMDo20Yz/vHxfO/c7+WzREvs8Lvh7du3pHfvzgA0b96IHj3as3DhCnr0aE/37u2+1r5Jk502BeO6deWoqjt5GABz3ppD646tadW+Vb5LKUgrP1nJzNdn0nvAl7esrgzKiKB8fXnVd4qxvNjhe5bZ5s4t4513PuKgg/bI2W7cuDmcf/69zJu3jPvv/4l7ldWYMmYKPY/tme8yCtZztz5Hv0v78cWaL74yf9RvRjHzjZnssscu/ODyH+SpOttcnfYsJR0nabqkWZKurvkd+fP55+s49dQ7ue22H9KiReOcbQ86aA+mTh3GW2/9ij/84TnWrSvfRlXWHxvKNzD91ensc9Q++S6lIM14bQZNWzdltx67fW3ZgF8P4MpnrqR0j1KmjJmSh+qsKnUWlpKKgT8BxwP7AKdLKsi/nPLyDZx66p2ceWZfBg3qXfMbEj16tKdp052YMmVhHVZXP818Yybtu7enWZtm+S6lIH307kdMf2U6tw28jSeue4I5E+Yw4voRm5YXFRexb799mfbitDxWadnqcje8LzAruTcGkh4BBgDv1+E2ay0iGDLkPnr0aMeVVx5TY/s5c8ro1GlnSkqKmTfvU6ZP/4QuXUq3QaX1y5TnvQueS79L+tHvkn4AzJ04lzcefINThp3CsvnLaN2pNRHBjFdnULq7/20ViroMyw7A/KzpBcBBmzeSNBQYCtC5c+s6LKdqr7/+IfffP5Zvf7sDvXrdCMDvfz+Q9esruOyyR1i69HNOPPEOevXqxOjRl/Paa7O46abnaNCgmKIi8ec/n0FpqXtP2crXlTN7/Gz6/6p/vkupXwL+/pu/s371eiKCdt3aceIvTsx3VZZQ5iZndbBiaTDwg4j4p2T6bKBvRFxW3Xv69Nk9Jky4tk7q2V4NG5/zvvBmW8Xwc4ezaNqirXpuvnubBnHn8W1Stz/ywU8mRkSfrVlDbdTlCZ4FQKes6Y6A/7LNrF6qy7B8C+gmaQ9JOwE/Bp6qw+2ZmdWZOjtmGREVki4FRgPFwH9HxNS62p6ZWV2q0x+lR8SzwLN1uQ0zs21hh7/c0cwsDYelmVkKDkszsxQclmZmKTgszcxScFiamaXgsDQzS8FhaWaWgsPSzCwFh6WZbTckFUt6R9LTyfQeksZJminp0WScCiQ1TKZnJcu71LRuh6WZbU8uB7KHl/9X4NaI6AYsB4Yk84cAyyOiK3Br0i4nh6WZbRckdQROBP4rmRZwFFB5I/t7gYHJ6wHJNMnyo5P21XJYmll9USppQtZj6GbLbwN+AWxMptsAKyKiIpleQOYODpB1J4dk+WdJ+2r5VrhmVl+UVTdSuqT+wJKImCjpiMrZVTSNFMuq5LA0s7xosmwXDnh4885hLjfkWngocLKkE4BGQAsyPc1WkkqS3mP23Roq7+SwQFIJ0BJYlmsD3g03s3ovIn4VER0joguZuzL8X0ScCbwInJY0OxcYlbx+KpkmWf5/UcMNyRyWZrY9+yVwpaRZZI5J3p3Mvxtok8y/Eri6phV5N9zMtisR8RLwUvJ6NtC3ijbrgMG1Wa97lmZmKTgszcxScFiamaXgsDQzS8FhaWaWgsPSzCwFh6WZWQoOSzOzFByWZmYpOCzNzFJwWJqZpeCwNDNLwWFpZpaCw9LMLIWCGqJt0epyho1fVHND2+TwhbUaZcqAlzs8nu8SrB5yz9LMLAWHpZlZCg5LM7MUHJZmZik4LM3MUnBYmpml4LA0M0vBYWlmloLD0swsBYelmVkKBXW5o5ntOGbu0oTjfnRA+jfcUXe1pOGepZlZCg5LM7MUHJZmZik4LM3MUqj2BI+kFrneGBErt345ZmaFKdfZ8KlAAMqaVzkdQOc6rMvMrKBUG5YR0WlbFmJmVshSHbOU9GNJ1ySvO0o6sG7LMjMrLDWGpaQ7gCOBs5NZa4C/1mVRZmaFJs0VPIdERG9J7wBExDJJO9VxXWZmBSXNbni5pCIyJ3WQ1AbYWKdVmZnVgqRGksZLelfSVEk3JPP3kDRO0kxJj1Z29CQ1TKZnJcu71LSNNGH5J+BJYJekgNeAf93iT2VmtvWtB46KiP2BXsBxkg4mk1W3RkQ3YDkwJGk/BFgeEV2BW0mRaTWGZUTcB1wH/DuwDBgcEY9swYcxM6sTkfF5MtkgeQRwFPBEMv9eYGDyekAyTbL8aEnZP5P8mrRX8BQD5cAXtXiPmdnWVCppQtZjaPZCScWSJgFLgDHAh8CKiKhImiwAOiSvOwDzAZLlnwFtcm28xhM8kq4FzgBGkvlB+kOSHoyIP6T9hGZmW0FZRPSpbmFEbAB6SWpFJq96VNUsea6qFxlVzNskzdnws4ADI2INgKTfARMBh6WZFZyIWCHpJeBgoJWkkqT32BFYlDRbAHQCFkgqAVqSOcxYrTS71PP4aqiWALNrV76ZWd2RtEvSo0RSY6AfMA14ETgtaXYuMCp5/VQyTbL8/yJiy3qWkm4l0y1dA0yVNDqZPpbMGXEzs0LRHrhXUjGZTuBjEfG0pPeBRyT9FngHuDtpfzdwv6RZZHqUP65pA7l2w6ckz1OBZ7Lmj63dZzAzq1sRMRn42j0qImI20LeK+euAwbXZRq6BNO6ubpmZ2Y4mzdnwvYDfAfsAjSrnR8TedVhXwXjz4Td5Z9Q7IGi7V1sG/MsAShr6Pm+be/Lp+3lmzJMEwYn9TuO0k87mwzkfcOudN7J23Rra7rob117xrzRt0izfpRaMUTeOYsbrM2i6c1MufvhiAJ649gnK5pUBsO7zdTRq1oiLHrgon2VaIs1f/T3Ab8n8KP144CfsIJc7rlyykvGPjufiRy6mQaMGPH7N40wZM4Ve/Xvlu7SCMmfeTJ4Z8yR//reHaVDSgF/eeBEHH/h9/v3P13PReT9n/32/wz9eGMGjf/8b559xWb7LLRi9+vei7+C+jLxh5KZ5p/3utE2vR//HaBo1bVTVWy0P0pwNbxIRowEi4sOIuI7MKEQ7hI0bNlKxvoKNFRspX1dO89Lm+S6p4MxbOJt99t6PRg0bU1xcwv779OG1cS8wf9Fc9tsn87O4A/f/Lq+OHZPnSgvL7gfsTuMWjatcFhG8/7/v0/PYntu4KqtOmrBcn1wG9KGkiySdBOxax3UVhBa7tuC7Z36XWwfcyh9P/CONmjVir4P3yndZBWePzl2Z/P5EPlu1gnXr1zLu7VdZUvYxXTp35Y23XgTg5TeeZ0nZx3mutP74aNJHNG3dlDadc15UYttQmrD8Z6AZ8P+AQ4ELgPNrepOk/5a0RNKUmtoWqrUr1zL9lelcPvJyrnzmSr5Y+wWT/zE532UVnN077sWPTzmfq4ZdwC9vvIi9uuxNcXExv7jkRv7+j4e58Oc/ZM3a1TQoaZDvUuuN955/z73KAlPjMcuIGJe8XMWXAwCncQ9wB3Bf7csqDLPfmk2r3VrRdOemAPQ4sgfz35vPfsfvl+fKCs8J/U7lhH6nAvBfD9zGLm3a0bnjntx8/V0AzF80l7ETX8lnifXGxoqNfPDiBwy9d2jNjW2byfWj9JHkuFYyIgblWnFEvJJmjLhC1rJtSxZOWUj5unJKGpYw5605tO/RPt9lFaTlKz5l51Zt+GTpYl4d9wJ3/OGBTfM2btzIA4/fyck/+GG+y6wXZr81m9IupbRom/MGq7aN5epZ3rEtCkhGDhkK0LJdy22xydQ69uxIj6N6cOc5d1JUXET7vdtz4EDffqgqw27+Z1auWkFxcQmXX3AtzZu15Mmn72fUPzKj+X3v4H4cd9Qpea6ysDx53ZPMfXsua1as4Zb+t3DE0CPofXJvpoyZ4l3wAqQaLof8ZivP9CyfjohU/+V367FbeNejdg5fWKuLEAx4ucPj+S6h3hl+7nAWTVuUc7zH2urTZ/eYMOHa1O2lCyfmGnWornlsSjOzFByWZmYppA5LSQ1rs2JJDwNvAt0lLZA0pKb3mJkVqjTXhvclM5xRS6CzpP2Bf4qInNetRcTpW6dEM7P8S9OzvB3oD3wKEBHvsgNd7mhmBunCsigi5m02b0NdFGNmVqjSjDo0P9kVj2QU4suAGXVblplZYUnTs/wpcCXQGfiEzE2AflqXRZmZFZo014YvIcX9KczMtmdpzobfRRXXiEeEL7Uxsx1GmmOW/5v1uhFwCjC/bsoxMytMaXbDH82elnQ/4CGvzWyHsiWXO+4B7L61CzEzK2Rpjlku58tjlkVkbkh+dV0WZWZWaHKGZXLvnf2BhcmsjVGXY7qZmRWonLvhSTCOjIgNycNBaWY7pDTHLMdL6l3nlZiZFbBc9+ApiYgK4HvABZI+BFYDItPpdICa2Q4j1zHL8UBvYOA2qsXMrGDlCksBRMSH26gWM7OClSssd5F0ZXULI+KWOqjHzKwg5QrLYqAZSQ/TzGxHlissF0fEb7ZZJWZmBSzXT4fcozQzS+QKy6O3WRVmZgWu2rCMiGXbshAzsy0lqZOkFyVNkzRV0uXJ/NaSxkiamTzvnMyXpNslzZI0Oc2FN1sy6pCZWaGpAH4WET3I3PrmEkn7kBn054WI6Aa8wJeDAB0PdEseQ4G/1LQBh6WZ1XsRsTgi3k5erwKmAR2AAcC9SbN7+fIimwHAfZExFmglqX2ubaQZKd3MbOtbtYGNL39am3eUSpqQNT08IoZv3khSF+AAYBzQNiIWQyZQJe2aNOvAV+/4sCCZt7i6jTsszay+KIuIPrkaSGoGPAlcERErM6NMVt20ink5R1XzbriZbRckNSATlA9GxIhk9ieVu9fJ85Jk/gKgU9bbOwKLcq3fYWlm9V4yUPndwLTNLsV+Cjg3eX0uMCpr/jnJWfGDgc8qd9er491wM9seHAqcDbwnaVIy7xrgJuAxSUOAj4DBybJngROAWcAa4Cc1bcBhaWb1XkS8RvVXHX7tApvkrg+X1GYb3g03M0vBYWlmloLD0swsBR+zrOde7vB4vkuodxbfnfOkp1WhvKw83yXknXuWZmYpOCzNzFJwWJqZpeCwNDNLwWFpZpaCw9LMLAWHpZlZCg5LM7MUHJZmZik4LM3MUnBYmpml4LA0M0vBYWlmloLD0swsBYelmVkKDkszsxQclmZmKTgszcxScFiamaXge/CYWV6s2tCSl5edXIt3XFNntaThnqWZWQoOSzOzFByWZmYpOCzNzFJwWJqZpeCwNDNLwWFpZpaCw9LMLAWHpZlZCg5LM7MUfLljFTZu2Mhd591F812ac8YtZzD+8fGMfWQsyxcs56rRV9GkVZN8l1hQRt04ihmvz6Dpzk25+OGLv7LsjQfeYMx/jvH3lnjooYdo0KABRUVFSGLQoEGMHTuWefPmUVxcTIsWLTj88MNp2LAhGzdu5OWXX6asrIyIoFu3bhxwwAH5/gg7LIdlFcY9Oo7SLqWsX70egE77dWLvQ/fmnovvyW9hBapX/170HdyXkTeM/Mr8zz75jNnjZ9OyXcs8VVaYTjrpJBo1arRpumPHjvTt25eioiLGjRvHpEmTOOigg5g9ezYbNmxg8ODBVFRU8Nhjj9G1a1eaN2+ex+p3XN4N38zKT1Yy8/WZ9B7Qe9O89t3b02q3VnmsqrDtfsDuNG7R+GvzR986mn6X9gPloah6pGPHjhQVZf4Ud911V1avXr1pWUVFBRs3bqSiooLi4mIaNGiQrzILnqT/lrRE0pSsea0ljZE0M3neOZkvSbdLmiVpsqTe1a85w2G5medufY5+l/ZD8l/4NzH9lek036U57fZul+9SCooknnnmGUaMGMG0adO+tnz69Ol06tQJgD333JOSkhIeeOABHnroIfbbb7+v9Ejta+4Bjtts3tXACxHRDXghmQY4HuiWPIYCf6lp5XUWlpI6SXpR0jRJUyVdXlfb2lpmvDaDpq2bsluP3fJdSr1Wvq6cV+95lSMvPDLfpRSck08+mVNPPZXjjz+eqVOnsnjx4k3L3n77bYqKiujatSsAS5YsoaioiLPOOovTTz+dyZMns3LlynyVXvAi4hVg2WazBwD3Jq/vBQZmzb8vMsYCrSS1z7X+ujxmWQH8LCLeltQcmChpTES8X4fb/EY+evcjpr8ynZlvzKRifQXrV69nxPUjGHTDoHyXVq8sW7CM5YuW89ez/grAyiUrufOcO7ngbxfQrE2zPFeXX02bNgWgcePGdOnShSVLltC+fXtmzJjBRx99RP/+/Tft1cyaNWvTLnrjxo1p27YtS5cupUWLFvn8CPlUKmlC1vTwiBhew3vaRsRigIhYLGnXZH4HYH5WuwXJvMVUo87CMimwsshVkqYlxRRsWPa7pB/9LukHwNyJc3njwTcclFugbde2XPXcVZumbxt4G0PvGbrDnw0vLy8nIthpp50oLy9n4cKF9O7dm/nz5zNp0iROOukkSkq+/JNs1qwZixYtolu3blRUVLBkyRK+/e1v5/ET5F1ZRPTZSuuq6jhb5HrDNjkbLqkLcAAwbltsb2sb9+g4Xr//dT5f9jl/OfMvdDukGydfW5sRnrdvT173JHPfnsuaFWu4pf8tHDH0CHqfXOPx8h3O2rVref755wGICPbaay86derEI488woYNG3j22WeBzEmeww47jH333ZeXXnqJJ554goige/futGnTJp8foT76RFL7pFfZHliSzF8AdMpq1xFYlGtFisgZpt+YpGbAy8DvImJEFcuHkjnASst2LQ+8YtQVdVqP2eK7q93TsmqMGDGCpUuXbtWznt277ht/vfmx1O2PGtRzYk09y6Rj9nRE9EymbwY+jYibJF0NtI6IX0g6EbgUOAE4CLg9IvrmWnedng2X1AB4EniwqqAEiIjhEdEnIvrs6LtpZrblJD0MvAl0l7RA0hDgJuAYSTOBY5JpgGeB2cAs4C7g4ipW+RV1thuuzFHqu4FpEXFLXW3HzAwgIk6vZtHRVbQN4JLarL8ue5aHAmcDR0malDxOqMPtmZnVmbo8G/4avnbDzLYTvoLHzCwFh6WZWQoOSzOzFByWZmYpOCzNzFLw4L9mlherGizn5Q6P57uM1NyzNDNLwWFpZpaCw9LMLAWHpZlZCg5LM7MUHJZmZik4LM3MUnBYmpml4LA0M0vBYWlmloLD0swsBYelmVkKDkszsxQclmZmKTgszcxScFiamaXgsDQzS8FhaWaWgsPSzCwFh6WZWQoOSzOzFByWZmYpOCzNzFJwWJqZpeCwNDNLwWFpZpaCw9LMLAVFRL5r2ETSUmBevuuoQilQlu8i6hl/Z1umUL+33SNil625QknPkfm8aZVFxHFbs4baKKiwLFSSJkREn3zXUZ/4O9sy/t4Kl3fDzcxScFiamaXgsExneL4LqIf8nW0Zf28FyscszcxScM/SzCwFh6WZWQoOSzOzFByW1ZDUXdJ3JTWQVJzveuoLf1e1I6mrpD6SGua7FsvNJ3iqIGkQ8HtgYfKYANwTESvzWlgBk7R3RMxIXhdHxIZ811ToJPUn8+/sU+Bj4PrK79AKj3uWm5HUAPgRMCQijgZGAZ2AX0hqkdfiClTyRz9J0kMAEbHBPczcJB0C/DtwbkQcCSwHrs5vVZaLw7JqLYBuyeuRwNPATsAZkpS3qgqQpKbApcAVwBeSHgAHZko3RcQ7yevrgdbeHS9cDsvNREQ5cAswSNJhEbEReA2YBHwvr8UVoIhYDZwPPAT8HGiUHZj5rK3AjQNGwKbjvA2B3cn8jxpJbfJXmlXFYVm1V4HngbMlfT8iNkTEQ8BuwP75La3wRMSiiPg8IsqAC4HGlYEpqbekb+W3wsKT/JuqPAYuYAWwLCKWSjoT+K2kxvmr0DZXku8CClFErJP0IBDAr5I/9vVAW2BxXosrcBHxqaQLgZslfQAUA0fmuayCFhEVwOeS5kv6A3AscF5ErM1zaZbFYVmNiFgu6S7gfTK9pXXAWRHxSX4rK3wRUSZpMnA8cExELMh3TYUsOQ7eADgseT46ImbmtyrbnH86lEJyTCmS45dWA0k7A48BP4uIyfmup76QdB7wVkRMzXct9nUOS6sTkhpFxLp811GfSFL4D7JgOSzNzFLw2XAzsxQclmZmKTgszcxScFiamaXgsNxOSNogaZKkKZIel9TkG6zrCElPJ69PllTtAA+SWkm6eAu2MUzSz9PO36zNPZJOq8W2ukiaUtsazbI5LLcfayOiV0T0BL4ALspeqIxa//eOiKci4qYcTVoBtQ5Ls/rGYbl9ehXomvSopkn6M/A20EnSsZLelPR20gNtBiDpOEkfSHoNGFS5IknnSbojed1W0khJ7yaPQ4CbgL2SXu3NSburJL0labKkG7LWda2k6ZL+F+he04eQdEGynnclPblZb7mxW9bKAAACWUlEQVSfpFclzUiGiENSsaSbs7Z94Tf9Is0qOSy3M5JKyFxm+F4yqztwX0QcAKwGrgP6RURvMoMaXympEXAXcBKZS+7aVbP624GXI2J/oDcwlcwYjB8mvdqrJB1LZni7vkAv4EBJ35d0IPBj4AAyYfydFB9nRER8J9neNGBI1rIuwOHAicBfk88wBPgsIr6TrP8CSXuk2I5ZjXxt+PajsaRJyetXgbvJjJI0LyLGJvMPBvYBXk+G5dwJeBP4FjCn8nrkZMSgoVVs4yjgHNg0/NpnyaWN2Y5NHpXjNDYjE57NgZERsSbZxlMpPlNPSb8ls6vfDBidteyx5PLTmZJmJ5/hWGC/rOOZLZNte/Rx+8YcltuPtRHRK3tGEoirs2cBYyLi9M3a9SIzwtLWIOAPEXHnZtu4Ygu2cQ8wMCLeTa6bPiJr2ebrimTbl0VEdqgiqUstt2v2Nd4N37GMBQ6V1BVAUhNJewMfAHtI2itpd3o1738B+Gny3uLkNhuryPQaK40Gzs86FtpB0q7AK8ApkhpLak5ml78mzYHFytzq48zNlg2WVJTUvCcwPdn2T5P2SNpbmZHczb4x9yx3IMnAsucBD+vL2xdcFxEzJA0FnpFURmZk+J5VrOJyYLikIcAG4KcR8aak15Of5vwjOW7ZA3gz6dl+TmZou7clPUpmxPl5ZA4V1ORfyIwoPo/MMdjsUJ4OvExmjNGLkjFI/4vMscy3k2HPlgID0307Zrl5IA0zsxS8G25mloLD0swsBYelmVkKDkszsxQclmZmKTgszcxScFiamaXw/wF8AB/JosKFdgAAAABJRU5ErkJggg==\n", 476 | "text/plain": [ 477 | "
" 478 | ] 479 | }, 480 | "metadata": {}, 481 | "output_type": "display_data" 482 | } 483 | ], 484 | "source": [ 485 | "metrics(y, y_pred)" 486 | ] 487 | } 488 | ], 489 | "metadata": { 490 | "kernelspec": { 491 | "display_name": "conda_pytorch_p36", 492 | "language": "python", 493 | "name": "conda_pytorch_p36" 494 | }, 495 | "language_info": { 496 | "codemirror_mode": { 497 | "name": "ipython", 498 | "version": 3 499 | }, 500 | "file_extension": ".py", 501 | "mimetype": "text/x-python", 502 | "name": "python", 503 | "nbconvert_exporter": "python", 504 | "pygments_lexer": "ipython3", 505 | "version": "3.6.5" 506 | } 507 | }, 508 | "nbformat": 4, 509 | "nbformat_minor": 2 510 | } 511 | -------------------------------------------------------------------------------- /sample_notebooks/TF_classi_transformers_BBCnews_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Tensorflow Framework for Multi-class Text Classification" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## BBC News Classification - 5 Class Problem\n", 15 | "\n", 16 | "* Consists of 2225 documents from the BBC news website corresponding to stories in five topical areas from 2004-2005.\n", 17 | "* Class Labels: 5
\n", 18 | "\n", 19 | " 0 - Business
\n", 20 | " 1 - Sport
\n", 21 | " 2 - Politics
\n", 22 | " 3 - Tech
\n", 23 | " 4 - Entertainment " 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "## Reading Data" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 1, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "import sys\n", 40 | "import numpy as np\n", 41 | "import pandas as pd\n", 42 | "\n", 43 | "sys.path.append(\"..\")\n", 44 | "pd.set_option('display.max_colwidth', -1)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "data": { 54 | "text/plain": [ 55 | "((1424, 3), (445, 2), (356, 3))" 56 | ] 57 | }, 58 | "execution_count": 2, 59 | "metadata": {}, 60 | "output_type": "execute_result" 61 | } 62 | ], 63 | "source": [ 64 | "df_train = pd.read_csv('../datasets/bbcfull/train.csv')\n", 65 | "df_test = pd.read_csv('../datasets/bbcfull/test.csv')\n", 66 | "df_score = pd.read_csv('../datasets/bbcfull/test_with_label.csv')\n", 67 | "df_val = pd.read_csv('../datasets/bbcfull/dev.csv')\n", 68 | "df_train.shape, df_test.shape, df_val.shape" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 3, 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "data": { 78 | "text/plain": [ 79 | "1 327\n", 80 | "0 326\n", 81 | "2 267\n", 82 | "3 257\n", 83 | "4 247\n", 84 | "Name: label, dtype: int64" 85 | ] 86 | }, 87 | "execution_count": 3, 88 | "metadata": {}, 89 | "output_type": "execute_result" 90 | } 91 | ], 92 | "source": [ 93 | "df_train.label.value_counts()" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 4, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "data": { 103 | "text/plain": [ 104 | "1 82\n", 105 | "0 82\n", 106 | "2 66\n", 107 | "3 64\n", 108 | "4 62\n", 109 | "Name: label, dtype: int64" 110 | ] 111 | }, 112 | "execution_count": 4, 113 | "metadata": {}, 114 | "output_type": "execute_result" 115 | } 116 | ], 117 | "source": [ 118 | "df_val.label.value_counts()" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 5, 124 | "metadata": {}, 125 | "outputs": [ 126 | { 127 | "data": { 128 | "text/plain": [ 129 | "1 102\n", 130 | "0 102\n", 131 | "2 84 \n", 132 | "3 80 \n", 133 | "4 77 \n", 134 | "Name: label, dtype: int64" 135 | ] 136 | }, 137 | "execution_count": 5, 138 | "metadata": {}, 139 | "output_type": "execute_result" 140 | } 141 | ], 142 | "source": [ 143 | "df_score.label.value_counts()" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": {}, 149 | "source": [ 150 | "## Training Bert/Electra classifier with FrameWork" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 6, 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "name": "stdout", 160 | "output_type": "stream", 161 | "text": [ 162 | "WARNING:tensorflow:From /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/__init__.py:1467: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead.\n", 163 | "\n" 164 | ] 165 | } 166 | ], 167 | "source": [ 168 | "from classitransformers.report import metrics\n", 169 | "from classitransformers.configs import Configs\n", 170 | "\n", 171 | "from classitransformers.tfelectra import ElectraClassification #(for using Electra)\n", 172 | "\n", 173 | "#from classitransformers.tfbert import BertClassification" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "## Set Configurations\n", 181 | "\n", 182 | "* Pre-trained folderpath\n", 183 | "* labels list\n", 184 | "* train, test and eval batch size\n", 185 | "* max_seq_len - how long is the each text input sentence !!\n", 186 | "* data_dir - data path folder (optional)\n", 187 | "* output_dir - output path folder (optional)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 7, 193 | "metadata": {}, 194 | "outputs": [ 195 | { 196 | "name": "stdout", 197 | "output_type": "stream", 198 | "text": [ 199 | "================================================================================\n", 200 | "Config: model=electra, trial 1/1\n", 201 | "================================================================================\n", 202 | "data_dir ../datasets/bbcfull/\n", 203 | "do_eval True\n", 204 | "do_lower_case True\n", 205 | "do_predict False\n", 206 | "do_train True\n", 207 | "embedding_size None\n", 208 | "eval_batch_size 8\n", 209 | "gcp_project None\n", 210 | "init_checkpoint ../models/Electra_base/\n", 211 | "iterations_per_loop 1000\n", 212 | "keep_all_models True\n", 213 | "label_list ['0', '1', '2', '3', '4']\n", 214 | "layerwise_lr_decay 0.8\n", 215 | "learning_rate 5e-05\n", 216 | "log_examples False\n", 217 | "master None\n", 218 | "max_seq_length 256\n", 219 | "model_hparam_overrides {}\n", 220 | "model_name electra\n", 221 | "model_size base\n", 222 | "num_tpu_cores 1\n", 223 | "num_train_epochs 3\n", 224 | "num_trials 1\n", 225 | "output_dir ../electra_output_bbcfull/\n", 226 | "predict_batch_size 8\n", 227 | "pretrained_model_dir ../models/Electra_base/\n", 228 | "save_checkpoints_steps 1000000\n", 229 | "task_names ['textclassification']\n", 230 | "tasks {'textclassification': {'type': 'classification', 'labels': ['0', '1', '2', '3', '4'], 'header': True, 'text_column': 1, 'label_column': 2}}\n", 231 | "tpu_job_name None\n", 232 | "tpu_name None\n", 233 | "tpu_zone None\n", 234 | "train_batch_size 16\n", 235 | "use_tfrecords_if_existing False\n", 236 | "use_tpu False\n", 237 | "vocab_file ../models/Electra_base/vocab.txt\n", 238 | "vocab_size 30522\n", 239 | "warmup_proportion 0.1\n", 240 | "weight_decay_rate 0.01\n", 241 | "\n", 242 | "Loading dataset textclassification_train\n", 243 | "Existing tfrecords not found so creating\n", 244 | "Writing example 0 of 1424\n" 245 | ] 246 | } 247 | ], 248 | "source": [ 249 | "# config = Configs(pretrained_model_dir = '../models/Bert_base',\n", 250 | "# model_name = 'bert'\n", 251 | "# num_train_epochs = 3,\n", 252 | "# train_batch_size = 16, \n", 253 | "# eval_batch_size = 8, \n", 254 | "# predict_batch_size = 8, \n", 255 | "# do_train = True, \n", 256 | "# do_eval = True, \n", 257 | "# label_list = [\"0\", \"1\", \"2\", \"3\", \"4\"], \n", 258 | "# max_seq_length = 256, \n", 259 | "# data_dir='../datasets/bbcfull/',\n", 260 | "# output_dir = '../bert_output_bbcfull/')\n", 261 | "\n", 262 | "\n", 263 | "#model = BertClassification(config)\n", 264 | "\n", 265 | "\n", 266 | "config = Configs(pretrained_model_dir = '../models/Electra_base/',\n", 267 | " model_name = 'electra',\n", 268 | " num_train_epochs=3, \n", 269 | " train_batch_size = 16,\n", 270 | " eval_batch_size = 8, \n", 271 | " do_train = True, \n", 272 | " do_eval = True, \n", 273 | " label_list = [\"0\", \"1\", \"2\", \"3\", \"4\"],\n", 274 | " max_seq_length = 256,\n", 275 | " data_dir = \"../datasets/bbcfull/\", \n", 276 | " output_dir = \"../electra_output_bbcfull/\")\n", 277 | "\n", 278 | "model = ElectraClassification(config) " 279 | ] 280 | }, 281 | { 282 | "cell_type": "markdown", 283 | "metadata": {}, 284 | "source": [ 285 | "## Train Method" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": 8, 291 | "metadata": { 292 | "scrolled": true 293 | }, 294 | "outputs": [ 295 | { 296 | "name": "stdout", 297 | "output_type": "stream", 298 | "text": [ 299 | "================================================================================\n", 300 | "Start training\n", 301 | "================================================================================\n", 302 | "Training for 267 steps\n", 303 | "Building model...\n", 304 | "Building complete\n", 305 | "10/267 = 3.7%, SPS: 0.3, ELAP: 32, ETA: 13:30 - loss: 24.9692\n", 306 | "20/267 = 7.5%, SPS: 0.4, ELAP: 49, ETA: 10:00 - loss: 22.0914\n", 307 | "30/267 = 11.2%, SPS: 0.5, ELAP: 1:06, ETA: 8:38 - loss: 10.4456\n", 308 | "40/267 = 15.0%, SPS: 0.5, ELAP: 1:23, ETA: 7:49 - loss: 4.9485\n", 309 | "50/267 = 18.7%, SPS: 0.5, ELAP: 1:40, ETA: 7:13 - loss: 11.9883\n", 310 | "60/267 = 22.5%, SPS: 0.5, ELAP: 1:57, ETA: 6:44 - loss: 1.3170\n", 311 | "70/267 = 26.2%, SPS: 0.5, ELAP: 2:14, ETA: 6:17 - loss: 3.1973\n", 312 | "80/267 = 30.0%, SPS: 0.5, ELAP: 2:31, ETA: 5:54 - loss: 0.4818\n", 313 | "90/267 = 33.7%, SPS: 0.5, ELAP: 2:48, ETA: 5:31 - loss: 4.8838\n", 314 | "100/267 = 37.5%, SPS: 0.5, ELAP: 3:05, ETA: 5:10 - loss: 0.1226\n", 315 | "110/267 = 41.2%, SPS: 0.5, ELAP: 3:22, ETA: 4:49 - loss: 13.2960\n", 316 | "120/267 = 44.9%, SPS: 0.5, ELAP: 3:40, ETA: 4:29 - loss: 5.2097\n", 317 | "130/267 = 48.7%, SPS: 0.5, ELAP: 3:57, ETA: 4:09 - loss: 1.4426\n", 318 | "140/267 = 52.4%, SPS: 0.6, ELAP: 4:14, ETA: 3:50 - loss: 0.0053\n", 319 | "150/267 = 56.2%, SPS: 0.6, ELAP: 4:31, ETA: 3:31 - loss: 0.0025\n", 320 | "160/267 = 59.9%, SPS: 0.6, ELAP: 4:48, ETA: 3:13 - loss: 0.3520\n", 321 | "170/267 = 63.7%, SPS: 0.6, ELAP: 5:06, ETA: 2:54 - loss: 0.0028\n", 322 | "180/267 = 67.4%, SPS: 0.6, ELAP: 5:23, ETA: 2:36 - loss: 0.0010\n", 323 | "190/267 = 71.2%, SPS: 0.6, ELAP: 5:40, ETA: 2:18 - loss: 0.0012\n", 324 | "200/267 = 74.9%, SPS: 0.6, ELAP: 5:57, ETA: 2:00 - loss: 0.0020\n", 325 | "210/267 = 78.7%, SPS: 0.6, ELAP: 6:14, ETA: 1:42 - loss: 0.0011\n", 326 | "220/267 = 82.4%, SPS: 0.6, ELAP: 6:31, ETA: 1:24 - loss: 6.7390\n", 327 | "230/267 = 86.1%, SPS: 0.6, ELAP: 6:49, ETA: 1:06 - loss: 0.0008\n", 328 | "240/267 = 89.9%, SPS: 0.6, ELAP: 7:06, ETA: 48 - loss: 0.0013\n", 329 | "250/267 = 93.6%, SPS: 0.6, ELAP: 7:23, ETA: 30 - loss: 0.0005\n", 330 | "260/267 = 97.4%, SPS: 0.6, ELAP: 7:40, ETA: 12 - loss: 1.2932\n", 331 | "================================================================================\n", 332 | "Training Finished\n", 333 | "================================================================================\n", 334 | "================================================================================\n", 335 | "Start Evaluation\n", 336 | "================================================================================\n", 337 | "Evaluating textclassification\n", 338 | "Loading dataset textclassification_dev\n", 339 | "Existing tfrecords not found so creating\n", 340 | "Writing example 0 of 356\n", 341 | "Building model...\n", 342 | "Building complete\n", 343 | "textclassification: accuracy: 97.47 - loss: 0.10\n", 344 | "\n" 345 | ] 346 | } 347 | ], 348 | "source": [ 349 | "model.train()" 350 | ] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "metadata": {}, 355 | "source": [ 356 | "## Test Methods" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 9, 362 | "metadata": {}, 363 | "outputs": [ 364 | { 365 | "name": "stdout", 366 | "output_type": "stream", 367 | "text": [ 368 | "Writing out predictions for [Task(textclassification)] test\n", 369 | "Loading dataset textclassification_test\n", 370 | "Existing tfrecords not found so creating\n", 371 | "Writing example 0 of 445\n", 372 | "Building model...\n", 373 | "Building complete\n", 374 | "Getting predictions for 445 textclassification examples (test)\n" 375 | ] 376 | } 377 | ], 378 | "source": [ 379 | "prediction = model.test()" 380 | ] 381 | }, 382 | { 383 | "cell_type": "markdown", 384 | "metadata": {}, 385 | "source": [ 386 | "**Testing a sentence, takes a list of sentences.**" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 10, 392 | "metadata": {}, 393 | "outputs": [ 394 | { 395 | "name": "stdout", 396 | "output_type": "stream", 397 | "text": [ 398 | "Writing out predictions for [Task(textclassification)] infer\n", 399 | "Loading dataset textclassification_infer\n", 400 | "Existing tfrecords not found so creating\n", 401 | "Writing example 0 of 1\n", 402 | "Building model...\n", 403 | "Building complete\n", 404 | "Getting predictions for 1 textclassification examples (infer)\n" 405 | ] 406 | }, 407 | { 408 | "data": { 409 | "text/plain": [ 410 | "[array([0.81033325, 0.03606391, 0.0256805 , 0.99965596, 0.10343838],\n", 411 | " dtype=float32)]" 412 | ] 413 | }, 414 | "execution_count": 10, 415 | "metadata": {}, 416 | "output_type": "execute_result" 417 | } 418 | ], 419 | "source": [ 420 | "model.text_inference(['He is very good in writing a book on financial investment money making strategy. There are some transactional charges but the profit we make is awesome. The business people are exchanging bulk money over trades happening every seconds. Even retailers are happy doing trade and adding sum to their investments.'])" 421 | ] 422 | }, 423 | { 424 | "cell_type": "markdown", 425 | "metadata": {}, 426 | "source": [ 427 | "## Export the checkpoint model to pb (protobuf) format\n", 428 | "\n", 429 | "**In TensorFlow, the protbuf file contains the graph definition as well as the weights of the model. Thus, a pb file is all you need to be able to run a given trained model.**\n", 430 | "\n", 431 | "Only support of Bert for now" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": 11, 437 | "metadata": { 438 | "scrolled": true 439 | }, 440 | "outputs": [], 441 | "source": [ 442 | "# Currently does not have support for Electra\n", 443 | "\n", 444 | "#model.export_model()" 445 | ] 446 | }, 447 | { 448 | "cell_type": "markdown", 449 | "metadata": {}, 450 | "source": [ 451 | "## Inferencing from Trained Pb Model\n", 452 | "\n", 453 | "* This tensorflow format is used for deployment purpose." 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": 12, 459 | "metadata": {}, 460 | "outputs": [], 461 | "source": [ 462 | "# Not supported for Electra\n", 463 | "\n", 464 | "# predictions = model.inference(test_filename = '../datasets/bbcfull/test.csv', batch_size = 256)" 465 | ] 466 | }, 467 | { 468 | "cell_type": "code", 469 | "execution_count": 13, 470 | "metadata": {}, 471 | "outputs": [], 472 | "source": [ 473 | "y_pred = [np.argmax(tup) for tup in prediction]\n", 474 | "y = df_score.label" 475 | ] 476 | }, 477 | { 478 | "cell_type": "markdown", 479 | "metadata": {}, 480 | "source": [ 481 | "## Accuracy and Report" 482 | ] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "execution_count": 14, 487 | "metadata": {}, 488 | "outputs": [ 489 | { 490 | "name": "stdout", 491 | "output_type": "stream", 492 | "text": [ 493 | "\n", 494 | "\n", 495 | "Accuracy: 0.975\n", 496 | "Precision: 0.975\n", 497 | "Recall: 0.975\n", 498 | "F1-Score: 0.975\n" 499 | ] 500 | }, 501 | { 502 | "data": { 503 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAEmCAYAAADr3bIaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xl8VfWd//HXJwmEAAl7WMImyubyE6OlIrii/hAVLZXWpQ52aGmtVqfWGa3t/NTWTtufv5+t3caitjqDVelUq2NRS7VIcaGKxQVExAUQwg4mLAmQ+5k/zglekCQnl9ycey7v5+NxH9x7zrnn+87l5pPv2b7H3B0REWlaQdwBRESSQMVSRCQCFUsRkQhULEVEIlCxFBGJQMVSRCQCFctDiJmVmNl/m9lHZva7g1jPZWb2p9bMFhczO9nM3o47h+Q+03mWucfMLgWuA0YANcAi4PvuPv8g13s58HXgJHffc9BBc5yZOTDU3ZfHnUWSTz3LHGNm1wE/Af4N6A0MBH4JXNAKqx8ELDsUCmUUZlYUdwZJEHfXI0ceQBdgGzCliWWKCYrpmvDxE6A4nHca8CHwTWA9UAV8MZx3K7AL2B22MQ24BZiZtu7BgANF4esrgPcIerfvA5elTZ+f9r6TgJeBj8J/T0qbNxf4HvB8uJ4/AT0b+dka8v9LWv4LgYnAMmAzcFPa8qOBF4Gt4bI/B9qH8+aFP8v28Of9fNr6bwDWAv/ZMC18z+FhG5Xh637ABuC0uL8besT/UM8yt4wBOgCPNrHMt4ETgVHAsQQF4ztp8/sQFN0KgoL4CzPr5u43E/RWH3b3zu5+b1NBzKwT8FPgHHcvJSiIiw6wXHfgj+GyPYA7gD+aWY+0xS4FvgiUA+2B65toug/BZ1AB/B/gbuALwPHAycC/mtlh4bL1wDeAngSf3XjgawDufkq4zLHhz/tw2vq7E/Syp6c37O7vEhTSmWbWEfgNcL+7z20irxwiVCxzSw9goze9mXwZ8F13X+/uGwh6jJenzd8dzt/t7rMJelXDM8yTAo42sxJ3r3L3xQdY5lzgHXf/T3ff4+4PAkuB89OW+Y27L3P3ncAsgkLfmN0E+2d3Aw8RFMI73b0mbH8JwR8J3H2hu78UtvsB8Cvg1Ag/083uXhfm2Ye73w0sBxYAfQn+OImoWOaYTUDPZval9QNWpL1eEU7bu479iu0OoHNLg7j7doJN168CVWb2RzMbESFPQ6aKtNdrW5Bnk7vXh88bitm6tPk7G95vZsPM7AkzW2tm1QQ9555NrBtgg7vXNrPM3cDRwM/cva6ZZeUQoWKZW14E6gj20zVmDcEmZIOB4bRMbAc6pr3ukz7T3Z9297MIelhLCYpIc3kaMq3OMFNL/DtBrqHuXgbcBFgz72ny9A8z60ywH/he4JZwN4OIimUucfePCPbT/cLMLjSzjmbWzszOMbP/Gy72IPAdM+tlZj3D5Wdm2OQi4BQzG2hmXYBvNcwws95mdkG477KOYHM+dYB1zAaGmdmlZlZkZp8HjgSeyDBTS5QC1cC2sNd75X7z1wFDWrjOO4FX3P1LBPti7zrolJIXVCxzjLv/f4JzLL9DcCR2FXA18IdwkduAV4DXgTeAV8NpmbQ1B3g4XNdC9i1wBWGONQRHiE/lk8UId98EnEdwBH4TwZHs89x9YyaZWuh6goNHNQS93of3m38LcL+ZbTWzzzW3MjO7AJjAxz/ndUClmV3WaoklsXRSuohIBOpZiohEoGIpIhKBiqWISAQqliIiEeTUQAIdOnTw0tLSuGO0WLue7eKOIJJVW6u2smPrjubOYW2RAQMGeG1tc9cHfGzjxo1Pu/uE1szQEjlVLEtLS5k8eXLcMVqs77S+cUcQyaoZU2e0+jpra2tb9Ps+Y8aM5q7OyipthotIXjCzX5vZejN7M21adzObY2bvhP92C6ebmf3UzJab2etmVtnc+lUsRSRf3EdwUUG6G4Fn3H0o8Ez4GuAcYGj4mE5w6WyTVCxFJC+4+zyCq83SXQDcHz6/n4/HXbgA+A8PvAR0NbMm96epWIpIUvQ0s1fSHtObfwu93b0qfL6W4O4DEIyKtSptuQ/Zd6SsT8ipAzwiIk3Y6O4nZPpmd/fwvkwZUc9SRPLZuobN6/Df9eH01cCAtOX608ywgiqWIpLPHgemhs+nAo+lTf+H8Kj4icBHaZvrB6TNcBHJC2b2IMEN6Hqa2YfAzcAPgVlmNo1gBP+GofpmE9wIbznB6P1fbG79KpYikhfc/ZJGZo0/wLIOXNWS9WszXEQkAhVLEZEIVCxFRCJQsRQRiUDFUkQkgrw8Gv7GG2+wdOlSAEaMGMExxxzDxo0bmT9/PvX19ZgZ48aNo7y8POakTVv+4nKeuuMpUqkUlZMqGTd1XNyRmvXRuo/4wy1/YNvmbZgZlRdWcuLFJ8Ydq1lJzd0gid+VpMm7Yrl582aWLl3KZz7zGQoKCnjyyScZOHAgCxYsoLKykoEDB7Jy5UoWLFjA+eefH3fcRqXqU8y+fTaX/+xyysrLuPuKuxl+8nB6DekVd7QmFRQWcPa1Z9N3RF/qttcxY+oMDh99uHJnUVK/K0mTd5vhW7dupby8nKKiIgoKCujbty/vv/8+Zsbu3bsB2LVrFx07dow5adNWL1lN9/7d6VbRjcJ2hRx11lEsnbc07ljNKu1ZSt8RweAtxZ2K6TW4F9UbqmNO1byk5obkfleSJu+KZbdu3Vi7di21tbXs2bOHlStXsn37dsaMGcNLL73EAw88wEsvvcTo0aPjjtqkmvU1lPUu2/u6rLyMmg01MSZqua1rtlK1rIr+R/WPO0qLJC13PnxXkiCrm+FmNgG4EygE7nH3H2azPQiK5bHHHsvs2bMpKiqiR48emBlLlixhzJgxDBkyhHfffZd58+Zx7rnnZjvOIWvXjl3MunEWE74xgeLOxXHHiSypuSX7stazNLNC4BcEIxIfCVxiZkdmq710I0aMYPLkyUyaNIni4mK6dOnCsmXLOOywwwAYMmQI69evb2Yt8SotL6V63cebgdXrqyntlYybudXvqWfWjbM4ZsIxjDx9ZNxxIktq7iR/V5Ikm5vho4Hl7v6eu+8CHiIYnTjrdu7cCcC2bdt4//33OeKII+jUqRNVVcGgImvWrKFLly5tESVjFSMr2LRqE1vWbKF+dz2L5yxm+CnD447VLHfn8dsep+fgnoy5dEzccSJLam5I7nclabK5GX6gkYg/ncX29pozZw61tbUUFBQwbtw4iouLOeWUU3jhhRdIpVIUFhZy8sknt0WUjBUUFTDx+onMvGYmnnJGnT+K8iG5faoTwKrXVvH6k69TfkQ5d33hLgDGXzmeoWOHxpysaUnNDcn9riRN7KcOhUPDTwfo3Llzq6xz0qRJn5jWp0+fxN1md+jYoYn4ZU03cNRAbl5wc9wxWiypuRsk8buSNNncDI80ErG7z3D3E9z9hA4dOmQxjohI5rJZLF8GhprZYWbWHriYYHRiEZHEydpmuLvvMbOrgacJTh36tbsvzlZ7IiLZlNV9lu4+m2D4dhGRRMu7K3hERLIh9qPhInJo6p7ayiXbH2t+wdCMLGaJQj1LEZEIVCxFRCJQsRQRiUDFUkQkAhVLEZEIVCxFRCJQsRQRiUDFUkQkAhVLEZEIVCxFRCJQsRQRiUDFUkQkAhVLEZEIVCxFRCJQsRQRiSCnxrNs17Mdfaf1jTtGi1XdWxV3hIwl8fMWiYN6liKSF8zsG2a22MzeNLMHzaxDeMPEBWa23MweDm+emBEVSxFJPDOrAK4BTnD3owlukngx8CPgx+5+BLAFmJZpGyqWIpIvioASMysCOgJVwBnAf4Xz7wcuzHTlKpYikhQ9zeyVtMf0hhnuvhr4f8BKgiL5EbAQ2Orue8LFPgQqMm08pw7wiIg0YaO7n3CgGWbWDbgAOAzYCvwOmNCajatnKSL54EzgfXff4O67gUeAsUDXcLMcoD+wOtMGVCxFJB+sBE40s45mZsB4YAnwF+CicJmpQPR77+5HxVJEEs/dFxAcyHkVeIOgts0AbgCuM7PlQA/g3kzb0D5LEckL7n4zcPN+k98DRrfG+tWzFBGJQMVSRCQCFUsRkQi0z1JEYtFxcy+Oe3B68wvudWvWskShnqWISAQqliIiEahYiohEkPf7LJe/uJyn7niKVCpF5aRKxk0dF3ekfcydO5eVK1dSUlLClClTAKitreWZZ56hpqaG0tJSzjzzTIqLi3nnnXd47bXXcHfat2/PuHHj6NGjR8w/wSfl+mfemKTmhmRnT4q87lmm6lPMvn02l/3kMq566Cre/NObbHhvQ9yx9jF8+HAmTpy4z7RFixZRUVHBxRdfTEVFBYsWLQKgtLSU888/nylTpnDccccxb968OCI3KQmf+YEkNTckO3uS5HWxXL1kNd37d6dbRTcK2xVy1FlHsXTe0rhj7aNv374UFxfvM23FihUMGzYMgGHDhvHBBx8A0KdPn73L9u7dm+3bt7dp1iiS8JkfSFJzQ7KzJ0leF8ua9TWU9S7b+7qsvIyaDTUxJopm586ddOzYEYCSkhJ27tz5iWWWLl3KgAED2jpas5L6mSc1NyQ7e5JkrVia2a/NbL2ZvZmtNg4FwQAq+1qzZg1vv/02n/70p2NIJHJoymbP8j5aefDNliotL6V6XfXe19XrqyntVRpjomhKSkrYsWMHADt27KCkpGTvvE2bNvHcc89x9tln06FDh7giNiqpn3lSc0OysydJ1oqlu88DNmdr/VFUjKxg06pNbFmzhfrd9Syes5jhpwyPM1IkgwYNYtmyZQAsW7aMQYMGAbBt2zbmzJnD6aefTteuXeOM2KikfuZJzQ3Jzp4ksZ86FN5HYzpAlz5dWnXdBUUFTLx+IjOvmYmnnFHnj6J8SHmrtnGwnnnmGdasWUNtbS0PPPAAxx9/PKNGjeLPf/4zS5cupbS0lPHjxwOwcOFCamtref7554FgE33y5Mlxxv+EJHzmB5LU3JDs7Eli7p69lZsNBp4Ib03ZrH4j+/n0+1tyrWhuqLq3Ku4IGes7rW/cESQBZkydwZq31nxyB/pBOM76+dyC6L/vXVO3LmzsHjxtIa+PhouItBYVSxGRCLJ56tCDwIvAcDP70MymZastEZFsy9oBHne/JFvrFhFpa9oMFxGJQMVSRCQCFUsRkQhULEVEIlCxFBGJQMVSRCQCFUsRkQhULEVEIlCxFBGJQMVSRCQCFUsRkQhULEVEIoh9pHQROTS906sjEz5/XPQ3/Dx7WaJQz1JEJAIVSxGRCFQsRUQiULEUEYmg0QM8ZlbW1Bvdvbqp+YeSJN8h8bSf/nvcETIy95or444gOcbMugL3AEcDDvwj8DbwMDAY+AD4nLtvyWT9TR0NXxw2mH77y4bXDgzMpEERkSy5E3jK3S8ys/ZAR+Am4Bl3/6GZ3QjcCNyQycobLZbuPiCTFYqItDUz6wKcAlwB4O67gF1mdgFwWrjY/cBcMiyWkfZZmtnFZnZT+Ly/mR2fSWMiIgehp5m9kvaYnjbvMGAD8Bsz+7uZ3WNmnYDe7l4VLrMW6J1p482elG5mPwfaEVTtfwN2AHcBn8q0URGRDGx09xMamVcEVAJfd/cFZnYnwSb3Xu7uZuaZNh6lZ3mSu38FqA0b3Ay0z7RBEZEs+BD40N0XhK//i6B4rjOzvgDhv+szbSBKsdxtZgUEB3Uwsx5AKtMGRURam7uvBVaZ2fBw0nhgCfA4MDWcNhV4LNM2olwb/gvg90AvM7sV+Bxwa6YNiohkydeBB8Ij4e8BXyToEM4ys2nACoL6lZFmi6W7/4eZLQTODCdNcfc3M21QRCQb3H0RcKB9muNbY/1RRx0qBHYTbIrrqh8ROeQ0W/jM7NvAg0A/oD/wWzP7VraDiYjkkig9y38AjnP3HQBm9n3g78APshlMRCSXRNmkrmLfoloUThMROWQ0NZDGjwn2UW4GFpvZ0+Hrs4GX2yaeiEhuaGozvOGI92Lgj2nTX8peHBGR3NTUQBr3tmUQEZFcFuXa8MOB7wNHAh0aprv7sCzmalWp+hR3X3E3pb1KufSOS+OOE9nyF5fz1B1PkUqlqJxUybip4+KO1KjfvbWDP767EwOGdC3ihjFlbNqZ4rvzP6K6zhnWvYibTiqjXaE1u664JOnz3l+SsydFlAM89wG/IRjH8hxgFsFgmomx4OEF9BzcM+4YLZKqTzH79tlc9pPLuOqhq3jzT2+y4b0Nccc6oA076nnk7R38akJ3fnNeD+odnv2gll/9fRtTRnTkgQt6UNremP3uzrijNipJn/f+kpw9SaIUy47u/jSAu7/r7t8hKJqJUL2umneef4fKCyrjjtIiq5espnv/7nSr6EZhu0KOOusols5bGnesRtU71NU79Smnbo/TvaSAv6/bxakDiwH430NKmP/hrphTNi5pn3e6JGdPkijFsi4cSONdM/uqmZ0PlGY5V6t56sdPcebVZ2KWu5t/B1Kzvoay3h/f2aOsvIyaDTUxJmpcr46FfG5kRz7/h0189pGNdGpvDO/ejs7tCigssHCZAjbuqI85aeOS9HnvL8nZkyRKsfwG0Am4BhgLfJng3hZNMrMBZvYXM1tiZovN7NqDi9pyy+Yvo1P3TvQb2a+tmz6k1NSleOHDOh68oAf/NbkntXucv1Xlbi9SJBNRBtJoGB+uBri8BeveA3zT3V81s1JgoZnNcfclGeTMyMrXVvL2vLd554V32FO3h7rtdTxy8yNMvnVyW0XIWGl5KdXrPr4nXPX6akp75WaHfuHaXfTpXEjXDsHf3pMHFPPmht1s252iPuUUFhgbdqTo2bEw5qSNS9Lnvb8kZ0+Spk5Kf5RwDMsDcfcmK044lHtV+LzGzN4CKgjGmGsTZ151JmdeFQyW9MHCD3jhgRcSUSgBKkZWsGnVJras2UJZrzIWz1nM5O/lZvbyToUs2bib2j1OcSG8unY3w3sUcVzv9jy3so4zBnfg6fd2MrZ/7o4ZnaTPe39Jzp4kTfUsf95ajZjZYOA4YMEB5k0HpgN06dOltZpMvIKiAiZeP5GZ18zEU86o80dRPqQ87lgHdGTPdpw6sJjpT26m0GBotyLOO6KEE/sV873nP+Le17YztHsREw8viTtqo5L0ee8vydmTxNwzviVFtAbMOgPPAd9390eaWrbfyH4+/f7pTS0irUz3DZcoZkydwZq31rTqUdITThjkr7zy7cjLm31lYRP34Mm6rI5NaWbtCEZZf6C5QikiksuyViwtOFfnXuAtd78jW+2IiLSFyMXSzIpbuO6xBEfPzzCzReFjYgvXISKSE6JcGz6aoIfYBRhoZscCX3L3rzf1PnefT3CJpIhI4kXpWf4UOA/YBODurwGnZzOUiEiuiVIsC9x9xX7Tcve6NRGRLIhyD55V4aa4m1khwb15l2U3lohIbonSs7wSuA4YCKwDTgyniYgcMqJcG74euLgNsoiI5KwoR8Pv5gDXiLu7LrURkUNGlH2Wf0573gH4DLAqO3FERHJTlM3wfW4hYWb/CczPWiIRkRyUyeWOhwG9WzuIiEgui7LPcgsf77MsADYDN2YzlIhIrmmyWIaDYRwLrA4npTzbY7qJiOSgJjfDw8I4293rw4cKpYgckqLss1xkZsdlPYmISA5r6h48Re6+h+B2EC+b2bvAdoKRhNzdk3UjbhGRg9DUPsu/AZXApDbKIiKSs5oqlgbg7u+2URYRkZzVVLHsZWbXNTZTt4oQkVwTjoz2CrDa3c8zs8OAh4AewELgcnfflcm6myqWhUBnNNp5XkvqXRJ1V0ppxLXAW0BZ+PpHwI/d/SEzuwuYBmT05WmqWFa5+3czWamISFszs/7AucD3gevC88TPAC4NF7kfuIUMi2VTpw6pRykiuaSnmb2S9th/5LOfAP8CpMLXPYCt4Vk9AB8CFZk23lTPcnymKxURyYKN7n7CgWaY2XnAendfaGanZaPxRoulu2/ORoMiIlkwFpgU3m67A8E+yzuBrmnnjPfn40u3WyyTUYdERHKKu3/L3fu7+2CCOzs86+6XAX8BLgoXmwo8lmkbKpYiks9uIDjYs5xgH+a9ma4oykjpIiKtr6ae1HObWn217j4XmBs+fw8Y3RrrVc9SRCQCFUsRkQhULEVEIlCxFBGJQMVSRCQCFUsRkQhULEVEIlCxFBGJIO9PSn/se4+x7PlldOrWia89+LW447TI8heX89QdT5FKpaicVMm4qePijhRJUnKvrN7Dd+dX731dVVPPF4/txOINu1lVUw/Atl0pOrcv4J6J3eOK2ayP1n3EH275A9s2b8PMqLywkhMvPjHuWHkn74vlqPNGMXrKaB699dG4o7RIqj7F7Ntnc/nPLqesvIy7r7ib4ScPp9eQXnFHa1KScg8sK9pbBOtTzpRHNzGufzEXjei4d5lfLqyhU/vc3gArKCzg7GvPpu+IvtRtr2PG1BkcPvrwnPzMkyy3vwWtYNBxgygpK4k7RoutXrKa7v27062iG4XtCjnqrKNYOm9p3LGaldTcr67bRb/OhfTpXLh3mrszd2Ud4wcVx5iseaU9S+k7oi8AxZ2K6TW4F9Ubqpt5l7RU3hfLpKpZX0NZ77K9r8vKy6jZUBNjomiSmvvZD+oYP3jfovj6+t1061BA/7LkbIBtXbOVqmVV9D+qf9xR8k7WiqWZdTCzv5nZa2a22MxuzVZbIgdjd73zwuo6Th3YYZ/pz66oY/zgDo28K/fs2rGLWTfOYsI3JlDcObd7w0mUzZ5lHXCGux8LjAImmJn2OkdUWl5K9bqPN6Wq11dT2qs0xkTRJDH3gjW7GNatiO4lH/861Kecv66q5fQc3wRvUL+nnlk3zuKYCccw8vSRccfJS1krlh7YFr5sFz48W+3lm4qRFWxatYkta7ZQv7uexXMWM/yU4XHHalYScz+7opYz9utBLly7iwFlRfTqWNjIu3KHu/P4bY/Tc3BPxlw6Ju44eSurO2PCe/guBI4AfuHuC7LZ3oH8/ju/54NXP2DH1h3ccd4dnDb9NConVbZ1jBYrKCpg4vUTmXnNTDzljDp/FOVDyuOO1ayk5d65x1lYtYvrRu/b+312RR3jByVjE3zVa6t4/cnXKT+inLu+cBcA468cz9CxQ2NOll+yWizdvR4YZWZdgUfN7Gh3fzN9mfAObdMBuvTp0uoZPnvbZ1t9nW1l6NihifzCJyl3SZHx2JRPnmJz45iyAyydmwaOGsjNC26OO0bea5Oj4e6+leBeGBMOMG+Gu5/g7id07Nrxk28WEckB2Twa3ivsUWJmJcBZQO6fcCcicgDZ3AzvC9wf7rcsAGa5+xNZbE9EJGuyVizd/XXguGytX0SkLekKHhGRCFQsRUQiULEUEYlAxVJEJAIVSxGRCFQsRUQiULEUEYkgOaOaikheqanvwnObJ7XgHTdlLUsU6lmKiESgYikiEoGKpYhIBCqWIiIRqFiKiESgYikiEoGKpYhIBCqWIiIRqFiKSOKZ2QAz+4uZLTGzxWZ2bTi9u5nNMbN3wn+7ZdqGiqWI5IM9wDfd/UjgROAqMzsSuBF4xt2HAs+ErzOiyx0lkeZec2XcETJy2k//Pe4IGXlw/ea4IzTJ3auAqvB5jZm9BVQAFwCnhYvdD8wFbsikDRVLEUmKnmb2StrrGe4+Y/+FzGwwwf2/FgC9w0IKsBbonWnjKpYikhQb3f2EphYws87A74F/cvdqM9s7z93dzDzTxrXPUkTygpm1IyiUD7j7I+HkdWbWN5zfF1if6fpVLEUk8SzoQt4LvOXud6TNehyYGj6fCjyWaRvaDBeRfDAWuBx4w8wWhdNuAn4IzDKzacAK4HOZNqBiKSKJ5+7zAWtk9vjWaEOb4SIiEahYiohEoGIpIhKBiqWISAQqliIiEahYiohEoGIpIhKBiqWISAQqliIiEeT9FTzLX1zOU3c8RSqVonJSJeOmjos7UmRJza7c2beyeg/fnV+993VVTT1fPLYTizfsZlVNPQDbdqXo3L6AeyZ2jytmXsnrYpmqTzH79tlc/rPLKSsv4+4r7mb4ycPpNaRX3NGaldTsyt02BpYV7S2C9SlnyqObGNe/mItGdNy7zC8X1tCpvTYeW0teF8vVS1bTvX93ulUEt9046qyjWDpvac7+AqRLanblbnuvrttFv86F9OlcuHeauzN3ZR13jO8aY7Km1bTbwnMVv4s7RmR5/WenZn0NZb3L9r4uKy+jZkNNjImiS2p25W57z35Qx/jBxftMe339brp1KKB/WV73h9pU1oulmRWa2d/N7IlstyVyqNld77ywuo5TB3bYZ/qzK+oYP7hDI++STLRFz/Ja4K02aOcTSstLqV738U7w6vXVlPYqjSNKiyU1u3K3rQVrdjGsWxHdSz7+Va5POX9dVcvpg4qbeKe0VFaLpZn1B84F7slmO42pGFnBplWb2LJmC/W761k8ZzHDTxkeR5QWS2p25W5bz66o5Yz9epAL1+5iQFkRvToWNvIuyUS2d2j8BPgXoNE/0WY2HZgO0KVPl1ZtvKCogInXT2TmNTPxlDPq/FGUDylv1TayJanZlbvt7NzjLKzaxXWj9/31enZFHeMHaRO8tWWtWJrZecB6d19oZqc1tlx4K8sZAP1G9sv4zmuNGTp2KEPHDm3t1baJpGZX7rZRUmQ8NuWTR+tvHFN2gKXlYGVzM3wsMMnMPgAeAs4ws5lZbE9EJGuyVizd/Vvu3t/dBwMXA8+6+xey1Z6ISDbl9XmWIiKtpU3OWHX3ucDctmhLRCQb1LMUEYlAxVJEJAIVSxGRCFQsRUQiULEUEYlAxVJEJAIVSxGRCFQsRUQiULEUEYlAxVJEJAIVSxGRCFQsRUQiULEUEYlAxVJEJAIVSxHJC2Y2wczeNrPlZnZja69fxVJEEs/MCoFfAOcARwKXmNmRrdmGiqWI5IPRwHJ3f8/ddxHc9+uC1mzA3Fv9hooZM7MNwIosrb4nsDFL684m5W57Sc2ezdyD3P2Tt5I8CGb2FEHmqDoAtWmvZ4R3h8XMLgImuPuXwteXA59296tbK2+b3FYiqtb+z0hnZq+4+wnZWn+2KHfbS2r2pOV29wlxZ2gJbYaLSD7M3TBcAAAGsUlEQVRYDQxIe90/nNZqVCxFJB+8DAw1s8PMrD3B7bcfb80GcmozPMtmxB0gQ8rd9pKaPam5D5q77zGzq4GngULg1+6+uDXbyKkDPCIiuUqb4SIiEahYiohEoGIpkifMzOLOkM/ytlia2XAzG2Nm7cJLoRIloZmPMLMTzKw47iwtYWZHmdmpZtYj7iwtZWbjwhOwcXdXwcyevDwabmaTgX8jOM9qNfCKmd3n7tXxJmuemQ1z92XuXm9mhe5eH3emKMzsPILPfBOw1sxudvdlMcdqlpmdA/wIeA9oZ2bT3H1tzLGaZWYFQEfgV8FL6+Tud4UFs8DdUzFHzDt517M0s3bA54Fp7j4eeIzgZNUbzKws1nDNCAvOIjP7LUBDwYw5VrPM7CTgdmCqu58ObAFafdSX1mZmpwF3Al9y9wuBXcDRsYaKyN1T7r4NuB+4FzjJzL7RMC/WcHkq74plqAwYGj5/FHgCaAdcmqubKWbWCbga+Cdgl5nNhOQUTOBH7v738PnNQPcEbI6vA77i7n8zsz7Ap4GrzexXZnZRrn5X9rOHoDNwPzDazO4wsx9YIF9/v2ORdx+mu+8G7gAmm9nJ4V/Z+cAiYFys4Zrg7tuBfwR+C1wPdEgvmHFmi2AB8Ajs3ddaDAwi+KNFru4LdPe33P0v4ctpwC/DHuaLwEW0bJCHuDwGrHX3Z4BXgK8CZR5QD7MV5V2xDP0V+BNwuZmd4u717v5boB9wbLzRGufua9x9m7tvBL4ClDQUTDOrNLMR8SY8sPDzbdgfbMBWYLO7bzCzy4DbzKwkvoTNc/fvu/tt4fP7CAr9gCbflBt2AsPN7MsEhfKHwEAz+0q8sfJPXh7gcfdaM3sAcOBbYZGpA3oDVbGGi8jdN4Vf+NvNbCnBJVynxxyrWe6+B9hmZqvM7AfA2cAV7r4z5miNMjPztEvZzOyzBN+VNfGlisbd15jZKuBfgavc/b/N7HRgeczR8k5eX+4YXlA/lqCXVgvcmbZfLRHCnfY3AGe5+xtx52lOuJ+vHfBW+O94d38n3lTRhPtYvwBcB3ze3d+MOVIkZjYAKHf3heFrHQ3Pgrwulg3C/WiJ24djZt2AWcA33f31uPO0hJldAbzc2oMZZFN4JsVZwLvu/nbceVpq/x6ytK5DolgmmZl1cPfa5pfMLfrFlXyjYikiEkG+Hg0XEWlVKpYiIhGoWIqIRKBiKSISgYplnjCzejNbZGZvmtnvzKzjQazrNDN7Inw+ycwaHRTDzLqa2dcyaOMWM7s+6vT9lrkvvE901LYGm1kizpmU3KVimT92uvsodz+aYPScr6bPzHRgBXd/3N1/2MQiXYEWF0uRpFGxzE9/BY4Ie1Rvm9l/AG8CA8zsbDN70cxeDXugnQHMbIKZLTWzV4HJDSsysyvM7Ofh895m9qiZvRY+TiK4FvnwsFd7e7jcP5vZy2b2upndmraub5vZMjObDwxv7ocwsy+H63nNzH6/X2/5TDN7JVzfeeHyhWZ2e1rbuj5aWo2KZZ4xsyLgHKDh0sihBKPpHAVsB74DnOnulQSj1FxnZh2Au4HzgeOBPo2s/qfAc+5+LFAJLCYYt/LdsFf7z2Z2dtjmaGAUcLyZnWJmxxPcy3kUMBH4VIQf5xF3/1TY3lsEIwM1GBy2cS5wV/gzTAM+cvdPhev/spkdFqEdkWbl5UAah6gSM1sUPv8rwYCw/YAV7v5SOP1E4Ejg+XCoxvYEw5GNAN5vuIY7HOlo+gHaOAP4B9g7bNxH4SWZ6c4OHw3X4HcmKJ6lwKPuviNs4/EIP9PRZnYbwaZ+Z4J7QjeYFV6++o6ZvRf+DGcD/yttf2aXsO2cH7Fdcp+KZf7Y6e6j0ieEBXF7+iRgjrtfst9y+7zvIBnwA3f/1X5t/FMG67oPuNDdXwuvNT8tbd7+l5552PbX3T29qGJmgzNoW2Qf2gw/tLwEjDWzIyAYnd3MhgFLgcFmdni43CWNvP8Z4MrwvYVm1gWoIeg1Nnga+Me0faEVZlYOzAMuNLMSMysl2ORvTilQFQ5wcdl+86aYWUGYeQjwdtj2leHymNkwC0agFzlo6lkeQsLBeK8AHrSPb/nwHXdfZmbTgT+a2Q6CzfjSA6ziWmCGmU0D6oEr3f1FM3s+PDXnyXC/5UjgxbBnuw34gru/amYPA68B64GXI0T+V4JR2DeE/6ZnWgn8jWCQ3q+GY5jeQ7Av89VwqLgNwIXRPh2RpmkgDRGRCLQZLiISgYqliEgEKpYiIhGoWIqIRKBiKSISgYqliEgEKpYiIhH8DwWig59ii/w1AAAAAElFTkSuQmCC\n", 504 | "text/plain": [ 505 | "
" 506 | ] 507 | }, 508 | "metadata": {}, 509 | "output_type": "display_data" 510 | } 511 | ], 512 | "source": [ 513 | "metrics(y, y_pred)" 514 | ] 515 | } 516 | ], 517 | "metadata": { 518 | "kernelspec": { 519 | "display_name": "conda_tensorflow_p36", 520 | "language": "python", 521 | "name": "conda_tensorflow_p36" 522 | }, 523 | "language_info": { 524 | "codemirror_mode": { 525 | "name": "ipython", 526 | "version": 3 527 | }, 528 | "file_extension": ".py", 529 | "mimetype": "text/x-python", 530 | "name": "python", 531 | "nbconvert_exporter": "python", 532 | "pygments_lexer": "ipython3", 533 | "version": "3.6.10" 534 | } 535 | }, 536 | "nbformat": 4, 537 | "nbformat_minor": 2 538 | } 539 | -------------------------------------------------------------------------------- /sample_notebooks/TF_classi_transformers_BBCsports_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Tensorflow Framework for Multi-class Text Classification" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## BBC News Classification - 5 Class Problem\n", 15 | "\n", 16 | "* Consists of 737 documents from the BBC Sport website corresponding to sports news articles in five topical areas from 2004-2005.\n", 17 | "* Class Labels: 5
\n", 18 | "\n", 19 | " 0 - Rugby
\n", 20 | " 1 - Tennis
\n", 21 | " 2 - Football
\n", 22 | " 3 - Cricket
\n", 23 | " 4 - Atheletics
" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "## Reading Data" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 1, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "import sys\n", 40 | "import numpy as np\n", 41 | "import pandas as pd\n", 42 | "\n", 43 | "sys.path.append(\"..\")\n", 44 | "pd.set_option('display.max_colwidth', -1)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "data": { 54 | "text/plain": [ 55 | "((471, 3), (148, 2), (118, 3))" 56 | ] 57 | }, 58 | "execution_count": 2, 59 | "metadata": {}, 60 | "output_type": "execute_result" 61 | } 62 | ], 63 | "source": [ 64 | "df_train = pd.read_csv('../datasets/bbcsports/train.csv')\n", 65 | "df_test = pd.read_csv('../datasets/bbcsports/test.csv')\n", 66 | "df_score = pd.read_csv('../datasets/bbcsports/test_with_label.csv')\n", 67 | "df_val = pd.read_csv('../datasets/bbcsports/dev.csv')\n", 68 | "df_train.shape, df_test.shape, df_val.shape" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 3, 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "data": { 78 | "text/plain": [ 79 | "2 169\n", 80 | "0 94 \n", 81 | "3 79 \n", 82 | "4 65 \n", 83 | "1 64 \n", 84 | "Name: label, dtype: int64" 85 | ] 86 | }, 87 | "execution_count": 3, 88 | "metadata": {}, 89 | "output_type": "execute_result" 90 | } 91 | ], 92 | "source": [ 93 | "df_train.label.value_counts()" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 4, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "data": { 103 | "text/plain": [ 104 | "2 43\n", 105 | "0 23\n", 106 | "3 20\n", 107 | "4 16\n", 108 | "1 16\n", 109 | "Name: label, dtype: int64" 110 | ] 111 | }, 112 | "execution_count": 4, 113 | "metadata": {}, 114 | "output_type": "execute_result" 115 | } 116 | ], 117 | "source": [ 118 | "df_val.label.value_counts()" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 5, 124 | "metadata": {}, 125 | "outputs": [ 126 | { 127 | "data": { 128 | "text/plain": [ 129 | "2 53\n", 130 | "0 30\n", 131 | "3 25\n", 132 | "4 20\n", 133 | "1 20\n", 134 | "Name: label, dtype: int64" 135 | ] 136 | }, 137 | "execution_count": 5, 138 | "metadata": {}, 139 | "output_type": "execute_result" 140 | } 141 | ], 142 | "source": [ 143 | "df_score.label.value_counts()" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": {}, 149 | "source": [ 150 | "## Training (Bert | Electra) classifier with FrameWork" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 6, 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "name": "stdout", 160 | "output_type": "stream", 161 | "text": [ 162 | "WARNING:tensorflow:From /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/__init__.py:1467: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead.\n", 163 | "\n" 164 | ] 165 | } 166 | ], 167 | "source": [ 168 | "from classitransformers.report import metrics\n", 169 | "from classitransformers.configs import Configs\n", 170 | "\n", 171 | "from classitransformers.tfelectra import ElectraClassification #(for using Electra)\n", 172 | "\n", 173 | "#from classitransformers.tfbert import BertClassification" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "## Set Configurations\n", 181 | "\n", 182 | "* Pre-trained config filepath, vocab filepath and model checkpoint filepath\n", 183 | "* labels list\n", 184 | "* train, test and eval batch size\n", 185 | "* max_seq_len - how long is the each text input sentence !!\n", 186 | "* data_dir - data path folder (optional)\n", 187 | "* output_dir - output path folder (optional)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 7, 193 | "metadata": {}, 194 | "outputs": [ 195 | { 196 | "name": "stdout", 197 | "output_type": "stream", 198 | "text": [ 199 | "================================================================================\n", 200 | "Config: model=electra, trial 1/1\n", 201 | "================================================================================\n", 202 | "data_dir ../datasets/bbcsports/\n", 203 | "do_eval True\n", 204 | "do_lower_case True\n", 205 | "do_predict False\n", 206 | "do_train True\n", 207 | "embedding_size None\n", 208 | "eval_batch_size 8\n", 209 | "gcp_project None\n", 210 | "init_checkpoint ../models/Electra_base/\n", 211 | "iterations_per_loop 1000\n", 212 | "keep_all_models True\n", 213 | "label_list ['0', '1', '2', '3', '4']\n", 214 | "layerwise_lr_decay 0.8\n", 215 | "learning_rate 5e-05\n", 216 | "log_examples False\n", 217 | "master None\n", 218 | "max_seq_length 256\n", 219 | "model_hparam_overrides {}\n", 220 | "model_name electra\n", 221 | "model_size base\n", 222 | "num_tpu_cores 1\n", 223 | "num_train_epochs 3\n", 224 | "num_trials 1\n", 225 | "output_dir ../electra_output_sports/\n", 226 | "predict_batch_size 8\n", 227 | "pretrained_model_dir ../models/Electra_base/\n", 228 | "save_checkpoints_steps 1000000\n", 229 | "task_names ['textclassification']\n", 230 | "tasks {'textclassification': {'type': 'classification', 'labels': ['0', '1', '2', '3', '4'], 'header': True, 'text_column': 1, 'label_column': 2}}\n", 231 | "tpu_job_name None\n", 232 | "tpu_name None\n", 233 | "tpu_zone None\n", 234 | "train_batch_size 16\n", 235 | "use_tfrecords_if_existing False\n", 236 | "use_tpu False\n", 237 | "vocab_file ../models/Electra_base/vocab.txt\n", 238 | "vocab_size 30522\n", 239 | "warmup_proportion 0.1\n", 240 | "weight_decay_rate 0.01\n", 241 | "\n", 242 | "Loading dataset textclassification_train\n", 243 | "Existing tfrecords not found so creating\n", 244 | "Writing example 0 of 471\n" 245 | ] 246 | } 247 | ], 248 | "source": [ 249 | "# config = Configs(pretrained_model_dir = '../models/Bert_base',\n", 250 | "# num_train_epochs = 3,\n", 251 | "# train_batch_size = 16, \n", 252 | "# eval_batch_size = 8, \n", 253 | "# predict_batch_size = 8, \n", 254 | "# do_train = True, \n", 255 | "# do_eval = True, \n", 256 | "# label_list = [\"0\", \"1\", \"2\", \"3\", \"4\"], \n", 257 | "# max_seq_length = 256, \n", 258 | "# data_dir='../datasets/bbcsports/',\n", 259 | "# output_dir = '../bert_output_bbcsports/')\n", 260 | "\n", 261 | "\n", 262 | "# model = BertClassification(config)\n", 263 | "\n", 264 | "\n", 265 | "config = Configs(pretrained_model_dir = '../models/Electra_base/',\n", 266 | " model_name='electra',\n", 267 | " num_train_epochs=3, \n", 268 | " train_batch_size = 16,\n", 269 | " eval_batch_size = 8, \n", 270 | " do_train = True, \n", 271 | " do_eval = True, \n", 272 | " label_list = [\"0\", \"1\", \"2\", \"3\", \"4\"],\n", 273 | " max_seq_length = 256,\n", 274 | " data_dir = \"../datasets/bbcsports/\", \n", 275 | " output_dir = \"../electra_output_sports/\")\n", 276 | "\n", 277 | "model = ElectraClassification(config) " 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": {}, 283 | "source": [ 284 | "## Train Method" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 8, 290 | "metadata": { 291 | "scrolled": true 292 | }, 293 | "outputs": [ 294 | { 295 | "name": "stdout", 296 | "output_type": "stream", 297 | "text": [ 298 | "================================================================================\n", 299 | "Start training\n", 300 | "================================================================================\n", 301 | "Training for 90 steps\n", 302 | "Building model...\n", 303 | "Building complete\n", 304 | "10/90 = 11.1%, SPS: 0.3, ELAP: 31, ETA: 4:10 - loss: 22.2336\n", 305 | "20/90 = 22.2%, SPS: 0.4, ELAP: 48, ETA: 2:48 - loss: 19.4345\n", 306 | "30/90 = 33.3%, SPS: 0.5, ELAP: 1:05, ETA: 2:10 - loss: 18.7599\n", 307 | "40/90 = 44.4%, SPS: 0.5, ELAP: 1:22, ETA: 1:43 - loss: 6.0089\n", 308 | "50/90 = 55.6%, SPS: 0.5, ELAP: 1:39, ETA: 1:19 - loss: 6.5244\n", 309 | "60/90 = 66.7%, SPS: 0.5, ELAP: 1:56, ETA: 58 - loss: 0.8747\n", 310 | "70/90 = 77.8%, SPS: 0.5, ELAP: 2:14, ETA: 38 - loss: 0.6737\n", 311 | "80/90 = 88.9%, SPS: 0.5, ELAP: 2:31, ETA: 19 - loss: 0.1257\n", 312 | "90/90 = 100.0%, SPS: 0.5, ELAP: 2:48, ETA: 0 - loss: 0.0573\n", 313 | "90/90 = 100.0%, SPS: 0.5, ELAP: 2:50, ETA: 0\n", 314 | "================================================================================\n", 315 | "Training Finished\n", 316 | "================================================================================\n", 317 | "================================================================================\n", 318 | "Start Evaluation\n", 319 | "================================================================================\n", 320 | "Evaluating textclassification\n", 321 | "Loading dataset textclassification_dev\n", 322 | "Existing tfrecords not found so creating\n", 323 | "Writing example 0 of 118\n", 324 | "Building model...\n", 325 | "Building complete\n", 326 | "textclassification: accuracy: 95.76 - loss: 0.11\n", 327 | "\n" 328 | ] 329 | } 330 | ], 331 | "source": [ 332 | "model.train()" 333 | ] 334 | }, 335 | { 336 | "cell_type": "markdown", 337 | "metadata": {}, 338 | "source": [ 339 | "## Test Methods" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 9, 345 | "metadata": {}, 346 | "outputs": [ 347 | { 348 | "name": "stdout", 349 | "output_type": "stream", 350 | "text": [ 351 | "Writing out predictions for [Task(textclassification)] test\n", 352 | "Loading dataset textclassification_test\n", 353 | "Existing tfrecords not found so creating\n", 354 | "Writing example 0 of 148\n", 355 | "Building model...\n", 356 | "Building complete\n", 357 | "Getting predictions for 148 textclassification examples (test)\n" 358 | ] 359 | } 360 | ], 361 | "source": [ 362 | "prediction = model.test()" 363 | ] 364 | }, 365 | { 366 | "cell_type": "markdown", 367 | "metadata": {}, 368 | "source": [ 369 | "**Testing a sentence, takes a list of sentences.**" 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": 10, 375 | "metadata": {}, 376 | "outputs": [ 377 | { 378 | "name": "stdout", 379 | "output_type": "stream", 380 | "text": [ 381 | "Writing out predictions for [Task(textclassification)] infer\n", 382 | "Loading dataset textclassification_infer\n", 383 | "Existing tfrecords not found so creating\n", 384 | "Writing example 0 of 1\n", 385 | "Building model...\n", 386 | "Building complete\n", 387 | "Getting predictions for 1 textclassification examples (infer)\n" 388 | ] 389 | }, 390 | { 391 | "data": { 392 | "text/plain": [ 393 | "[array([0.04234521, 0.99357986, 0.14866598, 0.5092559 , 0.16855492],\n", 394 | " dtype=float32)]" 395 | ] 396 | }, 397 | "execution_count": 10, 398 | "metadata": {}, 399 | "output_type": "execute_result" 400 | } 401 | ], 402 | "source": [ 403 | "model.text_inference(['Rodger Fedrer is one of the great tennis player. Rafael Nadal is very good player at clay court. They have combined won more than 30 grand slams till now.'])" 404 | ] 405 | }, 406 | { 407 | "cell_type": "markdown", 408 | "metadata": {}, 409 | "source": [ 410 | "## Export the checkpoint model to pb (protobuf) format\n", 411 | "\n", 412 | "**In TensorFlow, the protbuf file contains the graph definition as well as the weights of the model. Thus, a pb file is all you need to be able to run a given trained model.**\n", 413 | "\n", 414 | "Only support of Bert for now" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": 11, 420 | "metadata": { 421 | "scrolled": true 422 | }, 423 | "outputs": [], 424 | "source": [ 425 | "# Currently does not have support for Electra\n", 426 | "\n", 427 | "# model.export_model()" 428 | ] 429 | }, 430 | { 431 | "cell_type": "markdown", 432 | "metadata": {}, 433 | "source": [ 434 | "## Inferencing from Trained Pb Model\n", 435 | "\n", 436 | "* This tensorflow format is used for deployment purpose." 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": 12, 442 | "metadata": {}, 443 | "outputs": [], 444 | "source": [ 445 | "# Not supported for Electra\n", 446 | "\n", 447 | "# predictions = model.inference(test_filename = '../datasets/bbcsports/test.csv', batch_size = 256)" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": 13, 453 | "metadata": {}, 454 | "outputs": [], 455 | "source": [ 456 | "y_pred = [np.argmax(tup) for tup in prediction]\n", 457 | "y = df_score.label" 458 | ] 459 | }, 460 | { 461 | "cell_type": "markdown", 462 | "metadata": {}, 463 | "source": [ 464 | "## Accuracy and Report" 465 | ] 466 | }, 467 | { 468 | "cell_type": "code", 469 | "execution_count": 15, 470 | "metadata": {}, 471 | "outputs": [ 472 | { 473 | "name": "stdout", 474 | "output_type": "stream", 475 | "text": [ 476 | "\n", 477 | "\n", 478 | "Accuracy: 0.993\n", 479 | "Precision: 0.994\n", 480 | "Recall: 0.996\n", 481 | "F1-Score: 0.995\n" 482 | ] 483 | }, 484 | { 485 | "data": { 486 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUUAAAEmCAYAAAD1FIKpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3XucVXW9//HXewYUxBkEZQABxQshXpLIsBQvRZiSoj8tzUjxRNE9yzwnK88vupn9PJqe8pzErEjzdiqPHeWHejyZR394F0sQFT0YAgJykQFFYebz+2MtaInM7D3D7Fl7bd/Px2M/Zq/L/q737Ad85rtu36WIwMzMEnV5BzAzqyYuimZmGS6KZmYZLopmZhkuimZmGS6KZmYZLopvI5J6S/oPSa9I+rcdaGeypDu7MlteJB0l6em8c1j1kK9TrD6SPg6cBxwANANzgR9ExH072O5ZwJeAIyJi8w4HrXKSAhgREQvzzmLF4Z5ilZF0HnA5cBEwENgL+Bfg5C5ofm/gmbdDQSyHpB55Z7AqFBF+VckL6AusBz7azjo7kxTNpenrcmDndNmxwIvA14AVwDLg79Jl3wHeADal25gKTAeuy7Q9HAigRzp9DvA8SW/1f4DJmfn3ZT53BPAw8Er684jMsnuA7wH3p+3cCezRxu+2Jf8/ZPKfAkwEngFWA9/MrD8WmAOsTdf9KbBTuuze9HfZkP6+Z2Ta/zrwEnDtlnnpZ/ZLtzEmnd4TWAkcm/e/Db+67+WeYnV5H9ALuKWddb4FvBcYDRxKUhguzCwfRFJch5AUvisl9YuIb5P0Pm+KiF0j4pr2gkjqA/wzcEJENJAUvrnbWa8/cHu67u7AZcDtknbPrPZx4O+AJmAn4Px2Nj2I5DsYAvxv4GrgE8C7gaOAf5S0T7puC/BVYA+S72488HmAiDg6XefQ9Pe9KdN+f5Je87TshiPiOZKCeZ2kXYBfAjMj4p528lqNcVGsLrsDL0f7u7eTge9GxIqIWEnSAzwrs3xTunxTRMwi6SWN7GSeVuBgSb0jYllEzNvOOh8Gno2IayNic0TcACwATsqs88uIeCYiXgNuJinobdlEcvx0E3AjScG7IiKa0+3PJ/ljQEQ8GhEPpNtdBFwFHFPG7/TtiHg9zfMmEXE1sBB4EBhM8kfI3kZcFKvLKmCPEse69gReyEy/kM7b2sY2RfVVYNeOBomIDSS7nJ8Flkm6XdIBZeTZkmlIZvqlDuRZFREt6fstRWt5ZvlrWz4v6R2SbpP0kqR1JD3hPdppG2BlRGwssc7VwMHATyLi9RLrWo1xUawuc4DXSY6jtWUpya7fFnul8zpjA7BLZnpQdmFE3BERE0h6TAtIikWpPFsyLelkpo74V5JcIyKiEfgmoBKfafdyC0m7khynvQaYnh4esLcRF8UqEhGvkBxHu1LSKZJ2kdRT0gmS/k+62g3AhZIGSNojXf+6Tm5yLnC0pL0k9QW+sWWBpIGSTk6PLb5Oshveup02ZgHvkPRxST0knQEcCNzWyUwd0QCsA9anvdjPbbN8ObBvB9u8AngkIj5Fcqz0Zzuc0grFRbHKRMSlJNcoXkhy5nMx8EXg39NVvg88AvwZ+AvwWDqvM9u6C7gpbetR3lzI6tIcS0nOyB7DW4sOEbEKOJHkjPcqkjPHJ0bEy53J1EHnk5zEaSbpxd60zfLpwExJayWdXqoxSScDx/O33/M8YIykyV2W2KqeL942M8twT9HMLMNF0cwsw0XRzCzDRdHMLKOqbojv2bsxdm5oyjtGhzXu/mreEcwqau2ytby69tVS14B2yLBhw2LjxlLX0f/Nyy+/fEdEHN+VGbanqorizg1NHHLGpXnH6LAPnfV43hHMKmrGlBld3ubGjRs59dRTy88wY0apu5W6hHefzcwyXBTNzDJcFM3MMlwUzcwyqupEi5lZZ0laRHIffAuwOSIOS0c5uolkVPlFwOkRsaa9dtxTNLNa8v6IGB0Rh6XTFwB3R8QI4O50ul0uimZWy04GZqbvZ9L+WKWAi6KZFccekh7JvKZtszyAOyU9mlk2MCKWpe9fInlCZrt8TNHMiuLlzG7x9oyLiCWSmoC7JC3ILoyISJ8F3i73FM2sJkTEkvTnCpInYo4FlksaDJD+XFGqHRdFMys8SX0kNWx5DxwHPAn8AZiSrjYFuLVUW959NrNaMBC4RRIkde36iJgt6WHgZklTSZ4yWfKxFC6KZlZ4EfE86fPAt5m/Chjfkba8+2xmllFzPcXWzW8w7/ffIlo2EdFC//2OYNjhZ7Jx3XKeveOf2LyxmT4D9mP/CV+hrr5n3nHbtXDOQmZfNpvW1lbGTBrDuCnj8o5UFufufkXOXm1qriiqvicHnvJd6nfqTWvLZub9/hvstvcYls29lcGHTmKPdxzF83/8V1bM/08GHXJC3nHb1NrSyqxLZnHWT86isamRq8+5mpFHjWTAvgPyjtYu5+5+Rc5ejWpu91kS9Tv1BiBaW4jWFkCse/Ev7L7/EQAMOOD9rHn+wRxTlrZk/hL6D+1PvyH9qO9Zz0ETDmLBvQtKfzBnzt39ipy9GtVcUYSkGP75xq/w6C+m0HfYofTqO4j6nfugunoAdtp1d97YsDrnlO1rXtFM48DGrdONTY00r2zOMVF5nLv7FTl7NapoUZR0vKSnJS2UVPJG7C7bbl097/zY5Yw55+dsWP4sr615sbs2bWYFV7GiKKkeuBI4ATgQOFPSgZXa3vb02HlXGoccwvqXnqbl9Q3prjS8sX4VO/Xp351ROqyhqYF1y9dtnV63Yh0NAxpyTFQe5+5+Rc5ejSrZUxwLLIyI5yPiDeBGkhErKmrTa6+w+fX1ALRufp1XFs+ld7+hNA45hFUL/x8AKxf8kX77jq10lB0yZNQQVi1exZqla2jZ1MK8u+Yx8uiReccqybm7X5GzV6NKnn0eAizOTL8IHF7B7QHwxoY1PPefV0C0EhHsvv+R9NvnPfTuP4xn77iUxQ/+hj577EvTgRMqHWWH1PWoY+L5E7nuy9cRrcHok0bTtG/1P/7VubtfkbNXo9wvyUmH+JkGsNOuO34JQZ89hvPOj/34LfN79R3EIadfssPtd6cRR45gxJEj8o7RYc7d/YqcvdpUcvd5CTAsMz00nfcmETEjIg6LiMN69m7cdrGZWbeqZFF8GBghaR9JOwEfIxmxwsysalVs9zkiNkv6InAHUA/8IiLmVWp7ZmZdoaLHFCNiFjCrktswM+tKNXlHi5lZZ+V+9tnM3p76t67lzA0lB8LeakYFs2S5p2hmluGiaGaW4aJoZpbhomhmluGiaGaW4aJoZpbhomhmluGiaGaW4aJoZpbhomhmluGiaGaW4aJoZpbhomhmluGiaGaW4aJoZpZRVeMpNu7+Kh866/G8Y3TY9LF75h2h06Y/tDTvCGZVxT1FM7MMF0UzswwXRTOzDBdFM7MMF0UzqxmS6iU9Lum2dHofSQ9KWijpJkk7lWrDRdHMasm5wFOZ6R8BP46I/YE1wNRSDbgomllNkDQU+DDw83RawAeA36arzAROKdWOi6KZFcUekh7JvKZts/xy4B+A1nR6d2BtRGxOp18EhpTaSFVdvG1m1o6XI+Kw7S2QdCKwIiIelXTsjmzERdHMasGRwCRJE4FeQCNwBbCbpB5pb3EosKRUQ959NrPCi4hvRMTQiBgOfAz4r4iYDPwR+Ei62hTg1lJtuSiaWS37OnCepIUkxxivKfUB7z6bWS52WT2Ad92w7bmS9nynrLUi4h7gnvT988DYjuRyT9HMLMNF0cwsw0XRzCyj5o8pLpyzkNmXzaa1tZUxk8Ywbsq4vCO1afHi1Zx99i9ZvrwZCaZNO4pzzx3P6tUbOOOMq1m0aBXDh+/OzTd/mn79+uQdt01F+s6zipobip292tR0T7G1pZVZl8xi8uWT+cKNX+DJO59k5fMr847Vph496rn00o8yf/50HnjgAq688h7mz1/KxRfPZvz4A3j22e8xfvwBXHzx7Lyjtqlo3/kWRc0Nxc5ejWq6KC6Zv4T+Q/vTb0g/6nvWc9CEg1hw74K8Y7Vp8OC+jBmzFwANDb0YNWowS5as5dZbn2DKlPcBMGXK+/j3f38iz5jtKtp3vkVRc0Oxs1ejmi6KzSuaaRzYuHW6samR5pXNOSYq36JFL/P443/l8MP3YfnydQwe3BeAQYMaWb58Xc7p2lbU77youaHY2atRxYqipF9IWiHpyUpto1atX7+R0067issvP53Gxt5vWiaJZPAPM6uESvYUfwUcX8H2S2poamBdple1bsU6GgY05JiotE2bWjjttKuYPHksp546BoCBAxtZtuwVAJYte4Wmpur9HYr4nUNxc0Oxs1ejihXFiLgXWF2p9ssxZNQQVi1exZqla2jZ1MK8u+Yx8uiReUZqV0QwdeqvGTVqEOedN2Hr/EmT3snMmXMAmDlzDieffGheEUsq2ne+RVFzQ7GzV6PcL8lJx0SbBtB3UN8ubbuuRx0Tz5/IdV++jmgNRp80mqZ9m7p0G13p/vuf49prH+CQQ4YwevT3ALjoolO44ILjOf30GVxzzf3svXd/br65I7dGda+ifedbFDU3FDt7NVJEVK5xaThwW0QcXM76e47aM6bNrN7/8G2ZPnbPvCN02vSHluYdwQpgxpQZLH1qaZcezH6X9ox76sr//75b63cebWs8xa5U02efzcw6ykXRzCyjkpfk3ADMAUZKelFSyadomZnlrWInWiLizEq1bWZWKd59NjPLcFE0M8twUTQzy3BRNDPLcFE0M8twUTQzy3BRNDPLcFE0M8twUTQzy3BRNDPLcFE0M8twUTQzy8h95G0ze3t6dsAuHH/Gu8r/wE8rlyXLPUUzswwXRTOzDBdFM7MMF0Uzs4w2T7RIamzvgxGxrr3lbydFfiLesmuW5R2hUwZPHZx3BKtR7Z19ngcEkH2s4ZbpAPaqYC4zs1y0WRQjYlh3BjEz6yxJvYB7gZ1J6tpvI+LbkvYBbgR2Bx4FzoqIN9prq6xjipI+Jumb6fuhkt69I7+AmVkXex34QEQcCowGjpf0XuBHwI8jYn9gDVDyqaIli6KknwLvB85KZ70K/KyTwc3Mulwk1qeTPdNXAB8AfpvOnwmcUqqtcnqKR0TEZ4CN6cZXAzt1NLSZ2Q7aQ9Ijmde07EJJ9ZLmAiuAu4DngLURsTld5UVgSKmNlHOb3yZJdSRVF0m7A60d+EXMzLrCyxFxWFsLI6IFGC1pN+AW4IDObKScnuKVwO+AAZK+A9xHsp9uZlZ1ImIt8EfgfcBukrZ0/oYCS0p9vmRPMSJ+LelR4IPprI9GxJOdzGtm1uUkDQA2RcRaSb2BCSSdtz8CHyE5Az0FuLVUW+WOklMPbCLZhfZdMGZWbQYDMyXVk9SomyPiNknzgRslfR94HLimVEMli6KkbwEfJ9lHF3C9pN9ExA935DcwM+sqEfFn4C3jkEXE88DYjrRVTk/xbOBdEfEqgKQfkFRcF0Uzqznl7Aov483Fs0c6z8ys5rQ3IMSPSY4hrgbmSbojnT4OeLh74pmZda/2dp+3nGGeB9yemf9A5eKYmeWrvQEhSp6lMTOrNeWcfd4P+AFwINBry/yIeEcFc3WZW793K8/c/wx9+vXh8zd8Pu84HbJwzkJmXzab1tZWxkwaw7gp4/KO1Kbrr7+enj17UldXhyROPfVUHnjgAV544QXq6+tpbGzkmGOOYeedd847apuK9H1vq8jZq005Z59/BXwf+CfgBODvSG/5K4LRJ45m7EfHcst3bsk7Soe0trQy65JZnPWTs2hsauTqc65m5FEjGbDvgLyjtemkk06iV6+tfzcZOnQoY8eOpa6ujgcffJC5c+dy+OGH55iwbUX8vrcocvZqVM7Z510i4g6AiHguIi4kKY6FsPe79qZ3Y++8Y3TYkvlL6D+0P/2G9KO+Zz0HTTiIBfcuyDtWhwwdOpS6uuSfWFNTExs2bMg5UduK/H0XOXs1Kqen+Ho6IMRzkj5Lcu9gQ2VjWfOKZhoH/u2JEI1NjSyZV/K2zdxI4vbbb0cSo0aNYtSoUW9a/vTTT7PffvvllK60on3fWUXOXo3KKYpfBfoAXyY5ttgX+GSpD0kaBvwaGEiyuz0jIq7ofFSrZpMmTaJPnz689tpr3H777ey2224MHpw8R+Wxxx6jrq6O/fffP+eUZqWVMyDEg+nbZv420Gw5NgNfi4jHJDUAj0q6KyLmdyLn205DUwPrlv/t2WDrVqyjYUD1dtD79OkDQO/evRk+fDgrVqxg8ODBPP300/z1r3/lxBNPRFKJVvJTtO87q8jZq1GbxxQl3SLp9229SjUcEcsi4rH0fTPwFGUM8GiJIaOGsGrxKtYsXUPLphbm3TWPkUePzDvWdm3atIk33nhj6/slS5bQv39/Fi9ezBNPPMGHPvQhevQod+yRfBTp+95WkbNXo/b+pf60qzYiaTjJzdoPbmfZNGAaQN9Bfbtqk1v97sLfseixRby69lUuO/Eyjp12LGMmjeny7XS1uh51TDx/Itd9+TqiNRh90mia9m3KO9Z2vfbaa9x5550ARAT77bcfw4YN48Ybb6SlpYVZs2YBycmWo446Ks+obSrS972tImevRoqo7NU1knYF/gT8ICLa7WHuOWrPmDZzWnurWBfzc5+tHDOmzGDpU0u79PjHYYftHY888q2y15c+82h7I293lYqOjSipJ8mo3b8pVRDNzKpBxYqikqPq1wBPRcRlldqOmVlXKrsoSuro/VlHkpyt/oCkuelrYgfbMDPrVuXc+zyWpMfXF9hL0qHApyLiS+19LiLuIxmp28ysMMrpKf4zcCKwCiAingDeX8lQZmZ5Kaco1kXEC9vMa6lEGDOzvJVzRe3idBc60idlfQl4prKxzMzyUU5P8XPAecBewHLgvek8M7OaU869zyuAj3VDFjOz3JVz9vlqtjOobET41hMzqznlHFP8z8z7XsD/AhZXJo6ZWb7K2X2+KTst6VrgvoolMjPLUWdu89uHZOBYM7OaU84xxTX87ZhiHbAauKCSoczM8tJuUUwHdTiU5LksAK1R6bHGzMxy1O7uc1oAZ0VES/pyQTSzmlbOMcW5kt5V8SRmZlWgzd1nST0iYjPJYwQelvQcsIFk5JuIiOof09/MrIPaO6b4EDAGmNRNWczMctdeURRARDzXTVnMzHLXXlEcIOm8thb6EQNmVovaK4r1wK549OyaVtSn4k0fu2feETpl+kNL845gJbRXFJdFxHe7LYmZWSdJGgb8muRuuwBmRMQVkvoDNwHDgUXA6RGxpr222rskxz1EMyuKzcDXIuJAkjFfvyDpQJK77+6OiBHA3ZRxN157RXF8VyQ1M6u0iFgWEY+l75uBp4AhwMnAzHS1mcAppdpqc/c5IlbveFQzsy6zh6RHMtMzImLGtitJGk5yffWDwMCIWJYueokyBrMpZzxFM7Nq8HJEHNbeCpJ2BX4HfCUi1iXDNyQiIiSVvFW5M0OHmZlVHUk9SQribyLi9+ns5ZIGp8sHAytKteOeopnlo7mF1j+t6pKm0hG9rgGe2uYa6j8AU4CL05+3lmrLRdHMasGRwFnAXyTNTed9k6QY3ixpKvACcHqphlwUzazwIuI+2r6MsENX0viYoplZhouimVmGi6KZWYaLoplZhouimVmGi6KZWUbNX5KzcM5CZl82m9bWVsZMGsO4KePyjlS2omYvSu7Fi1dz9tm/ZPnyZiSYNu0ozj13PNOn/wdXX30fAwbsCsBFF53CxImH5Jy2fUX5zougpotia0srsy6ZxVk/OYvGpkauPudqRh41kgH7Dsg7WklFzV6k3D161HPppR9lzJi9aG7eyLvf/QMmTBgFwFe/Op7zzz8u54TlKdJ3XgQ1vfu8ZP4S+g/tT78h/ajvWc9BEw5iwb0L8o5VlqJmL1LuwYP7MmbMXgA0NPRi1KjBLFmyNudUHVek77wIarooNq9opnFg49bpxqZGmlc255iofEXNXtTcixa9zOOP/5XDD98HgJ/+9B7e+c7v8slPzmTNmg05p2tfUb/zalWxoiipl6SHJD0haZ6k71RqW2Y7Yv36jZx22lVcfvnpNDb25nOfO4bnnvs+c+deyODBffna136bd0TrRpXsKb4OfCAiDgVGA8dLem8Ft/cWDU0NrFu+buv0uhXraBjQ0J0ROq2o2YuWe9OmFk477SomTx7LqaeOAWDgwEbq6+uoq6vj058ex0MPLco3ZAlF+86rXcWKYiTWp5M901fJAR670pBRQ1i1eBVrlq6hZVML8+6ax8ijR3ZnhE4ravYi5Y4Ipk79NaNGDeK88yZsnb9s2Stb399yy1wOPri6nxxYpO+8CCp69llSPfAosD9wZUQ8WMntbauuRx0Tz5/IdV++jmgNRp80mqZ9m7ozQqcVNXuRct9//3Nce+0DHHLIEEaP/h6QXH5zww0PM3fuYiQxfPjuXHXVJ3JO2r4ifedFUNGiGBEtwGhJuwG3SDo4Ip7MriNpGjANoO+gvl2eYcSRIxhx5Igub7c7FDV7UXKPG7c/EVe9ZX61X5O4PUX5zougW84+R8Ra4I/A8dtZNiMiDouIw3bZbZfuiGNm1qZKnn0ekPYQkdQbmAD44ikzq2qV3H0eDMxMjyvWATdHxG0V3J6Z2Q6rWFGMiD+TPHvVzKwwavqOFjOzjnJRNDPLcFE0M8twUTQzy3BRNDPLcFE0M8twUTQzy6jpxxGYWfVqbunLn1ZP6sAnvlmxLFnuKZqZZbgomplluCiamWW4KJqZZbgomplluCiamWW4KJqZZbgomplluCiaWU2Q9AtJKyQ9mZnXX9Jdkp5Nf/Yr1Y6LopnVil/x1ofjXQDcHREjgLvT6Xb5Nj8rpOkPLc07QqdMH7tn3hE65bY+PfOOUFJE3Ctp+DazTwaOTd/PBO4Bvt5eO+4pmlktGxgRy9L3LwEDS33APUUzK4o9JD2SmZ4RETPK/XBEhKQotZ6LopkVxcsRcVgHP7Nc0uCIWCZpMLCi1Ae8+2xmtewPwJT0/RTg1lIfcFE0s5og6QZgDjBS0ouSpgIXAxMkPQt8MJ1ul3efzawmRMSZbSwa35F23FM0M8twUTQzy3BRNDPLcFE0M8twUTQzy3BRNDPLcFE0M8twUTQzy3BRNDPLqPk7WhbOWcjsy2bT2trKmEljGDdlXN6RylbU7M5deYsXr+bss3/J8uXNSDBt2lGce+54Vq/ewBlnXM2iRasYPnx3br750/Tr1yfvuIVS0z3F1pZWZl0yi8mXT+YLN36BJ+98kpXPr8w7VlmKmt25u0ePHvVceulHmT9/Og88cAFXXnkP8+cv5eKLZzN+/AE8++z3GD/+AC6+eHbeUQunpnuKS+Yvof/Q/vQbkjyW4aAJB7Hg3gUM2HdAzslKK2p25+4egwf3ZfDgvgA0NPRi1KjBLFmylltvfYJ77vkaAFOmvI9jj72UH/3otDyjtqm55xr+NOTf8o7xFjXdU2xe0UzjwMat041NjTSvbM4xUfmKmt25u9+iRS/z+ON/5fDD92H58nVbi+WgQY0sX74u53TFU/GiKKle0uOSbqv0tszebtav38hpp13F5ZefTmNj7zctk4SknJIVV3f0FM8FnuqG7bxFQ1MD6zJ/KdetWEfDgIY8onRYUbM7d/fZtKmF0067ismTx3LqqWMAGDiwkWXLXgFg2bJXaGqq7t+hGlW0KEoaCnwY+Hklt9OWIaOGsGrxKtYsXUPLphbm3TWPkUePzCNKhxU1u3N3j4hg6tRfM2rUIM47b8LW+ZMmvZOZM+cAMHPmHE4++dC8IhZWpU+0XA78A9DmnytJ04BpAH0H9e3Sjdf1qGPi+RO57svXEa3B6JNG07RvU5duo1KKmt25u8f99z/Htdc+wCGHDGH06O8BcNFFp3DBBcdz+ukzuOaa+9l77/7cfPO0nJMWT8WKoqQTgRUR8aikY9taL30a1wyAPUftWfJJWx014sgRjDhyRFc32y2Kmt25K2/cuP2JuGq7y+6++7xuTlNbKrn7fCQwSdIi4EbgA5Kuq+D2zMx2WMWKYkR8IyKGRsRw4GPAf0XEJyq1PTOzrlDT1ymamXVUt9zREhH3APd0x7bMzHaEe4pmZhkuimZmGS6KZmYZLopmZhkuimZmGS6KZmYZLopmZhkuimZmGS6KZmYZLopmZhkuimZmGS6KZmYZLopmZhkuimZmGS6KZlYTJB0v6WlJCyVd0Nl2XBTNrPAk1QNXAicABwJnSjqwM225KJpZLRgLLIyI5yPiDZLnQp3cmYYU0eUP0Os0SSuBFyrU/B7AyxVqu5Kcu/sVNXslc+8dEQO6skFJs0kyl6sXsDEzPSN9GiiSPgIcHxGfSqfPAg6PiC92NFe3PI6gXF39pWdJeiQiDqtU+5Xi3N2vqNmLljsijs87w/Z499nMasESYFhmemg6r8NcFM2sFjwMjJC0j6SdSB6r/IfONFRVu88VNiPvAJ3k3N2vqNmLmnuHRcRmSV8E7gDqgV9ExLzOtFVVJ1rMzPLm3WczswwXRTOzDBdFsxohSXlnqAU1WxQljZT0Pkk901uACqWgmfeXdJiknfPO0hGSDpJ0jKTd887SUZLGpRcqExHhwrjjavLss6RTgYtIrlNaAjwi6VcRsS7fZKVJekdEPBMRLZLqI6Il70zlkHQiyXe+CnhJ0rcj4pmcY5Uk6QTgR8DzQE9JUyPipZxjlSSpDtgFuCqZVJ+I+FlaGOsiojXniIVVcz1FST2BM4CpETEeuJXkos6vS2rMNVwJaWGZK+l6gC2FMedYJUk6ArgEmBIR7wfWAJ0epaS7SDoWuAL4VEScArwBHJxrqDJFRGtErAdmAtcAR0j66pZluYYruJoriqlGYET6/hbgNqAn8PFq3b2Q1Af4IvAV4A1J10FxCiPwo4h4PH3/baB/AXajlwOfiYiHJA0CDge+KOkqSR+p1n8r29hM8kd/JjBW0mWSfqhErf7/rqia+9IiYhNwGXCqpKPSv5r3AXOBcbmGa0dEbAA+CVwPnA/0yhbGPLOV4UHg97D1WOjOwN4kf5yo1mN1EfFURPwxnZwK/EvaY5wDfISODVaQl1vvk3v3AAAEnUlEQVSBlyLibuAR4LNAYyTcY+yEmiuKqf8G7gTOknR0RLRExPXAnsCh+UZrW0QsjYj1EfEy8Bmg95bCKGmMpAPyTbh96fe75XitgLXA6ohYKWky8H1JvfNLWFpE/CAivp++/xVJQR/W7oeqw2vASEmfJimIFwN7SfpMvrGKqyZPtETERkm/AQL4RlpMXgcGAstyDVemiFiV/sO+RNICkluX3p9zrJIiYjOwXtJiST8EjgPOiYjXco7WJkmKzK1dkk4j+beyNL9U5YmIpZIWA/8IfCEi/kPS+4GFOUcrrJq+zS+9MfxIkl7XRuCKzHGvQkgPnn8dmBARf8k7TynpcbiewFPpz/ER8Wy+qcqTHgP9BHAecEZEPJlzpLJIGgY0RcSj6bTPPu+Ami6KW6THuQp3jEVSP+Bm4GsR8ee883SEpHOAhzt7U34e0isXJgDPRcTTeefpqG17vNY5b4uiWGSSekXExtJrVhf/B7WiclE0M8uo1bPPZmad4qJoZpbhomhmluGiaGaW4aJYIyS1SJor6UlJ/yZplx1o61hJt6XvJ0lqc3AHSbtJ+nwntjFd0vnlzt9mnV+lz/ktd1vDJRXimkPLn4ti7XgtIkZHxMEko718NruwswMERMQfIuLidlbZDehwUTSrVi6Ktem/gf3THtLTkn4NPAkMk3ScpDmSHkt7lLsCSDpe0gJJjwGnbmlI0jmSfpq+HyjpFklPpK8jSO613S/tpV6Srvf3kh6W9GdJ38m09S1Jz0i6DxhZ6peQ9Om0nSck/W6b3u8HJT2Stndiun69pEsy2/b9v9ZhLoo1RlIP4ARgyy2BI0hGfzkI2ABcCHwwIsaQjKpynqRewNXAScC7gUFtNP/PwJ8i4lBgDDCPZNzE59Je6t9LOi7d5lhgNPBuSUdLejfJs3hHAxOB95Tx6/w+It6Tbu8pkpFsthiebuPDwM/S32Eq8EpEvCdt/9OS9iljO2Zb1eSAEG9TvSXNTd//N8nAo3sCL0TEA+n89wIHAvenQwXuRDJM1gHA/2y5RzkdmWfadrbxAeBs2Dqc2SvprYhZx6WvLfeY70pSJBuAWyLi1XQb5Tyo/GBJ3yfZRd+V5Jm+W9yc3rb5rKTn09/hOOCdmeONfdNtV/0I4FY9XBRrx2sRMTo7Iy18G7KzgLsi4sxt1nvT53aQgB9GxFXbbOMrnWjrV8ApEfFEei/1sZll296KFem2vxQR2eKJpOGd2La9TXn3+e3lAeBISftDMtq3pHcAC4DhkvZL1zuzjc/fDXwu/Wy9pL5AM0kvcIs7gE9mjlUOkdQE3AucIqm3pAaSXfVSGoBl6UANk7dZ9lFJdWnmfYGn021/Ll0fSe9QMqK5WdncU3wbSQd9PQe4QX97VMCFEfGMpGnA7ZJeJdn9bthOE+cCMyRNBVqAz0XEHEn3p5e8/N/0uOIoYE7aU10PfCIiHpN0E/AEsAJ4uIzI/0gyqvfK9Gc201+Bh0gGg/1sOobmz0mONT6WDmG2EjilvG/HLOEBIczMMrz7bGaW4aJoZpbhomhmluGiaGaW4aJoZpbhomhmluGiaGaW8f8BrSeQBJ1/ZVkAAAAASUVORK5CYII=\n", 487 | "text/plain": [ 488 | "
" 489 | ] 490 | }, 491 | "metadata": {}, 492 | "output_type": "display_data" 493 | } 494 | ], 495 | "source": [ 496 | "metrics(y, y_pred)" 497 | ] 498 | } 499 | ], 500 | "metadata": { 501 | "kernelspec": { 502 | "display_name": "conda_tensorflow_p36", 503 | "language": "python", 504 | "name": "conda_tensorflow_p36" 505 | }, 506 | "language_info": { 507 | "codemirror_mode": { 508 | "name": "ipython", 509 | "version": 3 510 | }, 511 | "file_extension": ".py", 512 | "mimetype": "text/x-python", 513 | "name": "python", 514 | "nbconvert_exporter": "python", 515 | "pygments_lexer": "ipython3", 516 | "version": "3.6.6" 517 | } 518 | }, 519 | "nbformat": 4, 520 | "nbformat_minor": 2 521 | } 522 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | with open("requirements.txt") as fh: 7 | required = fh.read().splitlines() 8 | 9 | setuptools.setup( 10 | name="classitransformers", 11 | packages=['classitransformers','electra'], 12 | version="0.0.1", 13 | author="FMR LLC", 14 | author_email="classitransformers@fmr.com", 15 | description="An abstract library for implementing text classification tasks based on various transformers based language models", 16 | long_description=long_description, 17 | long_description_content_type="text/markdown", 18 | url="https://github.com/fidelity/classitransformers", 19 | install_requires=required, 20 | python_requires='>=3.6', 21 | classifiers=[ 22 | "Programming Language :: Python :: 3.6", 23 | "License :: OSI Approved :: Apache Software License", 24 | "Operating System :: OS Independent", 25 | ], 26 | keywords='NLP language_models text_classification bert electra roberta distilbert albert', 27 | 28 | project_urls={ 29 | "Source": "https://github.com/fidelity/classitransformers" 30 | } 31 | ) 32 | 33 | -------------------------------------------------------------------------------- /wip: -------------------------------------------------------------------------------- 1 | The library is fully functional but we are looking to add these updates/features as next steps. 2 | 3 | 1. Add unit tests - tests & fixtures 4 | 2. Documentation 5 | 3. Register on pypi 6 | 4. Add deployment modules. After model is trained, codes for inferencing. 7 | 5. Tensorflow serving for Electra. --------------------------------------------------------------------------------