├── README.md ├── lucene ├── IndexFiles.class ├── IndexFiles.java ├── SearchSimilarFiles.class ├── SearchSimilarFiles.java ├── lucene-6.3.0 │ └── libs │ │ ├── lucene-backward-codecs-6.3.0.jar │ │ ├── lucene-benchmark-6.3.0.jar │ │ ├── lucene-classification-6.3.0.jar │ │ ├── lucene-codecs-6.3.0.jar │ │ ├── lucene-core-6.3.0.jar │ │ ├── lucene-demo-6.3.0.jar │ │ ├── lucene-expressions-6.3.0.jar │ │ ├── lucene-facet-6.3.0.jar │ │ ├── lucene-grouping-6.3.0.jar │ │ ├── lucene-highlighter-6.3.0.jar │ │ ├── lucene-join-6.3.0.jar │ │ ├── lucene-memory-6.3.0.jar │ │ ├── lucene-misc-6.3.0.jar │ │ ├── lucene-queries-6.3.0.jar │ │ ├── lucene-queryparser-6.3.0.jar │ │ ├── lucene-replicator-6.3.0.jar │ │ ├── lucene-sandbox-6.3.0.jar │ │ ├── lucene-spatial-6.3.0.jar │ │ ├── lucene-spatial-extras-6.3.0.jar │ │ ├── lucene-spatial3d-6.3.0.jar │ │ ├── lucene-suggest-6.3.0.jar │ │ └── lucene-test-framework-6.3.0.jar ├── run_lucene.sh └── run_ques_lucene.sh └── src ├── data_generation ├── README.md ├── data_generator.py ├── helper.py ├── parse.py ├── post_ques_ans_generator.py └── run_data_generator.sh ├── embedding_generation ├── README.md ├── create_we_vocab.py ├── run_create_we_vocab.sh └── run_glove.sh ├── evaluation ├── evaluate_model_with_human_annotations.py └── run_evaluation.sh └── models ├── README.md ├── baseline_pa.py ├── baseline_pq.py ├── baseline_pqa.py ├── combine_pickle.py ├── evpi.py ├── load_data.py ├── lstm_helper.py ├── main.py ├── model_helper.py ├── run_combine_domains.sh ├── run_load_data.sh └── run_main.sh /README.md: -------------------------------------------------------------------------------- 1 | # Repository information 2 | 3 | This repository contains data and code for the paper below: 4 | 5 | 6 | Learning to Ask Good Questions: Ranking Clarification Questions using Neural Expected Value of Perfect Information
7 | Sudha Rao (raosudha@cs.umd.edu) and Hal Daumé III (hal@umiacs.umd.edu)
8 | Proceedings of The 2018 Association of Computational Lingusitics (ACL 2018) 9 | 10 | # Downloading data 11 | 12 | * Download the clarification questions dataset from google drive here: https://go.umd.edu/clarification_questions_dataset 13 | * cp clarification_questions_dataset/data ranking_clarification_questions/data 14 | 15 | * Download word embeddings trained on stackexchange datadump here: https://go.umd.edu/stackexchange_embeddings 16 | * cp stackexchange_embeddings/embeddings ranking_clarification_questions/embeddings 17 | 18 | The above dataset contains clarification questions for these three sites of stackexchange:
19 | 1. askubuntu.com 20 | 2. unix.stackexchange.com 21 | 3. superuser.com 22 | 23 | # Running model on data above 24 | 25 | To run models on a combination of the three sites above, check ranking_clarification_questions/src/models/README 26 | 27 | # Generating data for other sites 28 | 29 | To generate clarification questions for a different site of stackexchange, check ranking_clarification_questions/src/data_generation/README 30 | 31 | # Retrain stackexchange word embeddings 32 | 33 | To retrain word embeddings on a newer version of stackexchange datadump, check ranking_clarification_questions/src/embedding_generation/README 34 | 35 | # Contact information 36 | 37 | Please contact Sudha Rao (raosudha@cs.umd.edu) if you have any questions or any feedback. 38 | -------------------------------------------------------------------------------- /lucene/IndexFiles.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/IndexFiles.class -------------------------------------------------------------------------------- /lucene/IndexFiles.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | import java.io.BufferedReader; 18 | import java.io.IOException; 19 | import java.io.InputStream; 20 | import java.io.InputStreamReader; 21 | import java.nio.charset.StandardCharsets; 22 | import java.nio.file.FileVisitResult; 23 | import java.nio.file.Files; 24 | import java.nio.file.Path; 25 | import java.nio.file.Paths; 26 | import java.nio.file.SimpleFileVisitor; 27 | import java.nio.file.attribute.BasicFileAttributes; 28 | import java.util.Date; 29 | 30 | import org.apache.lucene.analysis.Analyzer; 31 | import org.apache.lucene.analysis.standard.StandardAnalyzer; 32 | import org.apache.lucene.document.LongPoint; 33 | import org.apache.lucene.document.Document; 34 | import org.apache.lucene.document.Field; 35 | import org.apache.lucene.document.StringField; 36 | import org.apache.lucene.document.TextField; 37 | import org.apache.lucene.index.IndexWriter; 38 | import org.apache.lucene.index.IndexWriterConfig.OpenMode; 39 | import org.apache.lucene.index.IndexWriterConfig; 40 | import org.apache.lucene.index.Term; 41 | import org.apache.lucene.store.Directory; 42 | import org.apache.lucene.store.FSDirectory; 43 | 44 | /** Index all text files under a directory. 45 | *

