├── util ├── __init__.py ├── save_tool.py ├── dataset_util.py ├── mnli.py └── data_loader.py ├── saved_model └── trained_model_will_be_saved_in_here.txt ├── setup.sh ├── config.py ├── .gitignore ├── README.md ├── model └── res_encoder.py └── torch_util.py /util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /saved_model/trained_model_will_be_saved_in_here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Add current pwd to PYTHONPATH 4 | export DIR_TMP="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" 5 | export PYTHONPATH=$DIR_TMP 6 | 7 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ROOT_DIR = os.path.dirname(os.path.realpath(__file__)) 4 | 5 | DATA_ROOT = os.path.join(ROOT_DIR, 'data') 6 | EMBD_FILE = os.path.join(ROOT_DIR, 'data/saved_embd.pt') 7 | 8 | 9 | if __name__ == '__main__': 10 | print(EMBD_FILE) 11 | 12 | # /home/easonnie/projects/publiced_code/ResEncoder/saved_model/12-04-23:22:31_[600,600,600]-3stack-bilstm-maxout-residual-1-relu-seed(12)-dr(0.1)-mlpd(800) -------------------------------------------------------------------------------- /util/save_tool.py: -------------------------------------------------------------------------------- 1 | import os 2 | import config 3 | from datetime import datetime 4 | 5 | 6 | def gen_prefix(name, date): 7 | file_path = os.path.join(config.ROOT_DIR, 'saved_model', '_'.join((date, name))) 8 | return file_path 9 | 10 | 11 | def logging2file(file_path, type, info, file_name=None): 12 | if not os.path.exists(file_path): 13 | os.mkdir(file_path) 14 | if type == 'message': 15 | with open(os.path.join(file_path, 'message.txt'), 'a+') as f: 16 | f.write(info) 17 | f.flush() 18 | elif type == 'log': 19 | with open(os.path.join(file_path, 'log.txt'), 'a+') as f: 20 | f.write(info) 21 | f.flush() 22 | elif type == 'log_snli': 23 | with open(os.path.join(file_path, 'log_snli.txt'), 'a+') as f: 24 | f.write(info) 25 | f.flush() 26 | elif type == 'code': 27 | with open(os.path.join(file_path, 'code.pys'), 'a+') as f, open(file_name) as it: 28 | f.write(it.read()) 29 | f.flush() 30 | 31 | if __name__ == '__main__': 32 | date_now = datetime.now().strftime("%m-%d-%H:%M:%S") 33 | log_file_path = gen_prefix('conv_model', date_now) 34 | 35 | logging2file(log_file_path, 'message', 'something.') 36 | logging2file(log_file_path, 'code', 'something.') 37 | logging2file(log_file_path, 'log', 'something.') -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | saved_model/ 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | .static_storage/ 59 | .media/ 60 | local_settings.py 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | -------------------------------------------------------------------------------- /util/dataset_util.py: -------------------------------------------------------------------------------- 1 | from collections import Counter, OrderedDict 2 | 3 | from torchtext import data 4 | import torchtext 5 | import os 6 | import torch 7 | import six 8 | import config 9 | 10 | 11 | class SST_UTF8(data.Dataset): 12 | def __init__(self, path, text_field, label_field, subtrees=False, 13 | fine_grained=False, binary=True, **kwargs): 14 | """Create an SST dataset instance given a path and fields. 15 | 16 | Arguments: 17 | path: Path to the data file. 18 | text_field: The field that will be used for text data. 19 | newline_eos: Whether to add an token for every newline in the 20 | data file. Default: True. 21 | Remaining keyword arguments: Passed to the constructor of 22 | data.Dataset. 23 | """ 24 | 25 | fields = [('text', text_field), ('label', label_field)] 26 | 27 | def get_label_str(label): 28 | pre = 'very ' if fine_grained else '' 29 | return {'0': pre + 'negative', '1': 'negative', '2': 'neutral', 30 | '3': 'positive', '4': pre + 'positive', None: None}[label] 31 | 32 | label_field.preprocessing = data.Pipeline(get_label_str) 33 | 34 | if binary: 35 | filter_pred = lambda ex: ex.label != 'neutral' 36 | else: 37 | filter_pred = None 38 | 39 | with open(os.path.expanduser(path), encoding='utf-8') as f: 40 | if subtrees: 41 | examples = [ex for line in f for ex in 42 | data.Example.fromtree(line, fields, True)] 43 | else: 44 | examples = [data.Example.fromtree(line, fields) for line in f] 45 | super(SST_UTF8, self).__init__(examples=examples, fields=fields, filter_pred=filter_pred, **kwargs) 46 | 47 | @classmethod 48 | def splits(cls, text_field, label_field, root='.', 49 | train='/train.txt', validation='/dev.txt', test='/test.txt', 50 | train_subtrees=False, **kwargs): 51 | path = os.path.join(root, 'SST') 52 | 53 | train_data = None if train is None else cls( 54 | path + train, text_field, label_field, subtrees=train_subtrees, 55 | **kwargs) 56 | val_data = None if validation is None else cls( 57 | path + validation, text_field, label_field, **kwargs) 58 | test_data = None if test is None else cls( 59 | path + test, text_field, label_field, **kwargs) 60 | return tuple(d for d in (train_data, val_data, test_data) 61 | if d is not None) 62 | 63 | 64 | class TabularUTF8Dataset(data.Dataset): 65 | def __init__(self, path, format, fields, **kwargs): 66 | 67 | make_example = { 68 | 'json': data.Example.fromJSON, 'dict': data.Example.fromdict, 69 | 'tsv': data.Example.fromTSV, 'csv': data.Example.fromCSV}[format.lower()] 70 | 71 | with open(os.path.expanduser(path), encoding='utf-8') as f: 72 | examples = [ 73 | make_example(line.decode('utf-8') if six.PY2 else line, fields) 74 | for line in f] 75 | 76 | if make_example in (data.Example.fromdict, data.Example.fromJSON): 77 | fields, field_dict = [], fields 78 | for field in field_dict.values(): 79 | if isinstance(field, list): 80 | fields.extend(field) 81 | else: 82 | fields.append(field) 83 | 84 | super(TabularUTF8Dataset, self).__init__(examples, fields, **kwargs) 85 | 86 | 87 | class SSUField(data.Field): 88 | def __init__(self, tokenize=data.get_tokenizer('spacy'), eos_token='', include_lengths=True): 89 | super(SSUField, self).__init__(tokenize=tokenize, 90 | eos_token=eos_token, 91 | include_lengths=include_lengths) 92 | 93 | def merge_vocab(self, *args, **kwargs): 94 | counter = Counter() 95 | sources = [] 96 | for arg in args: 97 | # print(arg) 98 | if isinstance(arg, torchtext.data.Dataset): 99 | sources += [getattr(arg, name) for name, field in 100 | arg.fields.items() 101 | if field is self or 102 | isinstance(field, SSUField)] 103 | else: 104 | sources.append(arg) 105 | for data in sources: 106 | for x in data: 107 | if not self.sequential: 108 | x = [x] 109 | counter.update(x) 110 | specials = list(OrderedDict.fromkeys( 111 | tok for tok in [self.pad_token, self.init_token, self.eos_token] 112 | if tok is not None)) 113 | self.vocab = torchtext.data.Vocab(counter, specials=specials, **kwargs) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Residual Connected Sentence Encoder 2 | This is a repo for Residual-connected sentence encoder for NLI. 3 | [https://arxiv.org/abs/1708.02312](https://arxiv.org/abs/1708.02312) 4 | If you use this code as part of published research, please cite the following paper. 5 | 6 | ``` 7 | @article{nie2017shortcut, 8 | title={Shortcut-stacked sentence encoders for multi-domain inference}, 9 | author={Nie, Yixin and Bansal, Mohit}, 10 | journal={arXiv preprint arXiv:1708.02312}, 11 | year={2017} 12 | } 13 | ``` 14 | 15 | Try to follow the instruction below to successfully run the experiment. 16 | 17 | 1.Download the additional `data.zip` file, unzip it and place it at the root directory of this repo. 18 | Link for download `data.zip` file: [*DropBox Link*](https://www.dropbox.com/sh/kq81vmcmwktlyji/AADRVQRh9MdcXTkTQct7QlQFa?dl=0) 19 | 20 | 2.This repo is based on an old version of `torchtext`, the latest version of `torchtext` is not backward-compatible. 21 | We provide a link to download the old `torchtext` that should be used for this repo. Link: [*old_torchtext*](https://www.dropbox.com/sh/n8ipkm1ng8f6d5u/AADg4KhwQMwz4xFkVJafgUMma?dl=0) 22 | 23 | 3.Install the requirement: 24 | ``` 25 | python 3.6 26 | 27 | torchtext # The one you just download. Or you can use the latest torchtext by fixing the SNLI path problem. 28 | pytorch == 0.2.0 29 | fire 30 | tqdm 31 | numpy 32 | spacy 33 | ``` 34 | 35 | For the installation of torchtext, you can run the following script in the downloaded `torchtext_backward_compatible` directory (in step 2) using the python interpreter of your environment: 36 | ``` 37 | python setup.py install 38 | ``` 39 | 40 | To fully install spacy, you will need to run the following script. 41 | ``` 42 | pip install -U spacy 43 | python -m spacy download en 44 | ``` 45 | 46 | Optionally, you can try to match the pip freeze file below to set up the same experiment environment. 47 | ``` 48 | certifi==2017.11.5 49 | chardet==3.0.4 50 | cymem==1.31.2 51 | cytoolz==0.8.2 52 | dill==0.2.7.1 53 | en-core-web-sm==2.0.0 54 | fire==0.1.2 55 | ftfy==4.4.3 56 | html5lib==1.0.1 57 | idna==2.6 58 | msgpack-numpy==0.4.1 59 | msgpack-python==0.5.1 60 | murmurhash==0.28.0 61 | numpy==1.14.0 62 | pathlib==1.0.1 63 | plac==0.9.6 64 | preshed==1.0.0 65 | PyYAML==3.12 66 | regex==2017.4.5 67 | requests==2.18.4 68 | six==1.11.0 69 | spacy==2.0.5 70 | termcolor==1.1.0 71 | thinc==6.10.2 72 | toolz==0.9.0 73 | torch==0.2.0.post3 74 | torchtext==0.1.1 75 | tqdm==4.19.5 76 | ujson==1.35 77 | urllib3==1.22 78 | wcwidth==0.1.7 79 | webencodings==0.5.1 80 | wrapt==1.10.11 81 | ``` 82 | 83 | 4.There is a directory called `saved_model` At the root directory of this repo: 84 | This directory will be used for saving the models that produce best dev result. 85 | 86 | Before running the experiments, make sure that the structure of this repo should be something like below. 87 | ``` 88 | . 89 | ├── config.py 90 | ├── data 91 | │   ├── multinli_0.9 92 | │   │   ├── multinli_0.9_dev_matched.jsonl 93 | │   │   ├── multinli_0.9_dev_mismatched.jsonl 94 | │   │   ├── multinli_0.9_test_matched_unlabeled.jsonl 95 | │   │   ├── multinli_0.9_test_mismatched_unlabeled.jsonl 96 | │   │   └── multinli_0.9_train.jsonl 97 | │   ├── saved_embd.pt 98 | │   └── snli_1.0 99 | │   ├── README.txt 100 | │   ├── snli_1.0_dev.jsonl 101 | │   ├── snli_1.0_dev.txt 102 | │   ├── snli_1.0_test.jsonl 103 | │   ├── snli_1.0_test.txt 104 | │   ├── snli_1.0_train.jsonl 105 | │   └── snli_1.0_train.txt 106 | ├── model 107 | │   └── res_encoder.py 108 | ├── saved_model 109 | │   └── trained_model_will_be_saved_in_here.txt 110 | ├── setup.sh 111 | ├── torch_util.py 112 | └── util 113 | ├── data_loader.py 114 | ├── dataset_util.py 115 | ├── __init__.py 116 | ├── mnli.py 117 | └── save_tool.py 118 | ``` 119 | 120 | 5.Start training by run the script in the root directory. 121 | ``` 122 | source setup.sh 123 | python model/res_encoder.py train_snli 124 | ``` 125 | 126 | 6.After training completed, there will be a folder created by the script in the `saved_model` directory. 127 | The parameters of the model will be saved in that folder. The path of the model will be something like: 128 | ``` 129 | $DIR_TMP/saved_model/(TIME_STAMP)_[600,600,600]-3stack-bilstm-maxout-residual/saved_params/(YOUR_MODEL_WITH_DEV_RESULT) 130 | ``` 131 | Remember to change the bracketed part to the actual file name on your computer. 132 | 133 | 7.Now, you can evaluate the model on dev set again by running the script below. 134 | ``` 135 | python model/res_encoder.py eval (PATH_OF_YOUR_MODEL) dev # for evaluation on dev set 136 | python model/res_encoder.py eval (PATH_OF_YOUR_MODEL) test # for evaluation on test set 137 | ``` 138 | 139 | **Pretrained Model:** 140 | We also provide a link to download the [*pretrained model*](https://www.dropbox.com/s/raa29iwpkv2xldh/pretrained_model_dev%2887.00%29?dl=0). 141 | After downloading the pretrained model, you can run the script in step 7 for evaluation, however you need to keep the default parameter for pytorch to load the pretrained model. 142 | -------------------------------------------------------------------------------- /util/mnli.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from util.dataset_util import SSUField 4 | 5 | from collections import OrderedDict, Counter 6 | 7 | from torchtext import data, vocab 8 | from torchtext import datasets 9 | from torchtext.data import Dataset 10 | from torchtext.vocab import Vocab 11 | from util import data_loader 12 | import config 13 | import torch 14 | 15 | 16 | class RParsedTextLField(data.Field): 17 | def __init__(self, eos_token='', lower=False, include_lengths=True): 18 | super(RParsedTextLField, self).__init__( 19 | eos_token=eos_token, lower=lower, include_lengths=True, preprocessing=lambda parse: [ 20 | t for t in parse if t not in ('(', ')')], 21 | postprocessing=lambda parse, _, __: [ 22 | list(reversed(p)) for p in parse]) 23 | 24 | 25 | class ParsedTextLField(data.Field): 26 | def __init__(self, eos_token='', lower=False, include_lengths=True): 27 | super(ParsedTextLField, self).__init__( 28 | eos_token=eos_token, lower=lower, include_lengths=True, preprocessing=lambda parse: [ 29 | t for t in parse if t not in ('(', ')')]) 30 | 31 | def merge_vocab(self, *args, **kwargs): 32 | """Construct the Vocab object for this field from one or more datasets. 33 | 34 | Arguments: 35 | Positional arguments: Dataset objects or other iterable data 36 | sources from which to construct the Vocab object that 37 | represents the set of possible values for this field. If 38 | a Dataset object is provided, all columns corresponding 39 | to this field are used; individual columns can also be 40 | provided directly. 41 | Remaining keyword arguments: Passed to the constructor of Vocab. 42 | """ 43 | counter = Counter() 44 | sources = [] 45 | for arg in args: 46 | # print(arg) 47 | if isinstance(arg, Dataset): 48 | sources += [getattr(arg, name) for name, field in 49 | arg.fields.items() 50 | if field is self or 51 | isinstance(field, SSUField)] 52 | else: 53 | sources.append(arg) 54 | for data in sources: 55 | for x in data: 56 | if not self.sequential: 57 | x = [x] 58 | counter.update(x) 59 | specials = list(OrderedDict.fromkeys( 60 | tok for tok in [self.pad_token, self.init_token, self.eos_token] 61 | if tok is not None)) 62 | self.vocab = Vocab(counter, specials=specials, **kwargs) 63 | 64 | def plugin_new_words(self, new_vocab): 65 | for word, i in new_vocab.stoi.items(): 66 | if word in self.vocab.stoi: 67 | continue 68 | else: 69 | self.vocab.itos.append(word) 70 | self.vocab.stoi[word] = len(self.vocab.itos) - 1 71 | 72 | 73 | class SSTTextLField(data.Field): 74 | def __init__(self, tokenize=data.get_tokenizer('spacy'), eos_token='', lower=False, include_lengths=True): 75 | super(SSTTextLField, self).__init__( 76 | tokenize=tokenize, 77 | eos_token=eos_token, lower=lower, include_lengths=include_lengths) 78 | 79 | 80 | class MNLI(data.ZipDataset, data.TabularDataset): 81 | # url = 'http://nlp.stanford.edu/projects/snli/snli_1.0.zip' 82 | filename = 'multinli_0.9.zip' 83 | dirname = 'multinli_0.9' 84 | 85 | @staticmethod 86 | def sort_key(ex): 87 | return data.interleave_keys( 88 | len(ex.premise), len(ex.hypothesis)) 89 | 90 | @classmethod 91 | def splits(cls, text_field, label_field, parse_field=None, genre_field=None, root='.', 92 | train=None, validation=None, test=None): 93 | """Create dataset objects for splits of the SNLI dataset. 94 | This is the most flexible way to use the dataset. 95 | Arguments: 96 | text_field: The field that will be used for premise and hypothesis 97 | data. 98 | label_field: The field that will be used for label data. 99 | parse_field: The field that will be used for shift-reduce parser 100 | transitions, or None to not include them. 101 | root: The root directory that the dataset's zip archive will be 102 | expanded into; therefore the directory in whose snli_1.0 103 | subdirectory the data files will be stored. 104 | train: The filename of the train data. Default: 'train.jsonl'. 105 | validation: The filename of the validation data, or None to not 106 | load the validation set. Default: 'dev.jsonl'. 107 | test: The filename of the test data, or None to not load the test 108 | set. Default: 'test.jsonl'. 109 | """ 110 | path = cls.download_or_unzip(root) 111 | if parse_field is None: 112 | return super(MNLI, cls).splits( 113 | os.path.join(path, 'multinli_0.9_'), train, validation, test, 114 | format='json', fields={'sentence1': ('premise', text_field), 115 | 'sentence2': ('hypothesis', text_field), 116 | 'gold_label': ('label', label_field)}, 117 | filter_pred=lambda ex: ex.label != '-') 118 | return super(MNLI, cls).splits( 119 | os.path.join(path, 'multinli_0.9_'), train, validation, test, 120 | format='json', fields={'sentence1_binary_parse': 121 | [('premise', text_field), 122 | ('premise_transitions', parse_field)], 123 | 'sentence2_binary_parse': 124 | [('hypothesis', text_field), 125 | ('hypothesis_transitions', parse_field)], 126 | 'gold_label': ('label', label_field), 127 | 'genre': ('genre', genre_field)}, 128 | filter_pred=lambda ex: ex.label != '-') 129 | 130 | 131 | if __name__ == "__main__": 132 | pass -------------------------------------------------------------------------------- /util/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import config 3 | from torchtext import data, vocab 4 | from torchtext import datasets 5 | from util.mnli import MNLI 6 | import numpy as np 7 | import itertools 8 | from torch.autograd import Variable 9 | 10 | 11 | class RParsedTextLField(data.Field): 12 | def __init__(self, eos_token='', lower=False, include_lengths=True): 13 | super(RParsedTextLField, self).__init__( 14 | eos_token=eos_token, lower=lower, include_lengths=True, preprocessing=lambda parse: [ 15 | t for t in parse if t not in ('(', ')')], 16 | postprocessing=lambda parse, _, __: [ 17 | list(reversed(p)) for p in parse]) 18 | 19 | 20 | class ParsedTextLField(data.Field): 21 | def __init__(self, eos_token='', lower=False, include_lengths=True): 22 | super(ParsedTextLField, self).__init__( 23 | eos_token=eos_token, lower=lower, include_lengths=True, preprocessing=lambda parse: [ 24 | t for t in parse if t not in ('(', ')')]) 25 | 26 | def plugin_new_words(self, new_vocab): 27 | for word, i in new_vocab.stoi.items(): 28 | if word in self.vocab.stoi: 29 | continue 30 | else: 31 | self.vocab.itos.append(word) 32 | self.vocab.stoi[word] = len(self.vocab.itos) - 1 33 | 34 | 35 | def load_new_embedding(embd_file=config.DATA_ROOT + "/saved_embd_new.pt"): 36 | embd = torch.load(embd_file) 37 | return embd 38 | 39 | 40 | def load_data_sm(data_root, embd_file, reseversed=True, batch_sizes=(32, 32, 32, 32, 32), device=-1, shuffle=False): 41 | if reseversed: 42 | testl_field = RParsedTextLField() 43 | else: 44 | testl_field = ParsedTextLField() 45 | 46 | transitions_field = datasets.snli.ShiftReduceField() 47 | y_field = data.Field(sequential=False) 48 | g_field = data.Field(sequential=False) 49 | 50 | train_size, dev_size, test_size, m_dev_size, m_test_size = batch_sizes 51 | 52 | snli_train, snli_dev, snli_test = datasets.SNLI.splits(testl_field, y_field, transitions_field, root=data_root) 53 | 54 | mnli_train, mnli_dev_m, mnli_dev_um = MNLI.splits(testl_field, y_field, transitions_field, g_field, root=data_root, 55 | train='train.jsonl', 56 | validation='dev_matched.jsonl', 57 | test='dev_mismatched.jsonl') 58 | 59 | mnli_test_m, mnli_test_um = MNLI.splits(testl_field, y_field, transitions_field, g_field, root=data_root, 60 | train=None, 61 | validation='test_matched_unlabeled.jsonl', 62 | test='test_mismatched_unlabeled.jsonl') 63 | 64 | testl_field.build_vocab(snli_train, snli_dev, snli_test, 65 | mnli_train, mnli_dev_m, mnli_dev_um, mnli_test_m, mnli_test_um) 66 | 67 | g_field.build_vocab(mnli_train, mnli_dev_m, mnli_dev_um, mnli_test_m, mnli_test_um) 68 | y_field.build_vocab(snli_train) 69 | print('Important:', y_field.vocab.itos) 70 | testl_field.vocab.vectors = torch.load(embd_file) 71 | 72 | snli_train_iter, snli_dev_iter, snli_test_iter = data.Iterator.splits( 73 | (snli_train, snli_dev, snli_test), batch_sizes=batch_sizes, device=device, shuffle=False, sort=False) 74 | 75 | mnli_train_iter, mnli_dev_m_iter, mnli_dev_um_iter, mnli_test_m_iter, mnli_test_um_iter = data.Iterator.splits( 76 | (mnli_train, mnli_dev_m, mnli_dev_um, mnli_test_m, mnli_test_um), 77 | batch_sizes=(train_size, m_dev_size, m_test_size, m_dev_size, m_test_size), 78 | device=device, shuffle=shuffle, sort=False) 79 | 80 | return (snli_train_iter, snli_dev_iter, snli_test_iter), (mnli_train_iter, mnli_dev_m_iter, mnli_dev_um_iter, mnli_test_m_iter, mnli_test_um_iter), testl_field.vocab.vectors 81 | 82 | 83 | def load_data_embd_vocab_snli(data_root, embd_file, reseversed=True, batch_sizes=(32, 32, 32, 32, 32), device=-1): 84 | if reseversed: 85 | testl_field = RParsedTextLField() 86 | else: 87 | testl_field = ParsedTextLField() 88 | 89 | transitions_field = datasets.snli.ShiftReduceField() 90 | y_field = data.Field(sequential=False) 91 | g_field = data.Field(sequential=False) 92 | 93 | train_size, dev_size, test_size, m_dev_size, m_test_size = batch_sizes 94 | 95 | snli_train, snli_dev, snli_test = datasets.SNLI.splits(testl_field, y_field, transitions_field, root=data_root) 96 | 97 | mnli_train, mnli_dev_m, mnli_dev_um = MNLI.splits(testl_field, y_field, transitions_field, g_field, root=data_root, 98 | train='train.jsonl', 99 | validation='dev_matched.jsonl', 100 | test='dev_mismatched.jsonl') 101 | 102 | mnli_test_m, mnli_test_um = MNLI.splits(testl_field, y_field, transitions_field, g_field, root=data_root, 103 | train=None, 104 | validation='test_matched_unlabeled.jsonl', 105 | test='test_mismatched_unlabeled.jsonl') 106 | 107 | testl_field.build_vocab(snli_train, snli_dev, snli_test, 108 | mnli_train, mnli_dev_m, mnli_dev_um, mnli_test_m, mnli_test_um) 109 | 110 | testl_field.vocab.vectors = torch.load(embd_file) 111 | 112 | return testl_field.vocab.vectors, testl_field.vocab, testl_field 113 | 114 | 115 | def combine_two_set(set_1, set_2, rate=(1, 1), seed=0): 116 | np.random.seed(seed) 117 | len_1 = len(set_1) 118 | len_2 = len(set_2) 119 | # print(len_1, len_2) 120 | p1, p2 = rate 121 | c_1 = np.random.choice([0, 1], len_1, p=[1 - p1, p1]) 122 | c_2 = np.random.choice([0, 1], len_2, p=[1 - p2, p2]) 123 | iter_1 = itertools.compress(iter(set_1), c_1) 124 | iter_2 = itertools.compress(iter(set_2), c_2) 125 | for it in itertools.chain(iter_1, iter_2): 126 | yield it 127 | 128 | 129 | if __name__ == '__main__': 130 | snli, mnli, embd = load_data_sm(config.DATA_ROOT, config.EMBD_FILE, reseversed=False, 131 | batch_sizes=(32, 32, 32)) 132 | 133 | s_train, s_dev, s_test = snli 134 | m_train, m_dev_m, m_dev_um, m_test_m, m_test_um = mnli 135 | 136 | train = combine_two_set(s_train, m_train, rate=[0.15, 1]) 137 | 138 | print(len(list(train))) 139 | print(len(m_train)) 140 | print(len(s_train)) -------------------------------------------------------------------------------- /model/res_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import optim 4 | from torch.autograd import Variable 5 | import torch_util 6 | from tqdm import tqdm 7 | import util.save_tool as save_tool 8 | import os 9 | from datetime import datetime 10 | 11 | import util.data_loader as data_loader 12 | import config 13 | import fire 14 | 15 | 16 | def model_eval(model, data_iter, criterion, pred=False): 17 | model.eval() 18 | data_iter.init_epoch() 19 | n_correct = loss = 0 20 | totoal_size = 0 21 | 22 | if not pred: 23 | for batch_idx, batch in enumerate(data_iter): 24 | 25 | s1, s1_l = batch.premise 26 | s2, s2_l = batch.hypothesis 27 | y = batch.label.data - 1 28 | 29 | pred = model(s1, s1_l - 1, s2, s2_l - 1) 30 | n_correct += (torch.max(pred, 1)[1].view(batch.label.size()).data == y).sum() 31 | 32 | loss += criterion(pred, batch.label - 1).data[0] * batch.batch_size 33 | totoal_size += batch.batch_size 34 | 35 | avg_acc = 100. * n_correct / totoal_size 36 | avg_loss = loss / totoal_size 37 | 38 | return avg_acc, avg_loss 39 | else: 40 | pred_list = [] 41 | for batch_idx, batch in enumerate(data_iter): 42 | 43 | s1, s1_l = batch.premise 44 | s2, s2_l = batch.hypothesis 45 | 46 | pred = model(s1, s1_l - 1, s2, s2_l - 1) 47 | pred_list.append(torch.max(pred, 1)[1].view(batch.label.size()).data) 48 | 49 | return torch.cat(pred_list, dim=0) 50 | 51 | 52 | class ResEncoder(nn.Module): 53 | def __init__(self, h_size=[600, 600, 600], v_size=10, d=300, mlp_d=800, dropout_r=0.1, max_l=60, k=3, n_layers=1): 54 | super(ResEncoder, self).__init__() 55 | self.Embd = nn.Embedding(v_size, d) 56 | 57 | self.lstm = nn.LSTM(input_size=d, hidden_size=h_size[0], 58 | num_layers=1, bidirectional=True) 59 | 60 | self.lstm_1 = nn.LSTM(input_size=(d + h_size[0] * 2), hidden_size=h_size[1], 61 | num_layers=1, bidirectional=True) 62 | 63 | self.lstm_2 = nn.LSTM(input_size=(d + h_size[0] * 2), hidden_size=h_size[2], 64 | num_layers=1, bidirectional=True) 65 | 66 | self.max_l = max_l 67 | self.h_size = h_size 68 | self.k = k 69 | 70 | self.mlp_1 = nn.Linear(h_size[2] * 2 * 4, mlp_d) 71 | self.mlp_2 = nn.Linear(mlp_d, mlp_d) 72 | self.sm = nn.Linear(mlp_d, 3) 73 | 74 | if n_layers == 1: 75 | self.classifier = nn.Sequential(*[self.mlp_1, nn.ReLU(), nn.Dropout(dropout_r), 76 | self.sm]) 77 | elif n_layers == 2: 78 | self.classifier = nn.Sequential(*[self.mlp_1, nn.ReLU(), nn.Dropout(dropout_r), 79 | self.mlp_2, nn.ReLU(), nn.Dropout(dropout_r), 80 | self.sm]) 81 | else: 82 | print("Error num layers") 83 | 84 | def count_params(self): 85 | total_c = 0 86 | for param in self.parameters(): 87 | if len(param.size()) == 2: 88 | d1, d2 = param.size()[0], param.size()[1] 89 | total_c += d1 * d2 90 | print("Total count:", total_c) 91 | 92 | def display(self): 93 | for param in self.parameters(): 94 | print(param.data.size()) 95 | 96 | def forward(self, s1, l1, s2, l2): 97 | if self.max_l: 98 | l1 = l1.clamp(max=self.max_l) 99 | l2 = l2.clamp(max=self.max_l) 100 | if s1.size(0) > self.max_l: 101 | s1 = s1[:self.max_l, :] 102 | if s2.size(0) > self.max_l: 103 | s2 = s2[:self.max_l, :] 104 | 105 | p_s1 = self.Embd(s1) 106 | p_s2 = self.Embd(s2) 107 | 108 | s1_layer1_out = torch_util.auto_rnn_bilstm(self.lstm, p_s1, l1) 109 | s2_layer1_out = torch_util.auto_rnn_bilstm(self.lstm, p_s2, l2) 110 | 111 | # Length truncate 112 | len1 = s1_layer1_out.size(0) 113 | len2 = s2_layer1_out.size(0) 114 | p_s1 = p_s1[:len1, :, :] 115 | p_s2 = p_s2[:len2, :, :] 116 | 117 | # Using high way 118 | s1_layer2_in = torch.cat([p_s1, s1_layer1_out], dim=2) 119 | s2_layer2_in = torch.cat([p_s2, s2_layer1_out], dim=2) 120 | 121 | s1_layer2_out = torch_util.auto_rnn_bilstm(self.lstm_1, s1_layer2_in, l1) 122 | s2_layer2_out = torch_util.auto_rnn_bilstm(self.lstm_1, s2_layer2_in, l2) 123 | 124 | s1_layer3_in = torch.cat([p_s1, s1_layer1_out + s1_layer2_out], dim=2) 125 | s2_layer3_in = torch.cat([p_s2, s2_layer1_out + s2_layer2_out], dim=2) 126 | 127 | s1_layer3_out = torch_util.auto_rnn_bilstm(self.lstm_2, s1_layer3_in, l1) 128 | s2_layer3_out = torch_util.auto_rnn_bilstm(self.lstm_2, s2_layer3_in, l2) 129 | 130 | s1_layer3_maxout = torch_util.max_along_time(s1_layer3_out, l1) 131 | s2_layer3_maxout = torch_util.max_along_time(s2_layer3_out, l2) 132 | 133 | # Only use the last layer 134 | features = torch.cat([s1_layer3_maxout, s2_layer3_maxout, 135 | torch.abs(s1_layer3_maxout - s2_layer3_maxout), 136 | s1_layer3_maxout * s2_layer3_maxout], 137 | dim=1) 138 | 139 | out = self.classifier(features) 140 | return out 141 | 142 | 143 | def train_snli(): 144 | seed = 12 145 | rate = 0.1 146 | n_layers = 1 147 | mlp_d = 800 148 | torch.manual_seed(seed) 149 | torch.cuda.manual_seed(seed) 150 | 151 | snli_d, mnli_d, embd = data_loader.load_data_sm( 152 | config.DATA_ROOT, config.EMBD_FILE, reseversed=False, batch_sizes=(32, 200, 200, 32, 32), device=0) 153 | 154 | s_train, s_dev, s_test = snli_d 155 | 156 | s_train.repeat = False 157 | 158 | model = ResEncoder(mlp_d=mlp_d, dropout_r=rate, n_layers=n_layers) 159 | model.Embd.weight.data = embd 160 | model.display() 161 | 162 | if torch.cuda.is_available(): 163 | embd.cuda() 164 | model.cuda() 165 | 166 | start_lr = 2e-4 167 | 168 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=start_lr) 169 | criterion = nn.CrossEntropyLoss() 170 | 171 | date_now = datetime.now().strftime("%m-%d-%H:%M:%S") 172 | name = '[600,600,600]-3stack-bilstm-maxout-residual-{}-relu-seed({})-dr({})-mlpd({})'.format(n_layers, seed, rate, mlp_d) 173 | file_path = save_tool.gen_prefix(name, date_now) 174 | 175 | """ 176 | Attention:!!! 177 | Modify this to save to log file. 178 | """ 179 | 180 | save_tool.logging2file(file_path, 'code', None, __file__) 181 | 182 | iterations = 0 183 | best_dev = -1 184 | 185 | param_file_prefix = "{}/{}".format(file_path, "saved_params_snli") 186 | if not os.path.exists(os.path.join(config.ROOT_DIR, param_file_prefix)): 187 | os.mkdir(os.path.join(config.ROOT_DIR, param_file_prefix)) 188 | 189 | for i in range(3): 190 | s_train.init_epoch() 191 | 192 | train_iter, dev_iter = s_train, s_dev 193 | train_iter.repeat = False 194 | 195 | if i != 0: 196 | SAVE_PATH = os.path.join(config.ROOT_DIR, file_path, 'm_{}_snli_e'.format(i - 1)) 197 | model.load_state_dict(torch.load(SAVE_PATH)) 198 | 199 | start_perf = model_eval(model, dev_iter, criterion) 200 | i_decay = i // 2 201 | lr = start_lr / (2 ** i_decay) 202 | 203 | epoch_start_info = "epoch:{}, learning_rate:{}, start_performance:{}/{}\n".format(i, lr, *start_perf) 204 | print(epoch_start_info) 205 | save_tool.logging2file(file_path, 'log_snli', epoch_start_info) 206 | 207 | for batch_idx, batch in tqdm(enumerate(train_iter)): 208 | iterations += 1 209 | model.train() 210 | 211 | s1, s1_l = batch.premise 212 | s2, s2_l = batch.hypothesis 213 | y = batch.label - 1 214 | 215 | out = model(s1, (s1_l - 1), s2, (s2_l - 1)) 216 | loss = criterion(out, y) 217 | 218 | optimizer.zero_grad() 219 | 220 | for pg in optimizer.param_groups: 221 | pg['lr'] = lr 222 | 223 | loss.backward() 224 | optimizer.step() 225 | 226 | if i == 0: 227 | mod = 9000 228 | elif i == 1: 229 | mod = 1000 230 | else: 231 | mod = 100 232 | 233 | if (1 + batch_idx) % mod == 0: 234 | model.max_l = 150 235 | 236 | dev_score, dev_loss = model_eval(model, dev_iter, criterion) 237 | print('SNLI:dev:{}/{}'.format(dev_score, dev_loss), end='\n') 238 | 239 | model.max_l = 60 240 | 241 | if best_dev < dev_score: 242 | 243 | best_dev = dev_score 244 | 245 | now = datetime.now().strftime("%m-%d-%H:%M:%S") 246 | log_info = "{}\t{}\tdev:{}/{}".format(i, iterations, dev_score, dev_loss, now) 247 | save_tool.logging2file(file_path, "log_snli", log_info) 248 | 249 | save_path = os.path.join(config.ROOT_DIR, param_file_prefix, 250 | 'e({})_dev({})'.format(i, dev_score)) 251 | 252 | torch.save(model.state_dict(), save_path) 253 | 254 | SAVE_PATH = os.path.join(config.ROOT_DIR, file_path, 'm_{}_snli_e'.format(i)) 255 | torch.save(model.state_dict(), SAVE_PATH) 256 | 257 | 258 | def eval(model_path, mode='dev'): 259 | snli_d, mnli_d, embd = data_loader.load_data_sm( 260 | config.DATA_ROOT, config.EMBD_FILE, reseversed=False, batch_sizes=(32, 200, 200, 32, 32), device=0) 261 | 262 | s_train, s_dev, s_test = snli_d 263 | 264 | rate = 0.1 265 | n_layers = 1 266 | mlp_d = 800 267 | 268 | model = ResEncoder(mlp_d=mlp_d, dropout_r=rate, n_layers=n_layers) 269 | model.Embd.weight.data = embd 270 | # model.display() 271 | 272 | if torch.cuda.is_available(): 273 | embd.cuda() 274 | model.cuda() 275 | 276 | criterion = nn.CrossEntropyLoss() 277 | 278 | if mode == 'dev': 279 | d_iter = s_dev 280 | else: 281 | d_iter = s_test 282 | 283 | SAVE_PATH = model_path 284 | model.load_state_dict(torch.load(SAVE_PATH)) 285 | score, loss = model_eval(model, d_iter, criterion) 286 | print("{} score/loss:{}/{}".format(mode, score, loss)) 287 | 288 | if __name__ == '__main__': 289 | # train_snli() 290 | # eval(model_path="/home/easonnie/projects/ResEncoder/saved_model/12-04-23:22:31_[600,600,600]-3stack-bilstm-maxout-residual-1-relu-seed(12)-dr(0.1)-mlpd(800)/saved_params_snli/e(2)_dev(87.00467384677911)", mode='dev') 291 | # eval('test') 292 | 293 | # fire.Fire() 294 | fire.Fire() 295 | 296 | 297 | 298 | -------------------------------------------------------------------------------- /torch_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | 7 | 8 | def pad_1d(t, pad_l): 9 | l = t.size(0) 10 | if l >= pad_l: 11 | return t 12 | else: 13 | pad_seq = Variable(t.data.new(pad_l - l, *t.size()[1:]).zero_()) 14 | return torch.cat([t, pad_seq], dim=0) 15 | 16 | 17 | def pad(t, length, batch_first=False): 18 | """ 19 | Padding the sequence to a fixed length. 20 | :param t: [B * T * D] or [B * T] if batch_first else [T * B * D] or [T * B] 21 | :param length: [B] 22 | :param batch_first: 23 | :return: 24 | """ 25 | if batch_first: 26 | # [B * T * D] 27 | if length <= t.size(1): 28 | return t 29 | else: 30 | batch_size = t.size(0) 31 | pad_seq = Variable(t.data.new(batch_size, length - t.size(1), *t.size()[2:]).zero_()) 32 | # [B * T * D] 33 | return torch.cat([t, pad_seq], dim=1) 34 | else: 35 | # [T * B * D] 36 | if length <= t.size(0): 37 | return t 38 | else: 39 | return torch.cat([t, Variable(t.data.new(length - t.size(0), *t.size()[1:]).zero_())]) 40 | 41 | 42 | def batch_first2time_first(inputs): 43 | """ 44 | Convert input from batch_first to time_first: 45 | [B * T * D] -> [T * B * D] 46 | 47 | :param inputs: 48 | :return: 49 | """ 50 | return torch.transpose(inputs, 0, 1) 51 | 52 | 53 | def time_first2batch_first(inputs): 54 | """ 55 | Convert input from batch_first to time_first: 56 | [T * B * D] -> [B * T * D] 57 | 58 | :param inputs: 59 | :return: 60 | """ 61 | 62 | return torch.transpose(inputs, 0, 1) 63 | 64 | 65 | def get_state_shape(rnn: nn.RNN, batch_size, bidirectional=False): 66 | """ 67 | Return the state shape of a given RNN. This is helpful when you want to create a init state for RNN. 68 | 69 | Example: 70 | h0 = Variable(src_seq_p.data.new(*get_state_shape(self.encoder, 3, bidirectional)).zero_()) 71 | 72 | :param rnn: 73 | :param batch_size: 74 | :param bidirectional: 75 | :return: 76 | """ 77 | if bidirectional: 78 | return rnn.num_layers * 2, batch_size, rnn.hidden_size 79 | else: 80 | return rnn.num_layers, batch_size, rnn.hidden_size 81 | 82 | 83 | def pack_list_sequence(inputs, l, batch_first=False): 84 | """ 85 | Pack a batch of Tensor into one Tensor. 86 | :param inputs: 87 | :param l: 88 | :return: 89 | """ 90 | batch_list = [] 91 | max_l = max(list(l)) 92 | batch_size = len(inputs) 93 | 94 | for b_i in range(batch_size): 95 | batch_list.append(pad(inputs[b_i], max_l)) 96 | pack_batch_list = torch.stack(batch_list, dim=1) if not batch_first \ 97 | else torch.stack(batch_list, dim=0) 98 | return pack_batch_list 99 | 100 | 101 | def pack_for_rnn_seq(inputs, lengths, batch_first=False): 102 | """ 103 | :param inputs: [T * B * D] 104 | :param lengths: [B] 105 | :return: 106 | """ 107 | if not batch_first: 108 | _, sorted_indices = lengths.sort() 109 | ''' 110 | Reverse to decreasing order 111 | ''' 112 | r_index = reversed(list(sorted_indices)) 113 | 114 | s_inputs_list = [] 115 | lengths_list = [] 116 | reverse_indices = np.zeros(lengths.size(0), dtype=np.int64) 117 | 118 | for j, i in enumerate(r_index): 119 | s_inputs_list.append(inputs[:, i, :].unsqueeze(1)) 120 | lengths_list.append(lengths[i]) 121 | reverse_indices[i] = j 122 | 123 | reverse_indices = list(reverse_indices) 124 | 125 | s_inputs = torch.cat(s_inputs_list, 1) 126 | packed_seq = nn.utils.rnn.pack_padded_sequence(s_inputs, lengths_list) 127 | 128 | return packed_seq, reverse_indices 129 | 130 | else: 131 | _, sorted_indices = lengths.sort() 132 | ''' 133 | Reverse to decreasing order 134 | ''' 135 | r_index = reversed(list(sorted_indices)) 136 | 137 | s_inputs_list = [] 138 | lengths_list = [] 139 | reverse_indices = np.zeros(lengths.size(0), dtype=np.int64) 140 | 141 | for j, i in enumerate(r_index): 142 | s_inputs_list.append(inputs[i, :, :]) 143 | lengths_list.append(lengths[i]) 144 | reverse_indices[i] = j 145 | 146 | reverse_indices = list(reverse_indices) 147 | 148 | s_inputs = torch.stack(s_inputs_list, dim=0) 149 | # print(s_inputs) 150 | # print(lengths_list) 151 | packed_seq = nn.utils.rnn.pack_padded_sequence(s_inputs, lengths_list, batch_first=batch_first) 152 | 153 | return packed_seq, reverse_indices 154 | 155 | 156 | def unpack_from_rnn_seq(packed_seq, reverse_indices, batch_first=False): 157 | unpacked_seq, _ = nn.utils.rnn.pad_packed_sequence(packed_seq, batch_first=batch_first) 158 | s_inputs_list = [] 159 | 160 | if not batch_first: 161 | for i in reverse_indices: 162 | s_inputs_list.append(unpacked_seq[:, i, :].unsqueeze(1)) 163 | return torch.cat(s_inputs_list, 1) 164 | else: 165 | for i in reverse_indices: 166 | s_inputs_list.append(unpacked_seq[i, :, :].unsqueeze(0)) 167 | return torch.cat(s_inputs_list, 0) 168 | 169 | 170 | def auto_rnn(rnn: nn.RNN, seqs, lengths, batch_first=True): 171 | 172 | batch_size = seqs.size(0) if batch_first else seqs.size(1) 173 | state_shape = get_state_shape(rnn, batch_size, rnn.bidirectional) 174 | 175 | h0 = c0 = Variable(seqs.data.new(*state_shape).zero_()) 176 | 177 | packed_pinputs, r_index = pack_for_rnn_seq(seqs, lengths, batch_first) 178 | output, (hn, cn) = rnn(packed_pinputs, (h0, c0)) 179 | output = unpack_from_rnn_seq(output, r_index, batch_first) 180 | 181 | return output 182 | 183 | 184 | def pack_seqence_for_linear(inputs, lengths, batch_first=True): 185 | """ 186 | :param inputs: [B * T * D] if batch_first 187 | :param lengths: [B] 188 | :param batch_first: 189 | :param partition: this is the new batch_size for parallel matrix process. 190 | :param chuck: partition the output into equal size chucks 191 | :return: 192 | """ 193 | batch_list = [] 194 | if batch_first: 195 | for i, l in enumerate(lengths): 196 | batch_list.append(inputs[i, :l]) 197 | packed_sequence = torch.cat(batch_list, 0) 198 | # if chuck: 199 | # return list(torch.chunk(packed_sequence, chuck, dim=0)) 200 | # else: 201 | return packed_sequence 202 | 203 | else: 204 | raise NotImplemented() 205 | 206 | 207 | def chucked_forward(inputs, net, chuck=None): 208 | if not chuck: 209 | return net(inputs) 210 | else: 211 | output_list = [net(chuck) for chuck in torch.chunk(inputs, chuck, dim=0)] 212 | return torch.cat(output_list, dim=0) 213 | 214 | 215 | def unpack_seqence_for_linear(inputs, lengths, batch_first=True): 216 | batch_list = [] 217 | max_l = max(lengths) 218 | 219 | if not isinstance(inputs, list): 220 | inputs = [inputs] 221 | inputs = torch.cat(inputs) 222 | 223 | if batch_first: 224 | start = 0 225 | for l in lengths: 226 | end = start + l 227 | batch_list.append(pad_1d(inputs[start:end], max_l)) 228 | start = end 229 | return torch.stack(batch_list) 230 | else: 231 | raise NotImplemented() 232 | 233 | 234 | def auto_rnn_bilstm(lstm: nn.LSTM, seqs, lengths): 235 | 236 | batch_size = seqs.size(1) 237 | 238 | state_shape = lstm.num_layers * 2, batch_size, lstm.hidden_size 239 | 240 | h0 = c0 = Variable(seqs.data.new(*state_shape).zero_()) 241 | 242 | packed_pinputs, r_index = pack_for_rnn_seq(seqs, lengths) 243 | 244 | output, (hn, cn) = lstm(packed_pinputs, (h0, c0)) 245 | 246 | output = unpack_from_rnn_seq(output, r_index) 247 | 248 | return output 249 | 250 | 251 | def auto_rnn_bigru(gru: nn.GRU, seqs, lengths): 252 | 253 | batch_size = seqs.size(1) 254 | 255 | state_shape = gru.num_layers * 2, batch_size, gru.hidden_size 256 | 257 | h0 = Variable(seqs.data.new(*state_shape).zero_()) 258 | 259 | packed_pinputs, r_index = pack_for_rnn_seq(seqs, lengths) 260 | 261 | output, hn = gru(packed_pinputs, h0) 262 | 263 | output = unpack_from_rnn_seq(output, r_index) 264 | 265 | return output 266 | 267 | 268 | def select_last(inputs, lengths, hidden_size): 269 | """ 270 | :param inputs: [T * B * D] D = 2 * hidden_size 271 | :param lengths: [B] 272 | :param hidden_size: dimension 273 | :return: [B * D] 274 | """ 275 | batch_size = inputs.size(1) 276 | batch_out_list = [] 277 | for b in range(batch_size): 278 | batch_out_list.append(torch.cat((inputs[lengths[b] - 1, b, :hidden_size], 279 | inputs[0, b, hidden_size:]) 280 | ) 281 | ) 282 | 283 | out = torch.stack(batch_out_list) 284 | return out 285 | 286 | 287 | def matching_matrix(s1, s2): 288 | """ 289 | :param s1: [T * B * D] 290 | :param s2: [T * B * D] 291 | :return: [B * T * T] 292 | """ 293 | b_s1 = s1.transpose(0, 1) # [B * T * D] 294 | b_s2 = s2.transpose(0, 1) # [B * T * D] 295 | 296 | matrix = torch.bmm(b_s1, b_s2.transpose(1, 2)) 297 | return matrix 298 | 299 | 300 | def sequence_matrix_cross_alignment(s1, s2, l1, l2, matrix): 301 | """ 302 | :param s1: [T * B * D] 303 | :param s2: [T * B * D] 304 | :param l1: [B] 305 | :param l2: [B] 306 | :param matrix: [B * T * T] 307 | :return: 308 | """ 309 | 310 | b_s1 = s1.transpose(0, 1) # [B * T * D] 311 | b_s2 = s2.transpose(0, 1) # [B * T * D] 312 | dim = b_s1.size(2) 313 | batch_size = b_s1.size(0) 314 | 315 | b_wsum_a2to1_list = [] 316 | b_wsum_a1to2_list = [] 317 | 318 | for b in range(batch_size): 319 | b_m = matrix[b] 320 | ba_s1 = b_s1[b] 321 | ba_s2 = b_s2[b] 322 | 323 | # # align s2 to s1 _a2to1 324 | exp_b_m_a2to1 = F.softmax(b_m[:l1[b], :l2[b]]) # [t1 * t2] 325 | 326 | b_weight_a2to1 = exp_b_m_a2to1.unsqueeze(2).expand(l1[b], l2[b], dim) # [t1 * t2] -> [t1 * t2 * dim] 327 | b_seq_a2to1 = ba_s2[:l2[b]].unsqueeze(0).expand(l1[b], l2[b], dim) # [t2 * dim] -> [t1 * t2 * dim] 328 | 329 | b_wsum_a2to1 = torch.sum(b_weight_a2to1 * b_seq_a2to1, dim=1).squeeze(1) # [t1 * d] 330 | b_wsum_a2to1_list.append(b_wsum_a2to1) 331 | 332 | # # align s1 to s2 _a1to2 333 | exp_b_m_a1to2 = F.softmax(b_m[:l1[b], :l2[b]].transpose(0, 1)) # [t2 * t1] 334 | 335 | b_weight_a1to2 = exp_b_m_a1to2.unsqueeze(2).expand(l2[b], l1[b], dim) # [t2 * t1] -> [t2 * t1 * d] 336 | b_seq_a1to2 = ba_s1[:l1[b]].unsqueeze(0).expand(l2[b], l1[b], dim) # [t1 * dim] -> [t2 * t1 * dim] 337 | 338 | b_wsum_a1to2 = torch.sum(b_weight_a1to2 * b_seq_a1to2, dim=1).squeeze(1) # [t2 * d] 339 | b_wsum_a1to2_list.append(b_wsum_a1to2) 340 | # print(b_wsum_a1to2) 341 | 342 | align_s2_to_s1 = pack_list_sequence(b_wsum_a2to1_list, l1) 343 | align_s1_to_s2 = pack_list_sequence(b_wsum_a1to2_list, l2) 344 | 345 | return align_s2_to_s1, align_s1_to_s2 346 | 347 | 348 | def channel_weighted_sum(s, w, l, sharpen=None): 349 | batch_size = w.size(1) 350 | result_list = [] 351 | for b_i in range(batch_size): 352 | if sharpen: 353 | b_w = w[:l[b_i], b_i, :] * sharpen 354 | else: 355 | b_w = w[:l[b_i], b_i, :] 356 | b_s = s[:l[b_i], b_i, :] # T, D 357 | soft_b_w = F.softmax(b_w.transpose(0, 1)).transpose(0, 1) 358 | # print(soft_b_w) 359 | # print('soft:', ) 360 | # print(soft_b_w) 361 | result_list.append(torch.sum(soft_b_w * b_s, dim=0)) # [T, D] -> [1, D] 362 | return torch.cat(result_list, dim=0) 363 | 364 | 365 | def topk_weighted_sum(s, w, k, l): 366 | batch_size = w.size(1) 367 | result_list = [] 368 | for b_i in range(batch_size): 369 | # print(w.size()) 370 | # print(l[b_i]) 371 | b_w = w[:l[b_i], b_i, :] 372 | b_s = s[:l[b_i], b_i, :] # T, D 373 | if l[b_i] == 1: 374 | b_topk, b_topk_indices = b_s.max(dim=0) 375 | elif l[b_i] < k: 376 | b_topk, b_topk_indices = torch.topk(b_s, l[b_i], dim=0) 377 | else: 378 | b_topk, b_topk_indices = torch.topk(b_s, k, dim=0) 379 | 380 | b_topk_w = torch.gather(b_w, 0, b_topk_indices) 381 | soft_b_topk_w = F.softmax(b_topk_w.transpose(0, 1)).transpose(0, 1) 382 | result_list.append(torch.sum(soft_b_topk_w * b_topk, dim=0)) 383 | return torch.cat(result_list, dim=0) 384 | 385 | 386 | def topk_dp_weighted_sum(s, w, l): 387 | batch_size = w.size(1) 388 | result_list = [] 389 | for b_i in range(batch_size): 390 | # print(w.size()) 391 | # print(l[b_i]) 392 | 393 | # Dynamic pooling 394 | 395 | k = (int(l[b_i] - 1) // 10) + 1 396 | 397 | b_w = w[:l[b_i], b_i, :] 398 | b_s = s[:l[b_i], b_i, :] # T, D 399 | if l[b_i] == 1: 400 | b_topk, b_topk_indices = b_s.max(dim=0) 401 | elif l[b_i] < k: 402 | b_topk, b_topk_indices = torch.topk(b_s, l[b_i], dim=0) 403 | else: 404 | b_topk, b_topk_indices = torch.topk(b_s, k, dim=0) 405 | 406 | b_topk_w = torch.gather(b_w, 0, b_topk_indices) 407 | soft_b_topk_w = F.softmax(b_topk_w.transpose(0, 1)).transpose(0, 1) 408 | result_list.append(torch.sum(soft_b_topk_w * b_topk, dim=0)) 409 | return torch.cat(result_list, dim=0) 410 | 411 | 412 | def pack_to_matching_matrix(s1, s2, cat_only=[False, False]): 413 | t1 = s1.size(0) 414 | t2 = s2.size(0) 415 | batch_size = s1.size(1) 416 | d = s1.size(2) 417 | 418 | expanded_p_s1 = s1.expand(t2, t1, batch_size, d) 419 | 420 | expanded_p_s2 = s2.view(t2, 1, batch_size, d) 421 | expanded_p_s2 = expanded_p_s2.expand(t2, t1, batch_size, d) 422 | 423 | if not cat_only[0] and not cat_only[1]: 424 | matrix = torch.cat((expanded_p_s1, expanded_p_s2), dim=3) 425 | elif not cat_only[0] and cat_only[1]: 426 | matrix = torch.cat((expanded_p_s1, expanded_p_s2, expanded_p_s1 * expanded_p_s2), dim=3) 427 | else: 428 | matrix = torch.cat((expanded_p_s1, 429 | expanded_p_s2, 430 | torch.abs(expanded_p_s1 - expanded_p_s2), 431 | expanded_p_s1 * expanded_p_s2), dim=3) 432 | 433 | # matrix = torch.cat((expanded_p_s1, 434 | # expanded_p_s2), dim=3) 435 | 436 | return matrix 437 | 438 | 439 | def max_matching(gram_matrix, l1, l2): 440 | batch_size = gram_matrix.size(2) 441 | dim = gram_matrix.size(3) 442 | in_d = dim // 4 443 | 444 | t1_seq = [] 445 | t2_seq = [] 446 | for b_i in range(batch_size): 447 | b_m = gram_matrix[:l2[b_i], :l1[b_i], b_i, :] 448 | max_t1_a, _ = torch.max(b_m, dim=0) 449 | max_t2_a, _ = torch.max(b_m, dim=1) 450 | 451 | t1_seq.append(max_t1_a.view(l1[b_i], -1)) # [T1, B, 4D] 452 | t2_seq.append(max_t2_a.view(l2[b_i], -1)) # [T2, B, 4D] 453 | 454 | s1_seq = pack_list_sequence(t1_seq, l1) 455 | s2_seq = pack_list_sequence(t2_seq, l2) 456 | filp_l = [s2_seq[:, :, in_d:in_d * 2], s2_seq[:, :, :in_d], s2_seq[:, :, in_d * 2:]] 457 | s2_seq = torch.cat(filp_l, dim=2) 458 | 459 | return s1_seq, s2_seq 460 | 461 | 462 | def max_over_grammatrix(inputs, l1, l2): 463 | """ 464 | :param inputs: [T2 * T1 * B * D] 465 | :param l1: 466 | :param l2: 467 | :return: 468 | """ 469 | batch_size = inputs.size(2) 470 | max_out_list = [] 471 | for b in range(batch_size): 472 | b_gram_matrix = inputs[:l2[b], :l1[b], b, :] 473 | dim = b_gram_matrix.size(-1) 474 | 475 | b_max, _ = torch.max(b_gram_matrix.contiguous().view(-1, dim), dim=0) 476 | 477 | max_out_list.append(b_max) 478 | 479 | max_out = torch.cat(max_out_list, dim=0) 480 | return max_out 481 | 482 | 483 | def comparing_conv(matrices, l1, l2, conv_filter: nn.Linear, k_size, dropout=None, 484 | padding=True, list_in=False): 485 | """ 486 | :param conv_filter: [k * k * input_d] 487 | :param k_size: 488 | :param dropout: 489 | :return: 490 | """ 491 | k = k_size 492 | 493 | if list_in is False: 494 | batch_size = matrices.size(2) 495 | windows = [] 496 | for b in range(batch_size): 497 | b_matrix = matrices[:l2[b], :l1[b], b, :] 498 | 499 | if not padding: 500 | if l2[b] - k + 1 <= 0 or l1[b] - k + 1 <= 0: 501 | raise Exception('Kernel size error k={0}, matrix=({1},{2})'.format(k, l2[b], l1[b])) 502 | 503 | for i in range(l2[b] - k + 1): 504 | for j in range(l1[b] - k + 1): 505 | window = b_matrix[i:i + k, j:j + k, :] 506 | window_d = window.size(-1) 507 | windows.append(window.contiguous().view(k * k * window_d)) 508 | else: 509 | ch_d = b_matrix.size(-1) 510 | padding_n = (k - 1) // 2 511 | row_pad = Variable(torch.zeros(padding_n, l1[b], ch_d)) 512 | 513 | if torch.cuda.is_available(): 514 | row_pad = row_pad.cuda() 515 | # print(b_matrix) 516 | # print(row_pad) 517 | after_row_pad = torch.cat([row_pad, b_matrix, row_pad], dim=0) 518 | col_pad = Variable(torch.zeros(l2[b] + 2 * padding_n, padding_n, ch_d)) 519 | if torch.cuda.is_available(): 520 | col_pad = col_pad.cuda() 521 | after_col_pad = torch.cat([col_pad, after_row_pad, col_pad], dim=1) 522 | 523 | for i in range(padding_n, padding_n + l2[b]): 524 | for j in range(padding_n, padding_n + l1[b]): 525 | i_start = i - padding_n 526 | j_start = j - padding_n 527 | window = after_col_pad[i_start:i_start + k, j_start:j_start + k, :] 528 | windows.append(window.contiguous().view(k * k * ch_d)) 529 | 530 | windows = torch.stack(windows) 531 | else: 532 | batch_size = len(matrices) 533 | windows = [] 534 | for b in range(batch_size): 535 | b_matrix = matrices[b] 536 | b_l2 = b_matrix.size(0) 537 | b_l1 = b_matrix.size(1) 538 | 539 | if not padding: 540 | if l1 is not None and l2 is not None and (l2[b] != b_l2 or l1[b] != b_l1): 541 | raise Exception('Possible input matrices size error!') 542 | 543 | if b_l2 - k + 1 <= 0 or b_l1 - k + 1 <= 0: 544 | raise Exception('Kernel size error k={0}, matrix=({1},{2})'.format(k, l2[b], l1[b])) 545 | 546 | for i in range(b_l2 - k + 1): 547 | for j in range(b_l1 - k + 1): 548 | window = b_matrix[i:i + k, j:j + k, :] 549 | window_d = window.size(-1) 550 | windows.append(window.contiguous().view(k * k * window_d)) 551 | else: 552 | if l1 is not None and l2 is not None and (l2[b] != b_l2 or l1[b] != b_l1): 553 | raise Exception('Possible input matrices size error!') 554 | 555 | ch_d = b_matrix.size(-1) 556 | padding_n = (k - 1) // 2 557 | row_pad = Variable(torch.zeros(padding_n, b_l1, ch_d)) 558 | if torch.cuda.is_available(): 559 | row_pad = row_pad.cuda() 560 | after_row_pad = torch.cat([row_pad, b_matrix, row_pad], dim=0) 561 | col_pad = Variable(torch.zeros(b_l2 + 2 * padding_n, padding_n, ch_d)) 562 | if torch.cuda.is_available(): 563 | col_pad = col_pad.cuda() 564 | after_col_pad = torch.cat([col_pad, after_row_pad, col_pad], dim=1) 565 | 566 | for i in range(padding_n, padding_n + b_l2): 567 | for j in range(padding_n, padding_n + b_l1): 568 | i_start = i - padding_n 569 | j_start = j - padding_n 570 | window = after_col_pad[i_start:i_start + k, j_start:j_start + k, :] 571 | windows.append(window.contiguous().view(k * k * ch_d)) 572 | 573 | windows = torch.stack(windows) 574 | 575 | if dropout: 576 | dropout(windows) 577 | 578 | # print(windows) 579 | 580 | out_windows = conv_filter(windows) 581 | a, b = torch.chunk(out_windows, 2, dim=1) 582 | out = a * F.sigmoid(b) 583 | 584 | out_list = [] 585 | max_out_list = [] 586 | i = 0 587 | for b in range(batch_size): 588 | 589 | if not padding: 590 | c_l2 = l2[b] - k + 1 591 | c_l1 = l1[b] - k + 1 592 | else: 593 | c_l2 = l2[b] 594 | c_l1 = l1[b] 595 | 596 | b_end = i + c_l2 * c_l1 597 | b_matrix = out[i:b_end, :] 598 | 599 | max_out, _ = b_matrix.max(dim=0) 600 | max_out_list.append(max_out.squeeze()) 601 | 602 | dim = b_matrix.size(-1) 603 | out_list.append(b_matrix.view(c_l2, c_l1, dim)) 604 | i = b_end 605 | 606 | max_out = torch.stack(max_out_list) 607 | # for out in out_list: 608 | # max_out = torch.max(out.view(1, -1)) 609 | 610 | return out_list, max_out 611 | 612 | 613 | def max_along_time(inputs, lengths, list_in=False, batch_first=False): 614 | """ 615 | :param inputs: [T * B * D] 616 | :param lengths: [B] 617 | :return: [B * D] max_along_time 618 | """ 619 | ls = list(lengths) 620 | 621 | if not batch_first: 622 | if not list_in: 623 | b_seq_max_list = [] 624 | for i, l in enumerate(ls): 625 | seq_i = inputs[:l, i, :] 626 | seq_i_max, _ = seq_i.max(dim=0) 627 | seq_i_max = seq_i_max.squeeze() 628 | b_seq_max_list.append(seq_i_max) 629 | 630 | return torch.stack(b_seq_max_list) 631 | else: 632 | b_seq_max_list = [] 633 | for i, l in enumerate(ls): 634 | seq_i = inputs[i] 635 | seq_i_max, _ = seq_i.max(dim=0) 636 | seq_i_max = seq_i_max.squeeze() 637 | b_seq_max_list.append(seq_i_max) 638 | 639 | return torch.stack(b_seq_max_list) 640 | else: 641 | b_seq_max_list = [] 642 | for i, l in enumerate(ls): 643 | seq_i = inputs[i, :l, :] 644 | seq_i_max, _ = seq_i.max(dim=0) 645 | seq_i_max = seq_i_max.squeeze() 646 | b_seq_max_list.append(seq_i_max) 647 | 648 | return torch.stack(b_seq_max_list) 649 | 650 | 651 | def topk_along_time(inputs, k, lengths): 652 | """ 653 | :param inputs: [T * B * D] 654 | :param lengths: [B] 655 | :return: [B * D] max_along_time 656 | """ 657 | ls = list(lengths) 658 | d = inputs.size(-1) 659 | pad_z = Variable(inputs.data.new(1, d).zero_()) 660 | 661 | b_seq_max_list = [] 662 | for i, l in enumerate(ls): 663 | seq_i = inputs[:l, i, :] 664 | if l == 1: 665 | seq_i = torch.cat((seq_i, pad_z), dim=0) 666 | seq_i_topk, _ = torch.topk(seq_i, k, dim=0) 667 | b_seq_max_list.append(seq_i_topk.view(1, -1)) 668 | 669 | return torch.cat(b_seq_max_list) 670 | 671 | 672 | def topk_avg_along_time(inputs, k, lengths, list_in=False): 673 | ls = list(lengths) 674 | 675 | b_seq_max_list = [] 676 | for i, l in enumerate(ls): 677 | seq_i = inputs[:l, i, :] if not list_in else inputs[i] 678 | if l == 1: 679 | seq_i_topk_avg, _ = seq_i.max(dim=0) 680 | elif k > l: 681 | seq_i_topk, _ = torch.topk(seq_i, l, dim=0) 682 | seq_i_topk_avg = torch.sum(seq_i_topk, dim=0) / l 683 | else: 684 | seq_i_topk, _ = torch.topk(seq_i, k, dim=0) 685 | seq_i_topk_avg = torch.sum(seq_i_topk, dim=0) / k 686 | b_seq_max_list.append(seq_i_topk_avg) 687 | 688 | return torch.cat(b_seq_max_list) 689 | 690 | 691 | def comparing_conv_m(inputs, l1, l2, conv_layer: nn.Conv2d, mask_2d): 692 | batch_size = inputs.size(0) 693 | unit_d = conv_layer.out_channels // 2 694 | conv_out = conv_layer(inputs) 695 | 696 | a, b = torch.chunk(conv_out, 2, dim=1) 697 | gated_conv_out = a * F.sigmoid(b) * mask_2d[:, :unit_d, :, :] 698 | 699 | max_out_list = [] 700 | for b_i in range(batch_size): 701 | b_conv_out = gated_conv_out[b_i, :, :l2[b_i], :l1[b_i]] 702 | max_out, _ = torch.max(b_conv_out.contiguous().view(unit_d, -1), dim=1) 703 | # print(b_conv_out.size()) 704 | max_out_list.append(max_out.squeeze(1)) 705 | max_out = torch.stack(max_out_list) 706 | 707 | return gated_conv_out, max_out 708 | 709 | 710 | def text_conv1d(inputs, l1, conv_filter: nn.Linear, k_size, dropout=None, list_in=False, 711 | gate_way=True): 712 | """ 713 | :param inputs: [T * B * D] 714 | :param l1: [B] 715 | :param conv_filter: [k * D_in, D_out * 2] 716 | :param k_size: 717 | :param dropout: 718 | :param padding: 719 | :param list_in: 720 | :return: 721 | """ 722 | k = k_size 723 | batch_size = l1.size(0) 724 | d_in = inputs.size(2) if not list_in else inputs[0].size(1) 725 | unit_d = conv_filter.out_features // 2 726 | pad_n = (k - 1) // 2 727 | 728 | zeros_padding = Variable(inputs[0].data.new(pad_n, d_in).zero_()) 729 | 730 | batch_list = [] 731 | input_list = [] 732 | for b_i in range(batch_size): 733 | masked_in = inputs[:l1[b_i], b_i, :] if not list_in else inputs[b_i] 734 | if gate_way: 735 | input_list.append(masked_in) 736 | 737 | b_inputs = torch.cat([zeros_padding, masked_in, zeros_padding], dim=0) 738 | for i in range(l1[b_i]): 739 | # print(b_inputs[i:i+k]) 740 | batch_list.append(b_inputs[i:i+k].view(k * d_in)) 741 | 742 | batch_in = torch.stack(batch_list, dim=0) 743 | a, b = torch.chunk(conv_filter(batch_in), 2, 1) 744 | out = a * F.sigmoid(b) 745 | 746 | out_list = [] 747 | start = 0 748 | for b_i in range(batch_size): 749 | if gate_way: 750 | out_list.append(torch.cat((input_list[b_i], out[start:start + l1[b_i]]), dim=1)) 751 | else: 752 | out_list.append(out[start:start + l1[b_i]]) 753 | 754 | start = start + l1[b_i] 755 | 756 | # max_out_list = [] 757 | # for b_i in range(batch_size): 758 | # max_out, _ = torch.max(out_list[b_i], dim=0) 759 | # max_out_list.append(max_out) 760 | # max_out = torch.cat(max_out_list, 0) 761 | # 762 | # print(out_list) 763 | 764 | return out_list 765 | 766 | # Test something --------------------------------------------------------------------------------