├── gurobi.env ├── requirements.txt ├── setup.py ├── config.ini ├── LICENSE ├── probing ├── args.py ├── logconfig.py ├── config.py ├── loader.py ├── test.py ├── batch_probing.py ├── clusters.py ├── analyzer.py ├── utils.py ├── distanceQ.py ├── probing.py └── space.py ├── .gitignore ├── README.md └── main.py /gurobi.env: -------------------------------------------------------------------------------- 1 | OutputFlag 0 2 | GRB_IntParam_OutputFlag 0 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | tqdm 3 | torch 4 | ExAssist 5 | joblib 6 | scikit-learn 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | setup(name='probing', packages=['probing']) 3 | -------------------------------------------------------------------------------- /config.ini: -------------------------------------------------------------------------------- 1 | [run] 2 | comments = [probing] SS-role 3 | output_path = results/SS/ 4 | 5 | [data] 6 | common = ./data/SS 7 | 8 | label_set_path = ${common}/labels/tags.txt 9 | entities_path = ${common}/entities/train.txt 10 | test_entities_path = ${common}/entities/test.txt 11 | 12 | embeddings_path = ${common}/embeddings/ss-role-fine-tuned-bert-base-uncased/train/0/12.txt 13 | test_embeddings_path = ${common}/embeddings/ss-role-fine-tuned-bert-base-uncased/test/0/12.txt 14 | 15 | batch_embeddings_path = 16 | 17 | 18 | [clustering] 19 | enable_cuda = True 20 | rate=0.05 21 | mode = probing 22 | probing_cluster_path = 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Utah NLP 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /probing/args.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # 4 | # Author: Yichu Zhou - flyaway1217@gmail.com 5 | # Blog: zhouyichu.com 6 | # 7 | # Python release: 3.8.0 8 | # 9 | # Date: 2020-12-29 15:59:08 10 | # Last modified: 2020-12-29 16:16:42 11 | 12 | """ 13 | Setting the args. 14 | """ 15 | import argparse 16 | 17 | 18 | def get_args(): 19 | parser = argparse.ArgumentParser( 20 | description='Probing the given embeddings') 21 | # Run section 22 | parser.add_argument('-run_comments', '--run_comments') 23 | parser.add_argument('-run_output_path', '--run_output_path') 24 | 25 | # data section 26 | parser.add_argument('-label_set_path', '--label_set_path') 27 | parser.add_argument('-entities_path', '--entities_path') 28 | parser.add_argument('-embeddings_path', '--embeddings_path') 29 | parser.add_argument('-test_entities_path', '--test_entities_path') 30 | parser.add_argument('-test_embeddings_path', '--test_embeddings_path') 31 | 32 | # clustering setting 33 | parser.add_argument('--enable_cuda', action='store_true') 34 | parser.add_argument('-rate', '--rate', default=0.01) 35 | parser.add_argument('-mode', '--mode') 36 | parser.add_argument('-probing_cluster_path', '--probing_cluster_path') 37 | 38 | args = parser.parse_args() 39 | return args 40 | -------------------------------------------------------------------------------- /probing/logconfig.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # 4 | # Author: Yichu Zhou - flyaway1217@gmail.com 5 | # Blog: zhouyichu.com 6 | # 7 | # Python release: 3.8.0 8 | # 9 | # Date: 2020-07-24 16:12:52 10 | # Last modified: 2020-12-29 12:12:12 11 | 12 | """ 13 | Logger configurations. 14 | """ 15 | 16 | PACKAGE_NAME = 'probing' 17 | 18 | LOG_FILE = PACKAGE_NAME + '.log' 19 | 20 | LOGGING_CONFIG = { 21 | 'version': 1, 22 | 'disable_existing_loggers': True, 23 | 'formatters': { 24 | 'standard': { 25 | 'format': '%(asctime)s-%(filename)s-%(levelname)s: %(message)s' 26 | }, 27 | }, 28 | 'handlers': { 29 | 'file': { 30 | 'level': 'DEBUG', 31 | 'formatter': 'standard', 32 | 'class': 'logging.FileHandler', 33 | 'filename': LOG_FILE, 34 | }, 35 | 'console': { 36 | 'level': 'DEBUG', 37 | 'formatter': 'standard', 38 | 'class': 'logging.StreamHandler', 39 | # 'stream': 'ext://sys.stdout', 40 | }, 41 | }, 42 | 'loggers': { 43 | PACKAGE_NAME: { 44 | 'level': 'DEBUG', 45 | 'handlers': ['console', 'file'] 46 | }, 47 | '__main__': { 48 | 'level': 'DEBUG', 49 | 'handlers': ['console', 'file'] 50 | }, 51 | }, 52 | } 53 | 54 | 55 | def set_log_path(path): 56 | LOGGING_CONFIG['handlers']['file']['filename'] = path 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | Experiments/* 132 | data/* 133 | -------------------------------------------------------------------------------- /probing/config.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # 4 | # Author: Yichu Zhou - flyaway1217@gmail.com 5 | # Blog: zhouyichu.com 6 | # 7 | # Python release: 3.6.0 8 | # 9 | # Date: 2019-08-14 13:44:20 10 | # Last modified: 2022-03-14 11:18:02 11 | 12 | """ 13 | Load configuation from file. 14 | """ 15 | 16 | from pathlib import Path 17 | 18 | import torch 19 | 20 | 21 | class Config: 22 | def __init__(self, config): 23 | self._get_runpath(config) 24 | 25 | self._get_data(config) 26 | 27 | self._get_clustering(config) 28 | 29 | def _get_clustering(self, config): 30 | self.mode = config.mode 31 | self.probing_cluster_path = config.probing_cluster_path 32 | self.enable_cuda = bool(config.enable_cuda) 33 | self.rate = float(config.rate) 34 | # self.iter_step = int(config.iter_step) 35 | cuda = self.enable_cuda and torch.cuda.is_available() 36 | self.device = torch.device('cuda' if cuda else 37 | 'cpu') 38 | 39 | def _get_runpath(self, config): 40 | output_path = Path(config.output_path) 41 | if not output_path.exists(): 42 | output_path.mkdir(parents=True) 43 | else: 44 | if len(list(output_path.iterdir())) != 0: 45 | raise Exception('The results directory is non-empty!') 46 | self.cluster_path = output_path / 'clusters.txt' 47 | self.log_path = output_path / 'log.txt' 48 | self.prediction_path = output_path / 'prediction.txt' 49 | 50 | # Data cartography 51 | self.data_map_predictions_path = output_path / 'data_map_predictions/' 52 | self.data_map_path = output_path / 'data_map.txt' 53 | 54 | # Batch probing 55 | self.batch_label_vec_path = output_path / 'batch_vec/' 56 | self.batch_inside_mean_path = output_path / 'inside_mean/' 57 | self.batch_inside_max_path = output_path / 'inside_max/' 58 | self.batch_outside_min_path = output_path / 'outside_min/' 59 | 60 | self.label_vec_path = output_path / 'vec.txt' 61 | self.inside_mean_path = output_path / 'inside_mean.txt' 62 | self.inside_max_path = output_path / 'inside_max.txt' 63 | self.outside_min_path = output_path / 'outside_min.txt' 64 | 65 | def _get_data(self, config): 66 | self.entities_path = config.entities_path 67 | # self.test_entities_path = config.test_entities_path 68 | self.label_set_path = config.label_set_path 69 | self.embeddings_path = config.embeddings_path 70 | 71 | self.batch_embeddings_path = config.batch_embeddings_path 72 | # self.test_embeddings_path = config.test_embeddings_path 73 | # Data cartography 74 | # self.train_indices_path = config.train_indices_path 75 | # self.test_indices_path = config.test_indices_path 76 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BERT-fine-tuning-analysis 2 | The codebase for the paper: A Closer Look at How Fine-tuning Changes BERT. 3 | 4 | # Installing 5 | This codebase is dervied from the [DirectProbe][], following 6 | the same install instructions as [DirectProbeCode][]. 7 | 8 | # Getting Started 9 | 10 | ## Download datasets and Running examples 11 | 12 | 1. Download the pre-packed data from [here][data_url] and 13 | unzip them. The data format is the same as [DirectProbeCode][]. 14 | 2. Suppose all the pre-packed data is put in the directory 15 | `data`, then we can run an experiment using the 16 | configuration from `config.ini`. 17 | 18 | ``` 19 | python main.py 20 | ``` 21 | 22 | ## Results 23 | After probing, you will find the results in the 24 | directory `results/SS/`.(We are using the supersense 25 | role task as the example.) 26 | In this directory, there are 6 files: 27 | - `clusters.txt`: The clustering results. Each line contains 28 | a cluster number for the corresponding training example. 29 | 30 | - `inside_max.txt`: The maximum pairwise distances inside 31 | each cluster. Each line represents one cluster. 32 | 33 | - `inside_mean.txt`: The mean pairwise distances inside each 34 | cluster. Each line represents one cluster. 35 | 36 | - `log.txt`: The probing log file. 37 | 38 | - `outside_min.txt`: The minimum distance to other clusters 39 | for each cluster. Each line represents one cluster. 40 | 41 | - `vec.txt`: Pairwise distances between clusters. Each line 42 | represents a pair of cluster and its distance. 43 | 44 | # Citations 45 | 46 | ``` 47 | @inproceedings{zhou-srikumar-2022-closer, 48 | title = "A Closer Look at How Fine-tuning Changes {BERT}", 49 | author = "Zhou, Yichu and 50 | Srikumar, Vivek", 51 | booktitle = "Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", 52 | month = may, 53 | year = "2022", 54 | address = "Dublin, Ireland", 55 | publisher = "Association for Computational Linguistics", 56 | url = "https://aclanthology.org/2022.acl-long.75", 57 | doi = "10.18653/v1/2022.acl-long.75", 58 | pages = "1046--1061", 59 | abstract = "Given the prevalence of pre-trained contextualized representations in today{'}s NLP, there have been many efforts to understand what information they contain, and why they seem to be universally successful. The most common approach to use these representations involves fine-tuning them for an end task. Yet, how fine-tuning changes the underlying embedding space is less studied. In this work, we study the English BERT family and use two probing techniques to analyze how fine-tuning changes the space. We hypothesize that fine-tuning affects classification performance by increasing the distances between examples associated with different labels. We confirm this hypothesis with carefully designed experiments on five different NLP tasks. Via these experiments, we also discover an exception to the prevailing wisdom that {``}fine-tuning always improves performance{''}. Finally, by comparing the representations before and after fine-tuning, we discover that fine-tuning does not introduce arbitrary changes to representations; instead, it adjusts the representations to downstream tasks while largely preserving the original spatial structure of the data points.", 60 | } 61 | ``` 62 | 63 | [DirectProbe]: https://aclanthology.org/2021.naacl-main.401/ 64 | [DirectProbeCode]: https://github.com/utahnlp/DirectProbe 65 | [data_url]: https://drive.google.com/drive/folders/1mlF-O20Zsa_jJG3tjV-vVrIivY71_R5P 66 | -------------------------------------------------------------------------------- /probing/loader.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # 4 | # Author: Yichu Zhou - flyaway1217@gmail.com 5 | # Blog: zhouyichu.com 6 | # 7 | # Python release: 3.8.0 8 | # 9 | # Date: 2021-01-06 10:07:08 10 | # Last modified: 2021-02-05 10:10:56 11 | 12 | """ 13 | Load data 14 | """ 15 | 16 | from typing import Tuple 17 | import logging 18 | 19 | import numpy as np 20 | 21 | from probing import utils 22 | from probing.config import Config 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def load_train( 28 | config: Config) -> Tuple[np.array, np.array, np.array]: 29 | """Loading all the necessary input files. 30 | 31 | This function load 3 files: 32 | - entities: A file contains the entities and labels. 33 | One entity per line. 34 | - label_set_path: A file contains all the possible labels. 35 | We have a separate file because in some cases, 36 | not all the labels occure in the training set. 37 | - embeddings_path: A file contains all the embeddings. 38 | A vector per line. 39 | """ 40 | path = config.entities_path 41 | logger.info('Load entities from ' + path) 42 | entities = utils.load_entities(path) 43 | 44 | # For debugging 45 | n = len(entities) 46 | # n = 200 47 | annotations = [entities[i].Label for i in range(n)] 48 | entities = [entities[i] for i in range(n)] 49 | 50 | s = 'Finish loading {a} entities...' 51 | s = s.format(a=str(len(entities))) 52 | logger.info(s) 53 | 54 | labels = sorted(list(utils.load_labels(config.label_set_path))) 55 | label2idx = {labels[i]: i for i in range(len(labels))} 56 | annotations = [label2idx[t] for t in annotations] 57 | 58 | logger.info('Label size={a}'.format(a=str(len(labels)))) 59 | 60 | embeddings_path = config.embeddings_path 61 | logger.info('Loading embeddings from ' + str(embeddings_path)) 62 | embeddings = utils.load_embeddings(embeddings_path) 63 | embeddings = embeddings[:n] 64 | logger.info('Finish loading embeddings...') 65 | 66 | assert len(embeddings) == n 67 | 68 | annotations = np.array(annotations) 69 | labels = np.array(labels) 70 | embeddings = np.array(embeddings) 71 | return annotations, labels, embeddings, label2idx 72 | 73 | 74 | def load_test(config: Config): 75 | path = config.test_entities_path 76 | logger.info('Load entities from ' + path) 77 | entities = utils.load_entities(path) 78 | 79 | # For debugging 80 | n = len(entities) 81 | # n = 30 82 | annotations = [entities[i].Label for i in range(n)] 83 | entities = [entities[i] for i in range(n)] 84 | 85 | s = 'Finish loading {a} entities...' 86 | s = s.format(a=str(len(entities))) 87 | logger.info(s) 88 | 89 | labels = sorted(list(utils.load_labels(config.label_set_path))) 90 | label2idx = {labels[i]: i for i in range(len(labels))} 91 | annotations = [label2idx[t] for t in annotations] 92 | 93 | embeddings_path = config.test_embeddings_path 94 | logger.info('Loading embeddings from ' + embeddings_path) 95 | embeddings = utils.load_embeddings(embeddings_path) 96 | embeddings = embeddings[:n] 97 | logger.info('Finish loading embeddings...') 98 | 99 | assert len(embeddings) == n 100 | 101 | annotations = np.array(annotations) 102 | labels = np.array(labels) 103 | embeddings = np.array(embeddings) 104 | assert len(annotations) == len(embeddings) 105 | return annotations, embeddings, label2idx 106 | -------------------------------------------------------------------------------- /probing/test.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # 4 | # Author: Yichu Zhou - flyaway1217@gmail.com 5 | # Blog: zhouyichu.com 6 | # 7 | # Python release: 3.6.0 8 | # 9 | # Date: 2019-08-16 09:58:30 10 | # Last modified: 2020-11-02 11:59:27 11 | 12 | """ 13 | Test for SVM margin. 14 | """ 15 | 16 | import numpy as np 17 | # from sklearn.svm import LinearSVC 18 | from sklearn.svm import SVC 19 | import gurobipy as gp 20 | from gurobipy import GRB 21 | 22 | 23 | def load_embeddings(path): 24 | reval = [] 25 | with open(path, encoding='utf8') as f: 26 | for line in f: 27 | s = line.strip().split() 28 | vec = [float(v) for v in s] 29 | reval.append(vec) 30 | return reval 31 | 32 | 33 | def distance(X, p): 34 | # clf = LinearSVC(tol=1e-5, loss='hinge', C=100000, max_iter=20000) 35 | clf = SVC(tol=1e-5, C=1000000, kernel='linear', max_iter=20000) 36 | y = [0] * len(X) 37 | y.append(1) 38 | y = np.array(y) 39 | XX = np.concatenate((X, p.reshape(1, -1))) 40 | clf.fit(XX, y) 41 | w = clf.coef_.reshape(-1) 42 | b = clf.intercept_[0] 43 | if clf.score(XX, y) != 1.0: 44 | print(w) 45 | print(b) 46 | return 0 47 | 48 | print(w) 49 | print(b) 50 | 51 | d = np.dot(w, p) + b 52 | d = abs(d) / np.linalg.norm(w) 53 | 54 | d = (d*2) 55 | return d 56 | 57 | 58 | def hull2hull(X1, X2): 59 | clf = SVC(tol=1e-4, C=10000, kernel='linear', max_iter=20000) 60 | # clf = LinearSVC(tol=1e-5, loss='hinge', C=100000, max_iter=20000) 61 | y1 = [1] * len(X1) 62 | y2 = [-1] * len(X2) 63 | y1 = np.array(y1) 64 | y2 = np.array(y2) 65 | XX = np.concatenate((X1, X2)) 66 | yy = np.concatenate((y1, y2)) 67 | clf.fit(XX, yy) 68 | 69 | w = clf.coef_.reshape(-1) 70 | b = clf.intercept_[0] 71 | print(w) 72 | print(b) 73 | score = clf.score(XX, yy) 74 | if score != 1: 75 | return 0 76 | 77 | d = np.dot(XX, w) + b 78 | d = abs(d) / np.linalg.norm(w) 79 | 80 | d = (d*2) 81 | return np.min(d) 82 | 83 | 84 | def lp(X1, X2, has_same=False): 85 | """Return 1 when the LP problem is infeasible. 86 | """ 87 | if has_same: 88 | dist = distance.cdist(X1, X2) 89 | # There is a same vector on both sides. 90 | if np.any(dist == 0): 91 | return 1 92 | # logger = logging.getLogger('Probing') 93 | m = X1.shape[1] 94 | model = gp.Model("lp") 95 | 96 | # Create variables 97 | W = model.addMVar(shape=m+1, lb=-GRB.INFINITY, 98 | ub=GRB.INFINITY, vtype=GRB.CONTINUOUS, name="W") 99 | # Adding the bias 100 | XX1 = np.concatenate((X1, np.ones((X1.shape[0], 1))), axis=1) 101 | XX2 = np.concatenate((X2, np.ones((X2.shape[0], 1))), axis=1) 102 | 103 | Y1 = np.array([1]*X1.shape[0]) 104 | Y2 = np.array([-1]*X2.shape[0]) 105 | 106 | model.addConstr(XX1 @ W >= Y1) 107 | model.addConstr(XX2 @ W <= Y2) 108 | model.setObjective(0, GRB.MINIMIZE) 109 | model.setParam('OutputFlag', False) 110 | # model.setParam('Method', 2) 111 | # model.setParam('FeasibilityTol', 1e-4) 112 | 113 | # Optimize model 114 | model.update() 115 | model.optimize() 116 | 117 | print(model.Status) 118 | return int(model.Status != GRB.OPTIMAL) 119 | 120 | 121 | if __name__ == '__main__': 122 | # X = [[1, 1], [2, 2], [3, 3], [4, 4], 123 | # [1, 2], [3, 2]] 124 | # X = [[5, 10], [0, 5], [10, 5], [5, 0]] 125 | # p = [10, -5] 126 | # X = np.array(X) 127 | # p = np.array(p) 128 | 129 | # d = distance(X, p) 130 | # print(d) 131 | 132 | X1 = [[1, 6], [1, 3], [15, 3], [15, 6]] 133 | X2 = [[1, -3], [1, -1], [2, -3], [2, -1]] 134 | X1 = np.array(X1) 135 | X2 = np.array(X2) 136 | d = hull2hull(X1, X2) 137 | print(d) 138 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # 4 | # Author: Yichu Zhou - flyaway1217@gmail.com 5 | # Blog: zhouyichu.com 6 | # 7 | # Python release: 3.6.0 8 | # 9 | # Date: 2019-07-24 10:36:21 10 | # Last modified: 2022-03-14 11:23:34 11 | 12 | """ 13 | Main enterance. 14 | """ 15 | import logging 16 | import logging.config 17 | import configparser 18 | import numpy as np 19 | 20 | import ExAssist as EA 21 | 22 | from probing import utils 23 | from probing.config import Config 24 | from probing.probing import Probe 25 | from probing.clusters import Cluster 26 | from probing.distanceQ import DistanceQ 27 | from probing.analyzer import Analyzer 28 | import probing.logconfig as cfg 29 | from probing import loader 30 | from probing import batch_probing as bp 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | def run(config): 36 | annotations, labels, embeddings, label2idx = loader.load_train(config) 37 | probe = Probe(config) 38 | 39 | clusters = [Cluster([i], [label]) for 40 | i, label in enumerate(annotations)] 41 | 42 | logger.info('Initialize the Distance Queue...') 43 | q = DistanceQ(config, embeddings, clusters, len(labels)) 44 | q = probe.probing(q) 45 | assist = EA.getAssist('Probing') 46 | assist.result['final number'] = len(q) 47 | logger.info('Dumping the clusters...') 48 | utils.write_clusters(config.cluster_path, q) 49 | logger.info('Finish dumping the clusters...') 50 | 51 | config.probing_cluster_path = config.cluster_path 52 | develop_run(config) 53 | 54 | 55 | def develop_run(config): 56 | assist = EA.getAssist('Probing') 57 | s = 'Loading the clusters from {a}' 58 | s = s.format(a=str(config.probing_cluster_path)) 59 | logger.info(s) 60 | annotations, labels, embeddings, label2idx = loader.load_train(config) 61 | clusters_indices = utils.load_clusters(config.probing_cluster_path) 62 | labels_list = utils.assign_labels(clusters_indices, annotations) 63 | assert len(clusters_indices) == len(labels_list) 64 | clusters = [Cluster(indices, labs) for 65 | indices, labs in zip(clusters_indices, labels_list) 66 | if len(indices) > 5] 67 | q = DistanceQ(config, embeddings, clusters, len(labels)) 68 | logger.info('Finish loading the clusters...') 69 | 70 | analyzer = Analyzer(config) 71 | # annotations, embeddings, label2idx = loader.load_test(config) 72 | idx2label = {value: key for key, value in label2idx.items()} 73 | 74 | logger.info('Computing the distances between clusters...') 75 | total_label_pair_dis = analyzer.convex2convex(q) 76 | logger.info('Computing the distances vectors...') 77 | label_dis_vec = analyzer.label_dis_vec(total_label_pair_dis) 78 | logger.info('Computing the outside min distance...') 79 | outside_min_dis = analyzer.outside_min_dis( 80 | total_label_pair_dis) 81 | logger.info('Computing the inside mean distance...') 82 | inside_mean_dis = analyzer.inside_mean_dis(q) 83 | logger.info('Computing the inside max distance...') 84 | inside_max_dis = analyzer.inside_max_dis(q) 85 | 86 | utils.write_label_dis_vecs( 87 | config.label_vec_path, label_dis_vec, idx2label) 88 | utils.write_diss( 89 | config.outside_min_path, outside_min_dis, idx2label) 90 | utils.write_diss( 91 | config.inside_mean_path, inside_mean_dis, idx2label) 92 | utils.write_diss( 93 | config.inside_max_path, inside_max_dis, idx2label) 94 | 95 | outside_min_dis = np.array(outside_min_dis) 96 | inside_mean_dis = np.array(inside_mean_dis) 97 | inside_max_dis = np.array(inside_max_dis) 98 | min_dis = np.min(outside_min_dis[outside_min_dis > 0]) 99 | mean_inside_mean_dis = np.mean(inside_mean_dis[inside_mean_dis > 0]) 100 | max_inside_max_dis = np.max(inside_max_dis[inside_max_dis > 0]) 101 | 102 | s = 'global min dis={a}'.format(a=str(min_dis)) 103 | logger.info(s) 104 | s = 'mean_inside_mean_dis={a}'.format(a=str(mean_inside_mean_dis)) 105 | logger.info(s) 106 | s = 'max inside max dis={a}'.format(a=str(max_inside_max_dis)) 107 | logger.info(s) 108 | 109 | max_min_ratio = [] 110 | for i, j in zip(inside_max_dis, outside_min_dis): 111 | if i != 0 and j != 0: 112 | max_min_ratio.append(i/j) 113 | s = 'Inside max / outside min = {a}'.format(a=str(np.mean(max_min_ratio))) 114 | logger.info(s) 115 | 116 | mean_min_ratio = [] 117 | for i, j in zip(inside_mean_dis, outside_min_dis): 118 | if i != 0 and j != 0: 119 | mean_min_ratio.append(i/j) 120 | s = 'Inside mean / outside min = {a}'.format( 121 | a=str(np.mean(mean_min_ratio))) 122 | logger.info(s) 123 | 124 | assist.result['global min dis'] = min_dis 125 | assist.result['mean inside mean dis'] = mean_inside_mean_dis 126 | assist.result['max inside max dis'] = max_inside_max_dis 127 | assist.result['max_min_ratio'] = np.mean(max_min_ratio) 128 | assist.result['mean_min_ratio'] = np.mean(mean_min_ratio) 129 | 130 | 131 | def main(): 132 | assist = EA.getAssist('Probing') 133 | assist.deactivate() 134 | 135 | config = configparser.ConfigParser( 136 | interpolation=configparser.ExtendedInterpolation()) 137 | config.read('./config.ini', encoding='utf8') 138 | # config = args.get_args() 139 | 140 | assist.set_config(config) 141 | with EA.start(assist) as assist: 142 | config = Config(assist.config) 143 | cfg.set_log_path(config.log_path) 144 | logging.config.dictConfig(cfg.LOGGING_CONFIG) 145 | if config.mode == 'prediction': 146 | develop_run(config) 147 | elif config.mode == 'probing': 148 | run(config) 149 | elif config.mode == 'batch_probing': 150 | bp.batch_analyze(config) 151 | 152 | 153 | if __name__ == '__main__': 154 | # import cProfile 155 | # cProfile.run('main()', sort='cumulative') 156 | main() 157 | -------------------------------------------------------------------------------- /probing/batch_probing.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # 4 | # Author: Yichu Zhou - flyaway1217@gmail.com 5 | # Blog: zhouyichu.com 6 | # 7 | # Python release: 3.8.0 8 | # 9 | # Date: 2021-02-05 09:18:42 10 | # Last modified: 2021-03-03 09:32:32 11 | 12 | """ 13 | Batch probing for a series embedding spaces. 14 | """ 15 | import logging 16 | from pathlib import Path 17 | 18 | # import ExAssist as EA 19 | import numpy as np 20 | 21 | from probing import utils 22 | from probing import loader 23 | from probing.analyzer import Analyzer 24 | from probing.probing import Probe 25 | from probing.clusters import Cluster 26 | from probing.distanceQ import DistanceQ 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | def batch_analyze(config): 32 | batch_path = Path(config.batch_embeddings_path) 33 | common_label_vec_path = Path(config.batch_label_vec_path) 34 | common_inside_mean_path = Path(config.batch_inside_mean_path) 35 | common_inside_max_path = Path(config.batch_inside_max_path) 36 | common_outside_min_path = Path(config.batch_outside_min_path) 37 | 38 | if not common_label_vec_path.exists(): 39 | common_label_vec_path.mkdir() 40 | if not common_inside_mean_path.exists(): 41 | common_inside_mean_path.mkdir() 42 | if not common_inside_max_path.exists(): 43 | common_inside_max_path.mkdir() 44 | if not common_outside_min_path.exists(): 45 | common_outside_min_path.mkdir() 46 | 47 | iterations = list(batch_path.iterdir()) 48 | 49 | for iter_path in iterations: 50 | label_vec_path = common_label_vec_path / (iter_path.name + '.txt') 51 | inside_mean_path = common_inside_mean_path / (iter_path.name + '.txt') 52 | outside_min_path = common_outside_min_path / (iter_path.name + '.txt') 53 | inside_max_path = common_inside_max_path / (iter_path.name + '.txt') 54 | 55 | label_dis_vecs = [] 56 | inside_mean_diss = [] 57 | outside_min_diss = [] 58 | inside_max_diss = [] 59 | for layer_idx in range(1, 13): 60 | layer_path = iter_path / (str(layer_idx) + '.txt') 61 | config.embeddings_path = layer_path 62 | logger.info('Loading embeddings...') 63 | data = loader.load_train(config) 64 | annotations = data[0] 65 | labels = data[1] 66 | embeddings = data[2] 67 | label2idx = data[3] 68 | 69 | logger.info('Loading clusters...') 70 | clusters_indices = utils.load_clusters(config.probing_cluster_path) 71 | labels_list = utils.assign_labels(clusters_indices, annotations) 72 | assert len(clusters_indices) == len(labels_list) 73 | # Filter the clusters with points less than 5 74 | clusters = [Cluster(indices, labs) for 75 | indices, labs in zip(clusters_indices, labels_list) 76 | if len(indices) > 5] 77 | q = DistanceQ(config, embeddings, clusters, len(labels)) 78 | logger.info('Finish loading the clusters...') 79 | 80 | probe = Probe(config) 81 | # If overlaps happen 82 | # if True: 83 | if not probe._check_overlaps(q): 84 | s = str(layer_path) + ' IS linearly separable' 85 | logger.info(s) 86 | analyzer = Analyzer(config) 87 | idx2label = {value: key for key, value in label2idx.items()} 88 | 89 | logger.info('Computing the distances between clusters...') 90 | total_label_pair_dis = analyzer.convex2convex(q) 91 | logger.info('Computing the distances vectors...') 92 | label_dis_vec = analyzer.label_dis_vec(total_label_pair_dis) 93 | logger.info('Computing the outside min distance...') 94 | outside_min_dis = analyzer.outside_min_dis( 95 | total_label_pair_dis) 96 | logger.info('Computing the inside mean distance...') 97 | inside_mean_dis = analyzer.inside_mean_dis(q) 98 | logger.info('Computing the inside max distance...') 99 | inside_max_dis = analyzer.inside_max_dis(q) 100 | else: 101 | s = str(layer_path) + ' is NOT linearly separable' 102 | logger.info(s) 103 | label_dis_vec = [0] * ((q.label_size-1)*q.label_size // 2) 104 | outside_min_dis = [0] * q.label_size 105 | inside_mean_dis = [0] * q.label_size 106 | inside_max_dis = [0] * q.label_size 107 | 108 | if len(label_dis_vecs) != 0: 109 | assert len(label_dis_vecs[-1]) == len(label_dis_vec) 110 | if len(outside_min_diss) != 0: 111 | assert len(outside_min_diss[-1]) == len(outside_min_dis) 112 | if len(inside_mean_diss) != 0: 113 | assert len(inside_mean_diss[-1]) == len(inside_mean_dis) 114 | if len(inside_max_diss) != 0: 115 | assert len(inside_max_diss[-1]) == len(inside_max_dis) 116 | 117 | label_dis_vecs.append(label_dis_vec) 118 | outside_min_diss.append(outside_min_dis) 119 | inside_mean_diss.append(inside_mean_dis) 120 | inside_max_diss.append(inside_max_dis) 121 | 122 | label_dis_vecs = np.array(label_dis_vecs).transpose() 123 | outside_min_diss = np.array(outside_min_diss).transpose() 124 | inside_mean_diss = np.array(inside_mean_diss).transpose() 125 | inside_max_diss = np.array(inside_max_diss).transpose() 126 | 127 | # print(label_dis_vecs.shape) 128 | # print(outside_min_diss.shape) 129 | # print(inside_mean_diss.shape) 130 | utils.write_batch_label_dis_vecs( 131 | label_vec_path, label_dis_vecs, idx2label) 132 | utils.write_batch_diss( 133 | outside_min_path, outside_min_diss, idx2label) 134 | utils.write_batch_diss( 135 | inside_mean_path, inside_mean_diss, idx2label) 136 | utils.write_batch_diss( 137 | inside_max_path, inside_max_diss, idx2label) 138 | -------------------------------------------------------------------------------- /probing/clusters.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # 4 | # Author: Yichu Zhou - flyaway1217@gmail.com 5 | # Blog: zhouyichu.com 6 | # 7 | # Python release: 3.6.0 8 | # 9 | # Date: 2020-03-20 10:56:58 10 | # Last modified: 2021-01-07 10:15:13 11 | 12 | """ 13 | Data structure for probing. 14 | """ 15 | 16 | # import logging 17 | from collections import Counter 18 | from functools import total_ordering 19 | import heapq 20 | from typing import List 21 | # from multiprocessing import Pool 22 | 23 | # import numpy as np 24 | # import torch 25 | # from tqdm import tqdm 26 | # from tqdm import trange 27 | # from joblib import Parallel, delayed 28 | 29 | 30 | class Cluster: 31 | # __slots__ is used here because there will be 32 | # so many Cluster object during the probing, I 33 | # want to save as much memory as possible. 34 | __slots__ = ('indices', 'major_label', 35 | '_hash_value', 36 | 'children', 'labels') 37 | 38 | def __init__(self, indices: List[int], labels: List[int]): 39 | """Initialize a new cluster with indices 40 | 41 | Args: 42 | - indices: The index of each point. 43 | - labels: The label of each point. 44 | """ 45 | assert len(indices) == len(labels) 46 | self.indices = sorted(indices) 47 | self.labels = labels 48 | self.major_label = Counter(labels).most_common(1)[0][0] 49 | 50 | self._hash_value = ' '.join([str(i) for i in self.indices]) 51 | self._hash_value = hash(self._hash_value) 52 | 53 | # The children is used to track the path of merging 54 | # This can be used to speed up the probing during later steps. 55 | self.children = set() 56 | 57 | @property 58 | def purity(self) -> float: 59 | n = sum([1 for i in self.labels if i == self.major_label]) 60 | return n / len(self.labels) 61 | 62 | @staticmethod 63 | def merge(A: 'Cluster', B: 'Cluster') -> 'Cluster': 64 | """Merge two clusters and produce a new cluster. 65 | """ 66 | assert type(A) == Cluster 67 | assert type(B) == Cluster 68 | indices = A.indices + B.indices 69 | labels = A.labels + B.labels 70 | reval = Cluster(indices, labels) 71 | reval.children = A.children | B.children 72 | 73 | # Do not forget A and B themselves. 74 | reval.children.add(A) 75 | reval.children.add(B) 76 | return reval 77 | 78 | def __hash__(self): 79 | return self._hash_value 80 | 81 | def __eq__(self, other): 82 | return self._hash_value == other._hash_value 83 | 84 | def __repr__(self): 85 | n = len(self.indices) 86 | idx = ' '.join([str(self.indices[i]) for i in range(n)]) 87 | labels = ' '.join([str(self.labels[i]) for i in range(n)]) 88 | s = 'Cluster(Indices:{a}, labels:{b}, major_label={c}, purity={d})' 89 | s = s.format(a=str(idx), b=labels, c=self.major_label, d=self.purity) 90 | return s 91 | 92 | 93 | @total_ordering 94 | class ClusterDisPair: 95 | """This is a intermediate class which is used to compute 96 | the distance between two clusters. 97 | This class should not be exposed to the end user. 98 | """ 99 | __slots__ = ('i', 'j', 'dis', '_hash_value') 100 | 101 | def __init__(self, i: int, j: int, dis: float): 102 | """ 103 | Be note, here the index i and index j is not the index of 104 | points, instead they are the indices of clusters. 105 | 106 | Args: 107 | i: The index of the cluster. 108 | j: The index of the cluster. 109 | dis: The distance between these two clusters. 110 | """ 111 | assert i != j 112 | self.i = min(i, j) 113 | self.j = max(i, j) 114 | self.dis = dis 115 | self._hash_value = hash((self.i, self.j, dis)) 116 | 117 | def __hash__(self): 118 | return self._hash_value 119 | 120 | def __eq__(self, other): 121 | return self._hash_value == other._hash_value 122 | 123 | def __lt__(self, other): 124 | return self.dis < other.dis 125 | 126 | def __le__(self, other): 127 | return self.dis <= other.dis 128 | 129 | def __repr__(self): 130 | s = 'ClusterDisPair(i:{a}, j:{b}, dis:{c})' 131 | s = s.format(a=str(self.i), b=str(self.j), c=str(self.dis)) 132 | return s 133 | 134 | 135 | @total_ordering 136 | class ClusterDisList: 137 | """ A heap list of pair of clusters. 138 | 139 | Each list represents all the pair distance of (idx, i), i < idx. 140 | Here, idx and i are the indices of clusters. 141 | """ 142 | __slots__ = ('dis_list', 'idx', '_hash_value') 143 | 144 | def __init__(self, dis_list: List[ClusterDisPair], idx: int): 145 | self.dis_list = dis_list 146 | heapq.heapify(self.dis_list) 147 | self.idx = idx 148 | self._hash_value = hash(idx) 149 | 150 | def min(self) -> ClusterDisPair: 151 | """Return the pair of minimum distance of this list. 152 | """ 153 | return heapq.heappop(self.dis_list) 154 | 155 | def deactive(self): 156 | self.dis_list = [] 157 | 158 | def __hash__(self): 159 | return self._hash_value 160 | 161 | def __eq__(self, other): 162 | if not self.dis_list and not other.dis_list: 163 | return True 164 | elif not self.dis_list or not other.dis_list: 165 | return False 166 | else: 167 | return self.dis_list[0] == other.dis_list[0] 168 | 169 | def __lt__(self, other): 170 | if not self.dis_list: 171 | return False 172 | if not other.dis_list: 173 | return True 174 | return self.dis_list[0] < other.dis_list[0] 175 | 176 | def __le__(self, other): 177 | if not self.dis_list: 178 | return False 179 | if not other.dis_list: 180 | return True 181 | return self.dis_list[0] <= other.dis_list[0] 182 | 183 | def __repr__(self): 184 | if not self.dis_list: 185 | s = 'Index:{a} is deactivate' 186 | s = s.format(a=str(self.idx)) 187 | else: 188 | s = 'Index:{a} has minimum value {b}' 189 | s = s.format(a=str(self.idx), 190 | b=str(self.dis_list[0])) 191 | return s 192 | 193 | 194 | if __name__ == '__main__': 195 | import random 196 | n = 3 197 | array = [] 198 | for i in range(n): 199 | for j in range(n): 200 | array.append(ClusterDisPair(i, j, i+j)) 201 | random.shuffle(array) 202 | print(array) 203 | array.sort() 204 | print(array) 205 | -------------------------------------------------------------------------------- /probing/analyzer.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # 4 | # Author: Yichu Zhou - flyaway1217@gmail.com 5 | # Blog: zhouyichu.com 6 | # 7 | # Python release: 3.8.0 8 | # 9 | # Date: 2020-12-29 13:46:09 10 | # Last modified: 2021-03-03 10:48:44 11 | 12 | """ 13 | Analyzing functions. 14 | """ 15 | 16 | import logging 17 | from typing import List, Tuple 18 | import collections 19 | 20 | from tqdm import tqdm 21 | from joblib import Parallel, delayed 22 | import torch 23 | 24 | import numpy as np 25 | from probing.distanceQ import DistanceQ 26 | from probing.space import Space 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | class Analyzer: 32 | def __init__(self, config): 33 | self.args = config 34 | 35 | def predict(self, q, ann, embeds): 36 | return self.points2convex(q, ann, embeds) 37 | 38 | def points2convex( 39 | self, 40 | q: DistanceQ, 41 | ann: np.array, 42 | embeds: np.array 43 | ) -> Tuple[float, List[List[Tuple[int, int, float]]]]: 44 | """ 45 | Make predictions for `embeds` based on the distances 46 | between each point in `embeds` and all the clusters. 47 | 48 | Returns: 49 | - List((cluster_id, major_label, distance)): 50 | the ranking of clusters for each test point 51 | based on the distance. 52 | """ 53 | assert len(ann) == len(embeds) 54 | clusters = q.clusters 55 | 56 | logger.info('Computing the distances...') 57 | return_list = [] 58 | correct = 0 59 | for i, (label, vec) in tqdm( 60 | enumerate(zip(ann, embeds)), total=len(ann)): 61 | data = [] 62 | diss = [] 63 | # select all the points belong to cluster j 64 | for j in range(len(clusters)): 65 | cls = clusters[j] 66 | indexs = torch.LongTensor(cls.indices) 67 | vecs = q.fix_embeddings[indexs] 68 | vecs = vecs.cpu().numpy() 69 | data.append((vecs, vec)) 70 | diss = Parallel(n_jobs=20, prefer='processes', verbose=0, 71 | batch_size='auto')( 72 | delayed(Space.point2hull)(X1, X2) for X1, X2 in data) 73 | 74 | diss = np.array(diss) 75 | sorted_indices = np.argsort(diss) 76 | preds = [ 77 | (j, clusters[j].major_label, diss[j]) 78 | for j in sorted_indices] 79 | 80 | if preds[0][1] == label: 81 | correct += 1 82 | return_list.append(preds) 83 | acc = correct / len(ann) 84 | return acc, return_list 85 | 86 | def convex2convex( 87 | self, 88 | q: DistanceQ 89 | ) -> np.array: 90 | """Return the distance between the clusters. 91 | 92 | total_label_pair_dis[i][j] means the distance 93 | between label i and labe j. 94 | """ 95 | data = [] 96 | clusters = q.clusters 97 | 98 | indices = list(range(len(clusters))) 99 | 100 | # Prepare the embeddings 101 | for i in range(len(q.clusters)): 102 | cls = q.clusters[i] 103 | indexs = torch.LongTensor(cls.indices) 104 | vecs = q.fix_embeddings[indexs] 105 | vecs = vecs.cpu().numpy() 106 | data.append(vecs) 107 | 108 | total_label_pair_dis = np.empty((q.label_size, q.label_size)) 109 | total_label_pair_dis.fill(np.inf) 110 | indexs = [(i, j) for i in indices for j in indices 111 | if i < j] 112 | # Only compute the distances between clusters with different labels 113 | indexs = [(i, j) for i, j in indexs 114 | if clusters[i].major_label != clusters[j].major_label] 115 | 116 | data = [(data[i], data[j]) for i, j in indexs] 117 | label_pairs = [(clusters[i].major_label, clusters[j].major_label) 118 | for i, j in indexs] 119 | 120 | diss = Parallel(n_jobs=10, prefer='processes', verbose=0, 121 | batch_size=1)( 122 | delayed(Space.hull2hull)(X1, X2) for X1, X2 in data) 123 | 124 | assert len(diss) == len(label_pairs) 125 | label_holder = collections.defaultdict(list) 126 | for (i, j), d in zip(label_pairs, diss): 127 | key = (min(i, j), max(i, j)) 128 | label_holder[key].append(d) 129 | 130 | for (i, j), ds in label_holder.items(): 131 | d = np.mean(ds) 132 | total_label_pair_dis[i][j] = d 133 | total_label_pair_dis[j][i] = d 134 | return total_label_pair_dis 135 | 136 | def label_dis_vec( 137 | self, 138 | total_label_pair_dis: np.array) -> List[float]: 139 | reval = [] 140 | n = total_label_pair_dis.shape[0] 141 | for i in range(n): 142 | for j in range(i+1, n): 143 | if total_label_pair_dis[i][j] != np.inf: 144 | reval.append(total_label_pair_dis[i][j]) 145 | else: 146 | reval.append(0) 147 | return reval 148 | 149 | def inside_max_dis( 150 | self, 151 | q: DistanceQ) -> List[float]: 152 | """Computing the max distance inside each cluster. 153 | """ 154 | data = [] 155 | labels = [] 156 | for i in range(len(q.clusters)): 157 | cls = q.clusters[i] 158 | indexs = torch.LongTensor(cls.indices) 159 | vecs = q.fix_embeddings[indexs].to(self.args.device) 160 | data.append(vecs) 161 | labels.append(cls.major_label) 162 | 163 | holder = collections.defaultdict(list) 164 | for tag, embeds in zip(labels, data): 165 | pdist = torch.nn.functional.pdist(embeds) 166 | if len(pdist) > 0: 167 | holder[tag].append(torch.max(pdist).cpu().numpy()) 168 | 169 | reval = [0] * q.label_size 170 | for tag, ds in holder.items(): 171 | reval[tag] = np.max(ds) 172 | return reval 173 | 174 | def inside_mean_dis( 175 | self, 176 | q: DistanceQ) -> List[float]: 177 | """Computing the mean distance inside each cluster. 178 | """ 179 | data = [] 180 | labels = [] 181 | for i in range(len(q.clusters)): 182 | cls = q.clusters[i] 183 | indexs = torch.LongTensor(cls.indices) 184 | vecs = q.fix_embeddings[indexs].to(self.args.device) 185 | data.append(vecs) 186 | labels.append(cls.major_label) 187 | 188 | holder = collections.defaultdict(list) 189 | for tag, embeds in zip(labels, data): 190 | pdist = torch.nn.functional.pdist(embeds) 191 | if len(pdist) > 0: 192 | holder[tag].append(torch.mean(pdist).cpu().numpy()) 193 | 194 | reval = [0] * q.label_size 195 | for tag, ds in holder.items(): 196 | reval[tag] = np.mean(ds) 197 | return reval 198 | 199 | def outside_min_dis( 200 | self, 201 | total_label_pair_dis: np.array) -> List[float]: 202 | """Find the minmum distance for each label. 203 | """ 204 | min_dis = [] 205 | for t in total_label_pair_dis: 206 | min_dis.append(np.min(t)) 207 | if min_dis[-1] == np.inf: 208 | min_dis[-1] = 0 209 | assert len(min_dis) == total_label_pair_dis.shape[0] 210 | return min_dis 211 | -------------------------------------------------------------------------------- /probing/utils.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # 4 | # Author: Yichu Zhou - flyaway1217@gmail.com 5 | # Blog: zhouyichu.com 6 | # 7 | # Python release: 3.6.0 8 | # 9 | # Date: 2019-07-24 09:55:34 10 | # Last modified: 2021-03-03 09:44:37 11 | 12 | """ 13 | Some utility functions, including loading and saving data. 14 | """ 15 | 16 | import collections 17 | from typing import List, Tuple, TextIO, Dict 18 | 19 | import numpy as np 20 | import torch 21 | 22 | from probing.distanceQ import DistanceQ 23 | 24 | 25 | Pair = collections.namedtuple('Pair', ['Entity', 'Label']) 26 | 27 | 28 | def load_entities(path: TextIO): 29 | reval = [] 30 | with open(path, encoding='utf8') as f: 31 | for line in f: 32 | s = line.strip().split('\t') 33 | reval.append(Pair(*s)) 34 | return reval 35 | 36 | 37 | def load_labels(path: TextIO): 38 | reval = set() 39 | with open(path, encoding='utf8') as f: 40 | for line in f: 41 | reval.add(line.strip()) 42 | return reval 43 | 44 | 45 | def load_embeddings(path: TextIO): 46 | reval = [] 47 | with open(path, encoding='utf8') as f: 48 | for line in f: 49 | s = line.strip().split() 50 | vec = [float(v) for v in s] 51 | reval.append(vec) 52 | return reval 53 | 54 | 55 | def write_predictions( 56 | path: TextIO, 57 | cluster_list: List[List[Tuple[int, str, float]]], 58 | real_labels: List[str]): 59 | """ Write distances to the file. 60 | """ 61 | assert len(cluster_list) == len(real_labels) 62 | with open(path, 'w', encoding='utf8') as f: 63 | for i, label_dis_pair_list in enumerate(cluster_list): 64 | line = [str(real_labels[i])] 65 | for cls_id, label, dis in label_dis_pair_list: 66 | s = '{a}-{b},{c:0.4f}'.format( 67 | a=str(cls_id), b=str(label), c=dis) 68 | line.append(s) 69 | line = '\t'.join(line) 70 | f.write(line+'\n') 71 | 72 | 73 | def write_clusters(path: TextIO, q: DistanceQ): 74 | """Write down the clusters. 75 | """ 76 | ans = [-1] * len(q.fix_embeddings) 77 | ans = np.array(ans) 78 | indices = torch.nonzero(q.active).reshape(-1) 79 | indices = indices.cpu().numpy().tolist() 80 | for i, idx in enumerate(indices): 81 | t = q.clusters[idx] 82 | ans[t.indices] = i 83 | 84 | with open(path, 'w', encoding='utf8') as f: 85 | for i in ans: 86 | f.write(str(i)+'\n') 87 | 88 | 89 | def load_clusters(path: TextIO) -> List[List[int]]: 90 | """ Load the clusters from the file. 91 | 92 | Return: 93 | reval[i] is the list of points that belong 94 | to cluster i. 95 | """ 96 | cluster_labels = [] 97 | with open(path, encoding='utf8') as f: 98 | for line in f: 99 | cluster_labels.append(int(line.strip())) 100 | cluster_num = max(cluster_labels)+1 101 | reval = [[] for _ in range(cluster_num)] 102 | 103 | for i, v in enumerate(cluster_labels): 104 | reval[v].append(i) 105 | return reval 106 | 107 | 108 | def assign_labels( 109 | clusters_indices: List[List[int]], 110 | annotation: np.array) -> List[List[int]]: 111 | """ Assign labele to each cluster. 112 | """ 113 | labels = [] 114 | for cls in clusters_indices: 115 | labs = [annotation[i] for i in cls] 116 | labels.append(labs) 117 | return labels 118 | 119 | 120 | def map_to_label( 121 | idx2label: dict, 122 | cluster_list: List[List[Tuple[int, int, float]]], 123 | real_labels: np.array 124 | ) -> Tuple[List[List[Tuple[int, str, float]]], List[str]]: 125 | """Map the int label to str label. 126 | 127 | Args: 128 | idx2label: A dictionary from int to str. 129 | cluster_list: cluster_list[i][j] is a tuple of 130 | (cluster_id, label, dis), 131 | represents the distance between test point i and the 132 | cluster with label. 133 | real_labels: np.array. The real int labels for each test point. 134 | """ 135 | assert len(cluster_list) == len(real_labels) 136 | real_labels = [idx2label[v] for v in real_labels] 137 | return_list = [] 138 | for i, label_dis_pair_list in enumerate(cluster_list): 139 | s = [(cls_id, idx2label[label], dis) 140 | for cls_id, label, dis in label_dis_pair_list] 141 | return_list.append(s) 142 | return return_list, real_labels 143 | 144 | 145 | def write_convex_dis( 146 | path: str, 147 | label_pairs: List[Tuple[str, str]], 148 | diss: List[float]): 149 | """Write distances between clusters into the file. 150 | """ 151 | assert len(label_pairs) == len(diss) 152 | with open(path, 'w', encoding='utf8') as f: 153 | for (cls_i, label_i, cls_j, label_j), dis in zip(label_pairs, diss): 154 | s = '({a}-{b}, {c}-{d}): {e:0.4f}\n'.format( 155 | a=str(cls_i), b=str(label_i), 156 | c=str(cls_j), d=str(label_j), 157 | e=dis) 158 | f.write(s) 159 | 160 | 161 | def write_dis_inside_convex( 162 | path: TextIO, 163 | mean_std: List[Tuple[float, float]], 164 | labels: List[str]): 165 | assert len(mean_std) == len(labels) 166 | with open(path, 'w') as f: 167 | for i in range(len(labels)): 168 | tag = labels[i] 169 | mean, std = mean_std[i] 170 | s = '{a} {b:.4f} {c:.4f}\n' 171 | s = s.format(a=str(tag), b=mean, c=std) 172 | f.write(s) 173 | 174 | 175 | def load_indices(path: TextIO) -> List[List[int]]: 176 | indices = [] 177 | with open(path) as f: 178 | for line in f: 179 | s = [int(t) for t in line.strip().split()] 180 | indices.append(s) 181 | return indices 182 | 183 | 184 | def write_data_cartography( 185 | path: TextIO, 186 | data: List[Tuple[float, float, float]]): 187 | """ 188 | Args: 189 | data: a list of (mean_prob, std_prob, correctness) 190 | """ 191 | with open(path, 'w') as f: 192 | for mean, std, corr in data: 193 | s = '{a:.3f}\t{b:.3f}\t{c:.3f}\n'.format( 194 | a=mean, b=std, c=corr) 195 | f.write(s) 196 | 197 | 198 | def write_batch_label_dis_vecs( 199 | path: TextIO, 200 | data: np.array, 201 | idx2label: Dict[int, str]): 202 | n = len(idx2label) 203 | indices = [(i, j) for i in range(n) for j in range(i+1, n)] 204 | assert len(indices) == len(data) 205 | assert 12 == data.shape[1] 206 | with open(path, 'w') as f: 207 | s = [str(i) for i in range(1, 13)] 208 | s = ','.join(s) 209 | s = ' ,' + s 210 | f.write(s+'\n') 211 | for (i, j), vec in zip(indices, data): 212 | A = idx2label[i] 213 | B = idx2label[j] 214 | s = '{a}---{b},'.format(a=A, b=B) 215 | v = [format(i, '0.4f') for i in vec] 216 | s = s + ','.join(v) 217 | f.write(s+'\n') 218 | 219 | 220 | def write_batch_diss( 221 | path: TextIO, 222 | data: np.array, 223 | idx2label: Dict[int, str]): 224 | n = len(data) 225 | assert n == len(idx2label) 226 | assert 12 == data.shape[1] 227 | with open(path, 'w') as f: 228 | s = [str(i) for i in range(1, 13)] 229 | s = ','.join(s) 230 | s = ' ,' + s 231 | f.write(s+'\n') 232 | for i, vec in zip(list(range(n)), data): 233 | A = idx2label[i] 234 | s = '{a},'.format(a=A) 235 | v = [format(i, '0.4f') for i in vec] 236 | s = s + ','.join(v) 237 | f.write(s+'\n') 238 | 239 | 240 | def write_diss( 241 | path: TextIO, 242 | data: List[float], 243 | idx2label: Dict[int, str]): 244 | n = len(data) 245 | assert n == len(idx2label) 246 | with open(path, 'w') as f: 247 | for i, v in zip(list(range(n)), data): 248 | A = idx2label[i] 249 | s = '{a},{b:0.4f}'.format(a=A, b=v) 250 | f.write(s+'\n') 251 | 252 | 253 | def write_label_dis_vecs( 254 | path: TextIO, 255 | data: List[float], 256 | idx2label: Dict[int, str] 257 | ): 258 | n = len(idx2label) 259 | indices = [(i, j) for i in range(n) for j in range(i+1, n)] 260 | assert len(indices) == len(data) 261 | with open(path, 'w') as f: 262 | for (i, j), v in zip(indices, data): 263 | A = idx2label[i] 264 | B = idx2label[j] 265 | s = '{a}---{b},{c:0.4f}'.format(a=A, b=B, c=v) 266 | f.write(s+'\n') 267 | -------------------------------------------------------------------------------- /probing/distanceQ.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # 4 | # Author: Yichu Zhou - flyaway1217@gmail.com 5 | # Blog: zhouyichu.com 6 | # 7 | # Python release: 3.6.0 8 | # 9 | # Date: 2020-04-01 10:23:10 10 | # Last modified: 2020-12-29 11:32:13 11 | 12 | """ 13 | Distance Q implementation. 14 | 15 | This is the main class to maintain the list of clusters. 16 | """ 17 | import logging 18 | from typing import List 19 | import heapq 20 | 21 | import torch 22 | from tqdm import tqdm 23 | from tqdm import trange 24 | import numpy as np 25 | 26 | from probing.clusters import ClusterDisList 27 | from probing.clusters import ClusterDisPair 28 | from probing.clusters import Cluster 29 | from probing.config import Config 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | class DistanceQ: 35 | """ This distance data structure has the 36 | following functions: 37 | 1. Return the min. 38 | 2. Delete Clusters 39 | 3. Add new Clusters. 40 | """ 41 | def __init__( 42 | self, 43 | config: Config, 44 | embeddings: np.array, 45 | init_clusters: List[Cluster], 46 | label_size: int): 47 | """In this class, there are a few attributes need to be 48 | maintained: 49 | 50 | - clusters: a list of clusters 51 | - radius: record the radius for each cluster 52 | - embeddings: record the center for each cluster 53 | - active: record which cluster is active 54 | - major_labels: record the majority_label for each cluster 55 | """ 56 | # logger = logging.getLogger('Probing') 57 | # logger.info('Initialize the Distance Queue...') 58 | self._args = config 59 | args = config 60 | self.label_size = label_size 61 | self.clusters = list(init_clusters) 62 | 63 | # self.fix_embeddings is the original embedding 64 | # and it should not be changed during the process 65 | self.fix_embeddings = torch.Tensor(embeddings) 66 | 67 | # Initialize the centers for each clusters 68 | radius = [] 69 | centers = [] 70 | for cls in self.clusters: 71 | indexs = cls.indices 72 | vecs = self.fix_embeddings[indexs].to(args.device) 73 | center = torch.mean(vecs, 0) 74 | center = center.reshape(1, -1) 75 | centers.append(center) 76 | # Find the maximum distance and 77 | # use it as the radius 78 | dis = torch.cdist(vecs, center) 79 | r = torch.max(dis).reshape(1) 80 | radius.append(r) 81 | self.radius = torch.cat(radius).to(args.device) 82 | 83 | # self.embeddings is the centers of each cluster 84 | # It dynamic changes during the merging process 85 | self.embeddings = torch.cat(centers).to(args.device) 86 | assert self.radius.shape[0] == len(self.clusters) 87 | assert self.embeddings.shape[0] == len(self.clusters) 88 | 89 | # Initalize the major labels 90 | major_labels = [t.major_label for t in init_clusters] 91 | self.major_labels = torch.IntTensor(major_labels).to(config.device) 92 | 93 | self.active = torch.ones( 94 | len(self.clusters)).bool().to(config.device) 95 | self.heap_built = False 96 | 97 | def remove_pair(self, i, j) -> None: 98 | """ Clusrer i and cluster j is merged and 99 | they should be removed from the lists. 100 | """ 101 | self.remove(i) 102 | self.remove(j) 103 | if self.heap_built: 104 | heapq.heapify(self.cluster_dis) 105 | 106 | def remove(self, idx: int) -> None: 107 | """ Remove all the records regrading cluster idx. 108 | """ 109 | self.active[idx] = False 110 | if self.heap_built: 111 | self.dis_maps[idx].deactive() 112 | 113 | def min(self) -> ClusterDisPair: 114 | """ Return and delete the pair with least distance. 115 | """ 116 | while True: 117 | # Find the list of pairs that has the minimum distance. 118 | dislist = heapq.heappop(self.cluster_dis) 119 | # Return and remove the minimum cluster pair in this list. 120 | pair = dislist.min() 121 | # Push back the list 122 | heapq.heappush(self.cluster_dis, dislist) 123 | # Check if the given pair is valid 124 | if self.active[pair.i] and self.active[pair.j]: 125 | return pair 126 | 127 | def add(self, newcluster): 128 | """Add a new cluster. 129 | """ 130 | args = self._args 131 | self.clusters.append(newcluster) 132 | 133 | # Compute the new center 134 | indexs = torch.LongTensor(newcluster.indices) 135 | vecs = self.fix_embeddings[indexs].to(args.device) 136 | center = torch.mean(vecs, 0) 137 | center = center.reshape(1, -1) 138 | 139 | # add cluster distance pairs 140 | if self.heap_built: 141 | diss = torch.cdist(self.embeddings[self.active], center) 142 | diss = diss.reshape(-1).cpu().numpy().tolist() 143 | nonzeros = torch.nonzero(self.active).reshape(-1) 144 | nonzeros = nonzeros.cpu().numpy().tolist() 145 | tmp = [] 146 | m = len(self.clusters)-1 147 | 148 | # Build the new distance list for the new cluster: 149 | # computing all pair-wise distance for other clusers 150 | # that have the same label. 151 | for i, v in enumerate(nonzeros): 152 | if self.clusters[v].major_label == newcluster.major_label: 153 | tmp.append(ClusterDisPair(v, m, diss[i])) 154 | tmp = ClusterDisList(tmp, m) 155 | self.dis_maps[m] = tmp 156 | heapq.heappush(self.cluster_dis, tmp) 157 | 158 | # Update the active 159 | active = torch.Tensor([True]).bool().to(args.device) 160 | self.active = torch.cat((self.active, active)) 161 | 162 | # Update the centers 163 | self.embeddings = torch.cat((self.embeddings, center), 0) 164 | 165 | dis = torch.cdist(vecs, center) 166 | r = torch.max(dis).reshape(1) 167 | self.radius = torch.cat((self.radius, r)) 168 | 169 | # Update the major labels 170 | m = torch.IntTensor([newcluster.major_label]).to(args.device) 171 | self.major_labels = torch.cat((self.major_labels, m)) 172 | 173 | assert len(self.embeddings) == len(self.active) 174 | assert len(self.clusters) == len(self.active) 175 | assert len(self.radius) == len(self.active) 176 | assert len(self.major_labels) == len(self.active) 177 | if self.heap_built: 178 | assert len(self.cluster_dis) == len(self.active) 179 | 180 | def __len__(self): 181 | return len(torch.nonzero(self.active)) 182 | 183 | def build_heaps(self): 184 | if self.heap_built: 185 | return 186 | args = self._args 187 | logger.info('Initializing the pair-wise distance...') 188 | 189 | logger.info('Categorize the clusters based on the label...') 190 | labels_indices = [[] for _ in range(self.label_size)] 191 | for i, t in enumerate(self.clusters): 192 | labels_indices[t.major_label].append(i) 193 | 194 | logger.info( 195 | 'Computing the pair-wise distance inside the same label...') 196 | tmp = [[] for _ in range(len(self.clusters))] 197 | for i in trange(len(labels_indices), desc='Labels'): 198 | indices = labels_indices[i] 199 | idx = torch.LongTensor(indices).to(args.device) 200 | embeds = self.embeddings[idx] 201 | # Compute the pair-wise distance between 202 | # all clusters that have the same label. 203 | pdist = torch.nn.functional.pdist(embeds) 204 | pdist = pdist.cpu().numpy().tolist() 205 | m = len(embeds) 206 | ij = torch.triu_indices(m, m, 1) 207 | ij = ij.cpu().numpy().tolist() 208 | 209 | s = 'Building list for label {a}'.format(a=str(i)) 210 | for k in trange(len(ij[0]), desc=s, leave=False): 211 | i = min(indices[ij[0][k]], indices[ij[1][k]]) 212 | j = max(indices[ij[0][k]], indices[ij[1][k]]) 213 | t = ClusterDisPair(i, j, pdist[k]) 214 | tmp[j].append(t) 215 | 216 | self.dis_maps = dict() 217 | cluster_dis = [] 218 | assert len(tmp[0]) == 0 219 | 220 | logger.info('Build the double heaps...') 221 | for i in tqdm(range(len(tmp))): 222 | t = ClusterDisList(tmp[i], i) 223 | self.dis_maps[i] = t 224 | heapq.heappush(cluster_dis, t) 225 | self.cluster_dis = cluster_dis 226 | self.heap_built = True 227 | 228 | @staticmethod 229 | def cleanQ(args, q: 'DistanceQ'): 230 | """ Build a new clean Q. 231 | """ 232 | indices = torch.nonzero(q.active).reshape(-1) 233 | clusters = [q.clusters[i] for i in indices] 234 | return DistanceQ(args, q.fix_embeddings, 235 | clusters, q.label_size) 236 | -------------------------------------------------------------------------------- /probing/probing.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # 4 | # Author: Yichu Zhou - flyaway1217@gmail.com 5 | # Blog: zhouyichu.com 6 | # 7 | # Python release: 3.6.0 8 | # 9 | # Date: 2020-02-18 11:05:08 10 | # Last modified: 2022-09-22 21:22:35 11 | 12 | """ 13 | Applying the probing process. 14 | """ 15 | 16 | import logging 17 | 18 | from typing import Tuple 19 | import torch 20 | import numpy as np 21 | from tqdm import tqdm 22 | import ExAssist as EA 23 | 24 | from probing.space import Space 25 | from probing.clusters import Cluster 26 | from probing.distanceQ import DistanceQ 27 | from probing.config import Config 28 | 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | 33 | class Probe: 34 | """ The clustering stratgy: 35 | 1. Keep merging to the end. 36 | 2. We check for overlapping at the end. 37 | 3. If there is overlapping, we trace back to the step 38 | where first error happens. 39 | 4. If there is not overlapping, directly return the result. 40 | 41 | The reason we are doing this because we want to avoid 42 | unnecessary overlapping checking as many as possible. 43 | The overlapping checking is expensive. 44 | """ 45 | def __init__(self, config: Config): 46 | self._args = config 47 | self.space = Space(self._args) 48 | 49 | def probe_to_end(self, q: DistanceQ) -> np.array: 50 | """ Probing to end without any overlapping checking. 51 | """ 52 | q.build_heaps() 53 | max_num_steps = len(q) 54 | # Initialize the tracks... 55 | logger.info('Initializing the tracks...') 56 | tracks = [self._snapshot(q)] 57 | 58 | logger.info('Probing to the end...') 59 | iterator = tqdm(range(max_num_steps)) 60 | for step in iterator: 61 | try: 62 | pair = q.min() 63 | except IndexError: 64 | iterator.close() 65 | break 66 | i, j = pair.i, pair.j 67 | 68 | # This is for debuging. It should not return a pair 69 | # whose major_label is different. 70 | if q.clusters[i].major_label != q.clusters[j].major_label: 71 | logger.error('Cluster pairs have different major_label!') 72 | continue 73 | newcluster = Cluster.merge(q.clusters[i], q.clusters[j]) 74 | 75 | # It is important to remove first and then add 76 | q.remove_pair(i, j) 77 | q.add(newcluster) 78 | 79 | # Save every step 80 | tracks.append(np.copy(tracks[-1])) 81 | tracks[-1][newcluster.indices] = len(q.clusters)-1 82 | logger.info('Finish probing to the end...') 83 | s = tracks[-1].tolist() 84 | assert len(set(s)) >= len(q) 85 | return tracks 86 | 87 | def _snapshot(self, q: DistanceQ) -> np.array: 88 | """ Take a snapshot of current DistanceQ. 89 | 90 | This method return a array `tracks` whose length 91 | is equal to the length of embedding points. 92 | 93 | tracks[i] represents the cluster index that i-th 94 | embedding belong to. 95 | """ 96 | indices = torch.nonzero(q.active).reshape(-1) 97 | indices = indices.cpu().numpy() 98 | tracks = list(range(len(q.fix_embeddings))) 99 | tracks = np.array(tracks) 100 | for i, idx in enumerate(indices): 101 | t = q.clusters[idx] 102 | tracks[t.indices] = idx 103 | return tracks 104 | 105 | def probing(self, q: DistanceQ) -> DistanceQ: 106 | """ Apply probing procedure 107 | """ 108 | assist = EA.getAssist('Probing') 109 | tracks = self.probe_to_end(q) 110 | logger.info('Checking for the end state..') 111 | t = self._check_overlaps(q) 112 | if not t: 113 | logger.info('There is no overlaps in the end state!') 114 | logger.info('This space is linear!') 115 | assist.result['linear'] = 1 116 | s = 'Final number of clusters: {a}' 117 | s = s.format(a=str(len(q))) 118 | logger.info(s) 119 | return q 120 | 121 | logger.info('This space is non-linear!') 122 | assist.result['linear'] = 0 123 | # Use coarse search to find the range of first error 124 | i, j = self._find_coarse_range(tracks, q) 125 | # Use binary search to find the first error 126 | k = self._find_first_error(q, tracks, i, j) 127 | 128 | s = 'Found {a}-th state is the first error state' 129 | s = s.format(a=str(k-1)) 130 | logger.info(s) 131 | 132 | # rebuild the DistanceQ using the tracks 133 | # before the first error. 134 | q = self._build_q_from_track(q, tracks[k-1]) 135 | 136 | # Keep merging and checking. 137 | q = self._forward(q) 138 | 139 | s = 'Final number of clusters: {a}' 140 | s = s.format(a=str(len(q))) 141 | logger.info(s) 142 | return q 143 | 144 | def _check_overlaps(self, q: DistanceQ) -> bool: 145 | """ Check pair-wise clusters overlapping. 146 | 147 | Return true if there is at least one overlap. 148 | """ 149 | indexs = torch.nonzero(q.active).reshape(-1) 150 | indexs = indexs.cpu().numpy() 151 | logger.info('Start ovelapping checking...') 152 | iterator = tqdm(indexs) 153 | flag = False 154 | 155 | for i in iterator: 156 | if self.space.overlapping(q, q.clusters[i]): 157 | flag = True 158 | # Here we do not use early stop 159 | # because we want to cache all the errors 160 | # in space object. 161 | # It may take time here, but save more latter. 162 | # return flag 163 | return flag 164 | 165 | def _build_clusters_from_track( 166 | self, 167 | q: DistanceQ, 168 | track: np.array): 169 | assert len(track) == len(q.fix_embeddings) 170 | cls_set = set(track.tolist()) 171 | clusters = [] 172 | 173 | # Collect clusters. 174 | for n in cls_set: 175 | clusters.append(q.clusters[n]) 176 | 177 | # make sure the number of embedding points is correct 178 | s = [len(t.indices) for t in clusters] 179 | assert sum(s) == len(track) 180 | return clusters 181 | 182 | def _find_coarse_range( 183 | self, 184 | tracks: np.array, 185 | q: DistanceQ) -> Tuple[int, int]: 186 | """Return of range of state with the first 187 | error state in it. 188 | 189 | Returns: (i, j) 190 | """ 191 | m = len(tracks)-1 192 | if m < 1000: 193 | return 0, m 194 | step = int(m * self._args.rate)+1 # avoid 0 step 195 | logger.info('Start coarse search...') 196 | k = m - step 197 | while True: 198 | logger.info('Test for state {a}'.format(a=str(k))) 199 | newq = self._build_q_from_track(q, tracks[k]) 200 | t = self._check_overlaps(newq) 201 | if not t: 202 | s = 'Found {a}-th state is correct...' 203 | s = s.format(a=str(k)) 204 | logger.info(s) 205 | return (k, k+step) 206 | else: 207 | k = k - step 208 | k = max(0, k) 209 | 210 | def _forward(self, q: DistanceQ) -> DistanceQ: 211 | """ Merging with overlapping checking. 212 | """ 213 | q.build_heaps() 214 | m = len(q) 215 | logger.info('Start normal forward probing...') 216 | with tqdm(range(m)) as pbar: 217 | while True: 218 | try: 219 | pair = q.min() 220 | except IndexError: 221 | return q 222 | i, j = pair.i, pair.j 223 | newcluster = Cluster.merge(q.clusters[i], q.clusters[j]) 224 | t = self.space.overlapping(q, newcluster) 225 | if not t: 226 | q.remove_pair(i, j) 227 | q.add(newcluster) 228 | pbar.update(1) 229 | return q 230 | 231 | # def _closest_set(self, q, test_vec): 232 | # args = self._args 233 | # embeds = q.fix_embeddings.to(args.device) 234 | # vec = torch.Tensor(test_vec).to(args.device) 235 | 236 | # cdist = torch.cdist(embeds, vec.reshape(1, -1)) 237 | # cdist = cdist.reshape(-1).cpu().numpy() 238 | # min_dists = [] 239 | # for t in q.clusters: 240 | # min_dists.append(min(cdist[t.indices])) 241 | # n = min(len(q.clusters), 5) 242 | # return np.argsort(min_dists)[:n] 243 | 244 | def _build_q_from_track( 245 | self, 246 | q: DistanceQ, 247 | track: np.array) -> DistanceQ: 248 | clusters = self._build_clusters_from_track(q, track) 249 | newq = DistanceQ( 250 | self._args, q.fix_embeddings, 251 | clusters, q.label_size) 252 | return newq 253 | 254 | def _find_first_error( 255 | self, 256 | q: DistanceQ, 257 | tracks: np.array, 258 | i: int, 259 | j: int) -> int: 260 | """Use binary search to find the first error. 261 | """ 262 | logger.info('Start fine search...') 263 | while i < j: 264 | k = (i+j) // 2 265 | logger.info('Test for state {a}'.format(a=str(k))) 266 | newq = self._build_q_from_track(q, tracks[k]) 267 | t = self._check_overlaps(newq) 268 | if t: 269 | j = k 270 | else: 271 | i = k + 1 272 | return j 273 | -------------------------------------------------------------------------------- /probing/space.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # 4 | # Author: Yichu Zhou - flyaway1217@gmail.com 5 | # Blog: zhouyichu.com 6 | # 7 | # Python release: 3.6.0 8 | # 9 | # Date: 2020-03-20 10:42:42 10 | # Last modified: 2020-12-29 14:30:21 11 | 12 | import logging 13 | from typing import List, Tuple 14 | 15 | from sklearn.preprocessing import StandardScaler 16 | import torch 17 | import numpy as np 18 | from joblib import Parallel, delayed 19 | # from scipy.spatial import distance 20 | import gurobipy as gp 21 | from gurobipy import GRB 22 | # from sklearn.svm import LinearSVC 23 | from sklearn.svm import SVC 24 | 25 | from probing.distanceQ import DistanceQ 26 | from probing.clusters import Cluster 27 | 28 | Tensor = torch.Tensor 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | 33 | class Space: 34 | def __init__(self, args): 35 | self._args = args 36 | self.whitelist = [] 37 | self.blacklist = [] 38 | self.solver = self.check_overlap_detector() 39 | 40 | def check_overlap_detector(self): 41 | """ 42 | If gurobi is valid, we will use gurobi to solve a LP problem. 43 | Otherwise, a hard svm will be used to find the hyperplane. 44 | """ 45 | s1 = ('Gurobi is NOT found in the system, ' 46 | 'we will use sklearn.SCV instead.') 47 | s2 = ('Gurobi IS found in the system, we will use Gurobi.') 48 | try: 49 | gp.Model("lp") 50 | logger.info(s2) 51 | return Space.lp 52 | except Exception: 53 | logger.info(s1) 54 | return Space.hardSVM 55 | 56 | @staticmethod 57 | def point2hull(X: np.array, p: np.array) -> float: 58 | """Compute the distance between cluster X and point p. 59 | """ 60 | # clf = LinearSVC(tol=1e-5, loss='hinge', C=1000, max_iter=20000) 61 | clf = SVC(tol=1e-5, C=10000, kernel='linear', max_iter=20000) 62 | y = [0] * len(X) 63 | y.append(1) 64 | y = np.array(y) 65 | XX = np.concatenate((X, p.reshape(1, -1))) 66 | clf.fit(XX, y) 67 | # Can not separate them 68 | # This means the test point 69 | # is inside of the convex hull. 70 | # As a result, the distance is zero. 71 | if clf.score(XX, y) != 1.0: 72 | return 0 73 | 74 | # Compute the distance from point to hyperplane 75 | w = clf.coef_.reshape(-1) 76 | b = clf.intercept_[0] 77 | d = np.dot(w, p) + b 78 | d = abs(d) / np.linalg.norm(w) 79 | 80 | return 2*d 81 | 82 | @staticmethod 83 | def hull2hull(X1: np.array, X2: np.array) -> float: 84 | """Compute the distance between two convex hulls. 85 | """ 86 | clf = SVC(tol=1e-5, C=1000000, kernel='linear', max_iter=50000) 87 | # clf = LinearSVC(tol=1e-5, loss='hinge', C=100000, max_iter=20000) 88 | y1 = [1] * len(X1) 89 | y2 = [-1] * len(X2) 90 | y1 = np.array(y1) 91 | y2 = np.array(y2) 92 | XX = np.concatenate((X1, X2)) 93 | yy = np.concatenate((y1, y2)) 94 | clf.fit(XX, yy) 95 | 96 | score = clf.score(XX, yy) 97 | # Two convex hulls overlap 98 | if score != 1: 99 | return 0 100 | 101 | w = clf.coef_.reshape(-1) 102 | b = clf.intercept_[0] 103 | 104 | d = np.dot(XX, w) + b 105 | d = abs(d) / np.linalg.norm(w) 106 | 107 | d = d*2 108 | return np.min(d) 109 | 110 | def overlapping(self, q: DistanceQ, t: Cluster) -> bool: 111 | """ Check if the newcluster overlaps with other clusters 112 | with different labels. 113 | 114 | Return: 115 | True: there is at least one overlapping. 116 | False: there is no overlapping. 117 | """ 118 | # Filter the clusters that have the potential to overlap 119 | # with t. 120 | indexs = self._sphere_overlapping(q, t) 121 | indexs = indexs.cpu().numpy() 122 | if len(indexs) == 0: 123 | return False 124 | 125 | # Filter through the black list 126 | if [i for i in indexs if self._black_cached(q.clusters[i], t)]: 127 | return True 128 | 129 | # Filter through the white list 130 | indexs = [i for i in indexs 131 | if not self._white_cached(q.clusters[i], t)] 132 | 133 | if len(indexs) == 0: 134 | return False 135 | else: 136 | # Apply the LP solver. 137 | return self._lp_overlapping(q, t, indexs) 138 | 139 | def _sphere(self, q: DistanceQ, cluster: Cluster) -> Tensor: 140 | """ Find the clusters that have the potential to overlap 141 | with the new cluster. 142 | """ 143 | args = self._args 144 | # Compute the center for new cluster 145 | indexs = torch.LongTensor(cluster.indices) 146 | vecs = q.fix_embeddings[indexs].to(args.device) 147 | center = torch.mean(vecs, 0) 148 | center = center.reshape(1, -1) 149 | 150 | # Find the maximum distance and 151 | # use it as the radius 152 | dis = torch.cdist(vecs, center) 153 | r = torch.max(dis).reshape(1) 154 | 155 | embeddings = q.embeddings 156 | dist = torch.cdist(embeddings, center).reshape(-1) 157 | assert len(embeddings) == len(dist) 158 | 159 | overlap_mask = (q.radius + r) > dist 160 | return overlap_mask 161 | 162 | def _sphere_overlapping(self, q: DistanceQ, newcluster: Cluster) -> Tensor: 163 | """Check if the newcluster overlaps with other clusters 164 | with different labels in sphere. 165 | """ 166 | args = self._args 167 | # Compute the new center and radius 168 | overlap_mask = self._sphere(q, newcluster) 169 | 170 | # We do not compare the clusters with the same labels 171 | indexs = torch.ones(q.embeddings.shape[0]).bool() 172 | idx = q.major_labels == newcluster.major_label 173 | indexs[idx] = False 174 | indexs = indexs.to(args.device) 175 | assert indexs.shape == overlap_mask.shape 176 | 177 | # 1. potential overlap 178 | # 2. Have different labels 179 | # 3. It is a active cluster 180 | mask = overlap_mask & indexs & q.active 181 | return torch.nonzero(mask).reshape(-1) 182 | 183 | def _lp_overlapping( 184 | self, 185 | q: DistanceQ, 186 | newcluster: Cluster, 187 | indexs: List[int]) -> bool: 188 | """Apply the LP solver to check overlapping. 189 | 190 | Return: 191 | True: there is at least overlapping. 192 | False: there is not overlapping. 193 | """ 194 | # logger = logging.getLogger('Probing') 195 | cur_vecs = q.fix_embeddings[torch.LongTensor(newcluster.indices)] 196 | X1 = cur_vecs.numpy() 197 | data = [] 198 | for i in indexs: 199 | idxs = q.clusters[i].indices 200 | vecs = q.fix_embeddings[torch.LongTensor(idxs)] 201 | X2 = vecs.numpy() 202 | data.append((X1, X2)) 203 | 204 | # logger.info('Solving {a} LP...'.format(a=str(len(data)))) 205 | results = Parallel(n_jobs=30, prefer='processes', verbose=0, 206 | batch_size='auto')( 207 | delayed(self.solver)(X1, X2) for X1, X2 in data) 208 | 209 | # Add white list to avoid further computation 210 | for i, v in enumerate(results): 211 | cache = self._add_list(newcluster, q.clusters[indexs[i]]) 212 | if v == 0: 213 | self.whitelist.append(cache) 214 | else: 215 | self.blacklist.append(cache) 216 | 217 | if np.sum(results) != 0: 218 | return True 219 | else: 220 | return False 221 | 222 | @staticmethod 223 | def lp(X1: np.array, X2: np.array) -> int: 224 | """Return 1 when the LP problem is infeasible. 225 | """ 226 | # logger = logging.getLogger('Probing') 227 | m = X1.shape[1] 228 | model = gp.Model("lp") 229 | 230 | # Create variables 231 | W = model.addMVar(shape=m+1, lb=-GRB.INFINITY, 232 | ub=GRB.INFINITY, vtype=GRB.CONTINUOUS, name="W") 233 | # Adding the bias 234 | XX1 = np.concatenate((X1, np.ones((X1.shape[0], 1))), axis=1) 235 | XX2 = np.concatenate((X2, np.ones((X2.shape[0], 1))), axis=1) 236 | 237 | Y1 = np.array([1]*X1.shape[0]) 238 | Y2 = np.array([-1]*X2.shape[0]) 239 | 240 | model.addConstr(XX1 @ W >= Y1) 241 | model.addConstr(XX2 @ W <= Y2) 242 | model.setObjective(0, GRB.MINIMIZE) 243 | model.setParam('OutputFlag', False) 244 | model.setParam('FeasibilityTol', 1e-4) 245 | 246 | # Optimize model 247 | model.update() 248 | model.optimize() 249 | 250 | # s = 'Solved a LP problem, status code {a}' 251 | # logger.info(s.format(a=str(model.Status))) 252 | return int(model.Status != GRB.OPTIMAL) 253 | 254 | def hardSVM(X1: np.array, X2: np.array): 255 | """Return 1 when the X1 and X2 are not separable. 256 | """ 257 | clf = SVC(tol=1e-5, C=10000, kernel='linear', max_iter=500000) 258 | # clf = LinearSVC(tol=1e-5, loss='hinge', C=100000, max_iter=20000) 259 | y1 = [0] * len(X1) 260 | y2 = [1] * len(X2) 261 | y1 = np.array(y1) 262 | y2 = np.array(y2) 263 | XX = np.concatenate((X1, X2)) 264 | yy = np.concatenate((y1, y2)) 265 | 266 | scaler = StandardScaler() 267 | scaler.fit(XX) 268 | XX = scaler.transform(XX) 269 | clf.fit(XX, yy) 270 | 271 | score = clf.score(XX, yy) 272 | # Two convex hulls overlap 273 | if score != 1: 274 | return 1 275 | else: 276 | return 0 277 | 278 | def _add_list(self, A: Cluster, B: Cluster) -> Tuple[set, set]: 279 | """ Add A and B, along with all their children 280 | into white list. 281 | """ 282 | setA = set(A.indices) 283 | setB = set(B.indices) 284 | return (setA, setB) 285 | 286 | def _white_cached(self, A: Cluster, B: Cluster) -> bool: 287 | """Test if cluster A and B have been tested. 288 | 289 | Return true if A and B have been tested. 290 | """ 291 | a = set(A.indices) 292 | b = set(B.indices) 293 | for setA, setB in self.whitelist: 294 | if (a <= setA and b <= setB) or (b <= setA and a <= setB): 295 | return True 296 | else: 297 | return False 298 | 299 | def _black_cached(self, A: Cluster, B: Cluster) -> bool: 300 | """Test if any subclusters of A and B have been tested. 301 | 302 | Return true if subclusters of A and B have been tested. 303 | """ 304 | a = set(A.indices) 305 | b = set(B.indices) 306 | for setA, setB in self.blacklist: 307 | if (setA <= a and setB <= b) or (setA <= b and setB <= a): 308 | return True 309 | else: 310 | return False 311 | --------------------------------------------------------------------------------