├── .gitignore ├── LICENSE.txt ├── README.md ├── cmdlines ├── airlines_ae.sh ├── airlines_ae_mvc.sh ├── airlines_mvc.sh ├── airlines_none.sh ├── airlines_qt.sh ├── airlines_qt_mvc.sh ├── ubuntu_ae.sh ├── ubuntu_ae_mvc.sh ├── ubuntu_mvc.sh ├── ubuntu_none.sh ├── ubuntu_qt.sh └── ubuntu_qt_mvc.sh ├── data ├── LICENSE ├── README.md ├── airlines_500onlyb.csv.bz2 ├── airlines_processed.csv.bz2 ├── airlines_raw.csv.bz2 ├── askubuntu_processed.csv.bz2 └── askubuntu_raw.csv.bz2 ├── datasets ├── askubuntu_preprocess.py ├── labeled_unlabeled_merger.py ├── readme.md ├── twitter_airlines_raw_merger2.py └── twitter_dataset_preprocess.py ├── images ├── avkmeans_graph.png └── example_dialogs.png ├── labeled_unlabeled_merger.py ├── metrics ├── __init__.py └── cluster_metrics.py ├── model ├── __init__.py ├── decoder.py ├── encoder.py ├── multiview_encoders.py └── utils.py ├── preprocessing ├── __init__.py ├── askubuntu.py └── twitter_airlines.py ├── pretrain.py ├── proc_data.py ├── requirements.txt ├── run_mvsc.py ├── samplers.py ├── setup.cfg ├── tests ├── test_multiview_encoders.py └── test_samplers.py ├── train.py ├── train_pca.py └── train_qt.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 ASAPP Inc 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Code and data for paper ["Dialog Intent Induction with Deep Multi-View Clustering"](https://arxiv.org/abs/1908.11487), Hugh Perkins and Yi Yang, 2019, to appear in EMNLP 2019. 2 | 3 | Data is available in the sub-directory [data](data), with a specific [LICENSE](data/LICENSE) file. 4 | 5 | # Dialog Intent Induction 6 | 7 | Dialog intent induction aims at automatically discovering dialog intents from human-human conversations. The problem is largely overlooked in the prior academic works, which created a hugh gap between academics and industry. 8 | 9 | In particular, **academic dialog** datasets such as ATIS and MultiWoZ assume dialog intents are given; they also focus on simple dialog intents like `BookRestaurant` or `BookHotel`. However, many complex dialog intents emerge in **industrial settings** that are hard to predefine; the dialog intents are also undergoing dynamic changes. 10 | 11 | ## Deep multi-view clustering 12 | 13 | In this work, we propose to tackle this problem using multi-view clustering. Consider the following example dialogs: 14 | 15 | 16 | 17 | The user query utterances (query view) are lexically and syntactically dissimilar. However, the solution trajectories (content view) are similar. 18 | 19 | ### Alternating-view k-means (AV-KMeans) 20 | 21 | We propose a novel method for joint representation learning and multi-view cluster: alternating-view k-means (AV-KMeans). 22 | 23 | 24 | 25 | We perform clustering on view 1 and project the assignment to view 2 for classification. The encoders are fixed for clustering and updated for classification. 26 | 27 | ## Experiments 28 | 29 | We construct a new dataset to evaluate this new intent induction task: Twitter Airlines Customer Support (TwACS). 30 | 31 | We compare three competitive clustering methods: k-means, [Multi-View Spectral Clustering (MVSC)](https://github.com/mariceli3/multiview), and `AV-Kmeans` (ours). We experiment with three approaches to parameter initialization: PCA for k-means and MVSC; autoencoders; and [quick thoughts](https://arxiv.org/pdf/1803.02893.pdf). 32 | 33 | The F1 scores are presented below: 34 | 35 | |Algo | PCA/None | autoencoders | quick thoughts | 36 | |------|----------|--------------|----------------| 37 | |k-means| 28.2 | 29.5 | 42.1| 38 | |MVSC| 27.8 | 31.3 | 40 | 39 | |**AV-Kmeans (ours)** | **35.4** | **38.9** | **46.2** | 40 | 41 | 42 | # Usage 43 | 44 | ## Pre-requisites 45 | 46 | - decompress the `.bz2` files in `data`folder 47 | - download http://nlp.stanford.edu/data/glove.840B.300d.zip, and unzip `glove.840B.300d.txt` into `data` folder 48 | 49 | ## To run AV-Kmeans 50 | 51 | - run one of: 52 | ``` 53 | # no pre-training 54 | cmdlines/airlines_mvc.sh 55 | 56 | # ae pre-training 57 | cmdlines/airlines_ae.sh 58 | cmdlines/airlines_ae_mvc.sh 59 | 60 | # qt pre-training 61 | cmdlines/airlines_qt.sh 62 | cmdlines/airlines_qt_mvc.sh 63 | ``` 64 | - to train on askubuntu, replace `airlines` with `ubuntu` in the above command-lines 65 | 66 | ## To run k-means baseline 67 | 68 | - for qt pretraining run: 69 | ``` 70 | PYTHONPATH=. python train_qt.py --data-path data/airlines_processed.csv --pre-epoch 10 --view1-col first_utterance --view2-col context --scenarios view1 71 | ``` 72 | - to train on askubuntu, replace `airlines` with `askubuntu` in the above command-line, and remove the `--*-col` command-line options 73 | -------------------------------------------------------------------------------- /cmdlines/airlines_ae.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py --pre-epoch 20 --pre-model ae --num-epochs 0 \ 4 | --data-path data/airlines_processed.csv \ 5 | --view1-col first_utterance --view2-col context --label-col tag \ 6 | --save-model-path data/airlines_ae.pth 7 | -------------------------------------------------------------------------------- /cmdlines/airlines_ae_mvc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ ! -f data/airlines_ae.pth ]]; then { 4 | echo Please train airlines ae first 5 | } fi 6 | 7 | python train.py --pre-epoch 0 --num-epochs 50 --data-path data/airlines_processed.csv \ 8 | --view1-col first_utterance --view2-col context --label-col tag \ 9 | --model-path data/airlines_ae.pth 10 | -------------------------------------------------------------------------------- /cmdlines/airlines_mvc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py --pre-epoch 0 --num-epochs 50 --data-path data/airlines_processed.csv \ 4 | --view1-col first_utterance --view2-col context --label-col tag 5 | -------------------------------------------------------------------------------- /cmdlines/airlines_none.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py --pre-epoch 0 --num-epochs 0 --data-path data/airlines_processed.csv \ 4 | --view1-col first_utterance --view2-col context --label-col tag 5 | -------------------------------------------------------------------------------- /cmdlines/airlines_qt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py --pre-epoch 10 --pre-model qt --num-epochs 0 --data-path data/airlines_processed.csv \ 4 | --view1-col first_utterance --view2-col context --label-col tag \ 5 | --save-model-path data/airlines_qt.pth 6 | -------------------------------------------------------------------------------- /cmdlines/airlines_qt_mvc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ ! -f data/airlines_qt.pth ]]; then { 4 | echo Please train airlines qt first 5 | } fi 6 | 7 | python train.py --pre-epoch 0 --num-epochs 50 --data-path data/airlines_processed.csv \ 8 | --view1-col first_utterance --view2-col context --label-col tag \ 9 | --model-path data/airlines_qt.pth 10 | -------------------------------------------------------------------------------- /cmdlines/ubuntu_ae.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py --pre-epoch 20 --pre-model ae --num-epochs 0 \ 4 | --data-path data/askubuntu_processed.csv \ 5 | --save-model-path data/ubuntu_ae.pth 6 | -------------------------------------------------------------------------------- /cmdlines/ubuntu_ae_mvc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ ! -f data/ubuntu_ae.pth ]]; then { 4 | echo Please train askubuntu ae first 5 | } fi 6 | 7 | python train.py --pre-epoch 0 --num-epochs 50 --data-path data/askubuntu_processed.csv \ 8 | --model-path data/ubuntu_ae.pth 9 | -------------------------------------------------------------------------------- /cmdlines/ubuntu_mvc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py --pre-epoch 0 --num-epochs 50 --data-path data/askubuntu_processed.csv 4 | -------------------------------------------------------------------------------- /cmdlines/ubuntu_none.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py --pre-epoch 0 --num-epochs 0 --data-path data/askubuntu_processed.csv 4 | -------------------------------------------------------------------------------- /cmdlines/ubuntu_qt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py --pre-epoch 10 --pre-model qt --num-epochs 0 --data-path data/askubuntu_processed.csv \ 4 | --save-model-path data/ubuntu_qt.pth 5 | -------------------------------------------------------------------------------- /cmdlines/ubuntu_qt_mvc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ ! -f data/ubuntu_qt.pth ]]; then { 4 | echo Please train askubuntu qt first 5 | } fi 6 | 7 | python train.py --pre-epoch 0 --num-epochs 50 --data-path data/askubuntu_processed.csv \ 8 | --model-path data/ubuntu_qt.pth 9 | -------------------------------------------------------------------------------- /data/LICENSE: -------------------------------------------------------------------------------- 1 | There are four dataset files in this folder: 2 | 3 | - airlines_raw.csv.bz2 and airlines_processed.csv.bz2 are derived from Kaggle 'Customer Support on Twitter' dataset, 4 | https://www.kaggle.com/thoughtvector/customer-support-on-twitter, 5 | which is available under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0) license, 6 | https://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | - askubuntu_raw.csv.bz2 and askubuntu_processed.csv.bz2 are derived from Stack Exchange Data Dump, https://archive.org/details/stackexchange, 9 | which is provided under a Creative Common Attribution-ShareAlike 3.0 Unported (CC BY-SA 3.0) license, https://creativecommons.org/licenses/by-sa/3.0/ 10 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | Data 2 | ---- 3 | 4 | Training vs test data 5 | --------------------- 6 | 7 | We train on all data, without labels. We use the labels in order to evaluate the resulting clusters. 8 | 9 | Twitter Airlines Customer Support 10 | --------------------------------- 11 | 12 | The data is available in two version: 13 | 14 | - raw: minimal redaction (company and customer twitter id), no preprocessing: [airlines_raw.csv.bz2](airlines_raw.csv.bz2) 15 | - redacted, and preprocessed: [airlines_processed.csv.bz2](airlines_processed.csv.bz2) 16 | 17 | We sampled 500 examples, and annotated them. 8 examples were rejected because not English, leaving 492 labeled examples. The remaining examples were labeled `UNK`. 18 | 19 | AskUbuntu 20 | --------- 21 | 22 | - raw, no preprocessing: [askubuntu_raw.csv.bz2](askubuntu_raw.csv.bz2) 23 | - preprocessed: [askubuntu_processed.csv.bz2](askubuntu_processed.csv.bz2) 24 | -------------------------------------------------------------------------------- /data/airlines_500onlyb.csv.bz2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/dialog-intent-induction/6396f3153b0fda7e170b1df6b68e969b5e4eb16e/data/airlines_500onlyb.csv.bz2 -------------------------------------------------------------------------------- /data/airlines_processed.csv.bz2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/dialog-intent-induction/6396f3153b0fda7e170b1df6b68e969b5e4eb16e/data/airlines_processed.csv.bz2 -------------------------------------------------------------------------------- /data/airlines_raw.csv.bz2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/dialog-intent-induction/6396f3153b0fda7e170b1df6b68e969b5e4eb16e/data/airlines_raw.csv.bz2 -------------------------------------------------------------------------------- /data/askubuntu_processed.csv.bz2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/dialog-intent-induction/6396f3153b0fda7e170b1df6b68e969b5e4eb16e/data/askubuntu_processed.csv.bz2 -------------------------------------------------------------------------------- /data/askubuntu_raw.csv.bz2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/dialog-intent-induction/6396f3153b0fda7e170b1df6b68e969b5e4eb16e/data/askubuntu_raw.csv.bz2 -------------------------------------------------------------------------------- /datasets/askubuntu_preprocess.py: -------------------------------------------------------------------------------- 1 | """ 2 | given askubuntu links, find the distribution of clusters formed of connected subgraphs 3 | 4 | needs networkx installed (conda install networkx) 5 | 6 | conda install -y networkx beautifulsoup4 lxml 7 | python -c 'import nltk; nltk.download("stopwords")' 8 | 9 | """ 10 | import random 11 | import argparse 12 | import csv 13 | from collections import defaultdict 14 | import time 15 | from os.path import expanduser as expand 16 | from xml.dom import minidom 17 | 18 | import networkx as nx 19 | import torch 20 | import numpy as np 21 | 22 | from preprocessing.askubuntu import Preprocessor 23 | 24 | 25 | class NullPreprocessor(object): 26 | def __call__(self, text): 27 | text = text.replace('\n', ' ').replace('\r', ' ') 28 | return True, text 29 | 30 | 31 | def index_posts(args): 32 | """ 33 | got through Posts.xml, and find for each post id, the id of the answer. store both in a csv file 34 | 35 | example row: 36 | 45 | """ 46 | in_f = open(expand(args.in_posts), 'r') 47 | out_f = open(expand(args.out_posts_index), 'w') 48 | dict_writer = csv.DictWriter(out_f, fieldnames=['question_id', 'answer_id']) 49 | dict_writer.writeheader() 50 | last_print = time.time() 51 | for n, row in enumerate(in_f): 52 | row = row.strip() 53 | if not row.startswith('= args.in_max_posts: 63 | print('reached max rows => breaking') 64 | break 65 | if time.time() - last_print >= 3.0: 66 | print(n) 67 | last_print = time.time() 68 | 69 | 70 | def get_clusters(in_dupes_file, in_max_dupes, max_clusters): 71 | """ 72 | we're reading in the list of pairs of dupes. These are a in format of two post ids per line, like: 73 | 615465 8653 74 | 833376 377050 75 | 30585 120621 76 | 178532 152184 77 | 69455 68850 78 | 79 | When we read these in, we have no idea whether these posts have answers etc. We just read in all 80 | the pairs of post ids. We are then going to add these post ids to a graph (each post id forms a node), 81 | and the pairs of post ids become connectsion in the graph. We then form all connected components. 82 | 83 | We sort the connected components by size (reverse order), and take the top num_test_clusters components 84 | We just ignore the other components 85 | """ 86 | f_in = open(expand(in_dupes_file), 'r') 87 | csv_reader = csv.reader(f_in, delimiter=' ') 88 | G = nx.Graph() 89 | for n, row in enumerate(csv_reader): 90 | left = int(row[0]) 91 | right = int(row[1]) 92 | G.add_edge(left, right) 93 | if in_max_dupes is not None and n >= in_max_dupes: 94 | print('reached max tao rows => break') 95 | break 96 | print('num nodes', len(G)) 97 | print('num clusters', len(list(nx.connected_components(G)))) 98 | count_by_size = defaultdict(int) 99 | clusters_by_size = defaultdict(list) 100 | for i, cluster in enumerate(nx.connected_components(G)): 101 | size = len(cluster) 102 | count_by_size[size] += 1 103 | clusters_by_size[size].append(cluster) 104 | print('count by size:') 105 | clusters = [] 106 | top_clusters = [] 107 | for size, count in sorted(count_by_size.items(), reverse=True): 108 | for cluster in clusters_by_size[size]: 109 | clusters.append(cluster) 110 | if len(top_clusters) < max_clusters: 111 | top_clusters.append(cluster) 112 | 113 | print('len(clusters)', len(clusters)) 114 | top_cluster_post_ids = [id for cluster in top_clusters for id in cluster] 115 | 116 | post2cluster = {} 117 | for cluster_id, cluster in enumerate(clusters): 118 | for post in cluster: 119 | post2cluster[post] = cluster_id 120 | 121 | return top_clusters, top_cluster_post_ids, post2cluster 122 | 123 | 124 | def read_posts_index(in_posts_index): 125 | with open(expand(args.in_posts_index), 'r') as f: 126 | dict_reader = csv.DictReader(f) 127 | index_rows = list(dict_reader) 128 | index_rows = [{'question_id': int(row['question_id']), 'answer_id': int(row['answer_id'])} for row in index_rows] 129 | post2answer = {row['question_id']: row['answer_id'] for row in index_rows} 130 | answer2post = {row['answer_id']: row['question_id'] for row in index_rows} 131 | print('loaded index') 132 | return post2answer, answer2post 133 | 134 | 135 | def load_posts(answer2post, post2answer, in_posts): 136 | # load in all the posts from Posts.xml 137 | post_by_id = defaultdict(dict) 138 | posts_f = open(expand(in_posts), 'r') 139 | last_print = time.time() 140 | for n, row in enumerate(posts_f): 141 | row = row.strip() 142 | if not row.startswith(' 161 | """ 162 | dom = minidom.parseString(row) 163 | node = dom.firstChild 164 | att = node.attributes 165 | assert att['PostTypeId'].value == '1' 166 | post_by_id[row_id]['question_title'] = att['Title'].value 167 | post_by_id[row_id]['question_body'] = att['Body'].value 168 | elif row_id in answer2post: 169 | dom = minidom.parseString(row) 170 | node = dom.firstChild 171 | att = node.attributes 172 | assert att['PostTypeId'].value == '2' 173 | post_id = answer2post[row_id] 174 | post_by_id[post_id]['answer_body'] = att['Body'].value 175 | if time.time() - last_print >= 3.0: 176 | print(len(post_by_id)) 177 | last_print = time.time() 178 | if args.in_max_posts is not None and n > args.in_max_posts: 179 | print('reached in_max_posts => terminating') 180 | break 181 | print('loaded info from Posts.xml') 182 | return post_by_id 183 | 184 | 185 | def create_labeled(args): 186 | torch.manual_seed(args.seed) 187 | np.random.seed(args.seed) 188 | random.seed(args.seed) 189 | 190 | out_f = open(expand(args.out_labeled), 'w') # open now to check we can 191 | 192 | clusters, post_ids, post2cluster = get_clusters( 193 | in_dupes_file=args.in_dupes, in_max_dupes=args.in_max_dupes, max_clusters=args.max_clusters) 194 | print('cluster sizes:', [len(cluster) for cluster in clusters]) 195 | print('len(clusters)', len(clusters)) 196 | print('len(post_ids) from dupes graph', len(post_ids)) 197 | 198 | print('removing post ids which dont have answers...') 199 | post2answer, answer2post = read_posts_index(in_posts_index=args.in_posts_index) 200 | post_ids = [id for id in post_ids if id in post2answer] 201 | print('len(post_ids) after removing no answer', len(post_ids)) 202 | new_clusters = [] 203 | for cluster in clusters: 204 | cluster = [id for id in cluster if id in post2answer] 205 | new_clusters.append(cluster) 206 | clusters = new_clusters 207 | print('len clusters after removing no answer', [len(cluster) for cluster in clusters]) 208 | 209 | post_ids_set = set(post_ids) 210 | print('len(post_ids_set)', len(post_ids_set)) 211 | answer_ids = [post2answer[id] for id in post_ids] 212 | print('len(answer_ids)', len(answer_ids)) 213 | 214 | preprocessor = Preprocessor(max_len=args.max_len) if not args.no_preprocess else NullPreprocessor() 215 | post_by_id = load_posts(answer2post=answer2post, post2answer=post2answer, in_posts=args.in_posts) 216 | 217 | count_by_state = defaultdict(int) 218 | n = 0 219 | dict_writer = csv.DictWriter(out_f, fieldnames=[ 220 | 'id', 'cluster_id', 'question_title', 'question_body', 'answer_body']) 221 | dict_writer.writeheader() 222 | for post_id in post_ids: 223 | if post_id not in post_by_id: 224 | count_by_state['not in post_by_id'] += 1 225 | continue 226 | post = post_by_id[post_id] 227 | if 'answer_body' not in post or 'question_body' not in post: 228 | count_by_state['no body, or no answer'] += 1 229 | continue 230 | count_by_state['ok'] += 1 231 | cluster_id = post2cluster[post_id] 232 | row = { 233 | 'id': post_id, 234 | 'cluster_id': cluster_id, 235 | 'question_title': preprocessor(post['question_title'])[1], 236 | 'question_body': preprocessor(post['question_body'])[1], 237 | 'answer_body': preprocessor(post['answer_body'])[1] 238 | } 239 | dict_writer.writerow(row) 240 | n += 1 241 | print(count_by_state) 242 | print('rows written', n) 243 | 244 | 245 | def create_unlabeled(args): 246 | """ 247 | this is going to do: 248 | - take a question (specific type in Posts.xml) 249 | - match it with the accepted answer (using the index) 250 | - write these out 251 | """ 252 | torch.manual_seed(args.seed) 253 | np.random.seed(args.seed) 254 | random.seed(args.seed) 255 | 256 | out_f = open(expand(args.out_unlabeled), 'w') # open now, to check we can... 257 | 258 | post2answer, answer2post = read_posts_index(in_posts_index=args.in_posts_index) 259 | print('loaded index') 260 | print('posts in index', len(post2answer)) 261 | 262 | preprocessor = Preprocessor(max_len=args.max_len) if not args.no_preprocess else NullPreprocessor() 263 | post_by_id = load_posts(post2answer=post2answer, answer2post=answer2post, in_posts=args.in_posts) 264 | print('loaded all posts, len(post_by_id)', len(post_by_id)) 265 | 266 | count_by_state = defaultdict(int) 267 | n = 0 268 | dict_writer = csv.DictWriter(out_f, fieldnames=['id', 'question_title', 'question_body', 'answer_body']) 269 | dict_writer.writeheader() 270 | last_print = time.time() 271 | for post_id, info in post_by_id.items(): 272 | if 'answer_body' not in info or 'question_body' not in info: 273 | count_by_state['no body, or no answer'] += 1 274 | continue 275 | count_by_state['ok'] += 1 276 | dict_writer.writerow({ 277 | 'id': post_id, 278 | 'question_title': preprocessor(info['question_title'])[1], 279 | 'question_body': preprocessor(info['question_body'])[1], 280 | 'answer_body': preprocessor(info['answer_body'])[1] 281 | }) 282 | if time.time() - last_print >= 10: 283 | print('written', n) 284 | last_print = time.time() 285 | n += 1 286 | 287 | print(count_by_state) 288 | 289 | 290 | if __name__ == '__main__': 291 | root_parser = argparse.ArgumentParser() 292 | parsers = root_parser.add_subparsers() 293 | 294 | parser = parsers.add_parser('index-posts') 295 | parser.add_argument('--in-posts-dir', type=str, default='~/data/askubuntu.com') 296 | parser.add_argument('--in-posts', type=str, default='{in_posts_dir}/Posts.xml') 297 | parser.add_argument('--out-posts-index', type=str, default='{in_posts_dir}/Posts_index.csv') 298 | parser.add_argument('--in-max-posts', type=int, help='for dev/debugging mostly') 299 | parser.set_defaults(func=index_posts) 300 | 301 | parser = parsers.add_parser('create-labeled') 302 | parser.add_argument('--seed', type=int, default=123) 303 | parser.add_argument('--in-dupes-dir', type=str, default='~/data/askubuntu') 304 | parser.add_argument('--in-dupes', type=str, default='{in_dupes_dir}/full.pos.txt') 305 | parser.add_argument('--max-clusters', type=int, default=20) 306 | parser.add_argument('--in-posts-dir', type=str, default='~/data/askubuntu.com') 307 | parser.add_argument('--in-posts', type=str, default='{in_posts_dir}/Posts.xml') 308 | parser.add_argument('--in-posts-index', type=str, default='{in_posts_dir}/Posts_index.csv') 309 | parser.add_argument('--in-max-dupes', type=int, help='for dev/debugging mostly') 310 | parser.add_argument('--in-max-posts', type=int, help='for dev/debugging mostly') 311 | parser.add_argument('--out-labeled', type=str, required=True) 312 | parser.add_argument('--no-preprocess', action='store_true') 313 | parser.add_argument('--max-len', type=int, default=500, help='set to 0 to disable') 314 | parser.set_defaults(func=create_labeled) 315 | 316 | parser = parsers.add_parser('create-unlabeled') 317 | parser.add_argument('--seed', type=int, default=123) 318 | parser.add_argument('--in-posts-dir', type=str, default='~/data/askubuntu.com') 319 | parser.add_argument('--in-dupes-dir', type=str, default='~/data/askubuntu') 320 | parser.add_argument('--in-posts', type=str, default='{in_posts_dir}/Posts.xml') 321 | parser.add_argument('--in-posts-index', type=str, default='{in_posts_dir}/Posts_index.csv') 322 | parser.add_argument('--in-max-posts', type=int, help='for dev/debugging mostly') 323 | parser.add_argument('--out-unlabeled', type=str, required=True) 324 | parser.add_argument('--max-len', type=int, default=500, help='set to 0 to disable') 325 | parser.add_argument('--no-preprocess', action='store_true') 326 | parser.set_defaults(func=create_unlabeled) 327 | 328 | args = root_parser.parse_args() 329 | for k in [ 330 | 'in_dupes', 'out_labeled', 'in_posts', 'out_posts_index', 'in_posts_index', 331 | 'out_unlabeled']: 332 | if k in args.__dict__: 333 | args.__dict__[k] = args.__dict__[k].format(**args.__dict__) 334 | 335 | func = args.func 336 | del args.__dict__['func'] 337 | func(args) 338 | -------------------------------------------------------------------------------- /datasets/labeled_unlabeled_merger.py: -------------------------------------------------------------------------------- 1 | """ 2 | given a labeled and unlabeled datafile, creates a single file, taht contains the data 3 | from both. the labels for the unlabeled data file will be UNK 4 | """ 5 | import argparse 6 | from collections import OrderedDict 7 | import csv 8 | 9 | 10 | def get_col_by_role(role_string): 11 | col_by_role = {} 12 | for s in role_string.split(','): 13 | split_s = s.split('=') 14 | col_by_role[split_s[0]] = split_s[1] 15 | print('col_by_role', col_by_role) 16 | return col_by_role 17 | 18 | 19 | def run(labeled_files, unlabeled_files, out_file, unlabeled_columns, labeled_columns, no_add_cust_tokens, add_columns, column_order): 20 | add_columns = add_columns.split(',') 21 | out_f = open(out_file, 'w') 22 | dict_writer = csv.DictWriter(out_f, fieldnames=column_order.split(',')) 23 | dict_writer.writeheader() 24 | 25 | labeled_col_by_role = get_col_by_role(labeled_columns) 26 | unlabeled_col_by_role = get_col_by_role(unlabeled_columns) 27 | 28 | labeled_view1 = labeled_col_by_role['view1'] 29 | labeled_label = labeled_col_by_role['label'] 30 | 31 | unlabeled_view1 = unlabeled_col_by_role['view1'] 32 | unlabeled_view2 = unlabeled_col_by_role['view2'] 33 | 34 | unlabeled_by_first_utterance = OrderedDict() 35 | for filename in unlabeled_files: 36 | print(filename) 37 | with open(filename, 'r') as in_f: 38 | dict_reader = csv.DictReader(in_f) 39 | for row in dict_reader: 40 | view1 = row[unlabeled_view1] 41 | if not no_add_cust_tokens and not view1.startswith('' 43 | out_row = { 44 | 'view1': view1, 45 | 'view2': row[unlabeled_view2], 46 | 'label': 'UNK' 47 | } 48 | for k in add_columns: 49 | out_row[k] = row[k] 50 | unlabeled_by_first_utterance[view1] = out_row 51 | print('loaded unlabeled') 52 | 53 | for filename in labeled_files: 54 | print(filename) 55 | with open(filename, 'r') as in_f: 56 | dict_reader = csv.DictReader(in_f) 57 | for row in dict_reader: 58 | view1 = row[labeled_view1] 59 | if view1 not in unlabeled_by_first_utterance: 60 | print('warning: not found in unlabelled', view1) 61 | continue 62 | out_row = unlabeled_by_first_utterance[view1] 63 | out_row['label'] = row[labeled_label] 64 | for k in add_columns: 65 | assert out_row[k] == row[k] 66 | 67 | for row in unlabeled_by_first_utterance.values(): 68 | dict_writer.writerow(row) 69 | 70 | out_f.close() 71 | 72 | 73 | if __name__ == '__main__': 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument('--labeled-files', type=str, nargs='+', required=True) 76 | parser.add_argument('--unlabeled-files', type=str, nargs='+', required=True) 77 | parser.add_argument('--out-file', type=str, required=True) 78 | parser.add_argument('--unlabeled-columns', type=str, default='view1=first_utterance,view2=context') 79 | parser.add_argument('--labeled-columns', type=str, default='view1=text,label=tag') 80 | parser.add_argument('--no-add-cust-tokens', action='store_true') 81 | parser.add_argument('--add-columns', type=str, default='id,question_body') 82 | parser.add_argument('--column-order', type=str, default='id,label,view1,view2,question_body') 83 | args = parser.parse_args() 84 | run(**args.__dict__) 85 | -------------------------------------------------------------------------------- /datasets/readme.md: -------------------------------------------------------------------------------- 1 | To generate the data from raw data: 2 | 3 | For askubuntu: 4 | 5 | ``` 6 | export PYTHONPATH=. 7 | mkdir ~/data/askubuntu_induction 8 | python datasets/askubuntu_preprocess.py create-labeled --out-labeled ~/data/askubuntu_induction/preproc_labeled.csv 9 | python datasets/askubuntu_preprocess.py create-unlabeled --out-unlabeled ~/data/askubuntu_induction/preproc_unlabeled.csv 10 | python datasets/labeled_unlabeled_merger.py --labeled-files ~/data/askubuntu_induction/preproc_labeled.csv --unlabeled-files ~/data/askubuntu_induction/preproc_unlabeled.csv --out-file ~/data/askubuntu_induction/preproc_merged.csv --labeled-columns view1=question_title,label=cluster_id --unlabeled-columns view1=question_title,view2=answer_body --no-add-cust-tokens 11 | ``` 12 | 13 | For Twitter customer support: 14 | ``` 15 | (cd data; bunzip2 airlines_500onlyb.csv.bz2) 16 | python datasets/twitter_dataset_preprocess.py --no-preprocess --out-examples ~/data/twittersupport/airlines_raw.csv 17 | python datasets/twitter_airlines_raw_merger2.py 18 | ``` 19 | -------------------------------------------------------------------------------- /datasets/twitter_airlines_raw_merger2.py: -------------------------------------------------------------------------------- 1 | """ 2 | takes contents of airlines_raw.csv, and of data/airlines_mergedb.csv, and creates 3 | airlines_500_merged.csv, using the raw tweets from airlines_raw.csv, and the tags from airlines_500_mergedb.csv 4 | """ 5 | import argparse, os, time, csv, json 6 | from os import path 7 | from os.path import join, expanduser as expand 8 | 9 | def run(args): 10 | with open(expand(args.airlines_raw_file), 'r') as f: 11 | dict_reader = csv.DictReader(f) 12 | raw_rows = list(dict_reader) 13 | raw_row_by_tweet_id = {int(row['first_tweet_id']): row for i, row in enumerate(raw_rows)} 14 | 15 | f_in = open(expand(args.airlines_merged_file), 'r') 16 | f_out = open(expand(args.out_file), 'w') 17 | dict_reader = csv.DictReader(f_in) 18 | dict_writer = csv.DictWriter(f_out, fieldnames=['first_tweet_id', 'tag', 'first_utterance', 'context']) 19 | dict_writer.writeheader() 20 | for i, old_merged_row in enumerate(dict_reader): 21 | if old_merged_row['first_tweet_id'] == '': 22 | continue 23 | tweet_id = int(old_merged_row['first_tweet_id']) 24 | raw_row = raw_row_by_tweet_id[tweet_id] 25 | raw_row['tag'] = old_merged_row['tag'] 26 | dict_writer.writerow(raw_row) 27 | # so, we accidentally had a blank line at the end of the non-raw dataset. add that in here... 28 | dict_writer.writerow({'first_tweet_id': '', 'tag': 'UNK', 'first_utterance': '', 'context': ''}) 29 | f_out.close() 30 | 31 | if __name__ == '__main__': 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--airlines-raw-file', type=str, default='~/data/twittersupport/airlines_raw.csv') 34 | parser.add_argument('--airlines-merged-file', type=str, default='data/airlines_500_mergedb.csv') 35 | parser.add_argument('--out-file', type=str, default='~/data/twittersupport/airlines_raw_merged.csv') 36 | args = parser.parse_args() 37 | run(args) 38 | -------------------------------------------------------------------------------- /datasets/twitter_dataset_preprocess.py: -------------------------------------------------------------------------------- 1 | """ 2 | preprocessing twitter dataset 3 | """ 4 | import time 5 | import os 6 | from os import path 7 | from os.path import join, expanduser as expand 8 | import json 9 | import re 10 | import csv 11 | from collections import defaultdict 12 | import argparse 13 | 14 | import nltk 15 | 16 | from preprocessing import twitter_airlines 17 | 18 | 19 | all_companies = 'AppleSupport,Amazon,Uber_Support,SpotifyCares,Delta,Tesco,AmericanAir,TMobileHelp,comcastcares,British_Airways,SouthwestAir,VirginTrains,Ask_Spectrum,XboxSupport,sprintcare,hulu_support,AskPlayStation,ChipotleTweets,UPSHelp,AskTarget'.split(',') 20 | 21 | """ 22 | example rows: 23 | 24 | tweet_id,author_id,inbound,created_at,text,response_tweet_id,in_response_to_tweet_id 25 | 1,sprintcare,False,Tue Oct 31 22:10:47 +0000 2017,@115712 I understand. I would like to assist you. We would need to get you into a private secured link to further assist.,2,3 26 | 2,115712,True,Tue Oct 31 22:11:45 +0000 2017,@sprintcare and how do you propose we do that,,1 27 | """ 28 | 29 | class NullPreprocessor(object): 30 | def __call__(self, text, info_dict=None): 31 | is_valid = True 32 | if 'DM' in text: 33 | is_valid = False 34 | if 'https://t.co' in text: 35 | is_valid = False 36 | 37 | text = text.replace('\n', ' ').replace('\r', ' ') 38 | 39 | target_company = info_dict['target_company'] 40 | cust_twitter_id = info_dict['cust_twitter_id'] 41 | text = text.replace('@' + target_company, ' __company__ ').replace('@' + cust_twitter_id, ' __cust__ ') 42 | 43 | return is_valid, text 44 | 45 | def run_for_company(in_csv_file, in_max_rows, examples_writer, target_company, no_preprocess): 46 | f_in = open(expand(in_csv_file), 'r') 47 | node_by_id = {} 48 | start_node_ids = [] 49 | dict_reader = csv.DictReader(f_in) 50 | next_by_prev = {} 51 | for row in dict_reader: 52 | id = row['tweet_id'] 53 | prev = row['in_response_to_tweet_id'] 54 | next_by_prev[prev] = id 55 | if prev == '' and ('@' + target_company) in row['text']: 56 | start_node_ids.append(id) 57 | node_by_id[id] = row 58 | if in_max_rows is not None and len(node_by_id) >= in_max_rows: 59 | print('reached max rows', in_max_rows, '=> breaking') 60 | break 61 | print('len(node_by_id)', len(node_by_id)) 62 | print('len(start_node_ids)', len(start_node_ids)) 63 | count_by_status = defaultdict(int) 64 | count_by_count = defaultdict(int) 65 | 66 | preprocessor = twitter_airlines.Preprocessor() if not no_preprocess else NullPreprocessor() 67 | 68 | for i, start_node_id in enumerate(start_node_ids): 69 | conversation_texts = [] 70 | first_utterance = None 71 | is_valid = True 72 | node = node_by_id[start_node_id] 73 | cust_twitter_id = node['author_id'] 74 | while True: 75 | text = node['text'].replace('\n', ' ') 76 | if node['inbound'] == 'True': 77 | start_tok = ' skipping conversation') 110 | break 111 | node = node_by_id[response_id] 112 | print(count_by_status) 113 | 114 | def run_aggregated(in_csv_file, in_max_rows, out_examples, target_companies, no_preprocess): 115 | with open(expand(out_examples), 'w') as f_examples: 116 | examples_writer = csv.DictWriter(f_examples, fieldnames=['first_tweet_id', 'first_utterance', 'context']) 117 | examples_writer.writeheader() 118 | for company in target_companies: 119 | print(company) 120 | run_for_company(in_csv_file, in_max_rows, examples_writer, company, no_preprocess) 121 | 122 | def run_by_company(in_csv_file, in_max_rows, out_examples_templ, target_companies, no_preprocess): 123 | for company in target_companies: 124 | print(company) 125 | out_examples = out_examples_templ.format(company=company) 126 | with open(expand(out_examples), 'w') as f_examples: 127 | examples_writer = csv.DictWriter(f_examples, fieldnames=['first_tweet_id', 'first_utterance', 'context']) 128 | examples_writer.writeheader() 129 | run_for_company(in_csv_file, in_max_rows, examples_writer, company, no_preprocess) 130 | 131 | if __name__ == '__main__': 132 | root_parser = argparse.ArgumentParser() 133 | parsers = root_parser.add_subparsers() 134 | 135 | parser = parsers.add_parser('run-aggregated') 136 | parser.add_argument('--in-csv-file', type=str, default='~/data/twittersupport/twcs.csv') 137 | parser.add_argument('--in-max-rows', type=int, help='for dev/debugging mostly') 138 | parser.add_argument('--out-examples', type=str, required=True) 139 | parser.add_argument('--target-companies', type=str, default='Delta,AmericanAir,British_Airways,SouthwestAir') 140 | parser.add_argument('--no-preprocess', action='store_true') 141 | parser.set_defaults(func=run_aggregated) 142 | 143 | parser = parsers.add_parser('run-by-company') 144 | parser.add_argument('--in-csv-file', type=str, default='~/data/twittersupport/twcs.csv') 145 | parser.add_argument('--in-max-rows', type=int, help='for dev/debugging mostly') 146 | parser.add_argument('--out-examples-templ', type=str, default='~/data/twittersupport/{company}.csv') 147 | parser.add_argument('--target-companies', type=str, default=','.join(all_companies)) 148 | parser.add_argument('--no-preprocess', action='store_true') 149 | parser.set_defaults(func=run_by_company) 150 | 151 | args = root_parser.parse_args() 152 | 153 | if args.func == run_aggregated: 154 | args.out_examples = args.out_examples.format(**args.__dict__) 155 | args.target_companies = args.target_companies.split(',') 156 | elif args.func == run_by_company: 157 | args.target_companies = args.target_companies.split(',') 158 | 159 | func = args.func 160 | del args.__dict__['func'] 161 | func(**args.__dict__) 162 | -------------------------------------------------------------------------------- /images/avkmeans_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/dialog-intent-induction/6396f3153b0fda7e170b1df6b68e969b5e4eb16e/images/avkmeans_graph.png -------------------------------------------------------------------------------- /images/example_dialogs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/dialog-intent-induction/6396f3153b0fda7e170b1df6b68e969b5e4eb16e/images/example_dialogs.png -------------------------------------------------------------------------------- /labeled_unlabeled_merger.py: -------------------------------------------------------------------------------- 1 | """ 2 | given a labeled and unlabeled datafile, creates a single file, that contains the data 3 | from both. the labels for the unlabeled data file will be UNK 4 | """ 5 | import argparse 6 | from collections import OrderedDict 7 | import csv 8 | 9 | 10 | def get_col_by_role(role_string): 11 | col_by_role = {} 12 | for s in role_string.split(','): 13 | split_s = s.split('=') 14 | col_by_role[split_s[0]] = split_s[1] 15 | print('col_by_role', col_by_role) 16 | return col_by_role 17 | 18 | 19 | def run(labeled_files, unlabeled_files, out_file, unlabeled_columns, labeled_columns, no_add_cust_tokens, add_columns, column_order): 20 | add_columns = add_columns.split(',') 21 | out_f = open(out_file, 'w') 22 | dict_writer = csv.DictWriter(out_f, fieldnames=column_order.split(',')) 23 | dict_writer.writeheader() 24 | 25 | labeled_col_by_role = get_col_by_role(labeled_columns) 26 | unlabeled_col_by_role = get_col_by_role(unlabeled_columns) 27 | 28 | labeled_view1 = labeled_col_by_role['view1'] 29 | labeled_label = labeled_col_by_role['label'] 30 | 31 | unlabeled_view1 = unlabeled_col_by_role['view1'] 32 | unlabeled_view2 = unlabeled_col_by_role['view2'] 33 | 34 | unlabeled_by_first_utterance = OrderedDict() 35 | for filename in unlabeled_files: 36 | print(filename) 37 | with open(filename, 'r') as in_f: 38 | dict_reader = csv.DictReader(in_f) 39 | for row in dict_reader: 40 | view1 = row[unlabeled_view1] 41 | if not no_add_cust_tokens and not view1.startswith('' 43 | out_row = { 44 | 'view1': view1, 45 | 'view2': row[unlabeled_view2], 46 | 'label': 'UNK' 47 | } 48 | for k in add_columns: 49 | out_row[k] = row[k] 50 | unlabeled_by_first_utterance[view1] = out_row 51 | print('loaded unlabeled') 52 | 53 | for filename in labeled_files: 54 | print(filename) 55 | with open(filename, 'r') as in_f: 56 | dict_reader = csv.DictReader(in_f) 57 | for row in dict_reader: 58 | view1 = row[labeled_view1] 59 | if view1 not in unlabeled_by_first_utterance: 60 | print('warning: not found in unlabelled', view1) 61 | continue 62 | out_row = unlabeled_by_first_utterance[view1] 63 | out_row['label'] = row[labeled_label] 64 | for k in add_columns: 65 | assert out_row[k] == row[k] 66 | 67 | for row in unlabeled_by_first_utterance.values(): 68 | dict_writer.writerow(row) 69 | 70 | out_f.close() 71 | 72 | 73 | if __name__ == '__main__': 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument('--labeled-files', type=str, nargs='+', required=True) 76 | parser.add_argument('--unlabeled-files', type=str, nargs='+', required=True) 77 | parser.add_argument('--out-file', type=str, required=True) 78 | parser.add_argument('--unlabeled-columns', type=str, default='view1=first_utterance,view2=context') 79 | parser.add_argument('--labeled-columns', type=str, default='view1=text,label=tag') 80 | parser.add_argument('--no-add-cust-tokens', action='store_true') 81 | parser.add_argument('--add-columns', type=str, default='id,question_body') 82 | parser.add_argument('--column-order', type=str, default='id,label,view1,view2,question_body') 83 | args = parser.parse_args() 84 | run(**args.__dict__) 85 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/dialog-intent-induction/6396f3153b0fda7e170b1df6b68e969b5e4eb16e/metrics/__init__.py -------------------------------------------------------------------------------- /metrics/cluster_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | cluster metrics: precision, recall, )f1) 3 | """ 4 | from collections import Counter 5 | 6 | import scipy.optimize 7 | import torch 8 | 9 | 10 | def calc_precision(gnd_assignments, pred_assignments): 11 | """ 12 | gnd_clusters should be a torch tensor of longs, containing 13 | the assignment to each cluster 14 | 15 | assumes that cluster assignments are 0-based, and no 'holes' 16 | """ 17 | precision_sum = 0 18 | assert len(gnd_assignments.size()) == 1 19 | assert len(pred_assignments.size()) == 1 20 | assert pred_assignments.size(0) == gnd_assignments.size(0) 21 | N = gnd_assignments.size(0) 22 | K_gnd = gnd_assignments.max().item() + 1 23 | K_pred = pred_assignments.max().item() + 1 24 | for k_pred in range(K_pred): 25 | mask = pred_assignments == k_pred 26 | gnd = gnd_assignments[mask.nonzero().long().view(-1)] 27 | max_intersect = 0 28 | for k_gnd in range(K_gnd): 29 | intersect = (gnd == k_gnd).long().sum().item() 30 | max_intersect = max(max_intersect, intersect) 31 | precision_sum += max_intersect 32 | precision = precision_sum / N 33 | return precision 34 | 35 | 36 | def calc_recall(gnd_assignments, pred_assignments): 37 | """ 38 | basically the reverse of calc_precision 39 | 40 | so, we can just call calc_precision in reverse :P 41 | """ 42 | return calc_precision(gnd_assignments=pred_assignments, pred_assignments=gnd_assignments) 43 | 44 | 45 | def calc_f1(gnd_assignments, pred_assignments): 46 | prec = calc_precision(gnd_assignments=gnd_assignments, pred_assignments=pred_assignments) 47 | recall = calc_recall(gnd_assignments=gnd_assignments, pred_assignments=pred_assignments) 48 | f1 = 2 * (prec * recall) / (prec + recall) 49 | return f1 50 | 51 | 52 | def calc_prec_rec_f1(gnd_assignments, pred_assignments): 53 | prec = calc_precision(gnd_assignments=gnd_assignments, pred_assignments=pred_assignments) 54 | recall = calc_recall(gnd_assignments=gnd_assignments, pred_assignments=pred_assignments) 55 | f1 = 2 * (prec * recall) / (prec + recall) 56 | return prec, recall, f1 57 | 58 | 59 | def calc_ACC(pred, gnd): 60 | assert len(pred.size()) == 1 61 | assert len(gnd.size()) == 1 62 | N = pred.size(0) 63 | assert N == gnd.size(0) 64 | M = torch.zeros(N, N, dtype=torch.int64) 65 | counts = Counter(list(zip(gnd.tolist(), pred.tolist()))) 66 | keys = torch.LongTensor(list(counts.keys())) 67 | values = torch.LongTensor(list(counts.values())) 68 | M = scipy.sparse.csr_matrix((values.numpy(), (keys[:, 0].numpy(), keys[:, 1].numpy()))) 69 | M = M.todense() 70 | 71 | row_ind, col_ind = scipy.optimize.linear_sum_assignment(-M) 72 | cost = M[row_ind, col_ind].sum().item() 73 | ACC = cost / N 74 | 75 | return ACC 76 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/dialog-intent-induction/6396f3153b0fda7e170b1df6b68e969b5e4eb16e/model/__init__.py -------------------------------------------------------------------------------- /model/decoder.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | import torch.nn as nn 3 | 4 | 5 | class Decoder(nn.Module): 6 | def __init__(self, latent_z_size, word_emb_size, word_vocab_size, decoder_rnn_size, 7 | decoder_num_layers, dropout=0.5): 8 | super(Decoder, self).__init__() 9 | self.latent_z_size = latent_z_size 10 | self.word_vocab_size = word_vocab_size 11 | self.decoder_rnn_size = decoder_rnn_size 12 | self.dropout = dropout 13 | self.rnn = nn.LSTM(input_size=latent_z_size + word_emb_size, 14 | hidden_size=decoder_rnn_size, 15 | num_layers=decoder_num_layers, 16 | dropout=dropout, 17 | batch_first=True) 18 | self.fc = nn.Linear(decoder_rnn_size, word_vocab_size) 19 | 20 | def forward(self, decoder_input, latent_z): 21 | """ 22 | :param decoder_input: tensor with shape of [batch_size, seq_len, emb_size] 23 | :param latent_z: sequence context with shape of [batch_size, latent_z_size] 24 | :return: unnormalized logits of sentense words distribution probabilities 25 | with shape of [batch_size, seq_len, word_vocab_size] 26 | 27 | TODO: add padding support 28 | """ 29 | 30 | [batch_size, seq_len, _] = decoder_input.size() 31 | # decoder rnn is conditioned on context via additional bias = W_cond * z to every input 32 | # token 33 | latent_z = t.cat([latent_z] * seq_len, 1).view(batch_size, seq_len, self.latent_z_size) 34 | decoder_input = t.cat([decoder_input, latent_z], 2) 35 | rnn_out, _ = self.rnn(decoder_input) 36 | rnn_out = rnn_out.contiguous().view(-1, self.decoder_rnn_size) 37 | result = self.fc(rnn_out) 38 | result = result.view(batch_size, seq_len, self.word_vocab_size) 39 | return result 40 | -------------------------------------------------------------------------------- /model/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Encoder(nn.Module): 6 | def __init__(self, word_emb_size, encoder_rnn_size, encoder_num_layers, dropout=0.5): 7 | super(Encoder, self).__init__() 8 | 9 | self.encoder_rnn_size = encoder_rnn_size 10 | self.encoder_num_layers = encoder_num_layers 11 | # Create input dropout parameter 12 | self.rnn = nn.LSTM(input_size=word_emb_size, 13 | hidden_size=encoder_rnn_size, 14 | num_layers=encoder_num_layers, 15 | dropout=dropout, 16 | batch_first=True, 17 | bidirectional=True) 18 | 19 | def forward(self, encoder_input, lengths): 20 | """ 21 | :param encoder_input: [batch_size, seq_len, emb_size] tensor 22 | :return: context of input sentenses with shape of [batch_size, encoder_rnn_size] 23 | """ 24 | lengths, perm_idx = lengths.sort(0, descending=True) 25 | encoder_input = encoder_input[perm_idx] 26 | [batch_size, seq_len, _] = encoder_input.size() 27 | packed_words = torch.nn.utils.rnn.pack_padded_sequence( 28 | encoder_input, lengths, True) 29 | # Unfold rnn with zero initial state and get its final state from the last layer 30 | rnn_out, (_, final_state) = self.rnn(packed_words, None) 31 | final_state = final_state.view( 32 | self.encoder_num_layers, 2, batch_size, self.encoder_rnn_size)[-1] 33 | h_1, h_2 = final_state[0], final_state[1] 34 | final_state = torch.cat([h_1, h_2], 1) 35 | _, unperm_idx = perm_idx.sort(0) 36 | final_state = final_state[unperm_idx] 37 | return final_state 38 | -------------------------------------------------------------------------------- /model/multiview_encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch as t 3 | from torch import nn 4 | import numpy as np 5 | 6 | from model.utils import pad_sentences, pad_paragraphs 7 | import train 8 | 9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | 11 | 12 | class MultiviewEncoders(nn.Module): 13 | 14 | def __init__( 15 | self, vocab_size, num_layers, embedding_size, lstm_hidden_size, word_dropout, dropout, 16 | start_idx=2, end_idx=3, pad_idx=0): 17 | super().__init__() 18 | self.pad_idx = pad_idx 19 | self.start_idx = start_idx # for RNN autoencoder training 20 | self.end_idx = end_idx # for RNN autoencoder training 21 | self.num_layers = num_layers 22 | self.embedding_size = embedding_size 23 | self.lstm_hidden_size = lstm_hidden_size 24 | self.word_dropout = nn.Dropout(word_dropout) 25 | self.dropout = dropout 26 | self.vocab_size = vocab_size 27 | self.crit = nn.CrossEntropyLoss() 28 | 29 | self.embedder = nn.Embedding(vocab_size, embedding_size) 30 | 31 | def create_rnn(embedding_size, bidirectional=True): 32 | return nn.LSTM( 33 | embedding_size, 34 | lstm_hidden_size, 35 | dropout=dropout, 36 | num_layers=num_layers, 37 | bidirectional=bidirectional, 38 | batch_first=True 39 | ) 40 | 41 | self.view1_word_rnn = create_rnn(embedding_size) 42 | self.view2_word_rnn = create_rnn(embedding_size) 43 | self.view2_sent_rnn = create_rnn(2*lstm_hidden_size) 44 | 45 | self.ae_decoder = create_rnn(embedding_size + 2 * lstm_hidden_size, bidirectional=False) 46 | self.qt_context = create_rnn(embedding_size) 47 | 48 | self.fc = nn.Linear(lstm_hidden_size, vocab_size) 49 | 50 | def get_encoder(self, encoder): 51 | return { 52 | 'v1': self.view1_word_rnn, 53 | 'v2': self.view2_word_rnn, 54 | 'v2sent': self.view2_sent_rnn, 55 | 'ae_decoder': self.ae_decoder, 56 | 'qt': self.qt_context 57 | }[encoder] 58 | 59 | @classmethod 60 | def construct_from_embeddings( 61 | cls, embeddings, num_layers, embedding_size, lstm_hidden_size, word_dropout, dropout, 62 | vocab_size, start_idx=2, end_idx=3, pad_idx=0): 63 | model = cls( 64 | num_layers=num_layers, 65 | embedding_size=embedding_size, 66 | lstm_hidden_size=lstm_hidden_size, 67 | word_dropout=word_dropout, 68 | dropout=dropout, 69 | start_idx=start_idx, 70 | end_idx=end_idx, 71 | pad_idx=pad_idx, 72 | vocab_size=vocab_size 73 | ) 74 | model.embedder = nn.Embedding.from_pretrained(embeddings, freeze=False) 75 | return model 76 | 77 | def decode(self, decoder_input, latent_z): 78 | """ 79 | decode state into word indices 80 | 81 | :param decoder_input: list of lists of indices 82 | :param latent_z: sequence context with shape of [batch_size, latent_z_size] 83 | 84 | :return: unnormalized logits of sentense words distribution probabilities 85 | with shape of [batch_size, seq_len, word_vocab_size] 86 | 87 | """ 88 | # this pad_sentences call will a start_idx token at the start of each sequence 89 | # so, we are feeding in the previous token each time, in order to generate the 90 | # next token 91 | padded, lengths = pad_sentences(decoder_input, pad_idx=self.pad_idx, lpad=self.start_idx) 92 | embeddings = self.embedder(padded) 93 | embeddings = self.word_dropout(embeddings) 94 | [batch_size, seq_len, _] = embeddings.size() 95 | # decoder rnn is conditioned on context via additional bias = W_cond * z 96 | # to every input token 97 | latent_z = t.cat([latent_z] * seq_len, 1).view(batch_size, seq_len, -1) 98 | embeddings = t.cat([embeddings, latent_z], 2) 99 | rnn = self.ae_decoder 100 | rnn_out, _ = rnn(embeddings) 101 | rnn_out = rnn_out.contiguous().view(batch_size * seq_len, self.lstm_hidden_size) 102 | result = self.fc(rnn_out) 103 | result = result.view(batch_size, seq_len, self.vocab_size) 104 | return result 105 | 106 | def forward(self, input, encoder): 107 | """ 108 | Encode an input into a vector representation 109 | params: 110 | input : word indices 111 | encoder: [pt1|pt2|v1|v2] 112 | """ 113 | if encoder == 'v2': 114 | return self.hierarchical_forward(input) 115 | 116 | batch_size = len(input) 117 | padded, lengths = pad_sentences(input, pad_idx=self.pad_idx) 118 | embeddings = self.embedder(padded) 119 | embeddings = self.word_dropout(embeddings) 120 | lengths, perm_idx = lengths.sort(0, descending=True) 121 | embeddings = embeddings[perm_idx] 122 | packed = torch.nn.utils.rnn.pack_padded_sequence(embeddings, lengths, batch_first=True) 123 | rnn = self.get_encoder(encoder) 124 | _, (_, final_state) = rnn(packed, None) 125 | _, unperm_idx = perm_idx.sort(0) 126 | final_state = final_state[:, unperm_idx] 127 | final_state = final_state.view(self.num_layers, 2, batch_size, self.lstm_hidden_size)[-1] \ 128 | .transpose(0, 1).contiguous() \ 129 | .view(batch_size, 2 * self.lstm_hidden_size) 130 | return final_state 131 | 132 | def hierarchical_forward(self, input): 133 | batch_size = len(input) 134 | padded, word_lens, sent_lens, max_sent_len = pad_paragraphs(input, pad_idx=self.pad_idx) 135 | embeddings = self.embedder(padded) 136 | embeddings = self.word_dropout(embeddings) 137 | word_lens, perm_idx = word_lens.sort(0, descending=True) 138 | embeddings = embeddings[perm_idx] 139 | packed = torch.nn.utils.rnn.pack_padded_sequence( 140 | embeddings, word_lens, batch_first=True) 141 | _, (_, final_word_state) = self.view2_word_rnn(packed, None) 142 | _, unperm_idx = perm_idx.sort(0) 143 | final_word_state = final_word_state[:, unperm_idx] 144 | final_word_state = final_word_state.view( 145 | self.num_layers, 2, batch_size*max_sent_len, self.lstm_hidden_size)[-1] \ 146 | .transpose(0, 1).contiguous() \ 147 | .view(batch_size, max_sent_len, 2 * self.lstm_hidden_size) 148 | 149 | sent_lens, sent_perm_idx = sent_lens.sort(0, descending=True) 150 | sent_embeddings = final_word_state[sent_perm_idx] 151 | sent_packed = torch.nn.utils.rnn.pack_padded_sequence(sent_embeddings, sent_lens, batch_first=True) 152 | _, (_, final_sent_state) = self.view2_sent_rnn(sent_packed, None) 153 | _, sent_unperm_idx = sent_perm_idx.sort(0) 154 | final_sent_state = final_sent_state[:, sent_unperm_idx] 155 | final_sent_state = final_sent_state.view( 156 | self.num_layers, 2, batch_size, self.lstm_hidden_size)[-1] \ 157 | .transpose(0, 1).contiguous() \ 158 | .view(batch_size, 2 * self.lstm_hidden_size) 159 | return final_sent_state 160 | 161 | def qt_loss(self, target_view_state, input_view_state): 162 | """ 163 | pick out the correct example in the target_view, based on the corresponding input_view 164 | """ 165 | scores = input_view_state @ target_view_state.transpose(0, 1) 166 | batch_size = scores.size(0) 167 | targets = torch.from_numpy(np.arange(batch_size, dtype=np.int64)) 168 | targets = targets.to(scores.device) 169 | loss = self.crit(scores, targets) 170 | _, argmax = scores.max(dim=-1) 171 | examples_correct = (argmax == targets) 172 | acc = examples_correct.float().mean().item() 173 | return loss, acc 174 | 175 | def reconst_loss(self, gnd_utts, reconst): 176 | """ 177 | gnd_utts is a list of lists of indices (the outer list should be a minibatch) 178 | reconst is a tensor with the logits from a decoder [batchsize][seqlen][vocabsize] 179 | (should not have passed through softmax) 180 | 181 | reconst should be one token longer than the inputs in gnd_utts. the additional 182 | token to be predicted is the end_idx token 183 | """ 184 | batch_size, seq_len, vocab_size = reconst.size() 185 | loss = 0 186 | # this pad_sentences call will add token self.end_idx at the end of each sequence 187 | padded, lengths = pad_sentences(gnd_utts, pad_idx=self.pad_idx, rpad=self.end_idx) 188 | batch_size = len(lengths) 189 | crit = nn.CrossEntropyLoss() 190 | reconst_flat = reconst.view(batch_size * seq_len, vocab_size) 191 | padded_flat = padded.view(batch_size * seq_len) 192 | loss += crit(reconst_flat, padded_flat) 193 | _, argmax = reconst.max(dim=-1) 194 | correct = (argmax == padded) 195 | acc = correct.float().mean().item() 196 | return loss, acc 197 | 198 | 199 | def from_embeddings(glove_path, id_to_token, token_to_id): 200 | vocab_size = len(token_to_id) 201 | 202 | # Load pre-trained GloVe vectors 203 | pretrained = {} 204 | word_emb_size = 0 205 | print('loading glove') 206 | for line in open(glove_path): 207 | parts = line.strip().split() 208 | if len(parts) % 100 != 1: 209 | continue 210 | word = parts[0] 211 | if word not in token_to_id: 212 | continue 213 | vector = [float(v) for v in parts[1:]] 214 | pretrained[word] = vector 215 | word_emb_size = len(vector) 216 | pretrained_list = [] 217 | scale = np.sqrt(3.0 / word_emb_size) 218 | print('loading oov') 219 | for word in token_to_id: 220 | # apply lower() because all GloVe vectors are for lowercase words 221 | if word.lower() in pretrained: 222 | pretrained_list.append(np.array(pretrained[word.lower()])) 223 | else: 224 | random_vector = np.random.uniform(-scale, scale, [word_emb_size]) 225 | pretrained_list.append(random_vector) 226 | 227 | print('instantiating model') 228 | model = MultiviewEncoders.construct_from_embeddings( 229 | embeddings=torch.FloatTensor(pretrained_list), 230 | num_layers=train.LSTM_LAYER, 231 | embedding_size=word_emb_size, 232 | lstm_hidden_size=train.LSTM_HIDDEN, 233 | word_dropout=train.WORD_DROPOUT_RATE, 234 | dropout=train.DROPOUT_RATE, 235 | vocab_size=vocab_size 236 | ) 237 | model.to(device) 238 | return id_to_token, token_to_id, vocab_size, word_emb_size, model 239 | 240 | 241 | def load_model(model_path): 242 | with open(model_path, 'rb') as f: 243 | state = torch.load(f) 244 | 245 | id_to_token = state['id_to_token'] 246 | word_emb_size = state['word_emb_size'] 247 | 248 | token_to_id = {token: id for id, token in enumerate(id_to_token)} 249 | vocab_size = len(id_to_token) 250 | 251 | mvc_encoder = MultiviewEncoders( 252 | num_layers=train.LSTM_LAYER, 253 | embedding_size=word_emb_size, 254 | lstm_hidden_size=train.LSTM_HIDDEN, 255 | word_dropout=train.WORD_DROPOUT_RATE, 256 | dropout=train.DROPOUT_RATE, 257 | vocab_size=vocab_size 258 | ) 259 | mvc_encoder.to(device) 260 | mvc_encoder.load_state_dict(state['model_state']) 261 | return id_to_token, token_to_id, vocab_size, word_emb_size, mvc_encoder 262 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 5 | 6 | 7 | def pad_sentences( 8 | sents, lpad=None, rpad=None, reverse=False, pad_idx=0, 9 | max_sent_len=100): 10 | sentences = [] 11 | max_len = 0 12 | for i in range(len(sents)): 13 | if len(sents[i]) > max_sent_len: 14 | sentences.append(sents[i][:max_sent_len]) 15 | else: 16 | sentences.append(sents[i]) 17 | max_len = max(max_len, len(sentences[i])) 18 | if reverse: 19 | sentences = [sentence[::-1] for sentence in sentences] 20 | if lpad is not None: 21 | sentences = [[lpad] + sentence for sentence in sentences] 22 | max_len += 1 23 | if rpad is not None: 24 | sentences = [sentence + [rpad] for sentence in sentences] 25 | max_len += 1 26 | lengths = [] 27 | for i in range(len(sentences)): 28 | lengths.append(len(sentences[i])) 29 | sentences[i] = sentences[i] + [pad_idx]*(max_len - len(sentences[i])) 30 | return (torch.LongTensor(sentences).to(device), 31 | torch.LongTensor(lengths).to(device)) 32 | 33 | 34 | def pad_paragraphs(paras, pad_idx=0): 35 | sentences, lengths = [], [] 36 | max_len = 0 37 | for para in paras: 38 | max_len = max(max_len, len(para)) 39 | for para in paras: 40 | for sent in para: 41 | sentences.append(sent[:]) 42 | lengths.append(len(para)) 43 | for i in range(max_len - len(para)): 44 | sentences.append([pad_idx]) 45 | ret_sents, sent_lens = pad_sentences(sentences, pad_idx=pad_idx) 46 | return ret_sents, sent_lens, torch.LongTensor(lengths).to(device), max_len 47 | 48 | 49 | def euclidean_metric(a, b): 50 | n = a.shape[0] 51 | m = b.shape[0] 52 | a = a.unsqueeze(1).expand(n, m, -1) 53 | b = b.unsqueeze(0).expand(n, m, -1) 54 | logits = -((a - b)**2).sum(dim=2) 55 | return logits 56 | -------------------------------------------------------------------------------- /preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/dialog-intent-induction/6396f3153b0fda7e170b1df6b68e969b5e4eb16e/preprocessing/__init__.py -------------------------------------------------------------------------------- /preprocessing/askubuntu.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from nltk.tokenize import word_tokenize 4 | from nltk.corpus import stopwords 5 | from bs4 import BeautifulSoup 6 | 7 | 8 | REMOVE_STOP = False 9 | QUESTION_WORDS = ["what", "when", "where", "why", "how", "who"] 10 | STOPWORDS = set(stopwords.words("english")) 11 | for w in QUESTION_WORDS: 12 | STOPWORDS.remove(w) 13 | 14 | # replace URL with a special token *=URL=* 15 | URL_PATTERN = r'(?:(?:https?|ftp)://)(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+' 16 | 17 | URLREG = re.compile(URL_PATTERN) 18 | 19 | 20 | def inner_preprocess_text_from_tao(text, max_len): 21 | # remove non-ascii chars 22 | text = text.encode("utf8").decode("ascii", "ignore") 23 | 24 | # remove URLs 25 | text = URLREG.sub('*=URL=*', text) 26 | 27 | text = text.casefold() 28 | 29 | # tokenize, filter and truncate 30 | words = word_tokenize(text) 31 | if REMOVE_STOP: 32 | words = [x for x in words if x not in STOPWORDS] 33 | if max_len > 0: 34 | words = words[:max_len] 35 | return u' '.join(words) 36 | 37 | 38 | def preprocess_from_tao(text, max_len): 39 | """ 40 | adapted from Tao's code for http://aclweb.org/anthology/D18-1131 41 | """ 42 | body_soup = BeautifulSoup(text, "lxml") 43 | 44 | # remove code 45 | [x.extract() for x in body_soup.findAll('code')] 46 | 47 | # also remove pre 48 | [x.extract() for x in body_soup.findAll('pre')] 49 | 50 | # remove "Possible Duplicate" section 51 | blk = body_soup.blockquote 52 | if blk and blk.text.strip().startswith("Possible Duplicate:"): 53 | blk.decompose() 54 | body_cleaned = inner_preprocess_text_from_tao(body_soup.text, max_len=max_len) 55 | assert "Possible Duplicate:" not in body_cleaned 56 | 57 | assert "\n" not in body_cleaned 58 | return body_cleaned 59 | 60 | 61 | class Preprocessor(object): 62 | def __init__(self, max_len): 63 | self.max_len = max_len 64 | 65 | def __call__(self, text): 66 | return True, preprocess_from_tao(text, max_len=self.max_len) 67 | -------------------------------------------------------------------------------- /preprocessing/twitter_airlines.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import nltk 4 | 5 | 6 | class Preprocessor(object): 7 | def __init__(self): 8 | self.twitter_at_re = re.compile('@[a-zA-Z1-90]+') 9 | self.initials_re = re.compile('\^[A-Z][A-Z]([^A-Za-z]|$)') 10 | self.star_initials_re = re.compile('\*[A-Z]+$') 11 | self.money_re = re.compile('\$[1-90]+([-.][1-90]+)?') 12 | self.number_re = re.compile('[1-90]+([-.: ()][1-90]+)?') 13 | self.airport_code_re = re.compile('([^a-zA-Z])[A-Z][A-Z][A-Z]([^a-zA-Z])') 14 | self.tag_re = re.compile('([^a-zA-Z])#[a-zA-Z]+') 15 | self.url_re = re.compile('http[s]?:(//)?(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+') # from mleng data_preprocess.py 16 | 17 | def __call__(self, text, info_dict=None): 18 | is_valid = True 19 | if 'DM' in text: 20 | is_valid = False 21 | if 'https://t.co' in text: 22 | is_valid = False 23 | 24 | target_company = info_dict['target_company'] 25 | cust_twitter_id = info_dict['cust_twitter_id'] 26 | text = text.replace('\n', ' ').encode('ascii', 'ignore').decode('ascii') 27 | text = text.replace('@' + target_company, ' __company__ ').replace('@' + cust_twitter_id, ' __cust__ ') 28 | text = text.replace('<', '<').replace('>', '>').replace('&', '&') 29 | text = self.twitter_at_re.sub('__twitter_at__', text) 30 | text = self.initials_re.sub('__initials__', text) 31 | text = self.star_initials_re.sub('__initials__', text) 32 | text = self.money_re.sub(' __money__ ', text) 33 | text = self.number_re.sub(' __num__ ', text) 34 | text = self.airport_code_re.sub('\g<1> __airport_code__ \g<2>', text) 35 | text = self.tag_re.sub('\g<1> __twitter_tag__ ', text) 36 | text = text.casefold() 37 | text = self.url_re.sub(' __url__ ', text) 38 | text = ' '.join(nltk.word_tokenize(text)) 39 | 40 | return is_valid, text 41 | -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | 4 | import train as train_mod 5 | 6 | 7 | def pretrain_qt(dataset, perm_idx, expressions, train=True): 8 | """ 9 | for each pair of utterances: 10 | - encodes first utterance using 'v1' encoder 11 | - encodes second utterance using 'qt_context' encoder 12 | uses negative sampling loss between these two embeddings, relative to the 13 | other second utterances in the batch 14 | """ 15 | model, optimizer = expressions 16 | 17 | utts = [] 18 | qt_ex = [] 19 | for idx in perm_idx: 20 | v1, v2 = dataset[idx][0] 21 | conversation = [v1] + v2 22 | for n, utt in enumerate(conversation): 23 | utts.append(utt) 24 | if n > 0: 25 | num_utt = len(utts) 26 | ex = (num_utt - 2, num_utt - 1) 27 | qt_ex.append(ex) 28 | qt_ex = np.random.permutation(qt_ex) 29 | 30 | total_loss, total_acc = 0., 0. 31 | n_batch = (len(qt_ex) + train_mod.BATCH_SIZE - 1) // train_mod.BATCH_SIZE 32 | for i in range(n_batch): 33 | qt_ex_batch = qt_ex[i*train_mod.BATCH_SIZE:(i+1)*train_mod.BATCH_SIZE] 34 | 35 | v1_idxes, v2_idxes = list(zip(*[(ex[0].item(), ex[1].item()) for ex in qt_ex_batch])) 36 | v1_utts = [utts[idx] for idx in v1_idxes] 37 | v2_utts = [utts[idx] for idx in v2_idxes] 38 | 39 | v1_state = model(v1_utts, encoder='v1') 40 | v2_state = model(v2_utts, encoder='qt') 41 | 42 | loss, acc = model.qt_loss(v2_state, v1_state) 43 | total_loss += loss.item() 44 | total_acc += acc * len(qt_ex_batch) 45 | if train: 46 | optimizer.zero_grad() 47 | loss.backward() 48 | optimizer.step() 49 | return total_loss, total_acc / len(qt_ex) 50 | 51 | 52 | def after_pretrain_qt(model): 53 | model.view2_word_rnn = copy.deepcopy(model.view1_word_rnn) 54 | 55 | 56 | def pretrain_ae(dataset, perm_idx, expressions, train=True): 57 | """ 58 | uses v1 encoder to encode all utterances in both view1 and view2 59 | to utterance-level embeddings 60 | uses 'ae_decoder' rnn from model to decode these embeddings 61 | (works at utterance level) 62 | """ 63 | model, optimizer = expressions 64 | 65 | utterances = [] 66 | for idx in perm_idx: 67 | v1, v2 = dataset[idx][0] 68 | conversation = [v1] + v2 69 | utterances += conversation 70 | utterances = np.random.permutation(utterances) 71 | 72 | total_loss, total_acc = 0., 0. 73 | n_batch = (len(utterances) + train_mod.AE_BATCH_SIZE - 1) // train_mod.AE_BATCH_SIZE 74 | for i in range(n_batch): 75 | utt_batch = utterances[i*train_mod.AE_BATCH_SIZE:(i+1)*train_mod.AE_BATCH_SIZE] 76 | enc_state = model(utt_batch, encoder='v1') 77 | reconst = model.decode(decoder_input=utt_batch, latent_z=enc_state) 78 | loss, acc = model.reconst_loss(utt_batch, reconst) 79 | 80 | total_loss += loss.item() 81 | total_acc += acc * len(utt_batch) 82 | if train: 83 | optimizer.zero_grad() 84 | loss.backward() 85 | optimizer.step() 86 | total_acc = total_acc / len(utterances) 87 | return total_loss, total_acc 88 | 89 | 90 | def after_pretrain_ae(model): 91 | # we'll use the view1 encoder for both view 1 and view 2 92 | model.view2_word_rnn = copy.deepcopy(model.view1_word_rnn) 93 | -------------------------------------------------------------------------------- /proc_data.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from torch.utils.data import Dataset 3 | from nltk.tokenize import word_tokenize, sent_tokenize 4 | import numpy as np 5 | np.random.seed(0) 6 | 7 | PAD = "__PAD__" 8 | UNK = "__UNK__" 9 | START = "__START__" 10 | END = "__END__" 11 | 12 | 13 | class Dataset(Dataset): 14 | def __init__(self, fname, view1_col='view1_col', view2_col='view2_col', label_col='cluster_id', 15 | tokenized=True, max_sent=10, train_ratio=.9): 16 | """ 17 | Args: 18 | fname: str, training data file 19 | view1_col: str, the column corresponding to view 1 input 20 | view2_col: str, the column corresponding to view 2 input 21 | label_col: str, the column corresponding to label 22 | """ 23 | 24 | def tokens_to_idices(tokens): 25 | token_idices = [] 26 | for token in tokens: 27 | if token not in token_to_id: 28 | token_to_id[token] = len(token_to_id) 29 | id_to_token.append(token) 30 | token_idices.append(token_to_id[token]) 31 | return token_idices 32 | 33 | id_to_token = [PAD, UNK, START, END] 34 | token_to_id = {PAD: 0, UNK: 1, START: 2, END: 3} 35 | id_to_label = [UNK] 36 | label_to_id = {UNK: 0} 37 | data = [] 38 | labels = [] 39 | v1_utts = [] # needed for displaying cluster samples 40 | self.trn_idx, self.tst_idx = [], [] 41 | self.trn_idx_no_unk = [] 42 | with open(fname, 'r') as csvfile: 43 | reader = csv.DictReader(csvfile) 44 | for row in reader: 45 | view1_text, view2_text = row[view1_col], row[view2_col] 46 | label = row[label_col] 47 | if 'UNK' == label: 48 | label = UNK 49 | if ' <") 53 | for i in range(len(view2_sents) - 1): 54 | view2_sents[i] = view2_sents[i] + '>' 55 | view2_sents[i+1] = '<' + view2_sents[i + 1] 56 | v1_utts.append(view1_text) 57 | if not tokenized: 58 | v1_tokens = word_tokenize(view1_text.lower()) 59 | v2_tokens = [word_tokenize(sent.lower()) for sent in view2_sents] 60 | else: 61 | v1_tokens = view1_text.lower().split() 62 | v2_tokens = [sent.lower().split() for sent in view2_sents] 63 | v2_tokens = v2_tokens[:max_sent] 64 | 65 | v1_token_idices = tokens_to_idices(v1_tokens) 66 | v2_token_idices = [tokens_to_idices(tokens) for tokens in v2_tokens] 67 | v2_token_idices = [idices for idices in v2_token_idices if len(idices) > 0] 68 | if len(v1_token_idices) == 0 or len(v2_token_idices) == 0: 69 | continue 70 | if label not in label_to_id: 71 | label_to_id[label] = len(label_to_id) 72 | id_to_label.append(label) 73 | data.append((v1_token_idices, v2_token_idices)) 74 | labels.append(label_to_id[label]) 75 | if label == UNK and np.random.random_sample() < .1: 76 | self.tst_idx.append(len(data)-1) 77 | else: 78 | self.trn_idx.append(len(data)-1) 79 | if label != UNK: 80 | self.trn_idx_no_unk.append(len(data) - 1) 81 | 82 | self.v1_utts = v1_utts 83 | self.id_to_token = id_to_token 84 | self.token_to_id = token_to_id 85 | self.id_to_label = id_to_label 86 | self.label_to_id = label_to_id 87 | self.data = data 88 | self.labels = labels 89 | 90 | def __len__(self): 91 | return len(self.data) 92 | 93 | def __getitem__(self, i): 94 | return self.data[i], self.labels[i] 95 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # we used nltk 3.4.5 for our experiments, 2 | # but that is showing a security vulnerability, 3 | # so changing this to 3.6.4 4 | # nltk==3.4.5 5 | nltk==3.6.4 6 | numpy==1.17.3 7 | scikit-learn==0.21.3 8 | scipy==1.3.1 9 | torch==1.3.0 10 | -------------------------------------------------------------------------------- /run_mvsc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Given a pre-trained model, run mvsc on it, and print scores vs gold standard 3 | 4 | we'll use view1 encoder to encode each of view1 and view2, and then pass that through mvsc algo 5 | 6 | this should probably be folded innto run_clustering.py (originally kind of forked from 7 | run_clustering.py, and combined with some things from train_pca.py and train.py) 8 | """ 9 | import time 10 | import random 11 | import datetime 12 | import argparse 13 | 14 | import sklearn.cluster 15 | import numpy as np 16 | import torch 17 | 18 | from metrics import cluster_metrics 19 | from model import multiview_encoders 20 | from proc_data import Dataset 21 | 22 | try: 23 | import multiview 24 | except Exception as e: 25 | print('please install https://github.com/mariceli3/multiview') 26 | print('eg pip install git+https://github.com/mariceli3/multiview') 27 | raise e 28 | 29 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 30 | 31 | BATCH_SIZE = 32 32 | 33 | 34 | def transform(dataset, perm_idx, model, view): 35 | """ 36 | for view1 utterance, simply encode using view1 encoder 37 | for view 2 utterances: 38 | - encode each utterance, using view 1 encoder, to get utterance embeddings 39 | - take average of utterance embeddings to form view 2 embedding 40 | """ 41 | model.eval() 42 | latent_zs, golds = [], [] 43 | n_batch = (len(perm_idx) + BATCH_SIZE - 1) // BATCH_SIZE 44 | for i in range(n_batch): 45 | indices = perm_idx[i*BATCH_SIZE:(i+1)*BATCH_SIZE] 46 | v1_batch, v2_batch = list(zip(*[dataset[idx][0] for idx in indices])) 47 | golds += [dataset[idx][1] for idx in indices] 48 | if view == 'v1': 49 | latent_z = model(v1_batch, encoder='v1') 50 | elif view == 'v2': 51 | latent_z_l = [model(conv, encoder='v1').mean(dim=0) for conv in v2_batch] 52 | latent_z = torch.stack(latent_z_l) 53 | latent_zs.append(latent_z.cpu().data.numpy()) 54 | latent_zs = np.concatenate(latent_zs) 55 | return latent_zs, golds 56 | 57 | 58 | def run( 59 | ref, model_path, num_clusters, num_cluster_samples, seed, 60 | out_cluster_samples_file_hier, 61 | max_examples, out_cluster_samples_file, 62 | data_path, view1_col, view2_col, label_col, 63 | sampling_strategy, mvsc_no_unk): 64 | torch.manual_seed(seed) 65 | np.random.seed(seed) 66 | random.seed(seed) 67 | 68 | id_to_token, token_to_id, vocab_size, word_emb_size, mvc_encoder = \ 69 | multiview_encoders.load_model(model_path) 70 | print('loaded model') 71 | 72 | print('loading dataset') 73 | dataset = Dataset(data_path, view1_col=view1_col, view2_col=view2_col, label_col=label_col) 74 | n_cluster = len(dataset.id_to_label) - 1 75 | print("loaded dataset, num of class = %d" % n_cluster) 76 | 77 | idxes = dataset.trn_idx_no_unk if mvsc_no_unk else dataset.trn_idx 78 | trn_idx = [x.item() for x in np.random.permutation(idxes)] 79 | if max_examples is not None: 80 | trn_idx = trn_idx[:max_examples] 81 | 82 | num_clusters = n_cluster if num_clusters is None else num_clusters 83 | print('clustering over num clusters', num_clusters) 84 | 85 | mvsc = multiview.mvsc.MVSC( 86 | k=n_cluster 87 | ) 88 | latent_z1s, golds = transform(dataset, trn_idx, mvc_encoder, view='v1') 89 | latent_z2s, _ = transform(dataset, trn_idx, mvc_encoder, view='v2') 90 | print('running mvsc', end='', flush=True) 91 | start = time.time() 92 | preds, eivalues, eivectors, sigmas = mvsc.fit_transform( 93 | [latent_z1s, latent_z2s], [False] * 2 94 | ) 95 | print('...done') 96 | mvsc_time = time.time() - start 97 | print('time taken %.3f' % mvsc_time) 98 | 99 | lgolds, lpreds = [], [] 100 | for g, p in zip(golds, list(preds)): 101 | if g > 0: 102 | lgolds.append(g) 103 | lpreds.append(p) 104 | prec, rec, f1 = cluster_metrics.calc_prec_rec_f1( 105 | gnd_assignments=torch.LongTensor(lgolds).to(device), 106 | pred_assignments=torch.LongTensor(lpreds).to(device)) 107 | acc = cluster_metrics.calc_ACC( 108 | torch.LongTensor(lpreds).to(device), torch.LongTensor(lgolds).to(device)) 109 | silhouette = sklearn.metrics.silhouette_score(latent_z1s, preds, metric='euclidean') 110 | davies_bouldin = sklearn.metrics.davies_bouldin_score(latent_z1s, preds) 111 | print(f'{datetime.datetime.now()} pretrain: eval prec={prec:.4f} rec={rec:.4f} f1={f1:.4f} ' 112 | f'acc={acc:.4f} sil={silhouette:.4f}, db={davies_bouldin:.4f}') 113 | 114 | 115 | if __name__ == '__main__': 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument('--seed', type=int, default=123) 118 | parser.add_argument('--max-examples', type=int, 119 | help='since we might not want to cluster entire dataset?') 120 | parser.add_argument('--mvsc-no-unk', action='store_true', 121 | help='only feed non-unk data to MVSC (to avoid oom)') 122 | parser.add_argument('--ref', type=str, required=True) 123 | parser.add_argument('--model-path', type=str, required=True) 124 | parser.add_argument('--data-path', type=str, default='./data/airlines_500_merged.csv') 125 | parser.add_argument('--view1-col', type=str, default='view1_col') 126 | parser.add_argument('--view2-col', type=str, default='view2_col') 127 | parser.add_argument('--label-col', type=str, default='cluster_id') 128 | parser.add_argument('--num-clusters', type=int, help='defaults to number of supervised classes') 129 | parser.add_argument('--num-cluster-samples', type=int, default=10) 130 | parser.add_argument('--sampling-strategy', type=str, 131 | choices=['uniform', 'nearest'], default='nearest') 132 | parser.add_argument('--out-cluster-samples-file-hier', type=str, 133 | default='tmp/cluster_samples_hier_{ref}.txt') 134 | parser.add_argument('--out-cluster-samples-file', type=str, 135 | default='tmp/cluster_samples_{ref}.txt') 136 | args = parser.parse_args() 137 | args.out_cluster_samples_file = args.out_cluster_samples_file.format(**args.__dict__) 138 | args.out_cluster_samples_file_hier = args.out_cluster_samples_file_hier.format(**args.__dict__) 139 | run(**args.__dict__) 140 | -------------------------------------------------------------------------------- /samplers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | np.random.seed(0) 4 | 5 | 6 | class CategoriesSampler(): 7 | def __init__(self, labels, n_batch, n_cls, n_ins): 8 | """ 9 | Args: 10 | labels: size=(dataset_size), label indices of instances in a data set 11 | n_batch: int, number of batchs for episode training 12 | n_cls: int, number of sampled classes 13 | n_ins: int, number of instances considered for a class 14 | 15 | conceptually, this is for prototypical sampling: 16 | - for each training step, we sample 'n_ways' classes (in the paper), which is 'n_cls' here 17 | - we draw 'n_shot' examples of each class, ie n_ins here 18 | - these will be encoded, and averaged, to get the prototypes 19 | - and we draw 'n_query' query examples, of each class, which is also 'n_ins' here 20 | - these will be encoded, and then used to generate the prototype loss, relative to the 21 | prototypes 22 | 23 | __iter__ returns a generator, which will yield n_batch sets of training data, one set of 24 | training 25 | data per yield command 26 | """ 27 | if not isinstance(labels, list): 28 | labels = labels.tolist() 29 | 30 | self.n_batch = n_batch 31 | self.n_cls = n_cls 32 | self.n_ins = n_ins 33 | 34 | self.classes = list(set(labels)) 35 | labels = np.array(labels) 36 | self.cls_indices = {} 37 | for c in self.classes: 38 | indices = np.argwhere(labels == c).reshape(-1) 39 | self.cls_indices[c] = indices 40 | 41 | def __len__(self): 42 | return self.n_batch 43 | 44 | def __iter__(self): 45 | for _ in range(self.n_batch): 46 | batch = [] 47 | classes = np.random.permutation(self.classes)[:self.n_cls] 48 | for c in classes: 49 | indices = self.cls_indices[c] 50 | while len(indices) < self.n_ins: 51 | indices = np.concatenate((indices, indices)) 52 | batch.append(np.random.permutation(indices)[:self.n_ins]) 53 | batch = np.stack(batch).flatten('F') 54 | yield batch 55 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max_line_length = 100 3 | -------------------------------------------------------------------------------- /tests/test_multiview_encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from model.multiview_encoders import MultiviewEncoders 5 | 6 | 7 | def test_ae_decoder(): 8 | dropout = word_dropout = 0 9 | embedding_size = 15 10 | num_layers = 1 11 | vocab_size = 13 12 | lstm_hidden_size = 11 13 | 14 | batch_size = 5 15 | seq_len = 6 16 | 17 | print('vocab_size', vocab_size) 18 | print('batch_size', batch_size) 19 | print('seq_len', seq_len) 20 | print('lstm_hidden_size', lstm_hidden_size) 21 | 22 | _input_idxes = torch.from_numpy(np.random.choice( 23 | vocab_size - 1, (batch_size, seq_len), replace=True)) + 1 24 | print('_input_idxes', _input_idxes) 25 | print('_input_idxes.size()', _input_idxes.size()) 26 | input_idxes = [] 27 | for b in range(batch_size): 28 | idxes = [_input_idxes[b][i].item() for i in range(seq_len)] 29 | input_idxes.append(idxes) 30 | latent_z = torch.rand((batch_size, lstm_hidden_size * 2)) 31 | print('latent_z.size()', latent_z.size()) 32 | print('input_idxes', input_idxes) 33 | 34 | encoders = MultiviewEncoders(vocab_size, num_layers, embedding_size, lstm_hidden_size, 35 | word_dropout, dropout) 36 | with torch.no_grad(): 37 | logits = encoders.decode(decoder_input=input_idxes, latent_z=latent_z) 38 | print('logits.size()', logits.size()) 39 | 40 | for n in range(batch_size): 41 | _input_idxes = input_idxes[n:n + 1] 42 | _latent_z = latent_z[n: n + 1] 43 | with torch.no_grad(): 44 | _logits = encoders.decode(decoder_input=_input_idxes, latent_z=_latent_z) 45 | diff = (logits[n] - _logits).abs().max().item() 46 | assert diff < 1e-7 47 | 48 | 49 | def test_reconst_loss(): 50 | dropout = word_dropout = 0 51 | embedding_size = 15 52 | num_layers = 1 53 | lstm_hidden_size = 11 54 | 55 | vocab_size = 13 56 | batch_size = 5 57 | seq_len = 6 58 | 59 | print('vocab_size', vocab_size) 60 | print('batch_size', batch_size) 61 | print('seq_len', seq_len) 62 | print('lstm_hidden_size', lstm_hidden_size) 63 | 64 | _input_idxes = torch.from_numpy(np.random.choice( 65 | vocab_size - 1, (batch_size, seq_len), replace=True)) + 1 66 | print('_input_idxes', _input_idxes) 67 | print('_input_idxes.size()', _input_idxes.size()) 68 | input_idxes = [] 69 | for b in range(batch_size): 70 | idxes = [_input_idxes[b][i].item() for i in range(seq_len)] 71 | input_idxes.append(idxes) 72 | 73 | encoders = MultiviewEncoders(vocab_size, num_layers, embedding_size, lstm_hidden_size, 74 | word_dropout, dropout) 75 | 76 | probs = torch.zeros(batch_size, seq_len + 1, vocab_size) 77 | probs[:, seq_len, encoders.end_idx] = 1 78 | for b in range(batch_size): 79 | for i, idx in enumerate(input_idxes[b]): 80 | probs[b, i, idx] = 1 81 | logits = probs.log() 82 | print('logits.sum(dim=-1)', logits.sum(dim=-1)) 83 | print('logits.min(dim=-1)', logits.min(dim=-1)[0]) 84 | print('logits.max(dim=-1)', logits.max(dim=-1)[0]) 85 | _, logits_max = logits.max(dim=-1) 86 | print('logits_max', logits_max) 87 | assert (logits_max[:, :seq_len] == _input_idxes).all() 88 | 89 | loss, acc = encoders.reconst_loss(input_idxes, logits) 90 | loss = loss.item() 91 | print('loss', loss, 'acc', acc) 92 | assert acc == 1.0 93 | assert loss == 0.0 94 | -------------------------------------------------------------------------------- /tests/test_samplers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import defaultdict 3 | 4 | import samplers 5 | 6 | 7 | def test_categories_sampler(): 8 | N = 50 9 | n_batch = 5 10 | n_cls = 3 11 | n_ins = 4 12 | 13 | np.random.seed(123) 14 | labels = np.random.choice(5, N, replace=True) 15 | print('labels', labels) 16 | 17 | sampler = samplers.CategoriesSampler(labels, n_batch, n_cls, n_ins) 18 | for b, batch in enumerate(sampler): 19 | print('batch', batch) 20 | classes = [] 21 | count_by_class = defaultdict(int) 22 | for i in batch: 23 | classes.append(labels[i]) 24 | count_by_class[labels[i]] += 1 25 | print('classes', classes) 26 | print('count_by_class', count_by_class) 27 | assert len(count_by_class) == n_cls 28 | for v in count_by_class.values(): 29 | assert v == n_ins 30 | assert b == n_batch - 1 31 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import copy 4 | import numpy as np 5 | from os.path import expanduser as expand 6 | import sklearn.cluster 7 | import sklearn.metrics 8 | import warnings 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | from proc_data import Dataset 14 | from model import multiview_encoders 15 | from metrics import cluster_metrics 16 | from samplers import CategoriesSampler 17 | from model import utils 18 | import pretrain 19 | 20 | 21 | warnings.filterwarnings(action='ignore', category=RuntimeWarning) 22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 23 | torch.manual_seed(0) 24 | 25 | LSTM_LAYER = 1 26 | LSTM_HIDDEN = 300 27 | WORD_DROPOUT_RATE = 0. 28 | DROPOUT_RATE = 0. 29 | BATCH_SIZE = 32 30 | AE_BATCH_SIZE = 8 31 | LEARNING_RATE = 0.001 32 | 33 | 34 | def do_pass(batches, shot, way, query, expressions, encoder): 35 | model, optimizer = expressions 36 | model.train() 37 | for i, batch in enumerate(batches, 1): 38 | data = [x[{'v1': 0, 'v2': 1}[encoder]] for x, _ in batch] 39 | p = shot * way 40 | data_shot, data_query = data[:p], data[p:] 41 | 42 | proto = model(data_shot, encoder=encoder) 43 | proto = proto.reshape(shot, way, -1).mean(dim=0) 44 | 45 | # ignore original labels, reassign labels from 0 46 | label = torch.arange(way).repeat(query) 47 | label = label.type(torch.LongTensor).to(device) 48 | 49 | logits = utils.euclidean_metric(model(data_query, encoder=encoder), proto) 50 | loss = F.cross_entropy(logits, label) 51 | optimizer.zero_grad() 52 | 53 | loss.backward() 54 | optimizer.step() 55 | return loss.item() 56 | 57 | 58 | def transform(dataset, perm_idx, model, encoder): 59 | model.eval() 60 | latent_zs, golds = [], [] 61 | n_batch = (len(perm_idx) + BATCH_SIZE - 1) // BATCH_SIZE 62 | for i in range(n_batch): 63 | indices = perm_idx[i*BATCH_SIZE:(i+1)*BATCH_SIZE] 64 | v1_batch, v2_batch = list(zip(*[dataset[idx][0] for idx in indices])) 65 | golds += [dataset[idx][1] for idx in indices] 66 | latent_z = model({'v1': v1_batch, 'v2': v2_batch}[encoder], encoder=encoder) 67 | latent_zs.append(latent_z.cpu().data.numpy()) 68 | latent_zs = np.concatenate(latent_zs) 69 | return latent_zs, golds 70 | 71 | 72 | def calc_centroids(latent_zs, assignments, n_cluster): 73 | centroids = [] 74 | for i in range(n_cluster): 75 | idx = np.where(assignments == i)[0] 76 | mean = np.mean(latent_zs[idx], 0) 77 | centroids.append(mean) 78 | return np.stack(centroids) 79 | 80 | 81 | def run_one_side(model, optimizer, preds_left, pt_batch, way, shot, query, n_cluster, dataset, 82 | right_encoder_side): 83 | """ 84 | encoder_side should be 'v1' or 'v2'. It should match whichever view is 'right' here. 85 | """ 86 | loss = 0 87 | 88 | sampler = CategoriesSampler(preds_left, pt_batch, way, shot + query) 89 | train_batches = [[dataset[dataset.trn_idx[idx]] for idx in indices] for indices in sampler] 90 | loss += do_pass(train_batches, shot, way, query, (model, optimizer), encoder=right_encoder_side) 91 | 92 | z_right, _ = transform(dataset, dataset.trn_idx, model, encoder=right_encoder_side) 93 | centroids = calc_centroids(z_right, preds_left, n_cluster) 94 | kmeans = sklearn.cluster.KMeans( 95 | n_clusters=n_cluster, init=centroids, max_iter=10, verbose=0) 96 | preds_right = kmeans.fit_predict(z_right) 97 | 98 | tst_z_right, _ = transform(dataset, dataset.tst_idx, model, encoder=right_encoder_side) 99 | tst_preds_right = kmeans.predict(tst_z_right) 100 | 101 | return loss, preds_right, tst_preds_right 102 | 103 | 104 | def main(): 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument('--data-path', type=str, default='./data/airlines_processed.csv') 107 | parser.add_argument('--glove-path', type=str, default='./data/glove.840B.300d.txt') 108 | parser.add_argument('--pre-model', type=str, choices=['ae', 'qt'], default='qt') 109 | parser.add_argument('--pre-epoch', type=int, default=0) 110 | parser.add_argument('--pt-batch', type=int, default=100) 111 | parser.add_argument('--model-path', type=str, help='path of pretrained model to load') 112 | parser.add_argument('--way', type=int, default=5) 113 | parser.add_argument('--num-epochs', type=int, default=100) 114 | parser.add_argument('--seed', type=int, default=0) 115 | 116 | parser.add_argument('--save-model-path', type=str) 117 | 118 | parser.add_argument('--view1-col', type=str, default='view1') 119 | parser.add_argument('--view2-col', type=str, default='view2') 120 | parser.add_argument('--label-col', type=str, default='label') 121 | args = parser.parse_args() 122 | 123 | np.random.seed(args.seed) 124 | 125 | print('loading dataset') 126 | dataset = Dataset( 127 | args.data_path, view1_col=args.view1_col, view2_col=args.view2_col, 128 | label_col=args.label_col) 129 | n_cluster = len(dataset.id_to_label) - 1 130 | print("num of class = %d" % n_cluster) 131 | 132 | if args.model_path is not None: 133 | id_to_token, token_to_id, vocab_size, word_emb_size, model = multiview_encoders.load_model( 134 | args.model_path) 135 | print('loaded model') 136 | else: 137 | id_to_token, token_to_id, vocab_size, word_emb_size, model = \ 138 | multiview_encoders.from_embeddings( 139 | args.glove_path, dataset.id_to_token, dataset.token_to_id) 140 | print('created randomly initialized model') 141 | print('vocab_size', vocab_size) 142 | 143 | optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 144 | expressions = (model, optimizer) 145 | 146 | pre_acc, pre_state, pre_state_epoch = 0., None, None 147 | pretrain_method = { 148 | 'ae': pretrain.pretrain_ae, 149 | 'qt': pretrain.pretrain_qt, 150 | }[args.pre_model] 151 | for epoch in range(1, args.pre_epoch + 1): 152 | model.train() 153 | perm_idx = np.random.permutation(dataset.trn_idx) 154 | trn_loss, _ = pretrain_method(dataset, perm_idx, expressions, train=True) 155 | model.eval() 156 | _, tst_acc = pretrain_method(dataset, dataset.tst_idx, expressions, train=False) 157 | if tst_acc > pre_acc: 158 | pre_state = copy.deepcopy(model.state_dict()) 159 | pre_acc = tst_acc 160 | pre_state_epoch = epoch 161 | print('{} epoch {}, train_loss={:.4f} test_acc={:.4f}'.format( 162 | datetime.datetime.now(), epoch, trn_loss, tst_acc)) 163 | 164 | if args.pre_epoch > 0: 165 | # load best state 166 | model.load_state_dict(pre_state) 167 | print(f'loaded best state from epoch {pre_state_epoch}') 168 | 169 | # deepcopy pretrained views into v1 and/or view2 170 | { 171 | 'ae': pretrain.after_pretrain_ae, 172 | 'qt': pretrain.after_pretrain_qt, 173 | }[args.pre_model](model) 174 | 175 | # reinitialiate optimizer 176 | optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 177 | expressions = (model, optimizer) 178 | print('applied post-pretraining') 179 | 180 | kmeans = sklearn.cluster.KMeans(n_clusters=n_cluster, max_iter=300, verbose=0, random_state=0) 181 | z_v1, golds = transform(dataset, dataset.trn_idx, model, encoder='v1') 182 | preds_v1 = kmeans.fit_predict(z_v1) 183 | 184 | lgolds, lpreds = [], [] 185 | for g, p in zip(golds, list(preds_v1)): 186 | if g > 0: 187 | lgolds.append(g) 188 | lpreds.append(p) 189 | prec, rec, f1 = cluster_metrics.calc_prec_rec_f1( 190 | gnd_assignments=torch.LongTensor(lgolds).to(device), 191 | pred_assignments=torch.LongTensor(lpreds).to(device)) 192 | acc = cluster_metrics.calc_ACC( 193 | torch.LongTensor(lpreds).to(device), torch.LongTensor(lgolds).to(device)) 194 | 195 | print(f'{datetime.datetime.now()} pretrain: test prec={prec:.4f} rec={rec:.4f} ' 196 | f'f1={f1:.4f} acc={acc:.4f}') 197 | 198 | shot, way, query = 5, args.way, 15 199 | 200 | preds_v2 = None 201 | best_epoch, best_model, best_dev_f1 = None, None, None 202 | for epoch in range(1, args.num_epochs + 1): 203 | trn_loss = 0. 204 | 205 | _loss, preds_v2, tst_preds_v2 = run_one_side( 206 | model=model, optimizer=optimizer, preds_left=preds_v1, 207 | pt_batch=args.pt_batch, way=way, shot=shot, query=query, n_cluster=n_cluster, 208 | dataset=dataset, right_encoder_side='v2') 209 | trn_loss += _loss 210 | 211 | _loss, preds_v1, tst_preds_v1 = run_one_side( 212 | model=model, optimizer=optimizer, preds_left=preds_v2, 213 | pt_batch=args.pt_batch, way=way, shot=shot, query=query, n_cluster=n_cluster, 214 | dataset=dataset, right_encoder_side='v1') 215 | trn_loss += _loss 216 | 217 | dev_f1 = cluster_metrics.calc_f1(gnd_assignments=torch.LongTensor(tst_preds_v1).to(device), 218 | pred_assignments=torch.LongTensor(tst_preds_v2).to(device)) 219 | dev_acc = cluster_metrics.calc_ACC( 220 | torch.LongTensor(tst_preds_v2).to(device), torch.LongTensor(tst_preds_v1).to(device)) 221 | 222 | print('dev view 1 vs view 2: f1={:.4f} acc={:.4f}'.format(dev_f1, dev_acc)) 223 | 224 | if best_dev_f1 is None or dev_f1 > best_dev_f1: 225 | print('new best epoch', epoch) 226 | best_epoch = epoch 227 | best_dev_f1 = dev_f1 228 | best_model = copy.deepcopy(model.state_dict()) 229 | best_preds_v1 = preds_v1.copy() 230 | best_preds_v2 = preds_v2.copy() 231 | 232 | lgolds, lpreds = [], [] 233 | for g, p in zip(golds, list(preds_v1)): 234 | if g > 0: 235 | lgolds.append(g) 236 | lpreds.append(p) 237 | prec, rec, f1 = cluster_metrics.calc_prec_rec_f1( 238 | gnd_assignments=torch.LongTensor(lgolds).to(device), 239 | pred_assignments=torch.LongTensor(lpreds).to(device)) 240 | acc = cluster_metrics.calc_ACC( 241 | torch.LongTensor(lpreds).to(device), torch.LongTensor(lgolds).to(device)) 242 | 243 | print(f'{datetime.datetime.now()} epoch {epoch}, test prec={prec:.4f} rec={rec:.4f} ' 244 | f'f1={f1:.4f} acc={acc:.4f}') 245 | 246 | print('restoring model for best dev epoch', best_epoch) 247 | model.load_state_dict(best_model) 248 | preds_v1, preds_v2 = best_preds_v1, best_preds_v2 249 | 250 | lgolds, lpreds = [], [] 251 | for g, p in zip(golds, list(preds_v1)): 252 | if g > 0: 253 | lgolds.append(g) 254 | lpreds.append(p) 255 | prec, rec, f1 = cluster_metrics.calc_prec_rec_f1( 256 | gnd_assignments=torch.LongTensor(lgolds).to(device), 257 | pred_assignments=torch.LongTensor(lpreds).to(device)) 258 | acc = cluster_metrics.calc_ACC( 259 | torch.LongTensor(lpreds).to(device), torch.LongTensor(lgolds).to(device)) 260 | print(f'{datetime.datetime.now()} test prec={prec:.4f} rec={rec:.4f} f1={f1:.4f} acc={acc:.4f}') 261 | 262 | if args.save_model_path is not None: 263 | preds_v1 = torch.from_numpy(preds_v1) 264 | if preds_v2 is not None: 265 | preds_v2 = torch.from_numpy(preds_v2) 266 | state = { 267 | 'model_state': model.state_dict(), 268 | 'id_to_token': dataset.id_to_token, 269 | 'word_emb_size': word_emb_size, 270 | 'v1_assignments': preds_v1, 271 | 'v2_assignments': preds_v2 272 | } 273 | with open(expand(args.save_model_path), 'wb') as f: 274 | torch.save(state, f) 275 | print('saved model to ', args.save_model_path) 276 | 277 | 278 | if __name__ == '__main__': 279 | main() 280 | -------------------------------------------------------------------------------- /train_pca.py: -------------------------------------------------------------------------------- 1 | """ 2 | train PCA baseline 3 | 4 | forked from train.py, then MVC stripped, and replaced with BoW + PCA 5 | 6 | if using mvsc, needs installation of: 7 | - https://github.com/mariceli3/multiview 8 | """ 9 | import argparse 10 | import datetime 11 | import numpy as np 12 | import sklearn.cluster 13 | import warnings 14 | import time 15 | from sklearn.feature_extraction.text import TfidfVectorizer 16 | from sklearn.decomposition import TruncatedSVD 17 | 18 | import torch 19 | 20 | from proc_data import Dataset 21 | from metrics import cluster_metrics 22 | 23 | 24 | warnings.filterwarnings(action='ignore', category=RuntimeWarning) 25 | np.random.seed(0) 26 | torch.manual_seed(0) 27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 28 | 29 | 30 | def main(): 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--data-path', type=str, default='./data/airlines_processed.csv') 33 | parser.add_argument('--model', type=str, choices=[ 34 | 'view1pca', 'view2pca', 'wholeconvpca', 'mvsc'], default='view1pca') 35 | parser.add_argument('--pca-dims', type=int, default=600) 36 | 37 | parser.add_argument('--no-idf', action='store_true') 38 | parser.add_argument('--mvsc-no-unk', action='store_true', 39 | help='only feed non-unk data to MVSC (to avoid oom)') 40 | 41 | parser.add_argument('--view1-col', type=str, default='view1') 42 | parser.add_argument('--view2-col', type=str, default='view2') 43 | parser.add_argument('--label-col', type=str, default='tag') 44 | args = parser.parse_args() 45 | run(**args.__dict__) 46 | 47 | 48 | def run(data_path, model, pca_dims, view1_col, view2_col, label_col, no_idf, mvsc_no_unk): 49 | print('loading dataset') 50 | dataset = Dataset(data_path, view1_col=view1_col, view2_col=view2_col, label_col=label_col) 51 | n_cluster = len(dataset.id_to_label) - 1 52 | print("num of class = %d" % n_cluster) 53 | 54 | vocab_size = len(dataset.token_to_id) 55 | print('vocab_size', vocab_size) 56 | 57 | if model == 'mvsc': 58 | try: 59 | import multiview 60 | except Exception: 61 | print('please install https://github.com/mariceli3/multiview') 62 | return 63 | print('imported multiview ok') 64 | 65 | def run_pca(features): 66 | print('fitting tfidf vectorizer', flush=True, end='') 67 | vectorizer = TfidfVectorizer(token_pattern='\\d+', ngram_range=(1, 1), analyzer='word', 68 | min_df=0.0, max_df=1.0, use_idf=not no_idf) 69 | X = vectorizer.fit_transform(features) 70 | print(' ... done') 71 | print('X.shape', X.shape) 72 | 73 | print('running pca', flush=True, end='') 74 | pca = TruncatedSVD(n_components=pca_dims) 75 | X2 = pca.fit_transform(X) 76 | print(' ... done') 77 | return X2 78 | 79 | golds = [dataset[idx][1] for idx in dataset.trn_idx] 80 | 81 | if model in ['view1pca', 'view2pca', 'wholeconvpca']: 82 | if model == 'view1pca': 83 | utts = [dataset[idx][0][0] for idx in dataset.trn_idx] 84 | utts = [' '.join([str(idx) for idx in utt]) for utt in utts] 85 | elif model == 'view2pca': 86 | convs = [dataset[idx][0][1] for idx in dataset.trn_idx] 87 | utts = [[tok for utt in conv for tok in utt] for conv in convs] 88 | utts = [' '.join([str(idx) for idx in utt]) for utt in utts] 89 | elif model == 'wholeconvpca': 90 | v1 = [dataset[idx][0][0] for idx in dataset.trn_idx] 91 | convs = [dataset[idx][0][1] for idx in dataset.trn_idx] 92 | v2 = [[tok for utt in conv for tok in utt] for conv in convs] 93 | utts = [] 94 | for n in range(len(v1)): 95 | utts.append(v1[n] + v2[n]) 96 | utts = [' '.join([str(idx) for idx in utt]) for utt in utts] 97 | 98 | X2 = run_pca(utts) 99 | 100 | print('running kmeans', flush=True, end='') 101 | kmeans = sklearn.cluster.KMeans( 102 | n_clusters=n_cluster, max_iter=300, verbose=0, random_state=0) 103 | preds = kmeans.fit_predict(X2) 104 | print(' ... done') 105 | elif model == 'mvsc': 106 | mvsc = multiview.mvsc.MVSC( 107 | k=n_cluster 108 | ) 109 | idxes = dataset.trn_idx_no_unk if mvsc_no_unk else dataset.trn_idx 110 | v1 = [dataset[idx][0][0] for idx in idxes] 111 | convs = [dataset[idx][0][1] for idx in idxes] 112 | v2 = [[tok for utt in conv for tok in utt] for conv in convs] 113 | v1 = [' '.join([str(idx) for idx in utt]) for utt in v1] 114 | v2 = [' '.join([str(idx) for idx in utt]) for utt in v2] 115 | v1_pca = run_pca(v1) 116 | v2_pca = run_pca(v2) 117 | print('running mvsc', end='', flush=True) 118 | start = time.time() 119 | preds, eivalues, eivectors, sigmas = mvsc.fit_transform( 120 | [v1_pca, v2_pca], [False] * 2 121 | ) 122 | print('...done') 123 | mvsc_time = time.time() - start 124 | print('time taken %.3f' % mvsc_time) 125 | 126 | lgolds, lpreds = [], [] 127 | for g, p in zip(golds, list(preds)): 128 | if g > 0: 129 | lgolds.append(g) 130 | lpreds.append(p) 131 | prec, rec, f1 = cluster_metrics.calc_prec_rec_f1( 132 | gnd_assignments=torch.LongTensor(lgolds).to(device), 133 | pred_assignments=torch.LongTensor(lpreds).to(device)) 134 | acc = cluster_metrics.calc_ACC( 135 | torch.LongTensor(lpreds).to(device), torch.LongTensor(lgolds).to(device)) 136 | 137 | print(f'{datetime.datetime.now()} eval f1={f1:.4f} prec={prec:.4f} rec={rec:.4f} acc={acc:.4f}') 138 | 139 | return prec, rec, f1, acc 140 | 141 | 142 | if __name__ == '__main__': 143 | main() 144 | -------------------------------------------------------------------------------- /train_qt.py: -------------------------------------------------------------------------------- 1 | """ 2 | pretrain using qt, to get neural representation, then run kmeans on various 3 | combinations of the resulting representations 4 | 5 | this was forked initially from train.py, then modified 6 | """ 7 | import argparse 8 | import datetime 9 | import copy 10 | import time 11 | import numpy as np 12 | import sklearn.cluster 13 | import warnings 14 | import torch 15 | from torch import autograd 16 | 17 | from proc_data import Dataset 18 | from model.multiview_encoders import MultiviewEncoders 19 | from metrics import cluster_metrics 20 | import pretrain 21 | 22 | 23 | warnings.filterwarnings(action='ignore', category=RuntimeWarning) 24 | 25 | torch.manual_seed(0) 26 | np.random.seed(0) 27 | 28 | LSTM_LAYER = 1 29 | LSTM_HIDDEN = 300 30 | WORD_DROPOUT_RATE = 0. 31 | DROPOUT_RATE = 0. 32 | BATCH_SIZE = 32 33 | LEARNING_RATE = 0.001 34 | 35 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 36 | 37 | 38 | def transform(data, model): 39 | model.eval() 40 | latent_zs = [] 41 | n_batch = (len(data) + BATCH_SIZE - 1) // BATCH_SIZE 42 | for i in range(n_batch): 43 | data_batch = data[i*BATCH_SIZE:(i+1)*BATCH_SIZE] 44 | with autograd.no_grad(): 45 | latent_z = model(data_batch, encoder='v1') 46 | latent_zs.append(latent_z.cpu().data.numpy()) 47 | latent_zs = np.concatenate(latent_zs) 48 | return latent_zs 49 | 50 | 51 | def calc_prec_rec_f1_acc(preds, golds): 52 | lgolds, lpreds = [], [] 53 | for g, p in zip(golds, list(preds)): 54 | if g > 0: 55 | lgolds.append(g) 56 | lpreds.append(p) 57 | prec, rec, f1 = cluster_metrics.calc_prec_rec_f1( 58 | gnd_assignments=torch.LongTensor(lgolds).to(device), 59 | pred_assignments=torch.LongTensor(lpreds).to(device)) 60 | acc = cluster_metrics.calc_ACC( 61 | torch.LongTensor(lpreds).to(device), torch.LongTensor(lgolds).to(device)) 62 | return prec, rec, f1, acc 63 | 64 | 65 | def main(): 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument('--data-path', type=str, default='./data/airlines_processed.csv') 68 | parser.add_argument('--glove-path', type=str, default='./data/glove.840B.300d.txt') 69 | parser.add_argument('--pre-epoch', type=int, default=5) 70 | parser.add_argument('--pt-batch', type=int, default=100) 71 | parser.add_argument('--scenarios', type=str, default='view1,view2,concatviews,wholeconv', 72 | help='comma-separated, from [view1|view2|concatviews|wholeconv|mvsc]') 73 | parser.add_argument('--mvsc-no-unk', action='store_true', 74 | help='only feed non-unk data to MVSC (to avoid oom)') 75 | 76 | parser.add_argument('--view1-col', type=str, default='view1') 77 | parser.add_argument('--view2-col', type=str, default='view2') 78 | parser.add_argument('--label-col', type=str, default='tag') 79 | args = parser.parse_args() 80 | 81 | print('loading dataset') 82 | dataset = Dataset(args.data_path, view1_col=args.view1_col, view2_col=args.view2_col, 83 | label_col=args.label_col) 84 | n_cluster = len(dataset.id_to_label) - 1 85 | print("num of class = %d" % n_cluster) 86 | 87 | id_to_token, token_to_id = dataset.id_to_token, dataset.token_to_id 88 | vocab_size = len(dataset.token_to_id) 89 | print('vocab_size', vocab_size) 90 | 91 | # Load pre-trained GloVe vectors 92 | pretrained = {} 93 | word_emb_size = 0 94 | print('loading glove') 95 | for line in open(args.glove_path): 96 | parts = line.strip().split() 97 | if len(parts) % 100 != 1: 98 | continue 99 | word = parts[0] 100 | if word not in token_to_id: 101 | continue 102 | vector = [float(v) for v in parts[1:]] 103 | pretrained[word] = vector 104 | word_emb_size = len(vector) 105 | pretrained_list = [] 106 | scale = np.sqrt(3.0 / word_emb_size) 107 | print('loading oov') 108 | for word in id_to_token: 109 | # apply lower() because all GloVe vectors are for lowercase words 110 | if word.lower() in pretrained: 111 | pretrained_list.append(np.array(pretrained[word.lower()])) 112 | else: 113 | random_vector = np.random.uniform(-scale, scale, [word_emb_size]) 114 | pretrained_list.append(random_vector) 115 | 116 | model = MultiviewEncoders.from_embeddings( 117 | embeddings=torch.FloatTensor(pretrained_list), 118 | num_layers=LSTM_LAYER, 119 | embedding_size=word_emb_size, 120 | lstm_hidden_size=LSTM_HIDDEN, 121 | word_dropout=WORD_DROPOUT_RATE, 122 | dropout=DROPOUT_RATE, 123 | vocab_size=vocab_size 124 | ) 125 | model.to(device) 126 | optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 127 | 128 | expressions = (model, optimizer) 129 | pre_acc, pre_state = 0., None 130 | pretrain_method = pretrain.pretrain_qt 131 | for epoch in range(1, args.pre_epoch + 1): 132 | model.train() 133 | perm_idx = np.random.permutation(dataset.trn_idx) 134 | trn_loss, _ = pretrain_method(dataset, perm_idx, expressions, train=True) 135 | model.eval() 136 | _, tst_acc = pretrain_method(dataset, dataset.tst_idx, expressions, train=False) 137 | if tst_acc > pre_acc: 138 | pre_state = copy.deepcopy(model.state_dict()) 139 | pre_acc = tst_acc 140 | print(f'{datetime.datetime.now()} epoch {epoch}, train_loss={trn_loss:.4f} ' 141 | f'test_acc={tst_acc:.4f}') 142 | 143 | if args.pre_epoch > 0: 144 | # load best state 145 | model.load_state_dict(pre_state) 146 | 147 | # deepcopy pretrained views into v1 and/or view2 148 | pretrain.after_pretrain_qt(model) 149 | 150 | kmeans = sklearn.cluster.KMeans(n_clusters=n_cluster, max_iter=300, verbose=0, random_state=0) 151 | 152 | golds = [dataset[idx][1] for idx in dataset.trn_idx] 153 | for rep in args.scenarios.split(','): 154 | if rep == 'view1': 155 | data = [dataset[idx][0][0] for idx in dataset.trn_idx] 156 | encoded = transform(data=data, model=model) 157 | preds = kmeans.fit_predict(encoded) 158 | elif rep == 'view2': 159 | data = [dataset[idx][0][1] for idx in dataset.trn_idx] 160 | encoded = [] 161 | for conv in data: 162 | encoded_conv = transform(data=conv, model=model) 163 | encoded_conv = torch.from_numpy(encoded_conv) 164 | encoded_conv = encoded_conv.mean(dim=0) 165 | encoded.append(encoded_conv) 166 | encoded = torch.stack(encoded, dim=0) 167 | # print('encoded.size()', encoded.size()) 168 | encoded = encoded.numpy() 169 | preds = kmeans.fit_predict(encoded) 170 | elif rep == 'concatviews': 171 | v1_data = [dataset[idx][0][0] for idx in dataset.trn_idx] 172 | v1_encoded = torch.from_numpy(transform(data=v1_data, model=model)) 173 | 174 | v2_data = [dataset[idx][0][1] for idx in dataset.trn_idx] 175 | v2_encoded = [] 176 | for conv in v2_data: 177 | encoded_conv = transform(data=conv, model=model) 178 | encoded_conv = torch.from_numpy(encoded_conv) 179 | encoded_conv = encoded_conv.mean(dim=0) 180 | v2_encoded.append(encoded_conv) 181 | v2_encoded = torch.stack(v2_encoded, dim=0) 182 | concatview = torch.cat([v1_encoded, v2_encoded], dim=-1) 183 | print('concatview.size()', concatview.size()) 184 | encoded = concatview.numpy() 185 | preds = kmeans.fit_predict(encoded) 186 | elif rep == 'wholeconv': 187 | encoded = [] 188 | for idx in dataset.trn_idx: 189 | v1 = dataset[idx][0][0] 190 | v2 = dataset[idx][0][1] 191 | conv = [v1] + v2 192 | encoded_conv = transform(data=conv, model=model) 193 | encoded_conv = torch.from_numpy(encoded_conv) 194 | encoded_conv = encoded_conv.mean(dim=0) 195 | encoded.append(encoded_conv) 196 | encoded = torch.stack(encoded, dim=0) 197 | print('encoded.size()', encoded.size()) 198 | encoded = encoded.numpy() 199 | preds = kmeans.fit_predict(encoded) 200 | elif rep == 'mvsc': 201 | try: 202 | import multiview 203 | except Exception: 204 | print('please install https://github.com/mariceli3/multiview') 205 | return 206 | print('imported multiview ok') 207 | 208 | idx = dataset.trn_idx_no_unk if args.mvsc_no_unk else dataset.trn_idx 209 | v1_data = [dataset[idx][0][0] for idx in idx] 210 | v1_encoded = torch.from_numpy(transform(data=v1_data, model=model)) 211 | 212 | v2_data = [dataset[idx][0][1] for idx in idx] 213 | v2_encoded = [] 214 | for conv in v2_data: 215 | encoded_conv = transform(data=conv, model=model) 216 | encoded_conv = torch.from_numpy(encoded_conv) 217 | encoded_conv = encoded_conv.mean(dim=0) 218 | v2_encoded.append(encoded_conv) 219 | v2_encoded = torch.stack(v2_encoded, dim=0) 220 | 221 | mvsc = multiview.mvsc.MVSC( 222 | k=n_cluster 223 | ) 224 | print('running mvsc', end='', flush=True) 225 | start = time.time() 226 | preds, eivalues, eivectors, sigmas = mvsc.fit_transform( 227 | [v1_encoded, v2_encoded], [False] * 2 228 | ) 229 | print('...done') 230 | mvsc_time = time.time() - start 231 | print('time taken %.3f' % mvsc_time) 232 | else: 233 | raise Exception('unimplemented rep', rep) 234 | 235 | prec, rec, f1, acc = calc_prec_rec_f1_acc(preds, golds) 236 | print(f'{datetime.datetime.now()} {rep}: eval prec={prec:.4f} rec={rec:.4f} f1={f1:.4f} ' 237 | f'acc={acc:.4f}') 238 | 239 | 240 | if __name__ == '__main__': 241 | main() 242 | --------------------------------------------------------------------------------