├── evaluation ├── rbo │ ├── __init__.py │ └── rbo.py └── measures.py ├── images ├── loss_function.png ├── multilingual-only.png ├── multimodal-only.png └── multilingual_and_multimodal.png ├── LICENSE ├── utils ├── run_preprocess.py ├── encode_images.py ├── preprocess_wiki.py ├── preprocessing.py └── data_preparation.py ├── README.md ├── training_scripts ├── train_multilingual_contrast.py └── train_M3L_contrast.py ├── networks ├── inference_network.py └── decoding_network.py ├── datasets └── dataset.py ├── data ├── train-example.csv └── test-example.csv └── models ├── multilingual_contrast.py └── M3L_contrast.py /evaluation/rbo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/loss_function.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ezosa/M3L-topic-model/HEAD/images/loss_function.png -------------------------------------------------------------------------------- /images/multilingual-only.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ezosa/M3L-topic-model/HEAD/images/multilingual-only.png -------------------------------------------------------------------------------- /images/multimodal-only.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ezosa/M3L-topic-model/HEAD/images/multimodal-only.png -------------------------------------------------------------------------------- /images/multilingual_and_multimodal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ezosa/M3L-topic-model/HEAD/images/multilingual_and_multimodal.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Elaine Zosa 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/run_preprocess.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from utils.preprocess_wiki import parse_wikipedia_xml_dump, combine_multilingual_wikipedia_articles 3 | 4 | # Preprocessing steps 5 | # Step 1: Parse and clean Wikipedia articles from xml dumps for each language (e.g. EN and DE) 6 | # Wikipedia XML dumps are taken from https://mirror.accum.se/mirror/wikimedia.org/dumps/ (or other mirror sites) 7 | wiki_l1 = "enwiki-20230201-pages-articles.xml" 8 | wiki_l2 = "dewiki-20230201-pages-articles.xml" 9 | parsed_articles_l1 = parse_wikipedia_xml_dump(xml_path=wiki_l1, lang='en') 10 | parsed_articles_l2 = parse_wikipedia_xml_dump(xml_path=wiki_l2, lang='de') 11 | lang1_articles = pd.read_csv(parsed_articles_l1) 12 | lang2_articles = pd.read_csv(parsed_articles_l2) 13 | 14 | 15 | # Step 2: Here we combine the cleaned articles from step 1 into a dataset of aligned Wikipedia titles, image urls and full articles 16 | # 2.1 (optional) 17 | # aligned_titles used in our experiments is provided in the Git repo (see step 2.3 below) 18 | # If you want to use other language pairs, download an xml from https://linguatools.org/tools/corpora/wikipedia-comparable-corpora/ 19 | xml_path = "wikicomp-2014_deen.xml" 20 | # Parse the xml file for the aligned titles and create a csv 21 | aligned_titles_file = align_wikipedia_titles(xml_path, lang_pair='de-en') 22 | 23 | # 2.2 (optional) 24 | # image_url is taken from WIT (https://www.kaggle.com/competitions/wikipedia-image-caption/data) 25 | # we match the image_url to the article using the article title (e.g. en_title) 26 | # the result here should be a csv with the columns: en_title, de_title, image_url 27 | 28 | 29 | # 2.3 Merge cleaned articles with the aligned titles file 30 | aligned_titles_file = "https://github.com/ezosa/M3L-topic-model/blob/master/data/train-titles.csv" 31 | aligned_titles = pd.read_csv(aligned_titles_file) 32 | merged_wiki = combine_multilingual_wikipedia_articles(aligned_titles=aligned_titles, 33 | lang1_articles=lang1_articles, 34 | lang2_articles=lang2_articles, 35 | lang1='en', 36 | lang2='de') 37 | # the result should be a csv with columns: en_title, en_text, de_title, de_text, image_url (see example https://github.com/ezosa/M3L-topic-model/blob/master/data/train-example.csv) 38 | -------------------------------------------------------------------------------- /utils/encode_images.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import tarfile 3 | from tqdm import tqdm 4 | import gzip 5 | import os 6 | import shutil 7 | from PIL import Image 8 | from sentence_transformers import SentenceTransformer 9 | from io import BytesIO 10 | import base64 11 | 12 | # tar file of Wikipedia images from https://www.kaggle.com/c/wikipedia-image-caption/ 13 | 14 | 15 | def unpack_pixels_and_clip_encode(tar_file="image_data_train.tar", outfile="image_data_train_clip.csv", clip_encoder="clip-ViT-B-32",): 16 | """Unpack pixels of Wikipedia images and encode with CLIP.""" 17 | pixels_path = "image_data_train/image_pixels" 18 | tar = tarfile.open(tar_file) 19 | for f in tar.getmembers(): 20 | if os.path.exists(outfile) is False: 21 | if f.name.endswith("gz"): 22 | if f.name.startswith(pixels_path): 23 | tar.extract(f) 24 | print("Extracted tar member:", f.name) 25 | pixels_df = pd.read_csv(f.name, compression='gzip', header=None, sep='\t') 26 | image_urls = list(pixels_df[0]) 27 | image_pixels = [] 28 | valid_urls = [] 29 | for i, img_url in enumerate(image_urls): 30 | if '.svg' not in img_url: 31 | img = Image.open(BytesIO(base64.b64decode(pixels_df.iloc[i][1]))) 32 | img = img.convert("RGB") 33 | image_pixels.append(img) 34 | valid_urls.append(img_url) 35 | os.remove(f.name) 36 | img_model = SentenceTransformer(clip_encoder) 37 | encoded_images = img_model.encode(image_pixels) 38 | encdf = pd.DataFrame(encoded_images) 39 | encdf['image_url'] = valid_urls 40 | print("encdf shape:", encdf.shape) 41 | encdf.to_csv(outfile, index=False) 42 | print("Saved encoded file as", outfile) 43 | 44 | 45 | def unpack_resnet_embeddings(tar_file="image_data_train.tar", outfile="image_data_train_resnet.csv"): 46 | """Unpack ResNet embeddings of Wikipedia images.""" 47 | emb_path = "image_data_train/resnet_embeddings" 48 | tar = tarfile.open(tar_file) 49 | with open(outfile, "w") as out: 50 | for f in tar.getmembers(): 51 | if f.name.endswith("gz"): 52 | if f.name.startswith(emb_path): 53 | tar.extract(f) 54 | print(f.name) 55 | with gzip.open(f.name, "rt") as imp: 56 | for line in tqdm(imp): 57 | line = line.strip() 58 | image_url, emb = line.split("\t") 59 | print(image_url+","+emb, file=out) 60 | os.remove(f.name) 61 | print("Done unpacking ResNet encodings") 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Code for our COLING 2022 paper [Multilingual and Multimodal Topic Modelling with Pretrained Embeddings](https://aclanthology.org/2022.coling-1.355) 2 | 3 | ### Abstract 4 | 5 | We present M3L-Contrast--—a novel multimodal multilingual (M3L) neural topic model for comparable data that maps multilingual texts and images into a shared topic space using a contrastive objective. As a multilingual topic model, it produces aligned *language-specific topics* and as multimodal model, it infers textual representations of semantic concepts in images. We also show that our model performs almost as well on unaligned embeddings as it does on aligned embeddings. 6 | 7 | Our proposed topic model is: 8 | - multilingual 9 | - multimodal (image-text) 10 | - multimodal *and* multilingual (M3L) 11 | 12 | Our model is based on the [Contextualized Topic Model](https://github.com/MilaNLProc/contextualized-topic-models) (Bianchi et al., 2021) 13 | 14 | We use the PyTorch Metric Learning library for the [InfoNCE/NTXent loss](https://github.com/KevinMusgrave/pytorch-metric-learning/) 15 | 16 | ### Model architecture 17 | 18 | 19 | 20 | ### Dataset 21 | - Aligned articles from the [Wikipedia Comparable Corpora](https://linguatools.org/tools/corpora/wikipedia-comparable-corpora/) 22 | - Images from the [WIT](https://github.com/google-research-datasets/wit) dataset 23 | - We will release the article titles and image urls in the train and test sets (soon!) 24 | 25 | ### Talks and slides 26 | - [Slides](https://blogs.helsinki.fi/language-technology/files/2022/11/LT-seminar-Elaine-Zosa-2022-11-10.pdf) and [video](https://unitube.it.helsinki.fi/unitube/embed.html?id=dae2b02d-47e7-46b0-adc3-86da8034ed58) from my talk at the Helsinki Language Technology seminar 27 | 28 | ### Trained models 29 | We shared some of the models we trained: 30 | 31 | - [M3L topic model](https://www.dropbox.com/sh/0lc48k9o2ctzvrl/AADhM2TLq6XxVgNvU0WZ59nZa?dl=0) trained with CLIP embeddings for texts and images 32 | - [M3L topic model](https://www.dropbox.com/sh/ilu7kypztd7pbli/AABCpy6hECPPOSPXRiFN2njFa?dl=0) trained with multilingual SBERT for text and CLIP for images 33 | - [M3L topic model](https://www.dropbox.com/scl/fo/oh6hrif37gynstt8a4wi7/h?dl=0&rlkey=034ozpeiaypfbv6fx9nht85co) trained with monolingual SBERT models for the English and German texts and CLIP for images 34 | 35 | 36 | ### Citation 37 | ``` 38 | @inproceedings{zosa-pivovarova-2022-multilingual, 39 | title = "Multilingual and Multimodal Topic Modelling with Pretrained Embeddings", 40 | author = "Zosa, Elaine and Pivovarova, Lidia", 41 | booktitle = "Proceedings of the 29th International Conference on Computational Linguistics", 42 | month = oct, 43 | year = "2022", 44 | address = "Gyeongju, Republic of Korea", 45 | publisher = "International Committee on Computational Linguistics", 46 | url = "https://aclanthology.org/2022.coling-1.355", 47 | pages = "4037--4048", 48 | } 49 | ``` 50 | -------------------------------------------------------------------------------- /training_scripts/train_multilingual_contrast.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import torch 4 | import string 5 | import numpy as np 6 | from gensim.corpora import Dictionary 7 | from gensim.models.coherencemodel import CoherenceModel 8 | 9 | from models.multilingual_contrast import MultilingualContrastiveTM 10 | from utils.data_preparation import MultilingualTopicModelDataPreparation 11 | from utils.preprocessing import WhiteSpacePreprocessingMultilingual 12 | 13 | import argparse 14 | argparser = argparse.ArgumentParser() 15 | argparser.add_argument('--model_name', default='MultilingualContrast', type=str) 16 | argparser.add_argument('--data_path', default='data/', type=str) 17 | argparser.add_argument('--save_path', default='trained_models/', type=str) 18 | argparser.add_argument('--train_data', default='wikiarticles.csv', type=str) 19 | argparser.add_argument('--num_topics', default=100, type=int) 20 | argparser.add_argument('--num_epochs', default=100, type=int) 21 | argparser.add_argument('--langs', default='en,de', type=str) 22 | argparser.add_argument('--sbert_model', default='paraphrase-multilingual-mpnet-base-v2', type=str) 23 | argparser.add_argument('--text_enc_dim', default=512, type=int) 24 | argparser.add_argument('--batch_size', default=32, type=int) 25 | argparser.add_argument('--max_seq_length', default=200, type=int) 26 | argparser.add_argument('--kl_weight', default=0.01, type=int, help='weight for the KLD loss') 27 | argparser.add_argument('--cl_weight', default=50, type=int, help='weight for the contrastive loss') 28 | args = argparser.parse_args() 29 | 30 | print("\n" + "-"*5, "Train Contrastive PLTM", "-"*5) 31 | print("model_name:", args.model_name) 32 | print("data_path:", args.data_path) 33 | print("save_path:", args.save_path) 34 | print("train_data:", args.train_data) 35 | print("num_topics:", args.num_topics) 36 | print("num_epochs:", args.num_epochs) 37 | print("langs:", args.langs) 38 | print("sbert_model:", args.sbert_model) 39 | print("text_enc_dim:", args.text_enc_dim) 40 | print("batch_size:", args.batch_size) 41 | print("max_seq_length:", args.max_seq_length) 42 | print("kl_weight:", args.kl_weight) 43 | print("cl_weight:", args.cl_weight) 44 | print("-"*50 + "\n") 45 | 46 | # stopwords lang dict 47 | lang_dict = {'en': 'english', 48 | 'de': 'german', 49 | 'fi': 'finnish'} 50 | 51 | # ----- load dataset ----- 52 | df = pd.read_csv(os.path.join(args.data_path, args.train_data)) 53 | # print("df:", df.shape) 54 | languages = args.langs.lower().split(',') 55 | languages = [l.strip() for l in languages] 56 | print('languages:', languages) 57 | 58 | documents = [list(df[lang+'_text']) for lang in languages] 59 | 60 | 61 | # ----- preprocess documents ----- 62 | lang_stopwords = [lang_dict[l] for l in languages] 63 | preproc_pipeline = WhiteSpacePreprocessingMultilingual(documents=documents, 64 | stopwords_languages=lang_stopwords, 65 | max_len=args.max_seq_length) 66 | preprocessed_docs, raw_docs, vocab = preproc_pipeline.preprocess() 67 | for l in range(len(languages)): 68 | print("-"*5, "lang", l, ":", languages[l].upper(), "-"*5) 69 | print('preprocessed_docs:', len(preprocessed_docs[l])) 70 | print('raw_docs:', len(raw_docs[l])) 71 | print('vocab:', len(vocab[l])) 72 | 73 | # ----- encode documents ----- 74 | qt = MultilingualTopicModelDataPreparation(args.sbert_model, vocabularies=vocab) 75 | 76 | training_dataset = qt.fit(text_for_contextual=raw_docs, text_for_bow=preprocessed_docs) 77 | 78 | 79 | # ----- initialize model ----- 80 | loss_weights = {"KL": args.kl_weight, 81 | "CL": args.cl_weight} 82 | 83 | contrast_model = MultilingualContrastiveTM(bow_size=qt.vocab_sizes[0], 84 | contextual_size=args.text_enc_dim, 85 | n_components=args.num_topics, 86 | num_epochs=args.num_epochs, 87 | languages=languages, 88 | batch_size=args.batch_size, 89 | loss_weights=loss_weights) 90 | 91 | # ----- topic inference ----- 92 | contrast_model.fit(training_dataset) 93 | 94 | # ----- save model ----- 95 | save_filepath = os.path.join(args.save_path, args.model_name + '_K' + str(args.num_topics) 96 | + '_epochs' + str(args.num_epochs) 97 | + '_' + args.sbert_model) 98 | contrast_model.save(save_filepath) 99 | print("Done! Saved model as", save_filepath) 100 | 101 | -------------------------------------------------------------------------------- /training_scripts/train_M3L_contrast.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from models.M3L_contrast import MultimodalContrastiveTM 4 | from utils.data_preparation import M3LTopicModelDataPreparation 5 | from utils.preprocessing import WhiteSpacePreprocessingM3L 6 | 7 | import argparse 8 | argparser = argparse.ArgumentParser() 9 | argparser.add_argument('--model_name', default='M3LContrast', type=str) 10 | argparser.add_argument('--data_path', default='data/', type=str) 11 | argparser.add_argument('--save_path', default='trained_models/', type=str) 12 | argparser.add_argument('--train_data', default='wikiarticles.csv', type=str) 13 | argparser.add_argument('--num_topics', default=100, type=int) 14 | argparser.add_argument('--num_epochs', default=100, type=int) 15 | argparser.add_argument('--langs', default='en,de', type=str, help='comma-separated lang codes for multilingual') 16 | argparser.add_argument('--sbert_model', default='clip-ViT-B-32-multilingual-v1', type=str) 17 | argparser.add_argument('--image_embeddings', default='wiki_clip.csv', type=str) 18 | argparser.add_argument('--text_enc_dim', default=512, type=int, help='encoding size sbert_model') 19 | argparser.add_argument('--image_enc_dim', default=512, type=int, help='encoding size of image embeddings') 20 | argparser.add_argument('--batch_size', default=32, type=int) 21 | argparser.add_argument('--max_seq_length', default=200, type=int) 22 | argparser.add_argument('--kl_weight', default=0.01, type=int, help='weight for the KLD loss') 23 | argparser.add_argument('--cl_weight', default=50, type=int, help='weight for the contrastive loss') 24 | args = argparser.parse_args() 25 | 26 | print("\n" + "-"*5, "Train M3L-Contrast TM", "-"*5) 27 | print("model_name:", args.model_name) 28 | print("data_path:", args.data_path) 29 | print("save_path:", args.save_path) 30 | print("train_data:", args.train_data) 31 | print("num_topics:", args.num_topics) 32 | print("num_epochs:", args.num_epochs) 33 | print("langs:", args.langs) 34 | print("sbert_model:", args.sbert_model) 35 | print("image_embeddings:", args.image_embeddings) 36 | print("text_enc_dim:", args.text_enc_dim) 37 | print("image_enc_dim:", args.image_enc_dim) 38 | print("batch_size:", args.batch_size) 39 | print("max_seq_length:", args.max_seq_length) 40 | print("kl_weight:", args.kl_weight) 41 | print("cl_weight:", args.cl_weight) 42 | print("-"*40 + "\n") 43 | 44 | 45 | # stopwords lang dict 46 | lang_dict = {'en': 'english', 47 | 'de': 'german'} 48 | 49 | # ----- load dataset ----- 50 | df = pd.read_csv(os.path.join(args.data_path, args.train_data)) 51 | # print("df:", df.shape) 52 | languages = args.langs.lower().split(',') 53 | languages = [l.strip() for l in languages] 54 | print('languages:', languages) 55 | 56 | documents = [list(df[lang+'_text']) for lang in languages] 57 | image_urls = list(df.image_url) 58 | 59 | # ----- preprocess documents ----- 60 | lang_stopwords = [lang_dict[l] for l in languages] 61 | preproc_pipeline = WhiteSpacePreprocessingM3L(documents=documents, 62 | image_urls=image_urls, 63 | stopwords_languages=lang_stopwords, 64 | max_len=args.max_seq_length) 65 | preprocessed_docs, raw_docs, vocab, image_urls = preproc_pipeline.preprocess() 66 | for l in range(len(languages)): 67 | print("-"*5, "lang", l, ":", languages[l].upper(), "-"*5) 68 | print('preprocessed_docs:', len(preprocessed_docs[l])) 69 | print('raw_docs:', len(raw_docs[l])) 70 | print('image urls:', len(image_urls)) 71 | print('vocab:', len(vocab[l])) 72 | 73 | # preprocessed_documents: list of list of preprocessed articles (one list for each language) 74 | # raw_docs: list of list of original articles (one list for each language) 75 | # vocab: list of list of words (one list for each language) 76 | 77 | # ----- encode documents ----- 78 | image_emb_file = os.path.join(args.data_path, args.image_embeddings) 79 | qt = M3LTopicModelDataPreparation(args.sbert_model, vocabularies=vocab, image_emb_file=image_emb_file) 80 | 81 | training_dataset = qt.fit(text_for_contextual=raw_docs, text_for_bow=preprocessed_docs, image_urls=image_urls) 82 | 83 | 84 | # ----- initialize model ----- 85 | loss_weights = {"KL": args.kl_weight, 86 | "CL": args.cl_weight} 87 | m3l_contrast = MultimodalContrastiveTM(bow_size=qt.vocab_sizes[0], 88 | contextual_sizes=(args.text_enc_dim, args.image_enc_dim), 89 | n_components=args.num_topics, 90 | num_epochs=args.num_epochs, 91 | languages=languages, 92 | batch_size=args.batch_size, 93 | loss_weights=loss_weights 94 | ) 95 | 96 | # ----- topic inference ----- 97 | m3l_contrast.fit(training_dataset) 98 | 99 | # ----- save model ----- 100 | save_filepath = os.path.join(args.save_path, args.model_name 101 | + "_K" + str(args.num_topics) 102 | + "_epochs" + str(args.num_epochs) 103 | + "_batch" + str(args.batch_size)) 104 | m3l_contrast.save(save_filepath) 105 | 106 | print("Done! Saved model as", save_filepath) -------------------------------------------------------------------------------- /networks/inference_network.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from torch import nn 3 | import torch 4 | 5 | 6 | class ContextualInferenceNetwork(nn.Module): 7 | 8 | """Inference Network.""" 9 | 10 | def __init__(self, input_size, bert_size, output_size, hidden_sizes, 11 | activation='softplus', dropout=0.2, label_size=0): 12 | """ 13 | # TODO: check dropout in main caller 14 | Initialize InferenceNetwork. 15 | 16 | Args 17 | input_size : int, dimension of input 18 | output_size : int, dimension of output 19 | hidden_sizes : tuple, length = n_layers 20 | activation : string, 'softplus' or 'relu', default 'softplus' 21 | dropout : float, default 0.2, default 0.2 22 | """ 23 | super(ContextualInferenceNetwork, self).__init__() 24 | #print('hidden_sizes:', hidden_sizes) 25 | assert isinstance(input_size, int), "input_size must by type int." 26 | assert isinstance(output_size, int), "output_size must be type int." 27 | assert isinstance(hidden_sizes, tuple), \ 28 | "hidden_sizes must be type tuple." 29 | assert activation in ['softplus', 'relu'], \ 30 | "activation must be 'softplus' or 'relu'." 31 | assert dropout >= 0, "dropout must be >= 0." 32 | 33 | self.input_size = input_size 34 | self.output_size = output_size 35 | self.hidden_sizes = hidden_sizes 36 | self.dropout = dropout 37 | 38 | if activation == 'softplus': 39 | self.activation = nn.Softplus() 40 | elif activation == 'relu': 41 | self.activation = nn.ReLU() 42 | 43 | self.input_layer = nn.Linear(bert_size + label_size, hidden_sizes[0]) 44 | #self.adapt_bert = nn.Linear(bert_size, hidden_sizes[0]) 45 | 46 | self.hiddens = nn.Sequential(OrderedDict([ 47 | ('l_{}'.format(i), nn.Sequential(nn.Linear(h_in, h_out), self.activation)) 48 | for i, (h_in, h_out) in enumerate(zip(hidden_sizes[:-1], hidden_sizes[1:]))])) 49 | 50 | self.f_mu = nn.Linear(hidden_sizes[-1], output_size) 51 | self.f_mu_batchnorm = nn.BatchNorm1d(output_size, affine=False) 52 | 53 | self.f_sigma = nn.Linear(hidden_sizes[-1], output_size) 54 | self.f_sigma_batchnorm = nn.BatchNorm1d(output_size, affine=False) 55 | 56 | self.dropout_enc = nn.Dropout(p=self.dropout) 57 | 58 | def forward(self, x, x_bert, labels=None): 59 | """Forward pass.""" 60 | 61 | x = x_bert 62 | if labels: 63 | x = torch.cat((x_bert, labels), 1) 64 | #print("linear layer:", self.input_layer.weight) 65 | x = self.input_layer(x) 66 | #print("input_layer:", x) 67 | x = self.activation(x) 68 | x = self.hiddens(x) 69 | #print("hidden:", x) 70 | x = self.dropout_enc(x) 71 | mu = self.f_mu_batchnorm(self.f_mu(x)) 72 | #print("f_mu_batchnorm:", mu) 73 | log_sigma = self.f_sigma_batchnorm(self.f_sigma(x)) 74 | #print("f_sigma_batchnorm:", log_sigma) 75 | 76 | return mu, log_sigma 77 | 78 | 79 | class CombinedInferenceNetwork(nn.Module): 80 | 81 | """Inference Network.""" 82 | 83 | def __init__(self, input_size, bert_size, output_size, hidden_sizes, 84 | activation='softplus', dropout=0.2, label_size=0): 85 | """ 86 | Initialize InferenceNetwork. 87 | 88 | Args 89 | input_size : int, dimension of input 90 | output_size : int, dimension of output 91 | hidden_sizes : tuple, length = n_layers 92 | activation : string, 'softplus' or 'relu', default 'softplus' 93 | dropout : float, default 0.2, default 0.2 94 | """ 95 | super(CombinedInferenceNetwork, self).__init__() 96 | assert isinstance(input_size, int), "input_size must by type int." 97 | assert isinstance(output_size, int), "output_size must be type int." 98 | assert isinstance(hidden_sizes, tuple), \ 99 | "hidden_sizes must be type tuple." 100 | assert activation in ['softplus', 'relu'], \ 101 | "activation must be 'softplus' or 'relu'." 102 | assert dropout >= 0, "dropout must be >= 0." 103 | 104 | self.input_size = input_size 105 | self.output_size = output_size 106 | self.hidden_sizes = hidden_sizes 107 | self.dropout = dropout 108 | 109 | if activation == 'softplus': 110 | self.activation = nn.Softplus() 111 | elif activation == 'relu': 112 | self.activation = nn.ReLU() 113 | 114 | 115 | self.adapt_bert = nn.Linear(bert_size, input_size) 116 | #self.bert_layer = nn.Linear(hidden_sizes[0], hidden_sizes[0]) 117 | self.input_layer = nn.Linear(input_size + input_size + label_size, hidden_sizes[0]) 118 | 119 | self.hiddens = nn.Sequential(OrderedDict([ 120 | ('l_{}'.format(i), nn.Sequential(nn.Linear(h_in, h_out), self.activation)) 121 | for i, (h_in, h_out) in enumerate(zip(hidden_sizes[:-1], hidden_sizes[1:]))])) 122 | 123 | self.f_mu = nn.Linear(hidden_sizes[-1], output_size) 124 | self.f_mu_batchnorm = nn.BatchNorm1d(output_size, affine=False) 125 | 126 | self.f_sigma = nn.Linear(hidden_sizes[-1], output_size) 127 | self.f_sigma_batchnorm = nn.BatchNorm1d(output_size, affine=False) 128 | 129 | self.dropout_enc = nn.Dropout(p=self.dropout) 130 | 131 | def forward(self, x, x_bert, labels=None): 132 | """Forward pass.""" 133 | x_bert = self.adapt_bert(x_bert) 134 | 135 | x = torch.cat((x, x_bert), 1) 136 | 137 | if labels is not None: 138 | x = torch.cat((x, labels), 1) 139 | 140 | x = self.input_layer(x) 141 | 142 | x = self.activation(x) 143 | x = self.hiddens(x) 144 | x = self.dropout_enc(x) 145 | mu = self.f_mu_batchnorm(self.f_mu(x)) 146 | log_sigma = self.f_sigma_batchnorm(self.f_sigma(x)) 147 | 148 | return mu, log_sigma 149 | -------------------------------------------------------------------------------- /utils/preprocess_wiki.py: -------------------------------------------------------------------------------- 1 | import string 2 | import re 3 | import os 4 | import pandas as pd 5 | import xml.etree.ElementTree as ET 6 | import mwparserfromhell 7 | 8 | 9 | def parse_and_clean_wikicode(raw_content): 10 | """Strips formatting and unwanted sections from raw page content.""" 11 | wikicode = mwparserfromhell.parse(raw_content) 12 | # Filters for references, tables, and file/image links. 13 | re_rm_wikilink = re.compile("^(?:File|Image|Media):", flags=re.IGNORECASE | re.UNICODE) 14 | def rm_wikilink(obj): 15 | return bool(re_rm_wikilink.match(str(obj.title))) 16 | def rm_tag(obj): 17 | return str(obj.tag) in {"ref", "table"} 18 | def rm_template(obj): 19 | return obj.name.lower() in {"reflist", "notelist", "notelist-ua", "notelist-lr", "notelist-ur", "notelist-lg"} 20 | def try_remove_obj(obj, section): 21 | try: 22 | section.remove(obj) 23 | except ValueError: 24 | # For unknown reasons, objects are sometimes not found. 25 | pass 26 | section_text = [] 27 | # Filter individual sections to clean. 28 | for section in wikicode.get_sections(flat=True, include_lead=True, include_headings=True): 29 | for obj in section.ifilter_wikilinks(matches=rm_wikilink, recursive=True): 30 | try_remove_obj(obj, section) 31 | for obj in section.ifilter_templates(matches=rm_template, recursive=True): 32 | try_remove_obj(obj, section) 33 | for obj in section.ifilter_tags(matches=rm_tag, recursive=True): 34 | try_remove_obj(obj, section) 35 | section_text.append(section.strip_code().strip()) 36 | return "\n\n".join(section_text) 37 | 38 | 39 | def parse_wikipedia_xml_dump(xml_path, save_path=None, lang='de'): 40 | """ Extracts title and article content from Wikipedia XML dumps (https://dumps.wikimedia.org/) and saves result to CSV""" 41 | print("Parsing XML", xml_path) 42 | filepath = xml_path 43 | f = open(filepath, 'r') 44 | tree = ET.parse(f) 45 | root = tree.getroot() 46 | articles = {lang + '_title': [], 47 | lang + '_text': []} 48 | for child in root: 49 | if 'page' in child.tag: 50 | page = child 51 | article_title = "" 52 | for child2 in page: 53 | if 'title' in child2.tag: 54 | title = child2 55 | if title.text is not None: 56 | article_title = title.text.lower() 57 | if 'revision' in child2.tag: 58 | revision = child2 59 | for child3 in revision: 60 | if 'text' in child3.tag: 61 | txt = child3 62 | if txt.text is not None: 63 | text = txt.text.lower().strip() 64 | # check if article is long enough (by num of chars) 65 | if len(text) > 200 and len(article_title) > 0: 66 | articles[lang + '_title'].append(article_title) 67 | article_text = parse_and_clean_wikicode(text) 68 | articles[lang + '_text'].append(article_text) 69 | if len(articles[lang + '_title']) % 100 == 0: 70 | print("Title:", article_title) 71 | print("Article:", article_text[:100]) 72 | print("="*40) 73 | print("Done parsing", len(articles[lang + '_title']), "articles") 74 | df = pd.DataFrame.from_dict(articles) 75 | if save_path is None: 76 | csv_file = filepath[:-4] + ".csv" 77 | else: 78 | csv_file = os.path.join(save_path, filepath[:-4] + ".csv") 79 | df.to_csv(csv_file, index=False) 80 | print("Done! Saved Wikipedia articles as", csv_file, "!") 81 | return csv_file 82 | 83 | 84 | def align_wikipedia_titles(xml_path, lang_pair='de-en'): 85 | """ Extracts titles for aligned multilingual Wikipedia articles (https://linguatools.org/tools/corpora/wikipedia-monolingual-corpora/) and saves result to CSV""" 86 | print("Language pair:", lang_pair.upper()) 87 | languages = lang_pair.split('-') 88 | lang1 = languages[0] 89 | lang2 = languages[1] 90 | df = {lang+"_title": [] for lang in languages} 91 | wiki = open(xml_path, 'r') 92 | tree = ET.parse(wiki) 93 | root = tree.getroot() 94 | for child in root: 95 | if len(df[lang1+"_title"]) > 10000: 96 | break 97 | if 'article' in child.tag: 98 | article = child 99 | lang1_art_title = article.attrib['name'].lower() 100 | for child2 in article: 101 | if 'crosslanguage_link' in child2.tag: 102 | cross_lang = child2 103 | lang_attrib = cross_lang.attrib['language'] 104 | if lang_attrib == lang2: 105 | lang2_art_title = cross_lang.attrib['name'].lower() 106 | df[lang1+"_title"].append(lang1_art_title) 107 | df[lang2+"_title"].append(lang2_art_title) 108 | print("Links:", len(df[lang1+"_title"])) 109 | save_filename = 'wikipairs_titles_'+lang_pair+'.csv' 110 | df = pd.DataFrame.from_dict(df) 111 | df.to_csv(save_filename, index=None) 112 | print("Done! Dumped aligned titles for pair", lang_pair.upper(), "to", save_filename, "!") 113 | 114 | 115 | def combine_multilingual_wikipedia_articles(aligned_titles, lang1_articles, lang2_articles, lang1='de', lang2='en'): 116 | """ Merges the aligned multilingual titles and extracted article contents """ 117 | df_merged = lang1_articles.merge(aligned_titles, on=[lang1 + "_title"]) 118 | df_merged = df_merged.merge(lang2_articles, on=[lang2 + "_title"]) 119 | return df_merged 120 | 121 | 122 | def extract_wikipedia_image_urls(image_tsv='train-00000-of-00005.tsv', lang='en'): 123 | """ Extracts image urls and article titles for a given language from the WIT dataset""" 124 | chunksize = 100000 125 | df_reduced = {'image_url': [], 126 | lang + '_title': []} 127 | with pd.read_csv(image_tsv, sep='\t', chunksize=chunksize) as reader: 128 | for i, chunk in enumerate(reader): 129 | df = chunk[chunk.language == lang] 130 | df_reduced['image_url'].extend(list(df.image_url)) 131 | page_titles = list(df.page_title) 132 | page_titles = [title.lower() for title in page_titles] 133 | df_reduced[lang + '_title'].extend(list(page_titles)) 134 | df_reduced = pd.DataFrame.from_dict(df_reduced) 135 | save_path = image_tsv[:-4] + '-' + lang + '.csv' 136 | df_reduced.to_csv(save_path, index=False) 137 | 138 | 139 | -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import scipy.sparse 4 | 5 | 6 | # ----- Multimodal and Multilingual (M3L) ----- 7 | class M3LDataset(Dataset): 8 | 9 | """Class to load BoW and the contextualized embeddings for *aligned multilingual* datasets""" 10 | 11 | def __init__(self, X_contextual, X_bow, X_image_emb, idx2token, num_lang=2, is_inference=False): 12 | # during training, data is multilingual AND multimodal 13 | # during inference, data is monolingual AND monomodal (either image or text) 14 | if is_inference is False: 15 | if X_bow[0].shape[0] != X_contextual[0].shape[0]: 16 | raise Exception("Wait! BoW and Contextual Embeddings have different sizes! " 17 | "You might want to check if the BoW preparation method has removed some documents. ") 18 | # else: 19 | # if X_bow.shape[0] != X_contextual.shape[0]: 20 | # raise Exception("Wait! BoW and Contextual Embeddings have different sizes! " 21 | # "You might want to check if the BoW preparation method has removed some documents. ") 22 | # 23 | # if labels is not None: 24 | # if labels.shape[0] != X_bow.shape[0]: 25 | # raise Exception(f"There is something wrong in the length of the labels (size: {labels.shape[0]}) " 26 | # f"and the bow (len: {X_bow.shape[0]}). These two numbers should match.") 27 | 28 | self.X_bow = X_bow 29 | self.X_contextual = X_contextual 30 | self.X_image = X_image_emb 31 | self.idx2token = idx2token 32 | self.num_lang = num_lang 33 | self.inference_mode = is_inference 34 | 35 | def __len__(self): 36 | """Return length of dataset.""" 37 | # during training, X_bow, X_contextual and X_image are all available 38 | if self.inference_mode is False: 39 | return self.X_contextual[0].shape[0] 40 | # during inference, either X_contextual or X_image is available, not both 41 | else: 42 | if self.X_contextual is not None: 43 | return self.X_contextual.shape[0] 44 | else: 45 | return self.X_image.shape[0] 46 | 47 | def __getitem__(self, i): 48 | """Return sample from dataset at index i.""" 49 | # TRAINING: dataset is multimodal AND multilingual (X_contextual will have 1 extra row for the image embedding) 50 | if self.inference_mode is False: 51 | if type(self.X_bow[0][i]) == scipy.sparse.csr.csr_matrix: 52 | X_bow_collect = [] 53 | X_contextual_collect = [] 54 | for l in range(self.num_lang): 55 | X_bow = torch.FloatTensor(self.X_bow[l][i].todense()) 56 | X_contextual = torch.FloatTensor(self.X_contextual[l][i]) 57 | X_bow_collect.append(X_bow) 58 | X_contextual_collect.append(X_contextual) 59 | # X_bow_collect: L x vocab_size 60 | X_bow_collect = torch.stack(X_bow_collect) 61 | # X_contextual_collect: L x bert_dim 62 | X_contextual_collect = torch.stack(X_contextual_collect) 63 | # X_image: bert_dim 64 | X_image = torch.FloatTensor(self.X_image[i]) 65 | else: 66 | X_bow_collect = [] 67 | X_contextual_collect = [] 68 | for l in range(self.num_lang): 69 | X_bow = torch.FloatTensor(self.X_bow[i]) 70 | X_contextual = torch.FloatTensor(self.X_contextual[i]) 71 | X_bow_collect.append(X_bow) 72 | X_contextual_collect.append(X_contextual) 73 | # X_bow_collect: L x vocab_size 74 | X_bow_collect = torch.stack(X_bow_collect) 75 | # X_contextual_collect: L x bert_dim 76 | X_contextual_collect = torch.stack(X_contextual_collect) 77 | # X_image: bert_dim 78 | X_image = torch.FloatTensor(self.X_image[i]) 79 | return_dict = {'X_bow': X_bow_collect, 'X_contextual': X_contextual_collect, 'X_image': X_image} 80 | # INFERENCE: dataset is monolingual AND monomodal (either text or image) 81 | else: 82 | # X_bow is just a dummy variable 83 | X_bow = torch.FloatTensor(torch.rand(10)) 84 | if self.X_contextual is not None: 85 | X_test = torch.FloatTensor(self.X_contextual[i]) 86 | else: 87 | X_test = torch.FloatTensor(self.X_image[i]) 88 | return_dict = {'X_contextual': X_test, 'X_bow': X_bow} 89 | return return_dict 90 | 91 | 92 | 93 | # ----- Multilingual ----- 94 | class PLTMDataset(Dataset): 95 | 96 | """Class to load BoW and the contextualized embeddings for *aligned multilingual* datasets""" 97 | 98 | def __init__(self, X_contextual, X_bow, idx2token, labels=None, num_lang=2, is_inference=False): 99 | 100 | # if we are in training mode, dataset is multilingual; inference is monolingual 101 | if is_inference is False: 102 | if X_bow[0].shape[0] != X_contextual[0].shape[0]: 103 | raise Exception("Wait! BoW and Contextual Embeddings have different sizes! " 104 | "You might want to check if the BoW preparation method has removed some documents. ") 105 | else: 106 | if X_bow.shape[0] != X_contextual.shape[0]: 107 | raise Exception("Wait! BoW and Contextual Embeddings have different sizes! " 108 | "You might want to check if the BoW preparation method has removed some documents. ") 109 | 110 | if labels is not None: 111 | if labels.shape[0] != X_bow.shape[0]: 112 | raise Exception(f"There is something wrong in the length of the labels (size: {labels.shape[0]}) " 113 | f"and the bow (len: {X_bow.shape[0]}). These two numbers should match.") 114 | 115 | self.X_bow = X_bow 116 | self.X_contextual = X_contextual 117 | self.idx2token = idx2token 118 | self.labels = labels 119 | self.num_lang = num_lang 120 | self.inference_mode = is_inference 121 | 122 | def __len__(self): 123 | """Return length of dataset.""" 124 | if self.inference_mode is False: 125 | return self.X_contextual[0].shape[0] 126 | else: 127 | return self.X_contextual.shape[0] 128 | 129 | def __getitem__(self, i): 130 | """Return sample from dataset at index i.""" 131 | # TRAINING: dataset is multilingual 132 | if self.inference_mode is False: 133 | if type(self.X_bow[0][i]) == scipy.sparse.csr.csr_matrix: 134 | X_bow_collect = [] 135 | X_contextual_collect = [] 136 | for l in range(self.num_lang): 137 | X_bow = torch.FloatTensor(self.X_bow[l][i].todense()) 138 | X_contextual = torch.FloatTensor(self.X_contextual[l][i]) 139 | X_bow_collect.append(X_bow) 140 | X_contextual_collect.append(X_contextual) 141 | X_bow_collect = torch.stack(X_bow_collect) 142 | X_contextual_collect = torch.stack(X_contextual_collect) 143 | else: 144 | X_bow_collect = [] 145 | X_contextual_collect = [] 146 | for l in range(self.num_lang): 147 | X_bow = torch.FloatTensor(self.X_bow[i]) 148 | X_contextual = torch.FloatTensor(self.X_contextual[i]) 149 | X_bow_collect.append(X_bow) 150 | X_contextual_collect.append(X_contextual) 151 | X_bow_collect = torch.stack(X_bow_collect) 152 | X_contextual_collect = torch.stack(X_contextual_collect) 153 | 154 | return_dict = {'X_bow': X_bow_collect, 'X_contextual': X_contextual_collect} 155 | # INFERENCE: dataset is monolingual 156 | else: 157 | # we don't care about X_bow during inference 158 | if type(self.X_bow[i]) == scipy.sparse.csr.csr_matrix: 159 | X_bow = torch.FloatTensor(self.X_bow[i].todense()) 160 | X_contextual = torch.FloatTensor(self.X_contextual[i]) 161 | else: 162 | X_bow = torch.FloatTensor(self.X_bow[i]) 163 | X_contextual = torch.FloatTensor(self.X_contextual[i]) 164 | return_dict = {'X_bow': X_bow, 'X_contextual': X_contextual} 165 | 166 | return return_dict 167 | 168 | # ----- Original ----- 169 | 170 | class CTMDataset(Dataset): 171 | 172 | """Class to load BoW and the contextualized embeddings.""" 173 | 174 | def __init__(self, X_contextual, X_bow, idx2token, labels=None): 175 | 176 | if X_bow.shape[0] != len(X_contextual): 177 | raise Exception("Wait! BoW and Contextual Embeddings have different sizes! " 178 | "You might want to check if the BoW preparation method has removed some documents. ") 179 | 180 | if labels is not None: 181 | if labels.shape[0] != X_bow.shape[0]: 182 | raise Exception(f"There is something wrong in the length of the labels (size: {labels.shape[0]}) " 183 | f"and the bow (len: {X_bow.shape[0]}). These two numbers should match.") 184 | 185 | self.X_bow = X_bow 186 | self.X_contextual = X_contextual 187 | self.idx2token = idx2token 188 | self.labels = labels 189 | 190 | def __len__(self): 191 | """Return length of dataset.""" 192 | return self.X_bow.shape[0] 193 | 194 | def __getitem__(self, i): 195 | """Return sample from dataset at index i.""" 196 | if type(self.X_bow[i]) == scipy.sparse.csr.csr_matrix: 197 | X_bow = torch.FloatTensor(self.X_bow[i].todense()) 198 | X_contextual = torch.FloatTensor(self.X_contextual[i]) 199 | else: 200 | X_bow = torch.FloatTensor(self.X_bow[i]) 201 | X_contextual = torch.FloatTensor(self.X_contextual[i]) 202 | 203 | return_dict = {'X_bow': X_bow, 'X_contextual': X_contextual} 204 | 205 | if self.labels is not None: 206 | labels = self.labels[i] 207 | if type(labels) == scipy.sparse.csr.csr_matrix: 208 | return_dict["labels"] = torch.FloatTensor(labels.todense()) 209 | else: 210 | return_dict["labels"] = torch.FloatTensor(labels) 211 | 212 | return return_dict 213 | 214 | 215 | -------------------------------------------------------------------------------- /data/train-example.csv: -------------------------------------------------------------------------------- 1 | de_title,en_title,de_text,en_text,image_url 2 | hauspferd,horse,"Das Hauspferd (Equus caballus) ist ein weit verbreitetes Haus- bzw. Nutztier, das in zahlreichen Rassen auf der ganzen Welt existiert. 3 | 4 | Das Hauspferd ist die domestizierte Form des Wildpferdes (Equus ferus), das mit den Eseln und Zebras die Familie der Pferde (Einhufer, Equidae) innerhalb der Ordnung der Unpaarhufer (Perissodactyla) bildet. 5 | 6 | Merkmale 7 | 8 | Äußeres 9 | mini|Unterschiedliche Pferde (Lithographie) 10 | Das Aussehen des Hauspferdes variiert in seinem Körperbau, der Körpergröße, Fell und Farbe. ","The horse (Equus ferus caballus) is one of two extant subspecies of Equus ferus. It is an odd-toed ungulate mammal belonging to the taxonomic family Equidae. The horse has evolved over the past 45 to 55 million years from a small multi-toed creature, Eohippus, into the large, single-toed animal of today. Humans began domesticating horses around 4000 BC, and their domestication is believed to have been widespread by 3000 BC. Horses in the subspecies caballus are domesticated, although some domest",https://upload.wikimedia.org/wikipedia/commons/b/b6/Bhimbetka_rock_paintng1.jpg 11 | altamura,altamura,"Altamura ist eine italienische Stadt in der Metropolitanstadt Bari in Apulien mit Einwohnern (Stand ). 12 | 13 | Geografie 14 | Altamura liegt 45 km südwestlich von Bari an der Grenze zur Basilikata. Die Nachbargemeinden sind Bitonto, Cassano delle Murge, Gravina in Puglia, Grumo Appula, Matera (MT), Ruvo di Puglia, Santeramo in Colle und Toritto. 15 | 16 | Geschichte 17 | Seit dem 8. Jahrhundert v. Chr. wurde die Murge-Anhöhe, auf der Altamura liegt, besiedelt und im 6. bis 3. Jahrhundert v. Chr. wurden in der Gegend M","Altamura (, ; ) is a town and comune of Apulia, in southern Italy. It is located on one of the hills of the Murge plateau in the Metropolitan City of Bari, southwest of Bari, close to the border with Basilicata. , its population amounts to 70,595 inhabitants. 18 | 19 | The city is known for its particular quality of bread called Pane di Altamura, which is sold in numerous other Italian cities. The 130,000-year-old calcified Altamura Man was discovered in 1993 in the nearby limestone cave called grotta d",https://upload.wikimedia.org/wikipedia/commons/f/f6/Cesare_Orlandi_-_Altamura.png 20 | fc tschernomorez burgas,psfc chernomorets burgas,"Der FC Tschernomorez Burgas (, englisch Chernomorets Burgas) ist ein Fußballverein aus der bulgarischen Schwarzmeer-Hafenstadt Burgas. Der Verein spielt in der ersten bulgarischen Liga. Präsident des 2005 gegründeten Vereins ist Mitko Sabew. Ex-Bundesliga-Profi Krassimir Balakow trainierte die Mannschaft von 2008 bis Dezember 2010. Der Tschernomorez Pomorie fungiert als zweite Mannschaft des Clubs aus Burgas. 21 | 22 | Geschichte 23 | mini|ehem. „9 Septemvri-Stadion“ 24 | mini|Lasur-Stadion (Innenansicht) 25 | mini|Fa","PSFC Chernomorets Burgas () or simply Chernomorets () was a Bulgarian football club from the city of Burgas, which last competed in Bulgaria's fifth football league, the B Regional Group. The club has never won any major competition, its most notable achievement being a second-place finish in the UEFA Intertoto Cup competition in 2008. 26 | 27 | The club was founded in 2005, following the folding of FC Chernomorets Burgas, that played numerous seasons in the first tier of Bulgarian football. The new Cher",https://upload.wikimedia.org/wikipedia/commons/b/b8/Fanclub_Chernomorez.jpg 28 | salmānu-ašarēd i.,shalmaneser i,"Salmānu-ašarēd I. (Šulmānu-ašarēd I. oder auch Salmanassar I., in Analogie zum biblischen Salmanassar V.) war ein mittelassyrischer König. Sein Name bedeutet: „Salmānu ist der oberste Gott“. Als Sohn von Adad-nārārī I. regierte er nach der assyrischen Königsliste 30 Jahre als König das Assyrische Reich. 29 | 30 | Er war der Sohn und Nachfolger des assyrischen Königs Adad-nārārī I. und führte wie seine Vorgänger die Titel „Statthalter des Bel und Priester des Aššur“. Außerdem nannte er sich Großkönig und ","Shalmaneser I (Shulmanu-asharedu; 1274 BC – 1245 BC or 1265 BC – 1235 BC) was a king of Assyria during the Middle Assyrian Empire (1365 - 1050 BC). Son of Adad-nirari I, he succeeded his father as king in 1265 BC. 31 | 32 | According to his annals, discovered at Assur, in his first year he conquered eight countries in the northwest and destroyed the fortress of Arinnu, the dust of which he brought to Assur. In his second year he defeated Shattuara, king of Hanilgalbat (Mitanni), and his Hittite and Ahlam",https://upload.wikimedia.org/wikipedia/commons/a/aa/KingshalmaneserI.jpg 33 | automobil-weltmeisterschaft 1973,1973 formula one season,"Die Automobil-Weltmeisterschaft 1973 war die 24. Saison der Automobil-Weltmeisterschaft, die heutzutage als Formel-1-Weltmeisterschaft bezeichnet wird. In ihrem Rahmen wurden über 15 Rennen in der Zeit vom 28. Januar 1973 bis zum 7. Oktober 1973 die Fahrerweltmeisterschaft und der Internationale Pokal der Formel-1-Konstrukteure ausgetragen. 34 | 35 | Jackie Stewart gewann zum dritten und letzten Mal die Fahrer-Weltmeisterschaft. Lotus-Ford wurde zum sechsten Mal Konstrukteursweltmeister. 36 | 37 | Der FIA-Ehrenti","The 1973 Formula One season was the 27th season of FIA Formula One motor racing. It featured the 1973 World Championship of Drivers and the 1973 International Cup for F1 Manufacturers, which were contested concurrently over a fifteen-race series that commenced on 28 January and ended on 7 October. There were two new races for the 1973 season – the Brazilian Grand Prix at Interlagos in São Paulo and the Swedish Grand Prix at Anderstorp. The season also included two non-championship races which we",https://upload.wikimedia.org/wikipedia/commons/1/12/Lotus_72_JPS.jpg 38 | großer türkenkrieg,great turkish war,"Der Große Türkenkrieg zwischen der Heiligen Liga europäischer Mächte und dem Osmanischen Reich, auch als Großer Türkenkrieg Leopolds I. oder 5. Österreichischer Türkenkrieg bezeichnet, dauerte von 1683 bis 1699. Unter seinem neuen Großwesir und Oberbefehlshaber Kara Mustafa versuchte das Osmanische Reich 1683 zum zweiten Mal (nach der Ersten Wiener Türkenbelagerung 1529), die Kaiserstadt Wien zu erobern und das Tor nach Zentraleuropa aufzustoßen. Das Scheitern dieser Belagerung führte zur kaiser","The Great Turkish War () or the War of the Holy League () was a series of conflicts between the Ottoman Empire and the Holy League consisting of the Habsburg Monarchy, Poland-Lithuania, Venice and Russia. Intensive fighting began in 1683 and ended with the signing of the Treaty of Karlowitz in 1699. The war was a defeat for the Ottoman Empire, which for the first time lost large amounts of territory. It lost lands in Hungary and the Polish–Lithuanian Commonwealth, as well as part of the western ",https://upload.wikimedia.org/wikipedia/commons/1/19/1684_Entsatz_von_Wien_anagoria.JPG 39 | cœur de pirate,cœur de pirate,WEITERLEITUNG Cœur de Pirate,"Béatrice Martin (; born September 22, 1989), better known by her stage name Cœur de pirate (; French for Pirate's Heart), is a Canadian singer-songwriter and pianist. A francophone from Montreal, she sings mostly in French and has been credited in Montreal Mirror with ""bringing la chanson française to a whole new generation of Quebec youth"". 40 | 41 | Career 42 | 43 | Early beginnings 44 | Born in the province of Quebec, Martin started playing the piano when she was only three years old. She entered the Conservatoire ",https://upload.wikimedia.org/wikipedia/commons/e/e4/C%C5%93ur_de_pirate.jpg 45 | janisjarwi,lake yanisyarvi,"Der Janisjarwi (; finnische Bezeichnung Jänisjärvi, ""Hasensee"") ist ein 200 km² großer See in der Republik Karelien in Russland nördlich des Ladogasees gelegen. 46 | 47 | Der See wird vom Jänisjoki durchflossen und zum südlich gelegenen Ladogasee entwässert. 48 | Das Einzugsgebiet des Sees umfasst 3660 km². 49 | 50 | Das Becken des annähernd kreisförmigen Sees ist das Ergebnis eines Meteoriteneinschlags vor 700±5 Millionen Jahren (nach anderen Quellen 698±22 Mio. Jahre) während des Cryogeniums. 51 | 52 | Der Impaktkrater hat e","Lake Yanisyarvi (; ) is a lake in the Republic of Karelia, Russia, located north of and draining to Lake Ladoga. 53 | 54 | The basin of this somewhat circular lake was formed by meteorite impact 700±5 million years ago during the Cryogenian period. The crater is in diameter. 55 | 56 | Prior to World War II, the lake was thought to be the second known volcanic caldera in Finland (the other was Lake Lappajärvi). Both were eventually recognized as impact craters. 57 | 58 | References 59 | 60 | External links 61 | Lake Jänisjärvi Impact C",https://upload.wikimedia.org/wikipedia/commons/9/9e/Janisjarvi_crater_lake_01.jpg 62 | umm er-rasas,umm ar-rasas,"mini|Stadtvignette von Kastron Mefa'a, Detail des Bodenmosaiks in der Stephanuskirche 63 | Umm er-Rasas (; auch: Kastron Mefa’a oder Mefaa oder Mephaon) ist eine archäologische Stätte in Jordanien mit Ruinen vom Ende des 3. bis zum 9. Jahrhundert. Sie liegt rund 70 km südlich von Amman und 30 km von Madaba entfernt und repräsentiert vermutlich das in Josua 18, 13 erwähnte Mefaat (auch Mephaat, Mefa'at, Mepha'at). 64 | 65 | Umm er-Rasas war eine ummauerte Siedlung und enthält Ruinen aus römischer und byzantini","Umm ar-Rasas () (Kastrom Mefa'a, Kastron Mefa'a) is located 30 km southeast of Madaba, which is the capital city of the Madaba Governorate in central Jordan. It was once accessible by branches of the King's Highway, and is situated in the semi-arid steppe region of the Jordanian Desert. The site has been allied to the biblical settlement of Mephaat mentioned in the Book of Jeremiah. The Roman military utilized the site as a strategic garrison, but it was later converted and inhabited by Christia",http://upload.wikimedia.org/wikipedia/commons/9/91/Umm_Rasas_Fisherman.JPG 66 | petersburger blutsonntag,bloody sunday (1905),"miniatur|Demonstranten auf dem Weg zum Winterpalast in St. Petersburg 67 | Der Petersburger Blutsonntag (auch Blutiger Sonntag, Roter oder Schwarzer Sonntag) des Jahres 1905 war ein Ereignis in der Geschichte des Russischen Kaiserreichs und Teil der Russischen Revolution von 1905. 68 | 69 | Der Blutsonntag 70 | miniatur|Militär vor dem Winterpalast 71 | In den ersten Januartagen 1905 erfasste ein Generalstreik zuerst die Putilow-Werke und bald darauf auch die Werften, Manufakturen und Webereien. Am Sonntag, dem begab","Bloody Sunday or Red Sunday () is the name given to the events of Sunday, in St Petersburg, Russia, when unarmed demonstrators, led by Father Georgy Gapon, were fired upon by soldiers of the Imperial Guard as they marched towards the Winter Palace to present a petition to Tsar Nicholas II of Russia. 72 | 73 | Bloody Sunday caused grave consequences for the Tsarist autocracy governing Imperial Russia: the events in St. Petersburg provoked public outrage and a series of massive strikes that spread quickly",https://upload.wikimedia.org/wikipedia/commons/5/5c/BloodySunday1905.jpg 74 | -------------------------------------------------------------------------------- /data/test-example.csv: -------------------------------------------------------------------------------- 1 | de_title,en_title,de_text,en_text,image_url 2 | fleckenbauch-avosettkolibri,mountain avocetbill,"Der Fleckenbauch-Avosettkolibri (Opisthoprora euryptera) oder Degenschnabelkolibri ist eine Vogelart aus der Familie der Kolibris (Trochilidae) und die einzige Art der Gattung Opisthoprora. Das Verbreitungsgebiet dieser Art umfasst die Länder Peru, Ecuador und Kolumbien. Der Bestand wird von der IUCN als nicht gefährdet (Least Concern) eingeschätzt. 3 | 4 | Merkmale 5 | Der Fleckenbauch-Avosettkolibri erreicht eine Körperlänge von etwa 10 cm. Der kurze, 13 mm lange Schnabel ist an der Spitze deutlich nach","The mountain avocetbill (Opisthoprora euryptera) is a species of hummingbird in the family Trochilidae, the only member of its genus. 6 | 7 | It is found in Colombia, Ecuador, and Peru. Its natural habitat is subtropical or temperate moist montane forest. 8 | 9 | References 10 | 11 | Category:Trochilinae 12 | Category:Hummingbird species of South America 13 | Category:Birds of the Colombian Andes 14 | Category:Birds of the Ecuadorian Andes 15 | Category:Birds of the Peruvian Andes 16 | mountain avocetbill 17 | Category:Taxonomy articles created by",https://upload.wikimedia.org/wikipedia/commons/e/ef/Opisthoprora_euryptera.jpg 18 | luis salom,luis salom,"Luis Jaime Salom Horrach (* 7. August 1991 in Palma; † 3. Juni 2016 in Sant Cugat del Vallès) war ein spanischer Motorradrennfahrer. 19 | 20 | Karriere 21 | Salom gab sein Debüt in der Motorrad-Weltmeisterschaft im Jahre 2009 in der 125-cm³-Klasse für das Team SAG-Castrol. Er wurde in dieser Saison auch als Ersatzfahrer im Team Jack & Jones eingesetzt. Im Jahr 2010 fuhr er zunächst für das Team Lambretta Reparto Corse und wechselte nach zwei Rennen zu Stipa-Molenaar Racing GP. 2011 und 2012 trat er die Renne","Luis Jaime Salom Horrach (7 August 1991 – 3 June 2016) was a Spanish Grand Prix motorcycle racer. Salom died after a practice accident at Circuit de Catalunya, when making contact with his bike and the wall after a high-speed accident. Racing in the Moto2 class since 2014, he finished 41 races, with 3 podium appearances, including a second-place finish at the 2016 Qatar season opener. At the time of his death, Salom ranked 10th in the 2016 Moto2 Championship point standings. Previously he had c",https://upload.wikimedia.org/wikipedia/commons/3/34/Luis_Salom_2010_Silverstone.jpg 22 | takuma satō,takuma sato,"Takuma Satō (; * 28. Januar 1977 in Shinjuku, Tokio) ist ein japanischer Automobilrennfahrer, der seit 2010 in der IndyCar Series aktiv ist. 2001 wurde er britischer Formel-3-Meister. Von 2002 bis 2008 startete Satō für verschiedene Teams zu 90 Formel-1-Rennen. 23 | 24 | Satō gewann das Indianapolis 500 2017, er ist der erste asiatische Fahrer, der dieses Rennen gewinnen konnte. Auch ist er der erste japanische Formel-1-Pilot, der sich für die erste Startreihe eines Grand Prix qualifizierte. 25 | 26 | Karriere 27 | 28 | A","is a Japanese professional racing driver. Sato has raced full-time in the IndyCar Series since 2010 for the Honda-powered KV, Rahal, Foyt, Andretti, and again starting from 2018, the Rahal teams. Sato won the 2017 Indianapolis 500, becoming the first Asian driver to win the Indy 500. He also became the first Japanese driver to win an IndyCar race when he won the 2013 Grand Prix of Long Beach. He competed in Formula One from 2002 to 2008 for the Honda-powered Jordan, BAR and Super Aguri teams, sc",https://upload.wikimedia.org/wikipedia/commons/f/fb/Takuma_Sato_2007_Britain_2.jpg 29 | apoel nikosia,apoel fc,WEITERLEITUNG APOEL Nikosia,"APOEL FC (; short for Αθλητικός Ποδοσφαιρικός Όμιλος Ελλήνων Λευκωσίας, Athletikos Podosferikos Omilos Ellinon Lefkosias, ""Athletic Football Club of Greeks of Nicosia"") is a professional football club based in Nicosia, Cyprus. APOEL is the most popular and the most successful football team in Cyprus with an overall tally of 28 national championships, 21 cups, and 13 super cups. 30 | 31 | APOEL's greatest moment in the European competitions occurred in the 2011–12 season, when the club participated in the",https://upload.wikimedia.org/wikipedia/commons/c/cf/APOEL_-_Chelsea.jpg 32 | doc holliday,doc holliday,"mini|Doc Holliday in Tombstone (um 1882) 33 | 34 | John Henry Holliday (* 14. August 1851 oder Anfang 1852 in Griffin, Georgia; † 8. November 1887 in Glenwood Springs, Colorado), bekannt unter seinem Spitznamen Doc Holliday, war Zahnarzt und einer der berühmtesten Revolverhelden des Wilden Westens. Holliday war an neun Schießereien beteiligt und tötete zwischen drei und sieben Menschen. 35 | 36 | Leben 37 | Doc Holliday war der Sohn von Henry Burroughs Holliday, einem Major der US-Armee, und Alice Jane McKay. Sein Ge","John Henry ""Doc"" Holliday (August 14, 1851 – November 8, 1887) was an American gambler, gunfighter, and dentist. A close friend and associate of lawman Wyatt Earp, Holliday is best known for his role in the events leading up to and following the Gunfight at the O.K. Corral. He developed a reputation as having killed more than a dozen men in various altercations, but modern researchers have concluded that, contrary to popular myth-making, Holliday killed only one to three men. Holliday's colorful",https://upload.wikimedia.org/wikipedia/commons/9/96/HollidayandBowler.jpg 38 | volvo 480,volvo 480,"Der Volvo 480 ist ein Kompaktklasse-Coupé bzw. Shooting Brake und das erste frontgetriebene Fahrzeug von Volvo. Es wird zur 400er-Serie gezählt. Das auf dem Entwurf des Niederländers John de Vries basierende Design zeigt gewisse Anlehnungen an den von Sommer 1971 bis Ende 1973 gebauten Volvo P1800 ES. 39 | 40 | Allgemeines 41 | Der Volvo 480 wurde im niederländischen Werk Born (Limburg) gefertigt, aus dem 1991 NedCar hervorging. Bereits frühere Baureihen wie der Volvo 66 und der Volvo 340/360 sowie die Baure","The Volvo 480 is a sporty compact car that was produced in Born, Netherlands, by Volvo from 1986 to 1995. It was the first front-wheel drive car made by the automaker. The 480 was available in only one body style on an automobile platform related to the Volvo 440/460 five door hatchback and four door sedan models. 42 | 43 | It features an unusual four seat, three door hatchback body, somewhere between liftback and estate in form. The 480 was marketed as a coupé in Europe starting in 1986. The compact car",https://upload.wikimedia.org/wikipedia/commons/c/c4/Volvo_480_Heck.jpg 44 | 1037,1037," 45 | 46 | Ereignisse 47 | 48 | Politik und Weltgeschehen 49 | 50 | Heiliges Römisches Reich/Burgund 51 | 52 | 28. Mai: Mit der Constitutio de feudis verfügt Kaiser Konrad II. auf seinem zweiten Italienzug in Cremona die Erblichkeit der Lehen in Reichsitalien auch für den niederen Adel. Der Konflikt zwischen Erzbischof Aribert von Mailand und den Valvassoren wird damit vorläufig beigelegt. 53 | 15. November: In seinem Kampf gegen Kaiser Konrad II. wird Graf Odo II. von Blois auf der Ebene von Honol, zwischen Bar und Verdun, von ein","Year 1037 (MXXXVII) was a common year starting on Saturday (link will display the full calendar) of the Julian calendar. 54 | 55 | Events 56 | By place 57 | 58 | Europe 59 | Spring – A revolt in northern Italy is started by Aribert, archbishop of Milan. King Henry III (the eldest son of Emperor Conrad II) travels south of the Alps to quell it. 60 | February – At an Imperial Diet in Pavia (assembled by Conrad II) Aribert is accused of fomenting a revolt against the Holy Roman Empire, Conrad orders his arrest. 61 | May – Conra",https://upload.wikimedia.org/wikipedia/commons/f/f4/Ferdinand_I_of_Le%C3%B3n_cely.jpg 62 | gary medel,gary medel,"Gary Alexis Medel Soto (* 3. August 1987 in Santiago) ist ein chilenischer Fußballspieler. 63 | 64 | Karriere 65 | 66 | Verein 67 | Gary Medel kommt aus der Jugend von CD Universidad Católica, wo er im Jahr 2005 sein Profidebüt gab. Seine ersten beiden Tore erzielte er zwei Jahre später am 28. Juli 2007 gegen den Lokalrivalen CF Universidad de Chile im Clásico Universitario. Er besiegelte den Sieg durch seine beiden Tore und schoss sich damit in die Herzen der Fans. 68 | 69 | Im Jahr 2008 wurde er vor Lucas Barrios zum Fußba","Gary Alexis Medel Soto (; born 3 August 1987) is a Chilean professional footballer who plays for Italian club Bologna as a defensive midfielder. However, he can also play as a defender, and has even been deployed as a centre-back throughout his career, as well as in midfield. Medel has played club football with several teams in numerous countries, starting out with Chilean side Universidad Católica, and later playing for Argentine side Boca Juniors, Spanish side Sevilla, Premier League side Card",https://upload.wikimedia.org/wikipedia/commons/2/2a/Brazil_vs._Chile_in_Mineir%C3%A3o_02.jpg 70 | jaguar,jaguar,"mini|250px|Jaguar 71 | mini|250px|Zum Vergleich: Afrikanischer Leopard (Panthera pardus pardus) 72 | mini|250px|Ein Exemplar mit Melanismus (hier ist die Zeichnung teilweise erkennbar). 73 | mini|250px|Schwarzes Jaguarweibchen mit normal gefärbtem Jungtier (Zoo Salzburg) 74 | mini|250px|Verbreitungsgebiet des Jaguars:ursprünglich (rot und grün) und heute (grün) 75 | mini|250px|Schädel eines Jaguars mit sichtbar starkem Jochbein sowie Unterkiefer 76 | mini|250px|Jaguarmutter, die ihr Junges aufnimmt 77 | 78 | Der Jaguar (Panthera onca","The jaguar (Panthera onca) is a large felid species and the only extant member of the genus Panthera native to the Americas. The jaguar's present range extends from Southwestern United States and Mexico in North America, across much of Central America, and south to Paraguay and northern Argentina in South America. Though there are single cats now living within the Western United States, the species has largely been extirpated from the United States since the early 20th century. It is listed as N",https://upload.wikimedia.org/wikipedia/commons/a/a1/Statuette_Karaj%C3%A0_MHNT.ETH.2011.17.85.jpg 79 | värnamo,värnamo,"Värnamo ist eine Kleinstadt in der schwedischen Provinz Jönköpings län und der historischen Provinz Småland. Die im Tal des Flusses Lagan und an der Europastraße 4 gelegene Stadt, ist Hauptort der gleichnamigen Gemeinde. 80 | 81 | Verkehr 82 | Värnamo liegt an der Nord-Süd verlaufenden Europastraße 4. Der Ort hat einen Bahnhof an der Bahnstrecke Halmstad–Nässjö. 83 | 84 | Sehenswürdigkeiten 85 | Wohnhaus des Designers Bruno Mathsson mit Museum 86 | Park und Freilichtmuseum Apladalen 87 | Kirche von Värnamo aus dem 19. Jahrhunde","Värnamo () is a locality and the seat of Värnamo Municipality, Jönköping County, Sweden, with 19,817 inhabitants in 2018. 88 | 89 | History 90 | Värnamo traces its history back to a village in the medieval age; the first written mention of it stems from the 13th century. It came into existence as a village to the eastern side of a fordable place over the Lagan, a river that for large parts is difficult to travel by. As there are also smaller streams to the south and west of this location, it was considered so",https://upload.wikimedia.org/wikipedia/commons/7/77/Apladalen_Holzh%C3%A4user.jpg 91 | -------------------------------------------------------------------------------- /evaluation/rbo/rbo.py: -------------------------------------------------------------------------------- 1 | """Rank-biased overlap, a ragged sorted list similarity measure. 2 | 3 | See http://doi.acm.org/10.1145/1852102.1852106 for details. All functions 4 | directly corresponding to concepts from the paper are named so that they can be 5 | clearly cross-identified. 6 | 7 | The definition of overlap has been modified to account for ties. Without this, 8 | results for lists with tied items were being inflated. The modification itself 9 | is not mentioned in the paper but seems to be reasonable, see function 10 | ``overlap()``. Places in the code which diverge from the spec in the paper 11 | because of this are highlighted with comments. 12 | 13 | The two main functions for performing an RBO analysis are ``rbo()`` and 14 | ``rbo_dict()``; see their respective docstrings for how to use them. 15 | 16 | The following doctest just checks that equivalent specifications of a 17 | problem yield the same result using both functions: 18 | 19 | >>> lst1 = [{"c", "a"}, "b", "d"] 20 | >>> lst2 = ["a", {"c", "b"}, "d"] 21 | >>> ans_rbo = _round(rbo(lst1, lst2, p=.9)) 22 | >>> dct1 = dict(a=1, b=2, c=1, d=3) 23 | >>> dct2 = dict(a=1, b=2, c=2, d=3) 24 | >>> ans_rbo_dict = _round(rbo_dict(dct1, dct2, p=.9, sort_ascending=True)) 25 | >>> ans_rbo == ans_rbo_dict 26 | True 27 | 28 | """ 29 | 30 | from __future__ import division 31 | 32 | import math 33 | from bisect import bisect_left 34 | from collections import namedtuple 35 | 36 | 37 | RBO = namedtuple("RBO", "min res ext") 38 | RBO.__doc__ += ": Result of full RBO analysis" 39 | RBO.min.__doc__ = "Lower bound estimate" 40 | RBO.res.__doc__ = "Residual corresponding to min; min + res is an upper bound estimate" 41 | RBO.ext.__doc__ = "Extrapolated point estimate" 42 | 43 | 44 | def _round(obj): 45 | if isinstance(obj, RBO): 46 | return RBO(_round(obj.min), _round(obj.res), _round(obj.ext)) 47 | else: 48 | return round(obj, 3) 49 | 50 | 51 | def set_at_depth(lst, depth): 52 | ans = set() 53 | for v in lst[:depth]: 54 | if isinstance(v, set): 55 | ans.update(v) 56 | else: 57 | ans.add(v) 58 | return ans 59 | 60 | 61 | def raw_overlap(list1, list2, depth): 62 | """Overlap as defined in the article. 63 | 64 | """ 65 | set1, set2 = set_at_depth(list1, depth), set_at_depth(list2, depth) 66 | return len(set1.intersection(set2)), len(set1), len(set2) 67 | 68 | 69 | def overlap(list1, list2, depth): 70 | """Overlap which accounts for possible ties. 71 | 72 | This isn't mentioned in the paper but should be used in the ``rbo*()`` 73 | functions below, otherwise overlap at a given depth might be > depth which 74 | inflates the result. 75 | 76 | There are no guidelines in the paper as to what's a good way to calculate 77 | this, but a good guess is agreement scaled by the minimum between the 78 | requested depth and the lengths of the considered lists (overlap shouldn't 79 | be larger than the number of ranks in the shorter list, otherwise results 80 | are conspicuously wrong when the lists are of unequal lengths -- rbo_ext is 81 | not between rbo_min and rbo_min + rbo_res. 82 | 83 | >>> overlap("abcd", "abcd", 3) 84 | 3.0 85 | 86 | >>> overlap("abcd", "abcd", 5) 87 | 4.0 88 | 89 | >>> overlap(["a", {"b", "c"}, "d"], ["a", {"b", "c"}, "d"], 2) 90 | 2.0 91 | 92 | >>> overlap(["a", {"b", "c"}, "d"], ["a", {"b", "c"}, "d"], 3) 93 | 3.0 94 | 95 | """ 96 | return agreement(list1, list2, depth) * min(depth, len(list1), len(list2)) 97 | # NOTE: comment the preceding and uncomment the following line if you want 98 | # to stick to the algorithm as defined by the paper 99 | # return raw_overlap(list1, list2, depth)[0] 100 | 101 | 102 | def agreement(list1, list2, depth): 103 | """Proportion of shared values between two sorted lists at given depth. 104 | 105 | >>> _round(agreement("abcde", "abdcf", 1)) 106 | 1.0 107 | >>> _round(agreement("abcde", "abdcf", 3)) 108 | 0.667 109 | >>> _round(agreement("abcde", "abdcf", 4)) 110 | 1.0 111 | >>> _round(agreement("abcde", "abdcf", 5)) 112 | 0.8 113 | >>> _round(agreement([{1, 2}, 3], [1, {2, 3}], 1)) 114 | 0.667 115 | >>> _round(agreement([{1, 2}, 3], [1, {2, 3}], 2)) 116 | 1.0 117 | 118 | """ 119 | len_intersection, len_set1, len_set2 = raw_overlap(list1, list2, depth) 120 | return 2 * len_intersection / (len_set1 + len_set2) 121 | 122 | 123 | def cumulative_agreement(list1, list2, depth): 124 | return (agreement(list1, list2, d) for d in range(1, depth + 1)) 125 | 126 | 127 | def average_overlap(list1, list2, depth=None): 128 | """Calculate average overlap between ``list1`` and ``list2``. 129 | 130 | >>> _round(average_overlap("abcdefg", "zcavwxy", 1)) 131 | 0.0 132 | >>> _round(average_overlap("abcdefg", "zcavwxy", 2)) 133 | 0.0 134 | >>> _round(average_overlap("abcdefg", "zcavwxy", 3)) 135 | 0.222 136 | >>> _round(average_overlap("abcdefg", "zcavwxy", 4)) 137 | 0.292 138 | >>> _round(average_overlap("abcdefg", "zcavwxy", 5)) 139 | 0.313 140 | >>> _round(average_overlap("abcdefg", "zcavwxy", 6)) 141 | 0.317 142 | >>> _round(average_overlap("abcdefg", "zcavwxy", 7)) 143 | 0.312 144 | 145 | """ 146 | depth = min(len(list1), len(list2)) if depth is None else depth 147 | return sum(cumulative_agreement(list1, list2, depth)) / depth 148 | 149 | 150 | def rbo_at_k(list1, list2, p, depth=None): 151 | # ``p**d`` here instead of ``p**(d - 1)`` because enumerate starts at 152 | # 0 153 | depth = min(len(list1), len(list2)) if depth is None else depth 154 | d_a = enumerate(cumulative_agreement(list1, list2, depth)) 155 | return (1 - p) * sum(p ** d * a for (d, a) in d_a) 156 | 157 | 158 | def rbo_min(list1, list2, p, depth=None): 159 | """Tight lower bound on RBO. 160 | 161 | See equation (11) in paper. 162 | 163 | >>> _round(rbo_min("abcdefg", "abcdefg", .9)) 164 | 0.767 165 | >>> _round(rbo_min("abcdefgh", "abcdefg", .9)) 166 | 0.767 167 | 168 | """ 169 | depth = min(len(list1), len(list2)) if depth is None else depth 170 | x_k = overlap(list1, list2, depth) 171 | log_term = x_k * math.log(1 - p) 172 | sum_term = sum( 173 | p ** d / d * (overlap(list1, list2, d) - x_k) for d in range(1, depth + 1) 174 | ) 175 | return (1 - p) / p * (sum_term - log_term) 176 | 177 | 178 | def rbo_res(list1, list2, p): 179 | """Upper bound on residual overlap beyond evaluated depth. 180 | 181 | See equation (30) in paper. 182 | 183 | NOTE: The doctests weren't verified against manual computations but seem 184 | plausible. In particular, for identical lists, ``rbo_min()`` and 185 | ``rbo_res()`` should add up to 1, which is the case. 186 | 187 | >>> _round(rbo_res("abcdefg", "abcdefg", .9)) 188 | 0.233 189 | >>> _round(rbo_res("abcdefg", "abcdefghijklmnopqrstuvwxyz", .9)) 190 | 0.239 191 | 192 | """ 193 | S, L = sorted((list1, list2), key=len) 194 | s, l = len(S), len(L) 195 | x_l = overlap(list1, list2, l) 196 | # since overlap(...) can be fractional in the general case of ties and f 197 | # must be an integer --> math.ceil() 198 | f = int(math.ceil(l + s - x_l)) 199 | # upper bound of range() is non-inclusive, therefore + 1 is needed 200 | term1 = s * sum(p ** d / d for d in range(s + 1, f + 1)) 201 | term2 = l * sum(p ** d / d for d in range(l + 1, f + 1)) 202 | term3 = x_l * (math.log(1 / (1 - p)) - sum(p ** d / d for d in range(1, f + 1))) 203 | return p ** s + p ** l - p ** f - (1 - p) / p * (term1 + term2 + term3) 204 | 205 | 206 | def rbo_ext(list1, list2, p): 207 | """RBO point estimate based on extrapolating observed overlap. 208 | 209 | See equation (32) in paper. 210 | 211 | NOTE: The doctests weren't verified against manual computations but seem 212 | plausible. 213 | 214 | >>> _round(rbo_ext("abcdefg", "abcdefg", .9)) 215 | 1.0 216 | >>> _round(rbo_ext("abcdefg", "bacdefg", .9)) 217 | 0.9 218 | 219 | """ 220 | S, L = sorted((list1, list2), key=len) 221 | s, l = len(S), len(L) 222 | x_l = overlap(list1, list2, l) 223 | x_s = overlap(list1, list2, s) 224 | # the paper says overlap(..., d) / d, but it should be replaced by 225 | # agreement(..., d) defined as per equation (28) so that ties are handled 226 | # properly (otherwise values > 1 will be returned) 227 | # sum1 = sum(p**d * overlap(list1, list2, d)[0] / d for d in range(1, l + 1)) 228 | sum1 = sum(p ** d * agreement(list1, list2, d) for d in range(1, l + 1)) 229 | sum2 = sum(p ** d * x_s * (d - s) / s / d for d in range(s + 1, l + 1)) 230 | term1 = (1 - p) / p * (sum1 + sum2) 231 | term2 = p ** l * ((x_l - x_s) / l + x_s / s) 232 | return term1 + term2 233 | 234 | 235 | def rbo(list1, list2, p): 236 | """Complete RBO analysis (lower bound, residual, point estimate). 237 | 238 | ``list`` arguments should be already correctly sorted iterables and each 239 | item should either be an atomic value or a set of values tied for that 240 | rank. ``p`` is the probability of looking for overlap at rank k + 1 after 241 | having examined rank k. 242 | 243 | >>> lst1 = [{"c", "a"}, "b", "d"] 244 | >>> lst2 = ["a", {"c", "b"}, "d"] 245 | >>> _round(rbo(lst1, lst2, p=.9)) 246 | RBO(min=0.489, res=0.477, ext=0.967) 247 | 248 | """ 249 | if not 0 <= p <= 1: 250 | raise ValueError("The ``p`` parameter must be between 0 and 1.") 251 | args = (list1, list2, p) 252 | return RBO(rbo_min(*args), rbo_res(*args), rbo_ext(*args)) 253 | 254 | 255 | def sort_dict(dct, *, ascending=False): 256 | """Sort keys in ``dct`` according to their corresponding values. 257 | 258 | Sorts in descending order by default, because the values are 259 | typically scores, i.e. the higher the better. Specify 260 | ``ascending=True`` if the values are ranks, or some sort of score 261 | where lower values are better. 262 | 263 | Ties are handled by creating sets of tied keys at the given position 264 | in the sorted list. 265 | 266 | >>> dct = dict(a=1, b=2, c=1, d=3) 267 | >>> list(sort_dict(dct)) == ['d', 'b', {'a', 'c'}] 268 | True 269 | >>> list(sort_dict(dct, ascending=True)) == [{'a', 'c'}, 'b', 'd'] 270 | True 271 | 272 | """ 273 | scores = [] 274 | items = [] 275 | # items should be unique, scores don't have to 276 | for item, score in dct.items(): 277 | if not ascending: 278 | score *= -1 279 | i = bisect_left(scores, score) 280 | if i == len(scores): 281 | scores.append(score) 282 | items.append(item) 283 | elif scores[i] == score: 284 | existing_item = items[i] 285 | if isinstance(existing_item, set): 286 | existing_item.add(item) 287 | else: 288 | items[i] = {existing_item, item} 289 | else: 290 | scores.insert(i, score) 291 | items.insert(i, item) 292 | return items 293 | 294 | 295 | def rbo_dict(dict1, dict2, p, *, sort_ascending=False): 296 | """Wrapper around ``rbo()`` for dict input. 297 | 298 | Each dict maps items to be sorted to the score according to which 299 | they should be sorted. The RBO analysis is then performed on the 300 | resulting sorted lists. 301 | 302 | The sort is descending by default, because scores are typically the 303 | higher the better, but this can be overridden by specifying 304 | ``sort_ascending=True``. 305 | 306 | >>> dct1 = dict(a=1, b=2, c=1, d=3) 307 | >>> dct2 = dict(a=1, b=2, c=2, d=3) 308 | >>> _round(rbo_dict(dct1, dct2, p=.9, sort_ascending=True)) 309 | RBO(min=0.489, res=0.477, ext=0.967) 310 | 311 | """ 312 | list1, list2 = ( 313 | sort_dict(dict1, ascending=sort_ascending), 314 | sort_dict(dict2, ascending=sort_ascending), 315 | ) 316 | return rbo(list1, list2, p) 317 | 318 | 319 | if __name__ in ("__main__", "__console__"): 320 | import doctest 321 | 322 | doctest.testmod() 323 | -------------------------------------------------------------------------------- /utils/preprocessing.py: -------------------------------------------------------------------------------- 1 | from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer 2 | import string 3 | from nltk.corpus import stopwords as stop_words 4 | import warnings 5 | import numpy as np 6 | 7 | # ----- Original ----- 8 | class WhiteSpacePreprocessing(): 9 | """ 10 | Provides a very simple preprocessing script that filters infrequent tokens from text 11 | """ 12 | def __init__(self, documents, stopwords_language="english", vocabulary_size=2000, min_len=10, max_len=200): 13 | """ 14 | 15 | :param documents: list of strings 16 | :param stopwords_language: string of the language of the stopwords (see nltk stopwords) 17 | :param vocabulary_size: the number of most frequent words to include in the documents. Infrequent words will be discarded from the list of preprocessed documents 18 | """ 19 | self.documents = documents 20 | self.stopwords = set(stop_words.words(stopwords_language) + stop_words.words("english")) 21 | self.vocabulary_size = vocabulary_size 22 | self.max_len = max_len 23 | self.min_len = min_len 24 | 25 | def preprocess(self): 26 | """ 27 | Note that if after filtering some documents do not contain words we remove them. That is why we return also the 28 | list of unpreprocessed documents. 29 | 30 | :return: preprocessed documents, unpreprocessed documents and the vocabulary list 31 | """ 32 | print("Preprocessing", len(self.documents), "documents") 33 | print("Max seq length:", self.max_len) 34 | 35 | # --my changes-- truncate raw articles to the first 200 tokens 36 | truncated_docs = [' '.join(doc.split()[:self.max_len]) for doc in self.documents] 37 | self.documents = truncated_docs 38 | # --end my changes -- 39 | preprocessed_docs_tmp = self.documents 40 | preprocessed_docs_tmp = [doc.lower() for doc in preprocessed_docs_tmp] 41 | preprocessed_docs_tmp = [doc.translate( 42 | str.maketrans(string.punctuation, ' ' * len(string.punctuation))) for doc in preprocessed_docs_tmp] 43 | preprocessed_docs_tmp = [' '.join([w for w in doc.split() if len(w) > 2 and w not in self.stopwords]) 44 | for doc in preprocessed_docs_tmp] 45 | 46 | vectorizer = CountVectorizer(max_features=self.vocabulary_size, token_pattern=r'\b[a-zA-Z]{2,}\b') 47 | vectorizer.fit_transform(preprocessed_docs_tmp) 48 | vocabulary = set(vectorizer.get_feature_names()) 49 | preprocessed_docs_tmp = [' '.join([w for w in doc.split() if w in vocabulary]) 50 | for doc in preprocessed_docs_tmp] 51 | 52 | preprocessed_docs, unpreprocessed_docs = [], [] 53 | for i, doc in enumerate(preprocessed_docs_tmp): 54 | if len(doc.split()) >= self.min_len: 55 | # if len(doc.split()) > self.max_len: 56 | # doc = ' '.join(doc.split()[:self.max_len]) 57 | # if len(self.documents[i].split()) > self.max_len: 58 | # self.documents[i] = ' '.join(self.documents[i].split()[:self.max_len]) 59 | preprocessed_docs.append(doc) 60 | unpreprocessed_docs.append(self.documents[i]) 61 | 62 | return preprocessed_docs, unpreprocessed_docs, list(vocabulary) 63 | 64 | 65 | # ----- Multimodal AND Multilingual (M3L) ----- 66 | 67 | class WhiteSpacePreprocessingM3L(): 68 | """ 69 | Provides a very simple preprocessing script for aligned multimodal AND multilingual data 70 | """ 71 | def __init__(self, documents, image_urls, stopwords_languages, vocabulary_size=2000, min_len=10, max_len=200): 72 | """ 73 | 74 | :param documents: list of lists of strings, e.g. [['good morning', 'thank you'], ['guten morgen', 'danke sie']] 75 | :param stopwords_language: list of strings of the languages (see nltk stopwords) 76 | :param vocabulary_size: the number of most frequent words to include in the documents. Infrequent words will be discarded from the list of preprocessed documents 77 | """ 78 | self.documents = documents 79 | self.image_urls = image_urls 80 | self.num_lang = len(stopwords_languages) 81 | self.languages = stopwords_languages 82 | self.stopwords = [] 83 | for lang in self.languages: 84 | self.stopwords.append(set(stop_words.words(lang))) 85 | # same vocab_size for all langs for now 86 | self.vocabulary_size = vocabulary_size 87 | self.min_len = min_len 88 | self.max_len = max_len 89 | # if user has custom stopwords list 90 | #self.custom_stops = custom_stops 91 | 92 | def preprocess(self): 93 | """ 94 | Note that if after filtering some documents do not contain words we remove them. That is why we return also the 95 | list of unpreprocessed documents. 96 | 97 | :return: preprocessed documents, unpreprocessed documents and the vocabulary list 98 | """ 99 | 100 | for l in range(self.num_lang): 101 | truncated_docs = [' '.join(doc.split()[:self.max_len]) for doc in self.documents[l]] 102 | self.documents[l] = truncated_docs 103 | 104 | preprocessed_docs_tmp = [] 105 | vocabulary = [] 106 | for l in range(self.num_lang): 107 | preprocessed_docs = [doc.lower() for doc in self.documents[l]] 108 | preprocessed_docs = [doc.translate( 109 | str.maketrans(string.punctuation, ' ' * len(string.punctuation))) for doc in preprocessed_docs] 110 | preprocessed_docs = [' '.join([w for w in doc.split() if len(w) > 2 and w not in self.stopwords[l]]) 111 | for doc in preprocessed_docs] 112 | vectorizer = CountVectorizer(max_features=self.vocabulary_size, token_pattern=r'\b[a-zA-Z]{2,}\b') 113 | vectorizer.fit_transform(preprocessed_docs) 114 | vocabulary_lang = set(vectorizer.get_feature_names()) 115 | preprocessed_docs = [' '.join([w for w in doc.split() if w in vocabulary_lang]) for doc in preprocessed_docs] 116 | preprocessed_docs_tmp.append(preprocessed_docs) 117 | vocabulary.append(list(vocabulary_lang)) 118 | # print('vocab size:', len(vocabulary_lang)) 119 | 120 | if self.num_lang == 1: 121 | preprocessed_data_final = [[]] 122 | unpreprocessed_data_final = [[]] 123 | image_urls_final = [] 124 | 125 | for i in range(len(preprocessed_docs_tmp[0])): 126 | doc1 = preprocessed_docs_tmp[0][i] 127 | image_url = self.image_urls[i] 128 | if self.min_len <= len(doc1.split()): 129 | preprocessed_data_final[0].append(doc1) 130 | 131 | unpreprocessed_data_final[0].append(self.documents[0][i]) 132 | 133 | image_urls_final.append(image_url) 134 | 135 | 136 | elif self.num_lang == 2: 137 | 138 | preprocessed_data_final = [[], []] 139 | unpreprocessed_data_final = [[], []] 140 | image_urls_final = [] 141 | # docs must be aligned across languages and modalities (text-image) 142 | for i in range(len(preprocessed_docs_tmp[0])): 143 | doc1 = preprocessed_docs_tmp[0][i] 144 | doc2 = preprocessed_docs_tmp[1][i] 145 | image_url = self.image_urls[i] 146 | if self.min_len <= len(doc1.split()) and self.min_len <= len(doc2.split()): 147 | preprocessed_data_final[0].append(doc1) 148 | preprocessed_data_final[1].append(doc2) 149 | 150 | unpreprocessed_data_final[0].append(self.documents[0][i]) 151 | unpreprocessed_data_final[1].append(self.documents[1][i]) 152 | 153 | image_urls_final.append(image_url) 154 | 155 | else: 156 | # TODO: rewrite in generic form for any number of languages 157 | raise NonImplementedError("Cannot process number of languages: %s" %self.num_lang) 158 | 159 | 160 | # preprocessed_data_final is a list of list of strings (processed articles) and image urls 161 | # unpreprocessed_data_final is a list of list of strings (original articles) and image urls 162 | # vocabulary is a list of list of words (separate vocabularies for each language) 163 | return preprocessed_data_final, unpreprocessed_data_final, vocabulary, image_urls_final 164 | 165 | 166 | # ----- Multilingual only ----- 167 | class WhiteSpacePreprocessingMultilingual(): 168 | """ 169 | Provides a very simple preprocessing script for aligned multilingual documents 170 | """ 171 | def __init__(self, documents, stopwords_languages, vocabulary_size=2000, min_len=10, custom_stops=None, max_len=200): 172 | """ 173 | 174 | :param documents: list of lists of strings, e.g. [['good morning', 'thank you'], ['guten morgen', 'danke sie']] 175 | :param stopwords_language: list of strings of the languages (see nltk stopwords) 176 | :param vocabulary_size: the number of most frequent words to include in the documents. Infrequent words will be discarded from the list of preprocessed documents 177 | """ 178 | self.documents = documents 179 | self.num_lang = len(stopwords_languages) 180 | self.languages = stopwords_languages 181 | self.stopwords = [] 182 | for lang in self.languages: 183 | self.stopwords.append(set(stop_words.words(lang))) 184 | # same vocab_size for all langs for now 185 | self.vocabulary_size = vocabulary_size 186 | # min/max article length after preprocessing 187 | self.min_len = min_len 188 | self.max_len = max_len 189 | # if user has custom stopwords list 190 | self.custom_stops = custom_stops 191 | 192 | def preprocess(self): 193 | """ 194 | Note that if after filtering some documents do not contain words we remove them. That is why we return also the 195 | list of unpreprocessed documents. 196 | 197 | :return: preprocessed documents, unpreprocessed documents and the vocabulary list 198 | """ 199 | # truncate raw articles to the first max_len tokens 200 | for l in range(self.num_lang): 201 | truncated_docs = [' '.join(doc.split()[:self.max_len]) for doc in self.documents[l]] 202 | self.documents[l] = truncated_docs 203 | 204 | preprocessed_docs_tmp = [] 205 | vocabulary = [] 206 | for l in range(self.num_lang): 207 | print("--- lang", l, ":", self.languages[l], "---") 208 | preprocessed_docs = [doc.lower() for doc in self.documents[l]] 209 | preprocessed_docs = [doc.translate( 210 | str.maketrans(string.punctuation, ' ' * len(string.punctuation))) for doc in preprocessed_docs] 211 | preprocessed_docs = [' '.join([w for w in doc.split() if len(w) > 2 and w not in self.stopwords[l]]) 212 | for doc in preprocessed_docs] 213 | if self.custom_stops is not None: 214 | preprocessed_docs = [' '.join([w for w in doc.split() if len(w) > 2 and w not in self.custom_stops[l]]) 215 | for doc in preprocessed_docs] 216 | vectorizer = CountVectorizer(max_features=self.vocabulary_size, token_pattern=r'\b[a-zA-Z]{2,}\b') 217 | vectorizer.fit_transform(preprocessed_docs) 218 | vocabulary_lang = set(vectorizer.get_feature_names()) 219 | print('vocabulary_lang:', len(vocabulary_lang)) 220 | preprocessed_docs = [' '.join([w for w in doc.split() if w in vocabulary_lang]) for doc in preprocessed_docs] 221 | preprocessed_docs_tmp.append(preprocessed_docs) 222 | vocabulary.append(list(vocabulary_lang)) 223 | # print('vocab size:', len(vocabulary_lang)) 224 | 225 | preprocessed_docs_final = [[], []] 226 | unpreprocessed_docs_final = [[], []] 227 | # docs must be aligned across languages 228 | for i in range(len(preprocessed_docs_tmp[0])): 229 | doc1 = preprocessed_docs_tmp[0][i] 230 | doc2 = preprocessed_docs_tmp[1][i] 231 | if self.min_len <= len(doc1.split()) and self.min_len <= len(doc2.split()): 232 | # truncate docs if they exceed max_len 233 | # if len(doc1.split()) > self.max_len: 234 | # doc1 = " ".join(doc1.split()[:self.max_len]) 235 | # if len(doc2.split()) > self.max_len: 236 | # doc2 = " ".join(doc2.split()[:self.max_len]) 237 | preprocessed_docs_final[0].append(doc1) 238 | preprocessed_docs_final[1].append(doc2) 239 | unpreprocessed_docs_final[0].append(self.documents[0][i]) 240 | unpreprocessed_docs_final[1].append(self.documents[1][i]) 241 | # preprocessed_docs_final is a list of list of strings (processed articles) 242 | # unpreprocessed_docs_final is a list of list of strings (original articles) 243 | # vocabulary is a list of list of words (separate vocabularies for each language) 244 | return preprocessed_docs_final, unpreprocessed_docs_final, vocabulary -------------------------------------------------------------------------------- /utils/data_preparation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sentence_transformers import SentenceTransformer 3 | import scipy.sparse 4 | import warnings 5 | from datasets.dataset import CTMDataset, M3LDataset, PLTMDataset 6 | from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer 7 | from sklearn.preprocessing import OneHotEncoder 8 | import pandas as pd 9 | 10 | def get_bag_of_words(data, min_length): 11 | """ 12 | Creates the bag of words 13 | """ 14 | vect = [np.bincount(x[x != np.array(None)].astype('int'), minlength=min_length) 15 | for x in data if np.sum(x[x != np.array(None)]) != 0] 16 | 17 | vect = scipy.sparse.csr_matrix(vect) 18 | return vect 19 | 20 | 21 | def bert_embeddings_from_file(text_file, sbert_model_to_load, batch_size=200, max_seq_length=128): 22 | """ 23 | Creates SBERT Embeddings from an input file 24 | """ 25 | model = SentenceTransformer(sbert_model_to_load) 26 | 27 | if max_seq_length is not None: 28 | model.max_seq_length = max_seq_length 29 | 30 | with open(text_file, encoding="utf-8") as filino: 31 | texts = list(map(lambda x: x, filino.readlines())) 32 | 33 | check_max_local_length(max_seq_length, texts) 34 | 35 | return np.array(model.encode(texts, show_progress_bar=True, batch_size=batch_size)) 36 | 37 | 38 | def bert_embeddings_from_list(texts, sbert_model_to_load, batch_size=200, max_seq_length=128): 39 | """ 40 | Creates SBERT Embeddings from a list 41 | """ 42 | print('bert_embeddings_from_list - texts:', len(texts)) 43 | model = SentenceTransformer(sbert_model_to_load) 44 | 45 | if max_seq_length is not None: 46 | model.max_seq_length = max_seq_length 47 | 48 | check_max_local_length(max_seq_length, texts) 49 | 50 | return np.array(model.encode(texts, show_progress_bar=True, batch_size=batch_size)) 51 | 52 | 53 | def check_max_local_length(max_seq_length, texts): 54 | max_local_length = np.max([len(t.split()) for t in texts]) 55 | if max_local_length > max_seq_length: 56 | warnings.simplefilter('always', DeprecationWarning) 57 | warnings.warn(f"the longest document in your collection has {max_local_length} words, the model instead " 58 | f"truncates to {max_seq_length} tokens.") 59 | 60 | 61 | def image_embeddings_from_file(image_urls, embedding_file): 62 | print('image_urls:', len(image_urls)) 63 | if 'resnet' in embedding_file: 64 | embeddings = pd.read_csv(embedding_file, header=None) 65 | # embeddings = embeddings.rename({0: 'image_url'}, axis=1) # doesn't work for me 66 | desired_indices = [embeddings[embeddings[0] == im].index[0] for im in image_urls] 67 | embeddings = embeddings.iloc[desired_indices] 68 | embeddings = np.array(embeddings.drop(columns=[0], axis=1)) 69 | else: 70 | embeddings = pd.read_csv(embedding_file) 71 | desired_indices = [embeddings[embeddings.image_url == im].index[0] for im in image_urls] 72 | embeddings = embeddings.iloc[desired_indices] 73 | embeddings = np.array(embeddings.drop(labels='image_url', axis=1)) 74 | print("image embeddings:", embeddings.shape) 75 | return embeddings 76 | 77 | 78 | # ----- Multimodal and Multilingual (M3L) ----- 79 | 80 | class M3LTopicModelDataPreparation: 81 | 82 | def __init__(self, contextualized_model=None, vocabularies=None, image_emb_file=None): 83 | self.contextualized_model = contextualized_model 84 | self.vocabularies = vocabularies 85 | self.id2token = [] 86 | self.vectorizers = [] 87 | self.label_encoder = None 88 | self.vocab_sizes = [] 89 | self.image_emb_file = image_emb_file 90 | 91 | def load(self, contextualized_embeddings, bow_embeddings, image_embeddings, id2token): 92 | return M3LDataset(contextualized_embeddings, bow_embeddings, image_embeddings, id2token) 93 | 94 | # fit is for training data 95 | def fit(self, text_for_contextual, text_for_bow, image_urls): 96 | """ 97 | This method fits the vectorizer and gets the embeddings from the contextual model 98 | 99 | :param text_for_contextual: list of list of unpreprocessed documents to generate the contextualized embeddings 100 | :param text_for_bow: list of list of preprocessed documents for creating the bag-of-words 101 | :param labels: list of labels associated with each document (optional). 102 | 103 | """ 104 | if self.contextualized_model is None: 105 | raise Exception("You should define a contextualized model if you want to create the embeddings") 106 | 107 | # TODO: this count vectorizer removes tokens that have len = 1, might be unexpected for the users 108 | train_image_embeddings = image_embeddings_from_file(image_urls, self.image_emb_file) 109 | print("train_image_embeddings:", train_image_embeddings.shape) 110 | 111 | num_lang = len(text_for_bow) 112 | train_bow_embeddings, train_contextualized_embeddings = [], [] 113 | for l in range(num_lang): 114 | print("----- lang:", l, "-----") 115 | vectorizer = CountVectorizer(vocabulary=self.vocabularies[l]) 116 | train_bow_embeddings_lang = vectorizer.fit_transform(text_for_bow[l]) 117 | # we use the same SBERT model for both languages (multilingual SBERT) 118 | if len(self.contextualized_model.split(",")) == 1: 119 | print('context_model:', self.contextualized_model) 120 | train_contextualized_embeddings_lang = bert_embeddings_from_list(text_for_contextual[l], self.contextualized_model) 121 | # or we can use different SBERT models per language (monolingual SBERTs) 122 | else: 123 | context_model = self.contextualized_model.split(",")[l].strip() 124 | print('context_model:', context_model) 125 | train_contextualized_embeddings_lang = bert_embeddings_from_list(text_for_contextual[l], context_model) 126 | print('train_bow_embeddings_lang:', train_bow_embeddings_lang.shape) 127 | train_bow_embeddings.append(train_bow_embeddings_lang) 128 | train_contextualized_embeddings.append(train_contextualized_embeddings_lang) 129 | print('train_contextualized_embeddings_lang:', train_contextualized_embeddings_lang.shape) 130 | vocab_lang = vectorizer.get_feature_names() 131 | self.vocab_sizes.append(len(self.vocabularies[l])) 132 | id2token_lang = {k: v for k, v in zip(range(0, len(self.vocabularies[l])), self.vocabularies[l])} 133 | self.id2token.append(id2token_lang) 134 | 135 | return M3LDataset(train_contextualized_embeddings, train_bow_embeddings, train_image_embeddings, self.id2token, num_lang) 136 | 137 | # transform is for data during inference--dataset is monolingual during inference 138 | def transform(self, text_for_contextual=None, image_urls=None, lang_index=0): 139 | """ 140 | This methods create the input for the prediction. Essentially, it creates the embeddings with the contextualized 141 | model of choice and with trained vectorizer. 142 | 143 | If text_for_bow is missing, it should be because we are using ZeroShotTM 144 | """ 145 | if self.contextualized_model is None: 146 | raise Exception("You should define a contextualized model if you want to create the embeddings") 147 | 148 | # get SBERT embeddings for contextualized 149 | if text_for_contextual is not None: 150 | if len(self.contextualized_model.split(",")) == 1: 151 | print('context_model:', self.contextualized_model) 152 | test_contextualized_embeddings = bert_embeddings_from_list(text_for_contextual, self.contextualized_model) 153 | else: 154 | context_model = self.contextualized_model.split(",")[lang_index].strip() 155 | print('context_model:', context_model) 156 | test_contextualized_embeddings = bert_embeddings_from_list(text_for_contextual, context_model) 157 | # create dummy matrix for Bow 158 | test_bow_embeddings = scipy.sparse.csr_matrix(np.zeros((len(text_for_contextual), 1))) 159 | else: 160 | test_contextualized_embeddings = None 161 | if image_urls is not None: 162 | test_image_embeddings = image_embeddings_from_file(image_urls, self.image_emb_file) 163 | # create dummy matrix for Bow 164 | test_bow_embeddings = scipy.sparse.csr_matrix(np.zeros((len(image_urls), 1))) 165 | else: 166 | test_image_embeddings = None 167 | 168 | return M3LDataset(test_contextualized_embeddings, test_bow_embeddings, test_image_embeddings, self.id2token, is_inference=True) 169 | 170 | # ---- Multilingual only ---- 171 | 172 | class MultilingualTopicModelDataPreparation: 173 | 174 | def __init__(self, contextualized_model=None, vocabularies=None): 175 | self.contextualized_model = contextualized_model 176 | self.vocabularies = vocabularies 177 | self.id2token = [] 178 | self.vectorizers = [] 179 | self.label_encoder = None 180 | self.vocab_sizes = [] 181 | 182 | def load(self, contextualized_embeddings, bow_embeddings, id2token, labels=None): 183 | return PLTMDataset(contextualized_embeddings, bow_embeddings, id2token, labels) 184 | 185 | # fit is for training data 186 | def fit(self, text_for_contextual, text_for_bow, labels=None): 187 | """ 188 | This method fits the vectorizer and gets the embeddings from the contextual model 189 | 190 | :param text_for_contextual: list of list of unpreprocessed documents to generate the contextualized embeddings 191 | :param text_for_bow: list of list of preprocessed documents for creating the bag-of-words 192 | :param labels: list of labels associated with each document (optional). 193 | 194 | """ 195 | 196 | if self.contextualized_model is None: 197 | raise Exception("You should define a contextualized model if you want to create the embeddings") 198 | 199 | # TODO: this count vectorizer removes tokens that have len = 1, might be unexpected for the users 200 | 201 | num_lang = len(text_for_bow) 202 | train_bow_embeddings, train_contextualized_embeddings = [], [] 203 | for l in range(num_lang): 204 | print("----- lang:", l, "-----") 205 | vectorizer = CountVectorizer(vocabulary=self.vocabularies[l]) 206 | train_bow_embeddings_lang = vectorizer.fit_transform(text_for_bow[l]) 207 | if len(self.contextualized_model.split(",")) == 1: 208 | print('context_model:', self.contextualized_model) 209 | train_contextualized_embeddings_lang = bert_embeddings_from_list(text_for_contextual[l], self.contextualized_model) 210 | else: 211 | context_model = self.contextualized_model.split(",")[l].strip() 212 | print('context_model:', context_model) 213 | train_contextualized_embeddings_lang = bert_embeddings_from_list(text_for_contextual[l], context_model) 214 | train_bow_embeddings.append(train_bow_embeddings_lang) 215 | train_contextualized_embeddings.append(train_contextualized_embeddings_lang) 216 | print('train_bow_embeddings_lang:', train_bow_embeddings_lang.shape) 217 | print('train_contextualized_embeddings_lang:', train_contextualized_embeddings_lang.shape) 218 | vocab_lang = vectorizer.get_feature_names() 219 | print('vocab_lang:', len(vocab_lang)) 220 | #self.vocab.append(vocab_lang) 221 | self.vocab_sizes.append(len(self.vocabularies[l])) 222 | id2token_lang = {k: v for k, v in zip(range(0, len(self.vocabularies[l])), self.vocabularies[l])} 223 | self.id2token.append(id2token_lang) 224 | 225 | # don't care about labels for now 226 | if labels: 227 | self.label_encoder = OneHotEncoder() 228 | encoded_labels = self.label_encoder.fit_transform(np.array([labels]).reshape(-1, 1)) 229 | else: 230 | encoded_labels = None 231 | 232 | return PLTMDataset(train_contextualized_embeddings, train_bow_embeddings, self.id2token, encoded_labels, num_lang) 233 | 234 | # transform is for data during inference--dataset is monolingual during inference 235 | def transform(self, text_for_contextual, lang_index=0): 236 | """ 237 | This methods create the input for the prediction. Essentially, it creates the embeddings with the contextualized 238 | model of choice and with trained vectorizer. 239 | 240 | If text_for_bow is missing, it should be because we are using ZeroShotTM 241 | """ 242 | 243 | 244 | if self.contextualized_model is None: 245 | raise Exception("You should define a contextualized model if you want to create the embeddings") 246 | 247 | # if text_for_bow is not None: 248 | # test_bow_embeddings = self.vectorizer.transform(text_for_bow) 249 | # else: 250 | # # dummy matrix 251 | # warnings.simplefilter('always', DeprecationWarning) 252 | # warnings.warn("The method did not have in input the text_for_bow parameter. This IS EXPECTED if you " 253 | # "are using ZeroShotTM in a cross-lingual setting") 254 | 255 | # create dummy matrix for Bow 256 | test_bow_embeddings = scipy.sparse.csr_matrix(np.zeros((len(text_for_contextual), 1))) 257 | #print('test_bow_embeddings:', test_bow_embeddings.shape) 258 | # get SBERT embeddings for contextualized 259 | if len(self.contextualized_model.split(",")) == 1: 260 | print('context_model:', self.contextualized_model) 261 | test_contextualized_embeddings = bert_embeddings_from_list(text_for_contextual, self.contextualized_model) 262 | else: 263 | context_model = self.contextualized_model.split(",")[lang_index].strip() 264 | print('context_model:', context_model) 265 | test_contextualized_embeddings = bert_embeddings_from_list(text_for_contextual, context_model) 266 | #test_contextualized_embeddings = bert_embeddings_from_list(text_for_contextual, self.contextualized_model) 267 | #print('test_contextualized_embeddings:', test_contextualized_embeddings.shape) 268 | 269 | return PLTMDataset(test_contextualized_embeddings, test_bow_embeddings, self.id2token, is_inference=True) 270 | 271 | 272 | # ---- Original ----- 273 | 274 | class TopicModelDataPreparation: 275 | 276 | def __init__(self, contextualized_model=None): 277 | self.contextualized_model = contextualized_model 278 | self.vocab = [] 279 | self.id2token = {} 280 | self.vectorizer = None 281 | self.label_encoder = None 282 | 283 | def load(self, contextualized_embeddings, bow_embeddings, id2token, labels=None): 284 | return CTMDataset(contextualized_embeddings, bow_embeddings, id2token, labels) 285 | 286 | def fit(self, text_for_contextual, text_for_bow, labels=None): 287 | """ 288 | This method fits the vectorizer and gets the embeddings from the contextual model 289 | 290 | :param text_for_contextual: list of unpreprocessed documents to generate the contextualized embeddings 291 | :param text_for_bow: list of preprocessed documents for creating the bag-of-words 292 | :param labels: list of labels associated with each document (optional). 293 | 294 | """ 295 | 296 | if self.contextualized_model is None: 297 | raise Exception("You should define a contextualized model if you want to create the embeddings") 298 | 299 | # TODO: this count vectorizer removes tokens that have len = 1, might be unexpected for the users 300 | self.vectorizer = CountVectorizer() 301 | 302 | train_bow_embeddings = self.vectorizer.fit_transform(text_for_bow) 303 | train_contextualized_embeddings = bert_embeddings_from_list(text_for_contextual, self.contextualized_model) 304 | self.vocab = self.vectorizer.get_feature_names() 305 | self.id2token = {k: v for k, v in zip(range(0, len(self.vocab)), self.vocab)} 306 | 307 | if labels: 308 | self.label_encoder = OneHotEncoder() 309 | encoded_labels = self.label_encoder.fit_transform(np.array([labels]).reshape(-1, 1)) 310 | else: 311 | encoded_labels = None 312 | 313 | return CTMDataset(train_contextualized_embeddings, train_bow_embeddings, self.id2token, encoded_labels) 314 | 315 | def transform(self, text_for_contextual, text_for_bow=None, labels=None): 316 | """ 317 | This methods create the input for the prediction. Essentially, it creates the embeddings with the contextualized 318 | model of choice and with trained vectorizer. 319 | 320 | If text_for_bow is missing, it should be because we are using ZeroShotTM 321 | """ 322 | 323 | if self.contextualized_model is None: 324 | raise Exception("You should define a contextualized model if you want to create the embeddings") 325 | 326 | if text_for_bow is not None: 327 | test_bow_embeddings = self.vectorizer.transform(text_for_bow) 328 | else: 329 | # dummy matrix 330 | warnings.simplefilter('always', DeprecationWarning) 331 | warnings.warn("The method did not have in input the text_for_bow parameter. This IS EXPECTED if you " 332 | "are using ZeroShotTM in a cross-lingual setting") 333 | 334 | test_bow_embeddings = scipy.sparse.csr_matrix(np.zeros((len(text_for_contextual), 1))) 335 | test_contextualized_embeddings = bert_embeddings_from_list(text_for_contextual, self.contextualized_model) 336 | 337 | if labels: 338 | encoded_labels = self.label_encoder.transform(np.array([labels]).reshape(-1, 1)) 339 | else: 340 | encoded_labels = None 341 | 342 | return CTMDataset(test_contextualized_embeddings, test_bow_embeddings, self.id2token, encoded_labels) 343 | 344 | 345 | 346 | -------------------------------------------------------------------------------- /evaluation/measures.py: -------------------------------------------------------------------------------- 1 | from gensim.corpora.dictionary import Dictionary 2 | from gensim.models.coherencemodel import CoherenceModel 3 | from gensim.models import KeyedVectors 4 | import gensim.downloader as api 5 | import abc 6 | import numpy as np 7 | import itertools 8 | import torch 9 | from sentence_transformers import util 10 | from scipy.spatial.distance import cosine, jensenshannon 11 | from scipy.stats import entropy 12 | 13 | from evaluation.rbo import rbo 14 | 15 | # ----- JSD between two sets of topic distributions ----- 16 | class JSDivergence(abc.ABC): 17 | def __init__(self, topic_distrib1, topic_distrib2): 18 | """ 19 | :param doc_distribution_original_language: numpy array of the topical distribution of 20 | the documents in the original language (dim: num docs x num topics) 21 | :param doc_distribution_unseen_language: numpy array of the topical distribution of the 22 | documents in an unseen language (dim: num docs x num topics) 23 | """ 24 | super().__init__() 25 | self.topics1 = topic_distrib1 26 | self.topics2 = topic_distrib2 27 | if self.topics1.shape[0] != self.topics2.shape[0]: 28 | raise Exception('Distributions of the comparable documents must have the same length') 29 | 30 | def score(self): 31 | """ 32 | :return: average Jensen-Shannon Divergence between the distributions 33 | """ 34 | jsd2 = compute_jsd(self.topics1.T, self.topics2.T) 35 | jsd2[jsd2 == np.inf] = 0 36 | mean_div2 = np.mean(jsd2) 37 | 38 | return mean_div2 39 | 40 | 41 | def compute_jsd(p, q): 42 | p = np.asarray(p) 43 | q = np.asarray(q) 44 | p /= p.sum() 45 | q /= q.sum() 46 | m = (p + q) / 2 47 | return (entropy(p, m) + entropy(q, m)) / 2 48 | 49 | 50 | # ----- Text-image matching ----- 51 | class TextImageMatching(abc.ABC): 52 | """ 53 | :param doc_topics: matrix/tensor of L x D x K which contains the doc-topic distributions for docs in the test set where 54 | L: no. of languages (assume L=2 for now) 55 | D: no. fo docs 56 | K: no. of topics 57 | """ 58 | def __init__(self, doc_topics, image_topics, titles): 59 | self.doc_topics = doc_topics 60 | self.image_topics = image_topics 61 | self.titles = titles 62 | 63 | # compute MRR of first relevant image 64 | def mrr_first_score(self): 65 | # get unique article titles 66 | unique_titles = list(set(self.titles)) 67 | mrr_collect = [] 68 | for title in unique_titles: 69 | # get all image indices for this title 70 | true_indices = np.where(np.array(self.titles) == title)[0] 71 | doc_index = true_indices[0] 72 | # get JSD between this article and all images in the test data 73 | doc_dist = np.repeat([self.doc_topics[doc_index]], repeats=self.image_topics.shape[0], axis=0) 74 | scores = compute_jsd(doc_dist, self.image_topics) 75 | # rank images according to lowest JSD 76 | pred_indices = np.argsort(scores) 77 | for rank, index in enumerate(pred_indices): 78 | # found first relevant image 79 | if index in true_indices: 80 | mrr_first = 1/(rank+1) 81 | mrr_collect.append(mrr_first) 82 | break 83 | mrr_collect = np.mean(mrr_collect) 84 | return mrr_collect 85 | 86 | 87 | # compute UAP 88 | def uap_score(self): 89 | # get unique article titles 90 | unique_titles = list(set(self.titles)) 91 | print("unique articles:", len(unique_titles)) 92 | uap_collect = [] 93 | for title in unique_titles: 94 | # get all image indices for this title 95 | true_indices = np.where(np.array(self.titles) == title)[0] 96 | index_text = true_indices[0] 97 | # get JSD between this article and all images in the test data 98 | query_doc_dist = np.repeat([self.doc_topics[index_text]], repeats=self.image_topics.shape[0], axis=0) 99 | scores = compute_jsd(query_doc_dist.T, self.image_topics.T) 100 | # rank images according to lowest JSD 101 | pred_indices = np.argsort(scores) 102 | # find all positions where a relevant image is found and compute precision 103 | prec_values = [] 104 | for rank, index in enumerate(pred_indices): 105 | if index in true_indices: 106 | prec = (len(prec_values)+1)/(rank+1) 107 | prec_values.append(prec) 108 | mean_prec = np.mean(prec_values) 109 | uap_collect.append(mean_prec) 110 | uap = np.mean(uap_collect) 111 | return uap 112 | 113 | # ----- Cross-lingual Document Retrieval ----- 114 | class CrosslingualRetrieval(abc.ABC): 115 | """ 116 | :param doc_topics: matrix/tensor of L x D x K which contains the doc-topic distributions for docs in the test set where 117 | L: no. of languages (assume L=2 for now) 118 | D: no. fo docs 119 | K: no. of topics 120 | """ 121 | def __init__(self, doc_topics1, doc_topics2): 122 | self.doc_topics1 = doc_topics1 123 | self.doc_topics2 = doc_topics2 124 | 125 | # MRR to evaluate document retrieval performance 126 | def mrr_score(self): 127 | total_docs = self.doc_topics1.shape[0] 128 | total_MRR = 0 129 | for doc_index in range(total_docs): 130 | # get JSD between query doc in lang1 and candidate docs in lang2 131 | query_doc_distrib = np.repeat([self.doc_topics1[doc_index]], repeats=self.doc_topics2.shape[0], axis=0) 132 | scores = compute_jsd(query_doc_distrib.T, self.doc_topics2.T) 133 | # compute the MRR 134 | pred_indices = np.argsort(scores) 135 | matching_index = np.where((pred_indices == doc_index) == True)[0][0] 136 | # indices are zero-indexed but MRR assumes top position is index-1 so we add 1 to every index 137 | MRR = float(1/(matching_index+1)) 138 | total_MRR += MRR 139 | final_MRR = total_MRR / total_docs 140 | return final_MRR 141 | 142 | 143 | # ----- Original ----- 144 | class Measure: 145 | def __init__(self): 146 | pass 147 | 148 | def score(self): 149 | pass 150 | 151 | 152 | class TopicDiversity(Measure): 153 | def __init__(self, topics): 154 | super().__init__() 155 | self.topics = topics 156 | 157 | def score(self, topk=25): 158 | """ 159 | :param topk: topk words on which the topic diversity will be computed 160 | :return: 161 | """ 162 | if topk > len(self.topics[0]): 163 | raise Exception('Words in topics are less than topk') 164 | else: 165 | unique_words = set() 166 | for t in self.topics: 167 | unique_words = unique_words.union(set(t[:topk])) 168 | td = len(unique_words) / (topk * len(self.topics)) 169 | return td 170 | 171 | 172 | class Coherence(abc.ABC): 173 | """ 174 | :param topics: a list of lists of the top-k words 175 | :param texts: (list of lists of strings) represents the corpus on which the empirical frequencies of words are computed 176 | """ 177 | def __init__(self, topics, texts): 178 | self.topics = topics 179 | self.texts = texts 180 | self.dictionary = Dictionary(self.texts) 181 | 182 | @abc.abstractmethod 183 | def score(self): 184 | pass 185 | 186 | 187 | class CoherenceNPMI(Coherence): 188 | def __init__(self, topics, texts): 189 | super().__init__(topics, texts) 190 | 191 | def score(self, topk=10): 192 | """ 193 | :param topk: how many most likely words to consider in the evaluation 194 | :return: NPMI coherence 195 | """ 196 | if topk > len(self.topics[0]): 197 | raise Exception('Words in topics are less than topk') 198 | else: 199 | npmi = CoherenceModel(topics=self.topics, texts=self.texts, dictionary=self.dictionary, 200 | coherence='c_npmi', topn=topk) 201 | return npmi.get_coherence() 202 | 203 | 204 | class CoherenceUMASS(Coherence): 205 | def __init__(self, topics, texts): 206 | super().__init__(topics, texts) 207 | 208 | def score(self, topk=10): 209 | """ 210 | :param topk: how many most likely words to consider in the evaluation 211 | :return: UMass coherence 212 | """ 213 | if topk > len(self.topics[0]): 214 | raise Exception('Words in topics are less than topk') 215 | else: 216 | umass = CoherenceModel(topics=self.topics, texts=self.texts, dictionary=self.dictionary, 217 | coherence='u_mass', topn=topk) 218 | return umass.get_coherence() 219 | 220 | 221 | class CoherenceUCI(Coherence): 222 | def __init__(self, topics, texts): 223 | super().__init__(topics, texts) 224 | 225 | def score(self, topk=10): 226 | """ 227 | :param topk: how many most likely words to consider in the evaluation 228 | :return: UCI coherence 229 | """ 230 | if topk > len(self.topics[0]): 231 | raise Exception('Words in topics are less than topk') 232 | else: 233 | uci = CoherenceModel(topics=self.topics, texts=self.texts, dictionary=self.dictionary, 234 | coherence='c_uci', topn=topk) 235 | return uci.get_coherence() 236 | 237 | 238 | class CoherenceCV(Coherence): 239 | def __init__(self, topics, texts): 240 | super().__init__(topics, texts) 241 | 242 | def score(self, topk=10): 243 | """ 244 | :param topk: how many most likely words to consider in the evaluation 245 | :return: C_V coherence 246 | """ 247 | if topk > len(self.topics[0]): 248 | raise Exception('Words in topics are less than topk') 249 | else: 250 | cv = CoherenceModel(topics=self.topics, texts=self.texts, dictionary=self.dictionary, 251 | coherence='c_v', topn=topk) 252 | return cv.get_coherence() 253 | 254 | 255 | class CoherenceWordEmbeddings(Measure): 256 | def __init__(self, topics, word2vec_path=None, binary=False): 257 | """ 258 | :param topics: a list of lists of the top-n most likely words 259 | :param word2vec_path: if word2vec_file is specified, it retrieves the word embeddings file (in word2vec format) to 260 | compute similarities between words, otherwise 'word2vec-google-news-300' is downloaded 261 | :param binary: if the word2vec file is binary 262 | """ 263 | super().__init__() 264 | self.topics = topics 265 | self.binary = binary 266 | if word2vec_path is None: 267 | self.wv = api.load('word2vec-google-news-300') 268 | else: 269 | self.wv = KeyedVectors.load_word2vec_format(word2vec_path, binary=binary) 270 | 271 | def score(self, topk=10, binary= False): 272 | """ 273 | :param topk: how many most likely words to consider in the evaluation 274 | :return: topic coherence computed on the word embeddings similarities 275 | """ 276 | if topk > len(self.topics[0]): 277 | raise Exception('Words in topics are less than topk') 278 | else: 279 | arrays = [] 280 | for index, topic in enumerate(self.topics): 281 | if len(topic) > 0: 282 | local_simi = [] 283 | for word1, word2 in itertools.combinations(topic[0:topk], 2): 284 | if word1 in self.wv.vocab and word2 in self.wv.vocab: 285 | local_simi.append(self.wv.similarity(word1, word2)) 286 | arrays.append(np.mean(local_simi)) 287 | return np.mean(arrays) 288 | 289 | 290 | class InvertedRBO(Measure): 291 | def __init__(self, topics): 292 | """ 293 | :param topics: a list of lists of words 294 | """ 295 | super().__init__() 296 | self.topics = topics 297 | 298 | def score(self, topk = 10, weight=0.9): 299 | """ 300 | :param weight: p (float), default 1.0: Weight of each agreement at depth d: 301 | p**(d-1). When set to 1.0, there is no weight, the rbo returns to average overlap. 302 | :return: rank_biased_overlap over the topics 303 | """ 304 | if topk > len(self.topics[0]): 305 | raise Exception('Words in topics are less than topk') 306 | else: 307 | collect = [] 308 | for list1, list2 in itertools.combinations(self.topics, 2): 309 | rbo_val = rbo.rbo(list1[:topk], list2[:topk], p=weight)[2] 310 | collect.append(rbo_val) 311 | return 1 - np.mean(collect) 312 | 313 | 314 | class Matches(Measure): 315 | def __init__(self, doc_distribution_original_language, doc_distribution_unseen_language): 316 | """ 317 | :param doc_distribution_original_language: numpy array of the topical distribution of 318 | the documents in the original language (dim: num docs x num topics) 319 | :param doc_distribution_unseen_language: numpy array of the topical distribution of the 320 | documents in an unseen language (dim: num docs x num topics) 321 | """ 322 | super().__init__() 323 | self.orig_lang_docs = doc_distribution_original_language 324 | self.unseen_lang_docs = doc_distribution_unseen_language 325 | if len(self.orig_lang_docs) != len(self.unseen_lang_docs): 326 | raise Exception('Distributions of the comparable documents must have the same length') 327 | 328 | def score(self): 329 | """ 330 | :return: proportion of matches between the predicted topic in the original language and 331 | the predicted topic in the unseen language of the document distributions 332 | """ 333 | matches = 0 334 | for d1, d2 in zip(self.orig_lang_docs, self.unseen_lang_docs): 335 | if np.argmax(d1) == np.argmax(d2): 336 | matches = matches + 1 337 | return matches/len(self.unseen_lang_docs) 338 | 339 | 340 | class KLDivergence(Measure): 341 | def __init__(self, doc_distribution_original_language, doc_distribution_unseen_language): 342 | """ 343 | :param doc_distribution_original_language: numpy array of the topical distribution of 344 | the documents in the original language (dim: num docs x num topics) 345 | :param doc_distribution_unseen_language: numpy array of the topical distribution of the 346 | documents in an unseen language (dim: num docs x num topics) 347 | """ 348 | super().__init__() 349 | self.orig_lang_docs = doc_distribution_original_language 350 | self.unseen_lang_docs = doc_distribution_unseen_language 351 | if len(self.orig_lang_docs) != len(self.unseen_lang_docs): 352 | raise Exception('Distributions of the comparable documents must have the same length') 353 | 354 | def score(self): 355 | """ 356 | :return: average kullback leibler divergence between the distributions 357 | """ 358 | kl_mean = 0 359 | for d1, d2 in zip(self.orig_lang_docs, self.unseen_lang_docs): 360 | kl_mean = kl_mean + kl_div(d1, d2) 361 | return kl_mean/len(self.unseen_lang_docs) 362 | 363 | 364 | def kl_div(a, b): 365 | a = np.asarray(a, dtype=np.float) 366 | b = np.asarray(b, dtype=np.float) 367 | return np.sum(np.where(a != 0, a * np.log(a / b), 0)) 368 | 369 | 370 | class CentroidDistance(Measure): 371 | def __init__(self, doc_distribution_original_language, doc_distribution_unseen_language, topics, word2vec_path=None, 372 | binary=True, topk=10): 373 | """ 374 | :param doc_distribution_original_language: numpy array of the topical distribution of the 375 | documents in the original language (dim: num docs x num topics) 376 | :param doc_distribution_unseen_language: numpy array of the topical distribution of the 377 | documents in an unseen language (dim: num docs x num topics) 378 | :param topics: a list of lists of the top-n most likely words 379 | :param word2vec_path: if word2vec_file is specified, it retrieves the word embeddings 380 | file (in word2vec format) to compute similarities between words, otherwise 381 | 'word2vec-google-news-300' is downloaded 382 | :param binary: if the word2vec file is binary 383 | :param topk: max number of topical words 384 | """ 385 | super().__init__() 386 | self.topics = [t[:topk] for t in topics] 387 | self.orig_lang_docs = doc_distribution_original_language 388 | self.unseen_lang_docs = doc_distribution_unseen_language 389 | if len(self.orig_lang_docs) != len(self.unseen_lang_docs): 390 | raise Exception('Distributions of the comparable documents must have the same length') 391 | 392 | if word2vec_path is None: 393 | self.wv = api.load('word2vec-google-news-300') 394 | else: 395 | self.wv = KeyedVectors.load_word2vec_format(word2vec_path, binary=binary) 396 | 397 | def score(self): 398 | """ 399 | :return: average centroid distance between the words of the most likely topic of the 400 | document distributions 401 | """ 402 | cd = 0 403 | for d1, d2 in zip(self.orig_lang_docs, self.unseen_lang_docs): 404 | top_words_orig = self.topics[np.argmax(d1)] 405 | top_words_unseen = self.topics[np.argmax(d2)] 406 | 407 | centroid_lang = self.get_centroid(top_words_orig) 408 | centroid_en = self.get_centroid(top_words_unseen) 409 | 410 | cd += (1 - cosine(centroid_lang, centroid_en)) 411 | return cd/len(self.unseen_lang_docs) 412 | 413 | def get_centroid(self, word_list): 414 | vector_list = [] 415 | for word in word_list: 416 | if word in self.wv.vocab: 417 | vector_list.append(self.wv.get_vector(word)) 418 | vec = sum(vector_list) 419 | return vec / np.linalg.norm(vec) 420 | 421 | -------------------------------------------------------------------------------- /models/multilingual_contrast.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import multiprocessing as mp 3 | import os 4 | import warnings 5 | from collections import defaultdict 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import torch 9 | import wordcloud 10 | from scipy.special import softmax 11 | from torch import optim 12 | from torch.optim.lr_scheduler import ReduceLROnPlateau 13 | from torch.utils.data import DataLoader 14 | import torch.nn.functional as F 15 | from tqdm import tqdm 16 | from contextualized_topic_models.utils.early_stopping.early_stopping import EarlyStopping 17 | 18 | # decooder network 19 | from contextualized_topic_models.networks.decoding_network import ContrastiveDecoderNetwork 20 | 21 | # for contrastive loss 22 | from pytorch_metric_learning.losses import NTXentLoss 23 | 24 | class MultilingualContrastiveTM: 25 | """Class to train the contextualized topic model. This is the more general class that we are keeping to 26 | avoid braking code, users should use the two subclasses ZeroShotTM and CombinedTm to do topic modeling. 27 | 28 | :param bow_size: int, dimension of input 29 | :param contextual_size: int, dimension of input that comes from BERT embeddings 30 | :param inference_type: string, you can choose between the contextual model and the combined model 31 | :param n_components: int, number of topic components, (default 10) 32 | :param model_type: string, 'prodLDA' or 'LDA' (default 'prodLDA') 33 | :param hidden_sizes: tuple, length = n_layers, (default (100, 100)) 34 | :param activation: string, 'softplus', 'relu', (default 'softplus') 35 | :param dropout: float, dropout to use (default 0.2) 36 | :param learn_priors: bool, make priors a learnable parameter (default True) 37 | :param batch_size: int, size of batch to use for training (default 64) 38 | :param lr: float, learning rate to use for training (default 2e-3) 39 | :param momentum: float, momentum to use for training (default 0.99) 40 | :param solver: string, optimizer 'adam' or 'sgd' (default 'adam') 41 | :param num_epochs: int, number of epochs to train for, (default 100) 42 | :param reduce_on_plateau: bool, reduce learning rate by 10x on plateau of 10 epochs (default False) 43 | :param num_data_loader_workers: int, number of data loader workers (default cpu_count). set it to 0 if you are using Windows 44 | :param label_size: int, number of total labels (default: 0) 45 | :param loss_weights: dict, it contains the name of the weight parameter (key) and the weight (value) for each loss. 46 | It supports only the weight parameter beta for now. If None, then the weights are set to 1 (default: None). 47 | 48 | """ 49 | 50 | def __init__(self, bow_size, contextual_size, n_components=10, model_type='prodLDA', 51 | hidden_sizes=(100, 100), activation='softplus', dropout=0.2, learn_priors=True, batch_size=16, 52 | lr=2e-3, momentum=0.99, solver='adam', num_epochs=100, reduce_on_plateau=False, 53 | num_data_loader_workers=mp.cpu_count(), label_size=0, loss_weights=None, languages=None): 54 | 55 | self.device = ( 56 | torch.device("cuda") 57 | if torch.cuda.is_available() 58 | else torch.device("cpu") 59 | ) 60 | 61 | assert isinstance(bow_size, int) and bow_size > 0, \ 62 | "input_size must by type int > 0." 63 | assert isinstance(n_components, int) and bow_size > 0, \ 64 | "n_components must by type int > 0." 65 | assert model_type in ['LDA', 'prodLDA'], \ 66 | "model must be 'LDA' or 'prodLDA'." 67 | assert isinstance(hidden_sizes, tuple), \ 68 | "hidden_sizes must be type tuple." 69 | assert activation in ['softplus', 'relu'], \ 70 | "activation must be 'softplus' or 'relu'." 71 | assert dropout >= 0, "dropout must be >= 0." 72 | assert isinstance(learn_priors, bool), "learn_priors must be boolean." 73 | assert isinstance(batch_size, int) and batch_size > 0, \ 74 | "batch_size must be int > 0." 75 | assert lr > 0, "lr must be > 0." 76 | assert isinstance(momentum, float) and 0 < momentum <= 1, \ 77 | "momentum must be 0 < float <= 1." 78 | assert solver in ['adam', 'sgd'], "solver must be 'adam' or 'sgd'." 79 | assert isinstance(reduce_on_plateau, bool), \ 80 | "reduce_on_plateau must be type bool." 81 | assert isinstance(num_data_loader_workers, int) and num_data_loader_workers >= 0, \ 82 | "num_data_loader_workers must by type int >= 0. set 0 if you are using windows" 83 | 84 | # langs is an array of language codes e.g. ['en', 'de'] 85 | self.languages = languages 86 | self.num_lang = len(languages) 87 | # bow_size is an array of size n_languages; one bow_size for each language 88 | self.bow_size = bow_size 89 | #n_components is same for all languages 90 | self.n_components = n_components 91 | #model_type is same for all languages 92 | self.model_type = model_type 93 | #hidden_sizes is same for all languages 94 | self.hidden_sizes = hidden_sizes 95 | #activation is same for all languages 96 | self.activation = activation 97 | #dropout is same for all langs 98 | self.dropout = dropout 99 | #learn_priors is same for all langs 100 | self.learn_priors = learn_priors 101 | #batch_size is same for all langs 102 | self.batch_size = batch_size 103 | #lr is same for all langs 104 | self.lr = lr 105 | #contextual_size is same for all langs 106 | self.contextual_size = contextual_size 107 | # same 108 | self.momentum = momentum 109 | # same 110 | self.solver = solver 111 | # same 112 | self.num_epochs = num_epochs 113 | # same 114 | self.reduce_on_plateau = reduce_on_plateau 115 | # name 116 | self.num_data_loader_workers = num_data_loader_workers 117 | 118 | # same for now 119 | if loss_weights: 120 | self.weights = loss_weights 121 | else: 122 | self.weights = {"KL": 0.01, "CL": 50} 123 | 124 | # contrastive decoder 125 | self.model = ContrastiveDecoderNetwork( 126 | bow_size, self.contextual_size, n_components, model_type, hidden_sizes, activation, 127 | dropout, learn_priors, label_size=label_size) 128 | 129 | self.early_stopping = None 130 | 131 | # init optimizers 132 | if self.solver == 'adam': 133 | self.optimizer = optim.Adam( 134 | self.model.parameters(), lr=lr, betas=(self.momentum, 0.99)) 135 | elif self.solver == 'sgd': 136 | self.optimizer = optim.SGD( 137 | self.model.parameters(), lr=lr, momentum=self.momentum) 138 | 139 | # init lr scheduler 140 | if self.reduce_on_plateau: 141 | self.scheduler = ReduceLROnPlateau(self.optimizer, patience=10) 142 | 143 | # performance attributes 144 | self.best_loss_train = float('inf') 145 | 146 | # training attributes 147 | self.model_dir = None 148 | self.train_data = None 149 | self.nn_epoch = None 150 | 151 | # validation attributes 152 | self.validation_data = None 153 | 154 | # learned topics 155 | # best_components: n_components x vocab_size 156 | self.best_components = None 157 | 158 | # Use cuda if available 159 | if torch.cuda.is_available(): 160 | self.USE_CUDA = True 161 | else: 162 | self.USE_CUDA = False 163 | 164 | self.model = self.model.to(self.device) 165 | 166 | def _infoNCE_loss(self, embeddings1, embeddings2, temperature=0.07): 167 | batch_size = embeddings1.shape[0] 168 | labels = torch.arange(batch_size) 169 | labels = torch.cat([labels, labels]) 170 | embeddings_cat = torch.cat([embeddings1, embeddings2]) 171 | loss_func = NTXentLoss() 172 | infonce_loss = loss_func(embeddings_cat, labels) 173 | return infonce_loss 174 | 175 | def _kl_loss1(self, thetas1, thetas2): 176 | theta_kld = F.kl_div(thetas1.log(), thetas2, reduction='sum') 177 | return theta_kld 178 | 179 | def _kl_loss2(self, prior_mean, prior_variance, 180 | posterior_mean, posterior_variance, posterior_log_variance): 181 | # KL term 182 | # var division term 183 | var_division = torch.sum(posterior_variance / prior_variance, dim=1) 184 | # diff means term 185 | diff_means = prior_mean - posterior_mean 186 | diff_term = torch.sum( 187 | (diff_means * diff_means) / prior_variance, dim=1) 188 | # logvar det division term 189 | logvar_det_division = \ 190 | prior_variance.log().sum() - posterior_log_variance.sum(dim=1) 191 | # combine terms 192 | KL = 0.5 * ( 193 | var_division + diff_term - self.n_components + logvar_det_division) 194 | return KL 195 | 196 | def _rl_loss(self, true_word_dists, pred_word_dists): 197 | # Reconstruction term 198 | RL = -torch.sum(true_word_dists * torch.log(pred_word_dists + 1e-10), dim=1) 199 | return RL 200 | 201 | def _train_epoch(self, loader): 202 | """Train epoch.""" 203 | self.model.train() 204 | train_loss = 0 205 | samples_processed = 0 206 | 207 | for batch_num, batch_samples in enumerate(loader): 208 | # batch_size x L x vocab_size 209 | X_bow = batch_samples['X_bow'] 210 | X_bow = X_bow.squeeze(dim=2) 211 | 212 | # batch_size x L x bert_size 213 | X_contextual = batch_samples['X_contextual'] 214 | #print('X_contextual:', X_contextual.shape) 215 | 216 | if self.USE_CUDA: 217 | X_bow = X_bow.cuda() 218 | X_contextual = X_contextual.cuda() 219 | 220 | # forward pass 221 | self.model.zero_grad() 222 | prior_mean, prior_variance, posterior_mean1, posterior_variance1, posterior_log_variance1,\ 223 | posterior_mean2, posterior_variance2, posterior_log_variance2, word_dists, thetas, z_samples = self.model(X_bow, X_contextual) 224 | 225 | # backward pass 226 | 227 | # recon_losses for each language 228 | rl_loss1 = self._rl_loss(X_bow[:,0,:], word_dists[0]) 229 | rl_loss2 = self._rl_loss(X_bow[:,1,:], word_dists[1]) 230 | 231 | # KL between distributions of every language pair 232 | kl_cross = self._kl_loss2(posterior_mean1, posterior_variance1, 233 | posterior_mean2, posterior_variance2, posterior_log_variance2) 234 | 235 | # InfoNCE loss/NTXentLoss 236 | infoNCE_cross = self._infoNCE_loss(thetas[0], thetas[1]) 237 | 238 | loss = rl_loss1 + rl_loss2 + \ 239 | self.weights["KL"]*kl_cross + \ 240 | self.weights["CL"]*infoNCE_cross 241 | loss = loss.sum() 242 | loss.backward() 243 | self.optimizer.step() 244 | 245 | # compute train loss 246 | samples_processed += X_bow.size()[0] 247 | train_loss += loss.item() 248 | 249 | train_loss /= samples_processed 250 | 251 | return samples_processed, train_loss 252 | 253 | def fit(self, train_dataset, validation_dataset=None, save_dir=None, verbose=False, patience=5, delta=0): 254 | """ 255 | Train the CTM model. 256 | 257 | :param train_dataset: PyTorch Dataset class for training data. 258 | :param validation_dataset: PyTorch Dataset class for validation data. If not None, the training stops if validation loss doesn't improve after a given patience 259 | :param save_dir: directory to save checkpoint models to. 260 | :param verbose: verbose 261 | :param patience: How long to wait after last time validation loss improved. Default: 5 262 | :param delta: Minimum change in the monitored quantity to qualify as an improvement. Default: 0 263 | 264 | """ 265 | # Print settings to output file 266 | if verbose: 267 | print("Settings: \n\ 268 | N Components: {}\n\ 269 | Topic Prior Mean: {}\n\ 270 | Topic Prior Variance: {}\n\ 271 | Model Type: {}\n\ 272 | Hidden Sizes: {}\n\ 273 | Activation: {}\n\ 274 | Dropout: {}\n\ 275 | Learn Priors: {}\n\ 276 | Learning Rate: {}\n\ 277 | Momentum: {}\n\ 278 | Reduce On Plateau: {}\n\ 279 | Save Dir: {}".format( 280 | self.n_components, 0.0, 281 | 1. - (1. / self.n_components), self.model_type, 282 | self.hidden_sizes, self.activation, self.dropout, self.learn_priors, 283 | self.lr, self.momentum, self.reduce_on_plateau, save_dir)) 284 | 285 | self.model_dir = save_dir 286 | self.train_data = train_dataset 287 | self.validation_data = validation_dataset 288 | if self.validation_data is not None: 289 | self.early_stopping = EarlyStopping(patience=patience, verbose=verbose, path=save_dir, delta=delta) 290 | train_loader = DataLoader( 291 | self.train_data, batch_size=self.batch_size, shuffle=True, 292 | num_workers=self.num_data_loader_workers) 293 | 294 | # init training variables 295 | train_loss = 0 296 | samples_processed = 0 297 | 298 | # train loop 299 | pbar = tqdm(self.num_epochs, position=0, leave=True) 300 | for epoch in range(self.num_epochs): 301 | print("-"*10, "Epoch", epoch+1, "-"*10) 302 | self.nn_epoch = epoch 303 | # train epoch 304 | s = datetime.datetime.now() 305 | sp, train_loss = self._train_epoch(train_loader) 306 | samples_processed += sp 307 | e = datetime.datetime.now() 308 | pbar.update(1) 309 | 310 | if self.validation_data is not None: 311 | validation_loader = DataLoader(self.validation_data, batch_size=self.batch_size, shuffle=True, 312 | num_workers=self.num_data_loader_workers) 313 | # train epoch 314 | s = datetime.datetime.now() 315 | val_samples_processed, val_loss = self._validation(validation_loader) 316 | e = datetime.datetime.now() 317 | 318 | # report 319 | if verbose: 320 | print("Epoch: [{}/{}]\tSamples: [{}/{}]\tValidation Loss: {}\tTime: {}".format( 321 | epoch + 1, self.num_epochs, val_samples_processed, 322 | len(self.validation_data) * self.num_epochs, val_loss, e - s)) 323 | 324 | pbar.set_description("Epoch: [{}/{}]\t Seen Samples: [{}/{}]\tTrain Loss: {}\tValid Loss: {}\tTime: {}".format( 325 | epoch + 1, self.num_epochs, samples_processed, 326 | len(self.train_data) * self.num_epochs, train_loss, val_loss, e - s)) 327 | 328 | self.early_stopping(val_loss, self) 329 | if self.early_stopping.early_stop: 330 | print("Early stopping") 331 | 332 | break 333 | else: 334 | # save last epoch 335 | self.best_components = self.model.beta 336 | if save_dir is not None: 337 | self.save(save_dir) 338 | pbar.set_description("Epoch: [{}/{}]\t Seen Samples: [{}/{}]\tTrain Loss: {}\tTime: {}".format( 339 | epoch + 1, self.num_epochs, samples_processed, 340 | len(self.train_data) * self.num_epochs, train_loss, e - s)) 341 | 342 | #print topics for every epoch 343 | topics = self.get_topics() 344 | for l in range(self.num_lang): 345 | print("-"*10, self.languages[l].upper(), "-"*10) 346 | for k in range(self.n_components): 347 | print("Topic", k+1, ":", ', '.join(topics[l][k])) 348 | 349 | pbar.close() 350 | 351 | def _validation(self, loader): 352 | """Validation epoch.""" 353 | self.model.eval() 354 | val_loss = 0 355 | samples_processed = 0 356 | for batch_samples in loader: 357 | # batch_size x L x vocab_size 358 | X_bow = batch_samples['X_bow'] 359 | print('X_bow orig:', X_bow.shape) 360 | X_bow = X_bow.squeeze(dim=2) 361 | print('X_bow squeeze:', X_bow.shape) 362 | 363 | # batch_size x L x bert_size 364 | X_contextual = batch_samples['X_contextual'] 365 | print('X_contextual:', X_contextual.shape) 366 | 367 | if self.USE_CUDA: 368 | X_bow = X_bow.cuda() 369 | X_contextual = X_contextual.cuda() 370 | 371 | # forward pass 372 | self.model.zero_grad() 373 | posterior_mean1, posterior_variance1, posterior_mean2, posterior_variance2,\ 374 | posterior_log_variance2, word_dists, thetas = self.model(X_contextual) 375 | 376 | # backward pass 377 | 378 | # RL for language1 379 | rl_loss1 = self._rl_loss(X_bow[:,0,:], word_dists[0]) 380 | 381 | # RL for language2 382 | rl_loss2 = self._rl_loss(X_bow[:,1,:], word_dists[1]) 383 | 384 | # KL and contrastive loss 385 | kl_loss = self._kl_loss(posterior_mean1, posterior_variance1, posterior_mean2, posterior_variance2, 386 | posterior_log_variance2) 387 | cl_loss = self._contrastive_loss(thetas[0], thetas[1]) 388 | 389 | loss = self.weights["beta"]*kl_loss + rl_loss1 + rl_loss2 + cl_loss 390 | loss = loss.sum() 391 | 392 | # compute train loss 393 | # samples_processed += X_bow.size()[0] 394 | samples_processed += X_bow.size()[0] 395 | val_loss += loss.item() 396 | 397 | val_loss /= samples_processed 398 | 399 | return samples_processed, val_loss 400 | 401 | def get_thetas(self, dataset, n_samples=20): 402 | """ 403 | Get the document-topic distribution for a dataset of topics. Includes multiple sampling to reduce variation via 404 | the parameter n_sample. 405 | 406 | :param dataset: a PyTorch Dataset containing the documents 407 | :param n_samples: the number of sample to collect to estimate the final distribution (the more the better). 408 | """ 409 | return self.get_doc_topic_distribution(dataset, n_samples=n_samples) 410 | 411 | def get_doc_topic_distribution(self, dataset, n_samples=20, lang_index=0): 412 | """ 413 | Get the document-topic distribution for a dataset of topics. Includes multiple sampling to reduce variation via 414 | the parameter n_sample. 415 | 416 | :param dataset: a PyTorch Dataset containing the documents 417 | :param n_samples: the number of sample to collect to estimate the final distribution (the more the better). 418 | """ 419 | self.model.eval() 420 | 421 | loader = DataLoader( 422 | dataset, batch_size=self.batch_size, shuffle=False, 423 | num_workers=self.num_data_loader_workers) 424 | pbar = tqdm(n_samples, position=0, leave=True) 425 | final_thetas = [] 426 | for sample_index in range(n_samples): 427 | with torch.no_grad(): 428 | collect_theta = [] 429 | 430 | for batch_samples in loader: 431 | # batch_size x vocab_size 432 | X_bow = batch_samples['X_bow'] 433 | #print('X_bow orig:', X_bow.shape) 434 | X_bow = X_bow.squeeze(dim=1) 435 | #print('X_bow:', X_bow.shape) 436 | 437 | # batch_size x L x bert_size 438 | X_contextual = batch_samples['X_contextual'] 439 | #print('X_contextual:', X_contextual.shape) 440 | 441 | if self.USE_CUDA: 442 | X_bow = X_bow.cuda() 443 | X_contextual = X_contextual.cuda() 444 | 445 | # forward pass 446 | self.model.zero_grad() 447 | thetas = self.model.get_theta(X_bow, X_contextual, lang_index) 448 | collect_theta.extend(thetas.detach().cpu().numpy()) 449 | 450 | pbar.update(1) 451 | pbar.set_description("Sampling: [{}/{}]".format(sample_index + 1, n_samples)) 452 | 453 | final_thetas.append(np.array(collect_theta)) 454 | pbar.close() 455 | return np.sum(final_thetas, axis=0) / n_samples 456 | 457 | def get_most_likely_topic(self, doc_topic_distribution): 458 | """ get the most likely topic for each document 459 | 460 | :param doc_topic_distribution: ndarray representing the topic distribution of each document 461 | """ 462 | return np.argmax(doc_topic_distribution, axis=0) 463 | 464 | 465 | def get_topics(self, k=10): 466 | """ 467 | Retrieve topic words. 468 | 469 | :param k: int, number of words to return per topic, default 10. 470 | """ 471 | assert k <= self.bow_size, "k must be <= input size." 472 | component_dists = self.best_components 473 | #print('component_dists:', component_dists.shape) 474 | topics_all = [] 475 | for l in range(self.num_lang): 476 | topics = defaultdict(list) 477 | for i in range(self.n_components): 478 | _, idxs = torch.topk(component_dists[l][i], k) 479 | component_words = [self.train_data.idx2token[l][idx] 480 | for idx in idxs.cpu().numpy()] 481 | topics[i] = component_words 482 | topics_all.append(topics) 483 | return topics_all 484 | 485 | def get_topic_lists(self, k=10): 486 | """ 487 | Retrieve the lists of topic words. 488 | 489 | :param k: (int) number of words to return per topic, default 10. 490 | """ 491 | assert k <= self.bow_size, "k must be <= input size." 492 | # TODO: collapse this method with the one that just returns the topics 493 | component_dists = self.best_components 494 | topics_all = [] 495 | for l in range(self.num_lang): 496 | topics = [] 497 | for i in range(self.n_components): 498 | _, idxs = torch.topk(component_dists[l][i], k) 499 | component_words = [self.train_data.idx2token[l][idx] 500 | for idx in idxs.cpu().numpy()] 501 | topics.append(component_words) 502 | topics_all.append(topics) 503 | return topics_all 504 | 505 | def _format_file(self): 506 | model_dir = "contextualized_topic_model_nc_{}_tpm_{}_tpv_{}_hs_{}_ac_{}_do_{}_lr_{}_mo_{}_rp_{}". \ 507 | format(self.n_components, 0.0, 1 - (1. / self.n_components), 508 | self.model_type, self.hidden_sizes, self.activation, 509 | self.dropout, self.lr, self.momentum, 510 | self.reduce_on_plateau) 511 | return model_dir 512 | 513 | def save(self, models_dir=None): 514 | """ 515 | Save model. (Experimental Feature, not tested) 516 | 517 | :param models_dir: path to directory for saving NN models. 518 | """ 519 | warnings.simplefilter('always', Warning) 520 | warnings.warn("This is an experimental feature that we has not been fully tested. Refer to the following issue:" 521 | "https://github.com/MilaNLProc/contextualized-topic-models/issues/38", 522 | Warning) 523 | 524 | if (self.model is not None) and (models_dir is not None): 525 | 526 | model_dir = self._format_file() 527 | if not os.path.isdir(os.path.join(models_dir, model_dir)): 528 | os.makedirs(os.path.join(models_dir, model_dir)) 529 | 530 | filename = "epoch_{}".format(self.nn_epoch) + '.pth' 531 | fileloc = os.path.join(models_dir, model_dir, filename) 532 | with open(fileloc, 'wb') as file: 533 | torch.save({'state_dict': self.model.state_dict(), 534 | 'dcue_dict': self.__dict__}, file) 535 | 536 | def load(self, model_dir, epoch): 537 | """ 538 | Load a previously trained model. (Experimental Feature, not tested) 539 | 540 | :param model_dir: directory where models are saved. 541 | :param epoch: epoch of model to load. 542 | """ 543 | 544 | warnings.simplefilter('always', Warning) 545 | warnings.warn("This is an experimental feature that we has not been fully tested. Refer to the following issue:" 546 | "https://github.com/MilaNLProc/contextualized-topic-models/issues/38", 547 | Warning) 548 | 549 | epoch_file = "epoch_" + str(epoch) + ".pth" 550 | model_file = os.path.join(model_dir, epoch_file) 551 | with open(model_file, 'rb') as model_dict: 552 | checkpoint = torch.load(model_dict) 553 | 554 | for (k, v) in checkpoint['dcue_dict'].items(): 555 | setattr(self, k, v) 556 | 557 | self.model.load_state_dict(checkpoint['state_dict']) 558 | 559 | def get_topic_word_matrix(self): 560 | """ 561 | Return the topic-word matrix (dimensions: number of topics x length of the vocabulary). 562 | If model_type is LDA, the matrix is normalized; otherwise the matrix is unnormalized. 563 | """ 564 | return self.model.topic_word_matrix.cpu().detach().numpy() 565 | 566 | def get_topic_word_distribution(self): 567 | """ 568 | Return the topic-word distribution (dimensions: number of topics x length of the vocabulary). 569 | """ 570 | mat = self.get_topic_word_matrix() 571 | return softmax(mat, axis=1) 572 | 573 | def get_word_distribution_by_topic_id(self, topic): 574 | """ 575 | Return the word probability distribution of a topic sorted by probability. 576 | 577 | :param topic: id of the topic (int) 578 | 579 | :returns list of tuples (word, probability) sorted by the probability in descending order 580 | """ 581 | if topic >= self.n_components: 582 | raise Exception('Topic id must be lower than the number of topics') 583 | else: 584 | wd = self.get_topic_word_distribution() 585 | t = [(word, wd[topic][idx]) for idx, word in self.train_data.idx2token.items()] 586 | t = sorted(t, key=lambda x: -x[1]) 587 | return t 588 | 589 | def get_wordcloud(self, topic_id, n_words=5, background_color="black", width=1000, height=400): 590 | """ 591 | Plotting the wordcloud. It is an adapted version of the code found here: 592 | http://amueller.github.io/word_cloud/auto_examples/simple.html#sphx-glr-auto-examples-simple-py and 593 | here https://github.com/ddangelov/Top2Vec/blob/master/top2vec/Top2Vec.py 594 | 595 | :param topic_id: id of the topic 596 | :param n_words: number of words to show in word cloud 597 | :param background_color: color of the background 598 | :param width: width of the produced image 599 | :param height: height of the produced image 600 | """ 601 | word_score_list = self.get_word_distribution_by_topic_id(topic_id)[:n_words] 602 | word_score_dict = {tup[0]: tup[1] for tup in word_score_list} 603 | plt.figure(figsize=(10, 4), dpi=200) 604 | plt.axis("off") 605 | plt.imshow(wordcloud.WordCloud(width=width, height=height, background_color=background_color 606 | ).generate_from_frequencies(word_score_dict)) 607 | plt.title("Displaying Topic " + str(topic_id), loc='center', fontsize=24) 608 | plt.show() 609 | 610 | def get_predicted_topics(self, dataset, n_samples): 611 | """ 612 | Return the a list containing the predicted topic for each document (length: number of documents). 613 | 614 | :param dataset: CTMDataset to infer topics 615 | :param n_samples: number of sampling of theta 616 | :return: the predicted topics 617 | """ 618 | predicted_topics = [] 619 | thetas = self.get_doc_topic_distribution(dataset, n_samples) 620 | 621 | for idd in range(len(dataset)): 622 | predicted_topic = np.argmax(thetas[idd] / np.sum(thetas[idd])) 623 | predicted_topics.append(predicted_topic) 624 | return predicted_topics 625 | 626 | def get_ldavis_data_format(self, vocab, dataset, n_samples): 627 | """ 628 | Returns the data that can be used in input to pyldavis to plot 629 | the topics 630 | """ 631 | term_frequency = dataset.X_bow.toarray().sum(axis=0) 632 | doc_lengths = dataset.X_bow.toarray().sum(axis=1) 633 | term_topic = self.get_topic_word_distribution() 634 | doc_topic_distribution = self.get_doc_topic_distribution(dataset, n_samples=n_samples) 635 | 636 | data = {'topic_term_dists': term_topic, 637 | 'doc_topic_dists': doc_topic_distribution, 638 | 'doc_lengths': doc_lengths, 639 | 'vocab': vocab, 640 | 'term_frequency': term_frequency} 641 | 642 | return data 643 | -------------------------------------------------------------------------------- /networks/decoding_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from networks.inference_network import CombinedInferenceNetwork, ContextualInferenceNetwork 6 | 7 | # ----- Multimodal and Multingual (M3L) ----- 8 | class ContrastiveM3LDecoderNetwork(nn.Module): 9 | 10 | def __init__(self, input_size, bert_sizes=(768, 2048), n_components=10, model_type='prodLDA', 11 | hidden_sizes=(100,100), activation='softplus', dropout=0.2, 12 | learn_priors=True, label_size=0, num_languages=2): 13 | """ 14 | Initialize InferenceNetwork. 15 | 16 | Args 17 | input_size : int, dimension of input 18 | n_components : int, number of topic components, (default 10) 19 | model_type : string, 'prodLDA' or 'LDA' (default 'prodLDA') 20 | hidden_sizes : tuple, length = n_layers, (default (100, 100)) 21 | activation : string, 'softplus', 'relu', (default 'softplus') 22 | learn_priors : bool, make priors learnable parameter 23 | num_languages: no. of languages in dataset 24 | """ 25 | super(ContrastiveM3LDecoderNetwork, self).__init__() 26 | assert isinstance(input_size, int), "input_size must by type int." 27 | assert isinstance(n_components, int) and n_components > 0, \ 28 | "n_components must be type int > 0." 29 | assert model_type in ['prodLDA', 'LDA'], \ 30 | "model type must be 'prodLDA' or 'LDA'" 31 | assert isinstance(hidden_sizes, tuple), \ 32 | "hidden_sizes must be type tuple." 33 | assert activation in ['softplus', 'relu'], \ 34 | "activation must be 'softplus' or 'relu'." 35 | assert dropout >= 0, "dropout must be >= 0." 36 | 37 | # input_size: same as vocab size 38 | self.input_size = input_size 39 | # n_components: no. of topics 40 | self.n_components = n_components 41 | self.model_type = model_type 42 | self.hidden_sizes = hidden_sizes 43 | self.activation = activation 44 | self.dropout = dropout 45 | self.learn_priors = learn_priors 46 | 47 | if label_size != 0: 48 | self.label_classification = nn.Linear(n_components, label_size) 49 | 50 | # init prior parameters 51 | # \mu_1k = log \alpha_k + 1/K \sum_i log \alpha_i; 52 | # \alpha = 1 \forall \alpha 53 | # prior_mu is same for all languages 54 | topic_prior_mean = 0.0 55 | self.prior_mean = torch.tensor( 56 | [topic_prior_mean] * n_components) 57 | if torch.cuda.is_available(): 58 | self.prior_mean = self.prior_mean.cuda() 59 | if self.learn_priors: 60 | self.prior_mean = nn.Parameter(self.prior_mean) 61 | 62 | # \Sigma_1kk = 1 / \alpha_k (1 - 2/K) + 1/K^2 \sum_i 1 / \alpha_k; 63 | # \alpha = 1 \forall \alpha 64 | # prior_var is same for all languages 65 | topic_prior_variance = 1. - (1. / self.n_components) 66 | self.prior_variance = torch.tensor( 67 | [topic_prior_variance] * n_components) 68 | if torch.cuda.is_available(): 69 | self.prior_variance = self.prior_variance.cuda() 70 | if self.learn_priors: 71 | self.prior_variance = nn.Parameter(self.prior_variance) 72 | 73 | self.num_languages = num_languages 74 | # each language has their own inference network (assume num_lang=2 for now) --- instantiate separate inference networks for each language/modality 75 | # language 1 inference net 76 | self.inf_net1 = ContextualInferenceNetwork(input_size, bert_sizes[0], n_components, hidden_sizes, activation) 77 | # language 2 inference net 78 | self.inf_net2 = ContextualInferenceNetwork(input_size, bert_sizes[0], n_components, hidden_sizes, activation) 79 | # image 1 inference net 80 | self.inf_net3 = ContextualInferenceNetwork(input_size, bert_sizes[1], n_components, hidden_sizes, activation) 81 | 82 | # topic_word_matrix is K x V, where L = no. of languages 83 | self.topic_word_matrix = None 84 | 85 | # beta dimensions remains the same for multimodal because images have no BOW reconstruction 86 | # beta is L x K x V where L = no. of languages 87 | self.beta = torch.Tensor(num_languages, n_components, input_size) 88 | if torch.cuda.is_available(): 89 | self.beta = self.beta.cuda() 90 | self.beta = nn.Parameter(self.beta) 91 | nn.init.xavier_uniform_(self.beta) 92 | 93 | self.beta_batchnorm = nn.BatchNorm1d(input_size, affine=False) 94 | 95 | # dropout on theta 96 | self.drop_theta = nn.Dropout(p=self.dropout) 97 | 98 | @staticmethod 99 | def reparameterize(mu, logvar): 100 | """Reparameterize the theta distribution.""" 101 | std = torch.exp(0.5*logvar) 102 | eps = torch.randn_like(std) 103 | return eps.mul(std).add_(mu) 104 | 105 | def forward(self, x_bow, x_bert, x_image): 106 | """Forward pass.""" 107 | # x_bert: batch_size x L x bert_dim 108 | # pass language1 x_bert to inference net1 (input is batch_size x bert_dim) 109 | posterior_mu1, posterior_log_sigma1 = self.inf_net1(x_bow[:, 0, :], x_bert[:, 0, :]) 110 | posterior_sigma1 = torch.exp(posterior_log_sigma1) 111 | 112 | # pass language2 x_bert to inference net2 (input is batch_size x bert_dim) 113 | posterior_mu2, posterior_log_sigma2 = self.inf_net2(x_bow[:, 1, :], x_bert[:, 1, :]) 114 | posterior_sigma2 = torch.exp(posterior_log_sigma2) 115 | 116 | # pass encoded image to inference net3 (input is batch_size x image_enc_dim) 117 | # x_bow does not matter, inference net will not use it anyway 118 | posterior_mu3, posterior_log_sigma3 = self.inf_net3(x_bow[:, 1, :], x_image) 119 | posterior_sigma3 = torch.exp(posterior_log_sigma3) 120 | 121 | # generate separate thetas for each language 122 | z1 = self.reparameterize(posterior_mu1, posterior_log_sigma1) 123 | z2 = self.reparameterize(posterior_mu2, posterior_log_sigma2) 124 | z3 = self.reparameterize(posterior_mu3, posterior_log_sigma3) 125 | 126 | theta1 = F.softmax(z1, dim=1) 127 | theta2 = F.softmax(z2, dim=1) 128 | theta3 = F.softmax(z3, dim=1) 129 | thetas_no_drop = torch.stack([theta1, theta2, theta3]) 130 | 131 | theta1 = self.drop_theta(theta1) 132 | theta2 = self.drop_theta(theta2) 133 | theta3 = self.drop_theta(theta3) 134 | 135 | thetas = torch.stack([theta1, theta2, theta3]) 136 | 137 | word_dist_collect = [] 138 | for l in range(self.num_languages): 139 | # compute topic-word dist for language l 140 | # separate theta and separate beta for each language 141 | word_dist = F.softmax( 142 | self.beta_batchnorm(torch.matmul(thetas[l], self.beta[l])), dim=1) 143 | word_dist_collect.append(word_dist) 144 | 145 | # word_dist_collect: L x batch_size x input_size 146 | word_dist_collect = torch.stack([w for w in word_dist_collect]) 147 | 148 | # topic_word_matrix and beta should be L x n_components x vocab_size 149 | self.topic_word_matrix = self.beta 150 | 151 | return self.prior_mean, self.prior_variance, posterior_mu1, posterior_sigma1, posterior_log_sigma1, \ 152 | posterior_mu2, posterior_sigma2, posterior_log_sigma2, posterior_mu3, posterior_sigma3, posterior_log_sigma3, \ 153 | word_dist_collect, thetas_no_drop 154 | 155 | def get_theta(self, x, x_bert, lang_index=0): 156 | with torch.no_grad(): 157 | 158 | # we do inference PER LANGUAGE and PER MODALITY, so we use only 1 inference network at a time 159 | # inference with language1 160 | if lang_index == 0: 161 | posterior_mu, posterior_log_sigma = self.inf_net1(x, x_bert) 162 | # inference with language2 163 | elif lang_index == 1: 164 | posterior_mu, posterior_log_sigma = self.inf_net2(x, x_bert) 165 | # inference with image 166 | else: 167 | posterior_mu, posterior_log_sigma = self.inf_net3(x, x_bert) 168 | 169 | # generate samples from theta 170 | theta = F.softmax( 171 | self.reparameterize(posterior_mu, posterior_log_sigma), dim=1) 172 | 173 | #print('theta:', theta.shape) 174 | return theta 175 | 176 | 177 | # ----- Multilingual ----- 178 | 179 | class ContrastiveMultilingualDecoderNetwork(nn.Module): 180 | 181 | def __init__(self, input_size, bert_size, n_components=10, model_type='prodLDA', 182 | hidden_sizes=(100,100), activation='softplus', dropout=0.2, 183 | learn_priors=True, label_size=0, num_languages=2): 184 | """ 185 | Initialize InferenceNetwork. 186 | 187 | Args 188 | input_size : int, dimension of input 189 | n_components : int, number of topic components, (default 10) 190 | model_type : string, 'prodLDA' or 'LDA' (default 'prodLDA') 191 | hidden_sizes : tuple, length = n_layers, (default (100, 100)) 192 | activation : string, 'softplus', 'relu', (default 'softplus') 193 | learn_priors : bool, make priors learnable parameter 194 | num_languages: no. of languages in dataset 195 | """ 196 | super(ContrastiveMultilingualDecoderNetwork, self).__init__() 197 | assert isinstance(input_size, int), "input_size must by type int." 198 | assert isinstance(n_components, int) and n_components > 0, \ 199 | "n_components must be type int > 0." 200 | assert model_type in ['prodLDA', 'LDA'], \ 201 | "model type must be 'prodLDA' or 'LDA'" 202 | assert isinstance(hidden_sizes, tuple), \ 203 | "hidden_sizes must be type tuple." 204 | assert activation in ['softplus', 'relu'], \ 205 | "activation must be 'softplus' or 'relu'." 206 | assert dropout >= 0, "dropout must be >= 0." 207 | 208 | # input_size: same as vocab size 209 | self.input_size = input_size 210 | # n_components: no. of topics 211 | self.n_components = n_components 212 | self.model_type = model_type 213 | self.hidden_sizes = hidden_sizes 214 | self.activation = activation 215 | self.dropout = dropout 216 | self.learn_priors = learn_priors 217 | 218 | if label_size != 0: 219 | self.label_classification = nn.Linear(n_components, label_size) 220 | 221 | # init prior parameters 222 | # \mu_1k = log \alpha_k + 1/K \sum_i log \alpha_i; 223 | # \alpha = 1 \forall \alpha 224 | # prior_mu is same for all languages 225 | topic_prior_mean = 0.0 226 | self.prior_mean = torch.tensor( 227 | [topic_prior_mean] * n_components) 228 | if torch.cuda.is_available(): 229 | self.prior_mean = self.prior_mean.cuda() 230 | if self.learn_priors: 231 | self.prior_mean = nn.Parameter(self.prior_mean) 232 | 233 | # \Sigma_1kk = 1 / \alpha_k (1 - 2/K) + 1/K^2 \sum_i 1 / \alpha_k; 234 | # \alpha = 1 \forall \alpha 235 | # prior_var is same for all languages 236 | topic_prior_variance = 1. - (1. / self.n_components) 237 | self.prior_variance = torch.tensor( 238 | [topic_prior_variance] * n_components) 239 | if torch.cuda.is_available(): 240 | self.prior_variance = self.prior_variance.cuda() 241 | if self.learn_priors: 242 | self.prior_variance = nn.Parameter(self.prior_variance) 243 | 244 | self.num_languages = num_languages 245 | # each language has their own inference network (assume num_lang=2 for now) ---- instantiate separate inference networks for each language 246 | self.inf_net1 = ContextualInferenceNetwork(input_size, bert_size, n_components, hidden_sizes, activation) 247 | self.inf_net2 = ContextualInferenceNetwork(input_size, bert_size, n_components, hidden_sizes, activation) 248 | 249 | # topic_word_matrix is K x V, where L = no. of languages 250 | self.topic_word_matrix = None 251 | 252 | # beta is L x K x V where L = no. of languages 253 | self.beta = torch.Tensor(num_languages, n_components, input_size) 254 | if torch.cuda.is_available(): 255 | self.beta = self.beta.cuda() 256 | self.beta = nn.Parameter(self.beta) 257 | nn.init.xavier_uniform_(self.beta) 258 | 259 | self.beta_batchnorm = nn.BatchNorm1d(input_size, affine=False) 260 | 261 | # dropout on theta 262 | self.drop_theta = nn.Dropout(p=self.dropout) 263 | 264 | @staticmethod 265 | def reparameterize(mu, logvar): 266 | """Reparameterize the theta distribution.""" 267 | std = torch.exp(0.5*logvar) 268 | eps = torch.randn_like(std) 269 | return eps.mul(std).add_(mu) 270 | 271 | def forward(self, x, x_bert): 272 | """Forward pass.""" 273 | # x_bert: batch_size x L x bert_dim 274 | # print('DecoderNet - forward') 275 | # print('x_bert:', x_bert.shape) 276 | # pass to first x_bert to inference net1 (input is batch_size x bert_dim) 277 | posterior_mu1, posterior_log_sigma1 = self.inf_net1(x[:, 0, :], x_bert[:, 0, :]) 278 | posterior_sigma1 = torch.exp(posterior_log_sigma1) 279 | 280 | # pass to second x_bert to inference net2 (input is batch_size x bert_dim) 281 | posterior_mu2, posterior_log_sigma2 = self.inf_net2(x[:, 1, :], x_bert[:, 1, :]) 282 | posterior_sigma2 = torch.exp(posterior_log_sigma2) 283 | 284 | # generate separate thetas for each language 285 | z1 = self.reparameterize(posterior_mu1, posterior_log_sigma1) 286 | z2 = self.reparameterize(posterior_mu2, posterior_log_sigma2) 287 | theta1 = F.softmax(z1, dim=1) 288 | theta2 = F.softmax(z2, dim=1) 289 | # print("mu1:", posterior_mu1) 290 | # print("log_sigma1:", posterior_log_sigma1) 291 | # # print("z1:", z1) 292 | # print("-"*10) 293 | # print("mu2:", posterior_mu2) 294 | # print("log_sigma2:", posterior_log_sigma2) 295 | # # print("z2:", z2) 296 | # print("-" * 10) 297 | 298 | thetas_no_drop = torch.stack([theta1, theta2]) 299 | z_no_drop = torch.stack([z1, z2]) 300 | 301 | theta1 = self.drop_theta(theta1) 302 | theta2 = self.drop_theta(theta2) 303 | 304 | thetas = torch.stack([theta1, theta2]) 305 | 306 | word_dist_collect = [] 307 | for l in range(self.num_languages): 308 | # compute topic-word dist for language l 309 | # separate theta and separate beta for each language 310 | word_dist = F.softmax( 311 | self.beta_batchnorm(torch.matmul(thetas[l], self.beta[l])), dim=1) 312 | word_dist_collect.append(word_dist) 313 | 314 | # word_dist_collect: L x batch_size x input_size 315 | word_dist_collect = torch.stack([w for w in word_dist_collect]) 316 | 317 | # topic_word_matrix and beta should be L x n_components x vocab_size 318 | self.topic_word_matrix = self.beta 319 | #print('beta:', self.beta.shape) 320 | 321 | return self.prior_mean, self.prior_variance, posterior_mu1, posterior_sigma1, posterior_log_sigma1, \ 322 | posterior_mu2, posterior_sigma2, posterior_log_sigma2, word_dist_collect, thetas_no_drop, z_no_drop 323 | 324 | def get_theta(self, x, x_bert, lang_index=0): 325 | with torch.no_grad(): 326 | 327 | # we do inference PER LANGUAGE, so we use only 1 inference network at a time 328 | if lang_index == 0: 329 | posterior_mu, posterior_log_sigma = self.inf_net1(x, x_bert) 330 | else: 331 | posterior_mu, posterior_log_sigma = self.inf_net2(x, x_bert) 332 | 333 | # generate samples from theta 334 | theta = F.softmax( 335 | self.reparameterize(posterior_mu, posterior_log_sigma), dim=1) 336 | 337 | #print('theta:', theta.shape) 338 | return theta 339 | 340 | 341 | # ----- Contrastive ----- 342 | 343 | class ContrastiveDecoderNetwork(nn.Module): 344 | 345 | def __init__(self, input_size, bert_size, n_components=10, model_type='prodLDA', 346 | hidden_sizes=(100,100), activation='softplus', dropout=0.2, 347 | learn_priors=True, label_size=0, num_languages=2): 348 | """ 349 | Initialize InferenceNetwork. 350 | 351 | Args 352 | input_size : int, dimension of input 353 | n_components : int, number of topic components, (default 10) 354 | model_type : string, 'prodLDA' or 'LDA' (default 'prodLDA') 355 | hidden_sizes : tuple, length = n_layers, (default (100, 100)) 356 | activation : string, 'softplus', 'relu', (default 'softplus') 357 | learn_priors : bool, make priors learnable parameter 358 | num_languages: no. of languages in dataset 359 | """ 360 | super(ContrastiveDecoderNetwork, self).__init__() 361 | assert isinstance(input_size, int), "input_size must by type int." 362 | assert isinstance(n_components, int) and n_components > 0, \ 363 | "n_components must be type int > 0." 364 | assert model_type in ['prodLDA', 'LDA'], \ 365 | "model type must be 'prodLDA' or 'LDA'" 366 | assert isinstance(hidden_sizes, tuple), \ 367 | "hidden_sizes must be type tuple." 368 | assert activation in ['softplus', 'relu'], \ 369 | "activation must be 'softplus' or 'relu'." 370 | assert dropout >= 0, "dropout must be >= 0." 371 | 372 | # input_size: same as vocab size 373 | self.input_size = input_size 374 | # n_components: no. of topics 375 | self.n_components = n_components 376 | self.model_type = model_type 377 | self.hidden_sizes = hidden_sizes 378 | self.activation = activation 379 | self.dropout = dropout 380 | self.learn_priors = learn_priors 381 | 382 | if label_size != 0: 383 | self.label_classification = nn.Linear(n_components, label_size) 384 | 385 | # init prior parameters 386 | # \mu_1k = log \alpha_k + 1/K \sum_i log \alpha_i; 387 | # \alpha = 1 \forall \alpha 388 | # prior_mu is same for all languages 389 | topic_prior_mean = 0.0 390 | self.prior_mean = torch.tensor( 391 | [topic_prior_mean] * n_components) 392 | if torch.cuda.is_available(): 393 | self.prior_mean = self.prior_mean.cuda() 394 | if self.learn_priors: 395 | self.prior_mean = nn.Parameter(self.prior_mean) 396 | 397 | # \Sigma_1kk = 1 / \alpha_k (1 - 2/K) + 1/K^2 \sum_i 1 / \alpha_k; 398 | # \alpha = 1 \forall \alpha 399 | # prior_var is same for all languages 400 | topic_prior_variance = 1. - (1. / self.n_components) 401 | self.prior_variance = torch.tensor( 402 | [topic_prior_variance] * n_components) 403 | if torch.cuda.is_available(): 404 | self.prior_variance = self.prior_variance.cuda() 405 | if self.learn_priors: 406 | self.prior_variance = nn.Parameter(self.prior_variance) 407 | 408 | self.num_languages = num_languages 409 | # each language has their own inference network (assume num_lang=2 for now) 410 | self.inf_net1 = ContextualInferenceNetwork(input_size, bert_size, n_components, hidden_sizes, activation) 411 | self.inf_net2 = ContextualInferenceNetwork(input_size, bert_size, n_components, hidden_sizes, activation) 412 | 413 | # topic_word_matrix is K x V, where L = no. of languages 414 | self.topic_word_matrix = None 415 | 416 | # beta is L x K x V where L = no. of languages 417 | self.beta = torch.Tensor(num_languages, n_components, input_size) 418 | if torch.cuda.is_available(): 419 | self.beta = self.beta.cuda() 420 | self.beta = nn.Parameter(self.beta) 421 | nn.init.xavier_uniform_(self.beta) 422 | 423 | self.beta_batchnorm = nn.BatchNorm1d(input_size, affine=False) 424 | 425 | # dropout on theta 426 | self.drop_theta = nn.Dropout(p=self.dropout) 427 | 428 | @staticmethod 429 | def reparameterize(mu, logvar): 430 | """Reparameterize the theta distribution.""" 431 | std = torch.exp(0.5*logvar) 432 | eps = torch.randn_like(std) 433 | return eps.mul(std).add_(mu) 434 | 435 | def forward(self, x, x_bert): 436 | """Forward pass.""" 437 | # x_bert: batch_size x L x bert_dim 438 | # print('DecoderNet - forward') 439 | # print('x_bert:', x_bert.shape) 440 | # pass to first x_bert to inference net1 (input is batch_size x bert_dim) 441 | posterior_mu1, posterior_log_sigma1 = self.inf_net1(x[:, 0, :], x_bert[:, 0, :]) 442 | posterior_sigma1 = torch.exp(posterior_log_sigma1) 443 | 444 | # pass to second x_bert to inference net2 (input is batch_size x bert_dim) 445 | posterior_mu2, posterior_log_sigma2 = self.inf_net2(x[:, 1, :], x_bert[:, 1, :]) 446 | posterior_sigma2 = torch.exp(posterior_log_sigma2) 447 | 448 | # generate separate thetas for each language 449 | z1 = self.reparameterize(posterior_mu1, posterior_log_sigma1) 450 | z2 = self.reparameterize(posterior_mu2, posterior_log_sigma2) 451 | theta1 = F.softmax(z1, dim=1) 452 | theta2 = F.softmax(z2, dim=1) 453 | # print("mu1:", posterior_mu1) 454 | # print("log_sigma1:", posterior_log_sigma1) 455 | # # print("z1:", z1) 456 | # print("-"*10) 457 | # print("mu2:", posterior_mu2) 458 | # print("log_sigma2:", posterior_log_sigma2) 459 | # # print("z2:", z2) 460 | # print("-" * 10) 461 | 462 | thetas_no_drop = torch.stack([theta1, theta2]) 463 | z_no_drop = torch.stack([z1, z2]) 464 | 465 | theta1 = self.drop_theta(theta1) 466 | theta2 = self.drop_theta(theta2) 467 | 468 | thetas = torch.stack([theta1, theta2]) 469 | 470 | word_dist_collect = [] 471 | for l in range(self.num_languages): 472 | # compute topic-word dist for each language 473 | # separate thetas and betas per language 474 | word_dist = F.softmax( 475 | self.beta_batchnorm(torch.matmul(thetas[l], self.beta[l])), dim=1) 476 | word_dist_collect.append(word_dist) 477 | 478 | # word_dist_collect: L x batch_size x input_size 479 | word_dist_collect = torch.stack([w for w in word_dist_collect]) 480 | 481 | # topic_word_matrix and beta should be L x n_components x vocab_size 482 | self.topic_word_matrix = self.beta 483 | 484 | return self.prior_mean, self.prior_variance, posterior_mu1, posterior_sigma1, posterior_log_sigma1, \ 485 | posterior_mu2, posterior_sigma2, posterior_log_sigma2, word_dist_collect, thetas_no_drop, z_no_drop 486 | 487 | def get_theta(self, x, x_bert, lang_index=0): 488 | with torch.no_grad(): 489 | # we do inference PER LANGUAGE, so we use only 1 inference network at a time 490 | if lang_index == 0: 491 | posterior_mu, posterior_log_sigma = self.inf_net1(x, x_bert) 492 | else: 493 | posterior_mu, posterior_log_sigma = self.inf_net2(x, x_bert) 494 | 495 | # generate samples from theta 496 | theta = F.softmax( 497 | self.reparameterize(posterior_mu, posterior_log_sigma), dim=1) 498 | return theta 499 | 500 | 501 | # ------- Original ------- 502 | class DecoderNetwork(nn.Module): 503 | 504 | def __init__(self, input_size, bert_size, infnet, n_components=10, model_type='prodLDA', 505 | hidden_sizes=(100,100), activation='softplus', dropout=0.2, 506 | learn_priors=True, label_size=0): 507 | """ 508 | Initialize InferenceNetwork. 509 | 510 | Args 511 | input_size : int, dimension of input 512 | n_components : int, number of topic components, (default 10) 513 | model_type : string, 'prodLDA' or 'LDA' (default 'prodLDA') 514 | hidden_sizes : tuple, length = n_layers, (default (100, 100)) 515 | activation : string, 'softplus', 'relu', (default 'softplus') 516 | learn_priors : bool, make priors learnable parameter 517 | """ 518 | super(DecoderNetwork, self).__init__() 519 | assert isinstance(input_size, int), "input_size must by type int." 520 | assert isinstance(n_components, int) and n_components > 0, \ 521 | "n_components must be type int > 0." 522 | assert model_type in ['prodLDA', 'LDA'], \ 523 | "model type must be 'prodLDA' or 'LDA'" 524 | assert isinstance(hidden_sizes, tuple), \ 525 | "hidden_sizes must be type tuple." 526 | assert activation in ['softplus', 'relu'], \ 527 | "activation must be 'softplus' or 'relu'." 528 | assert dropout >= 0, "dropout must be >= 0." 529 | 530 | self.input_size = input_size 531 | self.n_components = n_components 532 | self.model_type = model_type 533 | self.hidden_sizes = hidden_sizes 534 | self.activation = activation 535 | self.dropout = dropout 536 | self.learn_priors = learn_priors 537 | self.topic_word_matrix = None 538 | 539 | # print('hidden_sizes:', hidden_sizes) 540 | if infnet == "zeroshot": 541 | self.inf_net = ContextualInferenceNetwork( 542 | input_size, bert_size, n_components, hidden_sizes, activation, label_size=label_size) 543 | elif infnet == "combined": 544 | self.inf_net = CombinedInferenceNetwork( 545 | input_size, bert_size, n_components, hidden_sizes, activation, label_size=label_size) 546 | else: 547 | raise Exception('Missing infnet parameter, options are zeroshot and combined') 548 | 549 | if label_size != 0: 550 | self.label_classification = nn.Linear(n_components, label_size) 551 | 552 | # init prior parameters 553 | # \mu_1k = log \alpha_k + 1/K \sum_i log \alpha_i; 554 | # \alpha = 1 \forall \alpha 555 | topic_prior_mean = 0.0 556 | self.prior_mean = torch.tensor( 557 | [topic_prior_mean] * n_components) 558 | if torch.cuda.is_available(): 559 | self.prior_mean = self.prior_mean.cuda() 560 | if self.learn_priors: 561 | self.prior_mean = nn.Parameter(self.prior_mean) 562 | 563 | # \Sigma_1kk = 1 / \alpha_k (1 - 2/K) + 1/K^2 \sum_i 1 / \alpha_k; 564 | # \alpha = 1 \forall \alpha 565 | topic_prior_variance = 1. - (1. / self.n_components) 566 | self.prior_variance = torch.tensor( 567 | [topic_prior_variance] * n_components) 568 | if torch.cuda.is_available(): 569 | self.prior_variance = self.prior_variance.cuda() 570 | if self.learn_priors: 571 | self.prior_variance = nn.Parameter(self.prior_variance) 572 | 573 | self.beta = torch.Tensor(n_components, input_size) 574 | if torch.cuda.is_available(): 575 | self.beta = self.beta.cuda() 576 | self.beta = nn.Parameter(self.beta) 577 | nn.init.xavier_uniform_(self.beta) 578 | 579 | self.beta_batchnorm = nn.BatchNorm1d(input_size, affine=False) 580 | 581 | # dropout on theta 582 | self.drop_theta = nn.Dropout(p=self.dropout) 583 | 584 | @staticmethod 585 | def reparameterize(mu, logvar): 586 | """Reparameterize the theta distribution.""" 587 | std = torch.exp(0.5*logvar) 588 | eps = torch.randn_like(std) 589 | return eps.mul(std).add_(mu) 590 | 591 | def forward(self, x, x_bert, labels=None): 592 | """Forward pass.""" 593 | # batch_size x n_components 594 | posterior_mu, posterior_log_sigma = self.inf_net(x, x_bert, labels) 595 | posterior_sigma = torch.exp(posterior_log_sigma) 596 | 597 | # generate samples from theta 598 | theta = F.softmax( 599 | self.reparameterize(posterior_mu, posterior_log_sigma), dim=1) 600 | theta = self.drop_theta(theta) 601 | 602 | # prodLDA vs LDA 603 | if self.model_type == 'prodLDA': 604 | # in: batch_size x input_size x n_components 605 | word_dist = F.softmax( 606 | self.beta_batchnorm(torch.matmul(theta, self.beta)), dim=1) 607 | # word_dist: batch_size x input_size 608 | self.topic_word_matrix = self.beta 609 | elif self.model_type == 'LDA': 610 | # simplex constrain on Beta 611 | beta = F.softmax(self.beta_batchnorm(self.beta), dim=1) 612 | self.topic_word_matrix = beta 613 | word_dist = torch.matmul(theta, beta) 614 | # word_dist: batch_size x input_size 615 | else: 616 | raise NotImplementedError("Model Type Not Implemented") 617 | 618 | # classify labels 619 | 620 | estimated_labels = None 621 | 622 | if labels is not None: 623 | estimated_labels = self.label_classification(theta) 624 | 625 | return self.prior_mean, self.prior_variance, posterior_mu, posterior_sigma, posterior_log_sigma, word_dist, estimated_labels 626 | 627 | def get_theta(self, x, x_bert, labels=None): 628 | with torch.no_grad(): 629 | # batch_size x n_components 630 | posterior_mu, posterior_log_sigma = self.inf_net(x, x_bert, labels) 631 | #posterior_sigma = torch.exp(posterior_log_sigma) 632 | 633 | # generate samples from theta 634 | theta = F.softmax( 635 | self.reparameterize(posterior_mu, posterior_log_sigma), dim=1) 636 | 637 | return theta 638 | 639 | -------------------------------------------------------------------------------- /models/M3L_contrast.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import multiprocessing as mp 3 | import os 4 | import warnings 5 | from collections import defaultdict 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import torch 9 | import wordcloud 10 | from scipy.special import softmax 11 | from torch import optim 12 | from torch.optim.lr_scheduler import ReduceLROnPlateau 13 | from torch.utils.data import DataLoader 14 | import torch.nn.functional as F 15 | from tqdm import tqdm 16 | from utils.early_stopping.early_stopping import EarlyStopping 17 | 18 | # decooder network 19 | from networks.decoding_network import ContrastiveM3LDecoderNetwork 20 | 21 | # for contrastive loss 22 | from pytorch_metric_learning.losses import NTXentLoss 23 | 24 | class MultimodalContrastiveTM: 25 | """Class to train the contextualized topic model. This is the more general class that we are keeping to 26 | avoid braking code, users should use the two subclasses ZeroShotTM and CombinedTm to do topic modeling. 27 | 28 | :param bow_size: int, dimension of input 29 | :param contextual_size: int, dimension of input that comes from BERT embeddings 30 | :param inference_type: string, you can choose between the contextual model and the combined model 31 | :param n_components: int, number of topic components, (default 10) 32 | :param model_type: string, 'prodLDA' or 'LDA' (default 'prodLDA') 33 | :param hidden_sizes: tuple, length = n_layers, (default (100, 100)) 34 | :param activation: string, 'softplus', 'relu', (default 'softplus') 35 | :param dropout: float, dropout to use (default 0.2) 36 | :param learn_priors: bool, make priors a learnable parameter (default True) 37 | :param batch_size: int, size of batch to use for training (default 64) 38 | :param lr: float, learning rate to use for training (default 2e-3) 39 | :param momentum: float, momentum to use for training (default 0.99) 40 | :param solver: string, optimizer 'adam' or 'sgd' (default 'adam') 41 | :param num_epochs: int, number of epochs to train for, (default 100) 42 | :param reduce_on_plateau: bool, reduce learning rate by 10x on plateau of 10 epochs (default False) 43 | :param num_data_loader_workers: int, number of data loader workers (default cpu_count). set it to 0 if you are using Windows 44 | :param label_size: int, number of total labels (default: 0) 45 | :param loss_weights: dict, it contains the name of the weight parameter (key) and the weight (value) for each loss. 46 | It supports only the weight parameter beta for now. If None, then the weights are set to 1 (default: None). 47 | 48 | """ 49 | 50 | def __init__(self, bow_size, contextual_sizes=(768, 2048), n_components=10, model_type='prodLDA', 51 | hidden_sizes=(100, 100), activation='softplus', dropout=0.2, learn_priors=True, batch_size=16, 52 | lr=2e-3, momentum=0.99, solver='adam', num_epochs=100, reduce_on_plateau=False, 53 | num_data_loader_workers=mp.cpu_count(), label_size=0, loss_weights=None, languages=None): 54 | 55 | self.device = ( 56 | torch.device("cuda") 57 | if torch.cuda.is_available() 58 | else torch.device("cpu") 59 | ) 60 | 61 | assert isinstance(bow_size, int) and bow_size > 0, \ 62 | "input_size must by type int > 0." 63 | assert isinstance(n_components, int) and bow_size > 0, \ 64 | "n_components must by type int > 0." 65 | assert model_type in ['LDA', 'prodLDA'], \ 66 | "model must be 'LDA' or 'prodLDA'." 67 | assert isinstance(hidden_sizes, tuple), \ 68 | "hidden_sizes must be type tuple." 69 | assert activation in ['softplus', 'relu'], \ 70 | "activation must be 'softplus' or 'relu'." 71 | assert dropout >= 0, "dropout must be >= 0." 72 | assert isinstance(learn_priors, bool), "learn_priors must be boolean." 73 | assert isinstance(batch_size, int) and batch_size > 0, \ 74 | "batch_size must be int > 0." 75 | assert lr > 0, "lr must be > 0." 76 | assert isinstance(momentum, float) and 0 < momentum <= 1, \ 77 | "momentum must be 0 < float <= 1." 78 | assert solver in ['adam', 'sgd'], "solver must be 'adam' or 'sgd'." 79 | assert isinstance(reduce_on_plateau, bool), \ 80 | "reduce_on_plateau must be type bool." 81 | assert isinstance(num_data_loader_workers, int) and num_data_loader_workers >= 0, \ 82 | "num_data_loader_workers must by type int >= 0. set 0 if you are using windows" 83 | 84 | # langs is an array of language codes e.g. ['en', 'de'] 85 | self.languages = languages 86 | self.num_lang = len(languages) 87 | # bow_size is an array of size n_languages; one bow_size for each language 88 | self.bow_size = bow_size 89 | self.n_components = n_components 90 | self.model_type = model_type 91 | self.hidden_sizes = hidden_sizes 92 | self.activation = activation 93 | self.dropout = dropout 94 | self.learn_priors = learn_priors 95 | self.batch_size = batch_size 96 | self.lr = lr 97 | self.contextual_sizes = contextual_sizes 98 | self.momentum = momentum 99 | self.solver = solver 100 | self.num_epochs = num_epochs 101 | self.reduce_on_plateau = reduce_on_plateau 102 | self.num_data_loader_workers = num_data_loader_workers 103 | 104 | if loss_weights: 105 | self.weights = loss_weights 106 | else: 107 | self.weights = {"KL": 1, "CL": 100} 108 | 109 | # contrastive decoder 110 | self.model = ContrastiveM3LDecoderNetwork( 111 | bow_size, self.contextual_sizes, n_components, model_type, hidden_sizes, activation, 112 | dropout, learn_priors, label_size=label_size) 113 | 114 | self.early_stopping = None 115 | 116 | # init optimizers 117 | if self.solver == 'adam': 118 | self.optimizer = optim.Adam( 119 | self.model.parameters(), lr=lr, betas=(self.momentum, 0.99)) 120 | elif self.solver == 'sgd': 121 | self.optimizer = optim.SGD( 122 | self.model.parameters(), lr=lr, momentum=self.momentum) 123 | 124 | # init lr scheduler 125 | if self.reduce_on_plateau: 126 | self.scheduler = ReduceLROnPlateau(self.optimizer, patience=10) 127 | 128 | # performance attributes 129 | self.best_loss_train = float('inf') 130 | 131 | # training attributes 132 | self.model_dir = None 133 | self.train_data = None 134 | self.nn_epoch = None 135 | 136 | # validation attributes 137 | self.validation_data = None 138 | 139 | # learned topics 140 | # best_components: n_components x vocab_size 141 | self.best_components = None 142 | 143 | # Use cuda if available 144 | if torch.cuda.is_available(): 145 | self.USE_CUDA = True 146 | else: 147 | self.USE_CUDA = False 148 | 149 | self.model = self.model.to(self.device) 150 | 151 | def _infoNCE_loss(self, embeddings1, embeddings2, temperature=0.07): 152 | batch_size = embeddings1.shape[0] 153 | labels = torch.arange(batch_size) 154 | labels = torch.cat([labels, labels]) 155 | embeddings_cat = torch.cat([embeddings1, embeddings2]) 156 | loss_func = NTXentLoss() 157 | infonce_loss = loss_func(embeddings_cat, labels) 158 | return infonce_loss 159 | 160 | def _kl_loss1(self, thetas1, thetas2): 161 | theta_kld = F.kl_div(thetas1.log(), thetas2, reduction='sum') 162 | return theta_kld 163 | 164 | def _kl_loss2(self, prior_mean, prior_variance, 165 | posterior_mean, posterior_variance, posterior_log_variance): 166 | # KL term 167 | # var division term 168 | var_division = torch.sum(posterior_variance / prior_variance, dim=1) 169 | # diff means term 170 | diff_means = prior_mean - posterior_mean 171 | diff_term = torch.sum( 172 | (diff_means * diff_means) / prior_variance, dim=1) 173 | # logvar det division term 174 | logvar_det_division = \ 175 | prior_variance.log().sum() - posterior_log_variance.sum(dim=1) 176 | # combine terms 177 | KL = 0.5 * ( 178 | var_division + diff_term - self.n_components + logvar_det_division) 179 | return KL 180 | 181 | def _rl_loss(self, true_word_dists, pred_word_dists): 182 | # Reconstruction term 183 | RL = -torch.sum(true_word_dists * torch.log(pred_word_dists + 1e-10), dim=1) 184 | return RL 185 | 186 | def _train_epoch(self, loader): 187 | """Train epoch.""" 188 | self.model.train() 189 | train_loss = 0 190 | samples_processed = 0 191 | 192 | for batch_samples in loader: 193 | # batch_size x L x vocab_size 194 | X_bow = batch_samples['X_bow'] 195 | X_bow = X_bow.squeeze(dim=2) 196 | 197 | # batch_size x L x bert_size 198 | X_contextual = batch_samples['X_contextual'] 199 | 200 | # batch_size x image_enc_size 201 | X_image = batch_samples['X_image'] 202 | 203 | if self.USE_CUDA: 204 | X_bow = X_bow.cuda() 205 | X_contextual = X_contextual.cuda() 206 | X_image = X_image.cuda() 207 | 208 | # forward pass 209 | self.model.zero_grad() 210 | prior_mean, prior_variance, posterior_mean1, posterior_variance1, posterior_log_variance1, \ 211 | posterior_mean2, posterior_variance2, posterior_log_variance2, \ 212 | posterior_mean3, posterior_variance3, posterior_log_variance3, \ 213 | word_dists, thetas = self.model(X_bow, X_contextual, X_image) 214 | 215 | # backward pass 216 | 217 | # Recon loss for lang1 and lang2 (no recon loss for image) 218 | rl_loss1 = self._rl_loss(X_bow[:,0,:], word_dists[0]) 219 | rl_loss2 = self._rl_loss(X_bow[:,1,:], word_dists[1]) 220 | 221 | 222 | # KL losses between posterior distributions and a prior distribution 223 | kl_en_prior = self._kl_loss(prior_mean, prior_variance, 224 | posterior_mean1, posterior_variance1, posterior_log_variance1) 225 | kl_de_prior = self._kl_loss(prior_mean, prior_variance, 226 | posterior_mean2, posterior_variance2, posterior_log_variance2) 227 | kl_image_prior = self._kl_loss(prior_mean, prior_variance, 228 | posterior_mean3, posterior_variance3, posterior_log_variance3) 229 | 230 | # KL loss between posterior distributions of paired languages/modalities 231 | kl_en_de = self._kl_loss2(posterior_mean1, posterior_variance1, 232 | posterior_mean2, posterior_variance2, posterior_log_variance2) 233 | kl_en_image = self._kl_loss2(posterior_mean1, posterior_variance1, 234 | posterior_mean3, posterior_variance3, posterior_log_variance3) 235 | kl_de_image = self._kl_loss2(posterior_mean2, posterior_variance2, 236 | posterior_mean3, posterior_variance3, posterior_log_variance3) 237 | 238 | # InfoNCE loss/NTXentLoss 239 | infoNCE_en_de = self._infoNCE_loss(thetas[0], thetas[1]) 240 | infoNCE_en_image = self._infoNCE_loss(thetas[0], thetas[2]) 241 | infoNCE_de_image = self._infoNCE_loss(thetas[1], thetas[2]) 242 | 243 | loss = rl_loss1 + rl_loss2 \ 244 | + self.weights["KL"] * kl_en_prior + self.weights["KL"] * kl_de_prior + self.weights["KL"] * kl_image_prior \ 245 | + self.weights["CL"]*infoNCE_en_de + self.weights["CL"]*infoNCE_en_image + self.weights["CL"]*infoNCE_de_image 246 | 247 | loss = loss.sum() 248 | loss.backward() 249 | self.optimizer.step() 250 | 251 | # compute train loss 252 | samples_processed += X_bow.size()[0] 253 | train_loss += loss.item() 254 | 255 | train_loss /= samples_processed 256 | 257 | return samples_processed, train_loss 258 | 259 | def fit(self, train_dataset, validation_dataset=None, save_dir=None, verbose=False, patience=5, delta=0): 260 | """ 261 | Train the CTM model. 262 | 263 | :param train_dataset: PyTorch Dataset class for training data. 264 | :param validation_dataset: PyTorch Dataset class for validation data. If not None, the training stops if validation loss doesn't improve after a given patience 265 | :param save_dir: directory to save checkpoint models to. 266 | :param verbose: verbose 267 | :param patience: How long to wait after last time validation loss improved. Default: 5 268 | :param delta: Minimum change in the monitored quantity to qualify as an improvement. Default: 0 269 | 270 | """ 271 | # Print settings to output file 272 | if verbose: 273 | print("Settings: \n\ 274 | N Components: {}\n\ 275 | Topic Prior Mean: {}\n\ 276 | Topic Prior Variance: {}\n\ 277 | Model Type: {}\n\ 278 | Hidden Sizes: {}\n\ 279 | Activation: {}\n\ 280 | Dropout: {}\n\ 281 | Learn Priors: {}\n\ 282 | Learning Rate: {}\n\ 283 | Momentum: {}\n\ 284 | Reduce On Plateau: {}\n\ 285 | Save Dir: {}".format( 286 | self.n_components, 0.0, 287 | 1. - (1. / self.n_components), self.model_type, 288 | self.hidden_sizes, self.activation, self.dropout, self.learn_priors, 289 | self.lr, self.momentum, self.reduce_on_plateau, save_dir)) 290 | 291 | self.model_dir = save_dir 292 | self.train_data = train_dataset 293 | self.validation_data = validation_dataset 294 | if self.validation_data is not None: 295 | self.early_stopping = EarlyStopping(patience=patience, verbose=verbose, path=save_dir, delta=delta) 296 | train_loader = DataLoader( 297 | self.train_data, batch_size=self.batch_size, shuffle=True, 298 | num_workers=self.num_data_loader_workers) 299 | 300 | # init training variables 301 | train_loss = 0 302 | samples_processed = 0 303 | 304 | # train loop 305 | pbar = tqdm(self.num_epochs, position=0, leave=True) 306 | for epoch in range(self.num_epochs): 307 | print("-"*10, "Epoch", epoch+1, "-"*10) 308 | self.nn_epoch = epoch 309 | # train epoch 310 | s = datetime.datetime.now() 311 | sp, train_loss = self._train_epoch(train_loader) 312 | samples_processed += sp 313 | e = datetime.datetime.now() 314 | pbar.update(1) 315 | 316 | if self.validation_data is not None: 317 | validation_loader = DataLoader(self.validation_data, batch_size=self.batch_size, shuffle=True, 318 | num_workers=self.num_data_loader_workers) 319 | # train epoch 320 | s = datetime.datetime.now() 321 | val_samples_processed, val_loss = self._validation(validation_loader) 322 | e = datetime.datetime.now() 323 | 324 | # report 325 | if verbose: 326 | print("Epoch: [{}/{}]\tSamples: [{}/{}]\tValidation Loss: {}\tTime: {}".format( 327 | epoch + 1, self.num_epochs, val_samples_processed, 328 | len(self.validation_data) * self.num_epochs, val_loss, e - s)) 329 | 330 | pbar.set_description("Epoch: [{}/{}]\t Seen Samples: [{}/{}]\tTrain Loss: {}\tValid Loss: {}\tTime: {}".format( 331 | epoch + 1, self.num_epochs, samples_processed, 332 | len(self.train_data) * self.num_epochs, train_loss, val_loss, e - s)) 333 | 334 | self.early_stopping(val_loss, self) 335 | if self.early_stopping.early_stop: 336 | print("Early stopping") 337 | 338 | break 339 | else: 340 | # save last epoch 341 | self.best_components = self.model.beta 342 | if save_dir is not None: 343 | self.save(save_dir) 344 | pbar.set_description("Epoch: [{}/{}]\t Seen Samples: [{}/{}]\tTrain Loss: {}\tTime: {}".format( 345 | epoch + 1, self.num_epochs, samples_processed, 346 | len(self.train_data) * self.num_epochs, train_loss, e - s)) 347 | 348 | #print topics for every epoch 349 | topics = self.get_topics() 350 | for l in range(self.num_lang): 351 | print("-"*10, self.languages[l].upper(), "-"*10) 352 | for k in range(self.n_components): 353 | print("Topic", k+1, ":", ', '.join(topics[l][k])) 354 | 355 | pbar.close() 356 | 357 | def _validation(self, loader): 358 | """Validation epoch.""" 359 | self.model.eval() 360 | val_loss = 0 361 | samples_processed = 0 362 | for batch_samples in loader: 363 | # batch_size x L x vocab_size 364 | X_bow = batch_samples['X_bow'] 365 | X_bow = X_bow.squeeze(dim=2) 366 | 367 | # batch_size x L x bert_size 368 | X_contextual = batch_samples['X_contextual'] 369 | 370 | # batch_size x image_enc_size 371 | X_image = batch_samples['X_image'] 372 | 373 | # forward pass 374 | self.model.zero_grad() 375 | prior_mean, prior_variance, posterior_mean1, posterior_variance1, posterior_log_variance1, \ 376 | posterior_mean2, posterior_variance2, posterior_log_variance2, \ 377 | posterior_mean3, posterior_variance3, posterior_log_variance3, \ 378 | word_dists, thetas = self.model(X_bow, X_contextual, X_image) 379 | 380 | # backward pass 381 | 382 | # Recon loss for lang1 and lang2 (no recon loss for image) 383 | rl_loss1 = self._rl_loss(X_bow[:,0,:], word_dists[0]) 384 | rl_loss2 = self._rl_loss(X_bow[:,1,:], word_dists[1]) 385 | 386 | # # KL loss 387 | # kl_en_de = self._kl_loss1(thetas[0], thetas[1]) 388 | # kl_en_image = self._kl_loss1(thetas[0], thetas[2]) 389 | # kl_de_image = self._kl_loss1(thetas[1], thetas[2]) 390 | 391 | # KL loss between posterior distributions of paired languages/modalities 392 | kl_en_de = self._kl_loss(posterior_mean1, posterior_variance1, 393 | posterior_mean2, posterior_variance2, posterior_log_variance2) 394 | kl_en_image = self._kl_loss(posterior_mean1, posterior_variance1, 395 | posterior_mean3, posterior_variance3, posterior_log_variance3) 396 | kl_de_image = self._kl_loss(posterior_mean2, posterior_variance2, 397 | posterior_mean3, posterior_variance3, posterior_log_variance3) 398 | 399 | # InfoNCE loss/NTXentLoss 400 | infoNCE_en_de = self._infoNCE_loss(thetas[0], thetas[1]) 401 | infoNCE_en_image = self._infoNCE_loss(thetas[0], thetas[2]) 402 | infoNCE_de_image = self._infoNCE_loss(thetas[1], thetas[2]) 403 | 404 | loss = rl_loss1 + rl_loss2 \ 405 | + self.weights["CL"]*infoNCE_en_de + self.weights["CL"]*infoNCE_en_image + self.weights["CL"]*infoNCE_de_image \ 406 | + self.weights["KL"] * kl_en_de + self.weights["KL"] * kl_en_image + self.weights["KL"] * kl_de_image 407 | 408 | 409 | loss = loss.sum() 410 | 411 | # compute train loss 412 | # samples_processed += X_bow.size()[0] 413 | samples_processed += X_bow.size()[0] 414 | val_loss += loss.item() 415 | 416 | val_loss /= samples_processed 417 | 418 | return samples_processed, val_loss 419 | 420 | def get_thetas(self, dataset, n_samples=20): 421 | """ 422 | Get the document-topic distribution for a dataset of topics. Includes multiple sampling to reduce variation via 423 | the parameter n_sample. 424 | 425 | :param dataset: a PyTorch Dataset containing the documents 426 | :param n_samples: the number of sample to collect to estimate the final distribution (the more the better). 427 | """ 428 | return self.get_doc_topic_distribution(dataset, n_samples=n_samples) 429 | 430 | def get_doc_topic_distribution(self, dataset, n_samples=20, lang_index=0): 431 | """ 432 | Get the document-topic distribution for a dataset of topics. Includes multiple sampling to reduce variation via 433 | the parameter n_sample. 434 | 435 | :param dataset: a PyTorch Dataset containing the documents 436 | :param n_samples: the number of sample to collect to estimate the final distribution (the more the better). 437 | """ 438 | self.model.eval() 439 | 440 | loader = DataLoader( 441 | dataset, batch_size=self.batch_size, shuffle=False, 442 | num_workers=self.num_data_loader_workers) 443 | pbar = tqdm(n_samples, position=0, leave=True) 444 | final_thetas = [] 445 | for sample_index in range(n_samples): 446 | with torch.no_grad(): 447 | collect_theta = [] 448 | 449 | for batch_samples in loader: 450 | 451 | # batch_size x L x bert_size 452 | X_contextual = batch_samples['X_contextual'] 453 | 454 | if self.USE_CUDA: 455 | X_contextual = X_contextual.cuda() 456 | 457 | # forward pass 458 | self.model.zero_grad() 459 | thetas = self.model.get_theta(x=None, x_bert=X_contextual, lang_index=lang_index) 460 | collect_theta.extend(thetas.detach().cpu().numpy()) 461 | 462 | pbar.update(1) 463 | pbar.set_description("Sampling: [{}/{}]".format(sample_index + 1, n_samples)) 464 | 465 | final_thetas.append(np.array(collect_theta)) 466 | pbar.close() 467 | return np.sum(final_thetas, axis=0) / n_samples 468 | 469 | def get_most_likely_topic(self, doc_topic_distribution): 470 | """ get the most likely topic for each document 471 | 472 | :param doc_topic_distribution: ndarray representing the topic distribution of each document 473 | """ 474 | return np.argmax(doc_topic_distribution, axis=0) 475 | 476 | 477 | def get_topics(self, k=10): 478 | """ 479 | Retrieve topic words. 480 | 481 | :param k: int, number of words to return per topic, default 10. 482 | """ 483 | assert k <= self.bow_size, "k must be <= input size." 484 | component_dists = self.best_components 485 | #print('component_dists:', component_dists.shape) 486 | topics_all = [] 487 | for l in range(self.num_lang): 488 | topics = defaultdict(list) 489 | for i in range(self.n_components): 490 | _, idxs = torch.topk(component_dists[l][i], k) 491 | component_words = [self.train_data.idx2token[l][idx] 492 | for idx in idxs.cpu().numpy()] 493 | topics[i] = component_words 494 | topics_all.append(topics) 495 | return topics_all 496 | 497 | def get_topic_lists(self, k=10): 498 | """ 499 | Retrieve the lists of topic words. 500 | 501 | :param k: (int) number of words to return per topic, default 10. 502 | """ 503 | assert k <= self.bow_size, "k must be <= input size." 504 | # TODO: collapse this method with the one that just returns the topics 505 | component_dists = self.best_components 506 | topics_all = [] 507 | for l in range(self.num_lang): 508 | topics = [] 509 | for i in range(self.n_components): 510 | _, idxs = torch.topk(component_dists[l][i], k) 511 | component_words = [self.train_data.idx2token[l][idx] 512 | for idx in idxs.cpu().numpy()] 513 | topics.append(component_words) 514 | topics_all.append(topics) 515 | return topics_all 516 | 517 | def _format_file(self): 518 | model_dir = "contextualized_topic_model_nc_{}_tpm_{}_tpv_{}_hs_{}_ac_{}_do_{}_lr_{}_mo_{}_rp_{}". \ 519 | format(self.n_components, 0.0, 1 - (1. / self.n_components), 520 | self.model_type, self.hidden_sizes, self.activation, 521 | self.dropout, self.lr, self.momentum, 522 | self.reduce_on_plateau) 523 | return model_dir 524 | 525 | def save(self, models_dir=None): 526 | """ 527 | Save model. (Experimental Feature, not tested) 528 | 529 | :param models_dir: path to directory for saving NN models. 530 | """ 531 | warnings.simplefilter('always', Warning) 532 | warnings.warn("This is an experimental feature that we has not been fully tested. Refer to the following issue:" 533 | "https://github.com/MilaNLProc/contextualized-topic-models/issues/38", 534 | Warning) 535 | 536 | if (self.model is not None) and (models_dir is not None): 537 | 538 | model_dir = self._format_file() 539 | if not os.path.isdir(os.path.join(models_dir, model_dir)): 540 | os.makedirs(os.path.join(models_dir, model_dir)) 541 | 542 | filename = "epoch_{}".format(self.nn_epoch) + '.pth' 543 | fileloc = os.path.join(models_dir, model_dir, filename) 544 | with open(fileloc, 'wb') as file: 545 | torch.save({'state_dict': self.model.state_dict(), 546 | 'dcue_dict': self.__dict__}, file) 547 | 548 | def load(self, model_dir, epoch): 549 | """ 550 | Load a previously trained model. (Experimental Feature, not tested) 551 | 552 | :param model_dir: directory where models are saved. 553 | :param epoch: epoch of model to load. 554 | """ 555 | 556 | warnings.simplefilter('always', Warning) 557 | warnings.warn("This is an experimental feature that we has not been fully tested. Refer to the following issue:" 558 | "https://github.com/MilaNLProc/contextualized-topic-models/issues/38", 559 | Warning) 560 | 561 | epoch_file = "epoch_" + str(epoch) + ".pth" 562 | model_file = os.path.join(model_dir, epoch_file) 563 | with open(model_file, 'rb') as model_dict: 564 | checkpoint = torch.load(model_dict) 565 | 566 | for (k, v) in checkpoint['dcue_dict'].items(): 567 | setattr(self, k, v) 568 | 569 | self.model.load_state_dict(checkpoint['state_dict']) 570 | 571 | def get_topic_word_matrix(self): 572 | """ 573 | Return the topic-word matrix (dimensions: number of topics x length of the vocabulary). 574 | If model_type is LDA, the matrix is normalized; otherwise the matrix is unnormalized. 575 | """ 576 | return self.model.topic_word_matrix.cpu().detach().numpy() 577 | 578 | def get_topic_word_distribution(self): 579 | """ 580 | Return the topic-word distribution (dimensions: number of topics x length of the vocabulary). 581 | """ 582 | mat = self.get_topic_word_matrix() 583 | return softmax(mat, axis=1) 584 | 585 | def get_word_distribution_by_topic_id(self, topic): 586 | """ 587 | Return the word probability distribution of a topic sorted by probability. 588 | 589 | :param topic: id of the topic (int) 590 | 591 | :returns list of tuples (word, probability) sorted by the probability in descending order 592 | """ 593 | if topic >= self.n_components: 594 | raise Exception('Topic id must be lower than the number of topics') 595 | else: 596 | wd = self.get_topic_word_distribution() 597 | t = [(word, wd[topic][idx]) for idx, word in self.train_data.idx2token.items()] 598 | t = sorted(t, key=lambda x: -x[1]) 599 | return t 600 | 601 | def get_wordcloud(self, topic_id, n_words=5, background_color="black", width=1000, height=400): 602 | """ 603 | Plotting the wordcloud. It is an adapted version of the code found here: 604 | http://amueller.github.io/word_cloud/auto_examples/simple.html#sphx-glr-auto-examples-simple-py and 605 | here https://github.com/ddangelov/Top2Vec/blob/master/top2vec/Top2Vec.py 606 | 607 | :param topic_id: id of the topic 608 | :param n_words: number of words to show in word cloud 609 | :param background_color: color of the background 610 | :param width: width of the produced image 611 | :param height: height of the produced image 612 | """ 613 | word_score_list = self.get_word_distribution_by_topic_id(topic_id)[:n_words] 614 | word_score_dict = {tup[0]: tup[1] for tup in word_score_list} 615 | plt.figure(figsize=(10, 4), dpi=200) 616 | plt.axis("off") 617 | plt.imshow(wordcloud.WordCloud(width=width, height=height, background_color=background_color 618 | ).generate_from_frequencies(word_score_dict)) 619 | plt.title("Displaying Topic " + str(topic_id), loc='center', fontsize=24) 620 | plt.show() 621 | 622 | def get_predicted_topics(self, dataset, n_samples): 623 | """ 624 | Return the a list containing the predicted topic for each document (length: number of documents). 625 | 626 | :param dataset: CTMDataset to infer topics 627 | :param n_samples: number of sampling of theta 628 | :return: the predicted topics 629 | """ 630 | predicted_topics = [] 631 | thetas = self.get_doc_topic_distribution(dataset, n_samples) 632 | 633 | for idd in range(len(dataset)): 634 | predicted_topic = np.argmax(thetas[idd] / np.sum(thetas[idd])) 635 | predicted_topics.append(predicted_topic) 636 | return predicted_topics 637 | 638 | def get_ldavis_data_format(self, vocab, dataset, n_samples): 639 | """ 640 | Returns the data that can be used in input to pyldavis to plot 641 | the topics 642 | """ 643 | term_frequency = dataset.X_bow.toarray().sum(axis=0) 644 | doc_lengths = dataset.X_bow.toarray().sum(axis=1) 645 | term_topic = self.get_topic_word_distribution() 646 | doc_topic_distribution = self.get_doc_topic_distribution(dataset, n_samples=n_samples) 647 | 648 | data = {'topic_term_dists': term_topic, 649 | 'doc_topic_dists': doc_topic_distribution, 650 | 'doc_lengths': doc_lengths, 651 | 'vocab': vocab, 652 | 'term_frequency': term_frequency} 653 | 654 | return data 655 | --------------------------------------------------------------------------------