├── .github └── workflows │ └── test.yaml ├── .gitignore ├── .project-root ├── LICENSE ├── README.md ├── create_splits_of_varying_difficulty.py ├── pyproject.toml ├── src └── familiarity │ ├── __init__.py │ ├── embedding_models.py │ ├── logger.py │ ├── metric.py │ └── utils.py └── tests ├── conftest.py ├── test_embedding_models.py ├── test_logger.py ├── test_metric.py └── test_utils.py /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | paths-ignore: 7 | - "README.md" 8 | - ".gitignore" 9 | - "LICENSE" 10 | pull_request: 11 | branches: [ "main" ] 12 | paths-ignore: 13 | - "README.md" 14 | - ".gitignore" 15 | - "LICENSE" 16 | 17 | jobs: 18 | run: 19 | name: "Run Tests" 20 | runs-on: ubuntu-latest 21 | strategy: 22 | fail-fast: false 23 | matrix: 24 | python-version: ["3.9", "3.10", "3.11", "3.12"] 25 | 26 | steps: 27 | - uses: actions/checkout@v4 28 | 29 | - name: Set up Python ${{ matrix.python-version }} 30 | uses: actions/setup-python@v5 31 | with: 32 | python-version: ${{ matrix.python-version }} 33 | 34 | # Set up caching for pip dependencies 35 | - name: Cache pip dependencies 36 | uses: actions/cache@v4 37 | with: 38 | path: ~/.cache/pip 39 | key: "pip-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml') }}" 40 | restore-keys: | 41 | pip-${{ matrix.python-version }}- 42 | 43 | # Install project in editable mode with dev dependencies 44 | - name: Install dependencies 45 | run: | 46 | python -m pip install --upgrade pip 47 | pip install -e .[dev] 48 | pip install -e .[testing] 49 | 50 | # Run Ruff for linting 51 | - name: Lint 52 | run: | 53 | ruff check . --fix 54 | 55 | # Run the tests 56 | - name: Run Tests 57 | run: | 58 | pytest --maxfail=5 --disable-warnings 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | .idea/ 163 | .vscode/ 164 | 165 | results/ 166 | -------------------------------------------------------------------------------- /.project-root: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flairNLP/familiarity/ed34cf9a62cb248c58e4fb9c518b599e512bfd39/.project-root -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jonas Golde 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Label Shift Estimation for Named Entity Recognition using Familiarity 2 | 3 | **Our paper got accepted to NAACL 2025 🎉 See our [paper](https://arxiv.org/abs/2412.10121) and find the datasets on the [huggingface hub](https://huggingface.co/flair)!** 4 | 5 | This repository computes the label shift for zero-shot NER settings using the Familiarity metric. The metric uses semantic similarity between the sets of label seen during training and used for evaluation to indicate how "familiar" the trained model will be with the evaluation labels. 6 | 7 | ## Installation 8 | ```python 9 | conda create -n familiarity python=3.11 10 | conda activate familiarity 11 | pip install -e . 12 | ``` 13 | 14 | ## Usage 15 | ```python 16 | import numpy as np 17 | from familiarity import compute_metric 18 | train_labels_set = ["person", "location", "building", "eagle", "restaurant", "util"] 19 | train_probs = [0.4, 0.1, 0.1, 0.1, 0.1, 0.2] 20 | train_labels = np.random.choice(train_labels_set, size=30000, p=train_probs).tolist() 21 | 22 | test_labels_set = ["human", "organization", "building", "review", "researcher", "car"] 23 | test_probs = [0.5, 0.2, 0.05, 0.05, 0.1, 0.1] 24 | test_labels = np.random.choice(test_labels_set, size=30000, p=test_probs).tolist() 25 | 26 | compute_metric( 27 | train_labels=train_labels, 28 | test_labels=test_labels, 29 | model_name_or_path="distilbert-base-cased", 30 | save_results=True, 31 | save_embeddings=True, 32 | ) 33 | ``` 34 | 35 | If you want to create splits of varying difficult as we done it in the paper, please refer to the script `create_splits_of_varying_difficulty.py`. 36 | 37 | ## Citation 38 | ``` 39 | @misc{golde2024familiaritybetterevaluationzeroshot, 40 | title={Familiarity: Better Evaluation of Zero-Shot Named Entity Recognition by Quantifying Label Shifts in Synthetic Training Data}, 41 | author={Jonas Golde and Patrick Haller and Max Ploner and Fabio Barth and Nicolaas Jedema and Alan Akbik}, 42 | year={2024}, 43 | eprint={2412.10121}, 44 | archivePrefix={arXiv}, 45 | primaryClass={cs.CL}, 46 | url={https://arxiv.org/abs/2412.10121}, 47 | } 48 | ``` 49 | -------------------------------------------------------------------------------- /create_splits_of_varying_difficulty.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | from typing import Dict, List 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from datasets import Dataset, DatasetDict 9 | from sentence_transformers import SentenceTransformer 10 | from torch.nn.functional import cosine_similarity 11 | from tqdm import tqdm 12 | 13 | 14 | def create_splits_for_hf_hub(train_dataset: str): 15 | # Dataset format should be a list of dictionaries, where each dictionary represents a data point. 16 | path_to_train_data = f"path/to/train/{train_dataset}.json" 17 | with open(path_to_train_data, "r") as f: 18 | data = json.load(f) 19 | 20 | for filter_by in ["entropy", "max"]: 21 | dataset_dict = DatasetDict() 22 | for setting in ["easy", "medium", "hard"]: 23 | new_split = create_splits( 24 | data, 25 | train_dataset, 26 | filter_by=filter_by, 27 | setting=setting, 28 | ) 29 | 30 | hf_format = [convert_to_hf_format(data_point) for data_point in new_split] 31 | 32 | ds = Dataset.from_pandas(pd.DataFrame(data=hf_format)) 33 | dataset_dict[setting] = ds 34 | 35 | dataset_dict.push_to_hub(f"{train_dataset}_{filter_by}_splits") 36 | 37 | 38 | def convert_to_hf_format(data_point): 39 | tags = ["O"] * len(data_point["tokenized_text"]) 40 | spans = [] 41 | for ent in data_point["ner"]: 42 | start, end, label = ent[0], ent[1], ent[2] 43 | spans.append({"start": start, "end": end, "label": label}) 44 | if start == end: 45 | tags[start] = "B-" + label 46 | else: 47 | try: 48 | tags[start] = "B-" + label 49 | tags[start + 1 : end + 1] = ["I-" + label] * (end - start) 50 | except IndexError: 51 | pass 52 | return {"tokens": data_point["tokenized_text"], "ner_tags": tags, "spans": spans} 53 | 54 | 55 | def create_splits( 56 | dataset: List[Dict], 57 | dataset_name: str, # The name of the dataset for which the splits should be created 58 | filter_by: str = "entropy", 59 | setting: str = "medium", 60 | ): 61 | try: 62 | df = pd.read_pickle("new_splits.pkl") 63 | except FileNotFoundError: 64 | raise FileNotFoundError("Please run the compute_new_splits function first to generate the data.") 65 | df = df[(df["train_dataset"] == dataset_name)] 66 | 67 | selected_entity_types = [] 68 | for benchmark_name in df["eval_dataset"].unique(): 69 | _df = df[(df["eval_dataset"] == benchmark_name)].copy() 70 | 71 | # The thresholds are dataset specific and may need to be adjusted to account for dataset with different characteristics 72 | if filter_by == "entropy": 73 | low_threshold = df[filter_by].quantile(0.01) 74 | high_threshold = df[filter_by].quantile(0.95) 75 | elif filter_by == "max": 76 | low_threshold = df[filter_by].quantile(0.05) 77 | high_threshold = df[filter_by].quantile(0.99) 78 | 79 | medium_lower_threshold = df[filter_by].quantile(0.495) 80 | medium_upper_threshold = df[filter_by].quantile(0.505) 81 | 82 | # Define conditions and choices for categorization 83 | conditions = [ 84 | _df[filter_by] <= low_threshold, # Bottom 85 | _df[filter_by].between(medium_lower_threshold, medium_upper_threshold), # Middle 86 | _df[filter_by] >= high_threshold, # Top 87 | ] 88 | choices = ["easy", "medium", "hard"] if filter_by == "entropy" else ["hard", "medium", "easy"] 89 | 90 | # Use np.select to create the new column based on the conditions 91 | _df["difficulty"] = np.select(conditions, choices, default="not relevant") 92 | 93 | selected_entity_types.extend(_df[_df["difficulty"] == setting]["entity"].tolist()) 94 | 95 | new_dataset = [] 96 | for dp in tqdm(dataset): 97 | matched_entities = [x for x in dp["ner"] if x[-1].lower().strip() in selected_entity_types] 98 | if matched_entities: 99 | new_np = copy.deepcopy(dp) 100 | new_np["ner"] = matched_entities 101 | new_dataset.append(new_np) 102 | 103 | return new_dataset 104 | 105 | 106 | def compute_new_splits(): 107 | # TODO: you need to load the data into two variables: 'benchmarks' and 'training_datasets'. 108 | # 'benchmarks' should be a dictionary with the benchmark names as keys and the (list of distinct) entity types as values. 109 | # 'training_datasets' should be a dictionary with the training dataset names as keys and the (list of distinct) entity types as values. 110 | # We process multiple benchmarks and training datasets in this example, but you can adjust the code to fit your needs. 111 | # Further, we stick with the following dataset layout: list of dictionaries, where each dictionary represents a data point. 112 | # For example: [{'tokenized_text': [...], 'ner': [(start, end, entity_type), ...]}, ...] 113 | 114 | benchmarks = {} 115 | for benchmark_name in ['path/to/eval/dataset1.json', 'path/to/eval/dataset2.json']: 116 | # Data loading logic here, e.g.: 117 | # tokens, entity_types = load_eval_dataset(benchmark_name) 118 | # benchmarks[benchmark_name] = list(entity_types) 119 | pass 120 | 121 | training_datasets = {} 122 | for train_dataset_name in ['path/to/train/dataset1.json', 'path/to/train/dataset2.json']: 123 | # Data loading logic here, e.g.: 124 | # tokens, entity_types = load_train_dataset(train_dataset_name) 125 | # training_datasets[train_dataset_name] = list(entity_types) 126 | pass 127 | 128 | batch_size = 256 129 | model = SentenceTransformer("all-mpnet-base-v2").to("cuda") 130 | eval_encodings = {} 131 | for benchmark_name, entity_types in benchmarks.items(): 132 | embeddings = model.encode(entity_types, convert_to_tensor=True, device="cuda") 133 | eval_encodings[benchmark_name] = embeddings 134 | 135 | results = {} 136 | for dataset_name, entity_types in training_datasets.items(): 137 | for i in tqdm(range(0, len(entity_types), batch_size)): 138 | dataset_name = dataset_name.split(".")[0] 139 | batch = entity_types[i : i + batch_size] 140 | embeddings = model.encode(batch, convert_to_tensor=True, device="cuda") 141 | for benchmark_name, eval_embeddings in eval_encodings.items(): 142 | similarities = torch.clamp( 143 | cosine_similarity( 144 | embeddings.unsqueeze(1), 145 | eval_embeddings.unsqueeze(0), 146 | dim=2, 147 | ), 148 | min=0.0, 149 | max=1.0, 150 | ) 151 | probabilities = torch.nn.functional.softmax(similarities / 0.01, dim=1) 152 | entropy_values = -torch.sum(probabilities * torch.log(probabilities + 1e-10), dim=1) 153 | max_values, _ = torch.max(similarities, dim=1) 154 | 155 | if dataset_name not in results: 156 | results[dataset_name] = {} 157 | if benchmark_name not in results[dataset_name]: 158 | results[dataset_name][benchmark_name] = {} 159 | 160 | for j, entity in enumerate(batch): 161 | if entity not in results[dataset_name][benchmark_name]: 162 | results[dataset_name][benchmark_name][entity] = {} 163 | results[dataset_name][benchmark_name][entity]["entropy"] = entropy_values[j].cpu().numpy().item() 164 | results[dataset_name][benchmark_name][entity]["max"] = max_values[j].cpu().numpy().item() 165 | 166 | entries = [] 167 | for dataset_name, eval_comparisons in results.items(): 168 | for benchmark_name, mapping in eval_comparisons.items(): 169 | for entity, values in mapping.items(): 170 | entries.append( 171 | { 172 | "entity": entity, 173 | "entropy": values["entropy"], 174 | "max": values["max"], 175 | "eval_dataset": benchmark_name, 176 | "train_dataset": dataset_name, 177 | } 178 | ) 179 | df = pd.DataFrame.from_dict(entries, orient="columns") 180 | df.to_pickle("new_splits.pkl") 181 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "familiarity" 3 | dynamic = ["version"] 4 | description = "Estimating label shift and transfer difficulty using Familiarity." 5 | authors = [{ name = "Jonas Golde", email = "jonas.max.golde@hu-berlin.de" }] 6 | readme = "README.md" 7 | requires-python = ">3.8" 8 | license = { file = "LICENSE" } 9 | dependencies = [ 10 | "transformers", 11 | "sentence-transformers", 12 | "datasets", 13 | "huggingface_hub", 14 | "numpy", 15 | "prettytable", 16 | "rootutils" 17 | ] 18 | classifiers = [ 19 | "Development Status :: 4 - Beta", 20 | "License :: OSI Approved :: MIT License", 21 | "Programming Language :: Python :: 3.8", 22 | "Programming Language :: Python :: 3.9", 23 | "Programming Language :: Python :: 3.10", 24 | "Programming Language :: Python :: 3.11", 25 | ] 26 | 27 | [build-system] 28 | requires = ["setuptools >= 61.0"] 29 | build-backend = "setuptools.build_meta" 30 | 31 | [project.optional-dependencies] 32 | testing = ["pytest"] 33 | dev = ["black", "isort", "ruff"] 34 | 35 | [tool.setuptools] 36 | packages = ["familiarity"] 37 | package-dir = { "" = "src" } 38 | 39 | [tool.black] 40 | target-version = ["py38"] 41 | line-length = 120 42 | skip-string-normalization = true 43 | 44 | [tool.isort] 45 | profile = "black" 46 | line_length = 120 47 | known_third_party = [ 48 | "transformers", 49 | "sentence_transformers", 50 | "datasets", 51 | "huggingface_hub", 52 | "numpy", 53 | "prettytable", 54 | "rootutils" 55 | ] 56 | 57 | [tool.ruff] 58 | line-length = 120 59 | 60 | [tool.ruff.lint] 61 | ignore = ["F405"] 62 | -------------------------------------------------------------------------------- /src/familiarity/__init__.py: -------------------------------------------------------------------------------- 1 | from familiarity.metric import compute_metric 2 | 3 | __all__ = ["compute_metric"] 4 | -------------------------------------------------------------------------------- /src/familiarity/embedding_models.py: -------------------------------------------------------------------------------- 1 | import io 2 | import re 3 | from abc import ABC, abstractmethod 4 | from pathlib import Path 5 | from typing import Dict, List, Tuple, Union 6 | 7 | import numpy as np 8 | import torch 9 | from huggingface_hub import repo_exists 10 | from sentence_transformers import SentenceTransformer 11 | from tqdm import tqdm 12 | from transformers import AutoModel, AutoTokenizer 13 | 14 | from familiarity.utils import get_device 15 | 16 | 17 | class LabelEmbeddingModel(ABC): 18 | def __init__(self): 19 | self.device = get_device() 20 | 21 | @abstractmethod 22 | def embed(self, batch: List[str]) -> np.array: 23 | """Abstract method to compute embeddings for a batch of words. 24 | 25 | Args: 26 | batch: List of strings (words). 27 | 28 | Returns: 29 | numpy array containing the embeddings. 30 | """ 31 | pass 32 | 33 | 34 | class FastTextModel(LabelEmbeddingModel): 35 | def __init__(self, model_path: Union[Path, str]): 36 | super().__init__() 37 | self.unk_token = "" 38 | self.pad_token = "" 39 | words, embeddings = self.load_vectors(Path(model_path)) 40 | self.words = words 41 | self.embeddings = embeddings.to(self.device) 42 | 43 | def load_vectors(self, model_name_or_path: Path) -> Tuple[List[str], torch.nn.Embedding]: 44 | file = io.open( 45 | model_name_or_path, 46 | "r", 47 | encoding="utf-8", 48 | newline="\n", 49 | errors="ignore", 50 | ) 51 | n, d = map(int, file.readline().split()) 52 | words = [] 53 | embeddings = [] 54 | for line in tqdm(file.readlines(), desc="Loading FastText"): 55 | tokens = line.strip().split(" ") 56 | words.append(tokens[0]) 57 | embeddings.append(torch.tensor(list(map(float, tokens[1:])))) 58 | 59 | words = {w: i for i, w in enumerate(words)} 60 | words[self.unk_token] = len(words) 61 | words[self.pad_token] = len(words) 62 | 63 | embeddings = torch.stack(embeddings) 64 | unk_embedding = torch.mean(embeddings, dim=0) 65 | padding_embedding = torch.zeros(1, embeddings.size(1)) 66 | embeddings = torch.cat([embeddings, unk_embedding.unsqueeze(0), padding_embedding], dim=0) 67 | embeddings = torch.nn.Embedding.from_pretrained(embeddings) 68 | return words, embeddings 69 | 70 | def embed(self, batch: List[str]) -> np.array: 71 | nested_batch = [re.split(r"[-/_ ]", label.lower()) for label in batch] 72 | max_length = max(len(inner_list) for inner_list in nested_batch) 73 | 74 | input_ids = torch.LongTensor( 75 | [ 76 | [self.words.get(label, self.words.get(self.unk_token)) for label in labels] 77 | + [self.words.get(self.pad_token)] * (max_length - len(labels)) 78 | for labels in nested_batch 79 | ] 80 | ).to(self.device) 81 | 82 | mask = input_ids != self.words.get(self.pad_token) 83 | 84 | embeddings = torch.sum(self.embeddings(input_ids), dim=1) / mask.sum(dim=1).unsqueeze(1) 85 | 86 | return embeddings.cpu().numpy() 87 | 88 | 89 | class GloveModel(LabelEmbeddingModel): 90 | def __init__(self, model_path: Union[Path, str]): 91 | super().__init__() 92 | self.unk_token = "" 93 | self.pad_token = "" 94 | words, embeddings = self.load_vectors(Path(model_path)) 95 | self.words = words 96 | self.embeddings = embeddings.to(self.device) 97 | 98 | def load_vectors(self, model_name_or_path: Path) -> Tuple[Dict[str, int], torch.nn.Embedding]: 99 | word_embedding_pairs = [] 100 | with open(model_name_or_path, "r", encoding="utf-8") as f: 101 | for line in tqdm(f.readlines(), desc="Loading GloVe"): 102 | parts = line.split(" ") 103 | word = parts[0] 104 | vector = torch.tensor([float(x) for x in parts[1:]]) 105 | word_embedding_pairs.append((word, vector)) 106 | 107 | words, embeddings = zip(*word_embedding_pairs) 108 | words = {w: i for i, w in enumerate(words)} 109 | words[self.unk_token] = len(words) 110 | words[self.pad_token] = len(words) 111 | 112 | embeddings = torch.stack(embeddings) 113 | unk_embedding = torch.mean(embeddings, dim=0) 114 | padding_embedding = torch.zeros(1, embeddings.size(1)) 115 | embeddings = torch.cat([embeddings, unk_embedding.unsqueeze(0), padding_embedding], dim=0) 116 | embeddings = torch.nn.Embedding.from_pretrained(embeddings) 117 | return words, embeddings 118 | 119 | def embed(self, batch: List[str]) -> np.array: 120 | nested_batch = [re.split(r"[-/_ ]", label.lower()) for label in batch] 121 | max_length = max(len(inner_list) for inner_list in nested_batch) 122 | 123 | input_ids = torch.LongTensor( 124 | [ 125 | [self.words.get(label, self.words.get(self.unk_token)) for label in labels] 126 | + [self.words.get(self.pad_token)] * (max_length - len(labels)) 127 | for labels in nested_batch 128 | ] 129 | ).to(self.device) 130 | 131 | mask = input_ids != self.words.get(self.pad_token) 132 | 133 | embeddings = torch.sum(self.embeddings(input_ids), dim=1) / mask.sum(dim=1).unsqueeze(1) 134 | return embeddings.cpu().numpy() 135 | 136 | 137 | class TransformerModel(LabelEmbeddingModel): 138 | def __init__(self, model_name_or_path: Union[Path, str], pooling: str = "mean"): 139 | super().__init__() 140 | self.model = AutoModel.from_pretrained(model_name_or_path).to(self.device) 141 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 142 | self.pooling = pooling 143 | 144 | def embed(self, batch: List[str]) -> np.array: 145 | inputs = self.tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(self.device) 146 | 147 | with torch.no_grad(): 148 | outputs = self.model(**inputs) 149 | 150 | if self.pooling == "mean": 151 | return outputs.last_hidden_state.mean(dim=1).cpu().numpy() 152 | else: 153 | return outputs.last_hidden_state[:, 0, :].cpu().numpy() 154 | 155 | 156 | class SentenceTransformerModel(LabelEmbeddingModel): 157 | def __init__(self, model_name_or_path: Union[Path, str]): 158 | super().__init__() 159 | self.model = SentenceTransformer(model_name_or_path).to(self.device) 160 | 161 | def embed(self, batch: List[str]) -> np.array: 162 | embedding = self.model.encode(batch, convert_to_tensor=True) 163 | return embedding.cpu().numpy() 164 | 165 | 166 | def infer_embedding_model(model_name_or_path: Union[Path, str]): 167 | try: 168 | if "glove" in model_name_or_path: 169 | embedding_model_type = "glove" 170 | elif "sentence-transformers" in model_name_or_path: 171 | embedding_model_type = "sentence-transformers" 172 | elif "wiki-news" in model_name_or_path or "crawl" in model_name_or_path: 173 | embedding_model_type = "fasttext" 174 | elif repo_exists(model_name_or_path): 175 | embedding_model_type = "transformers" 176 | else: 177 | embedding_model_type = None 178 | except Exception as e: 179 | print(f"Could not infer model type: {e}") 180 | embedding_model_type = None 181 | 182 | return embedding_model_type 183 | 184 | 185 | def load_embedding_model(model_name_or_path: str, embedding_model_type: str = None) -> torch.nn.Embedding: 186 | if not embedding_model_type: 187 | embedding_model_type = infer_embedding_model(model_name_or_path) 188 | if not embedding_model_type: 189 | raise ValueError( 190 | "Embedding model type can't be inferred. Please provide any of ['glove', 'fasttext', 'sentence-transformers', 'transformers'] as 'embedding_model_type' argument." 191 | ) 192 | assert embedding_model_type in [ 193 | 'glove', 194 | 'fasttext', 195 | 'sentence-transformers', 196 | 'transformers', 197 | ], f"{embedding_model_type} is not supported. It must be any of ['glove', 'fasttext', 'sentence-transformers', 'transformers']." 198 | 199 | if embedding_model_type == "glove": 200 | model = GloveModel(model_name_or_path) 201 | elif embedding_model_type == "sentence-transformers": 202 | model = SentenceTransformerModel(model_name_or_path) 203 | elif embedding_model_type == "fasttext": 204 | model = FastTextModel(model_name_or_path) 205 | elif embedding_model_type == "transformers": 206 | model = TransformerModel(model_name_or_path) 207 | return model 208 | -------------------------------------------------------------------------------- /src/familiarity/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from logging import Logger 3 | from pathlib import Path 4 | 5 | 6 | def setup_logger(output_path: Path) -> Logger: 7 | """Setup python logger.""" 8 | # Create the log file path 9 | log_file = output_path / "compute_metric.log" 10 | 11 | # Create a logger 12 | logger = logging.getLogger(__name__) 13 | logger.setLevel(logging.INFO) 14 | 15 | # Clear any existing handlers (important for pytest re-runs) 16 | if logger.hasHandlers(): 17 | logger.handlers.clear() 18 | 19 | # Add file handler 20 | file_handler = logging.FileHandler(log_file) 21 | file_handler.setLevel(logging.INFO) 22 | file_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) 23 | logger.addHandler(file_handler) 24 | 25 | # Optional: Add console handler 26 | console_handler = logging.StreamHandler() 27 | console_handler.setLevel(logging.INFO) 28 | console_handler.setFormatter(logging.Formatter("%(message)s")) 29 | logger.addHandler(console_handler) 30 | 31 | return logger 32 | -------------------------------------------------------------------------------- /src/familiarity/metric.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import Counter 3 | from pathlib import Path 4 | from typing import List, Union 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from tqdm import tqdm 9 | 10 | from familiarity.embedding_models import LabelEmbeddingModel, load_embedding_model 11 | from familiarity.logger import setup_logger 12 | from familiarity.utils import ( 13 | clipped_cosine_similarity, 14 | combine_counters, 15 | cumsum_until, 16 | df_to_prettytable, 17 | iterate_dict_in_batches, 18 | make_output_path, 19 | ) 20 | 21 | 22 | def compute_embeddings( 23 | train_labels_count: Counter, 24 | test_labels_count: Counter, 25 | model: LabelEmbeddingModel, 26 | batch_size: int = 32, 27 | output_path: Path = None, 28 | save_embeddings: bool = False, 29 | ) -> pd.DataFrame: 30 | 31 | embedding_df = pd.DataFrame(columns=["label", "count_train", "count_test", "embedding"]) 32 | iterator = iterate_dict_in_batches(combine_counters(train_labels_count, test_labels_count), batch_size) 33 | for batch in tqdm(iterator, desc="Embedding Labels..."): 34 | words, counts = zip(*batch.items()) 35 | train_counts, test_counts = zip(*counts) 36 | embeddings = model.embed(words) 37 | embedding_df = pd.concat( 38 | [ 39 | embedding_df, 40 | pd.DataFrame( 41 | { 42 | "label": words, 43 | "count_train": train_counts, 44 | "count_test": test_counts, 45 | "embedding": list(embeddings), 46 | } 47 | ), 48 | ] 49 | ) 50 | 51 | if save_embeddings and output_path: 52 | embedding_df.to_pickle(output_path / "embedding_df.pkl") 53 | 54 | return embedding_df 55 | 56 | 57 | def compute_similarities( 58 | embedding_df: pd.DataFrame, 59 | output_path: Path = None, 60 | save_embeddings: bool = False, 61 | ) -> pd.DataFrame: 62 | train_df = ( 63 | embedding_df.loc[embedding_df["count_train"] > 0] 64 | .drop(columns=["count_test"]) 65 | .rename(columns={"count_train": "count"}) 66 | ) 67 | test_df = ( 68 | embedding_df.loc[embedding_df["count_test"] > 0] 69 | .drop(columns=["count_train"]) 70 | .rename(columns={"count_test": "count"}) 71 | ) 72 | similarity_df = pd.merge(train_df, test_df, how="cross", suffixes=["_train", "_test"]) 73 | similarity_df["similarity"] = similarity_df.apply( 74 | lambda row: clipped_cosine_similarity(row["embedding_train"], row["embedding_test"]), 75 | axis=1, 76 | ) 77 | similarity_df.drop(columns=["embedding_train", "embedding_test"], inplace=True) 78 | 79 | if save_embeddings: 80 | similarity_df.to_pickle(output_path / "similarity_df.pkl") 81 | 82 | return similarity_df 83 | 84 | 85 | def compute_familiarity( 86 | similarity_df: pd.DataFrame, 87 | k: int = 1000, 88 | weighting: str = "zipf", 89 | output_path: Path = None, 90 | save_embeddings: bool = False, 91 | ) -> pd.DataFrame: 92 | familiarity_data = [] 93 | 94 | for label_test in similarity_df["label_test"].unique(): 95 | test_label_df = similarity_df[similarity_df["label_test"] == label_test] 96 | test_label_df = test_label_df.sort_values("similarity", ascending=False) 97 | counts = cumsum_until(test_label_df["count_train"], k) 98 | sims = test_label_df["similarity"][: len(counts)] 99 | familiarity = weighted_average(sims, counts, k, weighting=weighting) 100 | familiarity_data.append({"label": label_test, "familiarity": familiarity}) 101 | 102 | familiarity_df = pd.DataFrame(familiarity_data) 103 | 104 | if save_embeddings: 105 | familiarity_df.to_pickle(output_path / "familiarity_df.pkl") 106 | 107 | return familiarity_df 108 | 109 | 110 | def weighted_average( 111 | similarities: List[float], 112 | counts: List[int], 113 | k: int, 114 | weighting: str = "zipf", 115 | ) -> float: 116 | if weighting not in ["unweighted", "linear_decay", "zipf"]: 117 | raise ValueError(f"Possible weighting options: unweighted, linear_decay, zipf. {weighting} is not an option.") 118 | 119 | if weighting == "unweighted": 120 | return np.dot(np.array(similarities), np.array(counts)) / k 121 | 122 | if weighting == "linear_decay": 123 | linear_decay_weights = np.arange(1, k + 1, 1)[::-1] / k 124 | return np.dot(linear_decay_weights, np.repeat(similarities, counts)) / np.sum(linear_decay_weights) 125 | 126 | if weighting == "zipf": 127 | zipf_weights = 1 / np.arange(1, k + 1, 1) 128 | return np.dot(zipf_weights, np.repeat(similarities, counts)) / np.sum(zipf_weights) 129 | 130 | 131 | def compute_metric( 132 | train_labels: List[str], 133 | test_labels: List[str], 134 | model_name_or_path: Union[Path, str], 135 | batch_size: int = 32, 136 | k: int = 1000, 137 | weighting: str = "zipf", 138 | save_results: bool = False, 139 | save_embeddings: bool = False, 140 | ) -> None: 141 | 142 | train_labels_count = Counter(train_labels) 143 | test_labels_count = Counter(test_labels) 144 | model = load_embedding_model(model_name_or_path) 145 | 146 | output_path = None 147 | if save_results or save_embeddings: 148 | output_path = make_output_path() 149 | 150 | logger = setup_logger(output_path) if output_path else logging.getLogger(__name__) 151 | logger.info(50 * '-') 152 | logger.info(f"Train Labels Counter: {train_labels_count}") 153 | logger.info(f"Test Labels Counter: {test_labels_count}") 154 | logger.info(f"Model: {model_name_or_path}") 155 | logger.info(f"k-cutoff: {k}") 156 | logger.info(50 * '-' + "\n") 157 | 158 | embedding_df = compute_embeddings( 159 | train_labels_count=train_labels_count, 160 | test_labels_count=test_labels_count, 161 | model=model, 162 | batch_size=batch_size, 163 | output_path=output_path, 164 | save_embeddings=save_embeddings, 165 | ) 166 | 167 | similarity_df = compute_similarities(embedding_df, output_path=output_path, save_embeddings=save_embeddings) 168 | familiarity_df = compute_familiarity( 169 | similarity_df, k=k, weighting=weighting, output_path=output_path, save_embeddings=save_embeddings 170 | ) 171 | logger.info("Results:\n") 172 | logger.info(df_to_prettytable(familiarity_df)) 173 | -------------------------------------------------------------------------------- /src/familiarity/utils.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from datetime import datetime 3 | from itertools import islice 4 | from pathlib import Path 5 | from typing import Any, Dict, Iterator, List, Tuple 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import rootutils 10 | import torch 11 | from prettytable import PrettyTable 12 | 13 | 14 | def get_device() -> str: 15 | """Determine if GPU available.""" 16 | return "cuda" if torch.cuda.is_available() else "cpu" 17 | 18 | 19 | def make_output_path(base_path: Path = None) -> Path: 20 | if not base_path: 21 | base_path = rootutils.find_root(search_from=__file__, indicator=".project-root") 22 | 23 | current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 24 | output_path = Path(base_path / f"results/{current_datetime}") 25 | output_path.mkdir(parents=True) 26 | return output_path 27 | 28 | 29 | def df_to_prettytable(df: pd.DataFrame) -> PrettyTable: 30 | # Create a PrettyTable instance with column headers 31 | df = df.round(3) 32 | table = PrettyTable() 33 | table.field_names = df.columns.tolist() 34 | 35 | # Add each row from the DataFrame to the PrettyTable 36 | for idx, row in df.iterrows(): 37 | table.add_row(row, divider=True if idx + 1 == len(df) else False) 38 | 39 | table.add_row(["Marco-Avg. Familiarity", round(df["familiarity"].mean().item(), 3)]) 40 | 41 | return table 42 | 43 | 44 | def cumsum_until(counts: List[int], k: int) -> List[int]: 45 | """Cummulative sum of list of counts until k entries.""" 46 | cumsum = 0 47 | result = [] 48 | 49 | for count in counts: 50 | if cumsum + count >= k: 51 | result.append(k - cumsum) 52 | break 53 | else: 54 | cumsum += count 55 | result.append(count) 56 | 57 | return result 58 | 59 | 60 | def clipped_cosine_similarity(vec1: np.array, vec2: np.array) -> float: 61 | """Cosine similarity between two numpy arrays.""" 62 | dot_product = np.dot(vec1, vec2) 63 | 64 | norm_a = np.linalg.norm(vec1) 65 | norm_b = np.linalg.norm(vec2) 66 | 67 | if norm_a == 0 or norm_b == 0: 68 | return 0 69 | cos_sim = dot_product / (norm_a * norm_b) 70 | return max(cos_sim, 0) 71 | 72 | 73 | def combine_counters(train_counter: Counter, test_counter: Counter) -> Dict[str, Tuple[int, int]]: 74 | """Create a combined dictionary where each entry is a tuple (train_count, test_count)""" 75 | combined = dict( 76 | sorted( 77 | { 78 | word: (train_counter.get(word, 0), test_counter.get(word, 0)) 79 | for word in set(train_counter) | set(test_counter) 80 | }.items() 81 | ) 82 | ) 83 | return combined 84 | 85 | 86 | def iterate_dict_in_batches(d: Dict[Any, Any], batch_size: int) -> Iterator: 87 | """Iterator over dictionary""" 88 | it = iter(d.items()) 89 | for _ in range(0, len(d), batch_size): 90 | batch = dict(islice(it, batch_size)) 91 | yield batch 92 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from familiarity.embedding_models import LabelEmbeddingModel 4 | 5 | 6 | @pytest.fixture(scope="module") 7 | def dummy_ner_train(): 8 | np.random.seed(42) 9 | train_labels_set = [ 10 | "person", 11 | "location", 12 | "building", 13 | "eagle", 14 | "restaurant", 15 | "util", 16 | ] 17 | train_probs = [0.4, 0.1, 0.1, 0.1, 0.1, 0.2] 18 | train_labels = np.random.choice(train_labels_set, size=30000, p=train_probs).tolist() 19 | 20 | return train_labels 21 | 22 | 23 | @pytest.fixture(scope="module") 24 | def dummy_ner_test(): 25 | np.random.seed(42) 26 | test_labels_set = [ 27 | "human", 28 | "organization", 29 | "building", 30 | "review", 31 | "researcher", 32 | "car", 33 | ] 34 | test_probs = [0.5, 0.2, 0.05, 0.05, 0.1, 0.1] 35 | test_labels = np.random.choice(test_labels_set, size=30000, p=test_probs).tolist() 36 | return test_labels 37 | 38 | 39 | @pytest.fixture(scope="module") 40 | def tiny_glove_file(tmp_path_factory): 41 | glove_content = "word1 0.1 0.2 0.3\n" "word2 0.4 0.5 0.6" 42 | glove_path = tmp_path_factory.mktemp("models") / "tiny_glove.txt" 43 | glove_path.write_text(glove_content) 44 | return glove_path 45 | 46 | 47 | @pytest.fixture(scope="module") 48 | def tiny_fasttext_file(tmp_path_factory): 49 | glove_content = "2 3\n" "word1 0.1 0.2 0.3\n" "word2 0.4 0.5 0.6" 50 | glove_path = tmp_path_factory.mktemp("models") / "tiny_fasttext.txt" 51 | glove_path.write_text(glove_content) 52 | return glove_path 53 | 54 | 55 | @pytest.fixture(scope="module") 56 | def sample_embedding_model(): 57 | class MockEmbeddingModel(LabelEmbeddingModel): 58 | def embed(self, batch) -> np.array: 59 | np.random.seed(42) 60 | return np.stack([np.random.rand(10) for word in batch]) 61 | 62 | return MockEmbeddingModel() 63 | -------------------------------------------------------------------------------- /tests/test_embedding_models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from familiarity.embedding_models import ( 4 | FastTextModel, 5 | GloveModel, 6 | SentenceTransformerModel, 7 | TransformerModel, 8 | infer_embedding_model, 9 | load_embedding_model, 10 | ) 11 | 12 | 13 | def test_fasttext_model(tiny_fasttext_file): 14 | model = FastTextModel(model_path=tiny_fasttext_file) 15 | result = model.embed(["word1", "word2", "unknown_word"]) 16 | assert isinstance(result, np.ndarray) 17 | assert np.all(result[-1] == np.mean(result[:2], axis=0)) 18 | assert result.shape == (3, 3) 19 | 20 | 21 | def test_glove_model(tiny_glove_file): 22 | model = GloveModel(model_path=tiny_glove_file) 23 | result = model.embed(["word1", "word2", "unknown_word"]) 24 | assert isinstance(result, np.ndarray) 25 | assert np.all(result[-1] == np.mean(result[:2], axis=0)) 26 | assert result.shape == (3, 3) 27 | 28 | 29 | def test_transformer_model(): 30 | transformer_model = TransformerModel(model_name_or_path="distilbert-base-uncased") 31 | result = transformer_model.embed(["This is a test sentence.", "Another sentence."]) 32 | assert isinstance(result, np.ndarray) 33 | assert result.shape[0] == 2 # Two sentences 34 | 35 | 36 | def test_sentence_transformer_model(): 37 | model = SentenceTransformerModel("sentence-transformers/paraphrase-albert-small-v2") 38 | result = model.embed(["This is a test sentence.", "Another sentence."]) 39 | assert isinstance(result, np.ndarray) 40 | assert result.shape == (2, 768) 41 | 42 | 43 | @pytest.mark.parametrize( 44 | "model_path, expected_type", 45 | [ 46 | ("glove-xyz", "glove"), 47 | ("sentence-transformers/some-model", "sentence-transformers"), 48 | ("wiki-news-300d-1M-subword.vec", "fasttext"), 49 | ("distilbert-base-uncased", "transformers"), 50 | ], 51 | ) 52 | def test_infer_embedding_model(model_path, expected_type): 53 | assert infer_embedding_model(model_path) == expected_type 54 | 55 | 56 | @pytest.mark.parametrize("model_path", ["random-gibberish", "unknown-format", "1234!@#$"]) 57 | def test_infer_embedding_model_invalid(model_path): 58 | assert infer_embedding_model(model_path) is None 59 | 60 | 61 | @pytest.mark.parametrize("model_path", ["random-gibberish", "unknown-format", "1234!@#$"]) 62 | def test_load_embedding_model_invalid(model_path): 63 | with pytest.raises(ValueError, match="Embedding model type can't be inferred."): 64 | load_embedding_model(model_path) 65 | 66 | 67 | def test_load_embedding_model_glove(tiny_glove_file): 68 | model = load_embedding_model(model_name_or_path=tiny_glove_file, embedding_model_type="glove") 69 | assert isinstance(model, GloveModel) 70 | 71 | 72 | def test_load_embedding_model_sentence_transformer(): 73 | model = load_embedding_model( 74 | model_name_or_path="sentence-transformers/paraphrase-albert-small-v2", 75 | embedding_model_type="sentence-transformers", 76 | ) 77 | assert isinstance(model, SentenceTransformerModel) 78 | 79 | 80 | def test_load_embedding_model_fasttext(tiny_fasttext_file): 81 | model = load_embedding_model(model_name_or_path=tiny_fasttext_file, embedding_model_type="fasttext") 82 | assert isinstance(model, FastTextModel) 83 | 84 | 85 | def test_load_embedding_model_transformers(): 86 | model = load_embedding_model(model_name_or_path="distilbert-base-uncased", embedding_model_type="transformers") 87 | assert isinstance(model, TransformerModel) 88 | -------------------------------------------------------------------------------- /tests/test_logger.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from familiarity.logger import setup_logger 4 | 5 | 6 | def test_setup_logger(tmp_path: Path, capsys, caplog): 7 | output_path = tmp_path / "test_output" 8 | output_path.mkdir(parents=True, exist_ok=True) 9 | 10 | # Call setup_logger with a temporary path 11 | logger = setup_logger(output_path) 12 | 13 | # Check if the log file path is correct 14 | log_file = output_path / "compute_metric.log" 15 | assert log_file.exists(), "Log file was not created." 16 | 17 | # Test writing a log message 18 | test_message = "This is a test log message." 19 | logger.info(test_message) 20 | 21 | # Flush the handlers to ensure all output is written 22 | for handler in logger.handlers: 23 | handler.flush() 24 | 25 | # Check if the log message was written to the file 26 | with open(log_file, "r") as f: 27 | log_content = f.read() 28 | assert test_message in log_content, "Log file does not contain the expected log message." 29 | 30 | # Check log capture with caplog as a fallback 31 | assert any(test_message in record.message for record in caplog.records), "Log message not found in caplog records." 32 | -------------------------------------------------------------------------------- /tests/test_metric.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | import pytest 4 | from familiarity.metric import compute_embeddings, compute_familiarity, compute_similarities, weighted_average 5 | 6 | 7 | def test_compute_embeddings(dummy_ner_train, dummy_ner_test, sample_embedding_model, tmp_path): 8 | train_counter = Counter(dummy_ner_train) 9 | test_counter = Counter(dummy_ner_test) 10 | embedding_df = compute_embeddings( 11 | train_labels_count=train_counter, 12 | test_labels_count=test_counter, 13 | model=sample_embedding_model, 14 | output_path=tmp_path, 15 | save_embeddings=True, 16 | ) 17 | assert "embedding" in embedding_df.columns 18 | assert len(embedding_df) == 11 19 | assert embedding_df["count_train"].sum() == 30000 20 | assert embedding_df["count_test"].sum() == 30000 21 | assert (tmp_path / "embedding_df.pkl").exists() 22 | 23 | 24 | def test_compute_similarities(dummy_ner_train, dummy_ner_test, sample_embedding_model, tmp_path): 25 | train_counter = Counter(dummy_ner_train) 26 | test_counter = Counter(dummy_ner_test) 27 | embedding_df = compute_embeddings( 28 | train_labels_count=train_counter, 29 | test_labels_count=test_counter, 30 | model=sample_embedding_model, 31 | ) 32 | similarity_df = compute_similarities(embedding_df, output_path=tmp_path, save_embeddings=True) 33 | assert "similarity" in similarity_df.columns 34 | assert len(similarity_df) == (len(train_counter) * len(test_counter)) 35 | assert similarity_df["similarity"].min() >= 0 36 | assert similarity_df["similarity"].max() <= 1 37 | assert (tmp_path / "similarity_df.pkl").exists() 38 | 39 | 40 | def test_compute_familiarity(dummy_ner_train, dummy_ner_test, sample_embedding_model, tmp_path): 41 | train_counter = Counter(dummy_ner_train) 42 | test_counter = Counter(dummy_ner_test) 43 | embedding_df = compute_embeddings( 44 | train_labels_count=train_counter, 45 | test_labels_count=test_counter, 46 | model=sample_embedding_model, 47 | ) 48 | similarity_df = compute_similarities(embedding_df) 49 | familiarity_df = compute_familiarity( 50 | similarity_df, k=2, weighting="zipf", output_path=tmp_path, save_embeddings=True 51 | ) 52 | assert "familiarity" in familiarity_df.columns 53 | assert len(familiarity_df) == len(test_counter) 54 | assert pytest.approx(familiarity_df[familiarity_df["label"] == "building"]["familiarity"].iloc[0]) == 1 55 | assert pytest.approx(familiarity_df[familiarity_df["label"] == "car"]["familiarity"].iloc[0]) == 0.907777 56 | assert pytest.approx(familiarity_df[familiarity_df["label"] == "review"]["familiarity"].iloc[0]) == 0.912969 57 | assert (tmp_path / "familiarity_df.pkl").exists() 58 | 59 | 60 | @pytest.mark.parametrize( 61 | "sims, counts, weighting, k, gold_result", 62 | [ 63 | ([0.95, 0.8], [200, 200], "zipf", 400, 0.93420), 64 | ([0.95, 0.8, 0.7], [200, 500, 100], "zipf", 800, 0.91956), 65 | ([0.95, 0.8], [200, 200], "linear_decay", 400, 0.91240), 66 | ([0.95, 0.8, 0.7], [200, 500, 100], "linear_decay", 800, 0.86401), 67 | ([0.95, 0.8], [200, 200], "unweighted", 400, 0.875), 68 | ([0.95, 0.8, 0.7], [200, 500, 100], "unweighted", 800, 0.825), 69 | ], 70 | ) 71 | def test_weighted_average(sims, counts, weighting, k, gold_result): 72 | result = weighted_average(sims, counts, k=k, weighting=weighting) 73 | assert pytest.approx(result, 0.001) == gold_result 74 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import Counter 3 | from pathlib import Path 4 | from typing import Any, Dict 5 | 6 | import numpy as np 7 | import pytest 8 | from familiarity.utils import ( 9 | clipped_cosine_similarity, 10 | combine_counters, 11 | cumsum_until, 12 | iterate_dict_in_batches, 13 | make_output_path, 14 | ) 15 | 16 | 17 | def test_make_output_path(tmp_path): 18 | output_path = make_output_path(base_path=tmp_path) 19 | 20 | assert isinstance(output_path, Path) 21 | assert os.path.exists(output_path) 22 | assert os.path.isdir(output_path) 23 | 24 | 25 | @pytest.mark.parametrize( 26 | "values, threshold, expected", 27 | [ 28 | ([10, 20, 30], 25, [10, 15]), # Threshold reached within two elements 29 | ([5, 10, 15, 20], 30, [5, 10, 15]), # Accumulated sum stops before reaching the last element 30 | ([5, 10, 15], 10, [5, 5]), # Threshold reached exactly with two elements 31 | ([50, 20], 40, [40]), # Single element truncated to meet the threshold 32 | ([5, 5, 5], 50, [5, 5, 5]), # Threshold not reached, all values returned 33 | ], 34 | ) 35 | def test_cumsum_until(values, threshold, expected): 36 | assert cumsum_until(values, threshold) == expected 37 | 38 | 39 | @pytest.mark.parametrize( 40 | "vec1, vec2, expected", 41 | [ 42 | (np.array([1, 0]), np.array([1, 0]), 1.0), # Identical vectors 43 | (np.array([1, 0]), np.array([-1, 0]), 0.0), # Opposite vectors, should clip to 0 44 | (np.array([1, 1]), np.array([1, 1]), 1.0), # Same direction 45 | (np.array([1, 0]), np.array([0, 1]), 0.0), # Orthogonal vectors 46 | (np.array([0, 0]), np.array([1, 1]), 0.0), # vec1 is zero vector, expect 0 47 | (np.array([1, 1]), np.array([0, 0]), 0.0), # vec2 is zero vector, expect 0 48 | (np.array([0, 0]), np.array([0, 0]), 0.0), # Both are zero vectors 49 | (np.array([1, 2, 3]), np.array([4, 5, 6]), 0.974631), # General positive case 50 | (np.array([1, -1]), np.array([-1, 1]), 0.0), # Negative similarity, should clip to 0 51 | ], 52 | ) 53 | def test_clipped_cosine_similarity(vec1, vec2, expected): 54 | result = clipped_cosine_similarity(vec1, vec2) 55 | assert result == pytest.approx(expected, 0.00001) 56 | 57 | 58 | @pytest.mark.parametrize( 59 | "train_counter, test_counter, expected", 60 | [ 61 | # Basic test case with some overlapping words 62 | ( 63 | Counter({"apple": 3, "banana": 2}), 64 | Counter({"banana": 4, "cherry": 1}), 65 | {"apple": (3, 0), "banana": (2, 4), "cherry": (0, 1)}, 66 | ), 67 | # Case where train_counter has unique words 68 | ( 69 | Counter({"apple": 3}), 70 | Counter({"banana": 4}), 71 | {"apple": (3, 0), "banana": (0, 4)}, 72 | ), 73 | # Case where both counters are empty 74 | ( 75 | Counter(), 76 | Counter(), 77 | {}, 78 | ), 79 | # Case where test_counter is empty 80 | ( 81 | Counter({"apple": 3, "banana": 2}), 82 | Counter(), 83 | {"apple": (3, 0), "banana": (2, 0)}, 84 | ), 85 | # Case where train_counter is empty 86 | ( 87 | Counter(), 88 | Counter({"banana": 4, "cherry": 1}), 89 | {"banana": (0, 4), "cherry": (0, 1)}, 90 | ), 91 | # Case with more complex counters 92 | ( 93 | Counter({"apple": 2, "banana": 1, "cherry": 5}), 94 | Counter({"banana": 3, "cherry": 2, "date": 4}), 95 | {"apple": (2, 0), "banana": (1, 3), "cherry": (5, 2), "date": (0, 4)}, 96 | ), 97 | ], 98 | ) 99 | def test_combine_counters(train_counter, test_counter, expected): 100 | result = combine_counters(train_counter, test_counter) 101 | assert result == expected 102 | 103 | 104 | @pytest.mark.parametrize( 105 | "d, batch_size, expected_batches", 106 | [ 107 | # Basic case: dictionary with 5 items and batch size of 2 108 | ( 109 | {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}, 110 | 2, 111 | [{"a": 1, "b": 2}, {"c": 3, "d": 4}, {"e": 5}], 112 | ), 113 | # Dictionary with fewer items than batch size 114 | ( 115 | {"a": 1, "b": 2}, 116 | 5, 117 | [{"a": 1, "b": 2}], 118 | ), 119 | # Dictionary with exact multiples of batch size 120 | ( 121 | {"a": 1, "b": 2, "c": 3, "d": 4}, 122 | 2, 123 | [{"a": 1, "b": 2}, {"c": 3, "d": 4}], 124 | ), 125 | # Dictionary with batch size of 1 (each item in its own batch) 126 | ( 127 | {"a": 1, "b": 2, "c": 3}, 128 | 1, 129 | [{"a": 1}, {"b": 2}, {"c": 3}], 130 | ), 131 | # Empty dictionary should yield no batches 132 | ( 133 | {}, 134 | 2, 135 | [], 136 | ), 137 | ], 138 | ) 139 | def test_iterate_dict_in_batches(d: Dict[Any, Any], batch_size: int, expected_batches: list[Dict[Any, Any]]): 140 | result_batches = list(iterate_dict_in_batches(d, batch_size)) 141 | assert result_batches == expected_batches 142 | --------------------------------------------------------------------------------