├── neural ├── train_crossencoder.py ├── train_biencoder.py ├── modified_sbert │ ├── data_loaders.py │ ├── train.py │ └── clu_evaluators.py └── neural_eval.py ├── README.md ├── inference_at_scale ├── scaled_lsh.py ├── data_fns.py ├── test_sets.py └── biencoder_at_scale.py ├── rule_based ├── small_lsh.py ├── ngrams.py └── rule_based_utils.py ├── environment.yml └── utils.py /neural/train_crossencoder.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from datetime import datetime 3 | 4 | from modified_sbert.train import train_crossencoder 5 | 6 | 7 | def extract_raw_data(dataset_path): 8 | 9 | raw_data = pd.read_csv(dataset_path, sep='\t', encoding='utf-8') 10 | 11 | sentence_1_list = [str(i) for i in list(raw_data["Text 1"])] 12 | sentence_2_list = [str(i) for i in list(raw_data["Text 2"])] 13 | labels = list(raw_data["Label"]) 14 | 15 | return {'sentence_1': sentence_1_list, 'sentence_2': sentence_2_list, "labels": labels} 16 | 17 | 18 | if __name__ == '__main__': 19 | 20 | data_path = '' 21 | 22 | train_crossencoder( 23 | train_data=extract_raw_data(f'{data_path}/train_set.csv'), 24 | dev_data=extract_raw_data(f'{data_path}/dev_set.csv'), 25 | model_name='roberta-base', 26 | lr=2e-05, 27 | train_batch_size=32, 28 | num_epochs=5, 29 | warm_up_perc=0.2, 30 | eval_per_epoch=10, 31 | model_save_path=f'output/{datetime.now().strftime("%Y-%m-%d_%H-%M")}' 32 | ) 33 | -------------------------------------------------------------------------------- /neural/train_biencoder.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from datetime import datetime 3 | 4 | from sentence_transformers import losses 5 | 6 | from modified_sbert.train import train_biencoder 7 | 8 | 9 | def extract_raw_data(dataset_path): 10 | 11 | raw_data = pd.read_csv(dataset_path, sep='\t', encoding='utf-8') 12 | 13 | sentence_1_list = [str(i) for i in list(raw_data["Text 1"])] 14 | sentence_2_list = [str(i) for i in list(raw_data["Text 2"])] 15 | labels = list(raw_data["Label"]) 16 | 17 | return {'sentence_1': sentence_1_list, 'sentence_2': sentence_2_list, "labels": labels} 18 | 19 | 20 | if __name__ == '__main__': 21 | 22 | data_path = '' 23 | 24 | train_biencoder( 25 | train_data=extract_raw_data(f'{data_path}/train_set.csv'), 26 | dev_data=extract_raw_data(f'{data_path}/dev_set.csv'), 27 | base_model='sentence-transformers/all-mpnet-base-v2', 28 | add_pooling_layer=False, 29 | train_batch_size=32, 30 | num_epochs=16, 31 | warmup_epochs=16, 32 | loss_fn='contrastive', 33 | loss_params={'distance_metric': losses.SiameseDistanceMetric.COSINE_DISTANCE, 'margin': 0.2}, 34 | model_save_path=f'output/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}', 35 | ) 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NEWS-COPY 2 | 3 | Code for our paper "Noise-Robust De-Duplication at Scale" 4 | [[NBER](https://www.nber.org/papers/w30726)], [[arxiv](https://arxiv.org/abs/2210.04261)], [[ICLR](https://openreview.net/forum?id=bAz2DBS35i)] 5 | 6 | This repo includes: 7 | - NEWS-COPY dataset: 27,210 document dataset, with 122,876 positive duplicate pairs, for studying noise-robust de-duplication. 8 | - Rule-based de-duplication methods: hashing and N-gram overlap. 9 | - Neural de-duplication methods: a contrastively trained bi-encoder, and a "re-rank" style approach combining a bi- and cross-encoder. Plus pre-trained models for both of these. 10 | - Inference at scale for hashing and the biencoder methods. 11 | 12 | If you find this work useful, please cite the following paper: 13 | 14 | @inproceedings{silcock-etal-2020-noise, 15 | title = "Noise-Robust De-Duplication at Scale", 16 | author = "Silcock, Emily and D'Amico-Wong, Luca and Yang, Jinglin and Dell, Melissa", 17 | booktitle = "International Conference on Learning Representations (ICLR)", 18 | year = "2023", 19 | } 20 | 21 | ### Installation 22 | 23 | git clone https://github.com/dell-research-harvard/NEWS-COPY.git 24 | cd NEWS-COPY 25 | conda env create -f environment.yml 26 | 27 | 28 | ### Data 29 | - Historical Newspapers: train, evaluation and test sets can be downloaded [here](https://www.dropbox.com/sh/so3iw4xecayyrow/AAAiy5FhDf0WpUeHFzxO1SIza?dl=0). For more detail see the paper above 30 | - C4: C4 can be downloaded thanks to AllenAI - see https://github.com/allenai/allennlp/discussions/5056 31 | 32 | 33 | ### Rule-based 34 | Codebase for neural methods for de-duplication, n-gram overlap and locally sensitive hashing. These predominate in the literature, but are significantly outperformed by the neural methods below. 35 | 36 | 37 | ### Neural 38 | Training and evaluation scripts for a contrastively trained bi-encoder, and a "re-rank" style approach combining a bi- and cross-encoder. These outperform the rule-based approaches above. 39 | 40 | Pretrained models for both the bi-encoder and cross-encoder can be found [here](https://www.dropbox.com/sh/so3iw4xecayyrow/AAAiy5FhDf0WpUeHFzxO1SIza?dl=0). 41 | 42 | ### Inference at scale 43 | Inference at scale for hashing (LSH) and the biencoder methods over C4 and SuperGlue. The bi-encoder scales well, de-duplicating a 10 million text corpus on a single GPU card in a matter of hours. 44 | -------------------------------------------------------------------------------- /inference_at_scale/scaled_lsh.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | from datasketch import MinHash, LeanMinHash, MinHashLSH 4 | import pickle 5 | 6 | import data_fns 7 | 8 | 9 | def remove_odd_characters(text): 10 | ''' Removes punctuation and unknown characters. ''' 11 | chars_to_remove = r'"#$%&\()*+/:;<=>@[\\]^_`{|}~.?,!\'' 12 | 13 | text = text.replace("-\n", "").replace("\n", " ") 14 | text = text.translate(str.maketrans('', '', chars_to_remove)) 15 | text = text.encode('ascii', 'ignore').decode() 16 | 17 | return text 18 | 19 | 20 | def minhash(text, n_gram_size, num_hashes): 21 | ''' Returns hash object given a text. ''' 22 | text = remove_odd_characters(text) 23 | 24 | m = MinHash(num_perm=num_hashes) 25 | words = text.split() 26 | 27 | if len(words) > n_gram_size: 28 | n_grams = list(zip(*[words[i:] for i in range(n_gram_size)])) 29 | n_grams = [" ".join(list(x)) for x in n_grams] 30 | # If text is too short, just hash the entire text 31 | else: 32 | n_grams = [text] 33 | 34 | m.update_batch([s.encode('utf-8') for s in n_grams]) 35 | 36 | return LeanMinHash(m) 37 | 38 | 39 | def get_hashes_batched(articles, n_gram_size, num_hashes): 40 | ''' Returns hashed articles. ''' 41 | minhashes = [] 42 | 43 | for article in tqdm(articles): 44 | minhashes.append(minhash(article, n_gram_size, num_hashes)) 45 | 46 | return minhashes 47 | 48 | 49 | def lsh_similar(minhashes, num_hashes, bands, rows, n_gram_size, texts): 50 | ''' Creates edges between articles given hashes. ''' 51 | 52 | lsh = MinHashLSH(num_perm=num_hashes, params=(bands, rows)) 53 | for i, hsh in enumerate(tqdm(minhashes)): 54 | # Check if duplicate of already seen item 55 | for j in lsh.query(hsh): 56 | yield (j, i) 57 | # Add to the seen items 58 | lsh.insert(i, hsh) 59 | 60 | 61 | if __name__ == '__main__': 62 | 63 | corpus = data_fns.open_c4_by_url(pattern="patents.google.com", name="patents") 64 | # corpus = data_fns.get_super_glue() 65 | 66 | num_hashes = 30 67 | n_gram_size = 10 68 | 69 | hashes = get_hashes_batched(corpus, n_gram_size, num_hashes) 70 | 71 | similar_all = list(lsh_similar(hashes, num_hashes, 15, 2, n_gram_size, corpus)) 72 | total_count = len(similar_all) 73 | print("Total Edges: ", total_count) 74 | 75 | with open(f'', 'wb') as f: 76 | pickle.dump(similar_all, f, protocol=4) 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /neural/modified_sbert/data_loaders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | 5 | from sentence_transformers.readers import InputExample 6 | 7 | currentdir = os.path.dirname(os.path.realpath(__file__)) 8 | parentdir = os.path.dirname(currentdir) 9 | grandparentdir = os.path.dirname(parentdir) 10 | sys.path.append(parentdir) 11 | sys.path.append(grandparentdir) 12 | 13 | import utils 14 | 15 | 16 | def load_data_as_individuals(data, type): 17 | 18 | sentence_1_list = data['sentence_1'] 19 | sentence_2_list = data['sentence_2'] 20 | labels = data['labels'] 21 | 22 | # Organise by cluster 23 | edges_list = [] 24 | for i in range(len(sentence_1_list)): 25 | if labels[i] == "same": 26 | edges_list.append([sentence_1_list[i], sentence_2_list[i]]) 27 | 28 | cluster_dict = utils.clusters_from_edges(edges_list) 29 | 30 | # Pull out texts and cluster IDs 31 | indv_data = [] 32 | guid = 1 33 | for cluster_id in list(cluster_dict.keys()): 34 | 35 | for text in cluster_dict[cluster_id]: 36 | indv_data.append(InputExample(guid=guid, texts=[text], label=cluster_id)) 37 | 38 | guid += 1 39 | 40 | print(f'{len(indv_data)} {type} examples') 41 | 42 | return indv_data 43 | 44 | 45 | def load_data_as_pairs(data, type): 46 | 47 | sentence_1_list = data['sentence_1'] 48 | sentence_2_list = data['sentence_2'] 49 | labels = data['labels'] 50 | 51 | label2int = {"same": 1, "different": 0, 1: 1, 0: 0} 52 | 53 | paired_data = [] 54 | for i in range(len(sentence_1_list)): 55 | label_id = label2int[labels[i]] 56 | paired_data.append(InputExample(texts=[sentence_1_list[i], sentence_2_list[i]], label=float(label_id))) 57 | 58 | print(f'{len(paired_data)} {type} pairs') 59 | 60 | return paired_data 61 | 62 | 63 | def load_data_as_triplets(data, type): 64 | 65 | sentence_1_list = data['sentence_1'] 66 | sentence_2_list = data['sentence_2'] 67 | labels = data['labels'] 68 | 69 | # Create dict of examples where you have labels, at the anchor level 70 | def add_to_samples(sent1, sent2, label): 71 | if sent1 not in anchor_dict: 72 | anchor_dict[sent1] = {'same': set(), 'different': set()} 73 | anchor_dict[sent1][label].add(sent2) 74 | 75 | anchor_dict = {} 76 | for i in range(len(sentence_1_list)): 77 | add_to_samples(sentence_1_list[i], sentence_2_list[i], labels[i]) 78 | add_to_samples(sentence_2_list[i], sentence_1_list[i], labels[i]) #Also add the opposite 79 | 80 | # Create triplets 81 | triplet_data = [] 82 | for anchor, others in anchor_dict.items(): 83 | while len(others['same']) > 0 and len(others['different']) > 0: 84 | 85 | same_sent = random.choice(list(others['same'])) 86 | dif_sent = random.choice(list(others['different'])) 87 | 88 | triplet_data.append(InputExample(texts=[anchor, same_sent, dif_sent])) 89 | 90 | others['same'].remove(same_sent) 91 | others['different'].remove(dif_sent) 92 | 93 | print(f'{len(triplet_data)} {type} triplets') 94 | 95 | return triplet_data 96 | -------------------------------------------------------------------------------- /rule_based/small_lsh.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from itertools import combinations 3 | from datasketch import MinHash 4 | from collections import defaultdict 5 | import json 6 | import copy 7 | import sys 8 | import os 9 | from os.path import dirname as up 10 | 11 | 12 | sys.path.append(up(up(up(os.path.realpath(__file__))))) 13 | 14 | import rule_based_utils 15 | 16 | 17 | def get_counts(data_dict, num_hashes): 18 | ''' Gets parwise counts of shared hashes.''' 19 | 20 | pairwise_counts = defaultdict(int) 21 | hash_dict = [defaultdict(list) for _ in range(num_hashes)] 22 | m1 = MinHash(num_perm=num_hashes) 23 | 24 | # iterate through all the articles 25 | for date in data_dict: 26 | for index, ngrams in enumerate(data_dict[date]['ngram_list']): 27 | m1.clear() 28 | article_id = data_dict[date]['id_list'][index] 29 | 30 | # hash ngrams 31 | for ngram in ngrams: 32 | m1.update(ngram.encode('utf8')) 33 | 34 | # add hashes to hash table 35 | for index, hash_val in enumerate(list(m1.digest())): 36 | hash_dict[index][hash_val].append(article_id) 37 | 38 | # iterate through hash table and form edges 39 | for table in hash_dict: 40 | for val in table: 41 | if len(table[val]) >= 2: 42 | for pair in combinations(table[val], 2): 43 | pairwise_counts[pair] += 1 44 | 45 | return pairwise_counts 46 | 47 | 48 | def get_edges(counts, threshold): 49 | ''' Get edges given pairwise counts and threshold. ''' 50 | edges = set() 51 | 52 | for pair in counts: 53 | if counts[pair] >= threshold: 54 | edges.add(pair) 55 | 56 | return list(edges) 57 | 58 | 59 | if __name__ == '__main__': 60 | 61 | cleaned_text, cleaned_ids = rule_based_utils.gather_data(data_file_path='') 62 | ground_truth_path = '' 63 | 64 | # Set hyperparameters 65 | num_hashes = 10 66 | 67 | for (n_gram_size, thresholds) in tqdm([(10, [3, 4]), (15, [3, 4]), (20, [2, 3]), (25, [2, 3])]): 68 | 69 | # Get n-grams data for articles 70 | data_dict = rule_based_utils.get_ngrams(copy.deepcopy(cleaned_text), copy.deepcopy(cleaned_ids), n_gram_size=n_gram_size, concat=True, char=True) 71 | 72 | # Calculate n-gram overlaps and return pairs that meet overlap threshold 73 | counts = get_counts(copy.deepcopy(data_dict), num_hashes=num_hashes) 74 | 75 | for threshold in tqdm(thresholds): 76 | 77 | # Get edges from pairwise counts, threshold 78 | edges_list = get_edges(counts, threshold) 79 | 80 | # Get evaluation metrics 81 | edge_results, cluster_results, full_results = rule_based_utils.get_eval_metrics(edges_list, ground_truth_path, cleaned_ids, community_detection=True) 82 | 83 | # Add to dictionary of grid-search results 84 | edge_metrics[str((num_hashes, threshold, n_gram_size))] = edge_results 85 | cluster_metrics[str((num_hashes, threshold, n_gram_size))] = cluster_results 86 | full_metrics[str((num_hashes, threshold, n_gram_size))] = full_results 87 | 88 | # Create dictionary to export results 89 | results = {"edge_metrics": edge_metrics, "cluster_metrics": cluster_metrics, "full_metrics": full_metrics} 90 | 91 | # Save results 92 | j = json.dumps(results, indent=4) 93 | f = open('', 'w') 94 | print(j, file=f) 95 | f.close() 96 | -------------------------------------------------------------------------------- /inference_at_scale/data_fns.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import gzip 3 | from glob import glob 4 | from tqdm import tqdm 5 | from datetime import datetime 6 | 7 | from datasets import load_dataset 8 | 9 | 10 | def open_realnews(): 11 | 12 | start = datetime.now() 13 | file_list = [f'/c4/realnewslike/c4-train.{str(i).zfill(5)}-of-00512.json.gz' for i in range(512)] 14 | file_list.extend(glob('/c4/realnewslike/c4-validation**')) 15 | 16 | corpus = [] 17 | print("Loading data ...") 18 | for file in tqdm(file_list): 19 | 20 | with gzip.open(file, 'r') as fin: 21 | json_bytes = fin.read() 22 | json_str = json_bytes.decode('utf-8') 23 | str_split = json_str.split("\n") 24 | for string in str_split: 25 | if len(string) != 0: 26 | text = string.split('"text":"')[1].split('","timestamp"')[0] 27 | corpus.append(text) 28 | 29 | print(len(corpus), "files in corpus") 30 | print("Time taken:", datetime.now() - start) 31 | 32 | return corpus 33 | 34 | 35 | def open_c4_by_url(pattern="patents.google.com", name="patents"): 36 | 37 | start = datetime.now() 38 | 39 | full_corpus = [] 40 | for set in ['train', 'validation']: 41 | file_list = glob(f'/c4/en/c4-{set}**.json.gz') 42 | 43 | corpus = [] 44 | print("Loading data ...") 45 | for file in tqdm(file_list): 46 | 47 | with gzip.open(file, 'r') as fin: 48 | json_bytes = fin.read() 49 | json_str = json_bytes.decode('utf-8') 50 | str_split = json_str.split("\n") 51 | for string in str_split: 52 | if len(string) != 0: 53 | 54 | url = string.split('"url":"')[1].split('"')[-2] 55 | if pattern in url: 56 | 57 | text = string.split('"text":"')[1].split('","timestamp"')[0] 58 | corpus.append(text) 59 | 60 | print(len(corpus), f"files in {set} set") 61 | print("Time taken:", datetime.now() - start) 62 | 63 | with open(f"/c4/{name}_{set}.pkl", "wb") as f: 64 | pickle.dump(corpus, f, protocol=4) 65 | 66 | full_corpus.extend(corpus) 67 | 68 | return corpus 69 | 70 | 71 | def get_super_glue(): 72 | ''' Retrieves data from SuperGLUE. ''' 73 | 74 | # Which sections of data to load for SuperGLUE 75 | split_dict = { 76 | 'wsc': [['text']], 77 | 'boolq': [['passage']], 78 | 'cb': [['premise', 'hypothesis']], 79 | 'copa': [['premise', 'choice1'], ['premise', 'choice2']], 80 | 'multirc': [['paragraph']], 81 | 'record': [['passage']], 82 | 'rte': [['premise', 'hypothesis']], 83 | 'wic': [['sentence1'], ['sentence2']], 84 | } 85 | 86 | for ds in split_dict.keys(): 87 | print(f"**** {ds} ****") 88 | 89 | # Select text data 90 | dataset = load_dataset("super_glue", ds) 91 | 92 | dev_set = dataset['validation'] 93 | texts = [] 94 | for feat_list in split_dict[ds]: 95 | for i in range(len(dev_set)): 96 | text_list = [dev_set[feat][i] for feat in feat_list] 97 | texts.append(" ".join(text_list)) 98 | 99 | print(len(texts), "texts in corpus") 100 | corpus = list(set(texts)) 101 | print(len(corpus), "texts in corpus after deduplication") 102 | 103 | return corpus -------------------------------------------------------------------------------- /rule_based/ngrams.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from datetime import datetime 3 | from itertools import product 4 | import multiprocessing 5 | 6 | import json 7 | import copy 8 | import sys 9 | import os 10 | from os.path import dirname as up 11 | 12 | currentdir = os.path.dirname(os.path.realpath(__file__)) 13 | parentdir = os.path.dirname(currentdir) 14 | grandparentdir = os.path.dirname(parentdir) 15 | sys.path.append(parentdir) 16 | sys.path.append(grandparentdir) 17 | 18 | from rule_based import rule_based_utils 19 | 20 | 21 | def ngram_overlap(data_dict, overlap): 22 | '''Returns edges between articles given overlap threshold.''' 23 | 24 | global data 25 | 26 | data = data_dict 27 | 28 | # Compare pairs of passages, calculate overlap percentage and keep pair if meets overlap threshold 29 | print("\n Calculating overlaps ...") 30 | cores = multiprocessing.cpu_count() 31 | pool = multiprocessing.Pool(processes=cores) 32 | list_of_edges_lists = pool.starmap(compare_passage_pairs, [(date, overlap) for date in list(data.keys())]) 33 | pool.close() 34 | 35 | # Collapse into single list 36 | edges_list = [item for sublist in list_of_edges_lists for item in sublist] 37 | return edges_list 38 | 39 | 40 | def compare_passage_pairs(date_str, overlap): 41 | """Module for calculating N-Gram overlap on multiple cores""" 42 | 43 | same_day_dat = data[date_str] 44 | n_same_day = len(same_day_dat['art_list']) 45 | edges = [] 46 | 47 | # iterate through all other dates 48 | for alt_date in list(data.keys()): 49 | date_dt = datetime.strptime(date_str, "%b-%d-%Y") 50 | new_dt = datetime.strptime(alt_date, "%b-%d-%Y") 51 | 52 | # only compare to later dates (to prevent repetitions) 53 | if date_dt >= new_dt: 54 | new_day_dat = data[alt_date] 55 | n_new_day = len(new_day_dat['art_list']) 56 | 57 | print(f"\n Computing N-gram overlap between {date_str} and {alt_date} passages...") 58 | for i, j in product(range(n_same_day), range(n_new_day)): 59 | compare_overlap_to_threshold(i, j, data_i=same_day_dat, data_j=new_day_dat, overlap=overlap, outfile=edges) 60 | 61 | return edges 62 | 63 | 64 | def compare_overlap_to_threshold(i, j, data_i, data_j, overlap, outfile): 65 | 66 | # count ngram overlaps 67 | passage_i_set = data_i['ngram_list'][i] 68 | passage_j_set = data_j['ngram_list'][j] 69 | intersect = passage_i_set.intersection(passage_j_set) 70 | overlap_count = len(intersect) 71 | 72 | # compute percentage of possible ngrams that overlapped 73 | if len(passage_i_set) != 0 and len(passage_j_set) != 0: 74 | # overlap_pct = overlap_count / min(len(passage_i_set), len(passage_j_set)) 75 | overlap_pct = overlap_count / len(passage_i_set.union(passage_j_set)) 76 | else: 77 | overlap_pct = 0 78 | 79 | # compare to overlap threshold and add edge if meets threshold 80 | if overlap_pct >= overlap: 81 | id_i = data_i['id_list'][i] 82 | id_j = data_j['id_list'][j] 83 | text_i = data_i['art_list'][i] 84 | text_j = data_j['art_list'][j] 85 | outfile.append(( 86 | id_i, 87 | id_j, 88 | { 89 | 'text_1': text_i, 90 | 'text_2': text_j, 91 | 'overlap': overlap_pct 92 | } 93 | )) 94 | 95 | 96 | if __name__ == '__main__': 97 | 98 | cleaned_text, cleaned_ids = rule_based_utils.gather_data(data_file_path='') 99 | 100 | edge_metrics = {} 101 | cluster_metrics = {} 102 | full_metrics = {} 103 | 104 | ground_truth_path = '' 105 | 106 | for (n_gram_size, overlaps) in tqdm([(15, [0.4])]): 107 | for overlap in overlaps: 108 | edge_metrics[str((overlap, n_gram_size))] = {} 109 | cluster_metrics[str((overlap, n_gram_size))] = {} 110 | full_metrics[str((overlap, n_gram_size))] = {} 111 | 112 | data_dict = rule_based_utils.get_ngrams(copy.deepcopy(cleaned_text), copy.deepcopy(cleaned_ids), n_gram_size=n_gram_size, concat=False, char=True) 113 | 114 | # Calculate n-gram overlaps and return pairs that meet overlap threshold 115 | edges_list = ngram_overlap(copy.deepcopy(data_dict), overlap=overlap) 116 | 117 | # Get evaluation metrics 118 | edge_results, cluster_results, full_results = rule_based_utils.get_eval_metrics(edges_list, ground_truth_path, cleaned_ids, community_detection=True) 119 | 120 | # Add to dictionary of grid-search results 121 | edge_metrics[str((overlap, n_gram_size))] = edge_results 122 | cluster_metrics[str((overlap, n_gram_size))] = cluster_results 123 | full_metrics[str((overlap, n_gram_size))] = full_results 124 | 125 | # Create dictionary to export results 126 | results = {"edge_metrics": edge_metrics, "cluster_metrics": cluster_metrics, "full_metrics": full_metrics} 127 | 128 | # Save results 129 | j = json.dumps(results, indent=4) 130 | f = open('', 'w') 131 | print(j, file=f) 132 | f.close() 133 | -------------------------------------------------------------------------------- /inference_at_scale/test_sets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Python 3.8 3 | Compares SuperGlue to RealNews and pulls out duplicates between the sets. 4 | """ 5 | 6 | import pickle 7 | import json 8 | from tqdm import tqdm 9 | from datetime import datetime 10 | import math 11 | import numpy as np 12 | import os 13 | 14 | from datasets import load_dataset 15 | import faiss 16 | 17 | import data_fns 18 | from biencoder_at_scale import embed 19 | 20 | 21 | def find_real_news_neighbours(query_embeddings, corpus_embeddings, k=900, d=768, corpus_batch_size=10000000): 22 | 23 | """ 24 | Pull k nearest neighbours to every text in query embeddings from each batch of corpus embeddings 25 | """ 26 | 27 | start_time = datetime.now() 28 | 29 | # Initialise faiss 30 | res = faiss.StandardGpuResources() 31 | 32 | n_corpus_batches = math.ceil(corpus_embeddings.shape[0] / corpus_batch_size) 33 | print("Total batches:", n_corpus_batches) 34 | 35 | # Batch over corpus 36 | dist_list = [] 37 | nn_list = [] 38 | 39 | for j in range(n_corpus_batches): 40 | 41 | print(f"***** Corpus batch {j} *****") 42 | 43 | gpu_index_flat = faiss.GpuIndexFlatIP(res, d) 44 | gpu_index_flat.add(corpus_embeddings[(corpus_batch_size * j):corpus_batch_size * (j + 1)]) 45 | 46 | D, I = gpu_index_flat.search(query_embeddings, k) 47 | dist_list.append(D) 48 | 49 | # Adjust IDs for not starting at 0 50 | I_adj = I + j * corpus_batch_size 51 | nn_list.append(I_adj) 52 | 53 | gpu_index_flat.reset() 54 | 55 | end_time = datetime.now() 56 | print("total time elapsed: ", (end_time - start_time)) 57 | 58 | dist_list = np.concatenate(dist_list, axis=1) 59 | nn_list = np.concatenate(nn_list, axis=1) 60 | 61 | print(dist_list.shape) 62 | print(nn_list.shape) 63 | 64 | return dist_list, nn_list 65 | 66 | 67 | def compare_to_superglue(rn_embedding_file_list, biencoder_model, biencoder_threshold, working_directory): 68 | 69 | """ 70 | Find duplicates between realnews and all the test sets in superglue. 71 | Biencoder embeddings are loaded from saved file (can be computed in biencoder_at_scale.py) 72 | SuperGLUE texts downloaded from Datasets. 73 | Texts of duplicates are saved in 'pairs' 74 | """ 75 | 76 | os.makedirs(working_directory, exist_ok=True) 77 | os.makedirs(f'{working_directory}/embeddings', exist_ok=True) 78 | os.makedirs(f'{working_directory}/pairs', exist_ok=True) 79 | 80 | # Load real news embeddings 81 | print("Loading real news embeddings from...", rn_embedding_file_list) 82 | rn_embeddings = [] 83 | for file in tqdm(rn_embedding_file_list): 84 | with open(file, 'rb') as f: 85 | rn_embeddings.append(pickle.load(f)) 86 | real_news_embeddings = np.concatenate(rn_embeddings, axis=0) 87 | 88 | # Load real news texts 89 | real_news_corpus = data_fns.open_realnews() 90 | 91 | # Load SuperGlue data 92 | # Which sections of data to load for SuperGLUE 93 | split_dict = { 94 | 'wsc': [['text']], 95 | 'boolq': [['passage']], 96 | 'cb': [['premise', 'hypothesis']], 97 | 'copa': [['premise', 'choice1'], ['premise', 'choice2']], 98 | 'multirc': [['paragraph']], 99 | 'record': [['passage']], 100 | 'rte': [['premise', 'hypothesis']], 101 | 'wic': [['sentence1'], ['sentence2']], 102 | } 103 | 104 | for ds in split_dict.keys(): 105 | print(f"**** {ds} ****") 106 | 107 | # Select text data 108 | dataset = load_dataset("super_glue", ds) 109 | 110 | dev_set = dataset['validation'] 111 | texts = [] 112 | for feat_list in split_dict[ds]: 113 | for i in range(len(dev_set)): 114 | text_list = [dev_set[feat][i] for feat in feat_list] 115 | texts.append(" ".join(text_list)) 116 | 117 | print(len(texts), "texts in corpus") 118 | corpus = list(set(texts)) 119 | print(len(corpus), "texts in corpus after removing exact duplicates") 120 | 121 | # Embed 122 | embeddings = embed( 123 | corpus, 124 | trained_model=biencoder_model, 125 | save_stem=f'{working_directory}/embeddings/{ds}', 126 | batch_size=1024 127 | ) 128 | 129 | # Compare to real news embeddings 130 | dist, nn = find_real_news_neighbours(embeddings, real_news_embeddings, k=900, d=768, corpus_batch_size=5000000) 131 | 132 | # Subset to neighbours under threshold 133 | under_th = [nn[i][(dist[i] >= biencoder_threshold)] for i in range(len(nn))] 134 | 135 | # Convert to edges 136 | nn_edge_list = [(i, j) for i in range(len(under_th)) for j in under_th[i] if j != i] 137 | print("Pairs found:", len(nn_edge_list)) 138 | 139 | print(f"Number of {ds} texts with at least one duplicate:", len(set([edge[0] for edge in nn_edge_list]))) 140 | 141 | # Match back to texts 142 | out_list = [] 143 | for edge in nn_edge_list: 144 | out_list.append({ 145 | str(edge[0]): str(corpus[edge[0]]), 146 | str(edge[1]): str(real_news_corpus[edge[1]]), 147 | }) 148 | 149 | with open(f'{working_directory}/pairs/{ds}.json', 'w') as f: 150 | json.dump(out_list, f, indent=4) 151 | 152 | 153 | if __name__ == '__main__': 154 | 155 | rn_embedding_file_list = [f"/embeddings/real_news_like_embeddings_{i}.pkl" for i in range(0, 21)] 156 | 157 | compare_to_superglue( 158 | rn_embedding_file_list=rn_embedding_file_list, 159 | biencoder_model='', 160 | biencoder_threshold=0.94, 161 | working_directory='' 162 | ) 163 | 164 | 165 | -------------------------------------------------------------------------------- /neural/neural_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Python 3.8 3 | Evaluation for all neural models. Biencoder + cross encoder (optional) + community detection 4 | """ 5 | 6 | import json 7 | import os 8 | import sys 9 | 10 | import numpy as np 11 | import pandas as pd 12 | from itertools import combinations 13 | import logging 14 | 15 | from sentence_transformers import LoggingHandler, util, SentenceTransformer 16 | from sentence_transformers.cross_encoder import CrossEncoder 17 | from transformers import logging as lg 18 | 19 | currentdir = os.path.dirname(os.path.realpath(__file__)) 20 | parentdir = os.path.dirname(currentdir) 21 | grandparentdir = os.path.dirname(parentdir) 22 | sys.path.append(parentdir) 23 | sys.path.append(grandparentdir) 24 | 25 | import utils 26 | from rule_based import rule_based_utils, ngrams 27 | 28 | # Config logging 29 | lg.set_verbosity_error() 30 | logging.basicConfig(format='%(asctime)s - %(message)s', 31 | datefmt='%Y-%m-%d %H:%M:%S', 32 | level=logging.INFO, 33 | handlers=[LoggingHandler()]) 34 | logger = logging.getLogger(__name__) 35 | 36 | 37 | def filter_data( 38 | corpus_dict, 39 | filter_type='ngrams', 40 | parameters={'n_gram_size': 5, 'overlap': 0.01} 41 | ): 42 | 43 | """ 44 | Return all pairs from some filtering step 45 | - If biencoder: return all pairs closer than biencoder threshold 46 | - If ngrams return all pairs with ngrams in common 47 | - If none, return all pairs 48 | """ 49 | 50 | if filter_type == 'ngrams': 51 | 52 | cleaned_id_list, cleaned_text_list = rule_based_utils.clean_text( 53 | corpus_dict, 54 | first_n_tok=None, 55 | min_tok=None, 56 | spell_check="symspell" 57 | ) 58 | 59 | data_dict = rule_based_utils.get_ngrams(cleaned_text_list, cleaned_id_list, n_gram_size=parameters['n_gram_size']) 60 | 61 | # Calculate n-gram overlaps and return pairs that meet overlap threshold 62 | pairs_to_compare = ngrams.ngram_overlap(data_dict, overlap=parameters['overlap']) 63 | 64 | elif filter_type == 'biencoder': 65 | 66 | # Initialise model 67 | embedder = SentenceTransformer(parameters['model']) 68 | 69 | corpus = [] 70 | for art_id in list(corpus_dict.keys()): 71 | corpus.append(corpus_dict[art_id]['article']) 72 | 73 | # Embed data 74 | print("Embedding corpus ...") 75 | corpus_embeddings = embedder.encode(corpus, show_progress_bar=True, batch_size=512) 76 | corpus_embeddings = corpus_embeddings / np.linalg.norm(corpus_embeddings, axis=1, keepdims=True) 77 | 78 | corpus_ids = list(corpus_dict.keys()) 79 | 80 | # Compute distances 81 | cosine_scores = util.cos_sim(corpus_embeddings, corpus_embeddings) 82 | 83 | # Compare to threshold 84 | above_threshold = cosine_scores > parameters['threshold'] 85 | upper_only = np.triu(np.ones((len(corpus_ids), len(corpus_ids))) - np.identity(len(corpus_ids))) 86 | result = above_threshold * upper_only 87 | indices = [index for index, value in np.ndenumerate(result) if value] 88 | pairs_to_compare = [[corpus_ids[pair[0]], corpus_ids[pair[1]]] for pair in indices] 89 | 90 | else: 91 | pairs_to_compare = [list(comb) for comb in combinations(list(corpus_dict.keys()), 2)] 92 | 93 | print("Number of pairs to compare:", len(pairs_to_compare)) 94 | 95 | inf_samples = [] 96 | for pair in pairs_to_compare: 97 | inf_samples.append([str(corpus_dict[pair[0]]['article']), str(corpus_dict[pair[1]]['article'])]) 98 | 99 | return inf_samples, pairs_to_compare 100 | 101 | 102 | def crossencoder_inference( 103 | inf_samples, 104 | pairs_to_compare, 105 | trained_model_path, 106 | batch_size=128, 107 | save_dir=None 108 | ): 109 | 110 | """ 111 | Run cross-encoder inference over set of pairs 112 | """ 113 | 114 | # Run evaluation 115 | trained_model = CrossEncoder(trained_model_path, num_labels=1) 116 | 117 | dev_results = pd.read_csv(f"{trained_model_path}/CEBinaryClassificationEvaluator_dev_results.csv") 118 | threshold = list(dev_results['F1_Threshold'])[-1] 119 | 120 | inference = trained_model.predict(inf_samples, batch_size=batch_size, apply_softmax=True, show_progress_bar=True) 121 | 122 | outputs = [] 123 | for i in range(len(inference)): 124 | if inference[i] > threshold: 125 | outputs.append(pairs_to_compare[i]) 126 | 127 | if save_dir: 128 | os.makedirs(save_dir, exist_ok=True) 129 | with open(f'{save_dir}/predicted_edges.json', 'w') as of: 130 | json.dump(outputs, of, indent=4) 131 | 132 | return outputs 133 | 134 | 135 | def evaluate(gt_path, inf_data_path, biencoder_model, cross_encoder=False, cross_encoder_model=''): 136 | 137 | """ 138 | Run filtering step (see filter_data), then cross-encoder (optional). Evaluate results 139 | """ 140 | 141 | with open(inf_data_path) as f: 142 | inf_data = json.load(f) 143 | 144 | # Filter edges 145 | filtered_data, pred_ids = filter_data( 146 | inf_data, 147 | filter_type='biencoder', 148 | # parameters={'n_gram_size': 3, 'overlap': 0.05}, 149 | parameters={'threshold': 0.94, 'model': biencoder_model} # 0.92 is best with cross-encoder, 0.94 for biencoder alone 150 | ) 151 | 152 | # Run cross-encoder 153 | if cross_encoder: 154 | pred_ids = crossencoder_inference( 155 | filtered_data, 156 | pred_ids, 157 | trained_model_path=cross_encoder_model, 158 | batch_size=1024, 159 | ) 160 | 161 | # Community detection 162 | cd_edges = utils.detect_communities_nx(pred_ids, resolution=1) 163 | 164 | # Evaluate 165 | utils.evaluate(cd_edges, gt_edge_path=gt_path) 166 | utils.cluster_eval(cd_edges, gt_edges=gt_path, all_ids=list(inf_data.keys())) 167 | 168 | 169 | if __name__ == '__main__': 170 | 171 | evaluate( 172 | gt_path='', 173 | inf_data_path='', 174 | biencoder_model='', 175 | cross_encoder_model='', 176 | ) 177 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: wc_comp 2 | channels: 3 | - nvidia 4 | - pytorch 5 | - rapidsai 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=conda_forge 10 | - _openmp_mutex=4.5=2_gnu 11 | - arrow-cpp=9.0.0=py38he270906_2_cpu 12 | - aws-c-cal=0.5.11=h95a6274_0 13 | - aws-c-common=0.6.2=h7f98852_0 14 | - aws-c-event-stream=0.2.7=h3541f99_13 15 | - aws-c-io=0.10.5=hfb6a706_0 16 | - aws-checksums=0.1.11=ha31a3da_7 17 | - aws-sdk-cpp=1.8.186=hecaee15_4 18 | - bokeh=3.0.1=pyhd8ed1ab_0 19 | - brotlipy=0.7.0=py38h0a891b7_1005 20 | - bzip2=1.0.8=h7f98852_4 21 | - c-ares=1.18.1=h7f98852_0 22 | - ca-certificates=2022.10.11=h06a4308_0 23 | - cachetools=5.2.0=pyhd8ed1ab_0 24 | - certifi=2022.9.24=py38h06a4308_0 25 | - cffi=1.15.1=py38h4a40e3a_2 26 | - click=8.1.3=py38h578d9bd_1 27 | - cloudpickle=2.2.0=pyhd8ed1ab_0 28 | - contourpy=1.0.6=py38h43d8883_0 29 | - cryptography=38.0.3=py38h2b5fc30_0 30 | - cubinlinker=0.2.0=py38h7144610_1 31 | - cuda-python=11.7.1=py38h7525318_1 32 | - cudatoolkit=11.3.1=h9edb442_10 33 | - cudf=22.10.01=cuda_11_py38_gca9a422da9_2 34 | - cupy=11.2.0=py38h405e1b6_0 35 | - cytoolz=0.12.0=py38h0a891b7_1 36 | - dask=2022.9.2=pyhd8ed1ab_0 37 | - dask-core=2022.9.2=pyhd8ed1ab_0 38 | - dask-cuda=22.10.00=py38_g382e519_0 39 | - dask-cudf=22.10.01=cuda_11_py38_gca9a422da9_2 40 | - distributed=2022.9.2=pyhd8ed1ab_0 41 | - dlpack=0.5=h9c3ff4c_0 42 | - faiss-gpu=1.7.3=py3.8_h28a55e0_0_cuda11.3 43 | - fastavro=1.7.0=py38h0a891b7_0 44 | - fastrlock=0.8=py38hfa26641_3 45 | - freetype=2.12.1=hca18f0e_0 46 | - fsspec=2022.10.0=pyhd8ed1ab_0 47 | - gflags=2.2.2=he1b5a44_1004 48 | - glog=0.6.0=h6f12383_0 49 | - grpc-cpp=1.47.1=hbad87ad_6 50 | - heapdict=1.0.1=py_0 51 | - idna=3.4=pyhd8ed1ab_0 52 | - importlib-metadata=5.0.0=pyha770c72_1 53 | - intel-openmp=2022.1.0=h9e868ea_3769 54 | - jinja2=3.1.2=pyhd8ed1ab_1 55 | - joblib=1.2.0=pyhd8ed1ab_0 56 | - jpeg=9e=h166bdaf_2 57 | - keyutils=1.6.1=h166bdaf_0 58 | - krb5=1.19.3=h3790be6_0 59 | - lcms2=2.14=h6ed2654_0 60 | - ld_impl_linux-64=2.39=hc81fddc_0 61 | - lerc=4.0.0=h27087fc_0 62 | - libabseil=20220623.0=cxx17_h48a1fff_5 63 | - libblas=3.9.0=16_linux64_openblas 64 | - libbrotlicommon=1.0.9=h166bdaf_8 65 | - libbrotlidec=1.0.9=h166bdaf_8 66 | - libbrotlienc=1.0.9=h166bdaf_8 67 | - libcblas=3.9.0=16_linux64_openblas 68 | - libcrc32c=1.1.2=h9c3ff4c_0 69 | - libcudf=22.10.01=cuda11_gca9a422da9_2 70 | - libcugraph=22.10.01=cuda11_g9cf07eb6_0 71 | - libcugraphops=22.10.00=cuda11_g4257014_0 72 | - libcurl=7.86.0=h7bff187_1 73 | - libcusolver=11.4.1.48=0 74 | - libcusparse=11.7.5.86=0 75 | - libdeflate=1.14=h166bdaf_0 76 | - libedit=3.1.20191231=he28a2e2_2 77 | - libev=4.33=h516909a_1 78 | - libevent=2.1.10=h9b69904_4 79 | - libfaiss=1.7.3=hfc2d529_0_cuda11.3 80 | - libffi=3.4.2=h7f98852_5 81 | - libgcc-ng=12.2.0=h65d4601_19 82 | - libgfortran-ng=12.2.0=h69a702a_19 83 | - libgfortran5=12.2.0=h337968e_19 84 | - libgomp=12.2.0=h65d4601_19 85 | - libgoogle-cloud=2.1.0=h9ebe8e8_2 86 | - liblapack=3.9.0=16_linux64_openblas 87 | - libllvm11=11.1.0=he0ac6c6_5 88 | - libnghttp2=1.47.0=hdcd2b5c_1 89 | - libnsl=2.0.0=h7f98852_0 90 | - libopenblas=0.3.21=pthreads_h78a6416_3 91 | - libpng=1.6.38=h753d276_0 92 | - libprotobuf=3.20.1=h6239696_4 93 | - libraft-distance=22.10.01=cuda11_gf7d2335_0 94 | - libraft-headers=22.10.01=cuda11_gf7d2335_0 95 | - librmm=22.10.01=cuda11_gd98b8719_0 96 | - libsqlite=3.39.4=h753d276_0 97 | - libssh2=1.10.0=haa6b8db_3 98 | - libstdcxx-ng=12.2.0=h46fd767_19 99 | - libthrift=0.16.0=h491838f_2 100 | - libtiff=4.4.0=h55922b4_4 101 | - libutf8proc=2.8.0=h166bdaf_0 102 | - libuuid=2.32.1=h7f98852_1000 103 | - libwebp-base=1.2.4=h166bdaf_0 104 | - libxcb=1.13=h7f98852_1004 105 | - libzlib=1.2.13=h166bdaf_4 106 | - llvmlite=0.39.1=py38h38d86a4_1 107 | - locket=1.0.0=pyhd8ed1ab_0 108 | - lz4=4.0.2=py38h1bf946c_0 109 | - lz4-c=1.9.3=h9c3ff4c_1 110 | - markupsafe=2.1.1=py38h0a891b7_2 111 | - mkl=2022.1.0=hc2b9512_224 112 | - msgpack-python=1.0.4=py38h43d8883_1 113 | - nccl=2.14.3.1=h0800d71_0 114 | - ncurses=6.3=h27087fc_1 115 | - numba=0.56.3=py38h9a4aae9_0 116 | - numpy=1.23.4=py38h7042d01_1 117 | - nvtx=0.2.3=py38h0a891b7_2 118 | - openjpeg=2.5.0=h7d73246_1 119 | - openssl=1.1.1s=h7f8727e_0 120 | - orc=1.7.6=h6c59b99_0 121 | - packaging=21.3=pyhd8ed1ab_0 122 | - pandas=1.5.1=py38h8f669ce_1 123 | - parquet-cpp=1.5.1=2 124 | - partd=1.3.0=pyhd8ed1ab_0 125 | - pillow=9.2.0=py38h9eb91d8_3 126 | - pip=22.3.1=pyhd8ed1ab_0 127 | - protobuf=3.20.1=py38hfa26641_0 128 | - psutil=5.9.4=py38h0a891b7_0 129 | - pthread-stubs=0.4=h36c2ea0_1001 130 | - ptxcompiler=0.7.0=py38h7525318_2 131 | - pyarrow=9.0.0=py38h097c49a_2_cpu 132 | - pycparser=2.21=pyhd8ed1ab_0 133 | - pylibcugraph=22.10.01=cuda11_py38_g9cf07eb6_0 134 | - pylibraft=22.10.01=cuda11_py38_gf7d2335_0 135 | - pynvml=11.4.1=pyhd8ed1ab_0 136 | - pyopenssl=22.1.0=pyhd8ed1ab_0 137 | - pyparsing=3.0.9=pyhd8ed1ab_0 138 | - pysocks=1.7.1=py38h578d9bd_5 139 | - python=3.8.13=h582c2e5_0_cpython 140 | - python-dateutil=2.8.2=pyhd8ed1ab_0 141 | - python_abi=3.8=2_cp38 142 | - pytz=2022.6=pyhd8ed1ab_0 143 | - pyyaml=6.0=py38h0a891b7_5 144 | - raft-dask=22.10.01=cuda11_py38_gf7d2335_0 145 | - re2=2022.06.01=h27087fc_0 146 | - readline=8.1.2=h0f457ee_0 147 | - rmm=22.10.01=cuda11_py38_gd98b8719_0 148 | - s2n=1.0.10=h9b69904_0 149 | - setuptools=65.5.1=pyhd8ed1ab_0 150 | - six=1.16.0=pyh6c4a22f_0 151 | - snappy=1.1.9=hbd366e4_2 152 | - sortedcontainers=2.4.0=pyhd8ed1ab_0 153 | - spdlog=1.8.5=h4bd325d_1 154 | - sqlite=3.39.4=h4ff8645_0 155 | - tblib=1.7.0=pyhd8ed1ab_0 156 | - tk=8.6.12=h27826a3_0 157 | - toolz=0.12.0=pyhd8ed1ab_0 158 | - tornado=6.1=py38h0a891b7_3 159 | - typing_extensions=4.4.0=pyha770c72_0 160 | - ucx=1.13.1=h538f049_0 161 | - ucx-proc=1.0.0=gpu 162 | - ucx-py=0.28.00=py38_g8292636_0 163 | - urllib3=1.26.11=pyhd8ed1ab_0 164 | - wheel=0.38.3=pyhd8ed1ab_0 165 | - xorg-libxau=1.0.9=h7f98852_0 166 | - xorg-libxdmcp=1.1.3=h7f98852_0 167 | - xyzservices=2022.9.0=pyhd8ed1ab_0 168 | - xz=5.2.6=h166bdaf_0 169 | - yaml=0.2.5=h7f98852_2 170 | - zict=2.2.0=pyhd8ed1ab_0 171 | - zipp=3.10.0=pyhd8ed1ab_0 172 | - zlib=1.2.13=h166bdaf_4 173 | - zstd=1.5.2=h6239696_4 174 | - pip: 175 | - aiohttp==3.8.3 176 | - aiosignal==1.3.1 177 | - async-timeout==4.0.2 178 | - attrs==22.1.0 179 | - charset-normalizer==2.1.1 180 | - cugraph==22.10.1+0.g9cf07eb6.dirty 181 | - cython==0.29.32 182 | - datasets==2.6.1 183 | - datasketch==1.5.8 184 | - dill==0.3.5.1 185 | - docker-pycreds==0.4.0 186 | - editdistpy==0.1.3 187 | - frozenlist==1.3.3 188 | - gitdb==4.0.9 189 | - gitpython==3.1.29 190 | - hdbscan==0.8.29 191 | - markov-clustering==0.0.6.dev0 192 | - multidict==6.0.2 193 | - multiprocess==0.70.13 194 | - networkx==2.8.8 195 | - pathtools==0.1.2 196 | - promise==2.3 197 | - requests==2.28.1 198 | - responses==0.18.0 199 | - scipy==1.9.3 200 | - sentence-transformers==2.2.2 201 | - sentry-sdk==1.11.0 202 | - setproctitle==1.3.2 203 | - shortuuid==1.0.11 204 | - smmap==5.0.0 205 | - symspellpy==6.7.7 206 | - wandb==0.13.5 207 | - xxhash==3.1.0 208 | - yarl==1.8.1 209 | prefix: /home/silcock/.conda/envs/wc_comp 210 | 211 | -------------------------------------------------------------------------------- /inference_at_scale/biencoder_at_scale.py: -------------------------------------------------------------------------------- 1 | """ 2 | Python 3.8 3 | Runs biencoder inference over C4 at scale. 4 | """ 5 | 6 | import pickle 7 | from tqdm import tqdm 8 | from glob import glob 9 | import os 10 | import sys 11 | 12 | from datetime import datetime 13 | import numpy as np 14 | import math 15 | 16 | import faiss 17 | from sentence_transformers import SentenceTransformer 18 | 19 | currentdir = os.path.dirname(os.path.realpath(__file__)) 20 | parentdir = os.path.dirname(currentdir) 21 | grandparentdir = os.path.dirname(parentdir) 22 | sys.path.append(parentdir) 23 | sys.path.append(grandparentdir) 24 | 25 | import utils 26 | import data_fns 27 | 28 | def embed( 29 | corpus, 30 | trained_model, 31 | save_stem, 32 | batch_size=64, 33 | ): 34 | 35 | """ 36 | Embeds a list of text (corpus) using supplied biencoder model. Saves embeddings in chunks of 1M . 37 | """ 38 | 39 | # Initialise model 40 | embedder = SentenceTransformer(trained_model) 41 | 42 | # Embed data 43 | print("Embedding corpus ...") 44 | all_embeddings = [] 45 | 46 | chunk_size = 1000000 47 | nchunks = math.ceil(len(corpus)/chunk_size) 48 | 49 | for i in range(nchunks): 50 | 51 | print(f"Chunk {i}/{nchunks}") 52 | 53 | corpus_embeddings = embedder.encode(corpus[(chunk_size*i):chunk_size*(i+1)], show_progress_bar=True, batch_size=batch_size) 54 | 55 | # Normalize the embeddings to unit length 56 | corpus_embeddings = corpus_embeddings / np.linalg.norm(corpus_embeddings, axis=1, keepdims=True) 57 | 58 | with open(f'{save_stem}_{i}.pkl', 'wb') as f: 59 | pickle.dump(corpus_embeddings, f) 60 | 61 | all_embeddings.append(corpus_embeddings) 62 | 63 | corpus_embeddings = np.concatenate(all_embeddings, axis=0) 64 | print(len(corpus_embeddings), "embeddings in corpus") 65 | 66 | return corpus_embeddings 67 | 68 | 69 | def find_nearest_neighbours(embedding_list, save_stem, k=5, d=768, query_batch_size=1000, corpus_batch_size=10000000, normalize=False): 70 | 71 | """ 72 | Takes list of embeddings and compares each embedding to each other, retuning k nearest neighbours, per corpus batch 73 | Nearest neighbours and distances are saved. 74 | Batches approach https://davidefiocco.github.io/nearest-neighbor-search-with-faiss/ 75 | """ 76 | 77 | if normalize: 78 | faiss.normalize_L2(embedding_list) 79 | 80 | start_time = datetime.now() 81 | 82 | # Initialise faiss 83 | res = faiss.StandardGpuResources() 84 | 85 | n_query_batches = math.ceil(embedding_list.shape[0] / query_batch_size) 86 | n_corpus_batches = math.ceil(embedding_list.shape[0] / corpus_batch_size) 87 | print("Total batches:", n_query_batches * n_corpus_batches) 88 | 89 | # Batch over corpus 90 | for j in range(n_corpus_batches): 91 | 92 | print(f"***** Corpus batch {j} *****") 93 | 94 | gpu_index_flat = faiss.GpuIndexFlatIP(res, d) 95 | gpu_index_flat.add(embedding_list[(corpus_batch_size*j):corpus_batch_size*(j+1)]) 96 | 97 | for i in range(n_query_batches): 98 | print(f"\n Corpus batch {j}, query batch {i}: {(i*j) +i}/{n_query_batches * n_corpus_batches}") 99 | start_batch = datetime.now() 100 | D, I = gpu_index_flat.search(embedding_list[(query_batch_size*i):query_batch_size*(i+1)], k) 101 | end_batch = datetime.now() 102 | 103 | print("Time taken by batch was ", end_batch-start_batch) 104 | print("Saving intermediate results...") 105 | 106 | # Adjust IDs for not starting at 0 107 | I_adj = I + j*corpus_batch_size 108 | 109 | # Save intermediate results for this batch 110 | with open(f"{save_stem}/nn_list_batch_{i}_{j}.pkl", "wb") as f: 111 | pickle.dump(I_adj, f, protocol=4) 112 | with open(f"{save_stem}/dist_list_batch_{i}_{j}.pkl", "wb") as f: 113 | pickle.dump(D, f, protocol=4) 114 | 115 | gpu_index_flat.reset() 116 | 117 | end_time = datetime.now() 118 | print("total time elapsed: ", (end_time-start_time)) 119 | 120 | 121 | def subset_nn_data(saved_stem, threshold): 122 | 123 | """ 124 | Reloads nearest neighbours, subsets to those below a distance threshold and compiles across multiple batches. 125 | """ 126 | 127 | print("Reloading nearest neighbours ...") 128 | 129 | nn_files = glob(f"{saved_stem}/dist_list_batch*.pkl") 130 | 131 | i_list = set([int(file.split("_")[-2]) for file in nn_files]) 132 | j_list = set([int(file.split("_")[-1].split(".")[0]) for file in nn_files]) 133 | i_list = sorted(i_list) 134 | j_list = sorted(j_list) 135 | 136 | under_th = [] 137 | 138 | for i in tqdm(i_list): # For all batches of queries 139 | 140 | dist_list = [] 141 | nn_list = [] 142 | 143 | for j in j_list: # Grab results for all batches of the corpus 144 | with open(f"{saved_stem}/dist_list_batch_{i}_{j}.pkl", 'rb') as f: 145 | dist_list.append(pickle.load(f)) 146 | with open(f"{saved_stem}/nn_list_batch_{i}_{j}.pkl", 'rb') as f: 147 | nn_list.append(pickle.load(f)) 148 | 149 | dist_list = np.concatenate(dist_list, axis=1) 150 | nn_list = np.concatenate(nn_list, axis=1) 151 | 152 | under_th.extend([nn_list[i][(dist_list[i] >= threshold)] for i in range(len(nn_list))]) 153 | 154 | print(len(under_th)) 155 | 156 | return under_th 157 | 158 | 159 | def nearest_neighbours_to_pairs(final_under_thresh, save_dir): 160 | 161 | """ 162 | Reformat nearest neighbour list as pairs 163 | """ 164 | 165 | print("Sanity check ...") 166 | for i in tqdm(range(len(final_under_thresh))): 167 | assert i in final_under_thresh[i] 168 | 169 | print("Creating pairs ...") 170 | edge_list = [(i, j) for i in range(len(final_under_thresh)) for j in final_under_thresh[i] if j != i] 171 | 172 | # Remove edges that are in twice 173 | print("Removing duplicate edges ...") 174 | edge_list = list({*map(tuple, map(sorted, edge_list))}) 175 | 176 | print("Total edges:", len(edge_list)) 177 | 178 | # Save 179 | with open(f'{save_dir}/nn_edge_list.pkl', 'wb') as f: 180 | pickle.dump(edge_list, f, protocol=4) 181 | 182 | return edge_list 183 | 184 | 185 | def run_inference(corpus, working_directory, biencoder_model, biencoder_threshold): 186 | 187 | """ 188 | From a corpus, finds all pairs that are under a certain distance apart. Forms clusters, either suing connected 189 | components or community detection, and then imposes transitivity over these clusters. 190 | Save cudf of edge ID pairs. 191 | """ 192 | 193 | os.makedirs(working_directory, exist_ok=True) 194 | os.makedirs(f'{working_directory}/embeddings', exist_ok=True) 195 | os.makedirs(f'{working_directory}/all_nearest_neighbours', exist_ok=True) 196 | os.makedirs(f'{working_directory}/edges/', exist_ok=True) 197 | 198 | # Create embeddings 199 | corpus_embeddings = embed( 200 | corpus, 201 | trained_model=biencoder_model, 202 | save_stem=f'{working_directory}/embeddings/embeddings', 203 | batch_size=1024 204 | ) 205 | 206 | # Return list of nearest neighbours and their distances 207 | find_nearest_neighbours( 208 | corpus_embeddings, 209 | save_stem=f'{working_directory}/all_nearest_neighbours/', 210 | k=900, 211 | d=768, 212 | query_batch_size=150000, 213 | corpus_batch_size=10000000 214 | ) 215 | 216 | # Reload and subset nearest neighbour data, to pairs under given threshold 217 | under_th = subset_nn_data(f'{working_directory}/all_nearest_neighbours/', threshold=biencoder_threshold) 218 | 219 | nn_edges = nearest_neighbours_to_pairs(under_th, save_dir=f'{working_directory}/edges/') 220 | 221 | # Put on graph 222 | G = utils.cnx_make_graph_from_edges(nn_edges) 223 | 224 | # Pull all edges from connected components 225 | # utils.gpu_connected_components(G, save_file='f'{working_directory}/edges/connected_comps_edges.pkl') 226 | 227 | # Pull all edges after running community detection 228 | utils.gpu_connected_components(G, save_file=f'{working_directory}/edges/community_edges.pkl', detect_communities=True) 229 | 230 | 231 | if __name__ == '__main__': 232 | 233 | # C4 data 234 | # corpus = data_fns.open_realnews() 235 | corpus = data_fns.open_c4_by_url(pattern="patents.google.com", name="patents") 236 | 237 | run_inference( 238 | corpus, 239 | working_directory="", 240 | biencoder_model='', 241 | biencoder_threshold=0.94 242 | ) 243 | -------------------------------------------------------------------------------- /rule_based/rule_based_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from tqdm import tqdm 4 | import re 5 | import os 6 | from os.path import dirname as up 7 | import sys 8 | from transformers import BertTokenizerFast 9 | import logging 10 | 11 | from symspellpy import SymSpell, Verbosity 12 | 13 | currentdir = os.path.dirname(os.path.realpath(__file__)) 14 | parentdir = os.path.dirname(currentdir) 15 | grandparentdir = os.path.dirname(parentdir) 16 | sys.path.append(parentdir) 17 | sys.path.append(grandparentdir) 18 | 19 | import utils 20 | 21 | 22 | def set_global_logging_level(level=logging.ERROR, prefices=[""]): 23 | """ 24 | Override logging levels of different modules based on their name as a prefix. 25 | It needs to be invoked after the modules have been loaded so that their loggers have been initialized. 26 | """ 27 | prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })') 28 | for name in logging.root.manager.loggerDict: 29 | if re.match(prefix_re, name): 30 | logging.getLogger(name).setLevel(level) 31 | 32 | 33 | def spellcheck(list_of_texts, spell_check_type): 34 | ''' Runs spell-checker over the list of texts. ''' 35 | 36 | if spell_check_type == "symspell": 37 | spell_checked_texts = symspell_check_ocr(list_of_texts) 38 | if spell_check_type == "fixed": 39 | spell_checked_texts = fixed_dict(list_of_texts) 40 | if spell_check_type is None: 41 | return list_of_texts 42 | 43 | return spell_checked_texts 44 | 45 | 46 | def symspell_setup(resource_dir="", edit_distance=2): 47 | 48 | sym_spell = SymSpell(max_dictionary_edit_distance=edit_distance, prefix_length=7) 49 | 50 | dictionary_path = os.path.join(resource_dir, "frequency_dictionary_en_82_765.txt") 51 | bigram_path = os.path.join(resource_dir, "frequency_bigramdictionary_en_243_342.txt") 52 | 53 | print("Dictionary Path:", dictionary_path) 54 | sym_spell.load_dictionary(dictionary_path, term_index=0, count_index=1) 55 | sym_spell.load_bigram_dictionary(bigram_path, term_index=0, count_index=2) 56 | 57 | return sym_spell 58 | 59 | 60 | def fixed_dict(ocr_article_clean_texts): 61 | '''Very flexible spell checker.''' 62 | 63 | sym_spell = symspell_setup(edit_distance=5) 64 | 65 | ocr_spell_texts = [] 66 | 67 | print("\n Spell checking ...") 68 | for text in tqdm(ocr_article_clean_texts): 69 | spell_corr = [] 70 | for input_term in text.split(): 71 | suggestions = sym_spell.lookup(input_term, Verbosity.TOP, max_edit_distance=None, include_unknown=True, 72 | transfer_casing=True) 73 | spell_corr.append(suggestions[0].term) 74 | ocr_spell_texts.append(" ".join(spell_corr)) 75 | 76 | return ocr_spell_texts 77 | 78 | 79 | def symspell_check_ocr(ocr_article_clean_texts): 80 | """Corrects spelling of OCR article texts""" 81 | 82 | sym_spell = symspell_setup() 83 | 84 | ocr_spell_texts = [] 85 | 86 | print("\n Spell checking ...") 87 | for text in tqdm(ocr_article_clean_texts): 88 | spell_corr = [] 89 | for input_term in text.split(): 90 | suggestions = sym_spell.lookup(input_term, Verbosity.CLOSEST, max_edit_distance=2, include_unknown=True, 91 | transfer_casing=True) 92 | spell_corr.append(suggestions[0].term) 93 | ocr_spell_texts.append(" ".join(spell_corr)) 94 | 95 | return ocr_spell_texts 96 | 97 | 98 | def remove_odd_characters(list_of_texts): 99 | ''' Removes punctuation, unknown characters. ''' 100 | chars_to_remove = r'"#$%&\()*+/:;<=>@[\\]^_`{|}~.?,!\'' 101 | ocr_article_clean_texts = [] 102 | 103 | for text in list_of_texts: 104 | text = text.replace("-\n", "").replace("\n", " ") 105 | text = text.translate(str.maketrans('', '', chars_to_remove)) 106 | text = text.encode('ascii', 'ignore').decode() 107 | ocr_article_clean_texts.append(text) 108 | 109 | return ocr_article_clean_texts 110 | 111 | 112 | def clean_text(corpus_dict, first_n_tok=None, min_tok=None, spell_check="symspell"): 113 | ''' Cleans texts by removing punctuation, optionally spell-checking. ''' 114 | 115 | cleaned_ids = [] 116 | org_texts = [] 117 | 118 | # instantiate tokenizer 119 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 120 | tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased") 121 | set_global_logging_level(logging.ERROR, ["transformers", "BertTokenizerFast"]) 122 | 123 | for key in list(corpus_dict.keys()): 124 | text = corpus_dict[key]['byline'] + " " + corpus_dict[key]['article'] 125 | 126 | if first_n_tok is not None: 127 | tokens = tokenizer.encode(text, truncation=False) 128 | text = tokenizer.decode(tokens[1:first_n_tok]) 129 | if min_tok is not None: 130 | if len(tokens) > min_tok: 131 | cleaned_ids.append(key) 132 | org_texts.append(text) 133 | else: 134 | cleaned_ids.append(key) 135 | org_texts.append(text) 136 | 137 | cleaned_texts = remove_odd_characters(org_texts) 138 | 139 | if spell_check: 140 | cleaned_texts = spellcheck(cleaned_texts, spell_check_type=spell_check) 141 | 142 | return cleaned_ids, cleaned_texts 143 | 144 | 145 | def gather_data(data_file_path, min_tok=None, n_tok=None, spell_check=None): 146 | ''' Gathers article data. ''' 147 | 148 | with open(data_file_path) as f: 149 | corpus_dict = json.load(f) 150 | 151 | cleaned_id_list, cleaned_text_list = clean_text( 152 | corpus_dict, 153 | first_n_tok=n_tok, 154 | min_tok=min_tok, 155 | spell_check=spell_check 156 | ) 157 | 158 | return cleaned_text_list, cleaned_id_list 159 | 160 | 161 | def list_ngrams(list_of_texts, n_gram_size=5, concat=False, char=False): 162 | ''' Returns list of n-grams given list of texts. ''' 163 | 164 | # Create list of all n-grams in all passages 165 | ngram_sets = [] 166 | for passage in list_of_texts: 167 | # Creates character-based n-grams 168 | if char: 169 | words = passage.split() 170 | passage = " ".join(words) 171 | n_grams = list(zip(*[passage[i:] for i in range(n_gram_size)])) 172 | # Creates word-based n-grams 173 | else: 174 | words = passage.split() 175 | n_grams = list(zip(*[words[i:] for i in range(n_gram_size)])) 176 | 177 | # concatenates n-grams instead of leaving them as tuples 178 | if concat: 179 | n_grams = [" ".join(x) for x in n_grams] 180 | 181 | ngram_sets.append(set(n_grams)) 182 | 183 | return ngram_sets 184 | 185 | 186 | def get_ngrams(cleaned_text_list, cleaned_id_list, n_gram_size, concat=False, char=False): 187 | ''' Returns formatted dictionary of n-grams. ''' 188 | 189 | # Create list of n-grams for each article 190 | n_gram_list = list_ngrams(cleaned_text_list, n_gram_size, concat=concat, char=char) 191 | 192 | # Split into dictionary item per day 193 | date_list = [] 194 | for art_id in cleaned_id_list: 195 | date_list.append("-".join(art_id.split("-")[-5:-2])) 196 | unique_date_list = list(set(date_list)) 197 | 198 | data_dict = {} 199 | for date in unique_date_list: 200 | indices = [i for i, x in enumerate(date_list) if x == date] 201 | data_dict[date] = { 202 | "id_list": [cleaned_id_list[i] for i in indices], 203 | "art_list": [cleaned_text_list[i] for i in indices], 204 | "ngram_list": [n_gram_list[i] for i in indices] 205 | } 206 | 207 | return data_dict 208 | 209 | 210 | def get_eval_metrics(edges, gt_path, ids, community_detection=False): 211 | ''' Gets evaluation metrics after clustering. ''' 212 | 213 | # Store different metrics for clustering 214 | edge_metrics = {} 215 | cluster_metrics = {} 216 | full_metrics = {} 217 | 218 | # Perform edge-level evaluation 219 | metrics = utils.evaluate(edges, gt_edge_path=gt_path) 220 | for value in metrics: 221 | edge_metrics[value] = metrics[value] 222 | 223 | # Optionally perform community detection 224 | if community_detection: 225 | edges = utils.detect_communities_nx(edges) 226 | 227 | # Impose transitivty for edges 228 | cluster_dict = utils.clusters_from_edges(edges) 229 | edges = utils.edges_from_clusters(cluster_dict) 230 | 231 | # Perform cluster-level evaluation (Recall, Precision, F1) 232 | metrics = utils.evaluate(edges, gt_edge_path=gt_path) 233 | for value in metrics: 234 | cluster_metrics[value] = metrics[value] 235 | 236 | # Perform cluster-level evaluation (RI, ARI, NMI, AMI) 237 | metrics = utils.cluster_eval(pred_edges=edges, gt_edges=gt_path, all_ids=ids) 238 | for value in metrics: 239 | full_metrics[value] = metrics[value] 240 | 241 | return edge_metrics, cluster_metrics, full_metrics 242 | -------------------------------------------------------------------------------- /neural/modified_sbert/train.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import torch 5 | from torch import nn 6 | from torch.utils.data import DataLoader 7 | from sentence_transformers import models, losses, datasets, evaluation, LoggingHandler, SentenceTransformer, util 8 | from sentence_transformers.datasets import SentenceLabelDataset 9 | from sentence_transformers.cross_encoder import CrossEncoder 10 | from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator 11 | 12 | import logging 13 | from transformers import logging as lg 14 | 15 | from modified_sbert import data_loaders, clu_evaluators 16 | 17 | 18 | lg.set_verbosity_error() 19 | logging.basicConfig(format='%(asctime)s - %(message)s', 20 | datefmt='%Y-%m-%d %H:%M:%S', 21 | level=logging.INFO, 22 | handlers=[LoggingHandler()]) 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class SupConLoss(nn.Module): 27 | """ 28 | Source: https://github.com/UKPLab/sentence-transformers/issues/1604 29 | Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 30 | """ 31 | 32 | def __init__(self, model, temperature=0.07, contrast_mode='all', 33 | base_temperature=0.07): 34 | super(SupConLoss, self).__init__() 35 | self.model = model 36 | self.temperature = temperature 37 | self.contrast_mode = contrast_mode 38 | self.base_temperature = base_temperature 39 | 40 | def forward(self, sentence_features, labels=None, mask=None): 41 | """Compute loss for model. If both `labels` and `mask` are None, 42 | it degenerates to SimCLR unsupervised loss: 43 | https://arxiv.org/pdf/2002.05709.pdf 44 | Args: 45 | features: hidden vector of shape [bsz, n_views, ...]. 46 | labels: ground truth of shape [bsz]. 47 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 48 | has the same class as sample i. Can be asymmetric. 49 | Returns: 50 | A loss scalar. 51 | """ 52 | features = self.model(sentence_features[0])['sentence_embedding'] 53 | 54 | #Nils: Normalize embeddings 55 | features = torch.nn.functional.normalize(features, p=2, dim=1) 56 | 57 | ## Nils: Add n_views dimension 58 | features = torch.unsqueeze(features, 1) 59 | 60 | device = features.device 61 | 62 | if len(features.shape) < 3: 63 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 64 | 'at least 3 dimensions are required') 65 | if len(features.shape) > 3: 66 | features = features.view(features.shape[0], features.shape[1], -1) 67 | 68 | batch_size = features.shape[0] 69 | if labels is not None and mask is not None: 70 | raise ValueError('Cannot define both `labels` and `mask`') 71 | elif labels is None and mask is None: 72 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 73 | elif labels is not None: 74 | labels = labels.contiguous().view(-1, 1) 75 | if labels.shape[0] != batch_size: 76 | raise ValueError('Num of labels does not match num of features') 77 | mask = torch.eq(labels, labels.T).float().to(device) 78 | else: 79 | mask = mask.float().to(device) 80 | 81 | contrast_count = features.shape[1] 82 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 83 | if self.contrast_mode == 'one': 84 | anchor_feature = features[:, 0] 85 | anchor_count = 1 86 | elif self.contrast_mode == 'all': 87 | anchor_feature = contrast_feature 88 | anchor_count = contrast_count 89 | else: 90 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 91 | 92 | # compute logits 93 | anchor_dot_contrast = torch.div( 94 | torch.matmul(anchor_feature, contrast_feature.T), 95 | self.temperature) 96 | # for numerical stability 97 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 98 | logits = anchor_dot_contrast - logits_max.detach() 99 | 100 | # tile mask 101 | mask = mask.repeat(anchor_count, contrast_count) 102 | # mask-out self-contrast cases 103 | logits_mask = torch.scatter( 104 | torch.ones_like(mask), 105 | 1, 106 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 107 | 0 108 | ) 109 | mask = mask * logits_mask 110 | 111 | # compute log_prob 112 | exp_logits = torch.exp(logits) * logits_mask 113 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 114 | 115 | # compute mean of log-likelihood over positive 116 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 117 | 118 | # loss 119 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 120 | loss = loss.view(anchor_count, batch_size).mean() 121 | 122 | return loss 123 | 124 | 125 | def train_biencoder( 126 | train_data: dict = None, 127 | dev_data: dict = None, 128 | base_model='sentence-transformers/all-MiniLM-L12-v2', 129 | add_pooling_layer=False, 130 | train_batch_size=64, 131 | num_epochs=10, 132 | warmup_epochs=1, 133 | loss_fn='contrastive', 134 | loss_params=None, 135 | model_save_path="output", 136 | ): 137 | 138 | os.makedirs(model_save_path, exist_ok=True) 139 | 140 | # Base language model 141 | if add_pooling_layer: 142 | word_embedding_model = models.Transformer(base_model, max_seq_length=512) 143 | pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode='mean') 144 | model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) 145 | else: 146 | model = SentenceTransformer(base_model) 147 | 148 | # Loss functions 149 | if loss_fn == "contrastive": 150 | train_loss = losses.OnlineContrastiveLoss( 151 | model=model, 152 | distance_metric=loss_params['distance_metric'], 153 | margin=loss_params['margin'] 154 | ) 155 | 156 | train_samples = data_loaders.load_data_as_pairs(train_data, type="neural") 157 | train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size) 158 | 159 | elif loss_fn == "cosine": 160 | train_loss = losses.CosineSimilarityLoss(model=model) 161 | 162 | train_samples = data_loaders.load_data_as_pairs(train_data, type="neural") 163 | train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size) 164 | 165 | elif loss_fn == "triplet": 166 | train_loss = losses.TripletLoss( 167 | model=model, 168 | distance_metric=loss_params['distance_metric'], 169 | triplet_margin=loss_params['margin'] 170 | ) 171 | 172 | train_samples = data_loaders.load_data_as_triplets(train_data, type="neural") 173 | train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size) 174 | 175 | elif loss_fn == "mnrl": 176 | train_loss = losses.MultipleNegativesRankingLoss(model=model) 177 | 178 | train_samples = data_loaders.load_data_as_triplets(train_data, type="neural") 179 | 180 | # Special dataloader that avoid duplicates within a batch 181 | train_dataloader = datasets.NoDuplicatesDataLoader(train_samples, batch_size=train_batch_size) 182 | 183 | elif loss_fn == "supcon": 184 | train_loss = losses.SupConLoss(model=model) 185 | 186 | train_samples = data_loaders.load_data_as_individuals(train_data, type="neural") 187 | 188 | # Special dataset "SentenceLabelDataset" to wrap out train_set 189 | # It yields batches that contain at least two samples with the same label 190 | train_data_sampler = SentenceLabelDataset(train_samples) 191 | train_dataloader = DataLoader(train_data_sampler, batch_size=train_batch_size) 192 | 193 | # Evaluate with multiple evaluators 194 | dev_pairs = data_loaders.load_data_as_pairs(dev_data, type="dev") 195 | # dev_triplets = data_loaders.load_data_as_triplets(dev_data, type="dev") 196 | 197 | evaluators = [ 198 | evaluation.BinaryClassificationEvaluator.from_input_examples(dev_pairs), 199 | # evaluation.EmbeddingSimilarityEvaluator.from_input_examples(dev_pairs), 200 | # evaluation.TripletEvaluator.from_input_examples(dev_triplets), 201 | clu_evaluators.ClusterEvaluator.from_input_examples(dev_pairs, cluster_type="agglomerative") 202 | ] 203 | 204 | seq_evaluator = evaluation.SequentialEvaluator(evaluators, main_score_function=lambda scores: scores[-1]) 205 | 206 | logger.info("Evaluate model without neural") 207 | seq_evaluator(model, epoch=0, steps=0, output_path=model_save_path) 208 | 209 | # Train the model 210 | model.fit( 211 | train_objectives=[(train_dataloader, train_loss)], 212 | evaluator=seq_evaluator, 213 | epochs=num_epochs, 214 | warmup_steps=math.ceil(len(train_dataloader) * warmup_epochs), 215 | output_path=model_save_path, 216 | evaluation_steps=112, 217 | checkpoint_save_steps=112, 218 | checkpoint_path=model_save_path, 219 | save_best_model=True, 220 | checkpoint_save_total_limit=10 221 | ) 222 | 223 | 224 | def train_crossencoder( 225 | train_data, 226 | dev_data, 227 | model_name, 228 | lr, 229 | train_batch_size, 230 | num_epochs, 231 | warm_up_perc, 232 | eval_per_epoch, 233 | model_save_path, 234 | ): 235 | 236 | model = CrossEncoder(model_name, num_labels=1) 237 | 238 | train = data_loaders.load_data_as_pairs(train_data, type="neural") 239 | dev = data_loaders.load_data_as_pairs(dev_data, type="dev") 240 | 241 | # Wrap train_samples, which is a list of InputExample, in a pytorch DataLoader 242 | train_dataloader = DataLoader(train, shuffle=True, batch_size=train_batch_size) 243 | 244 | # Evaluate with multiple evaluators 245 | evaluators = [ 246 | CEBinaryClassificationEvaluator.from_input_examples(dev, name='dev'), 247 | clu_evaluators.CEClusterEvaluator.from_input_examples(dev, name='dev'), 248 | ] 249 | 250 | seq_evaluator = evaluation.SequentialEvaluator(evaluators, main_score_function=lambda scores: scores[0]) 251 | 252 | warmup_steps = math.ceil(len(train_dataloader) * num_epochs * warm_up_perc) 253 | logger.info("Warmup-steps: {}".format(warmup_steps)) 254 | 255 | # Train the model 256 | model.fit(train_dataloader=train_dataloader, 257 | evaluator=seq_evaluator, 258 | epochs=num_epochs, 259 | evaluation_steps=int(len(train_dataloader)*(1/eval_per_epoch)), 260 | loss_fct=torch.nn.BCEWithLogitsLoss(), 261 | optimizer_params={"lr": lr}, 262 | warmup_steps=warmup_steps, 263 | output_path=model_save_path) 264 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | from itertools import combinations 4 | from datetime import datetime 5 | 6 | from tqdm import tqdm 7 | import random 8 | 9 | from sklearn.metrics import average_precision_score, adjusted_mutual_info_score, rand_score, adjusted_rand_score, normalized_mutual_info_score 10 | from sklearn.cluster import AgglomerativeClustering, DBSCAN 11 | import hdbscan 12 | 13 | import networkx as nx 14 | import networkx.algorithms.community as nx_comm 15 | import cugraph as cnx 16 | import cudf as gd 17 | 18 | 19 | def cluster(cluster_type, cluster_params, corpus_embeddings, corpus_ids=None): 20 | 21 | """ 22 | Perform specified clustering method 23 | """ 24 | 25 | if cluster_type not in ["agglomerative", "HDBScan", "SLINK"]: 26 | raise ValueError('cluster_type must be "agglomerative", "HDBScan", "community" or "SLINK"') 27 | if cluster_type == "agglomerative": 28 | if "threshold" not in cluster_params: 29 | raise ValueError('cluster_params must contain "threshold"') 30 | if "clustering linkage" not in cluster_params: 31 | raise ValueError('cluster_params must contain "clustering linkage"') 32 | if "metric" not in cluster_params: 33 | raise ValueError('cluster_params must contain "metric"') 34 | if cluster_type == "HDBScan": 35 | if "min cluster size" not in cluster_params: 36 | raise ValueError('cluster_params must contain "min cluster size"') 37 | if "min samples" not in cluster_params: 38 | raise ValueError('cluster_params must contain "min cluster size"') 39 | if cluster_type == "SLINK": 40 | if "min cluster size" not in cluster_params: 41 | raise ValueError('cluster_params must contain "min cluster size"') 42 | if "threshold" not in cluster_params: 43 | raise ValueError('cluster_params must contain "threshold"') 44 | if "clustering affinity" not in cluster_params: 45 | raise ValueError('cluster_params must contain "clustering affinity"') 46 | 47 | if cluster_type == "agglomerative": 48 | clustering_model = AgglomerativeClustering( 49 | n_clusters=None, 50 | distance_threshold=cluster_params["threshold"], 51 | linkage=cluster_params["clustering linkage"], 52 | affinity=cluster_params["metric"] 53 | ) 54 | 55 | if cluster_type == "SLINK": 56 | clustering_model = DBSCAN( 57 | eps=cluster_params["threshold"], 58 | min_samples=cluster_params["min cluster size"], 59 | metric=cluster_params["metric"] 60 | ) 61 | 62 | if cluster_type == "HDBScan": 63 | clustering_model = hdbscan.HDBSCAN( 64 | min_cluster_size=cluster_params["min cluster size"], 65 | min_samples=cluster_params["min samples"], 66 | gen_min_span_tree=True 67 | ) 68 | 69 | clustering_model.fit(corpus_embeddings) 70 | cluster_assignment = clustering_model.labels_ 71 | 72 | clustered_ids = {} 73 | for sentence_id, cluster_id in enumerate(cluster_assignment): 74 | if int(cluster_id) not in clustered_ids: 75 | clustered_ids[int(cluster_id)] = [] 76 | 77 | if corpus_ids: 78 | clustered_ids[int(cluster_id)].append(corpus_ids[sentence_id]) 79 | else: 80 | clustered_ids[int(cluster_id)].append(sentence_id) 81 | 82 | # HDBScan has a cluster where it puts all the unassigned nodes 83 | if cluster_type == "HDBScan" or cluster_type == "SLINK" and -1 in clustered_ids: 84 | del clustered_ids[-1] 85 | 86 | return clustered_ids 87 | 88 | 89 | def clusters_from_edges(edges_list): 90 | """Identify clusters of passages given a dictionary of edges""" 91 | 92 | # clusters via NetworkX 93 | G = nx.Graph() 94 | G.add_edges_from(edges_list) 95 | sub_graphs = [G.subgraph(c).copy() for c in nx.connected_components(G)] 96 | 97 | sub_graph_dict = {} 98 | for i in range(len(sub_graphs)): 99 | sub_graph_dict[i] = list(sub_graphs[i].nodes()) 100 | 101 | return sub_graph_dict 102 | 103 | 104 | def edges_from_clusters(cluster_dict): 105 | """ 106 | Convert every pair in a cluster into an edge 107 | """ 108 | cluster_edges = [] 109 | for cluster_id in list(cluster_dict.keys()): 110 | art_ids_list = cluster_dict[cluster_id] 111 | edge_list = [list(comb) for comb in combinations(art_ids_list, 2)] 112 | cluster_edges.extend(edge_list) 113 | 114 | return cluster_edges 115 | 116 | 117 | def evaluate(pred_edges, gt_edge_path=None, gt_edges=None, print_metrics=True, print_incorrects=False, two_way=True, save_incorrects=False): 118 | 119 | """ 120 | Return F1, recall, precision, from set of predicted edges and gt set 121 | """ 122 | 123 | if not gt_edges and not gt_edge_path: 124 | raise ValueError("either gt_edge_path or gt_edges must be specified") 125 | 126 | # Prep ground truth 127 | if not gt_edges: 128 | with open(gt_edge_path) as f: 129 | gt_edges = json.load(f) 130 | 131 | set_gt = set(map(tuple, gt_edges)) 132 | 133 | # Prep preds 134 | pred_edges_list = [[edge[0], edge[1]] for edge in pred_edges] 135 | set_preds = set(map(tuple, pred_edges_list)) 136 | 137 | # Metrics 138 | if two_way: 139 | tps = len([i for i in set_gt if i in set_preds or (i[1], i[0]) in set_preds]) 140 | fps = len([i for i in set_preds if i not in set_gt and (i[1], i[0]) not in set_gt]) 141 | fns = len([i for i in set_gt if i not in set_preds and (i[1], i[0]) not in set_preds]) 142 | else: 143 | tps = len([i for i in set_gt if i in set_preds]) 144 | fps = len([i for i in set_preds if i not in set_gt]) 145 | fns = len([i for i in set_gt if i not in set_preds]) 146 | 147 | if tps + fps > 0: 148 | precision = tps / (tps + fps) 149 | else: 150 | precision = 0 151 | if tps + fns > 0: 152 | recall = tps / (tps + fns) 153 | else: 154 | recall = 0 155 | if precision + recall > 0: 156 | f_score = 2 * (precision * recall) / (precision + recall) 157 | else: 158 | f_score = 0 159 | 160 | metrics = {"precision": precision, "recall": recall, "f_score": f_score, "tps": tps, "fps": fps, "fns": fns} 161 | 162 | # Look at wrong ones 163 | if print_incorrects: 164 | fp_list = [i for i in set_preds if i not in set_gt] 165 | fn_list = [i for i in set_gt if i not in set_preds] 166 | 167 | print(fn_list) 168 | print(len(fn_list)) 169 | print(tps, fps, fns) 170 | 171 | if print_metrics: 172 | print(metrics) 173 | 174 | if save_incorrects: 175 | fp_list = [i for i in set_preds if i not in set_gt and (i[1], i[0]) not in set_gt] 176 | fn_list = [i for i in set_gt if i not in set_preds and (i[1], i[0]) not in set_preds] 177 | 178 | print(tps, fps, fns) 179 | 180 | fp_list = random.sample(fp_list, 50) 181 | fn_list = random.sample(fn_list, 50) 182 | 183 | return fp_list, fn_list 184 | 185 | else: 186 | 187 | return metrics 188 | 189 | 190 | def cluster_eval(pred_edges, gt_edges, all_ids): 191 | 192 | """ 193 | Return RI, ARI, NMI, AMI, from set of predicted edges and gt set 194 | """ 195 | 196 | pred_clusters = clusters_from_edges(pred_edges) 197 | 198 | with open(gt_edges) as f: 199 | gt_edges = json.load(f) 200 | 201 | set_gt = set(map(tuple, gt_edges)) 202 | gt_clusters = clusters_from_edges(set_gt) 203 | 204 | # get dictionary mapping article to cluster number 205 | pred_dict = {} 206 | pred_count = 0 207 | for cluster in pred_clusters: 208 | for article in pred_clusters[cluster]: 209 | pred_dict[article] = pred_count 210 | pred_count += 1 211 | 212 | gt_dict = {} 213 | gt_count = 0 214 | for cluster in gt_clusters: 215 | for article in gt_clusters[cluster]: 216 | gt_dict[article] = gt_count 217 | gt_count += 1 218 | 219 | # fill in clusters with unclustered articles 220 | full_pred_clusters = [] 221 | full_gt_clusters = [] 222 | for article in all_ids: 223 | if article in pred_dict: 224 | full_pred_clusters.append(pred_dict[article]) 225 | else: 226 | full_pred_clusters.append(pred_count) 227 | pred_count += 1 228 | 229 | if article in gt_dict: 230 | full_gt_clusters.append(gt_dict[article]) 231 | else: 232 | full_gt_clusters.append(gt_count) 233 | gt_count += 1 234 | 235 | assert len(full_pred_clusters) == len(full_gt_clusters) 236 | 237 | RI = rand_score(full_pred_clusters, full_gt_clusters) 238 | ARI = adjusted_rand_score(full_pred_clusters, full_gt_clusters) 239 | NMI = normalized_mutual_info_score(full_pred_clusters, full_gt_clusters) 240 | AMI = adjusted_mutual_info_score(full_pred_clusters, full_gt_clusters) 241 | 242 | print({"RI": RI, "ARI": ARI, "NMI": NMI, "AMI": AMI}) 243 | 244 | return {"RI": RI, "ARI": ARI, "NMI": NMI, "AMI": AMI} 245 | 246 | 247 | def detect_communities_nx(edges, resolution=1): 248 | 249 | """Louvain community detection using nx""" 250 | 251 | G = nx.Graph() 252 | G.add_edges_from(edges) 253 | 254 | communities = nx_comm.louvain_communities(G, resolution=resolution) 255 | 256 | sub_graph_dict = {} 257 | for i in range(len(communities)): 258 | sub_graph_dict[i] = list(communities[i]) 259 | 260 | return edges_from_clusters(sub_graph_dict) 261 | 262 | 263 | def cnx_make_graph_from_edges(edge_list): 264 | 265 | """Make a graph from list of lists of neighbors""" 266 | 267 | time_graph_start = datetime.now() 268 | 269 | # Build edges into a gpu dataframe 270 | edge_df = gd.DataFrame({'src': gd.Series([i[0] for i in edge_list]), 'dst': gd.Series([i[1] for i in edge_list])}) 271 | 272 | # Make graph 273 | G = cnx.Graph() 274 | G.from_cudf_edgelist(edge_df, source='src', destination='dst') 275 | 276 | print("Number of nodes:", cnx.structure.graph_implementation.simpleGraphImpl.number_of_vertices(G)) 277 | print("Number of edges before imposing transistivty:", cnx.structure.graph_implementation.simpleGraphImpl.number_of_edges(G)) 278 | 279 | time_graph_end = datetime.now() 280 | print("Time taken to make graph: ", time_graph_end-time_graph_start) 281 | 282 | return G 283 | 284 | 285 | def gpu_connected_components(G, save_file, detect_communities=False): 286 | 287 | """ 288 | Impose transitivity and return edges, either with or without community detection 289 | """ 290 | 291 | time_cc_start = datetime.now() 292 | 293 | print("Imposing transitivity ...") 294 | 295 | if detect_communities: 296 | ccs, _ = cnx.louvain(G, resolution=1) 297 | ccs = ccs.rename(columns={"partition": "labels"}) 298 | else: 299 | ccs = cnx.connected_components(G) 300 | 301 | print("Distinct connected components: ", ccs.labels.nunique()) 302 | 303 | total_perms = 0 304 | total_reduced = 0 305 | 306 | def get_edges(df_1, df_2, total_perms, total_reduced): 307 | 308 | df_1 = df_1.merge(df_2, on='labels', how='inner') 309 | df_1 = df_1.drop(['labels'], axis=1) 310 | total_perms += len(df_1) 311 | 312 | df_1 = df_1[df_1['vertex_x'] < df_1['vertex_y']] # remove both directions and loops 313 | total_reduced += len(df_1) 314 | 315 | return df_1, total_perms, total_reduced 316 | 317 | lengths = [] 318 | all_edges = [] 319 | 320 | ccs_pd = ccs.to_pandas() 321 | 322 | for label in tqdm(ccs_pd.labels.unique()): 323 | sub_frame = ccs[ccs['labels'] == label] 324 | 325 | lengths.append(len(sub_frame)) 326 | 327 | if len(sub_frame) < 50000: # Larger subframes don't fit on GPU, so run on CPU (though slower!) 328 | edge_df, total_perms, total_reduced = get_edges(sub_frame, sub_frame, total_perms, total_reduced) 329 | all_edges.append(edge_df) 330 | 331 | else: 332 | sub_frame_A = sub_frame[:30000] 333 | sub_frame_B = sub_frame[30000:] 334 | 335 | edge_df, total_perms, total_reduced = get_edges(sub_frame_A, sub_frame_A, total_perms, total_reduced) 336 | all_edges.append(edge_df) 337 | edge_df, total_perms, total_reduced = get_edges(sub_frame_A, sub_frame_B, total_perms, total_reduced) 338 | all_edges.append(edge_df) 339 | edge_df, total_perms, total_reduced = get_edges(sub_frame_B, sub_frame_A, total_perms, total_reduced) 340 | all_edges.append(edge_df) 341 | edge_df, total_perms, total_reduced = get_edges(sub_frame_B, sub_frame_B, total_perms, total_reduced) 342 | all_edges.append(edge_df) 343 | 344 | squares = [i*i for i in lengths] 345 | 346 | assert total_perms == sum(squares) 347 | assert total_reduced == (sum(squares) - len(ccs))/2 348 | 349 | time_cc_end = datetime.now() 350 | 351 | edges = gd.concat(all_edges) 352 | assert len(edges) == total_reduced 353 | 354 | print("Time taken to find connected components: ", time_cc_end-time_cc_start) 355 | print("Number of edges after imposing transitivity:", len(edges)) 356 | 357 | with open(save_file, 'wb') as f: 358 | pickle.dump(edges, f) 359 | -------------------------------------------------------------------------------- /neural/modified_sbert/clu_evaluators.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import csv 5 | from itertools import combinations 6 | from typing import List 7 | 8 | from sentence_transformers import evaluation, LoggingHandler 9 | from sentence_transformers.readers import InputExample 10 | 11 | import logging 12 | from transformers import logging as lg 13 | 14 | currentdir = os.path.dirname(os.path.realpath(__file__)) 15 | parentdir = os.path.dirname(currentdir) 16 | sys.path.append(parentdir) 17 | 18 | import utils 19 | 20 | lg.set_verbosity_error() 21 | logging.basicConfig(format='%(asctime)s - %(message)s', 22 | datefmt='%Y-%m-%d %H:%M:%S', 23 | level=logging.INFO, 24 | handlers=[LoggingHandler()]) 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | class ClusterEvaluator(evaluation.SentenceEvaluator): 29 | 30 | """ 31 | Evaluate a model based on allocation of texts into correct clusters. 32 | Embeddings are clustered with the specified clustering algorithm using cosine distance. Best clustering parameters 33 | (distance threshold) are found using an approximate search method to speed to evaluation time. 34 | 35 | All possible combination of articles are split into pairs, with positives being in the same cluster and negatives 36 | being in different clusters. 37 | Metrics are precision, recall and F1. 38 | 39 | Returned metrics are F1 along with the optimal clustering threshold. 40 | 41 | The results are written in a CSV. If a CSV already exists, then values are appended. 42 | 43 | :param sentences1: The first column of sentences 44 | :param sentences2: The second column of sentences 45 | :param labels: labels[i] is the label for the pair (sentences1[i], sentences2[i]). Must be 0 or 1 46 | :param name: Name for the output 47 | :param batch_size: Batch size used to compute embeddings 48 | :param show_progress_bar: If true, prints a progress bar 49 | :param write_csv: Write results to a CSV file 50 | :param cluster_type: Clustering algoritm to use. Supports "agglomerative" (hierarchical), "SLINK", "HDBScan" 51 | 52 | Modelled on: https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/evaluation/BinaryClassificationEvaluator.py 53 | """ 54 | 55 | def __init__( 56 | self, 57 | sentences1: List[str], 58 | sentences2: List[str], 59 | labels: List[int], 60 | name: str = '', 61 | batch_size: int = 512, 62 | show_progress_bar: bool = False, 63 | write_csv: bool = True, 64 | cluster_type: str = "agglomerative" 65 | ): 66 | 67 | self.sentences1 = sentences1 68 | self.sentences2 = sentences2 69 | self.labels = labels 70 | self.cluster_type = cluster_type 71 | 72 | assert len(self.sentences1) == len(self.sentences2) 73 | assert len(self.sentences1) == len(self.labels) 74 | for label in labels: 75 | assert (label == 0 or label == 1) 76 | 77 | self.write_csv = write_csv 78 | self.name = name 79 | self.batch_size = batch_size 80 | if show_progress_bar is None: 81 | show_progress_bar = (logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG) 82 | self.show_progress_bar = show_progress_bar 83 | 84 | self.csv_file = "clustering_evaluation" + ("_"+name if name else '') + "_results.csv" 85 | self.csv_headers = ["epoch", "steps", "accuracy", "accuracy_threshold", "f1", "precision", "recall", "f1_threshold"] 86 | 87 | @classmethod 88 | def from_input_examples(cls, examples: List[InputExample], **kwargs): 89 | sentences1 = [] 90 | sentences2 = [] 91 | scores = [] 92 | 93 | for example in examples: 94 | sentences1.append(example.texts[0]) 95 | sentences2.append(example.texts[1]) 96 | scores.append(example.label) 97 | return cls(sentences1, sentences2, scores, **kwargs) 98 | 99 | def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: 100 | 101 | if epoch != -1: 102 | if steps == -1: 103 | out_txt = f" after epoch {epoch}:" 104 | else: 105 | out_txt = f" in epoch {epoch} after {steps} steps:" 106 | else: 107 | out_txt = ":" 108 | 109 | logger.info("Cluster Evaluation of the model on " + self.name + " dataset" + out_txt) 110 | 111 | scores = self.compute_metrices(model) 112 | 113 | #Main score is F1 114 | main_score = scores['f1'] 115 | 116 | file_output_data = [epoch, steps] 117 | 118 | for score in self.csv_headers: 119 | if score in scores: 120 | file_output_data.append(scores[score]) 121 | 122 | if output_path is not None and self.write_csv: 123 | csv_path = os.path.join(output_path, self.csv_file) 124 | if not os.path.isfile(csv_path): 125 | with open(csv_path, newline='', mode="w", encoding="utf-8") as f: 126 | writer = csv.writer(f) 127 | writer.writerow(self.csv_headers) 128 | writer.writerow(file_output_data) 129 | else: 130 | with open(csv_path, newline='', mode="a", encoding="utf-8") as f: 131 | writer = csv.writer(f) 132 | writer.writerow(file_output_data) 133 | 134 | return main_score 135 | 136 | def compute_metrices(self, model): 137 | 138 | sentences = [] 139 | labels = [] 140 | for i in range(len(self.sentences1)): 141 | 142 | if self.sentences1[i] not in sentences: 143 | sentences.append(self.sentences1[i]) 144 | s1_id = sentences.index(self.sentences1[i]) 145 | if self.sentences2[i] not in sentences: 146 | sentences.append(self.sentences2[i]) 147 | s2_id = sentences.index(self.sentences2[i]) 148 | 149 | if self.labels[i] == 1: 150 | labels.append([s1_id, s2_id]) 151 | 152 | embeddings = model.encode(sentences, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar) 153 | 154 | # Normalize the embeddings to unit length 155 | embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) 156 | 157 | def cluster_eval(threshold, embeddings, labels, cluster_type='agglomerative'): 158 | 159 | clustered_ids = utils.cluster( 160 | cluster_type, 161 | cluster_params={"threshold": threshold, "clustering linkage": 'average', "metric": 'cosine', "min cluster size": 2}, 162 | corpus_embeddings=embeddings 163 | ) 164 | 165 | # Convert every pair in a cluster into an edge 166 | cluster_edges = utils.edges_from_clusters(clustered_ids) 167 | 168 | metrics = utils.evaluate(pred_edges=cluster_edges, gt_edges=labels, print_metrics=False) 169 | 170 | total = len(list(combinations(range(len(embeddings)), 2))) 171 | cluster_tn = total - metrics["tps"] - metrics["fps"] - metrics["fns"] 172 | 173 | metrics["accuracy"] = (metrics["tps"] + cluster_tn)/total 174 | 175 | return metrics 176 | 177 | tenths = {} 178 | for threshold in [0.01] + [round(x, 2) for x in (np.linspace(0.1, 0.9, 9))] + [0.99]: 179 | tenths[threshold] = cluster_eval(threshold, embeddings, labels, cluster_type=self.cluster_type) 180 | 181 | def best_threshold(dictionary, metric): 182 | 183 | ths = list(dictionary.keys()) 184 | scores = [] 185 | 186 | for th in ths: 187 | scores.append(dictionary[th][metric]) 188 | 189 | sorted_scores = sorted(scores) 190 | 191 | best_score = sorted_scores[-1] 192 | second_best_score = sorted_scores[-2] 193 | third_best_score = sorted_scores[-3] 194 | 195 | if best_score == second_best_score: 196 | best_indices = [i for i, x in enumerate(scores) if x == best_score] 197 | best_thresholds = [] 198 | for idx in best_indices: 199 | best_thresholds.append(ths[idx]) 200 | 201 | best_th = max(best_thresholds) 202 | second_best_th = min(best_thresholds) 203 | 204 | elif second_best_score == third_best_score: 205 | second_indices = [i for i, x in enumerate(scores) if x == second_best_score] 206 | second_thresholds = [] 207 | for idx in second_indices: 208 | second_thresholds.append(ths[idx]) 209 | 210 | best_th = max(second_thresholds) 211 | second_best_th = min(second_thresholds) 212 | 213 | else: 214 | best_idx = scores.index(best_score) 215 | best_th = ths[best_idx] 216 | 217 | second_best_idx = scores.index(second_best_score) 218 | second_best_th = ths[second_best_idx] 219 | 220 | return best_th, second_best_th 221 | 222 | max_f1_th, second_f1_th = best_threshold(tenths, metric='f_score') 223 | max_acc_th, second_acc_th = best_threshold(tenths, metric='accuracy') 224 | 225 | min_th = min(max_f1_th, second_f1_th, max_acc_th, second_acc_th) 226 | max_th = max(max_f1_th, second_f1_th, max_acc_th, second_acc_th) 227 | 228 | hundreths = {} 229 | for threshold in np.arange(min_th, max_th, 0.01): 230 | hundreths[threshold] = cluster_eval(threshold, embeddings, labels) 231 | 232 | hd_max_f1_th, _ = best_threshold(hundreths, 'f_score') 233 | hd_max_acc_th, _ = best_threshold(hundreths, 'accuracy') 234 | 235 | acc = hundreths[hd_max_acc_th]['accuracy'] 236 | acc_threshold = hd_max_acc_th 237 | 238 | f1 = hundreths[hd_max_f1_th]['f_score'] 239 | precision = hundreths[hd_max_f1_th]['precision'] 240 | recall = hundreths[hd_max_f1_th]['recall'] 241 | f1_threshold = hd_max_f1_th 242 | 243 | logger.info("Cluster Accuracy: {:.2f}\t(Threshold: {:.2f})".format(acc * 100, acc_threshold)) 244 | logger.info("Cluster F1: {:.2f}\t(Threshold: {:.2f})".format(f1 * 100, f1_threshold)) 245 | logger.info("Cluster Precision: {:.2f}".format(precision * 100)) 246 | logger.info("Cluster Recall: {:.2f}\n".format(recall * 100)) 247 | 248 | output_scores = { 249 | 'accuracy': acc, 250 | 'accuracy_threshold': acc_threshold, 251 | 'f1': f1, 252 | 'f1_threshold': f1_threshold, 253 | 'precision': precision, 254 | 'recall': recall, 255 | } 256 | 257 | return output_scores 258 | 259 | 260 | class CEClusterEvaluator(): 261 | 262 | def __init__(self, sentence_pairs: List[List[str]], labels: List[int], name: str='', write_csv: bool = True): 263 | assert len(sentence_pairs) == len(labels) 264 | for label in labels: 265 | assert (label == 0 or label == 1) 266 | 267 | self.sentence_pairs = sentence_pairs 268 | self.labels = np.asarray(labels) 269 | self.name = name 270 | 271 | self.csv_file = "CEClusterEvaluator" + ("_" + name if name else '') + "_results.csv" 272 | self.csv_headers = ["epoch", "steps", "accuracy", "accuracy_threshold", "f1", "f1_threshold", "precision", "recall"] 273 | self.write_csv = write_csv 274 | 275 | @classmethod 276 | def from_input_examples(cls, examples: List[InputExample], **kwargs): 277 | sentence_pairs = [] 278 | labels = [] 279 | 280 | for example in examples: 281 | sentence_pairs.append(example.texts) 282 | labels.append(example.label) 283 | return cls(sentence_pairs, labels, **kwargs) 284 | 285 | def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: 286 | if epoch != -1: 287 | if steps == -1: 288 | out_txt = " after epoch {}:".format(epoch) 289 | else: 290 | out_txt = " in epoch {} after {} steps:".format(epoch, steps) 291 | else: 292 | out_txt = ":" 293 | 294 | logger.info("CEClusterEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt) 295 | 296 | pred_scores = model.predict(self.sentence_pairs, convert_to_numpy=True, show_progress_bar=False) 297 | 298 | acc, acc_threshold, f1, precision, recall, f1_threshold = self.find_best_acc_and_f1(pred_scores, self.sentence_pairs, self.labels) 299 | 300 | logger.info("Cluster Accuracy: {:.2f}\t(Threshold: {:.2f})".format(acc * 100, acc_threshold)) 301 | logger.info("Cluster F1: {:.2f}\t(Threshold: {:.2f})".format(f1 * 100, f1_threshold)) 302 | logger.info("Cluster Precision: {:.2f}".format(precision * 100)) 303 | logger.info("Cluster Recall: {:.2f}\n".format(recall * 100)) 304 | 305 | 306 | if output_path is not None and self.write_csv: 307 | csv_path = os.path.join(output_path, self.csv_file) 308 | output_file_exists = os.path.isfile(csv_path) 309 | with open(csv_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f: 310 | writer = csv.writer(f) 311 | if not output_file_exists: 312 | writer.writerow(self.csv_headers) 313 | 314 | writer.writerow([epoch, steps, acc, acc_threshold, f1, f1_threshold, precision, recall]) 315 | 316 | return f1 317 | 318 | @staticmethod 319 | def find_best_acc_and_f1(scores, sentence_pairs, labels): 320 | 321 | assert len(scores) == len(labels) 322 | 323 | sentences = [] 324 | pair_ids = [] 325 | for pair in sentence_pairs: 326 | if pair[0] not in sentences: 327 | sentences.append(pair[0]) 328 | s1_id = sentences.index(pair[0]) 329 | if pair[1] not in sentences: 330 | sentences.append(pair[1]) 331 | s2_id = sentences.index(pair[1]) 332 | pair_ids.append([s1_id, s2_id]) 333 | 334 | gt_edges = [pair_ids[i] for i in range(len(pair_ids)) if labels[i] == 1] 335 | gt_edges = utils.edges_from_clusters(utils.clusters_from_edges(gt_edges)) # Impose transitivity 336 | 337 | total_possible_edges = len(labels) 338 | 339 | thds = list(set([round(score, 2) for score in scores])) 340 | accuracies = [] 341 | precisions = [] 342 | recalls = [] 343 | f1s = [] 344 | 345 | for th in thds: 346 | 347 | preds = [pair_ids[i] for i in range(len(pair_ids)) if scores[i] > th] 348 | 349 | # Impose transitivity 350 | pred_edges = utils.edges_from_clusters(utils.clusters_from_edges(preds)) 351 | 352 | metrics = utils.evaluate(pred_edges=pred_edges, gt_edges=gt_edges, print_metrics=False) 353 | precisions.append(metrics['precision']) 354 | recalls.append(metrics['recall']) 355 | f1s.append(metrics['f_score']) 356 | 357 | cluster_tn = total_possible_edges - metrics["tps"] - metrics["fps"] - metrics["fns"] 358 | accuracies.append((metrics["tps"] + cluster_tn) / total_possible_edges) 359 | 360 | # Find max values 361 | max_acc = max(accuracies) 362 | acc_idx = accuracies.index(max_acc) 363 | acc_threshold = thds[acc_idx] 364 | 365 | max_f1 = max(f1s) 366 | f1_idx = f1s.index(max_f1) 367 | precision = precisions[f1_idx] 368 | recall = recalls[f1_idx] 369 | f1_threshold = thds[f1_idx] 370 | 371 | return max_acc, acc_threshold, max_f1, precision, recall, f1_threshold 372 | 373 | 374 | 375 | --------------------------------------------------------------------------------