├── .gitignore ├── README.md ├── data ├── .keep ├── 30k-vocab-filtered.json └── 30k-vocab-filtered.txt ├── download_openwebtext.sh ├── inference.sh ├── mini_coil ├── __init__.py ├── convert_idf.py ├── data_pipeline │ ├── __init__.py │ ├── abstract_uploader.py │ ├── augment_data.py │ ├── cluster_words.py │ ├── combine_models.py │ ├── compress_dimentions.py │ ├── convert_openwebtext.py │ ├── distance_matrix.py │ ├── download_validation.py │ ├── encode_and_filter.py │ ├── encode_data.py │ ├── encode_targets.py │ ├── load_sentences.py │ ├── pre_encoder.py │ ├── prepare_vocab.py │ ├── read_pre_encoded.py │ ├── split_sentences.py │ ├── split_train_val.py │ ├── stopwords.py │ ├── upload_compressed_to_qdrant.py │ ├── upload_to_qdrant.py │ └── vocab_resolver.py ├── filtering.py ├── model │ ├── __init__.py │ ├── cosine_loss.py │ ├── decoder.py │ ├── encoder.py │ ├── encoder_numpy.py │ ├── mini_coil_inference.py │ ├── mse_loss.py │ ├── triplet_loss.py │ └── word_encoder.py ├── read_data.py ├── settings.py ├── tokenizer.py ├── train.py ├── training │ ├── __init__.py │ ├── coil_module.py │ ├── data_loader.py │ ├── train.py │ ├── train_word.py │ ├── train_word_triplet.py │ ├── triplet_dataloader.py │ ├── triplet_word_module.py │ ├── try_checkpoint.py │ └── word_module.py ├── triplet_loss.py └── visualize_encoder.py ├── poetry.lock ├── pyproject.toml ├── run_eval.sh ├── split_sentences.sh ├── tests ├── 01_embedding_maker.py ├── 02_matrix_create.py ├── 03_matrix_triplets.py ├── 04_umap_emb.py ├── __init__.py ├── embed_minicoil.py └── visualize_embeddings.py ├── train-triplets.sh ├── train.sh └── unpack_openwebtext.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | .idea/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # miniCOIL 3 | 4 | MiniCOIL is a contextualized per-word embedding model. 5 | MiniCOIL generates small-size embeddings for each word in a sentence, but the embeddings can only be compared within 6 | embeddings of the same words in different sentences(context). 7 | This restriction allows to generate an extremely small embeddings (8d or even 4d) while still preserving the context of the word. 8 | 9 | ## Usage 10 | 11 | MiniCOIL embeddings might be useful in information retrieval tasks, where we need to resolve meaning of the word in the context of the sentence. 12 | For example, many words have different meanings depending on the context, such as "bank" (river bank or financial institution). 13 | 14 | MiniCOIL allows to encode precise meaning of the word, but unlike traditional word embeddings it won't dilute exact match with other words in the vocabulary. 15 | 16 | MiniCOIL is not trained in end-to-end fashion, which means that it can't assign relative importance to the words in the sentence. 17 | However, it can be combined with BM25-like scoring formula and used in search engines. 18 | 19 | ## Architecture 20 | 21 | MiniCOIL is designed to be compatible with foundational transformer models, such as SentenceTransformers. 22 | There are two main reasons for this: 23 | 24 | - We don't want to spend enormous resources on training MiniCOIL. 25 | - We want to be able to combine MiniCOIL embeddings inference with dense embedding inference in a single step. 26 | 27 | Technically, MiniCOIL is a simple array of linear layers (one for each word in vocabulary) that are trained 28 | to compress the word embeddings into a small size. That makes MiniCOIL a paper-thin layer on top of the transformer model. 29 | 30 | ### Training process 31 | 32 | MiniCOIL is trained by the principle skip-gram models, adapted to the transformer model: we want to predict the context by the word. 33 | In case of the transformer models, we predict sentence embeddings by the word embeddings. 34 | 35 | Naturally, this process can be separated into two steps: encoding and deciding (similar to autoencoders), where in the middle we have a small-size embeddings. 36 | 37 | Since we want to make MiniCOIL compatible with many transformer models, we can replace the decoder step with compressed embeddings of some larger model, 38 | so for each input model we can train the encoder independently. 39 | 40 | So the process of training is as follows: 41 | 42 | 1. Download dataset (we use openwebtext) 43 | 1. Convert dataset into readable format with `mini_coil.data_pipeline.convert_openwebtext` 44 | 1. Split data into sentences with `mini_coil.data_pipeline.split_sentences` 45 | 1. Encode sentences with transformer model, save embeddings to disk (about 350M embeddings for openwebtext) with `mini_coil.data_pipeline.encode_targets` 46 | 1. Upload encoded sentences to Qdrant, so we can sample sentences with specified words with `mini_coil.data_pipeline.upload_to_qdrant` 47 | 1. For triplet-based training - follow [train-triplets.sh](./train-triplets.sh). It will for each word: 48 | 1. Denerate Distance Matrix based on large embeddings 49 | 2. Augment sentences 50 | 3. Encode sentences with small model 51 | 4. Train per-word encoder 52 | 1. Merge encoders for each word into a single model `mini_coil.data_pipeline.combine_models` 53 | 1. Make visualizations 54 | 1. Benchmark 55 | 1. Quantize (optional) 56 | 1. Convert to ONNX 57 | -------------------------------------------------------------------------------- /data/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qdrant/miniCOIL/2d939737e89bacd797cb6c3e312135070a93a56d/data/.keep -------------------------------------------------------------------------------- /download_openwebtext.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # Download OpenWebText 5 | 6 | mkdir -p data/openwebtext 7 | 8 | for index in {0..20} 9 | do 10 | index=$(printf "%02d" $index) 11 | wget https://huggingface.co/datasets/Skylion007/openwebtext/resolve/main/subsets/urlsf_subset${index}.tar?download=true -O data/openwebtext/urlsf_subset${index}.tar 12 | done 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Run inference for sentences 4 | 5 | ARCHIVE_DIR=data/openwebtext 6 | 7 | for file in $ARCHIVE_DIR/openwebtext-*-sentences.txt.gz 8 | do 9 | # file = data/openwebtext/openwebtext-00-sentences.txt.gz 10 | 11 | file_name=$(basename $file) 12 | # file_name = openwebtext-00-sentences.txt.gz 13 | 14 | # Filename without extension 15 | 16 | file_name_no_ext=${file_name%.txt.gz} 17 | 18 | python -m mini_coil.data_pipeline.encode_targets \ 19 | --input-file $file \ 20 | --output-file $ARCHIVE_DIR/${file_name_no_ext}-emb.npy \ 21 | --batch-size 32 \ 22 | --device-count 4 \ 23 | --use-cuda 24 | 25 | done 26 | 27 | 28 | -------------------------------------------------------------------------------- /mini_coil/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qdrant/miniCOIL/2d939737e89bacd797cb6c3e312135070a93a56d/mini_coil/__init__.py -------------------------------------------------------------------------------- /mini_coil/convert_idf.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import os 4 | import tqdm 5 | import math 6 | from snowballstemmer import stemmer as get_stemmer 7 | import pickle 8 | 9 | from mini_coil.settings import DATA_DIR 10 | 11 | 12 | class IDFVocab: 13 | def __init__(self, idf: Dict[str, int]): 14 | self.idf_vocab = idf 15 | self.num_docs = 10_000_000 16 | 17 | def get_idf(self, token: str) -> float: 18 | num_tokens = self.idf_vocab.get(token, 0) 19 | return math.log( 20 | (self.num_docs - num_tokens + 0.5) / (num_tokens + 0.5) + 1.0 21 | ) 22 | 23 | def save_vocab_pkl(self, path: str): 24 | with open(path, "wb") as f: 25 | pickle.dump(self, f) 26 | 27 | @classmethod 28 | def load_vocab_pkl(cls, path: str): 29 | with open(path, "rb") as f: 30 | return pickle.load(f) 31 | 32 | 33 | class IdfConverter: 34 | 35 | def __init__(self): 36 | self.stemmer = get_stemmer("english") 37 | self.vocab = {} 38 | 39 | def add_token(self, token: str, count: int): 40 | stemmed_token = self.stemmer.stemWord(token) 41 | 42 | if stemmed_token in self.vocab: 43 | self.vocab[stemmed_token] += count 44 | else: 45 | self.vocab[stemmed_token] = count 46 | 47 | def read_vocab(self, path: str): 48 | with open(path) as f: 49 | for line in tqdm.tqdm(f): 50 | token, count = line.split(",") 51 | self.add_token(token, int(count)) 52 | 53 | def to_idf_vocab(self) -> IDFVocab: 54 | return IDFVocab(self.vocab) 55 | 56 | 57 | def main(): 58 | converter = IdfConverter() 59 | 60 | vocab_path = os.path.join(DATA_DIR, "wp_word_idfs.csv") 61 | 62 | converter.read_vocab(vocab_path) 63 | 64 | idf_vocab = converter.to_idf_vocab() 65 | 66 | idf_vocab.save_vocab_pkl(os.path.join(DATA_DIR, "idf_vocab.pkl")) 67 | 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /mini_coil/data_pipeline/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qdrant/miniCOIL/2d939737e89bacd797cb6c3e312135070a93a56d/mini_coil/data_pipeline/__init__.py -------------------------------------------------------------------------------- /mini_coil/data_pipeline/abstract_uploader.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gzip 3 | import hashlib 4 | import os 5 | from typing import Iterable 6 | 7 | from qdrant_client import QdrantClient, models 8 | import tqdm 9 | 10 | QDRANT_URL = os.environ.get("QDRANT_URL", "http://localhost:6333") 11 | QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "") 12 | 13 | 14 | def read_abstracts(path: str) -> Iterable[str]: 15 | with gzip.open(path, "rt") as f: 16 | for line in f: 17 | line = line.strip() 18 | if len(line) == 0: 19 | continue 20 | yield line.strip() 21 | 22 | 23 | def compute_hash(text: str) -> str: 24 | return hashlib.sha256(text.encode()).hexdigest() 25 | 26 | 27 | def main(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("--input-file", type=str) 30 | parser.add_argument("--collection-name", type=str, default="coil-abstracts") 31 | parser.add_argument("--recreate-collection", action="store_true") 32 | parser.add_argument("--parallel", type=int, default=1) 33 | args = parser.parse_args() 34 | 35 | collection_name = args.collection_name 36 | 37 | qdrant = QdrantClient( 38 | url=QDRANT_URL, 39 | api_key=QDRANT_API_KEY, 40 | prefer_grpc=True, 41 | ) 42 | 43 | def data_iter(): 44 | abstracts = read_abstracts(args.input_file) 45 | for abstract in abstracts: 46 | abs_hash = compute_hash(abstract) 47 | 48 | # Compute hash from the text and convert it to UUID 49 | hash_uuid = hashlib.md5(abs_hash.encode()).hexdigest() 50 | 51 | yield models.PointStruct( 52 | id=hash_uuid, 53 | vector={}, 54 | payload={"abstract": abstract, "abs_hash": abs_hash} 55 | ) 56 | 57 | collection_exists = qdrant.collection_exists(collection_name) 58 | 59 | if not collection_exists or args.recreate_collection: 60 | qdrant.delete_collection(collection_name) 61 | 62 | qdrant.create_collection( 63 | collection_name=collection_name, 64 | vectors_config={}, 65 | optimizers_config=models.OptimizersConfigDiff( 66 | max_segment_size=2_000_000, 67 | max_optimization_threads=1, # Run one optimization per shard 68 | ), 69 | shard_number=6, 70 | ) 71 | 72 | qdrant.create_payload_index( 73 | collection_name, 74 | "abs_hash", 75 | field_schema=models.KeywordIndexParams( 76 | type=models.KeywordIndexType.KEYWORD, 77 | on_disk=True, 78 | ) 79 | ) 80 | 81 | qdrant.upload_points( 82 | collection_name, 83 | points=tqdm.tqdm(data_iter()), 84 | parallel=args.parallel, 85 | ) 86 | 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /mini_coil/data_pipeline/augment_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import random 6 | import re 7 | from typing import Tuple, List 8 | 9 | from mini_coil.settings import DATA_DIR 10 | 11 | logging.basicConfig( 12 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 13 | level=logging.INFO 14 | ) 15 | logging.getLogger("data-augment").setLevel(logging.WARNING) 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def extract_window_around_target( 20 | words: List[str], 21 | target_idx: int, 22 | min_words: int = 4, 23 | max_words: int = 6 24 | ) -> Tuple[int, int]: 25 | left_window = random.randint(1, 3) 26 | right_window = random.randint(1, 3) 27 | 28 | start = max(0, target_idx - left_window) 29 | end = min(len(words), target_idx + right_window + 1) 30 | 31 | while end - start < min_words: 32 | if end < len(words): 33 | end += 1 34 | elif start > 0: 35 | start -= 1 36 | else: 37 | break 38 | 39 | if end - start > max_words: 40 | if random.choice([True, False]): 41 | end = start + max_words 42 | else: 43 | start = end - max_words 44 | 45 | return start, end 46 | 47 | 48 | def create_snippet(sentence: str, target_word_forms: List[str]) -> str: 49 | words = sentence.split() 50 | 51 | if not words: 52 | return sentence 53 | 54 | normalized_target_forms = set( 55 | re.sub(r"\W+", "", target_word.lower()) 56 | for target_word in target_word_forms 57 | ) 58 | 59 | try: 60 | target_indices = [ 61 | i for i, w in enumerate(words) 62 | if re.sub(r"\W+", "", w.lower()) in normalized_target_forms 63 | ] 64 | 65 | if not target_indices: 66 | return sentence 67 | 68 | target_idx = random.choice(target_indices) 69 | start, end = extract_window_around_target(words, target_idx) 70 | 71 | return " ".join(words[start:end]) 72 | 73 | except Exception as e: 74 | logger.exception("Error creating snippet: ", e) 75 | return sentence 76 | 77 | 78 | def process_file(input_path: str, output_path: str, target_word_forms: List[str]) -> None: 79 | try: 80 | with open(input_path, "r", encoding="utf-8") as fin, \ 81 | open(output_path, "w", encoding="utf-8") as fout: 82 | 83 | for i, line in enumerate(fin): 84 | line = line.strip() 85 | if not line: 86 | continue 87 | 88 | try: 89 | data = json.loads(line) 90 | data["line_number"] = i 91 | 92 | fout.write(json.dumps(data, ensure_ascii=False) + "\n") 93 | 94 | data_copy = dict(data) 95 | data_copy["sentence"] = create_snippet(data["sentence"], target_word_forms) 96 | 97 | if data_copy["sentence"] != data["sentence"]: 98 | fout.write(json.dumps(data_copy, ensure_ascii=False) + "\n") 99 | 100 | except json.JSONDecodeError: 101 | continue 102 | except Exception as e: 103 | continue 104 | 105 | except Exception as e: 106 | raise Exception(f"Error processing file: {str(e)}") 107 | 108 | 109 | def main(): 110 | default_vocab_path = os.path.join(DATA_DIR, "30k-vocab-filtered.json") 111 | 112 | parser = argparse.ArgumentParser(description="Create snippets around target words in sentences") 113 | parser.add_argument("--input-file", required=True, help="Path to input .jsonl file") 114 | parser.add_argument("--output-file", required=True, help="Path to output .jsonl file") 115 | parser.add_argument("--target-word", required=True, help="Target word to create snippet around") 116 | parser.add_argument("--vocab-path", type=str, default=default_vocab_path) 117 | args = parser.parse_args() 118 | 119 | vocab = json.load(open(args.vocab_path)) 120 | 121 | if args.target_word not in vocab: 122 | print(f"WARNING: word {args.target_word} not found in vocab, using as is") 123 | word_forms = [args.target_word] 124 | else: 125 | word_forms = vocab[args.target_word] 126 | 127 | process_file(args.input_file, args.output_file, word_forms) 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /mini_coil/data_pipeline/cluster_words.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from typing import List 5 | 6 | import numpy as np 7 | from scipy.sparse import csr_matrix 8 | from sklearn.cluster import KMeans 9 | from sklearn.metrics import silhouette_score 10 | 11 | 12 | def load_matrix(matrix_path): 13 | with open(matrix_path, "r") as f: 14 | result = json.load(f) 15 | 16 | offsets_row = np.array(result['offsets_row']) 17 | offsets_col = np.array(result['offsets_col']) 18 | scores = np.array(result['scores']) 19 | 20 | matrix = csr_matrix((scores, (offsets_row, offsets_col))) 21 | 22 | matrix = matrix + matrix.T 23 | 24 | return matrix 25 | 26 | 27 | def plot_embeddings(embeddings, labels, save_path: str): 28 | import matplotlib.pyplot as plt 29 | 30 | plt.scatter(embeddings[:, 0], embeddings[:, 1], c=labels, s=1) 31 | 32 | plt.savefig(save_path) 33 | plt.close() 34 | 35 | 36 | def find_extrema_score(scores: List[float]) -> int: 37 | """ 38 | Find all scores which is an extrema of the function, meaning score before and after are lower 39 | """ 40 | extrema_index = [] 41 | 42 | for i in range(1, len(scores) - 1): 43 | if scores[i] > scores[i - 1] and scores[i] > scores[i + 1]: 44 | extrema_index.append(i) 45 | 46 | if len(extrema_index) == 0: 47 | return 0 48 | 49 | # return latest extrema 50 | return extrema_index[-1] 51 | 52 | 53 | def main(): 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--matrix-path", type=str) 56 | parser.add_argument("--vector-path", type=str) 57 | parser.add_argument("--output-dir", type=str, default=None) 58 | 59 | args = parser.parse_args() 60 | 61 | matrix_path = args.matrix_path 62 | 63 | matrix = load_matrix(matrix_path) 64 | 65 | print(matrix.shape) 66 | 67 | vector_path = args.vector_path 68 | vectors = np.load(vector_path, mmap_mode='r') 69 | 70 | print(vectors.shape) 71 | 72 | a = vectors[:, 0] 73 | b = vectors[:, 1] 74 | 75 | z = np.sqrt(1 + np.sum(vectors ** 2, axis=1)) 76 | 77 | disk_a = a / (1 + z) 78 | disk_b = b / (1 + z) 79 | 80 | translated = np.stack([disk_a, disk_b], axis=1) 81 | 82 | # Inverse values of matrix to (1 - value) 83 | 84 | scores = [] 85 | start_at = 2 86 | 87 | for i in range(start_at, 10): 88 | clusterer = KMeans(n_clusters=i, random_state=10) 89 | cluster_labels = clusterer.fit_predict(translated) 90 | 91 | if args.output_dir: 92 | if not os.path.exists(args.output_dir): 93 | os.makedirs(args.output_dir, exist_ok=True) 94 | save_path = os.path.join(args.output_dir, f"cluster_{i}.png") 95 | plot_embeddings(vectors, cluster_labels, save_path) 96 | 97 | silhouette_avg = silhouette_score(translated, cluster_labels) # , metric="precomputed") 98 | 99 | print(f"Silhouette Score for {i} clusters: {silhouette_avg}") 100 | scores.append(silhouette_avg) 101 | 102 | extrema_clusters = find_extrema_score(scores) + start_at 103 | print(f"Extrema extrema_clusters: {extrema_clusters}") 104 | 105 | 106 | if __name__ == "__main__": 107 | main() 108 | -------------------------------------------------------------------------------- /mini_coil/data_pipeline/combine_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import json 5 | from typing import List, Dict 6 | 7 | import tqdm 8 | import torch 9 | import numpy as np 10 | 11 | from mini_coil.data_pipeline.stopwords import english_stopwords 12 | from mini_coil.data_pipeline.vocab_resolver import VocabResolver 13 | from mini_coil.model.encoder import Encoder 14 | from mini_coil.model.word_encoder import WordEncoder 15 | 16 | 17 | def load_vocab(vocab_path) -> Dict[str, List[str]]: 18 | with open(vocab_path) as f: 19 | vocab = json.load(f) 20 | return vocab 21 | 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--models-dir", type=str) 26 | parser.add_argument("--vocab-path", type=str) 27 | parser.add_argument("--output-path", type=str) 28 | parser.add_argument("--input-dim", type=int, default=512) 29 | parser.add_argument("--output-dim", type=int, default=4) 30 | args = parser.parse_args() 31 | 32 | vocab = load_vocab(args.vocab_path) 33 | filtered_vocab = [] 34 | 35 | for word in vocab.keys(): 36 | if word in english_stopwords: 37 | continue 38 | model_path = os.path.join(args.models_dir, f"model-{word}.ptch") 39 | if os.path.exists(model_path): 40 | filtered_vocab.append(word) 41 | 42 | params = [torch.zeros(args.input_dim, args.output_dim)] # Extra zero tensor, as first word is vocab starts from 1 43 | 44 | vocab_resolver = VocabResolver() 45 | 46 | for word in tqdm.tqdm(filtered_vocab): 47 | model_path = os.path.join(args.models_dir, f"model-{word}.ptch") 48 | encoder = WordEncoder(args.input_dim, args.output_dim) 49 | encoder.load_state_dict(torch.load(model_path, weights_only=True)) 50 | 51 | encode_param = encoder.encoder_weights.data 52 | params.append(encode_param) 53 | 54 | vocab_resolver.add_word(word) 55 | 56 | vocab_size = vocab_resolver.vocab_size() 57 | 58 | combined_params = torch.stack(params, dim=0) 59 | 60 | print("combined_params", combined_params.shape) 61 | print("vocab_size", vocab_size) 62 | 63 | combined_params_numpy = combined_params.numpy() 64 | 65 | # Make sure the output directory exists 66 | os.makedirs(os.path.dirname(args.output_path), exist_ok=True) 67 | 68 | # Save numpy file as well 69 | np.save(args.output_path + ".npy", combined_params_numpy) 70 | 71 | encoder = Encoder( 72 | input_dim=args.input_dim, 73 | output_dim=args.output_dim, 74 | vocab_size=vocab_size, 75 | ) 76 | 77 | encoder.encoder_weights.data = combined_params 78 | 79 | torch.save(encoder.state_dict(), args.output_path) 80 | 81 | vocab_resolver.save_json_vocab(args.output_path + ".vocab") 82 | 83 | 84 | if __name__ == '__main__': 85 | main() 86 | -------------------------------------------------------------------------------- /mini_coil/data_pipeline/compress_dimentions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path 3 | import time 4 | from os import getenv 5 | 6 | import numpy as np 7 | from qdrant_client import QdrantClient, models 8 | from scipy.sparse import csr_matrix 9 | 10 | from mini_coil.settings import DATA_DIR 11 | 12 | DEFAULT_SAMPLE_SIZE = 4000 13 | DEFAULT_LIMIT = 20 14 | 15 | QDRANT_URL = os.environ.get("QDRANT_URL", getenv("QDRANT_URL", "http://localhost:80")) 16 | QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", getenv("QDRANT_API_KEY", "")) 17 | 18 | 19 | def query_qdrant_matrix_api( 20 | collection_name: str, 21 | sample_size: int = DEFAULT_SAMPLE_SIZE, 22 | limit: int = DEFAULT_LIMIT, 23 | word: str = None, 24 | ) -> models.SearchMatrixOffsetsResponse: 25 | time_start = time.time() 26 | qdrant = QdrantClient( 27 | url=QDRANT_URL, 28 | api_key=QDRANT_API_KEY, port=80, timeout=1000) 29 | 30 | existing_sample_size = qdrant.count( 31 | collection_name=collection_name, 32 | exact=True, 33 | count_filter=models.Filter( 34 | must=[ 35 | models.FieldCondition( 36 | key="sentence", 37 | match=models.MatchText(text=word) 38 | ) 39 | ] 40 | ) 41 | ).count 42 | 43 | if existing_sample_size < sample_size: 44 | print(f'''Only {existing_sample_size} samples available for "{word}"''') 45 | 46 | response = qdrant.search_matrix_offsets( 47 | collection_name=collection_name, 48 | sample=sample_size, 49 | limit=limit, 50 | query_filter=models.Filter( 51 | must=[ 52 | models.FieldCondition( 53 | key="sentence", 54 | match=models.MatchText(text=word) 55 | ) 56 | ] 57 | ), 58 | timeout=1000, 59 | ) 60 | 61 | elapsed = time.time() - time_start 62 | 63 | print(f"Elapsed time: {elapsed}") 64 | 65 | return response 66 | 67 | 68 | def compress_matrix( 69 | matrix: csr_matrix, 70 | dim: int = 2, 71 | n_neighbours: int = 20 72 | ): 73 | from umap import UMAP 74 | 75 | n_components = dim 76 | 77 | umap = UMAP( 78 | metric="precomputed", 79 | n_components=n_components, 80 | output_metric="hyperboloid", 81 | n_neighbors=n_neighbours, 82 | ) 83 | 84 | start_time = time.time() 85 | compressed_matrix = umap.fit_transform(matrix) 86 | print(f"Umap fit_transform time: {time.time() - start_time}") 87 | return compressed_matrix 88 | 89 | 90 | def closest_points(vectors: np.ndarray, vector: np.ndarray, precision_neighbours: int = 10): 91 | """ 92 | Select top n closest points to the given vector using cosine similarity 93 | """ 94 | from sklearn.metrics.pairwise import cosine_similarity 95 | 96 | similarities = cosine_similarity(vectors, vector.reshape(1, -1)) 97 | 98 | indices = np.argsort(similarities, axis=0)[::-1] 99 | 100 | return indices[:precision_neighbours].flatten() 101 | 102 | 103 | def estimate_precision(matrix: csr_matrix, compressed_vectors: np.ndarray, precision_n: int = 100, 104 | precision_neighbours: int = 10) -> float: 105 | import numpy as np 106 | 107 | precision = [] 108 | random_indices = np.random.choice(len(compressed_vectors), size=precision_n, replace=False) 109 | 110 | for i in random_indices: 111 | closest = closest_points(compressed_vectors, compressed_vectors[i], precision_neighbours) 112 | closest = closest[closest != i] 113 | 114 | precision.append(len(set(closest) & set(matrix[i].indices)) / len(closest)) 115 | 116 | return np.mean(precision) 117 | 118 | 119 | def plot_embeddings(embeddings, save_path: str): 120 | import matplotlib.pyplot as plt 121 | 122 | plt.scatter(embeddings[:, 0], embeddings[:, 1], s=1) 123 | plt.savefig(save_path) 124 | plt.close() 125 | 126 | 127 | def get_matrix(collection_name: str, word: str, output_dir, sample_size: int = DEFAULT_SAMPLE_SIZE, 128 | limit: int = DEFAULT_LIMIT): 129 | retry = 0 130 | 131 | while retry < 3: 132 | try: 133 | result = query_qdrant_matrix_api(collection_name, word=word, sample_size=sample_size, limit=limit) 134 | offsets_row = np.array(result.offsets_row) 135 | offsets_col = np.array(result.offsets_col) 136 | scores = np.array(result.scores) 137 | 138 | matrix = csr_matrix((scores, (offsets_row, offsets_col))) 139 | 140 | # make sure that the matrix is symmetric 141 | matrix = matrix + matrix.T 142 | 143 | sparse_matrix_path = os.path.join(output_dir, f"sparse_matrix_{word}.json") 144 | with open(sparse_matrix_path, "w") as f: 145 | f.write(result.model_dump_json()) 146 | 147 | return matrix 148 | except Exception as e: 149 | print(f"Error: {e}") 150 | retry += 1 151 | continue 152 | 153 | 154 | def main(): 155 | parser = argparse.ArgumentParser() 156 | parser.add_argument("--word", type=str) 157 | parser.add_argument("--collection-name", type=str, default="coil") 158 | parser.add_argument("--dim", type=int, default=4) 159 | parser.add_argument("--output-dir", type=str, default=None) 160 | parser.add_argument("--plot", action="store_true") 161 | parser.add_argument("--sample-size", type=int, default=DEFAULT_SAMPLE_SIZE) 162 | parser.add_argument("--overwrite", action="store_true") 163 | parser.add_argument('--limit', type=int, default=20) 164 | parser.add_argument('--n_neighbours', type=int, default=20) 165 | parser.add_argument('--precision_n', type=int, default=100) 166 | parser.add_argument('--precision_neighbours', type=int, default=10) 167 | 168 | args = parser.parse_args() 169 | 170 | collection_name = args.collection_name 171 | 172 | word = args.word 173 | 174 | if args.output_dir is None: 175 | output_dir = os.path.join(DATA_DIR, "test") 176 | else: 177 | output_dir = args.output_dir 178 | 179 | path = os.path.join(output_dir, f"compressed_matrix_{word}.npy") 180 | 181 | if os.path.exists(path) and not args.overwrite: 182 | print(f"File {path} already exists. Skipping") 183 | return 184 | 185 | if not os.path.exists(output_dir): 186 | os.makedirs(output_dir, exist_ok=True) 187 | 188 | matrix = get_matrix(collection_name, word, sample_size=args.sample_size, limit=args.limit, 189 | output_dir=args.output_dir) 190 | 191 | compressed_vectors = compress_matrix(matrix, dim=args.dim, n_neighbours=args.n_neighbours) 192 | 193 | np.save(path, compressed_vectors) 194 | 195 | if args.plot: 196 | compressed_vectors_2d = compressed_vectors[:, :2] 197 | 198 | plot_embeddings(compressed_vectors_2d, os.path.join(output_dir, f"compressed_matrix_{word}.png")) 199 | 200 | a = compressed_vectors[:, 0] 201 | b = compressed_vectors[:, 1] 202 | 203 | z = np.sqrt(1 + np.sum(compressed_vectors ** 2, axis=1)) 204 | 205 | disk_a = a / (1 + z) 206 | disk_b = b / (1 + z) 207 | 208 | plot_embeddings(np.stack([disk_a, disk_b], axis=1), 209 | os.path.join(output_dir, f"compressed_matrix_{word}_hyperboloid.png")) 210 | 211 | precision = estimate_precision(matrix, compressed_vectors, precision_n=args.precision_n, 212 | precision_neighbours=args.precision_neighbours) 213 | print(f"Precision: {precision}") 214 | 215 | # precision_2d = estimate_precision(matrix, compressed_vectors_2d) 216 | # print(f"Precision 2d: {precision_2d}") 217 | 218 | 219 | if __name__ == "__main__": 220 | main() 221 | -------------------------------------------------------------------------------- /mini_coil/data_pipeline/convert_openwebtext.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script converts the OpenWebText dataset to the usable format. 3 | 4 | Default format, apparently, was created by data scientists, who have no idea how to properly store data. 5 | So they did this: 6 | 7 | Main file (tar) 8 | - folder 9 | - sub-archive file (xz compressed) 10 | - tar file 11 | - text file 12 | - text file 13 | - text file 14 | 15 | This script will convert all this nonsense to the simple compressed texts file, 16 | which is possible to decompress on the fly with the CLI tools any teapot can run. 17 | """ 18 | 19 | import os 20 | import tarfile 21 | import glob 22 | import tqdm 23 | import gzip 24 | import argparse 25 | from typing import Iterable 26 | 27 | 28 | from mini_coil.settings import DATA_DIR 29 | 30 | 31 | def iterate_archives(archive_dir) -> Iterable[str]: 32 | path_to_xz_archives = os.path.join(archive_dir, "*.xz") 33 | all_files = glob.glob(path_to_xz_archives) 34 | for archive in tqdm.tqdm(all_files): 35 | yield archive 36 | 37 | 38 | def read_files_from_tar_xz(archive_path: str) -> Iterable[str]: 39 | with tarfile.open(archive_path, "r:xz") as tar: 40 | for member in tar.getmembers(): 41 | yield tar.extractfile(member).read().decode("utf-8") 42 | 43 | 44 | def read_texts(archive_dir) -> Iterable[str]: 45 | for archive in iterate_archives(archive_dir): 46 | for text in read_files_from_tar_xz(archive): 47 | yield text 48 | 49 | 50 | def main(): 51 | arg_parser = argparse.ArgumentParser() 52 | 53 | arg_parser.add_argument("--output-file", type=str, default=None) 54 | arg_parser.add_argument("--archive-dir", type=str, default=os.path.join(DATA_DIR, "openwebtext")) 55 | 56 | args = arg_parser.parse_args() 57 | 58 | output_file = args.output_file 59 | archive_dir = args.archive_dir 60 | 61 | # output_file = os.path.join(DATA_DIR, "openwebtext.txt.gz") 62 | 63 | with gzip.open(output_file, "wt") as f: 64 | for text in tqdm.tqdm(read_texts(archive_dir)): 65 | f.write(text) 66 | f.write("\n") 67 | 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /mini_coil/data_pipeline/distance_matrix.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sample sentences with a specific word from qdrant and build full distance matrix 3 | """ 4 | import json 5 | import os 6 | import time 7 | from typing import List 8 | 9 | from qdrant_client import QdrantClient, models 10 | import numpy as np 11 | import argparse 12 | 13 | from mini_coil.settings import DATA_DIR 14 | 15 | DEFAULT_SAMPLE_SIZE = 4000 16 | 17 | QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:80") 18 | QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "") 19 | 20 | 21 | def query_sentences( 22 | collection_name: str, 23 | words: List[str], 24 | sample_size: int = DEFAULT_SAMPLE_SIZE, 25 | ) -> tuple[np.ndarray, list[dict]]: 26 | 27 | print(QDRANT_URL, QDRANT_API_KEY) 28 | qdrant = QdrantClient( 29 | url=QDRANT_URL, 30 | api_key=QDRANT_API_KEY, 31 | timeout=1000, 32 | ) 33 | 34 | 35 | response = qdrant.query_points( 36 | collection_name=collection_name, 37 | query=models.SampleQuery(sample=models.Sample.RANDOM), 38 | query_filter=models.Filter( 39 | should=[ 40 | models.FieldCondition( 41 | key="sentence", 42 | match=models.MatchText(text=word) 43 | ) 44 | for word in words 45 | ] 46 | ), 47 | limit=sample_size, 48 | with_payload=True, 49 | with_vectors=True, 50 | ) 51 | 52 | vectors = np.array([point.vector for point in response.points]) 53 | payloads = [{ 54 | "id": point.id, 55 | **point.payload 56 | } for point in response.points] 57 | 58 | return vectors, payloads 59 | 60 | def cosine_similarity_matrix(vectors: np.ndarray) -> np.ndarray: 61 | """ 62 | Compute NxN cosine similarity matrix for N vectors. 63 | """ 64 | norms = np.linalg.norm(vectors, axis=1) 65 | normalized = vectors / (norms[:, np.newaxis] + 1e-9) 66 | distances = normalized @ normalized.T 67 | return distances 68 | 69 | 70 | def main(): 71 | default_vocab_path = os.path.join(DATA_DIR, "30k-vocab-filtered.json") 72 | 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument("--word", type=str) 75 | parser.add_argument("--collection-name", type=str, default="coil") 76 | parser.add_argument("--output-matrix", type=str) 77 | parser.add_argument("--output-sentences", type=str) 78 | parser.add_argument("--sample-size", type=int, default=DEFAULT_SAMPLE_SIZE) 79 | parser.add_argument("--vocab-path", type=str, default=default_vocab_path) 80 | 81 | args = parser.parse_args() 82 | 83 | start_time = time.time() 84 | 85 | vocab = json.load(open(args.vocab_path)) 86 | 87 | if args.word not in vocab: 88 | print(f"WARNING: word {args.word} not found in vocab, using as is") 89 | forms = [args.word] 90 | else: 91 | forms = vocab[args.word] 92 | 93 | vectors, payloads = query_sentences(args.collection_name, forms, args.sample_size) 94 | elapsed_time = time.time() - start_time 95 | print(f"Query time: {elapsed_time}") 96 | distances = cosine_similarity_matrix(vectors) 97 | elapsed_time = time.time() - start_time 98 | print(f"Matrix calculation time: {elapsed_time}") 99 | 100 | # create directory if not exists 101 | os.makedirs(os.path.dirname(args.output_matrix), exist_ok=True) 102 | os.makedirs(os.path.dirname(args.output_sentences), exist_ok=True) 103 | 104 | np.save(args.output_matrix, distances) 105 | 106 | with open(args.output_sentences, "w") as f: 107 | for payload in payloads: 108 | f.write(json.dumps(payload)) 109 | f.write("\n") 110 | 111 | 112 | if __name__ == '__main__': 113 | 114 | main() 115 | 116 | def test_distance_matrix(): 117 | vectors = np.array([ 118 | [1, 0, 1], 119 | [0, 1, 1], 120 | [0, 1, 0] 121 | ]) 122 | 123 | distances = cosine_similarity_matrix(vectors) 124 | 125 | print(distances) 126 | 127 | assert distances.shape == (3, 3) 128 | 129 | assert distances[0, 0] > .9999 130 | assert distances[1, 1] > .9999 131 | assert distances[2, 2] > .9999 132 | 133 | assert distances[0, 2] < distances[0, 1] 134 | 135 | # test_distance_matrix() -------------------------------------------------------------------------------- /mini_coil/data_pipeline/download_validation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import argparse 4 | from typing import List 5 | 6 | from qdrant_client import QdrantClient, models 7 | 8 | from mini_coil.settings import DATA_DIR 9 | 10 | DEFAULT_SAMPLE_SIZE = 1000 11 | 12 | QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:80") 13 | QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "") 14 | 15 | def query_sentences( 16 | collection_name: str, 17 | words: List[str], 18 | sample_size: int = DEFAULT_SAMPLE_SIZE, 19 | ) -> List[str]: 20 | 21 | qdrant = QdrantClient( 22 | url=QDRANT_URL, 23 | api_key=QDRANT_API_KEY, 24 | timeout=1000, 25 | ) 26 | 27 | response = qdrant.query_points( 28 | collection_name=collection_name, 29 | query=models.SampleQuery(sample=models.Sample.RANDOM), 30 | query_filter=models.Filter( 31 | should=[ 32 | models.FieldCondition( 33 | key="sentence", 34 | match=models.MatchText(text=word) 35 | ) for word in words 36 | ] 37 | ), 38 | limit=sample_size, 39 | with_payload=True, 40 | with_vectors=False, 41 | ) 42 | 43 | return [point.payload["sentence"] for point in response.points] 44 | 45 | 46 | def main(): 47 | default_vocab_path = os.path.join(DATA_DIR, "30k-vocab-filtered.json") 48 | 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("--word", type=str) 51 | parser.add_argument("--collection-name", type=str, default="coil-validation") 52 | parser.add_argument("--output-sentences", type=str) 53 | parser.add_argument("--sample-size", type=int, default=DEFAULT_SAMPLE_SIZE) 54 | parser.add_argument("--vocab-path", type=str, default=default_vocab_path) 55 | 56 | args = parser.parse_args() 57 | 58 | vocab = json.load(open(args.vocab_path)) 59 | 60 | if args.word not in vocab: 61 | print(f"WARNING: word {args.word} not found in vocab, using as is") 62 | forms = [args.word] 63 | else: 64 | forms = vocab[args.word] 65 | 66 | sentences = query_sentences( 67 | collection_name=args.collection_name, 68 | words=forms, 69 | sample_size=args.sample_size, 70 | ) 71 | 72 | with open(args.output_sentences, "w") as f: 73 | for sentence in sentences: 74 | f.write(sentence + "\n") 75 | 76 | if __name__ == "__main__": 77 | main() 78 | -------------------------------------------------------------------------------- /mini_coil/data_pipeline/encode_and_filter.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from typing import List, Iterable, Optional, Dict 5 | 6 | import numpy as np 7 | import tqdm 8 | from fastembed.late_interaction.token_embeddings import TokenEmbeddingsModel 9 | from npy_append_array import NpyAppendArray 10 | 11 | from mini_coil.data_pipeline.vocab_resolver import VocabResolver, VocabTokenizerTokenizer 12 | 13 | 14 | def load_model(model_name): 15 | return TokenEmbeddingsModel(model_name=model_name, threads=1) 16 | 17 | 18 | def read_sentences(file_path: str, limit_length: int = 4096) -> Iterable[Dict[str, str]]: 19 | with open(file_path, "r") as f: 20 | for line in f: 21 | doc = json.loads(line) 22 | yield { 23 | "sentence": doc["sentence"][:limit_length], 24 | "line_number": doc["line_number"], 25 | } 26 | 27 | 28 | def encode_and_filter(model_name: Optional[str], word: str, docs: List[dict]) -> Iterable[np.ndarray]: 29 | model = load_model(model_name) 30 | vocab_resolver = VocabResolver(tokenizer=VocabTokenizerTokenizer(model.tokenizer)) 31 | vocab_resolver.add_word(word) 32 | 33 | sentences = [doc["sentence"] for doc in docs] 34 | 35 | for embedding, sentence in zip(model.embed(sentences, batch_size=2), sentences): 36 | token_ids = np.array(model.tokenize([sentence])[0].ids) 37 | word_mask, counts, oov, _forms = vocab_resolver.resolve_tokens(token_ids) 38 | word_mask = word_mask.astype(bool) 39 | total_tokens = np.sum(word_mask) 40 | if total_tokens == 0: 41 | yield np.zeros(embedding.shape[1]) 42 | continue 43 | 44 | word_embeddings = embedding[word_mask] 45 | avg_embedding = np.mean(word_embeddings, axis=0) 46 | yield avg_embedding 47 | 48 | 49 | def main(): 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument("--sentences-file", type=str) 52 | parser.add_argument("--output-file", type=str) 53 | parser.add_argument("--output-line-numbers-file", type=str) 54 | parser.add_argument("--word", type=str) 55 | args = parser.parse_args() 56 | 57 | model_name = "jinaai/jina-embeddings-v2-small-en-tokens" 58 | 59 | input_file = args.sentences_file 60 | docs = list(read_sentences(input_file, limit_length=1024)) 61 | 62 | embeddings = encode_and_filter( 63 | model_name=model_name, 64 | word=args.word, 65 | docs=docs 66 | ) 67 | 68 | output_file = args.output_file 69 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 70 | 71 | text_np_emb_file = NpyAppendArray(output_file, delete_if_exists=True) 72 | line_numbers_file = args.output_line_numbers_file 73 | line_numbers = [] 74 | 75 | for doc, emb in tqdm.tqdm(zip(docs, embeddings), total=len(docs)): 76 | emb_conv = emb.reshape(1, -1) 77 | text_np_emb_file.append(emb_conv) 78 | line_numbers.append(int(doc["line_number"])) 79 | 80 | text_np_emb_file.close() 81 | np.save(line_numbers_file, np.array(line_numbers)) 82 | 83 | text_np_emb_file = np.load(output_file, mmap_mode='r') 84 | print(f"text_np_emb_file {output_file} shape:", text_np_emb_file.shape) 85 | print(f"line_numbers {line_numbers_file} shape:", np.load(line_numbers_file).shape) 86 | 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /mini_coil/data_pipeline/encode_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script takes raw training data and applies initial embeddings to it. 3 | 4 | The output of the process, for each abstract: 5 | 6 | * List of token ids 7 | * List of per-token embeddings 8 | * Aggregate embedding of the abstract 9 | 10 | This script can potentially generate huge amounts of data, so 11 | it will write directly to disk. 12 | 13 | """ 14 | 15 | import os 16 | from typing import Iterable 17 | import argparse 18 | 19 | import numpy as np 20 | import tqdm 21 | from npy_append_array import NpyAppendArray 22 | 23 | from mini_coil.data_pipeline.pre_encoder import PreEncoder 24 | from mini_coil.settings import DATA_DIR 25 | from mini_coil.data_pipeline.vocab_resolver import VocabResolver 26 | 27 | 28 | def read_texts(path: str) -> Iterable[str]: 29 | with open(path, "r") as f: 30 | for line in tqdm.tqdm(f): 31 | line = line.strip() 32 | if len(line) > 0: 33 | yield line 34 | 35 | 36 | def iter_batch(iterable, size): 37 | batch = [] 38 | for item in iterable: 39 | batch.append(item) 40 | if len(batch) >= size: 41 | yield batch 42 | batch = [] 43 | if len(batch) > 0: 44 | yield batch 45 | 46 | 47 | def main(): 48 | input_file = "bat.txt" 49 | 50 | default_input_data_path = os.path.join(DATA_DIR, "test", input_file) 51 | default_output_dir = os.path.join(DATA_DIR, "test") 52 | default_vocab_path = os.path.join(DATA_DIR, "test", "vocab.txt") 53 | 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--input-file", type=str, default=default_input_data_path) 56 | parser.add_argument("--output-dir", type=str, default=default_output_dir) 57 | parser.add_argument('--vocab-path', type=str, default=default_vocab_path) 58 | args = parser.parse_args() 59 | 60 | model_repository = "sentence-transformers/all-MiniLM-L6-v2" 61 | model_save_path = os.path.join(DATA_DIR, "all_miniLM_L6_v2.onnx") 62 | 63 | vocab_path = args.vocab_path 64 | 65 | batch_size = 32 66 | 67 | pre_encoder = PreEncoder(model_repository, model_save_path) 68 | 69 | vocab_resolver = VocabResolver(model_repository) 70 | vocab_resolver.load_vocab(vocab_path) 71 | 72 | total_token_emb_offset = 0 73 | offsets = [0] 74 | 75 | output_dir = args.output_dir 76 | 77 | # Create the output directory if it doesn't exist 78 | os.makedirs(output_dir, exist_ok=True) 79 | 80 | token_np_emb_file = NpyAppendArray( 81 | os.path.join(output_dir, "token_embeddings.npy"), 82 | delete_if_exists=True 83 | ) 84 | text_np_emb_file = NpyAppendArray( 85 | os.path.join(output_dir, "text_embeddings.npy"), 86 | delete_if_exists=True 87 | ) 88 | tokens_np_file = NpyAppendArray( 89 | os.path.join(output_dir, "tokens.npy"), 90 | delete_if_exists=True 91 | ) 92 | offsets_file = os.path.join(output_dir, "offsets.npy") 93 | 94 | for batch in iter_batch(read_texts(args.input_file), batch_size): 95 | batch_output = pre_encoder.encode(batch) 96 | 97 | batch_offsets, flattened_vocab_ids, flattened_token_embeddings = vocab_resolver.filter( 98 | batch_output["token_ids"], 99 | batch_output["token_embeddings"] 100 | ) 101 | 102 | tokens_np_file.append(flattened_vocab_ids) 103 | token_np_emb_file.append(flattened_token_embeddings) 104 | text_np_emb_file.append(batch_output["text_embeddings"]) 105 | 106 | for row_id, offset in enumerate(batch_offsets): 107 | global_offset = total_token_emb_offset + offset 108 | offsets.append(global_offset) 109 | total_token_emb_offset = global_offset 110 | 111 | offsets = np.array(offsets) 112 | 113 | np.save(offsets_file, offsets) 114 | 115 | token_np_emb_file.close() 116 | text_np_emb_file.close() 117 | tokens_np_file.close() 118 | 119 | print(total_token_emb_offset) 120 | 121 | 122 | if __name__ == "__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /mini_coil/data_pipeline/encode_targets.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script takes raw training data and applies initial embeddings to it. 3 | 4 | The output of the process, for each abstract: 5 | 6 | * List of token ids 7 | * List of per-token embeddings 8 | * Aggregate embedding of the abstract 9 | 10 | This script can potentially generate huge amounts of data, so 11 | it will write directly to disk. 12 | 13 | """ 14 | 15 | import argparse 16 | import gzip 17 | import itertools 18 | import os 19 | from typing import Iterable 20 | 21 | import numpy as np 22 | import tqdm 23 | from fastembed import TextEmbedding 24 | from npy_append_array import NpyAppendArray 25 | 26 | from mini_coil.settings import DATA_DIR 27 | 28 | 29 | def read_texts(path: str) -> Iterable[str]: 30 | with gzip.open(path, "rt") as f: 31 | for line in f: 32 | line = line.strip() 33 | _abs_hash, sentence = line.split("\t") 34 | yield sentence 35 | 36 | 37 | def main(): 38 | input_file = "bat.txt" 39 | 40 | default_input_data_path = os.path.join(DATA_DIR, "test", input_file) 41 | default_output_file = os.path.join(DATA_DIR, "test") 42 | 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument("--input-file", type=str, default=default_input_data_path) 45 | parser.add_argument("--output-file", type=str, default=default_output_file) 46 | parser.add_argument("--model-name", type=str, default="mixedbread-ai/mxbai-embed-large-v1") 47 | parser.add_argument("--use-cuda", action="store_true") 48 | parser.add_argument("--device-count", type=int, default=None) 49 | parser.add_argument("--max-count", type=int, default=None) 50 | parser.add_argument("--batch-size", type=int, default=1024) 51 | 52 | args = parser.parse_args() 53 | 54 | model_name = args.model_name 55 | 56 | device_ids = [i for i in range(args.device_count)] if args.device_count else None 57 | 58 | lazy_load = True if device_ids is not None else False 59 | 60 | model = TextEmbedding( 61 | model_name=model_name, 62 | cuda=args.use_cuda, 63 | device_ids=device_ids, 64 | lazy_load=lazy_load 65 | ) 66 | 67 | parallel = len(device_ids) if device_ids else None 68 | 69 | batch_size = args.batch_size 70 | 71 | output_file = args.output_file 72 | output_dir = os.path.basename(output_file) 73 | 74 | os.makedirs(output_dir, exist_ok=True) 75 | 76 | text_np_emb_file = NpyAppendArray(output_file, delete_if_exists=True) 77 | 78 | text_iterator = read_texts(args.input_file) 79 | 80 | if args.max_count: 81 | text_iterator = itertools.islice(text_iterator, args.max_count) 82 | 83 | for vector in tqdm.tqdm(model.embed( 84 | text_iterator, 85 | batch_size=batch_size, 86 | parallel=parallel 87 | )): 88 | # Convert to float16 and reshape from (dim,) to (1, dim) 89 | # vector_fp16 = vector.astype(np.float16).reshape(1, -1) 90 | 91 | # Do not convert to float16, just reshape 92 | vector_conv = vector.reshape(1, -1) 93 | text_np_emb_file.append(vector_conv) 94 | 95 | text_np_emb_file.close() 96 | 97 | # Check the output file shape 98 | 99 | text_np_emb_file = np.load(output_file, mmap_mode='r') 100 | 101 | print(f"text_np_emb_file {output_file} shape:", text_np_emb_file.shape) 102 | 103 | 104 | if __name__ == "__main__": 105 | main() 106 | -------------------------------------------------------------------------------- /mini_coil/data_pipeline/load_sentences.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from os import getenv 5 | from typing import List 6 | 7 | from qdrant_client import QdrantClient 8 | 9 | QDRANT_URL = os.environ.get("QDRANT_URL", getenv("QDRANT_URL", "http://localhost:80")) 10 | QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", getenv("QDRANT_API_KEY", "")) 11 | 12 | client = QdrantClient( 13 | url=QDRANT_URL, 14 | api_key=QDRANT_API_KEY, 15 | prefer_grpc=False, 16 | port=80, 17 | timeout=999, 18 | https=False, 19 | ) 20 | 21 | 22 | # Iterate over vocabulary and distance matrixes to select which abstracts and sentences should we used for inference 23 | # Additionally save the words accosiated with the selected abstracts and sentences 24 | 25 | def load_vocabulary(path: str) -> List[str]: 26 | vocabulary = [] 27 | with open(path, 'r') as f: 28 | for line in f: 29 | vocabulary.append(line.strip()) 30 | 31 | return vocabulary 32 | 33 | 34 | def load_matrix_ids(directory: str, word: str): 35 | path = os.path.join(directory, f"sparse_matrix_{word}.json") 36 | if not os.path.exists(path): 37 | return None 38 | 39 | with open(path, 'r') as f: 40 | matrix = json.load(f) 41 | ids = matrix['ids'] 42 | return ids 43 | 44 | 45 | def load_sentences(collection_name: str, ids: List[str]) -> List[dict]: 46 | points_to_abstracts = {} 47 | batch_size = 10000 48 | 49 | for i in range(0, len(ids), batch_size): 50 | print("Retrieving batch", i + 1, "of", len(ids) // batch_size) 51 | batch_ids = ids[i:i + batch_size] 52 | points = client.retrieve( 53 | collection_name, 54 | batch_ids, 55 | with_payload=True, 56 | with_vectors=False, 57 | ) 58 | batch_points = dict((point.id, point.payload) for point in points) 59 | points_to_abstracts.update(batch_points) 60 | 61 | result = [] 62 | for point_id in ids: 63 | if point_id not in points_to_abstracts: 64 | print(f"Point {point_id} not found in collection {collection_name}") 65 | exit(1) 66 | 67 | payload = points_to_abstracts[point_id] 68 | ln = payload["line_number"] 69 | 70 | result.append({ 71 | "id": point_id, 72 | "line_number": ln, 73 | "sentence": payload["sentence"], 74 | "abs_hash": payload["abs_hash"] 75 | }) 76 | 77 | return result 78 | 79 | 80 | def main(): 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument("--abstracts-collection-name", type=str, default="coil-abstracts") 83 | parser.add_argument("--sentences-collection-name", type=str, default="coil") 84 | parser.add_argument("--word", type=str) 85 | parser.add_argument("--matrix-dir", type=str) 86 | parser.add_argument("--output-dir", type=str) 87 | args = parser.parse_args() 88 | 89 | sentences_collection_name = args.sentences_collection_name 90 | abstracts_collection_name = args.abstracts_collection_name 91 | 92 | word = args.word 93 | 94 | ids = load_matrix_ids(args.matrix_dir, word) 95 | if ids is None: 96 | print(f"Matrix for word {word} does not exist") 97 | return 98 | 99 | sentences_data = load_sentences(sentences_collection_name, ids) 100 | 101 | output_dir = args.output_dir 102 | if not os.path.exists(output_dir): 103 | os.makedirs(output_dir, exist_ok=True) 104 | 105 | output_path = os.path.join(output_dir, f"sentences-{word}.jsonl") 106 | 107 | with open(output_path, 'w') as f: 108 | for sentence in sentences_data: 109 | f.write(json.dumps(sentence) + '\n') 110 | 111 | 112 | if __name__ == "__main__": 113 | main() 114 | -------------------------------------------------------------------------------- /mini_coil/data_pipeline/pre_encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from typing import List 4 | 5 | import numpy as np 6 | from transformers import AutoModel, AutoTokenizer 7 | import torch 8 | import onnxruntime as ort 9 | 10 | from mini_coil.settings import DATA_DIR 11 | 12 | 13 | def cosine_similarity(rows_a: np.ndarray, rows_b: np.ndarray): 14 | """ 15 | Compute a matrix of cosine distances between two sets of vectors. 16 | """ 17 | # Normalize the vectors 18 | rows_a = rows_a / np.linalg.norm(rows_a, axis=1, keepdims=True) 19 | rows_b = rows_b / np.linalg.norm(rows_b, axis=1, keepdims=True) 20 | 21 | # Compute the cosine similarity 22 | return np.dot(rows_a, rows_b.T) 23 | 24 | 25 | def download_and_save_onnx(model_repository, model_save_path): 26 | # Load model and tokenizer 27 | model = AutoModel.from_pretrained(model_repository, trust_remote_code=True) 28 | tokenizer = AutoTokenizer.from_pretrained(model_repository) 29 | 30 | texts = [ 31 | "Hello, this is a test.", 32 | "This is another test, a bit longer than the first one.", 33 | ] 34 | 35 | # Prepare dummy input for model export 36 | inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True) 37 | 38 | # Prepare the model for exporting 39 | model.eval() 40 | 41 | # Export the model to ONNX 42 | with torch.no_grad(): 43 | torch.onnx.export( 44 | model, 45 | args=(inputs['input_ids'], inputs['attention_mask'], inputs['token_type_ids']), 46 | f=model_save_path, 47 | input_names=['input_ids', 'attention_mask', 'token_type_ids'], 48 | output_names=[ 49 | 'last_hidden_state', 50 | 'embedding' 51 | ], 52 | dynamic_axes={ 53 | 'input_ids': {0: 'batch_size', 1: 'sequence_length'}, 54 | 'attention_mask': {0: 'batch_size', 1: 'sequence_length'}, 55 | 'token_type_ids': {0: 'batch_size', 1: 'sequence_length'}, 56 | 'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}, 57 | 'embedding': {0: 'batch_size'} 58 | }, 59 | opset_version=14 60 | ) 61 | 62 | 63 | class PreEncoder: 64 | 65 | def __init__(self, model_repository: str, model_path: str): 66 | self.tokenizer = AutoTokenizer.from_pretrained(model_repository) 67 | self.session = ort.InferenceSession(model_path) 68 | 69 | def encode(self, texts: List[str]): 70 | inputs = self.tokenizer(texts, return_tensors="np", padding=True, truncation=True) 71 | # { 72 | # 'input_ids': array( 73 | # [ 74 | # [101, 7592, 1010, 2023, 2003, 1037, 3231, 1012, 102], 75 | # [101, 2023, 2003, 2178, 3231, 1012, 102, 0, 0] 76 | # ] 77 | # ), 78 | # 'token_type_ids': array( 79 | # [ 80 | # [0, 0, 0, 0, 0, 0, 0, 0, 0], 81 | # [0, 0, 0, 0, 0, 0, 0, 0, 0] 82 | # ] 83 | # ), 84 | # 'attention_mask': array( 85 | # [ 86 | # [1, 1, 1, 1, 1, 1, 1, 1, 1], 87 | # [1, 1, 1, 1, 1, 1, 1, 0, 0] 88 | # ] 89 | # ) 90 | # } 91 | 92 | outputs = self.session.run(None, {**inputs}) 93 | 94 | # (batch_size, sequence_length) 95 | input_ids = inputs['input_ids'] 96 | 97 | # (batch_size, sequence_length, embedding_size) 98 | token_embeddings = outputs[0] 99 | 100 | # (batch_size, embedding_size) 101 | text_embeddings = outputs[1] 102 | 103 | # (batch_size) 104 | number_of_tokens = np.sum(inputs['attention_mask'], axis=1) 105 | 106 | return { 107 | 'number_of_tokens': number_of_tokens, 108 | 'token_ids': input_ids, 109 | 'token_embeddings': token_embeddings, 110 | 'text_embeddings': text_embeddings 111 | } 112 | 113 | 114 | def check_similarity(): 115 | from sentence_transformers import SentenceTransformer 116 | 117 | # model_repository = "Alibaba-NLP/gte-large-en-v1.5" 118 | model_repository = "mixedbread-ai/mxbai-embed-large-v1" 119 | 120 | model = SentenceTransformer(model_repository, trust_remote_code=True, device="cpu") 121 | 122 | text_a = "The bat flew out of the cave." 123 | text_b = "He is a baseball player. He knows how to swing a bat." 124 | text_c = "A bat can use echolocation to navigate in the dark." 125 | text_d = "It was just a cricket bat." 126 | text_e = "And guess who the orphans have at bat!" 127 | text_f = "Eric Byrnes, never with an at bat in Yankee Stadium and they don't get much bigger than this one." 128 | 129 | texts = [text_a, text_b, text_c, text_d, text_e, text_f] 130 | 131 | time_start = time.time() 132 | 133 | embeddings = model.encode(texts) 134 | 135 | print("Time taken to encode:", time.time() - time_start) 136 | 137 | original_matrix = cosine_similarity(embeddings, embeddings) 138 | 139 | print("original similarity matrix\n", original_matrix) 140 | 141 | texts = [ 142 | "java developer intern", 143 | "coffee from java island", 144 | "java programming language", 145 | "java is located in indonesia", 146 | ] 147 | 148 | embeddings = model.encode(texts) 149 | 150 | original_matrix = cosine_similarity(embeddings, embeddings) 151 | 152 | print("original similarity matrix\n", original_matrix) 153 | 154 | 155 | if __name__ == "__main__": 156 | # check_similarity() 157 | # 158 | # exit(0) 159 | 160 | # Specify the Hugging Face repository and the local path for saving the ONNX model 161 | model_repository = "jinaai/jina-embeddings-v2-base-en" 162 | model_save_path = os.path.join(DATA_DIR, "jina-embeddings-v2-base-en.onnx") 163 | 164 | download_and_save_onnx(model_repository, model_save_path) 165 | 166 | pre_encoder = PreEncoder(model_repository, model_save_path) 167 | 168 | text_a = "The bat flew out of the cave." 169 | text_b = "He is a baseball player. He knows how to swing a bat." 170 | text_c = "A bat can use echolocation to navigate in the dark." 171 | text_d = "It was just a cricket bat." 172 | text_e = "And guess who the orphans have at bat!" 173 | text_f = "Eric Byrnes, never with an at bat in Yankee Stadium and they don't get much bigger than this one." 174 | 175 | texts = [text_a, text_b, text_c, text_d, text_e, text_f] 176 | 177 | text_embeddings = pre_encoder.encode(texts)["text_embeddings"] 178 | 179 | original_matrix = cosine_similarity(text_embeddings, text_embeddings) 180 | 181 | print("original similarity matrix\n", original_matrix) 182 | -------------------------------------------------------------------------------- /mini_coil/data_pipeline/prepare_vocab.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | from typing import List 4 | import tqdm 5 | import json 6 | 7 | from py_rust_stemmers import SnowballStemmer 8 | from nltk import WordNetLemmatizer 9 | 10 | 11 | def read_source_vocab(path) -> List[str]: 12 | with open(path, "r") as f: 13 | return [line.strip() for line in f] 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--input-file", type=str) 19 | parser.add_argument("--output-file", type=str) 20 | args = parser.parse_args() 21 | 22 | import nltk 23 | 24 | nltk.download('wordnet') 25 | 26 | source_vocab = read_source_vocab(args.input_file) 27 | 28 | # Remove stopwords and words with less than 3 characters 29 | # As vocab is sorted by frequency, define stopwords as first 100 words if the word is less than 5 30 | 31 | target_vocab = [word for idx, word in enumerate(source_vocab) if len(word) > (5 if idx < 100 else 2)] 32 | 33 | stemmer = SnowballStemmer("english") 34 | lemmatizer = WordNetLemmatizer() 35 | 36 | normalized_words = defaultdict(set) 37 | 38 | for word in tqdm.tqdm(target_vocab): 39 | anchor_word = stemmer.stem_word(word) 40 | lemmatized_word = lemmatizer.lemmatize(word) 41 | 42 | normalized_words[anchor_word].add(word) 43 | normalized_words[anchor_word].add(lemmatized_word) 44 | 45 | normalized_words = {k: list(v) for k, v in normalized_words.items()} 46 | 47 | with open(args.output_file, "w") as f: 48 | json.dump(normalized_words, f, indent=4) 49 | 50 | 51 | if __name__ == "__main__": 52 | main() 53 | -------------------------------------------------------------------------------- /mini_coil/data_pipeline/read_pre_encoded.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict 3 | 4 | import numpy as np 5 | 6 | from mini_coil.settings import DATA_DIR 7 | 8 | 9 | class PreEncodedReader: 10 | def __init__(self, path): 11 | self.path = path 12 | token_np_emb_file = os.path.join(self.path, "token_embeddings.npy") 13 | text_np_emb_file = os.path.join(self.path, "text_embeddings.npy") 14 | tokens_np_file = os.path.join(self.path, "tokens.npy") 15 | offsets_file = os.path.join(self.path, "offsets.npy") 16 | 17 | self.offsets = np.load(offsets_file) 18 | 19 | self.token_embeddings = np.load(token_np_emb_file, mmap_mode='r') 20 | self.text_embeddings = np.load(text_np_emb_file, mmap_mode='r') 21 | self.token_ids = np.load(tokens_np_file, mmap_mode='r') 22 | 23 | # print("self.offsets", self.offsets.shape) 24 | # print("self.token_embeddings", self.token_embeddings.shape) 25 | # print("self.text_embeddings", self.text_embeddings.shape) 26 | # print("self.tokens", self.tokens.shape) 27 | 28 | def __len__(self): 29 | return len(self.offsets) - 1 30 | 31 | def read_one(self, idx: int) -> Dict[str, np.ndarray]: 32 | start = self.offsets[idx] 33 | end = self.offsets[idx + 1] 34 | token_ids = self.token_ids[start:end] 35 | 36 | return { 37 | 'token_embeddings': self.token_embeddings[start:end], 38 | 'text_embeddings': self.text_embeddings[idx], 39 | 'token_ids': token_ids 40 | } 41 | 42 | def read(self, from_idx: int, to_idx: int) -> Dict[str, np.ndarray]: 43 | token_ids_batch = [] 44 | token_embeddings_batch = [] 45 | text_embeddings_batch = [] 46 | 47 | for idx in range(from_idx, to_idx): 48 | data = self.read_one(idx) 49 | 50 | token_ids_batch.append(data['token_ids']) 51 | token_embeddings_batch.append(data['token_embeddings']) 52 | text_embeddings_batch.append(data['text_embeddings']) 53 | 54 | # (batch_size, embedding_size) 55 | text_embeddings_batch = np.stack(text_embeddings_batch) 56 | 57 | # token_embeddings_batch and token_ids_batch require padding 58 | max_len = max(len(x) for x in token_ids_batch) 59 | token_embeddings_padded = np.zeros((len(token_embeddings_batch), max_len, token_embeddings_batch[0].shape[1])) 60 | token_ids_padded = np.zeros((len(token_ids_batch), max_len), dtype=np.int64) 61 | 62 | for i, (token_ids, token_embeddings) in enumerate(zip(token_ids_batch, token_embeddings_batch)): 63 | token_embeddings_padded[i, :len(token_ids)] = token_embeddings 64 | token_ids_padded[i, :len(token_ids)] = token_ids 65 | 66 | return { 67 | 'token_embeddings': token_embeddings_padded, 68 | 'text_embeddings': text_embeddings_batch, 69 | 'token_ids': token_ids_padded 70 | } 71 | 72 | 73 | def main(): 74 | path = os.path.join(DATA_DIR, "test") 75 | reader = PreEncodedReader(path) 76 | batch = reader.read(20, 25) 77 | print("token_embeddings", batch['token_embeddings'].shape) 78 | print("text_embeddings", batch['text_embeddings'].shape) 79 | print("token_ids", batch['token_ids']) 80 | 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /mini_coil/data_pipeline/split_sentences.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import hashlib 3 | import gzip 4 | 5 | from typing import Iterable 6 | 7 | import tqdm 8 | from sentence_splitter import SentenceSplitter 9 | 10 | 11 | def compute_hash(text: str) -> str: 12 | return hashlib.sha256(text.encode()).hexdigest() 13 | 14 | 15 | def read_abstracts(path: str) -> Iterable[str]: 16 | with gzip.open(path, "rt") as f: 17 | for line in f: 18 | yield line.strip() 19 | 20 | 21 | def sentence_splitter(abstracts: Iterable[str]) -> Iterable[str]: 22 | splitter = SentenceSplitter(language='en') 23 | for abstract in abstracts: 24 | if len(abstract) == 0: 25 | continue 26 | abstract_hash = compute_hash(abstract) 27 | for sentence in splitter.split(abstract): 28 | yield abstract_hash, sentence 29 | 30 | 31 | def main(): 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--input-file", type=str) 34 | parser.add_argument("--output-file", type=str) 35 | args = parser.parse_args() 36 | 37 | abstracts = read_abstracts(args.input_file) 38 | 39 | sentences = sentence_splitter(abstracts) 40 | 41 | with gzip.open(args.output_file, "wt") as f: 42 | for abs_hash, sentence in tqdm.tqdm(sentences): 43 | f.write(f"{abs_hash}\t{sentence}\n") 44 | 45 | 46 | if __name__ == "__main__": 47 | main() 48 | -------------------------------------------------------------------------------- /mini_coil/data_pipeline/split_train_val.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import Iterable 4 | import random 5 | import tqdm 6 | 7 | from mini_coil.settings import DATA_DIR 8 | 9 | 10 | def read_abstracts(path: str) -> Iterable[str]: 11 | with open(path, "r") as f: 12 | for line in f: 13 | yield line.strip() 14 | 15 | 16 | def main(): 17 | default_input_data_path = os.path.join(DATA_DIR, "test", "bat.txt") 18 | default_train_path = os.path.join(DATA_DIR, "test", "train.txt") 19 | default_valid_path = os.path.join(DATA_DIR, "test", "valid.txt") 20 | 21 | default_split_ratio = 0.8 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--input-file", type=str, default=default_input_data_path) 25 | parser.add_argument("--out-train", type=str, default=default_train_path) 26 | parser.add_argument('--out-valid', type=str, default=default_valid_path) 27 | parser.add_argument('--split-ratio', type=float, default=default_split_ratio) 28 | args = parser.parse_args() 29 | 30 | abstracts = read_abstracts(args.input_file) 31 | 32 | train_file = open(args.out_train, "w") 33 | valid_file = open(args.out_valid, "w") 34 | 35 | for i, abstract in tqdm.tqdm(enumerate(abstracts)): 36 | if random.random() < args.split_ratio: 37 | train_file.write(abstract + "\n") 38 | else: 39 | valid_file.write(abstract + "\n") 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /mini_coil/data_pipeline/stopwords.py: -------------------------------------------------------------------------------- 1 | english_stopwords = { 2 | "i", 3 | "me", 4 | "my", 5 | "myself", 6 | "we", 7 | "our", 8 | "ours", 9 | "ourselves", 10 | "you", 11 | "your", 12 | "yours", 13 | "yourself", 14 | "yourselves", 15 | "he", 16 | "him", 17 | "his", 18 | "himself", 19 | "she", 20 | "her", 21 | "hers", 22 | "herself", 23 | "it", 24 | "its", 25 | "itself", 26 | "they", 27 | "them", 28 | "their", 29 | "theirs", 30 | "themselves", 31 | "what", 32 | "which", 33 | "who", 34 | "whom", 35 | "this", 36 | "that", 37 | "these", 38 | "those", 39 | "am", 40 | "is", 41 | "are", 42 | "was", 43 | "were", 44 | "be", 45 | "been", 46 | "being", 47 | "have", 48 | "has", 49 | "had", 50 | "having", 51 | "do", 52 | "does", 53 | "did", 54 | "doing", 55 | "a", 56 | "an", 57 | "the", 58 | "and", 59 | "but", 60 | "if", 61 | "or", 62 | "because", 63 | "as", 64 | "until", 65 | "while", 66 | "of", 67 | "at", 68 | "by", 69 | "for", 70 | "with", 71 | "about", 72 | "against", 73 | "between", 74 | "into", 75 | "through", 76 | "during", 77 | "before", 78 | "after", 79 | "above", 80 | "below", 81 | "to", 82 | "from", 83 | "up", 84 | "down", 85 | "in", 86 | "out", 87 | "on", 88 | "off", 89 | "over", 90 | "under", 91 | "again", 92 | "further", 93 | "then", 94 | "once", 95 | "here", 96 | "there", 97 | "when", 98 | "where", 99 | "why", 100 | "how", 101 | "all", 102 | "any", 103 | "both", 104 | "each", 105 | "few", 106 | "more", 107 | "most", 108 | "other", 109 | "some", 110 | "such", 111 | "no", 112 | "nor", 113 | "not", 114 | "only", 115 | "own", 116 | "same", 117 | "so", 118 | "than", 119 | "too", 120 | "very", 121 | "s", 122 | "t", 123 | "can", 124 | "will", 125 | "just", 126 | "don", 127 | "should", 128 | "now", 129 | "d", 130 | "ll", 131 | "m", 132 | "o", 133 | "re", 134 | "ve", 135 | "y", 136 | "ain", 137 | "aren", 138 | "couldn", 139 | "didn", 140 | "doesn", 141 | "hadn", 142 | "hasn", 143 | "haven", 144 | "isn", 145 | "ma", 146 | "mightn", 147 | "mustn", 148 | "needn", 149 | "shan", 150 | "shouldn", 151 | "wasn", 152 | "weren", 153 | "won", 154 | "wouldn", 155 | } 156 | -------------------------------------------------------------------------------- /mini_coil/data_pipeline/upload_compressed_to_qdrant.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import hashlib 3 | import os 4 | from typing import Iterable 5 | import json 6 | import itertools 7 | 8 | from qdrant_client import QdrantClient, models 9 | import numpy as np 10 | 11 | QDRANT_URL = os.environ.get("QDRANT_URL", "http://localhost:6333") 12 | QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "") 13 | 14 | 15 | def load_sentences(path: str) -> Iterable[dict]: 16 | with open(path, "r") as f: 17 | for line in f: 18 | yield json.loads(line) 19 | 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--sentences-path", type=str, default=None) 24 | parser.add_argument("--compressed-path", type=str) 25 | parser.add_argument("--collection-name", type=str, default="coil-targets") 26 | parser.add_argument("--recreate-collection", action="store_true") 27 | parser.add_argument("--word", type=str) 28 | parser.add_argument("--limit", type=int, default=None) 29 | args = parser.parse_args() 30 | 31 | client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY) 32 | 33 | vectors = np.load(args.compressed_path) 34 | 35 | dim = vectors.shape[1] 36 | 37 | collection_name = args.collection_name 38 | 39 | collection_exists = client.collection_exists(collection_name) 40 | 41 | if collection_exists and args.recreate_collection: 42 | client.delete_collection(collection_name) 43 | collection_exists = False 44 | 45 | if not collection_exists: 46 | client.create_collection( 47 | collection_name=collection_name, 48 | vectors_config=models.VectorParams( 49 | size=dim, 50 | distance=models.Distance.COSINE, 51 | on_disk=True, 52 | ), 53 | hnsw_config=models.HnswConfigDiff( 54 | m=0, 55 | payload_m=16, 56 | ) 57 | ) 58 | 59 | client.create_payload_index( 60 | collection_name=collection_name, 61 | field_name="word", 62 | field_schema=models.KeywordIndexParams( 63 | type=models.KeywordIndexType.KEYWORD, 64 | is_tenant=True, 65 | on_disk=True 66 | ) 67 | ) 68 | 69 | payloads = load_sentences(args.sentences_path) 70 | 71 | if args.limit: 72 | vectors = vectors[:args.limit] 73 | payloads = itertools.islice(payloads, args.limit) 74 | 75 | client.upload_collection( 76 | collection_name=collection_name, 77 | ids=list(map(lambda x: hashlib.md5((args.word + str(x)).encode()).hexdigest(), range(len(vectors)))), 78 | vectors=vectors, 79 | payload=map(lambda x: {"word": args.word, **x}, payloads), 80 | ) 81 | 82 | 83 | if __name__ == "__main__": 84 | main() 85 | -------------------------------------------------------------------------------- /mini_coil/data_pipeline/upload_to_qdrant.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import hashlib 3 | import os 4 | from os import getenv 5 | from typing import Iterable 6 | 7 | import numpy as np 8 | import tqdm 9 | from qdrant_client import QdrantClient, models 10 | 11 | QDRANT_URL = os.environ.get("QDRANT_URL", getenv("QDRANT_URL", "http://localhost:80")) 12 | QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", getenv("QDRANT_API_KEY", "")) 13 | 14 | 15 | def read_texts(path: str) -> Iterable[str]: 16 | is_gz = path.endswith(".gz") 17 | 18 | if is_gz: 19 | import gzip 20 | with gzip.open(path, "rt") as f: 21 | for line in f: 22 | abs_hash, sentence = line.strip().split("\t") 23 | yield abs_hash, sentence 24 | else: 25 | with open(path, "r") as f: 26 | for line in f: 27 | abs_hash, sentence = line.strip().split("\t") 28 | yield abs_hash, sentence 29 | 30 | 31 | def embed_texts(texts: Iterable[str]) -> np.ndarray: 32 | from sentence_transformers import SentenceTransformer 33 | model_repository = "mixedbread-ai/mxbai-embed-large-v1" 34 | model = SentenceTransformer(model_repository, trust_remote_code=True) 35 | 36 | texts = list(texts) 37 | 38 | embeddings = model.encode(texts) 39 | 40 | return embeddings 41 | 42 | 43 | def main(): 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("--input-emb", type=str, default=None) 46 | parser.add_argument("--input-text", type=str) 47 | parser.add_argument("--collection-name", type=str, default="coil") 48 | parser.add_argument("--recreate-collection", action="store_true") 49 | parser.add_argument("--parallel", type=int, default=1) 50 | parser.add_argument("--skip-first", type=int, default=0) 51 | args = parser.parse_args() 52 | 53 | collection_name = args.collection_name 54 | 55 | if args.input_emb is None: 56 | embeddings = embed_texts(map(lambda x: x[0], read_texts(args.input_text))) 57 | else: 58 | embeddings = np.load(args.input_emb, mmap_mode='r') 59 | 60 | def data_iter(): 61 | skipped = 0 62 | texts_iter = read_texts(args.input_text) 63 | for (abs_hash, sentence), emb in zip(texts_iter, embeddings): 64 | if skipped < args.skip_first: 65 | skipped += 1 66 | continue 67 | # Compute hash from the text and convert it to UUID 68 | hash_uuid = hashlib.md5((sentence + abs_hash).encode()).hexdigest() 69 | 70 | yield models.PointStruct( 71 | id=hash_uuid, 72 | vector=emb.tolist(), 73 | payload={"sentence": sentence, "abs_hash": abs_hash} 74 | ) 75 | 76 | qdrant = QdrantClient( 77 | url=QDRANT_URL, 78 | api_key=QDRANT_API_KEY, 79 | prefer_grpc=True, 80 | ) 81 | 82 | collection_exists = qdrant.collection_exists(collection_name) 83 | 84 | if not collection_exists or args.recreate_collection: 85 | qdrant.delete_collection(collection_name) 86 | 87 | qdrant.create_collection( 88 | collection_name, 89 | vectors_config=models.VectorParams( 90 | size=len(embeddings[0]), 91 | distance=models.Distance.COSINE, 92 | datatype=models.Datatype.FLOAT16, 93 | on_disk=True, 94 | ), 95 | hnsw_config=models.HnswConfigDiff( 96 | m=0, 97 | max_indexing_threads=1, 98 | ), 99 | optimizers_config=models.OptimizersConfigDiff( 100 | max_segment_size=2_000_000, 101 | max_optimization_threads=1, # Run one optimization per shard 102 | ), 103 | shard_number=6, 104 | ) 105 | 106 | qdrant.create_payload_index( 107 | collection_name, 108 | "sentence", 109 | field_schema=models.TextIndexParams( 110 | type=models.TextIndexType.TEXT, 111 | tokenizer=models.TokenizerType.WORD, 112 | lowercase=True, 113 | on_disk=True, 114 | ) 115 | ) 116 | 117 | qdrant.upload_points( 118 | collection_name, 119 | points=tqdm.tqdm(data_iter()), 120 | parallel=args.parallel, 121 | ) 122 | 123 | 124 | if __name__ == "__main__": 125 | main() 126 | -------------------------------------------------------------------------------- /mini_coil/data_pipeline/vocab_resolver.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Iterable, Tuple, List 3 | 4 | from py_rust_stemmers import SnowballStemmer 5 | 6 | import numpy as np 7 | from tokenizers import Tokenizer 8 | from transformers import AutoTokenizer 9 | 10 | from mini_coil.data_pipeline.stopwords import english_stopwords 11 | 12 | 13 | class VocabTokenizer: 14 | 15 | def tokenize(self, sentence: str) -> np.ndarray: 16 | raise NotImplementedError() 17 | 18 | def convert_ids_to_tokens(self, token_ids: np.ndarray) -> list: 19 | raise NotImplementedError() 20 | 21 | 22 | class VocabTokenizerAutoTokenizer(VocabTokenizer): 23 | def __init__(self, model_repository: str): 24 | self.auto_tokenizer = AutoTokenizer.from_pretrained(model_repository) 25 | 26 | def tokenize(self, sentence: str) -> np.ndarray: 27 | return np.array(self.auto_tokenizer(sentence).input_ids) 28 | 29 | def convert_ids_to_tokens(self, token_ids: np.ndarray) -> list: 30 | return self.auto_tokenizer.convert_ids_to_tokens(token_ids) 31 | 32 | 33 | class VocabTokenizerTokenizer(VocabTokenizer): 34 | def __init__(self, tokenizer: Tokenizer): 35 | self.tokenizer = tokenizer 36 | 37 | def tokenize(self, sentence: str) -> np.ndarray: 38 | return np.array(self.tokenizer.encode(sentence).ids) 39 | 40 | def convert_ids_to_tokens(self, token_ids: np.ndarray) -> list: 41 | return [self.tokenizer.id_to_token(token_id) for token_id in token_ids] 42 | 43 | 44 | class VocabResolver: 45 | def __init__(self, model_repository: str = None, tokenizer: VocabTokenizer = None): 46 | # Word to id mapping 47 | self.vocab = {} 48 | # Id to word mapping 49 | self.words = [] 50 | # Lemma to word mapping 51 | self.stem_mapping = {} 52 | self.tokenizer: VocabTokenizer = tokenizer 53 | self.stemmer = SnowballStemmer("english") 54 | if model_repository is not None and tokenizer is None: 55 | self.tokenizer = VocabTokenizerAutoTokenizer(model_repository) 56 | 57 | def tokenize(self, sentence: str) -> np.ndarray: 58 | return self.tokenizer.tokenize(sentence) 59 | 60 | def lookup_word(self, word_id: int) -> str: 61 | if word_id == 0: 62 | return "UNK" 63 | return self.words[word_id - 1] 64 | 65 | def convert_ids_to_tokens(self, token_ids: np.ndarray) -> list: 66 | return self.tokenizer.convert_ids_to_tokens(token_ids) 67 | 68 | def vocab_size(self): 69 | return len(self.vocab) + 1 70 | 71 | def save_vocab(self, path): 72 | with open(path, "w") as f: 73 | for word in self.words: 74 | f.write(word + "\n") 75 | 76 | def save_json_vocab(self, path): 77 | import json 78 | with open(path, "w") as f: 79 | json.dump({ 80 | "vocab": self.words, 81 | "stem_mapping": self.stem_mapping 82 | }, f, indent=2) 83 | 84 | def load_json_vocab(self, path): 85 | import json 86 | with open(path, "r") as f: 87 | data = json.load(f) 88 | self.words = data["vocab"] 89 | self.vocab = {word: idx + 1 for idx, word in enumerate(self.words)} 90 | self.stem_mapping = data["stem_mapping"] 91 | 92 | 93 | def add_word(self, word): 94 | if word not in self.vocab: 95 | self.vocab[word] = len(self.vocab) + 1 96 | self.words.append(word) 97 | stem = self.stemmer.stem_word(word) 98 | if stem not in self.stem_mapping: 99 | self.stem_mapping[stem] = word 100 | else: 101 | existing_word = self.stem_mapping[stem] 102 | if len(existing_word) > len(word): 103 | # Prefer shorter words for the same stem 104 | # Example: "swim" is preferred over "swimming" 105 | self.stem_mapping[stem] = word 106 | 107 | def load_vocab(self, path): 108 | with open(path, "r") as f: 109 | for line in f: 110 | self.add_word(line.strip()) 111 | 112 | @classmethod 113 | def _reconstruct_bpe( 114 | self, bpe_tokens: Iterable[Tuple[int, str]] 115 | ) -> List[Tuple[str, List[int]]]: 116 | result = [] 117 | acc = "" 118 | acc_idx = [] 119 | 120 | continuing_subword_prefix = "##" 121 | continuing_subword_prefix_len = len(continuing_subword_prefix) 122 | 123 | for idx, token in bpe_tokens: 124 | 125 | if token.startswith(continuing_subword_prefix): 126 | acc += token[continuing_subword_prefix_len:] 127 | acc_idx.append(idx) 128 | else: 129 | if acc: 130 | result.append((acc, acc_idx)) 131 | acc_idx = [] 132 | acc = token 133 | acc_idx.append(idx) 134 | 135 | if acc: 136 | result.append((acc, acc_idx)) 137 | 138 | return result 139 | 140 | def resolve_tokens(self, token_ids: np.ndarray) -> (np.ndarray, dict, dict, dict): 141 | """ 142 | Mark known tokens (including composed tokens) with vocab ids. 143 | 144 | Args: 145 | token_ids: (seq_len) - list of ids of tokens 146 | Example: 147 | [ 148 | 101, 3897, 19332, 12718, 23348, 149 | 1010, 1996, 7151, 2296, 4845, 150 | 2359, 2005, 4234, 1010, 4332, 151 | 2871, 3191, 2062, 102 152 | ] 153 | 154 | returns: 155 | - token_ids with vocab ids 156 | [ 157 | 0, 151, 151, 0, 0, 158 | 912, 0, 0, 0, 332, 159 | 332, 332, 0, 7121, 191, 160 | 0, 0, 332, 0 161 | ] 162 | - counts of each token 163 | { 164 | 151: 1, 165 | 332: 3, 166 | 7121: 1, 167 | 191: 1, 168 | 912: 1 169 | } 170 | - oov counts of each token 171 | { 172 | "the": 1, 173 | "a": 1, 174 | "[CLS]": 1, 175 | "[SEP]": 1, 176 | ... 177 | } 178 | - forms of each token 179 | { 180 | "hello": ["hello"], 181 | "world": ["worlds", "world", "worlding"], 182 | } 183 | 184 | """ 185 | 186 | # deep copy of token_ids 187 | token_ids = token_ids.copy() 188 | tokens = self.convert_ids_to_tokens(token_ids) 189 | tokens_mapping = self._reconstruct_bpe(enumerate(tokens)) 190 | 191 | counts = defaultdict(int) 192 | oov_count = defaultdict(int) 193 | 194 | forms = defaultdict(list) 195 | 196 | for token, mapped_token_ids in tokens_mapping: 197 | vocab_id = 0 198 | if token in english_stopwords: 199 | vocab_id = 0 200 | elif token in self.vocab: 201 | vocab_id = self.vocab[token] 202 | forms[token].append(token) 203 | elif token in self.stem_mapping: 204 | vocab_id = self.vocab[self.stem_mapping[token]] 205 | forms[self.stem_mapping[token]].append(token) 206 | else: 207 | stem = self.stemmer.stem_word(token) 208 | if stem in self.stem_mapping: 209 | vocab_id = self.vocab[self.stem_mapping[stem]] 210 | forms[self.stem_mapping[stem]].append(token) 211 | 212 | for token_id in mapped_token_ids: 213 | token_ids[token_id] = vocab_id 214 | 215 | if vocab_id == 0: 216 | oov_count[token] += 1 217 | else: 218 | counts[vocab_id] += 1 219 | 220 | return token_ids, counts, oov_count, forms 221 | 222 | def token_ids_to_vocab_batch(self, token_ids: np.ndarray) -> np.ndarray: 223 | """ 224 | Mark known tokens (including composed tokens) with vocab ids. 225 | 226 | Args: 227 | token_ids: (batch_size, seq_len) - list of ids of tokens 228 | Example: 229 | [ 230 | [101, 3897, 19332, 12718, 23348], 231 | [1010, 1996, 7151, 2296, 4845], 232 | [2359, 2005, 4234, 1010, 4332], 233 | [2871, 3191, 2062, 102, 0] 234 | ] 235 | 236 | """ 237 | 238 | for i in range(token_ids.shape[0]): 239 | self.resolve_tokens(token_ids[i]) 240 | 241 | return token_ids 242 | 243 | def filter( 244 | self, 245 | token_ids: np.ndarray, 246 | token_embeddings: np.ndarray, 247 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 248 | """ 249 | Filter out tokens that are not in the vocab. 250 | 251 | Args: 252 | token_ids: (batch_size, seq_len) - list of ids of tokens 253 | token_embeddings: (batch_size, seq_len, embedding_size) - embeddings of tokens 254 | 255 | Returns: 256 | - number of tokens in each sequence - (batch_size) 257 | - filtered and flattened token_ids - (total_tokens_size) 258 | - filtered and flattened token_embeddings - (total_tokens_size, embedding_size) 259 | """ 260 | 261 | # (batch_size, seq_len) 262 | filtered_token_ids = self.token_ids_to_vocab_batch(token_ids) 263 | 264 | # (batch_size, seq_len) 265 | mask = filtered_token_ids.__ne__(0) 266 | 267 | # (batch_size) 268 | num_tokens = mask.sum(axis=1) 269 | 270 | # (total_tokens_size) 271 | filtered_token_ids = filtered_token_ids[mask] 272 | 273 | # (total_tokens_size, embedding_size) 274 | filtered_token_embeddings = token_embeddings[mask] 275 | 276 | return num_tokens, filtered_token_ids, filtered_token_embeddings 277 | 278 | 279 | def test_basic_resolver(): 280 | resolver = VocabResolver() 281 | 282 | resolver.add_word("bat") 283 | resolver.add_word("nicolls") 284 | 285 | token_ids = np.array([ 286 | 101, 3897, 19332, 12718, 23348, 287 | 1010, 1996, 7151, 2296, 4845, 288 | 2359, 2005, 4234, 1010, 4332, 289 | 2871, 3191, 2062, 102 290 | ]) 291 | 292 | token_ids, counts, oov, _forms = resolver.resolve_tokens(token_ids) 293 | 294 | expected = np.array([0, 0, 2, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) 295 | 296 | assert np.all(np.equal(token_ids, expected)) 297 | 298 | batch = np.array([ 299 | [101, 3897, 19332, 12718, 23348], 300 | [1010, 1996, 7151, 2296, 4845], 301 | [2359, 2005, 4234, 1010, 4332], 302 | [2871, 3191, 2062, 102, 0] 303 | ]) 304 | 305 | batch = resolver.token_ids_to_vocab_batch(batch) 306 | 307 | expected = np.array([ 308 | [0, 0, 2, 2, 0], 309 | [0, 0, 1, 0, 0], 310 | [0, 0, 0, 0, 0], 311 | [0, 0, 0, 0, 0] 312 | ]) 313 | 314 | assert np.all(np.equal(batch, expected)) 315 | 316 | 317 | def main(): 318 | import os 319 | from mini_coil.settings import DATA_DIR 320 | 321 | resolver = VocabResolver(model_repository="jinaai/jina-embeddings-v2-small-en") 322 | 323 | resolver.load_json_vocab(os.path.join(DATA_DIR, "minicoil.ptch.vocab")) 324 | 325 | sentence = "I like to swim close to the bank of the river, cause I am not a very good swimmer. He swims slow." 326 | 327 | token_ids = np.array(resolver.tokenizer.tokenize(sentence)) 328 | 329 | word_ids, counts, oov, forms = resolver.resolve_tokens(token_ids) 330 | 331 | print("word_ids", word_ids) 332 | 333 | print("counts", counts) 334 | 335 | print("oov", oov) 336 | 337 | print("forms", forms) 338 | 339 | 340 | if __name__ == "__main__": 341 | main() 342 | -------------------------------------------------------------------------------- /mini_coil/filtering.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import os 3 | import string 4 | from collections import defaultdict 5 | from typing import Set, List, Dict 6 | 7 | from mini_coil.convert_idf import IDFVocab 8 | from mini_coil.read_data import read_data 9 | from mini_coil.settings import DATA_DIR 10 | from mini_coil.tokenizer import WordTokenizer 11 | from snowballstemmer import stemmer as get_stemmer 12 | 13 | 14 | @dataclasses.dataclass 15 | class TripletStats: 16 | pos_score: float 17 | neg_score: float 18 | overlapped_pos: int 19 | overlapped_neg: int 20 | 21 | pos: List[str] 22 | neg: List[str] 23 | 24 | 25 | class TripletFilter: 26 | @classmethod 27 | def load_stopwords(cls) -> Set[str]: 28 | path = os.path.join(DATA_DIR, 'stopwords.txt') 29 | with open(path) as f: 30 | return set(f.read().splitlines()) 31 | 32 | def __init__( 33 | self, 34 | k: float = 1.2, 35 | b: float = 0.75, 36 | avg_len: float = 64.0, 37 | ): 38 | self.k = k 39 | self.b = b 40 | 41 | self.avg_len = avg_len 42 | 43 | vocab_path = os.path.join(DATA_DIR, "idf_vocab.pkl") 44 | 45 | self.stopwords = self.load_stopwords() 46 | self.tokenizer = WordTokenizer 47 | self.stemmer = get_stemmer("english") 48 | self.punctuation = set(string.punctuation) 49 | self.idf_vocab: IDFVocab = IDFVocab.load_vocab_pkl(vocab_path) 50 | 51 | def _stem(self, tokens: List[str]) -> List[str]: 52 | stemmed_tokens = [] 53 | for token in tokens: 54 | if token in self.punctuation: 55 | continue 56 | 57 | if token in self.stopwords: 58 | continue 59 | 60 | stemmed_token = self.stemmer.stemWord(token) 61 | 62 | if stemmed_token: 63 | stemmed_tokens.append(stemmed_token) 64 | return stemmed_tokens 65 | 66 | def _term_frequency(self, tokens: List[str]) -> Dict[str, float]: 67 | """Calculate the term frequency part of the BM25 formula. 68 | 69 | ( 70 | f(q_i, d) * (k + 1) 71 | ) / ( 72 | f(q_i, d) + k * (1 - b + b * (|d| / avg_len)) 73 | ) 74 | 75 | Args: 76 | tokens (List[str]): The list of tokens in the document. 77 | 78 | Returns: 79 | Dict[int, float]: The token_id to term frequency mapping. 80 | """ 81 | tf_map = {} 82 | counter = defaultdict(int) 83 | for stemmed_token in tokens: 84 | counter[stemmed_token] += 1 85 | 86 | doc_len = len(tokens) 87 | for stemmed_token in counter: 88 | num_occurrences = counter[stemmed_token] 89 | tf_map[stemmed_token] = num_occurrences * (self.k + 1) / ( 90 | num_occurrences + self.k * (1 - self.b + self.b * doc_len / self.avg_len) 91 | ) 92 | return tf_map 93 | 94 | def tokenize_and_stem(self, text: str) -> List[str]: 95 | tokens = self.tokenizer.tokenize(text.lower()) 96 | stemmed_tokens = self._stem(tokens) 97 | return stemmed_tokens 98 | 99 | def get_bm25_score(self, tokens: List[str]) -> float: 100 | tf_map = self._term_frequency(tokens) 101 | bm25_score = 0.0 102 | for token in tf_map: 103 | idf = self.idf_vocab.get_idf(token) 104 | bm25_score += idf * tf_map[token] 105 | return bm25_score 106 | 107 | def check_triplet(self, query: str, pos: str, neg: str) -> TripletStats: 108 | """ 109 | Extract tokens which present in query from pos and neg 110 | 111 | Compare the BM25 score of pos and neg overlap with query 112 | 113 | If the BM25 score of pos is higher than neg, return True, else False 114 | """ 115 | 116 | query_tokens = set(self.tokenize_and_stem(query)) 117 | pos_tokens = self.tokenize_and_stem(pos) 118 | neg_tokens = self.tokenize_and_stem(neg) 119 | 120 | # Keep token count in documents 121 | pos_overlap = [token for token in pos_tokens if token in query_tokens] 122 | neg_overlap = [token for token in neg_tokens if token in query_tokens] 123 | 124 | pos_bm25 = self.get_bm25_score(pos_overlap) 125 | neg_bm25 = self.get_bm25_score(neg_overlap) 126 | 127 | return TripletStats( 128 | pos_score=pos_bm25, 129 | neg_score=neg_bm25, 130 | overlapped_pos=len(pos_overlap), 131 | overlapped_neg=len(neg_overlap), 132 | pos=pos_overlap, 133 | neg=neg_overlap 134 | ) 135 | 136 | 137 | if __name__ == "__main__": 138 | triplet_filter = TripletFilter() 139 | 140 | n = 0 141 | interesting = 0 142 | trivial = 0 143 | tie = 0 144 | 145 | for (query, pos, neg) in read_data(): 146 | if n > 1000: 147 | break 148 | n += 1 149 | 150 | stats = triplet_filter.check_triplet(query, pos, neg) 151 | 152 | if stats.overlapped_pos == 0 or stats.overlapped_neg == 0: 153 | continue 154 | 155 | if stats.pos_score < stats.neg_score: 156 | interesting += 1 157 | print(f"Query: {query}") 158 | print(f"Pos: {pos}") 159 | print(f"Neg: {neg}") 160 | print(f"Pos score: {stats.pos_score}") 161 | print(f"Neg score: {stats.neg_score}") 162 | print(f"Pos: {stats.pos}") 163 | print(f"Neg: {stats.neg}") 164 | print("-------------------") 165 | 166 | elif stats.pos_score > stats.neg_score: 167 | trivial += 1 168 | else: 169 | tie += 1 170 | 171 | print(f"Interesting: {interesting}") 172 | print(f"Trivial: {trivial}") 173 | print(f"Tie: {tie}") -------------------------------------------------------------------------------- /mini_coil/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qdrant/miniCOIL/2d939737e89bacd797cb6c3e312135070a93a56d/mini_coil/model/__init__.py -------------------------------------------------------------------------------- /mini_coil/model/cosine_loss.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | from torch import nn 6 | 7 | 8 | class CosineLoss(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | self.loss = nn.MSELoss() 13 | 14 | @classmethod 15 | def cosine_distance( 16 | cls, 17 | mapping: Optional[torch.LongTensor], 18 | prediction: torch.Tensor, 19 | target: torch.Tensor 20 | ) -> torch.Tensor: 21 | # (flatten_batch, output_dim) 22 | if mapping is None: 23 | mapped_target = target 24 | else: 25 | mapped_target = target[mapping] 26 | 27 | # Cosine similarity 28 | # (flatten_batch) 29 | prediction_norm = torch.norm(prediction, dim=1) 30 | 31 | # (flatten_batch) 32 | target_norm = torch.norm(mapped_target, dim=1) 33 | 34 | # If prediction_norm or target_norm is zero, exclude them from the calculation 35 | mask1 = prediction_norm < 1e-6 36 | mask2 = target_norm < 1e-6 37 | mask = mask1 + mask2 38 | 39 | prediction = prediction[~mask] 40 | mapped_target = mapped_target[~mask] 41 | prediction_norm = prediction_norm[~mask] 42 | target_norm = target_norm[~mask] 43 | 44 | # Pairwise cosine similarity 45 | # (flatten_batch) 46 | cosine_similarity = torch.einsum('bi,bi->b', prediction, mapped_target) / (prediction_norm * target_norm) 47 | 48 | # Cosine distance 49 | # (flatten_batch) 50 | cosine_distance = 1 - cosine_similarity 51 | 52 | return cosine_distance 53 | 54 | def forward( 55 | self, 56 | mapping: Optional[torch.LongTensor], 57 | prediction: torch.Tensor, 58 | target: torch.Tensor 59 | ) -> torch.Tensor: 60 | """ 61 | Calculate the mean cosine distance between the prediction and the target. 62 | 63 | Args: 64 | mapping: (flatten_batch) - association between the prediction and the target. 65 | prediction: (flatten_batch, output_dim) - prediction of the context 66 | target: (num_abstracts, output_dim) - target context 67 | 68 | Returns: 69 | loss: () - mean squared error 70 | """ 71 | 72 | cosine_distance = self.cosine_distance(mapping, prediction, target) 73 | 74 | loss = cosine_distance.mean() 75 | 76 | return loss 77 | 78 | 79 | def test_cosine_loss(): 80 | loss = CosineLoss() 81 | prediction = torch.tensor([ 82 | [0.1, 0.2, 0.3], 83 | [0.4, 0.5, 0.6], 84 | [1.7, 1.9, 2.1], 85 | [1.3, 1.4, 1.5], 86 | ]) 87 | target = torch.tensor([ 88 | [1.6, 1.7, 1.8], 89 | [1.9, 2.0, 2.1], 90 | [1.7, 1.9, 2.1], 91 | [0.0, 1.0, 0.0], 92 | ]) 93 | mapping = torch.tensor([0, 1, 2, 3]) 94 | 95 | cosine_distance = loss.cosine_distance(mapping, prediction, target) 96 | 97 | assert cosine_distance.shape == (4,) 98 | assert cosine_distance[2] == 0.0 99 | 100 | 101 | if __name__ == "__main__": 102 | test_cosine_loss() 103 | print("CosineLoss test passed") 104 | -------------------------------------------------------------------------------- /mini_coil/model/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | from torch import nn 6 | from torch.nn import init 7 | 8 | 9 | class Decoder(nn.Module): 10 | """ 11 | Decoder reverses the process of the Encoder. 12 | It takes compressed per-word representation converts it into the context vector. 13 | This is similar to Autoencoder, but the intermediate layer have 2 components: 14 | - Compressed contextualized word vector 15 | - ID of the vocabulary word 16 | 17 | Decoder is only used during training. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | input_dim: int, # Dimension of the internal representation (4) 23 | output_dim: int, # Size of the context vector (768) 24 | vocab_size: int, # Size of the vocabulary (10000) 25 | device=None, 26 | dtype=None 27 | ): 28 | factory_kwargs = {'device': device, 'dtype': dtype} 29 | super().__init__() 30 | self.input_dim = input_dim 31 | self.output_dim = output_dim 32 | self.vocab_size = vocab_size 33 | 34 | # First step of the decoder 35 | # Per-word matrix to convert very small per-word representation to universal intermediate representation 36 | self.decoder_weights = nn.Parameter(torch.zeros((vocab_size, input_dim, output_dim), **factory_kwargs)) 37 | 38 | self.output_activation = nn.Tanh() 39 | 40 | self.reset_parameters() 41 | 42 | def reset_parameters(self) -> None: 43 | # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with 44 | # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see 45 | # https://github.com/pytorch/pytorch/issues/57109 46 | init.kaiming_uniform_(self.decoder_weights, a=math.sqrt(5)) 47 | 48 | def forward(self, 49 | vocab_ids: torch.LongTensor, 50 | compressed: torch.Tensor) -> torch.Tensor: 51 | """ 52 | Convert compressed representation into prediction of the context. 53 | 54 | Args: 55 | vocab_ids: (flatten_batch) - list pairs of vocab_ids 56 | compressed: (flatten_batch, input_dim) - compressed representation of the words 57 | """ 58 | 59 | # Convert compressed representation into intermediate representation 60 | 61 | # # Select decoder weights according to vocab_ids 62 | 63 | # (flatten_batch, input_dim, intermediate_dim) 64 | decoder_weights = self.decoder_weights[vocab_ids] 65 | 66 | # Apply decoder weights to compressed representation 67 | # (flatten_batch, output_dim) 68 | output_raw = torch.einsum('bdi,bd->bi', decoder_weights, compressed) 69 | 70 | # Convert intermediate representation into prediction of the context 71 | # (flatten_batch, output_dim) 72 | prediction = self.output_activation(output_raw) 73 | 74 | return prediction 75 | 76 | 77 | def test_decoder(): 78 | decoder = Decoder(input_dim=4, output_dim=386, vocab_size=10000) 79 | vocab_ids = torch.randint(0, 10000, (10,)) 80 | compressed = torch.randn(10, 4) 81 | prediction = decoder(vocab_ids, compressed) 82 | assert prediction.shape == (10, 386) 83 | 84 | 85 | if __name__ == '__main__': 86 | test_decoder() 87 | -------------------------------------------------------------------------------- /mini_coil/model/encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | This model converts token embeddings into a compressed representation. 3 | Each token have its own linear encoder 4 | 5 | """ 6 | import math 7 | from typing import Tuple 8 | 9 | import torch 10 | 11 | from torch import nn 12 | from torch.nn import init 13 | 14 | 15 | class Encoder(nn.Module): 16 | """ 17 | Encoder(768, 128, 4, 10000) 18 | 19 | Will look like this: 20 | 21 | 22 | Per-word 23 | Encoder Matrix 24 | ┌─────────────────────┐ 25 | │ Token Embedding(768)├──────┐ (10k, 128, 4) 26 | └─────────────────────┘ │ ┌─────────┐ 27 | │ │ │ 28 | ┌─────────────────────┐ │ ┌─┴───────┐ │ 29 | │ │ │ │ │ │ 30 | └─────────────────────┘ │ ┌─┴───────┐ │ │ ┌─────────┐ 31 | └────►│ │ │ ├─────►│Tanh │ 32 | ┌─────────────────────┐ │ │ │ │ └─────────┘ 33 | │ │ │ │ ├─┘ 34 | └─────────────────────┘ │ ├─┘ 35 | │ │ 36 | ┌─────────────────────┐ └─────────┘ 37 | │ │ 38 | └─────────────────────┘ 39 | 40 | Final liner transformation is accompanied by a non-linear activation function: Tanh. 41 | 42 | Tanh is used to ensure that the output is in the range [-1, 1]. 43 | It would be easier to visually interpret the output of the model, assuming that each dimension 44 | would need to encode a type of semantic cluster. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | input_dim: int, 50 | # intermediate_dim: int, 51 | output_dim: int, 52 | vocab_size: int, 53 | device=None, 54 | dtype=None 55 | ): 56 | factory_kwargs = {'device': device, 'dtype': dtype} 57 | super().__init__() 58 | self.input_dim = input_dim 59 | # self.intermediate_dim = intermediate_dim 60 | self.output_dim = output_dim 61 | self.vocab_size = vocab_size 62 | 63 | # DISABLED 64 | # Before training embeddings for individual words, we lower dimensionality of the original embeddings 65 | # using universal linear layer, shared across all words 66 | 67 | # self.intermediate_layer = nn.Sequential( 68 | # nn.Linear(input_dim, intermediate_dim, **factory_kwargs), 69 | # nn.Tanh(), 70 | # ) 71 | 72 | # For each word in the vocabulary, we have a linear encoder 73 | self.encoder_weights = nn.Parameter(torch.zeros((vocab_size, input_dim, output_dim), **factory_kwargs)) 74 | 75 | self.activation = nn.Tanh() 76 | 77 | self.reset_parameters() 78 | 79 | def reset_parameters(self) -> None: 80 | # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with 81 | # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see 82 | # https://github.com/pytorch/pytorch/issues/57109 83 | init.kaiming_uniform_(self.encoder_weights, a=math.sqrt(5)) 84 | 85 | @classmethod 86 | def convert_vocab_ids( 87 | cls, 88 | vocab_ids: torch.LongTensor, 89 | ): 90 | """ 91 | Convert vocab ids into unique-able format. 92 | 93 | Args: 94 | vocab_ids: (batch_size, seq_len) - list of word ids for each embedding. 95 | 96 | Convert each number into a pair of (value, batch_id) 97 | 98 | vocab_ids = [ 99 | [7, 3, 6, 6, 2], 100 | [1, 2, 4, 0, 0] 101 | ] 102 | 103 | output: 104 | 105 | vocab_ids = [ 106 | [ 107 | [7, 0], 108 | [3, 0], 109 | [6, 0], 110 | [6, 0], 111 | [2, 0], 112 | ], 113 | [ 114 | [1, 1], 115 | [2, 1], 116 | [4, 1], 117 | [0, 1], 118 | [0, 1], 119 | ] 120 | ] 121 | """ 122 | batch_size, seq_len = vocab_ids.size() 123 | batch_ids = torch.arange(batch_size, device=vocab_ids.device).unsqueeze(1).expand(batch_size, seq_len) 124 | return torch.stack((vocab_ids, batch_ids), dim=2) 125 | 126 | @classmethod 127 | def avg_by_vocab_ids( 128 | cls, 129 | vocab_ids: torch.LongTensor, 130 | embeddings: torch.Tensor, 131 | ) -> Tuple[torch.LongTensor, torch.Tensor]: 132 | """ 133 | Example: 134 | vocab_ids = [ 135 | [7, 3, 6, 6, 2], 136 | [1, 2, 4, 0, 0] 137 | ] 138 | 139 | embeddings = [ 140 | [ 141 | [0.1, 0.2, 0.3], 142 | [0.4, 0.5, 0.6], 143 | [0.7, 0.8, 0.9], 144 | [1.0, 1.1, 1.2], 145 | [1.3, 1.4, 1.5], 146 | ], 147 | [ 148 | [1.6, 1.7, 1.8], 149 | [1.9, 2.0, 2.1], 150 | [2.2, 2.3, 2.4], 151 | [0.0, 0.0, 0.0], 152 | [0.0, 0.0, 0.0], 153 | ] 154 | ] 155 | 156 | output: 157 | vocab_ids = [ 158 | [7, 3, 6, 2], 159 | [1, 2, 4, 0] 160 | ] 161 | 162 | embeddings = [ 163 | [ 164 | [0.1, 0.2, 0.3], 165 | [0.4, 0.5, 0.6], 166 | [0.85,0.95,1.05], 167 | [1.3, 1.4, 1.5], 168 | ], 169 | [ 170 | [1.6, 1.7, 1.8], 171 | [1.9, 2.0, 2.1], 172 | [2.2, 2.3, 2.4], 173 | [0.0, 0.0, 0.0], 174 | ] 175 | ] 176 | 177 | returns: 178 | (total_unique_words_per_batch, 2), (total_unique_words_per_batch, input_dim) 179 | 180 | Returns unique (vocab_id, batch_id) pairs and their corresponding avg of embeddings. 181 | """ 182 | 183 | # (batch_size * seq_len, 2) - token id -> batch id 184 | # tensor([ 185 | # [7, 0], 186 | # [3, 0], 187 | # [6, 0], 188 | # [6, 0], 189 | # [2, 0], 190 | # [1, 1], 191 | # [2, 1], 192 | # [4, 1], 193 | # [0, 1], 194 | # [0, 1] 195 | # ]) 196 | flattened_vocab_ids = cls.convert_vocab_ids(vocab_ids).flatten(start_dim=0, end_dim=1) 197 | 198 | # (batch_size * seq_len, input_dim) 199 | flattened_embeddings = embeddings.flatten(start_dim=0, end_dim=1) 200 | 201 | # Unique vocab ids per batch element 202 | # unique_flattened_vocab_ids - (total_unique_vocab_ids, 2) 203 | # inverse_indices - (batch_size * seq_len) 204 | unique_flattened_vocab_ids, inverse_indices = flattened_vocab_ids.unique(dim=0, return_inverse=True) 205 | 206 | # Sum up embeddings for each unique vocab id 207 | # (total_unique_vocab_ids, input_dim) 208 | unique_flattened_embeddings = torch.zeros( 209 | (unique_flattened_vocab_ids.size(0), embeddings.size(2)), device=embeddings.device) 210 | 211 | # Count of unique vocab ids 212 | # (total_unique_vocab_ids) 213 | unique_flattened_vocab_ids_count = torch.zeros(unique_flattened_vocab_ids.size(0), device=embeddings.device).long() 214 | 215 | # Count unique vocab ids (a bit hacky way to do it) 216 | # (total_unique_vocab_ids) 217 | source_of_ones = torch.ones_like(inverse_indices) 218 | unique_flattened_vocab_ids_count.index_add_(0, inverse_indices, source_of_ones) 219 | 220 | # Sum up embeddings for each unique vocab id 221 | # (total_unique_vocab_ids, input_dim) 222 | unique_flattened_embeddings.index_add_(0, inverse_indices, flattened_embeddings) 223 | 224 | # Average embeddings 225 | # (total_unique_vocab_ids, input_dim) 226 | unique_flattened_embeddings /= unique_flattened_vocab_ids_count.unsqueeze(1) 227 | 228 | return unique_flattened_vocab_ids, unique_flattened_embeddings 229 | 230 | def forward( 231 | self, 232 | vocab_ids: torch.LongTensor, 233 | embeddings: torch.Tensor, 234 | ) -> Tuple[torch.LongTensor, torch.Tensor]: 235 | """ 236 | Args: 237 | vocab_ids: (batch_size, seq_len) - list of word ids for each embedding. 238 | embeddings: (batch_size, seq_len, input_dim) - list of token embeddings, obtained from the transformer. 239 | 240 | Vocab ids may have duplicates. In this case embeddings should be summed up. 241 | 242 | Returns: 243 | (total_unique_words_per_batch, 2), (total_unique_words_per_batch, output_dim) 244 | 245 | """ 246 | # (total_unique_vocab_ids, 2), (total_unique_vocab_ids, input_dim) 247 | unique_flattened_vocab_ids_and_batch_ids, unique_flattened_embeddings = \ 248 | self.avg_by_vocab_ids(vocab_ids, embeddings) 249 | 250 | # Generate intermediate embeddings 251 | 252 | # DISABLED 253 | # (total_unique_vocab_ids, intermediate_dim) 254 | # unique_flattened_embeddings = self.intermediate_layer(unique_flattened_embeddings) 255 | 256 | # Select which linear encoders to use for each embedding 257 | 258 | # Select linear encoders ids 259 | # (total_unique_vocab_ids) 260 | unique_flattened_vocab_ids = unique_flattened_vocab_ids_and_batch_ids[:, 0] 261 | 262 | # Select linear encoders 263 | # (total_unique_vocab_ids, input_dim, output_dim) 264 | unique_encoder_weights = self.encoder_weights[unique_flattened_vocab_ids] 265 | 266 | # (total_unique_vocab_ids, output_dim) 267 | unique_flattened_encoded = torch.einsum('bi,bio->bo', unique_flattened_embeddings, unique_encoder_weights) 268 | 269 | # Apply activation function 270 | unique_flattened_encoded = self.activation(unique_flattened_encoded) 271 | 272 | return unique_flattened_vocab_ids_and_batch_ids, unique_flattened_encoded 273 | -------------------------------------------------------------------------------- /mini_coil/model/encoder_numpy.py: -------------------------------------------------------------------------------- 1 | """ 2 | Encoder model for a single word. 3 | 4 | Same as `Encoder`, but instead of being a PyTorch model, it is a pure Numpy. 5 | By doing this, we avoid dependency on PyTorch, and we can use this model in a pure Python environment. 6 | 7 | This model is not trainable, and should only be used for inference. 8 | """ 9 | 10 | import numpy as np 11 | 12 | 13 | class EncoderNumpy: 14 | """ 15 | Encoder(768, 128, 4, 10000) 16 | 17 | Will look like this: 18 | 19 | 20 | Per-word 21 | Encoder Matrix 22 | ┌─────────────────────┐ 23 | │ Token Embedding(768)├──────┐ (10k, 128, 4) 24 | └─────────────────────┘ │ ┌─────────┐ 25 | │ │ │ 26 | ┌─────────────────────┐ │ ┌─┴───────┐ │ 27 | │ │ │ │ │ │ 28 | └─────────────────────┘ │ ┌─┴───────┐ │ │ ┌─────────┐ 29 | └────►│ │ │ ├─────►│Tanh │ 30 | ┌─────────────────────┐ │ │ │ │ └─────────┘ 31 | │ │ │ │ ├─┘ 32 | └─────────────────────┘ │ ├─┘ 33 | │ │ 34 | ┌─────────────────────┐ └─────────┘ 35 | │ │ 36 | └─────────────────────┘ 37 | 38 | Final liner transformation is accompanied by a non-linear activation function: Tanh. 39 | 40 | Tanh is used to ensure that the output is in the range [-1, 1]. 41 | It would be easier to visually interpret the output of the model, assuming that each dimension 42 | would need to encode a type of semantic cluster. 43 | """ 44 | 45 | def __init__( 46 | self, 47 | weights: np.ndarray, 48 | ): 49 | self.weights = weights 50 | self.vocab_size, self.input_dim, self.output_dim = weights.shape 51 | 52 | self.encoder_weights = weights 53 | 54 | # Activation function 55 | self.activation = np.tanh 56 | 57 | @staticmethod 58 | def convert_vocab_ids(vocab_ids: np.ndarray) -> np.ndarray: 59 | """ 60 | Convert vocab_ids of shape (batch_size, seq_len) into (batch_size, seq_len, 2) 61 | by appending batch_id alongside each vocab_id. 62 | """ 63 | batch_size, seq_len = vocab_ids.shape 64 | batch_ids = np.arange(batch_size, dtype=vocab_ids.dtype).reshape(batch_size, 1) 65 | batch_ids = np.repeat(batch_ids, seq_len, axis=1) 66 | # Stack vocab_ids and batch_ids along the last dimension 67 | combined = np.stack((vocab_ids, batch_ids), axis=2) 68 | return combined 69 | 70 | @classmethod 71 | def avg_by_vocab_ids(cls, vocab_ids: np.ndarray, embeddings: np.ndarray): 72 | """ 73 | Takes: 74 | vocab_ids: (batch_size, seq_len) int array 75 | embeddings: (batch_size, seq_len, input_dim) float array 76 | 77 | Returns: 78 | unique_flattened_vocab_ids: (total_unique, 2) array of [vocab_id, batch_id] 79 | unique_flattened_embeddings: (total_unique, input_dim) averaged embeddings 80 | """ 81 | batch_size, seq_len = vocab_ids.shape 82 | input_dim = embeddings.shape[2] 83 | 84 | # Flatten vocab_ids and embeddings 85 | # flattened_vocab_ids: (batch_size*seq_len, 2) 86 | flattened_vocab_ids = cls.convert_vocab_ids(vocab_ids).reshape(-1, 2) 87 | 88 | # flattened_embeddings: (batch_size*seq_len, input_dim) 89 | flattened_embeddings = embeddings.reshape(-1, input_dim) 90 | 91 | # Find unique (vocab_id, batch_id) pairs 92 | unique_flattened_vocab_ids, inverse_indices = np.unique(flattened_vocab_ids, axis=0, return_inverse=True) 93 | 94 | # Prepare arrays to accumulate sums 95 | unique_count = unique_flattened_vocab_ids.shape[0] 96 | unique_flattened_embeddings = np.zeros((unique_count, input_dim), dtype=embeddings.dtype) 97 | unique_flattened_count = np.zeros(unique_count, dtype=np.int32) 98 | 99 | # Use np.add.at to accumulate sums based on inverse indices 100 | np.add.at(unique_flattened_embeddings, inverse_indices, flattened_embeddings) 101 | np.add.at(unique_flattened_count, inverse_indices, 1) 102 | 103 | # Compute averages 104 | unique_flattened_embeddings /= unique_flattened_count[:, None] 105 | 106 | return unique_flattened_vocab_ids, unique_flattened_embeddings 107 | 108 | def forward(self, vocab_ids: np.ndarray, embeddings: np.ndarray): 109 | """ 110 | Args: 111 | vocab_ids: (batch_size, seq_len) int array 112 | embeddings: (batch_size, seq_len, input_dim) float array 113 | 114 | Returns: 115 | unique_flattened_vocab_ids_and_batch_ids: (total_unique, 2) 116 | unique_flattened_encoded: (total_unique, output_dim) 117 | """ 118 | # Average embeddings for duplicate vocab_ids 119 | unique_flattened_vocab_ids_and_batch_ids, unique_flattened_embeddings = self.avg_by_vocab_ids(vocab_ids, 120 | embeddings) 121 | 122 | # Select the encoder weights for each unique vocab_id 123 | unique_flattened_vocab_ids = unique_flattened_vocab_ids_and_batch_ids[:, 0] 124 | 125 | # unique_encoder_weights: (total_unique, input_dim, output_dim) 126 | unique_encoder_weights = self.encoder_weights[unique_flattened_vocab_ids] 127 | 128 | # Compute linear transform: (total_unique, output_dim) 129 | # Using Einstein summation for matrix multiplication: 130 | # 'bi,bio->bo' means: for each "b" (batch element), multiply embeddings (b,i) by weights (b,i,o) -> (b,o) 131 | unique_flattened_encoded = np.einsum('bi,bio->bo', unique_flattened_embeddings, unique_encoder_weights) 132 | 133 | # Apply Tanh activation 134 | unique_flattened_encoded = self.activation(unique_flattened_encoded) 135 | 136 | return unique_flattened_vocab_ids_and_batch_ids, unique_flattened_encoded 137 | 138 | 139 | -------------------------------------------------------------------------------- /mini_coil/model/mini_coil_inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | End-to-end inference of the miniCOIL model. 3 | This includes sentence transformer, vocabulary resolver, and the coil post-encoder. 4 | """ 5 | import itertools 6 | import json 7 | from typing import List, Iterable 8 | 9 | import ipdb 10 | import numpy as np 11 | import torch 12 | from fastembed.late_interaction.token_embeddings import TokenEmbeddingsModel 13 | 14 | from mini_coil.data_pipeline.vocab_resolver import VocabResolver, VocabTokenizerTokenizer 15 | from mini_coil.model.encoder_numpy import EncoderNumpy 16 | 17 | 18 | class MiniCOIL: 19 | 20 | def __init__( 21 | self, 22 | vocab_path: str, 23 | word_encoder_path: str, 24 | input_dim: int = 512, 25 | sentence_encoder_model: str = "jinaai/jina-embeddings-v2-small-en-tokens" 26 | ): 27 | self.sentence_encoder_model = sentence_encoder_model 28 | self.sentence_encoder = TokenEmbeddingsModel(model_name=sentence_encoder_model, threads=1) 29 | 30 | self.vocab_path = vocab_path 31 | self.vocab_resolver = VocabResolver(tokenizer=VocabTokenizerTokenizer(self.sentence_encoder.tokenizer)) 32 | self.vocab_resolver.load_json_vocab(vocab_path) 33 | 34 | self.input_dim = input_dim 35 | self.output_dim = None 36 | 37 | self.word_encoder_path = word_encoder_path 38 | 39 | self.word_encoder = None 40 | 41 | self.load_encoder_numpy() 42 | 43 | def load_encoder_numpy(self): 44 | weights = np.load(self.word_encoder_path) 45 | self.word_encoder = EncoderNumpy(weights) 46 | assert self.word_encoder.input_dim == self.input_dim 47 | self.output_dim = self.word_encoder.output_dim 48 | 49 | def encode(self, sentences: Iterable[str], parallel=None) -> List[dict]: 50 | """ 51 | Encode the given word in the context of the sentences. 52 | """ 53 | 54 | sentences1, sentences2 = itertools.tee(sentences, 2) 55 | 56 | result = [] 57 | 58 | with torch.no_grad(): 59 | for embedding, sentence in zip(self.sentence_encoder.embed(sentences1, batch_size=4, parallel=parallel), sentences2): 60 | token_ids = np.array(self.sentence_encoder.tokenize([sentence])[0].ids) 61 | 62 | word_ids, counts, oov, forms = self.vocab_resolver.resolve_tokens(token_ids) 63 | 64 | # Size: (1, words) 65 | word_ids = np.expand_dims(word_ids, axis=0) 66 | # Size: (1, words, embedding_size) 67 | embedding = np.expand_dims(embedding, axis=0) 68 | 69 | with ipdb.launch_ipdb_on_exception(): 70 | # Size of word_ids_mapping: (unique_words, 2) - [vocab_id, batch_id] 71 | # Size of embeddings: (unique_words, embedding_size) 72 | ids_mapping, embeddings = self.word_encoder.forward(word_ids, embedding) 73 | 74 | # Size of counts: (unique_words) 75 | words_ids = ids_mapping[:, 0] 76 | 77 | sentence_result = {} 78 | 79 | words = [self.vocab_resolver.lookup_word(word_id) for word_id in words_ids] 80 | 81 | for word, word_id, emb in zip(words, words_ids, embeddings): 82 | if word_id == 0: 83 | continue 84 | 85 | sentence_result[word] = { 86 | "count": int(counts[word_id]), 87 | "word_id": int(word_id), 88 | "embedding": emb.tolist() 89 | } 90 | 91 | for oov_word, count in oov.items(): 92 | sentence_result[oov_word] = { 93 | "count": int(count), 94 | "word_id": -1, 95 | "embedding": [1] 96 | } 97 | 98 | result.append(sentence_result) 99 | 100 | return result 101 | 102 | 103 | def main(): 104 | import argparse 105 | 106 | parser = argparse.ArgumentParser() 107 | parser.add_argument("--vocab-path", type=str) 108 | parser.add_argument("--word-encoder-path", type=str) 109 | parser.add_argument("--sentences", type=str, nargs='+') 110 | parser.add_argument("--dim", type=int, default=4) 111 | args = parser.parse_args() 112 | 113 | # Catch exception with ipdb 114 | 115 | with ipdb.launch_ipdb_on_exception(): 116 | model = MiniCOIL( 117 | vocab_path=args.vocab_path, 118 | word_encoder_path=args.word_encoder_path, 119 | ) 120 | 121 | for embeddings in model.encode(args.sentences): 122 | for word, embedding in embeddings.items(): 123 | print(word, json.dumps(embedding)) 124 | 125 | 126 | if __name__ == '__main__': 127 | main() 128 | -------------------------------------------------------------------------------- /mini_coil/model/mse_loss.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | from torch import nn 6 | 7 | 8 | class MSELoss(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | self.loss = nn.MSELoss() 13 | 14 | def forward( 15 | self, 16 | mapping: Optional[torch.LongTensor], 17 | prediction: torch.Tensor, 18 | target: torch.Tensor 19 | ) -> torch.Tensor: 20 | """ 21 | Calculate the mean squared error between the prediction and the target. 22 | 23 | Args: 24 | mapping: (flatten_batch) - association between the prediction and the target. 25 | prediction: (flatten_batch, output_dim) - prediction of the context 26 | target: (num_abstracts, output_dim) - target context 27 | 28 | Returns: 29 | loss: () - mean squared error 30 | """ 31 | if mapping is None: 32 | mapped_target = target 33 | else: 34 | mapped_target = target[mapping] 35 | 36 | loss = self.loss(prediction, mapped_target) 37 | return loss 38 | 39 | 40 | def test_mse_loss(): 41 | loss = MSELoss() 42 | prediction = torch.tensor([ 43 | [0.1, 0.2, 0.3], 44 | [0.4, 0.5, 0.6], 45 | [1.7, 1.9, 2.1], 46 | [1.3, 1.4, 1.5], 47 | ]) 48 | target = torch.tensor([ 49 | [1.6, 1.7, 1.8], 50 | [1.9, 2.0, 2.1], 51 | [1.7, 1.9, 2.1], 52 | [0.0, 1.0, 0.0], 53 | ]) 54 | mapping = torch.tensor([0, 1, 2, 3]) 55 | 56 | loss1 = loss(mapping, prediction, target) 57 | 58 | prediction = torch.tensor([ 59 | [0.1, 0.2, 0.3], 60 | [0.4, 1.5, 0.6], 61 | [1.7, 1.9, 2.1], 62 | [1.3, 1.4, 1.5], 63 | ]) 64 | target = torch.tensor([ 65 | [1.6, 1.7, 1.8], 66 | [1.9, 2.0, 2.1], 67 | [1.7, 1.9, 2.1], 68 | [0.0, 1.0, 0.0], 69 | ]) 70 | 71 | loss2 = loss(mapping, prediction, target) 72 | 73 | assert loss1 > loss2 74 | 75 | 76 | if __name__ == "__main__": 77 | test_mse_loss() 78 | print("MseLoss test passed") 79 | -------------------------------------------------------------------------------- /mini_coil/model/triplet_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def cosine_distance(x1: torch.Tensor, x2: torch.Tensor, eps=1e-8): 6 | """ 7 | Calculate pairwise cosine distance between two tensors. 8 | 9 | Args: 10 | x1: (batch_size, embedding_size) - embeddings of the first elements 11 | x2: (batch_size, embedding_size) - embeddings of the second elements 12 | eps: float - small value to avoid division by zero 13 | 14 | Returns: 15 | distance: (batch_size) - cosine distance between the elements 16 | """ 17 | 18 | dot_product = torch.sum(x1 * x2, dim=1) 19 | x1_norm = torch.norm(x1, p=2, dim=1) 20 | x2_norm = torch.norm(x2, p=2, dim=1) 21 | 22 | return 1.0 - dot_product / (x1_norm * x2_norm + eps) 23 | 24 | 25 | class TripletLoss(nn.Module): 26 | def __init__(self): 27 | super().__init__() 28 | 29 | def forward( 30 | self, 31 | embeddings: torch.Tensor, 32 | triplet_indices: torch.Tensor, 33 | margins: torch.Tensor 34 | ): 35 | """ 36 | Calculate the triplet loss with custom margin for each triplet. 37 | 38 | Args: 39 | embeddings: (batch_size, embedding_size) - embeddings of the elements 40 | triplet_indices: (num_triplets, 3) - indices of the triplets. Each row consists of: anchor, positive, negative 41 | margins: (num_triplets) - margins for each triplet 42 | 43 | Returns: 44 | loss: triplet loss for the batch 45 | """ 46 | 47 | # (num_triplets, embedding_size) 48 | anchor = embeddings[triplet_indices[:, 0]] 49 | # (num_triplets, embedding_size) 50 | positive = embeddings[triplet_indices[:, 1]] 51 | # (num_triplets, embedding_size) 52 | negative = embeddings[triplet_indices[:, 2]] 53 | 54 | # (num_triplets) 55 | positive_distance = cosine_distance(anchor, positive) 56 | # (num_triplets) 57 | negative_distance = cosine_distance(anchor, negative) 58 | 59 | # (num_triplets) 60 | # Example: 61 | # positive_distance = [0.1, 0.2, 0.5] 62 | # negative_distance = [0.3, 0.3, 0.2] 63 | # margins = [0.2, 0.2, 0.3] 64 | # loss = [ 65 | # 0.1 - 0.3 + 0.2 = 0.0, 66 | # 0.2 - 0.3 + 0.2 = 0.1, 67 | # 0.5 - 0.2 + 0.3 = 0.6 68 | # ] 69 | loss = torch.relu(positive_distance - negative_distance + margins) 70 | number_failed_triplets = torch.sum(loss > 0).item() 71 | return loss.mean(), number_failed_triplets 72 | 73 | def test_cosine_distance(): 74 | x1 = torch.tensor([ 75 | [0.1, 0.2, 0.3], 76 | [0.1, 0.2, 0.3], 77 | [0.1, 0.2, 0.3], 78 | [0.1, 0.2, 0.3], 79 | ]) 80 | x2 = torch.tensor([ 81 | [0.3, 0.2, 0.1], 82 | [3.0, 2.0, 1.0], 83 | [0.0, 0.5, 1.0], 84 | [1.0, 1.0, 1.0], 85 | ]) 86 | 87 | distance = cosine_distance(x1, x2) 88 | 89 | print(distance) 90 | 91 | assert distance[0] > distance[2] 92 | assert distance[0] - distance[1] < 1e-6 93 | assert distance[1] > distance[2] 94 | assert distance[3] > distance[2] 95 | 96 | 97 | if __name__ == "__main__": 98 | test_cosine_distance() 99 | -------------------------------------------------------------------------------- /mini_coil/model/word_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Encoder model for a single word. 3 | 4 | This is an intermediate form of the model, that is designed to be trained independently for each word. 5 | The model is a simple linear encoder, that compresses the input token embeddings into a smaller representation. 6 | 7 | Additionally, it includes layers that simulate quantization into int8. 8 | 9 | After training, all word encoders are combined into a single model defined in `encoder.py`. 10 | """ 11 | 12 | import math 13 | 14 | import torch 15 | 16 | from torch import nn 17 | from torch.nn import init 18 | 19 | 20 | class WordEncoder(nn.Module): 21 | """ 22 | WordEncoder(512, 4) 23 | 24 | Will look like this: 25 | 26 | 27 | Linear transformation 28 | ┌─────────────────────┐ ┌─────────┐ ┌─────────┐ 29 | │ Token Embedding(512)├─────►│512->4 ├───►│Tanh ├──► 4d representation 30 | └─────────────────────┘ └─────────┘ └─────────┘ 31 | 32 | 33 | Final liner transformation is accompanied by a non-linear activation function: Sigmoid. 34 | 35 | Tanh is used to ensure that the output is in the range [-1, 1]. 36 | It would be easier to visually interpret the output of the model, assuming that each dimension 37 | would need to encode a type of semantic cluster. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | input_dim: int, 43 | output_dim: int, 44 | device=None, 45 | dtype=None, 46 | dropout: float = 0.05 47 | ): 48 | factory_kwargs = {'device': device, 'dtype': dtype} 49 | super().__init__() 50 | self.input_dim = input_dim 51 | self.output_dim = output_dim 52 | self.device = device 53 | self.dtype = dtype 54 | 55 | self.dropout = nn.Dropout(dropout) 56 | 57 | # self.quant = torch.quantization.QuantStub() 58 | # self.dequant = torch.quantization.DeQuantStub() 59 | 60 | self.encoder_weights = nn.Parameter(torch.zeros((input_dim, output_dim), **factory_kwargs)) 61 | 62 | self.activation = nn.Tanh() 63 | 64 | self.reset_parameters() 65 | 66 | def reset_parameters(self) -> None: 67 | # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with 68 | # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see 69 | # https://github.com/pytorch/pytorch/issues/57109 70 | init.kaiming_uniform_(self.encoder_weights, a=math.sqrt(5)) 71 | 72 | def forward(self, word_embeddings: torch.Tensor) -> torch.Tensor: 73 | """ 74 | Forward pass of the model. 75 | 76 | Args: 77 | word_embeddings: (batch_size, input_dim) - input token embeddings 78 | 79 | Returns: 80 | (batch_size, output_dim) - compressed representation of the input 81 | """ 82 | # word_embeddings = self.quant(word_embeddings) 83 | word_embeddings = self.dropout(word_embeddings) 84 | compressed = self.activation(word_embeddings @ self.encoder_weights) 85 | return compressed 86 | # return self.dequant(compressed) 87 | -------------------------------------------------------------------------------- /mini_coil/read_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from typing import Iterator, Tuple 4 | import pyarrow.parquet as pq 5 | 6 | from mini_coil.settings import DATA_DIR 7 | 8 | 9 | def read_data() -> Iterator[Tuple[str, str, str]]: 10 | parquet_files = glob.glob(os.path.join(DATA_DIR, '*.parquet')) 11 | 12 | for file in parquet_files: 13 | table = pq.read_table(file).to_pandas() 14 | 15 | for row in table.itertuples(): 16 | yield ( 17 | row.query, 18 | row.pos, 19 | row.neg 20 | ) 21 | 22 | 23 | def main(): 24 | 25 | for query, pos, neg in read_data(): 26 | print(query) 27 | print(pos) 28 | print(neg) 29 | break 30 | # import ipdb; ipdb.set_trace() 31 | 32 | 33 | if __name__ == '__main__': 34 | main() 35 | -------------------------------------------------------------------------------- /mini_coil/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data') 4 | 5 | 6 | if __name__ == '__main__': 7 | print("DATA_DIR:", DATA_DIR) 8 | -------------------------------------------------------------------------------- /mini_coil/tokenizer.py: -------------------------------------------------------------------------------- 1 | # This code is a modified copy of the `NLTKWordTokenizer` class from `NLTK` library. 2 | 3 | import re 4 | from typing import List 5 | 6 | 7 | class WordTokenizer: 8 | """The tokenizer is "destructive" such that the regexes applied will munge the 9 | input string to a state beyond re-construction. 10 | """ 11 | 12 | # Starting quotes. 13 | STARTING_QUOTES = [ 14 | (re.compile("([«“‘„]|[`]+)", re.U), r" \1 "), 15 | (re.compile(r"^\""), r"``"), 16 | (re.compile(r"(``)"), r" \1 "), 17 | (re.compile(r"([ \(\[{<])(\"|\'{2})"), r"\1 `` "), 18 | (re.compile(r"(?i)(\')(?!re|ve|ll|m|t|s|d|n)(\w)\b", re.U), r"\1 \2"), 19 | ] 20 | 21 | # Ending quotes. 22 | ENDING_QUOTES = [ 23 | (re.compile("([»”’])", re.U), r" \1 "), 24 | (re.compile(r"''"), " '' "), 25 | (re.compile(r'"'), " '' "), 26 | (re.compile(r"([^' ])('[sS]|'[mM]|'[dD]|') "), r"\1 \2 "), 27 | (re.compile(r"([^' ])('ll|'LL|'re|'RE|'ve|'VE|n't|N'T) "), r"\1 \2 "), 28 | ] 29 | 30 | # Punctuation. 31 | PUNCTUATION = [ 32 | (re.compile(r'([^\.])(\.)([\]\)}>"\'' "»”’ " r"]*)\s*$", re.U), r"\1 \2 \3 "), 33 | (re.compile(r"([:,])([^\d])"), r" \1 \2"), 34 | (re.compile(r"([:,])$"), r" \1 "), 35 | ( 36 | re.compile(r"\.{2,}", re.U), 37 | r" \g<0> ", 38 | ), 39 | (re.compile(r"[;@#$%&]"), r" \g<0> "), 40 | ( 41 | re.compile(r'([^\.])(\.)([\]\)}>"\']*)\s*$'), 42 | r"\1 \2\3 ", 43 | ), # Handles the final period. 44 | (re.compile(r"[?!]"), r" \g<0> "), 45 | (re.compile(r"([^'])' "), r"\1 ' "), 46 | ( 47 | re.compile(r"[*]", re.U), 48 | r" \g<0> ", 49 | ), 50 | ] 51 | 52 | # Pads parentheses 53 | PARENS_BRACKETS = (re.compile(r"[\]\[\(\)\{\}\<\>]"), r" \g<0> ") 54 | DOUBLE_DASHES = (re.compile(r"--"), r" -- ") 55 | 56 | # List of contractions adapted from Robert MacIntyre's tokenizer. 57 | CONTRACTIONS2 = [ 58 | re.compile(pattern) 59 | for pattern in ( 60 | r"(?i)\b(can)(?#X)(not)\b", 61 | r"(?i)\b(d)(?#X)('ye)\b", 62 | r"(?i)\b(gim)(?#X)(me)\b", 63 | r"(?i)\b(gon)(?#X)(na)\b", 64 | r"(?i)\b(got)(?#X)(ta)\b", 65 | r"(?i)\b(lem)(?#X)(me)\b", 66 | r"(?i)\b(more)(?#X)('n)\b", 67 | r"(?i)\b(wan)(?#X)(na)(?=\s)", 68 | ) 69 | ] 70 | CONTRACTIONS3 = [ 71 | re.compile(pattern) 72 | for pattern in (r"(?i) ('t)(?#X)(is)\b", r"(?i) ('t)(?#X)(was)\b") 73 | ] 74 | 75 | @classmethod 76 | def tokenize(cls, text: str) -> List[str]: 77 | """Return a tokenized copy of `text`. 78 | 79 | >>> s = '''Good muffins cost $3.88 (roughly 3,36 euros)\nin New York.''' 80 | >>> WordTokenizer().tokenize(s) 81 | ['Good', 'muffins', 'cost', '$', '3.88', '(', 'roughly', '3,36', 'euros', ')', 'in', 'New', 'York', '.'] 82 | 83 | Args: 84 | text: The text to be tokenized. 85 | 86 | Returns: 87 | A list of tokens. 88 | """ 89 | for regexp, substitution in cls.STARTING_QUOTES: 90 | text = regexp.sub(substitution, text) 91 | 92 | for regexp, substitution in cls.PUNCTUATION: 93 | text = regexp.sub(substitution, text) 94 | 95 | # Handles parentheses. 96 | regexp, substitution = cls.PARENS_BRACKETS 97 | text = regexp.sub(substitution, text) 98 | 99 | # Handles double dash. 100 | regexp, substitution = cls.DOUBLE_DASHES 101 | text = regexp.sub(substitution, text) 102 | 103 | # add extra space to make things easier 104 | text = " " + text + " " 105 | 106 | for regexp, substitution in cls.ENDING_QUOTES: 107 | text = regexp.sub(substitution, text) 108 | 109 | for regexp in cls.CONTRACTIONS2: 110 | text = regexp.sub(r" \1 \2 ", text) 111 | for regexp in cls.CONTRACTIONS3: 112 | text = regexp.sub(r" \1 \2 ", text) 113 | return text.split() 114 | -------------------------------------------------------------------------------- /mini_coil/train.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qdrant/miniCOIL/2d939737e89bacd797cb6c3e312135070a93a56d/mini_coil/train.py -------------------------------------------------------------------------------- /mini_coil/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qdrant/miniCOIL/2d939737e89bacd797cb6c3e312135070a93a56d/mini_coil/training/__init__.py -------------------------------------------------------------------------------- /mini_coil/training/coil_module.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import lightning as L 4 | import numpy as np 5 | import torch 6 | from torch import optim 7 | 8 | from mini_coil.model.cosine_loss import CosineLoss 9 | from mini_coil.model.decoder import Decoder 10 | from mini_coil.model.encoder import Encoder 11 | 12 | 13 | class MiniCoil(L.LightningModule): 14 | def __init__( 15 | self, 16 | encoder: Encoder, 17 | decoder: Decoder, 18 | ): 19 | super().__init__() 20 | self.encoder: Encoder = encoder 21 | self.decoder: Decoder = decoder 22 | self.loss = CosineLoss() 23 | 24 | def configure_optimizers(self): 25 | optimizer = optim.Adam(self.parameters(), lr=1e-3) 26 | return optimizer 27 | # return { 28 | # "optimizer": optimizer, 29 | # "lr_scheduler": { 30 | # "scheduler": ReduceLROnPlateau( 31 | # optimizer, 32 | # mode='min', 33 | # factor=0.5, 34 | # patience=50, 35 | # verbose=True, 36 | # threshold=1e-5, 37 | # ), 38 | # "monitor": "val_loss", 39 | # # "frequency": "indicates how often the metric is updated", 40 | # # If "monitor" references validation metrics, then "frequency" should be set to a 41 | # # multiple of "trainer.check_val_every_n_epoch". 42 | # }, 43 | # } 44 | 45 | def encode_decode_loss( 46 | self, 47 | batch: Dict[str, np.ndarray], 48 | ): 49 | """ 50 | batch: 51 | { 52 | 'token_embeddings': np array of shape (batch_size, max_len, embedding_size), 53 | 'text_embeddings': np array of shape (batch_size, embedding_size), 54 | 'token_ids': np array of shape (batch_size, max_len) 55 | } 56 | """ 57 | 58 | # import ipdb; ipdb.set_trace() 59 | 60 | token_ids = torch.from_numpy(batch['token_ids']).to(self.device) 61 | token_embeddings = torch.from_numpy(batch['token_embeddings']).to(self.device).float() 62 | text_embeddings = torch.from_numpy(batch['text_embeddings']).to(self.device).float() 63 | 64 | encoded = self.encoder( 65 | token_ids, 66 | token_embeddings 67 | ) 68 | 69 | # (total_unique_words_per_batch, 2) 70 | # [ 71 | # [vocab_id1, id_in_batch1], 72 | # [vocab_id2, id_in_batch1], 73 | # [vocab_id3, id_in_batch2], 74 | # ] 75 | unique_flattened_vocab_ids_and_batch_ids = encoded[0] 76 | 77 | # (total_unique_words_per_batch, compressed_dim) 78 | unique_flattened_encoded = encoded[1] 79 | 80 | # (total_unique_words_per_batch) 81 | vocab_ids = unique_flattened_vocab_ids_and_batch_ids[:, 0] 82 | # (total_unique_words_per_batch) 83 | word_to_text_id = unique_flattened_vocab_ids_and_batch_ids[:, 1] 84 | 85 | # (total_unique_words_per_batch, embedding_size) 86 | decompressed = self.decoder( 87 | vocab_ids, 88 | unique_flattened_encoded, 89 | ) 90 | 91 | return self.loss(word_to_text_id, decompressed, text_embeddings) 92 | 93 | def training_step( 94 | self, 95 | batch: Dict[str, np.ndarray], 96 | batch_idx: int 97 | ): 98 | loss = self.encode_decode_loss(batch) 99 | 100 | self.log("train_loss", loss) 101 | 102 | return loss 103 | 104 | def validation_step( 105 | self, 106 | batch: Dict[str, np.ndarray], 107 | batch_idx: int 108 | ): 109 | batch_size = batch['token_ids'].shape[0] 110 | 111 | loss = self.encode_decode_loss(batch) 112 | self.log("val_loss", loss, batch_size=batch_size) 113 | 114 | return loss 115 | -------------------------------------------------------------------------------- /mini_coil/training/data_loader.py: -------------------------------------------------------------------------------- 1 | from mini_coil.data_pipeline.read_pre_encoded import PreEncodedReader 2 | 3 | 4 | class PreEncodedLoader: 5 | 6 | def __init__(self, path: str, batch_size: int = 32): 7 | self.reader = PreEncodedReader(path) 8 | self.batch_size = batch_size 9 | 10 | def __iter__(self): 11 | total_batches = len(self.reader) // self.batch_size 12 | for i in range(total_batches): 13 | yield self.reader.read(i * self.batch_size, (i + 1) * self.batch_size) 14 | -------------------------------------------------------------------------------- /mini_coil/training/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import lightning as L 4 | from ipdb import launch_ipdb_on_exception 5 | from lightning.pytorch.callbacks import ModelCheckpoint 6 | 7 | from mini_coil.data_pipeline.vocab_resolver import VocabResolver 8 | from mini_coil.model.decoder import Decoder 9 | from mini_coil.model.encoder import Encoder 10 | from mini_coil.settings import DATA_DIR 11 | from mini_coil.training.coil_module import MiniCoil 12 | from mini_coil.training.data_loader import PreEncodedLoader 13 | 14 | 15 | def get_encoder(vocab_size): 16 | return Encoder( 17 | input_dim=384, 18 | output_dim=4, 19 | vocab_size=vocab_size, 20 | ) 21 | 22 | 23 | def get_decoder(vocab_size): 24 | return Decoder( 25 | input_dim=4, 26 | output_dim=384, 27 | vocab_size=vocab_size, 28 | ) 29 | 30 | 31 | def get_model(vocab_size): 32 | return MiniCoil( 33 | encoder=get_encoder(vocab_size), 34 | decoder=get_decoder(vocab_size), 35 | ) 36 | 37 | 38 | def main(): 39 | batch_size = 64 40 | model_repository = "sentence-transformers/all-MiniLM-L6-v2" 41 | 42 | data_path = os.path.join(DATA_DIR, "test") 43 | 44 | train_path = os.path.join(data_path, "train") 45 | valid_path = os.path.join(data_path, "valid") 46 | 47 | train_loader = PreEncodedLoader(train_path, batch_size) 48 | valid_loader = PreEncodedLoader(valid_path, batch_size) 49 | 50 | test_vocab_path = os.path.join(DATA_DIR, "test", "vocab.txt") 51 | 52 | vocab_resolver = VocabResolver(model_repository) 53 | vocab_resolver.load_vocab(test_vocab_path) 54 | 55 | mini_coil = get_model(vocab_resolver.vocab_size()) 56 | 57 | checkpoint_callback = ModelCheckpoint( 58 | every_n_epochs=10, 59 | ) 60 | 61 | trainer = L.Trainer( 62 | max_epochs=1000, 63 | callbacks=[checkpoint_callback], 64 | ) 65 | 66 | # catch with ipdb 67 | with launch_ipdb_on_exception(): 68 | trainer.fit( 69 | model=mini_coil, 70 | train_dataloaders=train_loader, 71 | val_dataloaders=valid_loader 72 | ) 73 | 74 | 75 | if __name__ == "__main__": 76 | main() 77 | -------------------------------------------------------------------------------- /mini_coil/training/train_word.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | 6 | import lightning as L 7 | import torch 8 | from ipdb import launch_ipdb_on_exception 9 | from lightning.pytorch.loggers import CSVLogger 10 | 11 | from mini_coil.model.word_encoder import WordEncoder 12 | from mini_coil.training.word_module import WordModule 13 | 14 | 15 | def get_encoder(input_dim, output_dim, dropout: float = 0.05): 16 | return WordEncoder( 17 | input_dim=input_dim, 18 | output_dim=output_dim, 19 | dropout=dropout 20 | ) 21 | 22 | 23 | class DataLoader: 24 | 25 | def __init__( 26 | self, 27 | embeddings: np.ndarray, 28 | targets: np.ndarray, 29 | batch_size: int = 200, 30 | use_cuda: bool = True, 31 | ): 32 | if use_cuda: 33 | embeddings = torch.from_numpy(embeddings).float().cuda() 34 | targets = torch.from_numpy(targets).float().cuda() 35 | else: 36 | embeddings = torch.from_numpy(embeddings).float() 37 | targets = torch.from_numpy(targets).float() 38 | self.embeddings = embeddings 39 | self.targets = targets 40 | self.batch_size = batch_size 41 | 42 | def __iter__(self): 43 | total_batches = self.embeddings.shape[0] // self.batch_size 44 | 45 | for i in range(total_batches): 46 | from_idx = i * self.batch_size 47 | to_idx = (i + 1) * self.batch_size 48 | yield { 49 | 'word_embeddings': self.embeddings[from_idx:to_idx], 50 | 'target_embeddings': self.targets[from_idx:to_idx], 51 | } 52 | 53 | 54 | def split_train_val(embeddings: np.ndarray, target: np.ndarray, val_size=0.1): 55 | """ 56 | Take last N elements of the embeddings and target as validation set. 57 | Records are already shuffled, so we can just take the last N elements. 58 | """ 59 | val_size = int(embeddings.shape[0] * val_size) 60 | train_embeddings = embeddings[:-val_size] 61 | train_target = target[:-val_size] 62 | val_embeddings = embeddings[-val_size:] 63 | val_target = target[-val_size:] 64 | 65 | return train_embeddings, train_target, val_embeddings, val_target 66 | 67 | 68 | def main(): 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument("--embedding-path", type=str) 71 | parser.add_argument("--target-path", type=str) 72 | parser.add_argument('--output-path', type=str) 73 | parser.add_argument('--log-dir', type=str) 74 | parser.add_argument('--gpu', action='store_true') 75 | parser.add_argument("--epochs", type=int, default=500) 76 | parser.add_argument("--val-size", type=float, default=0.1) 77 | parser.add_argument("--batch-size", type=int, default=200) 78 | parser.add_argument("--dropout", type=float, default=0.05) 79 | parser.add_argument("--lr", type=float, default=2e-3) 80 | parser.add_argument("--factor", type=float, default=0.5) 81 | parser.add_argument("--patience", type=int, default=5) 82 | 83 | 84 | args = parser.parse_args() 85 | 86 | embedding = np.load(args.embedding_path) 87 | 88 | target = np.load(args.target_path) 89 | 90 | train_embeddings, train_target, val_embeddings, val_target = split_train_val(embedding, target, val_size=args.val_size) 91 | 92 | input_dim = train_embeddings.shape[1] 93 | output_dim = train_target.shape[1] 94 | 95 | encoder_load = get_encoder(input_dim, output_dim, dropout=args.dropout) 96 | 97 | encoder_prepared = encoder_load 98 | 99 | # ToDo: 100 | # encoder_load.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86') 101 | # encoder_prepared = torch.ao.quantization.prepare_qat(encoder_load.train()) 102 | 103 | if args.gpu: 104 | accelerator = 'auto' 105 | else: 106 | accelerator = 'cpu' 107 | torch.set_num_threads(1) 108 | 109 | trainer = L.Trainer( 110 | max_epochs=args.epochs, 111 | enable_checkpointing=False, 112 | logger=CSVLogger(args.log_dir), 113 | enable_progress_bar=True, 114 | accelerator=accelerator, 115 | ) 116 | 117 | default_batch_size = args.batch_size 118 | val_batch = min(val_embeddings.shape[0] - 1, default_batch_size) 119 | 120 | train_loader = DataLoader(train_embeddings, train_target, batch_size=args.batch_size, use_cuda=args.gpu) 121 | valid_loader = DataLoader(val_embeddings, val_target, use_cuda=args.gpu, batch_size=val_batch) 122 | 123 | with launch_ipdb_on_exception(): 124 | trainer.fit( 125 | model=WordModule(encoder_prepared, lr=args.lr, factor=args.factor, patience=args.patience), 126 | train_dataloaders=train_loader, 127 | val_dataloaders=valid_loader 128 | ) 129 | 130 | output_dir = os.path.dirname(args.output_path) 131 | 132 | if not os.path.exists(output_dir): 133 | os.makedirs(output_dir, exist_ok=True) 134 | 135 | torch.save(encoder_prepared.state_dict(), args.output_path) 136 | 137 | # Try to read the saved model 138 | encoder_load = get_encoder(input_dim, output_dim) 139 | 140 | # encoder_load.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86') 141 | # encoder_prepared_load = torch.ao.quantization.prepare_qat(encoder_load) 142 | encoder_load.load_state_dict(torch.load(args.output_path, weights_only=True)) 143 | 144 | 145 | if __name__ == "__main__": 146 | main() 147 | -------------------------------------------------------------------------------- /mini_coil/training/train_word_triplet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import Tuple 4 | 5 | import lightning as L 6 | import numpy as np 7 | import torch 8 | from lightning.pytorch.loggers import CSVLogger 9 | 10 | from mini_coil.model.word_encoder import WordEncoder 11 | from mini_coil.training.triplet_dataloader import TripletDataloader 12 | from mini_coil.training.triplet_word_module import TripletWordModule 13 | 14 | 15 | def get_encoder(input_dim, output_dim, dropout: float = 0.05): 16 | return WordEncoder( 17 | input_dim=input_dim, 18 | output_dim=output_dim, 19 | dropout=dropout 20 | ) 21 | 22 | 23 | def split_train_val( 24 | embeddings: np.ndarray, 25 | similarity_matrix: np.ndarray, 26 | line_numbers: np.ndarray, 27 | batch_size: int = 32, 28 | val_size=0.2, 29 | ) -> Tuple[TripletDataloader, TripletDataloader]: 30 | from_train = 0 31 | total_size = min(similarity_matrix.shape[0], len(line_numbers)) 32 | from_val = int(total_size * (1 - val_size)) 33 | to_train = from_val 34 | to_val = total_size 35 | 36 | train_dataloader = TripletDataloader( 37 | embeddings=embeddings, 38 | line_numbers=line_numbers, 39 | similarity_matrix=similarity_matrix, 40 | range_from=from_train, 41 | range_to=to_train, 42 | batch_size=batch_size, 43 | epoch_size=64_000, 44 | min_margin=0.1 45 | ) 46 | 47 | val_dataloader = TripletDataloader( 48 | embeddings=embeddings, 49 | line_numbers=line_numbers, 50 | similarity_matrix=similarity_matrix, 51 | range_from=from_val, 52 | range_to=to_val, 53 | batch_size=batch_size, 54 | epoch_size=6400, 55 | min_margin=0.1 56 | ) 57 | 58 | return train_dataloader, val_dataloader 59 | 60 | 61 | 62 | def main(): 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument("--embedding-path", type=str) 65 | parser.add_argument("--distance-matrix-path", type=str) 66 | parser.add_argument("--output-dim", type=int, default=4) 67 | parser.add_argument('--output-path', type=str) 68 | parser.add_argument('--log-dir', type=str) 69 | parser.add_argument("--epochs", type=int, default=500) 70 | parser.add_argument("--val-size", type=float, default=0.2) 71 | parser.add_argument("--batch-size", type=int, default=256) 72 | parser.add_argument("--dropout", type=float, default=0.05) 73 | parser.add_argument("--lr", type=float, default=2e-3) 74 | parser.add_argument("--factor", type=float, default=0.5) 75 | parser.add_argument("--patience", type=int, default=5) 76 | parser.add_argument("--line-numbers-path", type=str) 77 | 78 | args = parser.parse_args() 79 | 80 | embedding = np.load(args.embedding_path) 81 | 82 | line_numbers_path = args.line_numbers_path 83 | if not os.path.exists(line_numbers_path): 84 | raise FileNotFoundError(f"Expected line numbers file: {line_numbers_path}") 85 | line_numbers = np.load(line_numbers_path) 86 | 87 | distance_matrix = np.load(args.distance_matrix_path) 88 | 89 | train_loader, valid_loader = split_train_val( 90 | embeddings=embedding, 91 | similarity_matrix=distance_matrix, 92 | line_numbers=line_numbers, 93 | val_size=args.val_size, 94 | batch_size=args.batch_size 95 | ) 96 | 97 | input_dim = embedding.shape[1] 98 | output_dim = args.output_dim 99 | 100 | encoder_load = get_encoder(input_dim, output_dim, dropout=args.dropout) 101 | 102 | encoder_prepared = encoder_load 103 | 104 | accelerator = 'cpu' 105 | torch.set_num_threads(1) 106 | 107 | trainer = L.Trainer( 108 | max_epochs=args.epochs, 109 | enable_checkpointing=False, 110 | logger=CSVLogger(args.log_dir), 111 | # logger=TensorBoardLogger(args.log_dir), 112 | enable_progress_bar=True, 113 | accelerator=accelerator, 114 | ) 115 | 116 | # with launch_ipdb_on_exception(): 117 | trainer.fit( 118 | model=TripletWordModule( 119 | encoder_prepared, 120 | lr=args.lr, 121 | factor=args.factor, 122 | patience=args.patience), 123 | train_dataloaders=train_loader, 124 | val_dataloaders=valid_loader 125 | ) 126 | 127 | output_dir = os.path.dirname(args.output_path) 128 | 129 | if not os.path.exists(output_dir): 130 | os.makedirs(output_dir, exist_ok=True) 131 | 132 | torch.save(encoder_prepared.state_dict(), args.output_path) 133 | 134 | # Try to read the saved model 135 | encoder_load = get_encoder(input_dim, output_dim) 136 | encoder_load.load_state_dict(torch.load(args.output_path, weights_only=True)) 137 | 138 | 139 | if __name__ == "__main__": 140 | main() 141 | -------------------------------------------------------------------------------- /mini_coil/training/triplet_dataloader.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Iterable, Tuple, Dict 3 | import numpy as np 4 | 5 | 6 | def sample_triplets(distance_matrix: np.ndarray, margin: float) -> Iterable[Tuple[int, int, int, float]]: 7 | size = distance_matrix.shape[0] 8 | while True: 9 | x, y, z = random.sample(range(size), 3) 10 | 11 | dxy, dxz, dyz = distance_matrix[x, y], distance_matrix[x, z], distance_matrix[y, z] 12 | 13 | x_anchor_dist = abs(dxy - dxz) 14 | y_anchor_dist = abs(dxy - dyz) 15 | z_anchor_dist = abs(dxz - dyz) 16 | 17 | if x_anchor_dist > margin and x_anchor_dist > y_anchor_dist and x_anchor_dist > z_anchor_dist: 18 | anchor = x 19 | if dxy > dxz: 20 | positive = z 21 | negative = y 22 | else: 23 | positive = y 24 | negative = z 25 | yield anchor, positive, negative, x_anchor_dist 26 | continue 27 | 28 | if y_anchor_dist > margin and y_anchor_dist > x_anchor_dist and y_anchor_dist > z_anchor_dist: 29 | anchor = y 30 | if dxy > dyz: 31 | positive = z 32 | negative = x 33 | else: 34 | positive = x 35 | negative = z 36 | yield anchor, positive, negative, y_anchor_dist 37 | continue 38 | 39 | if z_anchor_dist > margin and z_anchor_dist > x_anchor_dist and z_anchor_dist > y_anchor_dist: 40 | anchor = z 41 | if dxz > dyz: 42 | positive = y 43 | negative = x 44 | else: 45 | positive = x 46 | negative = y 47 | yield anchor, positive, negative, z_anchor_dist 48 | continue 49 | 50 | 51 | class TripletDataloader: 52 | 53 | def __init__( 54 | self, 55 | embeddings: np.ndarray, 56 | similarity_matrix: np.ndarray, 57 | line_numbers: np.ndarray, 58 | min_margin: float = 0.1, 59 | batch_size: int = 32, 60 | # Subset of the similarity matrix to use 61 | range_from: int = 0, 62 | range_to: int = None, 63 | epoch_size: int = 3200, 64 | ): 65 | self.embeddings = embeddings 66 | self.line_numbers = line_numbers 67 | self.min_margin = min_margin 68 | self.batch_size = batch_size 69 | self.range_from = range_from 70 | 71 | max_range = min(similarity_matrix.shape[0], len(line_numbers)) 72 | self.range_to = range_to if range_to is not None else max_range 73 | 74 | self.matrix_to_embeddings = [[] for _ in range(similarity_matrix.shape[0])] 75 | for i, line_number in enumerate(line_numbers.tolist()): 76 | self.matrix_to_embeddings[line_number].append(i) 77 | 78 | self.distance_matrix = 1.0 - similarity_matrix[self.range_from:self.range_to, self.range_from:self.range_to] 79 | 80 | self.epoch_size = epoch_size 81 | 82 | def __iter__(self) -> Iterable[Dict[str, np.ndarray]]: 83 | rows = [] 84 | triplets = [] 85 | margins = [] 86 | n = 0 87 | for anchor, positive, negative, margin in sample_triplets(self.distance_matrix, self.min_margin): 88 | anchor_emb_id = random.choice(self.matrix_to_embeddings[anchor + self.range_from]) 89 | positive_emb_id = random.choice(self.matrix_to_embeddings[positive + self.range_from]) 90 | negative_emb_id = random.choice(self.matrix_to_embeddings[negative + self.range_from]) 91 | 92 | rows.append(self.embeddings[anchor_emb_id]) 93 | rows.append(self.embeddings[positive_emb_id]) 94 | rows.append(self.embeddings[negative_emb_id]) 95 | 96 | triplets.append((len(rows) - 3, len(rows) - 2, len(rows) - 1)) 97 | margins.append(margin) 98 | n += 1 99 | if len(triplets) >= self.batch_size: 100 | yield { 101 | "embeddings": np.array(rows), 102 | "triplets": np.array(triplets), 103 | "margins": np.array(margins) 104 | } 105 | rows = [] 106 | triplets = [] 107 | margins = [] 108 | if n >= self.epoch_size: 109 | break 110 | 111 | def test_triplet_dataloader(): 112 | embeddings = np.array([ 113 | [1, 0, 0, 0], 114 | [1.1, 0, 0, 0], 115 | [0, 1, 0, 0], 116 | [0, 1.1, 0, 0], 117 | [0, 0, 1, 0], 118 | [0, 0, 0, 1] 119 | ]) 120 | 121 | similarity_matrix = np.array([ 122 | [1, 0.5, 0.1, 0.2], 123 | [0.5, 1, 0.3, 0.4], 124 | [0.1, 0.3, 1, 0.5], 125 | [0.2, 0.4, 0.5, 1], 126 | ]) 127 | 128 | line_numbers = np.array([0, 0, 1, 1, 2, 3]) 129 | 130 | dataloader = TripletDataloader( 131 | embeddings, 132 | similarity_matrix, 133 | line_numbers, 134 | min_margin=0.1, 135 | batch_size=2, 136 | range_from=0, 137 | range_to=4 138 | ) 139 | 140 | batch = next(iter(dataloader)) 141 | 142 | print(dataloader.matrix_to_embeddings) 143 | 144 | print(batch["embeddings"]) 145 | print(batch["triplets"]) 146 | print(batch["margins"]) 147 | 148 | 149 | if __name__ == '__main__': 150 | test_triplet_dataloader() 151 | -------------------------------------------------------------------------------- /mini_coil/training/triplet_word_module.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | import torch 3 | from torch import optim 4 | from torch.optim.lr_scheduler import ReduceLROnPlateau 5 | 6 | from mini_coil.model.triplet_loss import TripletLoss 7 | from mini_coil.model.word_encoder import WordEncoder 8 | 9 | 10 | class TripletWordModule(L.LightningModule): 11 | def __init__(self, encoder: WordEncoder, lr: float = 2e-3, factor: float = 0.5, patience: int = 5): 12 | super().__init__() 13 | self.encoder = encoder 14 | self.loss = TripletLoss() 15 | self.lr = lr 16 | self.factor = factor 17 | self.patience = patience 18 | 19 | def configure_optimizers(self): 20 | optimizer = optim.Adam(self.parameters(), lr=self.lr) 21 | return { 22 | "optimizer": optimizer, 23 | "lr_scheduler": { 24 | "scheduler": ReduceLROnPlateau( 25 | optimizer, 26 | mode="min", 27 | factor=self.factor, 28 | patience=self.patience, 29 | verbose=True, 30 | threshold=1e-4 31 | ), 32 | "monitor": "val_loss", 33 | }, 34 | } 35 | 36 | def training_step(self, batch, batch_idx): 37 | embeddings = torch.from_numpy(batch["embeddings"]).float().to(self.device) 38 | triplets = torch.from_numpy(batch["triplets"]).long().to(self.device) 39 | margins = torch.from_numpy(batch["margins"]).float().to(self.device) 40 | encoded = self.encoder(embeddings) 41 | loss_val, _ = self.loss( 42 | encoded, 43 | triplets, 44 | margins, 45 | ) 46 | self.log( 47 | "train_loss", 48 | loss_val, 49 | on_epoch=True, 50 | on_step=False, 51 | prog_bar=True, 52 | logger=True, 53 | batch_size=embeddings.size(0) 54 | ) 55 | return loss_val 56 | 57 | def validation_step(self, batch, batch_idx): 58 | embeddings = torch.from_numpy(batch["embeddings"]).float().to(self.device) 59 | triplets = torch.from_numpy(batch["triplets"]).long().to(self.device) 60 | margins = torch.from_numpy(batch["margins"]).float().to(self.device) 61 | encoded = self.encoder(embeddings) 62 | loss_val, number_failed_triplets = self.loss( 63 | encoded, 64 | triplets, 65 | margins, 66 | ) 67 | self.log( 68 | "val_loss", 69 | loss_val, 70 | batch_size=embeddings.size(0) 71 | ) 72 | self.log( 73 | "val_failed_triplets", 74 | number_failed_triplets, 75 | batch_size=embeddings.size(0) 76 | ) 77 | return loss_val 78 | -------------------------------------------------------------------------------- /mini_coil/training/try_checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from mini_coil.data_pipeline.pre_encoder import PreEncoder 6 | from mini_coil.data_pipeline.vocab_resolver import VocabResolver 7 | from mini_coil.settings import DATA_DIR 8 | from mini_coil.training.coil_module import MiniCoil 9 | from mini_coil.training.train import get_encoder, get_decoder 10 | 11 | 12 | def cosine_similarity(rows_a: torch.Tensor, rows_b: torch.Tensor) -> torch.Tensor: 13 | """ 14 | Compute a matrix of cosine distances between two sets of vectors. 15 | """ 16 | # Normalize the vectors 17 | rows_a = rows_a / torch.norm(rows_a, dim=1, keepdim=True) 18 | rows_b = rows_b / torch.norm(rows_b, dim=1, keepdim=True) 19 | 20 | # Compute the cosine similarity 21 | return torch.mm(rows_a, rows_b.T) 22 | 23 | 24 | def main(): 25 | text_a = "The bat flew out of the cave." 26 | text_b = "A bat can use echolocation to navigate in the dark." 27 | text_c = "He is a baseball player. He knows how to swing a bat." 28 | text_d = "Eric Byrnes, never with an at bat in Yankee Stadium and they don't get much bigger than this one." 29 | text_e = "It was just a cricket bat." 30 | text_f = "And guess who the orphans have at bat!" 31 | 32 | model_repository = "sentence-transformers/all-MiniLM-L6-v2" 33 | model_save_path = os.path.join(DATA_DIR, "all_miniLM_L6_v2.onnx") 34 | test_vocab_path = os.path.join(DATA_DIR, "test", "vocab.txt") 35 | 36 | vocab_resolver = VocabResolver(model_repository) 37 | vocab_resolver.load_vocab(test_vocab_path) 38 | 39 | pre_encoder = PreEncoder(model_repository, model_save_path) 40 | 41 | texts = [text_a, text_b, text_c, text_d, text_e, text_f] 42 | 43 | result = pre_encoder.encode(texts) 44 | 45 | resolved_token_ids = vocab_resolver.token_ids_to_vocab_batch( 46 | result["token_ids"] 47 | ) 48 | 49 | text_embeddings = torch.from_numpy(result["text_embeddings"]) 50 | 51 | print("here") 52 | 53 | version = "8" 54 | 55 | model = MiniCoil.load_from_checkpoint( 56 | os.path.join(DATA_DIR, "..", "lightning_logs", f"version_{version}", "checkpoints", 57 | "epoch=999-step=50000.ckpt"), 58 | encoder=get_encoder(vocab_resolver.vocab_size()), 59 | decoder=get_decoder(vocab_resolver.vocab_size()), 60 | ) 61 | 62 | encoder = model.encoder 63 | encoder.eval() 64 | 65 | resolved_token_ids_torch = torch.from_numpy(resolved_token_ids).to(model.device) 66 | token_embeddings_torch = torch.from_numpy(result["token_embeddings"]).to(model.device).float() 67 | 68 | encoded = encoder( 69 | resolved_token_ids_torch, 70 | token_embeddings_torch 71 | ) 72 | 73 | # (batch size, embedding size) 74 | selected_token_embeddings_torch = token_embeddings_torch[resolved_token_ids_torch != 0] 75 | 76 | print(encoded[0]) 77 | 78 | bat_tokens = encoded[1][len(texts):] 79 | 80 | print(bat_tokens) 81 | 82 | # Matrix cosine similarity 83 | matrix = cosine_similarity(bat_tokens, bat_tokens) 84 | 85 | print("compressed similarity matrix\n", matrix) 86 | 87 | original_matrix = cosine_similarity(text_embeddings, text_embeddings) 88 | 89 | print("original similarity matrix\n", original_matrix) 90 | 91 | # token_matrix = cosine_similarity(token_embeddings_torch, token_embeddings_torch) 92 | # 93 | # print("token similarity matrix\n", token_matrix) 94 | 95 | # Replace values in each row with it's index in sorted order 96 | 97 | # (batch size, batch size) 98 | sorted_matrix, indices = torch.sort(matrix, descending=True) 99 | invert_indices = torch.argsort(indices) 100 | print("indices\n", invert_indices) 101 | 102 | # (batch size, batch size) 103 | sorted_original_matrix, original_indices = torch.sort(original_matrix, descending=True) 104 | invert_original_indices = torch.argsort(original_indices) 105 | print("original indices\n", invert_original_indices) 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /mini_coil/training/word_module.py: -------------------------------------------------------------------------------- 1 | from random import random 2 | from typing import Dict 3 | 4 | import lightning as L 5 | import numpy as np 6 | import torch 7 | from torch import optim 8 | from torch.optim.lr_scheduler import ReduceLROnPlateau 9 | 10 | from mini_coil.model.cosine_loss import CosineLoss 11 | from mini_coil.model.word_encoder import WordEncoder 12 | 13 | 14 | class WordModule(L.LightningModule): 15 | def __init__( 16 | self, 17 | encoder: WordEncoder, 18 | lr: float = 2e-3, 19 | factor: float = 0.5, 20 | patience: int = 5, 21 | ): 22 | super().__init__() 23 | self.encoder: WordEncoder = encoder 24 | self.loss = CosineLoss() 25 | self.lr = lr 26 | self.factor = factor 27 | self.patience = patience 28 | 29 | def configure_optimizers(self): 30 | optimizer = optim.Adam(self.parameters(), lr=self.lr) 31 | return { 32 | "optimizer": optimizer, 33 | "lr_scheduler": { 34 | "scheduler": ReduceLROnPlateau( 35 | optimizer, 36 | mode='min', 37 | factor=self.factor, 38 | patience=self.patience, 39 | verbose=True, 40 | threshold=1e-4, 41 | ), 42 | "monitor": "val_loss", 43 | # "frequency": "indicates how often the metric is updated", 44 | # If "monitor" references validation metrics, then "frequency" should be set to a 45 | # multiple of "trainer.check_val_every_n_epoch". 46 | }, 47 | } 48 | 49 | def encode_decode_loss( 50 | self, 51 | batch: Dict[str, np.ndarray], 52 | ): 53 | """ 54 | batch: 55 | { 56 | 'word_embeddings': np array of shape (batch_size, embedding_size), 57 | 'target_embeddings': np array of shape (batch_size, compressed_dim), 58 | } 59 | """ 60 | 61 | # import ipdb; ipdb.set_trace() 62 | 63 | word_embeddings = batch['word_embeddings'] 64 | target_embeddings = batch['target_embeddings'] 65 | 66 | # (batch_size, compressed_dim) 67 | encoded = self.encoder( 68 | word_embeddings 69 | ) 70 | 71 | # one-to-one mapping of the encoded to the target embeddings 72 | # (batch_size) 73 | mapping = None # torch.arange(encoded.size(0), device=encoded.device) 74 | 75 | loss = self.loss(mapping, encoded, target_embeddings) 76 | 77 | return loss 78 | 79 | def training_step( 80 | self, 81 | batch: Dict[str, np.ndarray], 82 | batch_idx: int 83 | ): 84 | loss = self.encode_decode_loss(batch) 85 | self.log("train_loss", loss, on_epoch=True, on_step=False, prog_bar=True, logger=True) 86 | 87 | return loss 88 | 89 | def validation_step( 90 | self, 91 | batch: Dict[str, np.ndarray], 92 | batch_idx: int 93 | ): 94 | batch_size = batch['word_embeddings'].shape[0] 95 | loss = self.encode_decode_loss(batch) 96 | self.log("val_loss", loss, batch_size=batch_size) 97 | 98 | return loss 99 | -------------------------------------------------------------------------------- /mini_coil/triplet_loss.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Optional, Tuple 3 | 4 | import pytest 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import LongTensor, Tensor 8 | 9 | 10 | class Distance(str, Enum): 11 | """An enumerator to pass distance metric names across the package.""" 12 | 13 | EUCLIDEAN = "euclidean" 14 | COSINE = "cosine" 15 | DOT_PRODUCT = "dot_product" 16 | MANHATTAN = "manhattan" 17 | 18 | 19 | class BaseDistance: 20 | def distance_matrix(self, x: Tensor, y: Optional[Tensor] = None) -> Tensor: 21 | raise NotImplementedError() 22 | 23 | 24 | class Cosine(BaseDistance): 25 | def distance_matrix(self, x: Tensor, y: Optional[Tensor] = None) -> Tensor: 26 | if y is None: 27 | y = x 28 | x_norm = F.normalize(x, p=2, dim=1) 29 | y_norm = F.normalize(y, p=2, dim=1) 30 | return 1.0 - torch.mm(x_norm, y_norm.t()) 31 | 32 | 33 | def get_triplet_mask(labels: torch.Tensor) -> torch.Tensor: 34 | """Creates a 3D mask of valid triplets for the batch-all strategy. 35 | 36 | Given a batch of labels with `shape = (batch_size,)` 37 | the number of possible triplets that can be formed is: 38 | batch_size^3, i.e. cube of batch_size, 39 | which can be represented as a tensor with `shape = (batch_size, batch_size, batch_size)`. 40 | However, a triplet is valid if: 41 | `labels[i] == labels[j] and labels[i] != labels[k]` 42 | and `i`, `j` and `k` are distinct indices. 43 | This function calculates a mask indicating which ones of all the possible triplets 44 | are actually valid triplets based on the given criteria above. 45 | 46 | Args: 47 | labels (torch.Tensor): Labels associated with embeddings in the batch. Shape: (batch_size,) 48 | 49 | Returns: 50 | torch.Tensor: Triplet mask. Shape: (batch_size, batch_size, batch_size) 51 | """ 52 | indices_equal = torch.eye(labels.size()[0], dtype=torch.bool, device=labels.device) 53 | indices_not_equal = torch.logical_not(indices_equal) 54 | 55 | i_not_equal_j = indices_not_equal.unsqueeze(2) 56 | i_not_equal_k = indices_not_equal.unsqueeze(1) 57 | j_not_equal_k = indices_not_equal.unsqueeze(0) 58 | 59 | distinct_indices = torch.logical_and( 60 | torch.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k 61 | ) 62 | 63 | labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1) 64 | i_equal_j = labels_equal.unsqueeze(2) 65 | i_equal_k = labels_equal.unsqueeze(1) 66 | valid_indices = torch.logical_and(i_equal_j, torch.logical_not(i_equal_k)) 67 | 68 | mask = torch.logical_and(distinct_indices, valid_indices) 69 | return mask 70 | 71 | 72 | def get_anchor_positive_mask( 73 | labels_a: torch.Tensor, labels_b: Optional[torch.Tensor] = None 74 | ) -> Tensor: 75 | """Creates a 2D mask of valid anchor-positive pairs. 76 | 77 | Args: 78 | labels_a (torch.Tensor): Labels associated with embeddings in the batch A. Shape: (batch_size_a,) 79 | labels_b (torch.Tensor): Labels associated with embeddings in the batch B. Shape: (batch_size_b,) 80 | If `labels_b is None`, it assigns `labels_a` to `labels_b`. 81 | 82 | Returns: 83 | torch.Tensor: Anchor-positive mask. Shape: (batch_size_a, batch_size_b) 84 | """ 85 | if labels_b is None: 86 | labels_b = labels_a 87 | 88 | mask = labels_a.expand(labels_b.shape[0], labels_a.shape[0]).t() == labels_b.expand( 89 | labels_a.shape[0], labels_b.shape[0] 90 | ) 91 | 92 | if torch.equal(labels_a, labels_b): 93 | indices_equal = torch.eye( 94 | labels_a.size()[0], dtype=torch.bool, device=labels_a.device 95 | ) 96 | indices_not_equal = torch.logical_not(indices_equal) 97 | mask = torch.logical_and(indices_not_equal, mask) 98 | 99 | return mask 100 | 101 | 102 | def get_anchor_negative_mask( 103 | labels_a: torch.Tensor, labels_b: Optional[torch.Tensor] = None 104 | ) -> Tensor: 105 | if labels_b is None: 106 | labels_b = labels_a 107 | 108 | mask = labels_a.expand(labels_b.shape[0], labels_a.shape[0]).t() != labels_b.expand( 109 | labels_a.shape[0], labels_b.shape[0] 110 | ) 111 | return mask 112 | 113 | 114 | class TripletLoss(torch.nn.Module): 115 | """Implements Triplet Loss as defined in https://arxiv.org/abs/1503.03832 116 | 117 | It supports batch-all, batch-hard and batch-semihard strategies for online triplet mining. 118 | 119 | Args: 120 | margin: Margin value to push negative examples 121 | apart. 122 | distance_metric_name: Name of the distance function, e.g., 123 | :class:`~quaterion.distances.Distance`. 124 | mining: Triplet mining strategy. One of 125 | `"all"`, `"hard"`, `"semi_hard"`. 126 | soft: If `True`, use soft margin variant of Hard Triplet Loss. Ignored in all other cases. 127 | """ 128 | 129 | def __init__( 130 | self, 131 | margin: float = 0.5, 132 | distance_metric: str = Distance.COSINE, 133 | mining: str = "hard", 134 | soft: bool = False 135 | ): 136 | super().__init__() 137 | 138 | mining_types = ["all", "hard", "semi_hard"] 139 | if mining not in mining_types: 140 | raise ValueError( 141 | f"Unrecognized mining strategy: {mining}. Must be one of {', '.join(mining_types)}" 142 | ) 143 | 144 | self._margin = margin 145 | self._mining = mining 146 | self._soft = soft 147 | 148 | if distance_metric == Distance.COSINE: 149 | self.distance_metric = Cosine() 150 | else: 151 | raise ValueError(f"Currently only cosine distance is implemented") 152 | 153 | def _hard_triplet_loss( 154 | self, 155 | embeddings_a: Tensor, 156 | groups_a: LongTensor, 157 | embeddings_b: Tensor, 158 | groups_b: LongTensor, 159 | ) -> Tensor: 160 | """ 161 | Calculates Triplet Loss with hard mining between two sets of embeddings. 162 | 163 | Args: 164 | embeddings_a: (batch_size_a, vector_length) - Batch of embeddings. 165 | groups_a: (batch_size_a,) - Batch of labels associated with `embeddings_a` 166 | embeddings_b: (batch_size_b, vector_length) - Batch of embeddings. 167 | groups_b: (batch_size_b,) - Batch of labels associated with `embeddings_b` 168 | 169 | Returns: 170 | torch.Tensor: Scalar loss value. 171 | """ 172 | dists = self.distance_metric.distance_matrix(embeddings_a, embeddings_b) 173 | 174 | anchor_positive_mask = get_anchor_positive_mask(groups_a, groups_b).float() 175 | anchor_positive_dists = anchor_positive_mask * dists 176 | hardest_positive_dists = anchor_positive_dists.max(dim=1)[0] 177 | 178 | anchor_negative_mask = get_anchor_negative_mask(groups_a, groups_b).float() 179 | anchor_negative_dists = dists + dists.max(dim=1, keepdim=True)[0] * ( 180 | 1.0 - anchor_negative_mask 181 | ) 182 | hardest_negative_dists = anchor_negative_dists.min(dim=1)[0] 183 | 184 | triplet_loss = ( 185 | F.softplus(hardest_positive_dists - hardest_negative_dists) 186 | if self._soft 187 | else F.relu( 188 | (hardest_positive_dists - hardest_negative_dists) 189 | / hardest_negative_dists.mean() 190 | + self._margin 191 | ) 192 | ) 193 | 194 | return triplet_loss.mean() 195 | 196 | def _semi_hard_triplet_loss( 197 | self, 198 | embeddings_a: Tensor, 199 | groups_a: Tensor, 200 | embeddings_b: Tensor, 201 | groups_b: Tensor, 202 | ) -> Tensor: 203 | """Compute triplet loss with semi-hard mining as described in https://arxiv.org/abs/1703.07737 204 | 205 | It encourages the positive distances to be smaller than the minimum negative distance 206 | among which are at least greater than the positive distance plus the margin 207 | (called semi-hard negative), 208 | i.e., D(a, p) < D(a, n) < D(a, p) + margin. 209 | If no such negative exists, it uses the largest negative distance instead. 210 | 211 | Inspired by https://github.com/tensorflow/addons/blob/master/tensorflow_addons/losses/triplet.py 212 | 213 | Args: 214 | embeddings_a: shape: (batch_size_a, vector_length) - Output embeddings from the 215 | encoder. 216 | groups_a: shape: (batch_size_a,) - Group ids associated with embeddings. 217 | embeddings: shape: (batch_size_b, vector_length) - Batch of bmbeddings 218 | groups_b: shape: (batch_size_b,) - Groups ids associated with `embeddings_b` 219 | 220 | Returns: 221 | Tensor: zero-size tensor, XBM loss value. 222 | """ 223 | distances = self.distance_metric.distance_matrix(embeddings_a, embeddings_b) 224 | 225 | positive_indices = groups_a[:, None] == groups_b[None, :] 226 | negative_indices = groups_a[:, None] != groups_b[None, :] 227 | 228 | pos_distance = torch.masked_select(distances, positive_indices) 229 | neg_distance = torch.masked_select(distances, negative_indices) 230 | 231 | basic_loss = pos_distance[:, None] - neg_distance[None, :] + self._margin 232 | zero_loss = torch.clamp(basic_loss, min=0.0) 233 | semi_hard_loss = torch.clamp(zero_loss, max=self._margin) 234 | 235 | return torch.mean(semi_hard_loss) 236 | 237 | def forward( 238 | self, 239 | embeddings: Tensor, 240 | groups: LongTensor, 241 | ) -> Tensor: 242 | """Calculates Triplet Loss with specified embeddings and labels. 243 | 244 | Args: 245 | embeddings: shape: (batch_size, vector_length) - Batch of embeddings. 246 | groups: shape: (batch_size,) - Batch of labels associated with `embeddings` 247 | 248 | Returns: 249 | torch.Tensor: Scalar loss value. 250 | """ 251 | if self._mining == "all": 252 | dists = self.distance_metric.distance_matrix(embeddings) 253 | 254 | anchor_positive_dists = dists.unsqueeze(2) 255 | anchor_negative_dists = dists.unsqueeze(1) 256 | triplet_loss = anchor_positive_dists - anchor_negative_dists + self._margin 257 | 258 | mask = get_triplet_mask(groups).float() 259 | triplet_loss = mask * triplet_loss 260 | triplet_loss = F.relu(triplet_loss) 261 | 262 | num_positive_triplets = torch.sum((triplet_loss > 1e-16).float()) 263 | triplet_loss = torch.sum(triplet_loss) / (num_positive_triplets + 1e-16) 264 | 265 | elif self._mining == "hard": 266 | triplet_loss = self._hard_triplet_loss( 267 | embeddings, groups, embeddings, groups 268 | ) 269 | else: 270 | triplet_loss = self._semi_hard_triplet_loss( 271 | embeddings, groups, embeddings, groups 272 | ) 273 | 274 | return triplet_loss 275 | 276 | def xbm_loss( 277 | self, 278 | embeddings: Tensor, 279 | groups: LongTensor, 280 | memory_embeddings: Tensor, 281 | memory_groups: LongTensor, 282 | ) -> Tensor: 283 | """Implement XBM loss computation for this loss. 284 | 285 | Args: 286 | embeddings: shape: (batch_size, vector_length) - Output embeddings from the 287 | encoder. 288 | groups: shape: (batch_size,) - Group ids associated with embeddings. 289 | memory_embeddings: shape: (memory_buffer_size, vector_length) - Embeddings stored 290 | in a ring buffer 291 | memory_groups: shape: (memory_buffer_size,) - Groups ids associated with `memory_embeddings` 292 | 293 | Returns: 294 | Tensor: zero-size tensor, XBM loss value. 295 | """ 296 | if len(memory_groups) == 0 or self._mining == "all": 297 | return torch.tensor( 298 | 0, device=embeddings.device 299 | ) 300 | 301 | return ( 302 | self._hard_triplet_loss( 303 | embeddings, groups, memory_embeddings, memory_groups 304 | ) 305 | if self._mining == "hard" 306 | else self._semi_hard_triplet_loss( 307 | embeddings, groups, memory_embeddings, memory_groups 308 | ) 309 | ) 310 | 311 | 312 | def generate_test_data(batch_size: int = 6) -> Tuple[torch.Tensor, torch.Tensor]: 313 | embeddings = torch.randn(batch_size, 4) 314 | groups = torch.tensor([0, 0, 1, 1, 2, 2]) 315 | return embeddings, groups 316 | 317 | 318 | def test_cosine_distance(): 319 | x = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) 320 | distance = Cosine() 321 | result = distance.distance_matrix(x) 322 | expected = torch.tensor([[0.0, 1.0], [1.0, 0.0]]) 323 | torch.testing.assert_close(result, expected) 324 | 325 | 326 | def test_get_triplet_mask(): 327 | labels = torch.tensor([0, 0, 1, 1]) 328 | mask = get_triplet_mask(labels) 329 | assert mask.shape == (4, 4, 4) 330 | assert mask[0, 1, 2].item() 331 | assert not mask[0, 0, 1].item() 332 | 333 | 334 | def test_get_anchor_positive_mask(): 335 | labels = torch.tensor([0, 0, 1, 1]) 336 | mask = get_anchor_positive_mask(labels) 337 | expected = torch.tensor([ 338 | [False, True, False, False], 339 | [True, False, False, False], 340 | [False, False, False, True], 341 | [False, False, True, False] 342 | ]) 343 | torch.testing.assert_close(mask, expected) 344 | 345 | 346 | def test_get_anchor_negative_mask(): 347 | labels = torch.tensor([0, 0, 1, 1]) 348 | mask = get_anchor_negative_mask(labels) 349 | expected = torch.tensor([ 350 | [False, False, True, True], 351 | [False, False, True, True], 352 | [True, True, False, False], 353 | [True, True, False, False] 354 | ]) 355 | torch.testing.assert_close(mask, expected) 356 | 357 | 358 | def test_triplet_loss_init(): 359 | with pytest.raises(ValueError): 360 | TripletLoss(mining="invalid") 361 | 362 | with pytest.raises(ValueError): 363 | TripletLoss(distance_metric="manhattan") 364 | 365 | 366 | def test_triplet_loss_forward_all(): 367 | loss_fn = TripletLoss(margin=0.5, mining="all") 368 | embeddings, groups = generate_test_data() 369 | loss = loss_fn(embeddings, groups) 370 | assert isinstance(loss.item(), float) 371 | assert loss.item() >= 0 372 | 373 | 374 | def test_triplet_loss_forward_hard(): 375 | loss_fn = TripletLoss(margin=0.5, mining="hard") 376 | embeddings, groups = generate_test_data() 377 | loss = loss_fn(embeddings, groups) 378 | assert isinstance(loss.item(), float) 379 | assert loss.item() >= 0 380 | 381 | 382 | def test_triplet_loss_forward_semi_hard(): 383 | loss_fn = TripletLoss(margin=0.5, mining="semi_hard") 384 | embeddings, groups = generate_test_data() 385 | loss = loss_fn(embeddings, groups) 386 | assert isinstance(loss.item(), float) 387 | assert loss.item() >= 0 388 | 389 | 390 | def test_triplet_loss_soft(): 391 | loss_fn = TripletLoss(margin=0.5, mining="hard", soft=True) 392 | embeddings, groups = generate_test_data() 393 | loss = loss_fn(embeddings, groups) 394 | assert isinstance(loss.item(), float) 395 | assert loss.item() >= 0 396 | 397 | 398 | def test_triplet_loss_xbm(): 399 | loss_fn = TripletLoss(margin=0.5, mining="hard") 400 | embeddings, groups = generate_test_data(6) 401 | memory_embeddings = torch.randn(4, 4) 402 | memory_groups = torch.tensor([0, 0, 1, 1]) 403 | 404 | loss = loss_fn.xbm_loss(embeddings, groups, memory_embeddings, memory_groups) 405 | assert isinstance(loss.item(), float) 406 | assert loss.item() >= 0 407 | 408 | loss_fn = TripletLoss(margin=0.5, mining="semi_hard") 409 | loss = loss_fn.xbm_loss(embeddings, groups, memory_embeddings, memory_groups) 410 | assert isinstance(loss.item(), float) 411 | assert loss.item() >= 0 412 | 413 | empty_memory = torch.tensor([]) 414 | assert loss_fn.xbm_loss(embeddings, groups, empty_memory, empty_memory).item() == 0 415 | 416 | 417 | if __name__ == "__main__": 418 | test_functions = [ 419 | test_cosine_distance, 420 | test_get_triplet_mask, 421 | test_get_anchor_positive_mask, 422 | test_get_anchor_negative_mask, 423 | test_triplet_loss_init, 424 | test_triplet_loss_forward_all, 425 | test_triplet_loss_forward_hard, 426 | test_triplet_loss_forward_semi_hard, 427 | test_triplet_loss_xbm, 428 | test_triplet_loss_soft 429 | ] 430 | 431 | failed_tests = 0 432 | for test in test_functions: 433 | try: 434 | test() 435 | print(f"✓ {test.__name__}") 436 | except Exception as e: 437 | failed_tests += 1 438 | print(f"✗ {test.__name__}") 439 | print(f" Error: {str(e)}") 440 | 441 | if failed_tests: 442 | print(f"\n{failed_tests} tests failed") 443 | exit(1) 444 | else: 445 | print("\nAll tests passed successfully!") 446 | exit(0) 447 | -------------------------------------------------------------------------------- /mini_coil/visualize_encoder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from mini_coil.data_pipeline.encode_and_filter import encode_and_filter 8 | from mini_coil.training.train_word import get_encoder 9 | 10 | 11 | def plot_embeddings( 12 | embeddings, 13 | special_point_x: float, 14 | special_point_y: float, 15 | save_path: str 16 | ): 17 | """ 18 | Plot scatter plot and also one additional red dot at special point 19 | """ 20 | import matplotlib.pyplot as plt 21 | 22 | plt.scatter(embeddings[:, 0], embeddings[:, 1], s=1) 23 | plt.scatter(special_point_x, special_point_y, color='red') 24 | plt.savefig(save_path) 25 | plt.close() 26 | 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--encoder-path", type=str) 31 | parser.add_argument("--embedding-path", type=str) 32 | parser.add_argument("--output-dir", type=str) 33 | parser.add_argument("--word", type=str) 34 | args = parser.parse_args() 35 | 36 | model_name = "jinaai/jina-embeddings-v2-small-en-tokens" 37 | word = args.word 38 | 39 | sentences = [ 40 | "The bat flew out of the cave at dusk, its wings silhouetted against the twilight sky.", 41 | "A small bat darted through the trees, barely visible in the moonlight", 42 | "Bat dung has been mined as guano from caves and used as fertiliser.", 43 | "pokemon bat is a flying type pokemon.", 44 | "She swung the baseball bat with precision, hitting the ball right out of the park.", 45 | "He gripped the cricket bat tightly, ready to face the next ball.", 46 | "She didn’t even bat an eyelash when he told her the shocking news.", 47 | "She didn't bat an eyelash when he announced the surprise, keeping her composure.", 48 | "in the batman comics, the bat was a superhero.", 49 | ] 50 | 51 | embeddings = np.array(list(encode_and_filter( 52 | model_name=model_name, 53 | word=word, 54 | docs=sentences 55 | ))) 56 | 57 | print(embeddings.shape) 58 | 59 | encoder = get_encoder(512, 4) 60 | 61 | encoder.load_state_dict(torch.load(args.encoder_path, weights_only=True)) 62 | 63 | encoder.eval() 64 | 65 | with torch.no_grad(): 66 | encoded = encoder(torch.from_numpy(embeddings).float()) 67 | encoded = encoded.cpu().numpy() 68 | 69 | for enc, sentence in zip(encoded.tolist(), sentences): 70 | print(enc, sentence) 71 | 72 | embeddings = np.load(args.embedding_path) 73 | 74 | if not os.path.exists(args.output_dir): 75 | os.makedirs(args.output_dir, exist_ok=True) 76 | 77 | embeddings_x = embeddings[:, 0] 78 | embeddings_y = embeddings[:, 1] 79 | 80 | # Normalize the embeddings to the length of 1 81 | 82 | length = np.sqrt(embeddings_x ** 2 + embeddings_y ** 2) 83 | 84 | # embeddings_x /= length 85 | # embeddings_y /= length 86 | 87 | for i in range(encoded.shape[0]): 88 | plot_embeddings( 89 | embeddings=np.column_stack([embeddings_x, embeddings_y]), 90 | special_point_x=encoded[i, 0], 91 | special_point_y=encoded[i, 1], 92 | save_path=os.path.join(args.output_dir, f"encoded_{i}.png") 93 | ) 94 | 95 | 96 | if __name__ == "__main__": 97 | main() 98 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "train-bm" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["generall "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.10,<3.13" 10 | pyarrow = "^17.0.0" 11 | pandas = "^2.2.2" 12 | ipdb = "^0.13.13" 13 | pystemmer = "^2.2.0.1" 14 | snowballstemmer = "^2.2.0" 15 | tqdm = "^4.66.4" 16 | torch = "^2.4.0" 17 | onnx = "^1.16.2" 18 | onnxruntime = "^1.18.1" 19 | transformers = "^4.44.0" 20 | npy-append-array = "^0.9.16" 21 | lightning = {extras = ["extra"], version = "^2.4.0"} 22 | sentence-splitter = "^1.4" 23 | sentence-transformers = "^3.0.1" 24 | qdrant-client = "^1.12.0" 25 | umap-learn = "^0.5.6" 26 | scipy = "^1.14.1" 27 | matplotlib = "^3.9.2" 28 | seaborn = "^0.13.2" 29 | fastembed = { git = "https://github.com/qdrant/fastembed.git", branch = "token-embeddings" } 30 | nltk = "^3.9.1" 31 | py-rust-stemmers = "^0.1.3" 32 | 33 | 34 | [build-system] 35 | requires = ["poetry-core"] 36 | build-backend = "poetry.core.masonry.api" 37 | -------------------------------------------------------------------------------- /run_eval.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | set -e # exit on error 4 | set -u # exit on using unset variable 5 | set -o pipefail # exit on error in pipe 6 | 7 | mkdir -p tests/em/t || true 8 | CURRENT_DIR=$(pwd -L) 9 | DIM=4 10 | 11 | WORD_TO_TEST=vector 12 | MINICOIL_MODEL=triplets-8000-mxbai-large-jina-small-${DIM}-augmented/full-models/model 13 | 14 | SENTENCES_FILE="${CURRENT_DIR}/data/validation/${WORD_TO_TEST}-validation.txt" 15 | 16 | 17 | function create_embeddings() { 18 | echo "Creating embeddings" 19 | python -m tests.01_embedding_maker \ 20 | --minicoil-test-word "${WORD_TO_TEST}" \ 21 | --input-file "${SENTENCES_FILE}" \ 22 | --vocab-path "${CURRENT_DIR}/data/${MINICOIL_MODEL}.vocab" \ 23 | --word-encoder-path "${CURRENT_DIR}/data/${MINICOIL_MODEL}.npy" \ 24 | --output-random "${CURRENT_DIR}/tests/em/random.npy" \ 25 | --output-mixedbread "${CURRENT_DIR}/tests/em/mixedbread.npy" \ 26 | --output-jina "${CURRENT_DIR}/tests/em/jina_small.npy" \ 27 | --output-minicoil "${CURRENT_DIR}/tests/em/minicoil.npy" \ 28 | --dim "${DIM}" 29 | } 30 | 31 | function create_matrix() { 32 | echo "Creating distance matrix: Minicoil" 33 | python -m tests.02_matrix_create \ 34 | --input "${CURRENT_DIR}/tests/em/minicoil.npy" \ 35 | --output "${CURRENT_DIR}/tests/em/distance_matrix_minicoil.npy" 36 | 37 | echo "Creating distance matrix: Mixedbread" 38 | python -m tests.02_matrix_create \ 39 | --input "${CURRENT_DIR}/tests/em/mixedbread.npy" \ 40 | --output "${CURRENT_DIR}/tests/em/distance_matrix_mixedbread.npy" 41 | 42 | echo "Creating distance matrix: Jina" 43 | python -m tests.02_matrix_create \ 44 | --input "${CURRENT_DIR}/tests/em/jina_small.npy" \ 45 | --output "${CURRENT_DIR}/tests/em/distance_matrix_jina_small.npy" 46 | 47 | echo "Creating distance matrix: Random" 48 | python -m tests.02_matrix_create \ 49 | --input "${CURRENT_DIR}/tests/em/random.npy" \ 50 | --output "${CURRENT_DIR}/tests/em/distance_matrix_random.npy" 51 | } 52 | 53 | function evaluate() { 54 | echo "Evaluating" 55 | python -m tests.03_matrix_triplets \ 56 | --distance-matrix-base-path "${CURRENT_DIR}/tests/em/distance_matrix_mixedbread.npy" \ 57 | --distance-matrix-eval-path "${CURRENT_DIR}/tests/em/distance_matrix_minicoil.npy" \ 58 | --sample-size 10000 \ 59 | --base-margin 0.1 \ 60 | --eval-margin 0.01 61 | 62 | python -m tests.03_matrix_triplets \ 63 | --distance-matrix-base-path "${CURRENT_DIR}/tests/em/distance_matrix_mixedbread.npy" \ 64 | --distance-matrix-eval-path "${CURRENT_DIR}/tests/em/distance_matrix_random.npy" \ 65 | --sample-size 10000 \ 66 | --base-margin 0.0000123 \ 67 | --eval-margin 0.0000123 68 | } 69 | 70 | 71 | 72 | 73 | 74 | create_embeddings 75 | create_matrix 76 | evaluate 77 | 78 | 79 | 80 | 81 | 82 | 83 | #################################### 84 | # UMAP COMPARE (disabled by default) 85 | 86 | function umap_compare() { 87 | python -m tests.04_umap_emb \ 88 | --embeddings "${CURRENT_DIR}/tests/em/mixedbread.npy" \ 89 | --output-umap "${CURRENT_DIR}/tests/em/mixedbread_umap.npy" \ 90 | --umap-components 4 \ 91 | --n-neighbors 20 92 | 93 | python -m tests.02_matrix_create \ 94 | --input "${CURRENT_DIR}/tests/em/mixedbread_umap.npy" \ 95 | --output "${CURRENT_DIR}/tests/em/distance_matrix_mixedbread_umap.npy" 96 | 97 | python -m tests.03_matrix_triplets \ 98 | --distance-matrix-base-path "${CURRENT_DIR}/tests/em/distance_matrix_mixedbread.npy" \ 99 | --distance-matrix-eval-path "${CURRENT_DIR}/tests/em/distance_matrix_mixedbread_umap.npy" \ 100 | --sample-size 100000 \ 101 | --base-margin 0.1 \ 102 | --eval-margin 0.01 103 | } 104 | 105 | #umap_compare 106 | -------------------------------------------------------------------------------- /split_sentences.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # Split sentences in OpenWebText 5 | 6 | 7 | ARCHIVE_DIR=data/openwebtext 8 | 9 | # Read files with the extension openwebtext-*.txt.gz and convert them to openwebtext-*-sentences.txt.gz 10 | 11 | for file in $ARCHIVE_DIR/openwebtext-*.txt.gz 12 | do 13 | # file = data/openwebtext/openwebtext-00.txt.gz 14 | 15 | file_name=$(basename $file) 16 | # file_name = openwebtext-00.txt.gz 17 | 18 | # Filename without extension 19 | 20 | file_name_no_ext=${file_name%.txt.gz} 21 | 22 | python -m mini_coil.data_pipeline.split_sentences --input-file $file --output-file $ARCHIVE_DIR/${file_name_no_ext}-sentences.txt.gz & 23 | done 24 | 25 | 26 | wait $(jobs -p) -------------------------------------------------------------------------------- /tests/01_embedding_maker.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import tqdm 5 | from fastembed.text import TextEmbedding 6 | from sentence_transformers import SentenceTransformer 7 | 8 | from mini_coil.model.mini_coil_inference import MiniCOIL 9 | 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--input-file", type=str, required=True, help="Path to the input file containing sentences") 14 | parser.add_argument("--output-mixedbread", type=str, required=False, help="Path to the output file for mixedbread") 15 | parser.add_argument("--output-jina", type=str, required=False, help="Path to the output file for jina2") 16 | parser.add_argument("--output-minicoil", type=str, required=False, help="Path to the output file for minicoil") 17 | parser.add_argument("--output-random", type=str, required=False, help="Path to the output file for random") 18 | 19 | parser.add_argument("--vocab-path", type=str, required=True, help="Path to the vocabulary file (minicoil)") 20 | parser.add_argument("--word-encoder-path", type=str, required=True, help="Path to the word encoder file (minicoil)") 21 | parser.add_argument("--use-cuda", action="store_true", default=False, help="Use CUDA for jina2base") 22 | parser.add_argument("--minicoil-test-word", type=str, required=True, help="Word to test for minicoil") 23 | 24 | parser.add_argument("--dim", type=int, default=4, help="Output dimension for minicoil") 25 | 26 | args = parser.parse_args() 27 | 28 | with open(args.input_file, "r") as f: 29 | lines = [x.strip() for x in f if len(x.strip()) > 0] 30 | 31 | skip_mixed_bread = args.output_mixedbread is None 32 | skip_jina_small = args.output_jina is None 33 | skip_minicoil = args.output_minicoil is None 34 | skip_random = args.output_random is None 35 | 36 | if not skip_mixed_bread: 37 | model_mixed = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1", trust_remote_code=True) 38 | if args.use_cuda: 39 | model_mixed = model_mixed.to("cuda") 40 | emb_mixed = model_mixed.encode(lines, batch_size=32, show_progress_bar=True) 41 | np.save(args.output_mixedbread, emb_mixed) 42 | del model_mixed 43 | 44 | if not skip_jina_small: 45 | model_jina = TextEmbedding("jinaai/jina-embeddings-v2-small-en", cuda=args.use_cuda) 46 | emb_jina = np.stack(list(model_jina.embed(tqdm.tqdm(lines), batch_size=8))) 47 | np.save(args.output_jina, emb_jina) 48 | del model_jina 49 | 50 | if not skip_random: 51 | emb_random = np.random.randn(len(lines), 768) 52 | np.save(args.output_random, emb_random) 53 | 54 | if not skip_minicoil: 55 | model_minicoil = MiniCOIL( 56 | vocab_path=args.vocab_path, 57 | word_encoder_path=args.word_encoder_path, 58 | sentence_encoder_model="jinaai/jina-embeddings-v2-small-en-tokens", 59 | ) 60 | emb_mc_list = [] 61 | for line in lines: 62 | row = model_minicoil.encode([line])[0] 63 | v = [] 64 | if args.minicoil_test_word not in row: 65 | zeros = np.zeros((model_minicoil.output_dim,)) 66 | emb_mc_list.append(zeros) 67 | else: 68 | emb_mc_list.append(row[args.minicoil_test_word]["embedding"]) 69 | emb_mc = np.stack(emb_mc_list) 70 | np.save(args.output_minicoil, emb_mc) 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /tests/02_matrix_create.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | 5 | 6 | def calculate_distance_matrix(input_path, output_path): 7 | embeddings = np.load(input_path) 8 | norms = np.linalg.norm(embeddings, axis=1) 9 | normalized = embeddings / (norms[:, np.newaxis] + 1e-9) 10 | distances = 1 - normalized @ normalized.T 11 | mean_dist = np.mean(distances) 12 | median_dist = np.median(distances) 13 | np.save(output_path, distances) 14 | print(f"Matrix shape: {distances.shape}") 15 | print(f"Distance range: {distances.min():.6f} to {distances.max():.6f}") 16 | print(f"Mean distance: {mean_dist:.6f}") 17 | print(f"Median distance: {median_dist:.6f}") 18 | 19 | 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--input', required=True, help='Path to input embeddings .npy file') 23 | parser.add_argument('--output', required=True, help='Path to output similarity matrix .npy file') 24 | args = parser.parse_args() 25 | 26 | calculate_distance_matrix(args.input, args.output) 27 | -------------------------------------------------------------------------------- /tests/03_matrix_triplets.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import Iterable, Tuple 4 | 5 | import random 6 | import numpy as np 7 | 8 | 9 | def check_triplet(distance_matrix, anchor, positive, negative, margin): 10 | """ 11 | Example: 12 | 13 | negative_distance = 0.8 # bigger = less similar 14 | positive_distance = 0.5 # smaller = more similar 15 | 16 | margin = 0.1 17 | 18 | negative_distance - positive_distance = 0.3 # more than margin, therefore True 19 | 20 | --- 21 | 22 | negative_distance = 0.5 # less similar 23 | positive_distance = 0.8 # more similar 24 | 25 | margin = 0.1 26 | 27 | negative_distance - positive_distance = -0.3 # less than margin, therefore False 28 | """ 29 | 30 | pos_dist = distance_matrix[anchor, positive] 31 | neg_dist = distance_matrix[anchor, negative] 32 | 33 | return neg_dist - pos_dist > margin 34 | 35 | 36 | def sample_triplets(distance_matrix: np.ndarray, margin: float) -> Iterable[Tuple[int, int, int]]: 37 | size = distance_matrix.shape[0] 38 | while True: 39 | x, y, z = random.sample(range(size), 3) 40 | 41 | dxy, dxz, dyz = distance_matrix[x, y], distance_matrix[x, z], distance_matrix[y, z] 42 | 43 | x_anchor_dist = abs(dxy - dxz) 44 | y_anchor_dist = abs(dxy - dyz) 45 | z_anchor_dist = abs(dxz - dyz) 46 | 47 | if x_anchor_dist > margin and x_anchor_dist > y_anchor_dist and x_anchor_dist > z_anchor_dist: 48 | anchor = x 49 | if dxy > dxz: 50 | positive = z 51 | negative = y 52 | else: 53 | positive = y 54 | negative = z 55 | yield anchor, positive, negative 56 | continue 57 | 58 | if y_anchor_dist > margin and y_anchor_dist > x_anchor_dist and y_anchor_dist > z_anchor_dist: 59 | anchor = y 60 | if dxy > dyz: 61 | positive = z 62 | negative = x 63 | else: 64 | positive = x 65 | negative = z 66 | yield anchor, positive, negative 67 | continue 68 | 69 | if z_anchor_dist > margin and z_anchor_dist > x_anchor_dist and z_anchor_dist > y_anchor_dist: 70 | anchor = z 71 | if dxz > dyz: 72 | positive = y 73 | negative = x 74 | else: 75 | positive = x 76 | negative = y 77 | yield anchor, positive, negative 78 | continue 79 | 80 | 81 | if __name__ == '__main__': 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument("--distance-matrix-base-path", type=str, required=True) 84 | parser.add_argument("--distance-matrix-eval-path", type=str, required=True) 85 | parser.add_argument("--sample-size", type=int, required=True) 86 | parser.add_argument("--base-margin", type=float, default=0.2) 87 | parser.add_argument("--eval-margin", type=float, default=0.0) 88 | args = parser.parse_args() 89 | 90 | # Ground truth distance matrix 91 | base_distance_matrix = np.load(args.distance_matrix_base_path) 92 | 93 | # Evaluated distance matrix 94 | eval_distance_matrix = np.load(args.distance_matrix_eval_path) 95 | 96 | results = [] 97 | 98 | # Sample triplets from the base distance matrix 99 | n = 0 100 | for anchor, positive, negative in sample_triplets(base_distance_matrix, args.base_margin): 101 | n += 1 102 | if n >= args.sample_size: 103 | break 104 | 105 | results.append(check_triplet(eval_distance_matrix, anchor, positive, negative, args.eval_margin)) 106 | 107 | total_matches = sum(results) 108 | 109 | print(f"Total matches: {total_matches} out of {len(results)} ({total_matches / len(results):.2%})") 110 | -------------------------------------------------------------------------------- /tests/04_umap_emb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | from umap import UMAP 5 | 6 | 7 | def create_umap_embeddings(output: str, 8 | embeddings: np.ndarray, 9 | n_components, 10 | n_neighbors, 11 | metric: str = 'cosine', ): 12 | 13 | # umap = UMAP( 14 | # metric="precomputed", 15 | # n_components=n_components, 16 | # output_metric="hyperboloid", 17 | # n_neighbors=n_neighbours, 18 | # ) 19 | 20 | umap = UMAP( 21 | n_components=n_components, 22 | n_neighbors=n_neighbors, 23 | metric=metric, 24 | random_state=42, 25 | output_metric="hyperboloid", 26 | ) 27 | 28 | umap_embeddings = umap.fit_transform(embeddings) 29 | np.save(output, umap_embeddings) 30 | 31 | 32 | def main(): 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("--embeddings", type=str, required=True, 35 | help="Path to original embeddings .npy file") 36 | parser.add_argument("--output-umap", type=str, required=True, 37 | help="Path to save UMAP embeddings") 38 | parser.add_argument("--umap-components", type=int, 39 | help="Number of UMAP components") 40 | parser.add_argument("--n-neighbors", type=int, 41 | help="UMAP n_neighbors parameter") 42 | 43 | args = parser.parse_args() 44 | 45 | create_umap_embeddings( 46 | output=args.output_umap, 47 | embeddings=np.load(args.embeddings), 48 | n_components=args.umap_components, 49 | n_neighbors=args.n_neighbors, 50 | ) 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qdrant/miniCOIL/2d939737e89bacd797cb6c3e312135070a93a56d/tests/__init__.py -------------------------------------------------------------------------------- /tests/embed_minicoil.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import os 4 | import numpy as np 5 | import tqdm 6 | 7 | from mini_coil.model.mini_coil_inference import MiniCOIL 8 | 9 | ANECDOTAL_EXAMPLES = { 10 | "vector": [ 11 | "vector search", 12 | "vector index", 13 | "vector space", 14 | "vector image", 15 | "vector graphics", 16 | "vector illustration", 17 | ], 18 | "bat": [ 19 | "baseball bat", 20 | "swing a bat", 21 | "aluminum bat", 22 | "bat in the cave", 23 | "vampire bat", 24 | "bat flying", 25 | ], 26 | "calcul": [ 27 | "I know how to use calculators", 28 | "open the calculator app", 29 | "calculating this won't be easy", 30 | "numerical calculations", 31 | ], 32 | "life": [ 33 | "How much money is sufficient to live a peaceful life?", 34 | "How much money do we need to live life happily?", 35 | ] 36 | } 37 | 38 | 39 | def main(): 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--vocab-path", type=str, required=True, help="Path to the vocabulary file (minicoil)") 42 | parser.add_argument("--word-encoder-path", type=str, required=True, help="Path to the word encoder file (minicoil)") 43 | parser.add_argument("--input-file", type=str, required=True, help="Path to the input file containing sentences") 44 | parser.add_argument("--output", type=str, required=True, help="Path to the output minicoil embeddings") 45 | parser.add_argument("--word", type=str, required=True, help="Word to test for minicoil") 46 | 47 | args = parser.parse_args() 48 | 49 | model_minicoil = MiniCOIL( 50 | vocab_path=args.vocab_path, 51 | word_encoder_path=args.word_encoder_path, 52 | sentence_encoder_model="jinaai/jina-embeddings-v2-small-en-tokens", 53 | ) 54 | emb_mc_list = [] 55 | 56 | lines = open(args.input_file).read().splitlines() 57 | 58 | encoded = model_minicoil.encode(tqdm.tqdm(lines), parallel=4) 59 | 60 | for row in encoded: 61 | if args.word not in row: 62 | zeros = np.zeros((model_minicoil.output_dim,)) 63 | emb_mc_list.append(zeros) 64 | else: 65 | emb_mc_list.append(row[args.word]["embedding"]) 66 | emb_mc = np.stack(emb_mc_list) 67 | 68 | # create output dir is not exists 69 | os.makedirs(os.path.dirname(args.output), exist_ok=True) 70 | 71 | np.save(args.output, emb_mc) 72 | 73 | # If the word happens to be in the anecdotal examples, print the embeddings 74 | if args.word in ANECDOTAL_EXAMPLES: 75 | examples = ANECDOTAL_EXAMPLES[args.word] 76 | encoded = model_minicoil.encode(examples) 77 | for example, row in zip(examples, encoded): 78 | word_emb = row[args.word]["embedding"] 79 | print(word_emb, " - ", example) 80 | 81 | 82 | if __name__ == "__main__": 83 | main() -------------------------------------------------------------------------------- /tests/visualize_embeddings.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import argparse 4 | import os 5 | 6 | def main(): 7 | """ 8 | Take numpy file with produced embeddings and visualize first 2 dimensions. 9 | 10 | Make big plot to see the distribution of the embeddings. 11 | 12 | In addition to scatterplot, print histogram of circular distribution of the embeddings around 0,0. 13 | Take the angle of the embedding and plot histogram of the angles. 14 | 15 | """ 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--input", type=str, required=True, help="Path to the input numpy file") 19 | parser.add_argument("--output", type=str, required=True, help="Path to the output visualization") 20 | args = parser.parse_args() 21 | 22 | emb = np.load(args.input) 23 | 24 | plt.figure(figsize=(10, 10)) 25 | plt.scatter(emb[:, 0], emb[:, 1]) 26 | 27 | os.makedirs(os.path.dirname(args.output), exist_ok=True) 28 | 29 | plt.savefig(args.output + ".png") 30 | 31 | print(f"Visualization saved to {args.output}") 32 | 33 | embedding_angles = np.arctan2(emb[:, 1], emb[:, 0]) 34 | 35 | plt.figure(figsize=(10, 10)) 36 | plt.hist(embedding_angles, bins=100) 37 | 38 | plt.savefig(args.output + "_histogram.png") 39 | 40 | print(f"Histogram saved to {args.output}_histogram.png") 41 | 42 | if __name__ == "__main__": 43 | main() 44 | 45 | -------------------------------------------------------------------------------- /train-triplets.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | set -e # exit on error 4 | set -u # exit on using unset variable 5 | set -o pipefail # exit on error in pipe 6 | 7 | CURRENT_DIR=$(pwd -L) 8 | #COLLECTION_NAME=$1 9 | TARGET_WORD=${1:-"vector"} 10 | DIM=4 11 | SAMPLES=8000 12 | LMODEL=mxbai-large 13 | IMODEL=jina-small 14 | 15 | 16 | INPUT_DIR=data/triplets-${SAMPLES}-${LMODEL} 17 | 18 | MODEL_DIR=${INPUT_DIR}-${IMODEL}-${DIM}-augmented 19 | 20 | WORD_MODELS_DIR=${MODEL_DIR}/word-models 21 | FULL_MODEL_DIR=${MODEL_DIR}/full-models 22 | 23 | 24 | generate_distance_matrix() { 25 | echo "Generate distance matrix for given word" 26 | python -m mini_coil.data_pipeline.distance_matrix \ 27 | --word ${TARGET_WORD} \ 28 | --output-matrix ${INPUT_DIR}/distance_matrix/dm-${TARGET_WORD}.npy \ 29 | --output-sentences ${INPUT_DIR}/target_sentences/sentences-${TARGET_WORD}.jsonl \ 30 | --sample-size ${SAMPLES} 31 | } 32 | 33 | 34 | augment_data() { 35 | python -m mini_coil.data_pipeline.augment_data \ 36 | --input-file ${INPUT_DIR}/target_sentences/sentences-${TARGET_WORD}.jsonl \ 37 | --output-file ${INPUT_DIR}/target_sentences/sentences-${TARGET_WORD}-augmented.jsonl \ 38 | --target-word "${TARGET_WORD}" 39 | echo "Augmented data" 40 | } 41 | 42 | 43 | encode_sentences() { 44 | # Encode sentences with smaller transformer model 45 | python -m mini_coil.data_pipeline.encode_and_filter \ 46 | --sentences-file ${INPUT_DIR}/target_sentences/sentences-${TARGET_WORD}-augmented.jsonl \ 47 | --output-file ${INPUT_DIR}-${IMODEL}/word-emb-${TARGET_WORD}.npy \ 48 | --output-line-numbers-file ${INPUT_DIR}-${IMODEL}/line-numbers-${TARGET_WORD}.npy \ 49 | --word "${TARGET_WORD}" 50 | echo "Encoded sentences" 51 | } 52 | 53 | 54 | train_encoder() { 55 | #Train encoder **for each word** 56 | python -m mini_coil.training.train_word_triplet \ 57 | --embedding-path ${INPUT_DIR}-${IMODEL}/word-emb-${TARGET_WORD}.npy \ 58 | --line-numbers-path ${INPUT_DIR}-${IMODEL}/line-numbers-${TARGET_WORD}.npy \ 59 | --distance-matrix-path ${INPUT_DIR}/distance_matrix/dm-${TARGET_WORD}.npy \ 60 | --output-dim ${DIM} \ 61 | --output-path ${MODEL_DIR}/model-${TARGET_WORD}.ptch \ 62 | --log-dir ${INPUT_DIR}-${IMODEL}-augmented/train_logs/log_"${TARGET_WORD}" \ 63 | --epochs 60 64 | echo "Trained model" 65 | } 66 | 67 | 68 | combine_models() { 69 | ## Merge encoders for each word into a single model 70 | python -m mini_coil.data_pipeline.combine_models \ 71 | --models-dir ${MODEL_DIR} \ 72 | --vocab-path "${CURRENT_DIR}/data/30k-vocab-filtered.json" \ 73 | --output-path ${FULL_MODEL_DIR}/model \ 74 | --output-dim "${DIM}" 75 | } 76 | 77 | 78 | download_validation_data() { 79 | # Download validation data 80 | python -m mini_coil.data_pipeline.download_validation \ 81 | --word ${TARGET_WORD} \ 82 | --output-sentences data/validation/${TARGET_WORD}-validation.txt 83 | } 84 | 85 | 86 | embed_sentences() { 87 | # Embed a bunch of sentences 88 | python -m tests.embed_minicoil \ 89 | --vocab-path ${FULL_MODEL_DIR}/model.vocab \ 90 | --word-encoder-path ${FULL_MODEL_DIR}/model.npy \ 91 | --input-file data/validation/${TARGET_WORD}-validation.txt \ 92 | --word ${TARGET_WORD} \ 93 | --output ${WORD_MODELS_DIR}/validation/${TARGET_WORD}-validation.npy 94 | } 95 | 96 | visualize_embeddings() { 97 | # Plot the embeddings 98 | python -m tests.visualize_embeddings \ 99 | --input ${WORD_MODELS_DIR}/validation/${TARGET_WORD}-validation.npy \ 100 | --output ${WORD_MODELS_DIR}/validation-viz/${TARGET_WORD}-plot 101 | } 102 | 103 | cleaenup() { 104 | rm ${INPUT_DIR}/distance_matrix/dm-${TARGET_WORD}.npy 105 | rm ${INPUT_DIR}/target_sentences/sentences-${TARGET_WORD}.jsonl 106 | rm ${INPUT_DIR}/target_sentences/sentences-${TARGET_WORD}-augmented.jsonl 107 | rm ${INPUT_DIR}-${IMODEL}/word-emb-${TARGET_WORD}.npy 108 | rm ${INPUT_DIR}-${IMODEL}/line-numbers-${TARGET_WORD}.npy 109 | } 110 | 111 | main() { 112 | # Skip if the model file already exists 113 | MODEL_FILE_NAME=${MODEL_DIR}/model-${TARGET_WORD}.ptch 114 | 115 | if [ -f "$MODEL_FILE_NAME" ]; then 116 | echo "Model file already exists. Skipping training." 117 | exit 0 118 | fi 119 | 120 | 121 | generate_distance_matrix 122 | augment_data 123 | encode_sentences 124 | train_encoder 125 | # combine_models 126 | # # Validate the model 127 | # download_validation_data 128 | # embed_sentences 129 | # visualize_embeddings 130 | cleaenup 131 | } 132 | 133 | main "$@" -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | set -e # exit on error 4 | set -u # exit on using unset variable 5 | set -o pipefail # exit on error in pipe 6 | 7 | CURRENT_DIR=$(pwd -L) 8 | #COLLECTION_NAME=$1 9 | TARGET_WORD=bat 10 | DIM=4 11 | SAMPLES=4000 12 | NEIGHBORS=20 13 | LMODEL=mxbai-large 14 | 15 | . venv/bin/activate || . .venv/bin/activate || true 16 | 17 | ## Convert dataset into readable format 18 | 19 | ## Split data into sentences 20 | #python -m mini_coil.data_pipeline.split_sentences \ 21 | # --input-file "${CURRENT_DIR}/data/openwebtext-1920-sentences-"${TARGET_WORD}".txt.gz" \ 22 | # --output-file "${CURRENT_DIR}/data/openwebtext-sentences/openwebtext-1920-splitted-"${TARGET_WORD}".txt.gz" 23 | 24 | ## Encode sentences with transformer model 25 | #python -m mini_coil.data_pipeline.encode_targets \ 26 | # --input-file "${CURRENT_DIR}/data/openwebtext-1920-sentences-"${TARGET_WORD}".txt.gz" \ 27 | # --output-file "${CURRENT_DIR}/data/output/openwebtext-1920-splitted-"${TARGET_WORD}"-encodings" 28 | 29 | ## Upload encoded sentences to Qdrant 30 | #python -m mini_coil.data_pipeline.upload_to_qdrant \ 31 | # --input-emb 32 | # --input-text 33 | # --collection-name ${COLLECTION_NAME} 34 | 35 | # Sample sentences with specified words and apply dimensionality reduction 36 | #python -m mini_coil.data_pipeline.compress_dimentions \ 37 | # --output-dir data/umap-"${SAMPLES}"-"${NEIGHBORS}"-"${DIM}"d-"${LMODEL}" \ 38 | # --sample-size "${SAMPLES}" --dim "${DIM}" --word "${TARGET_WORD}" --overwrite \ 39 | # --limit ${NEIGHBORS} --n_neighbours ${NEIGHBORS} 40 | 41 | echo "Compressed dimentions" 42 | ## Download sampled sentences 43 | python -m mini_coil.data_pipeline.load_sentences \ 44 | --word "${TARGET_WORD}" \ 45 | --matrix-dir data/umap-"${SAMPLES}"-"${NEIGHBORS}"-"${DIM}"d-"${LMODEL}" \ 46 | --output-dir data/umap-"${SAMPLES}"-"${NEIGHBORS}"-"${DIM}"d-"${LMODEL}"-sentences 47 | 48 | echo "Loaded sentences" 49 | ## Encode sentences with smaller transformer model 50 | python -m mini_coil.data_pipeline.encode_and_filter \ 51 | --sentences-file data/umap-"${SAMPLES}"-"${NEIGHBORS}"-"${DIM}"d-"${LMODEL}"-sentences/sentences-"${TARGET_WORD}".jsonl \ 52 | --output-file data/umap-"${SAMPLES}"-"${NEIGHBORS}"-"${DIM}"d-"${LMODEL}"-input/word-emb-"${TARGET_WORD}".npy \ 53 | --word "${TARGET_WORD}" \ 54 | --sample-size "${SAMPLES}" 55 | 56 | echo "Encoded sentences" 57 | #Train encoder **for each word** 58 | python -m mini_coil.training.train_word \ 59 | --embedding-path data/umap-"${SAMPLES}"-"${NEIGHBORS}"-"${DIM}"d-"${LMODEL}"-input/word-emb-"${TARGET_WORD}".npy \ 60 | --target-path data/umap-"${SAMPLES}"-"${NEIGHBORS}"-"${DIM}"d-"${LMODEL}"/compressed_matrix_"${TARGET_WORD}".npy \ 61 | --log-dir data/train_logs/log_"${TARGET_WORD}" \ 62 | --output-path data/umap-"${SAMPLES}"-"${NEIGHBORS}"-"${DIM}"d-"${LMODEL}"-models/model-"${TARGET_WORD}".ptch \ 63 | --epochs 500 64 | ## --gpu 65 | 66 | echo "Combined models" 67 | ## Merge encoders for each word into a single model 68 | python -m mini_coil.data_pipeline.combine_models \ 69 | --models-dir "${CURRENT_DIR}/data/umap-${SAMPLES}-${NEIGHBORS}-${DIM}d-${LMODEL}-models" \ 70 | --vocab-path "${CURRENT_DIR}/data/30k-vocab-filtered.txt" \ 71 | --output-path "data/model_${SAMPLES}_${DIM}d" \ 72 | --output-dim "${DIM}" 73 | 74 | -------------------------------------------------------------------------------- /unpack_openwebtext.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # Unpack OpenWebText tar files 5 | 6 | ARCHIVE_DIR=data/openwebtext 7 | 8 | rm -rf $ARCHIVE_DIR/openwebtext_subset* 9 | 10 | rm -f $ARCHIVE_DIR/openwebtext 11 | 12 | for archive in $ARCHIVE_DIR/*.tar 13 | do 14 | # archive = data/openwebtext/urlsf_subset00.tar 15 | 16 | file_name=$(basename $archive) 17 | # file_name = urlsf_subset00.tar 18 | 19 | # index = 00 20 | index=${file_name: -6:2} 21 | 22 | tar -xvf $archive -C $ARCHIVE_DIR 23 | 24 | mv $ARCHIVE_DIR/openwebtext $ARCHIVE_DIR/openwebtext_subset${index} 25 | 26 | python -m mini_coil.data_pipeline.convert_openwebtext \ 27 | --output-file $ARCHIVE_DIR/openwebtext-${index}.txt.gz \ 28 | --archive-dir $ARCHIVE_DIR/openwebtext_subset${index} & 29 | 30 | done 31 | 32 | 33 | 34 | wait $(jobs -p) --------------------------------------------------------------------------------