├── tests ├── __init__.py ├── test_loss.py └── test_ops.py ├── torch_impl ├── __init__.py ├── util.py └── MultipleNegativeRankingLoss.py ├── trainer ├── __init__.py ├── loss │ ├── __init__.py │ ├── basic.py │ └── custom.py ├── utils │ ├── __init__.py │ └── ops.py ├── dataloader │ └── __init__.py └── train.py ├── dataset_list ├── __init__.py └── stackexchange │ ├── __init__.py │ ├── download_archive.py │ ├── convert_title_body.py │ └── download_archive_file_list.tsv ├── examples ├── parity_torch_flax │ ├── __init__.py │ └── model.py ├── README.md ├── pytorch_train_script │ ├── MultiDatasetDataLoader.py │ └── training.py └── nils_flax_script │ ├── MultiDatasetDataLoader.py │ └── train.py ├── code-search-net ├── requirements.txt ├── readme.md └── train_code_search_net.py ├── pyproject.toml ├── dataset ├── download_data.py ├── dataset.py └── datasets_list.tsv ├── .github └── workflows │ └── build.yml ├── evaluation └── metrics.py ├── .gitignore ├── README.md ├── datasets └── stackexchange │ └── transforms.py └── conversational-model └── multi_context_train.py /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torch_impl/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset_list/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /trainer/loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /trainer/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /trainer/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset_list/stackexchange/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/parity_torch_flax/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/parity_torch_flax/model.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /code-search-net/requirements.txt: -------------------------------------------------------------------------------- 1 | optax 2 | flax 3 | tqdm 4 | transformers 5 | datasets 6 | wandb -------------------------------------------------------------------------------- /code-search-net/readme.md: -------------------------------------------------------------------------------- 1 | # code-search-net 2 | 3 | This directory holds code for `code-search-net`. 4 | 5 | **Setting up** 6 | 7 | ```shell 8 | pip3 install -r requirements.txt 9 | ``` 10 | 11 | **Initiate training** 12 | 13 | ```shell 14 | python3 train_code_search_net.py 15 | ``` 16 | 17 | Training arguments can be controlled using `TrainingArgs` present in [train_code_search.py](train_code_search.py). 18 | 19 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Example 2 | Some examples to help with with project. 3 | 4 | 5 | ## pytorch_train_script 6 | A train-script using [sentence-transformers](https://www.sbert.net) and PyTorch to train on the dataset format specified for this project. 7 | 8 | It uses a [nreimers/MiniLM-L6-H384-uncased](https://huggingface.co/nreimers/MiniLM-L6-H384-uncased) as model and trains with 2k steps and a batch size of 256. 9 | 10 | This configuration is used as reference to quickly benchmark datasets. -------------------------------------------------------------------------------- /trainer/loss/basic.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax import numpy as jnp 3 | import optax 4 | 5 | 6 | @jax.jit 7 | def jax_cross_entropy_loss(scores, labels): 8 | """ 9 | :param scores: 10 | :param labels: 11 | :return: 12 | """ 13 | class_count = scores.shape[1] 14 | assert scores.shape[0] == len(labels) 15 | one_hot = jax.nn.one_hot(labels, num_classes=class_count) 16 | soft_max = optax.softmax_cross_entropy(scores, one_hot) 17 | return jnp.mean(soft_max) 18 | -------------------------------------------------------------------------------- /dataset_list/stackexchange/download_archive.py: -------------------------------------------------------------------------------- 1 | """ 2 | Downloads all archive files for stackexchange from: https://archive.org/download/stackexchange 3 | 4 | Requires: 5 | pip install sentence-transformers 6 | 7 | for the download 8 | """ 9 | 10 | from sentence_transformers import util 11 | import os 12 | 13 | with open('download_archive_file_list.tsv') as fIn: 14 | for line in fIn: 15 | name = line.strip().split("\t")[0] 16 | output_path = os.path.join("archive", name) 17 | if os.path.exists(output_path): 18 | continue 19 | 20 | if name.endswith('.7z') and '.meta.' not in name: 21 | print("Download:", name) 22 | util.http_get("https://archive.org/download/stackexchange/"+name, output_path) -------------------------------------------------------------------------------- /trainer/loss/custom.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax import numpy as jnp 3 | from .basic import jax_cross_entropy_loss 4 | from ..utils.ops import cos_sim 5 | 6 | 7 | @jax.jit 8 | def multiple_negatives_ranking_loss(embeddings_a: jnp.DeviceArray, embeddings_b: jnp.DeviceArray, 9 | scale: float = 20.0, similarity_fct=cos_sim): 10 | """ 11 | 12 | :param embeddings_a: 13 | :param embeddings_b: if passing additional hard negatives, use jnp.concatenate([positives, negatives], axis=0) as input. 14 | :param scale: 15 | :param similarity_fct: 16 | :return: 17 | """ 18 | assert (len(embeddings_a) <= len(embeddings_b)) 19 | scores = similarity_fct(embeddings_a, embeddings_b) * scale 20 | assert scores.shape == (len(embeddings_a), len(embeddings_b)) 21 | 22 | labels = jnp.arange(len(scores), dtype=jnp.int64) 23 | return jax_cross_entropy_loss(scores, labels) 24 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "flax-sentence-embeddings" 3 | version = "0.1.0" 4 | description = "Flax Sentence Embeddings on 1B+ training pairs" 5 | authors = ["Nils Reimer"] 6 | 7 | packages = [ 8 | { include = "datasets" } 9 | ] 10 | 11 | [tool.poetry.dependencies] 12 | python = "^3.6" 13 | numpy = "*" 14 | flax = "*" 15 | jax = "*" 16 | torch = "*" 17 | pytest = "^3.4" 18 | internetarchive = "*" 19 | transformers = "*" 20 | sentence_transformers = "*" 21 | py7zr = "*" 22 | tqdm = "*" 23 | loguru = "*" 24 | datasets = "^1.8.0" 25 | wandb = "*" 26 | 27 | [tool.poetry.dev-dependencies] 28 | pytest = "^3.4" 29 | black = "*" 30 | pylama = "*" 31 | 32 | [tool.black] 33 | line-length = 100 34 | target-version = ['py37'] 35 | include = '\.pyi?$' 36 | exclude = ''' 37 | 38 | ( 39 | /( 40 | \.eggs # exclude a few common directories in the 41 | | \.git # root of the project 42 | | \.hg 43 | | \.mypy_cache 44 | | \.tox 45 | | \.venv 46 | | _build 47 | | buck-out 48 | | build 49 | | dist 50 | )/ 51 | | foo.py # also separately exclude a file named foo.py in 52 | # the root of the project 53 | ) 54 | ''' 55 | -------------------------------------------------------------------------------- /trainer/train.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax import random, numpy as jnp 3 | from flax import linen as nn 4 | from trainer.loss.custom import multiple_negatives_ranking_loss 5 | from jax.config import config 6 | 7 | # Dummy version 8 | batch_size = 20 9 | embedding_size = 250 10 | 11 | 12 | def demo_train_step(model, params, input): 13 | # We can integrate with existing scripts. this is for demo purpose. 14 | 15 | def loss(params): 16 | preds = model.apply(params, input) 17 | preds = jnp.reshape(preds, (preds.shape[0], -1, embedding_size)) 18 | return multiple_negatives_ranking_loss(preds) 19 | 20 | loss, grad = jax.value_and_grad(loss)(params) 21 | return loss, grad 22 | 23 | 24 | def main(): 25 | key = random.PRNGKey(0) 26 | key1, key2 = random.split(key) 27 | 28 | dummy_model = nn.Dense(features=3 * embedding_size) 29 | dummy_input = random.normal(key1, (batch_size, 200)) 30 | params = dummy_model.init(key2, dummy_input) 31 | 32 | value, grad = demo_train_step(dummy_model, params, dummy_input) 33 | print("Value : ", value) 34 | print("Grad : ", grad) 35 | 36 | 37 | if __name__ == "__main__": 38 | config.update("jax_enable_x64", True) 39 | main() 40 | -------------------------------------------------------------------------------- /dataset/download_data.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import pandas as pd 4 | import urllib.request 5 | 6 | def download_dataset(url, file_name): 7 | urllib.request.urlretrieve(url, file_name) 8 | return 9 | 10 | if __name__ == "__main__": 11 | 12 | parser = argparse.ArgumentParser(description='Dowload raw coropora given dataset list.') 13 | parser.add_argument('--dataset_list') 14 | parser.add_argument('--data_path') 15 | args = parser.parse_args(args=sys.argv[1:]) 16 | 17 | datasets = pd.read_csv( 18 | args.dataset_list, 19 | index_col=0, 20 | sep='\t', 21 | dtype={ 22 | 'Description': str, 23 | 'Size (#Pairs)': str, 24 | 'Performance': float, 25 | 'Download link': str, 26 | 'Source': str}) 27 | datasets['Size (#Pairs)'] = datasets['Size (#Pairs)'].str.replace(',', '').astype(int) 28 | datasets = datasets.to_dict(orient='index') 29 | 30 | print('Downloading {:,} dataset into {}.'.format(len(datasets), args.data_path)) 31 | 32 | for d in datasets.keys(): 33 | print('Downloading dataset {} ({:,} pairs) ... '.format(d, datasets[d]['Size (#Pairs)']), end='', flush=True) 34 | download_dataset( 35 | datasets[d]['Download link'], 36 | os.path.join(os.path.abspath(args.data_path), d + '.json.gz')) 37 | print('\033[32m' + 'Done' + '\033[0m') -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: [3.6] 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | #---------------------------------------------- 20 | # ----- install & configure poetry ----- 21 | #---------------------------------------------- 22 | - name: Install Poetry 23 | uses: snok/install-poetry@v1.1.6 24 | with: 25 | virtualenvs-create: true 26 | virtualenvs-in-project: true 27 | #---------------------------------------------- 28 | # load cached venv if cache exists 29 | #---------------------------------------------- 30 | - name: Load cached venv 31 | id: cached-poetry-dependencies 32 | uses: actions/cache@v2 33 | with: 34 | path: .venv 35 | key: venv-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }} 36 | #---------------------------------------------- 37 | # install dependencies if cache does not exist 38 | #---------------------------------------------- 39 | - name: Install dependencies 40 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 41 | run: poetry install --no-interaction --no-root 42 | #---------------------------------------------- 43 | # install your root project, if required 44 | #---------------------------------------------- 45 | - name: Install library 46 | run: poetry install --no-interaction 47 | - name: Test with pytest 48 | run: | 49 | source .venv/bin/activate 50 | pytest tests/ 51 | 52 | -------------------------------------------------------------------------------- /torch_impl/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | def cos_sim(a: Tensor, b: Tensor): 6 | """ 7 | Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j. 8 | :return: Matrix with res[i][j] = cos_sim(a[i], b[j]) 9 | """ 10 | if not isinstance(a, torch.Tensor): 11 | a = torch.tensor(a) 12 | 13 | if not isinstance(b, torch.Tensor): 14 | b = torch.tensor(b) 15 | 16 | if len(a.shape) == 1: 17 | a = a.unsqueeze(0) 18 | 19 | if len(b.shape) == 1: 20 | b = b.unsqueeze(0) 21 | 22 | a_norm = torch.nn.functional.normalize(a, p=2, dim=1) 23 | b_norm = torch.nn.functional.normalize(b, p=2, dim=1) 24 | return torch.mm(a_norm, b_norm.transpose(0, 1)) 25 | 26 | def mean_pooling(model_output, attention_mask): 27 | """ 28 | Returns mean pooled embeddings from the last layer of a PyTorch based HuggingFace Transformer model. 29 | """ 30 | embeddings = model_output[0] 31 | attention_mask_expanded = attention_mask.unsqueeze(-1).expand(embeddings.shape).float() 32 | sum_embeddings = torch.sum(embeddings * attention_mask_expanded, 1) 33 | sum_mask = torch.clamp(attention_mask_expanded.sum(1), min=1e-9) 34 | return sum_embeddings / sum_mask 35 | 36 | def max_pooling(model_output, attention_mask): 37 | """ 38 | Returns max pooled embeddings from the last layer of a PyTorch based HuggingFace Transformer model. 39 | """ 40 | embeddings = model_output[0] 41 | attention_mask_expanded = attention_mask.unsqueeze(-1).expand(embeddings.shape).float() 42 | return torch.max(embeddings * attention_mask_expanded, 1).values 43 | 44 | def cls_pooling(model_output): 45 | """ 46 | Returns [CLS] token embedding from the last layer of a PyTorch based HuggingFace Transformer model. 47 | """ 48 | return model_output[0][:, 0] # 1st token is the [CLS] token -------------------------------------------------------------------------------- /evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | from os import name 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | def recall_k(n,sim_func,contexts, responses): 6 | """ 7 | Recall 1-of-N metric as used in conveRT paper. 8 | 9 | Recall@k takes N responses to the given conversational context, where only one response is relevant. 10 | It indicates whether the relevant response occurs in the top k ranked candidate responses. 11 | The 1-of-N metric is obtained when k=1. This effectively means that, for each query, 12 | we indicate if the correct response is the top ranked response among N candidates. 13 | The final score is the average across all queries 14 | 15 | from https://github.com/PolyAI-LDN/conversational-datasets/blob/master/baselines/run_baseline.py 16 | 17 | :param n: Number of response candidates to be passed to a context for retrieval 18 | :param contexts: context embeddings - shape (num_embs,emb_dim) 19 | :param responses: response embeddings - shape (num_embs,emb_dim) 20 | :param sim_func: similarity function - cosine or dot product, which returns an similarity scores of shape (num_embs,num_embs) 21 | :return: recall score 22 | 23 | """ 24 | accuracy_numerator = 0.0 25 | accuracy_denominator = 0.0 26 | for i in tqdm(range(0, len(contexts), n)): 27 | context_batch = contexts[i:i + n] 28 | responses_batch = responses[i:i + n] 29 | if len(context_batch) != n: 30 | break 31 | 32 | # Shuffle the responses. 33 | permutation = np.arange(n) 34 | np.random.shuffle(permutation) 35 | context_batch_shuffled = [context_batch[j] for j in permutation] 36 | 37 | predictions = np.argmax(sim_func(context_batch_shuffled, responses_batch),axis=1) 38 | accuracy_numerator += np.equal(predictions, permutation).mean() 39 | accuracy_denominator += 1.0 40 | 41 | accuracy = 100 * accuracy_numerator / accuracy_denominator 42 | return accuracy -------------------------------------------------------------------------------- /trainer/utils/ops.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax import numpy as jnp 3 | 4 | 5 | @jax.jit 6 | def cos_sim(a, b): 7 | a = normalize_L2(a) 8 | b = normalize_L2(b) 9 | return a @ b.T 10 | 11 | 12 | @jax.jit 13 | def normalize_L2(embedding): 14 | return embedding / jnp.maximum(jnp.linalg.norm(embedding, ord=2, axis=1, keepdims=True), 1e-12) 15 | 16 | 17 | @jax.jit 18 | def mean_pooling(model_output, attention_mask): 19 | """ 20 | This function applies mean pooling to contextualized embeddings produced by the last layer of a Flax based HuggingFace Transformer model. 21 | 22 | :param model_output: model output from a model of type `FlaxPreTrainedModel` 23 | :param attention_mask: attention mask in the model input 24 | :return: mean pooled embeddings 25 | """ 26 | embeddings = model_output[0] 27 | attention_mask_expanded = jnp.broadcast_to(jnp.expand_dims(attention_mask, -1), embeddings.shape) 28 | sum_embeddings = jnp.sum(embeddings * attention_mask_expanded, 1) 29 | sum_mask = jnp.clip(attention_mask_expanded.sum(1), a_min=1e-9) 30 | return sum_embeddings / sum_mask 31 | 32 | 33 | @jax.jit 34 | def max_pooling(model_output, attention_mask): 35 | """ 36 | This function applies max pooling to contextualized embeddings produced by the last layer of a Flax based HuggingFace Transformer model. 37 | 38 | :param model_output: model output from a model of type `FlaxPreTrainedModel` 39 | :param attention_mask: attention mask in the model input 40 | :return: max pooled embeddings 41 | """ 42 | embeddings = model_output[0] 43 | attention_mask_expanded = jnp.broadcast_to(jnp.expand_dims(attention_mask, -1), embeddings.shape) 44 | return jnp.max(embeddings * attention_mask_expanded, 1) 45 | 46 | 47 | @jax.jit 48 | def cls_pooling(model_output): 49 | """ 50 | This function returns the [CLS] token embedding produced by the last layer of a Flax based HuggingFace Transformer model. 51 | 52 | :param model_output: model output from a model of type `FlaxPreTrainedModel` 53 | :param attention_mask: attention mask in the model input 54 | :return: [CLS] token embedding 55 | """ 56 | return model_output[0][:, 0] 57 | -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import gzip 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader, IterableDataset 5 | from transformers import BertTokenizer 6 | 7 | """ 8 | Implementation of lazy dataloder in PyTorch 9 | """ 10 | 11 | class TextIterator: 12 | def __init__(self, text_iterator, batch_size, num_workers, transform=None): 13 | self.batch_size = batch_size 14 | self.iter_number = 0 15 | self.num_workers = num_workers 16 | self.text_iterator = text_iterator 17 | self.transform = transform 18 | 19 | def __iter__(self): 20 | return self.text_iterator 21 | 22 | def __next__(self): 23 | if self.iter_number == self.batch_size: 24 | self.iter_number = 0 25 | for _ in range(self.batch_size * (self.num_workers - 1)): 26 | next(self.text_iterator) 27 | self.iter_number += 1 28 | answer, question = json.loads(next(self.text_iterator)) 29 | sample = {'question': question, 'answer': answer} 30 | sample = copy.deepcopy(sample) 31 | if self.transform: 32 | sample = self.transform(sample) 33 | return sample 34 | 35 | def __del__(self): 36 | self.text_iterator.close() 37 | 38 | 39 | class TextSimpleIterator: 40 | def __init__(self, text_iterator, transform=None): 41 | self.text_iterator = text_iterator 42 | self.transform = transform 43 | 44 | def __iter__(self): 45 | return self.text_iterator 46 | 47 | def __next__(self): 48 | answer, question = json.loads(next(self.text_iterator)) 49 | sample = {'question': question, 'answer': answer} 50 | if self.transform: 51 | sample = self.transform(sample) 52 | return sample 53 | 54 | def __del__(self): 55 | self.text_iterator.close() 56 | 57 | 58 | class IterableCorpusDataset(IterableDataset): 59 | def __init__(self, file_path, batch_size, num_workers, start=0, transform=None): 60 | self.file_path = file_path 61 | self.batch_size = batch_size 62 | self.num_workers = num_workers 63 | self.start = start 64 | self.transform = transform 65 | 66 | def __iter__(self): 67 | worker_info = torch.utils.data.get_worker_info() 68 | dataset_itr = gzip.open(self.file_path, "rb") 69 | if worker_info is None: 70 | dataset_itr = gzip.open(self.file_path, "rb") 71 | for _ in range(self.start): 72 | next(dataset_itr) 73 | return TextSimpleIterator(dataset_itr, self.transform) 74 | else: 75 | worker_id = worker_info.id 76 | for _ in range(self.start): 77 | next(dataset_itr) 78 | for _ in range(self.batch_size * worker_id): 79 | next(dataset_itr) 80 | return TextIterator(dataset_itr, self.batch_size, self.num_workers, self.transform) 81 | 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # pytype static type analyzer 136 | .pytype/ 137 | 138 | # Cython debug symbols 139 | cython_debug/ 140 | 141 | poetry.lock 142 | requirements.txt 143 | .vscode 144 | 145 | data 146 | 147 | -------------------------------------------------------------------------------- /torch_impl/MultipleNegativeRankingLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from typing import Iterable, Dict 4 | from .util import cos_sim 5 | 6 | 7 | class MultipleNegativesRankingLoss(nn.Module): 8 | """ 9 | This loss expects as input a batch consisting of sentence pairs (a_1, p_1), (a_2, p_2)..., (a_n, p_n) 10 | where we assume that (a_i, p_i) are a positive pair and (a_i, p_j) for i!=j a negative pair. 11 | For each a_i, it uses all other p_j as negative samples, i.e., for a_i, we have 1 positive example (p_i) and 12 | n-1 negative examples (p_j). It then minimizes the negative log-likehood for softmax normalized scores. 13 | This loss function works great to train embeddings for retrieval setups where you have positive pairs (e.g. (query, relevant_doc)) 14 | as it will sample in each batch n-1 negative docs randomly. 15 | The performance usually increases with increasing batch sizes. 16 | For more information, see: https://arxiv.org/pdf/1705.00652.pdf 17 | (Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4) 18 | You can also provide one or multiple hard negatives per anchor-positive pair by structering the data like this: 19 | (a_1, p_1, n_1), (a_2, p_2, n_2) 20 | Here, n_1 is a hard negative for (a_1, p_1). The loss will use for the pair (a_i, p_i) all p_j (j!=i) and all n_j as negatives. 21 | Example:: 22 | from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses 23 | from sentence_transformers.readers import InputExample 24 | model = SentenceTransformer('distilbert-base-nli-mean-tokens') 25 | train_examples = [InputExample(texts=['Anchor 1', 'Positive 1']), 26 | InputExample(texts=['Anchor 2', 'Positive 2'])] 27 | train_dataset = SentencesDataset(train_examples, model) 28 | train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size) 29 | train_loss = losses.MultipleNegativesRankingLoss(model=model) 30 | """ 31 | 32 | def __init__(self, scale: float = 20.0, similarity_fct=cos_sim): 33 | """ 34 | :param model: SentenceTransformer model 35 | :param scale: Output of similarity function is multiplied by scale value 36 | :param similarity_fct: similarity function between sentence embeddings. By default, cos_sim. Can also be set to dot product (and then set scale to 1) 37 | """ 38 | super(MultipleNegativesRankingLoss, self).__init__() 39 | self.scale = scale 40 | self.similarity_fct = similarity_fct 41 | self.cross_entropy_loss = nn.CrossEntropyLoss() 42 | 43 | def forward(self, embeddings_a, embeddings_b, labels: Tensor): 44 | scores = self.similarity_fct(embeddings_a, embeddings_b) * self.scale 45 | labels = torch.tensor(range(len(scores)), dtype=torch.long, 46 | device=scores.device) # Example a[i] should match with b[i] 47 | return self.cross_entropy_loss(scores, labels) 48 | 49 | def get_config_dict(self): 50 | return {'scale': self.scale, 'similarity_fct': self.similarity_fct.__name__} 51 | -------------------------------------------------------------------------------- /dataset_list/stackexchange/convert_title_body.py: -------------------------------------------------------------------------------- 1 | """ 2 | Converts the archive from 3 | https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml 4 | 5 | To jsonl format 6 | 7 | python convert_archives.py input_folder output_folder 8 | 9 | Returns (title, body) pairs that pass certain quality checks 10 | """ 11 | 12 | import os 13 | import glob 14 | import json 15 | import gzip 16 | import random 17 | from typing import List, Any, IO, Dict 18 | from tqdm import tqdm 19 | import xml.etree.ElementTree as ET 20 | import re 21 | import py7zr 22 | import sys 23 | 24 | input_folder = sys.argv[1] 25 | output_folder = sys.argv[2] 26 | os.makedirs(output_folder, exist_ok=False) 27 | 28 | 29 | random.seed(42) 30 | 31 | min_title_len = 20 32 | min_body_len = 20 33 | max_body_len = 4096 34 | min_score = 0 35 | 36 | large_stackexchange_threshold = 10000 #Stackexchange smaller than this go to a special output file 37 | small_stackexchange_filepath = os.path.join(output_folder, "small_stackexchanges.jsonl") 38 | 39 | def parse_posts(f: IO[Any]) -> List[Dict]: 40 | tree = ET.parse(f) 41 | posts = tree.getroot() 42 | pairs = [] 43 | num_questions = 0 44 | 45 | for post in posts: 46 | data = post.attrib 47 | if data["PostTypeId"] == "1": # focus just on questions for now, not answers 48 | num_questions += 1 49 | # remove all HTML tags (including links!) and normalize whitespace 50 | title = re.sub("<.*?>", "", data["Title"]).strip() 51 | body = re.sub("<.*?>", "", data["Body"]).strip() 52 | tags_str = data["Tags"] 53 | tags = re.findall(r"<(.*?)>", tags_str) 54 | score = int(data["Score"]) 55 | 56 | if len(title) < min_title_len or len(body) < min_body_len or len(body) > max_body_len or score < min_score: 57 | continue 58 | 59 | pairs.append({'texts': [title, body], 'tags': tags}) 60 | print("Questions:", num_questions) 61 | print("Questions after filter:", len(pairs)) 62 | return pairs 63 | 64 | 65 | def extract_posts(stack_exchange_file: str) -> List[Dict]: 66 | with py7zr.SevenZipFile(stack_exchange_file, mode="r") as z: 67 | fs = z.read(targets=["Posts.xml"]) 68 | if fs is not None and "Posts.xml" in fs: 69 | posts = parse_posts(fs["Posts.xml"]) 70 | return posts 71 | return [] 72 | 73 | 74 | def convert_to_jsonl_gz(input_file: str, output_file: str) -> None: 75 | posts = extract_posts(input_file) 76 | random.shuffle(posts) 77 | if len(posts) == 0: 78 | return 79 | 80 | if len(posts) >= large_stackexchange_threshold: 81 | fOut = gzip.open(output_file, "wt") 82 | else: 83 | fOut = open(small_stackexchange_filepath, "a") 84 | 85 | for post in posts: 86 | fOut.write(json.dumps(post)) 87 | fOut.write("\n") 88 | 89 | fOut.close() 90 | 91 | 92 | 93 | 94 | for filepath in sorted(glob.glob(os.path.join(input_folder, "*.7z")), key=os.path.getsize, reverse=True): 95 | name = os.path.basename(filepath.strip(".7z")) 96 | output_path = os.path.join(output_folder, f"{name}.jsonl.gz") 97 | print(filepath) 98 | convert_to_jsonl_gz(filepath, output_path) 99 | -------------------------------------------------------------------------------- /examples/pytorch_train_script/MultiDatasetDataLoader.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | import random 4 | 5 | class MultiDatasetDataLoader: 6 | def __init__(self, datasets, batch_size_pairs, batch_size_triplets=None, dataset_size_temp=-1, allow_swap=True): 7 | self.allow_swap = allow_swap 8 | self.batch_size_pairs = batch_size_pairs 9 | self.batch_size_triplets = batch_size_pairs if batch_size_triplets is None else batch_size_triplets 10 | 11 | # Compute dataset weights 12 | self.dataset_lengths = list(map(len, datasets)) 13 | self.dataset_lengths_sum = sum(self.dataset_lengths) 14 | 15 | weights = [] 16 | if dataset_size_temp > 0: # Scale probability with dataset size 17 | for dataset in datasets: 18 | prob = len(dataset) / self.dataset_lengths_sum 19 | weights.append(max(1, int(math.pow(prob, 1 / dataset_size_temp) * 1000))) 20 | else: # Equal weighting of all datasets 21 | weights = [100] * len(datasets) 22 | 23 | logging.info("Dataset lenghts and weights: {}".format(list(zip(self.dataset_lengths, weights)))) 24 | 25 | self.dataset_idx = [] 26 | self.dataset_idx_pointer = 0 27 | 28 | for idx, weight in enumerate(weights): 29 | self.dataset_idx.extend([idx] * weight) 30 | random.shuffle(self.dataset_idx) 31 | 32 | self.datasets = [] 33 | for dataset in datasets: 34 | random.shuffle(dataset) 35 | self.datasets.append({ 36 | 'elements': dataset, 37 | 'pointer': 0, 38 | }) 39 | 40 | def __iter__(self): 41 | for _ in range(int(self.__len__())): 42 | # Select dataset 43 | if self.dataset_idx_pointer >= len(self.dataset_idx): 44 | self.dataset_idx_pointer = 0 45 | random.shuffle(self.dataset_idx) 46 | 47 | dataset_idx = self.dataset_idx[self.dataset_idx_pointer] 48 | self.dataset_idx_pointer += 1 49 | 50 | # Select batch from this dataset 51 | dataset = self.datasets[dataset_idx] 52 | batch_size = self.batch_size_pairs if len(dataset['elements'][0].texts) == 2 else self.batch_size_triplets 53 | 54 | batch = [] 55 | texts_in_batch = set() 56 | guid_in_batch = set() 57 | while len(batch) < batch_size: 58 | example = dataset['elements'][dataset['pointer']] 59 | 60 | valid_example = True 61 | # First check if one of the texts in already in the batch 62 | for text in example.texts: 63 | text_norm = text.strip().lower() 64 | if text_norm in texts_in_batch: 65 | valid_example = False 66 | 67 | texts_in_batch.add(text_norm) 68 | 69 | # If the example has a label, check if label is in batch 70 | if example.guid is not None: 71 | valid_example = valid_example and example.guid not in guid_in_batch 72 | guid_in_batch.add(example.guid) 73 | 74 | 75 | if valid_example: 76 | if self.allow_swap and random.random() > 0.5: 77 | example.texts[0], example.texts[1] = example.texts[1], example.texts[0] 78 | 79 | batch.append(example) 80 | 81 | dataset['pointer'] += 1 82 | if dataset['pointer'] >= len(dataset['elements']): 83 | dataset['pointer'] = 0 84 | random.shuffle(dataset['elements']) 85 | 86 | yield self.collate_fn(batch) if self.collate_fn is not None else batch 87 | 88 | def __len__(self): 89 | return int(self.dataset_lengths_sum / self.batch_size_pairs) -------------------------------------------------------------------------------- /tests/test_loss.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from torch_impl.MultipleNegativeRankingLoss import MultipleNegativesRankingLoss 4 | from trainer.loss.custom import multiple_negatives_ranking_loss 5 | from jax import value_and_grad 6 | from jax import random 7 | import jax.numpy as jnp 8 | import numpy as onp 9 | import optax 10 | from jax import nn as jax_nn 11 | from torch import nn as torch_nn 12 | from trainer.loss.custom import jax_cross_entropy_loss 13 | from jax.config import config 14 | 15 | config.update("jax_enable_x64", True) 16 | 17 | class LossTest(unittest.TestCase): 18 | def test_jax_cross_entropy_loss(self): 19 | key = random.PRNGKey(0) 20 | key, a_key, b_key = random.split(key, 3) 21 | 22 | sample_count = 200 23 | label_count = 400 24 | scores = random.normal(a_key, (sample_count, label_count)) 25 | torch_scores = torch.tensor(onp.asarray(scores), requires_grad=True) 26 | 27 | labels = random.randint(b_key, (sample_count, ), minval=0, maxval=label_count - 1) 28 | torch_labels = torch.tensor(onp.asarray(labels), dtype=torch.long) 29 | jax_labels = jax_nn.one_hot(labels, num_classes=label_count) 30 | 31 | jax_cross_entropy = jnp.mean(optax.softmax_cross_entropy(scores, jax_labels)) 32 | jax_padded_cross_entropy = jax_cross_entropy_loss(scores, labels) 33 | torch_cross_entropy = torch_nn.CrossEntropyLoss() 34 | torch_cross_entropy = torch_cross_entropy.forward(torch_scores, torch_labels) 35 | 36 | assert onp.all(onp.abs(torch_cross_entropy.item() - jax_cross_entropy) < 0.001) 37 | assert onp.all(onp.abs(torch_cross_entropy.item() - jax_padded_cross_entropy) < 0.001) 38 | 39 | 40 | def test_multiple_negatives_ranking_loss(self): 41 | """Tests the correct computation of multiple_negatives_ranking_loss""" 42 | key = random.PRNGKey(0) 43 | key, a_key, b_key = random.split(key, 3) 44 | a = random.normal(a_key, (20, 200)) 45 | b = random.normal(b_key, (20, 200)) 46 | 47 | a_torch = torch.tensor(onp.asarray(a), requires_grad=True) 48 | b_torch = torch.tensor(onp.asarray(b), requires_grad=True) 49 | 50 | torch_loss = MultipleNegativesRankingLoss() 51 | torch_loss = torch_loss.forward(a_torch, b_torch, None) 52 | torch_loss.backward() 53 | torch_grad = a_torch.grad.numpy() 54 | 55 | jax_loss, jax_grad = value_and_grad(multiple_negatives_ranking_loss)(a, b) 56 | assert abs(torch_loss.item() - jax_loss) <= 0.001, "loss : {} vs {}".format(jax_loss, torch_loss.item()) 57 | 58 | assert onp.all(onp.abs(torch_grad - jax_grad) < 0.001) 59 | 60 | def test_multiple_negatives_ranking_loss_triple(self): 61 | """Tests the correct computation of multiple_negatives_ranking_loss, with hard negatives""" 62 | key = random.PRNGKey(0) 63 | key, a_key, b_key, c_key = random.split(key, 4) 64 | a = random.normal(a_key, (20, 200)) 65 | b = random.normal(b_key, (20, 200)) 66 | c = random.normal(c_key, (20, 200)) 67 | 68 | a_torch = torch.tensor(onp.asarray(a), requires_grad=True) 69 | b_torch = torch.tensor(onp.asarray(b), requires_grad=True) 70 | c_torch = torch.tensor(onp.asarray(c), requires_grad=True) 71 | 72 | comp_torch = torch.cat([b_torch, c_torch]) 73 | 74 | torch_loss = MultipleNegativesRankingLoss() 75 | torch_loss = torch_loss.forward(a_torch, comp_torch, None) 76 | torch_loss.backward() 77 | torch_grad = a_torch.grad.numpy() 78 | 79 | jax_loss, jax_grad = value_and_grad(multiple_negatives_ranking_loss)(a, jnp.concatenate([b, c], axis=0)) 80 | assert abs(torch_loss.item() - jax_loss) <= 0.001, "loss : {} vs {}".format(jax_loss, torch_loss.item()) 81 | 82 | assert onp.all(onp.abs(torch_grad - jax_grad) < 0.001) -------------------------------------------------------------------------------- /examples/pytorch_train_script/training.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is an example how to train with sentence-transformers. 3 | 4 | It trains the model just for 2k steps using equal weighting of all provided dataset files. 5 | 6 | Run: 7 | python training.py exp-name file1.jsonl.gz [file2.jsonl.gz] ... 8 | 9 | """ 10 | import math 11 | from sentence_transformers import models, losses, datasets 12 | from sentence_transformers import LoggingHandler, SentenceTransformer, util, InputExample 13 | from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator 14 | import logging 15 | from datetime import datetime 16 | import sys 17 | import os 18 | import gzip 19 | import csv 20 | from MultiDatasetDataLoader import MultiDatasetDataLoader 21 | from shutil import copyfile 22 | import json 23 | 24 | #### Just some code to print debug information to stdout 25 | logging.basicConfig(format='%(asctime)s - %(message)s', 26 | datefmt='%Y-%m-%d %H:%M:%S', 27 | level=logging.INFO, 28 | handlers=[LoggingHandler()]) 29 | #### /print debug information to stdout 30 | 31 | exp_name = sys.argv[1] 32 | 33 | 34 | model_name = 'nreimers/MiniLM-L6-H384-uncased' 35 | batch_size_pairs = 256 36 | batch_size_triplets = 256 37 | steps_per_epoch = 2000 38 | 39 | num_epochs = 1 40 | max_seq_length = 128 41 | use_amp = True 42 | warmup_steps = 500 43 | 44 | ##### 45 | 46 | output_path = 'output/training_data_benchmark-{}-norm-{}'.format(model_name.replace("/", "-"), exp_name) 47 | logging.info("Output: "+output_path) 48 | if os.path.exists(output_path): 49 | exit() 50 | 51 | 52 | # Write train script to output path 53 | os.makedirs(output_path, exist_ok=True) 54 | 55 | train_script_path = os.path.join(output_path, 'train_script.py') 56 | copyfile(__file__, train_script_path) 57 | with open(train_script_path, 'a') as fOut: 58 | fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv)) 59 | 60 | ## SentenceTransformer model 61 | word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length) 62 | pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) 63 | norm = models.Normalize() 64 | model = SentenceTransformer(modules=[word_embedding_model, pooling_model, norm]) 65 | 66 | datasets = [] 67 | for filepath in sys.argv[2:]: 68 | filepath = filepath.strip() 69 | dataset = [] 70 | 71 | with gzip.open(filepath, 'rt', encoding='utf8') as fIn: 72 | for line in fIn: 73 | data = json.loads(line.strip()) 74 | 75 | if not isinstance(data, dict): 76 | data = {'guid': None, 'texts': data} 77 | 78 | dataset.append(InputExample(guid=data.get('guid', None), texts=data['texts'])) 79 | if len(dataset) >= (steps_per_epoch * batch_size_pairs * 2): 80 | break 81 | 82 | datasets.append(dataset) 83 | logging.info("{}: {}".format(filepath, len(dataset))) 84 | 85 | 86 | # Special data loader to load from multiple datasets 87 | train_dataloader = MultiDatasetDataLoader(datasets, batch_size_pairs=batch_size_pairs, batch_size_triplets=batch_size_triplets) 88 | 89 | 90 | # Our training loss 91 | train_loss = losses.MultipleNegativesRankingLoss(model, scale=20, similarity_fct=util.dot_score) 92 | 93 | 94 | 95 | # Configure the training 96 | logging.info("Warmup-steps: {}".format(warmup_steps)) 97 | 98 | # Train the model 99 | model.fit(train_objectives=[(train_dataloader, train_loss)], 100 | evaluator=None, 101 | epochs=1, 102 | warmup_steps=warmup_steps, 103 | steps_per_epoch=steps_per_epoch, 104 | scheduler='warmupconstant', #Remove this line when you train on larger datasets. After warmup, LR will be constant 105 | use_amp=use_amp 106 | ) 107 | 108 | 109 | model.save(output_path) -------------------------------------------------------------------------------- /tests/test_ops.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as onp 3 | import jax.numpy as jnp 4 | import trainer.utils.ops as jax_util 5 | import torch 6 | from torch_impl import util as torch_util 7 | import torch.nn.functional as torch_F 8 | 9 | 10 | class UtilTest(unittest.TestCase): 11 | def test_cos_sim(self): 12 | """Tests the correct computation of utils.ops.cos_sim""" 13 | a = onp.random.randn(50, 100) 14 | b = onp.random.randn(50, 100) 15 | 16 | pytorch_cos_scores = torch_util.cos_sim(a, b).numpy() 17 | jax_cos_scores = onp.asarray(jax_util.cos_sim(a, b)) 18 | 19 | assert pytorch_cos_scores.shape == jax_cos_scores.shape 20 | for i in range(len(jax_cos_scores)): 21 | for j in range(len(jax_cos_scores[0])): 22 | assert abs(pytorch_cos_scores[i][j] - jax_cos_scores[i][j]) < 0.001, "Output : torch - {}, jax - {}" \ 23 | .format(pytorch_cos_scores[i], jax_cos_scores[i]) 24 | 25 | def test_normalize_L2(self): 26 | """Tests the correct computation of normalize_L2""" 27 | a = onp.random.randn(50, 100) 28 | 29 | pytorch_normalize = torch_F.normalize(torch.tensor(a), p=2, dim=1).numpy() 30 | jax_normalize = onp.asarray(jax_util.normalize_L2(a)) 31 | 32 | assert pytorch_normalize.shape == jax_normalize.shape 33 | for i in range(len(pytorch_normalize)): 34 | assert onp.all(onp.abs(pytorch_normalize[i] - jax_normalize[i]) < 0.001), "Output : torch - {}, jax - {}" \ 35 | .format(pytorch_normalize[i], jax_normalize[i]) 36 | 37 | def test_mean_pooling(self): 38 | """Tests the correct computation of mean_pooling""" 39 | batch_size = 3 40 | max_seq_len = 128 41 | embedding_size = 768 42 | 43 | model_outputs = (onp.random.randn(batch_size, max_seq_len, embedding_size),) 44 | attention_mask = onp.random.randint(2, size=(batch_size, max_seq_len)) 45 | 46 | model_outputs_pt = torch.tensor(model_outputs) 47 | attention_mask_pt = torch.tensor(attention_mask) 48 | embeddings_pt = torch_util.mean_pooling(model_outputs_pt, attention_mask_pt) 49 | 50 | model_outputs_jax = jnp.asarray(model_outputs) 51 | attention_mask_jax = jnp.asarray(attention_mask) 52 | embeddings_jax = jax_util.mean_pooling(model_outputs_jax, attention_mask_jax) 53 | 54 | assert embeddings_pt.numpy().shape == onp.array(embeddings_jax).shape 55 | assert onp.all(onp.abs(embeddings_pt.numpy() - onp.array(embeddings_jax)) < 0.001) 56 | 57 | def test_max_pooling(self): 58 | """Tests the correct computation of max_pooling""" 59 | batch_size = 3 60 | max_seq_len = 128 61 | embedding_size = 768 62 | 63 | model_outputs = (onp.random.randn(batch_size, max_seq_len, embedding_size),) 64 | attention_mask = onp.random.randint(2, size=(batch_size, max_seq_len)) 65 | 66 | model_outputs_pt = torch.tensor(model_outputs) 67 | attention_mask_pt = torch.tensor(attention_mask) 68 | embeddings_pt = torch_util.max_pooling(model_outputs_pt, attention_mask_pt) 69 | 70 | model_outputs_jax = jnp.asarray(model_outputs) 71 | attention_mask_jax = jnp.asarray(attention_mask) 72 | embeddings_jax = jax_util.max_pooling(model_outputs_jax, attention_mask_jax) 73 | 74 | assert embeddings_pt.numpy().shape == onp.array(embeddings_jax).shape 75 | assert onp.all(onp.abs(embeddings_pt.numpy() - onp.array(embeddings_jax)) < 0.001) 76 | 77 | def test_max_pooling(self): 78 | """Tests the correct computation of max_pooling""" 79 | batch_size = 3 80 | max_seq_len = 128 81 | embedding_size = 768 82 | 83 | model_outputs = (onp.random.randn(batch_size, max_seq_len, embedding_size),) 84 | attention_mask = onp.random.randint(2, size=(batch_size, max_seq_len)) 85 | 86 | model_outputs_pt = torch.tensor(model_outputs) 87 | attention_mask_pt = torch.tensor(attention_mask) 88 | embeddings_pt = torch_util.max_pooling(model_outputs_pt, attention_mask_pt) 89 | 90 | model_outputs_jax = jnp.asarray(model_outputs) 91 | attention_mask_jax = jnp.asarray(attention_mask) 92 | embeddings_jax = jax_util.max_pooling(model_outputs_jax, attention_mask_jax) 93 | 94 | assert embeddings_pt.numpy().shape == onp.array(embeddings_jax).shape 95 | assert onp.all(onp.abs(embeddings_pt.numpy() - onp.array(embeddings_jax)) < 0.001) 96 | 97 | def test_cls_pooling(self): 98 | """Tests the correct computation of cls_pooling""" 99 | batch_size = 3 100 | max_seq_len = 128 101 | embedding_size = 768 102 | 103 | model_outputs = (onp.random.randn(batch_size, max_seq_len, embedding_size),) 104 | 105 | model_outputs_pt = torch.tensor(model_outputs) 106 | embeddings_pt = torch_util.cls_pooling(model_outputs_pt) 107 | 108 | model_outputs_jax = jnp.asarray(model_outputs) 109 | embeddings_jax = jax_util.cls_pooling(model_outputs_jax) 110 | 111 | assert embeddings_pt.numpy().shape == onp.array(embeddings_jax).shape 112 | assert onp.all(onp.abs(embeddings_pt.numpy() - onp.array(embeddings_jax)) < 0.001) 113 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # flax-sentence-embeddings 2 | 3 | This repository will be used to share code for the Flax / JAX community event to train sentence embeddings on 1B+ training pairs. 4 | 5 | You can add your code by creating a pull request. 6 | 7 | 8 | ## Dataloading 9 | 10 | ### Dowload data 11 | 12 | You can download the data using this basic python script at the root of the project. 13 | Download should be completed in about 20 minutes given your connection speed. Total size on disk is arround 25G. 14 | 15 | ```bash 16 | python dataset/download_data.py --dataset_list=dataset/datasets_list.tsv --data_path=PATH_TO_STORE_DATASETS 17 | ``` 18 | On a different note: 19 | 20 | There is another directory called `dataset_list`, which contains a subdirectory called `stackexchange`. This subdirectory contains the script to download the compressed stackexchange `xml` files from the internet archive. Once downloaded, these compressed stackexchange xml files need to be converted to the `jsonl` format for training purpose. To transform, use the the `datasets/stackexchange/transforms.py` script, which generates the required input data format for training. A small file restructuring or cleanup is required, as the directories `dataset`, `datasets` and `dataset_list` could be confusing. 21 | 22 | ### Dataloading 23 | 24 | First implementation of the dataloader takes as input a single `jsonl.gz` file. 25 | It creates a pointer on the file such that samples are loaded one by one. 26 | The implementation is based on `torch` standard `Dataloader` and `Dataset` classes. 27 | The class supports `num_worker>0` such that data loading is done in a background process on the CPU, i.e. the data is loaded and tokenized in parallel to training the network. 28 | This avoid to create a bottleneck from I/O and tokenization. The implementation currently return `{'anchor': '...,' 'positive': '...'}` 29 | 30 | ``` 31 | from dataset.dataset import IterableCorpusDataset 32 | 33 | corpus_dataset = IterableCorpusDataset( 34 | file_path=os.path.join(PATH_TO_STORE_DATASETS, 'stackexchange_duplicate_questions_title_title.json.gz'), 35 | batch_size=2, 36 | num_workers=2, 37 | transform=None) 38 | 39 | corpus_dataset_itr = iter(corpus_dataset) 40 | next(corpus_dataset_itr) 41 | 42 | # {'anchor': 'Can anyone explain all these Developer Options?', 43 | # 'positive': 'what is the advantage of using the GPU rendering options in Android?'} 44 | 45 | def collate(batch_input_str): 46 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 47 | batch = {'anchor': tokenizer.batch_encode_plus([b['anchor'] for b in batch_input_str], pad_to_max_length=True), 48 | 'positive': tokenizer.batch_encode_plus([b['positive'] for b in batch_input_str], pad_to_max_length=True)} 49 | return batch 50 | 51 | corpus_dataloader = DataLoader( 52 | corpus_dataset, 53 | batch_size=2, 54 | num_workers=2, 55 | collate_fn=collate, 56 | pin_memory=False, 57 | drop_last=True, 58 | shuffle=False) 59 | 60 | print(next(iter(corpus_dataloader))) 61 | 62 | # {'anchor': {'input_ids': [[101, 4531, 2019, 2523, 2090, 2048, 4725, 1997, 2966, 8830, 1998, 1037, 7142, 8023, 102, 0, 0, 0], [101, 1039, 1001, 10463, 5164, 1061, 2100, 2100, 24335, 26876, 11927, 4779, 4779, 2102, 2000, 3058, 7292, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}, 'positive': {'input_ids': [[101, 1045, 2031, 2182, 2007, 2033, 1010, 2048, 4725, 1997, 8830, 1025, 1037, 3115, 2729, 4118, 1010, 1998, 1037, 17009, 8830, 1012, 2367, 3633, 4374, 2367, 4118, 1010, 2049, 2035, 18154, 11095, 1012, 1045, 2572, 2667, 2000, 2424, 1996, 2523, 1997, 1996, 17009, 8830, 1998, 1037, 1005, 2092, 2108, 3556, 1005, 2029, 2003, 1037, 15973, 3643, 1012, 2054, 2003, 1996, 2190, 2126, 2000, 2424, 2151, 8924, 1029, 1041, 1012, 1043, 1012, 8833, 6553, 26237, 2944, 1029, 102], [101, 1045, 2572, 2667, 2000, 10463, 1037, 5164, 3058, 2046, 1037, 4289, 2005, 29296, 3058, 7292, 1012, 1996, 4289, 2003, 2066, 1024, 1000, 2297, 2692, 20958, 2620, 17134, 19317, 19317, 1000, 1045, 2228, 2023, 1041, 16211, 4570, 2000, 1061, 2100, 2100, 24335, 26876, 11927, 4779, 4779, 2102, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]}} 63 | 64 | ``` 65 | ======= 66 | 67 | ## Installation 68 | 69 | ### Poetry 70 | 71 | A Poetry toml is provided to manage dependencies in a virtualenv. Check https://python-poetry.org/ 72 | 73 | Once you've installed poetry, you can connect to virtual env and update dependencies: 74 | 75 | ``` 76 | poetry shell 77 | poetry update 78 | poetry install 79 | ``` 80 | 81 | ### requirements.txt 82 | 83 | Someone on your platform should generate it once with following command. 84 | 85 | ``` 86 | poetry export -f requirements.txt --output requirements.txt 87 | ``` 88 | 89 | ### Rust compiler for hugginface tokenizers 90 | 91 | - Hugginface tokenizers require a Rust compiler so install one. 92 | 93 | ### custom libs 94 | 95 | - If you want a specific version of any library, edit the pyproject.toml, add it and/or replace "*" by it. 96 | 97 | ## Running Tests 98 | 99 | Call this in the project folder to execute unit tests. 100 | 101 | ``` 102 | python -m unittest discover -s tests 103 | ``` 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /examples/nils_flax_script/MultiDatasetDataLoader.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | import random 4 | 5 | class MultiDatasetDataLoader: 6 | def __init__(self, datasets, batch_size_pairs, batch_size_triplets=None, dataset_size_temp=-1, allow_swap=True, random_batch_fraction=0): 7 | self.allow_swap = allow_swap 8 | self.collate_fn = None 9 | self.batch_size_pairs = batch_size_pairs 10 | self.batch_size_triplets = batch_size_pairs if batch_size_triplets is None else batch_size_triplets 11 | self.random_batch_fraction = random_batch_fraction 12 | 13 | # Compute dataset weights 14 | self.dataset_lengths = list(map(len, datasets)) 15 | self.dataset_lengths_sum = sum(self.dataset_lengths) 16 | 17 | weights = [] 18 | if dataset_size_temp > 0: # Scale probability with dataset size 19 | for dataset in datasets: 20 | prob = len(dataset) / self.dataset_lengths_sum 21 | weights.append(max(1, int(math.pow(prob, 1 / dataset_size_temp) * 1000))) 22 | else: # Equal weighting of all datasets 23 | weights = [100] * len(datasets) 24 | 25 | logging.info("Dataset lenghts and weights: {}".format(list(zip(self.dataset_lengths, weights)))) 26 | 27 | self.dataset_idx = [] 28 | self.dataset_idx_pointer = 0 29 | 30 | for idx, weight in enumerate(weights): 31 | self.dataset_idx.extend([idx] * weight) 32 | random.shuffle(self.dataset_idx) 33 | 34 | self.datasets = [] 35 | for dataset in datasets: 36 | random.shuffle(dataset) 37 | self.datasets.append({ 38 | 'elements': dataset, 39 | 'pointer': 0, 40 | }) 41 | 42 | def __iter__(self): 43 | for _ in range(int(self.__len__())): 44 | if self.random_batch_fraction > 0 and len(self.datasets) > 1 and self.random_batch_fraction > random.random(): 45 | batch = self.batch_all_datasets() 46 | else: 47 | batch = self.batch_one_dataset() 48 | 49 | yield self.collate_fn(batch) if self.collate_fn is not None else batch 50 | 51 | def batch_one_dataset(self): 52 | # Select dataset 53 | if self.dataset_idx_pointer >= len(self.dataset_idx): 54 | self.dataset_idx_pointer = 0 55 | random.shuffle(self.dataset_idx) 56 | 57 | dataset_idx = self.dataset_idx[self.dataset_idx_pointer] 58 | self.dataset_idx_pointer += 1 59 | 60 | # Select batch from this dataset 61 | dataset = self.datasets[dataset_idx] 62 | batch_size = self.batch_size_pairs if len(dataset['elements'][0].texts) == 2 else self.batch_size_triplets 63 | 64 | batch = [] 65 | texts_in_batch = set() 66 | guid_in_batch = set() 67 | while len(batch) < batch_size: 68 | example = dataset['elements'][dataset['pointer']] 69 | 70 | valid_example = True 71 | # First check if one of the texts in already in the batch 72 | for text in example.texts: 73 | text_norm = text.strip().lower() 74 | if text_norm in texts_in_batch: 75 | valid_example = False 76 | 77 | texts_in_batch.add(text_norm) 78 | 79 | # If the example has a label, check if label is in batch 80 | if example.guid is not None: 81 | valid_example = valid_example and example.guid not in guid_in_batch 82 | guid_in_batch.add(example.guid) 83 | 84 | 85 | if valid_example: 86 | if self.allow_swap and random.random() > 0.5: 87 | example.texts[0], example.texts[1] = example.texts[1], example.texts[0] 88 | 89 | batch.append(example) 90 | 91 | dataset['pointer'] += 1 92 | if dataset['pointer'] >= len(dataset['elements']): 93 | dataset['pointer'] = 0 94 | random.shuffle(dataset['elements']) 95 | 96 | return batch 97 | 98 | def batch_all_datasets(self): 99 | batch_size = None 100 | text_length = None 101 | batch = [] 102 | texts_in_batch = set() 103 | guid_in_batch = set() 104 | while batch_size is None or len(batch) < batch_size: 105 | # Select dataset 106 | if self.dataset_idx_pointer >= len(self.dataset_idx): 107 | self.dataset_idx_pointer = 0 108 | random.shuffle(self.dataset_idx) 109 | 110 | dataset_idx = self.dataset_idx[self.dataset_idx_pointer] 111 | self.dataset_idx_pointer += 1 112 | 113 | # Select batch from this dataset 114 | dataset = self.datasets[dataset_idx] 115 | 116 | if batch_size is None: #First example in a batch 117 | batch_size = self.batch_size_pairs if len(dataset['elements'][0].texts) == 2 else self.batch_size_triplets 118 | text_length = len(dataset['elements'][0].texts) 119 | else: #Additional example, check if format is the same 120 | if len(dataset['elements'][0].texts) != text_length: 121 | continue 122 | 123 | #Get the example 124 | example = dataset['elements'][dataset['pointer']] 125 | 126 | valid_example = True 127 | # First check if one of the texts in already in the batch 128 | for text in example.texts: 129 | text_norm = text.strip().lower() 130 | if text_norm in texts_in_batch: 131 | valid_example = False 132 | 133 | texts_in_batch.add(text_norm) 134 | 135 | # If the example has a label, check if label is in batch 136 | if example.guid is not None: 137 | valid_example = valid_example and example.guid not in guid_in_batch 138 | guid_in_batch.add(example.guid) 139 | 140 | if valid_example: 141 | if self.allow_swap and random.random() > 0.5: 142 | example.texts[0], example.texts[1] = example.texts[1], example.texts[0] 143 | 144 | batch.append(example) 145 | 146 | dataset['pointer'] += 1 147 | if dataset['pointer'] >= len(dataset['elements']): 148 | dataset['pointer'] = 0 149 | random.shuffle(dataset['elements']) 150 | 151 | return batch 152 | 153 | def __len__(self): 154 | return int(self.dataset_lengths_sum / self.batch_size_pairs) -------------------------------------------------------------------------------- /dataset/datasets_list.tsv: -------------------------------------------------------------------------------- 1 | Name Description Size (#Pairs) Performance Download link Source 2 | stackexchange_title_body_small (Title, Body) pairs from different StackExchanges 364,001 59.83 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/stackexchange_title_body_small.jsonl.gz 3 | gooaq_pairs (Question, Answer)-Pairs from Google auto suggest 3,012,496 59.06 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/gooaq_pairs.jsonl.gz https://github.com/allenai/gooaq 4 | msmarco-query_passage_negative (Query, Answer_Passage, hard_negative) from MS MARCO dataset 9,144,553 58.76 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/msmarco-query_passage_negative.jsonl.gz https://microsoft.github.io/msmarco/ 5 | yahoo_answers_title_answer (Title, Answer) pairs from Yahoo Answers 1,198,260 58.65 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/yahoo_answers_title_answer.jsonl.gz https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset 6 | stackexchange_duplicate_questions_title_title (Title, Title) pairs of duplicate questions from StackExchange 304,525 58.47 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/stackexchange_duplicate_questions_title_title.jsonl.gz 7 | msmarco-query_passage (Query, Answer_Passage) from MS MARCO dataset 532,751 58.28 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/msmarco-query_passage.jsonl.gz https://microsoft.github.io/msmarco/ 8 | eli5_question_answer (Question, Answer)-Pairs from ELI5 dataset 325,475 58.24 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/eli5_question_answer.jsonl.gz https://huggingface.co/datasets/eli5 9 | yahoo_answers_title_question (Title, Question_Body) pairs from Yahoo Answers 659,896 58.05 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/yahoo_answers_title_question.jsonl.gz https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset 10 | squad_pairs (Question, Answer_Passage) Pairs from SQuAD dataset 87,599 58.02 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/squad_pairs.jsonl.gz https://rajpurkar.github.io/SQuAD-explorer/ 11 | yahoo_answers_question_answer (Question_Body, Answer) pairs from Yahoo Answers 681,164 57.74 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/yahoo_answers_question_answer.jsonl.gz https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset 12 | NQ-train_pairs Training pairs (query, answer_passage) from the NQ dataset 100,231 57.48 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/NQ-train_pairs.jsonl.gz https://ai.google.com/research/NaturalQuestions/ 13 | quora_duplicates Duplicate question pairs from Quora 103,663 57.36 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/quora_duplicates.jsonl.gz https://quoradata.quora.com/First-Quora-Dataset-Release-Question-Pairs 14 | WikiAnswers_pairs Duplicate questions pairs from WikiAnswers 77,427,422 57.34 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/WikiAnswers_pairs.jsonl.gz https://github.com/afader/oqa#wikianswers-corpus 15 | stackexchange_duplicate_questions_title-body_title-body (Title+Body, Title+Body) pairs of duplicate questions from StackExchange 250,460 57.3 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/stackexchange_duplicate_questions_title-body_title-body.jsonl.gz 16 | S2ORC_citation_pairs Citation Pairs (Title, Title) of scientific publications from the S2OR corpus 52,603,982 57.28 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/S2ORC_citation_pairs.jsonl.gz http://s2-public-api-prod.us-west-2.elasticbeanstalk.com/corpus/ 17 | stackexchange_duplicate_questions_body_body (Body, Body) pairs of duplicate questions from StackExchange 250,519 57.26 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/stackexchange_duplicate_questions_body_body.jsonl.gz 18 | quora_duplicates_triplets Triplets (question, duplicate_question, hard_negative) from Quora 103,663 56.97 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/quora_duplicates_triplets.jsonl.gz https://quoradata.quora.com/First-Quora-Dataset-Release-Question-Pairs 19 | AllNLI "Combination of SNLI + MultiNLI 20 | Triplets: (Anchor, Entailment_Text, Contradiction_Text)" 277,230 56.57 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/AllNLI.jsonl.gz "https://nlp.stanford.edu/projects/snli/ 21 | https://cims.nyu.edu/~sbowman/multinli/" 22 | specter_train_triples Triplets (Title, related_title, hard_negative) for Scientific Publications from Specter 684,100 56.32 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/specter_train_triples.jsonl.gz https://arxiv.org/abs/2004.07180 23 | SimpleWiki Matched pairs (English_Wikipedia, Simple_English_Wikipedia) 102,225 56.15 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/SimpleWiki.jsonl.gz https://cs.pomona.edu/~dkauchak/simplification/ 24 | PAQ_pairs Training pairs (query, answer_passage) from the PAQ dataset 64,371,441 56.11 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/PAQ_pairs.jsonl.gz https://github.com/facebookresearch/PAQ 25 | altlex Matched pairs (English_Wikipedia, Simple_English_Wikipedia) 112,696 55.95 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/altlex.jsonl.gz https://github.com/chridey/altlex/ 26 | CodeSearchNet CodeSearchNet corpus is a dataset of 2 milllion (comment, code) pairs from opensource libraries hosted on GitHub. It contains code and documentation for several programming languages. 1,151,414 55.80 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/codesearchnet.jsonl.gz https://huggingface.co/datasets/code_search_net 27 | sentence-compression Pairs (long_text, short_text) about sentence-compression 180,000 55.63 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/sentence-compression.jsonl.gz https://github.com/google-research-datasets/sentence-compression 28 | TriviaQA_pairs Pairs (query, answer) from TriviaQA dataset 73,346 55.56 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/TriviaQA_pairs.jsonl.gz https://huggingface.co/datasets/trivia_qa 29 | flickr30k_captions Image caption pairs for the same image from Flickr30k dataset 317,695 54.68 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/flickr30k_captions.jsonl.gz https://shannon.cs.illinois.edu/DenotationGraph/ 30 | coco_captions Image caption pairs for the same image from COCO dataset 828,395 53.77 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/coco_captions.jsonl.gz http://cocodataset.org/ 31 | fever_train Training data from the FEVER corpus 139,051 52.63 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/fever_train.jsonl.gz https://arxiv.org/abs/1803.05355 -------------------------------------------------------------------------------- /examples/nils_flax_script/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adapted from the code-search-net example 3 | 4 | !! THE CODE IS NOT READY YET !! 5 | 6 | Open TODO: 7 | - Save the model 8 | - Evaluate the model if it actually learns sensible embeddings. E.g. evaluate on STS benchmark dataset 9 | - Compare results with PyTorch training script if comparable 10 | """ 11 | from jax.config import config 12 | 13 | from dataclasses import dataclass, field 14 | from functools import partial 15 | from typing import Callable, List, Union 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | import optax 20 | from flax import jax_utils, struct, traverse_util 21 | from flax.training import train_state 22 | from flax.training.common_utils import shard 23 | from tqdm.auto import tqdm 24 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 25 | from trainer.loss.custom import multiple_negatives_ranking_loss 26 | 27 | from transformers import AutoTokenizer, FlaxBertModel, FlaxAutoModel 28 | import sys 29 | import gzip 30 | import json 31 | import logging 32 | from sentence_transformers import InputExample 33 | from MultiDatasetDataLoader import MultiDatasetDataLoader 34 | from trainer.utils.ops import normalize_L2, mean_pooling 35 | 36 | 37 | @dataclass 38 | class TrainingArgs: 39 | model_id: str = "microsoft/MiniLM-L12-H384-uncased" 40 | max_epochs: int = 2 41 | batch_size: int = 256 42 | seed: int = 42 43 | lr: float = 2e-5 44 | init_lr: float = 0 45 | warmup_steps: int = 500 46 | weight_decay: float = 1e-2 47 | 48 | input1_maxlen: int = 128 49 | input2_maxlen: int = 128 50 | 51 | tr_data_files: List[str] = field( 52 | default_factory=lambda: [ 53 | "data/quora_duplicates.jsonl.gz", 54 | ] 55 | ) 56 | 57 | 58 | def warmup_and_constant(lr, init_lr, warmup_steps): 59 | warmup_fn = optax.linear_schedule(init_value=init_lr, end_value=lr, transition_steps=warmup_steps) 60 | constant_fn = optax.constant_schedule(value=lr) 61 | lr = optax.join_schedules(schedules=[warmup_fn, constant_fn], boundaries=[warmup_steps]) 62 | return lr 63 | 64 | 65 | def build_tx(lr, init_lr, warmup_steps, weight_decay): 66 | def weight_decay_mask(params): 67 | params = traverse_util.flatten_dict(params) 68 | mask = {k: (v[-1] != "bias" and v[-2:] != ("LayerNorm", "scale")) for k, v in params.items()} 69 | return traverse_util.unflatten_dict(mask) 70 | lr = warmup_and_constant(lr, init_lr, warmup_steps) 71 | tx = optax.adamw(learning_rate=lr, weight_decay=weight_decay, mask=weight_decay_mask) 72 | return tx, lr 73 | 74 | 75 | class TrainState(train_state.TrainState): 76 | loss_fn: Callable = struct.field(pytree_node=False) 77 | scheduler_fn: Callable = struct.field(pytree_node=False) 78 | 79 | 80 | 81 | @partial(jax.pmap, axis_name="batch") 82 | def train_step(state, model_input1, model_input2, drp_rng): 83 | train = True 84 | new_drp_rng, drp_rng = jax.random.split(drp_rng, 2) 85 | 86 | def loss_fn(params, model_input1, model_input2, drp_rng): 87 | def _forward(model_input): 88 | attention_mask = model_input["attention_mask"] 89 | model_output = state.apply_fn(**model_input, params=params, train=train, dropout_rng=drp_rng) 90 | 91 | embedding = mean_pooling(model_output, attention_mask) 92 | embedding = normalize_L2(embedding) 93 | 94 | # gather all the embeddings on same device for calculation loss over global batch 95 | embedding = jax.lax.all_gather(embedding, axis_name="batch") 96 | embedding = jnp.reshape(embedding, (-1, embedding.shape[-1])) 97 | 98 | return embedding 99 | 100 | embedding1, embedding2 = _forward(model_input1), _forward(model_input2) 101 | return state.loss_fn(embedding1, embedding2) 102 | 103 | grad_fn = jax.value_and_grad(loss_fn) 104 | loss, grads = grad_fn(state.params, model_input1, model_input2, drp_rng) 105 | state = state.apply_gradients(grads=grads) 106 | 107 | step = jax.lax.pmean(state.step, axis_name="batch") 108 | metrics = {"tr_loss": loss, "lr": state.scheduler_fn(step)} 109 | 110 | return state, metrics, new_drp_rng 111 | 112 | 113 | def get_batched_dataset(dataset, batch_size, seed=None): 114 | if seed is not None: 115 | dataset = dataset.shuffle(seed=seed) 116 | for i in range(len(dataset) // batch_size): 117 | batch = dataset[i*batch_size: (i+1)*batch_size] 118 | yield dict(batch) 119 | 120 | 121 | def data_collator(batch, tokenizer): 122 | texts1 = [e.texts[0] for e in batch] 123 | texts2 = [e.texts[1] for e in batch] 124 | 125 | model_input1 = tokenizer(texts1, return_tensors="jax", max_length=128, truncation=True, padding=True, pad_to_multiple_of=32) 126 | model_input2 = tokenizer(texts2, return_tensors="jax", max_length=128, truncation=True, padding=True, pad_to_multiple_of=32) 127 | model_input1, model_input2 = dict(model_input1), dict(model_input2) 128 | return shard(model_input1), shard(model_input2) 129 | 130 | 131 | def main(args, train_dataloader): 132 | config.update("jax_enable_x64", True) 133 | model = FlaxAutoModel.from_pretrained(args.model_id) 134 | tokenizer = AutoTokenizer.from_pretrained(args.model_id) 135 | 136 | 137 | tx_args = { 138 | "lr": args.lr, 139 | "init_lr": args.init_lr, 140 | "warmup_steps": args.warmup_steps, 141 | "weight_decay": args.weight_decay, 142 | } 143 | tx, lr = build_tx(**tx_args) 144 | 145 | state = TrainState.create( 146 | apply_fn=model.__call__, 147 | params=model.params, 148 | tx=tx, 149 | loss_fn=multiple_negatives_ranking_loss, 150 | scheduler_fn=lr, 151 | ) 152 | state = jax_utils.replicate(state) 153 | 154 | rng = jax.random.PRNGKey(args.seed) 155 | drp_rng = jax.random.split(rng, jax.device_count()) 156 | 157 | print("Train steps:", len(train_dataloader)) 158 | for epoch in range(args.max_epochs): 159 | # training step 160 | for batch in tqdm(train_dataloader, total=len(train_dataloader), desc=f"Running epoch-{epoch}"): 161 | model_input1, model_input2 = data_collator(batch, tokenizer) 162 | state, metrics, drp_rng = train_step(state, model_input1, model_input2, drp_rng) 163 | 164 | # evaluation step 165 | # for batch in get_batched_dataset(val_dataset, args.batch_size, seed=None): 166 | # model_input1, model_input2 = data_collator(batch) 167 | # state, metric = val_step(state, model_input1, model_input2) 168 | 169 | 170 | if __name__ == '__main__': 171 | 172 | steps_per_epoch = 2000 173 | batch_size_pairs = 256 174 | batch_size_triplets = 256 175 | 176 | datasets = [] 177 | for filepath in sys.argv[1:]: 178 | filepath = filepath.strip() 179 | dataset = [] 180 | 181 | with gzip.open(filepath, 'rt', encoding='utf8') as fIn: 182 | for line in fIn: 183 | data = json.loads(line.strip()) 184 | 185 | if not isinstance(data, dict): 186 | data = {'guid': None, 'texts': data} 187 | 188 | dataset.append(InputExample(guid=data.get('guid', None), texts=data['texts'])) 189 | if len(dataset) >= (steps_per_epoch * batch_size_pairs * 2): 190 | break 191 | 192 | datasets.append(dataset) 193 | logging.info("{}: {}".format(filepath, len(dataset))) 194 | 195 | train_dataloader = MultiDatasetDataLoader(datasets, batch_size_pairs=batch_size_pairs, batch_size_triplets=batch_size_triplets, random_batch_fraction=0.25) 196 | 197 | args = TrainingArgs() 198 | main(args, train_dataloader) -------------------------------------------------------------------------------- /datasets/stackexchange/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Converts the archive from 3 | https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml 4 | 5 | To jsonl format 6 | 7 | python convert_archives.py input_folder output_folder 8 | 9 | Returns the following cominations that pass certain quality checks: 10 | -> title, body combination 11 | -> title, highest_score_answer combination 12 | -> title + body, highest_score_answer combination 13 | -> title + body, highly_score_answer and low answer combinations 14 | """ 15 | #===================================================================================== 16 | import os 17 | import glob 18 | import json 19 | import gzip 20 | import random 21 | from typing import List, Any, IO, Dict 22 | from tqdm import tqdm 23 | import xml.etree.ElementTree as ET 24 | import re 25 | import py7zr 26 | import sys 27 | 28 | input_folder = sys.argv[1] 29 | output_folder = sys.argv[2] 30 | os.makedirs(output_folder, exist_ok=False) 31 | 32 | title_body_folder = output_folder + "/TitleBody/" 33 | os.makedirs(title_body_folder, exist_ok=False) 34 | 35 | title_answer_folder = output_folder + "/TitleAnswer/" 36 | os.makedirs(title_answer_folder, exist_ok=False) 37 | 38 | titlebody_answer_folder = output_folder + "/TitleBodyAnswer/" 39 | os.makedirs(titlebody_answer_folder, exist_ok=False) 40 | 41 | titlebody_best_worst_answer_folder = output_folder + "/TitleBodyBestWorstAnswer/" 42 | os.makedirs(titlebody_best_worst_answer_folder, exist_ok=False) 43 | 44 | random.seed(42) 45 | 46 | min_title_len = 20 47 | min_body_len = 20 48 | max_body_len = 4096 49 | min_score = 0 50 | #======================================================================================= 51 | #Creates a dictionary object for every question with Key as the question id and Value as a set of strings containing all details 52 | def create_dict_for_questions(posts): 53 | num_questions = 0 54 | mydict = {} 55 | for post in posts: 56 | data = post.attrib 57 | if data["PostTypeId"] == "1": 58 | add_data = [] 59 | title = re.sub("<.*?>", "", data["Title"]).strip() 60 | body = re.sub("<.*?>", "", data["Body"]).strip() 61 | num_questions += 1 62 | add_data.append(title) 63 | add_data.append(body) 64 | score = int(data["Score"]) 65 | if len(title) < min_title_len or len(body) < min_body_len or len(body) > max_body_len or score < min_score: #If length is greater/lesser or score is lesser 66 | continue 67 | mydict[int(data["Id"])] = add_data 68 | 69 | #For every answer, checks for the best/worst answer for it's corresponding question and changes accordingly. 70 | for post in posts: 71 | data = post.attrib 72 | if data["PostTypeId"] == "2": 73 | q_id = int(data["ParentId"]) 74 | if (q_id in mydict.keys()): #If the question was discarded 75 | answer = re.sub("<.*?>", "", data["Body"]).strip() 76 | score = int(data["Score"]) 77 | if len(mydict[q_id]) <= 2: #If this question was encountered first time in the answers list 78 | mydict[q_id].append(answer) #Adding question for maximum score question 79 | mydict[q_id].append(score) 80 | mydict[q_id].append(answer) #Adding question for minimum score question 81 | mydict[q_id].append(score) 82 | else: 83 | if mydict[q_id][3] < score: #Comparing if question has higher score than existing question 84 | mydict[q_id][3] = score 85 | mydict[q_id][2] = answer 86 | elif mydict[q_id][5] > score: #Comparing if question has lower score than existing question 87 | mydict[q_id][5] = score 88 | mydict[q_id][4] = answer 89 | return mydict 90 | #========================================================================================== 91 | def extract_title_body(mydict): 92 | pairs = [] #title_body combination 93 | for key in mydict: 94 | pairs.append([ mydict[key][0],mydict[key][1] ]) #title+body 95 | return pairs 96 | 97 | def extract_title_highestscored(mydict): 98 | pairs = [] #title_highestScoreAnswer 99 | for key in mydict: 100 | if len(mydict[key])>2: 101 | pairs.append([ mydict[key][0],mydict[key][2] ]) #title + highest scored 102 | return pairs 103 | 104 | def extract_title_body_highscore(mydict): 105 | pairs = [] #title_body_highestScoreAnswer 106 | for key in mydict: 107 | if len(mydict[key])>2: 108 | pairs.append([ mydict[key][0]+ " " +mydict[key][1], mydict[key][2] ]) #title+body, highest scored 109 | return pairs 110 | 111 | def extract_title_body_highscore_lowscore(mydict): 112 | pairs = [] #title_body_highestScoreAnswer_lowlyScoredAnswer 113 | for key in mydict: 114 | if len(mydict[key])>5: 115 | if ((mydict[key][3] - mydict[key][5] >= 100) or (mydict[key][5]<0 )): #If the best and least answers have a difference of 100 votes 116 | pairs.append([ mydict[key][0]+ " " +mydict[key][1], mydict[key][2], mydict[key][4] ]) #title+body, high scored,least scored 117 | return pairs 118 | #=========================================================================================== 119 | def parse_posts(f: IO[Any]) -> List[Dict]: 120 | tree = ET.parse(f) 121 | posts = tree.getroot() 122 | pairs = [] 123 | tags = [] 124 | num_questions = 0 125 | mydict = create_dict_for_questions(posts) 126 | return mydict 127 | #============================================================================================ 128 | def extract_posts(stack_exchange_file: str) -> List[Dict]: 129 | with py7zr.SevenZipFile(stack_exchange_file, mode="r") as z: 130 | fs = z.read(targets=["Posts.xml"]) 131 | if fs is not None and "Posts.xml" in fs: 132 | posts = parse_posts(fs["Posts.xml"]) 133 | return posts 134 | return [] 135 | #============================================================================================= 136 | def convert_to_jsonl_gz(input_file: str, output_file: str) -> None: 137 | mydict = extract_posts(input_file) 138 | #save title_body combination 139 | posts = extract_title_body(mydict) 140 | random.shuffle(posts) 141 | output_file = os.path.join(title_body_folder, f"{name}.jsonl.gz") 142 | if len(posts) == 0: 143 | return 144 | fOut = gzip.open(output_file, "wt") 145 | for post in posts: 146 | fOut.write(json.dumps(post)) 147 | fOut.write("\n") 148 | fOut.close() 149 | #save title_highestScoreAnswer combination 150 | posts = extract_title_highestscored(mydict) 151 | random.shuffle(posts) 152 | output_file = os.path.join(title_answer_folder, f"{name}.jsonl.gz") 153 | if len(posts) == 0: 154 | return 155 | fOut = gzip.open(output_file, "wt") 156 | for post in posts: 157 | fOut.write(json.dumps(post)) 158 | fOut.write("\n") 159 | fOut.close() 160 | #save title_body_highestScoreAnswer combination 161 | posts = extract_title_body_highscore(mydict) 162 | random.shuffle(posts) 163 | output_file = os.path.join(titlebody_answer_folder, f"{name}.jsonl.gz") 164 | if len(posts) == 0: 165 | return 166 | fOut = gzip.open(output_file, "wt") 167 | for post in posts: 168 | fOut.write(json.dumps(post)) 169 | fOut.write("\n") 170 | fOut.close() 171 | #save title_body_highestScoreAnswer_lowlyScoredAnswer combination 172 | posts = extract_title_body_highscore_lowscore(mydict) 173 | random.shuffle(posts) 174 | output_file = os.path.join(titlebody_best_worst_answer_folder, f"{name}.jsonl.gz") 175 | if len(posts) == 0: 176 | return 177 | fOut = gzip.open(output_file, "wt") 178 | for post in posts: 179 | fOut.write(json.dumps(post)) 180 | fOut.write("\n") 181 | fOut.close() 182 | #======================================================================================================== 183 | for filepath in sorted(glob.glob(os.path.join(input_folder, "*.7z")), key=os.path.getsize, reverse=True): 184 | name = os.path.basename(filepath.strip(".7z")) 185 | output_path = os.path.join(output_folder, f"{name}.jsonl.gz") 186 | print(filepath) 187 | convert_to_jsonl_gz(filepath, output_path) 188 | -------------------------------------------------------------------------------- /conversational-model/multi_context_train.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from functools import partial 3 | from typing import Callable, List, Union 4 | 5 | import jax 6 | from jax.config import config 7 | config.update('jax_enable_x64', True) 8 | import jax.numpy as jnp 9 | import optax 10 | from flax import jax_utils, struct, traverse_util 11 | from flax.training import train_state 12 | from flax.training.common_utils import shard 13 | from tqdm.auto import tqdm 14 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 15 | from transformers import AutoTokenizer, FlaxBertModel 16 | from datasets import load_dataset 17 | 18 | 19 | 20 | 21 | @dataclass 22 | class TrainingArgs: 23 | model_id: str = "bert-base-uncased" 24 | max_epochs: int = 5 25 | batch_size: int = 2 26 | seed: int = 42 27 | lr: float = 2e-5 28 | init_lr: float = 1e-5 29 | warmup_steps: int = 2000 30 | weight_decay: float = 1e-3 31 | 32 | current_context_maxlen: int = 128 33 | past_context_maxlen: int = 128 34 | response_maxlen:int = 128 35 | max_past_contexts:int = 10 36 | 37 | tr_data_files: List[str] = field( 38 | default_factory=lambda: [ 39 | "data/dummy.jsonl", 40 | ] 41 | ) 42 | 43 | 44 | def scheduler_fn(lr, init_lr, warmup_steps, num_train_steps): 45 | decay_steps = num_train_steps - warmup_steps 46 | warmup_fn = optax.linear_schedule(init_value=init_lr, end_value=lr, transition_steps=warmup_steps) 47 | decay_fn = optax.linear_schedule(init_value=lr, end_value=1e-7, transition_steps=decay_steps) 48 | lr = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps]) 49 | return lr 50 | 51 | 52 | def build_tx(lr, init_lr, warmup_steps, num_train_steps, weight_decay): 53 | def weight_decay_mask(params): 54 | params = traverse_util.flatten_dict(params) 55 | mask = {k: (v[-1] != "bias" and v[-2:] != ("LayerNorm", "scale")) for k, v in params.items()} 56 | return traverse_util.unflatten_dict(mask) 57 | lr = scheduler_fn(lr, init_lr, warmup_steps, num_train_steps) 58 | tx = optax.adamw(learning_rate=lr, weight_decay=weight_decay, mask=weight_decay_mask) 59 | return tx, lr 60 | 61 | 62 | class TrainState(train_state.TrainState): 63 | loss_fn: Callable = struct.field(pytree_node=False) 64 | scheduler_fn: Callable = struct.field(pytree_node=False) 65 | 66 | 67 | def multiple_negative_ranking_loss(embedding1, embedding2): 68 | def _cross_entropy(logits): 69 | bsz = logits.shape[-1] 70 | labels = (jnp.arange(bsz)[..., None] == jnp.arange(bsz)[None]).astype("f4") 71 | logits = jax.nn.log_softmax(logits, axis=-1) 72 | loss = -jnp.sum(labels * logits, axis=-1) 73 | return loss 74 | 75 | batch_similarity = jnp.dot(embedding1, jnp.transpose(embedding2)) 76 | return _cross_entropy(batch_similarity) 77 | 78 | 79 | @partial(jax.pmap, axis_name="batch") 80 | def train_step(state, current_context_input, response_input, past_context_input, drp_rng): 81 | train = True 82 | new_drp_rng, drp_rng = jax.random.split(drp_rng, 2) 83 | 84 | def loss_fn(params, current_context_input, response_input, past_context_input, drp_rng): 85 | def _forward(model_input): 86 | attention_mask = model_input["attention_mask"][..., None] 87 | embedding = state.apply_fn(**model_input, params=params, train=train, dropout_rng=drp_rng)[0] 88 | attention_mask = jnp.broadcast_to(attention_mask, jnp.shape(embedding)) 89 | 90 | embedding = embedding * attention_mask 91 | embedding = jnp.mean(embedding, axis=1) 92 | 93 | modulus = jnp.sum(jnp.square(embedding), axis=-1, keepdims=True) 94 | embedding = embedding / jnp.maximum(modulus, 1e-12) 95 | 96 | # gather all the embeddings on same device for calculation loss over global batch 97 | embedding = jax.lax.all_gather(embedding, axis_name="batch") 98 | embedding = jnp.reshape(embedding, (-1, embedding.shape[-1])) 99 | 100 | return embedding 101 | 102 | current_context_emb, response_emb,past_context_emb = _forward(current_context_input), _forward(response_input),_forward(past_context_input) 103 | full_context_emb = (current_context_emb + past_context_emb)/2 104 | current_context_response_loss = state.loss_fn(current_context_emb, response_emb) 105 | past_context_response_loss = state.loss_fn(past_context_emb,response_emb) 106 | full_context_response_loss = state.loss_fn(full_context_emb,response_emb) 107 | # loss considering 108 | # 1) the interaction between the immediate context and its accompanying response, 109 | # 2) the interaction of the response with up to N past contexts from the conversation history, 110 | # as well as 3) the interaction of the full context with the response 111 | loss = (current_context_response_loss + past_context_response_loss + full_context_response_loss) / 3 112 | return jnp.mean(loss) 113 | 114 | grad_fn = jax.value_and_grad(loss_fn) 115 | loss, grads = grad_fn(state.params, current_context_input, response_input, past_context_input, drp_rng) 116 | state = state.apply_gradients(grads=grads) 117 | 118 | step = jax.lax.pmean(state.step, axis_name="batch") 119 | metrics = {"tr_loss": loss, "lr": state.scheduler_fn(step)} 120 | 121 | return state, metrics, new_drp_rng 122 | 123 | 124 | def get_batched_dataset(dataset, batch_size, seed=None): 125 | if seed is not None: 126 | dataset = dataset.shuffle(seed=seed) 127 | for i in range(len(dataset) // batch_size): 128 | batch = dataset[i*batch_size: (i+1)*batch_size] 129 | yield dict(batch) 130 | 131 | 132 | @dataclass 133 | class DataCollator: 134 | tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer] 135 | current_context_maxlen: int = 128 136 | response_maxlen: int = 128 137 | past_context_maxlen: int = 128 138 | max_past_contexts:int = 10 139 | 140 | def _prepare_past_context(self,batch): 141 | """ 142 | concatenation of past contexts - contexts are sorted to have the most recent context first and so on 143 | """ 144 | past_contexts = [] 145 | keys = list(batch.keys()) 146 | keys.sort() 147 | past_context_tuples = zip(*[batch[key] for key in keys if key.startswith("context/")]) 148 | for past_context_tuple in past_context_tuples: 149 | past_context_tuple = tuple(p_ctxt for p_ctxt in past_context_tuple if p_ctxt is not None) 150 | past_contexts.append(" ".join(past_context_tuple[:self.max_past_contexts])) 151 | return past_contexts 152 | 153 | def __call__(self, batch): 154 | # Currently only static padding; TODO: change below for adding dynamic padding support 155 | past_contexts = self._prepare_past_context(batch) 156 | current_context_input = self.tokenizer(batch["context"], return_tensors="jax", max_length=self.current_context_maxlen, truncation=True, padding="max_length") 157 | response_input = self.tokenizer(batch["response"], return_tensors="jax", max_length=self.response_maxlen, truncation=True, padding="max_length") 158 | past_context_input = self.tokenizer(past_contexts, return_tensors="jax",max_length=self.past_context_maxlen, truncation=True, padding="max_length") 159 | current_context_input, response_input,past_context_input = dict(current_context_input), dict(response_input),dict(past_context_input) 160 | return shard(current_context_input), shard(response_input), shard(past_context_input) 161 | 162 | 163 | def main(args): 164 | # code is generic to any other model as well 165 | model = FlaxBertModel.from_pretrained(args.model_id) 166 | tokenizer = AutoTokenizer.from_pretrained(args.model_id) 167 | 168 | data_collator = DataCollator( 169 | tokenizer=tokenizer, 170 | current_context_maxlen=args.current_context_maxlen, 171 | response_maxlen=args.response_maxlen, 172 | past_context_maxlen=args.past_context_maxlen, 173 | max_past_contexts = args.max_past_contexts 174 | ) 175 | 176 | tr_dataset = load_dataset("json", data_files=args.tr_data_files, split="train") 177 | columns_to_remove = ['response_author', 'context_author', 'subreddit', 'thread_id'] 178 | tr_dataset = tr_dataset.remove_columns(columns_to_remove) 179 | # drop extra batch from the end 180 | num_tr_samples = len(tr_dataset) - len(tr_dataset) % args.batch_size 181 | tr_dataset = tr_dataset.shuffle(seed=args.seed).select(range(num_tr_samples)) 182 | print(tr_dataset) 183 | 184 | tx_args = { 185 | "lr": args.lr, 186 | "init_lr": args.init_lr, 187 | "warmup_steps": args.warmup_steps, 188 | "num_train_steps": (len(tr_dataset) // args.batch_size) * args.max_epochs, 189 | "weight_decay": args.weight_decay, 190 | } 191 | tx, lr = build_tx(**tx_args) 192 | 193 | state = TrainState.create( 194 | apply_fn=model.__call__, 195 | params=model.params, 196 | tx=tx, 197 | loss_fn=multiple_negative_ranking_loss, 198 | scheduler_fn=lr, 199 | ) 200 | state = jax_utils.replicate(state) 201 | 202 | rng = jax.random.PRNGKey(args.seed) 203 | drp_rng = jax.random.split(rng, jax.device_count()) 204 | for epoch in range(args.max_epochs): 205 | # training step 206 | batch_iterator = get_batched_dataset(tr_dataset, args.batch_size, seed=epoch) 207 | for batch in tqdm(batch_iterator, desc=f"Running epoch-{epoch}"): 208 | current_context_input, response_input,past_context_input = data_collator(batch) 209 | state, metrics, drp_rng = train_step(state,current_context_input, response_input,past_context_input, drp_rng) 210 | print(metrics) 211 | 212 | # evaluation step 213 | # for batch in get_batched_dataset(val_dataset, args.batch_size, seed=None): 214 | # model_input1, model_input2 = data_collator(batch) 215 | # state, metric = val_step(state, model_input1, model_input2) 216 | 217 | if __name__ == "__main__": 218 | import json,os 219 | 220 | jsons = [{'context': 'Taste good though. ', 221 | 'context/0': 'Deer are, in fact, massive ****.', 222 | 'context/1': "Freaking deer. You can always tell the country people. They can't stand deer. Just giant garden destroying rats. \n\nHell, that's", 223 | 'context/3': 'I lived in Germany when I was 5. I have this memory of my dad stopping to pick up a hedgehog that was in the middle of the road', 224 | 'context/2': "Kinda like when visitors from more populated areas see deer here. They're not at all rare but people act like they are", 225 | 'context_author': 'KakarotMaag', 226 | 'response': "Ground venison mixed with beef fat is the best burger I've ever had.", 227 | 'response_author': 'SG804', 228 | 'subreddit': 'gifs', 229 | 'thread_id': '98fur0'},{ 230 | 'context/1': "Hello, how are you?", 231 | 'context/0': "I am fine. And you?", 232 | 'context': "Great. What do you think of the weather?", 233 | 'response': "It doesn't feel like February.", 234 | 'context_author': 'KakarotMaag', 235 | 'response_author': 'SG804', 236 | 'subreddit': 'gifs', 237 | 'thread_id': '98fur0' 238 | }] 239 | 240 | os.makedirs("data/",exist_ok=True) 241 | with open('data/dummy.jsonl', 'w') as outfile: 242 | for entry in jsons: 243 | json.dump(entry, outfile) 244 | outfile.write('\n') 245 | 246 | args = TrainingArgs() 247 | main(args) 248 | 249 | -------------------------------------------------------------------------------- /code-search-net/train_code_search_net.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field, asdict, replace 2 | from functools import partial 3 | from typing import Callable, List, Union 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import optax 8 | from flax import jax_utils, struct, traverse_util 9 | from flax.training import train_state 10 | from flax.serialization import to_bytes, from_bytes 11 | from flax.training.common_utils import shard 12 | from tqdm.auto import tqdm 13 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 14 | from trainer.loss.custom import multiple_negatives_ranking_loss 15 | 16 | import wandb 17 | import json 18 | import os 19 | 20 | from datasets import load_dataset, DatasetDict 21 | from transformers import AutoTokenizer, FlaxAutoModel 22 | 23 | 24 | @dataclass 25 | class TrainingArgs: 26 | model_id: str = "bert-base-uncased" 27 | max_epochs: int = 2 28 | batch_size_per_device: int = 32 29 | seed: int = 42 30 | lr: float = 2e-5 31 | init_lr: float = 1e-5 32 | warmup_steps: int = 2000 33 | weight_decay: float = 1e-3 34 | 35 | input1_maxlen: int = 128 36 | input2_maxlen: int = 128 37 | 38 | logging_steps: int = 20 39 | save_dir: str = "checkpoints" 40 | 41 | tr_data_files: List[str] = field( 42 | default_factory=lambda: [ 43 | "data/dummy.jsonl", 44 | ] 45 | ) 46 | 47 | val_data_files: List[str] = field( 48 | default_factory=lambda: [ 49 | "data/dummy.jsonl", 50 | ] 51 | ) 52 | 53 | def __post_init__(self): 54 | self.batch_size = self.batch_size_per_device * jax.device_count() 55 | 56 | 57 | def scheduler_fn(lr, init_lr, warmup_steps, num_train_steps): 58 | decay_steps = num_train_steps - warmup_steps 59 | warmup_fn = optax.linear_schedule(init_value=init_lr, end_value=lr, transition_steps=warmup_steps) 60 | decay_fn = optax.linear_schedule(init_value=lr, end_value=1e-7, transition_steps=decay_steps) 61 | lr = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps]) 62 | return lr 63 | 64 | 65 | def build_tx(lr, init_lr, warmup_steps, num_train_steps, weight_decay): 66 | def weight_decay_mask(params): 67 | params = traverse_util.flatten_dict(params) 68 | mask = {k: (v[-1] != "bias" and v[-2:] != ("LayerNorm", "scale")) for k, v in params.items()} 69 | return traverse_util.unflatten_dict(mask) 70 | lr = scheduler_fn(lr, init_lr, warmup_steps, num_train_steps) 71 | tx = optax.adamw(learning_rate=lr, weight_decay=weight_decay, mask=weight_decay_mask) 72 | return tx, lr 73 | 74 | 75 | class TrainState(train_state.TrainState): 76 | loss_fn: Callable = struct.field(pytree_node=False) 77 | scheduler_fn: Callable = struct.field(pytree_node=False) 78 | 79 | 80 | @partial(jax.pmap, axis_name="batch") 81 | def train_step(state, model_input1, model_input2, drp_rng): 82 | train = True 83 | new_drp_rng, drp_rng = jax.random.split(drp_rng, 2) 84 | 85 | def loss_fn(params, model_input1, model_input2, drp_rng): 86 | def _forward(model_input): 87 | attention_mask = model_input["attention_mask"][..., None] 88 | embedding = state.apply_fn(**model_input, params=params, train=train, dropout_rng=drp_rng)[0] 89 | attention_mask = jnp.broadcast_to(attention_mask, jnp.shape(embedding)) 90 | 91 | embedding = embedding * attention_mask 92 | embedding = jnp.mean(embedding, axis=1) 93 | 94 | modulus = jnp.sum(jnp.square(embedding), axis=-1, keepdims=True) 95 | embedding = embedding / jnp.maximum(modulus, 1e-12) 96 | 97 | # gather all the embeddings on same device for calculation loss over global batch 98 | embedding = jax.lax.all_gather(embedding, axis_name="batch") 99 | embedding = jnp.reshape(embedding, (-1, embedding.shape[-1])) 100 | 101 | return embedding 102 | 103 | embedding1, embedding2 = _forward(model_input1), _forward(model_input2) 104 | return state.loss_fn(embedding1, embedding2) 105 | 106 | grad_fn = jax.value_and_grad(loss_fn) 107 | loss, grads = grad_fn(state.params, model_input1, model_input2, drp_rng) 108 | state = state.apply_gradients(grads=grads) 109 | 110 | metrics = {"tr_loss": loss, "lr": state.scheduler_fn(state.step)} 111 | return state, metrics, new_drp_rng 112 | 113 | 114 | @partial(jax.pmap, axis_name="batch") 115 | def val_step(state, model_inputs1, model_inputs2): 116 | train = False 117 | 118 | def _forward(model_input): 119 | attention_mask = model_input["attention_mask"][..., None] 120 | embedding = state.apply_fn(**model_input, params=state.params, train=train)[0] 121 | attention_mask = jnp.broadcast_to(attention_mask, jnp.shape(embedding)) 122 | 123 | embedding = embedding * attention_mask 124 | embedding = jnp.mean(embedding, axis=1) 125 | 126 | modulus = jnp.sum(jnp.square(embedding), axis=-1, keepdims=True) 127 | embedding = embedding / jnp.maximum(modulus, 1e-12) 128 | 129 | # gather all the embeddings on same device for calculation loss over global batch 130 | embedding = jax.lax.all_gather(embedding, axis_name="batch") 131 | embedding = jnp.reshape(embedding, (-1, embedding.shape[-1])) 132 | 133 | return embedding 134 | 135 | embedding1, embedding2 = _forward(model_inputs1), _forward(model_inputs2) 136 | loss = state.loss_fn(embedding1, embedding2) 137 | return jnp.mean(loss) 138 | 139 | 140 | def get_batched_dataset(dataset, batch_size, seed=None): 141 | if seed is not None: 142 | dataset = dataset.shuffle(seed=seed) 143 | for i in range(len(dataset) // batch_size): 144 | batch = dataset[i*batch_size: (i+1)*batch_size] 145 | yield dict(batch) 146 | 147 | 148 | @dataclass 149 | class DataCollator: 150 | tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer] 151 | input1_maxlen: int = 128 152 | input2_maxlen: int = 128 153 | 154 | def __call__(self, batch): 155 | # Currently only static padding; TODO: change below for adding dynamic padding support 156 | model_input1 = self.tokenizer(batch["docstring"], return_tensors="jax", max_length=self.input1_maxlen, truncation=True, padding="max_length") 157 | model_input2 = self.tokenizer(batch["code"], return_tensors="jax", max_length=self.input2_maxlen, truncation=True, padding="max_length") 158 | model_input1, model_input2 = dict(model_input1), dict(model_input2) 159 | return shard(model_input1), shard(model_input2) 160 | 161 | 162 | def save_checkpoint(save_dir, state, save_fn=None, training_args=None): 163 | print(f"saving checkpoint in {save_dir}", end=" ... ") 164 | 165 | os.makedirs(save_dir, exist_ok=True) 166 | state = jax_utils.unreplicate(state) 167 | 168 | if save_fn is not None: 169 | # saving model in HF fashion 170 | save_fn(save_dir, params=state.params) 171 | else: 172 | path = os.path.join(save_dir, "flax_model.msgpack") 173 | with open(path, "wb") as f: 174 | f.write(to_bytes(state.params)) 175 | 176 | # this will save optimizer states 177 | path = os.path.join(save_dir, "opt_state.msgpack") 178 | with open(path, "wb") as f: 179 | f.write(to_bytes(state.opt_state)) 180 | 181 | if training_args is not None: 182 | path = os.path.join(save_dir, "training_args.json") 183 | with open(path, "w") as f: 184 | json.dump(asdict(training_args), f) 185 | 186 | print("done!!") 187 | 188 | 189 | def prepare_dataset(args): 190 | tr_dataset = load_dataset("json", data_files=args.tr_data_files, split="train") 191 | val_dataset = load_dataset("json", data_files=args.val_data_files, split="train") 192 | 193 | # ensures similar processing to all splits at once 194 | dataset = DatasetDict(train=tr_dataset, validation=val_dataset) 195 | 196 | columns_to_remove = ['repo', 'path', 'func_name', 'original_string', 'sha', 'url', 'partition'] 197 | dataset = dataset.remove_columns(columns_to_remove) 198 | 199 | # drop extra batch from the end 200 | for split in dataset: 201 | num_samples = len(dataset[split]) - len(dataset[split]) % args.batch_size 202 | dataset[split] = dataset[split].shuffle(seed=args.seed).select(range(num_samples)) 203 | 204 | print(dataset) 205 | tr_dataset, val_dataset = dataset["train"], dataset["validation"] 206 | return tr_dataset, val_dataset 207 | 208 | 209 | def main(args, logger): 210 | os.makedirs(args.save_dir, exist_ok=True) 211 | 212 | model = FlaxAutoModel.from_pretrained(args.model_id) 213 | tokenizer = AutoTokenizer.from_pretrained(args.model_id) 214 | 215 | data_collator = DataCollator( 216 | tokenizer=tokenizer, 217 | input1_maxlen=args.input1_maxlen, 218 | input2_maxlen=args.input2_maxlen, 219 | ) 220 | 221 | tr_dataset, val_dataset = prepare_dataset(args) 222 | 223 | tx_args = { 224 | "lr": args.lr, 225 | "init_lr": args.init_lr, 226 | "warmup_steps": args.warmup_steps, 227 | "num_train_steps": (len(tr_dataset) // args.batch_size) * args.max_epochs, 228 | "weight_decay": args.weight_decay, 229 | } 230 | tx, lr = build_tx(**tx_args) 231 | 232 | state = TrainState.create( 233 | apply_fn=model.__call__, 234 | params=model.params, 235 | tx=tx, 236 | loss_fn=multiple_negatives_ranking_loss, 237 | scheduler_fn=lr, 238 | ) 239 | state = jax_utils.replicate(state) 240 | 241 | rng = jax.random.PRNGKey(args.seed) 242 | drp_rng = jax.random.split(rng, jax.device_count()) 243 | for epoch in range(args.max_epochs): 244 | # training step 245 | total = len(tr_dataset) // args.batch_size 246 | batch_iterator = get_batched_dataset(tr_dataset, args.batch_size, seed=epoch) 247 | for i, batch in tqdm(enumerate(batch_iterator), desc=f"Running epoch-{epoch}", total=total): 248 | model_input1, model_input2 = data_collator(batch) 249 | state, metrics, drp_rng = train_step(state, model_input1, model_input2, drp_rng) 250 | 251 | if (i + 1) % args.logging_steps == 0: 252 | tr_loss = jax_utils.unreplicate(metrics["tr_loss"]).item() 253 | tqdm.write(str(dict(tr_loss=tr_loss, step=i+1))) 254 | logger.log({ 255 | "tr_loss": tr_loss, 256 | "step": i + 1, 257 | }, commit=True) 258 | 259 | # evaluation 260 | val_loss = jnp.array(0.) 261 | total = len(val_dataset) // args.batch_size 262 | val_batch_iterator = get_batched_dataset(val_dataset, args.batch_size, seed=None) 263 | for j, batch in tqdm(enumerate(val_batch_iterator), desc=f"evaluating after epoch-{epoch}", total=total): 264 | model_input1, model_input2 = data_collator(batch) 265 | val_step_loss = val_step(state, model_input1, model_input2) 266 | val_loss += jax_utils.unreplicate(val_step_loss) 267 | 268 | val_loss = val_loss.item() / (j + 1) 269 | print(f"val_loss: {val_loss}") 270 | logger.log({"val_loss": val_loss}, commit=True) 271 | 272 | save_dir = args.save_dir + f"-epoch-{epoch}" 273 | save_checkpoint(save_dir, state, save_fn=model.save_pretrained, training_args=args) 274 | 275 | 276 | if __name__ == '__main__': 277 | 278 | args = TrainingArgs() 279 | logger = wandb.init(project="code-search-net", config=asdict(args)) 280 | logging_dict = dict(logger.config); logging_dict["save_dir"] += f"-{logger.id}" 281 | args = replace(args, **logging_dict) 282 | 283 | print(args) 284 | main(args, logger) 285 | -------------------------------------------------------------------------------- /dataset_list/stackexchange/download_archive_file_list.tsv: -------------------------------------------------------------------------------- 1 | 3dprinting.meta.stackexchange.com.7z 07-Jun-2021 11:10 651.9K 2 | 3dprinting.stackexchange.com.7z 07-Jun-2021 11:10 13.2M 3 | academia.meta.stackexchange.com.7z 07-Jun-2021 11:10 4.6M 4 | academia.stackexchange.com.7z 07-Jun-2021 11:10 134.0M 5 | ai.meta.stackexchange.com.7z 07-Jun-2021 11:10 834.6K 6 | ai.stackexchange.com.7z 07-Jun-2021 11:10 23.9M 7 | android.meta.stackexchange.com.7z 07-Jun-2021 11:11 2.9M 8 | android.stackexchange.com.7z 07-Jun-2021 11:11 105.0M 9 | anime.meta.stackexchange.com.7z 07-Jun-2021 11:11 3.9M 10 | anime.stackexchange.com.7z 07-Jun-2021 11:11 31.0M 11 | apple.meta.stackexchange.com.7z 07-Jun-2021 11:11 4.3M 12 | apple.stackexchange.com.7z 07-Jun-2021 11:11 236.0M 13 | arduino.meta.stackexchange.com.7z 07-Jun-2021 11:11 929.8K 14 | arduino.stackexchange.com.7z 07-Jun-2021 11:11 65.9M 15 | askubuntu.com.7z 07-Jun-2021 11:12 861.1M 16 | astronomy.meta.stackexchange.com.7z 07-Jun-2021 11:12 911.0K 17 | astronomy.stackexchange.com.7z 07-Jun-2021 11:12 33.1M 18 | aviation.meta.stackexchange.com.7z 07-Jun-2021 11:12 2.2M 19 | aviation.stackexchange.com.7z 07-Jun-2021 11:12 78.2M 20 | avp.meta.stackexchange.com.7z 07-Jun-2021 11:12 546.9K 21 | avp.stackexchange.com.7z 07-Jun-2021 11:12 16.7M 22 | beer.meta.stackexchange.com.7z 07-Jun-2021 11:12 249.0K 23 | beer.stackexchange.com.7z 07-Jun-2021 11:12 3.5M 24 | bicycles.meta.stackexchange.com.7z 07-Jun-2021 11:12 1.4M 25 | bicycles.stackexchange.com.7z 07-Jun-2021 11:12 55.2M 26 | bioinformatics.meta.stackexchange.com.7z 07-Jun-2021 11:12 322.0K 27 | bioinformatics.stackexchange.com.7z 07-Jun-2021 11:12 10.6M 28 | biology.meta.stackexchange.com.7z 07-Jun-2021 11:12 2.8M 29 | biology.stackexchange.com.7z 07-Jun-2021 11:12 66.8M 30 | bitcoin.meta.stackexchange.com.7z 07-Jun-2021 11:12 1.2M 31 | bitcoin.stackexchange.com.7z 07-Jun-2021 11:13 54.5M 32 | blender.meta.stackexchange.com.7z 07-Jun-2021 11:13 2.5M 33 | blender.stackexchange.com.7z 07-Jun-2021 11:13 125.0M 34 | boardgames.meta.stackexchange.com.7z 07-Jun-2021 11:13 2.2M 35 | boardgames.stackexchange.com.7z 07-Jun-2021 11:13 34.7M 36 | bricks.meta.stackexchange.com.7z 07-Jun-2021 11:13 583.6K 37 | bricks.stackexchange.com.7z 07-Jun-2021 11:13 8.6M 38 | buddhism.meta.stackexchange.com.7z 07-Jun-2021 11:13 2.0M 39 | buddhism.stackexchange.com.7z 07-Jun-2021 11:13 42.2M 40 | cardano.meta.stackexchange.com.7z 07-Jun-2021 11:13 60.1K 41 | cardano.stackexchange.com.7z 07-Jun-2021 11:13 556.6K 42 | chemistry.meta.stackexchange.com.7z 07-Jun-2021 11:13 4.8M 43 | chemistry.stackexchange.com.7z 07-Jun-2021 11:13 95.2M 44 | chess.meta.stackexchange.com.7z 07-Jun-2021 11:13 902.0K 45 | chess.stackexchange.com.7z 07-Jun-2021 11:13 24.7M 46 | chinese.meta.stackexchange.com.7z 07-Jun-2021 11:13 1.2M 47 | chinese.stackexchange.com.7z 07-Jun-2021 11:13 26.4M 48 | christianity.meta.stackexchange.com.7z 07-Jun-2021 11:13 6.5M 49 | christianity.stackexchange.com.7z 07-Jun-2021 11:14 87.3M 50 | civicrm.meta.stackexchange.com.7z 07-Jun-2021 11:14 161.4K 51 | civicrm.stackexchange.com.7z 07-Jun-2021 11:14 18.3M 52 | codegolf.meta.stackexchange.com.7z 07-Jun-2021 11:14 20.8M 53 | codegolf.stackexchange.com.7z 07-Jun-2021 11:14 321.2M 54 | codereview.meta.stackexchange.com.7z 07-Jun-2021 11:14 9.7M 55 | codereview.stackexchange.com.7z 07-Jun-2021 11:14 458.7M 56 | coffee.meta.stackexchange.com.7z 07-Jun-2021 11:14 246.6K 57 | coffee.stackexchange.com.7z 07-Jun-2021 11:14 3.9M 58 | cogsci.meta.stackexchange.com.7z 07-Jun-2021 11:14 2.0M 59 | cogsci.stackexchange.com.7z 07-Jun-2021 11:14 22.9M 60 | computergraphics.meta.stackexchange.com.7z 07-Jun-2021 11:14 360.9K 61 | computergraphics.stackexchange.com.7z 07-Jun-2021 11:14 9.0M 62 | conlang.meta.stackexchange.com.7z 07-Jun-2021 11:15 173.5K 63 | conlang.stackexchange.com.7z 07-Jun-2021 11:15 1.5M 64 | cooking.meta.stackexchange.com.7z 07-Jun-2021 11:15 3.1M 65 | cooking.stackexchange.com.7z 07-Jun-2021 11:15 65.4M 66 | craftcms.meta.stackexchange.com.7z 07-Jun-2021 11:15 197.1K 67 | craftcms.stackexchange.com.7z 07-Jun-2021 11:15 19.3M 68 | crafts.meta.stackexchange.com.7z 07-Jun-2021 11:15 538.8K 69 | crafts.stackexchange.com.7z 07-Jun-2021 11:15 5.6M 70 | crypto.meta.stackexchange.com.7z 07-Jun-2021 11:15 2.0M 71 | crypto.stackexchange.com.7z 07-Jun-2021 11:23 76.8M 72 | cs.meta.stackexchange.com.7z 07-Jun-2021 11:23 2.8M 73 | cs.stackexchange.com.7z 07-Jun-2021 11:23 97.6M 74 | cseducators.meta.stackexchange.com.7z 07-Jun-2021 11:23 631.1K 75 | cseducators.stackexchange.com.7z 07-Jun-2021 11:23 7.7M 76 | cstheory.meta.stackexchange.com.7z 07-Jun-2021 11:23 2.2M 77 | cstheory.stackexchange.com.7z 07-Jun-2021 11:23 35.7M 78 | datascience.meta.stackexchange.com.7z 07-Jun-2021 11:23 741.9K 79 | datascience.stackexchange.com.7z 07-Jun-2021 11:23 61.7M 80 | dba.meta.stackexchange.com.7z 07-Jun-2021 11:23 3.2M 81 | dba.stackexchange.com.7z 07-Jun-2021 11:24 255.9M 82 | devops.meta.stackexchange.com.7z 07-Jun-2021 11:24 393.9K 83 | devops.stackexchange.com.7z 07-Jun-2021 11:24 11.2M 84 | diy.meta.stackexchange.com.7z 07-Jun-2021 11:24 1.5M 85 | diy.stackexchange.com.7z 07-Jun-2021 11:24 139.0M 86 | drones.meta.stackexchange.com.7z 07-Jun-2021 11:24 153.2K 87 | drones.stackexchange.com.7z 07-Jun-2021 11:24 1.6M 88 | drupal.meta.stackexchange.com.7z 07-Jun-2021 11:24 2.5M 89 | drupal.stackexchange.com.7z 07-Jun-2021 11:24 144.7M 90 | dsp.meta.stackexchange.com.7z 07-Jun-2021 11:24 805.0K 91 | dsp.stackexchange.com.7z 07-Jun-2021 11:24 60.5M 92 | earthscience.meta.stackexchange.com.7z 07-Jun-2021 11:24 1.0M 93 | earthscience.stackexchange.com.7z 07-Jun-2021 11:25 16.7M 94 | ebooks.meta.stackexchange.com.7z 07-Jun-2021 11:25 290.7K 95 | ebooks.stackexchange.com.7z 07-Jun-2021 11:25 3.9M 96 | economics.meta.stackexchange.com.7z 07-Jun-2021 11:25 1.2M 97 | economics.stackexchange.com.7z 07-Jun-2021 11:25 29.3M 98 | electronics.meta.stackexchange.com.7z 07-Jun-2021 11:25 6.1M 99 | electronics.stackexchange.com.7z 07-Jun-2021 11:25 437.0M 100 | elementaryos.meta.stackexchange.com.7z 07-Jun-2021 11:25 256.2K 101 | elementaryos.stackexchange.com.7z 07-Jun-2021 11:25 12.4M 102 | ell.meta.stackexchange.com.7z 07-Jun-2021 11:25 5.1M 103 | ell.stackexchange.com.7z 07-Jun-2021 11:25 162.6M 104 | emacs.meta.stackexchange.com.7z 07-Jun-2021 11:25 835.1K 105 | emacs.stackexchange.com.7z 07-Jun-2021 11:26 41.7M 106 | engineering.meta.stackexchange.com.7z 07-Jun-2021 11:26 1,020.3K 107 | engineering.stackexchange.com.7z 07-Jun-2021 11:26 27.6M 108 | english.meta.stackexchange.com.7z 07-Jun-2021 11:26 16.6M 109 | english.stackexchange.com.7z 07-Jun-2021 11:26 352.9M 110 | eosio.meta.stackexchange.com.7z 07-Jun-2021 11:26 75.4K 111 | eosio.stackexchange.com.7z 07-Jun-2021 11:26 3.6M 112 | es.meta.stackoverflow.com.7z 07-Jun-2021 11:26 7.8M 113 | esperanto.meta.stackexchange.com.7z 07-Jun-2021 11:26 187.8K 114 | esperanto.stackexchange.com.7z 07-Jun-2021 11:27 3.6M 115 | ethereum.meta.stackexchange.com.7z 07-Jun-2021 11:27 878.9K 116 | ethereum.stackexchange.com.7z 07-Jun-2021 11:27 61.7M 117 | expatriates.meta.stackexchange.com.7z 07-Jun-2021 11:27 458.5K 118 | expatriates.stackexchange.com.7z 07-Jun-2021 11:27 13.1M 119 | expressionengine.meta.stackexchange.com.7z 07-Jun-2021 11:27 351.3K 120 | expressionengine.stackexchange.com.7z 07-Jun-2021 11:27 17.8M 121 | fitness.meta.stackexchange.com.7z 07-Jun-2021 11:27 869.6K 122 | fitness.stackexchange.com.7z 07-Jun-2021 11:27 26.9M 123 | freelancing.meta.stackexchange.com.7z 07-Jun-2021 11:27 376.0K 124 | freelancing.stackexchange.com.7z 07-Jun-2021 11:27 7.3M 125 | french.meta.stackexchange.com.7z 07-Jun-2021 11:27 1.2M 126 | french.stackexchange.com.7z 07-Jun-2021 11:27 34.6M 127 | gamedev.meta.stackexchange.com.7z 07-Jun-2021 11:27 3.3M 128 | gamedev.stackexchange.com.7z 07-Jun-2021 11:28 147.3M 129 | gaming.meta.stackexchange.com.7z 07-Jun-2021 11:28 14.2M 130 | gaming.stackexchange.com.7z 07-Jun-2021 11:28 194.5M 131 | gardening.meta.stackexchange.com.7z 07-Jun-2021 11:28 944.9K 132 | gardening.stackexchange.com.7z 07-Jun-2021 11:28 32.9M 133 | genealogy.meta.stackexchange.com.7z 07-Jun-2021 11:28 1.9M 134 | genealogy.stackexchange.com.7z 07-Jun-2021 11:28 13.3M 135 | german.meta.stackexchange.com.7z 07-Jun-2021 11:28 2.2M 136 | german.stackexchange.com.7z 07-Jun-2021 11:28 49.6M 137 | gis.meta.stackexchange.com.7z 07-Jun-2021 11:28 4.0M 138 | gis.stackexchange.com.7z 07-Jun-2021 11:29 296.6M 139 | graphicdesign.meta.stackexchange.com.7z 07-Jun-2021 11:29 3.2M 140 | graphicdesign.stackexchange.com.7z 07-Jun-2021 11:29 79.4M 141 | ham.meta.stackexchange.com.7z 07-Jun-2021 11:29 497.0K 142 | ham.stackexchange.com.7z 07-Jun-2021 11:29 12.9M 143 | hardwarerecs.meta.stackexchange.com.7z 07-Jun-2021 11:29 946.7K 144 | hardwarerecs.stackexchange.com.7z 07-Jun-2021 11:29 8.2M 145 | health.meta.stackexchange.com.7z 07-Jun-2021 11:29 1.7M 146 | health.stackexchange.com.7z 07-Jun-2021 11:29 18.0M 147 | hermeneutics.meta.stackexchange.com.7z 07-Jun-2021 11:29 2.8M 148 | hermeneutics.stackexchange.com.7z 07-Jun-2021 11:29 73.3M 149 | hinduism.meta.stackexchange.com.7z 07-Jun-2021 11:30 3.1M 150 | hinduism.stackexchange.com.7z 07-Jun-2021 11:30 54.8M 151 | history.meta.stackexchange.com.7z 07-Jun-2021 11:30 2.9M 152 | history.stackexchange.com.7z 07-Jun-2021 11:30 76.4M 153 | homebrew.meta.stackexchange.com.7z 07-Jun-2021 11:30 362.5K 154 | homebrew.stackexchange.com.7z 07-Jun-2021 11:30 13.1M 155 | hsm.meta.stackexchange.com.7z 07-Jun-2021 11:30 449.5K 156 | hsm.stackexchange.com.7z 07-Jun-2021 11:30 12.6M 157 | interpersonal.meta.stackexchange.com.7z 07-Jun-2021 11:30 4.6M 158 | interpersonal.stackexchange.com.7z 07-Jun-2021 11:30 30.2M 159 | iot.meta.stackexchange.com.7z 07-Jun-2021 11:30 503.6K 160 | iot.stackexchange.com.7z 07-Jun-2021 11:30 5.7M 161 | iota.meta.stackexchange.com.7z 07-Jun-2021 11:30 91.9K 162 | iota.stackexchange.com.7z 07-Jun-2021 11:30 1.9M 163 | islam.meta.stackexchange.com.7z 07-Jun-2021 11:31 3.4M 164 | islam.stackexchange.com.7z 07-Jun-2021 11:31 40.9M 165 | italian.meta.stackexchange.com.7z 07-Jun-2021 11:31 600.8K 166 | italian.stackexchange.com.7z 07-Jun-2021 11:31 9.2M 167 | ja.meta.stackoverflow.com.7z 07-Jun-2021 11:31 3.1M 168 | ja.stackoverflow.com.7z 07-Jun-2021 11:31 60.1M 169 | japanese.meta.stackexchange.com.7z 07-Jun-2021 11:31 2.7M 170 | japanese.stackexchange.com.7z 07-Jun-2021 11:31 52.0M 171 | joomla.meta.stackexchange.com.7z 07-Jun-2021 11:31 380.8K 172 | joomla.stackexchange.com.7z 07-Jun-2021 11:31 13.7M 173 | judaism.meta.stackexchange.com.7z 07-Jun-2021 11:31 5.4M 174 | judaism.stackexchange.com.7z 07-Jun-2021 11:32 104.1M 175 | korean.meta.stackexchange.com.7z 07-Jun-2021 11:32 243.2K 176 | korean.stackexchange.com.7z 07-Jun-2021 11:32 3.7M 177 | languagelearning.meta.stackexchange.com.7z 07-Jun-2021 11:32 551.9K 178 | languagelearning.stackexchange.com.7z 07-Jun-2021 11:32 3.9M 179 | latin.meta.stackexchange.com.7z 07-Jun-2021 11:32 691.1K 180 | latin.stackexchange.com.7z 07-Jun-2021 11:32 14.4M 181 | law.meta.stackexchange.com.7z 07-Jun-2021 11:32 1.4M 182 | law.stackexchange.com.7z 07-Jun-2021 11:33 56.3M 183 | license.txt 23-Jan-2014 17:05 1.9K 184 | lifehacks.meta.stackexchange.com.7z 07-Jun-2021 11:33 1.0M 185 | lifehacks.stackexchange.com.7z 07-Jun-2021 11:33 11.0M 186 | linguistics.meta.stackexchange.com.7z 07-Jun-2021 11:33 1.1M 187 | linguistics.stackexchange.com.7z 07-Jun-2021 11:33 29.2M 188 | literature.meta.stackexchange.com.7z 07-Jun-2021 11:33 2.2M 189 | literature.stackexchange.com.7z 07-Jun-2021 11:33 17.8M 190 | magento.meta.stackexchange.com.7z 07-Jun-2021 13:09 1.6M 191 | magento.stackexchange.com.7z 07-Jun-2021 13:10 182.7M 192 | martialarts.meta.stackexchange.com.7z 07-Jun-2021 13:10 798.1K 193 | martialarts.stackexchange.com.7z 07-Jun-2021 13:10 10.2M 194 | materials.meta.stackexchange.com.7z 07-Jun-2021 13:10 362.3K 195 | materials.stackexchange.com.7z 07-Jun-2021 13:10 4.6M 196 | math.meta.stackexchange.com.7z 07-Jun-2021 13:10 41.7M 197 | math.stackexchange.com.7z 07-Jun-2021 14:18 2.6G 198 | matheducators.meta.stackexchange.com.7z 07-Jun-2021 13:10 866.7K 199 | matheducators.stackexchange.com.7z 07-Jun-2021 13:10 17.7M 200 | mathematica.meta.stackexchange.com.7z 07-Jun-2021 13:10 3.4M 201 | mathematica.stackexchange.com.7z 07-Jun-2021 13:10 230.7M 202 | mathoverflow.net.7z 07-Jun-2021 13:11 369.6M 203 | mechanics.meta.stackexchange.com.7z 07-Jun-2021 13:11 1.2M 204 | mechanics.stackexchange.com.7z 07-Jun-2021 13:11 48.4M 205 | meta.askubuntu.com.7z 07-Jun-2021 13:11 18.1M 206 | meta.mathoverflow.net.7z 07-Jun-2021 13:11 7.1M 207 | meta.serverfault.com.7z 07-Jun-2021 13:11 9.6M 208 | meta.stackexchange.com.7z 07-Jun-2021 13:12 345.6M 209 | meta.stackoverflow.com.7z 07-Jun-2021 13:12 266.2M 210 | meta.superuser.com.7z 07-Jun-2021 13:12 17.6M 211 | moderators.meta.stackexchange.com.7z 07-Jun-2021 13:12 483.8K 212 | moderators.stackexchange.com.7z 07-Jun-2021 13:12 2.8M 213 | monero.meta.stackexchange.com.7z 07-Jun-2021 13:12 218.9K 214 | monero.stackexchange.com.7z 07-Jun-2021 13:12 7.3M 215 | money.meta.stackexchange.com.7z 07-Jun-2021 13:12 2.2M 216 | money.stackexchange.com.7z 07-Jun-2021 13:12 97.0M 217 | movies.meta.stackexchange.com.7z 07-Jun-2021 13:12 4.7M 218 | movies.stackexchange.com.7z 07-Jun-2021 13:12 67.8M 219 | music.meta.stackexchange.com.7z 07-Jun-2021 13:12 2.7M 220 | music.stackexchange.com.7z 07-Jun-2021 13:13 75.6M 221 | musicfans.meta.stackexchange.com.7z 07-Jun-2021 13:13 624.2K 222 | musicfans.stackexchange.com.7z 07-Jun-2021 13:13 6.2M 223 | mythology.meta.stackexchange.com.7z 07-Jun-2021 13:13 557.7K 224 | mythology.stackexchange.com.7z 07-Jun-2021 13:13 7.5M 225 | networkengineering.meta.stackexchange.com.7z 07-Jun-2021 13:13 1.2M 226 | networkengineering.stackexchange.com.7z 07-Jun-2021 13:13 38.6M 227 | opendata.meta.stackexchange.com.7z 07-Jun-2021 13:13 473.4K 228 | opendata.stackexchange.com.7z 07-Jun-2021 13:13 10.6M 229 | opensource.meta.stackexchange.com.7z 07-Jun-2021 13:13 889.1K 230 | opensource.stackexchange.com.7z 07-Jun-2021 13:13 12.0M 231 | or.meta.stackexchange.com.7z 07-Jun-2021 13:13 342.2K 232 | or.stackexchange.com.7z 07-Jun-2021 13:13 5.8M 233 | outdoors.meta.stackexchange.com.7z 07-Jun-2021 13:13 1.5M 234 | outdoors.stackexchange.com.7z 07-Jun-2021 13:13 24.9M 235 | parenting.meta.stackexchange.com.7z 07-Jun-2021 13:13 2.0M 236 | parenting.stackexchange.com.7z 07-Jun-2021 13:14 35.5M 237 | patents.meta.stackexchange.com.7z 07-Jun-2021 13:14 447.2K 238 | patents.stackexchange.com.7z 07-Jun-2021 13:14 10.2M 239 | pets.meta.stackexchange.com.7z 07-Jun-2021 13:14 1.2M 240 | pets.stackexchange.com.7z 07-Jun-2021 13:14 20.9M 241 | philosophy.meta.stackexchange.com.7z 07-Jun-2021 13:14 2.8M 242 | philosophy.stackexchange.com.7z 07-Jun-2021 13:14 79.9M 243 | photo.meta.stackexchange.com.7z 07-Jun-2021 13:14 4.1M 244 | photo.stackexchange.com.7z 07-Jun-2021 13:14 91.4M 245 | physics.meta.stackexchange.com.7z 07-Jun-2021 13:14 13.6M 246 | physics.stackexchange.com.7z 07-Jun-2021 13:15 510.4M 247 | pm.meta.stackexchange.com.7z 07-Jun-2021 13:16 1.1M 248 | pm.stackexchange.com.7z 07-Jun-2021 13:16 23.2M 249 | poker.meta.stackexchange.com.7z 07-Jun-2021 13:16 311.7K 250 | poker.stackexchange.com.7z 07-Jun-2021 13:16 5.8M 251 | politics.meta.stackexchange.com.7z 07-Jun-2021 13:16 3.6M 252 | politics.stackexchange.com.7z 07-Jun-2021 13:16 72.7M 253 | portuguese.meta.stackexchange.com.7z 07-Jun-2021 13:16 538.9K 254 | portuguese.stackexchange.com.7z 07-Jun-2021 13:16 8.6M 255 | pt.meta.stackoverflow.com.7z 07-Jun-2021 13:16 12.3M 256 | pt.stackoverflow.com.7z 07-Jun-2021 13:17 412.8M 257 | puzzling.meta.stackexchange.com.7z 07-Jun-2021 13:17 5.7M 258 | puzzling.stackexchange.com.7z 07-Jun-2021 13:17 92.9M 259 | quant.meta.stackexchange.com.7z 07-Jun-2021 13:17 687.4K 260 | quant.stackexchange.com.7z 07-Jun-2021 13:17 40.3M 261 | quantumcomputing.meta.stackexchange.com.7z 07-Jun-2021 13:17 727.7K 262 | quantumcomputing.stackexchange.com.7z 07-Jun-2021 13:17 14.0M 263 | raspberrypi.meta.stackexchange.com.7z 07-Jun-2021 13:17 1.7M 264 | raspberrypi.stackexchange.com.7z 07-Jun-2021 13:17 83.7M 265 | readme.txt 23-Jan-2014 00:46 4.6K 266 | retrocomputing.meta.stackexchange.com.7z 07-Jun-2021 13:17 1.0M 267 | retrocomputing.stackexchange.com.7z 07-Jun-2021 13:17 25.4M 268 | reverseengineering.meta.stackexchange.com.7z 07-Jun-2021 13:17 503.4K 269 | reverseengineering.stackexchange.com.7z 07-Jun-2021 13:17 22.5M 270 | robotics.meta.stackexchange.com.7z 07-Jun-2021 13:17 551.0K 271 | robotics.stackexchange.com.7z 07-Jun-2021 13:17 15.4M 272 | rpg.meta.stackexchange.com.7z 07-Jun-2021 13:18 12.6M 273 | rpg.stackexchange.com.7z 07-Jun-2021 13:18 217.0M 274 | ru.meta.stackoverflow.com.7z 07-Jun-2021 13:18 15.2M 275 | ru.stackoverflow.com.7z 07-Jun-2021 13:19 749.4M 276 | rus.meta.stackexchange.com.7z 07-Jun-2021 13:19 553.3K 277 | rus.stackexchange.com.7z 07-Jun-2021 13:19 39.1M 278 | russian.meta.stackexchange.com.7z 07-Jun-2021 13:19 480.5K 279 | russian.stackexchange.com.7z 07-Jun-2021 13:19 13.7M 280 | salesforce.meta.stackexchange.com.7z 07-Jun-2021 13:19 2.1M 281 | salesforce.stackexchange.com.7z 07-Jun-2021 13:20 208.2M 282 | scicomp.meta.stackexchange.com.7z 07-Jun-2021 13:20 680.9K 283 | scicomp.stackexchange.com.7z 07-Jun-2021 13:20 28.4M 284 | scifi.meta.stackexchange.com.7z 07-Jun-2021 13:20 13.9M 285 | scifi.stackexchange.com.7z 07-Jun-2021 13:20 260.3M 286 | security.meta.stackexchange.com.7z 07-Jun-2021 13:20 4.4M 287 | security.stackexchange.com.7z 07-Jun-2021 13:21 204.3M 288 | serverfault.com.7z 07-Jun-2021 13:22 685.6M 289 | sharepoint.meta.stackexchange.com.7z 07-Jun-2021 13:22 1.4M 290 | sharepoint.stackexchange.com.7z 07-Jun-2021 13:22 148.4M 291 | sitecore.meta.stackexchange.com.7z 07-Jun-2021 13:22 438.2K 292 | sitecore.stackexchange.com.7z 07-Jun-2021 13:22 20.6M 293 | skeptics.meta.stackexchange.com.7z 07-Jun-2021 13:22 6.0M 294 | skeptics.stackexchange.com.7z 07-Jun-2021 13:22 63.7M 295 | softwareengineering.meta.stackexchange.com.7z 07-Jun-2021 13:22 13.9M 296 | softwareengineering.stackexchange.com.7z 07-Jun-2021 13:22 288.8M 297 | softwarerecs.meta.stackexchange.com.7z 07-Jun-2021 13:22 2.6M 298 | softwarerecs.stackexchange.com.7z 07-Jun-2021 13:22 41.0M 299 | sound.meta.stackexchange.com.7z 07-Jun-2021 13:22 463.7K 300 | sound.stackexchange.com.7z 07-Jun-2021 13:23 22.9M 301 | space.meta.stackexchange.com.7z 07-Jun-2021 13:23 2.3M 302 | space.stackexchange.com.7z 07-Jun-2021 13:23 58.0M 303 | spanish.meta.stackexchange.com.7z 07-Jun-2021 13:23 2.0M 304 | spanish.stackexchange.com.7z 07-Jun-2021 13:23 28.3M 305 | sports.meta.stackexchange.com.7z 07-Jun-2021 13:23 1.1M 306 | sports.stackexchange.com.7z 07-Jun-2021 13:23 13.7M 307 | sqa.meta.stackexchange.com.7z 07-Jun-2021 13:23 560.2K 308 | sqa.stackexchange.com.7z 07-Jun-2021 13:23 31.4M 309 | stackapps.com.7z 07-Jun-2021 13:23 12.9M 310 | stackexchange_archive.torrent 07-Jun-2021 23:34 458.7K 311 | stackexchange_files.xml 07-Jun-2021 23:35 103.6K 312 | stackexchange_meta.sqlite 07-Jun-2021 18:53 1.1M 313 | stackexchange_meta.xml 07-Jun-2021 23:34 3.4K 314 | stackexchange_reviews.xml 07-Jun-2021 23:35 17.0K 315 | stats.meta.stackexchange.com.7z 07-Jun-2021 13:25 8.0M 316 | stats.stackexchange.com.7z 07-Jun-2021 13:25 442.1M 317 | stellar.meta.stackexchange.com.7z 07-Jun-2021 13:25 90.0K 318 | stellar.stackexchange.com.7z 07-Jun-2021 13:25 2.4M 319 | superuser.com.7z 07-Jun-2021 13:27 994.8M 320 | sustainability.meta.stackexchange.com.7z 07-Jun-2021 13:27 558.2K 321 | sustainability.stackexchange.com.7z 07-Jun-2021 13:27 7.5M 322 | tex.meta.stackexchange.com.7z 07-Jun-2021 13:27 10.9M 323 | tex.stackexchange.com.7z 07-Jun-2021 13:28 620.5M 324 | tezos.meta.stackexchange.com.7z 07-Jun-2021 13:28 74.3K 325 | tezos.stackexchange.com.7z 07-Jun-2021 13:28 2.3M 326 | tor.meta.stackexchange.com.7z 07-Jun-2021 13:28 312.4K 327 | tor.stackexchange.com.7z 07-Jun-2021 13:28 10.3M 328 | travel.meta.stackexchange.com.7z 07-Jun-2021 13:28 5.4M 329 | travel.stackexchange.com.7z 07-Jun-2021 13:28 118.2M 330 | tridion.meta.stackexchange.com.7z 07-Jun-2021 13:28 442.6K 331 | tridion.stackexchange.com.7z 07-Jun-2021 13:28 14.8M 332 | ukrainian.meta.stackexchange.com.7z 07-Jun-2021 13:28 569.8K 333 | ukrainian.stackexchange.com.7z 07-Jun-2021 13:28 6.9M 334 | unix.meta.stackexchange.com.7z 07-Jun-2021 13:28 6.8M 335 | unix.stackexchange.com.7z 07-Jun-2021 13:28 539.2M 336 | ux.meta.stackexchange.com.7z 07-Jun-2021 13:29 2.6M 337 | ux.stackexchange.com.7z 07-Jun-2021 13:29 98.2M 338 | vegetarianism.meta.stackexchange.com.7z 07-Jun-2021 13:29 385.8K 339 | vegetarianism.stackexchange.com.7z 07-Jun-2021 13:29 2.5M 340 | vi.meta.stackexchange.com.7z 07-Jun-2021 13:29 714.5K 341 | vi.stackexchange.com.7z 07-Jun-2021 13:29 24.8M 342 | webapps.meta.stackexchange.com.7z 07-Jun-2021 13:29 2.5M 343 | webapps.stackexchange.com.7z 07-Jun-2021 13:29 62.6M 344 | webmasters.meta.stackexchange.com.7z 07-Jun-2021 13:29 1.7M 345 | webmasters.stackexchange.com.7z 07-Jun-2021 13:29 75.6M 346 | windowsphone.meta.stackexchange.com.7z 07-Jun-2021 13:29 345.8K 347 | windowsphone.stackexchange.com.7z 07-Jun-2021 13:29 5.4M 348 | woodworking.meta.stackexchange.com.7z 07-Jun-2021 13:29 352.3K 349 | woodworking.stackexchange.com.7z 07-Jun-2021 13:29 10.4M 350 | wordpress.meta.stackexchange.com.7z 07-Jun-2021 13:29 3.2M 351 | wordpress.stackexchange.com.7z 07-Jun-2021 13:29 209.3M 352 | workplace.meta.stackexchange.com.7z 07-Jun-2021 13:30 8.0M 353 | workplace.stackexchange.com.7z 07-Jun-2021 13:30 145.5M 354 | worldbuilding.meta.stackexchange.com.7z 07-Jun-2021 13:30 9.8M 355 | worldbuilding.stackexchange.com.7z 07-Jun-2021 13:30 260.5M 356 | writers.meta.stackexchange.com.7z 07-Jun-2021 13:30 3.1M 357 | writers.stackexchange.com.7z 07-Jun-2021 13:30 51.3M --------------------------------------------------------------------------------