├── .gitignore ├── .vscode └── settings.json ├── README.md ├── config.py ├── datasets.py ├── eval_generation.ipynb ├── experiment.py ├── losses.py ├── models.py ├── preprocess.py ├── requirements.txt ├── settings.py ├── train.py ├── update_functions.py ├── utils.py └── vocab.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Linux template 3 | *~ 4 | 5 | # temporary files which can be created if a process still has a handle open of a deleted file 6 | .fuse_hidden* 7 | 8 | # KDE directory preferences 9 | .directory 10 | 11 | # Linux trash folder which might appear on any partition or disk 12 | .Trash-* 13 | 14 | # .nfs files are created when an open file is removed but is still being accessed 15 | .nfs* 16 | ### macOS template 17 | # General 18 | .DS_Store 19 | .AppleDouble 20 | .LSOverride 21 | 22 | # Icon must end with two \r 23 | Icon 24 | 25 | # Thumbnails 26 | ._* 27 | 28 | # Files that might appear in the root of a volume 29 | .DocumentRevisions-V100 30 | .fseventsd 31 | .Spotlight-V100 32 | .TemporaryItems 33 | .Trashes 34 | .VolumeIcon.icns 35 | .com.apple.timemachine.donotpresent 36 | 37 | # Directories potentially created on remote AFP share 38 | .AppleDB 39 | .AppleDesktop 40 | Network Trash Folder 41 | Temporary Items 42 | .apdisk 43 | ### JetBrains template 44 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 45 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 46 | 47 | # User-specific stuff 48 | .idea/**/workspace.xml 49 | .idea/**/tasks.xml 50 | .idea/**/usage.statistics.xml 51 | .idea/**/dictionaries 52 | .idea/**/shelf 53 | 54 | # Sensitive or high-churn files 55 | .idea/**/dataSources/ 56 | .idea/**/dataSources.ids 57 | .idea/**/dataSources.local.xml 58 | .idea/**/sqlDataSources.xml 59 | .idea/**/dynamic.xml 60 | .idea/**/uiDesigner.xml 61 | .idea/**/dbnavigator.xml 62 | 63 | # Gradle 64 | .idea/**/gradle.xml 65 | .idea/**/libraries 66 | 67 | # Gradle and Maven with auto-import 68 | # When using Gradle or Maven with auto-import, you should exclude module files, 69 | # since they will be recreated, and may cause churn. Uncomment if using 70 | # auto-import. 71 | # .idea/modules.xml 72 | # .idea/*.iml 73 | # .idea/modules 74 | 75 | # CMake 76 | cmake-build-*/ 77 | 78 | # Mongo Explorer plugin 79 | .idea/**/mongoSettings.xml 80 | 81 | # File-based project format 82 | *.iws 83 | 84 | # IntelliJ 85 | out/ 86 | 87 | # mpeltonen/sbt-idea plugin 88 | .idea_modules/ 89 | 90 | # JIRA plugin 91 | atlassian-ide-plugin.xml 92 | 93 | # Cursive Clojure plugin 94 | .idea/replstate.xml 95 | 96 | # Crashlytics plugin (for Android Studio and IntelliJ) 97 | com_crashlytics_export_strings.xml 98 | crashlytics.properties 99 | crashlytics-build.properties 100 | fabric.properties 101 | 102 | # Editor-based Rest Client 103 | .idea/httpRequests 104 | ### Python template 105 | # Byte-compiled / optimized / DLL files 106 | __pycache__/ 107 | *.py[cod] 108 | *$py.class 109 | 110 | # C extensions 111 | *.so 112 | 113 | # Distribution / packaging 114 | .Python 115 | build/ 116 | develop-eggs/ 117 | dist/ 118 | downloads/ 119 | eggs/ 120 | .eggs/ 121 | lib/ 122 | lib64/ 123 | parts/ 124 | sdist/ 125 | var/ 126 | wheels/ 127 | *.egg-info/ 128 | .installed.cfg 129 | *.egg 130 | MANIFEST 131 | 132 | # PyInstaller 133 | # Usually these files are written by a python script from a template 134 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 135 | *.manifest 136 | *.spec 137 | 138 | # Installer logs 139 | pip-log.txt 140 | pip-delete-this-directory.txt 141 | 142 | # Unit test / coverage reports 143 | htmlcov/ 144 | .tox/ 145 | .coverage 146 | .coverage.* 147 | .cache 148 | nosetests.xml 149 | coverage.xml 150 | *.cover 151 | .hypothesis/ 152 | .pytest_cache/ 153 | 154 | # Translations 155 | *.mo 156 | *.pot 157 | 158 | # Django stuff: 159 | *.log 160 | local_settings.py 161 | db.sqlite3 162 | 163 | # Flask stuff: 164 | instance/ 165 | .webassets-cache 166 | 167 | # Scrapy stuff: 168 | .scrapy 169 | 170 | # Sphinx documentation 171 | docs/_build/ 172 | 173 | # PyBuilder 174 | target/ 175 | 176 | # Jupyter Notebook 177 | .ipynb_checkpoints 178 | 179 | # pyenv 180 | .python-version 181 | 182 | # celery beat schedule file 183 | celerybeat-schedule 184 | 185 | # SageMath parsed files 186 | *.sage.py 187 | 188 | # Environments 189 | .env 190 | .venv 191 | env/ 192 | venv/ 193 | ENV/ 194 | env.bak/ 195 | venv.bak/ 196 | 197 | # Spyder project settings 198 | .spyderproject 199 | .spyproject 200 | 201 | # Rope project settings 202 | .ropeproject 203 | 204 | # mkdocs documentation 205 | /site 206 | 207 | # mypy 208 | .mypy_cache/ 209 | 210 | 211 | data/ 212 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/home/aromanov/virt_env/adversarial_decomposition/bin/python" 3 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Decomposition of Text Representation 2 | The code for the paper "Adversarial Decomposition of Text Representation", NAACL 2019 3 | https://arxiv.org/abs/1808.09042 4 | 5 | # Installation 6 | 7 | 1. Clone this repo: `https://github.com/text-machine-lab/adversarial_decomposition.git` 8 | 2. Install NumPy: `pip install numpy==1.16.3` 9 | 3. Install PyTorch v1.1.0: `pip install https://download.pytorch.org/whl/cu100/torch-1.1.0-cp37-cp37m-linux_x86_64.whl` (for python3.6, use `pip install https://download.pytorch.org/whl/cu100/torch-1.1.0-cp36-cp36m-linux_x86_64.whl`) 10 | 4. Install dependencies: `pip install -r requirements.txt` 11 | 5. Download spacy models: `python -m spacy download en_core_web_lg` 12 | 13 | # Initial setup 14 | 15 | 1. Create dir `mkdir -p data/experiments` 16 | 2. Create dir `mkdir -p data/datasets` 17 | 3. Create dir `mkdir -p data/word_embeddings` 18 | 3. Download the Shakespeare data: `git clone https://github.com/cocoxu/Shakespeare.git data/datasets/shakespeare` 19 | 3. Download the Yelp data: `git clone https://github.com/shentianxiao/language-style-transfer.git data/datasets/yelp` 20 | 4. Download the pickled GloVe embeddings `wget https://mednli.blob.core.windows.net/shared/word_embeddings/glove.840B.300d.pickled -O data/word_embeddings/glove.840B.300d.pickled` 21 | 4. Download the pickled fastText embeddings `wget https://mednli.blob.core.windows.net/shared/word_embeddings/crawl-300d-2M.pickled -O data/word_embeddings/crawl-300d-2M.pickled` 22 | 23 | # Running the code 24 | 25 | Global constants are set in the file `settings.py`. In general, you don't need to change this file. 26 | Experiment parameters are set in the `config.py` file. 27 | 28 | First, run the preprocessing script: `python preprocess.py` 29 | This scipt will print the ID of the preprocessing experiment, for example `preprocess.buppgpnf`. Copy this ID and change parameter `preprocess_exp_id` of the `TrainConfig` class on the line 12 in the file `config.py` file accordingly. 30 | 31 | After you set the preprocess experiment id, run the training: `python train.py`. 32 | This scirpt will also print the ID of the training experiment. You can paste it in the `eval_generation.ipynb` notebook to play with the model. 33 | 34 | ## Chaning the form and meaning 35 | The provided `eval_generation.ipynb` notebook shows how to use the model to swap the meaning and form vectors of the input sentences! 36 | 37 | 38 | # Citation 39 | If you find this code helpful, please consider citing our paper: 40 | 41 | *A. Romanov, A. Rumshisky, A. Rogers, D. Donahue,Adversarial decomposition of text represen-tation, In Proceedings of NAACL 2019: Conference of the North American Chapter of the Association for Computational Linguistics, 2019* 42 | 43 | https://arxiv.org/abs/1808.09042 44 | 45 | ``` 46 | @inproceedings{romanov2019adversarial, 47 | title={Adversarial Decomposition of Text Representation}, 48 | author={Romanov, Alexey and Rumshisky, Anna and Rogers, Anna and Donahue, David}, 49 | booktitle={Proceedings of NAACL 2019: Conference of the North American Chapter of the Association for Computational Linguistics}, 50 | year={2019} 51 | } 52 | ``` 53 | 54 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from pathlib import Path 3 | 4 | from datasets import ShakespeareDatasetReader, YelpDatasetReader 5 | from models import Seq2Seq, Seq2SeqMeaningStyle, StyleClassifier 6 | from settings import SHAKESPEARE_DATASET_DIR, YELP_DATASET_DIR 7 | 8 | 9 | @dataclasses.dataclass 10 | class TrainConfig: 11 | model_class: type = Seq2SeqMeaningStyle 12 | preprocess_exp_id: str = 'preprocess.buppgpnf' # Shakespeare: xxx | Yelp: 2p089c54 13 | 14 | embedding_size: int = 300 15 | hidden_size: int = 256 16 | dropout: float = 0.2 17 | scheduled_sampling_ratio: float = 0.5 18 | pretrained_embeddings: bool = True 19 | trainable_embeddings: bool = False 20 | 21 | meaning_size: int = 128 22 | style_size: int = 128 23 | 24 | lr: float = 0.001 25 | weight_decay: float = 0.0000001 26 | grad_clipping: float = 5 27 | 28 | D_num_iterations: int = 10 29 | D_loss_multiplier: float = 1 30 | P_loss_multiplier: float = 10 31 | P_bow_loss_multiplier: float = 1 32 | use_discriminator: bool = True 33 | use_predictor: bool = False 34 | use_predictor_bow: bool = True 35 | use_motivator: bool = True 36 | use_gauss: bool = False 37 | 38 | num_epochs: int = 500 39 | batch_size: int = 1024 40 | best_loss: str = 'loss' 41 | 42 | 43 | @dataclasses.dataclass 44 | class PreprocessConfig: 45 | data_path: Path = YELP_DATASET_DIR 46 | dataset_reader_class: type = YelpDatasetReader 47 | 48 | min_len: int = 3 49 | max_len: int = 20 50 | lowercase: bool = True 51 | word_embeddings: str = 'fast_text' 52 | max_vocab_size: int = 50000 53 | 54 | nb_style_dims: int = 50 55 | nb_style_dims_sentences: int = 50000 56 | style_tokens_proportion: float = 0.2 57 | 58 | test_size: int = 10000 59 | val_size: int = 10000 60 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import csv 3 | import dataclasses 4 | import itertools 5 | import re 6 | 7 | import spacy 8 | import numpy as np 9 | import torch.utils.data 10 | 11 | from vocab import Vocab 12 | 13 | 14 | class SentenceStyleDatasetReader(object): 15 | 16 | def __init__(self, min_len, max_len, lowercase, *args, **kwargs): 17 | self.min_len = min_len 18 | self.max_len = max_len 19 | self.lowercase = lowercase 20 | 21 | disable = ['vectors', 'textcat', 'tagger', 'parser', 'ner'] 22 | self.spacy = spacy.load('en_core_web_lg', disable=disable) 23 | self.spacy.add_pipe(self.spacy.create_pipe('sentencizer')) 24 | 25 | @abc.abstractmethod 26 | def _read(self, data_path): 27 | pass 28 | 29 | 30 | def clean_sentence(self, sentence): 31 | sentence_cleaned = sentence.replace('\r', ' ') 32 | sentence_cleaned = sentence_cleaned.replace('\n', ' ') 33 | sentence_cleaned = sentence_cleaned.replace("'m", ' am') 34 | sentence_cleaned = sentence_cleaned.replace("'ve", ' have') 35 | sentence_cleaned = sentence_cleaned.replace("n\'t", ' not') 36 | sentence_cleaned = sentence_cleaned.replace("\'re", ' are') 37 | sentence_cleaned = sentence_cleaned.replace("\'d", ' would') 38 | sentence_cleaned = sentence_cleaned.replace("\'ll", ' will') 39 | 40 | return sentence_cleaned 41 | 42 | def preprocess_sentence(self, sentence): 43 | sentence = [ 44 | token.lower_ if self.lowercase else token.text 45 | for token in sentence 46 | if not token.is_space 47 | ] 48 | 49 | # cut to max len -1 for the END token 50 | sentence = sentence[:self.max_len - 1] 51 | 52 | return sentence 53 | 54 | def read(self, data_path): 55 | samples = [] 56 | 57 | for sentence, style in self._read(data_path): 58 | sentence = self.clean_sentence(sentence) 59 | sentence = self.spacy(sentence) 60 | sentence = self.preprocess_sentence(sentence) 61 | 62 | if len(sentence) > self.min_len: 63 | sample = dict(sentence=sentence, style=style) 64 | samples.append(sample) 65 | 66 | return samples 67 | 68 | 69 | class ShakespeareDatasetReader(SentenceStyleDatasetReader): 70 | 71 | def _read(self, data_path): 72 | for file in data_path.iterdir(): 73 | file_style = '-' 74 | if file.name.endswith('original.snt.aligned'): 75 | file_style = 'original' 76 | if file.name.endswith('modern.snt.aligned'): 77 | file_style = 'modern' 78 | 79 | with open(file) as f: 80 | for line in f: 81 | sentence = line.strip() 82 | 83 | yield sentence, file_style 84 | 85 | 86 | class YelpDatasetReader(SentenceStyleDatasetReader): 87 | def clean_sentence(self, sentence): 88 | sentence = super().clean_sentence(sentence) 89 | 90 | sentence_cleaned = sentence.replace("_num_", 'number') 91 | 92 | return sentence_cleaned 93 | 94 | def _read(self, data_path): 95 | files = [ 96 | data_path.joinpath('sentiment.train.0'), 97 | data_path.joinpath('sentiment.train.1'), 98 | data_path.joinpath('sentiment.dev.0'), 99 | data_path.joinpath('sentiment.dev.1'), 100 | ] 101 | 102 | for file in files: 103 | file_style = '-' 104 | if file.name.endswith('0'): 105 | file_style = 'negative' 106 | if file.name.endswith('1'): 107 | file_style = 'positive' 108 | 109 | with open(file) as f: 110 | for line in f: 111 | sentence = line.strip() 112 | 113 | yield sentence, file_style 114 | 115 | 116 | 117 | class SentenceStyleDataset(torch.utils.data.Dataset): 118 | 119 | def __init__(self, instances, vocab, style_vocab): 120 | self.instances = instances 121 | self.vocab = vocab 122 | self.style_vocab = style_vocab 123 | 124 | self.max_len = max(len(inst['sentence']) for inst in instances) + 1 # +1 for the END token 125 | 126 | for inst in self.instances: 127 | inst_encoded = self.encode_instance(inst) 128 | inst.update(inst_encoded) 129 | 130 | def pad_sentence(self, sentence): 131 | # add end token 132 | sentence = sentence + [Vocab.END_TOKEN, ] 133 | 134 | # pad 135 | sentence = sentence + [Vocab.PAD_TOKEN, ] * (self.max_len - len(sentence)) 136 | 137 | return sentence 138 | 139 | def encode_instance(self, instance): 140 | sentence, style = instance['sentence'], instance['style'] 141 | 142 | sentence = self.pad_sentence(sentence) 143 | sentence_enc = np.array([self.vocab.get(t, Vocab.UNK_TOKEN) for t in sentence], dtype=np.long) 144 | 145 | style_enc = self.style_vocab[style] 146 | 147 | encoded = dict( 148 | sentence_enc=sentence_enc, 149 | style_enc=style_enc 150 | ) 151 | return encoded 152 | 153 | def __getitem__(self, index): 154 | inst = self.instances[index] 155 | 156 | inst = { 157 | 'sentence': inst['sentence_enc'], 158 | 'style': inst['style_enc'], 159 | } 160 | 161 | return inst 162 | 163 | def __len__(self): 164 | return len(self.instances) 165 | 166 | 167 | class MeaningEmbeddingSentenceStyleDataset(SentenceStyleDataset): 168 | def __init__(self, W_emb, style_dimensions, style_tokens_proportion, *args, **kwargs): 169 | super().__init__(*args, **kwargs) 170 | 171 | self.W_emb = W_emb 172 | self.style_dimensions = style_dimensions 173 | self.style_tokens_proportion = style_tokens_proportion 174 | 175 | for inst in self.instances: 176 | inst['meaning_embedding'] = self.calc_meaning_embedding(inst, W_emb) 177 | 178 | def calc_meaning_embedding(self, instance, W_emb): 179 | tokens = [t for t in instance['sentence'] if t not in {Vocab.END_TOKEN, Vocab.PAD_TOKEN, Vocab.UNK_TOKEN}] 180 | 181 | nb_tokens = len(tokens) 182 | nb_style_tokens = int(np.ceil(nb_tokens * self.style_tokens_proportion)) 183 | 184 | sentence_embedding = np.array([W_emb[self.vocab[t]] for t in tokens]) 185 | sorted_by_style_dim_idx = np.argsort(-np.abs(sentence_embedding[:, self.style_dimensions]).max(axis=-1)) 186 | meaning_idx = sorted_by_style_dim_idx[nb_style_tokens:] 187 | meaning_embedding = np.sum(sentence_embedding[meaning_idx], axis=0) / (nb_tokens - nb_style_tokens) 188 | 189 | return meaning_embedding 190 | 191 | def __getitem__(self, index): 192 | inst = super().__getitem__(index) 193 | 194 | inst['meaning_embedding'] = self.instances[index]['meaning_embedding'] 195 | 196 | return inst 197 | 198 | -------------------------------------------------------------------------------- /eval_generation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import random\n", 10 | "\n", 11 | "import numpy as np\n", 12 | "import torch\n", 13 | "from torch.utils.data.dataloader import default_collate\n", 14 | "\n", 15 | "from settings import EXPERIMENTS_DIR\n", 16 | "from experiment import Experiment\n", 17 | "from utils import to_device, load_weights, load_embeddings, create_embeddings_matrix\n", 18 | "from vocab import Vocab\n", 19 | "from train import create_model\n", 20 | "from preprocess import load_dataset, create_dataset_reader" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "exp_id = 'train.jkmkvrrr'" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "# Load everything" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "exp = Experiment.load(EXPERIMENTS_DIR, exp_id)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 4, 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "data": { 55 | "text/plain": [ 56 | "TrainConfig(model_class=, preprocess_exp_id='preprocess.buppgpnf', embedding_size=300, hidden_size=256, dropout=0.2, scheduled_sampling_ratio=0.5, pretrained_embeddings=True, trainable_embeddings=False, meaning_size=128, style_size=128, lr=0.001, weight_decay=1e-07, grad_clipping=5, D_num_iterations=10, D_loss_multiplier=1, P_loss_multiplier=10, P_bow_loss_multiplier=1, use_discriminator=True, use_predictor=False, use_predictor_bow=True, use_motivator=True, use_gauss=False, num_epochs=500, batch_size=1024, best_loss='loss')" 57 | ] 58 | }, 59 | "execution_count": 4, 60 | "metadata": {}, 61 | "output_type": "execute_result" 62 | } 63 | ], 64 | "source": [ 65 | "exp.config" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 5, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "Dataset: 453655, val: 10000, test: 10000\n", 78 | "Vocab: 9419, style vocab: 2\n", 79 | "W_emb: (9419, 300)\n" 80 | ] 81 | } 82 | ], 83 | "source": [ 84 | "preprocess_exp = Experiment.load(EXPERIMENTS_DIR, exp.config.preprocess_exp_id)\n", 85 | "dataset_train, dataset_val, dataset_test, vocab, style_vocab, W_emb = load_dataset(preprocess_exp)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 6, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "dataset_reader = create_dataset_reader(preprocess_exp.config)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 7, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "model = create_model(exp.config, vocab, style_vocab, dataset_train.max_len, W_emb)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 8, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "load_weights(model, exp.experiment_dir.joinpath('best.th'))" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 9, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "model = model.eval()" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "## Predict" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 10, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "def create_inputs(instances):\n", 138 | " if not isinstance(instances, list):\n", 139 | " instances = [instances,]\n", 140 | " \n", 141 | " if not isinstance(instances[0], dict):\n", 142 | " sentences = [\n", 143 | " dataset_reader.preprocess_sentence(dataset_reader.spacy( dataset_reader.clean_sentence(sent)))\n", 144 | " for sent in instances\n", 145 | " ]\n", 146 | " \n", 147 | " style = list(style_vocab.token2id.keys())[0]\n", 148 | " instances = [\n", 149 | " {\n", 150 | " 'sentence': sent,\n", 151 | " 'style': style,\n", 152 | " }\n", 153 | " for sent in sentences\n", 154 | " ]\n", 155 | " \n", 156 | " for inst in instances:\n", 157 | " inst_encoded = dataset_train.encode_instance(inst)\n", 158 | " inst.update(inst_encoded) \n", 159 | " \n", 160 | " \n", 161 | " instances = [\n", 162 | " {\n", 163 | " 'sentence': inst['sentence_enc'],\n", 164 | " 'style': inst['style_enc'],\n", 165 | " } \n", 166 | " for inst in instances\n", 167 | " ]\n", 168 | " \n", 169 | " instances = default_collate(instances)\n", 170 | " instances = to_device(instances) \n", 171 | " \n", 172 | " return instances" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 11, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "def get_sentences(outputs):\n", 182 | " predicted_indices = outputs[\"predictions\"]\n", 183 | " end_idx = vocab[Vocab.END_TOKEN]\n", 184 | " \n", 185 | " if not isinstance(predicted_indices, np.ndarray):\n", 186 | " predicted_indices = predicted_indices.detach().cpu().numpy()\n", 187 | "\n", 188 | " all_predicted_tokens = []\n", 189 | " for indices in predicted_indices:\n", 190 | " indices = list(indices)\n", 191 | "\n", 192 | " # Collect indices till the first end_symbol\n", 193 | " if end_idx in indices:\n", 194 | " indices = indices[:indices.index(end_idx)]\n", 195 | "\n", 196 | " predicted_tokens = [vocab.id2token[x] for x in indices]\n", 197 | " all_predicted_tokens.append(predicted_tokens)\n", 198 | " \n", 199 | " return all_predicted_tokens" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 12, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "sentence = ' '.join(dataset_val.instances[1]['sentence'])" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 13, 214 | "metadata": {}, 215 | "outputs": [ 216 | { 217 | "data": { 218 | "text/plain": [ 219 | "'they are really good people .'" 220 | ] 221 | }, 222 | "execution_count": 13, 223 | "metadata": {}, 224 | "output_type": "execute_result" 225 | } 226 | ], 227 | "source": [ 228 | "sentence" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 14, 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "inputs = create_inputs(sentence)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 15, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "outputs = model(inputs)" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 16, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "sentences = get_sentences(outputs)" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 17, 261 | "metadata": {}, 262 | "outputs": [ 263 | { 264 | "data": { 265 | "text/plain": [ 266 | "'they are really good people .'" 267 | ] 268 | }, 269 | "execution_count": 17, 270 | "metadata": {}, 271 | "output_type": "execute_result" 272 | } 273 | ], 274 | "source": [ 275 | "' '.join(sentences[0])" 276 | ] 277 | }, 278 | { 279 | "cell_type": "markdown", 280 | "metadata": {}, 281 | "source": [ 282 | "### Swap style" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": 18, 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [ 291 | "possible_styles = list(style_vocab.token2id.keys()) #['negative', 'positive']" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 19, 297 | "metadata": {}, 298 | "outputs": [ 299 | { 300 | "data": { 301 | "text/plain": [ 302 | "['negative', 'positive']" 303 | ] 304 | }, 305 | "execution_count": 19, 306 | "metadata": {}, 307 | "output_type": "execute_result" 308 | } 309 | ], 310 | "source": [ 311 | "possible_styles" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 20, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "sentences0 = [s for s in dataset_val.instances if s['style'] == possible_styles[0]]\n", 321 | "sentences1 = [s for s in dataset_val.instances if s['style'] == possible_styles[1]]" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": 52, 327 | "metadata": {}, 328 | "outputs": [ 329 | { 330 | "name": "stdout", 331 | "output_type": "stream", 332 | "text": [ 333 | "3239 if i could give negative stars i certainly would for this place .\n", 334 | "2083 if it was not broke , why in the world would you fix it ?\n", 335 | "3874 the rice had hard things in it .\n", 336 | "1569 quite possibly the worst experience of my life .\n", 337 | "3584 however our little one ordered buttered noodles and was pleased as punch .\n" 338 | ] 339 | } 340 | ], 341 | "source": [ 342 | "for i in np.random.choice(np.arange(len(sentences0)), 5):\n", 343 | " print(i, ' '.join(sentences0[i]['sentence']))" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 22, 349 | "metadata": {}, 350 | "outputs": [ 351 | { 352 | "name": "stdout", 353 | "output_type": "stream", 354 | "text": [ 355 | "4935 which is awesome !\n", 356 | "1561 it is also a good place to go just for dessert .\n", 357 | "1208 they are amazing , truly , could not be happier .\n", 358 | "1347 i had the tamales and they were the best i have ever had !\n", 359 | "3450 the capistrami is the best thing ever .\n" 360 | ] 361 | } 362 | ], 363 | "source": [ 364 | "for i in np.random.choice(np.arange(len(sentences1)), 5):\n", 365 | " print(i, ' '.join(sentences1[i]['sentence']))" 366 | ] 367 | }, 368 | { 369 | "cell_type": "markdown", 370 | "metadata": {}, 371 | "source": [ 372 | "#### Swap" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": 53, 378 | "metadata": {}, 379 | "outputs": [], 380 | "source": [ 381 | "target0 = 3874 # np.random.choice(np.arange(len(sentences0)))\n", 382 | "target1 = 4935 # np.random.choice(np.arange(len(sentences0)))" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": 54, 388 | "metadata": {}, 389 | "outputs": [ 390 | { 391 | "name": "stdout", 392 | "output_type": "stream", 393 | "text": [ 394 | "the rice had hard things in it .\n" 395 | ] 396 | } 397 | ], 398 | "source": [ 399 | "print(' '.join(sentences0[target0]['sentence']))" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": 55, 405 | "metadata": {}, 406 | "outputs": [ 407 | { 408 | "name": "stdout", 409 | "output_type": "stream", 410 | "text": [ 411 | "which is awesome !\n" 412 | ] 413 | } 414 | ], 415 | "source": [ 416 | "print(' '.join(sentences1[target1]['sentence']))" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 56, 422 | "metadata": {}, 423 | "outputs": [], 424 | "source": [ 425 | "inputs = create_inputs([\n", 426 | " sentences0[target0],\n", 427 | " sentences1[target1],\n", 428 | "])" 429 | ] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "execution_count": 57, 434 | "metadata": {}, 435 | "outputs": [], 436 | "source": [ 437 | "z_hidden = model(inputs)" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": 58, 443 | "metadata": {}, 444 | "outputs": [ 445 | { 446 | "data": { 447 | "text/plain": [ 448 | "torch.Size([2, 128])" 449 | ] 450 | }, 451 | "execution_count": 58, 452 | "metadata": {}, 453 | "output_type": "execute_result" 454 | } 455 | ], 456 | "source": [ 457 | "z_hidden['style_hidden'].shape" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": 59, 463 | "metadata": {}, 464 | "outputs": [ 465 | { 466 | "data": { 467 | "text/plain": [ 468 | "torch.Size([2, 128])" 469 | ] 470 | }, 471 | "execution_count": 59, 472 | "metadata": {}, 473 | "output_type": "execute_result" 474 | } 475 | ], 476 | "source": [ 477 | "z_hidden['meaning_hidden'].shape" 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "execution_count": 60, 483 | "metadata": {}, 484 | "outputs": [], 485 | "source": [ 486 | "original_decoded = model.decode(z_hidden)" 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "execution_count": 61, 492 | "metadata": {}, 493 | "outputs": [], 494 | "source": [ 495 | "original_sentences = get_sentences(original_decoded)" 496 | ] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "execution_count": 62, 501 | "metadata": {}, 502 | "outputs": [ 503 | { 504 | "name": "stdout", 505 | "output_type": "stream", 506 | "text": [ 507 | "the rice had hard things in it .\n", 508 | "which is awesome !\n" 509 | ] 510 | } 511 | ], 512 | "source": [ 513 | "print(' '.join(original_sentences[0]))\n", 514 | "print(' '.join(original_sentences[1]))" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": 63, 520 | "metadata": {}, 521 | "outputs": [], 522 | "source": [ 523 | "z_hidden_swapped = {\n", 524 | " 'meaning_hidden': torch.stack([\n", 525 | " z_hidden['meaning_hidden'][0].clone(),\n", 526 | " z_hidden['meaning_hidden'][1].clone(), \n", 527 | " ], dim=0),\n", 528 | " 'style_hidden': torch.stack([\n", 529 | " z_hidden['style_hidden'][1].clone(),\n", 530 | " z_hidden['style_hidden'][0].clone(), \n", 531 | " ], dim=0),\n", 532 | "}" 533 | ] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "execution_count": 64, 538 | "metadata": {}, 539 | "outputs": [], 540 | "source": [ 541 | "swaped_decoded = model.decode(z_hidden_swapped)" 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": 65, 547 | "metadata": {}, 548 | "outputs": [], 549 | "source": [ 550 | "swaped_sentences = get_sentences(swaped_decoded)" 551 | ] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "execution_count": 66, 556 | "metadata": { 557 | "scrolled": true 558 | }, 559 | "outputs": [ 560 | { 561 | "name": "stdout", 562 | "output_type": "stream", 563 | "text": [ 564 | "the rice had hard things in it .\n", 565 | "which is awesome !\n", 566 | "\n", 567 | "plus is really hard to it .\n", 568 | "the rice was awesome .\n" 569 | ] 570 | } 571 | ], 572 | "source": [ 573 | "print(' '.join(original_sentences[0]))\n", 574 | "print(' '.join(original_sentences[1]))\n", 575 | "print()\n", 576 | "print(' '.join(swaped_sentences[0]))\n", 577 | "print(' '.join(swaped_sentences[1]))" 578 | ] 579 | }, 580 | { 581 | "cell_type": "code", 582 | "execution_count": null, 583 | "metadata": {}, 584 | "outputs": [], 585 | "source": [] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": null, 590 | "metadata": {}, 591 | "outputs": [], 592 | "source": [] 593 | } 594 | ], 595 | "metadata": { 596 | "kernelspec": { 597 | "display_name": "Python 3", 598 | "language": "python", 599 | "name": "python3" 600 | }, 601 | "language_info": { 602 | "codemirror_mode": { 603 | "name": "ipython", 604 | "version": 3 605 | }, 606 | "file_extension": ".py", 607 | "mimetype": "text/x-python", 608 | "name": "python", 609 | "nbconvert_exporter": "python", 610 | "pygments_lexer": "ipython3", 611 | "version": "3.7.1" 612 | } 613 | }, 614 | "nbformat": 4, 615 | "nbformat_minor": 2 616 | } 617 | -------------------------------------------------------------------------------- /experiment.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from pathlib import Path 3 | 4 | from utils import save_pickle, load_pickle 5 | 6 | 7 | class Experiment(object): 8 | _CONFIG_FILENAME = 'config.pkl' 9 | 10 | def __init__(self, experiments_dir, config, prefix=None): 11 | self.config = config 12 | self.experiments_dir = experiments_dir 13 | self.prefix = prefix 14 | 15 | # create dir for the experiment 16 | if self.prefix is not None: 17 | self.prefix = f'{self.prefix}.' 18 | 19 | self.experiment_dir = None 20 | self.experiment_id = None 21 | 22 | def __enter__(self): 23 | self.experiment_dir = Path(tempfile.mkdtemp(dir=self.experiments_dir, prefix=self.prefix)) 24 | self.experiment_id = self.experiment_dir.name 25 | 26 | # save the config file 27 | Experiment._save_config(self.config, self.experiment_dir) 28 | 29 | return self 30 | 31 | def __exit__(self, exc_type, exc_val, exc_tb): 32 | pass 33 | 34 | @classmethod 35 | def load(cls, experiments_dir, experiment_id): 36 | experiment_dir = experiments_dir.joinpath(experiment_id) 37 | 38 | config = Experiment._load_config(experiment_dir) 39 | 40 | exp = Experiment(experiments_dir, config) 41 | exp.experiment_dir = experiment_dir 42 | exp.experiment_id = exp.experiment_dir.name 43 | 44 | return exp 45 | 46 | @classmethod 47 | def _save_config(cls, config, experiment_dir): 48 | filename = experiment_dir.joinpath(Experiment._CONFIG_FILENAME) 49 | save_pickle(config, filename) 50 | 51 | @classmethod 52 | def _load_config(cls, experiment_dir): 53 | filename = experiment_dir.joinpath(Experiment._CONFIG_FILENAME) 54 | config = load_pickle(filename) 55 | 56 | return config 57 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class SequenceReconstructionLoss(torch.nn.Module): 6 | def __init__(self, ignore_index=-100): 7 | super(SequenceReconstructionLoss, self).__init__() 8 | 9 | self.xent_loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index) 10 | 11 | def _calc_sent_xent(self, outputs, targets): 12 | if len(outputs.shape) > 2: 13 | targets = targets.view(-1) 14 | outputs = outputs.view(targets.size(0), -1) 15 | 16 | xent = self.xent_loss(outputs, targets) 17 | 18 | return xent 19 | 20 | def forward(self, outputs, targets): 21 | loss = self._calc_sent_xent(outputs, targets) 22 | 23 | return loss 24 | 25 | 26 | class StyleEntropyLoss(torch.nn.Module): 27 | def __init__(self): 28 | super(StyleEntropyLoss, self).__init__() 29 | 30 | self.epsilon = 1e-07 31 | 32 | def forward(self, logits): 33 | probs = torch.sigmoid(logits) 34 | entropy = probs * torch.log(probs + self.epsilon) + (1 - probs) * torch.log(1 - probs + self.epsilon) 35 | entropy = torch.mean(entropy, dim=-1) 36 | 37 | loss_mean = torch.mean(entropy) # No `-1 *` as we are going to add it to the loss 38 | 39 | return loss_mean 40 | 41 | 42 | class MeaningZeroLoss(torch.nn.Module): 43 | def __init__(self): 44 | super(MeaningZeroLoss, self).__init__() 45 | 46 | def forward(self, predicted): 47 | loss = predicted ** 2 48 | loss_mean = torch.mean(loss) 49 | 50 | return loss_mean 51 | 52 | 53 | class LSGANDiscriminatorLoss(torch.nn.Module): 54 | # Least Squares GAN 55 | def __init__(self): 56 | super(LSGANDiscriminatorLoss, self).__init__() 57 | 58 | def forward(self, logits, styles): 59 | logits_zero = logits[styles == 0] 60 | logits_one = logits[styles == 1] 61 | 62 | loss = 0.5 * (torch.mean((logits_zero - 1) ** 2) + torch.mean(logits_one ** 2)) 63 | 64 | return loss 65 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from losses import SequenceReconstructionLoss, StyleEntropyLoss, MeaningZeroLoss 6 | from utils import get_sequences_lengths, to_device 7 | 8 | 9 | class LSTMEncoder(torch.nn.Module): 10 | def __init__(self, input_size, hidden_size, dropout, num_layers=1, bidirectional=False, return_sequence=False): 11 | super().__init__() 12 | 13 | self.input_size = input_size 14 | self.hidden_size = hidden_size 15 | self.num_layers = num_layers 16 | self.bidirectional = bidirectional 17 | self.return_sequence = return_sequence 18 | 19 | self.lstm = torch.nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True) 20 | 21 | def zero_state(self, batch_size): 22 | # The axes semantics are (num_layers, batch_size, hidden_dim) 23 | nb_layers = self.num_layers if not self.bidirectional else self.nb_layers * 2 24 | state_shape = (nb_layers, batch_size, self.hidden_size) 25 | 26 | # shape: (num_layers, batch_size, hidden_dim) 27 | h = to_device(torch.zeros(*state_shape)) 28 | 29 | # shape: (num_layers, batch_size, hidden_dim) 30 | c = torch.zeros_like(h) 31 | 32 | return h, c 33 | 34 | def forward(self, inputs, lengths): 35 | batch_size = inputs.shape[0] 36 | 37 | # shape: (num_layers, batch_size, hidden_dim) 38 | h, c = self.zero_state(batch_size) 39 | 40 | lengths_sorted, inputs_sorted_idx = lengths.sort(descending=True) 41 | inputs_sorted = inputs[inputs_sorted_idx] 42 | 43 | # pack sequences 44 | packed = torch.nn.utils.rnn.pack_padded_sequence(inputs_sorted, lengths_sorted.detach(), batch_first=True) 45 | 46 | # shape: (batch_size, sequence_len, hidden_dim) 47 | outputs, (h, c) = self.lstm(packed, (h, c)) 48 | 49 | # concatenate if bidirectional 50 | # shape: (batch_size, hidden_dim) 51 | h = torch.cat([x for x in h], dim=-1) 52 | 53 | # unpack sequences 54 | outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) 55 | 56 | _, inputs_unsorted_idx = inputs_sorted_idx.sort(descending=False) 57 | outputs = outputs[inputs_unsorted_idx] 58 | h = h[inputs_unsorted_idx] 59 | 60 | if self.return_sequence: 61 | return outputs 62 | else: 63 | return h 64 | 65 | 66 | class Squeeze(torch.nn.Module): 67 | def __init__(self, dim=-1): 68 | super().__init__() 69 | 70 | self.dim = dim 71 | 72 | def forward(self, inputs): 73 | inputs = inputs.squeeze(self.dim) 74 | return inputs 75 | 76 | 77 | class SpaceTransformer(torch.nn.Module): 78 | def __init__(self, input_size, output_size, dropout): 79 | super().__init__() 80 | 81 | self.input_size = input_size 82 | self.output_size = output_size 83 | self.dropout = dropout 84 | 85 | self.fc = torch.nn.Sequential( 86 | torch.nn.Linear(input_size, output_size), 87 | torch.nn.Dropout(dropout), 88 | # torch.nn.ELU(), 89 | torch.nn.Hardtanh(-10, 10), 90 | ) 91 | 92 | def forward(self, inputs): 93 | outputs = self.fc(inputs) 94 | return outputs 95 | 96 | 97 | class Discriminator(torch.nn.Module): 98 | def __init__(self, input_size, hidden_size, output_size, dropout): 99 | super().__init__() 100 | 101 | self.input_size = input_size 102 | self.hidden_size = hidden_size 103 | self.output_size = output_size 104 | self.dropout = dropout 105 | 106 | self.classifier = torch.nn.Sequential( 107 | torch.nn.Linear(input_size, hidden_size), 108 | torch.nn.Dropout(dropout), 109 | torch.nn.ELU(), 110 | torch.nn.Linear(hidden_size, output_size), 111 | ) 112 | 113 | def forward(self, inputs): 114 | outputs = self.classifier(inputs) 115 | return outputs 116 | 117 | 118 | class Seq2Seq(torch.nn.Module): 119 | def __init__(self, vocab_size, embedding_size, hidden_size, dropout, max_len, scheduled_sampling_ratio, 120 | start_index, end_index, pad_index, trainable_embeddings, W_emb=None, **kwargs): 121 | super().__init__() 122 | 123 | self.vocab_size = vocab_size 124 | self.embedding_size = embedding_size 125 | self.hidden_size = hidden_size 126 | self.max_len = max_len 127 | self.dropout = dropout 128 | self.scheduled_sampling_ratio = scheduled_sampling_ratio 129 | self.trainable_embeddings = trainable_embeddings 130 | 131 | self.start_index = start_index 132 | self.end_index = end_index 133 | self.pad_index = pad_index 134 | 135 | self.embedding = torch.nn.Embedding(vocab_size, embedding_size, padding_idx=pad_index) 136 | if W_emb is not None: 137 | self.embedding.weight.data.copy_(torch.from_numpy(W_emb)) 138 | if not trainable_embeddings: 139 | self.embedding.weight.requires_grad = False 140 | 141 | self.encoder = LSTMEncoder(embedding_size, hidden_size, dropout) 142 | self.decoder_cell = torch.nn.LSTMCell(embedding_size, hidden_size) 143 | self.output_projection = torch.nn.Linear(hidden_size, vocab_size) 144 | 145 | self._xent_loss = SequenceReconstructionLoss(ignore_index=pad_index) 146 | 147 | def encode(self, inputs): 148 | # shape: (batch_size, sequence_len) 149 | sentence = inputs['sentence'] 150 | 151 | # shape: (batch_size, ) 152 | lengths = get_sequences_lengths(sentence) 153 | 154 | # shape: (batch_size, sequence_len, embedding_size) 155 | sentence_emb = self.embedding(sentence) 156 | 157 | # shape: (batch_size, hidden_size) 158 | decoder_hidden = self.encoder(sentence_emb, lengths) 159 | 160 | output_dict = { 161 | 'decoder_hidden': decoder_hidden 162 | } 163 | 164 | return output_dict 165 | 166 | def decode(self, state, targets=None): 167 | # shape: (batch_size, hidden_size) 168 | decoder_hidden = state['decoder_hidden'] 169 | decoder_cell = torch.zeros_like(decoder_hidden) 170 | 171 | batch_size = decoder_hidden.size(0) 172 | 173 | if targets is not None: 174 | num_decoding_steps = targets.size(1) 175 | else: 176 | num_decoding_steps = self.max_len 177 | 178 | # shape: (batch_size, ) 179 | last_predictions = decoder_hidden.new_full((batch_size,), fill_value=self.start_index).long() 180 | # shape: (batch_size, sequence_len, vocab_size) 181 | step_logits = [] 182 | # shape: (batch_size, sequence_len, ) 183 | step_predictions = [] 184 | 185 | for timestep in range(num_decoding_steps): 186 | # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio during training. 187 | # shape: (batch_size,) 188 | decoder_input = last_predictions 189 | if timestep > 0 and self.training and torch.rand(1).item() > self.scheduled_sampling_ratio: 190 | decoder_input = targets[:, timestep - 1] 191 | 192 | # shape: (batch_size, embedding_size) 193 | decoder_input = self.embedding(decoder_input) 194 | 195 | # shape: (batch_size, hidden_size) 196 | decoder_hidden, decoder_cell = self.decoder_cell(decoder_input, (decoder_hidden, decoder_cell)) 197 | 198 | # shape: (batch_size, vocab_size) 199 | output_projection = self.output_projection(decoder_hidden) 200 | 201 | # list of tensors, shape: (batch_size, 1, vocab_size) 202 | step_logits.append(output_projection.unsqueeze(1)) 203 | 204 | # shape (predicted_classes): (batch_size,) 205 | last_predictions = torch.argmax(output_projection, 1) 206 | 207 | # list of tensors, shape: (batch_size, 1) 208 | step_predictions.append(last_predictions.unsqueeze(1)) 209 | 210 | # shape: (batch_size, max_len, vocab_size) 211 | logits = torch.cat(step_logits, 1) 212 | # shape: (batch_size, max_len) 213 | predictions = torch.cat(step_predictions, 1) 214 | 215 | state.update({ 216 | "logits": logits, 217 | "predictions": predictions, 218 | }) 219 | 220 | return state 221 | 222 | def calc_loss(self, output_dict, inputs): 223 | # shape: (batch_size, sequence_len) 224 | targets = inputs['sentence'] 225 | # shape: (batch_size, sequence_len, vocab_size) 226 | logits = output_dict['logits'] 227 | 228 | loss = self._xent_loss(logits, targets) 229 | 230 | output_dict['loss'] = loss 231 | 232 | return output_dict 233 | 234 | def forward(self, inputs): 235 | state = self.encode(inputs) 236 | output_dict = self.decode(state, inputs['sentence']) 237 | 238 | output_dict = self.calc_loss(output_dict, inputs) 239 | 240 | return output_dict 241 | 242 | 243 | class Seq2SeqMeaningStyle(Seq2Seq): 244 | def __init__(self, meaning_size, style_size, nb_styles, *args, **kwargs): 245 | super().__init__(*args, **kwargs) 246 | 247 | self.meaning_size = meaning_size 248 | self.style_size = style_size 249 | self.nb_styles = nb_styles 250 | 251 | self.hidden_meaning = SpaceTransformer(self.hidden_size, self.meaning_size, self.dropout) 252 | self.hidden_style = SpaceTransformer(self.hidden_size, self.meaning_size, self.dropout) 253 | self.meaning_style_hidden = SpaceTransformer(meaning_size + style_size, self.hidden_size, self.dropout) 254 | 255 | # D - discriminator: discriminates the style of a sentence 256 | self.D_meaning = Discriminator(meaning_size, self.hidden_size, nb_styles, self.dropout) 257 | self.D_style = Discriminator(style_size, self.hidden_size, nb_styles, self.dropout) 258 | 259 | # P - predictor: predicts the meaning of a sentence (word embeddings) 260 | self.P_meaning = Discriminator(meaning_size, self.hidden_size, self.embedding_size, self.dropout) 261 | self.P_style = Discriminator(style_size, self.hidden_size, self.embedding_size, self.dropout) 262 | 263 | # P_bow - predictor_bow: predicts the meaning of a sentence (BoW) 264 | self.P_bow_meaning = Discriminator(meaning_size, self.hidden_size, self.vocab_size, self.dropout) 265 | self.P_bow_style = Discriminator(style_size, self.hidden_size, self.vocab_size, self.dropout) 266 | 267 | # Discriminator for gaussian z 268 | self.D_hidden = Discriminator(self.hidden_size, self.hidden_size, 2, self.dropout) 269 | 270 | self._D_loss = torch.nn.CrossEntropyLoss() 271 | self._D_adv_loss = StyleEntropyLoss() 272 | 273 | self._P_loss = torch.nn.MSELoss() 274 | self._P_adv_loss = MeaningZeroLoss() 275 | 276 | self._P_bow_loss = torch.nn.BCEWithLogitsLoss() 277 | self._P_bow_adv_loss = StyleEntropyLoss() 278 | 279 | def encode(self, inputs): 280 | state = super().encode(inputs) 281 | 282 | # shape: (batch_size, hidden_size) 283 | decoder_hidden = state['decoder_hidden'] 284 | 285 | # shape: (batch_size, hidden_size) 286 | meaning_hidden = self.hidden_meaning(decoder_hidden) 287 | 288 | # shape: (batch_size, hidden_size) 289 | style_hidden = self.hidden_style(decoder_hidden) 290 | 291 | state['meaning_hidden'] = meaning_hidden 292 | state['style_hidden'] = style_hidden 293 | 294 | return state 295 | 296 | def combine_meaning_style(self, state): 297 | # shape: (batch_size, hidden_size * 2) 298 | decoder_hidden = torch.cat([state['meaning_hidden'], state['style_hidden']], dim=-1) 299 | 300 | # shape: (batch_size, hidden_size) 301 | decoder_hidden = self.meaning_style_hidden(decoder_hidden) 302 | 303 | state['decoder_hidden'] = decoder_hidden 304 | 305 | return state 306 | 307 | def decode(self, state, targets=None): 308 | state = self.combine_meaning_style(state) 309 | 310 | output_dict = super().decode(state, targets) 311 | return output_dict 312 | 313 | def calc_discriminator_loss(self, output_dict, inputs): 314 | output_dict['loss_D_meaning'] = self._D_loss(output_dict['D_meaning_logits'], inputs['style']) 315 | output_dict['loss_D_style'] = self._D_loss(output_dict['D_style_logits'], inputs['style']) 316 | 317 | if 'meaning_embedding' in inputs: 318 | output_dict['loss_P_meaning'] = self._P_loss(output_dict['P_meaning'], inputs['meaning_embedding']) 319 | output_dict['loss_P_style'] = self._P_loss(output_dict['P_style'], inputs['meaning_embedding']) 320 | 321 | if 'meaning_bow' in inputs: 322 | output_dict['loss_P_bow_meaning'] = self._P_bow_loss(output_dict['P_bow_meaning'], inputs['meaning_bow']) 323 | output_dict['loss_P_bow_style'] = self._P_bow_loss(output_dict['P_bow_style'], inputs['meaning_bow']) 324 | 325 | return output_dict 326 | 327 | def calc_discriminator_adv_loss(self, output_dict, inputs): 328 | output_dict['loss_D_adv_meaning'] = self._D_adv_loss(output_dict['D_meaning_logits']) 329 | output_dict['loss_D_adv_style'] = self._D_loss(output_dict['D_style_logits'], inputs['style']) 330 | 331 | if 'meaning_embedding' in inputs: 332 | output_dict['loss_P_adv_meaning'] = self._P_loss(output_dict['P_meaning'], inputs['meaning_embedding']) 333 | output_dict['loss_P_adv_style'] = self._P_adv_loss(output_dict['P_style']) 334 | 335 | if 'meaning_bow' in inputs: 336 | output_dict['loss_P_bow_adv_meaning'] = self._P_bow_loss( 337 | output_dict['P_bow_meaning'], inputs['meaning_bow']) 338 | output_dict['loss_P_bow_adv_style'] = self._P_bow_adv_loss(output_dict['P_bow_style']) 339 | 340 | return output_dict 341 | 342 | def discriminate(self, output_dict, inputs, adversarial=False): 343 | output_dict['D_meaning_logits'] = self.D_meaning(output_dict['meaning_hidden']) 344 | output_dict['D_style_logits'] = self.D_style(output_dict['style_hidden']) 345 | 346 | if 'meaning_embedding' in inputs: 347 | output_dict['P_meaning'] = self.P_meaning(output_dict['meaning_hidden']) 348 | output_dict['P_style'] = self.P_style(output_dict['style_hidden']) 349 | 350 | if 'meaning_bow' in inputs: 351 | output_dict['P_bow_meaning'] = self.P_bow_meaning(output_dict['meaning_hidden']) 352 | output_dict['P_bow_style'] = self.P_bow_style(output_dict['style_hidden']) 353 | 354 | # calc loss 355 | if not adversarial: 356 | output_dict = self.calc_discriminator_loss(output_dict, inputs) 357 | else: 358 | output_dict = self.calc_discriminator_adv_loss(output_dict, inputs) 359 | 360 | return output_dict 361 | 362 | 363 | class StyleClassifier(torch.nn.Module): 364 | def __init__(self, vocab_size, embedding_size, hidden_size, dropout, trainable_embeddings, pad_index, nb_styles, 365 | W_emb=None, **kwargs): 366 | super().__init__() 367 | 368 | self.vocab_size = vocab_size 369 | self.embedding_size = embedding_size 370 | self.hidden_size = hidden_size 371 | self.dropout = dropout 372 | self.trainable_embeddings = trainable_embeddings 373 | self.nb_styles = nb_styles 374 | 375 | self.embedding = torch.nn.Embedding(vocab_size, embedding_size, padding_idx=pad_index) 376 | if W_emb is not None: 377 | self.embedding.weight.data.copy_(torch.from_numpy(W_emb)) 378 | if not trainable_embeddings: 379 | self.embedding.weight.requires_grad = False 380 | 381 | self.encoder = LSTMEncoder(embedding_size, hidden_size, dropout) 382 | self.classifier = torch.nn.Sequential( 383 | torch.nn.Linear(hidden_size, hidden_size), 384 | torch.nn.Dropout(dropout), 385 | torch.nn.ELU(), 386 | torch.nn.Linear(hidden_size, nb_styles), 387 | ) 388 | 389 | self._xent_loss = torch.nn.CrossEntropyLoss() 390 | 391 | def encode(self, inputs): 392 | # shape: (batch_size, sequence_len) 393 | sentence = inputs['sentence'] 394 | 395 | # shape: (batch_size, ) 396 | lengths = get_sequences_lengths(sentence) 397 | 398 | # shape: (batch_size, sequence_len, embedding_size) 399 | sentence_emb = self.embedding(sentence) 400 | 401 | # shape: (batch_size, hidden_size) 402 | decoder_hidden = self.encoder(sentence_emb, lengths) 403 | 404 | output_dict = { 405 | 'decoder_hidden': decoder_hidden 406 | } 407 | 408 | return output_dict 409 | 410 | def classify(self, state): 411 | # shape: (batch_size, hidden_size) 412 | hidden = state['decoder_hidden'] 413 | 414 | # shape: (batch_size, nb_classes) 415 | logits = self.classifier(hidden) 416 | predictions = torch.argmax(logits, 1) 417 | 418 | state.update({ 419 | "logits": logits, 420 | "predictions": predictions, 421 | }) 422 | 423 | return state 424 | 425 | def calc_loss(self, output_dict, inputs): 426 | # shape: (batch_size, sequence_len) 427 | targets = inputs['style'] 428 | # shape: (batch_size, sequence_len, vocab_size) 429 | logits = output_dict['logits'] 430 | 431 | loss = self._xent_loss(logits, targets) 432 | 433 | output_dict['loss'] = loss 434 | 435 | return output_dict 436 | 437 | def forward(self, inputs): 438 | state = self.encode(inputs) 439 | output_dict = self.classify(state) 440 | 441 | output_dict = self.calc_loss(output_dict, inputs) 442 | 443 | return output_dict 444 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | from sklearn.model_selection import train_test_split 4 | 5 | from config import PreprocessConfig 6 | from datasets import MeaningEmbeddingSentenceStyleDataset 7 | from experiment import Experiment 8 | from settings import EXPERIMENTS_DIR 9 | from utils import save_pickle, load_pickle, load_embeddings, create_embeddings_matrix, extract_word_embeddings_style_dimensions 10 | from vocab import Vocab 11 | 12 | 13 | def save_dataset(exp, dataset_train, dataset_val, dataset_test, vocab, style_vocab, W_emb): 14 | save_pickle((dataset_train, dataset_val, dataset_test), exp.experiment_dir.joinpath('datasets.pkl')) 15 | save_pickle((vocab, style_vocab), exp.experiment_dir.joinpath('vocabs.pkl')) 16 | save_pickle(W_emb, exp.experiment_dir.joinpath('W_emb.pkl')) 17 | 18 | print(f'Saved: {exp.experiment_dir}') 19 | 20 | 21 | def load_dataset(exp): 22 | dataset_train, dataset_val, dataset_test = load_pickle(exp.experiment_dir.joinpath('datasets.pkl')) 23 | vocab, style_vocab = load_pickle(exp.experiment_dir.joinpath('vocabs.pkl')) 24 | W_emb = load_pickle(exp.experiment_dir.joinpath('W_emb.pkl')) 25 | 26 | print(f'Dataset: {len(dataset_train)}, val: {len(dataset_val)}, test: {len(dataset_test)}') 27 | print(f'Vocab: {len(vocab)}, style vocab: {len(style_vocab)}') 28 | print(f'W_emb: {W_emb.shape}') 29 | 30 | return dataset_train, dataset_val, dataset_test, vocab, style_vocab, W_emb 31 | 32 | 33 | def create_dataset_reader(cfg): 34 | dataset_reader_class = cfg.dataset_reader_class 35 | 36 | dataset_reader_params = dataclasses.asdict(cfg) 37 | dataset_reader = dataset_reader_class(**dataset_reader_params) 38 | 39 | return dataset_reader 40 | 41 | 42 | def create_vocab(instances): 43 | vocab = Vocab([Vocab.PAD_TOKEN, Vocab.START_TOKEN, Vocab.END_TOKEN, Vocab.UNK_TOKEN, ]) 44 | vocab.add_documents([inst['sentence'] for inst in instances]) 45 | 46 | style_vocab = Vocab() 47 | style_vocab.add_document([inst['style'] for inst in instances]) 48 | 49 | return vocab, style_vocab 50 | 51 | 52 | def create_splits(cfg, instances): 53 | if cfg.test_size != 0: 54 | instances_train_val, instances_test = train_test_split(instances, test_size=cfg.test_size, random_state=42) 55 | else: 56 | instances_test = [] 57 | instances_train_val = instances 58 | 59 | if cfg.val_size != 0: 60 | instances_train, instances_val = train_test_split(instances_train_val, test_size=cfg.val_size, random_state=0) 61 | else: 62 | instances_train = [] 63 | instances_val = [] 64 | 65 | return instances_train, instances_val, instances_test 66 | 67 | 68 | def main(cfg): 69 | with Experiment(EXPERIMENTS_DIR, cfg, prefix='preprocess') as exp: 70 | print(f'Experiment started: {exp.experiment_id}') 71 | 72 | # read instances 73 | dataset_reader = create_dataset_reader(exp.config) 74 | print(f'Dataset reader: {dataset_reader.__class__.__name__}') 75 | 76 | instances = dataset_reader.read(exp.config.data_path) 77 | print(f'Instances: {len(instances)}') 78 | 79 | # create vocabularies 80 | vocab, style_vocab = create_vocab(instances) 81 | print(f'Vocab: {len(vocab)}, style vocab: {style_vocab}') 82 | 83 | if exp.config.max_vocab_size != 0: 84 | vocab.prune_vocab(exp.config.max_vocab_size) 85 | 86 | # create splits 87 | instances_train, instances_val, instances_test = create_splits(exp.config, instances) 88 | print(f'Train: {len(instances_train)}, val: {len(instances_val)}, test: {len(instances_test)}') 89 | 90 | # create embeddings 91 | word_embeddings = load_embeddings(cfg) 92 | W_emb = create_embeddings_matrix(word_embeddings, vocab) 93 | 94 | # extract style dimensions 95 | style_dimensions = extract_word_embeddings_style_dimensions(cfg, instances_train, vocab, style_vocab, W_emb) 96 | 97 | # create datasets 98 | dataset_train = MeaningEmbeddingSentenceStyleDataset( 99 | W_emb, style_dimensions, exp.config.style_tokens_proportion, 100 | instances_train, vocab, style_vocab 101 | ) 102 | dataset_val = MeaningEmbeddingSentenceStyleDataset( 103 | W_emb, style_dimensions, exp.config.style_tokens_proportion, 104 | instances_val, vocab, style_vocab 105 | ) 106 | dataset_test = MeaningEmbeddingSentenceStyleDataset( 107 | W_emb, style_dimensions, exp.config.style_tokens_proportion, 108 | instances_test, vocab, style_vocab 109 | ) 110 | 111 | save_dataset(exp, dataset_train, dataset_val, dataset_test, vocab, style_vocab, W_emb) 112 | 113 | print(f'Experiment finished: {exp.experiment_id}') 114 | 115 | 116 | if __name__ == '__main__': 117 | main(PreprocessConfig()) 118 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | astroid==2.2.5 2 | attrs==19.1.0 3 | autopep8==1.4.4 4 | backcall==0.1.0 5 | bleach==3.1.0 6 | blis==0.2.4 7 | certifi==2019.3.9 8 | chardet==3.0.4 9 | cymem==2.0.2 10 | decorator==4.4.0 11 | defusedxml==0.6.0 12 | en-core-web-lg==2.1.0 13 | entrypoints==0.3 14 | idna==2.8 15 | ipykernel==5.1.1 16 | ipython==7.5.0 17 | ipython-genutils==0.2.0 18 | ipywidgets==7.4.2 19 | isort==4.3.20 20 | jedi==0.13.3 21 | Jinja2==2.10.1 22 | joblib==0.13.2 23 | jsonschema==3.0.1 24 | jupyter==1.0.0 25 | jupyter-client==5.2.4 26 | jupyter-console==6.0.0 27 | jupyter-core==4.4.0 28 | lazy-object-proxy==1.4.1 29 | MarkupSafe==1.1.1 30 | mccabe==0.6.1 31 | mistune==0.8.4 32 | murmurhash==1.0.2 33 | nbconvert==5.5.0 34 | nbformat==4.4.0 35 | nltk==3.4.3 36 | notebook==5.7.8 37 | numpy==1.16.3 38 | pandocfilters==1.4.2 39 | parso==0.4.0 40 | pexpect==4.7.0 41 | pickleshare==0.7.5 42 | plac==0.9.6 43 | preshed==2.0.1 44 | prometheus-client==0.7.0 45 | prompt-toolkit==2.0.9 46 | protobuf==3.8.0 47 | ptyprocess==0.6.0 48 | pycodestyle==2.5.0 49 | Pygments==2.4.2 50 | pylint==2.3.1 51 | pyrsistent==0.15.2 52 | python-dateutil==2.8.0 53 | pytorch-ignite==0.2.0 54 | pyzmq==18.0.1 55 | qtconsole==4.5.1 56 | requests==2.22.0 57 | scikit-learn==0.21.2 58 | scipy==1.3.0 59 | Send2Trash==1.5.0 60 | six==1.12.0 61 | spacy==2.1.4 62 | srsly==0.0.6 63 | tensorboardX==1.7 64 | terminado==0.8.2 65 | testpath==0.4.2 66 | thinc==7.0.4 67 | tornado==6.0.2 68 | tqdm==4.32.1 69 | traitlets==4.3.2 70 | typed-ast==1.4.0 71 | urllib3==1.25.3 72 | wasabi==0.2.2 73 | wcwidth==0.1.7 74 | webencodings==0.5.1 75 | widgetsnbextension==3.4.2 76 | wrapt==1.11.1 77 | -------------------------------------------------------------------------------- /settings.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | DATA_DIR = Path('./data/') 4 | EXPERIMENTS_DIR = DATA_DIR.joinpath('experiments/') 5 | 6 | SHAKESPEARE_DATASET_DIR = DATA_DIR.joinpath('datasets/shakespeare/data/align/plays/merged/') 7 | YELP_DATASET_DIR = DATA_DIR.joinpath('datasets/yelp/data/yelp') 8 | 9 | WORD_EMBEDDINGS_FILENAMES = dict( 10 | glove=DATA_DIR.joinpath('word_embeddings/glove.840B.300d.pickled'), 11 | fast_text=DATA_DIR.joinpath('word_embeddings/crawl-300d-2M.pickled'), 12 | ) 13 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import itertools 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import torch.utils.data 7 | from ignite.engine import Engine, Events 8 | from ignite.metrics import Metric 9 | from tensorboardX import SummaryWriter 10 | 11 | from config import TrainConfig 12 | from experiment import Experiment 13 | from models import Seq2Seq, Seq2SeqMeaningStyle, StyleClassifier 14 | from preprocess import load_dataset 15 | from settings import EXPERIMENTS_DIR 16 | from update_functions import Seq2SeqUpdateFunction, StyleClassifierUpdateFunction, Seq2SeqMeaningStyleUpdateFunction 17 | from utils import to_device, save_weights, init_weights 18 | from vocab import Vocab 19 | 20 | 21 | 22 | def create_model(cfg, vocab, style_vocab, max_len, W_emb=None): 23 | model_class = cfg.model_class 24 | model_params = dataclasses.asdict(cfg) 25 | model_params.update(dict( 26 | max_len=max_len, 27 | vocab_size=len(vocab), 28 | start_index=vocab[Vocab.START_TOKEN], 29 | end_index=vocab[Vocab.END_TOKEN], 30 | pad_index=vocab[Vocab.PAD_TOKEN], 31 | nb_styles=len(style_vocab), 32 | )) 33 | 34 | if cfg.pretrained_embeddings: 35 | model_params.update(dict( 36 | W_emb=W_emb, 37 | )) 38 | 39 | model = model_class(**model_params) 40 | 41 | init_weights(model) 42 | 43 | model = to_device(model) 44 | 45 | return model 46 | 47 | 48 | def create_update_function(cfg, model): 49 | update_function_class = None 50 | if isinstance(model, Seq2Seq): 51 | update_function_class = Seq2SeqUpdateFunction 52 | if isinstance(model, Seq2SeqMeaningStyle): 53 | update_function_class = Seq2SeqMeaningStyleUpdateFunction 54 | if isinstance(model, StyleClassifier): 55 | update_function_class = StyleClassifierUpdateFunction 56 | 57 | update_function_params = dataclasses.asdict(cfg) 58 | update_function_params.update(dict( 59 | model=model, 60 | )) 61 | 62 | update_function_train = update_function_class(train=True, **update_function_params) 63 | update_function_eval = update_function_class(train=False, **update_function_params) 64 | 65 | return update_function_train, update_function_eval 66 | 67 | 68 | class LossAggregatorMetric(Metric): 69 | def __init__(self, *args, **kwargs): 70 | self.total_losses = defaultdict(float) 71 | self.num_updates = defaultdict(int) 72 | super().__init__(*args, **kwargs) 73 | 74 | def reset(self): 75 | self.total_losses = defaultdict(float) 76 | self.num_updates = defaultdict(int) 77 | 78 | def update(self, output): 79 | for name, val in output.items(): 80 | self.total_losses[name] += float(val) 81 | self.num_updates[name] += 1 82 | 83 | def compute(self): 84 | losses = {name: val / self.num_updates[name] for name, val in self.total_losses.items()} 85 | 86 | return losses 87 | 88 | 89 | def log_progress(epoch, iteration, losses, mode='train', tensorboard_writer=None, use_iteration=False): 90 | if not use_iteration: 91 | losses_str = [ 92 | f'{name}: {val:.3f}' 93 | for name, val in losses.items() 94 | ] 95 | losses_str = ' | '.join(losses_str) 96 | 97 | epoch_str = f'Epoch [{epoch}|{iteration}] {mode}' 98 | 99 | print(f'{epoch_str:<25}{losses_str}') 100 | 101 | for name, val in losses.items(): 102 | tensorboard_writer.add_scalar(f'{mode}/{name}', val, epoch if not use_iteration else iteration) 103 | 104 | 105 | def main(cfg): 106 | with Experiment(EXPERIMENTS_DIR, cfg, prefix='train') as exp: 107 | print(f'Experiment started: {exp.experiment_id}') 108 | 109 | preprocess_exp = Experiment.load(EXPERIMENTS_DIR, exp.config.preprocess_exp_id) 110 | dataset_train, dataset_val, dataset_test, vocab, style_vocab, W_emb = load_dataset(preprocess_exp) 111 | 112 | data_loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=exp.config.batch_size, shuffle=True) 113 | data_loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=exp.config.batch_size, shuffle=False) 114 | print(f'Data loader: {len(data_loader_train)}, {len(data_loader_val)}') 115 | 116 | model = create_model(exp.config, vocab, style_vocab, dataset_train.max_len, W_emb) 117 | 118 | update_function_train, update_function_eval = create_update_function(exp.config, model) 119 | 120 | trainer = Engine(update_function_train) 121 | evaluator = Engine(update_function_eval) 122 | 123 | metrics = {'loss': LossAggregatorMetric(), } 124 | for metric_name, metric in metrics.items(): 125 | metric.attach(evaluator, metric_name) 126 | 127 | best_loss = np.inf 128 | 129 | @trainer.on(Events.ITERATION_COMPLETED) 130 | def log_training_iter(engine): 131 | losses_train = engine.state.output 132 | log_progress(trainer.state.epoch, trainer.state.iteration, losses_train, 'train', tensorboard_writer, True) 133 | 134 | @trainer.on(Events.EPOCH_COMPLETED) 135 | def log_training_results(engine): 136 | nonlocal best_loss 137 | 138 | # evaluator.run(data_loader_train) 139 | # losses_train = evaluator.state.metrics['loss'] 140 | 141 | evaluator.run(data_loader_val) 142 | losses_val = evaluator.state.metrics['loss'] 143 | 144 | # log_progress(trainer.state.epoch, trainer.state.iteration, losses_train, 'train', tensorboard_writer) 145 | log_progress(trainer.state.epoch, trainer.state.iteration, losses_val, 'val', tensorboard_writer) 146 | 147 | if losses_val[exp.config.best_loss] < best_loss: 148 | best_loss = losses_val[exp.config.best_loss] 149 | save_weights(model, exp.experiment_dir.joinpath('best.th')) 150 | 151 | tensorboard_dir = exp.experiment_dir.joinpath('log') 152 | tensorboard_writer = SummaryWriter(str(tensorboard_dir)) 153 | 154 | trainer.run(data_loader_train, max_epochs=exp.config.num_epochs) 155 | 156 | print(f'Experiment finished: {exp.experiment_id}') 157 | 158 | 159 | if __name__ == '__main__': 160 | main(TrainConfig()) 161 | -------------------------------------------------------------------------------- /update_functions.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import itertools 3 | 4 | import torch 5 | 6 | from utils import to_device 7 | 8 | 9 | class UpdateFunction(object): 10 | def __init__(self, model, lr=0.001, weight_decay=0.0000001, grad_clipping=0, training=True, **kwargs): 11 | self.model = model 12 | self.grad_clipping = grad_clipping 13 | self.lr = lr 14 | self.weight_decay = weight_decay 15 | self.training = training 16 | 17 | def get_parameters(self, *modules): 18 | parameters = itertools.chain.from_iterable(m.parameters() for m in modules) 19 | return parameters 20 | 21 | @abc.abstractmethod 22 | def step(self, engine, batch): 23 | pass 24 | 25 | def __call__(self, engine, batch): 26 | if self.training: 27 | self.model.train() 28 | else: 29 | self.model.eval() 30 | 31 | batch = to_device(batch) 32 | 33 | losses = self.step(engine, batch) 34 | 35 | return losses 36 | 37 | 38 | class Seq2SeqUpdateFunction(UpdateFunction): 39 | def __init__(self, *args, **kwargs): 40 | super().__init__(*args, **kwargs) 41 | 42 | self.params = self.get_parameters(self.model) 43 | self.optimizer = torch.optim.Adam(self.params, lr=self.lr, weight_decay=self.weight_decay, amsgrad=True) 44 | 45 | def step(self, engine, batch): 46 | self.optimizer.zero_grad() 47 | 48 | output_dict = self.model(batch) 49 | 50 | loss = output_dict["loss"] 51 | 52 | if self.training: 53 | loss.backward() 54 | 55 | if self.grad_clipping != 0: 56 | torch.nn.utils.clip_grad_norm_(self.params, self.grad_clipping) 57 | 58 | self.optimizer.step() 59 | 60 | output_dict = { 61 | 'loss': loss, 62 | } 63 | return output_dict 64 | 65 | 66 | class Seq2SeqMeaningStyleUpdateFunction(UpdateFunction): 67 | def __init__(self, D_num_iterations, D_loss_multiplier, P_loss_multiplier, P_bow_loss_multiplier, 68 | use_discriminator, use_predictor, use_predictor_bow, use_motivator, use_gauss, 69 | *args, **kwargs): 70 | super().__init__(*args, **kwargs) 71 | 72 | self.D_num_iterations = D_num_iterations 73 | self.D_loss_multiplier = D_loss_multiplier 74 | self.P_loss_multiplier = P_loss_multiplier 75 | self.P_bow_loss_multiplier = P_bow_loss_multiplier 76 | self.use_discriminator = use_discriminator 77 | self.use_predictor = use_predictor 78 | self.use_predictor_bow = use_predictor_bow 79 | self.use_motivator = use_motivator 80 | self.use_gauss = use_gauss 81 | 82 | self.params_encoder_decoder = self.get_parameters( 83 | # encoder 84 | self.model.embedding, self.model.encoder, 85 | self.model.hidden_meaning, self.model.hidden_style, 86 | 87 | # decoder 88 | self.model.decoder_cell, self.model.output_projection, 89 | self.model.meaning_style_hidden 90 | ) 91 | 92 | params_D = [] 93 | if use_discriminator: 94 | params_D.append(self.model.D_meaning) 95 | if use_motivator: 96 | params_D.append(self.model.D_style) 97 | 98 | if use_predictor: 99 | params_D.append(self.model.P_style) 100 | if use_motivator: 101 | params_D.append(self.model.P_meaning) 102 | 103 | if use_predictor_bow: 104 | params_D.append(self.model.P_bow_style) 105 | if use_motivator: 106 | params_D.append(self.model.P_bow_meaning) 107 | 108 | if use_gauss: 109 | params_D.append(self.model.D_hidden) 110 | 111 | self.params_D = self.get_parameters(*params_D) 112 | 113 | self.optimizer_encoder_decoder = torch.optim.Adam( 114 | self.params_encoder_decoder, lr=self.lr, weight_decay=self.weight_decay, amsgrad=True 115 | ) 116 | self.optimizer_D = torch.optim.Adam( 117 | self.params_D, lr=self.lr, weight_decay=self.weight_decay, amsgrad=True 118 | ) 119 | 120 | def step(self, engine, batch): 121 | losses_output_dict = {} 122 | 123 | # discriminator 124 | state = self.model.encode(batch) 125 | output_dict_D = self.model.discriminate(state, batch) 126 | 127 | if self.use_discriminator: 128 | loss_D_meaning = output_dict_D['loss_D_meaning'] 129 | losses_output_dict['loss_D_meaning'] = float(loss_D_meaning.item()) 130 | 131 | if self.use_motivator: 132 | loss_D_style = output_dict_D['loss_D_style'] 133 | losses_output_dict['loss_D_style'] = float(loss_D_style.item()) 134 | else: 135 | loss_D_style = 0 136 | else: 137 | loss_D_meaning = 0 138 | loss_D_style = 0 139 | 140 | if 'meaning_embedding' in batch and self.use_predictor: 141 | loss_P_style = output_dict_D['loss_P_style'] 142 | losses_output_dict['loss_P_style'] = float(loss_P_style.item()) 143 | 144 | if self.use_motivator: 145 | loss_P_meaning = output_dict_D['loss_P_meaning'] 146 | losses_output_dict['loss_P_meaning'] = float(loss_P_meaning.item()) 147 | else: 148 | loss_P_meaning = 0 149 | else: 150 | loss_P_style = 0 151 | loss_P_meaning = 0 152 | 153 | if 'meaning_bow' in batch and self.use_predictor_bow: 154 | loss_P_bow_style = output_dict_D['loss_P_bow_style'] 155 | losses_output_dict['loss_P_bow_style'] = float(loss_P_bow_style.item()) 156 | 157 | if self.use_motivator: 158 | loss_P_bow_meaning = output_dict_D['loss_P_bow_meaning'] 159 | losses_output_dict['loss_P_bow_meaning'] = float(loss_P_bow_meaning.item()) 160 | else: 161 | loss_P_bow_meaning = 0 162 | else: 163 | loss_P_bow_style = 0 164 | loss_P_bow_meaning = 0 165 | 166 | if self.use_gauss: 167 | output_dict_D = self.model.combine_meaning_style(output_dict_D) 168 | G_real = torch.randn_like(output_dict_D['decoder_hidden']) 169 | G_fake = output_dict_D['decoder_hidden'] 170 | G_labels = torch.cat([torch.ones_like(batch['style']), torch.zeros_like(batch['style'])], dim=0) 171 | G_inputs = torch.cat([G_real, G_fake], dim=0) 172 | 173 | G_logits = self.model.D_hidden(G_inputs) 174 | loss_D_hidden = self.model._D_loss(G_logits, G_labels) 175 | losses_output_dict['loss_D_hidden'] = float(loss_D_hidden.item()) 176 | else: 177 | loss_D_hidden = 0 178 | 179 | loss_D_total = loss_D_meaning + loss_D_style \ 180 | + loss_P_meaning + loss_P_style \ 181 | + loss_P_bow_meaning + loss_P_bow_style \ 182 | + loss_D_hidden 183 | 184 | if self.training: 185 | loss_D_total.backward() 186 | 187 | if self.grad_clipping != 0: 188 | torch.nn.utils.clip_grad_norm_(self.params_D, self.grad_clipping) 189 | 190 | self.optimizer_D.step() 191 | self.model.zero_grad() 192 | 193 | # encoder-decoder 194 | if not self.training or engine.state.iteration % self.D_num_iterations == 0: 195 | output_dict = self.model(batch) 196 | output_dict = self.model.discriminate(output_dict, batch, adversarial=True) 197 | 198 | loss = output_dict['loss'] 199 | losses_output_dict['loss'] = float(loss.item()) 200 | 201 | loss_D_adv_meaning = output_dict['loss_D_adv_meaning'] 202 | losses_output_dict['loss_D_adv_meaning'] = float(loss_D_adv_meaning.item()) 203 | if self.use_motivator: 204 | loss_D_adv_style = output_dict['loss_D_adv_style'] 205 | losses_output_dict['loss_D_adv_style'] = float(loss_D_adv_style.item()) 206 | else: 207 | loss_D_adv_style = 0 208 | 209 | if 'meaning_embedding' in batch and self.use_predictor: 210 | loss_P_adv_style = output_dict['loss_P_adv_style'] 211 | losses_output_dict['loss_P_adv_style'] = float(loss_P_adv_style.item()) 212 | 213 | if self.use_motivator: 214 | loss_P_adv_meaning = output_dict['loss_P_adv_meaning'] 215 | losses_output_dict['loss_P_adv_meaning'] = float(loss_P_adv_meaning.item()) 216 | else: 217 | loss_P_adv_meaning = 0 218 | else: 219 | loss_P_adv_style = 0 220 | loss_P_adv_meaning = 0 221 | 222 | if 'meaning_bow' in batch and self.use_predictor_bow: 223 | loss_P_bow_adv_style = output_dict['loss_P_bow_adv_style'] 224 | losses_output_dict['loss_P_bow_adv_style'] = float(loss_P_bow_adv_style.item()) 225 | 226 | if self.use_motivator: 227 | loss_P_bow_adv_meaning = output_dict['loss_P_bow_adv_meaning'] 228 | losses_output_dict['loss_P_bow_adv_meaning'] = float(loss_P_bow_adv_meaning.item()) 229 | else: 230 | loss_P_bow_adv_meaning = 0 231 | else: 232 | loss_P_bow_adv_style = 0 233 | loss_P_bow_adv_meaning = 0 234 | 235 | if self.use_gauss: 236 | G_logits = self.model.D_hidden(output_dict['decoder_hidden']) 237 | loss_D_adv_hidden = self.model._D_adv_loss(G_logits) 238 | losses_output_dict['loss_D_adv_hidden'] = float(loss_D_adv_hidden.item()) 239 | else: 240 | loss_D_adv_hidden = 0 241 | 242 | loss_total = loss 243 | if loss_D_meaning <= 0.35: 244 | loss_total += self.D_loss_multiplier * loss_D_adv_meaning 245 | if loss_D_style <= 0.35: 246 | loss_total += self.D_loss_multiplier * loss_D_adv_style 247 | 248 | if loss_P_style < 2.5e-3: 249 | loss_total += self.P_loss_multiplier * loss_P_adv_style 250 | if loss_P_meaning < 2.5e-3: 251 | loss_total += self.P_loss_multiplier * loss_P_adv_meaning 252 | 253 | if loss_P_bow_meaning <= 0.35: 254 | loss_total += self.P_bow_loss_multiplier * loss_P_bow_adv_meaning 255 | if loss_P_bow_style <= 0.35: 256 | loss_total += self.P_bow_loss_multiplier * loss_P_bow_adv_style 257 | 258 | if loss_D_hidden <= 0.35: 259 | loss_total += self.D_loss_multiplier * loss_D_adv_hidden 260 | 261 | losses_output_dict['loss_total'] = float(loss_total.item()) 262 | 263 | if self.training: 264 | loss_total.backward() 265 | 266 | if self.grad_clipping != 0: 267 | torch.nn.utils.clip_grad_norm_(self.params_encoder_decoder, self.grad_clipping) 268 | 269 | self.optimizer_encoder_decoder.step() 270 | self.model.zero_grad() 271 | 272 | return losses_output_dict 273 | 274 | 275 | class StyleClassifierUpdateFunction(UpdateFunction): 276 | def __init__(self, *args, **kwargs): 277 | super().__init__(*args, **kwargs) 278 | 279 | self.params = self.get_parameters(self.model) 280 | self.optimizer = torch.optim.Adam(self.params, lr=self.lr, weight_decay=self.weight_decay, amsgrad=True) 281 | 282 | def step(self, engine, batch): 283 | self.optimizer.zero_grad() 284 | 285 | output_dict = self.model(batch) 286 | 287 | loss = output_dict["loss"] 288 | 289 | if self.training: 290 | loss.backward() 291 | 292 | if self.grad_clipping != 0: 293 | torch.nn.utils.clip_grad_norm_(self.params, self.grad_clipping) 294 | 295 | self.optimizer.step() 296 | 297 | output_dict = { 298 | 'loss': loss, 299 | } 300 | return output_dict 301 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | import pickle 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from settings import WORD_EMBEDDINGS_FILENAMES 9 | from vocab import Vocab 10 | 11 | 12 | def save_json(obj, filename): 13 | with open(filename, 'wb') as f: 14 | json.dump(obj, f) 15 | 16 | 17 | def load_json(filename): 18 | with open(filename, 'rb') as f: 19 | obj = json.load(f) 20 | 21 | return obj 22 | 23 | 24 | def save_pickle(obj, filename): 25 | with open(filename, 'wb') as f: 26 | pickle.dump(obj, f) 27 | 28 | 29 | def load_pickle(filename): 30 | with open(filename, 'rb') as f: 31 | obj = pickle.load(f) 32 | 33 | return obj 34 | 35 | 36 | def to_device(obj, device=None): 37 | if device is None: 38 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 39 | 40 | if isinstance(obj, (list, tuple)): 41 | return [to_device(o, device) for o in obj] 42 | 43 | if isinstance(obj, dict): 44 | return {k: to_device(o, device) for k, o in obj.items()} 45 | 46 | if isinstance(obj, np.ndarray): 47 | obj = torch.from_numpy(obj) 48 | 49 | obj = obj.to(device) 50 | return obj 51 | 52 | 53 | def load_weights(model, filename): 54 | # load trained on GPU models to CPU 55 | if not torch.cuda.is_available(): 56 | def map_location(storage, loc): return storage 57 | else: 58 | map_location = None 59 | 60 | state_dict = torch.load(str(filename), map_location=map_location) 61 | 62 | if isinstance(model, torch.nn.DataParallel): 63 | model = model.module 64 | 65 | model.load_state_dict(state_dict) 66 | 67 | 68 | def save_weights(model, filename): 69 | if isinstance(model, torch.nn.DataParallel): 70 | model = model.module 71 | 72 | torch.save(model.state_dict(), str(filename)) 73 | 74 | 75 | def init_weights(modules): 76 | if isinstance(modules, torch.nn.Module): 77 | modules = modules.modules() 78 | 79 | for m in modules: 80 | if isinstance(m, torch.nn.Sequential): 81 | init_weights(m_inner for m_inner in m) 82 | 83 | if isinstance(m, torch.nn.ModuleList): 84 | init_weights(m_inner for m_inner in m) 85 | 86 | if isinstance(m, torch.nn.Linear): 87 | m.reset_parameters() 88 | torch.nn.init.xavier_normal_(m.weight.data) 89 | # m.bias.data.zero_() 90 | if m.bias is not None: 91 | m.bias.data.normal_(0, 0.01) 92 | 93 | if isinstance(m, torch.nn.Conv2d): 94 | torch.nn.init.xavier_normal_(m.weight.data) 95 | m.bias.data.zero_() 96 | 97 | if isinstance(m, torch.nn.Conv1d): 98 | torch.nn.init.xavier_normal_(m.weight.data) 99 | m.bias.data.zero_() 100 | 101 | 102 | def get_sequences_lengths(sequences, masking=0, dim=1): 103 | if len(sequences.size()) > 2: 104 | sequences = sequences.sum(dim=2) 105 | 106 | masks = torch.ne(sequences, masking).long() 107 | 108 | lengths = masks.sum(dim=dim) 109 | 110 | return lengths 111 | 112 | 113 | def load_embeddings(cfg): 114 | word_embeddings_filename = WORD_EMBEDDINGS_FILENAMES[cfg.word_embeddings] 115 | word_embeddings = load_pickle(word_embeddings_filename) 116 | 117 | return word_embeddings 118 | 119 | 120 | def create_embeddings_matrix(word_embeddings, vocab): 121 | embedding_size = word_embeddings[list(word_embeddings.keys())[0]].shape[0] 122 | 123 | W_emb = np.zeros((len(vocab), embedding_size), dtype=np.float32) 124 | special_tokens = { 125 | t: np.random.uniform(-0.3, 0.3, (embedding_size,)) 126 | for t in (Vocab.START_TOKEN, Vocab.END_TOKEN, Vocab.UNK_TOKEN) 127 | } 128 | special_tokens[Vocab.PAD_TOKEN] = np.zeros((embedding_size,)) 129 | nb_unk = 0 130 | for i, t in vocab.id2token.items(): 131 | if t in special_tokens: 132 | W_emb[i] = special_tokens[t] 133 | else: 134 | if t in word_embeddings: 135 | W_emb[i] = word_embeddings[t] 136 | else: 137 | W_emb[i] = np.random.uniform(-0.3, 0.3, embedding_size) 138 | nb_unk += 1 139 | 140 | print(f'Nb unk: {nb_unk}') 141 | 142 | return W_emb 143 | 144 | 145 | def extract_word_embeddings_style_dimensions(cfg, instances, vocab, style_vocab, W_emb): 146 | sample_size = min(cfg.nb_style_dims_sentences, len(instances)) 147 | instances = np.random.choice(instances, size=sample_size, replace=False) 148 | instances_grouped_by_style = [ 149 | [inst['sentence'] for inst in instances if inst['style'] == style] 150 | for style in style_vocab.token2id.keys() 151 | ] 152 | print(f'Styles instances: {[len(s) for s in instances_grouped_by_style]}') 153 | 154 | sentences_embed = [ 155 | [ 156 | W_emb[vocab[t]] 157 | for t in itertools.chain.from_iterable(style_sents) 158 | if t in vocab 159 | ] 160 | for style_sents in instances_grouped_by_style 161 | ] 162 | 163 | means = [np.mean(e, axis=0) for e in sentences_embed] 164 | print(f'Styles means: {[m.shape for m in means]}') 165 | 166 | # get dimensions that have the biggest absolute difference 167 | means_diff = np.abs(np.subtract(*means)) 168 | diff_sort_idx = np.argsort(-means_diff) 169 | style_dims = diff_sort_idx[:cfg.nb_style_dims] 170 | 171 | print(f'Style dimensions: {style_dims.shape}') 172 | 173 | return style_dims 174 | -------------------------------------------------------------------------------- /vocab.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | 4 | class Vocab(object): 5 | END_TOKEN = '' 6 | START_TOKEN = '' 7 | PAD_TOKEN = '' 8 | UNK_TOKEN = '' 9 | 10 | def __init__(self, special_tokens=None): 11 | super().__init__() 12 | 13 | self.special_tokens = special_tokens 14 | 15 | self.token2id = {} 16 | self.id2token = {} 17 | 18 | self.token_counts = Counter() 19 | 20 | if self.special_tokens is not None: 21 | self.add_document(self.special_tokens) 22 | 23 | def add_document(self, document, rebuild=True): 24 | for token in document: 25 | self.token_counts[token] += 1 26 | 27 | if token not in self.token2id: 28 | self.token2id[token] = len(self.token2id) 29 | 30 | if rebuild: 31 | self._rebuild_id2token() 32 | 33 | def add_documents(self, documents): 34 | for doc in documents: 35 | self.add_document(doc, rebuild=False) 36 | 37 | self._rebuild_id2token() 38 | 39 | def prune_vocab(self, max_size): 40 | nb_tokens_before = len(self.token2id) 41 | 42 | tokens_all = set(self.token2id.keys()) 43 | tokens_most_common = set(t for t, c in self.token_counts.most_common(max_size)) 44 | tokens_special = set(self.special_tokens) 45 | 46 | tokens_to_keep = tokens_most_common | tokens_special 47 | tokens_to_delete = tokens_all - tokens_to_keep 48 | 49 | for token in tokens_to_delete: 50 | self.token_counts.pop(token) 51 | # self.token2id.pop(token) 52 | 53 | self.add_document(self.special_tokens, rebuild=False) 54 | self.add_document(self.token_counts.keys(), rebuild=False) 55 | 56 | self._rebuild_id2token() 57 | 58 | nb_tokens_after = len(self.token2id) 59 | 60 | print(f'Vocab pruned: {nb_tokens_before} -> {nb_tokens_after}') 61 | 62 | def _rebuild_id2token(self): 63 | self.id2token = {i: t for t, i in self.token2id.items()} 64 | 65 | def get(self, item, default=None): 66 | return self.token2id.get(item, default) 67 | 68 | def __getitem__(self, item): 69 | return self.token2id[item] 70 | 71 | def __contains__(self, item): 72 | return item in self.token2id 73 | 74 | def __len__(self): 75 | return len(self.token2id) 76 | 77 | def __str__(self): 78 | return f'Vocab: {len(self)} tokens' 79 | --------------------------------------------------------------------------------