├── .gitignore ├── LICENSE ├── README.md ├── antu ├── __init__.py ├── io │ ├── __init__.py │ ├── configurators │ │ ├── .tmp_test_ini_configurator.ini.swp │ │ ├── __init__.py │ │ └── ini_configurator.py │ ├── dataset_readers │ │ ├── __init__.py │ │ └── dataset_reader.py │ ├── datasets │ │ ├── __init__.py │ │ ├── dataset.py │ │ └── single_task_dataset.py │ ├── ext_embedding_readers.py │ ├── fields │ │ ├── __init__.py │ │ ├── char_token_field.py │ │ ├── field.py │ │ ├── index_field.py │ │ ├── label_field.py │ │ ├── map_token_field.py │ │ ├── meta_field.py │ │ ├── raw_token_field.py │ │ ├── sequence_label_field.py │ │ ├── text_field.py │ │ ├── token_field.py │ │ └── tokenizer_field.py │ ├── instance.py │ ├── token_indexers │ │ ├── __init__.py │ │ ├── char_token_indexer.py │ │ ├── dynamic_token_indexer.py │ │ ├── sequence_token_indexer.py │ │ ├── single_id_token_indexer.py │ │ └── token_indexer.py │ └── vocabulary.py ├── nn │ ├── __init__.py │ └── dynet │ │ ├── __init__.py │ │ ├── attention │ │ ├── __init__.py │ │ ├── biaffine.py │ │ ├── biaffine_matrix.py │ │ ├── multi_head.py │ │ └── single.py │ │ ├── classifiers │ │ ├── __init__.py │ │ └── nn_classifier.py │ │ ├── embedding │ │ ├── __init__.py │ │ ├── bert.py │ │ ├── position.py │ │ ├── segment.py │ │ └── token.py │ │ ├── functional │ │ ├── __init__.py │ │ ├── gelu.py │ │ └── leaky_relu.py │ │ ├── init │ │ ├── __init__.py │ │ ├── init_wrap.py │ │ └── orthogonal_initializer.py │ │ ├── modules │ │ ├── BERT.py │ │ ├── __init__.py │ │ ├── attention_mechanism.py │ │ ├── dynet_model.py │ │ ├── feed_forward.py │ │ ├── graph_nn_unit.py │ │ ├── layer_norm.py │ │ ├── linear.py │ │ ├── perceptron.py │ │ ├── sublayer.py │ │ └── transformer.py │ │ ├── seq2seq_encoders │ │ ├── .rnn_builder.py.swp │ │ ├── __init__.py │ │ ├── rnn_builder.py │ │ └── seq2seq_encoder.py │ │ └── seq2vec_encoders │ │ ├── __init__.py │ │ └── char2word_embedder.py └── utils │ ├── __init__.py │ ├── case_sensitive_configurator.py │ ├── dual_channel_logger.py │ ├── padding_function.py │ └── top_k_indexes.py ├── doc ├── Makefile ├── make.bat └── source │ ├── api │ ├── antu.io.dataset_readers.rst │ ├── antu.io.datasets.rst │ ├── antu.io.fields.rst │ ├── antu.io.rst │ ├── antu.io.token_indexers.rst │ ├── antu.nn.dynet.rst │ ├── antu.nn.rst │ ├── antu.rst │ ├── modules.rst │ └── setup.rst │ ├── conf.py │ └── index.rst ├── examples └── dependency_parsing │ ├── conllu_reader.py │ └── train_parser.py ├── requirements.txt ├── setup.py └── test ├── io ├── configurators │ └── ini_configurator_test.py ├── fields │ └── text_field_test.py ├── instance_test.py ├── token_indexers │ ├── char_token_indexer_test.py │ └── single_id_token_indexer_test.py └── vocabulary_test.py └── nn └── dynet ├── modules ├── linear_test.py └── perceptron_test.py └── seq2seq_encoders └── rnn_builder_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AntU 2 | Universal data IO and neural network modules in NLP tasks. 3 | 4 | + **data IO** is an universal module in Natural Language Processing system and not based on any framework (like TensorFlow, PyTorch, MXNet, Dynet...). 5 | + **neural network** module contains the neural network structures commonly used in NLP tasks. We want to design commonly used structures for each neural network framework. We will continue to develop this module. 6 | 7 | 8 | 9 | # Requirements 10 | 11 | + Python>=3.6 12 | + bidict==0.17.5 13 | + numpy==1.15.4 14 | + numpydoc==0.8.0 15 | + overrides==1.9 16 | + pytest==4.0.2 17 | 18 | ##### If you need dynet neural network: 19 | 20 | + dynet>=2.0 21 | 22 | 23 | 24 | # Installing via pip 25 | 26 | ```bash 27 | pip install antu 28 | ``` 29 | 30 | 31 | 32 | # Resources 33 | 34 | + [Documentation](https://antu.readthedocs.io/en/latest/index.html) 35 | + [Source Code](https://github.com/AntNLP/antu) 36 | 37 | -------------------------------------------------------------------------------- /antu/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntNLP/antu/3256ada0784401b9677d9568e81f3a8792eebee7/antu/__init__.py -------------------------------------------------------------------------------- /antu/io/__init__.py: -------------------------------------------------------------------------------- 1 | from .vocabulary import Vocabulary 2 | from .instance import Instance 3 | -------------------------------------------------------------------------------- /antu/io/configurators/.tmp_test_ini_configurator.ini.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntNLP/antu/3256ada0784401b9677d9568e81f3a8792eebee7/antu/io/configurators/.tmp_test_ini_configurator.ini.swp -------------------------------------------------------------------------------- /antu/io/configurators/__init__.py: -------------------------------------------------------------------------------- 1 | from .ini_configurator import IniConfigurator 2 | -------------------------------------------------------------------------------- /antu/io/configurators/ini_configurator.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, TypeVar 2 | from ...utils.case_sensitive_configurator import CaseSensConfigParser 3 | import argparse 4 | import os 5 | import ast 6 | 7 | 8 | BaseObj = TypeVar("BaseObj", int, float, str, list, set, dict) 9 | BASEOBJ = {int, float, str, list, set, dict} 10 | 11 | 12 | class safe_var_sub(dict): 13 | """ Safe Variable Substitution """ 14 | 15 | def __miss__(self, key): 16 | raise RuntimeError('Attribute (%s) does not exist.' % (key)) 17 | 18 | 19 | def str_to_baseobj(s: str) -> BaseObj: 20 | """ 21 | Converts a string to the corresponding base python type value. 22 | 23 | Parameters 24 | ---------- 25 | s : ``str`` 26 | string like "123", "12.3", "[1, 2, 3]" ... 27 | 28 | Returns 29 | ------- 30 | res : ``BaseObj`` 31 | "123" -> int(123) 32 | "12.3" -> float(12.3) 33 | ... 34 | """ 35 | try: 36 | res = eval(s) 37 | # res = ast.literal_eval(s.format_map(vars())) 38 | except BaseException: 39 | return s 40 | if (s in globals() or s in locals()) and type(res) not in BASEOBJ: 41 | return s 42 | else: 43 | return res 44 | 45 | 46 | class IniConfigurator: 47 | """ 48 | Reads and stores the configuration in the ini Format file. 49 | 50 | Parameters 51 | ---------- 52 | config_file : ``str`` 53 | Path to the configuration file. 54 | extra_args : ``Dict[str, str]``, optional (default=``dict()``) 55 | The configuration of the command line input. 56 | """ 57 | 58 | def __init__(self, 59 | config_file: str, 60 | extra_args: Dict[str, str] = dict()) -> None: 61 | 62 | config = CaseSensConfigParser() 63 | config.read(config_file) 64 | if extra_args: 65 | extra_args = ( 66 | dict([(k[2:], v) 67 | for k, v in zip(extra_args[0::2], extra_args[1::2])])) 68 | attr_name = set() 69 | for section in config.sections(): 70 | for k, v in config.items(section): 71 | if k in extra_args: 72 | v = type(v)(extra_args[k]) 73 | config.set(section, k, v) 74 | 75 | if k in attr_name: 76 | raise RuntimeError( 77 | 'Attribute (%s) has already appeared.' % (k)) 78 | else: 79 | attr_name.update(k) 80 | super(IniConfigurator, self).__setattr__(k, str_to_baseobj(v)) 81 | 82 | with open(config_file, 'w') as fout: 83 | config.write(fout) 84 | 85 | print('Loaded config file sucessfully.') 86 | for section in config.sections(): 87 | for k, v in config.items(section): 88 | print(k, v) 89 | 90 | def __setattr__(self, name, value): 91 | raise RuntimeError('Try to set the attribute (%s) of the constant ' 92 | 'class (%s).' % (name, self.__class__.__name__)) 93 | -------------------------------------------------------------------------------- /antu/io/dataset_readers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntNLP/antu/3256ada0784401b9677d9568e81f3a8792eebee7/antu/io/dataset_readers/__init__.py -------------------------------------------------------------------------------- /antu/io/dataset_readers/dataset_reader.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | from abc import ABCMeta, abstractmethod 3 | from antu.io.instance import Instance 4 | 5 | 6 | class DatasetReader(metaclass=ABCMeta): 7 | 8 | @abstractmethod 9 | def read(self, file_path: str) -> List[Instance]: 10 | pass 11 | 12 | @abstractmethod 13 | def input_to_instance(self, inputs: str) -> Instance: 14 | pass -------------------------------------------------------------------------------- /antu/io/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntNLP/antu/3256ada0784401b9677d9568e81f3a8792eebee7/antu/io/datasets/__init__.py -------------------------------------------------------------------------------- /antu/io/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | from abc import ABCMeta, abstractmethod 3 | from antu.io.vocabulary import Vocabulary 4 | from antu.io.instance import Instance 5 | 6 | 7 | class Dataset(metaclass=ABCMeta): 8 | 9 | vocabulary_set: Vocabulary = {} 10 | datasets: Dict[str, List[Instance]] = {} 11 | 12 | @abstractmethod 13 | def build_dataset(): 14 | pass 15 | 16 | 17 | -------------------------------------------------------------------------------- /antu/io/datasets/single_task_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Callable, Union, Set 2 | from overrides import overrides 3 | from antu.io.vocabulary import Vocabulary 4 | from antu.io.instance import Instance 5 | from antu.io.datasets.dataset import Dataset 6 | from antu.io.dataset_readers.dataset_reader import DatasetReader 7 | from antu.utils.padding_function import shadow_padding 8 | import random 9 | from itertools import cycle 10 | 11 | 12 | class DatasetSetting: 13 | 14 | def __init__(self, file_path: str, is_train: bool): 15 | self.file_path = file_path 16 | self.is_train = is_train 17 | 18 | 19 | class SingleTaskDataset: 20 | 21 | def __init__( 22 | self, 23 | vocabulary: Vocabulary, 24 | datasets_settings: Dict[str, DatasetSetting], 25 | reader: DatasetReader): 26 | self.vocabulary = vocabulary 27 | self.datasets_settings = datasets_settings 28 | self.datasets = dict() 29 | self.reader = reader 30 | 31 | def build_dataset( 32 | self, 33 | counters: Dict[str, Dict[str, int]], 34 | min_count: Union[int, Dict[str, int]] = dict(), 35 | no_pad_namespace: Set[str] = set(), 36 | no_unk_namespace: Set[str] = set()) -> None: 37 | 38 | for name, setting in self.datasets_settings.items(): 39 | self.datasets[name] = self.reader.read(setting.file_path) 40 | if setting.is_train: 41 | for ins in self.datasets[name]: 42 | ins.count_vocab_items(counters) 43 | self.vocabulary.extend_from_counter( 44 | counters, min_count, no_pad_namespace, no_unk_namespace) 45 | for name in self.datasets: 46 | for ins in self.datasets[name]: 47 | ins.index_fields(self.vocabulary) 48 | 49 | def get_dataset(self, name: str) -> List[Instance]: 50 | return self.datasets[name] 51 | 52 | def get_batches( 53 | self, 54 | name: str, 55 | size: int, 56 | ordered: bool=False, 57 | cmp: Callable[[Instance, Instance], int]=None, 58 | is_infinite: bool=False) -> List[List[int]]: 59 | #print(self.datasets[name]) 60 | if ordered: self.datasets[name].sort(key=cmp) 61 | 62 | num = len(self.datasets[name]) # Number of Instances 63 | result = [] 64 | for beg in range(0, num, size): 65 | ins_batch = self.datasets[name][beg: beg+size] 66 | idx_batch = [ins.index_fields(self.vocabulary) for ins in ins_batch] 67 | indexes, masks = shadow_padding(idx_batch, self.vocabulary) 68 | yield indexes, masks 69 | result.append((indexes, masks)) 70 | 71 | while is_infinite: 72 | random.shuffle(result) 73 | for indexes, masks in result: 74 | yield indexes, masks 75 | 76 | # def build_batches(self, ) 77 | -------------------------------------------------------------------------------- /antu/io/ext_embedding_readers.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | import numpy 3 | import os 4 | import gzip 5 | 6 | 7 | def glove_reader(file_path: str, only_word: bool = False) -> Tuple[List[str], List[List[float]]]: 8 | if os.path.isfile(file_path): 9 | word = [] 10 | vector = [] 11 | if '.gz' != file_path[-3:]: 12 | with open(file_path, 'r') as fp: 13 | for w in fp: 14 | w_list = w.strip().split(' ') 15 | word.append(w_list[0]) 16 | if not only_word: 17 | vector.append(list(map(float, w_list[1:]))) 18 | return word if only_word else (word, vector) 19 | else: 20 | with gzip.open(file_path, 'rt') as fp: 21 | for w in fp: 22 | w_list = w.strip().split(' ') 23 | word.append(w_list[0]) 24 | if not only_word: 25 | vector.append(list(map(float, w_list[1:]))) 26 | return word if only_word else (word, vector) 27 | else: 28 | raise RuntimeError("Glove file (%s) does not exist.") 29 | 30 | 31 | def fasttext_reader(file_path: str, only_word: bool = False) -> Tuple[List[str], List[List[float]]]: 32 | if os.path.isfile(file_path): 33 | word = [] 34 | vector = [] 35 | if '.gz' != file_path[-3:]: 36 | with open(file_path, 'r') as fp: 37 | w_dim = int(fp.readline().strip().split(' ')[1]) 38 | for w in fp: 39 | w_list = w.strip().split(' ') 40 | if len(w_list)-1 != w_dim: 41 | continue 42 | word.append(w_list[0]) 43 | if not only_word: 44 | vector.append(list(map(float, w_list[1:]))) 45 | return word if only_word else (word, vector) 46 | else: 47 | with gzip.open(file_path, 'rt') as fp: 48 | w_dim = int(fp.readline().strip().split(' ')[1]) 49 | for w in fp: 50 | w_list = w.strip().split(' ') 51 | if len(w_list)-1 != w_dim: 52 | continue 53 | word.append(w_list[0]) 54 | if not only_word: 55 | vector.append(list(map(float, w_list[1:]))) 56 | return word if only_word else (word, vector) 57 | else: 58 | raise RuntimeError("Fasttext file (%s) does not exist.") 59 | -------------------------------------------------------------------------------- /antu/io/fields/__init__.py: -------------------------------------------------------------------------------- 1 | from .field import Field 2 | from .index_field import IndexField 3 | from .label_field import LabelField 4 | from .meta_field import MetaField 5 | from .sequence_label_field import SequenceLabelField 6 | from .text_field import TextField 7 | from .char_token_field import CharTokenField 8 | from .map_token_field import MapTokenField 9 | from .raw_token_field import RawTokenField 10 | from .tokenizer_field import TokenizerField 11 | from .token_field import TokenField 12 | -------------------------------------------------------------------------------- /antu/io/fields/char_token_field.py: -------------------------------------------------------------------------------- 1 | from . import Field 2 | import logging 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | class CharTokenField(Field): 8 | """Char token field: split each token into chars 9 | """ 10 | 11 | def __init__(self, namespace, source_key): 12 | """This function set namespace name and dataset source key 13 | 14 | Arguments: 15 | namespace {str} -- namespace name 16 | source_key {str} -- indicate key in text data 17 | """ 18 | 19 | self.namespace = namespace 20 | self.source_key = source_key 21 | super().__init__() 22 | 23 | def count_vocab_items(self, counter, sentences): 24 | """This function counts token's char in sentences, 25 | then updates counter 26 | 27 | Arguments: 28 | counter {dict} -- counter 29 | sentences {list} -- text content after preprocessing 30 | """ 31 | 32 | for sentence in sentences: 33 | for token in sentence[self.source_key]: 34 | for char in token: 35 | counter[self.namespace][str(char)] += 1 36 | 37 | logger.info( 38 | "Count sentences {} to update counter namespace {} successfully.". 39 | format(self.source_key, self.namespace)) 40 | 41 | def index(self, instance, vocab, sentences): 42 | """This function indexes token using vocabulary, 43 | then update instance 44 | 45 | Arguments: 46 | instance {dict} -- numerical represenration of text data 47 | vocab {Vocabulary} -- vocabulary 48 | sentences {list} -- text content after preprocessing 49 | """ 50 | 51 | for sentence in sentences: 52 | token_num_repr = [] 53 | for token in sentence[self.source_key]: 54 | token_num_repr.append([ 55 | vocab.get_token_index(char, self.namespace) 56 | for char in token 57 | ]) 58 | instance[self.namespace].append(token_num_repr) 59 | 60 | logger.info( 61 | "Index sentences {} to construct instance namespace {} successfully." 62 | .format(self.source_key, self.namespace)) 63 | -------------------------------------------------------------------------------- /antu/io/fields/field.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | from abc import ABCMeta, abstractmethod 3 | 4 | from .. import Vocabulary 5 | 6 | 7 | class Field(metaclass=ABCMeta): 8 | """ 9 | A ``Field`` is an ingredient of a data instance. In most NLP tasks, ``Field`` 10 | stores data of string types. It contains one or more indexers that map string 11 | data to the corresponding index. Data instances are collections of fields. 12 | """ 13 | @abstractmethod 14 | def count_vocab_items(self, counter: Dict[str, Dict[str, int]]) -> None: 15 | """ 16 | We count the number of strings if the string needs to be mapped to one 17 | or more integers. You can pass directly if there is no string that needs 18 | to be mapped. 19 | 20 | Parameters 21 | ---------- 22 | counter : ``Dict[str, Dict[str, int]]`` 23 | ``counter`` is used to count the number of each item. The first key 24 | represents the namespace of the vocabulary, and the second key represents 25 | the string of the item. 26 | """ 27 | pass 28 | 29 | @abstractmethod 30 | def index(self, vocab: Vocabulary) -> None: 31 | """ 32 | Gets one or more index mappings for each element in the Field. 33 | 34 | Parameters 35 | ---------- 36 | vocab : ``Vocabulary`` 37 | ``vocab`` is used to get the index of each item. 38 | """ 39 | pass 40 | -------------------------------------------------------------------------------- /antu/io/fields/index_field.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Iterator 2 | from overrides import overrides 3 | from ..token_indexers import TokenIndexer 4 | from .. import Vocabulary 5 | from . import Field 6 | 7 | 8 | class IndexField(Field): 9 | """ 10 | A ``IndexField`` is an integer field, and we can use it to store data ID. 11 | 12 | Parameters 13 | ---------- 14 | name : ``str`` 15 | Field name. This is necessary and must be unique (not the same as other 16 | field names). 17 | tokens : ``List[str]`` 18 | Field content that contains a list of string. 19 | """ 20 | 21 | def __init__(self, 22 | name: str, 23 | tokens: List[str]): 24 | self.name = name 25 | self.tokens = [int(x) for x in tokens] 26 | 27 | def __iter__(self) -> Iterator[str]: 28 | return iter(self.tokens) 29 | 30 | def __getitem__(self, idx: int) -> str: 31 | return self.tokens[idx] 32 | 33 | def __len__(self) -> int: 34 | return len(self.tokens) 35 | 36 | def __str__(self) -> str: 37 | return '{}: [{}]'.format(self.name, ', '.join(self.tokens)) 38 | 39 | @overrides 40 | def count_vocab_items(self, counters: Dict[str, Dict[str, int]]) -> None: 41 | """ 42 | ``IndexField`` doesn't need index operation. 43 | """ 44 | pass 45 | 46 | @overrides 47 | def index(self, vocab: Vocabulary) -> None: 48 | """ 49 | ``IndexField`` doesn't need index operation. 50 | """ 51 | # self.indexes = dict() 52 | # self.indexes[self.name] = self.tokens 53 | self.indexes = self.tokens 54 | -------------------------------------------------------------------------------- /antu/io/fields/label_field.py: -------------------------------------------------------------------------------- 1 | from typing import List, Iterator, Dict 2 | 3 | from overrides import overrides 4 | 5 | from ..token_indexers import TokenIndexer 6 | from .. import Vocabulary 7 | from . import Field 8 | 9 | 10 | class LabelField(Field): 11 | 12 | def __init__(self, 13 | name: str, 14 | tokens: str, 15 | indexers: List[TokenIndexer]): 16 | self.name = name 17 | self.tokens = [tokens] 18 | self.indexers = indexers 19 | 20 | def __iter__(self) -> Iterator[str]: 21 | return iter(self.tokens) 22 | 23 | def __getitem__(self, idx: int) -> str: 24 | return self.tokens[idx] 25 | 26 | def __len__(self) -> int: 27 | return len(self.tokens[0]) 28 | 29 | def __str__(self) -> str: 30 | return '{}: {}'.format(self.name, self.tokens[0]) 31 | 32 | @overrides 33 | def count_vocab_items(self, counters: Dict[str, Dict[str, int]]) -> None: 34 | for idxer in self.indexers: 35 | idxer.count_vocab_items(self.tokens[0], counters) 36 | 37 | @overrides 38 | def index(self, vocab: Vocabulary) -> None: 39 | self.indexes = {} 40 | for idxer in self.indexers: 41 | self.indexes.update(idxer.tokens_to_indices(self.tokens, vocab)) 42 | -------------------------------------------------------------------------------- /antu/io/fields/map_token_field.py: -------------------------------------------------------------------------------- 1 | from . import Field 2 | import logging 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | class MapTokenField(Field): 8 | """Map token field: preocess maping tokens 9 | """ 10 | 11 | def __init__(self, namespace, source_key): 12 | """This function set namespace name and dataset source key 13 | 14 | Arguments: 15 | namespace {str} -- namespace 16 | source_key {str} -- indicate key in text data 17 | """ 18 | 19 | self.namespace = namespace 20 | self.source_key = source_key 21 | super().__init__() 22 | 23 | def count_vocab_items(self, counter, sentences): 24 | """This function counts dict's values in sentences, 25 | then update counter, each sentence is a dict 26 | 27 | Arguments: 28 | counter {dict} -- counter 29 | sentences {list} -- text content after preprocessing, list of dict 30 | """ 31 | 32 | for sentence in sentences: 33 | for value in sentence[self.source_key].values(): 34 | counter[self.namespace][str(value)] += 1 35 | 36 | logger.info( 37 | "Count sentences {} to update counter namespace {} successfully.". 38 | format(self.source_key, self.namespace)) 39 | 40 | def index(self, instance, vocab, sentences): 41 | """This function indexes token using vocabulary, then update instance 42 | 43 | Arguments: 44 | instance {dict} -- numerical represenration of text data 45 | vocab {Vocabulary} -- vocabulary 46 | sentences {list} -- text content after preprocessing 47 | """ 48 | 49 | for sentence in sentences: 50 | instance[self.namespace].append({ 51 | key: vocab.get_token_index(value, self.namespace) 52 | for key, value in sentence[self.source_key].items() 53 | }) 54 | 55 | logger.info( 56 | "Index sentences {} to construct instance namespace {} successfully." 57 | .format(self.source_key, self.namespace)) 58 | -------------------------------------------------------------------------------- /antu/io/fields/meta_field.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Iterator 2 | from overrides import overrides 3 | from ..token_indexers import TokenIndexer 4 | from .. import Vocabulary 5 | from . import Field 6 | 7 | 8 | class MetaField(Field): 9 | """ 10 | A ``IndexField`` is an integer field, and we can use it to store data ID. 11 | 12 | Parameters 13 | ---------- 14 | name : ``str`` 15 | Field name. This is necessary and must be unique (not the same as other 16 | field names). 17 | tokens : ``List[str]`` 18 | Field content that contains a list of string. 19 | """ 20 | 21 | def __init__(self, 22 | name: str, 23 | tokens: List[str]): 24 | self.name = name 25 | self.tokens = tokens 26 | 27 | def __iter__(self) -> Iterator[str]: 28 | return iter(self.tokens) 29 | 30 | def __getitem__(self, idx: int) -> str: 31 | return self.tokens[idx] 32 | 33 | def __len__(self) -> int: 34 | return len(self.tokens) 35 | 36 | def __str__(self) -> str: 37 | return '{}: [{}]'.format(self.name, (self.tokens)) 38 | 39 | @overrides 40 | def count_vocab_items(self, counters: Dict[str, Dict[str, int]]) -> None: 41 | """ 42 | ``IndexField`` doesn't need index operation. 43 | """ 44 | pass 45 | 46 | @overrides 47 | def index(self, vocab: Vocabulary) -> None: 48 | """ 49 | ``IndexField`` doesn't need index operation. 50 | """ 51 | # self.indexes = dict() 52 | # self.indexes[self.name] = self.tokens 53 | self.indexes = self.tokens 54 | -------------------------------------------------------------------------------- /antu/io/fields/raw_token_field.py: -------------------------------------------------------------------------------- 1 | from . import Field 2 | import logging 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | class RawTokenField(Field): 8 | """This Class preserves raw text of tokens 9 | """ 10 | 11 | def __init__(self, namespace, source_key): 12 | """This function set namespace name and dataset source key 13 | 14 | Arguments: 15 | namespace {str} -- namespace 16 | source_key {str} -- indicate key in text data 17 | """ 18 | 19 | self.namespace = str(namespace) 20 | self.source_key = str(source_key) 21 | super().__init__() 22 | 23 | def count_vocab_items(self, counter, sentences): 24 | """ `RawTokenField` doesn't update counter 25 | 26 | Arguments: 27 | counter {dict} -- counter 28 | sentences {list} -- text content after preprocessing 29 | """ 30 | 31 | pass 32 | 33 | def index(self, instance, vocab, sentences): 34 | """This function doesn't use vocabulary, 35 | perserve raw text of sentences(tokens) 36 | 37 | Arguments: 38 | instance {dict} -- numerical represenration of text data 39 | vocab {Vocabulary} -- vocabulary 40 | sentences {list} -- text content after preprocessing 41 | """ 42 | 43 | for sentence in sentences: 44 | instance[self.namespace].append( 45 | [token for token in sentence[self.source_key]]) 46 | 47 | logger.info( 48 | "Index sentences {} to construct instance namespace {} successfully." 49 | .format(self.source_key, self.namespace)) 50 | -------------------------------------------------------------------------------- /antu/io/fields/sequence_label_field.py: -------------------------------------------------------------------------------- 1 | from typing import List, Iterator, Dict 2 | from overrides import overrides 3 | from ..token_indexers import TokenIndexer 4 | from .. import Vocabulary 5 | from . import Field 6 | 7 | 8 | class SequenceLabelField(Field): 9 | 10 | def __init__(self, 11 | name: str, 12 | tokens: List[str], 13 | indexers: List[TokenIndexer]): 14 | self.name = name 15 | self.tokens = tokens 16 | self.indexers = indexers 17 | 18 | def __iter__(self) -> Iterator[str]: 19 | return iter(self.tokens) 20 | 21 | def __getitem__(self, idx: int) -> str: 22 | return self.tokens[idx] 23 | 24 | def __len__(self) -> int: 25 | return len(self.tokens) 26 | 27 | def __str__(self) -> str: 28 | return '{}: [{}]'.format(self.name, ', '.join(self.tokens)) 29 | 30 | @overrides 31 | def count_vocab_items(self, counters: Dict[str, Dict[str, int]]) -> None: 32 | for idxer in self.indexers: 33 | for token in self.tokens: 34 | idxer.count_vocab_items(token, counters) 35 | 36 | @overrides 37 | def index(self, vocab: Vocabulary) -> None: 38 | self.indexes = {} 39 | for idxer in self.indexers: 40 | self.indexes.update(idxer.tokens_to_indices(self.tokens, vocab)) 41 | -------------------------------------------------------------------------------- /antu/io/fields/text_field.py: -------------------------------------------------------------------------------- 1 | from typing import List, Iterator, Dict 2 | from overrides import overrides 3 | from ..token_indexers import TokenIndexer 4 | from .. import Vocabulary 5 | from . import Field 6 | 7 | 8 | class TextField(Field): 9 | """ 10 | A ``TextField`` is a data field that is commonly used in NLP tasks, and we 11 | can use it to store text sequences such as sentences, paragraphs, POS tags, 12 | and so on. 13 | 14 | Parameters 15 | ---------- 16 | name : ``str`` 17 | Field name. This is necessary and must be unique (not the same as other 18 | field names). 19 | tokens : ``List[str]`` 20 | Field content that contains a list of string. 21 | indexers : ``List[TokenIndexer]``, optional (default=``list()``) 22 | Indexer list that defines the vocabularies associated with the field. 23 | """ 24 | 25 | def __init__(self, 26 | name: str, 27 | tokens: List[str], 28 | indexers: List[TokenIndexer] = list()): 29 | self.name = name 30 | self.tokens = tokens 31 | self.indexers = indexers 32 | 33 | def __iter__(self) -> Iterator[str]: 34 | return iter(self.tokens) 35 | 36 | def __getitem__(self, idx: int) -> str: 37 | return self.tokens[idx] 38 | 39 | def __len__(self) -> int: 40 | return len(self.tokens) 41 | 42 | def __str__(self) -> str: 43 | return '{}: [{}]'.format(self.name, ', '.join(self.tokens)) 44 | 45 | @overrides 46 | def count_vocab_items(self, counters: Dict[str, Dict[str, int]]) -> None: 47 | """ 48 | We count the number of strings if the string needs to be counted to some 49 | counters. You can pass directly if there is no string that needs 50 | to be counted. 51 | 52 | Parameters 53 | ---------- 54 | counters : ``Dict[str, Dict[str, int]]`` 55 | Element statistics for datasets. if field indexers indicate that 56 | this field is related to some counters, we use field content to 57 | update the counters. 58 | """ 59 | for idxer in self.indexers: 60 | for token in self.tokens: 61 | idxer.count_vocab_items(token, counters) 62 | 63 | @overrides 64 | def index(self, vocab: Vocabulary) -> None: 65 | """ 66 | Gets one or more index mappings for each element in the Field. 67 | 68 | Parameters 69 | ---------- 70 | vocab : ``Vocabulary`` 71 | ``vocab`` is used to get the index of each item. 72 | """ 73 | self.indexes = {} 74 | for idxer in self.indexers: 75 | self.indexes.update(idxer.tokens_to_indices(self.tokens, vocab)) 76 | -------------------------------------------------------------------------------- /antu/io/fields/token_field.py: -------------------------------------------------------------------------------- 1 | from . import Field 2 | import logging 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | class TokenField(Field): 8 | """Token field: regard sentence as token list 9 | """ 10 | 11 | def __init__(self, namespace, source_key): 12 | """This function set namespace name and dataset source key 13 | 14 | Arguments: 15 | namespace {str} -- namespace 16 | source_key {str} -- indicate key in text data 17 | """ 18 | 19 | self.namespace = str(namespace) 20 | self.source_key = str(source_key) 21 | super().__init__() 22 | 23 | def count_vocab_items(self, counter, sentences): 24 | """This function counts tokens in sentences, 25 | then update counter 26 | 27 | Arguments: 28 | counter {dict} -- counter 29 | sentences {list} -- text content after preprocessing 30 | """ 31 | 32 | for sentence in sentences: 33 | for token in sentence[self.source_key]: 34 | counter[self.namespace][str(token)] += 1 35 | 36 | logger.info( 37 | "Count sentences {} to update counter namespace {} successfully.". 38 | format(self.source_key, self.namespace)) 39 | 40 | def index(self, instance, vocab, sentences): 41 | """This function indexed token using vocabulary, 42 | then update instance 43 | 44 | Arguments: 45 | instance {dict} -- numerical represenration of text data 46 | vocab {Vocabulary} -- vocabulary 47 | sentences {list} -- text content after preprocessing 48 | """ 49 | 50 | for sentence in sentences: 51 | instance[self.namespace].append([ 52 | vocab.get_token_index(token, self.namespace) 53 | for token in sentence[self.source_key] 54 | ]) 55 | 56 | logger.info( 57 | "Index sentences {} to construct instance namespace {} successfully." 58 | .format(self.source_key, self.namespace)) 59 | -------------------------------------------------------------------------------- /antu/io/fields/tokenizer_field.py: -------------------------------------------------------------------------------- 1 | from . import Field 2 | import logging 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | class TokenizerField(Field): 8 | """This class using tokenizer to tokenize 9 | """ 10 | 11 | def __init__(self, namespace, source_key, tokenizer): 12 | """This function set namespace name and dataset source key 13 | 14 | Arguments: 15 | namespace {str} -- namespace 16 | source_key {str} -- indicate key in text data 17 | tokenizer {Callable} -- tokenizer function 18 | """ 19 | 20 | self.namespace = str(namespace) 21 | self.source_key = str(source_key) 22 | self.tokenizer = tokenizer 23 | super().__init__() 24 | 25 | def count_vocab_items(self, counter, sentences): 26 | """ `TokenizerField` doesn't update counter 27 | 28 | Arguments: 29 | counter {dict} -- counter 30 | sentences {list} -- text content after preprocessing 31 | """ 32 | 33 | pass 34 | 35 | def index(self, instance, vocab, sentences): 36 | """This function indexes token using vocabulary, 37 | then update instance 38 | 39 | Arguments: 40 | instance {dict} -- numerical represenration of text data 41 | vocab {Vocabulary} -- vocabulary 42 | sentences {list} -- text content after preprocessing 43 | """ 44 | 45 | for sentence in sentences: 46 | tokenized_tokens = [vocab.get_token_index('[CLS]', self.namespace)] 47 | token_index = [] 48 | index = 1 49 | for token in sentence[self.source_key]: 50 | tokenized_token = self.tokenizer(token) 51 | token_index.append(index) 52 | index += len(tokenized_token) 53 | tokenized_tokens.extend( 54 | [vocab.get_token_index(item, self.namespace) for item in tokenized_token]) 55 | tokenized_tokens.append( 56 | vocab.get_token_index('[SEP]', self.namespace)) 57 | instance[self.namespace].append(tokenized_tokens) 58 | instance[self.namespace + '_index'].append(token_index) 59 | 60 | logger.info("Index sentences {} to construct instance namespace {} successfully.".format( 61 | self.source_key, self.namespace)) 62 | -------------------------------------------------------------------------------- /antu/io/instance.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, MutableMapping, Mapping, TypeVar, List, Set 2 | 3 | from . import Vocabulary 4 | from .fields import Field 5 | 6 | Indices = TypeVar("Indices", List[int], List[List[int]]) 7 | 8 | 9 | class Instance(Mapping[str, Field]): 10 | """ 11 | An ``Instance`` is a collection (list) of multiple data fields. 12 | 13 | Parameters 14 | ---------- 15 | fields : ``List[Field]``, optional (default=``None``) 16 | A list of multiple data fields. 17 | """ 18 | 19 | def __init__(self, fields: List[Field] = None) -> None: 20 | self.fields = fields 21 | self._fields_dict = {} 22 | for field in fields: 23 | self._fields_dict[field.name] = field 24 | self.indexed = False # Indicates whether the instance has been indexed 25 | 26 | def __getitem__(self, key: str) -> Field: 27 | return self._fields_dict[key] 28 | 29 | def __iter__(self): 30 | return iter(self.fields) 31 | 32 | def __len__(self) -> int: 33 | return len(self.fields) 34 | 35 | def add_field(self, field: Field) -> None: 36 | """ 37 | Add the field to the existing ``Instance``. 38 | 39 | Parameters 40 | ---------- 41 | field : ``Field`` 42 | Which field needs to be added. 43 | """ 44 | self.fields.append(field) 45 | if self.indexed: 46 | field.index(vocab) 47 | 48 | def count_vocab_items(self, counter: Dict[str, Dict[str, int]]) -> None: 49 | """ 50 | Increments counts in the given ``counter`` for all of the vocabulary 51 | items in all of the ``Fields`` in this ``Instance``. 52 | 53 | Parameters 54 | ---------- 55 | counter : ``Dict[str, Dict[str, int]]`` 56 | We count the number of strings if the string needs to be counted to 57 | some counters. 58 | """ 59 | for field in self.fields: 60 | field.count_vocab_items(counter) 61 | 62 | def index_fields(self, vocab: Vocabulary) -> Dict[str, Dict[str, Indices]]: 63 | """ 64 | Indexes all fields in this ``Instance`` using the provided ``Vocabulary``. 65 | This `mutates` the current object, it does not return a new ``Instance``. 66 | A ``DataIterator`` will call this on each pass through a dataset; we use the ``indexed`` 67 | flag to make sure that indexing only happens once. 68 | This means that if for some reason you modify your vocabulary after you've 69 | indexed your instances, you might get unexpected behavior. 70 | 71 | Parameters 72 | ---------- 73 | vocab : ``Vocabulary`` 74 | ``vocab`` is used to get the index of each item. 75 | 76 | Returns 77 | ------- 78 | res : ``Dict[str, Dict[str, Indices]]`` 79 | Returns the Indices corresponding to the instance. The first key is 80 | field name and the second key is the vocabulary name. 81 | """ 82 | if not self.indexed: 83 | self.indexed = True 84 | for field in self.fields: 85 | field.index(vocab) 86 | res = {} 87 | for field in self.fields: 88 | res[field.name] = field.indexes 89 | return res 90 | 91 | def dynamic_index_fields(self, vocab: Vocabulary, dynamic_fields: Set[str]) -> Dict[str, Dict[str, Indices]]: 92 | """ 93 | Indexes all fields in this ``Instance`` using the provided ``Vocabulary``. 94 | This `mutates` the current object, it does not return a new ``Instance``. 95 | A ``DataIterator`` will call this on each pass through a dataset; we use the ``indexed`` 96 | flag to make sure that indexing only happens once. 97 | This means that if for some reason you modify your vocabulary after you've 98 | indexed your instances, you might get unexpected behavior. 99 | 100 | Parameters 101 | ---------- 102 | vocab : ``Vocabulary`` 103 | ``vocab`` is used to get the index of each item. 104 | 105 | Returns 106 | ------- 107 | res : ``Dict[str, Dict[str, Indices]]`` 108 | Returns the Indices corresponding to the instance. The first key is 109 | field name and the second key is the vocabulary name. 110 | """ 111 | if not self.indexed: 112 | self.indexed = True 113 | for field in self.fields: 114 | if field.name not in dynamic_fields: 115 | field.index(vocab) 116 | 117 | res = {} 118 | for field in self.fields: 119 | if field.name in dynamic_fields: 120 | field.index(vocab) 121 | res[field.name] = field.indexes 122 | return res 123 | -------------------------------------------------------------------------------- /antu/io/token_indexers/__init__.py: -------------------------------------------------------------------------------- 1 | from .token_indexer import TokenIndexer 2 | from .char_token_indexer import CharTokenIndexer 3 | from .dynamic_token_indexer import DynamicTokenIndexer 4 | from .sequence_token_indexer import SequenceTokenIndexer 5 | from .single_id_token_indexer import SingleIdTokenIndexer 6 | -------------------------------------------------------------------------------- /antu/io/token_indexers/char_token_indexer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Callable, TypeVar 2 | from overrides import overrides 3 | from .. import Vocabulary 4 | from . import TokenIndexer 5 | Indices = TypeVar("Indices", List[int], List[List[int]]) 6 | 7 | 8 | class CharTokenIndexer(TokenIndexer): 9 | """ 10 | A ``CharTokenIndexer`` determines how string token get represented as 11 | arrays of list of character indices in a model. 12 | 13 | Parameters 14 | ---------- 15 | related_vocabs : ``List[str]`` 16 | Which vocabularies are related to the indexer. 17 | transform : ``Callable[[str,], str]``, optional (default=``lambda x:x``) 18 | What changes need to be made to the token when counting or indexing. 19 | Commonly used are lowercase transformation functions. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | related_vocabs: List[str], 25 | transform: Callable[[str, ], str] = lambda x: x) -> None: 26 | self.related_vocabs = related_vocabs 27 | self.transform = transform 28 | 29 | @overrides 30 | def count_vocab_items( 31 | self, 32 | token: str, 33 | counters: Dict[str, Dict[str, int]]) -> None: 34 | """ 35 | Each character in the token is counted directly as an element. 36 | 37 | Parameters 38 | ---------- 39 | counter : ``Dict[str, Dict[str, int]]`` 40 | We count the number of strings if the string needs to be counted to 41 | some counters. 42 | """ 43 | for vocab_name in self.related_vocabs: 44 | if vocab_name in counters: 45 | for ch in token: 46 | counters[vocab_name][self.transform(ch)] += 1 47 | 48 | @overrides 49 | def tokens_to_indices( 50 | self, 51 | tokens: List[str], 52 | vocab: Vocabulary) -> Dict[str, List[List[int]]]: 53 | """ 54 | Takes a list of tokens and converts them to one or more sets of indices. 55 | During the indexing process, each token item corresponds to a list of 56 | index in the vocabulary. 57 | 58 | Parameters 59 | ---------- 60 | vocab : ``Vocabulary`` 61 | ``vocab`` is used to get the index of each item. 62 | """ 63 | res = {} 64 | for vocab_name in self.related_vocabs: 65 | index_list = [] 66 | 67 | for token in tokens: 68 | index_list.append( 69 | [vocab.get_token_index(self.transform(ch), vocab_name) 70 | for ch in token]) 71 | res[vocab_name] = index_list 72 | return res 73 | -------------------------------------------------------------------------------- /antu/io/token_indexers/dynamic_token_indexer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Callable, TypeVar 2 | from overrides import overrides 3 | from .. import Vocabulary 4 | from . import TokenIndexer 5 | Indices = TypeVar("Indices", List[int], List[List[int]]) 6 | 7 | 8 | class DynamicTokenIndexer(TokenIndexer): 9 | """ 10 | A ``SingleIdTokenIndexer`` determines how string token get represented as 11 | arrays of single id indexes in a model. 12 | 13 | Parameters 14 | ---------- 15 | related_vocabs : ``List[str]`` 16 | Which vocabularies are related to the indexer. 17 | transform_for_count : ``Callable[[str,], List]``, optional (default=``lambda x:[x]``) 18 | What changes need to be made to the token when counting. 19 | Commonly used is dynamic oracle. 20 | transform_for_index : ``Callable[[str,], List]``, optional (default=``lambda x:[x]``) 21 | What changes need to be made to the token when indexing. 22 | Commonly used is dynamic oracle. 23 | """ 24 | 25 | def __init__(self, 26 | related_vocabs: List[str], 27 | transform_for_count: Callable[[str, ], List] = lambda x: [x], 28 | transform_for_index: Callable[[str, ], List] = lambda x: [x]) -> None: 29 | self.related_vocabs = related_vocabs 30 | self.transform_for_count = transform_for_count 31 | self.transform_for_index = transform_for_index 32 | 33 | @overrides 34 | def count_vocab_items( 35 | self, 36 | token: str, 37 | counters: Dict[str, Dict[str, int]]) -> None: 38 | """ 39 | The token is counted directly as an element. 40 | 41 | Parameters 42 | ---------- 43 | counter : ``Dict[str, Dict[str, int]]`` 44 | We count the number of strings if the string needs to be counted to 45 | some counters. 46 | """ 47 | for vocab_name in self.related_vocabs: 48 | if vocab_name in counters: 49 | for item in self.transform_for_count(token): 50 | counters[vocab_name][item] += 1 51 | 52 | @overrides 53 | def tokens_to_indices( 54 | self, 55 | tokens: List[str], 56 | vocab: Vocabulary) -> Dict[str, List[int]]: 57 | """ 58 | Takes a list of tokens and converts them to one or more sets of indices. 59 | During the indexing process, each item corresponds to an index in the 60 | vocabulary. 61 | 62 | Parameters 63 | ---------- 64 | vocab : ``Vocabulary`` 65 | ``vocab`` is used to get the index of each item. 66 | 67 | Returns 68 | ------- 69 | res : ``Dict[str, List[int]]`` 70 | if the token and index list is [w1:5, w2:3, w3:0], the result will 71 | be {'vocab_name' : [5, 3, 0]} 72 | """ 73 | res = {} 74 | for index_name in self.related_vocabs: 75 | index_list = [ 76 | [vocab.get_token_index(item, index_name) 77 | for item in self.transform_for_index(tok)] 78 | for tok in tokens] 79 | res[index_name] = index_list 80 | return res 81 | -------------------------------------------------------------------------------- /antu/io/token_indexers/sequence_token_indexer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Callable, TypeVar 2 | from overrides import overrides 3 | from .. import Vocabulary 4 | from . import TokenIndexer 5 | Indices = TypeVar("Indices", List[int], List[List[int]]) 6 | 7 | 8 | class SequenceTokenIndexer(TokenIndexer): 9 | """ 10 | A ``SingleIdTokenIndexer`` determines how string token get represented as 11 | arrays of single id indices in a model. 12 | 13 | Parameters 14 | ---------- 15 | related_vocabs : ``List[str]`` 16 | Which vocabularies are related to the indexer. 17 | transform : ``Callable[[str,], str]``, optional (default=``lambda x:x``) 18 | What changes need to be made to the token when counting or indexing. 19 | Commonly used are lowercase transformation functions. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | related_vocabs: List[str], 25 | transform: Callable[[str, ], str] = lambda x: x) -> None: 26 | self.related_vocabs = related_vocabs 27 | self.transform = transform 28 | 29 | @overrides 30 | def count_vocab_items( 31 | self, 32 | token: str, 33 | counters: Dict[str, Dict[str, int]]) -> None: 34 | """ 35 | The token is counted directly as an element. 36 | 37 | Parameters 38 | ---------- 39 | counter : ``Dict[str, Dict[str, int]]`` 40 | We count the number of strings if the string needs to be counted to 41 | some counters. 42 | """ 43 | for vocab_name in self.related_vocabs: 44 | if vocab_name in counters: 45 | for item in self.transform(token): 46 | counters[vocab_name][item] += 1 47 | 48 | @overrides 49 | def tokens_to_indices( 50 | self, 51 | tokens: List[str], 52 | vocab: Vocabulary) -> Dict[str, List[int]]: 53 | """ 54 | Takes a list of tokens and converts them to one or more sets of indices. 55 | During the indexing process, each item corresponds to an index in the 56 | vocabulary. 57 | 58 | Parameters 59 | ---------- 60 | vocab : ``Vocabulary`` 61 | ``vocab`` is used to get the index of each item. 62 | 63 | Returns 64 | ------- 65 | res : ``Dict[str, List[int]]`` 66 | if the token and index list is [w1:5, w2:3, w3:0], the result will 67 | be {'vocab_name' : [5, 3, 0]} 68 | """ 69 | res = {} 70 | for index_name in self.related_vocabs: 71 | index_list = [[vocab.get_token_index(item, index_name) 72 | for item in self.transform(tok)] 73 | for tok in tokens] 74 | res[index_name] = index_list 75 | return res 76 | -------------------------------------------------------------------------------- /antu/io/token_indexers/single_id_token_indexer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Callable, TypeVar 2 | from overrides import overrides 3 | from .. import Vocabulary 4 | from . import TokenIndexer 5 | Indices = TypeVar("Indices", List[int], List[List[int]]) 6 | 7 | 8 | class SingleIdTokenIndexer(TokenIndexer): 9 | """ 10 | A ``SingleIdTokenIndexer`` determines how string token get represented as 11 | arrays of single id indices in a model. 12 | 13 | Parameters 14 | ---------- 15 | related_vocabs : ``List[str]`` 16 | Which vocabularies are related to the indexer. 17 | transform : ``Callable[[str,], str]``, optional (default=``lambda x:x``) 18 | What changes need to be made to the token when counting or indexing. 19 | Commonly used are lowercase transformation functions. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | related_vocabs: List[str], 25 | transform: Callable[[str, ], str] = lambda x: x) -> None: 26 | self.related_vocabs = related_vocabs 27 | self.transform = transform 28 | 29 | @overrides 30 | def count_vocab_items( 31 | self, 32 | token: str, 33 | counters: Dict[str, Dict[str, int]]) -> None: 34 | """ 35 | The token is counted directly as an element. 36 | 37 | Parameters 38 | ---------- 39 | counter : ``Dict[str, Dict[str, int]]`` 40 | We count the number of strings if the string needs to be counted to 41 | some counters. 42 | """ 43 | for vocab_name in self.related_vocabs: 44 | if vocab_name in counters: 45 | counters[vocab_name][self.transform(token)] += 1 46 | 47 | @overrides 48 | def tokens_to_indices( 49 | self, 50 | tokens: List[str], 51 | vocab: Vocabulary) -> Dict[str, List[int]]: 52 | """ 53 | Takes a list of tokens and converts them to one or more sets of indices. 54 | During the indexing process, each item corresponds to an index in the 55 | vocabulary. 56 | 57 | Parameters 58 | ---------- 59 | vocab : ``Vocabulary`` 60 | ``vocab`` is used to get the index of each item. 61 | 62 | Returns 63 | ------- 64 | res : ``Dict[str, List[int]]`` 65 | if the token and index list is [w1:5, w2:3, w3:0], the result will 66 | be {'vocab_name' : [5, 3, 0]} 67 | """ 68 | res = {} 69 | for index_name in self.related_vocabs: 70 | index_list = [vocab.get_token_index(self.transform(tok), index_name) 71 | for tok in tokens] 72 | res[index_name] = index_list 73 | return res 74 | -------------------------------------------------------------------------------- /antu/io/token_indexers/token_indexer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, TypeVar, Callable 2 | from abc import ABCMeta, abstractmethod 3 | 4 | from .. import Vocabulary 5 | 6 | Indices = TypeVar("Indices", List[int], List[List[int]]) 7 | 8 | 9 | class TokenIndexer(metaclass=ABCMeta): 10 | """ 11 | A ``TokenIndexer`` determines how string tokens get represented as arrays of 12 | indices in a model. 13 | """ 14 | 15 | @abstractmethod 16 | def count_vocab_items( 17 | self, 18 | token: str, 19 | counter: Dict[str, Dict[str, int]]) -> None: 20 | """ 21 | Defines how each token in the field is counted. In most cases, just use 22 | the string as a key. However, for character-level ``TokenIndexer``, you 23 | need to traverse each character in the string. 24 | 25 | Parameters 26 | ---------- 27 | counter : ``Dict[str, Dict[str, int]]`` 28 | We count the number of strings if the string needs to be counted to 29 | some counters. 30 | """ 31 | pass 32 | 33 | @abstractmethod 34 | def tokens_to_indices( 35 | self, 36 | tokens: List[str], 37 | vocab: Vocabulary) -> Dict[str, Indices]: 38 | """ 39 | Takes a list of tokens and converts them to one or more sets of indices. 40 | This could be just an ID for each token from the vocabulary. 41 | 42 | Parameters 43 | ---------- 44 | vocab : ``Vocabulary`` 45 | ``vocab`` is used to get the index of each item. 46 | """ 47 | pass 48 | -------------------------------------------------------------------------------- /antu/io/vocabulary.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Union, List, Set, TypeVar 2 | from bidict import bidict 3 | 4 | DEFAULT_PAD_TOKEN = "*@PAD@*" 5 | DEFAULT_UNK_TOKEN = "*@UNK@*" 6 | 7 | 8 | class Vocabulary(object): 9 | """ 10 | Parameters 11 | ---------- 12 | counters : ``Dict[str, Dict[str, int]]``, optional (default= ``dict()`` ) 13 | Element statistics for datasets. 14 | min_count : ``Dict[str, int]``, optional (default= ``dict()`` ) 15 | Defines the minimum number of occurrences when some counter are 16 | converted to vocabulary. 17 | pretrained_vocab : ``Dict[str, List[str]]``, optional (default= ``dict()`` 18 | External pre-trained vocabulary. 19 | intersection_vocab : ``Dict[str, str]``, optional (default= ``dict()`` ) 20 | Defines the intersection with which vocabulary takes, when loading some 21 | oversized pre-trained vocabulary. 22 | no_pad_namespace : ``Set[str]``, optional (default= ``set()`` ) 23 | Defines which vocabularies do not have `pad` token. 24 | no_unk_namespace : ``Set[str]``, optional (default= ``set()`` ) 25 | Defines which vocabularies do not have `oov` token. 26 | """ 27 | 28 | def __init__(self, 29 | counters: Dict[str, Dict[str, int]] = dict(), 30 | min_count: Dict[str, int] = dict(), 31 | pretrained_vocab: Dict[str, List[str]] = dict(), 32 | intersection_vocab: Dict[str, str] = dict(), 33 | no_pad_namespace: Set[str] = set(), 34 | no_unk_namespace: Set[str] = set()): 35 | 36 | self._PAD_token = DEFAULT_PAD_TOKEN 37 | self._UNK_token = DEFAULT_UNK_TOKEN 38 | self.min_count = min_count 39 | self.intersection_vocab = intersection_vocab 40 | self.no_unk_namespace = no_unk_namespace 41 | self.no_pad_namespace = no_pad_namespace 42 | self.vocab_cnt = {} 43 | self.vocab = {} 44 | 45 | for vocab_name, counter in dict(counters, **pretrained_vocab).items(): 46 | self.vocab[vocab_name] = bidict() 47 | cnt = 0 48 | 49 | # Handle unknown token 50 | if vocab_name not in no_unk_namespace: 51 | self.vocab[vocab_name][self._UNK_token] = cnt 52 | cnt += 1 53 | 54 | # Handle padding token 55 | if vocab_name not in no_pad_namespace: 56 | self.vocab[vocab_name][self._PAD_token] = cnt 57 | cnt += 1 58 | 59 | # Build Vocabulary from Dataset Counter 60 | if isinstance(counter, dict): 61 | minn = (min_count[vocab_name] 62 | if min_count and vocab_name in min_count else 0) 63 | for key, value in counter.items(): 64 | if value >= minn: 65 | self.vocab[vocab_name][key] = cnt 66 | cnt += 1 67 | 68 | # Build Vocabulary from Pretrained Vocabulary List 69 | elif isinstance(counter, list): 70 | is_intersection = vocab_name in intersection_vocab 71 | target_vocab = (self.vocab[intersection_vocab[vocab_name]] 72 | if is_intersection else {}) 73 | for key in counter: 74 | if not is_intersection or key in target_vocab: 75 | self.vocab[vocab_name][key] = cnt 76 | cnt += 1 77 | self.vocab_cnt[vocab_name] = cnt 78 | 79 | def extend_from_pretrained_vocab( 80 | self, 81 | pretrained_vocab: Dict[str, List[str]], 82 | intersection_vocab: Dict[str, str] = dict(), 83 | no_pad_namespace: Set[str] = set(), 84 | no_unk_namespace: Set[str] = set()) -> None: 85 | """ 86 | Extend the vocabulary from the pre-trained vocabulary after defining 87 | the vocabulary. 88 | 89 | Parameters 90 | ---------- 91 | pretrained_vocab : ``Dict[str, List[str]]`` 92 | External pre-trained vocabulary. 93 | intersection_vocab : ``Dict[str, str]``, optional (default= ``dict()`` ) 94 | Defines the intersection with which vocabulary takes, when loading some 95 | oversized pre-trained vocabulary. 96 | no_pad_namespace : ``Set[str]``, optional (default= ``set()`` ) 97 | Defines which vocabularies do not have `pad` token. 98 | no_unk_namespace : ``Set[str]``, optional (default= ``set()`` ) 99 | Defines which vocabularies do not have `oov` token. 100 | """ 101 | self.no_unk_namespace.update(no_unk_namespace) 102 | self.no_pad_namespace.update(no_pad_namespace) 103 | self.intersection_vocab.update(intersection_vocab) 104 | for vocab_name, counter in pretrained_vocab.items(): 105 | self.vocab[vocab_name] = bidict() 106 | 107 | cnt = 0 108 | # Handle unknown token 109 | if vocab_name not in no_unk_namespace: 110 | self.vocab[vocab_name][self._UNK_token] = cnt 111 | cnt += 1 112 | 113 | # Handle padding token 114 | if vocab_name not in no_pad_namespace: 115 | self.vocab[vocab_name][self._PAD_token] = cnt 116 | cnt += 1 117 | 118 | # Build Vocabulary from Pretrained Vocabulary List 119 | is_intersection = vocab_name in intersection_vocab 120 | target_vocab = (self.vocab[intersection_vocab[vocab_name]] 121 | if is_intersection else {}) 122 | for key in counter: 123 | if not is_intersection or key in target_vocab: 124 | self.vocab[vocab_name][key] = cnt 125 | cnt += 1 126 | self.vocab_cnt[vocab_name] = cnt 127 | 128 | def extend_from_counter( 129 | self, 130 | counters: Dict[str, Dict[str, int]], 131 | min_count: Union[int, Dict[str, int]] = dict(), 132 | no_pad_namespace: Set[str] = set(), 133 | no_unk_namespace: Set[str] = set()) -> None: 134 | """ 135 | Extend the vocabulary from the dataset statistic counters after defining 136 | the vocabulary. 137 | 138 | Parameters 139 | ---------- 140 | counters : ``Dict[str, Dict[str, int]]`` 141 | Element statistics for datasets. 142 | min_count : ``Dict[str, int]``, optional (default= ``dict()`` ) 143 | Defines the minimum number of occurrences when some counter are 144 | converted to vocabulary. 145 | no_pad_namespace : ``Set[str]``, optional (default= ``set()`` ) 146 | Defines which vocabularies do not have `pad` token. 147 | no_unk_namespace : ``Set[str]``, optional (default= ``set()`` ) 148 | Defines which vocabularies do not have `oov` token. 149 | """ 150 | self.no_unk_namespace.update(no_unk_namespace) 151 | self.no_pad_namespace.update(no_pad_namespace) 152 | self.min_count.update(min_count) 153 | 154 | for vocab_name, counter in counters.items(): 155 | self.vocab[vocab_name] = bidict() 156 | cnt = 0 157 | # Handle unknown token 158 | if vocab_name not in no_unk_namespace: 159 | self.vocab[vocab_name][self._UNK_token] = cnt 160 | cnt += 1 161 | 162 | # Handle padding token 163 | if vocab_name not in no_pad_namespace: 164 | self.vocab[vocab_name][self._PAD_token] = cnt 165 | cnt += 1 166 | 167 | # Build Vocabulary from Dataset Counter 168 | minn = (min_count[vocab_name] 169 | if min_count and vocab_name in min_count else 0) 170 | for key, value in counter.items(): 171 | if value >= minn: 172 | self.vocab[vocab_name][key] = cnt 173 | cnt += 1 174 | self.vocab_cnt[vocab_name] = cnt 175 | 176 | def add_token_to_namespace(self, token: str, namespace: str) -> None: 177 | """ 178 | Extend the vocabulary by add token to vocabulary namespace. 179 | 180 | Parameters 181 | ---------- 182 | token : ``str`` 183 | The token that needs to be added. 184 | namespace : ``str`` 185 | Which vocabulary needs to be added to. 186 | """ 187 | self.vocab[namespace][token] = self.vocab_cnt[namespace] 188 | self.vocab_cnt[namespace] += 1 189 | 190 | def get_token_index(self, token: str, vocab_name: str) -> int: 191 | """ 192 | Gets the index of a token in the vocabulary. 193 | 194 | Parameters 195 | ---------- 196 | token : ``str`` 197 | Gets the index of which token. 198 | namespace : ``str`` 199 | Which vocabulary this token belongs to. 200 | 201 | Returns 202 | ------- 203 | Index : ``int`` 204 | """ 205 | if token in self.vocab[vocab_name]: 206 | return self.vocab[vocab_name][token] 207 | elif vocab_name not in self.no_unk_namespace: 208 | return self.vocab[vocab_name][self._UNK_token] 209 | else: 210 | raise RuntimeError( 211 | 'Try to get a OOV token (%s)\'s index from a no unknown token ' 212 | 'vocabulary (%s)' % (token, vocab_name)) 213 | 214 | def get_token_from_index(self, index: int, vocab_name: str) -> str: 215 | """ 216 | Gets the token of a index in the vocabulary. 217 | 218 | Parameters 219 | ---------- 220 | index : ``int`` 221 | Gets the token of which index. 222 | namespace : ``str`` 223 | Which vocabulary this index belongs to. 224 | 225 | Returns 226 | ------- 227 | Token : ``str`` 228 | """ 229 | if index < self.vocab_cnt[vocab_name]: 230 | return self.vocab[vocab_name].inv[index] 231 | else: 232 | raise RuntimeError( 233 | 'Index (%d) out of vocabulary (%s) range' 234 | % (index, vocab_name)) 235 | 236 | def get_vocab_size(self, namespace: str) -> int: 237 | """ 238 | Gets the size of a vocabulary. 239 | 240 | Parameters 241 | ---------- 242 | namespace : ``str`` 243 | Which vocabulary. 244 | 245 | Returns 246 | ------- 247 | Vocabulary size : ``int`` 248 | """ 249 | return len(self.vocab[namespace]) 250 | 251 | def get_padding_index(self, namespace: str) -> int: 252 | if namespace not in self.no_pad_namespace: 253 | return self.vocab[namespace][self._PAD_token] 254 | else: 255 | raise RuntimeError("(%s) doesn't has PAD token." % (namespace)) 256 | 257 | def get_unknow_index(self, namespace: str) -> int: 258 | if namespace not in self.no_unk_namespace: 259 | return self.vocab[namespace][self._UNK_token] 260 | else: 261 | raise RuntimeError("(%s) doesn't has UNK token." % (namespace)) 262 | -------------------------------------------------------------------------------- /antu/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntNLP/antu/3256ada0784401b9677d9568e81f3a8792eebee7/antu/nn/__init__.py -------------------------------------------------------------------------------- /antu/nn/dynet/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules.dynet_model import dy_model 2 | from .modules.linear import Linear 3 | from .modules.perceptron import MLP 4 | from .modules.graph_nn_unit import GraphNNUnit 5 | 6 | from .functional.leaky_relu import leaky_relu 7 | 8 | from .classifiers.nn_classifier import BiaffineLabelClassifier, PointerLabelClassifier 9 | -------------------------------------------------------------------------------- /antu/nn/dynet/attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .biaffine import BiaffineAttention 2 | from .biaffine_matrix import BiaffineMatAttention 3 | -------------------------------------------------------------------------------- /antu/nn/dynet/attention/biaffine.py: -------------------------------------------------------------------------------- 1 | import dynet as dy 2 | import numpy as np 3 | from ..modules import dy_model 4 | from ..init import init_wrap 5 | 6 | 7 | @dy_model 8 | class BiaffineAttention(object): 9 | """This builds Pointer Networks labeled Classifier: 10 | .. math:: 11 | \\begin{split} 12 | f_{ptr}(\\boldsymbol{h}_i, \\boldsymbol{s}_t) &= 13 | \\boldsymbol{V}_a{}^\\top 14 | \\text{tanh}(\\boldsymbol{W}_1 \\boldsymbol{h}_i + 15 | \\boldsymbol{W}_2 \\boldsymbol{s}_t) \\\\ 16 | \\boldsymbol{p}_t &= \\text{softmax}(f_{ptr}(\\boldsymbol{h}_i, \\boldsymbol{s}_t)) \\\\ 17 | \end{split} 18 | :param model dynet.ParameterCollection: 19 | :param l_dim int: Row dimension of :math:`\\boldsymbol{V}` 20 | :param v_dim int: Column dimension of :math:`\\boldsymbol{V}` 21 | :param h_dim int: Dimension of :math:`\\boldsymbol{h}` 22 | :param s_dim int: Dimension of :math:`\\boldsymbol{s}` 23 | :returns: probatilistic vector :math:`\\boldsymbol{p}_t` 24 | :rtype: dynet.Expression 25 | """ 26 | 27 | def __init__( 28 | self, 29 | model, 30 | h_dim: int, s_dim: int, n_label: int, 31 | bias=False, init=dy.ConstInitializer(0.)): 32 | pc = model.add_subcollection() 33 | if bias: 34 | if n_label == 1: 35 | self.B = pc.add_parameters((h_dim,), init=0) 36 | else: 37 | self.V = pc.add_parameters((n_label, h_dim+s_dim), init=0) 38 | self.B = pc.add_parameters((n_label,), init=0) 39 | init_U = init_wrap(init, (h_dim*n_label, s_dim)) 40 | self.U = pc.add_parameters((h_dim*n_label, s_dim), init=init_U) 41 | self.h_dim, self.s_dim, self.n_label = h_dim, s_dim, n_label 42 | self.pc, self.bias = pc, bias 43 | self.spec = (h_dim, s_dim, n_label, bias, init) 44 | 45 | def __call__(self, h, s): 46 | # hT -> ((L, h_dim), B), s -> ((s_dim, L), B) 47 | hT = dy.transpose(h) 48 | lin = self.U * s # ((h_dim*n_label, L), B) 49 | if self.n_label > 1: 50 | lin = dy.reshape(lin, (self.h_dim, self.n_label)) 51 | blin = hT * lin 52 | if self.n_label == 1: 53 | return blin + (hT * self.B if self.bias else 0) 54 | else: 55 | return dy.transpose(blin)+(self.V*dy.concatenate([h, s])+self.B if self.bias else 0) 56 | 57 | -------------------------------------------------------------------------------- /antu/nn/dynet/attention/biaffine_matrix.py: -------------------------------------------------------------------------------- 1 | import _dynet as dy 2 | import numpy as np 3 | from ..modules import dy_model 4 | from ..init import init_wrap 5 | 6 | 7 | @dy_model 8 | class BiaffineMatAttention(object): 9 | """This builds Pointer Networks labeled Classifier: 10 | .. math:: 11 | \\begin{split} 12 | f_{ptr}(\\boldsymbol{h}_i, \\boldsymbol{s}_t) &= 13 | \\boldsymbol{V}_a{}^\\top 14 | \\text{tanh}(\\boldsymbol{W}_1 \\boldsymbol{h}_i + 15 | \\boldsymbol{W}_2 \\boldsymbol{s}_t) \\\\ 16 | \\boldsymbol{p}_t &= \\text{softmax}(f_{ptr}(\\boldsymbol{h}_i, \\boldsymbol{s}_t)) \\\\ 17 | \end{split} 18 | :param model dynet.ParameterCollection: 19 | :param l_dim int: Row dimension of :math:`\\boldsymbol{V}` 20 | :param v_dim int: Column dimension of :math:`\\boldsymbol{V}` 21 | :param h_dim int: Dimension of :math:`\\boldsymbol{h}` 22 | :param s_dim int: Dimension of :math:`\\boldsymbol{s}` 23 | :returns: probatilistic vector :math:`\\boldsymbol{p}_t` 24 | :rtype: dynet.Expression 25 | """ 26 | def __init__( 27 | self, 28 | model, 29 | h_dim: int, s_dim: int, n_label: int, 30 | h_bias=False, s_bias=False, init=dy.ConstInitializer(0.)): 31 | pc = model.add_subcollection() 32 | h_dim += s_bias 33 | s_dim += h_bias 34 | init_U = init_wrap(init, (h_dim*n_label, s_dim)) 35 | self.U = pc.add_parameters((h_dim*n_label, s_dim), init=init_U) 36 | self.h_dim, self.s_dim, self.n_label = h_dim, s_dim, n_label 37 | self.pc, self.h_bias, self.s_bias = pc, h_bias, s_bias 38 | self.spec = (h_dim, s_dim, n_label, h_bias, s_bias, init) 39 | 40 | def __call__(self, h, s): 41 | # hT -> ((L, h_dim), B), s -> ((s_dim, L), B) 42 | if len(h.dim()[0]) == 2: 43 | L = h.dim()[0][1] 44 | if self.h_bias: s = dy.concatenate([s, dy.inputTensor(np.ones((1, L), dtype=np.float32))]) 45 | if self.s_bias: h = dy.concatenate([h, dy.inputTensor(np.ones((1, L), dtype=np.float32))]) 46 | else: 47 | if self.h_bias: s = dy.concatenate([s, dy.inputTensor(np.ones((1,), dtype=np.float32))]) 48 | if self.s_bias: h = dy.concatenate([h, dy.inputTensor(np.ones((1,), dtype=np.float32))]) 49 | hT = dy.transpose(h) 50 | lin = self.U * s # ((h_dim*n_label, L), B) 51 | if self.n_label > 1: 52 | lin = dy.reshape(lin, (self.h_dim, self.n_label)) 53 | 54 | blin = hT * lin 55 | if self.n_label == 1: 56 | return blin 57 | else: 58 | return dy.transpose(blin) 59 | 60 | 61 | @staticmethod 62 | def from_spec(spec, model): 63 | """Create and return a new instane with the needed parameters. 64 | 65 | It is one of the prerequisites for Dynet save/load method. 66 | """ 67 | h_dim, s_dim, n_label, h_bias, s_bias, init = spec 68 | return BiaffineMatLabelClassifier(model, h_dim, s_dim, n_label, h_bias, s_bias, init) 69 | 70 | def param_collection(self): 71 | """Return a :code:`dynet.ParameterCollection` object with the parameters. 72 | 73 | It is one of the prerequisites for Dynet save/load method. 74 | """ 75 | return self.pc 76 | -------------------------------------------------------------------------------- /antu/nn/dynet/attention/multi_head.py: -------------------------------------------------------------------------------- 1 | from .single import Attention 2 | 3 | 4 | class MultiHeadedAttention: 5 | """ 6 | Take in model size and number of heads. 7 | """ 8 | 9 | def __init__(self, h, d_model, dropout=0.1): 10 | super().__init__() 11 | assert d_model % h == 0 12 | 13 | # We assume d_v always equals d_k 14 | self.d_k = d_model // h 15 | self.h = h 16 | 17 | self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)]) 18 | self.output_linear = nn.Linear(d_model, d_model) 19 | self.attention = Attention() 20 | 21 | self.dropout = nn.Dropout(p=dropout) 22 | 23 | def forward(self, query, key, value, mask=None): 24 | batch_size = query.size(0) 25 | 26 | # 1) Do all the linear projections in batch from d_model => h x d_k 27 | query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) 28 | for l, x in zip(self.linear_layers, (query, key, value))] 29 | 30 | # 2) Apply attention on all the projected vectors in batch. 31 | x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout) 32 | 33 | # 3) "Concat" using a view and apply a final linear. 34 | x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k) 35 | 36 | return self.output_linear(x) -------------------------------------------------------------------------------- /antu/nn/dynet/attention/single.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | 4 | import math 5 | 6 | 7 | class Attention: 8 | """ 9 | Compute 'Scaled Dot Product Attention 10 | """ 11 | 12 | def forward(self, query, key, value, mask=None, dropout=None): 13 | scores = torch.matmul(query, key.transpose(-2, -1)) \ 14 | / math.sqrt(query.size(-1)) 15 | 16 | if mask is not None: 17 | scores = scores.masked_fill(mask == 0, -1e9) 18 | 19 | p_attn = F.softmax(scores, dim=-1) 20 | 21 | if dropout is not None: 22 | p_attn = dropout(p_attn) 23 | 24 | return torch.matmul(p_attn, value), p_attn -------------------------------------------------------------------------------- /antu/nn/dynet/classifiers/__init__.py: -------------------------------------------------------------------------------- 1 | from .nn_classifier import BiaffineLabelClassifier 2 | from .nn_classifier import PointerLabelClassifier 3 | 4 | -------------------------------------------------------------------------------- /antu/nn/dynet/classifiers/nn_classifier.py: -------------------------------------------------------------------------------- 1 | import dynet as dy 2 | import numpy as np 3 | from ..modules import dy_model 4 | 5 | 6 | @dy_model 7 | class PointerLabelClassifier(object): 8 | """This builds Pointer Networks labeled Classifier: 9 | .. math:: 10 | \\begin{split} 11 | f_{ptr}(\\boldsymbol{h}_i, \\boldsymbol{s}_t) &= 12 | \\boldsymbol{V}_a{}^\\top 13 | \\text{tanh}(\\boldsymbol{W}_1 \\boldsymbol{h}_i + 14 | \\boldsymbol{W}_2 \\boldsymbol{s}_t) \\\\ 15 | \\boldsymbol{p}_t &= \\text{softmax}(f_{ptr}(\\boldsymbol{h}_i, \\boldsymbol{s}_t)) \\\\ 16 | \end{split} 17 | :param model dynet.ParameterCollection: 18 | :param l_dim int: Row dimension of :math:`\\boldsymbol{V}` 19 | :param v_dim int: Column dimension of :math:`\\boldsymbol{V}` 20 | :param h_dim int: Dimension of :math:`\\boldsymbol{h}` 21 | :param s_dim int: Dimension of :math:`\\boldsymbol{s}` 22 | :returns: probatilistic vector :math:`\\boldsymbol{p}_t` 23 | :rtype: dynet.Expression 24 | """ 25 | 26 | def __init__(self, model, l_dim, v_dim, h_dim, s_dim, layers=1): 27 | pc = model.add_subcollection() 28 | self.layers = layers 29 | self.V = ([pc.add_parameters((1, v_dim)) for _ in range(layers-1)] 30 | +[pc.add_parameters((l_dim, v_dim))]) 31 | 32 | self.W1 = [pc.add_parameters((v_dim, h_dim)) for _ in range(layers)] 33 | 34 | self.W2 = ([pc.add_parameters((v_dim, s_dim))] 35 | +[pc.add_parameters((v_dim, h_dim+s_dim)) for _ in range(layers-1)]) 36 | 37 | self.B1 = pc.add_parameters( 38 | (l_dim, h_dim), init=dy.ConstInitializer(0)) 39 | self.B2 = pc.add_parameters( 40 | (l_dim, s_dim), init=dy.ConstInitializer(0)) 41 | 42 | # Only single layer support 43 | #self._W1 = pc.add_parameters((v_dim, h_dim)) 44 | #self._W2 = pc.add_parameters((v_dim, s_dim)) 45 | #self._V = pc.add_parameters((l_dim, v_dim)) 46 | 47 | self.pc = pc 48 | self.spec = l_dim, v_dim, h_dim, s_dim, layers 49 | 50 | def __call__(self, x, h_matrix, noprob=False): 51 | s_t = x 52 | for i in range(self.layers-1): 53 | e_t = self.V[i] * dy.tanh(self.W1[i]*h_matrix + self.W2[i]*s_t) 54 | a_t = dy.softmax(dy.transpose(e_t)) 55 | c_t = h_matrix * a_t 56 | s_t = dy.concatenate([x, c_t]) 57 | 58 | e_t = self.V[-1] * dy.tanh(self.W1[-1]*h_matrix + 59 | self.W2[-1]*s_t) + self.B1 * h_matrix + self.B2 * s_t 60 | 61 | if len(h_matrix.dim()[0]) > 1: 62 | e_t = dy.reshape( 63 | e_t, (self.V[-1].dim()[0][0] * h_matrix.dim()[0][1],)) 64 | if not noprob: 65 | p_t = dy.softmax(e_t) 66 | return p_t 67 | else: 68 | return e_t 69 | 70 | 71 | @dy_model 72 | class BiaffineLabelClassifier(object): 73 | """This builds Pointer Networks labeled Classifier: 74 | .. math:: 75 | \\begin{split} 76 | f_{ptr}(\\boldsymbol{h}_i, \\boldsymbol{s}_t) &= 77 | \\boldsymbol{V}_a{}^\\top 78 | \\text{tanh}(\\boldsymbol{W}_1 \\boldsymbol{h}_i + 79 | \\boldsymbol{W}_2 \\boldsymbol{s}_t) \\\\ 80 | \\boldsymbol{p}_t &= \\text{softmax}(f_{ptr}(\\boldsymbol{h}_i, \\boldsymbol{s}_t)) \\\\ 81 | \end{split} 82 | :param model dynet.ParameterCollection: 83 | :param l_dim int: Row dimension of :math:`\\boldsymbol{V}` 84 | :param v_dim int: Column dimension of :math:`\\boldsymbol{V}` 85 | :param h_dim int: Dimension of :math:`\\boldsymbol{h}` 86 | :param s_dim int: Dimension of :math:`\\boldsymbol{s}` 87 | :returns: probatilistic vector :math:`\\boldsymbol{p}_t` 88 | :rtype: dynet.Expression 89 | """ 90 | 91 | def __init__(self, model, h_dim, s_dim, n_label, h_bias=False, s_bias=False): 92 | pc = model.add_subcollection() 93 | if h_bias: 94 | h_dim += 1 95 | if s_bias: 96 | s_dim += 1 97 | if n_label == 1: 98 | self.U = pc.add_parameters( 99 | (h_dim, s_dim), init=dy.ConstInitializer(0.)) 100 | else: 101 | self.U = pc.add_parameters( 102 | (h_dim*n_label, s_dim), init=dy.ConstInitializer(0.)) 103 | self.pc = pc 104 | self.h_dim = h_dim 105 | self.s_dim = s_dim 106 | self.n_label = n_label 107 | self.h_bias = h_bias 108 | self.s_bias = s_bias 109 | self.spec = (h_dim, s_dim, n_label, h_bias, s_bias) 110 | 111 | def __call__(self, h, s): 112 | if self.h_bias: 113 | if len(h.dim()[0]) == 2: 114 | h = dy.concatenate( 115 | [h, dy.inputTensor(np.ones((1, h.dim()[0][1]), dtype=np.float32))]) 116 | else: 117 | h = dy.concatenate( 118 | [h, dy.inputTensor(np.ones((1,), dtype=np.float32))]) 119 | if self.s_bias: 120 | if len(s.dim()[0]) == 2: 121 | s = dy.concatenate( 122 | [s, dy.inputTensor(np.ones((1, s.dim()[0][1]), dtype=np.float32))]) 123 | else: 124 | s = dy.concatenate( 125 | [s, dy.inputTensor(np.ones((1,), dtype=np.float32))]) 126 | lin = self.U * s 127 | if self.n_label > 1: 128 | lin = dy.reshape(lin, (self.h_dim, self.n_label)) 129 | blin = dy.transpose(h) * lin 130 | return blin 131 | -------------------------------------------------------------------------------- /antu/nn/dynet/embedding/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntNLP/antu/3256ada0784401b9677d9568e81f3a8792eebee7/antu/nn/dynet/embedding/__init__.py -------------------------------------------------------------------------------- /antu/nn/dynet/embedding/bert.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .token import TokenEmbedding 3 | from .position import PositionalEmbedding 4 | from .segment import SegmentEmbedding 5 | 6 | 7 | class BERTEmbedding(nn.Module): 8 | """ 9 | BERT Embedding which is consisted with under features 10 | 1. TokenEmbedding : normal embedding matrix 11 | 2. PositionalEmbedding : adding positional information using sin, cos 12 | 2. SegmentEmbedding : adding sentence segment info, (sent_A:1, sent_B:2) 13 | sum of all these features are output of BERTEmbedding 14 | """ 15 | 16 | def __init__(self, vocab_size, embed_size, dropout=0.1): 17 | """ 18 | :param vocab_size: total vocab size 19 | :param embed_size: embedding size of token embedding 20 | :param dropout: dropout rate 21 | """ 22 | super().__init__() 23 | self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size) 24 | self.position = PositionalEmbedding(d_model=self.token.embedding_dim) 25 | self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim) 26 | self.dropout = nn.Dropout(p=dropout) 27 | self.embed_size = embed_size 28 | 29 | def forward(self, sequence, segment_label): 30 | x = self.token(sequence) + self.position(sequence) + self.segment(segment_label) 31 | return self.dropout(x) -------------------------------------------------------------------------------- /antu/nn/dynet/embedding/position.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | class PositionalEmbedding: 5 | 6 | def __init__(self, d_model, max_len=512): 7 | 8 | # Compute the positional encodings once in log space. 9 | 10 | pe = torch.zeros(max_len, d_model).float() 11 | pe.require_grad = False 12 | 13 | position = torch.arange(0, max_len).float().unsqueeze(1) 14 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 15 | 16 | pe[:, 0::2] = torch.sin(position * div_term) 17 | pe[:, 1::2] = torch.cos(position * div_term) 18 | 19 | pe = pe.unsqueeze(0) 20 | self.register_buffer('pe', pe) 21 | 22 | def forward(self, x): 23 | return self.pe[:, :x.size(1)] 24 | -------------------------------------------------------------------------------- /antu/nn/dynet/embedding/segment.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class SegmentEmbedding(nn.Embedding): 5 | def __init__(self, embed_size=512): 6 | super().__init__(3, embed_size, padding_idx=0) -------------------------------------------------------------------------------- /antu/nn/dynet/embedding/token.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class TokenEmbedding(nn.Embedding): 5 | def __init__(self, vocab_size, embed_size=512): 6 | super().__init__(vocab_size, embed_size, padding_idx=0) -------------------------------------------------------------------------------- /antu/nn/dynet/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .gelu import GELU 2 | from .leaky_relu import leaky_relu 3 | -------------------------------------------------------------------------------- /antu/nn/dynet/functional/gelu.py: -------------------------------------------------------------------------------- 1 | import dynet as dy 2 | import math 3 | 4 | 5 | def GELU(x): 6 | return 0.5 * x * (1 + dy.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * dy.pow(x, 3)))) 7 | 8 | 9 | ''' 10 | class GELU: 11 | """ 12 | Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU 13 | """ 14 | 15 | def __call__(self, x): 16 | return 0.5 * x * (1 + dy.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * dy.pow(x, 3)))) 17 | ''' 18 | -------------------------------------------------------------------------------- /antu/nn/dynet/functional/leaky_relu.py: -------------------------------------------------------------------------------- 1 | import dynet as dy 2 | 3 | 4 | def leaky_relu(x, a): 5 | return dy.bmax(a*x, x) 6 | -------------------------------------------------------------------------------- /antu/nn/dynet/init/__init__.py: -------------------------------------------------------------------------------- 1 | from .orthogonal_initializer import get_orthogonal_matrix 2 | from .orthogonal_initializer import OrthogonalInitializer 3 | from .init_wrap import init_wrap 4 | -------------------------------------------------------------------------------- /antu/nn/dynet/init/init_wrap.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import dynet as dy 3 | from . import OrthogonalInitializer 4 | 5 | 6 | def init_wrap( 7 | init: dy.PyInitializer, 8 | size: Tuple[int]) -> dy.PyInitializer: 9 | 10 | if init == OrthogonalInitializer: 11 | return dy.NumpyInitializer(init.init(size)) 12 | elif isinstance(init, dy.PyInitializer) == True: 13 | return init 14 | else: 15 | raise RuntimeError('%s is not a instance of dy.PyInitializer.' % init) 16 | 17 | -------------------------------------------------------------------------------- /antu/nn/dynet/init/orthogonal_initializer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import dynet as dy 3 | 4 | 5 | def get_orthogonal_matrix(output_size, input_size): 6 | """ 7 | adopted from Timothy Dozat https://github.com/tdozat/Parser/blob/master/lib/linalg.py 8 | """ 9 | I = np.eye(output_size) 10 | lr = .1 11 | eps = .05/(output_size + input_size) 12 | success = False 13 | tries = 0 14 | while not success and tries < 10: 15 | Q = np.random.randn(input_size, output_size) / np.sqrt(output_size) 16 | for i in range(100): 17 | QTQmI = Q.T.dot(Q) - I 18 | loss = np.sum(QTQmI**2 / 2) 19 | Q2 = Q**2 20 | Q -= lr*Q.dot(QTQmI) / (np.abs(Q2 + Q2.sum(axis=0, 21 | keepdims=True) + Q2.sum(axis=1, keepdims=True) - 1) + eps) 22 | if np.max(Q) > 1e6 or loss > 1e6 or not np.isfinite(loss): 23 | tries += 1 24 | lr /= 2 25 | break 26 | success = True 27 | if success: 28 | print('Orthogonal pretrainer (%d, %d) loss: %.2e' % (output_size, 29 | input_size, loss)) 30 | else: 31 | print('Orthogonal pretrainer (%d, %d) failed, using non-orthogonal' 32 | 'random matrix', (output_size, input_size)) 33 | Q=np.random.randn(input_size, output_size) / np.sqrt(output_size) 34 | return np.transpose(Q.astype(np.float32)) 35 | 36 | 37 | class OrthogonalInitializer(dy.PyInitializer): 38 | 39 | @classmethod 40 | def init(cls, size): 41 | assert len(size) == 2 42 | return get_orthogonal_matrix(size[0], size[1]) 43 | 44 | -------------------------------------------------------------------------------- /antu/nn/dynet/modules/BERT.py: -------------------------------------------------------------------------------- 1 | from .transformer import TransformerBlock 2 | from antu.nn.dynet.embedding import BERTEmbedding 3 | 4 | 5 | class BERT: 6 | """ 7 | BERT model : Bidirectional Encoder Representations from Transformers. 8 | """ 9 | 10 | def __init__( 11 | self, 12 | vocab_size: int, 13 | hidden: int = 768, 14 | n_layers: int = 12, 15 | attn_heads: int = 12, 16 | dropout: float = 0.1): 17 | """ 18 | :param vocab_size: vocab_size of total words 19 | :param hidden: BERT model hidden size 20 | :param n_layers: numbers of Transformer blocks(layers) 21 | :param attn_heads: number of attention heads 22 | :param dropout: dropout rate 23 | """ 24 | super().__init__() 25 | self.hidden = hidden 26 | self.n_layers = n_layers 27 | self.attn_heads = attn_heads 28 | 29 | # paper noted they used 4*hidden_size for ff_network_hidden_size 30 | self.feed_forward_hidden = hidden * 4 31 | 32 | # embedding for BERT, sum of positional, segment, token embeddings 33 | self.embedding = BERTEmbedding( 34 | vocab_size=vocab_size, embed_size=hidden) 35 | 36 | # multi-layers transformer blocks, deep network 37 | self.transformer_blocks = [ 38 | TransformerBlock(hidden, attn_heads, hidden * 4, dropout) 39 | for _ in range(n_layers)] 40 | 41 | def forward(self, x, segment_info): 42 | # attention masking for padded token 43 | mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1) 44 | 45 | # embedding the indexed sequence to sequence of vectors 46 | x = self.embedding(x, segment_info) 47 | 48 | # running over multiple transformer blocks 49 | for transformer in self.transformer_blocks: 50 | x = transformer.forward(x, mask) 51 | 52 | return x 53 | -------------------------------------------------------------------------------- /antu/nn/dynet/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .dynet_model import dy_model 2 | -------------------------------------------------------------------------------- /antu/nn/dynet/modules/attention_mechanism.py: -------------------------------------------------------------------------------- 1 | """Attention mechanism module 2 | Attention is most commonly used in many NLP tasks. 3 | It can be used in any sequence model to look back at past states. 4 | We have implemented some of the attention mechanisms used. 5 | """ 6 | 7 | import dynet as dy 8 | from . import dy_model 9 | 10 | 11 | @dy_model 12 | class VanillaAttention(object): 13 | """This computes Additive attention: 14 | The original attention mechanism (Bahdanau et al., 2015) 15 | uses a one-hidden layer feed-forward network to calculate the attention alignment: 16 | .. math:: 17 | \\begin{split} 18 | f_{att}(\\boldsymbol{h}_i, \\boldsymbol{s}_t) &= 19 | \\boldsymbol{v}_a{}^\\top 20 | \\text{tanh}(\\boldsymbol{W}_1 \\boldsymbol{h}_i + 21 | \\boldsymbol{W}_2 \\boldsymbol{s}_t) \\\\ 22 | \\boldsymbol{a}_t &= \\text{softmax}(f_{att}(\\boldsymbol{h}_i, \\boldsymbol{s}_t)) \\\\ 23 | \\boldsymbol{c}_t &= \sum_i a_t^i \\boldsymbol{h}_i \\\\ 24 | \end{split} 25 | :param model dynet.ParameterCollection: 26 | :param v_dim int: Dimension of :math:`\\boldsymbol{v}` 27 | :param h_dim int: Dimension of :math:`\\boldsymbol{h}` 28 | :param s_dim int: Dimension of :math:`\\boldsymbol{s}` 29 | :returns: attention vector :math:`\\boldsymbol{c}_t` 30 | :rtype: dynet.Expression 31 | """ 32 | 33 | def __init__(self, model, v_dim, h_dim, s_dim): 34 | pc = model.add_subcollection() 35 | 36 | self.W1 = pc.add_parameters((v_dim, h_dim)) 37 | self.W2 = pc.add_parameters((v_dim, s_dim)) 38 | self.v = pc.add_parameters((1, v_dim)) 39 | 40 | self.pc = pc 41 | self.spec = v_dim, h_dim, s_dim 42 | 43 | def __call__(self, s_t, h_matrix): 44 | e_t = self.v * dy.tanh(self.W1*h_matrix + self.W2 * s_t) 45 | a_t = dy.softmax(dy.transpose(e_t)) 46 | c_t = h_matrix * a_t 47 | return c_t 48 | -------------------------------------------------------------------------------- /antu/nn/dynet/modules/dynet_model.py: -------------------------------------------------------------------------------- 1 | def dy_model(cls): 2 | 3 | def param_collection(self): 4 | return self.pc 5 | 6 | @staticmethod 7 | def from_spec(spec, model): 8 | return cls(model, *spec) 9 | 10 | cls.from_spec, cls.param_collection = from_spec, param_collection 11 | return cls 12 | -------------------------------------------------------------------------------- /antu/nn/dynet/modules/feed_forward.py: -------------------------------------------------------------------------------- 1 | import dynet as dy 2 | from ..functional import GELU 3 | from . import dy_model 4 | 5 | 6 | @dy_model 7 | class PositionwiseFeedForward: 8 | "Implements FFN equation." 9 | 10 | def __init__( 11 | self, 12 | model: dy.ParameterCollection, 13 | in_dim: int, 14 | hid_dim: int, 15 | p: int = 0.1): 16 | 17 | pc = model.add_subcollection() 18 | self.W1 = Linear(pc, in_dim, hid_dim) 19 | self.W2 = Linear(pc, hid_dim, in_dim) 20 | self.p = p 21 | self.pc = pc 22 | self.spec = (in_dim, hid_dim, p) 23 | 24 | def __call__(self, x, is_train=False): 25 | p = self.p if is_train else 0 26 | return self.W2(dy.dropout(GELU(self.W1(x)), p)) 27 | 28 | -------------------------------------------------------------------------------- /antu/nn/dynet/modules/graph_nn_unit.py: -------------------------------------------------------------------------------- 1 | import dynet as dy 2 | from . import dy_model 3 | from ..init import init_wrap 4 | 5 | 6 | @dy_model 7 | class GraphNNUnit(object): 8 | """docstring for GraphNNUnit""" 9 | 10 | def __init__( 11 | self, 12 | model: dy.ParameterCollection, 13 | h_dim: int, 14 | d_dim: int, 15 | f=dy.tanh, 16 | init: dy.PyInitializer = dy.GlorotInitializer()): 17 | 18 | pc = model.add_subcollection() 19 | init_W = init_wrap(init, (h_dim, d_dim)) 20 | self.W = pc.add_parameters((h_dim, d_dim), init=init_W) 21 | init_B = init_wrap(init, (h_dim, h_dim)) 22 | self.B = pc.add_parameters((h_dim, h_dim), init=init_B) 23 | 24 | self.pc, self.f = pc, f 25 | self.spec = (h_dim, d_dim, f, init) 26 | 27 | def __call__(self, H, D): 28 | return self.f(self.W * H + self.B * D) 29 | -------------------------------------------------------------------------------- /antu/nn/dynet/modules/layer_norm.py: -------------------------------------------------------------------------------- 1 | import dynet as dy 2 | from . import dy_model 3 | 4 | 5 | @dy_model 6 | class LayerNorm: 7 | "Construct a layernorm module (See citation for details)." 8 | 9 | def __init__(self, model, features, eps=1e-6): 10 | pc = model.add_subcollection() 11 | 12 | self.a_2 = pc.add_parameters(features, init=1) 13 | self.b_2 = pc.add_parameters(features, init=0) 14 | self.eps = eps 15 | self.spec = (features, eps) 16 | 17 | 18 | def __call__(self, x): 19 | # mean = x.mean(-1, keepdim=True) 20 | mean = dy.mean_elems(x) 21 | # std = x.std(-1, keepdim=True) 22 | std = dy.std(x) 23 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 24 | -------------------------------------------------------------------------------- /antu/nn/dynet/modules/linear.py: -------------------------------------------------------------------------------- 1 | import dynet as dy 2 | import math 3 | from . import dy_model 4 | from ..init import init_wrap 5 | 6 | 7 | @dy_model 8 | class Linear: 9 | "Construct a Affine Transformation." 10 | 11 | def __init__( 12 | self, 13 | model: dy.ParameterCollection, 14 | in_dim: int, 15 | out_dim: int, 16 | bias: bool = True, 17 | init: dy.PyInitializer = dy.GlorotInitializer()): 18 | 19 | pc = model.add_subcollection() 20 | init = init_wrap(init, (out_dim, in_dim)) 21 | self.W = pc.add_parameters((out_dim, in_dim), init=init) 22 | if bias: 23 | self.b = pc.add_parameters((out_dim,), init=0) 24 | self.pc = pc 25 | self.bias = bias 26 | self.spec = (in_dim, out_dim, bias, init) 27 | 28 | def __call__(self, x): 29 | b = self.b if self.bias else 0 30 | return self.W * x + b 31 | -------------------------------------------------------------------------------- /antu/nn/dynet/modules/perceptron.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import dynet as dy 3 | from . import dy_model 4 | from ..init import init_wrap 5 | 6 | 7 | @dy_model 8 | class MLP(object): 9 | """docstring for MLP""" 10 | 11 | def __init__( 12 | self, 13 | model: dy.ParameterCollection, 14 | sizes: List[int], 15 | f: 'nonlinear' = dy.tanh, 16 | p: float = 0.0, 17 | bias: bool = True, 18 | init: dy.PyInitializer = dy.GlorotInitializer()): 19 | 20 | pc = model.add_subcollection() 21 | self.W = [ 22 | pc.add_parameters((x, y), init=init_wrap(init, (x, y))) 23 | for x, y in zip(sizes[1:], sizes[:-1])] 24 | if bias: 25 | self.b = [pc.add_parameters((y,), init=0) for y in sizes[1:]] 26 | 27 | self.pc, self.f, self.p, self.bias = pc, f, p, bias 28 | self.spec = (sizes, f, p, bias, init) 29 | 30 | def __call__(self, x, train=False): 31 | h = x 32 | for i in range(len(self.W[:-1])): 33 | h = self.f(self.W[i]*h + (self.b[i] if self.bias else 0)) 34 | if train: 35 | if len(h.dim()[0]) > 1: 36 | h = dy.dropout_dim(h, 1, self.p) 37 | else: 38 | h = dy.dropout(h, self.p) 39 | return self.W[-1]*h + (self.b[-1] if self.bias else 0) 40 | -------------------------------------------------------------------------------- /antu/nn/dynet/modules/sublayer.py: -------------------------------------------------------------------------------- 1 | from .layer_norm import LayerNorm 2 | from . import dy_model 3 | 4 | 5 | @dy_model 6 | class SublayerConnection: 7 | """ 8 | A residual connection followed by a layer norm. 9 | Note for code simplicity the norm is first as opposed to last. 10 | """ 11 | 12 | def __init__(self, model, size, p): 13 | pc = model.add_subcollection() 14 | 15 | self.norm = LayerNorm(pc, size) 16 | self.p = p 17 | self.spec = (size, p) 18 | 19 | def __call__(self, x, sublayer): 20 | "Apply residual connection to any sublayer with the same size." 21 | return x + dy.dropout(sublayer(self.norm(x)), self.p) 22 | -------------------------------------------------------------------------------- /antu/nn/dynet/modules/transformer.py: -------------------------------------------------------------------------------- 1 | from antu.nn.dynet.attention import MultiHeadedAttention 2 | from antu.nn.dynet.utils import SublayerConnection, PositionwiseFeedForward 3 | import dynet as dy 4 | 5 | 6 | class TransformerBlock: 7 | """ 8 | Bidirectional Encoder = Transformer (self-attention) 9 | Transformer = MultiHead_Attention + Feed_Forward with sublayer connection 10 | """ 11 | 12 | def __init__( 13 | self, 14 | hidden: int, 15 | attn_heads: int, 16 | feed_forward_hidden: int, 17 | dropout: float): 18 | """ 19 | :param hidden: hidden size of transformer 20 | :param attn_heads: head sizes of multi-head attention 21 | :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size 22 | :param dropout: dropout rate 23 | """ 24 | 25 | self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden) 26 | self.feed_forward = PositionwiseFeedForward( 27 | d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout) 28 | self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout) 29 | self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout) 30 | self.dropout = dy.Dropout(p=dropout) 31 | 32 | def forward(self, x, mask): 33 | x = self.input_sublayer( 34 | x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask)) 35 | x = self.output_sublayer(x, self.feed_forward) 36 | return self.dropout(x) 37 | -------------------------------------------------------------------------------- /antu/nn/dynet/seq2seq_encoders/.rnn_builder.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntNLP/antu/3256ada0784401b9677d9568e81f3a8792eebee7/antu/nn/dynet/seq2seq_encoders/.rnn_builder.py.swp -------------------------------------------------------------------------------- /antu/nn/dynet/seq2seq_encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from .seq2seq_encoder import Seq2SeqEncoder 2 | from .rnn_builder import DeepBiRNNBuilder 3 | from .rnn_builder import orthonormal_VanillaLSTMBuilder 4 | from .rnn_builder import orthonormal_CompactVanillaLSTMBuilder -------------------------------------------------------------------------------- /antu/nn/dynet/seq2seq_encoders/rnn_builder.py: -------------------------------------------------------------------------------- 1 | import dynet as dy 2 | import numpy as np 3 | from . import Seq2SeqEncoder 4 | from ..init import get_orthogonal_matrix 5 | from ..modules import dy_model 6 | 7 | 8 | @dy_model 9 | class DeepBiRNNBuilder(Seq2SeqEncoder): 10 | """This builds deep bidirectional LSTM: 11 | 12 | The original attention mechanism (Bahdanau et al., 2015) 13 | uses a one-hidden layer feed-forward network to calculate the attention alignment: 14 | 15 | :param model dynet.ParameterCollection: 16 | :param n_layers int: Number of LSTM layers 17 | :param x_dim int: Dimension of LSTM input :math:`\\boldsymbol{x}` 18 | :param h_dim int: Dimension of LSTM hidden state :math:`\\boldsymbol{h}` 19 | :param LSTMBuilder dynet._RNNBuilder: Dynet LSTM type 20 | :param param_init bool: Initializes LSTM with parameter 21 | :returns: (last_output, outputs) 22 | :rtype: tuple 23 | """ 24 | 25 | def __init__(self, model, n_layers, x_dim, h_dim, LSTMBuilder): 26 | pc = model.add_subcollection() 27 | self.DeepBiLSTM = [] 28 | f = LSTMBuilder(1, x_dim, h_dim, pc) 29 | b = LSTMBuilder(1, x_dim, h_dim, pc) 30 | self.DeepBiLSTM.append((f, b)) 31 | for i in range(n_layers-1): 32 | f = LSTMBuilder(1, h_dim*2, h_dim, pc) 33 | b = LSTMBuilder(1, h_dim*2, h_dim, pc) 34 | self.DeepBiLSTM.append((f, b)) 35 | 36 | self.pc = pc 37 | self.spec = (n_layers, x_dim, h_dim, LSTMBuilder) 38 | 39 | def __call__(self, inputs, init_vecs=None, p_x=0., p_h=0., out_mask=None, drop_mask=False, train=False): 40 | batch_size = inputs[0].dim()[1] 41 | 42 | if out_mask is not None: 43 | mask = dy.inputTensor(out_mask, True) 44 | for fnn, bnn in self.DeepBiLSTM: 45 | f, b = fnn.initial_state(update=True), bnn.initial_state(update=True) 46 | if train: 47 | fnn.set_dropouts(p_x, p_h) 48 | bnn.set_dropouts(p_x, p_h) 49 | if drop_mask: 50 | fnn.set_dropout_masks(batch_size) 51 | bnn.set_dropout_masks(batch_size) 52 | else: 53 | fnn.set_dropouts(0., 0.) 54 | bnn.set_dropouts(0., 0.) 55 | if drop_mask: 56 | fnn.set_dropout_masks(batch_size) 57 | bnn.set_dropout_masks(batch_size) 58 | fs, bs = f.transduce(inputs), b.transduce(inputs[::-1]) 59 | inputs = [dy.concatenate([f, b]) for f, b in zip(fs, bs[::-1])] 60 | if out_mask is not None: 61 | inputs = [x*mask[i] for i, x in enumerate(inputs)] 62 | return inputs 63 | 64 | 65 | def orthonormal_VanillaLSTMBuilder(n_layers, x_dim, h_dim, pc): 66 | builder = dy.VanillaLSTMBuilder(n_layers, x_dim, h_dim, pc) 67 | 68 | for layer, params in enumerate(builder.get_parameters()): 69 | W = get_orthogonal_matrix( 70 | h_dim, h_dim + (h_dim if layer > 0 else x_dim)) 71 | W_h, W_x = W[:, :h_dim], W[:, h_dim:] 72 | params[0].set_value(np.concatenate([W_x]*4, 0)) 73 | params[1].set_value(np.concatenate([W_h]*4, 0)) 74 | b = np.zeros(4*h_dim, dtype=np.float32) 75 | b[h_dim:2*h_dim] = -1.0 76 | params[2].set_value(b) 77 | return builder 78 | 79 | 80 | def orthonormal_CompactVanillaLSTMBuilder(n_layers, x_dim, h_dim, pc): 81 | builder = dy.CompactVanillaLSTMBuilder(n_layers, x_dim, h_dim, pc) 82 | 83 | for layer, params in enumerate(builder.get_parameters()): 84 | W = get_orthogonal_matrix( 85 | h_dim, h_dim + (h_dim if layer > 0 else x_dim)) 86 | W_h, W_x = W[:, :h_dim], W[:, h_dim:] 87 | params[0].set_value(np.concatenate([W_x]*4, 0)) 88 | params[1].set_value(np.concatenate([W_h]*4, 0)) 89 | b = np.zeros(4*h_dim, dtype=np.float32) 90 | b[h_dim:2*h_dim] = -1.0 91 | params[2].set_value(b) 92 | return builder 93 | -------------------------------------------------------------------------------- /antu/nn/dynet/seq2seq_encoders/seq2seq_encoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import List 3 | import dynet as dy 4 | 5 | 6 | class Seq2SeqEncoder(metaclass=ABCMeta): 7 | """docstring for Seq2seqEncoder""" 8 | 9 | @abstractmethod 10 | def __call__(self, inputs: List[dy.Expression]) -> List[dy.Expression]: 11 | pass 12 | 13 | -------------------------------------------------------------------------------- /antu/nn/dynet/seq2vec_encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntNLP/antu/3256ada0784401b9677d9568e81f3a8792eebee7/antu/nn/dynet/seq2vec_encoders/__init__.py -------------------------------------------------------------------------------- /antu/nn/dynet/seq2vec_encoders/char2word_embedder.py: -------------------------------------------------------------------------------- 1 | import dynet as dy 2 | import ..modules import dy_model 3 | 4 | 5 | @dy_model 6 | class Char2WordCNNEmbedder(object): 7 | """This builds char to word embedder with CNN: 8 | :param model dynet.ParameterCollection: 9 | :param n_char int: Number of char 10 | :param char_dim int: Dimension of char embedding 11 | :param n_filter int: Number of CNN filter 12 | :param win_sizes list: Filter width list 13 | :returns: c2w_emb 14 | :rtype: list 15 | """ 16 | 17 | def __init__(self, model, n_char, char_dim, n_filter, win_sizes): 18 | pc = model.add_subcollection() 19 | 20 | self.clookup = pc.add_lookup_parameters((n_char, char_dim)) 21 | self.Ws = [pc.add_parameters((char_dim, size, 1, n_filter), 22 | init=dy.GlorotInitializer(gain=0.5)) 23 | for size in win_sizes] 24 | self.bs = [pc.add_parameters((n_filter), 25 | init=dy.ConstInitializer(0)) 26 | for _ in win_sizes] 27 | 28 | self.win_sizes = win_sizes 29 | self.pc = pc 30 | self.spec = (n_char, char_dim, n_filter, win_sizes) 31 | 32 | def __call__(self, sentence, c2i, maxn_char, act, train=False): 33 | words_batch = [] 34 | for token in sentence: 35 | chars_emb = [self.clookup[int(c2i.get(c, 0))] for c in token.chars] 36 | c2w = dy.concatenate_cols(chars_emb) 37 | c2w = dy.reshape(c2w, tuple(list(c2w.dim()[0]) + [1])) 38 | words_batch.append(c2w) 39 | 40 | words_batch = dy.concatenate_to_batch(words_batch) 41 | convds = [dy.conv2d(words_batch, W, stride=( 42 | 1, 1), is_valid=True) for W in self.Ws] 43 | actds = [act(convd) for convd in convds] 44 | poolds = [dy.maxpooling2d(actd, ksize=(1, maxn_char-win_size+1), stride=(1, 1)) 45 | for win_size, actd in zip(self.win_sizes, actds)] 46 | words_batch = [dy.reshape(poold, (poold.dim()[0][2],)) 47 | for poold in poolds] 48 | words_batch = dy.concatenate([out for out in words_batch]) 49 | 50 | c2w_emb = [] 51 | for idx, token in enumerate(sentence): 52 | c2w_emb.append(dy.pick_batch_elem(words_batch, idx)) 53 | return c2w_emb 54 | -------------------------------------------------------------------------------- /antu/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntNLP/antu/3256ada0784401b9677d9568e81f3a8792eebee7/antu/utils/__init__.py -------------------------------------------------------------------------------- /antu/utils/case_sensitive_configurator.py: -------------------------------------------------------------------------------- 1 | from configparser import ConfigParser 2 | from overrides import overrides 3 | 4 | 5 | class CaseSensConfigParser(ConfigParser): 6 | 7 | def __init__(self, defaults=None): 8 | ConfigParser.__init__(self, defaults=None) 9 | 10 | @overrides 11 | def optionxform(self, optionstr): 12 | return optionstr 13 | -------------------------------------------------------------------------------- /antu/utils/dual_channel_logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def dual_channel_logger( 5 | name: str, 6 | level: int=logging.DEBUG, 7 | file_path: str=None, 8 | file_model: str='w+', 9 | file_level: int=logging.DEBUG, 10 | console_level: int=logging.DEBUG, 11 | formatter: str='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 12 | time_formatter: str='%y-%m-%d %H:%M:%S',) -> logging.Logger: 13 | 14 | logger = logging.getLogger(name) 15 | logger.setLevel(logging.DEBUG) 16 | formatter = logging.Formatter(formatter, time_formatter) 17 | console_handler = logging.StreamHandler() 18 | console_handler.setLevel(console_level) 19 | console_handler.setFormatter(formatter) 20 | logger.addHandler(console_handler) 21 | if file_path: 22 | file_handler = logging.FileHandler(file_path, file_model) 23 | file_handler.setLevel(file_level) 24 | file_handler.setFormatter(formatter) 25 | logger.addHandler(file_handler) 26 | return logger 27 | 28 | 29 | -------------------------------------------------------------------------------- /antu/utils/padding_function.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def shadow_padding(batch_input, vocabulary): 4 | maxlen = 0 5 | for ins in batch_input: 6 | for field_name in ins: 7 | if isinstance(ins[field_name], dict): 8 | for vocab_name in ins[field_name]: 9 | maxlen = max(maxlen, len(ins[field_name][vocab_name])) 10 | else: 11 | maxlen = max(maxlen, len(ins[field_name])) 12 | 13 | masks = dict() 14 | inputs = dict() 15 | for ins in batch_input: 16 | for field_name in ins: 17 | if isinstance(ins[field_name], list): 18 | if field_name not in masks: 19 | masks[field_name] = dict() 20 | masks[field_name]['1D'] = list() 21 | if field_name not in inputs: 22 | inputs[field_name] = list() 23 | padding_length = maxlen - len(ins[field_name]) 24 | inputs[field_name].append(ins[field_name] + [0] * padding_length) 25 | ins_mask = [1]*(maxlen-padding_length) + [0]*padding_length 26 | masks[field_name]['1D'].append(ins_mask) 27 | else: 28 | # Build batch input 29 | if field_name not in masks: 30 | masks[field_name] = dict() 31 | if field_name not in inputs: 32 | inputs[field_name] = dict() 33 | for vocab_name in ins[field_name]: 34 | padding_length = maxlen - len(ins[field_name][vocab_name]) 35 | if vocab_name not in vocabulary.no_pad_namespace: 36 | padding_index = vocabulary.get_padding_index(vocab_name) 37 | else: padding_index = 0 38 | padding_list = [padding_index] * padding_length 39 | ins_input = ins[field_name][vocab_name] + padding_list 40 | ins_mask = [1]*(maxlen-padding_length) + [0]*padding_length 41 | if vocab_name not in masks[field_name]: 42 | masks[field_name][vocab_name] = dict() 43 | masks[field_name][vocab_name]['1D'] = list() 44 | if vocab_name not in inputs[field_name]: 45 | inputs[field_name][vocab_name] = list() 46 | masks[field_name][vocab_name]['1D'].append(ins_mask) 47 | inputs[field_name][vocab_name].append(ins_input) 48 | 49 | # Build [1D], [2D], [Flat] masks 50 | zero = [0] * maxlen 51 | for _, field in masks.items(): 52 | if '1D' not in field: 53 | for _, vocab in field.items(): 54 | vocab['2D'] = list() # batch_size * sent_len 55 | vocab['flat'] = list() 56 | for ins in vocab['1D']: 57 | no_pad = sum(ins) 58 | vocab['2D'].append([ins] * no_pad) 59 | vocab['2D'][-1].extend([zero] * (maxlen-no_pad)) 60 | vocab['flat'].extend(ins) 61 | else: 62 | field['2D'] = list() # batch_size * sent_len 63 | field['flat'] = list() 64 | for ins in field['1D']: 65 | no_pad = sum(ins) 66 | field['2D'] = [field['1D']] * no_pad 67 | field['2D'].extend([zero] * (maxlen-no_pad)) 68 | field['flat'].extend(ins) 69 | 70 | return inputs, masks -------------------------------------------------------------------------------- /antu/utils/top_k_indexes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def top_k_2D_col_indexes(arr: np.array, k: int): 5 | assert (len(arr.shape) == 2 and k >= 0 and k <= arr.size) 6 | tot_size = arr.size 7 | num_row = arr.shape[0] 8 | res = np.argpartition(arr.T.reshape((tot_size,)), -k)[-k:] // num_row 9 | return res 10 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /doc/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /doc/source/api/antu.io.dataset_readers.rst: -------------------------------------------------------------------------------- 1 | antu.io.dataset\_readers package 2 | ================================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | antu.io.dataset\_readers.dataset\_reader module 8 | ----------------------------------------------- 9 | 10 | .. automodule:: antu.io.dataset_readers.dataset_reader 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | 16 | Module contents 17 | --------------- 18 | 19 | .. automodule:: antu.io.dataset_readers 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /doc/source/api/antu.io.datasets.rst: -------------------------------------------------------------------------------- 1 | antu.io.datasets package 2 | ======================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | antu.io.datasets.dataset module 8 | ------------------------------- 9 | 10 | .. automodule:: antu.io.datasets.dataset 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | 16 | Module contents 17 | --------------- 18 | 19 | .. automodule:: antu.io.datasets 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /doc/source/api/antu.io.fields.rst: -------------------------------------------------------------------------------- 1 | antu.io.fields package 2 | ====================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | antu.io.fields.field module 8 | --------------------------- 9 | 10 | .. automodule:: antu.io.fields.field 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | antu.io.fields.index\_field module 16 | ---------------------------------- 17 | 18 | .. automodule:: antu.io.fields.index_field 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | antu.io.fields.sequence\_label\_field module 24 | -------------------------------------------- 25 | 26 | .. automodule:: antu.io.fields.sequence_label_field 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | antu.io.fields.text\_field module 32 | --------------------------------- 33 | 34 | .. automodule:: antu.io.fields.text_field 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | 40 | Module contents 41 | --------------- 42 | 43 | .. automodule:: antu.io.fields 44 | :members: 45 | :undoc-members: 46 | :show-inheritance: 47 | -------------------------------------------------------------------------------- /doc/source/api/antu.io.rst: -------------------------------------------------------------------------------- 1 | antu.io package 2 | =============== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | antu.io.dataset_readers 10 | antu.io.datasets 11 | antu.io.fields 12 | antu.io.token_indexers 13 | 14 | Submodules 15 | ---------- 16 | 17 | antu.io.instance module 18 | ----------------------- 19 | 20 | .. automodule:: antu.io.instance 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | antu.io.vocabulary module 26 | ------------------------- 27 | 28 | .. automodule:: antu.io.vocabulary 29 | :members: 30 | :undoc-members: 31 | :show-inheritance: 32 | 33 | 34 | Module contents 35 | --------------- 36 | 37 | .. automodule:: antu.io 38 | :members: 39 | :undoc-members: 40 | :show-inheritance: 41 | -------------------------------------------------------------------------------- /doc/source/api/antu.io.token_indexers.rst: -------------------------------------------------------------------------------- 1 | antu.io.token\_indexers package 2 | =============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | antu.io.token\_indexers.char\_token\_indexer module 8 | --------------------------------------------------- 9 | 10 | .. automodule:: antu.io.token_indexers.char_token_indexer 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | antu.io.token\_indexers.single\_id\_token\_indexer module 16 | --------------------------------------------------------- 17 | 18 | .. automodule:: antu.io.token_indexers.single_id_token_indexer 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | antu.io.token\_indexers.token\_indexer module 24 | --------------------------------------------- 25 | 26 | .. automodule:: antu.io.token_indexers.token_indexer 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | 32 | Module contents 33 | --------------- 34 | 35 | .. automodule:: antu.io.token_indexers 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | -------------------------------------------------------------------------------- /doc/source/api/antu.nn.dynet.rst: -------------------------------------------------------------------------------- 1 | antu.nn.dynet package 2 | ===================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | antu.nn.dynet.attention\_mechanism module 8 | ----------------------------------------- 9 | 10 | .. automodule:: antu.nn.dynet.attention_mechanism 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | antu.nn.dynet.char2word\_embedder module 16 | ---------------------------------------- 17 | 18 | .. automodule:: antu.nn.dynet.char2word_embedder 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | antu.nn.dynet.initializer module 24 | -------------------------------- 25 | 26 | .. automodule:: antu.nn.dynet.initializer 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | antu.nn.dynet.multi\_layer\_perception module 32 | --------------------------------------------- 33 | 34 | .. automodule:: antu.nn.dynet.multi_layer_perception 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | antu.nn.dynet.nn\_classifier module 40 | ----------------------------------- 41 | 42 | .. automodule:: antu.nn.dynet.nn_classifier 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | antu.nn.dynet.rnn\_builder module 48 | --------------------------------- 49 | 50 | .. automodule:: antu.nn.dynet.rnn_builder 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | 56 | Module contents 57 | --------------- 58 | 59 | .. automodule:: antu.nn.dynet 60 | :members: 61 | :undoc-members: 62 | :show-inheritance: 63 | -------------------------------------------------------------------------------- /doc/source/api/antu.nn.rst: -------------------------------------------------------------------------------- 1 | antu.nn package 2 | =============== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | antu.nn.dynet 10 | 11 | Module contents 12 | --------------- 13 | 14 | .. automodule:: antu.nn 15 | :members: 16 | :undoc-members: 17 | :show-inheritance: 18 | -------------------------------------------------------------------------------- /doc/source/api/antu.rst: -------------------------------------------------------------------------------- 1 | antu package 2 | ============ 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | antu.io 10 | antu.nn 11 | 12 | Module contents 13 | --------------- 14 | 15 | .. automodule:: antu 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | -------------------------------------------------------------------------------- /doc/source/api/modules.rst: -------------------------------------------------------------------------------- 1 | antu 2 | ==== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | antu 8 | setup 9 | -------------------------------------------------------------------------------- /doc/source/api/setup.rst: -------------------------------------------------------------------------------- 1 | setup module 2 | ============ 3 | 4 | .. automodule:: setup 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /doc/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | sys.path.insert(0, os.path.abspath('../..')) 18 | 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = 'antu' 23 | copyright = '2018, AntNLP' 24 | author = 'AntNLP' 25 | 26 | # The short X.Y version 27 | version = '' 28 | # The full version, including alpha/beta/rc tags 29 | release = '0.0.1' 30 | 31 | 32 | # -- General configuration --------------------------------------------------- 33 | 34 | # If your documentation needs a minimal Sphinx version, state it here. 35 | # 36 | # needs_sphinx = '1.0' 37 | 38 | # Add any Sphinx extension module names here, as strings. They can be 39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 40 | # ones. 41 | extensions = [ 42 | 'sphinx.ext.autodoc', 43 | 'sphinx.ext.doctest', 44 | 'sphinx.ext.intersphinx', 45 | 'sphinx.ext.mathjax', 46 | 'sphinx.ext.viewcode', 47 | 'sphinx.ext.coverage', 48 | 'numpydoc' 49 | ] 50 | 51 | # Add any paths that contain templates here, relative to this directory. 52 | templates_path = ['_templates'] 53 | 54 | # The suffix(es) of source filenames. 55 | # You can specify multiple suffix as a list of string: 56 | # 57 | # source_suffix = ['.rst', '.md'] 58 | source_suffix = '.rst' 59 | 60 | # The master toctree document. 61 | master_doc = 'index' 62 | 63 | # The language for content autogenerated by Sphinx. Refer to documentation 64 | # for a list of supported languages. 65 | # 66 | # This is also used if you do content translation via gettext catalogs. 67 | # Usually you set "language" from the command line for these cases. 68 | language = None 69 | 70 | # List of patterns, relative to source directory, that match files and 71 | # directories to ignore when looking for source files. 72 | # This pattern also affects html_static_path and html_extra_path. 73 | exclude_patterns = [] 74 | 75 | # The name of the Pygments (syntax highlighting) style to use. 76 | pygments_style = None 77 | 78 | 79 | # -- Options for HTML output ------------------------------------------------- 80 | 81 | # The theme to use for HTML and HTML Help pages. See the documentation for 82 | # a list of builtin themes. 83 | # 84 | html_theme = 'sphinx_rtd_theme' 85 | 86 | # Theme options are theme-specific and customize the look and feel of a theme 87 | # further. For a list of options available for each theme, see the 88 | # documentation. 89 | # 90 | # html_theme_options = {} 91 | 92 | # Add any paths that contain custom static files (such as style sheets) here, 93 | # relative to this directory. They are copied after the builtin static files, 94 | # so a file named "default.css" will overwrite the builtin "default.css". 95 | html_static_path = ['_static'] 96 | 97 | # Custom sidebar templates, must be a dictionary that maps document names 98 | # to template names. 99 | # 100 | # The default sidebars (for documents that don't match any pattern) are 101 | # defined by theme itself. Builtin themes are using these templates by 102 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 103 | # 'searchbox.html']``. 104 | # 105 | # html_sidebars = {} 106 | 107 | 108 | # -- Options for HTMLHelp output --------------------------------------------- 109 | 110 | # Output file base name for HTML help builder. 111 | htmlhelp_basename = 'antudoc' 112 | 113 | 114 | # -- Options for LaTeX output ------------------------------------------------ 115 | 116 | latex_elements = { 117 | # The paper size ('letterpaper' or 'a4paper'). 118 | # 119 | # 'papersize': 'letterpaper', 120 | 121 | # The font size ('10pt', '11pt' or '12pt'). 122 | # 123 | # 'pointsize': '10pt', 124 | 125 | # Additional stuff for the LaTeX preamble. 126 | # 127 | # 'preamble': '', 128 | 129 | # Latex figure (float) alignment 130 | # 131 | # 'figure_align': 'htbp', 132 | } 133 | 134 | # Grouping the document tree into LaTeX files. List of tuples 135 | # (source start file, target name, title, 136 | # author, documentclass [howto, manual, or own class]). 137 | latex_documents = [ 138 | (master_doc, 'antu.tex', 'antu Documentation', 139 | 'AntNLP', 'manual'), 140 | ] 141 | 142 | 143 | # -- Options for manual page output ------------------------------------------ 144 | 145 | # One entry per manual page. List of tuples 146 | # (source start file, name, description, authors, manual section). 147 | man_pages = [ 148 | (master_doc, 'antu', 'antu Documentation', 149 | [author], 1) 150 | ] 151 | 152 | 153 | # -- Options for Texinfo output ---------------------------------------------- 154 | 155 | # Grouping the document tree into Texinfo files. List of tuples 156 | # (source start file, target name, title, author, 157 | # dir menu entry, description, category) 158 | texinfo_documents = [ 159 | (master_doc, 'antu', 'antu Documentation', 160 | author, 'antu', 'One line description of project.', 161 | 'Miscellaneous'), 162 | ] 163 | 164 | 165 | # -- Options for Epub output ------------------------------------------------- 166 | 167 | # Bibliographic Dublin Core info. 168 | epub_title = project 169 | 170 | # The unique identifier of the text. This can be a ISBN number 171 | # or the project homepage. 172 | # 173 | # epub_identifier = '' 174 | 175 | # A unique identification for the text. 176 | # 177 | # epub_uid = '' 178 | 179 | # A list of files that should not be packed into the epub file. 180 | epub_exclude_files = ['search.html'] 181 | 182 | 183 | # -- Extension configuration ------------------------------------------------- 184 | 185 | # -- Options for intersphinx extension --------------------------------------- 186 | 187 | # Example configuration for intersphinx: refer to the Python standard library. 188 | intersphinx_mapping = {'https://docs.python.org/': None} -------------------------------------------------------------------------------- /doc/source/index.rst: -------------------------------------------------------------------------------- 1 | .. antu documentation master file, created by 2 | sphinx-quickstart on Mon Dec 24 17:36:31 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to antu's documentation! 7 | ================================ 8 | 9 | Universal data IO and neural network modules in NLP tasks. 10 | 11 | * *data IO is an universal module in Natural Language Processing system and not based on any framework (like TensorFlow, PyTorch, MXNet, Dynet...). 12 | * *neural network module contains the neural network structures commonly used in NLP tasks. We want to design commonly used structures for each neural network framework. We will continue to develop this module. 13 | 14 | .. toctree:: 15 | :maxdepth: 2 16 | :caption: Contents: 17 | 18 | api/antu.io 19 | api/antu.nn 20 | 21 | Indices and tables 22 | ================== 23 | 24 | * :ref:`genindex` 25 | * :ref:`modindex` 26 | * :ref:`search` 27 | -------------------------------------------------------------------------------- /examples/dependency_parsing/conllu_reader.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Dict 2 | from overrides import overrides 3 | import re, sys 4 | from collections import Counter 5 | from antu.io.instance import Instance 6 | from antu.io.fields.field import Field 7 | from antu.io.fields.text_field import TextField 8 | from antu.io.fields.index_field import IndexField 9 | from antu.io.token_indexers.token_indexer import TokenIndexer 10 | from antu.io.token_indexers.single_id_token_indexer import SingleIdTokenIndexer 11 | from antu.io.dataset_readers.dataset_reader import DatasetReader 12 | 13 | 14 | class PTBReader(DatasetReader): 15 | 16 | def __init__( 17 | self, 18 | field_list: List[str], 19 | root: str, 20 | spacer: str): 21 | 22 | self.field_list = field_list 23 | self.root = root 24 | self.spacer = spacer 25 | 26 | def _read(self, file_path: str) -> Instance: 27 | with open(file_path, 'rt') as fp: 28 | root_token = re.split(self.spacer, self.root) 29 | tokens = [[item,] for item in root_token] 30 | for line in fp: 31 | token = re.split(self.spacer, line.strip()) 32 | if line.strip() == '': 33 | if len(tokens[0]) > 1: yield tokens 34 | tokens = [[item,] for item in root_token] 35 | else: 36 | for idx, item in enumerate(token): 37 | tokens[idx].append(item) 38 | if len(tokens[0]) > 1: yield tokens 39 | 40 | @overrides 41 | def read(self, file_path: str) -> List[Instance]: 42 | # Build indexers 43 | indexers = dict() 44 | word_indexer = SingleIdTokenIndexer( 45 | ['word', 'glove'], (lambda x:x.casefold())) 46 | indexers['word'] = [word_indexer,] 47 | tag_indexer = SingleIdTokenIndexer(['tag']) 48 | indexers['tag'] = [tag_indexer,] 49 | rel_indexer = SingleIdTokenIndexer(['rel']) 50 | indexers['rel'] = [rel_indexer,] 51 | 52 | # Build instance list 53 | res = [] 54 | for sentence in self._read(file_path): 55 | res.append(self.input_to_instance(sentence, indexers)) 56 | return res 57 | 58 | @overrides 59 | def input_to_instance( 60 | self, 61 | inputs: List[List[str]], 62 | indexers: Dict[str, List[TokenIndexer]]) -> Instance: 63 | fields = [] 64 | if 'word' in self.field_list: 65 | fields.append(TextField('word', inputs[1], indexers['word'])) 66 | if 'tag' in self.field_list: 67 | fields.append(TextField('tag', inputs[3], indexers['tag'])) 68 | if 'head' in self.field_list: 69 | fields.append(IndexField('head', inputs[6])) 70 | if 'rel' in self.field_list: 71 | fields.append(TextField('rel', inputs[7], indexers['rel'])) 72 | return Instance(fields) -------------------------------------------------------------------------------- /examples/dependency_parsing/train_parser.py: -------------------------------------------------------------------------------- 1 | import argparse, _pickle, math, os, random, sys, time, logging 2 | random.seed(666) 3 | import numpy as np 4 | np.random.seed(666) 5 | from collections import Counter 6 | from antu.io.vocabulary import Vocabulary 7 | from antu.io.ext_embedding_readers import glove_reader 8 | from antu.io.datasets.single_task_dataset import DatasetSetting, SingleTaskDataset 9 | from utils.conllu_reader import PTBReader 10 | 11 | 12 | def main(): 13 | # Configuration file processing 14 | ... 15 | 16 | # DyNet setting 17 | ... 18 | 19 | # Build the dataset of the training process 20 | ## Build data reader 21 | data_reader = PTBReader( 22 | field_list=['word', 'tag', 'head', 'rel'], 23 | root='0\t**root**\t_\t**rpos**\t_\t_\t0\t**rrel**\t_\t_', 24 | spacer=r'[\t]',) 25 | ## Build vocabulary with pretrained glove 26 | vocabulary = Vocabulary() 27 | g_word, _ = glove_reader(cfg.GLOVE) 28 | pretrained_vocabs = {'glove': g_word} 29 | vocabulary.extend_from_pretrained_vocab(pretrained_vocabs) 30 | ## Setup datasets 31 | datasets_settings = { 32 | 'train': DatasetSetting(cfg.TRAIN, True), 33 | 'dev': DatasetSetting(cfg.DEV, True), 34 | 'test': DatasetSetting(cfg.TEST, True),} 35 | datasets = SingleTaskDataset(vocabulary, datasets_settings, data_reader) 36 | counters = {'word': Counter(), 'tag': Counter(), 'rel': Counter()} 37 | datasets.build_dataset( 38 | counters, no_pad_namespace={'rel'}, no_unk_namespace={'rel'}) 39 | 40 | # Build model 41 | ... 42 | 43 | # Train model 44 | train_batch = datasets.get_batches('train', cfg.TRAIN_BATCH_SIZE, True, cmp, True) 45 | valid_batch = datasets.get_batches('dev', cfg.TEST_BATCH_SIZE, True, cmp, False) 46 | test_batch = datasets.get_batches('test', cfg.TEST_BATCH_SIZE, True, cmp, False) 47 | 48 | 49 | 50 | if __name__ == '__main__': 51 | main() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bidict==0.17.5 2 | numpy==1.15.4 3 | numpydoc==0.8.0 4 | overrides==1.9 5 | pytest==4.0.2 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | try: 5 | from setuptools import setup, find_packages 6 | except: 7 | from distutils.core import setup 8 | 9 | def read_file(fname): 10 | return open(os.path.join(os.path.dirname(__file__), fname)).read() 11 | 12 | setup( 13 | name='antu', 14 | version='0.0.5a', 15 | author='AntNLP', 16 | author_email='taoji.cs@gmail.com', 17 | description='Universal data IO and neural network modules in NLP tasks', 18 | long_description = read_file("README.md"), 19 | license='Apache', 20 | packages=find_packages(), 21 | install_requires=[], 22 | classifiers = [ 23 | 'License :: OSI Approved :: Apache Software License', 24 | 'Programming Language :: Python :: 3 :: Only', 25 | 'Topic :: Documentation :: Sphinx', 26 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 27 | ], 28 | url ="https://github.com/AntNLP/antu", 29 | zip_safe=True, 30 | include_package_data=True, 31 | platforms='any', 32 | ) 33 | -------------------------------------------------------------------------------- /test/io/configurators/ini_configurator_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from antu.io.configurators import IniConfigurator 3 | import os 4 | 5 | 6 | class TestIniConfigurator: 7 | 8 | def setup(self): 9 | with open('tmp_test_ini_configurator.ini', 'w') as f: 10 | test = [ 11 | "[Test1]\n", 12 | "A = 123 \n", 13 | "B = 1.1 \n", 14 | "[Test2]\n", 15 | "C = add\n", 16 | "E = 1+3\n", 17 | "D = 7*2\n", 18 | "F = %(E)s*2\n", 19 | "G = %(C)sdda\n", 20 | ] 21 | f.writelines(test) 22 | 23 | def test_ini_configurator(self): 24 | cfg = IniConfigurator('tmp_test_ini_configurator.ini') 25 | assert cfg.A == 123 26 | assert cfg.B == 1.1 27 | assert cfg.C == 'add' 28 | assert cfg.E == 4 29 | assert cfg.D == 14 30 | assert cfg.F == 7 31 | assert cfg.G == 'adddda' 32 | 33 | def teardown(self): 34 | tmp_file = 'tmp_test_ini_configurator.ini' 35 | if os.path.exists(tmp_file): 36 | os.remove(tmp_file) 37 | -------------------------------------------------------------------------------- /test/io/fields/text_field_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from antu.io.fields import TextField 3 | 4 | 5 | class TestTextField: 6 | 7 | def test_textfield(self): 8 | sentence = ['This', 'is', 'a', 'test', 'sentence', '.'] 9 | sent = TextField('sentence', sentence) 10 | print(sent) 11 | assert sent[0] == 'This' 12 | assert sent[-1] == '.' 13 | assert str(sent) == 'sentence: [This, is, a, test, sentence, .]' 14 | -------------------------------------------------------------------------------- /test/io/instance_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from antu.io.token_indexers import SingleIdTokenIndexer, CharTokenIndexer 3 | from antu.io.fields import TextField 4 | from collections import Counter 5 | from antu.io import Vocabulary, Instance 6 | 7 | 8 | class TestInstance: 9 | 10 | def test_instance(self): 11 | sentence = ['This', 'is', 'is', 'a', 'a', 'test', 'sentence'] 12 | counter = {'my_word': Counter(), 'my_char': Counter()} 13 | vocab = Vocabulary() 14 | glove = ['This', 'is', 'glove', 'sentence', 'vocabulary'] 15 | vocab.extend_from_pretrained_vocab({'glove': glove}) 16 | single_id = SingleIdTokenIndexer(['my_word', 'glove']) 17 | char = CharTokenIndexer(['my_char']) 18 | sent = TextField('sentence', sentence, [single_id, char]) 19 | data = Instance([sent]) 20 | 21 | # Test count_vocab_items() 22 | data.count_vocab_items(counter) 23 | assert counter['my_word']['This'] == 1 24 | assert counter['my_word']['is'] == 2 25 | assert counter['my_word']['That'] == 0 26 | assert counter['my_char']['s'] == 5 27 | assert counter['my_char']['T'] == 1 28 | assert counter['my_char']['t'] == 3 29 | assert counter['my_char']['A'] == 0 30 | 31 | vocab.extend_from_counter(counter) 32 | 33 | # Test index() 34 | result = data.index_fields(vocab) 35 | assert result['sentence']['glove'] == [2, 3, 3, 0, 0, 0, 5] 36 | assert result['sentence']['my_word'] == [2, 3, 3, 4, 4, 5, 6] 37 | assert result['sentence']['my_char'][0] == [2, 3, 4, 5] # 'This' 38 | assert result['sentence']['my_char'][1] == result['sentence']['my_char'][2] 39 | assert result['sentence']['my_char'][3] == result['sentence']['my_char'][4] 40 | -------------------------------------------------------------------------------- /test/io/token_indexers/char_token_indexer_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from antu.io.token_indexers import CharTokenIndexer 3 | from antu.io.fields import TextField 4 | from collections import Counter 5 | from antu.io import Vocabulary 6 | 7 | 8 | class TestCharTokenIndexer: 9 | 10 | def test_char_token_indexer(self): 11 | sentence = ['This', 'is', 'is', 'a', 'a', 'test', 'sentence'] 12 | counter = {'my_char': Counter()} 13 | vocab = Vocabulary() 14 | glove = ['a', 'b', 'c', 'd', 'e'] 15 | vocab.extend_from_pretrained_vocab({'glove': glove}) 16 | indexer = CharTokenIndexer(['my_char', 'glove']) 17 | sent = TextField('sentence', sentence, [indexer]) 18 | 19 | # Test count_vocab_items() 20 | sent.count_vocab_items(counter) 21 | assert counter['my_char']['s'] == 5 22 | assert counter['my_char']['T'] == 1 23 | assert counter['my_char']['t'] == 3 24 | assert counter['my_char']['A'] == 0 25 | 26 | vocab.extend_from_counter(counter) 27 | 28 | # Test index() 29 | sent.index(vocab) 30 | assert sent.indexes['glove'][0] == [0, 0, 0, 0] # 'This' 31 | assert sent.indexes['glove'][3] == [2] # 'a' 32 | assert sent.indexes['my_char'][0] == [2, 3, 4, 5] # 'This' 33 | -------------------------------------------------------------------------------- /test/io/token_indexers/single_id_token_indexer_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from antu.io.token_indexers import SingleIdTokenIndexer 3 | from antu.io.fields import TextField 4 | from collections import Counter 5 | from antu.io import Vocabulary 6 | 7 | 8 | class TestSingleIdTokenIndexer: 9 | 10 | def test_single_id_token_indexer(self): 11 | sentence = ['This', 'is', 'is', 'a', 'a', 'test', 'sentence'] 12 | counter = {'my_word': Counter()} 13 | vocab = Vocabulary() 14 | glove = ['This', 'is', 'glove', 'sentence', 'vocabulary'] 15 | vocab.extend_from_pretrained_vocab({'glove': glove}) 16 | indexer = SingleIdTokenIndexer(['my_word', 'glove']) 17 | sent = TextField('sentence', sentence, [indexer]) 18 | 19 | # Test count_vocab_items() 20 | sent.count_vocab_items(counter) 21 | assert counter['my_word']['This'] == 1 22 | assert counter['my_word']['is'] == 2 23 | assert counter['my_word']['That'] == 0 24 | 25 | vocab.extend_from_counter(counter) 26 | 27 | # Test index() 28 | sent.index(vocab) 29 | assert sent.indexes['glove'] == [2, 3, 3, 0, 0, 0, 5] 30 | assert sent.indexes['my_word'] == [2, 3, 3, 4, 4, 5, 6] 31 | -------------------------------------------------------------------------------- /test/io/vocabulary_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from antu.io import Vocabulary 3 | from collections import Counter 4 | 5 | 6 | class TestVocabulary: 7 | 8 | def test_extend_from_pretrained_vocab(self): 9 | vocab = Vocabulary() 10 | 11 | # Test extend a vocabulary from a simple pretained vocab 12 | pretrained_vocabs = {'glove': ['a', 'b', 'c']} 13 | vocab.extend_from_pretrained_vocab(pretrained_vocabs) 14 | assert vocab.get_token_index('a', 'glove') == 2 15 | assert vocab.get_token_index('c', 'glove') == 4 16 | assert vocab.get_token_index('d', 'glove') == 0 17 | 18 | # Test extend a vocabulary from a pretained vocabulary, 19 | # and intersect with another vocabulary. 20 | pretrained_vocabs = {'w2v': ['b', 'c', 'd']} 21 | vocab.extend_from_pretrained_vocab(pretrained_vocabs, {'w2v': 'glove'}) 22 | assert vocab.get_token_index('b', 'w2v') == 2 23 | assert vocab.get_token_index('d', 'w2v') == 0 24 | assert vocab.get_token_from_index(2, 'w2v') == 'b' 25 | with pytest.raises(RuntimeError) as excinfo: 26 | vocab.get_token_from_index(4, 'w2v') 27 | assert excinfo.type == RuntimeError 28 | 29 | # Test extend a vocabulary from a no oov pretained vocabulary 30 | pretrained_vocabs = {'glove_nounk': ['a', 'b', 'c']} 31 | vocab.extend_from_pretrained_vocab( 32 | pretrained_vocabs, no_unk_namespace={'glove_nounk', }) 33 | assert vocab.get_token_index('a', 'glove_nounk') == 1 34 | assert vocab.get_token_index('c', 'glove_nounk') == 3 35 | with pytest.raises(RuntimeError) as excinfo: 36 | vocab.get_token_index('d', 'glove_nounk') 37 | assert excinfo.type == RuntimeError 38 | 39 | # Test extend a vocabulary from a no oov and pad pretained vocabulary 40 | pretrained_vocabs = {'glove_nounk_nopad': ['a', 'b', 'c']} 41 | vocab.extend_from_pretrained_vocab( 42 | pretrained_vocabs, 43 | no_unk_namespace={'glove_nounk_nopad', }, 44 | no_pad_namespace={"glove_nounk_nopad"}) 45 | assert vocab.get_token_index('a', 'glove_nounk_nopad') == 0 46 | assert vocab.get_token_index('c', 'glove_nounk_nopad') == 2 47 | with pytest.raises(RuntimeError) as excinfo: 48 | vocab.get_token_index('d', 'glove_nounk_nopad') 49 | assert excinfo.type == RuntimeError 50 | 51 | def test_extend_from_counter(self): 52 | vocab = Vocabulary() 53 | 54 | # Test extend a vocabulary from a simple counter 55 | counter = {'w': Counter(["This", "is", "a", "test", "sentence", '.'])} 56 | vocab.extend_from_counter(counter) 57 | assert vocab.get_token_index('a', 'w') == 4 58 | assert vocab.get_token_index('.', 'w') == 7 59 | assert vocab.get_token_index('That', 'w') == 0 60 | 61 | # Test extend a vocabulary from a counter with min_count 62 | counter = {'w_m': Counter(['This', 'is', 'is'])} 63 | min_count = {'w_m': 2} 64 | vocab.extend_from_counter(counter, min_count) 65 | assert vocab.get_token_index('is', 'w_m') == 2 66 | assert vocab.get_token_index('This', 'w_m') == 0 67 | assert vocab.get_token_index('That', 'w_m') == 0 68 | 69 | # Test extend a vocabulary from a counter without oov token 70 | counter = {'w_nounk': Counter(['This', 'is'])} 71 | vocab.extend_from_counter(counter, no_unk_namespace={'w_nounk', }) 72 | with pytest.raises(RuntimeError) as excinfo: 73 | vocab.get_token_index('That', 'w_nounk') 74 | assert excinfo.type == RuntimeError 75 | assert vocab.get_token_index('This', 'w_nounk') == 1 76 | 77 | # Test extend a vocabulary from a counter without pad & unk token 78 | counter = {'w_nounk_nopad': Counter(['This', 'is', 'a'])} 79 | vocab.extend_from_counter( 80 | counter, 81 | no_unk_namespace={'w_nounk_nopad'}, 82 | no_pad_namespace={'w_nounk_nopad'}) 83 | with pytest.raises(RuntimeError) as excinfo: 84 | vocab.get_token_index('That', 'w_nounk_nopad') 85 | assert excinfo.type == RuntimeError 86 | assert vocab.get_token_index('This', 'w_nounk_nopad') == 0 87 | 88 | def test_vocabulary(self): 89 | pretrained_vocabs = { 90 | 'glove': ['a', 'b', 'c'], 91 | 'w2v': ['b', 'c', 'd'], 92 | 'glove_nounk': ['a', 'b', 'c'], 93 | 'glove_nounk_nopad': ['a', 'b', 'c']} 94 | 95 | counters = { 96 | 'w': Counter(["This", "is", "a", "test", "sentence", '.']), 97 | 'w_m': Counter(['This', 'is', 'is']), 98 | 'w_nounk': Counter(['This', 'is']), 99 | 'w_nounk_nopad': Counter(['This', 'is', 'a'])} 100 | 101 | vocab = Vocabulary( 102 | counters=counters, 103 | min_count={'w_m': 2}, 104 | pretrained_vocab=pretrained_vocabs, 105 | intersection_vocab={'w2v': 'glove'}, 106 | no_pad_namespace={'glove_nounk_nopad', 'w_nounk_nopad'}, 107 | no_unk_namespace={ 108 | 'glove_nounk', 'w_nounk', 'glove_nounk_nopad', 'w_nounk_nopad'}) 109 | 110 | # Test glove 111 | print(vocab.get_vocab_size('glove')) 112 | assert vocab.get_token_index('a', 'glove') == 2 113 | assert vocab.get_token_index('c', 'glove') == 4 114 | assert vocab.get_token_index('d', 'glove') == 0 115 | 116 | # Test w2v 117 | assert vocab.get_token_index('b', 'w2v') == 2 118 | assert vocab.get_token_index('d', 'w2v') == 0 119 | assert vocab.get_token_from_index(2, 'w2v') == 'b' 120 | with pytest.raises(RuntimeError) as excinfo: 121 | vocab.get_token_from_index(4, 'w2v') 122 | assert excinfo.type == RuntimeError 123 | 124 | # Test glove_nounk 125 | assert vocab.get_token_index('a', 'glove_nounk') == 1 126 | assert vocab.get_token_index('c', 'glove_nounk') == 3 127 | with pytest.raises(RuntimeError) as excinfo: 128 | vocab.get_token_index('d', 'glove_nounk') 129 | assert excinfo.type == RuntimeError 130 | 131 | # Test glove_nounk_nopad 132 | assert vocab.get_token_index('a', 'glove_nounk_nopad') == 0 133 | assert vocab.get_token_index('c', 'glove_nounk_nopad') == 2 134 | with pytest.raises(RuntimeError) as excinfo: 135 | vocab.get_token_index('d', 'glove_nounk_nopad') 136 | assert excinfo.type == RuntimeError 137 | 138 | # Test w 139 | assert vocab.get_token_index('a', 'w') == 4 140 | assert vocab.get_token_index('.', 'w') == 7 141 | assert vocab.get_token_index('That', 'w') == 0 142 | 143 | # Test w_m 144 | assert vocab.get_token_index('is', 'w_m') == 2 145 | assert vocab.get_token_index('This', 'w_m') == 0 146 | assert vocab.get_token_index('That', 'w_m') == 0 147 | 148 | # Test w_nounk 149 | with pytest.raises(RuntimeError) as excinfo: 150 | vocab.get_token_index('That', 'w_nounk') 151 | assert excinfo.type == RuntimeError 152 | assert vocab.get_token_index('This', 'w_nounk') == 1 153 | 154 | # Test w_nounk_nopad 155 | with pytest.raises(RuntimeError) as excinfo: 156 | vocab.get_token_index('That', 'w_nounk_nopad') 157 | assert excinfo.type == RuntimeError 158 | assert vocab.get_token_index('This', 'w_nounk_nopad') == 0 159 | -------------------------------------------------------------------------------- /test/nn/dynet/modules/linear_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import math 3 | import dynet as dy 4 | import numpy as np 5 | from antu.nn.dynet import Linear 6 | from antu.nn.dynet.init import OrthogonalInitializer 7 | 8 | 9 | class TestLinear: 10 | 11 | def test_linear(self): 12 | pc = dy.ParameterCollection() 13 | init = OrthogonalInitializer 14 | affine = Linear(pc, in_dim=10, out_dim=5, bias=True, init=init) 15 | x = dy.random_normal((10,)) 16 | y = affine(x) 17 | assert y.dim() == ((5,), 1) 18 | 19 | init = dy.ConstInitializer(1) 20 | affine = Linear(pc, in_dim=10, out_dim=1, bias=False, init=init) 21 | x = dy.random_normal((10,)) 22 | y = affine(x) 23 | assert math.fabs(np.sum(y.npvalue())-np.sum(x.npvalue())) < 1e-6 24 | -------------------------------------------------------------------------------- /test/nn/dynet/modules/perceptron_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import dynet as dy 3 | from antu.nn.dynet import MLP 4 | from antu.nn.dynet.init import OrthogonalInitializer 5 | 6 | 7 | class TestPerceptron: 8 | 9 | def test_perceptron(self): 10 | pc = dy.ParameterCollection() 11 | init = OrthogonalInitializer 12 | mlp = MLP(pc, [10, 8, 5], init=init) 13 | x = dy.random_normal((10,)) 14 | y = mlp(x, True) 15 | assert y.dim() == ((5,), 1) 16 | 17 | mlp_batch = MLP(pc, [10, 8, 5], p=0.5, init=init) 18 | x = dy.random_normal((10,), batch_size=5) 19 | y = mlp_batch(x, True) 20 | assert y.dim() == ((5,), 5) 21 | -------------------------------------------------------------------------------- /test/nn/dynet/seq2seq_encoders/rnn_builder_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import math 3 | import dynet as dy 4 | import numpy as np 5 | from antu.nn.dynet.seq2seq_encoders import DeepBiRNNBuilder 6 | from antu.nn.dynet.seq2seq_encoders import orthonormal_VanillaLSTMBuilder 7 | 8 | 9 | class TestDeepBiRNNBuilder: 10 | 11 | def test_DeepBiRNNBuilder(self): 12 | pc = dy.ParameterCollection() 13 | 14 | ENC = DeepBiRNNBuilder(pc, 2, 50, 20, orthonormal_VanillaLSTMBuilder) 15 | x = [dy.random_normal((50,)) for _ in range(10)] 16 | y = ENC(x, p_x=0.33, p_h=0.33, train=True) 17 | assert len(y) == 10 18 | assert y[0].dim() == ((40, ), 1) 19 | 20 | --------------------------------------------------------------------------------