├── .github └── images │ ├── arbitrary-sentences.png │ ├── definition-sentences.png │ ├── hyperparameters.png │ └── overview.png ├── .gitignore ├── README.md ├── defsent ├── __init__.py ├── model.py └── pooling.py ├── examples ├── .gitignore ├── poetry.lock ├── pyproject.toml └── src │ └── demo_def2word.py ├── experiments ├── .gitignore ├── README.md ├── configs │ ├── config.yaml │ ├── data_module │ │ └── oxford.yaml │ ├── logger │ │ └── mlflow.yaml │ ├── model │ │ └── default.yaml │ ├── optimizer │ │ ├── adadelta.yaml │ │ ├── adagrad.yaml │ │ ├── adam.yaml │ │ ├── adamax.yaml │ │ ├── adamw.yaml │ │ ├── asgd.yaml │ │ ├── lbfgs.yaml │ │ ├── rmsprop.yaml │ │ ├── rprop.yaml │ │ ├── sgd.yaml │ │ └── sparse_adam.yaml │ ├── scheduler │ │ ├── cosine_annealing.yaml │ │ ├── cosine_annealing_warm_restarts.yaml │ │ ├── cyclic.yaml │ │ ├── exponential.yaml │ │ ├── lambda.yaml │ │ ├── multi_step.yaml │ │ ├── multipricative.yaml │ │ ├── oncyclic.yaml │ │ ├── plateau.yaml │ │ ├── step.yaml │ │ └── warmup.yaml │ ├── tokenizer │ │ └── default.yaml │ └── trainer │ │ └── default.yaml ├── main.py ├── poetry.lock ├── pyproject.toml ├── scripts │ ├── download-dataset.sh │ ├── run-base.sh │ ├── run-bert-base0.sh │ ├── run-bert-base1.sh │ ├── run-bert-large0.sh │ ├── run-bert-large1.sh │ ├── run-large.sh │ ├── run-roberta-base0.sh │ ├── run-roberta-base1.sh │ ├── run-roberta-large0.sh │ └── run-roberta-large1.sh └── src │ ├── data_module.py │ ├── dataset.py │ ├── evaluation │ ├── __init__.py │ ├── def2word.py │ ├── senteval.py │ └── sts.py │ ├── experiment.py │ ├── lr_scheduler.py │ ├── model.py │ ├── pooling.py │ ├── scripts │ └── extract_data_from_ishiwatari.py │ └── utils.py ├── poetry.lock └── pyproject.toml /.github/images/arbitrary-sentences.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hppRC/defsent/d488c2dd374a934613ec8bfe68cdc1ede95b900d/.github/images/arbitrary-sentences.png -------------------------------------------------------------------------------- /.github/images/definition-sentences.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hppRC/defsent/d488c2dd374a934613ec8bfe68cdc1ede95b900d/.github/images/definition-sentences.png -------------------------------------------------------------------------------- /.github/images/hyperparameters.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hppRC/defsent/d488c2dd374a934613ec8bfe68cdc1ede95b900d/.github/images/hyperparameters.png -------------------------------------------------------------------------------- /.github/images/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hppRC/defsent/d488c2dd374a934613ec8bfe68cdc1ede95b900d/.github/images/overview.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DefSent: Sentence Embeddings using Definition Sentences 2 | 3 | This repository contains the experimetns code, pre-trained models, and examples for our paper [DefSent: Sentence Embeddings using Definition Sentences](https://aclanthology.org/2021.acl-short.52/) 4 | 5 | ACL Anthology: https://aclanthology.org/2021.acl-short.52/ 6 | 7 | ## Overview 8 | 9 | 10 | 11 | ## Getting started 12 | 13 | ### Install from PyPI 14 | 15 | ``` 16 | pip install defsent 17 | ``` 18 | 19 | ### Encode sentences into `torch.Tensor` 20 | 21 | 22 | ```python 23 | from defsent import DefSent 24 | 25 | model = DefSent("cl-nagoya/defsent-bert-base-uncased-cls") 26 | embeddings = model.encode([ 27 | "A woman is playing the guitar.", 28 | "A man is playing guitar.", 29 | ]) 30 | ``` 31 | 32 | ### Predict words from input sentences 33 | 34 | ```python 35 | from defsent import DefSent 36 | 37 | model = DefSent("cl-nagoya/defsent-bert-base-uncased-cls") 38 | predicted_words = model.predict_words([ 39 | "be expensive for (someone)", 40 | "an open-source operating system modelled on unix", 41 | "not bad", 42 | ]) 43 | ``` 44 | 45 | Example reults for definition sentences. 46 | 47 | ![](.//.github/images/definition-sentences.png) 48 | 49 | Example reults for sentences other than definition sentences. 50 | 51 | ![](.//.github/images/arbitrary-sentences.png) 52 | 53 | 54 | ## Pretrained checkpoints 55 | 56 | Search: https://huggingface.co/models?search=defsent 57 | 58 | | checkpoint | STS12 | STS13 | STS14 | STS15 | STS16 | STS-B | SICK-R | Avg. | 59 | |--|--|--|--|--|--|--|--|--| 60 | |[defsent-bert-base-uncased-cls](https://huggingface.co/cl-nagoya/defsent-bert-base-uncased-cls)|67.61|80.44|70.12|77.5|76.34|75.25|71.71|74.14| 61 | |[defsent-bert-base-uncased-mean](https://huggingface.co/cl-nagoya/defsent-bert-base-uncased-mean)|68.24|82.62|72.8|78.44|76.79|77.5|71.69|75.44| 62 | |[defsent-bert-base-uncased-max](https://huggingface.co/cl-nagoya/defsent-bert-base-uncased-max)|65.32|82.00|73.00|77.38|75.84|76.74|71.67|74.57| 63 | |[defsent-bert-large-uncased-cls](https://huggingface.co/cl-nagoya/defsent-bert-large-uncased-cls)|67.03|82.41|71.25|80.33|75.43|73.83|73.34|74.8| 64 | |[defsent-bert-large-uncased-mean](https://huggingface.co/cl-nagoya/defsent-bert-large-uncased-mean)|63.93|82.43|73.29|80.52|77.84|78.41|73.39|75.69| 65 | |[defsent-bert-large-uncased-max](https://huggingface.co/cl-nagoya/defsent-bert-large-uncased-max)|60.15|80.70|71.67|77.19|75.71|76.90|72.57|73.55| 66 | |[defsent-roberta-base-cls](https://huggingface.co/cl-nagoya/defsent-roberta-base-cls)|66.13|80.96|72.59|78.33|78.85|78.51|74.44|75.69| 67 | |[defsent-roberta-base-mean](https://huggingface.co/cl-nagoya/defsent-roberta-base-mean)|62.38|78.42|70.79|74.60|77.32|77.38|73.07|73.42| 68 | |[defsent-roberta-base-max](https://huggingface.co/cl-nagoya/defsent-roberta-base-max)|64.61|78.76|70.24|76.07|79.02|78.34|74.54|74.51| 69 | |[defsent-roberta-large-cls](https://huggingface.co/cl-nagoya/defsent-roberta-large-cls)|62.47|79.07|69.87|72.62|77.87|79.11|73.95|73.56| 70 | |[defsent-roberta-large-mean](https://huggingface.co/cl-nagoya/defsent-roberta-large-mean)|57.8|72.98|69.18|72.84|76.50|79.17|74.36|71.83| 71 | |[defsent-roberta-large-max](https://huggingface.co/cl-nagoya/defsent-roberta-large-max)|64.11|81.42|72.52|75.37|80.23|79.16|73.76|75.22| 72 | 73 | ### Hyperparameters for each checkpoint and fine-tuning task performance 74 | 75 | ![](./.github/images/hyperparameters.png) 76 | 77 | 78 | ## Citation 79 | 80 | ```bibtex 81 | @inproceedings{tsukagoshi-etal-2021-defsent, 82 | title = "{D}ef{S}ent: Sentence Embeddings using Definition Sentences", 83 | author = "Tsukagoshi, Hayato and 84 | Sasano, Ryohei and 85 | Takeda, Koichi", 86 | booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 2: Short Papers)", 87 | month = aug, 88 | year = "2021", 89 | address = "Online", 90 | publisher = "Association for Computational Linguistics", 91 | url = "https://aclanthology.org/2021.acl-short.52", 92 | doi = "10.18653/v1/2021.acl-short.52", 93 | pages = "411--418", 94 | } 95 | ``` -------------------------------------------------------------------------------- /defsent/__init__.py: -------------------------------------------------------------------------------- 1 | from defsent.model import DefSent -------------------------------------------------------------------------------- /defsent/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | from typing import List, Tuple, Union 5 | from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer 6 | from defsent.pooling import Pooling 7 | 8 | 9 | class DefSent(nn.Module): 10 | def __init__( 11 | self, 12 | model_name_or_path: str, 13 | device: torch.device = None, 14 | ) -> None: 15 | super(DefSent, self).__init__() 16 | 17 | self.model_name_or_path = model_name_or_path 18 | self.pooling_name = model_name_or_path.rsplit("-", 1)[-1] 19 | 20 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 21 | self.encoder, self.prediction_layer = pretrained_modules(model_name_or_path) 22 | self.pooling = Pooling(pooling_name=self.pooling_name) 23 | 24 | if device is None: 25 | self.device = torch.device("cpu") 26 | else: 27 | self.device = device 28 | self.to(self.device) 29 | 30 | def to(self, device: torch.device) -> None: 31 | self.encoder = self.encoder.to(device) 32 | self.prediction_layer = self.prediction_layer.to(device) 33 | 34 | def forward(self, input_ids: Tensor, attention_mask: Tensor = None) -> Tensor: 35 | embs = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state 36 | emb = self.pooling(embs, attention_mask=attention_mask) 37 | return emb 38 | 39 | def calc_word_logits(self, input_ids: Tensor, attention_mask: Tensor = None) -> Tensor: 40 | emb = self(input_ids, attention_mask=attention_mask) 41 | logits = self.prediction_layer(emb) 42 | return logits 43 | 44 | @torch.no_grad() 45 | def encode( 46 | self, 47 | sentences: Union[str, List[str]], 48 | batch_size: int = 16, 49 | ) -> Tensor: 50 | if isinstance(sentences, str): 51 | sentences = [sentences] 52 | 53 | inputs = self.tokenizer( 54 | sentences, 55 | padding=True, 56 | return_tensors="pt", 57 | truncation=True, 58 | ) 59 | data_loader = torch.utils.data.DataLoader( 60 | list(zip(inputs.input_ids, inputs.attention_mask)), 61 | batch_size=batch_size, 62 | ) 63 | all_embs = [] 64 | for input_ids, attention_mask in data_loader: 65 | input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device) 66 | embs = self.forward(input_ids, attention_mask=attention_mask) 67 | # Prevent overuse of memory. 68 | embs = embs.cpu() 69 | all_embs.append(embs) 70 | 71 | embeddings = torch.cat(all_embs, dim=0) 72 | return embeddings 73 | 74 | @torch.no_grad() 75 | def predict_words( 76 | self, 77 | sentences: Union[str, List[str]], 78 | topk: int = 10, 79 | batch_size: int = 16, 80 | ) -> List[List[str]]: 81 | embs = self.encode( 82 | sentences=sentences, 83 | batch_size=batch_size, 84 | ) 85 | logits: Tensor = self.prediction_layer(embs.to(self.device)).cpu() 86 | hypothesis = logits.topk(topk, dim=1).indices 87 | words = [self.tokenizer.convert_ids_to_tokens(hyp_ids) for hyp_ids in hypothesis] 88 | return words 89 | 90 | 91 | def pretrained_modules(model_name_or_path: str) -> Tuple[nn.Module, nn.Module]: 92 | config = AutoConfig.from_pretrained(model_name_or_path) 93 | 94 | if "BertForMaskedLM" in config.architectures: 95 | pretrained_model = AutoModelForMaskedLM.from_pretrained(model_name_or_path) 96 | encoder = pretrained_model.bert 97 | prediction_layer = pretrained_model.cls 98 | 99 | elif "RobertaForMaskedLM" in config.architectures: 100 | pretrained_model = AutoModelForMaskedLM.from_pretrained(model_name_or_path) 101 | encoder = pretrained_model.roberta 102 | prediction_layer = pretrained_model.lm_head 103 | 104 | else: 105 | raise ValueError(f"No such a pre-trained model! > {model_name_or_path}") 106 | 107 | return encoder, prediction_layer -------------------------------------------------------------------------------- /defsent/pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | 5 | 6 | class Pooling(nn.Module): 7 | def __init__(self, pooling_name: str) -> None: 8 | super().__init__() 9 | self.pooling_name = pooling_name 10 | 11 | def forward(self, x: Tensor, attention_mask: Tensor) -> Tensor: 12 | if self.pooling_name == "cls": 13 | return x[:, 0] 14 | 15 | if self.pooling_name == "sep": 16 | # masked tokens are marked as `0` 17 | sent_len = attention_mask.sum(dim=1, keepdim=True) 18 | batch_size = x.size(0) 19 | batch_indices = torch.LongTensor(range(batch_size)) 20 | sep_indices = (sent_len.long() - 1).squeeze() 21 | return x[batch_indices, sep_indices] 22 | 23 | mask_value = 0 if self.pooling_name in ["mean", "sum"] else -1e6 24 | x[attention_mask.long() == 0, :] = mask_value 25 | 26 | if self.pooling_name == "mean": 27 | sent_len = attention_mask.sum(dim=1, keepdim=True) 28 | return x.sum(dim=1) / sent_len 29 | 30 | elif self.pooling_name == "max": 31 | return x.max(dim=1).values 32 | 33 | elif self.pooling_name == "sum": 34 | return x.sum(dim=1) 35 | 36 | else: 37 | raise ValueError(f"No such a pooling name! {self.pooling_name}") 38 | -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | -------------------------------------------------------------------------------- /examples/poetry.lock: -------------------------------------------------------------------------------- 1 | [[package]] 2 | name = "certifi" 3 | version = "2021.5.30" 4 | description = "Python package for providing Mozilla's CA Bundle." 5 | category = "main" 6 | optional = false 7 | python-versions = "*" 8 | 9 | [[package]] 10 | name = "charset-normalizer" 11 | version = "2.0.4" 12 | description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." 13 | category = "main" 14 | optional = false 15 | python-versions = ">=3.5.0" 16 | 17 | [package.extras] 18 | unicode_backport = ["unicodedata2"] 19 | 20 | [[package]] 21 | name = "click" 22 | version = "8.0.1" 23 | description = "Composable command line interface toolkit" 24 | category = "main" 25 | optional = false 26 | python-versions = ">=3.6" 27 | 28 | [package.dependencies] 29 | colorama = {version = "*", markers = "platform_system == \"Windows\""} 30 | importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} 31 | 32 | [[package]] 33 | name = "colorama" 34 | version = "0.4.4" 35 | description = "Cross-platform colored terminal text." 36 | category = "main" 37 | optional = false 38 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 39 | 40 | [[package]] 41 | name = "defsent" 42 | version = "0.1.0" 43 | description = "DefSent: Sentence Embeddings using Definition Sentences" 44 | category = "main" 45 | optional = false 46 | python-versions = "^3.7" 47 | develop = false 48 | 49 | [package.dependencies] 50 | torch = "*" 51 | transformers = "*" 52 | 53 | [package.source] 54 | type = "directory" 55 | url = ".." 56 | 57 | [[package]] 58 | name = "filelock" 59 | version = "3.0.12" 60 | description = "A platform independent file lock." 61 | category = "main" 62 | optional = false 63 | python-versions = "*" 64 | 65 | [[package]] 66 | name = "huggingface-hub" 67 | version = "0.0.12" 68 | description = "Client library to download and publish models on the huggingface.co hub" 69 | category = "main" 70 | optional = false 71 | python-versions = ">=3.6.0" 72 | 73 | [package.dependencies] 74 | filelock = "*" 75 | importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} 76 | packaging = ">=20.9" 77 | requests = "*" 78 | tqdm = "*" 79 | typing-extensions = "*" 80 | 81 | [package.extras] 82 | all = ["pytest", "black (>=20.8b1)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"] 83 | dev = ["pytest", "black (>=20.8b1)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"] 84 | quality = ["black (>=20.8b1)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"] 85 | testing = ["pytest"] 86 | torch = ["torch"] 87 | 88 | [[package]] 89 | name = "idna" 90 | version = "3.2" 91 | description = "Internationalized Domain Names in Applications (IDNA)" 92 | category = "main" 93 | optional = false 94 | python-versions = ">=3.5" 95 | 96 | [[package]] 97 | name = "importlib-metadata" 98 | version = "4.6.3" 99 | description = "Read metadata from Python packages" 100 | category = "main" 101 | optional = false 102 | python-versions = ">=3.6" 103 | 104 | [package.dependencies] 105 | typing-extensions = {version = ">=3.6.4", markers = "python_version < \"3.8\""} 106 | zipp = ">=0.5" 107 | 108 | [package.extras] 109 | docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"] 110 | perf = ["ipython"] 111 | testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "packaging", "pep517", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy", "importlib-resources (>=1.3)"] 112 | 113 | [[package]] 114 | name = "joblib" 115 | version = "1.0.1" 116 | description = "Lightweight pipelining with Python functions" 117 | category = "main" 118 | optional = false 119 | python-versions = ">=3.6" 120 | 121 | [[package]] 122 | name = "numpy" 123 | version = "1.21.1" 124 | description = "NumPy is the fundamental package for array computing with Python." 125 | category = "main" 126 | optional = false 127 | python-versions = ">=3.7" 128 | 129 | [[package]] 130 | name = "packaging" 131 | version = "21.0" 132 | description = "Core utilities for Python packages" 133 | category = "main" 134 | optional = false 135 | python-versions = ">=3.6" 136 | 137 | [package.dependencies] 138 | pyparsing = ">=2.0.2" 139 | 140 | [[package]] 141 | name = "pyparsing" 142 | version = "2.4.7" 143 | description = "Python parsing module" 144 | category = "main" 145 | optional = false 146 | python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" 147 | 148 | [[package]] 149 | name = "pyyaml" 150 | version = "5.4.1" 151 | description = "YAML parser and emitter for Python" 152 | category = "main" 153 | optional = false 154 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" 155 | 156 | [[package]] 157 | name = "regex" 158 | version = "2021.7.6" 159 | description = "Alternative regular expression module, to replace re." 160 | category = "main" 161 | optional = false 162 | python-versions = "*" 163 | 164 | [[package]] 165 | name = "requests" 166 | version = "2.26.0" 167 | description = "Python HTTP for Humans." 168 | category = "main" 169 | optional = false 170 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" 171 | 172 | [package.dependencies] 173 | certifi = ">=2017.4.17" 174 | charset-normalizer = {version = ">=2.0.0,<2.1.0", markers = "python_version >= \"3\""} 175 | idna = {version = ">=2.5,<4", markers = "python_version >= \"3\""} 176 | urllib3 = ">=1.21.1,<1.27" 177 | 178 | [package.extras] 179 | socks = ["PySocks (>=1.5.6,!=1.5.7)", "win-inet-pton"] 180 | use_chardet_on_py3 = ["chardet (>=3.0.2,<5)"] 181 | 182 | [[package]] 183 | name = "sacremoses" 184 | version = "0.0.45" 185 | description = "SacreMoses" 186 | category = "main" 187 | optional = false 188 | python-versions = "*" 189 | 190 | [package.dependencies] 191 | click = "*" 192 | joblib = "*" 193 | regex = "*" 194 | six = "*" 195 | tqdm = "*" 196 | 197 | [[package]] 198 | name = "six" 199 | version = "1.16.0" 200 | description = "Python 2 and 3 compatibility utilities" 201 | category = "main" 202 | optional = false 203 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" 204 | 205 | [[package]] 206 | name = "tokenizers" 207 | version = "0.10.3" 208 | description = "Fast and Customizable Tokenizers" 209 | category = "main" 210 | optional = false 211 | python-versions = "*" 212 | 213 | [package.extras] 214 | testing = ["pytest"] 215 | 216 | [[package]] 217 | name = "torch" 218 | version = "1.9.0" 219 | description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" 220 | category = "main" 221 | optional = false 222 | python-versions = ">=3.6.2" 223 | 224 | [package.dependencies] 225 | typing-extensions = "*" 226 | 227 | [[package]] 228 | name = "tqdm" 229 | version = "4.62.0" 230 | description = "Fast, Extensible Progress Meter" 231 | category = "main" 232 | optional = false 233 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" 234 | 235 | [package.dependencies] 236 | colorama = {version = "*", markers = "platform_system == \"Windows\""} 237 | 238 | [package.extras] 239 | dev = ["py-make (>=0.1.0)", "twine", "wheel"] 240 | notebook = ["ipywidgets (>=6)"] 241 | telegram = ["requests"] 242 | 243 | [[package]] 244 | name = "transformers" 245 | version = "4.9.1" 246 | description = "State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch" 247 | category = "main" 248 | optional = false 249 | python-versions = ">=3.6.0" 250 | 251 | [package.dependencies] 252 | filelock = "*" 253 | huggingface-hub = "0.0.12" 254 | importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} 255 | numpy = ">=1.17" 256 | packaging = "*" 257 | pyyaml = ">=5.1" 258 | regex = "!=2019.12.17" 259 | requests = "*" 260 | sacremoses = "*" 261 | tokenizers = ">=0.10.1,<0.11" 262 | tqdm = ">=4.27" 263 | 264 | [package.extras] 265 | all = ["tensorflow (>=2.3)", "onnxconverter-common", "keras2onnx", "torch (>=1.0)", "jax (>=0.2.8)", "jaxlib (>=0.1.65)", "flax (>=0.3.4)", "optax (>=0.0.8)", "sentencepiece (==0.1.91)", "protobuf", "tokenizers (>=0.10.1,<0.11)", "soundfile", "torchaudio", "pillow", "optuna", "ray", "timm", "codecarbon (==1.2.0)"] 266 | codecarbon = ["codecarbon (==1.2.0)"] 267 | deepspeed = ["deepspeed (>=0.4.3)"] 268 | dev = ["tensorflow (>=2.3)", "onnxconverter-common", "keras2onnx", "torch (>=1.0)", "jax (>=0.2.8)", "jaxlib (>=0.1.65)", "flax (>=0.3.4)", "optax (>=0.0.8)", "sentencepiece (==0.1.91)", "protobuf", "tokenizers (>=0.10.1,<0.11)", "soundfile", "torchaudio", "pillow", "optuna", "ray", "timm", "codecarbon (==1.2.0)", "pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-timeout", "black (==21.4b0)", "sacrebleu (>=1.4.12)", "rouge-score", "nltk", "gitpython", "faiss-cpu", "cookiecutter (==1.7.2)", "isort (>=5.5.4)", "flake8 (>=3.8.3)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "unidic-lite (>=1.0.7)", "unidic (>=1.0.2)", "docutils (==0.16.0)", "recommonmark", "sphinx (==3.2.1)", "sphinx-markdown-tables", "sphinx-rtd-theme (==0.4.3)", "sphinx-copybutton", "sphinxext-opengraph (==0.4.1)", "scikit-learn"] 269 | docs = ["tensorflow (>=2.3)", "onnxconverter-common", "keras2onnx", "torch (>=1.0)", "jax (>=0.2.8)", "jaxlib (>=0.1.65)", "flax (>=0.3.4)", "optax (>=0.0.8)", "sentencepiece (==0.1.91)", "protobuf", "tokenizers (>=0.10.1,<0.11)", "soundfile", "torchaudio", "pillow", "optuna", "ray", "timm", "codecarbon (==1.2.0)", "docutils (==0.16.0)", "recommonmark", "sphinx (==3.2.1)", "sphinx-markdown-tables", "sphinx-rtd-theme (==0.4.3)", "sphinx-copybutton", "sphinxext-opengraph (==0.4.1)"] 270 | docs_specific = ["docutils (==0.16.0)", "recommonmark", "sphinx (==3.2.1)", "sphinx-markdown-tables", "sphinx-rtd-theme (==0.4.3)", "sphinx-copybutton", "sphinxext-opengraph (==0.4.1)"] 271 | fairscale = ["fairscale (>0.3)"] 272 | flax = ["jax (>=0.2.8)", "jaxlib (>=0.1.65)", "flax (>=0.3.4)", "optax (>=0.0.8)"] 273 | integrations = ["optuna", "ray"] 274 | ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "unidic-lite (>=1.0.7)", "unidic (>=1.0.2)"] 275 | modelcreation = ["cookiecutter (==1.7.2)"] 276 | onnx = ["onnxconverter-common", "keras2onnx", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] 277 | onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] 278 | optuna = ["optuna"] 279 | quality = ["black (==21.4b0)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"] 280 | ray = ["ray"] 281 | retrieval = ["faiss-cpu", "datasets"] 282 | sagemaker = ["sagemaker (>=2.31.0)"] 283 | sentencepiece = ["sentencepiece (==0.1.91)", "protobuf"] 284 | serving = ["pydantic", "uvicorn", "fastapi", "starlette"] 285 | sklearn = ["scikit-learn"] 286 | speech = ["soundfile", "torchaudio"] 287 | testing = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-timeout", "black (==21.4b0)", "sacrebleu (>=1.4.12)", "rouge-score", "nltk", "gitpython", "faiss-cpu", "cookiecutter (==1.7.2)"] 288 | tf = ["tensorflow (>=2.3)", "onnxconverter-common", "keras2onnx"] 289 | tf-cpu = ["tensorflow-cpu (>=2.3)", "onnxconverter-common", "keras2onnx"] 290 | timm = ["timm"] 291 | tokenizers = ["tokenizers (>=0.10.1,<0.11)"] 292 | torch = ["torch (>=1.0)"] 293 | torchhub = ["filelock", "huggingface-hub (==0.0.12)", "importlib-metadata", "numpy (>=1.17)", "packaging", "protobuf", "regex (!=2019.12.17)", "requests", "sacremoses", "sentencepiece (==0.1.91)", "torch (>=1.0)", "tokenizers (>=0.10.1,<0.11)", "tqdm (>=4.27)"] 294 | vision = ["pillow"] 295 | 296 | [[package]] 297 | name = "typing-extensions" 298 | version = "3.10.0.0" 299 | description = "Backported and Experimental Type Hints for Python 3.5+" 300 | category = "main" 301 | optional = false 302 | python-versions = "*" 303 | 304 | [[package]] 305 | name = "urllib3" 306 | version = "1.26.6" 307 | description = "HTTP library with thread-safe connection pooling, file post, and more." 308 | category = "main" 309 | optional = false 310 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4" 311 | 312 | [package.extras] 313 | brotli = ["brotlipy (>=0.6.0)"] 314 | secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "ipaddress"] 315 | socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] 316 | 317 | [[package]] 318 | name = "zipp" 319 | version = "3.5.0" 320 | description = "Backport of pathlib-compatible object wrapper for zip files" 321 | category = "main" 322 | optional = false 323 | python-versions = ">=3.6" 324 | 325 | [package.extras] 326 | docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"] 327 | testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy"] 328 | 329 | [metadata] 330 | lock-version = "1.1" 331 | python-versions = "^3.7" 332 | content-hash = "9bd4696df2d264e7022ef58ad426df690651b6a9ee53238cb36fb47d9d1f2694" 333 | 334 | [metadata.files] 335 | certifi = [ 336 | {file = "certifi-2021.5.30-py2.py3-none-any.whl", hash = "sha256:50b1e4f8446b06f41be7dd6338db18e0990601dce795c2b1686458aa7e8fa7d8"}, 337 | {file = "certifi-2021.5.30.tar.gz", hash = "sha256:2bbf76fd432960138b3ef6dda3dde0544f27cbf8546c458e60baf371917ba9ee"}, 338 | ] 339 | charset-normalizer = [ 340 | {file = "charset-normalizer-2.0.4.tar.gz", hash = "sha256:f23667ebe1084be45f6ae0538e4a5a865206544097e4e8bbcacf42cd02a348f3"}, 341 | {file = "charset_normalizer-2.0.4-py3-none-any.whl", hash = "sha256:0c8911edd15d19223366a194a513099a302055a962bca2cec0f54b8b63175d8b"}, 342 | ] 343 | click = [ 344 | {file = "click-8.0.1-py3-none-any.whl", hash = "sha256:fba402a4a47334742d782209a7c79bc448911afe1149d07bdabdf480b3e2f4b6"}, 345 | {file = "click-8.0.1.tar.gz", hash = "sha256:8c04c11192119b1ef78ea049e0a6f0463e4c48ef00a30160c704337586f3ad7a"}, 346 | ] 347 | colorama = [ 348 | {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, 349 | {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, 350 | ] 351 | defsent = [] 352 | filelock = [ 353 | {file = "filelock-3.0.12-py3-none-any.whl", hash = "sha256:929b7d63ec5b7d6b71b0fa5ac14e030b3f70b75747cef1b10da9b879fef15836"}, 354 | {file = "filelock-3.0.12.tar.gz", hash = "sha256:18d82244ee114f543149c66a6e0c14e9c4f8a1044b5cdaadd0f82159d6a6ff59"}, 355 | ] 356 | huggingface-hub = [ 357 | {file = "huggingface_hub-0.0.12-py3-none-any.whl", hash = "sha256:5c82ff96897a72e1ed48a94c1796686f120dea05888200522f3994f130c12e6a"}, 358 | {file = "huggingface_hub-0.0.12.tar.gz", hash = "sha256:661b17fab0c475276fd71603ee7e16c3b3d1d6e812e1b29f40144f64d361e59d"}, 359 | ] 360 | idna = [ 361 | {file = "idna-3.2-py3-none-any.whl", hash = "sha256:14475042e284991034cb48e06f6851428fb14c4dc953acd9be9a5e95c7b6dd7a"}, 362 | {file = "idna-3.2.tar.gz", hash = "sha256:467fbad99067910785144ce333826c71fb0e63a425657295239737f7ecd125f3"}, 363 | ] 364 | importlib-metadata = [ 365 | {file = "importlib_metadata-4.6.3-py3-none-any.whl", hash = "sha256:51c6635429c77cf1ae634c997ff9e53ca3438b495f10a55ba28594dd69764a8b"}, 366 | {file = "importlib_metadata-4.6.3.tar.gz", hash = "sha256:0645585859e9a6689c523927a5032f2ba5919f1f7d0e84bd4533312320de1ff9"}, 367 | ] 368 | joblib = [ 369 | {file = "joblib-1.0.1-py3-none-any.whl", hash = "sha256:feeb1ec69c4d45129954f1b7034954241eedfd6ba39b5e9e4b6883be3332d5e5"}, 370 | {file = "joblib-1.0.1.tar.gz", hash = "sha256:9c17567692206d2f3fb9ecf5e991084254fe631665c450b443761c4186a613f7"}, 371 | ] 372 | numpy = [ 373 | {file = "numpy-1.21.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:38e8648f9449a549a7dfe8d8755a5979b45b3538520d1e735637ef28e8c2dc50"}, 374 | {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:fd7d7409fa643a91d0a05c7554dd68aa9c9bb16e186f6ccfe40d6e003156e33a"}, 375 | {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a75b4498b1e93d8b700282dc8e655b8bd559c0904b3910b144646dbbbc03e062"}, 376 | {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1412aa0aec3e00bc23fbb8664d76552b4efde98fb71f60737c83efbac24112f1"}, 377 | {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e46ceaff65609b5399163de5893d8f2a82d3c77d5e56d976c8b5fb01faa6b671"}, 378 | {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:c6a2324085dd52f96498419ba95b5777e40b6bcbc20088fddb9e8cbb58885e8e"}, 379 | {file = "numpy-1.21.1-cp37-cp37m-win32.whl", hash = "sha256:73101b2a1fef16602696d133db402a7e7586654682244344b8329cdcbbb82172"}, 380 | {file = "numpy-1.21.1-cp37-cp37m-win_amd64.whl", hash = "sha256:7a708a79c9a9d26904d1cca8d383bf869edf6f8e7650d85dbc77b041e8c5a0f8"}, 381 | {file = "numpy-1.21.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:95b995d0c413f5d0428b3f880e8fe1660ff9396dcd1f9eedbc311f37b5652e16"}, 382 | {file = "numpy-1.21.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:635e6bd31c9fb3d475c8f44a089569070d10a9ef18ed13738b03049280281267"}, 383 | {file = "numpy-1.21.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4a3d5fb89bfe21be2ef47c0614b9c9c707b7362386c9a3ff1feae63e0267ccb6"}, 384 | {file = "numpy-1.21.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8a326af80e86d0e9ce92bcc1e65c8ff88297de4fa14ee936cb2293d414c9ec63"}, 385 | {file = "numpy-1.21.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:791492091744b0fe390a6ce85cc1bf5149968ac7d5f0477288f78c89b385d9af"}, 386 | {file = "numpy-1.21.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0318c465786c1f63ac05d7c4dbcecd4d2d7e13f0959b01b534ea1e92202235c5"}, 387 | {file = "numpy-1.21.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9a513bd9c1551894ee3d31369f9b07460ef223694098cf27d399513415855b68"}, 388 | {file = "numpy-1.21.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:91c6f5fc58df1e0a3cc0c3a717bb3308ff850abdaa6d2d802573ee2b11f674a8"}, 389 | {file = "numpy-1.21.1-cp38-cp38-win32.whl", hash = "sha256:978010b68e17150db8765355d1ccdd450f9fc916824e8c4e35ee620590e234cd"}, 390 | {file = "numpy-1.21.1-cp38-cp38-win_amd64.whl", hash = "sha256:9749a40a5b22333467f02fe11edc98f022133ee1bfa8ab99bda5e5437b831214"}, 391 | {file = "numpy-1.21.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:d7a4aeac3b94af92a9373d6e77b37691b86411f9745190d2c351f410ab3a791f"}, 392 | {file = "numpy-1.21.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d9e7912a56108aba9b31df688a4c4f5cb0d9d3787386b87d504762b6754fbb1b"}, 393 | {file = "numpy-1.21.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:25b40b98ebdd272bc3020935427a4530b7d60dfbe1ab9381a39147834e985eac"}, 394 | {file = "numpy-1.21.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8a92c5aea763d14ba9d6475803fc7904bda7decc2a0a68153f587ad82941fec1"}, 395 | {file = "numpy-1.21.1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:05a0f648eb28bae4bcb204e6fd14603de2908de982e761a2fc78efe0f19e96e1"}, 396 | {file = "numpy-1.21.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f01f28075a92eede918b965e86e8f0ba7b7797a95aa8d35e1cc8821f5fc3ad6a"}, 397 | {file = "numpy-1.21.1-cp39-cp39-win32.whl", hash = "sha256:88c0b89ad1cc24a5efbb99ff9ab5db0f9a86e9cc50240177a571fbe9c2860ac2"}, 398 | {file = "numpy-1.21.1-cp39-cp39-win_amd64.whl", hash = "sha256:01721eefe70544d548425a07c80be8377096a54118070b8a62476866d5208e33"}, 399 | {file = "numpy-1.21.1-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2d4d1de6e6fb3d28781c73fbde702ac97f03d79e4ffd6598b880b2d95d62ead4"}, 400 | {file = "numpy-1.21.1.zip", hash = "sha256:dff4af63638afcc57a3dfb9e4b26d434a7a602d225b42d746ea7fe2edf1342fd"}, 401 | ] 402 | packaging = [ 403 | {file = "packaging-21.0-py3-none-any.whl", hash = "sha256:c86254f9220d55e31cc94d69bade760f0847da8000def4dfe1c6b872fd14ff14"}, 404 | {file = "packaging-21.0.tar.gz", hash = "sha256:7dc96269f53a4ccec5c0670940a4281106dd0bb343f47b7471f779df49c2fbe7"}, 405 | ] 406 | pyparsing = [ 407 | {file = "pyparsing-2.4.7-py2.py3-none-any.whl", hash = "sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b"}, 408 | {file = "pyparsing-2.4.7.tar.gz", hash = "sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1"}, 409 | ] 410 | pyyaml = [ 411 | {file = "PyYAML-5.4.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:3b2b1824fe7112845700f815ff6a489360226a5609b96ec2190a45e62a9fc922"}, 412 | {file = "PyYAML-5.4.1-cp27-cp27m-win32.whl", hash = "sha256:129def1b7c1bf22faffd67b8f3724645203b79d8f4cc81f674654d9902cb4393"}, 413 | {file = "PyYAML-5.4.1-cp27-cp27m-win_amd64.whl", hash = "sha256:4465124ef1b18d9ace298060f4eccc64b0850899ac4ac53294547536533800c8"}, 414 | {file = "PyYAML-5.4.1-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:bb4191dfc9306777bc594117aee052446b3fa88737cd13b7188d0e7aa8162185"}, 415 | {file = "PyYAML-5.4.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:6c78645d400265a062508ae399b60b8c167bf003db364ecb26dcab2bda048253"}, 416 | {file = "PyYAML-5.4.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:4e0583d24c881e14342eaf4ec5fbc97f934b999a6828693a99157fde912540cc"}, 417 | {file = "PyYAML-5.4.1-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:72a01f726a9c7851ca9bfad6fd09ca4e090a023c00945ea05ba1638c09dc3347"}, 418 | {file = "PyYAML-5.4.1-cp36-cp36m-manylinux2014_s390x.whl", hash = "sha256:895f61ef02e8fed38159bb70f7e100e00f471eae2bc838cd0f4ebb21e28f8541"}, 419 | {file = "PyYAML-5.4.1-cp36-cp36m-win32.whl", hash = "sha256:3bd0e463264cf257d1ffd2e40223b197271046d09dadf73a0fe82b9c1fc385a5"}, 420 | {file = "PyYAML-5.4.1-cp36-cp36m-win_amd64.whl", hash = "sha256:e4fac90784481d221a8e4b1162afa7c47ed953be40d31ab4629ae917510051df"}, 421 | {file = "PyYAML-5.4.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:5accb17103e43963b80e6f837831f38d314a0495500067cb25afab2e8d7a4018"}, 422 | {file = "PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:e1d4970ea66be07ae37a3c2e48b5ec63f7ba6804bdddfdbd3cfd954d25a82e63"}, 423 | {file = "PyYAML-5.4.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:cb333c16912324fd5f769fff6bc5de372e9e7a202247b48870bc251ed40239aa"}, 424 | {file = "PyYAML-5.4.1-cp37-cp37m-manylinux2014_s390x.whl", hash = "sha256:fe69978f3f768926cfa37b867e3843918e012cf83f680806599ddce33c2c68b0"}, 425 | {file = "PyYAML-5.4.1-cp37-cp37m-win32.whl", hash = "sha256:dd5de0646207f053eb0d6c74ae45ba98c3395a571a2891858e87df7c9b9bd51b"}, 426 | {file = "PyYAML-5.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:08682f6b72c722394747bddaf0aa62277e02557c0fd1c42cb853016a38f8dedf"}, 427 | {file = "PyYAML-5.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d2d9808ea7b4af864f35ea216be506ecec180628aced0704e34aca0b040ffe46"}, 428 | {file = "PyYAML-5.4.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:8c1be557ee92a20f184922c7b6424e8ab6691788e6d86137c5d93c1a6ec1b8fb"}, 429 | {file = "PyYAML-5.4.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:fd7f6999a8070df521b6384004ef42833b9bd62cfee11a09bda1079b4b704247"}, 430 | {file = "PyYAML-5.4.1-cp38-cp38-manylinux2014_s390x.whl", hash = "sha256:bfb51918d4ff3d77c1c856a9699f8492c612cde32fd3bcd344af9be34999bfdc"}, 431 | {file = "PyYAML-5.4.1-cp38-cp38-win32.whl", hash = "sha256:fa5ae20527d8e831e8230cbffd9f8fe952815b2b7dae6ffec25318803a7528fc"}, 432 | {file = "PyYAML-5.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:0f5f5786c0e09baddcd8b4b45f20a7b5d61a7e7e99846e3c799b05c7c53fa696"}, 433 | {file = "PyYAML-5.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:294db365efa064d00b8d1ef65d8ea2c3426ac366c0c4368d930bf1c5fb497f77"}, 434 | {file = "PyYAML-5.4.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:74c1485f7707cf707a7aef42ef6322b8f97921bd89be2ab6317fd782c2d53183"}, 435 | {file = "PyYAML-5.4.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d483ad4e639292c90170eb6f7783ad19490e7a8defb3e46f97dfe4bacae89122"}, 436 | {file = "PyYAML-5.4.1-cp39-cp39-manylinux2014_s390x.whl", hash = "sha256:fdc842473cd33f45ff6bce46aea678a54e3d21f1b61a7750ce3c498eedfe25d6"}, 437 | {file = "PyYAML-5.4.1-cp39-cp39-win32.whl", hash = "sha256:49d4cdd9065b9b6e206d0595fee27a96b5dd22618e7520c33204a4a3239d5b10"}, 438 | {file = "PyYAML-5.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:c20cfa2d49991c8b4147af39859b167664f2ad4561704ee74c1de03318e898db"}, 439 | {file = "PyYAML-5.4.1.tar.gz", hash = "sha256:607774cbba28732bfa802b54baa7484215f530991055bb562efbed5b2f20a45e"}, 440 | ] 441 | regex = [ 442 | {file = "regex-2021.7.6-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:e6a1e5ca97d411a461041d057348e578dc344ecd2add3555aedba3b408c9f874"}, 443 | {file = "regex-2021.7.6-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:6afe6a627888c9a6cfbb603d1d017ce204cebd589d66e0703309b8048c3b0854"}, 444 | {file = "regex-2021.7.6-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:ccb3d2190476d00414aab36cca453e4596e8f70a206e2aa8db3d495a109153d2"}, 445 | {file = "regex-2021.7.6-cp36-cp36m-manylinux2010_i686.whl", hash = "sha256:ed693137a9187052fc46eedfafdcb74e09917166362af4cc4fddc3b31560e93d"}, 446 | {file = "regex-2021.7.6-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:99d8ab206a5270c1002bfcf25c51bf329ca951e5a169f3b43214fdda1f0b5f0d"}, 447 | {file = "regex-2021.7.6-cp36-cp36m-manylinux2014_i686.whl", hash = "sha256:b85ac458354165405c8a84725de7bbd07b00d9f72c31a60ffbf96bb38d3e25fa"}, 448 | {file = "regex-2021.7.6-cp36-cp36m-manylinux2014_x86_64.whl", hash = "sha256:3f5716923d3d0bfb27048242a6e0f14eecdb2e2a7fac47eda1d055288595f222"}, 449 | {file = "regex-2021.7.6-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5983c19d0beb6af88cb4d47afb92d96751fb3fa1784d8785b1cdf14c6519407"}, 450 | {file = "regex-2021.7.6-cp36-cp36m-win32.whl", hash = "sha256:c92831dac113a6e0ab28bc98f33781383fe294df1a2c3dfd1e850114da35fd5b"}, 451 | {file = "regex-2021.7.6-cp36-cp36m-win_amd64.whl", hash = "sha256:791aa1b300e5b6e5d597c37c346fb4d66422178566bbb426dd87eaae475053fb"}, 452 | {file = "regex-2021.7.6-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:59506c6e8bd9306cd8a41511e32d16d5d1194110b8cfe5a11d102d8b63cf945d"}, 453 | {file = "regex-2021.7.6-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:564a4c8a29435d1f2256ba247a0315325ea63335508ad8ed938a4f14c4116a5d"}, 454 | {file = "regex-2021.7.6-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:59c00bb8dd8775473cbfb967925ad2c3ecc8886b3b2d0c90a8e2707e06c743f0"}, 455 | {file = "regex-2021.7.6-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:9a854b916806c7e3b40e6616ac9e85d3cdb7649d9e6590653deb5b341a736cec"}, 456 | {file = "regex-2021.7.6-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:db2b7df831c3187a37f3bb80ec095f249fa276dbe09abd3d35297fc250385694"}, 457 | {file = "regex-2021.7.6-cp37-cp37m-manylinux2014_i686.whl", hash = "sha256:173bc44ff95bc1e96398c38f3629d86fa72e539c79900283afa895694229fe6a"}, 458 | {file = "regex-2021.7.6-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:15dddb19823f5147e7517bb12635b3c82e6f2a3a6b696cc3e321522e8b9308ad"}, 459 | {file = "regex-2021.7.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ddeabc7652024803666ea09f32dd1ed40a0579b6fbb2a213eba590683025895"}, 460 | {file = "regex-2021.7.6-cp37-cp37m-win32.whl", hash = "sha256:f080248b3e029d052bf74a897b9d74cfb7643537fbde97fe8225a6467fb559b5"}, 461 | {file = "regex-2021.7.6-cp37-cp37m-win_amd64.whl", hash = "sha256:d8bbce0c96462dbceaa7ac4a7dfbbee92745b801b24bce10a98d2f2b1ea9432f"}, 462 | {file = "regex-2021.7.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:edd1a68f79b89b0c57339bce297ad5d5ffcc6ae7e1afdb10f1947706ed066c9c"}, 463 | {file = "regex-2021.7.6-cp38-cp38-manylinux1_i686.whl", hash = "sha256:422dec1e7cbb2efbbe50e3f1de36b82906def93ed48da12d1714cabcd993d7f0"}, 464 | {file = "regex-2021.7.6-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cbe23b323988a04c3e5b0c387fe3f8f363bf06c0680daf775875d979e376bd26"}, 465 | {file = "regex-2021.7.6-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:0eb2c6e0fcec5e0f1d3bcc1133556563222a2ffd2211945d7b1480c1b1a42a6f"}, 466 | {file = "regex-2021.7.6-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:1c78780bf46d620ff4fff40728f98b8afd8b8e35c3efd638c7df67be2d5cddbf"}, 467 | {file = "regex-2021.7.6-cp38-cp38-manylinux2014_i686.whl", hash = "sha256:bc84fb254a875a9f66616ed4538542fb7965db6356f3df571d783f7c8d256edd"}, 468 | {file = "regex-2021.7.6-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:598c0a79b4b851b922f504f9f39a863d83ebdfff787261a5ed061c21e67dd761"}, 469 | {file = "regex-2021.7.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:875c355360d0f8d3d827e462b29ea7682bf52327d500a4f837e934e9e4656068"}, 470 | {file = "regex-2021.7.6-cp38-cp38-win32.whl", hash = "sha256:e586f448df2bbc37dfadccdb7ccd125c62b4348cb90c10840d695592aa1b29e0"}, 471 | {file = "regex-2021.7.6-cp38-cp38-win_amd64.whl", hash = "sha256:2fe5e71e11a54e3355fa272137d521a40aace5d937d08b494bed4529964c19c4"}, 472 | {file = "regex-2021.7.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6110bab7eab6566492618540c70edd4d2a18f40ca1d51d704f1d81c52d245026"}, 473 | {file = "regex-2021.7.6-cp39-cp39-manylinux1_i686.whl", hash = "sha256:4f64fc59fd5b10557f6cd0937e1597af022ad9b27d454e182485f1db3008f417"}, 474 | {file = "regex-2021.7.6-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:89e5528803566af4df368df2d6f503c84fbfb8249e6631c7b025fe23e6bd0cde"}, 475 | {file = "regex-2021.7.6-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:2366fe0479ca0e9afa534174faa2beae87847d208d457d200183f28c74eaea59"}, 476 | {file = "regex-2021.7.6-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:f9392a4555f3e4cb45310a65b403d86b589adc773898c25a39184b1ba4db8985"}, 477 | {file = "regex-2021.7.6-cp39-cp39-manylinux2014_i686.whl", hash = "sha256:2bceeb491b38225b1fee4517107b8491ba54fba77cf22a12e996d96a3c55613d"}, 478 | {file = "regex-2021.7.6-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:f98dc35ab9a749276f1a4a38ab3e0e2ba1662ce710f6530f5b0a6656f1c32b58"}, 479 | {file = "regex-2021.7.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:319eb2a8d0888fa6f1d9177705f341bc9455a2c8aca130016e52c7fe8d6c37a3"}, 480 | {file = "regex-2021.7.6-cp39-cp39-win32.whl", hash = "sha256:eaf58b9e30e0e546cdc3ac06cf9165a1ca5b3de8221e9df679416ca667972035"}, 481 | {file = "regex-2021.7.6-cp39-cp39-win_amd64.whl", hash = "sha256:4c9c3155fe74269f61e27617529b7f09552fbb12e44b1189cebbdb24294e6e1c"}, 482 | {file = "regex-2021.7.6.tar.gz", hash = "sha256:8394e266005f2d8c6f0bc6780001f7afa3ef81a7a2111fa35058ded6fce79e4d"}, 483 | ] 484 | requests = [ 485 | {file = "requests-2.26.0-py2.py3-none-any.whl", hash = "sha256:6c1246513ecd5ecd4528a0906f910e8f0f9c6b8ec72030dc9fd154dc1a6efd24"}, 486 | {file = "requests-2.26.0.tar.gz", hash = "sha256:b8aa58f8cf793ffd8782d3d8cb19e66ef36f7aba4353eec859e74678b01b07a7"}, 487 | ] 488 | sacremoses = [ 489 | {file = "sacremoses-0.0.45-py3-none-any.whl", hash = "sha256:fa93db44bc04542553ba6090818b892f603d02aa0d681e6c5c3023baf17e8564"}, 490 | {file = "sacremoses-0.0.45.tar.gz", hash = "sha256:58176cc28391830789b763641d0f458819bebe88681dac72b41a19c0aedc07e9"}, 491 | ] 492 | six = [ 493 | {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, 494 | {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, 495 | ] 496 | tokenizers = [ 497 | {file = "tokenizers-0.10.3-cp36-cp36m-macosx_10_11_x86_64.whl", hash = "sha256:4ab688daf4692a6c31dfe42f1f3a4a8c22050705eb69d58d3efde9d55f434586"}, 498 | {file = "tokenizers-0.10.3-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c26dbc3b2a3d71d3d40c50975ec62145932f05aea73f03ea35c48ebd3a717611"}, 499 | {file = "tokenizers-0.10.3-cp36-cp36m-win32.whl", hash = "sha256:6b84673997990b3c260ae2f7c57fdf1f835e316820eff14aca46dc68be3c0c74"}, 500 | {file = "tokenizers-0.10.3-cp36-cp36m-win_amd64.whl", hash = "sha256:2a9ee3ee574d4aa740e099b0ad6ef8e63f52f48cde359bb31801146a5aa614dc"}, 501 | {file = "tokenizers-0.10.3-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:2f8c5fefef0d0a03be613547e613fbda06b9e6ee0891236649524964c3e54d80"}, 502 | {file = "tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:4cc194104c8e427ffc4f54c7866488b42f2b1f6351a6cad0d045ca5ab8108e42"}, 503 | {file = "tokenizers-0.10.3-cp37-cp37m-win32.whl", hash = "sha256:edd8cb85c16b4b65e87ea5ef9d400be9fdd53c4152adbaca8817e16dd3aa480b"}, 504 | {file = "tokenizers-0.10.3-cp37-cp37m-win_amd64.whl", hash = "sha256:7b11b373705d082d43657c08883b79b5330f1952f0668d17488b6b889c4d7feb"}, 505 | {file = "tokenizers-0.10.3-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:a7ce0c2f27f7c92aa3f895231de90319acdf960ce2e42ba591edc651fda7d3c9"}, 506 | {file = "tokenizers-0.10.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ae7e40d9c8a77c5a4109731ac3e21633b0c609c56a8b58be6b863da61fa54636"}, 507 | {file = "tokenizers-0.10.3-cp38-cp38-win32.whl", hash = "sha256:a7ce051aafc53c564c9edbc09df300c2bd4f6ce87460fc22a276fed405d1892a"}, 508 | {file = "tokenizers-0.10.3-cp38-cp38-win_amd64.whl", hash = "sha256:91a8c045980594c7c437a52c3da5276eb3c530a662b4ef628ff32d81fb22b543"}, 509 | {file = "tokenizers-0.10.3-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:1d8867db210d75d97312360ae23b92aeb6a6b5bc65e15c1cd9d204b3fa3fc262"}, 510 | {file = "tokenizers-0.10.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:18c495e700f4588b9a00e58b4c41dc459c36daaa7c39a27faf880eb8f5533ce1"}, 511 | {file = "tokenizers-0.10.3-cp39-cp39-win32.whl", hash = "sha256:ad700fd9da518884fd58bf89f0b6dfeecef9b4e2d2db8765ef259f66d6c14980"}, 512 | {file = "tokenizers-0.10.3-cp39-cp39-win_amd64.whl", hash = "sha256:e9d147e545cdfeca560646c7a703bf287afe45645da426506ccd5eb78aab5ef5"}, 513 | {file = "tokenizers-0.10.3.tar.gz", hash = "sha256:1a5d3b596c6d3a237e1ad7f46c472d467b0246be7fd1a364f12576eb8db8f7e6"}, 514 | ] 515 | torch = [ 516 | {file = "torch-1.9.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:3a2d070cf28860d285d4ab156f3954c0c1d12f4c037aa312a7c029227c0d106b"}, 517 | {file = "torch-1.9.0-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:b296e65e25081af147af936f1e3a1f17f583a9afacfa5309742678ffef728ace"}, 518 | {file = "torch-1.9.0-cp36-cp36m-win_amd64.whl", hash = "sha256:117098d4924b260a24a47c6b3fe37f2ae41f04a2ea2eff9f553ae9210b12fa54"}, 519 | {file = "torch-1.9.0-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:d6103b9a634993bd967337a1149f9d8b23922f42a3660676239399e15c1b4515"}, 520 | {file = "torch-1.9.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:0164673908e6b291ace592d382eba3e258b3bad009b8078cad8f3b9e00d8f23e"}, 521 | {file = "torch-1.9.0-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:52548b45efff772fe3810fe91daf34f981ac0ca1a7227f6226fd5693f53b5b88"}, 522 | {file = "torch-1.9.0-cp37-cp37m-win_amd64.whl", hash = "sha256:62c0a7e433681d0861494d1ede96d2485e4dbb3ea8fd867e8419addebf5de1af"}, 523 | {file = "torch-1.9.0-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:d88333091fd1627894bbf0d6dcef58a90e36bdf0d90a5d4675b5e07e72075511"}, 524 | {file = "torch-1.9.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:1d8139dcc864f48dc316376384f50e47a459284ad1cb84449242f4964e25aaec"}, 525 | {file = "torch-1.9.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:0aa4cca3f16fab40cb8dae6a49d0eccdc8f4ead9d1a6428cd9ba12befe082b2a"}, 526 | {file = "torch-1.9.0-cp38-cp38-win_amd64.whl", hash = "sha256:646de1bef85d6c7590e98f8ea52e47acdcf58330982e4f5d73f5ca28dea2d552"}, 527 | {file = "torch-1.9.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:e596f0105f748cf09d4763152d8157aaf58d5231232eaf2c5673d4562ba86ad3"}, 528 | {file = "torch-1.9.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:ecc7193fff7741ced3db1f760666c8454d6664956288c54d1b49613b987a42f4"}, 529 | {file = "torch-1.9.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:95eeec3a6c42fd35aca552777b7d9979ed489760423de97c0118a45e849a61f4"}, 530 | {file = "torch-1.9.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:8a2b2012b3c7d6019e189496688fa77de7029a220840b406d8302d1c8021a11c"}, 531 | {file = "torch-1.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:7e2b14fe5b3a8266cbe2f6740c0195497507974ced7bc21e99971561913a0c28"}, 532 | {file = "torch-1.9.0-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:0a9e74b5057463ce4e55d9332a5670993fc9e1299c52e1740e505eda106fb355"}, 533 | {file = "torch-1.9.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:569ead6ae6bb0e636df0fc8af660ef03260e630dc5f2f4cf3198027e7b6bb481"}, 534 | ] 535 | tqdm = [ 536 | {file = "tqdm-4.62.0-py2.py3-none-any.whl", hash = "sha256:706dea48ee05ba16e936ee91cb3791cd2ea6da348a0e50b46863ff4363ff4340"}, 537 | {file = "tqdm-4.62.0.tar.gz", hash = "sha256:3642d483b558eec80d3c831e23953582c34d7e4540db86d9e5ed9dad238dabc6"}, 538 | ] 539 | transformers = [ 540 | {file = "transformers-4.9.1-py3-none-any.whl", hash = "sha256:86f3c46efecf114c6886d361c1d6cca14738f0e9d1effadb1e9252770cba55a0"}, 541 | {file = "transformers-4.9.1.tar.gz", hash = "sha256:1c30e38b2e0da15e110d9bb9a627f78de9569b9c6036d6533baf783015c339be"}, 542 | ] 543 | typing-extensions = [ 544 | {file = "typing_extensions-3.10.0.0-py2-none-any.whl", hash = "sha256:0ac0f89795dd19de6b97debb0c6af1c70987fd80a2d62d1958f7e56fcc31b497"}, 545 | {file = "typing_extensions-3.10.0.0-py3-none-any.whl", hash = "sha256:779383f6086d90c99ae41cf0ff39aac8a7937a9283ce0a414e5dd782f4c94a84"}, 546 | {file = "typing_extensions-3.10.0.0.tar.gz", hash = "sha256:50b6f157849174217d0656f99dc82fe932884fb250826c18350e159ec6cdf342"}, 547 | ] 548 | urllib3 = [ 549 | {file = "urllib3-1.26.6-py2.py3-none-any.whl", hash = "sha256:39fb8672126159acb139a7718dd10806104dec1e2f0f6c88aab05d17df10c8d4"}, 550 | {file = "urllib3-1.26.6.tar.gz", hash = "sha256:f57b4c16c62fa2760b7e3d97c35b255512fb6b59a259730f36ba32ce9f8e342f"}, 551 | ] 552 | zipp = [ 553 | {file = "zipp-3.5.0-py3-none-any.whl", hash = "sha256:957cfda87797e389580cb8b9e3870841ca991e2125350677b2ca83a0e99390a3"}, 554 | {file = "zipp-3.5.0.tar.gz", hash = "sha256:f5812b1e007e48cff63449a5e9f4e7ebea716b4111f9c4f9a645f91d579bf0c4"}, 555 | ] 556 | -------------------------------------------------------------------------------- /examples/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "defsent/examples" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["hppRC "] 6 | 7 | [tool.poetry.dependencies] 8 | python = "^3.7" 9 | defsent = {path = "../"} 10 | 11 | [tool.poetry.dev-dependencies] 12 | 13 | [build-system] 14 | requires = ["poetry-core>=1.0.0"] 15 | build-backend = "poetry.core.masonry.api" 16 | -------------------------------------------------------------------------------- /examples/src/demo_def2word.py: -------------------------------------------------------------------------------- 1 | from defsent import DefSent 2 | 3 | def main(): 4 | model = DefSent("cl-nagoya/defsent-bert-base-uncased-cls") 5 | print("please input any sentences!") 6 | while True: 7 | sentence = input("> ") 8 | [words] = model.predict_words(sentence) 9 | line = " ".join(words) 10 | print(f"predicted: {line}") 11 | 12 | if __name__ == "__main__": 13 | main() -------------------------------------------------------------------------------- /experiments/.gitignore: -------------------------------------------------------------------------------- 1 | dataset 2 | logs 3 | mlruns 4 | models 5 | huggingface 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | 140 | # pytype static type analyzer 141 | .pytype/ 142 | 143 | # Cython debug symbols 144 | cython_debug/ 145 | 146 | -------------------------------------------------------------------------------- /experiments/README.md: -------------------------------------------------------------------------------- 1 | # DefSent: Sentence Embeddings using Definition Sentences / Experiments 2 | 3 | arxiv: [https://arxiv.org/abs/2105.04339](https://arxiv.org/abs/2105.04339) 4 | 5 | ## Installation 6 | 7 | ```bash 8 | poetry install 9 | ``` 10 | 11 | ## Download datasets and run pre-process 12 | 13 | ```bash 14 | bash ./scripts/download-dataset.sh 15 | poetry run python src/scripts/extract_data_from_ishiwatari.py 16 | ``` 17 | 18 | 19 | ## Run an experiment 20 | 21 | ```bash 22 | poetry run python main.py save_model=True model_name=bert-base-uncased pooling_name=CLS 23 | ``` 24 | 25 | For more detailed configurations, see `configs` directory. 26 | We use [hydra](https://github.com/facebookresearch/hydra) for configurations. 27 | 28 | 29 | ## Start Mlflow Server 30 | 31 | ```bash 32 | poetry run mlflow ui 33 | # access http://127.0.0.1:5000 34 | ``` 35 | 36 | 37 | ## Run Formatter 38 | 39 | ```bash 40 | poetry run pysen run format 41 | ``` 42 | 43 | ## Share models 44 | 45 | ``` 46 | 47 | huggingface-cli repo create defsent-bert-base-uncased-cls 48 | git clone https://huggingface.co/cl-nagoya/defsent-bert-base-uncased-cls 49 | mv /path/to/saved_model/* ./defsent-bert-base-uncased-cls/ 50 | cd ./defsent-bert-base-uncased-cls/ 51 | git add -A 52 | git commit -m ":tada: Add pre-trained model" 53 | ``` -------------------------------------------------------------------------------- /experiments/configs/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - model: default 6 | - tokenizer: default 7 | - trainer: default 8 | - data_module: oxford 9 | - optimizer: adam 10 | - scheduler: warmup 11 | - logger: mlflow 12 | 13 | # enable color logging 14 | - override hydra/job_logging: colorlog 15 | - override hydra/hydra_logging: colorlog 16 | 17 | 18 | # path to original working directory (that `main.py` was executed from in command line) 19 | # hydra hijacks working directory by changing it to the current log directory, 20 | # so it's useful to have path to original work dir as a special variable 21 | # read more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 22 | work_dir: ${hydra:runtime.cwd} 23 | 24 | # global configurations 25 | experiment_name: Default 26 | gpus: [0] 27 | lr: 1e-5 28 | epochs: 1 29 | batch_size: 16 30 | 31 | model_name: bert-base-uncased 32 | pooling_name: Mean 33 | 34 | dataset_dir: ${work_dir}/dataset/ 35 | save_model: True 36 | 37 | d2w: 38 | topk: 10 39 | save_predictions: True 40 | 41 | sts: 42 | data_dir: ${dataset_dir}/sts 43 | do_whitening: False 44 | to_lower: False 45 | 46 | 47 | hydra: 48 | # output paths for hydra logs 49 | run: 50 | dir: logs/runs/${experiment_name}/${now:%Y-%m-%d}/${now:%H:%M:%S}/${hydra.job.override_dirname} 51 | sweep: 52 | dir: logs/multiruns/${experiment_name}/${now:%Y-%m-%d}/${now:%H:%M:%S}/ 53 | subdir: ${hydra.job.override_dirname} 54 | 55 | job: 56 | # you can set here environment variables that are universal for all users 57 | # for system specific variables (like data paths) it's better to use .env file! 58 | env_set: 59 | # currently there are some issues with running sweeps alongside wandb 60 | # https://github.com/wandb/client/issues/1314 61 | # this env var fixes that 62 | WANDB_START_METHOD: thread 63 | TOKENIZERS_PARALLELISM: 'false' 64 | # you can set here environment variables that are universal for all users 65 | # for system specific variables (like data paths) it's better to use .env file! 66 | # env_set: 67 | config: 68 | # configuration for the ${hydra.job.override_dirname} runtime variable 69 | override_dirname: 70 | kv_sep: '=' 71 | item_sep: '/' -------------------------------------------------------------------------------- /experiments/configs/data_module/oxford.yaml: -------------------------------------------------------------------------------- 1 | # @package data_module 2 | 3 | _target_: src.data_module.DataModule 4 | 5 | batch_size: ${batch_size} 6 | tokenizer: ??? 7 | data_dir: ${dataset_dir}/oxford -------------------------------------------------------------------------------- /experiments/configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/stable/extensions/generated/pytorch_lightning.loggers.MLFlowLogger.html 2 | _target_: pytorch_lightning.loggers.MLFlowLogger 3 | 4 | experiment_name: ${experiment_name} 5 | tracking_uri: file://${work_dir}/mlruns 6 | tags: 7 | save_dir: ./mlruns 8 | prefix: '' 9 | artifact_location: -------------------------------------------------------------------------------- /experiments/configs/model/default.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | 3 | _target_: src.model.DefSent 4 | model_name: ${model_name} 5 | pooling_name: ${pooling_name} 6 | 7 | randomize_prediction_layer: False 8 | freeze_prediction_layer: True 9 | freeze_token_embeddings: True 10 | -------------------------------------------------------------------------------- /experiments/configs/optimizer/adadelta.yaml: -------------------------------------------------------------------------------- 1 | # @package optimizer 2 | # https://pytorch.org/docs/stable/generated/torch.optim.Adadelta.html 3 | 4 | _target_: torch.optim.Adadelta 5 | 6 | params: ??? 7 | lr: ${lr} # default: 1.0 8 | rho: 0.9 9 | eps: 1e-06 10 | weight_decay: 0 -------------------------------------------------------------------------------- /experiments/configs/optimizer/adagrad.yaml: -------------------------------------------------------------------------------- 1 | # @package optimizer 2 | # https://pytorch.org/docs/stable/generated/torch.optim.Adagrad.html 3 | 4 | _target_: torch.optim.Adagrad 5 | 6 | params: ??? 7 | lr: ${lr} # default: 0.01 8 | lr_decay: 0 9 | weight_decay: 0 10 | initial_accumulator_value: 0 11 | eps: 1e-10 -------------------------------------------------------------------------------- /experiments/configs/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | # @package optimizer 2 | # https://pytorch.org/docs/stable/generated/torch.optim.Adam.html 3 | 4 | _target_: torch.optim.Adam 5 | params: ??? 6 | lr: ${lr} # default: 0.001 7 | betas: [0.9, 0.999] 8 | eps: 1e-08 9 | weight_decay: 0 10 | amsgrad: False -------------------------------------------------------------------------------- /experiments/configs/optimizer/adamax.yaml: -------------------------------------------------------------------------------- 1 | # @package optimizer 2 | # https://pytorch.org/docs/stable/generated/torch.optim.Adamax.html 3 | 4 | _target_: torch.optim.Adamax 5 | 6 | params: ??? 7 | lr: ${lr} # default: 0.002 8 | betas: [0.9, 0.999] 9 | eps: 1e-08 10 | weight_decay: 0 -------------------------------------------------------------------------------- /experiments/configs/optimizer/adamw.yaml: -------------------------------------------------------------------------------- 1 | # @package optimizer 2 | # https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html 3 | 4 | _target_: torch.optim.AdamW 5 | 6 | params: ??? 7 | lr: ${lr} # default: 0.001 8 | betas: [0.9, 0.999] 9 | eps: 1e-08 10 | weight_decay: 0.01 11 | amsgrad: False -------------------------------------------------------------------------------- /experiments/configs/optimizer/asgd.yaml: -------------------------------------------------------------------------------- 1 | # @package optimizer 2 | # https://pytorch.org/docs/stable/generated/torch.optim.ASGD.html 3 | 4 | _target_: torch.optim.ASGD 5 | 6 | params: ??? 7 | lr: ${lr} # default: 0.01 8 | lambd: 0.0001 9 | alpha: 0.75 10 | t0: 1000000.0 11 | weight_decay: 0 -------------------------------------------------------------------------------- /experiments/configs/optimizer/lbfgs.yaml: -------------------------------------------------------------------------------- 1 | # @package optimizer 2 | # https://pytorch.org/docs/stable/generated/torch.optim.LBFGS.html 3 | 4 | _target_: torch.optim.LBFGS 5 | 6 | params: ??? 7 | lr: ${lr} # default: 1.0 8 | max_iter: 20 9 | max_eval: 10 | tolerance_grad: 1e-07 11 | tolerance_change: 1e-09 12 | history_size: 100 13 | line_search_fn: -------------------------------------------------------------------------------- /experiments/configs/optimizer/rmsprop.yaml: -------------------------------------------------------------------------------- 1 | # @package optimizer 2 | # https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html 3 | 4 | _target_: torch.optim.RMSprop 5 | 6 | params: ??? 7 | lr: ${lr} # 0.01 8 | alpha: 0.99 9 | eps: 1e-08 10 | weight_decay: 0 11 | momentum: 0 12 | centered: False -------------------------------------------------------------------------------- /experiments/configs/optimizer/rprop.yaml: -------------------------------------------------------------------------------- 1 | # @package optimizer 2 | # https://pytorch.org/docs/stable/generated/torch.optim.Rprop.html 3 | 4 | _target_: torch.optim.Rprop 5 | 6 | params: ??? 7 | lr: ${lr} # 0.01 8 | etas: [0.5, 1.2] 9 | step_sizes: [1e-06, 50] -------------------------------------------------------------------------------- /experiments/configs/optimizer/sgd.yaml: -------------------------------------------------------------------------------- 1 | # @package optimizer 2 | # https://pytorch.org/docs/stable/generated/torch.optim.SGD.html 3 | 4 | _target_: torch.optim.SGD 5 | 6 | params: ??? 7 | lr: ${lr} 8 | momentum: 0 9 | dampening: 0 10 | weight_decay: 0 11 | nesterov: False -------------------------------------------------------------------------------- /experiments/configs/optimizer/sparse_adam.yaml: -------------------------------------------------------------------------------- 1 | # @package optimizer 2 | # https://pytorch.org/docs/stable/generated/torch.optim.SparseAdam.html 3 | 4 | _target_: torch.optim.SparseAdam 5 | 6 | params: ??? 7 | lr: ${lr} # default: 0.001 8 | betas: [0.9, 0.999] 9 | eps: 1e-08 -------------------------------------------------------------------------------- /experiments/configs/scheduler/cosine_annealing.yaml: -------------------------------------------------------------------------------- 1 | # @package scheduler 2 | # https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html 3 | 4 | _target_: torch.optim.lr_scheduler.CosineAnnealingLR 5 | 6 | optimizer: ??? 7 | T_max: ??? 8 | 9 | eta_min: 0 10 | last_epoch: -1 11 | verbose: False -------------------------------------------------------------------------------- /experiments/configs/scheduler/cosine_annealing_warm_restarts.yaml: -------------------------------------------------------------------------------- 1 | # @package scheduler 2 | # https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingWarmRestarts.html 3 | 4 | _target_: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts 5 | 6 | optimizer: ??? 7 | T_0: ??? 8 | 9 | T_mult: 1 10 | eta_min: 0 11 | last_epoch: -1 12 | verbose: False -------------------------------------------------------------------------------- /experiments/configs/scheduler/cyclic.yaml: -------------------------------------------------------------------------------- 1 | # @package scheduler 2 | # https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CyclicLR.html 3 | 4 | _target_: torch.optim.lr_scheduler.CyclicLR 5 | 6 | optimizer: ??? 7 | base_lr: ??? 8 | max_lr: ??? 9 | 10 | step_size_up: 2000 11 | step_size_down: 12 | mode: triangular 13 | gamma: 1.0 14 | scale_fn: 15 | scale_mode: cycle 16 | cycle_momentum: True 17 | base_momentum: 0.8 18 | max_momentum: 0.9 19 | last_epoch: -1 20 | verbose: False -------------------------------------------------------------------------------- /experiments/configs/scheduler/exponential.yaml: -------------------------------------------------------------------------------- 1 | # @package scheduler 2 | # https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ExponentialLR.html 3 | 4 | _target_: torch.optim.lr_scheduler.ExponentialLR 5 | 6 | optimizer: ??? 7 | gamma: ??? 8 | 9 | last_epoch: -1 10 | verbose: False -------------------------------------------------------------------------------- /experiments/configs/scheduler/lambda.yaml: -------------------------------------------------------------------------------- 1 | # @package scheduler 2 | # https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.LambdaLR.html 3 | 4 | _target_: torch.optim.lr_scheduler.LambdaLR 5 | 6 | optimizer: ??? 7 | lr_lambda: ??? 8 | 9 | last_epoch: -1 10 | verbose: False -------------------------------------------------------------------------------- /experiments/configs/scheduler/multi_step.yaml: -------------------------------------------------------------------------------- 1 | # @package scheduler 2 | # https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.MultiStepLR.html 3 | 4 | _target_: torch.optim.lr_scheduler.MultiStepLR 5 | 6 | optimizer: ??? 7 | milestones: ??? 8 | 9 | gamma: 0.1 10 | last_epoch: -1 11 | verbose: False -------------------------------------------------------------------------------- /experiments/configs/scheduler/multipricative.yaml: -------------------------------------------------------------------------------- 1 | # @package scheduler 2 | # https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.MultiplicativeLR.html 3 | 4 | _target_: torch.optim.lr_scheduler.MultiplicativeLR 5 | 6 | optimizer: ??? 7 | lr_lambda: ??? 8 | 9 | last_epoch: -1 10 | verbose: False -------------------------------------------------------------------------------- /experiments/configs/scheduler/oncyclic.yaml: -------------------------------------------------------------------------------- 1 | # @package scheduler 2 | # https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html 3 | 4 | _target_: torch.optim.lr_scheduler.OneCycleLR 5 | 6 | optimizer: ??? 7 | max_lr: ??? 8 | 9 | total_steps: 10 | epochs: 11 | steps_per_epoch: 12 | pct_start: 0.3 13 | anneal_strategy: cos 14 | cycle_momentum: True 15 | base_momentum: 0.85 16 | max_momentum: 0.95 17 | div_factor: 25.0 18 | final_div_factor: 10000.0 19 | three_phase: False 20 | last_epoch: -1 21 | verbose: False -------------------------------------------------------------------------------- /experiments/configs/scheduler/plateau.yaml: -------------------------------------------------------------------------------- 1 | # @package scheduler 2 | # https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ReduceLROnPlateau.html 3 | 4 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 5 | 6 | optimizer: ??? 7 | 8 | mode: min 9 | factor: 0.1 10 | patience: 10 11 | threshold: 0.0001 12 | threshold_mode: rel 13 | cooldown: 0 14 | min_lr: 0 15 | eps: 1e-08 16 | verbose: False -------------------------------------------------------------------------------- /experiments/configs/scheduler/step.yaml: -------------------------------------------------------------------------------- 1 | # @package scheduler 2 | # https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html 3 | 4 | _target_: torch.optim.lr_scheduler.StepLR 5 | 6 | optimizer: ??? 7 | step_size: 3 8 | 9 | gamma: 0.1 10 | last_epoch: -1 11 | verbose: False -------------------------------------------------------------------------------- /experiments/configs/scheduler/warmup.yaml: -------------------------------------------------------------------------------- 1 | # @package scheduler 2 | 3 | _target_: src.lr_scheduler.warmup_scheduler 4 | 5 | optimizer: ??? 6 | steps_per_epoch: ??? 7 | epochs: ${epochs} 8 | warmup_ratio: 0.1 -------------------------------------------------------------------------------- /experiments/configs/tokenizer/default.yaml: -------------------------------------------------------------------------------- 1 | # @package tokenizer 2 | 3 | _target_: transformers.AutoTokenizer.from_pretrained 4 | pretrained_model_name_or_path: ${model_name} 5 | 6 | add_prefix_space: True -------------------------------------------------------------------------------- /experiments/configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | # @package trainer 2 | _target_: pytorch_lightning.Trainer 3 | 4 | # default parameters of `pytorch_lightning.Trainer` 5 | logger: True 6 | checkpoint_callback: False # default: True 7 | callbacks: 8 | default_root_dir: 9 | gradient_clip_val: 0.0 10 | gradient_clip_algorithm: norm 11 | process_position: 0 12 | num_nodes: 1 13 | num_processes: 1 14 | gpus: ${gpus} 15 | auto_select_gpus: False 16 | tpu_cores: 17 | log_gpu_memory: 18 | progress_bar_refresh_rate: 19 | overfit_batches: 0.0 20 | track_grad_norm: -1 21 | check_val_every_n_epoch: 1 22 | fast_dev_run: False 23 | accumulate_grad_batches: 1 24 | max_epochs: ${epochs} 25 | min_epochs: 26 | max_steps: 27 | min_steps: 28 | max_time: 29 | limit_train_batches: 1.0 30 | limit_val_batches: 1.0 31 | limit_test_batches: 1.0 32 | limit_predict_batches: 1.0 33 | val_check_interval: 1.0 34 | flush_logs_every_n_steps: 100 35 | log_every_n_steps: 50 36 | accelerator: 37 | sync_batchnorm: False 38 | precision: 32 39 | weights_summary: top 40 | weights_save_path: 41 | num_sanity_val_steps: 2 42 | truncated_bptt_steps: 43 | resume_from_checkpoint: 44 | profiler: 45 | benchmark: False 46 | deterministic: False 47 | reload_dataloaders_every_epoch: False 48 | auto_lr_find: False 49 | replace_sampler_ddp: True 50 | terminate_on_nan: False 51 | auto_scale_batch_size: False 52 | prepare_data_per_node: True 53 | plugins: 54 | amp_backend: native 55 | amp_level: O2 56 | distributed_backend: # default: None 57 | move_metrics_to_cpu: False 58 | multiple_trainloader_mode: max_size_cycle 59 | stochastic_weight_avg: False -------------------------------------------------------------------------------- /experiments/main.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import DictConfig 3 | from src.experiment import Experiment 4 | 5 | 6 | @hydra.main(config_path="configs/", config_name="config.yaml") 7 | def main(config: DictConfig) -> None: 8 | exp = Experiment(config) 9 | ret = exp.run() 10 | 11 | if config.save_model: 12 | exp.save_model() 13 | 14 | return ret 15 | 16 | 17 | if __name__ == "__main__": 18 | main() 19 | -------------------------------------------------------------------------------- /experiments/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "defsent/experiments" 3 | version = "0.1.0" 4 | description = "DefSent: Sentence Embeddings using Definition Sentences / experiments" 5 | authors = ["hppRC "] 6 | readme = "README.md" 7 | homepage = "https://arxiv.org/abs/2105.04339" 8 | repository = "https://github.com/hppRC/defsent" 9 | 10 | 11 | [tool.poetry.dependencies] 12 | python = "^3.7,<3.10" 13 | # please lookup a compatible PyTorch version with your OS and CUDA from: https://download.pytorch.org/whl/torch_stable.html 14 | torch = {url = "https://download.pytorch.org/whl/cu111/torch-1.9.0%2Bcu111-cp37-cp37m-linux_x86_64.whl"} 15 | tqdm = "^4.61.2" 16 | pytorch-lightning = "^1.3.8" 17 | mlflow = "^1.19.0" 18 | hydra-colorlog = "^1.1.0" 19 | hydra-core = "^1.1.0" 20 | omegaconf = "^2.1.0" 21 | senteval = {git = "https://github.com/facebookresearch/SentEval"} 22 | transformers = "^4.9.0" 23 | sklearn = "^0.0" 24 | scipy = "^1.7.0" 25 | 26 | 27 | [tool.poetry.dev-dependencies] 28 | pysen = {version = "^0.9.1", extras = ["lint"]} 29 | pytest = "^5.2" 30 | 31 | [build-system] 32 | requires = ["poetry-core>=1.0.0"] 33 | build-backend = "poetry.core.masonry.api" 34 | 35 | [tool.pysen] 36 | version = "0.9" 37 | 38 | [tool.pysen.lint] 39 | enable_black = true 40 | enable_flake8 = true 41 | enable_isort = true 42 | enable_mypy = false 43 | mypy_preset = "strict" 44 | py_version = "py37" 45 | [[tool.pysen.lint.mypy_targets]] 46 | paths = ["."] 47 | 48 | -------------------------------------------------------------------------------- /experiments/scripts/download-dataset.sh: -------------------------------------------------------------------------------- 1 | mkdir -p dataset 2 | 3 | wget http://www.tkl.iis.u-tokyo.ac.jp/~ishiwatari/naacl_data.zip 4 | unzip naacl_data.zip 5 | mv ./data ./dataset/ishiwatari 6 | rm naacl_data.zip 7 | 8 | # STS2012 9 | mkdir -p dataset/sts/2012 10 | wget http://ixa2.si.ehu.es/stswiki/images/4/40/STS2012-en-test.zip 11 | unzip STS2012-en-test.zip 12 | mv test-gold dataset/sts/2012/test 13 | rm STS2012-en-test.zip 14 | 15 | # STS2013 16 | mkdir -p dataset/sts/2013 17 | wget http://ixa2.si.ehu.es/stswiki/images/2/2f/STS2013-en-test.zip 18 | unzip STS2013-en-test.zip 19 | mv test-gs dataset/sts/2013/test 20 | rm STS2013-en-test.zip 21 | 22 | # STS2014 23 | mkdir -p dataset/sts/2014 24 | wget http://ixa2.si.ehu.es/stswiki/images/8/8c/STS2014-en-test.zip 25 | unzip STS2014-en-test.zip 26 | mv sts-en-test-gs-2014 dataset/sts/2014/test 27 | rm STS2014-en-test.zip 28 | 29 | # STS2015 30 | mkdir -p dataset/sts/2015 31 | wget http://ixa2.si.ehu.es/stswiki/images/d/da/STS2015-en-test.zip 32 | unzip STS2015-en-test.zip 33 | mv test_evaluation_task2a dataset/sts/2015/test 34 | rm STS2015-en-test.zip 35 | 36 | # STS2016 37 | mkdir -p dataset/sts/2016 38 | wget http://ixa2.si.ehu.es/stswiki/images/9/98/STS2016-en-test.zip 39 | unzip STS2016-en-test.zip 40 | mv sts2016-english-with-gs-v1.0 dataset/sts/2016/test 41 | rm STS2016-en-test.zip 42 | 43 | # STS2017 44 | mkdir -p dataset/sts/2017 45 | wget http://ixa2.si.ehu.es/stswiki/images/2/20/Sts2017.eval.v1.1.zip 46 | unzip Sts2017.eval.v1.1.zip 47 | wget http://ixa2.si.ehu.es/stswiki/images/7/70/Sts2017.gs.zip 48 | unzip Sts2017.gs.zip 49 | rm Sts2017.eval.v1.1.zip Sts2017.gs.zip 50 | mv STS2017.eval.v1.1 dataset/sts/2017/input 51 | mv STS2017.gs dataset/sts/2017/gs 52 | 53 | 54 | # STS Benchmark 55 | wget http://ixa2.si.ehu.es/stswiki/images/4/48/Stsbenchmark.tar.gz 56 | tar -zxvf Stsbenchmark.tar.gz 57 | mv stsbenchmark dataset/sts/ 58 | rm Stsbenchmark.tar.gz 59 | 60 | 61 | # SICK 62 | wget http://alt.qcri.org/semeval2014/task1/data/uploads/sick_test_annotated.zip 63 | unzip sick_test_annotated.zip -d SICK 64 | mv SICK dataset/sts/ 65 | rm sick_test_annotated.zip -------------------------------------------------------------------------------- /experiments/scripts/run-base.sh: -------------------------------------------------------------------------------- 1 | nohup bash scripts/run-bert-base0.sh > /dev/null 2>&1 & 2 | nohup bash scripts/run-bert-base1.sh > /dev/null 2>&1 & 3 | nohup bash scripts/run-roberta-base0.sh > /dev/null 2>&1 & 4 | nohup bash scripts/run-roberta-base1.sh > /dev/null 2>&1 & -------------------------------------------------------------------------------- /experiments/scripts/run-bert-base0.sh: -------------------------------------------------------------------------------- 1 | poetry run python main.py -m save_model=True gpus=[0] experiment_name=BERT-base-CLS model_name=bert-base-uncased pooling_name=CLS lr=5.656854249492381e-06 +exp_times=0,1,2,3,4 2 | poetry run python main.py -m save_model=True gpus=[0] experiment_name=BERT-base-Mean model_name=bert-base-uncased pooling_name=Mean lr=1.1313708498984761e-05 +exp_times=0,1,2,3,4 3 | poetry run python main.py -m save_model=True gpus=[0] experiment_name=BERT-base-Max model_name=bert-base-uncased pooling_name=Max lr=1.1313708498984761e-05 +exp_times=0,1,2,3,4 4 | -------------------------------------------------------------------------------- /experiments/scripts/run-bert-base1.sh: -------------------------------------------------------------------------------- 1 | poetry run python main.py -m save_model=True gpus=[1] experiment_name=BERT-base-CLS model_name=bert-base-uncased pooling_name=CLS lr=5.656854249492381e-06 +exp_times=5,6,7,8,9 2 | poetry run python main.py -m save_model=True gpus=[1] experiment_name=BERT-base-Mean model_name=bert-base-uncased pooling_name=Mean lr=1.1313708498984761e-05 +exp_times=5,6,7,8,9 3 | poetry run python main.py -m save_model=True gpus=[1] experiment_name=BERT-base-Max model_name=bert-base-uncased pooling_name=Max lr=1.1313708498984761e-05 +exp_times=5,6,7,8,9 -------------------------------------------------------------------------------- /experiments/scripts/run-bert-large0.sh: -------------------------------------------------------------------------------- 1 | poetry run python main.py -m save_model=True gpus=[0] experiment_name=BERT-large-CLS model_name=bert-large-uncased pooling_name=CLS lr=5.656854249492381e-06 +exp_times=0,1,2,3,4 2 | poetry run python main.py -m save_model=True gpus=[0] experiment_name=BERT-large-Mean model_name=bert-large-uncased pooling_name=Mean lr=1.1313708498984761e-05 +exp_times=0,1,2,3,4 3 | poetry run python main.py -m save_model=True gpus=[0] experiment_name=BERT-large-Max model_name=bert-large-uncased pooling_name=Max lr=8e-06 +exp_times=0,1,2,3,4 4 | -------------------------------------------------------------------------------- /experiments/scripts/run-bert-large1.sh: -------------------------------------------------------------------------------- 1 | poetry run python main.py -m save_model=True gpus=[1] experiment_name=BERT-large-CLS model_name=bert-large-uncased pooling_name=CLS lr=5.656854249492381e-06 +exp_times=5,6,7,8,9 2 | poetry run python main.py -m save_model=True gpus=[1] experiment_name=BERT-large-Mean model_name=bert-large-uncased pooling_name=Mean lr=1.1313708498984761e-05 +exp_times=5,6,7,8,9 3 | poetry run python main.py -m save_model=True gpus=[1] experiment_name=BERT-large-Max model_name=bert-large-uncased pooling_name=Max lr=8e-06 +exp_times=5,6,7,8,9 -------------------------------------------------------------------------------- /experiments/scripts/run-large.sh: -------------------------------------------------------------------------------- 1 | nohup bash scripts/run-bert-large0.sh > /dev/null 2>&1 & 2 | nohup bash scripts/run-bert-large1.sh > /dev/null 2>&1 & 3 | nohup bash scripts/run-roberta-large0.sh > /dev/null 2>&1 & 4 | nohup bash scripts/run-roberta-large1.sh > /dev/null 2>&1 & -------------------------------------------------------------------------------- /experiments/scripts/run-roberta-base0.sh: -------------------------------------------------------------------------------- 1 | poetry run python main.py -m save_model=True gpus=[2] experiment_name=RoBERTa-base-CLS model_name=roberta-base pooling_name=CLS lr=5.656854249492381e-06 +exp_times=0,1,2,3,4 2 | poetry run python main.py -m save_model=True gpus=[2] experiment_name=RoBERTa-base-Mean model_name=roberta-base pooling_name=Mean lr=8e-06 +exp_times=0,1,2,3,4 3 | poetry run python main.py -m save_model=True gpus=[2] experiment_name=RoBERTa-base-Max model_name=roberta-base pooling_name=Max lr=4e-06 +exp_times=0,1,2,3,4 -------------------------------------------------------------------------------- /experiments/scripts/run-roberta-base1.sh: -------------------------------------------------------------------------------- 1 | poetry run python main.py -m save_model=True gpus=[3] experiment_name=RoBERTa-base-CLS model_name=roberta-base pooling_name=CLS lr=5.656854249492381e-06 +exp_times=5,6,7,8,9 2 | poetry run python main.py -m save_model=True gpus=[3] experiment_name=RoBERTa-base-Mean model_name=roberta-base pooling_name=Mean lr=8e-06 +exp_times=5,6,7,8,9 3 | poetry run python main.py -m save_model=True gpus=[3] experiment_name=RoBERTa-base-Max model_name=roberta-base pooling_name=Max lr=4e-06 +exp_times=5,6,7,8,9 -------------------------------------------------------------------------------- /experiments/scripts/run-roberta-large0.sh: -------------------------------------------------------------------------------- 1 | poetry run python main.py -m save_model=True gpus=[2] experiment_name=RoBERTa-large-CLS model_name=roberta-large pooling_name=CLS lr=4e-06 +exp_times=0,1,2,3,4 2 | poetry run python main.py -m save_model=True gpus=[2] experiment_name=RoBERTa-large-Mean model_name=roberta-large pooling_name=Mean lr=4e-06 +exp_times=0,1,2,3,4 3 | poetry run python main.py -m save_model=True gpus=[2] experiment_name=RoBERTa-large-Max model_name=roberta-large pooling_name=Max lr=5.656854249492381e-06 +exp_times=0,1,2,3,4 -------------------------------------------------------------------------------- /experiments/scripts/run-roberta-large1.sh: -------------------------------------------------------------------------------- 1 | poetry run python main.py -m save_model=True gpus=[3] experiment_name=RoBERTa-large-CLS model_name=roberta-large pooling_name=CLS lr=4e-06 +exp_times=5,6,7,8,9 2 | poetry run python main.py -m save_model=True gpus=[3] experiment_name=RoBERTa-large-Mean model_name=roberta-large pooling_name=Mean lr=4e-06 +exp_times=5,6,7,8,9 3 | poetry run python main.py -m save_model=True gpus=[3] experiment_name=RoBERTa-large-Max model_name=roberta-large pooling_name=Max lr=5.656854249492381e-06 +exp_times=5,6,7,8,9 -------------------------------------------------------------------------------- /experiments/src/data_module.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from pathlib import Path 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import pytorch_lightning as pl 7 | import torch 8 | from src.dataset import Dataset 9 | from src.utils import pad_sequence 10 | from torch.functional import Tensor 11 | from torch.utils.data import DataLoader 12 | from transformers import PreTrainedTokenizerBase 13 | 14 | 15 | class DataModule(pl.LightningDataModule): 16 | def __init__( 17 | self, 18 | batch_size: int, 19 | data_dir: Union[Path, str], 20 | tokenizer: PreTrainedTokenizerBase, 21 | ) -> None: 22 | super().__init__() 23 | 24 | self.batch_size = batch_size 25 | self.data_dir = Path(data_dir) 26 | self.tokenizer = tokenizer 27 | 28 | self.train = None 29 | self.val = None 30 | self.test = None 31 | 32 | def text_to_data(self, lines: List[str]) -> Tuple[List[int], List[int]]: 33 | words, definitions = [], [] 34 | for line in lines: 35 | word, definition = line.strip().split("\t") # line: "word\tdefinition" 36 | words.append(word) 37 | definitions.append(definition) 38 | 39 | # encode without special tokens (e.g., [CLS], [SEP], , <\s>) 40 | words_ids = self.tokenizer(words, add_special_tokens=False).input_ids 41 | definitions_ids = self.tokenizer(definitions, truncation=True).input_ids 42 | 43 | filtered_words_ids, filtered_definitions_ids = [], [] 44 | for word_id, definition_ids in zip(words_ids, definitions_ids): 45 | if len(word_id) == 1: 46 | filtered_words_ids.append(word_id) 47 | filtered_definitions_ids.append(definition_ids) 48 | 49 | return (filtered_words_ids, filtered_definitions_ids) 50 | 51 | def collate_fn( 52 | self, data_list: List[Tuple[List[Tensor], List[Tensor]]] 53 | ) -> Tuple[Tensor, Tensor, Tensor]: 54 | word_id_list, definition_ids_list = zip(*data_list) 55 | words_ids = torch.cat(word_id_list, dim=0) 56 | definitions_ids = pad_sequence( 57 | definition_ids_list, 58 | padding_value=self.tokenizer.pad_token_id, 59 | padding_side="right", 60 | ) 61 | attention_mask = (definitions_ids != self.tokenizer.pad_token_id).float() 62 | 63 | return (words_ids, definitions_ids, attention_mask) 64 | 65 | def setup(self, stage: Optional[str] = None) -> None: 66 | # make assignments here (train/valid/test split) 67 | # called on every GPUs 68 | self.train = Dataset( 69 | data_path=self.data_dir / "train.tsv", text_to_data=self.text_to_data, 70 | ) 71 | self.val = Dataset( 72 | data_path=self.data_dir / "valid.tsv", text_to_data=self.text_to_data, 73 | ) 74 | self.test = Dataset( 75 | data_path=self.data_dir / "test.tsv", text_to_data=self.text_to_data, 76 | ) 77 | 78 | def train_dataloader(self) -> DataLoader: 79 | return DataLoader( 80 | self.train, 81 | batch_size=self.batch_size, 82 | num_workers=os.cpu_count(), 83 | collate_fn=self.collate_fn, 84 | # pin_memory=True, 85 | ) 86 | 87 | def val_dataloader(self) -> DataLoader: 88 | return DataLoader( 89 | self.val, 90 | batch_size=self.batch_size, 91 | num_workers=os.cpu_count(), 92 | collate_fn=self.collate_fn, 93 | # pin_memory=True, 94 | ) 95 | 96 | def test_dataloader(self) -> DataLoader: 97 | return DataLoader( 98 | self.test, 99 | batch_size=self.batch_size, 100 | num_workers=os.cpu_count(), 101 | collate_fn=self.collate_fn, 102 | # pin_memory=True, 103 | ) 104 | -------------------------------------------------------------------------------- /experiments/src/dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Callable, List, Tuple, Union 3 | 4 | import torch 5 | 6 | 7 | class Dataset(torch.utils.data.Dataset): 8 | def __init__( 9 | self, 10 | data_path: Union[Path, str], 11 | text_to_data: Callable[[List[str]], Tuple[List[int], List[int]]], 12 | ): 13 | with Path(data_path).open() as f: 14 | self.words, self.definitions = text_to_data(f.readlines()) 15 | 16 | assert len(self.words) == len(self.definitions) 17 | 18 | def __len__(self): 19 | return len(self.words) 20 | 21 | def __getitem__(self, key: Union[int, slice]): 22 | if isinstance(key, int): 23 | return ( 24 | torch.LongTensor(self.words[key]), 25 | torch.LongTensor(self.definitions[key]), 26 | ) 27 | elif isinstance(key, slice): 28 | return ( 29 | [torch.LongTensor(word_id) for word_id in self.words[key]], 30 | [ 31 | torch.LongTensor(definition_ids) 32 | for definition_ids in self.definitions[key] 33 | ], 34 | ) 35 | -------------------------------------------------------------------------------- /experiments/src/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .def2word import Def2WordEvaluationAll 2 | from .sts import STSEvaluation 3 | -------------------------------------------------------------------------------- /experiments/src/evaluation/def2word.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Callable 3 | 4 | import torch 5 | from src.data_module import DataModule 6 | from src.model import DefSent 7 | from tqdm import tqdm 8 | from transformers import PreTrainedTokenizerBase 9 | 10 | 11 | @torch.no_grad() 12 | def get_mrr(indices, targets): 13 | tmp = targets.view(-1, 1) 14 | targets = tmp.expand_as(indices) 15 | hits = (targets == indices).nonzero(as_tuple=False) 16 | ranks = hits[:, -1] + 1 17 | ranks = ranks.float() 18 | rranks = torch.reciprocal(ranks) 19 | return torch.sum(rranks) 20 | 21 | 22 | class Def2WordEvaluation: 23 | def __init__( 24 | self, 25 | data_module: DataModule, 26 | tokenizer: PreTrainedTokenizerBase, 27 | topk: int = 10, 28 | save_predictions: bool = False, 29 | ) -> None: 30 | self.dm = data_module 31 | self.tokenizer = tokenizer 32 | self.topk = topk 33 | self.save_predictions = save_predictions 34 | 35 | @torch.no_grad() 36 | def __call__(self, model: DefSent, mode: str): 37 | if mode == "train": 38 | dataset = self.dm.train 39 | dataloader = self.dm.train_dataloader() 40 | elif mode == "val": 41 | dataset = self.dm.val 42 | dataloader = self.dm.val_dataloader() 43 | elif mode == "test": 44 | dataset = self.dm.test 45 | dataloader = self.dm.test_dataloader() 46 | else: 47 | raise ValueError(f"No such a mode!: {mode}") 48 | 49 | res = [] 50 | mrr_sum = 0 51 | topk_acc_sum = [0] * self.topk 52 | device = model.device 53 | 54 | for batch in tqdm(dataloader): 55 | words_ids, definitions_ids, attention_mask = batch 56 | words_ids, definitions_ids, attention_mask = ( 57 | words_ids.to(device), 58 | definitions_ids.to(device), 59 | attention_mask.to(device), 60 | ) 61 | 62 | logits = model.predict_words(definitions_ids, attention_mask=attention_mask) 63 | hypothesis = logits.topk(self.topk, dim=1).indices 64 | words = self.tokenizer.convert_ids_to_tokens(words_ids) 65 | 66 | for word, definition_ids, hyp_words_ids in zip( 67 | words, definitions_ids, hypothesis 68 | ): 69 | hyp_words = self.tokenizer.convert_ids_to_tokens(hyp_words_ids) 70 | assert len(hyp_words) == self.topk 71 | 72 | if self.save_predictions: 73 | definition_tokens = self.tokenizer.convert_ids_to_tokens( 74 | definition_ids, skip_special_tokens=True 75 | ) 76 | definition = self.tokenizer.convert_tokens_to_string( 77 | definition_tokens 78 | ) 79 | res.append( 80 | {"word": word, "definition": definition, "hyp_words": hyp_words} 81 | ) 82 | 83 | already_found_correct_word = False 84 | for i in range(self.topk): 85 | if hyp_words[i] == word: 86 | already_found_correct_word = True 87 | if already_found_correct_word: 88 | topk_acc_sum[i] += 1 89 | 90 | mrr_sum += get_mrr(hypothesis, words_ids).item() 91 | 92 | ret = { 93 | mode: { 94 | "MRR": mrr_sum / len(dataset) * 100, 95 | "ACC": [cnt / len(dataset) * 100 for cnt in topk_acc_sum], 96 | } 97 | } 98 | if self.save_predictions: 99 | ret[mode]["result"] = res 100 | return ret 101 | 102 | 103 | class Def2WordEvaluationAll: 104 | def __init__( 105 | self, 106 | data_module: DataModule, 107 | tokenizer: PreTrainedTokenizerBase, 108 | topk: int = 10, 109 | save_predictions: bool = False, 110 | log_artifact: Callable[[str], None] = None, 111 | ) -> None: 112 | self.save_predictions = save_predictions 113 | self.def2word_evaluator = Def2WordEvaluation( 114 | data_module=data_module, 115 | tokenizer=tokenizer, 116 | topk=topk, 117 | save_predictions=save_predictions, 118 | ) 119 | self.log_artifact = log_artifact 120 | 121 | def __call__(self, model: DefSent): 122 | if self.save_predictions: 123 | results_dir = Path("./results/def2word-prediction") 124 | results_dir.mkdir(parents=True, exist_ok=True) 125 | 126 | metrics = {} 127 | for mode in ["train", "val", "test"]: 128 | result = self.def2word_evaluator(model, mode=mode) 129 | topk_acc = result[mode]["ACC"] 130 | top1, top3, top10 = topk_acc[0], topk_acc[2], topk_acc[9] 131 | mrr = result[mode]["MRR"] 132 | metrics[mode] = {"MRR": mrr, "top1": top1, "top3": top3, "top10": top10} 133 | 134 | if self.save_predictions: 135 | save_path = results_dir / f"{mode}.txt" 136 | res = result[mode]["result"] 137 | lines = [] 138 | for data in res: 139 | word, definition, hyp_words = ( 140 | data["word"], 141 | data["definition"], 142 | data["hyp_words"], 143 | ) 144 | hyp_line = "\t".join(hyp_words) 145 | lines.append(f"{word}\t[{definition}]\n{hyp_line}\n") 146 | save_path.write_text("\n".join(lines)) 147 | self.log_artifact(save_path) 148 | 149 | return metrics 150 | -------------------------------------------------------------------------------- /experiments/src/evaluation/senteval.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import senteval 4 | 5 | 6 | # SentEval prepare and batcher 7 | def prepare(params, samples): 8 | return 9 | 10 | 11 | def batcher(params, batch): 12 | batch = [" ".join(sent) if sent != [] else "." for sent in batch] 13 | embeddings = params["encoder"](batch) 14 | return embeddings 15 | 16 | 17 | class SentEvalEvaluator: 18 | def __init__(self, data_dir): 19 | self.data_dir = data_dir 20 | 21 | def __call__(self, encoder): 22 | # Set params for SentEval 23 | params_senteval = {"task_path": self.data_dir, "usepytorch": True, "kfold": 10} 24 | # params_senteval = {"task_path": self.data_dir, "usepytorch": True, "kfold": 2} 25 | params_senteval["classifier"] = { 26 | "nhid": 0, 27 | "optim": "adam", 28 | "batch_size": 64, 29 | "tenacity": 5, 30 | # "epoch_size": 1, 31 | "epoch_size": 4, 32 | } 33 | params_senteval["encoder"] = encoder 34 | 35 | se = senteval.engine.SE(params_senteval, batcher, prepare) 36 | 37 | # sts = [ 38 | # "STS12", 39 | # "STS13", 40 | # "STS14", 41 | # "STS15", 42 | # "STS16", 43 | # "STSBenchmark", 44 | # "SICKRelatedness", 45 | # ] 46 | classification_tasks = [ 47 | "MR", 48 | "CR", 49 | "SUBJ", 50 | "MPQA", 51 | "SST2", 52 | "TREC", 53 | "MRPC", 54 | # "SICKEntailment", 55 | ] 56 | # probing_tasks = [ 57 | # "Length", 58 | # "WordContent", 59 | # "Depth", 60 | # "TopConstituents", 61 | # "BigramShift", 62 | # "Tense", 63 | # "SubjNumber", 64 | # "ObjNumber", 65 | # "OddManOut", 66 | # "CoordinationInversion", 67 | # ] 68 | 69 | metrics = {} 70 | # for task in classification_tasks + probing_tasks + sts: 71 | for task in classification_tasks: 72 | # for task in se.list_tasks: 73 | print(task) 74 | try: 75 | metrics[task] = { 76 | k: self.convert(v) for k, v in se.eval([task])[task].items() 77 | } 78 | except: 79 | print("error:", task) 80 | 81 | return metrics 82 | 83 | def convert(self, v): 84 | try: 85 | return float(v) 86 | except: 87 | try: 88 | return [float(x) for x in v] 89 | except: 90 | return -1 91 | -------------------------------------------------------------------------------- /experiments/src/evaluation/sts.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Callable, Dict, List, Union 3 | 4 | import numpy as np 5 | import torch 6 | from scipy.stats import pearsonr, spearmanr 7 | from sklearn.metrics.pairwise import ( 8 | paired_cosine_distances, 9 | paired_euclidean_distances, 10 | paired_manhattan_distances, 11 | ) 12 | from torch import Tensor 13 | from tqdm import tqdm 14 | 15 | 16 | # https://arxiv.org/pdf/2104.01767.pdf 17 | def whitening_torch_final(embeddings): 18 | mu = torch.mean(embeddings, dim=0, keepdim=True) 19 | # cov = torch.mm((embeddings - mu).t(), embeddings - mu) 20 | cov = torch.mm((embeddings - mu).t(), embeddings - mu) / embeddings.size(0) 21 | u, s, _ = torch.svd(cov) 22 | W = torch.mm(u, torch.diag(1 / torch.sqrt(s))) 23 | embeddings = torch.mm(embeddings - mu, W) 24 | return embeddings 25 | 26 | 27 | class EmbeddingSimilarityEvaluator: 28 | def __init__( 29 | self, 30 | sentences1: List[str], 31 | sentences2: List[str], 32 | scores: List[float], 33 | batch_size: int = 1024, 34 | name: str = "", 35 | ): 36 | self.sentences1 = sentences1 37 | self.sentences2 = sentences2 38 | self.scores = scores 39 | 40 | # print(name, len(self.sentences1)) 41 | assert len(self.sentences1) == len(self.sentences2) 42 | assert len(self.sentences1) == len(self.scores) 43 | 44 | self.name = name 45 | self.batch_size = batch_size 46 | 47 | def __call__( 48 | self, 49 | encoder: Callable[[List[str]], Tensor], 50 | do_whitening: bool = False, 51 | to_lower: bool = False, 52 | ) -> Dict[str, Dict[str, float]]: 53 | if to_lower: 54 | self.sentences1 = [x.lower() for x in self.sentences1] 55 | self.sentences2 = [x.lower() for x in self.sentences2] 56 | 57 | embeddings1 = encoder(self.sentences1, batch_size=self.batch_size) 58 | embeddings2 = encoder(self.sentences2, batch_size=self.batch_size) 59 | 60 | if do_whitening: 61 | num_pairs = embeddings1.shape[0] 62 | embeddings = whitening_torch_final( 63 | torch.cat([embeddings1, embeddings2], dim=0) 64 | ) 65 | embeddings1 = embeddings[:num_pairs, :] 66 | embeddings2 = embeddings[num_pairs:, :] 67 | 68 | cosine_scores = 1 - paired_cosine_distances(embeddings1, embeddings2) 69 | manhattan_distances = -paired_manhattan_distances(embeddings1, embeddings2) 70 | euclidean_distances = -paired_euclidean_distances(embeddings1, embeddings2) 71 | dot_products = [ 72 | np.dot(emb1, emb2) for emb1, emb2 in zip(embeddings1, embeddings2) 73 | ] 74 | 75 | # convert to a premitive float type 76 | eval_pearson = lambda my_score: float(pearsonr(self.scores, my_score)[0]) * 100 77 | eval_spearman = ( 78 | lambda my_score: float(spearmanr(self.scores, my_score)[0]) * 100 79 | ) 80 | 81 | return { 82 | "spearman": { 83 | "cosine": eval_spearman(cosine_scores), 84 | "manhattan": eval_spearman(manhattan_distances), 85 | "euclidean": eval_spearman(euclidean_distances), 86 | "dot": eval_spearman(dot_products), 87 | }, 88 | "pearson": { 89 | "cosine": eval_pearson(cosine_scores), 90 | "manhattan": eval_pearson(manhattan_distances), 91 | "euclidean": eval_pearson(euclidean_distances), 92 | "dot": eval_pearson(dot_products), 93 | }, 94 | } 95 | 96 | 97 | class SICKRelatednessEvaluator(EmbeddingSimilarityEvaluator): 98 | def __init__(self, data_dir: Path): 99 | sentences1, sentences2, scores = [], [], [] 100 | 101 | with (data_dir / "SICK" / "SICK_test_annotated.txt").open() as f: 102 | _ = next(f) 103 | for line in f: 104 | _, sentence1, sentence2, score, *_ = line.strip().split("\t") 105 | sentences1.append(sentence1) 106 | sentences2.append(sentence2) 107 | scores.append(float(score)) 108 | 109 | super().__init__(sentences1, sentences2, scores, name="sick-relatedness") 110 | 111 | 112 | class STSBenchmarkEvaluator(EmbeddingSimilarityEvaluator): 113 | def __init__(self, data_dir: Path): 114 | name = "sts-benchmark" 115 | 116 | datasets = [ 117 | # "sts-train.csv", 118 | # "sts-dev.csv", 119 | "sts-test.csv", 120 | ] 121 | 122 | sentences1, sentences2, scores = [], [], [] 123 | 124 | for dataset in datasets: 125 | with (data_dir / "stsbenchmark" / dataset).open() as f: 126 | for line in f: 127 | _, _, _, _, score, sentence1, sentence2, *_ = line.strip().split( 128 | "\t" 129 | ) 130 | sentences1.append(sentence1) 131 | sentences2.append(sentence2) 132 | scores.append(float(score)) 133 | 134 | super().__init__(list(sentences1), list(sentences2), list(scores), name=name) 135 | 136 | 137 | class STS2016Evaluator(EmbeddingSimilarityEvaluator): 138 | def __init__(self, data_dir: Path): 139 | name = "sts-2016" 140 | 141 | sentences1, sentences2, scores = [], [], [] 142 | datasets = [ 143 | "answer-answer", 144 | "headlines", 145 | "plagiarism", 146 | "postediting", 147 | "question-question", 148 | ] 149 | 150 | for dataset in datasets: 151 | with ( 152 | data_dir / "2016" / "test" / f"STS2016.gs.{dataset}.txt" 153 | ).open() as gs, ( 154 | data_dir / "2016" / "test" / f"STS2016.input.{dataset}.txt" 155 | ).open() as f: 156 | for line_input, line_gs in zip(f, gs): 157 | sentence1, sentence2, *_ = line_input.strip().split("\t") 158 | if line_gs.strip() == "": 159 | continue 160 | sentences1.append(sentence1) 161 | sentences2.append(sentence2) 162 | scores.append(float(line_gs.strip())) 163 | 164 | super().__init__(sentences1, sentences2, scores, name=name) 165 | 166 | 167 | class STS2015Evaluator(EmbeddingSimilarityEvaluator): 168 | def __init__(self, data_dir: Path): 169 | name = "sts-2015" 170 | 171 | sentences1, sentences2, scores = [], [], [] 172 | datasets = [ 173 | "answers-forums", 174 | "answers-students", 175 | "belief", 176 | "headlines", 177 | "images", 178 | ] 179 | 180 | for dataset in datasets: 181 | with (data_dir / "2015" / "test" / f"STS.gs.{dataset}.txt").open() as gs, ( 182 | data_dir / "2015" / "test" / f"STS.input.{dataset}.txt" 183 | ).open() as f: 184 | for line_input, line_gs in zip(f, gs): 185 | sentence1, sentence2, *_ = line_input.strip().split("\t") 186 | if line_gs.strip() == "": 187 | continue 188 | sentences1.append(sentence1) 189 | sentences2.append(sentence2) 190 | scores.append(float(line_gs.strip())) 191 | 192 | super().__init__(sentences1, sentences2, scores, name=name) 193 | 194 | 195 | class STS2014Evaluator(EmbeddingSimilarityEvaluator): 196 | def __init__(self, data_dir: Path): 197 | name = "sts-2014" 198 | 199 | sentences1, sentences2, scores = [], [], [] 200 | datasets = [ 201 | "deft-forum", 202 | "deft-news", 203 | "headlines", 204 | "images", 205 | "OnWN", 206 | "tweet-news", 207 | ] 208 | 209 | for dataset in datasets: 210 | with (data_dir / "2014" / "test" / f"STS.gs.{dataset}.txt").open() as gs, ( 211 | data_dir / "2014" / "test" / f"STS.input.{dataset}.txt" 212 | ).open() as f: 213 | for line_input, line_gs in zip(f, gs): 214 | sentence1, sentence2, *_ = line_input.strip().split("\t") 215 | if line_gs.strip() == "": 216 | continue 217 | sentences1.append(sentence1) 218 | sentences2.append(sentence2) 219 | scores.append(float(line_gs.strip())) 220 | 221 | super().__init__(sentences1, sentences2, scores, name=name) 222 | 223 | 224 | class STS2013Evaluator(EmbeddingSimilarityEvaluator): 225 | # STS13 here does not contain the "SMT" subtask due to LICENSE issue 226 | def __init__(self, data_dir: Path): 227 | name = "sts-2013" 228 | 229 | sentences1, sentences2, scores = [], [], [] 230 | datasets = ["FNWN", "headlines", "OnWN"] 231 | 232 | for dataset in datasets: 233 | with (data_dir / "2013" / "test" / f"STS.gs.{dataset}.txt").open() as gs, ( 234 | data_dir / "2013" / "test" / f"STS.input.{dataset}.txt" 235 | ).open() as f: 236 | for line_input, line_gs, *_ in zip(f, gs): 237 | sentence1, sentence2 = line_input.strip().split("\t") 238 | if line_gs.strip() == "": 239 | continue 240 | sentences1.append(sentence1) 241 | sentences2.append(sentence2) 242 | scores.append(float(line_gs.strip())) 243 | 244 | super().__init__(sentences1, sentences2, scores, name=name) 245 | 246 | 247 | class STS2012Evaluator(EmbeddingSimilarityEvaluator): 248 | def __init__(self, data_dir: Path): 249 | name = "sts-2012" 250 | 251 | sentences1, sentences2, scores = [], [], [] 252 | datasets = [ 253 | "MSRpar", 254 | "MSRvid", 255 | "SMTeuroparl", 256 | "surprise.OnWN", 257 | "surprise.SMTnews", 258 | ] 259 | 260 | for dataset in datasets: 261 | with (data_dir / "2012" / "test" / f"STS.gs.{dataset}.txt").open() as gs, ( 262 | data_dir / "2012" / "test" / f"STS.input.{dataset}.txt" 263 | ).open() as f: 264 | for line_input, line_gs in zip(f, gs): 265 | sentence1, sentence2, *_ = line_input.strip().split("\t") 266 | if line_gs.strip() == "": 267 | continue 268 | sentences1.append(sentence1) 269 | sentences2.append(sentence2) 270 | scores.append(float(line_gs.strip())) 271 | 272 | super().__init__(sentences1, sentences2, scores, name=name) 273 | 274 | 275 | class STSEvaluation: 276 | def __init__(self, data_dir: Union[str, Path]): 277 | data_dir = Path(data_dir) 278 | self.sts_evaluators = { 279 | "STS12": STS2012Evaluator(data_dir=data_dir), 280 | "STS13": STS2013Evaluator(data_dir=data_dir), 281 | "STS14": STS2014Evaluator(data_dir=data_dir), 282 | "STS15": STS2015Evaluator(data_dir=data_dir), 283 | "STS16": STS2016Evaluator(data_dir=data_dir), 284 | "STSB": STSBenchmarkEvaluator(data_dir=data_dir), 285 | "SICK-R": SICKRelatednessEvaluator(data_dir=data_dir), 286 | } 287 | 288 | self.metrics = ["spearman", "pearson"] 289 | self.methods = ["cosine", "manhattan", "euclidean", "dot"] 290 | 291 | @torch.no_grad() 292 | def __call__( 293 | self, 294 | encoder: Callable[[List[str]], Tensor], 295 | do_whitening: bool = False, 296 | to_lower: bool = False, 297 | ): 298 | sts_evaluations = {} 299 | for name, evaluator in tqdm(list(self.sts_evaluators.items())): 300 | sts_evaluations[name] = evaluator( 301 | encoder, do_whitening=do_whitening, to_lower=to_lower 302 | ) 303 | 304 | sts_evaluations["AVG"] = {} 305 | for metric in self.metrics: 306 | sts_evaluations["AVG"][metric] = {} 307 | 308 | for method in self.methods: 309 | sts_evaluations["AVG"][metric][method] = 0.0 310 | 311 | for task in self.sts_evaluators: 312 | sts_evaluations["AVG"][metric][method] += sts_evaluations[task][ 313 | metric 314 | ][method] 315 | sts_evaluations["AVG"][metric][method] /= len(self.sts_evaluators) 316 | 317 | return sts_evaluations 318 | -------------------------------------------------------------------------------- /experiments/src/experiment.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Dict, List, Optional, Tuple 4 | 5 | import pytorch_lightning as pl 6 | import torch 7 | import torch.nn.functional as F 8 | from hydra.utils import instantiate 9 | from omegaconf import DictConfig, OmegaConf 10 | from pytorch_lightning.utilities import rank_zero_only 11 | from src.data_module import DataModule 12 | from src.evaluation import Def2WordEvaluationAll, STSEvaluation 13 | from src.model import DefSent 14 | from torch import Tensor 15 | from torch.optim import Optimizer 16 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 17 | 18 | 19 | class Experiment(pl.LightningModule): 20 | def __init__(self, config: DictConfig) -> None: 21 | super(Experiment, self).__init__() 22 | self.config: DictConfig = config 23 | logger = instantiate(config.logger) 24 | self.trainer = instantiate( 25 | config.trainer, 26 | logger=logger, 27 | # callbacks=[LearningRateMonitor(logging_interval="step")], 28 | ) 29 | self.model: DefSent = instantiate(config.model) 30 | self.tokenizer: PreTrainedTokenizerBase = instantiate(config.tokenizer) 31 | self.data_module: DataModule = instantiate( 32 | config.data_module, tokenizer=self.tokenizer 33 | ) 34 | 35 | self.def2word_evaluator = Def2WordEvaluationAll( 36 | data_module=self.data_module, 37 | tokenizer=self.tokenizer, 38 | topk=config.d2w.topk, 39 | save_predictions=config.d2w.save_predictions, 40 | log_artifact=self.log_artifact, 41 | ) 42 | self.sts_evaluator = STSEvaluation(data_dir=config.sts.data_dir) 43 | 44 | def configure_optimizers(self): 45 | params = (param for param in self.model.parameters() if param.requires_grad) 46 | steps_per_epoch = len(self.data_module.train_dataloader()) 47 | optimizer: Optimizer = instantiate(self.config.optimizer, params=params) 48 | scheduler = instantiate( 49 | self.config.scheduler, optimizer=optimizer, steps_per_epoch=steps_per_epoch 50 | ) 51 | return [optimizer], [scheduler] 52 | 53 | def loss_fn(self, logits: Tensor, labels_ids: Tensor) -> Tensor: 54 | return F.cross_entropy(logits, labels_ids) 55 | 56 | def training_step(self, batch: Tuple[Tensor, Tensor, Tensor], batch_idx: int): 57 | words_ids, definitions_ids, attention_mask = batch 58 | logits = self.model.predict_words( 59 | definitions_ids, attention_mask=attention_mask 60 | ) 61 | loss = self.loss_fn(logits, words_ids) 62 | self.log("train_loss", loss) 63 | return loss 64 | 65 | def validation_step(self, batch: Tuple[Tensor, Tensor, Tensor], batch_idx: int): 66 | words_ids, definitions_ids, attention_mask = batch 67 | logits = self.model.predict_words( 68 | definitions_ids, attention_mask=attention_mask 69 | ) 70 | loss = self.loss_fn(logits, words_ids) 71 | self.log("val_loss", loss) 72 | return loss 73 | 74 | # train your model 75 | def fit(self) -> None: 76 | self.trainer.fit(self, self.data_module) 77 | self.log_hyperparams() 78 | self.log_cwd() 79 | self.log_artifact(".hydra/config.yaml") 80 | self.log_artifact(".hydra/hydra.yaml") 81 | self.log_artifact(".hydra/overrides.yaml") 82 | self.log_artifact("main.log") 83 | 84 | @rank_zero_only 85 | def evaluate(self): 86 | prev_device = self.device 87 | self.to(self.trainer.accelerator_connector.root_gpu) 88 | self.eval() 89 | 90 | metrics = {} 91 | metrics["d2w"] = self.def2word_evaluator(self.model) 92 | metrics["sts"] = self.sts_evaluator( 93 | encoder=self.encode, 94 | do_whitening=self.config.sts.do_whitening, 95 | to_lower=self.config.sts.to_lower, 96 | ) 97 | self.log_main_metrics(metrics) 98 | 99 | metrics_str = OmegaConf.to_yaml(OmegaConf.create(metrics)) 100 | metrics_path = Path("./metrics.yaml") 101 | metrics_path.write_text(metrics_str) 102 | self.log_artifact(metrics_path) 103 | 104 | self.to(prev_device) 105 | 106 | # run your whole experiments 107 | def run(self): 108 | self.fit() 109 | self.evaluate() 110 | 111 | def log_artifact(self, artifact_path: str) -> None: 112 | self.logger.experiment.log_artifact(self.logger.run_id, artifact_path) 113 | 114 | def log_hyperparams(self) -> None: 115 | self.logger.log_hyperparams( 116 | { 117 | "model_name": self.config.model_name, 118 | "pooling_name": self.config.pooling_name, 119 | "batch_size": self.config.batch_size, 120 | "lr": self.config.lr, 121 | "optimizer": self.config.optimizer._target_, 122 | "lr_scheduler": self.config.scheduler._target_, 123 | } 124 | ) 125 | 126 | def log_cwd(self) -> None: 127 | self.logger.log_hyperparams({"_cwd": str(Path.cwd())}) 128 | 129 | def log_main_metrics(self, metrics: Dict) -> None: 130 | main_metrics = { 131 | "d2w_test_MRR": metrics["d2w"]["test"]["MRR"], 132 | "d2w_test_top1": metrics["d2w"]["test"]["top1"], 133 | "d2w_test_top3": metrics["d2w"]["test"]["top3"], 134 | "d2w_test_top10": metrics["d2w"]["test"]["top10"], 135 | "sts_12": metrics["sts"]["STS12"]["spearman"]["cosine"], 136 | "sts_13": metrics["sts"]["STS13"]["spearman"]["cosine"], 137 | "sts_14": metrics["sts"]["STS14"]["spearman"]["cosine"], 138 | "sts_15": metrics["sts"]["STS15"]["spearman"]["cosine"], 139 | "sts_16": metrics["sts"]["STS16"]["spearman"]["cosine"], 140 | "sts_B": metrics["sts"]["STSB"]["spearman"]["cosine"], 141 | "sts_SICK-R": metrics["sts"]["SICK-R"]["spearman"]["cosine"], 142 | "sts_AVG": metrics["sts"]["AVG"]["spearman"]["cosine"], 143 | } 144 | self.logger.log_metrics(main_metrics) 145 | 146 | @torch.no_grad() 147 | def encode(self, sentences: List[str], batch_size: Optional[int]) -> Tensor: 148 | inputs = self.tokenizer( 149 | sentences, padding=True, return_tensors="pt", truncation=True, 150 | ) 151 | data_loader = torch.utils.data.DataLoader( 152 | list(zip(inputs.input_ids, inputs.attention_mask)), 153 | batch_size=batch_size or self.config.batch_size, 154 | num_workers=os.cpu_count(), 155 | ) 156 | 157 | all_embs = [] 158 | for batch in data_loader: 159 | sentence_ids, attention_mask = self.transfer_batch_to_device( 160 | batch, self.device 161 | ) 162 | embs = self.model(sentence_ids, attention_mask=attention_mask).cpu() 163 | all_embs.append(embs) 164 | 165 | embeddings = torch.cat(all_embs, dim=0) 166 | return embeddings 167 | 168 | def save_model(self) -> None: 169 | self.model.pretrained_model.save_pretrained("./pretrained") 170 | self.tokenizer.save_pretrained("./pretrained") 171 | -------------------------------------------------------------------------------- /experiments/src/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import LambdaLR 2 | 3 | 4 | class LRPolycy: 5 | def __init__(self, num_warmup_steps: int) -> None: 6 | self.num_warmup_steps = num_warmup_steps 7 | 8 | def __call__(self, current_step: int) -> float: 9 | if current_step < self.num_warmup_steps: 10 | return float(current_step) / float(max(1.0, self.num_warmup_steps)) 11 | return 1.0 12 | 13 | 14 | def warmup_scheduler(optimizer, steps_per_epoch: int, epochs: int, warmup_ratio: float): 15 | num_training_steps = epochs * steps_per_epoch 16 | num_warmup_steps = num_training_steps * warmup_ratio 17 | 18 | lr_scheduler = LambdaLR(optimizer, lr_lambda=LRPolycy(num_warmup_steps)) 19 | scheduler_config = { 20 | "scheduler": lr_scheduler, 21 | "monitor": "val_loss", 22 | "interval": "step", 23 | "frequency": 1, 24 | "strict": True, 25 | } 26 | return scheduler_config 27 | -------------------------------------------------------------------------------- /experiments/src/model.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import pytorch_lightning as pl 4 | import torch.nn as nn 5 | from src.pooling import NonParametricPooling 6 | from torch import Tensor 7 | from transformers import ( 8 | AlbertForMaskedLM, 9 | BertForMaskedLM, 10 | DebertaForMaskedLM, 11 | PreTrainedModel, 12 | RobertaForMaskedLM, 13 | ) 14 | 15 | 16 | class DefSent(pl.LightningModule): 17 | def __init__( 18 | self, 19 | model_name: str, 20 | pooling_name: str, 21 | randomize_prediction_layer: bool = False, 22 | freeze_prediction_layer: bool = True, 23 | freeze_token_embeddings: bool = True, 24 | ) -> None: 25 | super().__init__() 26 | # When `freeze_prediction_layer or freeze_token_embeddings` is `False`, we should not tie `word_embeddings` and `prediction_layer.decoder`; 27 | # otherwise, when the parameters of one of them are updated, the other will be updated 28 | tie_word_embeddings = freeze_prediction_layer and freeze_token_embeddings 29 | ( 30 | self.pretrained_model, 31 | self.encoder, 32 | self.token_embeddings, 33 | self.prediction_layer, 34 | ) = pretrained_modules( 35 | model_name=model_name, tie_word_embeddings=tie_word_embeddings, 36 | ) 37 | 38 | if randomize_prediction_layer: 39 | nn.init.normal_(self.prediction_layer.weight) 40 | if freeze_prediction_layer: 41 | for param in self.prediction_layer.parameters(): 42 | param.requires_grad = False 43 | if freeze_token_embeddings: 44 | for param in self.token_embeddings.parameters(): 45 | param.requires_grad = False 46 | 47 | self.pooling = NonParametricPooling(pooling_name=pooling_name) 48 | 49 | def forward(self, input_ids: Tensor, attention_mask: Tensor = None) -> Tensor: 50 | embs = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state 51 | emb = self.pooling(embs, attention_mask=attention_mask) 52 | return emb 53 | 54 | def predict_words(self, input_ids: Tensor, attention_mask: Tensor = None) -> Tensor: 55 | emb = self(input_ids, attention_mask=attention_mask) 56 | logits = self.prediction_layer(emb) 57 | return logits 58 | 59 | 60 | # Each pretrained model have different architecture and name. 61 | # This function performs like an `adapter`. 62 | def pretrained_modules( 63 | model_name: str, tie_word_embeddings: bool, 64 | ) -> Tuple[PreTrainedModel, nn.Module, nn.Module, nn.Module]: 65 | if model_name in [ 66 | "bert-base-uncased", 67 | "bert-large-uncased", 68 | "bert-base-cased", 69 | "bert-large-cased", 70 | "bert-base-multilingual-uncased", 71 | "bert-base-multilingual-cased", 72 | "bert-base-chinese", 73 | "bert-base-german-cased", 74 | "bert-large-uncased-whole-word-masking", 75 | "bert-large-cased-whole-word-masking", 76 | "bert-large-uncased-whole-word-masking-finetuned-squad", 77 | "bert-large-cased-whole-word-masking-finetuned-squad", 78 | "bert-base-cased-finetuned-mrpc", 79 | "bert-base-german-dbmdz-cased", 80 | "bert-base-german-dbmdz-uncased", 81 | "cl-tohoku/bert-base-japanese", 82 | "cl-tohoku/bert-base-japanese-whole-word-masking", 83 | "cl-tohoku/bert-base-japanese-char", 84 | "cl-tohoku/bert-base-japanese-char-whole-word-masking", 85 | "TurkuNLP/bert-base-finnish-cased-v1", 86 | "TurkuNLP/bert-base-finnish-uncased-v1", 87 | "wietsedv/bert-base-dutch-cased", 88 | # See all BERT models at https://huggingface.co/models?filter=bert 89 | ]: 90 | pretrained_model = BertForMaskedLM.from_pretrained( 91 | model_name, tie_word_embeddings=tie_word_embeddings, 92 | ) 93 | encoder = pretrained_model.bert 94 | token_embeddings = pretrained_model.bert.embeddings 95 | prediction_layer = pretrained_model.cls 96 | 97 | elif model_name in [ 98 | "roberta-base", 99 | "roberta-large", 100 | "xlm-roberta-base", 101 | "xlm-roberta-large", 102 | ]: 103 | pretrained_model = RobertaForMaskedLM.from_pretrained( 104 | model_name, tie_word_embeddings=tie_word_embeddings, 105 | ) 106 | encoder = pretrained_model.roberta 107 | token_embeddings = pretrained_model.roberta.embeddings 108 | prediction_layer = pretrained_model.lm_head 109 | 110 | elif model_name in ["albert-base-v2", "albert-large-v2"]: 111 | pretrained_model = AlbertForMaskedLM.from_pretrained( 112 | model_name, tie_word_embeddings=tie_word_embeddings, 113 | ) 114 | encoder = pretrained_model.albert 115 | token_embeddings = pretrained_model.albert.embeddings 116 | prediction_layer = pretrained_model.predictions 117 | 118 | elif model_name in [ 119 | "microsoft/deberta-base", 120 | "microsoft/deberta-large", 121 | "microsoft/deberta-xlarge", 122 | "microsoft/deberta-base-mnli", 123 | "microsoft/deberta-large-mnli", 124 | "microsoft/deberta-xlarge-mnli", 125 | "microsoft/deberta-v2-xlarge", 126 | "microsoft/deberta-v2-xxlarge", 127 | "microsoft/deberta-v2-xlarge-mnli", 128 | "microsoft/deberta-v2-xxlarge-mnli", 129 | ]: 130 | pretrained_model = DebertaForMaskedLM.from_pretrained( 131 | model_name, tie_word_embeddings=tie_word_embeddings, 132 | ) 133 | encoder = pretrained_model.deberta 134 | token_embeddings = pretrained_model.deberta.embeddings 135 | prediction_layer = pretrained_model.lm_predictions 136 | 137 | else: 138 | raise ValueError(f"no such a model name! > {model_name}") 139 | 140 | return pretrained_model, encoder, token_embeddings, prediction_layer 141 | -------------------------------------------------------------------------------- /experiments/src/pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | 5 | 6 | # using @torch.jit.script is slower than this simple implementaion. 7 | class NonParametricPooling(nn.Module): 8 | def __init__(self, pooling_name: str) -> None: 9 | super().__init__() 10 | self.pooling_name = pooling_name 11 | 12 | def forward(self, x: Tensor, attention_mask: Tensor) -> Tensor: 13 | if self.pooling_name == "CLS": 14 | return x[:, 0] 15 | 16 | # masked tokens are marked as `0` 17 | sent_len = attention_mask.sum(dim=1, keepdim=True) 18 | if self.pooling_name == "SEP": 19 | batch_size = x.size(0) 20 | batch_indices = torch.LongTensor(range(batch_size)) 21 | sep_indices = (sent_len.long() - 1).squeeze() 22 | return x[batch_indices, sep_indices] 23 | 24 | mask_value = 0 if self.pooling_name in ["Mean", "Sum"] else -1e6 25 | x[attention_mask.long() == 0, :] = mask_value 26 | 27 | if self.pooling_name == "Mean": 28 | return x.sum(dim=1) / sent_len 29 | 30 | elif self.pooling_name == "Max": 31 | return x.max(dim=1).values 32 | 33 | elif self.pooling_name == "Sum": 34 | return x.sum(dim=1) 35 | 36 | else: 37 | raise ValueError(f"No such a pooling name! {self.pooling_name}") 38 | -------------------------------------------------------------------------------- /experiments/src/scripts/extract_data_from_ishiwatari.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | from collections import defaultdict 4 | from pathlib import Path 5 | 6 | DATASET_DIR = Path("./dataset") 7 | 8 | 9 | def main(dataset_name): 10 | save_dir = DATASET_DIR / dataset_name 11 | save_dir.mkdir(exist_ok=True, parents=True) 12 | 13 | word_def = defaultdict(lambda: []) 14 | 15 | modes = ["train", "valid", "test"] 16 | for mode in modes: 17 | with (DATASET_DIR / "ishiwatari" / dataset_name / f"{mode}.txt").open() as f: 18 | for line in f: 19 | word, _, _, definition, *_ = line.strip().split("\t") 20 | word = word.rsplit("%", 1)[0].lstrip().rstrip() 21 | definition = ( 22 | definition.replace(" .", ".") 23 | .replace(" ,", ",") 24 | .replace(" ;", ";") 25 | .replace("( ", "(") 26 | .replace(" )", ")") 27 | .replace(" '", "'") 28 | ) 29 | definition = re.sub( 30 | r"`` (.*?)''", lambda x: x.group(1).capitalize(), definition 31 | ) 32 | definition = re.sub(r"‘\s*(.*?)\s*’", r"’\1’", definition) 33 | definition = definition.lstrip().rstrip() 34 | word_def[word].append(definition) 35 | 36 | all_words = sorted(word_def.keys()) 37 | 38 | def process(filename, words): 39 | num = 0 40 | lines = [] 41 | for word in words: 42 | definitions = word_def[word] 43 | num += len(definitions) 44 | lines += [f"{word}\t{definition}" for definition in definitions] 45 | 46 | (save_dir / filename).write_text("\n".join(lines)) 47 | return num 48 | 49 | print("sum of\tall lines:\t", process("all.tsv", all_words)) 50 | 51 | random.shuffle(all_words) 52 | train_words = all_words[: len(all_words) * 8 // 10] 53 | valid_words = all_words[len(all_words) * 8 // 10 : len(all_words) * 9 // 10] 54 | test_words = all_words[len(all_words) * 9 // 10 :] 55 | 56 | print("sum of\twords:\t", len(all_words)) 57 | print("sum of\ttrain words:\t", len(train_words)) 58 | print("sum of\tvalid words:\t", len(valid_words)) 59 | print("sum of\ttest words:\t", len(test_words)) 60 | 61 | print("sum of\ttrain lines:\t", process("train.tsv", train_words)) 62 | print("sum of\tvalid lines:\t", process("valid.tsv", valid_words)) 63 | print("sum of\ttest lines:\t", process("test.tsv", test_words)) 64 | 65 | 66 | if __name__ == "__main__": 67 | main("oxford") 68 | # main("wiki") 69 | # main("slang") 70 | -------------------------------------------------------------------------------- /experiments/src/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch.nn as nn 4 | from torch import Tensor 5 | 6 | 7 | def pad_sequence( 8 | sequences: List[Tensor], padding_value: int, padding_side: str = "right" 9 | ): 10 | if padding_side == "right": 11 | return right_side_padding(sequences, padding_value) 12 | elif padding_side == "left": 13 | return left_side_padding(sequences, padding_value) 14 | else: 15 | raise ValueError(f"no such a padding side name! > {padding_side}") 16 | 17 | 18 | def right_side_padding(sequences: List[Tensor], padding_value: int): 19 | return nn.utils.rnn.pad_sequence( 20 | sequences, batch_first=True, padding_value=padding_value, 21 | ) 22 | 23 | 24 | def left_side_padding(sequences: List[Tensor], padding_value: int): 25 | max_size = sequences[0].size() 26 | trailing_dims = max_size[1:] 27 | max_len = max([s.size(0) for s in sequences]) 28 | out_dims = (len(sequences), max_len) + trailing_dims 29 | 30 | out_tensor = sequences[0].new_full(out_dims, padding_value) 31 | for i, tensor in enumerate(sequences): 32 | # use index notation to prevent duplicate references to the tensor 33 | length = tensor.size(0) 34 | out_tensor[i, -length:, ...] = tensor 35 | 36 | return out_tensor 37 | -------------------------------------------------------------------------------- /poetry.lock: -------------------------------------------------------------------------------- 1 | [[package]] 2 | name = "certifi" 3 | version = "2021.5.30" 4 | description = "Python package for providing Mozilla's CA Bundle." 5 | category = "main" 6 | optional = false 7 | python-versions = "*" 8 | 9 | [[package]] 10 | name = "charset-normalizer" 11 | version = "2.0.4" 12 | description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." 13 | category = "main" 14 | optional = false 15 | python-versions = ">=3.5.0" 16 | 17 | [package.extras] 18 | unicode_backport = ["unicodedata2"] 19 | 20 | [[package]] 21 | name = "click" 22 | version = "8.0.1" 23 | description = "Composable command line interface toolkit" 24 | category = "main" 25 | optional = false 26 | python-versions = ">=3.6" 27 | 28 | [package.dependencies] 29 | colorama = {version = "*", markers = "platform_system == \"Windows\""} 30 | importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} 31 | 32 | [[package]] 33 | name = "colorama" 34 | version = "0.4.4" 35 | description = "Cross-platform colored terminal text." 36 | category = "main" 37 | optional = false 38 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 39 | 40 | [[package]] 41 | name = "filelock" 42 | version = "3.0.12" 43 | description = "A platform independent file lock." 44 | category = "main" 45 | optional = false 46 | python-versions = "*" 47 | 48 | [[package]] 49 | name = "huggingface-hub" 50 | version = "0.0.12" 51 | description = "Client library to download and publish models on the huggingface.co hub" 52 | category = "main" 53 | optional = false 54 | python-versions = ">=3.6.0" 55 | 56 | [package.dependencies] 57 | filelock = "*" 58 | importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} 59 | packaging = ">=20.9" 60 | requests = "*" 61 | tqdm = "*" 62 | typing-extensions = "*" 63 | 64 | [package.extras] 65 | all = ["pytest", "black (>=20.8b1)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"] 66 | dev = ["pytest", "black (>=20.8b1)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"] 67 | quality = ["black (>=20.8b1)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"] 68 | testing = ["pytest"] 69 | torch = ["torch"] 70 | 71 | [[package]] 72 | name = "idna" 73 | version = "3.2" 74 | description = "Internationalized Domain Names in Applications (IDNA)" 75 | category = "main" 76 | optional = false 77 | python-versions = ">=3.5" 78 | 79 | [[package]] 80 | name = "importlib-metadata" 81 | version = "4.6.3" 82 | description = "Read metadata from Python packages" 83 | category = "main" 84 | optional = false 85 | python-versions = ">=3.6" 86 | 87 | [package.dependencies] 88 | typing-extensions = {version = ">=3.6.4", markers = "python_version < \"3.8\""} 89 | zipp = ">=0.5" 90 | 91 | [package.extras] 92 | docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"] 93 | perf = ["ipython"] 94 | testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "packaging", "pep517", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy", "importlib-resources (>=1.3)"] 95 | 96 | [[package]] 97 | name = "joblib" 98 | version = "1.0.1" 99 | description = "Lightweight pipelining with Python functions" 100 | category = "main" 101 | optional = false 102 | python-versions = ">=3.6" 103 | 104 | [[package]] 105 | name = "numpy" 106 | version = "1.21.1" 107 | description = "NumPy is the fundamental package for array computing with Python." 108 | category = "main" 109 | optional = false 110 | python-versions = ">=3.7" 111 | 112 | [[package]] 113 | name = "packaging" 114 | version = "21.0" 115 | description = "Core utilities for Python packages" 116 | category = "main" 117 | optional = false 118 | python-versions = ">=3.6" 119 | 120 | [package.dependencies] 121 | pyparsing = ">=2.0.2" 122 | 123 | [[package]] 124 | name = "pyparsing" 125 | version = "2.4.7" 126 | description = "Python parsing module" 127 | category = "main" 128 | optional = false 129 | python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" 130 | 131 | [[package]] 132 | name = "pyyaml" 133 | version = "5.4.1" 134 | description = "YAML parser and emitter for Python" 135 | category = "main" 136 | optional = false 137 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" 138 | 139 | [[package]] 140 | name = "regex" 141 | version = "2021.7.6" 142 | description = "Alternative regular expression module, to replace re." 143 | category = "main" 144 | optional = false 145 | python-versions = "*" 146 | 147 | [[package]] 148 | name = "requests" 149 | version = "2.26.0" 150 | description = "Python HTTP for Humans." 151 | category = "main" 152 | optional = false 153 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" 154 | 155 | [package.dependencies] 156 | certifi = ">=2017.4.17" 157 | charset-normalizer = {version = ">=2.0.0,<2.1.0", markers = "python_version >= \"3\""} 158 | idna = {version = ">=2.5,<4", markers = "python_version >= \"3\""} 159 | urllib3 = ">=1.21.1,<1.27" 160 | 161 | [package.extras] 162 | socks = ["PySocks (>=1.5.6,!=1.5.7)", "win-inet-pton"] 163 | use_chardet_on_py3 = ["chardet (>=3.0.2,<5)"] 164 | 165 | [[package]] 166 | name = "sacremoses" 167 | version = "0.0.45" 168 | description = "SacreMoses" 169 | category = "main" 170 | optional = false 171 | python-versions = "*" 172 | 173 | [package.dependencies] 174 | click = "*" 175 | joblib = "*" 176 | regex = "*" 177 | six = "*" 178 | tqdm = "*" 179 | 180 | [[package]] 181 | name = "six" 182 | version = "1.16.0" 183 | description = "Python 2 and 3 compatibility utilities" 184 | category = "main" 185 | optional = false 186 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" 187 | 188 | [[package]] 189 | name = "tokenizers" 190 | version = "0.10.3" 191 | description = "Fast and Customizable Tokenizers" 192 | category = "main" 193 | optional = false 194 | python-versions = "*" 195 | 196 | [package.extras] 197 | testing = ["pytest"] 198 | 199 | [[package]] 200 | name = "torch" 201 | version = "1.9.0" 202 | description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" 203 | category = "main" 204 | optional = false 205 | python-versions = ">=3.6.2" 206 | 207 | [package.dependencies] 208 | typing-extensions = "*" 209 | 210 | [[package]] 211 | name = "tqdm" 212 | version = "4.62.0" 213 | description = "Fast, Extensible Progress Meter" 214 | category = "main" 215 | optional = false 216 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" 217 | 218 | [package.dependencies] 219 | colorama = {version = "*", markers = "platform_system == \"Windows\""} 220 | 221 | [package.extras] 222 | dev = ["py-make (>=0.1.0)", "twine", "wheel"] 223 | notebook = ["ipywidgets (>=6)"] 224 | telegram = ["requests"] 225 | 226 | [[package]] 227 | name = "transformers" 228 | version = "4.9.1" 229 | description = "State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch" 230 | category = "main" 231 | optional = false 232 | python-versions = ">=3.6.0" 233 | 234 | [package.dependencies] 235 | filelock = "*" 236 | huggingface-hub = "0.0.12" 237 | importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} 238 | numpy = ">=1.17" 239 | packaging = "*" 240 | pyyaml = ">=5.1" 241 | regex = "!=2019.12.17" 242 | requests = "*" 243 | sacremoses = "*" 244 | tokenizers = ">=0.10.1,<0.11" 245 | tqdm = ">=4.27" 246 | 247 | [package.extras] 248 | all = ["tensorflow (>=2.3)", "onnxconverter-common", "keras2onnx", "torch (>=1.0)", "jax (>=0.2.8)", "jaxlib (>=0.1.65)", "flax (>=0.3.4)", "optax (>=0.0.8)", "sentencepiece (==0.1.91)", "protobuf", "tokenizers (>=0.10.1,<0.11)", "soundfile", "torchaudio", "pillow", "optuna", "ray", "timm", "codecarbon (==1.2.0)"] 249 | codecarbon = ["codecarbon (==1.2.0)"] 250 | deepspeed = ["deepspeed (>=0.4.3)"] 251 | dev = ["tensorflow (>=2.3)", "onnxconverter-common", "keras2onnx", "torch (>=1.0)", "jax (>=0.2.8)", "jaxlib (>=0.1.65)", "flax (>=0.3.4)", "optax (>=0.0.8)", "sentencepiece (==0.1.91)", "protobuf", "tokenizers (>=0.10.1,<0.11)", "soundfile", "torchaudio", "pillow", "optuna", "ray", "timm", "codecarbon (==1.2.0)", "pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-timeout", "black (==21.4b0)", "sacrebleu (>=1.4.12)", "rouge-score", "nltk", "gitpython", "faiss-cpu", "cookiecutter (==1.7.2)", "isort (>=5.5.4)", "flake8 (>=3.8.3)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "unidic-lite (>=1.0.7)", "unidic (>=1.0.2)", "docutils (==0.16.0)", "recommonmark", "sphinx (==3.2.1)", "sphinx-markdown-tables", "sphinx-rtd-theme (==0.4.3)", "sphinx-copybutton", "sphinxext-opengraph (==0.4.1)", "scikit-learn"] 252 | docs = ["tensorflow (>=2.3)", "onnxconverter-common", "keras2onnx", "torch (>=1.0)", "jax (>=0.2.8)", "jaxlib (>=0.1.65)", "flax (>=0.3.4)", "optax (>=0.0.8)", "sentencepiece (==0.1.91)", "protobuf", "tokenizers (>=0.10.1,<0.11)", "soundfile", "torchaudio", "pillow", "optuna", "ray", "timm", "codecarbon (==1.2.0)", "docutils (==0.16.0)", "recommonmark", "sphinx (==3.2.1)", "sphinx-markdown-tables", "sphinx-rtd-theme (==0.4.3)", "sphinx-copybutton", "sphinxext-opengraph (==0.4.1)"] 253 | docs_specific = ["docutils (==0.16.0)", "recommonmark", "sphinx (==3.2.1)", "sphinx-markdown-tables", "sphinx-rtd-theme (==0.4.3)", "sphinx-copybutton", "sphinxext-opengraph (==0.4.1)"] 254 | fairscale = ["fairscale (>0.3)"] 255 | flax = ["jax (>=0.2.8)", "jaxlib (>=0.1.65)", "flax (>=0.3.4)", "optax (>=0.0.8)"] 256 | integrations = ["optuna", "ray"] 257 | ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "unidic-lite (>=1.0.7)", "unidic (>=1.0.2)"] 258 | modelcreation = ["cookiecutter (==1.7.2)"] 259 | onnx = ["onnxconverter-common", "keras2onnx", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] 260 | onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] 261 | optuna = ["optuna"] 262 | quality = ["black (==21.4b0)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"] 263 | ray = ["ray"] 264 | retrieval = ["faiss-cpu", "datasets"] 265 | sagemaker = ["sagemaker (>=2.31.0)"] 266 | sentencepiece = ["sentencepiece (==0.1.91)", "protobuf"] 267 | serving = ["pydantic", "uvicorn", "fastapi", "starlette"] 268 | sklearn = ["scikit-learn"] 269 | speech = ["soundfile", "torchaudio"] 270 | testing = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-timeout", "black (==21.4b0)", "sacrebleu (>=1.4.12)", "rouge-score", "nltk", "gitpython", "faiss-cpu", "cookiecutter (==1.7.2)"] 271 | tf = ["tensorflow (>=2.3)", "onnxconverter-common", "keras2onnx"] 272 | tf-cpu = ["tensorflow-cpu (>=2.3)", "onnxconverter-common", "keras2onnx"] 273 | timm = ["timm"] 274 | tokenizers = ["tokenizers (>=0.10.1,<0.11)"] 275 | torch = ["torch (>=1.0)"] 276 | torchhub = ["filelock", "huggingface-hub (==0.0.12)", "importlib-metadata", "numpy (>=1.17)", "packaging", "protobuf", "regex (!=2019.12.17)", "requests", "sacremoses", "sentencepiece (==0.1.91)", "torch (>=1.0)", "tokenizers (>=0.10.1,<0.11)", "tqdm (>=4.27)"] 277 | vision = ["pillow"] 278 | 279 | [[package]] 280 | name = "typing-extensions" 281 | version = "3.10.0.0" 282 | description = "Backported and Experimental Type Hints for Python 3.5+" 283 | category = "main" 284 | optional = false 285 | python-versions = "*" 286 | 287 | [[package]] 288 | name = "urllib3" 289 | version = "1.26.6" 290 | description = "HTTP library with thread-safe connection pooling, file post, and more." 291 | category = "main" 292 | optional = false 293 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4" 294 | 295 | [package.extras] 296 | brotli = ["brotlipy (>=0.6.0)"] 297 | secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "ipaddress"] 298 | socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] 299 | 300 | [[package]] 301 | name = "zipp" 302 | version = "3.5.0" 303 | description = "Backport of pathlib-compatible object wrapper for zip files" 304 | category = "main" 305 | optional = false 306 | python-versions = ">=3.6" 307 | 308 | [package.extras] 309 | docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"] 310 | testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy"] 311 | 312 | [metadata] 313 | lock-version = "1.1" 314 | python-versions = "^3.7" 315 | content-hash = "a6db320632a90159e9df5a30f81c02ed4544646f9633f1b62450b015777e544b" 316 | 317 | [metadata.files] 318 | certifi = [ 319 | {file = "certifi-2021.5.30-py2.py3-none-any.whl", hash = "sha256:50b1e4f8446b06f41be7dd6338db18e0990601dce795c2b1686458aa7e8fa7d8"}, 320 | {file = "certifi-2021.5.30.tar.gz", hash = "sha256:2bbf76fd432960138b3ef6dda3dde0544f27cbf8546c458e60baf371917ba9ee"}, 321 | ] 322 | charset-normalizer = [ 323 | {file = "charset-normalizer-2.0.4.tar.gz", hash = "sha256:f23667ebe1084be45f6ae0538e4a5a865206544097e4e8bbcacf42cd02a348f3"}, 324 | {file = "charset_normalizer-2.0.4-py3-none-any.whl", hash = "sha256:0c8911edd15d19223366a194a513099a302055a962bca2cec0f54b8b63175d8b"}, 325 | ] 326 | click = [ 327 | {file = "click-8.0.1-py3-none-any.whl", hash = "sha256:fba402a4a47334742d782209a7c79bc448911afe1149d07bdabdf480b3e2f4b6"}, 328 | {file = "click-8.0.1.tar.gz", hash = "sha256:8c04c11192119b1ef78ea049e0a6f0463e4c48ef00a30160c704337586f3ad7a"}, 329 | ] 330 | colorama = [ 331 | {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, 332 | {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, 333 | ] 334 | filelock = [ 335 | {file = "filelock-3.0.12-py3-none-any.whl", hash = "sha256:929b7d63ec5b7d6b71b0fa5ac14e030b3f70b75747cef1b10da9b879fef15836"}, 336 | {file = "filelock-3.0.12.tar.gz", hash = "sha256:18d82244ee114f543149c66a6e0c14e9c4f8a1044b5cdaadd0f82159d6a6ff59"}, 337 | ] 338 | huggingface-hub = [ 339 | {file = "huggingface_hub-0.0.12-py3-none-any.whl", hash = "sha256:5c82ff96897a72e1ed48a94c1796686f120dea05888200522f3994f130c12e6a"}, 340 | {file = "huggingface_hub-0.0.12.tar.gz", hash = "sha256:661b17fab0c475276fd71603ee7e16c3b3d1d6e812e1b29f40144f64d361e59d"}, 341 | ] 342 | idna = [ 343 | {file = "idna-3.2-py3-none-any.whl", hash = "sha256:14475042e284991034cb48e06f6851428fb14c4dc953acd9be9a5e95c7b6dd7a"}, 344 | {file = "idna-3.2.tar.gz", hash = "sha256:467fbad99067910785144ce333826c71fb0e63a425657295239737f7ecd125f3"}, 345 | ] 346 | importlib-metadata = [ 347 | {file = "importlib_metadata-4.6.3-py3-none-any.whl", hash = "sha256:51c6635429c77cf1ae634c997ff9e53ca3438b495f10a55ba28594dd69764a8b"}, 348 | {file = "importlib_metadata-4.6.3.tar.gz", hash = "sha256:0645585859e9a6689c523927a5032f2ba5919f1f7d0e84bd4533312320de1ff9"}, 349 | ] 350 | joblib = [ 351 | {file = "joblib-1.0.1-py3-none-any.whl", hash = "sha256:feeb1ec69c4d45129954f1b7034954241eedfd6ba39b5e9e4b6883be3332d5e5"}, 352 | {file = "joblib-1.0.1.tar.gz", hash = "sha256:9c17567692206d2f3fb9ecf5e991084254fe631665c450b443761c4186a613f7"}, 353 | ] 354 | numpy = [ 355 | {file = "numpy-1.21.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:38e8648f9449a549a7dfe8d8755a5979b45b3538520d1e735637ef28e8c2dc50"}, 356 | {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:fd7d7409fa643a91d0a05c7554dd68aa9c9bb16e186f6ccfe40d6e003156e33a"}, 357 | {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a75b4498b1e93d8b700282dc8e655b8bd559c0904b3910b144646dbbbc03e062"}, 358 | {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1412aa0aec3e00bc23fbb8664d76552b4efde98fb71f60737c83efbac24112f1"}, 359 | {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e46ceaff65609b5399163de5893d8f2a82d3c77d5e56d976c8b5fb01faa6b671"}, 360 | {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:c6a2324085dd52f96498419ba95b5777e40b6bcbc20088fddb9e8cbb58885e8e"}, 361 | {file = "numpy-1.21.1-cp37-cp37m-win32.whl", hash = "sha256:73101b2a1fef16602696d133db402a7e7586654682244344b8329cdcbbb82172"}, 362 | {file = "numpy-1.21.1-cp37-cp37m-win_amd64.whl", hash = "sha256:7a708a79c9a9d26904d1cca8d383bf869edf6f8e7650d85dbc77b041e8c5a0f8"}, 363 | {file = "numpy-1.21.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:95b995d0c413f5d0428b3f880e8fe1660ff9396dcd1f9eedbc311f37b5652e16"}, 364 | {file = "numpy-1.21.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:635e6bd31c9fb3d475c8f44a089569070d10a9ef18ed13738b03049280281267"}, 365 | {file = "numpy-1.21.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4a3d5fb89bfe21be2ef47c0614b9c9c707b7362386c9a3ff1feae63e0267ccb6"}, 366 | {file = "numpy-1.21.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8a326af80e86d0e9ce92bcc1e65c8ff88297de4fa14ee936cb2293d414c9ec63"}, 367 | {file = "numpy-1.21.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:791492091744b0fe390a6ce85cc1bf5149968ac7d5f0477288f78c89b385d9af"}, 368 | {file = "numpy-1.21.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0318c465786c1f63ac05d7c4dbcecd4d2d7e13f0959b01b534ea1e92202235c5"}, 369 | {file = "numpy-1.21.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9a513bd9c1551894ee3d31369f9b07460ef223694098cf27d399513415855b68"}, 370 | {file = "numpy-1.21.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:91c6f5fc58df1e0a3cc0c3a717bb3308ff850abdaa6d2d802573ee2b11f674a8"}, 371 | {file = "numpy-1.21.1-cp38-cp38-win32.whl", hash = "sha256:978010b68e17150db8765355d1ccdd450f9fc916824e8c4e35ee620590e234cd"}, 372 | {file = "numpy-1.21.1-cp38-cp38-win_amd64.whl", hash = "sha256:9749a40a5b22333467f02fe11edc98f022133ee1bfa8ab99bda5e5437b831214"}, 373 | {file = "numpy-1.21.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:d7a4aeac3b94af92a9373d6e77b37691b86411f9745190d2c351f410ab3a791f"}, 374 | {file = "numpy-1.21.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d9e7912a56108aba9b31df688a4c4f5cb0d9d3787386b87d504762b6754fbb1b"}, 375 | {file = "numpy-1.21.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:25b40b98ebdd272bc3020935427a4530b7d60dfbe1ab9381a39147834e985eac"}, 376 | {file = "numpy-1.21.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8a92c5aea763d14ba9d6475803fc7904bda7decc2a0a68153f587ad82941fec1"}, 377 | {file = "numpy-1.21.1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:05a0f648eb28bae4bcb204e6fd14603de2908de982e761a2fc78efe0f19e96e1"}, 378 | {file = "numpy-1.21.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f01f28075a92eede918b965e86e8f0ba7b7797a95aa8d35e1cc8821f5fc3ad6a"}, 379 | {file = "numpy-1.21.1-cp39-cp39-win32.whl", hash = "sha256:88c0b89ad1cc24a5efbb99ff9ab5db0f9a86e9cc50240177a571fbe9c2860ac2"}, 380 | {file = "numpy-1.21.1-cp39-cp39-win_amd64.whl", hash = "sha256:01721eefe70544d548425a07c80be8377096a54118070b8a62476866d5208e33"}, 381 | {file = "numpy-1.21.1-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2d4d1de6e6fb3d28781c73fbde702ac97f03d79e4ffd6598b880b2d95d62ead4"}, 382 | {file = "numpy-1.21.1.zip", hash = "sha256:dff4af63638afcc57a3dfb9e4b26d434a7a602d225b42d746ea7fe2edf1342fd"}, 383 | ] 384 | packaging = [ 385 | {file = "packaging-21.0-py3-none-any.whl", hash = "sha256:c86254f9220d55e31cc94d69bade760f0847da8000def4dfe1c6b872fd14ff14"}, 386 | {file = "packaging-21.0.tar.gz", hash = "sha256:7dc96269f53a4ccec5c0670940a4281106dd0bb343f47b7471f779df49c2fbe7"}, 387 | ] 388 | pyparsing = [ 389 | {file = "pyparsing-2.4.7-py2.py3-none-any.whl", hash = "sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b"}, 390 | {file = "pyparsing-2.4.7.tar.gz", hash = "sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1"}, 391 | ] 392 | pyyaml = [ 393 | {file = "PyYAML-5.4.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:3b2b1824fe7112845700f815ff6a489360226a5609b96ec2190a45e62a9fc922"}, 394 | {file = "PyYAML-5.4.1-cp27-cp27m-win32.whl", hash = "sha256:129def1b7c1bf22faffd67b8f3724645203b79d8f4cc81f674654d9902cb4393"}, 395 | {file = "PyYAML-5.4.1-cp27-cp27m-win_amd64.whl", hash = "sha256:4465124ef1b18d9ace298060f4eccc64b0850899ac4ac53294547536533800c8"}, 396 | {file = "PyYAML-5.4.1-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:bb4191dfc9306777bc594117aee052446b3fa88737cd13b7188d0e7aa8162185"}, 397 | {file = "PyYAML-5.4.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:6c78645d400265a062508ae399b60b8c167bf003db364ecb26dcab2bda048253"}, 398 | {file = "PyYAML-5.4.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:4e0583d24c881e14342eaf4ec5fbc97f934b999a6828693a99157fde912540cc"}, 399 | {file = "PyYAML-5.4.1-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:72a01f726a9c7851ca9bfad6fd09ca4e090a023c00945ea05ba1638c09dc3347"}, 400 | {file = "PyYAML-5.4.1-cp36-cp36m-manylinux2014_s390x.whl", hash = "sha256:895f61ef02e8fed38159bb70f7e100e00f471eae2bc838cd0f4ebb21e28f8541"}, 401 | {file = "PyYAML-5.4.1-cp36-cp36m-win32.whl", hash = "sha256:3bd0e463264cf257d1ffd2e40223b197271046d09dadf73a0fe82b9c1fc385a5"}, 402 | {file = "PyYAML-5.4.1-cp36-cp36m-win_amd64.whl", hash = "sha256:e4fac90784481d221a8e4b1162afa7c47ed953be40d31ab4629ae917510051df"}, 403 | {file = "PyYAML-5.4.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:5accb17103e43963b80e6f837831f38d314a0495500067cb25afab2e8d7a4018"}, 404 | {file = "PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:e1d4970ea66be07ae37a3c2e48b5ec63f7ba6804bdddfdbd3cfd954d25a82e63"}, 405 | {file = "PyYAML-5.4.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:cb333c16912324fd5f769fff6bc5de372e9e7a202247b48870bc251ed40239aa"}, 406 | {file = "PyYAML-5.4.1-cp37-cp37m-manylinux2014_s390x.whl", hash = "sha256:fe69978f3f768926cfa37b867e3843918e012cf83f680806599ddce33c2c68b0"}, 407 | {file = "PyYAML-5.4.1-cp37-cp37m-win32.whl", hash = "sha256:dd5de0646207f053eb0d6c74ae45ba98c3395a571a2891858e87df7c9b9bd51b"}, 408 | {file = "PyYAML-5.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:08682f6b72c722394747bddaf0aa62277e02557c0fd1c42cb853016a38f8dedf"}, 409 | {file = "PyYAML-5.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d2d9808ea7b4af864f35ea216be506ecec180628aced0704e34aca0b040ffe46"}, 410 | {file = "PyYAML-5.4.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:8c1be557ee92a20f184922c7b6424e8ab6691788e6d86137c5d93c1a6ec1b8fb"}, 411 | {file = "PyYAML-5.4.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:fd7f6999a8070df521b6384004ef42833b9bd62cfee11a09bda1079b4b704247"}, 412 | {file = "PyYAML-5.4.1-cp38-cp38-manylinux2014_s390x.whl", hash = "sha256:bfb51918d4ff3d77c1c856a9699f8492c612cde32fd3bcd344af9be34999bfdc"}, 413 | {file = "PyYAML-5.4.1-cp38-cp38-win32.whl", hash = "sha256:fa5ae20527d8e831e8230cbffd9f8fe952815b2b7dae6ffec25318803a7528fc"}, 414 | {file = "PyYAML-5.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:0f5f5786c0e09baddcd8b4b45f20a7b5d61a7e7e99846e3c799b05c7c53fa696"}, 415 | {file = "PyYAML-5.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:294db365efa064d00b8d1ef65d8ea2c3426ac366c0c4368d930bf1c5fb497f77"}, 416 | {file = "PyYAML-5.4.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:74c1485f7707cf707a7aef42ef6322b8f97921bd89be2ab6317fd782c2d53183"}, 417 | {file = "PyYAML-5.4.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d483ad4e639292c90170eb6f7783ad19490e7a8defb3e46f97dfe4bacae89122"}, 418 | {file = "PyYAML-5.4.1-cp39-cp39-manylinux2014_s390x.whl", hash = "sha256:fdc842473cd33f45ff6bce46aea678a54e3d21f1b61a7750ce3c498eedfe25d6"}, 419 | {file = "PyYAML-5.4.1-cp39-cp39-win32.whl", hash = "sha256:49d4cdd9065b9b6e206d0595fee27a96b5dd22618e7520c33204a4a3239d5b10"}, 420 | {file = "PyYAML-5.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:c20cfa2d49991c8b4147af39859b167664f2ad4561704ee74c1de03318e898db"}, 421 | {file = "PyYAML-5.4.1.tar.gz", hash = "sha256:607774cbba28732bfa802b54baa7484215f530991055bb562efbed5b2f20a45e"}, 422 | ] 423 | regex = [ 424 | {file = "regex-2021.7.6-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:e6a1e5ca97d411a461041d057348e578dc344ecd2add3555aedba3b408c9f874"}, 425 | {file = "regex-2021.7.6-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:6afe6a627888c9a6cfbb603d1d017ce204cebd589d66e0703309b8048c3b0854"}, 426 | {file = "regex-2021.7.6-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:ccb3d2190476d00414aab36cca453e4596e8f70a206e2aa8db3d495a109153d2"}, 427 | {file = "regex-2021.7.6-cp36-cp36m-manylinux2010_i686.whl", hash = "sha256:ed693137a9187052fc46eedfafdcb74e09917166362af4cc4fddc3b31560e93d"}, 428 | {file = "regex-2021.7.6-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:99d8ab206a5270c1002bfcf25c51bf329ca951e5a169f3b43214fdda1f0b5f0d"}, 429 | {file = "regex-2021.7.6-cp36-cp36m-manylinux2014_i686.whl", hash = "sha256:b85ac458354165405c8a84725de7bbd07b00d9f72c31a60ffbf96bb38d3e25fa"}, 430 | {file = "regex-2021.7.6-cp36-cp36m-manylinux2014_x86_64.whl", hash = "sha256:3f5716923d3d0bfb27048242a6e0f14eecdb2e2a7fac47eda1d055288595f222"}, 431 | {file = "regex-2021.7.6-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5983c19d0beb6af88cb4d47afb92d96751fb3fa1784d8785b1cdf14c6519407"}, 432 | {file = "regex-2021.7.6-cp36-cp36m-win32.whl", hash = "sha256:c92831dac113a6e0ab28bc98f33781383fe294df1a2c3dfd1e850114da35fd5b"}, 433 | {file = "regex-2021.7.6-cp36-cp36m-win_amd64.whl", hash = "sha256:791aa1b300e5b6e5d597c37c346fb4d66422178566bbb426dd87eaae475053fb"}, 434 | {file = "regex-2021.7.6-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:59506c6e8bd9306cd8a41511e32d16d5d1194110b8cfe5a11d102d8b63cf945d"}, 435 | {file = "regex-2021.7.6-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:564a4c8a29435d1f2256ba247a0315325ea63335508ad8ed938a4f14c4116a5d"}, 436 | {file = "regex-2021.7.6-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:59c00bb8dd8775473cbfb967925ad2c3ecc8886b3b2d0c90a8e2707e06c743f0"}, 437 | {file = "regex-2021.7.6-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:9a854b916806c7e3b40e6616ac9e85d3cdb7649d9e6590653deb5b341a736cec"}, 438 | {file = "regex-2021.7.6-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:db2b7df831c3187a37f3bb80ec095f249fa276dbe09abd3d35297fc250385694"}, 439 | {file = "regex-2021.7.6-cp37-cp37m-manylinux2014_i686.whl", hash = "sha256:173bc44ff95bc1e96398c38f3629d86fa72e539c79900283afa895694229fe6a"}, 440 | {file = "regex-2021.7.6-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:15dddb19823f5147e7517bb12635b3c82e6f2a3a6b696cc3e321522e8b9308ad"}, 441 | {file = "regex-2021.7.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ddeabc7652024803666ea09f32dd1ed40a0579b6fbb2a213eba590683025895"}, 442 | {file = "regex-2021.7.6-cp37-cp37m-win32.whl", hash = "sha256:f080248b3e029d052bf74a897b9d74cfb7643537fbde97fe8225a6467fb559b5"}, 443 | {file = "regex-2021.7.6-cp37-cp37m-win_amd64.whl", hash = "sha256:d8bbce0c96462dbceaa7ac4a7dfbbee92745b801b24bce10a98d2f2b1ea9432f"}, 444 | {file = "regex-2021.7.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:edd1a68f79b89b0c57339bce297ad5d5ffcc6ae7e1afdb10f1947706ed066c9c"}, 445 | {file = "regex-2021.7.6-cp38-cp38-manylinux1_i686.whl", hash = "sha256:422dec1e7cbb2efbbe50e3f1de36b82906def93ed48da12d1714cabcd993d7f0"}, 446 | {file = "regex-2021.7.6-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cbe23b323988a04c3e5b0c387fe3f8f363bf06c0680daf775875d979e376bd26"}, 447 | {file = "regex-2021.7.6-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:0eb2c6e0fcec5e0f1d3bcc1133556563222a2ffd2211945d7b1480c1b1a42a6f"}, 448 | {file = "regex-2021.7.6-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:1c78780bf46d620ff4fff40728f98b8afd8b8e35c3efd638c7df67be2d5cddbf"}, 449 | {file = "regex-2021.7.6-cp38-cp38-manylinux2014_i686.whl", hash = "sha256:bc84fb254a875a9f66616ed4538542fb7965db6356f3df571d783f7c8d256edd"}, 450 | {file = "regex-2021.7.6-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:598c0a79b4b851b922f504f9f39a863d83ebdfff787261a5ed061c21e67dd761"}, 451 | {file = "regex-2021.7.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:875c355360d0f8d3d827e462b29ea7682bf52327d500a4f837e934e9e4656068"}, 452 | {file = "regex-2021.7.6-cp38-cp38-win32.whl", hash = "sha256:e586f448df2bbc37dfadccdb7ccd125c62b4348cb90c10840d695592aa1b29e0"}, 453 | {file = "regex-2021.7.6-cp38-cp38-win_amd64.whl", hash = "sha256:2fe5e71e11a54e3355fa272137d521a40aace5d937d08b494bed4529964c19c4"}, 454 | {file = "regex-2021.7.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6110bab7eab6566492618540c70edd4d2a18f40ca1d51d704f1d81c52d245026"}, 455 | {file = "regex-2021.7.6-cp39-cp39-manylinux1_i686.whl", hash = "sha256:4f64fc59fd5b10557f6cd0937e1597af022ad9b27d454e182485f1db3008f417"}, 456 | {file = "regex-2021.7.6-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:89e5528803566af4df368df2d6f503c84fbfb8249e6631c7b025fe23e6bd0cde"}, 457 | {file = "regex-2021.7.6-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:2366fe0479ca0e9afa534174faa2beae87847d208d457d200183f28c74eaea59"}, 458 | {file = "regex-2021.7.6-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:f9392a4555f3e4cb45310a65b403d86b589adc773898c25a39184b1ba4db8985"}, 459 | {file = "regex-2021.7.6-cp39-cp39-manylinux2014_i686.whl", hash = "sha256:2bceeb491b38225b1fee4517107b8491ba54fba77cf22a12e996d96a3c55613d"}, 460 | {file = "regex-2021.7.6-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:f98dc35ab9a749276f1a4a38ab3e0e2ba1662ce710f6530f5b0a6656f1c32b58"}, 461 | {file = "regex-2021.7.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:319eb2a8d0888fa6f1d9177705f341bc9455a2c8aca130016e52c7fe8d6c37a3"}, 462 | {file = "regex-2021.7.6-cp39-cp39-win32.whl", hash = "sha256:eaf58b9e30e0e546cdc3ac06cf9165a1ca5b3de8221e9df679416ca667972035"}, 463 | {file = "regex-2021.7.6-cp39-cp39-win_amd64.whl", hash = "sha256:4c9c3155fe74269f61e27617529b7f09552fbb12e44b1189cebbdb24294e6e1c"}, 464 | {file = "regex-2021.7.6.tar.gz", hash = "sha256:8394e266005f2d8c6f0bc6780001f7afa3ef81a7a2111fa35058ded6fce79e4d"}, 465 | ] 466 | requests = [ 467 | {file = "requests-2.26.0-py2.py3-none-any.whl", hash = "sha256:6c1246513ecd5ecd4528a0906f910e8f0f9c6b8ec72030dc9fd154dc1a6efd24"}, 468 | {file = "requests-2.26.0.tar.gz", hash = "sha256:b8aa58f8cf793ffd8782d3d8cb19e66ef36f7aba4353eec859e74678b01b07a7"}, 469 | ] 470 | sacremoses = [ 471 | {file = "sacremoses-0.0.45-py3-none-any.whl", hash = "sha256:fa93db44bc04542553ba6090818b892f603d02aa0d681e6c5c3023baf17e8564"}, 472 | {file = "sacremoses-0.0.45.tar.gz", hash = "sha256:58176cc28391830789b763641d0f458819bebe88681dac72b41a19c0aedc07e9"}, 473 | ] 474 | six = [ 475 | {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, 476 | {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, 477 | ] 478 | tokenizers = [ 479 | {file = "tokenizers-0.10.3-cp36-cp36m-macosx_10_11_x86_64.whl", hash = "sha256:4ab688daf4692a6c31dfe42f1f3a4a8c22050705eb69d58d3efde9d55f434586"}, 480 | {file = "tokenizers-0.10.3-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c26dbc3b2a3d71d3d40c50975ec62145932f05aea73f03ea35c48ebd3a717611"}, 481 | {file = "tokenizers-0.10.3-cp36-cp36m-win32.whl", hash = "sha256:6b84673997990b3c260ae2f7c57fdf1f835e316820eff14aca46dc68be3c0c74"}, 482 | {file = "tokenizers-0.10.3-cp36-cp36m-win_amd64.whl", hash = "sha256:2a9ee3ee574d4aa740e099b0ad6ef8e63f52f48cde359bb31801146a5aa614dc"}, 483 | {file = "tokenizers-0.10.3-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:2f8c5fefef0d0a03be613547e613fbda06b9e6ee0891236649524964c3e54d80"}, 484 | {file = "tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:4cc194104c8e427ffc4f54c7866488b42f2b1f6351a6cad0d045ca5ab8108e42"}, 485 | {file = "tokenizers-0.10.3-cp37-cp37m-win32.whl", hash = "sha256:edd8cb85c16b4b65e87ea5ef9d400be9fdd53c4152adbaca8817e16dd3aa480b"}, 486 | {file = "tokenizers-0.10.3-cp37-cp37m-win_amd64.whl", hash = "sha256:7b11b373705d082d43657c08883b79b5330f1952f0668d17488b6b889c4d7feb"}, 487 | {file = "tokenizers-0.10.3-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:a7ce0c2f27f7c92aa3f895231de90319acdf960ce2e42ba591edc651fda7d3c9"}, 488 | {file = "tokenizers-0.10.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ae7e40d9c8a77c5a4109731ac3e21633b0c609c56a8b58be6b863da61fa54636"}, 489 | {file = "tokenizers-0.10.3-cp38-cp38-win32.whl", hash = "sha256:a7ce051aafc53c564c9edbc09df300c2bd4f6ce87460fc22a276fed405d1892a"}, 490 | {file = "tokenizers-0.10.3-cp38-cp38-win_amd64.whl", hash = "sha256:91a8c045980594c7c437a52c3da5276eb3c530a662b4ef628ff32d81fb22b543"}, 491 | {file = "tokenizers-0.10.3-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:1d8867db210d75d97312360ae23b92aeb6a6b5bc65e15c1cd9d204b3fa3fc262"}, 492 | {file = "tokenizers-0.10.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:18c495e700f4588b9a00e58b4c41dc459c36daaa7c39a27faf880eb8f5533ce1"}, 493 | {file = "tokenizers-0.10.3-cp39-cp39-win32.whl", hash = "sha256:ad700fd9da518884fd58bf89f0b6dfeecef9b4e2d2db8765ef259f66d6c14980"}, 494 | {file = "tokenizers-0.10.3-cp39-cp39-win_amd64.whl", hash = "sha256:e9d147e545cdfeca560646c7a703bf287afe45645da426506ccd5eb78aab5ef5"}, 495 | {file = "tokenizers-0.10.3.tar.gz", hash = "sha256:1a5d3b596c6d3a237e1ad7f46c472d467b0246be7fd1a364f12576eb8db8f7e6"}, 496 | ] 497 | torch = [ 498 | {file = "torch-1.9.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:3a2d070cf28860d285d4ab156f3954c0c1d12f4c037aa312a7c029227c0d106b"}, 499 | {file = "torch-1.9.0-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:b296e65e25081af147af936f1e3a1f17f583a9afacfa5309742678ffef728ace"}, 500 | {file = "torch-1.9.0-cp36-cp36m-win_amd64.whl", hash = "sha256:117098d4924b260a24a47c6b3fe37f2ae41f04a2ea2eff9f553ae9210b12fa54"}, 501 | {file = "torch-1.9.0-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:d6103b9a634993bd967337a1149f9d8b23922f42a3660676239399e15c1b4515"}, 502 | {file = "torch-1.9.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:0164673908e6b291ace592d382eba3e258b3bad009b8078cad8f3b9e00d8f23e"}, 503 | {file = "torch-1.9.0-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:52548b45efff772fe3810fe91daf34f981ac0ca1a7227f6226fd5693f53b5b88"}, 504 | {file = "torch-1.9.0-cp37-cp37m-win_amd64.whl", hash = "sha256:62c0a7e433681d0861494d1ede96d2485e4dbb3ea8fd867e8419addebf5de1af"}, 505 | {file = "torch-1.9.0-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:d88333091fd1627894bbf0d6dcef58a90e36bdf0d90a5d4675b5e07e72075511"}, 506 | {file = "torch-1.9.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:1d8139dcc864f48dc316376384f50e47a459284ad1cb84449242f4964e25aaec"}, 507 | {file = "torch-1.9.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:0aa4cca3f16fab40cb8dae6a49d0eccdc8f4ead9d1a6428cd9ba12befe082b2a"}, 508 | {file = "torch-1.9.0-cp38-cp38-win_amd64.whl", hash = "sha256:646de1bef85d6c7590e98f8ea52e47acdcf58330982e4f5d73f5ca28dea2d552"}, 509 | {file = "torch-1.9.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:e596f0105f748cf09d4763152d8157aaf58d5231232eaf2c5673d4562ba86ad3"}, 510 | {file = "torch-1.9.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:ecc7193fff7741ced3db1f760666c8454d6664956288c54d1b49613b987a42f4"}, 511 | {file = "torch-1.9.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:95eeec3a6c42fd35aca552777b7d9979ed489760423de97c0118a45e849a61f4"}, 512 | {file = "torch-1.9.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:8a2b2012b3c7d6019e189496688fa77de7029a220840b406d8302d1c8021a11c"}, 513 | {file = "torch-1.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:7e2b14fe5b3a8266cbe2f6740c0195497507974ced7bc21e99971561913a0c28"}, 514 | {file = "torch-1.9.0-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:0a9e74b5057463ce4e55d9332a5670993fc9e1299c52e1740e505eda106fb355"}, 515 | {file = "torch-1.9.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:569ead6ae6bb0e636df0fc8af660ef03260e630dc5f2f4cf3198027e7b6bb481"}, 516 | ] 517 | tqdm = [ 518 | {file = "tqdm-4.62.0-py2.py3-none-any.whl", hash = "sha256:706dea48ee05ba16e936ee91cb3791cd2ea6da348a0e50b46863ff4363ff4340"}, 519 | {file = "tqdm-4.62.0.tar.gz", hash = "sha256:3642d483b558eec80d3c831e23953582c34d7e4540db86d9e5ed9dad238dabc6"}, 520 | ] 521 | transformers = [ 522 | {file = "transformers-4.9.1-py3-none-any.whl", hash = "sha256:86f3c46efecf114c6886d361c1d6cca14738f0e9d1effadb1e9252770cba55a0"}, 523 | {file = "transformers-4.9.1.tar.gz", hash = "sha256:1c30e38b2e0da15e110d9bb9a627f78de9569b9c6036d6533baf783015c339be"}, 524 | ] 525 | typing-extensions = [ 526 | {file = "typing_extensions-3.10.0.0-py2-none-any.whl", hash = "sha256:0ac0f89795dd19de6b97debb0c6af1c70987fd80a2d62d1958f7e56fcc31b497"}, 527 | {file = "typing_extensions-3.10.0.0-py3-none-any.whl", hash = "sha256:779383f6086d90c99ae41cf0ff39aac8a7937a9283ce0a414e5dd782f4c94a84"}, 528 | {file = "typing_extensions-3.10.0.0.tar.gz", hash = "sha256:50b6f157849174217d0656f99dc82fe932884fb250826c18350e159ec6cdf342"}, 529 | ] 530 | urllib3 = [ 531 | {file = "urllib3-1.26.6-py2.py3-none-any.whl", hash = "sha256:39fb8672126159acb139a7718dd10806104dec1e2f0f6c88aab05d17df10c8d4"}, 532 | {file = "urllib3-1.26.6.tar.gz", hash = "sha256:f57b4c16c62fa2760b7e3d97c35b255512fb6b59a259730f36ba32ce9f8e342f"}, 533 | ] 534 | zipp = [ 535 | {file = "zipp-3.5.0-py3-none-any.whl", hash = "sha256:957cfda87797e389580cb8b9e3870841ca991e2125350677b2ca83a0e99390a3"}, 536 | {file = "zipp-3.5.0.tar.gz", hash = "sha256:f5812b1e007e48cff63449a5e9f4e7ebea716b4111f9c4f9a645f91d579bf0c4"}, 537 | ] 538 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "defsent" 3 | version = "0.1.0" 4 | description = "DefSent: Sentence Embeddings using Definition Sentences" 5 | authors = ["hppRC "] 6 | readme = "README.md" 7 | homepage = "https://arxiv.org/abs/2105.04339" 8 | repository = "https://github.com/hppRC/defsent" 9 | 10 | include = ["defsent/**/*"] 11 | exclude = ["experiments/**/*", "examples/**/*"] 12 | 13 | [tool.poetry.dependencies] 14 | python = "^3.7" 15 | transformers = "*" 16 | torch = "*" 17 | 18 | [tool.poetry.dev-dependencies] 19 | pysen = {version = "^0.9.1", extras = ["lint"]} 20 | pytest = "^5.2" 21 | 22 | [build-system] 23 | requires = ["poetry-core>=1.0.0"] 24 | build-backend = "poetry.core.masonry.api" 25 | 26 | [tool.pysen] 27 | version = "0.9" 28 | 29 | [tool.pysen.lint] 30 | enable_black = true 31 | enable_flake8 = true 32 | enable_isort = true 33 | enable_mypy = false 34 | mypy_preset = "strict" 35 | py_version = "py37" 36 | [[tool.pysen.lint.mypy_targets]] 37 | paths = ["."] 38 | --------------------------------------------------------------------------------