├── SNLI_disan ├── src │ ├── __init__.py │ ├── model │ │ ├── __init__.py │ │ ├── .DS_Store │ │ ├── model_disan.py │ │ ├── exp_emb_attn.py │ │ ├── exp_emb_mul_attn.py │ │ ├── exp_bi_lstm_mul_attn.py │ │ ├── exp_emb_self_mul_attn.py │ │ └── exp_emb_dir_mul_attn.py │ ├── utils │ │ ├── __init__.py │ │ ├── tree │ │ │ ├── __init__.py │ │ │ ├── .DS_Store │ │ │ ├── tree2parent.py │ │ │ ├── str_transform.py │ │ │ └── shift_reduce.py │ │ ├── .DS_Store │ │ ├── time_counter.py │ │ ├── record_log.py │ │ ├── file.py │ │ └── nlp.py │ ├── nn_utils │ │ ├── __init__.py │ │ ├── .DS_Store │ │ ├── rnn_cell.py │ │ ├── basic.py │ │ ├── rnn.py │ │ └── general.py │ ├── .DS_Store │ ├── perform_recorder.py │ ├── graph_handler.py │ └── evaluator.py ├── dataset │ ├── glove │ │ ├── .gitkeep │ │ └── .DS_Store │ ├── snli_1.0 │ │ └── .gitkeep │ └── .DS_Store ├── pretrained_model │ └── .gitkeep ├── result │ ├── processed_data │ │ └── .gitkeep │ └── .DS_Store ├── .DS_Store ├── .gitignore ├── README.md ├── configs.py └── snli_main.py ├── SST_disan ├── src │ ├── __init__.py │ ├── model │ │ ├── __init__.py │ │ ├── .DS_Store │ │ ├── model_disan.py │ │ ├── exp_emb_dir_mul_attn.py │ │ └── template.py │ ├── utils │ │ ├── __init__.py │ │ ├── tree │ │ │ ├── __init__.py │ │ │ ├── .DS_Store │ │ │ └── shift_reduce.py │ │ ├── .DS_Store │ │ ├── time_counter.py │ │ ├── record_log.py │ │ ├── file.py │ │ └── nlp.py │ ├── nn_utils │ │ ├── __init__.py │ │ ├── .DS_Store │ │ ├── rnn_cell.py │ │ ├── basic.py │ │ ├── rnn.py │ │ ├── general.py │ │ └── integration.py │ ├── .DS_Store │ ├── perform_recorder.py │ ├── graph_handler.py │ ├── analysis.py │ └── evaluator.py ├── dataset │ ├── glove │ │ ├── .gitkeep │ │ └── .DS_Store │ ├── stanfordSentimentTreebank │ │ └── .gitkeep │ └── .DS_Store ├── pretrained_model │ └── .gitkeep ├── result │ └── processed_data │ │ └── .gitkeep ├── .DS_Store ├── .gitignore ├── sst_log_analysis.py ├── README.md └── configs.py ├── ReSAN └── README.md ├── README.md ├── .gitignore └── Fast-DiSA └── README.md /SNLI_disan/src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SST_disan/src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SNLI_disan/dataset/glove/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SNLI_disan/src/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SNLI_disan/src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SST_disan/dataset/glove/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SST_disan/src/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SST_disan/src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SNLI_disan/dataset/snli_1.0/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SNLI_disan/pretrained_model/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SNLI_disan/src/nn_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SNLI_disan/src/utils/tree/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SST_disan/pretrained_model/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SST_disan/src/nn_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SST_disan/src/utils/tree/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SNLI_disan/result/processed_data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SST_disan/result/processed_data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SST_disan/dataset/stanfordSentimentTreebank/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SNLI_disan/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoshen58/DiSAN/HEAD/SNLI_disan/.DS_Store -------------------------------------------------------------------------------- /SST_disan/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoshen58/DiSAN/HEAD/SST_disan/.DS_Store -------------------------------------------------------------------------------- /SST_disan/src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoshen58/DiSAN/HEAD/SST_disan/src/.DS_Store -------------------------------------------------------------------------------- /SNLI_disan/src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoshen58/DiSAN/HEAD/SNLI_disan/src/.DS_Store -------------------------------------------------------------------------------- /SNLI_disan/dataset/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoshen58/DiSAN/HEAD/SNLI_disan/dataset/.DS_Store -------------------------------------------------------------------------------- /SNLI_disan/result/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoshen58/DiSAN/HEAD/SNLI_disan/result/.DS_Store -------------------------------------------------------------------------------- /SST_disan/dataset/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoshen58/DiSAN/HEAD/SST_disan/dataset/.DS_Store -------------------------------------------------------------------------------- /SNLI_disan/src/model/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoshen58/DiSAN/HEAD/SNLI_disan/src/model/.DS_Store -------------------------------------------------------------------------------- /SNLI_disan/src/utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoshen58/DiSAN/HEAD/SNLI_disan/src/utils/.DS_Store -------------------------------------------------------------------------------- /SST_disan/src/model/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoshen58/DiSAN/HEAD/SST_disan/src/model/.DS_Store -------------------------------------------------------------------------------- /SST_disan/src/utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoshen58/DiSAN/HEAD/SST_disan/src/utils/.DS_Store -------------------------------------------------------------------------------- /SNLI_disan/src/nn_utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoshen58/DiSAN/HEAD/SNLI_disan/src/nn_utils/.DS_Store -------------------------------------------------------------------------------- /SST_disan/dataset/glove/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoshen58/DiSAN/HEAD/SST_disan/dataset/glove/.DS_Store -------------------------------------------------------------------------------- /SST_disan/src/nn_utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoshen58/DiSAN/HEAD/SST_disan/src/nn_utils/.DS_Store -------------------------------------------------------------------------------- /SNLI_disan/dataset/glove/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoshen58/DiSAN/HEAD/SNLI_disan/dataset/glove/.DS_Store -------------------------------------------------------------------------------- /SNLI_disan/src/utils/tree/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoshen58/DiSAN/HEAD/SNLI_disan/src/utils/tree/.DS_Store -------------------------------------------------------------------------------- /SST_disan/src/utils/tree/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taoshen58/DiSAN/HEAD/SST_disan/src/utils/tree/.DS_Store -------------------------------------------------------------------------------- /ReSAN/README.md: -------------------------------------------------------------------------------- 1 | # Reinforced Self-Attention (ReSA) 2 | 3 | Codes for ReSA are released at [here](https://github.com/taoshen58/ReSAN) 4 | 5 | 6 | -------------------------------------------------------------------------------- /SNLI_disan/src/utils/time_counter.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class TimeCounter(object): 5 | def __init__(self): 6 | self.data_round = 0 7 | self.epoch_time_list = [] 8 | self.batch_time_list = [] 9 | 10 | # run time 11 | self.start_time = None 12 | 13 | def add_start(self): 14 | self.start_time = time.time() 15 | 16 | def add_stop(self): 17 | assert self.start_time is not None 18 | self.batch_time_list.append(time.time() - self.start_time) 19 | self.start_time = None 20 | 21 | def update_data_round(self, data_round): 22 | if self.data_round == data_round: 23 | return None, None 24 | else: 25 | this_epoch_time = sum(self.batch_time_list) 26 | self.epoch_time_list.append(this_epoch_time) 27 | self.batch_time_list = [] 28 | self.data_round = data_round 29 | return this_epoch_time, \ 30 | 1.0 * sum(self.epoch_time_list)/len(self.epoch_time_list) if len(self.epoch_time_list) > 0 else 0 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /SST_disan/src/utils/time_counter.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class TimeCounter(object): 5 | def __init__(self): 6 | self.data_round = 0 7 | self.epoch_time_list = [] 8 | self.batch_time_list = [] 9 | 10 | # run time 11 | self.start_time = None 12 | 13 | def add_start(self): 14 | self.start_time = time.time() 15 | 16 | def add_stop(self): 17 | assert self.start_time is not None 18 | self.batch_time_list.append(time.time() - self.start_time) 19 | self.start_time = None 20 | 21 | def update_data_round(self, data_round): 22 | if self.data_round == data_round: 23 | return None, None 24 | else: 25 | this_epoch_time = sum(self.batch_time_list) 26 | self.epoch_time_list.append(this_epoch_time) 27 | self.batch_time_list = [] 28 | self.data_round = data_round 29 | return this_epoch_time, \ 30 | 1.0 * sum(self.epoch_time_list)/len(self.epoch_time_list) if len(self.epoch_time_list) > 0 else 0 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /SST_disan/src/nn_utils/rnn_cell.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.rnn import DropoutWrapper 3 | 4 | 5 | class SwitchableDropoutWrapper(DropoutWrapper): 6 | def __init__(self, cell, is_train, input_keep_prob=1.0, output_keep_prob=1.0, 7 | seed=None): 8 | super(SwitchableDropoutWrapper, self).__init__(cell, 9 | input_keep_prob=input_keep_prob, 10 | output_keep_prob=output_keep_prob, 11 | seed=seed) 12 | self.is_train = is_train 13 | 14 | def __call__(self, inputs, state, scope=None): 15 | outputs_do, new_state_do = super(SwitchableDropoutWrapper, self).__call__(inputs, state, scope=scope) 16 | tf.get_variable_scope().reuse_variables() 17 | outputs, new_state = self._cell(inputs, state, scope) 18 | outputs = tf.cond(self.is_train, lambda: outputs_do, lambda: outputs) 19 | if isinstance(state, tf.contrib.rnn.LSTMStateTuple): 20 | new_state = state.__class__(*[tf.cond(self.is_train, lambda: new_state_do_i, lambda: new_state_i) 21 | for new_state_do_i, new_state_i in zip(new_state_do, new_state)]) 22 | elif isinstance(state, tuple): 23 | new_state = state.__class__([tf.cond(self.is_train, lambda: new_state_do_i, lambda: new_state_i) 24 | for new_state_do_i, new_state_i in zip(new_state_do, new_state)]) 25 | else: 26 | new_state = tf.cond(self.is_train, lambda: new_state_do, lambda: new_state) 27 | return outputs, new_state 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /SNLI_disan/src/nn_utils/rnn_cell.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.rnn import DropoutWrapper 3 | 4 | 5 | class SwitchableDropoutWrapper(DropoutWrapper): 6 | def __init__(self, cell, is_train, input_keep_prob=1.0, output_keep_prob=1.0, 7 | seed=None): 8 | super(SwitchableDropoutWrapper, self).__init__(cell, 9 | input_keep_prob=input_keep_prob, 10 | output_keep_prob=output_keep_prob, 11 | seed=seed) 12 | self.is_train = is_train 13 | 14 | def __call__(self, inputs, state, scope=None): 15 | outputs_do, new_state_do = super(SwitchableDropoutWrapper, self).__call__(inputs, state, scope=scope) 16 | tf.get_variable_scope().reuse_variables() 17 | outputs, new_state = self._cell(inputs, state, scope) 18 | outputs = tf.cond(self.is_train, lambda: outputs_do, lambda: outputs) 19 | if isinstance(state, tf.contrib.rnn.LSTMStateTuple): 20 | new_state = state.__class__(*[tf.cond(self.is_train, lambda: new_state_do_i, lambda: new_state_i) 21 | for new_state_do_i, new_state_i in zip(new_state_do, new_state)]) 22 | elif isinstance(state, tuple): 23 | new_state = state.__class__([tf.cond(self.is_train, lambda: new_state_do_i, lambda: new_state_i) 24 | for new_state_do_i, new_state_i in zip(new_state_do, new_state)]) 25 | else: 26 | new_state = tf.cond(self.is_train, lambda: new_state_do, lambda: new_state) 27 | return outputs, new_state 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /SNLI_disan/src/utils/record_log.py: -------------------------------------------------------------------------------- 1 | from configs import cfg 2 | import os 3 | import time 4 | 5 | 6 | class RecordLog(object): 7 | def __init__(self,writeToFileInterval=20, fileName = 'log.txt'): 8 | self.writeToFileInterval = writeToFileInterval 9 | self.waitNumToFile = self.writeToFileInterval 10 | buildTime = '-'.join(time.asctime(time.localtime(time.time())).strip().split(' ')[1:-1]) 11 | buildTime = '-'.join(buildTime.split(':')) 12 | logFileName = buildTime#cfg.model_name[1:] + '_' + buildTime 13 | 14 | 15 | self.path = os.path.join(cfg.log_dir or cfg.standby_log_dir, logFileName+"_"+fileName) 16 | 17 | self.storage = [] 18 | 19 | 20 | def add(self, content = '-'*30, ifTime = False , ifPrint = True , ifSave = True): 21 | #timeStr = " ---" + str(time()) if ifTime else '' 22 | timeStr = " --- "+time.asctime(time.localtime(time.time())) if ifTime else '' 23 | logContent = content + timeStr 24 | if ifPrint: 25 | print(logContent) 26 | #check save 27 | if ifSave: 28 | self.storage.append(logContent) 29 | self.waitNumToFile -= 1 30 | if self.waitNumToFile == 0: 31 | self.waitNumToFile = self.writeToFileInterval 32 | self.writeToFile() 33 | self.storage = [] 34 | 35 | def writeToFile(self): 36 | 37 | with open(self.path, 'a', encoding='utf-8') as file: 38 | for ele in self.storage: 39 | file.write(ele + os.linesep) 40 | 41 | def done(self): 42 | self.add('Done') 43 | 44 | 45 | #def __del__(self): 46 | # with open(self.path, 'a', encoding='utf-8') as file: 47 | # for ele in self.storage: 48 | # file.write(ele + os.linesep) 49 | 50 | _logger = RecordLog(20) -------------------------------------------------------------------------------- /SST_disan/src/utils/record_log.py: -------------------------------------------------------------------------------- 1 | from configs import cfg 2 | import codecs,os 3 | import time 4 | 5 | 6 | class RecordLog(object): 7 | def __init__(self,writeToFileInterval=20, fileName = 'log.txt'): 8 | self.writeToFileInterval = writeToFileInterval 9 | self.waitNumToFile = self.writeToFileInterval 10 | buildTime = '-'.join(time.asctime(time.localtime(time.time())).strip().split(' ')[1:-1]) 11 | buildTime = '-'.join(buildTime.split(':')) 12 | logFileName = buildTime#cfg.model_name[1:] + '_' + buildTime 13 | 14 | 15 | self.path = os.path.join(cfg.log_dir or cfg.standby_log_dir, logFileName+"_"+fileName) 16 | 17 | self.storage = [] 18 | 19 | 20 | def add(self, content = '-'*30, ifTime = False , ifPrint = True , ifSave = True): 21 | #timeStr = " ---" + str(time()) if ifTime else '' 22 | timeStr = " --- "+time.asctime(time.localtime(time.time())) if ifTime else '' 23 | logContent = content + timeStr 24 | if ifPrint: 25 | print(logContent) 26 | #check save 27 | if ifSave: 28 | self.storage.append(logContent) 29 | self.waitNumToFile -= 1 30 | if self.waitNumToFile == 0: 31 | self.waitNumToFile = self.writeToFileInterval 32 | self.writeToFile() 33 | self.storage = [] 34 | 35 | def writeToFile(self): 36 | 37 | with open(self.path, 'a', encoding='utf-8') as file: 38 | for ele in self.storage: 39 | file.write(ele + os.linesep) 40 | 41 | def done(self): 42 | self.add('Done') 43 | 44 | 45 | #def __del__(self): 46 | # with open(self.path, 'a', encoding='utf-8') as file: 47 | # for ele in self.storage: 48 | # file.write(ele + os.linesep) 49 | 50 | _logger = RecordLog(20) -------------------------------------------------------------------------------- /SNLI_disan/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # data file 10 | *.pickle 11 | *.json 12 | *.jsonl 13 | *.txt 14 | !.gitkeep 15 | 16 | 17 | # Distribution / packaging 18 | .Python 19 | env/ 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | .hypothesis/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # dotenv 91 | .env 92 | 93 | # virtualenv 94 | .venv 95 | venv/ 96 | ENV/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | .idea/ 111 | result/dict/ 112 | result/log/ 113 | result/model/ 114 | -------------------------------------------------------------------------------- /SST_disan/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # data file 10 | *.pickle 11 | *.json 12 | *.jsonl 13 | *.txt 14 | !.gitkeep 15 | 16 | 17 | # Distribution / packaging 18 | .Python 19 | env/ 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | .hypothesis/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # dotenv 91 | .env 92 | 93 | # virtualenv 94 | .venv 95 | venv/ 96 | ENV/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | .idea/ 111 | result/dict/ 112 | result/log/ 113 | result/model/ 114 | -------------------------------------------------------------------------------- /SST_disan/sst_log_analysis.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def do_analyse_sst(file_path, dev=True, delta=0, stop=None): 4 | results = [] 5 | with open(file_path, 'r', encoding='utf-8') as file: 6 | find_entry = False 7 | output = [0, 0., 0., 0., 0., 0., 0.] # xx, dev, test, 8 | for line in file: 9 | if not find_entry: 10 | if line.startswith('data round'): # get step 11 | output[0] = int(line.split(' ')[-4].split(':')[-1]) 12 | if stop is not None and output[0] > stop: break 13 | if line.startswith('==> for dev'): # dev 14 | output[1] = float(line.split(' ')[-1]) 15 | output[2] = float(line.split(' ')[-4][:-1]) 16 | output[3] = float(line.split(' ')[-6][:-1]) 17 | find_entry = True 18 | else: 19 | if line.startswith('~~> for test'): # test 20 | output[4] = float(line.split(' ')[-1]) 21 | output[5] = float(line.split(' ')[-4][:-1]) 22 | output[6] = float(line.split(' ')[-6][:-1]) 23 | results.append(output) 24 | find_entry = False 25 | output = [0, 0., 0., 0., 0., 0., 0.] 26 | 27 | # max step 28 | if len(results) > 0: 29 | print('max step:', results[-1][0]) 30 | 31 | # sort 32 | sort = 1 if dev else 4 33 | sort += delta 34 | 35 | output = list(sorted(results, key=lambda elem: elem[sort], reverse=(not delta == 2))) 36 | 37 | for elem in output[:20]: 38 | print('step: %d, dev_sent: %.4f, dev: %.4f, dev_loss: %.4f, ' 39 | 'test_sent: %.4f, test: %.4f, test_loss: %.4f' % 40 | (elem[0], elem[1], elem[2], elem[3],elem[4], elem[5],elem[6])) 41 | 42 | 43 | if __name__ == '__main__': 44 | 45 | file_path = '/Users/tshen/Desktop/tmp/file_transfer/sst/Jul-22-17-56-41_log.txt' 46 | dev = True 47 | delta = 0 48 | 49 | do_analyse_sst(file_path, dev, delta, None) 50 | 51 | 52 | -------------------------------------------------------------------------------- /SNLI_disan/src/perform_recorder.py: -------------------------------------------------------------------------------- 1 | from configs import cfg 2 | import tensorflow as tf 3 | import os 4 | 5 | 6 | class PerformRecoder(object): 7 | def __init__(self, top_limit_num=3): 8 | self.top_limit_num = top_limit_num 9 | self.dev_top_list = [] # list of tuple(step, dev_accu) 10 | self.saver = tf.train.Saver(max_to_keep=None) 11 | 12 | def update_top_list(self, global_step, dev_accu, sess): 13 | cur_ckpt_path = self.ckpt_file_path_generator(global_step) 14 | self.dev_top_list.append([global_step, dev_accu]) 15 | self.dev_top_list = list(sorted(self.dev_top_list, key=lambda elem: elem[1], reverse=True)) 16 | 17 | if len(self.dev_top_list) <= self.top_limit_num: 18 | self.create_ckpt_file(sess, cur_ckpt_path) 19 | return True, None 20 | elif len(self.dev_top_list) == self.top_limit_num + 1: 21 | out_state = self.dev_top_list[-1] 22 | self.dev_top_list = self.dev_top_list[:-1] 23 | if out_state[0] == global_step: 24 | return False, None 25 | else: # add and delete 26 | self.delete_ckpt_file(self.ckpt_file_path_generator(out_state[0])) 27 | self.create_ckpt_file(sess, cur_ckpt_path) 28 | return True, out_state[0] 29 | 30 | else: 31 | raise RuntimeError() 32 | 33 | def ckpt_file_path_generator(self, step): 34 | return os.path.join(cfg.ckpt_dir, 'top_result_saver_step_%d.ckpt' % step) 35 | 36 | def delete_ckpt_file(self, ckpt_file_path): 37 | if os.path.isfile(ckpt_file_path+'.meta'): 38 | os.remove(ckpt_file_path+'.meta') 39 | if os.path.isfile(ckpt_file_path+'.index'): 40 | os.remove(ckpt_file_path+'.index') 41 | if os.path.isfile(ckpt_file_path+'.data-00000-of-00001'): 42 | os.remove(ckpt_file_path+'.data-00000-of-00001') 43 | 44 | def create_ckpt_file(self, sess, ckpt_file_path): 45 | self.saver.save(sess, ckpt_file_path) 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /SST_disan/src/perform_recorder.py: -------------------------------------------------------------------------------- 1 | from configs import cfg 2 | import tensorflow as tf 3 | import os 4 | 5 | 6 | class PerformRecoder(object): 7 | def __init__(self, top_limit_num=3): 8 | self.top_limit_num = top_limit_num 9 | self.dev_top_list = [] # list of tuple(step, dev_accu) 10 | self.saver = tf.train.Saver(max_to_keep=None) 11 | 12 | def update_top_list(self, global_step, dev_accu, sess): 13 | cur_ckpt_path = self.ckpt_file_path_generator(global_step) 14 | self.dev_top_list.append([global_step, dev_accu]) 15 | self.dev_top_list = list(sorted(self.dev_top_list, key=lambda elem: elem[1], reverse=True)) 16 | 17 | if len(self.dev_top_list) <= self.top_limit_num: 18 | self.create_ckpt_file(sess, cur_ckpt_path) 19 | return True, None 20 | elif len(self.dev_top_list) == self.top_limit_num + 1: 21 | out_state = self.dev_top_list[-1] 22 | self.dev_top_list = self.dev_top_list[:-1] 23 | if out_state[0] == global_step: 24 | return False, None 25 | else: # add and delete 26 | self.delete_ckpt_file(self.ckpt_file_path_generator(out_state[0])) 27 | self.create_ckpt_file(sess, cur_ckpt_path) 28 | return True, out_state[0] 29 | 30 | else: 31 | raise RuntimeError() 32 | 33 | def ckpt_file_path_generator(self, step): 34 | return os.path.join(cfg.ckpt_dir, 'top_result_saver_step_%d.ckpt' % step) 35 | 36 | def delete_ckpt_file(self, ckpt_file_path): 37 | if os.path.isfile(ckpt_file_path+'.meta'): 38 | os.remove(ckpt_file_path+'.meta') 39 | if os.path.isfile(ckpt_file_path+'.index'): 40 | os.remove(ckpt_file_path+'.index') 41 | if os.path.isfile(ckpt_file_path+'.data-00000-of-00001'): 42 | os.remove(ckpt_file_path+'.data-00000-of-00001') 43 | 44 | def create_ckpt_file(self, sess, ckpt_file_path): 45 | self.saver.save(sess, ckpt_file_path) 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /SST_disan/src/model/model_disan.py: -------------------------------------------------------------------------------- 1 | from configs import cfg 2 | from src.utils.record_log import _logger 3 | import tensorflow as tf 4 | from src.model.template import ModelTemplate 5 | 6 | from src.nn_utils.integration_func import generate_embedding_mat 7 | from src.nn_utils.nn import linear 8 | from src.nn_utils.disan import disan 9 | 10 | 11 | class ModelDiSAN(ModelTemplate): 12 | def __init__(self, token_emb_mat, glove_emb_mat, tds, cds, tl, scope): 13 | super(ModelDiSAN, self).__init__(token_emb_mat, glove_emb_mat, tds, cds, tl, scope) 14 | self.update_tensor_add_ema_and_opt() 15 | 16 | def build_network(self): 17 | _logger.add() 18 | _logger.add('building %s neural network structure...' % cfg.network_type) 19 | tds, cds = self.tds, self.cds 20 | tl = self.tl 21 | tel, cel, cos, ocd, fh = self.tel, self.cel, self.cos, self.ocd, self.fh 22 | hn = self.hn 23 | bs, sl, ol, mc = self.bs, self.sl, self.ol, self.mc 24 | 25 | with tf.variable_scope('emb'): 26 | token_emb_mat = generate_embedding_mat(tds, tel, init_mat=self.token_emb_mat, 27 | extra_mat=self.glove_emb_mat, extra_trainable=self.finetune_emb, 28 | scope='gene_token_emb_mat') 29 | emb = tf.nn.embedding_lookup(token_emb_mat, self.token_seq) # bs,sl,tel 30 | self.tensor_dict['emb'] = emb 31 | 32 | rep = disan( 33 | emb, self.token_mask, 'DiSAN', cfg.dropout, 34 | self.is_train, cfg.wd, 'relu', tensor_dict=self.tensor_dict, name='') 35 | 36 | with tf.variable_scope('output'): 37 | pre_logits = tf.nn.relu(linear([rep], hn, True, scope='pre_logits_linear', 38 | wd=cfg.wd, input_keep_prob=cfg.dropout, 39 | is_train=self.is_train)) # bs, hn 40 | logits = linear([pre_logits], self.output_class, False, scope='get_output', 41 | wd=cfg.wd, input_keep_prob=cfg.dropout, is_train=self.is_train) # bs, 5 42 | _logger.done() 43 | return logits 44 | 45 | 46 | -------------------------------------------------------------------------------- /SST_disan/src/graph_handler.py: -------------------------------------------------------------------------------- 1 | from configs import cfg 2 | from src.utils.record_log import _logger 3 | import tensorflow as tf 4 | 5 | 6 | class GraphHandler(object): 7 | def __init__(self, model): 8 | self.model = model 9 | self.saver = tf.train.Saver(max_to_keep=3) 10 | self.writer = None 11 | 12 | def initialize(self, sess): 13 | sess.run(tf.global_variables_initializer()) 14 | if cfg.load_model or cfg.mode != 'train': 15 | self.restore(sess) 16 | if cfg.mode == 'train': 17 | self.writer = tf.summary.FileWriter(logdir=cfg.summary_dir, graph=tf.get_default_graph()) 18 | 19 | def add_summary(self, summary, global_step): 20 | _logger.add() 21 | _logger.add('saving summary...') 22 | self.writer.add_summary(summary, global_step) 23 | _logger.done() 24 | 25 | def add_summaries(self, summaries, global_step): 26 | for summary in summaries: 27 | self.add_summary(summary, global_step) 28 | 29 | def save(self, sess, global_step = None): 30 | _logger.add() 31 | _logger.add('saving model to %s'% cfg.ckpt_path) 32 | self.saver.save(sess, cfg.ckpt_path, global_step) 33 | _logger.done() 34 | 35 | def restore(self,sess): 36 | _logger.add() 37 | # print(cfg.ckpt_dir) 38 | 39 | if cfg.load_step is None: 40 | if cfg.load_path is None: 41 | _logger.add('trying to restore from dir %s' % cfg.ckpt_dir) 42 | latest_checkpoint_path = tf.train.latest_checkpoint(cfg.ckpt_dir) 43 | else: 44 | latest_checkpoint_path = cfg.load_path 45 | else: 46 | latest_checkpoint_path = cfg.ckpt_path+'-'+str(cfg.load_step) 47 | 48 | if latest_checkpoint_path is not None: 49 | _logger.add('trying to restore from ckpt file %s' % latest_checkpoint_path) 50 | try: 51 | self.saver.restore(sess, latest_checkpoint_path) 52 | _logger.add('success to restore') 53 | except tf.errors.NotFoundError: 54 | _logger.add('failure to restore') 55 | if cfg.mode != 'train': raise FileNotFoundError('canot find model file') 56 | else: 57 | _logger.add('No check point file in dir %s '% cfg.ckpt_dir) 58 | if cfg.mode != 'train': raise FileNotFoundError('canot find model file') 59 | 60 | _logger.done() 61 | 62 | 63 | -------------------------------------------------------------------------------- /SNLI_disan/src/graph_handler.py: -------------------------------------------------------------------------------- 1 | from configs import cfg 2 | from src.utils.record_log import _logger 3 | import tensorflow as tf 4 | 5 | 6 | class GraphHandler(object): 7 | def __init__(self, model): 8 | self.model = model 9 | self.saver = tf.train.Saver(max_to_keep=3) 10 | self.writer = None 11 | 12 | def initialize(self, sess): 13 | sess.run(tf.global_variables_initializer()) 14 | if cfg.load_model or cfg.mode != 'train': 15 | self.restore(sess) 16 | if cfg.mode == 'train': 17 | self.writer = tf.summary.FileWriter(logdir=cfg.summary_dir, graph=tf.get_default_graph()) 18 | 19 | def add_summary(self, summary, global_step): 20 | _logger.add() 21 | _logger.add('saving summary...') 22 | self.writer.add_summary(summary, global_step) 23 | _logger.done() 24 | 25 | def add_summaries(self, summaries, global_step): 26 | for summary in summaries: 27 | self.add_summary(summary, global_step) 28 | 29 | def save(self, sess, global_step = None): 30 | _logger.add() 31 | _logger.add('saving model to %s'% cfg.ckpt_path) 32 | self.saver.save(sess, cfg.ckpt_path, global_step) 33 | _logger.done() 34 | 35 | def restore(self,sess): 36 | _logger.add() 37 | # print(cfg.ckpt_dir) 38 | 39 | if cfg.load_step is None: 40 | if cfg.load_path is None: 41 | _logger.add('trying to restore from dir %s' % cfg.ckpt_dir) 42 | latest_checkpoint_path = tf.train.latest_checkpoint(cfg.ckpt_dir) 43 | else: 44 | latest_checkpoint_path = cfg.load_path 45 | else: 46 | latest_checkpoint_path = cfg.ckpt_path+'-'+str(cfg.load_step) 47 | 48 | if latest_checkpoint_path is not None: 49 | _logger.add('trying to restore from ckpt file %s' % latest_checkpoint_path) 50 | try: 51 | self.saver.restore(sess, latest_checkpoint_path) 52 | _logger.add('success to restore') 53 | except tf.errors.NotFoundError: 54 | _logger.add('failure to restore') 55 | if cfg.mode != 'train': raise FileNotFoundError('canot find model file') 56 | else: 57 | _logger.add('No check point file in dir %s '% cfg.ckpt_dir) 58 | if cfg.mode != 'train': raise FileNotFoundError('canot find model file') 59 | 60 | _logger.done() 61 | 62 | 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Directional Self-Attention Network 2 | * This repo is the codes of [DiSAN: Directional Self-Attention Network for RNN/CNN-free Language Understanding](https://arxiv.org/abs/1709.04696). 3 | * This is python based codes implementation under tensorflow 1.2 DL framework. 4 | * The leaderboard of Stanford Natural Language Inference is available [here](https://nlp.stanford.edu/projects/snli/). 5 | * Please contact [Tao Shen](Tao.Shen@student.uts.edu.au) or open an issue for questions/suggestions. 6 | 7 | 8 | **Cite this paper using BibTex:** 9 | 10 | @inproceedings{shen2018disan, 11 | Author = {Shen, Tao and Zhou, Tianyi and Long, Guodong and Jiang, Jing and Pan, Shirui and Zhang, Chengqi}, 12 | Booktitle = {AAAI Conference on Artificial Intelligence}, 13 | Title = {DISAN: Directional self-attention network for rnn/cnn-free language understanding}, 14 | Year = {2018} 15 | } 16 | 17 | 18 | ## Overall Requirements 19 | * Python3 (verified on 3.5.2, or Anaconda3 4.2.0) 20 | * tensorflow>=1.2 21 | 22 | #### Python Packages: 23 | 24 | * numpy 25 | 26 | ------- 27 | ### This repo includes three part as follows: 28 | 1. Directionnal Self-Attention Network independent file -> file disan.py 29 | 2. DiSAN implementation for Stanford Natural Language Inference -> dir SNLI_disan 30 | 3. DiSAN implementation for Stanford Sentiment Classification -> dir SST_disan 31 | 32 | __The Usage of *disan.py* will be introduced below, and as for the implementation of SNLI and SST, please enter corresponding folder for further introduction.__ 33 | 34 | __And, Code for the other experiments (e.g. SICK, MPQA, CR etc.) appeared in the paper is under preparation.__ 35 | 36 | ------- 37 | ## Usage of disan.py 38 | 39 | ### Parameters: 40 | 41 | * param **rep\_tensor**: 3D tensorflow dense float tensor [batch\_size, max\_len, dim] 42 | * param **rep\_mask**: 2D tensorflow bool tensor as mask for rep\_tensor, [batch\_size, max\_len] 43 | * param **scope**: tensorflow variable scope 44 | * param **keep\_prob**: float, dropout keep probability 45 | * param **is\_train**: tensorflow bool scalar 46 | * param **wd**: if wd>0, add related tensor to tf collectoion "reg_vars" for further l2 decay 47 | * param **activation**: disan activation function [elu|relu|selu] 48 | * param **tensor\_dict**: a dict to record disan internal attention result (insignificance) 49 | * param **name**: record name suffix (insignificance) 50 | 51 | ### Output: 52 | 2D tensorflow dense float tensor, which shape is [batch\_size, dim] as the encoding result for each sentence. 53 | 54 | ------ 55 | ## Acknowledgements 56 | * Some basic neural networks are copied from [Minjoon's Repo](https://github.com/allenai/bi-att-flow), including RNN cell, dropout-able dynamic RNN etc. 57 | 58 | -------------------------------------------------------------------------------- /SNLI_disan/src/utils/tree/tree2parent.py: -------------------------------------------------------------------------------- 1 | 2 | def transform_tree_to_parent_index(tree_structure): 3 | # 0. get parent node starting index (how many token this tree have) 4 | def recursive_count_leaf_node(tree): 5 | if len(tree.children_nodes) > 0: 6 | leaf_number = sum([recursive_count_leaf_node(node) for node in tree.children_nodes]) 7 | return leaf_number 8 | else: 9 | return 1 10 | parent_node_start_index = recursive_count_leaf_node(tree_structure) + 1 11 | 12 | # 1. assign index for leaf node and non-leaf node separately 13 | def recursive_assign_index_for_tree(tree, patent_index, non_leaf_index, leaf_index): 14 | if len(tree.children_nodes) > 0: # non-leaf 15 | # non-leaf node 16 | tree.node_index = non_leaf_index 17 | tree.parent_index = patent_index 18 | # for its children nodes 19 | now_non_leaf_index = non_leaf_index + 1 20 | now_leaf_node_index = leaf_index 21 | for child_node in tree.children_nodes: 22 | now_non_leaf_index,now_leaf_node_index = recursive_assign_index_for_tree( 23 | child_node, tree.node_index, now_non_leaf_index, now_leaf_node_index) 24 | return now_non_leaf_index, now_leaf_node_index 25 | else: 26 | # leaf node 27 | tree.node_index = leaf_index 28 | tree.parent_index = patent_index 29 | return non_leaf_index, leaf_index + 1 30 | 31 | recursive_assign_index_for_tree(tree_structure, 0, parent_node_start_index, 1) 32 | 33 | # 2. get leaf_node_index_seq for all_nodes 34 | def recursive_gene_leaf_indices(tree): 35 | if len(tree.children_nodes) > 0: # non-leaf 36 | tree.leaf_node_index_seq = [] 37 | for child_node in tree.children_nodes: 38 | tree.leaf_node_index_seq += recursive_gene_leaf_indices(child_node) 39 | return tree.leaf_node_index_seq 40 | else: 41 | tree.leaf_node_index_seq = [tree.node_index] 42 | return tree.leaf_node_index_seq 43 | recursive_gene_leaf_indices(tree_structure) 44 | 45 | # 3. get all node as list 46 | def recursive_get_all_nodes(tree): 47 | if len(tree.children_nodes) > 0: # non-leaf 48 | nodes = [tree] 49 | for child_node in tree.children_nodes: 50 | nodes += recursive_get_all_nodes(child_node) 51 | return nodes 52 | else: 53 | return [tree] 54 | all_nodes = recursive_get_all_nodes(tree_structure) 55 | 56 | # 4. sort for all nodes 57 | all_nodes_sorted = list(sorted(all_nodes, key=lambda node: node.node_index)) 58 | 59 | 60 | 61 | return tree_structure, all_nodes_sorted 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /SNLI_disan/src/model/model_disan.py: -------------------------------------------------------------------------------- 1 | from configs import cfg 2 | from src.utils.record_log import _logger 3 | import tensorflow as tf 4 | 5 | from src.model.model_template import ModelTemplate 6 | from src.nn_utils.nn import linear 7 | from src.nn_utils.integration_func import generate_embedding_mat 8 | from src.nn_utils.disan import disan 9 | 10 | 11 | class ModelDiSAN(ModelTemplate): 12 | def __init__(self, token_emb_mat, glove_emb_mat, tds, cds, tl, scope): 13 | super(ModelDiSAN, self).__init__(token_emb_mat, glove_emb_mat, tds, cds, tl, scope) 14 | self.update_tensor_add_ema_and_opt() 15 | 16 | def build_network(self): 17 | _logger.add() 18 | _logger.add('building %s neural network structure...' % cfg.network_type) 19 | tds, cds = self.tds, self.cds 20 | tl = self.tl 21 | tel, cel, cos, ocd, fh = self.tel, self.cel, self.cos, self.ocd, self.fh 22 | hn = self.hn 23 | bs, sl1, sl2 = self.bs, self.sl1, self.sl2 24 | 25 | with tf.variable_scope('emb'): 26 | token_emb_mat = generate_embedding_mat(tds, tel, init_mat=self.token_emb_mat, 27 | extra_mat=self.glove_emb_mat, extra_trainable=self.finetune_emb, 28 | scope='gene_token_emb_mat') 29 | s1_emb = tf.nn.embedding_lookup(token_emb_mat, self.sent1_token) # bs,sl1,tel 30 | s2_emb = tf.nn.embedding_lookup(token_emb_mat, self.sent2_token) # bs,sl2,tel 31 | self.tensor_dict['s1_emb'] = s1_emb 32 | self.tensor_dict['s2_emb'] = s2_emb 33 | 34 | with tf.variable_scope('sent_enc'): 35 | s1_rep = disan( 36 | s1_emb, self.sent1_token_mask, 'DiSAN', cfg.dropout, self.is_train, cfg.wd, 37 | 'elu', self.tensor_dict, 's1' 38 | ) 39 | self.tensor_dict['s1_rep'] = s1_rep 40 | 41 | tf.get_variable_scope().reuse_variables() 42 | 43 | s2_rep = disan( 44 | s2_emb, self.sent2_token_mask, 'DiSAN', cfg.dropout, self.is_train, cfg.wd, 45 | 'elu', self.tensor_dict, 's2' 46 | ) 47 | self.tensor_dict['s2_rep'] = s2_rep 48 | 49 | 50 | with tf.variable_scope('output'): 51 | out_rep = tf.concat([s1_rep, s2_rep, s1_rep - s2_rep, s1_rep * s2_rep], -1) 52 | pre_output = tf.nn.elu(linear([out_rep], hn, True, 0., scope= 'pre_output', squeeze=False, 53 | wd=cfg.wd, input_keep_prob=cfg.dropout,is_train=self.is_train)) 54 | logits = linear([pre_output], self.output_class, True, 0., scope= 'logits', squeeze=False, 55 | wd=cfg.wd, input_keep_prob=cfg.dropout,is_train=self.is_train) 56 | self.tensor_dict[logits] = logits 57 | return logits # logits 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /SNLI_disan/src/model/exp_emb_attn.py: -------------------------------------------------------------------------------- 1 | from configs import cfg 2 | from src.utils.record_log import _logger 3 | import tensorflow as tf 4 | 5 | from src.model.model_template import ModelTemplate 6 | from src.nn_utils.nn import linear 7 | from src.nn_utils.integration_func import generate_embedding_mat,\ 8 | traditional_attention 9 | 10 | 11 | class ModelExpEmbAttn(ModelTemplate): 12 | def __init__(self, token_emb_mat, glove_emb_mat, tds, cds, tl, scope): 13 | super(ModelExpEmbAttn, self).__init__(token_emb_mat, glove_emb_mat, tds, cds, tl, scope) 14 | self.update_tensor_add_ema_and_opt() 15 | 16 | def build_network(self): 17 | _logger.add() 18 | _logger.add('building %s neural network structure...' % cfg.network_type) 19 | tds, cds = self.tds, self.cds 20 | tl = self.tl 21 | tel, cel, cos, ocd, fh = self.tel, self.cel, self.cos, self.ocd, self.fh 22 | hn = self.hn 23 | bs, sl1, sl2 = self.bs, self.sl1, self.sl2 24 | 25 | with tf.variable_scope('emb'): 26 | token_emb_mat = generate_embedding_mat(tds, tel, init_mat=self.token_emb_mat, 27 | extra_mat=self.glove_emb_mat, extra_trainable=self.finetune_emb, 28 | scope='gene_token_emb_mat') 29 | s1_emb = tf.nn.embedding_lookup(token_emb_mat, self.sent1_token) # bs,sl1,tel 30 | s2_emb = tf.nn.embedding_lookup(token_emb_mat, self.sent2_token) # bs,sl2,tel 31 | self.tensor_dict['s1_emb'] = s1_emb 32 | self.tensor_dict['s2_emb'] = s2_emb 33 | 34 | with tf.variable_scope('sent_enc_attn'): 35 | s1_rep = traditional_attention( 36 | s1_emb, self.sent1_token_mask, 'traditional_attention', 37 | cfg.dropout, self.is_train, cfg.wd, 38 | tensor_dict=self.tensor_dict, name='s1_attn') 39 | tf.get_variable_scope().reuse_variables() 40 | s2_rep = traditional_attention( 41 | s2_emb, self.sent2_token_mask, 'traditional_attention', 42 | cfg.dropout, self.is_train, cfg.wd, 43 | tensor_dict=self.tensor_dict, name='s2_attn') 44 | 45 | self.tensor_dict['s1_rep'] = s1_rep 46 | self.tensor_dict['s2_rep'] = s2_rep 47 | 48 | with tf.variable_scope('output'): 49 | out_rep = tf.concat([s1_rep, s2_rep, s1_rep - s2_rep, s1_rep * s2_rep], -1) 50 | pre_output = tf.nn.elu(linear([out_rep], hn, True, 0., scope= 'pre_output', squeeze=False, 51 | wd=cfg.wd, input_keep_prob=cfg.dropout,is_train=self.is_train)) 52 | logits = linear([pre_output], self.output_class, True, 0., scope= 'logits', squeeze=False, 53 | wd=cfg.wd, input_keep_prob=cfg.dropout,is_train=self.is_train) 54 | self.tensor_dict[logits] = logits 55 | return logits # logits 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /SNLI_disan/src/model/exp_emb_mul_attn.py: -------------------------------------------------------------------------------- 1 | from configs import cfg 2 | from src.utils.record_log import _logger 3 | import tensorflow as tf 4 | 5 | from src.model.model_template import ModelTemplate 6 | from src.nn_utils.nn import linear 7 | from src.nn_utils.integration_func import multi_dimensional_attention, generate_embedding_mat 8 | 9 | class ModelExpEmbMulAttn(ModelTemplate): 10 | def __init__(self, token_emb_mat, glove_emb_mat, tds, cds, tl, scope): 11 | super(ModelExpEmbMulAttn, self).__init__(token_emb_mat, glove_emb_mat, tds, cds, tl, scope) 12 | self.update_tensor_add_ema_and_opt() 13 | 14 | def build_network(self): 15 | _logger.add() 16 | _logger.add('building %s neural network structure...' % cfg.network_type) 17 | tds, cds = self.tds, self.cds 18 | tl = self.tl 19 | tel, cel, cos, ocd, fh = self.tel, self.cel, self.cos, self.ocd, self.fh 20 | hn = self.hn 21 | bs, sl1, sl2 = self.bs, self.sl1, self.sl2 22 | 23 | with tf.variable_scope('emb'): 24 | token_emb_mat = generate_embedding_mat(tds, tel, init_mat=self.token_emb_mat, 25 | extra_mat=self.glove_emb_mat, extra_trainable=self.finetune_emb, 26 | scope='gene_token_emb_mat') 27 | s1_emb = tf.nn.embedding_lookup(token_emb_mat, self.sent1_token) # bs,sl1,tel 28 | s2_emb = tf.nn.embedding_lookup(token_emb_mat, self.sent2_token) # bs,sl2,tel 29 | self.tensor_dict['s1_emb'] = s1_emb 30 | self.tensor_dict['s2_emb'] = s2_emb 31 | 32 | with tf.variable_scope('sent_enc_attn'): 33 | s1_rep = multi_dimensional_attention( 34 | s1_emb, self.sent1_token_mask, 'multi_dimensional_attention', 35 | cfg.dropout, self.is_train, cfg.wd, 36 | tensor_dict=self.tensor_dict, name='s1_attn') 37 | tf.get_variable_scope().reuse_variables() 38 | s2_rep = multi_dimensional_attention( 39 | s2_emb, self.sent2_token_mask, 'multi_dimensional_attention', 40 | cfg.dropout, self.is_train, cfg.wd, 41 | tensor_dict=self.tensor_dict, name='s2_attn') 42 | 43 | self.tensor_dict['s1_rep'] = s1_rep 44 | self.tensor_dict['s2_rep'] = s2_rep 45 | 46 | with tf.variable_scope('output'): 47 | out_rep = tf.concat([s1_rep, s2_rep, s1_rep - s2_rep, s1_rep * s2_rep], -1) 48 | pre_output = tf.nn.elu(linear([out_rep], hn, True, 0., scope= 'pre_output', squeeze=False, 49 | wd=cfg.wd, input_keep_prob=cfg.dropout,is_train=self.is_train)) 50 | logits = linear([pre_output], self.output_class, True, 0., scope= 'logits', squeeze=False, 51 | wd=cfg.wd, input_keep_prob=cfg.dropout,is_train=self.is_train) 52 | self.tensor_dict[logits] = logits 53 | return logits # logits 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /SST_disan/src/model/exp_emb_dir_mul_attn.py: -------------------------------------------------------------------------------- 1 | from configs import cfg 2 | from src.utils.record_log import _logger 3 | import tensorflow as tf 4 | from src.model.template import ModelTemplate 5 | 6 | from src.nn_utils.integration_func import multi_dimensional_attention, generate_embedding_mat,\ 7 | directional_attention_with_dense 8 | from src.nn_utils.nn import linear 9 | 10 | 11 | class ModelExpEmbDirMulAttn(ModelTemplate): 12 | def __init__(self, token_emb_mat, glove_emb_mat, tds, cds, tl, scope): 13 | super(ModelExpEmbDirMulAttn, self).__init__(token_emb_mat, glove_emb_mat, tds, cds, tl, scope) 14 | self.update_tensor_add_ema_and_opt() 15 | 16 | def build_network(self): 17 | _logger.add() 18 | _logger.add('building %s neural network structure...' % cfg.network_type) 19 | tds, cds = self.tds, self.cds 20 | tl = self.tl 21 | tel, cel, cos, ocd, fh = self.tel, self.cel, self.cos, self.ocd, self.fh 22 | hn = self.hn 23 | bs, sl, ol, mc = self.bs, self.sl, self.ol, self.mc 24 | 25 | with tf.variable_scope('emb'): 26 | token_emb_mat = generate_embedding_mat(tds, tel, init_mat=self.token_emb_mat, 27 | extra_mat=self.glove_emb_mat, extra_trainable=self.finetune_emb, 28 | scope='gene_token_emb_mat') 29 | emb = tf.nn.embedding_lookup(token_emb_mat, self.token_seq) # bs,sl,tel 30 | self.tensor_dict['emb'] = emb 31 | 32 | with tf.variable_scope('ct_attn'): 33 | rep_fw = directional_attention_with_dense( 34 | emb, self.token_mask, 'forward', 'dir_attn_fw', 35 | cfg.dropout, self.is_train, cfg.wd, 'relu', 36 | tensor_dict=self.tensor_dict, name='fw_attn') 37 | rep_bw = directional_attention_with_dense( 38 | emb, self.token_mask, 'backward', 'dir_attn_bw', 39 | cfg.dropout, self.is_train, cfg.wd, 'relu', 40 | tensor_dict=self.tensor_dict, name='bw_attn') 41 | 42 | seq_rep = tf.concat([rep_fw, rep_bw], -1) 43 | 44 | with tf.variable_scope('sent_enc_attn'): 45 | rep = multi_dimensional_attention( 46 | seq_rep, self.token_mask, 'multi_dimensional_attention', 47 | cfg.dropout, self.is_train, cfg.wd, 'relu', 48 | tensor_dict=self.tensor_dict, name='attn') 49 | 50 | with tf.variable_scope('output'): 51 | pre_logits = tf.nn.relu(linear([rep], hn, True, scope='pre_logits_linear', 52 | wd=cfg.wd, input_keep_prob=cfg.dropout, 53 | is_train=self.is_train)) # bs, hn 54 | logits = linear([pre_logits], self.output_class, False, scope='get_output', 55 | wd=cfg.wd, input_keep_prob=cfg.dropout, is_train=self.is_train) # bs, 5 56 | _logger.done() 57 | return logits 58 | 59 | 60 | -------------------------------------------------------------------------------- /SNLI_disan/src/nn_utils/basic.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Tensorflow Implementation of the Scaled ELU function and Dropout 4 | ''' 5 | 6 | from __future__ import absolute_import, division, print_function 7 | import numbers 8 | from tensorflow.contrib import layers 9 | from tensorflow.python.framework import ops 10 | from tensorflow.python.framework import tensor_shape 11 | from tensorflow.python.framework import tensor_util 12 | from tensorflow.python.ops import math_ops 13 | from tensorflow.python.ops import random_ops 14 | from tensorflow.python.ops import array_ops 15 | from tensorflow.python.layers import utils 16 | import tensorflow as tf 17 | 18 | # (1) scale inputs to zero mean and unit variance 19 | 20 | 21 | # (2) use SELUs 22 | def selu(x): 23 | with ops.name_scope('elu') as scope: 24 | alpha = 1.6732632423543772848170429916717 25 | scale = 1.0507009873554804934193349852946 26 | return scale*tf.where(x>=0.0, x, alpha*tf.nn.elu(x)) 27 | 28 | 29 | # (3) initialize weights with stddev sqrt(1/n) 30 | # e.g. use: 31 | initializer = layers.variance_scaling_initializer(factor=1.0, mode='FAN_IN') 32 | 33 | 34 | # (4) use this dropout 35 | def dropout_selu(x, rate, alpha= -1.7580993408473766, fixedPointMean=0.0, fixedPointVar=1.0, 36 | noise_shape=None, seed=None, name=None, training=False): 37 | """Dropout to a value with rescaling.""" 38 | 39 | def dropout_selu_impl(x, rate, alpha, noise_shape, seed, name): 40 | keep_prob = 1.0 - rate 41 | x = ops.convert_to_tensor(x, name="x") 42 | if isinstance(keep_prob, numbers.Real) and not 0 < keep_prob <= 1: 43 | raise ValueError("keep_prob must be a scalar tensor or a float in the " 44 | "range (0, 1], got %g" % keep_prob) 45 | keep_prob = ops.convert_to_tensor(keep_prob, dtype=x.dtype, name="keep_prob") 46 | keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) 47 | 48 | alpha = ops.convert_to_tensor(alpha, dtype=x.dtype, name="alpha") 49 | alpha.get_shape().assert_is_compatible_with(tensor_shape.scalar()) 50 | 51 | if tensor_util.constant_value(keep_prob) == 1: 52 | return x 53 | 54 | noise_shape = noise_shape if noise_shape is not None else array_ops.shape(x) 55 | random_tensor = keep_prob 56 | random_tensor += random_ops.random_uniform(noise_shape, seed=seed, dtype=x.dtype) 57 | binary_tensor = math_ops.floor(random_tensor) 58 | ret = x * binary_tensor + alpha * (1-binary_tensor) 59 | 60 | a = math_ops.sqrt(fixedPointVar / (keep_prob *((1-keep_prob) * math_ops.pow(alpha-fixedPointMean,2) + fixedPointVar))) 61 | 62 | b = fixedPointMean - a * (keep_prob * fixedPointMean + (1 - keep_prob) * alpha) 63 | ret = a * ret + b 64 | ret.set_shape(x.get_shape()) 65 | return ret 66 | 67 | with ops.name_scope(name, "dropout", [x]) as name: 68 | return utils.smart_cond(training, 69 | lambda: dropout_selu_impl(x, rate, alpha, noise_shape, seed, name), 70 | lambda: array_ops.identity(x)) -------------------------------------------------------------------------------- /SST_disan/src/nn_utils/basic.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Tensorflow Implementation of the Scaled ELU function and Dropout 4 | ''' 5 | 6 | from __future__ import absolute_import, division, print_function 7 | import numbers 8 | from tensorflow.contrib import layers 9 | from tensorflow.python.framework import ops 10 | from tensorflow.python.framework import tensor_shape 11 | from tensorflow.python.framework import tensor_util 12 | from tensorflow.python.ops import math_ops 13 | from tensorflow.python.ops import random_ops 14 | from tensorflow.python.ops import array_ops 15 | from tensorflow.python.layers import utils 16 | import tensorflow as tf 17 | 18 | # (1) scale inputs to zero mean and unit variance 19 | 20 | 21 | # (2) use SELUs 22 | def selu(x): 23 | with ops.name_scope('elu') as scope: 24 | alpha = 1.6732632423543772848170429916717 25 | scale = 1.0507009873554804934193349852946 26 | return scale*tf.where(x>=0.0, x, alpha*tf.nn.elu(x)) 27 | 28 | 29 | # (3) initialize weights with stddev sqrt(1/n) 30 | # e.g. use: 31 | initializer = layers.variance_scaling_initializer(factor=1.0, mode='FAN_IN') 32 | 33 | 34 | # (4) use this dropout 35 | def dropout_selu(x, rate, alpha= -1.7580993408473766, fixedPointMean=0.0, fixedPointVar=1.0, 36 | noise_shape=None, seed=None, name=None, training=False): 37 | """Dropout to a value with rescaling.""" 38 | 39 | def dropout_selu_impl(x, rate, alpha, noise_shape, seed, name): 40 | keep_prob = 1.0 - rate 41 | x = ops.convert_to_tensor(x, name="x") 42 | if isinstance(keep_prob, numbers.Real) and not 0 < keep_prob <= 1: 43 | raise ValueError("keep_prob must be a scalar tensor or a float in the " 44 | "range (0, 1], got %g" % keep_prob) 45 | keep_prob = ops.convert_to_tensor(keep_prob, dtype=x.dtype, name="keep_prob") 46 | keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) 47 | 48 | alpha = ops.convert_to_tensor(alpha, dtype=x.dtype, name="alpha") 49 | alpha.get_shape().assert_is_compatible_with(tensor_shape.scalar()) 50 | 51 | if tensor_util.constant_value(keep_prob) == 1: 52 | return x 53 | 54 | noise_shape = noise_shape if noise_shape is not None else array_ops.shape(x) 55 | random_tensor = keep_prob 56 | random_tensor += random_ops.random_uniform(noise_shape, seed=seed, dtype=x.dtype) 57 | binary_tensor = math_ops.floor(random_tensor) 58 | ret = x * binary_tensor + alpha * (1-binary_tensor) 59 | 60 | a = math_ops.sqrt(fixedPointVar / (keep_prob *((1-keep_prob) * math_ops.pow(alpha-fixedPointMean,2) + fixedPointVar))) 61 | 62 | b = fixedPointMean - a * (keep_prob * fixedPointMean + (1 - keep_prob) * alpha) 63 | ret = a * ret + b 64 | ret.set_shape(x.get_shape()) 65 | return ret 66 | 67 | with ops.name_scope(name, "dropout", [x]) as name: 68 | return utils.smart_cond(training, 69 | lambda: dropout_selu_impl(x, rate, alpha, noise_shape, seed, name), 70 | lambda: array_ops.identity(x)) -------------------------------------------------------------------------------- /SNLI_disan/src/utils/tree/str_transform.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class TreeNode(object): 4 | def __init__(self, is_leaf, tag=None, token=None): 5 | self.tag = tag 6 | self.is_leaf = is_leaf 7 | self.token = token 8 | self.children_nodes = [] 9 | # for transformation 10 | self.parent_index = None 11 | self.node_index = None 12 | self.leaf_node_index_seq = [] 13 | 14 | 15 | def recursive_build_penn_format(seq): 16 | if seq[0] == '(' and seq[-1] == ')' and len(seq[1:-1]) > 2: 17 | node = TreeNode(False, tag=seq[1]) 18 | children_seqs = [] 19 | children_seq = [] 20 | counter = 0 21 | for token in seq[2:-1]: 22 | children_seq.append(token) 23 | if token == '(': counter += 1 24 | elif token == ')': counter -= 1 25 | if counter == 0: 26 | children_seqs.append(children_seq) 27 | children_seq = [] 28 | node.children_nodes = [recursive_build_penn_format(children_seq) for children_seq in children_seqs] 29 | return node 30 | else: 31 | new_seq = seq[1:-1] 32 | assert len(new_seq) == 2, seq 33 | node = TreeNode(True, tag=new_seq[0], token=new_seq[1]) 34 | return node 35 | 36 | 37 | def recursive_build_binary(seq): 38 | if seq[0] == '(' and seq[-1] == ')' and len(seq): 39 | node = TreeNode(is_leaf=False) 40 | children_seqs = [] 41 | children_seq = [] 42 | counter = 0 43 | for token in seq[1:-1]: 44 | children_seq.append(token) 45 | if token == '(': 46 | counter += 1 47 | elif token == ')': 48 | counter -= 1 49 | if counter == 0: 50 | children_seqs.append(children_seq) 51 | children_seq = [] 52 | node.children_nodes = [recursive_build_binary(children_seq) for children_seq in children_seqs] 53 | return node 54 | else: 55 | assert len(seq) == 1, seq 56 | node = TreeNode(is_leaf=True, token=seq[0]) 57 | return node 58 | 59 | 60 | def check_tree(tree, layer): 61 | if len(tree.children_nodes) > 0: 62 | now_str = '%snon_leaf: %s:%s, %s:%s\n' % \ 63 | ('\t'* layer, tree.tag, tree.token, tree.node_index, tree.parent_index) 64 | s = ''.join([check_tree(node, layer+1) for node in tree.children_nodes]) 65 | return now_str + s 66 | else: 67 | return '%sleaf: %s:%s, %s:%s\n' % ('\t'* layer, tree.tag, tree.token, tree.node_index, tree.parent_index) 68 | 69 | 70 | def tokenize_str_format_tree(tree_str): 71 | 72 | # 1. spilt by ' ' 73 | raw_token_list = tree_str.split(' ') 74 | # 2. split when find '(' or ')' 75 | token_list = [] 76 | for token in raw_token_list: 77 | new_token_list = [] 78 | idx_in_token = 0 79 | for idx_char, char in enumerate(token): 80 | if char == '(' or char == ')': 81 | if idx_char > idx_in_token: 82 | new_token_list.append(token[idx_in_token: idx_char]) 83 | new_token_list.append(char) 84 | idx_in_token = idx_char + 1 85 | if idx_in_token < len(token): 86 | new_token_list.append(token[idx_in_token:]) 87 | token_list += new_token_list 88 | return token_list 89 | 90 | 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /SST_disan/src/nn_utils/rnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from src.nn_utils.general import flatten, reconstruct 3 | 4 | 5 | def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, 6 | dtype=None, parallel_iterations=None, swap_memory=False, 7 | time_major=False, scope=None): 8 | assert not time_major # TODO : to be implemented later! 9 | flat_inputs = flatten(inputs, 2) # [-1, J, d] 10 | flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64') 11 | 12 | flat_outputs, final_state = tf.nn.dynamic_rnn(cell, flat_inputs, sequence_length=flat_len, 13 | initial_state=initial_state, dtype=dtype, 14 | parallel_iterations=parallel_iterations, swap_memory=swap_memory, 15 | time_major=time_major, scope=scope) 16 | 17 | outputs = reconstruct(flat_outputs, inputs, 2) 18 | return outputs, final_state 19 | 20 | 21 | def bw_dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, 22 | dtype=None, parallel_iterations=None, swap_memory=False, 23 | time_major=False, scope=None): 24 | assert not time_major # TODO : to be implemented later! 25 | 26 | flat_inputs = flatten(inputs, 2) # [-1, J, d] 27 | flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64') 28 | 29 | flat_inputs = tf.reverse(flat_inputs, 1) if sequence_length is None \ 30 | else tf.reverse_sequence(flat_inputs, sequence_length, 1) 31 | flat_outputs, final_state = tf.nn.dynamic_rnn(cell, flat_inputs, sequence_length=flat_len, 32 | initial_state=initial_state, dtype=dtype, 33 | parallel_iterations=parallel_iterations, swap_memory=swap_memory, 34 | time_major=time_major, scope=scope) 35 | flat_outputs = tf.reverse(flat_outputs, 1) if sequence_length is None \ 36 | else tf.reverse_sequence(flat_outputs, sequence_length, 1) 37 | 38 | outputs = reconstruct(flat_outputs, inputs, 2) 39 | return outputs, final_state 40 | 41 | 42 | def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, 43 | initial_state_fw=None, initial_state_bw=None, 44 | dtype=None, parallel_iterations=None, 45 | swap_memory=False, time_major=False, scope=None): 46 | assert not time_major 47 | 48 | flat_inputs = flatten(inputs, 2) # [-1, J, d] 49 | flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64') 50 | 51 | (flat_fw_outputs, flat_bw_outputs), final_state = \ 52 | tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len, 53 | initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw, 54 | dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory, 55 | time_major=time_major, scope=scope) 56 | 57 | fw_outputs = reconstruct(flat_fw_outputs, inputs, 2) 58 | bw_outputs = reconstruct(flat_bw_outputs, inputs, 2) 59 | # FIXME : final state is not reshaped! 60 | return (fw_outputs, bw_outputs), final_state 61 | -------------------------------------------------------------------------------- /SNLI_disan/src/nn_utils/rnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from src.nn_utils.general import flatten, reconstruct 3 | 4 | 5 | def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, 6 | dtype=None, parallel_iterations=None, swap_memory=False, 7 | time_major=False, scope=None): 8 | assert not time_major # TODO : to be implemented later! 9 | flat_inputs = flatten(inputs, 2) # [-1, J, d] 10 | flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64') 11 | 12 | flat_outputs, final_state = tf.nn.dynamic_rnn(cell, flat_inputs, sequence_length=flat_len, 13 | initial_state=initial_state, dtype=dtype, 14 | parallel_iterations=parallel_iterations, swap_memory=swap_memory, 15 | time_major=time_major, scope=scope) 16 | 17 | outputs = reconstruct(flat_outputs, inputs, 2) 18 | return outputs, final_state 19 | 20 | 21 | def bw_dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, 22 | dtype=None, parallel_iterations=None, swap_memory=False, 23 | time_major=False, scope=None): 24 | assert not time_major # TODO : to be implemented later! 25 | 26 | flat_inputs = flatten(inputs, 2) # [-1, J, d] 27 | flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64') 28 | 29 | flat_inputs = tf.reverse(flat_inputs, 1) if sequence_length is None \ 30 | else tf.reverse_sequence(flat_inputs, sequence_length, 1) 31 | flat_outputs, final_state = tf.nn.dynamic_rnn(cell, flat_inputs, sequence_length=flat_len, 32 | initial_state=initial_state, dtype=dtype, 33 | parallel_iterations=parallel_iterations, swap_memory=swap_memory, 34 | time_major=time_major, scope=scope) 35 | flat_outputs = tf.reverse(flat_outputs, 1) if sequence_length is None \ 36 | else tf.reverse_sequence(flat_outputs, sequence_length, 1) 37 | 38 | outputs = reconstruct(flat_outputs, inputs, 2) 39 | return outputs, final_state 40 | 41 | 42 | def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, 43 | initial_state_fw=None, initial_state_bw=None, 44 | dtype=None, parallel_iterations=None, 45 | swap_memory=False, time_major=False, scope=None): 46 | assert not time_major 47 | 48 | flat_inputs = flatten(inputs, 2) # [-1, J, d] 49 | flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64') 50 | 51 | (flat_fw_outputs, flat_bw_outputs), final_state = \ 52 | tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len, 53 | initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw, 54 | dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory, 55 | time_major=time_major, scope=scope) 56 | 57 | fw_outputs = reconstruct(flat_fw_outputs, inputs, 2) 58 | bw_outputs = reconstruct(flat_bw_outputs, inputs, 2) 59 | # FIXME : final state is not reshaped! 60 | return (fw_outputs, bw_outputs), final_state 61 | -------------------------------------------------------------------------------- /SNLI_disan/src/model/exp_bi_lstm_mul_attn.py: -------------------------------------------------------------------------------- 1 | from configs import cfg 2 | from src.utils.record_log import _logger 3 | import tensorflow as tf 4 | 5 | from src.model.model_template import ModelTemplate 6 | from src.nn_utils.nn import linear 7 | from src.nn_utils.integration_func import multi_dimensional_attention, generate_embedding_mat,\ 8 | contextual_bi_rnn 9 | 10 | 11 | class ModelExpBiLSTMMulAttn(ModelTemplate): 12 | def __init__(self, token_emb_mat, glove_emb_mat, tds, cds, tl, scope): 13 | super(ModelExpBiLSTMMulAttn, self).__init__(token_emb_mat, glove_emb_mat, tds, cds, tl, scope) 14 | self.update_tensor_add_ema_and_opt() 15 | 16 | def build_network(self): 17 | _logger.add() 18 | _logger.add('building %s neural network structure...' % cfg.network_type) 19 | tds, cds = self.tds, self.cds 20 | tl = self.tl 21 | tel, cel, cos, ocd, fh = self.tel, self.cel, self.cos, self.ocd, self.fh 22 | hn = self.hn 23 | bs, sl1, sl2 = self.bs, self.sl1, self.sl2 24 | 25 | with tf.variable_scope('emb'): 26 | token_emb_mat = generate_embedding_mat(tds, tel, init_mat=self.token_emb_mat, 27 | extra_mat=self.glove_emb_mat, extra_trainable=self.finetune_emb, 28 | scope='gene_token_emb_mat') 29 | s1_emb = tf.nn.embedding_lookup(token_emb_mat, self.sent1_token) # bs,sl1,tel 30 | s2_emb = tf.nn.embedding_lookup(token_emb_mat, self.sent2_token) # bs,sl2,tel 31 | self.tensor_dict['s1_emb'] = s1_emb 32 | self.tensor_dict['s2_emb'] = s2_emb 33 | 34 | with tf.variable_scope('context_fusion'): 35 | s1_seq_rep = contextual_bi_rnn(s1_emb, self.sent1_token_mask, hn, 'lstm', False, cfg.wd, 36 | cfg.dropout, self.is_train, 'bi_lstm') 37 | tf.get_variable_scope().reuse_variables() 38 | s2_seq_rep = contextual_bi_rnn(s2_emb, self.sent2_token_mask, hn, 'lstm', False, cfg.wd, 39 | cfg.dropout, self.is_train, 'bi_lstm') 40 | 41 | self.tensor_dict['s1_seq_rep'] = s1_seq_rep 42 | self.tensor_dict['s2_seq_rep'] = s2_seq_rep 43 | 44 | with tf.variable_scope('sent_enc_attn'): 45 | s1_rep = multi_dimensional_attention( 46 | s1_seq_rep, self.sent1_token_mask, 'multi_dimensional_attention', 47 | cfg.dropout, self.is_train, cfg.wd, 'relu', 48 | tensor_dict=self.tensor_dict, name='s1_attn') 49 | tf.get_variable_scope().reuse_variables() 50 | s2_rep = multi_dimensional_attention( 51 | s2_seq_rep, self.sent2_token_mask, 'multi_dimensional_attention', 52 | cfg.dropout, self.is_train, cfg.wd, 'relu', 53 | tensor_dict=self.tensor_dict, name='s2_attn') 54 | 55 | self.tensor_dict['s1_rep'] = s1_rep 56 | self.tensor_dict['s2_rep'] = s2_rep 57 | 58 | with tf.variable_scope('output'): 59 | out_rep = tf.concat([s1_rep, s2_rep, s1_rep - s2_rep, s1_rep * s2_rep], -1) 60 | pre_output = tf.nn.relu(linear([out_rep], hn, True, 0., scope= 'pre_output', squeeze=False, 61 | wd=cfg.wd, input_keep_prob=cfg.dropout,is_train=self.is_train)) 62 | logits = linear([pre_output], self.output_class, True, 0., scope= 'logits', squeeze=False, 63 | wd=cfg.wd, input_keep_prob=cfg.dropout,is_train=self.is_train) 64 | self.tensor_dict[logits] = logits 65 | return logits # logits 66 | 67 | 68 | -------------------------------------------------------------------------------- /SST_disan/README.md: -------------------------------------------------------------------------------- 1 | # DiSAN Implementation for SST 2 | The introduction to Sentiment Classification task, please refer to [Stanford Sentiment Treebank (SST)](https://nlp.stanford.edu/sentiment/index.html) or paper. 3 | 4 | ## Python Package Requirements 5 | * tqdm 6 | * nltk 7 | 8 | ## Dataset 9 | First of all, please clone repo and enter into project folder by: 10 | 11 | git clone https://github.com/forreview/DiSAN 12 | cd DiSAN/SST_disan 13 | 14 | download data files into **dataset/** dir: 15 | 16 | * [GloVe Pretrained word2vec](http://nlp.stanford.edu/data/glove.6B.zip) 17 | * [SST dataset](http://nlp.stanford.edu/~socherr/stanfordSentimentTreebank.zip) 18 | 19 | __Please check and ensure following files in corresponding folders for running under default hyper-parameters:__ 20 | 21 | * dataset/glove/glove.6B.300d.txt 22 | * dataset/stanfordSentimentTreebank/datasetSentences.txt 23 | * dataset/stanfordSentimentTreebank/datasetSplit.txt 24 | * dataset/stanfordSentimentTreebank/dictionary.txt 25 | * dataset/stanfordSentimentTreebank/original\_rt\_snippets.txt 26 | * dataset/stanfordSentimentTreebank/sentiment_labels.txt 27 | * dataset/stanfordSentimentTreebank/SOStr.txt 28 | * dataset/stanfordSentimentTreebank/STree.txt 29 | 30 | ## 1. Run a Pre-trained Model to Verify the Result in Paper 31 | 32 | ### 1.1 Download Pre-processed Dataset and Pre-trained Model 33 | Download URL is [here](https://drive.google.com/open?id=0B3Sd3TjOhd-JcnY4dkJOMFo0Ujg), **Please do not rename file after downloading!**. 34 | 35 | #### 1.1.1 Download Pre-processed Dataset 36 | * file name: *processed\_lw\_True\_ugut\_True\_gc\_6B\_wel\_300.pickle* 37 | * Download file to *result/processed_data* 38 | 39 | #### 1.1.2 Download Pre-trained Model File 40 | 41 | * two files: *disan\_sst\_model.ckpt.data-00000-of-00001* and *disan\_sst\_model.ckpt.index* 42 | * Download files to folder *pretrained_model/*, and specify the path to running params `--load_path` 43 | 44 | 45 | ### 1.2 Run the code 46 | python sst_main.py --mode test --network_type disan --model_dir_suffix pretrained --gpu 0 --load_path pretrained_model/disan_sst_model.ckpt 47 | 48 | __notice:__ 49 | 50 | * Please specify the GPU index in param `--gpu` to run the code on specified GPU. And if gpu is not avaliable, the code will be run on CPU automatically. 51 | * if you dont have enough GPU memory, please feel free to change `--test_batch_size` whose default value is 128. 52 | * [For tensorflow newcomer] The augument to `--load_path` does not need *.ckpt.data* or *.ckpt.index* as postfix, just *.ckpt*. 53 | 54 | ## 2. Train a Model 55 | Just run codes as follows after preparing dataset, do not need to download any pre-processed files: 56 | 57 | python sst_main.py --mode train --network_type disan --model_dir_suffix training --gpu 0 58 | 59 | __notice:__ 60 | 61 | * Everytime you running the code will build a folder *result/model/xx\_model\_name\_xxx/* whose name begin with augument of `--model_dir_suffix`, which includes running log in *log* folder, tensorflow summary in *summary* folder and top-3 models in *ckpt* folder. 62 | * Please specify the GPU index in param `--gpu` to run the code on specified GPU. And if gpu is not avaliable, the code will be run on CPU automatically. 63 | * if you dont have enough GPU memory, please feel free to change `--test_batch_size` whose default value is 128, and `--train_batch_size` whose default value is 64. 64 | * After processing the raw dataset, a pickle file will be stored in *result/pocessed_data* which can be employed in following code running without processing the raw data for time saving. 65 | * The detail of parameters can be viewed in file `configs.py`. 66 | 67 | ### Test Your Trained Model 68 | The training will take about 4 hours on single GTX1080Ti, and at the end of training models, top 3 model, including step number, dev accuracy and test accuracy, will be display in the bash window. You can also check the tensorflow checkpoint files in *result/model/xxxx/ckpt*, and run in test mode which introduced in sec.1. 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /SNLI_disan/src/utils/file.py: -------------------------------------------------------------------------------- 1 | from configs import cfg 2 | from src.utils.record_log import _logger 3 | import numpy as np 4 | import json 5 | import pickle 6 | import os 7 | 8 | def load_squad_dataset(filePath): 9 | with open(filePath, 'r', encoding='utf-8') as data_file: 10 | line = data_file.readline() 11 | 12 | dataset = json.loads(line) 13 | 14 | return dataset['data'] 15 | 16 | 17 | 18 | def save_file(data, filePath, dataName = 'data', mode='pickle'): 19 | _logger.add() 20 | _logger.add('Saving %s to %s' % (dataName,filePath)) 21 | 22 | if mode == 'pickle': 23 | with open(filePath, 'wb') as f: 24 | pickle.dump(obj=data, 25 | file=f) 26 | elif mode == 'json': 27 | with open(filePath, 'w', encoding='utf-8') as f: 28 | json.dump(obj=data, 29 | fp=f) 30 | else: 31 | raise(ValueError,'Function save_file does not have mode %s' % (mode)) 32 | _logger.add('Done') 33 | 34 | 35 | def load_file(filePath, dataName = 'data', mode='pickle'): 36 | _logger.add() 37 | _logger.add('Trying to load %s from %s' % (dataName, filePath)) 38 | data = None 39 | ifLoad = False 40 | if os.path.isfile(filePath): 41 | _logger.add('Have found the file, loading...') 42 | 43 | if mode == 'pickle': 44 | with open(filePath, 'rb') as f: 45 | data = pickle.load(f) 46 | ifLoad = True 47 | elif mode == 'json': 48 | with open(filePath, 'r', encoding='utf-8') as f: 49 | data = json.load(f) 50 | ifLoad = True 51 | else: 52 | raise (ValueError, 'Function save_file does not have mode %s' % (mode)) 53 | 54 | else: 55 | _logger.add('Have not found the file') 56 | _logger.add('Done') 57 | return (ifLoad,data) 58 | 59 | 60 | def load_glove(dim): 61 | _logger.add() 62 | _logger.add('loading glove from pre-trained file...') 63 | # if dim not in [50, 100, 200, 300]: 64 | # raise(ValueError, 'glove dim must be in [50, 100, 200, 300]') 65 | word2vec = {} 66 | with open(os.path.join(cfg.glove_dir, "glove.%s.%sd.txt" % (cfg.glove_corpus, str(dim))), encoding='utf-8') as f: 67 | for line in f: 68 | l = None 69 | try: 70 | l = line.strip(os.linesep).split(' ') 71 | vector = np.array(list(map(float, l[1:])), 72 | dtype=cfg.floatX) 73 | word2vec[l[0]] = vector 74 | 75 | assert vector.shape[0] == dim 76 | #print('right:', l) 77 | except ValueError: 78 | print('1.token_error-line:', line[:-1]) 79 | print('2.token_error-split:', l) 80 | anchor = 0 81 | except AssertionError: 82 | print('1.vec_error-line:', line[:-1]) 83 | print('2.vec_error-split:', l) 84 | 85 | _logger.add('Done') 86 | return word2vec 87 | 88 | 89 | def save_nn_model(modelFilePath, allParams, epoch): 90 | _logger.add() 91 | _logger.add('saving model file to %s' % modelFilePath) 92 | with open(modelFilePath,'wb') as f: 93 | pickle.dump(obj=[[param.get_value() for param in allParams ], 94 | epoch], 95 | file = f) 96 | _logger.add('Done') 97 | 98 | def load_nn_model(modelFilePath): 99 | _logger.add() 100 | _logger.add('try to load model file %s' % modelFilePath) 101 | allParamValues = None 102 | epoch = 1 103 | isLoaded = False 104 | if os.path.isfile(modelFilePath): 105 | _logger.add('Have found model file, loading...') 106 | with open(modelFilePath, 'rb') as f: 107 | data = pickle.load(f) 108 | allParamValues = data[0] 109 | epoch = data[1] 110 | isLoaded = True 111 | 112 | else: 113 | _logger.add('Have not found model file') 114 | _logger.add('Done') 115 | return isLoaded, allParamValues, epoch 116 | 117 | 118 | -------------------------------------------------------------------------------- /SST_disan/src/utils/file.py: -------------------------------------------------------------------------------- 1 | from configs import cfg 2 | from src.utils.record_log import _logger 3 | import numpy as np 4 | import json 5 | import pickle 6 | import os 7 | 8 | def load_squad_dataset(filePath): 9 | with open(filePath, 'r', encoding='utf-8') as data_file: 10 | line = data_file.readline() 11 | 12 | dataset = json.loads(line) 13 | 14 | return dataset['data'] 15 | 16 | 17 | 18 | def save_file(data, filePath, dataName = 'data', mode='pickle'): 19 | _logger.add() 20 | _logger.add('Saving %s to %s' % (dataName,filePath)) 21 | 22 | if mode == 'pickle': 23 | with open(filePath, 'wb') as f: 24 | pickle.dump(obj=data, 25 | file=f) 26 | elif mode == 'json': 27 | with open(filePath, 'w', encoding='utf-8') as f: 28 | json.dump(obj=data, 29 | fp=f) 30 | else: 31 | raise(ValueError,'Function save_file does not have mode %s' % (mode)) 32 | _logger.add('Done') 33 | 34 | 35 | def load_file(filePath, dataName = 'data', mode='pickle'): 36 | _logger.add() 37 | _logger.add('Trying to load %s from %s' % (dataName, filePath)) 38 | data = None 39 | ifLoad = False 40 | if os.path.isfile(filePath): 41 | _logger.add('Have found the file, loading...') 42 | 43 | if mode == 'pickle': 44 | with open(filePath, 'rb') as f: 45 | data = pickle.load(f) 46 | ifLoad = True 47 | elif mode == 'json': 48 | with open(filePath, 'r', encoding='utf-8') as f: 49 | data = json.load(f) 50 | ifLoad = True 51 | else: 52 | raise (ValueError, 'Function save_file does not have mode %s' % (mode)) 53 | 54 | else: 55 | _logger.add('Have not found the file') 56 | _logger.add('Done') 57 | return (ifLoad,data) 58 | 59 | 60 | def load_glove(dim): 61 | _logger.add() 62 | _logger.add('loading glove from pre-trained file...') 63 | # if dim not in [50, 100, 200, 300]: 64 | # raise(ValueError, 'glove dim must be in [50, 100, 200, 300]') 65 | word2vec = {} 66 | with open(os.path.join(cfg.glove_dir, "glove.%s.%sd.txt" % (cfg.glove_corpus, str(dim))), encoding='utf-8') as f: 67 | for line in f: 68 | l = None 69 | try: 70 | l = line.strip(os.linesep).split(' ') 71 | vector = np.array(list(map(float, l[1:])), 72 | dtype=cfg.floatX) 73 | word2vec[l[0]] = vector 74 | 75 | assert vector.shape[0] == dim 76 | #print('right:', l) 77 | except ValueError: 78 | print('1.token_error-line:', line[:-1]) 79 | print('2.token_error-split:', l) 80 | anchor = 0 81 | except AssertionError: 82 | print('1.vec_error-line:', line[:-1]) 83 | print('2.vec_error-split:', l) 84 | 85 | _logger.add('Done') 86 | return word2vec 87 | 88 | 89 | def save_nn_model(modelFilePath, allParams, epoch): 90 | _logger.add() 91 | _logger.add('saving model file to %s' % modelFilePath) 92 | with open(modelFilePath,'wb') as f: 93 | pickle.dump(obj=[[param.get_value() for param in allParams ], 94 | epoch], 95 | file = f) 96 | _logger.add('Done') 97 | 98 | def load_nn_model(modelFilePath): 99 | _logger.add() 100 | _logger.add('try to load model file %s' % modelFilePath) 101 | allParamValues = None 102 | epoch = 1 103 | isLoaded = False 104 | if os.path.isfile(modelFilePath): 105 | _logger.add('Have found model file, loading...') 106 | with open(modelFilePath, 'rb') as f: 107 | data = pickle.load(f) 108 | allParamValues = data[0] 109 | epoch = data[1] 110 | isLoaded = True 111 | 112 | else: 113 | _logger.add('Have not found model file') 114 | _logger.add('Done') 115 | return isLoaded, allParamValues, epoch 116 | 117 | 118 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Linux template 3 | *~ 4 | 5 | # temporary files which can be created if a process still has a handle open of a deleted file 6 | .fuse_hidden* 7 | 8 | # KDE directory preferences 9 | .directory 10 | 11 | # Linux trash folder which might appear on any partition or disk 12 | .Trash-* 13 | 14 | # .nfs files are created when an open file is removed but is still being accessed 15 | .nfs* 16 | ### Python template 17 | # Byte-compiled / optimized / DLL files 18 | __pycache__/ 19 | *.py[cod] 20 | *$py.class 21 | 22 | # C extensions 23 | *.so 24 | 25 | # Distribution / packaging 26 | .Python 27 | build/ 28 | develop-eggs/ 29 | dist/ 30 | downloads/ 31 | eggs/ 32 | .eggs/ 33 | lib/ 34 | lib64/ 35 | parts/ 36 | sdist/ 37 | var/ 38 | wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | .hypothesis/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | .static_storage/ 72 | .media/ 73 | local_settings.py 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | ### macOS template 122 | # General 123 | .DS_Store 124 | .AppleDouble 125 | .LSOverride 126 | 127 | # Icon must end with two \r 128 | Icon 129 | 130 | # Thumbnails 131 | ._* 132 | 133 | # Files that might appear in the root of a volume 134 | .DocumentRevisions-V100 135 | .fseventsd 136 | .Spotlight-V100 137 | .TemporaryItems 138 | .Trashes 139 | .VolumeIcon.icns 140 | .com.apple.timemachine.donotpresent 141 | 142 | # Directories potentially created on remote AFP share 143 | .AppleDB 144 | .AppleDesktop 145 | Network Trash Folder 146 | Temporary Items 147 | .apdisk 148 | ### JetBrains template 149 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 150 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 151 | 152 | # User-specific stuff: 153 | .idea/**/workspace.xml 154 | .idea/**/tasks.xml 155 | .idea/dictionaries 156 | 157 | # Sensitive or high-churn files: 158 | .idea/**/dataSources/ 159 | .idea/**/dataSources.ids 160 | .idea/**/dataSources.xml 161 | .idea/**/dataSources.local.xml 162 | .idea/**/sqlDataSources.xml 163 | .idea/**/dynamic.xml 164 | .idea/**/uiDesigner.xml 165 | 166 | # Gradle: 167 | .idea/**/gradle.xml 168 | .idea/**/libraries 169 | 170 | # CMake 171 | cmake-build-debug/ 172 | cmake-build-release/ 173 | 174 | # Mongo Explorer plugin: 175 | .idea/**/mongoSettings.xml 176 | 177 | ## File-based project format: 178 | *.iws 179 | 180 | ## Plugin-specific files: 181 | 182 | # IntelliJ 183 | out/ 184 | 185 | # mpeltonen/sbt-idea plugin 186 | .idea_modules/ 187 | 188 | # JIRA plugin 189 | atlassian-ide-plugin.xml 190 | 191 | # Cursive Clojure plugin 192 | .idea/replstate.xml 193 | 194 | # Crashlytics plugin (for Android Studio and IntelliJ) 195 | com_crashlytics_export_strings.xml 196 | crashlytics.properties 197 | crashlytics-build.properties 198 | fabric.properties 199 | 200 | .gitignore 201 | .idea/ 202 | Fast-DiSA/.DS_Store 203 | -------------------------------------------------------------------------------- /SNLI_disan/README.md: -------------------------------------------------------------------------------- 1 | # DiSAN Implementation for SNLI 2 | The introduction to NLI task, please refer to [Stanford Natural Language Inference (SNLI)](https://nlp.stanford.edu/projects/snli/) or paper. 3 | 4 | ## Python Package Requirements 5 | * tqdm 6 | * nltk 7 | 8 | ## Dataset 9 | 10 | First of all, please clone repo and enter into project folder by: 11 | 12 | git clone https://github.com/taoshen58/DiSAN 13 | cd DiSAN/SNLI_disan 14 | 15 | download data files into **dataset/** dir: 16 | 17 | * [GloVe Pretrained word2vec](http://nlp.stanford.edu/data/glove.6B.zip) 18 | * [SNLI dataset](https://nlp.stanford.edu/projects/snli/snli_1.0.zip) 19 | 20 | __Please check and ensure following files in corresponding folders for running under default hyper-parameters:__ 21 | 22 | * dataset/glove/glove.6B.300d.txt 23 | * dataset/snli\_1.0/snli\_1.0\_train.jsonl 24 | * dataset/snli\_1.0/snli\_1.0\_dev.jsonl 25 | * dataset/snli\_1.0/snli\_1.0\_test.jsonl 26 | 27 | ## 1. Run a Pre-trained Model to Verify the Result in Paper 28 | 29 | ### 1.1 Download Pre-processed Dataset and Pre-trained Model 30 | Download URL is [here](https://drive.google.com/drive/folders/0B3Sd3TjOhd-JNjJNT2RoZU1NalU?usp=sharing), **Please do not rename file after downloading!**. 31 | 32 | #### 1.1.1 Download Pre-processed Dataset 33 | * file name: *processed\_lw\_True\_ugut\_True\_gc\_6B\_wel\_300\_slr\_0.97\_dcm\_no\_tree.pickle* 34 | * Download file to *result/processed_data* 35 | 36 | #### 1.1.2 Download Pre-trained Model File 37 | 38 | * two files: *disan\_snli\_model.ckpt.data-00000-of-00001* and *disan\_snli\_model.ckpt.index* 39 | * Download files to folder *pretrained_model/*, and specify the path to running params `--load_path` 40 | 41 | 42 | ### 1.2 Run the code 43 | 44 | python snli_main.py --mode test --network_type disan --model_dir_suffix pretrained --gpu 0 --load_path pretrained_model/disan_snli_model.ckpt 45 | 46 | __notice:__ 47 | 48 | * Please specify the GPU index in param `--gpu` to run the code on specified GPU. And if gpu is not avaliable, the code will be run on CPU automatically. 49 | * The consuming memory is about 9GB with `--test_batch_size` set to 100, if your device have limited memory, please try to change `--test_batch_size` to smaller. 50 | * [For tensorflow newcomer] The augument to `--load_path` does not need *.ckpt.data* or *.ckpt.index* as postfix, just *.ckpt*. 51 | 52 | ## 2. Train a Model 53 | Just run codes as follows after preparing dataset, do not need to download any pre-processed files: 54 | 55 | python snli_main.py --mode train --network_type disan --model_dir_suffix training --gpu 0 56 | 57 | __notice:__ 58 | 59 | * Everytime you running the code will build a folder *result/model/xx\_model\_name\_xxx/* whose name begin with augument of `--model_dir_suffix`, which includes running log in *log* folder, tensorflow summary in *summary* folder and top-3 models in *ckpt* folder. 60 | * Please specify the GPU index in param `--gpu` to run the code on specified GPU. And if gpu is not avaliable, the code will be run on CPU automatically. 61 | * The consuming memory is about 9GB with `--test_batch_size` set to 100, if your device have limited memory, please try to change `--test_batch_size` into smaller. In addition, with `--train_batch_size` set to 64, the minimum consuming memory is 5GB, you can also change it to smaller. 62 | * After processing the raw dataset, a pickle file will be stored in *result/pocessed_data* which can be employed in following code running without processing the raw data for time saving. 63 | * The detail of parameters can be viewed in file `configs.py`. 64 | * There are also baseline neural networks provided which are appeared in paper, respectively `exp_emb_attn`, `exp_emb_mul_attn`, `exp_bi_lstm_mul_attn`, `exp_emb_self_mul_attn` with the same order in paper, so you can pass one of these model name to `-network_type` to run baseline experiments. 65 | 66 | ### Test Your Trained Model 67 | The training will take about 15 hours on single GTX1080Ti, and at the end of training models, top 3 model, including step number, dev accuracy and test accuracy, will be display in the bash window. You can also check the tensorflow checkpoint files in *result/model/xxxx/ckpt*, and run in test mode which introduced in sec.1. 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /SNLI_disan/src/model/exp_emb_self_mul_attn.py: -------------------------------------------------------------------------------- 1 | from configs import cfg 2 | from src.utils.record_log import _logger 3 | import tensorflow as tf 4 | 5 | from src.model.model_template import ModelTemplate 6 | from src.nn_utils.nn import linear 7 | from src.nn_utils.integration_func import multi_dimensional_attention, generate_embedding_mat,\ 8 | directional_attention_with_dense 9 | 10 | 11 | class ModelExpEmbSelfMulAttn(ModelTemplate): 12 | def __init__(self, token_emb_mat, glove_emb_mat, tds, cds, tl, scope): 13 | super(ModelExpEmbSelfMulAttn, self).__init__(token_emb_mat, glove_emb_mat, tds, cds, tl, scope) 14 | self.update_tensor_add_ema_and_opt() 15 | 16 | def build_network(self): 17 | _logger.add() 18 | _logger.add('building %s neural network structure...' % cfg.network_type) 19 | tds, cds = self.tds, self.cds 20 | tl = self.tl 21 | tel, cel, cos, ocd, fh = self.tel, self.cel, self.cos, self.ocd, self.fh 22 | hn = self.hn 23 | bs, sl1, sl2 = self.bs, self.sl1, self.sl2 24 | 25 | with tf.variable_scope('emb'): 26 | token_emb_mat = generate_embedding_mat(tds, tel, init_mat=self.token_emb_mat, 27 | extra_mat=self.glove_emb_mat, extra_trainable=self.finetune_emb, 28 | scope='gene_token_emb_mat') 29 | s1_emb = tf.nn.embedding_lookup(token_emb_mat, self.sent1_token) # bs,sl1,tel 30 | s2_emb = tf.nn.embedding_lookup(token_emb_mat, self.sent2_token) # bs,sl2,tel 31 | self.tensor_dict['s1_emb'] = s1_emb 32 | self.tensor_dict['s2_emb'] = s2_emb 33 | 34 | with tf.variable_scope('ct_attn'): 35 | s1_1 = directional_attention_with_dense( 36 | s1_emb, self.sent1_token_mask, None, 'dir_attn_1', 37 | cfg.dropout, self.is_train, cfg.wd, 38 | tensor_dict=self.tensor_dict, name='s1_1_attn') 39 | s1_2 = directional_attention_with_dense( 40 | s1_emb, self.sent1_token_mask, None, 'dir_attn_2', 41 | cfg.dropout, self.is_train, cfg.wd, 42 | tensor_dict=self.tensor_dict, name='s1_2_attn') 43 | 44 | s1_seq_rep = tf.concat([s1_1, s1_2], -1) 45 | 46 | tf.get_variable_scope().reuse_variables() 47 | 48 | s2_1 = directional_attention_with_dense( 49 | s2_emb, self.sent2_token_mask, None, 'dir_attn_1', 50 | cfg.dropout, self.is_train, cfg.wd, 51 | tensor_dict=self.tensor_dict, name='s2_1_attn') 52 | s2_2 = directional_attention_with_dense( 53 | s2_emb, self.sent2_token_mask, None, 'dir_attn_2', 54 | cfg.dropout, self.is_train, cfg.wd, 55 | tensor_dict=self.tensor_dict, name='s2_2_attn') 56 | s2_seq_rep = tf.concat([s2_1, s2_2], -1) 57 | 58 | with tf.variable_scope('sent_enc_attn'): 59 | s1_rep = multi_dimensional_attention( 60 | s1_seq_rep, self.sent1_token_mask, 'multi_dimensional_attention', 61 | cfg.dropout, self.is_train, cfg.wd, 62 | tensor_dict=self.tensor_dict, name='s1_attn') 63 | tf.get_variable_scope().reuse_variables() 64 | s2_rep = multi_dimensional_attention( 65 | s2_seq_rep, self.sent2_token_mask, 'multi_dimensional_attention', 66 | cfg.dropout, self.is_train, cfg.wd, 67 | tensor_dict=self.tensor_dict, name='s2_attn') 68 | 69 | self.tensor_dict['s1_rep'] = s1_rep 70 | self.tensor_dict['s2_rep'] = s2_rep 71 | 72 | with tf.variable_scope('output'): 73 | out_rep = tf.concat([s1_rep, s2_rep, s1_rep - s2_rep, s1_rep * s2_rep], -1) 74 | pre_output = tf.nn.elu(linear([out_rep], hn, True, 0., scope= 'pre_output', squeeze=False, 75 | wd=cfg.wd, input_keep_prob=cfg.dropout,is_train=self.is_train)) 76 | logits = linear([pre_output], self.output_class, True, 0., scope= 'logits', squeeze=False, 77 | wd=cfg.wd, input_keep_prob=cfg.dropout,is_train=self.is_train) 78 | self.tensor_dict[logits] = logits 79 | return logits # logits 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /SNLI_disan/src/model/exp_emb_dir_mul_attn.py: -------------------------------------------------------------------------------- 1 | from configs import cfg 2 | from src.utils.record_log import _logger 3 | import tensorflow as tf 4 | 5 | from src.model.model_template import ModelTemplate 6 | from src.nn_utils.nn import linear 7 | from src.nn_utils.integration_func import multi_dimensional_attention, generate_embedding_mat,\ 8 | directional_attention_with_dense 9 | 10 | 11 | class ModelExpEmbDirMulAttn(ModelTemplate): 12 | def __init__(self, token_emb_mat, glove_emb_mat, tds, cds, tl, scope): 13 | super(ModelExpEmbDirMulAttn, self).__init__(token_emb_mat, glove_emb_mat, tds, cds, tl, scope) 14 | self.update_tensor_add_ema_and_opt() 15 | 16 | def build_network(self): 17 | _logger.add() 18 | _logger.add('building %s neural network structure...' % cfg.network_type) 19 | tds, cds = self.tds, self.cds 20 | tl = self.tl 21 | tel, cel, cos, ocd, fh = self.tel, self.cel, self.cos, self.ocd, self.fh 22 | hn = self.hn 23 | bs, sl1, sl2 = self.bs, self.sl1, self.sl2 24 | 25 | with tf.variable_scope('emb'): 26 | token_emb_mat = generate_embedding_mat(tds, tel, init_mat=self.token_emb_mat, 27 | extra_mat=self.glove_emb_mat, extra_trainable=self.finetune_emb, 28 | scope='gene_token_emb_mat') 29 | s1_emb = tf.nn.embedding_lookup(token_emb_mat, self.sent1_token) # bs,sl1,tel 30 | s2_emb = tf.nn.embedding_lookup(token_emb_mat, self.sent2_token) # bs,sl2,tel 31 | self.tensor_dict['s1_emb'] = s1_emb 32 | self.tensor_dict['s2_emb'] = s2_emb 33 | 34 | with tf.variable_scope('ct_attn'): 35 | s1_fw = directional_attention_with_dense( 36 | s1_emb, self.sent1_token_mask, 'forward', 'dir_attn_fw', 37 | cfg.dropout, self.is_train, cfg.wd, 38 | tensor_dict=self.tensor_dict, name='s1_fw_attn') 39 | s1_bw = directional_attention_with_dense( 40 | s1_emb, self.sent1_token_mask, 'backward', 'dir_attn_bw', 41 | cfg.dropout, self.is_train, cfg.wd, 42 | tensor_dict=self.tensor_dict, name='s1_bw_attn') 43 | 44 | s1_seq_rep = tf.concat([s1_fw, s1_bw], -1) 45 | 46 | tf.get_variable_scope().reuse_variables() 47 | 48 | s2_fw = directional_attention_with_dense( 49 | s2_emb, self.sent2_token_mask, 'forward', 'dir_attn_fw', 50 | cfg.dropout, self.is_train, cfg.wd, 51 | tensor_dict=self.tensor_dict, name='s2_fw_attn') 52 | s2_bw = directional_attention_with_dense( 53 | s2_emb, self.sent2_token_mask, 'backward', 'dir_attn_bw', 54 | cfg.dropout, self.is_train, cfg.wd, 55 | tensor_dict=self.tensor_dict, name='s2_bw_attn') 56 | s2_seq_rep = tf.concat([s2_fw, s2_bw], -1) 57 | 58 | with tf.variable_scope('sent_enc_attn'): 59 | s1_rep = multi_dimensional_attention( 60 | s1_seq_rep, self.sent1_token_mask, 'multi_dimensional_attention', 61 | cfg.dropout, self.is_train, cfg.wd, 62 | tensor_dict=self.tensor_dict, name='s1_attn') 63 | tf.get_variable_scope().reuse_variables() 64 | s2_rep = multi_dimensional_attention( 65 | s2_seq_rep, self.sent2_token_mask, 'multi_dimensional_attention', 66 | cfg.dropout, self.is_train, cfg.wd, 67 | tensor_dict=self.tensor_dict, name='s2_attn') 68 | 69 | self.tensor_dict['s1_rep'] = s1_rep 70 | self.tensor_dict['s2_rep'] = s2_rep 71 | 72 | with tf.variable_scope('output'): 73 | out_rep = tf.concat([s1_rep, s2_rep, s1_rep - s2_rep, s1_rep * s2_rep], -1) 74 | pre_output = tf.nn.elu(linear([out_rep], hn, True, 0., scope= 'pre_output', squeeze=False, 75 | wd=cfg.wd, input_keep_prob=cfg.dropout,is_train=self.is_train)) 76 | logits = linear([pre_output], self.output_class, True, 0., scope= 'logits', squeeze=False, 77 | wd=cfg.wd, input_keep_prob=cfg.dropout,is_train=self.is_train) 78 | self.tensor_dict[logits] = logits 79 | return logits # logits 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /SNLI_disan/src/evaluator.py: -------------------------------------------------------------------------------- 1 | from configs import cfg 2 | from src.utils.record_log import _logger 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | class Evaluator(object): 8 | def __init__(self, model): 9 | self.model = model 10 | self.global_step = model.global_step 11 | 12 | ## ---- summary---- 13 | self.build_summary() 14 | self.writer = tf.summary.FileWriter(cfg.summary_dir) 15 | 16 | def get_evaluation(self, sess, dataset_obj, global_step=None): 17 | _logger.add() 18 | _logger.add('getting evaluation result for %s' % dataset_obj.data_type) 19 | 20 | logits_list, loss_list, accu_list = [], [], [] 21 | for sample_batch, _, _, _ in dataset_obj.generate_batch_sample_iter(): 22 | feed_dict = self.model.get_feed_dict(sample_batch, 'dev') 23 | logits, loss, accu = sess.run([self.model.logits, 24 | self.model.loss, self.model.accuracy], feed_dict) 25 | logits_list.append(np.argmax(logits, -1)) 26 | loss_list.append(loss) 27 | accu_list.append(accu) 28 | 29 | logits_array = np.concatenate(logits_list, 0) 30 | loss_value = np.mean(loss_list) 31 | accu_array = np.concatenate(accu_list, 0) 32 | accu_value = np.mean(accu_array) 33 | 34 | # todo: analysis 35 | # analysis_save_dir = cfg.mkdir(cfg.answer_dir, 'gs_%d' % global_step or 0) 36 | # OutputAnalysis.do_analysis(dataset_obj, logits_array, accu_array, analysis_save_dir, 37 | # cfg.fine_grained) 38 | 39 | if global_step is not None: 40 | if dataset_obj.data_type == 'train': 41 | summary_feed_dict = { 42 | self.train_loss: loss_value, 43 | self.train_accuracy: accu_value, 44 | } 45 | summary = sess.run(self.train_summaries, summary_feed_dict) 46 | self.writer.add_summary(summary, global_step) 47 | elif dataset_obj.data_type == 'dev': 48 | summary_feed_dict = { 49 | self.dev_loss: loss_value, 50 | self.dev_accuracy: accu_value, 51 | } 52 | summary = sess.run(self.dev_summaries, summary_feed_dict) 53 | self.writer.add_summary(summary, global_step) 54 | else: 55 | summary_feed_dict = { 56 | self.test_loss: loss_value, 57 | self.test_accuracy: accu_value, 58 | } 59 | summary = sess.run(self.test_summaries, summary_feed_dict) 60 | self.writer.add_summary(summary, global_step) 61 | 62 | return loss_value, accu_value 63 | 64 | 65 | # --- internal use ------ 66 | def build_summary(self): 67 | with tf.name_scope('train_summaries'): 68 | self.train_loss = tf.placeholder(tf.float32, [], 'train_loss') 69 | self.train_accuracy = tf.placeholder(tf.float32, [], 'train_accuracy') 70 | tf.add_to_collection('train_summaries_collection', tf.summary.scalar('train_loss', self.train_loss)) 71 | tf.add_to_collection('train_summaries_collection', tf.summary.scalar('train_accuracy', self.train_accuracy)) 72 | self.train_summaries = tf.summary.merge_all('train_summaries_collection') 73 | 74 | with tf.name_scope('dev_summaries'): 75 | self.dev_loss = tf.placeholder(tf.float32, [], 'dev_loss') 76 | self.dev_accuracy = tf.placeholder(tf.float32, [], 'dev_accuracy') 77 | tf.add_to_collection('dev_summaries_collection', tf.summary.scalar('dev_loss',self.dev_loss)) 78 | tf.add_to_collection('dev_summaries_collection', tf.summary.scalar('dev_accuracy',self.dev_accuracy)) 79 | self.dev_summaries = tf.summary.merge_all('dev_summaries_collection') 80 | 81 | with tf.name_scope('test_summaries'): 82 | self.test_loss = tf.placeholder(tf.float32, [], 'test_loss') 83 | self.test_accuracy = tf.placeholder(tf.float32, [], 'test_accuracy') 84 | tf.add_to_collection('test_summaries_collection', tf.summary.scalar('test_loss',self.test_loss)) 85 | tf.add_to_collection('test_summaries_collection', tf.summary.scalar('test_accuracy',self.test_accuracy)) 86 | self.test_summaries = tf.summary.merge_all('test_summaries_collection') 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /SST_disan/src/analysis.py: -------------------------------------------------------------------------------- 1 | from configs import cfg 2 | from src.utils.file import load_file 3 | import nltk 4 | import csv, os 5 | import math 6 | 7 | 8 | class OutputAnalysis(object): 9 | def __init__(self): 10 | # todo: 11 | # 1. all sample classification distribution and their accuracy 12 | # 1. all sentence sample classification distribution and their accuracy 13 | # 3. output all sentence sample with label and prediction 14 | pass 15 | 16 | @staticmethod 17 | def do_analysis(dataset_obj, pred_arr,eval_arr, save_dir, fine_grained=True): 18 | out_class_num = 5 if fine_grained else 2 19 | 20 | save_dir = cfg.mkdir(save_dir, dataset_obj.data_type) 21 | sample_list = [] 22 | int_labels = [] 23 | for trees in dataset_obj.nn_data: 24 | for sample in trees: 25 | sample_list.append(sample) 26 | sentiment_float = sample['root_node']['sentiment_label'] 27 | sentiment_int = cfg.sentiment_float_to_int(sentiment_float, fine_grained) 28 | int_labels.append(sentiment_int) 29 | 30 | # check data 31 | assert len(sample_list) == pred_arr.shape[0] 32 | assert pred_arr.shape[0] == eval_arr.shape[0] 33 | 34 | # save csv 35 | with open(os.path.join(save_dir, 'sent_sample_res.csv'), 'w', newline='') as file: 36 | csv_writer = csv.writer(file) 37 | csv_writer.writerow(['sent', 'label', 'pred', 'delta']) 38 | for sample, pred_val, int_label in zip(sample_list, pred_arr, int_labels): 39 | if sample['is_sent']: 40 | sent = ' '.join(sample['root_node']['token_seq']) 41 | label = int_label 42 | pred = int(pred_val) 43 | delta = int(math.fabs(label - pred)) 44 | csv_writer.writerow([sent, label, pred, delta]) 45 | 46 | # statistics 47 | all_class_collect = [] 48 | all_class_right_collect = [] 49 | sent_class_collect = [] 50 | sent_class_right_collect = [] 51 | for sample, eval_val, int_label in zip(sample_list, eval_arr, int_labels): 52 | if sample['is_sent']: 53 | if float(eval_val) == 1.: 54 | sent_class_right_collect.append(int_label) 55 | sent_class_collect.append(int_label) 56 | if float(eval_val) == 1.: 57 | all_class_right_collect.append(int_label) 58 | all_class_collect.append(int_label) 59 | all_class_pdf = nltk.FreqDist(all_class_collect) 60 | all_class_right_pdf = nltk.FreqDist(all_class_right_collect) 61 | sent_class_pdf = nltk.FreqDist(sent_class_collect) 62 | sent_class_right_pdf = nltk.FreqDist(sent_class_right_collect) 63 | 64 | with open(os.path.join(save_dir, 'statistics.txt'), 'w') as file: 65 | file.write('class ,all_class, all_class_right, all_rate, sent_class, sent_class_right, sent_rate' + 66 | os.linesep) 67 | for i in range(out_class_num): 68 | all_class_num = all_class_pdf[i] 69 | all_class_right_num = all_class_right_pdf[i] 70 | sent_class_num = sent_class_pdf[i] 71 | sent_class_right_num = sent_class_right_pdf[i] 72 | file.write('%d, %d, %d, %.4f, %d, %d, %.4f'% 73 | (i, all_class_num, all_class_right_num, 1.0*all_class_right_num/all_class_num, 74 | sent_class_num, sent_class_right_num, 1.0*sent_class_right_num/sent_class_num)) 75 | file.write(os.linesep) 76 | 77 | # statistics.csv 78 | all_class_table = [[] for _ in range(out_class_num)] 79 | sent_class_table = [[] for _ in range(out_class_num)] 80 | for sample, pred_val, int_label in zip(sample_list, pred_arr, int_labels): 81 | if sample['is_sent']: 82 | sent_class_table[int_label].append(pred_val) 83 | all_class_table[int_label].append(pred_val) 84 | all_class_table = [nltk.FreqDist(pred_list) for pred_list in all_class_table] 85 | sent_class_table = [nltk.FreqDist(pred_list) for pred_list in sent_class_table] 86 | with open(os.path.join(save_dir, 'statistics.csv'), 'w') as file: 87 | csv_writer = csv.writer(file) 88 | for i in range(out_class_num): 89 | row = [all_class_table[i][j] for j in range(out_class_num)] 90 | csv_writer.writerow(row) 91 | csv_writer.writerow([]) 92 | for i in range(out_class_num): 93 | row = [sent_class_table[i][j] for j in range(out_class_num)] 94 | csv_writer.writerow(row) 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | class DatasetAnalysis(object): 104 | def __init__(self): 105 | 106 | pass 107 | 108 | @staticmethod 109 | def do_analysis(dataset_obj): 110 | # 1. all sample classification distribution 111 | # 2. all sentence sample classification distribution 112 | sample_num = dataset_obj.sample_num 113 | collect = [] 114 | sent_collect = [] 115 | for trees in dataset_obj.nn_data: 116 | for sample in trees: 117 | sentiment_float = sample['root_node']['sentiment_label'] 118 | sentiment_int = cfg.sentiment_float_to_int(sentiment_float) 119 | if sample['is_sent']: 120 | sent_collect.append(sentiment_int) 121 | collect.append(sentiment_int) 122 | all_pdf = nltk.FreqDist(collect) 123 | sent_pdf = nltk.FreqDist(sent_collect) 124 | print('sample_num:', sample_num) 125 | print('all') 126 | print(all_pdf.tabulate()) 127 | print('sent') 128 | print(sent_pdf.tabulate()) 129 | 130 | 131 | 132 | if __name__ == '__main__': 133 | ifLoad, data = load_file(cfg.processed_path, 'processed data', 'pickle') 134 | assert ifLoad 135 | train_data_obj = data['train_data_obj'] 136 | dev_data_obj = data['dev_data_obj'] 137 | test_data_obj = data['test_data_obj'] 138 | 139 | DatasetAnalysis.do_analysis(train_data_obj) 140 | 141 | 142 | 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /SNLI_disan/src/nn_utils/general.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from functools import reduce 3 | from operator import mul 4 | 5 | VERY_BIG_NUMBER = 1e30 6 | VERY_SMALL_NUMBER = 1e-30 7 | VERY_POSITIVE_NUMBER = VERY_BIG_NUMBER 8 | VERY_NEGATIVE_NUMBER = -VERY_BIG_NUMBER 9 | 10 | 11 | def get_last_state(rnn_out_put, mask): # correct 12 | ''' 13 | get_last_state of rnn output 14 | :param rnn_out_put: [d1,d2,dn-1,max_len,d] 15 | :param mask: [d1,d2,dn-1,max_len] 16 | :return: [d1,d2,dn-1,d] 17 | ''' 18 | rnn_out_put_flatten = flatten(rnn_out_put, 2)# [X, ml, d] 19 | mask_flatten = flatten(mask,1) # [X,ml] 20 | idxs = tf.reduce_sum(tf.cast(mask_flatten,tf.int32),-1) - 1 # [X] 21 | indices = tf.stack([tf.range(tf.shape(idxs)[0]), idxs], axis=-1) #[X] => [X,2] 22 | flatten_res = tf.expand_dims(tf.gather_nd(rnn_out_put_flatten, indices),-2 )# #[x,d]->[x,1,d] 23 | return tf.squeeze(reconstruct(flatten_res,rnn_out_put,2),-2) #[d1,d2,dn-1,1,d] ->[d1,d2,dn-1,d] 24 | 25 | 26 | def expand_tile(tensor,pattern,tile_num = None, scope=None): # todo: add more func 27 | with tf.name_scope(scope or 'expand_tile'): 28 | assert isinstance(pattern,(tuple,list)) 29 | assert isinstance(tile_num,(tuple,list)) or tile_num is None 30 | assert len(pattern) == len(tile_num) or tile_num is None 31 | idx_pattern = list([(dim, p) for dim, p in enumerate(pattern)]) 32 | for dim,p in idx_pattern: 33 | if p == 'x': 34 | tensor = tf.expand_dims(tensor,dim) 35 | return tf.tile(tensor,tile_num) if tile_num is not None else tensor 36 | 37 | 38 | def get_initializer(matrix): 39 | def _initializer(shape, dtype=None, partition_info=None, **kwargs): return matrix 40 | return _initializer 41 | 42 | 43 | def mask(val, mask, name=None): 44 | if name is None: 45 | name = 'mask' 46 | return tf.multiply(val, tf.cast(mask, 'float'), name=name) 47 | 48 | 49 | def mask_for_high_rank(val, val_mask, name=None): 50 | val_mask = tf.expand_dims(val_mask, -1) 51 | return tf.multiply(val, tf.cast(val_mask, tf.float32), name=name or 'mask_for_high_rank') 52 | 53 | 54 | def exp_mask(val, mask, name=None): 55 | """Give very negative number to unmasked elements in val. 56 | For example, [-3, -2, 10], [True, True, False] -> [-3, -2, -1e9]. 57 | Typically, this effectively masks in exponential space (e.g. softmax) 58 | Args: 59 | val: values to be masked 60 | mask: masking boolean tensor, same shape as tensor 61 | name: name for output tensor 62 | 63 | Returns: 64 | Same shape as val, where some elements are very small (exponentially zero) 65 | """ 66 | if name is None: 67 | name = "exp_mask" 68 | return tf.add(val, (1 - tf.cast(mask, 'float')) * VERY_NEGATIVE_NUMBER, name=name) 69 | 70 | 71 | def exp_mask_for_high_rank(val, val_mask, name=None): 72 | val_mask = tf.expand_dims(val_mask, -1) 73 | return tf.add(val, (1 - tf.cast(val_mask, tf.float32)) * VERY_NEGATIVE_NUMBER, 74 | name=name or 'exp_mask_for_high_rank') 75 | 76 | 77 | def flatten(tensor, keep): 78 | fixed_shape = tensor.get_shape().as_list() 79 | start = len(fixed_shape) - keep 80 | left = reduce(mul, [fixed_shape[i] or tf.shape(tensor)[i] for i in range(start)]) 81 | out_shape = [left] + [fixed_shape[i] or tf.shape(tensor)[i] for i in range(start, len(fixed_shape))] 82 | flat = tf.reshape(tensor, out_shape) 83 | return flat 84 | 85 | 86 | def reconstruct(tensor, ref, keep, dim_reduced_keep=None): 87 | dim_reduced_keep = dim_reduced_keep or keep 88 | 89 | ref_shape = ref.get_shape().as_list() # original shape 90 | tensor_shape = tensor.get_shape().as_list() # current shape 91 | ref_stop = len(ref_shape) - keep # flatten dims list 92 | tensor_start = len(tensor_shape) - dim_reduced_keep # start 93 | pre_shape = [ref_shape[i] or tf.shape(ref)[i] for i in range(ref_stop)] # 94 | keep_shape = [tensor_shape[i] or tf.shape(tensor)[i] for i in range(tensor_start, len(tensor_shape))] # 95 | # pre_shape = [tf.shape(ref)[i] for i in range(len(ref.get_shape().as_list()[:-keep]))] 96 | # keep_shape = tensor.get_shape().as_list()[-keep:] 97 | target_shape = pre_shape + keep_shape 98 | out = tf.reshape(tensor, target_shape) 99 | return out 100 | 101 | 102 | def add_wd(wd, scope=None): 103 | scope = scope or tf.get_variable_scope().name 104 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope) 105 | counter = 0 106 | with tf.name_scope("weight_decay"): 107 | for var in variables: 108 | counter+=1 109 | weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, 110 | name="{}-wd".format('-'.join(str(var.op.name).split('/')))) 111 | tf.add_to_collection('losses', weight_decay) 112 | return counter 113 | 114 | 115 | def add_wd_without_bias(wd, scope=None): 116 | scope = scope or tf.get_variable_scope().name 117 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope) 118 | counter = 0 119 | with tf.name_scope("weight_decay"): 120 | for var in variables: 121 | if len(var.get_shape().as_list()) <= 1: continue 122 | counter += 1 123 | weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, 124 | name="{}-wd".format('-'.join(str(var.op.name).split('/')))) 125 | tf.add_to_collection('losses', weight_decay) 126 | return counter 127 | 128 | 129 | def add_reg_without_bias(scope=None): 130 | scope = scope or tf.get_variable_scope().name 131 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope) 132 | counter = 0 133 | for var in variables: 134 | if len(var.get_shape().as_list()) <= 1: continue 135 | tf.add_to_collection('reg_vars', var) 136 | counter += 1 137 | return counter 138 | 139 | 140 | def add_var_reg(var): 141 | tf.add_to_collection('reg_vars', var) 142 | 143 | 144 | def add_wd_for_var(var, wd): 145 | with tf.name_scope("weight_decay"): 146 | weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, 147 | name="{}-wd".format('-'.join(str(var.op.name).split('/')))) 148 | tf.add_to_collection('losses', weight_decay) 149 | 150 | -------------------------------------------------------------------------------- /SST_disan/src/nn_utils/general.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from functools import reduce 3 | from operator import mul 4 | 5 | VERY_BIG_NUMBER = 1e30 6 | VERY_SMALL_NUMBER = 1e-30 7 | VERY_POSITIVE_NUMBER = VERY_BIG_NUMBER 8 | VERY_NEGATIVE_NUMBER = -VERY_BIG_NUMBER 9 | 10 | 11 | def get_last_state(rnn_out_put, mask): # correct 12 | ''' 13 | get_last_state of rnn output 14 | :param rnn_out_put: [d1,d2,dn-1,max_len,d] 15 | :param mask: [d1,d2,dn-1,max_len] 16 | :return: [d1,d2,dn-1,d] 17 | ''' 18 | rnn_out_put_flatten = flatten(rnn_out_put, 2)# [X, ml, d] 19 | mask_flatten = flatten(mask,1) # [X,ml] 20 | idxs = tf.reduce_sum(tf.cast(mask_flatten,tf.int32),-1) - 1 # [X] 21 | indices = tf.stack([tf.range(tf.shape(idxs)[0]), idxs], axis=-1) #[X] => [X,2] 22 | flatten_res = tf.expand_dims(tf.gather_nd(rnn_out_put_flatten, indices),-2 )# #[x,d]->[x,1,d] 23 | return tf.squeeze(reconstruct(flatten_res,rnn_out_put,2),-2) #[d1,d2,dn-1,1,d] ->[d1,d2,dn-1,d] 24 | 25 | 26 | def expand_tile(tensor,pattern,tile_num = None, scope=None): # todo: add more func 27 | with tf.name_scope(scope or 'expand_tile'): 28 | assert isinstance(pattern,(tuple,list)) 29 | assert isinstance(tile_num,(tuple,list)) or tile_num is None 30 | assert len(pattern) == len(tile_num) or tile_num is None 31 | idx_pattern = list([(dim, p) for dim, p in enumerate(pattern)]) 32 | for dim,p in idx_pattern: 33 | if p == 'x': 34 | tensor = tf.expand_dims(tensor,dim) 35 | return tf.tile(tensor,tile_num) if tile_num is not None else tensor 36 | 37 | 38 | def get_initializer(matrix): 39 | def _initializer(shape, dtype=None, partition_info=None, **kwargs): return matrix 40 | return _initializer 41 | 42 | 43 | def mask(val, mask, name=None): 44 | if name is None: 45 | name = 'mask' 46 | return tf.multiply(val, tf.cast(mask, 'float'), name=name) 47 | 48 | 49 | def mask_for_high_rank(val, val_mask, name=None): 50 | val_mask = tf.expand_dims(val_mask, -1) 51 | return tf.multiply(val, tf.cast(val_mask, tf.float32), name=name or 'mask_for_high_rank') 52 | 53 | 54 | def exp_mask(val, mask, name=None): 55 | """Give very negative number to unmasked elements in val. 56 | For example, [-3, -2, 10], [True, True, False] -> [-3, -2, -1e9]. 57 | Typically, this effectively masks in exponential space (e.g. softmax) 58 | Args: 59 | val: values to be masked 60 | mask: masking boolean tensor, same shape as tensor 61 | name: name for output tensor 62 | 63 | Returns: 64 | Same shape as val, where some elements are very small (exponentially zero) 65 | """ 66 | if name is None: 67 | name = "exp_mask" 68 | return tf.add(val, (1 - tf.cast(mask, 'float')) * VERY_NEGATIVE_NUMBER, name=name) 69 | 70 | 71 | def exp_mask_for_high_rank(val, val_mask, name=None): 72 | val_mask = tf.expand_dims(val_mask, -1) 73 | return tf.add(val, (1 - tf.cast(val_mask, tf.float32)) * VERY_NEGATIVE_NUMBER, 74 | name=name or 'exp_mask_for_high_rank') 75 | 76 | 77 | def flatten(tensor, keep): 78 | fixed_shape = tensor.get_shape().as_list() 79 | start = len(fixed_shape) - keep 80 | left = reduce(mul, [fixed_shape[i] or tf.shape(tensor)[i] for i in range(start)]) 81 | out_shape = [left] + [fixed_shape[i] or tf.shape(tensor)[i] for i in range(start, len(fixed_shape))] 82 | flat = tf.reshape(tensor, out_shape) 83 | return flat 84 | 85 | 86 | def reconstruct(tensor, ref, keep, dim_reduced_keep=None): 87 | dim_reduced_keep = dim_reduced_keep or keep 88 | 89 | ref_shape = ref.get_shape().as_list() # original shape 90 | tensor_shape = tensor.get_shape().as_list() # current shape 91 | ref_stop = len(ref_shape) - keep # flatten dims list 92 | tensor_start = len(tensor_shape) - dim_reduced_keep # start 93 | pre_shape = [ref_shape[i] or tf.shape(ref)[i] for i in range(ref_stop)] # 94 | keep_shape = [tensor_shape[i] or tf.shape(tensor)[i] for i in range(tensor_start, len(tensor_shape))] # 95 | # pre_shape = [tf.shape(ref)[i] for i in range(len(ref.get_shape().as_list()[:-keep]))] 96 | # keep_shape = tensor.get_shape().as_list()[-keep:] 97 | target_shape = pre_shape + keep_shape 98 | out = tf.reshape(tensor, target_shape) 99 | return out 100 | 101 | 102 | def add_wd(wd, scope=None): 103 | scope = scope or tf.get_variable_scope().name 104 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope) 105 | counter = 0 106 | with tf.name_scope("weight_decay"): 107 | for var in variables: 108 | counter+=1 109 | weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, 110 | name="{}-wd".format('-'.join(str(var.op.name).split('/')))) 111 | tf.add_to_collection('losses', weight_decay) 112 | return counter 113 | 114 | 115 | def add_wd_without_bias(wd, scope=None): 116 | scope = scope or tf.get_variable_scope().name 117 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope) 118 | counter = 0 119 | with tf.name_scope("weight_decay"): 120 | for var in variables: 121 | if len(var.get_shape().as_list()) <= 1: continue 122 | counter += 1 123 | weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, 124 | name="{}-wd".format('-'.join(str(var.op.name).split('/')))) 125 | tf.add_to_collection('losses', weight_decay) 126 | return counter 127 | 128 | 129 | def add_reg_without_bias(scope=None): 130 | scope = scope or tf.get_variable_scope().name 131 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope) 132 | counter = 0 133 | for var in variables: 134 | if len(var.get_shape().as_list()) <= 1: continue 135 | tf.add_to_collection('reg_vars', var) 136 | counter += 1 137 | return counter 138 | 139 | 140 | def add_var_reg(var): 141 | tf.add_to_collection('reg_vars', var) 142 | 143 | 144 | def add_wd_for_var(var, wd): 145 | with tf.name_scope("weight_decay"): 146 | weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, 147 | name="{}-wd".format('-'.join(str(var.op.name).split('/')))) 148 | tf.add_to_collection('losses', weight_decay) 149 | 150 | -------------------------------------------------------------------------------- /SNLI_disan/src/utils/tree/shift_reduce.py: -------------------------------------------------------------------------------- 1 | 2 | def shift_reduce_constituency_forest(node_and_parent_idx_pairs): 3 | 4 | def get_idx_node_parent_pair(node_1base_idx): 5 | for idx_input,(node_idx,parent_idx) in enumerate(node_and_parent_idx_pairs): 6 | if node_idx == node_1base_idx: 7 | return idx_input,node_idx,parent_idx, 8 | raise RuntimeError( 'cannot find the node %d in node_and_parent_idx_pairs %s'\ 9 | %(node_1base_idx,str(node_and_parent_idx_pairs))) 10 | 11 | 12 | 13 | node_num = len(node_and_parent_idx_pairs) 14 | root_node_num = sum([1 for _,p in node_and_parent_idx_pairs if p == 0]) 15 | shifted = [0] * node_num 16 | children = [] 17 | parents = [] 18 | used = [] # 0 for un-used, 1 for used 19 | op_stack = [] # to restore the operation as the output 20 | 21 | while True: 22 | # check enough 0 23 | now_root_num = sum([1 for p in parents if p == 0]) 24 | if now_root_num == root_node_num: 25 | break 26 | 27 | # check whether reduce 28 | do_reduce = False 29 | try: 30 | last_idx = parents[-1] 31 | now_children_num = sum([1 for u, p_idx in zip(used, parents) if u == 0 and p_idx == last_idx]) 32 | fact_children_num = sum([1 for _, p_idx in node_and_parent_idx_pairs if p_idx == last_idx]) 33 | if now_children_num == fact_children_num: do_reduce = True 34 | #if last_idx == 0:do_reduce = False#FIXME :???WHY 35 | except IndexError: 36 | pass # len(parents) == 0 37 | # reduce or shift 38 | if do_reduce: 39 | reduce_idx = parents[-1] # reduce for reduce_idx node 40 | reduce_parent_idx = get_idx_node_parent_pair(reduce_idx)[2] # its parent node index 41 | # mark used and collect reduce_idxs 42 | reduce_idxs = [] 43 | for idx_n,p in enumerate(parents): # idx_n: index in stack mat; p:corresponding parent node idx 44 | if used[idx_n] == 0 and p == reduce_idx: 45 | used[idx_n] = 1 46 | reduce_idxs.append(idx_n) 47 | 48 | shifted[get_idx_node_parent_pair(reduce_idx)[0]] = 1 49 | children.append(reduce_idx) 50 | parents.append(reduce_parent_idx) 51 | used.append(0) 52 | op_stack.append((2, reduce_idx, reduce_idxs)) 53 | 54 | 55 | else: # do shift 56 | # get pointer 57 | pointer = 0 58 | for idx_s in range(node_num): 59 | if shifted[idx_s] == 0: 60 | pointer = idx_s 61 | shifted[idx_s] = 1 62 | break 63 | children.append(node_and_parent_idx_pairs[pointer][0]) 64 | parents.append(node_and_parent_idx_pairs[pointer][1]) 65 | used.append(0) 66 | # op_stack.append(-child_and_parent_idx[pointer][0]) 67 | op_stack.append((1, node_and_parent_idx_pairs[pointer][0], [])) 68 | assert len(op_stack) == len(node_and_parent_idx_pairs) 69 | return op_stack 70 | 71 | 72 | def shift_reduce_constitucy(parent_idx_seq): 73 | ''' 74 | This file is implemented to solve: tree sequence to constituency transition sequence: 75 | 76 | input is a list of father node idx of constituency tree (1 based) 77 | output is a list of element which is a one of 78 | [ 79 | 1. 1 for reduce and 2 for shift 80 | 2. 1 based node index 81 | 2. a list of 0-based index to reduce 82 | ] 83 | @ author: xx 84 | @ Email: xx@xxx 85 | ''' 86 | node_num = len(parent_idx_seq) # all node number in the parsing tree 87 | child_and_parent_idx = [(child_idx+1, parent_idx) # list of (child_idx, parent_idx) pair 88 | for child_idx,parent_idx in enumerate(parent_idx_seq)] 89 | # runtime variable: 90 | shifted = [0] * node_num # 0 for shifted and 0 for un-shifted 91 | children = [] 92 | parents = [] 93 | used = [] # 0 for un-used, 1 for used 94 | op_stack = [] # to restore the operation as the output 95 | 96 | while True: 97 | # check whether reduce 98 | do_reduce = False 99 | try: 100 | last_idx = parents[-1] 101 | # count 102 | count=sum([1 for u,p_idx in zip(used,parents) if u==0 and p_idx==last_idx ]) 103 | # check if count satisfy the number of 'last_idx''s children num 104 | children_num = sum([1 for _,p_idx in child_and_parent_idx if p_idx==last_idx]) 105 | if count == children_num: 106 | do_reduce = True 107 | except IndexError: 108 | pass # len(parents) == 0 109 | # reduce or shift 110 | if do_reduce: 111 | reduce_idx = parents[-1] 112 | reduce_parent_idx = child_and_parent_idx[reduce_idx-1][1] 113 | # mark used 114 | reduce_idxs = [] 115 | for idx_n,p in enumerate(parents): 116 | if used[idx_n] == 0 and p == reduce_idx: 117 | used[idx_n] = 1 118 | reduce_idxs.append(idx_n) 119 | 120 | shifted[reduce_idx-1] = 1 121 | children.append(reduce_idx) 122 | parents.append(reduce_parent_idx) 123 | used.append(0) 124 | op_stack.append((2,reduce_idx,reduce_idxs)) 125 | if reduce_parent_idx == 0: 126 | break 127 | else: # do shift 128 | # get pointer 129 | pointer = 0 130 | for idx_s in range(node_num): 131 | if shifted[idx_s] == 0: 132 | pointer = idx_s 133 | shifted[idx_s] = 1 134 | break 135 | 136 | children.append(child_and_parent_idx[pointer][0]) 137 | parents.append(child_and_parent_idx[pointer][1]) 138 | used.append(0) 139 | #op_stack.append(-child_and_parent_idx[pointer][0]) 140 | op_stack.append((1,child_and_parent_idx[pointer][0],[] )) 141 | 142 | return op_stack 143 | 144 | 145 | 146 | 147 | if __name__ == '__main__': 148 | seq_str = '19 19 22 23 24 25 26 27 27 28 30 31 31 32 33 34 35 35 20 21 0 20 22 23 24 25 28 29 26 29 30 21 32 33 34' 149 | seq = list(map(int,seq_str.split(' '))) 150 | 151 | print(shift_reduce_constitucy(seq)) 152 | 153 | 154 | -------------------------------------------------------------------------------- /SST_disan/src/utils/tree/shift_reduce.py: -------------------------------------------------------------------------------- 1 | 2 | def shift_reduce_constituency_forest(node_and_parent_idx_pairs): 3 | 4 | def get_idx_node_parent_pair(node_1base_idx): 5 | for idx_input,(node_idx,parent_idx) in enumerate(node_and_parent_idx_pairs): 6 | if node_idx == node_1base_idx: 7 | return idx_input,node_idx,parent_idx, 8 | raise RuntimeError( 'cannot find the node %d in node_and_parent_idx_pairs %s'\ 9 | %(node_1base_idx,str(node_and_parent_idx_pairs))) 10 | 11 | 12 | 13 | node_num = len(node_and_parent_idx_pairs) 14 | root_node_num = sum([1 for _,p in node_and_parent_idx_pairs if p == 0]) 15 | shifted = [0] * node_num 16 | children = [] 17 | parents = [] 18 | used = [] # 0 for un-used, 1 for used 19 | op_stack = [] # to restore the operation as the output 20 | 21 | while True: 22 | # check enough 0 23 | now_root_num = sum([1 for p in parents if p == 0]) 24 | if now_root_num == root_node_num: 25 | break 26 | 27 | # check whether reduce 28 | do_reduce = False 29 | try: 30 | last_idx = parents[-1] 31 | now_children_num = sum([1 for u, p_idx in zip(used, parents) if u == 0 and p_idx == last_idx]) 32 | fact_children_num = sum([1 for _, p_idx in node_and_parent_idx_pairs if p_idx == last_idx]) 33 | if now_children_num == fact_children_num: do_reduce = True 34 | #if last_idx == 0:do_reduce = False#FIXME :???WHY 35 | except IndexError: 36 | pass # len(parents) == 0 37 | # reduce or shift 38 | if do_reduce: 39 | reduce_idx = parents[-1] # reduce for reduce_idx node 40 | reduce_parent_idx = get_idx_node_parent_pair(reduce_idx)[2] # its parent node index 41 | # mark used and collect reduce_idxs 42 | reduce_idxs = [] 43 | for idx_n,p in enumerate(parents): # idx_n: index in stack mat; p:corresponding parent node idx 44 | if used[idx_n] == 0 and p == reduce_idx: 45 | used[idx_n] = 1 46 | reduce_idxs.append(idx_n) 47 | 48 | shifted[get_idx_node_parent_pair(reduce_idx)[0]] = 1 49 | children.append(reduce_idx) 50 | parents.append(reduce_parent_idx) 51 | used.append(0) 52 | op_stack.append((2, reduce_idx, reduce_idxs)) 53 | 54 | 55 | else: # do shift 56 | # get pointer 57 | pointer = 0 58 | for idx_s in range(node_num): 59 | if shifted[idx_s] == 0: 60 | pointer = idx_s 61 | shifted[idx_s] = 1 62 | break 63 | children.append(node_and_parent_idx_pairs[pointer][0]) 64 | parents.append(node_and_parent_idx_pairs[pointer][1]) 65 | used.append(0) 66 | # op_stack.append(-child_and_parent_idx[pointer][0]) 67 | op_stack.append((1, node_and_parent_idx_pairs[pointer][0], [])) 68 | assert len(op_stack) == len(node_and_parent_idx_pairs) 69 | return op_stack 70 | 71 | 72 | def shift_reduce_constitucy(parent_idx_seq): 73 | ''' 74 | This file is implemented to solve: tree sequence to constituency transition sequence: 75 | 76 | input is a list of father node idx of constituency tree (1 based) 77 | output is a list of element which is a one of 78 | [ 79 | 1. 1 for reduce and 2 for shift 80 | 2. 1 based node index 81 | 2. a list of 0-based index to reduce 82 | ] 83 | @ author: xx 84 | @ Email: xx@xxx 85 | ''' 86 | node_num = len(parent_idx_seq) # all node number in the parsing tree 87 | child_and_parent_idx = [(child_idx+1, parent_idx) # list of (child_idx, parent_idx) pair 88 | for child_idx,parent_idx in enumerate(parent_idx_seq)] 89 | # runtime variable: 90 | shifted = [0] * node_num # 0 for shifted and 0 for un-shifted 91 | children = [] 92 | parents = [] 93 | used = [] # 0 for un-used, 1 for used 94 | op_stack = [] # to restore the operation as the output 95 | 96 | while True: 97 | # check whether reduce 98 | do_reduce = False 99 | try: 100 | last_idx = parents[-1] 101 | # count 102 | count=sum([1 for u,p_idx in zip(used,parents) if u==0 and p_idx==last_idx ]) 103 | # check if count satisfy the number of 'last_idx''s children num 104 | children_num = sum([1 for _,p_idx in child_and_parent_idx if p_idx==last_idx]) 105 | if count == children_num: 106 | do_reduce = True 107 | except IndexError: 108 | pass # len(parents) == 0 109 | # reduce or shift 110 | if do_reduce: 111 | reduce_idx = parents[-1] 112 | reduce_parent_idx = child_and_parent_idx[reduce_idx-1][1] 113 | # mark used 114 | reduce_idxs = [] 115 | for idx_n,p in enumerate(parents): 116 | if used[idx_n] == 0 and p == reduce_idx: 117 | used[idx_n] = 1 118 | reduce_idxs.append(idx_n) 119 | 120 | shifted[reduce_idx-1] = 1 121 | children.append(reduce_idx) 122 | parents.append(reduce_parent_idx) 123 | used.append(0) 124 | op_stack.append((2,reduce_idx,reduce_idxs)) 125 | if reduce_parent_idx == 0: 126 | break 127 | else: # do shift 128 | # get pointer 129 | pointer = 0 130 | for idx_s in range(node_num): 131 | if shifted[idx_s] == 0: 132 | pointer = idx_s 133 | shifted[idx_s] = 1 134 | break 135 | 136 | children.append(child_and_parent_idx[pointer][0]) 137 | parents.append(child_and_parent_idx[pointer][1]) 138 | used.append(0) 139 | #op_stack.append(-child_and_parent_idx[pointer][0]) 140 | op_stack.append((1,child_and_parent_idx[pointer][0],[] )) 141 | 142 | return op_stack 143 | 144 | 145 | 146 | 147 | if __name__ == '__main__': 148 | seq_str = '19 19 22 23 24 25 26 27 27 28 30 31 31 32 33 34 35 35 20 21 0 20 22 23 24 25 28 29 26 29 30 21 32 33 34' 149 | seq = list(map(int,seq_str.split(' '))) 150 | 151 | print(shift_reduce_constitucy(seq)) 152 | 153 | 154 | -------------------------------------------------------------------------------- /Fast-DiSA/README.md: -------------------------------------------------------------------------------- 1 | # Fast Directional Self-Attention (Fast-DiSA) Mechanism 2 | 3 | This is the Tensorflow implementation for **Fast-DiSA** and **Stacking Fast-DiSA** that is a time-efficient and memory-friendly multi-dim token2token self-attention mechanism for context fusion. Fast-DiSA can be regarded as a alternative to RNNs and multi-head self-attention. 4 | 5 | The details of this model are elaborated in [this paper](https://arxiv.org/abs/1805.00912). 6 | 7 | # How to Use 8 | just download [*fast_disa.py*](https://github.com/taoshen58/DiSAN/tree/master/Fast-DiSA/fast_disa.py) and add the line below to your model script file: 9 | 10 | from fast_disa import fast_directional_self_attention, mask_ft_generation, stacking_fast_directional_self_attention 11 | 12 | 13 | Then, follow the API below. 14 | 15 | 16 | ## API 17 | 18 | ### For `fast_directional_self_attention` 19 | 20 | The general API for Fast Self-Attention Attention mechanism for context fusion. 21 | :param rep_tensor: tf.float32-[batch_size,seq_len,channels], input sequence tensor; 22 | :param rep_mask: tf.bool-[batch_size,seq_len], mask to indicate padding or not for "rep_tensor"; 23 | :param hn: int32-[], hidden unit number for this attention module; 24 | :param head_num: int32-[]; multi-head number, if "use_direction" is set to True, this must be set to a even number, 25 | i.e., half for forward and remaining for backward; 26 | :param is_train: tf.bool-[]; This arg must be a Placehold or Tensor of Tensorflow. This may be useful if you build 27 | a graph for both training and testing, and you can create a Placehold to indicate training(True) or testing(False) 28 | and pass the Placehold into this method; 29 | :param attn_keep_prob: float-[], the value must be in [0.0 ,1.0] and this keep probability is for attenton dropout; 30 | :param dense_keep_prob: float-[], the value must be in [0.0 ,1.0] and this probability is for dense-layer dropout; 31 | :param wd: float-[], if you use L2-reg, set this value to be greater than 0., which will result in that the 32 | trainable parameters (without biases) are added to a tensorflow collection named as "reg_vars"; 33 | :param use_direction: bool-[], for mask generation, use forward and backward direction masks or not; 34 | :param attn_self: bool-[], for mask generation, include attention over self or not 35 | :param use_fusion_gate: bool-[], use a fusion gate to dynamically combine attention results with input or not. 36 | :param final_mask_ft: None/tf.float-[head_num,batch_size,seq_len,seq_len], the value is whether 0 (disabled) or 37 | 1 (enabled), set to None if you only use single layer of this method; use *mask_generation* method 38 | to generate one and pass it into this method if you want to stack this module for computation resources saving; 39 | :param dot_activation_name: str-[], "exp" or "sigmoid", the activation function name for dot product 40 | self-attention logits; 41 | :param use_input_for_attn: bool-[], if True, use *rep_tensor* to compute dot-product and s2t multi-dim self-attn 42 | alignment score; if False, use a tensor obtained by applying a dense layer to the *rep_tensor*, which can add the 43 | non-linearity for this layer; 44 | :param add_layer_for_multi: bool-[], if True, add a dense layer with activation func -- "activation_func_name" 45 | to calculate the s2t multi-dim self-attention alignment score; 46 | :param activation_func_name: str-[], activation function name, commonly-used: "relu", "elu", "selu"; 47 | :param apply_act_for_v: bool-[], if or not apply the non-linearity activation function ("activation_func_name") to 48 | value map (same as the value map in multi-head attention); 49 | :param apply_act_for_v: bool-[], if apply an activation function to v in the attention; 50 | :param input_hn: None/int32-[], if not None, add an extra dense layer (unit num is "input_hn") with 51 | activation function ("activation_func_name") before attention without consideration of multi-head. 52 | :param output_hn: None/int32-[], if not None, add an extra dense layer (unit num is "output_hn") with 53 | activation function ("activation_func_name") after attention without consideration of multi-head. 54 | :param accelerate: bool-[], for model acceleration, we optimize and combined some matrix multiplication if using 55 | the accelerate (i.e., set as True), which may effect the dropout-sensitive models or tasks. 56 | :param merge_var: bool-[], because the batch matmul is used for parallelism of multi-head attention, if True, the 57 | trainable variables are declared and defined together, otherwise them are defined separately and combined together. 58 | :param scope: None/str-[], variable scope name. 59 | :return: tf.float32-[batch_size, sequence_length, out_hn], if output_hn is not None, the out_hn = "output_hn" 60 | otherwise out_hn = "hn" 61 | 62 | ### For `stacking_fast_directional_self_attention` 63 | stacked Fast-DiSA 64 | :param rep_tensor: same as that in Fast-DiSA; 65 | :param rep_mask: same as that in Fast-DiSA; 66 | :param hn: same as that in Fast-DiSA; 67 | :param head_num: same as that in Fast-DiSA; 68 | :param is_train: same as that in Fast-DiSA; 69 | :param residual_keep_prob: float-[], dropout keep probability for residual connection; 70 | :param attn_keep_prob: same as that in Fast-DiSA; 71 | :param dense_keep_prob: same as that in Fast-DiSA; 72 | :param wd: same as that in Fast-DiSA; 73 | :param use_direction: same as that in Fast-DiSA; 74 | :param attn_self: same as that in Fast-DiSA; 75 | :param activation_func_name: same as that in Fast-DiSA; 76 | :param dot_activation_name: same as that in Fast-DiSA; 77 | :param layer_num: int-[], the number of layer stacked; 78 | :param scope: str-[], scope name 79 | :return: tf.float32-[batch_size, sequence_length, hn] 80 | 81 | 82 | 83 | ## Hyper-Parameters Suggestion 84 | 85 | * param `accelerate` should be set to `False` if the model or task is dropout-sensitive. 86 | * param `use_direction` should be set to `False` when the input is not order-related. 87 | * The hyper-params choosing for both single layer and stacked layer is detailed in the Table below. The reason why we choose params like this is that we need a fastest and simplest fast-disa model for the stacking. Note that the suggested hyper-params below is only for `fast_directional_self_attention`, and the suggested hyper-params have been applied in `stacking_fast_directional_self_attention`. 88 | 89 | | Hyper-Params | For Single Layer | For Stacked Layer | 90 | | --- | --- | --- | 91 | | use_direction | True | True | 92 | | attn_self | False | False (Depends on task) | 93 | | use_fusion_gate | True | False | 94 | | final_mask_ft | None | invoke `mask_ft_generation` | 95 | | dot_activation_name | 'sigmoid' | 'exp' | 96 | | use_input_for_attn | False | True | 97 | | add_layer_for_multi | True | False | 98 | | activation_func_name | as you will | as you will | 99 | | apply_act_for_v | True | False | 100 | | input_hn | None | None | 101 | | output_hn | None | Set to `hn` for residual connection | 102 | | accelerate | False | True | 103 | | merge_var | False | True | 104 | 105 | ## TODO List 106 | 1. ~~release the single layer version of Fast-DiSA~~ 107 | 2. ~~release the stacking version of Fast-DiSA (like the deep model in the Transformer)~~ 108 | 3. Projects codes for some NLP tasks 109 | 110 | ## Contact Information 111 | Email: [tao.shen@student.uts.edu.au](mailto:tao.shen@student.uts.edu.au) 112 | Feel free to contact me or open an issue if you have any question or encounter any bug! 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /SST_disan/src/evaluator.py: -------------------------------------------------------------------------------- 1 | from configs import cfg 2 | from src.utils.record_log import _logger 3 | import numpy as np 4 | import tensorflow as tf 5 | from src.analysis import OutputAnalysis 6 | import os, shutil 7 | 8 | class Evaluator(object): 9 | def __init__(self, model): 10 | self.model = model 11 | self.global_step = model.global_step 12 | 13 | ## ---- summary---- 14 | self.build_summary() 15 | self.writer = tf.summary.FileWriter(cfg.summary_dir) 16 | 17 | # --- external use --- 18 | def get_evaluation(self, sess, dataset_obj, global_step=None): 19 | _logger.add() 20 | _logger.add('getting evaluation result for %s' % dataset_obj.data_type) 21 | 22 | logits_list, loss_list, accu_list = [], [], [] 23 | is_sent_list = [] 24 | for sample_batch, _, _, _ in dataset_obj.generate_batch_sample_iter(): 25 | feed_dict = self.model.get_feed_dict(sample_batch, 'dev') 26 | logits, loss, accu = sess.run([self.model.logits, 27 | self.model.loss, self.model.accuracy], feed_dict) 28 | logits_list.append(np.argmax(logits, -1)) 29 | loss_list.append(loss) 30 | accu_list.append(accu) 31 | is_sent_list += [sample['is_sent'] for sample in sample_batch] 32 | logits_array = np.concatenate(logits_list, 0) 33 | loss_value = np.mean(loss_list) 34 | accu_array = np.concatenate(accu_list, 0) 35 | accu_value =np.mean(accu_array) 36 | sent_accu_list = [] 37 | for idx, is_sent in enumerate(is_sent_list): 38 | if is_sent: 39 | sent_accu_list.append(accu_array[idx]) 40 | sent_accu_value = np.mean(sent_accu_list) 41 | 42 | # analysis 43 | # analysis_save_dir = cfg.mkdir(cfg.answer_dir,'gs_%s'%global_step or 'test') 44 | # OutputAnalysis.do_analysis(dataset_obj, logits_array, accu_array, analysis_save_dir, 45 | # cfg.fine_grained) 46 | 47 | # add summary 48 | if global_step is not None: 49 | if dataset_obj.data_type == 'train': 50 | summary_feed_dict = { 51 | self.train_loss: loss_value, 52 | self.train_accuracy: accu_value, 53 | self.train_sent_accuracy: sent_accu_value, 54 | } 55 | summary = sess.run(self.train_summaries, summary_feed_dict) 56 | self.writer.add_summary(summary, global_step) 57 | elif dataset_obj.data_type == 'dev': 58 | summary_feed_dict = { 59 | self.dev_loss: loss_value, 60 | self.dev_accuracy: accu_value, 61 | self.dev_sent_accuracy: sent_accu_value, 62 | } 63 | summary = sess.run(self.dev_summaries, summary_feed_dict) 64 | self.writer.add_summary(summary, global_step) 65 | else: 66 | summary_feed_dict = { 67 | self.test_loss: loss_value, 68 | self.test_accuracy: accu_value, 69 | self.test_sent_accuracy: sent_accu_value, 70 | } 71 | summary = sess.run(self.test_summaries, summary_feed_dict) 72 | self.writer.add_summary(summary, global_step) 73 | return loss_value, accu_value, sent_accu_value 74 | 75 | def get_evaluation_file_output(self, sess, dataset_obj, global_step, deleted_step): 76 | _logger.add() 77 | _logger.add('get evaluation file output for %s' % dataset_obj.data_type) 78 | # delete old file 79 | if deleted_step is not None: 80 | delete_name = 'gs_%d' % deleted_step 81 | delete_path = os.path.join(cfg.answer_dir, delete_name) 82 | if os.path.exists(delete_path): 83 | shutil.rmtree(delete_path) 84 | _logger.add() 85 | _logger.add('getting evaluation result for %s' % dataset_obj.data_type) 86 | 87 | logits_list, loss_list, accu_list = [], [], [] 88 | is_sent_list = [] 89 | for sample_batch, _, _, _ in dataset_obj.generate_batch_sample_iter(): 90 | feed_dict = self.model.get_feed_dict(sample_batch, 'dev') 91 | logits, loss, accu = sess.run([self.model.logits, 92 | self.model.loss, self.model.accuracy], feed_dict) 93 | logits_list.append(np.argmax(logits, -1)) 94 | loss_list.append(loss) 95 | accu_list.append(accu) 96 | is_sent_list += [sample['is_sent'] for sample in sample_batch] 97 | logits_array = np.concatenate(logits_list, 0) 98 | loss_value = np.mean(loss_list) 99 | accu_array = np.concatenate(accu_list, 0) 100 | accu_value = np.mean(accu_array) 101 | sent_accu_list = [] 102 | for idx, is_sent in enumerate(is_sent_list): 103 | if is_sent: 104 | sent_accu_list.append(accu_array[idx]) 105 | sent_accu_value = np.mean(sent_accu_list) 106 | 107 | # analysis 108 | analysis_save_dir = cfg.mkdir(cfg.answer_dir,'gs_%s'%global_step or 'test') 109 | OutputAnalysis.do_analysis(dataset_obj, logits_array, accu_array, analysis_save_dir, 110 | cfg.fine_grained) 111 | 112 | 113 | # --- internal use ------ 114 | def build_summary(self): 115 | with tf.name_scope('train_summaries'): 116 | self.train_loss = tf.placeholder(tf.float32, [], 'train_loss') 117 | self.train_accuracy = tf.placeholder(tf.float32, [], 'train_accuracy') 118 | self.train_sent_accuracy = tf.placeholder(tf.float32, [], 'train_sent_accuracy') 119 | tf.add_to_collection('train_summaries_collection', tf.summary.scalar('train_loss', self.train_loss)) 120 | tf.add_to_collection('train_summaries_collection', tf.summary.scalar('train_accuracy', self.train_accuracy)) 121 | tf.add_to_collection('train_summaries_collection', tf.summary.scalar('train_sent_accuracy', 122 | self.train_sent_accuracy)) 123 | self.train_summaries = tf.summary.merge_all('train_summaries_collection') 124 | 125 | with tf.name_scope('dev_summaries'): 126 | self.dev_loss = tf.placeholder(tf.float32, [], 'dev_loss') 127 | self.dev_accuracy = tf.placeholder(tf.float32, [], 'dev_accuracy') 128 | self.dev_sent_accuracy = tf.placeholder(tf.float32, [], 'dev_sent_accuracy') 129 | tf.add_to_collection('dev_summaries_collection', tf.summary.scalar('dev_loss',self.dev_loss)) 130 | tf.add_to_collection('dev_summaries_collection', tf.summary.scalar('dev_accuracy',self.dev_accuracy)) 131 | tf.add_to_collection('dev_summaries_collection', tf.summary.scalar('dev_sent_accuracy', 132 | self.dev_sent_accuracy)) 133 | self.dev_summaries = tf.summary.merge_all('dev_summaries_collection') 134 | 135 | with tf.name_scope('test_summaries'): 136 | self.test_loss = tf.placeholder(tf.float32, [], 'test_loss') 137 | self.test_accuracy = tf.placeholder(tf.float32, [], 'test_accuracy') 138 | self.test_sent_accuracy = tf.placeholder(tf.float32, [], 'test_sent_accuracy') 139 | tf.add_to_collection('test_summaries_collection', tf.summary.scalar('test_loss',self.test_loss)) 140 | tf.add_to_collection('test_summaries_collection', tf.summary.scalar('test_accuracy',self.test_accuracy)) 141 | tf.add_to_collection('test_summaries_collection', tf.summary.scalar('test_sent_accuracy', 142 | self.test_sent_accuracy)) 143 | self.test_summaries = tf.summary.merge_all('test_summaries_collection') 144 | 145 | 146 | 147 | -------------------------------------------------------------------------------- /SST_disan/src/nn_utils/integration.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from src.nn_utils.rnn import dynamic_rnn, bidirectional_dynamic_rnn 3 | from src.nn_utils.rnn_cell import SwitchableDropoutWrapper 4 | from src.nn_utils.general import get_last_state, add_reg_without_bias 5 | from src.nn_utils.nn import highway_network,multi_conv1d 6 | 7 | 8 | def contextual_bi_rnn(tensor_rep, mask_rep, hn, cell_type, only_final=False, 9 | wd=0., keep_prob=1.,is_train=None, scope=None): 10 | """ 11 | fusing contextual information using bi-direction rnn 12 | :param tensor_rep: [..., sl, vec] 13 | :param mask_rep: [..., sl] 14 | :param hn: 15 | :param cell_type: 'gru', 'lstm', basic_lstm' and 'basic_rnn' 16 | :param only_final: True or False 17 | :param wd: 18 | :param keep_prob: 19 | :param is_train: 20 | :param scope: 21 | :return: 22 | """ 23 | with tf.variable_scope(scope or 'contextual_bi_rnn'): # correct 24 | reuse = None if not tf.get_variable_scope().reuse else True 25 | #print(reuse) 26 | if cell_type == 'gru': 27 | cell_fw = tf.contrib.rnn.GRUCell(hn, reuse=reuse) 28 | cell_bw = tf.contrib.rnn.GRUCell(hn, reuse=reuse) 29 | elif cell_type == 'lstm': 30 | cell_fw = tf.contrib.rnn.LSTMCell(hn, reuse=reuse) 31 | cell_bw = tf.contrib.rnn.LSTMCell(hn, reuse=reuse) 32 | elif cell_type == 'basic_lstm': 33 | cell_fw = tf.contrib.rnn.BasicLSTMCell(hn, reuse=reuse) 34 | cell_bw = tf.contrib.rnn.BasicLSTMCell(hn, reuse=reuse) 35 | elif cell_type == 'basic_rnn': 36 | cell_fw = tf.contrib.rnn.BasicRNNCell(hn, reuse=reuse) 37 | cell_bw = tf.contrib.rnn.BasicRNNCell(hn, reuse=reuse) 38 | else: 39 | raise AttributeError('no cell type \'%s\'' % cell_type) 40 | cell_dp_fw = SwitchableDropoutWrapper(cell_fw,is_train,keep_prob) 41 | cell_dp_bw = SwitchableDropoutWrapper(cell_bw,is_train,keep_prob) 42 | 43 | tensor_len = tf.reduce_sum(tf.cast(mask_rep, tf.int32), -1) # [bs] 44 | 45 | (outputs_fw, output_bw), _=bidirectional_dynamic_rnn( 46 | cell_dp_fw, cell_dp_bw, tensor_rep, tensor_len, 47 | dtype=tf.float32) 48 | rnn_outputs = tf.concat([outputs_fw,output_bw],-1) # [...,sl,2hn] 49 | 50 | if wd > 0: 51 | add_reg_without_bias() 52 | if not only_final: 53 | return rnn_outputs # [....,sl, 2hn] 54 | else: 55 | return get_last_state(rnn_outputs, mask_rep) # [...., 2hn] 56 | 57 | 58 | def one_direction_rnn(tensor_rep, mask_rep, hn, cell_type, only_final=False, 59 | wd=0., keep_prob=1.,is_train=None, is_forward = True,scope=None): 60 | assert not is_forward # todo: waiting to be implemented 61 | with tf.variable_scope(scope or '%s_rnn' % 'forward' if is_forward else 'backward'): 62 | reuse = None if not tf.get_variable_scope().reuse else True 63 | # print(reuse) 64 | if cell_type == 'gru': 65 | cell = tf.contrib.rnn.GRUCell(hn, reuse=reuse) 66 | elif cell_type == 'lstm': 67 | cell = tf.contrib.rnn.LSTMCell(hn, reuse=reuse) 68 | elif cell_type == 'basic_lstm': 69 | cell = tf.contrib.rnn.BasicLSTMCell(hn, reuse=reuse) 70 | elif cell_type == 'basic_rnn': 71 | cell = tf.contrib.rnn.BasicRNNCell(hn, reuse=reuse) 72 | else: 73 | raise AttributeError('no cell type \'%s\'' % cell_type) 74 | cell_dp = SwitchableDropoutWrapper(cell, is_train, keep_prob) 75 | 76 | tensor_len = tf.reduce_sum(tf.cast(mask_rep, tf.int32), -1) # [bs] 77 | 78 | rnn_outputs, _ = dynamic_rnn( 79 | cell_dp, tensor_rep, tensor_len, 80 | dtype=tf.float32) 81 | 82 | if wd > 0: 83 | add_reg_without_bias() 84 | if not only_final: 85 | return rnn_outputs # [....,sl, 2hn] 86 | else: 87 | return get_last_state(rnn_outputs, mask_rep) # [...., 2hn] 88 | 89 | 90 | def generate_embedding_mat(dict_size, emb_len, init_mat=None, extra_mat=None, 91 | extra_trainable=False, scope=None): 92 | """ 93 | generate embedding matrix for looking up 94 | :param dict_size: indices 0 and 1 corresponding to empty and unknown token 95 | :param emb_len: 96 | :param init_mat: init mat matching for [dict_size, emb_len] 97 | :param extra_mat: extra tensor [extra_dict_size, emb_len] 98 | :param extra_trainable: 99 | :param scope: 100 | :return: if extra_mat is None, return[dict_size+extra_dict_size,emb_len], else [dict_size,emb_len] 101 | """ 102 | with tf.variable_scope(scope or 'gene_emb_mat'): 103 | emb_mat_ept_and_unk = tf.constant(value=0, dtype=tf.float32, shape=[2, emb_len]) 104 | if init_mat is None: 105 | emb_mat_other = tf.get_variable('emb_mat',[dict_size - 2, emb_len], tf.float32) 106 | else: 107 | emb_mat_other = tf.get_variable("emb_mat",[dict_size - 2, emb_len], tf.float32, 108 | initializer=tf.constant_initializer(init_mat[2:], dtype=tf.float32, 109 | verify_shape=True)) 110 | emb_mat = tf.concat([emb_mat_ept_and_unk, emb_mat_other], 0) 111 | 112 | if extra_mat is not None: 113 | if extra_trainable: 114 | extra_mat_var = tf.get_variable("extra_emb_mat",extra_mat.shape, tf.float32, 115 | initializer=tf.constant_initializer(extra_mat, 116 | dtype=tf.float32, 117 | verify_shape=True)) 118 | return tf.concat([emb_mat, extra_mat_var], 0) 119 | else: 120 | #with tf.device('/cpu:0'): 121 | extra_mat_con = tf.constant(extra_mat, dtype=tf.float32) 122 | return tf.concat([emb_mat, extra_mat_con], 0) 123 | else: 124 | return emb_mat 125 | 126 | 127 | def token_and_char_emb(if_token_emb=True, context_token=None, tds=None, tel=None, 128 | token_emb_mat=None, glove_emb_mat=None, 129 | if_char_emb=True, context_char=None, cds=None, cel=None, 130 | cos=None, ocd=None, fh=None, use_highway=True,highway_layer_num=None, 131 | wd=0., keep_prob=1., is_train=None): 132 | with tf.variable_scope('token_and_char_emb'): 133 | if if_token_emb: 134 | with tf.variable_scope('token_emb'): 135 | token_emb_mat = generate_embedding_mat(tds, tel, init_mat=token_emb_mat, 136 | extra_mat=glove_emb_mat, 137 | scope='gene_token_emb_mat') 138 | 139 | c_token_emb = tf.nn.embedding_lookup(token_emb_mat, context_token) # bs,sl,tel 140 | 141 | if if_char_emb: 142 | with tf.variable_scope('char_emb'): 143 | char_emb_mat = generate_embedding_mat(cds, cel, scope='gene_char_emb_mat') 144 | c_char_lu_emb = tf.nn.embedding_lookup(char_emb_mat, context_char) # bs,sl,tl,cel 145 | 146 | assert sum(ocd) == cos and len(ocd) == len(fh) 147 | 148 | with tf.variable_scope('conv'): 149 | c_char_emb = multi_conv1d(c_char_lu_emb, ocd, fh, "VALID", 150 | is_train, keep_prob, scope="xx") # bs,sl,cocn 151 | if if_token_emb and if_char_emb: 152 | c_emb = tf.concat([c_token_emb, c_char_emb], -1) # bs,sl,cocn+tel 153 | elif if_token_emb: 154 | c_emb = c_token_emb 155 | elif if_char_emb: 156 | c_emb = c_char_emb 157 | else: 158 | raise AttributeError('No embedding!') 159 | 160 | if use_highway: 161 | with tf.variable_scope('highway'): 162 | c_emb = highway_network(c_emb, highway_layer_num, True, wd=wd, 163 | input_keep_prob=keep_prob,is_train=is_train) 164 | return c_emb 165 | 166 | 167 | def generate_feature_emb_for_c_and_q(feature_dict_size, feature_emb_len, 168 | feature_name , c_feature, q_feature=None, scope=None): 169 | with tf.variable_scope(scope or '%s_feature_emb' % feature_name): 170 | emb_mat = generate_embedding_mat(feature_dict_size, feature_emb_len, scope='emb_mat') 171 | c_feature_emb = tf.nn.embedding_lookup(emb_mat, c_feature) 172 | if q_feature is not None: 173 | q_feature_emb = tf.nn.embedding_lookup(emb_mat, q_feature) 174 | else: 175 | q_feature_emb = None 176 | return c_feature_emb, q_feature_emb 177 | -------------------------------------------------------------------------------- /SNLI_disan/configs.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import argparse 3 | import os 4 | from os.path import join 5 | from src.utils.time_counter import TimeCounter 6 | 7 | 8 | class Configs(object): 9 | def __init__(self): 10 | self.project_dir = os.getcwd() 11 | self.dataset_dir = join(self.project_dir, 'dataset') 12 | 13 | # ------parsing input arguments"-------- 14 | parser = argparse.ArgumentParser() 15 | parser.register('type', 'bool', (lambda x: x.lower() in ("yes", "true", "t", "1"))) 16 | 17 | # @ ----- control ---- 18 | parser.add_argument('--debug', type='bool', default=False, help='whether run as debug mode') 19 | parser.add_argument('--mode', type=str, default='train', help='train, dev or test') 20 | parser.add_argument('--network_type', type=str, default='test', help='network type') 21 | parser.add_argument('--log_period', type=int, default=2000, help='save tf summary period') 22 | parser.add_argument('--save_period', type=int, default=3000, help='abandoned') 23 | parser.add_argument('--eval_period', type=int, default=500, help='evaluation period') 24 | parser.add_argument('--gpu', type=int, default=0, help='employed gpu index') 25 | parser.add_argument('--gpu_mem', type=float, default=0.96, help='gpu memory ratio to employ') 26 | parser.add_argument('--model_dir_suffix', type=str, default='', help='model folder name suffix') 27 | parser.add_argument('--swap_memory', type='bool', default=False, help='abandoned') 28 | parser.add_argument('--load_model', type='bool', default=False, help='do not use') 29 | parser.add_argument('--load_step', type=int, default=None, help='do not use') 30 | parser.add_argument('--load_path', type=str, default=None, help='specify which pre-trianed model to be load') 31 | 32 | # @ ----------training ------ 33 | parser.add_argument('--max_epoch', type=int, default=100, help='max epoch number') 34 | parser.add_argument('--num_steps', type=int, default=400000, help='max steps num') 35 | parser.add_argument('--train_batch_size', type=int, default=64, help='Train Batch Size') 36 | parser.add_argument('--test_batch_size', type=int, default=100, help='Test Batch Size') 37 | parser.add_argument('--optimizer', type=str, default='adadelta', help='choose an optimizer[adadelta|adam]') 38 | parser.add_argument('--learning_rate', type=float, default=0.5, help='Init Learning rate') 39 | parser.add_argument('--dy_lr', type='bool', default=False, help='if decay lr during training') 40 | parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate decay') 41 | parser.add_argument('--dropout', type=float, default=0.75, help='dropout keep prob') 42 | parser.add_argument('--wd', type=float, default=5e-5, help='weight decay factor/l2 decay factor') 43 | parser.add_argument('--var_decay', type=float, default=0.999, help='Learning rate') # ema 44 | parser.add_argument('--decay', type=float, default=0.9, help='summary decay') # ema 45 | 46 | 47 | # @ ----- Text Processing ---- 48 | parser.add_argument('--word_embedding_length', type=int, default=300, help='word embedding length') 49 | parser.add_argument('--glove_corpus', type=str, default='6B', help='choose glove corpus to employ') 50 | parser.add_argument('--use_glove_unk_token', type='bool', default=True, help='') 51 | parser.add_argument('--lower_word', type='bool', default=True, help='') 52 | parser.add_argument('--data_clip_method', type=str, default='no_tree', 53 | help='for space-efficiency[no_tree|]no_redundancy') 54 | parser.add_argument('--sent_len_rate', type=float, default=0.97, help='delete too long sentences') 55 | 56 | # @ ------neural network----- 57 | parser.add_argument('--use_char_emb', type='bool', default=False, help='abandoned') 58 | parser.add_argument('--use_token_emb', type='bool', default=True, help='abandoned') 59 | parser.add_argument('--char_embedding_length', type=int, default=8, help='(abandoned)') 60 | parser.add_argument('--char_out_size', type=int, default=150, help='(abandoned)') 61 | parser.add_argument('--out_channel_dims', type=str, default='50,50,50', help='(abandoned)') 62 | parser.add_argument('--filter_heights', type=str, default='1,3,5', help='(abandoned)') 63 | parser.add_argument('--highway_layer_num', type=int, default=2, help='highway layer number(abandoned)') 64 | 65 | parser.add_argument('--hidden_units_num', type=int, default=300, help='Hidden units number of Neural Network') 66 | parser.add_argument('--tree_hn', type=int, default=100, help='(abandoned)') 67 | 68 | parser.add_argument('--fine_tune', type='bool', default=False, help='(abandoned, keep False)') # ema 69 | 70 | # # emb_opt_direct_attn 71 | parser.add_argument('--batch_norm', type='bool', default=False, help='(abandoned, keep False)') 72 | parser.add_argument('--activation', type=str, default='relu', help='(abandoned') 73 | 74 | parser.set_defaults(shuffle=True) 75 | self.args = parser.parse_args() 76 | 77 | ## ---- to member variables ----- 78 | for key, value in self.args.__dict__.items(): 79 | if key not in ['test', 'shuffle']: 80 | exec('self.%s = self.args.%s' % (key, key)) 81 | 82 | # ------- name -------- 83 | self.train_data_name = 'snli_1.0_train.jsonl' 84 | self.dev_data_name = 'snli_1.0_dev.jsonl' 85 | self.test_data_name = 'snli_1.0_test.jsonl' 86 | 87 | self.processed_name = 'processed' + self.get_params_str(['lower_word', 'use_glove_unk_token', 88 | 'glove_corpus', 'word_embedding_length', 89 | 'sent_len_rate', 90 | 'data_clip_method']) + '.pickle' 91 | self.dict_name = 'dicts' + self.get_params_str(['lower_word', 'use_glove_unk_token', 92 | ]) 93 | 94 | if not self.network_type == 'test': 95 | params_name_list = ['network_type', 'dropout', 'glove_corpus', 96 | 'word_embedding_length', 'fine_tune', 'char_out_size', 'sent_len_rate', 97 | 'hidden_units_num', 'wd', 'optimizer', 'learning_rate', 'dy_lr', 'lr_decay'] 98 | self.model_name = self.get_params_str(params_name_list) 99 | else: 100 | self.model_name = self.network_type 101 | self.model_ckpt_name = 'modelfile.ckpt' 102 | 103 | # ---------- dir ------------- 104 | self.data_dir = join(self.dataset_dir, 'snli_1.0') 105 | self.glove_dir = join(self.dataset_dir, 'glove') 106 | self.result_dir = self.mkdir(self.project_dir, 'result') 107 | self.standby_log_dir = self.mkdir(self.result_dir, 'log') 108 | self.dict_dir = self.mkdir(self.result_dir, 'dict') 109 | self.processed_dir = self.mkdir(self.result_dir, 'processed_data') 110 | 111 | self.log_dir = None 112 | self.all_model_dir = self.mkdir(self.result_dir, 'model') 113 | self.model_dir = self.mkdir(self.all_model_dir, self.model_dir_suffix + self.model_name) 114 | self.log_dir = self.mkdir(self.model_dir, 'log_files') 115 | self.summary_dir = self.mkdir(self.model_dir, 'summary') 116 | self.ckpt_dir = self.mkdir(self.model_dir, 'ckpt') 117 | self.answer_dir = self.mkdir(self.model_dir, 'answer') 118 | 119 | # -------- path -------- 120 | self.train_data_path = join(self.data_dir, self.train_data_name) 121 | self.dev_data_path = join(self.data_dir, self.dev_data_name) 122 | self.test_data_path = join(self.data_dir, self.test_data_name) 123 | 124 | self.processed_path = join(self.processed_dir, self.processed_name) 125 | self.dict_path = join(self.dict_dir, self.dict_name) 126 | self.ckpt_path = join(self.ckpt_dir, self.model_ckpt_name) 127 | 128 | self.extre_dict_path = join(self.dict_dir, 129 | 'extra_dict'+self.get_params_str(['data_clip_method'])+'.json') 130 | 131 | # dtype 132 | self.floatX = 'float32' 133 | self.intX = 'int32' 134 | os.environ["CUDA_VISIBLE_DEVICES"] = str(self.gpu) 135 | self.time_counter = TimeCounter() 136 | 137 | def get_params_str(self, params): 138 | def abbreviation(name): 139 | words = name.strip().split('_') 140 | abb = '' 141 | for word in words: 142 | abb += word[0] 143 | return abb 144 | 145 | abbreviations = map(abbreviation, params) 146 | model_params_str = '' 147 | for paramsStr, abb in zip(params, abbreviations): 148 | model_params_str += '_' + abb + '_' + str(eval('self.args.' + paramsStr)) 149 | return model_params_str 150 | 151 | def mkdir(self, *args): 152 | dirPath = join(*args) 153 | if not os.path.exists(dirPath): 154 | os.makedirs(dirPath) 155 | return dirPath 156 | 157 | def get_file_name_from_path(self, path): 158 | assert isinstance(path, str) 159 | fileName = '.'.join((path.split('/')[-1]).split('.')[:-1]) 160 | return fileName 161 | 162 | 163 | cfg = Configs() 164 | -------------------------------------------------------------------------------- /SST_disan/configs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os.path import join 4 | from src.utils.time_counter import TimeCounter 5 | 6 | class Configs(object): 7 | def __init__(self): 8 | self.project_dir = os.getcwd() 9 | self.dataset_dir = join(self.project_dir, 'dataset') 10 | 11 | # ------parsing input arguments"-------- 12 | parser = argparse.ArgumentParser() 13 | parser.register('type', 'bool', (lambda x: x.lower() in ('True', "yes", "true", "t", "1"))) 14 | 15 | # @ ----- control ---- 16 | parser.add_argument('--debug', type='bool', default=False, help='whether run as debug mode') 17 | parser.add_argument('--mode', type=str, default='train', help='train, dev or test') 18 | parser.add_argument('--network_type', type=str, default='test', help='network type') 19 | parser.add_argument('--fine_grained', type='bool', default=True, help='5 classes or 2 classes [True|False]') 20 | parser.add_argument('--only_sentence', type='bool', default=False, help='sentence or phrase level') 21 | parser.add_argument('--data_imbalance', type='bool', default=True, help='balance data distribution') 22 | parser.add_argument('--log_period', type=int, default=500, help='save tf summary period') ### change for running 23 | parser.add_argument('--save_period', type=int, default=3000, help='abandoned') 24 | parser.add_argument('--eval_period', type=int, default=1000, help='evaluation period') ### change for running 25 | parser.add_argument('--gpu', type=int, default=3, help='employed gpu index') 26 | parser.add_argument('--gpu_mem', type=float, default=0.96, help='gpu memory ratio to employ') 27 | parser.add_argument('--model_dir_suffix', type=str, default='', help='model folder name suffix') 28 | parser.add_argument('--swap_memory', type='bool', default=False, help='abandoned') 29 | parser.add_argument('--load_model', type='bool', default=False, help='do not use') 30 | parser.add_argument('--load_step', type=int, default=None, help='do not use') 31 | parser.add_argument('--load_path', type=str, default=None, help='specify which pre-trianed model to be load') 32 | 33 | # @ ----------training ------ 34 | parser.add_argument('--max_epoch', type=int, default=200, help='max epoch number') 35 | parser.add_argument('--num_steps', type=int, default=120000, help='max steps num') 36 | parser.add_argument('--train_batch_size', type=int, default=64, help='Train Batch Size') 37 | parser.add_argument('--test_batch_size', type=int, default=128, help='Test Batch Size') 38 | parser.add_argument('--optimizer', type=str, default='adadelta', help='choose an optimizer[adadelta|adam]') 39 | parser.add_argument('--learning_rate', type=float, default=0.5, help='Init Learning rate') 40 | parser.add_argument('--wd', type=float, default=1e-4, help='weight decay factor/l2 decay factor') 41 | parser.add_argument('--var_decay', type=float, default=0.999, help='Learning rate') # ema 42 | parser.add_argument('--decay', type=float, default=0.9, help='summary decay') # ema 43 | 44 | # @ ----- Text Processing ---- 45 | parser.add_argument('--word_embedding_length', type=int, default=300, help='word embedding length') 46 | parser.add_argument('--glove_corpus', type=str, default='6B', help='choose glove corpus to employ') 47 | parser.add_argument('--use_glove_unk_token', type='bool', default=True, help='') 48 | parser.add_argument('--lower_word', type='bool', default=True, help='') 49 | 50 | # @ ------neural network----- 51 | parser.add_argument('--use_char_emb', type='bool', default=False, help='abandoned') 52 | parser.add_argument('--use_token_emb', type='bool', default=True, help='abandoned') 53 | parser.add_argument('--char_embedding_length', type=int, default=8, help='abandoned') 54 | parser.add_argument('--char_out_size', type=int, default=150, help='abandoned') 55 | parser.add_argument('--out_channel_dims', type=str, default='50,50,50', help='abandoned') 56 | parser.add_argument('--filter_heights', type=str, default='1,3,5', help='abandoned') 57 | parser.add_argument('--highway_layer_num', type=int, default=2, help='highway layer number(abandoned)') 58 | 59 | parser.add_argument('--dropout', type=float, default=0.7, help='dropout keep prob') 60 | parser.add_argument('--hidden_units_num', type=int, default=300, help='Hidden units number of Neural Network') 61 | parser.add_argument('--fine_tune', type='bool', default=False, help='(abandoned, keep False)') # ema 62 | 63 | parser.set_defaults(shuffle=True) 64 | self.args = parser.parse_args() 65 | 66 | ## ---- to member variables ----- 67 | for key, value in self.args.__dict__.items(): 68 | if key not in ['test', 'shuffle']: 69 | exec('self.%s = self.args.%s' % (key, key)) 70 | 71 | # ------- name -------- 72 | self.processed_name = 'processed' + self.get_params_str(['lower_word', 'use_glove_unk_token', 73 | 'glove_corpus', 'word_embedding_length']) + '.pickle' 74 | self.dict_name = 'dicts' + self.get_params_str(['lower_word', 'use_glove_unk_token', 75 | ]) 76 | 77 | if not self.network_type == 'test': 78 | params_name_list = ['network_type', 'fine_grained', 'data_imbalance', 79 | 'only_sentence','dropout', 'word_embedding_length', 80 | 'char_out_size', 'hidden_units_num', 'learning_rate', 81 | 'wd', 'optimizer'] 82 | if self.network_type.startswith('baseline'): 83 | params_name_list.append('tree_hn') 84 | params_name_list.append('shift_reduce_method') 85 | params_name_list.append('') 86 | if self.network_type.startswith('emb_direct_attn') or \ 87 | self.network_type.startswith('emb_interact_attn'): 88 | params_name_list.append('method_index') 89 | params_name_list.append('use_bi') 90 | self.model_name = self.get_params_str(params_name_list) 91 | 92 | else: 93 | self.model_name = self.network_type 94 | self.model_ckpt_name = 'modelfile.ckpt' 95 | 96 | 97 | 98 | # ---------- dir ------------- 99 | self.data_dir = join(self.dataset_dir, 'stanfordSentimentTreebank') 100 | self.glove_dir = join(self.dataset_dir, 'glove') 101 | self.result_dir = self.mkdir(self.project_dir, 'result') 102 | self.standby_log_dir = self.mkdir(self.result_dir, 'log') 103 | self.dict_dir = self.mkdir(self.result_dir, 'dict') 104 | self.processed_dir = self.mkdir(self.result_dir, 'processed_data') 105 | 106 | self.log_dir = None 107 | self.all_model_dir = self.mkdir(self.result_dir, 'model') 108 | self.model_dir = self.mkdir(self.all_model_dir, self.model_dir_suffix + self.model_name) 109 | self.log_dir = self.mkdir(self.model_dir, 'log_files') 110 | self.summary_dir = self.mkdir(self.model_dir, 'summary') 111 | self.ckpt_dir = self.mkdir(self.model_dir, 'ckpt') 112 | self.answer_dir = self.mkdir(self.model_dir, 'answer') 113 | 114 | # -------- path -------- 115 | self.processed_path = join(self.processed_dir, self.processed_name) 116 | self.dict_path = join(self.dict_dir, self.dict_name) 117 | self.ckpt_path = join(self.ckpt_dir, self.model_ckpt_name) 118 | 119 | self.extre_dict_path = join(self.dict_dir, 'extra_dict.json') 120 | 121 | # dtype 122 | self.floatX = 'float32' 123 | self.intX = 'int32' 124 | os.environ["CUDA_VISIBLE_DEVICES"] = str(self.gpu) 125 | self.time_counter = TimeCounter() 126 | 127 | def get_params_str(self, params): 128 | def abbreviation(name): 129 | words = name.strip().split('_') 130 | abb = '' 131 | for word in words: 132 | abb += word[0] 133 | return abb 134 | 135 | abbreviations = map(abbreviation, params) 136 | model_params_str = '' 137 | for paramsStr, abb in zip(params, abbreviations): 138 | model_params_str += '_' + abb + '_' + str(eval('self.args.' + paramsStr)) 139 | return model_params_str 140 | 141 | def mkdir(self, *args): 142 | dirPath = join(*args) 143 | if not os.path.exists(dirPath): 144 | os.makedirs(dirPath) 145 | return dirPath 146 | 147 | def get_file_name_from_path(self, path): 148 | assert isinstance(path, str) 149 | fileName = '.'.join((path.split('/')[-1]).split('.')[:-1]) 150 | return fileName 151 | 152 | def sentiment_float_to_int(self, sentiment_float, fine_grained=None): 153 | fine_grained = None or self.fine_grained 154 | if fine_grained: 155 | if sentiment_float <= 0.2: 156 | sentiment_int = 0 157 | elif sentiment_float <= 0.4: 158 | sentiment_int = 1 159 | elif sentiment_float <= 0.6: 160 | sentiment_int = 2 161 | elif sentiment_float <= 0.8: 162 | sentiment_int = 3 163 | else: 164 | sentiment_int = 4 165 | else: 166 | if sentiment_float < 0.5: 167 | sentiment_int = 0 168 | else: 169 | sentiment_int = 1 170 | return sentiment_int 171 | 172 | cfg = Configs() 173 | -------------------------------------------------------------------------------- /SNLI_disan/src/utils/nlp.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | import os 3 | import numpy as np 4 | import string 5 | import re 6 | from collections import Counter 7 | import nltk 8 | 9 | # --------------------2d spans ------------------- 10 | # read : span for each token -> char level 11 | def get_2d_spans(text, tokenss): 12 | spanss = [] 13 | cur_idx = 0 14 | for tokens in tokenss: 15 | spans = [] 16 | for token in tokens: 17 | if text.find(token, cur_idx) < 0: 18 | print(tokens) 19 | print("{} {} {}".format(token, cur_idx, text)) 20 | raise Exception() 21 | cur_idx = text.find(token, cur_idx) 22 | spans.append((cur_idx, cur_idx + len(token))) 23 | cur_idx += len(token) 24 | spanss.append(spans) 25 | return spanss 26 | 27 | 28 | # read 29 | def get_word_span(context, wordss, start, stop): 30 | spanss = get_2d_spans(context, wordss) # [[(start,end),...],...] -> char level 31 | idxs = [] 32 | for sent_idx, spans in enumerate(spanss): 33 | for word_idx, span in enumerate(spans): 34 | if not (stop <= span[0] or start >= span[1]): 35 | idxs.append((sent_idx, word_idx)) 36 | 37 | assert len(idxs) > 0, "{} {} {} {}".format(context, spanss, start, stop) 38 | return idxs[0], (idxs[-1][0], idxs[-1][1] + 1) # (sent_start, token_start) --> (sent_stop, token_stop+1) 39 | 40 | 41 | def get_word_idx(context, wordss, idx): 42 | spanss = get_2d_spans(context, wordss) # [[(start,end),...],...] -> char level 43 | return spanss[idx[0]][idx[1]][0] 44 | 45 | # ----------------- 1d span----------------------- 46 | 47 | def get_1d_spans(text, token_seq): 48 | spans = [] 49 | curIdx = 0 50 | for token in token_seq: 51 | token = token.replace('\xa0',' ') 52 | findRes = text.find(token,curIdx) 53 | if findRes < 0: 54 | raise RuntimeError('{} {} {}'.format(token,curIdx,text)) 55 | curIdx = findRes 56 | spans.append((curIdx, curIdx+len(token))) 57 | curIdx += len(token) 58 | return spans 59 | 60 | 61 | def get_word_idxs_1d(context, token_seq, char_start_idx, char_end_idx): 62 | """ 63 | 0 based 64 | :param context: 65 | :param token_seq: 66 | :param char_start_idx: 67 | :param char_end_idx: 68 | :return: 0-based token index sequence in the tokenized context. 69 | """ 70 | spans = get_1d_spans(context,token_seq) 71 | idxs = [] 72 | for wordIdx, span in enumerate(spans): 73 | if not (char_end_idx <= span[0] or char_start_idx >= span[1]): 74 | idxs.append(wordIdx) 75 | assert len(idxs) > 0, "{} {} {} {}".format(context, token_seq, char_start_idx, char_end_idx) 76 | return idxs 77 | 78 | 79 | def get_start_and_end_char_idx_for_word_idx_1d(context, token_seq, word_idx_seq): 80 | ''' 81 | 0 based 82 | :param context: 83 | :param token_seq: 84 | :param word_idx_seq: 85 | :return: 86 | ''' 87 | spans = get_1d_spans(context, token_seq) 88 | correct_spans = [span for idx,span in enumerate(spans) if idx in word_idx_seq] 89 | 90 | return correct_spans[0][0],correct_spans[-1][-1] 91 | 92 | 93 | # ----------------- for node target idx ----------------------- 94 | def calculate_idx_seq_f1_score(input_idx_seq, label_idx_seq, recall_factor=1.): 95 | assert len(input_idx_seq) > 0 and len(label_idx_seq)>0 96 | # recall 97 | recall_counter = sum(1 for label_idx in label_idx_seq if label_idx in input_idx_seq) 98 | precision_counter = sum(1 for input_idx in input_idx_seq if input_idx in label_idx_seq) 99 | 100 | recall = 1.0*recall_counter/ len(label_idx_seq) 101 | precision = 1.0*precision_counter / len(input_idx_seq) 102 | 103 | recall = recall/recall_factor 104 | 105 | if recall + precision <= 0.: 106 | return 0. 107 | else: 108 | return 2.*recall*precision / (recall + precision) 109 | 110 | 111 | def get_best_node_idx(node_and_leaf_pair, answer_token_idx_seq, recall_factor=1.): 112 | """ 113 | all index in this function is 1 bases 114 | :param node_and_leaves_pair: 115 | :param answer_token_idx_seq: 116 | :return: 117 | """ 118 | f1_scores = [] 119 | for node_idx, leaf_idx_seq in node_and_leaf_pair: 120 | f1_scores.append(calculate_idx_seq_f1_score(leaf_idx_seq,answer_token_idx_seq, 121 | recall_factor)) 122 | max_idx = np.argmax(f1_scores) 123 | return node_and_leaf_pair[max_idx][0] 124 | 125 | # ------------------ calculate text f1------------------- 126 | 127 | def normalize_answer(s): 128 | """Lower text and remove punctuation, articles and extra whitespace.""" 129 | def remove_articles(text): 130 | return re.sub(r'\b(a|an|the)\b', ' ', text) 131 | 132 | def white_space_fix(text): 133 | return ' '.join(text.split()) 134 | 135 | def remove_punc(text): 136 | exclude = set(string.punctuation) 137 | return ''.join(ch for ch in text if ch not in exclude) 138 | 139 | def lower(text): 140 | return text.lower() 141 | 142 | def tokenize(text): 143 | return ' '.join(nltk.word_tokenize(text)) 144 | 145 | return white_space_fix(remove_articles(remove_punc(lower(tokenize(s))))) 146 | 147 | def f1_score(prediction, ground_truth): 148 | prediction_tokens = normalize_answer(prediction).split() 149 | ground_truth_tokens = normalize_answer(ground_truth).split() 150 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 151 | num_same = sum(common.values()) 152 | if num_same == 0: 153 | return 0 154 | precision = 1.0 * num_same / len(prediction_tokens) 155 | recall = 1.0 * num_same / len(ground_truth_tokens) 156 | f1 = (2 * precision * recall) / (precision + recall) 157 | return f1 158 | 159 | 160 | def exact_match_score(prediction, ground_truth): 161 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 162 | 163 | 164 | def check_rebuild_quality(prediction,ground_truth): 165 | em = exact_match_score(prediction,ground_truth) 166 | f1 = f1_score(prediction, ground_truth) 167 | return em,f1 168 | 169 | 170 | def dynamic_length(lengthList, ratio, add=None, security = True, fileName=None): 171 | ratio = float(ratio) 172 | if add is not None: 173 | ratio += add 174 | ratio = ratio if ratio < 1 else 1 175 | if security: 176 | ratio = ratio if ratio < 0.99 else 0.99 177 | def calculate_dynamic_len(pdf ,ratio_ = ratio): 178 | cdf = [] 179 | previous = 0 180 | # accumulate 181 | for len ,freq in pdf: 182 | previous += freq 183 | cdf.append((len, previous)) 184 | # calculate 185 | for len ,accu in cdf: 186 | if 1.0 * accu/ previous >= ratio_: # satisfy the condition 187 | return len, cdf[-1][0] 188 | # max 189 | return cdf[-1][0], cdf[-1][0] 190 | 191 | pdf = dict(nltk.FreqDist(lengthList)) 192 | pdf = sorted(pdf.items(), key=lambda d: d[0]) 193 | 194 | if fileName is not None: 195 | with open(fileName, 'w') as f: 196 | for len, freq in pdf: 197 | f.write('%d\t%d' % (len, freq)) 198 | f.write(os.linesep) 199 | 200 | return calculate_dynamic_len(pdf, ratio) 201 | 202 | 203 | def dynamic_keep(collect,ratio,fileName=None): 204 | 205 | pdf = dict(nltk.FreqDist(collect)) 206 | pdf = sorted(pdf.items(), key=lambda d: d[1],reverse=True) 207 | 208 | cdf = [] 209 | previous = 0 210 | # accumulate 211 | for token, freq in pdf: 212 | previous += freq 213 | cdf.append((token, previous)) 214 | # calculate 215 | for idx, (token, accu) in enumerate(cdf): 216 | keepAnchor = idx 217 | if 1.0 * accu / previous >= ratio: # satisfy the condition 218 | break 219 | 220 | tokenList=[] 221 | for idx, (token, freq) in enumerate(pdf): 222 | if idx > keepAnchor: break 223 | tokenList.append(token) 224 | 225 | 226 | if fileName is not None: 227 | with open(fileName, 'w') as f: 228 | for idx, (token, freq) in enumerate(pdf): 229 | f.write('%d\t%d' % (token, freq)) 230 | f.write(os.linesep) 231 | 232 | if idx == keepAnchor: 233 | print(os.linesep*20) 234 | 235 | return tokenList 236 | 237 | 238 | def gene_question_explicit_class_tag(question_token): 239 | classes = ['what', 'how', 'who', 'when', 'which', 'where', 'why', 'whom', 'whose', 240 | ['am', 'is', 'are', 'was', 'were']] 241 | question_token = [token.lower() for token in question_token] 242 | 243 | for idx_c, cls in enumerate(classes): 244 | if not isinstance(cls, list): 245 | if cls in question_token: 246 | return idx_c 247 | else: 248 | for ccls in cls: 249 | if ccls in question_token: 250 | return idx_c 251 | return len(classes) 252 | 253 | 254 | def gene_token_freq_info(context_token, question_token): 255 | def look_up_dict(t_dict, t): 256 | try: 257 | return t_dict[t] 258 | except KeyError: 259 | return 0 260 | context_token_dict = dict(nltk.FreqDist(context_token)) 261 | question_token_dict = dict(nltk.FreqDist(question_token)) 262 | 263 | # context tokens in context and question dicts 264 | context_tf = [] 265 | for token in context_token: 266 | context_tf.append((look_up_dict(context_token_dict, token), look_up_dict(question_token_dict, token))) 267 | 268 | # question tokens in context and question dicts 269 | question_tf = [] 270 | for token in context_token: 271 | question_tf.append((look_up_dict(context_token_dict, token), look_up_dict(question_token_dict, token))) 272 | 273 | return {'context':context_tf, 'question':question_tf} 274 | 275 | 276 | -------------------------------------------------------------------------------- /SST_disan/src/utils/nlp.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | import os 3 | import numpy as np 4 | import string 5 | import re 6 | from collections import Counter 7 | import nltk 8 | 9 | # --------------------2d spans ------------------- 10 | # read : span for each token -> char level 11 | def get_2d_spans(text, tokenss): 12 | spanss = [] 13 | cur_idx = 0 14 | for tokens in tokenss: 15 | spans = [] 16 | for token in tokens: 17 | if text.find(token, cur_idx) < 0: 18 | print(tokens) 19 | print("{} {} {}".format(token, cur_idx, text)) 20 | raise Exception() 21 | cur_idx = text.find(token, cur_idx) 22 | spans.append((cur_idx, cur_idx + len(token))) 23 | cur_idx += len(token) 24 | spanss.append(spans) 25 | return spanss 26 | 27 | 28 | # read 29 | def get_word_span(context, wordss, start, stop): 30 | spanss = get_2d_spans(context, wordss) # [[(start,end),...],...] -> char level 31 | idxs = [] 32 | for sent_idx, spans in enumerate(spanss): 33 | for word_idx, span in enumerate(spans): 34 | if not (stop <= span[0] or start >= span[1]): 35 | idxs.append((sent_idx, word_idx)) 36 | 37 | assert len(idxs) > 0, "{} {} {} {}".format(context, spanss, start, stop) 38 | return idxs[0], (idxs[-1][0], idxs[-1][1] + 1) # (sent_start, token_start) --> (sent_stop, token_stop+1) 39 | 40 | 41 | def get_word_idx(context, wordss, idx): 42 | spanss = get_2d_spans(context, wordss) # [[(start,end),...],...] -> char level 43 | return spanss[idx[0]][idx[1]][0] 44 | 45 | # ----------------- 1d span----------------------- 46 | 47 | def get_1d_spans(text, token_seq): 48 | spans = [] 49 | curIdx = 0 50 | for token in token_seq: 51 | token = token.replace('\xa0',' ') 52 | findRes = text.find(token,curIdx) 53 | if findRes < 0: 54 | raise RuntimeError('{} {} {}'.format(token,curIdx,text)) 55 | curIdx = findRes 56 | spans.append((curIdx, curIdx+len(token))) 57 | curIdx += len(token) 58 | return spans 59 | 60 | 61 | def get_word_idxs_1d(context, token_seq, char_start_idx, char_end_idx): 62 | """ 63 | 0 based 64 | :param context: 65 | :param token_seq: 66 | :param char_start_idx: 67 | :param char_end_idx: 68 | :return: 0-based token index sequence in the tokenized context. 69 | """ 70 | spans = get_1d_spans(context,token_seq) 71 | idxs = [] 72 | for wordIdx, span in enumerate(spans): 73 | if not (char_end_idx <= span[0] or char_start_idx >= span[1]): 74 | idxs.append(wordIdx) 75 | assert len(idxs) > 0, "{} {} {} {}".format(context, token_seq, char_start_idx, char_end_idx) 76 | return idxs 77 | 78 | 79 | def get_start_and_end_char_idx_for_word_idx_1d(context, token_seq, word_idx_seq): 80 | ''' 81 | 0 based 82 | :param context: 83 | :param token_seq: 84 | :param word_idx_seq: 85 | :return: 86 | ''' 87 | spans = get_1d_spans(context, token_seq) 88 | correct_spans = [span for idx,span in enumerate(spans) if idx in word_idx_seq] 89 | 90 | return correct_spans[0][0],correct_spans[-1][-1] 91 | 92 | 93 | # ----------------- for node target idx ----------------------- 94 | def calculate_idx_seq_f1_score(input_idx_seq, label_idx_seq, recall_factor=1.): 95 | assert len(input_idx_seq) > 0 and len(label_idx_seq)>0 96 | # recall 97 | recall_counter = sum(1 for label_idx in label_idx_seq if label_idx in input_idx_seq) 98 | precision_counter = sum(1 for input_idx in input_idx_seq if input_idx in label_idx_seq) 99 | 100 | recall = 1.0*recall_counter/ len(label_idx_seq) 101 | precision = 1.0*precision_counter / len(input_idx_seq) 102 | 103 | recall = recall/recall_factor 104 | 105 | if recall + precision <= 0.: 106 | return 0. 107 | else: 108 | return 2.*recall*precision / (recall + precision) 109 | 110 | 111 | def get_best_node_idx(node_and_leaf_pair, answer_token_idx_seq, recall_factor=1.): 112 | """ 113 | all index in this function is 1 bases 114 | :param node_and_leaves_pair: 115 | :param answer_token_idx_seq: 116 | :return: 117 | """ 118 | f1_scores = [] 119 | for node_idx, leaf_idx_seq in node_and_leaf_pair: 120 | f1_scores.append(calculate_idx_seq_f1_score(leaf_idx_seq,answer_token_idx_seq, 121 | recall_factor)) 122 | max_idx = np.argmax(f1_scores) 123 | return node_and_leaf_pair[max_idx][0] 124 | 125 | # ------------------ calculate text f1------------------- 126 | 127 | def normalize_answer(s): 128 | """Lower text and remove punctuation, articles and extra whitespace.""" 129 | def remove_articles(text): 130 | return re.sub(r'\b(a|an|the)\b', ' ', text) 131 | 132 | def white_space_fix(text): 133 | return ' '.join(text.split()) 134 | 135 | def remove_punc(text): 136 | exclude = set(string.punctuation) 137 | return ''.join(ch for ch in text if ch not in exclude) 138 | 139 | def lower(text): 140 | return text.lower() 141 | 142 | def tokenize(text): 143 | return ' '.join(nltk.word_tokenize(text)) 144 | 145 | return white_space_fix(remove_articles(remove_punc(lower(tokenize(s))))) 146 | 147 | def f1_score(prediction, ground_truth): 148 | prediction_tokens = normalize_answer(prediction).split() 149 | ground_truth_tokens = normalize_answer(ground_truth).split() 150 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 151 | num_same = sum(common.values()) 152 | if num_same == 0: 153 | return 0 154 | precision = 1.0 * num_same / len(prediction_tokens) 155 | recall = 1.0 * num_same / len(ground_truth_tokens) 156 | f1 = (2 * precision * recall) / (precision + recall) 157 | return f1 158 | 159 | 160 | def exact_match_score(prediction, ground_truth): 161 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 162 | 163 | 164 | def check_rebuild_quality(prediction,ground_truth): 165 | em = exact_match_score(prediction,ground_truth) 166 | f1 = f1_score(prediction, ground_truth) 167 | return em,f1 168 | 169 | 170 | def dynamic_length(lengthList, ratio, add=None, security = True, fileName=None): 171 | ratio = float(ratio) 172 | if add is not None: 173 | ratio += add 174 | ratio = ratio if ratio < 1 else 1 175 | if security: 176 | ratio = ratio if ratio < 0.99 else 0.99 177 | def calculate_dynamic_len(pdf ,ratio_ = ratio): 178 | cdf = [] 179 | previous = 0 180 | # accumulate 181 | for len ,freq in pdf: 182 | previous += freq 183 | cdf.append((len, previous)) 184 | # calculate 185 | for len ,accu in cdf: 186 | if 1.0 * accu/ previous >= ratio_: # satisfy the condition 187 | return len, cdf[-1][0] 188 | # max 189 | return cdf[-1][0], cdf[-1][0] 190 | 191 | pdf = dict(nltk.FreqDist(lengthList)) 192 | pdf = sorted(pdf.items(), key=lambda d: d[0]) 193 | 194 | if fileName is not None: 195 | with open(fileName, 'w') as f: 196 | for len, freq in pdf: 197 | f.write('%d\t%d' % (len, freq)) 198 | f.write(os.linesep) 199 | 200 | return calculate_dynamic_len(pdf, ratio) 201 | 202 | 203 | def dynamic_keep(collect,ratio,fileName=None): 204 | 205 | pdf = dict(nltk.FreqDist(collect)) 206 | pdf = sorted(pdf.items(), key=lambda d: d[1],reverse=True) 207 | 208 | cdf = [] 209 | previous = 0 210 | # accumulate 211 | for token, freq in pdf: 212 | previous += freq 213 | cdf.append((token, previous)) 214 | # calculate 215 | for idx, (token, accu) in enumerate(cdf): 216 | keepAnchor = idx 217 | if 1.0 * accu / previous >= ratio: # satisfy the condition 218 | break 219 | 220 | tokenList=[] 221 | for idx, (token, freq) in enumerate(pdf): 222 | if idx > keepAnchor: break 223 | tokenList.append(token) 224 | 225 | 226 | if fileName is not None: 227 | with open(fileName, 'w') as f: 228 | for idx, (token, freq) in enumerate(pdf): 229 | f.write('%d\t%d' % (token, freq)) 230 | f.write(os.linesep) 231 | 232 | if idx == keepAnchor: 233 | print(os.linesep*20) 234 | 235 | return tokenList 236 | 237 | 238 | def gene_question_explicit_class_tag(question_token): 239 | classes = ['what', 'how', 'who', 'when', 'which', 'where', 'why', 'whom', 'whose', 240 | ['am', 'is', 'are', 'was', 'were']] 241 | question_token = [token.lower() for token in question_token] 242 | 243 | for idx_c, cls in enumerate(classes): 244 | if not isinstance(cls, list): 245 | if cls in question_token: 246 | return idx_c 247 | else: 248 | for ccls in cls: 249 | if ccls in question_token: 250 | return idx_c 251 | return len(classes) 252 | 253 | 254 | def gene_token_freq_info(context_token, question_token): 255 | def look_up_dict(t_dict, t): 256 | try: 257 | return t_dict[t] 258 | except KeyError: 259 | return 0 260 | context_token_dict = dict(nltk.FreqDist(context_token)) 261 | question_token_dict = dict(nltk.FreqDist(question_token)) 262 | 263 | # context tokens in context and question dicts 264 | context_tf = [] 265 | for token in context_token: 266 | context_tf.append((look_up_dict(context_token_dict, token), look_up_dict(question_token_dict, token))) 267 | 268 | # question tokens in context and question dicts 269 | question_tf = [] 270 | for token in context_token: 271 | question_tf.append((look_up_dict(context_token_dict, token), look_up_dict(question_token_dict, token))) 272 | 273 | return {'context':context_tf, 'question':question_tf} 274 | 275 | 276 | -------------------------------------------------------------------------------- /SNLI_disan/snli_main.py: -------------------------------------------------------------------------------- 1 | import math 2 | import tensorflow as tf 3 | 4 | from configs import cfg 5 | from src.dataset import Dataset 6 | from src.evaluator import Evaluator 7 | from src.graph_handler import GraphHandler 8 | from src.perform_recorder import PerformRecoder 9 | from src.utils.file import load_file, save_file 10 | from src.utils.record_log import _logger 11 | 12 | # choose model 13 | network_type = cfg.network_type 14 | if network_type == 'exp_bi_lstm_mul_attn': # check, running 15 | from src.model.exp_bi_lstm_mul_attn import ModelExpBiLSTMMulAttn as Model 16 | elif network_type == 'exp_emb_attn': # check, done 17 | from src.model.exp_emb_attn import ModelExpEmbAttn as Model 18 | elif network_type == 'exp_emb_mul_attn': # check, done 19 | from src.model.exp_emb_mul_attn import ModelExpEmbMulAttn as Model 20 | elif network_type == 'exp_emb_self_mul_attn': 21 | from src.model.exp_emb_self_mul_attn import ModelExpEmbSelfMulAttn as Model 22 | elif network_type == 'exp_emb_dir_mul_attn': 23 | from src.model.exp_emb_dir_mul_attn import ModelExpEmbDirMulAttn as Model 24 | elif network_type == 'disan': 25 | from src.model.model_disan import ModelDiSAN as Model 26 | 27 | model_type_set = ['exp_bi_lstm_mul_attn', 'exp_emb_attn', 'exp_emb_mul_attn', 28 | 'exp_emb_self_mul_attn', 'exp_emb_dir_mul_attn', 'disan'] 29 | 30 | 31 | def train(): 32 | output_model_params() 33 | loadFile = True 34 | ifLoad, data = False, None 35 | if loadFile: 36 | ifLoad, data = load_file(cfg.processed_path, 'processed data', 'pickle') 37 | if not ifLoad or not loadFile: 38 | train_data_obj = Dataset(cfg.train_data_path, 'train') 39 | dev_data_obj = Dataset(cfg.dev_data_path, 'dev', dicts=train_data_obj.dicts) 40 | test_data_obj = Dataset(cfg.test_data_path, 'test', dicts=train_data_obj.dicts) 41 | 42 | save_file({'train_data_obj': train_data_obj, 'dev_data_obj': dev_data_obj, 'test_data_obj': test_data_obj}, 43 | cfg.processed_path) 44 | 45 | train_data_obj.save_dict(cfg.dict_path) 46 | else: 47 | train_data_obj = data['train_data_obj'] 48 | dev_data_obj = data['dev_data_obj'] 49 | test_data_obj = data['test_data_obj'] 50 | 51 | train_data_obj.filter_data() 52 | dev_data_obj.filter_data() 53 | test_data_obj.filter_data() 54 | 55 | emb_mat_token, emb_mat_glove = train_data_obj.emb_mat_token, train_data_obj.emb_mat_glove 56 | 57 | with tf.variable_scope(network_type) as scope: 58 | if network_type in model_type_set: 59 | model = Model(emb_mat_token, emb_mat_glove, len(train_data_obj.dicts['token']), 60 | len(train_data_obj.dicts['char']), train_data_obj.max_lens['token'], scope.name) 61 | graphHandler = GraphHandler(model) 62 | evaluator = Evaluator(model) 63 | performRecoder = PerformRecoder(3) 64 | 65 | if cfg.gpu_mem < 1.: 66 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=cfg.gpu_mem, 67 | allow_growth=True) 68 | else: 69 | gpu_options = tf.GPUOptions() 70 | graph_config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) 71 | # graph_config.gpu_options.allow_growth = True 72 | sess = tf.Session(config=graph_config) 73 | graphHandler.initialize(sess) 74 | 75 | # begin training 76 | steps_per_epoch = int(math.ceil(1.0 * train_data_obj.sample_num / cfg.train_batch_size)) 77 | num_steps = cfg.num_steps or steps_per_epoch * cfg.max_epoch 78 | 79 | global_step = 0 80 | 81 | for sample_batch, batch_num, data_round, idx_b in train_data_obj.generate_batch_sample_iter(num_steps): 82 | global_step = sess.run(model.global_step) + 1 83 | if_get_summary = global_step % (cfg.log_period or steps_per_epoch) == 0 84 | loss, summary, train_op = model.step(sess, sample_batch, get_summary=if_get_summary) 85 | if global_step % 100 == 0 or global_step == 1: 86 | _logger.add('data round: %d: %d/%d, global step:%d -- loss: %.4f' % 87 | (data_round, idx_b, batch_num, global_step, loss)) 88 | 89 | if if_get_summary: 90 | graphHandler.add_summary(summary, global_step) 91 | 92 | # Occasional evaluation 93 | if global_step > int(cfg.num_steps - 100000) and (cfg.global_step % (cfg.eval_period or steps_per_epoch) == 0): 94 | # ---- dev ---- 95 | dev_loss, dev_accu = evaluator.get_evaluation( 96 | sess, dev_data_obj, global_step 97 | ) 98 | _logger.add('==> for dev, loss: %.4f, accuracy: %.4f' % 99 | (dev_loss, dev_accu)) 100 | # ---- test ---- 101 | test_loss, test_accu = evaluator.get_evaluation( 102 | sess, test_data_obj, global_step 103 | ) 104 | _logger.add('~~> for test, loss: %.4f, accuracy: %.4f' % 105 | (test_loss, test_accu)) 106 | 107 | model.update_learning_rate(dev_loss, cfg.lr_decay) 108 | is_in_top, deleted_step = performRecoder.update_top_list(global_step, dev_accu, sess) 109 | 110 | this_epoch_time, mean_epoch_time = cfg.time_counter.update_data_round(data_round) 111 | if this_epoch_time is not None and mean_epoch_time is not None: 112 | _logger.add('##> this epoch time: %f, mean epoch time: %f' % (this_epoch_time, mean_epoch_time)) 113 | 114 | do_analyse_snli(_logger.path) 115 | 116 | 117 | def test(): 118 | 119 | assert cfg.load_path is not None 120 | output_model_params() 121 | loadFile = True 122 | ifLoad, data = False, None 123 | if loadFile: 124 | ifLoad, data = load_file(cfg.processed_path, 'processed data', 'pickle') 125 | if not ifLoad or not loadFile: 126 | raise RuntimeError('cannot find pre-processed dataset') 127 | else: 128 | train_data_obj = data['train_data_obj'] 129 | dev_data_obj = data['dev_data_obj'] 130 | test_data_obj = data['test_data_obj'] 131 | 132 | train_data_obj.filter_data('test') 133 | dev_data_obj.filter_data('test') 134 | test_data_obj.filter_data('test') 135 | 136 | emb_mat_token, emb_mat_glove = train_data_obj.emb_mat_token, train_data_obj.emb_mat_glove 137 | 138 | with tf.variable_scope(network_type) as scope: 139 | if network_type in model_type_set: 140 | model = Model(emb_mat_token, emb_mat_glove, len(train_data_obj.dicts['token']), 141 | len(train_data_obj.dicts['char']), train_data_obj.max_lens['token'], scope.name) 142 | graphHandler = GraphHandler(model) 143 | evaluator = Evaluator(model) 144 | 145 | if cfg.gpu_mem < 1.: 146 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=cfg.gpu_mem, 147 | allow_growth=True) 148 | else: 149 | gpu_options = tf.GPUOptions() 150 | graph_config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) 151 | sess = tf.Session(config=graph_config) 152 | graphHandler.initialize(sess) 153 | 154 | # ---- dev ---- 155 | dev_loss, dev_accu = evaluator.get_evaluation( 156 | sess, dev_data_obj, None 157 | ) 158 | _logger.add('==> for dev, loss: %.4f, accuracy: %.4f' % 159 | (dev_loss, dev_accu)) 160 | # ---- test ---- 161 | test_loss, test_accu = evaluator.get_evaluation( 162 | sess, test_data_obj, None 163 | ) 164 | _logger.add('~~> for test, loss: %.4f, accuracy: %.4f' % 165 | (test_loss, test_accu)) 166 | 167 | train_loss, train_accu = evaluator.get_evaluation( 168 | sess, train_data_obj, None 169 | ) 170 | _logger.add('--> for test, loss: %.4f, accuracy: %.4f' % 171 | (train_loss, train_accu)) 172 | 173 | 174 | def main(_): 175 | if cfg.mode == 'train': 176 | train() 177 | elif cfg.mode == 'test': 178 | test() 179 | else: 180 | raise RuntimeError('no running mode named as %s' % cfg.mode) 181 | 182 | 183 | def output_model_params(): 184 | _logger.add() 185 | _logger.add('==>model_title: ' + cfg.model_name[1:]) 186 | _logger.add() 187 | for key,value in cfg.args.__dict__.items(): 188 | if key not in ['test','shuffle']: 189 | _logger.add('%s: %s' % (key, value)) 190 | 191 | 192 | def do_analyse_snli(file_path, dev=True, use_loss=False, stop=None): 193 | results = [] 194 | with open(file_path, 'r', encoding='utf-8') as file: 195 | find_entry = False 196 | output = [0, 0., 0., 0., 0.] # xx, dev, test, 197 | for line in file: 198 | if not find_entry: 199 | if line.startswith('data round'): # get step 200 | output[0] = int(line.split(' ')[-4].split(':')[-1]) 201 | if stop is not None and output[0] > stop: break 202 | if line.startswith('==> for dev'): # dev 203 | output[1] = float(line.split(' ')[-1]) 204 | output[2] = float(line.split(' ')[-3][:-1]) 205 | find_entry = True 206 | else: 207 | if line.startswith('~~> for test'): # test 208 | output[3] = float(line.split(' ')[-1]) 209 | output[4] = float(line.split(' ')[-3][:-1]) 210 | results.append(output) 211 | find_entry = False 212 | output = [0, 0., 0., 0., 0.] 213 | 214 | # max step 215 | if len(results) > 0: 216 | print('max step:', results[-1][0]) 217 | 218 | # sort 219 | sort = 1 if dev else 3 220 | if use_loss: sort += 1 221 | output = list(sorted(results, key=lambda elem: elem[sort], reverse=not use_loss)) 222 | 223 | for elem in output[:3]: 224 | print('step: %d, dev: %.4f, dev_loss: %.4f, test: %.4f, test_loss: %.4f' % 225 | (elem[0], elem[1], elem[2], elem[3],elem[4])) 226 | 227 | 228 | 229 | if __name__ == '__main__': 230 | tf.app.run() 231 | 232 | 233 | 234 | -------------------------------------------------------------------------------- /SST_disan/src/model/template.py: -------------------------------------------------------------------------------- 1 | from configs import cfg 2 | from src.utils.record_log import _logger 3 | import tensorflow as tf 4 | import numpy as np 5 | from abc import ABCMeta, abstractmethod 6 | 7 | 8 | class ModelTemplate(metaclass=ABCMeta): 9 | def __init__(self, token_emb_mat, glove_emb_mat, tds, cds, tl, scope): 10 | self.scope = scope 11 | self.global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32, 12 | initializer=tf.constant_initializer(0), trainable=False) 13 | self.token_emb_mat, self.glove_emb_mat = token_emb_mat, glove_emb_mat 14 | 15 | # ---- place holder ----- 16 | self.token_seq = tf.placeholder(tf.int32, [None, None], name='token_seq') 17 | self.char_seq = tf.placeholder(tf.int32, [None, None, tl], name='context_char') 18 | 19 | self.op_list = tf.placeholder(tf.int32, [None, None], name='op_lists') # bs,sol 20 | self.reduce_mat = tf.placeholder(tf.int32, [None, None, None], name='reduce_mats') # [bs,sol,mc] 21 | 22 | self.sentiment_label = tf.placeholder(tf.int32, [None], name='sentiment_label') # bs 23 | self.is_train = tf.placeholder(tf.bool, [], name='is_train') 24 | 25 | 26 | # ----------- parameters ------------- 27 | self.tds, self.cds = tds, cds 28 | self.tl = tl 29 | self.tel = cfg.word_embedding_length 30 | self.cel = cfg.char_embedding_length 31 | self.cos = cfg.char_out_size 32 | self.ocd = list(map(int, cfg.out_channel_dims.split(','))) 33 | self.fh = list(map(int, cfg.filter_heights.split(','))) 34 | self.hn = cfg.hidden_units_num 35 | self.finetune_emb = cfg.fine_tune 36 | 37 | self.output_class = 5 if cfg.fine_grained else 2 38 | 39 | self.bs = tf.shape(self.token_seq)[0] 40 | self.sl = tf.shape(self.token_seq)[1] 41 | self.ol = tf.shape(self.op_list)[1] 42 | self.mc = tf.shape(self.reduce_mat)[2] 43 | 44 | # ------------ other --------- 45 | self.token_mask = tf.cast(self.token_seq, tf.bool) 46 | self.char_mask = tf.cast(self.char_seq, tf.bool) 47 | self.token_len = tf.reduce_sum(tf.cast(self.token_mask, tf.int32), -1) 48 | self.char_len = tf.reduce_sum(tf.cast(self.char_mask, tf.int32), -1) 49 | 50 | self.stack_mask = tf.not_equal(self.op_list, tf.zeros_like(self.op_list)) 51 | 52 | self.tensor_dict = {} 53 | 54 | # ------ start ------ 55 | self.logits = None 56 | self.loss = None 57 | self.accuracy = None 58 | self.var_ema = None 59 | self.ema = None 60 | self.summary = None 61 | self.opt = None 62 | self.train_op = None 63 | 64 | @abstractmethod 65 | def build_network(self): 66 | pass 67 | 68 | def build_loss(self): 69 | # weight_decay 70 | with tf.name_scope("weight_decay"): 71 | for var in set(tf.get_collection('reg_vars', self.scope)): 72 | weight_decay = tf.multiply(tf.nn.l2_loss(var), cfg.wd, 73 | name="{}-wd".format('-'.join(str(var.op.name).split('/')))) 74 | tf.add_to_collection('losses', weight_decay) 75 | reg_vars = tf.get_collection('losses', self.scope) 76 | trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope) 77 | _logger.add('regularization var num: %d' % len(reg_vars)) 78 | _logger.add('trainable var num: %d' % len(trainable_vars)) 79 | losses = tf.nn.sparse_softmax_cross_entropy_with_logits( 80 | labels=self.sentiment_label, 81 | logits=self.logits 82 | ) 83 | tf.add_to_collection('losses', tf.reduce_mean(losses, name='xentropy_loss_mean')) 84 | loss = tf.add_n(tf.get_collection('losses', self.scope), name='loss') 85 | tf.summary.scalar(loss.op.name, loss) 86 | tf.add_to_collection('ema/scalar', loss) 87 | return loss 88 | 89 | def build_accuracy(self): 90 | correct = tf.equal( 91 | tf.cast(tf.argmax(self.logits, -1), tf.int32), 92 | self.sentiment_label 93 | ) # [bs] 94 | return tf.cast(correct, tf.float32) 95 | 96 | def update_tensor_add_ema_and_opt(self): 97 | self.logits = self.build_network() 98 | self.loss = self.build_loss() 99 | self.accuracy = self.build_accuracy() 100 | 101 | # ------------ema------------- 102 | if True: 103 | self.var_ema = tf.train.ExponentialMovingAverage(cfg.var_decay) 104 | self.build_var_ema() 105 | 106 | if cfg.mode == 'train': 107 | self.ema = tf.train.ExponentialMovingAverage(cfg.decay) 108 | self.build_ema() 109 | self.summary = tf.summary.merge_all() 110 | 111 | # ---------- optimization --------- 112 | if cfg.optimizer.lower() == 'adadelta': 113 | assert cfg.learning_rate > 0.1 and cfg.learning_rate < 1. 114 | self.opt = tf.train.AdadeltaOptimizer(cfg.learning_rate) 115 | elif cfg.optimizer.lower() == 'adam': 116 | assert cfg.learning_rate < 0.1 117 | self.opt = tf.train.AdamOptimizer(cfg.learning_rate) 118 | elif cfg.optimizer.lower() == 'rmsprop': 119 | assert cfg.learning_rate < 0.1 120 | self.opt = tf.train.RMSPropOptimizer(cfg.learning_rate) 121 | else: 122 | raise AttributeError('no optimizer named as \'%s\'' % cfg.optimizer) 123 | 124 | 125 | self.train_op = self.opt.minimize(self.loss, self.global_step, 126 | var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)) 127 | 128 | def build_var_ema(self): 129 | ema_op = self.var_ema.apply(tf.trainable_variables(),) 130 | with tf.control_dependencies([ema_op]): 131 | self.loss = tf.identity(self.loss) 132 | 133 | def build_ema(self): 134 | tensors = tf.get_collection("ema/scalar", scope=self.scope) + \ 135 | tf.get_collection("ema/vector", scope=self.scope) 136 | ema_op = self.ema.apply(tensors) 137 | for var in tf.get_collection("ema/scalar", scope=self.scope): 138 | ema_var = self.ema.average(var) 139 | tf.summary.scalar(ema_var.op.name, ema_var) 140 | for var in tf.get_collection("ema/vector", scope=self.scope): 141 | ema_var = self.ema.average(var) 142 | tf.summary.histogram(ema_var.op.name, ema_var) 143 | 144 | with tf.control_dependencies([ema_op]): 145 | self.loss = tf.identity(self.loss) 146 | 147 | def get_feed_dict(self, sample_batch, data_type='train'): 148 | # max lens 149 | sl, ol, mc = 0, 0, 0 150 | for sample in sample_batch: 151 | sl = max(sl, len(sample['root_node']['token_seq'])) 152 | ol = max(ol, len(sample['shift_reduce_info']['op_list'])) 153 | for reduce_list in sample['shift_reduce_info']['reduce_mat']: 154 | mc = max(mc, len(reduce_list)) 155 | 156 | assert mc == 0 or mc == 2, mc 157 | 158 | # token and char 159 | token_seq_b = [] 160 | char_seq_b = [] 161 | for sample in sample_batch: 162 | token_seq = np.zeros([sl], cfg.intX) 163 | char_seq = np.zeros([sl, self.tl], cfg.intX) 164 | 165 | for idx_t,(token, char_seq_v) in enumerate(zip(sample['root_node']['token_seq_digital'], 166 | sample['root_node']['char_seq_digital'])): 167 | token_seq[idx_t] = token 168 | for idx_c, char in enumerate(char_seq_v): 169 | if idx_c >= self.tl: break 170 | char_seq[idx_t, idx_c] = char 171 | token_seq_b.append(token_seq) 172 | char_seq_b.append(char_seq) 173 | token_seq_b = np.stack(token_seq_b) 174 | char_seq_b = np.stack(char_seq_b) 175 | 176 | # tree 177 | op_list_b = [] 178 | reduce_mat_b = [] 179 | for sample in sample_batch: 180 | op_list = np.zeros([ol], cfg.intX) 181 | reduce_mat = np.zeros([ol, mc], cfg.intX) 182 | 183 | for idx_o, (op, reduce_list) in enumerate(zip(sample['shift_reduce_info']['op_list'], 184 | sample['shift_reduce_info']['reduce_mat'])): 185 | op_list[idx_o] = op 186 | for idx_m, red in enumerate(reduce_list): 187 | reduce_mat[idx_o, idx_m] = red 188 | op_list_b.append(op_list) 189 | reduce_mat_b.append(reduce_mat) 190 | op_list_b = np.stack(op_list_b) 191 | reduce_mat_b = np.stack(reduce_mat_b) 192 | 193 | # label 194 | sentiment_label_b = [] 195 | for sample in sample_batch: 196 | sentiment_float = sample['root_node']['sentiment_label'] 197 | sentiment_int = cfg.sentiment_float_to_int(sentiment_float) 198 | sentiment_label_b.append(sentiment_int) 199 | sentiment_label_b = np.stack(sentiment_label_b).astype(cfg.intX) 200 | 201 | feed_dict = {self.token_seq: token_seq_b, self.char_seq: char_seq_b, 202 | self.op_list: op_list_b, self.reduce_mat: reduce_mat_b, 203 | self.sentiment_label: sentiment_label_b, 204 | self.is_train: True if data_type == 'train' else False} 205 | return feed_dict 206 | 207 | def step(self, sess, batch_samples, get_summary=False): 208 | assert isinstance(sess, tf.Session) 209 | feed_dict = self.get_feed_dict(batch_samples, 'train') 210 | cfg.time_counter.add_start() 211 | if get_summary: 212 | loss, summary, train_op = sess.run([self.loss, self.summary, self.train_op], feed_dict=feed_dict) 213 | 214 | else: 215 | loss, train_op = sess.run([self.loss, self.train_op], feed_dict=feed_dict) 216 | summary = None 217 | cfg.time_counter.add_stop() 218 | return loss, summary, train_op --------------------------------------------------------------------------------