├── README.md ├── img ├── logograms1_en_postprocess.png ├── model.svg ├── vae2.svg └── wordsGif.gif ├── language_data ├── README.md └── bash_scripts │ ├── create128dimEmbs_fromWords.sh │ ├── download_embeddings.sh │ └── get_evaluation.sh ├── project_proposal.pdf ├── src ├── MapperLayer.py ├── Omniglot.py ├── Omniglot_triplet.py ├── Siamese.py ├── configs │ ├── bbvae.yaml │ ├── bbvae_CompleteRun.yaml │ ├── bbvae_setup2.yaml │ ├── bbvae_setup3.yml │ ├── bbvae_setup4.yml │ └── betatc_vae.yaml ├── decoder_with_discriminator.py ├── experiment.py ├── logogram_language_generator.py ├── lossyMapper_train.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── base.cpython-36.pyc │ │ ├── beta_vae.cpython-36.pyc │ │ ├── betatc_vae.cpython-36.pyc │ │ ├── cat_vae.cpython-36.pyc │ │ ├── cvae.cpython-36.pyc │ │ ├── dfcvae.cpython-36.pyc │ │ ├── dip_vae.cpython-36.pyc │ │ ├── fvae.cpython-36.pyc │ │ ├── gamma_vae.cpython-36.pyc │ │ ├── hvae.cpython-36.pyc │ │ ├── info_vae.cpython-36.pyc │ │ ├── iwae.cpython-36.pyc │ │ ├── joint_vae.cpython-36.pyc │ │ ├── logcosh_vae.cpython-36.pyc │ │ ├── lvae.cpython-36.pyc │ │ ├── miwae.cpython-36.pyc │ │ ├── mssim_vae.cpython-36.pyc │ │ ├── swae.cpython-36.pyc │ │ ├── types_.cpython-36.pyc │ │ ├── vampvae.cpython-36.pyc │ │ ├── vanilla_vae.cpython-36.pyc │ │ ├── vq_vae.cpython-36.pyc │ │ └── wae_mmd.cpython-36.pyc │ ├── base.py │ └── beta_vae.py ├── randomSamples_fromEmbeddings │ ├── sample_fasttext_embs_en.pickle │ ├── sample_fasttext_embs_en_128.pickle │ ├── sample_fasttext_embs_es_128.pickle │ ├── sample_fasttext_embs_fr.pickle │ ├── sample_fasttext_embs_fr_128.pickle │ ├── sample_fasttext_embs_it_128.pickle │ └── sample_from_embeds.py ├── requirements.txt ├── run.py ├── test.py ├── train_mapper_and_discriminator.py ├── train_mapper_with_siamese.py ├── train_siamese.py ├── umwe2vae.py ├── umwe_mappers │ ├── best_mapping_es2en.t7 │ ├── best_mapping_fr2en.t7 │ └── best_mapping_it2en.t7 ├── utils.py └── vae_with_norm.py └── umwe ├── LICENSE ├── README.md ├── data └── get_evaluation.sh ├── demo.ipynb ├── evaluate.py ├── src ├── __init__.py ├── dico_builder.py ├── dictionary.py ├── evaluation │ ├── __init__.py │ ├── evaluator.py │ ├── sent_translation.py │ ├── word_translation.py │ └── wordsim.py ├── logger.py ├── models.py ├── trainer.py └── utils.py ├── supervised.py └── unsupervised.py /README.md: -------------------------------------------------------------------------------- 1 | # Logogram Language Generator 2 | 3 | This project is done within the scope of inzva AI Projects #5 August-November 2020. Check out other projects from the inzva GitHub. 4 | 5 | 6 | [Medium Post]: https://medium.com/@selim.seker00/logogram-language-generator-eb003293d51d . 7 | 8 | [Record]: https://www.youtube.com/watch?v=hTKlFtC7NMw of the final project presentation (in Turkish) 9 | 10 | 11 | 12 | ![wordsGif](./img/wordsGif.gif)

