├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── code ├── filtering │ ├── average_word_embedding.py │ ├── filter_problem.py │ ├── identity.py │ ├── semantic_clustering.py │ └── sent2vec.py ├── main.py └── utils │ ├── config.py │ ├── filtering_demo.ipynb │ ├── utils.py │ └── visualization.ipynb ├── docs ├── 3d.png ├── class_diagram.uml ├── cluster_examples.png ├── example_responses.png ├── help.png ├── high_entropy.png ├── metrics_table.png ├── other_datasets.png ├── uml.png └── visu.png ├── requirements.txt └── setup.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | 4 | # Custom for Visual Studio 5 | *.cs diff=csharp 6 | 7 | # Standard to msysgit 8 | *.doc diff=astextplain 9 | *.DOC diff=astextplain 10 | *.docx diff=astextplain 11 | *.DOCX diff=astextplain 12 | *.dot diff=astextplain 13 | *.DOT diff=astextplain 14 | *.pdf diff=astextplain 15 | *.PDF diff=astextplain 16 | *.rtf diff=astextplain 17 | *.RTF diff=astextplain 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask instance folder 57 | instance/ 58 | 59 | # Scrapy stuff: 60 | .scrapy 61 | 62 | # Sphinx documentation 63 | docs/_build/ 64 | 65 | # PyBuilder 66 | target/ 67 | 68 | # IPython Notebook 69 | .ipynb_checkpoints 70 | 71 | # pyenv 72 | .python-version 73 | 74 | # celery beat schedule file 75 | celerybeat-schedule 76 | 77 | # dotenv 78 | .env 79 | 80 | # virtualenv 81 | venv/ 82 | ENV/ 83 | 84 | # Spyder project settings 85 | .spyderproject 86 | 87 | # Rope project settings 88 | .ropeproject 89 | 90 | # ========================= 91 | # Operating System Files 92 | # ========================= 93 | 94 | # OSX 95 | # ========================= 96 | 97 | .DS_Store 98 | .AppleDouble 99 | .LSOverride 100 | 101 | # Thumbnails 102 | ._* 103 | 104 | # Files that might appear in the root of a volume 105 | .DocumentRevisions-V100 106 | .fseventsd 107 | .Spotlight-V100 108 | .TemporaryItems 109 | .Trashes 110 | .VolumeIcon.icns 111 | 112 | # Directories potentially created on remote AFP share 113 | .AppleDB 114 | .AppleDesktop 115 | Network Trash Folder 116 | Temporary Items 117 | .apdisk 118 | 119 | # Windows 120 | # ========================= 121 | 122 | # Windows image file caches 123 | Thumbs.db 124 | ehthumbs.db 125 | 126 | # Folder config file 127 | Desktop.ini 128 | 129 | # Recycle Bin used on file shares 130 | $RECYCLE.BIN/ 131 | 132 | # Windows Installer files 133 | *.cab 134 | *.msi 135 | *.msm 136 | *.msp 137 | 138 | # Windows shortcuts 139 | *.lnk 140 | 141 | *.txt 142 | # Large files and folders 143 | private/ 144 | data/ 145 | responses/ 146 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Richard Krisztian Csaky 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NeuralChatbots-DataFiltering · [![twitter](https://img.shields.io/twitter/url/https/shields.io.svg?style=social)](https://ctt.ac/E_jP6) 2 | [![Paper](https://img.shields.io/badge/Presented%20at-ACL%202019-yellow.svg)](https://www.aclweb.org/anthology/P19-1567) [![Poster](https://img.shields.io/badge/The-Poster-yellow.svg)](https://ricsinaruto.github.io/website/docs/acl_poster_h.pdf) [![Code1](https://img.shields.io/badge/code-chatbot%20training-green.svg)](https://github.com/ricsinaruto/Seq2seqChatbots) [![Code2](https://img.shields.io/badge/code-evaluation-green.svg)](https://github.com/ricsinaruto/dialog-eval) [![documentation](https://img.shields.io/badge/documentation-on%20wiki-red.svg)](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/wiki) [![blog](https://img.shields.io/badge/Blog-post-black.svg)](https://medium.com/@richardcsaky/neural-chatbots-are-dumb-65b6b40e9bd4) 3 | A lightweight repo for filtering dialog data with entropy-based methods. 4 | 5 | The program **reads the dataset**, runs **clustering** if needed, computes the **entropy** of individual utterances, and then **removes high entropy** utterances based on the threshold, and **saves the filtered dataset** to the output directory. See the [paper](https://www.aclweb.org/anthology/P19-1567) or the [poster](https://ricsinaruto.github.io/website/docs/acl_poster_h.pdf) for more details. 6 | 7 | ## Features 8 | :floppy_disk:   Cluster and filter any dialog data that you provide, or use pre-downloaded datasets 9 | :rocket:   Various parameters can be used to adjust the algorithm 10 | :ok_hand:    Choose between different entropy computation methods 11 | :twisted_rightwards_arrows:   Choose between different clustering and filtering types 12 | :movie_camera:   Visualize clustering and filtering results 13 | 14 | 15 | 16 | ## Setup 17 | Run setup.py which installs required packages and steps you through downloading additional data: 18 | ``` 19 | python setup.py 20 | ``` 21 | You can download all trained models used in [this](https://www.aclweb.org/anthology/P19-1567) paper from [here](https://mega.nz/#!mI0iDCTI!qhKoBiQRY3rLg3K6nxAmd4ZMNEX4utFRvSby_0q2dwU). Each training contains two checkpoints, one for the validation loss minimum and another after 150 epochs. The data and the trainings folder structure match each other exactly. 22 | ## Usage 23 | The main file can be called from anywhere, but when specifying paths to directories you should give them from the root of the repository. 24 | ``` 25 | python code/main.py -h 26 | ``` 27 | 28 | For the complete **documentation** visit the [wiki](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/wiki). 29 | 30 | ### Cluster Type 31 | * [identity](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/code/filtering/identity.py): In this method there is basically no clustering, the entropy of utterances is calculated based on the conditional probability of utterance pairs. 32 | * [avg-embedding](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/code/filtering/average_word_embedding.py): This clustering type uses average word embedding sentence representations as in [this paper](https://pdfs.semanticscholar.org/3fc9/7768dc0b36449ec377d6a4cad8827908d5b4.pdf). 33 | * [sent2vec](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/code/filtering/sent2vec.py): This clustering type should use [sent2vec](https://github.com/epfml/sent2vec) sentence embeddings, but currently uses any embeddings you provide to it. 34 | 35 | ### Filter Type 36 | * **source**: Filters utterance pairs in which the source utterance's entropy is above the threshold. 37 | * **target**: Filters utterance pairs in which the target utterance's entropy is above the threshold. 38 | * **both**: Filters utterance pairs in which either the source or target utterance's entropy is above the threshold. 39 | 40 | ### [Filtering Demo](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/code/utils/filtering_demo.ipynb) 41 | In this jupyter notebook you can easily try out the identity filtering method implemented in less than 40 lines, and it filters DailyDialog in a couple of seconds (you only need to provide a sources and targets file). In the second part of the notebook there are some cool visualizations for entropy, frequency and sentence length. 42 | 43 | 44 | ### [Visualization](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/code/utils/visualization.ipynb) 45 | Visualize clustering and filtering results by running the [visualization](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/code/utils/visualization.ipynb) jupyter notebook. The notebook is pretty self-explanatory, you just have to provide the directory containing the clustering files. 46 | 47 | 48 | 49 | ## Results & Examples 50 | ### High Entropy Utterances and Clusters from [DailyDialog](https://arxiv.org/abs/1710.03957) 51 | 52 | 53 | A high entropy cluster found by sent2vec. 54 | 55 | ### [Transformer](https://arxiv.org/abs/1706.03762) Trained on [DailyDialog](https://arxiv.org/abs/1710.03957) 56 | For an explanation of the metrics please check [this repo](https://github.com/ricsinaruto/dialog-eval) or the [paper](https://arxiv.org/pdf/1905.05471.pdf). 57 | 58 | 59 | 60 | More examples can be found in the appendix of the [paper](https://arxiv.org/pdf/1905.05471.pdf). 61 | 62 | ### [Transformer](https://arxiv.org/abs/1706.03762) Trained on [Cornell](https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html) and [Twitter](https://github.com/facebookresearch/ParlAI/tree/master/parlai/tasks/twitter) 63 | For an explanation of the metrics please check [this repo](https://github.com/ricsinaruto/dialog-eval) or the [paper](https://arxiv.org/pdf/1905.05471.pdf). 64 | 65 | 66 | ## Contributing 67 | ##### Check the [issues](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/issues) for some additions where help is appreciated. Any contributions are welcome :heart: 68 | ##### Please try to follow the code syntax style used in the repo (flake8, 2 spaces indent, 80 char lines, commenting a lot, etc.) 69 | 70 | **New clustering** methods can be added, by subclassing the [FilterProblem](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/code/filtering/filter_problem.py#L48) class, check [Identity](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/code/filtering/identity.py) for a minimal example. Normally you only have to redefine the *clustering* function, which does the clustering of sentences. 71 | 72 | Loading and saving data is taken care of, and you should use the [Cluster](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/code/filtering/filter_problem.py#L24) and [DataPoint](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/code/filtering/filter_problem.py#L9) objects. Use the *data_point* list to get the sentences for your clustering algorithm, and use the *clusters* list to save the results of your clustering. These can also be subclassed if you want to add extra data to your DataPoint and Cluster objects (like a vector). 73 | 74 | Finally add your class to the dictionary in [main](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/code/main.py#L90), and to the command-line argument choices. 75 | 76 | 77 | ## Authors 78 | * **[Richard Csaky](ricsinaruto.github.io)** (If you need any help with running the code: ricsinaruto@hotmail.com) 79 | * **[Patrik Purgai](https://github.com/Mrpatekful)** (clustering part) 80 | 81 | ## License 82 | This project is licensed under the MIT License - see the [LICENSE](https://github.com/ricsinaruto/NeuralChatbots-DataFiltering/blob/master/LICENSE) file for details. 83 | Please include a link to this repo if you use it in your work and consider citing the following paper: 84 | ``` 85 | @inproceedings{Csaky:2019, 86 | title = "Improving Neural Conversational Models with Entropy-Based Data Filtering", 87 | author = "Cs{\'a}ky, Rich{\'a}rd and Purgai, Patrik and Recski, G{\'a}bor", 88 | booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics", 89 | month = jul, 90 | year = "2019", 91 | address = "Florence, Italy", 92 | publisher = "Association for Computational Linguistics", 93 | url = "https://www.aclweb.org/anthology/P19-1567", 94 | pages = "5650--5669", 95 | } 96 | ``` 97 | -------------------------------------------------------------------------------- /code/filtering/average_word_embedding.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import requests 5 | import zipfile 6 | from collections import Counter 7 | from clint.textui import progress 8 | 9 | from filtering import semantic_clustering 10 | 11 | 12 | class AverageWordEmbedding(semantic_clustering.SemanticClustering): 13 | ''' 14 | Averaged word embeddings clustering method. The meaning vector of the 15 | sentence is created by the weighted average of the word vectors. 16 | ''' 17 | 18 | # Download data from fasttext. 19 | def download_fasttext(self): 20 | # Open the url and download the data with progress bars. 21 | data_stream = requests.get('https://dl.fbaipublicfiles.com/fasttext/' + 22 | 'vectors-english/wiki-news-300d-1M.vec.zip', stream=True) 23 | zipped_path = os.path.join(self.input_dir, 'fasttext.zip') 24 | 25 | with open(zipped_path, 'wb') as file: 26 | total_length = int(data_stream.headers.get('content-length')) 27 | for chunk in progress.bar(data_stream.iter_content(chunk_size=1024), 28 | expected_size=total_length / 1024 + 1): 29 | if chunk: 30 | file.write(chunk) 31 | file.flush() 32 | 33 | # Extract file. 34 | zip_file = zipfile.ZipFile(zipped_path, 'r') 35 | zip_file.extractall(self.input_dir) 36 | zip_file.close() 37 | 38 | # Generate a vocab from data files. 39 | def get_vocab(self, vocab_path): 40 | vocab = [] 41 | 42 | with open(vocab_path, 'w') as file: 43 | for dp in self.data_points['Source']: 44 | vocab.extend(dp.string.split()) 45 | file.write('\n'.join( 46 | [w[0] for w in Counter(vocab).most_common(self.config.vocab_size)])) 47 | 48 | # Download FastText word embeddings. 49 | def get_fast_text_embeddings(self): 50 | vocab_path = os.path.join(self.input_dir, 'vocab.txt') 51 | if not os.path.exists(vocab_path): 52 | print('No vocab file named \'vocab.txt\' found in ' + self.input_dir) 53 | print('Building vocab from data.') 54 | self.get_vocab(vocab_path) 55 | 56 | fasttext_path = os.path.join(self.input_dir, 'wiki-news-300d-1M.vec') 57 | if not os.path.exists(fasttext_path): 58 | self.download_fasttext() 59 | 60 | vocab = [line.strip('\n') for line in open(vocab_path)] 61 | vocab_path = os.path.join(self.input_dir, 'vocab.npy') 62 | 63 | # Save the vectors for words in the vocab. 64 | with open(fasttext_path, errors='ignore') as in_file: 65 | with open(vocab_path, 'w') as out_file: 66 | vectors = {} 67 | for line in in_file: 68 | tokens = line.strip().split() 69 | vectors[tokens[0]] = line 70 | 71 | for word in vocab: 72 | try: 73 | out_file.write(vectors[word]) 74 | except KeyError: 75 | pass 76 | 77 | # Generate the sentence embeddings. 78 | def generate_embeddings(self, tag, vector_path): 79 | ''' 80 | Params: 81 | :tag: Whether it's source or target data. 82 | :vector_path: Path to save the sentence vectors. 83 | ''' 84 | vocab = {} 85 | vocab_path = os.path.join(self.input_dir, 'vocab.npy') 86 | if not os.path.exists(vocab_path): 87 | print('File containing word vectors not found in ' + self.input_dir) 88 | print('The file should be named \'vocab.npy\'') 89 | print('If you would like to use FastText embeddings press \'y\'') 90 | if input() == 'y': 91 | self.get_fast_text_embeddings() 92 | else: 93 | sys.exit() 94 | 95 | # Get the word embeddings. 96 | with open(vocab_path) as v: 97 | for line in v: 98 | tokens = line.strip().split() 99 | vocab[tokens[0]] = [0, np.array(list(map(float, tokens[1:])))] 100 | 101 | embedding_dim = vocab[list(vocab)[0]][1].shape[0] 102 | unique_sentences = set() 103 | word_count = 0 104 | 105 | # Statistics of number of words. 106 | for dp in self.data_points[tag]: 107 | unique_sentences.add(dp.string) 108 | for word in dp.string.split(): 109 | if vocab.get(word): 110 | vocab[word][0] += 1 111 | word_count += 1 112 | 113 | meaning_vectors = [] 114 | sentences = unique_sentences if self.unique else [ 115 | s.string for s in self.data_points[tag]] 116 | # Calculate smooth average embedding. 117 | for s in sentences: 118 | vectors = [] 119 | for word in s.split(): 120 | vector = vocab.get(word) 121 | if vector: 122 | vectors.append(vector[1] * 0.001 / (0.001 + vector[0] / word_count)) 123 | 124 | num_vecs = len(vectors) 125 | if num_vecs: 126 | meaning_vectors.append(np.sum(np.array(vectors), axis=0) / num_vecs) 127 | else: 128 | meaning_vectors.append(np.zeros(embedding_dim)) 129 | 130 | np.save(vector_path, np.array(meaning_vectors).reshape(-1, embedding_dim)) 131 | -------------------------------------------------------------------------------- /code/filtering/filter_problem.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | from collections import Counter 5 | 6 | from utils import utils 7 | 8 | 9 | class DataPoint: 10 | ''' 11 | A simple class that handles a string example. 12 | ''' 13 | def __init__(self, string, index): 14 | ''' 15 | Params: 16 | :string: String to be stored. 17 | :index: Number of the line in the file from which this sentence was read. 18 | ''' 19 | self.index = index 20 | self.string = ' '.join(string.strip('\n').split()) 21 | self.cluster_index = 0 22 | 23 | 24 | class Cluster: 25 | ''' 26 | A class to handle one cluster in the clustering problem. 27 | ''' 28 | def __init__(self, medoid): 29 | ''' 30 | Params: 31 | :medoid: Center of the cluster: a DataPoint object. 32 | ''' 33 | self.medoid = medoid 34 | self.elements = [] 35 | self.targets = [] 36 | self.entropy = 0 37 | self.index = 0 38 | 39 | # Append an element to the list of elements in the cluster. 40 | def add_element(self, element): 41 | self.elements.append(element) 42 | 43 | # append an element to the list of targets in the cluster. 44 | def add_target(self, target): 45 | self.targets.append(target) 46 | 47 | 48 | class FilterProblem: 49 | ''' 50 | An abstract class to handle different types of filtering. 51 | ''' 52 | 53 | @property 54 | def DataPointClass(self): 55 | return DataPoint 56 | 57 | @property 58 | def ClusterClass(self): 59 | return Cluster 60 | 61 | def __init__(self, config): 62 | ''' 63 | Params: 64 | :config: Config object storing all arguments. 65 | ''' 66 | self.config = config 67 | self.tag = config.filter_split 68 | self.threshold = config.threshold 69 | self.max_avg_length = config.max_avg_length 70 | self.max_medoid_length = config.max_medoid_length 71 | self.min_cluster_size = config.min_cluster_size 72 | self.unique = config.unique 73 | 74 | self.project_path = os.path.join( 75 | os.path.dirname(os.path.abspath(__file__)), '..', '..') 76 | self.output_dir = os.path.join(self.project_path, config.output_dir) 77 | self.input_dir = os.path.join(self.project_path, config.data_dir) 78 | self.type = config.filter_type 79 | 80 | self.clusters = {'Source': [], 'Target': []} 81 | self.data_points = {'Source': [], 'Target': []} 82 | self.num_clusters = {'Source': config.source_clusters, 83 | 'Target': config.target_clusters} 84 | 85 | self.build('Source') 86 | self.build('Target') 87 | 88 | # Build statistics and full files. 89 | def build(self, tag): 90 | ''' 91 | Params: 92 | :tag: 'Source' or 'Target'. 93 | ''' 94 | splits = ['train', 'dev', 'test'] 95 | try: 96 | files = [open(os.path.join(self.input_dir, split + tag + '.txt')).read() 97 | for split in splits] 98 | except FileNotFoundError: 99 | print('Data files not found in ' + self.input_dir) 100 | print('The following files should be here:') 101 | print('trainSource.txt\ntrainTarget.txt\ndevSource.txt\ndevTarget.txt') 102 | print('testSource.txt\ntestTarget.txt') 103 | sys.exit() 104 | 105 | full_path = os.path.join(self.input_dir, 'full' + tag + '.txt') 106 | if not os.path.exists(full_path): 107 | open(full_path, 'w').write(''.join(files)) 108 | 109 | if tag == 'Source': 110 | self.line_counts = dict(zip(splits, 111 | map(lambda x: len(x.split('\n')), files))) 112 | 113 | # Main method that will run all the functions to do the filtering. 114 | def run(self): 115 | # If we have already done the clustering, don't redo it. 116 | source_data = os.path.join(self.output_dir, 117 | self.tag + 'Source_cluster_elements.txt') 118 | target_data = os.path.join(self.output_dir, 119 | self.tag + 'Target_cluster_elements.txt') 120 | if os.path.exists(source_data) and os.path.exists(target_data): 121 | print('Cluster files are in ' + self.output_dir + ', filtering now.') 122 | self.load_clusters(source_data, target_data) 123 | self.filtering() 124 | 125 | else: 126 | print('No cluster files in ' + self.output_dir + ', clustering now.') 127 | self.read_inputs('Source') 128 | self.read_inputs('Target') 129 | self.clustering('Source') 130 | self.clustering('Target') 131 | self.save_clusters('Source') 132 | self.save_clusters('Target') 133 | self.filtering() 134 | 135 | # Read the data and make it ready for clustering. 136 | def read_inputs(self, tag): 137 | ''' 138 | Params: 139 | :tag: 'Source' or 'Target'. 140 | ''' 141 | file = open(os.path.join(self.input_dir, self.tag + tag + '.txt')) 142 | for i, line in enumerate(file): 143 | self.data_points[tag].append(self.DataPointClass(line, i)) 144 | file.close() 145 | 146 | print('Finished reading ' + tag + ' data.') 147 | 148 | # Load clusters from file. 149 | def load_clusters(self, source_path, target_path): 150 | ''' 151 | Params: 152 | :source_path: Path to source cluster elements. 153 | :target_path: Path to target cluster elements. 154 | ''' 155 | source_clusters = {} 156 | target_clusters = {} 157 | source_data_points = {} 158 | target_data_points = {} 159 | 160 | with open(source_path, 'r') as source_file: 161 | for line in source_file: 162 | # If the data contains these special characters it won't work. 163 | [source_index_center, source_target, _] = line.split('<=====>') 164 | [source_index, source_center] = source_index_center.split(';') 165 | [source, target] = source_target.split('=') 166 | 167 | # Initialize the source and target utterances. 168 | source_data_points[int(source_index)] = self.DataPointClass( 169 | source, int(source_index)) 170 | target_data_points[int(source_index)] = self.DataPointClass( 171 | target, int(source_index)) 172 | 173 | # If this is a new cluster add it to the list. 174 | if source_clusters.get(source_center) is None: 175 | center = self.DataPointClass(source_center, 0) 176 | source_clusters[source_center] = self.ClusterClass(center) 177 | source_clusters[source_center].index = len(source_clusters) - 1 178 | 179 | # Add the elements to the cluster. 180 | source_data_points[int(source_index)].cluster_index = \ 181 | source_clusters[source_center].index 182 | source_clusters[source_center].add_element( 183 | source_data_points[int(source_index)]) 184 | source_clusters[source_center].add_target( 185 | target_data_points[int(source_index)]) 186 | 187 | with open(target_path, 'r') as target_file: 188 | for line in target_file: 189 | [target_index_center, target_source, _] = line.split('<=====>') 190 | [target_index, target_center] = target_index_center.split(';') 191 | [target, source] = target_source.split('=') 192 | 193 | # All elements are already added at this point. 194 | target_data_point = target_data_points[int(target_index)] 195 | source_data_point = source_data_points[int(target_index)] 196 | 197 | # If this is a new cluster add it to the list. 198 | if target_clusters.get(target_center) is None: 199 | center = self.DataPointClass(target_center, 0) 200 | target_clusters[target_center] = self.ClusterClass(center) 201 | target_clusters[target_center].index = len(target_clusters) - 1 202 | 203 | # Add the elements to the cluster. 204 | target_data_point.cluster_index = target_clusters[target_center].index 205 | target_clusters[target_center].add_element(target_data_point) 206 | target_clusters[target_center].add_target(source_data_point) 207 | 208 | # Save the data correctly into self.clusters. 209 | def id(cl): 210 | return sorted(list(cl), key=lambda x: cl[x].index) 211 | self.clusters['Source'] = [source_clusters[i] for i in id(source_clusters)] 212 | self.clusters['Target'] = [target_clusters[i] for i in id(target_clusters)] 213 | 214 | # Cluster sources or targets, should be implemented in subclass. 215 | def clustering(self, tag): 216 | raise NotImplementedError 217 | 218 | # Return a list of indices, showing which clusters should be filtered out. 219 | def get_filtered_indices(self, tag): 220 | ''' 221 | Params: 222 | :tag: Source or Target. 223 | ''' 224 | indices = [] 225 | for num_cl, cluster in enumerate(self.clusters[tag]): 226 | # Build a distribution for the current cluster, based on the targets. 227 | distribution = Counter([t.cluster_index for t in cluster.targets]) 228 | 229 | num_elements = len(cluster.elements) 230 | # Calculate entropy. 231 | entropy = 0 232 | for cl_index in distribution: 233 | if num_elements > 1: 234 | probability = distribution[cl_index] / num_elements 235 | entropy += probability * math.log(probability, 2) 236 | cluster.entropy = -entropy 237 | 238 | avg_length = ( 239 | sum(len(sent.string.split()) for sent in cluster.elements) / 240 | (num_elements if num_elements > 0 else 1)) 241 | medoid_length = len(cluster.medoid.string.split()) 242 | 243 | # Filter based on threshold. 244 | if (cluster.entropy > self.threshold and 245 | avg_length < self.max_avg_length and 246 | medoid_length < self.max_medoid_length): 247 | indices.append(num_cl) 248 | 249 | print('Finished filtering ' + tag + ' data.') 250 | return indices 251 | 252 | # Do the filtering of the dataset. 253 | def filtering(self): 254 | # These are not needed anymore. 255 | self.data_points['Source'].clear() 256 | self.data_points['Target'].clear() 257 | 258 | # Get the filtered indices for both sides. 259 | source_indices = self.get_filtered_indices('Source') 260 | target_indices = self.get_filtered_indices('Target') 261 | 262 | file_dict = {} 263 | # We have to open 6 files in this case. 264 | if self.tag == 'full': 265 | name_list = ['trainS', 'trainT', 'devS', 'devT', 'testS', 'testT'] 266 | file_dict = dict(zip(name_list, self.open_6_files())) 267 | else: 268 | file_dict[self.tag + 'S'] = open( 269 | os.path.join(self.output_dir, self.tag + 'Source.txt'), 'w') 270 | file_dict[self.tag + 'T'] = open( 271 | os.path.join(self.output_dir, self.tag + 'Target.txt'), 'w') 272 | 273 | # Handle all other cases and open files. 274 | if self.type == 'source' or self.type == 'both': 275 | file_dict['source_entropy'] = open( 276 | os.path.join(self.output_dir, 277 | self.tag + 'Source_cluster_entropies.txt'), 'w') 278 | if self.type == 'target' or self.type == 'both': 279 | file_dict['target_entropy'] = open( 280 | os.path.join(self.output_dir, 281 | self.tag + 'Target_cluster_entropies.txt'), 'w') 282 | 283 | # Save data and close files. 284 | self.save_filtered_data(source_indices, target_indices, file_dict) 285 | utils.close_n_files(file_dict) 286 | 287 | # Save the new filtered datasets. 288 | def save_filtered_data(self, source_indices, target_indices, file_dict): 289 | ''' 290 | Params: 291 | :source_indices: Indices of source clusters that will be filtered. 292 | :target_indices: Indices of target clusters that will be filtered. 293 | :file_dict: Dictionary containing all the files that we want to write. 294 | ''' 295 | # Function for writing filtered source or target data to file. 296 | def save_dataset(tag): 297 | for num_cl, cluster in enumerate(self.clusters[tag]): 298 | # Write cluster entropies. 299 | file_dict[tag.lower() + '_entropy'].write( 300 | cluster.medoid.string + ';' + 301 | str(cluster.entropy) + ';' + 302 | str(len(cluster.elements)) + '\n') 303 | 304 | # Check if a cluster is smaller than threshold. 305 | cluster_too_small = len(cluster.elements) < self.min_cluster_size 306 | indices = source_indices if tag == 'Source' else target_indices 307 | 308 | # Make sure that in 'both' case this is only run once. 309 | if ((tag == 'Source' or self.type != 'both') and 310 | (num_cl not in indices or cluster_too_small)): 311 | # Filter one side. 312 | for num_el, element in enumerate(cluster.elements): 313 | target_cl = cluster.targets[num_el].cluster_index 314 | if self.type == 'both': 315 | cluster_too_small = ( 316 | len(self.clusters['Target'][target_cl].elements) < 317 | self.min_cluster_size) 318 | # Check both sides in 'both' case. 319 | if ((target_cl not in target_indices or cluster_too_small) or 320 | self.type != 'both'): 321 | 322 | # Reverse if Target. 323 | source = element.string + '\n' 324 | target = cluster.targets[num_el].string + '\n' 325 | source_string = source if tag == 'Source' else target 326 | target_string = target if tag == 'Source' else source 327 | 328 | # Separate the full case. 329 | if self.tag == 'full': 330 | if element.index < self.line_counts['train']: 331 | file_dict['trainS'].write(source_string) 332 | file_dict['trainT'].write(target_string) 333 | elif element.index < (self.line_counts['train'] + 334 | self.line_counts['dev']): 335 | file_dict['devS'].write(source_string) 336 | file_dict['devT'].write(target_string) 337 | else: 338 | file_dict['testS'].write(source_string) 339 | file_dict['testT'].write(target_string) 340 | else: 341 | file_dict[self.tag + 'S'].write(source_string) 342 | file_dict[self.tag + 'T'].write(target_string) 343 | 344 | # Write source entropies and data to file. 345 | if self.type == 'source' or self.type == 'both': 346 | save_dataset('Source') 347 | # Write target entropies and data to file. 348 | if self.type == 'target' or self.type == 'both': 349 | save_dataset('Target') 350 | 351 | # Save clusters and their elements to files. 352 | def save_clusters(self, tag): 353 | ''' 354 | Params: 355 | :tag: Whether it's source or target data. 356 | ''' 357 | output = open( 358 | os.path.join(self.output_dir, 359 | self.tag + tag + '_cluster_elements.txt'), 'w') 360 | 361 | medoid_counts = [] 362 | rev_tag = 'Target' if tag == 'Source' else 'Source' 363 | 364 | for cluster in self.clusters[tag]: 365 | medoid_counts.append((cluster.medoid.string, len(cluster.elements))) 366 | 367 | # Save together the source and target medoids and elements. 368 | for source, target in zip(cluster.elements, cluster.targets): 369 | output.write( 370 | str(source.index) + ';' + 371 | cluster.medoid.string + '<=====>' + 372 | source.string + '=' + 373 | target.string + '<=====>' + 374 | self.clusters[rev_tag][target.cluster_index].medoid.string + ':' + 375 | str(target.cluster_index) + '\n') 376 | output.close() 377 | 378 | # Save the medoids and the count of their elements, in decreasing order. 379 | output = open(os.path.join(self.output_dir, 380 | self.tag + tag + '_clusters.txt'), 'w') 381 | medoids = sorted(medoid_counts, key=lambda count: count[1], reverse=True) 382 | 383 | for medoid in medoids: 384 | output.write(medoid[0] + ':' + str(medoid[1]) + '\n') 385 | output.close() 386 | 387 | if tag == 'Target': 388 | print('Finished clustering, proceeding with filtering.') 389 | 390 | # Open the 6 files. 391 | def open_6_files(self): 392 | trainS = open(os.path.join(self.output_dir, 'trainSource.txt'), 'w') 393 | trainT = open(os.path.join(self.output_dir, 'trainTarget.txt'), 'w') 394 | devS = open(os.path.join(self.output_dir, 'devSource.txt'), 'w') 395 | devT = open(os.path.join(self.output_dir, 'devTarget.txt'), 'w') 396 | testS = open(os.path.join(self.output_dir, 'testSource.txt'), 'w') 397 | testT = open(os.path.join(self.output_dir, 'testTarget.txt'), 'w') 398 | 399 | return [trainS, trainT, devS, devT, testS, testT] 400 | -------------------------------------------------------------------------------- /code/filtering/identity.py: -------------------------------------------------------------------------------- 1 | from filtering.filter_problem import FilterProblem 2 | 3 | 4 | class Identity(FilterProblem): 5 | ''' 6 | Calculate entropy based on and filter individual utterances. 7 | ''' 8 | 9 | # Do the clustering of sources and targets. 10 | def clustering(self, tag): 11 | ''' 12 | Params: 13 | :tag: Whether it's source or target data. 14 | ''' 15 | rev_tag = 'Target' if tag == 'Source' else 'Source' 16 | 17 | clean_sents = [' '.join(dp.string.split()) for dp in self.data_points[tag]] 18 | sentence_set = list(set(clean_sents)) 19 | 20 | # Build a hash for efficient string searching. 21 | sentence_dict = {} 22 | for data_point, clean_sentence in zip(self.data_points[tag], clean_sents): 23 | if clean_sentence in sentence_dict: 24 | sentence_dict[clean_sentence].append(data_point) 25 | else: 26 | sentence_dict[clean_sentence] = [data_point] 27 | 28 | print(tag + ': ' + str(len(sentence_set)) + ' clusters') 29 | 30 | # Loop through the clusters. 31 | for i, sentence in enumerate(sentence_set): 32 | cl = self.ClusterClass(self.DataPointClass(sentence, 10)) 33 | self.clusters[tag].append(cl) 34 | 35 | # Loop through the data points associated with this sentence. 36 | for data_point in sentence_dict[sentence]: 37 | data_point.cluster_index = i 38 | cl.add_element(data_point) 39 | cl.add_target(self.data_points[rev_tag][data_point.index]) 40 | -------------------------------------------------------------------------------- /code/filtering/semantic_clustering.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from sklearn.neighbors import BallTree 4 | from sklearn.cluster import MeanShift 5 | 6 | from filtering import filter_problem 7 | from utils.config import Config 8 | 9 | if Config.use_faiss: 10 | try: 11 | from faiss import Kmeans 12 | Config.faiss = True 13 | 14 | except ImportError: 15 | print('Failed to import faiss, using SKLearn clustering instead.') 16 | 17 | if not Config.use_faiss: 18 | from sklearn.cluster import KMeans 19 | 20 | 21 | class DataPoint(filter_problem.DataPoint): 22 | ''' 23 | A simple class that handles a string example. 24 | ''' 25 | def __init__(self, string, index, meaning_vector=None): 26 | ''' 27 | Params: 28 | :string: String to be stored. 29 | :index: Number of the line in the file from which this sentence was read. 30 | :meaning_vector: Numpy embedding vector for the sentence. 31 | ''' 32 | super().__init__(string, index) 33 | self.meaning_vector = meaning_vector 34 | 35 | 36 | class SemanticClustering(filter_problem.FilterProblem): 37 | ''' 38 | Base class for the meaning-based (semantic vector representation) clustering. 39 | The source and target sentences are read into an extended DataPoint object, 40 | that also contains a 'meaning_vector' attribute. This attribute holds 41 | the semantic vector representation of the sentence, which will be used 42 | by the clustering logic. 43 | ''' 44 | 45 | @property 46 | def DataPointClass(self): 47 | return DataPoint 48 | 49 | def __init__(self, *args, **kwargs): 50 | super().__init__(*args, **kwargs) 51 | self.unique_data = {"Source": [], "Target": []} 52 | 53 | def clustering(self, tag): 54 | ''' 55 | Params: 56 | :tag: Whether it's source or target data. 57 | ''' 58 | data_points = self.unique_data if self.unique else self.data_points 59 | centroids = self.calculate_centroids(tag) 60 | 61 | data_point_vectors = np.array( 62 | [data_point.meaning_vector for data_point in 63 | data_points[tag]]).reshape( 64 | -1, data_points[tag][0].meaning_vector.shape[-1]) 65 | 66 | # Get the actual data point for each centroid. 67 | tree = BallTree(data_point_vectors) 68 | _, centroids = tree.query(centroids, k=1) 69 | 70 | # Get the closest centroid for each data point. 71 | tree = BallTree(data_point_vectors[np.array(centroids).reshape(-1)]) 72 | _, labels = tree.query(data_point_vectors, k=1) 73 | labels = labels.reshape(-1) 74 | 75 | # Build the list of clusters. 76 | clusters = {index: self.ClusterClass(data_points[tag][index]) for 77 | index in {labels[_index] for _index in range(len(labels))}} 78 | clusters = [(clusters[cluster_index], cluster_index) for cluster_index in 79 | sorted(list(clusters))] 80 | 81 | label_lookup = {c[1]: i for i, c in enumerate(clusters)} 82 | clusters = [c[0] for c in clusters] 83 | 84 | rev_tag = 'Target' if tag == 'Source' else 'Source' 85 | 86 | # Store the cluster index for each unique sentence. 87 | if self.unique: 88 | cluster_ind_dict = {} 89 | for data_point, cluster_index in zip(data_points[tag], labels): 90 | cluster_ind_dict[data_point.string] = label_lookup[cluster_index] 91 | 92 | # This is different for unique clustering. 93 | for i, data_point in enumerate(self.data_points[tag]): 94 | cl_index = cluster_ind_dict[data_point.string] 95 | data_point.cluster_index = cl_index 96 | clusters[cl_index].add_element(data_point) 97 | clusters[cl_index].add_target(self.data_points[rev_tag][i]) 98 | 99 | # Assign the actual clusters. 100 | else: 101 | for dp, cl_index in zip(self.data_points[tag], labels): 102 | cl_index = label_lookup[cl_index] 103 | dp.cluster_index = cl_index 104 | clusters[cl_index].add_element(dp) 105 | clusters[cl_index].add_target(self.data_points[rev_tag][dp.index]) 106 | 107 | self.clusters[tag] = clusters 108 | 109 | def read_inputs(self, tag): 110 | ''' 111 | Called twice for source and target data. It should implement the 112 | logic of reading the data from Source and Target files into the 113 | data_points list. Each sentence should be wrapped into an appropriate 114 | subclass of the DataPoint class. Source.npy and Target.npy should contain 115 | sentence embeddings, if not they have to be generated in a subclass. 116 | These vectors in the .npy files have to correspond to the loaded sentences. 117 | 118 | Params: 119 | :tag: Whether it's source or target data. 120 | ''' 121 | super().read_inputs(tag) 122 | 123 | vector_path = os.path.join(self.input_dir, self.tag + tag + '.npy') 124 | if not os.path.exists(vector_path): 125 | print('No sentence embeddings found in ' + self.input_dir) 126 | print('They should be named \'fullSource.npy\' and \'fullTarget.npy\',') 127 | print('where each line is a vector corresponding to') 128 | print('sentences in \'fullSource.txt\' and \'fullTarget.txt\'.') 129 | print('Building sentence representations of ' + self.config.cluster_type) 130 | self.generate_embeddings(tag, vector_path) 131 | 132 | # Add vectors to sentences. 133 | sent_vectors = np.load(vector_path) 134 | if not self.unique: 135 | for index, dp in enumerate(self.data_points[tag]): 136 | dp.meaning_vector = sent_vectors[index] 137 | # Create unique data points if necessary. 138 | else: 139 | for i, sent in enumerate(set([s.string for s in self.data_points[tag]])): 140 | self.unique_data[tag].append(self.DataPointClass(sent, 141 | i, 142 | sent_vectors[i])) 143 | 144 | # Has to be implemented by subclass to generate sentence embeddings. 145 | def generate_embeddings(self, tag): 146 | raise NotImplementedError 147 | 148 | # Cluster the data and return the centers. 149 | def calculate_centroids(self, tag): 150 | ''' 151 | Params: 152 | :tag: Whether it's source or target data. 153 | ''' 154 | data_points = self.unique_data if self.unique else self.data_points 155 | matrix = np.stack([dp.meaning_vector for dp in data_points[tag]]) 156 | 157 | if self.config.clustering_method == 'kmeans': 158 | # Kmeans with either the faiss or the sklearn implementation. 159 | if self.config.use_faiss: 160 | kmeans = Kmeans(matrix.shape[1], self.num_clusters[tag], 20, True) 161 | kmeans.train(matrix) 162 | centroids = kmeans.centroids 163 | else: 164 | kmeans = KMeans(n_clusters=self.num_clusters[tag], 165 | random_state=0, 166 | n_jobs=10).fit(matrix) 167 | centroids = kmeans.cluster_centers_ 168 | 169 | else: 170 | mean_shift = MeanShift(bandwidth=self.config.bandwidth, n_jobs=10) 171 | mean_shift.fit(matrix) 172 | centroids = mean_shift.cluster_centers_ 173 | 174 | return centroids 175 | -------------------------------------------------------------------------------- /code/filtering/sent2vec.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from filtering import semantic_clustering 4 | 5 | 6 | class Sent2vec(semantic_clustering.SemanticClustering): 7 | def generate_embeddings(self, tag, vector_path): 8 | print('Currently sent2vec only works with provided sentence embeddings.') 9 | print('Check github.com/epfml/sent2vec for getting sentence embeddings.') 10 | print('Btw any kind of sentence embeddings can be used if they are in the') 11 | print('required format, I recommend github.com/hanxiao/bert-as-service.') 12 | sys.exit() 13 | -------------------------------------------------------------------------------- /code/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from utils.config import Config 4 | from filtering.average_word_embedding import AverageWordEmbedding 5 | from filtering.identity import Identity 6 | from filtering.sent2vec import Sent2vec 7 | 8 | 9 | def main(): 10 | config = Config() 11 | parser = argparse.ArgumentParser( 12 | description='Code for filtering methods in: arxiv.org/abs/1905.05471. ' + 13 | 'These arguments can also be set in config.py, ' + 14 | 'and will be saved to the output directory.') 15 | parser.add_argument('-d', '--data_dir', default=config.data_dir, 16 | help='Directory containing the dataset in these files:' + 17 | ' (trainSource.txt, trainTarget.txt, devSource.txt, ' + 18 | 'devTarget.txt, testSource.txt, testTarget.txt, ' + 19 | 'vocab.txt)', 20 | metavar='') 21 | parser.add_argument('-o', '--output_dir', default=config.output_dir, 22 | help='Save here the filtered data and any output', 23 | metavar='') 24 | parser.add_argument('-l', '--load_config', default=config.load_config, 25 | help='Path to load config from file, or leave empty ' + 26 | '(default: %(default)s)', 27 | metavar='') 28 | parser.add_argument('-fs', '--filter_split', default=config.filter_split, 29 | help='Data split to filter, \'full\' filters ' + 30 | 'all splits (choices: %(choices)s)', 31 | metavar='', choices=['full', 'train', 'dev', 'test']) 32 | parser.add_argument('-ct', '--cluster_type', default=config.cluster_type, 33 | help='Clustering method (choices: %(choices)s)', 34 | metavar='', 35 | choices=['identity', 'avg_embedding', 'sent2vec']) 36 | parser.add_argument('-sc', '--source_clusters', 37 | default=config.source_clusters, 38 | help='Number of source clusters in case of Kmeans', 39 | metavar='', type=int) 40 | parser.add_argument('-tc', '--target_clusters', 41 | default=config.target_clusters, 42 | help='Number of target clusters in case of Kmeans', 43 | metavar='', type=int) 44 | parser.add_argument('-u', '--unique', default=config.unique, 45 | help='Whether to cluster only unique sentences ' + 46 | '(default: %(default)s)', 47 | metavar='', type=bool) 48 | parser.add_argument('-vs', '--vocab_size', 49 | default=config.vocab_size, 50 | help='Vocab size, only used if vocab file not given', 51 | metavar='', type=int) 52 | parser.add_argument('-ft', '--filter_type', default=config.filter_type, 53 | help='Filtering way (choices: %(choices)s)', 54 | metavar='', choices=['source', 'target', 'both']) 55 | parser.add_argument('-mins', '--min_cluster_size', 56 | default=config.min_cluster_size, 57 | help='Clusters with fewer elements won\'t get filtered' + 58 | ' (default: %(default)s)', 59 | metavar='', type=int) 60 | parser.add_argument('-t', '--threshold', default=config.threshold, 61 | help='Entropy threshold (default: %(default)s)', 62 | metavar='', type=int) 63 | parser.add_argument('-cm', '--clustering_method', 64 | default=config.clustering_method, 65 | help='Mean shift recommended (choices: %(choices)s)', 66 | metavar='', choices=['kmeans', 'mean_shift']) 67 | parser.add_argument('-bw', '--bandwidth', default=config.bandwidth, 68 | help='Mean shift bandwidth (default: %(default)s)', 69 | metavar='', type=float) 70 | parser.add_argument('-f', '--use_faiss', default=config.use_faiss, 71 | help='Whether to use faiss for GPU based clustering ' + 72 | '(default: %(default)s)', 73 | metavar='', type=bool) 74 | parser.add_argument('-maxal', '--max_avg_length', 75 | default=config.max_avg_length, 76 | help='Clusters with higher average sentence length' + 77 | 'won\'t get filtered (default: %(default)s)', 78 | metavar='', type=int) 79 | parser.add_argument('-maxml', '--max_medoid_length', 80 | default=config.max_medoid_length, 81 | help='Clusters with longer medoids won\'t get filtered' + 82 | ' (default: %(default)s)', 83 | metavar='', type=int) 84 | 85 | parser.parse_args(namespace=config) 86 | if config.load_config: 87 | config.load() 88 | config.save() 89 | 90 | filter_problems = { 91 | 'identity': Identity, 92 | 'avg_embedding': AverageWordEmbedding, 93 | 'sent2vec': Sent2vec, 94 | } 95 | 96 | problem = filter_problems[config.cluster_type](config) 97 | problem.run() 98 | 99 | 100 | if __name__ == "__main__": 101 | main() 102 | -------------------------------------------------------------------------------- /code/utils/config.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | 4 | 5 | # These can also be set as arguments via the command line. 6 | class Config: 7 | data_dir = 'data/DailyDialog/baseline' # Directory containing dataset. 8 | output_dir = 'data/DailyDialog/baseline/filtered_data/avg_embedding' 9 | load_config = None 10 | source_clusters = 0 11 | target_clusters = 0 12 | filter_split = 'full' # Which data split to filter. 13 | cluster_type = 'avg_embedding' 14 | unique = False # Whether to cluster only unique sentences. 15 | vocab_size = 16384 # Only used for average word embeddings. 16 | filter_type = 'both' 17 | min_cluster_size = 2 # Clusters with fewer elements won't get filtered. 18 | threshold = 1.1 # Entropy threshold for filtering. 19 | clustering_method = 'mean_shift' # Kmeans or mean_shift. 20 | bandwidth = 0.7 # Mean shift bandwidth. 21 | use_faiss = False # Whether to use the library for GPU based clustering. 22 | max_avg_length = 15 # Clusters with longer sentences won't get filtered. 23 | max_medoid_length = 50 # Clusters with longer medoids won't get filtered. 24 | project_path = os.path.join( 25 | os.path.dirname(os.path.abspath(__file__)), '..', '..') 26 | 27 | # Save this object to output dir. 28 | def save(self): 29 | out_dir = os.path.join(self.project_path, self.output_dir) 30 | if not os.path.exists(out_dir): 31 | os.makedirs(out_dir) 32 | 33 | file = open(os.path.join(out_dir, 'config'), 'wb') 34 | file.write(pickle.dumps(self.__dict__)) 35 | file.close() 36 | 37 | # Load from output dir. 38 | def load(self): 39 | load_config = os.path.join(self.project_path, self.load_config, 'config') 40 | file = open(load_config, 'rb') 41 | self.__dict__ = pickle.loads(file.read()) 42 | file.close() 43 | -------------------------------------------------------------------------------- /code/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Close n files to write the processed data into. 2 | def close_n_files(files): 3 | for file_name in files: 4 | files[file_name].close() 5 | -------------------------------------------------------------------------------- /code/utils/visualization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "application/javascript": [ 11 | "IPython.OutputArea.prototype._should_scroll = function(lines) {\n", 12 | " return false;\n", 13 | "}\n" 14 | ], 15 | "text/plain": [ 16 | "" 17 | ] 18 | }, 19 | "metadata": {}, 20 | "output_type": "display_data" 21 | } 22 | ], 23 | "source": [ 24 | "%%javascript\n", 25 | "IPython.OutputArea.prototype._should_scroll = function(lines) {\n", 26 | " return false;\n", 27 | "}" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": { 34 | "colab": { 35 | "base_uri": "https://localhost:8080/", 36 | "height": 122 37 | }, 38 | "colab_type": "code", 39 | "executionInfo": { 40 | "elapsed": 27098, 41 | "status": "ok", 42 | "timestamp": 1562670013345, 43 | "user": { 44 | "displayName": "Richard Csaky", 45 | "photoUrl": "https://lh6.googleusercontent.com/-wOnQ8NuCQqo/AAAAAAAAAAI/AAAAAAAAABQ/GcPUlAm-a98/s64/photo.jpg", 46 | "userId": "14062868679888781538" 47 | }, 48 | "user_tz": -120 49 | }, 50 | "id": "C_bdg07CY0fR", 51 | "outputId": "32f43021-8632-46ff-f5f0-4ca318d6cefe" 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "# Only run if you want to use google drive inside the notebook.\n", 56 | "from google.colab import drive\n", 57 | "drive.mount('/content/drive')" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": { 63 | "colab_type": "text", 64 | "id": "-4T8h7psXxMI" 65 | }, 66 | "source": [ 67 | "# Parameters" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 2, 73 | "metadata": { 74 | "cellView": "both", 75 | "colab": {}, 76 | "colab_type": "code", 77 | "id": "obzOs3Y1X6qm" 78 | }, 79 | "outputs": [], 80 | "source": [ 81 | "#@title Path to directory containing filtered files and config file\n", 82 | "DIR = \"data/DailyDialog/baseline/filtered_data/\" #@param {type:\"string\"}" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": { 88 | "colab_type": "text", 89 | "id": "AnLdGG7LSd1B" 90 | }, 91 | "source": [ 92 | "# Setup\n", 93 | "Run some setup code and define the functions that will be used later." 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 3, 99 | "metadata": { 100 | "colab": { 101 | "base_uri": "https://localhost:8080/", 102 | "height": 316 103 | }, 104 | "colab_type": "code", 105 | "executionInfo": { 106 | "elapsed": 652, 107 | "status": "error", 108 | "timestamp": 1562671037191, 109 | "user": { 110 | "displayName": "Richard Csaky", 111 | "photoUrl": "https://lh6.googleusercontent.com/-wOnQ8NuCQqo/AAAAAAAAAAI/AAAAAAAAABQ/GcPUlAm-a98/s64/photo.jpg", 112 | "userId": "14062868679888781538" 113 | }, 114 | "user_tz": -120 115 | }, 116 | "id": "eoZ4rowjSrbu", 117 | "outputId": "0e84048e-681f-41b6-c33f-2d54c9981f7a" 118 | }, 119 | "outputs": [], 120 | "source": [ 121 | "%matplotlib inline\n", 122 | "\n", 123 | "import matplotlib.pyplot as plt\n", 124 | "import os\n", 125 | "import matplotlib.pyplot as plt\n", 126 | "import numpy as np\n", 127 | "import operator\n", 128 | "\n", 129 | "from config import Config\n", 130 | "\n", 131 | "\n", 132 | "# Load config file from specified directory\n", 133 | "Config.load_config = DIR\n", 134 | "config = Config()\n", 135 | "config.load()\n", 136 | "\n", 137 | "plt.rcParams.update({'font.size': 14})" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 4, 143 | "metadata": { 144 | "colab": {}, 145 | "colab_type": "code", 146 | "id": "kejzW0VnSP0L" 147 | }, 148 | "outputs": [], 149 | "source": [ 150 | "# Visualization function for the clustering data.\n", 151 | "def _visualize(file, tag, fig_list):\n", 152 | " '''\n", 153 | " Params:\n", 154 | " :file: Clustering file, from which to read data.\n", 155 | " :tag: Can be 'Source' or 'Target'.\n", 156 | " :fig_list: A list containing the plots, which we will draw.\n", 157 | " '''\n", 158 | " sentence_entropy = []\n", 159 | " entropies_all = []\n", 160 | " entropies = []\n", 161 | " sentence_cl_size = []\n", 162 | " cl_sizes_all = []\n", 163 | " cl_sizes = []\n", 164 | " lengths = []\n", 165 | "\n", 166 | " for line in file:\n", 167 | " [sentence, entropy, cl_size] = line.split(';')\n", 168 | " entropy = float(entropy)\n", 169 | " cl_size = int(cl_size)\n", 170 | "\n", 171 | " # Populate the lists.\n", 172 | " sentence_entropy.append([sentence, entropy])\n", 173 | " entropies_all.extend([entropy] * cl_size)\n", 174 | " entropies.append(entropy)\n", 175 | " sentence_cl_size.append([sentence, cl_size])\n", 176 | " cl_sizes_all.extend([cl_size] * cl_size)\n", 177 | " cl_sizes.append(cl_size)\n", 178 | " lengths.append(len(sentence.split()))\n", 179 | "\n", 180 | " # Draw the plots, and set properties.\n", 181 | " fig_list[0].plot(sorted(entropies_all))\n", 182 | " fig_list[0].set_xlabel('Sentence no.')\n", 183 | " fig_list[0].set_ylabel('Entropy')\n", 184 | " #fig_list[0].axis([0, 90000, -0.2, 9])\n", 185 | "\n", 186 | " fig_list[1].plot(sorted(cl_sizes_all))\n", 187 | " fig_list[1].set_xlabel('Sentence no.')\n", 188 | " fig_list[1].set_ylabel('Cluster size')\n", 189 | " #fig_list[1].axis([0, 90000, -0.2, 500])\n", 190 | "\n", 191 | " fig_list[2].scatter(np.array(cl_sizes), np.array(entropies))\n", 192 | " fig_list[2].set_xlabel('Cluster size')\n", 193 | " fig_list[2].set_ylabel('Entropy')\n", 194 | " #fig_list[2].axis([0, 320, -0.2, 9])\n", 195 | "\n", 196 | " fig_list[3].scatter(np.array(lengths), np.array(entropies))\n", 197 | " fig_list[3].set_xlabel('No. of words in utterance')\n", 198 | " fig_list[3].set_ylabel('Entropy')\n", 199 | " #fig_list[3].axis([-0.2, 30, -0.2, 10])\n", 200 | "\n", 201 | " # Sort the sentence lists.\n", 202 | " sent_ent = sorted(sentence_entropy, key=operator.itemgetter(1), reverse=True)\n", 203 | " sent_cl = sorted(sentence_cl_size, key=operator.itemgetter(1), reverse=True)\n", 204 | " return sent_ent, sent_cl" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 5, 210 | "metadata": { 211 | "colab": {}, 212 | "colab_type": "code", 213 | "id": "jetWP3UaXH6q" 214 | }, 215 | "outputs": [], 216 | "source": [ 217 | "# Main function to visualize clustering/filtering results.\n", 218 | "def data_visualization():\n", 219 | " # Open the clustering files.\n", 220 | " source_cl_entropies = open(os.path.join(config.project_path, config.output_dir, 'fullSource_cluster_entropies.txt'))\n", 221 | " target_cl_entropies = open(os.path.join(config.project_path, config.output_dir, 'fullTarget_cluster_entropies.txt'))\n", 222 | "\n", 223 | " # Set up matplotlib.\n", 224 | " plt.close('all')\n", 225 | " fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6), (ax7, ax8)) = plt.subplots(nrows=4, ncols=2)\n", 226 | " fig.set_size_inches(13, 20)\n", 227 | "\n", 228 | " # Call the actual visualization function for source and target data.\n", 229 | " source_entropies, source_cl_sizes = _visualize(source_cl_entropies,\n", 230 | " 'Source',\n", 231 | " [ax1, ax3, ax5, ax7])\n", 232 | " target_entropies, target_cl_sizes = _visualize(target_cl_entropies,\n", 233 | " 'Target',\n", 234 | " [ax2, ax4, ax6, ax8])\n", 235 | " plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=1.0)\n", 236 | "\n", 237 | " source_cl_entropies.close()\n", 238 | " target_cl_entropies.close()" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 10, 244 | "metadata": { 245 | "colab": {}, 246 | "colab_type": "code", 247 | "id": "hfgy52_TXMQl" 248 | }, 249 | "outputs": [], 250 | "source": [ 251 | "def print_clusters(tag, top_clusters):\n", 252 | " clusters = {}\n", 253 | " cluster_element_lengths = {}\n", 254 | "\n", 255 | " with open(os.path.join(config.project_path, config.output_dir, 'full{}_cluster_elements.txt'.format(tag))) as file:\n", 256 | " for line in file:\n", 257 | " [source, source_cl_target, target_cl] = line.split('<=====>')\n", 258 | "\n", 259 | " if tag == 'Source':\n", 260 | " source_cl = source.split(';')[1]\n", 261 | " target_cl_ind = target_cl.split(':')[1].strip('\\n')\n", 262 | " source = source_cl_target.split('=')[0]\n", 263 | " target = source_cl_target.split('=')[1]\n", 264 | " cluster_element_lengths[source_cl] = \\\n", 265 | " cluster_element_lengths.get(source_cl, 0) + len(source.split())\n", 266 | " clusters[source_cl] = [*clusters.get(source_cl, []), source]\n", 267 | "\n", 268 | " else:\n", 269 | " target_cl = source.split(';')[1]\n", 270 | " target = source_cl_target.split('=')[0]\n", 271 | " cluster_element_lengths[target_cl] = \\\n", 272 | " cluster_element_lengths.get(target_cl, 0) + len(target.split())\n", 273 | " clusters[target_cl] = [*clusters.get(target_cl, []), target]\n", 274 | "\n", 275 | " with open(os.path.join(config.project_path, config.output_dir, 'full{}_cluster_entropies.txt'.format(tag))) as file:\n", 276 | " entropies = {}\n", 277 | " for line in file:\n", 278 | " [medoid, entropy, size] = line.split(';')\n", 279 | " entropies[medoid] = float(entropy)\n", 280 | "\n", 281 | " num_removed = 0\n", 282 | " num_all = 0\n", 283 | " for medoid in cluster_element_lengths:\n", 284 | " num_all += len(clusters[medoid])\n", 285 | " if ((cluster_element_lengths[medoid] / len(clusters[medoid]) if\n", 286 | " len(clusters[medoid]) > 0 else 1) > 1000 or\n", 287 | " len(medoid.split()) > 1000 or\n", 288 | " entropies[medoid] < config.threshold):\n", 289 | " num_removed += len(clusters[medoid])\n", 290 | "\n", 291 | " #print(num_removed / num_all)\n", 292 | " for _, medoid in zip(range(top_clusters),\n", 293 | " sorted(list(clusters), key=lambda x: entropies[x],\n", 294 | " reverse=True)):\n", 295 | " print('=====================================================')\n", 296 | " #print('{}& {} & {} \\\\\\\\'.format(list(set(clusters[medoid]))[0], len(clusters[medoid]), str(entropies[medoid])[:4]))\n", 297 | " print('Center: {}'.format(medoid))\n", 298 | " print('Entropy: {}'.format(entropies[medoid]))\n", 299 | " print('Size: {}'.format(len(clusters[medoid])))\n", 300 | " if len(clusters[medoid]) < 1000:\n", 301 | " print('Elements: \\n{}\\n\\n'.format('\\n'.join(set(clusters[medoid]))))" 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "metadata": { 307 | "colab_type": "text", 308 | "id": "hItN-dJISP0N" 309 | }, 310 | "source": [ 311 | "# Visualize the clustering.\n", 312 | "Graphs on the left are about source data, and on the right the target data.\n", 313 | "* First the entropy of all the utterances in the dataset is plotted.\n", 314 | "* Second each sentence's cluster's size is plotted (for all sentences in the dataset).\n", 315 | "* Third the entropy and cluster size of all clusters is plotted\n", 316 | "* Finally the relationship between the entropy of an utterance and the number of words in it is plotted." 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 15, 322 | "metadata": { 323 | "colab": {}, 324 | "colab_type": "code", 325 | "id": "EKGlzOI6SP0O", 326 | "outputId": "52151bf6-100e-4905-c798-669cd7fd3d67", 327 | "scrolled": true 328 | }, 329 | "outputs": [ 330 | { 331 | "data": { 332 | "image/png": "\n", 333 | "text/plain": [ 334 | "
" 335 | ] 336 | }, 337 | "metadata": { 338 | "needs_background": "light" 339 | }, 340 | "output_type": "display_data" 341 | } 342 | ], 343 | "source": [ 344 | "data_visualization()" 345 | ] 346 | }, 347 | { 348 | "cell_type": "markdown", 349 | "metadata": { 350 | "colab_type": "text", 351 | "collapsed": true, 352 | "id": "I3m_tPouSP0U" 353 | }, 354 | "source": [ 355 | "# Print some clusters\n", 356 | "Let's see the unique elements of the clusters with highest entropy." 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 13, 362 | "metadata": { 363 | "colab": {}, 364 | "colab_type": "code", 365 | "id": "B-y1cyabSP0V", 366 | "outputId": "02f5b184-a729-4089-f788-64d1ef39933d", 367 | "scrolled": true 368 | }, 369 | "outputs": [ 370 | { 371 | "name": "stdout", 372 | "output_type": "stream", 373 | "text": [ 374 | "=====================================================\n", 375 | "Center: so dick how about getting some coffee for tonight ?\n", 376 | "Entropy: 2.3204735011414868\n", 377 | "Size: 23\n", 378 | "Elements: \n", 379 | "i do not care the brand as long as it works well .\n", 380 | "no the steak was recommended but it is not very fresh .\n", 381 | "i have this list of stuff that i need and i only have half the dough .\n", 382 | "half the dough huh . well . how would you like to earn the other half ?\n", 383 | "would you care for a drink before you order ?\n", 384 | "yes i have many things to buy . i would like to choose the cleaning milk first .\n", 385 | "i want to buy the toothpaste the brand of jiajieshi .\n", 386 | "not today honey . do n't eat too much ice cream .\n", 387 | "sounds good . what about shampoo ? i would like to buy the product that prevents scurf .\n", 388 | "yes very . believe it or not it will cost you more than one hundred dollars .\n", 389 | "mom can i have some ice cream ?\n", 390 | "i want to buy more books .\n", 391 | "i just happen to have a question for you guys . why do the chinese cook the vegetables ? you see what i mean is that most vitamin are destroyed when heated .\n", 392 | "is that expensive ?\n", 393 | "i need a packet of cigarettes please .\n", 394 | "yes we charge 50 cents for water .\n", 395 | "mom can i have one more piece of cake ?\n", 396 | "where is your store located ?\n", 397 | "peter go and tidy up your toys now .\n", 398 | "i think i going to need some iced water too . is there an extra charge for that ?\n", 399 | "no sweat . it s a piece of cake .\n", 400 | "may i have a cookie ?\n", 401 | "that 's fine . could you give me some more napkins too ?\n", 402 | "\n", 403 | "\n", 404 | "=====================================================\n", 405 | "Center: come on you can at least try a little besides your cigarette .\n", 406 | "Entropy: 2.265710144725459\n", 407 | "Size: 26\n", 408 | "Elements: \n", 409 | "but i already tried my best .\n", 410 | "wow i ca n't thank you enough .\n", 411 | "thanks a lot . that 's the favor i was going to ask you for .\n", 412 | "i am wiped out . thank you .\n", 413 | "i 'm sorry marry is out right now .\n", 414 | "ok . but i 'm not familiar . i do n't know the beginning part .\n", 415 | "forgive you for what ?\n", 416 | "oh yeah ? are you joking ?\n", 417 | "16 divided by 2 . what 's the answer ?\n", 418 | "i m afraid not . i apologize .\n", 419 | "no ! are you kidding me ?\n", 420 | "jim could you do me a favor ?\n", 421 | "yes i believe they are . here are something that might interest you .\n", 422 | "i 'm afraid there 's been a mistake .\n", 423 | "how much does it cost ?\n", 424 | "i 'm mad ! i said now ! turn off the tv and do it now .\n", 425 | "never ! but thank you for inviting me .\n", 426 | "i 'm sorry sir . what seems to be the trouble ?\n", 427 | "let 's go !\n", 428 | "dad can you just tell me what it means ? i 'm too lazy .\n", 429 | "what time should i call back ?\n", 430 | "you 've got it . tell me a little bit about what you might be wanted .\n", 431 | "thanks for helping me out . i really appreciate it .\n", 432 | "thank you . maybe we can sing a song together . would you like to sing with me ?\n", 433 | "thanks a lot .\n", 434 | "i prefer glossy . how much ?\n", 435 | "\n", 436 | "\n", 437 | "=====================================================\n", 438 | "Center: look next time get yourself some comfy shoes . you re gonna come back again with me aren t you ?\n", 439 | "Entropy: 2.187126559186671\n", 440 | "Size: 67\n", 441 | "Elements: \n", 442 | "would you mind waiting a while ?\n", 443 | "dad . here comes another bus .\n", 444 | "i 'll put all that into the bag for you .\n", 445 | "peter wash your hands first and then have some dessert .\n", 446 | "when is a good time to catch him ?\n", 447 | "i do n't know about all of that but think about it their business gets free publicity !\n", 448 | "did somebody hit you ? or did you just fall ?\n", 449 | "so dick how about getting some coffee for tonight ?\n", 450 | "of course not ! i 'm telling the truth .\n", 451 | "it looks to be a nice day today .\n", 452 | "jessie i m afraid i can t come back home for dinner tonight .\n", 453 | "dad but i do n't want to share a room with peter . he snores every night .\n", 454 | "oh great ! it 's delicious . you see i am already putting on weight . there is one thing i do n't like however msg .\n", 455 | "do you mind if i smoke ?\n", 456 | "it runs every 15 minutes . you must have missed it .\n", 457 | "oh no get off it . it wasn t such a killer class . you just have to get into it . like they say no pain no gain .\n", 458 | "oh do n't let that worry you . if that were true china would n't have such a large population .\n", 459 | "oh i almost forgot . it 's my mum 's birthday saturday . i need to get her some more chanel . could you get me the 1 . 7 ounce bottle of chanel cologne ?\n", 460 | "mom . i have to go school shopping . there 's only one more week left .\n", 461 | "i swear i m going to kill you for this .\n", 462 | "well i want a fillet steak medium but my little girl does n't care for steak . could she have something else instead ?\n", 463 | "are things still going badly with your houseguest ?\n", 464 | "what 's wrong with msg ? it helps to bring out the taste of the food .\n", 465 | "mrs . smith has not checked in yet .\n", 466 | "which do you prefer color or black and white ?\n", 467 | "honey you can ask him to be quite . otherwise you may punish him and tell him to stand out of the room right ?\n", 468 | "what happened peter ? did you have a fight ?\n", 469 | "leo i really think you re beating around the bush with this guy . i know he used to be your best friend in college but i really think it s time to lay down the law .\n", 470 | "ok i 'm sorry it took so long .\n", 471 | "oh honey i 'm so sorry we do n't have enough space for you to have your own room .\n", 472 | "no he 's perfectly harmless . and he 's not afraid of strangers either . here hold him .\n", 473 | "the movie company has to pay them ?\n", 474 | "you can use this . it has special effect for keeping your face moisturized . it has this lotion as a gift attached .\n", 475 | "i used your computer . and i m afraid i ve erased your personal files accidentally .\n", 476 | "can you leave a message for her to call her office ?\n", 477 | "may i sit here ?\n", 478 | "i wonder how all the businesses in the area feel about that .\n", 479 | "he stepped out of the office for a little while .\n", 480 | "isn t he the best instructor ? i think he s so hot . wow ! i really feel energized don t you ?\n", 481 | "sure we are off singing road close to the bank .\n", 482 | "i 'm looking for an engagement ring for my girlfriend . i have an idea of what she likes but i want to surprise her with something special too .\n", 483 | "good evening . welcome to cherry 's . do you have a reservation ?\n", 484 | "getting worse . now he s eating me out of house and home . i ve tried talking to him but it all goes in one ear and out the other . he makes himself at home which is fine . but what really gets me is that yesterday he walked into the living room in the raw and i had company over ! that was the last straw .\n", 485 | "the phone has been ringing off the hook today .\n", 486 | "i am in no particular hurry .\n", 487 | "you are in luck . we just receive a shipment of several different styles of white purses .\n", 488 | "we ca n't go that way the road is blocked for the next few days .\n", 489 | "oh honey you made a mistake .\n", 490 | "i 'm looking for some blush . do you still have some in peach rose ?\n", 491 | "you stepped on my foot !\n", 492 | "ahahah ! what is that thing on your couch ! it just moved !\n", 493 | "sam you ve got to forgive me .\n", 494 | "mom just ten more minutes . the show is going to be over soon .\n", 495 | "they will be ready at noon tomorrow . each negative develops one print right ?\n", 496 | "now we 're going to draw an apple in your sketch book . what do we use ?\n", 497 | "i believe you have charged me twice for the same thing . look the figure of 6 . 5 dollar appears here then again here .\n", 498 | "go straight up zhongshan road and you will see our sign on your right after you pass the museum .\n", 499 | "excuse me i 've been waiting here for 10 minutes . do you know how often does no . 636 run ?\n", 500 | "kata ! you 've got a beautiful singing voice . you hit the high notes perfectly .\n", 501 | "can you connect me to mary . smith hotel room ?\n", 502 | "they must be popular again this season .\n", 503 | "sorry i ca n't sing the song .\n", 504 | "i can t believe it ! i have all my important personal documents stored in that computer . it s no laughing matter .\n", 505 | "you have a pet lizard ? somehow i never would have imagined that .\n", 506 | "oh ! sorry to hear that . this is quite unusual as we have steak from the market every day .\n", 507 | "not back home for dinner again ? that s the third time this week !\n", 508 | "come on you can at least try a little besides your cigarette .\n", 509 | "\n", 510 | "\n", 511 | "=====================================================\n", 512 | "Center: what s wrong with that ? cigarette is the thing i go crazy for .\n", 513 | "Entropy: 2.1571675430303348\n", 514 | "Size: 57\n", 515 | "Elements: \n", 516 | "can i help you sir what do you need ?\n", 517 | "yes sir . i 'll bring it over . have you decided what you 'd like sir ?\n", 518 | "are you sure ?\n", 519 | "yes we shall . what size do you like ?\n", 520 | "great i 'll take one .\n", 521 | "yes it is . and develop them as glossy as possible .\n", 522 | "sure just ask . what can i do for you ?\n", 523 | "dad how can we get to the zoo ?\n", 524 | "how about this one ? it is wellknown for the effect of removing scurf .\n", 525 | "where are you going ?\n", 526 | "i told you i m sorry . what can i do to make it up to you ?\n", 527 | "no mom . i did n't .\n", 528 | "all right . what is your type of skin ?\n", 529 | "what can i do for you ?\n", 530 | "oh ok where do i get off ?\n", 531 | "do i have to memorize it ?\n", 532 | "do you need money or what ?\n", 533 | "welcome . may i help you ?\n", 534 | "this is how you turn on the computer .\n", 535 | "when do you need them sir ?\n", 536 | "well how long will it be ?\n", 537 | "it is difficult .\n", 538 | "i like red .\n", 539 | "yes it is .\n", 540 | "dad can you help me ?\n", 541 | "depends ? depends on what ?\n", 542 | "of course sir no problem .\n", 543 | "let me see . it depends .\n", 544 | "no is it difficult ?\n", 545 | "sure where are you ?\n", 546 | "we can take a bus there .\n", 547 | "er . . . how about this one ?\n", 548 | "how many of you please ?\n", 549 | "so what ? it is not fresh and i 'm not happy about it .\n", 550 | "do i have to change ?\n", 551 | "well the 4 x 6 is fine .\n", 552 | "ok . let 's do it together .\n", 553 | "may i help you find something sir ?\n", 554 | "no . the museum is the terminal of this bus .\n", 555 | "dad may i have a room of my own ?\n", 556 | "dad i want to draw with crayons can i ?\n", 557 | "what for ?\n", 558 | "what kind of food do you like ?\n", 559 | "may i help you madam ?\n", 560 | "may i help you sir ?\n", 561 | "no we do n't .\n", 562 | "let 's see . from here you have to take the 278 bus .\n", 563 | "i see thank you . by the way is this the right bus for the museum ?\n", 564 | "how about this one ?\n", 565 | "do i have a choice ? uh . that 's a no . what can i do ?\n", 566 | "i do n't know how to do it .\n", 567 | "thank you for your compliment . but you are exaggerating . i think you are destined to be a singer . you have the best voice !\n", 568 | "i would like to have a roll developed .\n", 569 | "i think so .\n", 570 | "is it a newcomer ?\n", 571 | "yes i 'd like to . it 's my honor . let 's pick a song .\n", 572 | "yes . thank you .\n", 573 | "\n", 574 | "\n", 575 | "=====================================================\n", 576 | "Center: but your american ?\n", 577 | "Entropy: 2.0\n", 578 | "Size: 4\n", 579 | "Elements: \n", 580 | "have you heard about our special promotion this month ? if you purchase at least 18 dollar 50 cents in any elizabeth arden products you will receive this black poke with a sample of lipstick mascara and two shades of white shadow .\n", 581 | "we have all shapes sizes qualities and price ranges do you know about the four cs of picking a diamond ?\n", 582 | "wow that sounds like a bargain . i 'm running low on facial moisturizer and toner . could you ring those up for me too along with the blush ?\n", 583 | "well my price range is a 5000 dollars to 7000 dollars i 'm looking for a marquise cut on the wide band .\n", 584 | "\n", 585 | "\n", 586 | "=====================================================\n", 587 | "Center: oh no get off it . it wasn t such a killer class . you just have to get into it . like they say no pain no gain .\n", 588 | "Entropy: 1.9877733714879842\n", 589 | "Size: 13\n", 590 | "Elements: \n", 591 | "i 'd be glad to . do you need anything else ?\n", 592 | "dad how do you say this word ?\n", 593 | "coffee ? i don t honestly like that kind of stuff .\n", 594 | "certainly sir .\n", 595 | "you don t have to explain . suit yourself .\n", 596 | "i understand .\n", 597 | "i think that they get a pretty good payoff .\n", 598 | "sword say it sword .\n", 599 | "what s wrong ? didn t you think it was fun ? !\n", 600 | "is everything to your satisfaction ?\n", 601 | "sure . do you need anything else ?\n", 602 | "what does this word mean ?\n", 603 | "anything else ?\n", 604 | "\n", 605 | "\n", 606 | "=====================================================\n", 607 | "Center: oh do n't let that worry you . if that were true china would n't have such a large population .\n", 608 | "Entropy: 1.8220931167465637\n", 609 | "Size: 54\n", 610 | "Elements: \n", 611 | "my car has a problem starting . could you please take a look at it for me ?\n", 612 | "oh that 's right . they 're filming a movie up there are n't they ?\n", 613 | "are you going to the annual party ? i can give you a ride if you need one .\n", 614 | "if you are in a hurry you should take a taxi .\n", 615 | "and do you want the glossy or matted finish ?\n", 616 | "yes a roll of kodak film please .\n", 617 | "would you like to take a look at the menu sir ?\n", 618 | "hurry up will you ?\n", 619 | "i apologize . you have my word i ll spend some time with you on the weekend . i promise .\n", 620 | "the last one is black and white all the rest should need color .\n", 621 | "can you please tell me where you are located ?\n", 622 | "you pay when you pick them up . i do n't need a deposit for just one roll of film .\n", 623 | "could i have my bill please ?\n", 624 | "his name is grunt . come closer and i 'll properly introduce you .\n", 625 | "no he is at work now .\n", 626 | "did you think it was n't real ? that 's my pet lizard .\n", 627 | "i think so . are n't the four cs cut clarity carat and color .\n", 628 | "what s wrong with that ? cigarette is the thing i go crazy for .\n", 629 | "i 'm looking for a white purse as a gift . could you show what you have in stock ?\n", 630 | "can you send a cab to pick me up ?\n", 631 | "may i have his office phone number please ?\n", 632 | "you should get off at the first shi da stop .\n", 633 | "have you got change for a thousand ?\n", 634 | "when will she be back ?\n", 635 | "i 'm not sure . but i 'll get a table ready as fast as i can .\n", 636 | "no problem . do you need another film ?\n", 637 | "sorry i 'm late .\n", 638 | "how was your test ?\n", 639 | "no it 's quite simple . when you get on just ask the bus driver when to pay the fare and where you want to get off .\n", 640 | "fine . let 's get on . oh no judy ! get off the bus quickly !\n", 641 | "i m sorry . our company has just opened . there are always too many things to handle . you know that .\n", 642 | "no honey it 's easy if you know the way .\n", 643 | "excuse me i 'm a little lost . which bus do i take to get to shi da ?\n", 644 | "mom can i have more allowance ?\n", 645 | "does this bus go there ?\n", 646 | "never mind . you can follow me . i 'll sing the first part .\n", 647 | "i 'd like to talk to mr . white please ?\n", 648 | "oh yes that is a beautiful color . it has been very popular blush this season . i have two left .\n", 649 | "yes how do i get to your shop from chilin ?\n", 650 | "wow . this is nice . i 'll take this one . i guess if she does n't like it she can return it right ?\n", 651 | "no problem . do you want 3 x 5 or 4 x 6 ?\n", 652 | "good let 's go for a drive .\n", 653 | "i 'm sorry sir . do you wish to try something else ? that would be on the house of course .\n", 654 | "okay i 'll get the car out of the garage .\n", 655 | "it seems you got here at good time . do you have a bus schedule ?\n", 656 | "yes i do . you can buy a bus schedule in a news stand .\n", 657 | "let 's step in dad .\n", 658 | "mike 's mechanics . can i help you ?\n", 659 | "take a bus then . it will only cost you 5 dollars .\n", 660 | "can i take your order now or do you still want to look at the menu ?\n", 661 | "robert is not available at the moment .\n", 662 | "i hope they will come out well . when should i pick them up ?\n", 663 | "only 15 nt per section . oh look that is your bus .\n", 664 | "would you like a lift home ?\n", 665 | "\n", 666 | "\n", 667 | "=====================================================\n", 668 | "Center: a glass of qingdao beer .\n", 669 | "Entropy: 1.584962500721156\n", 670 | "Size: 3\n", 671 | "Elements: \n", 672 | "yes i would also like some sweet and sour sauce and pepper .\n", 673 | "certainly . how about spaghetti with clams and shrimps .\n", 674 | "do i owe you anything for the sauce pepper and napkins ?\n", 675 | "\n", 676 | "\n", 677 | "=====================================================\n", 678 | "Center: yes sir . i 'll bring it over . have you decided what you 'd like sir ?\n", 679 | "Entropy: 1.584962500721156\n", 680 | "Size: 3\n", 681 | "Elements: \n", 682 | "ok .\n", 683 | "the 4 x6 will be ok . thanks .\n", 684 | "ok thanks . . .\n", 685 | "\n", 686 | "\n", 687 | "=====================================================\n", 688 | "Center: i am wiped out . thank you .\n", 689 | "Entropy: 1.0\n", 690 | "Size: 2\n", 691 | "Elements: \n", 692 | "somebody please answer the phone .\n", 693 | "look next time get yourself some comfy shoes . you re gonna come back again with me aren t you ?\n", 694 | "\n", 695 | "\n" 696 | ] 697 | } 698 | ], 699 | "source": [ 700 | "print_clusters(tag='Source', top_clusters=10)" 701 | ] 702 | }, 703 | { 704 | "cell_type": "code", 705 | "execution_count": null, 706 | "metadata": {}, 707 | "outputs": [], 708 | "source": [] 709 | } 710 | ], 711 | "metadata": { 712 | "colab": { 713 | "collapsed_sections": [ 714 | "hItN-dJISP0N", 715 | "I3m_tPouSP0U" 716 | ], 717 | "name": "visualization.ipynb", 718 | "provenance": [], 719 | "toc_visible": true, 720 | "version": "0.3.2" 721 | }, 722 | "kernelspec": { 723 | "display_name": "Python 3", 724 | "language": "python", 725 | "name": "python3" 726 | }, 727 | "language_info": { 728 | "codemirror_mode": { 729 | "name": "ipython", 730 | "version": 3 731 | }, 732 | "file_extension": ".py", 733 | "mimetype": "text/x-python", 734 | "name": "python", 735 | "nbconvert_exporter": "python", 736 | "pygments_lexer": "ipython3", 737 | "version": "3.6.6" 738 | } 739 | }, 740 | "nbformat": 4, 741 | "nbformat_minor": 1 742 | } 743 | -------------------------------------------------------------------------------- /docs/3d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ricsinaruto/NeuralChatbots-DataFiltering/2b8a2e848b089ff782ef4c3426af407be3c6f68b/docs/3d.png -------------------------------------------------------------------------------- /docs/cluster_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ricsinaruto/NeuralChatbots-DataFiltering/2b8a2e848b089ff782ef4c3426af407be3c6f68b/docs/cluster_examples.png -------------------------------------------------------------------------------- /docs/example_responses.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ricsinaruto/NeuralChatbots-DataFiltering/2b8a2e848b089ff782ef4c3426af407be3c6f68b/docs/example_responses.png -------------------------------------------------------------------------------- /docs/help.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ricsinaruto/NeuralChatbots-DataFiltering/2b8a2e848b089ff782ef4c3426af407be3c6f68b/docs/help.png -------------------------------------------------------------------------------- /docs/high_entropy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ricsinaruto/NeuralChatbots-DataFiltering/2b8a2e848b089ff782ef4c3426af407be3c6f68b/docs/high_entropy.png -------------------------------------------------------------------------------- /docs/metrics_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ricsinaruto/NeuralChatbots-DataFiltering/2b8a2e848b089ff782ef4c3426af407be3c6f68b/docs/metrics_table.png -------------------------------------------------------------------------------- /docs/other_datasets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ricsinaruto/NeuralChatbots-DataFiltering/2b8a2e848b089ff782ef4c3426af407be3c6f68b/docs/other_datasets.png -------------------------------------------------------------------------------- /docs/uml.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ricsinaruto/NeuralChatbots-DataFiltering/2b8a2e848b089ff782ef4c3426af407be3c6f68b/docs/uml.png -------------------------------------------------------------------------------- /docs/visu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ricsinaruto/NeuralChatbots-DataFiltering/2b8a2e848b089ff782ef4c3426af407be3c6f68b/docs/visu.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | sklearn 2 | clint 3 | matplotlib 4 | numpy 5 | requests -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | print('Installing requirements...') 5 | os.system('pip install -r requirements.txt') 6 | 7 | import requests 8 | import zipfile 9 | from clint.textui import progress 10 | 11 | 12 | def download_data(url, zipped_path, extract): 13 | # Open the url and download the data with progress bars. 14 | data_stream = requests.get(url, stream=True) 15 | 16 | with open(zipped_path, 'wb') as file: 17 | total_length = int(data_stream.headers.get('content-length')) 18 | for chunk in progress.bar(data_stream.iter_content(chunk_size=1024), 19 | expected_size=total_length / 1024 + 1): 20 | if chunk: 21 | file.write(chunk) 22 | file.flush() 23 | 24 | # Extract file. 25 | zip_file = zipfile.ZipFile(zipped_path, 'r') 26 | zip_file.extractall(extract) 27 | zip_file.close() 28 | 29 | 30 | print('Do you want to download all datasets used in the paper (116 MB)? (y/n)') 31 | if input() == 'y': 32 | if not os.path.exists('data'): 33 | os.mkdir('data') 34 | download_data('https://ricsinaruto.github.io/website/docs/Twitter.zip', 'data/Twitter.zip', 'data') 35 | download_data('https://ricsinaruto.github.io/website/docs/Cornell.zip', 'data/Cornell.zip', 'data') 36 | download_data('https://ricsinaruto.github.io/website/docs/DailyDialog.zip', 'data/DailyDialog.zip', 'data') 37 | 38 | print('Do you want to download all generated responses on the test set by the different models (7 MB)? (y/n)') 39 | if input() == 'y': 40 | download_data('https://ricsinaruto.github.io/website/docs/responses.zip', 'responses.zip', '') 41 | --------------------------------------------------------------------------------