├── .github └── workflows │ └── docs.yml ├── .gitignore ├── LICENSE ├── README.md ├── banner.png ├── examples ├── classification │ ├── NotoSansDisplay-Regular.ttf │ ├── rnn.py │ └── rnn.yaml └── translation │ ├── NotoSansDisplay-Regular.ttf │ ├── nmt.yaml │ └── nmt_transformer.py ├── poetry.lock ├── pyproject.toml ├── tests ├── hash.py └── visual.py └── text_embeddings ├── __init__.py ├── base └── __init__.py ├── byte ├── __init__.py ├── byt5.py └── charformer.py ├── hash ├── __init__.py ├── canine.py ├── pqrnn.py └── util.py ├── pruning ├── __init__.py └── ltp.py ├── visual ├── __init__.py └── vtr.py └── x └── __init__.py /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: github pages 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-18.04 11 | steps: 12 | - uses: actions/checkout@v2 13 | 14 | - name: Setup Python 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: '3.8' 18 | 19 | - name: Upgrade pip 20 | run: | 21 | # install pip=>20.1 to use "pip cache dir" 22 | python3 -m pip install --upgrade pip 23 | 24 | - name: Get pip cache dir 25 | id: pip-cache 26 | run: echo "::set-output name=dir::$(pip cache dir)" 27 | 28 | - name: Cache dependencies 29 | uses: actions/cache@v2 30 | with: 31 | path: ${{ steps.pip-cache.outputs.dir }} 32 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} 33 | restore-keys: | 34 | ${{ runner.os }}-pip- 35 | 36 | - name: Install dependencies 37 | run: python3 -m pip install poetry; poetry export -o requirements.txt --without-hashes 38 | 39 | - name: Install rest of the dependencies 40 | run: python3 -m pip install -r requirements.txt 41 | 42 | - run: pdoc --html text_embeddings --output-dir docs 43 | 44 | - name: Deploy 45 | uses: peaceiris/actions-gh-pages@v3 46 | with: 47 | github_token: ${{ secrets.GITHUB_TOKEN }} 48 | publish_dir: ./docs/text_embeddings -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # VSCode 2 | .vscode/settings.json 3 | *logs 4 | .coverage 5 | .data 6 | Noto_Sans 7 | flores_test_sets 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | 139 | cobertura.xml 140 | .coverage -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Chenghao MOU 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![banner](./banner.png) 2 | [![PyPI version](https://badge.fury.io/py/text-embeddings.svg)](https://badge.fury.io/py/text-embeddings) [![Codacy Badge](https://app.codacy.com/project/badge/Grade/112e50abd97444a4aca06f94fb7e8873)](https://www.codacy.com/gh/ChenghaoMou/embeddings/dashboard?utm_source=github.com&utm_medium=referral&utm_content=ChenghaoMou/embeddings&utm_campaign=Badge_Grade)[![Codacy Badge](https://app.codacy.com/project/badge/Coverage/112e50abd97444a4aca06f94fb7e8873)](https://www.codacy.com/gh/ChenghaoMou/embeddings/dashboard?utm_source=github.com&utm_medium=referral&utm_content=ChenghaoMou/embeddings&utm_campaign=Badge_Coverage) 3 | 4 | ## Features 5 | 6 | - [x] `VTRTokenizer` from [Robust Open­-Vocabulary Translation from Visual Text Representations](https://t.co/l9E6rL8O5p?amp=1) 7 | - [x] `PQRNNTokenizer` from [Advancing NLP with Efficient Projection-Based Model Architectures](https://ai.googleblog.com/2020/09/advancing-nlp-with-efficient-projection.html) 8 | - [x] `CANINETokenizer` from [CANINE: Pre-training an Efficient Tokenization-Free Encoder for Language Representation](https://arxiv.org/abs/2103.06874) 9 | - [x] `ByT5Tokenizer` from [ByT5: Towards a token-free future with pre-trained byte-to-byte models](https://arxiv.org/pdf/2105.13626.pdf) 10 | - [x] `GBST` and `ByteTokenizer` from [Charformer: Fast Character Transformers via Gradient-based Subword Tokenization](https://arxiv.org/abs/2106.12672) 11 | - [x] `LTPMultiHeadAttention` from [Learned Token Pruning for Transformers](https://arxiv.org/abs/2107.00910) 12 | - [x] `X` and `XLoss`, a model inspired from [PonderNet](https://arxiv.org/abs/2107.05407) and [Perceiver](https://arxiv.org/abs/2103.03206), with Byte Embeddings. 13 | 14 | ## Examples 15 | 16 | - [x] [Machine Translation](examples/translation/nmt_transformer.py) 17 | - [x] [Text Classification](examples/classification/rnn.py) 18 | 19 | ## Installation 20 | 21 | ```bash 22 | pip install text-embeddings --upgrade 23 | ``` 24 | 25 | ## Documentation 26 | 27 | [Link](https://chenghaomou.github.io/embeddings/) 28 | 29 | ## Example Usage 30 | 31 | ```python 32 | from text_embeddings.visual import VTRTokenizer 33 | 34 | data = [ 35 | "Hello world!", 36 | "¡Hola Mundo!", 37 | "你好,世界!", 38 | ] 39 | 40 | tokenizer = VTRTokenizer( 41 | font_size=14, 42 | window_size=10, 43 | font="resources/NotoSans-Regular.ttf", 44 | max_length=36 45 | ) 46 | 47 | results = tokenizer( 48 | text=data, 49 | text_pair=data, 50 | add_special_tokens=True, 51 | padding="longest", 52 | return_tensors='pt', 53 | truncation="longest_first", 54 | return_attention_mask=True, 55 | return_special_tokens_mask=True, 56 | return_length=True, 57 | prepend_batch_axis=True, 58 | return_overflowing_tokens=False, 59 | ) 60 | 61 | assert results["input_ids"].shape == (3, results["input_ids"].shape[1], 14, 10) 62 | assert results["attention_mask"].shape == (3, results["input_ids"].shape[1]) 63 | assert results["token_type_ids"].shape == (3, results["input_ids"].shape[1]) 64 | assert results["length"].shape == (3, ) 65 | ``` 66 | 67 | ## Write Your Own Embedding Tokenizer 68 | 69 | ```python 70 | import numpy as np 71 | from typing import Optional, List, Dict 72 | from text_embeddings.base import EmbeddingTokenizer 73 | 74 | 75 | class MyOwnTokenizer(EmbeddingTokenizer): 76 | 77 | def __init__( 78 | self, 79 | model_input_names: Optional[List[str]] = None, 80 | special_tokens: Optional[Dict[str, np.ndarray]] = None, 81 | max_length: Optional[int] = 2048, 82 | ): 83 | super().__init__(model_input_names, special_tokens, max_length) 84 | 85 | def text2embeddings(self, text: str) -> np.ndarray: 86 | 87 | sequence_length = 10 88 | dimensions = (10, 10, 10) # each token is mapped to a 3-d array 89 | return np.zeros((sequence_length, *dimensions)) 90 | 91 | def create_padding_token_embedding(self, input_embeddings=None) -> np.ndarray: 92 | 93 | # let's create a consistent 3-d array 94 | return np.zeros((10, 10, 10)) 95 | 96 | ``` 97 | 98 | ## Example Usage for GBST 99 | 100 | ```python 101 | import torch.onnx # nightly torch only 102 | from text_embeddings.byte.charformer import GBST, ByteTokenizer 103 | 104 | model = GBST( 105 | embed_size=128, 106 | max_block_size=4, 107 | downsampling_factor=2, 108 | score_calibration=True, 109 | vocab_size=259, 110 | ) 111 | 112 | tokenizer = ByteTokenizer() 113 | results = tokenizer( 114 | ["Life is like a box of chocolates.", "Coding is fun."], 115 | add_special_tokens=True, 116 | padding="longest", 117 | truncation="longest_first", 118 | ) 119 | 120 | # Export the model 121 | torch.onnx.export( 122 | model, 123 | torch.tensor(results["input_ids"], requires_grad=True).long(), 124 | "gbst.onnx", 125 | export_params=True, 126 | opset_version=11, 127 | do_constant_folding=True, 128 | input_names=["input"], 129 | output_names=["output"], 130 | dynamic_axes={ 131 | "input": {0: "batch_size", 1: "sequence_length"}, 132 | "output": {0: "batch_size"}, 133 | }, 134 | ) 135 | ``` 136 | -------------------------------------------------------------------------------- /banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenghaoMou/embeddings/d414ebfbfc50eb218727ceac207035bf731c4701/banner.png -------------------------------------------------------------------------------- /examples/classification/NotoSansDisplay-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenghaoMou/embeddings/d414ebfbfc50eb218727ceac207035bf731c4701/examples/classification/NotoSansDisplay-Regular.ttf -------------------------------------------------------------------------------- /examples/classification/rnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2021-05-22 14:50:34 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | 6 | import torch 7 | import numpy as np 8 | import torch.nn as nn 9 | import pytorch_lightning as pl 10 | from datasets import load_dataset 11 | from torch.utils.data import DataLoader 12 | from typing import Optional 13 | from text_embeddings.visual import VTRTokenizer 14 | from einops import rearrange 15 | 16 | 17 | class Model(pl.LightningModule): 18 | def __init__( 19 | self, hidden: int = 128, learning_rate: float = 1e-3, num_labels: int = 20 20 | ): 21 | super().__init__() 22 | self.model = nn.GRU( 23 | hidden, hidden, num_layers=2, bidirectional=True, batch_first=True 24 | ) 25 | self.nonlinear = nn.ReLU() 26 | self.fc = nn.Linear(hidden * 2, num_labels) 27 | self.loss = nn.CrossEntropyLoss(ignore_index=0) 28 | self.lr = learning_rate 29 | 30 | def forward(self, batch): 31 | 32 | embeddings = batch["input_ids"].float() 33 | logits, _ = self.model(rearrange(embeddings, "b s h w -> b s (h w)")) 34 | logits = torch.cat( 35 | [ 36 | logits[:, :, : logits.shape[-1] // 2], 37 | logits[:, :, logits.shape[-1] // 2 :], 38 | ], 39 | dim=-1, 40 | ) 41 | logits = torch.mean(logits, dim=1) 42 | logits = self.nonlinear(logits) 43 | logits = self.fc(logits) 44 | 45 | return logits 46 | 47 | def training_step(self, batch, batch_idx): 48 | 49 | inputs, labels = batch 50 | logits = self.forward(inputs) 51 | return {"loss": self.loss(logits, labels)} 52 | 53 | def validation_step(self, batch, batch_idx): 54 | inputs, labels = batch 55 | logits = self.forward(inputs) 56 | # logger.debug(f"{labels.shape, logits.shape}") 57 | loss = self.loss(logits, labels) 58 | 59 | self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True) 60 | return {"val_loss": loss} 61 | 62 | def configure_optimizers(self): 63 | 64 | return {"optimizer": torch.optim.Adam(self.parameters(), lr=self.lr)} 65 | 66 | 67 | class DataModule(pl.LightningDataModule): 68 | def __init__( 69 | self, 70 | dataset_name: str, 71 | font_path="/home/chenghaomou/embeddings/Noto_Sans/NotoSans-Regular.ttf", 72 | font_size: int = 16, 73 | window_size: int = 8, 74 | stride: int = 5, 75 | batch_size: int = 8, 76 | subtask: Optional[str] = None, 77 | ): 78 | super().__init__() 79 | self.dataset = ( 80 | load_dataset(dataset_name, subtask) 81 | if subtask 82 | else load_dataset(dataset_name) 83 | ) 84 | self.tokenizer = VTRTokenizer( 85 | font=font_path, window_size=window_size, font_size=font_size, stride=stride 86 | ) 87 | self.batch_size = batch_size 88 | 89 | def setup(self, stage=None): 90 | self.train = self.dataset["train"] 91 | self.val = self.dataset["test"] 92 | 93 | def train_dataloader(self): 94 | return DataLoader( 95 | [{"text": x["text"], "label": x["label"]} for x in self.train], 96 | batch_size=self.batch_size, 97 | collate_fn=self.collate_fn, 98 | num_workers=4, 99 | ) 100 | 101 | def val_dataloader(self): 102 | return DataLoader( 103 | [{"text": x["text"], "label": x["label"]} for x in self.val], 104 | batch_size=self.batch_size, 105 | collate_fn=self.collate_fn, 106 | num_workers=4, 107 | ) 108 | 109 | def collate_fn(self, examples): 110 | 111 | text = [e["text"] for e in examples] 112 | labels = [e["label"] for e in examples] 113 | 114 | results = self.tokenizer( 115 | text, 116 | return_tensors="pt", 117 | padding="longest", 118 | truncation="longest_first", 119 | return_attention_mask=True, 120 | return_token_type_ids=False, 121 | ) 122 | return results, torch.from_numpy(np.asarray(labels)).long() 123 | 124 | 125 | if __name__ == "__main__": 126 | 127 | from pytorch_lightning.utilities.cli import LightningCLI 128 | 129 | cli = LightningCLI(Model, DataModule) 130 | -------------------------------------------------------------------------------- /examples/classification/rnn.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 42 2 | trainer: 3 | logger: true 4 | checkpoint_callback: true 5 | callbacks: 6 | - class_path: pytorch_lightning.callbacks.EarlyStopping 7 | init_args: 8 | patience: 5 9 | monitor: val_loss 10 | stopping_threshold: 1.0e-4 11 | divergence_threshold: 9.0 12 | check_finite: true 13 | gradient_clip_val: 1.0 14 | gradient_clip_algorithm: norm 15 | process_position: 0 16 | num_nodes: 1 17 | num_processes: 1 18 | gpus: 1 19 | progress_bar_refresh_rate: 50 20 | check_val_every_n_epoch: 1 21 | fast_dev_run: false 22 | accumulate_grad_batches: 1 23 | max_epochs: 5 24 | min_epochs: 1 25 | limit_train_batches: 1.0 26 | limit_val_batches: 1.0 27 | limit_test_batches: 1.0 28 | limit_predict_batches: 1.0 29 | val_check_interval: 1.0 30 | flush_logs_every_n_steps: 100 31 | log_every_n_steps: 50 32 | precision: 16 33 | weights_summary: top 34 | deterministic: true 35 | terminate_on_nan: true 36 | amp_backend: native 37 | amp_level: O2 38 | model: 39 | hidden: 128 40 | learning_rate: 1.0e-4 41 | num_labels: 20 42 | data: 43 | dataset_name: tweet_eval 44 | font_path: "examples/classification/NotoSansDisplay-Regular.ttf" 45 | font_size: 16 46 | window_size: 8 47 | stride: 5 48 | batch_size: 64 49 | subtask: emoji 50 | -------------------------------------------------------------------------------- /examples/translation/NotoSansDisplay-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenghaoMou/embeddings/d414ebfbfc50eb218727ceac207035bf731c4701/examples/translation/NotoSansDisplay-Regular.ttf -------------------------------------------------------------------------------- /examples/translation/nmt.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | logger: true 4 | checkpoint_callback: true 5 | callbacks: 6 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 7 | init_args: 8 | logging_interval: step 9 | default_root_dir: null 10 | gradient_clip_val: 1.0 11 | gradient_clip_algorithm: norm 12 | process_position: 0 13 | num_nodes: 1 14 | num_processes: 1 15 | gpus: 1 16 | auto_select_gpus: false 17 | tpu_cores: null 18 | log_gpu_memory: null 19 | progress_bar_refresh_rate: null 20 | overfit_batches: 0.0 21 | track_grad_norm: -1 22 | check_val_every_n_epoch: 1 23 | fast_dev_run: false 24 | accumulate_grad_batches: 1 25 | max_epochs: null 26 | min_epochs: null 27 | max_steps: null 28 | min_steps: null 29 | max_time: null 30 | limit_train_batches: 1.0 31 | limit_val_batches: 1.0 32 | limit_test_batches: 1.0 33 | limit_predict_batches: 1.0 34 | val_check_interval: 1.0 35 | flush_logs_every_n_steps: 100 36 | log_every_n_steps: 50 37 | accelerator: null 38 | sync_batchnorm: false 39 | precision: 16 40 | weights_summary: top 41 | weights_save_path: null 42 | num_sanity_val_steps: 2 43 | truncated_bptt_steps: null 44 | resume_from_checkpoint: null 45 | profiler: null 46 | benchmark: false 47 | deterministic: false 48 | reload_dataloaders_every_epoch: false 49 | auto_lr_find: false 50 | replace_sampler_ddp: true 51 | terminate_on_nan: false 52 | auto_scale_batch_size: false 53 | prepare_data_per_node: true 54 | plugins: null 55 | amp_backend: native 56 | amp_level: O2 57 | distributed_backend: null 58 | move_metrics_to_cpu: false 59 | multiple_trainloader_mode: max_size_cycle 60 | stochastic_weight_avg: false 61 | model: 62 | vocab_size: 30000 63 | d_model: 128 64 | nhead: 8 65 | num_encoder_layers: 6 66 | num_decoder_layers: 6 67 | dim_feedforward: 2048 68 | max_seq_length: 512 69 | pos_dropout: 0.3 70 | trans_dropout: 0.3 71 | warmup_steps: 4000 72 | lr: 1.0e-3 73 | data: 74 | font_path: examples/translation/NotoSansDisplay-Regular.ttf 75 | font_size: 16 76 | window_size: 8 77 | stride: 4 78 | test_size: 0.2 79 | max_seq_length: 512 80 | vocab_size: 30000 81 | batch_size: 4096 82 | -------------------------------------------------------------------------------- /examples/translation/nmt_transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | import torch 7 | import torch.nn as nn 8 | from datasets import load_dataset 9 | from einops import rearrange 10 | from loguru import logger 11 | from sklearn.model_selection import train_test_split as tts 12 | from spacy.lang.en import English 13 | from text_embeddings.visual import VTRTokenizer 14 | from torch.optim import Adam 15 | from torch.utils.data import BatchSampler, DataLoader, Dataset, SequentialSampler 16 | from tqdm import tqdm 17 | from transformers import get_cosine_schedule_with_warmup 18 | 19 | 20 | def gen_no_peek_mask(length: int) -> np.ndarray: 21 | """Generate an N by N mask for autoregressive attention. 22 | 23 | Parameters 24 | ---------- 25 | length : int 26 | Length of the sequence 27 | 28 | Returns 29 | ------- 30 | nd.ndarray 31 | An N by N mask where allowed positions are marked 32 | as zeros while others are negative infinities 33 | """ 34 | mask = rearrange(torch.triu(torch.ones(length, length)) == 1, "h w -> w h") 35 | mask = ( 36 | mask.float() 37 | .masked_fill(mask == 0, float("-inf")) 38 | .masked_fill(mask == 1, float(0.0)) 39 | ) 40 | 41 | return mask 42 | 43 | 44 | class Translator(pl.LightningModule): 45 | def __init__( 46 | self, 47 | vocab_size: int = 10000 + 4, 48 | d_model: int = 512, 49 | nhead: int = 8, 50 | num_encoder_layers: int = 6, 51 | num_decoder_layers: int = 6, 52 | dim_feedforward: int = 2048, 53 | max_seq_length: int = 96, 54 | pos_dropout: float = 0.1, 55 | trans_dropout: float = 0.1, 56 | warmup_steps: int = 4000, 57 | lr: float = 1e-09, 58 | ): 59 | super().__init__() 60 | self.d_model = d_model 61 | self.lr = lr 62 | self.warmup_steps = warmup_steps 63 | self.max_seq_length = max_seq_length 64 | self.embed_tgt = nn.Embedding(vocab_size, d_model) 65 | self.pos_enc = PositionalEncoding(d_model, pos_dropout, max_seq_length) 66 | logger.debug( 67 | f"lr: {self.lr}, d_model: {d_model}, max_seq_length: {max_seq_length}" 68 | ) 69 | 70 | self.transformer = nn.Transformer( 71 | d_model, 72 | nhead, 73 | num_encoder_layers, 74 | num_decoder_layers, 75 | dim_feedforward, 76 | trans_dropout, 77 | ) 78 | self.fc = nn.Linear(d_model, vocab_size) 79 | self.loss = nn.CrossEntropyLoss(ignore_index=0, reduction="mean") 80 | self.init_weights() 81 | 82 | def init_weights(self): 83 | 84 | for p in self.parameters(): 85 | if p.dim() > 1: 86 | nn.init.xavier_normal_(p) 87 | 88 | def forward( 89 | self, 90 | src, 91 | tgt, 92 | src_key_padding_mask, 93 | tgt_key_padding_mask, 94 | memory_key_padding_mask, 95 | tgt_mask, 96 | ): 97 | src = rearrange(src, "n s h w -> s n h w") 98 | tgt = rearrange(tgt, "n t -> t n") 99 | 100 | src = self.pos_enc(rearrange(src, "s n h w -> s n (h w)")) 101 | tgt = self.pos_enc(self.embed_tgt(tgt) * math.sqrt(self.d_model)) 102 | 103 | output = self.transformer( 104 | src, 105 | tgt, 106 | tgt_mask=tgt_mask, 107 | src_key_padding_mask=src_key_padding_mask, 108 | tgt_key_padding_mask=tgt_key_padding_mask, 109 | memory_key_padding_mask=memory_key_padding_mask, 110 | ) 111 | 112 | output = rearrange(output, "t n e -> n t e") 113 | 114 | return self.fc(output) 115 | 116 | def training_step(self, batch, batch_idx): 117 | 118 | src, src_key_padding_mask, tgt, tgt_key_padding_mask, ratio = batch 119 | memory_key_padding_mask = src_key_padding_mask.clone() 120 | tgt_inp, tgt_out = tgt[:, :-1], tgt[:, 1:] 121 | tgt_mask = gen_no_peek_mask(tgt_inp.shape[1]).to(self.device) 122 | 123 | outputs = self.forward( 124 | src, 125 | tgt_inp, 126 | src_key_padding_mask, 127 | tgt_key_padding_mask[:, :-1], 128 | memory_key_padding_mask, 129 | tgt_mask, 130 | ) 131 | loss = self.loss( 132 | rearrange(outputs, "b t v -> (b t) v"), rearrange(tgt_out, "b o -> (b o)") 133 | ) 134 | self.log("batch_ratio", ratio, prog_bar=True) 135 | return {"loss": loss} 136 | 137 | def validation_step(self, batch, batch_idx): 138 | 139 | src, src_key_padding_mask, tgt, tgt_key_padding_mask, _ = batch 140 | memory_key_padding_mask = src_key_padding_mask.clone() 141 | tgt_inp, tgt_out = tgt[:, :-1], tgt[:, 1:] 142 | tgt_mask = gen_no_peek_mask(tgt_inp.shape[1]).to(self.device) 143 | 144 | outputs = self.forward( 145 | src, 146 | tgt_inp, 147 | src_key_padding_mask, 148 | tgt_key_padding_mask[:, :-1], 149 | memory_key_padding_mask, 150 | tgt_mask, 151 | ) 152 | loss = self.loss( 153 | rearrange(outputs, "b t v -> (b t) v"), rearrange(tgt_out, "b o -> (b o)") 154 | ) 155 | 156 | return {"val_loss": loss} 157 | 158 | def validation_epoch_end(self, outputs) -> None: 159 | 160 | loss = torch.mean(torch.stack([o["val_loss"] for o in outputs])) 161 | self.log("val_loss", loss, prog_bar=True, on_epoch=True) 162 | self.log("val_ppl", torch.exp(loss), prog_bar=True, on_epoch=True) 163 | logger.debug( 164 | self.decode("Strategie republikánské strany proti Obamovu znovuzvolení") 165 | ) 166 | 167 | def configure_optimizers(self): 168 | print(self.lr) 169 | optimizer = Adam(self.parameters(), betas=(0.9, 0.98), lr=self.lr, eps=1e-9) 170 | scheduler = get_cosine_schedule_with_warmup(optimizer, self.warmup_steps, 18020) 171 | return [optimizer], [ 172 | { 173 | "scheduler": scheduler, 174 | "interval": "step", 175 | "frequency": 1, 176 | "monitor": "val_loss", 177 | "strict": True, 178 | "name": "lr", 179 | } 180 | ] 181 | 182 | def decode(self, sentence): 183 | source_tokenizer = self.trainer.datamodule.tokenizer 184 | src = source_tokenizer.text2embeddings(sentence) 185 | src = rearrange( 186 | torch.from_numpy(src).to(self.device).unsqueeze(0), "n s h w -> s n h w" 187 | ) 188 | src = self.pos_enc(rearrange(src, "s n h w -> s n (h w)")) 189 | 190 | memory = self.transformer.encoder(src) 191 | 192 | ids = [2] 193 | tgt = torch.from_numpy(np.asarray([ids])).long().to(self.device) 194 | 195 | while True: 196 | tgt = rearrange(tgt, "n t -> t n") 197 | tgt = self.pos_enc(self.embed_tgt(tgt) * math.sqrt(self.d_model)) 198 | output = self.transformer.decoder( 199 | tgt, memory, tgt_mask=gen_no_peek_mask(tgt.shape[0]).to(self.device) 200 | ) 201 | output = rearrange(output, "t n e -> n t e") 202 | logits = self.fc(output) 203 | idx = torch.argmax(logits[0], dim=-1)[-1].item() 204 | if idx == 3 or len(ids) == self.max_seq_length: 205 | break 206 | ids.append(idx) 207 | tgt = torch.from_numpy(np.asarray([ids])).long().to(self.device) 208 | 209 | return " ".join( 210 | [self.trainer.datamodule.idx2token.get(id, "[unk]") for id in ids] 211 | ) 212 | 213 | 214 | class PositionalEncoding(nn.Module): 215 | def __init__(self, d_model, dropout=0.1, max_len=100): 216 | super(PositionalEncoding, self).__init__() 217 | self.dropout = nn.Dropout(p=dropout) 218 | 219 | pe = torch.zeros(max_len, d_model) 220 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 221 | div_term = torch.exp( 222 | torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) 223 | ) 224 | pe[:, 0::2] = torch.sin(position * div_term) 225 | pe[:, 1::2] = torch.cos(position * div_term) 226 | pe = pe.unsqueeze(0).transpose(0, 1) 227 | self.register_buffer("pe", pe) 228 | 229 | def forward(self, x): 230 | x = x + self.pe[: x.size(0), :] 231 | return self.dropout(x) 232 | 233 | 234 | class TranslationDataModule(pl.LightningDataModule): 235 | def __init__( 236 | self, 237 | font_path="/home/chenghaomou/embeddings/Noto_Sans/NotoSans-Regular.ttf", 238 | font_size: int = 16, 239 | window_size: int = 8, 240 | stride: int = 5, 241 | test_size: float = 0.2, 242 | max_seq_length: int = 96, 243 | vocab_size: int = 10000, 244 | batch_size: int = 2048, 245 | ): 246 | super().__init__() 247 | dataset = load_dataset("wmt14", "cs-en") 248 | sentences = [ 249 | (x["translation"]["cs"], x["translation"]["en"]) 250 | for x in dataset["validation"] 251 | ] + [(x["translation"]["cs"], x["translation"]["en"]) for x in dataset["test"]] 252 | source_tokenizer = VTRTokenizer( 253 | font=font_path, window_size=window_size, font_size=font_size, stride=stride 254 | ) 255 | source, target = zip(*sentences) 256 | 257 | nlp = English() 258 | target_tokenizer = nlp.tokenizer 259 | target_tokens = [ 260 | [t.text for t in d] for d in tqdm(target_tokenizer.pipe(target)) 261 | ] 262 | target_vocab = { 263 | t: i + 4 264 | for i, (t, _) in enumerate( 265 | Counter([t for doc in target_tokens for t in doc]).most_common( 266 | vocab_size 267 | ) 268 | ) 269 | } 270 | target_vocab["[pad]"] = 0 271 | target_vocab["[oov]"] = 1 272 | target_vocab["[bos]"] = 2 273 | target_vocab["[eos]"] = 3 274 | 275 | ids = [ 276 | ( 277 | np.asarray(source_tokenizer.text2embeddings(s)), 278 | [2] + [target_vocab.get(x, 1) for x in t] + [3], 279 | ) 280 | for s, t in tqdm(zip(source, target_tokens)) 281 | ] 282 | ids = [ 283 | (x, y) 284 | for x, y in ids 285 | if len(x) <= max_seq_length and len(y) <= max_seq_length 286 | ] 287 | 288 | self.ids = ids 289 | self.test_size = test_size 290 | self.batch_size = batch_size 291 | self.font_size = font_size 292 | self.window_size = window_size 293 | self.tokenizer = source_tokenizer 294 | self.idx2token = {i: t for t, i in target_vocab.items()} 295 | 296 | def setup(self, stage=None): 297 | self.train, self.val = tts(self.ids, test_size=self.test_size) 298 | 299 | def train_dataloader(self): 300 | return DataLoader( 301 | DummyDataset(self.train), 302 | batch_sampler=DummySampler( 303 | SequentialSampler( 304 | sorted(range(len(self.train)), key=lambda x: len(self.train[x][0])) 305 | ), 306 | batch_size=self.batch_size, 307 | drop_last=False, 308 | ids=self.train, 309 | ), 310 | collate_fn=self.collate_fn, 311 | num_workers=8, 312 | ) 313 | 314 | def val_dataloader(self): 315 | return DataLoader( 316 | DummyDataset(self.val), 317 | batch_sampler=DummySampler( 318 | SequentialSampler( 319 | sorted(range(len(self.val)), key=lambda x: len(self.val[x][0])) 320 | ), 321 | batch_size=self.batch_size, 322 | drop_last=False, 323 | ids=self.val, 324 | ), 325 | collate_fn=self.collate_fn, 326 | num_workers=8, 327 | ) 328 | 329 | def collate_fn(self, batch): 330 | source_input_ids = np.zeros( 331 | ( 332 | len(batch), 333 | max(map(lambda x: len(x[0]), batch)), 334 | self.font_size, 335 | self.window_size, 336 | ) 337 | ) 338 | target_input_ids = np.zeros((len(batch), max(map(lambda x: len(x[1]), batch)))) 339 | source_mask = np.zeros( 340 | (len(batch), max(map(lambda x: len(x[0]), batch))), dtype=bool 341 | ) 342 | target_mask = np.zeros_like(target_input_ids, dtype=bool) 343 | 344 | for i, (source, target) in enumerate(batch): 345 | source_input_ids[i, : len(source), :, :] = source 346 | target_input_ids[i, : len(target)] = target 347 | source_mask[i, len(source) :] = True 348 | target_mask[i, len(target) :] = True 349 | 350 | return ( 351 | torch.from_numpy(source_input_ids).float(), 352 | torch.from_numpy(source_mask), 353 | torch.from_numpy(target_input_ids).long(), 354 | torch.from_numpy(target_mask), 355 | np.count_nonzero((~source_mask).astype(int)) / source_mask.size, 356 | ) 357 | 358 | 359 | class DummyDataset(Dataset): 360 | def __init__(self, data): 361 | self.data = data 362 | 363 | def __getitem__(self, key): 364 | return self.data[key] 365 | 366 | def __len__(self): 367 | return len(self.data) 368 | 369 | 370 | class DummySampler(BatchSampler): 371 | def __init__(self, sampler, batch_size, drop_last, ids): 372 | self.sampler = sampler 373 | self.batch_size = batch_size 374 | self.drop_last = drop_last 375 | 376 | batch = [] 377 | batches = [] 378 | curr_max = 0 379 | for idx in self.sampler: 380 | curr_token = len(ids[idx][0]) + len(ids[idx][1]) 381 | curr_max = max(curr_max, curr_token) 382 | if curr_max * len(batch) >= self.batch_size: 383 | batches.append(batch[:]) 384 | batch = [idx] 385 | curr_max = curr_token 386 | else: 387 | batch.append(idx) 388 | if batch and not self.drop_last: 389 | batches.append(batch[:]) 390 | 391 | self.batches = batches 392 | 393 | def __iter__(self): 394 | for batch in self.batches: 395 | yield batch 396 | 397 | def __len__(self): 398 | return len(self.batches) 399 | 400 | 401 | if __name__ == "__main__": 402 | 403 | from pytorch_lightning.utilities.cli import LightningCLI 404 | 405 | cli = LightningCLI(Translator, TranslationDataModule) 406 | -------------------------------------------------------------------------------- /poetry.lock: -------------------------------------------------------------------------------- 1 | [[package]] 2 | name = "absl-py" 3 | version = "1.1.0" 4 | description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." 5 | category = "dev" 6 | optional = false 7 | python-versions = ">=3.6" 8 | 9 | [[package]] 10 | name = "aiohttp" 11 | version = "3.8.1" 12 | description = "Async http client/server framework (asyncio)" 13 | category = "dev" 14 | optional = false 15 | python-versions = ">=3.6" 16 | 17 | [package.dependencies] 18 | aiosignal = ">=1.1.2" 19 | async-timeout = ">=4.0.0a3,<5.0" 20 | attrs = ">=17.3.0" 21 | charset-normalizer = ">=2.0,<3.0" 22 | frozenlist = ">=1.1.1" 23 | multidict = ">=4.5,<7.0" 24 | yarl = ">=1.0,<2.0" 25 | 26 | [package.extras] 27 | speedups = ["aiodns", "brotli", "cchardet"] 28 | 29 | [[package]] 30 | name = "aiosignal" 31 | version = "1.2.0" 32 | description = "aiosignal: a list of registered asynchronous callbacks" 33 | category = "dev" 34 | optional = false 35 | python-versions = ">=3.6" 36 | 37 | [package.dependencies] 38 | frozenlist = ">=1.1.0" 39 | 40 | [[package]] 41 | name = "async-timeout" 42 | version = "4.0.2" 43 | description = "Timeout context manager for asyncio programs" 44 | category = "dev" 45 | optional = false 46 | python-versions = ">=3.6" 47 | 48 | [[package]] 49 | name = "atomicwrites" 50 | version = "1.4.1" 51 | description = "Atomic file writes." 52 | category = "main" 53 | optional = false 54 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 55 | 56 | [[package]] 57 | name = "attrs" 58 | version = "21.4.0" 59 | description = "Classes Without Boilerplate" 60 | category = "main" 61 | optional = false 62 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 63 | 64 | [package.extras] 65 | dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit", "cloudpickle"] 66 | docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"] 67 | tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "cloudpickle"] 68 | tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "cloudpickle"] 69 | 70 | [[package]] 71 | name = "cachetools" 72 | version = "5.2.0" 73 | description = "Extensible memoizing collections and decorators" 74 | category = "dev" 75 | optional = false 76 | python-versions = "~=3.7" 77 | 78 | [[package]] 79 | name = "certifi" 80 | version = "2022.6.15" 81 | description = "Python package for providing Mozilla's CA Bundle." 82 | category = "main" 83 | optional = false 84 | python-versions = ">=3.6" 85 | 86 | [[package]] 87 | name = "charset-normalizer" 88 | version = "2.1.0" 89 | description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." 90 | category = "main" 91 | optional = false 92 | python-versions = ">=3.6.0" 93 | 94 | [package.extras] 95 | unicode_backport = ["unicodedata2"] 96 | 97 | [[package]] 98 | name = "colorama" 99 | version = "0.4.5" 100 | description = "Cross-platform colored terminal text." 101 | category = "main" 102 | optional = false 103 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 104 | 105 | [[package]] 106 | name = "datasets" 107 | version = "1.18.4" 108 | description = "HuggingFace community-driven open-source library of datasets" 109 | category = "dev" 110 | optional = false 111 | python-versions = "*" 112 | 113 | [package.dependencies] 114 | aiohttp = "*" 115 | dill = "*" 116 | fsspec = {version = ">=2021.05.0", extras = ["http"]} 117 | huggingface-hub = ">=0.1.0,<1.0.0" 118 | multiprocess = "*" 119 | numpy = ">=1.17" 120 | packaging = "*" 121 | pandas = "*" 122 | pyarrow = ">=3.0.0,<4.0.0 || >4.0.0" 123 | requests = ">=2.19.0" 124 | responses = "<0.19" 125 | tqdm = ">=4.62.1" 126 | xxhash = "*" 127 | 128 | [package.extras] 129 | apache-beam = ["apache-beam (>=2.26.0)"] 130 | audio = ["librosa"] 131 | benchmarks = ["numpy (==1.18.5)", "tensorflow (==2.3.0)", "torch (==1.6.0)", "transformers (==3.0.2)"] 132 | dev = ["absl-py", "pytest", "pytest-datadir", "pytest-xdist", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "aiobotocore", "boto3", "botocore", "faiss-cpu (>=1.6.4)", "fsspec", "moto[server,s3] (==2.0.4)", "rarfile (>=4.0)", "s3fs (==2021.08.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "torch", "torchaudio", "soundfile", "transformers", "bs4", "conllu", "h5py", "langdetect", "lxml", "mwparserfromhell", "nltk", "openpyxl", "py7zr", "tldextract", "zstandard", "bert-score (>=0.3.6)", "rouge-score", "sacrebleu", "scipy", "seqeval", "scikit-learn", "jiwer", "sentencepiece", "torchmetrics (==0.6.0)", "mauve-text", "toml (>=0.10.1)", "requests-file (>=1.5.1)", "tldextract (>=3.1.0)", "texttable (>=1.6.3)", "Werkzeug (>=1.0.1)", "six (>=1.15.0,<1.16.0)", "Pillow (>=6.2.1)", "librosa", "wget (>=3.2)", "pytorch-nlp (==0.5.0)", "pytorch-lightning", "fastBPE (==0.1.0)", "fairseq", "black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "pyyaml (>=5.3.1)", "importlib-resources"] 133 | docs = ["docutils (==0.16.0)", "recommonmark", "sphinx (==3.1.2)", "sphinx-markdown-tables", "sphinx-rtd-theme (==0.4.3)", "sphinxext-opengraph (==0.4.1)", "sphinx-copybutton", "fsspec (<2021.9.0)", "s3fs", "sphinx-panels", "sphinx-inline-tabs", "myst-parser", "Markdown (!=3.3.5)"] 134 | quality = ["black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "pyyaml (>=5.3.1)"] 135 | s3 = ["fsspec", "boto3", "botocore", "s3fs"] 136 | tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)"] 137 | tensorflow_gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"] 138 | tests = ["absl-py", "pytest", "pytest-datadir", "pytest-xdist", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "aiobotocore", "boto3", "botocore", "faiss-cpu (>=1.6.4)", "fsspec", "moto[server,s3] (==2.0.4)", "rarfile (>=4.0)", "s3fs (==2021.08.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "torch", "torchaudio", "soundfile", "transformers", "bs4", "conllu", "h5py", "langdetect", "lxml", "mwparserfromhell", "nltk", "openpyxl", "py7zr", "tldextract", "zstandard", "bert-score (>=0.3.6)", "rouge-score", "sacrebleu", "scipy", "seqeval", "scikit-learn", "jiwer", "sentencepiece", "torchmetrics (==0.6.0)", "mauve-text", "toml (>=0.10.1)", "requests-file (>=1.5.1)", "tldextract (>=3.1.0)", "texttable (>=1.6.3)", "Werkzeug (>=1.0.1)", "six (>=1.15.0,<1.16.0)", "Pillow (>=6.2.1)", "librosa", "wget (>=3.2)", "pytorch-nlp (==0.5.0)", "pytorch-lightning", "fastBPE (==0.1.0)", "fairseq", "importlib-resources"] 139 | torch = ["torch"] 140 | vision = ["Pillow (>=6.2.1)"] 141 | 142 | [[package]] 143 | name = "dill" 144 | version = "0.3.5.1" 145 | description = "serialize all of python" 146 | category = "dev" 147 | optional = false 148 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" 149 | 150 | [package.extras] 151 | graph = ["objgraph (>=1.7.2)"] 152 | 153 | [[package]] 154 | name = "einops" 155 | version = "0.3.2" 156 | description = "A new flavour of deep learning operations" 157 | category = "main" 158 | optional = false 159 | python-versions = "*" 160 | 161 | [[package]] 162 | name = "filelock" 163 | version = "3.7.1" 164 | description = "A platform independent file lock." 165 | category = "main" 166 | optional = false 167 | python-versions = ">=3.7" 168 | 169 | [package.extras] 170 | docs = ["furo (>=2021.8.17b43)", "sphinx (>=4.1)", "sphinx-autodoc-typehints (>=1.12)"] 171 | testing = ["covdefaults (>=1.2.0)", "coverage (>=4)", "pytest (>=4)", "pytest-cov", "pytest-timeout (>=1.4.2)"] 172 | 173 | [[package]] 174 | name = "frozenlist" 175 | version = "1.3.0" 176 | description = "A list-like structure which implements collections.abc.MutableSequence" 177 | category = "dev" 178 | optional = false 179 | python-versions = ">=3.7" 180 | 181 | [[package]] 182 | name = "fsspec" 183 | version = "2022.5.0" 184 | description = "File-system specification" 185 | category = "dev" 186 | optional = false 187 | python-versions = ">=3.7" 188 | 189 | [package.dependencies] 190 | aiohttp = {version = "*", optional = true, markers = "extra == \"http\""} 191 | requests = {version = "*", optional = true, markers = "extra == \"http\""} 192 | 193 | [package.extras] 194 | abfs = ["adlfs"] 195 | adl = ["adlfs"] 196 | arrow = ["pyarrow (>=1)"] 197 | dask = ["dask", "distributed"] 198 | dropbox = ["dropboxdrivefs", "requests", "dropbox"] 199 | entrypoints = ["importlib-metadata"] 200 | fuse = ["fusepy"] 201 | gcs = ["gcsfs"] 202 | git = ["pygit2"] 203 | github = ["requests"] 204 | gs = ["gcsfs"] 205 | gui = ["panel"] 206 | hdfs = ["pyarrow (>=1)"] 207 | http = ["requests", "aiohttp"] 208 | libarchive = ["libarchive-c"] 209 | oci = ["ocifs"] 210 | s3 = ["s3fs"] 211 | sftp = ["paramiko"] 212 | smb = ["smbprotocol"] 213 | ssh = ["paramiko"] 214 | tqdm = ["tqdm"] 215 | 216 | [[package]] 217 | name = "google-auth" 218 | version = "2.9.1" 219 | description = "Google Authentication Library" 220 | category = "dev" 221 | optional = false 222 | python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*" 223 | 224 | [package.dependencies] 225 | cachetools = ">=2.0.0,<6.0" 226 | pyasn1-modules = ">=0.2.1" 227 | rsa = {version = ">=3.1.4,<5", markers = "python_version >= \"3.6\""} 228 | six = ">=1.9.0" 229 | 230 | [package.extras] 231 | aiohttp = ["requests (>=2.20.0,<3.0.0dev)", "aiohttp (>=3.6.2,<4.0.0dev)"] 232 | enterprise_cert = ["cryptography (==36.0.2)", "pyopenssl (==22.0.0)"] 233 | pyopenssl = ["pyopenssl (>=20.0.0)"] 234 | reauth = ["pyu2f (>=0.1.5)"] 235 | 236 | [[package]] 237 | name = "google-auth-oauthlib" 238 | version = "0.4.6" 239 | description = "Google Authentication Library" 240 | category = "dev" 241 | optional = false 242 | python-versions = ">=3.6" 243 | 244 | [package.dependencies] 245 | google-auth = ">=1.0.0" 246 | requests-oauthlib = ">=0.7.0" 247 | 248 | [package.extras] 249 | tool = ["click (>=6.0.0)"] 250 | 251 | [[package]] 252 | name = "grpcio" 253 | version = "1.38.1" 254 | description = "HTTP/2-based RPC framework" 255 | category = "main" 256 | optional = false 257 | python-versions = "*" 258 | 259 | [package.dependencies] 260 | six = ">=1.5.2" 261 | 262 | [package.extras] 263 | protobuf = ["grpcio-tools (>=1.38.1)"] 264 | 265 | [[package]] 266 | name = "huggingface-hub" 267 | version = "0.8.1" 268 | description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" 269 | category = "main" 270 | optional = false 271 | python-versions = ">=3.7.0" 272 | 273 | [package.dependencies] 274 | filelock = "*" 275 | packaging = ">=20.9" 276 | pyyaml = ">=5.1" 277 | requests = "*" 278 | tqdm = "*" 279 | typing-extensions = ">=3.7.4.3" 280 | 281 | [package.extras] 282 | all = ["pytest", "pytest-cov", "datasets", "soundfile", "black (>=22.0,<23.0)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"] 283 | dev = ["pytest", "pytest-cov", "datasets", "soundfile", "black (>=22.0,<23.0)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"] 284 | fastai = ["toml", "fastai (>=2.4)", "fastcore (>=1.3.27)"] 285 | quality = ["black (>=22.0,<23.0)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"] 286 | tensorflow = ["tensorflow", "pydot", "graphviz"] 287 | testing = ["pytest", "pytest-cov", "datasets", "soundfile"] 288 | torch = ["torch"] 289 | 290 | [[package]] 291 | name = "idna" 292 | version = "3.3" 293 | description = "Internationalized Domain Names in Applications (IDNA)" 294 | category = "main" 295 | optional = false 296 | python-versions = ">=3.5" 297 | 298 | [[package]] 299 | name = "importlib-metadata" 300 | version = "4.12.0" 301 | description = "Read metadata from Python packages" 302 | category = "main" 303 | optional = false 304 | python-versions = ">=3.7" 305 | 306 | [package.dependencies] 307 | zipp = ">=0.5" 308 | 309 | [package.extras] 310 | docs = ["sphinx", "jaraco.packaging (>=9)", "rst.linker (>=1.9)"] 311 | perf = ["ipython"] 312 | testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.3)", "packaging", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)", "importlib-resources (>=1.3)"] 313 | 314 | [[package]] 315 | name = "iniconfig" 316 | version = "1.1.1" 317 | description = "iniconfig: brain-dead simple config-ini parsing" 318 | category = "main" 319 | optional = false 320 | python-versions = "*" 321 | 322 | [[package]] 323 | name = "joblib" 324 | version = "1.1.0" 325 | description = "Lightweight pipelining with Python functions" 326 | category = "main" 327 | optional = false 328 | python-versions = ">=3.6" 329 | 330 | [[package]] 331 | name = "loguru" 332 | version = "0.5.3" 333 | description = "Python logging made (stupidly) simple" 334 | category = "main" 335 | optional = false 336 | python-versions = ">=3.5" 337 | 338 | [package.dependencies] 339 | colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""} 340 | win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} 341 | 342 | [package.extras] 343 | dev = ["codecov (>=2.0.15)", "colorama (>=0.3.4)", "flake8 (>=3.7.7)", "tox (>=3.9.0)", "tox-travis (>=0.12)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "Sphinx (>=2.2.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "black (>=19.10b0)", "isort (>=5.1.1)"] 344 | 345 | [[package]] 346 | name = "mako" 347 | version = "1.2.1" 348 | description = "A super-fast templating language that borrows the best ideas from the existing templating languages." 349 | category = "main" 350 | optional = false 351 | python-versions = ">=3.7" 352 | 353 | [package.dependencies] 354 | MarkupSafe = ">=0.9.2" 355 | 356 | [package.extras] 357 | babel = ["babel"] 358 | lingua = ["lingua"] 359 | testing = ["pytest"] 360 | 361 | [[package]] 362 | name = "markdown" 363 | version = "3.4.1" 364 | description = "Python implementation of Markdown." 365 | category = "main" 366 | optional = false 367 | python-versions = ">=3.7" 368 | 369 | [package.dependencies] 370 | importlib-metadata = {version = ">=4.4", markers = "python_version < \"3.10\""} 371 | 372 | [package.extras] 373 | testing = ["coverage", "pyyaml"] 374 | 375 | [[package]] 376 | name = "markupsafe" 377 | version = "2.1.1" 378 | description = "Safely add untrusted strings to HTML/XML markup." 379 | category = "main" 380 | optional = false 381 | python-versions = ">=3.7" 382 | 383 | [[package]] 384 | name = "mmh3" 385 | version = "3.0.0" 386 | description = "Python wrapper for MurmurHash (MurmurHash3), a set of fast and robust hash functions." 387 | category = "main" 388 | optional = false 389 | python-versions = "*" 390 | 391 | [[package]] 392 | name = "multidict" 393 | version = "6.0.2" 394 | description = "multidict implementation" 395 | category = "dev" 396 | optional = false 397 | python-versions = ">=3.7" 398 | 399 | [[package]] 400 | name = "multiprocess" 401 | version = "0.70.13" 402 | description = "better multiprocessing and multithreading in python" 403 | category = "dev" 404 | optional = false 405 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" 406 | 407 | [package.dependencies] 408 | dill = ">=0.3.5.1" 409 | 410 | [[package]] 411 | name = "numpy" 412 | version = "1.23.1" 413 | description = "NumPy is the fundamental package for array computing with Python." 414 | category = "main" 415 | optional = false 416 | python-versions = ">=3.8" 417 | 418 | [[package]] 419 | name = "oauthlib" 420 | version = "3.2.0" 421 | description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" 422 | category = "dev" 423 | optional = false 424 | python-versions = ">=3.6" 425 | 426 | [package.extras] 427 | rsa = ["cryptography (>=3.0.0)"] 428 | signals = ["blinker (>=1.4.0)"] 429 | signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] 430 | 431 | [[package]] 432 | name = "packaging" 433 | version = "21.3" 434 | description = "Core utilities for Python packages" 435 | category = "main" 436 | optional = false 437 | python-versions = ">=3.6" 438 | 439 | [package.dependencies] 440 | pyparsing = ">=2.0.2,<3.0.5 || >3.0.5" 441 | 442 | [[package]] 443 | name = "pandas" 444 | version = "1.4.3" 445 | description = "Powerful data structures for data analysis, time series, and statistics" 446 | category = "dev" 447 | optional = false 448 | python-versions = ">=3.8" 449 | 450 | [package.dependencies] 451 | numpy = [ 452 | {version = ">=1.18.5", markers = "platform_machine != \"aarch64\" and platform_machine != \"arm64\" and python_version < \"3.10\""}, 453 | {version = ">=1.19.2", markers = "platform_machine == \"aarch64\" and python_version < \"3.10\""}, 454 | {version = ">=1.20.0", markers = "platform_machine == \"arm64\" and python_version < \"3.10\""}, 455 | {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, 456 | ] 457 | python-dateutil = ">=2.8.1" 458 | pytz = ">=2020.1" 459 | 460 | [package.extras] 461 | test = ["hypothesis (>=5.5.3)", "pytest (>=6.0)", "pytest-xdist (>=1.31)"] 462 | 463 | [[package]] 464 | name = "pdoc3" 465 | version = "0.9.2" 466 | description = "Auto-generate API documentation for Python projects." 467 | category = "main" 468 | optional = false 469 | python-versions = ">= 3.5" 470 | 471 | [package.dependencies] 472 | mako = "*" 473 | markdown = ">=3.0" 474 | 475 | [[package]] 476 | name = "pillow" 477 | version = "8.4.0" 478 | description = "Python Imaging Library (Fork)" 479 | category = "main" 480 | optional = false 481 | python-versions = ">=3.6" 482 | 483 | [[package]] 484 | name = "pluggy" 485 | version = "1.0.0" 486 | description = "plugin and hook calling mechanisms for python" 487 | category = "main" 488 | optional = false 489 | python-versions = ">=3.6" 490 | 491 | [package.extras] 492 | dev = ["pre-commit", "tox"] 493 | testing = ["pytest", "pytest-benchmark"] 494 | 495 | [[package]] 496 | name = "protobuf" 497 | version = "3.20.1" 498 | description = "Protocol Buffers" 499 | category = "dev" 500 | optional = false 501 | python-versions = ">=3.7" 502 | 503 | [[package]] 504 | name = "py" 505 | version = "1.11.0" 506 | description = "library with cross-python path, ini-parsing, io, code, log facilities" 507 | category = "main" 508 | optional = false 509 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 510 | 511 | [[package]] 512 | name = "pyarrow" 513 | version = "8.0.0" 514 | description = "Python library for Apache Arrow" 515 | category = "dev" 516 | optional = false 517 | python-versions = ">=3.7" 518 | 519 | [package.dependencies] 520 | numpy = ">=1.16.6" 521 | 522 | [[package]] 523 | name = "pyasn1" 524 | version = "0.4.8" 525 | description = "ASN.1 types and codecs" 526 | category = "dev" 527 | optional = false 528 | python-versions = "*" 529 | 530 | [[package]] 531 | name = "pyasn1-modules" 532 | version = "0.2.8" 533 | description = "A collection of ASN.1-based protocols modules." 534 | category = "dev" 535 | optional = false 536 | python-versions = "*" 537 | 538 | [package.dependencies] 539 | pyasn1 = ">=0.4.6,<0.5.0" 540 | 541 | [[package]] 542 | name = "pydeprecate" 543 | version = "0.3.2" 544 | description = "Deprecation tooling" 545 | category = "dev" 546 | optional = false 547 | python-versions = ">=3.6" 548 | 549 | [[package]] 550 | name = "pyparsing" 551 | version = "3.0.9" 552 | description = "pyparsing module - Classes and methods to define and execute parsing grammars" 553 | category = "main" 554 | optional = false 555 | python-versions = ">=3.6.8" 556 | 557 | [package.extras] 558 | diagrams = ["railroad-diagrams", "jinja2"] 559 | 560 | [[package]] 561 | name = "pytest" 562 | version = "6.2.5" 563 | description = "pytest: simple powerful testing with Python" 564 | category = "main" 565 | optional = false 566 | python-versions = ">=3.6" 567 | 568 | [package.dependencies] 569 | atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""} 570 | attrs = ">=19.2.0" 571 | colorama = {version = "*", markers = "sys_platform == \"win32\""} 572 | iniconfig = "*" 573 | packaging = "*" 574 | pluggy = ">=0.12,<2.0" 575 | py = ">=1.8.2" 576 | toml = "*" 577 | 578 | [package.extras] 579 | testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] 580 | 581 | [[package]] 582 | name = "python-dateutil" 583 | version = "2.8.2" 584 | description = "Extensions to the standard Python datetime module" 585 | category = "dev" 586 | optional = false 587 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" 588 | 589 | [package.dependencies] 590 | six = ">=1.5" 591 | 592 | [[package]] 593 | name = "pytorch-lightning" 594 | version = "1.6.5" 595 | description = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate." 596 | category = "dev" 597 | optional = false 598 | python-versions = ">=3.7" 599 | 600 | [package.dependencies] 601 | fsspec = {version = ">=2021.05.0,<2021.06.0 || >2021.06.0", extras = ["http"]} 602 | numpy = ">=1.17.2" 603 | packaging = ">=17.0" 604 | protobuf = "<=3.20.1" 605 | pyDeprecate = ">=0.3.1" 606 | PyYAML = ">=5.4" 607 | tensorboard = ">=2.2.0" 608 | torch = ">=1.8" 609 | torchmetrics = ">=0.4.1" 610 | tqdm = ">=4.57.0" 611 | typing-extensions = ">=4.0.0" 612 | 613 | [package.extras] 614 | all = ["matplotlib (>3.1)", "torchtext (>=0.9)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.7.1)", "gcsfs (>=2021.5.0)", "rich (>=10.2.2,!=10.15.0.a)", "neptune-client (>=0.10.0)", "comet-ml (>=3.1.12)", "mlflow (>=1.0.0)", "test-tube (>=0.7.5)", "wandb (>=0.8.21)", "coverage (>=6.4)", "codecov (>=2.1)", "pytest (>=6.0)", "pytest-rerunfailures (>=10.2)", "mypy (>=0.920)", "flake8 (>=3.9.2)", "pre-commit (>=1.0)", "pytest-forked", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime", "pandas", "torchvision (>=0.9)", "gym[classic_control] (>=0.17.0)", "ipython", "fairscale (>=0.4.5)", "deepspeed", "horovod (>=0.21.2,!=0.24.0)", "hivemind (>=1.0.1)"] 615 | deepspeed = ["deepspeed"] 616 | dev = ["matplotlib (>3.1)", "torchtext (>=0.9)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.7.1)", "gcsfs (>=2021.5.0)", "rich (>=10.2.2,!=10.15.0.a)", "neptune-client (>=0.10.0)", "comet-ml (>=3.1.12)", "mlflow (>=1.0.0)", "test-tube (>=0.7.5)", "wandb (>=0.8.21)", "coverage (>=6.4)", "codecov (>=2.1)", "pytest (>=6.0)", "pytest-rerunfailures (>=10.2)", "mypy (>=0.920)", "flake8 (>=3.9.2)", "pre-commit (>=1.0)", "pytest-forked", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime", "pandas"] 617 | examples = ["torchvision (>=0.9)", "gym[classic_control] (>=0.17.0)", "ipython"] 618 | extra = ["matplotlib (>3.1)", "torchtext (>=0.9)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.7.1)", "gcsfs (>=2021.5.0)", "rich (>=10.2.2,!=10.15.0.a)"] 619 | fairscale = ["fairscale (>=0.4.5)"] 620 | hivemind = ["hivemind (>=1.0.1)"] 621 | horovod = ["horovod (>=0.21.2,!=0.24.0)"] 622 | loggers = ["neptune-client (>=0.10.0)", "comet-ml (>=3.1.12)", "mlflow (>=1.0.0)", "test-tube (>=0.7.5)", "wandb (>=0.8.21)"] 623 | strategies = ["fairscale (>=0.4.5)", "deepspeed", "horovod (>=0.21.2,!=0.24.0)", "hivemind (>=1.0.1)"] 624 | test = ["coverage (>=6.4)", "codecov (>=2.1)", "pytest (>=6.0)", "pytest-rerunfailures (>=10.2)", "mypy (>=0.920)", "flake8 (>=3.9.2)", "pre-commit (>=1.0)", "pytest-forked", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime", "pandas"] 625 | 626 | [[package]] 627 | name = "pytz" 628 | version = "2022.1" 629 | description = "World timezone definitions, modern and historical" 630 | category = "dev" 631 | optional = false 632 | python-versions = "*" 633 | 634 | [[package]] 635 | name = "pyyaml" 636 | version = "6.0" 637 | description = "YAML parser and emitter for Python" 638 | category = "main" 639 | optional = false 640 | python-versions = ">=3.6" 641 | 642 | [[package]] 643 | name = "regex" 644 | version = "2022.7.9" 645 | description = "Alternative regular expression module, to replace re." 646 | category = "main" 647 | optional = false 648 | python-versions = ">=3.6" 649 | 650 | [[package]] 651 | name = "requests" 652 | version = "2.28.1" 653 | description = "Python HTTP for Humans." 654 | category = "main" 655 | optional = false 656 | python-versions = ">=3.7, <4" 657 | 658 | [package.dependencies] 659 | certifi = ">=2017.4.17" 660 | charset-normalizer = ">=2,<3" 661 | idna = ">=2.5,<4" 662 | urllib3 = ">=1.21.1,<1.27" 663 | 664 | [package.extras] 665 | socks = ["PySocks (>=1.5.6,!=1.5.7)"] 666 | use_chardet_on_py3 = ["chardet (>=3.0.2,<6)"] 667 | 668 | [[package]] 669 | name = "requests-oauthlib" 670 | version = "1.3.1" 671 | description = "OAuthlib authentication support for Requests." 672 | category = "dev" 673 | optional = false 674 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 675 | 676 | [package.dependencies] 677 | oauthlib = ">=3.0.0" 678 | requests = ">=2.0.0" 679 | 680 | [package.extras] 681 | rsa = ["oauthlib[signedtoken] (>=3.0.0)"] 682 | 683 | [[package]] 684 | name = "responses" 685 | version = "0.18.0" 686 | description = "A utility library for mocking out the `requests` Python library." 687 | category = "dev" 688 | optional = false 689 | python-versions = ">=3.7" 690 | 691 | [package.dependencies] 692 | requests = ">=2.0,<3.0" 693 | urllib3 = ">=1.25.10" 694 | 695 | [package.extras] 696 | tests = ["pytest (>=4.6)", "coverage (>=6.0.0)", "pytest-cov", "pytest-localserver", "flake8", "types-mock", "types-requests", "mypy"] 697 | 698 | [[package]] 699 | name = "rsa" 700 | version = "4.8" 701 | description = "Pure-Python RSA implementation" 702 | category = "dev" 703 | optional = false 704 | python-versions = ">=3.6,<4" 705 | 706 | [package.dependencies] 707 | pyasn1 = ">=0.1.3" 708 | 709 | [[package]] 710 | name = "scikit-learn" 711 | version = "1.1.1" 712 | description = "A set of python modules for machine learning and data mining" 713 | category = "main" 714 | optional = false 715 | python-versions = ">=3.8" 716 | 717 | [package.dependencies] 718 | joblib = ">=1.0.0" 719 | numpy = ">=1.17.3" 720 | scipy = ">=1.3.2" 721 | threadpoolctl = ">=2.0.0" 722 | 723 | [package.extras] 724 | benchmark = ["matplotlib (>=3.1.2)", "pandas (>=1.0.5)", "memory-profiler (>=0.57.0)"] 725 | docs = ["matplotlib (>=3.1.2)", "scikit-image (>=0.14.5)", "pandas (>=1.0.5)", "seaborn (>=0.9.0)", "memory-profiler (>=0.57.0)", "sphinx (>=4.0.1)", "sphinx-gallery (>=0.7.0)", "numpydoc (>=1.2.0)", "Pillow (>=7.1.2)", "sphinx-prompt (>=1.3.0)", "sphinxext-opengraph (>=0.4.2)"] 726 | examples = ["matplotlib (>=3.1.2)", "scikit-image (>=0.14.5)", "pandas (>=1.0.5)", "seaborn (>=0.9.0)"] 727 | tests = ["matplotlib (>=3.1.2)", "scikit-image (>=0.14.5)", "pandas (>=1.0.5)", "pytest (>=5.0.1)", "pytest-cov (>=2.9.0)", "flake8 (>=3.8.2)", "black (>=22.3.0)", "mypy (>=0.770)", "pyamg (>=4.0.0)", "numpydoc (>=1.2.0)"] 728 | 729 | [[package]] 730 | name = "scipy" 731 | version = "1.8.1" 732 | description = "SciPy: Scientific Library for Python" 733 | category = "main" 734 | optional = false 735 | python-versions = ">=3.8,<3.11" 736 | 737 | [package.dependencies] 738 | numpy = ">=1.17.3,<1.25.0" 739 | 740 | [[package]] 741 | name = "six" 742 | version = "1.16.0" 743 | description = "Python 2 and 3 compatibility utilities" 744 | category = "main" 745 | optional = false 746 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" 747 | 748 | [[package]] 749 | name = "tensorboard" 750 | version = "2.9.0" 751 | description = "TensorBoard lets you watch Tensors Flow" 752 | category = "dev" 753 | optional = false 754 | python-versions = ">=3.6" 755 | 756 | [package.dependencies] 757 | absl-py = ">=0.4" 758 | google-auth = ">=1.6.3,<3" 759 | google-auth-oauthlib = ">=0.4.1,<0.5" 760 | grpcio = ">=1.24.3" 761 | markdown = ">=2.6.8" 762 | numpy = ">=1.12.0" 763 | protobuf = ">=3.9.2" 764 | requests = ">=2.21.0,<3" 765 | tensorboard-data-server = ">=0.6.0,<0.7.0" 766 | tensorboard-plugin-wit = ">=1.6.0" 767 | werkzeug = ">=1.0.1" 768 | 769 | [[package]] 770 | name = "tensorboard-data-server" 771 | version = "0.6.1" 772 | description = "Fast data loading for TensorBoard" 773 | category = "dev" 774 | optional = false 775 | python-versions = ">=3.6" 776 | 777 | [[package]] 778 | name = "tensorboard-plugin-wit" 779 | version = "1.8.1" 780 | description = "What-If Tool TensorBoard plugin." 781 | category = "dev" 782 | optional = false 783 | python-versions = "*" 784 | 785 | [[package]] 786 | name = "threadpoolctl" 787 | version = "3.1.0" 788 | description = "threadpoolctl" 789 | category = "main" 790 | optional = false 791 | python-versions = ">=3.6" 792 | 793 | [[package]] 794 | name = "tokenizers" 795 | version = "0.12.1" 796 | description = "Fast and Customizable Tokenizers" 797 | category = "main" 798 | optional = false 799 | python-versions = "*" 800 | 801 | [package.extras] 802 | docs = ["sphinx", "sphinx-rtd-theme", "setuptools-rust"] 803 | testing = ["pytest", "requests", "numpy", "datasets"] 804 | 805 | [[package]] 806 | name = "toml" 807 | version = "0.10.2" 808 | description = "Python Library for Tom's Obvious, Minimal Language" 809 | category = "main" 810 | optional = false 811 | python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" 812 | 813 | [[package]] 814 | name = "torch" 815 | version = "1.12.0" 816 | description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" 817 | category = "main" 818 | optional = false 819 | python-versions = ">=3.7.0" 820 | 821 | [package.dependencies] 822 | typing-extensions = "*" 823 | 824 | [[package]] 825 | name = "torchmetrics" 826 | version = "0.9.2" 827 | description = "PyTorch native Metrics" 828 | category = "dev" 829 | optional = false 830 | python-versions = ">=3.7" 831 | 832 | [package.dependencies] 833 | numpy = ">=1.17.2" 834 | packaging = "*" 835 | torch = ">=1.3.1" 836 | 837 | [package.extras] 838 | all = ["pystoi", "torchvision (>=0.8)", "pycocotools", "sphinx-copybutton (>=0.3)", "nbsphinx (>=0.8)", "docutils (>=0.16)", "sphinxcontrib-fulltoc (>=1.0)", "sphinx-paramlinks (>=0.5.1)", "sphinx-autodoc-typehints (>=1.0)", "sphinx (>=4.0,<5.0)", "myst-parser", "sphinx-togglebutton (>=0.2)", "pandoc (>=1.0)", "sphinxcontrib-mockautodoc", "torchvision", "torch-fidelity", "lpips", "scipy", "pytorch-lightning (>=1.5)", "mypy (>=0.790)", "transformers (>=4.0)", "pytorch-msssim", "scikit-image (>0.17.1)", "fast-bss-eval (>=0.1.0)", "rouge-score (>=0.0.4)", "torch-complex", "bert-score (==0.3.10)", "cloudpickle (>=1.3)", "jiwer (>=2.3.0)", "pypesq", "coverage (>5.2)", "codecov (>=2.1)", "huggingface-hub (<0.7)", "sacrebleu (>=2.0.0)", "pytest-cov (>2.10)", "psutil", "pre-commit (>=1.0)", "requests", "pytest (>=6.0.0,<7.0.0)", "twine (>=3.2)", "scikit-learn (>1.0,<1.1.1)", "pytest-doctestplus (>=0.9.0)", "pytest-timeout", "phmdoctest (>=1.1.1)", "check-manifest", "mir-eval (>=0.6)", "fire", "nltk (>=3.6)", "regex (>=2021.9.24)", "tqdm (>=4.41.0)"] 839 | audio = ["pystoi"] 840 | detection = ["torchvision (>=0.8)", "pycocotools"] 841 | docs = ["sphinx-copybutton (>=0.3)", "nbsphinx (>=0.8)", "docutils (>=0.16)", "sphinxcontrib-fulltoc (>=1.0)", "sphinx-paramlinks (>=0.5.1)", "sphinx-autodoc-typehints (>=1.0)", "sphinx (>=4.0,<5.0)", "myst-parser", "sphinx-togglebutton (>=0.2)", "pandoc (>=1.0)", "sphinxcontrib-mockautodoc"] 842 | image = ["torchvision", "torch-fidelity", "lpips", "scipy"] 843 | integrate = ["pytorch-lightning (>=1.5)"] 844 | test = ["mypy (>=0.790)", "transformers (>=4.0)", "pytorch-msssim", "scikit-image (>0.17.1)", "fast-bss-eval (>=0.1.0)", "pycocotools", "rouge-score (>=0.0.4)", "torch-complex", "bert-score (==0.3.10)", "cloudpickle (>=1.3)", "jiwer (>=2.3.0)", "pypesq", "coverage (>5.2)", "codecov (>=2.1)", "huggingface-hub (<0.7)", "sacrebleu (>=2.0.0)", "pytest-cov (>2.10)", "psutil", "pre-commit (>=1.0)", "requests", "pytest (>=6.0.0,<7.0.0)", "twine (>=3.2)", "scikit-learn (>1.0,<1.1.1)", "pytest-doctestplus (>=0.9.0)", "pytest-timeout", "phmdoctest (>=1.1.1)", "check-manifest", "mir-eval (>=0.6)", "fire"] 845 | text = ["nltk (>=3.6)", "regex (>=2021.9.24)", "tqdm (>=4.41.0)"] 846 | 847 | [[package]] 848 | name = "tqdm" 849 | version = "4.64.0" 850 | description = "Fast, Extensible Progress Meter" 851 | category = "main" 852 | optional = false 853 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" 854 | 855 | [package.dependencies] 856 | colorama = {version = "*", markers = "platform_system == \"Windows\""} 857 | 858 | [package.extras] 859 | dev = ["py-make (>=0.1.0)", "twine", "wheel"] 860 | notebook = ["ipywidgets (>=6)"] 861 | slack = ["slack-sdk"] 862 | telegram = ["requests"] 863 | 864 | [[package]] 865 | name = "transformers" 866 | version = "4.20.1" 867 | description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" 868 | category = "main" 869 | optional = false 870 | python-versions = ">=3.7.0" 871 | 872 | [package.dependencies] 873 | filelock = "*" 874 | huggingface-hub = ">=0.1.0,<1.0" 875 | numpy = ">=1.17" 876 | packaging = ">=20.0" 877 | pyyaml = ">=5.1" 878 | regex = "!=2019.12.17" 879 | requests = "*" 880 | tokenizers = ">=0.11.1,<0.11.3 || >0.11.3,<0.13" 881 | tqdm = ">=4.27" 882 | 883 | [package.extras] 884 | all = ["tensorflow (>=2.3)", "onnxconverter-common", "tf2onnx", "torch (>=1.0)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "flax (>=0.3.5)", "optax (>=0.0.8)", "sentencepiece (>=0.1.91,!=0.1.92)", "protobuf (<=3.20.1)", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer", "pillow", "optuna", "ray", "sigopt", "timm", "codecarbon (==1.2.0)"] 885 | audio = ["librosa", "pyctcdecode (>=0.3.0)", "phonemizer"] 886 | codecarbon = ["codecarbon (==1.2.0)"] 887 | deepspeed = ["deepspeed (>=0.6.5)"] 888 | deepspeed-testing = ["deepspeed (>=0.6.5)", "pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "dill (<0.3.5)", "pytest-timeout", "black (>=22.3,<23.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "protobuf (<=3.20.1)", "sacremoses", "rjieba", "faiss-cpu", "cookiecutter (==1.7.3)", "optuna"] 889 | dev = ["tensorflow (>=2.3)", "onnxconverter-common", "tf2onnx", "torch (>=1.0)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "flax (>=0.3.5)", "optax (>=0.0.8)", "sentencepiece (>=0.1.91,!=0.1.92)", "protobuf (<=3.20.1)", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer", "pillow", "optuna", "ray", "sigopt", "timm", "codecarbon (==1.2.0)", "pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "dill (<0.3.5)", "pytest-timeout", "black (>=22.3,<23.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "sacremoses", "rjieba", "faiss-cpu", "cookiecutter (==1.7.3)", "isort (>=5.5.4)", "flake8 (>=3.8.3)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "unidic-lite (>=1.0.7)", "unidic (>=1.0.2)", "hf-doc-builder", "scikit-learn"] 890 | dev-tensorflow = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "dill (<0.3.5)", "pytest-timeout", "black (>=22.3,<23.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "protobuf (<=3.20.1)", "sacremoses", "rjieba", "faiss-cpu", "cookiecutter (==1.7.3)", "tensorflow (>=2.3)", "onnxconverter-common", "tf2onnx", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "pillow", "isort (>=5.5.4)", "flake8 (>=3.8.3)", "hf-doc-builder", "scikit-learn", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer"] 891 | dev-torch = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "dill (<0.3.5)", "pytest-timeout", "black (>=22.3,<23.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "protobuf (<=3.20.1)", "sacremoses", "rjieba", "faiss-cpu", "cookiecutter (==1.7.3)", "torch (>=1.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer", "pillow", "optuna", "ray", "sigopt", "timm", "codecarbon (==1.2.0)", "isort (>=5.5.4)", "flake8 (>=3.8.3)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "unidic-lite (>=1.0.7)", "unidic (>=1.0.2)", "hf-doc-builder", "scikit-learn", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] 892 | docs = ["tensorflow (>=2.3)", "onnxconverter-common", "tf2onnx", "torch (>=1.0)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "flax (>=0.3.5)", "optax (>=0.0.8)", "sentencepiece (>=0.1.91,!=0.1.92)", "protobuf (<=3.20.1)", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer", "pillow", "optuna", "ray", "sigopt", "timm", "codecarbon (==1.2.0)", "hf-doc-builder"] 893 | docs_specific = ["hf-doc-builder"] 894 | fairscale = ["fairscale (>0.3)"] 895 | flax = ["jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "flax (>=0.3.5)", "optax (>=0.0.8)"] 896 | flax-speech = ["librosa", "pyctcdecode (>=0.3.0)", "phonemizer"] 897 | ftfy = ["ftfy"] 898 | integrations = ["optuna", "ray", "sigopt"] 899 | ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "unidic-lite (>=1.0.7)", "unidic (>=1.0.2)"] 900 | modelcreation = ["cookiecutter (==1.7.3)"] 901 | onnx = ["onnxconverter-common", "tf2onnx", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] 902 | onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] 903 | optuna = ["optuna"] 904 | quality = ["black (>=22.3,<23.0)", "isort (>=5.5.4)", "flake8 (>=3.8.3)", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)"] 905 | ray = ["ray"] 906 | retrieval = ["faiss-cpu", "datasets"] 907 | sagemaker = ["sagemaker (>=2.31.0)"] 908 | sentencepiece = ["sentencepiece (>=0.1.91,!=0.1.92)", "protobuf (<=3.20.1)"] 909 | serving = ["pydantic", "uvicorn", "fastapi", "starlette"] 910 | sigopt = ["sigopt"] 911 | sklearn = ["scikit-learn"] 912 | speech = ["torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer"] 913 | testing = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "dill (<0.3.5)", "pytest-timeout", "black (>=22.3,<23.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "protobuf (<=3.20.1)", "sacremoses", "rjieba", "faiss-cpu", "cookiecutter (==1.7.3)"] 914 | tf = ["tensorflow (>=2.3)", "onnxconverter-common", "tf2onnx"] 915 | tf-cpu = ["tensorflow-cpu (>=2.3)", "onnxconverter-common", "tf2onnx"] 916 | tf-speech = ["librosa", "pyctcdecode (>=0.3.0)", "phonemizer"] 917 | timm = ["timm"] 918 | tokenizers = ["tokenizers (>=0.11.1,!=0.11.3,<0.13)"] 919 | torch = ["torch (>=1.0)"] 920 | torch-speech = ["torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer"] 921 | torchhub = ["filelock", "huggingface-hub (>=0.1.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf (<=3.20.1)", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.0)", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "tqdm (>=4.27)"] 922 | vision = ["pillow"] 923 | 924 | [[package]] 925 | name = "typing-extensions" 926 | version = "4.3.0" 927 | description = "Backported and Experimental Type Hints for Python 3.7+" 928 | category = "main" 929 | optional = false 930 | python-versions = ">=3.7" 931 | 932 | [[package]] 933 | name = "urllib3" 934 | version = "1.26.10" 935 | description = "HTTP library with thread-safe connection pooling, file post, and more." 936 | category = "main" 937 | optional = false 938 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, <4" 939 | 940 | [package.extras] 941 | brotli = ["brotlicffi (>=0.8.0)", "brotli (>=1.0.9)", "brotlipy (>=0.6.0)"] 942 | secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "ipaddress"] 943 | socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] 944 | 945 | [[package]] 946 | name = "werkzeug" 947 | version = "2.1.2" 948 | description = "The comprehensive WSGI web application library." 949 | category = "dev" 950 | optional = false 951 | python-versions = ">=3.7" 952 | 953 | [package.extras] 954 | watchdog = ["watchdog"] 955 | 956 | [[package]] 957 | name = "win32-setctime" 958 | version = "1.1.0" 959 | description = "A small Python utility to set file creation time on Windows" 960 | category = "main" 961 | optional = false 962 | python-versions = ">=3.5" 963 | 964 | [package.extras] 965 | dev = ["pytest (>=4.6.2)", "black (>=19.3b0)"] 966 | 967 | [[package]] 968 | name = "xxhash" 969 | version = "3.0.0" 970 | description = "Python binding for xxHash" 971 | category = "dev" 972 | optional = false 973 | python-versions = ">=3.6" 974 | 975 | [[package]] 976 | name = "yarl" 977 | version = "1.7.2" 978 | description = "Yet another URL library" 979 | category = "dev" 980 | optional = false 981 | python-versions = ">=3.6" 982 | 983 | [package.dependencies] 984 | idna = ">=2.0" 985 | multidict = ">=4.0" 986 | 987 | [[package]] 988 | name = "zipp" 989 | version = "3.8.1" 990 | description = "Backport of pathlib-compatible object wrapper for zip files" 991 | category = "main" 992 | optional = false 993 | python-versions = ">=3.7" 994 | 995 | [package.extras] 996 | docs = ["sphinx", "jaraco.packaging (>=9)", "rst.linker (>=1.9)", "jaraco.tidelift (>=1.4)"] 997 | testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.3)", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)"] 998 | 999 | [metadata] 1000 | lock-version = "1.1" 1001 | python-versions = ">=3.8,<3.11" 1002 | content-hash = "be24d9496911193c229f57d7e6a8ae9ca51c5f97a09861158210b568a4ecd622" 1003 | 1004 | [metadata.files] 1005 | absl-py = [ 1006 | {file = "absl-py-1.1.0.tar.gz", hash = "sha256:3aa39f898329c2156ff525dfa69ce709e42d77aab18bf4917719d6f260aa6a08"}, 1007 | {file = "absl_py-1.1.0-py3-none-any.whl", hash = "sha256:db97287655e30336938f8058d2c81ed2be6af1d9b6ebbcd8df1080a6c7fcd24e"}, 1008 | ] 1009 | aiohttp = [] 1010 | aiosignal = [] 1011 | async-timeout = [] 1012 | atomicwrites = [] 1013 | attrs = [] 1014 | cachetools = [ 1015 | {file = "cachetools-5.2.0-py3-none-any.whl", hash = "sha256:f9f17d2aec496a9aa6b76f53e3b614c965223c061982d434d160f930c698a9db"}, 1016 | {file = "cachetools-5.2.0.tar.gz", hash = "sha256:6a94c6402995a99c3970cc7e4884bb60b4a8639938157eeed436098bf9831757"}, 1017 | ] 1018 | certifi = [] 1019 | charset-normalizer = [] 1020 | colorama = [] 1021 | datasets = [ 1022 | {file = "datasets-1.18.4-py3-none-any.whl", hash = "sha256:e13695ad7aeda2af4430ac1a0b62def9c4b60bb4cc14dbaa240e6683cac50c49"}, 1023 | {file = "datasets-1.18.4.tar.gz", hash = "sha256:8f28a7afc2f894c68cb017335a32812f443fe41bc59c089cbd15d7412d3f7f96"}, 1024 | ] 1025 | dill = [] 1026 | einops = [ 1027 | {file = "einops-0.3.2-py3-none-any.whl", hash = "sha256:285f3c75620897acb8b5580170c88121f010c77ce130bc7b9f220179009dafe0"}, 1028 | {file = "einops-0.3.2.tar.gz", hash = "sha256:5200e413539f0377f4177ef00dc019968f4177c49b1db3e836c7883df2a5fe2e"}, 1029 | ] 1030 | filelock = [] 1031 | frozenlist = [] 1032 | fsspec = [] 1033 | google-auth = [ 1034 | {file = "google-auth-2.9.1.tar.gz", hash = "sha256:14292fa3429f2bb1e99862554cde1ee730d6840ebae067814d3d15d8549c0888"}, 1035 | {file = "google_auth-2.9.1-py2.py3-none-any.whl", hash = "sha256:5a7eed0cb0e3a83989fad0b59fe1329dfc8c479543039cd6fd1e01e9adf39475"}, 1036 | ] 1037 | google-auth-oauthlib = [ 1038 | {file = "google-auth-oauthlib-0.4.6.tar.gz", hash = "sha256:a90a072f6993f2c327067bf65270046384cda5a8ecb20b94ea9a687f1f233a7a"}, 1039 | {file = "google_auth_oauthlib-0.4.6-py2.py3-none-any.whl", hash = "sha256:3f2a6e802eebbb6fb736a370fbf3b055edcb6b52878bf2f26330b5e041316c73"}, 1040 | ] 1041 | grpcio = [ 1042 | {file = "grpcio-1.38.1-cp27-cp27m-macosx_10_10_x86_64.whl", hash = "sha256:118479436bda25b369e2dc1cd0921790fbfaea1ec663e4ee7095c4c325694495"}, 1043 | {file = "grpcio-1.38.1-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:7adfbd4e22647f880c9ed86b2be7f6d7a7dbbb8adc09395808cc7a4d021bc328"}, 1044 | {file = "grpcio-1.38.1-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:87b4b1977b52d5e0873a5e396340d2443640ba760f4fa23e93a38997ecfbcd5b"}, 1045 | {file = "grpcio-1.38.1-cp27-cp27m-win32.whl", hash = "sha256:3a25e1a46f51c80d06b66223f61938b9ffda37f2824ca65749c49b758137fac2"}, 1046 | {file = "grpcio-1.38.1-cp27-cp27m-win_amd64.whl", hash = "sha256:b5ea9902fc2990af993b74862282b49ae0b8de8a64ca3b4a8dda26a3163c3bb4"}, 1047 | {file = "grpcio-1.38.1-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:8ccde1df51eeaddf5515edc41bde2ea43a834a288914eae9ce4287399be108f5"}, 1048 | {file = "grpcio-1.38.1-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:0e193feaf4ebc72f6af57d7b8a08c0b8e43ebbd76f81c6f1e55d013557602dfd"}, 1049 | {file = "grpcio-1.38.1-cp35-cp35m-macosx_10_10_intel.whl", hash = "sha256:b16e1967709392a0ec4b10b4374a72eb062c47c168a189606c9a7ea7b36593a8"}, 1050 | {file = "grpcio-1.38.1-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:4bc60f8372c3ab06f41279163c5d558bf95195bb3f68e35ed19f95d4fbd53d71"}, 1051 | {file = "grpcio-1.38.1-cp35-cp35m-manylinux2010_x86_64.whl", hash = "sha256:a433d3740a9ef7bc34a18e2b12bf72b25e618facdfd09871167b30fd8e955fed"}, 1052 | {file = "grpcio-1.38.1-cp35-cp35m-manylinux2014_i686.whl", hash = "sha256:d49f250c3ffbe83ba2d03e3500e03505576a985f7c5f77172a9531058347aa68"}, 1053 | {file = "grpcio-1.38.1-cp35-cp35m-manylinux2014_x86_64.whl", hash = "sha256:6e137d014cf4162e5a796777012452516d92547717c1b4914fb71ce4e41817b5"}, 1054 | {file = "grpcio-1.38.1-cp35-cp35m-win32.whl", hash = "sha256:5ff4802d9b3704e680454289587e1cc146bb0d953cf3c9296e2d96441a6a8e88"}, 1055 | {file = "grpcio-1.38.1-cp35-cp35m-win_amd64.whl", hash = "sha256:4c19578b35715e110c324b27c18ab54a56fccc4c41b8f651b1d1da5a64e0d605"}, 1056 | {file = "grpcio-1.38.1-cp36-cp36m-linux_armv7l.whl", hash = "sha256:6edf68d4305e08f6f8c45bfaa9dc04d527ab5a1562aaf0c452fa921fbe90eb23"}, 1057 | {file = "grpcio-1.38.1-cp36-cp36m-macosx_10_10_x86_64.whl", hash = "sha256:ddd33c90b0c95eca737c9f6db7e969a48d23aed72cecb23f3b8aac009ca2cfb4"}, 1058 | {file = "grpcio-1.38.1-cp36-cp36m-manylinux2010_i686.whl", hash = "sha256:c83481501533824fe341c17d297bbec1ec584ec46b352f98ce12bf16740615c4"}, 1059 | {file = "grpcio-1.38.1-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:3e85bba6f0e0c454a90b8fea16b59db9c6d19ddf9cc95052b2d4ca77b22d46d6"}, 1060 | {file = "grpcio-1.38.1-cp36-cp36m-manylinux2014_i686.whl", hash = "sha256:dcfcb147c18272a22a592251a49830b3c7abc82385ffff34916c2534175d885e"}, 1061 | {file = "grpcio-1.38.1-cp36-cp36m-manylinux2014_x86_64.whl", hash = "sha256:419af4f577a3d5d9f386aeacf4c4992f90016f84cbceb11ecd832101b1f7f9c9"}, 1062 | {file = "grpcio-1.38.1-cp36-cp36m-manylinux_2_24_aarch64.whl", hash = "sha256:cd7ddb5b6ffcbd3691990df20f260a888c8bd770d57480a97da1b756fb1be5c0"}, 1063 | {file = "grpcio-1.38.1-cp36-cp36m-win32.whl", hash = "sha256:d4179d96b0ce27602756185c1a00d088c9c1feb0cc17a36f8a66eec6ddddbc0c"}, 1064 | {file = "grpcio-1.38.1-cp36-cp36m-win_amd64.whl", hash = "sha256:96d78d9edf3070770cefd1822bc220d8cccad049b818a70a3c630052e9f15490"}, 1065 | {file = "grpcio-1.38.1-cp37-cp37m-linux_armv7l.whl", hash = "sha256:8ab27a6626c2038e13c1b250c5cd22da578f182364134620ec298b4ccfc85722"}, 1066 | {file = "grpcio-1.38.1-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:532ab738351aad2cdad80f4355123652e08b207281f3923ce51fb2b58692dd4c"}, 1067 | {file = "grpcio-1.38.1-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:e4a8a371ad02bf31576bcd99093cea3849e19ca1e9eb63fc0b2c0f1db1132f7d"}, 1068 | {file = "grpcio-1.38.1-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:89af675d38bf490384dae85151768b8434e997cece98e5d1eb6fcb3c16d6af12"}, 1069 | {file = "grpcio-1.38.1-cp37-cp37m-manylinux2014_i686.whl", hash = "sha256:ff9ebc416e815161d89d2fd22d1a91acf3b810ef800dae38c402d19d203590bf"}, 1070 | {file = "grpcio-1.38.1-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:3db0680fee9e55022677abda186e73e3c019c59ed83e1550519250dc97cf6793"}, 1071 | {file = "grpcio-1.38.1-cp37-cp37m-manylinux_2_24_aarch64.whl", hash = "sha256:a77d1f47e5e82504c531bc9dd22c093ff093b6706ec8bcdad228464ef3a5dd54"}, 1072 | {file = "grpcio-1.38.1-cp37-cp37m-win32.whl", hash = "sha256:549beb5646137b78534a312a3b80b2b8b1ea01058b38a711d42d6b54b20b6c2b"}, 1073 | {file = "grpcio-1.38.1-cp37-cp37m-win_amd64.whl", hash = "sha256:3eb960c2f9e031f0643b53bab67733a9544d82f42d0714338183d14993d2a23c"}, 1074 | {file = "grpcio-1.38.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:e90cda2ccd4bdb89a3cd5dc11771c3b8394817d5caaa1ae36042bc96a428c10e"}, 1075 | {file = "grpcio-1.38.1-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:26af85ae0a7ff8e8f8f550255bf85551df86a89883c11721c0756b71bc1019be"}, 1076 | {file = "grpcio-1.38.1-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:947bdba3ebcd93a7cef537d6405bc5667d1caf818fa8bbd2e2cc952ec8f97e09"}, 1077 | {file = "grpcio-1.38.1-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:6d898441ada374f76e0b5354d7e240e1c0e905a1ebcb1e95d9ffd99c88f63700"}, 1078 | {file = "grpcio-1.38.1-cp38-cp38-manylinux2014_i686.whl", hash = "sha256:59f5fb4ba219a11fdc1c23e17c93ca3090480a8cde4370c980908546ffc091e6"}, 1079 | {file = "grpcio-1.38.1-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:cddd61bff66e42ef334f8cb9e719951e479b5ad2cb75c00338aac8de28e17484"}, 1080 | {file = "grpcio-1.38.1-cp38-cp38-manylinux_2_24_aarch64.whl", hash = "sha256:c323265a4f18f586e8de84fda12b48eb3bd48395294aa2b8c05307ac1680299d"}, 1081 | {file = "grpcio-1.38.1-cp38-cp38-win32.whl", hash = "sha256:72e8358c751da9ab4f8653a3b67b2a3bb7e330ee57cb26439c6af358d6eac032"}, 1082 | {file = "grpcio-1.38.1-cp38-cp38-win_amd64.whl", hash = "sha256:278e131bfbc57bab112359b98930b0fdbf81aa0ba2cdfc6555c7a5119d7e2117"}, 1083 | {file = "grpcio-1.38.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:44efa41ac36f6bcbf4f64d6479b3031cceea28cf6892a77f15bd1c22611bff9d"}, 1084 | {file = "grpcio-1.38.1-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:cf6c3bfa403e055380fe90844beb4fe8e9448edab5d2bf40d37d208dbb2f768c"}, 1085 | {file = "grpcio-1.38.1-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:5efa68fc3fe0c439e2858215f2224bfb7242c35079538d58063f68a0d5d5ec33"}, 1086 | {file = "grpcio-1.38.1-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:2a179b2565fa85a134933acc7845f9d4c12e742c802b4f50bf2fd208bf8b741e"}, 1087 | {file = "grpcio-1.38.1-cp39-cp39-manylinux2014_i686.whl", hash = "sha256:b1624123710fa701988a8a43994de78416e5010ac1508f64ed41e2577358604a"}, 1088 | {file = "grpcio-1.38.1-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:6a225440015db88ec4625a2a41c21582a50cce7ffbe38dcbbb416c7180352516"}, 1089 | {file = "grpcio-1.38.1-cp39-cp39-manylinux_2_24_aarch64.whl", hash = "sha256:e891b0936aab73550d673dd3bbf89fa9577b3db1a61baecea480afd36fdb1852"}, 1090 | {file = "grpcio-1.38.1-cp39-cp39-win32.whl", hash = "sha256:889518ce7c2a0609a3cffb7b667669a39b3410e869ff38e087bf7eeadad62e5d"}, 1091 | {file = "grpcio-1.38.1-cp39-cp39-win_amd64.whl", hash = "sha256:77054f24d46498d9696c809da7810b67bccf6153f9848ea48331708841926d82"}, 1092 | {file = "grpcio-1.38.1.tar.gz", hash = "sha256:1f79d8a24261e3c12ec3a6c25945ff799ae09874fd24815bc17c2dc37715ef6c"}, 1093 | ] 1094 | huggingface-hub = [] 1095 | idna = [] 1096 | importlib-metadata = [] 1097 | iniconfig = [] 1098 | joblib = [ 1099 | {file = "joblib-1.1.0-py2.py3-none-any.whl", hash = "sha256:f21f109b3c7ff9d95f8387f752d0d9c34a02aa2f7060c2135f465da0e5160ff6"}, 1100 | {file = "joblib-1.1.0.tar.gz", hash = "sha256:4158fcecd13733f8be669be0683b96ebdbbd38d23559f54dca7205aea1bf1e35"}, 1101 | ] 1102 | loguru = [ 1103 | {file = "loguru-0.5.3-py3-none-any.whl", hash = "sha256:f8087ac396b5ee5f67c963b495d615ebbceac2796379599820e324419d53667c"}, 1104 | {file = "loguru-0.5.3.tar.gz", hash = "sha256:b28e72ac7a98be3d28ad28570299a393dfcd32e5e3f6a353dec94675767b6319"}, 1105 | ] 1106 | mako = [ 1107 | {file = "Mako-1.2.1-py3-none-any.whl", hash = "sha256:df3921c3081b013c8a2d5ff03c18375651684921ae83fd12e64800b7da923257"}, 1108 | {file = "Mako-1.2.1.tar.gz", hash = "sha256:f054a5ff4743492f1aa9ecc47172cb33b42b9d993cffcc146c9de17e717b0307"}, 1109 | ] 1110 | markdown = [ 1111 | {file = "Markdown-3.4.1-py3-none-any.whl", hash = "sha256:08fb8465cffd03d10b9dd34a5c3fea908e20391a2a90b88d66362cb05beed186"}, 1112 | {file = "Markdown-3.4.1.tar.gz", hash = "sha256:3b809086bb6efad416156e00a0da66fe47618a5d6918dd688f53f40c8e4cfeff"}, 1113 | ] 1114 | markupsafe = [] 1115 | mmh3 = [ 1116 | {file = "mmh3-3.0.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:23912dde2ad4f701926948dd8e79a0e42b000f73962806f153931f52985e1e07"}, 1117 | {file = "mmh3-3.0.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:07f1308a410dc406d6a3c282a685728d00a87f3ed684f012671b96d6cc6a41c3"}, 1118 | {file = "mmh3-3.0.0-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:167cbc2b5ae27f3bccd797a2e8a9e7561791bee4cc2885f2c140eedc5df000ef"}, 1119 | {file = "mmh3-3.0.0-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:8fb833c2942917eff54f984b067d93e5a3c54dbb00720323460cdfed9292835f"}, 1120 | {file = "mmh3-3.0.0-cp36-cp36m-win32.whl", hash = "sha256:b7d26d0243ed9a5b8bf7aa8c53697cb79dff1e1d207f42396b7a7cb2a62298b7"}, 1121 | {file = "mmh3-3.0.0-cp36-cp36m-win_amd64.whl", hash = "sha256:2b6c79fc314b34b911245b460a79b601fff39bb807521fb7ed7c15cacf0394ac"}, 1122 | {file = "mmh3-3.0.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6d0b3e9def1fdfe4eadd35ee26bf72bd715ba97711f7101302d54c9d2e70ba27"}, 1123 | {file = "mmh3-3.0.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:8803d28c17cf898f5f00c0433e8b13d51fa3bb4ebecf59872ba1eaa20d94128a"}, 1124 | {file = "mmh3-3.0.0-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:01e456edf9cc381298a590923aadd1c0bf9934d93433099a5001d656112437c2"}, 1125 | {file = "mmh3-3.0.0-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:ff69ddc2d46e3e42720840b6b4f7bfb032fd1e677fac347fdfff6e4d9fd01212"}, 1126 | {file = "mmh3-3.0.0-cp37-cp37m-win32.whl", hash = "sha256:e08a5d81a2ff53625953290187bed4ae96a6972e2b5cd5984a6ebc5a9aab256c"}, 1127 | {file = "mmh3-3.0.0-cp37-cp37m-win_amd64.whl", hash = "sha256:12484ac80373db77d8a6beb7615e7dac8b6c3fb118905311a51450b4fc4a24d1"}, 1128 | {file = "mmh3-3.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:93c96e657e9bf9e9ef12ddaeae9f109c0b3134146e2eff2cbddde5a34190920e"}, 1129 | {file = "mmh3-3.0.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:9097be65aa95460bc68b6108601da8894757532450daf74034e4eaecd536acca"}, 1130 | {file = "mmh3-3.0.0-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:19874e12acb4119ef1ef83062ef4ac953c3343dd07a67ede8fa096d0393f34be"}, 1131 | {file = "mmh3-3.0.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:4589adcb609d1547aac7c1ac1064eb27cdd44b65b7e8a114e2971cd3b7110306"}, 1132 | {file = "mmh3-3.0.0-cp38-cp38-win32.whl", hash = "sha256:7a311efd4ecf122f21392ec6bf447c620cc783d20bdb9aec60bb469a54318419"}, 1133 | {file = "mmh3-3.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:3566d1455fa4a09f8fb1aa5b37f68914949674f9aa2bd630e9fdf344207f55b5"}, 1134 | {file = "mmh3-3.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:92fdffd63edb67c30dbaba18a7448d762209c0e678b0c9d577d17b30362b59a3"}, 1135 | {file = "mmh3-3.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3e52b869572c09db0c1a483f6e9cedbccfae8a282d95e552d3d4bd0712ab3196"}, 1136 | {file = "mmh3-3.0.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:f1cce018cc82a8a6287e6aeb139e441129837b810f2ddf372e3ff7f0fefb0947"}, 1137 | {file = "mmh3-3.0.0-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:0fd09c4b61fcddbcf0a87d5463b4e6d2919896736a67efc5248d5c74c1c9c742"}, 1138 | {file = "mmh3-3.0.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:c17fe2e276edd37ad8a6aff3b1663d3479c2c5c5993539c1050422a1dae33033"}, 1139 | {file = "mmh3-3.0.0-cp39-cp39-win32.whl", hash = "sha256:150439b906b4deaf6d796b2c2d11fb6159f08d02330d97723071ab3bf43b51df"}, 1140 | {file = "mmh3-3.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:bd870aedd9189eff1cf4e1687869b56c7e9461ee869789139c3e704009e5c227"}, 1141 | {file = "mmh3-3.0.0.tar.gz", hash = "sha256:d1ec578c09a07d3518ec9be540b87546397fa3455de73c166fcce51eaa5c41c5"}, 1142 | ] 1143 | multidict = [] 1144 | multiprocess = [] 1145 | numpy = [ 1146 | {file = "numpy-1.23.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b15c3f1ed08df4980e02cc79ee058b788a3d0bef2fb3c9ca90bb8cbd5b8a3a04"}, 1147 | {file = "numpy-1.23.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9ce242162015b7e88092dccd0e854548c0926b75c7924a3495e02c6067aba1f5"}, 1148 | {file = "numpy-1.23.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0d7447679ae9a7124385ccf0ea990bb85bb869cef217e2ea6c844b6a6855073"}, 1149 | {file = "numpy-1.23.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3119daed207e9410eaf57dcf9591fdc68045f60483d94956bee0bfdcba790953"}, 1150 | {file = "numpy-1.23.1-cp310-cp310-win32.whl", hash = "sha256:3ab67966c8d45d55a2bdf40701536af6443763907086c0a6d1232688e27e5447"}, 1151 | {file = "numpy-1.23.1-cp310-cp310-win_amd64.whl", hash = "sha256:1865fdf51446839ca3fffaab172461f2b781163f6f395f1aed256b1ddc253622"}, 1152 | {file = "numpy-1.23.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:aeba539285dcf0a1ba755945865ec61240ede5432df41d6e29fab305f4384db2"}, 1153 | {file = "numpy-1.23.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7e8229f3687cdadba2c4faef39204feb51ef7c1a9b669247d49a24f3e2e1617c"}, 1154 | {file = "numpy-1.23.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68b69f52e6545af010b76516f5daaef6173e73353e3295c5cb9f96c35d755641"}, 1155 | {file = "numpy-1.23.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1408c3527a74a0209c781ac82bde2182b0f0bf54dea6e6a363fe0cc4488a7ce7"}, 1156 | {file = "numpy-1.23.1-cp38-cp38-win32.whl", hash = "sha256:47f10ab202fe4d8495ff484b5561c65dd59177949ca07975663f4494f7269e3e"}, 1157 | {file = "numpy-1.23.1-cp38-cp38-win_amd64.whl", hash = "sha256:37e5ebebb0eb54c5b4a9b04e6f3018e16b8ef257d26c8945925ba8105008e645"}, 1158 | {file = "numpy-1.23.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:173f28921b15d341afadf6c3898a34f20a0569e4ad5435297ba262ee8941e77b"}, 1159 | {file = "numpy-1.23.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:876f60de09734fbcb4e27a97c9a286b51284df1326b1ac5f1bf0ad3678236b22"}, 1160 | {file = "numpy-1.23.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35590b9c33c0f1c9732b3231bb6a72d1e4f77872390c47d50a615686ae7ed3fd"}, 1161 | {file = "numpy-1.23.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a35c4e64dfca659fe4d0f1421fc0f05b8ed1ca8c46fb73d9e5a7f175f85696bb"}, 1162 | {file = "numpy-1.23.1-cp39-cp39-win32.whl", hash = "sha256:c2f91f88230042a130ceb1b496932aa717dcbd665350beb821534c5c7e15881c"}, 1163 | {file = "numpy-1.23.1-cp39-cp39-win_amd64.whl", hash = "sha256:37ece2bd095e9781a7156852e43d18044fd0d742934833335599c583618181b9"}, 1164 | {file = "numpy-1.23.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8002574a6b46ac3b5739a003b5233376aeac5163e5dcd43dd7ad062f3e186129"}, 1165 | {file = "numpy-1.23.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d732d17b8a9061540a10fda5bfeabca5785700ab5469a5e9b93aca5e2d3a5fb"}, 1166 | {file = "numpy-1.23.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:55df0f7483b822855af67e38fb3a526e787adf189383b4934305565d71c4b148"}, 1167 | {file = "numpy-1.23.1.tar.gz", hash = "sha256:d748ef349bfef2e1194b59da37ed5a29c19ea8d7e6342019921ba2ba4fd8b624"}, 1168 | ] 1169 | oauthlib = [ 1170 | {file = "oauthlib-3.2.0-py3-none-any.whl", hash = "sha256:6db33440354787f9b7f3a6dbd4febf5d0f93758354060e802f6c06cb493022fe"}, 1171 | {file = "oauthlib-3.2.0.tar.gz", hash = "sha256:23a8208d75b902797ea29fd31fa80a15ed9dc2c6c16fe73f5d346f83f6fa27a2"}, 1172 | ] 1173 | packaging = [] 1174 | pandas = [] 1175 | pdoc3 = [ 1176 | {file = "pdoc3-0.9.2.tar.gz", hash = "sha256:9df5d931f25f353c69c46819a3bd03ef96dd286f2a70bb1b93a23a781f91faa1"}, 1177 | ] 1178 | pillow = [ 1179 | {file = "Pillow-8.4.0-cp310-cp310-macosx_10_10_universal2.whl", hash = "sha256:81f8d5c81e483a9442d72d182e1fb6dcb9723f289a57e8030811bac9ea3fef8d"}, 1180 | {file = "Pillow-8.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3f97cfb1e5a392d75dd8b9fd274d205404729923840ca94ca45a0af57e13dbe6"}, 1181 | {file = "Pillow-8.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb9fc393f3c61f9054e1ed26e6fe912c7321af2f41ff49d3f83d05bacf22cc78"}, 1182 | {file = "Pillow-8.4.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d82cdb63100ef5eedb8391732375e6d05993b765f72cb34311fab92103314649"}, 1183 | {file = "Pillow-8.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:62cc1afda735a8d109007164714e73771b499768b9bb5afcbbee9d0ff374b43f"}, 1184 | {file = "Pillow-8.4.0-cp310-cp310-win32.whl", hash = "sha256:e3dacecfbeec9a33e932f00c6cd7996e62f53ad46fbe677577394aaa90ee419a"}, 1185 | {file = "Pillow-8.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:620582db2a85b2df5f8a82ddeb52116560d7e5e6b055095f04ad828d1b0baa39"}, 1186 | {file = "Pillow-8.4.0-cp36-cp36m-macosx_10_10_x86_64.whl", hash = "sha256:1bc723b434fbc4ab50bb68e11e93ce5fb69866ad621e3c2c9bdb0cd70e345f55"}, 1187 | {file = "Pillow-8.4.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:72cbcfd54df6caf85cc35264c77ede902452d6df41166010262374155947460c"}, 1188 | {file = "Pillow-8.4.0-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:70ad9e5c6cb9b8487280a02c0ad8a51581dcbbe8484ce058477692a27c151c0a"}, 1189 | {file = "Pillow-8.4.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:25a49dc2e2f74e65efaa32b153527fc5ac98508d502fa46e74fa4fd678ed6645"}, 1190 | {file = "Pillow-8.4.0-cp36-cp36m-win32.whl", hash = "sha256:93ce9e955cc95959df98505e4608ad98281fff037350d8c2671c9aa86bcf10a9"}, 1191 | {file = "Pillow-8.4.0-cp36-cp36m-win_amd64.whl", hash = "sha256:2e4440b8f00f504ee4b53fe30f4e381aae30b0568193be305256b1462216feff"}, 1192 | {file = "Pillow-8.4.0-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:8c803ac3c28bbc53763e6825746f05cc407b20e4a69d0122e526a582e3b5e153"}, 1193 | {file = "Pillow-8.4.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c8a17b5d948f4ceeceb66384727dde11b240736fddeda54ca740b9b8b1556b29"}, 1194 | {file = "Pillow-8.4.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1394a6ad5abc838c5cd8a92c5a07535648cdf6d09e8e2d6df916dfa9ea86ead8"}, 1195 | {file = "Pillow-8.4.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:792e5c12376594bfcb986ebf3855aa4b7c225754e9a9521298e460e92fb4a488"}, 1196 | {file = "Pillow-8.4.0-cp37-cp37m-win32.whl", hash = "sha256:d99ec152570e4196772e7a8e4ba5320d2d27bf22fdf11743dd882936ed64305b"}, 1197 | {file = "Pillow-8.4.0-cp37-cp37m-win_amd64.whl", hash = "sha256:7b7017b61bbcdd7f6363aeceb881e23c46583739cb69a3ab39cb384f6ec82e5b"}, 1198 | {file = "Pillow-8.4.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:d89363f02658e253dbd171f7c3716a5d340a24ee82d38aab9183f7fdf0cdca49"}, 1199 | {file = "Pillow-8.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0a0956fdc5defc34462bb1c765ee88d933239f9a94bc37d132004775241a7585"}, 1200 | {file = "Pillow-8.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b7bb9de00197fb4261825c15551adf7605cf14a80badf1761d61e59da347779"}, 1201 | {file = "Pillow-8.4.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:72b9e656e340447f827885b8d7a15fc8c4e68d410dc2297ef6787eec0f0ea409"}, 1202 | {file = "Pillow-8.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a5a4532a12314149d8b4e4ad8ff09dde7427731fcfa5917ff16d0291f13609df"}, 1203 | {file = "Pillow-8.4.0-cp38-cp38-win32.whl", hash = "sha256:82aafa8d5eb68c8463b6e9baeb4f19043bb31fefc03eb7b216b51e6a9981ae09"}, 1204 | {file = "Pillow-8.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:066f3999cb3b070a95c3652712cffa1a748cd02d60ad7b4e485c3748a04d9d76"}, 1205 | {file = "Pillow-8.4.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:5503c86916d27c2e101b7f71c2ae2cddba01a2cf55b8395b0255fd33fa4d1f1a"}, 1206 | {file = "Pillow-8.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4acc0985ddf39d1bc969a9220b51d94ed51695d455c228d8ac29fcdb25810e6e"}, 1207 | {file = "Pillow-8.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b052a619a8bfcf26bd8b3f48f45283f9e977890263e4571f2393ed8898d331b"}, 1208 | {file = "Pillow-8.4.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:493cb4e415f44cd601fcec11c99836f707bb714ab03f5ed46ac25713baf0ff20"}, 1209 | {file = "Pillow-8.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8831cb7332eda5dc89b21a7bce7ef6ad305548820595033a4b03cf3091235ed"}, 1210 | {file = "Pillow-8.4.0-cp39-cp39-win32.whl", hash = "sha256:5e9ac5f66616b87d4da618a20ab0a38324dbe88d8a39b55be8964eb520021e02"}, 1211 | {file = "Pillow-8.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:3eb1ce5f65908556c2d8685a8f0a6e989d887ec4057326f6c22b24e8a172c66b"}, 1212 | {file = "Pillow-8.4.0-pp36-pypy36_pp73-macosx_10_10_x86_64.whl", hash = "sha256:ddc4d832a0f0b4c52fff973a0d44b6c99839a9d016fe4e6a1cb8f3eea96479c2"}, 1213 | {file = "Pillow-8.4.0-pp36-pypy36_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9a3e5ddc44c14042f0844b8cf7d2cd455f6cc80fd7f5eefbe657292cf601d9ad"}, 1214 | {file = "Pillow-8.4.0-pp36-pypy36_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c70e94281588ef053ae8998039610dbd71bc509e4acbc77ab59d7d2937b10698"}, 1215 | {file = "Pillow-8.4.0-pp37-pypy37_pp73-macosx_10_10_x86_64.whl", hash = "sha256:3862b7256046fcd950618ed22d1d60b842e3a40a48236a5498746f21189afbbc"}, 1216 | {file = "Pillow-8.4.0-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a4901622493f88b1a29bd30ec1a2f683782e57c3c16a2dbc7f2595ba01f639df"}, 1217 | {file = "Pillow-8.4.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84c471a734240653a0ec91dec0996696eea227eafe72a33bd06c92697728046b"}, 1218 | {file = "Pillow-8.4.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:244cf3b97802c34c41905d22810846802a3329ddcb93ccc432870243211c79fc"}, 1219 | {file = "Pillow-8.4.0.tar.gz", hash = "sha256:b8e2f83c56e141920c39464b852de3719dfbfb6e3c99a2d8da0edf4fb33176ed"}, 1220 | ] 1221 | pluggy = [] 1222 | protobuf = [ 1223 | {file = "protobuf-3.20.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3cc797c9d15d7689ed507b165cd05913acb992d78b379f6014e013f9ecb20996"}, 1224 | {file = "protobuf-3.20.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:ff8d8fa42675249bb456f5db06c00de6c2f4c27a065955917b28c4f15978b9c3"}, 1225 | {file = "protobuf-3.20.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:cd68be2559e2a3b84f517fb029ee611546f7812b1fdd0aa2ecc9bc6ec0e4fdde"}, 1226 | {file = "protobuf-3.20.1-cp310-cp310-win32.whl", hash = "sha256:9016d01c91e8e625141d24ec1b20fed584703e527d28512aa8c8707f105a683c"}, 1227 | {file = "protobuf-3.20.1-cp310-cp310-win_amd64.whl", hash = "sha256:32ca378605b41fd180dfe4e14d3226386d8d1b002ab31c969c366549e66a2bb7"}, 1228 | {file = "protobuf-3.20.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:9be73ad47579abc26c12024239d3540e6b765182a91dbc88e23658ab71767153"}, 1229 | {file = "protobuf-3.20.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:097c5d8a9808302fb0da7e20edf0b8d4703274d140fd25c5edabddcde43e081f"}, 1230 | {file = "protobuf-3.20.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e250a42f15bf9d5b09fe1b293bdba2801cd520a9f5ea2d7fb7536d4441811d20"}, 1231 | {file = "protobuf-3.20.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:cdee09140e1cd184ba9324ec1df410e7147242b94b5f8b0c64fc89e38a8ba531"}, 1232 | {file = "protobuf-3.20.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:af0ebadc74e281a517141daad9d0f2c5d93ab78e9d455113719a45a49da9db4e"}, 1233 | {file = "protobuf-3.20.1-cp37-cp37m-win32.whl", hash = "sha256:755f3aee41354ae395e104d62119cb223339a8f3276a0cd009ffabfcdd46bb0c"}, 1234 | {file = "protobuf-3.20.1-cp37-cp37m-win_amd64.whl", hash = "sha256:62f1b5c4cd6c5402b4e2d63804ba49a327e0c386c99b1675c8a0fefda23b2067"}, 1235 | {file = "protobuf-3.20.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:06059eb6953ff01e56a25cd02cca1a9649a75a7e65397b5b9b4e929ed71d10cf"}, 1236 | {file = "protobuf-3.20.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:cb29edb9eab15742d791e1025dd7b6a8f6fcb53802ad2f6e3adcb102051063ab"}, 1237 | {file = "protobuf-3.20.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:69ccfdf3657ba59569c64295b7d51325f91af586f8d5793b734260dfe2e94e2c"}, 1238 | {file = "protobuf-3.20.1-cp38-cp38-win32.whl", hash = "sha256:dd5789b2948ca702c17027c84c2accb552fc30f4622a98ab5c51fcfe8c50d3e7"}, 1239 | {file = "protobuf-3.20.1-cp38-cp38-win_amd64.whl", hash = "sha256:77053d28427a29987ca9caf7b72ccafee011257561259faba8dd308fda9a8739"}, 1240 | {file = "protobuf-3.20.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6f50601512a3d23625d8a85b1638d914a0970f17920ff39cec63aaef80a93fb7"}, 1241 | {file = "protobuf-3.20.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:284f86a6207c897542d7e956eb243a36bb8f9564c1742b253462386e96c6b78f"}, 1242 | {file = "protobuf-3.20.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:7403941f6d0992d40161aa8bb23e12575637008a5a02283a930addc0508982f9"}, 1243 | {file = "protobuf-3.20.1-cp39-cp39-win32.whl", hash = "sha256:db977c4ca738dd9ce508557d4fce0f5aebd105e158c725beec86feb1f6bc20d8"}, 1244 | {file = "protobuf-3.20.1-cp39-cp39-win_amd64.whl", hash = "sha256:7e371f10abe57cee5021797126c93479f59fccc9693dafd6bd5633ab67808a91"}, 1245 | {file = "protobuf-3.20.1-py2.py3-none-any.whl", hash = "sha256:adfc6cf69c7f8c50fd24c793964eef18f0ac321315439d94945820612849c388"}, 1246 | {file = "protobuf-3.20.1.tar.gz", hash = "sha256:adc31566d027f45efe3f44eeb5b1f329da43891634d61c75a5944e9be6dd42c9"}, 1247 | ] 1248 | py = [] 1249 | pyarrow = [] 1250 | pyasn1 = [ 1251 | {file = "pyasn1-0.4.8-py2.4.egg", hash = "sha256:fec3e9d8e36808a28efb59b489e4528c10ad0f480e57dcc32b4de5c9d8c9fdf3"}, 1252 | {file = "pyasn1-0.4.8-py2.5.egg", hash = "sha256:0458773cfe65b153891ac249bcf1b5f8f320b7c2ce462151f8fa74de8934becf"}, 1253 | {file = "pyasn1-0.4.8-py2.6.egg", hash = "sha256:5c9414dcfede6e441f7e8f81b43b34e834731003427e5b09e4e00e3172a10f00"}, 1254 | {file = "pyasn1-0.4.8-py2.7.egg", hash = "sha256:6e7545f1a61025a4e58bb336952c5061697da694db1cae97b116e9c46abcf7c8"}, 1255 | {file = "pyasn1-0.4.8-py2.py3-none-any.whl", hash = "sha256:39c7e2ec30515947ff4e87fb6f456dfc6e84857d34be479c9d4a4ba4bf46aa5d"}, 1256 | {file = "pyasn1-0.4.8-py3.1.egg", hash = "sha256:78fa6da68ed2727915c4767bb386ab32cdba863caa7dbe473eaae45f9959da86"}, 1257 | {file = "pyasn1-0.4.8-py3.2.egg", hash = "sha256:08c3c53b75eaa48d71cf8c710312316392ed40899cb34710d092e96745a358b7"}, 1258 | {file = "pyasn1-0.4.8-py3.3.egg", hash = "sha256:03840c999ba71680a131cfaee6fab142e1ed9bbd9c693e285cc6aca0d555e576"}, 1259 | {file = "pyasn1-0.4.8-py3.4.egg", hash = "sha256:7ab8a544af125fb704feadb008c99a88805126fb525280b2270bb25cc1d78a12"}, 1260 | {file = "pyasn1-0.4.8-py3.5.egg", hash = "sha256:e89bf84b5437b532b0803ba5c9a5e054d21fec423a89952a74f87fa2c9b7bce2"}, 1261 | {file = "pyasn1-0.4.8-py3.6.egg", hash = "sha256:014c0e9976956a08139dc0712ae195324a75e142284d5f87f1a87ee1b068a359"}, 1262 | {file = "pyasn1-0.4.8-py3.7.egg", hash = "sha256:99fcc3c8d804d1bc6d9a099921e39d827026409a58f2a720dcdb89374ea0c776"}, 1263 | {file = "pyasn1-0.4.8.tar.gz", hash = "sha256:aef77c9fb94a3ac588e87841208bdec464471d9871bd5050a287cc9a475cd0ba"}, 1264 | ] 1265 | pyasn1-modules = [ 1266 | {file = "pyasn1-modules-0.2.8.tar.gz", hash = "sha256:905f84c712230b2c592c19470d3ca8d552de726050d1d1716282a1f6146be65e"}, 1267 | {file = "pyasn1_modules-0.2.8-py2.4.egg", hash = "sha256:0fe1b68d1e486a1ed5473f1302bd991c1611d319bba158e98b106ff86e1d7199"}, 1268 | {file = "pyasn1_modules-0.2.8-py2.5.egg", hash = "sha256:fe0644d9ab041506b62782e92b06b8c68cca799e1a9636ec398675459e031405"}, 1269 | {file = "pyasn1_modules-0.2.8-py2.6.egg", hash = "sha256:a99324196732f53093a84c4369c996713eb8c89d360a496b599fb1a9c47fc3eb"}, 1270 | {file = "pyasn1_modules-0.2.8-py2.7.egg", hash = "sha256:0845a5582f6a02bb3e1bde9ecfc4bfcae6ec3210dd270522fee602365430c3f8"}, 1271 | {file = "pyasn1_modules-0.2.8-py2.py3-none-any.whl", hash = "sha256:a50b808ffeb97cb3601dd25981f6b016cbb3d31fbf57a8b8a87428e6158d0c74"}, 1272 | {file = "pyasn1_modules-0.2.8-py3.1.egg", hash = "sha256:f39edd8c4ecaa4556e989147ebf219227e2cd2e8a43c7e7fcb1f1c18c5fd6a3d"}, 1273 | {file = "pyasn1_modules-0.2.8-py3.2.egg", hash = "sha256:b80486a6c77252ea3a3e9b1e360bc9cf28eaac41263d173c032581ad2f20fe45"}, 1274 | {file = "pyasn1_modules-0.2.8-py3.3.egg", hash = "sha256:65cebbaffc913f4fe9e4808735c95ea22d7a7775646ab690518c056784bc21b4"}, 1275 | {file = "pyasn1_modules-0.2.8-py3.4.egg", hash = "sha256:15b7c67fabc7fc240d87fb9aabf999cf82311a6d6fb2c70d00d3d0604878c811"}, 1276 | {file = "pyasn1_modules-0.2.8-py3.5.egg", hash = "sha256:426edb7a5e8879f1ec54a1864f16b882c2837bfd06eee62f2c982315ee2473ed"}, 1277 | {file = "pyasn1_modules-0.2.8-py3.6.egg", hash = "sha256:cbac4bc38d117f2a49aeedec4407d23e8866ea4ac27ff2cf7fb3e5b570df19e0"}, 1278 | {file = "pyasn1_modules-0.2.8-py3.7.egg", hash = "sha256:c29a5e5cc7a3f05926aff34e097e84f8589cd790ce0ed41b67aed6857b26aafd"}, 1279 | ] 1280 | pydeprecate = [ 1281 | {file = "pyDeprecate-0.3.2-py3-none-any.whl", hash = "sha256:ed86b68ed837e6465245904a3de2f59bf9eef78ac7a2502ee280533d04802457"}, 1282 | {file = "pyDeprecate-0.3.2.tar.gz", hash = "sha256:d481116cc5d7f6c473e7c4be820efdd9b90a16b594b350276e9e66a6cb5bdd29"}, 1283 | ] 1284 | pyparsing = [] 1285 | pytest = [ 1286 | {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"}, 1287 | {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"}, 1288 | ] 1289 | python-dateutil = [] 1290 | pytorch-lightning = [ 1291 | {file = "pytorch-lightning-1.6.5.tar.gz", hash = "sha256:8d521f2619b9db2ada5bbaf9713330d01460e75a11e4bc0bc2ca25fd37c47c57"}, 1292 | {file = "pytorch_lightning-1.6.5-py3-none-any.whl", hash = "sha256:00c9205d26aa354defdd4dd592b2dded33916d6e0c180ccffbb06cf34dc67ccf"}, 1293 | ] 1294 | pytz = [] 1295 | pyyaml = [] 1296 | regex = [] 1297 | requests = [] 1298 | requests-oauthlib = [ 1299 | {file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"}, 1300 | {file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"}, 1301 | ] 1302 | responses = [] 1303 | rsa = [ 1304 | {file = "rsa-4.8-py3-none-any.whl", hash = "sha256:95c5d300c4e879ee69708c428ba566c59478fd653cc3a22243eeb8ed846950bb"}, 1305 | {file = "rsa-4.8.tar.gz", hash = "sha256:5c6bd9dc7a543b7fe4304a631f8a8a3b674e2bbfc49c2ae96200cdbe55df6b17"}, 1306 | ] 1307 | scikit-learn = [ 1308 | {file = "scikit-learn-1.1.1.tar.gz", hash = "sha256:3e77b71e8e644f86c8b5be7f1c285ef597de4c384961389ee3e9ca36c445b256"}, 1309 | {file = "scikit_learn-1.1.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:102f51797cd8944bf44a038d106848ddf2804f2c1edf7aea45fba81a4fdc4d80"}, 1310 | {file = "scikit_learn-1.1.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:723cdb278b1fa57a55f68945bc4e501a2f12abe82f76e8d21e1806cbdbef6fc5"}, 1311 | {file = "scikit_learn-1.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:33cf061ed0b79d647a3e4c3f6c52c412172836718a7cd4d11c1318d083300133"}, 1312 | {file = "scikit_learn-1.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47464c110eaa9ed9d1fe108cb403510878c3d3a40f110618d2a19b2190a3e35c"}, 1313 | {file = "scikit_learn-1.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:542ccd2592fe7ad31f5c85fed3a3deb3e252383960a85e4b49a629353fffaba4"}, 1314 | {file = "scikit_learn-1.1.1-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:3be10d8d325821ca366d4fe7083d87c40768f842f54371a9c908d97c45da16fc"}, 1315 | {file = "scikit_learn-1.1.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:b2db720e13e697d912a87c1a51194e6fb085dc6d8323caa5ca51369ca6948f78"}, 1316 | {file = "scikit_learn-1.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e851f8874398dcd50d1e174e810e9331563d189356e945b3271c0e19ee6f4d6f"}, 1317 | {file = "scikit_learn-1.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b928869072366dc138762fe0929e7dc88413f8a469aebc6a64adc10a9226180c"}, 1318 | {file = "scikit_learn-1.1.1-cp38-cp38-win32.whl", hash = "sha256:e9d228ced1214d67904f26fb820c8abbea12b2889cd4aa8cda20a4ca0ed781c1"}, 1319 | {file = "scikit_learn-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:f2d5b5d6e87d482e17696a7bfa03fe9515fdfe27e462a4ad37f3d7774a5e2fd6"}, 1320 | {file = "scikit_learn-1.1.1-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:0403ad13f283e27d43b0ad875f187ec7f5d964903d92d1ed06c51439560ecea0"}, 1321 | {file = "scikit_learn-1.1.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8fe80df08f5b9cee5dd008eccc672e543976198d790c07e5337f7dfb67eaac05"}, 1322 | {file = "scikit_learn-1.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8ff56d07b9507fbe07ca0f4e5c8f3e171f74a429f998da03e308166251316b34"}, 1323 | {file = "scikit_learn-1.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2dad2bfc502344b869d4a3f4aa7271b2a5f4fe41f7328f404844c51612e2c58"}, 1324 | {file = "scikit_learn-1.1.1-cp39-cp39-win32.whl", hash = "sha256:22145b60fef02e597a8e7f061ebc7c51739215f11ce7fcd2ca9af22c31aa9f86"}, 1325 | {file = "scikit_learn-1.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:45c0f6ae523353f1d99b85469d746f9c497410adff5ba8b24423705b6956a86e"}, 1326 | ] 1327 | scipy = [] 1328 | six = [] 1329 | tensorboard = [ 1330 | {file = "tensorboard-2.9.0-py3-none-any.whl", hash = "sha256:bd78211076dca5efa27260afacfaa96cd05c7db12a6c09cc76a1d6b2987ca621"}, 1331 | ] 1332 | tensorboard-data-server = [ 1333 | {file = "tensorboard_data_server-0.6.1-py3-none-any.whl", hash = "sha256:809fe9887682d35c1f7d1f54f0f40f98bb1f771b14265b453ca051e2ce58fca7"}, 1334 | {file = "tensorboard_data_server-0.6.1-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:fa8cef9be4fcae2f2363c88176638baf2da19c5ec90addb49b1cde05c95c88ee"}, 1335 | {file = "tensorboard_data_server-0.6.1-py3-none-manylinux2010_x86_64.whl", hash = "sha256:d8237580755e58eff68d1f3abefb5b1e39ae5c8b127cc40920f9c4fb33f4b98a"}, 1336 | ] 1337 | tensorboard-plugin-wit = [ 1338 | {file = "tensorboard_plugin_wit-1.8.1-py3-none-any.whl", hash = "sha256:ff26bdd583d155aa951ee3b152b3d0cffae8005dc697f72b44a8e8c2a77a8cbe"}, 1339 | ] 1340 | threadpoolctl = [ 1341 | {file = "threadpoolctl-3.1.0-py3-none-any.whl", hash = "sha256:8b99adda265feb6773280df41eece7b2e6561b772d21ffd52e372f999024907b"}, 1342 | {file = "threadpoolctl-3.1.0.tar.gz", hash = "sha256:a335baacfaa4400ae1f0d8e3a58d6674d2f8828e3716bb2802c44955ad391380"}, 1343 | ] 1344 | tokenizers = [] 1345 | toml = [ 1346 | {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, 1347 | {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, 1348 | ] 1349 | torch = [ 1350 | {file = "torch-1.12.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:3322d33a06e440d715bb214334bd41314c94632d9a2f07d22006bf21da3a2be4"}, 1351 | {file = "torch-1.12.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:2568f011dddeb5990d8698cc375d237f14568ffa8489854e3b94113b4b6b7c8b"}, 1352 | {file = "torch-1.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:e3e8348edca3e3cee5a67a2b452b85c57712efe1cc3ffdb87c128b3dde54534e"}, 1353 | {file = "torch-1.12.0-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:349ea3ba0c0e789e0507876c023181f13b35307aebc2e771efd0e045b8e03e84"}, 1354 | {file = "torch-1.12.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:13c7cca6b2ea3704d775444f02af53c5f072d145247e17b8cd7813ac57869f03"}, 1355 | {file = "torch-1.12.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:60d06ee2abfa85f10582d205404d52889d69bcbb71f7e211cfc37e3957ac19ca"}, 1356 | {file = "torch-1.12.0-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:a1325c9c28823af497cbf443369bddac9ac59f67f1e600f8ab9b754958e55b76"}, 1357 | {file = "torch-1.12.0-cp37-cp37m-win_amd64.whl", hash = "sha256:fb47291596677570246d723ee6abbcbac07eeba89d8f83de31e3954f21f44879"}, 1358 | {file = "torch-1.12.0-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:abbdc5483359b9495dc76e3bd7911ccd2ddc57706c117f8316832e31590af871"}, 1359 | {file = "torch-1.12.0-cp37-none-macosx_11_0_arm64.whl", hash = "sha256:72207b8733523388c49d43ffcc4416d1d8cd64c40f7826332e714605ace9b1d2"}, 1360 | {file = "torch-1.12.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:0986685f2ec8b7c4d3593e8cfe96be85d462943f1a8f54112fc48d4d9fbbe903"}, 1361 | {file = "torch-1.12.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:0399746f83b4541bcb5b219a18dbe8cade760aba1c660d2748a38c6dc338ebc7"}, 1362 | {file = "torch-1.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:7ddb167827170c4e3ff6a27157414a00b9fef93dea175da04caf92a0619b7aee"}, 1363 | {file = "torch-1.12.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:2143d5fe192fd908b70b494349de5b1ac02854a8a902bd5f47d13d85b410e430"}, 1364 | {file = "torch-1.12.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:44a3804e9bb189574f5d02ccc2dc6e32e26a81b3e095463b7067b786048c6072"}, 1365 | {file = "torch-1.12.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:844f1db41173b53fe40c44b3e04fcca23a6ce00ac328b7099f2800e611766845"}, 1366 | {file = "torch-1.12.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:63341f96840a223f277e498d2737b39da30d9f57c7a1ef88857b920096317739"}, 1367 | {file = "torch-1.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:201abf43a99bb4980cc827dd4b38ac28f35e4dddac7832718be3d5479cafd2c1"}, 1368 | {file = "torch-1.12.0-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:c0313438bc36448ffd209f5fb4e5f325b3af158cdf61c8829b8ddaf128c57816"}, 1369 | {file = "torch-1.12.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:5ed69d5af232c5c3287d44cef998880dadcc9721cd020e9ae02f42e56b79c2e4"}, 1370 | ] 1371 | torchmetrics = [ 1372 | {file = "torchmetrics-0.9.2-py3-none-any.whl", hash = "sha256:ced006295c95c4555df0b8dea92960c00e3303de0da878fcf27e394df4757827"}, 1373 | {file = "torchmetrics-0.9.2.tar.gz", hash = "sha256:8178c9242e243318093d9b7237738a504535193d2006da6e58b0ed4003e318d2"}, 1374 | ] 1375 | tqdm = [] 1376 | transformers = [] 1377 | typing-extensions = [ 1378 | {file = "typing_extensions-4.3.0-py3-none-any.whl", hash = "sha256:25642c956049920a5aa49edcdd6ab1e06d7e5d467fc00e0506c44ac86fbfca02"}, 1379 | {file = "typing_extensions-4.3.0.tar.gz", hash = "sha256:e6d2677a32f47fc7eb2795db1dd15c1f34eff616bcaf2cfb5e997f854fa1c4a6"}, 1380 | ] 1381 | urllib3 = [] 1382 | werkzeug = [ 1383 | {file = "Werkzeug-2.1.2-py3-none-any.whl", hash = "sha256:72a4b735692dd3135217911cbeaa1be5fa3f62bffb8745c5215420a03dc55255"}, 1384 | {file = "Werkzeug-2.1.2.tar.gz", hash = "sha256:1ce08e8093ed67d638d63879fd1ba3735817f7a80de3674d293f5984f25fb6e6"}, 1385 | ] 1386 | win32-setctime = [ 1387 | {file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"}, 1388 | {file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"}, 1389 | ] 1390 | xxhash = [] 1391 | yarl = [] 1392 | zipp = [ 1393 | {file = "zipp-3.8.1-py3-none-any.whl", hash = "sha256:47c40d7fe183a6f21403a199b3e4192cca5774656965b0a4988ad2f8feb5f009"}, 1394 | {file = "zipp-3.8.1.tar.gz", hash = "sha256:05b45f1ee8f807d0cc928485ca40a07cb491cf092ff587c0df9cb1fd154848d2"}, 1395 | ] 1396 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "text-embeddings" 3 | version = "0.2.0" 4 | description = "zero-vocab or low-vocab embeddings" 5 | authors = ["Chenghao Mou "] 6 | license = "MIT" 7 | 8 | [tool.poetry.dependencies] 9 | python = ">=3.8,<3.11" 10 | loguru = "^0.5.3" 11 | 12 | transformers = "^4.8.2" 13 | numpy = "^1.21.1" 14 | einops = "^0.3.0" 15 | Pillow = "^8.3.1" 16 | pytest = "^6.2.4" 17 | mmh3 = "^3.0.0" 18 | pdoc3 = "^0.9.2" 19 | grpcio = "1.38.1" 20 | scikit-learn = "^1.1.1" 21 | scipy = "^1.8.1" 22 | tokenizers = "^0.12.1" 23 | torch = "^1.12.0" 24 | 25 | [tool.poetry.dev-dependencies] 26 | datasets = "^1.9.0" 27 | tqdm = "^4.61.2" 28 | pytorch-lightning = "^1.3.8" 29 | 30 | [build-system] 31 | requires = ["poetry-core>=1.0.0"] 32 | build-backend = "poetry.core.masonry.api" 33 | -------------------------------------------------------------------------------- /tests/hash.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2021-04-17 08:29:54 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | 6 | import pytest 7 | from text_embeddings.hash import CANINETokenizer, PQRNNTokenizer 8 | from transformers.tokenization_utils_base import PaddingStrategy 9 | 10 | 11 | @pytest.mark.parametrize( 12 | ( 13 | "text_pair", 14 | "add_special_tokens", 15 | "stride", 16 | "padding", 17 | "truncation", 18 | "return_attention_mask", 19 | "return_special_tokens_mask", 20 | "return_length", 21 | ), 22 | [ 23 | (True, True, 5, "longest", "longest_first", True, True, True), 24 | (True, True, 5, "longest", "longest_first", True, True, False), 25 | (True, True, 5, "longest", "longest_first", True, False, True), 26 | (True, True, 5, "longest", "longest_first", False, True, True), 27 | (True, False, 5, "longest", "longest_first", True, False, True), 28 | (False, False, 5, "longest", "longest_first", True, False, True), 29 | ], 30 | ) 31 | def test_canine_tokenizer( 32 | text_pair: bool, 33 | add_special_tokens: bool, 34 | stride: int, 35 | padding, 36 | truncation, 37 | return_attention_mask, 38 | return_special_tokens_mask, 39 | return_length, 40 | ): 41 | 42 | data = [ 43 | "Hello world! Hello world! Hello world! Hello world! Hello world! Hello world! Hello world! Hello world! Hello world! Hello world! Hello world! Hello world!", 44 | "Hóla!", 45 | "你好,世界!", 46 | ] 47 | 48 | embedder = CANINETokenizer(hash_size=768, max_length=2048) 49 | 50 | results = embedder( 51 | text=data, 52 | text_pair=data if text_pair else None, 53 | add_special_tokens=add_special_tokens, 54 | stride=stride, 55 | padding=padding, 56 | return_tensors="pt", 57 | truncation=truncation, 58 | return_attention_mask=return_attention_mask, 59 | return_special_tokens_mask=return_special_tokens_mask, 60 | return_length=return_length, 61 | prepend_batch_axis=True, 62 | return_overflowing_tokens=False, 63 | ) 64 | 65 | sequence_length = results["input_ids"].shape[1] 66 | 67 | assert sequence_length <= embedder.max_length 68 | if return_special_tokens_mask and add_special_tokens: 69 | assert results["special_tokens_mask"].shape == (3, sequence_length) 70 | 71 | assert results["input_ids"].shape == ( 72 | 3, 73 | sequence_length, 74 | 768, 75 | ) # hight is slightly different because of the font 76 | if return_length: 77 | assert results["length"].shape == (3,) 78 | 79 | 80 | @pytest.mark.parametrize( 81 | ( 82 | "text_pair", 83 | "add_special_tokens", 84 | "stride", 85 | "padding", 86 | "truncation", 87 | "return_attention_mask", 88 | "return_special_tokens_mask", 89 | "return_length", 90 | ), 91 | [ 92 | (True, True, 5, PaddingStrategy.LONGEST, "longest_first", True, True, True), 93 | (True, True, 5, PaddingStrategy.LONGEST, "longest_first", True, True, False), 94 | (True, True, 5, PaddingStrategy.LONGEST, "longest_first", True, False, True), 95 | (True, True, 5, PaddingStrategy.LONGEST, "longest_first", False, True, True), 96 | (True, False, 5, PaddingStrategy.LONGEST, "longest_first", True, False, True), 97 | (False, False, 5, PaddingStrategy.LONGEST, "longest_first", True, False, True), 98 | ], 99 | ) 100 | def test_pqrnn_tokenizer( 101 | text_pair: bool, 102 | add_special_tokens: bool, 103 | stride: int, 104 | padding, 105 | truncation, 106 | return_attention_mask, 107 | return_special_tokens_mask, 108 | return_length, 109 | ): 110 | 111 | data = [ 112 | "Hello world! Hello world! Hello world! Hello world! Hello world! Hello world! Hello world! Hello world! Hello world! Hello world! Hello world! Hello world!", 113 | "Hóla!", 114 | "你好,世界!", 115 | ] 116 | 117 | embedder = PQRNNTokenizer(hash_size=768, max_length=512) 118 | 119 | results = embedder( 120 | text=data, 121 | text_pair=data if text_pair else None, 122 | add_special_tokens=add_special_tokens, 123 | stride=stride, 124 | padding=padding, 125 | return_tensors="pt", 126 | truncation=truncation, 127 | return_attention_mask=return_attention_mask, 128 | return_special_tokens_mask=return_special_tokens_mask, 129 | return_length=return_length, 130 | prepend_batch_axis=True, 131 | return_overflowing_tokens=False, 132 | ) 133 | 134 | sequence_length = results["input_ids"].shape[1] 135 | 136 | assert sequence_length <= embedder.max_length 137 | if return_special_tokens_mask and add_special_tokens: 138 | assert results["special_tokens_mask"].shape == (3, sequence_length) 139 | 140 | assert results["input_ids"].shape == ( 141 | 3, 142 | sequence_length, 143 | 768, 144 | ) # hight is slightly different because of the font 145 | if return_length: 146 | assert results["length"].shape == (3,) 147 | assert results["token_type_ids"].shape == (3, sequence_length) 148 | -------------------------------------------------------------------------------- /tests/visual.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2021-04-17 08:29:54 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | 6 | from pathlib import Path 7 | 8 | import pytest 9 | import tempfile 10 | from loguru import logger 11 | from text_embeddings.visual import text2image, VTRTokenizer 12 | from transformers.tokenization_utils_base import PaddingStrategy 13 | 14 | font_path = str( 15 | Path(__file__).parent.parent / "resources/Noto_Sans/NotoSans-Regular.ttf" 16 | ) 17 | logger.debug(f"Using font_path: {font_path}") 18 | 19 | 20 | @pytest.mark.parametrize(("text",), [("Hello world!",), ("Hóla!",)]) 21 | def test_text2image(text: str): 22 | 23 | with tempfile.NamedTemporaryFile(suffix=".png") as tmp: 24 | img = text2image(text, font=font_path) 25 | img.save(tmp.name) 26 | 27 | 28 | @pytest.mark.parametrize(("text",), [("Hello world!",), ("Hóla!",)]) 29 | def test_text2embeddings(text: str): 30 | 31 | embedder = VTRTokenizer(font_size=14, window_size=10, font=font_path, max_length=36) 32 | 33 | print(embedder.text2embeddings(text)) 34 | 35 | 36 | @pytest.mark.parametrize( 37 | ( 38 | "text_pair", 39 | "add_special_tokens", 40 | "stride", 41 | "padding", 42 | "truncation", 43 | "return_attention_mask", 44 | "return_special_tokens_mask", 45 | "return_length", 46 | ), 47 | [ 48 | (True, True, 5, PaddingStrategy.LONGEST, "longest_first", True, True, True), 49 | (True, True, 5, PaddingStrategy.LONGEST, "longest_first", True, True, False), 50 | (True, True, 5, PaddingStrategy.LONGEST, "longest_first", True, False, True), 51 | (True, True, 5, PaddingStrategy.LONGEST, "longest_first", False, True, True), 52 | (True, False, 5, PaddingStrategy.LONGEST, "longest_first", True, False, True), 53 | (False, False, 5, PaddingStrategy.LONGEST, "longest_first", True, False, True), 54 | # (True, True, 5, PaddingStrategy.DO_NOT_PAD, "longest_first", True, True, True), 55 | ], 56 | ) 57 | def test_vtr_tokenizer( 58 | text_pair: bool, 59 | add_special_tokens: bool, 60 | stride: int, 61 | padding, 62 | truncation, 63 | return_attention_mask, 64 | return_special_tokens_mask, 65 | return_length, 66 | ): 67 | 68 | data = [ 69 | "Hello world! Hello world! Hello world! Hello world! Hello world! Hello world! Hello world! Hello world! Hello world! Hello world! Hello world! Hello world!", 70 | "Hóla!", 71 | "你好,世界!", 72 | ] 73 | 74 | embedder = VTRTokenizer(font_size=14, window_size=10, font=font_path, max_length=36) 75 | 76 | results = embedder( 77 | text=data, 78 | text_pair=data if text_pair else None, 79 | add_special_tokens=add_special_tokens, 80 | stride=stride, 81 | padding=padding, 82 | return_tensors="pt", 83 | truncation=truncation, 84 | return_attention_mask=return_attention_mask, 85 | return_special_tokens_mask=return_special_tokens_mask, 86 | return_length=return_length, 87 | prepend_batch_axis=True, 88 | return_overflowing_tokens=False, 89 | ) 90 | 91 | sequence_length = results["input_ids"].shape[1] 92 | 93 | assert sequence_length <= embedder.max_length 94 | if return_special_tokens_mask and add_special_tokens: 95 | assert results["special_tokens_mask"].shape == (3, sequence_length) 96 | 97 | if add_special_tokens: 98 | assert results["input_ids"].shape == ( 99 | 3, 100 | sequence_length, 101 | 14, 102 | 10, 103 | ) # hight is slightly different because of the font 104 | else: 105 | assert results["input_ids"].shape == ( 106 | 3, 107 | sequence_length, 108 | 14, 109 | 10, 110 | ) # hight is slightly different because of the font 111 | if return_length: 112 | assert results["length"].shape == (3,) 113 | assert results["token_type_ids"].shape == (3, sequence_length) 114 | -------------------------------------------------------------------------------- /text_embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2021-04-17 08:49:42 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | 6 | """text_embeddings is a package for no-vocabulary text embeddings.""" -------------------------------------------------------------------------------- /text_embeddings/base/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2021-04-22 20:43:06 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | 6 | """base covers all the base classes, functions for other embedding based tokenizers.""" 7 | 8 | import abc 9 | from typing import List, Optional, Union, Dict 10 | from itertools import zip_longest 11 | 12 | import numpy as np 13 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy, TruncationStrategy, TensorType, BatchEncoding, EncodedInput, is_torch_available, to_py_obj, TextInput 14 | 15 | def is_torch(x) -> bool: # pragma: no 16 | """ 17 | Helper function to check whether the input is a torch tensor. 18 | 19 | Parameters 20 | ---------- 21 | x : [type] 22 | Input data 23 | 24 | Returns 25 | ------- 26 | bool 27 | Boolean value indicating whether the input is a torch tensor 28 | """ 29 | import torch 30 | return isinstance(x, torch.Tensor) 31 | 32 | class EmbeddingTokenizer(PreTrainedTokenizerBase): 33 | """ 34 | Embedding based tokenizer. It assumes each token is mapped to a tensor instead of an index number. 35 | This implementation borrows most implementation from huggingface's transformers library. 36 | 37 | Parameters 38 | ---------- 39 | model_input_names : Optional[List[str]], optional 40 | Required model input names, by default None 41 | special_tokens : Optional[Dict[str, np.ndarray]], optional 42 | Required model special tokens, by default None 43 | max_length : Optional[int], optional 44 | Maximum sequence length supported by the model, by default 2048 45 | """ 46 | 47 | def __init__( 48 | self, 49 | model_input_names: Optional[List[str]] = None, 50 | special_tokens: Optional[Dict[str, np.ndarray]] = None, 51 | max_length: Optional[int] = 2048, 52 | ): 53 | self.model_input_names = model_input_names 54 | self.special_tokens = special_tokens 55 | self.max_length = max_length 56 | 57 | @abc.abstractmethod 58 | def text2embeddings(self, text: str) -> np.ndarray: 59 | raise NotImplementedError('This function is not implemented') 60 | 61 | def __call__( 62 | self, 63 | text: Union[TextInput, List[TextInput]], 64 | text_pair: Optional[Union[TextInput, List[TextInput]]] = None, 65 | add_special_tokens: bool = True, 66 | padding: Union[bool, str, PaddingStrategy] = False, 67 | truncation: Union[bool, str, TruncationStrategy] = False, 68 | max_length: Optional[int] = None, 69 | pad_to_multiple_of: Optional[int] = None, 70 | return_tensors: Optional[Union[str, TensorType]] = None, 71 | return_token_type_ids: Optional[bool] = None, 72 | return_attention_mask: Optional[bool] = None, 73 | return_overflowing_tokens: bool = False, 74 | return_special_tokens_mask: bool = False, 75 | return_length: bool = False, 76 | **kwargs, 77 | ) -> BatchEncoding: 78 | """ 79 | Tokenize the text into a sequence of image blocks. 80 | 81 | Parameters 82 | ---------- 83 | text : Union[TextInput, List[TextInput]] 84 | A single text or a list of text 85 | text_pair : Optional[Union[TextInput, List[TextInput]]], optional 86 | A single text or a list of text, by default None 87 | add_special_tokens : bool, optional 88 | Whether to add special tokens to the data, by default True 89 | padding : Union[bool, str, PaddingStrategy], optional 90 | The padding strategy, by default False 91 | truncation : Union[bool, str, TruncationStrategy], optional 92 | The truncation strategy, by default False 93 | max_length : Optional[int], optional 94 | Maximum sequence length, overriding the class variable, by default None 95 | pad_to_multiple_of : Optional[int], optional 96 | Padding parameters, by default None 97 | return_tensors : Optional[Union[str, TensorType]], optional 98 | Return tensors in `pt`, 'tf' or 'np', by default None 99 | return_token_type_ids : Optional[bool], optional 100 | Return token type ids, by default None 101 | return_attention_mask : Optional[bool], optional 102 | Return attention mask, by default None 103 | return_overflowing_tokens : bool, optional 104 | Return overflowing tokens, by default False 105 | return_special_tokens_mask : bool, optional 106 | Return special token mask, by default False 107 | return_length : bool, optional 108 | Return length, by default False 109 | 110 | Returns 111 | ------- 112 | BatchEncoding 113 | A BatchEncoding object 114 | """ 115 | if self.special_tokens is None: 116 | self.special_tokens = { 117 | "CLS": self.text2embeddings("[CLS]"), 118 | "SEP": self.text2embeddings("[SEP]"), 119 | } 120 | 121 | if add_special_tokens and text_pair: 122 | actual_max_length = self.max_length - len(self.special_tokens["SEP"]) * 2 - len(self.special_tokens["CLS"]) 123 | else: 124 | actual_max_length = self.max_length 125 | 126 | batch_outputs = {} 127 | text = text if isinstance(text, list) else [text] 128 | text_pair = text_pair if isinstance(text_pair, list) else [text_pair] 129 | 130 | if isinstance(padding, str): 131 | padding = PaddingStrategy(padding) 132 | 133 | if isinstance(truncation, str): 134 | truncation = TruncationStrategy(truncation) 135 | 136 | for first_text, second_text in zip_longest(text, text_pair, fillvalue=None): 137 | 138 | first_embeddings = self.text2embeddings(first_text) 139 | second_embeddings = self.text2embeddings(second_text) 140 | 141 | outputs = self.prepare_for_model( 142 | first_embeddings, 143 | second_embeddings, 144 | add_special_tokens=add_special_tokens, 145 | padding=PaddingStrategy.DO_NOT_PAD, # we pad in batch afterward 146 | truncation=truncation, 147 | max_length=max_length or actual_max_length, 148 | pad_to_multiple_of=None, # we pad in batch afterward 149 | return_attention_mask=False, # we pad in batch afterward 150 | return_token_type_ids=return_token_type_ids, 151 | return_overflowing_tokens=return_overflowing_tokens, 152 | return_special_tokens_mask=return_special_tokens_mask, 153 | return_length=return_length, 154 | return_tensors=None, # We convert the whole batch to tensors at the end 155 | prepend_batch_axis=False, 156 | ) 157 | 158 | for key, value in outputs.items(): 159 | if key not in batch_outputs: 160 | batch_outputs[key] = [] 161 | batch_outputs[key].append(value) 162 | 163 | batch_outputs = self.pad( 164 | batch_outputs, 165 | padding=padding, 166 | max_length=max_length or actual_max_length, 167 | pad_to_multiple_of=pad_to_multiple_of, 168 | return_attention_mask=return_attention_mask, 169 | ) 170 | 171 | batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors) 172 | 173 | return batch_outputs 174 | 175 | def build_inputs_with_special_tokens( 176 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 177 | ) -> List[int]: 178 | if token_ids_1 is None: 179 | return token_ids_0 180 | 181 | return np.concatenate( 182 | [ 183 | self.special_tokens["CLS"], 184 | token_ids_0, 185 | self.special_tokens["SEP"], 186 | token_ids_1, 187 | self.special_tokens["SEP"], 188 | ], 189 | axis=0 190 | ) 191 | 192 | def prepare_for_model( 193 | self, 194 | ids: List[int], 195 | pair_ids: Optional[List[int]] = None, 196 | add_special_tokens: bool = True, 197 | padding: Union[bool, str, PaddingStrategy] = False, 198 | truncation: Union[bool, str, TruncationStrategy] = False, 199 | max_length: Optional[int] = None, 200 | stride: int = 0, 201 | pad_to_multiple_of: Optional[int] = None, 202 | return_tensors: Optional[Union[str, TensorType]] = None, 203 | return_token_type_ids: Optional[bool] = None, 204 | return_attention_mask: Optional[bool] = None, 205 | return_overflowing_tokens: bool = False, 206 | return_special_tokens_mask: bool = False, 207 | return_length: bool = False, 208 | prepend_batch_axis: bool = False, 209 | **kwargs 210 | ): 211 | 212 | pair = bool(pair_ids is not None) 213 | len_ids = len(ids) 214 | len_pair_ids = len(pair_ids) if pair else 0 215 | if return_token_type_ids and not add_special_tokens: 216 | raise ValueError( 217 | "Asking to return token_type_ids while setting add_special_tokens to False " 218 | "results in an undefined behavior. Please set add_special_tokens to True or " 219 | "set return_token_type_ids to None." 220 | ) 221 | 222 | # Load from model defaults 223 | if return_token_type_ids is None: 224 | return_token_type_ids = "token_type_ids" in self.model_input_names 225 | if return_attention_mask is None: 226 | return_attention_mask = "attention_mask" in self.model_input_names 227 | 228 | encoded_inputs = {} 229 | 230 | # Compute the total size of the returned encodings 231 | total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) 232 | 233 | # Truncation: Handle max sequence length 234 | overflowing_tokens = [] 235 | if truncation != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: 236 | ids, pair_ids, overflowing_tokens = self.truncate_sequences( 237 | ids, 238 | pair_ids=pair_ids, 239 | num_tokens_to_remove=total_len - max_length, 240 | truncation_strategy=truncation, 241 | stride=stride, 242 | ) 243 | 244 | if return_overflowing_tokens: 245 | encoded_inputs["overflowing_tokens"] = overflowing_tokens 246 | encoded_inputs["num_truncated_tokens"] = total_len - max_length 247 | 248 | # Add special tokens 249 | if add_special_tokens: 250 | sequence = self.build_inputs_with_special_tokens(ids, pair_ids) 251 | token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) 252 | else: 253 | sequence = np.concatenate([ids, pair_ids], axis=0) if pair is True else ids 254 | token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else []) 255 | 256 | # Build output dictionary 257 | encoded_inputs["input_ids"] = sequence 258 | 259 | if return_token_type_ids: 260 | encoded_inputs["token_type_ids"] = token_type_ids 261 | if return_special_tokens_mask: 262 | if add_special_tokens: 263 | encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) 264 | else: 265 | encoded_inputs["special_tokens_mask"] = [0] * len(sequence) 266 | 267 | # Padding 268 | if padding != PaddingStrategy.DO_NOT_PAD or return_attention_mask: 269 | encoded_inputs = self.pad( 270 | encoded_inputs, 271 | max_length=max_length, 272 | padding=padding, 273 | pad_to_multiple_of=pad_to_multiple_of, 274 | return_attention_mask=return_attention_mask, 275 | ) 276 | 277 | if return_length: 278 | encoded_inputs["length"] = len(encoded_inputs["input_ids"]) 279 | 280 | batch_outputs = BatchEncoding( 281 | encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis 282 | ) 283 | 284 | return batch_outputs 285 | 286 | def num_special_tokens_to_add(self, pair: bool = False) -> int: 287 | return 0 if not pair else 3 288 | 289 | def get_special_tokens_mask( 290 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 291 | ) -> List[int]: 292 | if token_ids_1 is None: 293 | return [0 for _ in token_ids_0] 294 | return [1 for _ in self.special_tokens["CLS"]] + [0 for _ in token_ids_0] + [1 for _ in self.special_tokens["SEP"]] + [0 for _ in token_ids_1] + [1 for _ in self.special_tokens["SEP"]] 295 | 296 | def create_token_type_ids_from_sequences( 297 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 298 | ) -> List[int]: 299 | 300 | if token_ids_1 is None: 301 | return len(token_ids_0) * [0] 302 | return [0]*len(self.special_tokens["CLS"]) + [0] * len(token_ids_0) + [0]*len(self.special_tokens["SEP"]) + [1] * len(token_ids_1) + [0]*len(self.special_tokens["SEP"]) 303 | 304 | def pad( 305 | self, 306 | encoded_inputs: Union[ 307 | BatchEncoding, 308 | List[BatchEncoding], 309 | Dict[str, EncodedInput], 310 | Dict[str, List[EncodedInput]], 311 | List[Dict[str, EncodedInput]], 312 | ], 313 | padding: Union[bool, str, PaddingStrategy] = True, 314 | max_length: Optional[int] = None, 315 | pad_to_multiple_of: Optional[int] = None, 316 | return_attention_mask: Optional[bool] = None, 317 | return_tensors: Optional[Union[str, TensorType]] = None, 318 | ) -> BatchEncoding: 319 | 320 | # If we have a list of dicts, let's convert it in a dict of lists 321 | # We do this to allow using this method as a collate_fn function in PyTorch Dataloader 322 | if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)): 323 | encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()} 324 | 325 | # The model's main input name, usually `input_ids`, has be passed for padding 326 | if self.model_input_names[0] not in encoded_inputs: 327 | raise ValueError( 328 | "You should supply an encoding or a list of encodings to this method" 329 | f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}" 330 | ) 331 | 332 | required_input = encoded_inputs[self.model_input_names[0]] 333 | 334 | if required_input is None: 335 | if return_attention_mask: 336 | encoded_inputs["attention_mask"] = [] 337 | return encoded_inputs 338 | 339 | # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects 340 | # and rebuild them afterwards if no return_tensors is specified 341 | # Note that we lose the specific device the tensor may be on for PyTorch 342 | 343 | first_element = required_input[0] 344 | if isinstance(first_element, (list, tuple)): 345 | # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element. 346 | index = 0 347 | while len(required_input[index]) == 0: 348 | index += 1 349 | if index < len(required_input): 350 | first_element = required_input[index][0] 351 | # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. 352 | if not isinstance(first_element, (int, list, tuple)): 353 | if is_torch_available() and is_torch(first_element): 354 | return_tensors = "pt" if return_tensors is None else return_tensors 355 | elif isinstance(first_element, np.ndarray): 356 | return_tensors = "np" if return_tensors is None else return_tensors 357 | else: 358 | raise ValueError( 359 | f"type of {first_element} unknown: {type(first_element)}. " 360 | f"Should be one of a python, numpy or pytorch object." 361 | ) 362 | 363 | for key, value in encoded_inputs.items(): 364 | encoded_inputs[key] = to_py_obj(value) 365 | 366 | required_input = encoded_inputs[self.model_input_names[0]] 367 | if required_input and not isinstance(required_input[0], (list, tuple)): 368 | encoded_inputs = self._pad( 369 | encoded_inputs, 370 | max_length=max_length, 371 | padding_strategy=padding, 372 | pad_to_multiple_of=pad_to_multiple_of, 373 | return_attention_mask=return_attention_mask, 374 | ) 375 | return BatchEncoding(encoded_inputs, tensor_type=return_tensors) 376 | 377 | batch_size = len(required_input) 378 | assert all( 379 | len(v) == batch_size for v in encoded_inputs.values() 380 | ), "Some items in the output dictionary have a different batch size than others." 381 | 382 | if padding == PaddingStrategy.LONGEST: 383 | max_length = max(len(inputs) for inputs in required_input) 384 | padding = PaddingStrategy.MAX_LENGTH 385 | 386 | batch_outputs = {} 387 | for i in range(batch_size): 388 | inputs = dict((k, v[i]) for k, v in encoded_inputs.items()) 389 | outputs = self._pad( 390 | inputs, 391 | max_length=max_length, 392 | padding_strategy=padding, 393 | pad_to_multiple_of=pad_to_multiple_of, 394 | return_attention_mask=return_attention_mask, 395 | ) 396 | 397 | for key, value in outputs.items(): 398 | if key not in batch_outputs: 399 | batch_outputs[key] = [] 400 | batch_outputs[key].append(value) 401 | 402 | return BatchEncoding(batch_outputs, tensor_type=return_tensors) 403 | 404 | def create_padding_token_embedding(self, input_embeddings=None): 405 | raise NotImplementedError('This function is not implemented') 406 | 407 | def _pad( 408 | self, 409 | encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], 410 | max_length: Optional[int] = None, 411 | padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, 412 | pad_to_multiple_of: Optional[int] = None, 413 | return_attention_mask: Optional[bool] = None, 414 | ) -> dict: 415 | 416 | # Load from model defaults 417 | if return_attention_mask is None: 418 | return_attention_mask = "attention_mask" in self.model_input_names 419 | 420 | required_input = encoded_inputs[self.model_input_names[0]] 421 | if padding_strategy == PaddingStrategy.LONGEST: 422 | max_length = len(required_input) 423 | 424 | if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): 425 | max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of 426 | 427 | needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length 428 | 429 | if needs_to_be_padded: 430 | difference = max_length - len(required_input) 431 | if "token_type_ids" in encoded_inputs and isinstance(encoded_inputs["token_type_ids"], int): 432 | encoded_inputs["token_type_ids"] = [encoded_inputs["token_type_ids"]] 433 | if self.padding_side == "right": 434 | if return_attention_mask: 435 | encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference 436 | if "token_type_ids" in encoded_inputs: 437 | encoded_inputs["token_type_ids"] = ( 438 | encoded_inputs["token_type_ids"] + [1] * difference 439 | ) 440 | if "special_tokens_mask" in encoded_inputs: 441 | encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference 442 | 443 | encoded_inputs[self.model_input_names[0]] = required_input + [self.create_padding_token_embedding(input_embeddings=required_input)] * difference 444 | elif self.padding_side == "left": 445 | if return_attention_mask: 446 | encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input) 447 | if "token_type_ids" in encoded_inputs: 448 | encoded_inputs["token_type_ids"] = [0] * difference + encoded_inputs[ 449 | "token_type_ids" 450 | ] 451 | if "special_tokens_mask" in encoded_inputs: 452 | encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] 453 | encoded_inputs[self.model_input_names[0]] = [self.create_padding_token_embedding(input_embeddings=required_input)] * difference + required_input 454 | else: 455 | raise ValueError("Invalid padding strategy:" + str(self.padding_side)) 456 | elif return_attention_mask and "attention_mask" not in encoded_inputs: 457 | encoded_inputs["attention_mask"] = [1] * len(required_input) 458 | 459 | return encoded_inputs -------------------------------------------------------------------------------- /text_embeddings/byte/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2021-07-18 14:36:09 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | 6 | from text_embeddings.byte.byt5 import ByT5Tokenizer 7 | from text_embeddings.byte.charformer import ByteTokenizer, GBST 8 | 9 | __all__ = ['ByT5Tokenizer', 'GBST', 'ByteTokenizer'] 10 | -------------------------------------------------------------------------------- /text_embeddings/byte/byt5.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2021-06-02 08:10:13 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | 6 | """From ByT5: Towards a token-free future with pre-trained byte-to-byte models.""" 7 | 8 | import numpy as np 9 | from typing import Optional, List, Dict 10 | from text_embeddings.base import EmbeddingTokenizer 11 | from loguru import logger 12 | 13 | 14 | class ByT5Tokenizer(EmbeddingTokenizer): 15 | """Embed text into byte sequences. This is different from other tokenizers because it still has a small vocabulary where each byte is mapped to an index. 16 | 17 | Parameters 18 | ---------- 19 | embed_size : int, optional 20 | The size of the embedding, by default 259 (256 + 3 special tokens) 21 | model_input_names : Optional[List[str]], optional 22 | Required inputs of the downstream model, by default it uses the same names as a BERT — ["input_ids", "token_type_ids", "attention_mask"] 23 | special_tokens : Optional[Dict[str, np.ndarray]], optional 24 | Special tokens for the downstream model, by default it uses the same special tokens as a BERT — {"CLS": "[CLS]", "SEP": "[SEP]"} 25 | max_length : Optional[int], optional 26 | Maximum character length, by default 1024 27 | 28 | Examples 29 | -------- 30 | >>> tokenizer = ByT5Tokenizer() 31 | >>> e = tokenizer.text2embeddings("This is a test message") 32 | >>> e.shape 33 | (22, 259) 34 | >>> np.equal(np.max(e, axis=1), np.ones((len(e)))).all() 35 | True 36 | """ 37 | 38 | def __init__( 39 | self, 40 | embed_size: int = 259, 41 | model_input_names: Optional[List[str]] = None, 42 | special_tokens: Optional[Dict[str, np.ndarray]] = None, 43 | max_length: Optional[int] = 1024, 44 | ): 45 | super().__init__(model_input_names, special_tokens, max_length) 46 | self.embed_size = embed_size 47 | self.model_input_names = model_input_names 48 | self.special_tokens = special_tokens 49 | self.max_length = max_length 50 | 51 | if self.model_input_names is None: 52 | logger.warning( 53 | 'Using default model_input_names values ["input_ids", "token_type_ids", "attention_mask"]' 54 | ) 55 | self.model_input_names = ["input_ids", "token_type_ids", "attention_mask"] 56 | 57 | if self.special_tokens is None: 58 | logger.warning("Using default special_tokens values") 59 | self.special_tokens = { 60 | "SEP": np.zeros((self.embed_size,)), 61 | "CLS": np.zeros((self.embed_size,)), 62 | } 63 | self.special_tokens["CLS"][1] = 1 64 | self.special_tokens["SEP"][2] = 1 65 | 66 | logger.info("Be sure to add an embedding layer when using a ByT5Tokenizer.") 67 | 68 | def text2embeddings(self, text: str) -> np.ndarray: 69 | """Convert text into an numpy array, in (sequence_length, embed_size) shape. 70 | 71 | Parameters 72 | ---------- 73 | text : str 74 | Input text 75 | 76 | Returns 77 | ------- 78 | np.ndarray 79 | An array in (sequence_length, embed_size) shape 80 | """ 81 | if not text: 82 | return None 83 | 84 | b = text.encode("utf-8", errors="ignore") 85 | 86 | result = np.zeros((len(b), self.embed_size)) 87 | for i, byte in enumerate(b): 88 | result[i][byte + 3] = 1 89 | 90 | return result 91 | 92 | def create_padding_token_embedding(self, input_embeddings=None) -> np.ndarray: 93 | """Create a padding token embedding. 94 | 95 | Parameters 96 | ---------- 97 | input_embeddings : np.ndarray, optional 98 | Embedded input, by default None 99 | 100 | Returns 101 | ------- 102 | np.ndarray 103 | A padding token embedding compatible with the input 104 | """ 105 | e = np.zeros((self.embed_size,)) 106 | e[0] = 1 107 | return e 108 | -------------------------------------------------------------------------------- /text_embeddings/byte/charformer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2021-06-27 09:47:26 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | 6 | 7 | """This is from paper Charformer: Fast Character Transformers via Gradient-based Subword Tokenization.""" 8 | import math 9 | from typing import Dict, List, Optional 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | from einops import rearrange, repeat 15 | from loguru import logger 16 | from transformers.file_utils import PaddingStrategy 17 | from text_embeddings.base import EmbeddingTokenizer 18 | 19 | 20 | class ByteTokenizer(EmbeddingTokenizer): 21 | """Embed text into byte sequences. This is different from other tokenizers because it still needs a small vocabulary where each byte is mapped to an index. 22 | 23 | Parameters 24 | ---------- 25 | model_input_names : Optional[List[str]], optional 26 | Required inputs of the downstream model, by default it uses the same names as a BERT — ["input_ids", "token_type_ids", "attention_mask"] 27 | special_tokens : Optional[Dict[str, np.ndarray]], optional 28 | Special tokens for the downstream model, by default it uses the same special tokens as a BERT — {"CLS": "[CLS]", "SEP": "[SEP]"} 29 | max_length : Optional[int], optional 30 | Maximum character length, by default 1024 31 | 32 | Examples 33 | -------- 34 | >>> from transformers.tokenization_utils_base import PaddingStrategy, TruncationStrategy 35 | >>> tokenizer = ByteTokenizer() 36 | >>> e = tokenizer.text2embeddings("This is a test message") 37 | >>> e.shape 38 | (22, 1) 39 | >>> r = tokenizer(["This is a test message", "This is another test message"], padding=PaddingStrategy.LONGEST) 40 | >>> r["input_ids"].shape 41 | (2, 28) 42 | """ 43 | 44 | def __init__( 45 | self, 46 | model_input_names: Optional[List[str]] = None, 47 | special_tokens: Optional[Dict[str, np.ndarray]] = None, 48 | max_length: Optional[int] = 1024, 49 | ): 50 | super().__init__(model_input_names, special_tokens, max_length) 51 | self.embed_size = 1 52 | self.model_input_names = model_input_names 53 | self.special_tokens = special_tokens 54 | self.max_length = max_length 55 | 56 | if self.model_input_names is None: 57 | logger.warning( 58 | 'Using default model_input_names values ["input_ids", "token_type_ids", "attention_mask"]' 59 | ) 60 | self.model_input_names = ["input_ids", "token_type_ids", "attention_mask"] 61 | 62 | if self.special_tokens is None: 63 | logger.warning("Using default special_tokens values") 64 | self.special_tokens = { 65 | "SEP": np.zeros((self.embed_size,)), 66 | "CLS": np.zeros((self.embed_size,)), 67 | } 68 | self.special_tokens["CLS"] = 1 69 | self.special_tokens["SEP"] = 2 70 | 71 | logger.info("Be sure to add an embedding layer when using a ByteTokenizer.") 72 | 73 | def text2embeddings(self, text: str) -> np.ndarray: 74 | """Convert text into an numpy array, in (sequence_length, embed_size) shape. 75 | 76 | Parameters 77 | ---------- 78 | text : str 79 | Input text 80 | 81 | Returns 82 | ------- 83 | np.ndarray 84 | An array in (sequence_length, embed_size) shape 85 | """ 86 | if not text: 87 | return None 88 | 89 | b = text.encode("utf-8", errors="ignore") 90 | 91 | result = np.zeros((len(b), self.embed_size)) 92 | for i, byte in enumerate(b): 93 | result[i] = byte + len(self.special_tokens) + 1 94 | 95 | return result 96 | 97 | def create_padding_token_embedding(self, input_embeddings=None) -> np.ndarray: 98 | """Create a padding token embedding. 99 | 100 | Parameters 101 | ---------- 102 | input_embeddings : np.ndarray, optional 103 | Embedded input, by default None 104 | 105 | Returns 106 | ------- 107 | np.ndarray 108 | A padding token embedding compatible with the input 109 | """ 110 | e = np.zeros((self.embed_size,)) 111 | return e 112 | 113 | def __call__(self, *args, **kwargs): 114 | results = super().__call__(*args, **kwargs) 115 | results["input_ids"] = np.squeeze(results["input_ids"], axis=-1) 116 | return results 117 | 118 | 119 | class PositionalEncoding(nn.Module): 120 | def __init__(self, d_model, dropout=0.1, max_len=5000): 121 | super(PositionalEncoding, self).__init__() 122 | self.dropout = nn.Dropout(p=dropout) 123 | 124 | pe = torch.zeros(max_len, d_model) 125 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 126 | div_term = torch.exp( 127 | torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) 128 | ) 129 | pe[:, 0::2] = torch.sin(position * div_term) 130 | pe[:, 1::2] = torch.cos(position * div_term) 131 | pe = pe.unsqueeze(0).transpose(0, 1) 132 | self.register_buffer("pe", pe) 133 | 134 | def forward(self, x): 135 | 136 | x = rearrange(x, "b h l -> l b h") 137 | pe_ = repeat( 138 | self.pe, 139 | "s b h -> (repeat s) b h", 140 | repeat=torch.div(x.shape[0], self.pe.shape[0], rounding_mode="trunc") + 1, 141 | ) 142 | x = x + pe_[: x.shape[0], :] 143 | return rearrange(self.dropout(x), "l b h -> b h l") 144 | 145 | 146 | class GBST(nn.Module): 147 | """Gradient-based Subword Tokenization module from the paper: 148 | Charformer: Fast Character Transformers via Gradient-based Subword Tokenization. 149 | 150 | Parameters 151 | ---------- 152 | embed_size : int, optional 153 | The embedding size for each byte/character, by default 259 154 | max_block_size : int, optional 155 | Every subword token of length from 1 to max_block_size are considered, by default 4 156 | downsampling_factor : int, optional 157 | Downsampling rate from byte sequence to the final sequence, by default 2 158 | score_calibration : bool, optional 159 | To calibrate the scores with a self-attention like step, by default True 160 | vocab_size : int, optional 161 | The size of the byte vocabulary, by default 256 162 | 163 | Examples 164 | -------- 165 | >>> model = GBST( 166 | ... embed_size=128, 167 | ... max_block_size=4, 168 | ... downsampling_factor=2, 169 | ... score_calibration=True, 170 | ... vocab_size=256, 171 | ... ) 172 | >>> tokenizer = ByteTokenizer() 173 | >>> results = tokenizer(["Life is like a box of chocolates.", "Coding is fun."], add_special_tokens=True) 174 | >>> results["input_ids"].shape 175 | (2, 1024) 176 | >>> hidden = model(torch.tensor(results["input_ids"]).long()) 177 | >>> hidden.shape 178 | torch.Size([2, 512, 128]) 179 | """ 180 | 181 | def __init__( 182 | self, 183 | embed_size: int = 256, 184 | max_block_size: int = 4, 185 | downsampling_factor: int = 2, 186 | score_calibration: bool = True, 187 | vocab_size: int = 256, 188 | ): 189 | super().__init__() 190 | self.vocab_size = vocab_size 191 | self.max_block_size = max_block_size 192 | self.score_calibration = score_calibration 193 | self.downsampling_factor = downsampling_factor 194 | self.embed_size = embed_size 195 | 196 | self.byte_embedding = nn.Embedding( 197 | self.vocab_size, self.embed_size, padding_idx=0 198 | ) 199 | self.block_position_embedding = PositionalEncoding( 200 | self.embed_size, max_len=self.max_block_size 201 | ) 202 | 203 | self.avg_pools = nn.ModuleDict( 204 | { 205 | str(i): nn.AvgPool1d(i, ceil_mode=True) 206 | for i in range(1, self.max_block_size + 1) 207 | } 208 | ) 209 | self.block_scorer = nn.Linear(self.embed_size, 1) 210 | self.down_sampler = nn.AvgPool1d(self.downsampling_factor, ceil_mode=True) 211 | 212 | self.apply(self._init_weights) 213 | 214 | def _init_weights(self, module): 215 | if isinstance(module, nn.Linear): 216 | module.weight.data.normal_(mean=0.0, std=0.02) 217 | if module.bias is not None: 218 | module.bias.data.zero_() 219 | elif isinstance(module, nn.Embedding): 220 | module.weight.data.normal_(mean=0.0, std=0.02) 221 | if module.padding_idx is not None: 222 | module.weight.data[module.padding_idx].zero_() 223 | elif isinstance(module, nn.LayerNorm): 224 | module.bias.data.zero_() 225 | module.weight.data.fill_(1.0) 226 | 227 | def forward(self, input): 228 | 229 | byte_embeddings = self.byte_embedding(input) 230 | sequence_length = byte_embeddings.shape[1] 231 | 232 | Xs = [] 233 | X_scores = [] 234 | 235 | for block_size in range(1, self.max_block_size + 1): 236 | positioned_embeddings = rearrange(byte_embeddings, "b l h -> b h l") 237 | positioned_embeddings = self.block_position_embedding(positioned_embeddings) 238 | 239 | # b h s 240 | Xb = self.avg_pools[str(block_size)](positioned_embeddings) 241 | # b 1 s 242 | Xb_scores = rearrange( 243 | self.block_scorer(rearrange(Xb, "b h s -> b s h")), "b s 1 -> b 1 s" 244 | ) 245 | # b h l 246 | Xb_ = Xb.repeat_interleave(repeats=block_size, dim=2) 247 | # b 1 l 248 | Xb_scores_ = Xb_scores.repeat_interleave(repeats=block_size, dim=2) 249 | 250 | Xs.append(Xb_[:, :, :sequence_length]) 251 | X_scores.append(Xb_scores_[:, :, :sequence_length]) 252 | 253 | # b M l 254 | scores = torch.cat(X_scores, dim=1) 255 | # b l M 1 256 | scores = rearrange(torch.softmax(scores, dim=1), "b M l -> b l M 1") 257 | 258 | if self.score_calibration: 259 | # b l M 1 260 | scores = ( 261 | torch.softmax(scores @ rearrange(scores, "b l M 1 -> b l 1 M"), dim=-1) 262 | @ scores 263 | ) 264 | 265 | # b l h M 266 | Xs = rearrange(torch.stack(Xs, dim=0), "M b h l -> b l h M") 267 | Xs = rearrange(Xs @ scores, "b l h 1 -> b h l") 268 | Xs = rearrange(self.down_sampler(Xs), "b h s -> b s h") 269 | 270 | return Xs 271 | 272 | 273 | if __name__ == "__main__": 274 | 275 | import torch.onnx # nightly torch only 276 | from transformers.tokenization_utils_base import PaddingStrategy 277 | 278 | model = GBST( 279 | embed_size=128, 280 | max_block_size=4, 281 | downsampling_factor=2, 282 | score_calibration=True, 283 | vocab_size=259, 284 | ) 285 | 286 | tokenizer = ByteTokenizer() 287 | results = tokenizer( 288 | ["Life is like a box of chocolates.", "Coding is fun."], 289 | add_special_tokens=True, 290 | padding=PaddingStrategy.LONGEST, 291 | truncation="longest_first", 292 | ) 293 | 294 | # Export the model 295 | torch.onnx.export( 296 | model, 297 | torch.tensor(results["input_ids"], requires_grad=True).long(), 298 | "gbst.onnx", 299 | export_params=True, 300 | opset_version=11, 301 | do_constant_folding=True, 302 | input_names=["input"], 303 | output_names=["output"], 304 | dynamic_axes={ 305 | "input": {0: "batch_size", 1: "sequence_length"}, 306 | "output": {0: "batch_size"}, 307 | }, 308 | ) 309 | -------------------------------------------------------------------------------- /text_embeddings/hash/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2021-04-22 20:58:54 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | 6 | """Hash related tokenizers.""" 7 | 8 | from .canine import CANINETokenizer 9 | from .pqrnn import PQRNNTokenizer 10 | 11 | __all__ = ['PQRNNTokenizer', 'CANINETokenizer'] 12 | -------------------------------------------------------------------------------- /text_embeddings/hash/canine.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2021-04-18 09:06:29 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | 6 | """From CANINE: Pre-training an Efficient Tokenization-Free Encoder for Language Representation.""" 7 | 8 | import numpy as np 9 | from typing import Optional, List, Dict 10 | from text_embeddings.hash.util import murmurhash 11 | from text_embeddings.base import EmbeddingTokenizer 12 | from loguru import logger 13 | 14 | 15 | class CANINETokenizer(EmbeddingTokenizer): 16 | """ 17 | A character hashing tokenizer/embedder from [CANINE: Pre-training an Efficient Tokenization-Free Encoder for Language Representation](https://arxiv.org/abs/2103.06874) 18 | 19 | Parameters 20 | ---------- 21 | hash_size : int, optional 22 | The embedding size of each character, by default 768 23 | model_input_names : Optional[List[str]], optional 24 | Required inputs of the downstream model, by default it uses the same names as a BERT — ["input_ids", "token_type_ids", "attention_mask"] 25 | special_tokens : Optional[Dict[str, np.ndarray]], optional 26 | Special tokens for the downstream model, by default it uses the same special tokens as a BERT — {"CLS": "[CLS]", "SEP": "[SEP]"} 27 | max_length : Optional[int], optional 28 | Maximum character length, by default 2048 29 | 30 | Examples 31 | -------- 32 | >>> from text_embeddings.hash import CANINETokenizer 33 | >>> from transformers.tokenization_utils_base import * 34 | >>> tokenier = CANINETokenizer() 35 | >>> results = tokenier(text=['This is a sentence.', 'This is another sentence.'], padding=PaddingStrategy.LONGEST, truncation="longest_first", add_special_tokens=False) 36 | >>> assert results['input_ids'].shape == (2, 25, 768), results['input_ids'].shape 37 | """ 38 | 39 | def __init__( 40 | self, 41 | hash_size: int = 768, 42 | model_input_names: Optional[List[str]] = None, 43 | special_tokens: Optional[Dict[str, np.ndarray]] = None, 44 | max_length: Optional[int] = 2048, 45 | ): 46 | super().__init__(model_input_names, special_tokens, max_length) 47 | self.hash_size = hash_size 48 | self.model_input_names = model_input_names 49 | self.special_tokens = special_tokens 50 | self.max_length = max_length 51 | 52 | if self.model_input_names is None: 53 | logger.warning( 54 | 'Using default model_input_names values ["input_ids", "token_type_ids", "attention_mask"]' 55 | ) 56 | self.model_input_names = ["input_ids", "token_type_ids", "attention_mask"] 57 | 58 | def text2embeddings(self, text: str) -> np.ndarray: 59 | """Convert text into an numpy array, in (sequence_length, hash_size) shape. 60 | 61 | Parameters 62 | ---------- 63 | text : str 64 | Input text 65 | 66 | Returns 67 | ------- 68 | np.ndarray 69 | An array in (sequence_length, hash_size) shape 70 | """ 71 | if not text: 72 | return None 73 | 74 | result = np.zeros((len(text), self.hash_size)) 75 | for i, char in enumerate(text): 76 | result[i] = murmurhash(char, feature_size=self.hash_size * 2) 77 | 78 | return result 79 | 80 | def create_padding_token_embedding(self, input_embeddings=None) -> np.ndarray: 81 | """Create a padding token embedding. 82 | 83 | Parameters 84 | ---------- 85 | input_embeddings : [type], optional 86 | Embeddings already encoded, by default None 87 | 88 | Returns 89 | ------- 90 | np.ndarray 91 | An embedding array in (hash_size) 92 | """ 93 | return np.zeros((self.hash_size,)) 94 | -------------------------------------------------------------------------------- /text_embeddings/hash/pqrnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2021-04-18 09:06:29 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | 6 | """https://ai.googleblog.com/2020/09/advancing-nlp-with-efficient-projection.html""" 7 | 8 | import numpy as np 9 | from typing import Optional, List, Dict 10 | from text_embeddings.hash.util import murmurhash 11 | from text_embeddings.base import EmbeddingTokenizer 12 | from loguru import logger 13 | 14 | class PQRNNTokenizer(EmbeddingTokenizer): 15 | """ 16 | Boundary-based hashing embeddings based on [PQRNN](https://ai.googleblog.com/2020/09/advancing-nlp-with-efficient-projection.html) 17 | 18 | Parameters 19 | ---------- 20 | hash_size : int, optional 21 | The size of the hashing embedding, by default 768 22 | model_input_names : Optional[List[str]], optional 23 | Required inputs of the downstream model, by default it uses the same names as a BERT — ["input_ids", "token_type_ids", "attention_mask"] 24 | special_tokens : Optional[Dict[str, np.ndarray]], optional 25 | Special tokens for the downstream model, by default it uses the same special tokens as a BERT — {"CLS": "[CLS]", "SEP": "[SEP]"} 26 | max_length : Optional[int], optional 27 | Maximum token length, by default 2048 28 | 29 | Examples 30 | -------- 31 | >>> from text_embeddings.hash import PQRNNTokenizer 32 | >>> from transformers.tokenization_utils_base import * 33 | >>> tokenier = PQRNNTokenizer() 34 | >>> results = tokenier(text=['This is a sentence.', 'This is another sentence.'], padding=PaddingStrategy.LONGEST, truncation="longest_first", add_special_tokens=False) 35 | >>> assert results['input_ids'].shape == (2, 4, 768) 36 | """ 37 | 38 | def __init__( 39 | self, 40 | hash_size: int = 768, 41 | model_input_names: Optional[List[str]] = None, 42 | special_tokens: Optional[Dict[str, np.ndarray]] = None, 43 | max_length: Optional[int] = 2048, 44 | ): 45 | super().__init__(model_input_names, special_tokens, max_length) 46 | self.hash_size = hash_size 47 | self.model_input_names = model_input_names 48 | self.special_tokens = special_tokens 49 | self.max_length = max_length 50 | 51 | if self.model_input_names is None: 52 | logger.warning('Using default model_input_names values ["input_ids", "token_type_ids", "attention_mask"]') 53 | self.model_input_names = ["input_ids", "token_type_ids", "attention_mask"] 54 | 55 | def text2embeddings(self, text: str) -> np.ndarray: 56 | """Convert text into an numpy array, in (sequence_length, 1, hash_size) shape. 57 | 58 | Parameters 59 | ---------- 60 | text : str 61 | Input text 62 | 63 | Returns 64 | ------- 65 | np.ndarray 66 | An array in (sequence_length, 1, hash_size) shape 67 | """ 68 | if not text: 69 | return None 70 | 71 | tokens = text.split(" ") 72 | result = np.zeros((len(tokens), self.hash_size)) 73 | for i, token in enumerate(tokens): 74 | result[i] = murmurhash(token, feature_size=self.hash_size*2) 75 | 76 | return result 77 | 78 | def create_padding_token_embedding(self, input_embeddings=None) -> np.ndarray: 79 | """Create a padding token embedding 80 | 81 | Parameters 82 | ---------- 83 | input_embeddings : [type], optional 84 | [description], by default None 85 | 86 | Returns 87 | ------- 88 | np.ndarray 89 | An empty embedding in (hash_size, ) shape 90 | """ 91 | 92 | return np.zeros((self.hash_size, )) 93 | -------------------------------------------------------------------------------- /text_embeddings/hash/util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2021-04-22 21:03:09 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | 6 | """Python translation of the original code in tensorflow.""" 7 | 8 | from typing import List 9 | 10 | import mmh3 11 | 12 | kMul: int = 0xC6A4A7935BD1E995 13 | kMul2: int = 0x9E3779B97F4A7835 14 | 15 | kMappingTable: List[int] = [0, 1, -1, 0] 16 | 17 | 18 | def shift_mix(val): 19 | return val ^ (val >> 47) 20 | 21 | 22 | def get_more_bits(hash1, hash2): 23 | 24 | hash1 = shift_mix(hash1) * kMul 25 | hash2 ^= hash1 26 | newhigh = shift_mix(hash1) 27 | newlow = shift_mix(hash2 * kMul2) * kMul2 28 | 29 | return newlow, newhigh 30 | 31 | 32 | def murmurhash(token: str, feature_size: int = 512) -> List[int]: 33 | """ 34 | Hash a token into a list of feature_size integers (-1, 0, or 1). 35 | 36 | Parameters 37 | ---------- 38 | token : str 39 | Input token string 40 | feature_size : int, optional 41 | The target size of the hash embedding, by default 512 42 | 43 | Returns 44 | ------- 45 | List[int] 46 | A list of feature_size trinary integers 47 | """ 48 | hash_low = 0 49 | hash_high = 0 50 | hash_codes = [] 51 | 52 | for i in range(0, feature_size, 64): 53 | if i == 0: 54 | hash_low, hash_high = mmh3.hash64(token, signed=False) 55 | else: 56 | hash_low, hash_high = get_more_bits(hash_low, hash_high) 57 | hash_codes.append(hash_low) 58 | hash_codes.append(hash_high) 59 | 60 | projection: List[int] = [] 61 | for code in hash_codes: 62 | while code: 63 | if len(projection) >= feature_size // 2: 64 | break 65 | projection.append(kMappingTable[code & 3]) 66 | code = code >> 2 67 | if len(projection) >= feature_size // 2: 68 | break 69 | return projection[: feature_size // 2] 70 | -------------------------------------------------------------------------------- /text_embeddings/pruning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenghaoMou/embeddings/d414ebfbfc50eb218727ceac207035bf731c4701/text_embeddings/pruning/__init__.py -------------------------------------------------------------------------------- /text_embeddings/pruning/ltp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2021-07-18 14:39:02 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | 6 | from typing import Optional, Tuple 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch import Tensor 12 | 13 | 14 | class LTPMultiHeadAttention(nn.MultiheadAttention): 15 | def __init__( 16 | self, 17 | temperature, 18 | embed_dim, 19 | num_heads, 20 | dropout=0.0, 21 | bias=True, 22 | add_bias_kv=False, 23 | add_zero_attn=False, 24 | kdim=None, 25 | vdim=None, 26 | batch_first=False, 27 | device=None, 28 | dtype=None, 29 | ) -> None: 30 | """ 31 | Examples 32 | -------- 33 | >>> attention = LTPMultiHeadAttention(1, 512, 8, 0.5, batch_first=True) 34 | >>> x = torch.rand((10, 128, 512)) 35 | >>> output, weights, norm = attention(x, x, x) 36 | >>> output.shape 37 | torch.Size([10, 128, 512]) 38 | >>> weights.shape 39 | torch.Size([10, 128, 128]) 40 | """ 41 | super().__init__( 42 | embed_dim, 43 | num_heads, 44 | dropout, 45 | bias, 46 | add_bias_kv, 47 | add_zero_attn, 48 | kdim, 49 | vdim, 50 | batch_first, 51 | device, 52 | dtype, 53 | ) 54 | self.temperature = temperature 55 | self.soft_threshold = nn.Parameter(torch.rand(1), requires_grad=True) 56 | 57 | def forward( 58 | self, 59 | query: Tensor, 60 | key: Tensor, 61 | value: Tensor, 62 | key_padding_mask: Optional[Tensor] = None, 63 | need_weights: bool = True, 64 | attn_mask: Optional[Tensor] = None, 65 | ) -> Tuple[Tensor, Optional[Tensor]]: 66 | r""" 67 | Args: 68 | query, key, value: map a query and a set of key-value pairs to an output. 69 | See "Attention Is All You Need" for more details. 70 | key_padding_mask: if provided, specified padding elements in the key will 71 | be ignored by the attention. When given a binary mask and a value is True, 72 | the corresponding value on the attention layer will be ignored. When given 73 | a byte mask and a value is non-zero, the corresponding value on the attention 74 | layer will be ignored 75 | need_weights: output attn_output_weights. 76 | attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all 77 | the batches while a 3D mask allows to specify a different mask for the entries of each batch. 78 | 79 | Shapes for inputs: 80 | - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is 81 | the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. 82 | - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is 83 | the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``. 84 | - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is 85 | the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``. 86 | - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. 87 | If a ByteTensor is provided, the non-zero positions will be ignored while the position 88 | with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the 89 | value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. 90 | - attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the 91 | source sequence length. 92 | 93 | If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence 94 | length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend 95 | the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend 96 | while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` 97 | is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor 98 | is provided, it will be added to the attention weight. 99 | 100 | Shapes for outputs: 101 | - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, 102 | E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. 103 | - attn_output_weights: :math:`(N, L, S)` where N is the batch size, 104 | L is the target sequence length, S is the source sequence length. 105 | """ 106 | if self.batch_first: 107 | query, key, value = [x.transpose(1, 0) for x in (query, key, value)] 108 | 109 | if not self._qkv_same_embed_dim: 110 | attn_output, attn_output_weights = F.multi_head_attention_forward( 111 | query, 112 | key, 113 | value, 114 | self.embed_dim, 115 | self.num_heads, 116 | self.in_proj_weight, 117 | self.in_proj_bias, 118 | self.bias_k, 119 | self.bias_v, 120 | self.add_zero_attn, 121 | self.dropout, 122 | self.out_proj.weight, 123 | self.out_proj.bias, 124 | training=self.training, 125 | key_padding_mask=key_padding_mask, 126 | need_weights=need_weights, 127 | attn_mask=attn_mask, 128 | use_separate_proj_weight=True, 129 | q_proj_weight=self.q_proj_weight, 130 | k_proj_weight=self.k_proj_weight, 131 | v_proj_weight=self.v_proj_weight, 132 | ) 133 | else: 134 | attn_output, attn_output_weights = F.multi_head_attention_forward( 135 | query, 136 | key, 137 | value, 138 | self.embed_dim, 139 | self.num_heads, 140 | self.in_proj_weight, 141 | self.in_proj_bias, 142 | self.bias_k, 143 | self.bias_v, 144 | self.add_zero_attn, 145 | self.dropout, 146 | self.out_proj.weight, 147 | self.out_proj.bias, 148 | training=self.training, 149 | key_padding_mask=key_padding_mask, 150 | need_weights=need_weights, 151 | attn_mask=attn_mask, 152 | ) 153 | 154 | # (N, L, S) -> (N, S/L) 155 | scores = torch.mean(attn_output_weights, dim=1) 156 | pruning_mask = F.sigmoid((scores - self.soft_threshold) / self.temperature) 157 | attn_output = attn_output.transpose(1, 0) 158 | attn_output = pruning_mask[:, :, None] * attn_output 159 | 160 | if self.batch_first: 161 | return ( 162 | attn_output, 163 | attn_output_weights, 164 | torch.sum(torch.norm(pruning_mask, p=1) / self.num_heads), 165 | ) 166 | else: 167 | return ( 168 | attn_output.transpose(1, 0), 169 | attn_output_weights, 170 | torch.sum(torch.norm(pruning_mask, p=1) / self.num_heads), 171 | ) 172 | -------------------------------------------------------------------------------- /text_embeddings/visual/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2021-04-22 20:59:35 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | 6 | """Visual information based tokenizers.""" 7 | 8 | from .vtr import VTRTokenizer, text2image 9 | 10 | __all__ = ['VTRTokenizer', 'text2image'] -------------------------------------------------------------------------------- /text_embeddings/visual/vtr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2021-04-17 08:08:04 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | # @Description : From Robust Open-­Vocabulary Translation from Visual Text Representations 6 | 7 | """Robust Open­ Vocabulary Translation from Visual Text Representations""" 8 | 9 | from typing import List, Optional, Dict 10 | 11 | import numpy as np 12 | from PIL import Image 13 | from PIL import ImageDraw 14 | from PIL import ImageFont 15 | from numpy.lib.stride_tricks import sliding_window_view 16 | from loguru import logger 17 | 18 | from text_embeddings.base import EmbeddingTokenizer 19 | 20 | def text2image(text: str, font: str, font_size: int = 14) -> Image: 21 | """Convert text into an image and return the image. Reference: https://gist.github.com/destan/5540702 22 | 23 | Parameters 24 | ---------- 25 | text : str 26 | Text to encode 27 | font : str 28 | Name of the font to use 29 | font_size : int, optional 30 | Size of the font, by default 14 31 | 32 | Returns 33 | ------- 34 | Image 35 | Encoded image 36 | """ 37 | 38 | image_font = ImageFont.truetype(font, max(font_size - 2, 8)) 39 | text = text.replace("\n", " ") 40 | 41 | line_width, _ = image_font.getsize(text) 42 | 43 | img = Image.new("L", (line_width, font_size)) 44 | draw = ImageDraw.Draw(img) 45 | draw.text(xy=(0, 0), text=text, fill="#FFFFFF", font=image_font) 46 | 47 | return img 48 | 49 | class VTRTokenizer(EmbeddingTokenizer): 50 | """ 51 | Render the text into a series of image blocks. Reference [VTR](https://t.co/l9E6rL8O5p?amp=1) 52 | 53 | Parameters 54 | ---------- 55 | window_size : int, optional 56 | The width of the image window, by default 10 57 | stride: int optional 58 | The stride used to generate image windows, by default 10 59 | font : str, optional 60 | Path to the font file, by default "resources/Noto_Sans/NotoSans-Regular.ttf" 61 | font_size : int, optional 62 | The size of the font in pixels, might be smaller than the actual image height, by default 14 63 | model_input_names : List[str], optional 64 | Required inputs of the downstream model, by default it uses the same names as a BERT — ["input_ids", "token_type_ids", "attention_mask"] 65 | special_tokens : Optional[Dict[str, np.ndarray]], optional 66 | Special tokens for the downstream model, by default it uses the same special tokens as a BERT — {"CLS": "[CLS]", "SEP": "[SEP]"} 67 | max_length : Optional[int], optional 68 | Maximum sequence length, by default 25 69 | 70 | Examples 71 | -------- 72 | >>> from text_embeddings.visual import VTRTokenizer 73 | >>> from transformers.tokenization_utils_base import * 74 | >>> tokenier = VTRTokenizer() 75 | >>> results = tokenier(text=['This is a sentence.', 'This is another sentence.'], padding=PaddingStrategy.LONGEST, truncation="longest_first", add_special_tokens=False) 76 | >>> assert results['input_ids'].shape == (2, 13, 14, 10), results['input_ids'].shape 77 | """ 78 | 79 | def __init__( 80 | self, 81 | window_size: int = 10, 82 | stride: int = 10, 83 | font: str = "resources/Noto_Sans/NotoSans-Regular.ttf", 84 | font_size: int = 14, 85 | model_input_names: List[str] = None, 86 | special_tokens: Optional[Dict[str, np.ndarray]] = None, 87 | max_length: Optional[int] = 25, 88 | ): 89 | super().__init__(model_input_names, special_tokens, max_length) 90 | self.font_size = font_size 91 | self.window_size = window_size 92 | self.stride = stride 93 | self.font = font 94 | 95 | if self.model_input_names is None: 96 | logger.warning('Using default model_input_names values ["input_ids", "token_type_ids", "attention_mask"]') 97 | self.model_input_names = ["input_ids", "token_type_ids", "attention_mask"] 98 | 99 | def text2embeddings(self, text: str) -> np.ndarray: 100 | """Convert text into an numpy array, in (sequence_length, font_size, window_size) shape. 101 | 102 | Parameters 103 | ---------- 104 | text : str 105 | Input text 106 | 107 | Returns 108 | ------- 109 | np.ndarray 110 | An array in (sequence_length, height, width) shape 111 | """ 112 | if not text: 113 | return None 114 | 115 | image = text2image(text, font=self.font, font_size=self.font_size) 116 | image_array = np.asarray(image) 117 | 118 | return np.squeeze( 119 | sliding_window_view(image_array, (image_array.shape[0], min(self.window_size, image_array.shape[1]))), 120 | axis=0, 121 | )[:: self.stride] 122 | 123 | def create_padding_token_embedding(self, input_embeddings=None) -> np.ndarray: 124 | """Create a padding token embedding for an empty window. 125 | 126 | Parameters 127 | ---------- 128 | input_embeddings : [type], optional 129 | Embeddings already encoded, by default None 130 | 131 | Returns 132 | ------- 133 | np.ndarray 134 | An empty array in (font_size, window_size) shape 135 | """ 136 | return np.zeros((len(input_embeddings[0]), self.window_size)) 137 | -------------------------------------------------------------------------------- /text_embeddings/x/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2021-08-19 12:47:53 4 | # @Author : Chenghao Mou (mouchenghao@gmail.com) 5 | 6 | """X is a Perceiver-based encoder model that incorporates byte hash embeddings, learned token pruning and layer wise adaptive computation (inspired from PonderNet).""" 7 | 8 | import math 9 | from typing import Callable 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | from torch import Tensor 15 | from einops import repeat, rearrange 16 | from transformers import CanineModel 17 | 18 | class PositionalEncoding(nn.Module): 19 | 20 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, batch_first: bool = False): 21 | super().__init__() 22 | self.dropout = nn.Dropout(p=dropout) 23 | 24 | position = torch.arange(max_len).unsqueeze(1) 25 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 26 | pe = torch.zeros(max_len, 1, d_model) 27 | pe[:, 0, 0::2] = torch.sin(position * div_term) 28 | pe[:, 0, 1::2] = torch.cos(position * div_term) 29 | self.register_buffer('pe', pe) 30 | self.batch_first = batch_first 31 | 32 | def forward(self, x: Tensor) -> Tensor: 33 | """ 34 | Args: 35 | x: Tensor, shape [seq_len, batch_size, embedding_dim] 36 | """ 37 | if self.batch_first: 38 | x = x.transpose(1, 0) 39 | x = x + self.pe[:x.size(0)] 40 | return self.dropout(x) if not self.batch_first else self.dropout(x).transpose(0, 1) 41 | 42 | 43 | class AttentionWrapper(nn.Module): 44 | def __init__( 45 | self, 46 | attention_class: Callable, 47 | embed_dim: int, 48 | num_heads: int, 49 | ff_dim: int, 50 | dropout: float, 51 | batch_first: bool, 52 | is_cross_attention: bool, 53 | ): 54 | super().__init__() 55 | 56 | self.is_cross_attention = is_cross_attention 57 | self.pre_attention_q_norm = nn.LayerNorm(embed_dim) 58 | self.pre_attention_kv_norm = ( 59 | nn.LayerNorm(embed_dim) if is_cross_attention else None 60 | ) 61 | 62 | self.attention = attention_class( 63 | embed_dim=embed_dim, 64 | num_heads=num_heads, 65 | dropout=dropout, 66 | batch_first=batch_first, 67 | ) 68 | self.ff = nn.Sequential( 69 | nn.Linear(embed_dim, ff_dim), 70 | nn.GELU(), 71 | nn.Dropout(dropout), 72 | nn.Linear(ff_dim, embed_dim), 73 | ) 74 | 75 | def forward( 76 | self, 77 | query: Tensor, 78 | key: Tensor = None, 79 | value: Tensor = None, 80 | mask: Tensor = None, 81 | ): 82 | query = self.pre_attention_q_norm(query) 83 | key = ( 84 | self.pre_attention_kv_norm(key) 85 | if key is not None and self.pre_attention_kv_norm is not None 86 | else key 87 | ) 88 | value = ( 89 | self.pre_attention_kv_norm(value) 90 | if value is not None and self.pre_attention_kv_norm is not None 91 | else value 92 | ) 93 | 94 | # mask is only useful for cross attention, ignore attention weights 95 | attn_output, *_ = ( 96 | self.attention(query, key, value, key_padding_mask=mask) 97 | if self.is_cross_attention 98 | else self.attention(query, query, query) 99 | ) 100 | output = attn_output + query 101 | output = self.ff(output) + output 102 | 103 | return output 104 | 105 | 106 | class XLayer(nn.Module): 107 | def __init__( 108 | self, 109 | embed_dim: int, 110 | num_cross_attention_heads: int, 111 | num_latent_attention_heads: int, 112 | num_latent_layers: int, 113 | ff_dim: int, 114 | dropout: float, 115 | batch_first: bool, 116 | latent_attention: Callable, 117 | ): 118 | super().__init__() 119 | 120 | self.cross_attention = AttentionWrapper( 121 | attention_class=nn.MultiheadAttention, 122 | embed_dim=embed_dim, 123 | num_heads=num_cross_attention_heads, 124 | ff_dim=ff_dim, 125 | dropout=dropout, 126 | batch_first=batch_first, 127 | is_cross_attention=True, 128 | ) 129 | 130 | # pesudo transfomer 131 | self.latent_attentions = nn.ModuleList( 132 | [ 133 | AttentionWrapper( 134 | attention_class=latent_attention, 135 | embed_dim=embed_dim, 136 | num_heads=num_latent_attention_heads, 137 | ff_dim=ff_dim, 138 | dropout=dropout, 139 | batch_first=batch_first, 140 | is_cross_attention=False, 141 | ) 142 | for _ in range(num_latent_layers) 143 | ] 144 | ) 145 | 146 | def forward( 147 | self, 148 | query: Tensor, 149 | key: Tensor = None, 150 | value: Tensor = None, 151 | mask: Tensor = None, 152 | ): 153 | o = self.cross_attention( 154 | query, 155 | key, 156 | value, 157 | mask=mask, 158 | ) 159 | 160 | for attn in self.latent_attentions: 161 | o = attn(o) 162 | 163 | return o 164 | 165 | 166 | class X(nn.Module): 167 | def __init__( 168 | self, 169 | num_classes: int, 170 | latent_dim: int, 171 | num_layers: int, 172 | embed_dim: int, 173 | num_cross_attention_heads: int, 174 | num_latent_attention_heads: int, 175 | num_latent_layers: int, 176 | ff_dim: int, 177 | dropout: float, 178 | batch_first: bool, 179 | max_length: int, 180 | latent_attention: Callable, 181 | ): 182 | super().__init__() 183 | 184 | self.embedding = CanineModel.from_pretrained('google/canine-s') 185 | self.embedding_ff = nn.Linear(768, embed_dim) 186 | self.layers = nn.ModuleList( 187 | [ 188 | XLayer( 189 | embed_dim=embed_dim, 190 | num_cross_attention_heads=num_cross_attention_heads, 191 | num_latent_attention_heads=num_latent_attention_heads, 192 | num_latent_layers=num_latent_layers, 193 | ff_dim=ff_dim, 194 | dropout=dropout, 195 | batch_first=batch_first, 196 | latent_attention=latent_attention, 197 | ) 198 | for _ in range(num_layers) 199 | ] 200 | ) 201 | 202 | self.num_classes = num_classes 203 | self.latent = nn.Parameter(torch.rand((latent_dim, embed_dim))) 204 | self.output_layer = nn.Linear(embed_dim, self.num_classes) 205 | self.lambda_layer = nn.Sequential(nn.Linear(embed_dim, 1), nn.Sigmoid()) 206 | 207 | def forward( 208 | self, 209 | **inputs 210 | ): 211 | 212 | mask = inputs.get("attention_mask", None) 213 | with torch.no_grad(): 214 | outputs = self.embedding(**inputs) 215 | x = outputs.last_hidden_state 216 | 217 | x = self.embedding_ff(x) 218 | batch_size, *_ = x.shape 219 | un_halted_prob = x.new_ones((batch_size,)) 220 | halted = x.new_zeros((batch_size,)) 221 | 222 | latent = repeat( 223 | rearrange(self.latent, "N D -> 1 N D"), "1 N D -> B N D", B=batch_size 224 | ) 225 | 226 | probas = [] 227 | preds = [] 228 | 229 | p_m = x.new_zeros((batch_size,)) 230 | y_m = x.new_zeros((batch_size, self.num_classes)) 231 | 232 | for i, layer in enumerate(self.layers): 233 | latent = layer(latent, x, x, mask) 234 | 235 | # calculate halting probability for current layer 236 | layer_lambda = ( 237 | x.new_ones((batch_size,)) 238 | if i == len(self.layers) - 1 239 | else self.lambda_layer(torch.mean(latent, dim=1)) 240 | ) 241 | # calculate current prediction from current layer 242 | layer_predictions = self.output_layer(torch.mean(latent, dim=1)) 243 | 244 | # conditional halting probability for current layer: previously not halted * halting now 245 | layer_halted_prob = un_halted_prob * layer_lambda.view(-1) 246 | un_halted_prob = un_halted_prob * (1 - layer_lambda.view(-1)) 247 | 248 | # Halt based on the halting probability 249 | sampling = torch.bernoulli(layer_lambda.reshape(-1)) 250 | halt = sampling * (1 - halted) 251 | 252 | probas.append(layer_halted_prob) 253 | preds.append(layer_predictions) 254 | 255 | p_m = p_m * (1 - halt) + layer_halted_prob * halt 256 | 257 | y_m = y_m * repeat( 258 | 1 - halt, "B -> B C", C=self.num_classes 259 | ) + layer_predictions * repeat(halt, "B -> B C", C=self.num_classes) 260 | 261 | halted = halted + halt 262 | 263 | if not self.training and halted.sum() == batch_size: 264 | break 265 | 266 | return torch.stack(probas), torch.stack(preds), p_m, y_m 267 | 268 | 269 | class ReconstructionLoss(nn.Module): 270 | def __init__(self, loss_fn: Callable): 271 | super().__init__() 272 | self.loss_fn = loss_fn 273 | 274 | def forward(self, probas, preds, labels): 275 | 276 | total = preds.new_tensor(0.0) 277 | for layer_probas, layer_preds in zip(probas, preds): 278 | layer_loss = layer_probas * self.loss_fn(layer_preds, labels) 279 | total = total + layer_loss.mean() 280 | 281 | return total 282 | 283 | 284 | class RegularizationLoss(nn.Module): 285 | def __init__(self, lambda_p: float, max_layers: int): 286 | super().__init__() 287 | p_g = torch.zeros((max_layers,)) 288 | not_halted = 1.0 289 | for k in range(max_layers): 290 | p_g[k] = lambda_p * not_halted 291 | not_halted = not_halted * (1 - lambda_p) 292 | 293 | self.p_g = nn.Parameter(p_g, requires_grad=False) 294 | self.kl_div = nn.KLDivLoss(reduction="batchmean") 295 | 296 | def forward(self, probas): 297 | probas = probas.transpose(0, 1) 298 | p_g = self.p_g[None, : probas.shape[1]].expand_as(probas) 299 | 300 | return self.kl_div(probas.log(), p_g) 301 | 302 | 303 | class XLoss(nn.Module): 304 | def __init__(self, loss_fn: Callable, lambda_p: float, max_layers: int): 305 | super().__init__() 306 | self.reconstruction_loss = ReconstructionLoss(loss_fn) 307 | self.regularization_loss = RegularizationLoss(lambda_p, max_layers) 308 | 309 | def forward(self, probas, preds, labels): 310 | 311 | return self.reconstruction_loss( 312 | probas, preds, labels 313 | ) + self.regularization_loss(probas) 314 | --------------------------------------------------------------------------------