├── .gitignore ├── README.md ├── data └── .gitkeep ├── requirements.txt ├── src ├── ShortTextTopic │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── gsdmm │ │ ├── __init__.py │ │ └── mgp.py │ ├── requirements.txt │ ├── setup.py │ └── test │ │ ├── __init__.py │ │ └── test_gsdmm.py ├── __init__.py ├── evaluation │ ├── __init__.py │ ├── difficulty │ │ ├── __init__.py │ │ ├── difficulty_cases.py │ │ └── difficulty_metrics.py │ ├── evaluate.py │ └── metrics │ │ ├── __init__.py │ │ ├── dice.py │ │ ├── jaccard.py │ │ ├── js_div.py │ │ └── lexical_similarity.py ├── experiments │ ├── bert.py │ └── tbert.py ├── loaders │ ├── MSRP │ │ ├── __init__.py │ │ └── build.py │ ├── PAWS │ │ ├── __init__.py │ │ └── build.py │ ├── Quora │ │ ├── __init__.py │ │ └── build.py │ ├── Semeval │ │ ├── __init__.py │ │ ├── agents.py │ │ ├── build.py │ │ └── helper.py │ ├── __init__.py │ ├── augment_data.py │ ├── build_data.py │ └── load_data.py ├── logs │ ├── tf_event_logs.py │ └── training_logs.py ├── models │ ├── __init__.py │ ├── base_model_bert.py │ ├── forward │ │ ├── __init__.py │ │ ├── bert.py │ │ └── bert_simple_topic.py │ ├── helpers │ │ ├── __init__.py │ │ ├── base.py │ │ ├── bert.py │ │ ├── minibatching.py │ │ └── training_regimes.py │ ├── save_load.py │ ├── tf_helpers.py │ └── topic_baseline.py ├── preprocessing │ ├── Preprocessor.py │ ├── __init__.py │ └── bert_tokenization.py └── topic_model │ ├── __init__.py │ ├── gsdmm.py │ ├── lda.py │ ├── topic_eval.py │ ├── topic_loader.py │ ├── topic_mass_trainer.py │ ├── topic_predictor.py │ ├── topic_trainer.py │ └── topic_visualiser.py └── tBERT.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled source # 2 | ################### 3 | *.com 4 | *.class 5 | *.dll 6 | *.exe 7 | *.o 8 | *.so 9 | 10 | # Packages # 11 | ############ 12 | # it's better to unpack these files and commit the raw source 13 | # git has its own built in compression methods 14 | *.7z 15 | *.dmg 16 | *.gz 17 | *.iso 18 | *.jar 19 | *.rar 20 | *.tar 21 | *.zip 22 | .python-version 23 | 24 | # Logs and databases # 25 | ###################### 26 | *.log 27 | *.sql 28 | *.sqlite 29 | 30 | # Data folder # 31 | ############### 32 | data/* 33 | !data/.gitkeep 34 | 35 | # OS generated files # 36 | ###################### 37 | .DS_Store 38 | .DS_Store? 39 | ._* 40 | .Spotlight-V100 41 | .Trashes 42 | ehthumbs.db 43 | Thumbs.db 44 | 45 | # Compiled python files # 46 | ######################### 47 | *.pyc 48 | __pycache__/ 49 | .ipynb_checkpoints 50 | .idea/ 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tBERT 2 | 3 | ![Alt text](tBERT.jpg?raw=true "tBERT model") 4 | 5 | This repository provides code for the paper "tBERT: Topic Models and BERT Joining Forces for Semantic Similarity Detection" (https://www.aclweb.org/anthology/2020.acl-main.630/). 6 | 7 | ## Setup 8 | 9 | 10 | ### Download pretrained BERT 11 | 12 | - Create cache folder in home directory: 13 | ``` 14 | cd ~ 15 | mkdir tf-hub-cache 16 | cd tf-hub-cache 17 | ``` 18 | - Download pretrained BERT model and unzip: 19 | ``` 20 | wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip 21 | unzip uncased_L-12_H-768_A-12.zip 22 | ``` 23 | 24 | ### Download preprocessed data 25 | 26 | - Go to the tBERT repository: 27 | ``` 28 | cd /path/to/tBERT/ 29 | ``` 30 | - Download topic models and original datasets from dropbox: 31 | ``` 32 | wget "https://www.dropbox.com/s/6p26mkwv62677zt/original_data.tar.gz" 33 | ``` 34 | - Uncompress original_data.tar.gz: 35 | ``` 36 | tar zxvf original_data.tar.gz & 37 | ``` 38 | - Your tBERT directory should now have the following content (the Semeval folder is empty because data will be automatically downloaded by the script): 39 | ``` 40 | . 41 | ├── data 42 | │   ├── cache 43 | │   ├── logs 44 | │   ├── models 45 | │   ├── MSRP 46 | │   │   └── MSRParaphraseCorpus 47 | │   │   ├── msr-para-test.tsv 48 | │   │   ├── msr-para-train.tsv 49 | │   │   └── msr-para-val.tsv 50 | │   ├── Quora 51 | │   │   └── Quora_question_pair_partition 52 | │   │   ├── dev.tsv 53 | │   │   ├── test.tsv 54 | │   │   └── train.tsv 55 | │   ├── Semeval 56 | │   └── topic_models 57 | │   ├── basic 58 | │   │   ├── MSRP_alpha1_80 59 | │   │   ├── Quora_alpha1_90 60 | │   │   ├── Semeval_alpha10_70 61 | │   │   ├── Semeval_alpha10_80 62 | │   │   └── Semeval_alpha50_70 63 | │   ├── basic_gsdmm 64 | │   │   ├── MSRP_alpha0.1_80 65 | │   │   ├── Quora_alpha0.1_90 66 | │   │   ├── Semeval_alpha0.1_70 67 | │   │   └── Semeval_alpha0.1_80 68 | │   └── mallet-2.0.8.zip 69 | └── src 70 | ├── evaluation 71 | │   ├── difficulty 72 | │   ├── metrics 73 | ├── experiments 74 | ├── loaders 75 | │   ├── MSRP 76 | │   ├── PAWS 77 | │   ├── Quora 78 | │   └── Semeval 79 | ├── logs 80 | ├── models 81 | │   ├── forward 82 | │   ├── helpers 83 | ├── preprocessing 84 | ├── ShortTextTopic 85 | │   ├── gsdmm 86 | │   └── test 87 | └── topic_model 88 | ``` 89 | 90 | ### Requirements 91 | 92 | - This code has been tested with Python 3.6 and Tensorflow 1.11. 93 | - Install the required Python packages as defined in requirements.txt: 94 | ``` 95 | pip install -r requirements.txt 96 | ``` 97 | 98 | ## Usage 99 | 100 | - You can try out if everything works by training a model on a small portion of the data (you can play around with different model options by changing the opt dictionary). Please make sure you are in the top tBERT directory when executing the following commands (`ls` should show `data data.tar.gz README.md requirements.txt src` as output): 101 | ``` 102 | python src/models/base_model_bert.py 103 | ``` 104 | - This should produce the following output: 105 | ``` 106 | ['m_train_B', 'm_dev_B', 'm_test_B'] 107 | ['data/MSRP/m_train_B.txt', 'data/MSRP/m_dev_B.txt', 'data/MSRP/m_test_B.txt'] 108 | data/cache/m_train_B.pickle 109 | Loading cached input for m_train_B 110 | data/cache/m_dev_B.pickle 111 | Loading cached input for m_dev_B 112 | data/cache/m_test_B.pickle 113 | Loading cached input for m_test_B 114 | Mapping words to BERT ids... 115 | Finished word id mapping. 116 | Done. 117 | {'topic_type': 'ldamallet', 'load_ids': True, 'topic': 'doc', 'minibatch_size': 10, 'seed': 1, 'max_m': 10, 'bert_large': False, 'num_topics': 80, 'num_epochs': 1, 'model': 'bert_simple_topic', 'max_length': 'minimum', 'simple_padding': True, 'padding': False, 'bert_update': True, 'L2': 0, 'dropout': 0.1, 'bert_cased': False, 'speedup_new_layers': False, 'unk_topic': 'zero', 'stopping_criterion': 'F1', 'tasks': ['B'], 'learning_rate': 0.3, 'hidden_layer': 1, 'gpu': -1, 'optimizer': 'Adadelta', 'datapath': 'data/', 'unflat_topics': False, 'sparse_labels': True, 'freeze_thaw_tune': False, 'dataset': 'MSRP', 'topic_alpha': 1, 'predict_every_epoch': False, 'unk_sub': False, 'subsets': ['train', 'dev', 'test'], 'topic_update': True} 118 | Topic scope: doc 119 | input ids shape: (?, ?) 120 | Loading pretrained model from https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1 121 | --- 122 | Model: tBERT 123 | --- 124 | D_T1 shape: (?, 80) 125 | D_T2 shape: (?, 80) 126 | pooled BERT shape: (?, 768) 127 | combined shape: (?, 928) 128 | hidden 1 shape: (?, 464) 129 | output layer shape: (?, 2) 130 | reading logs... 131 | No file found at data/logs/test.json. Creating new log. 132 | get new id 133 | Model 0 134 | 2020-07-03 19:02:31.946114: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA 135 | 2020-07-03 19:02:37.224331: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1411] Found device 0 with properties: 136 | name: Tesla K80 major: 3 minor: 7 memoryClockRate(GHz): 0.8235 137 | pciBusID: 6a05:00:00.0 138 | totalMemory: 11.17GiB freeMemory: 11.11GiB 139 | 2020-07-03 19:02:37.224390: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1490] Adding visible gpu devices: 0 140 | logfile: test.json 141 | Finetune... 142 | Epoch 1 143 | Dev F1 after epoch 1: 0.75 144 | data/models/model_0/model_epoch1.ckpt 145 | Maximum number of epochs reached during early stopping. 146 | Finished training. 147 | Load best model from epoch 1 148 | reading logs... 149 | Finished training after 0.55 min 150 | Dev F1: 0.75 151 | Test F1: 0.75 152 | reading logs... 153 | Wrote predictions for model_0. 154 | ``` 155 | - The model will be saved under data/models/model_0/ and the training log is available under data/logs/test.json 156 | - You can also run an experiment on the complete dataset and alter different commandline flags, e.g.: 157 | ``` 158 | python src/experiments/tbert.py -dataset MSRP -layers 1 -topic doc -topic_type ldamallet -learning_rate 5e-5 --early_stopping -seed 3 -gpu 0 159 | ``` 160 | - This should give you the following output: 161 | ``` 162 | Starting experiment 1 of 1 163 | tbert_1_seed_early_stopping.json 164 | {'dropout': 0.1, 'model': 'bert_simple_topic', 'bert_cased': False, 'max_m': None, 'tasks': ['B'], 'padding': False, 'dataset': 'MSRP', 165 | 'L2': 0, 'subsets': ['train', 'dev', 'test'], 'unk_sub': False, 'hidden_layer': 1, 'datapath': 'data/', 'predict_every_epoch': False, 166 | 'num_epochs': 3, 'simple_padding': True, 'patience': 2, 'speedup_new_layers': False, 'minibatch_size': 32, 'max_length': 'minimum', 'lo 167 | ad_ids': True, 'topic_type': 'ldamallet', 'unk_topic': 'uniform', 'topic_update': False, 'sparse_labels': True, 'num_topics': 80, 'topi 168 | c_alpha': 1, 'seed': 1, 'gpu': 0, 'stopping_criterion': 'F1', 'bert_update': True, 'learning_rate': 3e-05, 'optimizer': 'Adam', 'topic' 169 | : 'doc'} 170 | ['m_train_B', 'm_dev_B', 'm_test_B'] 171 | ['data/MSRP/m_train_B.txt', 'data/MSRP/m_dev_B.txt', 'data/MSRP/m_test_B.txt'] 172 | data/cache/m_train_B.pickle 173 | Loading cached input for m_train_B 174 | data/cache/m_dev_B.pickle 175 | Loading cached input for m_dev_B 176 | data/cache/m_test_B.pickle 177 | Loading cached input for m_test_B 178 | Mapping words to BERT ids... 179 | Finished word id mapping. 180 | Done. 181 | {'dropout': 0.1, 'model': 'bert_simple_topic', 'bert_cased': False, 'max_m': None, 'tasks': ['B'], 'padding': False, 'dataset': 'MSRP', 182 | 'unflat_topics': False, 'L2': 0, 'subsets': ['train', 'dev', 'test'], 'unk_sub': False, 'hidden_layer': 1, 'datapath': 'data/', 'bert_ 183 | large': False, 'predict_every_epoch': False, 'num_epochs': 3, 'simple_padding': True, 'patience': 2, 'speedup_new_layers': False, 'mini 184 | batch_size': 32, 'max_length': 'minimum', 'load_ids': True, 'topic_type': 'ldamallet', 'unk_topic': 'uniform', 'topic_update': False, ' 185 | sparse_labels': True, 'num_topics': 80, 'topic_alpha': 1, 'seed': 1, 'gpu': 0, 'stopping_criterion': 'F1', 'bert_update': True, 'learni 186 | ng_rate': 3e-05, 'optimizer': 'Adam', 'topic': 'doc'} 187 | Running on GPU: 0 188 | Topic scope: doc 189 | input ids shape: (?, ?) 190 | Loading pretrained model from https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1 191 | --- 192 | Model: tBERT 193 | --- 194 | D_T1 shape: (?, 80) 195 | D_T2 shape: (?, 80) 196 | pooled BERT shape: (?, 768) 197 | combined shape: (?, 928) 198 | hidden 1 shape: (?, 464) 199 | output layer shape: (?, 2) 200 | reading logs... 201 | get new id 202 | Model 1 203 | 2020-07-03 19:51:34.629485: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow bi 204 | nary was not compiled to use: AVX2 FMA 205 | 2020-07-03 19:51:39.501180: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1411] Found device 0 with properties: 206 | name: Tesla K80 major: 3 minor: 7 memoryClockRate(GHz): 0.8235 207 | pciBusID: 6a05:00:00.0 208 | totalMemory: 11.17GiB freeMemory: 11.11GiB 209 | 2020-07-03 19:51:39.501233: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1490] Adding visible gpu devices: 0 210 | 2020-07-03 19:51:39.769183: I tensorflow/core/common_runtime/gpu/gpu_device.cc:971] Device interconnect StreamExecutor with strength 1 edge matrix: 211 | 2020-07-03 19:51:39.769242: I tensorflow/core/common_runtime/gpu/gpu_device.cc:977] 0 212 | 2020-07-03 19:51:39.769263: I tensorflow/core/common_runtime/gpu/gpu_device.cc:990] 0: N 213 | 2020-07-03 19:51:39.769368: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1103] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 10761 MB memory) -> physical GPU (device: 0, name: Tesla K80, pci bus id: 6a05:00:00.0, compute capability: 3.7) 214 | logfile: tbert_1_seed_early_stopping.json 215 | Finetune... 216 | Epoch 1 217 | Dev F1 after epoch 1: 0.8917378783226013 218 | data/models/model_1/model_epoch1.ckpt 219 | Epoch 2 220 | Dev F1 after epoch 2: 0.9058663249015808 221 | data/models/model_1/model_epoch2.ckpt 222 | Epoch 3 223 | Dev F1 after epoch 3: 0.8959999680519104 224 | Maximum number of epochs reached during early stopping. 225 | Finished training. 226 | Load best model from epoch 2 227 | reading logs... 228 | Finished training after 10.74 min 229 | Dev F1: 0.9059 230 | Test F1: 0.8841 231 | reading logs... 232 | Wrote predictions for model_1. 233 | ``` 234 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuningxi/tBERT/604e3f728629c5dd0fe1f201464440edbaf82e79/data/.gitkeep -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.1.10 2 | adal==1.2.0 3 | antlr4-python3-runtime==4.7.1 4 | applicationinsights==0.11.7 5 | appnope==0.1.0 6 | argcomplete==1.9.4 7 | asn1crypto==0.24.0 8 | astor==0.7.1 9 | attrs==17.4.0 10 | backports.tempfile==1.0 11 | backports.weakref==1.0.post1 12 | bcrypt==3.1.5 13 | bleach==1.5.0 14 | boto==2.48.0 15 | boto3==1.5.35 16 | botocore==1.8.49 17 | bz2file==0.98 18 | certifi==2018.1.18 19 | cffi==1.11.5 20 | chardet==3.0.4 21 | colorama==0.4.1 22 | contextlib2==0.5.5 23 | cryptography==2.4.2 24 | cycler==0.10.0 25 | cymem==2.0.2 26 | cytoolz==0.9.0.1 27 | decorator==4.2.1 28 | dill==0.2.8.2 29 | diskcache==3.0.6 30 | docker==3.6.0 31 | docker-pycreds==0.4.0 32 | docutils==0.14 33 | entrypoints==0.2.3 34 | funcy==1.11 35 | future==0.16.0 36 | gast==0.2.2 37 | gensim==3.6.0 38 | gitdb2==2.0.5 39 | GitPython==2.1.11 40 | googleapis-common-protos==1.5.3 41 | grpcio==1.12.1 42 | h5py==2.7.1 43 | html5lib==0.9999999 44 | humanfriendly==4.17 45 | hyperopt==0.1 46 | idna==2.6 47 | imageio==2.4.1 48 | ipykernel==4.8.2 49 | ipython==6.2.1 50 | ipython-genutils==0.2.0 51 | ipywidgets==7.1.2 52 | isodate==0.6.0 53 | jedi==0.11.1 54 | Jinja2==2.10 55 | jmespath==0.9.3 56 | joblib==0.12.5 57 | jsonpickle==1.0 58 | jsonschema==2.6.0 59 | jupyter==1.0.0 60 | jupyter-client==5.2.2 61 | jupyter-console==5.2.0 62 | jupyter-core==4.4.0 63 | Keras==2.1.6 64 | Keras-Applications==1.0.8 65 | Keras-Preprocessing==1.1.0 66 | knack==0.5.1 67 | liac-arff==2.3.1 68 | Markdown==2.6.11 69 | MarkupSafe==1.0 70 | matplotlib==2.1.2 71 | mistune==0.8.3 72 | mlxtend==0.15.0.0 73 | msgpack==0.5.6 74 | msgpack-numpy==0.4.3.2 75 | msrest==0.6.2 76 | msrestazure==0.5.1 77 | murmurhash==1.0.1 78 | nbconvert==5.3.1 79 | nbformat==4.4.0 80 | ndg-httpsclient==0.5.1 81 | networkx==1.11 82 | nltk==3.2.5 83 | nose==1.3.7 84 | notebook==5.4.0 85 | numexpr==2.6.8 86 | numpy==1.15.3 87 | numpy-indexed==0.3.5 88 | oauthlib==2.1.0 89 | overrides==1.9 90 | pandas==0.22.0 91 | pandocfilters==1.4.2 92 | paramiko==2.4.2 93 | parso==0.1.1 94 | pathspec==0.5.9 95 | patsy==0.5.1 96 | pexpect==4.4.0 97 | pickleshare==0.7.4 98 | Pillow==5.3.0 99 | pkg-resources==0.0.0 100 | plac==0.9.6 101 | pluggy==0.6.0 102 | portalocker==1.2.1 103 | preshed==2.0.1 104 | progressbar33==2.4 105 | prompt-toolkit==1.0.15 106 | protobuf==3.7.0 107 | ptyprocess==0.5.2 108 | py==1.5.2 109 | pyasn1==0.4.4 110 | pycparser==2.19 111 | Pygments==2.2.0 112 | PyJWT==1.7.1 113 | pyLDAvis==2.1.2 114 | pymongo==3.6.0 115 | PyNaCl==1.3.0 116 | pyOpenSSL==18.0.0 117 | pyparsing==2.2.0 118 | pytest==3.4.1 119 | python-dateutil==2.7.5 120 | pytz==2018.3 121 | PyYAML==3.13 122 | pyzmq==17.0.0 123 | qtconsole==4.3.1 124 | regex==2018.1.10 125 | requests==2.21.0 126 | requests-oauthlib==1.0.0 127 | ruamel.yaml==0.15.51 128 | s3transfer==0.1.13 129 | scikit-learn==0.19.1 130 | scikit-optimize==0.5.2 131 | scipy==1.0.0 132 | seaborn==0.9.0 133 | SecretStorage==2.3.1 134 | Send2Trash==1.5.0 135 | simplegeneric==0.8.1 136 | six==1.12.0 137 | sklearn==0.0 138 | smart-open==1.5.6 139 | smmap2==2.0.5 140 | spacy==2.0.16 141 | statsmodels==0.9.0 142 | tabulate==0.8.2 143 | tensorboard==1.11.0 144 | tensorflow-gpu==1.11.0 145 | tensorflow-hub==0.3.0 146 | termcolor==1.1.0 147 | terminado==0.8.1 148 | testpath==0.3.1 149 | thinc==6.12.0 150 | toolz==0.9.0 151 | tornado==4.5.3 152 | tqdm==4.28.1 153 | traitlets==4.3.2 154 | ujson==1.35 155 | urllib3==1.23 156 | wcwidth==0.1.7 157 | websocket-client==0.54.0 158 | Werkzeug==0.14.1 159 | widgetsnbextension==3.1.4 160 | wrapt==1.10.11 161 | -------------------------------------------------------------------------------- /src/ShortTextTopic/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | ### Eclipse template 3 | 4 | .metadata 5 | bin/ 6 | tmp/ 7 | *.tmp 8 | *.bak 9 | *.swp 10 | *~.nib 11 | local.properties 12 | .settings/ 13 | .loadpath 14 | .recommenders 15 | 16 | # Eclipse Core 17 | .project 18 | 19 | # External tool builders 20 | .externalToolBuilders/ 21 | 22 | # Locally stored "Eclipse launch configurations" 23 | *.launch 24 | 25 | # PyDev specific (Python IDE for Eclipse) 26 | *.pydevproject 27 | 28 | # CDT-specific (C/C++ Development Tooling) 29 | .cproject 30 | 31 | # JDT-specific (Eclipse Java Development Tools) 32 | .classpath 33 | 34 | # Java annotation processor (APT) 35 | .factorypath 36 | 37 | # PDT-specific (PHP Development Tools) 38 | .buildpath 39 | 40 | # sbteclipse plugin 41 | .target 42 | 43 | # Tern plugin 44 | .tern-project 45 | 46 | # TeXlipse plugin 47 | .texlipse 48 | 49 | # STS (Spring Tool Suite) 50 | .springBeans 51 | 52 | # Code Recommenders 53 | .recommenders/ 54 | ### VirtualEnv template 55 | # Virtualenv 56 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ 57 | .Python 58 | [Bb]in 59 | [Ii]nclude 60 | [Ll]ib 61 | [Ll]ib64 62 | [Ll]ocal 63 | [Ss]cripts 64 | pyvenv.cfg 65 | .venv 66 | pip-selfcheck.json 67 | ### JetBrains template 68 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 69 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 70 | 71 | # User-specific stuff: 72 | .idea/workspace.xml 73 | .idea/tasks.xml 74 | 75 | # Sensitive or high-churn files: 76 | .idea/dataSources/ 77 | .idea/dataSources.ids 78 | .idea/dataSources.xml 79 | .idea/dataSources.local.xml 80 | .idea/sqlDataSources.xml 81 | .idea/dynamic.xml 82 | .idea/uiDesigner.xml 83 | 84 | # Gradle: 85 | .idea/gradle.xml 86 | .idea/libraries 87 | 88 | # Mongo Explorer plugin: 89 | .idea/mongoSettings.xml 90 | 91 | ## File-based project format: 92 | *.iws 93 | 94 | ## Plugin-specific files: 95 | 96 | # IntelliJ 97 | /out/ 98 | 99 | # mpeltonen/sbt-idea plugin 100 | .idea_modules/ 101 | 102 | # JIRA plugin 103 | atlassian-ide-plugin.xml 104 | 105 | # Crashlytics plugin (for Android Studio and IntelliJ) 106 | com_crashlytics_export_strings.xml 107 | crashlytics.properties 108 | crashlytics-build.properties 109 | fabric.properties 110 | ### Python template 111 | # Byte-compiled / optimized / DLL files 112 | __pycache__/ 113 | *.py[cod] 114 | *$py.class 115 | 116 | # C extensions 117 | *.so 118 | 119 | # Distribution / packaging 120 | .Python 121 | env/ 122 | build/ 123 | develop-eggs/ 124 | dist/ 125 | downloads/ 126 | eggs/ 127 | .eggs/ 128 | lib/ 129 | lib64/ 130 | parts/ 131 | sdist/ 132 | var/ 133 | *.egg-info/ 134 | .installed.cfg 135 | *.egg 136 | 137 | # PyInstaller 138 | # Usually these files are written by a python script from a template 139 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 140 | *.manifest 141 | *.spec 142 | 143 | # Installer logs 144 | pip-log.txt 145 | pip-delete-this-directory.txt 146 | 147 | # Unit test / coverage reports 148 | htmlcov/ 149 | .tox/ 150 | .coverage 151 | .coverage.* 152 | .cache 153 | nosetests.xml 154 | coverage.xml 155 | *,cover 156 | .hypothesis/ 157 | 158 | # Translations 159 | *.mo 160 | *.pot 161 | 162 | # Django stuff: 163 | *.log 164 | local_settings.py 165 | 166 | # Flask stuff: 167 | instance/ 168 | .webassets-cache 169 | 170 | # Scrapy stuff: 171 | .scrapy 172 | 173 | # Sphinx documentation 174 | docs/_build/ 175 | 176 | # PyBuilder 177 | target/ 178 | 179 | # Jupyter Notebook 180 | .ipynb_checkpoints 181 | 182 | # pyenv 183 | .python-version 184 | 185 | # celery beat schedule file 186 | celerybeat-schedule 187 | 188 | # dotenv 189 | .env 190 | 191 | # virtualenv 192 | .venv/ 193 | venv/ 194 | ENV/ 195 | 196 | # Spyder project settings 197 | .spyderproject 198 | 199 | # Rope project settings 200 | .ropeproject 201 | ### JetBrains template 202 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 203 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 204 | 205 | # User-specific stuff: 206 | .idea/workspace.xml 207 | .idea/tasks.xml 208 | 209 | # Sensitive or high-churn files: 210 | .idea/dataSources/ 211 | .idea/dataSources.ids 212 | .idea/dataSources.xml 213 | .idea/dataSources.local.xml 214 | .idea/sqlDataSources.xml 215 | .idea/dynamic.xml 216 | .idea/uiDesigner.xml 217 | 218 | # Gradle: 219 | .idea/gradle.xml 220 | .idea/libraries 221 | 222 | # Mongo Explorer plugin: 223 | .idea/mongoSettings.xml 224 | 225 | ## File-based project format: 226 | *.iws 227 | 228 | ## Plugin-specific files: 229 | 230 | # IntelliJ 231 | /out/ 232 | 233 | # mpeltonen/sbt-idea plugin 234 | .idea_modules/ 235 | 236 | # JIRA plugin 237 | atlassian-ide-plugin.xml 238 | 239 | # Crashlytics plugin (for Android Studio and IntelliJ) 240 | com_crashlytics_export_strings.xml 241 | crashlytics.properties 242 | crashlytics-build.properties 243 | fabric.properties 244 | ### Example user template template 245 | ### Example user template 246 | 247 | # IntelliJ project files 248 | .idea 249 | *.iml 250 | out 251 | gen 252 | -------------------------------------------------------------------------------- /src/ShortTextTopic/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Ryan Walker 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/ShortTextTopic/README.md: -------------------------------------------------------------------------------- 1 | # GSDMM: Short text clustering 2 | 3 | This project implements the Gibbs sampling algorithm for a Dirichlet Mixture Model of [Yin and Wang 2014](https://pdfs.semanticscholar.org/058a/d0815ce350f0e7538e00868c762be78fe5ef.pdf) for the 4 | clustering of short text documents. 5 | Some advantages of this algorithm: 6 | - It requires only an upper bound `K` on the number of clusters 7 | - With good parameter selection, the model converges quickly 8 | - Space efficient and scalable 9 | 10 | This project is an easy to read reference implementation of GSDMM -- I don't plan to maintain it unless there is demand. I am however actively maintaining the much faster Rust version of GSDMM [here](https://github.com/rwalk/gsdmm-rust). 11 | 12 | ## The Movie Group Process 13 | In their paper, the authors introduce a simple conceptual model for explaining the GSDMM called the Movie Group Process. 14 | 15 | Imagine a professor is leading a film class. At the start of the class, the students 16 | are randomly assigned to `K` tables. Before class begins, the students make lists of 17 | their favorite films. The professor repeatedly reads the class role. Each time the student's name is called, 18 | the student must select a new table satisfying one or both of the following conditions: 19 | 20 | - The new table has more students than the current table. 21 | - The new table has students with similar lists of favorite movies. 22 | 23 | By following these steps consistently, we might expect that the students eventually arrive at an "optimal" table configuration. 24 | 25 | ## Usage 26 | To use a Movie Group Process to cluster short texts, first initialize a [MovieGroupProcess](gsdmm/mgp.py): 27 | ```python 28 | from gsdmm import MovieGroupProcess 29 | mgp = MovieGroupProcess(K=8, alpha=0.1, beta=0.1, n_iters=30) 30 | ``` 31 | It's important to always choose `K` to be larger than the number of clusters you expect exist in your data, as the algorithm 32 | can never return more than `K` clusters. 33 | 34 | To fit the model: 35 | ```python 36 | y = mgp.fit(docs) 37 | ``` 38 | Each doc in `docs` must be a unique list of tokens found in your short text document. This implementation does not support 39 | counting tokens with multiplicity (which generally has little value in short text documents). 40 | -------------------------------------------------------------------------------- /src/ShortTextTopic/gsdmm/__init__.py: -------------------------------------------------------------------------------- 1 | from .mgp import MovieGroupProcess -------------------------------------------------------------------------------- /src/ShortTextTopic/gsdmm/mgp.py: -------------------------------------------------------------------------------- 1 | from numpy.random import multinomial 2 | from numpy import log, exp 3 | from numpy import argmax 4 | import json 5 | 6 | class MovieGroupProcess: 7 | def __init__(self, K=8, alpha=0.1, beta=0.1, n_iters=30): 8 | ''' 9 | A MovieGroupProcess is a conceptual model introduced by Yin and Wang 2014 to 10 | describe their Gibbs sampling algorithm for a Dirichlet Mixture Model for the 11 | clustering short text documents. 12 | Reference: http://dbgroup.cs.tsinghua.edu.cn/wangjy/papers/KDD14-GSDMM.pdf 13 | 14 | Imagine a professor is leading a film class. At the start of the class, the students 15 | are randomly assigned to K tables. Before class begins, the students make lists of 16 | their favorite films. The teacher reads the role n_iters times. When 17 | a student is called, the student must select a new table satisfying either: 18 | 1) The new table has more students than the current table. 19 | OR 20 | 2) The new table has students with similar lists of favorite movies. 21 | 22 | :param K: int 23 | Upper bound on the number of possible clusters. Typically many fewer 24 | :param alpha: float between 0 and 1 25 | Alpha controls the probability that a student will join a table that is currently empty 26 | When alpha is 0, no one will join an empty table. 27 | :param beta: float between 0 and 1 28 | Beta controls the student's affinity for other students with similar interests. A low beta means 29 | that students desire to sit with students of similar interests. A high beta means they are less 30 | concerned with affinity and are more influenced by the popularity of a table 31 | :param n_iters: 32 | ''' 33 | self.K = K 34 | self.alpha = alpha 35 | self.beta = beta 36 | self.n_iters = n_iters 37 | 38 | # slots for computed variables 39 | self.number_docs = None 40 | self.vocab_size = None 41 | self.cluster_doc_count = [0 for _ in range(K)] 42 | self.cluster_word_count = [0 for _ in range(K)] 43 | self.cluster_word_distribution = [{} for i in range(K)] 44 | 45 | @staticmethod 46 | def from_data(K, alpha, beta, D, vocab_size, cluster_doc_count, cluster_word_count, cluster_word_distribution): 47 | ''' 48 | Reconstitute a MovieGroupProcess from previously fit data 49 | :param K: 50 | :param alpha: 51 | :param beta: 52 | :param D: 53 | :param vocab_size: 54 | :param cluster_doc_count: 55 | :param cluster_word_count: 56 | :param cluster_word_distribution: 57 | :return: 58 | ''' 59 | mgp = MovieGroupProcess(K, alpha, beta, n_iters=30) 60 | mgp.number_docs = D 61 | mgp.vocab_size = vocab_size 62 | mgp.cluster_doc_count = cluster_doc_count 63 | mgp.cluster_word_count = cluster_word_count 64 | mgp.cluster_word_distribution = cluster_word_distribution 65 | return mgp 66 | 67 | @staticmethod 68 | def _sample(p): 69 | ''' 70 | Sample with probability vector p from a multinomial distribution 71 | :param p: list 72 | List of probabilities representing probability vector for the multinomial distribution 73 | :return: int 74 | index of randomly selected output 75 | ''' 76 | return [i for i, entry in enumerate(multinomial(1, p)) if entry != 0][0] 77 | 78 | def fit(self, docs, vocab_size): 79 | ''' 80 | Cluster the input documents 81 | :param docs: list of list 82 | list of lists containing the unique token set of each document 83 | :param V: total vocabulary size for each document 84 | :return: list of length len(doc) 85 | cluster label for each document 86 | ''' 87 | alpha, beta, K, n_iters, V = self.alpha, self.beta, self.K, self.n_iters, vocab_size 88 | 89 | D = len(docs) 90 | self.number_docs = D 91 | self.vocab_size = vocab_size 92 | 93 | # unpack to easy var names 94 | m_z, n_z, n_z_w = self.cluster_doc_count, self.cluster_word_count, self.cluster_word_distribution 95 | cluster_count = K 96 | d_z = [None for i in range(len(docs))] 97 | 98 | # initialize the clusters 99 | for i, doc in enumerate(docs): 100 | 101 | # choose a random initial cluster for the doc 102 | z = self._sample([1.0 / K for _ in range(K)]) 103 | d_z[i] = z 104 | m_z[z] += 1 105 | n_z[z] += len(doc) 106 | 107 | for word in doc: 108 | if word not in n_z_w[z]: 109 | n_z_w[z][word] = 0 110 | n_z_w[z][word] += 1 111 | 112 | for _iter in range(n_iters): 113 | total_transfers = 0 114 | 115 | for i, doc in enumerate(docs): 116 | 117 | # remove the doc from it's current cluster 118 | z_old = d_z[i] 119 | 120 | m_z[z_old] -= 1 121 | n_z[z_old] -= len(doc) 122 | 123 | for word in doc: 124 | n_z_w[z_old][word] -= 1 125 | 126 | # compact dictionary to save space 127 | if n_z_w[z_old][word] == 0: 128 | del n_z_w[z_old][word] 129 | 130 | # draw sample from distribution to find new cluster 131 | p = self.score(doc) 132 | z_new = self._sample(p) 133 | 134 | # transfer doc to the new cluster 135 | if z_new != z_old: 136 | total_transfers += 1 137 | 138 | d_z[i] = z_new 139 | m_z[z_new] += 1 140 | n_z[z_new] += len(doc) 141 | 142 | for word in doc: 143 | if word not in n_z_w[z_new]: 144 | n_z_w[z_new][word] = 0 145 | n_z_w[z_new][word] += 1 146 | 147 | cluster_count_new = sum([1 for v in m_z if v > 0]) 148 | print("In stage %d: transferred %d clusters with %d clusters populated" % ( 149 | _iter, total_transfers, cluster_count_new)) 150 | if total_transfers == 0 and cluster_count_new == cluster_count and _iter>25: 151 | print("Converged. Breaking out.") 152 | break 153 | cluster_count = cluster_count_new 154 | self.cluster_word_distribution = n_z_w 155 | return d_z 156 | 157 | def score(self, doc): 158 | ''' 159 | Score a document 160 | 161 | Implements formula (3) of Yin and Wang 2014. 162 | http://dbgroup.cs.tsinghua.edu.cn/wangjy/papers/KDD14-GSDMM.pdf 163 | 164 | :param doc: list[str]: The doc token stream 165 | :return: list[float]: A length K probability vector where each component represents 166 | the probability of the document appearing in a particular cluster 167 | ''' 168 | alpha, beta, K, V, D = self.alpha, self.beta, self.K, self.vocab_size, self.number_docs 169 | m_z, n_z, n_z_w = self.cluster_doc_count, self.cluster_word_count, self.cluster_word_distribution 170 | 171 | p = [0 for _ in range(K)] 172 | 173 | # We break the formula into the following pieces 174 | # p = N1*N2/(D1*D2) = exp(lN1 - lD1 + lN2 - lD2) 175 | # lN1 = log(m_z[z] + alpha) 176 | # lN2 = log(D - 1 + K*alpha) 177 | # lN2 = log(product(n_z_w[w] + beta)) = sum(log(n_z_w[w] + beta)) 178 | # lD2 = log(product(n_z[d] + V*beta + i -1)) = sum(log(n_z[d] + V*beta + i -1)) 179 | 180 | lD1 = log(D - 1 + K * alpha) 181 | doc_size = len(doc) 182 | for label in range(K): 183 | lN1 = log(m_z[label] + alpha) 184 | lN2 = 0 185 | lD2 = 0 186 | for word in doc: 187 | lN2 += log(n_z_w[label].get(word, 0) + beta) 188 | for j in range(1, doc_size +1): 189 | lD2 += log(n_z[label] + V * beta + j - 1) 190 | p[label] = exp(lN1 - lD1 + lN2 - lD2) 191 | 192 | # normalize the probability vector 193 | pnorm = sum(p) 194 | pnorm = pnorm if pnorm>0 else 1 195 | return [pp/pnorm for pp in p] 196 | 197 | def choose_best_label(self, doc): 198 | ''' 199 | Choose the highest probability label for the input document 200 | :param doc: list[str]: The doc token stream 201 | :return: 202 | ''' 203 | p = self.score(doc) 204 | return argmax(p),max(p) 205 | -------------------------------------------------------------------------------- /src/ShortTextTopic/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy -------------------------------------------------------------------------------- /src/ShortTextTopic/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | VERSION=0.1 4 | INSTALL_REQUIRES = [ 5 | 'numpy' 6 | ] 7 | 8 | setup( 9 | name='gsdmm', 10 | packages=['gsdmm'], 11 | version=0.1, 12 | url='https://www.github.com/rwalk/gsdmm', 13 | author='Ryan Walker', 14 | author_email='ryan@ryanwalker.us', 15 | description='GSDMM: Short text clustering ', 16 | license='MIT', 17 | install_requires=INSTALL_REQUIRES 18 | ) 19 | -------------------------------------------------------------------------------- /src/ShortTextTopic/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuningxi/tBERT/604e3f728629c5dd0fe1f201464440edbaf82e79/src/ShortTextTopic/test/__init__.py -------------------------------------------------------------------------------- /src/ShortTextTopic/test/test_gsdmm.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | from gsdmm.mgp import MovieGroupProcess 3 | import numpy 4 | 5 | class TestGSDMM(TestCase): 6 | '''This class tests the Panel data structures needed to support the RSK model''' 7 | 8 | def setUp(self): 9 | numpy.random.seed(47) 10 | 11 | def tearDown(self): 12 | numpy.random.seed(None) 13 | 14 | def compute_V(self, texts): 15 | V = set() 16 | for text in texts: 17 | for word in text: 18 | V.add(word) 19 | return len(V) 20 | 21 | def test_grades(self): 22 | 23 | grades = list(map(list, [ 24 | "A", 25 | "A", 26 | "A", 27 | "B", 28 | "B", 29 | "B", 30 | "B", 31 | "C", 32 | "C", 33 | "C", 34 | "C", 35 | "C", 36 | "C", 37 | "C", 38 | "C", 39 | "C", 40 | "C", 41 | "D", 42 | "D", 43 | "F", 44 | "F", 45 | "P", 46 | "W" 47 | ])) 48 | 49 | grades = grades + grades + grades + grades + grades 50 | mgp = MovieGroupProcess(K=100, n_iters=100, alpha=0.001, beta=0.01) 51 | y = mgp.fit(grades, self.compute_V(grades)) 52 | self.assertEqual(len(set(y)), 7) 53 | for words in mgp.cluster_word_distribution: 54 | self.assertTrue(len(words) in {0,1}, "More than one grade ended up in a cluster!") 55 | 56 | def test_short_text(self): 57 | # there is no perfect segmentation of this text data: 58 | texts = [ 59 | "where the red dog lives", 60 | "red dog lives in the house", 61 | "blue cat eats mice", 62 | "monkeys hate cat but love trees", 63 | "green cat eats mice", 64 | "orange elephant never forgets", 65 | "orange elephant must forget", 66 | "monkeys eat banana", 67 | "monkeys live in trees", 68 | "elephant", 69 | "cat", 70 | "dog", 71 | "monkeys" 72 | ] 73 | 74 | texts = [text.split() for text in texts] 75 | V = self.compute_V(texts) 76 | mgp = MovieGroupProcess(K=30, n_iters=100, alpha=0.2, beta=0.01) 77 | y = mgp.fit(texts, V) 78 | self.assertTrue(len(set(y))<10) 79 | self.assertTrue(len(set(y))>3) -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuningxi/tBERT/604e3f728629c5dd0fe1f201464440edbaf82e79/src/__init__.py -------------------------------------------------------------------------------- /src/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuningxi/tBERT/604e3f728629c5dd0fe1f201464440edbaf82e79/src/evaluation/__init__.py -------------------------------------------------------------------------------- /src/evaluation/difficulty/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuningxi/tBERT/604e3f728629c5dd0fe1f201464440edbaf82e79/src/evaluation/difficulty/__init__.py -------------------------------------------------------------------------------- /src/evaluation/difficulty/difficulty_metrics.py: -------------------------------------------------------------------------------- 1 | def count_tp_tn_fp_fn(split_df,case=None): 2 | if case is None: 3 | tp = len(split_df.loc[(split_df['gold_label']==1) & (split_df['pred_label']==1)]) 4 | tn = len(split_df.loc[(split_df['gold_label']==0) & (split_df['pred_label']==0)]) 5 | fp = len(split_df.loc[(split_df['gold_label']==0) & (split_df['pred_label']==1)]) 6 | fn = len(split_df.loc[(split_df['gold_label']==1) & (split_df['pred_label']==0)]) 7 | else: 8 | tp = len(split_df.loc[(split_df['difficulty'].str.contains(case)) & (split_df['gold_label']==1) & (split_df['pred_label']==1)]) 9 | tn = len(split_df.loc[(split_df['difficulty'].str.contains(case)) & (split_df['gold_label']==0) & (split_df['pred_label']==0)]) 10 | fp = len(split_df.loc[(split_df['difficulty'].str.contains(case)) & (split_df['gold_label']==0) & (split_df['pred_label']==1)]) 11 | fn = len(split_df.loc[(split_df['difficulty'].str.contains(case)) & (split_df['gold_label']==1) & (split_df['pred_label']==0)]) 12 | return [tp,tn,fp,fn] 13 | 14 | def safe_div(x,y): 15 | if y == 0: 16 | return 0 17 | return x / y 18 | def calculate_accuracy(tp,tn,fp,fn): 19 | return round(safe_div((tp+tn), (tp+tn+fp+fn)),3) 20 | def calculate_precision(tp,tn,fp,fn): 21 | return round(safe_div(tp, (tp+fp)),3) 22 | def calculate_recall(tp,tn,fp,fn): 23 | return round(safe_div(tp, (tp+fn)),3) 24 | def calculate_true_neg_rate(tp,tn,fp,fn): 25 | return round(safe_div(tn , (tn+fp)),3) 26 | def calculate_f1(precision,recall): 27 | return round(safe_div(2 * precision * recall, (precision + recall)), 3) 28 | 29 | 30 | if __name__ == '__main__': 31 | from src.evaluation.metrics.lexical_similarity import LexicalSimilarity 32 | from src.evaluation.difficulty.difficulty_cases import annotate_difficulty_case,load_subset_pred_overlap 33 | 34 | 35 | VM_path = False 36 | opt = {'dataset': 'Semeval', 'datapath': 'data/', 37 | 'tasks': ['B'],'n_gram_embd':False, 38 | 'subsets': ['train_large','dev', 'test'], 39 | 'load_ids': False, 'cache':True,'id':1155, 40 | 'simple_padding': True, 'padding': True} 41 | metric = 'js-div' 42 | LexSim = LexicalSimilarity() 43 | overlapping = LexSim.get_metric(metric,opt['dataset'],opt['tasks'][0],'test') 44 | test_df = load_subset_pred_overlap(opt, 'test', overlapping, VM_path) 45 | test_df 46 | split_df = annotate_difficulty_case(test_df,metric,split_by='median') 47 | count_tp_tn_fp_fn(split_df) 48 | [tp,tn,fp,fn] = count_tp_tn_fp_fn(split_df) 49 | print(calculate_accuracy(tp,tn,fp,fn)) 50 | print(calculate_precision(tp,tn,fp,fn)) 51 | print(calculate_recall(tp,tn,fp,fn)) 52 | print(calculate_f1(calculate_precision(tp,tn,fp,fn),calculate_recall(tp,tn,fp,fn))) -------------------------------------------------------------------------------- /src/evaluation/evaluate.py: -------------------------------------------------------------------------------- 1 | from src.models.save_load import get_model_dir 2 | import csv 3 | import numpy as np 4 | import tensorflow as tf 5 | import pandas as pd 6 | from src.loaders.load_data import load_data 7 | 8 | def get_confidence_scores(Z3,normalised=False): 9 | if normalised: 10 | # normalise logits to get probablities?? 11 | Z3 = tf.nn.softmax(Z3) 12 | conf_score = tf.gather(Z3,1,axis=1,name='conf_score') # is equal to: Z3[:,1] 13 | return conf_score 14 | 15 | def save_eval_metrics(metrics, opt, data_split='test',dict_key='score'): 16 | # dev_metrics = [acc, prec, rec, f_1, ma_p] 17 | metric_names = ['Accuracy', 'Precision', 'Recall', 'F1', 'MAP'] 18 | if dict_key not in opt: 19 | opt[dict_key] = {} 20 | for i,eval_score in enumerate(metrics): 21 | metric = metric_names[i] 22 | if metric not in opt[dict_key]: 23 | opt[dict_key][metric] = {} 24 | if eval_score is None: 25 | opt[dict_key][metric][data_split] = eval_score 26 | else: 27 | opt[dict_key][metric][data_split] = round(float(eval_score), 4) # prevent problem when writing log file 28 | return opt 29 | 30 | def output_predictions(query_ids,doc_ids,Z3,Y,subset,opt): 31 | ''' 32 | Writes an output files with system predictions to be evaluated by official Semeval scorer. 33 | :param query_ids: list of question ids 34 | :param doc_ids: list of document ids 35 | :param Z3: numpy array with ranking scores (m,) 36 | :param Y: numpy array with True / False (m,) 37 | :param opt: 38 | :param subset: string indicating which subset of the data ('train','dev','test) 39 | :return: 40 | ''' 41 | if 'PAWS' in subset: 42 | subset = subset.replace('PAWS','p') 43 | outfile = get_model_dir(opt)+'subtask'+''.join(opt['tasks'])+'.'+subset+'.pred' 44 | with open(outfile, 'w') as f: 45 | file_writer = csv.writer(f,delimiter='\t') 46 | # print(Y) 47 | label = [str(e==1).lower() for e in Y] 48 | for i in range(len(query_ids)): 49 | file_writer.writerow([query_ids[i],doc_ids[i],0,Z3[i],label[i]]) 50 | 51 | def read_predictions(opt,subset='dev',VM_path=True): 52 | ''' 53 | Reads prediction file from model directory, extracts pair id, prediction score and predicted label. 54 | :param opt: option log 55 | :param subset: ['train','dev','test'] 56 | :param VM_path: was prediction file transferred from VM? 57 | :return: pandas dataframe 58 | ''' 59 | if type(opt['id'])==str: 60 | if opt['dataset']=='Semeval': 61 | outfile = get_model_dir(opt,VM_copy=VM_path)+'subtask_'+''.join(opt['tasks'])+'_'+subset+'.txt' 62 | else: 63 | outfile = get_model_dir(opt, VM_copy=VM_path) + 'subtask' + ''.join(opt['tasks']) + '.' + subset + '.pred' 64 | else: 65 | outfile = get_model_dir(opt,VM_copy=VM_path)+'subtask'+''.join(opt['tasks'])+'.'+subset+'.pred' 66 | print(outfile) 67 | predictions = [] 68 | with open(outfile, 'r') as f: 69 | file_reader = csv.reader(f,delimiter='\t') 70 | for id1,id2,_,score,pred_label in file_reader: 71 | pairid = id1+'-'+id2 72 | if pred_label == 'true': 73 | pred_label=1 74 | elif pred_label == 'false': 75 | pred_label=0 76 | else: 77 | raise ValueError("Output labels should be 'true' or 'false', but are {}.".format(pred_label)) 78 | predictions.append([pairid,score,pred_label]) 79 | cols = ['pair_id','score','pred_label'] 80 | prediction_df = pd.DataFrame.from_records(predictions,columns=cols) 81 | return prediction_df 82 | 83 | def read_original_data(opt, subset='dev'): 84 | ''' 85 | Reads original labelled dev file from data directory, extracts get pair_id, gold_label and sentences. 86 | :param opt: option log 87 | :param subset: ['train','dev','test'] 88 | :return: pandas dataframe 89 | ''' 90 | # adjust filenames in case of increased training data 91 | if 'train_large' in opt['subsets']: 92 | print('adjusting names') 93 | if subset=='dev': 94 | subset='test2016' 95 | elif subset=='test': 96 | subset='test2017' 97 | # adjust loading options: 98 | opt['subsets'] = [subset] # only specific subset 99 | opt['load_ids'] = True # with labels 100 | # print(opt) 101 | data_dict = load_data(opt,numerical=False) 102 | ID1 = data_dict['ID1'][0] # unlist, as we are only dealing with one subset 103 | ID2 = data_dict['ID2'][0] 104 | R1 = data_dict['R1'][0] 105 | R2 = data_dict['R2'][0] 106 | L = data_dict['L'][0] 107 | # extract get pair_id, gold_label, sentences 108 | labeled_data = [] 109 | for i in range(len(L)): 110 | pair_id = ID1[i]+'-'+ID2[i] 111 | gold_label = L[i] 112 | s1 = R1[i] 113 | s2 = R2[i] 114 | labeled_data.append([pair_id,gold_label,s1,s2]) 115 | # turn into pandas dataframe 116 | cols = ['pair_id','gold_label','s1','s2'] 117 | label_df = pd.DataFrame.from_records(labeled_data,columns=cols) 118 | return label_df -------------------------------------------------------------------------------- /src/evaluation/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuningxi/tBERT/604e3f728629c5dd0fe1f201464440edbaf82e79/src/evaluation/metrics/__init__.py -------------------------------------------------------------------------------- /src/evaluation/metrics/dice.py: -------------------------------------------------------------------------------- 1 | def dice_similarity(list1, list2, print_intersection=False): 2 | ''' 3 | Calculates Dice Coefficient between two lists of words 4 | ''' 5 | set1 = set(list1) 6 | set2 = set(list2) 7 | intersection = len(list(set1.intersection(list2))) 8 | if print_intersection: 9 | print(list(set1.intersection(list2))) 10 | # union = (len(list1) + len(list2)) - intersection 11 | return 2*float(intersection / (len(set1)+len(set2))) 12 | 13 | 14 | def calculate_dice_sim(R1,R2): 15 | ''' 16 | Calculates Dice Coefficient for each sentence pair and returns nested list with similarities in each subset. 17 | :param R1: raw text 1 18 | :param R2: raw text 2 19 | :return : nested list with dice coefficent 20 | ''' 21 | subset_overlap = [] 22 | for n in range(len(R1)): 23 | sim_per_pair = [] 24 | for i in range(len(R1[n])): 25 | s1 = R1[n][i] 26 | s2 = R2[n][i] 27 | sim = dice_similarity(s1, s2) 28 | sim_per_pair.append(sim) 29 | subset_overlap.append(sim_per_pair) 30 | return subset_overlap 31 | 32 | 33 | if __name__ == '__main__': 34 | l1 = ['this', 'is', 'great'] 35 | l2 = ['this', 'not'] 36 | 37 | dice_similarity(l1, l2, True) 38 | -------------------------------------------------------------------------------- /src/evaluation/metrics/jaccard.py: -------------------------------------------------------------------------------- 1 | def jaccard_similarity(list1, list2, print_intersection=False): 2 | ''' 3 | Calculates Jaccard Index between two lists of words 4 | ''' 5 | intersection = len(list(set(list1).intersection(list2))) 6 | if print_intersection: 7 | print(list(set(list1).intersection(list2))) 8 | union = (len(list1) + len(list2)) - intersection 9 | return float(intersection / union) 10 | 11 | 12 | def calculate_jaccard_index(R1,R2): 13 | ''' 14 | Calculates Jaccard Index for each sentence pair and returns nested list with similarities in each subset. 15 | :param R1: raw text 1 16 | :param R2: raw text 2 17 | :return : nested list with jaccard index 18 | ''' 19 | subset_overlap = [] 20 | 21 | for n in range(len(R1)): 22 | sim_per_pair = [] 23 | for i in range(len(R1[n])): 24 | s1 = R1[n][i] 25 | s2 = R2[n][i] 26 | sim = jaccard_similarity(s1,s2) 27 | sim_per_pair.append(sim) 28 | subset_overlap.append(sim_per_pair) 29 | return subset_overlap 30 | 31 | if __name__=='__main__': 32 | l1 = ['this','is','great'] 33 | l2 = ['this','not'] 34 | print(jaccard_similarity(l1, l2, True)) -------------------------------------------------------------------------------- /src/evaluation/metrics/js_div.py: -------------------------------------------------------------------------------- 1 | import scipy 2 | import numpy as np 3 | from src.loaders.load_data import load_data 4 | import re, math, collections 5 | from nltk.corpus import stopwords 6 | 7 | def js(p, q, print_dist=False): 8 | # _p = p / norm(p, ord=1) 9 | # _q = q / norm(q, ord=1) 10 | _p = p/p.sum() 11 | _q = q/q.sum() 12 | _m = (_p + _q) / 2 13 | if print_dist: 14 | print('p\n{}'.format(_p)) 15 | print('q\n{}'.format(_q)) 16 | print('m\n{}'.format(_m)) 17 | return (scipy.stats.entropy(_p, _m,base=2) + scipy.stats.entropy(_q, _m,base=2)) / 2 18 | 19 | # split into tokens and count individual tokens (dict) 20 | def tokenize(_str,stopword_filter=False): 21 | # return _str.split(' ') 22 | # stopwords = ['and', 'for', 'if', 'the', 'then', 'be', 'is', 'are', 'will', 'in', 'it', 'to', 'that'] 23 | tokens = collections.defaultdict(lambda: 0.) 24 | if stopword_filter: 25 | stop_words = set(stopwords.words('english')) 26 | for m in _str: 27 | # m = m.group(1).lower() 28 | # if len(m) < 2: continue 29 | if len(m) < 2 or m in stop_words: 30 | pass 31 | else: 32 | tokens[m] += 1 33 | else: 34 | for m in _str: 35 | # m = m.group(1).lower() 36 | # if len(m) < 2: continue 37 | # if len(m) < 2: 38 | # pass 39 | # else: 40 | tokens[m] += 1 41 | return tokens 42 | 43 | # get tokens not contained in each of the distributions, add with count 0 44 | def add_missing_tokens(dict1,dict2): 45 | missing_from_dict1 = set(dict2.keys()).difference(set(dict1.keys())) # elements in dict2, but not dict1 46 | # print(missing_from_dict1) 47 | for v in missing_from_dict1: 48 | dict1[v]=0 49 | missing_from_dict2 = set(dict1.keys()).difference(set(dict2.keys())) # elements in dict1, but not dict2 50 | # print(missing_from_dict2) 51 | for v in missing_from_dict2: 52 | dict2[v]=0 53 | return dict1,dict2 54 | 55 | # order dict by keys alphabetically and turn into list 56 | def to_list(d1_complete,d2_complete): 57 | keys = sorted(list(d1_complete.keys())) 58 | # print(keys) 59 | l1 = [d1_complete[k] for k in keys] 60 | l2 = [d2_complete[k] for k in keys] 61 | return l1,l2 62 | 63 | # put everything together in one function 64 | def js_divergence(d1, d2,print_m=False): 65 | ''' 66 | Calculates JS divergence between two documents (str) 67 | ''' 68 | d1_complete,d2_complete = add_missing_tokens(tokenize(d1),tokenize(d2)) 69 | if print_m: 70 | print(d1_complete) 71 | print(d2_complete) 72 | l1,l2 = to_list(d1_complete,d2_complete) 73 | return js(np.array(l1), np.array(l2),print_m) 74 | 75 | def calculate_js_div(R1,R2): 76 | ''' 77 | Loads dataset defined by opt, calculates Jensen-Shannon divergence for each sentence pair and returns nested list with Jaccard similarities in each subset. 78 | :param opt: option dictionary to load dataset 79 | :return : nested list with overlap ratios 80 | ''' 81 | subset_overlap = [] 82 | for n in range(len(R1)): 83 | sim_per_pair = [] 84 | for i in range(len(R1[n])): 85 | s1 = R1[n][i] 86 | s2 = R2[n][i] 87 | try: 88 | sim = js_divergence(s1,s2) # change here 89 | except TypeError: 90 | print(s1) 91 | print(s2) 92 | print('---') 93 | sim_per_pair.append(sim) 94 | subset_overlap.append(sim_per_pair) 95 | return subset_overlap 96 | 97 | if __name__=='__main__': 98 | 99 | # js(np.array([0.1, 0.9]), np.array([1.0, 0.0]), True) 100 | # 101 | # d1 = """Many research publications want you to use BibTeX, which better 102 | # organizes the whole process. Suppose for concreteness your source 103 | # file is x.tex. Basically, you create a file x.bib containing the 104 | # bibliography, and run bibtex on that file.""" 105 | # d2 = """In this case you must supply both a \left and a \right because the 106 | # delimiter height are made to match whatever is contained between the 107 | # two commands. But, the \left doesn't have to be an actual 'left 108 | # delimiter', that is you can use '\left)' if there were some reason 109 | # to do it.""" 110 | # d1_complete, d2_complete = add_missing_tokens(tokenize(d1), tokenize(d2)) 111 | # print(d1_complete) 112 | # l1, l2 = to_list(d1_complete, d2_complete) 113 | # js(np.array(l1), np.array(l2), True) 114 | # 115 | # js_divergence(d1,d2) 116 | 117 | opt = {'dataset': 'MSRP', 'datapath': 'data/', 118 | 'tasks': ['B'], 'n_gram_embd': False, 119 | 'subsets': ['train', 'dev', 'test'], 'simple_padding': True, 'padding': True, 120 | 'model': 'basic_cnn', 'load_ids': False, 'cache': True} 121 | js = calculate_js_div(opt) -------------------------------------------------------------------------------- /src/experiments/bert.py: -------------------------------------------------------------------------------- 1 | from src.loaders.load_data import load_data 2 | from src.models.base_model_bert import model,test_opt 3 | import argparse 4 | 5 | # run bert with different learning rates on a certain dataset 6 | # example usage: python src/experiments/bert_baseline.py -dataset MSRP -learning_rate 5e-05 -gpu 0 --debug 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.register("type", "bool", lambda v: v.lower() == "true") 10 | parser.add_argument('-dataset', action="store", dest="dataset", type=str, default='MSRP') 11 | parser.add_argument('-learning_rate', action="store", dest="learning_rate", type=str, default='3e-5') # learning_rates = [5e-5, 2e-5, 4e-5, 3e-5] 12 | parser.add_argument("--debug",type="bool",nargs="?",const=True,default=False,help="Try to use small number of examples for troubleshooting") 13 | parser.add_argument("--speedup_new_layers",type="bool",nargs="?",const=True,default=False,help="Use 100 times higher learning rate for new layers.") 14 | parser.add_argument('-gpu', action="store", dest="gpu", type=int, default=-1) 15 | parser.add_argument('-seed', action="store", dest="seed", type=str, default='fixed') 16 | parser.add_argument("--early_stopping",type="bool",nargs="?",const=True,default=False) 17 | parser.add_argument("--train_longer",type="bool",nargs="?",const=True,default=False,help="Train for 9 epochs") 18 | 19 | FLAGS, unparsed = parser.parse_known_args() 20 | 21 | # sanity check command line arguments 22 | if len(unparsed)>0: 23 | parser.print_help() 24 | raise ValueError('Unidentified command line arguments passed: {}\n'.format(str(unparsed))) 25 | 26 | dataset = FLAGS.dataset 27 | assert dataset in ['MSRP', 'Semeval_A', 'Semeval_B', 'Semeval_C', 'Quora'] # 'MSRP','Semeval', 28 | 29 | tasks = [] 30 | 31 | # setting model options based on flags 32 | 33 | stopping_criterion = None #'F1' 34 | patience = None 35 | batch_size = 32 # standard minibatch size 36 | if 'Semeval' in dataset: 37 | dataset, task = dataset.split('_') 38 | subsets = ['train_large', 'test2016', 'test2017'] 39 | if task in ['A','C']: 40 | batch_size = 16 # need smaller minibatch to fit on GPU due to long sentences 41 | else: 42 | task = 'B' 43 | subsets = ['train', 'dev', 'test'] 44 | 45 | if FLAGS.debug: 46 | max_m = 100 47 | else: 48 | max_m = None 49 | load_ids = True 50 | max_length = 'minimum' 51 | ngrams = False 52 | pad = False 53 | if FLAGS.train_longer: 54 | epochs = 9 55 | predict_every_epoch = True 56 | else: 57 | epochs = 3 58 | predict_every_epoch = False 59 | 60 | if FLAGS.early_stopping: 61 | patience = 2 62 | stopping_criterion = 'F1' 63 | 64 | try: 65 | seed = int(FLAGS.seed) 66 | except: 67 | seed = None 68 | 69 | opt = {'dataset': dataset, 'datapath': 'data/', 70 | 'model': 'bert','bert_update':True,'bert_cased':False, 71 | 'tasks': [task], 72 | 'subsets': subsets,'seed':seed, 73 | 'minibatch_size': batch_size, 'L2': 0, 74 | 'max_m': max_m, 'load_ids': True, 75 | 'unk_sub': False, 'padding': False, 'simple_padding': True, 76 | 'learning_rate': float(FLAGS.learning_rate), 77 | 'num_epochs': epochs, # 'patience':20 78 | 'sparse_labels': True, 'max_length': 'minimum', 79 | 'optimizer': 'Adam', 'dropout':0.1, 80 | 'gpu': FLAGS.gpu, 81 | 'speedup_new_layers':FLAGS.speedup_new_layers, 82 | 'predict_every_epoch':predict_every_epoch, 83 | 'stopping_criterion':stopping_criterion, 'patience':patience 84 | } 85 | tasks.append(opt) 86 | 87 | if FLAGS.train_longer: 88 | log = 'bert_{}_seed_train_longer.json'.format(str(seed)) 89 | elif FLAGS.speedup_new_layers: 90 | log = 'bert_{}_seed_speedup_new_layers.json'.format(str(seed)) 91 | else: 92 | log = 'bert_{}_seed.json'.format(str(seed)) 93 | if FLAGS.early_stopping: 94 | log = log.replace('.json', '_early_stopping.json') 95 | 96 | print(log) 97 | 98 | if __name__ == '__main__': 99 | 100 | for i,opt in enumerate(tasks): 101 | print('Starting experiment {} of {}'.format(i+1,len(tasks))) 102 | print(opt) 103 | test_opt(opt) 104 | data = load_data(opt, cache=True, write_vocab=False) 105 | if FLAGS.debug: 106 | # print(data['']) 107 | print(data['E1'][0].shape) 108 | print(data['E1'][1].shape) 109 | print(data['E1'][2].shape) 110 | 111 | print(data['E1_mask'][0]) 112 | print(data['E1_seg'][0]) 113 | opt = model(data, opt, logfile=log, print_dim=True) 114 | -------------------------------------------------------------------------------- /src/experiments/tbert.py: -------------------------------------------------------------------------------- 1 | from src.loaders.load_data import load_data 2 | from src.models.base_model_bert import model,test_opt 3 | import argparse 4 | 5 | # run tbert with different learning rates on a certain dataset 6 | # example usage: python src/experiments/tbert.py -learning_rate 5e-05 -gpu 0 -topic_type ldamallet -topic word -dataset MSRP --debug 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.register("type", "bool", lambda v: v.lower() == "true") 10 | parser.add_argument('-dataset', action="store", dest="dataset", type=str, default='MSRP') 11 | parser.add_argument('-learning_rate', action="store", dest="learning_rate", type=str, default='3e-5') # learning_rates = [5e-5, 2e-5, 4e-5, 3e-5] 12 | parser.add_argument('-layers', action="store", dest="hidden_layers", type=str, default='0') 13 | parser.add_argument('-topic', action="store", dest="topic", type=str, default='word') 14 | parser.add_argument('-gpu', action="store", dest="gpu", type=int, default=-1) 15 | parser.add_argument("--speedup_new_layers",type="bool",nargs="?",const=True,default=False,help="Use 100 times higher learning rate for new layers.") 16 | parser.add_argument("--debug",type="bool",nargs="?",const=True,default=False,help="Try to use small number of examples for troubleshooting") 17 | parser.add_argument("--train_longer",type="bool",nargs="?",const=True,default=False,help="Train for 9 epochs") 18 | parser.add_argument("--early_stopping",type="bool",nargs="?",const=True,default=False) 19 | parser.add_argument("--unk_topic_zero",type="bool",nargs="?",const=True,default=False) 20 | parser.add_argument('-seed', action="store", dest="seed", type=str, default='fixed') 21 | parser.add_argument('-topic_type', action="store", dest="topic_type", type=str, default='ldamallet') 22 | 23 | FLAGS, unparsed = parser.parse_known_args() 24 | 25 | # sanity check command line arguments 26 | if len(unparsed)>0: 27 | parser.print_help() 28 | raise ValueError('Unidentified command line arguments passed: {}\n'.format(str(unparsed))) 29 | 30 | # setting model options based on flags 31 | 32 | dataset = FLAGS.dataset 33 | assert dataset in ['MSRP','Semeval_A','Semeval_B','Semeval_C','Quora'] 34 | 35 | hidden_layers = [int(h) for h in FLAGS.hidden_layers.split(',')] 36 | for h in hidden_layers: 37 | assert h in [0,1,2] 38 | 39 | topics = FLAGS.topic.split(',') 40 | for t in topics: 41 | assert t in ['word','doc'] 42 | 43 | priority = [] 44 | todo = [] 45 | last = [] 46 | 47 | stopping_criterion = None #'F1' 48 | patience = None 49 | batch_size = 32 # standard minibatch size 50 | if 'Semeval' in dataset: 51 | dataset, task = dataset.split('_') 52 | subsets = ['train_large', 'test2016', 'test2017'] 53 | if task in ['A']: 54 | batch_size = 16 # need smaller minibatch to fit on GPU due to long sentences 55 | num_topics = 70 56 | if FLAGS.topic_type=='gsdmm': 57 | alpha = 0.1 58 | else: 59 | alpha = 50 60 | elif task == 'B': 61 | num_topics = 80 62 | if FLAGS.topic_type=='gsdmm': 63 | alpha = 0.1 64 | else: 65 | alpha = 10 66 | elif task == 'C': 67 | batch_size = 16 # need smaller minibatch to fit on GPU due to long sentences 68 | num_topics = 70 69 | if FLAGS.topic_type=='gsdmm': 70 | alpha = 0.1 71 | else: 72 | alpha = 10 73 | else: 74 | task = 'B' 75 | if dataset== 'Quora': 76 | subsets = ['train', 'dev', 'test'] 77 | num_topics = 90 78 | if FLAGS.topic_type=='gsdmm': 79 | alpha = 0.1 80 | else: 81 | alpha = 1 82 | task = 'B' 83 | else: 84 | subsets = ['train', 'dev', 'test'] # MSRP 85 | num_topics = 80 86 | if FLAGS.topic_type=='gsdmm': 87 | alpha = 0.1 88 | else: 89 | alpha = 1 90 | task = 'B' 91 | 92 | if FLAGS.debug: 93 | max_m = 100 94 | else: 95 | max_m = None 96 | 97 | if FLAGS.train_longer: 98 | epochs = 9 99 | predict_every_epoch = True 100 | else: 101 | epochs = 3 102 | predict_every_epoch = False 103 | 104 | if FLAGS.early_stopping: 105 | patience = 2 106 | stopping_criterion = 'F1' 107 | 108 | try: 109 | seed = int(FLAGS.seed) 110 | except: 111 | seed = None 112 | 113 | if FLAGS.unk_topic_zero: 114 | unk_topic = 'zero' 115 | else: 116 | unk_topic = 'uniform' 117 | for topic_scope in topics: 118 | 119 | for hidden_layer in hidden_layers: 120 | 121 | opt = {'dataset': dataset, 'datapath': 'data/', 122 | 'model': 'bert_simple_topic','bert_update':True,'bert_cased':False, 123 | 'tasks': [task], 124 | 'subsets': subsets,'seed':seed, 125 | 'minibatch_size': batch_size, 'L2': 0, 126 | 'max_m': max_m, 'load_ids': True, 127 | 'topic':topic_scope,'topic_update':False, 128 | 'num_topics':num_topics, 'topic_alpha':alpha, 129 | 'unk_topic': unk_topic, 'topic_type':FLAGS.topic_type, 130 | 'unk_sub': False, 'padding': False, 'simple_padding': True, 131 | 'learning_rate': float(FLAGS.learning_rate), 132 | 'num_epochs': epochs, 'hidden_layer':hidden_layer, 133 | 'sparse_labels': True, 'max_length': 'minimum', 134 | 'optimizer': 'Adam', 'dropout':0.1, 135 | 'gpu': FLAGS.gpu, 136 | 'speedup_new_layers':FLAGS.speedup_new_layers, 137 | 'predict_every_epoch': predict_every_epoch, 138 | 'stopping_criterion':stopping_criterion, 'patience':patience 139 | } 140 | todo.append(opt) 141 | 142 | tasks = todo 143 | 144 | if __name__ == '__main__': 145 | 146 | for i,opt in enumerate(tasks): 147 | print('Starting experiment {} of {}'.format(i+1,len(tasks))) 148 | l_rate = str(opt['learning_rate']).replace('-0','-') 149 | if FLAGS.speedup_new_layers: 150 | log = 'tbert_{}_seed_speedup_new_layers.json'.format(str(seed)) 151 | # elif FLAGS.freeze_thaw_tune: 152 | # log = 'tbert_{}_seed_freeze_thaw_tune.json'.format(str(seed)) 153 | elif FLAGS.train_longer: 154 | log = 'tbert_{}_seed_train_longer.json'.format(str(seed)) 155 | else: 156 | log = 'tbert_{}_seed.json'.format(str(seed)) 157 | if FLAGS.early_stopping: 158 | log = log.replace('.json','_early_stopping.json') 159 | 160 | print(log) 161 | print(opt) 162 | test_opt(opt) 163 | data = load_data(opt, cache=True, write_vocab=False) 164 | if FLAGS.debug: 165 | # print(data['']) 166 | print(data['E1'][0].shape) 167 | print(data['E1'][1].shape) 168 | print(data['E1'][2].shape) 169 | 170 | print(data['E1_mask'][0]) 171 | print(data['E1_seg'][0]) 172 | opt = model(data, opt, logfile=log, print_dim=True) 173 | -------------------------------------------------------------------------------- /src/loaders/MSRP/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuningxi/tBERT/604e3f728629c5dd0fe1f201464440edbaf82e79/src/loaders/MSRP/__init__.py -------------------------------------------------------------------------------- /src/loaders/MSRP/build.py: -------------------------------------------------------------------------------- 1 | import src.loaders.build_data as build_data 2 | import os 3 | from src.loaders.Semeval.helper import Loader 4 | import csv 5 | import random 6 | 7 | files = ['m_train','m_dev','m_test'] 8 | 9 | 10 | def reformat_split(outpath, dtype, inpath): 11 | print('reformatting:' + inpath) 12 | # Quality #1 ID #2 ID #1 String #2 String 13 | with open(os.path.join(outpath, dtype+'_B.txt'), 'w', encoding='utf-8') as output_file: 14 | with open(inpath, newline='', encoding='utf-8') as f: 15 | csv_reader = csv.reader(f, delimiter='\t',quoting=csv.QUOTE_NONE,quotechar='|') 16 | next(csv_reader) # skip header 17 | for line in csv_reader: 18 | #print(line) 19 | labelB, id1, id2, doc1, doc2 = line 20 | output_file.write(id1 + '\t' + id2 + '\t' + doc1 + '\t' + doc2 + '\t' + labelB + '\n') 21 | 22 | 23 | def build(opt): 24 | dpath = os.path.join(opt['datapath'], opt['dataset']) 25 | embpath = os.path.join(opt['datapath'], 'embeddings') 26 | logpath = os.path.join(opt['datapath'], 'logs') 27 | modelpath = os.path.join(opt['datapath'], 'models') 28 | version = None 29 | 30 | if not build_data.built(dpath, version_string=version): 31 | print('[building data: ' + dpath + ']') 32 | if build_data.built(dpath): 33 | # An older version exists, so remove these outdated files. 34 | build_data.remove_dir(dpath) 35 | build_data.make_dir(dpath) 36 | build_data.make_dir(embpath) 37 | build_data.make_dir(logpath) 38 | build_data.make_dir(modelpath) 39 | 40 | # Download the data. 41 | # Use data from Kaggle 42 | # raise NotImplemented('Quora data not implemented yet') 43 | 44 | # '/Users/nicole/code/CQA/data/Quora/quora_duplicate_questions.tsv' 45 | # fnames = ['quora_duplicate_questions.tsv'] 46 | 47 | # urls = ['https://zhiguowang.github.io' + fnames[0]] 48 | 49 | dpext = os.path.join(dpath, 'MSRParaphraseCorpus') 50 | # build_data.make_dir(dpext) 51 | 52 | # for fname, url in zip(fnames,urls): 53 | # build_data.download(url, dpext, fname) 54 | # build_data.untar(dpext, fname) # should be able to handle zip 55 | 56 | reformat_split(dpath, files[0], os.path.join(dpext, 'msr-para-train.tsv')) 57 | reformat_split(dpath, files[1], os.path.join(dpext, 'msr-para-val.tsv')) 58 | reformat_split(dpath, files[2], os.path.join(dpext, 'msr-para-test.tsv')) 59 | 60 | # reformat(dpath, files[1], os.path.join(dpext, 'test.csv')) 61 | 62 | # Mark the data as built. 63 | build_data.mark_done(dpath, version_string=version) 64 | -------------------------------------------------------------------------------- /src/loaders/PAWS/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuningxi/tBERT/604e3f728629c5dd0fe1f201464440edbaf82e79/src/loaders/PAWS/__init__.py -------------------------------------------------------------------------------- /src/loaders/PAWS/build.py: -------------------------------------------------------------------------------- 1 | import src.loaders.build_data as build_data 2 | import os 3 | from src.loaders.Semeval.helper import Loader 4 | import csv 5 | import random 6 | 7 | files = ['p_train','p_test'] 8 | 9 | 10 | def reformat_original(outpath, inpath): 11 | random.seed(1) 12 | print('reformatting:' + inpath) 13 | # randomly split in train and test set 14 | with open(os.path.join(outpath, 'quora_train_B.txt'), 'w', encoding='utf-8') as output_file: 15 | with open(inpath, newline='', encoding='utf-8') as f: 16 | csv_reader = csv.reader(f, delimiter='\t') 17 | next(csv_reader) # skip header 18 | for pair_id, id1, id2, doc1, doc2, labelB in csv_reader: 19 | # print(pair_id + '\t' + id1 + '\t' + id2 + '\t' + doc1 + '\t' + doc2 + '\t' + labelB + '\n') 20 | # break 21 | output_file.write(id1 + '\t' + id2 + '\t' + doc1 + '\t' + doc2 + '\t' + labelB + '\n') 22 | # # randomly split in train and test set 23 | # with open(os.path.join(outpath, 'train' + '_B.txt'), 'w', encoding='utf-8') as training_file: 24 | # with open(os.path.join(outpath, 'test' + '_B.txt'), 'w', encoding='utf-8') as testing_file: 25 | # with open(inpath, newline='', encoding='utf-8') as f: 26 | # data = f.read().split('\n') 27 | # data = data[1:] # skip header 28 | # random.shuffle(data) # shuffle 29 | # # csv_reader = csv.reader(data, delimiter='\t') 30 | # train_examples = round(len(data) * 0.99) 31 | # i = 0 32 | # for line in data: 33 | # try: 34 | # pair_id, id1, id2, doc1, doc2, labelB = line.split('\t') 35 | # # print(pair_id + '\t' + id1 + '\t' + id2 + '\t' + doc1 + '\t' + doc2 + '\t' + labelB + '\n') 36 | # # break 37 | # if i < train_examples: 38 | # training_file.write(id1 + '\t' + id2 + '\t' + doc1 + '\t' + doc2 + '\t' + labelB + '\n') 39 | # else: 40 | # testing_file.write(id1 + '\t' + id2 + '\t' + doc1 + '\t' + doc2 + '\t' + labelB + '\n') 41 | # except ValueError: 42 | # print(line) 43 | # i+=1 44 | 45 | def reformat_split(outpath, dtype, inpath): 46 | print('reformatting:' + inpath) 47 | # reformat Wang's split 48 | with open(os.path.join(outpath, dtype+'_B.txt'), 'w', encoding='utf-8') as output_file: 49 | with open(inpath, newline='', encoding='utf-8') as f: 50 | csv_reader = csv.reader(f, delimiter='\t') 51 | next(csv_reader) # skip header 52 | for pair_id, doc1, doc2, labelB in csv_reader: 53 | id1 = pair_id+'-1' 54 | id2 = pair_id+'-2' 55 | output_file.write(id1 + '\t' + id2 + '\t' + doc1 + '\t' + doc2 + '\t' + labelB + '\n') 56 | 57 | 58 | def build(opt): 59 | dpath = os.path.join(opt['datapath'], opt['dataset']) 60 | embpath = os.path.join(opt['datapath'], 'embeddings') 61 | logpath = os.path.join(opt['datapath'], 'logs') 62 | modelpath = os.path.join(opt['datapath'], 'models') 63 | version = None 64 | 65 | if not build_data.built(dpath, version_string=version): 66 | print('[building data: ' + dpath + ']') 67 | if build_data.built(dpath): 68 | # An older version exists, so remove these outdated files. 69 | build_data.remove_dir(dpath) 70 | build_data.make_dir(dpath) 71 | build_data.make_dir(embpath) 72 | build_data.make_dir(logpath) 73 | build_data.make_dir(modelpath) 74 | 75 | # Download the data. 76 | # Use data from Kaggle 77 | # raise NotImplemented('Quora data not implemented yet') 78 | 79 | # '/Users/nicole/code/CQA/data/Quora/quora_duplicate_questions.tsv' 80 | # fnames = ['quora_duplicate_questions.tsv'] 81 | 82 | # urls = ['https://zhiguowang.github.io' + fnames[0]] 83 | 84 | dpext = os.path.join(dpath, 'data') 85 | # build_data.make_dir(dpext) 86 | 87 | # for fname, url in zip(fnames,urls): 88 | # build_data.download(url, dpext, fname) 89 | # build_data.untar(dpext, fname) # should be able to handle zip 90 | 91 | reformat_split(dpath, files[0], os.path.join(dpext, 'train.tsv')) 92 | reformat_split(dpath, files[1], os.path.join(dpext, 'dev_and_test.tsv')) 93 | # reformat_split(dpath, files[2], os.path.join(dpext, 'test.tsv')) 94 | 95 | # reformat(dpath, files[1], os.path.join(dpext, 'test.csv')) 96 | 97 | # Mark the data as built. 98 | build_data.mark_done(dpath, version_string=version) -------------------------------------------------------------------------------- /src/loaders/Quora/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuningxi/tBERT/604e3f728629c5dd0fe1f201464440edbaf82e79/src/loaders/Quora/__init__.py -------------------------------------------------------------------------------- /src/loaders/Quora/build.py: -------------------------------------------------------------------------------- 1 | import src.loaders.build_data as build_data 2 | import os 3 | from src.loaders.Semeval.helper import Loader 4 | import csv 5 | import random 6 | 7 | files = ['q_train','q_dev','q_test'] 8 | 9 | 10 | def reformat_original(outpath, inpath): 11 | random.seed(1) 12 | print('reformatting:' + inpath) 13 | # randomly split in train and test set 14 | with open(os.path.join(outpath, 'quora_train_B.txt'), 'w', encoding='utf-8') as output_file: 15 | with open(inpath, newline='', encoding='utf-8') as f: 16 | csv_reader = csv.reader(f, delimiter='\t') 17 | next(csv_reader) # skip header 18 | for pair_id, id1, id2, doc1, doc2, labelB in csv_reader: 19 | # print(pair_id + '\t' + id1 + '\t' + id2 + '\t' + doc1 + '\t' + doc2 + '\t' + labelB + '\n') 20 | # break 21 | output_file.write(id1 + '\t' + id2 + '\t' + doc1 + '\t' + doc2 + '\t' + labelB + '\n') 22 | # # randomly split in train and test set 23 | # with open(os.path.join(outpath, 'train' + '_B.txt'), 'w', encoding='utf-8') as training_file: 24 | # with open(os.path.join(outpath, 'test' + '_B.txt'), 'w', encoding='utf-8') as testing_file: 25 | # with open(inpath, newline='', encoding='utf-8') as f: 26 | # data = f.read().split('\n') 27 | # data = data[1:] # skip header 28 | # random.shuffle(data) # shuffle 29 | # # csv_reader = csv.reader(data, delimiter='\t') 30 | # train_examples = round(len(data) * 0.99) 31 | # i = 0 32 | # for line in data: 33 | # try: 34 | # pair_id, id1, id2, doc1, doc2, labelB = line.split('\t') 35 | # # print(pair_id + '\t' + id1 + '\t' + id2 + '\t' + doc1 + '\t' + doc2 + '\t' + labelB + '\n') 36 | # # break 37 | # if i < train_examples: 38 | # training_file.write(id1 + '\t' + id2 + '\t' + doc1 + '\t' + doc2 + '\t' + labelB + '\n') 39 | # else: 40 | # testing_file.write(id1 + '\t' + id2 + '\t' + doc1 + '\t' + doc2 + '\t' + labelB + '\n') 41 | # except ValueError: 42 | # print(line) 43 | # i+=1 44 | 45 | def reformat_split(outpath, dtype, inpath): 46 | print('reformatting:' + inpath) 47 | # reformat Wang's split 48 | with open(os.path.join(outpath, dtype+'_B.txt'), 'w', encoding='utf-8') as output_file: 49 | with open(inpath, newline='', encoding='utf-8') as f: 50 | csv_reader = csv.reader(f, delimiter='\t') 51 | next(csv_reader) # skip header 52 | for labelB, doc1, doc2, pair_id in csv_reader: 53 | id1 = pair_id+'-1' 54 | id2 = pair_id+'-2' 55 | output_file.write(id1 + '\t' + id2 + '\t' + doc1 + '\t' + doc2 + '\t' + labelB + '\n') 56 | 57 | 58 | def build(opt): 59 | dpath = os.path.join(opt['datapath'], opt['dataset']) 60 | embpath = os.path.join(opt['datapath'], 'embeddings') 61 | logpath = os.path.join(opt['datapath'], 'logs') 62 | modelpath = os.path.join(opt['datapath'], 'models') 63 | version = None 64 | 65 | if not build_data.built(dpath, version_string=version): 66 | print('[building data: ' + dpath + ']') 67 | if build_data.built(dpath): 68 | # An older version exists, so remove these outdated files. 69 | build_data.remove_dir(dpath) 70 | build_data.make_dir(dpath) 71 | build_data.make_dir(embpath) 72 | build_data.make_dir(logpath) 73 | build_data.make_dir(modelpath) 74 | 75 | # Download the data. 76 | # Use data from Kaggle 77 | # raise NotImplemented('Quora data not implemented yet') 78 | 79 | # '/Users/nicole/code/CQA/data/Quora/quora_duplicate_questions.tsv' 80 | # fnames = ['quora_duplicate_questions.tsv'] 81 | 82 | # urls = ['https://zhiguowang.github.io' + fnames[0]] 83 | 84 | dpext = os.path.join(dpath, 'Quora_question_pair_partition') 85 | # build_data.make_dir(dpext) 86 | 87 | # for fname, url in zip(fnames,urls): 88 | # build_data.download(url, dpext, fname) 89 | # build_data.untar(dpext, fname) # should be able to handle zip 90 | 91 | reformat_split(dpath, files[0], os.path.join(dpext, 'train.tsv')) 92 | reformat_split(dpath, files[1], os.path.join(dpext, 'dev.tsv')) 93 | reformat_split(dpath, files[2], os.path.join(dpext, 'test.tsv')) 94 | 95 | # reformat(dpath, files[1], os.path.join(dpext, 'test.csv')) 96 | 97 | # Mark the data as built. 98 | build_data.mark_done(dpath, version_string=version) -------------------------------------------------------------------------------- /src/loaders/Semeval/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. An additional grant 5 | # of patent rights can be found in the PATENTS file in the same directory. -------------------------------------------------------------------------------- /src/loaders/Semeval/agents.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. An additional grant 5 | # of patent rights can be found in the PATENTS file in the same directory. 6 | 7 | from parlai.core.fbdialog_teacher import FbDialogTeacher 8 | from .build import build 9 | 10 | import copy 11 | import os 12 | 13 | 14 | def _path(opt, filtered): 15 | # Build the data if it doesn't exist. 16 | build(opt) 17 | dt = opt['datatype'].split(':')[0] 18 | return os.path.join(opt['datapath'], 'WikiQA', dt + filtered + '.txt') 19 | 20 | 21 | class FilteredTeacher(FbDialogTeacher): 22 | def __init__(self, opt, shared=None): 23 | opt = copy.deepcopy(opt) 24 | opt['datafile'] = _path(opt, '-filtered') 25 | super().__init__(opt, shared) 26 | 27 | 28 | class UnfilteredTeacher(FbDialogTeacher): 29 | def __init__(self, opt, shared=None): 30 | opt = copy.deepcopy(opt) 31 | opt['datafile'] = _path(opt, '') 32 | super().__init__(opt, shared) 33 | 34 | 35 | class DefaultTeacher(FilteredTeacher): 36 | pass 37 | -------------------------------------------------------------------------------- /src/loaders/Semeval/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. An additional grant 5 | # of patent rights can be found in the PATENTS file in the same directory. 6 | # Download and build the data if it does not exist. 7 | 8 | import src.loaders.build_data as build_data 9 | import os 10 | from src.loaders.Semeval.helper import Loader 11 | 12 | 13 | files = ['train','train2','dev','test2016','test2017'] 14 | 15 | def reformat(outpath, dtype, inpath,concat=True): 16 | print('reformatting:' + dtype) 17 | 18 | with open(os.path.join(outpath, dtype + '_A.txt'), 'w',encoding='utf-8') as Aout: 19 | with open(os.path.join(outpath, dtype + '_B.txt'), 'w',encoding='utf-8') as Bout: 20 | with open(os.path.join(outpath, dtype + '_C.txt'), 'w',encoding='utf-8') as Cout: 21 | questions,taska_exclude = Loader.loadXMLQuestions([inpath]) 22 | # questions = OrderedDict(sorted(questions.items())) 23 | 24 | for k in sorted(questions.keys()): 25 | # print(k) 26 | # original question id, subject and question 27 | id1 = questions[k]['id'] 28 | if concat: 29 | doc1 = questions[k]['subject'] + ' ' + questions[k]['question'] 30 | else: 31 | doc1 = questions[k]['question'] 32 | 33 | for r in questions[k]['related'].keys(): 34 | # print(r) 35 | # related question id, subject and question 36 | id2 = questions[k]['related'][r]['id'] 37 | if concat: 38 | doc2 = questions[k]['related'][r]['subject'] + ' ' + questions[k]['related'][r]['question'] 39 | else: 40 | doc2 = questions[k]['related'][r]['question'] 41 | labelB = questions[k]['related'][r]['B-label'] 42 | # encode labels 43 | if labelB in ['Relevant','PerfectMatch']: 44 | labelB = '1' 45 | elif labelB in ['Irrelevant']: 46 | labelB = '0' 47 | else: 48 | raise ValueError('Annotation {} for example {} not defined!'.format((labelB,id2))) 49 | Bout.write(id1 + '\t' + id2 + '\t' + doc1 + '\t' + doc2 + '\t' + labelB + '\n') 50 | 51 | for c in questions[k]['related'][r]['comments'].keys(): 52 | # print(c) 53 | # comment id and comment 54 | id3 = questions[k]['related'][r]['comments'][c]['id'] 55 | doc3 = questions[k]['related'][r]['comments'][c]['comment'] 56 | labelA = questions[k]['related'][r]['comments'][c]['A-label'] 57 | labelC = questions[k]['related'][r]['comments'][c]['C-label'] 58 | # encode labels 59 | if labelA in ['Good']: 60 | labelA = '1' 61 | elif labelA in ['Bad','PotentiallyUseful']: 62 | labelA = '0' 63 | else: 64 | raise ValueError('Annotation {} for example {} not defined!'.format((labelA,id3))) 65 | if labelC in ['Good']: 66 | labelC = '1' 67 | elif labelC in ['Bad','PotentiallyUseful']: 68 | labelC = '0' 69 | else: 70 | raise ValueError('Annotation {} for example {} not defined!'.format((labelC,id3))) 71 | if r not in taska_exclude: 72 | Aout.write(id2 + '\t' + id3 + '\t' + doc2 + '\t' + doc3 + '\t' + labelA + '\n') 73 | Cout.write(id1 + '\t' + id3 + '\t' + doc1 + '\t' + doc3 + '\t' + labelC + '\n') 74 | 75 | # output format: 76 | # 77 | # dict(question_id => dict( 78 | # question 79 | # id 80 | # subject 81 | # comments = {} 82 | # related = dict(related_id => dict( 83 | # question 84 | # id 85 | # subject 86 | # relevance 87 | # comments = dict(comment_id => dict( 88 | # comment 89 | # date 90 | # id 91 | # username 92 | # ) 93 | # ) 94 | # ) 95 | 96 | 97 | def build(opt): 98 | dpath = os.path.join(opt['datapath'], opt['dataset']) 99 | embpath = os.path.join(opt['datapath'], 'embeddings') 100 | logpath = os.path.join(opt['datapath'], 'logs') 101 | modelpath = os.path.join(opt['datapath'], 'models') 102 | version = None 103 | 104 | if not build_data.built(dpath, version_string=version): 105 | print('[building data: ' + dpath + ']') 106 | if build_data.built(dpath): 107 | # An older version exists, so remove these outdated files. 108 | build_data.remove_dir(dpath) 109 | build_data.make_dir(dpath) 110 | build_data.make_dir(embpath) 111 | build_data.make_dir(logpath) 112 | build_data.make_dir(modelpath) 113 | 114 | # Download the data. 115 | fnames = ['semeval2016-task3-cqa-ql-traindev-v3.2.zip','semeval2016_task3_test.zip','semeval2017_task3_test.zip'] 116 | urls = ['http://alt.qcri.org/semeval2016/task3/data/uploads/' + fnames[0], 117 | 'http://alt.qcri.org/semeval2016/task3/data/uploads/' + fnames[1], 118 | 'http://alt.qcri.org/semeval2017/task3/data/uploads/' + fnames[2]] 119 | 120 | dpext = os.path.join(dpath, 'Semeval2017') 121 | build_data.make_dir(dpext) 122 | 123 | for fname, url in zip(fnames,urls): 124 | build_data.download(url, dpext, fname) 125 | build_data.untar(dpext, fname) # should be able to handle zip 126 | 127 | reformat(dpath, files[0], 128 | os.path.join(dpext, 'v3.2/train/SemEval2016-Task3-CQA-QL-train-part1.xml')) 129 | reformat(dpath, files[1], 130 | os.path.join(dpext, 'v3.2/train/SemEval2016-Task3-CQA-QL-train-part2.xml')) 131 | reformat(dpath, files[2], 132 | os.path.join(dpext, 'v3.2/dev/SemEval2016-Task3-CQA-QL-dev.xml')) 133 | reformat(dpath, files[3], 134 | os.path.join(dpext, 'SemEval2016_task3_test/English/SemEval2016-Task3-CQA-QL-test.xml')) 135 | reformat(dpath, files[4], 136 | os.path.join(dpext, 'SemEval2017_task3_test/English/SemEval2017-task3-English-test.xml')) 137 | 138 | # Mark the data as built. 139 | build_data.mark_done(dpath, version_string=version) 140 | -------------------------------------------------------------------------------- /src/loaders/Semeval/helper.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/whiskeyromeo/CommunityQuestionAnswering 2 | 3 | import xml.etree.ElementTree as ElementTree 4 | import sys 5 | from collections import OrderedDict 6 | 7 | 8 | # preexisting 9 | def getargvalue(name, required): 10 | output = False 11 | for arg in sys.argv: 12 | if arg[2:len(name)+2] == name: 13 | output = arg[3+len(name):] 14 | if required and not output: 15 | raise Exception("Required argument " + name + " not found in sys.argv") 16 | return output 17 | 18 | def argvalueexists(name): 19 | output = False 20 | for arg in sys.argv: 21 | if arg[2:len(name)+2] == name: 22 | output = True 23 | return output 24 | 25 | 26 | class Loader: 27 | 28 | @staticmethod 29 | def getfilenames(): 30 | if not argvalueexists("questionfiles"): 31 | output = Loader.defaultfilenames() 32 | else: 33 | output = getargvalue("questionfiles", True).split(",") 34 | return output 35 | 36 | 37 | @staticmethod 38 | def defaultfilenames(): 39 | filePaths = [ 40 | # train 2016 41 | 'data/semeval2017/v3.2/train/SemEval2016-Task3-CQA-QL-train-part1.xml', 42 | 'data/semeval2017/v3.2/train/SemEval2016-Task3-CQA-QL-train-part2.xml', 43 | # train 2015 44 | # 'data/semeval2017/v3.2/train-more-for-subtaskA-from-2015/SemEval2015-Task3-CQA-QL-train-reformatted-excluding-2016-questions-cleansed.xml', 45 | # dev 2015 46 | # 'data/semeval2017/v3.2/train-more-for-subtaskA-from-2015/SemEval2015-Task3-CQA-QL-dev-reformatted-excluding-2016-questions-cleansed.xml', 47 | # dev 2016 48 | 'data/semeval2017/v3.2/dev/SemEval2016-Task3-CQA-QL-dev.xml', 49 | # test 2015 50 | # 'data/semeval2017/v3.2/train-more-for-subtaskA-from-2015/SemEval2015-Task3-CQA-QL-test-reformatted-excluding-2016-questions-cleansed.xml', 51 | # test 2016 52 | 'data/semeval2017/v3.2/test/SemEval2016-Task3-CQA-QL-test.xml', 53 | # test 2017 54 | 'data/semeval2017/v3.2/test/SemEval2017-task3-English-test.xml' 55 | ] 56 | return filePaths 57 | 58 | 59 | @staticmethod 60 | def loadXMLQuestions(filenames): 61 | output = {} 62 | for filePath in filenames: 63 | print("\nParsing %s" % filePath) 64 | fileoutput,excludelist = Loader.parseTask3TrainingData(filePath) 65 | print(" Got %s primary questions" % len(fileoutput)) 66 | if not len(fileoutput): 67 | raise Exception("Failed to load any entries from " + filePath) 68 | isTraining = "train" in filePath 69 | for q in fileoutput: 70 | fileoutput[q]['isTraining'] = isTraining 71 | for r in fileoutput[q]['related']: 72 | fileoutput[q]['related'][r]['isTraining'] = isTraining 73 | output.update(fileoutput) 74 | print("\nTotal of %s entries" % len(output)) 75 | return output,excludelist 76 | 77 | @staticmethod 78 | def parseTask3TrainingData(filepath): 79 | tree = ElementTree.parse(filepath) 80 | root = tree.getroot() 81 | OrgQuestions = OrderedDict() 82 | exclude_task_a = [] 83 | for OrgQuestion in root.iter('OrgQuestion'): 84 | OrgQuestionOutput = {} 85 | OrgQuestionOutput['id'] = OrgQuestion.attrib['ORGQ_ID'] 86 | OrgQuestionOutput['subject'] = OrgQuestion.find('OrgQSubject').text 87 | OrgQuestionOutput['question'] = OrgQuestion.find('OrgQBody').text 88 | OrgQuestionOutput['comments'] = {} 89 | OrgQuestionOutput['related'] = OrderedDict() 90 | OrgQuestionOutput['featureVector'] = [] 91 | hasDuplicate = OrgQuestion.find('Thread').get('SubtaskA_Skip_Because_Same_As_RelQuestion_ID',None) 92 | if hasDuplicate is not None: 93 | exclude_task_a.append(OrgQuestion.find('Thread').attrib['THREAD_SEQUENCE']) 94 | # SubtaskA_Skip_Because_Same_As_RelQuestion_ID 95 | if OrgQuestionOutput['id'] not in OrgQuestions: 96 | OrgQuestions[OrgQuestionOutput['id']] = OrgQuestionOutput 97 | for RelQuestion in OrgQuestion.iter('RelQuestion'): 98 | RelQuestionOutput = {} 99 | RelQuestionOutput['id'] = RelQuestion.attrib['RELQ_ID'] 100 | RelQuestionOutput['subject'] = RelQuestion.find('RelQSubject').text 101 | RelQuestionOutput['question'] = RelQuestion.find('RelQBody').text 102 | RelQuestionOutput['B-label'] = RelQuestion.attrib['RELQ_RELEVANCE2ORGQ'] 103 | RelQuestionOutput['givenRank'] = RelQuestion.attrib['RELQ_RANKING_ORDER'] 104 | RelQuestionOutput['comments'] = OrderedDict() 105 | RelQuestionOutput['featureVector'] = [] 106 | for RelComment in OrgQuestion.iter('RelComment'): 107 | RelCommentOutput = {} 108 | RelCommentOutput['id'] = RelComment.attrib['RELC_ID'] 109 | RelCommentOutput['date'] = RelComment.attrib['RELC_DATE'] 110 | RelCommentOutput['username'] = RelComment.attrib['RELC_USERNAME'] 111 | RelCommentOutput['comment'] = RelComment.find('RelCText').text 112 | RelCommentOutput['C-label'] = RelComment.attrib['RELC_RELEVANCE2ORGQ'] 113 | RelCommentOutput['A-label'] = RelComment.attrib['RELC_RELEVANCE2RELQ'] 114 | RelQuestionOutput['comments'][RelCommentOutput['id']] = RelCommentOutput 115 | # if RelQuestionOutput['question'] != None: 116 | if RelQuestionOutput['question'] == None: 117 | RelQuestionOutput['question'] = "" 118 | OrgQuestions[OrgQuestionOutput['id']]['related'][RelQuestionOutput['id']] = RelQuestionOutput 119 | # else: 120 | # print("Warning: skipping empty question " + RelQuestionOutput['id']) 121 | return OrgQuestions, exclude_task_a 122 | 123 | # output format: 124 | # 125 | # dict(question_id => dict( 126 | # question 127 | # id 128 | # subject 129 | # comments = {} 130 | # related = dict(related_id => dict( 131 | # question 132 | # id 133 | # subject 134 | # relevance 135 | # comments = dict(comment_id => dict( 136 | # comment 137 | # date 138 | # id 139 | # username 140 | # ) 141 | # ) 142 | # ) 143 | 144 | 145 | -------------------------------------------------------------------------------- /src/loaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuningxi/tBERT/604e3f728629c5dd0fe1f201464440edbaf82e79/src/loaders/__init__.py -------------------------------------------------------------------------------- /src/loaders/augment_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def create_large_train(): 4 | ''' 5 | Create large training set for each task based on Deriu 2017 6 | ''' 7 | from src.loaders.load_data import get_filepath, load_file 8 | print('Creating large training set...') 9 | for t in ['A','B','C']: 10 | opt = {'dataset': 'Semeval', 'datapath': 'data/', 11 | 'tasks': [t], 12 | 'subsets': ['train','train2','dev']} 13 | files = get_filepath(opt) 14 | outfile = os.path.join(opt['datapath'],opt['dataset'],'train_large_'+t+'.txt') 15 | large_train = [] 16 | for f in files: 17 | with open(f,encoding='utf-8') as infile: 18 | for l in infile: 19 | large_train.append(l) 20 | with open(outfile,'w',encoding='utf-8') as out: 21 | for l in large_train: 22 | out.writelines(l) 23 | print('Done.') 24 | 25 | def double_task_training_data(): 26 | ''' 27 | Double existing data by switching side of questions to mitigate data scarcity for task B 28 | ''' 29 | from src.loaders.load_data import get_filepath, load_file 30 | print('Creating augmented training files for tasks') 31 | subsets = ['train', 'train_large'] 32 | for t in ['A','B','C']: 33 | for s in subsets: 34 | opt = {'dataset': 'Semeval', 'datapath': 'data/', 35 | 'tasks': [t], 36 | 'subsets': [s]} 37 | f = get_filepath(opt)[0] 38 | id1,id2,s1,s2,l = load_file(f,True) 39 | id1_double = id1 + id2 40 | id2_double = id2 + id1 41 | s1_double = s1 + s2 42 | s2_double = s2 + s1 43 | l_double = list(l) + list(l) 44 | assert len(id1_double)==len(id2_double)==len(s1_double)==len(s2_double)==len(l_double) 45 | outfile = os.path.join(os.path.join(opt['datapath'],opt['dataset'])+'/'+s + '_double' + '_'+t+'.txt') 46 | print(outfile) 47 | with open(outfile,'w',encoding='utf-8') as out: 48 | for i in range(len(id1_double)): 49 | out.writelines(id1_double[i]+'\t'+id2_double[i]+'\t'+s1_double[i]+'\t'+s2_double[i]+'\t'+str(l_double[i])+'\n') 50 | 51 | 52 | def augment_task_b_with_(): 53 | ''' 54 | Double existing data by switching side of questions to mitigate data scarcity for task B 55 | ''' 56 | from src.loaders.load_data import get_filepath, load_file 57 | print('Creating augmented training files for Task B') 58 | subsets = ['train', 'train_large'] 59 | for s in subsets: 60 | opt = {'dataset': 'Semeval', 'datapath': 'data/', 61 | 'tasks': ['B'], 62 | 'subsets': [s]} 63 | f = get_filepath(opt)[0] 64 | id1,id2,s1,s2,l = load_file(f,True) 65 | id1_double = id1 + id2 66 | id2_double = id2 + id1 67 | s1_double = s1 + s2 68 | s2_double = s2 + s1 69 | l_double = list(l) + list(l) 70 | assert len(id1_double)==len(id2_double)==len(s1_double)==len(s2_double)==len(l_double) 71 | outfile = os.path.join(f.split('_B.txt')[0] + '_double' + '_B.txt') 72 | with open(outfile,'w',encoding='utf-8') as out: 73 | for i in range(len(id1_double)): 74 | out.writelines(id1_double[i]+'\t'+id2_double[i]+'\t'+s1_double[i]+'\t'+s2_double[i]+'\t'+str(l_double[i])+'\n') 75 | 76 | -------------------------------------------------------------------------------- /src/loaders/build_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. An additional grant 5 | # of patent rights can be found in the PATENTS file in the same directory. 6 | """ 7 | Utilities for downloading and building data. 8 | These can be replaced if your particular file system does not support them. 9 | """ 10 | 11 | import time 12 | import datetime 13 | import os 14 | import requests 15 | import shutil 16 | 17 | 18 | def built(path, version_string=None): 19 | """Checks if '.built' flag has been set for that task. 20 | If a version_string is provided, this has to match, or the version 21 | is regarded as not built. 22 | """ 23 | if version_string: 24 | fname = os.path.join(path, '.built') 25 | if not os.path.isfile(fname): 26 | return False 27 | else: 28 | with open(fname, 'r') as read: 29 | text = read.read().split('\n') 30 | return (len(text) > 1 and text[1] == version_string) 31 | else: 32 | return os.path.isfile(os.path.join(path, '.built')) 33 | 34 | def mark_done(path, version_string=None): 35 | """Marks the path as done by adding a '.built' file with the current 36 | timestamp plus a version description string if specified. 37 | """ 38 | with open(os.path.join(path, '.built'), 'w') as write: 39 | write.write(str(datetime.datetime.today())) 40 | if version_string: 41 | write.write('\n' + version_string) 42 | 43 | 44 | def log_progress(curr, total, width=40): 45 | """Displays a bar showing the current progress.""" 46 | done = min(curr * width // total, width) 47 | remain = width - done 48 | progress = '[{}{}] {} / {}'.format( 49 | ''.join(['|'] * done), 50 | ''.join(['.'] * remain), 51 | curr, 52 | total 53 | ) 54 | print(progress, end='\r') 55 | 56 | 57 | def download(url, path, fname, redownload=False): 58 | """Downloads file using `requests`. If ``redownload`` is set to false, then 59 | will not download tar file again if it is present (default ``True``).""" 60 | outfile = os.path.join(path, fname) 61 | download = not os.path.isfile(outfile) or redownload 62 | 63 | retry = 5 64 | exp_backoff = [2 ** r for r in reversed(range(retry))] 65 | while download and retry >= 0: 66 | resume_file = outfile + '.part' 67 | resume = os.path.isfile(resume_file) 68 | if resume: 69 | resume_pos = os.path.getsize(resume_file) 70 | mode = 'ab' 71 | else: 72 | resume_pos = 0 73 | mode = 'wb' 74 | response = None 75 | 76 | with requests.Session() as session: 77 | try: 78 | header = {'Range': 'bytes=%d-' % resume_pos, 79 | 'Accept-Encoding': 'identity'} if resume else {} 80 | response = session.get(url, stream=True, timeout=5, headers=header) 81 | 82 | # negative reply could be 'none' or just missing 83 | if resume and response.headers.get('Accept-Ranges', 'none') == 'none': 84 | resume_pos = 0 85 | mode = 'wb' 86 | 87 | CHUNK_SIZE = 32768 88 | total_size = int(response.headers.get('Content-Length', -1)) 89 | # server returns remaining size if resuming, so adjust total 90 | total_size += resume_pos 91 | done = resume_pos 92 | 93 | with open(resume_file, mode) as f: 94 | for chunk in response.iter_content(CHUNK_SIZE): 95 | if chunk: # filter out keep-alive new chunks 96 | f.write(chunk) 97 | if total_size > 0: 98 | done += len(chunk) 99 | if total_size < done: 100 | # don't freak out if content-length was too small 101 | total_size = done 102 | log_progress(done, total_size) 103 | break 104 | except requests.exceptions.ConnectionError: 105 | retry -= 1 106 | print(''.join([' '] * 60), end='\r') 107 | if retry >= 0: 108 | print('Connection error, retrying. (%d retries left)' % retry) 109 | time.sleep(exp_backoff[retry]) 110 | else: 111 | print('Retried too many times, stopped retrying.') 112 | finally: 113 | if response: 114 | response.close() 115 | if retry < 0: 116 | raise RuntimeWarning('Connection broken too many times. Stopped retrying.') 117 | 118 | if download and retry > 0: 119 | print() 120 | if done < total_size: 121 | raise RuntimeWarning('Received less data than specified in ' + 122 | 'Content-Length header for ' + url + '.' + 123 | ' There may be a download problem.') 124 | move(resume_file, outfile) 125 | 126 | 127 | def make_dir(path): 128 | """Makes the directory and any nonexistent parent directories.""" 129 | os.makedirs(path, exist_ok=True) 130 | 131 | 132 | def move(path1, path2): 133 | """Renames the given file.""" 134 | shutil.move(path1, path2) 135 | 136 | 137 | def remove_dir(path): 138 | """Removes the given directory, if it exists.""" 139 | shutil.rmtree(path, ignore_errors=True) 140 | 141 | 142 | def untar(path, fname, deleteTar=True): 143 | """Unpacks the given archive file to the same directory, then (by default) 144 | deletes the archive file. 145 | """ 146 | print('unpacking ' + fname) 147 | fullpath = os.path.join(path, fname) 148 | shutil.unpack_archive(fullpath, path) 149 | if deleteTar: 150 | os.remove(fullpath) 151 | 152 | 153 | def _get_confirm_token(response): 154 | for key, value in response.cookies.items(): 155 | if key.startswith('download_warning'): 156 | return value 157 | return None 158 | 159 | 160 | def download_from_google_drive(gd_id, destination): 161 | """Uses the requests package to download a file from Google Drive.""" 162 | URL = 'https://docs.google.com/uc?export=download' 163 | 164 | with requests.Session() as session: 165 | response = session.get(URL, params={'id': gd_id}, stream=True) 166 | token = _get_confirm_token(response) 167 | 168 | if token: 169 | response.close() 170 | params = {'id': gd_id, 'confirm': token} 171 | response = session.get(URL, params=params, stream=True) 172 | 173 | CHUNK_SIZE = 32768 174 | with open(destination, 'wb') as f: 175 | for chunk in response.iter_content(CHUNK_SIZE): 176 | if chunk: # filter out keep-alive new chunks 177 | f.write(chunk) 178 | response.close() 179 | -------------------------------------------------------------------------------- /src/logs/tf_event_logs.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | 4 | def read_previous_performance_from_tfevent(model_id,epoch,split='dev',metric='Accuracy'): 5 | # get name of file in TF log dir 6 | tf_log_dir = 'data/models/model_{}/{}/'.format(model_id,split) 7 | eventfile = os.listdir(tf_log_dir) 8 | assert len(eventfile)==1 9 | for summary in tf.train.summary_iterator(tf_log_dir+eventfile[0]): 10 | # find epoch 11 | if summary.step==epoch: 12 | for value in summary.summary.value: 13 | # find metric 14 | if value.tag == 'evaluation_metrics/{}'.format(metric): 15 | return value.simple_value 16 | 17 | if __name__ == '__main__': 18 | 19 | for epoch in range(5,45,5): 20 | print('epoch {}: {}'.format(epoch,read_previous_performance_from_tfevent(3545, epoch, split='dev',metric='Accuracy'))) 21 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuningxi/tBERT/604e3f728629c5dd0fe1f201464440edbaf82e79/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/forward/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuningxi/tBERT/604e3f728629c5dd0fe1f201464440edbaf82e79/src/models/forward/__init__.py -------------------------------------------------------------------------------- /src/models/forward/bert.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from src.models.tf_helpers import maybe_print 3 | 4 | def forward_propagation(input_dict, classes, hidden_layer=0,reduction_factor=2, dropout=0,seed_list=[], print_dim=False): 5 | """ 6 | Defines forward pass for BERT model 7 | 8 | Returns: logits 9 | """ 10 | if print_dim: 11 | print('---') 12 | print('Model: BERT') 13 | print('---') 14 | 15 | bert = input_dict['E1'] 16 | # bert has 2 keys: sequence_output which is output embedding for each token and pooled_output which is output embedding for the entire sequence. 17 | 18 | with tf.name_scope('bert_rep'): 19 | # pooled output (containing extra dense layer) 20 | bert_rep = bert['pooled_output'] # pooled output over entire sequence 21 | maybe_print([bert_rep], ['pooled BERT'], print_dim) 22 | 23 | # C vector from last layer corresponding to CLS token 24 | # bert_rep = bert['sequence_output'][:, 0, :] # shape (batch, BERT_hidden) 25 | # maybe_print([bert_rep], ['BERT C vector'], print_dim) 26 | 27 | # add dropout 28 | bert_rep = tf.layers.dropout(inputs=bert_rep, rate=dropout, seed=seed_list.pop(0)) 29 | 30 | if hidden_layer>0: 31 | raise ValueError("BERT baseline doesn't use additional hidden layers.") 32 | 33 | with tf.name_scope('output_layer'): 34 | hidden_size = bert_rep.shape[-1].value 35 | output_weights = tf.get_variable( 36 | "output_weights", [classes, hidden_size], 37 | initializer=tf.truncated_normal_initializer(stddev=0.02, seed=seed_list.pop(0))) 38 | output_bias = tf.get_variable( 39 | "output_bias", [classes], initializer=tf.zeros_initializer()) 40 | logits = tf.matmul(bert_rep, output_weights, transpose_b=True) 41 | logits = tf.nn.bias_add(logits, output_bias) 42 | 43 | # Softmax Layer 44 | maybe_print([logits], ['output layer'], print_dim) 45 | 46 | output = {'logits':logits} 47 | 48 | return output -------------------------------------------------------------------------------- /src/models/forward/bert_simple_topic.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from src.models.tf_helpers import maybe_print 3 | 4 | def forward_propagation(input_dict, classes, hidden_layer=0, reduction_factor=2, dropout=0, seed_list=[], print_dim=False): 5 | """ 6 | Defines forward pass for tBERT model 7 | 8 | Returns: logits 9 | """ 10 | if print_dim: 11 | print('---') 12 | print('Model: tBERT') 13 | print('---') 14 | 15 | # word topics 16 | if input_dict['D_T1'] is None: 17 | W_T1 = input_dict['W_T1'] # (batch, sent_len_1, num_topics) 18 | W_T2 = input_dict['W_T2'] # (batch, sent_len_2, num_topics) 19 | maybe_print([W_T1, W_T2], ['W_T1', 'W_T2'], print_dim) 20 | # compute mean of word topics 21 | D_T1 = tf.reduce_mean(W_T1,axis=1) # (batch, num_topics) 22 | D_T2 = tf.reduce_mean(W_T2,axis=1) # (batch, num_topics) 23 | 24 | # document topics 25 | elif input_dict['W_T1'] is None: 26 | D_T1 = input_dict['D_T1'] # (batch, num_topics) 27 | D_T2 = input_dict['D_T2'] # (batch, num_topics) 28 | 29 | else: 30 | ValueError('Word or document topics need to be provided for bert_simple_topic.') 31 | maybe_print([D_T1,D_T2], ['D_T1','D_T2'], print_dim) 32 | 33 | # bert representation 34 | bert = input_dict['E1'] 35 | # bert has 2 keys: sequence_output which is output embedding for each token and pooled_output which is output embedding for the entire sequence. 36 | with tf.name_scope('bert_rep'): 37 | # pooled output (containing extra dense layer) 38 | bert_rep = bert['pooled_output'] # pooled output over entire sequence 39 | maybe_print([bert_rep], ['pooled BERT'], print_dim) 40 | 41 | # C vector from last layer corresponding to CLS token 42 | # bert_rep = bert['sequence_output'][:, 0, :] # shape (batch, BERT_hidden) 43 | # maybe_print([bert_rep], ['BERT C vector'], print_dim) 44 | bert_rep = tf.layers.dropout(inputs=bert_rep, rate=dropout, seed=seed_list.pop(0)) 45 | 46 | # combine BERT with document topics 47 | combined = tf.concat([bert_rep, D_T1, D_T2], -1) 48 | maybe_print([combined], ['combined'], print_dim) 49 | 50 | if hidden_layer>0: 51 | with tf.name_scope('hidden_1'): 52 | hidden_size = combined.shape[-1].value/reduction_factor 53 | combined = tf.layers.dense( 54 | combined, 55 | hidden_size, 56 | activation=tf.tanh, 57 | kernel_initializer=tf.truncated_normal_initializer(stddev=0.02, seed=seed_list.pop(0))) 58 | maybe_print([combined], ['hidden 1'], print_dim) 59 | combined = tf.layers.dropout(inputs=combined, rate=dropout, seed=seed_list.pop(0)) 60 | 61 | if hidden_layer>1: 62 | with tf.name_scope('hidden_2'): 63 | hidden_size = combined.shape[-1].value/reduction_factor 64 | combined = tf.layers.dense( 65 | combined, 66 | hidden_size, 67 | activation=tf.tanh, 68 | kernel_initializer=tf.truncated_normal_initializer(stddev=0.02, seed=seed_list.pop(0))) 69 | maybe_print([combined], ['hidden 2'], print_dim) 70 | combined = tf.layers.dropout(inputs=combined, rate=dropout, seed=seed_list.pop(0)) 71 | 72 | if hidden_layer>2: 73 | raise ValueError('Only 2 hidden layers supported.') 74 | 75 | with tf.name_scope('output_layer'): 76 | hidden_size = combined.shape[-1].value 77 | output_weights = tf.get_variable( 78 | "output_weights", [classes, hidden_size], 79 | initializer=tf.truncated_normal_initializer(stddev=0.02,seed=seed_list.pop(0))) 80 | output_bias = tf.get_variable( 81 | "output_bias", [classes], initializer=tf.zeros_initializer()) 82 | logits = tf.matmul(combined, output_weights, transpose_b=True) 83 | logits = tf.nn.bias_add(logits, output_bias) 84 | maybe_print([logits], ['output layer'], print_dim) 85 | 86 | output = {'logits':logits} 87 | 88 | return output -------------------------------------------------------------------------------- /src/models/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuningxi/tBERT/604e3f728629c5dd0fe1f201464440edbaf82e79/src/models/helpers/__init__.py -------------------------------------------------------------------------------- /src/models/helpers/base.py: -------------------------------------------------------------------------------- 1 | from src.logs.training_logs import get_git_sha 2 | 3 | def add_git_version(opt): 4 | ''' 5 | Add git SHA to option dictionary to keep track of code version during experiment 6 | :return: updated opt 7 | ''' 8 | opt['git'] = get_git_sha() 9 | return opt 10 | 11 | def skip_MAP(X): 12 | """ 13 | Returns False if X contains 10 * x examples (necessary for computing MAP at 10), True otherwise 14 | :param X: 15 | :return: 16 | """ 17 | if X.shape[0] % 10 == 0: 18 | return False 19 | else: 20 | return True 21 | 22 | def extract_data(data_dict, topic_scope, extra_test): 23 | ''' 24 | Extracts relevant data from data_dict for base_model 25 | :param data_dict: dictionary with preprocessed data as provided by load_data() 26 | :param topic_scope: '','word','doc', or 'word+doc' 27 | :param extra_test: boolean, evaluate on extra test set (as with PAWS) 28 | :return: train_dict, dev_dict, test_dict, test_dict_extra 29 | ''' 30 | 31 | D_T1_train, D_T2_train, D_T1_dev, D_T2_dev, D_T1_test, D_T2_test, D_T1_test_extra, D_T2_test_extra = [None] * 8 32 | W_T1_train, W_T2_train, W_T1_dev, W_T2_dev, W_T1_test, W_T2_test, W_T1_test_extra, W_T2_test_extra = [None] * 8 33 | W_T_train, W_T_dev, W_T_test, W_T_test_extra, = [None] * 4 34 | 35 | if extra_test: 36 | X1_train, X1_dev, X1_test, X1_test_extra = data_dict['E1'] 37 | X1_mask_train, X1_mask_dev, X1_mask_test, X1_mask_test_extra = data_dict['E1_mask'] 38 | X1_seg_train, X1_seg_dev, X1_seg_test, X1_seg_test_extra = data_dict['E1_seg'] 39 | Y_train, Y_dev, Y_test, Y_test_extra = data_dict['L'] 40 | else: 41 | X1_train, X1_dev, X1_test = data_dict['E1'] 42 | X1_mask_train, X1_mask_dev, X1_mask_test = data_dict['E1_mask'] 43 | X1_seg_train, X1_seg_dev, X1_seg_test = data_dict['E1_seg'] 44 | Y_train, Y_dev, Y_test = data_dict['L'] 45 | 46 | # assign topics if required 47 | if 'doc' in topic_scope: 48 | if extra_test: 49 | D_T1_train, D_T1_dev, D_T1_test, D_T1_test_extra = data_dict['D_T1'] 50 | D_T2_train, D_T2_dev, D_T2_test, D_T2_test_extra = data_dict['D_T2'] 51 | else: 52 | D_T1_train, D_T1_dev, D_T1_test = data_dict['D_T1'] 53 | D_T2_train, D_T2_dev, D_T2_test = data_dict['D_T2'] 54 | if 'word' in topic_scope: 55 | if extra_test: 56 | W_T1_train, W_T1_dev, W_T1_test, W_T1_test_extra = data_dict['W_T1'] 57 | W_T2_train, W_T2_dev, W_T2_test, W_T2_test_extra = data_dict['W_T2'] 58 | else: 59 | W_T1_train, W_T1_dev, W_T1_test = data_dict['W_T1'] 60 | W_T2_train, W_T2_dev, W_T2_test = data_dict['W_T2'] 61 | # word_topic_matrix = data['word_topics']['topic_matrix'] 62 | 63 | train_dict = {'E1': X1_train, 'E1_mask': X1_mask_train,'E1_seg': X1_seg_train, 'D_T1': D_T1_train, 'D_T2': D_T2_train, 'W_T1': W_T1_train, 64 | 'W_T2': W_T2_train, 'W_T': W_T_train, 'Y': Y_train} 65 | dev_dict = {'E1': X1_dev, 'E1_mask': X1_mask_dev,'E1_seg': X1_seg_dev, 'D_T1': D_T1_dev, 'D_T2': D_T2_dev, 'W_T1': W_T1_dev, 66 | 'W_T2': W_T2_dev, 'W_T': W_T_dev, 'Y': Y_dev} 67 | test_dict = {'E1': X1_test, 'E1_mask': X1_mask_test,'E1_seg': X1_seg_test, 'D_T1': D_T1_test, 'D_T2': D_T2_test, 'W_T1': W_T1_test, 68 | 'W_T2': W_T2_test,'W_T': W_T_test, 'Y': Y_test} 69 | 70 | # shapes of E1, E1_mask and E1_seg always need to match 71 | assert train_dict['E1'].shape == train_dict['E1_mask'].shape == train_dict['E1_seg'].shape 72 | assert dev_dict['E1'].shape == dev_dict['E1_mask'].shape == dev_dict['E1_seg'].shape 73 | assert test_dict['E1'].shape == test_dict['E1_mask'].shape == test_dict['E1_seg'].shape 74 | 75 | if extra_test: 76 | test_dict_extra = {'E1': X1_test_extra, 'E1_mask': X1_mask_test_extra,'E1_seg': X1_seg_test_extra, 'D_T1': D_T1_test_extra, 'D_T2': D_T2_test_extra, 77 | 'W_T1': W_T1_test_extra, 'W_T2': W_T2_test_extra, 'W_T': W_T_test_extra, 'Y': Y_test_extra} 78 | assert test_dict_extra['E1'].shape == test_dict_extra['E1_mask'].shape == test_dict_extra['E1_seg'].shape 79 | else: 80 | test_dict_extra = None 81 | if len(train_dict['Y'].shape)>1: 82 | assert (train_dict['Y'].shape[1] == dev_dict['Y'].shape[1] == test_dict['Y'].shape[1]), \ 83 | 'Inconsistent input dimensions of labels' 84 | 85 | return train_dict, dev_dict, test_dict, test_dict_extra -------------------------------------------------------------------------------- /src/models/helpers/bert.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from src.preprocessing import bert_tokenization 3 | import numpy as np 4 | 5 | def get_bert_version(cased,large): 6 | if large: 7 | size = '24_H-1024_A-16' 8 | else: 9 | size = '12_H-768_A-12' 10 | if cased: 11 | BERT_version = 'cased_L-{}'.format(size) 12 | else: 13 | BERT_version = 'uncased_L-{}'.format(size) 14 | return BERT_version 15 | 16 | def create_tokenizer(vocab_file, do_lower_case=False): 17 | return bert_tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case) 18 | 19 | def convert_sentence_pairs_to_features(S1, S2, tokenizer, max_seq_len): 20 | ids = [i for i in range(len(S1))] 21 | L = [None for i in range(len(S1))] 22 | examples = generate_examples(ids, S1, S2, L) 23 | features = convert_examples_to_features(examples, [None,0, 1], max_seq_len, tokenizer) 24 | input_ids = [] 25 | input_mask = [] 26 | segment_ids = [] 27 | for feature in features: 28 | input_ids.append(feature.input_ids) 29 | input_mask.append(feature.input_mask) 30 | segment_ids.append(feature.segment_ids) 31 | return np.array(input_ids), np.array(input_mask), np.array(segment_ids) 32 | 33 | 34 | class InputExample(object): 35 | """A single training/test example for simple sequence classification.""" 36 | 37 | def __init__(self, guid, text_a, text_b=None, label=None): 38 | """Constructs a InputExample. 39 | Args: 40 | guid: Unique id for the example. 41 | text_a: string. The untokenized text of the first sequence. For single 42 | sequence tasks, only this sequence must be specified. 43 | text_b: (Optional) string. The untokenized text of the second sequence. 44 | Only must be specified for sequence pair tasks. 45 | label: (Optional) string. The label of the example. This should be 46 | specified for train and dev examples, but not for test examples. 47 | """ 48 | self.guid = guid 49 | self.text_a = text_a 50 | self.text_b = text_b 51 | self.label = label 52 | 53 | def generate_examples(ids,S1,S2,L): 54 | ''' 55 | Create example objects from ids, sentence1s, sentence2s and labels 56 | :param ids: list of ids 57 | :param S1: list of sentences 58 | :param S2: list of sentences 59 | :param L: list of labels 60 | :return: list of example objects 61 | ''' 62 | # toy example 63 | # ids = [0,1] 64 | # S1 = ['this is a test', 'absd'] 65 | # S2 = ['this is another test', 'absd dsd'] 66 | # L = [1,0] 67 | # generate_examples(ids,S1,S2,L) 68 | examples = [] 69 | for i,s1,s2,l in zip(ids,S1,S2,L): 70 | example = InputExample(i, s1, s2, l) 71 | examples.append(example) 72 | return examples 73 | 74 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): 75 | """Convert a set of `InputExample`s to a list of `InputFeatures`.""" 76 | 77 | features = [] 78 | for (ex_index, example) in enumerate(examples): 79 | if ex_index % 100000 == 0: 80 | tf.logging.info("Converting example %d of %d" % (ex_index, len(examples))) 81 | 82 | feature = convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer) 83 | 84 | features.append(feature) 85 | return features 86 | 87 | def convert_single_example(ex_index, example, label_list, max_seq_length,tokenizer): 88 | """Converts a single `InputExample` into a single `InputFeatures`.""" 89 | 90 | if isinstance(example, PaddingInputExample): 91 | return InputFeatures( 92 | input_ids=[0] * max_seq_length, 93 | input_mask=[0] * max_seq_length, 94 | segment_ids=[0] * max_seq_length, 95 | label_id=0, 96 | is_real_example=False) 97 | 98 | label_map = {} 99 | for (i, label) in enumerate(label_list): 100 | label_map[label] = i 101 | 102 | tokens_a = tokenizer.tokenize(example.text_a) 103 | tokens_b = None 104 | if example.text_b: 105 | tokens_b = tokenizer.tokenize(example.text_b) 106 | 107 | if tokens_b: 108 | # Modifies `tokens_a` and `tokens_b` in place so that the total 109 | # length is less than the specified length. 110 | # Account for [CLS], [SEP], [SEP] with "- 3" 111 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 112 | else: 113 | # Account for [CLS] and [SEP] with "- 2" 114 | if len(tokens_a) > max_seq_length - 2: 115 | tokens_a = tokens_a[0:(max_seq_length - 2)] 116 | 117 | # The convention in BERT is: 118 | # (a) For sequence pairs: 119 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 120 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 121 | # (b) For single sequences: 122 | # tokens: [CLS] the dog is hairy . [SEP] 123 | # type_ids: 0 0 0 0 0 0 0 124 | # 125 | # Where "type_ids" are used to indicate whether this is the first 126 | # sequence or the second sequence. The embedding vectors for `type=0` and 127 | # `type=1` were learned during pre-training and are added to the wordpiece 128 | # embedding vector (and position vector). This is not *strictly* necessary 129 | # since the [SEP] token unambiguously separates the sequences, but it makes 130 | # it easier for the model to learn the concept of sequences. 131 | # 132 | # For classification tasks, the first vector (corresponding to [CLS]) is 133 | # used as the "sentence vector". Note that this only makes sense because 134 | # the entire model is fine-tuned. 135 | tokens = [] 136 | segment_ids = [] 137 | tokens.append("[CLS]") 138 | segment_ids.append(0) 139 | for token in tokens_a: 140 | tokens.append(token) 141 | segment_ids.append(0) 142 | tokens.append("[SEP]") 143 | segment_ids.append(0) 144 | 145 | if tokens_b: 146 | for token in tokens_b: 147 | tokens.append(token) 148 | segment_ids.append(1) 149 | tokens.append("[SEP]") 150 | segment_ids.append(1) 151 | 152 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 153 | 154 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 155 | # tokens are attended to. 156 | input_mask = [1] * len(input_ids) 157 | 158 | # Zero-pad up to the sequence length. 159 | while len(input_ids) < max_seq_length: 160 | input_ids.append(0) 161 | input_mask.append(0) 162 | segment_ids.append(0) 163 | 164 | assert len(input_ids) == max_seq_length 165 | assert len(input_mask) == max_seq_length 166 | assert len(segment_ids) == max_seq_length 167 | 168 | label_id = label_map[example.label] 169 | if ex_index < 5: 170 | tf.logging.info("*** Example ***") 171 | tf.logging.info("guid: %s" % (example.guid)) 172 | tf.logging.info("tokens: %s" % " ".join( 173 | [bert_tokenization.printable_text(x) for x in tokens])) 174 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 175 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 176 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 177 | tf.logging.info("label: %s (id = %d)" % (example.label, label_id)) 178 | 179 | feature = InputFeatures( 180 | input_ids=input_ids, 181 | input_mask=input_mask, 182 | segment_ids=segment_ids, 183 | label_id=label_id, 184 | is_real_example=True) 185 | return feature 186 | 187 | class InputFeatures(object): 188 | """A single set of features of data.""" 189 | 190 | def __init__(self, 191 | input_ids, 192 | input_mask, 193 | segment_ids, 194 | label_id, 195 | is_real_example=True): 196 | self.input_ids = input_ids 197 | self.input_mask = input_mask 198 | self.segment_ids = segment_ids 199 | self.label_id = label_id 200 | self.is_real_example = is_real_example 201 | 202 | 203 | class PaddingInputExample(object): 204 | """Fake example so the num input examples is a multiple of the batch size. 205 | When running eval/predict on the TPU, we need to pad the number of examples 206 | to be a multiple of the batch size, because the TPU requires a fixed batch 207 | size. The alternative is to drop the last batch, which is bad because it means 208 | the entire output data won't be generated. 209 | We use this class instead of `None` because treating `None` as padding 210 | battches could cause silent errors. 211 | """ 212 | 213 | 214 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 215 | """Truncates a sequence pair in place to the maximum length.""" 216 | 217 | # This is a simple heuristic which will always truncate the longer sequence 218 | # one token at a time. This makes more sense than truncating an equal percent 219 | # of tokens from each, since if one sequence is very short then each token 220 | # that's truncated likely contains more information than a longer sequence. 221 | while True: 222 | total_length = len(tokens_a) + len(tokens_b) 223 | if total_length <= max_length: 224 | break 225 | if len(tokens_a) > len(tokens_b): 226 | tokens_a.pop() 227 | else: 228 | tokens_b.pop() 229 | 230 | 231 | if __name__ == '__main__': 232 | # single toy example 233 | example = InputExample(231, 'This is a test', 'Here is another sentence', 0) 234 | BERT_version = 'cased_L-12_H-768_A-12' 235 | tokenizer = create_tokenizer('/Users/nicole/tf-hub-cache/{}/vocab.txt'.format(BERT_version), do_lower_case=True) 236 | max_seq_length = 20 237 | feature = convert_single_example(0,example, [0,1], max_seq_length, tokenizer) 238 | feature.label_id 239 | 240 | # multiple examples 241 | ids = [0,1] 242 | S1 = ['this is a test', 'absd'] 243 | S2 = ['this is another test', 'absd dsd'] 244 | L = [1,0] 245 | examples = generate_examples(ids,S1,S2,L) 246 | features = convert_examples_to_features(examples, [0,1], max_seq_length, tokenizer) -------------------------------------------------------------------------------- /src/models/helpers/training_regimes.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def standard_training_regime(optimizer_choice, cost, learning_rate, epsilon, rho): 4 | # normal setting with only one learning rate and optimizer for all variables 5 | with tf.name_scope('train'): 6 | if optimizer_choice == 'Adam': 7 | train_step = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) 8 | elif optimizer_choice == 'Adadelta': 9 | train_step = tf.train.AdadeltaOptimizer(learning_rate=learning_rate, epsilon=epsilon, rho=rho).minimize(cost) 10 | else: 11 | raise NotImplementedError() 12 | return train_step 13 | 14 | def layer_specific_regime(optimizer_choice, cost, learning_rate_old_layers, learning_rate_new_layers, epsilon, rho): 15 | ''' 16 | Using layer-specific learning rates for BERT vs. newer layers which can be changed during training (e.g. freeze --> unfreeze) 17 | :param optimizer_choice: Adam or Adadelta 18 | :param cost: cost tensor 19 | :param learning_rate_old_layers: placeholder 20 | :param learning_rate_new_layers: placeholder 21 | :param epsilon: 22 | :param rho: 23 | :return: update op (combining bert optimizer and new layer optimizer) 24 | ''' 25 | # based on https://stackoverflow.com/questions/34945554/how-to-set-layer-wise-learning-rate-in-tensorflow 26 | trainable_vars = tf.trainable_variables() # huge list of trainable model variables 27 | # separate existing and new variables based on name (not position as previously) 28 | bert_vars = [] 29 | new_vars = [] 30 | for t in trainable_vars: 31 | if t.name.startswith('bert_lookup/'): 32 | bert_vars.append(t) 33 | else: 34 | new_vars.append(t) 35 | # create optimizers with different learning rates 36 | with tf.name_scope('train'): 37 | if optimizer_choice == 'Adam': 38 | old_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate_old_layers, name='old_optimizer') 39 | new_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate_new_layers, name='new_optimizer') 40 | elif optimizer_choice == 'Adadelta': 41 | old_optimizer = tf.train.AdadeltaOptimizer(learning_rate=learning_rate_old_layers, epsilon=epsilon, rho=rho,name='old_optimizer') 42 | new_optimizer = tf.train.AdadeltaOptimizer(learning_rate=learning_rate_new_layers, epsilon=epsilon, rho=rho,name='new_optimizer') 43 | else: 44 | raise NotImplementedError() 45 | # only compute gradients once 46 | grads = tf.gradients(cost, bert_vars + new_vars) 47 | # separate gradients from pretrained and new layers 48 | bert_grads = grads[:len(bert_vars)] 49 | new_grads = grads[len(bert_vars):] 50 | # apply optimisers to respective variables and gradients 51 | train_step_bert = old_optimizer.apply_gradients(zip(bert_grads, bert_vars)) 52 | train_step_new = new_optimizer.apply_gradients(zip(new_grads, new_vars)) 53 | # combine to one operation 54 | train_step = tf.group(train_step_bert, train_step_new) 55 | return train_step 56 | -------------------------------------------------------------------------------- /src/models/save_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | import shutil 4 | # http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ 5 | 6 | def get_model_dir(opt,VM_copy=False): 7 | try: 8 | if type(opt['id'])==str: 9 | model_folder = opt['datapath'] + 'baseline_models/SemEval2017_task3_submissions_and_scores 3/' + opt['id'] + '/' 10 | else: 11 | if VM_copy: 12 | model_folder = opt['datapath'] + 'VM_models/model_{}/'.format(opt['id']) 13 | else: 14 | model_folder = opt['datapath'] + 'models/model_{}/'.format(opt['id']) 15 | except KeyError: 16 | raise KeyError('"id" and "datapath" in opt dictionary necessary for saving or loading model.') 17 | return model_folder 18 | 19 | def load_model(opt, saver, sess, epoch): 20 | # for early stopping 21 | model_path = get_model_dir(opt) + 'model_epoch{}.ckpt'.format(epoch) 22 | saver.restore(sess,model_path) 23 | 24 | # def load_predictions(model_id): 25 | # opt = {'datapath': 'data/','id':model_id} 26 | # model_dir = get_model_dir(opt,VM_copy=True) 27 | # pred_files = [f for f in os.listdir(model_dir) if os.path.isfile(os.path.join(model_dir, f)) and f.endswith('.pred')] 28 | # read_predictions(opt, subset='dev', VM_path=True) 29 | # pass 30 | 31 | def load_model_and_graph(opt, sess, epoch, vm_path=True): 32 | # for model inspection 33 | model_path = get_model_dir(opt,VM_copy=vm_path) + 'model_epoch{}.ckpt'.format(epoch) 34 | new_saver = tf.train.import_meta_graph(model_path+'.meta') # load graph 35 | new_saver.restore(sess,model_path) # restore weights 36 | 37 | def run_restored_tensor(tensornames, left, right, sess=None, label=None, w_tl=None, w_tr=None, d_tl=None, d_tr=None, t_l=None, t_r=None): 38 | ''' 39 | Runs graph on specified tensor(s) with supplied word ids for left and right sentence for inspection 40 | :param tensornames: names of tensors to restore 41 | :param left: numpy array of word ids for left sentence 42 | :param right: numpy array of word ids for right sentence 43 | :param sess: session object 44 | :return: values of evaluated tensors 45 | ''' 46 | graph = tf.get_default_graph() 47 | # build feed_dict based on provided input 48 | XL = graph.get_tensor_by_name("XL:0") 49 | XR = graph.get_tensor_by_name("XR:0") 50 | feed_dict = {XL: left, XR: right} 51 | if not label is None: 52 | Y = graph.get_tensor_by_name("labels:0") 53 | feed_dict[Y]=label 54 | if not d_tl is None: 55 | D_TL = graph.get_tensor_by_name("D_TL:0") 56 | D_TR = graph.get_tensor_by_name("D_TR:0") 57 | feed_dict[D_TL]=d_tl 58 | feed_dict[D_TR]=d_tr 59 | if not w_tl is None: 60 | W_TL = graph.get_tensor_by_name("W_TL:0") 61 | W_TR = graph.get_tensor_by_name("W_TR:0") 62 | feed_dict[W_TL]=w_tl 63 | feed_dict[W_TR]=w_tr 64 | if not t_l is None: 65 | T_L = graph.get_tensor_by_name("TL:0") 66 | T_R = graph.get_tensor_by_name("TR:0") 67 | feed_dict[T_L]=t_l 68 | feed_dict[T_R]=t_r 69 | # restore tensors 70 | ops_to_restore = [] 71 | for tensor in tensornames: 72 | ops_to_restore.append(graph.get_tensor_by_name(tensor+":0")) 73 | return sess.run(ops_to_restore, feed_dict) 74 | 75 | def run_restored_bert_tensor(tensornames, word_ids, mask_ids, seg_ids, sess=None, label=None, w_tl=None, w_tr=None, d_tl=None, d_tr=None): 76 | ''' 77 | Runs graph on specified tensor(s) with supplied word ids for left and right sentence for inspection 78 | :param tensornames: names of tensors to restore 79 | :param left: numpy array of word ids for left sentence 80 | :param right: numpy array of word ids for right sentence 81 | :param sess: session object 82 | :return: values of evaluated tensors 83 | ''' 84 | graph = tf.get_default_graph() 85 | # build feed_dict based on provided input 86 | X = graph.get_tensor_by_name("Placeholder:0") 87 | X_mask = graph.get_tensor_by_name("Placeholder_1:0") 88 | X_seg = graph.get_tensor_by_name("Placeholder_2:0") 89 | 90 | feed_dict = {X: word_ids, X_mask: mask_ids, X_seg: seg_ids} 91 | if not label is None: 92 | Y = graph.get_tensor_by_name("labels:0") 93 | feed_dict[Y]=label 94 | if not d_tl is None: 95 | D_TL = graph.get_tensor_by_name("D_TL:0") 96 | D_TR = graph.get_tensor_by_name("D_TR:0") 97 | feed_dict[D_TL]=d_tl 98 | feed_dict[D_TR]=d_tr 99 | if not w_tl is None: 100 | W_TL = graph.get_tensor_by_name("W_TL:0") 101 | W_TR = graph.get_tensor_by_name("W_TR:0") 102 | feed_dict[W_TL]=w_tl 103 | feed_dict[W_TR]=w_tr 104 | # restore tensors 105 | ops_to_restore = [] 106 | for tensor in tensornames: 107 | ops_to_restore.append(graph.get_tensor_by_name(tensor+":0")) 108 | return sess.run(ops_to_restore, feed_dict) 109 | 110 | def create_saver(): 111 | return tf.train.Saver(max_to_keep=1) 112 | 113 | def create_model_folder(opt): 114 | folder = get_model_dir(opt) 115 | if os.path.exists(folder): 116 | FileExistsError('{} already exists. Please delete.'.format(folder)) 117 | else: 118 | os.mkdir(folder) 119 | 120 | def save_model(opt, saver, sess, epoch): 121 | model_path = get_model_dir(opt) + 'model_epoch{}.ckpt'.format(epoch) 122 | print(model_path) 123 | saver.save(sess, model_path) 124 | 125 | def delete_all_checkpoints_but_best(opt,best_epoch): 126 | # list all files in model dir 127 | model_dir = get_model_dir(opt) 128 | # list all checkpoints but best 129 | best_model = 'model_epoch{}.ckpt'.format(best_epoch) 130 | to_delete = [f for f in os.listdir(model_dir) if os.path.isfile(os.path.join(model_dir, f)) and f.startswith('model_epoch') and not f.startswith(best_model)] 131 | if len(to_delete)>0: 132 | print('Deleting the following checkpoint files:') 133 | for f in to_delete: 134 | file_path = os.path.join(model_dir, f) 135 | print(file_path) 136 | # delete 137 | os.remove(file_path) 138 | 139 | def delete_model_dir(opt): 140 | model_dir = get_model_dir(opt) 141 | shutil.rmtree(model_dir) #ignore_errors=True 142 | -------------------------------------------------------------------------------- /src/models/tf_helpers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def maybe_print(elements, names, test_print): 5 | if test_print: 6 | for e, n in zip(elements, names): 7 | print(n + " shape: " + str(e.get_shape())) 8 | 9 | def compute_vocabulary_size(files): 10 | """ 11 | Counts number of distinct vocabulary indices 12 | :param files: only X, not Y 13 | :return: size of vocabulary 14 | """ 15 | vocabulary = set() 16 | for f in files: 17 | for row in f: 18 | for integer in row: 19 | if integer not in vocabulary: 20 | vocabulary.add(integer) 21 | return max(vocabulary)+1 22 | 23 | def create_placeholders(sentence_lengths, classes, bicnn=False, sparse=True, bert=False): 24 | """ 25 | Creates the placeholders for the tensorflow session. 26 | 27 | Arguments: 28 | sentence length -- scalar, width of sentence matrix 29 | classes -- scalar, number of classes 30 | 31 | Returns: 32 | X -- placeholder for the data input, of shape [None, sentence_length] and dtype "float" 33 | Y -- placeholder for the input labels, of shape [None, classes] and dtype "float" 34 | """ 35 | sentence_length = sentence_lengths[0] 36 | if sparse: 37 | Y = tf.placeholder(tf.int64, [None, ], name='labels') 38 | else: 39 | Y = tf.placeholder(tf.int32, [None, classes], name='labels') 40 | if bert: 41 | # BERT Placeholders # no names!! 42 | X1 = tf.placeholder(dtype=tf.int32, shape=[None, None]) # input ids 43 | X1_mask = tf.placeholder(dtype=tf.int32, shape=[None, None]) # input masks 44 | X1_seg = tf.placeholder(dtype=tf.int32, shape=[None, None]) # segment ids 45 | return X1,X1_mask,X1_seg,Y 46 | else: 47 | X = tf.placeholder(tf.int32, [None, sentence_length], name='XL') 48 | if bicnn: 49 | sentence_length2 = sentence_lengths[1] 50 | X2 = tf.placeholder(tf.int32, [None, sentence_length2], name='XR') 51 | return X, X2, Y 52 | else: 53 | return X, Y 54 | 55 | def create_text_placeholders(sentence_lengths): 56 | T1 = tf.placeholder(tf.string, [None, sentence_lengths[0]], name='TL') 57 | T2 = tf.placeholder(tf.string, [None, sentence_lengths[1]], name='TR') 58 | return T1,T2 59 | 60 | def create_word_topic_placeholders(sentence_lengths): 61 | """ 62 | Creates the placeholders for the tensorflow session. 63 | 64 | Arguments: 65 | :param sentence lengths: scalar, width of sentence matrix 66 | :param num_topics: number of topics for topic model 67 | :param dim: dimensions of word topics, should be 2, 3 or None 68 | 69 | Returns: 70 | X -- placeholder for the data input, of shape [None, sentence_length] and dtype "float" 71 | Y -- placeholder for the input labels, of shape [None, classes] and dtype "float" 72 | """ 73 | T1 = tf.placeholder(tf.int32, [None, sentence_lengths[0]], name='W_TL') 74 | T2 = tf.placeholder(tf.int32, [None, sentence_lengths[1]], name='W_TR') 75 | return T1,T2 76 | 77 | def create_word_topic_placeholder(sentence_length): 78 | """ 79 | Creates the placeholders for the tensorflow session. 80 | 81 | Arguments: 82 | :param sentence lengths: scalar, width of sentence matrix 83 | :param num_topics: number of topics for topic model 84 | :param dim: dimensions of word topics, should be 2, 3 or None 85 | 86 | Returns: 87 | X -- placeholder for the data input, of shape [None, sentence_length] and dtype "float" 88 | Y -- placeholder for the input labels, of shape [None, classes] and dtype "float" 89 | """ 90 | WT = tf.placeholder(tf.int32, [None, sentence_length], name='W_T') 91 | return WT 92 | 93 | def create_doc_topic_placeholders(num_topics): 94 | """ 95 | Creates the placeholders for the tensorflow session. 96 | 97 | Arguments: 98 | sentence length -- scalar, width of sentence matrix 99 | classes -- scalar, number of classes 100 | 101 | Returns: 102 | X -- placeholder for the data input, of shape [None, sentence_length] and dtype "float" 103 | Y -- placeholder for the input labels, of shape [None, classes] and dtype "float" 104 | """ 105 | T1 = tf.placeholder(tf.float32, [None, num_topics], name='D_TL') 106 | T2 = tf.placeholder(tf.float32, [None, num_topics], name='D_TR') 107 | return T1,T2 108 | 109 | def create_embedding(vocab_size, embedding_dim,name='embedding'): 110 | with tf.name_scope(name): 111 | embedding_matrix = tf.Variable( 112 | tf.random_uniform([vocab_size, embedding_dim], -1.0, 1.0), 113 | name="W", trainable=True) 114 | return embedding_matrix 115 | 116 | def create_embedding_placeholder(doc_vocab_size, embedding_dim): 117 | # use this or load data directly as variable? 118 | embedding_placeholder = tf.placeholder(tf.float32, [doc_vocab_size, embedding_dim], name='embd_placeholder') 119 | return embedding_placeholder 120 | 121 | def initialise_pretrained_embedding(doc_vocab_size, embedding_dim, embedding_placeholder, name='embedding',trainable=True): 122 | with tf.name_scope(name): 123 | if trainable: 124 | print('init pretrained embds') 125 | embedding_matrix = tf.Variable(embedding_placeholder, trainable=True, name="W",dtype=tf.float32) 126 | else: 127 | W = tf.Variable(tf.constant(0.0, shape=[doc_vocab_size, embedding_dim]), trainable=False, name="W") 128 | embedding_matrix = W.assign(embedding_placeholder) 129 | return embedding_matrix 130 | 131 | def lookup_embedding(X, embedding_matrix,expand=True,transpose=True,name='embedding_lookup'): 132 | ''' 133 | Looks up embeddings based on word ids 134 | :param X: word id matrix with shape (m, sentence_length) 135 | :param embedding_matrix: embedding matrix with shape (vocab_size, embedding_dim) 136 | :param expand: add dimension to embedded matrix or not 137 | :param transpose: switch dimensions of embedding matrix or not 138 | :param name: name used in TF graph 139 | :return: embedded_matrix 140 | ''' 141 | embedded_matrix = tf.nn.embedding_lookup(embedding_matrix, X, name=name) # dim [m, sentence_length, embedding_dim] 142 | if transpose: 143 | embedded_matrix = tf.transpose(embedded_matrix, perm=[0, 2, 1]) # dim [m, embedding_dim, sentence_length] 144 | if expand: 145 | embedded_matrix = tf.expand_dims(embedded_matrix, -1) # dim [m, embedding_dim, sentence_length, 1] 146 | return embedded_matrix 147 | 148 | def compute_cost(logits, Y, loss_fn='cross_entropy', name='main_cost'): 149 | """ 150 | Computes the cost 151 | 152 | Arguments: 153 | logits -- output of forward propagation (output of the last LINEAR unit of shape (batch, classes) 154 | Y -- "true" labels vector of shape (batch,) 155 | 156 | Returns: 157 | cost - Tensor of the cost function 158 | """ 159 | # multi class classification (binary classification as special case) 160 | with tf.name_scope(name): 161 | if loss_fn=='cross_entropy': 162 | # maybe_print([logits,Y], ['logits','Y'], True) 163 | cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=Y),name='cost') 164 | elif loss_fn=='bert': 165 | # from https://github.com/google-research/bert/blob/master/run_classifier_with_tfhub.py 166 | # probabilities = tf.nn.softmax(logits, axis=-1) 167 | log_probs = tf.nn.log_softmax(logits, axis=-1) 168 | one_hot_labels = tf.one_hot(Y, depth=2, dtype=tf.float32) 169 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 170 | cost = tf.reduce_mean(per_example_loss) 171 | else: 172 | raise NotImplemented() 173 | return cost 174 | 175 | 176 | 177 | -------------------------------------------------------------------------------- /src/models/topic_baseline.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import traceback 4 | import numpy as np 5 | import sys 6 | import argparse 7 | 8 | np.random.seed(1) 9 | 10 | from sklearn.metrics import accuracy_score,f1_score 11 | from src.loaders.load_data import load_data 12 | from src.logs.training_logs import write_log_entry, start_timer, end_timer 13 | from src.models.save_load import create_model_folder 14 | from src.evaluation.evaluate import output_predictions 15 | from src.evaluation.evaluate import save_eval_metrics 16 | from src.models.helpers.base import add_git_version,extract_data 17 | import warnings 18 | from src.evaluation.metrics.js_div import js 19 | warnings.filterwarnings("ignore", category=DeprecationWarning, module="tensorflow") 20 | # warnings.filterwarnings('error') 21 | 22 | 23 | """ 24 | Implements Base Model with early stopping and auxiliary loss 25 | """ 26 | 27 | predefined_opts = [ 28 | # loader 29 | 'dataset', 'datapath', 'tasks', 'subsets', 'w2v_limit', 'max_m', 'max_length', 'n_gram_embd', 'padding', 'unk_sub', 30 | 'simple_padding','lemmatize', 31 | # topic models 32 | 'num_topics', 'topic_type', 'topic','topic_alpha','unflat_topics','unk_topic','stem', 33 | # model 34 | 'model', 'load_ids', 'embedding_dim', 'embd_update','pretrained_embeddings','model','threshold', 35 | # auxiliary loss 36 | 'git'] 37 | 38 | def test_opt(opt): 39 | ''' 40 | Test if opt contains unused or wrong 41 | :param opt: 42 | :return: 43 | ''' 44 | for k in opt.keys(): 45 | assert k in predefined_opts, '{} not in accepted options'.format(k) 46 | 47 | def model(data_dict, opt, logfile=None, print_dim=False): 48 | """ 49 | Implements affinity CNN in Tensorflow: 50 | 51 | Arguments: 52 | data -- train, dev, test data 53 | 54 | 55 | 56 | opt -- option log, contains learning_rate, num_epochs, minibatch_size, ... 57 | logfile -- path of file to save opt and results 58 | print_dim -- print dimensions for debugging purposes 59 | 60 | Returns: 61 | opt -- updated option log 62 | parameters -- trained parameters of model 63 | """ 64 | 65 | ##### 66 | # Read options, set defaults and update log 67 | ##### 68 | 69 | try: 70 | # check input options 71 | print(opt) 72 | test_opt(opt) 73 | if opt.get('git',None) is None: 74 | add_git_version(opt) # keep track of git SHA 75 | 76 | # assign variables 77 | opt['model'] = opt.get('model','topic_baseline') 78 | assert opt['model'] == 'topic_baseline' 79 | # topic models 80 | topic_scope = opt['topic'] = opt.get('topic', '') 81 | assert topic_scope in ['','word','doc','word+doc'] 82 | if opt['model'] in ['topic_bi_cnn','topic_affinity_cnn','topic_separate_affinity_cnn']: 83 | assert topic_scope in ['word+doc','doc','word'] 84 | num_topics = opt['num_topics'] = opt.get('num_topics', None) 85 | topic_type = opt['topic_type'] = opt.get('topic_type', 'ldamallet') 86 | threshold = opt['threshold'] = opt.get('threshold', 0.5) 87 | print('threshold: {}'.format(threshold)) 88 | if not topic_scope == '': 89 | assert 'topic' in opt['model'] 90 | assert num_topics > 1 91 | assert topic_type in ['LDA','ldamallet','gsdmm'] 92 | opt['topic_alpha'] = opt.get('topic_alpha',50) 93 | else: 94 | assert num_topics is None 95 | assert topic_type is None 96 | if opt['dataset']=='Quora' and opt['subsets']==['train','dev','test','p_test']: 97 | extra_test = True 98 | else: 99 | extra_test = False 100 | 101 | ##### 102 | # unpack data and assign to model variables 103 | ##### 104 | 105 | embd = data_dict.get('embd', None) # (565852, 200) 106 | if 'word' in topic_scope: 107 | topic_embd = data_dict['word_topics'].get('topic_matrix', None) #topic_emb.shape 108 | 109 | # assign word ids 110 | if extra_test: 111 | ID1_train, ID1_dev, ID1_test, ID1_test_extra = data_dict['ID1'] 112 | ID2_train, ID2_dev, ID2_test, ID2_test_extra = data_dict['ID2'] 113 | else: 114 | ID1_train, ID1_dev, ID1_test = data_dict['ID1'] 115 | ID2_train, ID2_dev, ID2_test = data_dict['ID2'] 116 | train_dict,dev_dict,test_dict,test_dict_extra = extract_data(data_dict, topic_scope, extra_test) 117 | 118 | start_time, opt = start_timer(opt, logfile) 119 | # if not logfile is None: 120 | # print('logfile: {}'.format(logfile)) 121 | # create_model_folder(opt) 122 | # model_dir = get_model_dir(opt) 123 | 124 | def get_mean_word_topics(w_topic_ids, topic_matrix): 125 | T = [] 126 | for i in range(len(w_topic_ids)): # loop through sentences 127 | s_dist = [] 128 | # mean of word topics (todo: ignore non-topic words or not?) 129 | for w in w_topic_ids[i]: # loop through words 130 | if not w==0: 131 | w_dist = topic_matrix[w] 132 | s_dist.append(w_dist) 133 | if len(s_dist)==0: 134 | # no word topic vector 135 | s_dist=topic_matrix[0] 136 | elif len(s_dist)==1: 137 | # only one word topic vector 138 | s_dist = s_dist[0] 139 | pass 140 | else: 141 | # multiple word topic vectors 142 | s_dist = np.array(s_dist).mean(axis=0) 143 | T.append(s_dist) 144 | T = np.array(T) 145 | assert(len(T.shape)==2) 146 | return T 147 | 148 | # Predict + evaluate on dev and test set 149 | def predict(split_dict,topic_scope,threshold): 150 | # extract topics 151 | if topic_scope == 'word': 152 | # get topic ids from dict 153 | T1_ids = split_dict['W_T1'] 154 | T2_ids = split_dict['W_T2'] 155 | # print(T1_ids.shape) 156 | # print(T2_ids.shape) 157 | topic_matrix = data_dict['word_topics']['topic_matrix'] 158 | # lookup actual distributions, reduce dim through mean across word topics 159 | T1 = get_mean_word_topics(T1_ids, topic_matrix) # shape (m,num_topic) 160 | T2 = get_mean_word_topics(T2_ids, topic_matrix) 161 | elif topic_scope == 'doc': 162 | T1 = split_dict['D_T1'] # shape (m,num_topic) 163 | T2 = split_dict['D_T2'] 164 | elif 'word+doc': 165 | # concat 166 | NotImplementedError() 167 | # extract labels 168 | L = split_dict['Y'] 169 | predictions = [] 170 | scores = [] 171 | 172 | # calculate JSD between topic distributions 173 | print(T1.shape) 174 | print(T2.shape) 175 | for tl,tr in zip(T1,T2): 176 | # divergence = distance.jensenshannon(tl,tr) # why does it give nans? zero div 177 | divergence = js(tl,tr) 178 | if math.isnan(divergence): 179 | Warning('JS divergence is NAN') 180 | print(tl) 181 | print(tr) 182 | print(tl==tr) 183 | if divergence>threshold: 184 | prediction = 0 185 | else: 186 | prediction = 1 187 | scores.append(divergence) 188 | predictions.append(prediction) 189 | predictions = np.array(predictions) 190 | 191 | acc = accuracy_score(L,predictions) 192 | f_1 = f1_score(L,predictions) 193 | # todo 194 | prec = 0 195 | rec = 0 196 | ma_p = 0 197 | metrics = [acc, prec, rec, f_1, ma_p] 198 | return scores, predictions, metrics 199 | 200 | train_scores, train_pred, train_metrics = predict(train_dict,topic_scope,threshold) 201 | # print(train_scores) 202 | # print(train_pred) 203 | def get_mean_jsd(scores): 204 | mean_jsd = sum(scores) / len(scores) 205 | return mean_jsd 206 | 207 | dev_scores, dev_pred, dev_metrics = predict(dev_dict,topic_scope,threshold) 208 | opt = save_eval_metrics(dev_metrics, opt, 'dev') 209 | test_scores, test_pred, test_metrics = predict(test_dict,topic_scope,threshold) 210 | 211 | # save mean jsd for splits 212 | opt['mean_jsd_train'] = get_mean_jsd(train_scores) 213 | opt['mean_jsd_dev'] = get_mean_jsd(dev_scores) 214 | opt['mean_jsd_test'] = get_mean_jsd(test_scores) 215 | 216 | print('mean div scores on train set: {}'.format(opt['mean_jsd_train'])) 217 | print('mean div scores on dev set: {}'.format(opt['mean_jsd_dev'])) 218 | 219 | opt = save_eval_metrics(test_metrics, opt, 'test') 220 | opt = end_timer(opt, start_time, logfile) 221 | 222 | if extra_test: 223 | test_scores_extra, test_pred_extra, test_metrics_extra = predict(test_dict_extra,topic_scope,threshold) 224 | opt = save_eval_metrics(test_metrics_extra, opt, 'PAWS') 225 | 226 | if print_dim: 227 | stopping_criterion = 'Accuracy' 228 | print('Dev {}: {}'.format(stopping_criterion, opt['score'][stopping_criterion]['dev'])) 229 | print('Test {}: {}'.format(stopping_criterion, opt['score'][stopping_criterion]['test'])) 230 | 231 | if not logfile is None: 232 | 233 | # # prevent problem with long floats when writing log file 234 | # for k,v in opt.items(): 235 | # if type(v)==float: 236 | # opt[k] = round(float(v), 6) 237 | # save log 238 | print('logfile: {}'.format(logfile)) 239 | create_model_folder(opt) 240 | write_log_entry(opt, 'data/logs/' + logfile) 241 | # writer.add_graph(sess.graph) 242 | 243 | # write predictions to file for scorer 244 | output_predictions(ID1_train, ID2_train, train_scores, train_pred, 'train', opt) 245 | output_predictions(ID1_dev, ID2_dev, dev_scores, dev_pred, 'dev', opt) 246 | output_predictions(ID1_test, ID2_test, test_scores, test_pred, 'test', opt) 247 | if extra_test: 248 | output_predictions(ID1_test_extra, ID2_test_extra, test_scores_extra, test_pred_extra, 'PAWS_test', opt) 249 | print('Wrote predictions for model_{}.'.format(opt['id'])) 250 | 251 | # todo: compute non-obvious F1 and save in opt 252 | 253 | except Exception as e: 254 | print("Error: {0}".format(e.__doc__)) 255 | traceback.print_exc(file=sys.stdout) 256 | opt['status'] = 'Error' 257 | write_log_entry(opt, 'data/logs/' + logfile) 258 | 259 | # print('==============') 260 | 261 | return opt 262 | 263 | 264 | if __name__ == '__main__': 265 | 266 | parser = argparse.ArgumentParser() 267 | parser.register("type", "bool", lambda v: v.lower() == "true") 268 | parser.add_argument( 269 | "--debug", 270 | type="bool", 271 | nargs="?", 272 | const=True, 273 | default=False, 274 | help="Use debugger to track down bad values during training. " 275 | "Mutually exclusive with the --tensorboard_debug_address flag.") 276 | parser.add_argument('-gpu', action="store", dest="gpu", type=int, default=-1) 277 | FLAGS, unparsed = parser.parse_known_args() 278 | 279 | unks = ['zero','zero','uniform','uniform'] 280 | stems = [True,False,True,False] 281 | 282 | for unk_topic,stem in zip(unks,stems): 283 | opt = {'embd_update': False, 284 | 'load_ids': True, 'max_length': 'minimum', 'max_m': None, #todo: fix max_m=None error 285 | 'model': 'topic_baseline', 'n_gram_embd': False, 286 | 'pretrained_embeddings': None, 287 | 'subsets': ['train', 'dev', 'test'], 288 | # 'subsets': ['train_large', 'test2016', 'test2017'], 289 | 'topic': 'word', 'topic_type': 'ldamallet', 290 | 291 | 'padding': False, 'simple_padding': True, 292 | 'max_length': 'minimum', 'unk_sub': False, 293 | 'lemmatize': False, 294 | 295 | 'datapath': 'data/', 'dataset': 'Quora', 'tasks': ['B'], 296 | 'topic_alpha': 1, 'num_topics': 90, 'threshold': 0.090, 297 | 'unk_topic':unk_topic, 'stem':stem} 298 | 299 | # print(opt) 300 | data_dict = load_data(opt, cache=True, write_vocab=False) 301 | opt = model(data_dict, opt, logfile='specific_topic_settings.json', print_dim=True) 302 | 303 | # todo: zeros - no stem 304 | # todo: zeros - stem 305 | # todo: uniform - stem 306 | 307 | # data_dict['embd'] 308 | 309 | -------------------------------------------------------------------------------- /src/preprocessing/Preprocessor.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/whiskeyromeo/CommunityQuestionAnswering 2 | 3 | import nltk 4 | import re 5 | import numpy as np 6 | import os 7 | from pathlib import Path 8 | from src.models.helpers.bert import convert_sentence_pairs_to_features,create_tokenizer,get_bert_version 9 | # from src.models.helpers.bert import convert_sentence_pairs_to_features,create_tokenizer 10 | 11 | def get_homedir(): 12 | ''' 13 | Returns user homedir across different platforms 14 | :return: 15 | ''' 16 | return str(Path.home()) 17 | 18 | def reduce_embd_id_len(E1,tasks, cutoff=100): 19 | ''' 20 | Reduces numpy array dimensions for word embedding ids to save computational resources, e.g. from (2930, 200) to (2930, 100) 21 | :param E1: document represented as numpy array with word ids of shape (m,sent_len) 22 | :param cutoff: sentence length after cutoff 23 | :return: shortened E1 24 | ''' 25 | if len(tasks) > 1: 26 | raise NotImplementedError('Not implemented minimum length with multiple tasks yet.') 27 | # cut length of questions to 100 tokens, leave answers as is 28 | # only select questions 29 | E1_short = [] 30 | for sub in E1: 31 | # reduce length (drop last 100 elements in array) 32 | d = np.delete(sub, np.s_[cutoff:], 1) 33 | E1_short.append(d) 34 | assert E1_short[-1].shape == (E1[-1].shape[0], 100) 35 | return E1_short 36 | 37 | # same as convert_to_one_hot().T 38 | def get_onehot_encoding(labels): 39 | classes = 2 40 | test = labels.reshape(labels.size, ) 41 | onehotL = np.zeros((test.size, classes)) 42 | onehotL[np.arange(test.size), test] = 1 43 | onehotL[np.arange(test.size), test] = 1 44 | return onehotL 45 | 46 | class Preprocessor: 47 | vocab_processor = None 48 | 49 | @staticmethod 50 | def basic_pipeline(sentences): 51 | # process text 52 | print("Preprocessor: replace urls and images") 53 | sentences = Preprocessor.replaceImagesURLs(sentences) 54 | print("Preprocessor: to lower case") 55 | sentences = Preprocessor.toLowerCase(sentences) 56 | print("Preprocessor: split sentence into words") 57 | sentences = Preprocessor.tokenize_tweet(sentences) 58 | print("Preprocessor: remove quotes") 59 | sentences = Preprocessor.removeQuotes(sentences) 60 | return sentences 61 | 62 | @staticmethod 63 | def replaceImagesURLs(sentences): 64 | out = [] 65 | # URL_tokens = ['','','URLTOK'] # 'URLTOK' or '' 66 | # IMG_tokens = ['','IMG'] 67 | URL_token = '' 68 | IMG_token = '' 69 | 70 | for s in sentences: 71 | s = re.sub(r'(http://)?www.*?(\s|$)', URL_token+'\\2', s) # URL containing www 72 | s = re.sub(r'http://.*?(\s|$)', URL_token+'\\1', s) # URL starting with http 73 | s = re.sub(r'\w+?@.+?\\.com.*',URL_token,s) #email 74 | s = re.sub(r'\[img.*?\]',IMG_token,s) # image 75 | s = re.sub(r'< ?img.*?>', IMG_token, s) 76 | out.append(s) 77 | return out 78 | 79 | @staticmethod 80 | def removeQuotes(sentences): 81 | ''' 82 | Remove punctuation from list of strings 83 | :param sentences: list with tokenised sentences 84 | :return: list 85 | ''' 86 | out = [] 87 | for s in sentences: 88 | out.append([w for w in s if not re.match(r"['`\"]+",w)]) 89 | # # Twitter embeddings retain punctuation and use the following special tokens: 90 | # # , , , , 91 | # # s = re.sub(r'[^\w\s]', ' ', s) 92 | # s = re.sub(r'[^a-zA-Z0-9_<>?.,]', ' ', s) 93 | # s = re.sub(r'[\s+]', ' ', s) 94 | # s = re.sub(r' +', ' ', s) # prevent too much whitespace 95 | # s = s.lstrip().rstrip() 96 | # out.append(s) 97 | return out 98 | 99 | @staticmethod 100 | def stopwordsList(): 101 | stopwords = nltk.corpus.stopwords.words('english') 102 | stopwords.append('...') 103 | stopwords.append('___') 104 | stopwords.append('') 105 | stopwords.append('') 106 | stopwords.append('') 107 | stopwords.append('') 108 | stopwords.append("can't") 109 | stopwords.append("i've") 110 | stopwords.append("i'll") 111 | stopwords.append("i'm") 112 | stopwords.append("that's") 113 | stopwords.append("n't") 114 | stopwords.append('rrb') 115 | stopwords.append('lrb') 116 | return stopwords 117 | 118 | @staticmethod 119 | def removeStopwords(question): 120 | stopwords = Preprocessor.stopwordsList() 121 | return [i for i in question if i not in stopwords] 122 | 123 | @staticmethod 124 | def removeShortLongWords(sentence): 125 | return [w for w in sentence if len(w)>2 and len(w)<200] 126 | 127 | @staticmethod 128 | def tokenize_simple(iterator): 129 | return [sentence.split(' ') for sentence in iterator] 130 | 131 | @staticmethod 132 | def tokenize_nltk(iterator): 133 | return [nltk.word_tokenize(sentence) for sentence in iterator] 134 | 135 | @staticmethod 136 | def tokenize_tweet(iterator,strip=True): 137 | # tknzr = nltk.tokenize.TweetTokenizer(strip_handles=True, reduce_len=True) 138 | tknzr = nltk.tokenize.TweetTokenizer(strip_handles=True, reduce_len=True) 139 | result = [tknzr.tokenize(sentence) for sentence in iterator] 140 | if strip: 141 | result = [[w.replace(" ", "") for w in s] for s in result] 142 | return result 143 | 144 | @staticmethod 145 | def toLowerCase(sentences): 146 | out = [] 147 | special_tokens = ['UNK','',''] 148 | for s in Preprocessor.tokenize_tweet(sentences): 149 | sent =[] 150 | # split sentences in tokens and lowercase except for special tokens 151 | for w in s: 152 | if w in special_tokens: 153 | sent.append(w) 154 | else: 155 | sent.append(w.lower()) 156 | out.append(' '.join(sent)) 157 | return out 158 | 159 | @staticmethod 160 | def max_document_length(sentences,tokenizer): 161 | sentences = tokenizer(sentences) 162 | return max([len(x) for x in sentences]) # tokenised length of sentence! 163 | 164 | @staticmethod 165 | def pad_sentences(sentences, max_length,pad_token='',tokenized=False): 166 | ''' 167 | Manually pad sentences with pad_token (to avoid the same representation for and ) 168 | :param sentences: 169 | :param tokenizer: 170 | :param max_length: 171 | :param pad_token: 172 | :return: 173 | ''' 174 | if tokenized: 175 | tokenized = sentences 176 | return [(s + [pad_token] * (max_length - len(s))) for s in tokenized] 177 | else: 178 | tokenized = Preprocessor.tokenize_tweet(sentences) 179 | return [' '.join(s + [pad_token] * (max_length - len(s))) for s in tokenized] 180 | 181 | @staticmethod 182 | def reduce_sentence_len(r_tok,max_len): 183 | ''' 184 | Reduce length of tokenised sentence 185 | :param r_tok: nested list consisting of tokenised sentences e.g. [['w1','w2'],['w3']] 186 | :param max_len: maximum length of sentence 187 | :return: nested list consisting of tokenised sentences, none longer than max_len 188 | ''' 189 | return [s if len(s) <= max_len else s[:max_len] for s in r_tok] 190 | 191 | @staticmethod 192 | def map_topics_to_id(r_tok,word2id_dict,s_max_len,opt): 193 | r_red = Preprocessor.reduce_sentence_len(r_tok, s_max_len) 194 | r_pad = Preprocessor.pad_sentences(r_red, s_max_len, pad_token='UNK', tokenized=True) 195 | mapped_sentences = [] 196 | for s in r_pad: 197 | ids = [word2id_dict[lemma] if lemma in word2id_dict.keys() else 0 for lemma in s] # todo:fix 0 for UNK 198 | assert len(ids)==s_max_len, 'id len for {} should be {}, but is {}'.format(s,s_max_len,len(ids)) 199 | mapped_sentences.append(np.array(ids)) 200 | return np.array(mapped_sentences) 201 | 202 | @staticmethod 203 | def map_files_to_bert_ids(T1, T2, max_length, calculate_mapping_rate=False, bert_cased=False, bert_large=False): 204 | ''' 205 | Split raw text into tokens and map to embedding ids for all subsets 206 | :param T1: nested list with tokenized sentences in each subset e.g. [R1_train,R1_dev,R1_test] 207 | :param T1: nested list with tokenized sentences in each subset e.g. [R2_train,R2_dev,R2_test] 208 | :param max_length: number of tokens in longest sentence, int 209 | :param pretrained_embedding: use mapping from existing embeddings?, boolean 210 | :param padding_tokens: padding tokens to use, should be [''] or ['',''] 211 | :return: {'E1':E1,'E2':E2, 'mapping_rates':mapping_rates or None} 212 | ''' 213 | mapping_rates = [] #todo: fix mapping rates 214 | 215 | # set unused to None rather than [] 216 | E1 = [] 217 | E1_mask = [] 218 | E1_seg = [] 219 | # use new_bert preprocessing code to encode sentence pairs 220 | for S1,S2 in zip(T1,T2): # look through subsets 221 | BERT_version = get_bert_version(bert_cased, bert_large) 222 | if bert_cased: 223 | lower=False 224 | else: 225 | lower=True 226 | tokenizer = create_tokenizer('{}/tf-hub-cache/{}/vocab.txt'.format(get_homedir(),BERT_version), 227 | do_lower_case=lower) 228 | Preprocessor.word2id = tokenizer.vocab # dict(zip(vocabulary, range(len(vocabulary)))) 229 | Preprocessor.id2word = {v: k for k, v in Preprocessor.word2id.items()} 230 | # S1 = [' '.join(s) for s in S1] # don't use tokenized version 231 | # S2 = [' '.join(s) for s in S2] # don't use tokenized version 232 | input_ids_vals, input_mask_vals, segment_ids_vals = convert_sentence_pairs_to_features(S1,S2, tokenizer,max_seq_len=max_length) # double length due to 2 sentences 233 | assert input_ids_vals.shape == input_mask_vals.shape == segment_ids_vals.shape 234 | E1.append(input_ids_vals) 235 | E1_mask.append(input_mask_vals) 236 | E1_seg.append(segment_ids_vals) 237 | 238 | if not calculate_mapping_rate: 239 | mapping_rates = None 240 | return {'E1':E1,'E1_mask':E1_mask,'E1_seg':E1_seg,'E2':None, 'mapping_rates':mapping_rates, 'word2id':Preprocessor.word2id,'id2word':Preprocessor.id2word} -------------------------------------------------------------------------------- /src/preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuningxi/tBERT/604e3f728629c5dd0fe1f201464440edbaf82e79/src/preprocessing/__init__.py -------------------------------------------------------------------------------- /src/topic_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuningxi/tBERT/604e3f728629c5dd0fe1f201464440edbaf82e79/src/topic_model/__init__.py -------------------------------------------------------------------------------- /src/topic_model/gsdmm.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import gensim 4 | import numpy as np 5 | import pandas as pd 6 | import pickle 7 | 8 | from src.ShortTextTopic.gsdmm import MovieGroupProcess 9 | from src.topic_model.topic_loader import get_topic_model_folder, get_topic_parent_folder,get_topic_root_folder 10 | 11 | # ----------------------------------------------- 12 | # save and load functions for gsdmm topic model 13 | # ----------------------------------------------- 14 | 15 | def train_save_gsdmm_model(processed_texts,id2word, opt): 16 | ''' 17 | Trains and saves Gibbs Sampling Dirichlet Multinomial Mixture model (topic model for short text) 18 | :param processed_texts: preprocessed texts 19 | :param opt: option dictionary 20 | :return: trained topic model 21 | ''' 22 | alpha = opt.get('topic_alpha', 0.1) 23 | # It's important to always choose K to be larger than the number of clusters you expect exist in your data, as the algorithm can never return more than K clusters. 24 | gsdmm_model = MovieGroupProcess(K=opt['num_topics'], alpha=alpha, beta=alpha, n_iters=30) 25 | # corpus, id2word, data_lemmatized, texts = lda_preprocess(data, id2word=None, print_steps=True, lemmatize=False) 26 | vocab = set(x for doc in processed_texts for x in doc) 27 | n_terms = len(vocab) 28 | y = gsdmm_model.fit(processed_texts, n_terms) 29 | # save additional information for later 30 | gsdmm_model.id2word = id2word 31 | gsdmm_model.num_topics = opt['num_topics'] 32 | save_gsdmm_model(gsdmm_model, opt) 33 | return gsdmm_model 34 | 35 | def save_gsdmm_model(gsdmm_model, opt): 36 | root_folder = get_topic_root_folder(opt) 37 | if not os.path.exists(root_folder): 38 | os.mkdir(root_folder) 39 | model_folder = get_topic_parent_folder(opt) 40 | if not os.path.exists(model_folder): 41 | os.mkdir(model_folder) 42 | model_path = get_topic_model_folder(opt) 43 | if not os.path.exists(model_path): 44 | os.mkdir(model_path) 45 | print('Saving Topic Model to {}...'.format(model_path)) 46 | with open(model_path + 'gsdmm.model', 'wb') as f: 47 | pickle.dump(gsdmm_model, f) 48 | f.close() 49 | write_keywords(gsdmm_model, opt) 50 | 51 | def load_gsdmm_model(opt): 52 | model_path = os.path.join(get_topic_model_folder(opt), 'gsdmm.model') 53 | print('Loading Topic Model from {}...'.format(model_path)) 54 | filehandler = open(model_path, 'rb') 55 | gsdmm_model = pickle.load(filehandler) 56 | print('Done.') 57 | return gsdmm_model 58 | 59 | # ----------------------------------------------- 60 | # functions to infer topic distribution 61 | # ----------------------------------------------- 62 | 63 | def infer_gsdmm_topics(gsdmm_model, texts): 64 | ''' 65 | Predicts topic distribution for tokenized sentences 66 | :param gsdmm_model: 67 | :param texts: nested list of tokenized sentences 68 | :return: 69 | ''' 70 | assert type(texts)==list 71 | assert type(texts[0])==list 72 | assert type(texts[0][0])==str 73 | dist_over_topic = [gsdmm_model.score(t) for t in texts] # probablility distribution over topics 74 | global_topics = extract_topic_from_gsdmm_prediction(dist_over_topic) 75 | return global_topics 76 | 77 | # ----------------------------------------------- 78 | # functions to explore topic distribution 79 | # ----------------------------------------------- 80 | 81 | def extract_topic_from_gsdmm_prediction(dist_over_topic): 82 | ''' 83 | Extracts topic vectors from prediction 84 | :param dist_over_topic: nested topic distribution object 85 | :return: topic array with (examples, num_topics) 86 | ''' 87 | global_topics = np.array(dist_over_topic) 88 | return global_topics 89 | 90 | def top_words(cluster_word_distribution, top_cluster, values): 91 | for cluster in top_cluster: 92 | sort_dicts =sorted(cluster_word_distribution[cluster].items(), key=lambda k: k[1], reverse=True)[:values] 93 | print('Cluster %s : %s'%(cluster,sort_dicts)) 94 | print(' — — — — — — — — — ') 95 | 96 | def write_keywords(gsdmm_model, opt, cutoff=20): 97 | clusters = [i for i in range(len(gsdmm_model.cluster_doc_count))] 98 | importance = [doc_count/sum(gsdmm_model.cluster_doc_count) for doc_count in gsdmm_model.cluster_doc_count] 99 | topickey_path = os.path.join(get_topic_model_folder(opt), 'topickeys.txt') 100 | with open(topickey_path, 'w', encoding="utf-8") as outfile: 101 | for i in clusters: 102 | sort_dicts = sorted(gsdmm_model.cluster_word_distribution[i].items(), key=lambda k: k[1], reverse=True)[:cutoff] 103 | keywords = ' '.join([w for w,c in sort_dicts]) 104 | outfile.writelines('{}\t{}\t{}\n'.format(i,importance[i],keywords)) 105 | 106 | # def show_topic_dist(topic_dist, lda_model): 107 | # ''' 108 | # Print topic distribution with labels and scores 109 | # :param topic_dist: 110 | # :param lda_model: 111 | # :return: 112 | # ''' 113 | # total = 0 114 | # props = [] 115 | # topics = [] 116 | # for topic_num, prop_topic in enumerate(topic_dist): 117 | # # [(self.id2word[id], value) for id, value in self.get_topic_terms(topicid, topn)] 118 | # wp = lda_model.show_topic(topic_num) 119 | # topic_keywords = ", ".join([word for word, prop in wp]) 120 | # # print(topic_keywords + ' {}'.format(prop_topic)) 121 | # topics.append(topic_keywords) 122 | # props.append(prop_topic) 123 | # total += prop_topic 124 | # if 1 - total > 0.01: 125 | # raise ValueError('Should add up to 1, but are {}'.format(total)) 126 | # sent_topics_df = pd.DataFrame(props, topics) 127 | # return sent_topics_df.sort_values(0, ascending=False) 128 | 129 | if __name__ == '__main__': 130 | from src.topic_model.topic_trainer import load_sentences_for_topic_model,lda_preprocess 131 | # Import Dataset 132 | gsdmm_opt = {'dataset': 'Semeval', 'datapath': 'data/', 'num_topics': 10, 'topic_type': 'gsdmm', 133 | 'tasks': ['A','B','C'], 'n_gram_embd': False, 'numerical': False, 'topic_alpha': 0.1, 134 | 'subsets': ['train_large'], 'cache': True, 'stem':True, 135 | 'max_m':100} 136 | # load input data 137 | data = load_sentences_for_topic_model(gsdmm_opt) 138 | # preprocess input for LDA 139 | 140 | corpus, id2word, processed_texts = lda_preprocess(data, id2word=None, delete_stopwords=True, print_steps=True) 141 | # # train model 142 | # gsdmm_model = train_save_gsdmm_model(processed_texts, id2word, gsdmm_opt) 143 | gsdmm_model = load_gsdmm_model(gsdmm_opt) 144 | # gsdmm_predictions = infer_gsdmm_topics(gsdmm_model, processed_texts) 145 | # todo: compare format with lda 146 | 147 | 148 | 149 | 150 | # # # # evaluate 151 | # results = evaluate_topic_model(lda_model, corpus, processed_texts, id2word) 152 | # from src.topic_model.topic_predictor import infer_and_write_word_topics 153 | # infer_and_write_word_topics(opt) -------------------------------------------------------------------------------- /src/topic_model/lda.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import gensim 4 | import numpy as np 5 | 6 | from gensim.models import CoherenceModel 7 | from src.topic_model.topic_loader import get_topic_model_folder, get_topic_parent_folder 8 | 9 | # ----------------------------------------------- 10 | # save and load functions for lda topic models 11 | # ----------------------------------------------- 12 | 13 | def train_save_lda_model(corpus, id2word, opt): 14 | ''' 15 | Trains and saves original LDA model from Gensim implementation 16 | :param corpus: preprocessed corpus 17 | :param id2word: id2word from preprocessing step 18 | :param opt: option dictionary 19 | :return: trained topic model 20 | ''' 21 | lda_model = gensim.models.ldamodel.LdaModel(corpus=corpus, 22 | id2word=id2word, 23 | num_topics=opt['num_topics'], 24 | random_state=100, 25 | update_every=1, 26 | chunksize=100, 27 | passes=10, 28 | alpha='auto', 29 | per_word_topics=True) 30 | # save 31 | save_topic_model(lda_model, opt, 'LDA') 32 | return lda_model 33 | 34 | 35 | def train_ldamallet_topic_model(corpus, id2word, opt): 36 | ''' 37 | Trains and saves LDA model from ldamallet implementation (tends to be better) 38 | :param corpus: preprocessed corpus 39 | :param id2word: id2word from preprocessing step 40 | :param opt: option dictionary 41 | :return: trained topic model 42 | ''' 43 | # Download File: http://mallet.cs.umass.edu/dist/mallet-2.0.8.zip 44 | mallet_path = os.path.join(opt['datapath'], 'topic_models', 'mallet-2.0.8', 'bin', 'mallet') 45 | prefix = get_topic_model_folder(opt) # saves model here, e.g.: 'data/topic_models/Semeval_15/ldamallet' 46 | parent_folder = get_topic_parent_folder(opt) # e.g.: 'data/topic_models/Semeval_15' 47 | alpha = opt.get('topic_alpha', 50) 48 | if not os.path.exists(parent_folder): 49 | os.mkdir(parent_folder) 50 | os.mkdir(prefix) 51 | lda_model = gensim.models.wrappers.LdaMallet(mallet_path, corpus=corpus, num_topics=opt['num_topics'], 52 | id2word=id2word, prefix=prefix, alpha=alpha) # try to change alpha 53 | # Save model 54 | save_topic_model(lda_model, opt) 55 | return lda_model 56 | 57 | 58 | def save_topic_model(topic_model, opt): 59 | model_path = os.path.join(get_topic_model_folder(opt), 'lda_model') 60 | print('Saving Topic Model to {}...'.format(model_path)) 61 | topic_model.save(model_path) 62 | 63 | 64 | def load_lda_model(opt): 65 | model_path = os.path.join(get_topic_model_folder(opt), 'lda_model') 66 | print('Loading Topic Model from {}...'.format(model_path)) 67 | lda_model = gensim.models.LdaModel.load(model_path) 68 | # update path in case topic model was trained with old code on VM (starting with /data-disk/...) 69 | lda_model.mallet_path = 'data/topic_models/mallet-2.0.8/bin/mallet' 70 | print('Done.') 71 | return lda_model 72 | 73 | # ----------------------------------------------- 74 | # functions to infer topic distribution 75 | # ----------------------------------------------- 76 | 77 | def infer_lda_topics(lda_model, new_corpus): 78 | assert type(new_corpus) == list # of sentences 79 | assert type(new_corpus[0]) == list # of word id/count tuples 80 | assert type(new_corpus[0][0]) == tuple 81 | return lda_model[new_corpus] 82 | 83 | # ----------------------------------------------- 84 | # functions to explore topic distribution 85 | # ----------------------------------------------- 86 | 87 | def extract_topic_from_lda_prediction(dist_over_topic, num_topics): 88 | ''' 89 | Extracts topic vectors from prediction 90 | :param dist_over_topic: nested topic distribution object 91 | :param num_topics: number of topics 92 | :return: topic array with (examples, num_topics) 93 | ''' 94 | # iterate through nested representation and extract topic distribution as single vector with length=num_topics for each document 95 | topic_array = [] 96 | for j, example in enumerate(dist_over_topic): 97 | i = 0 98 | topics_per_doc = [] 99 | for topic_num, prop_topic in example[0]: 100 | # print(i) 101 | # print(topic_num) 102 | while not i == topic_num: 103 | topics_per_doc.append(0) # fill in 'missing' topics with probabilites < threshold as P=0 104 | # print('missing') 105 | i = i + 1 106 | topics_per_doc.append(prop_topic) 107 | i = i + 1 108 | while len(topics_per_doc) < num_topics: 109 | topics_per_doc.append(0) # fill in last 'missing' topics 110 | topic_array.append(np.array(topics_per_doc)) 111 | global_topics = np.array(topic_array) 112 | # sanity check 113 | if not ((len(global_topics.shape) == 2) and (global_topics.shape[1] == num_topics)): 114 | print('Inconsistent topic vector length detected:') 115 | i = 0 116 | for dist, example in zip(global_topics, dist_over_topic): 117 | if len(dist) != num_topics: 118 | print('{}th example with length {}: {}'.format(i, len(dist), dist)) 119 | print('from: {}'.format(example[0])) 120 | print('--') 121 | i += 1 122 | return global_topics 123 | 124 | def extract_topic_from_ldamallet_prediction(dist_over_topic): 125 | ''' 126 | Extracts topic vectors from prediction 127 | :param dist_over_topic: nested topic distribution object 128 | :return: topic array with (examples, num_topics) 129 | ''' 130 | global_topics = np.array([np.array([prop_topic for j, (topic_num, prop_topic) in enumerate(example)]) for example in 131 | dist_over_topic]) # doesn't work with missing topics 132 | return global_topics 133 | 134 | def evaluate_lda_topic_model(lda_model, corpus, processed_texts, id2word): 135 | ''' 136 | Performs intrinsic evaluation by computing perplexity and coherence 137 | :param lda_model: 138 | :param corpus: 139 | :param processed_texts: 140 | :param id2word: 141 | :return: dictionary with perplexity and coherence 142 | ''' 143 | try: 144 | model_perplexity = lda_model.log_perplexity(corpus) 145 | print('\nPerplexity: ', model_perplexity) # a measure of how good the model is. lower the better. 146 | results = {'perplexity': model_perplexity} 147 | except AttributeError: 148 | results = {} 149 | coherence_model_lda = CoherenceModel(model=lda_model, texts=processed_texts, dictionary=id2word, coherence='c_v') 150 | coherence_lda = coherence_model_lda.get_coherence() # the higher the better 151 | print('\nCoherence Score: ', coherence_lda) 152 | results['coherence'] = coherence_lda 153 | return results 154 | if __name__ == '__main__': 155 | from src.topic_model.topic_trainer import load_sentences_for_topic_model,lda_preprocess 156 | # Import Dataset 157 | opt = {'dataset': 'Semeval', 'datapath': 'data/', 'num_topics': 10, 'topic_type': 'ldamallet', 158 | 'tasks': ['A','B','C'], 'n_gram_embd': False, 'numerical': False, 159 | 'subsets': ['train_large'], 'cache': True, 'stem':False, 160 | 'max_m':100} 161 | # load input data 162 | data = load_sentences_for_topic_model(opt) 163 | # preprocess input for LDA 164 | corpus, id2word, processed_texts = lda_preprocess(data, id2word=None, delete_stopwords=True, print_steps=True) 165 | lda_model = load_lda_model(opt) 166 | # dist_over_topic = infer_lda_topics(lda_model, corpus) 167 | # global_topics = extract_topic_from_ldamallet_prediction(dist_over_topic) 168 | -------------------------------------------------------------------------------- /src/topic_model/topic_eval.py: -------------------------------------------------------------------------------- 1 | from src.topic_model.topic_trainer import load_sentences_for_topic_model,lda_preprocess,load_topic_model 2 | from src.topic_model.lda import evaluate_lda_topic_model 3 | import os 4 | from src.topic_model.topic_loader import get_topic_root_folder 5 | import argparse 6 | 7 | def evaluate_topic_model(topic_model, corpus, processed_texts, id2word, opt): 8 | if opt['topic_type'] in ['LDA','ldamallet']: 9 | results = evaluate_lda_topic_model(topic_model, corpus, processed_texts, id2word) 10 | elif opt['topic_type'] =='gsdmm': 11 | raise NotImplementedError() 12 | write_topic_model_log(opt, results) 13 | return results 14 | 15 | def write_topic_model_log(opt,results): 16 | log_path = os.path.join(get_topic_root_folder(opt), 'eval_log.txt') 17 | if not os.path.exists(log_path): 18 | with open(log_path, 'a') as outfile: 19 | outfile.writelines('{}\t{}\t{}\t{}\n'.format('num_topics', 'topic_alpha','coherence','perplexity')) # todo: add metric for gsdmm evaluation? 20 | with open(log_path, 'a') as outfile: 21 | outfile.writelines('{}\t{}\t{}\t{}\n'.format(opt['num_topics'],opt.get('topic_alpha',50),results.get('coherence','N/A'),results.get('perplexity','N/A'))) 22 | print('Wrote eval log: {}'.format(log_path)) 23 | 24 | if __name__ == '__main__': 25 | 26 | # get command line arguments 27 | parser = argparse.ArgumentParser() 28 | parser.register("type", "bool", lambda v: v.lower() == "true") 29 | parser.add_argument("--stem",type="bool",nargs="?",const=True,default=False) 30 | parser.add_argument('-dataset', action="store", dest="dataset", type=str, default='Quora') 31 | parser.add_argument('-topic_type', action="store", dest="topic_type", type=str, default='ldamallet') 32 | FLAGS, unparsed = parser.parse_known_args() 33 | 34 | # Import Dataset 35 | if FLAGS.dataset in ['Quora','MSRP']: 36 | subsets = ['train'] # , 'dev', 'test' 37 | tasks = ['B'] 38 | elif FLAGS.dataset == 'Semeval': 39 | subsets = ['train_large'] # , 'test2016', 'test2017' 40 | tasks = ['A', 'B', 'C'] 41 | opt = {'dataset': FLAGS.dataset, 'datapath': 'data/', 'topic_type': FLAGS.topic_type, 42 | 'tasks': tasks, 'n_gram_embd': False, 'numerical': False, 43 | 'subsets': subsets, 'cache': True, 'stem': FLAGS.stem} 44 | 45 | # load input data 46 | train_data = load_sentences_for_topic_model(opt) 47 | # preprocess input for LDA 48 | stem = opt.get('stem', False) 49 | lemmatize = opt.get('lemmatize', False) 50 | corpus, id2word, processed_texts = lda_preprocess(train_data, id2word=None, delete_stopwords=True, print_steps=True) 51 | for alpha in [0.1]: # ,10,1,0.1 52 | for topics in [t*10 for t in range(1,11)]: 53 | # load topic model, evaluate and save in log 54 | opt['num_topics'] = topics 55 | opt['topic_alpha'] = alpha 56 | topic_model = load_topic_model(opt) 57 | results = evaluate_topic_model(topic_model, corpus, processed_texts, topic_model.id2word, opt) 58 | -------------------------------------------------------------------------------- /src/topic_model/topic_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | # ----------------------------------------------- 5 | # get functions for consistent model dirs 6 | # ----------------------------------------------- 7 | 8 | def get_alpha_str(opt): 9 | return '_alpha{}'.format(opt.get('topic_alpha',50)) 10 | def get_topic_root_folder(opt): 11 | preprocessing = 'basic' 12 | if opt['topic_type']=='gsdmm': 13 | return os.path.join(opt['datapath'], 'topic_models', preprocessing+'_gsdmm', '') 14 | else: 15 | return os.path.join(opt['datapath'], 'topic_models', preprocessing,'') 16 | def get_topic_parent_folder(opt): 17 | alpha_str = get_alpha_str(opt) 18 | return os.path.join(get_topic_root_folder(opt),opt['dataset']+alpha_str+'_'+str(opt['num_topics']),'') 19 | def get_topic_model_folder(opt): 20 | return os.path.join(get_topic_parent_folder(opt), opt['topic_type'],'') 21 | def get_topic_pred_folder(opt): 22 | return os.path.join(get_topic_model_folder(opt), 'predictions','') 23 | 24 | # ----------------------------------------------- 25 | # prediction loaders 26 | # ----------------------------------------------- 27 | 28 | def load_document_topics(opt,recover_topic_peaks,max_m=None): 29 | ''' 30 | Loads inferred topic distribution for each sentence in dataset. Dependent on data subset (e.g. Semeval A vs. B). 31 | :param opt: option dict 32 | :return: document topics in list corresponding to subsets {'D_T1':T1,'D_T2':T2} 33 | ''' 34 | # set model paths 35 | filepaths1 = [] 36 | filepaths2 = [] 37 | topic_model_folder = get_topic_pred_folder(opt) 38 | task = opt.get('tasks')[0] 39 | subsets = opt.get('subsets') 40 | for s in subsets: # train, dev, test 41 | filepaths1.append(os.path.join(topic_model_folder,task+'_'+s+'_1.npy')) 42 | filepaths2.append(os.path.join(topic_model_folder,task+'_'+s+'_2.npy')) 43 | # load 44 | T1 = [np.load(f) for f in filepaths1] 45 | T2 = [np.load(f) for f in filepaths2] 46 | 47 | if recover_topic_peaks: 48 | for split in range(len(T1)): 49 | for line in range(len(T1[split])): 50 | T1[split][line] = unflatten_topic(T1[split][line]) 51 | T2[split][line] = unflatten_topic(T2[split][line]) 52 | 53 | # reduce number of examples if necessary 54 | # max_examples = opt.get('max_m',None) 55 | if not max_m is None: 56 | T1 = [t[:max_m] for t in T1] 57 | T2 = [t[:max_m] for t in T2] 58 | return {'D_T1':T1,'D_T2':T2} 59 | 60 | def unflatten_topic(topic_vector): 61 | # unflatten topic distribution 62 | min_val = topic_vector.min() 63 | for j, topic in enumerate(topic_vector): 64 | if topic == min_val: 65 | topic_vector[j] = 0 66 | return topic_vector 67 | 68 | def load_word_topics(opt, add_unk = True,recover_topic_peaks=False): 69 | ''' 70 | Reads word topic vector and dictionary from file 71 | :param opt: option dictionary containing settings for topic model 72 | :return: word_topic_dict,id_word_dict 73 | ''' 74 | complete_word_topic_dict = {} 75 | id_word_dict = {} 76 | word_id_dict = {} 77 | count = 0 78 | topic_matrix = [] 79 | num_topics = opt.get('num_topics',None) 80 | unk_topic = opt.get('unk_topic','uniform') 81 | word_topic_file = os.path.join(get_topic_pred_folder(opt),'word_topics.log') 82 | # todo:train topic model 83 | print("Reading word_topic vector from {}".format(word_topic_file)) 84 | if add_unk: 85 | # add line for UNK word topics 86 | word = '' 87 | if unk_topic=='zero': 88 | topic_vector = np.array([0.0]*num_topics) 89 | elif unk_topic=='uniform': 90 | assert not recover_topic_peaks, "Do not use unk_topic='uniform' and 'unflat_topics'=True' together. As it will result in flattened non-topics, but unflattened topics." 91 | topic_vector = np.array([1/num_topics] * num_topics) 92 | else: 93 | raise ValueError 94 | wordid = 0 95 | id_word_dict[wordid] = word 96 | word_id_dict[word] = wordid 97 | complete_word_topic_dict[word] = topic_vector 98 | topic_matrix.append(topic_vector) 99 | # read other word topics 100 | with open(word_topic_file, 'r', encoding='utf-8') as infile: 101 | for line in infile: 102 | count += 1 103 | if count > 2: 104 | ldata = line.rstrip().split(' ') # \xa in string caused trouble 105 | if add_unk: 106 | wordid = int(ldata[0])+1 107 | else: 108 | wordid = int(ldata[0]) 109 | word = ldata[1] 110 | id_word_dict[wordid] = word 111 | word_id_dict[word] = wordid 112 | # print(ldata[2:]) 113 | topic_vector = np.array([float(s.replace('[','').replace(']','')) for s in ldata[2:]]) 114 | assert len(topic_vector)==num_topics 115 | if recover_topic_peaks: 116 | topic_vector = unflatten_topic(topic_vector) 117 | complete_word_topic_dict[word] = topic_vector 118 | topic_matrix.append(topic_vector) 119 | topic_matrix = np.array(topic_matrix) 120 | print('word topic embedding dim: {}'.format(topic_matrix.shape)) 121 | assert len(topic_matrix.shape)==2 122 | return {'complete_topic_dict':complete_word_topic_dict,'topic_dict':id_word_dict,'word_id_dict':word_id_dict,'topic_matrix':topic_matrix} 123 | 124 | if __name__ == '__main__': 125 | 126 | # Example usage 127 | 128 | # load topic predictions for dataset 129 | opt = {'dataset': 'Semeval', 'datapath': 'data/', 130 | 'tasks': ['A'], 131 | 'subsets': ['train_large','test2016','test2017'], 132 | 'model': 'basic_cnn', 'load_ids':True, 133 | 'num_topics':20, 'topic_type':'ldamallet','unk_topic':'zero'} 134 | 135 | doc_topics = load_document_topics(opt) 136 | word_topics = load_word_topics(opt) 137 | 138 | word_topics['topic_dict'] 139 | word_topics['topic_matrix'] 140 | 141 | word_topics['topic_dict'][0] 142 | word_topics['topic_matrix'][0] 143 | 144 | word_topics['complete_topic_dict']['trends'] 145 | -------------------------------------------------------------------------------- /src/topic_model/topic_mass_trainer.py: -------------------------------------------------------------------------------- 1 | from src.topic_model.topic_trainer import lda_preprocess,load_sentences_for_topic_model,train_topic_model 2 | from src.topic_model.topic_predictor import infer_and_write_word_topics,infer_and_write_document_topics 3 | from src.topic_model.topic_loader import get_topic_root_folder,get_alpha_str 4 | import os 5 | from os import path 6 | 7 | import argparse 8 | 9 | # trains a number of topic models with different number of topics for specific dataset 10 | # usage example: python src/topic_model/topic_mass_trainer.py -dataset Quora -min 10 -max 100 -alpha 50 11 | 12 | 13 | def check_existing_topic_models(opt): 14 | topic_folder = get_topic_root_folder(opt) 15 | if path.exists(topic_folder): 16 | print('Looking for existing topic models for {} in {}'.format(opt['dataset'],topic_folder)) 17 | prefix = opt['dataset']+get_alpha_str(opt)+'_' 18 | existing = [int(f.split('_')[-1]) for f in os.listdir(topic_folder) if f.startswith(prefix)] 19 | print('Found {}.'.format(existing)) 20 | return existing 21 | else: 22 | return [] 23 | 24 | # read logs 25 | def train_many_topic_models(opt,minimum_topics,maximum_topics,existing_models): 26 | # load input data 27 | data = load_sentences_for_topic_model(opt) 28 | # preprocess input for LDA 29 | corpus, id2word, processed_texts = lda_preprocess(data, id2word=None, delete_stopwords=True, print_steps=False) 30 | 31 | for n in range(minimum_topics, maximum_topics, 10): 32 | if n not in existing_models: 33 | print('===') 34 | print('Training topic model with {} topics '.format(n)) 35 | # train model 36 | opt['num_topics'] = n 37 | topic_model = train_topic_model(corpus, id2word, processed_texts, opt) 38 | # predict 39 | for t in tasks: 40 | opt['tasks'] = [t] 41 | opt['subsets'] = subsets 42 | infer_and_write_document_topics(opt, topic_model, id2word) # different for each subset and task 43 | infer_and_write_word_topics(opt, topic_model, id2word) # same for each subset and task 44 | print('===') 45 | 46 | if __name__ == '__main__': 47 | 48 | parser = argparse.ArgumentParser() 49 | parser.register("type", "bool", lambda v: v.lower() == "true") 50 | parser.add_argument( 51 | "--debug", 52 | type="bool", 53 | nargs="?", 54 | const=True, 55 | default=False, 56 | help="Use to run with tiny portion of actual data. ") 57 | parser.add_argument('-max', action="store", dest="max", type=int, default=None) 58 | parser.add_argument('-min', action="store", dest="min", type=int, default=None) 59 | parser.add_argument('-alpha', action="store", dest="alpha", type=float, default=None) 60 | parser.add_argument('-dataset', action="store", dest="dataset", type=str, default='Quora') 61 | parser.add_argument('-topic_type', action="store", dest="topic_type", type=str, default='ldamallet') 62 | 63 | FLAGS, unparsed = parser.parse_known_args() 64 | 65 | print('Dataset: {}'.format(FLAGS.dataset)) 66 | print('Minimum: {}'.format(FLAGS.min)) 67 | print('Maximum: {}'.format(FLAGS.max)) 68 | if (FLAGS.alpha).is_integer(): 69 | FLAGS.alpha = int(FLAGS.alpha) 70 | print('Alpha: {}'.format(FLAGS.alpha)) 71 | 72 | datasets = [FLAGS.dataset] 73 | 74 | for d in datasets: 75 | print('Start topic model training for {}'.format(d)) 76 | # construct opt 77 | if d == 'Semeval': 78 | tasks = ['A', 'B', 'C'] 79 | subsets = ['train_large', 'test2016', 'test2017'] 80 | else: 81 | tasks = ['B'] 82 | subsets = ['train', 'dev', 'test'] 83 | 84 | opt = {'dataset': d, 'datapath': 'data/', 'num_topics': None, 'topic_type': FLAGS.topic_type, 85 | 'subsets': [subsets[0]], 'cache': True} 86 | if not FLAGS.alpha is None: 87 | opt['topic_alpha'] = FLAGS.alpha 88 | if FLAGS.debug: 89 | opt['max_m']=1000 90 | print('--- DEBUG MODE ---') 91 | 92 | existing = check_existing_topic_models(opt) 93 | minimum_topics = FLAGS.min 94 | maximum_topics = FLAGS.max+10 95 | 96 | if len([i for i in range(minimum_topics, maximum_topics, 10) if not i in existing])>0: 97 | print('Preparing to train the following topic models:') 98 | print([i for i in range(minimum_topics, maximum_topics, 10) if not i in existing]) 99 | train_many_topic_models(opt,minimum_topics,maximum_topics,existing) 100 | else: 101 | print('All topic models already exist for specified number of topics.') -------------------------------------------------------------------------------- /src/topic_model/topic_predictor.py: -------------------------------------------------------------------------------- 1 | # Enable logging for gensim - optional 2 | import logging 3 | import os 4 | import re 5 | import argparse 6 | 7 | import numpy as np 8 | 9 | from src.loaders.load_data import load_data 10 | from src.topic_model.topic_loader import get_topic_model_folder, get_topic_pred_folder 11 | from src.topic_model.topic_trainer import infer_topic_dist,lda_preprocess,load_topic_model 12 | 13 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.ERROR) 14 | 15 | # NLTK Stop words 16 | from nltk.corpus import stopwords 17 | stop_words = stopwords.words('english') 18 | 19 | ''' 20 | High level functions to apply based on topic model functionality from src.preprocessing.topic_models to infer, save and 21 | load topic distribution for data sets 22 | ''' 23 | 24 | # ----------------------------------------------- 25 | # infer, save, load topic distribution functions for dataset 26 | # ----------------------------------------------- 27 | 28 | # -------- document topic model -------- 29 | 30 | def infer_and_write_document_topics(opt, topic_model=None, id2word=None): 31 | ''' 32 | Infer global topic distribution for all documents in data splits (e.g. train, dev, test) mentioned in opt['subsets'] 33 | and save as for both documents separately as numpy arrays. 34 | :param opt: option dictionary 35 | :return: Nothing 36 | ''' 37 | subsets = opt['subsets'] 38 | # try to load topic model 39 | if topic_model is None or id2word is None: 40 | topic_model = load_topic_model(opt) 41 | id2word = topic_model.id2word 42 | 43 | assert len(opt['tasks'])==1 44 | task = opt.get('tasks')[0] 45 | topic_dist_list = [] 46 | 47 | for subset in subsets: # train, dev, test 48 | # load data split 49 | opt['subsets']=[subset] 50 | data_dict = load_data(opt, numerical=False) 51 | 52 | # use tokenized sentences, not raw strings (to prevent topic inference tokenisation bug which resulted in mapping the whole sentence to UNK -->same doc topic for every sentence) 53 | sent_1 = data_dict['T1'][0] # undo nested list since we are only dealing with one subset at a time 54 | sent_2 = data_dict['T2'][0] 55 | 56 | # preprocess and infer topics 57 | new_corpus_1, _, new_processed_texts_1 = lda_preprocess(sent_1, id2word=id2word, delete_stopwords=True, 58 | print_steps=False) 59 | topic_dist_1 = infer_topic_dist(new_corpus_1,new_processed_texts_1, topic_model, opt['topic_type']) 60 | 61 | # preprocess and infer topics 62 | new_corpus_2, _, new_processed_texts_2 = lda_preprocess(sent_2, id2word=id2word, delete_stopwords=True, 63 | print_steps=False) 64 | topic_dist_2 = infer_topic_dist(new_corpus_2,new_processed_texts_2, topic_model, opt['topic_type']) 65 | 66 | # sanity check 67 | assert(len(sent_1)==len(sent_2)==topic_dist_1.shape[0]==topic_dist_2.shape[0]) # same number of examples 68 | assert(topic_dist_1.shape[1]==topic_dist_2.shape[1]) # same number of topics 69 | 70 | # set model path 71 | topic_model_folder = get_topic_pred_folder(opt) 72 | # make folder if not existing 73 | if not os.path.exists(topic_model_folder): 74 | os.mkdir(topic_model_folder) 75 | topic_dist_path = os.path.join(topic_model_folder,task+'_'+subset) 76 | 77 | # save as separate numpy arrays 78 | np.save(topic_dist_path+'_1', topic_dist_1) 79 | np.save(topic_dist_path+'_2', topic_dist_2) 80 | topic_dist_list.extend([topic_dist_1,topic_dist_2]) 81 | return topic_dist_list 82 | 83 | # -------- word topic model -------- 84 | 85 | def infer_and_write_word_topics(opt, topic_model=None, id2word=None, max_vocab=None): 86 | ''' 87 | Loads a topic model and writes topic predictions for each word in dictionary to file. Independent of data subset (e.g. Semeval A = B) due to shared dictionary. 88 | :param opt: option dictionary containing settings for topic model 89 | :return: void 90 | ''' 91 | # try to load topic model 92 | if topic_model is None or id2word is None: 93 | topic_model = load_topic_model(opt) 94 | print('Infering and writing word topic distribution ...') 95 | if max_vocab is None: 96 | vocab_len = len(topic_model.id2word.keys()) 97 | if not id2word is None: 98 | assert len(topic_model.id2word.keys()) == len(id2word.keys()) 99 | else: 100 | vocab_len = max_vocab 101 | # create one word documents in bag of words format for topic model 102 | print(vocab_len) 103 | new_corpus =[[(i, 1)] for i in range(vocab_len)] 104 | # use [[word1],[word2],...] to prevent gsdmm from splitting them up into individual characters (e.g. ['w', 'o', 'r', 'd', '1']) 105 | new_processed_texts = [[topic_model.id2word[i]] for i in range(vocab_len)] 106 | print(new_processed_texts[:10]) 107 | # get topic distribution for each word in topic model dictionary 108 | # dist_over_topic = lda_model[new_corpus] 109 | # word_topics = extract_topics_from_prediction(dist_over_topic, opt['topic_type'], lda_model.num_topics) 110 | 111 | word_topics = infer_topic_dist(new_corpus,new_processed_texts, topic_model, opt['topic_type']) 112 | # if opt['topic_type'] in ['LDA','ldamallet']: 113 | # dist_over_topic = infer_lda_topics(topic_model, new_corpus) 114 | # # extract topic vectors from prediction 115 | # global_topics = extract_topics_from_prediction(dist_over_topic, type, topic_model.num_topics) 116 | # elif opt['topic_type'] == 'gsdmm': 117 | # global_topics = infer_gsdmm_topics(topic_model, new_processed_texts) 118 | 119 | topic_model_folder = get_topic_pred_folder(opt) 120 | # make folder if not existing 121 | if not os.path.exists(topic_model_folder): 122 | os.mkdir(topic_model_folder) 123 | word_topic_file = os.path.join(topic_model_folder,'word_topics.log') 124 | with open(word_topic_file, 'w', encoding='utf-8') as outfile: 125 | # write to file 126 | model_path = os.path.join(get_topic_model_folder(opt), 'lda_model') 127 | outfile.writelines('Loading Topic Model from {}\n'.format(model_path)) 128 | outfile.writelines('{}\n'.format(vocab_len)) 129 | for i in range(vocab_len): 130 | # for i,(k,w) in enumerate(topic_model.id2word.items()): 131 | w = topic_model.id2word.id2token[i] 132 | if max_vocab is None or i 0: 148 | # multiple sentences 149 | human_format = [[(id2word[id], freq) for id, freq in cp] for cp in corpus] 150 | else: 151 | # one sentence 152 | human_format = [(id2word[id], freq) for id, freq in corpus] 153 | return human_format 154 | 155 | 156 | def load_sentences_for_topic_model(opt): 157 | ''' 158 | Loads dataset 159 | :param opt: 160 | :return: 161 | ''' 162 | # for s in opt['subsets']: 163 | # assert ('train' in s) # only use for training data 164 | # load dataset 165 | from src.loaders.load_data import load_data 166 | data_dict = load_data(opt, numerical=False) 167 | R1 = data_dict['T1'] 168 | R2 = data_dict['T2'] 169 | 170 | # select sentences from dataset (avoid duplication in Semeval) 171 | if opt['dataset'] == 'Semeval': 172 | assert opt['tasks'] == ['A', 'B', 'C'] 173 | # combine data from all subtasks (A,B,C) 174 | Asent_1 = [sent for i, sent in enumerate(R1[0]) if i % 10 == 0] # only once 175 | Asent_2 = R2[0] 176 | Bsent_1 = [sent for i, sent in enumerate(R1[1]) if i % 10 == 0] # only once 177 | Bsent_2 = R2[1] 178 | # Csent_1 = [sent for i,sent in enumerate(dataset[12]) if i%100==0] # same as Bsent_! 179 | Csent_2 = R2[2] 180 | sentences = Asent_1 + Asent_2 + Bsent_1 + Bsent_2 + Csent_2 181 | print(len(sentences)) 182 | else: 183 | sentences = [s for s in R1[0]] + [s for s in R2[0]] 184 | return sentences 185 | 186 | 187 | def train_one_topic_model_on_data(opt): 188 | ''' 189 | Reads data for topic model, preprocesses input, trains multiple topic models based on opt specifications, evaluates and saves topic model evaluation log. 190 | :param opt: 191 | :return: 192 | ''' 193 | # load input data 194 | data = load_sentences_for_topic_model(opt) 195 | # preprocess input for LDA 196 | corpus, id2word, data_lemmatized, processed_texts = lda_preprocess(data, id2word=None, delete_stopwords=True, 197 | print_steps=False) 198 | # train model 199 | topic_model = train_topic_model(corpus, id2word, processed_texts, opt) # saves model automatically 200 | return topic_model 201 | 202 | def load_topic_model(opt): 203 | if opt['topic_type'] in ['LDA', 'ldamallet']: 204 | topic_model = load_lda_model(opt) 205 | elif opt['topic_type'] == 'gsdmm': 206 | topic_model = load_gsdmm_model(opt) 207 | return topic_model 208 | 209 | def read_topic_model_log(): 210 | ''' 211 | 212 | :return: settings for best topic model ('num_topics','topic_type') 213 | ''' 214 | NotImplementedError() 215 | log_path = os.path.join(get_topic_root_folder(opt), 'eval_log.txt') 216 | 217 | 218 | if __name__ == '__main__': 219 | # Import Dataset 220 | opt = {'dataset': 'MSRP', 'datapath': 'data/','topic_type': 'gsdmm', 221 | 'tasks': ['B',], 'n_gram_embd': False, 'numerical': False, 222 | 'num_topics': 10,'topic_alpha':0.1, 223 | 'subsets': ['train'], 'cache': True} 224 | # load input data 225 | train_data = load_sentences_for_topic_model(opt) 226 | # preprocess input for LDA 227 | corpus, id2word, processed_texts = lda_preprocess(train_data, id2word=None, delete_stopwords=True, print_steps=True) 228 | # # train model 229 | # topic_model = train_topic_model(corpus, id2word, processed_texts, opt) 230 | topic_model = load_topic_model(opt) 231 | topic_model.id2word 232 | 233 | -------------------------------------------------------------------------------- /src/topic_model/topic_visualiser.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import os 3 | 4 | import pandas as pd 5 | import pyLDAvis 6 | import pyLDAvis.gensim # don't skip this 7 | 8 | from src.topic_model.topic_loader import get_topic_model_folder 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | def visualise_topics(opt,lda_model=None, corpus=None, id2word=None): 13 | ''' 14 | Generic topic visualisation function that calls specific visualisation based on topic model type 15 | :param opt: 16 | :param lda_model: 17 | :param corpus: 18 | :param id2word: 19 | :return: 20 | ''' 21 | print('Visualising {} topic model...'.format(opt['topic_type'])) 22 | if opt['topic_type'] == 'ldamallet': 23 | return visualise_ldamallet_topics(opt['dataset'],opt.get('topic_alpha',50),opt['num_topics']) 24 | elif opt['topic_type'] == 'LDA': 25 | return visualise_lda_topics(lda_model, corpus, id2word) 26 | else: 27 | ValueError('Topic model type not supported. Choose "ldamallet" or "LDA".') 28 | 29 | 30 | def visualise_lda_topics(lda_model, corpus, id2word): 31 | ''' 32 | Visualizes the topics for Gensim's LDA implementation 33 | :param lda_model: 34 | :param corpus: 35 | :param id2word: 36 | :return: visualisation 37 | ''' 38 | pyLDAvis.enable_notebook() 39 | vis = pyLDAvis.gensim.prepare(lda_model, corpus, id2word) 40 | return vis 41 | 42 | def visualise_ldamallet_topics(dataset,alpha,num_topic): 43 | ''' 44 | Extracts relevant information form ldamallet's LDA model and visualizes the topics with Gensim's LDA visualisation 45 | :return: visualisation 46 | ''' 47 | ldamallet_dir = 'data/topic_models/basic/{}_alpha{}_{}/ldamallet'.format(dataset,alpha,num_topic) # e.g. Semeval_alpha50_20 48 | convertedLDAmallet = convertLDAmallet(dataDir=ldamallet_dir, filename='state.mallet.gz') 49 | pyLDAvis.enable_notebook() 50 | vis = pyLDAvis.prepare(**convertedLDAmallet) 51 | # pyLDAvis.display(vis) 52 | return vis 53 | 54 | # from http://jeriwieringa.com/2018/07/17/pyLDAviz-and-Mallet/#comment-4018495276 55 | def convertLDAmallet(dataDir='data/topic_models/SemevalA/', filename='state.mallet.gz'): 56 | def extract_params(statefile): 57 | """Extract the alpha and beta values from the statefile. 58 | 59 | Args: 60 | statefile (str): Path to statefile produced by MALLET. 61 | Returns: 62 | tuple: alpha (list), beta 63 | """ 64 | with gzip.open(statefile, 'r') as state: 65 | params = [x.decode('utf8').strip() for x in state.readlines()[1:3]] 66 | return (list(params[0].split(":")[1].split(" ")), float(params[1].split(":")[1])) 67 | 68 | def state_to_df(statefile): 69 | """Transform state file into pandas dataframe. 70 | The MALLET statefile is tab-separated, and the first two rows contain the alpha and beta hypterparamters. 71 | 72 | Args: 73 | statefile (str): Path to statefile produced by MALLET. 74 | Returns: 75 | datframe: topic assignment for each token in each document of the model 76 | """ 77 | return pd.read_csv(statefile, 78 | compression='gzip', 79 | sep=' ', 80 | skiprows=[1, 2] 81 | ) 82 | 83 | params = extract_params(os.path.join(dataDir, filename)) 84 | alpha = [float(x) for x in params[0][1:]] 85 | beta = params[1] 86 | # print("{}, {}".format(alpha, beta)) 87 | 88 | df = state_to_df(os.path.join(dataDir, filename)) 89 | df['type'] = df.type.astype(str) 90 | # df[:10] 91 | 92 | # Get document lengths from statefile 93 | docs = df.groupby('#doc')['type'].count().reset_index(name='doc_length') 94 | # docs[:10] 95 | 96 | # Get vocab and term frequencies from statefile 97 | vocab = df['type'].value_counts().reset_index() 98 | vocab.columns = ['type', 'term_freq'] 99 | vocab = vocab.sort_values(by='type', ascending=True) 100 | # vocab[:10] 101 | 102 | # Topic-term matrix from state file 103 | # https://ldavis.cpsievert.me/reviews/reviews.html 104 | import sklearn.preprocessing 105 | def pivot_and_smooth(df, smooth_value, rows_variable, cols_variable, values_variable): 106 | """ 107 | Turns the pandas dataframe into a data matrix. 108 | Args: 109 | df (dataframe): aggregated dataframe 110 | smooth_value (float): value to add to the matrix to account for the priors 111 | rows_variable (str): name of dataframe column to use as the rows in the matrix 112 | cols_variable (str): name of dataframe column to use as the columns in the matrix 113 | values_variable(str): name of the dataframe column to use as the values in the matrix 114 | Returns: 115 | dataframe: pandas matrix that has been normalized on the rows. 116 | """ 117 | matrix = df.pivot(index=rows_variable, columns=cols_variable, values=values_variable).fillna(value=0) 118 | matrix = matrix.values + smooth_value 119 | normed = sklearn.preprocessing.normalize(matrix, norm='l1', axis=1) 120 | return pd.DataFrame(normed) 121 | 122 | phi_df = df.groupby(['topic', 'type'])['type'].count().reset_index(name='token_count') 123 | phi_df = phi_df.sort_values(by='type', ascending=True) 124 | # phi_df[:10] 125 | phi = pivot_and_smooth(phi_df, beta, 'topic', 'type', 'token_count') 126 | # phi[:10] 127 | theta_df = df.groupby(['#doc', 'topic'])['topic'].count().reset_index(name='topic_count') 128 | # theta_df[:10] 129 | theta = pivot_and_smooth(theta_df, alpha, '#doc', 'topic', 'topic_count') 130 | data = {'topic_term_dists': phi, 131 | 'doc_topic_dists': theta, 132 | 'doc_lengths': list(docs['doc_length']), 133 | 'vocab': list(vocab['type']), 134 | 'term_frequency': list(vocab['term_freq']) 135 | } 136 | return data 137 | 138 | # Visualisation for topic prediction 139 | def read_topic_key_table(opt): 140 | keyfile = get_topic_model_folder(opt) + 'topickeys.txt' 141 | topic_table = pd.read_csv(keyfile, sep='\t', header=None, names=['topic','ratio','keywords']) 142 | return reformat_topic_key_table(topic_table) 143 | 144 | def reformat_topic_key_table(topic_table): 145 | topic_table['keywords'] = ['T{}: {}'.format(i, k) for i, k in zip(topic_table['topic'], topic_table['keywords'])] 146 | return topic_table 147 | 148 | def visualise_topic_pred(sentence,topic_vector,topic_table,figsize=None): 149 | ''' 150 | Produces horizontal bar plot for predicted topics 151 | :param sentence: str 152 | :param topic_vector: vector with length=num_topics 153 | :param topic_table: table with topic key words 154 | :return: plot 155 | ''' 156 | print(sentence) 157 | print(topic_vector) 158 | topic_table['pred'] = topic_vector 159 | return topic_table.plot.barh('keywords','pred',figsize=figsize) 160 | 161 | def shorten_topic_keywords(topic_table,n_keywords=None): 162 | topic_keywords = [t for t in topic_table['keywords']] 163 | if not n_keywords is None: 164 | topic_keywords = [' '.join(t.split(' ')[:n_keywords + 1]) for t in topic_keywords] 165 | return topic_keywords 166 | 167 | def visualise_doc_topic_pair_pred(sentence1, sentence2, topic_vector1,topic_vector2,topic_table,figsize=None,print_vectors=False,n_keywords=None,alpha=None): 168 | ''' 169 | Produces horizontal bar plot for predicted topics 170 | :param sentence: str 171 | :param topic_vector: vector with length=num_topics 172 | :param topic_table: table with topic key words 173 | :return: plot 174 | ''' 175 | if print_vectors: 176 | print('doc1: {} {}'.format(sentence1,topic_vector1)) 177 | print('doc2: {} {}'.format(sentence2,topic_vector2)) 178 | else: 179 | print('doc1: {}'.format(sentence1)) 180 | print('doc2: {}'.format(sentence2 )) 181 | plt.figure(figsize=figsize) 182 | topic_keywords = shorten_topic_keywords(topic_table,n_keywords) 183 | title = 'doc1: {}\ndoc2: {}'.format(sentence1, sentence2) 184 | if not alpha is None: 185 | title = '{}\nalpha: {}'.format(title,alpha) 186 | plt.title(title) 187 | plt.barh(topic_keywords, topic_vector1, alpha=0.1, label='doc1', color='r') 188 | plt.barh(topic_keywords, topic_vector2, alpha=0.1, label='doc2', color='b') 189 | plt.legend(loc='best') 190 | plt.show() 191 | 192 | def visualise_word_topic_pair_pred(sentence1, sentence2, topic_vector1,topic_vector2,topic_table,figsize=None,n_keywords=None,alpha=None): 193 | ''' 194 | Produces horizontal bar plot for predicted topics 195 | :param sentence: str 196 | :param topic_vector: vector with length=num_topics 197 | :param topic_table: table with topic key words 198 | :return: plot 199 | ''' 200 | plt.figure(figsize=figsize) 201 | print('doc1: {}'.format(sentence1)) 202 | print('doc2: {}'.format(sentence2)) 203 | topic_keywords = shorten_topic_keywords(topic_table,n_keywords) 204 | title = 'doc1: {}\ndoc2: {}'.format(sentence1, sentence2) 205 | if not alpha is None: 206 | title = '{}\nalpha: {}'.format(title,alpha) 207 | plt.title(title) 208 | for i in range(len(topic_vector1)): 209 | if i > 0: 210 | label1 = None 211 | label2 = None 212 | else: 213 | label1 = 'doc1' 214 | label2 = 'doc2' 215 | plt.barh(topic_keywords, topic_vector1[i], alpha=0.1, label=label1, color='r') 216 | plt.barh(topic_keywords, topic_vector2[i], alpha=0.1, label=label2, color='b') 217 | plt.legend(loc='best') 218 | plt.show() 219 | 220 | if __name__ == '__main__': 221 | 222 | from src.loaders.load_data import load_data 223 | 224 | opt = {'dataset': 'Quora', 'datapath': 'data/', 225 | 'tasks': ['B'], 226 | 'subsets': ['train', 227 | 'dev', 228 | 'test'], 229 | 'max_m':100, 230 | 'num_topics': 50, 'topic_type': 'ldamallet','topic':'word+doc', 231 | 'topic_alpha':10#,'unk_topic':'zero',#'unflat_topics':True #,'topic_update':True 232 | } 233 | data_dict = load_data(opt) 234 | 235 | 236 | sentence1 = data_dict['R1'][0][0] 237 | sentence2 = data_dict['R2'][0][0] 238 | topic_vector1 = data_dict['D_T1'][0][0] 239 | topic_vector2 = data_dict['D_T2'][0][0] 240 | topic_table = data_dict['topic_keys'] 241 | visualise_doc_topic_pair_pred(sentence1, sentence2, topic_vector1, topic_vector2, topic_table, figsize=(5,10),n_keywords=None, alpha=opt['topic_alpha']) -------------------------------------------------------------------------------- /tBERT.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuningxi/tBERT/604e3f728629c5dd0fe1f201464440edbaf82e79/tBERT.jpg --------------------------------------------------------------------------------