├── .gitignore ├── run.sh ├── models ├── de │ ├── example_2017-07-13T085725.498310.ii │ └── 2_1499938269.299021_100.0.model ├── en │ ├── example_2017-04-10T193850.536289.ii │ ├── md_3 │ └── 4_1491902620.876421_10000.0.model └── es │ ├── example_2017-04-12T215308.030747.ii │ └── 2_1492035151.291134_100.0.model ├── download_data.sh ├── README.md ├── LICENSE ├── model.py ├── load_corpora.py ├── testbench.py ├── utils.py ├── eval_lib.py ├── dataset_loader.py ├── clustering.py └── eval.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | dataset 3 | clustering.eng.out 4 | clustering.spa.out 5 | clustering.deu.out 6 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python testbench.py 2 | python eval.py clustering.eng.out dataset/dataset.test.json -f 3 | python eval.py clustering.spa.out dataset/dataset.test.json -f 4 | python eval.py clustering.deu.out dataset/dataset.test.json -f -------------------------------------------------------------------------------- /models/de/example_2017-07-13T085725.498310.ii: -------------------------------------------------------------------------------- 1 | 9 RELEVANCE_TS 2 | 1 Entities_all 3 | 2 Entities_body 4 | 3 Entities_title 5 | 4 Lemmas_all 6 | 5 Lemmas_body 7 | 6 Lemmas_title 8 | 7 NEWEST_TS 9 | 8 OLDEST_TS 10 | 10 Tokens_all 11 | 11 Tokens_body 12 | 12 Tokens_title 13 | -------------------------------------------------------------------------------- /models/en/example_2017-04-10T193850.536289.ii: -------------------------------------------------------------------------------- 1 | 9 RELEVANCE_TS 2 | 1 Entities_all 3 | 2 Entities_body 4 | 3 Entities_title 5 | 4 Lemmas_all 6 | 5 Lemmas_body 7 | 6 Lemmas_title 8 | 7 NEWEST_TS 9 | 8 OLDEST_TS 10 | 10 Tokens_all 11 | 11 Tokens_body 12 | 12 Tokens_title 13 | -------------------------------------------------------------------------------- /models/es/example_2017-04-12T215308.030747.ii: -------------------------------------------------------------------------------- 1 | 1 Entities_all 2 | 2 Entities_body 3 | 3 Entities_title 4 | 4 Lemmas_all 5 | 5 Lemmas_body 6 | 6 Lemmas_title 7 | 7 NEWEST_TS 8 | 8 OLDEST_TS 9 | 9 RELEVANCE_TS 10 | 10 Tokens_all 11 | 11 Tokens_body 12 | 12 Tokens_title 13 | 13 ZINV_POOL_SIZE 14 | 14 ZZINVCLUSTER_SIZE 15 | -------------------------------------------------------------------------------- /models/en/md_3: -------------------------------------------------------------------------------- 1 | 0.0 2 | Entities_all 1.021403965824735 3 | Entities_body 1.071686657406047 4 | Entities_title -0.494689970075746 5 | Lemmas_all 2.761380597640421 6 | Lemmas_body 3.584626976821256 7 | Lemmas_title -0.5835470910968313 8 | NEWEST_TS -3.410517579967286 9 | OLDEST_TS 2.446669764223947 10 | RELEVANCE_TS -0.2715168685905864 11 | Tokens_all 5.367801738248652 12 | Tokens_body 4.107900097257066 13 | Tokens_title -0.5822909115616434 14 | ZINV_POOL_SIZE -0.06738887528081441 15 | ZZINVCLUSTER_SIZE 0.0 -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | mkdir dataset 2 | wget -P dataset ftp://"ftp.priberam.pt|anonymous"@ftp.priberam.pt/SUMMAPublic/Corpora/Clustering/2018.0/dataset/dataset.dev.json 3 | wget -P dataset ftp://"ftp.priberam.pt|anonymous"@ftp.priberam.pt/SUMMAPublic/Corpora/Clustering/2018.0/dataset/dataset.test.json 4 | wget -P dataset ftp://"ftp.priberam.pt|anonymous"@ftp.priberam.pt/SUMMAPublic/Corpora/Clustering/2018.0/dataset-tok-ner/clustering.dev.json 5 | wget -P dataset ftp://"ftp.priberam.pt|anonymous"@ftp.priberam.pt/SUMMAPublic/Corpora/Clustering/2018.0/dataset-tok-ner/clustering.test.json -------------------------------------------------------------------------------- /models/de/2_1499938269.299021_100.0.model: -------------------------------------------------------------------------------- 1 | SVM-light Version V6.20 2 | 0 # kernel type 3 | 3 # kernel parameter -d 4 | 1 # kernel parameter -g 5 | 1 # kernel parameter -s 6 | 1 # kernel parameter -r 7 | empty# kernel parameter -u 8 | 13 # highest feature index 9 | 15 # number of training documents 10 | 2 # number of support vectors plus 1 11 | 0 # threshold b, each following line is a SV (starting with alpha*y) 12 | 1 1:2.6352043 2:3.0292611 3:0.37948003 4:5.0083957 5:1.6542898 6:0.092155419 7:1.1313787 8:3.0538087 9:-0.4301641 10:5.0044088 11:1.6542898 12:0.092155419 # 13 | -------------------------------------------------------------------------------- /models/en/4_1491902620.876421_10000.0.model: -------------------------------------------------------------------------------- 1 | SVM-light Version V6.20 2 | 0 # kernel type 3 | 3 # kernel parameter -d 4 | 1 # kernel parameter -g 5 | 1 # kernel parameter -s 6 | 1 # kernel parameter -r 7 | empty# kernel parameter -u 8 | 13 # highest feature index 9 | 25 # number of training documents 10 | 2 # number of support vectors plus 1 11 | 0 # threshold b, each following line is a SV (starting with alpha*y) 12 | 1 1:2.4383616 2:-0.73693264 3:1.4310931 4:0.74489218 5:-1.7037395 6:0.32604426 7:2.1834517 8:5.0244422 9:0.29178089 10:3.8883567 11:1.0086491 12:-1.9079906 # 13 | -------------------------------------------------------------------------------- /models/es/2_1492035151.291134_100.0.model: -------------------------------------------------------------------------------- 1 | SVM-light Version V6.20 2 | 0 # kernel type 3 | 3 # kernel parameter -d 4 | 1 # kernel parameter -g 5 | 1 # kernel parameter -s 6 | 1 # kernel parameter -r 7 | empty# kernel parameter -u 8 | 15 # highest feature index 9 | 25 # number of training documents 10 | 2 # number of support vectors plus 1 11 | 0 # threshold b, each following line is a SV (starting with alpha*y) 12 | 1 1:2.6354764 2:2.3855691 3:-0.48799703 4:4.8541822 5:3.498678 6:1.4854121 7:3.0260785 8:1.3557703 9:-1.7998464 10:3.6655061 11:2.3828528 12:1.0742556 14:1.3829948 # 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ⚠️Supercedence Notice 2 | 3 | This work has been superseded by https://github.com/Priberam/projected-news-clustering. 4 | 5 | # news-clustering 6 | 7 | run download_data.sh to download dataset 8 | run run.sh to execute and print scores 9 | 10 | Implementation of the paper: Multilingual Clustering of Streaming News, Sebastião Miranda, Artūrs Znotiņš, Shay B. Cohen, Guntis Barzdins, In EMNLP 2018 (http://aclweb.org/anthology/D18-1483) 11 | 12 | The original paper, as mentioned above, used proprietary software by Priberam. Unfortunately, we are unable to release this software (because of licensing issues and because it is embedded in a larger C++ system), so we provide a re-implementation in Python that we hope will also be clearer to work with and change. Some parts, such as the feature extraction and svm training code are proprietary or part of proprietary code, so we provide the dataset with the features already extracted and also the pre-trained models. 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The 3-Clause BSD License 2 | 3 | For Priberam Clustering Software 4 | 5 | Copyright 2018 by PRIBERAM INFORMÁTICA, S.A. (“PRIBERAM”) (www.priberam.com) 6 | 7 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 12 | 13 | 3. Neither the name of the copyright holder (PRIBERAM) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 16 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # The 3-Clause BSD License 2 | # For Priberam Clustering Software 3 | # Copyright 2018 by PRIBERAM INFORMÁTICA, S.A. ("PRIBERAM") (www.priberam.com) 4 | # Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | # 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | # 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 7 | # 3. Neither the name of the copyright holder (PRIBERAM) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 8 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 9 | import json 10 | 11 | 12 | class Model: 13 | def __init__(self): 14 | self.weights = {} 15 | self.bias = 0 16 | 17 | def load(self, model_path, ii_path): 18 | ii = {} 19 | with open(ii_path) as fii: 20 | ii = dict([(line.split('\t')[0], line.split('\t')[1].strip()) 21 | for line in fii]) 22 | 23 | print(json.dumps(ii)) 24 | with open(model_path) as fm: 25 | for i, line in enumerate(fm): 26 | if i == 10: 27 | self.bias = float(line.split('#')[0].strip()) 28 | elif i == 11: 29 | for f in line.split('#')[0].split(' ')[1:]: 30 | if len(f.split(':')) < 2: 31 | continue 32 | self.weights[ii[f.split(':')[0]]] = float( 33 | f.split(':')[1]) 34 | print(json.dumps(self.weights)) 35 | 36 | def load_raw(self, model_path): 37 | linei = -1 38 | with open(model_path) as fii: 39 | for line in fii: 40 | linei += 1 41 | if linei == 0: 42 | self.bias = float(line) 43 | else: 44 | parts = line.split('\t') 45 | self.weights[parts[0]] = float(parts[1]) 46 | 47 | print(json.dumps(self.weights)) 48 | 49 | -------------------------------------------------------------------------------- /load_corpora.py: -------------------------------------------------------------------------------- 1 | # The 3-Clause BSD License 2 | # For Priberam Clustering Software 3 | # Copyright 2018 by PRIBERAM INFORMÁTICA, S.A. ("PRIBERAM") (www.priberam.com) 4 | # Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | # 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | # 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 7 | # 3. Neither the name of the copyright holder (PRIBERAM) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 8 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 9 | import json 10 | import datetime 11 | 12 | 13 | class Corpus: 14 | def __init__(self): 15 | self.index = {} 16 | self.documents = [] 17 | 18 | def build_index(self): 19 | self.documents = sorted(self.documents, key=lambda k: datetime.datetime.strptime( 20 | k["date"], "%Y-%m-%d %H:%M:%S")) 21 | self.index = {} 22 | i = -1 23 | for sorted_document in self.documents: 24 | rem = [] 25 | for fn, fv in sorted_document["features"].items(): 26 | if not fv: 27 | rem.append(fn) 28 | for fr in rem: 29 | del(sorted_document["features"][fr]) 30 | 31 | i += 1 32 | self.index[sorted_document["id"]] = i 33 | 34 | def get_document(self, id): 35 | return self.documents[self.index[id]] 36 | 37 | 38 | def load(dataset_path, dataset_tok_ner_path, languages): 39 | corpus_index = {} 40 | with open(dataset_path, errors="ignore", mode="r") as data_file: 41 | corpus_data = json.load(data_file) 42 | for d in corpus_data: 43 | corpus_index[d["id"]] = d 44 | 45 | with open(dataset_tok_ner_path, errors="ignore", mode="r") as data_file: 46 | for l in data_file: 47 | tok_ner_document = json.loads(l) 48 | corpus_index[tok_ner_document["id"] 49 | ]["features"] = tok_ner_document["features"] 50 | 51 | corpus = Corpus() 52 | corpus.documents = [] 53 | for archive_document in corpus_index.values(): 54 | if archive_document["lang"] not in languages: 55 | continue 56 | corpus.documents.append(archive_document) 57 | 58 | corpus.build_index() 59 | return corpus 60 | -------------------------------------------------------------------------------- /testbench.py: -------------------------------------------------------------------------------- 1 | # The 3-Clause BSD License 2 | # For Priberam Clustering Software 3 | # Copyright 2018 by PRIBERAM INFORMÁTICA, S.A. ("PRIBERAM") (www.priberam.com) 4 | # Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | # 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | # 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 7 | # 3. Neither the name of the copyright holder (PRIBERAM) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 8 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 9 | 10 | 11 | # python testbench.py 12 | # python eval.py clustering.out E:\Corpora\clustering\processed_clusters\dataset.test.json -f 13 | 14 | import model 15 | import clustering 16 | import load_corpora 17 | import json 18 | import os 19 | 20 | def test(lang, thr, model_path, model_path_ii, merge_model_path=None): 21 | corpus = load_corpora.load(r"dataset/dataset.test.json", 22 | r"dataset/clustering.test.json", set([lang])) 23 | print(lang,"#docs",len(corpus.documents)) 24 | clustering_model = model.Model() 25 | clustering_model.load(model_path, model_path_ii) 26 | 27 | merge_model = None 28 | if merge_model_path: 29 | merge_model = model.Model() 30 | merge_model.load_raw(merge_model_path) 31 | 32 | aggregator = clustering.Aggregator(clustering_model, thr, merge_model) 33 | 34 | for i, d in enumerate(corpus.documents): 35 | print("\r", i, "/", len(corpus.documents), 36 | " | #c= ", len(aggregator.clusters), end="") 37 | aggregator.PutDocument(clustering.Document(d, "???")) 38 | 39 | with open("clustering."+lang+".out", "w") as fo: 40 | ci = 0 41 | for c in aggregator.clusters: 42 | for d in c.ids: 43 | fo.write(d) 44 | fo.write("\t") 45 | fo.write(str(ci)) 46 | fo.write("\n") 47 | ci += 1 48 | 49 | test('eng', 0.0, r'models/en/4_1491902620.876421_10000.0.model', 50 | r'models/en/example_2017-04-10T193850.536289.ii', r'models/en/md_3') 51 | 52 | test('spa', 8.18067, r'models/es/2_1492035151.291134_100.0.model', 53 | r'models/es/example_2017-04-12T215308.030747.ii') 54 | 55 | test('deu', 8.1175, r'models/de/2_1499938269.299021_100.0.model', 56 | r'models/de/example_2017-07-13T085725.498310.ii') 57 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #The 3-Clause BSD License 2 | # For Priberam Clustering Software 3 | # Copyright 2018 by PRIBERAM INFORMÁTICA, S.A. ("PRIBERAM") (www.priberam.com) 4 | # Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | # 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | # 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 7 | # 3. Neither the name of the copyright holder (PRIBERAM) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 8 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 9 | # 10 | # Script utils for clustering evaluation 11 | # Adapted from Arturs Znotins script 12 | # 13 | 14 | from scipy.misc import comb 15 | import numpy as np 16 | 17 | # JavaScript like dictionary: d.key <=> d[key] 18 | # http://stackoverflow.com/a/14620633 19 | class Dict(dict): 20 | def __init__(self, *args, **kwargs): 21 | super(Dict, self).__init__(*args, **kwargs) 22 | self.__dict__ = self 23 | 24 | def __getattribute__(self, key): 25 | try: 26 | return super(Dict, self).__getattribute__(key) 27 | except: 28 | return 29 | 30 | def __delattr__(self, name): 31 | if name in self: 32 | del self[name] 33 | 34 | 35 | def myComb(a,b): 36 | return comb(a,b,exact=True) 37 | 38 | 39 | vComb = np.vectorize(myComb) 40 | 41 | 42 | def get_tp_fp_tn_fn(cooccurrence_matrix): 43 | tp_plus_fp = vComb(cooccurrence_matrix.sum(0, dtype=int),2).sum() 44 | tp_plus_fn = vComb(cooccurrence_matrix.sum(1, dtype=int),2).sum() 45 | tp = vComb(cooccurrence_matrix.astype(int), 2).sum() 46 | fp = tp_plus_fp - tp 47 | fn = tp_plus_fn - tp 48 | tn = comb(cooccurrence_matrix.sum(), 2) - tp - fp - fn 49 | 50 | return [tp, fp, tn, fn] 51 | 52 | 53 | def get_cooccurrence_matrix(true_labels, pred_labels): 54 | assert len(true_labels) == len(pred_labels) 55 | true_label_map = {} 56 | i = 0 57 | for l in true_labels: 58 | if not l in true_label_map: 59 | true_label_map[l] = i 60 | i += 1 61 | hyp_label_map = {} 62 | i = 0 63 | for l in pred_labels: 64 | if not l in hyp_label_map: 65 | hyp_label_map[l] = i 66 | i += 1 67 | m = [[0 for i in range(len(hyp_label_map))] for j in range(len(true_label_map))] 68 | for i in range(len(true_labels)): 69 | m[true_label_map[true_labels[i]]][hyp_label_map[pred_labels[i]]] += 1 70 | return (np.array(m), true_label_map, hyp_label_map) 71 | 72 | 73 | def sum_sparse(a, b, amult=1, bmult=1): 74 | r = {} 75 | for x in a: 76 | r[x[0]] = amult * x[1] 77 | for x in b: 78 | r[x[0]] = bmult * x[1] + (r[x[0]] if x[0] in r else 0) 79 | res = [] 80 | for k, v in r.items(): 81 | res.append((k, v)) 82 | return res 83 | 84 | 85 | def trim_sparse(a, topn=100): 86 | return sorted(a, key=lambda x: x[1], reverse=True)[:topn] -------------------------------------------------------------------------------- /eval_lib.py: -------------------------------------------------------------------------------- 1 | #The 3-Clause BSD License 2 | # For Priberam Clustering Software 3 | # Copyright 2018 by PRIBERAM INFORMÁTICA, S.A. ("PRIBERAM") (www.priberam.com) 4 | # Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | # 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | # 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 7 | # 3. Neither the name of the copyright holder (PRIBERAM) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 8 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 9 | # 10 | # Script utils for clustering evaluation 11 | # Adapted from Arturs Znotins script 12 | # 13 | 14 | import sys 15 | import os 16 | from math import * 17 | import json 18 | from sklearn import metrics 19 | import numpy as np 20 | import time 21 | from datetime import datetime 22 | import logging 23 | from pprint import pprint 24 | from collections import Counter 25 | from sklearn.metrics import precision_recall_fscore_support 26 | import utils 27 | 28 | def ScoreSet(true_labels, pred_labels, logging="", get_data=False): 29 | cooccurrence_matrix, true_label_map, pred_label_map = utils.get_cooccurrence_matrix(true_labels, pred_labels) 30 | tp, fp, tn, fn = utils.get_tp_fp_tn_fn(cooccurrence_matrix) 31 | 32 | acc = 1. * (tp + tn) / (tp + tn + fp + fn) if tp + tn + fp + fn > 0 else 0 33 | p = 1. * tp / (tp + fp) if tp + fp > 0 else 0 34 | r = 1. * tp / (tp + fn) if tp + fn > 0 else 0 35 | f1 = 2. * p * r / (p + r) if p + r > 0 else 0 36 | 37 | ri = 1. * (tp + tn) / (tp + tn + fp + fn) if tp + tn + fp + fn > 0 else 0 38 | 39 | entropies, purities = [], [] 40 | for cluster in cooccurrence_matrix: 41 | cluster = cluster / float(cluster.sum()) 42 | # ee = (cluster * [log(max(x, 1e-6), 2) for x in cluster]).sum() 43 | pp = cluster.max() 44 | # entropies += [ee] 45 | purities += [pp] 46 | counts = np.array([c.sum() for c in cooccurrence_matrix]) 47 | coeffs = counts / float(counts.sum()) 48 | purity = (coeffs * purities).sum() 49 | # entropy = (coeffs * entropies).sum() 50 | 51 | ari = metrics.adjusted_rand_score(true_labels, pred_labels) 52 | nmi = metrics.normalized_mutual_info_score(true_labels, pred_labels) 53 | ami = metrics.adjusted_mutual_info_score(true_labels, pred_labels) 54 | v_measure = metrics.homogeneity_completeness_v_measure(true_labels, pred_labels) 55 | 56 | pred_cluset = {} 57 | for v in pred_labels: 58 | pred_cluset[v] = True 59 | 60 | true_cluset = {} 61 | for v in true_labels: 62 | true_cluset[v] = True 63 | 64 | s = "{\n\"logging\" : \"" + logging + "\",\n" 65 | s += "\"f1\" : %.5f,\n" % f1 66 | s += "\"p\" : %.5f,\n" % p 67 | s += "\"r\" : %.5f,\n" % r 68 | s += "\"a\" : %.5f,\n" % acc 69 | s += "\"ari\" : %.5f,\n" % ari 70 | s += "\"size_true\" : %.5f,\n" % len(true_labels) 71 | s += "\"size_pred\" : %.5f,\n" % len(pred_labels) 72 | s += "\"num_labels_true\" : %.5f,\n" % len(true_cluset) 73 | s += "\"num_labels_pred\" : %.5f,\n" % len(pred_cluset) 74 | s += "\"ri\" : %.5f,\n" % ri 75 | s += "\"nmi\" : %.5f,\n" % nmi 76 | s += "\"ami\" : %.5f,\n" % ami 77 | s += "\"pur\" : %.5f,\n" % purity 78 | s += "\"hom\" : %.5f,\n" % v_measure[0] 79 | s += "\"comp\" : %.5f,\n" % v_measure[1] 80 | s += "\"v\" : %.5f,\n" % v_measure[2] 81 | s += "\"tp\" : %.5f,\n" % tp 82 | s += "\"fp\" : %.5f,\n" % fp 83 | s += "\"tn\" : %.5f,\n" % tn 84 | s += "\"fn\" : %.5f\n" % fn 85 | s += "}" 86 | 87 | if get_data: 88 | return s 89 | else: 90 | print (s) -------------------------------------------------------------------------------- /dataset_loader.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import json 3 | import os 4 | import datetime 5 | 6 | 7 | def LoadLinkingDataset(path, allowed_languages=None, allowed_documents=None, limited=True): 8 | dataset = {} 9 | 10 | dataset["linking"] = {} 11 | dataset["bags"] = {} 12 | dataset["ii"] = {} 13 | dataset["ii_langs"] = {} 14 | dataset["i_document_data"] = {} 15 | dataset["documents"] = set() 16 | 17 | ext_list = [".json"] 18 | for dirName, subdirList, fileList in os.walk(path): 19 | for fname in fileList: 20 | if any(fname.endswith(ext) for ext in ext_list): 21 | complete_fname = dirName + "\\" + fname 22 | filedata = None 23 | with open(complete_fname) as file: 24 | filedata = file.read() 25 | 26 | linkage_type = "negative" 27 | if "positive" in complete_fname: 28 | linkage_type = "positive" 29 | 30 | if filedata: 31 | try: 32 | doc_object = json.loads(filedata) 33 | 34 | langs = [] 35 | bag_ids = [] 36 | selected_annotation = {} 37 | for k, v in doc_object.items(): 38 | if k == "meta": 39 | continue 40 | if "info" not in v: 41 | continue 42 | lang = v["info"]["lang"] 43 | 44 | if allowed_languages and lang not in allowed_languages: 45 | continue 46 | 47 | # document bag 48 | bag_id = k 49 | articles = v["articles"]["results"] 50 | allowed_articles = [] 51 | for article in articles: 52 | article_id = article["id"] 53 | if allowed_documents and article_id not in allowed_documents: 54 | continue 55 | allowed_articles.append(article) 56 | 57 | if len(allowed_articles) <= 0: 58 | continue 59 | 60 | bag_ids.append(k) 61 | langs.append(lang) 62 | selected_annotation[k] = (v, allowed_articles) 63 | 64 | if limited and len(bag_ids) < 2: 65 | continue 66 | 67 | if limited: 68 | assert(len(bag_ids) == 2) 69 | if langs[0] == langs[1]: 70 | continue 71 | 72 | bag_ids = [] 73 | for k, v in selected_annotation.items(): 74 | lang = v[0]["info"]["lang"] 75 | bag_event_id = v[0]["info"]["eventUri"] 76 | allowed_articles = v[1] 77 | 78 | assert(len(allowed_articles) > 0) 79 | 80 | if allowed_languages: 81 | assert(lang in allowed_languages) 82 | 83 | bag_id = k 84 | 85 | bag_ids.append(bag_id) 86 | dataset["bags"][bag_id] = {} 87 | dataset["bags"][bag_id]["linked_as"] = linkage_type 88 | dataset["bags"][bag_id]["articles"] = set() 89 | dataset["bags"][bag_id]["lang"] = lang 90 | dataset["bags"][bag_id]["event_id"] = bag_event_id 91 | 92 | for article in allowed_articles: 93 | article_id = article["id"] 94 | dataset["ii"][article_id] = bag_id 95 | dataset["ii_langs"][article_id] = lang 96 | 97 | dataset["i_document_data"][article_id] = {"id": article_id, 98 | "text": article["body"], 99 | "title": article["title"], 100 | "event_id": article["eventUri"], 101 | "duplicate": article["isDuplicate"], 102 | "lang": article["lang"], 103 | "bag_id": bag_id, 104 | "date": datetime.datetime.strptime(article['date'] + ' ' + article['time'], "%Y-%m-%d %H:%M:%S"), 105 | "source": article["source"]["title"]} 106 | 107 | dataset["bags"][bag_id]["articles"].add( 108 | article_id) 109 | dataset["documents"].add(article_id) 110 | 111 | # Create linking annotation 112 | for bag_id_0 in bag_ids: 113 | for bag_id_1 in bag_ids: 114 | if bag_id_0 == bag_id_1: 115 | continue 116 | if bag_id_0 not in dataset["linking"]: 117 | dataset["linking"][bag_id_0] = {} 118 | dataset["linking"][bag_id_0][bag_id_1] = linkage_type 119 | 120 | if bag_id_1 not in dataset["linking"]: 121 | dataset["linking"][bag_id_1] = {} 122 | dataset["linking"][bag_id_1][bag_id_0] = linkage_type 123 | 124 | except Exception as e: 125 | #raise e 126 | print("Failed reading file", complete_fname, "e:", e) 127 | return dataset 128 | -------------------------------------------------------------------------------- /clustering.py: -------------------------------------------------------------------------------- 1 | #The 3-Clause BSD License 2 | # For Priberam Clustering Software 3 | # Copyright 2018 by PRIBERAM INFORMÁTICA, S.A. ("PRIBERAM") (www.priberam.com) 4 | # Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | # 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | # 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 7 | # 3. Neither the name of the copyright holder (PRIBERAM) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 8 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 9 | import math 10 | import model 11 | import pdb 12 | import datetime 13 | 14 | 15 | def sparse_dotprod(fv0, fv1): 16 | dotprod = 0 17 | 18 | for f_id_0, f_value_0 in fv0.items(): 19 | if f_id_0 in fv1: 20 | f_value_1 = fv1[f_id_0] 21 | dotprod += f_value_0 * f_value_1 22 | 23 | return dotprod 24 | 25 | 26 | def cosine_bof(d0, d1): 27 | cosine_bof_v = {} 28 | for fn, fv0 in d0.items(): 29 | if fn in d1: 30 | fv1 = d1[fn] 31 | cosine_bof_v[fn] = sparse_dotprod( 32 | fv0, fv1) / math.sqrt(sparse_dotprod(fv0, fv0) * sparse_dotprod(fv1, fv1)) 33 | return cosine_bof_v 34 | 35 | def normalized_gaussian(mean, stddev, x): 36 | return (math.exp(-((x - mean) * (x - mean)) / (2 * stddev * stddev))) 37 | 38 | 39 | def timestamp_feature(tsi, tst, gstddev): 40 | return normalized_gaussian(0, gstddev, (tsi-tst)/(60*60*24.0)) 41 | 42 | def sim_bof_dc(d0, c1): 43 | numdays_stddev = 3.0 44 | bof = cosine_bof(d0.reprs, c1.reprs) 45 | bof["NEWEST_TS"] = timestamp_feature( 46 | d0.timestamp.timestamp(), c1.newest_timestamp.timestamp(), numdays_stddev) 47 | bof["OLDEST_TS"] = timestamp_feature( 48 | d0.timestamp.timestamp(), c1.oldest_timestamp.timestamp(), numdays_stddev) 49 | bof["RELEVANCE_TS"] = timestamp_feature( 50 | d0.timestamp.timestamp(), c1.get_relevance_stamp(), numdays_stddev) 51 | bof["ZZINVCLUSTER_SIZE"] = 1.0 / float(100 if c1.num_docs > 100 else c1.num_docs) 52 | 53 | return bof 54 | 55 | def model_score(bof, model: model.Model): 56 | return sparse_dotprod(bof, model.weights) - model.bias 57 | 58 | 59 | class Document: 60 | def __init__(self, archive_document, group_id): 61 | self.id = archive_document["id"] 62 | self.reprs = archive_document["features"] 63 | self.timestamp = datetime.datetime.strptime( 64 | archive_document["date"], "%Y-%m-%d %H:%M:%S") 65 | self.group_id = group_id 66 | 67 | 68 | class Cluster: 69 | def __init__(self, document): 70 | self.ids = set() 71 | self.num_docs = 0 72 | self.reprs = {} 73 | self.sum_timestamp = 0 74 | self.sumsq_timestamp = 0 75 | self.newest_timestamp = datetime.datetime.strptime( 76 | "1000-01-01 00:00:00", "%Y-%m-%d %H:%M:%S") 77 | self.oldest_timestamp = datetime.datetime.strptime( 78 | "3000-01-01 00:00:00", "%Y-%m-%d %H:%M:%S") 79 | self.add_document(document) 80 | 81 | def get_relevance_stamp(self): 82 | z_score = 0 83 | mean = self.sum_timestamp / self.num_docs 84 | try: 85 | std_dev = math.sqrt((self.sumsq_timestamp / self.num_docs) - (mean*mean)) 86 | except: 87 | std_dev = 0.0 88 | return mean + ((z_score * std_dev) * 3600.0) # its in secods since epoch 89 | 90 | def add_document(self, document): 91 | self.ids.add(document.id) 92 | self.newest_timestamp = max(self.newest_timestamp, document.timestamp) 93 | self.oldest_timestamp = min(self.oldest_timestamp, document.timestamp) 94 | ts_hours = (document.timestamp.timestamp() / 3600.0) 95 | self.sum_timestamp += ts_hours 96 | self.sumsq_timestamp += ts_hours * ts_hours 97 | self.__add_bof(document.reprs) 98 | 99 | def __add_bof(self, reprs0): 100 | for fn, fv0 in reprs0.items(): 101 | if fn in self.reprs: 102 | for f_id_0, f_value_0 in fv0.items(): 103 | if f_id_0 in self.reprs[fn]: 104 | self.reprs[fn][f_id_0] += f_value_0 105 | else: 106 | self.reprs[fn][f_id_0] = f_value_0 107 | else: 108 | self.reprs[fn] = fv0 109 | self.num_docs += 1 110 | 111 | 112 | class Aggregator: 113 | def __init__(self, model: model.Model, thr, merge_model: model.Model = None): 114 | self.clusters = [] 115 | self.model = model 116 | self.thr = thr 117 | self.merge_model = merge_model 118 | 119 | def PutDocument(self, document): 120 | best_i = -1 121 | best_s = 0.0 122 | i = -1 123 | bofs = [] 124 | for cluster in self.clusters: 125 | i += 1 126 | 127 | bof = sim_bof_dc(document, cluster) 128 | bofs.append(bof) 129 | score = model_score(bof, self.model) 130 | if score > best_s and (score > self.thr or self.merge_model): 131 | best_s = score 132 | best_i = i 133 | 134 | if best_i != -1 and self.merge_model: 135 | merge_score = model_score(bofs[best_i], self.merge_model) 136 | #print(merge_score) 137 | if merge_score <= 0: 138 | best_i = -1 139 | 140 | if best_i == -1: 141 | self.clusters.append(Cluster(document)) 142 | best_i = len(self.clusters) - 1 143 | else: 144 | self.clusters[best_i].add_document(document) 145 | 146 | return best_i 147 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #The 3-Clause BSD License 2 | # For Priberam Clustering Software 3 | # Copyright 2018 by PRIBERAM INFORMÁTICA, S.A. ("PRIBERAM") (www.priberam.com) 4 | # Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | # 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | # 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 7 | # 3. Neither the name of the copyright holder (PRIBERAM) nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 8 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 9 | from eval_lib import ScoreSet 10 | import argparse 11 | import json 12 | from dataset_loader import LoadLinkingDataset 13 | from collections import OrderedDict 14 | 15 | language_converter = { "de" : "deu" , "en" : "eng", "es" : "spa"} 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("predfile", help="the file with the predictions") 19 | parser.add_argument("goldfile", help="the file with the gold labels") 20 | parser.add_argument("-f", "--fix_sizes", help="fix gold size to match pred size (used for midway validation)", action="store_true") 21 | parser.add_argument("-c", "--crosslingual", help="Evaluate crosslingual score, args = languages considered", nargs='+') 22 | parser.add_argument("-l", "--eval_linking_dataset", help="Evaluate on linking dataset (arg = path)", nargs='?') 23 | parser.add_argument("-d", "--debug", help="Enable debug prints", action="store_true") 24 | parser.add_argument("-s", "--symetric_formats", help="Use same format for gold and pred", action="store_true") 25 | args = parser.parse_args() 26 | 27 | fix_sizes = False 28 | if args.fix_sizes: 29 | fix_sizes = True 30 | 31 | crosslingual = False 32 | if args.crosslingual: 33 | crosslingual = True 34 | 35 | debug_prints = False 36 | if args.debug: 37 | debug_prints = True 38 | 39 | file_pred = args.predfile 40 | 41 | 42 | clusters_pred = {} 43 | clusters_to_docs_pred = {} 44 | with open(file_pred, errors="ignore") as fp: 45 | for line in fp: 46 | sl = line.split('\t') 47 | clusters_pred[sl[0]] = sl[1] 48 | if(not sl[1] in clusters_to_docs_pred): 49 | clusters_to_docs_pred[sl[1]] = [] 50 | clusters_to_docs_pred[sl[1]].append(sl[0]) 51 | 52 | if(args.eval_linking_dataset): 53 | allowed_documents = set() 54 | for k,v in clusters_pred.items(): 55 | allowed_documents.add(k) 56 | dataset_path = args.eval_linking_dataset 57 | linking_dataset = LoadLinkingDataset(dataset_path, set(['eng', 'deu', 'spa']), allowed_documents, limited=False) 58 | 59 | clusters_gold = {} 60 | clusters_to_docs_gold = {} 61 | if(args.symetric_formats): 62 | with open(args.goldfile, errors="ignore") as fp: 63 | for line in fp: 64 | sl = line.split('\t') 65 | clusters_gold[sl[0]] = sl[1] 66 | if(not sl[1] in clusters_to_docs_gold): 67 | clusters_to_docs_gold[sl[1]] = [] 68 | clusters_to_docs_gold[sl[1]].append(sl[0]) 69 | else: 70 | with open(args.goldfile, "r", errors="ignore") as fg: 71 | documents = json.load(fg) 72 | for document in documents: 73 | if(fix_sizes): 74 | if not document["id"] in clusters_pred: 75 | continue 76 | 77 | clusters_gold[document["id"]] = document["cluster"] 78 | if(not document["cluster"] in clusters_to_docs_gold): 79 | clusters_to_docs_gold[document["cluster"]] = [] 80 | 81 | clusters_to_docs_gold[document["cluster"]].append(document["id"]) 82 | 83 | if debug_prints: 84 | print("loaded files") 85 | 86 | if not crosslingual: 87 | 88 | id_to_index = {} 89 | index = -1 90 | for k, v in clusters_pred.items(): 91 | index += 1 92 | id_to_index[k] = index 93 | 94 | if debug_prints: 95 | print("ii built") 96 | 97 | if(len(clusters_gold) != len(clusters_pred)): 98 | print ("error:", len(clusters_gold), len(clusters_pred)) 99 | for k0,v in clusters_gold.items(): 100 | if k0 not in clusters_pred: 101 | print(k0) 102 | 103 | for k0,v in clusters_pred.items(): 104 | if k0 not in clusters_gold: 105 | print(k0) 106 | 107 | assert(False) 108 | 109 | true_labels = [0 for i in range(len(clusters_gold))] 110 | for k, v in clusters_gold.items(): 111 | true_labels[id_to_index[k]] = v 112 | 113 | pred_labels = [0 for i in range(len(clusters_pred))] 114 | for k, v in clusters_pred.items(): 115 | pred_labels[id_to_index[k]] = v 116 | 117 | if debug_prints: 118 | print("running score set...") 119 | print (true_labels) 120 | print (pred_labels) 121 | 122 | ScoreSet(true_labels, pred_labels, "mono-score") 123 | 124 | if debug_prints: 125 | print("score set done.") 126 | 127 | else: 128 | cross_languages = set() 129 | for language in args.crosslingual: 130 | cross_languages.add(language_converter[language]) 131 | 132 | if debug_prints: 133 | print(cross_languages) 134 | 135 | num_mono = 0 136 | gold_cross_clusters = [] 137 | gold_cross_clusters_ids = [] 138 | pred_cross_clusters = [] 139 | pred_cross_clusters_ids = [] 140 | 141 | # This is a bit legacy, needs to be re-written 142 | language_to_index = OrderedDict() 143 | mono_doc_to_cluster_gold = [{} for i in range(len(cross_languages))] 144 | with open(args.goldfile, "r", errors="ignore") as fg: 145 | documents = json.load(fg) 146 | for document in documents: 147 | if document["lang"] not in cross_languages: 148 | continue 149 | if(fix_sizes): 150 | if not document["id"] in clusters_pred: 151 | continue 152 | 153 | if document["lang"] not in language_to_index: 154 | nindex = len(language_to_index) 155 | language_to_index[document["lang"]] = nindex 156 | nindex = language_to_index[document["lang"]] 157 | mono_doc_to_cluster_gold[nindex][document["id"]] = document["cluster"] 158 | 159 | for cluster_id, doc_list in clusters_to_docs_gold.items(): 160 | gold_cross_cluster = set() 161 | for doc_id in doc_list: 162 | 163 | i = -1 164 | for mono_d_t_c in mono_doc_to_cluster_gold: 165 | i += 1 166 | if(doc_id in mono_d_t_c): 167 | gold_cross_cluster.add(mono_d_t_c[doc_id] + "_" + str(i)) 168 | gold_cross_clusters.append(gold_cross_cluster) 169 | gold_cross_clusters_ids.append(cluster_id) 170 | 171 | if(debug_prints): 172 | foo = open("gold_cross.out", "w") 173 | y = -1 174 | for gc in gold_cross_clusters: 175 | y += 1 176 | for mono_id in gc: 177 | if(debug_prints): 178 | foo.write(str(gold_cross_clusters_ids[y]) + "\t" + str(mono_id) + "\n") 179 | if(debug_prints): 180 | foo.close() 181 | 182 | 183 | unique_docs = set() 184 | for cluster_id, doc_list in clusters_to_docs_pred.items(): 185 | pred_cross_cluster = set() 186 | for doc_id in doc_list: 187 | 188 | i = -1 189 | for mono_d_t_c in mono_doc_to_cluster_gold: 190 | i += 1 191 | if(doc_id in mono_d_t_c): 192 | pred_cross_cluster.add(mono_d_t_c[doc_id] + "_" + str(i)) 193 | assert (doc_id not in unique_docs) 194 | unique_docs.add(doc_id) 195 | 196 | pred_cross_clusters.append(pred_cross_cluster) 197 | pred_cross_clusters_ids.append(cluster_id) 198 | 199 | if(debug_prints): 200 | foo = open("pred_cross.out", "w") 201 | y = -1 202 | for pc in pred_cross_clusters: 203 | y += 1 204 | for mono_id in pc: 205 | if(debug_prints): 206 | foo.write(str(pred_cross_clusters_ids[y]) + "\t" + str(mono_id) + "\n") 207 | #print (str(pred_cross_clusters_ids[y]) + "\t" + str(mono_id)) 208 | if(debug_prints): 209 | foo.close() 210 | 211 | id_to_index = {} 212 | index = -1 213 | for pred_cluster in pred_cross_clusters: 214 | for mono_id in pred_cluster: 215 | index += 1 216 | id_to_index[mono_id] = index 217 | if(debug_prints): 218 | print("i", index) 219 | 220 | if debug_prints: 221 | u = set() 222 | for c in pred_cross_clusters: 223 | for p in c: 224 | print (p) 225 | assert (p not in u) 226 | u.add(p) 227 | 228 | # just to check len 229 | gold_id_to_index = {} 230 | index = -1 231 | for gold_cluster in gold_cross_clusters: 232 | for mono_id in gold_cluster: 233 | index += 1 234 | gold_id_to_index[mono_id] = index 235 | 236 | if(debug_prints): 237 | print (len(id_to_index)) 238 | #print (len(id_to_index)) 239 | assert(len(id_to_index) == len(gold_id_to_index)) 240 | 241 | if(debug_prints): 242 | print (len(id_to_index)) 243 | 244 | label = -1 245 | pred_labels = [-1 for i in range(len(id_to_index))] 246 | for pred_cluster in pred_cross_clusters: 247 | label += 1 248 | for mono_id in pred_cluster: 249 | #print(id_to_index[mono_id]) 250 | 251 | if (debug_prints): 252 | with open("out.tmp","w") as outtmp: 253 | print (mono_id, file=outtmp) 254 | print (id_to_index, file=outtmp) 255 | print (pred_labels, file=outtmp) 256 | print (id_to_index[mono_id], file=outtmp) 257 | print (len(pred_labels), file=outtmp) 258 | 259 | pred_labels[id_to_index[mono_id]] = label 260 | 261 | #pred_labels = [] 262 | #for l in pred_labels_tmp: 263 | # if l == -1: 264 | # break 265 | # pred_labels.append(l) 266 | 267 | label = -1 268 | true_labels = [-1 for i in range(len(id_to_index))] 269 | for gold_cluster in gold_cross_clusters: 270 | label += 1 271 | for mono_id in gold_cluster: 272 | true_labels[id_to_index[mono_id]] = label 273 | 274 | #true_labels = [] 275 | #for l in true_labels_tmp: 276 | # if l == -1: 277 | # break 278 | # true_labels.append(l) 279 | 280 | #print (true_labels_tmp) 281 | #print (pred_labels_tmp) 282 | if(debug_prints): 283 | print (true_labels) 284 | print (pred_labels) 285 | 286 | additional_scoring = None 287 | if(args.eval_linking_dataset): 288 | #DEBUG 289 | num_annots=set() 290 | for bag_i, bag_l in linking_dataset["linking"].items(): 291 | for bag_j in bag_l: 292 | num_annots.add(bag_i+bag_j) 293 | num_annots.add(bag_j+bag_i) 294 | #assert (linking_dataset["bags"][bag_i]["lang"] != linking_dataset["bags"][bag_j]["lang"]) 295 | #print ("num_annots:",len(num_annots) / 2) 296 | 297 | lcount={} 298 | for document, lang in linking_dataset["ii_langs"].items(): 299 | if lang not in lcount: 300 | lcount[lang] = 1 301 | else: 302 | lcount[lang] += 1 303 | #print(lcount) 304 | 305 | visited_documents = set() 306 | predicted_links_map = {} 307 | visited_pairs = set() 308 | for cross_cluster_id, document_list in clusters_to_docs_pred.items(): 309 | bag_ids = set() 310 | for document in document_list: 311 | if document not in linking_dataset["ii"]: 312 | assert(False) 313 | #continue # document got rejected because it belongs to cloud which has not been linked across langs 314 | bag_id = linking_dataset["ii"][document] 315 | bag_ids.add(bag_id) 316 | visited_documents.add(document) 317 | if(len(bag_ids) <= 0): 318 | assert(False) 319 | 320 | for bag_id_0 in bag_ids: 321 | lang_0 = linking_dataset["bags"][bag_id_0]["lang"] 322 | for bag_id_1 in bag_ids: 323 | if bag_id_0 == bag_id_1 or bag_id_0+bag_id_1 in visited_pairs: 324 | continue 325 | lang_1 = linking_dataset["bags"][bag_id_1]["lang"] 326 | #if lang_0 == lang_1: 327 | # continue 328 | 329 | # predicted: pairs of bags of different languages 330 | 331 | if bag_id_0 not in predicted_links_map: 332 | predicted_links_map[bag_id_0] = set() 333 | predicted_links_map[bag_id_0].add(bag_id_1) 334 | 335 | if bag_id_1 not in predicted_links_map: 336 | predicted_links_map[bag_id_1] = set() 337 | predicted_links_map[bag_id_1].add(bag_id_0) 338 | 339 | visited_pairs.add(bag_id_0+bag_id_1) 340 | visited_pairs.add(bag_id_1+bag_id_0) 341 | 342 | #print ("Visited documents: ",len(visited_documents)) 343 | #print ("dataset documents: ",len(linking_dataset["ii"])) 344 | #print("DEBUG", len(visited_pairs) / 2) 345 | #print("DEBUG", len(predicted_links_map)) 346 | visited_data = set() 347 | for bag_t, bag_list in predicted_links_map.items(): 348 | for bag_o in bag_list: 349 | visited_data.add(bag_t+bag_o) 350 | visited_data.add(bag_o+bag_t) 351 | #print ("LEN: ", len(visited_data) / 2) 352 | 353 | #import pdb; pdb.set_trace() 354 | 355 | tp = 0 356 | tn = 0 357 | fp = 0 358 | fn = 0 359 | visited_data = set() 360 | for bag_id_0, bag_links in linking_dataset["linking"].items(): 361 | lang_0 = linking_dataset["bags"][bag_id_0]["lang"] 362 | for bag_id_1, gold_linking in bag_links.items(): 363 | if bag_id_0+bag_id_1 in visited_data: 364 | continue 365 | 366 | lang_1 = linking_dataset["bags"][bag_id_1]["lang"] 367 | if lang_0 == lang_1: 368 | continue # Cannot compare of same language, its a bias because for me those come as gold. 369 | 370 | #if not (bag_id_0 in predicted_links_map): 371 | # continue 372 | #import pdb; pdb.set_trace() 373 | 374 | predicted_link = bag_id_0 in predicted_links_map and bag_id_1 in predicted_links_map[bag_id_0] 375 | 376 | if gold_linking == "positive": 377 | if predicted_link: 378 | tp += 1 379 | else: 380 | fn += 1 381 | elif gold_linking == "negative": 382 | if predicted_link: 383 | fp += 1 384 | else: 385 | tn +=1 386 | else: 387 | assert(False) 388 | 389 | visited_data.add(bag_id_0+bag_id_1) 390 | visited_data.add(bag_id_1+bag_id_0) 391 | 392 | #print("gold-sweep:", len(visited_data) / 2) 393 | 394 | def Scoring (tp,fp,tn,fn): 395 | acc = 1. * (tp + tn) / (tp + tn + fp + fn) if tp + tn + fp + fn > 0 else 0 396 | p = 1. * tp / (tp + fp) if tp + fp > 0 else 0 397 | r = 1. * tp / (tp + fn) if tp + fn > 0 else 0 398 | f1 = 2. * p * r / (p + r) if p + r > 0 else 0 399 | return { 400 | "p" : p, 401 | "r" : r, 402 | "f1" : f1, 403 | "a" : acc, 404 | "tp" : tp, 405 | "fp" : fp, 406 | "tn" : tn, 407 | "fn" : fn, 408 | "tot" : tp+fp+tn+fn} 409 | 410 | additional_scoring = Scoring(tp,fp,tn,fn) 411 | 412 | 413 | eval_object = json.loads(ScoreSet(true_labels, pred_labels, "cross-score", get_data=True)) 414 | if additional_scoring != None: 415 | eval_object["linking_score"] = additional_scoring 416 | print(json.dumps(eval_object)) 417 | --------------------------------------------------------------------------------