46 | * This is a command-line application demonstrating simple Lucene indexing. 47 | * Run it with no command-line arguments for usage information. 48 | */ 49 | public class IndexFiles { 50 | 51 | private IndexFiles() {} 52 | 53 | /** Index all text files under a directory. */ 54 | public static void main(String[] args) { 55 | String usage = "java org.apache.lucene.demo.IndexFiles" 56 | + " [-index INDEX_PATH] [-docs DOCS_PATH] [-update]\n\n" 57 | + "This indexes the documents in DOCS_PATH, creating a Lucene index" 58 | + "in INDEX_PATH that can be searched with SearchFiles"; 59 | String indexPath = "index"; 60 | String docsPath = null; 61 | boolean create = true; 62 | for(int i=0;iWriteLineDocTask. 141 | * 142 | * @param writer Writer to the index where the given file/dir info will be stored 143 | * @param path The file to index, or the directory to recurse into to find files to index 144 | * @throws IOException If there is a low-level I/O error 145 | */ 146 | static void indexDocs(final IndexWriter writer, Path path) throws IOException { 147 | if (Files.isDirectory(path)) { 148 | Files.walkFileTree(path, new SimpleFileVisitor() { 149 | @Override 150 | public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException { 151 | try { 152 | indexDoc(writer, file, attrs.lastModifiedTime().toMillis()); 153 | } catch (IOException ignore) { 154 | // don't index files that can't be read. 155 | } 156 | return FileVisitResult.CONTINUE; 157 | } 158 | }); 159 | } else { 160 | indexDoc(writer, path, Files.getLastModifiedTime(path).toMillis()); 161 | } 162 | } 163 | 164 | /** Indexes a single document */ 165 | static void indexDoc(IndexWriter writer, Path file, long lastModified) throws IOException { 166 | try (InputStream stream = Files.newInputStream(file)) { 167 | // make a new, empty document 168 | Document doc = new Document(); 169 | 170 | // Add the path of the file as a field named "path". Use a 171 | // field that is indexed (i.e. searchable), but don't tokenize 172 | // the field into separate words and don't index term frequency 173 | // or positional information: 174 | Field pathField = new StringField("path", file.getFileName().toString(), Field.Store.YES); 175 | doc.add(pathField); 176 | 177 | // Add the last modified date of the file a field named "modified". 178 | // Use a LongPoint that is indexed (i.e. efficiently filterable with 179 | // PointRangeQuery). This indexes to milli-second resolution, which 180 | // is often too fine. You could instead create a number based on 181 | // year/month/day/hour/minutes/seconds, down the resolution you require. 182 | // For example the long value 2011021714 would mean 183 | // February 17, 2011, 2-3 PM. 184 | doc.add(new LongPoint("modified", lastModified)); 185 | 186 | // Add the contents of the file to a field named "contents". Specify a Reader, 187 | // so that the text of the file is tokenized and indexed, but not stored. 188 | // Note that FileReader expects the file to be in UTF-8 encoding. 189 | // If that's not the case searching for special characters will fail. 190 | doc.add(new TextField("contents", new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8)))); 191 | 192 | if (writer.getConfig().getOpenMode() == OpenMode.CREATE) { 193 | // New index, so we just add the document (no old document can be there): 194 | System.out.println("adding " + file); 195 | writer.addDocument(doc); 196 | } else { 197 | // Existing index (an old copy of this document may have been indexed) so 198 | // we use updateDocument instead to replace the old one matching the exact 199 | // path, if present: 200 | System.out.println("updating " + file); 201 | writer.updateDocument(new Term("path", file.toString()), doc); 202 | } 203 | } 204 | } 205 | } 206 | -------------------------------------------------------------------------------- /lucene/SearchSimilarFiles.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/SearchSimilarFiles.class -------------------------------------------------------------------------------- /lucene/SearchSimilarFiles.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | //package org.apache.lucene.demo; 18 | 19 | 20 | import java.io.BufferedReader; 21 | import java.io.IOException; 22 | import java.io.InputStream; 23 | import java.io.InputStreamReader; 24 | import java.io.PrintWriter; 25 | import java.lang.Math; 26 | import java.nio.charset.StandardCharsets; 27 | import java.nio.file.Files; 28 | import java.nio.file.FileVisitResult; 29 | import java.nio.file.Path; 30 | import java.nio.file.Paths; 31 | import java.nio.file.SimpleFileVisitor; 32 | import java.nio.file.attribute.BasicFileAttributes; 33 | import java.util.Date; 34 | import java.util.stream.Stream; 35 | 36 | import org.apache.lucene.analysis.Analyzer; 37 | import org.apache.lucene.analysis.standard.StandardAnalyzer; 38 | import org.apache.lucene.document.Document; 39 | import org.apache.lucene.document.TextField; 40 | import org.apache.lucene.index.DirectoryReader; 41 | import org.apache.lucene.index.IndexReader; 42 | import org.apache.lucene.queryparser.classic.QueryParser; 43 | import org.apache.lucene.queryparser.classic.ParseException; 44 | import org.apache.lucene.search.IndexSearcher; 45 | import org.apache.lucene.search.Query; 46 | import org.apache.lucene.search.ScoreDoc; 47 | import org.apache.lucene.search.TopDocs; 48 | import org.apache.lucene.store.FSDirectory; 49 | 50 | /** Simple command-line based search demo. */ 51 | public class SearchSimilarFiles { 52 | 53 | private SearchSimilarFiles() {} 54 | 55 | /** Simple command-line based search demo. */ 56 | public static void main(String[] args) throws Exception, ParseException, IOException{ 57 | String usage = 58 | "Usage:\tjava org.apache.lucene.demo.SearchSimilarFiles [-index dir] [-queries file] [-outputFile string]\n\n"; 59 | if (args.length < 2 && ("-h".equals(args[0]) || "-help".equals(args[0]))) { 60 | System.out.println(usage); 61 | System.exit(0); 62 | } 63 | 64 | String index = "index"; 65 | String field = "contents"; 66 | String queriesPath = null; 67 | String queryString = null; 68 | String outputFile = null; 69 | 70 | for(int i = 0;i < args.length;i++) { 71 | if ("-index".equals(args[i])) { 72 | index = args[i+1]; 73 | i++; 74 | } else if ("-field".equals(args[i])) { 75 | field = args[i+1]; 76 | i++; 77 | } else if ("-queries".equals(args[i])) { 78 | queriesPath = args[i+1]; 79 | i++; 80 | } else if ("-outputFile".equals(args[i])) { 81 | outputFile = args[i+1]; 82 | i++; 83 | } 84 | } 85 | IndexReader reader = DirectoryReader.open(FSDirectory.open(Paths.get(index))); 86 | IndexSearcher searcher = new IndexSearcher(reader); 87 | Analyzer analyzer = new StandardAnalyzer(); 88 | 89 | QueryParser parser = new QueryParser(field, analyzer); 90 | PrintWriter writer = new PrintWriter(outputFile, "UTF-8"); 91 | 92 | final Path queriesDir = Paths.get(queriesPath); 93 | if (Files.isDirectory(queriesDir)){ 94 | Files.walkFileTree(queriesDir, new SimpleFileVisitor() { 95 | @Override 96 | public FileVisitResult visitFile(Path queryFile, BasicFileAttributes attrs) throws IOException { 97 | similarDocs(queryFile, parser, searcher, writer); 98 | return FileVisitResult.CONTINUE; 99 | } 100 | }); 101 | //Stream queryFiles = Files.walk(queriesDir); 102 | //queryFiles.forEach(queryFile -> 103 | // similarDocs(queryFile, parser, searcher) 104 | //); 105 | } 106 | writer.close(); 107 | reader.close(); 108 | } 109 | 110 | static void similarDocs(Path file, QueryParser parser, IndexSearcher searcher, PrintWriter writer) { 111 | try (InputStream stream = Files.newInputStream(file)) { 112 | String filename = file.getFileName().toString(); 113 | filename = filename.substring(0, filename.lastIndexOf('.')); 114 | writer.print(filename + ' '); 115 | String content = readDocument(new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8))); 116 | Query query = parser.parse(parser.escape(content)); 117 | 118 | TopDocs results = searcher.search(query, 100); 119 | ScoreDoc[] hits = results.scoreDocs; 120 | 121 | int numTotalHits = results.totalHits; 122 | //System.out.println(numTotalHits + " total matching documents"); 123 | 124 | for (int i = 0; i < Math.min(100, numTotalHits); i++) { 125 | Document doc = searcher.doc(hits[i].doc); 126 | String docname = doc.get("path"); 127 | docname = docname.substring(0, docname.lastIndexOf('.')); 128 | //System.out.println("doc="+docname+" score="+hits[i].score); 129 | writer.print(docname + ' '); 130 | } 131 | } catch (IOException e) { 132 | e.printStackTrace(); 133 | } catch (ParseException pe){ 134 | System.out.println(pe.getStackTrace()); 135 | System.out.println(pe.getMessage()); 136 | } finally { 137 | writer.println(); 138 | } 139 | } 140 | 141 | static String readDocument(BufferedReader br) throws IOException { 142 | StringBuilder sb = new StringBuilder(); 143 | String line = br.readLine(); 144 | 145 | while (line != null) { 146 | sb.append(line); 147 | sb.append(System.lineSeparator()); 148 | line = br.readLine(); 149 | } 150 | return sb.toString(); 151 | } 152 | } 153 | 154 | -------------------------------------------------------------------------------- /lucene/lucene-6.3.0/libs/lucene-backward-codecs-6.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-backward-codecs-6.3.0.jar -------------------------------------------------------------------------------- /lucene/lucene-6.3.0/libs/lucene-benchmark-6.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-benchmark-6.3.0.jar -------------------------------------------------------------------------------- /lucene/lucene-6.3.0/libs/lucene-classification-6.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-classification-6.3.0.jar -------------------------------------------------------------------------------- /lucene/lucene-6.3.0/libs/lucene-codecs-6.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-codecs-6.3.0.jar -------------------------------------------------------------------------------- /lucene/lucene-6.3.0/libs/lucene-core-6.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-core-6.3.0.jar -------------------------------------------------------------------------------- /lucene/lucene-6.3.0/libs/lucene-demo-6.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-demo-6.3.0.jar -------------------------------------------------------------------------------- /lucene/lucene-6.3.0/libs/lucene-expressions-6.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-expressions-6.3.0.jar -------------------------------------------------------------------------------- /lucene/lucene-6.3.0/libs/lucene-facet-6.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-facet-6.3.0.jar -------------------------------------------------------------------------------- /lucene/lucene-6.3.0/libs/lucene-grouping-6.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-grouping-6.3.0.jar -------------------------------------------------------------------------------- /lucene/lucene-6.3.0/libs/lucene-highlighter-6.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-highlighter-6.3.0.jar -------------------------------------------------------------------------------- /lucene/lucene-6.3.0/libs/lucene-join-6.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-join-6.3.0.jar -------------------------------------------------------------------------------- /lucene/lucene-6.3.0/libs/lucene-memory-6.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-memory-6.3.0.jar -------------------------------------------------------------------------------- /lucene/lucene-6.3.0/libs/lucene-misc-6.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-misc-6.3.0.jar -------------------------------------------------------------------------------- /lucene/lucene-6.3.0/libs/lucene-queries-6.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-queries-6.3.0.jar -------------------------------------------------------------------------------- /lucene/lucene-6.3.0/libs/lucene-queryparser-6.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-queryparser-6.3.0.jar -------------------------------------------------------------------------------- /lucene/lucene-6.3.0/libs/lucene-replicator-6.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-replicator-6.3.0.jar -------------------------------------------------------------------------------- /lucene/lucene-6.3.0/libs/lucene-sandbox-6.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-sandbox-6.3.0.jar -------------------------------------------------------------------------------- /lucene/lucene-6.3.0/libs/lucene-spatial-6.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-spatial-6.3.0.jar -------------------------------------------------------------------------------- /lucene/lucene-6.3.0/libs/lucene-spatial-extras-6.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-spatial-extras-6.3.0.jar -------------------------------------------------------------------------------- /lucene/lucene-6.3.0/libs/lucene-spatial3d-6.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-spatial3d-6.3.0.jar -------------------------------------------------------------------------------- /lucene/lucene-6.3.0/libs/lucene-suggest-6.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-suggest-6.3.0.jar -------------------------------------------------------------------------------- /lucene/lucene-6.3.0/libs/lucene-test-framework-6.3.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-test-framework-6.3.0.jar -------------------------------------------------------------------------------- /lucene/run_lucene.sh: -------------------------------------------------------------------------------- 1 | SITE_DIR=$1 2 | /usr/lib/jvm/java-1.8.0-oracle-1.8.0.151-1jpp.5.el7.x86_64/bin/java -cp '.:lucene-6.3.0/libs/*' IndexFiles -docs $SITE_DIR/post_docs/ -index $SITE_DIR/post_doc_indices/ 3 | 4 | /usr/lib/jvm/java-1.8.0-oracle-1.8.0.151-1jpp.5.el7.x86_64/bin/java -cp '.:lucene-6.3.0/libs/*' SearchSimilarFiles -index $SITE_DIR/post_doc_indices/ -queries $SITE_DIR/post_docs/ -outputFile $SITE_DIR/lucene_similar_posts.txt 5 | -------------------------------------------------------------------------------- /lucene/run_ques_lucene.sh: -------------------------------------------------------------------------------- 1 | SITE_DIR=$1 2 | #/usr/lib/jvm/java-1.8.0-oracle.x86_64/bin/javac -cp 'lucene-6.3.0/libs/*' IndexFiles.java 3 | 4 | #/usr/lib/jvm/java-1.8.0-oracle.x86_64/bin/java -cp '.:lucene-6.3.0/libs/*' IndexFiles -docs $SITE_DIR/ques_docs/ -index $SITE_DIR/ques_doc_indices/ 5 | 6 | /usr/lib/jvm/java-1.8.0-oracle-1.8.0.151-1jpp.5.el7.x86_64/bin/java -cp '.:lucene-6.3.0/libs/*' IndexFiles -docs $SITE_DIR/ques_docs/ -index $SITE_DIR/ques_doc_indices/ 7 | 8 | #/usr/lib/jvm/java-1.8.0-oracle.x86_64/bin/javac -cp 'lucene-6.3.0/libs/*' SearchSimilarFiles.java 9 | 10 | #/usr/lib/jvm/java-1.8.0-oracle.x86_64/bin/java -cp '.:lucene-6.3.0/libs/*' SearchSimilarFiles -index $SITE_DIR/ques_doc_indices/ -queries $SITE_DIR/ques_docs/ -outputFile $SITE_DIR/lucene_similar_questions.txt 11 | 12 | /usr/lib/jvm/java-1.8.0-oracle-1.8.0.151-1jpp.5.el7.x86_64/bin/java -cp '.:lucene-6.3.0/libs/*' SearchSimilarFiles -index $SITE_DIR/ques_doc_indices/ -queries $SITE_DIR/ques_docs/ -outputFile $SITE_DIR/lucene_similar_questions.txt 13 | -------------------------------------------------------------------------------- /src/data_generation/README.md: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | 3 | * Install numpy, nltk, BeautifulSoup 4 | 5 | # NOTE 6 | 7 | * Data for the sites: askubuntu.com, unix.stackexchange.com & superuser.com can be found here: https://go.umd.edu/clarification_questions_dataset 8 | 9 | # Steps to generate data for any other site 10 | 11 | 1. Choose a sitename from the list of sites in https://archive.org/download/stackexchange. Let's say you chose 'academia.com' 12 | 2. Download the .7z file corresponding to the site i.e. academia.com.7z and unzip it under ranking_clarification_questions/stackexchange/ 13 | 3. Set "SITENAME=academia.com" in ranking_clarification_questions/src/data_generation/run_data_generator.sh 14 | 4. cd ranking_clarification_questions; sh src/data_generation/run_data_generator.sh 15 | 16 | This will create data/academia.com/post_data.tsv & data/academia.com/qa_data.tsv files 17 | -------------------------------------------------------------------------------- /src/data_generation/data_generator.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import argparse 3 | from parse import * 4 | from post_ques_ans_generator import * 5 | from helper import * 6 | import time 7 | import numpy as np 8 | import cPickle as p 9 | import pdb 10 | import random 11 | 12 | def get_similar_docs(lucene_similar_docs): 13 | lucene_similar_docs_file = open(lucene_similar_docs, 'r') 14 | similar_docs = {} 15 | for line in lucene_similar_docs_file.readlines(): 16 | parts = line.split() 17 | if len(parts) > 1: 18 | similar_docs[parts[0]] = parts[1:] 19 | else: 20 | similar_docs[parts[0]] = [] 21 | return similar_docs 22 | 23 | def generate_docs_for_lucene(post_ques_answers, posts, output_dir): 24 | for postId in post_ques_answers: 25 | f = open(os.path.join(output_dir, str(postId) + '.txt'), 'w') 26 | content = ' '.join(posts[postId].title).encode('utf-8') + ' ' + ' '.join(posts[postId].body).encode('utf-8') 27 | f.write(content) 28 | f.close() 29 | 30 | def create_tsv_files(post_data_tsv, qa_data_tsv, post_ques_answers, lucene_similar_posts): 31 | lucene_similar_posts = get_similar_docs(lucene_similar_posts) 32 | similar_posts = {} 33 | for line in lucene_similar_posts.readlines(): 34 | splits = line.strip('\n').split() 35 | if len(splits) < 11: 36 | continue 37 | postId = splits[0] 38 | similar_posts[postId] = splits[1:11] 39 | post_data_tsv_file = open(post_data_tsv, 'w') 40 | post_data_tsv_file.write('postid\ttitle\tpost\n') 41 | qa_data_tsv_file = open(qa_data_tsv, 'w') 42 | qa_data_tsv_file.write('postid\tq1\tq2\tq3\tq4\tq5\tq6\tq7\tq8\tq9\tq10\ta1\ta2\ta3\ta4\ta5\ta6\ta7\ta8\ta9\ta10\n') 43 | for postId in similar_posts: 44 | post_data_tsv_file.write('%s\t%s\t%s\n' % (postId, \ 45 | ' '.join(post_ques_answers[postId].post_title), \ 46 | ' '.join(post_ques_answers[postId].post))) 47 | line = postId 48 | for i in range(10): 49 | line += '\t%s' % ' '.join(post_ques_answers[similar_posts[postId][i]].question_comment) 50 | for i in range(10): 51 | line += '\t%s' % ' '.join(post_ques_answers[similar_posts[postId][i]].answer) 52 | line += '\n' 53 | qa_data_tsv_file.write(line) 54 | 55 | def main(args): 56 | start_time = time.time() 57 | print 'Parsing posts...' 58 | post_parser = PostParser(args.posts_xml) 59 | post_parser.parse() 60 | posts = post_parser.get_posts() 61 | print 'Size: ', len(posts) 62 | print 'Done! Time taken ', time.time() - start_time 63 | print 64 | 65 | start_time = time.time() 66 | print 'Parsing posthistories...' 67 | posthistory_parser = PostHistoryParser(args.posthistory_xml) 68 | posthistory_parser.parse() 69 | posthistories = posthistory_parser.get_posthistories() 70 | print 'Size: ', len(posthistories) 71 | print 'Done! Time taken ', time.time() - start_time 72 | print 73 | 74 | start_time = time.time() 75 | print 'Parsing question comments...' 76 | comment_parser = CommentParser(args.comments_xml) 77 | comment_parser.parse_all_comments() 78 | question_comments = comment_parser.get_question_comments() 79 | all_comments = comment_parser.get_all_comments() 80 | print 'Size: ', len(question_comments) 81 | print 'Done! Time taken ', time.time() - start_time 82 | print 83 | 84 | start_time = time.time() 85 | print 'Loading vocab' 86 | vocab = p.load(open(args.vocab, 'rb')) 87 | print 'Done! Time taken ', time.time() - start_time 88 | print 89 | 90 | start_time = time.time() 91 | print 'Loading word_embeddings' 92 | word_embeddings = p.load(open(args.word_embeddings, 'rb')) 93 | word_embeddings = np.asarray(word_embeddings, dtype=np.float32) 94 | print 'Done! Time taken ', time.time() - start_time 95 | print 96 | 97 | start_time = time.time() 98 | print 'Generating post_ques_ans...' 99 | post_ques_ans_generator = PostQuesAnsGenerator() 100 | post_ques_answers = post_ques_ans_generator.generate(posts, question_comments, all_comments, posthistories, vocab, word_embeddings) 101 | print 'Size: ', len(post_ques_answers) 102 | print 'Done! Time taken ', time.time() - start_time 103 | print 104 | 105 | generate_docs_for_lucene(post_ques_answers, posts, args.lucene_docs_dir) 106 | os.system('cd %s && sh run_lucene.sh %s' % (args.lucene_dir, os.path.dirname(args.post_data_tsv))) 107 | 108 | create_tsv_files(args.post_data_tsv, args.qa_data_tsv, post_ques_answers, args.lucene_similar_posts) 109 | 110 | if __name__ == "__main__": 111 | argparser = argparse.ArgumentParser(sys.argv[0]) 112 | argparser.add_argument("--posts_xml", type = str) 113 | argparser.add_argument("--comments_xml", type = str) 114 | argparser.add_argument("--posthistory_xml", type = str) 115 | argparser.add_argument("--lucene_dir", type = str) 116 | argparser.add_argument("--lucene_docs_dir", type = str) 117 | argparser.add_argument("--lucene_similar_posts", type = str) 118 | argparser.add_argument("--word_embeddings", type = str) 119 | argparser.add_argument("--vocab", type = str) 120 | argparser.add_argument("--no_of_candidates", type = int, default = 10) 121 | argparser.add_argument("--site_name", type = str) 122 | argparser.add_argument("--post_data_tsv", type = str) 123 | argparser.add_argument("--qa_data_tsv", type = str) 124 | args = argparser.parse_args() 125 | print args 126 | print "" 127 | main(args) 128 | 129 | -------------------------------------------------------------------------------- /src/data_generation/helper.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from nltk.tokenize import word_tokenize 3 | from nltk.tokenize import sent_tokenize 4 | import nltk 5 | nltk.download('punkt') 6 | import numpy as np 7 | from BeautifulSoup import BeautifulSoup 8 | import re 9 | 10 | def get_tokens(text): 11 | text = BeautifulSoup(text.encode('utf-8').decode('ascii', 'ignore')).text 12 | return word_tokenize(str(text).lower()) 13 | 14 | def get_indices(tokens, vocab): 15 | indices = np.zeros([len(tokens)], dtype=np.int32) 16 | UNK = "" 17 | for i, w in enumerate(tokens): 18 | try: 19 | indices[i] = vocab[w] 20 | except: 21 | indices[i] = vocab[UNK] 22 | return indices 23 | 24 | def get_similarity(a_indices, b_indices, word_embeddings): 25 | a_embeddings = [word_embeddings[idx] for idx in a_indices] 26 | b_embeddings = [word_embeddings[idx] for idx in b_indices] 27 | avg_a_embedding = np.mean(a_embeddings, axis=0) 28 | avg_b_embedding = np.mean(b_embeddings, axis=0) 29 | cosine_similarity = np.dot(avg_a_embedding, avg_b_embedding)/(np.linalg.norm(avg_a_embedding) * np.linalg.norm(avg_b_embedding)) 30 | return cosine_similarity 31 | 32 | def remove_urls(text): 33 | r = re.compile(r"(http://[^ ]+)") 34 | text = r.sub("", text) #remove urls so that ? is not identified in urls 35 | r = re.compile(r"(https://[^ ]+)") 36 | text = r.sub("", text) #remove urls so that ? is not identified in urls 37 | r = re.compile(r"(http : //[^ ]+)") 38 | text = r.sub("", text) #remove urls so that ? is not identified in urls 39 | r = re.compile(r"(https : //[^ ]+)") 40 | text = r.sub("", text) #remove urls so that ? is not identified in urls 41 | return text 42 | 43 | def is_too_short_or_long(tokens): 44 | text = ' '.join(tokens) 45 | r = re.compile('[^a-zA-Z ]+') 46 | text = r.sub('', text) 47 | tokens = text.split() 48 | if len(tokens) < 3 or len(tokens) > 100: 49 | return True 50 | return False 51 | -------------------------------------------------------------------------------- /src/data_generation/parse.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import cPickle as p 3 | import xml.etree.ElementTree as ET 4 | from collections import defaultdict 5 | import re 6 | import datetime 7 | import time 8 | import pdb 9 | from helper import * 10 | 11 | class Post: 12 | 13 | def __init__(self, title, body, sents, typeId, accepted_answerId, answer_count, owner_userId, creation_date, parentId, closed_date): 14 | self.title = title 15 | self.body = body 16 | self.sents = sents 17 | self.typeId = typeId 18 | self.accepted_answerId = accepted_answerId 19 | self.answer_count = answer_count 20 | self.owner_userId = owner_userId 21 | self.creation_date = creation_date 22 | self.parentId = parentId 23 | self.closed_date = closed_date 24 | 25 | class PostParser: 26 | 27 | def __init__(self, filename): 28 | self.filename = filename 29 | self.posts = dict() 30 | 31 | def parse(self): 32 | posts_tree = ET.parse(self.filename) 33 | for post in posts_tree.getroot(): 34 | postId = post.attrib['Id'] 35 | postTypeId = int(post.attrib['PostTypeId']) 36 | try: 37 | accepted_answerId = post.attrib['AcceptedAnswerId'] 38 | except: 39 | accepted_answerId = None #non-main posts & unanswered posts don't have accepted_answerId 40 | try: 41 | answer_count = int(post.attrib['AnswerCount']) 42 | except: 43 | answer_count = None #non-main posts don't have answer_count 44 | try: 45 | title = get_tokens(post.attrib['Title']) 46 | except: 47 | title = [] 48 | try: 49 | owner_userId = post.attrib['OwnerUserId'] 50 | except: 51 | owner_userId = None 52 | creation_date = datetime.datetime.strptime(post.attrib['CreationDate'].split('.')[0], "%Y-%m-%dT%H:%M:%S") 53 | try: 54 | closed_date = datetime.datetime.strptime(post.attrib['ClosedDate'].split('.')[0], "%Y-%m-%dT%H:%M:%S") 55 | except: 56 | closed_date = None 57 | if postTypeId == 2: 58 | parentId = post.attrib['ParentId'] 59 | else: 60 | parentId = None 61 | body = get_tokens(post.attrib['Body']) 62 | sent_tokens = get_sent_tokens(post.attrib['Body']) 63 | self.posts[postId] = Post(title, body, sent_tokens, postTypeId, accepted_answerId, answer_count, owner_userId, creation_date, parentId, closed_date) 64 | 65 | def get_posts(self): 66 | return self.posts 67 | 68 | class Comment: 69 | 70 | def __init__(self, text, creation_date, userId): 71 | self.text = text 72 | self.creation_date = creation_date 73 | self.userId = userId 74 | 75 | class QuestionComment: 76 | 77 | def __init__(self, text, creation_date, userId): 78 | self.text = text 79 | self.creation_date = creation_date 80 | self.userId = userId 81 | 82 | class CommentParser: 83 | 84 | def __init__(self, filename): 85 | self.filename = filename 86 | self.question_comments = defaultdict(list) 87 | self.question_comment = defaultdict(None) 88 | self.comment = defaultdict(None) 89 | self.all_comments = defaultdict(list) 90 | 91 | def domain_words(self): 92 | return ['duplicate', 'upvote', 'downvote', 'vote', 'related', 'upvoted', 'downvoted', 'edit'] 93 | 94 | def get_question(self, text): 95 | old_text = text 96 | text = remove_urls(text) 97 | if old_text != text: #ignore questions with urls 98 | return None 99 | tokens = get_tokens(text) 100 | lc_text = ' '.join(tokens) 101 | if 'have you' in lc_text or 'did you try' in lc_text or 'can you try' in lc_text or 'could you try' in lc_text: #ignore questions that indirectly provide answer 102 | return None 103 | if '?' in tokens: 104 | parts = " ".join(tokens).split('?') 105 | text = "" 106 | for i in range(len(parts)-1): 107 | text += parts[i]+ ' ?' 108 | words = text.split() 109 | if len(words) > 20: 110 | break 111 | if len(words) > 20: #ignore long comments 112 | return None 113 | for w in self.domain_words(): 114 | if w in words: 115 | return None 116 | if words[0] == '@': 117 | text = words[2:] 118 | else: 119 | text = words 120 | return text 121 | return None 122 | 123 | def get_comment_tokens(self, text): 124 | text = remove_urls(text) 125 | if text == '': 126 | return None 127 | tokens = get_tokens(text) 128 | if tokens == []: 129 | return None 130 | if tokens[0] == '@': 131 | tokens = tokens[2:] 132 | return tokens 133 | 134 | def parse_all_comments(self): 135 | comments_tree = ET.parse(self.filename) 136 | for comment in comments_tree.getroot(): 137 | postId = comment.attrib['PostId'] 138 | text = comment.attrib['Text'] 139 | try: 140 | userId = comment.attrib['UserId'] 141 | except: 142 | userId = None 143 | creation_date = datetime.datetime.strptime(comment.attrib['CreationDate'].split('.')[0], "%Y-%m-%dT%H:%M:%S") 144 | comment_tokens = self.get_comment_tokens(text) 145 | if not comment_tokens: 146 | continue 147 | curr_comment = Comment(comment_tokens, creation_date, userId) 148 | self.all_comments[postId].append(curr_comment) 149 | question = self.get_question(text) 150 | if question: 151 | question_comment = QuestionComment(question, creation_date, userId) 152 | self.question_comments[postId].append(question_comment) 153 | 154 | def get_question_comments(self): 155 | return self.question_comments 156 | 157 | def get_all_comments(self): 158 | return self.all_comments 159 | 160 | class PostHistory: 161 | def __init__(self): 162 | self.initial_post = None 163 | self.initial_post_sents = None 164 | self.edited_posts = [] 165 | self.edit_comments = [] 166 | self.edit_dates = [] 167 | 168 | class PostHistoryParser: 169 | 170 | def __init__(self, filename): 171 | self.filename = filename 172 | self.posthistories = defaultdict(PostHistory) 173 | 174 | def parse(self): 175 | posthistory_tree = ET.parse(self.filename) 176 | for posthistory in posthistory_tree.getroot(): 177 | posthistory_typeid = posthistory.attrib['PostHistoryTypeId'] 178 | postId = posthistory.attrib['PostId'] 179 | if posthistory_typeid == '2': 180 | self.posthistories[postId].initial_post = get_tokens(posthistory.attrib['Text']) 181 | self.posthistories[postId].initial_post_sents = get_sent_tokens(posthistory.attrib['Text']) 182 | elif posthistory_typeid == '5': 183 | self.posthistories[postId].edited_posts.append(get_tokens(posthistory.attrib['Text'])) 184 | self.posthistories[postId].edit_comments.append(get_tokens(posthistory.attrib['Comment'])) 185 | self.posthistories[postId].edit_dates.append(datetime.datetime.strptime(posthistory.attrib['CreationDate'].split('.')[0], "%Y-%m-%dT%H:%M:%S")) 186 | #format of date e.g.:"2008-09-06T08:07:10.730" We don't want .730 187 | for postId in self.posthistories.keys(): 188 | if not self.posthistories[postId].edited_posts: 189 | del self.posthistories[postId] 190 | 191 | def get_posthistories(self): 192 | return self.posthistories 193 | 194 | -------------------------------------------------------------------------------- /src/data_generation/post_ques_ans_generator.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from helper import * 3 | from collections import defaultdict 4 | from difflib import SequenceMatcher 5 | import pdb 6 | 7 | class PostQuesAns: 8 | 9 | def __init__(self, post_title, post, post_sents, question_comment, answer): 10 | self.post_title = post_title 11 | self.post = post 12 | self.post_sents = post_sents 13 | self.question_comment = question_comment 14 | self.answer = answer 15 | 16 | class PostQuesAnsGenerator: 17 | 18 | def __init__(self): 19 | self.post_ques_ans_dict = defaultdict(PostQuesAns) 20 | 21 | def get_diff(self, initial, final): 22 | s = SequenceMatcher(None, initial, final) 23 | diff = None 24 | for tag, i1, i2, j1, j2 in s.get_opcodes(): 25 | if tag == 'insert': 26 | diff = final[j1:j2] 27 | if not diff: 28 | return None 29 | return diff 30 | 31 | def find_first_question(self, answer, question_comment_candidates, vocab, word_embeddings): 32 | first_question = None 33 | first_date = None 34 | for question_comment in question_comment_candidates: 35 | if first_question == None: 36 | first_question = question_comment 37 | first_date = question_comment.creation_date 38 | else: 39 | if question_comment.creation_date < first_date: 40 | first_question = question_comment 41 | first_date = question_comment.creation_date 42 | return first_question 43 | 44 | def find_answer_comment(self, all_comments, question_comment, post_userId): 45 | answer_comment, answer_comment_date = None, None 46 | for comment in all_comments: 47 | if comment.userId and comment.userId == post_userId: 48 | if comment.creation_date > question_comment.creation_date: 49 | if not answer_comment or (comment.creation_date < answer_comment_date): 50 | answer_comment = comment 51 | answer_comment_date = comment.creation_date 52 | return answer_comment 53 | 54 | def generate_using_comments(self, posts, question_comments, all_comments, vocab, word_embeddings): 55 | for postId in posts.keys(): 56 | if postId in self.post_ques_ans_dict.keys(): 57 | continue 58 | if posts[postId].typeId != 1: # is not a main post 59 | continue 60 | first_question = None 61 | first_date = None 62 | for question_comment in question_comments[postId]: 63 | if question_comment.userId and question_comment.userId == posts[postId].owner_userId: 64 | continue #Ignore comments by the original author of the post 65 | if first_question == None: 66 | first_question = question_comment 67 | first_date = question_comment.creation_date 68 | else: 69 | if question_comment.creation_date < first_date: 70 | first_question = question_comment 71 | first_date = question_comment.creation_date 72 | question = first_question 73 | if not question: 74 | continue 75 | answer_comment = self.find_answer_comment(all_comments[postId], question, posts[postId].owner_userId) 76 | if not answer_comment: 77 | continue 78 | answer = answer_comment 79 | self.post_ques_ans_dict[postId] = PostQuesAns(posts[postId].title, posts[postId].body, \ 80 | posts[postId].sents, question.text, answer.text) 81 | 82 | def generate(self, posts, question_comments, all_comments, posthistories, vocab, word_embeddings): 83 | for postId, posthistory in posthistories.iteritems(): 84 | if not posthistory.edited_posts: 85 | continue 86 | if posts[postId].typeId != 1: # is not a main post 87 | continue 88 | if not posthistory.initial_post: 89 | continue 90 | first_edit_date, first_question, first_answer = None, None, None 91 | for i in range(len(posthistory.edited_posts)): 92 | answer = self.get_diff(posthistory.initial_post, posthistory.edited_posts[i]) 93 | if not answer: 94 | continue 95 | else: 96 | answer = remove_urls(' '.join(answer)) 97 | answer = answer.split() 98 | if is_too_short_or_long(answer): 99 | continue 100 | question_comment_candidates = [] 101 | for comment in question_comments[postId]: 102 | if comment.userId and comment.userId == posts[postId].owner_userId: 103 | continue #Ignore comments by the original author of the post 104 | if comment.creation_date > posthistory.edit_dates[i]: 105 | continue #Ignore comments added after the edit 106 | else: 107 | question_comment_candidates.append(comment) 108 | question = self.find_first_question(answer, question_comment_candidates, vocab, word_embeddings) 109 | if not question: 110 | continue 111 | answer_comment = self.find_answer_comment(all_comments[postId], question, posts[postId].owner_userId) 112 | if answer_comment and 'edit' not in answer_comment.text: #prefer edit if comment points to the edit 113 | question_indices = get_indices(question.text, vocab) 114 | answer_indices = get_indices(answer, vocab) 115 | answer_comment_indices = get_indices(answer_comment.text, vocab) 116 | if get_similarity(question_indices, answer_comment_indices, word_embeddings) > get_similarity(question_indices, answer_indices, word_embeddings): 117 | answer = answer_comment.text 118 | if first_edit_date == None or posthistory.edit_dates[i] < first_edit_date: 119 | first_question, first_answer, first_edit_date = question, answer, posthistory.edit_dates[i] 120 | 121 | if not first_question: 122 | continue 123 | self.post_ques_ans_dict[postId] = PostQuesAns(posts[postId].title, posthistory.initial_post, \ 124 | posthistory.initial_post_sents, first_question.text, first_answer) 125 | 126 | self.generate_using_comments(posts, question_comments, all_comments, vocab, word_embeddings) 127 | return self.post_ques_ans_dict 128 | -------------------------------------------------------------------------------- /src/data_generation/run_data_generator.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATADUMP_DIR=stackexchange #Directory containing xml files 4 | EMB_DIR=embeddings 5 | DATA_DIR=data 6 | SITE_NAME=askubuntu.com 7 | #SITE_NAME=unix.stackexchange.com 8 | #SITE_NAME=superuser.com 9 | 10 | SCRIPTS_DIR=ranking_clarification_questions/src/data_generation 11 | LUCENE_DIR=ranking_clarification_questions/lucene 12 | 13 | mkdir -p $DATA_DIR/$SITE_NAME 14 | 15 | rm -r $DATA_DIR/$SITE_NAME/post_docs 16 | rm -r $DATA_DIR/$SITE_NAME/post_doc_indices 17 | mkdir -p $DATA_DIR/$SITE_NAME/post_docs 18 | 19 | rm -r $DATA_DIR/$SITE_NAME/ques_docs 20 | rm -r $DATA_DIR/$SITE_NAME/ques_doc_indices 21 | mkdir -p $DATA_DIR/$SITE_NAME/ques_docs 22 | 23 | python $SCRIPTS_DIR/data_generator.py --posts_xml $DATADUMP_DIR/$SITE_NAME/Posts.xml \ 24 | --comments_xml $DATADUMP_DIR/$SITE_NAME/Comments.xml \ 25 | --posthistory_xml $DATADUMP_DIR/$SITE_NAME/PostHistory.xml \ 26 | --word_embeddings $EMB_DIR/word_embeddings.p \ 27 | --vocab $EMB_DIR/vocab.p \ 28 | --lucene_dir $LUCENE_DIR \ 29 | --lucene_docs_dir $DATA_DIR/$SITE_NAME/post_docs \ 30 | --lucene_similar_posts $DATA_DIR/$SITE_NAME/lucene_similar_posts.txt \ 31 | --site_name $SITE_NAME \ 32 | --post_data_tsv $DATA_DIR/$SITE_NAME/post_data.tsv 33 | --qa_data_tsv $DATA_DIR/$SITE_NAME/qa_data.tsv 34 | 35 | -------------------------------------------------------------------------------- /src/embedding_generation/README.md: -------------------------------------------------------------------------------- 1 | # Prerequisite 2 | 3 | * Download and compile GLoVE code from: https://nlp.stanford.edu/projects/glove/ 4 | 5 | # NOTE 6 | 7 | * Word embeddings pretrained on stackexchange datadump (version year 2017) can be found here: https://go.umd.edu/stackexchange_embeddings 8 | 9 | # Steps to retrain word embeddings 10 | 11 | 1. Download all domains of stackexchange from: https://archive.org/download/stackexchange 12 | 2. Extract text from all Posts.xml, Comments.xml and PostHistory.xml 13 | 3. Save the combined data under: stackexchange/stackexchange_datadump.txt 14 | 4. cd ranking_clarification_questions; sh src/embedding_generation/run_glove.sh 15 | 5. cd ranking_clarification_questions; sh src/embedding_generation/run_create_we_vocab.sh 16 | -------------------------------------------------------------------------------- /src/embedding_generation/create_we_vocab.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import cPickle as p 3 | 4 | if __name__ == "__main__": 5 | if len(sys.argv) < 4: 6 | print "usage: python create_we_vocab.py " 7 | sys.exit(0) 8 | word_vectors_file = open(sys.argv[1], 'r') 9 | word_embeddings = [] 10 | vocab = {} 11 | i = 0 12 | for line in word_vectors_file.readlines(): 13 | vals = line.rstrip().split(' ') 14 | vocab[vals[0]] = i 15 | word_embeddings.append(map(float, vals[1:])) 16 | i += 1 17 | p.dump(word_embeddings, open(sys.argv[2], 'wb')) 18 | p.dump(vocab, open(sys.argv[3], 'wb')) 19 | 20 | -------------------------------------------------------------------------------- /src/embedding_generation/run_create_we_vocab.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SCRIPTS_DIR=ranking_clarification_questions/src/embedding_generation 4 | EMB_DIR=ranking_clarification_questions/embeddings/ 5 | 6 | python $SCRIPTS_DIR/create_we_vocab.py $EMB_DIR/vectors.txt $EMB_DIR/word_embeddings.p $EMB_DIR/vocab.p 7 | 8 | -------------------------------------------------------------------------------- /src/embedding_generation/run_glove.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Makes programs, downloads sample data, trains a GloVe model, and then evaluates it. 4 | # One optional argument can specify the language used for eval script: matlab, octave or [default] python 5 | 6 | DATADIR=embeddings 7 | 8 | CORPUS=stackexchange/stackexchange_datadump.txt 9 | VOCAB_FILE=$DATADIR/vocab.txt 10 | COOCCURRENCE_FILE=$DATADIR/cooccurrence.bin 11 | COOCCURRENCE_SHUF_FILE=$DATADIR/cooccurrence.shuf.bin 12 | BUILDDIR=GloVe-1.2/build 13 | SAVE_FILE=$DATADIR/vectors 14 | VERBOSE=2 15 | MEMORY=4.0 16 | VOCAB_MIN_COUNT=100 17 | VECTOR_SIZE=200 18 | MAX_ITER=30 19 | WINDOW_SIZE=15 20 | BINARY=2 21 | NUM_THREADS=4 22 | X_MAX=10 23 | 24 | $BUILDDIR/vocab_count -min-count $VOCAB_MIN_COUNT -verbose $VERBOSE < $CORPUS > $VOCAB_FILE 25 | $BUILDDIR/cooccur -memory $MEMORY -vocab-file $VOCAB_FILE -verbose $VERBOSE -window-size $WINDOW_SIZE < $CORPUS > $COOCCURRENCE_FILE 26 | $BUILDDIR/shuffle -memory $MEMORY -verbose $VERBOSE < $COOCCURRENCE_FILE > $COOCCURRENCE_SHUF_FILE 27 | $BUILDDIR/glove -save-file $SAVE_FILE -eta 0.05 -threads $NUM_THREADS -input-file $COOCCURRENCE_SHUF_FILE -x-max $X_MAX -iter $MAX_ITER -vector-size $VECTOR_SIZE -binary $BINARY -vocab-file $VOCAB_FILE -verbose $VERBOSE 28 | -------------------------------------------------------------------------------- /src/evaluation/evaluate_model_with_human_annotations.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import pdb 4 | import random 5 | from sklearn.metrics import average_precision_score 6 | import sys 7 | 8 | BAD_QUESTIONS="unix_56867 unix_136954 unix_160510 unix_138507".split() + \ 9 | "askubuntu_791945 askubuntu_91332 askubuntu_704807 askubuntu_628216 askubuntu_688172 askubuntu_727993".split() + \ 10 | "askubuntu_279488 askubuntu_624918 askubuntu_527314 askubuntu_182249 askubuntu_610081 askubuntu_613851 askubuntu_777774 askubuntu_624498".split() + \ 11 | "superuser_356658 superuser_121201 superuser_455589 superuser_38460 superuser_739955 superuser_931151".split() + \ 12 | "superuser_291105 superuser_627439 superuser_584013 superuser_399182 superuser_632675 superuser_706347".split() + \ 13 | "superuser_670748 superuser_369878 superuser_830279 superuser_927242 superuser_850786".split() 14 | 15 | def get_annotations(line): 16 | set_info, post_id, best, valids, confidence = line.split(',') 17 | annotator_name = set_info.split('_')[0] 18 | sitename = set_info.split('_')[1] 19 | best = int(best) 20 | valids = [int(v) for v in valids.split()] 21 | confidence = int(confidence) 22 | return post_id, annotator_name, sitename, best, valids, confidence 23 | 24 | def calculate_precision(model_pred_indices, best, valids): 25 | bp1, bp3, bp5 = 0., 0., 0. 26 | vp1, vp3, vp5 = 0., 0., 0. 27 | bp1 = len(set(model_pred_indices[:1]).intersection(set(best)))*1.0 28 | bp3 = len(set(model_pred_indices[:3]).intersection(set(best)))*1.0/3 29 | bp5 = len(set(model_pred_indices[:5]).intersection(set(best)))*1.0/5 30 | 31 | vp1 = len(set(model_pred_indices[:1]).intersection(set(valids)))*1.0 32 | vp3 = len(set(model_pred_indices[:3]).intersection(set(valids)))*1.0/3 33 | vp5 = len(set(model_pred_indices[:5]).intersection(set(valids)))*1.0/5 34 | return bp1, bp3, bp5, vp1, vp3, vp5 35 | 36 | def calculate_avg_precision(model_probs, best, valids): 37 | best_bool = [0]*10 38 | valids_bool = [0]*10 39 | for i in range(10): 40 | if i in best: 41 | best_bool[i] = 1 42 | if i in valids: 43 | valids_bool[i] = 1 44 | bap = average_precision_score(best_bool, model_probs) 45 | if 1 in valids_bool: 46 | vap = average_precision_score(valids_bool, model_probs) 47 | else: 48 | vap = 0. 49 | if 1 in best_bool[1:]: 50 | bap_on9 = average_precision_score(best_bool[1:], model_probs[1:]) 51 | else: 52 | bap_on9 = 0. 53 | if 1 in valids_bool[1:]: 54 | vap_on9 = average_precision_score(valids_bool[1:], model_probs[1:]) 55 | else: 56 | vap_on9 = 0. 57 | return bap, vap, bap_on9, vap_on9 58 | 59 | def get_pred_indices(model_predictions, asc=False): 60 | preds = np.array(model_predictions) 61 | pred_indices = np.argsort(preds) 62 | if not asc: 63 | pred_indices = pred_indices[::-1] #since ascending sort and we want descending 64 | return pred_indices 65 | 66 | def convert_to_probalitites(model_predictions): 67 | tot = sum(model_predictions) 68 | model_probs = [0]*10 69 | for i,v in enumerate(model_predictions): 70 | model_probs[i] = v*1./tot 71 | return model_probs 72 | 73 | def evaluate_model_on_org(model_predictions, asc=False): 74 | br1_tot, br3_tot, br5_tot = 0., 0., 0. 75 | for post_id in model_predictions: 76 | model_probs = convert_to_probalitites(model_predictions[post_id]) 77 | model_pred_indices = get_pred_indices(model_probs, asc) 78 | br1, br3, br5, vr1, vr3, vr5 = calculate_precision(model_pred_indices, [0], [0]) 79 | br1_tot += br1 80 | br3_tot += br3 81 | br5_tot += br5 82 | N = len(model_predictions) 83 | return br1_tot*100.0/N, br3_tot*100.0/N, br5_tot*100.0/N 84 | 85 | def read_human_annotations(human_annotations_filename): 86 | human_annotations_file = open(human_annotations_filename, 'r') 87 | annotations = {} 88 | for line in human_annotations_file.readlines(): 89 | line = line.strip('\n') 90 | splits = line.split('\t') 91 | post_id1, annotator_name1, sitename1, best1, valids1, confidence1 = get_annotations(splits[0]) 92 | post_id2, annotator_name2, sitename2, best2, valids2, confidence2 = get_annotations(splits[1]) 93 | assert(sitename1 == sitename2) 94 | assert(post_id1 == post_id2) 95 | post_id = sitename1+'_'+post_id1 96 | best_union = list(set([best1, best2])) 97 | valids_inter = list(set(valids1).intersection(set(valids2))) 98 | annotations[post_id] = (best_union, valids_inter) 99 | return annotations 100 | 101 | def evaluate_model(human_annotations_filename, model_predictions, asc=False): 102 | human_annotations_file = open(human_annotations_filename, 'r') 103 | br1_tot, br3_tot, br5_tot = 0., 0., 0. 104 | vr1_tot, vr3_tot, vr5_tot = 0., 0., 0. 105 | br1_on9_tot, br3_on9_tot, br5_on9_tot = 0., 0., 0. 106 | vr1_on9_tot, vr3_on9_tot, vr5_on9_tot = 0., 0., 0. 107 | bap_tot, vap_tot = 0., 0. 108 | bap_on9_tot, vap_on9_tot = 0., 0. 109 | N = 0 110 | for line in human_annotations_file.readlines(): 111 | line = line.strip('\n') 112 | splits = line.split('\t') 113 | post_id1, annotator_name1, sitename1, best1, valids1, confidence1 = get_annotations(splits[0]) 114 | post_id2, annotator_name2, sitename2, best2, valids2, confidence2 = get_annotations(splits[1]) 115 | assert(sitename1 == sitename2) 116 | assert(post_id1 == post_id2) 117 | post_id = sitename1+'_'+post_id1 118 | if post_id in BAD_QUESTIONS: 119 | continue 120 | best_union = list(set([best1, best2])) 121 | valids_inter = list(set(valids1).intersection(set(valids2))) 122 | valids_union = list(set(valids1+valids2)) 123 | 124 | model_probs = convert_to_probalitites(model_predictions[post_id]) 125 | model_pred_indices = get_pred_indices(model_probs, asc) 126 | 127 | br1, br3, br5, vr1, vr3, vr5 = calculate_precision(model_pred_indices, best_union, valids_inter) 128 | br1_tot += br1 129 | br3_tot += br3 130 | br5_tot += br5 131 | vr1_tot += vr1 132 | vr3_tot += vr3 133 | vr5_tot += vr5 134 | bap, vap, bap_on9, vap_on9 = calculate_avg_precision(model_probs, best_union, valids_inter) 135 | bap_tot += bap 136 | vap_tot += vap 137 | bap_on9_tot += bap_on9 138 | vap_on9_tot += vap_on9 139 | 140 | model_pred_indices = np.delete(model_pred_indices, 0) 141 | 142 | br1_on9, br3_on9, br5_on9, vr1_on9, vr3_on9, vr5_on9 = calculate_precision(model_pred_indices, best_union, valids_inter) 143 | 144 | br1_on9_tot += br1_on9 145 | br3_on9_tot += br3_on9 146 | br5_on9_tot += br5_on9 147 | vr1_on9_tot += vr1_on9 148 | vr3_on9_tot += vr3_on9 149 | vr5_on9_tot += vr5_on9 150 | 151 | N += 1 152 | 153 | human_annotations_file.close() 154 | return br1_tot*100.0/N, br3_tot*100.0/N, br5_tot*100.0/N, vr1_tot*100.0/N, vr3_tot*100.0/N, vr5_tot*100.0/N, \ 155 | br1_on9_tot*100.0/N, br3_on9_tot*100.0/N, br5_on9_tot*100.0/N, vr1_on9_tot*100.0/N, vr3_on9_tot*100.0/N, vr5_on9_tot*100.0/N, \ 156 | bap_tot*100./N, vap_tot*100./N, bap_on9_tot*100./N, vap_on9_tot*100./N 157 | 158 | def read_model_predictions(model_predictions_file): 159 | model_predictions = {} 160 | for line in model_predictions_file.readlines(): 161 | splits = line.strip('\n').split() 162 | post_id = splits[0][1:-2] 163 | predictions = [float(val) for val in splits[1:]] 164 | model_predictions[post_id] = predictions 165 | return model_predictions 166 | 167 | def print_numbers(br1, br3, br5, vr1, vr3, vr5, br1_on9, br3_on9, br5_on9, vr1_on9, vr3_on9, vr5_on9, bmap, vmap, bmap_on9, vmap_on9, br1_org, br3_org, br5_org): 168 | print 'Best' 169 | print 'p@1 %.1f' % (br1) 170 | print 'p@3 %.1f' % (br3) 171 | print 'p@5 %.1f' % (br5) 172 | print 'map %.1f' % (bmap) 173 | print 174 | print 'Valid' 175 | print 'p@1 %.1f' % (vr1) 176 | print 'p@3 %.1f' % (vr3) 177 | print 'p@5 %.1f' % (vr5) 178 | print 'map %.1f' % (vmap) 179 | print 180 | print 'Best on 9' 181 | print 'p@1 %.1f' % (br1_on9) 182 | print 'p@3 %.1f' % (br3_on9) 183 | print 'p@5 %.1f' % (br5_on9) 184 | print 'map %.1f' % (bmap_on9) 185 | print 186 | print 'Valid on 9' 187 | print 'p@1 %.1f' % (vr1_on9) 188 | print 'p@3 %.1f' % (vr3_on9) 189 | print 'p@5 %.1f' % (vr5_on9) 190 | print 'map %.1f' % (vmap_on9) 191 | print 192 | print 'Original' 193 | print 'p@1 %.1f' % (br1_org) 194 | #print 'p@3 %.1f' % (br3_org) 195 | #print 'p@5 %.1f' % (br5_org) 196 | 197 | def main(args): 198 | model_predictions_file = open(args.model_predictions_filename, 'r') 199 | asc=False 200 | model_predictions = read_model_predictions(model_predictions_file) 201 | br1, br3, br5, vr1, vr3, vr5, \ 202 | br1_on9, br3_on9, br5_on9, \ 203 | vr1_on9, vr3_on9, vr5_on9, \ 204 | bmap, vmap, bmap_on9, vmap_on9 = evaluate_model(args.human_annotations_filename, model_predictions, asc) 205 | br1_org, br3_org, br5_org = evaluate_model_on_org(model_predictions) 206 | 207 | print_numbers(br1, br3, br5, vr1, vr3, vr5, \ 208 | br1_on9, br3_on9, br5_on9, \ 209 | vr1_on9, vr3_on9, vr5_on9, \ 210 | bmap, vmap, bmap_on9, vmap_on9, \ 211 | br1_org, br3_org, br5_org) 212 | 213 | if __name__ == '__main__': 214 | argparser = argparse.ArgumentParser(sys.argv[0]) 215 | argparser.add_argument("--human_annotations_filename", type = str) 216 | argparser.add_argument("--model_predictions_filename", type = str) 217 | args = argparser.parse_args() 218 | print args 219 | print "" 220 | main(args) 221 | 222 | -------------------------------------------------------------------------------- /src/evaluation/run_evaluation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_DIR=data 4 | #SITE_NAME=askubuntu.com 5 | #SITE_NAME=unix.stackexchange.com 6 | #SITE_NAME=superuser.com 7 | SITE_NAME=askubuntu_unix_superuser 8 | 9 | SCRIPTS_DIR=src/evaluation 10 | #MODEL=baseline_pq 11 | #MODEL=baseline_pa 12 | #MODEL=baseline_pqa 13 | MODEL=evpi 14 | 15 | python $SCRIPTS_DIR/evaluate_model_with_human_annotations.py \ 16 | --human_annotations_filename $DATA_DIR/$SITE_NAME/human_annotations \ 17 | --model_predictions_filename $DATA_DIR/$SITE_NAME/test_predictions_${MODEL}.out.epoch13 \ 18 | -------------------------------------------------------------------------------- /src/models/README.md: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | 3 | * Install lasagne: http://lasagne.readthedocs.io/en/latest/user/installation.html 4 | * Install numpy, scipy 5 | * Version information: 6 | 7 | Python 2.7.5 8 | 9 | Theano 0.9.0dev5 10 | 11 | Lasagne 0.2.dev1 12 | 13 | Cuda 8.0.44 14 | 15 | Cudnn 5.1 16 | 17 | # Loading data 18 | 19 | Load data from askubuntu.com 20 | 21 | * Set "SITE_NAME=askubuntu.com" in ranking_clarification_questions/src/models/run_load_data.sh 22 | * cd ranking_clarification_questions; sh src/models/run_load_data.sh 23 | 24 | Load data from unix.stackexchange.com 25 | 26 | * Set "SITE_NAME=unix.stackexchange.com" in ranking_clarification_questions/src/models/run_load_data.sh 27 | * cd ranking_clarification_questions; sh src/models/run_load_data.sh 28 | 29 | Load data from superuser.com 30 | 31 | * Set "SITE_NAME=superuser.com" in ranking_clarification_questions/src/models/run_load_data.sh 32 | * cd ranking_clarification_questions; sh src/models/run_load_data.sh 33 | 34 | Combine data from three domains 35 | 36 | * cd ranking_clarification_questions; sh src/models/run_combine_domains.sh 37 | * cat data/askubuntu.com/human_annotations data/unix.stackexchange.com/human_annotations data/superuser.com/human_annotations > askubuntu_unix_superuser/human_annotations 38 | 39 | # Running neural baselines on the combined data 40 | 41 | Neural(p,q) 42 | 43 | * Set "MODEL=baseline_pq" in ranking_clarification_questions/src/models/run_main.sh 44 | * cd ranking_clarification_questions; sh src/models/run_main.sh 45 | 46 | Neural(p,a) 47 | 48 | * Set "MODEL=baseline_pa" in ranking_clarification_questions/src/models/run_main.sh 49 | * cd ranking_clarification_questions; sh src/models/run_main.sh 50 | 51 | Neural(p,q,a) 52 | 53 | * Set "MODEL=baseline_pqa" in ranking_clarification_questions/src/models/run_main.sh 54 | * cd ranking_clarification_questions; sh src/models/run_main.sh 55 | 56 | # Runing EVPI model on the combined data 57 | 58 | * Set "MODEL=evpi" in ranking_clarification_questions/src/models/run_main.sh 59 | * cd ranking_clarification_questions; sh src/models/run_main.sh 60 | 61 | # Runing evaluation 62 | 63 | * cd ranking_clarification_questions; sh src/evaluation/run_evaluation.sh 64 | -------------------------------------------------------------------------------- /src/models/baseline_pa.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import theano, lasagne 4 | import numpy as np 5 | import cPickle as p 6 | import theano.tensor as T 7 | from collections import Counter 8 | import pdb 9 | import time 10 | import random, math 11 | DEPTH = 5 12 | from lstm_helper import * 13 | from model_helper import * 14 | 15 | def build(word_embeddings, len_voc, word_emb_dim, args, freeze=False): 16 | 17 | # input theano vars 18 | posts = T.imatrix() 19 | post_masks = T.fmatrix() 20 | ans_list = T.itensor3() 21 | ans_masks_list = T.ftensor3() 22 | labels = T.imatrix() 23 | N = args.no_of_candidates 24 | 25 | post_out, post_lstm_params = build_lstm(posts, post_masks, args.post_max_len, \ 26 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size) 27 | ans_out, ans_emb_out, ans_lstm_params = build_list_lstm(ans_list, ans_masks_list, N, args.ans_max_len, \ 28 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size) 29 | 30 | pa_preds = [None]*N 31 | post_ans = T.concatenate([post_out, ans_out[0]], axis=1) 32 | l_post_ans_in = lasagne.layers.InputLayer(shape=(args.batch_size, 2*args.hidden_dim), input_var=post_ans) 33 | l_post_ans_denses = [None]*DEPTH 34 | for k in range(DEPTH): 35 | if k == 0: 36 | l_post_ans_denses[k] = lasagne.layers.DenseLayer(l_post_ans_in, num_units=args.hidden_dim,\ 37 | nonlinearity=lasagne.nonlinearities.rectify) 38 | else: 39 | l_post_ans_denses[k] = lasagne.layers.DenseLayer(l_post_ans_denses[k-1], num_units=args.hidden_dim,\ 40 | nonlinearity=lasagne.nonlinearities.rectify) 41 | l_post_ans_dense = lasagne.layers.DenseLayer(l_post_ans_denses[-1], num_units=1,\ 42 | nonlinearity=lasagne.nonlinearities.sigmoid) 43 | pa_preds[0] = lasagne.layers.get_output(l_post_ans_dense) 44 | loss = T.sum(lasagne.objectives.binary_crossentropy(pa_preds[0], labels[:,0])) 45 | for i in range(1, N): 46 | post_ans = T.concatenate([post_out, ans_out[i]], axis=1) 47 | l_post_ans_in_ = lasagne.layers.InputLayer(shape=(args.batch_size, 2*args.hidden_dim), input_var=post_ans) 48 | for k in range(DEPTH): 49 | if k == 0: 50 | l_post_ans_dense_ = lasagne.layers.DenseLayer(l_post_ans_in_, num_units=args.hidden_dim,\ 51 | nonlinearity=lasagne.nonlinearities.rectify,\ 52 | W=l_post_ans_denses[k].W,\ 53 | b=l_post_ans_denses[k].b) 54 | else: 55 | l_post_ans_dense_ = lasagne.layers.DenseLayer(l_post_ans_dense_, num_units=args.hidden_dim,\ 56 | nonlinearity=lasagne.nonlinearities.rectify,\ 57 | W=l_post_ans_denses[k].W,\ 58 | b=l_post_ans_denses[k].b) 59 | l_post_ans_dense_ = lasagne.layers.DenseLayer(l_post_ans_dense_, num_units=1,\ 60 | nonlinearity=lasagne.nonlinearities.sigmoid) 61 | pa_preds[i] = lasagne.layers.get_output(l_post_ans_dense_) 62 | loss += T.sum(lasagne.objectives.binary_crossentropy(pa_preds[i], labels[:,i])) 63 | 64 | post_ans_dense_params = lasagne.layers.get_all_params(l_post_ans_dense, trainable=True) 65 | 66 | all_params = post_lstm_params + ans_lstm_params + post_ans_dense_params 67 | #print 'Params in concat ', lasagne.layers.count_params(l_post_ans_dense) 68 | loss += args.rho * sum(T.sum(l ** 2) for l in all_params) 69 | 70 | updates = lasagne.updates.adam(loss, all_params, learning_rate=args.learning_rate) 71 | 72 | train_fn = theano.function([posts, post_masks, ans_list, ans_masks_list, labels], \ 73 | [loss] + pa_preds, updates=updates) 74 | test_fn = theano.function([posts, post_masks, ans_list, ans_masks_list, labels], \ 75 | [loss] + pa_preds,) 76 | return train_fn, test_fn 77 | 78 | def validate(val_fn, fold_name, epoch, fold, args, out_file=None): 79 | start = time.time() 80 | num_batches = 0 81 | cost = 0 82 | corr = 0 83 | mrr = 0 84 | total = 0 85 | _lambda = 0.5 86 | N = args.no_of_candidates 87 | recall = [0]*N 88 | 89 | if out_file: 90 | out_file_o = open(out_file+'.epoch%d' % epoch, 'w') 91 | out_file_o.close() 92 | posts, post_masks, ques_list, ques_masks_list, ans_list, ans_masks_list, post_ids = fold 93 | labels = np.zeros((len(post_ids), N), dtype=np.int32) 94 | ranks = np.zeros((len(post_ids), N), dtype=np.int32) 95 | labels[:,0] = 1 96 | for j in range(N): 97 | ranks[:,j] = j 98 | ques_list, ques_masks_list, ans_list, ans_masks_list, labels, ranks = shuffle(ques_list, ques_masks_list, \ 99 | ans_list, ans_masks_list, labels, ranks) 100 | for p, pm, q, qm, a, am, l, r, ids in iterate_minibatches(posts, post_masks, ques_list, ques_masks_list, \ 101 | ans_list, ans_masks_list, labels, ranks, \ 102 | post_ids, args.batch_size, shuffle=False): 103 | a = np.transpose(a, (1, 0, 2)) 104 | am = np.transpose(am, (1, 0, 2)) 105 | 106 | pa_out = val_fn(p, pm, a, am, l) 107 | loss = pa_out[0] 108 | pa_preds = pa_out[1:] 109 | pa_preds = np.transpose(pa_preds, (1, 0, 2)) 110 | pa_preds = pa_preds[:,:,0] 111 | cost += loss 112 | for j in range(args.batch_size): 113 | preds = [0.0]*N 114 | for k in range(N): 115 | preds[k] = pa_preds[j][k] 116 | rank = get_rank(preds, l[j]) 117 | if rank == 1: 118 | corr += 1 119 | mrr += 1.0/rank 120 | for index in range(N): 121 | if rank <= index+1: 122 | recall[index] += 1 123 | total += 1 124 | if out_file: 125 | write_test_predictions(out_file, ids[j], preds, r[j], epoch) 126 | num_batches += 1 127 | 128 | lstring = '%s: epoch:%d, cost:%f, acc:%f, mrr:%f,time:%d' % \ 129 | (fold_name, epoch, cost*1.0/num_batches, corr*1.0/total, mrr*1.0/total, time.time()-start) 130 | 131 | recall = [round(curr_r*1.0/total, 3) for curr_r in recall] 132 | recall_str = '[' 133 | for r in recall: 134 | recall_str += '%.3f ' % r 135 | recall_str += ']\n' 136 | 137 | print lstring 138 | print recall 139 | 140 | def baseline_pa(word_embeddings, vocab_size, word_emb_dim, freeze, args, train, test): 141 | start = time.time() 142 | print 'compiling pq graph...' 143 | train_fn, test_fn, = build(word_embeddings, vocab_size, word_emb_dim, args, freeze=freeze) 144 | print 'done! Time taken: ', time.time()-start 145 | 146 | # train network 147 | for epoch in range(args.no_of_epochs): 148 | validate(train_fn, 'TRAIN', epoch, train, args) 149 | validate(test_fn, '\t TEST', epoch, test, args, args.test_predictions_output) 150 | print "\n" 151 | -------------------------------------------------------------------------------- /src/models/baseline_pq.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import theano, lasagne 4 | import numpy as np 5 | import cPickle as p 6 | import theano.tensor as T 7 | from collections import Counter 8 | import pdb 9 | import time 10 | import random, math 11 | DEPTH = 5 12 | from lstm_helper import * 13 | from model_helper import * 14 | 15 | def build(word_embeddings, len_voc, word_emb_dim, args, freeze=False): 16 | 17 | # input theano vars 18 | posts = T.imatrix() 19 | post_masks = T.fmatrix() 20 | ques_list = T.itensor3() 21 | ques_masks_list = T.ftensor3() 22 | labels = T.imatrix() 23 | N = args.no_of_candidates 24 | 25 | post_out, post_lstm_params = build_lstm(posts, post_masks, args.post_max_len, \ 26 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size) 27 | ques_out, ques_emb_out, ques_lstm_params = build_list_lstm(ques_list, ques_masks_list, N, args.ques_max_len, \ 28 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size) 29 | 30 | pq_preds = [None]*N 31 | post_ques = T.concatenate([post_out, ques_out[0]], axis=1) 32 | l_post_ques_in = lasagne.layers.InputLayer(shape=(args.batch_size, 2*args.hidden_dim), input_var=post_ques) 33 | l_post_ques_denses = [None]*DEPTH 34 | for k in range(DEPTH): 35 | if k == 0: 36 | l_post_ques_denses[k] = lasagne.layers.DenseLayer(l_post_ques_in, num_units=args.hidden_dim,\ 37 | nonlinearity=lasagne.nonlinearities.rectify) 38 | else: 39 | l_post_ques_denses[k] = lasagne.layers.DenseLayer(l_post_ques_denses[k-1], num_units=args.hidden_dim,\ 40 | nonlinearity=lasagne.nonlinearities.rectify) 41 | l_post_ques_dense = lasagne.layers.DenseLayer(l_post_ques_denses[-1], num_units=1,\ 42 | nonlinearity=lasagne.nonlinearities.sigmoid) 43 | pq_preds[0] = lasagne.layers.get_output(l_post_ques_dense) 44 | loss = T.sum(lasagne.objectives.binary_crossentropy(pq_preds[0], labels[:,0])) 45 | for i in range(1, N): 46 | post_ques = T.concatenate([post_out, ques_out[i]], axis=1) 47 | l_post_ques_in_ = lasagne.layers.InputLayer(shape=(args.batch_size, 2*args.hidden_dim), input_var=post_ques) 48 | for k in range(DEPTH): 49 | if k == 0: 50 | l_post_ques_dense_ = lasagne.layers.DenseLayer(l_post_ques_in_, num_units=args.hidden_dim,\ 51 | nonlinearity=lasagne.nonlinearities.rectify,\ 52 | W=l_post_ques_denses[k].W,\ 53 | b=l_post_ques_denses[k].b) 54 | else: 55 | l_post_ques_dense_ = lasagne.layers.DenseLayer(l_post_ques_dense_, num_units=args.hidden_dim,\ 56 | nonlinearity=lasagne.nonlinearities.rectify,\ 57 | W=l_post_ques_denses[k].W,\ 58 | b=l_post_ques_denses[k].b) 59 | l_post_ques_dense_ = lasagne.layers.DenseLayer(l_post_ques_dense_, num_units=1,\ 60 | nonlinearity=lasagne.nonlinearities.sigmoid) 61 | pq_preds[i] = lasagne.layers.get_output(l_post_ques_dense_) 62 | loss += T.sum(lasagne.objectives.binary_crossentropy(pq_preds[i], labels[:,i])) 63 | 64 | post_ques_dense_params = lasagne.layers.get_all_params(l_post_ques_dense, trainable=True) 65 | 66 | all_params = post_lstm_params + ques_lstm_params + post_ques_dense_params 67 | #print 'Params in concat ', lasagne.layers.count_params(l_post_ques_dense) 68 | loss += args.rho * sum(T.sum(l ** 2) for l in all_params) 69 | 70 | updates = lasagne.updates.adam(loss, all_params, learning_rate=args.learning_rate) 71 | 72 | train_fn = theano.function([posts, post_masks, ques_list, ques_masks_list, labels], \ 73 | [loss] + pq_preds, updates=updates) 74 | test_fn = theano.function([posts, post_masks, ques_list, ques_masks_list, labels], \ 75 | [loss] + pq_preds,) 76 | return train_fn, test_fn 77 | 78 | def validate(val_fn, fold_name, epoch, fold, args, out_file=None): 79 | start = time.time() 80 | num_batches = 0 81 | cost = 0 82 | corr = 0 83 | mrr = 0 84 | total = 0 85 | _lambda = 0.5 86 | N = args.no_of_candidates 87 | recall = [0]*N 88 | 89 | if out_file: 90 | out_file_o = open(out_file+'.epoch%d' % epoch, 'w') 91 | out_file_o.close() 92 | posts, post_masks, ques_list, ques_masks_list, ans_list, ans_masks_list, post_ids = fold 93 | labels = np.zeros((len(post_ids), N), dtype=np.int32) 94 | ranks = np.zeros((len(post_ids), N), dtype=np.int32) 95 | labels[:,0] = 1 96 | for j in range(N): 97 | ranks[:,j] = j 98 | ques_list, ques_masks_list, ans_list, ans_masks_list, labels, ranks = shuffle(ques_list, ques_masks_list, \ 99 | ans_list, ans_masks_list, labels, ranks) 100 | for p, pm, q, qm, a, am, l, r, ids in iterate_minibatches(posts, post_masks, ques_list, ques_masks_list, \ 101 | ans_list, ans_masks_list, labels, ranks, \ 102 | post_ids, args.batch_size, shuffle=False): 103 | q = np.transpose(q, (1, 0, 2)) 104 | qm = np.transpose(qm, (1, 0, 2)) 105 | 106 | pq_out = val_fn(p, pm, q, qm, l) 107 | loss = pq_out[0] 108 | pq_preds = pq_out[1:] 109 | pq_preds = np.transpose(pq_preds, (1, 0, 2)) 110 | pq_preds = pq_preds[:,:,0] 111 | cost += loss 112 | for j in range(args.batch_size): 113 | preds = [0.0]*N 114 | for k in range(N): 115 | preds[k] = pq_preds[j][k] 116 | rank = get_rank(preds, l[j]) 117 | if rank == 1: 118 | corr += 1 119 | mrr += 1.0/rank 120 | for index in range(N): 121 | if rank <= index+1: 122 | recall[index] += 1 123 | total += 1 124 | if out_file: 125 | write_test_predictions(out_file, ids[j], preds, r[j], epoch) 126 | num_batches += 1 127 | 128 | lstring = '%s: epoch:%d, cost:%f, acc:%f, mrr:%f,time:%d' % \ 129 | (fold_name, epoch, cost*1.0/num_batches, corr*1.0/total, mrr*1.0/total, time.time()-start) 130 | 131 | recall = [round(curr_r*1.0/total, 3) for curr_r in recall] 132 | recall_str = '[' 133 | for r in recall: 134 | recall_str += '%.3f ' % r 135 | recall_str += ']\n' 136 | 137 | print lstring 138 | print recall 139 | 140 | def baseline_pq(word_embeddings, vocab_size, word_emb_dim, freeze, args, train, test): 141 | start = time.time() 142 | print 'compiling pq graph...' 143 | train_fn, test_fn, = build(word_embeddings, vocab_size, word_emb_dim, args, freeze=freeze) 144 | print 'done! Time taken: ', time.time()-start 145 | 146 | # train network 147 | for epoch in range(args.no_of_epochs): 148 | validate(train_fn, 'TRAIN', epoch, train, args) 149 | validate(test_fn, '\t TEST', epoch, test, args, args.test_predictions_output) 150 | print "\n" 151 | -------------------------------------------------------------------------------- /src/models/baseline_pqa.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import theano, lasagne 3 | import numpy as np 4 | import cPickle as p 5 | import theano.tensor as T 6 | import pdb 7 | import time 8 | DEPTH = 5 9 | from lstm_helper import * 10 | from model_helper import * 11 | 12 | def build(word_embeddings, len_voc, word_emb_dim, args, freeze=False): 13 | 14 | # input theano vars 15 | posts = T.imatrix() 16 | post_masks = T.fmatrix() 17 | ques_list = T.itensor3() 18 | ques_masks_list = T.ftensor3() 19 | ans_list = T.itensor3() 20 | ans_masks_list = T.ftensor3() 21 | labels = T.imatrix() 22 | N = args.no_of_candidates 23 | 24 | post_out, post_lstm_params = build_lstm(posts, post_masks, args.post_max_len, \ 25 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size) 26 | ques_out, ques_emb_out, ques_lstm_params = build_list_lstm(ques_list, ques_masks_list, N, args.ques_max_len, \ 27 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size) 28 | ans_out, ans_emb_out, ans_lstm_params = build_list_lstm(ans_list, ans_masks_list, N, args.ans_max_len, \ 29 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size) 30 | 31 | pqa_preds = [None]*(N*N) 32 | post_ques_ans = T.concatenate([post_out, ques_out[0], ans_out[0]], axis=1) 33 | l_post_ques_ans_in = lasagne.layers.InputLayer(shape=(args.batch_size, 3*args.hidden_dim), input_var=post_ques_ans) 34 | l_post_ques_ans_denses = [None]*DEPTH 35 | for k in range(DEPTH): 36 | if k == 0: 37 | l_post_ques_ans_denses[k] = lasagne.layers.DenseLayer(l_post_ques_ans_in, num_units=args.hidden_dim,\ 38 | nonlinearity=lasagne.nonlinearities.rectify) 39 | else: 40 | l_post_ques_ans_denses[k] = lasagne.layers.DenseLayer(l_post_ques_ans_denses[k-1], num_units=args.hidden_dim,\ 41 | nonlinearity=lasagne.nonlinearities.rectify) 42 | l_post_ques_ans_dense = lasagne.layers.DenseLayer(l_post_ques_ans_denses[-1], num_units=1,\ 43 | nonlinearity=lasagne.nonlinearities.sigmoid) 44 | pqa_preds[0] = lasagne.layers.get_output(l_post_ques_ans_dense) 45 | loss = 0.0 46 | for i in range(N): 47 | for j in range(N): 48 | if i == 0 and j == 0: 49 | continue 50 | post_ques_ans = T.concatenate([post_out, ques_out[i], ans_out[j]], axis=1) 51 | l_post_ques_ans_in_ = lasagne.layers.InputLayer(shape=(args.batch_size, 3*args.hidden_dim), input_var=post_ques_ans) 52 | for k in range(DEPTH): 53 | if k == 0: 54 | l_post_ques_ans_dense_ = lasagne.layers.DenseLayer(l_post_ques_ans_in_, num_units=args.hidden_dim,\ 55 | nonlinearity=lasagne.nonlinearities.rectify,\ 56 | W=l_post_ques_ans_denses[k].W,\ 57 | b=l_post_ques_ans_denses[k].b) 58 | else: 59 | l_post_ques_ans_dense_ = lasagne.layers.DenseLayer(l_post_ques_ans_dense_, num_units=args.hidden_dim,\ 60 | nonlinearity=lasagne.nonlinearities.rectify,\ 61 | W=l_post_ques_ans_denses[k].W,\ 62 | b=l_post_ques_ans_denses[k].b) 63 | l_post_ques_ans_dense_ = lasagne.layers.DenseLayer(l_post_ques_ans_dense_, num_units=1,\ 64 | nonlinearity=lasagne.nonlinearities.sigmoid) 65 | pqa_preds[i*N+j] = lasagne.layers.get_output(l_post_ques_ans_dense_) 66 | loss += T.mean(lasagne.objectives.binary_crossentropy(pqa_preds[i*N+i], labels[:,i])) 67 | 68 | squared_errors = [None]*(N*N) 69 | for i in range(N): 70 | for j in range(N): 71 | squared_errors[i*N+j] = lasagne.objectives.squared_error(ans_out[i], ans_out[j]) 72 | post_ques_ans_dense_params = lasagne.layers.get_all_params(l_post_ques_ans_dense, trainable=True) 73 | 74 | all_params = post_lstm_params + ques_lstm_params + ans_lstm_params + post_ques_ans_dense_params 75 | #print 'Params in concat ', lasagne.layers.count_params(l_post_ques_ans_dense) 76 | loss += args.rho * sum(T.sum(l ** 2) for l in all_params) 77 | 78 | updates = lasagne.updates.adam(loss, all_params, learning_rate=args.learning_rate) 79 | 80 | train_fn = theano.function([posts, post_masks, ques_list, ques_masks_list, ans_list, ans_masks_list, labels], \ 81 | [loss] + pqa_preds + squared_errors, updates=updates) 82 | test_fn = theano.function([posts, post_masks, ques_list, ques_masks_list, ans_list, ans_masks_list, labels], \ 83 | [loss] + pqa_preds + squared_errors,) 84 | return train_fn, test_fn 85 | 86 | def validate(val_fn, fold_name, epoch, fold, args, out_file=None): 87 | start = time.time() 88 | num_batches = 0 89 | cost = 0 90 | corr = 0 91 | mrr = 0 92 | total = 0 93 | _lambda = 0.5 94 | N = args.no_of_candidates 95 | recall = [0]*N 96 | batch_size = args.batch_size 97 | 98 | if out_file: 99 | out_file_o = open(out_file+'.epoch%d' % epoch, 'w') 100 | out_file_o.close() 101 | posts, post_masks, ques_list, ques_masks_list, ans_list, ans_masks_list, post_ids = fold 102 | labels = np.zeros((len(post_ids), N), dtype=np.int32) 103 | ranks = np.zeros((len(post_ids), N), dtype=np.int32) 104 | labels[:,0] = 1 105 | for j in range(N): 106 | ranks[:,j] = j 107 | ques_list, ques_masks_list, ans_list, ans_masks_list, labels, ranks = shuffle(ques_list, ques_masks_list, \ 108 | ans_list, ans_masks_list, labels, ranks) 109 | for p, pm, q, qm, a, am, l, r, ids in iterate_minibatches(posts, post_masks, ques_list, ques_masks_list, \ 110 | ans_list, ans_masks_list, labels, ranks, \ 111 | post_ids, args.batch_size, shuffle=False): 112 | q = np.transpose(q, (1, 0, 2)) 113 | qm = np.transpose(qm, (1, 0, 2)) 114 | a = np.transpose(a, (1, 0, 2)) 115 | am = np.transpose(am, (1, 0, 2)) 116 | 117 | out = val_fn(p, pm, q, qm, a, am, l) 118 | loss = out[0] 119 | probs = out[1:1+N*N] 120 | errors = out[1+N*N:] 121 | probs = np.transpose(probs, (1, 0, 2)) 122 | probs = probs[:,:,0] 123 | errors = np.transpose(errors, (1, 0, 2)) 124 | errors = errors[:,:,0] 125 | cost += loss 126 | for j in range(batch_size): 127 | preds = [0.0]*N 128 | for k in range(N): 129 | preds[k] = probs[j][k*N+k] 130 | rank = get_rank(preds, l[j]) 131 | if rank == 1: 132 | corr += 1 133 | mrr += 1.0/rank 134 | for index in range(N): 135 | if rank <= index+1: 136 | recall[index] += 1 137 | total += 1 138 | if out_file: 139 | write_test_predictions(out_file, ids[j], preds, r[j], epoch) 140 | num_batches += 1 141 | 142 | lstring = '%s: epoch:%d, cost:%f, acc:%f, mrr:%f,time:%d' % \ 143 | (fold_name, epoch, cost*1.0/num_batches, corr*1.0/total, mrr*1.0/total, time.time()-start) 144 | 145 | recall = [round(curr_r*1.0/total, 3) for curr_r in recall] 146 | recall_str = '[' 147 | for r in recall: 148 | recall_str += '%.3f ' % r 149 | recall_str += ']\n' 150 | 151 | print lstring 152 | print recall 153 | 154 | def baseline_pqa(word_embeddings, vocab_size, word_emb_dim, freeze, args, train, test): 155 | start = time.time() 156 | print 'compiling pqa graph...' 157 | train_fn, test_fn, = build(word_embeddings, vocab_size, word_emb_dim, args, freeze=freeze) 158 | print 'done! Time taken: ', time.time()-start 159 | 160 | # train network 161 | for epoch in range(args.no_of_epochs): 162 | validate(train_fn, 'TRAIN', epoch, train, args) 163 | validate(test_fn, '\t TEST', epoch, test, args, args.test_predictions_output) 164 | print "\n" 165 | -------------------------------------------------------------------------------- /src/models/combine_pickle.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import cPickle as p 3 | 4 | if __name__ == "__main__": 5 | askubuntu = p.load(open(sys.argv[1], 'rb')) 6 | unix = p.load(open(sys.argv[2], 'rb')) 7 | superuser = p.load(open(sys.argv[3], 'rb')) 8 | combined = unix + superuser + askubuntu 9 | p.dump(combined, open(sys.argv[4], 'wb')) 10 | -------------------------------------------------------------------------------- /src/models/evpi.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import argparse 3 | import theano, lasagne 4 | import numpy as np 5 | import cPickle as p 6 | import theano.tensor as T 7 | from collections import Counter 8 | import pdb 9 | import time 10 | import random, math 11 | DEPTH = 3 12 | DEPTH_A = 2 13 | from lstm_helper import * 14 | from model_helper import * 15 | 16 | def cos_sim_fn(v1, v2): 17 | numerator = T.sum(v1*v2, axis=1) 18 | denominator = T.sqrt(T.sum(v1**2, axis=1) * T.sum(v2**2, axis=1)) 19 | val = numerator/denominator 20 | return T.gt(val,0) * val + T.le(val,0) * 0.001 21 | 22 | def custom_sim_fn(v1, v2): 23 | val = cos_sim_fn(v1, v2) 24 | val = val - 0.95 25 | return T.gt(val,0) * T.exp(val) 26 | 27 | def answer_model(post_out, ques_out, ques_emb_out, ans_out, ans_emb_out, labels, args): 28 | # Pr(a|p,q) 29 | N = args.no_of_candidates 30 | post_ques = T.concatenate([post_out, ques_out[0]], axis=1) 31 | hidden_dim = 200 32 | l_post_ques_in = lasagne.layers.InputLayer(shape=(args.batch_size, 2*args.hidden_dim), input_var=post_ques) 33 | l_post_ques_denses = [None]*DEPTH_A 34 | for k in range(DEPTH_A): 35 | if k == 0: 36 | l_post_ques_denses[k] = lasagne.layers.DenseLayer(l_post_ques_in, num_units=hidden_dim,\ 37 | nonlinearity=lasagne.nonlinearities.rectify) 38 | else: 39 | l_post_ques_denses[k] = lasagne.layers.DenseLayer(l_post_ques_denses[k-1], num_units=hidden_dim,\ 40 | nonlinearity=lasagne.nonlinearities.rectify) 41 | 42 | l_post_ques_dense = lasagne.layers.DenseLayer(l_post_ques_denses[-1], num_units=1,\ 43 | nonlinearity=lasagne.nonlinearities.sigmoid) 44 | post_ques_dense_params = lasagne.layers.get_all_params(l_post_ques_denses, trainable=True) 45 | #print 'Params in post_ques ', lasagne.layers.count_params(l_post_ques_denses) 46 | 47 | for i in range(1, N): 48 | post_ques = T.concatenate([post_out, ques_out[i]], axis=1) 49 | l_post_ques_in_ = lasagne.layers.InputLayer(shape=(args.batch_size, 2*args.hidden_dim), input_var=post_ques) 50 | for k in range(DEPTH_A): 51 | if k == 0: 52 | l_post_ques_dense_ = lasagne.layers.DenseLayer(l_post_ques_in_, num_units=hidden_dim,\ 53 | nonlinearity=lasagne.nonlinearities.rectify,\ 54 | W=l_post_ques_denses[k].W,\ 55 | b=l_post_ques_denses[k].b) 56 | else: 57 | l_post_ques_dense_ = lasagne.layers.DenseLayer(l_post_ques_dense_, num_units=hidden_dim,\ 58 | nonlinearity=lasagne.nonlinearities.rectify,\ 59 | W=l_post_ques_denses[k].W,\ 60 | b=l_post_ques_denses[k].b) 61 | l_post_ques_dense_ = lasagne.layers.DenseLayer(l_post_ques_dense_, num_units=1,\ 62 | nonlinearity=lasagne.nonlinearities.sigmoid) 63 | 64 | ques_sim = [None]*(N*N) 65 | pq_a_squared_errors = [None]*(N*N) 66 | for i in range(N): 67 | for j in range(N): 68 | ques_sim[i*N+j] = custom_sim_fn(ques_emb_out[i], ques_emb_out[j]) 69 | pq_a_squared_errors[i*N+j] = 1-cos_sim_fn(ques_emb_out[i], ans_emb_out[j]) 70 | 71 | pq_a_loss = 0.0 72 | for i in range(N): 73 | pq_a_loss += T.mean(labels[:,i] * pq_a_squared_errors[i*N+i]) 74 | for j in range(N): 75 | if i == j: 76 | continue 77 | pq_a_loss += T.mean(labels[:,i] * pq_a_squared_errors[i*N+j] * ques_sim[i*N+j]) 78 | 79 | return ques_sim, pq_a_squared_errors, pq_a_loss, post_ques_dense_params 80 | 81 | def utility_calculator(post_out, ques_out, ques_emb_out, ans_out, ques_sim, pq_a_squared_errors, labels, args): 82 | # U(p+a) 83 | N = args.no_of_candidates 84 | pqa_loss = 0.0 85 | pqa_preds = [None]*(N*N) 86 | post_ques_ans = T.concatenate([post_out, ques_out[0], ans_out[0]], axis=1) 87 | l_post_ques_ans_in = lasagne.layers.InputLayer(shape=(args.batch_size, 3*args.hidden_dim), input_var=post_ques_ans) 88 | l_post_ques_ans_denses = [None]*DEPTH 89 | for k in range(DEPTH): 90 | if k == 0: 91 | l_post_ques_ans_denses[k] = lasagne.layers.DenseLayer(l_post_ques_ans_in, num_units=args.hidden_dim,\ 92 | nonlinearity=lasagne.nonlinearities.rectify) 93 | else: 94 | l_post_ques_ans_denses[k] = lasagne.layers.DenseLayer(l_post_ques_ans_denses[k-1], num_units=args.hidden_dim,\ 95 | nonlinearity=lasagne.nonlinearities.rectify) 96 | l_post_ques_ans_dense = lasagne.layers.DenseLayer(l_post_ques_ans_denses[-1], num_units=1,\ 97 | nonlinearity=lasagne.nonlinearities.sigmoid) 98 | pqa_preds[0] = lasagne.layers.get_output(l_post_ques_ans_dense) 99 | 100 | pqa_loss += T.mean(lasagne.objectives.binary_crossentropy(pqa_preds[0], labels[:,0])) 101 | 102 | for i in range(N): 103 | for j in range(N): 104 | if i == 0 and j == 0: 105 | continue 106 | post_ques_ans = T.concatenate([post_out, ques_out[i], ans_out[j]], axis=1) 107 | l_post_ques_ans_in_ = lasagne.layers.InputLayer(shape=(args.batch_size, 3*args.hidden_dim), input_var=post_ques_ans) 108 | for k in range(DEPTH): 109 | if k == 0: 110 | l_post_ques_ans_dense_ = lasagne.layers.DenseLayer(l_post_ques_ans_in_, num_units=args.hidden_dim,\ 111 | nonlinearity=lasagne.nonlinearities.rectify,\ 112 | W=l_post_ques_ans_denses[k].W,\ 113 | b=l_post_ques_ans_denses[k].b) 114 | else: 115 | l_post_ques_ans_dense_ = lasagne.layers.DenseLayer(l_post_ques_ans_dense_, num_units=args.hidden_dim,\ 116 | nonlinearity=lasagne.nonlinearities.rectify,\ 117 | W=l_post_ques_ans_denses[k].W,\ 118 | b=l_post_ques_ans_denses[k].b) 119 | l_post_ques_ans_dense_ = lasagne.layers.DenseLayer(l_post_ques_ans_dense_, num_units=1,\ 120 | nonlinearity=lasagne.nonlinearities.sigmoid) 121 | pqa_preds[i*N+j] = lasagne.layers.get_output(l_post_ques_ans_dense_) 122 | pqa_loss += T.mean(lasagne.objectives.binary_crossentropy(pqa_preds[i*N+i], labels[:,i])) 123 | post_ques_ans_dense_params = lasagne.layers.get_all_params(l_post_ques_ans_dense, trainable=True) 124 | #print 'Params in post_ques_ans ', lasagne.layers.count_params(l_post_ques_ans_dense) 125 | 126 | return pqa_loss, post_ques_ans_dense_params, pqa_preds 127 | 128 | def build(word_embeddings, len_voc, word_emb_dim, args, freeze=False): 129 | # input theano vars 130 | posts = T.imatrix() 131 | post_masks = T.fmatrix() 132 | ques_list = T.itensor3() 133 | ques_masks_list = T.ftensor3() 134 | ans_list = T.itensor3() 135 | ans_masks_list = T.ftensor3() 136 | labels = T.imatrix() 137 | N = args.no_of_candidates 138 | 139 | post_out, post_lstm_params = build_lstm(posts, post_masks, args.post_max_len, \ 140 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size) 141 | ques_out, ques_emb_out, ques_lstm_params = build_list_lstm(ques_list, ques_masks_list, N, args.ques_max_len, \ 142 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size) 143 | ans_out, ans_emb_out, ans_lstm_params = build_list_lstm(ans_list, ans_masks_list, N, args.ans_max_len, \ 144 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size) 145 | 146 | ques_sim, pq_a_squared_errors, pq_a_loss, post_ques_dense_params \ 147 | = answer_model(post_out, ques_out, ques_emb_out, ans_out, ans_emb_out, labels, args) 148 | 149 | all_params = post_lstm_params + ques_lstm_params + post_ques_dense_params 150 | 151 | post_out, post_lstm_params = build_lstm(posts, post_masks, args.post_max_len, \ 152 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size) 153 | ques_out, ques_emb_out, ques_lstm_params = build_list_lstm(ques_list, ques_masks_list, N, args.ques_max_len, \ 154 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size) 155 | ans_out, ans_emb_out, ans_lstm_params = build_list_lstm(ans_list, ans_masks_list, N, args.ans_max_len, \ 156 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size) 157 | 158 | pqa_loss, post_ques_ans_dense_params, pqa_preds = utility_calculator(post_out, ques_out, ques_emb_out, ans_out, \ 159 | ques_sim, pq_a_squared_errors, labels, args) 160 | 161 | all_params += post_lstm_params + ques_lstm_params + ans_lstm_params 162 | all_params += post_ques_ans_dense_params 163 | 164 | loss = pq_a_loss + pqa_loss 165 | loss += args.rho * sum(T.sum(l ** 2) for l in all_params) 166 | 167 | updates = lasagne.updates.adam(loss, all_params, learning_rate=args.learning_rate) 168 | 169 | train_fn = theano.function([posts, post_masks, ques_list, ques_masks_list, ans_list, ans_masks_list, labels], \ 170 | [loss, pq_a_loss, pqa_loss] + pq_a_squared_errors + ques_sim + pqa_preds, updates=updates) 171 | test_fn = theano.function([posts, post_masks, ques_list, ques_masks_list, ans_list, ans_masks_list, labels], \ 172 | [loss, pq_a_loss, pqa_loss] + pq_a_squared_errors + ques_sim + pqa_preds,) 173 | return train_fn, test_fn 174 | 175 | def validate(val_fn, fold_name, epoch, fold, args, out_file=None): 176 | start = time.time() 177 | num_batches = 0 178 | cost = 0 179 | pq_a_cost = 0 180 | utility_cost = 0 181 | corr = 0 182 | mrr = 0 183 | total = 0 184 | N = args.no_of_candidates 185 | recall = [0]*N 186 | 187 | if out_file: 188 | out_file_o = open(out_file+'.epoch%d' % epoch, 'w') 189 | out_file_o.close() 190 | posts, post_masks, ques_list, ques_masks_list, ans_list, ans_masks_list, post_ids = fold 191 | labels = np.zeros((len(post_ids), N), dtype=np.int32) 192 | ranks = np.zeros((len(post_ids), N), dtype=np.int32) 193 | labels[:,0] = 1 194 | for j in range(N): 195 | ranks[:,j] = j 196 | ques_list, ques_masks_list, ans_list, ans_masks_list, labels, ranks = shuffle(ques_list, ques_masks_list, \ 197 | ans_list, ans_masks_list, labels, ranks) 198 | for p, pm, q, qm, a, am, l, r, ids in iterate_minibatches(posts, post_masks, ques_list, ques_masks_list, \ 199 | ans_list, ans_masks_list, labels, ranks, \ 200 | post_ids, args.batch_size, shuffle=False): 201 | q = np.transpose(q, (1, 0, 2)) 202 | qm = np.transpose(qm, (1, 0, 2)) 203 | a = np.transpose(a, (1, 0, 2)) 204 | am = np.transpose(am, (1, 0, 2)) 205 | 206 | out = val_fn(p, pm, q, qm, a, am, l) 207 | loss = out[0] 208 | pq_a_loss = out[1] 209 | pqa_loss = out[2] 210 | 211 | pq_a_errors = out[3: 3+(N*N)] 212 | pq_a_errors = np.transpose(pq_a_errors) 213 | 214 | ques_sim = out[3+(N*N): 3+(N*N)+(N*N)] 215 | ques_sim = np.transpose(ques_sim) 216 | 217 | pqa_preds = out[3+(N*N)+(N*N):] 218 | pqa_preds = np.array(pqa_preds)[:,:,0] 219 | pqa_preds = np.transpose(pqa_preds) 220 | 221 | cost += loss 222 | pq_a_cost += pq_a_loss 223 | utility_cost += pqa_loss 224 | 225 | for j in range(args.batch_size): 226 | preds = [0.0]*N 227 | for k in range(N): 228 | preds[k] = pqa_preds[j][k*N+k] 229 | for m in range(N): 230 | if m == k: 231 | continue 232 | preds[k] += 0.1*math.exp(pq_a_errors[j][k*N+m])*ques_sim[j][k*N+m] * pqa_preds[j][k*N+m] 233 | rank = get_rank(preds, l[j]) 234 | if rank == 1: 235 | corr += 1 236 | mrr += 1.0/rank 237 | for index in range(N): 238 | if rank <= index+1: 239 | recall[index] += 1 240 | total += 1 241 | if out_file: 242 | write_test_predictions(out_file, ids[j], preds, r[j], epoch) 243 | num_batches += 1 244 | 245 | lstring = '%s: epoch:%d, cost:%f, pq_a_cost:%f, utility_cost:%f, acc:%f, mrr:%f,time:%d\n' % \ 246 | (fold_name, epoch, cost*1.0/num_batches, pq_a_cost*1.0/num_batches, utility_cost*1.0/num_batches, \ 247 | corr*1.0/total, mrr*1.0/total, time.time()-start) 248 | recall = [round(curr_r*1.0/total, 3) for curr_r in recall] 249 | recall_str = '[' 250 | for r in recall: 251 | recall_str += '%.3f ' % r 252 | recall_str += ']\n' 253 | 254 | outfile = open(args.stdout_file, 'a') 255 | outfile.write(lstring+'\n') 256 | outfile.write(recall_str+'\n') 257 | outfile.close() 258 | 259 | print lstring 260 | print recall 261 | 262 | def evpi(word_embeddings, vocab_size, word_emb_dim, freeze, args, train, test): 263 | outfile = open(args.stdout_file, 'w') 264 | outfile.close() 265 | 266 | print 'Compiling graph...' 267 | start = time.time() 268 | train_fn, test_fn = build(word_embeddings, vocab_size, word_emb_dim, args, freeze=freeze) 269 | print 'done! Time taken: ', time.time() - start 270 | 271 | # train network 272 | for epoch in range(args.no_of_epochs): 273 | validate(train_fn, 'TRAIN', epoch, train, args) 274 | validate(test_fn, '\t TEST', epoch, test, args, args.test_predictions_output) 275 | print "\n" 276 | 277 | -------------------------------------------------------------------------------- /src/models/load_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | import sys 5 | import cPickle as p 6 | import numpy as np 7 | 8 | def get_indices(tokens, vocab): 9 | indices = np.zeros([len(tokens)], dtype=np.int32) 10 | UNK = "" 11 | for i, w in enumerate(tokens): 12 | try: 13 | indices[i] = vocab[w] 14 | except: 15 | indices[i] = vocab[UNK] 16 | return indices 17 | 18 | def read_data(post_data_tsv, qa_data_tsv): 19 | posts = {} 20 | titles = {} 21 | ques_lists = {} 22 | ans_lists = {} 23 | with open(post_data_tsv, 'rb') as tsvfile: 24 | tsv_reader = csv.DictReader(tsvfile, delimiter='\t') 25 | for row in tsv_reader: 26 | post_id = row['postid'] 27 | titles[post_id] = row['title'] 28 | posts[post_id] = row['post'] 29 | with open(qa_data_tsv, 'rb') as tsvfile: 30 | tsv_reader = csv.DictReader(tsvfile, delimiter='\t') 31 | for row in tsv_reader: 32 | post_id = row['postid'] 33 | ques_lists[post_id] = [row['q1'], row['q2'], row['q3'], row['q4'], row['q5'], row['q6'], row['q7'], row['q8'], row['q9'], row['q10']] 34 | ans_lists[post_id] = [row['a1'], row['a2'], row['a3'], row['a4'], row['a5'], row['a6'], row['a7'], row['a8'], row['a9'], row['a10']] 35 | return posts, titles, ques_lists, ans_lists 36 | 37 | def read_ids(ids_file): 38 | ids = [curr_id.strip('\n') for curr_id in open(ids_file, 'r').readlines()] 39 | return ids 40 | 41 | def generate_neural_vectors(posts, titles, ques_lists, ans_lists, post_ids, vocab, N, split): 42 | post_vectors = [] 43 | ques_list_vectors = [] 44 | ans_list_vectors = [] 45 | for post_id in post_ids: 46 | post_vectors.append(get_indices(titles[post_id] + ' ' + posts[post_id], vocab)) 47 | ques_list_vector = [None]*N 48 | ans_list_vector = [None]*N 49 | for k in range(N): 50 | ques_list_vector[k] = get_indices(ques_lists[post_id][k], vocab) 51 | ans_list_vector[k] = get_indices(ans_lists[post_id][k], vocab) 52 | ques_list_vectors.append(ques_list_vector) 53 | ans_list_vectors.append(ans_list_vector) 54 | dirname = os.path.dirname(args.train_ids) 55 | p.dump(post_ids, open(os.path.join(dirname, 'post_ids_'+split+'.p'), 'wb')) 56 | p.dump(post_vectors, open(os.path.join(dirname, 'post_vectors_'+split+'.p'), 'wb')) 57 | p.dump(ques_list_vectors, open(os.path.join(dirname, 'ques_list_vectors_'+split+'.p'), 'wb')) 58 | p.dump(ans_list_vectors, open(os.path.join(dirname, 'ans_list_vectors_'+split+'.p'), 'wb')) 59 | 60 | def main(args): 61 | vocab = p.load(open(args.vocab, 'rb')) 62 | train_ids = read_ids(args.train_ids) 63 | tune_ids = read_ids(args.tune_ids) 64 | test_ids = read_ids(args.test_ids) 65 | N = args.no_of_candidates 66 | posts, titles, ques_lists, ans_lists = read_data(args.post_data_tsv, args.qa_data_tsv) 67 | generate_neural_vectors(posts, titles, ques_lists, ans_lists, train_ids, vocab, N, 'train') 68 | generate_neural_vectors(posts, titles, ques_lists, ans_lists, tune_ids, vocab, N, 'tune') 69 | generate_neural_vectors(posts, titles, ques_lists, ans_lists, test_ids, vocab, N, 'test') 70 | 71 | if __name__ == "__main__": 72 | argparser = argparse.ArgumentParser(sys.argv[0]) 73 | argparser.add_argument("--post_data_tsv", type = str) 74 | argparser.add_argument("--qa_data_tsv", type = str) 75 | argparser.add_argument("--train_ids", type = str) 76 | argparser.add_argument("--tune_ids", type = str) 77 | argparser.add_argument("--test_ids", type = str) 78 | argparser.add_argument("--vocab", type = str) 79 | argparser.add_argument("--no_of_candidates", type = int, default = 10) 80 | args = argparser.parse_args() 81 | print args 82 | print "" 83 | main(args) 84 | 85 | -------------------------------------------------------------------------------- /src/models/lstm_helper.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import theano, lasagne 3 | import theano.tensor as T 4 | 5 | def build_list_lstm(content_list, content_masks_list, N, max_len, word_embeddings, word_emb_dim, hidden_dim, len_voc, batch_size): 6 | out = [None]*N 7 | emb_out = [None]*N 8 | l_in = lasagne.layers.InputLayer(shape=(batch_size, max_len), input_var=content_list[0]) 9 | l_mask = lasagne.layers.InputLayer(shape=(batch_size, max_len), input_var=content_masks_list[0]) 10 | l_emb = lasagne.layers.EmbeddingLayer(l_in, len_voc, word_emb_dim, W=word_embeddings) 11 | l_lstm = lasagne.layers.LSTMLayer(l_emb, hidden_dim, mask_input=l_mask, ) 12 | out[0] = lasagne.layers.get_output(l_lstm) 13 | out[0] = T.mean(out[0] * content_masks_list[0][:,:,None], axis=1) 14 | emb_out[0] = lasagne.layers.get_output(l_emb) 15 | emb_out[0] = T.mean(emb_out[0] * content_masks_list[0][:,:,None], axis=1) 16 | for i in range(1, N): 17 | l_in_ = lasagne.layers.InputLayer(shape=(batch_size, max_len), input_var=content_list[i]) 18 | l_mask_ = lasagne.layers.InputLayer(shape=(batch_size, max_len), input_var=content_masks_list[i]) 19 | l_emb_ = lasagne.layers.EmbeddingLayer(l_in_, len_voc, word_emb_dim, W=l_emb.W) 20 | l_lstm_ = lasagne.layers.LSTMLayer(l_emb_, hidden_dim, mask_input=l_mask_,\ 21 | ingate=lasagne.layers.Gate(W_in=l_lstm.W_in_to_ingate,\ 22 | W_hid=l_lstm.W_hid_to_ingate,\ 23 | b=l_lstm.b_ingate,\ 24 | nonlinearity=l_lstm.nonlinearity_ingate),\ 25 | outgate=lasagne.layers.Gate(W_in=l_lstm.W_in_to_outgate,\ 26 | W_hid=l_lstm.W_hid_to_outgate,\ 27 | b=l_lstm.b_outgate,\ 28 | nonlinearity=l_lstm.nonlinearity_outgate),\ 29 | forgetgate=lasagne.layers.Gate(W_in=l_lstm.W_in_to_forgetgate,\ 30 | W_hid=l_lstm.W_hid_to_forgetgate,\ 31 | b=l_lstm.b_forgetgate,\ 32 | nonlinearity=l_lstm.nonlinearity_forgetgate),\ 33 | cell=lasagne.layers.Gate(W_in=l_lstm.W_in_to_cell,\ 34 | W_hid=l_lstm.W_hid_to_cell,\ 35 | b=l_lstm.b_cell,\ 36 | nonlinearity=l_lstm.nonlinearity_cell),\ 37 | peepholes=False,\ 38 | ) 39 | out[i] = lasagne.layers.get_output(l_lstm_) 40 | out[i] = T.mean(out[i] * content_masks_list[i][:,:,None], axis=1) 41 | emb_out[i] = lasagne.layers.get_output(l_emb_) 42 | emb_out[i] = T.mean(emb_out[i] * content_masks_list[i][:,:,None], axis=1) 43 | l_emb.params[l_emb.W].remove('trainable') 44 | params = lasagne.layers.get_all_params(l_lstm, trainable=True) 45 | #print 'Params in lstm: ', (lasagne.layers.count_params(l_lstm)-lasagne.layers.count_params(l_emb)) 46 | return out, emb_out, params 47 | 48 | def build_lstm(posts, post_masks, max_len, word_embeddings, word_emb_dim, hidden_dim, len_voc, batch_size): 49 | 50 | l_in = lasagne.layers.InputLayer(shape=(batch_size, max_len), input_var=posts) 51 | l_mask = lasagne.layers.InputLayer(shape=(batch_size, max_len), input_var=post_masks) 52 | l_emb = lasagne.layers.EmbeddingLayer(l_in, len_voc, word_emb_dim, W=word_embeddings) 53 | l_lstm = lasagne.layers.LSTMLayer(l_emb, hidden_dim, mask_input=l_mask, ) 54 | out = lasagne.layers.get_output(l_lstm) 55 | out = T.mean(out * post_masks[:,:,None], axis=1) 56 | l_emb.params[l_emb.W].remove('trainable') 57 | params = lasagne.layers.get_all_params(l_lstm, trainable=True) 58 | #print 'Params in post_lstm: ', (lasagne.layers.count_params(l_lstm)-lasagne.layers.count_params(l_emb)) 59 | return out, params 60 | 61 | -------------------------------------------------------------------------------- /src/models/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import theano, lasagne 4 | import numpy as np 5 | import cPickle as p 6 | import theano.tensor as T 7 | from collections import Counter 8 | import pdb 9 | import time 10 | import random, math 11 | from baseline_pq import baseline_pq 12 | from baseline_pa import baseline_pa 13 | from baseline_pqa import baseline_pqa 14 | from evpi import evpi 15 | from model_helper import * 16 | 17 | def main(args): 18 | post_ids_train = p.load(open(args.post_ids_train, 'rb')) 19 | post_ids_train = np.array(post_ids_train) 20 | post_vectors_train = p.load(open(args.post_vectors_train, 'rb')) 21 | ques_list_vectors_train = p.load(open(args.ques_list_vectors_train, 'rb')) 22 | ans_list_vectors_train = p.load(open(args.ans_list_vectors_train, 'rb')) 23 | 24 | post_ids_test = p.load(open(args.post_ids_test, 'rb')) 25 | post_ids_test = np.array(post_ids_test) 26 | post_vectors_test = p.load(open(args.post_vectors_test, 'rb')) 27 | ques_list_vectors_test = p.load(open(args.ques_list_vectors_test, 'rb')) 28 | ans_list_vectors_test = p.load(open(args.ans_list_vectors_test, 'rb')) 29 | 30 | out_file = open(args.test_predictions_output, 'w') 31 | out_file.close() 32 | 33 | word_embeddings = p.load(open(args.word_embeddings, 'rb')) 34 | word_embeddings = np.asarray(word_embeddings, dtype=np.float32) 35 | vocab_size = len(word_embeddings) 36 | word_emb_dim = len(word_embeddings[0]) 37 | print 'word emb dim: ', word_emb_dim 38 | freeze = False 39 | N = args.no_of_candidates 40 | 41 | print 'vocab_size ', vocab_size, ', post_max_len ', args.post_max_len, ' ques_max_len ', args.ques_max_len, ' ans_max_len ', args.ans_max_len 42 | 43 | start = time.time() 44 | print 'generating data' 45 | train = generate_data(post_vectors_train, ques_list_vectors_train, ans_list_vectors_train, args) 46 | test = generate_data(post_vectors_test, ques_list_vectors_test, ans_list_vectors_test, args) 47 | train.append(post_ids_train) 48 | test.append(post_ids_test) 49 | 50 | print 'done! Time taken: ', time.time() - start 51 | 52 | print 'Size of training data: ', len(post_ids_train) 53 | print 'Size of test data: ', len(post_ids_test) 54 | 55 | if args.model == 'baseline_pq': 56 | baseline_pq(word_embeddings, vocab_size, word_emb_dim, freeze, args, train, test) 57 | elif args.model == 'baseline_pa': 58 | baseline_pa(word_embeddings, vocab_size, word_emb_dim, freeze, args, train, test) 59 | elif args.model == 'baseline_pqa': 60 | baseline_pqa(word_embeddings, vocab_size, word_emb_dim, freeze, args, train, test) 61 | elif args.model == 'evpi': 62 | evpi(word_embeddings, vocab_size, word_emb_dim, freeze, args, train, test) 63 | 64 | if __name__ == '__main__': 65 | argparser = argparse.ArgumentParser(sys.argv[0]) 66 | argparser.add_argument("--post_ids_train", type = str) 67 | argparser.add_argument("--post_vectors_train", type = str) 68 | argparser.add_argument("--ques_list_vectors_train", type = str) 69 | argparser.add_argument("--ans_list_vectors_train", type = str) 70 | argparser.add_argument("--post_ids_test", type = str) 71 | argparser.add_argument("--post_vectors_test", type = str) 72 | argparser.add_argument("--ques_list_vectors_test", type = str) 73 | argparser.add_argument("--ans_list_vectors_test", type = str) 74 | argparser.add_argument("--word_embeddings", type = str) 75 | argparser.add_argument("--batch_size", type = int, default = 256) 76 | argparser.add_argument("--no_of_epochs", type = int, default = 20) 77 | argparser.add_argument("--hidden_dim", type = int, default = 100) 78 | argparser.add_argument("--no_of_candidates", type = int, default = 10) 79 | argparser.add_argument("--learning_rate", type = float, default = 0.001) 80 | argparser.add_argument("--rho", type = float, default = 1e-5) 81 | argparser.add_argument("--post_max_len", type = int, default = 300) 82 | argparser.add_argument("--ques_max_len", type = int, default = 40) 83 | argparser.add_argument("--ans_max_len", type = int, default = 40) 84 | argparser.add_argument("--test_predictions_output", type = str) 85 | argparser.add_argument("--stdout_file", type = str) 86 | argparser.add_argument("--model", type = str) 87 | args = argparser.parse_args() 88 | print args 89 | print "" 90 | main(args) 91 | -------------------------------------------------------------------------------- /src/models/model_helper.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import theano, lasagne 4 | import numpy as np 5 | import cPickle as p 6 | import theano.tensor as T 7 | from collections import Counter 8 | import pdb 9 | import time 10 | import random, math 11 | DEPTH = 5 12 | 13 | def get_data_masks(content, max_len): 14 | if len(content) > max_len: 15 | data = content[:max_len] 16 | data_mask = np.ones(max_len) 17 | else: 18 | data = np.concatenate((content, np.zeros(max_len-len(content))), axis=0) 19 | data_mask = np.concatenate((np.ones(len(content)), np.zeros(max_len-len(content))), axis=0) 20 | return data, data_mask 21 | 22 | def generate_data(posts, ques_list, ans_list, args): 23 | data_size = len(posts) 24 | data_posts = np.zeros((data_size, args.post_max_len), dtype=np.int32) 25 | data_post_masks = np.zeros((data_size, args.post_max_len), dtype=np.float32) 26 | 27 | N = args.no_of_candidates 28 | data_ques_list = np.zeros((data_size, N, args.ques_max_len), dtype=np.int32) 29 | data_ques_masks_list = np.zeros((data_size, N, args.ques_max_len), dtype=np.float32) 30 | 31 | data_ans_list = np.zeros((data_size, N, args.ans_max_len), dtype=np.int32) 32 | data_ans_masks_list = np.zeros((data_size, N, args.ans_max_len), dtype=np.float32) 33 | 34 | for i in range(data_size): 35 | data_posts[i], data_post_masks[i] = get_data_masks(posts[i], args.post_max_len) 36 | for j in range(N): 37 | data_ques_list[i][j], data_ques_masks_list[i][j] = get_data_masks(ques_list[i][j], args.ques_max_len) 38 | data_ans_list[i][j], data_ans_masks_list[i][j] = get_data_masks(ans_list[i][j], args.ans_max_len) 39 | 40 | return [data_posts, data_post_masks, data_ques_list, data_ques_masks_list, data_ans_list, data_ans_masks_list] 41 | 42 | def iterate_minibatches(posts, post_masks, ques_list, ques_masks_list, ans_list, ans_masks_list, \ 43 | labels, ranks, post_ids, batch_size, shuffle=False): 44 | if shuffle: 45 | indices = np.arange(posts.shape[0]) 46 | np.random.shuffle(indices) 47 | for start_idx in range(0, posts.shape[0] - batch_size + 1, batch_size): 48 | if shuffle: 49 | excerpt = indices[start_idx:start_idx + batch_size] 50 | else: 51 | excerpt = slice(start_idx, start_idx + batch_size) 52 | yield posts[excerpt], post_masks[excerpt], ques_list[excerpt], ques_masks_list[excerpt], \ 53 | ans_list[excerpt], ans_masks_list[excerpt], labels[excerpt], ranks[excerpt], post_ids[excerpt] 54 | 55 | def get_rank(preds, labels): 56 | preds = np.array(preds) 57 | correct = np.where(labels==1)[0][0] 58 | sort_index_preds = np.argsort(preds) 59 | desc_sort_index_preds = sort_index_preds[::-1] #since ascending sort and we want descending 60 | rank = np.where(desc_sort_index_preds==correct)[0][0] 61 | return rank+1 62 | 63 | def shuffle(q, qm, a, am, l, r): 64 | shuffled_q = np.zeros(q.shape, dtype=np.int32) 65 | shuffled_qm = np.zeros(qm.shape, dtype=np.float32) 66 | shuffled_a = np.zeros(a.shape, dtype=np.int32) 67 | shuffled_am = np.zeros(am.shape, dtype=np.float32) 68 | shuffled_l = np.zeros(l.shape, dtype=np.int32) 69 | shuffled_r = np.zeros(r.shape, dtype=np.int32) 70 | 71 | for i in range(len(q)): 72 | indexes = range(len(q[i])) 73 | random.shuffle(indexes) 74 | for j, index in enumerate(indexes): 75 | shuffled_q[i][j] = q[i][index] 76 | shuffled_qm[i][j] = qm[i][index] 77 | shuffled_a[i][j] = a[i][index] 78 | shuffled_am[i][j] = am[i][index] 79 | shuffled_l[i][j] = l[i][index] 80 | shuffled_r[i][j] = r[i][index] 81 | 82 | return shuffled_q, shuffled_qm, shuffled_a, shuffled_am, shuffled_l, shuffled_r 83 | 84 | 85 | def write_test_predictions(out_file, postId, utilities, ranks, epoch): 86 | lstring = "[%s]: " % (postId) 87 | N = len(utilities) 88 | scores = [0]*N 89 | for i in range(N): 90 | scores[ranks[i]] = utilities[i] 91 | for i in range(N): 92 | lstring += "%f " % (scores[i]) 93 | out_file_o = open(out_file+'.epoch%d' % epoch, 'a') 94 | out_file_o.write(lstring + '\n') 95 | out_file_o.close() 96 | 97 | def get_annotations(line): 98 | set_info, post_id, best, valids, confidence = line.split(',') 99 | sitename = set_info.split('_')[1] 100 | best = [int(best)] 101 | valids = [int(v) for v in valids.split()] 102 | confidence = int(confidence) 103 | return post_id, sitename, best, valids, confidence 104 | 105 | def evaluate_using_human_annotations(args, preds): 106 | human_annotations_file = open(args.test_human_annotations, 'r') 107 | best_acc_on10 = 0 108 | best_acc_on9 = 0 109 | 110 | valid_acc_on10 = 0 111 | valid_acc_on9 = 0 112 | 113 | total = 0 114 | total_best_on9 = 0 115 | 116 | for line in human_annotations_file.readlines(): 117 | line = line.strip('\n') 118 | splits = line.split('\t') 119 | post_id, sitename, best, valids, confidence = get_annotations(splits[0]) 120 | if len(splits) > 1: 121 | post_id2, sitename2, best2, valids2, confidence2 = get_annotations(splits[1]) 122 | assert(sitename == sitename2) 123 | assert(post_id == post_id2) 124 | best += [best2] 125 | valids += valids2 126 | 127 | if best != 0: 128 | total_best_on9 += 1 129 | 130 | post_id = sitename+'_'+post_id 131 | pred = preds[post_id].index(max(preds[post_id])) 132 | 133 | if pred in best: 134 | best_acc_on10 += 1 135 | if best != 0: 136 | best_acc_on9 +=1 137 | 138 | if pred in valids: 139 | valid_acc_on10 += 1 140 | if pred != 0: 141 | valid_acc_on9 += 1 142 | 143 | total += 1 144 | 145 | print 146 | print '\t\tBest acc on 10: %.2f Valid acc on 10: %.2f' % (best_acc_on10*100.0/total, valid_acc_on10*100.0/total) 147 | print '\t\tBest acc on 9: %.2f Valid acc on 9: %.2f' % (best_acc_on9*100.0/total_best_on9, valid_acc_on9*100.0/total) 148 | 149 | -------------------------------------------------------------------------------- /src/models/run_combine_domains.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_DIR=data 4 | UBUNTU=askubuntu.com 5 | UNIX=unix.stackexchange.com 6 | SUPERUSER=superuser.com 7 | SCRIPTS_DIR=src/models 8 | SITE_NAME=askubuntu_unix_superuser 9 | 10 | mkdir -p $DATA_DIR/$SITE_NAME 11 | 12 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/post_vectors_train.p \ 13 | $DATA_DIR/$UNIX/post_vectors_train.p \ 14 | $DATA_DIR/$SUPERUSER/post_vectors_train.p \ 15 | $DATA_DIR/$SITE_NAME/post_vectors_train.p 16 | 17 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/ques_list_vectors_train.p \ 18 | $DATA_DIR/$UNIX/ques_list_vectors_train.p \ 19 | $DATA_DIR/$SUPERUSER/ques_list_vectors_train.p \ 20 | $DATA_DIR/$SITE_NAME/ques_list_vectors_train.p 21 | 22 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/ans_list_vectors_train.p \ 23 | $DATA_DIR/$UNIX/ans_list_vectors_train.p \ 24 | $DATA_DIR/$SUPERUSER/ans_list_vectors_train.p \ 25 | $DATA_DIR/$SITE_NAME/ans_list_vectors_train.p 26 | 27 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/post_ids_train.p \ 28 | $DATA_DIR/$UNIX/post_ids_train.p \ 29 | $DATA_DIR/$SUPERUSER/post_ids_train.p \ 30 | $DATA_DIR/$SITE_NAME/post_ids_train.p 31 | 32 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/post_vectors_tune.p \ 33 | $DATA_DIR/$UNIX/post_vectors_tune.p \ 34 | $DATA_DIR/$SUPERUSER/post_vectors_tune.p \ 35 | $DATA_DIR/$SITE_NAME/post_vectors_tune.p 36 | 37 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/ques_list_vectors_tune.p \ 38 | $DATA_DIR/$UNIX/ques_list_vectors_tune.p \ 39 | $DATA_DIR/$SUPERUSER/ques_list_vectors_tune.p \ 40 | $DATA_DIR/$SITE_NAME/ques_list_vectors_tune.p 41 | 42 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/ans_list_vectors_tune.p \ 43 | $DATA_DIR/$UNIX/ans_list_vectors_tune.p \ 44 | $DATA_DIR/$SUPERUSER/ans_list_vectors_tune.p \ 45 | $DATA_DIR/$SITE_NAME/ans_list_vectors_tune.p 46 | 47 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/post_ids_tune.p \ 48 | $DATA_DIR/$UNIX/post_ids_tune.p \ 49 | $DATA_DIR/$SUPERUSER/post_ids_tune.p \ 50 | $DATA_DIR/$SITE_NAME/post_ids_tune.p 51 | 52 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/post_vectors_test.p \ 53 | $DATA_DIR/$UNIX/post_vectors_test.p \ 54 | $DATA_DIR/$SUPERUSER/post_vectors_test.p \ 55 | $DATA_DIR/$SITE_NAME/post_vectors_test.p 56 | 57 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/ques_list_vectors_test.p \ 58 | $DATA_DIR/$UNIX/ques_list_vectors_test.p \ 59 | $DATA_DIR/$SUPERUSER/ques_list_vectors_test.p \ 60 | $DATA_DIR/$SITE_NAME/ques_list_vectors_test.p 61 | 62 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/ans_list_vectors_test.p \ 63 | $DATA_DIR/$UNIX/ans_list_vectors_test.p \ 64 | $DATA_DIR/$SUPERUSER/ans_list_vectors_test.p \ 65 | $DATA_DIR/$SITE_NAME/ans_list_vectors_test.p 66 | 67 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/post_ids_test.p \ 68 | $DATA_DIR/$UNIX/post_ids_test.p \ 69 | $DATA_DIR/$SUPERUSER/post_ids_test.p \ 70 | $DATA_DIR/$SITE_NAME/post_ids_test.p 71 | 72 | cat $DATA_DIR/$UBUNTU/human_annotations $DATA_DIR/$UNIX/human_annotations $DATA_DIR/$SUPERUSER/human_annotations > $DATA_DIR/$SITE_NAME/human_annotations 73 | -------------------------------------------------------------------------------- /src/models/run_load_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_DIR=data 4 | EMB_DIR=embeddings 5 | #SITE_NAME=askubuntu.com 6 | #SITE_NAME=unix.stackexchange.com 7 | SITE_NAME=superuser.com 8 | 9 | SCRIPTS_DIR=src/models 10 | 11 | python $SCRIPTS_DIR/load_data.py --post_data_tsv $DATA_DIR/$SITE_NAME/post_data.tsv \ 12 | --qa_data_tsv $DATA_DIR/$SITE_NAME/qa_data.tsv \ 13 | --train_ids $DATA_DIR/$SITE_NAME/train_ids \ 14 | --tune_ids $DATA_DIR/$SITE_NAME/tune_ids \ 15 | --test_ids $DATA_DIR/$SITE_NAME/test_ids \ 16 | --vocab $EMB_DIR/vocab.p 17 | -------------------------------------------------------------------------------- /src/models/run_main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_DIR=data 4 | EMB_DIR=embeddings 5 | #SITE_NAME=askubuntu.com 6 | #SITE_NAME=unix.stackexchange.com 7 | #SITE_NAME=superuser.com 8 | SITE_NAME=askubuntu_unix_superuser 9 | 10 | OUTPUT_DIR=output 11 | SCRIPTS_DIR=src/models 12 | #MODEL=baseline_pq 13 | #MODEL=baseline_pa 14 | #MODEL=baseline_pqa 15 | MODEL=evpi 16 | 17 | mkdir -p $OUTPUT_DIR 18 | 19 | source /fs/clip-amr/gpu_virtualenv/bin/activate 20 | module add cuda/8.0.44 21 | module add cudnn/v5.1 22 | 23 | THEANO_FLAGS=floatX=float32,device=gpu python $SCRIPTS_DIR/main.py \ 24 | --post_ids_train $DATA_DIR/$SITE_NAME/post_ids_train.p \ 25 | --post_vectors_train $DATA_DIR/$SITE_NAME/post_vectors_train.p \ 26 | --ques_list_vectors_train $DATA_DIR/$SITE_NAME/ques_list_vectors_train.p \ 27 | --ans_list_vectors_train $DATA_DIR/$SITE_NAME/ans_list_vectors_train.p \ 28 | --post_ids_test $DATA_DIR/$SITE_NAME/post_ids_test.p \ 29 | --post_vectors_test $DATA_DIR/$SITE_NAME/post_vectors_test.p \ 30 | --ques_list_vectors_test $DATA_DIR/$SITE_NAME/ques_list_vectors_test.p \ 31 | --ans_list_vectors_test $DATA_DIR/$SITE_NAME/ans_list_vectors_test.p \ 32 | --word_embeddings $EMB_DIR/word_embeddings.p \ 33 | --batch_size 128 --no_of_epochs 20 --no_of_candidates 10 \ 34 | --test_predictions_output $DATA_DIR/$SITE_NAME/test_predictions_${MODEL}.out \ 35 | --stdout_file $OUTPUT_DIR/${SITE_NAME}.${MODEL}.out \ 36 | --model $MODEL \ 37 | --------------------------------------------------------------------------------