├── .gitignore ├── LICENSE ├── LMCurriculumLearning-2009 ├── 1marginloss.png ├── LMArchitecture.png ├── README.md └── lm.py ├── README.md └── WordTranslationWithoutParallelData ├── README.md ├── data └── ref-enfr.dict ├── img ├── Losses-Validation.png ├── all-lp.png ├── koen.png ├── overview.png └── validation-precision.png ├── out ├── dict-aren.txt.gz ├── dict-csen.txt.gz ├── dict-deen.txt.gz ├── dict-enar.txt.gz ├── dict-encs.txt.gz ├── dict-ende.txt.gz ├── dict-enen.txt.gz ├── dict-enes.txt.gz ├── dict-enfr.txt.gz ├── dict-enit.txt.gz ├── dict-enja.txt.gz ├── dict-enko.txt.gz ├── dict-enpt.txt.gz ├── dict-enru.txt.gz ├── dict-enzh.txt.gz ├── dict-esen.txt.gz ├── dict-fren.txt.gz ├── dict-iten.txt.gz ├── dict-jaen.txt.gz ├── dict-koen.txt.gz ├── dict-pten.txt.gz ├── dict-ruen.txt.gz └── dict-zhen.txt.gz ├── sample.md ├── scripts ├── README.md ├── launch.pl └── run-cos.sh └── src ├── net.py ├── train.lua └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Jean Senellart 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LMCurriculumLearning-2009/1marginloss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/LMCurriculumLearning-2009/1marginloss.png -------------------------------------------------------------------------------- /LMCurriculumLearning-2009/LMArchitecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/LMCurriculumLearning-2009/LMArchitecture.png -------------------------------------------------------------------------------- /LMCurriculumLearning-2009/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | This program implements the language model from the paper [**Curriculum Learning** [*Bengio, 2009*]](#references) using tensorflow. 4 | 5 | 6 | # Prerequisite 7 | 8 | You need python 3, and functional installation of tensorflow 2. 9 | 10 | You also need `configargparse` module for management of configuration files. 11 | 12 | ``` 13 | pip install configargparse 14 | ``` 15 | 16 | # Usage 17 | 18 | Data is expected to be space-pretokenized. Vocabulary can be either precalculated, in that case, it must be in a file where each line is a token, and the token list must absolutely be sorted by decreasing frequency (this order is important, since curriculum learning filters out sentence with rarest vocabs). 19 | If not provided, vocabulary is dynamically calculated on training data and sorted (this is not efficient at all on a large training data). 20 | 21 | The model architecture is hardcoded in the lines: 22 | 23 | ```python 24 | main_input = K.layers.Input(shape=(5), dtype='int32', name='main_input') 25 | embedding = K.layers.Embedding(VOC_SIZE+1, 50)(main_input) 26 | o1 = K.layers.Reshape((250,))(embedding) 27 | o2 = K.layers.Dense(100, activation='tanh')(o1) 28 | predictions = K.layers.Dense(1)(o2) 29 | ``` 30 | 31 | ![Network architecture](LMArchitecture.png) 32 | 33 | The code is using tensorflow checkpoint saving and tensorboard metrics for following the progress. 34 | 35 | You can use `-C config_file` to create a configuration file for some specific parameters, and use this configuration file with `-c config_file`. 36 | 37 | Launching a training is simply done with: `python -c config_file` or `python [options]`. 38 | 39 | See `python lm.py -h` for more information. 40 | 41 | Function loss is **max-margin loss** function calculated as: 42 | 43 | ![1-margin loss](1marginloss.png) 44 | 45 | Given a 4-word context, this is technically implemented by replacing the fifth word by a random word, and taking as a loss the difference of the score with actual word and the score with the random word. During the training, each batch contains actual window followed by fake-window, and the loss is calculated by the special loss function: 46 | 47 | ```python 48 | def loss_fnc(y_true, y_pred): 49 | positive = y_pred[0::2] 50 | negative = y_pred[1::2] 51 | loss = tf.maximum(0., 1. - positive + negative) 52 | loss = tf.reshape(tf.tile(tf.reshape(loss,[tf.size(loss),1]),[1,2]),[2*tf.size(loss)]) 53 | return loss 54 | ``` 55 | 56 | *logrank* is calculated on `--test_size N` sentences of the test file - by finding the average logrank calculated on all the windows in the test examples. *logrank* for a given windows is the rank of the actual 4+1 example, in comparison with all the 4+other possible examples. Ideally the rank would be 1 meaning that given the 4 word context, the language model always rank the actual example first in comparison to all the other possible vocab in the same context. This prediction is however clearly not always possible, since 4 word context is small, and many other possible words can fit. When trained on 2009 English wikipedia corpus (638 million 5-word windows), the paper reports the best log rank being 2.68 meaning an equivalent rank of \~14. 57 | 58 | Curriculum learning is activated with the parameters: `--curriculum_steps S` indicating how many steps are necessary to use full vocabulary. 1 step is the default, and means that the full vocabulary is used immediately (so without curriculum learning). The paper use 4 steps. A step is use for `--curriculum_examples N` examples - this parameter is calculated (as in the paper), before the filtering of the sentences with out of vocabulary tokens. So each step will include more examples. 59 | 60 | # References 61 | 62 | _BENGIO, Yoshua, LOURADOUR, Jérôme, COLLOBERT, Ronan_, et al. **Curriculum learning**. In : Proceedings of the 26th annual international conference on machine learning. 2009. p. 41-48. 63 | -------------------------------------------------------------------------------- /LMCurriculumLearning-2009/lm.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from pathlib import Path 3 | import sys 4 | import random 5 | import math 6 | from tqdm import tqdm 7 | 8 | import configargparse 9 | 10 | parser = configargparse.ArgParser(config_file_parser_class=configargparse.DefaultConfigFileParser) 11 | parser.add('-c', '--config', is_config_file=True, help='config file path') 12 | parser.add('-C', '--create_config', help='create config file') 13 | parser.add("--train", help="training data", required=True) 14 | parser.add("--test", help="testing data", required=True) 15 | parser.add("--include_unk", "-u", action="store_true") 16 | parser.add("--test_size", type=int, default=1000) 17 | parser.add("--buffer_size", type=int, default=500000) 18 | parser.add("--batch_size", type=int, default=512) 19 | parser.add("--test_batch_size", type=int, default=32) 20 | parser.add("--steps_per_epoch", type=int, default=4096) 21 | parser.add("--vocab", "-v", type=int, default=20000) 22 | parser.add("--vocab_file", help="vocabulary file", default=None) 23 | parser.add("--verbose", type=int, default=1) 24 | parser.add("--model", default="./lmmodel") 25 | parser.add("--sgd_learning_rate", default=0.01, type=float) 26 | parser.add("--sgd_decay", default=1e-6, type=float) 27 | parser.add("--sgd_momentum", default=0.9, type=float) 28 | parser.add("--curriculum_steps", default=1, type=int) 29 | parser.add("--curriculum_examples", default=100000000, type=int) 30 | parser.add("--curriculum_vocabs", type=int, action="append") 31 | parser.add("--window_size", type=int, default=5) 32 | parser.add("--embedding_size", type=int, default=50) 33 | parser.add("--dense_size", type=int, default=100) 34 | 35 | args = parser.parse_args() 36 | 37 | # Display all of values - and where they are coming from 38 | print(parser.format_values()) 39 | 40 | WINDOW_SIZE = args.window_size 41 | EMBEDDING_SIZE = args.embedding_size 42 | DENSE_SIZE = args.dense_size 43 | 44 | if args.create_config: 45 | options = {} 46 | for attr, value in args.__dict__.items(): 47 | if attr != "config" and attr != "create_config" and value is not None: 48 | options[attr] = value 49 | file_name = args.create_config 50 | content = configargparse.DefaultConfigFileParser().serialize(options) 51 | Path(file_name).write_text(content) 52 | print("configuration saved to file: %s" % file_name) 53 | sys.exit(0) 54 | 55 | import tensorflow as tf 56 | import tensorflow.keras as K 57 | import tensorflow_datasets as tfds 58 | 59 | # Implement simple SpaceTokenizer - built-in tokenizer in tf filter-out 60 | # non alphanumeric tokens 61 | # see https://www.tensorflow.org/datasets/api_docs/python/tfds/features/text/TokenTextEncoder 62 | class SpaceTokenizer(object): 63 | def tokenize(self, s): 64 | toks = [] 65 | toks.extend(tf.compat.as_text(s).split(' ')) 66 | toks = [t for t in toks if t] 67 | return toks 68 | 69 | tokenizer = SpaceTokenizer() 70 | 71 | train_dataset = tf.data.TextLineDataset(args.train) 72 | test_dataset = tf.data.TextLineDataset(args.test) 73 | 74 | if not args.vocab_file: 75 | VOC_SIZE = args.vocab 76 | 77 | print("Read Corpus - prepare sorted vocabulary") 78 | freq = Counter() 79 | for text_tensor in train_dataset: 80 | some_tokens = tokenizer.tokenize(text_tensor.numpy()) 81 | sentence_vocabulary_set = set(some_tokens) 82 | for v in sentence_vocabulary_set: 83 | freq[v] += 1 84 | 85 | vocab = [k for (k,v) in freq.most_common()] 86 | else: 87 | print("Read Vocab file - assuming it is sorted by decreasing frequency") 88 | vocab = [] 89 | with open(args.vocab_file) as f: 90 | for l in f: 91 | vocab.append(l.strip()) 92 | 93 | VOC_SIZE = min(len(vocab), args.vocab) 94 | print("Total Vocab Size=", len(vocab), "Actual vocab size=", VOC_SIZE) 95 | vocab = vocab[:VOC_SIZE] 96 | 97 | if args.include_unk: 98 | VOC_SIZE += 1 99 | 100 | #let us define a tensor with all vocabs - that we will use in test 101 | #to find rank of a given context-word score 102 | vocabs = tf.cast(tf.constant(range(1,VOC_SIZE+1)), 103 | dtype=tf.int64) 104 | 105 | encoder = tfds.features.text.TokenTextEncoder(vocab, tokenizer=tokenizer) 106 | 107 | def encode(text_tensor): 108 | encoded_text = encoder.encode(text_tensor.numpy()) 109 | return encoded_text, True 110 | 111 | def encode_map_fn(text): 112 | encoded_text, _ = tf.py_function(encode, 113 | inp=[text], 114 | Tout=(tf.int64, tf.bool)) 115 | encoded_text.set_shape([None]) 116 | return encoded_text 117 | 118 | competence = tf.Variable(1.0) 119 | nexamples = tf.Variable(0, dtype=tf.int64) 120 | 121 | def window_train(token_list): 122 | windows = [] 123 | labels = [] 124 | for i in range(len(token_list)-WINDOW_SIZE+1): 125 | max_window = tf.math.reduce_max(token_list[i:i+WINDOW_SIZE]) 126 | if max_window <= tf.cast(competence*VOC_SIZE, tf.int64): 127 | windows.append(token_list[i:i+WINDOW_SIZE]) 128 | labels.append(0) 129 | fake = [] 130 | for j in range(WINDOW_SIZE-1): 131 | fake.append(token_list[i+j]) 132 | fake.append(tf.random.uniform([], 133 | minval=1, 134 | maxval=tf.cast(competence*VOC_SIZE,tf.dtypes.int64), 135 | dtype=tf.dtypes.int64)) 136 | windows.append(fake) 137 | labels.append(1) 138 | nexamples.assign_add(tf.cast(len(token_list)-WINDOW_SIZE+1, tf.int64)) 139 | return windows, labels 140 | 141 | def window_map_train_fn(token_list): 142 | windows, labels = tf.py_function(window_train, 143 | inp=[token_list], 144 | Tout=(tf.int64, tf.float32)) 145 | 146 | windows.set_shape([None,WINDOW_SIZE]) 147 | labels.set_shape([None]) 148 | 149 | return windows, labels 150 | 151 | def window_test(token_list): 152 | windows = [] 153 | labels = [] 154 | for i in range(len(token_list)-WINDOW_SIZE+1): 155 | max_window = tf.math.reduce_max(token_list[i:i+WINDOW_SIZE]) 156 | if max_window <= tf.cast(VOC_SIZE, tf.int64): 157 | windows.append(token_list[i:i+WINDOW_SIZE]) 158 | labels.append(0) 159 | return windows, labels 160 | 161 | def window_map_test_fn(token_list): 162 | windows, _ = tf.py_function(window_test, 163 | inp=[token_list], 164 | Tout=(tf.int64, tf.float32)) 165 | 166 | windows.set_shape([None,WINDOW_SIZE]) 167 | 168 | return windows 169 | 170 | def filter_nonempty_xy(ds): 171 | return ds.filter(lambda x, y: len(x) > 0) 172 | 173 | def filter_nonempty_x(ds): 174 | return ds.filter(lambda x: len(x) > 0) 175 | 176 | # Train Data preparation 177 | encoded_train_data = train_dataset.map(encode_map_fn) 178 | train_data = encoded_train_data.shuffle(args.buffer_size) 179 | # apply window operator, and remove empty list 180 | train_data = train_data.map(window_map_train_fn).apply(filter_nonempty_xy) 181 | train_data = train_data.flat_map(lambda x,y: tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(x), tf.data.Dataset.from_tensor_slices(y)))) 182 | train_data = train_data.batch(args.batch_size*2) 183 | 184 | # Test Data preparation 185 | test_data = test_dataset.map(encode_map_fn) 186 | # apply window operator, and remove empty list 187 | test_data = test_data.map(window_map_test_fn).apply(filter_nonempty_x) 188 | test_data = test_data.flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x)) 189 | test_data = test_data.take(args.test_size) 190 | test_data = test_data.batch(args.test_batch_size) 191 | 192 | # Definition of the model and loss function 193 | main_input = K.layers.Input(shape=(WINDOW_SIZE), dtype='int32', name='main_input') 194 | embedding = K.layers.Embedding(VOC_SIZE+1, EMBEDDING_SIZE)(main_input) 195 | o1 = K.layers.Reshape((WINDOW_SIZE*EMBEDDING_SIZE,))(embedding) 196 | o2 = K.layers.Dense(DENSE_SIZE, activation='tanh')(o1) 197 | predictions = K.layers.Dense(1)(o2) 198 | 199 | def loss_fnc(y_true, y_pred): 200 | positive = y_pred[0::2] 201 | negative = y_pred[1::2] 202 | loss = tf.maximum(0., 1. - positive + negative) 203 | loss = tf.reshape(tf.tile(tf.reshape(loss,[tf.size(loss),1]),[1,2]),[2*tf.size(loss)]) 204 | return loss 205 | 206 | model = K.Model(inputs=main_input, outputs=predictions) 207 | 208 | print(model.summary()) 209 | 210 | # Train the model 211 | sgd = K.optimizers.SGD(lr=args.sgd_learning_rate, 212 | decay=args.sgd_decay, 213 | momentum=args.sgd_momentum, 214 | nesterov=True) 215 | model.compile(sgd, loss=loss_fnc) 216 | 217 | ckpt = tf.train.Checkpoint(step=tf.Variable(0, dtype=tf.int64), 218 | optimizer=sgd, 219 | net=model, 220 | nexamples=nexamples) 221 | manager = tf.train.CheckpointManager(ckpt, args.model, max_to_keep=3) 222 | 223 | ckpt.restore(manager.latest_checkpoint) 224 | if manager.latest_checkpoint: 225 | print("Restored from {}".format(manager.latest_checkpoint)) 226 | print("Total number of examples=", nexamples.numpy()) 227 | else: 228 | print("Initializing from scratch.") 229 | 230 | summary_writer = tf.summary.create_file_writer(args.model) 231 | 232 | while True: 233 | curriculum_step = int(nexamples.numpy()*1.0/args.curriculum_examples) 234 | if args.curriculum_vocabs: 235 | if curriculum_step >= len(args.curriculum_vocabs): 236 | curriculum_step = len(args.curriculum_vocabs)-1 237 | rate = args.curriculum_vocabs[curriculum_step]*1.0/VOC_SIZE 238 | else: 239 | rate = (curriculum_step+1.0)/args.curriculum_steps 240 | competence.assign(tf.cast(tf.math.minimum(rate, 1.0), tf.float32)) 241 | 242 | # Train 243 | if ckpt.save_counter != 0: 244 | h = model.fit(train_data, steps_per_epoch=args.steps_per_epoch, verbose=args.verbose) 245 | loss = h.history["loss"][0] 246 | ckpt.step.assign_add(args.steps_per_epoch) 247 | else: 248 | loss = None 249 | first = False 250 | 251 | # Eval 252 | test_name = "%s/test_%d.out" % (args.model, ckpt.save_counter) 253 | print("Evaluating model => ", test_name) 254 | with open(test_name, "w") as ftest: 255 | sum_logrank = 0 256 | for test in tqdm(test_data, 257 | unit="batch", 258 | ncols=80, 259 | total=math.ceil(args.test_size*1.0/args.test_batch_size)): 260 | t_expanded = tf.concat([tf.reshape(tf.tile(test[:,:WINDOW_SIZE-1],[1,VOC_SIZE]), 261 | [-1, WINDOW_SIZE-1]), 262 | tf.reshape(tf.tile(vocabs,[test.shape[0]]), 263 | [-1, 1])], 264 | axis=1) 265 | out = model.predict(t_expanded) 266 | out = tf.reshape(out, [-1, VOC_SIZE]) 267 | for ib in range(test.shape[0]): 268 | w = test[ib][WINDOW_SIZE-1]-1 269 | windows_s = "["+" ".join([vocab[t-1] for t in test[ib][:-1]])+ "]..."+vocab[w] 270 | out_w = out[ib][w] 271 | sorted_out = tf.sort(tf.reshape(out[ib],[VOC_SIZE]), direction='DESCENDING') 272 | rank = tf.where(tf.equal(sorted_out,out_w))[:,0][0]+1 273 | best10 =" ".join([vocab[tf.where(tf.equal(out[ib], sorted_out[idx]))[:,0][0].numpy()]+"/"+ 274 | str(sorted_out[idx].numpy()) 275 | for idx in range(10)]) 276 | ftest.write("\t".join((windows_s, str(rank.numpy()), 277 | str(tf.math.log(tf.cast(rank,dtype=tf.float32)).numpy()), best10))+"\n") 278 | sum_logrank += tf.math.log(tf.cast(rank,dtype=tf.float32)) 279 | ftest.write("======\n%f\n" % (sum_logrank.numpy()/args.test_size)) 280 | 281 | print(ckpt.step.numpy()*args.batch_size, "==>", "total windows", nexamples.numpy(), "competence", competence.numpy(), "loss", loss, "logrank", sum_logrank.numpy()/args.test_size) 282 | with summary_writer.as_default(): 283 | tf.summary.scalar('logrank', sum_logrank/args.test_size, step=ckpt.step*args.batch_size) 284 | tf.summary.scalar('competence', competence, step=ckpt.step*args.batch_size) 285 | tf.summary.scalar('total windows', nexamples, step=ckpt.step*args.batch_size) 286 | if loss is not None: 287 | tf.summary.scalar('train_loss', loss, step=ckpt.step*args.batch_size) 288 | manager.save() 289 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # papers 2 | This repo is containing notes and implementation for cherry-picked publications of my particular interest. 3 | 4 | My notes and interpretation are fully personal and might not reflect the authors ideas. 5 | The code is available as-is and meant for me to better understand the papers and check the results - if you want to take it, or/and improve it, you are welcome! 6 | -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/README.md: -------------------------------------------------------------------------------- 1 | ## Overview 2 | [Conneau, A., Lample, G., Ranzato, M. A., Denoyer, L., & Jégou, H. (2017). *Word Translation Without Parallel Data*.](https://arxiv.org/pdf/1710.04087.pdf) 3 | 4 | As a first step towards fully unsupervised machine translation, this paper demonstrates how to build a bilingual dictionary between two languages without using any parallel corpora, by aligning monolingual word embedding spaces in a fully unsupervised way. 5 | 6 | This work is built on [Mikolov, 2013] showing that it is possible exploit the similarities of monolingual embedding spaces to build a mapping between monolingual embeddings. Using supervised approach, [Mikolov, 2013] successfully shows that a simple linear mapping *W* gives good result for this mapping. In addition, [Xing, 2015] shows that such mapping can be improved by enforcing an orthogonality constraint on *W*. 7 | 8 | Based on that starting point, the work described in the paper shows that it is possible to learn then refine such a mapping with a fully unsupervized approach using adversarial learning. It is based on the following process (figure extracted from the paper): 9 | 10 | ![Overview](./img/overview.png) 11 | 12 | - (A) As a starting point, we have two independently trained word embeddings (can be different domains or different languages). The goal is to train a mapping between these word embeddings. The mapping can be as simple as a linear mapping, but could be more complicated (any neural network). 13 | - (B) A first version of this mapping is obtained using adversarial learning following [Goodfellow, 2014] - where the generator is the projection mapping, and the discriminator is trying to make a difference between source projected embedding and native target embedding 14 | - (C) Based on this first mapping, a refinment method finding the optimal *W* is applied based on Procrustes solution [Schönemann, 1966], this process might be iterated several times 15 | - (D) With a new distance metric (CSLS) - dealing with *hubs* issue - source-target pairs are extracted 16 | 17 | Key references: 18 | 19 | * [Mikolov, T., Le, Q., & Sutskever, I. (2013) *Exploiting similarities among languages for machine translation*.](https://arxiv.org/abs/1309.4168) 20 | * [Xing, C., Wang, D., Liu, C., & Lin, Y. (2015). *Normalized Word Embedding and Orthogonal Transform for Bilingual Word Translation*.](http://anthology.aclweb.org/N/N15/N15-1104.pdf) 21 | * [Schönemann, P. H. (1966). *A generalized solution of the orthogonal Procrustes problem*](https://link.springer.com/article/10.1007/BF02289451) 22 | 23 | Major follow-up: 24 | 25 | * [Lample, G., Denoyer, L., & Ranzato, M. (2017) *Unsupervised Machine Translation Using Monolingual Corpora Only*.](https://arxiv.org/pdf/1711.00043.pdf) 26 | 27 | ## Step-by-Step analysis and Implementation notes 28 | 29 | The full process is implement using 30 | 31 | ### The models 32 | * The generator is a simple Linear mapping defined in `net.py`, `Generator` class. Any other model could be implemented (but just think about adaptation of the `orthogonalityUpdate` method). 33 | * The class also implements `orthogonalityUpdate` method (section 3.3). This method is called after each gradient update. We found that beta=0.0001 was a better value than beta=0.01 proposed in the paper. (this value being also close to the one proposed in [Cisse, 2017]). Also *W* is initialized as a random orthogonal matrix using `scipy.stats.special_ortho_group` 34 | * The discriminator is a 2-layer network implemented in `net.py` and as defined in 3.1. It has a single cell output activated by a sigmoid. The value is the probability of the input being a true target embedding. Layers are initialized with uniform value in [-0.1,0.1] 35 | 36 | ``` 37 | Sequential ( 38 | (0): Dropout (p = 0.1) 39 | (1): Linear (300 -> 2048) 40 | (2): LeakyReLU (0.01) 41 | (3): Linear (2048 -> 1) 42 | (4): Sigmoid () 43 | ) 44 | ``` 45 | * Loss function is Binary Cross Entropy measure (`BCELoss`) 46 | 47 | ### Extracting source-target pairs 48 | * Fast nearest neighbors is implemented using Facebook FAISS library [Johnson 2017] with python binding 49 | * To calculate Cross-Domain Similarity Local Scaling (CSLS) - as defined in section 2.3 - the value \( r_S(y_t) \) is pre-calculated for the full target-dictionary. The results of FAISS `search` are rescored using the CSLS metrics. For simplicity, I am using -CSLS value so that minimal value is the better 50 | * !!note!! - I only implemented one-way CSLS. In 3.4, it is also mentioned that for refinement, the mutual nearest neighbor are also considered to restrict the list of candidates. 51 | 52 | ### Evaluation 53 | An internal EN-FR dictionary is provided and can be used for evaluation with the `--evalDict` option. Note: this is note the dictionary used by the authors but it gives same type of results. 54 | 55 | The provided dictionary has multiple meanings for each simple word, so for the calculation of the precision, I just give credit if one hypothesis matches one of the available meanings. 56 | 57 | Note that the dictionary does not contain inflected forms, so the score is under-evaluated since the proposed meaning might be available but in lemmatized form. Also source words missing from the evaluation dictionary are not taken into account. This could be improved using inflected form dictionary and/or stemming but my goal was just to validate the approach. 58 | 59 | For instance - the first entry below corresponding a correctly hypothetized tranlsation for the word `men`, is ignored since `men` is not part of the reference dictionary, and the second is considered as wrong because `peut` would normally count in P@5 but is the inflected form of `pouvoir` so not validated by the dictionary. 60 | 61 | ``` 62 | men hommes eux les enfin , alors que ainsi même pourtant 63 | can éventuellement eventuellement faut peut sinon exemple sachant ou bien effectivement 64 | ``` 65 | 66 | *It will be interesting to check the content of the authors dictionary for the evaluation*. 67 | 68 | ### Adversarial Training 69 | For adversarial training, the process described in [Goodfellow, 2014] as been implemented: 70 | * first, k-steps of discriminator update using *m=batchSize* projected source example, and *m=batchSize* native target. As suggested, I use label smoothing for calculating the loss of the native target. 71 | * then, one mini-batch update of the generator is done propagating gradient from the discriminator with inverse loss function. 72 | 73 | ## Unsupervised Model Selection and decay/early stopping strategy 74 | * average COS (optionally CSLS) on first 10k is used as validation criterion (3.5) 75 | * for each epoch, `-decayRate` is applied on learning rate 76 | * if validation score is going up for 2 epochs with relative increase larger than twice `--halfDecayTreshold` parameter, the learning rate is divided by 2 77 | * when learning rate reaches 1/20 of initial learning rate, the training stops 78 | 79 | ### Refinement 80 | Refinement procedure is using generated dictionary of 10000 points (the anchors) and use SVD decomposition to calculate optimate value for *W*. 81 | 82 | ## Using the script 83 | 84 | ### Dependency 85 | 86 | * Python 3+ 87 | * `pytorch` 88 | * `scipy` 89 | * `progressbar2` 90 | * `FAISS` (with the python binding - see [here](https://github.com/facebookresearch/faiss)) 91 | 92 | GPU can be used if available with `--gpuid ID` option in the script. 93 | 94 | A luatorch version is also provided - it includes the adversarial training, and the (non optimized) nearest neighbors extraction. 95 | 96 | ### Running it 97 | * Get fasttext word embeddings 98 | 99 | ``` 100 | wget https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.en.vec 101 | wget https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.fr.vec 102 | ``` 103 | 104 | * Train a model: 105 | 106 | ``` 107 | python -u src/train.py wiki.en.vec wiki.fr.vec --vocSize 50000 --gpuid 0 --nIter 40 108 | ``` 109 | 110 | note, that after the first run, the word embeddings will be saved to binary format so that next call can just load these binary files as following: 111 | 112 | ``` 113 | python -u src/train.py wiki.en.vec_50000.bin wiki.fr.vec_50000.bin --vocSize 50000 --gpuid 0 --nIter 40 114 | ``` 115 | 116 | Also - to speed-up calculation, the calculation of *rs(yt)* (formula 6) is done at the beginning of the script, and saved in file `wiki.fr.vec_200000.bin_rs_knn10` for further runs (depends only on vocabulary size and knn value). 117 | 118 | Most of the parameters of the process can be set in commandline as follow: 119 | 120 | ``` 121 | WORD TRANSLATION WITHOUT PARALLEL DATA 122 | 123 | positional arguments: 124 | srcemb source word embedding 125 | tgtemb target word embedding 126 | 127 | optional arguments: 128 | -h, --help show this help message and exit 129 | --seed SEED initial random seed 130 | --vocSize VOCSIZE vocabulary size 131 | --dim DIM embedding size 132 | --hidden HIDDEN discriminator hidden layer size [3.1] 133 | --discDropout DISCDROPOUT 134 | discriminator dropout [3.1] 135 | --smoothing SMOOTHING 136 | label smoothing value [3.1] 137 | --samplingRange SAMPLINGRANGE 138 | sampling range on vocabulary for adversarial training 139 | [3.2] 140 | --beta BETA orthogonality adjustment parameter (equation 7) 141 | --k K #iteration of discriminator training for each 142 | iteration 143 | --batchSize BATCHSIZE 144 | batch size 145 | --learningRate LEARNINGRATE 146 | learning rate 147 | --decayRate DECAYRATE 148 | decay rate 149 | --nEpochs NEPOCHS number of epochs 150 | --halfDecayThreshold HALFDECAYTHRESHOLD 151 | if valid relative increase > this value for 2 epochs, 152 | half the LR 153 | --knn KNN number of neighbors to extract 154 | --skipRefinment 155 | --distance {CSLS,NN} distance to use NN or CSLS [2.3] 156 | --load LOAD load parameters of generator 157 | --save SAVE save parameters of generator 158 | --evalDict EVALDICT dictionary for evaluation 159 | --gpuid GPUID 160 | ``` 161 | 162 | 163 | The option `--load file` and `--save file`, can be used to save and reload the generator state. 164 | 165 | for `-save FILE` - the files saved will be `FILE_adversarial.t7` - corresponding to optimal solution found during adversarial training, then `FILE_refinement.t7` after refinement procedure. 166 | 167 | ## Some results 168 | 169 | * 10k first English-French entries are provided [here](./sample.md) - corresponding to a P@5 score of 62.24 (vs. P@1 77.8 in the paper without refinment but using the authors own dictionary). The hyper-parameters are the same than the ones in the paper (except batch size) 170 | * Unsupervized model selection: as shown in Figure 2 of the paper, the highest precision (evaluated on reference dictionary) is also corresponding to minimal average CSLS score over the first 10k entries - here at the epoch 59. The empty lines are showing score and precision for a training without half-decay strategy. 171 | ![Unsupervized Validation vs. Precision](./img/validation-precision.png) 172 | * For the same run, the following graph shows the evolution of the generator loss, of the discriminator loss, and of the (unsupervized) validation score. It is interesting to see that the discrimator loss decreases continuously till epoch 43 then struggle to keep position while generator loss decreases. In this run, a continuous decay rate of 0.99 was used, both halving/no halving of the learning rate was done (respectively plain, empty lines). Also - learning rates of generator and discriminator are the same. It could be interesting to investigate other strategies. 173 | ![Losses evolution vs. Validation](./img/Losses-Validation.png) 174 | * I tried for more challenging language pairs like Korean-English - the shape of the curves are really different. First, the generator does not seem to fool *at all* the discriminator. Also after few epoch of discriminator training, the loss of the generator training raises drastically as if discriminator wins the battle. In parallel, the validation curve does not show any clear sign of minimum reaching. I did not have any reference Korean-English dictionary to get an evaluation of the results, but surfing through generated dictionary does not show any good mapping. In the paper, scores for more exotic languages like Russian and Chinese are also comparatively very low. Is the mapping for these languages possible at all? More experiments with other hyper-parameters might help? 175 | ![Same for Korean-English](./img/koen.png) 176 | * The following table shows the final validation score (COS) for a variety of language pairs. Accross European language pairs (except cs and de>es)- we can see that validation score is around 0.3 and quick review of corresponding dictionary shows a good quality. For some language pairs however, like de>es, en>cs, cs>es, cs>fr - validation score is around 0.45 and corresponding dictionary are poor. For these language pairs, the training has been trained with multiple seeds - and in each of these cases, we can see that the adversarial training does not manage to get accurate generator (discriminator clearly dominates). For ja, ko, zh - the results are very poor (according to validation score) - and ja>* and zh>* shows some very odd convergence (almost all points are mapped to a small subset of the target embedding space which gives an apparent good validation score). ru and ar are intermediate. Corresponding dictionaries (centered on English) with 10-best are provided in `out` directory. 177 | ![all language pair](./img/all-lp.png) 178 | 179 | ## Personal comments and Discussions 180 | * Even without the refinment implementation, the results are as good as promised by the paper: without using any explicit bilingual knowledge, the proposed approach proves is possible to build a relatively accurate word translation table. We could argue that since the word embeddings have been trained on wikipedia (see [fasttext](https://github.com/facebookresearch/fastText), there is some implicit _aligned knowledge_ that is necessarily reflected in the embeddings - leading to these results. Still it is still a wonder that multilingual word embeddings can be aligned like this, and with such a simple transformation. It would be interesting to test with other embeddings built on different sources of data. The experiments from the authors showing that even on a same language, we cannot align so easily wikipedia-trained word embedding and gigaword-trained word embedding is also confirming this intuition. So probably, the usage of this method is optimal for comparable corpora. 181 | 182 | There are some limits in the extraction: 183 | 184 | * The mapping can not really deal with polysemy, since all meanings for a source word will necessarily be in the same "neighborhood". This is not a problem of the approach, but this is the very nature of these word embeddings forcing multiple meanings to share single representations. It could be interesting to see what would happen with adaptive skip-gram word vectors... 185 | * Also, beyond polysemy and contextual mappings, what is also really missing is the notion multi-word expression which are critical for a human perspective and building of a translation lexicon. 186 | 187 | Regarding the potential usage of this work, it is clearly not directly usable to build human-ready word translation table: for most language pairs, such resource is already available with higher quality and without all the limitations mentioned, on the other hand for rarer language pairs, the quality is far lower and would be challenged very quickly by any traditional (human) resource building process. 188 | 189 | However, the findings and implications of this work are a huge step forward better understanding of cross-language word embeddings and knowledge representation, and this work is clearly the _apetizer_ for the the main course - also published by facebook research team - *Unsupervised Machine Translation Using Monolingual Corpora Only*... 190 | 191 | Keep tuned... 192 | -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/img/Losses-Validation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/img/Losses-Validation.png -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/img/all-lp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/img/all-lp.png -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/img/koen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/img/koen.png -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/img/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/img/overview.png -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/img/validation-precision.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/img/validation-precision.png -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-aren.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-aren.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-csen.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-csen.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-deen.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-deen.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-enar.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-enar.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-encs.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-encs.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-ende.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-ende.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-enen.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-enen.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-enes.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-enes.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-enfr.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-enfr.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-enit.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-enit.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-enja.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-enja.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-enko.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-enko.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-enpt.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-enpt.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-enru.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-enru.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-enzh.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-enzh.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-esen.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-esen.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-fren.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-fren.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-iten.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-iten.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-jaen.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-jaen.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-koen.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-koen.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-pten.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-pten.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-ruen.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-ruen.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/out/dict-zhen.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsenellart/papers/d5525cba87ff40f1ee1aab66a4b3fcda9e9e1148/WordTranslationWithoutParallelData/out/dict-zhen.txt.gz -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/scripts/README.md: -------------------------------------------------------------------------------- 1 | Utility scripts to automate extraction for all language pairs. Each lp is run with 2 different seeds. 2 | -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/scripts/launch.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl 2 | 3 | use strict; 4 | 5 | 6 | my @todo; 7 | 8 | foreach my $src (qw(it pt ko ru en es de fr ja)) { 9 | foreach my $tgt (qw(it pt ko ru en es de fr ja)) { 10 | push @todo,[$src,$tgt, 123] 11 | } 12 | } 13 | foreach my $src (qw(it pt ko ru en es de fr ja)) { 14 | foreach my $tgt (qw(it pt ko ru en es de fr ja)) { 15 | push @todo,[$src,$tgt, 768] 16 | } 17 | } 18 | 19 | 20 | sub countslot { 21 | my @r=split(/\n/,`nvidia-smi | grep python`); 22 | return $#r+1; 23 | } 24 | 25 | sub dolaunch { 26 | my ($cmd)=@_; 27 | print("RUN CMD: $cmd\n"); 28 | system($cmd); 29 | } 30 | 31 | while(my $lp=shift @todo) { 32 | while(countslot()>=9) { sleep(30); } 33 | my @lp=@{$lp}; 34 | print "************* LAUNCH $lp[0]-$lp[1]-$lp[2]\n"; 35 | dolaunch("scripts/run-cos.sh $lp[0] $lp[1] $lp[2] 2> /dev/null &"); 36 | sleep(30); 37 | } 38 | -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/scripts/run-cos.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=/home/devling/projects/faiss 4 | export LD_LIBRARY_PATH=/opt/OpenBLAS/lib:/usr/local/cuda-8.0/lib64/ 5 | 6 | src=$1 7 | tgt=$2 8 | seed=$3 9 | 10 | echo "LAUNCH python3.5 -u traincos.py wiki.${src}.vec_200000.bin wiki.${tgt}.vec_200000.bin --vocSize 200000 --gpuid 0 --nEpoch 200 --refinementIterations 6 --dump_output out/dictcos-${src}${tgt}-seed${seed}.txt --save out/Wcos-${src}${tgt}-seed${seed} --seed $seed > out/logcos-${src}${tgt}-seed${seed}.txt" 11 | 12 | python3.5 -u traincos.py wiki.${src}.vec_200000.bin wiki.${tgt}.vec_200000.bin --vocSize 200000 --gpuid 0 --nEpoch 200 --refinementIterations 6 --dump_output out/dictcos-${src}${tgt}-seed${seed}.txt --save out/Wcos-${src}${tgt}-seed${seed} --seed $seed > out/logcos-${src}${tgt}-seed${seed}.txt 13 | -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/src/net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | from scipy.stats import special_ortho_group 5 | 6 | 7 | class Generator(nn.Module): 8 | def __init__(self, args): 9 | super(Generator, self).__init__() 10 | 11 | print("* Generator Model Initialization") 12 | self.net = nn.Linear(args.dim, args.dim, False) 13 | 14 | if args.load: 15 | print(" * Load parameters from file: "+args.load) 16 | self.net.load_state_dict(torch.load(args.load)) 17 | 18 | W = special_ortho_group.rvs(args.dim) 19 | list(self.net.parameters())[0].data.copy_(torch.from_numpy(W)) 20 | 21 | print(self.net) 22 | 23 | def load(self, filename): 24 | self.net.load_state_dict(torch.load(filename)) 25 | 26 | def forward(self, x): 27 | return self.net.forward(x) 28 | 29 | def save(self, filename): 30 | torch.save(self.net.state_dict(), filename) 31 | 32 | def set(self, W): 33 | list(self.net.parameters())[0].data.copy_(W) 34 | 35 | def orthogonalityUpdate(self, beta): 36 | W = list(self.net.parameters())[0].data 37 | # update to keep W orthogonal 38 | W = (1+beta) * W 39 | W.addmm(-beta, torch.mm(W, W.t()), W) 40 | self.set(W) 41 | 42 | 43 | class Discriminator(nn.Module): 44 | def __init__(self, args): 45 | super(Discriminator, self).__init__() 46 | 47 | print("* Discriminator model initialiation") 48 | # the discriminator - outputs log probability of input to be source or target 49 | self.net = nn.Sequential( 50 | nn.Dropout(args.discDropout), 51 | nn.Linear(args.dim, args.hidden, True), 52 | nn.LeakyReLU(), 53 | nn.Linear(args.hidden, 1), 54 | nn.Sigmoid() 55 | ) 56 | for param in self.net.parameters(): 57 | param.data.uniform_(-0.1,0.1) 58 | print(self.net) 59 | 60 | def forward(self, x): 61 | return self.net.forward(x) 62 | -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/src/train.lua: -------------------------------------------------------------------------------- 1 | local torch = require('torch') 2 | local cutorch 3 | local nn = require('nn') 4 | require('nngraph') 5 | 6 | local cmd = torch.CmdLine() 7 | cmd:text() 8 | cmd:text() 9 | cmd:text('WORD TRANSLATION WITHOUT PARALLEL DATA') 10 | cmd:text() 11 | cmd:text('Options') 12 | cmd:option('-seed',123,'initial random seed') 13 | cmd:option('-gpuid',0,'use cuda') 14 | cmd:option('-srcemb','','path to source embedding') 15 | cmd:option('-tgtemb','','path to target embedding') 16 | cmd:option('-vocsize', 200000, 'vocabulary size') 17 | cmd:text() 18 | 19 | local opt = cmd:parse(arg) 20 | 21 | if opt.gpuid > 0 then 22 | print('loading CUDA') 23 | cutorch = require('cutorch') 24 | require('cunn') 25 | cutorch.setDevice(opt.gpuid) 26 | end 27 | 28 | local dim = 300 29 | -- hiiden layer size of discriminator 30 | local hidden = 2048 31 | -- batch size 32 | local batchSize = 64 33 | -- discriminant dropout 34 | local disc_dropout = 0.1 35 | -- smoothing label parameter 36 | local smoothing = 0.2 37 | -- number of iterations for discriminator training 38 | local k = 1 39 | -- learning rate and decay 40 | local learningRate = 0.1 41 | local decay = 0.99 42 | 43 | local beta = 0.01 44 | 45 | torch.manualSeed(opt.seed) 46 | 47 | local function find_knn(t, v, topk) 48 | local d = torch.Tensor(v:size(1)) 49 | if opt.gpuid then d = d:cuda() end 50 | for i = 1, v:size(1) do 51 | d[i] = -t:dot(v[i])/v[i]:norm() 52 | end 53 | return torch.topk(d, topk) 54 | end 55 | 56 | -- read embedding from text file or from .t7 57 | local function read_embed(filename) 58 | print("read embedding from "..filename.." - opt.vocsize ="..opt.vocsize) 59 | if filename:sub(-3) == ".t7" then 60 | local voc, weights = table.unpack(torch.load(filename)) 61 | assert(#voc == opt.vocsize) 62 | assert(weights:size(2) == dim) 63 | return voc, weights 64 | else 65 | local f = io.open(filename, "r") 66 | local header = f:read() 67 | local splitHeader = header:split(' ') 68 | assert(#splitHeader==2, "incorrect file format - header should be '#vocab dim'") 69 | local numWords = tonumber(splitHeader[1]) 70 | local embeddingSize = tonumber(splitHeader[2]) 71 | assert(numWords>=opt.vocsize, "opt.vocsize larger than vocabulary in embedding") 72 | assert(embeddingSize==dim, "embedding size does not match dim") 73 | local weights = torch.Tensor(opt.vocsize, embeddingSize) 74 | local voc = {} 75 | for i=1, opt.vocsize do 76 | local line = f:read() 77 | local splitLine = line:split(' ') 78 | assert(#splitLine == dim+1, "incorrect embedding format") 79 | table.insert(voc,splitLine[1]) 80 | for j = 2, #splitLine do 81 | weights[i][j-1] = tonumber(splitLine[j]) 82 | end 83 | end 84 | torch.save(filename.."_"..opt.vocsize..".t7", { voc, weights }) 85 | print(" * saved to "..filename.."_"..opt.vocsize..".t7") 86 | return voc, weights 87 | end 88 | end 89 | 90 | local svoc, semb = read_embed(opt.srcemb) 91 | local tvoc, temb = read_embed(opt.tgtemb) 92 | 93 | -- the generator - input source embedding, output projection in target embedding 94 | -- no bias 95 | local x = nn.Identity()() 96 | local generator = nn.gModule({x},{nn.Linear(dim, dim, false)(x)}) 97 | 98 | -- the discriminator - outputs log probability of input to be source or target 99 | x = nn.Identity()() 100 | local h1 = nn.Linear(dim, hidden)(nn.Dropout(disc_dropout)(x)) 101 | local o = nn.Linear(hidden,1)(nn.LeakyReLU()(h1)) 102 | local discriminator = nn.gModule({x},{nn.Sigmoid()(o)}) 103 | 104 | -- use cross entropy 105 | local criterion = nn.BCECriterion() 106 | 107 | local zeroClass = torch.Tensor(batchSize):fill(0) 108 | local oneClass = torch.Tensor(batchSize):fill(1) 109 | local smoothedOneClass = torch.Tensor(batchSize):fill(1-smoothing) 110 | 111 | if opt.gpuid > 0 then 112 | generator = generator:cuda() 113 | discriminator = discriminator:cuda() 114 | zeroClass = zeroClass:cuda() 115 | oneClass = oneClass:cuda() 116 | smoothedOneClass = smoothedOneClass:cuda() 117 | criterion = criterion:cuda() 118 | semb = semb:cuda() 119 | temb = temb:cuda() 120 | end 121 | 122 | local W = generator:getParameters():reshape(dim,dim) 123 | 124 | for iter = 1, 100 do 125 | local genLoss = 0 126 | local discLoss = 0 127 | 128 | for _ = 1, opt.vocsize/batchSize do 129 | for _ = 1, k do 130 | local bsrcIdx = (torch.rand(batchSize)*opt.vocsize+1):long() 131 | local btgtIdx = (torch.rand(batchSize)*opt.vocsize+1):long() 132 | local batch_src = semb:index(1, bsrcIdx) 133 | local batch_tgt = temb:index(1, btgtIdx) 134 | 135 | -- projection of source in target 136 | local projectedSrc = generator:forward(batch_src) 137 | 138 | discriminator:zeroGradParameters() 139 | 140 | -- calculate loss for batch src projected in target 141 | local discProjSrc = discriminator:forward(projectedSrc) 142 | discLoss = discLoss + criterion:forward(discProjSrc, zeroClass) 143 | discriminator:backward(projectedSrc, criterion:backward(discProjSrc, zeroClass)) 144 | 145 | -- loss for tgt classified with smoothed label 146 | local discTgt = discriminator:forward(batch_tgt) 147 | discLoss = discLoss + criterion:forward(discTgt, smoothedOneClass) 148 | discriminator:backward(batch_tgt, criterion:backward(discTgt, smoothedOneClass)) 149 | 150 | discriminator:updateParameters(learningRate) 151 | 152 | end 153 | 154 | local bsrcIdx = (torch.rand(batchSize)*opt.vocsize+1):long() 155 | local batch_src = semb:index(1, bsrcIdx) 156 | 157 | if opt.gpuid > 0 then 158 | batch_src = batch_src:cuda() 159 | end 160 | 161 | -- calculate loss for batch src projected in target 162 | local projectedSrc = generator:forward(batch_src) 163 | local discProjSrc = discriminator:forward(projectedSrc) 164 | 165 | genLoss = genLoss + criterion:forward(discProjSrc, oneClass) 166 | generator:zeroGradParameters() 167 | local gradGen = discriminator:backward(projectedSrc, criterion:backward(discProjSrc, oneClass)) 168 | generator:backward(batch_src, gradGen) 169 | generator:updateParameters(learningRate) 170 | 171 | -- update to keep W orthogonal 172 | W = (1+beta) * W 173 | local prod = torch.mm(W, W:t()) 174 | W:addmm(-beta,prod,W) 175 | end 176 | 177 | print('--- ',iter,'genLoss='..genLoss*batchSize/opt.vocsize, 'discLoss='..discLoss*batchSize/opt.vocsize/k, 178 | 'learningRate='..learningRate) 179 | 180 | learningRate = learningRate * decay 181 | 182 | end 183 | 184 | 185 | for i = 1, 10000 do 186 | local projSEmb = generator:forward(semb[i]) 187 | local y, idx = find_knn(projSEmb, temb, 10) 188 | print('* '..svoc[i]) 189 | for j = 1, idx:size(1) do 190 | print(' '..tvoc[idx[j]], y[j]) 191 | end 192 | end 193 | -------------------------------------------------------------------------------- /WordTranslationWithoutParallelData/src/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | from scipy.spatial.distance import cosine 5 | import progressbar 6 | from net import Generator, Discriminator 7 | from os import path 8 | import numpy as np 9 | 10 | import argparse 11 | import math 12 | 13 | parser = argparse.ArgumentParser(description='Word Translation Without Parallel Data') 14 | parser.add_argument('srcemb', nargs=1, type=str, help='source word embedding') 15 | parser.add_argument('tgtemb', nargs=1, type=str, help='target word embedding') 16 | parser.add_argument('--seed', type=int, default=123, help='initial random seed') 17 | parser.add_argument('--vocSize', type=int, default=200000, help='vocabulary size') 18 | parser.add_argument('--dim', default=300, type=int, help='embedding size') 19 | parser.add_argument('--hidden', default=2048, type=int, help='discriminator hidden layer size [3.1]') 20 | parser.add_argument('--discDropout', default=0.1, type=float, help='discriminator dropout [3.1]') 21 | parser.add_argument('--smoothing', default=0.2, type=float, help='label smoothing value [3.1]') 22 | parser.add_argument('--samplingRange', default=50000, type=int, help='sampling range on vocabulary for adversarial training [3.2]') 23 | parser.add_argument('--beta', default=0.0001, type=float, help='orthogonality adjustment parameter (equation 7)') 24 | parser.add_argument('--k', default=1, type=int, help='#iteration of discriminator training for each iteration') 25 | parser.add_argument('--batchSize', default=64, type=int, help='batch size') 26 | parser.add_argument('--learningRate', default=0.1, type=float, help='learning rate') 27 | parser.add_argument('--decayRate', default=0.99, type=float, help='decay rate') 28 | parser.add_argument('--nEpochs', default=100, type=int, help='number of epochs') 29 | parser.add_argument('--halfDecayThreshold', default=0.03, type=float, help='if valid relative increase > this value for 2 epochs, half the LR') 30 | parser.add_argument('--halfDecayDelay', default=8, type=int, help='if no progress in this period, half the LR') 31 | parser.add_argument('--knn', default=10, type=int, help='number of neighbors to extract') 32 | parser.add_argument('--refinementIterations', default=1, type=int, help='number of iteration of refinement') 33 | parser.add_argument('--distance', type=str, default='CSLS', help='distance to use NN or CSLS [2.3]', choices=['CSLS', 'NN']) 34 | parser.add_argument('--validDistance', type=str, default='COS', help='validation distance', choices=['CSLS', 'COS']) 35 | parser.add_argument('--load', type=str, help='load parameters of generator') 36 | parser.add_argument('--save', type=str, help='save parameters of generator', required=True) 37 | parser.add_argument('--dump_output', type=str, help='dump the complete mapped dictionary') 38 | parser.add_argument('--evalDict', type=str, help='dictionary for evaluation') 39 | parser.add_argument('--gpuid', default=-1, type=int) 40 | 41 | args = parser.parse_args() 42 | 43 | torch.manual_seed(args.seed) 44 | np.random.seed(args.seed) 45 | 46 | if args.gpuid >= 0: 47 | # allocate dummy tensor to check GPU is ok 48 | with torch.cuda.device(args.gpuid): 49 | torch.Tensor(1).cuda() 50 | 51 | print("* params: ", args) 52 | 53 | # ------------------------------------------------------- 54 | # READ DICTIONARY 55 | 56 | evalDict = {} 57 | def read_dict(filename): 58 | with open(filename) as f: 59 | for line in f: 60 | lineSplit = line.strip().split("\t") 61 | assert len(lineSplit)==2, "invalid format in dictionary" 62 | if not lineSplit[0] in evalDict: 63 | evalDict[lineSplit[0]] = [lineSplit[1]] 64 | else: 65 | evalDict[lineSplit[0]].append(lineSplit[1]) 66 | 67 | # check an entry meaning and returns @1, @5, @10 68 | def eval_entry(src, tgts): 69 | if not src in evalDict: 70 | return 71 | meanings = evalDict[src] 72 | for i in range(min(len(tgts), 10)): 73 | if tgts[i] in meanings: 74 | if i == 0: return (1, 1, 1) 75 | if i < 5: return (0, 1, 1) 76 | return (0, 0, 1) 77 | return (0, 0, 0) 78 | 79 | def eval_dictionary(d): 80 | s = [0, 0, 0] 81 | c = 0 82 | for k in d.keys(): 83 | score = eval_entry(k, d[k]) 84 | if score: 85 | c += 1 86 | s = [x+y for x,y in zip(s,score)] 87 | s = [ int(x/c*10000.)/100 for x in s ] 88 | return s 89 | 90 | if args.evalDict: 91 | print("* read "+args.evalDict+" dictionary for evaluation") 92 | read_dict(args.evalDict) 93 | print(" => ", len(evalDict.keys()), "entries") 94 | 95 | # ------------------------------------------------------- 96 | # READ EMBEDDING 97 | 98 | def read_embed(filename): 99 | print("* read embedding from "+filename+" - args.vocSize="+str(args.vocSize)) 100 | if filename[-4:] == '.bin': 101 | emb = torch.load(filename) 102 | return emb[0], emb[1] 103 | else: 104 | with open(filename) as f: 105 | header = f.readline().strip() 106 | headerSplit = header.split(" ") 107 | numWords = int(headerSplit[0]) 108 | embeddingSize = int(headerSplit[1]) 109 | assert len(headerSplit)==2, "incorrect file format - header should be '#vocab dim'" 110 | assert numWords>=args.vocSize, "args.vocSize larger than vocabulary in embedding" 111 | assert embeddingSize == args.dim, "embedding size does not match dim" 112 | weights = torch.Tensor(args.vocSize, embeddingSize) 113 | voc = [] 114 | vocDict = dict() 115 | i = 0 116 | bar = progressbar.ProgressBar(max_value=args.vocSize) 117 | while i != args.vocSize: 118 | line = f.readline().strip() 119 | splitLine = line.split(" ") 120 | if len(splitLine) == args.dim + 1: 121 | token = splitLine[0] 122 | if token in vocDict: 123 | print('*** duplicate key in word embedding: '+token) 124 | else: 125 | vocDict[token] = i 126 | voc.append(token) 127 | for j in range(1, args.dim): 128 | weights[i][j-1] = float(splitLine[j]) 129 | bar.update(i) 130 | i = i + 1 131 | torch.save([voc, weights], filename+"_"+str(args.vocSize)+".bin") 132 | print(" * saved to "+filename+"_"+str(args.vocSize)+".bin") 133 | return voc, weights 134 | 135 | svoc, semb = read_embed(args.srcemb[0]) 136 | tvoc, temb = read_embed(args.tgtemb[0]) 137 | 138 | # ------------------------------------------------------- 139 | # PAIR MATCHING 140 | 141 | # initialize index using FAISS 142 | 143 | import faiss 144 | print("* indexing target vocabulary with FAISS") 145 | # index the target embedding 146 | index = faiss.IndexFlatL2(args.dim) 147 | index.add(temb.numpy()) 148 | 149 | # given a tensor or a batch of tensor returns distance and index to closes target neighbours 150 | def NN(v): 151 | cv = v 152 | if v.dim() == 1: 153 | cv.resize_(1, cv.shape[0]) 154 | D, I=index.search(cv.numpy(), args.knn) 155 | return D, I, D 156 | 157 | # calculate rs on the full vocabulary or load it from file 158 | rs = None 159 | rsfile = args.tgtemb[0]+'_rs_knn'+str(args.knn) 160 | if path.isfile(rsfile): 161 | print("* read rs file from: "+rsfile) 162 | rs = torch.load(rsfile) 163 | else: 164 | print("* preparing rs file (on vocabulary size/knn) - it will take a little while - but will get serialized for next iterations") 165 | bar = progressbar.ProgressBar() 166 | rs = torch.Tensor(args.vocSize) 167 | for istep in bar(range(0, args.vocSize, 500)): 168 | istepplus = min(istep+500, args.vocSize) 169 | Ds, Is, Cs = NN(temb[istep:istepplus]) 170 | for i in range(istep, istepplus): 171 | rs[i] = 0 172 | for l in range(args.knn): 173 | rs[i] += cosine(temb[i].numpy(), temb[Is[i-istep][l]].numpy()) 174 | rs[i] /= args.knn 175 | print("* save rs file to: "+rsfile) 176 | torch.save(rs, rsfile) 177 | 178 | def CSLS(v): 179 | # get nearest neighbors and return adjusted cos distance 180 | D, I, COS = NN(v) 181 | COS = np.copy(D) 182 | for idx in range(v.shape[0]): 183 | rt = 0 184 | for j in range(args.knn): 185 | COS[idx][j] = cosine(v[idx].numpy(), temb[I[idx][j]].numpy()) 186 | rt += COS[idx][j] 187 | rt /= args.knn 188 | for j in range(args.knn): 189 | D[idx][j] = 2*COS[idx][j]-rs[I[idx][j]]-rt 190 | return D, I, COS 191 | 192 | def find_matches(v, distance): 193 | if distance == 'NN': 194 | return NN(v) 195 | return CSLS(v) 196 | 197 | def get_dictionary(n, distance): 198 | # get the first n source vocab - and project in target embedding, find their mappings 199 | srcSubset = semb[0:n] 200 | if args.gpuid>=0: 201 | with torch.cuda.device(args.gpuid): 202 | srcSubset = srcSubset.cuda() 203 | 204 | proj = generator(Variable(srcSubset, requires_grad = False)).data.cpu() 205 | 206 | D, I, COS = find_matches(proj, distance) 207 | 208 | d = {} 209 | dID = {} 210 | 211 | validationScore = 0 212 | 213 | for i in range(0, n): 214 | distance = D[i].tolist() 215 | idx = list(range(args.knn)) 216 | idx.sort(key=distance.__getitem__) 217 | if args.validDistance=='COS': 218 | validationScore += COS[i][idx[0]] 219 | else: 220 | validationScore += distance[idx[0]] 221 | dID[i] = [I[i][idx[j]] for j in range(args.knn)] 222 | d[svoc[i]] = [tvoc[I[i][idx[j]]] for j in range(args.knn)] 223 | 224 | return d, validationScore/n, dID 225 | 226 | # ------------------------------------------------------- 227 | # MODEL BUILDING 228 | 229 | discriminator = Discriminator(args) 230 | generator = Generator(args) 231 | 232 | print("* Loss Initialization") 233 | loss_fn = nn.BCELoss() 234 | print(loss_fn) 235 | 236 | zeroClass = Variable(torch.Tensor(args.batchSize).fill_(0), requires_grad = False) 237 | oneClass = Variable(torch.Tensor(args.batchSize).fill_(1), requires_grad = False) 238 | smoothedOneClass = Variable(torch.Tensor(args.batchSize).fill_(1-args.smoothing), requires_grad = False) 239 | 240 | if args.gpuid>=0: 241 | with torch.cuda.device(args.gpuid): 242 | generator = generator.cuda() 243 | discriminator = discriminator.cuda() 244 | zeroClass = zeroClass.cuda() 245 | oneClass = oneClass.cuda() 246 | smoothedOneClass = smoothedOneClass.cuda() 247 | 248 | learningRate = args.learningRate 249 | 250 | # ------------------------------------------------------- 251 | # TRAINING 252 | 253 | if args.nEpochs>0: 254 | print("* Start Training") 255 | valids = [] 256 | optimalScore = 10000000 257 | stopCondition = False 258 | it = 1 259 | while it <= args.nEpochs and not stopCondition: 260 | genLoss = 0 261 | discLoss = 0 262 | print(" * Epoch", it) 263 | bar = progressbar.ProgressBar() 264 | N = min(args.samplingRange, args.vocSize) 265 | for i in bar(range(0, math.ceil(N/args.batchSize))): 266 | for j in range(0, args.k): 267 | bsrcIdx = torch.min((torch.rand(args.batchSize)*N).long(), torch.LongTensor([N-1])) 268 | btgtIdx = torch.min((torch.rand(args.batchSize)*N).long(), torch.LongTensor([N-1])) 269 | batch_src = Variable(torch.index_select(semb, 0, bsrcIdx)) 270 | batch_tgt = Variable(torch.index_select(temb, 0, btgtIdx)) 271 | if args.gpuid>=0: 272 | with torch.cuda.device(args.gpuid): 273 | batch_src = batch_src.cuda() 274 | batch_tgt = batch_tgt.cuda() 275 | 276 | # projection of source in target 277 | projectedSrc = generator(batch_src) 278 | 279 | discriminator.zero_grad() 280 | 281 | # calculate loss for batch src projected in target 282 | discProjSrc = discriminator(projectedSrc).squeeze() 283 | loss = loss_fn(discProjSrc, zeroClass) 284 | discLoss = discLoss + loss.data[0] 285 | loss.backward() 286 | 287 | # loss for tgt classified with smoothed label 288 | discTgt = discriminator(batch_tgt).squeeze() 289 | loss = loss_fn(discTgt, smoothedOneClass) 290 | discLoss = discLoss + loss.data[0] 291 | loss.backward() 292 | 293 | for param in discriminator.parameters(): 294 | param.data -= learningRate * param.grad.data 295 | 296 | bsrcIdx = torch.min((torch.rand(args.batchSize)*N).long(), torch.LongTensor([N-1])) 297 | batch_src = Variable(torch.index_select(semb, 0, bsrcIdx)) 298 | if args.gpuid>=0: 299 | with torch.cuda.device(args.gpuid): 300 | batch_src = batch_src.cuda() 301 | 302 | # calculate loss for batch src projected in target 303 | projectedSrc = generator(batch_src) 304 | discProjSrc = discriminator(projectedSrc).squeeze() 305 | 306 | generator.zero_grad() 307 | loss = loss_fn(discProjSrc, oneClass) 308 | genLoss = genLoss + loss.data[0] 309 | loss.backward() 310 | 311 | for param in generator.parameters(): 312 | param.data -= learningRate * param.grad.data 313 | 314 | generator.orthogonalityUpdate(args.beta) 315 | 316 | evalScore = 'n/a' 317 | d, validationScore, dID = get_dictionary(10000, args.distance) 318 | 319 | if evalDict: 320 | evalScore = eval_dictionary(d) 321 | 322 | print(' * --- ',it,'genLoss=',genLoss*args.batchSize/N, 'discLoss=', discLoss*args.batchSize/N/args.k, 323 | 'learningRate=', learningRate, 'valid=', validationScore, 'eval=', evalScore) 324 | 325 | valids.append(validationScore) 326 | 327 | if validationScore < optimalScore: 328 | generator.save(args.save+"_adversarial.t7") 329 | optimalScore = validationScore 330 | optimalScoreIt = it 331 | optimalScoreLR = learningRate 332 | print(' => saved as optimal W') 333 | 334 | # if validationScore increases than halfDecayThreshold % above optimal score 335 | # or no progress for args.haldDecayDelay epochs - come back to optimal and half decay 336 | if ((validationScore-optimalScore)/abs(optimalScore) > args.halfDecayThreshold 337 | or it - optimalScoreIt > args.halfDecayDelay): 338 | generator.load(args.save+"_adversarial.t7") 339 | it = optimalScoreIt 340 | print(' ***** HALF DECAY - go back to iteration ', it) 341 | learningRate = optimalScoreLR / 2 342 | optimalScoreLR = learningRate 343 | else: 344 | learningRate = learningRate * args.decayRate 345 | 346 | it += 1 347 | # stop completely when learningRate is not more than 20 initial learning rate 348 | stopCondition = learningRate < args.learningRate / 20 349 | 350 | # ------------------------------------------------------- 351 | # EXTRACT 10000 first entries and calculate W using Procrustes solution 352 | 353 | print('* reloading best saved') 354 | generator.load(args.save+"_adversarial.t7") 355 | 356 | if args.refinementIterations > 0: 357 | print('* Start Refining') 358 | 359 | evalScore = 'n/a' 360 | d, v, dID = get_dictionary(10000, args.distance) 361 | 362 | if evalDict: 363 | evalScore = eval_dictionary(d) 364 | 365 | print(' - CSLS score before refinement', v, evalScore) 366 | for itref in range(args.refinementIterations): 367 | ne = len(d.keys()) 368 | X = np.zeros((ne, args.dim)) 369 | Y = np.zeros((ne, args.dim)) 370 | idx = 0 371 | for k in dID.keys(): 372 | X[idx] = semb[k].numpy() 373 | Y[idx] = temb[dID[k][0]].numpy() 374 | idx = idx + 1 375 | A = np.matmul(Y.transpose(), X) 376 | U, s, V = np.linalg.svd(A, full_matrices=True) 377 | WP = np.matmul(U, V) 378 | generator.set(torch.from_numpy(WP)) 379 | d, v, dID = get_dictionary(10000, args.distance) 380 | 381 | evalScore = 'n/a' 382 | if evalDict: 383 | evalScore = eval_dictionary(d) 384 | 385 | print(' - CSLS score after refinement iteration', v, evalScore) 386 | 387 | generator.save(args.save+"_refinement.t7") 388 | 389 | # ------------------------------------------------------- 390 | # GET RESULTS 391 | 392 | if args.dump_output: 393 | with open(args.dump_output, 'w') as fd: 394 | d, v, dID = get_dictionary(args.vocSize, args.distance) 395 | for k in d.keys(): 396 | fd.write(k+"\t"+"\t".join(d[k])+"\n") 397 | --------------------------------------------------------------------------------