├── README.md
├── img
├── logograms1_en_postprocess.png
├── model.svg
├── vae2.svg
└── wordsGif.gif
├── language_data
├── README.md
└── bash_scripts
│ ├── create128dimEmbs_fromWords.sh
│ ├── download_embeddings.sh
│ └── get_evaluation.sh
├── project_proposal.pdf
├── src
├── MapperLayer.py
├── Omniglot.py
├── Omniglot_triplet.py
├── Siamese.py
├── configs
│ ├── bbvae.yaml
│ ├── bbvae_CompleteRun.yaml
│ ├── bbvae_setup2.yaml
│ ├── bbvae_setup3.yml
│ ├── bbvae_setup4.yml
│ └── betatc_vae.yaml
├── decoder_with_discriminator.py
├── experiment.py
├── logogram_language_generator.py
├── lossyMapper_train.py
├── models
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── base.cpython-36.pyc
│ │ ├── beta_vae.cpython-36.pyc
│ │ ├── betatc_vae.cpython-36.pyc
│ │ ├── cat_vae.cpython-36.pyc
│ │ ├── cvae.cpython-36.pyc
│ │ ├── dfcvae.cpython-36.pyc
│ │ ├── dip_vae.cpython-36.pyc
│ │ ├── fvae.cpython-36.pyc
│ │ ├── gamma_vae.cpython-36.pyc
│ │ ├── hvae.cpython-36.pyc
│ │ ├── info_vae.cpython-36.pyc
│ │ ├── iwae.cpython-36.pyc
│ │ ├── joint_vae.cpython-36.pyc
│ │ ├── logcosh_vae.cpython-36.pyc
│ │ ├── lvae.cpython-36.pyc
│ │ ├── miwae.cpython-36.pyc
│ │ ├── mssim_vae.cpython-36.pyc
│ │ ├── swae.cpython-36.pyc
│ │ ├── types_.cpython-36.pyc
│ │ ├── vampvae.cpython-36.pyc
│ │ ├── vanilla_vae.cpython-36.pyc
│ │ ├── vq_vae.cpython-36.pyc
│ │ └── wae_mmd.cpython-36.pyc
│ ├── base.py
│ └── beta_vae.py
├── randomSamples_fromEmbeddings
│ ├── sample_fasttext_embs_en.pickle
│ ├── sample_fasttext_embs_en_128.pickle
│ ├── sample_fasttext_embs_es_128.pickle
│ ├── sample_fasttext_embs_fr.pickle
│ ├── sample_fasttext_embs_fr_128.pickle
│ ├── sample_fasttext_embs_it_128.pickle
│ └── sample_from_embeds.py
├── requirements.txt
├── run.py
├── test.py
├── train_mapper_and_discriminator.py
├── train_mapper_with_siamese.py
├── train_siamese.py
├── umwe2vae.py
├── umwe_mappers
│ ├── best_mapping_es2en.t7
│ ├── best_mapping_fr2en.t7
│ └── best_mapping_it2en.t7
├── utils.py
└── vae_with_norm.py
└── umwe
├── LICENSE
├── README.md
├── data
└── get_evaluation.sh
├── demo.ipynb
├── evaluate.py
├── src
├── __init__.py
├── dico_builder.py
├── dictionary.py
├── evaluation
│ ├── __init__.py
│ ├── evaluator.py
│ ├── sent_translation.py
│ ├── word_translation.py
│ └── wordsim.py
├── logger.py
├── models.py
├── trainer.py
└── utils.py
├── supervised.py
└── unsupervised.py
/README.md:
--------------------------------------------------------------------------------
1 | # Logogram Language Generator
2 |
3 | This project is done within the scope of inzva AI Projects #5 August-November 2020. Check out other projects from the inzva GitHub.
4 |
5 |
6 | [Medium Post]: https://medium.com/@selim.seker00/logogram-language-generator-eb003293d51d .
7 |
8 | [Record]: https://www.youtube.com/watch?v=hTKlFtC7NMw of the final project presentation (in Turkish)
9 |
10 |
11 |
12 | 
13 | 
14 |
15 |
16 |
17 |
18 | ## Readings
19 |
20 | #### Multilingual embedding:
21 |
22 | * https://arxiv.org/pdf/1808.08933.pdf
23 | * https://ai.facebook.com/tools/muse/
24 |
25 | * http://jalammar.github.io/illustrated-word2vec/
26 |
27 | * https://ruder.io/cross-lingual-embeddings/
28 |
29 |
30 |
31 | ##### About linear mapping between language embedings:
32 |
33 | * https://arxiv.org/pdf/1309.4168.pdf
34 |
35 | * https://www.aclweb.org/anthology/N15-1104.pdf
36 |
37 |
38 |
39 | ##### About fastText (pre-trained monolingual embeddings in UMWE):
40 |
41 | * https://fasttext.cc/
42 |
43 | * https://www.aclweb.org/anthology/Q17-1010.pdf
44 |
45 |
46 |
47 |
48 |
49 | #### Generative Models:
50 |
51 | * https://arxiv.org/pdf/1312.6114.pdf
52 |
53 | * https://arxiv.org/pdf/1511.05644.pdf
54 |
55 | * https://jaan.io/what-is-variational-autoencoder-vae-tutorial/
56 |
57 | * https://towardsdatascience.com/generating-images-with-autoencoders-77fd3a8dd368
58 |
59 | * https://theaisummer.com/Autoencoder/
60 |
61 | * https://towardsdatascience.com/generative-variational-autoencoder-for-high-resolution-image-synthesis-48dd98d4dcc2
62 |
63 |
64 | ## Datasets
65 |
66 | * https://github.com/brendenlake/omniglot
67 | * https://fasttext.cc/
68 |
69 |
70 | #### Drive Link for Docs
71 |
72 | * https://drive.google.com/drive/folders/1md5ZayU85yaKnx0d6LM63gR1NwQNd3pX?usp=sharing
73 |
74 |
75 |
76 |
77 | 
78 |
--------------------------------------------------------------------------------
/img/logograms1_en_postprocess.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/img/logograms1_en_postprocess.png
--------------------------------------------------------------------------------
/img/wordsGif.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/img/wordsGif.gif
--------------------------------------------------------------------------------
/language_data/README.md:
--------------------------------------------------------------------------------
1 | here is the procedure we applied for umwe training:
2 | 1. we specified our vae's latent dim as 128 and fasttext monolingual vectors are 300
3 | 2. umwe takes training data from wiki.vec files (whitespace seperated plain text -> "word 300dVector")
4 | 3. so we need to reduce the vectors in the wiki.vec files (seperate file for each language)
5 | 4. in fasttext besides .vec files there is .bin embedding binaries for each language. we first need to reduce binaries dimension (there is a script for that in fasttext repo)
6 | 5. after reducing the each .bin file, again we are going to use a fasttext script for extracting embedding vectors from a word-list
7 | 6. to use that script we need to split the words-list from the .vec files (just write a python script for that)
8 | 7. then create the new 128 dimensional .vec files
9 | 8. dont forget to change the hyperparam for training umwe with 128 dim embeddings
10 |
--------------------------------------------------------------------------------
/language_data/bash_scripts/create128dimEmbs_fromWords.sh:
--------------------------------------------------------------------------------
1 | # dont forget to add first line as "num_of_words emb_dim" for umwe (sed command for inserting line to the head of file)
2 |
3 | #./fasttext print-word-vectors ../cc.en.128.bin < ../wiki_wordsOnly/wiki.en.128.vec > ../wiki_vectors_128/wiki.en.vec
4 | #./fasttext print-word-vectors ../cc.es.128.bin < ../wiki_wordsOnly/wiki.es.128.vec > ../wiki_vectors_128/wiki.es.vec
5 | #./fasttext print-word-vectors ../cc.fr.128.bin < ../wiki_wordsOnly/wiki.fr.128.vec > ../wiki_vectors_128/wiki.fr.vec
6 | #./fasttext print-word-vectors ../cc.it.128.bin < ../wiki_wordsOnly/wiki.it.128.vec > ../wiki_vectors_128/wiki.it.vec
7 | ../fastText/fasttext print-word-vectors ../../vae/model_checkpoints/cc_bins_128/cc.tr.128.bin < ../wiki_wordsOnly/wiki.tr.128.vec > ../wiki_vectors_128/wiki.tr.vec
8 |
--------------------------------------------------------------------------------
/language_data/bash_scripts/download_embeddings.sh:
--------------------------------------------------------------------------------
1 |
2 | ### This bash script downloads the fastText monolingual embedding vectors for:
3 | ### [en, fr, es, it, tr]
4 |
5 | #curl -o wiki.en.vec https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.en.vec
6 | #curl -o wiki.fr.vec https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.fr.vec
7 | #curl -o wiki.es.vec https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.es.vec
8 | #curl -o wiki.it.vec https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.it.vec
9 | #curl -o wiki.tr.vec https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.tr.vec
10 | #curl -o wiki.tr.vec https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.tr.vec
11 | #curl -o wiki.ru.vec https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.ru.vec
12 | curl -o cc.fr.300.bin https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fr.300.bin.gz
13 | curl -o cc.es.300.bin https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.es.300.bin.gz
14 | curl -o cc.it.300.bin https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.it.300.bin.gz
15 | curl -o cc.tr.300.bin https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.tr.300.bin.gz
16 |
--------------------------------------------------------------------------------
/language_data/bash_scripts/get_evaluation.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | en_analogy='https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/word2vec/source-archive.zip'
9 | dl_path='https://dl.fbaipublicfiles.com/arrival'
10 | semeval_2017='http://alt.qcri.org/semeval2017/task2/data/uploads'
11 | europarl='http://www.statmt.org/europarl/v7'
12 |
13 | #declare -A wordsim_lg
14 | #wordsim_lg=(["en"]="EN_MC-30.txt EN_MTurk-287.txt EN_RG-65.txt EN_VERB-143.txt EN_WS-353-REL.txt EN_YP-130.txt EN_MEN-TR-3k.txt EN_MTurk-771.txt EN_RW-STANFORD.txt EN_SIMLEX-999.txt EN_WS-353-ALL.txt EN_WS-353-SIM.txt" ["es"]="ES_MC-30.txt ES_RG-65.txt ES_WS-353.txt" ["de"]="DE_GUR350.txt DE_GUR65.txt DE_SIMLEX-999.txt DE_WS-353.txt DE_ZG222.txt" ["fr"]="FR_RG-65.txt" ["it"]="IT_SIMLEX-999.txt IT_WS-353.txt")
15 |
16 | declare -A wordsim_lg
17 | wordsim_lg=(["fr"]="FR_RG-65.txt" ["it"]="IT_SIMLEX-999.txt IT_WS-353.txt")
18 |
19 |
20 | mkdir monolingual crosslingual
21 |
22 | ## English word analogy task
23 | curl -Lo source-archive.zip $en_analogy
24 | mkdir -p monolingual/en/
25 | unzip -p source-archive.zip word2vec/trunk/questions-words.txt > monolingual/en/questions-words.txt
26 | rm source-archive.zip
27 |
28 |
29 | ## Downloading en-{} or {}-en dictionaries
30 | #lgs="af ar bg bn bs ca cs da de el en es et fa fi fr he hi hr hu id it ja ko lt lv mk ms nl no pl pt ro ru sk sl sq sv ta th tl tr uk vi zh"
31 | lgs="fr he hi hr hu id it ja ko lt lv mk ms nl no pl pt ro ru sk sl sq sv ta th tl tr uk vi zh"
32 |
33 | mkdir -p crosslingual/dictionaries/
34 | for lg in ${lgs}
35 | do
36 | for suffix in .txt .0-5000.txt .5000-6500.txt
37 | do
38 | fname=en-$lg$suffix
39 | curl -Lo crosslingual/dictionaries/$fname $dl_path/dictionaries/$fname
40 | fname=$lg-en$suffix
41 | curl -Lo crosslingual/dictionaries/$fname $dl_path/dictionaries/$fname
42 | done
43 | done
44 |
45 | ## Download European dictionaries
46 | for src_lg in de es fr it pt
47 | do
48 | for tgt_lg in de es fr it pt
49 | do
50 | if [ $src_lg != $tgt_lg ]
51 | then
52 | for suffix in .txt .0-5000.txt .5000-6500.txt
53 | do
54 | fname=$src_lg-$tgt_lg$suffix
55 | curl -Lo crosslingual/dictionaries/$fname $dl_path/dictionaries/$fname
56 | done
57 | fi
58 | done
59 | done
60 |
61 | ## Download Dinu et al. dictionaries
62 | for fname in OPUS_en_it_europarl_train_5K.txt OPUS_en_it_europarl_test.txt
63 | do
64 | echo $fname
65 | curl -Lo crosslingual/dictionaries/$fname $dl_path/dictionaries/$fname
66 | done
67 |
68 | ## Monolingual wordsim tasks
69 | for lang in "${!wordsim_lg[@]}"
70 | do
71 | echo $lang
72 | mkdir monolingual/$lang
73 | for wsim in ${wordsim_lg[$lang]}
74 | do
75 | echo $wsim
76 | curl -Lo monolingual/$lang/$wsim $dl_path/$lang/$wsim
77 | done
78 | done
79 |
80 | ## SemEval 2017 monolingual and cross-lingual wordsim tasks
81 | # 1) Task1: monolingual
82 | curl -Lo semeval2017-task2.zip $semeval_2017/semeval2017-task2.zip
83 | unzip semeval2017-task2.zip
84 |
85 | fdir='SemEval17-Task2/test/subtask1-monolingual'
86 | for lang in en es de fa it
87 | do
88 | mkdir -p monolingual/$lang
89 | uplang=`echo $lang | awk '{print toupper($0)}'`
90 | paste $fdir/data/$lang.test.data.txt $fdir/keys/$lang.test.gold.txt > monolingual/$lang/${uplang}_SEMEVAL17.txt
91 | done
92 |
93 | # 2) Task2: cross-lingual
94 | mkdir -p crosslingual/wordsim
95 | fdir='SemEval17-Task2/test/subtask2-crosslingual'
96 | for lg_pair in de-es de-fa de-it en-de en-es en-fa en-it es-fa es-it it-fa
97 | do
98 | echo $lg_pair
99 | paste $fdir/data/$lg_pair.test.data.txt $fdir/keys/$lg_pair.test.gold.txt > crosslingual/wordsim/$lg_pair-SEMEVAL17.txt
100 | done
101 | rm semeval2017-task2.zip
102 | rm -r SemEval17-Task2/
103 |
104 | ## Europarl for sentence retrieval
105 | # TODO: set to true to activate download of Europarl (slow)
106 | if false; then
107 | mkdir -p crosslingual/europarl
108 | # Tokenize EUROPARL with MOSES
109 | echo 'Cloning Moses github repository (for tokenization scripts)...'
110 | git clone https://github.com/moses-smt/mosesdecoder.git
111 | SCRIPTS=mosesdecoder/scripts
112 | TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
113 |
114 | for lg_pair in it-en # es-en etc
115 | do
116 | curl -Lo $lg_pair.tgz $europarl/$lg_pair.tgz
117 | tar -xvf $lg_pair.tgz
118 | rm $lg_pair.tgz
119 | lgs=(${lg_pair//-/ })
120 | for lg in ${lgs[0]} ${lgs[1]}
121 | do
122 | cat europarl-v7.$lg_pair.$lg | $TOKENIZER -threads 8 -l $lg -no-escape > euro.$lg.txt
123 | rm europarl-v7.$lg_pair.$lg
124 | done
125 |
126 | paste euro.${lgs[0]}.txt euro.${lgs[1]}.txt | shuf > euro.paste.txt
127 | rm euro.${lgs[0]}.txt euro.${lgs[1]}.txt
128 |
129 | cut -f1 euro.paste.txt > crosslingual/europarl/europarl-v7.$lg_pair.${lgs[0]}
130 | cut -f2 euro.paste.txt > crosslingual/europarl/europarl-v7.$lg_pair.${lgs[1]}
131 | rm euro.paste.txt
132 | done
133 |
134 | rm -rf mosesdecoder
135 | fi
136 |
--------------------------------------------------------------------------------
/project_proposal.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/project_proposal.pdf
--------------------------------------------------------------------------------
/src/MapperLayer.py:
--------------------------------------------------------------------------------
1 | from Omniglot import Omniglot
2 | import yaml
3 | import argparse
4 | import numpy as np
5 |
6 | from models import *
7 | from experiment import VAEXperiment
8 | import torch.backends.cudnn as cudnn
9 | from pytorch_lightning import Trainer
10 | from pytorch_lightning.logging import TestTubeLogger
11 | import torchvision.utils as vutils
12 | from torch.utils.data import DataLoader
13 | import torch.optim as optim
14 | from torchvision import transforms
15 |
16 | import torch
17 | from torch import nn
18 | import torch.optim as optim
19 | from torchvision.transforms.functional import adjust_contrast
20 |
21 |
22 | import pickle
23 | import random
24 |
25 |
26 |
27 | class umwe2vae(nn.Module):
28 | def __init__(self,vae_model, in_dim=300, out_dim=128):
29 | super(umwe2vae, self).__init__()
30 | self.vae_model = vae_model
31 | self.fc = nn.Linear(in_dim, out_dim)
32 |
33 | def forward(self, x):
34 | h = self.fc(x)
35 | # y = self.vae_model.decode(h)
36 | return h
37 | # here used to live post-processing
38 | #out = torch.zeros(y.shape)
39 | #for i in range(y.shape[0]):
40 | # out[i] = adjust_contrast(y[i], contrast_factor=2.5)
41 | #return out
42 |
43 | def loss(self, x, alpha=1, beta=1):
44 | middle = x[:,:,1:-1,1:-1]
45 | ne = x[:,:,0:-2,0:-2]
46 | n = x[:,:,0:-2,1:-1]
47 | nw = x[:,:,0:-2,2:]
48 | e = x[:,:,1:-1,0:-2]
49 | w = x[:,:,1:-1,2:]
50 | se = x[:,:,2:,0:-2]
51 | s = x[:,:,2:,1:-1]
52 | sw = x[:,:,2:,2:]
53 |
54 | return alpha * torch.mean(sum([torch.abs(middle-ne),torch.abs(middle-n),torch.abs(middle-nw),torch.abs(middle-e),torch.abs(middle-w),torch.abs(middle-se),torch.abs(middle-s),torch.abs(middle-sw)]) / 8.) - beta * torch.mean(torch.abs(x-0.5))
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 | class EmbeddingMapping(nn.Module):
64 | def __init__(self, device, embedding_vector_dim = 300, decoder_input_dim=128):
65 | super(EmbeddingMapping, self).__init__()
66 | self.device = device
67 | self.embedding_vector_dim = embedding_vector_dim
68 | self.decoder_input_dim = decoder_input_dim
69 | self.mapper_numlayer = 3
70 |
71 | self.linear_layers = []
72 | self.batch_norms = []
73 | for layer in range(0, self.mapper_numlayer-1):
74 | self.linear_layers.append(nn.Linear(embedding_vector_dim, embedding_vector_dim))
75 | self.batch_norms.append(nn.BatchNorm1d(embedding_vector_dim))
76 |
77 | # final layer
78 | self.linear_layers.append(nn.Linear(embedding_vector_dim, decoder_input_dim))
79 | self.batch_norms.append(nn.BatchNorm1d(decoder_input_dim))
80 |
81 |
82 | self.linear_layers = nn.ModuleList(self.linear_layers)
83 | self.batch_norms = nn.ModuleList(self.batch_norms)
84 |
85 |
86 | self.relu = nn.ReLU()
87 |
88 | def forward(self, embedding_vector):
89 | inp = embedding_vector
90 | for layer in range(self.mapper_numlayer):
91 | out = self.linear_layers[layer](inp)
92 | out = self.batch_norms[layer](out)
93 | out = self.relu(out)
94 | inp = out
95 | return out
96 |
97 |
98 |
99 | class MultilingualMapper(nn.Module):
100 | def __init__(self, device, embedding_vector_dim = 300, decoder_input_dim=128):
101 | super(MultilingualMapper, self).__init__()
102 | self.device = device
103 | self.embedding_vector_dim = embedding_vector_dim
104 | self.decoder_input_dim = decoder_input_dim
105 | self.mapper_numlayer = 3
106 |
107 | self.linear_layers = []
108 | self.batch_norms = []
109 | for layer in range(0, self.mapper_numlayer-1):
110 | self.linear_layers.append(nn.Linear(embedding_vector_dim, embedding_vector_dim))
111 | self.batch_norms.append(nn.BatchNorm1d(embedding_vector_dim))
112 |
113 | # final layer
114 | self.linear_layers.append(nn.Linear(embedding_vector_dim, decoder_input_dim))
115 | self.batch_norms.append(nn.BatchNorm1d(decoder_input_dim))
116 |
117 |
118 | self.linear_layers = nn.ModuleList(self.linear_layers)
119 | self.batch_norms = nn.ModuleList(self.batch_norms)
120 | self.relu = nn.ReLU()
121 |
122 | self.bce = nn.BCEWithLogitsLoss()
123 |
124 | def forward(self, embedding_vector):
125 | inp = embedding_vector
126 | for layer in range(self.mapper_numlayer):
127 | out = self.linear_layers[layer](inp)
128 | out = self.batch_norms[layer](out)
129 | out = self.relu(out)
130 | inp = out
131 | return out
132 |
133 | def triplet_loss(self, sameWords_diffLangs, diffWords_sameLangs):
134 | return self.bce(sameWords_diffLangs, torch.ones(sameWords_diffLangs.shape).to(self.device)) + self.bce(diffWords_sameLangs, torch.zeros(diffWords_sameLangs.shape).to(self.device))
135 |
136 |
137 |
--------------------------------------------------------------------------------
/src/Omniglot.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from PIL import Image
4 | from torchvision import transforms
5 | import os
6 |
7 | class Omniglot(Dataset):
8 | def __init__(self, split="train", transform=None):
9 | self.split = split
10 | if(transform==None):
11 | self.transform = transforms.ToTensor()
12 | else:
13 | self.transform = transform
14 |
15 | if(split=="train"):
16 | self.alphabet_names = ['Alphabet_of_the_Magi', 'Anglo-Saxon_Futhorc', 'Arcadian', 'Armenian', 'Asomtavruli_(Georgian)', 'Balinese', 'Bengali', 'Blackfoot_(Canadian_Aboriginal_Syllabics)', 'Braille', 'Burmese_(Myanmar)', 'Cyrillic', 'Early_Aramaic', 'Futurama', 'Grantha', 'Greek', 'Gujarati', 'Hebrew', 'Inuktitut_(Canadian_Aboriginal_Syllabics)', 'Japanese_(hiragana)', 'Japanese_(katakana)', 'Korean', 'Latin', 'Malay_(Jawi_-_Arabic)', 'Mkhedruli_(Georgian)', 'N_Ko', 'Ojibwe_(Canadian_Aboriginal_Syllabics)', 'Sanskrit', 'Syriac_(Estrangelo)', 'Tagalog', 'Tifinagh']
17 | self.character_nums = [20, 29, 26, 41, 40, 24, 46, 14, 26, 34, 33, 22, 26, 43, 24, 48, 22, 16, 52, 47, 40, 26, 40, 41, 33, 14, 42, 23, 17, 55]
18 |
19 | elif(split=="test"):
20 | self.alphabet_names = ['Angelic', 'Atemayar_Qelisayer', 'Atlantean', 'Aurek-Besh', 'Avesta', 'Ge_ez', 'Glagolitic', 'Gurmukhi', 'Kannada', 'Keble', 'Malayalam', 'Manipuri', 'Mongolian', 'Old_Church_Slavonic_(Cyrillic)', 'Oriya', 'Sylheti', 'Syriac_(Serto)', 'Tengwar', 'Tibetan', 'ULOG']
21 | self.character_nums = [20, 26, 26, 26, 26, 26, 45, 45, 41, 26, 47, 40, 30, 45, 46, 28, 23, 25, 42, 26]
22 |
23 | else: # all splits
24 | self.alphabet_names = ['Alphabet_of_the_Magi', 'Anglo-Saxon_Futhorc', 'Arcadian', 'Armenian', 'Asomtavruli_(Georgian)', 'Balinese', 'Bengali', 'Blackfoot_(Canadian_Aboriginal_Syllabics)', 'Braille', 'Burmese_(Myanmar)', 'Cyrillic', 'Early_Aramaic', 'Futurama', 'Grantha', 'Greek', 'Gujarati', 'Hebrew', 'Inuktitut_(Canadian_Aboriginal_Syllabics)', 'Japanese_(hiragana)', 'Japanese_(katakana)', 'Korean', 'Latin', 'Malay_(Jawi_-_Arabic)', 'Mkhedruli_(Georgian)', 'N_Ko', 'Ojibwe_(Canadian_Aboriginal_Syllabics)', 'Sanskrit', 'Syriac_(Estrangelo)', 'Tagalog', 'Tifinagh', 'Angelic', 'Atemayar_Qelisayer', 'Atlantean', 'Aurek-Besh', 'Avesta', 'Ge_ez', 'Glagolitic', 'Gurmukhi', 'Kannada', 'Keble', 'Malayalam', 'Manipuri', 'Mongolian', 'Old_Church_Slavonic_(Cyrillic)', 'Oriya', 'Sylheti', 'Syriac_(Serto)', 'Tengwar', 'Tibetan', 'ULOG']
25 | self.character_nums = [20, 29, 26, 41, 40, 24, 46, 14, 26, 34, 33, 22, 26, 43, 24, 48, 22, 16, 52, 47, 40, 26, 40, 41, 33, 14, 42, 23, 17, 55, 20, 26, 26, 26, 26, 26, 45, 45, 41, 26, 47, 40, 30, 45, 46, 28, 23, 25, 42, 26]
26 |
27 | self.n_images = sum(self.character_nums) * 20
28 |
29 | def __len__(self):
30 | return self.n_images
31 |
32 | def __getitem__(self, idx):
33 | j = idx % 20
34 | i = idx // 20
35 | label = i
36 |
37 | for k in range(len(self.alphabet_names)):
38 | if(i < self.character_nums[k]):
39 | alphabet = self.alphabet_names[k]+"/"
40 | if(self.split=="train"):
41 | folder = "images_background/"
42 | elif(self.split=="test"):
43 | folder = "images_evaluation/"
44 | else:
45 | folder = "images_background/" if k<30 else "images_evaluation/"
46 | break
47 | else:
48 | i -= self.character_nums[k]
49 |
50 | # /content/omniglot
51 | # char_name = "/content/omniglot/"+folder+alphabet+"character"+(str(i+1) if i>8 else "0"+str(i+1))+"/"
52 | char_name = "omniglot/"+folder+alphabet+"character"+(str(i+1) if i>8 else "0"+str(i+1))+"/"
53 | example_list = sorted(os.listdir(char_name))
54 | img = Image.open(char_name+example_list[j])
55 |
56 | return (self.transform(img.convert('L')), label)
57 |
--------------------------------------------------------------------------------
/src/Omniglot_triplet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from PIL import Image
4 | from torchvision import transforms
5 | import os
6 | import random
7 |
8 | class Omniglot(Dataset):
9 | def __init__(self, split="train", transform=None):
10 | self.split = split
11 | if(transform==None):
12 | self.transform = transforms.ToTensor()
13 | else:
14 | self.transform = transform
15 |
16 | if(split=="train"):
17 | self.alphabet_names = ['Alphabet_of_the_Magi', 'Anglo-Saxon_Futhorc', 'Arcadian', 'Armenian', 'Asomtavruli_(Georgian)', 'Balinese', 'Bengali', 'Blackfoot_(Canadian_Aboriginal_Syllabics)', 'Braille', 'Burmese_(Myanmar)', 'Cyrillic', 'Early_Aramaic', 'Futurama', 'Grantha', 'Greek', 'Gujarati', 'Hebrew', 'Inuktitut_(Canadian_Aboriginal_Syllabics)', 'Japanese_(hiragana)', 'Japanese_(katakana)', 'Korean', 'Latin', 'Malay_(Jawi_-_Arabic)', 'Mkhedruli_(Georgian)', 'N_Ko', 'Ojibwe_(Canadian_Aboriginal_Syllabics)', 'Sanskrit', 'Syriac_(Estrangelo)', 'Tagalog', 'Tifinagh']
18 | self.character_nums = [20, 29, 26, 41, 40, 24, 46, 14, 26, 34, 33, 22, 26, 43, 24, 48, 22, 16, 52, 47, 40, 26, 40, 41, 33, 14, 42, 23, 17, 55]
19 |
20 | elif(split=="test"):
21 | self.alphabet_names = ['Angelic', 'Atemayar_Qelisayer', 'Atlantean', 'Aurek-Besh', 'Avesta', 'Ge_ez', 'Glagolitic', 'Gurmukhi', 'Kannada', 'Keble', 'Malayalam', 'Manipuri', 'Mongolian', 'Old_Church_Slavonic_(Cyrillic)', 'Oriya', 'Sylheti', 'Syriac_(Serto)', 'Tengwar', 'Tibetan', 'ULOG']
22 | self.character_nums = [20, 26, 26, 26, 26, 26, 45, 45, 41, 26, 47, 40, 30, 45, 46, 28, 23, 25, 42, 26]
23 |
24 | else: # all splits
25 | self.alphabet_names = ['Alphabet_of_the_Magi', 'Anglo-Saxon_Futhorc', 'Arcadian', 'Armenian', 'Asomtavruli_(Georgian)', 'Balinese', 'Bengali', 'Blackfoot_(Canadian_Aboriginal_Syllabics)', 'Braille', 'Burmese_(Myanmar)', 'Cyrillic', 'Early_Aramaic', 'Futurama', 'Grantha', 'Greek', 'Gujarati', 'Hebrew', 'Inuktitut_(Canadian_Aboriginal_Syllabics)', 'Japanese_(hiragana)', 'Japanese_(katakana)', 'Korean', 'Latin', 'Malay_(Jawi_-_Arabic)', 'Mkhedruli_(Georgian)', 'N_Ko', 'Ojibwe_(Canadian_Aboriginal_Syllabics)', 'Sanskrit', 'Syriac_(Estrangelo)', 'Tagalog', 'Tifinagh', 'Angelic', 'Atemayar_Qelisayer', 'Atlantean', 'Aurek-Besh', 'Avesta', 'Ge_ez', 'Glagolitic', 'Gurmukhi', 'Kannada', 'Keble', 'Malayalam', 'Manipuri', 'Mongolian', 'Old_Church_Slavonic_(Cyrillic)', 'Oriya', 'Sylheti', 'Syriac_(Serto)', 'Tengwar', 'Tibetan', 'ULOG']
26 | self.character_nums = [20, 29, 26, 41, 40, 24, 46, 14, 26, 34, 33, 22, 26, 43, 24, 48, 22, 16, 52, 47, 40, 26, 40, 41, 33, 14, 42, 23, 17, 55, 20, 26, 26, 26, 26, 26, 45, 45, 41, 26, 47, 40, 30, 45, 46, 28, 23, 25, 42, 26]
27 |
28 | self.n_images = sum(self.character_nums) * 20
29 |
30 | def __len__(self):
31 | return self.n_images
32 |
33 | def __getitem__(self, idx):
34 | j = idx % 20
35 | i = idx // 20
36 | label = i
37 |
38 | for k in range(len(self.alphabet_names)):
39 | if(i < self.character_nums[k]):
40 | alphabet = self.alphabet_names[k]+"/"
41 | if(self.split=="train"):
42 | folder = "images_background/"
43 | elif(self.split=="test"):
44 | folder = "images_evaluation/"
45 | else:
46 | folder = "images_background/" if k<30 else "images_evaluation/"
47 | break
48 | else:
49 | i -= self.character_nums[k]
50 |
51 | char_name = "omniglot/"+folder+alphabet+"character"+(str(i+1) if i>8 else "0"+str(i+1))+"/"
52 | example_list = sorted(os.listdir(char_name))
53 | img = Image.open(char_name+example_list[j])
54 |
55 | return (self.transform(img.convert('L')), label)
56 |
57 | class Omniglot_triplet(Dataset):
58 | def __init__(self, split="train", transform=None):
59 | self.split = split
60 | if(transform==None):
61 | self.transform = transforms.ToTensor()
62 | else:
63 | self.transform = transform
64 |
65 | if(split=="train"):
66 | self.alphabet_names = ['Alphabet_of_the_Magi', 'Anglo-Saxon_Futhorc', 'Arcadian', 'Armenian', 'Asomtavruli_(Georgian)', 'Balinese', 'Bengali', 'Blackfoot_(Canadian_Aboriginal_Syllabics)', 'Braille', 'Burmese_(Myanmar)', 'Cyrillic', 'Early_Aramaic', 'Futurama', 'Grantha', 'Greek', 'Gujarati', 'Hebrew', 'Inuktitut_(Canadian_Aboriginal_Syllabics)', 'Japanese_(hiragana)', 'Japanese_(katakana)', 'Korean', 'Latin', 'Malay_(Jawi_-_Arabic)', 'Mkhedruli_(Georgian)', 'N_Ko', 'Ojibwe_(Canadian_Aboriginal_Syllabics)', 'Sanskrit', 'Syriac_(Estrangelo)', 'Tagalog', 'Tifinagh']
67 | self.character_nums = [20, 29, 26, 41, 40, 24, 46, 14, 26, 34, 33, 22, 26, 43, 24, 48, 22, 16, 52, 47, 40, 26, 40, 41, 33, 14, 42, 23, 17, 55]
68 |
69 | elif(split=="test"):
70 | self.alphabet_names = ['Angelic', 'Atemayar_Qelisayer', 'Atlantean', 'Aurek-Besh', 'Avesta', 'Ge_ez', 'Glagolitic', 'Gurmukhi', 'Kannada', 'Keble', 'Malayalam', 'Manipuri', 'Mongolian', 'Old_Church_Slavonic_(Cyrillic)', 'Oriya', 'Sylheti', 'Syriac_(Serto)', 'Tengwar', 'Tibetan', 'ULOG']
71 | self.character_nums = [20, 26, 26, 26, 26, 26, 45, 45, 41, 26, 47, 40, 30, 45, 46, 28, 23, 25, 42, 26]
72 |
73 | else: # all splits
74 | self.alphabet_names = ['Alphabet_of_the_Magi', 'Anglo-Saxon_Futhorc', 'Arcadian', 'Armenian', 'Asomtavruli_(Georgian)', 'Balinese', 'Bengali', 'Blackfoot_(Canadian_Aboriginal_Syllabics)', 'Braille', 'Burmese_(Myanmar)', 'Cyrillic', 'Early_Aramaic', 'Futurama', 'Grantha', 'Greek', 'Gujarati', 'Hebrew', 'Inuktitut_(Canadian_Aboriginal_Syllabics)', 'Japanese_(hiragana)', 'Japanese_(katakana)', 'Korean', 'Latin', 'Malay_(Jawi_-_Arabic)', 'Mkhedruli_(Georgian)', 'N_Ko', 'Ojibwe_(Canadian_Aboriginal_Syllabics)', 'Sanskrit', 'Syriac_(Estrangelo)', 'Tagalog', 'Tifinagh', 'Angelic', 'Atemayar_Qelisayer', 'Atlantean', 'Aurek-Besh', 'Avesta', 'Ge_ez', 'Glagolitic', 'Gurmukhi', 'Kannada', 'Keble', 'Malayalam', 'Manipuri', 'Mongolian', 'Old_Church_Slavonic_(Cyrillic)', 'Oriya', 'Sylheti', 'Syriac_(Serto)', 'Tengwar', 'Tibetan', 'ULOG']
75 | self.character_nums = [20, 29, 26, 41, 40, 24, 46, 14, 26, 34, 33, 22, 26, 43, 24, 48, 22, 16, 52, 47, 40, 26, 40, 41, 33, 14, 42, 23, 17, 55, 20, 26, 26, 26, 26, 26, 45, 45, 41, 26, 47, 40, 30, 45, 46, 28, 23, 25, 42, 26]
76 |
77 | self.n_images = sum(self.character_nums) * 20
78 |
79 | def __len__(self):
80 | return self.n_images
81 |
82 | def get_single_image(self, idx):
83 | j = idx % 20
84 | i = idx // 20
85 | label = i
86 |
87 | for k in range(len(self.alphabet_names)):
88 | if(i < self.character_nums[k]):
89 | alphabet = self.alphabet_names[k]+"/"
90 | if(self.split=="train"):
91 | folder = "images_background/"
92 | elif(self.split=="test"):
93 | folder = "images_evaluation/"
94 | else:
95 | folder = "images_background/" if k<30 else "images_evaluation/"
96 | break
97 | else:
98 | i -= self.character_nums[k]
99 |
100 | # char_name = "omniglot/"+folder+alphabet+"character"+(str(i+1) if i>8 else "0"+str(i+1))+"/"
101 | char_name = "./omniglot/"+folder+alphabet+"character"+(str(i+1) if i>8 else "0"+str(i+1))+"/"
102 | example_list = sorted(os.listdir(char_name))
103 | img = Image.open(char_name+example_list[j])
104 |
105 | return self.transform(img.convert('L'))
106 |
107 | def __getitem__(self, idx):
108 | j = idx % 20
109 | prange = list(range(idx-j, idx)) + list(range(idx+1, idx-j+20))
110 | nrange = list(range(idx-j)) + list(range(idx-j+20,self.n_images))
111 | idx_p = random.choice(prange)
112 | idx_n = random.choice(nrange)
113 |
114 | return (self.get_single_image(idx),self.get_single_image(idx_p),self.get_single_image(idx_n))
115 |
116 |
--------------------------------------------------------------------------------
/src/Siamese.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.optim as optim
5 |
6 | class Siamese(nn.Module):
7 |
8 | def __init__(self):
9 | super(Siamese, self).__init__()
10 | self.encoder = nn.Sequential(
11 | nn.Conv2d(1, 64, 10), # 64@96*96
12 | nn.ReLU(inplace=True),
13 | nn.MaxPool2d(2), # 64@48*48
14 | nn.Conv2d(64, 128, 7),
15 | nn.ReLU(), # 128@42*42
16 | nn.MaxPool2d(2), # 128@21*21
17 | nn.Conv2d(128, 128, 4),
18 | nn.ReLU(), # 128@18*18
19 | nn.MaxPool2d(2), # 128@9*9
20 | nn.Conv2d(128, 256, 4),
21 | nn.ReLU(), # 256@6*6
22 | nn.Flatten(),
23 | nn.Linear(9216, 4096),
24 | nn.Sigmoid()
25 | )
26 |
27 | self.classifier = nn.Linear(4096, 1)
28 | self.bce_fn = nn.BCEWithLogitsLoss()
29 |
30 | def forward(self, x1, x2):
31 | z1 = self.encoder.forward(x1)
32 | z2 = self.encoder.forward(x2)
33 |
34 | z = torch.abs(z1 - z2)
35 | y = self.classifier.forward(z)
36 | return y
37 |
38 | def triplet_loss(self, yp, yn):
39 | return self.bce_fn(yp, torch.ones(yp.shape).cuda()) + self.bce_fn(yn, torch.zeros(yn.shape).cuda())
40 |
41 | def train_triplet(self, loader, epochs=100):
42 | self.train()
43 | print("new_siamese")
44 | optimizer = optim.Adam(self.parameters(), lr=0.00001)
45 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)
46 | minLoss = 1000.5
47 | patience = 0
48 | maxPatience = 5
49 | for epoch in range(epochs):
50 | total_loss = 0
51 | for i,(xa,xp,xn) in enumerate(loader):
52 | yp = self.forward(xa.cuda(), xp.cuda())
53 | yn = self.forward(xa.cuda(), xn.cuda())
54 |
55 | loss = self.triplet_loss(yp, yn)
56 |
57 | loss.backward()
58 | optimizer.step()
59 |
60 | total_loss += loss
61 | print("Epoch: %d Loss: %f" % (epoch+1, total_loss.item()/len(loader)))
62 | scheduler.step()
63 |
64 | if total_loss.item()/len(loader) < minLoss:
65 | minLoss = total_loss.item()/len(loader)
66 | patience = 0
67 | else:
68 | patience +=1
69 | print("patienceCounter: ", patience)
70 | if patience == maxPatience:
71 | print("early stopping: ", minLoss)
72 | break
73 | torch.save(self.state_dict(), "siamese.ckpt")
74 | #torch.save({'model_state_dict': self.state_dict()}, "siamese_EPOCH100.ckpt")
75 |
76 | def eval_triplet(self, loader):
77 | self.eval()
78 | correct,total = 0,0
79 | for i,(xa,xp,xn) in enumerate(loader):
80 | yp = self.forward(xa.cuda(), xp.cuda())
81 | yn = self.forward(xa.cuda(), xn.cuda())
82 | correct += torch.sum((yp>0).float()) + torch.sum((yn<0).float())
83 | #correct += torch.sum(torch.FloatTensor(yp>0)) + torch.sum(torch.FloatTensor(yn<0))
84 | #correct += torch.sum(yp>0) + torch.sum(yn<0)
85 | total += yp.shape[0] + yn.shape[0]
86 |
87 | if i==100:
88 | break
89 |
90 | print("Accuracy: %f" % (correct/total).item())
91 |
92 |
--------------------------------------------------------------------------------
/src/configs/bbvae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'BetaVAE'
3 | in_channels: 1
4 | latent_dim: 300
5 | loss_type: 'B'
6 | gamma: 10.0
7 | max_capacity: 25
8 | Capacity_max_iter: 10000
9 |
10 | exp_params:
11 | dataset: omniglot
12 | data_path: "../../shared/Data/"
13 | img_size: 64
14 | batch_size: 144 # Better to have a square number
15 | LR: 0.00005
16 | weight_decay: 0.0
17 | scheduler_gamma: 0.95
18 |
19 | trainer_params:
20 | gpus: 1
21 | max_nb_epochs: 50
22 | max_epochs: 50
23 |
24 | logging_params:
25 | save_dir: "logs/"
26 | name: "BetaVAE_B"
27 | manual_seed: 1265
28 |
--------------------------------------------------------------------------------
/src/configs/bbvae_CompleteRun.yaml:
--------------------------------------------------------------------------------
1 | # latent_dim 300 to 128
2 | # batch_size 144 to 64
3 |
4 | model_params:
5 | name: 'BetaVAE'
6 | in_channels: 1
7 | latent_dim: 128
8 | loss_type: 'B'
9 | gamma: 10.0
10 | max_capacity: 25
11 | Capacity_max_iter: 10000
12 |
13 | exp_params:
14 | dataset: omniglot
15 | data_path: "../../shared/Data/"
16 | img_size: 64
17 | batch_size: 64 # Better to have a square number
18 | LR: 0.00005
19 | weight_decay: 0.0
20 | scheduler_gamma: 0.95
21 |
22 | trainer_params:
23 | gpus: 1
24 | max_nb_epochs: 70
25 | max_epochs: 70
26 |
27 | logging_params:
28 | save_dir: "logs/"
29 | name: "BetaVAE_B_completeRun"
30 | manual_seed: 1265
--------------------------------------------------------------------------------
/src/configs/bbvae_setup2.yaml:
--------------------------------------------------------------------------------
1 | # latent_dim 300 to 128
2 | # batch_size 144 to 64
3 |
4 | model_params:
5 | name: 'BetaVAE'
6 | in_channels: 1
7 | latent_dim: 128
8 | loss_type: 'B'
9 | gamma: 10.0
10 | max_capacity: 25
11 | Capacity_max_iter: 10000
12 |
13 | exp_params:
14 | dataset: omniglot
15 | data_path: "../../shared/Data/"
16 | img_size: 64
17 | batch_size: 64 # Better to have a square number
18 | LR: 0.00005
19 | weight_decay: 0.0
20 | scheduler_gamma: 0.95
21 |
22 | trainer_params:
23 | gpus: 1
24 | max_nb_epochs: 50
25 | max_epochs: 50
26 |
27 | logging_params:
28 | save_dir: "logs/"
29 | name: "BetaVAE_B_setup2_run2"
30 | manual_seed: 1265
--------------------------------------------------------------------------------
/src/configs/bbvae_setup3.yml:
--------------------------------------------------------------------------------
1 | # batch_size 144 to 64
2 |
3 | model_params:
4 | name: "BetaVAE"
5 | in_channels: 1
6 | latent_dim: 300
7 | loss_type: "B"
8 | gamma: 10.0
9 | max_capacity: 25
10 | Capacity_max_iter: 10000
11 |
12 | exp_params:
13 | dataset: omniglot
14 | data_path: "/content/"
15 | img_size: 64
16 | batch_size: 64 # Better to have a square number
17 | LR: 0.00005
18 | weight_decay: 0.0
19 | scheduler_gamma: 0.95
20 |
21 | trainer_params:
22 | gpus: 1
23 | max_nb_epochs: 50
24 | max_epochs: 50
25 |
26 | logging_params:
27 | save_dir: "logs/"
28 | name: "BetaVAE_setup3_run2"
29 | manual_seed: 1265
30 |
--------------------------------------------------------------------------------
/src/configs/bbvae_setup4.yml:
--------------------------------------------------------------------------------
1 | # batch_size 144 to 64
2 | # latent_dim 128 to 64
3 |
4 | model_params:
5 | name: "BetaVAE"
6 | in_channels: 1
7 | latent_dim: 64
8 | loss_type: "B"
9 | gamma: 10.0
10 | max_capacity: 25
11 | Capacity_max_iter: 10000
12 |
13 | exp_params:
14 | dataset: omniglot
15 | data_path: "../../shared/Data/"
16 | img_size: 64
17 | batch_size: 64 # Better to have a square number
18 | LR: 0.00005
19 | weight_decay: 0.0
20 | scheduler_gamma: 0.95
21 |
22 | trainer_params:
23 | gpus: 1
24 | max_nb_epochs: 50
25 | max_epochs: 50
26 |
27 | logging_params:
28 | save_dir: "logs/"
29 | name: "BetaVAE_setup4"
30 | manual_seed: 1265
31 |
--------------------------------------------------------------------------------
/src/configs/betatc_vae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'BetaTCVAE'
3 | in_channels: 3
4 | latent_dim: 10
5 | anneal_steps: 10000
6 | alpha: 1.
7 | beta: 6.
8 | gamma: 1.
9 |
10 | exp_params:
11 | dataset: celeba
12 | data_path: "../../shared/momo/Data/"
13 | img_size: 64
14 | batch_size: 144 # Better to have a square number
15 | LR: 0.001
16 | weight_decay: 0.0
17 | # scheduler_gamma: 0.99
18 |
19 | trainer_params:
20 | gpus: 1
21 | max_nb_epochs: 50
22 | max_epochs: 50
23 |
24 | logging_params:
25 | save_dir: "logs/"
26 | name: "BetaTCVAE"
27 | manual_seed: 1265
28 |
--------------------------------------------------------------------------------
/src/decoder_with_discriminator.py:
--------------------------------------------------------------------------------
1 | from Omniglot import Omniglot
2 | import yaml
3 | import argparse
4 | import numpy as np
5 |
6 | from models import *
7 | from experiment import VAEXperiment
8 | import torch.backends.cudnn as cudnn
9 | from pytorch_lightning import Trainer
10 | from pytorch_lightning.logging import TestTubeLogger
11 | import torchvision.utils as vutils
12 | from torch.utils.data import DataLoader
13 | import torch.optim as optim
14 | from torchvision import transforms
15 |
16 | import pickle
17 | import random
18 |
19 | import fasttext
20 | import fasttext.util
21 |
22 |
23 |
24 | class EmbeddingMapping(nn.Module):
25 | def __init__(self, device, embedding_vector_dim = 300, decoder_input_dim=128):
26 | super(EmbeddingMapping, self).__init__()
27 | self.device = device
28 | self.embedding_vector_dim = embedding_vector_dim
29 | self.decoder_input_dim = decoder_input_dim
30 | self.linear_mapping1 = nn.Linear(embedding_vector_dim, embedding_vector_dim)
31 | self.linear_mapping2 = nn.Linear(embedding_vector_dim, embedding_vector_dim)
32 | self.linear_mapping3 = nn.Linear(embedding_vector_dim, decoder_input_dim)
33 | self.relu = nn.ReLU()
34 | self.nrm1 = nn.BatchNorm1d(embedding_vector_dim)
35 | self.nrm2 = nn.BatchNorm1d(decoder_input_dim)
36 |
37 | def forward(self, embedding_vector):
38 | l1 = self.linear_mapping1(embedding_vector)
39 | l1 = self.nrm1(l1)
40 | l1 = self.relu(l1)
41 | l2 = self.linear_mapping2(l1)
42 | l2 = self.nrm1(l2)
43 | l2 = self.relu(l2)
44 | l3 = self.linear_mapping3(l2)
45 | l3 = self.nrm2(l3)
46 | l3 = self.relu(l3)
47 |
48 | return l3
49 |
50 |
51 |
52 |
53 |
54 | class Discriminator(nn.Module):
55 | def __init__(self, ngpu):
56 | super(Discriminator, self).__init__()
57 | self.ngpu = ngpu
58 | self.ndf = 64
59 | self.nc = 1
60 | self.main = nn.Sequential(
61 | # input is (nc) x 64 x 64
62 | nn.Conv2d(self.nc, self.ndf, 4, 2, 1, bias=False),
63 | nn.LeakyReLU(0.2, inplace=True),
64 | # state size. (self.ndf) x 32 x 32
65 | nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1, bias=False),
66 | nn.BatchNorm2d(self.ndf * 2),
67 | nn.LeakyReLU(0.2, inplace=True),
68 | # state size. (self.ndf*2) x 16 x 16
69 | nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False),
70 | nn.BatchNorm2d(self.ndf * 4),
71 | nn.LeakyReLU(0.2, inplace=True),
72 | # state size. (self.ndf*4) x 8 x 8
73 | nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1, bias=False),
74 | nn.BatchNorm2d(self.ndf * 8),
75 | nn.LeakyReLU(0.2, inplace=True),
76 | # state size. (self.ndf*8) x 4 x 4
77 | nn.Conv2d(self.ndf * 8, 1, 4, 1, 0, bias=False),
78 | nn.Sigmoid()
79 | )
80 |
81 | def forward(self, input):
82 | return self.main(input)
83 |
84 |
85 | def weights_init(m):
86 | classname = m.__class__.__name__
87 | if classname.find('Conv') != -1:
88 | nn.init.normal_(m.weight.data, 0.0, 0.02)
89 | elif classname.find('BatchNorm') != -1:
90 | nn.init.normal_(m.weight.data, 1.0, 0.02)
91 | nn.init.constant_(m.bias.data, 0)
92 |
93 |
94 | def get_dataLoader(batch_size):
95 | transform = data_transforms()
96 | dataset = Omniglot(split="train", transform=transform)
97 | num_train_imgs = len(dataset)
98 | return DataLoader(dataset,
99 | batch_size= batch_size,
100 | shuffle = True,
101 | drop_last=True)
102 |
103 |
104 | def data_transforms():
105 | SetRange = transforms.Lambda(lambda X: 2 * X - 1.)
106 | SetScale = transforms.Lambda(lambda X: X/X.sum(0).expand_as(X))
107 | transform = transforms.Compose([transforms.Resize((64,64)), transforms.ToTensor()])
108 | return transform
109 |
110 |
111 | def randomSample_from_embeddings(batch_size):
112 | with open("/content/drive/My Drive/vae/logogram-language-generator-master/sample_fasttext_embs.pickle", "rb") as f:
113 | random_sample = pickle.load(f)
114 |
115 | fake_batch = random.sample(list(random_sample.values()), batch_size)
116 | fake_batch = torch.stack(fake_batch, dim=0)
117 | return fake_batch
118 |
119 |
120 |
121 |
122 |
123 | def trainer(vae_model, mapper, netD, batch_size, device):
124 | # Training Loop
125 | # Lists to keep track of progress
126 | img_list = []
127 | G_losses = []
128 | D_losses = []
129 | iters = 0
130 | num_epochs = 5
131 | dataloader = get_dataLoader(batch_size=batch_size)
132 |
133 | # Initialize BCELoss function
134 | criterion = nn.BCELoss()
135 | lr = 0.00001
136 | beta1 = 0.5
137 | # Setup Adam optimizers for both G and D
138 | optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
139 | optimizerG = optim.Adam(mapper.parameters(), lr=lr, betas=(beta1, 0.999))
140 |
141 |
142 |
143 | # Create batch of latent vectors that we will use to visualize
144 | # the progression of the generator
145 | fixed_noise = randomSample_from_embeddings(batch_size).to(device)
146 |
147 | criterion = nn.BCELoss()
148 |
149 | # Establish convention for real and fake labels during training
150 | real_label = 1.
151 | fake_label = 0.
152 |
153 |
154 | print("Starting Training Loop...")
155 | # For each epoch
156 | for epoch in range(num_epochs):
157 | # For each batch in the dataloader
158 | for i, data in enumerate(dataloader, 0):
159 |
160 | ############################
161 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
162 | ###########################
163 | ## Train with all-real batch
164 | netD.zero_grad()
165 | # Format batch
166 | real_cpu = data[0].to(device)
167 | b_size = real_cpu.size(0)
168 | label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
169 | # Forward pass real batch through D
170 | output = netD(real_cpu).view(-1)
171 | # Calculate loss on all-real batch
172 | errD_real = criterion(output, label)
173 | # Calculate gradients for D in backward pass
174 | errD_real.backward()
175 | D_x = output.mean().item()
176 |
177 | ## Train with all-fake batch
178 | # Generate batch of latent vectors
179 | embeddings = randomSample_from_embeddings(batch_size).to(device)
180 | # Generate fake image batch with G
181 | fake = mapper(embeddings).to(device)
182 | fake = vae_model.decode(fake)
183 |
184 | label.fill_(fake_label)
185 | # Classify all fake batch with D
186 | output = netD(fake.detach()).view(-1)
187 | # Calculate D's loss on the all-fake batch
188 | errD_fake = criterion(output, label)
189 | # Calculate the gradients for this batch
190 | errD_fake.backward()
191 | D_G_z1 = output.mean().item()
192 | # Add the gradients from the all-real and all-fake batches
193 | errD = errD_real + errD_fake
194 | # Update D
195 | optimizerD.step()
196 |
197 | ############################
198 | # (2) Update G network: maximize log(D(G(z)))
199 | ###########################
200 | mapper.zero_grad()
201 | label.fill_(real_label) # fake labels are real for generator cost
202 | # Since we just updated D, perform another forward pass of all-fake batch through D
203 | output = netD(fake).view(-1)
204 | # Calculate G's loss based on this output
205 | errG = criterion(output, label)
206 | # Calculate gradients for G
207 | errG.backward()
208 | D_G_z2 = output.mean().item()
209 | # Update G
210 | optimizerG.step()
211 |
212 | # Output training stats
213 | if i % 50 == 0:
214 | print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
215 | % (epoch, num_epochs, i, len(dataloader),
216 | errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
217 | randomsamples = []
218 | # for i in range(64):
219 |
220 | with torch.no_grad():
221 | mapped = mapper(torch.randn(64,300).to(device))
222 | randomsamples.append(mapped.to(device))
223 |
224 | embed_samples = []
225 | # for i in range(64):
226 | one_embed = randomSample_from_embeddings(64).to(device)
227 | with torch.no_grad():
228 | mapped_embed = mapper(one_embed).to(device)
229 | embed_samples.append(mapped_embed)
230 |
231 |
232 |
233 | randomsamples = torch.stack(randomsamples, dim=0)
234 | recons_randoms = vae_model.decode(randomsamples)
235 |
236 | embedsamples = torch.stack(embed_samples, dim=0)
237 | recons_embeds = vae_model.decode(embedsamples)
238 |
239 |
240 |
241 | vutils.save_image(recons_randoms.data,
242 | f"./vae_with_disc/test_random{i}.png",
243 | normalize=True,
244 | nrow=12)
245 | vutils.save_image(recons_embeds.data,
246 | f"./vae_with_disc/test_embed{i}.png",
247 | normalize=True,
248 | nrow=12)
249 |
250 | # Save Losses for plotting later
251 | G_losses.append(errG.item())
252 | D_losses.append(errD.item())
253 |
254 | # Check how the generator is doing by saving G's output on fixed_noise
255 | if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
256 | with torch.no_grad():
257 | fake = mapper(fixed_noise).detach().to(device)
258 | img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
259 | # vutils.save_image(img_list,
260 | # f"./vae_with_disc/samples_{i}.png",
261 | # normalize=True,
262 | # nrow=12)
263 |
264 | iters += 1
265 | return vae_model, mapper, netD
266 |
267 | def load_vae_model(vae_checkpointPath, config):
268 | model = vae_models[config['model_params']['name']](**config['model_params'])
269 | experiment = VAEXperiment(model,
270 | config['exp_params'])
271 |
272 |
273 | checkpoint = torch.load(vae_checkpointPath, map_location=lambda storage, loc: storage)
274 | new_ckpoint = {}
275 | for k in checkpoint["state_dict"].keys():
276 | newKey = k.split("model.")[1]
277 | new_ckpoint[newKey] = checkpoint["state_dict"][k]
278 |
279 | model.load_state_dict(new_ckpoint)
280 | model.eval()
281 | return model
282 |
283 |
284 | def main():
285 | print("on main")
286 |
287 |
288 | with open("./configs/bbvae_setup2.yaml", 'r') as file:
289 | try:
290 | config = yaml.safe_load(file)
291 | except yaml.YAMLError as exc:
292 | print(exc)
293 |
294 | vae_checkpointPath = "logs/BetaVAE_B_setup2_run2/final_model_checkpoint.ckpt"
295 | batch_size = 64
296 | device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
297 | print(device)
298 |
299 | vae_model = load_vae_model(vae_checkpointPath, config).to(device)
300 |
301 | # fixed_noise = randomSample_from_embeddings(batch_size).to(device)
302 | # fixed_noise = torch.randn()
303 |
304 |
305 |
306 | mapper = EmbeddingMapping(device, 300, 128).to(device)
307 | mapper.apply(weights_init)
308 |
309 | netD = Discriminator(device).to(device)
310 | netD.apply(weights_init)
311 | vae_model, mapper, netD = trainer(vae_model=vae_model, mapper=mapper, netD=netD, batch_size=batch_size, device=device)
312 |
313 | with open("/content/drive/My Drive/vae/logogram-language-generator-master/fasttext_hello_world.pickle", "rb") as f:
314 | helloworld = pickle.load(f)
315 |
316 | # helloworld = random.sample(list(helloworld.values()), batch_size)
317 | hello = torch.from_numpy(helloworld["hello"]).to(device)
318 | world = torch.from_numpy(helloworld["world"]).to(device)
319 |
320 |
321 | helloworld = torch.stack([hello, world], dim=0).to(device)
322 |
323 | embed_samples = []
324 | with torch.no_grad():
325 | mapped_embed = mapper(helloworld).to(device)
326 | embed_samples.append(mapped_embed)
327 | embedsamples = torch.stack(embed_samples, dim=0)
328 | recons_embeds = vae_model.decode(embedsamples)
329 | vutils.save_image(recons_embeds.data,
330 | "./vae_with_disc/helloWorld.png",
331 | normalize=True,
332 | nrow=12)
333 |
334 |
335 |
336 |
337 | print("ALL DONE!")
338 |
339 |
340 |
341 |
342 |
343 | if __name__ == "__main__":
344 | main()
345 |
--------------------------------------------------------------------------------
/src/experiment.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import optim
4 | from models import BaseVAE
5 | from models.types_ import *
6 | from utils import data_loader
7 | import pytorch_lightning as pl
8 | from torchvision import transforms
9 | import torchvision.utils as vutils
10 | from torchvision.datasets import CelebA
11 | from torch.utils.data import DataLoader
12 | from Omniglot import Omniglot
13 |
14 |
15 | class VAEXperiment(pl.LightningModule):
16 |
17 | def __init__(self,
18 | vae_model: BaseVAE,
19 | params: dict) -> None:
20 | super(VAEXperiment, self).__init__()
21 |
22 | self.model = vae_model
23 | self.params = params
24 | self.curr_device = None
25 | self.hold_graph = False
26 | try:
27 | self.hold_graph = self.params['retain_first_backpass']
28 | except:
29 | pass
30 |
31 | def forward(self, input: Tensor, **kwargs) -> Tensor:
32 | return self.model(input, **kwargs)
33 |
34 | def training_step(self, batch, batch_idx, optimizer_idx = 0):
35 | real_img, labels = batch
36 | self.curr_device = real_img.device
37 |
38 | results = self.forward(real_img, labels = labels)
39 | train_loss = self.model.loss_function(*results,
40 | M_N = self.params['batch_size']/ self.num_train_imgs,
41 | optimizer_idx=optimizer_idx,
42 | batch_idx = batch_idx)
43 |
44 | self.logger.experiment.log({key: val.item() for key, val in train_loss.items()})
45 |
46 | return train_loss
47 |
48 | def validation_step(self, batch, batch_idx, optimizer_idx = 0):
49 | real_img, labels = batch
50 | self.curr_device = real_img.device
51 |
52 | results = self.forward(real_img, labels = labels)
53 | val_loss = self.model.loss_function(*results,
54 | M_N = self.params['batch_size']/ self.num_val_imgs,
55 | optimizer_idx = optimizer_idx,
56 | batch_idx = batch_idx)
57 |
58 | return val_loss
59 |
60 | def validation_end(self, outputs):
61 | avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
62 | tensorboard_logs = {'avg_val_loss': avg_loss}
63 | self.sample_images()
64 | ###############################
65 | ###############################
66 | return {'val_loss': avg_loss, 'log': tensorboard_logs}
67 |
68 | def sample_images(self):
69 | # Get sample reconstruction image
70 | test_input, test_label = next(iter(self.sample_dataloader))
71 | test_input = test_input.to(self.curr_device)
72 | test_label = test_label.to(self.curr_device)
73 | recons = self.model.generate(test_input, labels = test_label)
74 | vutils.save_image(recons.data,
75 | f"{self.logger.save_dir}{self.logger.name}/version_{self.logger.version}/"
76 | f"recons_{self.logger.name}_{self.current_epoch}.png",
77 | normalize=True,
78 | nrow=12)
79 |
80 | # vutils.save_image(test_input.data,
81 | # f"{self.logger.save_dir}{self.logger.name}/version_{self.logger.version}/"
82 | # f"real_img_{self.logger.name}_{self.current_epoch}.png",
83 | # normalize=True,
84 | # nrow=12)
85 |
86 | try:
87 | samples = self.model.sample(144,
88 | self.curr_device,
89 | labels = test_label)
90 | vutils.save_image(samples.cpu().data,
91 | f"{self.logger.save_dir}{self.logger.name}/version_{self.logger.version}/"
92 | f"{self.logger.name}_{self.current_epoch}.png",
93 | normalize=True,
94 | nrow=12)
95 | except:
96 | pass
97 |
98 |
99 | del test_input, recons #, samples
100 |
101 |
102 | def configure_optimizers(self):
103 |
104 | optims = []
105 | scheds = []
106 |
107 | optimizer = optim.Adam(self.model.parameters(),
108 | lr=self.params['LR'],
109 | weight_decay=self.params['weight_decay'])
110 | optims.append(optimizer)
111 | # Check if more than 1 optimizer is required (Used for adversarial training)
112 | try:
113 | if self.params['LR_2'] is not None:
114 | optimizer2 = optim.Adam(getattr(self.model,self.params['submodel']).parameters(),
115 | lr=self.params['LR_2'])
116 | optims.append(optimizer2)
117 | except:
118 | pass
119 |
120 | try:
121 | if self.params['scheduler_gamma'] is not None:
122 | scheduler = optim.lr_scheduler.ExponentialLR(optims[0],
123 | gamma = self.params['scheduler_gamma'])
124 | scheds.append(scheduler)
125 |
126 | # Check if another scheduler is required for the second optimizer
127 | try:
128 | if self.params['scheduler_gamma_2'] is not None:
129 | scheduler2 = optim.lr_scheduler.ExponentialLR(optims[1],
130 | gamma = self.params['scheduler_gamma_2'])
131 | scheds.append(scheduler2)
132 | except:
133 | pass
134 | return optims, scheds
135 | except:
136 | return optims
137 |
138 | @data_loader
139 | def train_dataloader(self):
140 | transform = self.data_transforms()
141 |
142 | if self.params['dataset'] == 'celeba':
143 | dataset = CelebA(root = self.params['data_path'],
144 | split = "train",
145 | transform=transform,
146 | download=False)
147 | elif self.params['dataset'] == 'omniglot':
148 | dataset = Omniglot(split="train", transform=transform)
149 | else:
150 | raise ValueError('Undefined dataset type')
151 |
152 | self.num_train_imgs = len(dataset)
153 | return DataLoader(dataset,
154 | batch_size= self.params['batch_size'],
155 | shuffle = True,
156 | drop_last=True)
157 |
158 | @data_loader
159 | def val_dataloader(self):
160 | transform = self.data_transforms()
161 |
162 | if self.params['dataset'] == 'celeba':
163 | self.sample_dataloader = DataLoader(CelebA(root = self.params['data_path'],
164 | split = "test",
165 | transform=transform,
166 | download=False),
167 | batch_size= 144,
168 | shuffle = True,
169 | drop_last=True)
170 | self.num_val_imgs = len(self.sample_dataloader)
171 | elif self.params['dataset'] == 'omniglot':
172 | self.sample_dataloader = DataLoader(Omniglot(split="test", transform=transform), batch_size=144, shuffle=True, drop_last=True) # batch_size may change
173 | self.num_val_imgs = len(self.sample_dataloader)
174 | else:
175 | raise ValueError('Undefined dataset type')
176 |
177 | return self.sample_dataloader
178 |
179 | def data_transforms(self):
180 |
181 | SetRange = transforms.Lambda(lambda X: 2 * X - 1.)
182 | SetScale = transforms.Lambda(lambda X: X/X.sum(0).expand_as(X))
183 |
184 | if self.params['dataset'] == 'celeba':
185 | transform = transforms.Compose([transforms.RandomHorizontalFlip(),
186 | transforms.CenterCrop(148),
187 | transforms.Resize(self.params['img_size']),
188 | transforms.ToTensor(),
189 | SetRange])
190 | elif self.params['dataset'] == 'omniglot':
191 | #todo
192 | transform = transforms.Compose([transforms.Resize((64,64)), transforms.ToTensor()])
193 | else:
194 | raise ValueError('Undefined dataset type')
195 | return transform
196 |
197 |
--------------------------------------------------------------------------------
/src/logogram_language_generator.py:
--------------------------------------------------------------------------------
1 | from Omniglot import Omniglot
2 | import yaml
3 | import argparse
4 | import numpy as np
5 |
6 | from models import *
7 | from experiment import VAEXperiment
8 | import torch.backends.cudnn as cudnn
9 | from pytorch_lightning import Trainer
10 | from pytorch_lightning.logging import TestTubeLogger
11 | import torchvision.utils as vutils
12 | from torch.utils.data import DataLoader
13 | import torch.optim as optim
14 | from torchvision import transforms
15 | import imageio
16 | from scipy.interpolate import interp1d
17 | import pickle
18 | import random
19 | import fasttext
20 | from torchvision.transforms.functional import *
21 | from PIL import Image, ImageFont, ImageDraw
22 | from MapperLayer import umwe2vae
23 |
24 | from MapperLayer import EmbeddingMapping
25 | from MapperLayer import MultilingualMapper
26 | import matplotlib.pyplot as plt
27 | import math
28 |
29 | parser = argparse.ArgumentParser(description='Logogram Language Generator Main script')
30 | parser.add_argument('--embd_random_samples', help = 'path to the embedding samples', default='./randomSamples_fromEmbeddings/')
31 | parser.add_argument("--vae_ckp_path", type=str, default="./model_checkpoints/final_model_checkpoint.ckpt", help="checkpoint path of vae")
32 | parser.add_argument("--config_path", type=str, default="./configs/bbvae_CompleteRun.yaml", help="config")
33 | parser.add_argument("--export_path", type=str, default="./outputs/", help="export")
34 |
35 | # parser.add_argument("--test_onHelloWorld", type=bool, default=True, help="")
36 | # parser.add_argument("--emb_vector_dim", type=int, default=300, help="")
37 | # parser.add_argument("--vae_latent_dim", type=int, default=128, help="")
38 | # parser.add_argument("--mapper_numlayer", type=int, default=3, help="")
39 |
40 | parser.add_argument("--umwe_mappers_path", type=str, default="./model_checkpoints/", help="config")
41 | parser.add_argument("--words_path", type=str, default="./words.txt", help="config")
42 |
43 | # "standard" for statistical normalization "mapper" for linear layered mapper trained with discriminator
44 | parser.add_argument("--norm_option", type=str, default="standard", help="config")
45 | parser.add_argument("--mapper_layer_ckp", type=str, default="./lossy_mapper/umwe2vae.ckpt", help="config")
46 |
47 |
48 | parser.add_argument("--emb_bins_path", type=str, default="./model_checkpoints/cc_bins_128/", help="config")
49 |
50 | parser.add_argument("--gif", type=str, default="True", help="config")
51 |
52 | parser.add_argument("--siamese_mapper", type=str, default="False", help="config")
53 |
54 |
55 | args = parser.parse_args()
56 | with open(args.words_path, "r") as f:
57 | args.words = f.readlines()
58 | for word in range(len(args.words)):
59 | args.words[word] = args.words[word].split("\n")[0]
60 |
61 | args.latent_dim = 128
62 | args.emb_vector_dim = 128
63 |
64 | device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
65 |
66 |
67 | def randomSample_from_embeddings(batch_size, lang):
68 | with open(args.embd_random_samples+"sample_fasttext_embs_"+lang+"_128.pickle", "rb") as f:
69 | random_sample = pickle.load(f)
70 | if batch_size == -1:
71 | batch_size = len(list(random_sample.values()))
72 | fake_batch = random.sample(list(random_sample.values()), batch_size)
73 | fake_batch = torch.stack(fake_batch, dim=0)
74 | return fake_batch
75 |
76 | def load_vae_model(vae_checkpointPath):
77 | with open(args.config_path, 'r') as file:
78 | try:
79 | config = yaml.safe_load(file)
80 | except yaml.YAMLError as exc:
81 | print(exc)
82 |
83 | model = vae_models[config['model_params']['name']](**config['model_params'])
84 | experiment = VAEXperiment(model,
85 | config['exp_params'])
86 |
87 |
88 | checkpoint = torch.load(vae_checkpointPath, map_location=lambda storage, loc: storage)
89 | new_ckpoint = {}
90 | for k in checkpoint["state_dict"].keys():
91 | newKey = k.split("model.")[1]
92 | new_ckpoint[newKey] = checkpoint["state_dict"][k]
93 |
94 | model.load_state_dict(new_ckpoint)
95 | model.eval()
96 | return model
97 |
98 | def load_mapper(vae_model, mapper_type="mapper-disc"):
99 | if mapper_type == "mapper-disc":
100 | model = EmbeddingMapping(device, args.emb_vector_dim, args.latent_dim)
101 | model.load_state_dict(torch.load(args.mapper_layer_ckp))
102 | model.eval()
103 | elif mapper_type == "mapper-lossy":
104 | model = umwe2vae(vae_model, in_dim=128, out_dim=128)
105 | model.load_state_dict(torch.load(args.mapper_layer_ckp))
106 | model.eval()
107 | elif mapper_type == "siamese_mapper":
108 | model = MultilingualMapper(device, args.emb_vector_dim, args.latent_dim)
109 | model.load_state_dict(torch.load("./layerWithSiamese/multilingualMapper_checkpoint.pt"))
110 | model.eval()
111 | return model.to(device)
112 |
113 | def embed_words():
114 | embs = []
115 | lang = ""
116 | for word in args.words:
117 | if word.split(" - ")[1] != lang:
118 | lang = word.split(" - ")[1]
119 | ft = fasttext.load_model(args.emb_bins_path+"cc."+lang+".128.bin")
120 | ft.get_dimension()
121 | emb = torch.from_numpy(ft.get_word_vector(word.split(" - ")[0])).to(device)
122 | if lang != "en":
123 | multilingual_mapper = torch.from_numpy(torch.load(args.umwe_mappers_path+"best_mapping_"+lang+"2en.t7")).to(device)
124 | emb = torch.matmul(emb, multilingual_mapper)
125 | embs.append(emb)
126 | embs = torch.stack(embs, dim=0)
127 | return embs
128 | # numpy to tensor word embs
129 | # torch.from_numpy(helloworld["hello"]).to(device)
130 |
131 | # best_mapping = "./best_mapping_fr2en.t7"
132 | # Readed_t7 = torch.from_numpy(torch.load(best_mapping)).to(device)
133 | # mapped_hello = torch.matmul(hello, Readed_t7)
134 | # mapped_world = torch.matmul(world, Readed_t7)
135 |
136 |
137 | def normalize_embeddings(embeddings, vae_model):
138 | normalized_embs = []
139 | if args.norm_option == "standard":
140 | for word in range(embeddings.shape[0]):
141 | sample_embeds = randomSample_from_embeddings(-1, "en").to(device)
142 | std, mean = torch.std_mean(input=sample_embeds, dim=0, unbiased=True)
143 | normalized_embs.append((embeddings[word,:] - mean) / std)
144 | normalized_embs = torch.stack(normalized_embs, dim=0)
145 |
146 | elif args.norm_option == "mapper-disc" or args.norm_option == "mapper-lossy":
147 | mapperLayer = load_mapper(vae_model, args.norm_option)
148 | with torch.no_grad():
149 | normalized_embs = mapperLayer(embeddings).to(device)
150 | elif args.norm_option == "none":
151 | normalized_embs = embeddings
152 |
153 | return normalized_embs
154 |
155 | def interpolate_me_baby(emb1,emb2,n):
156 | """
157 | a diabolical function that linearly interpolates two tensors
158 | n: you better have this high
159 | """
160 |
161 | shapes = emb1.shape
162 | emb1_flattened_n = emb1.flatten().cpu().numpy()
163 | emb2_flattened_n = emb2.flatten().cpu().numpy()
164 | f = interp1d(x=[0,1], y=np.vstack([emb1_flattened_n,emb2_flattened_n]),axis=0)
165 | y = f(np.linspace(0, 1, n))
166 | L = [torch.reshape(torch.from_numpy(kral).float(), shapes).to(device) for kral in y]
167 | return torch.stack(L, dim=0) # return shape [n, latent_dim]
168 |
169 | def add_text_to_image(tensir, word): #tensir: tenSÖR DID YOU GOT THE JOKE :d
170 | width, height = tensir.shape
171 | img_pil = transforms.functional.to_pil_image(tensir) #.resize((width,height + 30),0) #resmin altına yazı için boşluk
172 | draw = ImageDraw.Draw(img_pil)
173 | fontsize = 40
174 | font = ImageFont.truetype("./TimesNewRoman400.ttf", size=fontsize) #font olayını ayarlayamadım, fontu truetype dosyası olarak indirmek gerekiyor
175 | wordSize = font.getsize(word)
176 | #draw.text(xy = (width//3,height+15), text = word, fill = "white",font = font, anchor = "ms") #fontu ayarlamayınca da şu anchor'lama olayı yapılamıyormuş. yazı şu anda sol altta, ben alt ortada olsun istemiştim ama böyle de güzel oldu.
177 | draw.text(xy = (int(width/2)-int(wordSize[0]/2),height-wordSize[1]-5), text = word, fill = "black", font = font)
178 | processed_tensir = to_tensor(img_pil)
179 | return torch.reshape(processed_tensir, [width,height])
180 |
181 | def make_transparent(tensor_img):
182 | img = transforms.functional.to_pil_image(tensor_img)
183 | img = img.convert("RGBA")
184 | datas = img.getdata()
185 |
186 | newData = []
187 | for item in datas:
188 | if item[0] == 255 and item[1] == 255 and item[2] == 255:
189 | newData.append((255, 255, 255, 0))
190 | else:
191 | newData.append(item)
192 |
193 | img.putdata(newData)
194 | return to_tensor(img)
195 |
196 |
197 | def savegrid(ims, rows=None, cols=None, fill=True, showax=False):
198 | if rows is None != cols is None:
199 | raise ValueError("Set either both rows and cols or neither.")
200 |
201 | if rows is None:
202 | rows = len(ims)
203 | cols = 1
204 |
205 | gridspec_kw = {'wspace': 0, 'hspace': 0} if fill else {}
206 | fig,axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw)
207 |
208 | if fill:
209 | bleed = 0
210 | fig.subplots_adjust(left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed))
211 |
212 | for ax,im in zip(axarr.ravel(), ims):
213 | ax.imshow(im, cmap='gray')
214 | if not showax:
215 | ax.set_axis_off()
216 |
217 | kwargs = {'pad_inches': .01} if fill else {}
218 | fig.savefig(args.export_path+'logograms.png', **kwargs)
219 |
220 | def post_proccess(imgs):
221 | imgs = list(map(lambda x:resize(to_pil_image(x),256), imgs))
222 | imgs = list(map(lambda x:adjust_saturation(x,1.5), imgs))
223 | imgs = list(map(lambda x:adjust_gamma(x,1.5), imgs))
224 | imgs = list(map(lambda x:adjust_contrast(x,2), imgs))
225 | imgs = list(map(lambda x:to_tensor(x), imgs))
226 | for x in imgs:
227 | x[x>0.5]=1
228 | for x in imgs:
229 | x[x<0.4]=0
230 |
231 | #save_image(imgs, 'img3.png')
232 | return imgs[0]
233 |
234 | def main():
235 | print("started...")
236 |
237 | vae_model = load_vae_model(args.vae_ckp_path).to(device)
238 |
239 | word_embeddings = embed_words()
240 |
241 | normalized_embeddings = normalize_embeddings(word_embeddings, vae_model)
242 |
243 | if args.siamese_mapper == "True":
244 | siam_mapper = load_mapper(vae_model, mapper_type="siamese_mapper")
245 | with torch.no_grad():
246 | normalized_embeddings = siam_mapper(normalized_embeddings).to(device)
247 | normalized_embeddings = normalize_embeddings(normalized_embeddings, vae_model)
248 |
249 |
250 | if args.gif == "True":
251 | n = 5
252 | interpolations = []
253 | for word in range(normalized_embeddings.shape[0]-1):
254 | w1 = normalized_embeddings[word,:128]
255 | w2 = normalized_embeddings[word+1,:128]
256 |
257 | intplt = interpolate_me_baby(w1, w2, n)
258 | interpolations.append(intplt)
259 | interpolations = torch.cat(interpolations, dim=0).to(device)
260 |
261 | images = []
262 | for word in range(interpolations.shape[0]):
263 | if word % n == 0:
264 | word_duration = 10
265 | textWord = args.words[int(word/n)] #.split(" - ")[0]
266 | else:
267 | word_duration = 1
268 | textWord = ""
269 |
270 | for duration in range(word_duration):
271 | logogram = vae_model.decode(interpolations[word,:128].float())
272 | # logogram = torch.cat([torch.zeros(1,1,64,64).float(), logogram.data.cpu()], dim=1).cpu()
273 | # images.append(logogram)
274 | processed = post_proccess(torch.reshape(logogram.data, [1,64,64]).cpu())
275 | processed = torch.reshape(processed, [256,256])
276 | processed = add_text_to_image(processed, textWord)
277 | images.append(processed)
278 | imageio.mimwrite(args.export_path+"helloToWorld.gif", images)
279 |
280 | else:
281 | images = []
282 | for word in range(normalized_embeddings.shape[0]):
283 | logogram = vae_model.decode(normalized_embeddings[word,:128].float())
284 | processed = post_proccess(torch.reshape(logogram.data, [1,64,64]).cpu())
285 | processed = torch.reshape(processed, [256,256])
286 | processed = add_text_to_image(processed, args.words[word])
287 | images.append(processed)
288 | savegrid(images, cols=int(math.sqrt(len(args.words))), rows=int(len(args.words)/int(math.sqrt(len(args.words)))))
289 |
290 |
291 |
292 |
293 | # vutils.save_image(logogram.data,
294 | # args.export_path+args.words[word]+".png",
295 | # normalize=True,
296 | # nrow=12)
297 |
298 |
299 |
300 | print("ALL DONE!")
301 |
302 |
303 | # python logogram_language_generator.py --embd_random_samples ./randomSamples_fromEmbeddings/sample_fasttext_embs_en_128.pickle --vae_ckp_path ./model_checkpoints/final_model_checkpoint.ckpt --config_path ./configs/bbvae_CompleteRun.yaml --export_path ./outputs/ --umwe_mappers_path ./model_checkpoints/ --words_path words.txt --norm_option standard --mapper_layer_ckp ./vae_with_disc/mapper_checkpoint.pt --emb_bins_path ./model_checkpoints/cc_bins_128/ --gif False
304 |
305 |
306 | if __name__ == "__main__":
307 | main()
308 |
--------------------------------------------------------------------------------
/src/lossyMapper_train.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import argparse
3 | import numpy as np
4 |
5 | from models import *
6 | from experiment import VAEXperiment
7 | import torch.backends.cudnn as cudnn
8 | from pytorch_lightning import Trainer
9 | from pytorch_lightning.logging import TestTubeLogger
10 |
11 | from torchvision.utils import save_image
12 | from umwe2vae import umwe2vae
13 |
14 |
15 | parser = argparse.ArgumentParser(description='Generic runner for VAE models')
16 | parser.add_argument('--config', '-c',
17 | dest="filename",
18 | metavar='FILE',
19 | help = 'path to the config file',
20 | default='configs/vae.yaml')
21 |
22 | args = parser.parse_args()
23 |
24 |
25 | def randomSample_from_embeddings(batch_size):
26 | with open(args.embd_random_samples, "rb") as f:
27 | random_sample = pickle.load(f)
28 |
29 | fake_batch = random.sample(list(random_sample.values()), batch_size)
30 | fake_batch = torch.stack(fake_batch, dim=0)
31 | return fake_batch
32 |
33 |
34 | with open(args.filename, 'r') as file:
35 | try:
36 | config = yaml.safe_load(file)
37 | except yaml.YAMLError as exc:
38 | print(exc)
39 |
40 | model = vae_models[config['model_params']['name']](**config['model_params'])
41 |
42 | checkpoint = torch.load("./model_checkpoints/final_model_checkpoint.ckpt")
43 | state_dict = {}
44 | for k in checkpoint['state_dict'].keys():
45 | state_dict[k[6:]] = checkpoint['state_dict'][k]
46 |
47 | model.load_state_dict(state_dict)
48 |
49 | ldr = [torch.randn((64,128)) for _ in range(10)]
50 |
51 | inp = torch.randn((5,128))
52 | u2v = umwe2vae(model, 128, 128)
53 | save_image([u2v(inp)[0],u2v(inp)[1],u2v(inp)[2],u2v(inp)[3],u2v(inp)[4]], 'img1.png')
54 | u2v.train(ldr)
55 |
56 | save_image([u2v(inp)[0],u2v(inp)[1],u2v(inp)[2],u2v(inp)[3],u2v(inp)[4]], 'img2.png')
57 |
58 | checkpoint = torch.load("./umwe2vae.ckpt")
59 | u2v.load_state_dict(checkpoint['model_state_dict'])
60 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import *
2 | from .vanilla_vae import *
3 | from .gamma_vae import *
4 | from .beta_vae import *
5 | from .wae_mmd import *
6 | from .cvae import *
7 | from .hvae import *
8 | from .vampvae import *
9 | from .iwae import *
10 | from .dfcvae import *
11 | from .mssim_vae import MSSIMVAE
12 | from .fvae import *
13 | from .cat_vae import *
14 | from .joint_vae import *
15 | from .info_vae import *
16 | # from .twostage_vae import *
17 | from .lvae import LVAE
18 | from .logcosh_vae import *
19 | from .swae import *
20 | from .miwae import *
21 | from .vq_vae import *
22 | from .betatc_vae import *
23 | from .dip_vae import *
24 |
25 |
26 | # Aliases
27 | VAE = VanillaVAE
28 | GaussianVAE = VanillaVAE
29 | CVAE = ConditionalVAE
30 | GumbelVAE = CategoricalVAE
31 |
32 | vae_models = {'HVAE':HVAE,
33 | 'LVAE':LVAE,
34 | 'IWAE':IWAE,
35 | 'SWAE':SWAE,
36 | 'MIWAE':MIWAE,
37 | 'VQVAE':VQVAE,
38 | 'DFCVAE':DFCVAE,
39 | 'DIPVAE':DIPVAE,
40 | 'BetaVAE':BetaVAE,
41 | 'InfoVAE':InfoVAE,
42 | 'WAE_MMD':WAE_MMD,
43 | 'VampVAE': VampVAE,
44 | 'GammaVAE':GammaVAE,
45 | 'MSSIMVAE':MSSIMVAE,
46 | 'JointVAE':JointVAE,
47 | 'BetaTCVAE':BetaTCVAE,
48 | 'FactorVAE':FactorVAE,
49 | 'LogCoshVAE':LogCoshVAE,
50 | 'VanillaVAE':VanillaVAE,
51 | 'ConditionalVAE':ConditionalVAE,
52 | 'CategoricalVAE':CategoricalVAE}
53 |
--------------------------------------------------------------------------------
/src/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/base.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/base.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/beta_vae.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/beta_vae.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/betatc_vae.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/betatc_vae.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/cat_vae.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/cat_vae.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/cvae.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/cvae.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/dfcvae.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/dfcvae.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/dip_vae.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/dip_vae.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/fvae.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/fvae.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/gamma_vae.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/gamma_vae.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/hvae.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/hvae.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/info_vae.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/info_vae.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/iwae.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/iwae.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/joint_vae.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/joint_vae.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/logcosh_vae.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/logcosh_vae.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/lvae.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/lvae.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/miwae.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/miwae.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/mssim_vae.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/mssim_vae.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/swae.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/swae.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/types_.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/types_.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/vampvae.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/vampvae.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/vanilla_vae.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/vanilla_vae.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/vq_vae.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/vq_vae.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/__pycache__/wae_mmd.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/models/__pycache__/wae_mmd.cpython-36.pyc
--------------------------------------------------------------------------------
/src/models/base.py:
--------------------------------------------------------------------------------
1 | from .types_ import *
2 | from torch import nn
3 | from abc import abstractmethod
4 |
5 | class BaseVAE(nn.Module):
6 |
7 | def __init__(self) -> None:
8 | super(BaseVAE, self).__init__()
9 |
10 | def encode(self, input: Tensor) -> List[Tensor]:
11 | raise NotImplementedError
12 |
13 | def decode(self, input: Tensor) -> Any:
14 | raise NotImplementedError
15 |
16 | def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor:
17 | raise RuntimeWarning()
18 |
19 | def generate(self, x: Tensor, **kwargs) -> Tensor:
20 | raise NotImplementedError
21 |
22 | @abstractmethod
23 | def forward(self, *inputs: Tensor) -> Tensor:
24 | pass
25 |
26 | @abstractmethod
27 | def loss_function(self, *inputs: Any, **kwargs) -> Tensor:
28 | pass
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/src/models/beta_vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import BaseVAE
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from .types_ import *
6 |
7 |
8 | class BetaVAE(BaseVAE):
9 |
10 | num_iter = 0 # Global static variable to keep track of iterations
11 |
12 | def __init__(self,
13 | in_channels: int,
14 | latent_dim: int,
15 | hidden_dims: List = None,
16 | beta: int = 4,
17 | gamma:float = 1000.,
18 | max_capacity: int = 25,
19 | Capacity_max_iter: int = 1e5,
20 | loss_type:str = 'B',
21 | **kwargs) -> None:
22 | super(BetaVAE, self).__init__()
23 |
24 | self.latent_dim = latent_dim
25 | self.beta = beta
26 | self.gamma = gamma
27 | self.loss_type = loss_type
28 | self.C_max = torch.Tensor([max_capacity])
29 | self.C_stop_iter = Capacity_max_iter
30 |
31 | modules = []
32 | if hidden_dims is None:
33 | hidden_dims = [32, 64, 128, 256, 512]
34 |
35 | # Build Encoder
36 | for h_dim in hidden_dims:
37 | modules.append(
38 | nn.Sequential(
39 | nn.Conv2d(in_channels, out_channels=h_dim,
40 | kernel_size= 3, stride= 2, padding = 1),
41 | nn.BatchNorm2d(h_dim),
42 | nn.LeakyReLU())
43 | )
44 | in_channels = h_dim
45 |
46 | self.encoder = nn.Sequential(*modules)
47 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
48 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
49 |
50 |
51 | # Build Decoder
52 | modules = []
53 |
54 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
55 |
56 | hidden_dims.reverse()
57 |
58 | for i in range(len(hidden_dims) - 1):
59 | modules.append(
60 | nn.Sequential(
61 | nn.ConvTranspose2d(hidden_dims[i],
62 | hidden_dims[i + 1],
63 | kernel_size=3,
64 | stride = 2,
65 | padding=1,
66 | output_padding=1),
67 | nn.BatchNorm2d(hidden_dims[i + 1]),
68 | nn.LeakyReLU())
69 | )
70 |
71 |
72 |
73 | self.decoder = nn.Sequential(*modules)
74 |
75 | self.final_layer = nn.Sequential(
76 | nn.ConvTranspose2d(hidden_dims[-1],
77 | hidden_dims[-1],
78 | kernel_size=3,
79 | stride=2,
80 | padding=1,
81 | output_padding=1),
82 | nn.BatchNorm2d(hidden_dims[-1]),
83 | nn.LeakyReLU(),
84 | nn.Conv2d(hidden_dims[-1], out_channels= 1,
85 | kernel_size= 3, padding= 1),
86 | nn.Tanh())
87 |
88 | def encode(self, input: Tensor) -> List[Tensor]:
89 | """
90 | Encodes the input by passing through the encoder network
91 | and returns the latent codes.
92 | :param input: (Tensor) Input tensor to encoder [N x C x H x W]
93 | :return: (Tensor) List of latent codes
94 | """
95 | result = self.encoder(input)
96 | result = torch.flatten(result, start_dim=1)
97 |
98 | # Split the result into mu and var components
99 | # of the latent Gaussian distribution
100 | mu = self.fc_mu(result)
101 | log_var = self.fc_var(result)
102 |
103 | return [mu, log_var]
104 |
105 | def decode(self, z: Tensor) -> Tensor:
106 | # mapping layer
107 | result = self.decoder_input(z)
108 | result = result.view(-1, 512, 2, 2)
109 | result = self.decoder(result)
110 | result = self.final_layer(result)
111 | return result
112 |
113 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
114 | """
115 | Will a single z be enough ti compute the expectation
116 | for the loss??
117 | :param mu: (Tensor) Mean of the latent Gaussian
118 | :param logvar: (Tensor) Standard deviation of the latent Gaussian
119 | :return:
120 | """
121 | std = torch.exp(0.5 * logvar)
122 | eps = torch.randn_like(std)
123 | return eps * std + mu
124 |
125 | def forward(self, input: Tensor, **kwargs) -> Tensor:
126 | mu, log_var = self.encode(input)
127 | z = self.reparameterize(mu, log_var)
128 | return [self.decode(z), input, mu, log_var]
129 |
130 | def loss_function(self,
131 | *args,
132 | **kwargs) -> dict:
133 | self.num_iter += 1
134 | recons = args[0]
135 | input = args[1]
136 | mu = args[2]
137 | log_var = args[3]
138 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
139 |
140 | recons_loss =F.mse_loss(recons, input)
141 |
142 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
143 |
144 | if self.loss_type == 'H': # https://openreview.net/forum?id=Sy2fzU9gl
145 | loss = recons_loss + self.beta * kld_weight * kld_loss
146 | elif self.loss_type == 'B': # https://arxiv.org/pdf/1804.03599.pdf
147 | self.C_max = self.C_max.to(input.device)
148 | C = torch.clamp(self.C_max/self.C_stop_iter * self.num_iter, 0, self.C_max.data[0])
149 | loss = recons_loss + self.gamma * kld_weight* (kld_loss - C).abs()
150 | else:
151 | raise ValueError('Undefined loss type.')
152 |
153 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss}
154 |
155 | def sample(self,
156 | num_samples:int,
157 | current_device: int, **kwargs) -> Tensor:
158 | """
159 | Samples from the latent space and return the corresponding
160 | image space map.
161 | :param num_samples: (Int) Number of samples
162 | :param current_device: (Int) Device to run the model
163 | :return: (Tensor)
164 | """
165 | z = torch.randn(num_samples,
166 | self.latent_dim)
167 |
168 | z = z.to(current_device)
169 |
170 | samples = self.decode(z)
171 | return samples
172 |
173 | def generate(self, x: Tensor, **kwargs) -> Tensor:
174 | """
175 | Given an input image x, returns the reconstructed image
176 | :param x: (Tensor) [B x C x H x W]
177 | :return: (Tensor) [B x C x H x W]
178 | """
179 |
180 | return self.forward(x)[0]
--------------------------------------------------------------------------------
/src/randomSamples_fromEmbeddings/sample_fasttext_embs_en.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/randomSamples_fromEmbeddings/sample_fasttext_embs_en.pickle
--------------------------------------------------------------------------------
/src/randomSamples_fromEmbeddings/sample_fasttext_embs_en_128.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/randomSamples_fromEmbeddings/sample_fasttext_embs_en_128.pickle
--------------------------------------------------------------------------------
/src/randomSamples_fromEmbeddings/sample_fasttext_embs_es_128.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/randomSamples_fromEmbeddings/sample_fasttext_embs_es_128.pickle
--------------------------------------------------------------------------------
/src/randomSamples_fromEmbeddings/sample_fasttext_embs_fr.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/randomSamples_fromEmbeddings/sample_fasttext_embs_fr.pickle
--------------------------------------------------------------------------------
/src/randomSamples_fromEmbeddings/sample_fasttext_embs_fr_128.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/randomSamples_fromEmbeddings/sample_fasttext_embs_fr_128.pickle
--------------------------------------------------------------------------------
/src/randomSamples_fromEmbeddings/sample_fasttext_embs_it_128.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/randomSamples_fromEmbeddings/sample_fasttext_embs_it_128.pickle
--------------------------------------------------------------------------------
/src/randomSamples_fromEmbeddings/sample_from_embeds.py:
--------------------------------------------------------------------------------
1 | import fasttext
2 | import random
3 | import torch
4 | import pickle
5 |
6 | langs = ["fr", "it", "es", "en"]
7 |
8 | for lang in langs:
9 | ft = fasttext.load_model("../model_checkpoints/cc_bins_128/cc."+lang+".128.bin")
10 | ft.get_dimension()
11 | size = 10000
12 | random_words = random.sample(ft.get_words(), size)
13 | samples = {}
14 | for i in range(size):
15 | samples[random_words[i]] = torch.from_numpy(ft.get_word_vector(random_words[i]))
16 | filename = "sample_fasttext_embs_"+lang+"_128.pickle"
17 | outfile = open(filename,'wb')
18 |
19 | pickle.dump(samples,outfile)
20 | outfile.close()
21 |
22 |
--------------------------------------------------------------------------------
/src/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch-lightning==0.6.0
2 | PyYAML==5.1.2
3 | tensorboard==2.1.0
4 | tensorboardX==1.6
5 | terminado==0.8.1
6 | test-tube==0.7.0
7 | torch==1.4.0
8 | torchfile==0.1.0
9 | torchnet==0.0.4
10 | torchsummary==1.5.1
11 | torchvision==0.5.0
12 |
13 |
--------------------------------------------------------------------------------
/src/run.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import argparse
3 | import numpy as np
4 |
5 | from models import *
6 | from experiment import VAEXperiment
7 | import torch.backends.cudnn as cudnn
8 | from pytorch_lightning import Trainer
9 | from pytorch_lightning.logging import TestTubeLogger
10 |
11 |
12 | parser = argparse.ArgumentParser(description='Generic runner for VAE models')
13 | parser.add_argument('--config', '-c',
14 | dest="filename",
15 | metavar='FILE',
16 | help = 'path to the config file',
17 | default='configs/vae.yaml')
18 |
19 | args = parser.parse_args()
20 | with open(args.filename, 'r') as file:
21 | try:
22 | config = yaml.safe_load(file)
23 | except yaml.YAMLError as exc:
24 | print(exc)
25 |
26 |
27 | tt_logger = TestTubeLogger(
28 | save_dir=config['logging_params']['save_dir'],
29 | name=config['logging_params']['name'],
30 | debug=False,
31 | create_git_tag=False,
32 | )
33 |
34 | # For reproducibility
35 | torch.manual_seed(config['logging_params']['manual_seed'])
36 | np.random.seed(config['logging_params']['manual_seed'])
37 | cudnn.deterministic = True
38 | cudnn.benchmark = False
39 |
40 | model = vae_models[config['model_params']['name']](**config['model_params'])
41 | experiment = VAEXperiment(model,
42 | config['exp_params'])
43 | # print(f"{tt_logger.save_dir}"+f"{tt_logger.name}"+"/final_model_checkpoint.ckpt")
44 | print(f"{tt_logger.name}")
45 |
46 | runner = Trainer(default_save_path=f"{tt_logger.save_dir}",
47 | min_nb_epochs=1,
48 | logger=tt_logger,
49 | log_save_interval=100,
50 | train_percent_check=1.,
51 | val_percent_check=1.,
52 | num_sanity_val_steps=5,
53 | early_stop_callback = False,
54 | **config['trainer_params'])
55 |
56 | print(f"======= Training {config['model_params']['name']} =======")
57 | runner.fit(experiment)
58 | runner.save_checkpoint(f"{tt_logger.save_dir}"+f"{tt_logger.name}"+"/final_model_checkpoint.ckpt")
59 |
60 |
61 |
--------------------------------------------------------------------------------
/src/test.py:
--------------------------------------------------------------------------------
1 | print("hello from nodejs with python")
2 |
3 | for i in range(1000):
4 | print(i)
5 |
--------------------------------------------------------------------------------
/src/train_mapper_and_discriminator.py:
--------------------------------------------------------------------------------
1 | from Omniglot import Omniglot
2 | import yaml
3 | import argparse
4 | import numpy as np
5 |
6 | from models import *
7 | from experiment import VAEXperiment
8 | import torch.backends.cudnn as cudnn
9 | from pytorch_lightning import Trainer
10 | from pytorch_lightning.logging import TestTubeLogger
11 | import torchvision.utils as vutils
12 | from torch.utils.data import DataLoader
13 | import torch.optim as optim
14 | from torchvision import transforms
15 |
16 | import pickle
17 | import random
18 |
19 | import argparse
20 |
21 | # python train_mapper_and_discriminator.py --embd_random_samples ./randomSamples_fromEmbeddings/sample_fasttext_embs_en_128.pickle --vae_ckp_path ./model_checkpoints/final_model_checkpoint.ckpt --config_path ./configs/bbvae_CompleteRun.yaml
22 |
23 | #python logogram_language_generator.py --embd_random_samples ./randomSamples_fromEmbeddings/sample_fasttext_embs_en_128.pickle
24 | #--vae_ckp_path ./model_checkpoints/final_model_checkpoint.ckpt --config_path ./configs/bbvae_CompleteRun.yaml
25 | # --export_path ./outputs/ --words_path words.txt --norm_option standard
26 | # --emb_bins_path ./model_checkpoints/cc_bins_128/
27 |
28 | parser = argparse.ArgumentParser(description='Mapper training for feeding embeddings to VAE')
29 | parser.add_argument('--embd_random_samples', help = 'path to the embedding samples', default='/content/drive/My Drive/vae/logogram-language-generator-master/sample_fasttext_embs.pickle')
30 | parser.add_argument("--vae_ckp_path", type=str, default="logs/BetaVAE_B_setup2_run2/final_model_checkpoint.ckpt", help="checkpoint path of vae")
31 | parser.add_argument("--config_path", type=str, default="./configs/bbvae_setup2.yaml", help="config")
32 | parser.add_argument("--export_path", type=str, default="./vae_with_disc/", help="export")
33 | parser.add_argument("--test_onHelloWorld", type=bool, default=False, help="")
34 | parser.add_argument("--emb_vector_dim", type=int, default=300, help="")
35 | parser.add_argument("--vae_latent_dim", type=int, default=128, help="")
36 | parser.add_argument("--mapper_numlayer", type=int, default=3, help="")
37 |
38 |
39 | args = parser.parse_args()
40 |
41 |
42 | class EmbeddingMapping(nn.Module):
43 | def __init__(self, device, embedding_vector_dim = 300, decoder_input_dim=128):
44 | super(EmbeddingMapping, self).__init__()
45 | self.device = device
46 | self.embedding_vector_dim = embedding_vector_dim
47 | self.decoder_input_dim = decoder_input_dim
48 | self.mapper_numlayer = args.mapper_numlayer
49 |
50 | self.linear_layers = []
51 | self.batch_norms = []
52 | for layer in range(0, self.mapper_numlayer-1):
53 | self.linear_layers.append(nn.Linear(embedding_vector_dim, embedding_vector_dim))
54 | self.batch_norms.append(nn.BatchNorm1d(embedding_vector_dim))
55 |
56 | # final layer
57 | self.linear_layers.append(nn.Linear(embedding_vector_dim, decoder_input_dim))
58 | self.batch_norms.append(nn.BatchNorm1d(decoder_input_dim))
59 |
60 |
61 | self.linear_layers = nn.ModuleList(self.linear_layers)
62 | self.batch_norms = nn.ModuleList(self.batch_norms)
63 |
64 |
65 | self.relu = nn.ReLU()
66 |
67 | def forward(self, embedding_vector):
68 | inp = embedding_vector
69 | for layer in range(self.mapper_numlayer):
70 | out = self.linear_layers[layer](inp)
71 | out = self.batch_norms[layer](out)
72 | out = self.relu(out)
73 | inp = out
74 | return out
75 |
76 | class Discriminator(nn.Module):
77 | def __init__(self, ngpu):
78 | super(Discriminator, self).__init__()
79 | self.ngpu = ngpu
80 | self.ndf = 64
81 | self.nc = 1
82 | self.main = nn.Sequential(
83 | # input is (nc) x 64 x 64
84 | nn.Conv2d(self.nc, self.ndf, 4, 2, 1, bias=False),
85 | nn.LeakyReLU(0.2, inplace=True),
86 | # state size. (self.ndf) x 32 x 32
87 | nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1, bias=False),
88 | nn.BatchNorm2d(self.ndf * 2),
89 | nn.LeakyReLU(0.2, inplace=True),
90 | # state size. (self.ndf*2) x 16 x 16
91 | nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False),
92 | nn.BatchNorm2d(self.ndf * 4),
93 | nn.LeakyReLU(0.2, inplace=True),
94 | # state size. (self.ndf*4) x 8 x 8
95 | nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1, bias=False),
96 | nn.BatchNorm2d(self.ndf * 8),
97 | nn.LeakyReLU(0.2, inplace=True),
98 | # state size. (self.ndf*8) x 4 x 4
99 | nn.Conv2d(self.ndf * 8, 1, 4, 1, 0, bias=False),
100 | nn.Sigmoid()
101 | )
102 |
103 | def forward(self, input):
104 | return self.main(input)
105 |
106 |
107 | def weights_init(m):
108 | classname = m.__class__.__name__
109 | if classname.find('Conv') != -1:
110 | nn.init.normal_(m.weight.data, 0.0, 0.02)
111 | elif classname.find('BatchNorm') != -1:
112 | nn.init.normal_(m.weight.data, 1.0, 0.02)
113 | nn.init.constant_(m.bias.data, 0)
114 |
115 |
116 | def get_dataLoader(batch_size):
117 | transform = data_transforms()
118 | dataset = Omniglot(split="train", transform=transform)
119 | num_train_imgs = len(dataset)
120 | return DataLoader(dataset,
121 | batch_size= batch_size,
122 | shuffle = True,
123 | drop_last=True)
124 |
125 |
126 | def data_transforms():
127 | SetRange = transforms.Lambda(lambda X: 2 * X - 1.)
128 | SetScale = transforms.Lambda(lambda X: X/X.sum(0).expand_as(X))
129 | transform = transforms.Compose([transforms.Resize((64,64)), transforms.ToTensor()])
130 | return transform
131 |
132 |
133 | def randomSample_from_embeddings(batch_size):
134 | with open(args.embd_random_samples, "rb") as f:
135 | random_sample = pickle.load(f)
136 |
137 | fake_batch = random.sample(list(random_sample.values()), batch_size)
138 | fake_batch = torch.stack(fake_batch, dim=0)
139 | return fake_batch
140 |
141 |
142 |
143 |
144 |
145 | def trainer(vae_model, mapper, netD, batch_size, device):
146 | # Training Loop
147 | # Lists to keep track of progress
148 | img_list = []
149 | G_losses = []
150 | D_losses = []
151 | iters = 0
152 | num_epochs = 10
153 | dataloader = get_dataLoader(batch_size=batch_size)
154 |
155 | # Initialize BCELoss function
156 | criterion = nn.BCELoss()
157 | lr = 0.00001
158 | beta1 = 0.5
159 | # Setup Adam optimizers for both G and D
160 | optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
161 | optimizerG = optim.Adam(mapper.parameters(), lr=lr, betas=(beta1, 0.999))
162 |
163 |
164 |
165 | # Create batch of latent vectors that we will use to visualize
166 | # the progression of the generator
167 | fixed_noise = randomSample_from_embeddings(batch_size).to(device)
168 |
169 | criterion = nn.BCELoss()
170 |
171 | # Establish convention for real and fake labels during training
172 | real_label = 1.
173 | fake_label = 0.
174 |
175 |
176 | print("Starting Training Loop...")
177 | # For each epoch
178 | for epoch in range(num_epochs):
179 | # For each batch in the dataloader
180 | for i, data in enumerate(dataloader, 0):
181 |
182 | ############################
183 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
184 | ###########################
185 | ## Train with all-real batch
186 | netD.zero_grad()
187 | # Format batch
188 | real_cpu = data[0].to(device)
189 | b_size = real_cpu.size(0)
190 | label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
191 | # Forward pass real batch through D
192 | output = netD(real_cpu).view(-1)
193 | # Calculate loss on all-real batch
194 | errD_real = criterion(output, label)
195 | # Calculate gradients for D in backward pass
196 | errD_real.backward()
197 | D_x = output.mean().item()
198 |
199 | ## Train with all-fake batch
200 | # Generate batch of latent vectors
201 | embeddings = randomSample_from_embeddings(batch_size).to(device)
202 | # Generate fake image batch with G
203 | fake = mapper(embeddings).to(device)
204 | fake = vae_model.decode(fake)
205 |
206 | label.fill_(fake_label)
207 | # Classify all fake batch with D
208 | output = netD(fake.detach()).view(-1)
209 | # Calculate D's loss on the all-fake batch
210 | errD_fake = criterion(output, label)
211 | # Calculate the gradients for this batch
212 | errD_fake.backward()
213 | D_G_z1 = output.mean().item()
214 | # Add the gradients from the all-real and all-fake batches
215 | errD = errD_real + errD_fake
216 | # Update D
217 | optimizerD.step()
218 |
219 | ############################
220 | # (2) Update G network: maximize log(D(G(z)))
221 | ###########################
222 | mapper.zero_grad()
223 | label.fill_(real_label) # fake labels are real for generator cost
224 | # Since we just updated D, perform another forward pass of all-fake batch through D
225 | output = netD(fake).view(-1)
226 | # Calculate G's loss based on this output
227 | errG = criterion(output, label)
228 | # Calculate gradients for G
229 | errG.backward()
230 | D_G_z2 = output.mean().item()
231 | # Update G
232 | optimizerG.step()
233 |
234 | # Output training stats
235 | if i % 50 == 0:
236 | print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
237 | % (epoch, num_epochs, i, len(dataloader),
238 | errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
239 |
240 | # Save Losses for plotting later
241 | G_losses.append(errG.item())
242 | D_losses.append(errD.item())
243 |
244 | # Check how the generator is doing by saving G's output on fixed_noise
245 | if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
246 | randomsamples = []
247 | with torch.no_grad():
248 | mapped = mapper(torch.randn(batch_size,args.emb_vector_dim).to(device))
249 | randomsamples.append(mapped.to(device))
250 |
251 | embed_samples = []
252 | one_embed = randomSample_from_embeddings(batch_size).to(device)
253 | with torch.no_grad():
254 | mapped_embed = mapper(one_embed).to(device)
255 | embed_samples.append(mapped_embed)
256 |
257 | randomsamples = torch.stack(randomsamples, dim=0)
258 | recons_randoms = vae_model.decode(randomsamples)
259 |
260 | embedsamples = torch.stack(embed_samples, dim=0)
261 | recons_embeds = vae_model.decode(embedsamples)
262 |
263 | vutils.save_image(recons_randoms.data,
264 | f"{args.export_path}test_random{i}.png",
265 | normalize=True,
266 | nrow=12)
267 | vutils.save_image(recons_embeds.data,
268 | f"{args.export_path}test_embed{i}.png",
269 | normalize=True,
270 | nrow=12)
271 |
272 | iters += 1
273 | return vae_model, mapper, netD
274 |
275 | def load_vae_model(vae_checkpointPath, config):
276 | model = vae_models[config['model_params']['name']](**config['model_params'])
277 | experiment = VAEXperiment(model,
278 | config['exp_params'])
279 |
280 |
281 | checkpoint = torch.load(vae_checkpointPath, map_location=lambda storage, loc: storage)
282 | new_ckpoint = {}
283 | for k in checkpoint["state_dict"].keys():
284 | newKey = k.split("model.")[1]
285 | new_ckpoint[newKey] = checkpoint["state_dict"][k]
286 |
287 | model.load_state_dict(new_ckpoint)
288 | model.eval()
289 | return model
290 |
291 |
292 | def main():
293 | print("on main")
294 | with open(args.config_path, 'r') as file:
295 | try:
296 | config = yaml.safe_load(file)
297 | except yaml.YAMLError as exc:
298 | print(exc)
299 |
300 | batch_size = 64
301 | device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
302 | print(device)
303 |
304 | vae_model = load_vae_model(args.vae_ckp_path, config).to(device)
305 |
306 | mapper = EmbeddingMapping(device, args.emb_vector_dim, args.vae_latent_dim).to(device)
307 | mapper.apply(weights_init)
308 |
309 | netD = Discriminator(device).to(device)
310 | netD.apply(weights_init)
311 | vae_model, mapper, netD = trainer(vae_model=vae_model, mapper=mapper, netD=netD, batch_size=batch_size, device=device)
312 |
313 | torch.save(mapper.state_dict(), args.export_path+"mapper_checkpoint.pt")
314 | torch.save(netD.state_dict(), args.export_path+"discriminator_checkpoint.pt")
315 |
316 |
317 |
318 | if args.test_onHelloWorld:
319 | with open("/content/drive/My Drive/vae/logogram-language-generator-master/fasttext_hello_world.pickle", "rb") as f:
320 | helloworld = pickle.load(f)
321 |
322 | hello = torch.from_numpy(helloworld["hello"]).to(device)
323 | world = torch.from_numpy(helloworld["world"]).to(device)
324 |
325 |
326 | helloworld = torch.stack([hello, world], dim=0).to(device)
327 |
328 | embed_samples = []
329 | with torch.no_grad():
330 | mapped_embed = mapper(helloworld).to(device)
331 | embed_samples.append(mapped_embed)
332 | embedsamples = torch.stack(embed_samples, dim=0)
333 | recons_embeds = vae_model.decode(embedsamples)
334 | vutils.save_image(recons_embeds.data,
335 | args.export_path+"helloWorld.png",
336 | normalize=True,
337 | nrow=12)
338 |
339 |
340 |
341 |
342 | print("ALL DONE!")
343 |
344 |
345 |
346 |
347 |
348 | if __name__ == "__main__":
349 | main()
350 |
--------------------------------------------------------------------------------
/src/train_siamese.py:
--------------------------------------------------------------------------------
1 | from Omniglot_triplet import Omniglot_triplet
2 | from torch.utils.data import DataLoader
3 | from torchvision.utils import save_image
4 | from Siamese import Siamese
5 | import torch
6 |
7 | #omni = Omniglot_triplet()
8 | #dloader = DataLoader(omni,batch_size=16,shuffle=True)
9 |
10 | siam = Siamese().cuda()
11 | #siam.train_triplet(dloader)
12 | siam.load_state_dict(torch.load("siamese.ckpt"))
13 | siam.eval()
14 | omnieval = Omniglot_triplet(split="test")
15 | dloader = DataLoader(omnieval,batch_size=16,shuffle=True)
16 | siam.eval_triplet(dloader)
17 |
--------------------------------------------------------------------------------
/src/umwe2vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.optim as optim
4 | from torchvision.transforms.functional import adjust_contrast
5 |
6 | class umwe2vae(nn.Module):
7 | def __init__(self, vae_model, in_dim=300, out_dim=128):
8 | super(umwe2vae, self).__init__()
9 | self.vae_model = vae_model
10 | #self.fc = nn.Linear(in_dim, out_dim)
11 | self.mapper_numlayer = 3
12 |
13 | self.linear_layers = []
14 | self.batch_norms = []
15 | for layer in range(0, self.mapper_numlayer-1):
16 | self.linear_layers.append(nn.Linear(in_dim, in_dim))
17 | self.batch_norms.append(nn.BatchNorm1d(in_dim))
18 |
19 | # final layer
20 | self.linear_layers.append(nn.Linear(in_dim, out_dim))
21 | self.batch_norms.append(nn.BatchNorm1d(out_dim))
22 |
23 |
24 | self.linear_layers = nn.ModuleList(self.linear_layers)
25 | self.batch_norms = nn.ModuleList(self.batch_norms)
26 |
27 |
28 | self.relu = nn.ReLU()
29 |
30 |
31 |
32 |
33 | def forward(self, x):
34 | inp = x
35 | for layer in range(self.mapper_numlayer):
36 | out = self.linear_layers[layer](inp)
37 | out = self.batch_norms[layer](out)
38 | out = self.relu(out)
39 | inp = out
40 | return out
41 |
42 | # h = self.fc(x)
43 | # y = self.vae_model.decode(h)
44 | # return h
45 | # here used to live post-processing
46 | #out = torch.zeros(y.shape)
47 | #for i in range(y.shape[0]):
48 | # out[i] = adjust_contrast(y[i], contrast_factor=2.5)
49 | #return out
50 |
51 | def loss(self, x, alpha=1, beta=1):
52 | middle = x[:,:,1:-1,1:-1]
53 | ne = x[:,:,0:-2,0:-2]
54 | n = x[:,:,0:-2,1:-1]
55 | nw = x[:,:,0:-2,2:]
56 | e = x[:,:,1:-1,0:-2]
57 | w = x[:,:,1:-1,2:]
58 | se = x[:,:,2:,0:-2]
59 | s = x[:,:,2:,1:-1]
60 | sw = x[:,:,2:,2:]
61 |
62 | return alpha * torch.mean(sum([torch.abs(middle-ne),torch.abs(middle-n),torch.abs(middle-nw),torch.abs(middle-e),torch.abs(middle-w),torch.abs(middle-se),torch.abs(middle-s),torch.abs(middle-sw)]) / 8.) - beta * torch.mean(torch.abs(x-0.5))
63 |
64 | def train(self, loader, lr=0.001, epochs=5):
65 | optimizer = optim.Adam(self.parameters(), lr=lr)
66 |
67 | for epoch in range(epochs):
68 | total_loss = 0
69 | for _,inp in enumerate(loader):
70 | optimizer.zero_grad()
71 |
72 | out = self.forward(inp)
73 | out = self.vae_model.decode(out)
74 | loss = self.loss(out)
75 |
76 | loss.backward()
77 | optimizer.step()
78 |
79 | total_loss += loss
80 |
81 | print("Epoch: %d Loss: %f" % (epoch+1, total_loss/len(loader)))
82 |
83 |
84 | torch.save(self.state_dict(), "umwe2vae.ckpt")
85 |
86 |
--------------------------------------------------------------------------------
/src/umwe_mappers/best_mapping_es2en.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/umwe_mappers/best_mapping_es2en.t7
--------------------------------------------------------------------------------
/src/umwe_mappers/best_mapping_fr2en.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/umwe_mappers/best_mapping_fr2en.t7
--------------------------------------------------------------------------------
/src/umwe_mappers/best_mapping_it2en.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/src/umwe_mappers/best_mapping_it2en.t7
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 |
3 |
4 | ## Utils to handle newer PyTorch Lightning changes from version 0.6
5 | ## ==================================================================================================== ##
6 |
7 |
8 | def data_loader(fn):
9 | """
10 | Decorator to handle the deprecation of data_loader from 0.7
11 | :param fn: User defined data loader function
12 | :return: A wrapper for the data_loader function
13 | """
14 |
15 | def func_wrapper(self):
16 | try: # Works for version 0.6.0
17 | return pl.data_loader(fn)(self)
18 |
19 | except: # Works for version > 0.6.0
20 | return fn(self)
21 |
22 | return func_wrapper
23 |
--------------------------------------------------------------------------------
/src/vae_with_norm.py:
--------------------------------------------------------------------------------
1 | from Omniglot import Omniglot
2 | import yaml
3 | import argparse
4 | import numpy as np
5 |
6 | from models import *
7 | from experiment import VAEXperiment
8 | import torch.backends.cudnn as cudnn
9 | from pytorch_lightning import Trainer
10 | from pytorch_lightning.logging import TestTubeLogger
11 | import torchvision.utils as vutils
12 | from torch.utils.data import DataLoader
13 | import torch.optim as optim
14 | from torchvision import transforms
15 |
16 | import pickle
17 | import random
18 |
19 | # import fasttext
20 | # import fasttext.util
21 |
22 |
23 |
24 |
25 | def get_dataLoader(batch_size):
26 | transform = data_transforms()
27 | dataset = Omniglot(split="train", transform=transform)
28 | num_train_imgs = len(dataset)
29 | return DataLoader(dataset,
30 | batch_size= batch_size,
31 | shuffle = True,
32 | drop_last=True)
33 |
34 |
35 | def data_transforms():
36 | SetRange = transforms.Lambda(lambda X: 2 * X - 1.)
37 | SetScale = transforms.Lambda(lambda X: X/X.sum(0).expand_as(X))
38 | transform = transforms.Compose([transforms.Resize((64,64)), transforms.ToTensor()])
39 | return transform
40 |
41 |
42 | def randomSample_from_embeddings(batch_size, fr=""):
43 | with open("/content/drive/My Drive/vae/logogram-language-generator-master/sample_fasttext_embs"+fr+".pickle", "rb") as f:
44 | random_sample = pickle.load(f)
45 |
46 | fake_batch = random.sample(list(random_sample.values()), batch_size)
47 | fake_batch = torch.stack(fake_batch, dim=0)
48 | return fake_batch
49 |
50 |
51 |
52 |
53 |
54 | def trainer(vae_model, mapper, netD, batch_size, device):
55 | # Training Loop
56 | # Lists to keep track of progress
57 | img_list = []
58 | dataloader = get_dataLoader(batch_size=batch_size)
59 |
60 | # Create batch of latent vectors that we will use to visualize
61 | # the progression of the generator
62 | fixed_noise = randomSample_from_embeddings(batch_size).to(device)
63 |
64 |
65 |
66 |
67 | print("Starting Training Loop...")
68 | # For each epoch
69 | for epoch in range(num_epochs):
70 | # For each batch in the dataloader
71 | for i, data in enumerate(dataloader, 0):
72 |
73 | ############################
74 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
75 | ###########################
76 | ## Train with all-real batch
77 | netD.zero_grad()
78 | # Format batch
79 | real_cpu = data[0].to(device)
80 | b_size = real_cpu.size(0)
81 | label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
82 | # Forward pass real batch through D
83 | output = netD(real_cpu).view(-1)
84 | # Calculate loss on all-real batch
85 | errD_real = criterion(output, label)
86 | # Calculate gradients for D in backward pass
87 | errD_real.backward()
88 | D_x = output.mean().item()
89 |
90 | ## Train with all-fake batch
91 | # Generate batch of latent vectors
92 | embeddings = randomSample_from_embeddings(batch_size).to(device)
93 | # Generate fake image batch with G
94 | fake = mapper(embeddings).to(device)
95 | fake = vae_model.decode(fake)
96 |
97 | label.fill_(fake_label)
98 | # Classify all fake batch with D
99 | output = netD(fake.detach()).view(-1)
100 | # Calculate D's loss on the all-fake batch
101 | errD_fake = criterion(output, label)
102 | # Calculate the gradients for this batch
103 | errD_fake.backward()
104 | D_G_z1 = output.mean().item()
105 | # Add the gradients from the all-real and all-fake batches
106 | errD = errD_real + errD_fake
107 | # Update D
108 | optimizerD.step()
109 |
110 | ############################
111 | # (2) Update G network: maximize log(D(G(z)))
112 | ###########################
113 | mapper.zero_grad()
114 | label.fill_(real_label) # fake labels are real for generator cost
115 | # Since we just updated D, perform another forward pass of all-fake batch through D
116 | output = netD(fake).view(-1)
117 | # Calculate G's loss based on this output
118 | errG = criterion(output, label)
119 | # Calculate gradients for G
120 | errG.backward()
121 | D_G_z2 = output.mean().item()
122 | # Update G
123 | optimizerG.step()
124 |
125 | # Output training stats
126 | if i % 50 == 0:
127 | print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
128 | % (epoch, num_epochs, i, len(dataloader),
129 | errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
130 | randomsamples = []
131 | # for i in range(64):
132 |
133 | with torch.no_grad():
134 | mapped = mapper(torch.randn(64,300).to(device))
135 | randomsamples.append(mapped.to(device))
136 |
137 | embed_samples = []
138 | # for i in range(64):
139 | one_embed = randomSample_from_embeddings(64).to(device)
140 | with torch.no_grad():
141 | mapped_embed = mapper(one_embed).to(device)
142 | embed_samples.append(mapped_embed)
143 |
144 |
145 |
146 | randomsamples = torch.stack(randomsamples, dim=0)
147 | recons_randoms = vae_model.decode(randomsamples)
148 |
149 | embedsamples = torch.stack(embed_samples, dim=0)
150 | recons_embeds = vae_model.decode(embedsamples)
151 |
152 |
153 |
154 | vutils.save_image(recons_randoms.data,
155 | f"./vae_with_disc/test_random{i}.png",
156 | normalize=True,
157 | nrow=12)
158 | vutils.save_image(recons_embeds.data,
159 | f"./vae_with_disc/test_embed{i}.png",
160 | normalize=True,
161 | nrow=12)
162 |
163 | # Save Losses for plotting later
164 | G_losses.append(errG.item())
165 | D_losses.append(errD.item())
166 |
167 | # Check how the generator is doing by saving G's output on fixed_noise
168 | if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
169 | with torch.no_grad():
170 | fake = mapper(fixed_noise).detach().to(device)
171 | img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
172 | # vutils.save_image(img_list,
173 | # f"./vae_with_disc/samples_{i}.png",
174 | # normalize=True,
175 | # nrow=12)
176 |
177 | iters += 1
178 | return vae_model, mapper, netD
179 |
180 | def load_vae_model(vae_checkpointPath, config):
181 | model = vae_models[config['model_params']['name']](**config['model_params'])
182 | experiment = VAEXperiment(model,
183 | config['exp_params'])
184 |
185 |
186 | checkpoint = torch.load(vae_checkpointPath, map_location=lambda storage, loc: storage)
187 | new_ckpoint = {}
188 | for k in checkpoint["state_dict"].keys():
189 | newKey = k.split("model.")[1]
190 | new_ckpoint[newKey] = checkpoint["state_dict"][k]
191 |
192 | model.load_state_dict(new_ckpoint)
193 | model.eval()
194 | return model
195 |
196 |
197 | def main():
198 | print("on main")
199 |
200 |
201 | with open("./configs/bbvae_setup2.yaml", 'r') as file:
202 | try:
203 | config = yaml.safe_load(file)
204 | except yaml.YAMLError as exc:
205 | print(exc)
206 |
207 | vae_checkpointPath = "logs/BetaVAE_B_setup2_run2/final_model_checkpoint.ckpt"
208 | batch_size = 64
209 | device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
210 | print(device)
211 |
212 | vae_model = load_vae_model(vae_checkpointPath, config).to(device)
213 |
214 | # fixed_noise = randomSample_from_embeddings(batch_size).to(device)
215 | # fixed_noise = torch.randn()
216 |
217 | best_mapping = "./best_mapping_fr2en.t7"
218 | Readed_t7 = torch.from_numpy(torch.load(best_mapping)).to(device)
219 |
220 | with open("/content/drive/My Drive/vae/logogram-language-generator-master/fasttext_hello_world_fr.pickle", "rb") as f:
221 | helloworld = pickle.load(f)
222 |
223 | # helloworld = random.sample(list(helloworld.values()), batch_size)
224 | hello = torch.from_numpy(helloworld["hello"]).to(device)
225 | world = torch.from_numpy(helloworld["world"]).to(device)
226 | # helloworld = torch.stack([hello, world], dim=0).to(device)
227 |
228 | print(type(hello))
229 | print(type(Readed_t7))
230 | mapped_hello = torch.matmul(hello, Readed_t7)
231 | mapped_world = torch.matmul(world, Readed_t7)
232 |
233 |
234 |
235 |
236 | sample_embeds = randomSample_from_embeddings(batch_size, "_fr").to(device)
237 |
238 | std, mean = torch.std_mean(input=sample_embeds, dim=0, unbiased=True)
239 | # norm_func = transforms.Normalize(mean, std)
240 | ## norm here
241 | # normalized = norm_func(hello)
242 |
243 | normalized_hello = (mapped_hello - mean) / std
244 | normalized_world = (mapped_world - mean) / std
245 |
246 | recons_embeds = vae_model.decode(normalized_hello[:128])
247 | vutils.save_image(recons_embeds.data,
248 | "./vae_with_disc/hello_normalized2_fr.png",
249 | normalize=True,
250 | nrow=12)
251 | recons_embeds = vae_model.decode(normalized_world[:128])
252 | vutils.save_image(recons_embeds.data,
253 | "./vae_with_disc/World_normalized2_fr.png",
254 | normalize=True,
255 | nrow=12)
256 |
257 |
258 |
259 | print("ALL DONE!")
260 |
261 |
262 |
263 |
264 |
265 | if __name__ == "__main__":
266 | main()
--------------------------------------------------------------------------------
/umwe/README.md:
--------------------------------------------------------------------------------
1 | ## Unsupervise Multilingual Word Embeddings
2 | This repo contains the source code for our paper:
3 |
4 | [**Unsupervised Multilingual Word Embeddings**](http://aclweb.org/anthology/D18-1024)
5 |
6 | [Xilun Chen](http://www.cs.cornell.edu/~xlchen/),
7 | [Claire Cardie](http://www.cs.cornell.edu/home/cardie/)
8 |
9 | EMNLP 2018
10 |
11 | [paper](http://aclweb.org/anthology/D18-1024),
12 | [bibtex](http://aclweb.org/anthology/D18-1024.bib)
13 |
14 | 
15 |
16 | ## Dependencies
17 | * Python 3.6 with [NumPy](http://www.numpy.org/)/[SciPy](https://www.scipy.org/)
18 | * [PyTorch](http://pytorch.org/) 0.4
19 | * [Faiss](https://github.com/facebookresearch/faiss) (recommended) for fast nearest neighbor search (CPU or GPU).
20 |
21 | Faiss is *optional* for GPU users - though Faiss-GPU will greatly speed up nearest neighbor search - and *highly recommended* for CPU users. Faiss can be installed using "conda install faiss-cpu -c pytorch" or "conda install faiss-gpu -c pytorch".
22 |
23 | ## Get evaluation datasets
24 | The evaluation datasets from [MUSE](https://github.com/facebookresearch/MUSE) can be downloaded by simply running (in data/):
25 |
26 | ```bash
27 | ./get_evaluation.sh
28 | ```
29 | *Note: Requires bash 4. The download of Europarl is disabled by default (slow), you can enable it [here](https://github.com/facebookresearch/MUSE/blob/master/data/get_evaluation.sh#L99-L100).*
30 |
31 | ## Get monolingual word embeddings
32 | For pre-trained monolingual word embeddings, we adopt the same [fastText Wikipedia embeddings](https://github.com/facebookresearch/fastText/blob/master/pretrained-vectors.md) recommended by the MUSE authors.
33 |
34 | For example, you can download the English (en) embeddings this way:
35 | ```bash
36 | # English fastText Wikipedia embeddings
37 | curl -Lo data/fasttext-vectors/wiki.en.vec https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.en.vec
38 | ```
39 |
40 | ## Learn unsupervised multilingual word embeddings
41 | ```bash
42 | python unsupervised.py --src_langs de fr es it pt --tgt_lang en
43 | ```
44 |
45 | ## Learn supervised multilingual word embeddings
46 | This was not explored in the paper, but our MPSR method can in principle be applied to the supervised setting as well.
47 | No adversarial training will take place, and the Procrustes method in MUSE is replaced by our MPSR algorithm to learn multilingual embeddings.
48 | For example, the following command can be used to train (weakly) supervised UMWEs with identical characters served as supervsion.
49 | ```bash
50 | python supervised.py --src_langs de fr es it pt --tgt_lang en --dico_train identical_char
51 | ```
52 |
53 | ## Evaluate cross-lingual embeddings (CPU|GPU)
54 | To save time, not all language pairs are evaluated during training.
55 | You may run the evaluation script perform full evaluation.
56 | The test results are saved in 'evaluate.log' in the model folder.
57 | ```bash
58 | python evaluate.py --src_langs de fr es it pt --tgt_lang en --eval_pairs all --exp_id [your exp_id]
59 | ```
60 |
61 | ## Word embedding format
62 | By default, the aligned embeddings are exported to a text format at the end of experiments: `--export txt`. Exporting embeddings to a text file can take a while if you have a lot of embeddings. For a very fast export, you can set `--export pth` to export the embeddings in a PyTorch binary file, or simply disable the export (`--export ""`).
63 |
64 | When loading embeddings, the model can load:
65 | * PyTorch binary files previously generated by UMWE (.pth files)
66 | * fastText binary files previously generated by fastText (.bin files)
67 | * text files (text file with one word embedding per line)
68 |
69 | The two first options are very fast and can load 1 million embeddings in a few seconds, while loading text files can take a while.
70 |
71 | ## License
72 |
73 | This work is developed based on [MUSE](https://github.com/facebookresearch/MUSE) by Alexis Conneau, Guillaume Lample, et al. from Facebook Inc., used under [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/).
74 |
75 | This work is licensed under [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) by Xilun Chen.
76 |
--------------------------------------------------------------------------------
/umwe/data/get_evaluation.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | en_analogy='https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/word2vec/source-archive.zip'
9 | aws_path='https://s3.amazonaws.com/arrival'
10 | semeval_2017='http://alt.qcri.org/semeval2017/task2/data/uploads'
11 | europarl='http://www.statmt.org/europarl/v7'
12 |
13 | declare -A wordsim_lg
14 | wordsim_lg=(["en"]="EN_MC-30.txt EN_MTurk-287.txt EN_RG-65.txt EN_VERB-143.txt EN_WS-353-REL.txt EN_YP-130.txt EN_MEN-TR-3k.txt EN_MTurk-771.txt EN_RW-STANFORD.txt EN_SIMLEX-999.txt EN_WS-353-ALL.txt EN_WS-353-SIM.txt" ["es"]="ES_MC-30.txt ES_RG-65.txt ES_WS-353.txt" ["de"]="DE_GUR350.txt DE_GUR65.txt DE_SIMLEX-999.txt DE_WS-353.txt DE_ZG222.txt" ["fr"]="FR_RG-65.txt" ["it"]="IT_SIMLEX-999.txt IT_WS-353.txt")
15 |
16 | mkdir monolingual crosslingual
17 |
18 | ## English word analogy task
19 | curl -Lo source-archive.zip $en_analogy
20 | mkdir -p monolingual/en/
21 | unzip -p source-archive.zip word2vec/trunk/questions-words.txt > monolingual/en/questions-words.txt
22 | rm source-archive.zip
23 |
24 |
25 | ## Downloading en-{} or {}-en dictionaries
26 | lgs="af ar bg bn bs ca cs da de el en es et fa fi fr he hi hr hu id it ja ko lt lv mk ms nl no pl pt ro ru sk sl sq sv ta th tl tr uk vi zh"
27 | mkdir -p crosslingual/dictionaries/
28 | for lg in ${lgs}
29 | do
30 | for suffix in .txt .0-5000.txt .5000-6500.txt
31 | do
32 | fname=en-$lg$suffix
33 | curl -Lo crosslingual/dictionaries/$fname $aws_path/dictionaries/$fname
34 | fname=$lg-en$suffix
35 | curl -Lo crosslingual/dictionaries/$fname $aws_path/dictionaries/$fname
36 | done
37 | done
38 |
39 | ## Download European dictionaries
40 | for src_lg in de es fr it pt
41 | do
42 | for tgt_lg in de es fr it pt
43 | do
44 | if [ $src_lg != $tgt_lg ]
45 | then
46 | for suffix in .txt .0-5000.txt .5000-6500.txt
47 | do
48 | fname=$src_lg-$tgt_lg$suffix
49 | curl -Lo crosslingual/dictionaries/$fname $aws_path/dictionaries/european/$fname
50 | done
51 | fi
52 | done
53 | done
54 |
55 | ## Download Dinu et al. dictionaries
56 | for fname in OPUS_en_it_europarl_train_5K.txt OPUS_en_it_europarl_test.txt
57 | do
58 | echo $fname
59 | curl -Lo crosslingual/dictionaries/$fname $aws_path/dictionaries/$fname
60 | done
61 |
62 | ## Monolingual wordsim tasks
63 | for lang in "${!wordsim_lg[@]}"
64 | do
65 | echo $lang
66 | mkdir monolingual/$lang
67 | for wsim in ${wordsim_lg[$lang]}
68 | do
69 | echo $wsim
70 | curl -Lo monolingual/$lang/$wsim $aws_path/$lang/$wsim
71 | done
72 | done
73 |
74 | ## SemEval 2017 monolingual and cross-lingual wordsim tasks
75 | # 1) Task1: monolingual
76 | curl -Lo semeval2017-task2.zip $semeval_2017/semeval2017-task2.zip
77 | unzip semeval2017-task2.zip
78 |
79 | fdir='SemEval17-Task2/test/subtask1-monolingual'
80 | for lang in en es de fa it
81 | do
82 | mkdir -p monolingual/$lang
83 | uplang=`echo $lang | awk '{print toupper($0)}'`
84 | paste $fdir/data/$lang.test.data.txt $fdir/keys/$lang.test.gold.txt > monolingual/$lang/${uplang}_SEMEVAL17.txt
85 | done
86 |
87 | # 2) Task2: cross-lingual
88 | mkdir -p crosslingual/wordsim
89 | fdir='SemEval17-Task2/test/subtask2-crosslingual'
90 | for lg_pair in de-es de-fa de-it en-de en-es en-fa en-it es-fa es-it it-fa
91 | do
92 | echo $lg_pair
93 | paste $fdir/data/$lg_pair.test.data.txt $fdir/keys/$lg_pair.test.gold.txt > crosslingual/wordsim/$lg_pair-SEMEVAL17.txt
94 | done
95 | rm semeval2017-task2.zip
96 | rm -r SemEval17-Task2/
97 |
98 | ## Europarl for sentence retrieval
99 | # TODO: set to true to activate download of Europarl (slow)
100 | if false; then
101 | mkdir -p crosslingual/europarl
102 | # Tokenize EUROPARL with MOSES
103 | echo 'Cloning Moses github repository (for tokenization scripts)...'
104 | git clone https://github.com/moses-smt/mosesdecoder.git
105 | SCRIPTS=mosesdecoder/scripts
106 | TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
107 |
108 | for lg_pair in it-en # es-en etc
109 | do
110 | curl -Lo $lg_pair.tgz $europarl/$lg_pair.tgz
111 | tar -xvf it-en.tgz
112 | rm it-en.tgz
113 | lgs=(${lg_pair//-/ })
114 | for lg in ${lgs[0]} ${lgs[1]}
115 | do
116 | cat europarl-v7.$lg_pair.$lg | $TOKENIZER -threads 8 -l $lg -no-escape > euro.$lg.txt
117 | rm europarl-v7.$lg_pair.$lg
118 | done
119 |
120 | paste euro.${lgs[0]}.txt euro.${lgs[1]}.txt | shuf > euro.paste.txt
121 | rm euro.${lgs[0]}.txt euro.${lgs[1]}.txt
122 |
123 | cut -f1 euro.paste.txt > crosslingual/europarl/europarl-v7.$lg_pair.${lgs[0]}
124 | cut -f2 euro.paste.txt > crosslingual/europarl/europarl-v7.$lg_pair.${lgs[1]}
125 | rm euro.paste.txt
126 | done
127 |
128 | rm -rf mosesdecoder
129 | fi
130 |
--------------------------------------------------------------------------------
/umwe/evaluate.py:
--------------------------------------------------------------------------------
1 | # Original work Copyright (c) 2017-present, Facebook, Inc.
2 | # Modified work Copyright (c) 2018, Xilun Chen
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | #
8 |
9 | # python evaluate.py --src_langs de fr es it pt --tgt_lang en --eval_pairs all
10 |
11 | import os
12 | import argparse
13 | from collections import OrderedDict
14 |
15 | from src.utils import bool_flag, initialize_exp
16 | from src.models import build_model
17 | from src.trainer import Trainer
18 | from src.evaluation import Evaluator
19 |
20 | # default path to embeddings embeddings if not otherwise specified
21 | EMB_DIR = 'data/fasttext-vectors/'
22 |
23 |
24 | # main
25 | parser = argparse.ArgumentParser(description='Evaluation')
26 | parser.add_argument("--verbose", type=int, default=2, help="Verbose level (2:debug, 1:info, 0:warning)")
27 | parser.add_argument("--exp_path", type=str, default="", help="Where to store experiment logs and models")
28 | parser.add_argument("--exp_name", type=str, default="debug", help="Experiment name")
29 | parser.add_argument("--exp_id", type=str, default="", help="Experiment ID")
30 | # parser.add_argument("--cuda", type=bool_flag, default=True, help="Run on GPU")
31 | parser.add_argument("--device", type=str, default="cuda", help="Run on GPU or CPU")
32 | # data
33 | parser.add_argument("--src_langs", type=str, nargs='+', default=[], help="Source languages")
34 | parser.add_argument("--tgt_lang", type=str, default="", help="Target language")
35 | # evaluation
36 | parser.add_argument("--eval_pairs", type=str, nargs='+', default=[], help="Language pairs to evaluate. e.g. ['en-de', 'de-fr']")
37 | parser.add_argument("--dico_eval", type=str, default="default", help="Path to evaluation dictionary")
38 | parser.add_argument("--dict_suffix", type=str, default="5000-6500.txt", help="suffix to use for word translation (0-5000.txt or 5000-6500.txt or txt)")
39 | parser.add_argument("--semeval_ignore_oov", type=bool_flag, default=True, help="Whether to ignore OOV in SEMEVAL evaluation (the original authors used True)")
40 | # reload pre-trained embeddings
41 | parser.add_argument("--src_embs", type=str, nargs='+', default=[], help="Reload source embeddings (should be in the same order as in src_langs)")
42 | parser.add_argument("--tgt_emb", type=str, default="", help="Reload target embeddings")
43 | parser.add_argument("--max_vocab", type=int, default=200000, help="Maximum vocabulary size (-1 to disable)")
44 | parser.add_argument("--emb_dim", type=int, default=300, help="Embedding dimension")
45 | parser.add_argument("--normalize_embeddings", type=str, default="", help="Normalize embeddings before training")
46 |
47 |
48 | # parse parameters
49 | params = parser.parse_args()
50 |
51 | # post-processing options
52 | params.src_N = len(params.src_langs)
53 | params.all_langs = params.src_langs + [params.tgt_lang]
54 | # load default embeddings
55 | if len(params.src_embs) == 0:
56 | params.src_embs = []
57 | for lang in params.src_langs:
58 | params.src_embs.append(os.path.join(EMB_DIR, f'wiki.{lang}.vec'))
59 | if len(params.tgt_emb) == 0:
60 | params.tgt_emb = os.path.join(EMB_DIR, f'wiki.{params.tgt_lang}.vec')
61 | # expand 'all' in eval_pairs
62 | if 'all' in params.eval_pairs:
63 | params.eval_pairs = []
64 | for lang1 in params.all_langs:
65 | for lang2 in params.all_langs:
66 | if lang1 != lang2:
67 | params.eval_pairs.append(f'{lang1}-{lang2}')
68 |
69 | # check parameters
70 | assert len(params.src_langs) > 0, "source language undefined"
71 | assert all([os.path.isfile(emb) for emb in params.src_embs])
72 | assert not params.tgt_lang or os.path.isfile(params.tgt_emb)
73 | assert params.dico_eval == 'default' or os.path.isfile(params.dico_eval)
74 |
75 | # build logger / model / trainer / evaluator
76 | logger = initialize_exp(params, dump_params=False, log_name='evaluate.log')
77 | embs, mappings, _ = build_model(params, False)
78 | trainer = Trainer(embs, mappings, None, params)
79 | trainer.reload_best()
80 | evaluator = Evaluator(trainer)
81 |
82 | # run evaluations
83 | to_log = OrderedDict({'n_iter': 0})
84 | all_wt = []
85 | evaluator.monolingual_wordsim(to_log)
86 | for eval_pair in params.eval_pairs:
87 | parts = eval_pair.split('-')
88 | assert len(parts) == 2, 'Invalid format for evaluation pairs.'
89 | src_lang, tgt_lang = parts[0], parts[1]
90 | logger.info(f'Evaluating language pair: {src_lang} - {tgt_lang}')
91 | evaluator.crosslingual_wordsim(to_log, src_lang=src_lang, tgt_lang=tgt_lang)
92 | evaluator.word_translation(to_log, src_lang=src_lang, tgt_lang=tgt_lang)
93 | all_wt.append(to_log[f'{src_lang}-{tgt_lang}_precision_at_1-csls_knn_10'])
94 | evaluator.sent_translation(to_log, src_lang=src_lang, tgt_lang=tgt_lang)
95 |
96 | logger.info(f"Overall Word Translation Precision@1 over {len(all_wt)} language pairs: {sum(all_wt)/len(all_wt)}")
97 |
--------------------------------------------------------------------------------
/umwe/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/selimseker/logogram-language-generator/a7c80eede2dc18f678a960b59e03a250374eece2/umwe/src/__init__.py
--------------------------------------------------------------------------------
/umwe/src/dico_builder.py:
--------------------------------------------------------------------------------
1 | # Original work Copyright (c) 2017-present, Facebook, Inc.
2 | # Modified work Copyright (c) 2018, Xilun Chen
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | #
8 |
9 | from logging import getLogger
10 | import torch
11 |
12 | from .utils import get_nn_avg_dist
13 |
14 |
15 | logger = getLogger()
16 |
17 |
18 | def get_candidates(emb1, emb2, params):
19 | """
20 | Get best translation pairs candidates.
21 | """
22 | bs = 128
23 |
24 | all_scores = []
25 | all_targets = []
26 |
27 | # number of source words to consider
28 | n_src = emb1.size(0)
29 | if params.dico_max_rank > 0 and not params.dico_method.startswith('invsm_beta_'):
30 | n_src = params.dico_max_rank
31 |
32 | # nearest neighbors
33 | if params.dico_method == 'nn':
34 |
35 | # for every source word
36 | for i in range(0, n_src, bs):
37 |
38 | # compute target words scores
39 | scores = emb2.mm(emb1[i:min(n_src, i + bs)].transpose(0, 1)).transpose(0, 1)
40 | best_scores, best_targets = scores.topk(2, dim=1, largest=True, sorted=True)
41 |
42 | # update scores / potential targets
43 | all_scores.append(best_scores.cpu())
44 | all_targets.append(best_targets.cpu())
45 |
46 | all_scores = torch.cat(all_scores, 0)
47 | all_targets = torch.cat(all_targets, 0)
48 |
49 | # inverted softmax
50 | elif params.dico_method.startswith('invsm_beta_'):
51 |
52 | beta = float(params.dico_method[len('invsm_beta_'):])
53 |
54 | # for every target word
55 | for i in range(0, emb2.size(0), bs):
56 |
57 | # compute source words scores
58 | scores = emb1.mm(emb2[i:i + bs].transpose(0, 1))
59 | scores.mul_(beta).exp_()
60 | scores.div_(scores.sum(0, keepdim=True).expand_as(scores))
61 |
62 | best_scores, best_targets = scores.topk(2, dim=1, largest=True, sorted=True)
63 |
64 | # update scores / potential targets
65 | all_scores.append(best_scores.cpu())
66 | all_targets.append((best_targets + i).cpu())
67 |
68 | all_scores = torch.cat(all_scores, 1)
69 | all_targets = torch.cat(all_targets, 1)
70 |
71 | all_scores, best_targets = all_scores.topk(2, dim=1, largest=True, sorted=True)
72 | all_targets = all_targets.gather(1, best_targets)
73 |
74 | # contextual dissimilarity measure
75 | elif params.dico_method.startswith('csls_knn_'):
76 |
77 | knn = params.dico_method[len('csls_knn_'):]
78 | assert knn.isdigit()
79 | knn = int(knn)
80 |
81 | # average distances to k nearest neighbors
82 | average_dist1 = torch.from_numpy(get_nn_avg_dist(emb2, emb1, knn))
83 | average_dist2 = torch.from_numpy(get_nn_avg_dist(emb1, emb2, knn))
84 | average_dist1 = average_dist1.type_as(emb1)
85 | average_dist2 = average_dist2.type_as(emb2)
86 |
87 | # for every source word
88 | for i in range(0, n_src, bs):
89 |
90 | # compute target words scores
91 | scores = emb2.mm(emb1[i:min(n_src, i + bs)].transpose(0, 1)).transpose(0, 1)
92 | scores.mul_(2)
93 | scores.sub_(average_dist1[i:min(n_src, i + bs)][:, None] + average_dist2[None, :])
94 | best_scores, best_targets = scores.topk(2, dim=1, largest=True, sorted=True)
95 |
96 | # update scores / potential targets
97 | all_scores.append(best_scores.cpu())
98 | all_targets.append(best_targets.cpu())
99 |
100 | all_scores = torch.cat(all_scores, 0)
101 | all_targets = torch.cat(all_targets, 0)
102 |
103 | all_pairs = torch.cat([
104 | torch.arange(0, all_targets.size(0)).long().unsqueeze(1),
105 | all_targets[:, 0].unsqueeze(1)
106 | ], 1)
107 |
108 | # sanity check
109 | assert all_scores.size() == all_pairs.size() == (n_src, 2)
110 |
111 | # sort pairs by score confidence
112 | diff = all_scores[:, 0] - all_scores[:, 1]
113 | reordered = diff.sort(0, descending=True)[1]
114 | all_scores = all_scores[reordered]
115 | all_pairs = all_pairs[reordered]
116 |
117 | # max dico words rank
118 | if params.dico_max_rank > 0:
119 | selected = all_pairs.max(1)[0] <= params.dico_max_rank
120 | mask = selected.unsqueeze(1).expand_as(all_scores).clone()
121 | all_scores = all_scores.masked_select(mask).view(-1, 2)
122 | all_pairs = all_pairs.masked_select(mask).view(-1, 2)
123 | if len(all_pairs) == 0:
124 | return []
125 |
126 | # max dico size
127 | if params.dico_max_size > 0:
128 | all_scores = all_scores[:params.dico_max_size]
129 | all_pairs = all_pairs[:params.dico_max_size]
130 |
131 | # min dico size
132 | diff = all_scores[:, 0] - all_scores[:, 1]
133 | if params.dico_min_size > 0:
134 | diff[:params.dico_min_size] = 1e9
135 |
136 | # confidence threshold
137 | if params.dico_threshold > 0:
138 | mask = diff > params.dico_threshold
139 | logger.info("Selected %i / %i pairs above the confidence threshold." % (mask.sum(), diff.size(0)))
140 | mask = mask.unsqueeze(1).expand_as(all_pairs).clone()
141 | all_pairs = all_pairs.masked_select(mask).view(-1, 2)
142 |
143 | return all_pairs
144 |
145 |
146 | def build_dictionary(src_emb, tgt_emb, params, s2t_candidates=None, t2s_candidates=None):
147 | """
148 | Build a training dictionary given current embeddings / mapping.
149 | """
150 | logger.info("Building the train dictionary ...")
151 | s2t = 'S2T' in params.dico_build
152 | t2s = 'T2S' in params.dico_build
153 | assert s2t or t2s
154 |
155 | if s2t:
156 | if s2t_candidates is None:
157 | s2t_candidates = get_candidates(src_emb, tgt_emb, params)
158 | if t2s:
159 | if t2s_candidates is None:
160 | t2s_candidates = get_candidates(tgt_emb, src_emb, params)
161 | t2s_candidates = torch.cat([t2s_candidates[:, 1:], t2s_candidates[:, :1]], 1)
162 |
163 | if params.dico_build == 'S2T':
164 | dico = s2t_candidates
165 | elif params.dico_build == 'T2S':
166 | dico = t2s_candidates
167 | else:
168 | s2t_candidates = set([(a, b) for a, b in s2t_candidates.numpy()])
169 | t2s_candidates = set([(a, b) for a, b in t2s_candidates.numpy()])
170 | if params.dico_build == 'S2T|T2S':
171 | final_pairs = s2t_candidates | t2s_candidates
172 | else:
173 | assert params.dico_build == 'S2T&T2S'
174 | final_pairs = s2t_candidates & t2s_candidates
175 | if len(final_pairs) == 0:
176 | logger.warning("Empty intersection ...")
177 | return None
178 | dico = torch.LongTensor(list([[int(a), int(b)] for (a, b) in final_pairs]))
179 |
180 | if len(dico) == 0:
181 | return None
182 | logger.info('New train dictionary of %i pairs.' % dico.size(0))
183 | return dico.to(params.device)
184 |
--------------------------------------------------------------------------------
/umwe/src/dictionary.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from logging import getLogger
9 |
10 |
11 | logger = getLogger()
12 |
13 |
14 | class Dictionary(object):
15 |
16 | def __init__(self, id2word, word2id, lang):
17 | assert len(id2word) == len(word2id)
18 | self.id2word = id2word
19 | self.word2id = word2id
20 | self.lang = lang
21 | self.check_valid()
22 |
23 | def __len__(self):
24 | """
25 | Returns the number of words in the dictionary.
26 | """
27 | return len(self.id2word)
28 |
29 | def __getitem__(self, i):
30 | """
31 | Returns the word of the specified index.
32 | """
33 | return self.id2word[i]
34 |
35 | def __contains__(self, w):
36 | """
37 | Returns whether a word is in the dictionary.
38 | """
39 | return w in self.word2id
40 |
41 | def __eq__(self, y):
42 | """
43 | Compare the dictionary with another one.
44 | """
45 | self.check_valid()
46 | y.check_valid()
47 | if len(self.id2word) != len(y):
48 | return False
49 | return self.lang == y.lang and all(self.id2word[i] == y[i] for i in range(len(y)))
50 |
51 | def check_valid(self):
52 | """
53 | Check that the dictionary is valid.
54 | """
55 | assert len(self.id2word) == len(self.word2id)
56 | for i in range(len(self.id2word)):
57 | assert self.word2id[self.id2word[i]] == i
58 |
59 | def index(self, word):
60 | """
61 | Returns the index of the specified word.
62 | """
63 | return self.word2id[word]
64 |
65 | def prune(self, max_vocab):
66 | """
67 | Limit the vocabulary size.
68 | """
69 | assert max_vocab >= 1
70 | self.id2word = {k: v for k, v in self.id2word.items() if k < max_vocab}
71 | self.word2id = {v: k for k, v in self.id2word.items()}
72 | self.check_valid()
73 |
--------------------------------------------------------------------------------
/umwe/src/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .wordsim import get_wordsim_scores, get_crosslingual_wordsim_scores, get_wordanalogy_scores
2 | from .word_translation import get_word_translation_accuracy
3 | from .sent_translation import get_sent_translation_accuracy, load_europarl_data
4 | from .evaluator import Evaluator
5 |
--------------------------------------------------------------------------------
/umwe/src/evaluation/sent_translation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import os
9 | import io
10 | from logging import getLogger
11 | import numpy as np
12 | import torch
13 |
14 | from src.utils import bow_idf, get_nn_avg_dist
15 |
16 |
17 | EUROPARL_DIR = 'data/crosslingual/europarl'
18 |
19 |
20 | logger = getLogger()
21 |
22 |
23 | def load_europarl_data(lg1, lg2, n_max=1e10, lower=True):
24 | """
25 | Load data parallel sentences
26 | """
27 | if not (os.path.isfile(os.path.join(EUROPARL_DIR, 'europarl-v7.%s-%s.%s' % (lg1, lg2, lg1))) or
28 | os.path.isfile(os.path.join(EUROPARL_DIR, 'europarl-v7.%s-%s.%s' % (lg2, lg1, lg1)))):
29 | return None
30 |
31 | if os.path.isfile(os.path.join(EUROPARL_DIR, 'europarl-v7.%s-%s.%s' % (lg2, lg1, lg1))):
32 | lg1, lg2 = lg2, lg1
33 |
34 | # load sentences
35 | data = {lg1: [], lg2: []}
36 | for lg in [lg1, lg2]:
37 | fname = os.path.join(EUROPARL_DIR, 'europarl-v7.%s-%s.%s' % (lg1, lg2, lg))
38 |
39 | with io.open(fname, 'r', encoding='utf-8') as f:
40 | for i, line in enumerate(f):
41 | if i >= n_max:
42 | break
43 | line = line.lower() if lower else line
44 | data[lg].append(line.rstrip().split())
45 |
46 | # get only unique sentences for each language
47 | assert len(data[lg1]) == len(data[lg2])
48 | data[lg1] = np.array(data[lg1])
49 | data[lg2] = np.array(data[lg2])
50 | data[lg1], indices = np.unique(data[lg1], return_index=True)
51 | data[lg2] = data[lg2][indices]
52 | data[lg2], indices = np.unique(data[lg2], return_index=True)
53 | data[lg1] = data[lg1][indices]
54 |
55 | # shuffle sentences
56 | rng = np.random.RandomState(1234)
57 | perm = rng.permutation(len(data[lg1]))
58 | data[lg1] = data[lg1][perm]
59 | data[lg2] = data[lg2][perm]
60 |
61 | logger.info("Loaded europarl %s-%s (%i sentences)." % (lg1, lg2, len(data[lg1])))
62 | return data
63 |
64 |
65 | def get_sent_translation_accuracy(data, lg1, word2id1, emb1, lg2, word2id2, emb2,
66 | n_keys, n_queries, method, idf):
67 |
68 | """
69 | Given parallel sentences from Europarl, evaluate the
70 | sentence translation accuracy using the precision@k.
71 | """
72 | # get word vectors dictionaries
73 | emb1 = emb1.cpu().detach().numpy()
74 | emb2 = emb2.cpu().detach().numpy()
75 | word_vec1 = dict([(w, emb1[word2id1[w]]) for w in word2id1])
76 | word_vec2 = dict([(w, emb2[word2id2[w]]) for w in word2id2])
77 | word_vect = {lg1: word_vec1, lg2: word_vec2}
78 | lg_keys = lg2
79 | lg_query = lg1
80 |
81 | # get n_keys pairs of sentences
82 | keys = data[lg_keys][:n_keys]
83 | keys = bow_idf(keys, word_vect[lg_keys], idf_dict=idf[lg_keys])
84 |
85 | # get n_queries query pairs from these n_keys pairs
86 | rng = np.random.RandomState(1234)
87 | idx_query = rng.choice(range(n_keys), size=n_queries, replace=False)
88 | queries = data[lg_query][idx_query]
89 | queries = bow_idf(queries, word_vect[lg_query], idf_dict=idf[lg_query])
90 |
91 | # normalize embeddings
92 | queries = torch.from_numpy(queries).float()
93 | queries = queries / queries.norm(2, 1, keepdim=True).expand_as(queries)
94 | keys = torch.from_numpy(keys).float()
95 | keys = keys / keys.norm(2, 1, keepdim=True).expand_as(keys)
96 |
97 | # nearest neighbors
98 | if method == 'nn':
99 | scores = keys.mm(queries.transpose(0, 1)).transpose(0, 1)
100 | scores = scores.cpu()
101 |
102 | # inverted softmax
103 | elif method.startswith('invsm_beta_'):
104 | beta = float(method[len('invsm_beta_'):])
105 | scores = keys.mm(queries.transpose(0, 1)).transpose(0, 1)
106 | scores.mul_(beta).exp_()
107 | scores.div_(scores.sum(0, keepdim=True).expand_as(scores))
108 | scores = scores.cpu()
109 |
110 | # contextual dissimilarity measure
111 | elif method.startswith('csls_knn_'):
112 | knn = method[len('csls_knn_'):]
113 | assert knn.isdigit()
114 | knn = int(knn)
115 | # average distances to k nearest neighbors
116 | knn = method[len('csls_knn_'):]
117 | assert knn.isdigit()
118 | knn = int(knn)
119 | average_dist_keys = torch.from_numpy(get_nn_avg_dist(queries, keys, knn))
120 | average_dist_queries = torch.from_numpy(get_nn_avg_dist(keys, queries, knn))
121 | # scores
122 | scores = keys.mm(queries.transpose(0, 1)).transpose(0, 1)
123 | scores.mul_(2)
124 | scores.sub_(average_dist_queries[:, None].float() + average_dist_keys[None, :].float())
125 | scores = scores.cpu()
126 |
127 | results = []
128 | top_matches = scores.topk(10, 1, True)[1]
129 | for k in [1, 5, 10]:
130 | top_k_matches = (top_matches[:, :k] == torch.from_numpy(idx_query)[:, None]).sum(1)
131 | precision_at_k = 100 * top_k_matches.float().numpy().mean()
132 | logger.info("%i queries (%s) - %s - Precision at k = %i: %f" %
133 | (len(top_k_matches), lg_query.upper(), method, k, precision_at_k))
134 | results.append(('sent-precision_at_%i' % k, precision_at_k))
135 |
136 | return results
137 |
--------------------------------------------------------------------------------
/umwe/src/evaluation/word_translation.py:
--------------------------------------------------------------------------------
1 | # Original work Copyright (c) 2017-present, Facebook, Inc.
2 | # Modified work Copyright (c) 2018, Xilun Chen
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | #
8 |
9 | import os
10 | import io
11 | from logging import getLogger
12 | import numpy as np
13 | import torch
14 |
15 | from ..utils import get_nn_avg_dist
16 |
17 |
18 | DIC_EVAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', 'data', 'crosslingual', 'dictionaries')
19 |
20 |
21 | logger = getLogger()
22 |
23 |
24 | def load_identical_char_dico(word2id1, word2id2):
25 | """
26 | Build a dictionary of identical character strings.
27 | """
28 | pairs = [(w1, w1) for w1 in word2id1.keys() if w1 in word2id2]
29 | if len(pairs) == 0:
30 | raise Exception("No identical character strings were found. "
31 | "Please specify a dictionary.")
32 |
33 | logger.info("Found %i pairs of identical character strings." % len(pairs))
34 |
35 | # sort the dictionary by source word frequencies
36 | pairs = sorted(pairs, key=lambda x: word2id1[x[0]])
37 | dico = torch.LongTensor(len(pairs), 2)
38 | for i, (word1, word2) in enumerate(pairs):
39 | dico[i, 0] = word2id1[word1]
40 | dico[i, 1] = word2id2[word2]
41 |
42 | return dico
43 |
44 |
45 | def load_dictionary(path, word2id1, word2id2):
46 | """
47 | Return a torch tensor of size (n, 2) where n is the size of the
48 | loader dictionary, and sort it by source word frequency.
49 | """
50 | assert os.path.isfile(path)
51 |
52 | pairs = []
53 | not_found = 0
54 | not_found1 = 0
55 | not_found2 = 0
56 |
57 | with io.open(path, 'r', encoding='utf-8') as f:
58 | for _, line in enumerate(f):
59 | assert line == line.lower()
60 | word1, word2 = line.rstrip().split()
61 | if word1 in word2id1 and word2 in word2id2:
62 | pairs.append((word1, word2))
63 | else:
64 | not_found += 1
65 | not_found1 += int(word1 not in word2id1)
66 | not_found2 += int(word2 not in word2id2)
67 |
68 | logger.info("Found %i pairs of words in %s (%i unique). "
69 | "%i other pairs contained at least one unknown word "
70 | "(%i in lang1, %i in lang2)"
71 | % (len(pairs), path, len(set([x for x, _ in pairs])),
72 | not_found, not_found1, not_found2))
73 |
74 | # sort the dictionary by source word frequencies
75 | pairs = sorted(pairs, key=lambda x: word2id1[x[0]])
76 | dico = torch.LongTensor(len(pairs), 2)
77 | for i, (word1, word2) in enumerate(pairs):
78 | dico[i, 0] = word2id1[word1]
79 | dico[i, 1] = word2id2[word2]
80 |
81 | return dico
82 |
83 |
84 | def get_word_translation_accuracy(lang1, word2id1, emb1, lang2, word2id2, emb2, method, dico_eval):
85 | """
86 | Given source and target word embeddings, and a dictionary,
87 | evaluate the translation accuracy using the precision@k.
88 | """
89 | if dico_eval == 'default':
90 | path = os.path.join(DIC_EVAL_PATH, '%s-%s.5000-6500.txt' % (lang1, lang2))
91 | elif dico_eval == 'train':
92 | path = os.path.join(DIC_EVAL_PATH, '%s-%s.0-5000.txt' % (lang1, lang2))
93 | elif dico_eval == 'all':
94 | path = os.path.join(DIC_EVAL_PATH, '%s-%s.txt' % (lang1, lang2))
95 | else:
96 | raise NotImplemented(dico_eval)
97 | dico = load_dictionary(path, word2id1, word2id2).to(emb1.device)
98 |
99 | assert dico[:, 0].max() < emb1.size(0)
100 | assert dico[:, 1].max() < emb2.size(0)
101 |
102 | # normalize word embeddings
103 | emb1 = emb1 / emb1.norm(2, 1, keepdim=True).expand_as(emb1)
104 | emb2 = emb2 / emb2.norm(2, 1, keepdim=True).expand_as(emb2)
105 |
106 | # nearest neighbors
107 | if method == 'nn':
108 | query = emb1[dico[:, 0]]
109 | scores = query.mm(emb2.transpose(0, 1))
110 |
111 | # inverted softmax
112 | elif method.startswith('invsm_beta_'):
113 | beta = float(method[len('invsm_beta_'):])
114 | bs = 128
115 | word_scores = []
116 | for i in range(0, emb2.size(0), bs):
117 | scores = emb1.mm(emb2[i:i + bs].transpose(0, 1))
118 | scores.mul_(beta).exp_()
119 | scores.div_(scores.sum(0, keepdim=True).expand_as(scores))
120 | word_scores.append(scores.index_select(0, dico[:, 0]))
121 | scores = torch.cat(word_scores, 1)
122 |
123 | # contextual dissimilarity measure
124 | elif method.startswith('csls_knn_'):
125 | # average distances to k nearest neighbors
126 | knn = method[len('csls_knn_'):]
127 | assert knn.isdigit()
128 | knn = int(knn)
129 | average_dist1 = get_nn_avg_dist(emb2, emb1, knn)
130 | average_dist2 = get_nn_avg_dist(emb1, emb2, knn)
131 | average_dist1 = torch.from_numpy(average_dist1).type_as(emb1)
132 | average_dist2 = torch.from_numpy(average_dist2).type_as(emb2)
133 | # queries / scores
134 | query = emb1[dico[:, 0]]
135 | scores = query.mm(emb2.transpose(0, 1))
136 | scores.mul_(2)
137 | scores.sub_(average_dist1[dico[:, 0]][:, None])
138 | scores.sub_(average_dist2[None, :])
139 |
140 | else:
141 | raise Exception('Unknown method: "%s"' % method)
142 |
143 | results = []
144 | top_matches = scores.topk(10, 1, True)[1]
145 | for k in [1, 5, 10]:
146 | top_k_matches = top_matches[:, :k]
147 | _matching = (top_k_matches == dico[:, 1][:, None].expand_as(top_k_matches)).sum(1)
148 | # allow for multiple possible translations
149 | matching = {}
150 | for i, src_id in enumerate(dico[:, 0].cpu().numpy()):
151 | matching[src_id] = min(matching.get(src_id, 0) + _matching[i], 1)
152 | # evaluate precision@k
153 | precision_at_k = 100 * np.mean(list(matching.values()))
154 | logger.info("%i source words - %s - Precision at k = %i: %f" %
155 | (len(matching), method, k, precision_at_k))
156 | results.append(('precision_at_%i' % k, precision_at_k))
157 |
158 | return results
159 |
--------------------------------------------------------------------------------
/umwe/src/evaluation/wordsim.py:
--------------------------------------------------------------------------------
1 | # Original work Copyright (c) 2017-present, Facebook, Inc.
2 | # Modified work Copyright (c) 2018, Xilun Chen
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | #
8 |
9 | import os
10 | import io
11 | from logging import getLogger
12 | import numpy as np
13 | import torch
14 | from scipy.stats import spearmanr
15 |
16 |
17 | MONOLINGUAL_EVAL_PATH = 'data/monolingual'
18 | SEMEVAL17_EVAL_PATH = 'data/crosslingual/wordsim'
19 |
20 |
21 | logger = getLogger()
22 |
23 |
24 | def get_word_pairs(path, lower=True):
25 | """
26 | Return a list of (word1, word2, score) tuples from a word similarity file.
27 | """
28 | assert os.path.isfile(path) and type(lower) is bool
29 | word_pairs = []
30 | with io.open(path, 'r', encoding='utf-8') as f:
31 | for line in f:
32 | line = line.rstrip()
33 | line = line.lower() if lower else line
34 | line = line.split()
35 | # ignore phrases, only consider words
36 | if len(line) != 3:
37 | assert len(line) > 3
38 | assert 'SEMEVAL17' in os.path.basename(path) or 'EN-IT_MWS353' in path
39 | continue
40 | word_pairs.append((line[0], line[1], float(line[2])))
41 | return word_pairs
42 |
43 |
44 | def get_word_id(word, word2id, lower):
45 | """
46 | Get a word ID.
47 | If the model does not use lowercase and the evaluation file is lowercased,
48 | we might be able to find an associated word.
49 | """
50 | assert type(lower) is bool
51 | word_id = word2id.get(word)
52 | if word_id is None and not lower:
53 | word_id = word2id.get(word.capitalize())
54 | if word_id is None and not lower:
55 | word_id = word2id.get(word.title())
56 | return word_id
57 |
58 |
59 | def get_spearman_rho(word2id1, embeddings1, path, lower,
60 | word2id2=None, embeddings2=None, ignore_oov=True):
61 | """
62 | Compute monolingual or cross-lingual word similarity score.
63 | """
64 | assert not ((word2id2 is None) ^ (embeddings2 is None))
65 | word2id2 = word2id1 if word2id2 is None else word2id2
66 | embeddings2 = embeddings1 if embeddings2 is None else embeddings2
67 | assert len(word2id1) == embeddings1.shape[0]
68 | assert len(word2id2) == embeddings2.shape[0]
69 | assert type(lower) is bool
70 | word_pairs = get_word_pairs(path)
71 | not_found = 0
72 | pred = []
73 | gold = []
74 | for word1, word2, similarity in word_pairs:
75 | id1 = get_word_id(word1, word2id1, lower)
76 | id2 = get_word_id(word2, word2id2, lower)
77 | if id1 is None or id2 is None:
78 | not_found += 1
79 | if not ignore_oov:
80 | gold.append(similarity)
81 | pred.append(0.5)
82 | continue
83 | u = embeddings1[id1]
84 | v = embeddings2[id2]
85 | score = u.dot(v) / (np.linalg.norm(u) * np.linalg.norm(v))
86 | gold.append(similarity)
87 | pred.append(score)
88 | return spearmanr(gold, pred).correlation, len(gold), not_found
89 |
90 |
91 | def get_wordsim_scores(language, word2id, embeddings, lower=True):
92 | """
93 | Return monolingual word similarity scores.
94 | """
95 | dirpath = os.path.join(MONOLINGUAL_EVAL_PATH, language)
96 | if not os.path.isdir(dirpath):
97 | return None
98 |
99 | scores = {}
100 | separator = "=" * (30 + 1 + 10 + 1 + 13 + 1 + 12)
101 | pattern = "%30s %10s %13s %12s"
102 | logger.info(separator)
103 | logger.info(pattern % ("Dataset", "Found", "Not found", "Rho"))
104 | logger.info(separator)
105 |
106 | for filename in list(os.listdir(dirpath)):
107 | if filename.startswith('%s_' % (language.upper())):
108 | filepath = os.path.join(dirpath, filename)
109 | coeff, found, not_found = get_spearman_rho(word2id, embeddings, filepath, lower)
110 | logger.info(pattern % (filename[:-4], str(found), str(not_found), "%.4f" % coeff))
111 | scores[filename[:-4]] = coeff
112 | logger.info(separator)
113 |
114 | return scores
115 |
116 |
117 | def get_wordanalogy_scores(language, word2id, embeddings, lower=True):
118 | """
119 | Return (english) word analogy score
120 | """
121 | dirpath = os.path.join(MONOLINGUAL_EVAL_PATH, language)
122 | if not os.path.isdir(dirpath) or language not in ["en"]:
123 | return None
124 |
125 | # normalize word embeddings
126 | embeddings = embeddings / np.sqrt((embeddings ** 2).sum(1))[:, None]
127 |
128 | # scores by category
129 | scores = {}
130 |
131 | word_ids = {}
132 | queries = {}
133 |
134 | with io.open(os.path.join(dirpath, 'questions-words.txt'), 'r', encoding='utf-8') as f:
135 | for line in f:
136 | # new line
137 | line = line.rstrip()
138 | if lower:
139 | line = line.lower()
140 |
141 | # new category
142 | if ":" in line:
143 | assert line[1] == ' '
144 | category = line[2:]
145 | assert category not in scores
146 | scores[category] = {'n_found': 0, 'n_not_found': 0, 'n_correct': 0}
147 | word_ids[category] = []
148 | queries[category] = []
149 | continue
150 |
151 | # get word IDs
152 | assert len(line.split()) == 4, line
153 | word1, word2, word3, word4 = line.split()
154 | word_id1 = get_word_id(word1, word2id, lower)
155 | word_id2 = get_word_id(word2, word2id, lower)
156 | word_id3 = get_word_id(word3, word2id, lower)
157 | word_id4 = get_word_id(word4, word2id, lower)
158 |
159 | # if at least one word is not found
160 | if any(x is None for x in [word_id1, word_id2, word_id3, word_id4]):
161 | scores[category]['n_not_found'] += 1
162 | continue
163 | else:
164 | scores[category]['n_found'] += 1
165 | word_ids[category].append([word_id1, word_id2, word_id3, word_id4])
166 | # generate query vector and get nearest neighbors
167 | query = embeddings[word_id1] - embeddings[word_id2] + embeddings[word_id4]
168 | query = query / np.linalg.norm(query)
169 |
170 | queries[category].append(query)
171 |
172 | # Compute score for each category
173 | for cat in queries:
174 | qs = torch.from_numpy(np.vstack(queries[cat]))
175 | keys = torch.from_numpy(embeddings.T)
176 | values = qs.mm(keys).cpu().numpy()
177 |
178 | # be sure we do not select input words
179 | for i, ws in enumerate(word_ids[cat]):
180 | for wid in [ws[0], ws[1], ws[3]]:
181 | values[i, wid] = -1e9
182 | scores[cat]['n_correct'] = np.sum(values.argmax(axis=1) == [ws[2] for ws in word_ids[cat]])
183 |
184 | # pretty print
185 | separator = "=" * (30 + 1 + 10 + 1 + 13 + 1 + 12)
186 | pattern = "%30s %10s %13s %12s"
187 | logger.info(separator)
188 | logger.info(pattern % ("Category", "Found", "Not found", "Accuracy"))
189 | logger.info(separator)
190 |
191 | # compute and log accuracies
192 | accuracies = {}
193 | for k in sorted(scores.keys()):
194 | v = scores[k]
195 | accuracies[k] = float(v['n_correct']) / max(v['n_found'], 1)
196 | logger.info(pattern % (k, str(v['n_found']), str(v['n_not_found']), "%.4f" % accuracies[k]))
197 | logger.info(separator)
198 |
199 | return accuracies
200 |
201 |
202 | def get_crosslingual_wordsim_scores(lang1, word2id1, embeddings1,
203 | lang2, word2id2, embeddings2,
204 | lower=True, ignore_oov=True):
205 | """
206 | Return cross-lingual word similarity scores.
207 | """
208 | f1 = os.path.join(SEMEVAL17_EVAL_PATH, '%s-%s-SEMEVAL17.txt' % (lang1, lang2))
209 | f2 = os.path.join(SEMEVAL17_EVAL_PATH, '%s-%s-SEMEVAL17.txt' % (lang2, lang1))
210 | if not (os.path.exists(f1) or os.path.exists(f2)):
211 | return None
212 |
213 | if os.path.exists(f1):
214 | coeff, found, not_found = get_spearman_rho(
215 | word2id1, embeddings1, f1,
216 | lower, word2id2, embeddings2, ignore_oov
217 | )
218 | elif os.path.exists(f2):
219 | coeff, found, not_found = get_spearman_rho(
220 | word2id2, embeddings2, f2,
221 | lower, word2id1, embeddings1, ignore_oov
222 | )
223 |
224 | scores = {}
225 | separator = "=" * (30 + 1 + 10 + 1 + 13 + 1 + 12)
226 | pattern = "%30s %10s %13s %12s"
227 | logger.info(separator)
228 | logger.info(pattern % ("Dataset", "Found", "Not found", "Rho"))
229 | logger.info(separator)
230 |
231 | task_name = '%s_%s_SEMEVAL17' % (lang1.upper(), lang2.upper())
232 | logger.info(pattern % (task_name, str(found), str(not_found), "%.4f" % coeff))
233 | scores[task_name] = coeff
234 | if not scores:
235 | return None
236 | logger.info(separator)
237 |
238 | return scores
239 |
--------------------------------------------------------------------------------
/umwe/src/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import logging
9 | import time
10 | from datetime import timedelta
11 |
12 |
13 | class LogFormatter():
14 |
15 | def __init__(self):
16 | self.start_time = time.time()
17 |
18 | def format(self, record):
19 | elapsed_seconds = round(record.created - self.start_time)
20 |
21 | prefix = "%s - %s - %s" % (
22 | record.levelname,
23 | time.strftime('%x %X'),
24 | timedelta(seconds=elapsed_seconds)
25 | )
26 | message = record.getMessage()
27 | message = message.replace('\n', '\n' + ' ' * (len(prefix) + 3))
28 | return "%s - %s" % (prefix, message)
29 |
30 |
31 | def create_logger(filepath, vb=2):
32 | """
33 | Create a logger.
34 | """
35 | # create log formatter
36 | log_formatter = LogFormatter()
37 |
38 | # create file handler and set level to debug
39 | file_handler = logging.FileHandler(filepath, "a")
40 | file_handler.setLevel(logging.DEBUG)
41 | file_handler.setFormatter(log_formatter)
42 |
43 | # create console handler and set level to info
44 | log_level = logging.DEBUG if vb == 2 else logging.INFO if vb == 1 else logging.WARNING
45 | console_handler = logging.StreamHandler()
46 | console_handler.setLevel(log_level)
47 | console_handler.setFormatter(log_formatter)
48 |
49 | # create logger and set level to debug
50 | logger = logging.getLogger()
51 | logger.handlers = []
52 | logger.setLevel(logging.DEBUG)
53 | logger.propagate = False
54 | logger.addHandler(file_handler)
55 | logger.addHandler(console_handler)
56 |
57 | # reset logger elapsed time
58 | def reset_time():
59 | log_formatter.start_time = time.time()
60 | logger.reset_time = reset_time
61 |
62 | return logger
63 |
--------------------------------------------------------------------------------
/umwe/src/models.py:
--------------------------------------------------------------------------------
1 | # Original work Copyright (c) 2017-present, Facebook, Inc.
2 | # Modified work Copyright (c) 2018, Xilun Chen
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | #
8 |
9 | import torch
10 | from torch import nn
11 |
12 | from .utils import load_embeddings, normalize_embeddings
13 |
14 |
15 | class Discriminator(nn.Module):
16 |
17 | def __init__(self, params, lang):
18 | super(Discriminator, self).__init__()
19 |
20 | self.lang = lang
21 | self.emb_dim = params.emb_dim
22 | self.dis_layers = params.dis_layers
23 | self.dis_hid_dim = params.dis_hid_dim
24 | self.dis_dropout = params.dis_dropout
25 | self.dis_input_dropout = params.dis_input_dropout
26 |
27 | layers = [nn.Dropout(self.dis_input_dropout)]
28 | for i in range(self.dis_layers + 1):
29 | input_dim = self.emb_dim if i == 0 else self.dis_hid_dim
30 | output_dim = 1 if i == self.dis_layers else self.dis_hid_dim
31 | layers.append(nn.Linear(input_dim, output_dim))
32 | if i < self.dis_layers:
33 | layers.append(nn.LeakyReLU(0.2))
34 | layers.append(nn.Dropout(self.dis_dropout))
35 | layers.append(nn.Sigmoid())
36 | self.layers = nn.Sequential(*layers)
37 |
38 | def forward(self, x):
39 | assert x.dim() == 2 and x.size(1) == self.emb_dim
40 | return self.layers(x).view(-1)
41 |
42 |
43 | def build_model(params, with_dis):
44 | """
45 | Build all components of the model.
46 | """
47 | # source embeddings
48 | params.vocabs, _src_embs, embs = {}, {}, {}
49 | for i, lang in enumerate(params.src_langs):
50 | dico, emb = load_embeddings(params, lang, params.src_embs[i])
51 | params.vocabs[lang] = dico
52 | _src_embs[lang] = emb
53 | for i, lang in enumerate(params.src_langs):
54 | src_emb = nn.Embedding(len(params.vocabs[lang]), params.emb_dim, sparse=True)
55 | src_emb.weight.data.copy_(_src_embs[lang])
56 | embs[lang] = src_emb
57 |
58 | # target embeddings
59 | if params.tgt_lang:
60 | tgt_dico, _tgt_emb = load_embeddings(params, params.tgt_lang, params.tgt_emb)
61 | params.vocabs[params.tgt_lang] = tgt_dico
62 | tgt_emb = nn.Embedding(len(tgt_dico), params.emb_dim, sparse=True)
63 | tgt_emb.weight.data.copy_(_tgt_emb)
64 | embs[params.tgt_lang] = tgt_emb
65 | else:
66 | tgt_emb = None
67 |
68 | # mappings
69 | mappings = {lang: nn.Linear(params.emb_dim, params.emb_dim,
70 | bias=False) for lang in params.src_langs}
71 | # set tgt mapping to fixed identity matrix
72 | tgt_map = nn.Linear(params.emb_dim, params.emb_dim, bias=False)
73 | tgt_map.weight.data.copy_(torch.diag(torch.ones(params.emb_dim)))
74 | for p in tgt_map.parameters():
75 | p.requires_grad = False
76 | mappings[params.tgt_lang] = tgt_map
77 | if getattr(params, 'map_id_init', True):
78 | for mapping in mappings.values():
79 | mapping.weight.data.copy_(torch.diag(torch.ones(params.emb_dim)))
80 |
81 | # discriminators
82 | discriminators = {lang: Discriminator(params, lang)
83 | for lang in params.all_langs} if with_dis else None
84 |
85 | for lang in params.all_langs:
86 | embs[lang] = embs[lang].to(params.device)
87 | mappings[lang] = mappings[lang].to(params.device)
88 | if with_dis:
89 | discriminators[lang] = discriminators[lang].to(params.device)
90 |
91 | # normalize embeddings
92 | params.lang_mean = {}
93 | for lang, emb in embs.items():
94 | params.lang_mean[lang] = normalize_embeddings(emb.weight.detach(), params.normalize_embeddings)
95 |
96 | return embs, mappings, discriminators
97 |
--------------------------------------------------------------------------------
/umwe/src/trainer.py:
--------------------------------------------------------------------------------
1 | # Original work Copyright (c) 2017-present, Facebook, Inc.
2 | # Modified work Copyright (c) 2018, Xilun Chen
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | #
8 |
9 | import itertools
10 | import os
11 | from logging import getLogger
12 | import random
13 |
14 | import numpy as np
15 | import scipy
16 | import scipy.linalg
17 | import torch
18 | from torch.nn import functional as F
19 | from torch import optim
20 |
21 | from .utils import get_optimizer, load_embeddings, normalize_embeddings, export_embeddings
22 | from .utils import clip_parameters, apply_mapping
23 | from .dico_builder import build_dictionary
24 | from .evaluation.word_translation import DIC_EVAL_PATH, load_identical_char_dico, load_dictionary
25 |
26 |
27 | logger = getLogger()
28 |
29 |
30 | class Trainer(object):
31 |
32 | def __init__(self, embs, mappings, discriminators, params):
33 | """
34 | Initialize trainer script.
35 | """
36 | self.embs = embs
37 | self.vocabs = params.vocabs
38 | self.mappings = mappings
39 | self.discriminators = discriminators
40 | self.params = params
41 | self.dicos = {}
42 |
43 | # optimizers
44 | if hasattr(params, 'map_optimizer'):
45 | optim_fn, optim_params = get_optimizer(params.map_optimizer)
46 | self.map_optimizer = optim_fn(itertools.chain(*[m.parameters()
47 | for l,m in mappings.items() if l!=params.tgt_lang]), **optim_params)
48 | if hasattr(params, 'dis_optimizer'):
49 | optim_fn, optim_params = get_optimizer(params.dis_optimizer)
50 | self.dis_optimizer = optim_fn(itertools.chain(*[d.parameters()
51 | for d in discriminators.values()]), **optim_params)
52 | else:
53 | assert discriminators is None
54 | if hasattr(params, 'mpsr_optimizer'):
55 | optim_fn, optim_params = get_optimizer(params.mpsr_optimizer)
56 | self.mpsr_optimizer = optim_fn(itertools.chain(*[m.parameters()
57 | for l,m in self.mappings.items() if l!=self.params.tgt_lang]), **optim_params)
58 |
59 | # best validation score
60 | self.best_valid_metric = -1e12
61 |
62 | self.decrease_lr = False
63 | self.decrease_mpsr_lr = False
64 |
65 | def get_dis_xy(self, lang1, lang2, volatile):
66 | """
67 | Get discriminator input batch / output target.
68 | Encode from lang1, decode to lang2 and then discriminate
69 | """
70 | # select random word IDs
71 | bs = self.params.batch_size
72 | mf = self.params.dis_most_frequent
73 | assert mf <= min(map(len, self.vocabs.values()))
74 | src_ids = torch.LongTensor(bs).random_(len(self.vocabs[lang1]) if mf == 0 else mf)
75 | tgt_ids = torch.LongTensor(bs).random_(len(self.vocabs[lang2]) if mf == 0 else mf)
76 | src_ids = src_ids.to(self.params.device)
77 | tgt_ids = tgt_ids.to(self.params.device)
78 |
79 | with torch.set_grad_enabled(not volatile):
80 | # get word embeddings
81 | src_emb = self.embs[lang1](src_ids).detach()
82 | tgt_emb = self.embs[lang2](tgt_ids).detach()
83 | # map
84 | src_emb = self.mappings[lang1](src_emb)
85 | # decode
86 | src_emb = F.linear(src_emb, self.mappings[lang2].weight.t())
87 |
88 | # input / target
89 | x = torch.cat([src_emb, tgt_emb], 0)
90 | y = torch.FloatTensor(2 * bs).zero_()
91 | # 0 indicates real (lang2) samples
92 | y[:bs] = 1 - self.params.dis_smooth
93 | y[bs:] = self.params.dis_smooth
94 | y = y.to(self.params.device)
95 |
96 | return x, y
97 |
98 | def get_mpsr_xy(self, lang1, lang2, volatile):
99 | """
100 | Get input batch / output target for MPSR.
101 | """
102 | # select random word IDs
103 | bs = self.params.batch_size
104 | dico = self.dicos[(lang1, lang2)]
105 | indices = torch.from_numpy(np.random.randint(0, len(dico), bs)).to(self.params.device)
106 | dico = dico.index_select(0, indices)
107 | src_ids = dico[:, 0].to(self.params.device)
108 | tgt_ids = dico[:, 1].to(self.params.device)
109 |
110 | with torch.set_grad_enabled(not volatile):
111 | # get word embeddings
112 | src_emb = self.embs[lang1](src_ids).detach()
113 | tgt_emb = self.embs[lang2](tgt_ids).detach()
114 | # map
115 | src_emb = self.mappings[lang1](src_emb)
116 | # decode
117 | src_emb = F.linear(src_emb, self.mappings[lang2].weight.t())
118 |
119 | return src_emb, tgt_emb
120 |
121 | def dis_step(self, stats):
122 | """
123 | Train the discriminator.
124 | """
125 | for disc in self.discriminators.values():
126 | disc.train()
127 |
128 | # loss
129 | loss = 0
130 | # for each target language
131 | for lang2 in self.params.all_langs:
132 | # random select a source language
133 | lang1 = random.choice(self.params.all_langs)
134 |
135 | x, y = self.get_dis_xy(lang1, lang2, volatile=True)
136 | preds = self.discriminators[lang2](x.detach())
137 | loss += F.binary_cross_entropy(preds, y)
138 |
139 | # check NaN
140 | if (loss != loss).any():
141 | logger.error("NaN detected (discriminator)")
142 | exit()
143 | stats['DIS_COSTS'].append(loss.item())
144 |
145 | # optim
146 | self.dis_optimizer.zero_grad()
147 | loss.backward()
148 | self.dis_optimizer.step()
149 | for d in self.discriminators:
150 | clip_parameters(d, self.params.dis_clip_weights)
151 |
152 | def mapping_step(self, stats):
153 | """
154 | Fooling discriminator training step.
155 | """
156 | if self.params.dis_lambda == 0:
157 | return 0
158 |
159 | for disc in self.discriminators.values():
160 | disc.eval()
161 |
162 | # loss
163 | loss = 0
164 | for lang1 in self.params.all_langs:
165 | lang2 = random.choice(self.params.all_langs)
166 |
167 | x, y = self.get_dis_xy(lang1, lang2, volatile=False)
168 | preds = self.discriminators[lang2](x)
169 | loss += F.binary_cross_entropy(preds, 1 - y)
170 | loss = self.params.dis_lambda * loss
171 |
172 | # check NaN
173 | if (loss != loss).any():
174 | logger.error("NaN detected (fool discriminator)")
175 | exit()
176 |
177 | # optim
178 | self.map_optimizer.zero_grad()
179 | loss.backward()
180 | self.map_optimizer.step()
181 | self.orthogonalize()
182 |
183 | return len(self.params.all_langs) * self.params.batch_size
184 |
185 | def load_training_dico(self, dico_train):
186 | """
187 | Load training dictionary.
188 | """
189 | # load dicos for all lang pairs
190 | for i, lang1 in enumerate(self.params.all_langs):
191 | for j, lang2 in enumerate(self.params.all_langs):
192 | if lang1 == lang2:
193 | idx = torch.arange(self.params.dico_max_rank).long().view(self.params.dico_max_rank, 1)
194 | self.dicos[(lang1, lang2)] = torch.cat([idx, idx], dim=1).to(self.params.device)
195 | else:
196 | word2id1 = self.vocabs[lang1].word2id
197 | word2id2 = self.vocabs[lang2].word2id
198 |
199 | # identical character strings
200 | if dico_train == "identical_char":
201 | self.dicos[(lang1, lang2)] = load_identical_char_dico(word2id1, word2id2)
202 | # use one of the provided dictionary
203 | elif dico_train == "default":
204 | filename = '%s-%s.0-5000.txt' % (lang1, lang2)
205 | self.dicos[(lang1, lang2)] = load_dictionary(
206 | os.path.join(DIC_EVAL_PATH, filename),
207 | word2id1, word2id2
208 | )
209 | # TODO dictionary provided by the user
210 | else:
211 | # self.dicos[(lang1, lang2)] = load_dictionary(dico_train, word2id1, word2id2)
212 | raise NotImplemented(dico_train)
213 | self.dicos[(lang1, lang2)] = self.dicos[(lang1, lang2)].to(self.params.device)
214 |
215 | def build_dictionary(self):
216 | """
217 | Build dictionaries from aligned embeddings.
218 | """
219 | # build dicos for all lang pairs
220 | for i, lang1 in enumerate(self.params.all_langs):
221 | for j, lang2 in enumerate(self.params.all_langs):
222 | if i < j:
223 | src_emb = self.embs[lang1].weight
224 | src_emb = apply_mapping(self.mappings[lang1], src_emb).detach()
225 | tgt_emb = self.embs[lang2].weight
226 | tgt_emb = apply_mapping(self.mappings[lang2], tgt_emb).detach()
227 | src_emb = src_emb / src_emb.norm(2, 1, keepdim=True).expand_as(src_emb)
228 | tgt_emb = tgt_emb / tgt_emb.norm(2, 1, keepdim=True).expand_as(tgt_emb)
229 | self.dicos[(lang1, lang2)] = build_dictionary(src_emb, tgt_emb, self.params)
230 | elif i > j:
231 | self.dicos[(lang1, lang2)] = self.dicos[(lang2, lang1)][:, [1,0]]
232 | else:
233 | idx = torch.arange(self.params.dico_max_rank).long().view(self.params.dico_max_rank, 1)
234 | self.dicos[(lang1, lang2)] = torch.cat([idx, idx], dim=1).to(self.params.device)
235 |
236 | def mpsr_step(self, stats):
237 | # loss
238 | loss = 0
239 | for lang1 in self.params.all_langs:
240 | lang2 = random.choice(self.params.all_langs)
241 |
242 | x, y = self.get_mpsr_xy(lang1, lang2, volatile=False)
243 | loss += F.mse_loss(x, y)
244 | # check NaN
245 | if (loss != loss).any():
246 | logger.error("NaN detected (fool discriminator)")
247 | exit()
248 |
249 | stats['MPSR_COSTS'].append(loss.item())
250 | # optim
251 | self.mpsr_optimizer.zero_grad()
252 | loss.backward()
253 | self.mpsr_optimizer.step()
254 |
255 | if self.params.mpsr_orthogonalize:
256 | self.orthogonalize()
257 |
258 | return len(self.params.all_langs) * self.params.batch_size
259 |
260 | def orthogonalize(self):
261 | """
262 | Orthogonalize the mapping.
263 | """
264 | if self.params.map_beta > 0:
265 | for mapping in self.mappings.values():
266 | W = mapping.weight.detach()
267 | beta = self.params.map_beta
268 | W.copy_((1 + beta) * W - beta * W.mm(W.transpose(0, 1).mm(W)))
269 |
270 | def update_lr(self, to_log, metric):
271 | """
272 | Update learning rate when using SGD.
273 | """
274 | if 'sgd' not in self.params.map_optimizer:
275 | return
276 | old_lr = self.map_optimizer.param_groups[0]['lr']
277 | new_lr = max(self.params.min_lr, old_lr * self.params.lr_decay)
278 | if new_lr < old_lr:
279 | logger.info("Decreasing learning rate: %.8f -> %.8f" % (old_lr, new_lr))
280 | self.map_optimizer.param_groups[0]['lr'] = new_lr
281 |
282 | if self.params.lr_shrink < 1 and to_log[metric] >= -1e7:
283 | if to_log[metric] < self.best_valid_metric:
284 | logger.info("Validation metric is smaller than the best: %.5f vs %.5f"
285 | % (to_log[metric], self.best_valid_metric))
286 | # decrease the learning rate, only if this is the
287 | # second time the validation metric decreases
288 | if self.decrease_lr:
289 | old_lr = self.map_optimizer.param_groups[0]['lr']
290 | self.map_optimizer.param_groups[0]['lr'] *= self.params.lr_shrink
291 | logger.info("Shrinking the learning rate: %.5f -> %.5f"
292 | % (old_lr, self.map_optimizer.param_groups[0]['lr']))
293 | self.decrease_lr = True
294 |
295 | def update_mpsr_lr(self, to_log, metric):
296 | """
297 | Update learning rate when using SGD.
298 | """
299 | if 'sgd' not in self.params.mpsr_optimizer:
300 | return
301 | old_lr = self.mpsr_optimizer.param_groups[0]['lr']
302 | new_lr = max(self.params.min_lr, old_lr * self.params.lr_decay)
303 | if new_lr < old_lr:
304 | logger.info("Decreasing learning rate: %.8f -> %.8f" % (old_lr, new_lr))
305 | self.mpsr_optimizer.param_groups[0]['lr'] = new_lr
306 |
307 | if self.params.lr_shrink < 1 and to_log[metric] >= -1e7:
308 | if to_log[metric] < self.best_valid_metric:
309 | logger.info("Validation metric is smaller than the best: %.5f vs %.5f"
310 | % (to_log[metric], self.best_valid_metric))
311 | # decrease the learning rate, only if this is the
312 | # second time the validation metric decreases
313 | if self.decrease_mpsr_lr:
314 | old_lr = self.mpsr_optimizer.param_groups[0]['lr']
315 | self.mpsr_optimizer.param_groups[0]['lr'] *= self.params.lr_shrink
316 | logger.info("Shrinking the learning rate: %.5f -> %.5f"
317 | % (old_lr, self.mpsr_optimizer.param_groups[0]['lr']))
318 | self.decrease_mpsr_lr = True
319 |
320 | def save_best(self, to_log, metric):
321 | """
322 | Save the best model for the given validation metric.
323 | """
324 | # best mapping for the given validation criterion
325 | if to_log[metric] > self.best_valid_metric:
326 | # new best mapping
327 | self.best_valid_metric = to_log[metric]
328 | logger.info('* Best value for "%s": %.5f' % (metric, to_log[metric]))
329 | # save the mapping
330 | tgt_lang = self.params.tgt_lang
331 | for src_lang in self.params.src_langs:
332 | W = self.mappings[src_lang].weight.detach().cpu().numpy()
333 | path = os.path.join(self.params.exp_path,
334 | f'best_mapping_{src_lang}2{tgt_lang}.t7')
335 | logger.info(f'* Saving the {src_lang} to {tgt_lang} mapping to %s ...' % path)
336 | torch.save(W, path)
337 |
338 | def reload_best(self):
339 | """
340 | Reload the best mapping.
341 | """
342 | tgt_lang = self.params.tgt_lang
343 | for src_lang in self.params.src_langs:
344 | path = os.path.join(self.params.exp_path,
345 | f'best_mapping_{src_lang}2{tgt_lang}.t7')
346 | logger.info(f'* Reloading the best {src_lang} to {tgt_lang} model from {path} ...')
347 | # reload the model
348 | assert os.path.isfile(path)
349 | to_reload = torch.from_numpy(torch.load(path))
350 | W = self.mappings[src_lang].weight.detach()
351 | assert to_reload.size() == W.size()
352 | W.copy_(to_reload.type_as(W))
353 |
354 | def export(self):
355 | """
356 | Export embeddings.
357 | """
358 | params = self.params
359 | # load all embeddings
360 | logger.info("Reloading embeddings for mapping ...")
361 | params.vocabs[params.tgt_lang], tgt_emb = load_embeddings(params, params.tgt_lang,
362 | params.tgt_emb, full_vocab=True)
363 | normalize_embeddings(tgt_emb, params.normalize_embeddings,
364 | mean=params.lang_mean[params.tgt_lang])
365 | # export target embeddings
366 | export_embeddings(tgt_emb, self.params.tgt_lang, self.params)
367 | # export all source embeddings
368 | for i, src_lang in enumerate(self.params.src_langs):
369 | params.vocabs[src_lang], src_emb = load_embeddings(params, src_lang,
370 | params.src_embs[i], full_vocab=True)
371 | logger.info(f"Map {src_lang} embeddings to the target space ...")
372 | src_emb = apply_mapping(self.mappings[src_lang], src_emb)
373 | export_embeddings(src_emb, src_lang, self.params)
374 |
--------------------------------------------------------------------------------
/umwe/supervised.py:
--------------------------------------------------------------------------------
1 | # Original work Copyright (c) 2017-present, Facebook, Inc.
2 | # Modified work Copyright (c) 2018, Xilun Chen
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | #
8 |
9 | import os
10 | import json
11 | import argparse
12 | from collections import OrderedDict
13 | import numpy as np
14 | import time
15 | import torch
16 |
17 | from src.utils import bool_flag, initialize_exp
18 | from src.models import build_model
19 | from src.trainer import Trainer
20 | from src.evaluation import Evaluator
21 |
22 |
23 | VALIDATION_METRIC_SUP = 'precision_at_1-csls_knn_10'
24 | VALIDATION_METRIC_UNSUP = 'mean_cosine-csls_knn_10-S2T-10000'
25 | # default path to embeddings embeddings if not otherwise specified
26 | EMB_DIR = 'data/fasttext-vectors/'
27 |
28 | # main
29 | parser = argparse.ArgumentParser(description='Supervised training')
30 | parser.add_argument("--seed", type=int, default=-1, help="Initialization seed")
31 | parser.add_argument("--verbose", type=int, default=2, help="Verbose level (2:debug, 1:info, 0:warning)")
32 | parser.add_argument("--exp_path", type=str, default="", help="Where to store experiment logs and models")
33 | parser.add_argument("--exp_name", type=str, default="debug", help="Experiment name")
34 | parser.add_argument("--exp_id", type=str, default="", help="Experiment ID")
35 | # parser.add_argument("--cuda", type=bool_flag, default=True, help="Run on GPU")
36 | parser.add_argument("--device", type=str, default="cuda", help="Run on GPU or CPU")
37 | parser.add_argument("--export", type=str, default="txt", help="Export embeddings after training (txt / pth)")
38 |
39 | # data
40 | parser.add_argument("--src_langs", type=str, nargs='+', default=['de', 'es', 'fr', 'it', 'pt'], help="Source languages")
41 | parser.add_argument("--tgt_lang", type=str, default='es', help="Target language")
42 | parser.add_argument("--emb_dim", type=int, default=300, help="Embedding dimension")
43 | parser.add_argument("--max_vocab", type=int, default=200000, help="Maximum vocabulary size (-1 to disable)")
44 | # training refinement
45 | parser.add_argument("--n_refinement", type=int, default=5, help="Number of refinement iterations (0 to disable the refinement procedure)")
46 | parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
47 | parser.add_argument("--map_beta", type=float, default=0.001, help="Beta for orthogonalization")
48 | # MPSR parameters
49 | parser.add_argument("--mpsr_optimizer", type=str, default="adam", help="Multilingual Pseudo-Supervised Refinement optimizer")
50 | parser.add_argument("--mpsr_orthogonalize", type=bool_flag, default=True, help="During MPSR, whether to perform orthogonalization")
51 | parser.add_argument("--mpsr_n_steps", type=int, default=30000, help="Number of optimization steps for MPSR")
52 | # dictionary creation parameters (for refinement)
53 | parser.add_argument("--dico_train", type=str, default="default", help="Path to training dictionary (default or identical_char)")
54 | parser.add_argument("--dico_eval", type=str, default="default", help="Path to evaluation dictionary")
55 | parser.add_argument("--dico_method", type=str, default='csls_knn_10', help="Method used for dictionary generation (nn/invsm_beta_30/csls_knn_10)")
56 | parser.add_argument("--dico_build", type=str, default='S2T&T2S', help="S2T,T2S,S2T|T2S,S2T&T2S")
57 | parser.add_argument("--dico_threshold", type=float, default=0, help="Threshold confidence for dictionary generation")
58 | parser.add_argument("--dico_max_rank", type=int, default=10000, help="Maximum dictionary words rank (0 to disable)")
59 | parser.add_argument("--dico_min_size", type=int, default=0, help="Minimum generated dictionary size (0 to disable)")
60 | parser.add_argument("--dico_max_size", type=int, default=0, help="Maximum generated dictionary size (0 to disable)")
61 | parser.add_argument("--semeval_ignore_oov", type=bool_flag, default=True, help="Whether to ignore OOV in SEMEVAL evaluation (the original authors used True)")
62 | # reload pre-trained embeddings
63 | parser.add_argument("--src_embs", type=str, nargs='+', default=[], help="Reload source embeddings (should be in the same order as in src_langs)")
64 | parser.add_argument("--tgt_emb", type=str, default='', help="Reload target embeddings")
65 | parser.add_argument("--normalize_embeddings", type=str, default="", help="Normalize embeddings before training")
66 |
67 |
68 | # parse parameters
69 | params = parser.parse_args()
70 |
71 | # post-processing options
72 | params.src_N = len(params.src_langs)
73 | params.all_langs = params.src_langs + [params.tgt_lang]
74 | # load default embeddings if no embeddings specified
75 | if len(params.src_embs) == 0:
76 | params.src_embs = []
77 | for lang in params.src_langs:
78 | params.src_embs.append(os.path.join(EMB_DIR, f'wiki.{lang}.vec'))
79 | if len(params.tgt_emb) == 0:
80 | params.tgt_emb = os.path.join(EMB_DIR, f'wiki.{params.tgt_lang}.vec')
81 |
82 | # check parameters
83 | assert not params.device.lower().startswith('cuda') or torch.cuda.is_available()
84 | assert params.dico_train in ["identical_char", "default"] or os.path.isfile(params.dico_train)
85 | assert params.dico_build in ["S2T", "T2S", "S2T|T2S", "S2T&T2S"]
86 | assert params.dico_max_size == 0 or params.dico_max_size < params.dico_max_rank
87 | assert params.dico_max_size == 0 or params.dico_max_size > params.dico_min_size
88 | assert all([os.path.isfile(emb) for emb in params.src_embs])
89 | assert os.path.isfile(params.tgt_emb)
90 | assert params.dico_eval == 'default' or os.path.isfile(params.dico_eval)
91 | assert params.export in ["", "txt", "pth"]
92 |
93 | # build logger / model / trainer / evaluator
94 | logger = initialize_exp(params)
95 | # N+1 embeddings, N mappings , N+1 discriminators
96 | embs, mappings, discriminators = build_model(params, False)
97 | trainer = Trainer(embs, mappings, discriminators, params)
98 | evaluator = Evaluator(trainer)
99 |
100 | # load a training dictionary. if a dictionary path is not provided, use a default
101 | # one ("default") or create one based on identical character strings ("identical_char")
102 | trainer.load_training_dico(params.dico_train)
103 |
104 | # define the validation metric
105 | VALIDATION_METRIC = VALIDATION_METRIC_UNSUP if params.dico_train == 'identical_char' else VALIDATION_METRIC_SUP
106 | logger.info("Validation metric: %s" % VALIDATION_METRIC)
107 |
108 | """
109 | Learning loop for Procrustes Iterative Learning
110 | """
111 | for n_epoch in range(params.n_refinement + 1):
112 |
113 | logger.info('Starting iteration %i...' % n_epoch)
114 |
115 | # build a dictionary from aligned embeddings (unless
116 | # it is the first iteration and we use the init one)
117 | if n_epoch > 0 or not hasattr(trainer, 'dicos'):
118 | trainer.build_dictionary()
119 |
120 | # optimize MPSR
121 | tic = time.time()
122 | n_words_mpsr = 0
123 | stats = {'MPSR_COSTS': []}
124 | for n_iter in range(params.mpsr_n_steps):
125 | # mpsr training step
126 | n_words_mpsr += trainer.mpsr_step(stats)
127 | # log stats
128 | if n_iter % 500 == 0:
129 | stats_str = [('MPSR_COSTS', 'MPSR loss')]
130 | stats_log = ['%s: %.4f' % (v, np.mean(stats[k]))
131 | for k, v in stats_str if len(stats[k]) > 0]
132 | stats_log.append('%i samples/s' % int(n_words_mpsr / (time.time() - tic)))
133 | logger.info(('%06i - ' % n_iter) + ' - '.join(stats_log))
134 | # reset
135 | tic = time.time()
136 | n_words_mpsr = 0
137 | for k, _ in stats_str:
138 | del stats[k][:]
139 |
140 | # embeddings evaluation
141 | to_log = OrderedDict({'n_epoch': n_epoch})
142 | evaluator.all_eval(to_log)
143 |
144 | # JSON log / save best model / end of epoch
145 | logger.info("__log__:%s" % json.dumps(to_log))
146 | trainer.save_best(to_log, VALIDATION_METRIC)
147 | logger.info('End of iteration %i.\n\n' % n_epoch)
148 |
149 | # update the learning rate (effective only if using SGD for MPSR)
150 | trainer.update_mpsr_lr(to_log, VALIDATION_METRIC)
151 |
152 | # export embeddings
153 | if params.export:
154 | trainer.reload_best()
155 | trainer.export()
156 |
--------------------------------------------------------------------------------
/umwe/unsupervised.py:
--------------------------------------------------------------------------------
1 | # Original work Copyright (c) 2017-present, Facebook, Inc.
2 | # Modified work Copyright (c) 2018, Xilun Chen
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | #
8 |
9 | import os
10 | import time
11 | import json
12 | import argparse
13 | from collections import OrderedDict
14 | import numpy as np
15 | import torch
16 |
17 | from src.utils import bool_flag, initialize_exp
18 | from src.models import build_model
19 | from src.trainer import Trainer
20 | from src.evaluation import Evaluator
21 |
22 |
23 | VALIDATION_METRIC = 'mean_cosine-csls_knn_10-S2T-10000'
24 | # default path to embeddings embeddings if not otherwise specified
25 | EMB_DIR = 'data/fasttext-vectors/'
26 |
27 |
28 | # main
29 | parser = argparse.ArgumentParser(description='Unsupervised training')
30 | parser.add_argument("--seed", type=int, default=-1, help="Initialization seed")
31 | parser.add_argument("--verbose", type=int, default=2, help="Verbose level (2:debug, 1:info, 0:warning)")
32 | parser.add_argument("--exp_path", type=str, default="", help="Where to store experiment logs and models")
33 | parser.add_argument("--exp_name", type=str, default="debug", help="Experiment name")
34 | parser.add_argument("--exp_id", type=str, default="", help="Experiment ID")
35 | # parser.add_argument("--cuda", type=bool_flag, default=True, help="Run on GPU")
36 | parser.add_argument("--device", type=str, default="cuda", help="Run on GPU or CPU")
37 | parser.add_argument("--export", type=str, default="txt", help="Export embeddings after training (txt / pth)")
38 | # data
39 | parser.add_argument("--src_langs", type=str, nargs='+', default=['de', 'es', 'fr', 'it', 'pt'], help="Source languages")
40 | parser.add_argument("--tgt_lang", type=str, default='en', help="Target language")
41 | parser.add_argument("--emb_dim", type=int, default=300, help="Embedding dimension")
42 | parser.add_argument("--max_vocab", type=int, default=200000, help="Maximum vocabulary size (-1 to disable)")
43 | # mapping
44 | parser.add_argument("--map_id_init", type=bool_flag, default=True, help="Initialize the mapping as an identity matrix")
45 | parser.add_argument("--map_beta", type=float, default=0.001, help="Beta for orthogonalization")
46 | # discriminator
47 | parser.add_argument("--dis_layers", type=int, default=2, help="Discriminator layers")
48 | parser.add_argument("--dis_hid_dim", type=int, default=2048, help="Discriminator hidden layer dimensions")
49 | parser.add_argument("--dis_dropout", type=float, default=0., help="Discriminator dropout")
50 | parser.add_argument("--dis_input_dropout", type=float, default=0.1, help="Discriminator input dropout")
51 | parser.add_argument("--dis_steps", type=int, default=5, help="Discriminator steps")
52 | parser.add_argument("--dis_lambda", type=float, default=1, help="Discriminator loss feedback coefficient")
53 | parser.add_argument("--dis_most_frequent", type=int, default=75000, help="Select embeddings of the k most frequent words for discrimination (0 to disable)")
54 | parser.add_argument("--dis_smooth", type=float, default=0.1, help="Discriminator smooth predictions")
55 | parser.add_argument("--dis_clip_weights", type=float, default=0, help="Clip discriminator weights (0 to disable)")
56 | # training adversarial
57 | parser.add_argument("--adversarial", type=bool_flag, default=True, help="Use adversarial training")
58 | parser.add_argument("--n_epochs", type=int, default=5, help="Number of epochs")
59 | parser.add_argument("--epoch_size", type=int, default=1000000, help="Iterations per epoch")
60 | parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
61 | parser.add_argument("--map_optimizer", type=str, default="sgd,lr=0.1", help="Mapping optimizer")
62 | parser.add_argument("--dis_optimizer", type=str, default="sgd,lr=0.1", help="Discriminator optimizer")
63 | parser.add_argument("--lr_decay", type=float, default=0.98, help="Learning rate decay (SGD only)")
64 | parser.add_argument("--min_lr", type=float, default=1e-6, help="Minimum learning rate (SGD only)")
65 | parser.add_argument("--lr_shrink", type=float, default=0.5, help="Shrink the learning rate if the validation metric decreases (1 to disable)")
66 | # training refinement
67 | parser.add_argument("--n_refinement", type=int, default=5, help="Number of refinement iterations (0 to disable the refinement procedure)")
68 | # MPSR parameters
69 | parser.add_argument("--mpsr_optimizer", type=str, default="adam", help="Multilingual Pseudo-Supervised Refinement optimizer")
70 | parser.add_argument("--mpsr_orthogonalize", type=bool_flag, default=True, help="During MPSR, whether to perform orthogonalization")
71 | parser.add_argument("--mpsr_n_steps", type=int, default=30000, help="Number of optimization steps for MPSR")
72 | # dictionary creation parameters (for refinement)
73 | # default uses .5000-6500.txt; train uses .0-5000.txt; all uses .txt
74 | parser.add_argument("--dico_eval", type=str, default="default", help="Path to evaluation dictionary")
75 | parser.add_argument("--dico_method", type=str, default='csls_knn_10', help="Method used for dictionary generation (nn/invsm_beta_30/csls_knn_10)")
76 | parser.add_argument("--dico_build", type=str, default='S2T&T2S', help="S2T,T2S,S2T|T2S,S2T&T2S")
77 | parser.add_argument("--dico_threshold", type=float, default=0, help="Threshold confidence for dictionary generation")
78 | parser.add_argument("--dico_max_rank", type=int, default=15000, help="Maximum dictionary words rank (0 to disable)")
79 | parser.add_argument("--dico_min_size", type=int, default=0, help="Minimum generated dictionary size (0 to disable)")
80 | parser.add_argument("--dico_max_size", type=int, default=0, help="Maximum generated dictionary size (0 to disable)")
81 | parser.add_argument("--semeval_ignore_oov", type=bool_flag, default=True, help="Whether to ignore OOV in SEMEVAL evaluation (the original authors used True)")
82 | # reload pre-trained embeddings
83 | parser.add_argument("--src_embs", type=str, nargs='+', default=[], help="Reload source embeddings (should be in the same order as in src_langs)")
84 | parser.add_argument("--tgt_emb", type=str, default="", help="Reload target embeddings")
85 | parser.add_argument("--normalize_embeddings", type=str, default="", help="Normalize embeddings before training")
86 |
87 |
88 | # parse parameters
89 | params = parser.parse_args()
90 |
91 | # post-processing options
92 | params.src_N = len(params.src_langs)
93 | params.all_langs = params.src_langs + [params.tgt_lang]
94 | # load default embeddings if no embeddings specified
95 | if len(params.src_embs) == 0:
96 | params.src_embs = []
97 | for lang in params.src_langs:
98 | params.src_embs.append(os.path.join(EMB_DIR, f'wiki.{lang}.vec'))
99 | if len(params.tgt_emb) == 0:
100 | params.tgt_emb = os.path.join(EMB_DIR, f'wiki.{params.tgt_lang}.vec')
101 |
102 | # check parameters
103 | assert not params.device.lower().startswith('cuda') or torch.cuda.is_available()
104 | assert 0 <= params.dis_dropout < 1
105 | assert 0 <= params.dis_input_dropout < 1
106 | assert 0 <= params.dis_smooth < 0.5
107 | assert params.dis_lambda > 0 and params.dis_steps > 0
108 | assert 0 < params.lr_shrink <= 1
109 | assert all([os.path.isfile(emb) for emb in params.src_embs])
110 | assert os.path.isfile(params.tgt_emb)
111 | assert params.dico_eval == 'default' or os.path.isfile(params.dico_eval)
112 | assert params.export in ["", "txt", "pth"]
113 |
114 | # build model / trainer / evaluator
115 | logger = initialize_exp(params)
116 | # N+1 embeddings, N mappings , N+1 discriminators
117 | embs, mappings, discriminators = build_model(params, True)
118 | trainer = Trainer(embs, mappings, discriminators, params)
119 | evaluator = Evaluator(trainer)
120 |
121 |
122 | """
123 | Learning loop for Multilingual Adversarial Training
124 | """
125 | if params.adversarial:
126 | logger.info('----> MULTILINGUAL ADVERSARIAL TRAINING <----\n\n')
127 |
128 | # training loop
129 | for n_epoch in range(params.n_epochs):
130 |
131 | logger.info('Starting adversarial training epoch %i...' % n_epoch)
132 | tic = time.time()
133 | n_words_proc = 0
134 | stats = {'DIS_COSTS': []}
135 |
136 | for n_iter in range(0, params.epoch_size, params.batch_size):
137 |
138 | # discriminator training
139 | for _ in range(params.dis_steps):
140 | trainer.dis_step(stats)
141 |
142 | # mapping training (discriminator fooling)
143 | n_words_proc += trainer.mapping_step(stats)
144 |
145 | # log stats
146 | if n_iter % 500 == 0:
147 | stats_str = [('DIS_COSTS', 'Discriminator loss')]
148 | stats_log = ['%s: %.4f' % (v, np.mean(stats[k]))
149 | for k, v in stats_str if len(stats[k]) > 0]
150 | stats_log.append('%i samples/s' % int(n_words_proc / (time.time() - tic)))
151 | logger.info(('%06i - ' % n_iter) + ' - '.join(stats_log))
152 |
153 | # reset
154 | tic = time.time()
155 | n_words_proc = 0
156 | for k, _ in stats_str:
157 | del stats[k][:]
158 |
159 | # embeddings / discriminator evaluation
160 | to_log = OrderedDict({'n_epoch': n_epoch})
161 | evaluator.all_eval(to_log)
162 | evaluator.eval_all_dis(to_log)
163 |
164 | # JSON log / save best model / end of epoch
165 | logger.info("__log__:%s" % json.dumps(to_log))
166 | trainer.save_best(to_log, VALIDATION_METRIC)
167 | logger.info('End of epoch %i.\n\n' % n_epoch)
168 |
169 | # update the learning rate (stop if too small)
170 | trainer.update_lr(to_log, VALIDATION_METRIC)
171 | if trainer.map_optimizer.param_groups[0]['lr'] < params.min_lr:
172 | logger.info('Learning rate < 1e-6. BREAK.')
173 | break
174 |
175 |
176 | """
177 | Learning loop for Multilingual Pseudo-Supervised Refinement
178 | """
179 | if params.n_refinement > 0:
180 | # Get the best mapping according to VALIDATION_METRIC
181 | logger.info('----> MULTILINGUAL PSEUDO-SUPERVISED REFINEMENT <----\n\n')
182 | trainer.reload_best()
183 |
184 | # training loop
185 | for n_epoch in range(params.n_refinement):
186 |
187 | logger.info('Starting refinement iteration %i...' % n_epoch)
188 |
189 | # build a dictionary from aligned embeddings
190 | trainer.build_dictionary()
191 |
192 | # optimize MPSR
193 | tic = time.time()
194 | n_words_mpsr = 0
195 | stats = {'MPSR_COSTS': []}
196 | for n_iter in range(params.mpsr_n_steps):
197 | # mpsr training step
198 | n_words_mpsr += trainer.mpsr_step(stats)
199 | # log stats
200 | if n_iter % 500 == 0:
201 | stats_str = [('MPSR_COSTS', 'MPSR loss')]
202 | stats_log = ['%s: %.4f' % (v, np.mean(stats[k]))
203 | for k, v in stats_str if len(stats[k]) > 0]
204 | stats_log.append('%i samples/s' % int(n_words_mpsr / (time.time() - tic)))
205 | logger.info(('%06i - ' % n_iter) + ' - '.join(stats_log))
206 | # reset
207 | tic = time.time()
208 | n_words_mpsr = 0
209 | for k, _ in stats_str:
210 | del stats[k][:]
211 |
212 | # embeddings evaluation
213 | to_log = OrderedDict({'n_mpsr_epoch': n_epoch})
214 | evaluator.all_eval(to_log)
215 |
216 | # JSON log / save best model / end of epoch
217 | logger.info("__log__:%s" % json.dumps(to_log))
218 | trainer.save_best(to_log, VALIDATION_METRIC)
219 | logger.info('End of refinement iteration %i.\n\n' % n_epoch)
220 |
221 | # update the learning rate (effective only if using SGD for MPSR)
222 | trainer.update_mpsr_lr(to_log, VALIDATION_METRIC)
223 |
224 |
225 | # export embeddings
226 | if params.export:
227 | trainer.reload_best()
228 | trainer.export()
229 |
--------------------------------------------------------------------------------