├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── code
├── filtering
│ ├── average_word_embedding.py
│ ├── filter_problem.py
│ ├── identity.py
│ ├── semantic_clustering.py
│ └── sent2vec.py
├── main.py
└── utils
│ ├── config.py
│ ├── filtering_demo.ipynb
│ ├── utils.py
│ └── visualization.ipynb
├── docs
├── 3d.png
├── class_diagram.uml
├── cluster_examples.png
├── example_responses.png
├── help.png
├── high_entropy.png
├── metrics_table.png
├── other_datasets.png
├── uml.png
└── visu.png
├── requirements.txt
└── setup.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
4 | # Custom for Visual Studio
5 | *.cs diff=csharp
6 |
7 | # Standard to msysgit
8 | *.doc diff=astextplain
9 | *.DOC diff=astextplain
10 | *.docx diff=astextplain
11 | *.DOCX diff=astextplain
12 | *.dot diff=astextplain
13 | *.DOT diff=astextplain
14 | *.pdf diff=astextplain
15 | *.PDF diff=astextplain
16 | *.rtf diff=astextplain
17 | *.RTF diff=astextplain
18 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask instance folder
57 | instance/
58 |
59 | # Scrapy stuff:
60 | .scrapy
61 |
62 | # Sphinx documentation
63 | docs/_build/
64 |
65 | # PyBuilder
66 | target/
67 |
68 | # IPython Notebook
69 | .ipynb_checkpoints
70 |
71 | # pyenv
72 | .python-version
73 |
74 | # celery beat schedule file
75 | celerybeat-schedule
76 |
77 | # dotenv
78 | .env
79 |
80 | # virtualenv
81 | venv/
82 | ENV/
83 |
84 | # Spyder project settings
85 | .spyderproject
86 |
87 | # Rope project settings
88 | .ropeproject
89 |
90 | # =========================
91 | # Operating System Files
92 | # =========================
93 |
94 | # OSX
95 | # =========================
96 |
97 | .DS_Store
98 | .AppleDouble
99 | .LSOverride
100 |
101 | # Thumbnails
102 | ._*
103 |
104 | # Files that might appear in the root of a volume
105 | .DocumentRevisions-V100
106 | .fseventsd
107 | .Spotlight-V100
108 | .TemporaryItems
109 | .Trashes
110 | .VolumeIcon.icns
111 |
112 | # Directories potentially created on remote AFP share
113 | .AppleDB
114 | .AppleDesktop
115 | Network Trash Folder
116 | Temporary Items
117 | .apdisk
118 |
119 | # Windows
120 | # =========================
121 |
122 | # Windows image file caches
123 | Thumbs.db
124 | ehthumbs.db
125 |
126 | # Folder config file
127 | Desktop.ini
128 |
129 | # Recycle Bin used on file shares
130 | $RECYCLE.BIN/
131 |
132 | # Windows Installer files
133 | *.cab
134 | *.msi
135 | *.msm
136 | *.msp
137 |
138 | # Windows shortcuts
139 | *.lnk
140 |
141 | *.txt
142 | # Large files and folders
143 | private/
144 | data/
145 | responses/
146 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Richard Krisztian Csaky
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NeuralChatbots-DataFiltering · [](https://ctt.ac/E_jP6)
2 | [](https://www.aclweb.org/anthology/P19-1567) [](https://ricsinaruto.github.io/website/docs/acl_poster_h.pdf) [](https://github.com/ricsinaruto/Seq2seqChatbots) [](https://github.com/ricsinaruto/dialog-eval) [](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/wiki) [](https://medium.com/@richardcsaky/neural-chatbots-are-dumb-65b6b40e9bd4)
3 | A lightweight repo for filtering dialog data with entropy-based methods.
4 |
5 | The program **reads the dataset**, runs **clustering** if needed, computes the **entropy** of individual utterances, and then **removes high entropy** utterances based on the threshold, and **saves the filtered dataset** to the output directory. See the [paper](https://www.aclweb.org/anthology/P19-1567) or the [poster](https://ricsinaruto.github.io/website/docs/acl_poster_h.pdf) for more details.
6 |
7 | ## Features
8 | :floppy_disk: Cluster and filter any dialog data that you provide, or use pre-downloaded datasets
9 | :rocket: Various parameters can be used to adjust the algorithm
10 | :ok_hand: Choose between different entropy computation methods
11 | :twisted_rightwards_arrows: Choose between different clustering and filtering types
12 | :movie_camera: Visualize clustering and filtering results
13 |
14 |
15 |
16 | ## Setup
17 | Run setup.py which installs required packages and steps you through downloading additional data:
18 | ```
19 | python setup.py
20 | ```
21 | You can download all trained models used in [this](https://www.aclweb.org/anthology/P19-1567) paper from [here](https://mega.nz/#!mI0iDCTI!qhKoBiQRY3rLg3K6nxAmd4ZMNEX4utFRvSby_0q2dwU). Each training contains two checkpoints, one for the validation loss minimum and another after 150 epochs. The data and the trainings folder structure match each other exactly.
22 | ## Usage
23 | The main file can be called from anywhere, but when specifying paths to directories you should give them from the root of the repository.
24 | ```
25 | python code/main.py -h
26 | ```
27 |
28 | For the complete **documentation** visit the [wiki](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/wiki).
29 |
30 | ### Cluster Type
31 | * [identity](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/code/filtering/identity.py): In this method there is basically no clustering, the entropy of utterances is calculated based on the conditional probability of utterance pairs.
32 | * [avg-embedding](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/code/filtering/average_word_embedding.py): This clustering type uses average word embedding sentence representations as in [this paper](https://pdfs.semanticscholar.org/3fc9/7768dc0b36449ec377d6a4cad8827908d5b4.pdf).
33 | * [sent2vec](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/code/filtering/sent2vec.py): This clustering type should use [sent2vec](https://github.com/epfml/sent2vec) sentence embeddings, but currently uses any embeddings you provide to it.
34 |
35 | ### Filter Type
36 | * **source**: Filters utterance pairs in which the source utterance's entropy is above the threshold.
37 | * **target**: Filters utterance pairs in which the target utterance's entropy is above the threshold.
38 | * **both**: Filters utterance pairs in which either the source or target utterance's entropy is above the threshold.
39 |
40 | ### [Filtering Demo](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/code/utils/filtering_demo.ipynb)
41 | In this jupyter notebook you can easily try out the identity filtering method implemented in less than 40 lines, and it filters DailyDialog in a couple of seconds (you only need to provide a sources and targets file). In the second part of the notebook there are some cool visualizations for entropy, frequency and sentence length.
42 |
43 |
44 | ### [Visualization](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/code/utils/visualization.ipynb)
45 | Visualize clustering and filtering results by running the [visualization](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/code/utils/visualization.ipynb) jupyter notebook. The notebook is pretty self-explanatory, you just have to provide the directory containing the clustering files.
46 |
47 |
48 |
49 | ## Results & Examples
50 | ### High Entropy Utterances and Clusters from [DailyDialog](https://arxiv.org/abs/1710.03957)
51 |
52 |
53 | A high entropy cluster found by sent2vec.
54 |
55 | ### [Transformer](https://arxiv.org/abs/1706.03762) Trained on [DailyDialog](https://arxiv.org/abs/1710.03957)
56 | For an explanation of the metrics please check [this repo](https://github.com/ricsinaruto/dialog-eval) or the [paper](https://arxiv.org/pdf/1905.05471.pdf).
57 |
58 |
59 |
60 | More examples can be found in the appendix of the [paper](https://arxiv.org/pdf/1905.05471.pdf).
61 |
62 | ### [Transformer](https://arxiv.org/abs/1706.03762) Trained on [Cornell](https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html) and [Twitter](https://github.com/facebookresearch/ParlAI/tree/master/parlai/tasks/twitter)
63 | For an explanation of the metrics please check [this repo](https://github.com/ricsinaruto/dialog-eval) or the [paper](https://arxiv.org/pdf/1905.05471.pdf).
64 |
65 |
66 | ## Contributing
67 | ##### Check the [issues](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/issues) for some additions where help is appreciated. Any contributions are welcome :heart:
68 | ##### Please try to follow the code syntax style used in the repo (flake8, 2 spaces indent, 80 char lines, commenting a lot, etc.)
69 |
70 | **New clustering** methods can be added, by subclassing the [FilterProblem](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/code/filtering/filter_problem.py#L48) class, check [Identity](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/code/filtering/identity.py) for a minimal example. Normally you only have to redefine the *clustering* function, which does the clustering of sentences.
71 |
72 | Loading and saving data is taken care of, and you should use the [Cluster](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/code/filtering/filter_problem.py#L24) and [DataPoint](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/code/filtering/filter_problem.py#L9) objects. Use the *data_point* list to get the sentences for your clustering algorithm, and use the *clusters* list to save the results of your clustering. These can also be subclassed if you want to add extra data to your DataPoint and Cluster objects (like a vector).
73 |
74 | Finally add your class to the dictionary in [main](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/code/main.py#L90), and to the command-line argument choices.
75 |
76 |
77 | ## Authors
78 | * **[Richard Csaky](ricsinaruto.github.io)** (If you need any help with running the code: ricsinaruto@hotmail.com)
79 | * **[Patrik Purgai](https://github.com/Mrpatekful)** (clustering part)
80 |
81 | ## License
82 | This project is licensed under the MIT License - see the [LICENSE](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/LICENSE) file for details.
83 | Please include a link to this repo if you use it in your work and consider citing the following paper:
84 | ```
85 | @inproceedings{Csaky:2019,
86 | title = "Improving Neural Conversational Models with Entropy-Based Data Filtering",
87 | author = "Cs{\'a}ky, Rich{\'a}rd and Purgai, Patrik and Recski, G{\'a}bor",
88 | booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics",
89 | month = jul,
90 | year = "2019",
91 | address = "Florence, Italy",
92 | publisher = "Association for Computational Linguistics",
93 | url = "https://www.aclweb.org/anthology/P19-1567",
94 | pages = "5650--5669",
95 | }
96 | ```
97 |
--------------------------------------------------------------------------------
/code/filtering/average_word_embedding.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 | import requests
5 | import zipfile
6 | from collections import Counter
7 | from clint.textui import progress
8 |
9 | from filtering import semantic_clustering
10 |
11 |
12 | class AverageWordEmbedding(semantic_clustering.SemanticClustering):
13 | '''
14 | Averaged word embeddings clustering method. The meaning vector of the
15 | sentence is created by the weighted average of the word vectors.
16 | '''
17 |
18 | # Download data from fasttext.
19 | def download_fasttext(self):
20 | # Open the url and download the data with progress bars.
21 | data_stream = requests.get('https://dl.fbaipublicfiles.com/fasttext/' +
22 | 'vectors-english/wiki-news-300d-1M.vec.zip', stream=True)
23 | zipped_path = os.path.join(self.input_dir, 'fasttext.zip')
24 |
25 | with open(zipped_path, 'wb') as file:
26 | total_length = int(data_stream.headers.get('content-length'))
27 | for chunk in progress.bar(data_stream.iter_content(chunk_size=1024),
28 | expected_size=total_length / 1024 + 1):
29 | if chunk:
30 | file.write(chunk)
31 | file.flush()
32 |
33 | # Extract file.
34 | zip_file = zipfile.ZipFile(zipped_path, 'r')
35 | zip_file.extractall(self.input_dir)
36 | zip_file.close()
37 |
38 | # Generate a vocab from data files.
39 | def get_vocab(self, vocab_path):
40 | vocab = []
41 |
42 | with open(vocab_path, 'w') as file:
43 | for dp in self.data_points['Source']:
44 | vocab.extend(dp.string.split())
45 | file.write('\n'.join(
46 | [w[0] for w in Counter(vocab).most_common(self.config.vocab_size)]))
47 |
48 | # Download FastText word embeddings.
49 | def get_fast_text_embeddings(self):
50 | vocab_path = os.path.join(self.input_dir, 'vocab.txt')
51 | if not os.path.exists(vocab_path):
52 | print('No vocab file named \'vocab.txt\' found in ' + self.input_dir)
53 | print('Building vocab from data.')
54 | self.get_vocab(vocab_path)
55 |
56 | fasttext_path = os.path.join(self.input_dir, 'wiki-news-300d-1M.vec')
57 | if not os.path.exists(fasttext_path):
58 | self.download_fasttext()
59 |
60 | vocab = [line.strip('\n') for line in open(vocab_path)]
61 | vocab_path = os.path.join(self.input_dir, 'vocab.npy')
62 |
63 | # Save the vectors for words in the vocab.
64 | with open(fasttext_path, errors='ignore') as in_file:
65 | with open(vocab_path, 'w') as out_file:
66 | vectors = {}
67 | for line in in_file:
68 | tokens = line.strip().split()
69 | vectors[tokens[0]] = line
70 |
71 | for word in vocab:
72 | try:
73 | out_file.write(vectors[word])
74 | except KeyError:
75 | pass
76 |
77 | # Generate the sentence embeddings.
78 | def generate_embeddings(self, tag, vector_path):
79 | '''
80 | Params:
81 | :tag: Whether it's source or target data.
82 | :vector_path: Path to save the sentence vectors.
83 | '''
84 | vocab = {}
85 | vocab_path = os.path.join(self.input_dir, 'vocab.npy')
86 | if not os.path.exists(vocab_path):
87 | print('File containing word vectors not found in ' + self.input_dir)
88 | print('The file should be named \'vocab.npy\'')
89 | print('If you would like to use FastText embeddings press \'y\'')
90 | if input() == 'y':
91 | self.get_fast_text_embeddings()
92 | else:
93 | sys.exit()
94 |
95 | # Get the word embeddings.
96 | with open(vocab_path) as v:
97 | for line in v:
98 | tokens = line.strip().split()
99 | vocab[tokens[0]] = [0, np.array(list(map(float, tokens[1:])))]
100 |
101 | embedding_dim = vocab[list(vocab)[0]][1].shape[0]
102 | unique_sentences = set()
103 | word_count = 0
104 |
105 | # Statistics of number of words.
106 | for dp in self.data_points[tag]:
107 | unique_sentences.add(dp.string)
108 | for word in dp.string.split():
109 | if vocab.get(word):
110 | vocab[word][0] += 1
111 | word_count += 1
112 |
113 | meaning_vectors = []
114 | sentences = unique_sentences if self.unique else [
115 | s.string for s in self.data_points[tag]]
116 | # Calculate smooth average embedding.
117 | for s in sentences:
118 | vectors = []
119 | for word in s.split():
120 | vector = vocab.get(word)
121 | if vector:
122 | vectors.append(vector[1] * 0.001 / (0.001 + vector[0] / word_count))
123 |
124 | num_vecs = len(vectors)
125 | if num_vecs:
126 | meaning_vectors.append(np.sum(np.array(vectors), axis=0) / num_vecs)
127 | else:
128 | meaning_vectors.append(np.zeros(embedding_dim))
129 |
130 | np.save(vector_path, np.array(meaning_vectors).reshape(-1, embedding_dim))
131 |
--------------------------------------------------------------------------------
/code/filtering/filter_problem.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | from collections import Counter
5 |
6 | from utils import utils
7 |
8 |
9 | class DataPoint:
10 | '''
11 | A simple class that handles a string example.
12 | '''
13 | def __init__(self, string, index):
14 | '''
15 | Params:
16 | :string: String to be stored.
17 | :index: Number of the line in the file from which this sentence was read.
18 | '''
19 | self.index = index
20 | self.string = ' '.join(string.strip('\n').split())
21 | self.cluster_index = 0
22 |
23 |
24 | class Cluster:
25 | '''
26 | A class to handle one cluster in the clustering problem.
27 | '''
28 | def __init__(self, medoid):
29 | '''
30 | Params:
31 | :medoid: Center of the cluster: a DataPoint object.
32 | '''
33 | self.medoid = medoid
34 | self.elements = []
35 | self.targets = []
36 | self.entropy = 0
37 | self.index = 0
38 |
39 | # Append an element to the list of elements in the cluster.
40 | def add_element(self, element):
41 | self.elements.append(element)
42 |
43 | # append an element to the list of targets in the cluster.
44 | def add_target(self, target):
45 | self.targets.append(target)
46 |
47 |
48 | class FilterProblem:
49 | '''
50 | An abstract class to handle different types of filtering.
51 | '''
52 |
53 | @property
54 | def DataPointClass(self):
55 | return DataPoint
56 |
57 | @property
58 | def ClusterClass(self):
59 | return Cluster
60 |
61 | def __init__(self, config):
62 | '''
63 | Params:
64 | :config: Config object storing all arguments.
65 | '''
66 | self.config = config
67 | self.tag = config.filter_split
68 | self.threshold = config.threshold
69 | self.max_avg_length = config.max_avg_length
70 | self.max_medoid_length = config.max_medoid_length
71 | self.min_cluster_size = config.min_cluster_size
72 | self.unique = config.unique
73 |
74 | self.project_path = os.path.join(
75 | os.path.dirname(os.path.abspath(__file__)), '..', '..')
76 | self.output_dir = os.path.join(self.project_path, config.output_dir)
77 | self.input_dir = os.path.join(self.project_path, config.data_dir)
78 | self.type = config.filter_type
79 |
80 | self.clusters = {'Source': [], 'Target': []}
81 | self.data_points = {'Source': [], 'Target': []}
82 | self.num_clusters = {'Source': config.source_clusters,
83 | 'Target': config.target_clusters}
84 |
85 | self.build('Source')
86 | self.build('Target')
87 |
88 | # Build statistics and full files.
89 | def build(self, tag):
90 | '''
91 | Params:
92 | :tag: 'Source' or 'Target'.
93 | '''
94 | splits = ['train', 'dev', 'test']
95 | try:
96 | files = [open(os.path.join(self.input_dir, split + tag + '.txt')).read()
97 | for split in splits]
98 | except FileNotFoundError:
99 | print('Data files not found in ' + self.input_dir)
100 | print('The following files should be here:')
101 | print('trainSource.txt\ntrainTarget.txt\ndevSource.txt\ndevTarget.txt')
102 | print('testSource.txt\ntestTarget.txt')
103 | sys.exit()
104 |
105 | full_path = os.path.join(self.input_dir, 'full' + tag + '.txt')
106 | if not os.path.exists(full_path):
107 | open(full_path, 'w').write(''.join(files))
108 |
109 | if tag == 'Source':
110 | self.line_counts = dict(zip(splits,
111 | map(lambda x: len(x.split('\n')), files)))
112 |
113 | # Main method that will run all the functions to do the filtering.
114 | def run(self):
115 | # If we have already done the clustering, don't redo it.
116 | source_data = os.path.join(self.output_dir,
117 | self.tag + 'Source_cluster_elements.txt')
118 | target_data = os.path.join(self.output_dir,
119 | self.tag + 'Target_cluster_elements.txt')
120 | if os.path.exists(source_data) and os.path.exists(target_data):
121 | print('Cluster files are in ' + self.output_dir + ', filtering now.')
122 | self.load_clusters(source_data, target_data)
123 | self.filtering()
124 |
125 | else:
126 | print('No cluster files in ' + self.output_dir + ', clustering now.')
127 | self.read_inputs('Source')
128 | self.read_inputs('Target')
129 | self.clustering('Source')
130 | self.clustering('Target')
131 | self.save_clusters('Source')
132 | self.save_clusters('Target')
133 | self.filtering()
134 |
135 | # Read the data and make it ready for clustering.
136 | def read_inputs(self, tag):
137 | '''
138 | Params:
139 | :tag: 'Source' or 'Target'.
140 | '''
141 | file = open(os.path.join(self.input_dir, self.tag + tag + '.txt'))
142 | for i, line in enumerate(file):
143 | self.data_points[tag].append(self.DataPointClass(line, i))
144 | file.close()
145 |
146 | print('Finished reading ' + tag + ' data.')
147 |
148 | # Load clusters from file.
149 | def load_clusters(self, source_path, target_path):
150 | '''
151 | Params:
152 | :source_path: Path to source cluster elements.
153 | :target_path: Path to target cluster elements.
154 | '''
155 | source_clusters = {}
156 | target_clusters = {}
157 | source_data_points = {}
158 | target_data_points = {}
159 |
160 | with open(source_path, 'r') as source_file:
161 | for line in source_file:
162 | # If the data contains these special characters it won't work.
163 | [source_index_center, source_target, _] = line.split('<=====>')
164 | [source_index, source_center] = source_index_center.split(';')
165 | [source, target] = source_target.split('=')
166 |
167 | # Initialize the source and target utterances.
168 | source_data_points[int(source_index)] = self.DataPointClass(
169 | source, int(source_index))
170 | target_data_points[int(source_index)] = self.DataPointClass(
171 | target, int(source_index))
172 |
173 | # If this is a new cluster add it to the list.
174 | if source_clusters.get(source_center) is None:
175 | center = self.DataPointClass(source_center, 0)
176 | source_clusters[source_center] = self.ClusterClass(center)
177 | source_clusters[source_center].index = len(source_clusters) - 1
178 |
179 | # Add the elements to the cluster.
180 | source_data_points[int(source_index)].cluster_index = \
181 | source_clusters[source_center].index
182 | source_clusters[source_center].add_element(
183 | source_data_points[int(source_index)])
184 | source_clusters[source_center].add_target(
185 | target_data_points[int(source_index)])
186 |
187 | with open(target_path, 'r') as target_file:
188 | for line in target_file:
189 | [target_index_center, target_source, _] = line.split('<=====>')
190 | [target_index, target_center] = target_index_center.split(';')
191 | [target, source] = target_source.split('=')
192 |
193 | # All elements are already added at this point.
194 | target_data_point = target_data_points[int(target_index)]
195 | source_data_point = source_data_points[int(target_index)]
196 |
197 | # If this is a new cluster add it to the list.
198 | if target_clusters.get(target_center) is None:
199 | center = self.DataPointClass(target_center, 0)
200 | target_clusters[target_center] = self.ClusterClass(center)
201 | target_clusters[target_center].index = len(target_clusters) - 1
202 |
203 | # Add the elements to the cluster.
204 | target_data_point.cluster_index = target_clusters[target_center].index
205 | target_clusters[target_center].add_element(target_data_point)
206 | target_clusters[target_center].add_target(source_data_point)
207 |
208 | # Save the data correctly into self.clusters.
209 | def id(cl):
210 | return sorted(list(cl), key=lambda x: cl[x].index)
211 | self.clusters['Source'] = [source_clusters[i] for i in id(source_clusters)]
212 | self.clusters['Target'] = [target_clusters[i] for i in id(target_clusters)]
213 |
214 | # Cluster sources or targets, should be implemented in subclass.
215 | def clustering(self, tag):
216 | raise NotImplementedError
217 |
218 | # Return a list of indices, showing which clusters should be filtered out.
219 | def get_filtered_indices(self, tag):
220 | '''
221 | Params:
222 | :tag: Source or Target.
223 | '''
224 | indices = []
225 | for num_cl, cluster in enumerate(self.clusters[tag]):
226 | # Build a distribution for the current cluster, based on the targets.
227 | distribution = Counter([t.cluster_index for t in cluster.targets])
228 |
229 | num_elements = len(cluster.elements)
230 | # Calculate entropy.
231 | entropy = 0
232 | for cl_index in distribution:
233 | if num_elements > 1:
234 | probability = distribution[cl_index] / num_elements
235 | entropy += probability * math.log(probability, 2)
236 | cluster.entropy = -entropy
237 |
238 | avg_length = (
239 | sum(len(sent.string.split()) for sent in cluster.elements) /
240 | (num_elements if num_elements > 0 else 1))
241 | medoid_length = len(cluster.medoid.string.split())
242 |
243 | # Filter based on threshold.
244 | if (cluster.entropy > self.threshold and
245 | avg_length < self.max_avg_length and
246 | medoid_length < self.max_medoid_length):
247 | indices.append(num_cl)
248 |
249 | print('Finished filtering ' + tag + ' data.')
250 | return indices
251 |
252 | # Do the filtering of the dataset.
253 | def filtering(self):
254 | # These are not needed anymore.
255 | self.data_points['Source'].clear()
256 | self.data_points['Target'].clear()
257 |
258 | # Get the filtered indices for both sides.
259 | source_indices = self.get_filtered_indices('Source')
260 | target_indices = self.get_filtered_indices('Target')
261 |
262 | file_dict = {}
263 | # We have to open 6 files in this case.
264 | if self.tag == 'full':
265 | name_list = ['trainS', 'trainT', 'devS', 'devT', 'testS', 'testT']
266 | file_dict = dict(zip(name_list, self.open_6_files()))
267 | else:
268 | file_dict[self.tag + 'S'] = open(
269 | os.path.join(self.output_dir, self.tag + 'Source.txt'), 'w')
270 | file_dict[self.tag + 'T'] = open(
271 | os.path.join(self.output_dir, self.tag + 'Target.txt'), 'w')
272 |
273 | # Handle all other cases and open files.
274 | if self.type == 'source' or self.type == 'both':
275 | file_dict['source_entropy'] = open(
276 | os.path.join(self.output_dir,
277 | self.tag + 'Source_cluster_entropies.txt'), 'w')
278 | if self.type == 'target' or self.type == 'both':
279 | file_dict['target_entropy'] = open(
280 | os.path.join(self.output_dir,
281 | self.tag + 'Target_cluster_entropies.txt'), 'w')
282 |
283 | # Save data and close files.
284 | self.save_filtered_data(source_indices, target_indices, file_dict)
285 | utils.close_n_files(file_dict)
286 |
287 | # Save the new filtered datasets.
288 | def save_filtered_data(self, source_indices, target_indices, file_dict):
289 | '''
290 | Params:
291 | :source_indices: Indices of source clusters that will be filtered.
292 | :target_indices: Indices of target clusters that will be filtered.
293 | :file_dict: Dictionary containing all the files that we want to write.
294 | '''
295 | # Function for writing filtered source or target data to file.
296 | def save_dataset(tag):
297 | for num_cl, cluster in enumerate(self.clusters[tag]):
298 | # Write cluster entropies.
299 | file_dict[tag.lower() + '_entropy'].write(
300 | cluster.medoid.string + ';' +
301 | str(cluster.entropy) + ';' +
302 | str(len(cluster.elements)) + '\n')
303 |
304 | # Check if a cluster is smaller than threshold.
305 | cluster_too_small = len(cluster.elements) < self.min_cluster_size
306 | indices = source_indices if tag == 'Source' else target_indices
307 |
308 | # Make sure that in 'both' case this is only run once.
309 | if ((tag == 'Source' or self.type != 'both') and
310 | (num_cl not in indices or cluster_too_small)):
311 | # Filter one side.
312 | for num_el, element in enumerate(cluster.elements):
313 | target_cl = cluster.targets[num_el].cluster_index
314 | if self.type == 'both':
315 | cluster_too_small = (
316 | len(self.clusters['Target'][target_cl].elements) <
317 | self.min_cluster_size)
318 | # Check both sides in 'both' case.
319 | if ((target_cl not in target_indices or cluster_too_small) or
320 | self.type != 'both'):
321 |
322 | # Reverse if Target.
323 | source = element.string + '\n'
324 | target = cluster.targets[num_el].string + '\n'
325 | source_string = source if tag == 'Source' else target
326 | target_string = target if tag == 'Source' else source
327 |
328 | # Separate the full case.
329 | if self.tag == 'full':
330 | if element.index < self.line_counts['train']:
331 | file_dict['trainS'].write(source_string)
332 | file_dict['trainT'].write(target_string)
333 | elif element.index < (self.line_counts['train'] +
334 | self.line_counts['dev']):
335 | file_dict['devS'].write(source_string)
336 | file_dict['devT'].write(target_string)
337 | else:
338 | file_dict['testS'].write(source_string)
339 | file_dict['testT'].write(target_string)
340 | else:
341 | file_dict[self.tag + 'S'].write(source_string)
342 | file_dict[self.tag + 'T'].write(target_string)
343 |
344 | # Write source entropies and data to file.
345 | if self.type == 'source' or self.type == 'both':
346 | save_dataset('Source')
347 | # Write target entropies and data to file.
348 | if self.type == 'target' or self.type == 'both':
349 | save_dataset('Target')
350 |
351 | # Save clusters and their elements to files.
352 | def save_clusters(self, tag):
353 | '''
354 | Params:
355 | :tag: Whether it's source or target data.
356 | '''
357 | output = open(
358 | os.path.join(self.output_dir,
359 | self.tag + tag + '_cluster_elements.txt'), 'w')
360 |
361 | medoid_counts = []
362 | rev_tag = 'Target' if tag == 'Source' else 'Source'
363 |
364 | for cluster in self.clusters[tag]:
365 | medoid_counts.append((cluster.medoid.string, len(cluster.elements)))
366 |
367 | # Save together the source and target medoids and elements.
368 | for source, target in zip(cluster.elements, cluster.targets):
369 | output.write(
370 | str(source.index) + ';' +
371 | cluster.medoid.string + '<=====>' +
372 | source.string + '=' +
373 | target.string + '<=====>' +
374 | self.clusters[rev_tag][target.cluster_index].medoid.string + ':' +
375 | str(target.cluster_index) + '\n')
376 | output.close()
377 |
378 | # Save the medoids and the count of their elements, in decreasing order.
379 | output = open(os.path.join(self.output_dir,
380 | self.tag + tag + '_clusters.txt'), 'w')
381 | medoids = sorted(medoid_counts, key=lambda count: count[1], reverse=True)
382 |
383 | for medoid in medoids:
384 | output.write(medoid[0] + ':' + str(medoid[1]) + '\n')
385 | output.close()
386 |
387 | if tag == 'Target':
388 | print('Finished clustering, proceeding with filtering.')
389 |
390 | # Open the 6 files.
391 | def open_6_files(self):
392 | trainS = open(os.path.join(self.output_dir, 'trainSource.txt'), 'w')
393 | trainT = open(os.path.join(self.output_dir, 'trainTarget.txt'), 'w')
394 | devS = open(os.path.join(self.output_dir, 'devSource.txt'), 'w')
395 | devT = open(os.path.join(self.output_dir, 'devTarget.txt'), 'w')
396 | testS = open(os.path.join(self.output_dir, 'testSource.txt'), 'w')
397 | testT = open(os.path.join(self.output_dir, 'testTarget.txt'), 'w')
398 |
399 | return [trainS, trainT, devS, devT, testS, testT]
400 |
--------------------------------------------------------------------------------
/code/filtering/identity.py:
--------------------------------------------------------------------------------
1 | from filtering.filter_problem import FilterProblem
2 |
3 |
4 | class Identity(FilterProblem):
5 | '''
6 | Calculate entropy based on and filter individual utterances.
7 | '''
8 |
9 | # Do the clustering of sources and targets.
10 | def clustering(self, tag):
11 | '''
12 | Params:
13 | :tag: Whether it's source or target data.
14 | '''
15 | rev_tag = 'Target' if tag == 'Source' else 'Source'
16 |
17 | clean_sents = [' '.join(dp.string.split()) for dp in self.data_points[tag]]
18 | sentence_set = list(set(clean_sents))
19 |
20 | # Build a hash for efficient string searching.
21 | sentence_dict = {}
22 | for data_point, clean_sentence in zip(self.data_points[tag], clean_sents):
23 | if clean_sentence in sentence_dict:
24 | sentence_dict[clean_sentence].append(data_point)
25 | else:
26 | sentence_dict[clean_sentence] = [data_point]
27 |
28 | print(tag + ': ' + str(len(sentence_set)) + ' clusters')
29 |
30 | # Loop through the clusters.
31 | for i, sentence in enumerate(sentence_set):
32 | cl = self.ClusterClass(self.DataPointClass(sentence, 10))
33 | self.clusters[tag].append(cl)
34 |
35 | # Loop through the data points associated with this sentence.
36 | for data_point in sentence_dict[sentence]:
37 | data_point.cluster_index = i
38 | cl.add_element(data_point)
39 | cl.add_target(self.data_points[rev_tag][data_point.index])
40 |
--------------------------------------------------------------------------------
/code/filtering/semantic_clustering.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from sklearn.neighbors import BallTree
4 | from sklearn.cluster import MeanShift
5 |
6 | from filtering import filter_problem
7 | from utils.config import Config
8 |
9 | if Config.use_faiss:
10 | try:
11 | from faiss import Kmeans
12 | Config.faiss = True
13 |
14 | except ImportError:
15 | print('Failed to import faiss, using SKLearn clustering instead.')
16 |
17 | if not Config.use_faiss:
18 | from sklearn.cluster import KMeans
19 |
20 |
21 | class DataPoint(filter_problem.DataPoint):
22 | '''
23 | A simple class that handles a string example.
24 | '''
25 | def __init__(self, string, index, meaning_vector=None):
26 | '''
27 | Params:
28 | :string: String to be stored.
29 | :index: Number of the line in the file from which this sentence was read.
30 | :meaning_vector: Numpy embedding vector for the sentence.
31 | '''
32 | super().__init__(string, index)
33 | self.meaning_vector = meaning_vector
34 |
35 |
36 | class SemanticClustering(filter_problem.FilterProblem):
37 | '''
38 | Base class for the meaning-based (semantic vector representation) clustering.
39 | The source and target sentences are read into an extended DataPoint object,
40 | that also contains a 'meaning_vector' attribute. This attribute holds
41 | the semantic vector representation of the sentence, which will be used
42 | by the clustering logic.
43 | '''
44 |
45 | @property
46 | def DataPointClass(self):
47 | return DataPoint
48 |
49 | def __init__(self, *args, **kwargs):
50 | super().__init__(*args, **kwargs)
51 | self.unique_data = {"Source": [], "Target": []}
52 |
53 | def clustering(self, tag):
54 | '''
55 | Params:
56 | :tag: Whether it's source or target data.
57 | '''
58 | data_points = self.unique_data if self.unique else self.data_points
59 | centroids = self.calculate_centroids(tag)
60 |
61 | data_point_vectors = np.array(
62 | [data_point.meaning_vector for data_point in
63 | data_points[tag]]).reshape(
64 | -1, data_points[tag][0].meaning_vector.shape[-1])
65 |
66 | # Get the actual data point for each centroid.
67 | tree = BallTree(data_point_vectors)
68 | _, centroids = tree.query(centroids, k=1)
69 |
70 | # Get the closest centroid for each data point.
71 | tree = BallTree(data_point_vectors[np.array(centroids).reshape(-1)])
72 | _, labels = tree.query(data_point_vectors, k=1)
73 | labels = labels.reshape(-1)
74 |
75 | # Build the list of clusters.
76 | clusters = {index: self.ClusterClass(data_points[tag][index]) for
77 | index in {labels[_index] for _index in range(len(labels))}}
78 | clusters = [(clusters[cluster_index], cluster_index) for cluster_index in
79 | sorted(list(clusters))]
80 |
81 | label_lookup = {c[1]: i for i, c in enumerate(clusters)}
82 | clusters = [c[0] for c in clusters]
83 |
84 | rev_tag = 'Target' if tag == 'Source' else 'Source'
85 |
86 | # Store the cluster index for each unique sentence.
87 | if self.unique:
88 | cluster_ind_dict = {}
89 | for data_point, cluster_index in zip(data_points[tag], labels):
90 | cluster_ind_dict[data_point.string] = label_lookup[cluster_index]
91 |
92 | # This is different for unique clustering.
93 | for i, data_point in enumerate(self.data_points[tag]):
94 | cl_index = cluster_ind_dict[data_point.string]
95 | data_point.cluster_index = cl_index
96 | clusters[cl_index].add_element(data_point)
97 | clusters[cl_index].add_target(self.data_points[rev_tag][i])
98 |
99 | # Assign the actual clusters.
100 | else:
101 | for dp, cl_index in zip(self.data_points[tag], labels):
102 | cl_index = label_lookup[cl_index]
103 | dp.cluster_index = cl_index
104 | clusters[cl_index].add_element(dp)
105 | clusters[cl_index].add_target(self.data_points[rev_tag][dp.index])
106 |
107 | self.clusters[tag] = clusters
108 |
109 | def read_inputs(self, tag):
110 | '''
111 | Called twice for source and target data. It should implement the
112 | logic of reading the data from Source and Target files into the
113 | data_points list. Each sentence should be wrapped into an appropriate
114 | subclass of the DataPoint class. Source.npy and Target.npy should contain
115 | sentence embeddings, if not they have to be generated in a subclass.
116 | These vectors in the .npy files have to correspond to the loaded sentences.
117 |
118 | Params:
119 | :tag: Whether it's source or target data.
120 | '''
121 | super().read_inputs(tag)
122 |
123 | vector_path = os.path.join(self.input_dir, self.tag + tag + '.npy')
124 | if not os.path.exists(vector_path):
125 | print('No sentence embeddings found in ' + self.input_dir)
126 | print('They should be named \'fullSource.npy\' and \'fullTarget.npy\',')
127 | print('where each line is a vector corresponding to')
128 | print('sentences in \'fullSource.txt\' and \'fullTarget.txt\'.')
129 | print('Building sentence representations of ' + self.config.cluster_type)
130 | self.generate_embeddings(tag, vector_path)
131 |
132 | # Add vectors to sentences.
133 | sent_vectors = np.load(vector_path)
134 | if not self.unique:
135 | for index, dp in enumerate(self.data_points[tag]):
136 | dp.meaning_vector = sent_vectors[index]
137 | # Create unique data points if necessary.
138 | else:
139 | for i, sent in enumerate(set([s.string for s in self.data_points[tag]])):
140 | self.unique_data[tag].append(self.DataPointClass(sent,
141 | i,
142 | sent_vectors[i]))
143 |
144 | # Has to be implemented by subclass to generate sentence embeddings.
145 | def generate_embeddings(self, tag):
146 | raise NotImplementedError
147 |
148 | # Cluster the data and return the centers.
149 | def calculate_centroids(self, tag):
150 | '''
151 | Params:
152 | :tag: Whether it's source or target data.
153 | '''
154 | data_points = self.unique_data if self.unique else self.data_points
155 | matrix = np.stack([dp.meaning_vector for dp in data_points[tag]])
156 |
157 | if self.config.clustering_method == 'kmeans':
158 | # Kmeans with either the faiss or the sklearn implementation.
159 | if self.config.use_faiss:
160 | kmeans = Kmeans(matrix.shape[1], self.num_clusters[tag], 20, True)
161 | kmeans.train(matrix)
162 | centroids = kmeans.centroids
163 | else:
164 | kmeans = KMeans(n_clusters=self.num_clusters[tag],
165 | random_state=0,
166 | n_jobs=10).fit(matrix)
167 | centroids = kmeans.cluster_centers_
168 |
169 | else:
170 | mean_shift = MeanShift(bandwidth=self.config.bandwidth, n_jobs=10)
171 | mean_shift.fit(matrix)
172 | centroids = mean_shift.cluster_centers_
173 |
174 | return centroids
175 |
--------------------------------------------------------------------------------
/code/filtering/sent2vec.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from filtering import semantic_clustering
4 |
5 |
6 | class Sent2vec(semantic_clustering.SemanticClustering):
7 | def generate_embeddings(self, tag, vector_path):
8 | print('Currently sent2vec only works with provided sentence embeddings.')
9 | print('Check github.com/epfml/sent2vec for getting sentence embeddings.')
10 | print('Btw any kind of sentence embeddings can be used if they are in the')
11 | print('required format, I recommend github.com/hanxiao/bert-as-service.')
12 | sys.exit()
13 |
--------------------------------------------------------------------------------
/code/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from utils.config import Config
4 | from filtering.average_word_embedding import AverageWordEmbedding
5 | from filtering.identity import Identity
6 | from filtering.sent2vec import Sent2vec
7 |
8 |
9 | def main():
10 | config = Config()
11 | parser = argparse.ArgumentParser(
12 | description='Code for filtering methods in: arxiv.org/abs/1905.05471. ' +
13 | 'These arguments can also be set in config.py, ' +
14 | 'and will be saved to the output directory.')
15 | parser.add_argument('-d', '--data_dir', default=config.data_dir,
16 | help='Directory containing the dataset in these files:' +
17 | ' (trainSource.txt, trainTarget.txt, devSource.txt, ' +
18 | 'devTarget.txt, testSource.txt, testTarget.txt, ' +
19 | 'vocab.txt)',
20 | metavar='')
21 | parser.add_argument('-o', '--output_dir', default=config.output_dir,
22 | help='Save here the filtered data and any output',
23 | metavar='')
24 | parser.add_argument('-l', '--load_config', default=config.load_config,
25 | help='Path to load config from file, or leave empty ' +
26 | '(default: %(default)s)',
27 | metavar='')
28 | parser.add_argument('-fs', '--filter_split', default=config.filter_split,
29 | help='Data split to filter, \'full\' filters ' +
30 | 'all splits (choices: %(choices)s)',
31 | metavar='', choices=['full', 'train', 'dev', 'test'])
32 | parser.add_argument('-ct', '--cluster_type', default=config.cluster_type,
33 | help='Clustering method (choices: %(choices)s)',
34 | metavar='',
35 | choices=['identity', 'avg_embedding', 'sent2vec'])
36 | parser.add_argument('-sc', '--source_clusters',
37 | default=config.source_clusters,
38 | help='Number of source clusters in case of Kmeans',
39 | metavar='', type=int)
40 | parser.add_argument('-tc', '--target_clusters',
41 | default=config.target_clusters,
42 | help='Number of target clusters in case of Kmeans',
43 | metavar='', type=int)
44 | parser.add_argument('-u', '--unique', default=config.unique,
45 | help='Whether to cluster only unique sentences ' +
46 | '(default: %(default)s)',
47 | metavar='', type=bool)
48 | parser.add_argument('-vs', '--vocab_size',
49 | default=config.vocab_size,
50 | help='Vocab size, only used if vocab file not given',
51 | metavar='', type=int)
52 | parser.add_argument('-ft', '--filter_type', default=config.filter_type,
53 | help='Filtering way (choices: %(choices)s)',
54 | metavar='', choices=['source', 'target', 'both'])
55 | parser.add_argument('-mins', '--min_cluster_size',
56 | default=config.min_cluster_size,
57 | help='Clusters with fewer elements won\'t get filtered' +
58 | ' (default: %(default)s)',
59 | metavar='', type=int)
60 | parser.add_argument('-t', '--threshold', default=config.threshold,
61 | help='Entropy threshold (default: %(default)s)',
62 | metavar='', type=int)
63 | parser.add_argument('-cm', '--clustering_method',
64 | default=config.clustering_method,
65 | help='Mean shift recommended (choices: %(choices)s)',
66 | metavar='', choices=['kmeans', 'mean_shift'])
67 | parser.add_argument('-bw', '--bandwidth', default=config.bandwidth,
68 | help='Mean shift bandwidth (default: %(default)s)',
69 | metavar='', type=float)
70 | parser.add_argument('-f', '--use_faiss', default=config.use_faiss,
71 | help='Whether to use faiss for GPU based clustering ' +
72 | '(default: %(default)s)',
73 | metavar='', type=bool)
74 | parser.add_argument('-maxal', '--max_avg_length',
75 | default=config.max_avg_length,
76 | help='Clusters with higher average sentence length' +
77 | 'won\'t get filtered (default: %(default)s)',
78 | metavar='', type=int)
79 | parser.add_argument('-maxml', '--max_medoid_length',
80 | default=config.max_medoid_length,
81 | help='Clusters with longer medoids won\'t get filtered' +
82 | ' (default: %(default)s)',
83 | metavar='', type=int)
84 |
85 | parser.parse_args(namespace=config)
86 | if config.load_config:
87 | config.load()
88 | config.save()
89 |
90 | filter_problems = {
91 | 'identity': Identity,
92 | 'avg_embedding': AverageWordEmbedding,
93 | 'sent2vec': Sent2vec,
94 | }
95 |
96 | problem = filter_problems[config.cluster_type](config)
97 | problem.run()
98 |
99 |
100 | if __name__ == "__main__":
101 | main()
102 |
--------------------------------------------------------------------------------
/code/utils/config.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import os
3 |
4 |
5 | # These can also be set as arguments via the command line.
6 | class Config:
7 | data_dir = 'data/DailyDialog/baseline' # Directory containing dataset.
8 | output_dir = 'data/DailyDialog/baseline/filtered_data/avg_embedding'
9 | load_config = None
10 | source_clusters = 0
11 | target_clusters = 0
12 | filter_split = 'full' # Which data split to filter.
13 | cluster_type = 'avg_embedding'
14 | unique = False # Whether to cluster only unique sentences.
15 | vocab_size = 16384 # Only used for average word embeddings.
16 | filter_type = 'both'
17 | min_cluster_size = 2 # Clusters with fewer elements won't get filtered.
18 | threshold = 1.1 # Entropy threshold for filtering.
19 | clustering_method = 'mean_shift' # Kmeans or mean_shift.
20 | bandwidth = 0.7 # Mean shift bandwidth.
21 | use_faiss = False # Whether to use the library for GPU based clustering.
22 | max_avg_length = 15 # Clusters with longer sentences won't get filtered.
23 | max_medoid_length = 50 # Clusters with longer medoids won't get filtered.
24 | project_path = os.path.join(
25 | os.path.dirname(os.path.abspath(__file__)), '..', '..')
26 |
27 | # Save this object to output dir.
28 | def save(self):
29 | out_dir = os.path.join(self.project_path, self.output_dir)
30 | if not os.path.exists(out_dir):
31 | os.makedirs(out_dir)
32 |
33 | file = open(os.path.join(out_dir, 'config'), 'wb')
34 | file.write(pickle.dumps(self.__dict__))
35 | file.close()
36 |
37 | # Load from output dir.
38 | def load(self):
39 | load_config = os.path.join(self.project_path, self.load_config, 'config')
40 | file = open(load_config, 'rb')
41 | self.__dict__ = pickle.loads(file.read())
42 | file.close()
43 |
--------------------------------------------------------------------------------
/code/utils/utils.py:
--------------------------------------------------------------------------------
1 | # Close n files to write the processed data into.
2 | def close_n_files(files):
3 | for file_name in files:
4 | files[file_name].close()
5 |
--------------------------------------------------------------------------------
/code/utils/visualization.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "data": {
10 | "application/javascript": [
11 | "IPython.OutputArea.prototype._should_scroll = function(lines) {\n",
12 | " return false;\n",
13 | "}\n"
14 | ],
15 | "text/plain": [
16 | ""
17 | ]
18 | },
19 | "metadata": {},
20 | "output_type": "display_data"
21 | }
22 | ],
23 | "source": [
24 | "%%javascript\n",
25 | "IPython.OutputArea.prototype._should_scroll = function(lines) {\n",
26 | " return false;\n",
27 | "}"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": null,
33 | "metadata": {
34 | "colab": {
35 | "base_uri": "https://localhost:8080/",
36 | "height": 122
37 | },
38 | "colab_type": "code",
39 | "executionInfo": {
40 | "elapsed": 27098,
41 | "status": "ok",
42 | "timestamp": 1562670013345,
43 | "user": {
44 | "displayName": "Richard Csaky",
45 | "photoUrl": "https://lh6.googleusercontent.com/-wOnQ8NuCQqo/AAAAAAAAAAI/AAAAAAAAABQ/GcPUlAm-a98/s64/photo.jpg",
46 | "userId": "14062868679888781538"
47 | },
48 | "user_tz": -120
49 | },
50 | "id": "C_bdg07CY0fR",
51 | "outputId": "32f43021-8632-46ff-f5f0-4ca318d6cefe"
52 | },
53 | "outputs": [],
54 | "source": [
55 | "# Only run if you want to use google drive inside the notebook.\n",
56 | "from google.colab import drive\n",
57 | "drive.mount('/content/drive')"
58 | ]
59 | },
60 | {
61 | "cell_type": "markdown",
62 | "metadata": {
63 | "colab_type": "text",
64 | "id": "-4T8h7psXxMI"
65 | },
66 | "source": [
67 | "# Parameters"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": 2,
73 | "metadata": {
74 | "cellView": "both",
75 | "colab": {},
76 | "colab_type": "code",
77 | "id": "obzOs3Y1X6qm"
78 | },
79 | "outputs": [],
80 | "source": [
81 | "#@title Path to directory containing filtered files and config file\n",
82 | "DIR = \"data/DailyDialog/baseline/filtered_data/\" #@param {type:\"string\"}"
83 | ]
84 | },
85 | {
86 | "cell_type": "markdown",
87 | "metadata": {
88 | "colab_type": "text",
89 | "id": "AnLdGG7LSd1B"
90 | },
91 | "source": [
92 | "# Setup\n",
93 | "Run some setup code and define the functions that will be used later."
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": 3,
99 | "metadata": {
100 | "colab": {
101 | "base_uri": "https://localhost:8080/",
102 | "height": 316
103 | },
104 | "colab_type": "code",
105 | "executionInfo": {
106 | "elapsed": 652,
107 | "status": "error",
108 | "timestamp": 1562671037191,
109 | "user": {
110 | "displayName": "Richard Csaky",
111 | "photoUrl": "https://lh6.googleusercontent.com/-wOnQ8NuCQqo/AAAAAAAAAAI/AAAAAAAAABQ/GcPUlAm-a98/s64/photo.jpg",
112 | "userId": "14062868679888781538"
113 | },
114 | "user_tz": -120
115 | },
116 | "id": "eoZ4rowjSrbu",
117 | "outputId": "0e84048e-681f-41b6-c33f-2d54c9981f7a"
118 | },
119 | "outputs": [],
120 | "source": [
121 | "%matplotlib inline\n",
122 | "\n",
123 | "import matplotlib.pyplot as plt\n",
124 | "import os\n",
125 | "import matplotlib.pyplot as plt\n",
126 | "import numpy as np\n",
127 | "import operator\n",
128 | "\n",
129 | "from config import Config\n",
130 | "\n",
131 | "\n",
132 | "# Load config file from specified directory\n",
133 | "Config.load_config = DIR\n",
134 | "config = Config()\n",
135 | "config.load()\n",
136 | "\n",
137 | "plt.rcParams.update({'font.size': 14})"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": 4,
143 | "metadata": {
144 | "colab": {},
145 | "colab_type": "code",
146 | "id": "kejzW0VnSP0L"
147 | },
148 | "outputs": [],
149 | "source": [
150 | "# Visualization function for the clustering data.\n",
151 | "def _visualize(file, tag, fig_list):\n",
152 | " '''\n",
153 | " Params:\n",
154 | " :file: Clustering file, from which to read data.\n",
155 | " :tag: Can be 'Source' or 'Target'.\n",
156 | " :fig_list: A list containing the plots, which we will draw.\n",
157 | " '''\n",
158 | " sentence_entropy = []\n",
159 | " entropies_all = []\n",
160 | " entropies = []\n",
161 | " sentence_cl_size = []\n",
162 | " cl_sizes_all = []\n",
163 | " cl_sizes = []\n",
164 | " lengths = []\n",
165 | "\n",
166 | " for line in file:\n",
167 | " [sentence, entropy, cl_size] = line.split(';')\n",
168 | " entropy = float(entropy)\n",
169 | " cl_size = int(cl_size)\n",
170 | "\n",
171 | " # Populate the lists.\n",
172 | " sentence_entropy.append([sentence, entropy])\n",
173 | " entropies_all.extend([entropy] * cl_size)\n",
174 | " entropies.append(entropy)\n",
175 | " sentence_cl_size.append([sentence, cl_size])\n",
176 | " cl_sizes_all.extend([cl_size] * cl_size)\n",
177 | " cl_sizes.append(cl_size)\n",
178 | " lengths.append(len(sentence.split()))\n",
179 | "\n",
180 | " # Draw the plots, and set properties.\n",
181 | " fig_list[0].plot(sorted(entropies_all))\n",
182 | " fig_list[0].set_xlabel('Sentence no.')\n",
183 | " fig_list[0].set_ylabel('Entropy')\n",
184 | " #fig_list[0].axis([0, 90000, -0.2, 9])\n",
185 | "\n",
186 | " fig_list[1].plot(sorted(cl_sizes_all))\n",
187 | " fig_list[1].set_xlabel('Sentence no.')\n",
188 | " fig_list[1].set_ylabel('Cluster size')\n",
189 | " #fig_list[1].axis([0, 90000, -0.2, 500])\n",
190 | "\n",
191 | " fig_list[2].scatter(np.array(cl_sizes), np.array(entropies))\n",
192 | " fig_list[2].set_xlabel('Cluster size')\n",
193 | " fig_list[2].set_ylabel('Entropy')\n",
194 | " #fig_list[2].axis([0, 320, -0.2, 9])\n",
195 | "\n",
196 | " fig_list[3].scatter(np.array(lengths), np.array(entropies))\n",
197 | " fig_list[3].set_xlabel('No. of words in utterance')\n",
198 | " fig_list[3].set_ylabel('Entropy')\n",
199 | " #fig_list[3].axis([-0.2, 30, -0.2, 10])\n",
200 | "\n",
201 | " # Sort the sentence lists.\n",
202 | " sent_ent = sorted(sentence_entropy, key=operator.itemgetter(1), reverse=True)\n",
203 | " sent_cl = sorted(sentence_cl_size, key=operator.itemgetter(1), reverse=True)\n",
204 | " return sent_ent, sent_cl"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": 5,
210 | "metadata": {
211 | "colab": {},
212 | "colab_type": "code",
213 | "id": "jetWP3UaXH6q"
214 | },
215 | "outputs": [],
216 | "source": [
217 | "# Main function to visualize clustering/filtering results.\n",
218 | "def data_visualization():\n",
219 | " # Open the clustering files.\n",
220 | " source_cl_entropies = open(os.path.join(config.project_path, config.output_dir, 'fullSource_cluster_entropies.txt'))\n",
221 | " target_cl_entropies = open(os.path.join(config.project_path, config.output_dir, 'fullTarget_cluster_entropies.txt'))\n",
222 | "\n",
223 | " # Set up matplotlib.\n",
224 | " plt.close('all')\n",
225 | " fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6), (ax7, ax8)) = plt.subplots(nrows=4, ncols=2)\n",
226 | " fig.set_size_inches(13, 20)\n",
227 | "\n",
228 | " # Call the actual visualization function for source and target data.\n",
229 | " source_entropies, source_cl_sizes = _visualize(source_cl_entropies,\n",
230 | " 'Source',\n",
231 | " [ax1, ax3, ax5, ax7])\n",
232 | " target_entropies, target_cl_sizes = _visualize(target_cl_entropies,\n",
233 | " 'Target',\n",
234 | " [ax2, ax4, ax6, ax8])\n",
235 | " plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=1.0)\n",
236 | "\n",
237 | " source_cl_entropies.close()\n",
238 | " target_cl_entropies.close()"
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": 10,
244 | "metadata": {
245 | "colab": {},
246 | "colab_type": "code",
247 | "id": "hfgy52_TXMQl"
248 | },
249 | "outputs": [],
250 | "source": [
251 | "def print_clusters(tag, top_clusters):\n",
252 | " clusters = {}\n",
253 | " cluster_element_lengths = {}\n",
254 | "\n",
255 | " with open(os.path.join(config.project_path, config.output_dir, 'full{}_cluster_elements.txt'.format(tag))) as file:\n",
256 | " for line in file:\n",
257 | " [source, source_cl_target, target_cl] = line.split('<=====>')\n",
258 | "\n",
259 | " if tag == 'Source':\n",
260 | " source_cl = source.split(';')[1]\n",
261 | " target_cl_ind = target_cl.split(':')[1].strip('\\n')\n",
262 | " source = source_cl_target.split('=')[0]\n",
263 | " target = source_cl_target.split('=')[1]\n",
264 | " cluster_element_lengths[source_cl] = \\\n",
265 | " cluster_element_lengths.get(source_cl, 0) + len(source.split())\n",
266 | " clusters[source_cl] = [*clusters.get(source_cl, []), source]\n",
267 | "\n",
268 | " else:\n",
269 | " target_cl = source.split(';')[1]\n",
270 | " target = source_cl_target.split('=')[0]\n",
271 | " cluster_element_lengths[target_cl] = \\\n",
272 | " cluster_element_lengths.get(target_cl, 0) + len(target.split())\n",
273 | " clusters[target_cl] = [*clusters.get(target_cl, []), target]\n",
274 | "\n",
275 | " with open(os.path.join(config.project_path, config.output_dir, 'full{}_cluster_entropies.txt'.format(tag))) as file:\n",
276 | " entropies = {}\n",
277 | " for line in file:\n",
278 | " [medoid, entropy, size] = line.split(';')\n",
279 | " entropies[medoid] = float(entropy)\n",
280 | "\n",
281 | " num_removed = 0\n",
282 | " num_all = 0\n",
283 | " for medoid in cluster_element_lengths:\n",
284 | " num_all += len(clusters[medoid])\n",
285 | " if ((cluster_element_lengths[medoid] / len(clusters[medoid]) if\n",
286 | " len(clusters[medoid]) > 0 else 1) > 1000 or\n",
287 | " len(medoid.split()) > 1000 or\n",
288 | " entropies[medoid] < config.threshold):\n",
289 | " num_removed += len(clusters[medoid])\n",
290 | "\n",
291 | " #print(num_removed / num_all)\n",
292 | " for _, medoid in zip(range(top_clusters),\n",
293 | " sorted(list(clusters), key=lambda x: entropies[x],\n",
294 | " reverse=True)):\n",
295 | " print('=====================================================')\n",
296 | " #print('{}& {} & {} \\\\\\\\'.format(list(set(clusters[medoid]))[0], len(clusters[medoid]), str(entropies[medoid])[:4]))\n",
297 | " print('Center: {}'.format(medoid))\n",
298 | " print('Entropy: {}'.format(entropies[medoid]))\n",
299 | " print('Size: {}'.format(len(clusters[medoid])))\n",
300 | " if len(clusters[medoid]) < 1000:\n",
301 | " print('Elements: \\n{}\\n\\n'.format('\\n'.join(set(clusters[medoid]))))"
302 | ]
303 | },
304 | {
305 | "cell_type": "markdown",
306 | "metadata": {
307 | "colab_type": "text",
308 | "id": "hItN-dJISP0N"
309 | },
310 | "source": [
311 | "# Visualize the clustering.\n",
312 | "Graphs on the left are about source data, and on the right the target data.\n",
313 | "* First the entropy of all the utterances in the dataset is plotted.\n",
314 | "* Second each sentence's cluster's size is plotted (for all sentences in the dataset).\n",
315 | "* Third the entropy and cluster size of all clusters is plotted\n",
316 | "* Finally the relationship between the entropy of an utterance and the number of words in it is plotted."
317 | ]
318 | },
319 | {
320 | "cell_type": "code",
321 | "execution_count": 15,
322 | "metadata": {
323 | "colab": {},
324 | "colab_type": "code",
325 | "id": "EKGlzOI6SP0O",
326 | "outputId": "52151bf6-100e-4905-c798-669cd7fd3d67",
327 | "scrolled": true
328 | },
329 | "outputs": [
330 | {
331 | "data": {
332 | "image/png": "\n",
333 | "text/plain": [
334 | ""
335 | ]
336 | },
337 | "metadata": {
338 | "needs_background": "light"
339 | },
340 | "output_type": "display_data"
341 | }
342 | ],
343 | "source": [
344 | "data_visualization()"
345 | ]
346 | },
347 | {
348 | "cell_type": "markdown",
349 | "metadata": {
350 | "colab_type": "text",
351 | "collapsed": true,
352 | "id": "I3m_tPouSP0U"
353 | },
354 | "source": [
355 | "# Print some clusters\n",
356 | "Let's see the unique elements of the clusters with highest entropy."
357 | ]
358 | },
359 | {
360 | "cell_type": "code",
361 | "execution_count": 13,
362 | "metadata": {
363 | "colab": {},
364 | "colab_type": "code",
365 | "id": "B-y1cyabSP0V",
366 | "outputId": "02f5b184-a729-4089-f788-64d1ef39933d",
367 | "scrolled": true
368 | },
369 | "outputs": [
370 | {
371 | "name": "stdout",
372 | "output_type": "stream",
373 | "text": [
374 | "=====================================================\n",
375 | "Center: so dick how about getting some coffee for tonight ?\n",
376 | "Entropy: 2.3204735011414868\n",
377 | "Size: 23\n",
378 | "Elements: \n",
379 | "i do not care the brand as long as it works well .\n",
380 | "no the steak was recommended but it is not very fresh .\n",
381 | "i have this list of stuff that i need and i only have half the dough .\n",
382 | "half the dough huh . well . how would you like to earn the other half ?\n",
383 | "would you care for a drink before you order ?\n",
384 | "yes i have many things to buy . i would like to choose the cleaning milk first .\n",
385 | "i want to buy the toothpaste the brand of jiajieshi .\n",
386 | "not today honey . do n't eat too much ice cream .\n",
387 | "sounds good . what about shampoo ? i would like to buy the product that prevents scurf .\n",
388 | "yes very . believe it or not it will cost you more than one hundred dollars .\n",
389 | "mom can i have some ice cream ?\n",
390 | "i want to buy more books .\n",
391 | "i just happen to have a question for you guys . why do the chinese cook the vegetables ? you see what i mean is that most vitamin are destroyed when heated .\n",
392 | "is that expensive ?\n",
393 | "i need a packet of cigarettes please .\n",
394 | "yes we charge 50 cents for water .\n",
395 | "mom can i have one more piece of cake ?\n",
396 | "where is your store located ?\n",
397 | "peter go and tidy up your toys now .\n",
398 | "i think i going to need some iced water too . is there an extra charge for that ?\n",
399 | "no sweat . it s a piece of cake .\n",
400 | "may i have a cookie ?\n",
401 | "that 's fine . could you give me some more napkins too ?\n",
402 | "\n",
403 | "\n",
404 | "=====================================================\n",
405 | "Center: come on you can at least try a little besides your cigarette .\n",
406 | "Entropy: 2.265710144725459\n",
407 | "Size: 26\n",
408 | "Elements: \n",
409 | "but i already tried my best .\n",
410 | "wow i ca n't thank you enough .\n",
411 | "thanks a lot . that 's the favor i was going to ask you for .\n",
412 | "i am wiped out . thank you .\n",
413 | "i 'm sorry marry is out right now .\n",
414 | "ok . but i 'm not familiar . i do n't know the beginning part .\n",
415 | "forgive you for what ?\n",
416 | "oh yeah ? are you joking ?\n",
417 | "16 divided by 2 . what 's the answer ?\n",
418 | "i m afraid not . i apologize .\n",
419 | "no ! are you kidding me ?\n",
420 | "jim could you do me a favor ?\n",
421 | "yes i believe they are . here are something that might interest you .\n",
422 | "i 'm afraid there 's been a mistake .\n",
423 | "how much does it cost ?\n",
424 | "i 'm mad ! i said now ! turn off the tv and do it now .\n",
425 | "never ! but thank you for inviting me .\n",
426 | "i 'm sorry sir . what seems to be the trouble ?\n",
427 | "let 's go !\n",
428 | "dad can you just tell me what it means ? i 'm too lazy .\n",
429 | "what time should i call back ?\n",
430 | "you 've got it . tell me a little bit about what you might be wanted .\n",
431 | "thanks for helping me out . i really appreciate it .\n",
432 | "thank you . maybe we can sing a song together . would you like to sing with me ?\n",
433 | "thanks a lot .\n",
434 | "i prefer glossy . how much ?\n",
435 | "\n",
436 | "\n",
437 | "=====================================================\n",
438 | "Center: look next time get yourself some comfy shoes . you re gonna come back again with me aren t you ?\n",
439 | "Entropy: 2.187126559186671\n",
440 | "Size: 67\n",
441 | "Elements: \n",
442 | "would you mind waiting a while ?\n",
443 | "dad . here comes another bus .\n",
444 | "i 'll put all that into the bag for you .\n",
445 | "peter wash your hands first and then have some dessert .\n",
446 | "when is a good time to catch him ?\n",
447 | "i do n't know about all of that but think about it their business gets free publicity !\n",
448 | "did somebody hit you ? or did you just fall ?\n",
449 | "so dick how about getting some coffee for tonight ?\n",
450 | "of course not ! i 'm telling the truth .\n",
451 | "it looks to be a nice day today .\n",
452 | "jessie i m afraid i can t come back home for dinner tonight .\n",
453 | "dad but i do n't want to share a room with peter . he snores every night .\n",
454 | "oh great ! it 's delicious . you see i am already putting on weight . there is one thing i do n't like however msg .\n",
455 | "do you mind if i smoke ?\n",
456 | "it runs every 15 minutes . you must have missed it .\n",
457 | "oh no get off it . it wasn t such a killer class . you just have to get into it . like they say no pain no gain .\n",
458 | "oh do n't let that worry you . if that were true china would n't have such a large population .\n",
459 | "oh i almost forgot . it 's my mum 's birthday saturday . i need to get her some more chanel . could you get me the 1 . 7 ounce bottle of chanel cologne ?\n",
460 | "mom . i have to go school shopping . there 's only one more week left .\n",
461 | "i swear i m going to kill you for this .\n",
462 | "well i want a fillet steak medium but my little girl does n't care for steak . could she have something else instead ?\n",
463 | "are things still going badly with your houseguest ?\n",
464 | "what 's wrong with msg ? it helps to bring out the taste of the food .\n",
465 | "mrs . smith has not checked in yet .\n",
466 | "which do you prefer color or black and white ?\n",
467 | "honey you can ask him to be quite . otherwise you may punish him and tell him to stand out of the room right ?\n",
468 | "what happened peter ? did you have a fight ?\n",
469 | "leo i really think you re beating around the bush with this guy . i know he used to be your best friend in college but i really think it s time to lay down the law .\n",
470 | "ok i 'm sorry it took so long .\n",
471 | "oh honey i 'm so sorry we do n't have enough space for you to have your own room .\n",
472 | "no he 's perfectly harmless . and he 's not afraid of strangers either . here hold him .\n",
473 | "the movie company has to pay them ?\n",
474 | "you can use this . it has special effect for keeping your face moisturized . it has this lotion as a gift attached .\n",
475 | "i used your computer . and i m afraid i ve erased your personal files accidentally .\n",
476 | "can you leave a message for her to call her office ?\n",
477 | "may i sit here ?\n",
478 | "i wonder how all the businesses in the area feel about that .\n",
479 | "he stepped out of the office for a little while .\n",
480 | "isn t he the best instructor ? i think he s so hot . wow ! i really feel energized don t you ?\n",
481 | "sure we are off singing road close to the bank .\n",
482 | "i 'm looking for an engagement ring for my girlfriend . i have an idea of what she likes but i want to surprise her with something special too .\n",
483 | "good evening . welcome to cherry 's . do you have a reservation ?\n",
484 | "getting worse . now he s eating me out of house and home . i ve tried talking to him but it all goes in one ear and out the other . he makes himself at home which is fine . but what really gets me is that yesterday he walked into the living room in the raw and i had company over ! that was the last straw .\n",
485 | "the phone has been ringing off the hook today .\n",
486 | "i am in no particular hurry .\n",
487 | "you are in luck . we just receive a shipment of several different styles of white purses .\n",
488 | "we ca n't go that way the road is blocked for the next few days .\n",
489 | "oh honey you made a mistake .\n",
490 | "i 'm looking for some blush . do you still have some in peach rose ?\n",
491 | "you stepped on my foot !\n",
492 | "ahahah ! what is that thing on your couch ! it just moved !\n",
493 | "sam you ve got to forgive me .\n",
494 | "mom just ten more minutes . the show is going to be over soon .\n",
495 | "they will be ready at noon tomorrow . each negative develops one print right ?\n",
496 | "now we 're going to draw an apple in your sketch book . what do we use ?\n",
497 | "i believe you have charged me twice for the same thing . look the figure of 6 . 5 dollar appears here then again here .\n",
498 | "go straight up zhongshan road and you will see our sign on your right after you pass the museum .\n",
499 | "excuse me i 've been waiting here for 10 minutes . do you know how often does no . 636 run ?\n",
500 | "kata ! you 've got a beautiful singing voice . you hit the high notes perfectly .\n",
501 | "can you connect me to mary . smith hotel room ?\n",
502 | "they must be popular again this season .\n",
503 | "sorry i ca n't sing the song .\n",
504 | "i can t believe it ! i have all my important personal documents stored in that computer . it s no laughing matter .\n",
505 | "you have a pet lizard ? somehow i never would have imagined that .\n",
506 | "oh ! sorry to hear that . this is quite unusual as we have steak from the market every day .\n",
507 | "not back home for dinner again ? that s the third time this week !\n",
508 | "come on you can at least try a little besides your cigarette .\n",
509 | "\n",
510 | "\n",
511 | "=====================================================\n",
512 | "Center: what s wrong with that ? cigarette is the thing i go crazy for .\n",
513 | "Entropy: 2.1571675430303348\n",
514 | "Size: 57\n",
515 | "Elements: \n",
516 | "can i help you sir what do you need ?\n",
517 | "yes sir . i 'll bring it over . have you decided what you 'd like sir ?\n",
518 | "are you sure ?\n",
519 | "yes we shall . what size do you like ?\n",
520 | "great i 'll take one .\n",
521 | "yes it is . and develop them as glossy as possible .\n",
522 | "sure just ask . what can i do for you ?\n",
523 | "dad how can we get to the zoo ?\n",
524 | "how about this one ? it is wellknown for the effect of removing scurf .\n",
525 | "where are you going ?\n",
526 | "i told you i m sorry . what can i do to make it up to you ?\n",
527 | "no mom . i did n't .\n",
528 | "all right . what is your type of skin ?\n",
529 | "what can i do for you ?\n",
530 | "oh ok where do i get off ?\n",
531 | "do i have to memorize it ?\n",
532 | "do you need money or what ?\n",
533 | "welcome . may i help you ?\n",
534 | "this is how you turn on the computer .\n",
535 | "when do you need them sir ?\n",
536 | "well how long will it be ?\n",
537 | "it is difficult .\n",
538 | "i like red .\n",
539 | "yes it is .\n",
540 | "dad can you help me ?\n",
541 | "depends ? depends on what ?\n",
542 | "of course sir no problem .\n",
543 | "let me see . it depends .\n",
544 | "no is it difficult ?\n",
545 | "sure where are you ?\n",
546 | "we can take a bus there .\n",
547 | "er . . . how about this one ?\n",
548 | "how many of you please ?\n",
549 | "so what ? it is not fresh and i 'm not happy about it .\n",
550 | "do i have to change ?\n",
551 | "well the 4 x 6 is fine .\n",
552 | "ok . let 's do it together .\n",
553 | "may i help you find something sir ?\n",
554 | "no . the museum is the terminal of this bus .\n",
555 | "dad may i have a room of my own ?\n",
556 | "dad i want to draw with crayons can i ?\n",
557 | "what for ?\n",
558 | "what kind of food do you like ?\n",
559 | "may i help you madam ?\n",
560 | "may i help you sir ?\n",
561 | "no we do n't .\n",
562 | "let 's see . from here you have to take the 278 bus .\n",
563 | "i see thank you . by the way is this the right bus for the museum ?\n",
564 | "how about this one ?\n",
565 | "do i have a choice ? uh . that 's a no . what can i do ?\n",
566 | "i do n't know how to do it .\n",
567 | "thank you for your compliment . but you are exaggerating . i think you are destined to be a singer . you have the best voice !\n",
568 | "i would like to have a roll developed .\n",
569 | "i think so .\n",
570 | "is it a newcomer ?\n",
571 | "yes i 'd like to . it 's my honor . let 's pick a song .\n",
572 | "yes . thank you .\n",
573 | "\n",
574 | "\n",
575 | "=====================================================\n",
576 | "Center: but your american ?\n",
577 | "Entropy: 2.0\n",
578 | "Size: 4\n",
579 | "Elements: \n",
580 | "have you heard about our special promotion this month ? if you purchase at least 18 dollar 50 cents in any elizabeth arden products you will receive this black poke with a sample of lipstick mascara and two shades of white shadow .\n",
581 | "we have all shapes sizes qualities and price ranges do you know about the four cs of picking a diamond ?\n",
582 | "wow that sounds like a bargain . i 'm running low on facial moisturizer and toner . could you ring those up for me too along with the blush ?\n",
583 | "well my price range is a 5000 dollars to 7000 dollars i 'm looking for a marquise cut on the wide band .\n",
584 | "\n",
585 | "\n",
586 | "=====================================================\n",
587 | "Center: oh no get off it . it wasn t such a killer class . you just have to get into it . like they say no pain no gain .\n",
588 | "Entropy: 1.9877733714879842\n",
589 | "Size: 13\n",
590 | "Elements: \n",
591 | "i 'd be glad to . do you need anything else ?\n",
592 | "dad how do you say this word ?\n",
593 | "coffee ? i don t honestly like that kind of stuff .\n",
594 | "certainly sir .\n",
595 | "you don t have to explain . suit yourself .\n",
596 | "i understand .\n",
597 | "i think that they get a pretty good payoff .\n",
598 | "sword say it sword .\n",
599 | "what s wrong ? didn t you think it was fun ? !\n",
600 | "is everything to your satisfaction ?\n",
601 | "sure . do you need anything else ?\n",
602 | "what does this word mean ?\n",
603 | "anything else ?\n",
604 | "\n",
605 | "\n",
606 | "=====================================================\n",
607 | "Center: oh do n't let that worry you . if that were true china would n't have such a large population .\n",
608 | "Entropy: 1.8220931167465637\n",
609 | "Size: 54\n",
610 | "Elements: \n",
611 | "my car has a problem starting . could you please take a look at it for me ?\n",
612 | "oh that 's right . they 're filming a movie up there are n't they ?\n",
613 | "are you going to the annual party ? i can give you a ride if you need one .\n",
614 | "if you are in a hurry you should take a taxi .\n",
615 | "and do you want the glossy or matted finish ?\n",
616 | "yes a roll of kodak film please .\n",
617 | "would you like to take a look at the menu sir ?\n",
618 | "hurry up will you ?\n",
619 | "i apologize . you have my word i ll spend some time with you on the weekend . i promise .\n",
620 | "the last one is black and white all the rest should need color .\n",
621 | "can you please tell me where you are located ?\n",
622 | "you pay when you pick them up . i do n't need a deposit for just one roll of film .\n",
623 | "could i have my bill please ?\n",
624 | "his name is grunt . come closer and i 'll properly introduce you .\n",
625 | "no he is at work now .\n",
626 | "did you think it was n't real ? that 's my pet lizard .\n",
627 | "i think so . are n't the four cs cut clarity carat and color .\n",
628 | "what s wrong with that ? cigarette is the thing i go crazy for .\n",
629 | "i 'm looking for a white purse as a gift . could you show what you have in stock ?\n",
630 | "can you send a cab to pick me up ?\n",
631 | "may i have his office phone number please ?\n",
632 | "you should get off at the first shi da stop .\n",
633 | "have you got change for a thousand ?\n",
634 | "when will she be back ?\n",
635 | "i 'm not sure . but i 'll get a table ready as fast as i can .\n",
636 | "no problem . do you need another film ?\n",
637 | "sorry i 'm late .\n",
638 | "how was your test ?\n",
639 | "no it 's quite simple . when you get on just ask the bus driver when to pay the fare and where you want to get off .\n",
640 | "fine . let 's get on . oh no judy ! get off the bus quickly !\n",
641 | "i m sorry . our company has just opened . there are always too many things to handle . you know that .\n",
642 | "no honey it 's easy if you know the way .\n",
643 | "excuse me i 'm a little lost . which bus do i take to get to shi da ?\n",
644 | "mom can i have more allowance ?\n",
645 | "does this bus go there ?\n",
646 | "never mind . you can follow me . i 'll sing the first part .\n",
647 | "i 'd like to talk to mr . white please ?\n",
648 | "oh yes that is a beautiful color . it has been very popular blush this season . i have two left .\n",
649 | "yes how do i get to your shop from chilin ?\n",
650 | "wow . this is nice . i 'll take this one . i guess if she does n't like it she can return it right ?\n",
651 | "no problem . do you want 3 x 5 or 4 x 6 ?\n",
652 | "good let 's go for a drive .\n",
653 | "i 'm sorry sir . do you wish to try something else ? that would be on the house of course .\n",
654 | "okay i 'll get the car out of the garage .\n",
655 | "it seems you got here at good time . do you have a bus schedule ?\n",
656 | "yes i do . you can buy a bus schedule in a news stand .\n",
657 | "let 's step in dad .\n",
658 | "mike 's mechanics . can i help you ?\n",
659 | "take a bus then . it will only cost you 5 dollars .\n",
660 | "can i take your order now or do you still want to look at the menu ?\n",
661 | "robert is not available at the moment .\n",
662 | "i hope they will come out well . when should i pick them up ?\n",
663 | "only 15 nt per section . oh look that is your bus .\n",
664 | "would you like a lift home ?\n",
665 | "\n",
666 | "\n",
667 | "=====================================================\n",
668 | "Center: a glass of qingdao beer .\n",
669 | "Entropy: 1.584962500721156\n",
670 | "Size: 3\n",
671 | "Elements: \n",
672 | "yes i would also like some sweet and sour sauce and pepper .\n",
673 | "certainly . how about spaghetti with clams and shrimps .\n",
674 | "do i owe you anything for the sauce pepper and napkins ?\n",
675 | "\n",
676 | "\n",
677 | "=====================================================\n",
678 | "Center: yes sir . i 'll bring it over . have you decided what you 'd like sir ?\n",
679 | "Entropy: 1.584962500721156\n",
680 | "Size: 3\n",
681 | "Elements: \n",
682 | "ok .\n",
683 | "the 4 x6 will be ok . thanks .\n",
684 | "ok thanks . . .\n",
685 | "\n",
686 | "\n",
687 | "=====================================================\n",
688 | "Center: i am wiped out . thank you .\n",
689 | "Entropy: 1.0\n",
690 | "Size: 2\n",
691 | "Elements: \n",
692 | "somebody please answer the phone .\n",
693 | "look next time get yourself some comfy shoes . you re gonna come back again with me aren t you ?\n",
694 | "\n",
695 | "\n"
696 | ]
697 | }
698 | ],
699 | "source": [
700 | "print_clusters(tag='Source', top_clusters=10)"
701 | ]
702 | },
703 | {
704 | "cell_type": "code",
705 | "execution_count": null,
706 | "metadata": {},
707 | "outputs": [],
708 | "source": []
709 | }
710 | ],
711 | "metadata": {
712 | "colab": {
713 | "collapsed_sections": [
714 | "hItN-dJISP0N",
715 | "I3m_tPouSP0U"
716 | ],
717 | "name": "visualization.ipynb",
718 | "provenance": [],
719 | "toc_visible": true,
720 | "version": "0.3.2"
721 | },
722 | "kernelspec": {
723 | "display_name": "Python 3",
724 | "language": "python",
725 | "name": "python3"
726 | },
727 | "language_info": {
728 | "codemirror_mode": {
729 | "name": "ipython",
730 | "version": 3
731 | },
732 | "file_extension": ".py",
733 | "mimetype": "text/x-python",
734 | "name": "python",
735 | "nbconvert_exporter": "python",
736 | "pygments_lexer": "ipython3",
737 | "version": "3.6.6"
738 | }
739 | },
740 | "nbformat": 4,
741 | "nbformat_minor": 1
742 | }
743 |
--------------------------------------------------------------------------------
/docs/3d.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ricsinaruto/NeuralChatbots-DataFiltering/2b8a2e848b089ff782ef4c3426af407be3c6f68b/docs/3d.png
--------------------------------------------------------------------------------
/docs/cluster_examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ricsinaruto/NeuralChatbots-DataFiltering/2b8a2e848b089ff782ef4c3426af407be3c6f68b/docs/cluster_examples.png
--------------------------------------------------------------------------------
/docs/example_responses.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ricsinaruto/NeuralChatbots-DataFiltering/2b8a2e848b089ff782ef4c3426af407be3c6f68b/docs/example_responses.png
--------------------------------------------------------------------------------
/docs/help.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ricsinaruto/NeuralChatbots-DataFiltering/2b8a2e848b089ff782ef4c3426af407be3c6f68b/docs/help.png
--------------------------------------------------------------------------------
/docs/high_entropy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ricsinaruto/NeuralChatbots-DataFiltering/2b8a2e848b089ff782ef4c3426af407be3c6f68b/docs/high_entropy.png
--------------------------------------------------------------------------------
/docs/metrics_table.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ricsinaruto/NeuralChatbots-DataFiltering/2b8a2e848b089ff782ef4c3426af407be3c6f68b/docs/metrics_table.png
--------------------------------------------------------------------------------
/docs/other_datasets.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ricsinaruto/NeuralChatbots-DataFiltering/2b8a2e848b089ff782ef4c3426af407be3c6f68b/docs/other_datasets.png
--------------------------------------------------------------------------------
/docs/uml.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ricsinaruto/NeuralChatbots-DataFiltering/2b8a2e848b089ff782ef4c3426af407be3c6f68b/docs/uml.png
--------------------------------------------------------------------------------
/docs/visu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ricsinaruto/NeuralChatbots-DataFiltering/2b8a2e848b089ff782ef4c3426af407be3c6f68b/docs/visu.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | sklearn
2 | clint
3 | matplotlib
4 | numpy
5 | requests
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | print('Installing requirements...')
5 | os.system('pip install -r requirements.txt')
6 |
7 | import requests
8 | import zipfile
9 | from clint.textui import progress
10 |
11 |
12 | def download_data(url, zipped_path, extract):
13 | # Open the url and download the data with progress bars.
14 | data_stream = requests.get(url, stream=True)
15 |
16 | with open(zipped_path, 'wb') as file:
17 | total_length = int(data_stream.headers.get('content-length'))
18 | for chunk in progress.bar(data_stream.iter_content(chunk_size=1024),
19 | expected_size=total_length / 1024 + 1):
20 | if chunk:
21 | file.write(chunk)
22 | file.flush()
23 |
24 | # Extract file.
25 | zip_file = zipfile.ZipFile(zipped_path, 'r')
26 | zip_file.extractall(extract)
27 | zip_file.close()
28 |
29 |
30 | print('Do you want to download all datasets used in the paper (116 MB)? (y/n)')
31 | if input() == 'y':
32 | if not os.path.exists('data'):
33 | os.mkdir('data')
34 | download_data('https://ricsinaruto.github.io/website/docs/Twitter.zip', 'data/Twitter.zip', 'data')
35 | download_data('https://ricsinaruto.github.io/website/docs/Cornell.zip', 'data/Cornell.zip', 'data')
36 | download_data('https://ricsinaruto.github.io/website/docs/DailyDialog.zip', 'data/DailyDialog.zip', 'data')
37 |
38 | print('Do you want to download all generated responses on the test set by the different models (7 MB)? (y/n)')
39 | if input() == 'y':
40 | download_data('https://ricsinaruto.github.io/website/docs/responses.zip', 'responses.zip', '')
41 |
--------------------------------------------------------------------------------