├── .gitignore
├── LICENSE.txt
├── README.md
├── cmdlines
├── airlines_ae.sh
├── airlines_ae_mvc.sh
├── airlines_mvc.sh
├── airlines_none.sh
├── airlines_qt.sh
├── airlines_qt_mvc.sh
├── ubuntu_ae.sh
├── ubuntu_ae_mvc.sh
├── ubuntu_mvc.sh
├── ubuntu_none.sh
├── ubuntu_qt.sh
└── ubuntu_qt_mvc.sh
├── data
├── LICENSE
├── README.md
├── airlines_500onlyb.csv.bz2
├── airlines_processed.csv.bz2
├── airlines_raw.csv.bz2
├── askubuntu_processed.csv.bz2
└── askubuntu_raw.csv.bz2
├── datasets
├── askubuntu_preprocess.py
├── labeled_unlabeled_merger.py
├── readme.md
├── twitter_airlines_raw_merger2.py
└── twitter_dataset_preprocess.py
├── images
├── avkmeans_graph.png
└── example_dialogs.png
├── labeled_unlabeled_merger.py
├── metrics
├── __init__.py
└── cluster_metrics.py
├── model
├── __init__.py
├── decoder.py
├── encoder.py
├── multiview_encoders.py
└── utils.py
├── preprocessing
├── __init__.py
├── askubuntu.py
└── twitter_airlines.py
├── pretrain.py
├── proc_data.py
├── requirements.txt
├── run_mvsc.py
├── samplers.py
├── setup.cfg
├── tests
├── test_multiview_encoders.py
└── test_samplers.py
├── train.py
├── train_pca.py
└── train_qt.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 ASAPP Inc
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Code and data for paper ["Dialog Intent Induction with Deep Multi-View Clustering"](https://arxiv.org/abs/1908.11487), Hugh Perkins and Yi Yang, 2019, to appear in EMNLP 2019.
2 |
3 | Data is available in the sub-directory [data](data), with a specific [LICENSE](data/LICENSE) file.
4 |
5 | # Dialog Intent Induction
6 |
7 | Dialog intent induction aims at automatically discovering dialog intents from human-human conversations. The problem is largely overlooked in the prior academic works, which created a hugh gap between academics and industry.
8 |
9 | In particular, **academic dialog** datasets such as ATIS and MultiWoZ assume dialog intents are given; they also focus on simple dialog intents like `BookRestaurant` or `BookHotel`. However, many complex dialog intents emerge in **industrial settings** that are hard to predefine; the dialog intents are also undergoing dynamic changes.
10 |
11 | ## Deep multi-view clustering
12 |
13 | In this work, we propose to tackle this problem using multi-view clustering. Consider the following example dialogs:
14 |
15 |
16 |
17 | The user query utterances (query view) are lexically and syntactically dissimilar. However, the solution trajectories (content view) are similar.
18 |
19 | ### Alternating-view k-means (AV-KMeans)
20 |
21 | We propose a novel method for joint representation learning and multi-view cluster: alternating-view k-means (AV-KMeans).
22 |
23 |
24 |
25 | We perform clustering on view 1 and project the assignment to view 2 for classification. The encoders are fixed for clustering and updated for classification.
26 |
27 | ## Experiments
28 |
29 | We construct a new dataset to evaluate this new intent induction task: Twitter Airlines Customer Support (TwACS).
30 |
31 | We compare three competitive clustering methods: k-means, [Multi-View Spectral Clustering (MVSC)](https://github.com/mariceli3/multiview), and `AV-Kmeans` (ours). We experiment with three approaches to parameter initialization: PCA for k-means and MVSC; autoencoders; and [quick thoughts](https://arxiv.org/pdf/1803.02893.pdf).
32 |
33 | The F1 scores are presented below:
34 |
35 | |Algo | PCA/None | autoencoders | quick thoughts |
36 | |------|----------|--------------|----------------|
37 | |k-means| 28.2 | 29.5 | 42.1|
38 | |MVSC| 27.8 | 31.3 | 40 |
39 | |**AV-Kmeans (ours)** | **35.4** | **38.9** | **46.2** |
40 |
41 |
42 | # Usage
43 |
44 | ## Pre-requisites
45 |
46 | - decompress the `.bz2` files in `data`folder
47 | - download http://nlp.stanford.edu/data/glove.840B.300d.zip, and unzip `glove.840B.300d.txt` into `data` folder
48 |
49 | ## To run AV-Kmeans
50 |
51 | - run one of:
52 | ```
53 | # no pre-training
54 | cmdlines/airlines_mvc.sh
55 |
56 | # ae pre-training
57 | cmdlines/airlines_ae.sh
58 | cmdlines/airlines_ae_mvc.sh
59 |
60 | # qt pre-training
61 | cmdlines/airlines_qt.sh
62 | cmdlines/airlines_qt_mvc.sh
63 | ```
64 | - to train on askubuntu, replace `airlines` with `ubuntu` in the above command-lines
65 |
66 | ## To run k-means baseline
67 |
68 | - for qt pretraining run:
69 | ```
70 | PYTHONPATH=. python train_qt.py --data-path data/airlines_processed.csv --pre-epoch 10 --view1-col first_utterance --view2-col context --scenarios view1
71 | ```
72 | - to train on askubuntu, replace `airlines` with `askubuntu` in the above command-line, and remove the `--*-col` command-line options
73 |
--------------------------------------------------------------------------------
/cmdlines/airlines_ae.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python train.py --pre-epoch 20 --pre-model ae --num-epochs 0 \
4 | --data-path data/airlines_processed.csv \
5 | --view1-col first_utterance --view2-col context --label-col tag \
6 | --save-model-path data/airlines_ae.pth
7 |
--------------------------------------------------------------------------------
/cmdlines/airlines_ae_mvc.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [[ ! -f data/airlines_ae.pth ]]; then {
4 | echo Please train airlines ae first
5 | } fi
6 |
7 | python train.py --pre-epoch 0 --num-epochs 50 --data-path data/airlines_processed.csv \
8 | --view1-col first_utterance --view2-col context --label-col tag \
9 | --model-path data/airlines_ae.pth
10 |
--------------------------------------------------------------------------------
/cmdlines/airlines_mvc.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python train.py --pre-epoch 0 --num-epochs 50 --data-path data/airlines_processed.csv \
4 | --view1-col first_utterance --view2-col context --label-col tag
5 |
--------------------------------------------------------------------------------
/cmdlines/airlines_none.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python train.py --pre-epoch 0 --num-epochs 0 --data-path data/airlines_processed.csv \
4 | --view1-col first_utterance --view2-col context --label-col tag
5 |
--------------------------------------------------------------------------------
/cmdlines/airlines_qt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python train.py --pre-epoch 10 --pre-model qt --num-epochs 0 --data-path data/airlines_processed.csv \
4 | --view1-col first_utterance --view2-col context --label-col tag \
5 | --save-model-path data/airlines_qt.pth
6 |
--------------------------------------------------------------------------------
/cmdlines/airlines_qt_mvc.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [[ ! -f data/airlines_qt.pth ]]; then {
4 | echo Please train airlines qt first
5 | } fi
6 |
7 | python train.py --pre-epoch 0 --num-epochs 50 --data-path data/airlines_processed.csv \
8 | --view1-col first_utterance --view2-col context --label-col tag \
9 | --model-path data/airlines_qt.pth
10 |
--------------------------------------------------------------------------------
/cmdlines/ubuntu_ae.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python train.py --pre-epoch 20 --pre-model ae --num-epochs 0 \
4 | --data-path data/askubuntu_processed.csv \
5 | --save-model-path data/ubuntu_ae.pth
6 |
--------------------------------------------------------------------------------
/cmdlines/ubuntu_ae_mvc.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [[ ! -f data/ubuntu_ae.pth ]]; then {
4 | echo Please train askubuntu ae first
5 | } fi
6 |
7 | python train.py --pre-epoch 0 --num-epochs 50 --data-path data/askubuntu_processed.csv \
8 | --model-path data/ubuntu_ae.pth
9 |
--------------------------------------------------------------------------------
/cmdlines/ubuntu_mvc.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python train.py --pre-epoch 0 --num-epochs 50 --data-path data/askubuntu_processed.csv
4 |
--------------------------------------------------------------------------------
/cmdlines/ubuntu_none.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python train.py --pre-epoch 0 --num-epochs 0 --data-path data/askubuntu_processed.csv
4 |
--------------------------------------------------------------------------------
/cmdlines/ubuntu_qt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python train.py --pre-epoch 10 --pre-model qt --num-epochs 0 --data-path data/askubuntu_processed.csv \
4 | --save-model-path data/ubuntu_qt.pth
5 |
--------------------------------------------------------------------------------
/cmdlines/ubuntu_qt_mvc.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [[ ! -f data/ubuntu_qt.pth ]]; then {
4 | echo Please train askubuntu qt first
5 | } fi
6 |
7 | python train.py --pre-epoch 0 --num-epochs 50 --data-path data/askubuntu_processed.csv \
8 | --model-path data/ubuntu_qt.pth
9 |
--------------------------------------------------------------------------------
/data/LICENSE:
--------------------------------------------------------------------------------
1 | There are four dataset files in this folder:
2 |
3 | - airlines_raw.csv.bz2 and airlines_processed.csv.bz2 are derived from Kaggle 'Customer Support on Twitter' dataset,
4 | https://www.kaggle.com/thoughtvector/customer-support-on-twitter,
5 | which is available under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0) license,
6 | https://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | - askubuntu_raw.csv.bz2 and askubuntu_processed.csv.bz2 are derived from Stack Exchange Data Dump, https://archive.org/details/stackexchange,
9 | which is provided under a Creative Common Attribution-ShareAlike 3.0 Unported (CC BY-SA 3.0) license, https://creativecommons.org/licenses/by-sa/3.0/
10 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | Data
2 | ----
3 |
4 | Training vs test data
5 | ---------------------
6 |
7 | We train on all data, without labels. We use the labels in order to evaluate the resulting clusters.
8 |
9 | Twitter Airlines Customer Support
10 | ---------------------------------
11 |
12 | The data is available in two version:
13 |
14 | - raw: minimal redaction (company and customer twitter id), no preprocessing: [airlines_raw.csv.bz2](airlines_raw.csv.bz2)
15 | - redacted, and preprocessed: [airlines_processed.csv.bz2](airlines_processed.csv.bz2)
16 |
17 | We sampled 500 examples, and annotated them. 8 examples were rejected because not English, leaving 492 labeled examples. The remaining examples were labeled `UNK`.
18 |
19 | AskUbuntu
20 | ---------
21 |
22 | - raw, no preprocessing: [askubuntu_raw.csv.bz2](askubuntu_raw.csv.bz2)
23 | - preprocessed: [askubuntu_processed.csv.bz2](askubuntu_processed.csv.bz2)
24 |
--------------------------------------------------------------------------------
/data/airlines_500onlyb.csv.bz2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asappresearch/dialog-intent-induction/6396f3153b0fda7e170b1df6b68e969b5e4eb16e/data/airlines_500onlyb.csv.bz2
--------------------------------------------------------------------------------
/data/airlines_processed.csv.bz2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asappresearch/dialog-intent-induction/6396f3153b0fda7e170b1df6b68e969b5e4eb16e/data/airlines_processed.csv.bz2
--------------------------------------------------------------------------------
/data/airlines_raw.csv.bz2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asappresearch/dialog-intent-induction/6396f3153b0fda7e170b1df6b68e969b5e4eb16e/data/airlines_raw.csv.bz2
--------------------------------------------------------------------------------
/data/askubuntu_processed.csv.bz2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asappresearch/dialog-intent-induction/6396f3153b0fda7e170b1df6b68e969b5e4eb16e/data/askubuntu_processed.csv.bz2
--------------------------------------------------------------------------------
/data/askubuntu_raw.csv.bz2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asappresearch/dialog-intent-induction/6396f3153b0fda7e170b1df6b68e969b5e4eb16e/data/askubuntu_raw.csv.bz2
--------------------------------------------------------------------------------
/datasets/askubuntu_preprocess.py:
--------------------------------------------------------------------------------
1 | """
2 | given askubuntu links, find the distribution of clusters formed of connected subgraphs
3 |
4 | needs networkx installed (conda install networkx)
5 |
6 | conda install -y networkx beautifulsoup4 lxml
7 | python -c 'import nltk; nltk.download("stopwords")'
8 |
9 | """
10 | import random
11 | import argparse
12 | import csv
13 | from collections import defaultdict
14 | import time
15 | from os.path import expanduser as expand
16 | from xml.dom import minidom
17 |
18 | import networkx as nx
19 | import torch
20 | import numpy as np
21 |
22 | from preprocessing.askubuntu import Preprocessor
23 |
24 |
25 | class NullPreprocessor(object):
26 | def __call__(self, text):
27 | text = text.replace('\n', ' ').replace('\r', ' ')
28 | return True, text
29 |
30 |
31 | def index_posts(args):
32 | """
33 | got through Posts.xml, and find for each post id, the id of the answer. store both in a csv file
34 |
35 | example row:
36 |
45 | """
46 | in_f = open(expand(args.in_posts), 'r')
47 | out_f = open(expand(args.out_posts_index), 'w')
48 | dict_writer = csv.DictWriter(out_f, fieldnames=['question_id', 'answer_id'])
49 | dict_writer.writeheader()
50 | last_print = time.time()
51 | for n, row in enumerate(in_f):
52 | row = row.strip()
53 | if not row.startswith('= args.in_max_posts:
63 | print('reached max rows => breaking')
64 | break
65 | if time.time() - last_print >= 3.0:
66 | print(n)
67 | last_print = time.time()
68 |
69 |
70 | def get_clusters(in_dupes_file, in_max_dupes, max_clusters):
71 | """
72 | we're reading in the list of pairs of dupes. These are a in format of two post ids per line, like:
73 | 615465 8653
74 | 833376 377050
75 | 30585 120621
76 | 178532 152184
77 | 69455 68850
78 |
79 | When we read these in, we have no idea whether these posts have answers etc. We just read in all
80 | the pairs of post ids. We are then going to add these post ids to a graph (each post id forms a node),
81 | and the pairs of post ids become connectsion in the graph. We then form all connected components.
82 |
83 | We sort the connected components by size (reverse order), and take the top num_test_clusters components
84 | We just ignore the other components
85 | """
86 | f_in = open(expand(in_dupes_file), 'r')
87 | csv_reader = csv.reader(f_in, delimiter=' ')
88 | G = nx.Graph()
89 | for n, row in enumerate(csv_reader):
90 | left = int(row[0])
91 | right = int(row[1])
92 | G.add_edge(left, right)
93 | if in_max_dupes is not None and n >= in_max_dupes:
94 | print('reached max tao rows => break')
95 | break
96 | print('num nodes', len(G))
97 | print('num clusters', len(list(nx.connected_components(G))))
98 | count_by_size = defaultdict(int)
99 | clusters_by_size = defaultdict(list)
100 | for i, cluster in enumerate(nx.connected_components(G)):
101 | size = len(cluster)
102 | count_by_size[size] += 1
103 | clusters_by_size[size].append(cluster)
104 | print('count by size:')
105 | clusters = []
106 | top_clusters = []
107 | for size, count in sorted(count_by_size.items(), reverse=True):
108 | for cluster in clusters_by_size[size]:
109 | clusters.append(cluster)
110 | if len(top_clusters) < max_clusters:
111 | top_clusters.append(cluster)
112 |
113 | print('len(clusters)', len(clusters))
114 | top_cluster_post_ids = [id for cluster in top_clusters for id in cluster]
115 |
116 | post2cluster = {}
117 | for cluster_id, cluster in enumerate(clusters):
118 | for post in cluster:
119 | post2cluster[post] = cluster_id
120 |
121 | return top_clusters, top_cluster_post_ids, post2cluster
122 |
123 |
124 | def read_posts_index(in_posts_index):
125 | with open(expand(args.in_posts_index), 'r') as f:
126 | dict_reader = csv.DictReader(f)
127 | index_rows = list(dict_reader)
128 | index_rows = [{'question_id': int(row['question_id']), 'answer_id': int(row['answer_id'])} for row in index_rows]
129 | post2answer = {row['question_id']: row['answer_id'] for row in index_rows}
130 | answer2post = {row['answer_id']: row['question_id'] for row in index_rows}
131 | print('loaded index')
132 | return post2answer, answer2post
133 |
134 |
135 | def load_posts(answer2post, post2answer, in_posts):
136 | # load in all the posts from Posts.xml
137 | post_by_id = defaultdict(dict)
138 | posts_f = open(expand(in_posts), 'r')
139 | last_print = time.time()
140 | for n, row in enumerate(posts_f):
141 | row = row.strip()
142 | if not row.startswith('
161 | """
162 | dom = minidom.parseString(row)
163 | node = dom.firstChild
164 | att = node.attributes
165 | assert att['PostTypeId'].value == '1'
166 | post_by_id[row_id]['question_title'] = att['Title'].value
167 | post_by_id[row_id]['question_body'] = att['Body'].value
168 | elif row_id in answer2post:
169 | dom = minidom.parseString(row)
170 | node = dom.firstChild
171 | att = node.attributes
172 | assert att['PostTypeId'].value == '2'
173 | post_id = answer2post[row_id]
174 | post_by_id[post_id]['answer_body'] = att['Body'].value
175 | if time.time() - last_print >= 3.0:
176 | print(len(post_by_id))
177 | last_print = time.time()
178 | if args.in_max_posts is not None and n > args.in_max_posts:
179 | print('reached in_max_posts => terminating')
180 | break
181 | print('loaded info from Posts.xml')
182 | return post_by_id
183 |
184 |
185 | def create_labeled(args):
186 | torch.manual_seed(args.seed)
187 | np.random.seed(args.seed)
188 | random.seed(args.seed)
189 |
190 | out_f = open(expand(args.out_labeled), 'w') # open now to check we can
191 |
192 | clusters, post_ids, post2cluster = get_clusters(
193 | in_dupes_file=args.in_dupes, in_max_dupes=args.in_max_dupes, max_clusters=args.max_clusters)
194 | print('cluster sizes:', [len(cluster) for cluster in clusters])
195 | print('len(clusters)', len(clusters))
196 | print('len(post_ids) from dupes graph', len(post_ids))
197 |
198 | print('removing post ids which dont have answers...')
199 | post2answer, answer2post = read_posts_index(in_posts_index=args.in_posts_index)
200 | post_ids = [id for id in post_ids if id in post2answer]
201 | print('len(post_ids) after removing no answer', len(post_ids))
202 | new_clusters = []
203 | for cluster in clusters:
204 | cluster = [id for id in cluster if id in post2answer]
205 | new_clusters.append(cluster)
206 | clusters = new_clusters
207 | print('len clusters after removing no answer', [len(cluster) for cluster in clusters])
208 |
209 | post_ids_set = set(post_ids)
210 | print('len(post_ids_set)', len(post_ids_set))
211 | answer_ids = [post2answer[id] for id in post_ids]
212 | print('len(answer_ids)', len(answer_ids))
213 |
214 | preprocessor = Preprocessor(max_len=args.max_len) if not args.no_preprocess else NullPreprocessor()
215 | post_by_id = load_posts(answer2post=answer2post, post2answer=post2answer, in_posts=args.in_posts)
216 |
217 | count_by_state = defaultdict(int)
218 | n = 0
219 | dict_writer = csv.DictWriter(out_f, fieldnames=[
220 | 'id', 'cluster_id', 'question_title', 'question_body', 'answer_body'])
221 | dict_writer.writeheader()
222 | for post_id in post_ids:
223 | if post_id not in post_by_id:
224 | count_by_state['not in post_by_id'] += 1
225 | continue
226 | post = post_by_id[post_id]
227 | if 'answer_body' not in post or 'question_body' not in post:
228 | count_by_state['no body, or no answer'] += 1
229 | continue
230 | count_by_state['ok'] += 1
231 | cluster_id = post2cluster[post_id]
232 | row = {
233 | 'id': post_id,
234 | 'cluster_id': cluster_id,
235 | 'question_title': preprocessor(post['question_title'])[1],
236 | 'question_body': preprocessor(post['question_body'])[1],
237 | 'answer_body': preprocessor(post['answer_body'])[1]
238 | }
239 | dict_writer.writerow(row)
240 | n += 1
241 | print(count_by_state)
242 | print('rows written', n)
243 |
244 |
245 | def create_unlabeled(args):
246 | """
247 | this is going to do:
248 | - take a question (specific type in Posts.xml)
249 | - match it with the accepted answer (using the index)
250 | - write these out
251 | """
252 | torch.manual_seed(args.seed)
253 | np.random.seed(args.seed)
254 | random.seed(args.seed)
255 |
256 | out_f = open(expand(args.out_unlabeled), 'w') # open now, to check we can...
257 |
258 | post2answer, answer2post = read_posts_index(in_posts_index=args.in_posts_index)
259 | print('loaded index')
260 | print('posts in index', len(post2answer))
261 |
262 | preprocessor = Preprocessor(max_len=args.max_len) if not args.no_preprocess else NullPreprocessor()
263 | post_by_id = load_posts(post2answer=post2answer, answer2post=answer2post, in_posts=args.in_posts)
264 | print('loaded all posts, len(post_by_id)', len(post_by_id))
265 |
266 | count_by_state = defaultdict(int)
267 | n = 0
268 | dict_writer = csv.DictWriter(out_f, fieldnames=['id', 'question_title', 'question_body', 'answer_body'])
269 | dict_writer.writeheader()
270 | last_print = time.time()
271 | for post_id, info in post_by_id.items():
272 | if 'answer_body' not in info or 'question_body' not in info:
273 | count_by_state['no body, or no answer'] += 1
274 | continue
275 | count_by_state['ok'] += 1
276 | dict_writer.writerow({
277 | 'id': post_id,
278 | 'question_title': preprocessor(info['question_title'])[1],
279 | 'question_body': preprocessor(info['question_body'])[1],
280 | 'answer_body': preprocessor(info['answer_body'])[1]
281 | })
282 | if time.time() - last_print >= 10:
283 | print('written', n)
284 | last_print = time.time()
285 | n += 1
286 |
287 | print(count_by_state)
288 |
289 |
290 | if __name__ == '__main__':
291 | root_parser = argparse.ArgumentParser()
292 | parsers = root_parser.add_subparsers()
293 |
294 | parser = parsers.add_parser('index-posts')
295 | parser.add_argument('--in-posts-dir', type=str, default='~/data/askubuntu.com')
296 | parser.add_argument('--in-posts', type=str, default='{in_posts_dir}/Posts.xml')
297 | parser.add_argument('--out-posts-index', type=str, default='{in_posts_dir}/Posts_index.csv')
298 | parser.add_argument('--in-max-posts', type=int, help='for dev/debugging mostly')
299 | parser.set_defaults(func=index_posts)
300 |
301 | parser = parsers.add_parser('create-labeled')
302 | parser.add_argument('--seed', type=int, default=123)
303 | parser.add_argument('--in-dupes-dir', type=str, default='~/data/askubuntu')
304 | parser.add_argument('--in-dupes', type=str, default='{in_dupes_dir}/full.pos.txt')
305 | parser.add_argument('--max-clusters', type=int, default=20)
306 | parser.add_argument('--in-posts-dir', type=str, default='~/data/askubuntu.com')
307 | parser.add_argument('--in-posts', type=str, default='{in_posts_dir}/Posts.xml')
308 | parser.add_argument('--in-posts-index', type=str, default='{in_posts_dir}/Posts_index.csv')
309 | parser.add_argument('--in-max-dupes', type=int, help='for dev/debugging mostly')
310 | parser.add_argument('--in-max-posts', type=int, help='for dev/debugging mostly')
311 | parser.add_argument('--out-labeled', type=str, required=True)
312 | parser.add_argument('--no-preprocess', action='store_true')
313 | parser.add_argument('--max-len', type=int, default=500, help='set to 0 to disable')
314 | parser.set_defaults(func=create_labeled)
315 |
316 | parser = parsers.add_parser('create-unlabeled')
317 | parser.add_argument('--seed', type=int, default=123)
318 | parser.add_argument('--in-posts-dir', type=str, default='~/data/askubuntu.com')
319 | parser.add_argument('--in-dupes-dir', type=str, default='~/data/askubuntu')
320 | parser.add_argument('--in-posts', type=str, default='{in_posts_dir}/Posts.xml')
321 | parser.add_argument('--in-posts-index', type=str, default='{in_posts_dir}/Posts_index.csv')
322 | parser.add_argument('--in-max-posts', type=int, help='for dev/debugging mostly')
323 | parser.add_argument('--out-unlabeled', type=str, required=True)
324 | parser.add_argument('--max-len', type=int, default=500, help='set to 0 to disable')
325 | parser.add_argument('--no-preprocess', action='store_true')
326 | parser.set_defaults(func=create_unlabeled)
327 |
328 | args = root_parser.parse_args()
329 | for k in [
330 | 'in_dupes', 'out_labeled', 'in_posts', 'out_posts_index', 'in_posts_index',
331 | 'out_unlabeled']:
332 | if k in args.__dict__:
333 | args.__dict__[k] = args.__dict__[k].format(**args.__dict__)
334 |
335 | func = args.func
336 | del args.__dict__['func']
337 | func(args)
338 |
--------------------------------------------------------------------------------
/datasets/labeled_unlabeled_merger.py:
--------------------------------------------------------------------------------
1 | """
2 | given a labeled and unlabeled datafile, creates a single file, taht contains the data
3 | from both. the labels for the unlabeled data file will be UNK
4 | """
5 | import argparse
6 | from collections import OrderedDict
7 | import csv
8 |
9 |
10 | def get_col_by_role(role_string):
11 | col_by_role = {}
12 | for s in role_string.split(','):
13 | split_s = s.split('=')
14 | col_by_role[split_s[0]] = split_s[1]
15 | print('col_by_role', col_by_role)
16 | return col_by_role
17 |
18 |
19 | def run(labeled_files, unlabeled_files, out_file, unlabeled_columns, labeled_columns, no_add_cust_tokens, add_columns, column_order):
20 | add_columns = add_columns.split(',')
21 | out_f = open(out_file, 'w')
22 | dict_writer = csv.DictWriter(out_f, fieldnames=column_order.split(','))
23 | dict_writer.writeheader()
24 |
25 | labeled_col_by_role = get_col_by_role(labeled_columns)
26 | unlabeled_col_by_role = get_col_by_role(unlabeled_columns)
27 |
28 | labeled_view1 = labeled_col_by_role['view1']
29 | labeled_label = labeled_col_by_role['label']
30 |
31 | unlabeled_view1 = unlabeled_col_by_role['view1']
32 | unlabeled_view2 = unlabeled_col_by_role['view2']
33 |
34 | unlabeled_by_first_utterance = OrderedDict()
35 | for filename in unlabeled_files:
36 | print(filename)
37 | with open(filename, 'r') as in_f:
38 | dict_reader = csv.DictReader(in_f)
39 | for row in dict_reader:
40 | view1 = row[unlabeled_view1]
41 | if not no_add_cust_tokens and not view1.startswith(''
43 | out_row = {
44 | 'view1': view1,
45 | 'view2': row[unlabeled_view2],
46 | 'label': 'UNK'
47 | }
48 | for k in add_columns:
49 | out_row[k] = row[k]
50 | unlabeled_by_first_utterance[view1] = out_row
51 | print('loaded unlabeled')
52 |
53 | for filename in labeled_files:
54 | print(filename)
55 | with open(filename, 'r') as in_f:
56 | dict_reader = csv.DictReader(in_f)
57 | for row in dict_reader:
58 | view1 = row[labeled_view1]
59 | if view1 not in unlabeled_by_first_utterance:
60 | print('warning: not found in unlabelled', view1)
61 | continue
62 | out_row = unlabeled_by_first_utterance[view1]
63 | out_row['label'] = row[labeled_label]
64 | for k in add_columns:
65 | assert out_row[k] == row[k]
66 |
67 | for row in unlabeled_by_first_utterance.values():
68 | dict_writer.writerow(row)
69 |
70 | out_f.close()
71 |
72 |
73 | if __name__ == '__main__':
74 | parser = argparse.ArgumentParser()
75 | parser.add_argument('--labeled-files', type=str, nargs='+', required=True)
76 | parser.add_argument('--unlabeled-files', type=str, nargs='+', required=True)
77 | parser.add_argument('--out-file', type=str, required=True)
78 | parser.add_argument('--unlabeled-columns', type=str, default='view1=first_utterance,view2=context')
79 | parser.add_argument('--labeled-columns', type=str, default='view1=text,label=tag')
80 | parser.add_argument('--no-add-cust-tokens', action='store_true')
81 | parser.add_argument('--add-columns', type=str, default='id,question_body')
82 | parser.add_argument('--column-order', type=str, default='id,label,view1,view2,question_body')
83 | args = parser.parse_args()
84 | run(**args.__dict__)
85 |
--------------------------------------------------------------------------------
/datasets/readme.md:
--------------------------------------------------------------------------------
1 | To generate the data from raw data:
2 |
3 | For askubuntu:
4 |
5 | ```
6 | export PYTHONPATH=.
7 | mkdir ~/data/askubuntu_induction
8 | python datasets/askubuntu_preprocess.py create-labeled --out-labeled ~/data/askubuntu_induction/preproc_labeled.csv
9 | python datasets/askubuntu_preprocess.py create-unlabeled --out-unlabeled ~/data/askubuntu_induction/preproc_unlabeled.csv
10 | python datasets/labeled_unlabeled_merger.py --labeled-files ~/data/askubuntu_induction/preproc_labeled.csv --unlabeled-files ~/data/askubuntu_induction/preproc_unlabeled.csv --out-file ~/data/askubuntu_induction/preproc_merged.csv --labeled-columns view1=question_title,label=cluster_id --unlabeled-columns view1=question_title,view2=answer_body --no-add-cust-tokens
11 | ```
12 |
13 | For Twitter customer support:
14 | ```
15 | (cd data; bunzip2 airlines_500onlyb.csv.bz2)
16 | python datasets/twitter_dataset_preprocess.py --no-preprocess --out-examples ~/data/twittersupport/airlines_raw.csv
17 | python datasets/twitter_airlines_raw_merger2.py
18 | ```
19 |
--------------------------------------------------------------------------------
/datasets/twitter_airlines_raw_merger2.py:
--------------------------------------------------------------------------------
1 | """
2 | takes contents of airlines_raw.csv, and of data/airlines_mergedb.csv, and creates
3 | airlines_500_merged.csv, using the raw tweets from airlines_raw.csv, and the tags from airlines_500_mergedb.csv
4 | """
5 | import argparse, os, time, csv, json
6 | from os import path
7 | from os.path import join, expanduser as expand
8 |
9 | def run(args):
10 | with open(expand(args.airlines_raw_file), 'r') as f:
11 | dict_reader = csv.DictReader(f)
12 | raw_rows = list(dict_reader)
13 | raw_row_by_tweet_id = {int(row['first_tweet_id']): row for i, row in enumerate(raw_rows)}
14 |
15 | f_in = open(expand(args.airlines_merged_file), 'r')
16 | f_out = open(expand(args.out_file), 'w')
17 | dict_reader = csv.DictReader(f_in)
18 | dict_writer = csv.DictWriter(f_out, fieldnames=['first_tweet_id', 'tag', 'first_utterance', 'context'])
19 | dict_writer.writeheader()
20 | for i, old_merged_row in enumerate(dict_reader):
21 | if old_merged_row['first_tweet_id'] == '':
22 | continue
23 | tweet_id = int(old_merged_row['first_tweet_id'])
24 | raw_row = raw_row_by_tweet_id[tweet_id]
25 | raw_row['tag'] = old_merged_row['tag']
26 | dict_writer.writerow(raw_row)
27 | # so, we accidentally had a blank line at the end of the non-raw dataset. add that in here...
28 | dict_writer.writerow({'first_tweet_id': '', 'tag': 'UNK', 'first_utterance': '', 'context': ''})
29 | f_out.close()
30 |
31 | if __name__ == '__main__':
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument('--airlines-raw-file', type=str, default='~/data/twittersupport/airlines_raw.csv')
34 | parser.add_argument('--airlines-merged-file', type=str, default='data/airlines_500_mergedb.csv')
35 | parser.add_argument('--out-file', type=str, default='~/data/twittersupport/airlines_raw_merged.csv')
36 | args = parser.parse_args()
37 | run(args)
38 |
--------------------------------------------------------------------------------
/datasets/twitter_dataset_preprocess.py:
--------------------------------------------------------------------------------
1 | """
2 | preprocessing twitter dataset
3 | """
4 | import time
5 | import os
6 | from os import path
7 | from os.path import join, expanduser as expand
8 | import json
9 | import re
10 | import csv
11 | from collections import defaultdict
12 | import argparse
13 |
14 | import nltk
15 |
16 | from preprocessing import twitter_airlines
17 |
18 |
19 | all_companies = 'AppleSupport,Amazon,Uber_Support,SpotifyCares,Delta,Tesco,AmericanAir,TMobileHelp,comcastcares,British_Airways,SouthwestAir,VirginTrains,Ask_Spectrum,XboxSupport,sprintcare,hulu_support,AskPlayStation,ChipotleTweets,UPSHelp,AskTarget'.split(',')
20 |
21 | """
22 | example rows:
23 |
24 | tweet_id,author_id,inbound,created_at,text,response_tweet_id,in_response_to_tweet_id
25 | 1,sprintcare,False,Tue Oct 31 22:10:47 +0000 2017,@115712 I understand. I would like to assist you. We would need to get you into a private secured link to further assist.,2,3
26 | 2,115712,True,Tue Oct 31 22:11:45 +0000 2017,@sprintcare and how do you propose we do that,,1
27 | """
28 |
29 | class NullPreprocessor(object):
30 | def __call__(self, text, info_dict=None):
31 | is_valid = True
32 | if 'DM' in text:
33 | is_valid = False
34 | if 'https://t.co' in text:
35 | is_valid = False
36 |
37 | text = text.replace('\n', ' ').replace('\r', ' ')
38 |
39 | target_company = info_dict['target_company']
40 | cust_twitter_id = info_dict['cust_twitter_id']
41 | text = text.replace('@' + target_company, ' __company__ ').replace('@' + cust_twitter_id, ' __cust__ ')
42 |
43 | return is_valid, text
44 |
45 | def run_for_company(in_csv_file, in_max_rows, examples_writer, target_company, no_preprocess):
46 | f_in = open(expand(in_csv_file), 'r')
47 | node_by_id = {}
48 | start_node_ids = []
49 | dict_reader = csv.DictReader(f_in)
50 | next_by_prev = {}
51 | for row in dict_reader:
52 | id = row['tweet_id']
53 | prev = row['in_response_to_tweet_id']
54 | next_by_prev[prev] = id
55 | if prev == '' and ('@' + target_company) in row['text']:
56 | start_node_ids.append(id)
57 | node_by_id[id] = row
58 | if in_max_rows is not None and len(node_by_id) >= in_max_rows:
59 | print('reached max rows', in_max_rows, '=> breaking')
60 | break
61 | print('len(node_by_id)', len(node_by_id))
62 | print('len(start_node_ids)', len(start_node_ids))
63 | count_by_status = defaultdict(int)
64 | count_by_count = defaultdict(int)
65 |
66 | preprocessor = twitter_airlines.Preprocessor() if not no_preprocess else NullPreprocessor()
67 |
68 | for i, start_node_id in enumerate(start_node_ids):
69 | conversation_texts = []
70 | first_utterance = None
71 | is_valid = True
72 | node = node_by_id[start_node_id]
73 | cust_twitter_id = node['author_id']
74 | while True:
75 | text = node['text'].replace('\n', ' ')
76 | if node['inbound'] == 'True':
77 | start_tok = ' skipping conversation')
110 | break
111 | node = node_by_id[response_id]
112 | print(count_by_status)
113 |
114 | def run_aggregated(in_csv_file, in_max_rows, out_examples, target_companies, no_preprocess):
115 | with open(expand(out_examples), 'w') as f_examples:
116 | examples_writer = csv.DictWriter(f_examples, fieldnames=['first_tweet_id', 'first_utterance', 'context'])
117 | examples_writer.writeheader()
118 | for company in target_companies:
119 | print(company)
120 | run_for_company(in_csv_file, in_max_rows, examples_writer, company, no_preprocess)
121 |
122 | def run_by_company(in_csv_file, in_max_rows, out_examples_templ, target_companies, no_preprocess):
123 | for company in target_companies:
124 | print(company)
125 | out_examples = out_examples_templ.format(company=company)
126 | with open(expand(out_examples), 'w') as f_examples:
127 | examples_writer = csv.DictWriter(f_examples, fieldnames=['first_tweet_id', 'first_utterance', 'context'])
128 | examples_writer.writeheader()
129 | run_for_company(in_csv_file, in_max_rows, examples_writer, company, no_preprocess)
130 |
131 | if __name__ == '__main__':
132 | root_parser = argparse.ArgumentParser()
133 | parsers = root_parser.add_subparsers()
134 |
135 | parser = parsers.add_parser('run-aggregated')
136 | parser.add_argument('--in-csv-file', type=str, default='~/data/twittersupport/twcs.csv')
137 | parser.add_argument('--in-max-rows', type=int, help='for dev/debugging mostly')
138 | parser.add_argument('--out-examples', type=str, required=True)
139 | parser.add_argument('--target-companies', type=str, default='Delta,AmericanAir,British_Airways,SouthwestAir')
140 | parser.add_argument('--no-preprocess', action='store_true')
141 | parser.set_defaults(func=run_aggregated)
142 |
143 | parser = parsers.add_parser('run-by-company')
144 | parser.add_argument('--in-csv-file', type=str, default='~/data/twittersupport/twcs.csv')
145 | parser.add_argument('--in-max-rows', type=int, help='for dev/debugging mostly')
146 | parser.add_argument('--out-examples-templ', type=str, default='~/data/twittersupport/{company}.csv')
147 | parser.add_argument('--target-companies', type=str, default=','.join(all_companies))
148 | parser.add_argument('--no-preprocess', action='store_true')
149 | parser.set_defaults(func=run_by_company)
150 |
151 | args = root_parser.parse_args()
152 |
153 | if args.func == run_aggregated:
154 | args.out_examples = args.out_examples.format(**args.__dict__)
155 | args.target_companies = args.target_companies.split(',')
156 | elif args.func == run_by_company:
157 | args.target_companies = args.target_companies.split(',')
158 |
159 | func = args.func
160 | del args.__dict__['func']
161 | func(**args.__dict__)
162 |
--------------------------------------------------------------------------------
/images/avkmeans_graph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asappresearch/dialog-intent-induction/6396f3153b0fda7e170b1df6b68e969b5e4eb16e/images/avkmeans_graph.png
--------------------------------------------------------------------------------
/images/example_dialogs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asappresearch/dialog-intent-induction/6396f3153b0fda7e170b1df6b68e969b5e4eb16e/images/example_dialogs.png
--------------------------------------------------------------------------------
/labeled_unlabeled_merger.py:
--------------------------------------------------------------------------------
1 | """
2 | given a labeled and unlabeled datafile, creates a single file, that contains the data
3 | from both. the labels for the unlabeled data file will be UNK
4 | """
5 | import argparse
6 | from collections import OrderedDict
7 | import csv
8 |
9 |
10 | def get_col_by_role(role_string):
11 | col_by_role = {}
12 | for s in role_string.split(','):
13 | split_s = s.split('=')
14 | col_by_role[split_s[0]] = split_s[1]
15 | print('col_by_role', col_by_role)
16 | return col_by_role
17 |
18 |
19 | def run(labeled_files, unlabeled_files, out_file, unlabeled_columns, labeled_columns, no_add_cust_tokens, add_columns, column_order):
20 | add_columns = add_columns.split(',')
21 | out_f = open(out_file, 'w')
22 | dict_writer = csv.DictWriter(out_f, fieldnames=column_order.split(','))
23 | dict_writer.writeheader()
24 |
25 | labeled_col_by_role = get_col_by_role(labeled_columns)
26 | unlabeled_col_by_role = get_col_by_role(unlabeled_columns)
27 |
28 | labeled_view1 = labeled_col_by_role['view1']
29 | labeled_label = labeled_col_by_role['label']
30 |
31 | unlabeled_view1 = unlabeled_col_by_role['view1']
32 | unlabeled_view2 = unlabeled_col_by_role['view2']
33 |
34 | unlabeled_by_first_utterance = OrderedDict()
35 | for filename in unlabeled_files:
36 | print(filename)
37 | with open(filename, 'r') as in_f:
38 | dict_reader = csv.DictReader(in_f)
39 | for row in dict_reader:
40 | view1 = row[unlabeled_view1]
41 | if not no_add_cust_tokens and not view1.startswith(''
43 | out_row = {
44 | 'view1': view1,
45 | 'view2': row[unlabeled_view2],
46 | 'label': 'UNK'
47 | }
48 | for k in add_columns:
49 | out_row[k] = row[k]
50 | unlabeled_by_first_utterance[view1] = out_row
51 | print('loaded unlabeled')
52 |
53 | for filename in labeled_files:
54 | print(filename)
55 | with open(filename, 'r') as in_f:
56 | dict_reader = csv.DictReader(in_f)
57 | for row in dict_reader:
58 | view1 = row[labeled_view1]
59 | if view1 not in unlabeled_by_first_utterance:
60 | print('warning: not found in unlabelled', view1)
61 | continue
62 | out_row = unlabeled_by_first_utterance[view1]
63 | out_row['label'] = row[labeled_label]
64 | for k in add_columns:
65 | assert out_row[k] == row[k]
66 |
67 | for row in unlabeled_by_first_utterance.values():
68 | dict_writer.writerow(row)
69 |
70 | out_f.close()
71 |
72 |
73 | if __name__ == '__main__':
74 | parser = argparse.ArgumentParser()
75 | parser.add_argument('--labeled-files', type=str, nargs='+', required=True)
76 | parser.add_argument('--unlabeled-files', type=str, nargs='+', required=True)
77 | parser.add_argument('--out-file', type=str, required=True)
78 | parser.add_argument('--unlabeled-columns', type=str, default='view1=first_utterance,view2=context')
79 | parser.add_argument('--labeled-columns', type=str, default='view1=text,label=tag')
80 | parser.add_argument('--no-add-cust-tokens', action='store_true')
81 | parser.add_argument('--add-columns', type=str, default='id,question_body')
82 | parser.add_argument('--column-order', type=str, default='id,label,view1,view2,question_body')
83 | args = parser.parse_args()
84 | run(**args.__dict__)
85 |
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asappresearch/dialog-intent-induction/6396f3153b0fda7e170b1df6b68e969b5e4eb16e/metrics/__init__.py
--------------------------------------------------------------------------------
/metrics/cluster_metrics.py:
--------------------------------------------------------------------------------
1 | """
2 | cluster metrics: precision, recall, )f1)
3 | """
4 | from collections import Counter
5 |
6 | import scipy.optimize
7 | import torch
8 |
9 |
10 | def calc_precision(gnd_assignments, pred_assignments):
11 | """
12 | gnd_clusters should be a torch tensor of longs, containing
13 | the assignment to each cluster
14 |
15 | assumes that cluster assignments are 0-based, and no 'holes'
16 | """
17 | precision_sum = 0
18 | assert len(gnd_assignments.size()) == 1
19 | assert len(pred_assignments.size()) == 1
20 | assert pred_assignments.size(0) == gnd_assignments.size(0)
21 | N = gnd_assignments.size(0)
22 | K_gnd = gnd_assignments.max().item() + 1
23 | K_pred = pred_assignments.max().item() + 1
24 | for k_pred in range(K_pred):
25 | mask = pred_assignments == k_pred
26 | gnd = gnd_assignments[mask.nonzero().long().view(-1)]
27 | max_intersect = 0
28 | for k_gnd in range(K_gnd):
29 | intersect = (gnd == k_gnd).long().sum().item()
30 | max_intersect = max(max_intersect, intersect)
31 | precision_sum += max_intersect
32 | precision = precision_sum / N
33 | return precision
34 |
35 |
36 | def calc_recall(gnd_assignments, pred_assignments):
37 | """
38 | basically the reverse of calc_precision
39 |
40 | so, we can just call calc_precision in reverse :P
41 | """
42 | return calc_precision(gnd_assignments=pred_assignments, pred_assignments=gnd_assignments)
43 |
44 |
45 | def calc_f1(gnd_assignments, pred_assignments):
46 | prec = calc_precision(gnd_assignments=gnd_assignments, pred_assignments=pred_assignments)
47 | recall = calc_recall(gnd_assignments=gnd_assignments, pred_assignments=pred_assignments)
48 | f1 = 2 * (prec * recall) / (prec + recall)
49 | return f1
50 |
51 |
52 | def calc_prec_rec_f1(gnd_assignments, pred_assignments):
53 | prec = calc_precision(gnd_assignments=gnd_assignments, pred_assignments=pred_assignments)
54 | recall = calc_recall(gnd_assignments=gnd_assignments, pred_assignments=pred_assignments)
55 | f1 = 2 * (prec * recall) / (prec + recall)
56 | return prec, recall, f1
57 |
58 |
59 | def calc_ACC(pred, gnd):
60 | assert len(pred.size()) == 1
61 | assert len(gnd.size()) == 1
62 | N = pred.size(0)
63 | assert N == gnd.size(0)
64 | M = torch.zeros(N, N, dtype=torch.int64)
65 | counts = Counter(list(zip(gnd.tolist(), pred.tolist())))
66 | keys = torch.LongTensor(list(counts.keys()))
67 | values = torch.LongTensor(list(counts.values()))
68 | M = scipy.sparse.csr_matrix((values.numpy(), (keys[:, 0].numpy(), keys[:, 1].numpy())))
69 | M = M.todense()
70 |
71 | row_ind, col_ind = scipy.optimize.linear_sum_assignment(-M)
72 | cost = M[row_ind, col_ind].sum().item()
73 | ACC = cost / N
74 |
75 | return ACC
76 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asappresearch/dialog-intent-induction/6396f3153b0fda7e170b1df6b68e969b5e4eb16e/model/__init__.py
--------------------------------------------------------------------------------
/model/decoder.py:
--------------------------------------------------------------------------------
1 | import torch as t
2 | import torch.nn as nn
3 |
4 |
5 | class Decoder(nn.Module):
6 | def __init__(self, latent_z_size, word_emb_size, word_vocab_size, decoder_rnn_size,
7 | decoder_num_layers, dropout=0.5):
8 | super(Decoder, self).__init__()
9 | self.latent_z_size = latent_z_size
10 | self.word_vocab_size = word_vocab_size
11 | self.decoder_rnn_size = decoder_rnn_size
12 | self.dropout = dropout
13 | self.rnn = nn.LSTM(input_size=latent_z_size + word_emb_size,
14 | hidden_size=decoder_rnn_size,
15 | num_layers=decoder_num_layers,
16 | dropout=dropout,
17 | batch_first=True)
18 | self.fc = nn.Linear(decoder_rnn_size, word_vocab_size)
19 |
20 | def forward(self, decoder_input, latent_z):
21 | """
22 | :param decoder_input: tensor with shape of [batch_size, seq_len, emb_size]
23 | :param latent_z: sequence context with shape of [batch_size, latent_z_size]
24 | :return: unnormalized logits of sentense words distribution probabilities
25 | with shape of [batch_size, seq_len, word_vocab_size]
26 |
27 | TODO: add padding support
28 | """
29 |
30 | [batch_size, seq_len, _] = decoder_input.size()
31 | # decoder rnn is conditioned on context via additional bias = W_cond * z to every input
32 | # token
33 | latent_z = t.cat([latent_z] * seq_len, 1).view(batch_size, seq_len, self.latent_z_size)
34 | decoder_input = t.cat([decoder_input, latent_z], 2)
35 | rnn_out, _ = self.rnn(decoder_input)
36 | rnn_out = rnn_out.contiguous().view(-1, self.decoder_rnn_size)
37 | result = self.fc(rnn_out)
38 | result = result.view(batch_size, seq_len, self.word_vocab_size)
39 | return result
40 |
--------------------------------------------------------------------------------
/model/encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Encoder(nn.Module):
6 | def __init__(self, word_emb_size, encoder_rnn_size, encoder_num_layers, dropout=0.5):
7 | super(Encoder, self).__init__()
8 |
9 | self.encoder_rnn_size = encoder_rnn_size
10 | self.encoder_num_layers = encoder_num_layers
11 | # Create input dropout parameter
12 | self.rnn = nn.LSTM(input_size=word_emb_size,
13 | hidden_size=encoder_rnn_size,
14 | num_layers=encoder_num_layers,
15 | dropout=dropout,
16 | batch_first=True,
17 | bidirectional=True)
18 |
19 | def forward(self, encoder_input, lengths):
20 | """
21 | :param encoder_input: [batch_size, seq_len, emb_size] tensor
22 | :return: context of input sentenses with shape of [batch_size, encoder_rnn_size]
23 | """
24 | lengths, perm_idx = lengths.sort(0, descending=True)
25 | encoder_input = encoder_input[perm_idx]
26 | [batch_size, seq_len, _] = encoder_input.size()
27 | packed_words = torch.nn.utils.rnn.pack_padded_sequence(
28 | encoder_input, lengths, True)
29 | # Unfold rnn with zero initial state and get its final state from the last layer
30 | rnn_out, (_, final_state) = self.rnn(packed_words, None)
31 | final_state = final_state.view(
32 | self.encoder_num_layers, 2, batch_size, self.encoder_rnn_size)[-1]
33 | h_1, h_2 = final_state[0], final_state[1]
34 | final_state = torch.cat([h_1, h_2], 1)
35 | _, unperm_idx = perm_idx.sort(0)
36 | final_state = final_state[unperm_idx]
37 | return final_state
38 |
--------------------------------------------------------------------------------
/model/multiview_encoders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch as t
3 | from torch import nn
4 | import numpy as np
5 |
6 | from model.utils import pad_sentences, pad_paragraphs
7 | import train
8 |
9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10 |
11 |
12 | class MultiviewEncoders(nn.Module):
13 |
14 | def __init__(
15 | self, vocab_size, num_layers, embedding_size, lstm_hidden_size, word_dropout, dropout,
16 | start_idx=2, end_idx=3, pad_idx=0):
17 | super().__init__()
18 | self.pad_idx = pad_idx
19 | self.start_idx = start_idx # for RNN autoencoder training
20 | self.end_idx = end_idx # for RNN autoencoder training
21 | self.num_layers = num_layers
22 | self.embedding_size = embedding_size
23 | self.lstm_hidden_size = lstm_hidden_size
24 | self.word_dropout = nn.Dropout(word_dropout)
25 | self.dropout = dropout
26 | self.vocab_size = vocab_size
27 | self.crit = nn.CrossEntropyLoss()
28 |
29 | self.embedder = nn.Embedding(vocab_size, embedding_size)
30 |
31 | def create_rnn(embedding_size, bidirectional=True):
32 | return nn.LSTM(
33 | embedding_size,
34 | lstm_hidden_size,
35 | dropout=dropout,
36 | num_layers=num_layers,
37 | bidirectional=bidirectional,
38 | batch_first=True
39 | )
40 |
41 | self.view1_word_rnn = create_rnn(embedding_size)
42 | self.view2_word_rnn = create_rnn(embedding_size)
43 | self.view2_sent_rnn = create_rnn(2*lstm_hidden_size)
44 |
45 | self.ae_decoder = create_rnn(embedding_size + 2 * lstm_hidden_size, bidirectional=False)
46 | self.qt_context = create_rnn(embedding_size)
47 |
48 | self.fc = nn.Linear(lstm_hidden_size, vocab_size)
49 |
50 | def get_encoder(self, encoder):
51 | return {
52 | 'v1': self.view1_word_rnn,
53 | 'v2': self.view2_word_rnn,
54 | 'v2sent': self.view2_sent_rnn,
55 | 'ae_decoder': self.ae_decoder,
56 | 'qt': self.qt_context
57 | }[encoder]
58 |
59 | @classmethod
60 | def construct_from_embeddings(
61 | cls, embeddings, num_layers, embedding_size, lstm_hidden_size, word_dropout, dropout,
62 | vocab_size, start_idx=2, end_idx=3, pad_idx=0):
63 | model = cls(
64 | num_layers=num_layers,
65 | embedding_size=embedding_size,
66 | lstm_hidden_size=lstm_hidden_size,
67 | word_dropout=word_dropout,
68 | dropout=dropout,
69 | start_idx=start_idx,
70 | end_idx=end_idx,
71 | pad_idx=pad_idx,
72 | vocab_size=vocab_size
73 | )
74 | model.embedder = nn.Embedding.from_pretrained(embeddings, freeze=False)
75 | return model
76 |
77 | def decode(self, decoder_input, latent_z):
78 | """
79 | decode state into word indices
80 |
81 | :param decoder_input: list of lists of indices
82 | :param latent_z: sequence context with shape of [batch_size, latent_z_size]
83 |
84 | :return: unnormalized logits of sentense words distribution probabilities
85 | with shape of [batch_size, seq_len, word_vocab_size]
86 |
87 | """
88 | # this pad_sentences call will a start_idx token at the start of each sequence
89 | # so, we are feeding in the previous token each time, in order to generate the
90 | # next token
91 | padded, lengths = pad_sentences(decoder_input, pad_idx=self.pad_idx, lpad=self.start_idx)
92 | embeddings = self.embedder(padded)
93 | embeddings = self.word_dropout(embeddings)
94 | [batch_size, seq_len, _] = embeddings.size()
95 | # decoder rnn is conditioned on context via additional bias = W_cond * z
96 | # to every input token
97 | latent_z = t.cat([latent_z] * seq_len, 1).view(batch_size, seq_len, -1)
98 | embeddings = t.cat([embeddings, latent_z], 2)
99 | rnn = self.ae_decoder
100 | rnn_out, _ = rnn(embeddings)
101 | rnn_out = rnn_out.contiguous().view(batch_size * seq_len, self.lstm_hidden_size)
102 | result = self.fc(rnn_out)
103 | result = result.view(batch_size, seq_len, self.vocab_size)
104 | return result
105 |
106 | def forward(self, input, encoder):
107 | """
108 | Encode an input into a vector representation
109 | params:
110 | input : word indices
111 | encoder: [pt1|pt2|v1|v2]
112 | """
113 | if encoder == 'v2':
114 | return self.hierarchical_forward(input)
115 |
116 | batch_size = len(input)
117 | padded, lengths = pad_sentences(input, pad_idx=self.pad_idx)
118 | embeddings = self.embedder(padded)
119 | embeddings = self.word_dropout(embeddings)
120 | lengths, perm_idx = lengths.sort(0, descending=True)
121 | embeddings = embeddings[perm_idx]
122 | packed = torch.nn.utils.rnn.pack_padded_sequence(embeddings, lengths, batch_first=True)
123 | rnn = self.get_encoder(encoder)
124 | _, (_, final_state) = rnn(packed, None)
125 | _, unperm_idx = perm_idx.sort(0)
126 | final_state = final_state[:, unperm_idx]
127 | final_state = final_state.view(self.num_layers, 2, batch_size, self.lstm_hidden_size)[-1] \
128 | .transpose(0, 1).contiguous() \
129 | .view(batch_size, 2 * self.lstm_hidden_size)
130 | return final_state
131 |
132 | def hierarchical_forward(self, input):
133 | batch_size = len(input)
134 | padded, word_lens, sent_lens, max_sent_len = pad_paragraphs(input, pad_idx=self.pad_idx)
135 | embeddings = self.embedder(padded)
136 | embeddings = self.word_dropout(embeddings)
137 | word_lens, perm_idx = word_lens.sort(0, descending=True)
138 | embeddings = embeddings[perm_idx]
139 | packed = torch.nn.utils.rnn.pack_padded_sequence(
140 | embeddings, word_lens, batch_first=True)
141 | _, (_, final_word_state) = self.view2_word_rnn(packed, None)
142 | _, unperm_idx = perm_idx.sort(0)
143 | final_word_state = final_word_state[:, unperm_idx]
144 | final_word_state = final_word_state.view(
145 | self.num_layers, 2, batch_size*max_sent_len, self.lstm_hidden_size)[-1] \
146 | .transpose(0, 1).contiguous() \
147 | .view(batch_size, max_sent_len, 2 * self.lstm_hidden_size)
148 |
149 | sent_lens, sent_perm_idx = sent_lens.sort(0, descending=True)
150 | sent_embeddings = final_word_state[sent_perm_idx]
151 | sent_packed = torch.nn.utils.rnn.pack_padded_sequence(sent_embeddings, sent_lens, batch_first=True)
152 | _, (_, final_sent_state) = self.view2_sent_rnn(sent_packed, None)
153 | _, sent_unperm_idx = sent_perm_idx.sort(0)
154 | final_sent_state = final_sent_state[:, sent_unperm_idx]
155 | final_sent_state = final_sent_state.view(
156 | self.num_layers, 2, batch_size, self.lstm_hidden_size)[-1] \
157 | .transpose(0, 1).contiguous() \
158 | .view(batch_size, 2 * self.lstm_hidden_size)
159 | return final_sent_state
160 |
161 | def qt_loss(self, target_view_state, input_view_state):
162 | """
163 | pick out the correct example in the target_view, based on the corresponding input_view
164 | """
165 | scores = input_view_state @ target_view_state.transpose(0, 1)
166 | batch_size = scores.size(0)
167 | targets = torch.from_numpy(np.arange(batch_size, dtype=np.int64))
168 | targets = targets.to(scores.device)
169 | loss = self.crit(scores, targets)
170 | _, argmax = scores.max(dim=-1)
171 | examples_correct = (argmax == targets)
172 | acc = examples_correct.float().mean().item()
173 | return loss, acc
174 |
175 | def reconst_loss(self, gnd_utts, reconst):
176 | """
177 | gnd_utts is a list of lists of indices (the outer list should be a minibatch)
178 | reconst is a tensor with the logits from a decoder [batchsize][seqlen][vocabsize]
179 | (should not have passed through softmax)
180 |
181 | reconst should be one token longer than the inputs in gnd_utts. the additional
182 | token to be predicted is the end_idx token
183 | """
184 | batch_size, seq_len, vocab_size = reconst.size()
185 | loss = 0
186 | # this pad_sentences call will add token self.end_idx at the end of each sequence
187 | padded, lengths = pad_sentences(gnd_utts, pad_idx=self.pad_idx, rpad=self.end_idx)
188 | batch_size = len(lengths)
189 | crit = nn.CrossEntropyLoss()
190 | reconst_flat = reconst.view(batch_size * seq_len, vocab_size)
191 | padded_flat = padded.view(batch_size * seq_len)
192 | loss += crit(reconst_flat, padded_flat)
193 | _, argmax = reconst.max(dim=-1)
194 | correct = (argmax == padded)
195 | acc = correct.float().mean().item()
196 | return loss, acc
197 |
198 |
199 | def from_embeddings(glove_path, id_to_token, token_to_id):
200 | vocab_size = len(token_to_id)
201 |
202 | # Load pre-trained GloVe vectors
203 | pretrained = {}
204 | word_emb_size = 0
205 | print('loading glove')
206 | for line in open(glove_path):
207 | parts = line.strip().split()
208 | if len(parts) % 100 != 1:
209 | continue
210 | word = parts[0]
211 | if word not in token_to_id:
212 | continue
213 | vector = [float(v) for v in parts[1:]]
214 | pretrained[word] = vector
215 | word_emb_size = len(vector)
216 | pretrained_list = []
217 | scale = np.sqrt(3.0 / word_emb_size)
218 | print('loading oov')
219 | for word in token_to_id:
220 | # apply lower() because all GloVe vectors are for lowercase words
221 | if word.lower() in pretrained:
222 | pretrained_list.append(np.array(pretrained[word.lower()]))
223 | else:
224 | random_vector = np.random.uniform(-scale, scale, [word_emb_size])
225 | pretrained_list.append(random_vector)
226 |
227 | print('instantiating model')
228 | model = MultiviewEncoders.construct_from_embeddings(
229 | embeddings=torch.FloatTensor(pretrained_list),
230 | num_layers=train.LSTM_LAYER,
231 | embedding_size=word_emb_size,
232 | lstm_hidden_size=train.LSTM_HIDDEN,
233 | word_dropout=train.WORD_DROPOUT_RATE,
234 | dropout=train.DROPOUT_RATE,
235 | vocab_size=vocab_size
236 | )
237 | model.to(device)
238 | return id_to_token, token_to_id, vocab_size, word_emb_size, model
239 |
240 |
241 | def load_model(model_path):
242 | with open(model_path, 'rb') as f:
243 | state = torch.load(f)
244 |
245 | id_to_token = state['id_to_token']
246 | word_emb_size = state['word_emb_size']
247 |
248 | token_to_id = {token: id for id, token in enumerate(id_to_token)}
249 | vocab_size = len(id_to_token)
250 |
251 | mvc_encoder = MultiviewEncoders(
252 | num_layers=train.LSTM_LAYER,
253 | embedding_size=word_emb_size,
254 | lstm_hidden_size=train.LSTM_HIDDEN,
255 | word_dropout=train.WORD_DROPOUT_RATE,
256 | dropout=train.DROPOUT_RATE,
257 | vocab_size=vocab_size
258 | )
259 | mvc_encoder.to(device)
260 | mvc_encoder.load_state_dict(state['model_state'])
261 | return id_to_token, token_to_id, vocab_size, word_emb_size, mvc_encoder
262 |
--------------------------------------------------------------------------------
/model/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
5 |
6 |
7 | def pad_sentences(
8 | sents, lpad=None, rpad=None, reverse=False, pad_idx=0,
9 | max_sent_len=100):
10 | sentences = []
11 | max_len = 0
12 | for i in range(len(sents)):
13 | if len(sents[i]) > max_sent_len:
14 | sentences.append(sents[i][:max_sent_len])
15 | else:
16 | sentences.append(sents[i])
17 | max_len = max(max_len, len(sentences[i]))
18 | if reverse:
19 | sentences = [sentence[::-1] for sentence in sentences]
20 | if lpad is not None:
21 | sentences = [[lpad] + sentence for sentence in sentences]
22 | max_len += 1
23 | if rpad is not None:
24 | sentences = [sentence + [rpad] for sentence in sentences]
25 | max_len += 1
26 | lengths = []
27 | for i in range(len(sentences)):
28 | lengths.append(len(sentences[i]))
29 | sentences[i] = sentences[i] + [pad_idx]*(max_len - len(sentences[i]))
30 | return (torch.LongTensor(sentences).to(device),
31 | torch.LongTensor(lengths).to(device))
32 |
33 |
34 | def pad_paragraphs(paras, pad_idx=0):
35 | sentences, lengths = [], []
36 | max_len = 0
37 | for para in paras:
38 | max_len = max(max_len, len(para))
39 | for para in paras:
40 | for sent in para:
41 | sentences.append(sent[:])
42 | lengths.append(len(para))
43 | for i in range(max_len - len(para)):
44 | sentences.append([pad_idx])
45 | ret_sents, sent_lens = pad_sentences(sentences, pad_idx=pad_idx)
46 | return ret_sents, sent_lens, torch.LongTensor(lengths).to(device), max_len
47 |
48 |
49 | def euclidean_metric(a, b):
50 | n = a.shape[0]
51 | m = b.shape[0]
52 | a = a.unsqueeze(1).expand(n, m, -1)
53 | b = b.unsqueeze(0).expand(n, m, -1)
54 | logits = -((a - b)**2).sum(dim=2)
55 | return logits
56 |
--------------------------------------------------------------------------------
/preprocessing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asappresearch/dialog-intent-induction/6396f3153b0fda7e170b1df6b68e969b5e4eb16e/preprocessing/__init__.py
--------------------------------------------------------------------------------
/preprocessing/askubuntu.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | from nltk.tokenize import word_tokenize
4 | from nltk.corpus import stopwords
5 | from bs4 import BeautifulSoup
6 |
7 |
8 | REMOVE_STOP = False
9 | QUESTION_WORDS = ["what", "when", "where", "why", "how", "who"]
10 | STOPWORDS = set(stopwords.words("english"))
11 | for w in QUESTION_WORDS:
12 | STOPWORDS.remove(w)
13 |
14 | # replace URL with a special token *=URL=*
15 | URL_PATTERN = r'(?:(?:https?|ftp)://)(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
16 |
17 | URLREG = re.compile(URL_PATTERN)
18 |
19 |
20 | def inner_preprocess_text_from_tao(text, max_len):
21 | # remove non-ascii chars
22 | text = text.encode("utf8").decode("ascii", "ignore")
23 |
24 | # remove URLs
25 | text = URLREG.sub('*=URL=*', text)
26 |
27 | text = text.casefold()
28 |
29 | # tokenize, filter and truncate
30 | words = word_tokenize(text)
31 | if REMOVE_STOP:
32 | words = [x for x in words if x not in STOPWORDS]
33 | if max_len > 0:
34 | words = words[:max_len]
35 | return u' '.join(words)
36 |
37 |
38 | def preprocess_from_tao(text, max_len):
39 | """
40 | adapted from Tao's code for http://aclweb.org/anthology/D18-1131
41 | """
42 | body_soup = BeautifulSoup(text, "lxml")
43 |
44 | # remove code
45 | [x.extract() for x in body_soup.findAll('code')]
46 |
47 | # also remove pre
48 | [x.extract() for x in body_soup.findAll('pre')]
49 |
50 | # remove "Possible Duplicate" section
51 | blk = body_soup.blockquote
52 | if blk and blk.text.strip().startswith("Possible Duplicate:"):
53 | blk.decompose()
54 | body_cleaned = inner_preprocess_text_from_tao(body_soup.text, max_len=max_len)
55 | assert "Possible Duplicate:" not in body_cleaned
56 |
57 | assert "\n" not in body_cleaned
58 | return body_cleaned
59 |
60 |
61 | class Preprocessor(object):
62 | def __init__(self, max_len):
63 | self.max_len = max_len
64 |
65 | def __call__(self, text):
66 | return True, preprocess_from_tao(text, max_len=self.max_len)
67 |
--------------------------------------------------------------------------------
/preprocessing/twitter_airlines.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import nltk
4 |
5 |
6 | class Preprocessor(object):
7 | def __init__(self):
8 | self.twitter_at_re = re.compile('@[a-zA-Z1-90]+')
9 | self.initials_re = re.compile('\^[A-Z][A-Z]([^A-Za-z]|$)')
10 | self.star_initials_re = re.compile('\*[A-Z]+$')
11 | self.money_re = re.compile('\$[1-90]+([-.][1-90]+)?')
12 | self.number_re = re.compile('[1-90]+([-.: ()][1-90]+)?')
13 | self.airport_code_re = re.compile('([^a-zA-Z])[A-Z][A-Z][A-Z]([^a-zA-Z])')
14 | self.tag_re = re.compile('([^a-zA-Z])#[a-zA-Z]+')
15 | self.url_re = re.compile('http[s]?:(//)?(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+') # from mleng data_preprocess.py
16 |
17 | def __call__(self, text, info_dict=None):
18 | is_valid = True
19 | if 'DM' in text:
20 | is_valid = False
21 | if 'https://t.co' in text:
22 | is_valid = False
23 |
24 | target_company = info_dict['target_company']
25 | cust_twitter_id = info_dict['cust_twitter_id']
26 | text = text.replace('\n', ' ').encode('ascii', 'ignore').decode('ascii')
27 | text = text.replace('@' + target_company, ' __company__ ').replace('@' + cust_twitter_id, ' __cust__ ')
28 | text = text.replace('<', '<').replace('>', '>').replace('&', '&')
29 | text = self.twitter_at_re.sub('__twitter_at__', text)
30 | text = self.initials_re.sub('__initials__', text)
31 | text = self.star_initials_re.sub('__initials__', text)
32 | text = self.money_re.sub(' __money__ ', text)
33 | text = self.number_re.sub(' __num__ ', text)
34 | text = self.airport_code_re.sub('\g<1> __airport_code__ \g<2>', text)
35 | text = self.tag_re.sub('\g<1> __twitter_tag__ ', text)
36 | text = text.casefold()
37 | text = self.url_re.sub(' __url__ ', text)
38 | text = ' '.join(nltk.word_tokenize(text))
39 |
40 | return is_valid, text
41 |
--------------------------------------------------------------------------------
/pretrain.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import numpy as np
3 |
4 | import train as train_mod
5 |
6 |
7 | def pretrain_qt(dataset, perm_idx, expressions, train=True):
8 | """
9 | for each pair of utterances:
10 | - encodes first utterance using 'v1' encoder
11 | - encodes second utterance using 'qt_context' encoder
12 | uses negative sampling loss between these two embeddings, relative to the
13 | other second utterances in the batch
14 | """
15 | model, optimizer = expressions
16 |
17 | utts = []
18 | qt_ex = []
19 | for idx in perm_idx:
20 | v1, v2 = dataset[idx][0]
21 | conversation = [v1] + v2
22 | for n, utt in enumerate(conversation):
23 | utts.append(utt)
24 | if n > 0:
25 | num_utt = len(utts)
26 | ex = (num_utt - 2, num_utt - 1)
27 | qt_ex.append(ex)
28 | qt_ex = np.random.permutation(qt_ex)
29 |
30 | total_loss, total_acc = 0., 0.
31 | n_batch = (len(qt_ex) + train_mod.BATCH_SIZE - 1) // train_mod.BATCH_SIZE
32 | for i in range(n_batch):
33 | qt_ex_batch = qt_ex[i*train_mod.BATCH_SIZE:(i+1)*train_mod.BATCH_SIZE]
34 |
35 | v1_idxes, v2_idxes = list(zip(*[(ex[0].item(), ex[1].item()) for ex in qt_ex_batch]))
36 | v1_utts = [utts[idx] for idx in v1_idxes]
37 | v2_utts = [utts[idx] for idx in v2_idxes]
38 |
39 | v1_state = model(v1_utts, encoder='v1')
40 | v2_state = model(v2_utts, encoder='qt')
41 |
42 | loss, acc = model.qt_loss(v2_state, v1_state)
43 | total_loss += loss.item()
44 | total_acc += acc * len(qt_ex_batch)
45 | if train:
46 | optimizer.zero_grad()
47 | loss.backward()
48 | optimizer.step()
49 | return total_loss, total_acc / len(qt_ex)
50 |
51 |
52 | def after_pretrain_qt(model):
53 | model.view2_word_rnn = copy.deepcopy(model.view1_word_rnn)
54 |
55 |
56 | def pretrain_ae(dataset, perm_idx, expressions, train=True):
57 | """
58 | uses v1 encoder to encode all utterances in both view1 and view2
59 | to utterance-level embeddings
60 | uses 'ae_decoder' rnn from model to decode these embeddings
61 | (works at utterance level)
62 | """
63 | model, optimizer = expressions
64 |
65 | utterances = []
66 | for idx in perm_idx:
67 | v1, v2 = dataset[idx][0]
68 | conversation = [v1] + v2
69 | utterances += conversation
70 | utterances = np.random.permutation(utterances)
71 |
72 | total_loss, total_acc = 0., 0.
73 | n_batch = (len(utterances) + train_mod.AE_BATCH_SIZE - 1) // train_mod.AE_BATCH_SIZE
74 | for i in range(n_batch):
75 | utt_batch = utterances[i*train_mod.AE_BATCH_SIZE:(i+1)*train_mod.AE_BATCH_SIZE]
76 | enc_state = model(utt_batch, encoder='v1')
77 | reconst = model.decode(decoder_input=utt_batch, latent_z=enc_state)
78 | loss, acc = model.reconst_loss(utt_batch, reconst)
79 |
80 | total_loss += loss.item()
81 | total_acc += acc * len(utt_batch)
82 | if train:
83 | optimizer.zero_grad()
84 | loss.backward()
85 | optimizer.step()
86 | total_acc = total_acc / len(utterances)
87 | return total_loss, total_acc
88 |
89 |
90 | def after_pretrain_ae(model):
91 | # we'll use the view1 encoder for both view 1 and view 2
92 | model.view2_word_rnn = copy.deepcopy(model.view1_word_rnn)
93 |
--------------------------------------------------------------------------------
/proc_data.py:
--------------------------------------------------------------------------------
1 | import csv
2 | from torch.utils.data import Dataset
3 | from nltk.tokenize import word_tokenize, sent_tokenize
4 | import numpy as np
5 | np.random.seed(0)
6 |
7 | PAD = "__PAD__"
8 | UNK = "__UNK__"
9 | START = "__START__"
10 | END = "__END__"
11 |
12 |
13 | class Dataset(Dataset):
14 | def __init__(self, fname, view1_col='view1_col', view2_col='view2_col', label_col='cluster_id',
15 | tokenized=True, max_sent=10, train_ratio=.9):
16 | """
17 | Args:
18 | fname: str, training data file
19 | view1_col: str, the column corresponding to view 1 input
20 | view2_col: str, the column corresponding to view 2 input
21 | label_col: str, the column corresponding to label
22 | """
23 |
24 | def tokens_to_idices(tokens):
25 | token_idices = []
26 | for token in tokens:
27 | if token not in token_to_id:
28 | token_to_id[token] = len(token_to_id)
29 | id_to_token.append(token)
30 | token_idices.append(token_to_id[token])
31 | return token_idices
32 |
33 | id_to_token = [PAD, UNK, START, END]
34 | token_to_id = {PAD: 0, UNK: 1, START: 2, END: 3}
35 | id_to_label = [UNK]
36 | label_to_id = {UNK: 0}
37 | data = []
38 | labels = []
39 | v1_utts = [] # needed for displaying cluster samples
40 | self.trn_idx, self.tst_idx = [], []
41 | self.trn_idx_no_unk = []
42 | with open(fname, 'r') as csvfile:
43 | reader = csv.DictReader(csvfile)
44 | for row in reader:
45 | view1_text, view2_text = row[view1_col], row[view2_col]
46 | label = row[label_col]
47 | if 'UNK' == label:
48 | label = UNK
49 | if ' <")
53 | for i in range(len(view2_sents) - 1):
54 | view2_sents[i] = view2_sents[i] + '>'
55 | view2_sents[i+1] = '<' + view2_sents[i + 1]
56 | v1_utts.append(view1_text)
57 | if not tokenized:
58 | v1_tokens = word_tokenize(view1_text.lower())
59 | v2_tokens = [word_tokenize(sent.lower()) for sent in view2_sents]
60 | else:
61 | v1_tokens = view1_text.lower().split()
62 | v2_tokens = [sent.lower().split() for sent in view2_sents]
63 | v2_tokens = v2_tokens[:max_sent]
64 |
65 | v1_token_idices = tokens_to_idices(v1_tokens)
66 | v2_token_idices = [tokens_to_idices(tokens) for tokens in v2_tokens]
67 | v2_token_idices = [idices for idices in v2_token_idices if len(idices) > 0]
68 | if len(v1_token_idices) == 0 or len(v2_token_idices) == 0:
69 | continue
70 | if label not in label_to_id:
71 | label_to_id[label] = len(label_to_id)
72 | id_to_label.append(label)
73 | data.append((v1_token_idices, v2_token_idices))
74 | labels.append(label_to_id[label])
75 | if label == UNK and np.random.random_sample() < .1:
76 | self.tst_idx.append(len(data)-1)
77 | else:
78 | self.trn_idx.append(len(data)-1)
79 | if label != UNK:
80 | self.trn_idx_no_unk.append(len(data) - 1)
81 |
82 | self.v1_utts = v1_utts
83 | self.id_to_token = id_to_token
84 | self.token_to_id = token_to_id
85 | self.id_to_label = id_to_label
86 | self.label_to_id = label_to_id
87 | self.data = data
88 | self.labels = labels
89 |
90 | def __len__(self):
91 | return len(self.data)
92 |
93 | def __getitem__(self, i):
94 | return self.data[i], self.labels[i]
95 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # we used nltk 3.4.5 for our experiments,
2 | # but that is showing a security vulnerability,
3 | # so changing this to 3.6.4
4 | # nltk==3.4.5
5 | nltk==3.6.4
6 | numpy==1.17.3
7 | scikit-learn==0.21.3
8 | scipy==1.3.1
9 | torch==1.3.0
10 |
--------------------------------------------------------------------------------
/run_mvsc.py:
--------------------------------------------------------------------------------
1 | """
2 | Given a pre-trained model, run mvsc on it, and print scores vs gold standard
3 |
4 | we'll use view1 encoder to encode each of view1 and view2, and then pass that through mvsc algo
5 |
6 | this should probably be folded innto run_clustering.py (originally kind of forked from
7 | run_clustering.py, and combined with some things from train_pca.py and train.py)
8 | """
9 | import time
10 | import random
11 | import datetime
12 | import argparse
13 |
14 | import sklearn.cluster
15 | import numpy as np
16 | import torch
17 |
18 | from metrics import cluster_metrics
19 | from model import multiview_encoders
20 | from proc_data import Dataset
21 |
22 | try:
23 | import multiview
24 | except Exception as e:
25 | print('please install https://github.com/mariceli3/multiview')
26 | print('eg pip install git+https://github.com/mariceli3/multiview')
27 | raise e
28 |
29 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30 |
31 | BATCH_SIZE = 32
32 |
33 |
34 | def transform(dataset, perm_idx, model, view):
35 | """
36 | for view1 utterance, simply encode using view1 encoder
37 | for view 2 utterances:
38 | - encode each utterance, using view 1 encoder, to get utterance embeddings
39 | - take average of utterance embeddings to form view 2 embedding
40 | """
41 | model.eval()
42 | latent_zs, golds = [], []
43 | n_batch = (len(perm_idx) + BATCH_SIZE - 1) // BATCH_SIZE
44 | for i in range(n_batch):
45 | indices = perm_idx[i*BATCH_SIZE:(i+1)*BATCH_SIZE]
46 | v1_batch, v2_batch = list(zip(*[dataset[idx][0] for idx in indices]))
47 | golds += [dataset[idx][1] for idx in indices]
48 | if view == 'v1':
49 | latent_z = model(v1_batch, encoder='v1')
50 | elif view == 'v2':
51 | latent_z_l = [model(conv, encoder='v1').mean(dim=0) for conv in v2_batch]
52 | latent_z = torch.stack(latent_z_l)
53 | latent_zs.append(latent_z.cpu().data.numpy())
54 | latent_zs = np.concatenate(latent_zs)
55 | return latent_zs, golds
56 |
57 |
58 | def run(
59 | ref, model_path, num_clusters, num_cluster_samples, seed,
60 | out_cluster_samples_file_hier,
61 | max_examples, out_cluster_samples_file,
62 | data_path, view1_col, view2_col, label_col,
63 | sampling_strategy, mvsc_no_unk):
64 | torch.manual_seed(seed)
65 | np.random.seed(seed)
66 | random.seed(seed)
67 |
68 | id_to_token, token_to_id, vocab_size, word_emb_size, mvc_encoder = \
69 | multiview_encoders.load_model(model_path)
70 | print('loaded model')
71 |
72 | print('loading dataset')
73 | dataset = Dataset(data_path, view1_col=view1_col, view2_col=view2_col, label_col=label_col)
74 | n_cluster = len(dataset.id_to_label) - 1
75 | print("loaded dataset, num of class = %d" % n_cluster)
76 |
77 | idxes = dataset.trn_idx_no_unk if mvsc_no_unk else dataset.trn_idx
78 | trn_idx = [x.item() for x in np.random.permutation(idxes)]
79 | if max_examples is not None:
80 | trn_idx = trn_idx[:max_examples]
81 |
82 | num_clusters = n_cluster if num_clusters is None else num_clusters
83 | print('clustering over num clusters', num_clusters)
84 |
85 | mvsc = multiview.mvsc.MVSC(
86 | k=n_cluster
87 | )
88 | latent_z1s, golds = transform(dataset, trn_idx, mvc_encoder, view='v1')
89 | latent_z2s, _ = transform(dataset, trn_idx, mvc_encoder, view='v2')
90 | print('running mvsc', end='', flush=True)
91 | start = time.time()
92 | preds, eivalues, eivectors, sigmas = mvsc.fit_transform(
93 | [latent_z1s, latent_z2s], [False] * 2
94 | )
95 | print('...done')
96 | mvsc_time = time.time() - start
97 | print('time taken %.3f' % mvsc_time)
98 |
99 | lgolds, lpreds = [], []
100 | for g, p in zip(golds, list(preds)):
101 | if g > 0:
102 | lgolds.append(g)
103 | lpreds.append(p)
104 | prec, rec, f1 = cluster_metrics.calc_prec_rec_f1(
105 | gnd_assignments=torch.LongTensor(lgolds).to(device),
106 | pred_assignments=torch.LongTensor(lpreds).to(device))
107 | acc = cluster_metrics.calc_ACC(
108 | torch.LongTensor(lpreds).to(device), torch.LongTensor(lgolds).to(device))
109 | silhouette = sklearn.metrics.silhouette_score(latent_z1s, preds, metric='euclidean')
110 | davies_bouldin = sklearn.metrics.davies_bouldin_score(latent_z1s, preds)
111 | print(f'{datetime.datetime.now()} pretrain: eval prec={prec:.4f} rec={rec:.4f} f1={f1:.4f} '
112 | f'acc={acc:.4f} sil={silhouette:.4f}, db={davies_bouldin:.4f}')
113 |
114 |
115 | if __name__ == '__main__':
116 | parser = argparse.ArgumentParser()
117 | parser.add_argument('--seed', type=int, default=123)
118 | parser.add_argument('--max-examples', type=int,
119 | help='since we might not want to cluster entire dataset?')
120 | parser.add_argument('--mvsc-no-unk', action='store_true',
121 | help='only feed non-unk data to MVSC (to avoid oom)')
122 | parser.add_argument('--ref', type=str, required=True)
123 | parser.add_argument('--model-path', type=str, required=True)
124 | parser.add_argument('--data-path', type=str, default='./data/airlines_500_merged.csv')
125 | parser.add_argument('--view1-col', type=str, default='view1_col')
126 | parser.add_argument('--view2-col', type=str, default='view2_col')
127 | parser.add_argument('--label-col', type=str, default='cluster_id')
128 | parser.add_argument('--num-clusters', type=int, help='defaults to number of supervised classes')
129 | parser.add_argument('--num-cluster-samples', type=int, default=10)
130 | parser.add_argument('--sampling-strategy', type=str,
131 | choices=['uniform', 'nearest'], default='nearest')
132 | parser.add_argument('--out-cluster-samples-file-hier', type=str,
133 | default='tmp/cluster_samples_hier_{ref}.txt')
134 | parser.add_argument('--out-cluster-samples-file', type=str,
135 | default='tmp/cluster_samples_{ref}.txt')
136 | args = parser.parse_args()
137 | args.out_cluster_samples_file = args.out_cluster_samples_file.format(**args.__dict__)
138 | args.out_cluster_samples_file_hier = args.out_cluster_samples_file_hier.format(**args.__dict__)
139 | run(**args.__dict__)
140 |
--------------------------------------------------------------------------------
/samplers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | np.random.seed(0)
4 |
5 |
6 | class CategoriesSampler():
7 | def __init__(self, labels, n_batch, n_cls, n_ins):
8 | """
9 | Args:
10 | labels: size=(dataset_size), label indices of instances in a data set
11 | n_batch: int, number of batchs for episode training
12 | n_cls: int, number of sampled classes
13 | n_ins: int, number of instances considered for a class
14 |
15 | conceptually, this is for prototypical sampling:
16 | - for each training step, we sample 'n_ways' classes (in the paper), which is 'n_cls' here
17 | - we draw 'n_shot' examples of each class, ie n_ins here
18 | - these will be encoded, and averaged, to get the prototypes
19 | - and we draw 'n_query' query examples, of each class, which is also 'n_ins' here
20 | - these will be encoded, and then used to generate the prototype loss, relative to the
21 | prototypes
22 |
23 | __iter__ returns a generator, which will yield n_batch sets of training data, one set of
24 | training
25 | data per yield command
26 | """
27 | if not isinstance(labels, list):
28 | labels = labels.tolist()
29 |
30 | self.n_batch = n_batch
31 | self.n_cls = n_cls
32 | self.n_ins = n_ins
33 |
34 | self.classes = list(set(labels))
35 | labels = np.array(labels)
36 | self.cls_indices = {}
37 | for c in self.classes:
38 | indices = np.argwhere(labels == c).reshape(-1)
39 | self.cls_indices[c] = indices
40 |
41 | def __len__(self):
42 | return self.n_batch
43 |
44 | def __iter__(self):
45 | for _ in range(self.n_batch):
46 | batch = []
47 | classes = np.random.permutation(self.classes)[:self.n_cls]
48 | for c in classes:
49 | indices = self.cls_indices[c]
50 | while len(indices) < self.n_ins:
51 | indices = np.concatenate((indices, indices))
52 | batch.append(np.random.permutation(indices)[:self.n_ins])
53 | batch = np.stack(batch).flatten('F')
54 | yield batch
55 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max_line_length = 100
3 |
--------------------------------------------------------------------------------
/tests/test_multiview_encoders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from model.multiview_encoders import MultiviewEncoders
5 |
6 |
7 | def test_ae_decoder():
8 | dropout = word_dropout = 0
9 | embedding_size = 15
10 | num_layers = 1
11 | vocab_size = 13
12 | lstm_hidden_size = 11
13 |
14 | batch_size = 5
15 | seq_len = 6
16 |
17 | print('vocab_size', vocab_size)
18 | print('batch_size', batch_size)
19 | print('seq_len', seq_len)
20 | print('lstm_hidden_size', lstm_hidden_size)
21 |
22 | _input_idxes = torch.from_numpy(np.random.choice(
23 | vocab_size - 1, (batch_size, seq_len), replace=True)) + 1
24 | print('_input_idxes', _input_idxes)
25 | print('_input_idxes.size()', _input_idxes.size())
26 | input_idxes = []
27 | for b in range(batch_size):
28 | idxes = [_input_idxes[b][i].item() for i in range(seq_len)]
29 | input_idxes.append(idxes)
30 | latent_z = torch.rand((batch_size, lstm_hidden_size * 2))
31 | print('latent_z.size()', latent_z.size())
32 | print('input_idxes', input_idxes)
33 |
34 | encoders = MultiviewEncoders(vocab_size, num_layers, embedding_size, lstm_hidden_size,
35 | word_dropout, dropout)
36 | with torch.no_grad():
37 | logits = encoders.decode(decoder_input=input_idxes, latent_z=latent_z)
38 | print('logits.size()', logits.size())
39 |
40 | for n in range(batch_size):
41 | _input_idxes = input_idxes[n:n + 1]
42 | _latent_z = latent_z[n: n + 1]
43 | with torch.no_grad():
44 | _logits = encoders.decode(decoder_input=_input_idxes, latent_z=_latent_z)
45 | diff = (logits[n] - _logits).abs().max().item()
46 | assert diff < 1e-7
47 |
48 |
49 | def test_reconst_loss():
50 | dropout = word_dropout = 0
51 | embedding_size = 15
52 | num_layers = 1
53 | lstm_hidden_size = 11
54 |
55 | vocab_size = 13
56 | batch_size = 5
57 | seq_len = 6
58 |
59 | print('vocab_size', vocab_size)
60 | print('batch_size', batch_size)
61 | print('seq_len', seq_len)
62 | print('lstm_hidden_size', lstm_hidden_size)
63 |
64 | _input_idxes = torch.from_numpy(np.random.choice(
65 | vocab_size - 1, (batch_size, seq_len), replace=True)) + 1
66 | print('_input_idxes', _input_idxes)
67 | print('_input_idxes.size()', _input_idxes.size())
68 | input_idxes = []
69 | for b in range(batch_size):
70 | idxes = [_input_idxes[b][i].item() for i in range(seq_len)]
71 | input_idxes.append(idxes)
72 |
73 | encoders = MultiviewEncoders(vocab_size, num_layers, embedding_size, lstm_hidden_size,
74 | word_dropout, dropout)
75 |
76 | probs = torch.zeros(batch_size, seq_len + 1, vocab_size)
77 | probs[:, seq_len, encoders.end_idx] = 1
78 | for b in range(batch_size):
79 | for i, idx in enumerate(input_idxes[b]):
80 | probs[b, i, idx] = 1
81 | logits = probs.log()
82 | print('logits.sum(dim=-1)', logits.sum(dim=-1))
83 | print('logits.min(dim=-1)', logits.min(dim=-1)[0])
84 | print('logits.max(dim=-1)', logits.max(dim=-1)[0])
85 | _, logits_max = logits.max(dim=-1)
86 | print('logits_max', logits_max)
87 | assert (logits_max[:, :seq_len] == _input_idxes).all()
88 |
89 | loss, acc = encoders.reconst_loss(input_idxes, logits)
90 | loss = loss.item()
91 | print('loss', loss, 'acc', acc)
92 | assert acc == 1.0
93 | assert loss == 0.0
94 |
--------------------------------------------------------------------------------
/tests/test_samplers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from collections import defaultdict
3 |
4 | import samplers
5 |
6 |
7 | def test_categories_sampler():
8 | N = 50
9 | n_batch = 5
10 | n_cls = 3
11 | n_ins = 4
12 |
13 | np.random.seed(123)
14 | labels = np.random.choice(5, N, replace=True)
15 | print('labels', labels)
16 |
17 | sampler = samplers.CategoriesSampler(labels, n_batch, n_cls, n_ins)
18 | for b, batch in enumerate(sampler):
19 | print('batch', batch)
20 | classes = []
21 | count_by_class = defaultdict(int)
22 | for i in batch:
23 | classes.append(labels[i])
24 | count_by_class[labels[i]] += 1
25 | print('classes', classes)
26 | print('count_by_class', count_by_class)
27 | assert len(count_by_class) == n_cls
28 | for v in count_by_class.values():
29 | assert v == n_ins
30 | assert b == n_batch - 1
31 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import copy
4 | import numpy as np
5 | from os.path import expanduser as expand
6 | import sklearn.cluster
7 | import sklearn.metrics
8 | import warnings
9 |
10 | import torch
11 | import torch.nn.functional as F
12 |
13 | from proc_data import Dataset
14 | from model import multiview_encoders
15 | from metrics import cluster_metrics
16 | from samplers import CategoriesSampler
17 | from model import utils
18 | import pretrain
19 |
20 |
21 | warnings.filterwarnings(action='ignore', category=RuntimeWarning)
22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23 | torch.manual_seed(0)
24 |
25 | LSTM_LAYER = 1
26 | LSTM_HIDDEN = 300
27 | WORD_DROPOUT_RATE = 0.
28 | DROPOUT_RATE = 0.
29 | BATCH_SIZE = 32
30 | AE_BATCH_SIZE = 8
31 | LEARNING_RATE = 0.001
32 |
33 |
34 | def do_pass(batches, shot, way, query, expressions, encoder):
35 | model, optimizer = expressions
36 | model.train()
37 | for i, batch in enumerate(batches, 1):
38 | data = [x[{'v1': 0, 'v2': 1}[encoder]] for x, _ in batch]
39 | p = shot * way
40 | data_shot, data_query = data[:p], data[p:]
41 |
42 | proto = model(data_shot, encoder=encoder)
43 | proto = proto.reshape(shot, way, -1).mean(dim=0)
44 |
45 | # ignore original labels, reassign labels from 0
46 | label = torch.arange(way).repeat(query)
47 | label = label.type(torch.LongTensor).to(device)
48 |
49 | logits = utils.euclidean_metric(model(data_query, encoder=encoder), proto)
50 | loss = F.cross_entropy(logits, label)
51 | optimizer.zero_grad()
52 |
53 | loss.backward()
54 | optimizer.step()
55 | return loss.item()
56 |
57 |
58 | def transform(dataset, perm_idx, model, encoder):
59 | model.eval()
60 | latent_zs, golds = [], []
61 | n_batch = (len(perm_idx) + BATCH_SIZE - 1) // BATCH_SIZE
62 | for i in range(n_batch):
63 | indices = perm_idx[i*BATCH_SIZE:(i+1)*BATCH_SIZE]
64 | v1_batch, v2_batch = list(zip(*[dataset[idx][0] for idx in indices]))
65 | golds += [dataset[idx][1] for idx in indices]
66 | latent_z = model({'v1': v1_batch, 'v2': v2_batch}[encoder], encoder=encoder)
67 | latent_zs.append(latent_z.cpu().data.numpy())
68 | latent_zs = np.concatenate(latent_zs)
69 | return latent_zs, golds
70 |
71 |
72 | def calc_centroids(latent_zs, assignments, n_cluster):
73 | centroids = []
74 | for i in range(n_cluster):
75 | idx = np.where(assignments == i)[0]
76 | mean = np.mean(latent_zs[idx], 0)
77 | centroids.append(mean)
78 | return np.stack(centroids)
79 |
80 |
81 | def run_one_side(model, optimizer, preds_left, pt_batch, way, shot, query, n_cluster, dataset,
82 | right_encoder_side):
83 | """
84 | encoder_side should be 'v1' or 'v2'. It should match whichever view is 'right' here.
85 | """
86 | loss = 0
87 |
88 | sampler = CategoriesSampler(preds_left, pt_batch, way, shot + query)
89 | train_batches = [[dataset[dataset.trn_idx[idx]] for idx in indices] for indices in sampler]
90 | loss += do_pass(train_batches, shot, way, query, (model, optimizer), encoder=right_encoder_side)
91 |
92 | z_right, _ = transform(dataset, dataset.trn_idx, model, encoder=right_encoder_side)
93 | centroids = calc_centroids(z_right, preds_left, n_cluster)
94 | kmeans = sklearn.cluster.KMeans(
95 | n_clusters=n_cluster, init=centroids, max_iter=10, verbose=0)
96 | preds_right = kmeans.fit_predict(z_right)
97 |
98 | tst_z_right, _ = transform(dataset, dataset.tst_idx, model, encoder=right_encoder_side)
99 | tst_preds_right = kmeans.predict(tst_z_right)
100 |
101 | return loss, preds_right, tst_preds_right
102 |
103 |
104 | def main():
105 | parser = argparse.ArgumentParser()
106 | parser.add_argument('--data-path', type=str, default='./data/airlines_processed.csv')
107 | parser.add_argument('--glove-path', type=str, default='./data/glove.840B.300d.txt')
108 | parser.add_argument('--pre-model', type=str, choices=['ae', 'qt'], default='qt')
109 | parser.add_argument('--pre-epoch', type=int, default=0)
110 | parser.add_argument('--pt-batch', type=int, default=100)
111 | parser.add_argument('--model-path', type=str, help='path of pretrained model to load')
112 | parser.add_argument('--way', type=int, default=5)
113 | parser.add_argument('--num-epochs', type=int, default=100)
114 | parser.add_argument('--seed', type=int, default=0)
115 |
116 | parser.add_argument('--save-model-path', type=str)
117 |
118 | parser.add_argument('--view1-col', type=str, default='view1')
119 | parser.add_argument('--view2-col', type=str, default='view2')
120 | parser.add_argument('--label-col', type=str, default='label')
121 | args = parser.parse_args()
122 |
123 | np.random.seed(args.seed)
124 |
125 | print('loading dataset')
126 | dataset = Dataset(
127 | args.data_path, view1_col=args.view1_col, view2_col=args.view2_col,
128 | label_col=args.label_col)
129 | n_cluster = len(dataset.id_to_label) - 1
130 | print("num of class = %d" % n_cluster)
131 |
132 | if args.model_path is not None:
133 | id_to_token, token_to_id, vocab_size, word_emb_size, model = multiview_encoders.load_model(
134 | args.model_path)
135 | print('loaded model')
136 | else:
137 | id_to_token, token_to_id, vocab_size, word_emb_size, model = \
138 | multiview_encoders.from_embeddings(
139 | args.glove_path, dataset.id_to_token, dataset.token_to_id)
140 | print('created randomly initialized model')
141 | print('vocab_size', vocab_size)
142 |
143 | optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
144 | expressions = (model, optimizer)
145 |
146 | pre_acc, pre_state, pre_state_epoch = 0., None, None
147 | pretrain_method = {
148 | 'ae': pretrain.pretrain_ae,
149 | 'qt': pretrain.pretrain_qt,
150 | }[args.pre_model]
151 | for epoch in range(1, args.pre_epoch + 1):
152 | model.train()
153 | perm_idx = np.random.permutation(dataset.trn_idx)
154 | trn_loss, _ = pretrain_method(dataset, perm_idx, expressions, train=True)
155 | model.eval()
156 | _, tst_acc = pretrain_method(dataset, dataset.tst_idx, expressions, train=False)
157 | if tst_acc > pre_acc:
158 | pre_state = copy.deepcopy(model.state_dict())
159 | pre_acc = tst_acc
160 | pre_state_epoch = epoch
161 | print('{} epoch {}, train_loss={:.4f} test_acc={:.4f}'.format(
162 | datetime.datetime.now(), epoch, trn_loss, tst_acc))
163 |
164 | if args.pre_epoch > 0:
165 | # load best state
166 | model.load_state_dict(pre_state)
167 | print(f'loaded best state from epoch {pre_state_epoch}')
168 |
169 | # deepcopy pretrained views into v1 and/or view2
170 | {
171 | 'ae': pretrain.after_pretrain_ae,
172 | 'qt': pretrain.after_pretrain_qt,
173 | }[args.pre_model](model)
174 |
175 | # reinitialiate optimizer
176 | optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
177 | expressions = (model, optimizer)
178 | print('applied post-pretraining')
179 |
180 | kmeans = sklearn.cluster.KMeans(n_clusters=n_cluster, max_iter=300, verbose=0, random_state=0)
181 | z_v1, golds = transform(dataset, dataset.trn_idx, model, encoder='v1')
182 | preds_v1 = kmeans.fit_predict(z_v1)
183 |
184 | lgolds, lpreds = [], []
185 | for g, p in zip(golds, list(preds_v1)):
186 | if g > 0:
187 | lgolds.append(g)
188 | lpreds.append(p)
189 | prec, rec, f1 = cluster_metrics.calc_prec_rec_f1(
190 | gnd_assignments=torch.LongTensor(lgolds).to(device),
191 | pred_assignments=torch.LongTensor(lpreds).to(device))
192 | acc = cluster_metrics.calc_ACC(
193 | torch.LongTensor(lpreds).to(device), torch.LongTensor(lgolds).to(device))
194 |
195 | print(f'{datetime.datetime.now()} pretrain: test prec={prec:.4f} rec={rec:.4f} '
196 | f'f1={f1:.4f} acc={acc:.4f}')
197 |
198 | shot, way, query = 5, args.way, 15
199 |
200 | preds_v2 = None
201 | best_epoch, best_model, best_dev_f1 = None, None, None
202 | for epoch in range(1, args.num_epochs + 1):
203 | trn_loss = 0.
204 |
205 | _loss, preds_v2, tst_preds_v2 = run_one_side(
206 | model=model, optimizer=optimizer, preds_left=preds_v1,
207 | pt_batch=args.pt_batch, way=way, shot=shot, query=query, n_cluster=n_cluster,
208 | dataset=dataset, right_encoder_side='v2')
209 | trn_loss += _loss
210 |
211 | _loss, preds_v1, tst_preds_v1 = run_one_side(
212 | model=model, optimizer=optimizer, preds_left=preds_v2,
213 | pt_batch=args.pt_batch, way=way, shot=shot, query=query, n_cluster=n_cluster,
214 | dataset=dataset, right_encoder_side='v1')
215 | trn_loss += _loss
216 |
217 | dev_f1 = cluster_metrics.calc_f1(gnd_assignments=torch.LongTensor(tst_preds_v1).to(device),
218 | pred_assignments=torch.LongTensor(tst_preds_v2).to(device))
219 | dev_acc = cluster_metrics.calc_ACC(
220 | torch.LongTensor(tst_preds_v2).to(device), torch.LongTensor(tst_preds_v1).to(device))
221 |
222 | print('dev view 1 vs view 2: f1={:.4f} acc={:.4f}'.format(dev_f1, dev_acc))
223 |
224 | if best_dev_f1 is None or dev_f1 > best_dev_f1:
225 | print('new best epoch', epoch)
226 | best_epoch = epoch
227 | best_dev_f1 = dev_f1
228 | best_model = copy.deepcopy(model.state_dict())
229 | best_preds_v1 = preds_v1.copy()
230 | best_preds_v2 = preds_v2.copy()
231 |
232 | lgolds, lpreds = [], []
233 | for g, p in zip(golds, list(preds_v1)):
234 | if g > 0:
235 | lgolds.append(g)
236 | lpreds.append(p)
237 | prec, rec, f1 = cluster_metrics.calc_prec_rec_f1(
238 | gnd_assignments=torch.LongTensor(lgolds).to(device),
239 | pred_assignments=torch.LongTensor(lpreds).to(device))
240 | acc = cluster_metrics.calc_ACC(
241 | torch.LongTensor(lpreds).to(device), torch.LongTensor(lgolds).to(device))
242 |
243 | print(f'{datetime.datetime.now()} epoch {epoch}, test prec={prec:.4f} rec={rec:.4f} '
244 | f'f1={f1:.4f} acc={acc:.4f}')
245 |
246 | print('restoring model for best dev epoch', best_epoch)
247 | model.load_state_dict(best_model)
248 | preds_v1, preds_v2 = best_preds_v1, best_preds_v2
249 |
250 | lgolds, lpreds = [], []
251 | for g, p in zip(golds, list(preds_v1)):
252 | if g > 0:
253 | lgolds.append(g)
254 | lpreds.append(p)
255 | prec, rec, f1 = cluster_metrics.calc_prec_rec_f1(
256 | gnd_assignments=torch.LongTensor(lgolds).to(device),
257 | pred_assignments=torch.LongTensor(lpreds).to(device))
258 | acc = cluster_metrics.calc_ACC(
259 | torch.LongTensor(lpreds).to(device), torch.LongTensor(lgolds).to(device))
260 | print(f'{datetime.datetime.now()} test prec={prec:.4f} rec={rec:.4f} f1={f1:.4f} acc={acc:.4f}')
261 |
262 | if args.save_model_path is not None:
263 | preds_v1 = torch.from_numpy(preds_v1)
264 | if preds_v2 is not None:
265 | preds_v2 = torch.from_numpy(preds_v2)
266 | state = {
267 | 'model_state': model.state_dict(),
268 | 'id_to_token': dataset.id_to_token,
269 | 'word_emb_size': word_emb_size,
270 | 'v1_assignments': preds_v1,
271 | 'v2_assignments': preds_v2
272 | }
273 | with open(expand(args.save_model_path), 'wb') as f:
274 | torch.save(state, f)
275 | print('saved model to ', args.save_model_path)
276 |
277 |
278 | if __name__ == '__main__':
279 | main()
280 |
--------------------------------------------------------------------------------
/train_pca.py:
--------------------------------------------------------------------------------
1 | """
2 | train PCA baseline
3 |
4 | forked from train.py, then MVC stripped, and replaced with BoW + PCA
5 |
6 | if using mvsc, needs installation of:
7 | - https://github.com/mariceli3/multiview
8 | """
9 | import argparse
10 | import datetime
11 | import numpy as np
12 | import sklearn.cluster
13 | import warnings
14 | import time
15 | from sklearn.feature_extraction.text import TfidfVectorizer
16 | from sklearn.decomposition import TruncatedSVD
17 |
18 | import torch
19 |
20 | from proc_data import Dataset
21 | from metrics import cluster_metrics
22 |
23 |
24 | warnings.filterwarnings(action='ignore', category=RuntimeWarning)
25 | np.random.seed(0)
26 | torch.manual_seed(0)
27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28 |
29 |
30 | def main():
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument('--data-path', type=str, default='./data/airlines_processed.csv')
33 | parser.add_argument('--model', type=str, choices=[
34 | 'view1pca', 'view2pca', 'wholeconvpca', 'mvsc'], default='view1pca')
35 | parser.add_argument('--pca-dims', type=int, default=600)
36 |
37 | parser.add_argument('--no-idf', action='store_true')
38 | parser.add_argument('--mvsc-no-unk', action='store_true',
39 | help='only feed non-unk data to MVSC (to avoid oom)')
40 |
41 | parser.add_argument('--view1-col', type=str, default='view1')
42 | parser.add_argument('--view2-col', type=str, default='view2')
43 | parser.add_argument('--label-col', type=str, default='tag')
44 | args = parser.parse_args()
45 | run(**args.__dict__)
46 |
47 |
48 | def run(data_path, model, pca_dims, view1_col, view2_col, label_col, no_idf, mvsc_no_unk):
49 | print('loading dataset')
50 | dataset = Dataset(data_path, view1_col=view1_col, view2_col=view2_col, label_col=label_col)
51 | n_cluster = len(dataset.id_to_label) - 1
52 | print("num of class = %d" % n_cluster)
53 |
54 | vocab_size = len(dataset.token_to_id)
55 | print('vocab_size', vocab_size)
56 |
57 | if model == 'mvsc':
58 | try:
59 | import multiview
60 | except Exception:
61 | print('please install https://github.com/mariceli3/multiview')
62 | return
63 | print('imported multiview ok')
64 |
65 | def run_pca(features):
66 | print('fitting tfidf vectorizer', flush=True, end='')
67 | vectorizer = TfidfVectorizer(token_pattern='\\d+', ngram_range=(1, 1), analyzer='word',
68 | min_df=0.0, max_df=1.0, use_idf=not no_idf)
69 | X = vectorizer.fit_transform(features)
70 | print(' ... done')
71 | print('X.shape', X.shape)
72 |
73 | print('running pca', flush=True, end='')
74 | pca = TruncatedSVD(n_components=pca_dims)
75 | X2 = pca.fit_transform(X)
76 | print(' ... done')
77 | return X2
78 |
79 | golds = [dataset[idx][1] for idx in dataset.trn_idx]
80 |
81 | if model in ['view1pca', 'view2pca', 'wholeconvpca']:
82 | if model == 'view1pca':
83 | utts = [dataset[idx][0][0] for idx in dataset.trn_idx]
84 | utts = [' '.join([str(idx) for idx in utt]) for utt in utts]
85 | elif model == 'view2pca':
86 | convs = [dataset[idx][0][1] for idx in dataset.trn_idx]
87 | utts = [[tok for utt in conv for tok in utt] for conv in convs]
88 | utts = [' '.join([str(idx) for idx in utt]) for utt in utts]
89 | elif model == 'wholeconvpca':
90 | v1 = [dataset[idx][0][0] for idx in dataset.trn_idx]
91 | convs = [dataset[idx][0][1] for idx in dataset.trn_idx]
92 | v2 = [[tok for utt in conv for tok in utt] for conv in convs]
93 | utts = []
94 | for n in range(len(v1)):
95 | utts.append(v1[n] + v2[n])
96 | utts = [' '.join([str(idx) for idx in utt]) for utt in utts]
97 |
98 | X2 = run_pca(utts)
99 |
100 | print('running kmeans', flush=True, end='')
101 | kmeans = sklearn.cluster.KMeans(
102 | n_clusters=n_cluster, max_iter=300, verbose=0, random_state=0)
103 | preds = kmeans.fit_predict(X2)
104 | print(' ... done')
105 | elif model == 'mvsc':
106 | mvsc = multiview.mvsc.MVSC(
107 | k=n_cluster
108 | )
109 | idxes = dataset.trn_idx_no_unk if mvsc_no_unk else dataset.trn_idx
110 | v1 = [dataset[idx][0][0] for idx in idxes]
111 | convs = [dataset[idx][0][1] for idx in idxes]
112 | v2 = [[tok for utt in conv for tok in utt] for conv in convs]
113 | v1 = [' '.join([str(idx) for idx in utt]) for utt in v1]
114 | v2 = [' '.join([str(idx) for idx in utt]) for utt in v2]
115 | v1_pca = run_pca(v1)
116 | v2_pca = run_pca(v2)
117 | print('running mvsc', end='', flush=True)
118 | start = time.time()
119 | preds, eivalues, eivectors, sigmas = mvsc.fit_transform(
120 | [v1_pca, v2_pca], [False] * 2
121 | )
122 | print('...done')
123 | mvsc_time = time.time() - start
124 | print('time taken %.3f' % mvsc_time)
125 |
126 | lgolds, lpreds = [], []
127 | for g, p in zip(golds, list(preds)):
128 | if g > 0:
129 | lgolds.append(g)
130 | lpreds.append(p)
131 | prec, rec, f1 = cluster_metrics.calc_prec_rec_f1(
132 | gnd_assignments=torch.LongTensor(lgolds).to(device),
133 | pred_assignments=torch.LongTensor(lpreds).to(device))
134 | acc = cluster_metrics.calc_ACC(
135 | torch.LongTensor(lpreds).to(device), torch.LongTensor(lgolds).to(device))
136 |
137 | print(f'{datetime.datetime.now()} eval f1={f1:.4f} prec={prec:.4f} rec={rec:.4f} acc={acc:.4f}')
138 |
139 | return prec, rec, f1, acc
140 |
141 |
142 | if __name__ == '__main__':
143 | main()
144 |
--------------------------------------------------------------------------------
/train_qt.py:
--------------------------------------------------------------------------------
1 | """
2 | pretrain using qt, to get neural representation, then run kmeans on various
3 | combinations of the resulting representations
4 |
5 | this was forked initially from train.py, then modified
6 | """
7 | import argparse
8 | import datetime
9 | import copy
10 | import time
11 | import numpy as np
12 | import sklearn.cluster
13 | import warnings
14 | import torch
15 | from torch import autograd
16 |
17 | from proc_data import Dataset
18 | from model.multiview_encoders import MultiviewEncoders
19 | from metrics import cluster_metrics
20 | import pretrain
21 |
22 |
23 | warnings.filterwarnings(action='ignore', category=RuntimeWarning)
24 |
25 | torch.manual_seed(0)
26 | np.random.seed(0)
27 |
28 | LSTM_LAYER = 1
29 | LSTM_HIDDEN = 300
30 | WORD_DROPOUT_RATE = 0.
31 | DROPOUT_RATE = 0.
32 | BATCH_SIZE = 32
33 | LEARNING_RATE = 0.001
34 |
35 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
36 |
37 |
38 | def transform(data, model):
39 | model.eval()
40 | latent_zs = []
41 | n_batch = (len(data) + BATCH_SIZE - 1) // BATCH_SIZE
42 | for i in range(n_batch):
43 | data_batch = data[i*BATCH_SIZE:(i+1)*BATCH_SIZE]
44 | with autograd.no_grad():
45 | latent_z = model(data_batch, encoder='v1')
46 | latent_zs.append(latent_z.cpu().data.numpy())
47 | latent_zs = np.concatenate(latent_zs)
48 | return latent_zs
49 |
50 |
51 | def calc_prec_rec_f1_acc(preds, golds):
52 | lgolds, lpreds = [], []
53 | for g, p in zip(golds, list(preds)):
54 | if g > 0:
55 | lgolds.append(g)
56 | lpreds.append(p)
57 | prec, rec, f1 = cluster_metrics.calc_prec_rec_f1(
58 | gnd_assignments=torch.LongTensor(lgolds).to(device),
59 | pred_assignments=torch.LongTensor(lpreds).to(device))
60 | acc = cluster_metrics.calc_ACC(
61 | torch.LongTensor(lpreds).to(device), torch.LongTensor(lgolds).to(device))
62 | return prec, rec, f1, acc
63 |
64 |
65 | def main():
66 | parser = argparse.ArgumentParser()
67 | parser.add_argument('--data-path', type=str, default='./data/airlines_processed.csv')
68 | parser.add_argument('--glove-path', type=str, default='./data/glove.840B.300d.txt')
69 | parser.add_argument('--pre-epoch', type=int, default=5)
70 | parser.add_argument('--pt-batch', type=int, default=100)
71 | parser.add_argument('--scenarios', type=str, default='view1,view2,concatviews,wholeconv',
72 | help='comma-separated, from [view1|view2|concatviews|wholeconv|mvsc]')
73 | parser.add_argument('--mvsc-no-unk', action='store_true',
74 | help='only feed non-unk data to MVSC (to avoid oom)')
75 |
76 | parser.add_argument('--view1-col', type=str, default='view1')
77 | parser.add_argument('--view2-col', type=str, default='view2')
78 | parser.add_argument('--label-col', type=str, default='tag')
79 | args = parser.parse_args()
80 |
81 | print('loading dataset')
82 | dataset = Dataset(args.data_path, view1_col=args.view1_col, view2_col=args.view2_col,
83 | label_col=args.label_col)
84 | n_cluster = len(dataset.id_to_label) - 1
85 | print("num of class = %d" % n_cluster)
86 |
87 | id_to_token, token_to_id = dataset.id_to_token, dataset.token_to_id
88 | vocab_size = len(dataset.token_to_id)
89 | print('vocab_size', vocab_size)
90 |
91 | # Load pre-trained GloVe vectors
92 | pretrained = {}
93 | word_emb_size = 0
94 | print('loading glove')
95 | for line in open(args.glove_path):
96 | parts = line.strip().split()
97 | if len(parts) % 100 != 1:
98 | continue
99 | word = parts[0]
100 | if word not in token_to_id:
101 | continue
102 | vector = [float(v) for v in parts[1:]]
103 | pretrained[word] = vector
104 | word_emb_size = len(vector)
105 | pretrained_list = []
106 | scale = np.sqrt(3.0 / word_emb_size)
107 | print('loading oov')
108 | for word in id_to_token:
109 | # apply lower() because all GloVe vectors are for lowercase words
110 | if word.lower() in pretrained:
111 | pretrained_list.append(np.array(pretrained[word.lower()]))
112 | else:
113 | random_vector = np.random.uniform(-scale, scale, [word_emb_size])
114 | pretrained_list.append(random_vector)
115 |
116 | model = MultiviewEncoders.from_embeddings(
117 | embeddings=torch.FloatTensor(pretrained_list),
118 | num_layers=LSTM_LAYER,
119 | embedding_size=word_emb_size,
120 | lstm_hidden_size=LSTM_HIDDEN,
121 | word_dropout=WORD_DROPOUT_RATE,
122 | dropout=DROPOUT_RATE,
123 | vocab_size=vocab_size
124 | )
125 | model.to(device)
126 | optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
127 |
128 | expressions = (model, optimizer)
129 | pre_acc, pre_state = 0., None
130 | pretrain_method = pretrain.pretrain_qt
131 | for epoch in range(1, args.pre_epoch + 1):
132 | model.train()
133 | perm_idx = np.random.permutation(dataset.trn_idx)
134 | trn_loss, _ = pretrain_method(dataset, perm_idx, expressions, train=True)
135 | model.eval()
136 | _, tst_acc = pretrain_method(dataset, dataset.tst_idx, expressions, train=False)
137 | if tst_acc > pre_acc:
138 | pre_state = copy.deepcopy(model.state_dict())
139 | pre_acc = tst_acc
140 | print(f'{datetime.datetime.now()} epoch {epoch}, train_loss={trn_loss:.4f} '
141 | f'test_acc={tst_acc:.4f}')
142 |
143 | if args.pre_epoch > 0:
144 | # load best state
145 | model.load_state_dict(pre_state)
146 |
147 | # deepcopy pretrained views into v1 and/or view2
148 | pretrain.after_pretrain_qt(model)
149 |
150 | kmeans = sklearn.cluster.KMeans(n_clusters=n_cluster, max_iter=300, verbose=0, random_state=0)
151 |
152 | golds = [dataset[idx][1] for idx in dataset.trn_idx]
153 | for rep in args.scenarios.split(','):
154 | if rep == 'view1':
155 | data = [dataset[idx][0][0] for idx in dataset.trn_idx]
156 | encoded = transform(data=data, model=model)
157 | preds = kmeans.fit_predict(encoded)
158 | elif rep == 'view2':
159 | data = [dataset[idx][0][1] for idx in dataset.trn_idx]
160 | encoded = []
161 | for conv in data:
162 | encoded_conv = transform(data=conv, model=model)
163 | encoded_conv = torch.from_numpy(encoded_conv)
164 | encoded_conv = encoded_conv.mean(dim=0)
165 | encoded.append(encoded_conv)
166 | encoded = torch.stack(encoded, dim=0)
167 | # print('encoded.size()', encoded.size())
168 | encoded = encoded.numpy()
169 | preds = kmeans.fit_predict(encoded)
170 | elif rep == 'concatviews':
171 | v1_data = [dataset[idx][0][0] for idx in dataset.trn_idx]
172 | v1_encoded = torch.from_numpy(transform(data=v1_data, model=model))
173 |
174 | v2_data = [dataset[idx][0][1] for idx in dataset.trn_idx]
175 | v2_encoded = []
176 | for conv in v2_data:
177 | encoded_conv = transform(data=conv, model=model)
178 | encoded_conv = torch.from_numpy(encoded_conv)
179 | encoded_conv = encoded_conv.mean(dim=0)
180 | v2_encoded.append(encoded_conv)
181 | v2_encoded = torch.stack(v2_encoded, dim=0)
182 | concatview = torch.cat([v1_encoded, v2_encoded], dim=-1)
183 | print('concatview.size()', concatview.size())
184 | encoded = concatview.numpy()
185 | preds = kmeans.fit_predict(encoded)
186 | elif rep == 'wholeconv':
187 | encoded = []
188 | for idx in dataset.trn_idx:
189 | v1 = dataset[idx][0][0]
190 | v2 = dataset[idx][0][1]
191 | conv = [v1] + v2
192 | encoded_conv = transform(data=conv, model=model)
193 | encoded_conv = torch.from_numpy(encoded_conv)
194 | encoded_conv = encoded_conv.mean(dim=0)
195 | encoded.append(encoded_conv)
196 | encoded = torch.stack(encoded, dim=0)
197 | print('encoded.size()', encoded.size())
198 | encoded = encoded.numpy()
199 | preds = kmeans.fit_predict(encoded)
200 | elif rep == 'mvsc':
201 | try:
202 | import multiview
203 | except Exception:
204 | print('please install https://github.com/mariceli3/multiview')
205 | return
206 | print('imported multiview ok')
207 |
208 | idx = dataset.trn_idx_no_unk if args.mvsc_no_unk else dataset.trn_idx
209 | v1_data = [dataset[idx][0][0] for idx in idx]
210 | v1_encoded = torch.from_numpy(transform(data=v1_data, model=model))
211 |
212 | v2_data = [dataset[idx][0][1] for idx in idx]
213 | v2_encoded = []
214 | for conv in v2_data:
215 | encoded_conv = transform(data=conv, model=model)
216 | encoded_conv = torch.from_numpy(encoded_conv)
217 | encoded_conv = encoded_conv.mean(dim=0)
218 | v2_encoded.append(encoded_conv)
219 | v2_encoded = torch.stack(v2_encoded, dim=0)
220 |
221 | mvsc = multiview.mvsc.MVSC(
222 | k=n_cluster
223 | )
224 | print('running mvsc', end='', flush=True)
225 | start = time.time()
226 | preds, eivalues, eivectors, sigmas = mvsc.fit_transform(
227 | [v1_encoded, v2_encoded], [False] * 2
228 | )
229 | print('...done')
230 | mvsc_time = time.time() - start
231 | print('time taken %.3f' % mvsc_time)
232 | else:
233 | raise Exception('unimplemented rep', rep)
234 |
235 | prec, rec, f1, acc = calc_prec_rec_f1_acc(preds, golds)
236 | print(f'{datetime.datetime.now()} {rep}: eval prec={prec:.4f} rec={rec:.4f} f1={f1:.4f} '
237 | f'acc={acc:.4f}')
238 |
239 |
240 | if __name__ == '__main__':
241 | main()
242 |
--------------------------------------------------------------------------------