13 | ![logograms1_en_postprocess](./img/logograms1_en_postprocess.png) 14 | 15 | 16 | 17 | 18 | ## Readings 19 | 20 | #### Multilingual embedding: 21 | 22 | * https://arxiv.org/pdf/1808.08933.pdf 23 | * https://ai.facebook.com/tools/muse/ 24 | 25 | * http://jalammar.github.io/illustrated-word2vec/ 26 | 27 | * https://ruder.io/cross-lingual-embeddings/ 28 | 29 | 30 | 31 | ##### About linear mapping between language embedings: 32 | 33 | * https://arxiv.org/pdf/1309.4168.pdf 34 | 35 | * https://www.aclweb.org/anthology/N15-1104.pdf 36 | 37 | 38 | 39 | ##### About fastText (pre-trained monolingual embeddings in UMWE): 40 | 41 | * https://fasttext.cc/ 42 | 43 | * https://www.aclweb.org/anthology/Q17-1010.pdf 44 | 45 | 46 | 47 | 48 | 49 | #### Generative Models: 50 | 51 | * https://arxiv.org/pdf/1312.6114.pdf 52 | 53 | * https://arxiv.org/pdf/1511.05644.pdf 54 | 55 | * https://jaan.io/what-is-variational-autoencoder-vae-tutorial/ 56 | 57 | * https://towardsdatascience.com/generating-images-with-autoencoders-77fd3a8dd368 58 | 59 | * https://theaisummer.com/Autoencoder/ 60 | 61 | * https://towardsdatascience.com/generative-variational-autoencoder-for-high-resolution-image-synthesis-48dd98d4dcc2 62 | 63 | 64 | ## Datasets 65 | 66 | * https://github.com/brendenlake/omniglot 67 | * https://fasttext.cc/ 68 | 69 | 70 | #### Drive Link for Docs 71 | 72 | * https://drive.google.com/drive/folders/1md5ZayU85yaKnx0d6LM63gR1NwQNd3pX?usp=sharing 73 | 74 | 75 | 76 | 77 | ![model](./img/model.svg) 78 | -------------------------------------------------------------------------------- /img/logograms1_en_postprocess.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/img/logograms1_en_postprocess.png -------------------------------------------------------------------------------- /img/wordsGif.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/img/wordsGif.gif -------------------------------------------------------------------------------- /language_data/README.md: -------------------------------------------------------------------------------- 1 | here is the procedure we applied for umwe training: 2 | 1. we specified our vae's latent dim as 128 and fasttext monolingual vectors are 300 3 | 2. umwe takes training data from wiki.vec files (whitespace seperated plain text -> "word 300dVector") 4 | 3. so we need to reduce the vectors in the wiki.vec files (seperate file for each language) 5 | 4. in fasttext besides .vec files there is .bin embedding binaries for each language. we first need to reduce binaries dimension (there is a script for that in fasttext repo) 6 | 5. after reducing the each .bin file, again we are going to use a fasttext script for extracting embedding vectors from a word-list 7 | 6. to use that script we need to split the words-list from the .vec files (just write a python script for that) 8 | 7. then create the new 128 dimensional .vec files 9 | 8. dont forget to change the hyperparam for training umwe with 128 dim embeddings 10 | -------------------------------------------------------------------------------- /language_data/bash_scripts/create128dimEmbs_fromWords.sh: -------------------------------------------------------------------------------- 1 | # dont forget to add first line as "num_of_words emb_dim" for umwe (sed command for inserting line to the head of file) 2 | 3 | #./fasttext print-word-vectors ../cc.en.128.bin < ../wiki_wordsOnly/wiki.en.128.vec > ../wiki_vectors_128/wiki.en.vec 4 | #./fasttext print-word-vectors ../cc.es.128.bin < ../wiki_wordsOnly/wiki.es.128.vec > ../wiki_vectors_128/wiki.es.vec 5 | #./fasttext print-word-vectors ../cc.fr.128.bin < ../wiki_wordsOnly/wiki.fr.128.vec > ../wiki_vectors_128/wiki.fr.vec 6 | #./fasttext print-word-vectors ../cc.it.128.bin < ../wiki_wordsOnly/wiki.it.128.vec > ../wiki_vectors_128/wiki.it.vec 7 | ../fastText/fasttext print-word-vectors ../../vae/model_checkpoints/cc_bins_128/cc.tr.128.bin < ../wiki_wordsOnly/wiki.tr.128.vec > ../wiki_vectors_128/wiki.tr.vec 8 | -------------------------------------------------------------------------------- /language_data/bash_scripts/download_embeddings.sh: -------------------------------------------------------------------------------- 1 | 2 | ### This bash script downloads the fastText monolingual embedding vectors for: 3 | ### [en, fr, es, it, tr] 4 | 5 | #curl -o wiki.en.vec https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.en.vec 6 | #curl -o wiki.fr.vec https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.fr.vec 7 | #curl -o wiki.es.vec https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.es.vec 8 | #curl -o wiki.it.vec https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.it.vec 9 | #curl -o wiki.tr.vec https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.tr.vec 10 | #curl -o wiki.tr.vec https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.tr.vec 11 | #curl -o wiki.ru.vec https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.ru.vec 12 | curl -o cc.fr.300.bin https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fr.300.bin.gz 13 | curl -o cc.es.300.bin https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.es.300.bin.gz 14 | curl -o cc.it.300.bin https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.it.300.bin.gz 15 | curl -o cc.tr.300.bin https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.tr.300.bin.gz 16 | -------------------------------------------------------------------------------- /language_data/bash_scripts/get_evaluation.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | en_analogy='https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/word2vec/source-archive.zip' 9 | dl_path='https://dl.fbaipublicfiles.com/arrival' 10 | semeval_2017='http://alt.qcri.org/semeval2017/task2/data/uploads' 11 | europarl='http://www.statmt.org/europarl/v7' 12 | 13 | #declare -A wordsim_lg 14 | #wordsim_lg=(["en"]="EN_MC-30.txt EN_MTurk-287.txt EN_RG-65.txt EN_VERB-143.txt EN_WS-353-REL.txt EN_YP-130.txt EN_MEN-TR-3k.txt EN_MTurk-771.txt EN_RW-STANFORD.txt EN_SIMLEX-999.txt EN_WS-353-ALL.txt EN_WS-353-SIM.txt" ["es"]="ES_MC-30.txt ES_RG-65.txt ES_WS-353.txt" ["de"]="DE_GUR350.txt DE_GUR65.txt DE_SIMLEX-999.txt DE_WS-353.txt DE_ZG222.txt" ["fr"]="FR_RG-65.txt" ["it"]="IT_SIMLEX-999.txt IT_WS-353.txt") 15 | 16 | declare -A wordsim_lg 17 | wordsim_lg=(["fr"]="FR_RG-65.txt" ["it"]="IT_SIMLEX-999.txt IT_WS-353.txt") 18 | 19 | 20 | mkdir monolingual crosslingual 21 | 22 | ## English word analogy task 23 | curl -Lo source-archive.zip $en_analogy 24 | mkdir -p monolingual/en/ 25 | unzip -p source-archive.zip word2vec/trunk/questions-words.txt > monolingual/en/questions-words.txt 26 | rm source-archive.zip 27 | 28 | 29 | ## Downloading en-{} or {}-en dictionaries 30 | #lgs="af ar bg bn bs ca cs da de el en es et fa fi fr he hi hr hu id it ja ko lt lv mk ms nl no pl pt ro ru sk sl sq sv ta th tl tr uk vi zh" 31 | lgs="fr he hi hr hu id it ja ko lt lv mk ms nl no pl pt ro ru sk sl sq sv ta th tl tr uk vi zh" 32 | 33 | mkdir -p crosslingual/dictionaries/ 34 | for lg in ${lgs} 35 | do 36 | for suffix in .txt .0-5000.txt .5000-6500.txt 37 | do 38 | fname=en-$lg$suffix 39 | curl -Lo crosslingual/dictionaries/$fname $dl_path/dictionaries/$fname 40 | fname=$lg-en$suffix 41 | curl -Lo crosslingual/dictionaries/$fname $dl_path/dictionaries/$fname 42 | done 43 | done 44 | 45 | ## Download European dictionaries 46 | for src_lg in de es fr it pt 47 | do 48 | for tgt_lg in de es fr it pt 49 | do 50 | if [ $src_lg != $tgt_lg ] 51 | then 52 | for suffix in .txt .0-5000.txt .5000-6500.txt 53 | do 54 | fname=$src_lg-$tgt_lg$suffix 55 | curl -Lo crosslingual/dictionaries/$fname $dl_path/dictionaries/$fname 56 | done 57 | fi 58 | done 59 | done 60 | 61 | ## Download Dinu et al. dictionaries 62 | for fname in OPUS_en_it_europarl_train_5K.txt OPUS_en_it_europarl_test.txt 63 | do 64 | echo $fname 65 | curl -Lo crosslingual/dictionaries/$fname $dl_path/dictionaries/$fname 66 | done 67 | 68 | ## Monolingual wordsim tasks 69 | for lang in "${!wordsim_lg[@]}" 70 | do 71 | echo $lang 72 | mkdir monolingual/$lang 73 | for wsim in ${wordsim_lg[$lang]} 74 | do 75 | echo $wsim 76 | curl -Lo monolingual/$lang/$wsim $dl_path/$lang/$wsim 77 | done 78 | done 79 | 80 | ## SemEval 2017 monolingual and cross-lingual wordsim tasks 81 | # 1) Task1: monolingual 82 | curl -Lo semeval2017-task2.zip $semeval_2017/semeval2017-task2.zip 83 | unzip semeval2017-task2.zip 84 | 85 | fdir='SemEval17-Task2/test/subtask1-monolingual' 86 | for lang in en es de fa it 87 | do 88 | mkdir -p monolingual/$lang 89 | uplang=`echo $lang | awk '{print toupper($0)}'` 90 | paste $fdir/data/$lang.test.data.txt $fdir/keys/$lang.test.gold.txt > monolingual/$lang/${uplang}_SEMEVAL17.txt 91 | done 92 | 93 | # 2) Task2: cross-lingual 94 | mkdir -p crosslingual/wordsim 95 | fdir='SemEval17-Task2/test/subtask2-crosslingual' 96 | for lg_pair in de-es de-fa de-it en-de en-es en-fa en-it es-fa es-it it-fa 97 | do 98 | echo $lg_pair 99 | paste $fdir/data/$lg_pair.test.data.txt $fdir/keys/$lg_pair.test.gold.txt > crosslingual/wordsim/$lg_pair-SEMEVAL17.txt 100 | done 101 | rm semeval2017-task2.zip 102 | rm -r SemEval17-Task2/ 103 | 104 | ## Europarl for sentence retrieval 105 | # TODO: set to true to activate download of Europarl (slow) 106 | if false; then 107 | mkdir -p crosslingual/europarl 108 | # Tokenize EUROPARL with MOSES 109 | echo 'Cloning Moses github repository (for tokenization scripts)...' 110 | git clone https://github.com/moses-smt/mosesdecoder.git 111 | SCRIPTS=mosesdecoder/scripts 112 | TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl 113 | 114 | for lg_pair in it-en # es-en etc 115 | do 116 | curl -Lo $lg_pair.tgz $europarl/$lg_pair.tgz 117 | tar -xvf $lg_pair.tgz 118 | rm $lg_pair.tgz 119 | lgs=(${lg_pair//-/ }) 120 | for lg in ${lgs[0]} ${lgs[1]} 121 | do 122 | cat europarl-v7.$lg_pair.$lg | $TOKENIZER -threads 8 -l $lg -no-escape > euro.$lg.txt 123 | rm europarl-v7.$lg_pair.$lg 124 | done 125 | 126 | paste euro.${lgs[0]}.txt euro.${lgs[1]}.txt | shuf > euro.paste.txt 127 | rm euro.${lgs[0]}.txt euro.${lgs[1]}.txt 128 | 129 | cut -f1 euro.paste.txt > crosslingual/europarl/europarl-v7.$lg_pair.${lgs[0]} 130 | cut -f2 euro.paste.txt > crosslingual/europarl/europarl-v7.$lg_pair.${lgs[1]} 131 | rm euro.paste.txt 132 | done 133 | 134 | rm -rf mosesdecoder 135 | fi 136 | -------------------------------------------------------------------------------- /project_proposal.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/project_proposal.pdf -------------------------------------------------------------------------------- /src/MapperLayer.py: -------------------------------------------------------------------------------- 1 | from Omniglot import Omniglot 2 | import yaml 3 | import argparse 4 | import numpy as np 5 | 6 | from models import * 7 | from experiment import VAEXperiment 8 | import torch.backends.cudnn as cudnn 9 | from pytorch_lightning import Trainer 10 | from pytorch_lightning.logging import TestTubeLogger 11 | import torchvision.utils as vutils 12 | from torch.utils.data import DataLoader 13 | import torch.optim as optim 14 | from torchvision import transforms 15 | 16 | import torch 17 | from torch import nn 18 | import torch.optim as optim 19 | from torchvision.transforms.functional import adjust_contrast 20 | 21 | 22 | import pickle 23 | import random 24 | 25 | 26 | 27 | class umwe2vae(nn.Module): 28 | def __init__(self,vae_model, in_dim=300, out_dim=128): 29 | super(umwe2vae, self).__init__() 30 | self.vae_model = vae_model 31 | self.fc = nn.Linear(in_dim, out_dim) 32 | 33 | def forward(self, x): 34 | h = self.fc(x) 35 | # y = self.vae_model.decode(h) 36 | return h 37 | # here used to live post-processing 38 | #out = torch.zeros(y.shape) 39 | #for i in range(y.shape[0]): 40 | # out[i] = adjust_contrast(y[i], contrast_factor=2.5) 41 | #return out 42 | 43 | def loss(self, x, alpha=1, beta=1): 44 | middle = x[:,:,1:-1,1:-1] 45 | ne = x[:,:,0:-2,0:-2] 46 | n = x[:,:,0:-2,1:-1] 47 | nw = x[:,:,0:-2,2:] 48 | e = x[:,:,1:-1,0:-2] 49 | w = x[:,:,1:-1,2:] 50 | se = x[:,:,2:,0:-2] 51 | s = x[:,:,2:,1:-1] 52 | sw = x[:,:,2:,2:] 53 | 54 | return alpha * torch.mean(sum([torch.abs(middle-ne),torch.abs(middle-n),torch.abs(middle-nw),torch.abs(middle-e),torch.abs(middle-w),torch.abs(middle-se),torch.abs(middle-s),torch.abs(middle-sw)]) / 8.) - beta * torch.mean(torch.abs(x-0.5)) 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | class EmbeddingMapping(nn.Module): 64 | def __init__(self, device, embedding_vector_dim = 300, decoder_input_dim=128): 65 | super(EmbeddingMapping, self).__init__() 66 | self.device = device 67 | self.embedding_vector_dim = embedding_vector_dim 68 | self.decoder_input_dim = decoder_input_dim 69 | self.mapper_numlayer = 3 70 | 71 | self.linear_layers = [] 72 | self.batch_norms = [] 73 | for layer in range(0, self.mapper_numlayer-1): 74 | self.linear_layers.append(nn.Linear(embedding_vector_dim, embedding_vector_dim)) 75 | self.batch_norms.append(nn.BatchNorm1d(embedding_vector_dim)) 76 | 77 | # final layer 78 | self.linear_layers.append(nn.Linear(embedding_vector_dim, decoder_input_dim)) 79 | self.batch_norms.append(nn.BatchNorm1d(decoder_input_dim)) 80 | 81 | 82 | self.linear_layers = nn.ModuleList(self.linear_layers) 83 | self.batch_norms = nn.ModuleList(self.batch_norms) 84 | 85 | 86 | self.relu = nn.ReLU() 87 | 88 | def forward(self, embedding_vector): 89 | inp = embedding_vector 90 | for layer in range(self.mapper_numlayer): 91 | out = self.linear_layers[layer](inp) 92 | out = self.batch_norms[layer](out) 93 | out = self.relu(out) 94 | inp = out 95 | return out 96 | 97 | 98 | 99 | class MultilingualMapper(nn.Module): 100 | def __init__(self, device, embedding_vector_dim = 300, decoder_input_dim=128): 101 | super(MultilingualMapper, self).__init__() 102 | self.device = device 103 | self.embedding_vector_dim = embedding_vector_dim 104 | self.decoder_input_dim = decoder_input_dim 105 | self.mapper_numlayer = 3 106 | 107 | self.linear_layers = [] 108 | self.batch_norms = [] 109 | for layer in range(0, self.mapper_numlayer-1): 110 | self.linear_layers.append(nn.Linear(embedding_vector_dim, embedding_vector_dim)) 111 | self.batch_norms.append(nn.BatchNorm1d(embedding_vector_dim)) 112 | 113 | # final layer 114 | self.linear_layers.append(nn.Linear(embedding_vector_dim, decoder_input_dim)) 115 | self.batch_norms.append(nn.BatchNorm1d(decoder_input_dim)) 116 | 117 | 118 | self.linear_layers = nn.ModuleList(self.linear_layers) 119 | self.batch_norms = nn.ModuleList(self.batch_norms) 120 | self.relu = nn.ReLU() 121 | 122 | self.bce = nn.BCEWithLogitsLoss() 123 | 124 | def forward(self, embedding_vector): 125 | inp = embedding_vector 126 | for layer in range(self.mapper_numlayer): 127 | out = self.linear_layers[layer](inp) 128 | out = self.batch_norms[layer](out) 129 | out = self.relu(out) 130 | inp = out 131 | return out 132 | 133 | def triplet_loss(self, sameWords_diffLangs, diffWords_sameLangs): 134 | return self.bce(sameWords_diffLangs, torch.ones(sameWords_diffLangs.shape).to(self.device)) + self.bce(diffWords_sameLangs, torch.zeros(diffWords_sameLangs.shape).to(self.device)) 135 | 136 | 137 | -------------------------------------------------------------------------------- /src/Omniglot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from PIL import Image 4 | from torchvision import transforms 5 | import os 6 | 7 | class Omniglot(Dataset): 8 | def __init__(self, split="train", transform=None): 9 | self.split = split 10 | if(transform==None): 11 | self.transform = transforms.ToTensor() 12 | else: 13 | self.transform = transform 14 | 15 | if(split=="train"): 16 | self.alphabet_names = ['Alphabet_of_the_Magi', 'Anglo-Saxon_Futhorc', 'Arcadian', 'Armenian', 'Asomtavruli_(Georgian)', 'Balinese', 'Bengali', 'Blackfoot_(Canadian_Aboriginal_Syllabics)', 'Braille', 'Burmese_(Myanmar)', 'Cyrillic', 'Early_Aramaic', 'Futurama', 'Grantha', 'Greek', 'Gujarati', 'Hebrew', 'Inuktitut_(Canadian_Aboriginal_Syllabics)', 'Japanese_(hiragana)', 'Japanese_(katakana)', 'Korean', 'Latin', 'Malay_(Jawi_-_Arabic)', 'Mkhedruli_(Georgian)', 'N_Ko', 'Ojibwe_(Canadian_Aboriginal_Syllabics)', 'Sanskrit', 'Syriac_(Estrangelo)', 'Tagalog', 'Tifinagh'] 17 | self.character_nums = [20, 29, 26, 41, 40, 24, 46, 14, 26, 34, 33, 22, 26, 43, 24, 48, 22, 16, 52, 47, 40, 26, 40, 41, 33, 14, 42, 23, 17, 55] 18 | 19 | elif(split=="test"): 20 | self.alphabet_names = ['Angelic', 'Atemayar_Qelisayer', 'Atlantean', 'Aurek-Besh', 'Avesta', 'Ge_ez', 'Glagolitic', 'Gurmukhi', 'Kannada', 'Keble', 'Malayalam', 'Manipuri', 'Mongolian', 'Old_Church_Slavonic_(Cyrillic)', 'Oriya', 'Sylheti', 'Syriac_(Serto)', 'Tengwar', 'Tibetan', 'ULOG'] 21 | self.character_nums = [20, 26, 26, 26, 26, 26, 45, 45, 41, 26, 47, 40, 30, 45, 46, 28, 23, 25, 42, 26] 22 | 23 | else: # all splits 24 | self.alphabet_names = ['Alphabet_of_the_Magi', 'Anglo-Saxon_Futhorc', 'Arcadian', 'Armenian', 'Asomtavruli_(Georgian)', 'Balinese', 'Bengali', 'Blackfoot_(Canadian_Aboriginal_Syllabics)', 'Braille', 'Burmese_(Myanmar)', 'Cyrillic', 'Early_Aramaic', 'Futurama', 'Grantha', 'Greek', 'Gujarati', 'Hebrew', 'Inuktitut_(Canadian_Aboriginal_Syllabics)', 'Japanese_(hiragana)', 'Japanese_(katakana)', 'Korean', 'Latin', 'Malay_(Jawi_-_Arabic)', 'Mkhedruli_(Georgian)', 'N_Ko', 'Ojibwe_(Canadian_Aboriginal_Syllabics)', 'Sanskrit', 'Syriac_(Estrangelo)', 'Tagalog', 'Tifinagh', 'Angelic', 'Atemayar_Qelisayer', 'Atlantean', 'Aurek-Besh', 'Avesta', 'Ge_ez', 'Glagolitic', 'Gurmukhi', 'Kannada', 'Keble', 'Malayalam', 'Manipuri', 'Mongolian', 'Old_Church_Slavonic_(Cyrillic)', 'Oriya', 'Sylheti', 'Syriac_(Serto)', 'Tengwar', 'Tibetan', 'ULOG'] 25 | self.character_nums = [20, 29, 26, 41, 40, 24, 46, 14, 26, 34, 33, 22, 26, 43, 24, 48, 22, 16, 52, 47, 40, 26, 40, 41, 33, 14, 42, 23, 17, 55, 20, 26, 26, 26, 26, 26, 45, 45, 41, 26, 47, 40, 30, 45, 46, 28, 23, 25, 42, 26] 26 | 27 | self.n_images = sum(self.character_nums) * 20 28 | 29 | def __len__(self): 30 | return self.n_images 31 | 32 | def __getitem__(self, idx): 33 | j = idx % 20 34 | i = idx // 20 35 | label = i 36 | 37 | for k in range(len(self.alphabet_names)): 38 | if(i < self.character_nums[k]): 39 | alphabet = self.alphabet_names[k]+"/" 40 | if(self.split=="train"): 41 | folder = "images_background/" 42 | elif(self.split=="test"): 43 | folder = "images_evaluation/" 44 | else: 45 | folder = "images_background/" if k<30 else "images_evaluation/" 46 | break 47 | else: 48 | i -= self.character_nums[k] 49 | 50 | # /content/omniglot 51 | # char_name = "/content/omniglot/"+folder+alphabet+"character"+(str(i+1) if i>8 else "0"+str(i+1))+"/" 52 | char_name = "omniglot/"+folder+alphabet+"character"+(str(i+1) if i>8 else "0"+str(i+1))+"/" 53 | example_list = sorted(os.listdir(char_name)) 54 | img = Image.open(char_name+example_list[j]) 55 | 56 | return (self.transform(img.convert('L')), label) 57 | -------------------------------------------------------------------------------- /src/Omniglot_triplet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from PIL import Image 4 | from torchvision import transforms 5 | import os 6 | import random 7 | 8 | class Omniglot(Dataset): 9 | def __init__(self, split="train", transform=None): 10 | self.split = split 11 | if(transform==None): 12 | self.transform = transforms.ToTensor() 13 | else: 14 | self.transform = transform 15 | 16 | if(split=="train"): 17 | self.alphabet_names = ['Alphabet_of_the_Magi', 'Anglo-Saxon_Futhorc', 'Arcadian', 'Armenian', 'Asomtavruli_(Georgian)', 'Balinese', 'Bengali', 'Blackfoot_(Canadian_Aboriginal_Syllabics)', 'Braille', 'Burmese_(Myanmar)', 'Cyrillic', 'Early_Aramaic', 'Futurama', 'Grantha', 'Greek', 'Gujarati', 'Hebrew', 'Inuktitut_(Canadian_Aboriginal_Syllabics)', 'Japanese_(hiragana)', 'Japanese_(katakana)', 'Korean', 'Latin', 'Malay_(Jawi_-_Arabic)', 'Mkhedruli_(Georgian)', 'N_Ko', 'Ojibwe_(Canadian_Aboriginal_Syllabics)', 'Sanskrit', 'Syriac_(Estrangelo)', 'Tagalog', 'Tifinagh'] 18 | self.character_nums = [20, 29, 26, 41, 40, 24, 46, 14, 26, 34, 33, 22, 26, 43, 24, 48, 22, 16, 52, 47, 40, 26, 40, 41, 33, 14, 42, 23, 17, 55] 19 | 20 | elif(split=="test"): 21 | self.alphabet_names = ['Angelic', 'Atemayar_Qelisayer', 'Atlantean', 'Aurek-Besh', 'Avesta', 'Ge_ez', 'Glagolitic', 'Gurmukhi', 'Kannada', 'Keble', 'Malayalam', 'Manipuri', 'Mongolian', 'Old_Church_Slavonic_(Cyrillic)', 'Oriya', 'Sylheti', 'Syriac_(Serto)', 'Tengwar', 'Tibetan', 'ULOG'] 22 | self.character_nums = [20, 26, 26, 26, 26, 26, 45, 45, 41, 26, 47, 40, 30, 45, 46, 28, 23, 25, 42, 26] 23 | 24 | else: # all splits 25 | self.alphabet_names = ['Alphabet_of_the_Magi', 'Anglo-Saxon_Futhorc', 'Arcadian', 'Armenian', 'Asomtavruli_(Georgian)', 'Balinese', 'Bengali', 'Blackfoot_(Canadian_Aboriginal_Syllabics)', 'Braille', 'Burmese_(Myanmar)', 'Cyrillic', 'Early_Aramaic', 'Futurama', 'Grantha', 'Greek', 'Gujarati', 'Hebrew', 'Inuktitut_(Canadian_Aboriginal_Syllabics)', 'Japanese_(hiragana)', 'Japanese_(katakana)', 'Korean', 'Latin', 'Malay_(Jawi_-_Arabic)', 'Mkhedruli_(Georgian)', 'N_Ko', 'Ojibwe_(Canadian_Aboriginal_Syllabics)', 'Sanskrit', 'Syriac_(Estrangelo)', 'Tagalog', 'Tifinagh', 'Angelic', 'Atemayar_Qelisayer', 'Atlantean', 'Aurek-Besh', 'Avesta', 'Ge_ez', 'Glagolitic', 'Gurmukhi', 'Kannada', 'Keble', 'Malayalam', 'Manipuri', 'Mongolian', 'Old_Church_Slavonic_(Cyrillic)', 'Oriya', 'Sylheti', 'Syriac_(Serto)', 'Tengwar', 'Tibetan', 'ULOG'] 26 | self.character_nums = [20, 29, 26, 41, 40, 24, 46, 14, 26, 34, 33, 22, 26, 43, 24, 48, 22, 16, 52, 47, 40, 26, 40, 41, 33, 14, 42, 23, 17, 55, 20, 26, 26, 26, 26, 26, 45, 45, 41, 26, 47, 40, 30, 45, 46, 28, 23, 25, 42, 26] 27 | 28 | self.n_images = sum(self.character_nums) * 20 29 | 30 | def __len__(self): 31 | return self.n_images 32 | 33 | def __getitem__(self, idx): 34 | j = idx % 20 35 | i = idx // 20 36 | label = i 37 | 38 | for k in range(len(self.alphabet_names)): 39 | if(i < self.character_nums[k]): 40 | alphabet = self.alphabet_names[k]+"/" 41 | if(self.split=="train"): 42 | folder = "images_background/" 43 | elif(self.split=="test"): 44 | folder = "images_evaluation/" 45 | else: 46 | folder = "images_background/" if k<30 else "images_evaluation/" 47 | break 48 | else: 49 | i -= self.character_nums[k] 50 | 51 | char_name = "omniglot/"+folder+alphabet+"character"+(str(i+1) if i>8 else "0"+str(i+1))+"/" 52 | example_list = sorted(os.listdir(char_name)) 53 | img = Image.open(char_name+example_list[j]) 54 | 55 | return (self.transform(img.convert('L')), label) 56 | 57 | class Omniglot_triplet(Dataset): 58 | def __init__(self, split="train", transform=None): 59 | self.split = split 60 | if(transform==None): 61 | self.transform = transforms.ToTensor() 62 | else: 63 | self.transform = transform 64 | 65 | if(split=="train"): 66 | self.alphabet_names = ['Alphabet_of_the_Magi', 'Anglo-Saxon_Futhorc', 'Arcadian', 'Armenian', 'Asomtavruli_(Georgian)', 'Balinese', 'Bengali', 'Blackfoot_(Canadian_Aboriginal_Syllabics)', 'Braille', 'Burmese_(Myanmar)', 'Cyrillic', 'Early_Aramaic', 'Futurama', 'Grantha', 'Greek', 'Gujarati', 'Hebrew', 'Inuktitut_(Canadian_Aboriginal_Syllabics)', 'Japanese_(hiragana)', 'Japanese_(katakana)', 'Korean', 'Latin', 'Malay_(Jawi_-_Arabic)', 'Mkhedruli_(Georgian)', 'N_Ko', 'Ojibwe_(Canadian_Aboriginal_Syllabics)', 'Sanskrit', 'Syriac_(Estrangelo)', 'Tagalog', 'Tifinagh'] 67 | self.character_nums = [20, 29, 26, 41, 40, 24, 46, 14, 26, 34, 33, 22, 26, 43, 24, 48, 22, 16, 52, 47, 40, 26, 40, 41, 33, 14, 42, 23, 17, 55] 68 | 69 | elif(split=="test"): 70 | self.alphabet_names = ['Angelic', 'Atemayar_Qelisayer', 'Atlantean', 'Aurek-Besh', 'Avesta', 'Ge_ez', 'Glagolitic', 'Gurmukhi', 'Kannada', 'Keble', 'Malayalam', 'Manipuri', 'Mongolian', 'Old_Church_Slavonic_(Cyrillic)', 'Oriya', 'Sylheti', 'Syriac_(Serto)', 'Tengwar', 'Tibetan', 'ULOG'] 71 | self.character_nums = [20, 26, 26, 26, 26, 26, 45, 45, 41, 26, 47, 40, 30, 45, 46, 28, 23, 25, 42, 26] 72 | 73 | else: # all splits 74 | self.alphabet_names = ['Alphabet_of_the_Magi', 'Anglo-Saxon_Futhorc', 'Arcadian', 'Armenian', 'Asomtavruli_(Georgian)', 'Balinese', 'Bengali', 'Blackfoot_(Canadian_Aboriginal_Syllabics)', 'Braille', 'Burmese_(Myanmar)', 'Cyrillic', 'Early_Aramaic', 'Futurama', 'Grantha', 'Greek', 'Gujarati', 'Hebrew', 'Inuktitut_(Canadian_Aboriginal_Syllabics)', 'Japanese_(hiragana)', 'Japanese_(katakana)', 'Korean', 'Latin', 'Malay_(Jawi_-_Arabic)', 'Mkhedruli_(Georgian)', 'N_Ko', 'Ojibwe_(Canadian_Aboriginal_Syllabics)', 'Sanskrit', 'Syriac_(Estrangelo)', 'Tagalog', 'Tifinagh', 'Angelic', 'Atemayar_Qelisayer', 'Atlantean', 'Aurek-Besh', 'Avesta', 'Ge_ez', 'Glagolitic', 'Gurmukhi', 'Kannada', 'Keble', 'Malayalam', 'Manipuri', 'Mongolian', 'Old_Church_Slavonic_(Cyrillic)', 'Oriya', 'Sylheti', 'Syriac_(Serto)', 'Tengwar', 'Tibetan', 'ULOG'] 75 | self.character_nums = [20, 29, 26, 41, 40, 24, 46, 14, 26, 34, 33, 22, 26, 43, 24, 48, 22, 16, 52, 47, 40, 26, 40, 41, 33, 14, 42, 23, 17, 55, 20, 26, 26, 26, 26, 26, 45, 45, 41, 26, 47, 40, 30, 45, 46, 28, 23, 25, 42, 26] 76 | 77 | self.n_images = sum(self.character_nums) * 20 78 | 79 | def __len__(self): 80 | return self.n_images 81 | 82 | def get_single_image(self, idx): 83 | j = idx % 20 84 | i = idx // 20 85 | label = i 86 | 87 | for k in range(len(self.alphabet_names)): 88 | if(i < self.character_nums[k]): 89 | alphabet = self.alphabet_names[k]+"/" 90 | if(self.split=="train"): 91 | folder = "images_background/" 92 | elif(self.split=="test"): 93 | folder = "images_evaluation/" 94 | else: 95 | folder = "images_background/" if k<30 else "images_evaluation/" 96 | break 97 | else: 98 | i -= self.character_nums[k] 99 | 100 | # char_name = "omniglot/"+folder+alphabet+"character"+(str(i+1) if i>8 else "0"+str(i+1))+"/" 101 | char_name = "./omniglot/"+folder+alphabet+"character"+(str(i+1) if i>8 else "0"+str(i+1))+"/" 102 | example_list = sorted(os.listdir(char_name)) 103 | img = Image.open(char_name+example_list[j]) 104 | 105 | return self.transform(img.convert('L')) 106 | 107 | def __getitem__(self, idx): 108 | j = idx % 20 109 | prange = list(range(idx-j, idx)) + list(range(idx+1, idx-j+20)) 110 | nrange = list(range(idx-j)) + list(range(idx-j+20,self.n_images)) 111 | idx_p = random.choice(prange) 112 | idx_n = random.choice(nrange) 113 | 114 | return (self.get_single_image(idx),self.get_single_image(idx_p),self.get_single_image(idx_n)) 115 | 116 | -------------------------------------------------------------------------------- /src/Siamese.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | 6 | class Siamese(nn.Module): 7 | 8 | def __init__(self): 9 | super(Siamese, self).__init__() 10 | self.encoder = nn.Sequential( 11 | nn.Conv2d(1, 64, 10), # 64@96*96 12 | nn.ReLU(inplace=True), 13 | nn.MaxPool2d(2), # 64@48*48 14 | nn.Conv2d(64, 128, 7), 15 | nn.ReLU(), # 128@42*42 16 | nn.MaxPool2d(2), # 128@21*21 17 | nn.Conv2d(128, 128, 4), 18 | nn.ReLU(), # 128@18*18 19 | nn.MaxPool2d(2), # 128@9*9 20 | nn.Conv2d(128, 256, 4), 21 | nn.ReLU(), # 256@6*6 22 | nn.Flatten(), 23 | nn.Linear(9216, 4096), 24 | nn.Sigmoid() 25 | ) 26 | 27 | self.classifier = nn.Linear(4096, 1) 28 | self.bce_fn = nn.BCEWithLogitsLoss() 29 | 30 | def forward(self, x1, x2): 31 | z1 = self.encoder.forward(x1) 32 | z2 = self.encoder.forward(x2) 33 | 34 | z = torch.abs(z1 - z2) 35 | y = self.classifier.forward(z) 36 | return y 37 | 38 | def triplet_loss(self, yp, yn): 39 | return self.bce_fn(yp, torch.ones(yp.shape).cuda()) + self.bce_fn(yn, torch.zeros(yn.shape).cuda()) 40 | 41 | def train_triplet(self, loader, epochs=100): 42 | self.train() 43 | print("new_siamese") 44 | optimizer = optim.Adam(self.parameters(), lr=0.00001) 45 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95) 46 | minLoss = 1000.5 47 | patience = 0 48 | maxPatience = 5 49 | for epoch in range(epochs): 50 | total_loss = 0 51 | for i,(xa,xp,xn) in enumerate(loader): 52 | yp = self.forward(xa.cuda(), xp.cuda()) 53 | yn = self.forward(xa.cuda(), xn.cuda()) 54 | 55 | loss = self.triplet_loss(yp, yn) 56 | 57 | loss.backward() 58 | optimizer.step() 59 | 60 | total_loss += loss 61 | print("Epoch: %d Loss: %f" % (epoch+1, total_loss.item()/len(loader))) 62 | scheduler.step() 63 | 64 | if total_loss.item()/len(loader) < minLoss: 65 | minLoss = total_loss.item()/len(loader) 66 | patience = 0 67 | else: 68 | patience +=1 69 | print("patienceCounter: ", patience) 70 | if patience == maxPatience: 71 | print("early stopping: ", minLoss) 72 | break 73 | torch.save(self.state_dict(), "siamese.ckpt") 74 | #torch.save({'model_state_dict': self.state_dict()}, "siamese_EPOCH100.ckpt") 75 | 76 | def eval_triplet(self, loader): 77 | self.eval() 78 | correct,total = 0,0 79 | for i,(xa,xp,xn) in enumerate(loader): 80 | yp = self.forward(xa.cuda(), xp.cuda()) 81 | yn = self.forward(xa.cuda(), xn.cuda()) 82 | correct += torch.sum((yp>0).float()) + torch.sum((yn<0).float()) 83 | #correct += torch.sum(torch.FloatTensor(yp>0)) + torch.sum(torch.FloatTensor(yn<0)) 84 | #correct += torch.sum(yp>0) + torch.sum(yn<0) 85 | total += yp.shape[0] + yn.shape[0] 86 | 87 | if i==100: 88 | break 89 | 90 | print("Accuracy: %f" % (correct/total).item()) 91 | 92 | -------------------------------------------------------------------------------- /src/configs/bbvae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'BetaVAE' 3 | in_channels: 1 4 | latent_dim: 300 5 | loss_type: 'B' 6 | gamma: 10.0 7 | max_capacity: 25 8 | Capacity_max_iter: 10000 9 | 10 | exp_params: 11 | dataset: omniglot 12 | data_path: "../../shared/Data/" 13 | img_size: 64 14 | batch_size: 144 # Better to have a square number 15 | LR: 0.00005 16 | weight_decay: 0.0 17 | scheduler_gamma: 0.95 18 | 19 | trainer_params: 20 | gpus: 1 21 | max_nb_epochs: 50 22 | max_epochs: 50 23 | 24 | logging_params: 25 | save_dir: "logs/" 26 | name: "BetaVAE_B" 27 | manual_seed: 1265 28 | -------------------------------------------------------------------------------- /src/configs/bbvae_CompleteRun.yaml: -------------------------------------------------------------------------------- 1 | # latent_dim 300 to 128 2 | # batch_size 144 to 64 3 | 4 | model_params: 5 | name: 'BetaVAE' 6 | in_channels: 1 7 | latent_dim: 128 8 | loss_type: 'B' 9 | gamma: 10.0 10 | max_capacity: 25 11 | Capacity_max_iter: 10000 12 | 13 | exp_params: 14 | dataset: omniglot 15 | data_path: "../../shared/Data/" 16 | img_size: 64 17 | batch_size: 64 # Better to have a square number 18 | LR: 0.00005 19 | weight_decay: 0.0 20 | scheduler_gamma: 0.95 21 | 22 | trainer_params: 23 | gpus: 1 24 | max_nb_epochs: 70 25 | max_epochs: 70 26 | 27 | logging_params: 28 | save_dir: "logs/" 29 | name: "BetaVAE_B_completeRun" 30 | manual_seed: 1265 -------------------------------------------------------------------------------- /src/configs/bbvae_setup2.yaml: -------------------------------------------------------------------------------- 1 | # latent_dim 300 to 128 2 | # batch_size 144 to 64 3 | 4 | model_params: 5 | name: 'BetaVAE' 6 | in_channels: 1 7 | latent_dim: 128 8 | loss_type: 'B' 9 | gamma: 10.0 10 | max_capacity: 25 11 | Capacity_max_iter: 10000 12 | 13 | exp_params: 14 | dataset: omniglot 15 | data_path: "../../shared/Data/" 16 | img_size: 64 17 | batch_size: 64 # Better to have a square number 18 | LR: 0.00005 19 | weight_decay: 0.0 20 | scheduler_gamma: 0.95 21 | 22 | trainer_params: 23 | gpus: 1 24 | max_nb_epochs: 50 25 | max_epochs: 50 26 | 27 | logging_params: 28 | save_dir: "logs/" 29 | name: "BetaVAE_B_setup2_run2" 30 | manual_seed: 1265 -------------------------------------------------------------------------------- /src/configs/bbvae_setup3.yml: -------------------------------------------------------------------------------- 1 | # batch_size 144 to 64 2 | 3 | model_params: 4 | name: "BetaVAE" 5 | in_channels: 1 6 | latent_dim: 300 7 | loss_type: "B" 8 | gamma: 10.0 9 | max_capacity: 25 10 | Capacity_max_iter: 10000 11 | 12 | exp_params: 13 | dataset: omniglot 14 | data_path: "/content/" 15 | img_size: 64 16 | batch_size: 64 # Better to have a square number 17 | LR: 0.00005 18 | weight_decay: 0.0 19 | scheduler_gamma: 0.95 20 | 21 | trainer_params: 22 | gpus: 1 23 | max_nb_epochs: 50 24 | max_epochs: 50 25 | 26 | logging_params: 27 | save_dir: "logs/" 28 | name: "BetaVAE_setup3_run2" 29 | manual_seed: 1265 30 | -------------------------------------------------------------------------------- /src/configs/bbvae_setup4.yml: -------------------------------------------------------------------------------- 1 | # batch_size 144 to 64 2 | # latent_dim 128 to 64 3 | 4 | model_params: 5 | name: "BetaVAE" 6 | in_channels: 1 7 | latent_dim: 64 8 | loss_type: "B" 9 | gamma: 10.0 10 | max_capacity: 25 11 | Capacity_max_iter: 10000 12 | 13 | exp_params: 14 | dataset: omniglot 15 | data_path: "../../shared/Data/" 16 | img_size: 64 17 | batch_size: 64 # Better to have a square number 18 | LR: 0.00005 19 | weight_decay: 0.0 20 | scheduler_gamma: 0.95 21 | 22 | trainer_params: 23 | gpus: 1 24 | max_nb_epochs: 50 25 | max_epochs: 50 26 | 27 | logging_params: 28 | save_dir: "logs/" 29 | name: "BetaVAE_setup4" 30 | manual_seed: 1265 31 | -------------------------------------------------------------------------------- /src/configs/betatc_vae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'BetaTCVAE' 3 | in_channels: 3 4 | latent_dim: 10 5 | anneal_steps: 10000 6 | alpha: 1. 7 | beta: 6. 8 | gamma: 1. 9 | 10 | exp_params: 11 | dataset: celeba 12 | data_path: "../../shared/momo/Data/" 13 | img_size: 64 14 | batch_size: 144 # Better to have a square number 15 | LR: 0.001 16 | weight_decay: 0.0 17 | # scheduler_gamma: 0.99 18 | 19 | trainer_params: 20 | gpus: 1 21 | max_nb_epochs: 50 22 | max_epochs: 50 23 | 24 | logging_params: 25 | save_dir: "logs/" 26 | name: "BetaTCVAE" 27 | manual_seed: 1265 28 | -------------------------------------------------------------------------------- /src/decoder_with_discriminator.py: -------------------------------------------------------------------------------- 1 | from Omniglot import Omniglot 2 | import yaml 3 | import argparse 4 | import numpy as np 5 | 6 | from models import * 7 | from experiment import VAEXperiment 8 | import torch.backends.cudnn as cudnn 9 | from pytorch_lightning import Trainer 10 | from pytorch_lightning.logging import TestTubeLogger 11 | import torchvision.utils as vutils 12 | from torch.utils.data import DataLoader 13 | import torch.optim as optim 14 | from torchvision import transforms 15 | 16 | import pickle 17 | import random 18 | 19 | import fasttext 20 | import fasttext.util 21 | 22 | 23 | 24 | class EmbeddingMapping(nn.Module): 25 | def __init__(self, device, embedding_vector_dim = 300, decoder_input_dim=128): 26 | super(EmbeddingMapping, self).__init__() 27 | self.device = device 28 | self.embedding_vector_dim = embedding_vector_dim 29 | self.decoder_input_dim = decoder_input_dim 30 | self.linear_mapping1 = nn.Linear(embedding_vector_dim, embedding_vector_dim) 31 | self.linear_mapping2 = nn.Linear(embedding_vector_dim, embedding_vector_dim) 32 | self.linear_mapping3 = nn.Linear(embedding_vector_dim, decoder_input_dim) 33 | self.relu = nn.ReLU() 34 | self.nrm1 = nn.BatchNorm1d(embedding_vector_dim) 35 | self.nrm2 = nn.BatchNorm1d(decoder_input_dim) 36 | 37 | def forward(self, embedding_vector): 38 | l1 = self.linear_mapping1(embedding_vector) 39 | l1 = self.nrm1(l1) 40 | l1 = self.relu(l1) 41 | l2 = self.linear_mapping2(l1) 42 | l2 = self.nrm1(l2) 43 | l2 = self.relu(l2) 44 | l3 = self.linear_mapping3(l2) 45 | l3 = self.nrm2(l3) 46 | l3 = self.relu(l3) 47 | 48 | return l3 49 | 50 | 51 | 52 | 53 | 54 | class Discriminator(nn.Module): 55 | def __init__(self, ngpu): 56 | super(Discriminator, self).__init__() 57 | self.ngpu = ngpu 58 | self.ndf = 64 59 | self.nc = 1 60 | self.main = nn.Sequential( 61 | # input is (nc) x 64 x 64 62 | nn.Conv2d(self.nc, self.ndf, 4, 2, 1, bias=False), 63 | nn.LeakyReLU(0.2, inplace=True), 64 | # state size. (self.ndf) x 32 x 32 65 | nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1, bias=False), 66 | nn.BatchNorm2d(self.ndf * 2), 67 | nn.LeakyReLU(0.2, inplace=True), 68 | # state size. (self.ndf*2) x 16 x 16 69 | nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False), 70 | nn.BatchNorm2d(self.ndf * 4), 71 | nn.LeakyReLU(0.2, inplace=True), 72 | # state size. (self.ndf*4) x 8 x 8 73 | nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1, bias=False), 74 | nn.BatchNorm2d(self.ndf * 8), 75 | nn.LeakyReLU(0.2, inplace=True), 76 | # state size. (self.ndf*8) x 4 x 4 77 | nn.Conv2d(self.ndf * 8, 1, 4, 1, 0, bias=False), 78 | nn.Sigmoid() 79 | ) 80 | 81 | def forward(self, input): 82 | return self.main(input) 83 | 84 | 85 | def weights_init(m): 86 | classname = m.__class__.__name__ 87 | if classname.find('Conv') != -1: 88 | nn.init.normal_(m.weight.data, 0.0, 0.02) 89 | elif classname.find('BatchNorm') != -1: 90 | nn.init.normal_(m.weight.data, 1.0, 0.02) 91 | nn.init.constant_(m.bias.data, 0) 92 | 93 | 94 | def get_dataLoader(batch_size): 95 | transform = data_transforms() 96 | dataset = Omniglot(split="train", transform=transform) 97 | num_train_imgs = len(dataset) 98 | return DataLoader(dataset, 99 | batch_size= batch_size, 100 | shuffle = True, 101 | drop_last=True) 102 | 103 | 104 | def data_transforms(): 105 | SetRange = transforms.Lambda(lambda X: 2 * X - 1.) 106 | SetScale = transforms.Lambda(lambda X: X/X.sum(0).expand_as(X)) 107 | transform = transforms.Compose([transforms.Resize((64,64)), transforms.ToTensor()]) 108 | return transform 109 | 110 | 111 | def randomSample_from_embeddings(batch_size): 112 | with open("/content/drive/My Drive/vae/logogram-language-generator-master/sample_fasttext_embs.pickle", "rb") as f: 113 | random_sample = pickle.load(f) 114 | 115 | fake_batch = random.sample(list(random_sample.values()), batch_size) 116 | fake_batch = torch.stack(fake_batch, dim=0) 117 | return fake_batch 118 | 119 | 120 | 121 | 122 | 123 | def trainer(vae_model, mapper, netD, batch_size, device): 124 | # Training Loop 125 | # Lists to keep track of progress 126 | img_list = [] 127 | G_losses = [] 128 | D_losses = [] 129 | iters = 0 130 | num_epochs = 5 131 | dataloader = get_dataLoader(batch_size=batch_size) 132 | 133 | # Initialize BCELoss function 134 | criterion = nn.BCELoss() 135 | lr = 0.00001 136 | beta1 = 0.5 137 | # Setup Adam optimizers for both G and D 138 | optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999)) 139 | optimizerG = optim.Adam(mapper.parameters(), lr=lr, betas=(beta1, 0.999)) 140 | 141 | 142 | 143 | # Create batch of latent vectors that we will use to visualize 144 | # the progression of the generator 145 | fixed_noise = randomSample_from_embeddings(batch_size).to(device) 146 | 147 | criterion = nn.BCELoss() 148 | 149 | # Establish convention for real and fake labels during training 150 | real_label = 1. 151 | fake_label = 0. 152 | 153 | 154 | print("Starting Training Loop...") 155 | # For each epoch 156 | for epoch in range(num_epochs): 157 | # For each batch in the dataloader 158 | for i, data in enumerate(dataloader, 0): 159 | 160 | ############################ 161 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) 162 | ########################### 163 | ## Train with all-real batch 164 | netD.zero_grad() 165 | # Format batch 166 | real_cpu = data[0].to(device) 167 | b_size = real_cpu.size(0) 168 | label = torch.full((b_size,), real_label, dtype=torch.float, device=device) 169 | # Forward pass real batch through D 170 | output = netD(real_cpu).view(-1) 171 | # Calculate loss on all-real batch 172 | errD_real = criterion(output, label) 173 | # Calculate gradients for D in backward pass 174 | errD_real.backward() 175 | D_x = output.mean().item() 176 | 177 | ## Train with all-fake batch 178 | # Generate batch of latent vectors 179 | embeddings = randomSample_from_embeddings(batch_size).to(device) 180 | # Generate fake image batch with G 181 | fake = mapper(embeddings).to(device) 182 | fake = vae_model.decode(fake) 183 | 184 | label.fill_(fake_label) 185 | # Classify all fake batch with D 186 | output = netD(fake.detach()).view(-1) 187 | # Calculate D's loss on the all-fake batch 188 | errD_fake = criterion(output, label) 189 | # Calculate the gradients for this batch 190 | errD_fake.backward() 191 | D_G_z1 = output.mean().item() 192 | # Add the gradients from the all-real and all-fake batches 193 | errD = errD_real + errD_fake 194 | # Update D 195 | optimizerD.step() 196 | 197 | ############################ 198 | # (2) Update G network: maximize log(D(G(z))) 199 | ########################### 200 | mapper.zero_grad() 201 | label.fill_(real_label) # fake labels are real for generator cost 202 | # Since we just updated D, perform another forward pass of all-fake batch through D 203 | output = netD(fake).view(-1) 204 | # Calculate G's loss based on this output 205 | errG = criterion(output, label) 206 | # Calculate gradients for G 207 | errG.backward() 208 | D_G_z2 = output.mean().item() 209 | # Update G 210 | optimizerG.step() 211 | 212 | # Output training stats 213 | if i % 50 == 0: 214 | print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' 215 | % (epoch, num_epochs, i, len(dataloader), 216 | errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) 217 | randomsamples = [] 218 | # for i in range(64): 219 | 220 | with torch.no_grad(): 221 | mapped = mapper(torch.randn(64,300).to(device)) 222 | randomsamples.append(mapped.to(device)) 223 | 224 | embed_samples = [] 225 | # for i in range(64): 226 | one_embed = randomSample_from_embeddings(64).to(device) 227 | with torch.no_grad(): 228 | mapped_embed = mapper(one_embed).to(device) 229 | embed_samples.append(mapped_embed) 230 | 231 | 232 | 233 | randomsamples = torch.stack(randomsamples, dim=0) 234 | recons_randoms = vae_model.decode(randomsamples) 235 | 236 | embedsamples = torch.stack(embed_samples, dim=0) 237 | recons_embeds = vae_model.decode(embedsamples) 238 | 239 | 240 | 241 | vutils.save_image(recons_randoms.data, 242 | f"./vae_with_disc/test_random{i}.png", 243 | normalize=True, 244 | nrow=12) 245 | vutils.save_image(recons_embeds.data, 246 | f"./vae_with_disc/test_embed{i}.png", 247 | normalize=True, 248 | nrow=12) 249 | 250 | # Save Losses for plotting later 251 | G_losses.append(errG.item()) 252 | D_losses.append(errD.item()) 253 | 254 | # Check how the generator is doing by saving G's output on fixed_noise 255 | if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)): 256 | with torch.no_grad(): 257 | fake = mapper(fixed_noise).detach().to(device) 258 | img_list.append(vutils.make_grid(fake, padding=2, normalize=True)) 259 | # vutils.save_image(img_list, 260 | # f"./vae_with_disc/samples_{i}.png", 261 | # normalize=True, 262 | # nrow=12) 263 | 264 | iters += 1 265 | return vae_model, mapper, netD 266 | 267 | def load_vae_model(vae_checkpointPath, config): 268 | model = vae_models[config['model_params']['name']](**config['model_params']) 269 | experiment = VAEXperiment(model, 270 | config['exp_params']) 271 | 272 | 273 | checkpoint = torch.load(vae_checkpointPath, map_location=lambda storage, loc: storage) 274 | new_ckpoint = {} 275 | for k in checkpoint["state_dict"].keys(): 276 | newKey = k.split("model.")[1] 277 | new_ckpoint[newKey] = checkpoint["state_dict"][k] 278 | 279 | model.load_state_dict(new_ckpoint) 280 | model.eval() 281 | return model 282 | 283 | 284 | def main(): 285 | print("on main") 286 | 287 | 288 | with open("./configs/bbvae_setup2.yaml", 'r') as file: 289 | try: 290 | config = yaml.safe_load(file) 291 | except yaml.YAMLError as exc: 292 | print(exc) 293 | 294 | vae_checkpointPath = "logs/BetaVAE_B_setup2_run2/final_model_checkpoint.ckpt" 295 | batch_size = 64 296 | device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") 297 | print(device) 298 | 299 | vae_model = load_vae_model(vae_checkpointPath, config).to(device) 300 | 301 | # fixed_noise = randomSample_from_embeddings(batch_size).to(device) 302 | # fixed_noise = torch.randn() 303 | 304 | 305 | 306 | mapper = EmbeddingMapping(device, 300, 128).to(device) 307 | mapper.apply(weights_init) 308 | 309 | netD = Discriminator(device).to(device) 310 | netD.apply(weights_init) 311 | vae_model, mapper, netD = trainer(vae_model=vae_model, mapper=mapper, netD=netD, batch_size=batch_size, device=device) 312 | 313 | with open("/content/drive/My Drive/vae/logogram-language-generator-master/fasttext_hello_world.pickle", "rb") as f: 314 | helloworld = pickle.load(f) 315 | 316 | # helloworld = random.sample(list(helloworld.values()), batch_size) 317 | hello = torch.from_numpy(helloworld["hello"]).to(device) 318 | world = torch.from_numpy(helloworld["world"]).to(device) 319 | 320 | 321 | helloworld = torch.stack([hello, world], dim=0).to(device) 322 | 323 | embed_samples = [] 324 | with torch.no_grad(): 325 | mapped_embed = mapper(helloworld).to(device) 326 | embed_samples.append(mapped_embed) 327 | embedsamples = torch.stack(embed_samples, dim=0) 328 | recons_embeds = vae_model.decode(embedsamples) 329 | vutils.save_image(recons_embeds.data, 330 | "./vae_with_disc/helloWorld.png", 331 | normalize=True, 332 | nrow=12) 333 | 334 | 335 | 336 | 337 | print("ALL DONE!") 338 | 339 | 340 | 341 | 342 | 343 | if __name__ == "__main__": 344 | main() 345 | -------------------------------------------------------------------------------- /src/experiment.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import optim 4 | from models import BaseVAE 5 | from models.types_ import * 6 | from utils import data_loader 7 | import pytorch_lightning as pl 8 | from torchvision import transforms 9 | import torchvision.utils as vutils 10 | from torchvision.datasets import CelebA 11 | from torch.utils.data import DataLoader 12 | from Omniglot import Omniglot 13 | 14 | 15 | class VAEXperiment(pl.LightningModule): 16 | 17 | def __init__(self, 18 | vae_model: BaseVAE, 19 | params: dict) -> None: 20 | super(VAEXperiment, self).__init__() 21 | 22 | self.model = vae_model 23 | self.params = params 24 | self.curr_device = None 25 | self.hold_graph = False 26 | try: 27 | self.hold_graph = self.params['retain_first_backpass'] 28 | except: 29 | pass 30 | 31 | def forward(self, input: Tensor, **kwargs) -> Tensor: 32 | return self.model(input, **kwargs) 33 | 34 | def training_step(self, batch, batch_idx, optimizer_idx = 0): 35 | real_img, labels = batch 36 | self.curr_device = real_img.device 37 | 38 | results = self.forward(real_img, labels = labels) 39 | train_loss = self.model.loss_function(*results, 40 | M_N = self.params['batch_size']/ self.num_train_imgs, 41 | optimizer_idx=optimizer_idx, 42 | batch_idx = batch_idx) 43 | 44 | self.logger.experiment.log({key: val.item() for key, val in train_loss.items()}) 45 | 46 | return train_loss 47 | 48 | def validation_step(self, batch, batch_idx, optimizer_idx = 0): 49 | real_img, labels = batch 50 | self.curr_device = real_img.device 51 | 52 | results = self.forward(real_img, labels = labels) 53 | val_loss = self.model.loss_function(*results, 54 | M_N = self.params['batch_size']/ self.num_val_imgs, 55 | optimizer_idx = optimizer_idx, 56 | batch_idx = batch_idx) 57 | 58 | return val_loss 59 | 60 | def validation_end(self, outputs): 61 | avg_loss = torch.stack([x['loss'] for x in outputs]).mean() 62 | tensorboard_logs = {'avg_val_loss': avg_loss} 63 | self.sample_images() 64 | ############################### 65 | ############################### 66 | return {'val_loss': avg_loss, 'log': tensorboard_logs} 67 | 68 | def sample_images(self): 69 | # Get sample reconstruction image 70 | test_input, test_label = next(iter(self.sample_dataloader)) 71 | test_input = test_input.to(self.curr_device) 72 | test_label = test_label.to(self.curr_device) 73 | recons = self.model.generate(test_input, labels = test_label) 74 | vutils.save_image(recons.data, 75 | f"{self.logger.save_dir}{self.logger.name}/version_{self.logger.version}/" 76 | f"recons_{self.logger.name}_{self.current_epoch}.png", 77 | normalize=True, 78 | nrow=12) 79 | 80 | # vutils.save_image(test_input.data, 81 | # f"{self.logger.save_dir}{self.logger.name}/version_{self.logger.version}/" 82 | # f"real_img_{self.logger.name}_{self.current_epoch}.png", 83 | # normalize=True, 84 | # nrow=12) 85 | 86 | try: 87 | samples = self.model.sample(144, 88 | self.curr_device, 89 | labels = test_label) 90 | vutils.save_image(samples.cpu().data, 91 | f"{self.logger.save_dir}{self.logger.name}/version_{self.logger.version}/" 92 | f"{self.logger.name}_{self.current_epoch}.png", 93 | normalize=True, 94 | nrow=12) 95 | except: 96 | pass 97 | 98 | 99 | del test_input, recons #, samples 100 | 101 | 102 | def configure_optimizers(self): 103 | 104 | optims = [] 105 | scheds = [] 106 | 107 | optimizer = optim.Adam(self.model.parameters(), 108 | lr=self.params['LR'], 109 | weight_decay=self.params['weight_decay']) 110 | optims.append(optimizer) 111 | # Check if more than 1 optimizer is required (Used for adversarial training) 112 | try: 113 | if self.params['LR_2'] is not None: 114 | optimizer2 = optim.Adam(getattr(self.model,self.params['submodel']).parameters(), 115 | lr=self.params['LR_2']) 116 | optims.append(optimizer2) 117 | except: 118 | pass 119 | 120 | try: 121 | if self.params['scheduler_gamma'] is not None: 122 | scheduler = optim.lr_scheduler.ExponentialLR(optims[0], 123 | gamma = self.params['scheduler_gamma']) 124 | scheds.append(scheduler) 125 | 126 | # Check if another scheduler is required for the second optimizer 127 | try: 128 | if self.params['scheduler_gamma_2'] is not None: 129 | scheduler2 = optim.lr_scheduler.ExponentialLR(optims[1], 130 | gamma = self.params['scheduler_gamma_2']) 131 | scheds.append(scheduler2) 132 | except: 133 | pass 134 | return optims, scheds 135 | except: 136 | return optims 137 | 138 | @data_loader 139 | def train_dataloader(self): 140 | transform = self.data_transforms() 141 | 142 | if self.params['dataset'] == 'celeba': 143 | dataset = CelebA(root = self.params['data_path'], 144 | split = "train", 145 | transform=transform, 146 | download=False) 147 | elif self.params['dataset'] == 'omniglot': 148 | dataset = Omniglot(split="train", transform=transform) 149 | else: 150 | raise ValueError('Undefined dataset type') 151 | 152 | self.num_train_imgs = len(dataset) 153 | return DataLoader(dataset, 154 | batch_size= self.params['batch_size'], 155 | shuffle = True, 156 | drop_last=True) 157 | 158 | @data_loader 159 | def val_dataloader(self): 160 | transform = self.data_transforms() 161 | 162 | if self.params['dataset'] == 'celeba': 163 | self.sample_dataloader = DataLoader(CelebA(root = self.params['data_path'], 164 | split = "test", 165 | transform=transform, 166 | download=False), 167 | batch_size= 144, 168 | shuffle = True, 169 | drop_last=True) 170 | self.num_val_imgs = len(self.sample_dataloader) 171 | elif self.params['dataset'] == 'omniglot': 172 | self.sample_dataloader = DataLoader(Omniglot(split="test", transform=transform), batch_size=144, shuffle=True, drop_last=True) # batch_size may change 173 | self.num_val_imgs = len(self.sample_dataloader) 174 | else: 175 | raise ValueError('Undefined dataset type') 176 | 177 | return self.sample_dataloader 178 | 179 | def data_transforms(self): 180 | 181 | SetRange = transforms.Lambda(lambda X: 2 * X - 1.) 182 | SetScale = transforms.Lambda(lambda X: X/X.sum(0).expand_as(X)) 183 | 184 | if self.params['dataset'] == 'celeba': 185 | transform = transforms.Compose([transforms.RandomHorizontalFlip(), 186 | transforms.CenterCrop(148), 187 | transforms.Resize(self.params['img_size']), 188 | transforms.ToTensor(), 189 | SetRange]) 190 | elif self.params['dataset'] == 'omniglot': 191 | #todo 192 | transform = transforms.Compose([transforms.Resize((64,64)), transforms.ToTensor()]) 193 | else: 194 | raise ValueError('Undefined dataset type') 195 | return transform 196 | 197 | -------------------------------------------------------------------------------- /src/logogram_language_generator.py: -------------------------------------------------------------------------------- 1 | from Omniglot import Omniglot 2 | import yaml 3 | import argparse 4 | import numpy as np 5 | 6 | from models import * 7 | from experiment import VAEXperiment 8 | import torch.backends.cudnn as cudnn 9 | from pytorch_lightning import Trainer 10 | from pytorch_lightning.logging import TestTubeLogger 11 | import torchvision.utils as vutils 12 | from torch.utils.data import DataLoader 13 | import torch.optim as optim 14 | from torchvision import transforms 15 | import imageio 16 | from scipy.interpolate import interp1d 17 | import pickle 18 | import random 19 | import fasttext 20 | from torchvision.transforms.functional import * 21 | from PIL import Image, ImageFont, ImageDraw 22 | from MapperLayer import umwe2vae 23 | 24 | from MapperLayer import EmbeddingMapping 25 | from MapperLayer import MultilingualMapper 26 | import matplotlib.pyplot as plt 27 | import math 28 | 29 | parser = argparse.ArgumentParser(description='Logogram Language Generator Main script') 30 | parser.add_argument('--embd_random_samples', help = 'path to the embedding samples', default='./randomSamples_fromEmbeddings/') 31 | parser.add_argument("--vae_ckp_path", type=str, default="./model_checkpoints/final_model_checkpoint.ckpt", help="checkpoint path of vae") 32 | parser.add_argument("--config_path", type=str, default="./configs/bbvae_CompleteRun.yaml", help="config") 33 | parser.add_argument("--export_path", type=str, default="./outputs/", help="export") 34 | 35 | # parser.add_argument("--test_onHelloWorld", type=bool, default=True, help="") 36 | # parser.add_argument("--emb_vector_dim", type=int, default=300, help="") 37 | # parser.add_argument("--vae_latent_dim", type=int, default=128, help="") 38 | # parser.add_argument("--mapper_numlayer", type=int, default=3, help="") 39 | 40 | parser.add_argument("--umwe_mappers_path", type=str, default="./model_checkpoints/", help="config") 41 | parser.add_argument("--words_path", type=str, default="./words.txt", help="config") 42 | 43 | # "standard" for statistical normalization "mapper" for linear layered mapper trained with discriminator 44 | parser.add_argument("--norm_option", type=str, default="standard", help="config") 45 | parser.add_argument("--mapper_layer_ckp", type=str, default="./lossy_mapper/umwe2vae.ckpt", help="config") 46 | 47 | 48 | parser.add_argument("--emb_bins_path", type=str, default="./model_checkpoints/cc_bins_128/", help="config") 49 | 50 | parser.add_argument("--gif", type=str, default="True", help="config") 51 | 52 | parser.add_argument("--siamese_mapper", type=str, default="False", help="config") 53 | 54 | 55 | args = parser.parse_args() 56 | with open(args.words_path, "r") as f: 57 | args.words = f.readlines() 58 | for word in range(len(args.words)): 59 | args.words[word] = args.words[word].split("\n")[0] 60 | 61 | args.latent_dim = 128 62 | args.emb_vector_dim = 128 63 | 64 | device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") 65 | 66 | 67 | def randomSample_from_embeddings(batch_size, lang): 68 | with open(args.embd_random_samples+"sample_fasttext_embs_"+lang+"_128.pickle", "rb") as f: 69 | random_sample = pickle.load(f) 70 | if batch_size == -1: 71 | batch_size = len(list(random_sample.values())) 72 | fake_batch = random.sample(list(random_sample.values()), batch_size) 73 | fake_batch = torch.stack(fake_batch, dim=0) 74 | return fake_batch 75 | 76 | def load_vae_model(vae_checkpointPath): 77 | with open(args.config_path, 'r') as file: 78 | try: 79 | config = yaml.safe_load(file) 80 | except yaml.YAMLError as exc: 81 | print(exc) 82 | 83 | model = vae_models[config['model_params']['name']](**config['model_params']) 84 | experiment = VAEXperiment(model, 85 | config['exp_params']) 86 | 87 | 88 | checkpoint = torch.load(vae_checkpointPath, map_location=lambda storage, loc: storage) 89 | new_ckpoint = {} 90 | for k in checkpoint["state_dict"].keys(): 91 | newKey = k.split("model.")[1] 92 | new_ckpoint[newKey] = checkpoint["state_dict"][k] 93 | 94 | model.load_state_dict(new_ckpoint) 95 | model.eval() 96 | return model 97 | 98 | def load_mapper(vae_model, mapper_type="mapper-disc"): 99 | if mapper_type == "mapper-disc": 100 | model = EmbeddingMapping(device, args.emb_vector_dim, args.latent_dim) 101 | model.load_state_dict(torch.load(args.mapper_layer_ckp)) 102 | model.eval() 103 | elif mapper_type == "mapper-lossy": 104 | model = umwe2vae(vae_model, in_dim=128, out_dim=128) 105 | model.load_state_dict(torch.load(args.mapper_layer_ckp)) 106 | model.eval() 107 | elif mapper_type == "siamese_mapper": 108 | model = MultilingualMapper(device, args.emb_vector_dim, args.latent_dim) 109 | model.load_state_dict(torch.load("./layerWithSiamese/multilingualMapper_checkpoint.pt")) 110 | model.eval() 111 | return model.to(device) 112 | 113 | def embed_words(): 114 | embs = [] 115 | lang = "" 116 | for word in args.words: 117 | if word.split(" - ")[1] != lang: 118 | lang = word.split(" - ")[1] 119 | ft = fasttext.load_model(args.emb_bins_path+"cc."+lang+".128.bin") 120 | ft.get_dimension() 121 | emb = torch.from_numpy(ft.get_word_vector(word.split(" - ")[0])).to(device) 122 | if lang != "en": 123 | multilingual_mapper = torch.from_numpy(torch.load(args.umwe_mappers_path+"best_mapping_"+lang+"2en.t7")).to(device) 124 | emb = torch.matmul(emb, multilingual_mapper) 125 | embs.append(emb) 126 | embs = torch.stack(embs, dim=0) 127 | return embs 128 | # numpy to tensor word embs 129 | # torch.from_numpy(helloworld["hello"]).to(device) 130 | 131 | # best_mapping = "./best_mapping_fr2en.t7" 132 | # Readed_t7 = torch.from_numpy(torch.load(best_mapping)).to(device) 133 | # mapped_hello = torch.matmul(hello, Readed_t7) 134 | # mapped_world = torch.matmul(world, Readed_t7) 135 | 136 | 137 | def normalize_embeddings(embeddings, vae_model): 138 | normalized_embs = [] 139 | if args.norm_option == "standard": 140 | for word in range(embeddings.shape[0]): 141 | sample_embeds = randomSample_from_embeddings(-1, "en").to(device) 142 | std, mean = torch.std_mean(input=sample_embeds, dim=0, unbiased=True) 143 | normalized_embs.append((embeddings[word,:] - mean) / std) 144 | normalized_embs = torch.stack(normalized_embs, dim=0) 145 | 146 | elif args.norm_option == "mapper-disc" or args.norm_option == "mapper-lossy": 147 | mapperLayer = load_mapper(vae_model, args.norm_option) 148 | with torch.no_grad(): 149 | normalized_embs = mapperLayer(embeddings).to(device) 150 | elif args.norm_option == "none": 151 | normalized_embs = embeddings 152 | 153 | return normalized_embs 154 | 155 | def interpolate_me_baby(emb1,emb2,n): 156 | """ 157 | a diabolical function that linearly interpolates two tensors 158 | n: you better have this high 159 | """ 160 | 161 | shapes = emb1.shape 162 | emb1_flattened_n = emb1.flatten().cpu().numpy() 163 | emb2_flattened_n = emb2.flatten().cpu().numpy() 164 | f = interp1d(x=[0,1], y=np.vstack([emb1_flattened_n,emb2_flattened_n]),axis=0) 165 | y = f(np.linspace(0, 1, n)) 166 | L = [torch.reshape(torch.from_numpy(kral).float(), shapes).to(device) for kral in y] 167 | return torch.stack(L, dim=0) # return shape [n, latent_dim] 168 | 169 | def add_text_to_image(tensir, word): #tensir: tenSÖR DID YOU GOT THE JOKE :d 170 | width, height = tensir.shape 171 | img_pil = transforms.functional.to_pil_image(tensir) #.resize((width,height + 30),0) #resmin altına yazı için boşluk 172 | draw = ImageDraw.Draw(img_pil) 173 | fontsize = 40 174 | font = ImageFont.truetype("./TimesNewRoman400.ttf", size=fontsize) #font olayını ayarlayamadım, fontu truetype dosyası olarak indirmek gerekiyor 175 | wordSize = font.getsize(word) 176 | #draw.text(xy = (width//3,height+15), text = word, fill = "white",font = font, anchor = "ms") #fontu ayarlamayınca da şu anchor'lama olayı yapılamıyormuş. yazı şu anda sol altta, ben alt ortada olsun istemiştim ama böyle de güzel oldu. 177 | draw.text(xy = (int(width/2)-int(wordSize[0]/2),height-wordSize[1]-5), text = word, fill = "black", font = font) 178 | processed_tensir = to_tensor(img_pil) 179 | return torch.reshape(processed_tensir, [width,height]) 180 | 181 | def make_transparent(tensor_img): 182 | img = transforms.functional.to_pil_image(tensor_img) 183 | img = img.convert("RGBA") 184 | datas = img.getdata() 185 | 186 | newData = [] 187 | for item in datas: 188 | if item[0] == 255 and item[1] == 255 and item[2] == 255: 189 | newData.append((255, 255, 255, 0)) 190 | else: 191 | newData.append(item) 192 | 193 | img.putdata(newData) 194 | return to_tensor(img) 195 | 196 | 197 | def savegrid(ims, rows=None, cols=None, fill=True, showax=False): 198 | if rows is None != cols is None: 199 | raise ValueError("Set either both rows and cols or neither.") 200 | 201 | if rows is None: 202 | rows = len(ims) 203 | cols = 1 204 | 205 | gridspec_kw = {'wspace': 0, 'hspace': 0} if fill else {} 206 | fig,axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw) 207 | 208 | if fill: 209 | bleed = 0 210 | fig.subplots_adjust(left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed)) 211 | 212 | for ax,im in zip(axarr.ravel(), ims): 213 | ax.imshow(im, cmap='gray') 214 | if not showax: 215 | ax.set_axis_off() 216 | 217 | kwargs = {'pad_inches': .01} if fill else {} 218 | fig.savefig(args.export_path+'logograms.png', **kwargs) 219 | 220 | def post_proccess(imgs): 221 | imgs = list(map(lambda x:resize(to_pil_image(x),256), imgs)) 222 | imgs = list(map(lambda x:adjust_saturation(x,1.5), imgs)) 223 | imgs = list(map(lambda x:adjust_gamma(x,1.5), imgs)) 224 | imgs = list(map(lambda x:adjust_contrast(x,2), imgs)) 225 | imgs = list(map(lambda x:to_tensor(x), imgs)) 226 | for x in imgs: 227 | x[x>0.5]=1 228 | for x in imgs: 229 | x[x<0.4]=0 230 | 231 | #save_image(imgs, 'img3.png') 232 | return imgs[0] 233 | 234 | def main(): 235 | print("started...") 236 | 237 | vae_model = load_vae_model(args.vae_ckp_path).to(device) 238 | 239 | word_embeddings = embed_words() 240 | 241 | normalized_embeddings = normalize_embeddings(word_embeddings, vae_model) 242 | 243 | if args.siamese_mapper == "True": 244 | siam_mapper = load_mapper(vae_model, mapper_type="siamese_mapper") 245 | with torch.no_grad(): 246 | normalized_embeddings = siam_mapper(normalized_embeddings).to(device) 247 | normalized_embeddings = normalize_embeddings(normalized_embeddings, vae_model) 248 | 249 | 250 | if args.gif == "True": 251 | n = 5 252 | interpolations = [] 253 | for word in range(normalized_embeddings.shape[0]-1): 254 | w1 = normalized_embeddings[word,:128] 255 | w2 = normalized_embeddings[word+1,:128] 256 | 257 | intplt = interpolate_me_baby(w1, w2, n) 258 | interpolations.append(intplt) 259 | interpolations = torch.cat(interpolations, dim=0).to(device) 260 | 261 | images = [] 262 | for word in range(interpolations.shape[0]): 263 | if word % n == 0: 264 | word_duration = 10 265 | textWord = args.words[int(word/n)] #.split(" - ")[0] 266 | else: 267 | word_duration = 1 268 | textWord = "" 269 | 270 | for duration in range(word_duration): 271 | logogram = vae_model.decode(interpolations[word,:128].float()) 272 | # logogram = torch.cat([torch.zeros(1,1,64,64).float(), logogram.data.cpu()], dim=1).cpu() 273 | # images.append(logogram) 274 | processed = post_proccess(torch.reshape(logogram.data, [1,64,64]).cpu()) 275 | processed = torch.reshape(processed, [256,256]) 276 | processed = add_text_to_image(processed, textWord) 277 | images.append(processed) 278 | imageio.mimwrite(args.export_path+"helloToWorld.gif", images) 279 | 280 | else: 281 | images = [] 282 | for word in range(normalized_embeddings.shape[0]): 283 | logogram = vae_model.decode(normalized_embeddings[word,:128].float()) 284 | processed = post_proccess(torch.reshape(logogram.data, [1,64,64]).cpu()) 285 | processed = torch.reshape(processed, [256,256]) 286 | processed = add_text_to_image(processed, args.words[word]) 287 | images.append(processed) 288 | savegrid(images, cols=int(math.sqrt(len(args.words))), rows=int(len(args.words)/int(math.sqrt(len(args.words))))) 289 | 290 | 291 | 292 | 293 | # vutils.save_image(logogram.data, 294 | # args.export_path+args.words[word]+".png", 295 | # normalize=True, 296 | # nrow=12) 297 | 298 | 299 | 300 | print("ALL DONE!") 301 | 302 | 303 | # python logogram_language_generator.py --embd_random_samples ./randomSamples_fromEmbeddings/sample_fasttext_embs_en_128.pickle --vae_ckp_path ./model_checkpoints/final_model_checkpoint.ckpt --config_path ./configs/bbvae_CompleteRun.yaml --export_path ./outputs/ --umwe_mappers_path ./model_checkpoints/ --words_path words.txt --norm_option standard --mapper_layer_ckp ./vae_with_disc/mapper_checkpoint.pt --emb_bins_path ./model_checkpoints/cc_bins_128/ --gif False 304 | 305 | 306 | if __name__ == "__main__": 307 | main() 308 | -------------------------------------------------------------------------------- /src/lossyMapper_train.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | import numpy as np 4 | 5 | from models import * 6 | from experiment import VAEXperiment 7 | import torch.backends.cudnn as cudnn 8 | from pytorch_lightning import Trainer 9 | from pytorch_lightning.logging import TestTubeLogger 10 | 11 | from torchvision.utils import save_image 12 | from umwe2vae import umwe2vae 13 | 14 | 15 | parser = argparse.ArgumentParser(description='Generic runner for VAE models') 16 | parser.add_argument('--config', '-c', 17 | dest="filename", 18 | metavar='FILE', 19 | help = 'path to the config file', 20 | default='configs/vae.yaml') 21 | 22 | args = parser.parse_args() 23 | 24 | 25 | def randomSample_from_embeddings(batch_size): 26 | with open(args.embd_random_samples, "rb") as f: 27 | random_sample = pickle.load(f) 28 | 29 | fake_batch = random.sample(list(random_sample.values()), batch_size) 30 | fake_batch = torch.stack(fake_batch, dim=0) 31 | return fake_batch 32 | 33 | 34 | with open(args.filename, 'r') as file: 35 | try: 36 | config = yaml.safe_load(file) 37 | except yaml.YAMLError as exc: 38 | print(exc) 39 | 40 | model = vae_models[config['model_params']['name']](**config['model_params']) 41 | 42 | checkpoint = torch.load("./model_checkpoints/final_model_checkpoint.ckpt") 43 | state_dict = {} 44 | for k in checkpoint['state_dict'].keys(): 45 | state_dict[k[6:]] = checkpoint['state_dict'][k] 46 | 47 | model.load_state_dict(state_dict) 48 | 49 | ldr = [torch.randn((64,128)) for _ in range(10)] 50 | 51 | inp = torch.randn((5,128)) 52 | u2v = umwe2vae(model, 128, 128) 53 | save_image([u2v(inp)[0],u2v(inp)[1],u2v(inp)[2],u2v(inp)[3],u2v(inp)[4]], 'img1.png') 54 | u2v.train(ldr) 55 | 56 | save_image([u2v(inp)[0],u2v(inp)[1],u2v(inp)[2],u2v(inp)[3],u2v(inp)[4]], 'img2.png') 57 | 58 | checkpoint = torch.load("./umwe2vae.ckpt") 59 | u2v.load_state_dict(checkpoint['model_state_dict']) 60 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .vanilla_vae import * 3 | from .gamma_vae import * 4 | from .beta_vae import * 5 | from .wae_mmd import * 6 | from .cvae import * 7 | from .hvae import * 8 | from .vampvae import * 9 | from .iwae import * 10 | from .dfcvae import * 11 | from .mssim_vae import MSSIMVAE 12 | from .fvae import * 13 | from .cat_vae import * 14 | from .joint_vae import * 15 | from .info_vae import * 16 | # from .twostage_vae import * 17 | from .lvae import LVAE 18 | from .logcosh_vae import * 19 | from .swae import * 20 | from .miwae import * 21 | from .vq_vae import * 22 | from .betatc_vae import * 23 | from .dip_vae import * 24 | 25 | 26 | # Aliases 27 | VAE = VanillaVAE 28 | GaussianVAE = VanillaVAE 29 | CVAE = ConditionalVAE 30 | GumbelVAE = CategoricalVAE 31 | 32 | vae_models = {'HVAE':HVAE, 33 | 'LVAE':LVAE, 34 | 'IWAE':IWAE, 35 | 'SWAE':SWAE, 36 | 'MIWAE':MIWAE, 37 | 'VQVAE':VQVAE, 38 | 'DFCVAE':DFCVAE, 39 | 'DIPVAE':DIPVAE, 40 | 'BetaVAE':BetaVAE, 41 | 'InfoVAE':InfoVAE, 42 | 'WAE_MMD':WAE_MMD, 43 | 'VampVAE': VampVAE, 44 | 'GammaVAE':GammaVAE, 45 | 'MSSIMVAE':MSSIMVAE, 46 | 'JointVAE':JointVAE, 47 | 'BetaTCVAE':BetaTCVAE, 48 | 'FactorVAE':FactorVAE, 49 | 'LogCoshVAE':LogCoshVAE, 50 | 'VanillaVAE':VanillaVAE, 51 | 'ConditionalVAE':ConditionalVAE, 52 | 'CategoricalVAE':CategoricalVAE} 53 | -------------------------------------------------------------------------------- /src/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/base.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/base.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/beta_vae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/beta_vae.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/betatc_vae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/betatc_vae.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/cat_vae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/cat_vae.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/cvae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/cvae.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/dfcvae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/dfcvae.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/dip_vae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/dip_vae.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/fvae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/fvae.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/gamma_vae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/gamma_vae.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/hvae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/hvae.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/info_vae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/info_vae.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/iwae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/iwae.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/joint_vae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/joint_vae.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/logcosh_vae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/logcosh_vae.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/lvae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/lvae.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/miwae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/miwae.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/mssim_vae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/mssim_vae.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/swae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/swae.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/types_.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/types_.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/vampvae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/vampvae.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/vanilla_vae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/vanilla_vae.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/vq_vae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/vq_vae.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/wae_mmd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/wae_mmd.cpython-36.pyc -------------------------------------------------------------------------------- /src/models/base.py: -------------------------------------------------------------------------------- 1 | from .types_ import * 2 | from torch import nn 3 | from abc import abstractmethod 4 | 5 | class BaseVAE(nn.Module): 6 | 7 | def __init__(self) -> None: 8 | super(BaseVAE, self).__init__() 9 | 10 | def encode(self, input: Tensor) -> List[Tensor]: 11 | raise NotImplementedError 12 | 13 | def decode(self, input: Tensor) -> Any: 14 | raise NotImplementedError 15 | 16 | def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor: 17 | raise RuntimeWarning() 18 | 19 | def generate(self, x: Tensor, **kwargs) -> Tensor: 20 | raise NotImplementedError 21 | 22 | @abstractmethod 23 | def forward(self, *inputs: Tensor) -> Tensor: 24 | pass 25 | 26 | @abstractmethod 27 | def loss_function(self, *inputs: Any, **kwargs) -> Tensor: 28 | pass 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /src/models/beta_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | 7 | 8 | class BetaVAE(BaseVAE): 9 | 10 | num_iter = 0 # Global static variable to keep track of iterations 11 | 12 | def __init__(self, 13 | in_channels: int, 14 | latent_dim: int, 15 | hidden_dims: List = None, 16 | beta: int = 4, 17 | gamma:float = 1000., 18 | max_capacity: int = 25, 19 | Capacity_max_iter: int = 1e5, 20 | loss_type:str = 'B', 21 | **kwargs) -> None: 22 | super(BetaVAE, self).__init__() 23 | 24 | self.latent_dim = latent_dim 25 | self.beta = beta 26 | self.gamma = gamma 27 | self.loss_type = loss_type 28 | self.C_max = torch.Tensor([max_capacity]) 29 | self.C_stop_iter = Capacity_max_iter 30 | 31 | modules = [] 32 | if hidden_dims is None: 33 | hidden_dims = [32, 64, 128, 256, 512] 34 | 35 | # Build Encoder 36 | for h_dim in hidden_dims: 37 | modules.append( 38 | nn.Sequential( 39 | nn.Conv2d(in_channels, out_channels=h_dim, 40 | kernel_size= 3, stride= 2, padding = 1), 41 | nn.BatchNorm2d(h_dim), 42 | nn.LeakyReLU()) 43 | ) 44 | in_channels = h_dim 45 | 46 | self.encoder = nn.Sequential(*modules) 47 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 48 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 49 | 50 | 51 | # Build Decoder 52 | modules = [] 53 | 54 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 55 | 56 | hidden_dims.reverse() 57 | 58 | for i in range(len(hidden_dims) - 1): 59 | modules.append( 60 | nn.Sequential( 61 | nn.ConvTranspose2d(hidden_dims[i], 62 | hidden_dims[i + 1], 63 | kernel_size=3, 64 | stride = 2, 65 | padding=1, 66 | output_padding=1), 67 | nn.BatchNorm2d(hidden_dims[i + 1]), 68 | nn.LeakyReLU()) 69 | ) 70 | 71 | 72 | 73 | self.decoder = nn.Sequential(*modules) 74 | 75 | self.final_layer = nn.Sequential( 76 | nn.ConvTranspose2d(hidden_dims[-1], 77 | hidden_dims[-1], 78 | kernel_size=3, 79 | stride=2, 80 | padding=1, 81 | output_padding=1), 82 | nn.BatchNorm2d(hidden_dims[-1]), 83 | nn.LeakyReLU(), 84 | nn.Conv2d(hidden_dims[-1], out_channels= 1, 85 | kernel_size= 3, padding= 1), 86 | nn.Tanh()) 87 | 88 | def encode(self, input: Tensor) -> List[Tensor]: 89 | """ 90 | Encodes the input by passing through the encoder network 91 | and returns the latent codes. 92 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 93 | :return: (Tensor) List of latent codes 94 | """ 95 | result = self.encoder(input) 96 | result = torch.flatten(result, start_dim=1) 97 | 98 | # Split the result into mu and var components 99 | # of the latent Gaussian distribution 100 | mu = self.fc_mu(result) 101 | log_var = self.fc_var(result) 102 | 103 | return [mu, log_var] 104 | 105 | def decode(self, z: Tensor) -> Tensor: 106 | # mapping layer 107 | result = self.decoder_input(z) 108 | result = result.view(-1, 512, 2, 2) 109 | result = self.decoder(result) 110 | result = self.final_layer(result) 111 | return result 112 | 113 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 114 | """ 115 | Will a single z be enough ti compute the expectation 116 | for the loss?? 117 | :param mu: (Tensor) Mean of the latent Gaussian 118 | :param logvar: (Tensor) Standard deviation of the latent Gaussian 119 | :return: 120 | """ 121 | std = torch.exp(0.5 * logvar) 122 | eps = torch.randn_like(std) 123 | return eps * std + mu 124 | 125 | def forward(self, input: Tensor, **kwargs) -> Tensor: 126 | mu, log_var = self.encode(input) 127 | z = self.reparameterize(mu, log_var) 128 | return [self.decode(z), input, mu, log_var] 129 | 130 | def loss_function(self, 131 | *args, 132 | **kwargs) -> dict: 133 | self.num_iter += 1 134 | recons = args[0] 135 | input = args[1] 136 | mu = args[2] 137 | log_var = args[3] 138 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 139 | 140 | recons_loss =F.mse_loss(recons, input) 141 | 142 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) 143 | 144 | if self.loss_type == 'H': # https://openreview.net/forum?id=Sy2fzU9gl 145 | loss = recons_loss + self.beta * kld_weight * kld_loss 146 | elif self.loss_type == 'B': # https://arxiv.org/pdf/1804.03599.pdf 147 | self.C_max = self.C_max.to(input.device) 148 | C = torch.clamp(self.C_max/self.C_stop_iter * self.num_iter, 0, self.C_max.data[0]) 149 | loss = recons_loss + self.gamma * kld_weight* (kld_loss - C).abs() 150 | else: 151 | raise ValueError('Undefined loss type.') 152 | 153 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss} 154 | 155 | def sample(self, 156 | num_samples:int, 157 | current_device: int, **kwargs) -> Tensor: 158 | """ 159 | Samples from the latent space and return the corresponding 160 | image space map. 161 | :param num_samples: (Int) Number of samples 162 | :param current_device: (Int) Device to run the model 163 | :return: (Tensor) 164 | """ 165 | z = torch.randn(num_samples, 166 | self.latent_dim) 167 | 168 | z = z.to(current_device) 169 | 170 | samples = self.decode(z) 171 | return samples 172 | 173 | def generate(self, x: Tensor, **kwargs) -> Tensor: 174 | """ 175 | Given an input image x, returns the reconstructed image 176 | :param x: (Tensor) [B x C x H x W] 177 | :return: (Tensor) [B x C x H x W] 178 | """ 179 | 180 | return self.forward(x)[0] -------------------------------------------------------------------------------- /src/randomSamples_fromEmbeddings/sample_fasttext_embs_en.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/randomSamples_fromEmbeddings/sample_fasttext_embs_en.pickle -------------------------------------------------------------------------------- /src/randomSamples_fromEmbeddings/sample_fasttext_embs_en_128.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/randomSamples_fromEmbeddings/sample_fasttext_embs_en_128.pickle -------------------------------------------------------------------------------- /src/randomSamples_fromEmbeddings/sample_fasttext_embs_es_128.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/randomSamples_fromEmbeddings/sample_fasttext_embs_es_128.pickle -------------------------------------------------------------------------------- /src/randomSamples_fromEmbeddings/sample_fasttext_embs_fr.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/randomSamples_fromEmbeddings/sample_fasttext_embs_fr.pickle -------------------------------------------------------------------------------- /src/randomSamples_fromEmbeddings/sample_fasttext_embs_fr_128.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/randomSamples_fromEmbeddings/sample_fasttext_embs_fr_128.pickle -------------------------------------------------------------------------------- /src/randomSamples_fromEmbeddings/sample_fasttext_embs_it_128.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/randomSamples_fromEmbeddings/sample_fasttext_embs_it_128.pickle -------------------------------------------------------------------------------- /src/randomSamples_fromEmbeddings/sample_from_embeds.py: -------------------------------------------------------------------------------- 1 | import fasttext 2 | import random 3 | import torch 4 | import pickle 5 | 6 | langs = ["fr", "it", "es", "en"] 7 | 8 | for lang in langs: 9 | ft = fasttext.load_model("../model_checkpoints/cc_bins_128/cc."+lang+".128.bin") 10 | ft.get_dimension() 11 | size = 10000 12 | random_words = random.sample(ft.get_words(), size) 13 | samples = {} 14 | for i in range(size): 15 | samples[random_words[i]] = torch.from_numpy(ft.get_word_vector(random_words[i])) 16 | filename = "sample_fasttext_embs_"+lang+"_128.pickle" 17 | outfile = open(filename,'wb') 18 | 19 | pickle.dump(samples,outfile) 20 | outfile.close() 21 | 22 | -------------------------------------------------------------------------------- /src/requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning==0.6.0 2 | PyYAML==5.1.2 3 | tensorboard==2.1.0 4 | tensorboardX==1.6 5 | terminado==0.8.1 6 | test-tube==0.7.0 7 | torch==1.4.0 8 | torchfile==0.1.0 9 | torchnet==0.0.4 10 | torchsummary==1.5.1 11 | torchvision==0.5.0 12 | 13 | -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | import numpy as np 4 | 5 | from models import * 6 | from experiment import VAEXperiment 7 | import torch.backends.cudnn as cudnn 8 | from pytorch_lightning import Trainer 9 | from pytorch_lightning.logging import TestTubeLogger 10 | 11 | 12 | parser = argparse.ArgumentParser(description='Generic runner for VAE models') 13 | parser.add_argument('--config', '-c', 14 | dest="filename", 15 | metavar='FILE', 16 | help = 'path to the config file', 17 | default='configs/vae.yaml') 18 | 19 | args = parser.parse_args() 20 | with open(args.filename, 'r') as file: 21 | try: 22 | config = yaml.safe_load(file) 23 | except yaml.YAMLError as exc: 24 | print(exc) 25 | 26 | 27 | tt_logger = TestTubeLogger( 28 | save_dir=config['logging_params']['save_dir'], 29 | name=config['logging_params']['name'], 30 | debug=False, 31 | create_git_tag=False, 32 | ) 33 | 34 | # For reproducibility 35 | torch.manual_seed(config['logging_params']['manual_seed']) 36 | np.random.seed(config['logging_params']['manual_seed']) 37 | cudnn.deterministic = True 38 | cudnn.benchmark = False 39 | 40 | model = vae_models[config['model_params']['name']](**config['model_params']) 41 | experiment = VAEXperiment(model, 42 | config['exp_params']) 43 | # print(f"{tt_logger.save_dir}"+f"{tt_logger.name}"+"/final_model_checkpoint.ckpt") 44 | print(f"{tt_logger.name}") 45 | 46 | runner = Trainer(default_save_path=f"{tt_logger.save_dir}", 47 | min_nb_epochs=1, 48 | logger=tt_logger, 49 | log_save_interval=100, 50 | train_percent_check=1., 51 | val_percent_check=1., 52 | num_sanity_val_steps=5, 53 | early_stop_callback = False, 54 | **config['trainer_params']) 55 | 56 | print(f"======= Training {config['model_params']['name']} =======") 57 | runner.fit(experiment) 58 | runner.save_checkpoint(f"{tt_logger.save_dir}"+f"{tt_logger.name}"+"/final_model_checkpoint.ckpt") 59 | 60 | 61 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | print("hello from nodejs with python") 2 | 3 | for i in range(1000): 4 | print(i) 5 | -------------------------------------------------------------------------------- /src/train_mapper_and_discriminator.py: -------------------------------------------------------------------------------- 1 | from Omniglot import Omniglot 2 | import yaml 3 | import argparse 4 | import numpy as np 5 | 6 | from models import * 7 | from experiment import VAEXperiment 8 | import torch.backends.cudnn as cudnn 9 | from pytorch_lightning import Trainer 10 | from pytorch_lightning.logging import TestTubeLogger 11 | import torchvision.utils as vutils 12 | from torch.utils.data import DataLoader 13 | import torch.optim as optim 14 | from torchvision import transforms 15 | 16 | import pickle 17 | import random 18 | 19 | import argparse 20 | 21 | # python train_mapper_and_discriminator.py --embd_random_samples ./randomSamples_fromEmbeddings/sample_fasttext_embs_en_128.pickle --vae_ckp_path ./model_checkpoints/final_model_checkpoint.ckpt --config_path ./configs/bbvae_CompleteRun.yaml 22 | 23 | #python logogram_language_generator.py --embd_random_samples ./randomSamples_fromEmbeddings/sample_fasttext_embs_en_128.pickle 24 | #--vae_ckp_path ./model_checkpoints/final_model_checkpoint.ckpt --config_path ./configs/bbvae_CompleteRun.yaml 25 | # --export_path ./outputs/ --words_path words.txt --norm_option standard 26 | # --emb_bins_path ./model_checkpoints/cc_bins_128/ 27 | 28 | parser = argparse.ArgumentParser(description='Mapper training for feeding embeddings to VAE') 29 | parser.add_argument('--embd_random_samples', help = 'path to the embedding samples', default='/content/drive/My Drive/vae/logogram-language-generator-master/sample_fasttext_embs.pickle') 30 | parser.add_argument("--vae_ckp_path", type=str, default="logs/BetaVAE_B_setup2_run2/final_model_checkpoint.ckpt", help="checkpoint path of vae") 31 | parser.add_argument("--config_path", type=str, default="./configs/bbvae_setup2.yaml", help="config") 32 | parser.add_argument("--export_path", type=str, default="./vae_with_disc/", help="export") 33 | parser.add_argument("--test_onHelloWorld", type=bool, default=False, help="") 34 | parser.add_argument("--emb_vector_dim", type=int, default=300, help="") 35 | parser.add_argument("--vae_latent_dim", type=int, default=128, help="") 36 | parser.add_argument("--mapper_numlayer", type=int, default=3, help="") 37 | 38 | 39 | args = parser.parse_args() 40 | 41 | 42 | class EmbeddingMapping(nn.Module): 43 | def __init__(self, device, embedding_vector_dim = 300, decoder_input_dim=128): 44 | super(EmbeddingMapping, self).__init__() 45 | self.device = device 46 | self.embedding_vector_dim = embedding_vector_dim 47 | self.decoder_input_dim = decoder_input_dim 48 | self.mapper_numlayer = args.mapper_numlayer 49 | 50 | self.linear_layers = [] 51 | self.batch_norms = [] 52 | for layer in range(0, self.mapper_numlayer-1): 53 | self.linear_layers.append(nn.Linear(embedding_vector_dim, embedding_vector_dim)) 54 | self.batch_norms.append(nn.BatchNorm1d(embedding_vector_dim)) 55 | 56 | # final layer 57 | self.linear_layers.append(nn.Linear(embedding_vector_dim, decoder_input_dim)) 58 | self.batch_norms.append(nn.BatchNorm1d(decoder_input_dim)) 59 | 60 | 61 | self.linear_layers = nn.ModuleList(self.linear_layers) 62 | self.batch_norms = nn.ModuleList(self.batch_norms) 63 | 64 | 65 | self.relu = nn.ReLU() 66 | 67 | def forward(self, embedding_vector): 68 | inp = embedding_vector 69 | for layer in range(self.mapper_numlayer): 70 | out = self.linear_layers[layer](inp) 71 | out = self.batch_norms[layer](out) 72 | out = self.relu(out) 73 | inp = out 74 | return out 75 | 76 | class Discriminator(nn.Module): 77 | def __init__(self, ngpu): 78 | super(Discriminator, self).__init__() 79 | self.ngpu = ngpu 80 | self.ndf = 64 81 | self.nc = 1 82 | self.main = nn.Sequential( 83 | # input is (nc) x 64 x 64 84 | nn.Conv2d(self.nc, self.ndf, 4, 2, 1, bias=False), 85 | nn.LeakyReLU(0.2, inplace=True), 86 | # state size. (self.ndf) x 32 x 32 87 | nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1, bias=False), 88 | nn.BatchNorm2d(self.ndf * 2), 89 | nn.LeakyReLU(0.2, inplace=True), 90 | # state size. (self.ndf*2) x 16 x 16 91 | nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False), 92 | nn.BatchNorm2d(self.ndf * 4), 93 | nn.LeakyReLU(0.2, inplace=True), 94 | # state size. (self.ndf*4) x 8 x 8 95 | nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1, bias=False), 96 | nn.BatchNorm2d(self.ndf * 8), 97 | nn.LeakyReLU(0.2, inplace=True), 98 | # state size. (self.ndf*8) x 4 x 4 99 | nn.Conv2d(self.ndf * 8, 1, 4, 1, 0, bias=False), 100 | nn.Sigmoid() 101 | ) 102 | 103 | def forward(self, input): 104 | return self.main(input) 105 | 106 | 107 | def weights_init(m): 108 | classname = m.__class__.__name__ 109 | if classname.find('Conv') != -1: 110 | nn.init.normal_(m.weight.data, 0.0, 0.02) 111 | elif classname.find('BatchNorm') != -1: 112 | nn.init.normal_(m.weight.data, 1.0, 0.02) 113 | nn.init.constant_(m.bias.data, 0) 114 | 115 | 116 | def get_dataLoader(batch_size): 117 | transform = data_transforms() 118 | dataset = Omniglot(split="train", transform=transform) 119 | num_train_imgs = len(dataset) 120 | return DataLoader(dataset, 121 | batch_size= batch_size, 122 | shuffle = True, 123 | drop_last=True) 124 | 125 | 126 | def data_transforms(): 127 | SetRange = transforms.Lambda(lambda X: 2 * X - 1.) 128 | SetScale = transforms.Lambda(lambda X: X/X.sum(0).expand_as(X)) 129 | transform = transforms.Compose([transforms.Resize((64,64)), transforms.ToTensor()]) 130 | return transform 131 | 132 | 133 | def randomSample_from_embeddings(batch_size): 134 | with open(args.embd_random_samples, "rb") as f: 135 | random_sample = pickle.load(f) 136 | 137 | fake_batch = random.sample(list(random_sample.values()), batch_size) 138 | fake_batch = torch.stack(fake_batch, dim=0) 139 | return fake_batch 140 | 141 | 142 | 143 | 144 | 145 | def trainer(vae_model, mapper, netD, batch_size, device): 146 | # Training Loop 147 | # Lists to keep track of progress 148 | img_list = [] 149 | G_losses = [] 150 | D_losses = [] 151 | iters = 0 152 | num_epochs = 10 153 | dataloader = get_dataLoader(batch_size=batch_size) 154 | 155 | # Initialize BCELoss function 156 | criterion = nn.BCELoss() 157 | lr = 0.00001 158 | beta1 = 0.5 159 | # Setup Adam optimizers for both G and D 160 | optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999)) 161 | optimizerG = optim.Adam(mapper.parameters(), lr=lr, betas=(beta1, 0.999)) 162 | 163 | 164 | 165 | # Create batch of latent vectors that we will use to visualize 166 | # the progression of the generator 167 | fixed_noise = randomSample_from_embeddings(batch_size).to(device) 168 | 169 | criterion = nn.BCELoss() 170 | 171 | # Establish convention for real and fake labels during training 172 | real_label = 1. 173 | fake_label = 0. 174 | 175 | 176 | print("Starting Training Loop...") 177 | # For each epoch 178 | for epoch in range(num_epochs): 179 | # For each batch in the dataloader 180 | for i, data in enumerate(dataloader, 0): 181 | 182 | ############################ 183 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) 184 | ########################### 185 | ## Train with all-real batch 186 | netD.zero_grad() 187 | # Format batch 188 | real_cpu = data[0].to(device) 189 | b_size = real_cpu.size(0) 190 | label = torch.full((b_size,), real_label, dtype=torch.float, device=device) 191 | # Forward pass real batch through D 192 | output = netD(real_cpu).view(-1) 193 | # Calculate loss on all-real batch 194 | errD_real = criterion(output, label) 195 | # Calculate gradients for D in backward pass 196 | errD_real.backward() 197 | D_x = output.mean().item() 198 | 199 | ## Train with all-fake batch 200 | # Generate batch of latent vectors 201 | embeddings = randomSample_from_embeddings(batch_size).to(device) 202 | # Generate fake image batch with G 203 | fake = mapper(embeddings).to(device) 204 | fake = vae_model.decode(fake) 205 | 206 | label.fill_(fake_label) 207 | # Classify all fake batch with D 208 | output = netD(fake.detach()).view(-1) 209 | # Calculate D's loss on the all-fake batch 210 | errD_fake = criterion(output, label) 211 | # Calculate the gradients for this batch 212 | errD_fake.backward() 213 | D_G_z1 = output.mean().item() 214 | # Add the gradients from the all-real and all-fake batches 215 | errD = errD_real + errD_fake 216 | # Update D 217 | optimizerD.step() 218 | 219 | ############################ 220 | # (2) Update G network: maximize log(D(G(z))) 221 | ########################### 222 | mapper.zero_grad() 223 | label.fill_(real_label) # fake labels are real for generator cost 224 | # Since we just updated D, perform another forward pass of all-fake batch through D 225 | output = netD(fake).view(-1) 226 | # Calculate G's loss based on this output 227 | errG = criterion(output, label) 228 | # Calculate gradients for G 229 | errG.backward() 230 | D_G_z2 = output.mean().item() 231 | # Update G 232 | optimizerG.step() 233 | 234 | # Output training stats 235 | if i % 50 == 0: 236 | print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' 237 | % (epoch, num_epochs, i, len(dataloader), 238 | errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) 239 | 240 | # Save Losses for plotting later 241 | G_losses.append(errG.item()) 242 | D_losses.append(errD.item()) 243 | 244 | # Check how the generator is doing by saving G's output on fixed_noise 245 | if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)): 246 | randomsamples = [] 247 | with torch.no_grad(): 248 | mapped = mapper(torch.randn(batch_size,args.emb_vector_dim).to(device)) 249 | randomsamples.append(mapped.to(device)) 250 | 251 | embed_samples = [] 252 | one_embed = randomSample_from_embeddings(batch_size).to(device) 253 | with torch.no_grad(): 254 | mapped_embed = mapper(one_embed).to(device) 255 | embed_samples.append(mapped_embed) 256 | 257 | randomsamples = torch.stack(randomsamples, dim=0) 258 | recons_randoms = vae_model.decode(randomsamples) 259 | 260 | embedsamples = torch.stack(embed_samples, dim=0) 261 | recons_embeds = vae_model.decode(embedsamples) 262 | 263 | vutils.save_image(recons_randoms.data, 264 | f"{args.export_path}test_random{i}.png", 265 | normalize=True, 266 | nrow=12) 267 | vutils.save_image(recons_embeds.data, 268 | f"{args.export_path}test_embed{i}.png", 269 | normalize=True, 270 | nrow=12) 271 | 272 | iters += 1 273 | return vae_model, mapper, netD 274 | 275 | def load_vae_model(vae_checkpointPath, config): 276 | model = vae_models[config['model_params']['name']](**config['model_params']) 277 | experiment = VAEXperiment(model, 278 | config['exp_params']) 279 | 280 | 281 | checkpoint = torch.load(vae_checkpointPath, map_location=lambda storage, loc: storage) 282 | new_ckpoint = {} 283 | for k in checkpoint["state_dict"].keys(): 284 | newKey = k.split("model.")[1] 285 | new_ckpoint[newKey] = checkpoint["state_dict"][k] 286 | 287 | model.load_state_dict(new_ckpoint) 288 | model.eval() 289 | return model 290 | 291 | 292 | def main(): 293 | print("on main") 294 | with open(args.config_path, 'r') as file: 295 | try: 296 | config = yaml.safe_load(file) 297 | except yaml.YAMLError as exc: 298 | print(exc) 299 | 300 | batch_size = 64 301 | device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") 302 | print(device) 303 | 304 | vae_model = load_vae_model(args.vae_ckp_path, config).to(device) 305 | 306 | mapper = EmbeddingMapping(device, args.emb_vector_dim, args.vae_latent_dim).to(device) 307 | mapper.apply(weights_init) 308 | 309 | netD = Discriminator(device).to(device) 310 | netD.apply(weights_init) 311 | vae_model, mapper, netD = trainer(vae_model=vae_model, mapper=mapper, netD=netD, batch_size=batch_size, device=device) 312 | 313 | torch.save(mapper.state_dict(), args.export_path+"mapper_checkpoint.pt") 314 | torch.save(netD.state_dict(), args.export_path+"discriminator_checkpoint.pt") 315 | 316 | 317 | 318 | if args.test_onHelloWorld: 319 | with open("/content/drive/My Drive/vae/logogram-language-generator-master/fasttext_hello_world.pickle", "rb") as f: 320 | helloworld = pickle.load(f) 321 | 322 | hello = torch.from_numpy(helloworld["hello"]).to(device) 323 | world = torch.from_numpy(helloworld["world"]).to(device) 324 | 325 | 326 | helloworld = torch.stack([hello, world], dim=0).to(device) 327 | 328 | embed_samples = [] 329 | with torch.no_grad(): 330 | mapped_embed = mapper(helloworld).to(device) 331 | embed_samples.append(mapped_embed) 332 | embedsamples = torch.stack(embed_samples, dim=0) 333 | recons_embeds = vae_model.decode(embedsamples) 334 | vutils.save_image(recons_embeds.data, 335 | args.export_path+"helloWorld.png", 336 | normalize=True, 337 | nrow=12) 338 | 339 | 340 | 341 | 342 | print("ALL DONE!") 343 | 344 | 345 | 346 | 347 | 348 | if __name__ == "__main__": 349 | main() 350 | -------------------------------------------------------------------------------- /src/train_siamese.py: -------------------------------------------------------------------------------- 1 | from Omniglot_triplet import Omniglot_triplet 2 | from torch.utils.data import DataLoader 3 | from torchvision.utils import save_image 4 | from Siamese import Siamese 5 | import torch 6 | 7 | #omni = Omniglot_triplet() 8 | #dloader = DataLoader(omni,batch_size=16,shuffle=True) 9 | 10 | siam = Siamese().cuda() 11 | #siam.train_triplet(dloader) 12 | siam.load_state_dict(torch.load("siamese.ckpt")) 13 | siam.eval() 14 | omnieval = Omniglot_triplet(split="test") 15 | dloader = DataLoader(omnieval,batch_size=16,shuffle=True) 16 | siam.eval_triplet(dloader) 17 | -------------------------------------------------------------------------------- /src/umwe2vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.optim as optim 4 | from torchvision.transforms.functional import adjust_contrast 5 | 6 | class umwe2vae(nn.Module): 7 | def __init__(self, vae_model, in_dim=300, out_dim=128): 8 | super(umwe2vae, self).__init__() 9 | self.vae_model = vae_model 10 | #self.fc = nn.Linear(in_dim, out_dim) 11 | self.mapper_numlayer = 3 12 | 13 | self.linear_layers = [] 14 | self.batch_norms = [] 15 | for layer in range(0, self.mapper_numlayer-1): 16 | self.linear_layers.append(nn.Linear(in_dim, in_dim)) 17 | self.batch_norms.append(nn.BatchNorm1d(in_dim)) 18 | 19 | # final layer 20 | self.linear_layers.append(nn.Linear(in_dim, out_dim)) 21 | self.batch_norms.append(nn.BatchNorm1d(out_dim)) 22 | 23 | 24 | self.linear_layers = nn.ModuleList(self.linear_layers) 25 | self.batch_norms = nn.ModuleList(self.batch_norms) 26 | 27 | 28 | self.relu = nn.ReLU() 29 | 30 | 31 | 32 | 33 | def forward(self, x): 34 | inp = x 35 | for layer in range(self.mapper_numlayer): 36 | out = self.linear_layers[layer](inp) 37 | out = self.batch_norms[layer](out) 38 | out = self.relu(out) 39 | inp = out 40 | return out 41 | 42 | # h = self.fc(x) 43 | # y = self.vae_model.decode(h) 44 | # return h 45 | # here used to live post-processing 46 | #out = torch.zeros(y.shape) 47 | #for i in range(y.shape[0]): 48 | # out[i] = adjust_contrast(y[i], contrast_factor=2.5) 49 | #return out 50 | 51 | def loss(self, x, alpha=1, beta=1): 52 | middle = x[:,:,1:-1,1:-1] 53 | ne = x[:,:,0:-2,0:-2] 54 | n = x[:,:,0:-2,1:-1] 55 | nw = x[:,:,0:-2,2:] 56 | e = x[:,:,1:-1,0:-2] 57 | w = x[:,:,1:-1,2:] 58 | se = x[:,:,2:,0:-2] 59 | s = x[:,:,2:,1:-1] 60 | sw = x[:,:,2:,2:] 61 | 62 | return alpha * torch.mean(sum([torch.abs(middle-ne),torch.abs(middle-n),torch.abs(middle-nw),torch.abs(middle-e),torch.abs(middle-w),torch.abs(middle-se),torch.abs(middle-s),torch.abs(middle-sw)]) / 8.) - beta * torch.mean(torch.abs(x-0.5)) 63 | 64 | def train(self, loader, lr=0.001, epochs=5): 65 | optimizer = optim.Adam(self.parameters(), lr=lr) 66 | 67 | for epoch in range(epochs): 68 | total_loss = 0 69 | for _,inp in enumerate(loader): 70 | optimizer.zero_grad() 71 | 72 | out = self.forward(inp) 73 | out = self.vae_model.decode(out) 74 | loss = self.loss(out) 75 | 76 | loss.backward() 77 | optimizer.step() 78 | 79 | total_loss += loss 80 | 81 | print("Epoch: %d Loss: %f" % (epoch+1, total_loss/len(loader))) 82 | 83 | 84 | torch.save(self.state_dict(), "umwe2vae.ckpt") 85 | 86 | -------------------------------------------------------------------------------- /src/umwe_mappers/best_mapping_es2en.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/umwe_mappers/best_mapping_es2en.t7 -------------------------------------------------------------------------------- /src/umwe_mappers/best_mapping_fr2en.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/umwe_mappers/best_mapping_fr2en.t7 -------------------------------------------------------------------------------- /src/umwe_mappers/best_mapping_it2en.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/umwe_mappers/best_mapping_it2en.t7 -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | 3 | 4 | ## Utils to handle newer PyTorch Lightning changes from version 0.6 5 | ## ==================================================================================================== ## 6 | 7 | 8 | def data_loader(fn): 9 | """ 10 | Decorator to handle the deprecation of data_loader from 0.7 11 | :param fn: User defined data loader function 12 | :return: A wrapper for the data_loader function 13 | """ 14 | 15 | def func_wrapper(self): 16 | try: # Works for version 0.6.0 17 | return pl.data_loader(fn)(self) 18 | 19 | except: # Works for version > 0.6.0 20 | return fn(self) 21 | 22 | return func_wrapper 23 | -------------------------------------------------------------------------------- /src/vae_with_norm.py: -------------------------------------------------------------------------------- 1 | from Omniglot import Omniglot 2 | import yaml 3 | import argparse 4 | import numpy as np 5 | 6 | from models import * 7 | from experiment import VAEXperiment 8 | import torch.backends.cudnn as cudnn 9 | from pytorch_lightning import Trainer 10 | from pytorch_lightning.logging import TestTubeLogger 11 | import torchvision.utils as vutils 12 | from torch.utils.data import DataLoader 13 | import torch.optim as optim 14 | from torchvision import transforms 15 | 16 | import pickle 17 | import random 18 | 19 | # import fasttext 20 | # import fasttext.util 21 | 22 | 23 | 24 | 25 | def get_dataLoader(batch_size): 26 | transform = data_transforms() 27 | dataset = Omniglot(split="train", transform=transform) 28 | num_train_imgs = len(dataset) 29 | return DataLoader(dataset, 30 | batch_size= batch_size, 31 | shuffle = True, 32 | drop_last=True) 33 | 34 | 35 | def data_transforms(): 36 | SetRange = transforms.Lambda(lambda X: 2 * X - 1.) 37 | SetScale = transforms.Lambda(lambda X: X/X.sum(0).expand_as(X)) 38 | transform = transforms.Compose([transforms.Resize((64,64)), transforms.ToTensor()]) 39 | return transform 40 | 41 | 42 | def randomSample_from_embeddings(batch_size, fr=""): 43 | with open("/content/drive/My Drive/vae/logogram-language-generator-master/sample_fasttext_embs"+fr+".pickle", "rb") as f: 44 | random_sample = pickle.load(f) 45 | 46 | fake_batch = random.sample(list(random_sample.values()), batch_size) 47 | fake_batch = torch.stack(fake_batch, dim=0) 48 | return fake_batch 49 | 50 | 51 | 52 | 53 | 54 | def trainer(vae_model, mapper, netD, batch_size, device): 55 | # Training Loop 56 | # Lists to keep track of progress 57 | img_list = [] 58 | dataloader = get_dataLoader(batch_size=batch_size) 59 | 60 | # Create batch of latent vectors that we will use to visualize 61 | # the progression of the generator 62 | fixed_noise = randomSample_from_embeddings(batch_size).to(device) 63 | 64 | 65 | 66 | 67 | print("Starting Training Loop...") 68 | # For each epoch 69 | for epoch in range(num_epochs): 70 | # For each batch in the dataloader 71 | for i, data in enumerate(dataloader, 0): 72 | 73 | ############################ 74 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) 75 | ########################### 76 | ## Train with all-real batch 77 | netD.zero_grad() 78 | # Format batch 79 | real_cpu = data[0].to(device) 80 | b_size = real_cpu.size(0) 81 | label = torch.full((b_size,), real_label, dtype=torch.float, device=device) 82 | # Forward pass real batch through D 83 | output = netD(real_cpu).view(-1) 84 | # Calculate loss on all-real batch 85 | errD_real = criterion(output, label) 86 | # Calculate gradients for D in backward pass 87 | errD_real.backward() 88 | D_x = output.mean().item() 89 | 90 | ## Train with all-fake batch 91 | # Generate batch of latent vectors 92 | embeddings = randomSample_from_embeddings(batch_size).to(device) 93 | # Generate fake image batch with G 94 | fake = mapper(embeddings).to(device) 95 | fake = vae_model.decode(fake) 96 | 97 | label.fill_(fake_label) 98 | # Classify all fake batch with D 99 | output = netD(fake.detach()).view(-1) 100 | # Calculate D's loss on the all-fake batch 101 | errD_fake = criterion(output, label) 102 | # Calculate the gradients for this batch 103 | errD_fake.backward() 104 | D_G_z1 = output.mean().item() 105 | # Add the gradients from the all-real and all-fake batches 106 | errD = errD_real + errD_fake 107 | # Update D 108 | optimizerD.step() 109 | 110 | ############################ 111 | # (2) Update G network: maximize log(D(G(z))) 112 | ########################### 113 | mapper.zero_grad() 114 | label.fill_(real_label) # fake labels are real for generator cost 115 | # Since we just updated D, perform another forward pass of all-fake batch through D 116 | output = netD(fake).view(-1) 117 | # Calculate G's loss based on this output 118 | errG = criterion(output, label) 119 | # Calculate gradients for G 120 | errG.backward() 121 | D_G_z2 = output.mean().item() 122 | # Update G 123 | optimizerG.step() 124 | 125 | # Output training stats 126 | if i % 50 == 0: 127 | print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' 128 | % (epoch, num_epochs, i, len(dataloader), 129 | errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) 130 | randomsamples = [] 131 | # for i in range(64): 132 | 133 | with torch.no_grad(): 134 | mapped = mapper(torch.randn(64,300).to(device)) 135 | randomsamples.append(mapped.to(device)) 136 | 137 | embed_samples = [] 138 | # for i in range(64): 139 | one_embed = randomSample_from_embeddings(64).to(device) 140 | with torch.no_grad(): 141 | mapped_embed = mapper(one_embed).to(device) 142 | embed_samples.append(mapped_embed) 143 | 144 | 145 | 146 | randomsamples = torch.stack(randomsamples, dim=0) 147 | recons_randoms = vae_model.decode(randomsamples) 148 | 149 | embedsamples = torch.stack(embed_samples, dim=0) 150 | recons_embeds = vae_model.decode(embedsamples) 151 | 152 | 153 | 154 | vutils.save_image(recons_randoms.data, 155 | f"./vae_with_disc/test_random{i}.png", 156 | normalize=True, 157 | nrow=12) 158 | vutils.save_image(recons_embeds.data, 159 | f"./vae_with_disc/test_embed{i}.png", 160 | normalize=True, 161 | nrow=12) 162 | 163 | # Save Losses for plotting later 164 | G_losses.append(errG.item()) 165 | D_losses.append(errD.item()) 166 | 167 | # Check how the generator is doing by saving G's output on fixed_noise 168 | if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)): 169 | with torch.no_grad(): 170 | fake = mapper(fixed_noise).detach().to(device) 171 | img_list.append(vutils.make_grid(fake, padding=2, normalize=True)) 172 | # vutils.save_image(img_list, 173 | # f"./vae_with_disc/samples_{i}.png", 174 | # normalize=True, 175 | # nrow=12) 176 | 177 | iters += 1 178 | return vae_model, mapper, netD 179 | 180 | def load_vae_model(vae_checkpointPath, config): 181 | model = vae_models[config['model_params']['name']](**config['model_params']) 182 | experiment = VAEXperiment(model, 183 | config['exp_params']) 184 | 185 | 186 | checkpoint = torch.load(vae_checkpointPath, map_location=lambda storage, loc: storage) 187 | new_ckpoint = {} 188 | for k in checkpoint["state_dict"].keys(): 189 | newKey = k.split("model.")[1] 190 | new_ckpoint[newKey] = checkpoint["state_dict"][k] 191 | 192 | model.load_state_dict(new_ckpoint) 193 | model.eval() 194 | return model 195 | 196 | 197 | def main(): 198 | print("on main") 199 | 200 | 201 | with open("./configs/bbvae_setup2.yaml", 'r') as file: 202 | try: 203 | config = yaml.safe_load(file) 204 | except yaml.YAMLError as exc: 205 | print(exc) 206 | 207 | vae_checkpointPath = "logs/BetaVAE_B_setup2_run2/final_model_checkpoint.ckpt" 208 | batch_size = 64 209 | device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") 210 | print(device) 211 | 212 | vae_model = load_vae_model(vae_checkpointPath, config).to(device) 213 | 214 | # fixed_noise = randomSample_from_embeddings(batch_size).to(device) 215 | # fixed_noise = torch.randn() 216 | 217 | best_mapping = "./best_mapping_fr2en.t7" 218 | Readed_t7 = torch.from_numpy(torch.load(best_mapping)).to(device) 219 | 220 | with open("/content/drive/My Drive/vae/logogram-language-generator-master/fasttext_hello_world_fr.pickle", "rb") as f: 221 | helloworld = pickle.load(f) 222 | 223 | # helloworld = random.sample(list(helloworld.values()), batch_size) 224 | hello = torch.from_numpy(helloworld["hello"]).to(device) 225 | world = torch.from_numpy(helloworld["world"]).to(device) 226 | # helloworld = torch.stack([hello, world], dim=0).to(device) 227 | 228 | print(type(hello)) 229 | print(type(Readed_t7)) 230 | mapped_hello = torch.matmul(hello, Readed_t7) 231 | mapped_world = torch.matmul(world, Readed_t7) 232 | 233 | 234 | 235 | 236 | sample_embeds = randomSample_from_embeddings(batch_size, "_fr").to(device) 237 | 238 | std, mean = torch.std_mean(input=sample_embeds, dim=0, unbiased=True) 239 | # norm_func = transforms.Normalize(mean, std) 240 | ## norm here 241 | # normalized = norm_func(hello) 242 | 243 | normalized_hello = (mapped_hello - mean) / std 244 | normalized_world = (mapped_world - mean) / std 245 | 246 | recons_embeds = vae_model.decode(normalized_hello[:128]) 247 | vutils.save_image(recons_embeds.data, 248 | "./vae_with_disc/hello_normalized2_fr.png", 249 | normalize=True, 250 | nrow=12) 251 | recons_embeds = vae_model.decode(normalized_world[:128]) 252 | vutils.save_image(recons_embeds.data, 253 | "./vae_with_disc/World_normalized2_fr.png", 254 | normalize=True, 255 | nrow=12) 256 | 257 | 258 | 259 | print("ALL DONE!") 260 | 261 | 262 | 263 | 264 | 265 | if __name__ == "__main__": 266 | main() -------------------------------------------------------------------------------- /umwe/README.md: -------------------------------------------------------------------------------- 1 | ## Unsupervise Multilingual Word Embeddings 2 | This repo contains the source code for our paper: 3 | 4 | [**Unsupervised Multilingual Word Embeddings**](http://aclweb.org/anthology/D18-1024) 5 |
6 | [Xilun Chen](http://www.cs.cornell.edu/~xlchen/), 7 | [Claire Cardie](http://www.cs.cornell.edu/home/cardie/) 8 |
9 | EMNLP 2018 10 |
11 | [paper](http://aclweb.org/anthology/D18-1024), 12 | [bibtex](http://aclweb.org/anthology/D18-1024.bib) 13 | 14 | ![Highlight](http://www.cs.cornell.edu/~xlchen/assets/images/umwe_highlight.png) 15 | 16 | ## Dependencies 17 | * Python 3.6 with [NumPy](http://www.numpy.org/)/[SciPy](https://www.scipy.org/) 18 | * [PyTorch](http://pytorch.org/) 0.4 19 | * [Faiss](https://github.com/facebookresearch/faiss) (recommended) for fast nearest neighbor search (CPU or GPU). 20 | 21 | Faiss is *optional* for GPU users - though Faiss-GPU will greatly speed up nearest neighbor search - and *highly recommended* for CPU users. Faiss can be installed using "conda install faiss-cpu -c pytorch" or "conda install faiss-gpu -c pytorch". 22 | 23 | ## Get evaluation datasets 24 | The evaluation datasets from [MUSE](https://github.com/facebookresearch/MUSE) can be downloaded by simply running (in data/): 25 | 26 | ```bash 27 | ./get_evaluation.sh 28 | ``` 29 | *Note: Requires bash 4. The download of Europarl is disabled by default (slow), you can enable it [here](https://github.com/facebookresearch/MUSE/blob/master/data/get_evaluation.sh#L99-L100).* 30 | 31 | ## Get monolingual word embeddings 32 | For pre-trained monolingual word embeddings, we adopt the same [fastText Wikipedia embeddings](https://github.com/facebookresearch/fastText/blob/master/pretrained-vectors.md) recommended by the MUSE authors. 33 | 34 | For example, you can download the English (en) embeddings this way: 35 | ```bash 36 | # English fastText Wikipedia embeddings 37 | curl -Lo data/fasttext-vectors/wiki.en.vec https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.en.vec 38 | ``` 39 | 40 | ## Learn unsupervised multilingual word embeddings 41 | ```bash 42 | python unsupervised.py --src_langs de fr es it pt --tgt_lang en 43 | ``` 44 | 45 | ## Learn supervised multilingual word embeddings 46 | This was not explored in the paper, but our MPSR method can in principle be applied to the supervised setting as well. 47 | No adversarial training will take place, and the Procrustes method in MUSE is replaced by our MPSR algorithm to learn multilingual embeddings. 48 | For example, the following command can be used to train (weakly) supervised UMWEs with identical characters served as supervsion. 49 | ```bash 50 | python supervised.py --src_langs de fr es it pt --tgt_lang en --dico_train identical_char 51 | ``` 52 | 53 | ## Evaluate cross-lingual embeddings (CPU|GPU) 54 | To save time, not all language pairs are evaluated during training. 55 | You may run the evaluation script perform full evaluation. 56 | The test results are saved in 'evaluate.log' in the model folder. 57 | ```bash 58 | python evaluate.py --src_langs de fr es it pt --tgt_lang en --eval_pairs all --exp_id [your exp_id] 59 | ``` 60 | 61 | ## Word embedding format 62 | By default, the aligned embeddings are exported to a text format at the end of experiments: `--export txt`. Exporting embeddings to a text file can take a while if you have a lot of embeddings. For a very fast export, you can set `--export pth` to export the embeddings in a PyTorch binary file, or simply disable the export (`--export ""`). 63 | 64 | When loading embeddings, the model can load: 65 | * PyTorch binary files previously generated by UMWE (.pth files) 66 | * fastText binary files previously generated by fastText (.bin files) 67 | * text files (text file with one word embedding per line) 68 | 69 | The two first options are very fast and can load 1 million embeddings in a few seconds, while loading text files can take a while. 70 | 71 | ## License 72 | 73 | This work is developed based on [MUSE](https://github.com/facebookresearch/MUSE) by Alexis Conneau, Guillaume Lample, et al. from Facebook Inc., used under [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/). 74 | 75 | This work is licensed under [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) by Xilun Chen. 76 | -------------------------------------------------------------------------------- /umwe/data/get_evaluation.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | en_analogy='https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/word2vec/source-archive.zip' 9 | aws_path='https://s3.amazonaws.com/arrival' 10 | semeval_2017='http://alt.qcri.org/semeval2017/task2/data/uploads' 11 | europarl='http://www.statmt.org/europarl/v7' 12 | 13 | declare -A wordsim_lg 14 | wordsim_lg=(["en"]="EN_MC-30.txt EN_MTurk-287.txt EN_RG-65.txt EN_VERB-143.txt EN_WS-353-REL.txt EN_YP-130.txt EN_MEN-TR-3k.txt EN_MTurk-771.txt EN_RW-STANFORD.txt EN_SIMLEX-999.txt EN_WS-353-ALL.txt EN_WS-353-SIM.txt" ["es"]="ES_MC-30.txt ES_RG-65.txt ES_WS-353.txt" ["de"]="DE_GUR350.txt DE_GUR65.txt DE_SIMLEX-999.txt DE_WS-353.txt DE_ZG222.txt" ["fr"]="FR_RG-65.txt" ["it"]="IT_SIMLEX-999.txt IT_WS-353.txt") 15 | 16 | mkdir monolingual crosslingual 17 | 18 | ## English word analogy task 19 | curl -Lo source-archive.zip $en_analogy 20 | mkdir -p monolingual/en/ 21 | unzip -p source-archive.zip word2vec/trunk/questions-words.txt > monolingual/en/questions-words.txt 22 | rm source-archive.zip 23 | 24 | 25 | ## Downloading en-{} or {}-en dictionaries 26 | lgs="af ar bg bn bs ca cs da de el en es et fa fi fr he hi hr hu id it ja ko lt lv mk ms nl no pl pt ro ru sk sl sq sv ta th tl tr uk vi zh" 27 | mkdir -p crosslingual/dictionaries/ 28 | for lg in ${lgs} 29 | do 30 | for suffix in .txt .0-5000.txt .5000-6500.txt 31 | do 32 | fname=en-$lg$suffix 33 | curl -Lo crosslingual/dictionaries/$fname $aws_path/dictionaries/$fname 34 | fname=$lg-en$suffix 35 | curl -Lo crosslingual/dictionaries/$fname $aws_path/dictionaries/$fname 36 | done 37 | done 38 | 39 | ## Download European dictionaries 40 | for src_lg in de es fr it pt 41 | do 42 | for tgt_lg in de es fr it pt 43 | do 44 | if [ $src_lg != $tgt_lg ] 45 | then 46 | for suffix in .txt .0-5000.txt .5000-6500.txt 47 | do 48 | fname=$src_lg-$tgt_lg$suffix 49 | curl -Lo crosslingual/dictionaries/$fname $aws_path/dictionaries/european/$fname 50 | done 51 | fi 52 | done 53 | done 54 | 55 | ## Download Dinu et al. dictionaries 56 | for fname in OPUS_en_it_europarl_train_5K.txt OPUS_en_it_europarl_test.txt 57 | do 58 | echo $fname 59 | curl -Lo crosslingual/dictionaries/$fname $aws_path/dictionaries/$fname 60 | done 61 | 62 | ## Monolingual wordsim tasks 63 | for lang in "${!wordsim_lg[@]}" 64 | do 65 | echo $lang 66 | mkdir monolingual/$lang 67 | for wsim in ${wordsim_lg[$lang]} 68 | do 69 | echo $wsim 70 | curl -Lo monolingual/$lang/$wsim $aws_path/$lang/$wsim 71 | done 72 | done 73 | 74 | ## SemEval 2017 monolingual and cross-lingual wordsim tasks 75 | # 1) Task1: monolingual 76 | curl -Lo semeval2017-task2.zip $semeval_2017/semeval2017-task2.zip 77 | unzip semeval2017-task2.zip 78 | 79 | fdir='SemEval17-Task2/test/subtask1-monolingual' 80 | for lang in en es de fa it 81 | do 82 | mkdir -p monolingual/$lang 83 | uplang=`echo $lang | awk '{print toupper($0)}'` 84 | paste $fdir/data/$lang.test.data.txt $fdir/keys/$lang.test.gold.txt > monolingual/$lang/${uplang}_SEMEVAL17.txt 85 | done 86 | 87 | # 2) Task2: cross-lingual 88 | mkdir -p crosslingual/wordsim 89 | fdir='SemEval17-Task2/test/subtask2-crosslingual' 90 | for lg_pair in de-es de-fa de-it en-de en-es en-fa en-it es-fa es-it it-fa 91 | do 92 | echo $lg_pair 93 | paste $fdir/data/$lg_pair.test.data.txt $fdir/keys/$lg_pair.test.gold.txt > crosslingual/wordsim/$lg_pair-SEMEVAL17.txt 94 | done 95 | rm semeval2017-task2.zip 96 | rm -r SemEval17-Task2/ 97 | 98 | ## Europarl for sentence retrieval 99 | # TODO: set to true to activate download of Europarl (slow) 100 | if false; then 101 | mkdir -p crosslingual/europarl 102 | # Tokenize EUROPARL with MOSES 103 | echo 'Cloning Moses github repository (for tokenization scripts)...' 104 | git clone https://github.com/moses-smt/mosesdecoder.git 105 | SCRIPTS=mosesdecoder/scripts 106 | TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl 107 | 108 | for lg_pair in it-en # es-en etc 109 | do 110 | curl -Lo $lg_pair.tgz $europarl/$lg_pair.tgz 111 | tar -xvf it-en.tgz 112 | rm it-en.tgz 113 | lgs=(${lg_pair//-/ }) 114 | for lg in ${lgs[0]} ${lgs[1]} 115 | do 116 | cat europarl-v7.$lg_pair.$lg | $TOKENIZER -threads 8 -l $lg -no-escape > euro.$lg.txt 117 | rm europarl-v7.$lg_pair.$lg 118 | done 119 | 120 | paste euro.${lgs[0]}.txt euro.${lgs[1]}.txt | shuf > euro.paste.txt 121 | rm euro.${lgs[0]}.txt euro.${lgs[1]}.txt 122 | 123 | cut -f1 euro.paste.txt > crosslingual/europarl/europarl-v7.$lg_pair.${lgs[0]} 124 | cut -f2 euro.paste.txt > crosslingual/europarl/europarl-v7.$lg_pair.${lgs[1]} 125 | rm euro.paste.txt 126 | done 127 | 128 | rm -rf mosesdecoder 129 | fi 130 | -------------------------------------------------------------------------------- /umwe/evaluate.py: -------------------------------------------------------------------------------- 1 | # Original work Copyright (c) 2017-present, Facebook, Inc. 2 | # Modified work Copyright (c) 2018, Xilun Chen 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | # python evaluate.py --src_langs de fr es it pt --tgt_lang en --eval_pairs all 10 | 11 | import os 12 | import argparse 13 | from collections import OrderedDict 14 | 15 | from src.utils import bool_flag, initialize_exp 16 | from src.models import build_model 17 | from src.trainer import Trainer 18 | from src.evaluation import Evaluator 19 | 20 | # default path to embeddings embeddings if not otherwise specified 21 | EMB_DIR = 'data/fasttext-vectors/' 22 | 23 | 24 | # main 25 | parser = argparse.ArgumentParser(description='Evaluation') 26 | parser.add_argument("--verbose", type=int, default=2, help="Verbose level (2:debug, 1:info, 0:warning)") 27 | parser.add_argument("--exp_path", type=str, default="", help="Where to store experiment logs and models") 28 | parser.add_argument("--exp_name", type=str, default="debug", help="Experiment name") 29 | parser.add_argument("--exp_id", type=str, default="", help="Experiment ID") 30 | # parser.add_argument("--cuda", type=bool_flag, default=True, help="Run on GPU") 31 | parser.add_argument("--device", type=str, default="cuda", help="Run on GPU or CPU") 32 | # data 33 | parser.add_argument("--src_langs", type=str, nargs='+', default=[], help="Source languages") 34 | parser.add_argument("--tgt_lang", type=str, default="", help="Target language") 35 | # evaluation 36 | parser.add_argument("--eval_pairs", type=str, nargs='+', default=[], help="Language pairs to evaluate. e.g. ['en-de', 'de-fr']") 37 | parser.add_argument("--dico_eval", type=str, default="default", help="Path to evaluation dictionary") 38 | parser.add_argument("--dict_suffix", type=str, default="5000-6500.txt", help="suffix to use for word translation (0-5000.txt or 5000-6500.txt or txt)") 39 | parser.add_argument("--semeval_ignore_oov", type=bool_flag, default=True, help="Whether to ignore OOV in SEMEVAL evaluation (the original authors used True)") 40 | # reload pre-trained embeddings 41 | parser.add_argument("--src_embs", type=str, nargs='+', default=[], help="Reload source embeddings (should be in the same order as in src_langs)") 42 | parser.add_argument("--tgt_emb", type=str, default="", help="Reload target embeddings") 43 | parser.add_argument("--max_vocab", type=int, default=200000, help="Maximum vocabulary size (-1 to disable)") 44 | parser.add_argument("--emb_dim", type=int, default=300, help="Embedding dimension") 45 | parser.add_argument("--normalize_embeddings", type=str, default="", help="Normalize embeddings before training") 46 | 47 | 48 | # parse parameters 49 | params = parser.parse_args() 50 | 51 | # post-processing options 52 | params.src_N = len(params.src_langs) 53 | params.all_langs = params.src_langs + [params.tgt_lang] 54 | # load default embeddings 55 | if len(params.src_embs) == 0: 56 | params.src_embs = [] 57 | for lang in params.src_langs: 58 | params.src_embs.append(os.path.join(EMB_DIR, f'wiki.{lang}.vec')) 59 | if len(params.tgt_emb) == 0: 60 | params.tgt_emb = os.path.join(EMB_DIR, f'wiki.{params.tgt_lang}.vec') 61 | # expand 'all' in eval_pairs 62 | if 'all' in params.eval_pairs: 63 | params.eval_pairs = [] 64 | for lang1 in params.all_langs: 65 | for lang2 in params.all_langs: 66 | if lang1 != lang2: 67 | params.eval_pairs.append(f'{lang1}-{lang2}') 68 | 69 | # check parameters 70 | assert len(params.src_langs) > 0, "source language undefined" 71 | assert all([os.path.isfile(emb) for emb in params.src_embs]) 72 | assert not params.tgt_lang or os.path.isfile(params.tgt_emb) 73 | assert params.dico_eval == 'default' or os.path.isfile(params.dico_eval) 74 | 75 | # build logger / model / trainer / evaluator 76 | logger = initialize_exp(params, dump_params=False, log_name='evaluate.log') 77 | embs, mappings, _ = build_model(params, False) 78 | trainer = Trainer(embs, mappings, None, params) 79 | trainer.reload_best() 80 | evaluator = Evaluator(trainer) 81 | 82 | # run evaluations 83 | to_log = OrderedDict({'n_iter': 0}) 84 | all_wt = [] 85 | evaluator.monolingual_wordsim(to_log) 86 | for eval_pair in params.eval_pairs: 87 | parts = eval_pair.split('-') 88 | assert len(parts) == 2, 'Invalid format for evaluation pairs.' 89 | src_lang, tgt_lang = parts[0], parts[1] 90 | logger.info(f'Evaluating language pair: {src_lang} - {tgt_lang}') 91 | evaluator.crosslingual_wordsim(to_log, src_lang=src_lang, tgt_lang=tgt_lang) 92 | evaluator.word_translation(to_log, src_lang=src_lang, tgt_lang=tgt_lang) 93 | all_wt.append(to_log[f'{src_lang}-{tgt_lang}_precision_at_1-csls_knn_10']) 94 | evaluator.sent_translation(to_log, src_lang=src_lang, tgt_lang=tgt_lang) 95 | 96 | logger.info(f"Overall Word Translation Precision@1 over {len(all_wt)} language pairs: {sum(all_wt)/len(all_wt)}") 97 | -------------------------------------------------------------------------------- /umwe/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/umwe/src/__init__.py -------------------------------------------------------------------------------- /umwe/src/dico_builder.py: -------------------------------------------------------------------------------- 1 | # Original work Copyright (c) 2017-present, Facebook, Inc. 2 | # Modified work Copyright (c) 2018, Xilun Chen 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | from logging import getLogger 10 | import torch 11 | 12 | from .utils import get_nn_avg_dist 13 | 14 | 15 | logger = getLogger() 16 | 17 | 18 | def get_candidates(emb1, emb2, params): 19 | """ 20 | Get best translation pairs candidates. 21 | """ 22 | bs = 128 23 | 24 | all_scores = [] 25 | all_targets = [] 26 | 27 | # number of source words to consider 28 | n_src = emb1.size(0) 29 | if params.dico_max_rank > 0 and not params.dico_method.startswith('invsm_beta_'): 30 | n_src = params.dico_max_rank 31 | 32 | # nearest neighbors 33 | if params.dico_method == 'nn': 34 | 35 | # for every source word 36 | for i in range(0, n_src, bs): 37 | 38 | # compute target words scores 39 | scores = emb2.mm(emb1[i:min(n_src, i + bs)].transpose(0, 1)).transpose(0, 1) 40 | best_scores, best_targets = scores.topk(2, dim=1, largest=True, sorted=True) 41 | 42 | # update scores / potential targets 43 | all_scores.append(best_scores.cpu()) 44 | all_targets.append(best_targets.cpu()) 45 | 46 | all_scores = torch.cat(all_scores, 0) 47 | all_targets = torch.cat(all_targets, 0) 48 | 49 | # inverted softmax 50 | elif params.dico_method.startswith('invsm_beta_'): 51 | 52 | beta = float(params.dico_method[len('invsm_beta_'):]) 53 | 54 | # for every target word 55 | for i in range(0, emb2.size(0), bs): 56 | 57 | # compute source words scores 58 | scores = emb1.mm(emb2[i:i + bs].transpose(0, 1)) 59 | scores.mul_(beta).exp_() 60 | scores.div_(scores.sum(0, keepdim=True).expand_as(scores)) 61 | 62 | best_scores, best_targets = scores.topk(2, dim=1, largest=True, sorted=True) 63 | 64 | # update scores / potential targets 65 | all_scores.append(best_scores.cpu()) 66 | all_targets.append((best_targets + i).cpu()) 67 | 68 | all_scores = torch.cat(all_scores, 1) 69 | all_targets = torch.cat(all_targets, 1) 70 | 71 | all_scores, best_targets = all_scores.topk(2, dim=1, largest=True, sorted=True) 72 | all_targets = all_targets.gather(1, best_targets) 73 | 74 | # contextual dissimilarity measure 75 | elif params.dico_method.startswith('csls_knn_'): 76 | 77 | knn = params.dico_method[len('csls_knn_'):] 78 | assert knn.isdigit() 79 | knn = int(knn) 80 | 81 | # average distances to k nearest neighbors 82 | average_dist1 = torch.from_numpy(get_nn_avg_dist(emb2, emb1, knn)) 83 | average_dist2 = torch.from_numpy(get_nn_avg_dist(emb1, emb2, knn)) 84 | average_dist1 = average_dist1.type_as(emb1) 85 | average_dist2 = average_dist2.type_as(emb2) 86 | 87 | # for every source word 88 | for i in range(0, n_src, bs): 89 | 90 | # compute target words scores 91 | scores = emb2.mm(emb1[i:min(n_src, i + bs)].transpose(0, 1)).transpose(0, 1) 92 | scores.mul_(2) 93 | scores.sub_(average_dist1[i:min(n_src, i + bs)][:, None] + average_dist2[None, :]) 94 | best_scores, best_targets = scores.topk(2, dim=1, largest=True, sorted=True) 95 | 96 | # update scores / potential targets 97 | all_scores.append(best_scores.cpu()) 98 | all_targets.append(best_targets.cpu()) 99 | 100 | all_scores = torch.cat(all_scores, 0) 101 | all_targets = torch.cat(all_targets, 0) 102 | 103 | all_pairs = torch.cat([ 104 | torch.arange(0, all_targets.size(0)).long().unsqueeze(1), 105 | all_targets[:, 0].unsqueeze(1) 106 | ], 1) 107 | 108 | # sanity check 109 | assert all_scores.size() == all_pairs.size() == (n_src, 2) 110 | 111 | # sort pairs by score confidence 112 | diff = all_scores[:, 0] - all_scores[:, 1] 113 | reordered = diff.sort(0, descending=True)[1] 114 | all_scores = all_scores[reordered] 115 | all_pairs = all_pairs[reordered] 116 | 117 | # max dico words rank 118 | if params.dico_max_rank > 0: 119 | selected = all_pairs.max(1)[0] <= params.dico_max_rank 120 | mask = selected.unsqueeze(1).expand_as(all_scores).clone() 121 | all_scores = all_scores.masked_select(mask).view(-1, 2) 122 | all_pairs = all_pairs.masked_select(mask).view(-1, 2) 123 | if len(all_pairs) == 0: 124 | return [] 125 | 126 | # max dico size 127 | if params.dico_max_size > 0: 128 | all_scores = all_scores[:params.dico_max_size] 129 | all_pairs = all_pairs[:params.dico_max_size] 130 | 131 | # min dico size 132 | diff = all_scores[:, 0] - all_scores[:, 1] 133 | if params.dico_min_size > 0: 134 | diff[:params.dico_min_size] = 1e9 135 | 136 | # confidence threshold 137 | if params.dico_threshold > 0: 138 | mask = diff > params.dico_threshold 139 | logger.info("Selected %i / %i pairs above the confidence threshold." % (mask.sum(), diff.size(0))) 140 | mask = mask.unsqueeze(1).expand_as(all_pairs).clone() 141 | all_pairs = all_pairs.masked_select(mask).view(-1, 2) 142 | 143 | return all_pairs 144 | 145 | 146 | def build_dictionary(src_emb, tgt_emb, params, s2t_candidates=None, t2s_candidates=None): 147 | """ 148 | Build a training dictionary given current embeddings / mapping. 149 | """ 150 | logger.info("Building the train dictionary ...") 151 | s2t = 'S2T' in params.dico_build 152 | t2s = 'T2S' in params.dico_build 153 | assert s2t or t2s 154 | 155 | if s2t: 156 | if s2t_candidates is None: 157 | s2t_candidates = get_candidates(src_emb, tgt_emb, params) 158 | if t2s: 159 | if t2s_candidates is None: 160 | t2s_candidates = get_candidates(tgt_emb, src_emb, params) 161 | t2s_candidates = torch.cat([t2s_candidates[:, 1:], t2s_candidates[:, :1]], 1) 162 | 163 | if params.dico_build == 'S2T': 164 | dico = s2t_candidates 165 | elif params.dico_build == 'T2S': 166 | dico = t2s_candidates 167 | else: 168 | s2t_candidates = set([(a, b) for a, b in s2t_candidates.numpy()]) 169 | t2s_candidates = set([(a, b) for a, b in t2s_candidates.numpy()]) 170 | if params.dico_build == 'S2T|T2S': 171 | final_pairs = s2t_candidates | t2s_candidates 172 | else: 173 | assert params.dico_build == 'S2T&T2S' 174 | final_pairs = s2t_candidates & t2s_candidates 175 | if len(final_pairs) == 0: 176 | logger.warning("Empty intersection ...") 177 | return None 178 | dico = torch.LongTensor(list([[int(a), int(b)] for (a, b) in final_pairs])) 179 | 180 | if len(dico) == 0: 181 | return None 182 | logger.info('New train dictionary of %i pairs.' % dico.size(0)) 183 | return dico.to(params.device) 184 | -------------------------------------------------------------------------------- /umwe/src/dictionary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | 10 | 11 | logger = getLogger() 12 | 13 | 14 | class Dictionary(object): 15 | 16 | def __init__(self, id2word, word2id, lang): 17 | assert len(id2word) == len(word2id) 18 | self.id2word = id2word 19 | self.word2id = word2id 20 | self.lang = lang 21 | self.check_valid() 22 | 23 | def __len__(self): 24 | """ 25 | Returns the number of words in the dictionary. 26 | """ 27 | return len(self.id2word) 28 | 29 | def __getitem__(self, i): 30 | """ 31 | Returns the word of the specified index. 32 | """ 33 | return self.id2word[i] 34 | 35 | def __contains__(self, w): 36 | """ 37 | Returns whether a word is in the dictionary. 38 | """ 39 | return w in self.word2id 40 | 41 | def __eq__(self, y): 42 | """ 43 | Compare the dictionary with another one. 44 | """ 45 | self.check_valid() 46 | y.check_valid() 47 | if len(self.id2word) != len(y): 48 | return False 49 | return self.lang == y.lang and all(self.id2word[i] == y[i] for i in range(len(y))) 50 | 51 | def check_valid(self): 52 | """ 53 | Check that the dictionary is valid. 54 | """ 55 | assert len(self.id2word) == len(self.word2id) 56 | for i in range(len(self.id2word)): 57 | assert self.word2id[self.id2word[i]] == i 58 | 59 | def index(self, word): 60 | """ 61 | Returns the index of the specified word. 62 | """ 63 | return self.word2id[word] 64 | 65 | def prune(self, max_vocab): 66 | """ 67 | Limit the vocabulary size. 68 | """ 69 | assert max_vocab >= 1 70 | self.id2word = {k: v for k, v in self.id2word.items() if k < max_vocab} 71 | self.word2id = {v: k for k, v in self.id2word.items()} 72 | self.check_valid() 73 | -------------------------------------------------------------------------------- /umwe/src/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .wordsim import get_wordsim_scores, get_crosslingual_wordsim_scores, get_wordanalogy_scores 2 | from .word_translation import get_word_translation_accuracy 3 | from .sent_translation import get_sent_translation_accuracy, load_europarl_data 4 | from .evaluator import Evaluator 5 | -------------------------------------------------------------------------------- /umwe/src/evaluation/sent_translation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import io 10 | from logging import getLogger 11 | import numpy as np 12 | import torch 13 | 14 | from src.utils import bow_idf, get_nn_avg_dist 15 | 16 | 17 | EUROPARL_DIR = 'data/crosslingual/europarl' 18 | 19 | 20 | logger = getLogger() 21 | 22 | 23 | def load_europarl_data(lg1, lg2, n_max=1e10, lower=True): 24 | """ 25 | Load data parallel sentences 26 | """ 27 | if not (os.path.isfile(os.path.join(EUROPARL_DIR, 'europarl-v7.%s-%s.%s' % (lg1, lg2, lg1))) or 28 | os.path.isfile(os.path.join(EUROPARL_DIR, 'europarl-v7.%s-%s.%s' % (lg2, lg1, lg1)))): 29 | return None 30 | 31 | if os.path.isfile(os.path.join(EUROPARL_DIR, 'europarl-v7.%s-%s.%s' % (lg2, lg1, lg1))): 32 | lg1, lg2 = lg2, lg1 33 | 34 | # load sentences 35 | data = {lg1: [], lg2: []} 36 | for lg in [lg1, lg2]: 37 | fname = os.path.join(EUROPARL_DIR, 'europarl-v7.%s-%s.%s' % (lg1, lg2, lg)) 38 | 39 | with io.open(fname, 'r', encoding='utf-8') as f: 40 | for i, line in enumerate(f): 41 | if i >= n_max: 42 | break 43 | line = line.lower() if lower else line 44 | data[lg].append(line.rstrip().split()) 45 | 46 | # get only unique sentences for each language 47 | assert len(data[lg1]) == len(data[lg2]) 48 | data[lg1] = np.array(data[lg1]) 49 | data[lg2] = np.array(data[lg2]) 50 | data[lg1], indices = np.unique(data[lg1], return_index=True) 51 | data[lg2] = data[lg2][indices] 52 | data[lg2], indices = np.unique(data[lg2], return_index=True) 53 | data[lg1] = data[lg1][indices] 54 | 55 | # shuffle sentences 56 | rng = np.random.RandomState(1234) 57 | perm = rng.permutation(len(data[lg1])) 58 | data[lg1] = data[lg1][perm] 59 | data[lg2] = data[lg2][perm] 60 | 61 | logger.info("Loaded europarl %s-%s (%i sentences)." % (lg1, lg2, len(data[lg1]))) 62 | return data 63 | 64 | 65 | def get_sent_translation_accuracy(data, lg1, word2id1, emb1, lg2, word2id2, emb2, 66 | n_keys, n_queries, method, idf): 67 | 68 | """ 69 | Given parallel sentences from Europarl, evaluate the 70 | sentence translation accuracy using the precision@k. 71 | """ 72 | # get word vectors dictionaries 73 | emb1 = emb1.cpu().detach().numpy() 74 | emb2 = emb2.cpu().detach().numpy() 75 | word_vec1 = dict([(w, emb1[word2id1[w]]) for w in word2id1]) 76 | word_vec2 = dict([(w, emb2[word2id2[w]]) for w in word2id2]) 77 | word_vect = {lg1: word_vec1, lg2: word_vec2} 78 | lg_keys = lg2 79 | lg_query = lg1 80 | 81 | # get n_keys pairs of sentences 82 | keys = data[lg_keys][:n_keys] 83 | keys = bow_idf(keys, word_vect[lg_keys], idf_dict=idf[lg_keys]) 84 | 85 | # get n_queries query pairs from these n_keys pairs 86 | rng = np.random.RandomState(1234) 87 | idx_query = rng.choice(range(n_keys), size=n_queries, replace=False) 88 | queries = data[lg_query][idx_query] 89 | queries = bow_idf(queries, word_vect[lg_query], idf_dict=idf[lg_query]) 90 | 91 | # normalize embeddings 92 | queries = torch.from_numpy(queries).float() 93 | queries = queries / queries.norm(2, 1, keepdim=True).expand_as(queries) 94 | keys = torch.from_numpy(keys).float() 95 | keys = keys / keys.norm(2, 1, keepdim=True).expand_as(keys) 96 | 97 | # nearest neighbors 98 | if method == 'nn': 99 | scores = keys.mm(queries.transpose(0, 1)).transpose(0, 1) 100 | scores = scores.cpu() 101 | 102 | # inverted softmax 103 | elif method.startswith('invsm_beta_'): 104 | beta = float(method[len('invsm_beta_'):]) 105 | scores = keys.mm(queries.transpose(0, 1)).transpose(0, 1) 106 | scores.mul_(beta).exp_() 107 | scores.div_(scores.sum(0, keepdim=True).expand_as(scores)) 108 | scores = scores.cpu() 109 | 110 | # contextual dissimilarity measure 111 | elif method.startswith('csls_knn_'): 112 | knn = method[len('csls_knn_'):] 113 | assert knn.isdigit() 114 | knn = int(knn) 115 | # average distances to k nearest neighbors 116 | knn = method[len('csls_knn_'):] 117 | assert knn.isdigit() 118 | knn = int(knn) 119 | average_dist_keys = torch.from_numpy(get_nn_avg_dist(queries, keys, knn)) 120 | average_dist_queries = torch.from_numpy(get_nn_avg_dist(keys, queries, knn)) 121 | # scores 122 | scores = keys.mm(queries.transpose(0, 1)).transpose(0, 1) 123 | scores.mul_(2) 124 | scores.sub_(average_dist_queries[:, None].float() + average_dist_keys[None, :].float()) 125 | scores = scores.cpu() 126 | 127 | results = [] 128 | top_matches = scores.topk(10, 1, True)[1] 129 | for k in [1, 5, 10]: 130 | top_k_matches = (top_matches[:, :k] == torch.from_numpy(idx_query)[:, None]).sum(1) 131 | precision_at_k = 100 * top_k_matches.float().numpy().mean() 132 | logger.info("%i queries (%s) - %s - Precision at k = %i: %f" % 133 | (len(top_k_matches), lg_query.upper(), method, k, precision_at_k)) 134 | results.append(('sent-precision_at_%i' % k, precision_at_k)) 135 | 136 | return results 137 | -------------------------------------------------------------------------------- /umwe/src/evaluation/word_translation.py: -------------------------------------------------------------------------------- 1 | # Original work Copyright (c) 2017-present, Facebook, Inc. 2 | # Modified work Copyright (c) 2018, Xilun Chen 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | import os 10 | import io 11 | from logging import getLogger 12 | import numpy as np 13 | import torch 14 | 15 | from ..utils import get_nn_avg_dist 16 | 17 | 18 | DIC_EVAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', 'data', 'crosslingual', 'dictionaries') 19 | 20 | 21 | logger = getLogger() 22 | 23 | 24 | def load_identical_char_dico(word2id1, word2id2): 25 | """ 26 | Build a dictionary of identical character strings. 27 | """ 28 | pairs = [(w1, w1) for w1 in word2id1.keys() if w1 in word2id2] 29 | if len(pairs) == 0: 30 | raise Exception("No identical character strings were found. " 31 | "Please specify a dictionary.") 32 | 33 | logger.info("Found %i pairs of identical character strings." % len(pairs)) 34 | 35 | # sort the dictionary by source word frequencies 36 | pairs = sorted(pairs, key=lambda x: word2id1[x[0]]) 37 | dico = torch.LongTensor(len(pairs), 2) 38 | for i, (word1, word2) in enumerate(pairs): 39 | dico[i, 0] = word2id1[word1] 40 | dico[i, 1] = word2id2[word2] 41 | 42 | return dico 43 | 44 | 45 | def load_dictionary(path, word2id1, word2id2): 46 | """ 47 | Return a torch tensor of size (n, 2) where n is the size of the 48 | loader dictionary, and sort it by source word frequency. 49 | """ 50 | assert os.path.isfile(path) 51 | 52 | pairs = [] 53 | not_found = 0 54 | not_found1 = 0 55 | not_found2 = 0 56 | 57 | with io.open(path, 'r', encoding='utf-8') as f: 58 | for _, line in enumerate(f): 59 | assert line == line.lower() 60 | word1, word2 = line.rstrip().split() 61 | if word1 in word2id1 and word2 in word2id2: 62 | pairs.append((word1, word2)) 63 | else: 64 | not_found += 1 65 | not_found1 += int(word1 not in word2id1) 66 | not_found2 += int(word2 not in word2id2) 67 | 68 | logger.info("Found %i pairs of words in %s (%i unique). " 69 | "%i other pairs contained at least one unknown word " 70 | "(%i in lang1, %i in lang2)" 71 | % (len(pairs), path, len(set([x for x, _ in pairs])), 72 | not_found, not_found1, not_found2)) 73 | 74 | # sort the dictionary by source word frequencies 75 | pairs = sorted(pairs, key=lambda x: word2id1[x[0]]) 76 | dico = torch.LongTensor(len(pairs), 2) 77 | for i, (word1, word2) in enumerate(pairs): 78 | dico[i, 0] = word2id1[word1] 79 | dico[i, 1] = word2id2[word2] 80 | 81 | return dico 82 | 83 | 84 | def get_word_translation_accuracy(lang1, word2id1, emb1, lang2, word2id2, emb2, method, dico_eval): 85 | """ 86 | Given source and target word embeddings, and a dictionary, 87 | evaluate the translation accuracy using the precision@k. 88 | """ 89 | if dico_eval == 'default': 90 | path = os.path.join(DIC_EVAL_PATH, '%s-%s.5000-6500.txt' % (lang1, lang2)) 91 | elif dico_eval == 'train': 92 | path = os.path.join(DIC_EVAL_PATH, '%s-%s.0-5000.txt' % (lang1, lang2)) 93 | elif dico_eval == 'all': 94 | path = os.path.join(DIC_EVAL_PATH, '%s-%s.txt' % (lang1, lang2)) 95 | else: 96 | raise NotImplemented(dico_eval) 97 | dico = load_dictionary(path, word2id1, word2id2).to(emb1.device) 98 | 99 | assert dico[:, 0].max() < emb1.size(0) 100 | assert dico[:, 1].max() < emb2.size(0) 101 | 102 | # normalize word embeddings 103 | emb1 = emb1 / emb1.norm(2, 1, keepdim=True).expand_as(emb1) 104 | emb2 = emb2 / emb2.norm(2, 1, keepdim=True).expand_as(emb2) 105 | 106 | # nearest neighbors 107 | if method == 'nn': 108 | query = emb1[dico[:, 0]] 109 | scores = query.mm(emb2.transpose(0, 1)) 110 | 111 | # inverted softmax 112 | elif method.startswith('invsm_beta_'): 113 | beta = float(method[len('invsm_beta_'):]) 114 | bs = 128 115 | word_scores = [] 116 | for i in range(0, emb2.size(0), bs): 117 | scores = emb1.mm(emb2[i:i + bs].transpose(0, 1)) 118 | scores.mul_(beta).exp_() 119 | scores.div_(scores.sum(0, keepdim=True).expand_as(scores)) 120 | word_scores.append(scores.index_select(0, dico[:, 0])) 121 | scores = torch.cat(word_scores, 1) 122 | 123 | # contextual dissimilarity measure 124 | elif method.startswith('csls_knn_'): 125 | # average distances to k nearest neighbors 126 | knn = method[len('csls_knn_'):] 127 | assert knn.isdigit() 128 | knn = int(knn) 129 | average_dist1 = get_nn_avg_dist(emb2, emb1, knn) 130 | average_dist2 = get_nn_avg_dist(emb1, emb2, knn) 131 | average_dist1 = torch.from_numpy(average_dist1).type_as(emb1) 132 | average_dist2 = torch.from_numpy(average_dist2).type_as(emb2) 133 | # queries / scores 134 | query = emb1[dico[:, 0]] 135 | scores = query.mm(emb2.transpose(0, 1)) 136 | scores.mul_(2) 137 | scores.sub_(average_dist1[dico[:, 0]][:, None]) 138 | scores.sub_(average_dist2[None, :]) 139 | 140 | else: 141 | raise Exception('Unknown method: "%s"' % method) 142 | 143 | results = [] 144 | top_matches = scores.topk(10, 1, True)[1] 145 | for k in [1, 5, 10]: 146 | top_k_matches = top_matches[:, :k] 147 | _matching = (top_k_matches == dico[:, 1][:, None].expand_as(top_k_matches)).sum(1) 148 | # allow for multiple possible translations 149 | matching = {} 150 | for i, src_id in enumerate(dico[:, 0].cpu().numpy()): 151 | matching[src_id] = min(matching.get(src_id, 0) + _matching[i], 1) 152 | # evaluate precision@k 153 | precision_at_k = 100 * np.mean(list(matching.values())) 154 | logger.info("%i source words - %s - Precision at k = %i: %f" % 155 | (len(matching), method, k, precision_at_k)) 156 | results.append(('precision_at_%i' % k, precision_at_k)) 157 | 158 | return results 159 | -------------------------------------------------------------------------------- /umwe/src/evaluation/wordsim.py: -------------------------------------------------------------------------------- 1 | # Original work Copyright (c) 2017-present, Facebook, Inc. 2 | # Modified work Copyright (c) 2018, Xilun Chen 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | import os 10 | import io 11 | from logging import getLogger 12 | import numpy as np 13 | import torch 14 | from scipy.stats import spearmanr 15 | 16 | 17 | MONOLINGUAL_EVAL_PATH = 'data/monolingual' 18 | SEMEVAL17_EVAL_PATH = 'data/crosslingual/wordsim' 19 | 20 | 21 | logger = getLogger() 22 | 23 | 24 | def get_word_pairs(path, lower=True): 25 | """ 26 | Return a list of (word1, word2, score) tuples from a word similarity file. 27 | """ 28 | assert os.path.isfile(path) and type(lower) is bool 29 | word_pairs = [] 30 | with io.open(path, 'r', encoding='utf-8') as f: 31 | for line in f: 32 | line = line.rstrip() 33 | line = line.lower() if lower else line 34 | line = line.split() 35 | # ignore phrases, only consider words 36 | if len(line) != 3: 37 | assert len(line) > 3 38 | assert 'SEMEVAL17' in os.path.basename(path) or 'EN-IT_MWS353' in path 39 | continue 40 | word_pairs.append((line[0], line[1], float(line[2]))) 41 | return word_pairs 42 | 43 | 44 | def get_word_id(word, word2id, lower): 45 | """ 46 | Get a word ID. 47 | If the model does not use lowercase and the evaluation file is lowercased, 48 | we might be able to find an associated word. 49 | """ 50 | assert type(lower) is bool 51 | word_id = word2id.get(word) 52 | if word_id is None and not lower: 53 | word_id = word2id.get(word.capitalize()) 54 | if word_id is None and not lower: 55 | word_id = word2id.get(word.title()) 56 | return word_id 57 | 58 | 59 | def get_spearman_rho(word2id1, embeddings1, path, lower, 60 | word2id2=None, embeddings2=None, ignore_oov=True): 61 | """ 62 | Compute monolingual or cross-lingual word similarity score. 63 | """ 64 | assert not ((word2id2 is None) ^ (embeddings2 is None)) 65 | word2id2 = word2id1 if word2id2 is None else word2id2 66 | embeddings2 = embeddings1 if embeddings2 is None else embeddings2 67 | assert len(word2id1) == embeddings1.shape[0] 68 | assert len(word2id2) == embeddings2.shape[0] 69 | assert type(lower) is bool 70 | word_pairs = get_word_pairs(path) 71 | not_found = 0 72 | pred = [] 73 | gold = [] 74 | for word1, word2, similarity in word_pairs: 75 | id1 = get_word_id(word1, word2id1, lower) 76 | id2 = get_word_id(word2, word2id2, lower) 77 | if id1 is None or id2 is None: 78 | not_found += 1 79 | if not ignore_oov: 80 | gold.append(similarity) 81 | pred.append(0.5) 82 | continue 83 | u = embeddings1[id1] 84 | v = embeddings2[id2] 85 | score = u.dot(v) / (np.linalg.norm(u) * np.linalg.norm(v)) 86 | gold.append(similarity) 87 | pred.append(score) 88 | return spearmanr(gold, pred).correlation, len(gold), not_found 89 | 90 | 91 | def get_wordsim_scores(language, word2id, embeddings, lower=True): 92 | """ 93 | Return monolingual word similarity scores. 94 | """ 95 | dirpath = os.path.join(MONOLINGUAL_EVAL_PATH, language) 96 | if not os.path.isdir(dirpath): 97 | return None 98 | 99 | scores = {} 100 | separator = "=" * (30 + 1 + 10 + 1 + 13 + 1 + 12) 101 | pattern = "%30s %10s %13s %12s" 102 | logger.info(separator) 103 | logger.info(pattern % ("Dataset", "Found", "Not found", "Rho")) 104 | logger.info(separator) 105 | 106 | for filename in list(os.listdir(dirpath)): 107 | if filename.startswith('%s_' % (language.upper())): 108 | filepath = os.path.join(dirpath, filename) 109 | coeff, found, not_found = get_spearman_rho(word2id, embeddings, filepath, lower) 110 | logger.info(pattern % (filename[:-4], str(found), str(not_found), "%.4f" % coeff)) 111 | scores[filename[:-4]] = coeff 112 | logger.info(separator) 113 | 114 | return scores 115 | 116 | 117 | def get_wordanalogy_scores(language, word2id, embeddings, lower=True): 118 | """ 119 | Return (english) word analogy score 120 | """ 121 | dirpath = os.path.join(MONOLINGUAL_EVAL_PATH, language) 122 | if not os.path.isdir(dirpath) or language not in ["en"]: 123 | return None 124 | 125 | # normalize word embeddings 126 | embeddings = embeddings / np.sqrt((embeddings ** 2).sum(1))[:, None] 127 | 128 | # scores by category 129 | scores = {} 130 | 131 | word_ids = {} 132 | queries = {} 133 | 134 | with io.open(os.path.join(dirpath, 'questions-words.txt'), 'r', encoding='utf-8') as f: 135 | for line in f: 136 | # new line 137 | line = line.rstrip() 138 | if lower: 139 | line = line.lower() 140 | 141 | # new category 142 | if ":" in line: 143 | assert line[1] == ' ' 144 | category = line[2:] 145 | assert category not in scores 146 | scores[category] = {'n_found': 0, 'n_not_found': 0, 'n_correct': 0} 147 | word_ids[category] = [] 148 | queries[category] = [] 149 | continue 150 | 151 | # get word IDs 152 | assert len(line.split()) == 4, line 153 | word1, word2, word3, word4 = line.split() 154 | word_id1 = get_word_id(word1, word2id, lower) 155 | word_id2 = get_word_id(word2, word2id, lower) 156 | word_id3 = get_word_id(word3, word2id, lower) 157 | word_id4 = get_word_id(word4, word2id, lower) 158 | 159 | # if at least one word is not found 160 | if any(x is None for x in [word_id1, word_id2, word_id3, word_id4]): 161 | scores[category]['n_not_found'] += 1 162 | continue 163 | else: 164 | scores[category]['n_found'] += 1 165 | word_ids[category].append([word_id1, word_id2, word_id3, word_id4]) 166 | # generate query vector and get nearest neighbors 167 | query = embeddings[word_id1] - embeddings[word_id2] + embeddings[word_id4] 168 | query = query / np.linalg.norm(query) 169 | 170 | queries[category].append(query) 171 | 172 | # Compute score for each category 173 | for cat in queries: 174 | qs = torch.from_numpy(np.vstack(queries[cat])) 175 | keys = torch.from_numpy(embeddings.T) 176 | values = qs.mm(keys).cpu().numpy() 177 | 178 | # be sure we do not select input words 179 | for i, ws in enumerate(word_ids[cat]): 180 | for wid in [ws[0], ws[1], ws[3]]: 181 | values[i, wid] = -1e9 182 | scores[cat]['n_correct'] = np.sum(values.argmax(axis=1) == [ws[2] for ws in word_ids[cat]]) 183 | 184 | # pretty print 185 | separator = "=" * (30 + 1 + 10 + 1 + 13 + 1 + 12) 186 | pattern = "%30s %10s %13s %12s" 187 | logger.info(separator) 188 | logger.info(pattern % ("Category", "Found", "Not found", "Accuracy")) 189 | logger.info(separator) 190 | 191 | # compute and log accuracies 192 | accuracies = {} 193 | for k in sorted(scores.keys()): 194 | v = scores[k] 195 | accuracies[k] = float(v['n_correct']) / max(v['n_found'], 1) 196 | logger.info(pattern % (k, str(v['n_found']), str(v['n_not_found']), "%.4f" % accuracies[k])) 197 | logger.info(separator) 198 | 199 | return accuracies 200 | 201 | 202 | def get_crosslingual_wordsim_scores(lang1, word2id1, embeddings1, 203 | lang2, word2id2, embeddings2, 204 | lower=True, ignore_oov=True): 205 | """ 206 | Return cross-lingual word similarity scores. 207 | """ 208 | f1 = os.path.join(SEMEVAL17_EVAL_PATH, '%s-%s-SEMEVAL17.txt' % (lang1, lang2)) 209 | f2 = os.path.join(SEMEVAL17_EVAL_PATH, '%s-%s-SEMEVAL17.txt' % (lang2, lang1)) 210 | if not (os.path.exists(f1) or os.path.exists(f2)): 211 | return None 212 | 213 | if os.path.exists(f1): 214 | coeff, found, not_found = get_spearman_rho( 215 | word2id1, embeddings1, f1, 216 | lower, word2id2, embeddings2, ignore_oov 217 | ) 218 | elif os.path.exists(f2): 219 | coeff, found, not_found = get_spearman_rho( 220 | word2id2, embeddings2, f2, 221 | lower, word2id1, embeddings1, ignore_oov 222 | ) 223 | 224 | scores = {} 225 | separator = "=" * (30 + 1 + 10 + 1 + 13 + 1 + 12) 226 | pattern = "%30s %10s %13s %12s" 227 | logger.info(separator) 228 | logger.info(pattern % ("Dataset", "Found", "Not found", "Rho")) 229 | logger.info(separator) 230 | 231 | task_name = '%s_%s_SEMEVAL17' % (lang1.upper(), lang2.upper()) 232 | logger.info(pattern % (task_name, str(found), str(not_found), "%.4f" % coeff)) 233 | scores[task_name] = coeff 234 | if not scores: 235 | return None 236 | logger.info(separator) 237 | 238 | return scores 239 | -------------------------------------------------------------------------------- /umwe/src/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import logging 9 | import time 10 | from datetime import timedelta 11 | 12 | 13 | class LogFormatter(): 14 | 15 | def __init__(self): 16 | self.start_time = time.time() 17 | 18 | def format(self, record): 19 | elapsed_seconds = round(record.created - self.start_time) 20 | 21 | prefix = "%s - %s - %s" % ( 22 | record.levelname, 23 | time.strftime('%x %X'), 24 | timedelta(seconds=elapsed_seconds) 25 | ) 26 | message = record.getMessage() 27 | message = message.replace('\n', '\n' + ' ' * (len(prefix) + 3)) 28 | return "%s - %s" % (prefix, message) 29 | 30 | 31 | def create_logger(filepath, vb=2): 32 | """ 33 | Create a logger. 34 | """ 35 | # create log formatter 36 | log_formatter = LogFormatter() 37 | 38 | # create file handler and set level to debug 39 | file_handler = logging.FileHandler(filepath, "a") 40 | file_handler.setLevel(logging.DEBUG) 41 | file_handler.setFormatter(log_formatter) 42 | 43 | # create console handler and set level to info 44 | log_level = logging.DEBUG if vb == 2 else logging.INFO if vb == 1 else logging.WARNING 45 | console_handler = logging.StreamHandler() 46 | console_handler.setLevel(log_level) 47 | console_handler.setFormatter(log_formatter) 48 | 49 | # create logger and set level to debug 50 | logger = logging.getLogger() 51 | logger.handlers = [] 52 | logger.setLevel(logging.DEBUG) 53 | logger.propagate = False 54 | logger.addHandler(file_handler) 55 | logger.addHandler(console_handler) 56 | 57 | # reset logger elapsed time 58 | def reset_time(): 59 | log_formatter.start_time = time.time() 60 | logger.reset_time = reset_time 61 | 62 | return logger 63 | -------------------------------------------------------------------------------- /umwe/src/models.py: -------------------------------------------------------------------------------- 1 | # Original work Copyright (c) 2017-present, Facebook, Inc. 2 | # Modified work Copyright (c) 2018, Xilun Chen 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from .utils import load_embeddings, normalize_embeddings 13 | 14 | 15 | class Discriminator(nn.Module): 16 | 17 | def __init__(self, params, lang): 18 | super(Discriminator, self).__init__() 19 | 20 | self.lang = lang 21 | self.emb_dim = params.emb_dim 22 | self.dis_layers = params.dis_layers 23 | self.dis_hid_dim = params.dis_hid_dim 24 | self.dis_dropout = params.dis_dropout 25 | self.dis_input_dropout = params.dis_input_dropout 26 | 27 | layers = [nn.Dropout(self.dis_input_dropout)] 28 | for i in range(self.dis_layers + 1): 29 | input_dim = self.emb_dim if i == 0 else self.dis_hid_dim 30 | output_dim = 1 if i == self.dis_layers else self.dis_hid_dim 31 | layers.append(nn.Linear(input_dim, output_dim)) 32 | if i < self.dis_layers: 33 | layers.append(nn.LeakyReLU(0.2)) 34 | layers.append(nn.Dropout(self.dis_dropout)) 35 | layers.append(nn.Sigmoid()) 36 | self.layers = nn.Sequential(*layers) 37 | 38 | def forward(self, x): 39 | assert x.dim() == 2 and x.size(1) == self.emb_dim 40 | return self.layers(x).view(-1) 41 | 42 | 43 | def build_model(params, with_dis): 44 | """ 45 | Build all components of the model. 46 | """ 47 | # source embeddings 48 | params.vocabs, _src_embs, embs = {}, {}, {} 49 | for i, lang in enumerate(params.src_langs): 50 | dico, emb = load_embeddings(params, lang, params.src_embs[i]) 51 | params.vocabs[lang] = dico 52 | _src_embs[lang] = emb 53 | for i, lang in enumerate(params.src_langs): 54 | src_emb = nn.Embedding(len(params.vocabs[lang]), params.emb_dim, sparse=True) 55 | src_emb.weight.data.copy_(_src_embs[lang]) 56 | embs[lang] = src_emb 57 | 58 | # target embeddings 59 | if params.tgt_lang: 60 | tgt_dico, _tgt_emb = load_embeddings(params, params.tgt_lang, params.tgt_emb) 61 | params.vocabs[params.tgt_lang] = tgt_dico 62 | tgt_emb = nn.Embedding(len(tgt_dico), params.emb_dim, sparse=True) 63 | tgt_emb.weight.data.copy_(_tgt_emb) 64 | embs[params.tgt_lang] = tgt_emb 65 | else: 66 | tgt_emb = None 67 | 68 | # mappings 69 | mappings = {lang: nn.Linear(params.emb_dim, params.emb_dim, 70 | bias=False) for lang in params.src_langs} 71 | # set tgt mapping to fixed identity matrix 72 | tgt_map = nn.Linear(params.emb_dim, params.emb_dim, bias=False) 73 | tgt_map.weight.data.copy_(torch.diag(torch.ones(params.emb_dim))) 74 | for p in tgt_map.parameters(): 75 | p.requires_grad = False 76 | mappings[params.tgt_lang] = tgt_map 77 | if getattr(params, 'map_id_init', True): 78 | for mapping in mappings.values(): 79 | mapping.weight.data.copy_(torch.diag(torch.ones(params.emb_dim))) 80 | 81 | # discriminators 82 | discriminators = {lang: Discriminator(params, lang) 83 | for lang in params.all_langs} if with_dis else None 84 | 85 | for lang in params.all_langs: 86 | embs[lang] = embs[lang].to(params.device) 87 | mappings[lang] = mappings[lang].to(params.device) 88 | if with_dis: 89 | discriminators[lang] = discriminators[lang].to(params.device) 90 | 91 | # normalize embeddings 92 | params.lang_mean = {} 93 | for lang, emb in embs.items(): 94 | params.lang_mean[lang] = normalize_embeddings(emb.weight.detach(), params.normalize_embeddings) 95 | 96 | return embs, mappings, discriminators 97 | -------------------------------------------------------------------------------- /umwe/src/trainer.py: -------------------------------------------------------------------------------- 1 | # Original work Copyright (c) 2017-present, Facebook, Inc. 2 | # Modified work Copyright (c) 2018, Xilun Chen 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | import itertools 10 | import os 11 | from logging import getLogger 12 | import random 13 | 14 | import numpy as np 15 | import scipy 16 | import scipy.linalg 17 | import torch 18 | from torch.nn import functional as F 19 | from torch import optim 20 | 21 | from .utils import get_optimizer, load_embeddings, normalize_embeddings, export_embeddings 22 | from .utils import clip_parameters, apply_mapping 23 | from .dico_builder import build_dictionary 24 | from .evaluation.word_translation import DIC_EVAL_PATH, load_identical_char_dico, load_dictionary 25 | 26 | 27 | logger = getLogger() 28 | 29 | 30 | class Trainer(object): 31 | 32 | def __init__(self, embs, mappings, discriminators, params): 33 | """ 34 | Initialize trainer script. 35 | """ 36 | self.embs = embs 37 | self.vocabs = params.vocabs 38 | self.mappings = mappings 39 | self.discriminators = discriminators 40 | self.params = params 41 | self.dicos = {} 42 | 43 | # optimizers 44 | if hasattr(params, 'map_optimizer'): 45 | optim_fn, optim_params = get_optimizer(params.map_optimizer) 46 | self.map_optimizer = optim_fn(itertools.chain(*[m.parameters() 47 | for l,m in mappings.items() if l!=params.tgt_lang]), **optim_params) 48 | if hasattr(params, 'dis_optimizer'): 49 | optim_fn, optim_params = get_optimizer(params.dis_optimizer) 50 | self.dis_optimizer = optim_fn(itertools.chain(*[d.parameters() 51 | for d in discriminators.values()]), **optim_params) 52 | else: 53 | assert discriminators is None 54 | if hasattr(params, 'mpsr_optimizer'): 55 | optim_fn, optim_params = get_optimizer(params.mpsr_optimizer) 56 | self.mpsr_optimizer = optim_fn(itertools.chain(*[m.parameters() 57 | for l,m in self.mappings.items() if l!=self.params.tgt_lang]), **optim_params) 58 | 59 | # best validation score 60 | self.best_valid_metric = -1e12 61 | 62 | self.decrease_lr = False 63 | self.decrease_mpsr_lr = False 64 | 65 | def get_dis_xy(self, lang1, lang2, volatile): 66 | """ 67 | Get discriminator input batch / output target. 68 | Encode from lang1, decode to lang2 and then discriminate 69 | """ 70 | # select random word IDs 71 | bs = self.params.batch_size 72 | mf = self.params.dis_most_frequent 73 | assert mf <= min(map(len, self.vocabs.values())) 74 | src_ids = torch.LongTensor(bs).random_(len(self.vocabs[lang1]) if mf == 0 else mf) 75 | tgt_ids = torch.LongTensor(bs).random_(len(self.vocabs[lang2]) if mf == 0 else mf) 76 | src_ids = src_ids.to(self.params.device) 77 | tgt_ids = tgt_ids.to(self.params.device) 78 | 79 | with torch.set_grad_enabled(not volatile): 80 | # get word embeddings 81 | src_emb = self.embs[lang1](src_ids).detach() 82 | tgt_emb = self.embs[lang2](tgt_ids).detach() 83 | # map 84 | src_emb = self.mappings[lang1](src_emb) 85 | # decode 86 | src_emb = F.linear(src_emb, self.mappings[lang2].weight.t()) 87 | 88 | # input / target 89 | x = torch.cat([src_emb, tgt_emb], 0) 90 | y = torch.FloatTensor(2 * bs).zero_() 91 | # 0 indicates real (lang2) samples 92 | y[:bs] = 1 - self.params.dis_smooth 93 | y[bs:] = self.params.dis_smooth 94 | y = y.to(self.params.device) 95 | 96 | return x, y 97 | 98 | def get_mpsr_xy(self, lang1, lang2, volatile): 99 | """ 100 | Get input batch / output target for MPSR. 101 | """ 102 | # select random word IDs 103 | bs = self.params.batch_size 104 | dico = self.dicos[(lang1, lang2)] 105 | indices = torch.from_numpy(np.random.randint(0, len(dico), bs)).to(self.params.device) 106 | dico = dico.index_select(0, indices) 107 | src_ids = dico[:, 0].to(self.params.device) 108 | tgt_ids = dico[:, 1].to(self.params.device) 109 | 110 | with torch.set_grad_enabled(not volatile): 111 | # get word embeddings 112 | src_emb = self.embs[lang1](src_ids).detach() 113 | tgt_emb = self.embs[lang2](tgt_ids).detach() 114 | # map 115 | src_emb = self.mappings[lang1](src_emb) 116 | # decode 117 | src_emb = F.linear(src_emb, self.mappings[lang2].weight.t()) 118 | 119 | return src_emb, tgt_emb 120 | 121 | def dis_step(self, stats): 122 | """ 123 | Train the discriminator. 124 | """ 125 | for disc in self.discriminators.values(): 126 | disc.train() 127 | 128 | # loss 129 | loss = 0 130 | # for each target language 131 | for lang2 in self.params.all_langs: 132 | # random select a source language 133 | lang1 = random.choice(self.params.all_langs) 134 | 135 | x, y = self.get_dis_xy(lang1, lang2, volatile=True) 136 | preds = self.discriminators[lang2](x.detach()) 137 | loss += F.binary_cross_entropy(preds, y) 138 | 139 | # check NaN 140 | if (loss != loss).any(): 141 | logger.error("NaN detected (discriminator)") 142 | exit() 143 | stats['DIS_COSTS'].append(loss.item()) 144 | 145 | # optim 146 | self.dis_optimizer.zero_grad() 147 | loss.backward() 148 | self.dis_optimizer.step() 149 | for d in self.discriminators: 150 | clip_parameters(d, self.params.dis_clip_weights) 151 | 152 | def mapping_step(self, stats): 153 | """ 154 | Fooling discriminator training step. 155 | """ 156 | if self.params.dis_lambda == 0: 157 | return 0 158 | 159 | for disc in self.discriminators.values(): 160 | disc.eval() 161 | 162 | # loss 163 | loss = 0 164 | for lang1 in self.params.all_langs: 165 | lang2 = random.choice(self.params.all_langs) 166 | 167 | x, y = self.get_dis_xy(lang1, lang2, volatile=False) 168 | preds = self.discriminators[lang2](x) 169 | loss += F.binary_cross_entropy(preds, 1 - y) 170 | loss = self.params.dis_lambda * loss 171 | 172 | # check NaN 173 | if (loss != loss).any(): 174 | logger.error("NaN detected (fool discriminator)") 175 | exit() 176 | 177 | # optim 178 | self.map_optimizer.zero_grad() 179 | loss.backward() 180 | self.map_optimizer.step() 181 | self.orthogonalize() 182 | 183 | return len(self.params.all_langs) * self.params.batch_size 184 | 185 | def load_training_dico(self, dico_train): 186 | """ 187 | Load training dictionary. 188 | """ 189 | # load dicos for all lang pairs 190 | for i, lang1 in enumerate(self.params.all_langs): 191 | for j, lang2 in enumerate(self.params.all_langs): 192 | if lang1 == lang2: 193 | idx = torch.arange(self.params.dico_max_rank).long().view(self.params.dico_max_rank, 1) 194 | self.dicos[(lang1, lang2)] = torch.cat([idx, idx], dim=1).to(self.params.device) 195 | else: 196 | word2id1 = self.vocabs[lang1].word2id 197 | word2id2 = self.vocabs[lang2].word2id 198 | 199 | # identical character strings 200 | if dico_train == "identical_char": 201 | self.dicos[(lang1, lang2)] = load_identical_char_dico(word2id1, word2id2) 202 | # use one of the provided dictionary 203 | elif dico_train == "default": 204 | filename = '%s-%s.0-5000.txt' % (lang1, lang2) 205 | self.dicos[(lang1, lang2)] = load_dictionary( 206 | os.path.join(DIC_EVAL_PATH, filename), 207 | word2id1, word2id2 208 | ) 209 | # TODO dictionary provided by the user 210 | else: 211 | # self.dicos[(lang1, lang2)] = load_dictionary(dico_train, word2id1, word2id2) 212 | raise NotImplemented(dico_train) 213 | self.dicos[(lang1, lang2)] = self.dicos[(lang1, lang2)].to(self.params.device) 214 | 215 | def build_dictionary(self): 216 | """ 217 | Build dictionaries from aligned embeddings. 218 | """ 219 | # build dicos for all lang pairs 220 | for i, lang1 in enumerate(self.params.all_langs): 221 | for j, lang2 in enumerate(self.params.all_langs): 222 | if i < j: 223 | src_emb = self.embs[lang1].weight 224 | src_emb = apply_mapping(self.mappings[lang1], src_emb).detach() 225 | tgt_emb = self.embs[lang2].weight 226 | tgt_emb = apply_mapping(self.mappings[lang2], tgt_emb).detach() 227 | src_emb = src_emb / src_emb.norm(2, 1, keepdim=True).expand_as(src_emb) 228 | tgt_emb = tgt_emb / tgt_emb.norm(2, 1, keepdim=True).expand_as(tgt_emb) 229 | self.dicos[(lang1, lang2)] = build_dictionary(src_emb, tgt_emb, self.params) 230 | elif i > j: 231 | self.dicos[(lang1, lang2)] = self.dicos[(lang2, lang1)][:, [1,0]] 232 | else: 233 | idx = torch.arange(self.params.dico_max_rank).long().view(self.params.dico_max_rank, 1) 234 | self.dicos[(lang1, lang2)] = torch.cat([idx, idx], dim=1).to(self.params.device) 235 | 236 | def mpsr_step(self, stats): 237 | # loss 238 | loss = 0 239 | for lang1 in self.params.all_langs: 240 | lang2 = random.choice(self.params.all_langs) 241 | 242 | x, y = self.get_mpsr_xy(lang1, lang2, volatile=False) 243 | loss += F.mse_loss(x, y) 244 | # check NaN 245 | if (loss != loss).any(): 246 | logger.error("NaN detected (fool discriminator)") 247 | exit() 248 | 249 | stats['MPSR_COSTS'].append(loss.item()) 250 | # optim 251 | self.mpsr_optimizer.zero_grad() 252 | loss.backward() 253 | self.mpsr_optimizer.step() 254 | 255 | if self.params.mpsr_orthogonalize: 256 | self.orthogonalize() 257 | 258 | return len(self.params.all_langs) * self.params.batch_size 259 | 260 | def orthogonalize(self): 261 | """ 262 | Orthogonalize the mapping. 263 | """ 264 | if self.params.map_beta > 0: 265 | for mapping in self.mappings.values(): 266 | W = mapping.weight.detach() 267 | beta = self.params.map_beta 268 | W.copy_((1 + beta) * W - beta * W.mm(W.transpose(0, 1).mm(W))) 269 | 270 | def update_lr(self, to_log, metric): 271 | """ 272 | Update learning rate when using SGD. 273 | """ 274 | if 'sgd' not in self.params.map_optimizer: 275 | return 276 | old_lr = self.map_optimizer.param_groups[0]['lr'] 277 | new_lr = max(self.params.min_lr, old_lr * self.params.lr_decay) 278 | if new_lr < old_lr: 279 | logger.info("Decreasing learning rate: %.8f -> %.8f" % (old_lr, new_lr)) 280 | self.map_optimizer.param_groups[0]['lr'] = new_lr 281 | 282 | if self.params.lr_shrink < 1 and to_log[metric] >= -1e7: 283 | if to_log[metric] < self.best_valid_metric: 284 | logger.info("Validation metric is smaller than the best: %.5f vs %.5f" 285 | % (to_log[metric], self.best_valid_metric)) 286 | # decrease the learning rate, only if this is the 287 | # second time the validation metric decreases 288 | if self.decrease_lr: 289 | old_lr = self.map_optimizer.param_groups[0]['lr'] 290 | self.map_optimizer.param_groups[0]['lr'] *= self.params.lr_shrink 291 | logger.info("Shrinking the learning rate: %.5f -> %.5f" 292 | % (old_lr, self.map_optimizer.param_groups[0]['lr'])) 293 | self.decrease_lr = True 294 | 295 | def update_mpsr_lr(self, to_log, metric): 296 | """ 297 | Update learning rate when using SGD. 298 | """ 299 | if 'sgd' not in self.params.mpsr_optimizer: 300 | return 301 | old_lr = self.mpsr_optimizer.param_groups[0]['lr'] 302 | new_lr = max(self.params.min_lr, old_lr * self.params.lr_decay) 303 | if new_lr < old_lr: 304 | logger.info("Decreasing learning rate: %.8f -> %.8f" % (old_lr, new_lr)) 305 | self.mpsr_optimizer.param_groups[0]['lr'] = new_lr 306 | 307 | if self.params.lr_shrink < 1 and to_log[metric] >= -1e7: 308 | if to_log[metric] < self.best_valid_metric: 309 | logger.info("Validation metric is smaller than the best: %.5f vs %.5f" 310 | % (to_log[metric], self.best_valid_metric)) 311 | # decrease the learning rate, only if this is the 312 | # second time the validation metric decreases 313 | if self.decrease_mpsr_lr: 314 | old_lr = self.mpsr_optimizer.param_groups[0]['lr'] 315 | self.mpsr_optimizer.param_groups[0]['lr'] *= self.params.lr_shrink 316 | logger.info("Shrinking the learning rate: %.5f -> %.5f" 317 | % (old_lr, self.mpsr_optimizer.param_groups[0]['lr'])) 318 | self.decrease_mpsr_lr = True 319 | 320 | def save_best(self, to_log, metric): 321 | """ 322 | Save the best model for the given validation metric. 323 | """ 324 | # best mapping for the given validation criterion 325 | if to_log[metric] > self.best_valid_metric: 326 | # new best mapping 327 | self.best_valid_metric = to_log[metric] 328 | logger.info('* Best value for "%s": %.5f' % (metric, to_log[metric])) 329 | # save the mapping 330 | tgt_lang = self.params.tgt_lang 331 | for src_lang in self.params.src_langs: 332 | W = self.mappings[src_lang].weight.detach().cpu().numpy() 333 | path = os.path.join(self.params.exp_path, 334 | f'best_mapping_{src_lang}2{tgt_lang}.t7') 335 | logger.info(f'* Saving the {src_lang} to {tgt_lang} mapping to %s ...' % path) 336 | torch.save(W, path) 337 | 338 | def reload_best(self): 339 | """ 340 | Reload the best mapping. 341 | """ 342 | tgt_lang = self.params.tgt_lang 343 | for src_lang in self.params.src_langs: 344 | path = os.path.join(self.params.exp_path, 345 | f'best_mapping_{src_lang}2{tgt_lang}.t7') 346 | logger.info(f'* Reloading the best {src_lang} to {tgt_lang} model from {path} ...') 347 | # reload the model 348 | assert os.path.isfile(path) 349 | to_reload = torch.from_numpy(torch.load(path)) 350 | W = self.mappings[src_lang].weight.detach() 351 | assert to_reload.size() == W.size() 352 | W.copy_(to_reload.type_as(W)) 353 | 354 | def export(self): 355 | """ 356 | Export embeddings. 357 | """ 358 | params = self.params 359 | # load all embeddings 360 | logger.info("Reloading embeddings for mapping ...") 361 | params.vocabs[params.tgt_lang], tgt_emb = load_embeddings(params, params.tgt_lang, 362 | params.tgt_emb, full_vocab=True) 363 | normalize_embeddings(tgt_emb, params.normalize_embeddings, 364 | mean=params.lang_mean[params.tgt_lang]) 365 | # export target embeddings 366 | export_embeddings(tgt_emb, self.params.tgt_lang, self.params) 367 | # export all source embeddings 368 | for i, src_lang in enumerate(self.params.src_langs): 369 | params.vocabs[src_lang], src_emb = load_embeddings(params, src_lang, 370 | params.src_embs[i], full_vocab=True) 371 | logger.info(f"Map {src_lang} embeddings to the target space ...") 372 | src_emb = apply_mapping(self.mappings[src_lang], src_emb) 373 | export_embeddings(src_emb, src_lang, self.params) 374 | -------------------------------------------------------------------------------- /umwe/supervised.py: -------------------------------------------------------------------------------- 1 | # Original work Copyright (c) 2017-present, Facebook, Inc. 2 | # Modified work Copyright (c) 2018, Xilun Chen 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | import os 10 | import json 11 | import argparse 12 | from collections import OrderedDict 13 | import numpy as np 14 | import time 15 | import torch 16 | 17 | from src.utils import bool_flag, initialize_exp 18 | from src.models import build_model 19 | from src.trainer import Trainer 20 | from src.evaluation import Evaluator 21 | 22 | 23 | VALIDATION_METRIC_SUP = 'precision_at_1-csls_knn_10' 24 | VALIDATION_METRIC_UNSUP = 'mean_cosine-csls_knn_10-S2T-10000' 25 | # default path to embeddings embeddings if not otherwise specified 26 | EMB_DIR = 'data/fasttext-vectors/' 27 | 28 | # main 29 | parser = argparse.ArgumentParser(description='Supervised training') 30 | parser.add_argument("--seed", type=int, default=-1, help="Initialization seed") 31 | parser.add_argument("--verbose", type=int, default=2, help="Verbose level (2:debug, 1:info, 0:warning)") 32 | parser.add_argument("--exp_path", type=str, default="", help="Where to store experiment logs and models") 33 | parser.add_argument("--exp_name", type=str, default="debug", help="Experiment name") 34 | parser.add_argument("--exp_id", type=str, default="", help="Experiment ID") 35 | # parser.add_argument("--cuda", type=bool_flag, default=True, help="Run on GPU") 36 | parser.add_argument("--device", type=str, default="cuda", help="Run on GPU or CPU") 37 | parser.add_argument("--export", type=str, default="txt", help="Export embeddings after training (txt / pth)") 38 | 39 | # data 40 | parser.add_argument("--src_langs", type=str, nargs='+', default=['de', 'es', 'fr', 'it', 'pt'], help="Source languages") 41 | parser.add_argument("--tgt_lang", type=str, default='es', help="Target language") 42 | parser.add_argument("--emb_dim", type=int, default=300, help="Embedding dimension") 43 | parser.add_argument("--max_vocab", type=int, default=200000, help="Maximum vocabulary size (-1 to disable)") 44 | # training refinement 45 | parser.add_argument("--n_refinement", type=int, default=5, help="Number of refinement iterations (0 to disable the refinement procedure)") 46 | parser.add_argument("--batch_size", type=int, default=32, help="Batch size") 47 | parser.add_argument("--map_beta", type=float, default=0.001, help="Beta for orthogonalization") 48 | # MPSR parameters 49 | parser.add_argument("--mpsr_optimizer", type=str, default="adam", help="Multilingual Pseudo-Supervised Refinement optimizer") 50 | parser.add_argument("--mpsr_orthogonalize", type=bool_flag, default=True, help="During MPSR, whether to perform orthogonalization") 51 | parser.add_argument("--mpsr_n_steps", type=int, default=30000, help="Number of optimization steps for MPSR") 52 | # dictionary creation parameters (for refinement) 53 | parser.add_argument("--dico_train", type=str, default="default", help="Path to training dictionary (default or identical_char)") 54 | parser.add_argument("--dico_eval", type=str, default="default", help="Path to evaluation dictionary") 55 | parser.add_argument("--dico_method", type=str, default='csls_knn_10', help="Method used for dictionary generation (nn/invsm_beta_30/csls_knn_10)") 56 | parser.add_argument("--dico_build", type=str, default='S2T&T2S', help="S2T,T2S,S2T|T2S,S2T&T2S") 57 | parser.add_argument("--dico_threshold", type=float, default=0, help="Threshold confidence for dictionary generation") 58 | parser.add_argument("--dico_max_rank", type=int, default=10000, help="Maximum dictionary words rank (0 to disable)") 59 | parser.add_argument("--dico_min_size", type=int, default=0, help="Minimum generated dictionary size (0 to disable)") 60 | parser.add_argument("--dico_max_size", type=int, default=0, help="Maximum generated dictionary size (0 to disable)") 61 | parser.add_argument("--semeval_ignore_oov", type=bool_flag, default=True, help="Whether to ignore OOV in SEMEVAL evaluation (the original authors used True)") 62 | # reload pre-trained embeddings 63 | parser.add_argument("--src_embs", type=str, nargs='+', default=[], help="Reload source embeddings (should be in the same order as in src_langs)") 64 | parser.add_argument("--tgt_emb", type=str, default='', help="Reload target embeddings") 65 | parser.add_argument("--normalize_embeddings", type=str, default="", help="Normalize embeddings before training") 66 | 67 | 68 | # parse parameters 69 | params = parser.parse_args() 70 | 71 | # post-processing options 72 | params.src_N = len(params.src_langs) 73 | params.all_langs = params.src_langs + [params.tgt_lang] 74 | # load default embeddings if no embeddings specified 75 | if len(params.src_embs) == 0: 76 | params.src_embs = [] 77 | for lang in params.src_langs: 78 | params.src_embs.append(os.path.join(EMB_DIR, f'wiki.{lang}.vec')) 79 | if len(params.tgt_emb) == 0: 80 | params.tgt_emb = os.path.join(EMB_DIR, f'wiki.{params.tgt_lang}.vec') 81 | 82 | # check parameters 83 | assert not params.device.lower().startswith('cuda') or torch.cuda.is_available() 84 | assert params.dico_train in ["identical_char", "default"] or os.path.isfile(params.dico_train) 85 | assert params.dico_build in ["S2T", "T2S", "S2T|T2S", "S2T&T2S"] 86 | assert params.dico_max_size == 0 or params.dico_max_size < params.dico_max_rank 87 | assert params.dico_max_size == 0 or params.dico_max_size > params.dico_min_size 88 | assert all([os.path.isfile(emb) for emb in params.src_embs]) 89 | assert os.path.isfile(params.tgt_emb) 90 | assert params.dico_eval == 'default' or os.path.isfile(params.dico_eval) 91 | assert params.export in ["", "txt", "pth"] 92 | 93 | # build logger / model / trainer / evaluator 94 | logger = initialize_exp(params) 95 | # N+1 embeddings, N mappings , N+1 discriminators 96 | embs, mappings, discriminators = build_model(params, False) 97 | trainer = Trainer(embs, mappings, discriminators, params) 98 | evaluator = Evaluator(trainer) 99 | 100 | # load a training dictionary. if a dictionary path is not provided, use a default 101 | # one ("default") or create one based on identical character strings ("identical_char") 102 | trainer.load_training_dico(params.dico_train) 103 | 104 | # define the validation metric 105 | VALIDATION_METRIC = VALIDATION_METRIC_UNSUP if params.dico_train == 'identical_char' else VALIDATION_METRIC_SUP 106 | logger.info("Validation metric: %s" % VALIDATION_METRIC) 107 | 108 | """ 109 | Learning loop for Procrustes Iterative Learning 110 | """ 111 | for n_epoch in range(params.n_refinement + 1): 112 | 113 | logger.info('Starting iteration %i...' % n_epoch) 114 | 115 | # build a dictionary from aligned embeddings (unless 116 | # it is the first iteration and we use the init one) 117 | if n_epoch > 0 or not hasattr(trainer, 'dicos'): 118 | trainer.build_dictionary() 119 | 120 | # optimize MPSR 121 | tic = time.time() 122 | n_words_mpsr = 0 123 | stats = {'MPSR_COSTS': []} 124 | for n_iter in range(params.mpsr_n_steps): 125 | # mpsr training step 126 | n_words_mpsr += trainer.mpsr_step(stats) 127 | # log stats 128 | if n_iter % 500 == 0: 129 | stats_str = [('MPSR_COSTS', 'MPSR loss')] 130 | stats_log = ['%s: %.4f' % (v, np.mean(stats[k])) 131 | for k, v in stats_str if len(stats[k]) > 0] 132 | stats_log.append('%i samples/s' % int(n_words_mpsr / (time.time() - tic))) 133 | logger.info(('%06i - ' % n_iter) + ' - '.join(stats_log)) 134 | # reset 135 | tic = time.time() 136 | n_words_mpsr = 0 137 | for k, _ in stats_str: 138 | del stats[k][:] 139 | 140 | # embeddings evaluation 141 | to_log = OrderedDict({'n_epoch': n_epoch}) 142 | evaluator.all_eval(to_log) 143 | 144 | # JSON log / save best model / end of epoch 145 | logger.info("__log__:%s" % json.dumps(to_log)) 146 | trainer.save_best(to_log, VALIDATION_METRIC) 147 | logger.info('End of iteration %i.\n\n' % n_epoch) 148 | 149 | # update the learning rate (effective only if using SGD for MPSR) 150 | trainer.update_mpsr_lr(to_log, VALIDATION_METRIC) 151 | 152 | # export embeddings 153 | if params.export: 154 | trainer.reload_best() 155 | trainer.export() 156 | -------------------------------------------------------------------------------- /umwe/unsupervised.py: -------------------------------------------------------------------------------- 1 | # Original work Copyright (c) 2017-present, Facebook, Inc. 2 | # Modified work Copyright (c) 2018, Xilun Chen 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | import os 10 | import time 11 | import json 12 | import argparse 13 | from collections import OrderedDict 14 | import numpy as np 15 | import torch 16 | 17 | from src.utils import bool_flag, initialize_exp 18 | from src.models import build_model 19 | from src.trainer import Trainer 20 | from src.evaluation import Evaluator 21 | 22 | 23 | VALIDATION_METRIC = 'mean_cosine-csls_knn_10-S2T-10000' 24 | # default path to embeddings embeddings if not otherwise specified 25 | EMB_DIR = 'data/fasttext-vectors/' 26 | 27 | 28 | # main 29 | parser = argparse.ArgumentParser(description='Unsupervised training') 30 | parser.add_argument("--seed", type=int, default=-1, help="Initialization seed") 31 | parser.add_argument("--verbose", type=int, default=2, help="Verbose level (2:debug, 1:info, 0:warning)") 32 | parser.add_argument("--exp_path", type=str, default="", help="Where to store experiment logs and models") 33 | parser.add_argument("--exp_name", type=str, default="debug", help="Experiment name") 34 | parser.add_argument("--exp_id", type=str, default="", help="Experiment ID") 35 | # parser.add_argument("--cuda", type=bool_flag, default=True, help="Run on GPU") 36 | parser.add_argument("--device", type=str, default="cuda", help="Run on GPU or CPU") 37 | parser.add_argument("--export", type=str, default="txt", help="Export embeddings after training (txt / pth)") 38 | # data 39 | parser.add_argument("--src_langs", type=str, nargs='+', default=['de', 'es', 'fr', 'it', 'pt'], help="Source languages") 40 | parser.add_argument("--tgt_lang", type=str, default='en', help="Target language") 41 | parser.add_argument("--emb_dim", type=int, default=300, help="Embedding dimension") 42 | parser.add_argument("--max_vocab", type=int, default=200000, help="Maximum vocabulary size (-1 to disable)") 43 | # mapping 44 | parser.add_argument("--map_id_init", type=bool_flag, default=True, help="Initialize the mapping as an identity matrix") 45 | parser.add_argument("--map_beta", type=float, default=0.001, help="Beta for orthogonalization") 46 | # discriminator 47 | parser.add_argument("--dis_layers", type=int, default=2, help="Discriminator layers") 48 | parser.add_argument("--dis_hid_dim", type=int, default=2048, help="Discriminator hidden layer dimensions") 49 | parser.add_argument("--dis_dropout", type=float, default=0., help="Discriminator dropout") 50 | parser.add_argument("--dis_input_dropout", type=float, default=0.1, help="Discriminator input dropout") 51 | parser.add_argument("--dis_steps", type=int, default=5, help="Discriminator steps") 52 | parser.add_argument("--dis_lambda", type=float, default=1, help="Discriminator loss feedback coefficient") 53 | parser.add_argument("--dis_most_frequent", type=int, default=75000, help="Select embeddings of the k most frequent words for discrimination (0 to disable)") 54 | parser.add_argument("--dis_smooth", type=float, default=0.1, help="Discriminator smooth predictions") 55 | parser.add_argument("--dis_clip_weights", type=float, default=0, help="Clip discriminator weights (0 to disable)") 56 | # training adversarial 57 | parser.add_argument("--adversarial", type=bool_flag, default=True, help="Use adversarial training") 58 | parser.add_argument("--n_epochs", type=int, default=5, help="Number of epochs") 59 | parser.add_argument("--epoch_size", type=int, default=1000000, help="Iterations per epoch") 60 | parser.add_argument("--batch_size", type=int, default=32, help="Batch size") 61 | parser.add_argument("--map_optimizer", type=str, default="sgd,lr=0.1", help="Mapping optimizer") 62 | parser.add_argument("--dis_optimizer", type=str, default="sgd,lr=0.1", help="Discriminator optimizer") 63 | parser.add_argument("--lr_decay", type=float, default=0.98, help="Learning rate decay (SGD only)") 64 | parser.add_argument("--min_lr", type=float, default=1e-6, help="Minimum learning rate (SGD only)") 65 | parser.add_argument("--lr_shrink", type=float, default=0.5, help="Shrink the learning rate if the validation metric decreases (1 to disable)") 66 | # training refinement 67 | parser.add_argument("--n_refinement", type=int, default=5, help="Number of refinement iterations (0 to disable the refinement procedure)") 68 | # MPSR parameters 69 | parser.add_argument("--mpsr_optimizer", type=str, default="adam", help="Multilingual Pseudo-Supervised Refinement optimizer") 70 | parser.add_argument("--mpsr_orthogonalize", type=bool_flag, default=True, help="During MPSR, whether to perform orthogonalization") 71 | parser.add_argument("--mpsr_n_steps", type=int, default=30000, help="Number of optimization steps for MPSR") 72 | # dictionary creation parameters (for refinement) 73 | # default uses .5000-6500.txt; train uses .0-5000.txt; all uses .txt 74 | parser.add_argument("--dico_eval", type=str, default="default", help="Path to evaluation dictionary") 75 | parser.add_argument("--dico_method", type=str, default='csls_knn_10', help="Method used for dictionary generation (nn/invsm_beta_30/csls_knn_10)") 76 | parser.add_argument("--dico_build", type=str, default='S2T&T2S', help="S2T,T2S,S2T|T2S,S2T&T2S") 77 | parser.add_argument("--dico_threshold", type=float, default=0, help="Threshold confidence for dictionary generation") 78 | parser.add_argument("--dico_max_rank", type=int, default=15000, help="Maximum dictionary words rank (0 to disable)") 79 | parser.add_argument("--dico_min_size", type=int, default=0, help="Minimum generated dictionary size (0 to disable)") 80 | parser.add_argument("--dico_max_size", type=int, default=0, help="Maximum generated dictionary size (0 to disable)") 81 | parser.add_argument("--semeval_ignore_oov", type=bool_flag, default=True, help="Whether to ignore OOV in SEMEVAL evaluation (the original authors used True)") 82 | # reload pre-trained embeddings 83 | parser.add_argument("--src_embs", type=str, nargs='+', default=[], help="Reload source embeddings (should be in the same order as in src_langs)") 84 | parser.add_argument("--tgt_emb", type=str, default="", help="Reload target embeddings") 85 | parser.add_argument("--normalize_embeddings", type=str, default="", help="Normalize embeddings before training") 86 | 87 | 88 | # parse parameters 89 | params = parser.parse_args() 90 | 91 | # post-processing options 92 | params.src_N = len(params.src_langs) 93 | params.all_langs = params.src_langs + [params.tgt_lang] 94 | # load default embeddings if no embeddings specified 95 | if len(params.src_embs) == 0: 96 | params.src_embs = [] 97 | for lang in params.src_langs: 98 | params.src_embs.append(os.path.join(EMB_DIR, f'wiki.{lang}.vec')) 99 | if len(params.tgt_emb) == 0: 100 | params.tgt_emb = os.path.join(EMB_DIR, f'wiki.{params.tgt_lang}.vec') 101 | 102 | # check parameters 103 | assert not params.device.lower().startswith('cuda') or torch.cuda.is_available() 104 | assert 0 <= params.dis_dropout < 1 105 | assert 0 <= params.dis_input_dropout < 1 106 | assert 0 <= params.dis_smooth < 0.5 107 | assert params.dis_lambda > 0 and params.dis_steps > 0 108 | assert 0 < params.lr_shrink <= 1 109 | assert all([os.path.isfile(emb) for emb in params.src_embs]) 110 | assert os.path.isfile(params.tgt_emb) 111 | assert params.dico_eval == 'default' or os.path.isfile(params.dico_eval) 112 | assert params.export in ["", "txt", "pth"] 113 | 114 | # build model / trainer / evaluator 115 | logger = initialize_exp(params) 116 | # N+1 embeddings, N mappings , N+1 discriminators 117 | embs, mappings, discriminators = build_model(params, True) 118 | trainer = Trainer(embs, mappings, discriminators, params) 119 | evaluator = Evaluator(trainer) 120 | 121 | 122 | """ 123 | Learning loop for Multilingual Adversarial Training 124 | """ 125 | if params.adversarial: 126 | logger.info('----> MULTILINGUAL ADVERSARIAL TRAINING <----\n\n') 127 | 128 | # training loop 129 | for n_epoch in range(params.n_epochs): 130 | 131 | logger.info('Starting adversarial training epoch %i...' % n_epoch) 132 | tic = time.time() 133 | n_words_proc = 0 134 | stats = {'DIS_COSTS': []} 135 | 136 | for n_iter in range(0, params.epoch_size, params.batch_size): 137 | 138 | # discriminator training 139 | for _ in range(params.dis_steps): 140 | trainer.dis_step(stats) 141 | 142 | # mapping training (discriminator fooling) 143 | n_words_proc += trainer.mapping_step(stats) 144 | 145 | # log stats 146 | if n_iter % 500 == 0: 147 | stats_str = [('DIS_COSTS', 'Discriminator loss')] 148 | stats_log = ['%s: %.4f' % (v, np.mean(stats[k])) 149 | for k, v in stats_str if len(stats[k]) > 0] 150 | stats_log.append('%i samples/s' % int(n_words_proc / (time.time() - tic))) 151 | logger.info(('%06i - ' % n_iter) + ' - '.join(stats_log)) 152 | 153 | # reset 154 | tic = time.time() 155 | n_words_proc = 0 156 | for k, _ in stats_str: 157 | del stats[k][:] 158 | 159 | # embeddings / discriminator evaluation 160 | to_log = OrderedDict({'n_epoch': n_epoch}) 161 | evaluator.all_eval(to_log) 162 | evaluator.eval_all_dis(to_log) 163 | 164 | # JSON log / save best model / end of epoch 165 | logger.info("__log__:%s" % json.dumps(to_log)) 166 | trainer.save_best(to_log, VALIDATION_METRIC) 167 | logger.info('End of epoch %i.\n\n' % n_epoch) 168 | 169 | # update the learning rate (stop if too small) 170 | trainer.update_lr(to_log, VALIDATION_METRIC) 171 | if trainer.map_optimizer.param_groups[0]['lr'] < params.min_lr: 172 | logger.info('Learning rate < 1e-6. BREAK.') 173 | break 174 | 175 | 176 | """ 177 | Learning loop for Multilingual Pseudo-Supervised Refinement 178 | """ 179 | if params.n_refinement > 0: 180 | # Get the best mapping according to VALIDATION_METRIC 181 | logger.info('----> MULTILINGUAL PSEUDO-SUPERVISED REFINEMENT <----\n\n') 182 | trainer.reload_best() 183 | 184 | # training loop 185 | for n_epoch in range(params.n_refinement): 186 | 187 | logger.info('Starting refinement iteration %i...' % n_epoch) 188 | 189 | # build a dictionary from aligned embeddings 190 | trainer.build_dictionary() 191 | 192 | # optimize MPSR 193 | tic = time.time() 194 | n_words_mpsr = 0 195 | stats = {'MPSR_COSTS': []} 196 | for n_iter in range(params.mpsr_n_steps): 197 | # mpsr training step 198 | n_words_mpsr += trainer.mpsr_step(stats) 199 | # log stats 200 | if n_iter % 500 == 0: 201 | stats_str = [('MPSR_COSTS', 'MPSR loss')] 202 | stats_log = ['%s: %.4f' % (v, np.mean(stats[k])) 203 | for k, v in stats_str if len(stats[k]) > 0] 204 | stats_log.append('%i samples/s' % int(n_words_mpsr / (time.time() - tic))) 205 | logger.info(('%06i - ' % n_iter) + ' - '.join(stats_log)) 206 | # reset 207 | tic = time.time() 208 | n_words_mpsr = 0 209 | for k, _ in stats_str: 210 | del stats[k][:] 211 | 212 | # embeddings evaluation 213 | to_log = OrderedDict({'n_mpsr_epoch': n_epoch}) 214 | evaluator.all_eval(to_log) 215 | 216 | # JSON log / save best model / end of epoch 217 | logger.info("__log__:%s" % json.dumps(to_log)) 218 | trainer.save_best(to_log, VALIDATION_METRIC) 219 | logger.info('End of refinement iteration %i.\n\n' % n_epoch) 220 | 221 | # update the learning rate (effective only if using SGD for MPSR) 222 | trainer.update_mpsr_lr(to_log, VALIDATION_METRIC) 223 | 224 | 225 | # export embeddings 226 | if params.export: 227 | trainer.reload_best() 228 | trainer.export() 229 | --------------------------------------------------------------------------------