├── .gitignore ├── Example.ipynb ├── LICENSE ├── README.md ├── images ├── LEMON.png └── explanation.png ├── poetry.lock ├── poetry.toml ├── pyproject.toml └── src └── lemon ├── __init__.py ├── _lemon.py ├── _lemon_utils.py ├── _matching_attribution_explanation.py └── utils ├── __init__.py ├── datasets ├── __init__.py ├── _dataset.py ├── _utils.py └── deepmatcher.py └── matchers ├── __init__.py ├── _magellan.py └── _transformer_matcher.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Nils Barlaug 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LEMON: Explainable Entity Matching 2 | 3 | ![Illustration of LEMON](images/LEMON.png) 4 | 5 | LEMON is an explainability method that addresses the unique challenges of explaining entity matching models. 6 | 7 | 8 | ## Installation 9 | 10 | ```shell 11 | pip install lemon-explain 12 | ``` 13 | or 14 | ```shell 15 | pip install lemon-explain[storage] # Save and load explanations 16 | pip install lemon-explain[matchers] # To run matchers in lemon.utils 17 | pip install lemon-explain[all] # All dependencies 18 | ``` 19 | 20 | ## Usage 21 | 22 | ```python 23 | import lemon 24 | 25 | 26 | # You need a matcher that follows this api: 27 | def predict_proba(records_a, records_b, record_id_pairs): 28 | ... # predict probabilities / confidence scores 29 | return proba 30 | 31 | exp = lemon.explain(records_a, records_b, record_id_pairs, predict_proba) 32 | 33 | # exp can be visualized in a Jupyter notebook or saved to a json file 34 | exp.save("explanation.json") 35 | 36 | ``` 37 | [See the example notebook](https://nbviewer.jupyter.org/github/NilsBarlaug/lemon/blob/main/Example.ipynb) 38 | 39 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NilsBarlaug/lemon/blob/main/Example.ipynb) 40 | 41 | ![Example of explanation from LEMON](images/explanation.png) 42 | 43 | ## Documentation 44 | 45 | ### lemon.explain() 46 | 47 | ``` 48 | lemon.explain( 49 | records_a: pd.DataFrame, 50 | records_b: pd.DataFrame, 51 | record_id_pairs: pd.DataFrame, 52 | predict_proba: Callable, 53 | *, 54 | num_features: int = 5, 55 | dual_explanation: bool = True, 56 | estimate_potential: bool = True, 57 | granularity: str = "counterfactual", 58 | num_samples: int = None, 59 | token_representation: str = "record-bow", 60 | token_patterns: Union[str, List[str], Dict] = "[^ ]+", 61 | explain_attrs: bool = False, 62 | attribution_method: str = "lime", 63 | show_progress: bool = True, 64 | random_state: Union[int, np.random.Generator, None] = 0, 65 | return_dict: bool = None, 66 | ) -> Union[MatchingAttributionExplanation, Dict[any, MatchingAttributionExplanation]]: 67 | ``` 68 | 69 | #### Parameters 70 | - **records_a** : pd.DataFrame 71 | - Records from data source a. 72 | - **records_b** : pd.DataFrame 73 | - Records from data source b. 74 | - **record_id_pairs** : pd.DataFrame 75 | - Which record pairs to explain. 76 | Must be a pd.DataFrame with columns `"a.rid"` and `"b.rid"` that reference the index of `records_a` and `records_b` respectively. 77 | - **predict_proba** : Callable 78 | - Matcher function that predicts the probability of match. 79 | Must accept three arguments: `records_a`, `records_b`, and `record_id_pairs`. 80 | Should return array-like (list, np.ndarray, pd.Series, ...) of floats between 0 and 1 - the predicted probability that a record pair is a match - for all record pairs described by `record_id_pairs` in the same order. 81 | - **num_features** : int, default = 5 82 | - The number of features to select for the explanation. 83 | - **dual_explanation** : bool, default = True 84 | - Whether to use dual explanations or not. 85 | - **estimate_potential** : bool, default = True 86 | - Whether to estimate potential or not. 87 | - **granularity** : {"tokens", "attributes", "counterfactual"}, default = "counterfactual" 88 | - The granularity of the explanation. 89 | For more info on `"counterfactual"` granularity see our paper. 90 | - **num_samples** : int, default = None 91 | - The number of neighborhood samples to use. 92 | If None a heuristic will automatically pick the number of samples. 93 | - **token_representation** : {"independent", "shared-bow", "record-bow"}, default = "record-bow" 94 | - Which token representation to use. 95 | - independent: All tokens are unique. 96 | - shared-bow: Bag-of-words representation shared across both records 97 | - record-bow: Bag-of-words representation per individual record 98 | - **token_patterns** : str, List[str], or Dict, default = `"[^ ]+"` 99 | - Regex patterns for valid tokens in strings. 100 | A single string will be interpreted as a regex pattern and all strings will be tokenized into non-overlapping matches of this pattern. 101 | You can specify a list of patterns to tokenize into non-overlapping matches of any pattern. 102 | For fine-grained control of how different parts of records are tokenized you can provide a dictionary with keys on the format `("a" or "b", attribute_name, "attr" or "val")` and values that are lists of token regex patterns. 103 | - **explain_attrs** : bool, default = False 104 | - Whether to explain attribution names or not. 105 | If True, `predict_proba` should accept the keyword argument `attr_strings` - a list that specifies what strings to use as attributes for each prediction. 106 | Each list element is on the format {("a" or "b", record_column_name): attr_string}. 107 | - **attribution_method** : {"lime", "shap"}, default = False 108 | - Which underlying method to use contribution estimation. 109 | Note that in order to use shap `estimate_potential` must be False and the shap package must be installed. 110 | - **show_progress** : bool, default = True 111 | - Whether to show progress or not. This is passed to `predict_proba` if it accepts this parameter. 112 | - **return_dict** : bool, default = None 113 | - If True a dictionary of explanations will be returned where the keys are labels from the index of `record_id_pairs`. 114 | If False a single explanation will be returned (an exception is raised if `len(record_id_pairs) > 1`). 115 | If None it will return a single explanation if `len(record_id_pairs)` and a dictionary otherwise. 116 | 117 | #### Returns 118 | `lemon.MatchingAttributionExplanation` isntance or an `Dict[any, lemon.MatchingAttributionExplanation]`, 119 | depending on the input to the `return_dict` parameter. 120 | 121 | 122 | ### lemon.MatchingAttributionExplanation 123 | 124 | #### Attributes 125 | - **record_pair** : pd.DataFrame 126 | - **string_representation** : Dict[Tuple, Union[None, str, TokenizedString]], 127 | - **attributions** : List[Attribution], 128 | - **prediction_score** : float 129 | - **dual** : bool 130 | - **metadata** : Dict[str, any] 131 | 132 | #### Methods 133 | - **save(path: str = None) -> Optional[Dict]** 134 | - Save the explanation to a json file. 135 | If path is not specified a json-serializable dictionary will be returned. 136 | Requires pyarrow to be installed (`pip install lemon-explain[storage]`). 137 | - **static load(path: Union[str, Dict]) -> MatchingAttributionExplanation** 138 | - Load an explanation from a json file. 139 | Instead of a path, one can instead provide a json-serializable dictionary. 140 | Requires pyarrow to be installed (`pip install lemon-explain[storage]`). 141 | 142 | ### lemon.Attribution 143 | 144 | #### Attributes 145 | - **weight**: float 146 | - **potential**: Optional[float] 147 | - **positions**: List[Union[Tuple[str, str, str, Optional[int]]]] 148 | - **name**: Optional[str] 149 | -------------------------------------------------------------------------------- /images/LEMON.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NilsBarlaug/lemon/ee82f20253c50eb5a958fc5507b0df8ca51fa317/images/LEMON.png -------------------------------------------------------------------------------- /images/explanation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NilsBarlaug/lemon/ee82f20253c50eb5a958fc5507b0df8ca51fa317/images/explanation.png -------------------------------------------------------------------------------- /poetry.toml: -------------------------------------------------------------------------------- 1 | [virtualenvs] 2 | in-project = true 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "lemon-explain" 3 | version = "0.1.1" 4 | description = "LEMON: Explainable Entity Matching" 5 | license = "MIT" 6 | authors = ["Nils Barlaug "] 7 | readme = "README.md" 8 | homepage = "https://github.com/NilsBarlaug/lemon" 9 | repository = "https://github.com/NilsBarlaug/lemon" 10 | documentation = "https://github.com/NilsBarlaug/lemon" 11 | packages = [ 12 | { include = "lemon", from = "src" } 13 | ] 14 | 15 | [tool.poetry.dependencies] 16 | python = "^3.7" 17 | pandas = "^1" 18 | scikit-learn = "^1.0" 19 | transformers = {version = "^4.10.3", optional = true} 20 | torch = {version = "^1.9.1", optional = true} 21 | py-entitymatching = {version = "^0.4.0", optional = true} 22 | pyarrow = {version = "^5.0.0", optional = true} 23 | 24 | [tool.poetry.extras] 25 | matchers = ["transformers", "torch", "py_entitymatching"] 26 | storage = ["pyarrow"] 27 | all = ["transformers", "torch", "py_entitymatching", "pyarrow"] 28 | 29 | [tool.poetry.dev-dependencies] 30 | black = "^21.9b0" 31 | isort = "^5.9.3" 32 | jupyterlab = "^3.1.14" 33 | 34 | [tool.black] 35 | line-length = 120 36 | target_version = ['py38'] 37 | include = '\.py$' 38 | 39 | [tool.isort] 40 | line_length=120 41 | multi_line_output=3 42 | include_trailing_comma=true 43 | skip_glob = '^((?!py$).)*$' 44 | known_third_party = [] 45 | 46 | [build-system] 47 | requires = ["poetry-core>=1.0.0"] 48 | build-backend = "poetry.core.masonry.api" 49 | -------------------------------------------------------------------------------- /src/lemon/__init__.py: -------------------------------------------------------------------------------- 1 | from . import utils 2 | from ._lemon import explain 3 | from ._lemon_utils import TokenizedString 4 | from ._matching_attribution_explanation import Attribution, MatchingAttributionExplanation 5 | 6 | __all__ = ["explain", "utils", "Attribution", "MatchingAttributionExplanation", "TokenizedString"] 7 | 8 | 9 | def __dir__(): 10 | return __all__ 11 | -------------------------------------------------------------------------------- /src/lemon/_lemon.py: -------------------------------------------------------------------------------- 1 | import math 2 | import re 3 | import warnings 4 | from typing import Callable, Dict, List, Optional, Tuple, Union 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | from ._lemon_utils import TokenizedString, get_predictions_scores_for_perturbed_record_pairs, perturb_record_pair 10 | from ._matching_attribution_explanation import ( 11 | Attribution, 12 | MatchingAttributionExplanation, 13 | explanation_counterfactual_strength, 14 | ) 15 | 16 | 17 | def _reduce_tokenized_string_granularity(s: TokenizedString, n: int) -> TokenizedString: 18 | num_tokens = len(s) 19 | num_output_tokens = math.ceil(num_tokens / n) 20 | group_size = len(s) / num_output_tokens 21 | num_tokens_consumed = 0 22 | 23 | for i in range(num_output_tokens): 24 | to_be_consumed = n 25 | while num_tokens_consumed + to_be_consumed >= (i + 1) * group_size + 1: 26 | to_be_consumed -= 1 27 | s = s.merge(start=i, end=i + to_be_consumed) 28 | num_tokens_consumed += to_be_consumed 29 | return s 30 | 31 | 32 | class _InterpretableRecordPair: 33 | def __init__( 34 | self, 35 | record_pair: pd.DataFrame, 36 | granularity: str, 37 | token_representation: str, 38 | features_a: bool, 39 | features_b: bool, 40 | features_attr: bool, 41 | features_val: bool, 42 | token_regexes: Union[Dict[Tuple[str, str, str], Union[List[str], str]], List[str], str], 43 | ): 44 | if token_representation not in ["independent", "record-bow", "shared-bow"]: 45 | raise ValueError("token_repesentation must be one of ['independent', 'record-bow', 'shared-bow']") 46 | assert granularity in ["tokens", "attributes"] or re.fullmatch("[0-9]+-tokens", granularity) 47 | 48 | self.record_pair = record_pair 49 | self.granularity = granularity 50 | self.token_representation = token_representation 51 | self.features_a = features_a 52 | self.features_b = features_b 53 | self.features_attr = features_attr 54 | self.features_val = features_val 55 | self.token_regexes = token_regexes 56 | if isinstance(self.token_regexes, str): 57 | self.token_regexes = [self.token_regexes] 58 | if isinstance(self.token_regexes, list): 59 | rgs = self.token_regexes 60 | self.token_regexes = {} 61 | for source, attr in self.record_pair.columns.to_list(): 62 | if self.features_attr: 63 | self.token_regexes[(source, attr, "attr")] = rgs 64 | if self.features_val: 65 | self.token_regexes[(source, attr, "val")] = rgs 66 | 67 | self.string_representation = {} 68 | 69 | if self.granularity.endswith("tokens"): 70 | for source, attr in self.record_pair.columns.to_list(): 71 | if features_attr: 72 | self.string_representation[(source, attr, "attr")] = TokenizedString.tokenize( 73 | attr, self.token_regexes[(source, attr, "attr")] 74 | ) 75 | if features_val: 76 | if pd.api.types.is_string_dtype(self.record_pair[source, attr]): 77 | val = self.record_pair[source, attr].iloc[0] 78 | if not pd.isna(val): 79 | self.string_representation[(source, attr, "val")] = TokenizedString.tokenize( 80 | val, self.token_regexes[(source, attr, "val")] 81 | ) 82 | 83 | if self.granularity != "tokens": 84 | n = int(self.granularity.split("-")[0]) 85 | for pos, repr in self.string_representation.items(): 86 | if repr is not None: 87 | self.string_representation[pos] = _reduce_tokenized_string_granularity(repr, n) 88 | 89 | self._feature_pos = [] 90 | self._feature_values = [] 91 | inverted_index = {} 92 | value_to_index = {} 93 | value_to_positions = {} 94 | for source in ["a", "b"]: 95 | if (source == "a" and not features_a) or (source == "b" and not features_b): 96 | continue 97 | if token_representation == "record-bow": 98 | value_to_index.clear() 99 | 100 | for attr in self.record_pair[source].columns: 101 | for attr_or_val in ["attr", "val"]: 102 | if (attr_or_val == "attr" and not features_attr) or (attr_or_val == "val" and not features_val): 103 | continue 104 | if (source, attr, attr_or_val) in self.string_representation: 105 | tokenized_string = self.string_representation[(source, attr, attr_or_val)] 106 | for j, t in enumerate(tokenized_string): 107 | if token_representation.endswith("bow"): 108 | if t not in value_to_index: 109 | value_to_index[t] = len(self._feature_pos) 110 | self._feature_pos.append([]) 111 | self._feature_values.append(t) 112 | self._feature_pos[value_to_index[t]].append((source, attr, attr_or_val, j)) 113 | inverted_index[(source, attr, attr_or_val, j)] = value_to_index[t] 114 | else: 115 | inverted_index[(source, attr, attr_or_val, j)] = len(self._feature_pos) 116 | self._feature_pos.append([(source, attr, attr_or_val, j)]) 117 | self._feature_values.append(t) 118 | 119 | if t not in value_to_positions: 120 | value_to_positions[t] = [] 121 | value_to_positions[t].append((source, attr, attr_or_val, j)) 122 | else: 123 | val = record_pair[source, attr].iloc[0] if attr_or_val == "val" else attr 124 | if pd.isna(val): 125 | continue 126 | inverted_index[(source, attr, attr_or_val, None)] = len(self._feature_pos) 127 | self._feature_pos.append([(source, attr, attr_or_val, None)]) 128 | self._feature_values.append(val) 129 | 130 | if val not in value_to_positions: 131 | value_to_positions[val] = [] 132 | value_to_positions[val].append((source, attr)) 133 | 134 | def __len__(self) -> int: 135 | return len(self._feature_pos) 136 | 137 | def get_all_pos(self, i: int) -> List[Tuple[str, str, str, Optional[int]]]: 138 | return self._feature_pos[i] 139 | 140 | def get_first_pos(self, i: int) -> Tuple[str, str, str, Optional[int]]: 141 | return self._feature_pos[i][0] 142 | 143 | def get_value(self, i: int) -> any: 144 | return self._feature_values[i] 145 | 146 | 147 | class _InterpretableSamples: 148 | def __init__( 149 | self, 150 | num_samples: int, 151 | record_pair: _InterpretableRecordPair, 152 | random_state: np.random.Generator, 153 | perturb_injection=True, 154 | ): 155 | import sklearn.preprocessing 156 | 157 | self.num_samples = num_samples 158 | self.record_pair = record_pair 159 | self.random_state = random_state 160 | self.perturb_injection = perturb_injection 161 | 162 | num_features = len(record_pair) 163 | 164 | self._X = np.zeros((num_samples, num_features)) 165 | self.distances = [0] 166 | 167 | min_num_exclusions = min(1, max(0, num_features - 1)) 168 | max_num_exclusions = min(max(5, num_features // 5), num_features) 169 | min_num_injections = 0 170 | max_num_injections = min(max(3, num_features // 10), num_features) 171 | 172 | for i in range(1, num_samples): 173 | num_exclusions = random_state.integers(min_num_exclusions, max_num_exclusions) if num_features > 0 else 0 174 | num_injections = ( 175 | random_state.integers(min_num_injections, max_num_injections) 176 | if perturb_injection and random_state.random() > 0.5 and num_features > 0 177 | else 0 178 | ) 179 | num_changes = num_injections + num_exclusions 180 | self.distances.append(num_changes / (max_num_exclusions + max_num_injections)) 181 | 182 | exclude_indices = random_state.choice(num_features, replace=False, size=num_exclusions) 183 | self._X[i, exclude_indices] = 1 184 | 185 | inject_indices = random_state.choice(num_features, replace=False, size=num_injections) 186 | self._X[i, inject_indices] = 2 187 | 188 | categories = [[0, 1, 2]] * self._X.shape[1] if self.perturb_injection else [[0, 1]] * self._X.shape[1] 189 | self._X_dummy = sklearn.preprocessing.OneHotEncoder( 190 | categories=categories, drop="first", sparse=False 191 | ).fit_transform(self._X) 192 | 193 | self.distances = np.array(self.distances) 194 | 195 | def features(self, dummy_encode: bool = True) -> np.ndarray: 196 | if dummy_encode: 197 | return self._X_dummy 198 | else: 199 | return self._X 200 | 201 | 202 | def _get_perturbed_record_pairs( 203 | X: np.ndarray, record_pair: _InterpretableRecordPair, random_state: np.random.Generator 204 | ) -> Tuple[pd.DataFrame, List[Dict], List[int]]: 205 | exclusions = [] 206 | injections = [] 207 | for i in range(X.shape[0]): 208 | exclusions.append([record_pair.get_first_pos(j) for j in (X[i] == 1).nonzero()[0]]) 209 | injections.append([record_pair.get_first_pos(j) for j in (X[i] == 2).nonzero()[0]]) 210 | records_pairs, attr_strings, groups = perturb_record_pair( 211 | record_pair=record_pair.record_pair, 212 | string_representation=record_pair.string_representation, 213 | perturbations=list(zip(exclusions, injections)), 214 | random_state=random_state, 215 | ) 216 | return records_pairs, attr_strings, groups 217 | 218 | 219 | def _get_predictions( 220 | X: np.ndarray, 221 | record_pair: _InterpretableRecordPair, 222 | predict_proba: Callable, 223 | random_state, 224 | show_progress: bool = False, 225 | ) -> np.ndarray: 226 | record_pairs, attr_strings, groups = _get_perturbed_record_pairs(X, record_pair, random_state) 227 | return get_predictions_scores_for_perturbed_record_pairs( 228 | record_pairs, attr_strings, groups, predict_proba, show_progress 229 | ) 230 | 231 | 232 | def _kernel_fn(d: np.ndarray) -> np.ndarray: 233 | return np.exp(-2 * d) 234 | 235 | 236 | def _forward_selection( 237 | X: np.ndarray, y: np.ndarray, sample_weights: np.ndarray, num_features: int, feature_group_size: int 238 | ) -> np.ndarray: 239 | from sklearn.linear_model import LinearRegression 240 | 241 | used_features = [] 242 | for _ in range(min(num_features, X.shape[1] // feature_group_size)): 243 | max_ = float("-inf") 244 | best = -1 245 | for feature in range(0, X.shape[1], feature_group_size): 246 | if feature in used_features: 247 | continue 248 | feature_group = list(range(feature, feature + feature_group_size)) 249 | X_used = X[:, used_features + feature_group] 250 | clf = LinearRegression(fit_intercept=False) 251 | clf.fit(X_used, y, sample_weight=sample_weights) 252 | score = clf.score(X_used, y, sample_weight=sample_weights) 253 | if score > max_: 254 | best = feature_group 255 | max_ = score 256 | used_features.extend(best) 257 | return np.array(used_features) 258 | 259 | 260 | def _lime( 261 | X: np.ndarray, y: np.ndarray, distances: np.ndarray, num_features: int, feature_group_size: int = 1 262 | ) -> Tuple[Dict[int, Tuple[float, ...]], float, float]: 263 | from sklearn.linear_model import LinearRegression 264 | 265 | sample_weights = _kernel_fn(distances) 266 | sample_weights[0] *= min(distances[1:]) * len(distances) 267 | 268 | y = y - y[0] 269 | 270 | used_features = _forward_selection(X, y, sample_weights, num_features, feature_group_size) 271 | X_used = X[:, used_features] 272 | model = LinearRegression(fit_intercept=False) 273 | model.fit(X_used, y, sample_weight=sample_weights) 274 | 275 | prediction_score = float(model.score(X_used, y, sample_weight=sample_weights)) 276 | local_prediction = model.predict(X_used[0:1])[0] 277 | coefs = {} 278 | for f in range(0, len(used_features), feature_group_size): 279 | features = used_features[f : f + feature_group_size].tolist() 280 | coefs[features[0] // feature_group_size] = model.coef_[f : f + feature_group_size].tolist() 281 | 282 | return coefs, prediction_score, local_prediction 283 | 284 | 285 | def _harmonic_mean(x, y): 286 | return 2 * (x * y) / (x + y) 287 | 288 | 289 | def _create_explanation( 290 | interpretable_record_pair: _InterpretableRecordPair, 291 | coefs: Dict[int, Tuple[float, ...]], 292 | prediction_score: float, 293 | dual_explanation: bool, 294 | metadata: Dict, 295 | ) -> MatchingAttributionExplanation: 296 | def str_val(val): 297 | return None if pd.isna(val) else str(val) 298 | 299 | record_pair = interpretable_record_pair.record_pair 300 | string_representation = {} 301 | for source, attr in record_pair.columns.to_list(): 302 | string_representation[(source, attr, "attr")] = interpretable_record_pair.string_representation.get( 303 | (source, attr, "attr"), attr 304 | ) 305 | string_representation[(source, attr, "val")] = interpretable_record_pair.string_representation.get( 306 | (source, attr, "val"), str_val(record_pair[source, attr].iloc[0]) 307 | ) 308 | 309 | attributions = [] 310 | for i in coefs: 311 | weight = -coefs[i][0] 312 | if len(coefs[i]) == 2: 313 | potential = coefs[i][1] 314 | else: 315 | potential = None 316 | attributions.append( 317 | Attribution(weight=weight, potential=potential, positions=interpretable_record_pair.get_all_pos(i)) 318 | ) 319 | 320 | return MatchingAttributionExplanation( 321 | record_pair, string_representation, attributions, prediction_score, dual_explanation, metadata=metadata 322 | ) 323 | 324 | 325 | def _explain( 326 | record_pair: pd.DataFrame, 327 | predict_proba: Callable, 328 | explain_sources: str, 329 | num_features: int, 330 | num_samples: Optional[int], 331 | granularity: str, 332 | token_representation: str, 333 | token_patterns: Union[str, List[str], Dict], 334 | dual_explanation: bool, 335 | estimate_potential: bool, 336 | explain_attrs: bool, 337 | attribution_method: str, 338 | show_progress: bool, 339 | random_state: np.random.Generator, 340 | ) -> MatchingAttributionExplanation: 341 | if not ( 342 | granularity in ["tokens", "attributes", "counterfactual"] 343 | or re.fullmatch("[0-9]+-tokens", granularity) 344 | or re.fullmatch("counterfactual-x[0-9]+", granularity) 345 | ): 346 | raise ValueError( 347 | "granularity must be 'tokens', 'attributes', 'counterfactual', or on the format '*-tokens' / 'counterfactual-x*' (where * is an integer)" 348 | ) 349 | if attribution_method not in ["lime", "shap"]: 350 | raise ValueError("attribution_method must be either 'lime' or 'shap'") 351 | if attribution_method == "shap" and estimate_potential == True: 352 | raise ValueError("attribution_method='shap' can't be used when estimate_potential=True") 353 | 354 | if granularity.startswith("counterfactual"): 355 | max_tokens_in_attribute = 1 356 | if explain_sources in ["a", "both"]: 357 | for attr, value in record_pair["a"].iloc[0].items(): 358 | max_tokens_in_attribute = max( 359 | max_tokens_in_attribute, 360 | len(TokenizedString.tokenize(value, token_patterns)) if isinstance(value, str) else 1, 361 | ) 362 | if explain_sources in ["b", "both"]: 363 | for attr, value in record_pair["b"].iloc[0].items(): 364 | max_tokens_in_attribute = max( 365 | max_tokens_in_attribute, 366 | len(TokenizedString.tokenize(value, token_patterns)) if isinstance(value, str) else 1, 367 | ) 368 | 369 | if re.fullmatch(".+-x[0-9]+", granularity): 370 | base = int(granularity.split("-")[-1][1:]) 371 | else: 372 | base = 2 373 | 374 | granularities = ["tokens"] 375 | n = base 376 | while n < max_tokens_in_attribute: 377 | granularities.append(f"{n}-tokens") 378 | n *= base 379 | granularities.append("attributes") 380 | 381 | else: 382 | granularities = [granularity] 383 | 384 | best_explanation: Tuple[float, Optional[MatchingAttributionExplanation]] = (float("-inf"), None) 385 | for g in granularities: 386 | interpretable_record_pair = _InterpretableRecordPair( 387 | record_pair, 388 | granularity=g, 389 | token_representation=token_representation, 390 | features_a=explain_sources in ["a", "both"], 391 | features_b=explain_sources in ["b", "both"], 392 | features_attr=explain_attrs, 393 | features_val=True, 394 | token_regexes=token_patterns, 395 | ) 396 | 397 | if len(interpretable_record_pair) == 0: 398 | return _create_explanation( 399 | interpretable_record_pair, 400 | coefs={}, 401 | prediction_score=float( 402 | np.array( 403 | predict_proba( 404 | records_a=record_pair["a"].rename_axis(index="rid"), 405 | records_b=record_pair["b"].rename_axis(index="rid"), 406 | record_id_pairs=pd.DataFrame( 407 | {"a.rid": [record_pair.index[0]], "b.rid": [record_pair.index[0]]} 408 | ), 409 | ) 410 | )[0] 411 | ), 412 | dual_explanation=dual_explanation, 413 | metadata={"r2_score": None, "granularity": granularity, "token_representation": token_representation}, 414 | ) 415 | 416 | if attribution_method == "lime": 417 | if num_samples is None: 418 | n = len(interpretable_record_pair) 419 | num_samples = max(min(30 * n, 3000), 500) 420 | 421 | samples = _InterpretableSamples( 422 | num_samples=num_samples, 423 | record_pair=interpretable_record_pair, 424 | random_state=random_state, 425 | perturb_injection=estimate_potential, 426 | ) 427 | predictions = _get_predictions( 428 | samples.features(dummy_encode=False), 429 | interpretable_record_pair, 430 | predict_proba, 431 | random_state, 432 | show_progress=show_progress, 433 | ) 434 | distances = samples.distances 435 | 436 | coefs, r2_score, local_prediction = _lime( 437 | samples.features(), 438 | predictions, 439 | distances, 440 | num_features, 441 | feature_group_size=(2 if estimate_potential else 1), 442 | ) 443 | 444 | exp = _create_explanation( 445 | interpretable_record_pair, 446 | coefs, 447 | float(predictions[0]), 448 | dual_explanation, 449 | metadata={ 450 | "r2_score": r2_score, 451 | "granularity": granularity, 452 | "token_representation": token_representation, 453 | }, 454 | ) 455 | elif attribution_method == "shap": 456 | try: 457 | import shap 458 | except ImportError: 459 | raise ImportError("You need to have the shap library installed to use attribution_method='shap'") 460 | 461 | def wrapped_predict_proba(X): 462 | return _get_predictions( 463 | X, interpretable_record_pair, predict_proba, random_state, show_progress=show_progress 464 | ) 465 | 466 | explainer = shap.KernelExplainer(wrapped_predict_proba, np.ones((1, len(interpretable_record_pair)))) 467 | shap_values = explainer.shap_values( 468 | np.zeros((1, len(interpretable_record_pair))), 469 | nsamples=num_samples if num_samples is not None else "auto", 470 | ) 471 | 472 | coefs = {i: (-w,) for i, w in enumerate(shap_values[0])} 473 | exp = _create_explanation( 474 | interpretable_record_pair, 475 | coefs, 476 | float(wrapped_predict_proba(np.zeros((1, len(interpretable_record_pair))))[0]), 477 | dual_explanation, 478 | metadata={ 479 | "granularity": granularity, 480 | "token_representation": token_representation, 481 | }, 482 | ) 483 | else: 484 | raise AssertionError 485 | 486 | if granularity.startswith("counterfactual"): 487 | counterfactual_strength, predicted_counterfactual_strength, _ = explanation_counterfactual_strength( 488 | exp, predict_proba, random_state 489 | ) 490 | h_mean_cfs = _harmonic_mean( 491 | counterfactual_strength + 0.5 + 1e-6, predicted_counterfactual_strength + 0.5 + 1e-6 492 | ) 493 | if h_mean_cfs > best_explanation[0]: 494 | best_explanation = (h_mean_cfs, exp) 495 | if counterfactual_strength >= 0.1 and predicted_counterfactual_strength >= 0.1: 496 | break 497 | else: 498 | best_explanation = (0, exp) 499 | 500 | return best_explanation[1] 501 | 502 | 503 | def _explain_record_pair( 504 | record_pair, 505 | predict_proba, 506 | num_features, 507 | dual_explanation, 508 | estimate_potential, 509 | granularity, 510 | num_samples, 511 | token_representation, 512 | token_patterns, 513 | explain_attrs, 514 | attribution_method, 515 | show_progress, 516 | random_state, 517 | ): 518 | if not dual_explanation: 519 | return _explain( 520 | record_pair, 521 | predict_proba, 522 | explain_sources="both", 523 | num_features=num_features, 524 | num_samples=num_samples, 525 | granularity=granularity, 526 | token_representation=token_representation, 527 | token_patterns=token_patterns, 528 | dual_explanation=False, 529 | estimate_potential=estimate_potential, 530 | explain_attrs=explain_attrs, 531 | attribution_method=attribution_method, 532 | show_progress=show_progress, 533 | random_state=random_state, 534 | ) 535 | else: 536 | explanation_a: MatchingAttributionExplanation 537 | explanation_b: MatchingAttributionExplanation 538 | explanation_a, explanation_b = [ 539 | _explain( 540 | record_pair, 541 | predict_proba, 542 | explain_sources=source, 543 | num_features=num_features, 544 | num_samples=num_samples, 545 | granularity=granularity, 546 | token_representation=token_representation, 547 | token_patterns=token_patterns, 548 | dual_explanation=True, 549 | estimate_potential=estimate_potential, 550 | explain_attrs=explain_attrs, 551 | attribution_method=attribution_method, 552 | show_progress=show_progress, 553 | random_state=random_state, 554 | ) 555 | for source in ["a", "b"] 556 | ] 557 | if not math.isclose(explanation_a.prediction_score, explanation_b.prediction_score, rel_tol=1e-2): 558 | warnings.warn( 559 | f"The prediction score from explanation a and b should be (at least almost) identical, but was {explanation_a.prediction_score} and {explanation_b.prediction_score}" 560 | ) 561 | string_representation = { 562 | **{p: s for p, s in explanation_a.string_representation.items() if p[0] == "a"}, 563 | **{p: s for p, s in explanation_b.string_representation.items() if p[0] == "b"}, 564 | } 565 | return MatchingAttributionExplanation( 566 | record_pair, 567 | string_representation, 568 | attributions=explanation_a.attributions + explanation_b.attributions, 569 | prediction_score=(explanation_a.prediction_score + explanation_b.prediction_score) / 2, 570 | dual=True, 571 | metadata={"a": explanation_a.metadata, "b": explanation_b.metadata}, 572 | ) 573 | 574 | 575 | def explain( 576 | records_a: pd.DataFrame, 577 | records_b: pd.DataFrame, 578 | record_id_pairs: pd.DataFrame, 579 | predict_proba: Callable, 580 | *, 581 | num_features: int = 5, 582 | dual_explanation: bool = True, 583 | estimate_potential: bool = True, 584 | granularity: str = "counterfactual", 585 | num_samples: int = None, 586 | token_representation: str = "record-bow", 587 | token_patterns: Union[str, List[str], Dict] = "[^ ]+", 588 | explain_attrs: bool = False, 589 | attribution_method: str = "lime", 590 | show_progress: bool = True, 591 | random_state: Union[int, np.random.Generator, None] = 0, 592 | return_dict: bool = None, 593 | ) -> Union[MatchingAttributionExplanation, Dict[any, MatchingAttributionExplanation]]: 594 | if return_dict == False and len(record_id_pairs) != 1: 595 | raise ValueError("If return_dict=False you can only explain one record pair (but multiple were provided)") 596 | if random_state is None: 597 | random_state = np.random.default_rng() 598 | if isinstance(random_state, int): 599 | random_state = np.random.default_rng(random_state) 600 | 601 | records_a = records_a.convert_dtypes() 602 | records_b = records_b.convert_dtypes() 603 | records_a = ( 604 | record_id_pairs[["a.rid"]] 605 | .merge(records_a, how="left", left_on="a.rid", right_index=True) 606 | .rename(columns={"a.rid": "rid"}) 607 | ) 608 | records_b = ( 609 | record_id_pairs[["b.rid"]] 610 | .merge(records_b, how="left", left_on="b.rid", right_index=True) 611 | .rename(columns={"b.rid": "rid"}) 612 | ) 613 | records_a = records_a.drop(columns="rid") 614 | records_b = records_b.drop(columns="rid") 615 | record_pairs = pd.concat((records_a, records_b), axis=1, keys=["a", "b"], names=["source", "attribute"]) 616 | 617 | explanations = {} 618 | s = random_state.bit_generator.state 619 | if show_progress and len(record_pairs) > 1: 620 | from tqdm.auto import trange 621 | else: 622 | trange = range 623 | for i in trange(len(record_pairs)): 624 | random_state.bit_generator.state = s 625 | explanations[record_pairs.index[i]] = _explain_record_pair( 626 | record_pairs.iloc[i : i + 1], 627 | predict_proba, 628 | num_features=num_features, 629 | dual_explanation=dual_explanation, 630 | estimate_potential=estimate_potential, 631 | granularity=granularity, 632 | num_samples=num_samples, 633 | token_representation=token_representation, 634 | token_patterns=token_patterns, 635 | explain_attrs=explain_attrs, 636 | attribution_method=attribution_method, 637 | show_progress=show_progress and len(record_pairs) == 1, 638 | random_state=random_state, 639 | ) 640 | 641 | if return_dict == True: 642 | return explanations 643 | elif return_dict == False: 644 | assert len(explanations) == 1 645 | return list(explanations.values())[0] 646 | elif return_dict is None: 647 | if len(explanations) == 1: 648 | return list(explanations.values())[0] 649 | else: 650 | return explanations 651 | else: 652 | raise TypeError 653 | -------------------------------------------------------------------------------- /src/lemon/_lemon_utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import random 3 | import re 4 | import statistics 5 | from bisect import bisect_left 6 | from functools import lru_cache 7 | from typing import Dict, List, Set, Tuple 8 | 9 | import numpy as np 10 | import pandas as pd 11 | 12 | 13 | class _StringSpans: 14 | __slots__ = ("string", "spans") 15 | 16 | def __init__(self, string, spans): 17 | self.string = string 18 | self.spans = spans 19 | 20 | def __getitem__(self, i): 21 | return self.string[self.spans[2 * i] : self.spans[2 * i + 1]] 22 | 23 | def __iter__(self): 24 | return iter(self[i] for i in range(len(self.spans) // 2)) 25 | 26 | 27 | class TokenizedString: 28 | def __init__( 29 | self, 30 | string: str, 31 | spans: np.ndarray, 32 | omitted_tokens: Set[int] = None, 33 | default_delimiter=None, 34 | ): 35 | assert spans[0] == 0 and spans[-1] == len(string) 36 | self.string = string 37 | self.spans = spans 38 | self.omitted_tokens = set() if omitted_tokens is None else omitted_tokens 39 | self.default_delimiter = default_delimiter 40 | self._num_tokens = len(self.spans) // 2 - 1 41 | 42 | self.tokens = _StringSpans(string, spans[1:]) 43 | self.delimiters = _StringSpans(string, spans) 44 | 45 | if self.default_delimiter is None: 46 | if self._num_tokens >= 2: 47 | self.default_delimiter = statistics.mode([self.delimiters[i] for i in range(1, self._num_tokens)]) 48 | else: 49 | self.default_delimiter = " " 50 | 51 | def save(self): 52 | assert not self.omitted_tokens 53 | return {"string": self.string, "spans": self.spans.tolist(), "default_delimiter": self.default_delimiter} 54 | 55 | @classmethod 56 | def load(cls, d): 57 | return cls(d["string"], np.array(d["spans"], dtype=int), default_delimiter=d["default_delimiter"]) 58 | 59 | def _update(self, string=None, spans=None, omitted_tokens=None, default_delimiter=None): 60 | return TokenizedString( 61 | string=string if string is not None else self.string, 62 | spans=spans if spans is not None else self.spans, 63 | omitted_tokens=omitted_tokens if omitted_tokens is not None else self.omitted_tokens, 64 | default_delimiter=default_delimiter if default_delimiter is not None else self.default_delimiter, 65 | ) 66 | 67 | @staticmethod 68 | def tokenize(s: str, token_regexes: str = "[^ ]+", default_delimiter=None) -> "TokenizedString": 69 | if isinstance(token_regexes, str): 70 | token_regexes = [token_regexes] 71 | pattern = "|".join(f"(?:{r})" for r in token_regexes) 72 | spans = [0] 73 | for m in re.finditer(pattern, s): 74 | spans.append(m.start()) 75 | spans.append(m.end()) 76 | spans.append(len(s)) 77 | return TokenizedString( 78 | s, 79 | np.array(spans, dtype=int), 80 | default_delimiter, 81 | ) 82 | 83 | @staticmethod 84 | def from_str(s: str) -> "TokenizedString": 85 | return TokenizedString(s, np.array([0, 0, len(s), len(s)], dtype=int)) 86 | 87 | @staticmethod 88 | def from_tokens_and_delimiters(tokens: List[str], delimiters: List[str] = None) -> "TokenizedString": 89 | if delimiters is None: 90 | delimiters = [" "] * (len(tokens) + 1) 91 | s = [delimiters[0]] 92 | offset = len(delimiters[0]) 93 | spans = [0, offset] 94 | for t, d in zip(tokens, delimiters[1:]): 95 | s.append(t) 96 | offset += len(t) 97 | spans.append(offset) 98 | s.append(d) 99 | offset += len(d) 100 | spans.append(offset) 101 | 102 | return TokenizedString(string="".join(s), spans=np.array(spans, dtype=int)) 103 | 104 | @staticmethod 105 | def empty(default_delimiter=" ") -> "TokenizedString": 106 | return TokenizedString("", np.array([0, 0], dtype=int), default_delimiter=default_delimiter) 107 | 108 | def insert(self, i, token, *, before_delimiter=None, delimiters=None): 109 | assert 0 <= i <= self._num_tokens 110 | 111 | if i < self._num_tokens and i in self.omitted_tokens: 112 | cur_start = self.spans[2 * i + 1] 113 | cur_end = self.spans[2 * i + 2] 114 | string = self.string[:cur_start] + token + self.string[cur_end:] 115 | spans = self.spans.copy() 116 | index_change = (cur_start + len(token)) - cur_end 117 | spans[2 * i + 2 :] += index_change 118 | omitted_tokens = self.omitted_tokens - {i} 119 | return self._update(string, spans, omitted_tokens) 120 | 121 | if before_delimiter is None: 122 | before_delimiter = i != self._num_tokens 123 | 124 | if delimiters is None: 125 | if i == 0 and before_delimiter: 126 | if self._num_tokens > 0: 127 | delimiters = ["", self.default_delimiter if self.delimiters[0] == "" else self.delimiters[0]] 128 | else: 129 | delimiters = ["", self.delimiters[0]] 130 | elif i == self._num_tokens and not before_delimiter: 131 | if self._num_tokens > 0: 132 | delimiters = [self.default_delimiter if self.delimiters[-1] == "" else self.delimiters[-1], ""] 133 | else: 134 | delimiters = [self.delimiters[-1], ""] 135 | elif before_delimiter: 136 | if self.delimiters[i - 1] == self.delimiters[i]: 137 | delimiters = [self.delimiters[i - 1], self.delimiters[i]] 138 | else: 139 | delimiters = [self.default_delimiter, self.default_delimiter if i < self._num_tokens else ""] 140 | elif not before_delimiter: 141 | if self.delimiters[i] == self.delimiters[i + 1]: 142 | delimiters = [self.delimiters[i], self.delimiters[i + 1]] 143 | else: 144 | delimiters = [self.default_delimiter if i > 0 else "", self.default_delimiter] 145 | else: 146 | raise AssertionError 147 | 148 | string = "".join( 149 | ( 150 | self.string[: self.spans[2 * i]], 151 | delimiters[0], 152 | token, 153 | delimiters[1], 154 | self.string[self.spans[2 * i + 1] :], 155 | ) 156 | ) 157 | token_start = self.spans[2 * i] + len(delimiters[0]) 158 | token_end = token_start + len(token) 159 | second_delimiter_end = token_end + len(delimiters[1]) 160 | spans = np.empty(self.spans.size + 2, dtype=int) 161 | spans[: 2 * i + 1] = self.spans[: 2 * i + 1] 162 | spans[2 * i + 1] = token_start 163 | spans[2 * i + 2] = token_end 164 | spans[2 * i + 3] = second_delimiter_end 165 | if i < self._num_tokens: 166 | index_change = second_delimiter_end - self.spans[2 * i + 1] 167 | spans[2 * i + 4 :] = self.spans[2 * i + 2 :] + index_change 168 | 169 | return self._update(string, spans) 170 | 171 | def omit(self, start, end=None): 172 | if not end: 173 | end = start + 1 174 | 175 | assert 0 <= start < self._num_tokens 176 | assert 0 < end <= self._num_tokens 177 | 178 | return self._update(omitted_tokens=self.omitted_tokens | set(range(start, end))) 179 | 180 | def merge(self, start, end): 181 | assert 0 <= start < self._num_tokens 182 | assert 0 < end <= self._num_tokens 183 | 184 | spans = np.concatenate( 185 | ( 186 | self.spans[: 2 * start + 2], 187 | self.spans[2 * end :], 188 | ) 189 | ) 190 | 191 | return self._update(spans=spans) 192 | 193 | @lru_cache(maxsize=1) 194 | def untokenize(self): 195 | if not self.omitted_tokens: 196 | return self.string 197 | 198 | s = [] 199 | prev_i = -1 200 | new_delimiter = self.delimiters[0] 201 | for i in sorted(self.omitted_tokens): 202 | if i > 0 and prev_i < i - 1: 203 | s.append(new_delimiter) 204 | s.append(self.string[self.spans[2 * prev_i + 2] : self.spans[2 * i]]) 205 | new_delimiter = self.delimiters[i] 206 | if new_delimiter != self.delimiters[i + 1]: 207 | if i == 0 or i == self._num_tokens - 1: 208 | new_delimiter = "" 209 | else: 210 | new_delimiter = self.default_delimiter 211 | prev_i = i 212 | if new_delimiter is not None: 213 | s.append(new_delimiter) 214 | if prev_i + 1 < self._num_tokens: 215 | s.append(self.string[self.spans[2 * prev_i + 3] :]) 216 | return "".join(s) 217 | 218 | def __len__(self): 219 | return self._num_tokens 220 | 221 | def __getitem__(self, item): 222 | return self.tokens[item] 223 | 224 | def __bool__(self): 225 | return bool(len(self)) 226 | 227 | def __str__(self): 228 | return self.untokenize() 229 | 230 | def __repr__(self): 231 | return f"" 232 | 233 | 234 | def _fast_choice(population, weights, k, std_random): 235 | weights = weights.copy() 236 | samples = [] 237 | for _ in range(k): 238 | cum_weights = weights.cumsum() 239 | i = bisect_left(cum_weights, cum_weights[-1] * std_random.random()) 240 | weights[i] = 0 241 | samples.append(population[i]) 242 | return samples 243 | 244 | 245 | def _materialize_represenation(repr): 246 | repr = {p: v.untokenize() if isinstance(v, TokenizedString) else v for p, v in repr.items()} 247 | attrs = {(source, attr): v for (source, attr, attr_or_val), v in repr.items() if attr_or_val == "attr"} 248 | vals = {(source, attr): v for (source, attr, attr_or_val), v in repr.items() if attr_or_val == "val"} 249 | return attrs, vals 250 | 251 | 252 | def perturb_record_pair( 253 | record_pair, 254 | perturbations, 255 | string_representation=None, 256 | random_state: np.random.Generator = None, 257 | num_injection_sampling: int = None, 258 | injection_only_append_to_same_attr: bool = False, 259 | ) -> Tuple[pd.DataFrame, List[Dict[Tuple[str, str], str]], List[int]]: 260 | if string_representation is None: 261 | string_representation = {} 262 | representation = string_representation.copy() 263 | for source, attr in record_pair.columns.to_list(): 264 | if (source, attr, "attr") not in representation: 265 | representation[(source, attr, "attr")] = attr 266 | if (source, attr, "val") not in representation: 267 | representation[(source, attr, "val")] = record_pair[source, attr].iloc[0] 268 | 269 | if random_state is None: 270 | random_state = np.random.default_rng() 271 | std_random = random.Random(random_state.integers(1e9)) 272 | 273 | relevant_target_attrs = {} 274 | for source, attr in record_pair.columns.to_list(): 275 | injection_type = record_pair.dtypes[source, attr] 276 | target_source = "a" if source == "b" else "b" 277 | 278 | relevant_target_attrs[(source, attr, "attr")] = list(c for c in record_pair[target_source].columns if c != attr) 279 | 280 | relevant_target_attrs[(source, attr, "val")] = [ 281 | target_attr 282 | for target_attr in list(record_pair[target_source].columns) 283 | if str(injection_type) == str(record_pair.dtypes[target_source, target_attr]) 284 | or pd.api.types.is_string_dtype(record_pair.dtypes[target_source, target_attr]) 285 | ] 286 | 287 | relevant_targets = {} 288 | for (source, attr, attr_or_val), target_attrs in relevant_target_attrs.items(): 289 | targets = [] 290 | target_weights = [] 291 | target_source = "a" if source == "b" else "b" 292 | 293 | if injection_only_append_to_same_attr: 294 | target_weights.append(1.0) 295 | targets.append((target_source, attr, attr_or_val, len(representation[(target_source, attr, attr_or_val)]))) 296 | else: 297 | for target_attr in target_attrs: 298 | target_pp = (target_source, target_attr, attr_or_val) 299 | if isinstance(representation[target_pp], TokenizedString): 300 | num_target_j = len(representation[target_pp]) + 1 301 | else: 302 | num_target_j = 1 303 | target_attr_weight = max(1, len(target_attrs) - 1) if attr == target_attr else 1 304 | target_weights.extend([target_attr_weight / num_target_j for _ in range(num_target_j)]) 305 | targets.extend( 306 | [(target_source, target_attr, attr_or_val, target_j) for target_j in range(num_target_j)] 307 | ) 308 | 309 | target_weights = np.array(target_weights) 310 | target_weights = target_weights / target_weights.sum() if target_weights.size else target_weights 311 | relevant_targets[(source, attr, attr_or_val)] = (targets, target_weights) 312 | 313 | perturbed_representations = [] 314 | groups = [] 315 | for group_i, (exclusions, injections) in enumerate(perturbations): 316 | with_exclusions = representation.copy() 317 | token_exclusions = [p for p in exclusions if p[3] is not None] 318 | value_exclusions = [p for p in exclusions if p[3] is None] 319 | for source, attr, attr_or_val, j in token_exclusions: 320 | with_exclusions[(source, attr, attr_or_val)] = with_exclusions[(source, attr, attr_or_val)].omit(j) 321 | for source, attr, attr_or_val, _ in value_exclusions: 322 | existing_value = with_exclusions[(source, attr, attr_or_val)] 323 | if isinstance(existing_value, TokenizedString): 324 | new_value = TokenizedString.empty(default_delimiter=existing_value.default_delimiter) 325 | elif isinstance(existing_value, str): 326 | new_value = "" 327 | else: 328 | new_value = None 329 | with_exclusions[(source, attr, attr_or_val)] = new_value 330 | 331 | if not injections or not injection_only_append_to_same_attr: 332 | perturbed_representations.append(with_exclusions) 333 | groups.append(group_i) 334 | 335 | if injections: 336 | max_injection_sampling = 0 337 | for (source, attr, attr_or_val, j) in injections: 338 | target_source = "a" if source == "b" else "b" 339 | target_attrs = relevant_target_attrs[(source, attr, attr_or_val)] 340 | suggested_injection_sampling = 0 341 | for target_attr in target_attrs: 342 | target_value = with_exclusions[(target_source, target_attr, attr_or_val)] 343 | if isinstance(target_value, TokenizedString): 344 | suggested_injection_sampling += min(3, len(target_value)) 345 | else: 346 | suggested_injection_sampling += 1 347 | suggested_injection_sampling = min(10, suggested_injection_sampling) 348 | max_injection_sampling = max(max_injection_sampling, suggested_injection_sampling) 349 | 350 | if num_injection_sampling is None: 351 | num_injection_sampling = max_injection_sampling 352 | if injection_only_append_to_same_attr: 353 | num_injection_sampling = 1 354 | 355 | injection_targets_used = {p: set() for p in injections} 356 | sampled_targets = {} 357 | for p in injections: 358 | targets, target_weights = relevant_targets[p[:3]] 359 | sampled_targets[p] = [] 360 | while len(sampled_targets[p]) < num_injection_sampling and targets: 361 | sampled_targets[p].extend( 362 | _fast_choice( 363 | targets, 364 | weights=target_weights, 365 | k=min(len(targets), num_injection_sampling - len(sampled_targets[p])), 366 | std_random=std_random, 367 | ) 368 | ) 369 | 370 | for sampling_i in range(num_injection_sampling): 371 | perturbed = with_exclusions.copy() 372 | if injection_only_append_to_same_attr: 373 | injections = sorted(injections, reverse=True, key=lambda inj: inj[3]) 374 | else: 375 | std_random.shuffle(injections) 376 | for p in injections: 377 | source, attr, attr_or_val, j = p 378 | pp = p[:3] 379 | 380 | if isinstance(representation[pp], TokenizedString): 381 | injection_value = representation[pp][j] 382 | else: 383 | injection_value = representation[pp] 384 | 385 | if pd.isna(injection_value): 386 | continue 387 | 388 | if not sampled_targets[p]: 389 | continue 390 | target = sampled_targets[p][sampling_i] 391 | 392 | injection_targets_used[p].add(target) 393 | target_pp = target[:3] 394 | 395 | if isinstance(perturbed[target_pp], TokenizedString): 396 | perturbed[target_pp] = perturbed[target_pp].insert( 397 | target[3], str(injection_value), before_delimiter=std_random.random() < 0.5 398 | ) 399 | elif isinstance(perturbed[target_pp], str): 400 | perturbed[target_pp] = TokenizedString.from_str(str(injection_value)) 401 | elif pd.isna(perturbed[target_pp]): 402 | if isinstance(injection_value, str): 403 | perturbed[target_pp] = TokenizedString.from_str(str(injection_value)) 404 | else: 405 | perturbed[target_pp] = injection_value 406 | else: 407 | perturbed[target_pp] = injection_value 408 | 409 | perturbed_representations.append(perturbed) 410 | groups.append(group_i) 411 | 412 | all_attrs, all_vals = zip(*[_materialize_represenation(repr) for repr in perturbed_representations]) 413 | record_pairs = pd.DataFrame(all_vals, columns=record_pair.columns).astype(record_pair.dtypes) 414 | return record_pairs, all_attrs, groups 415 | 416 | 417 | def get_predictions_scores_for_perturbed_record_pairs( 418 | record_pairs, attr_strings, groups, predict_proba, show_progress 419 | ) -> np.ndarray: 420 | 421 | # Avoid running prediction on duplicates 422 | num_groups = groups[-1] + 1 423 | dtypes = record_pairs.dtypes 424 | record_pairs = record_pairs.assign(attr_strings=[str(x) for x in attr_strings], group=groups) 425 | record_pairs = ( 426 | record_pairs.groupby(by=record_pairs.columns[:-1].to_list(), as_index=False, dropna=False) 427 | .agg(list) 428 | .astype(dtypes) 429 | ) 430 | groups_per_unique_pair = record_pairs["group"] 431 | record_pairs = record_pairs[record_pairs.columns[:-2]] 432 | 433 | records_a = record_pairs["a"].rename_axis(index="rid") 434 | records_b = record_pairs["b"].rename_axis(index="rid") 435 | record_id_pairs = pd.DataFrame({"a.rid": range(len(record_pairs)), "b.rid": range(len(record_pairs))}).rename_axis( 436 | index="pid" 437 | ) 438 | 439 | predict_proba_kwargs = {} 440 | if "show_progress" in inspect.signature(predict_proba).parameters: 441 | predict_proba_kwargs["show_progress"] = show_progress 442 | if "attr_strings" in inspect.signature(predict_proba).parameters: 443 | predict_proba_kwargs["attr_strings"] = attr_strings 444 | all_predictions = np.array(predict_proba(records_a, records_b, record_id_pairs, **predict_proba_kwargs)) 445 | 446 | predictions = [float("-inf")] * num_groups 447 | for p, groups_for_pair in zip(all_predictions, groups_per_unique_pair): 448 | for g in groups_for_pair: 449 | predictions[g] = max(predictions[g], p) 450 | predictions = np.array(predictions) 451 | 452 | return predictions 453 | -------------------------------------------------------------------------------- /src/lemon/_matching_attribution_explanation.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import colorsys 3 | import dataclasses 4 | import io 5 | import json 6 | from dataclasses import dataclass 7 | from io import BytesIO 8 | from typing import Callable, Dict, List, Optional, Tuple, Union 9 | 10 | import numpy as np 11 | import pandas as pd 12 | 13 | from ._lemon_utils import TokenizedString, get_predictions_scores_for_perturbed_record_pairs, perturb_record_pair 14 | 15 | 16 | def _correct_intensity(w): 17 | if w > 0: 18 | return 0.05 + 0.95 * w 19 | else: 20 | return -0.05 + 0.95 * w 21 | 22 | 23 | def _highlight_string(string, weight, potential): 24 | h = 1 / 3 if weight > 0 else 0 25 | l = np.clip(1 - _correct_intensity(weight + potential - min(0.0, weight)) / 2, 0.5, 1) 26 | s = np.clip(_correct_intensity(abs(weight) / (abs(weight) + potential)), 0, 1) if potential != 0 else 1 27 | rgb = colorsys.hls_to_rgb(h, l, s) 28 | rgb = tuple(int(round(255 * x)) for x in rgb) 29 | r, g, b = rgb 30 | return f'{string}' 31 | 32 | 33 | def _highlight_tokenized_string(s: TokenizedString, attributions): 34 | html = [s.delimiters[0]] 35 | for i, token in enumerate(s.tokens): 36 | if i in attributions: 37 | html.append(_highlight_string(token, *attributions[i])) 38 | else: 39 | html.append(token) 40 | html.append(s.delimiters[i + 1]) 41 | return "".join(html) 42 | 43 | 44 | @dataclass 45 | class Attribution: 46 | weight: float 47 | positions: List[Union[Tuple[str, str, str, Optional[int]]]] 48 | potential: float = None 49 | name: str = None 50 | 51 | def to_dict(self): 52 | return dataclasses.asdict(self) 53 | 54 | @property 55 | def consistent_potential(self): 56 | return max(0.0, 0.0 if self.potential is None else self.potential, -self.weight) 57 | 58 | 59 | class MatchingAttributionExplanation: 60 | def __init__( 61 | self, 62 | record_pair: pd.DataFrame, 63 | string_representation: Dict[Tuple, Union[None, str, TokenizedString]], 64 | attributions: List[Attribution], 65 | prediction_score: float, 66 | dual: bool, 67 | *, 68 | metadata: Dict[str, any] = None, 69 | ): 70 | self.record_pair = record_pair 71 | self.string_representation = string_representation 72 | self.attributions = attributions 73 | self.prediction_score = prediction_score 74 | self.dual = dual 75 | self.metadata = metadata if metadata is not None else {} 76 | 77 | def only_attributions_from(self, source) -> "MatchingAttributionExplanation": 78 | return MatchingAttributionExplanation( 79 | record_pair=self.record_pair, 80 | string_representation=self.string_representation, 81 | attributions=[a for a in self.attributions if all(p[0] == source for p in a.positions)], 82 | prediction_score=self.prediction_score, 83 | dual=self.dual, 84 | ) 85 | 86 | def save(self, path: str = None): 87 | record_pair_feather = io.BytesIO() 88 | rp = self.record_pair.copy() 89 | rp.columns = [".".join(c) for c in rp.columns.to_flat_index()] 90 | rp = rp.reset_index() 91 | rp.to_feather(record_pair_feather) 92 | string_representation = [ 93 | (p, val.save() if isinstance(val, TokenizedString) else val) 94 | for p, val in self.string_representation.items() 95 | ] 96 | dump = { 97 | "record_pair.feather": base64.b64encode(record_pair_feather.getvalue()).decode(), 98 | "string_representation": string_representation, 99 | "attributions": [a.to_dict() for a in self.attributions], 100 | "prediction_score": self.prediction_score, 101 | "dual": self.dual, 102 | "metadata": self.metadata, 103 | } 104 | if path is None: 105 | return dump 106 | else: 107 | with open(path, "w") as f: 108 | json.dump(dump, f) 109 | 110 | @classmethod 111 | def load(cls, path: Union[str, Dict]): 112 | if isinstance(path, str): 113 | with open(path, "r") as f: 114 | path = json.load(f) 115 | else: 116 | if not isinstance(path, dict): 117 | raise TypeError(f"Expecting dict, but got {path}") 118 | d = path 119 | record_pair_feather = io.BytesIO(base64.b64decode(d["record_pair.feather"].encode())) 120 | rp = pd.read_feather(record_pair_feather).set_index("pid") 121 | rp.columns = pd.MultiIndex.from_tuples([c.split(".", 1) for c in rp.columns], names=["source", "attribute"]) 122 | 123 | string_representation = { 124 | tuple(p): TokenizedString.load(val) if isinstance(val, dict) else val 125 | for p, val in d["string_representation"] 126 | } 127 | 128 | attributions = [Attribution(**a) for a in d["attributions"]] 129 | for a in attributions: 130 | a.positions = [tuple(p) for p in a.positions] 131 | 132 | return cls( 133 | record_pair=rp, 134 | string_representation=string_representation, 135 | attributions=attributions, 136 | prediction_score=d["prediction_score"], 137 | dual=d["dual"], 138 | metadata=d["metadata"], 139 | ) 140 | 141 | def _highlighted_record(self, s): 142 | html = [ 143 | """ 144 | 149 | 150 | 151 | """ 152 | ] 153 | for attr in self.record_pair[s].columns: 154 | html.append('') 155 | attr_attributions = [] 156 | val_attributions = [] 157 | for attribution in self.attributions: 158 | for pos in attribution.positions: 159 | if pos[0] == s and pos[1] == attr: 160 | if pos[2] == "attr": 161 | attr_attributions.append( 162 | (attribution.weight, max(0.0, attribution.consistent_potential), pos[3]) 163 | ) 164 | else: 165 | val_attributions.append( 166 | (attribution.weight, max(0.0, attribution.consistent_potential), pos[3]) 167 | ) 168 | 169 | attr_repr = self.string_representation[(s, attr, "attr")] 170 | attr_str = "" if attr_repr is None else str(attr_repr) 171 | if attr_attributions: 172 | if isinstance(attr_repr, TokenizedString): 173 | attr_str = _highlight_tokenized_string(attr_repr, {i: (w, p) for w, p, i in attr_attributions}) 174 | else: 175 | assert len(attr_attributions) == 1 176 | w, p, i = attr_attributions[0] 177 | assert i is None 178 | attr_str = _highlight_string(attr_str, w, p) 179 | 180 | val_repr = self.string_representation[(s, attr, "val")] 181 | val_str = "" if val_repr is None else str(val_repr) 182 | if val_attributions: 183 | if isinstance(val_repr, TokenizedString): 184 | val_str = _highlight_tokenized_string(val_repr, {i: (w, p) for w, p, i in val_attributions}) 185 | else: 186 | assert len(val_attributions) == 1 187 | w, p, i = val_attributions[0] 188 | assert i is None 189 | val_str = _highlight_string(val_str, w, p) 190 | 191 | html.append( 192 | f'' 193 | ) 194 | html.append("") 195 | html.append("
{attr_str}{val_str}
") 196 | return "".join(html) 197 | 198 | def _prediction_header(self, use_percentage): 199 | prediction_score_str = f"{self.prediction_score:.0%}" if use_percentage else f"{self.prediction_score:.2f}" 200 | return f""" 201 |
202 | Prediction: {'Match' if self.prediction_score > 0.5 else 'Not match'} ({prediction_score_str}) 203 |
204 |
205 |
206 |
207 | """ 208 | 209 | def as_html(self, min_attribution=0.0, use_percentage=True): 210 | return f""" 211 |
212 |
{self._prediction_header(use_percentage)}
213 |
{self._highlighted_record("a")}
214 |
{self._highlighted_record("b")}
215 |
{self.plot("a", min_attribution, return_html=True)}
216 |
{self.plot("b", min_attribution, return_html=True)}
217 |
218 | """ 219 | 220 | def _repr_html_(self): 221 | return self.as_html() 222 | 223 | def plot( 224 | self, source, min_attribution=0.0, return_html=False, max_features=5, use_percentage=True, show_values=False 225 | ): 226 | import matplotlib.pyplot as plt 227 | 228 | values = [] 229 | for attribution in self.attributions: 230 | if source != "both" and all(s != source for s, _, _, _ in attribution.positions): 231 | continue 232 | if abs(attribution.weight) < min_attribution and abs(attribution.consistent_potential) < min_attribution: 233 | continue 234 | if attribution.name is not None: 235 | string = attribution.name 236 | else: 237 | s, attr, attr_or_val, j = attribution.positions[0] 238 | val = self.string_representation[(s, attr, attr_or_val)] 239 | if j is None: 240 | if val is None: 241 | string = f"<{attr}>" if attr_or_val == "attr" else f"[{attr}]" 242 | else: 243 | string = "" if val is None else str(val) 244 | if len(string) > 33: 245 | string = f"<{attr}>" if attr_or_val == "attr" else f"[{attr}]" 246 | else: 247 | string = val[j] 248 | if len(string) > 33: 249 | string = string[:30] + "..." 250 | values.append((attribution.consistent_potential, attribution.weight, string)) 251 | 252 | values.sort(key=lambda v: v[1] + v[0] - min(0.0, v[1]), reverse=True) 253 | values = values[:max_features] 254 | 255 | fig, ax = plt.subplots(figsize=(8, 0.7 * max(1, len(values)))) 256 | 257 | xs = np.arange(len(values)) - len(values) + 0.5 258 | weights = [v[1] for v in values] 259 | potential = [v[1] + v[0] for v in values] 260 | 261 | ax.barh(xs, potential, 0.3, color="#ddd", zorder=2) 262 | ax.barh(xs, weights, 0.3, color=["g" if e > 0 else "r" for e in weights], zorder=2) 263 | for i, (x, (p, w, s)) in enumerate(zip(xs, values)): 264 | ax.text( 265 | 0.02 if w > 0 else -0.02, 266 | x - 0.35, 267 | s, 268 | color="black", 269 | fontsize=12, 270 | verticalalignment="center", 271 | horizontalalignment="left" if w > 0 else "right", 272 | ) 273 | if show_values: 274 | ax.text( 275 | w + 0.03 if w > 0 else w - 0.03, 276 | x, 277 | f"{abs(w):.0%}" if use_percentage else f"{abs(w):.2f}", 278 | color="black", 279 | fontsize=8, 280 | verticalalignment="center", 281 | horizontalalignment="left" if w > 0 else "right", 282 | ) 283 | if w + p - 0.20 > w: 284 | ax.text( 285 | w + p + 0.03, 286 | x, 287 | f"{abs(w+p):.0%}" if use_percentage else f"{abs(w + p):.2f}", 288 | color="black", 289 | fontsize=8, 290 | verticalalignment="center", 291 | horizontalalignment="left", 292 | ) 293 | ax.invert_yaxis() 294 | ax.set_xlim((-1.2, 1.2)) 295 | ax.set_ylim((0.5, -max(1, len(values)))) 296 | ax.spines["left"].set_position("zero") 297 | ax.spines["bottom"].set_position("zero") 298 | ax.spines["right"].set_color("none") 299 | ax.spines["top"].set_color("none") 300 | ax.xaxis.set_ticks_position("bottom") 301 | xticks = [-1, -0.8, -0.6, -0.4, -0.2, 0.2, 0.4, 0.6, 0.8, 1.0] 302 | ax.xaxis.set_ticks(xticks) 303 | if use_percentage: 304 | ax.set_xticklabels([f"{x:.0%}" for x in xticks]) 305 | ax.yaxis.set_visible(False) 306 | ax.vlines(x=xticks, ymin=-1e9, ymax=0, colors="#EEE", linewidth=1) 307 | 308 | if return_html: 309 | f = BytesIO() 310 | fig.savefig(f, format="png", dpi=120, bbox_inches="tight") 311 | f.seek(0) 312 | 313 | png_data = base64.b64encode(f.getvalue()) 314 | plt.close() 315 | return f'' 316 | 317 | return ax 318 | 319 | 320 | def explanation_counterfactual_strength( 321 | exp: MatchingAttributionExplanation, 322 | predict_proba: Callable, 323 | random_state: np.random.RandomState, 324 | max_features=None, 325 | injection_only_append_to_same_attr: bool = False, 326 | ) -> Tuple[float, float, int]: 327 | target_prediction = exp.prediction_score 328 | exclusions = [] 329 | injections = [] 330 | num_actions = 0 331 | if exp.prediction_score > 0.5: 332 | available_attributions = sorted(exp.attributions, reverse=True, key=lambda a: a.weight) 333 | if max_features is not None: 334 | available_attributions = available_attributions[:max_features] 335 | i = 0 336 | while target_prediction > 0.4 and i < len(available_attributions): 337 | attribution = available_attributions[i] 338 | if attribution.weight > 0: 339 | target_prediction -= attribution.weight 340 | exclusions.extend(attribution.positions) 341 | num_actions += 1 342 | else: 343 | break 344 | i += 1 345 | else: 346 | available_attributions = sorted(exp.attributions, reverse=True, key=lambda a: a.consistent_potential) 347 | if max_features is not None: 348 | available_attributions = available_attributions[:max_features] 349 | i = 0 350 | while target_prediction < 0.6 and i < len(available_attributions): 351 | attribution = available_attributions[i] 352 | if attribution.consistent_potential > 0.0: 353 | if -attribution.weight >= attribution.consistent_potential: 354 | target_prediction -= attribution.weight 355 | exclusions.extend(attribution.positions) 356 | else: 357 | target_prediction += attribution.consistent_potential 358 | injections.append(attribution.positions[0]) 359 | num_actions += 1 360 | else: 361 | break 362 | i += 1 363 | 364 | string_representation = {p: v for p, v in exp.string_representation.items() if isinstance(v, TokenizedString)} 365 | record_pairs, attr_strings, groups = perturb_record_pair( 366 | exp.record_pair, 367 | perturbations=[(exclusions, injections)], 368 | string_representation=string_representation, 369 | random_state=random_state, 370 | injection_only_append_to_same_attr=injection_only_append_to_same_attr, 371 | ) 372 | prediction = get_predictions_scores_for_perturbed_record_pairs( 373 | record_pairs, attr_strings, groups, predict_proba, show_progress=False 374 | )[0] 375 | 376 | if exp.prediction_score > 0.5: 377 | return 0.5 - prediction, 0.5 - target_prediction, num_actions 378 | else: 379 | return prediction - 0.5, target_prediction - 0.5, num_actions 380 | -------------------------------------------------------------------------------- /src/lemon/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datasets, matchers 2 | 3 | __all__ = ["datasets", "matchers"] 4 | 5 | 6 | def __dir__(): 7 | return __all__ 8 | -------------------------------------------------------------------------------- /src/lemon/utils/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["deepmatcher", "Dataset", "Records", "SplittedDataset"] 2 | 3 | from . import deepmatcher 4 | from ._dataset import Dataset, Records, SplittedDataset 5 | -------------------------------------------------------------------------------- /src/lemon/utils/datasets/_dataset.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from typing import Tuple, Union 3 | 4 | import pandas as pd 5 | 6 | RecordsLike = Union["Records", Tuple[pd.DataFrame, pd.DataFrame]] 7 | 8 | Records = namedtuple("Records", ["a", "b"]) 9 | 10 | 11 | class Dataset: 12 | def __init__( 13 | self, 14 | records: RecordsLike, 15 | record_id_pairs: pd.DataFrame, 16 | labels: pd.Series = None, 17 | ): 18 | self.records = Records(*records) 19 | self.record_id_pairs = record_id_pairs 20 | self.labels = labels if labels is not None else None 21 | 22 | def __repr__(self): 23 | return f"" 24 | 25 | 26 | class SplittedDataset: 27 | def __init__( 28 | self, 29 | records: RecordsLike, 30 | *, 31 | record_id_pairs_train: pd.DataFrame, 32 | record_id_pairs_val: pd.DataFrame, 33 | record_id_pairs_test: pd.DataFrame, 34 | labels_train: pd.Series = None, 35 | labels_val: pd.Series = None, 36 | labels_test: pd.Series = None, 37 | ): 38 | self.records = Records(*records) 39 | self._record_id_pairs_train = record_id_pairs_train 40 | self._record_id_pairs_val = record_id_pairs_val 41 | self._record_id_pairs_test = record_id_pairs_test 42 | self._labels_train = labels_train if labels_train is not None else None 43 | self._labels_val = labels_val if labels_val is not None else None 44 | self._labels_test = labels_test if labels_test is not None else None 45 | 46 | self.train = Dataset(self.records, self._record_id_pairs_train, self._labels_train) 47 | self.val = Dataset(self.records, self._record_id_pairs_val, self._labels_val) 48 | self.test = Dataset(self.records, self._record_id_pairs_test, self._labels_test) 49 | 50 | def __repr__(self): 51 | return f"" 52 | -------------------------------------------------------------------------------- /src/lemon/utils/datasets/_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import urllib.request 4 | import zipfile 5 | 6 | log = logging.getLogger(__name__) 7 | 8 | _CACHE_DIR = "./datasets" 9 | 10 | 11 | def get_cache_path(*paths, cache_dir=None): 12 | if not cache_dir: 13 | cache_dir = _CACHE_DIR 14 | 15 | return os.path.join(cache_dir, *paths) 16 | 17 | 18 | def download_file(url, file_path, unzip=False, cache_dir=None): 19 | file_path = get_cache_path(file_path, cache_dir=cache_dir) 20 | if not os.path.exists(file_path): 21 | log.warning(f"Downloading {url} to {file_path}") 22 | os.makedirs(os.path.dirname(file_path), exist_ok=True) 23 | urllib.request.urlretrieve(url, file_path) 24 | if unzip: 25 | with zipfile.ZipFile(file_path) as zip_ref: 26 | zip_ref.extractall(os.path.dirname(file_path)) 27 | 28 | return file_path 29 | -------------------------------------------------------------------------------- /src/lemon/utils/datasets/deepmatcher.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | 5 | from ._dataset import SplittedDataset 6 | from ._utils import download_file, get_cache_path 7 | 8 | 9 | def _deepmatcher_dataset(name, url, dtypes, root=None, download=True): 10 | name = os.path.join("deepmatcher", name) 11 | if download: 12 | filepath = os.path.join(name, url.split("/")[-1]) 13 | download_file(url, filepath, cache_dir=root, unzip=True) 14 | return _load_deepmatcher_dataset(name, dtypes, root) 15 | 16 | 17 | def _load_deepmatcher_dataset(name, dtypes, root): 18 | dir = "exp_data" if os.path.exists(get_cache_path(name, "exp_data", cache_dir=root)) else "" 19 | records_a, records_b = [ 20 | pd.read_csv(get_cache_path(name, dir, f, cache_dir=root), index_col="id", dtype=dtype).rename_axis(index="__id") 21 | for f, dtype in [("tableA.csv", dtypes["a"]), ("tableB.csv", dtypes["b"])] 22 | ] 23 | train_pairs, val_pairs, test_pairs = [ 24 | pd.read_csv(get_cache_path(name, dir, f, cache_dir=root)) 25 | .rename(columns={"ltable_id": "a.rid", "rtable_id": "b.rid"}) 26 | .astype({"label": "bool"}) 27 | .rename_axis(index="pid") 28 | for f in ["train.csv", "valid.csv", "test.csv"] 29 | ] 30 | return SplittedDataset( 31 | records=(records_a, records_b), 32 | record_id_pairs_train=train_pairs[["a.rid", "b.rid"]], 33 | record_id_pairs_val=val_pairs[["a.rid", "b.rid"]], 34 | record_id_pairs_test=test_pairs[["a.rid", "b.rid"]], 35 | labels_train=train_pairs["label"], 36 | labels_val=val_pairs["label"], 37 | labels_test=test_pairs["label"], 38 | ) 39 | 40 | 41 | def structured_amazon_google(root=None, download=True): 42 | return _deepmatcher_dataset( 43 | "Structured/Amazon-Google", 44 | "http://pages.cs.wisc.edu/~anhai/data1/deepmatcher_data/Structured/Amazon-Google/amazon_google_exp_data.zip", 45 | dtypes={ 46 | "a": { 47 | "id": "int64", 48 | "title": "string", 49 | "manufacturer": "string", 50 | "price": "Float64", 51 | }, 52 | "b": { 53 | "id": "int64", 54 | "title": "string", 55 | "manufacturer": "string", 56 | "price": "Float64", 57 | }, 58 | }, 59 | root=root, 60 | download=download, 61 | ) 62 | 63 | 64 | def structured_beer(root=None, download=True): 65 | return _deepmatcher_dataset( 66 | "Structured/Beer", 67 | "http://pages.cs.wisc.edu/~anhai/data1/deepmatcher_data/Structured/Beer/beer_exp_data.zip", 68 | dtypes={ 69 | "a": { 70 | "id": "int64", 71 | "Beer_Name": "string", 72 | "Brew_Factory_Name": "string", 73 | "Style": "string", 74 | "ABV": "string", 75 | }, 76 | "b": { 77 | "id": "int64", 78 | "Beer_Name": "string", 79 | "Brew_Factory_Name": "string", 80 | "Style": "string", 81 | "ABV": "string", 82 | }, 83 | }, 84 | root=root, 85 | download=download, 86 | ) 87 | 88 | 89 | def structured_dblp_acm(root=None, download=True): 90 | return _deepmatcher_dataset( 91 | "Structured/DBLP-ACM", 92 | "http://pages.cs.wisc.edu/~anhai/data1/deepmatcher_data/Structured/DBLP-ACM/dblp_acm_exp_data.zip", 93 | dtypes={ 94 | "a": { 95 | "id": "int64", 96 | "title": "string", 97 | "author": "string", 98 | "venue": "string", 99 | "year": "Int64", 100 | }, 101 | "b": { 102 | "id": "int64", 103 | "title": "string", 104 | "author": "string", 105 | "venue": "string", 106 | "year": "Int64", 107 | }, 108 | }, 109 | root=root, 110 | download=download, 111 | ) 112 | 113 | 114 | def structured_dblp_google_scholar(root=None, download=True): 115 | return _deepmatcher_dataset( 116 | "Structured/DBLP-GoogleScholar", 117 | "http://pages.cs.wisc.edu/~anhai/data1/deepmatcher_data/Structured/DBLP-GoogleScholar/dblp_scholar_exp_data.zip", 118 | dtypes={ 119 | "a": { 120 | "id": "int64", 121 | "title": "string", 122 | "author": "string", 123 | "venue": "string", 124 | "year": "Int64", 125 | }, 126 | "b": { 127 | "id": "int64", 128 | "title": "string", 129 | "author": "string", 130 | "venue": "string", 131 | "year": "Int64", 132 | }, 133 | }, 134 | root=root, 135 | download=download, 136 | ) 137 | 138 | 139 | def structured_fodors_zagat(root=None, download=True): 140 | dataset = _deepmatcher_dataset( 141 | "Structured/Fodors-Zagats", 142 | "http://pages.cs.wisc.edu/~anhai/data1/deepmatcher_data/Structured/Fodors-Zagats/fodors_zagat_exp_data.zip", 143 | dtypes={ 144 | "a": { 145 | "id": "int64", 146 | "name": "string", 147 | "addr": "string", 148 | "city": "string", 149 | "phone": "string", 150 | "type": "string", 151 | }, 152 | "b": { 153 | "id": "int64", 154 | "name": "string", 155 | "addr": "string", 156 | "city": "string", 157 | "phone": "string", 158 | "type": "string", 159 | }, 160 | }, 161 | root=root, 162 | download=download, 163 | ) 164 | dataset.records.a.drop(columns="class", inplace=True) 165 | dataset.records.b.drop(columns="class", inplace=True) 166 | return dataset 167 | 168 | 169 | def structured_walmart_amazon(root=None, download=True): 170 | return _deepmatcher_dataset( 171 | "Structured/Walmart-Amazon", 172 | "http://pages.cs.wisc.edu/~anhai/data1/deepmatcher_data/Structured/Walmart-Amazon/walmart_amazon_exp_data.zip", 173 | dtypes={ 174 | "a": { 175 | "id": "int64", 176 | "title": "string", 177 | "category": "string", 178 | "brand": "string", 179 | "modelno": "string", 180 | "price": "Float64", 181 | }, 182 | "b": { 183 | "id": "int64", 184 | "title": "string", 185 | "category": "string", 186 | "brand": "string", 187 | "modelno": "string", 188 | "price": "Float64", 189 | }, 190 | }, 191 | root=root, 192 | download=download, 193 | ) 194 | 195 | 196 | def structured_itunes_amazon(root=None, download=True): 197 | return _deepmatcher_dataset( 198 | "Structured/iTunes-Amazon", 199 | "http://pages.cs.wisc.edu/~anhai/data1/deepmatcher_data/Structured/iTunes-Amazon/itunes_amazon_exp_data.zip", 200 | dtypes={ 201 | "a": { 202 | "id": "int64", 203 | "Song_Name": "string", 204 | "Artist_Name": "string", 205 | "Album_Name": "string", 206 | "Genre": "string", 207 | "Price": "string", 208 | "CopyRight": "string", 209 | "Time": "string", 210 | "Released": "string", 211 | }, 212 | "b": { 213 | "id": "int64", 214 | "Song_Name": "string", 215 | "Artist_Name": "string", 216 | "Album_Name": "string", 217 | "Genre": "string", 218 | "Price": "string", 219 | "CopyRight": "string", 220 | "Time": "string", 221 | "Released": "string", 222 | }, 223 | }, 224 | root=root, 225 | download=download, 226 | ) 227 | 228 | 229 | def dirty_dblp_acm(root=None, download=True): 230 | return _deepmatcher_dataset( 231 | "Dirty/DBLP-ACM", 232 | "http://pages.cs.wisc.edu/~anhai/data1/deepmatcher_data/Dirty/DBLP-ACM/dirty_dblp_acm_exp_data.zip", 233 | dtypes={ 234 | "a": { 235 | "id": "int64", 236 | "title": "string", 237 | "author": "string", 238 | "venue": "string", 239 | "year": "Int64", 240 | }, 241 | "b": { 242 | "id": "int64", 243 | "title": "string", 244 | "author": "string", 245 | "venue": "string", 246 | "year": "Int64", 247 | }, 248 | }, 249 | root=root, 250 | download=download, 251 | ) 252 | 253 | 254 | def dirty_dblp_google_scholar(root=None, download=True): 255 | return _deepmatcher_dataset( 256 | "Dirty/DBLP-GoogleScholar", 257 | "http://pages.cs.wisc.edu/~anhai/data1/deepmatcher_data/Dirty/DBLP-GoogleScholar/dirty_dblp_scholar_exp_data.zip", 258 | dtypes={ 259 | "a": { 260 | "id": "int64", 261 | "title": "string", 262 | "author": "string", 263 | "venue": "string", 264 | "year": "Int64", 265 | }, 266 | "b": { 267 | "id": "int64", 268 | "title": "string", 269 | "author": "string", 270 | "venue": "string", 271 | "year": "Int64", 272 | }, 273 | }, 274 | root=root, 275 | download=download, 276 | ) 277 | 278 | 279 | def dirty_walmart_amazon(root=None, download=True): 280 | return _deepmatcher_dataset( 281 | "Dirty/Walmart-Amazon", 282 | "http://pages.cs.wisc.edu/~anhai/data1/deepmatcher_data/Dirty/Walmart-Amazon/dirty_walmart_amazon_exp_data.zip", 283 | dtypes={ 284 | "a": { 285 | "id": "int64", 286 | "title": "string", 287 | "category": "string", 288 | "brand": "string", 289 | "modelno": "string", 290 | "price": "Float64", 291 | }, 292 | "b": { 293 | "id": "int64", 294 | "title": "string", 295 | "category": "string", 296 | "brand": "string", 297 | "modelno": "string", 298 | "price": "Float64", 299 | }, 300 | }, 301 | root=root, 302 | download=download, 303 | ) 304 | 305 | 306 | def dirty_itunes_amazon(root=None, download=True): 307 | return _deepmatcher_dataset( 308 | "Dirty/iTunes-Amazon", 309 | "http://pages.cs.wisc.edu/~anhai/data1/deepmatcher_data/Dirty/iTunes-Amazon/dirty_itunes_amazon_exp_data.zip", 310 | dtypes={ 311 | "a": { 312 | "id": "int64", 313 | "Song_Name": "string", 314 | "Artist_Name": "string", 315 | "Album_Name": "string", 316 | "Genre": "string", 317 | "Price": "string", 318 | "CopyRight": "string", 319 | "Time": "string", 320 | "Released": "string", 321 | }, 322 | "b": { 323 | "id": "int64", 324 | "Song_Name": "string", 325 | "Artist_Name": "string", 326 | "Album_Name": "string", 327 | "Genre": "string", 328 | "Price": "string", 329 | "CopyRight": "string", 330 | "Time": "string", 331 | "Released": "string", 332 | }, 333 | }, 334 | root=root, 335 | download=download, 336 | ) 337 | 338 | 339 | def textual_abt_buy(root=None, download=True): 340 | return _deepmatcher_dataset( 341 | "Textual/Abt-Buy", 342 | "http://pages.cs.wisc.edu/~anhai/data1/deepmatcher_data/Textual/Abt-Buy/abt_buy_exp_data.zip", 343 | dtypes={ 344 | "a": { 345 | "id": "int64", 346 | "name": "string", 347 | "description": "string", 348 | "price": "Float64", 349 | }, 350 | "b": { 351 | "id": "int64", 352 | "name": "string", 353 | "description": "string", 354 | "price": "Float64", 355 | }, 356 | }, 357 | root=root, 358 | download=download, 359 | ) 360 | 361 | 362 | def textual_company(root=None, download=True): 363 | return _deepmatcher_dataset( 364 | "Textual/Company", 365 | "http://pages.cs.wisc.edu/~anhai/data1/deepmatcher_data/Textual/Company/company_exp_data.zip", 366 | dtypes={ 367 | "a": { 368 | "id": "string", 369 | "content": "string", 370 | }, 371 | "b": { 372 | "id": "string", 373 | "content": "string", 374 | }, 375 | }, 376 | root=root, 377 | download=download, 378 | ) 379 | -------------------------------------------------------------------------------- /src/lemon/utils/matchers/__init__.py: -------------------------------------------------------------------------------- 1 | _lazy_subimports = {"TransformerMatcher": "_transformer_matcher", "MagellanMatcher": "_magellan"} 2 | 3 | 4 | __all__ = ["TransformerMatcher", "MagellanMatcher"] # type: ignore 5 | 6 | 7 | def __getattr__(name): 8 | import importlib 9 | 10 | if name in _lazy_subimports: 11 | module = importlib.import_module("." + _lazy_subimports[name], __name__) 12 | return module.__dict__[name] 13 | raise AttributeError(f"module {__name__!r} has no attribute {name!r}") 14 | 15 | 16 | def __dir__(): 17 | return __all__ 18 | -------------------------------------------------------------------------------- /src/lemon/utils/matchers/_magellan.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Tuple 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | 8 | class DummyTkinterFrame: 9 | ... 10 | 11 | 12 | class DummyTkinter: 13 | Frame = DummyTkinterFrame 14 | 15 | 16 | sys.modules["Tkinter"] = DummyTkinter # Magellan imports Tkinter, but we don't need it for our use case 17 | 18 | import py_entitymatching as em 19 | import py_entitymatching.feature.attributeutils as au 20 | import py_entitymatching.feature.simfunctions as sim 21 | import py_entitymatching.feature.tokenizers as tok 22 | 23 | del sys.modules["Tkinter"] 24 | 25 | 26 | class MagellanMatcher: 27 | def __init__(self, magellan_matcher=None): 28 | self.magellan_matcher = magellan_matcher 29 | 30 | def fit(self, records_a: pd.DataFrame, records_b: pd.DataFrame, record_id_pairs: pd.DataFrame, labels: pd.Series): 31 | records_a = records_a.rename_axis(index="rid") 32 | records_b = records_b.rename_axis(index="rid") 33 | A = records_a.reset_index() 34 | em.set_key(A, "rid") 35 | B = records_b.reset_index() 36 | em.set_key(B, "rid") 37 | 38 | A_without_rid = A.drop(columns="rid") 39 | B_without_rid = B.drop(columns="rid") 40 | 41 | self._attr_types_a = au.get_attr_types(A_without_rid) 42 | self._attr_types_b = au.get_attr_types(B_without_rid) 43 | 44 | self._attr_corres = au.get_attr_corres(A_without_rid, B_without_rid) 45 | 46 | feature_table = em.get_features( 47 | A_without_rid, 48 | B_without_rid, 49 | self._attr_types_a, 50 | self._attr_types_b, 51 | self._attr_corres, 52 | tok.get_tokenizers_for_matching(), 53 | sim.get_sim_funs_for_matching(), 54 | ) 55 | 56 | G = ( 57 | record_id_pairs.reset_index() 58 | .merge(records_a.add_prefix("a."), left_on="a.rid", right_index=True) 59 | .merge(records_b.add_prefix("b."), left_on="b.rid", right_index=True) 60 | ).sort_index() 61 | 62 | em.set_key(G, "pid") 63 | em.set_ltable(G, A) 64 | em.set_rtable(G, B) 65 | em.set_fk_ltable(G, "a.rid") 66 | em.set_fk_rtable(G, "b.rid") 67 | 68 | H = em.extract_feature_vecs( 69 | G, feature_table=feature_table, n_jobs=1, show_progress=False 70 | ) # Use n_jobs=1 to avoid Tkinter problems 71 | 72 | if self.magellan_matcher is None: 73 | self.magellan_matcher = em.RFMatcher() 74 | self.magellan_matcher.fit( 75 | table=H.fillna(0).assign(label=labels.astype("int").to_numpy()), 76 | exclude_attrs=["pid", "a.rid", "b.rid"], 77 | target_attr="label", 78 | ) 79 | 80 | def _run_predict( 81 | self, records_a: pd.DataFrame, records_b: pd.DataFrame, record_id_pairs: pd.DataFrame 82 | ) -> Tuple[np.ndarray, np.ndarray]: 83 | records_a = records_a.rename_axis(index="rid") 84 | records_b = records_b.rename_axis(index="rid") 85 | A = records_a.reset_index() 86 | em.set_key(A, "rid") 87 | B = records_b.reset_index() 88 | em.set_key(B, "rid") 89 | 90 | A_without_rid = A.drop(columns="rid") 91 | B_without_rid = B.drop(columns="rid") 92 | 93 | self._attr_types_a["_table"] = A_without_rid 94 | self._attr_types_b["_table"] = B_without_rid 95 | 96 | self._attr_corres["ltable"], self._attr_corres["rtable"] = A_without_rid, B_without_rid 97 | 98 | feature_table = em.get_features( 99 | A_without_rid, 100 | B_without_rid, 101 | self._attr_types_a, 102 | self._attr_types_b, 103 | self._attr_corres, 104 | tok.get_tokenizers_for_matching(), 105 | sim.get_sim_funs_for_matching(), 106 | ) 107 | 108 | C = ( 109 | record_id_pairs.reset_index() 110 | .merge(records_a.add_prefix("a."), left_on="a.rid", right_index=True) 111 | .merge(records_b.add_prefix("b."), left_on="b.rid", right_index=True) 112 | ).sort_index() 113 | em.set_key(C, "pid") 114 | em.set_ltable(C, A) 115 | em.set_rtable(C, B) 116 | em.set_fk_ltable(C, "a.rid") 117 | em.set_fk_rtable(C, "b.rid") 118 | 119 | L = em.extract_feature_vecs( 120 | C, feature_table=feature_table, n_jobs=1, show_progress=False 121 | ) # Use n_jobs=1 to avoid Tkinter problems 122 | predictions, probabilities = self.magellan_matcher.predict( 123 | table=L.fillna(0), 124 | exclude_attrs=["pid", "a.rid", "b.rid"], 125 | return_probs=True, 126 | ) 127 | return predictions, probabilities 128 | 129 | def predict_proba( 130 | self, record_a: pd.DataFrame, records_b: pd.DataFrame, record_id_pairs: pd.DataFrame 131 | ) -> pd.Series: 132 | _, probs = self._run_predict(record_a, records_b, record_id_pairs) 133 | return pd.Series(probs, index=record_id_pairs.index) 134 | 135 | def predict(self, record_a: pd.DataFrame, records_b: pd.DataFrame, record_id_pairs: pd.DataFrame) -> pd.Series: 136 | preds, _ = self._run_predict(record_a, records_b, record_id_pairs) 137 | return pd.Series(preds, index=record_id_pairs.index, dtype=bool) 138 | 139 | def evaluate( 140 | self, record_a: pd.DataFrame, records_b: pd.DataFrame, record_id_pairs: pd.DataFrame, labels: pd.Series 141 | ): 142 | preds, _ = self._run_predict(record_a, records_b, record_id_pairs) 143 | results = record_id_pairs.assign(label=labels.astype(int), prediction=preds).reset_index() 144 | em.set_key(results, "pid") 145 | em.set_fk_ltable(results, "a.rid") 146 | em.set_fk_rtable(results, "b.rid") 147 | evaluation = em.eval_matches(results, "label", "prediction") 148 | return { 149 | "precision": evaluation["precision"], 150 | "recall": evaluation["recall"], 151 | "f1": evaluation["f1"], 152 | } 153 | -------------------------------------------------------------------------------- /src/lemon/utils/matchers/_transformer_matcher.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple, Union 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | from torch.utils.data import Dataset as TorchDataset 7 | from transformers import AutoModelForSequenceClassification, AutoTokenizer, PreTrainedModel, Trainer, TrainingArguments 8 | 9 | 10 | def _format_records(records, attr_strings): 11 | cols = list(records.columns) 12 | return pd.DataFrame( 13 | data={ 14 | "record": [ 15 | " ".join(f"COL {attr_strings(i, c)} VAL {'' if pd.isna(v) else v}" for c, v in zip(cols, r)) 16 | for i, r in enumerate(records.itertuples(index=False, name=None)) 17 | ] 18 | }, 19 | index=records.index, 20 | dtype="string", 21 | ) 22 | 23 | 24 | class _EntityMatchingTransformerDataset(TorchDataset): 25 | def __init__( 26 | self, 27 | record_pairs: pd.DataFrame, 28 | labels: pd.Series = None, 29 | pretrained_model=None, 30 | tokenizer=None, 31 | max_length=None, 32 | defer_encoding=False, 33 | attr_strings: List[Dict[Tuple[str, str], str]] = None, 34 | ): 35 | assert pretrained_model is not None or tokenizer is not None 36 | 37 | if tokenizer: 38 | self.tokenizer = tokenizer 39 | else: 40 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model, use_fast=True) 41 | 42 | self.max_length = max_length if max_length is not None else self.tokenizer.model_max_length 43 | self.defer_encoding = defer_encoding 44 | self._len = len(record_pairs) 45 | 46 | self._labels = labels.astype("int") if labels is not None else None 47 | 48 | if attr_strings is None: 49 | attr_strings = [{}] * len(record_pairs) 50 | record_pairs = pd.concat( 51 | ( 52 | _format_records( 53 | record_pairs[[c for c in record_pairs.columns if c.startswith("a.")]].rename( 54 | columns=lambda c: c[2:] 55 | ), 56 | lambda i, attr: attr_strings[i].get(("a", attr), attr), 57 | ).add_prefix("a."), 58 | _format_records( 59 | record_pairs[[c for c in record_pairs.columns if c.startswith("b.")]].rename( 60 | columns=lambda c: c[2:] 61 | ), 62 | lambda i, attr: attr_strings[i].get(("b", attr), attr), 63 | ).add_prefix("b."), 64 | ), 65 | axis=1, 66 | ) 67 | 68 | if self.defer_encoding: 69 | self._record_pairs = record_pairs 70 | else: 71 | self._encoded_pairs = self.tokenizer( 72 | record_pairs["a.record"].tolist(), 73 | record_pairs["b.record"].tolist(), 74 | padding=False, 75 | truncation=True, 76 | max_length=self.max_length, 77 | ) 78 | 79 | def __len__(self) -> int: 80 | return self._len 81 | 82 | def __getitem__(self, index): 83 | if self.defer_encoding: 84 | encoded_pair = self.tokenizer( 85 | self._record_pairs.iloc[index]["a.record"], 86 | self._record_pairs.iloc[index]["b.record"], 87 | padding=False, 88 | truncation=True, 89 | max_length=self.max_length, 90 | ) 91 | else: 92 | encoded_pair = {k: v[index] for k, v in {**self._encoded_pairs}.items()} 93 | 94 | if self._labels is not None: 95 | return {**encoded_pair, "labels": torch.tensor(self._labels.iloc[index])} 96 | else: 97 | return encoded_pair 98 | 99 | 100 | def _compute_metrics(pred): 101 | labels = pred.label_ids 102 | preds = pred.predictions.argmax(-1) 103 | precision = (preds * labels).sum() / preds.sum() if preds.sum() > 0 else 0.0 104 | recall = (preds * labels).sum() / labels.sum() if labels.sum() > 0 else 0.0 105 | f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0.0 106 | return {"f1": f1, "precision": precision, "recall": recall} 107 | 108 | 109 | class TransformerMatcher: 110 | def __init__( 111 | self, 112 | pretrained_model: Union[str, PreTrainedModel], 113 | *, 114 | training_args: Dict = None, 115 | tokenizer_args: Dict = None, 116 | extra_input_generator=None, 117 | ): 118 | tokenizer_args = tokenizer_args if tokenizer_args is not None else {} 119 | self._tokenizer_args = {"use_fast": True, **tokenizer_args} 120 | if isinstance(pretrained_model, str): 121 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model, **tokenizer_args) 122 | self.model = AutoModelForSequenceClassification.from_pretrained(pretrained_model) 123 | else: 124 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model.config.name_or_path, **tokenizer_args) 125 | self.model = pretrained_model 126 | self.extra_input_generator = extra_input_generator 127 | 128 | default_training_args = { 129 | "output_dir": "./output/checkpoints", 130 | "num_train_epochs": 5, 131 | "evaluation_strategy": "epoch", 132 | "per_device_train_batch_size": 32, 133 | "per_device_eval_batch_size": 64, 134 | "learning_rate": 3e-5, 135 | "warmup_steps": 50, 136 | "logging_dir": "./output/logs", 137 | "logging_steps": 10, 138 | "fp16": torch.cuda.is_available(), 139 | "no_cuda": not torch.cuda.is_available(), 140 | "load_best_model_at_end": True, 141 | "metric_for_best_model": "f1", 142 | "dataloader_num_workers": 0, 143 | "save_strategy": "epoch", 144 | "report_to": "all", 145 | } 146 | self.training_args = {**default_training_args, **(training_args if training_args is not None else {})} 147 | 148 | def _create_dataset( 149 | self, 150 | records_a: pd.DataFrame, 151 | records_b: pd.DataFrame, 152 | record_id_pairs: pd.DataFrame, 153 | labels: pd.Series = None, 154 | attr_strings: List[Dict[Tuple[str, str], str]] = None, 155 | ) -> TorchDataset: 156 | record_pairs = ( 157 | ( 158 | record_id_pairs.merge(records_a.add_prefix("a."), left_on="a.rid", right_index=True).merge( 159 | records_b.add_prefix("b."), left_on="b.rid", right_index=True 160 | ) 161 | ) 162 | .sort_index() 163 | .drop(columns=["a.rid", "b.rid"]) 164 | ) 165 | return _EntityMatchingTransformerDataset( 166 | record_pairs, 167 | labels=labels, 168 | tokenizer=self.tokenizer, 169 | defer_encoding=self.training_args["dataloader_num_workers"] > 0, 170 | attr_strings=attr_strings, 171 | ) 172 | 173 | def _create_trainer(self, training_args=None, **kwargs) -> Trainer: 174 | training_args = {} if training_args is None else training_args 175 | training_args.setdefault("fp16", torch.cuda.is_available()) 176 | training_args.setdefault("no_cuda", not torch.cuda.is_available()) 177 | return Trainer( 178 | self.model, 179 | args=TrainingArguments(**{**self.training_args, **training_args}), 180 | compute_metrics=_compute_metrics, 181 | tokenizer=self.tokenizer, 182 | **kwargs, 183 | ) 184 | 185 | def fit( 186 | self, 187 | records_a: pd.DataFrame, 188 | records_b: pd.DataFrame, 189 | record_id_pairs: pd.DataFrame, 190 | labels: pd.Series, 191 | val_record_id_pairs: pd.DataFrame = None, 192 | val_labels: pd.Series = None, 193 | *, 194 | show_progress: bool = True, 195 | ): 196 | train_dataset = self._create_dataset(records_a, records_b, record_id_pairs, labels) 197 | val_dataset = ( 198 | self._create_dataset(records_a, records_b, val_record_id_pairs, val_labels) 199 | if val_record_id_pairs is not None 200 | else None 201 | ) 202 | 203 | trainer = self._create_trainer( 204 | train_dataset=train_dataset, 205 | eval_dataset=val_dataset, 206 | training_args={"disable_tqdm": not show_progress}, 207 | ) 208 | 209 | trainer.train() 210 | 211 | def predict_proba( 212 | self, 213 | records_a: pd.DataFrame, 214 | records_b: pd.DataFrame, 215 | record_id_pairs: pd.DataFrame, 216 | *, 217 | attr_strings: List[Dict[Tuple[str, str], str]] = None, 218 | show_progress: bool = True, 219 | ) -> pd.Series: 220 | import numpy as np 221 | 222 | transformer_dataset = self._create_dataset(records_a, records_b, record_id_pairs, attr_strings=attr_strings) 223 | trainer = self._create_trainer( 224 | training_args={ 225 | "disable_tqdm": not show_progress, 226 | "skip_memory_metrics": True, 227 | "seed": np.random.randint(0, 2 ** 31), 228 | } 229 | ) 230 | return pd.Series( 231 | data=torch.softmax( 232 | torch.from_numpy(trainer.predict(transformer_dataset).predictions.astype("float64")), 233 | dim=1, 234 | )[:, 1] 235 | .detach() 236 | .numpy(), 237 | index=record_id_pairs.index, 238 | ) 239 | 240 | def predict( 241 | self, 242 | records_a: pd.DataFrame, 243 | records_b: pd.DataFrame, 244 | record_id_pairs: pd.DataFrame, 245 | *, 246 | show_progress: bool = True, 247 | ) -> pd.Series: 248 | confidences = self.predict_proba(records_a, records_b, record_id_pairs, show_progress=show_progress) 249 | return confidences >= 0.5 250 | 251 | def evaluate( 252 | self, 253 | records_a: pd.DataFrame, 254 | records_b: pd.DataFrame, 255 | record_id_pairs: pd.DataFrame, 256 | labels: pd.Series, 257 | *, 258 | show_progress: bool = True, 259 | ) -> Dict: 260 | dataset = self._create_dataset(records_a, records_b, record_id_pairs, labels) 261 | trainer = self._create_trainer( 262 | training_args={ 263 | "disable_tqdm": not show_progress, 264 | "skip_memory_metrics": True, 265 | "seed": np.random.randint(0, 2 ** 31), 266 | } 267 | ) 268 | metrics = trainer.predict(dataset, metric_key_prefix="test").metrics 269 | return { 270 | "precision": metrics["test_precision"], 271 | "recall": metrics["test_recall"], 272 | "f1": metrics["test_f1"], 273 | } 274 | --------------------------------------------------------------------------------