├── README.md
├── lucene
├── IndexFiles.class
├── IndexFiles.java
├── SearchSimilarFiles.class
├── SearchSimilarFiles.java
├── lucene-6.3.0
│ └── libs
│ │ ├── lucene-backward-codecs-6.3.0.jar
│ │ ├── lucene-benchmark-6.3.0.jar
│ │ ├── lucene-classification-6.3.0.jar
│ │ ├── lucene-codecs-6.3.0.jar
│ │ ├── lucene-core-6.3.0.jar
│ │ ├── lucene-demo-6.3.0.jar
│ │ ├── lucene-expressions-6.3.0.jar
│ │ ├── lucene-facet-6.3.0.jar
│ │ ├── lucene-grouping-6.3.0.jar
│ │ ├── lucene-highlighter-6.3.0.jar
│ │ ├── lucene-join-6.3.0.jar
│ │ ├── lucene-memory-6.3.0.jar
│ │ ├── lucene-misc-6.3.0.jar
│ │ ├── lucene-queries-6.3.0.jar
│ │ ├── lucene-queryparser-6.3.0.jar
│ │ ├── lucene-replicator-6.3.0.jar
│ │ ├── lucene-sandbox-6.3.0.jar
│ │ ├── lucene-spatial-6.3.0.jar
│ │ ├── lucene-spatial-extras-6.3.0.jar
│ │ ├── lucene-spatial3d-6.3.0.jar
│ │ ├── lucene-suggest-6.3.0.jar
│ │ └── lucene-test-framework-6.3.0.jar
├── run_lucene.sh
└── run_ques_lucene.sh
└── src
├── data_generation
├── README.md
├── data_generator.py
├── helper.py
├── parse.py
├── post_ques_ans_generator.py
└── run_data_generator.sh
├── embedding_generation
├── README.md
├── create_we_vocab.py
├── run_create_we_vocab.sh
└── run_glove.sh
├── evaluation
├── evaluate_model_with_human_annotations.py
└── run_evaluation.sh
└── models
├── README.md
├── baseline_pa.py
├── baseline_pq.py
├── baseline_pqa.py
├── combine_pickle.py
├── evpi.py
├── load_data.py
├── lstm_helper.py
├── main.py
├── model_helper.py
├── run_combine_domains.sh
├── run_load_data.sh
└── run_main.sh
/README.md:
--------------------------------------------------------------------------------
1 | # Repository information
2 |
3 | This repository contains data and code for the paper below:
4 |
5 |
6 | Learning to Ask Good Questions: Ranking Clarification Questions using Neural Expected Value of Perfect Information
7 | Sudha Rao (raosudha@cs.umd.edu) and Hal Daumé III (hal@umiacs.umd.edu)
8 | Proceedings of The 2018 Association of Computational Lingusitics (ACL 2018)
9 |
10 | # Downloading data
11 |
12 | * Download the clarification questions dataset from google drive here: https://go.umd.edu/clarification_questions_dataset
13 | * cp clarification_questions_dataset/data ranking_clarification_questions/data
14 |
15 | * Download word embeddings trained on stackexchange datadump here: https://go.umd.edu/stackexchange_embeddings
16 | * cp stackexchange_embeddings/embeddings ranking_clarification_questions/embeddings
17 |
18 | The above dataset contains clarification questions for these three sites of stackexchange:
19 | 1. askubuntu.com
20 | 2. unix.stackexchange.com
21 | 3. superuser.com
22 |
23 | # Running model on data above
24 |
25 | To run models on a combination of the three sites above, check ranking_clarification_questions/src/models/README
26 |
27 | # Generating data for other sites
28 |
29 | To generate clarification questions for a different site of stackexchange, check ranking_clarification_questions/src/data_generation/README
30 |
31 | # Retrain stackexchange word embeddings
32 |
33 | To retrain word embeddings on a newer version of stackexchange datadump, check ranking_clarification_questions/src/embedding_generation/README
34 |
35 | # Contact information
36 |
37 | Please contact Sudha Rao (raosudha@cs.umd.edu) if you have any questions or any feedback.
38 |
--------------------------------------------------------------------------------
/lucene/IndexFiles.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/IndexFiles.class
--------------------------------------------------------------------------------
/lucene/IndexFiles.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 | import java.io.BufferedReader;
18 | import java.io.IOException;
19 | import java.io.InputStream;
20 | import java.io.InputStreamReader;
21 | import java.nio.charset.StandardCharsets;
22 | import java.nio.file.FileVisitResult;
23 | import java.nio.file.Files;
24 | import java.nio.file.Path;
25 | import java.nio.file.Paths;
26 | import java.nio.file.SimpleFileVisitor;
27 | import java.nio.file.attribute.BasicFileAttributes;
28 | import java.util.Date;
29 |
30 | import org.apache.lucene.analysis.Analyzer;
31 | import org.apache.lucene.analysis.standard.StandardAnalyzer;
32 | import org.apache.lucene.document.LongPoint;
33 | import org.apache.lucene.document.Document;
34 | import org.apache.lucene.document.Field;
35 | import org.apache.lucene.document.StringField;
36 | import org.apache.lucene.document.TextField;
37 | import org.apache.lucene.index.IndexWriter;
38 | import org.apache.lucene.index.IndexWriterConfig.OpenMode;
39 | import org.apache.lucene.index.IndexWriterConfig;
40 | import org.apache.lucene.index.Term;
41 | import org.apache.lucene.store.Directory;
42 | import org.apache.lucene.store.FSDirectory;
43 |
44 | /** Index all text files under a directory.
45 | *
46 | * This is a command-line application demonstrating simple Lucene indexing.
47 | * Run it with no command-line arguments for usage information.
48 | */
49 | public class IndexFiles {
50 |
51 | private IndexFiles() {}
52 |
53 | /** Index all text files under a directory. */
54 | public static void main(String[] args) {
55 | String usage = "java org.apache.lucene.demo.IndexFiles"
56 | + " [-index INDEX_PATH] [-docs DOCS_PATH] [-update]\n\n"
57 | + "This indexes the documents in DOCS_PATH, creating a Lucene index"
58 | + "in INDEX_PATH that can be searched with SearchFiles";
59 | String indexPath = "index";
60 | String docsPath = null;
61 | boolean create = true;
62 | for(int i=0;iWriteLineDocTask.
141 | *
142 | * @param writer Writer to the index where the given file/dir info will be stored
143 | * @param path The file to index, or the directory to recurse into to find files to index
144 | * @throws IOException If there is a low-level I/O error
145 | */
146 | static void indexDocs(final IndexWriter writer, Path path) throws IOException {
147 | if (Files.isDirectory(path)) {
148 | Files.walkFileTree(path, new SimpleFileVisitor() {
149 | @Override
150 | public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException {
151 | try {
152 | indexDoc(writer, file, attrs.lastModifiedTime().toMillis());
153 | } catch (IOException ignore) {
154 | // don't index files that can't be read.
155 | }
156 | return FileVisitResult.CONTINUE;
157 | }
158 | });
159 | } else {
160 | indexDoc(writer, path, Files.getLastModifiedTime(path).toMillis());
161 | }
162 | }
163 |
164 | /** Indexes a single document */
165 | static void indexDoc(IndexWriter writer, Path file, long lastModified) throws IOException {
166 | try (InputStream stream = Files.newInputStream(file)) {
167 | // make a new, empty document
168 | Document doc = new Document();
169 |
170 | // Add the path of the file as a field named "path". Use a
171 | // field that is indexed (i.e. searchable), but don't tokenize
172 | // the field into separate words and don't index term frequency
173 | // or positional information:
174 | Field pathField = new StringField("path", file.getFileName().toString(), Field.Store.YES);
175 | doc.add(pathField);
176 |
177 | // Add the last modified date of the file a field named "modified".
178 | // Use a LongPoint that is indexed (i.e. efficiently filterable with
179 | // PointRangeQuery). This indexes to milli-second resolution, which
180 | // is often too fine. You could instead create a number based on
181 | // year/month/day/hour/minutes/seconds, down the resolution you require.
182 | // For example the long value 2011021714 would mean
183 | // February 17, 2011, 2-3 PM.
184 | doc.add(new LongPoint("modified", lastModified));
185 |
186 | // Add the contents of the file to a field named "contents". Specify a Reader,
187 | // so that the text of the file is tokenized and indexed, but not stored.
188 | // Note that FileReader expects the file to be in UTF-8 encoding.
189 | // If that's not the case searching for special characters will fail.
190 | doc.add(new TextField("contents", new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8))));
191 |
192 | if (writer.getConfig().getOpenMode() == OpenMode.CREATE) {
193 | // New index, so we just add the document (no old document can be there):
194 | System.out.println("adding " + file);
195 | writer.addDocument(doc);
196 | } else {
197 | // Existing index (an old copy of this document may have been indexed) so
198 | // we use updateDocument instead to replace the old one matching the exact
199 | // path, if present:
200 | System.out.println("updating " + file);
201 | writer.updateDocument(new Term("path", file.toString()), doc);
202 | }
203 | }
204 | }
205 | }
206 |
--------------------------------------------------------------------------------
/lucene/SearchSimilarFiles.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/SearchSimilarFiles.class
--------------------------------------------------------------------------------
/lucene/SearchSimilarFiles.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 | //package org.apache.lucene.demo;
18 |
19 |
20 | import java.io.BufferedReader;
21 | import java.io.IOException;
22 | import java.io.InputStream;
23 | import java.io.InputStreamReader;
24 | import java.io.PrintWriter;
25 | import java.lang.Math;
26 | import java.nio.charset.StandardCharsets;
27 | import java.nio.file.Files;
28 | import java.nio.file.FileVisitResult;
29 | import java.nio.file.Path;
30 | import java.nio.file.Paths;
31 | import java.nio.file.SimpleFileVisitor;
32 | import java.nio.file.attribute.BasicFileAttributes;
33 | import java.util.Date;
34 | import java.util.stream.Stream;
35 |
36 | import org.apache.lucene.analysis.Analyzer;
37 | import org.apache.lucene.analysis.standard.StandardAnalyzer;
38 | import org.apache.lucene.document.Document;
39 | import org.apache.lucene.document.TextField;
40 | import org.apache.lucene.index.DirectoryReader;
41 | import org.apache.lucene.index.IndexReader;
42 | import org.apache.lucene.queryparser.classic.QueryParser;
43 | import org.apache.lucene.queryparser.classic.ParseException;
44 | import org.apache.lucene.search.IndexSearcher;
45 | import org.apache.lucene.search.Query;
46 | import org.apache.lucene.search.ScoreDoc;
47 | import org.apache.lucene.search.TopDocs;
48 | import org.apache.lucene.store.FSDirectory;
49 |
50 | /** Simple command-line based search demo. */
51 | public class SearchSimilarFiles {
52 |
53 | private SearchSimilarFiles() {}
54 |
55 | /** Simple command-line based search demo. */
56 | public static void main(String[] args) throws Exception, ParseException, IOException{
57 | String usage =
58 | "Usage:\tjava org.apache.lucene.demo.SearchSimilarFiles [-index dir] [-queries file] [-outputFile string]\n\n";
59 | if (args.length < 2 && ("-h".equals(args[0]) || "-help".equals(args[0]))) {
60 | System.out.println(usage);
61 | System.exit(0);
62 | }
63 |
64 | String index = "index";
65 | String field = "contents";
66 | String queriesPath = null;
67 | String queryString = null;
68 | String outputFile = null;
69 |
70 | for(int i = 0;i < args.length;i++) {
71 | if ("-index".equals(args[i])) {
72 | index = args[i+1];
73 | i++;
74 | } else if ("-field".equals(args[i])) {
75 | field = args[i+1];
76 | i++;
77 | } else if ("-queries".equals(args[i])) {
78 | queriesPath = args[i+1];
79 | i++;
80 | } else if ("-outputFile".equals(args[i])) {
81 | outputFile = args[i+1];
82 | i++;
83 | }
84 | }
85 | IndexReader reader = DirectoryReader.open(FSDirectory.open(Paths.get(index)));
86 | IndexSearcher searcher = new IndexSearcher(reader);
87 | Analyzer analyzer = new StandardAnalyzer();
88 |
89 | QueryParser parser = new QueryParser(field, analyzer);
90 | PrintWriter writer = new PrintWriter(outputFile, "UTF-8");
91 |
92 | final Path queriesDir = Paths.get(queriesPath);
93 | if (Files.isDirectory(queriesDir)){
94 | Files.walkFileTree(queriesDir, new SimpleFileVisitor() {
95 | @Override
96 | public FileVisitResult visitFile(Path queryFile, BasicFileAttributes attrs) throws IOException {
97 | similarDocs(queryFile, parser, searcher, writer);
98 | return FileVisitResult.CONTINUE;
99 | }
100 | });
101 | //Stream queryFiles = Files.walk(queriesDir);
102 | //queryFiles.forEach(queryFile ->
103 | // similarDocs(queryFile, parser, searcher)
104 | //);
105 | }
106 | writer.close();
107 | reader.close();
108 | }
109 |
110 | static void similarDocs(Path file, QueryParser parser, IndexSearcher searcher, PrintWriter writer) {
111 | try (InputStream stream = Files.newInputStream(file)) {
112 | String filename = file.getFileName().toString();
113 | filename = filename.substring(0, filename.lastIndexOf('.'));
114 | writer.print(filename + ' ');
115 | String content = readDocument(new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8)));
116 | Query query = parser.parse(parser.escape(content));
117 |
118 | TopDocs results = searcher.search(query, 100);
119 | ScoreDoc[] hits = results.scoreDocs;
120 |
121 | int numTotalHits = results.totalHits;
122 | //System.out.println(numTotalHits + " total matching documents");
123 |
124 | for (int i = 0; i < Math.min(100, numTotalHits); i++) {
125 | Document doc = searcher.doc(hits[i].doc);
126 | String docname = doc.get("path");
127 | docname = docname.substring(0, docname.lastIndexOf('.'));
128 | //System.out.println("doc="+docname+" score="+hits[i].score);
129 | writer.print(docname + ' ');
130 | }
131 | } catch (IOException e) {
132 | e.printStackTrace();
133 | } catch (ParseException pe){
134 | System.out.println(pe.getStackTrace());
135 | System.out.println(pe.getMessage());
136 | } finally {
137 | writer.println();
138 | }
139 | }
140 |
141 | static String readDocument(BufferedReader br) throws IOException {
142 | StringBuilder sb = new StringBuilder();
143 | String line = br.readLine();
144 |
145 | while (line != null) {
146 | sb.append(line);
147 | sb.append(System.lineSeparator());
148 | line = br.readLine();
149 | }
150 | return sb.toString();
151 | }
152 | }
153 |
154 |
--------------------------------------------------------------------------------
/lucene/lucene-6.3.0/libs/lucene-backward-codecs-6.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-backward-codecs-6.3.0.jar
--------------------------------------------------------------------------------
/lucene/lucene-6.3.0/libs/lucene-benchmark-6.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-benchmark-6.3.0.jar
--------------------------------------------------------------------------------
/lucene/lucene-6.3.0/libs/lucene-classification-6.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-classification-6.3.0.jar
--------------------------------------------------------------------------------
/lucene/lucene-6.3.0/libs/lucene-codecs-6.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-codecs-6.3.0.jar
--------------------------------------------------------------------------------
/lucene/lucene-6.3.0/libs/lucene-core-6.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-core-6.3.0.jar
--------------------------------------------------------------------------------
/lucene/lucene-6.3.0/libs/lucene-demo-6.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-demo-6.3.0.jar
--------------------------------------------------------------------------------
/lucene/lucene-6.3.0/libs/lucene-expressions-6.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-expressions-6.3.0.jar
--------------------------------------------------------------------------------
/lucene/lucene-6.3.0/libs/lucene-facet-6.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-facet-6.3.0.jar
--------------------------------------------------------------------------------
/lucene/lucene-6.3.0/libs/lucene-grouping-6.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-grouping-6.3.0.jar
--------------------------------------------------------------------------------
/lucene/lucene-6.3.0/libs/lucene-highlighter-6.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-highlighter-6.3.0.jar
--------------------------------------------------------------------------------
/lucene/lucene-6.3.0/libs/lucene-join-6.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-join-6.3.0.jar
--------------------------------------------------------------------------------
/lucene/lucene-6.3.0/libs/lucene-memory-6.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-memory-6.3.0.jar
--------------------------------------------------------------------------------
/lucene/lucene-6.3.0/libs/lucene-misc-6.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-misc-6.3.0.jar
--------------------------------------------------------------------------------
/lucene/lucene-6.3.0/libs/lucene-queries-6.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-queries-6.3.0.jar
--------------------------------------------------------------------------------
/lucene/lucene-6.3.0/libs/lucene-queryparser-6.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-queryparser-6.3.0.jar
--------------------------------------------------------------------------------
/lucene/lucene-6.3.0/libs/lucene-replicator-6.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-replicator-6.3.0.jar
--------------------------------------------------------------------------------
/lucene/lucene-6.3.0/libs/lucene-sandbox-6.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-sandbox-6.3.0.jar
--------------------------------------------------------------------------------
/lucene/lucene-6.3.0/libs/lucene-spatial-6.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-spatial-6.3.0.jar
--------------------------------------------------------------------------------
/lucene/lucene-6.3.0/libs/lucene-spatial-extras-6.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-spatial-extras-6.3.0.jar
--------------------------------------------------------------------------------
/lucene/lucene-6.3.0/libs/lucene-spatial3d-6.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-spatial3d-6.3.0.jar
--------------------------------------------------------------------------------
/lucene/lucene-6.3.0/libs/lucene-suggest-6.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-suggest-6.3.0.jar
--------------------------------------------------------------------------------
/lucene/lucene-6.3.0/libs/lucene-test-framework-6.3.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/ranking_clarification_questions/4054f93173e32e5deab0a59eb3a3814d37d75086/lucene/lucene-6.3.0/libs/lucene-test-framework-6.3.0.jar
--------------------------------------------------------------------------------
/lucene/run_lucene.sh:
--------------------------------------------------------------------------------
1 | SITE_DIR=$1
2 | /usr/lib/jvm/java-1.8.0-oracle-1.8.0.151-1jpp.5.el7.x86_64/bin/java -cp '.:lucene-6.3.0/libs/*' IndexFiles -docs $SITE_DIR/post_docs/ -index $SITE_DIR/post_doc_indices/
3 |
4 | /usr/lib/jvm/java-1.8.0-oracle-1.8.0.151-1jpp.5.el7.x86_64/bin/java -cp '.:lucene-6.3.0/libs/*' SearchSimilarFiles -index $SITE_DIR/post_doc_indices/ -queries $SITE_DIR/post_docs/ -outputFile $SITE_DIR/lucene_similar_posts.txt
5 |
--------------------------------------------------------------------------------
/lucene/run_ques_lucene.sh:
--------------------------------------------------------------------------------
1 | SITE_DIR=$1
2 | #/usr/lib/jvm/java-1.8.0-oracle.x86_64/bin/javac -cp 'lucene-6.3.0/libs/*' IndexFiles.java
3 |
4 | #/usr/lib/jvm/java-1.8.0-oracle.x86_64/bin/java -cp '.:lucene-6.3.0/libs/*' IndexFiles -docs $SITE_DIR/ques_docs/ -index $SITE_DIR/ques_doc_indices/
5 |
6 | /usr/lib/jvm/java-1.8.0-oracle-1.8.0.151-1jpp.5.el7.x86_64/bin/java -cp '.:lucene-6.3.0/libs/*' IndexFiles -docs $SITE_DIR/ques_docs/ -index $SITE_DIR/ques_doc_indices/
7 |
8 | #/usr/lib/jvm/java-1.8.0-oracle.x86_64/bin/javac -cp 'lucene-6.3.0/libs/*' SearchSimilarFiles.java
9 |
10 | #/usr/lib/jvm/java-1.8.0-oracle.x86_64/bin/java -cp '.:lucene-6.3.0/libs/*' SearchSimilarFiles -index $SITE_DIR/ques_doc_indices/ -queries $SITE_DIR/ques_docs/ -outputFile $SITE_DIR/lucene_similar_questions.txt
11 |
12 | /usr/lib/jvm/java-1.8.0-oracle-1.8.0.151-1jpp.5.el7.x86_64/bin/java -cp '.:lucene-6.3.0/libs/*' SearchSimilarFiles -index $SITE_DIR/ques_doc_indices/ -queries $SITE_DIR/ques_docs/ -outputFile $SITE_DIR/lucene_similar_questions.txt
13 |
--------------------------------------------------------------------------------
/src/data_generation/README.md:
--------------------------------------------------------------------------------
1 | # Prerequisites
2 |
3 | * Install numpy, nltk, BeautifulSoup
4 |
5 | # NOTE
6 |
7 | * Data for the sites: askubuntu.com, unix.stackexchange.com & superuser.com can be found here: https://go.umd.edu/clarification_questions_dataset
8 |
9 | # Steps to generate data for any other site
10 |
11 | 1. Choose a sitename from the list of sites in https://archive.org/download/stackexchange. Let's say you chose 'academia.com'
12 | 2. Download the .7z file corresponding to the site i.e. academia.com.7z and unzip it under ranking_clarification_questions/stackexchange/
13 | 3. Set "SITENAME=academia.com" in ranking_clarification_questions/src/data_generation/run_data_generator.sh
14 | 4. cd ranking_clarification_questions; sh src/data_generation/run_data_generator.sh
15 |
16 | This will create data/academia.com/post_data.tsv & data/academia.com/qa_data.tsv files
17 |
--------------------------------------------------------------------------------
/src/data_generation/data_generator.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 | import argparse
3 | from parse import *
4 | from post_ques_ans_generator import *
5 | from helper import *
6 | import time
7 | import numpy as np
8 | import cPickle as p
9 | import pdb
10 | import random
11 |
12 | def get_similar_docs(lucene_similar_docs):
13 | lucene_similar_docs_file = open(lucene_similar_docs, 'r')
14 | similar_docs = {}
15 | for line in lucene_similar_docs_file.readlines():
16 | parts = line.split()
17 | if len(parts) > 1:
18 | similar_docs[parts[0]] = parts[1:]
19 | else:
20 | similar_docs[parts[0]] = []
21 | return similar_docs
22 |
23 | def generate_docs_for_lucene(post_ques_answers, posts, output_dir):
24 | for postId in post_ques_answers:
25 | f = open(os.path.join(output_dir, str(postId) + '.txt'), 'w')
26 | content = ' '.join(posts[postId].title).encode('utf-8') + ' ' + ' '.join(posts[postId].body).encode('utf-8')
27 | f.write(content)
28 | f.close()
29 |
30 | def create_tsv_files(post_data_tsv, qa_data_tsv, post_ques_answers, lucene_similar_posts):
31 | lucene_similar_posts = get_similar_docs(lucene_similar_posts)
32 | similar_posts = {}
33 | for line in lucene_similar_posts.readlines():
34 | splits = line.strip('\n').split()
35 | if len(splits) < 11:
36 | continue
37 | postId = splits[0]
38 | similar_posts[postId] = splits[1:11]
39 | post_data_tsv_file = open(post_data_tsv, 'w')
40 | post_data_tsv_file.write('postid\ttitle\tpost\n')
41 | qa_data_tsv_file = open(qa_data_tsv, 'w')
42 | qa_data_tsv_file.write('postid\tq1\tq2\tq3\tq4\tq5\tq6\tq7\tq8\tq9\tq10\ta1\ta2\ta3\ta4\ta5\ta6\ta7\ta8\ta9\ta10\n')
43 | for postId in similar_posts:
44 | post_data_tsv_file.write('%s\t%s\t%s\n' % (postId, \
45 | ' '.join(post_ques_answers[postId].post_title), \
46 | ' '.join(post_ques_answers[postId].post)))
47 | line = postId
48 | for i in range(10):
49 | line += '\t%s' % ' '.join(post_ques_answers[similar_posts[postId][i]].question_comment)
50 | for i in range(10):
51 | line += '\t%s' % ' '.join(post_ques_answers[similar_posts[postId][i]].answer)
52 | line += '\n'
53 | qa_data_tsv_file.write(line)
54 |
55 | def main(args):
56 | start_time = time.time()
57 | print 'Parsing posts...'
58 | post_parser = PostParser(args.posts_xml)
59 | post_parser.parse()
60 | posts = post_parser.get_posts()
61 | print 'Size: ', len(posts)
62 | print 'Done! Time taken ', time.time() - start_time
63 | print
64 |
65 | start_time = time.time()
66 | print 'Parsing posthistories...'
67 | posthistory_parser = PostHistoryParser(args.posthistory_xml)
68 | posthistory_parser.parse()
69 | posthistories = posthistory_parser.get_posthistories()
70 | print 'Size: ', len(posthistories)
71 | print 'Done! Time taken ', time.time() - start_time
72 | print
73 |
74 | start_time = time.time()
75 | print 'Parsing question comments...'
76 | comment_parser = CommentParser(args.comments_xml)
77 | comment_parser.parse_all_comments()
78 | question_comments = comment_parser.get_question_comments()
79 | all_comments = comment_parser.get_all_comments()
80 | print 'Size: ', len(question_comments)
81 | print 'Done! Time taken ', time.time() - start_time
82 | print
83 |
84 | start_time = time.time()
85 | print 'Loading vocab'
86 | vocab = p.load(open(args.vocab, 'rb'))
87 | print 'Done! Time taken ', time.time() - start_time
88 | print
89 |
90 | start_time = time.time()
91 | print 'Loading word_embeddings'
92 | word_embeddings = p.load(open(args.word_embeddings, 'rb'))
93 | word_embeddings = np.asarray(word_embeddings, dtype=np.float32)
94 | print 'Done! Time taken ', time.time() - start_time
95 | print
96 |
97 | start_time = time.time()
98 | print 'Generating post_ques_ans...'
99 | post_ques_ans_generator = PostQuesAnsGenerator()
100 | post_ques_answers = post_ques_ans_generator.generate(posts, question_comments, all_comments, posthistories, vocab, word_embeddings)
101 | print 'Size: ', len(post_ques_answers)
102 | print 'Done! Time taken ', time.time() - start_time
103 | print
104 |
105 | generate_docs_for_lucene(post_ques_answers, posts, args.lucene_docs_dir)
106 | os.system('cd %s && sh run_lucene.sh %s' % (args.lucene_dir, os.path.dirname(args.post_data_tsv)))
107 |
108 | create_tsv_files(args.post_data_tsv, args.qa_data_tsv, post_ques_answers, args.lucene_similar_posts)
109 |
110 | if __name__ == "__main__":
111 | argparser = argparse.ArgumentParser(sys.argv[0])
112 | argparser.add_argument("--posts_xml", type = str)
113 | argparser.add_argument("--comments_xml", type = str)
114 | argparser.add_argument("--posthistory_xml", type = str)
115 | argparser.add_argument("--lucene_dir", type = str)
116 | argparser.add_argument("--lucene_docs_dir", type = str)
117 | argparser.add_argument("--lucene_similar_posts", type = str)
118 | argparser.add_argument("--word_embeddings", type = str)
119 | argparser.add_argument("--vocab", type = str)
120 | argparser.add_argument("--no_of_candidates", type = int, default = 10)
121 | argparser.add_argument("--site_name", type = str)
122 | argparser.add_argument("--post_data_tsv", type = str)
123 | argparser.add_argument("--qa_data_tsv", type = str)
124 | args = argparser.parse_args()
125 | print args
126 | print ""
127 | main(args)
128 |
129 |
--------------------------------------------------------------------------------
/src/data_generation/helper.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from nltk.tokenize import word_tokenize
3 | from nltk.tokenize import sent_tokenize
4 | import nltk
5 | nltk.download('punkt')
6 | import numpy as np
7 | from BeautifulSoup import BeautifulSoup
8 | import re
9 |
10 | def get_tokens(text):
11 | text = BeautifulSoup(text.encode('utf-8').decode('ascii', 'ignore')).text
12 | return word_tokenize(str(text).lower())
13 |
14 | def get_indices(tokens, vocab):
15 | indices = np.zeros([len(tokens)], dtype=np.int32)
16 | UNK = ""
17 | for i, w in enumerate(tokens):
18 | try:
19 | indices[i] = vocab[w]
20 | except:
21 | indices[i] = vocab[UNK]
22 | return indices
23 |
24 | def get_similarity(a_indices, b_indices, word_embeddings):
25 | a_embeddings = [word_embeddings[idx] for idx in a_indices]
26 | b_embeddings = [word_embeddings[idx] for idx in b_indices]
27 | avg_a_embedding = np.mean(a_embeddings, axis=0)
28 | avg_b_embedding = np.mean(b_embeddings, axis=0)
29 | cosine_similarity = np.dot(avg_a_embedding, avg_b_embedding)/(np.linalg.norm(avg_a_embedding) * np.linalg.norm(avg_b_embedding))
30 | return cosine_similarity
31 |
32 | def remove_urls(text):
33 | r = re.compile(r"(http://[^ ]+)")
34 | text = r.sub("", text) #remove urls so that ? is not identified in urls
35 | r = re.compile(r"(https://[^ ]+)")
36 | text = r.sub("", text) #remove urls so that ? is not identified in urls
37 | r = re.compile(r"(http : //[^ ]+)")
38 | text = r.sub("", text) #remove urls so that ? is not identified in urls
39 | r = re.compile(r"(https : //[^ ]+)")
40 | text = r.sub("", text) #remove urls so that ? is not identified in urls
41 | return text
42 |
43 | def is_too_short_or_long(tokens):
44 | text = ' '.join(tokens)
45 | r = re.compile('[^a-zA-Z ]+')
46 | text = r.sub('', text)
47 | tokens = text.split()
48 | if len(tokens) < 3 or len(tokens) > 100:
49 | return True
50 | return False
51 |
--------------------------------------------------------------------------------
/src/data_generation/parse.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import cPickle as p
3 | import xml.etree.ElementTree as ET
4 | from collections import defaultdict
5 | import re
6 | import datetime
7 | import time
8 | import pdb
9 | from helper import *
10 |
11 | class Post:
12 |
13 | def __init__(self, title, body, sents, typeId, accepted_answerId, answer_count, owner_userId, creation_date, parentId, closed_date):
14 | self.title = title
15 | self.body = body
16 | self.sents = sents
17 | self.typeId = typeId
18 | self.accepted_answerId = accepted_answerId
19 | self.answer_count = answer_count
20 | self.owner_userId = owner_userId
21 | self.creation_date = creation_date
22 | self.parentId = parentId
23 | self.closed_date = closed_date
24 |
25 | class PostParser:
26 |
27 | def __init__(self, filename):
28 | self.filename = filename
29 | self.posts = dict()
30 |
31 | def parse(self):
32 | posts_tree = ET.parse(self.filename)
33 | for post in posts_tree.getroot():
34 | postId = post.attrib['Id']
35 | postTypeId = int(post.attrib['PostTypeId'])
36 | try:
37 | accepted_answerId = post.attrib['AcceptedAnswerId']
38 | except:
39 | accepted_answerId = None #non-main posts & unanswered posts don't have accepted_answerId
40 | try:
41 | answer_count = int(post.attrib['AnswerCount'])
42 | except:
43 | answer_count = None #non-main posts don't have answer_count
44 | try:
45 | title = get_tokens(post.attrib['Title'])
46 | except:
47 | title = []
48 | try:
49 | owner_userId = post.attrib['OwnerUserId']
50 | except:
51 | owner_userId = None
52 | creation_date = datetime.datetime.strptime(post.attrib['CreationDate'].split('.')[0], "%Y-%m-%dT%H:%M:%S")
53 | try:
54 | closed_date = datetime.datetime.strptime(post.attrib['ClosedDate'].split('.')[0], "%Y-%m-%dT%H:%M:%S")
55 | except:
56 | closed_date = None
57 | if postTypeId == 2:
58 | parentId = post.attrib['ParentId']
59 | else:
60 | parentId = None
61 | body = get_tokens(post.attrib['Body'])
62 | sent_tokens = get_sent_tokens(post.attrib['Body'])
63 | self.posts[postId] = Post(title, body, sent_tokens, postTypeId, accepted_answerId, answer_count, owner_userId, creation_date, parentId, closed_date)
64 |
65 | def get_posts(self):
66 | return self.posts
67 |
68 | class Comment:
69 |
70 | def __init__(self, text, creation_date, userId):
71 | self.text = text
72 | self.creation_date = creation_date
73 | self.userId = userId
74 |
75 | class QuestionComment:
76 |
77 | def __init__(self, text, creation_date, userId):
78 | self.text = text
79 | self.creation_date = creation_date
80 | self.userId = userId
81 |
82 | class CommentParser:
83 |
84 | def __init__(self, filename):
85 | self.filename = filename
86 | self.question_comments = defaultdict(list)
87 | self.question_comment = defaultdict(None)
88 | self.comment = defaultdict(None)
89 | self.all_comments = defaultdict(list)
90 |
91 | def domain_words(self):
92 | return ['duplicate', 'upvote', 'downvote', 'vote', 'related', 'upvoted', 'downvoted', 'edit']
93 |
94 | def get_question(self, text):
95 | old_text = text
96 | text = remove_urls(text)
97 | if old_text != text: #ignore questions with urls
98 | return None
99 | tokens = get_tokens(text)
100 | lc_text = ' '.join(tokens)
101 | if 'have you' in lc_text or 'did you try' in lc_text or 'can you try' in lc_text or 'could you try' in lc_text: #ignore questions that indirectly provide answer
102 | return None
103 | if '?' in tokens:
104 | parts = " ".join(tokens).split('?')
105 | text = ""
106 | for i in range(len(parts)-1):
107 | text += parts[i]+ ' ?'
108 | words = text.split()
109 | if len(words) > 20:
110 | break
111 | if len(words) > 20: #ignore long comments
112 | return None
113 | for w in self.domain_words():
114 | if w in words:
115 | return None
116 | if words[0] == '@':
117 | text = words[2:]
118 | else:
119 | text = words
120 | return text
121 | return None
122 |
123 | def get_comment_tokens(self, text):
124 | text = remove_urls(text)
125 | if text == '':
126 | return None
127 | tokens = get_tokens(text)
128 | if tokens == []:
129 | return None
130 | if tokens[0] == '@':
131 | tokens = tokens[2:]
132 | return tokens
133 |
134 | def parse_all_comments(self):
135 | comments_tree = ET.parse(self.filename)
136 | for comment in comments_tree.getroot():
137 | postId = comment.attrib['PostId']
138 | text = comment.attrib['Text']
139 | try:
140 | userId = comment.attrib['UserId']
141 | except:
142 | userId = None
143 | creation_date = datetime.datetime.strptime(comment.attrib['CreationDate'].split('.')[0], "%Y-%m-%dT%H:%M:%S")
144 | comment_tokens = self.get_comment_tokens(text)
145 | if not comment_tokens:
146 | continue
147 | curr_comment = Comment(comment_tokens, creation_date, userId)
148 | self.all_comments[postId].append(curr_comment)
149 | question = self.get_question(text)
150 | if question:
151 | question_comment = QuestionComment(question, creation_date, userId)
152 | self.question_comments[postId].append(question_comment)
153 |
154 | def get_question_comments(self):
155 | return self.question_comments
156 |
157 | def get_all_comments(self):
158 | return self.all_comments
159 |
160 | class PostHistory:
161 | def __init__(self):
162 | self.initial_post = None
163 | self.initial_post_sents = None
164 | self.edited_posts = []
165 | self.edit_comments = []
166 | self.edit_dates = []
167 |
168 | class PostHistoryParser:
169 |
170 | def __init__(self, filename):
171 | self.filename = filename
172 | self.posthistories = defaultdict(PostHistory)
173 |
174 | def parse(self):
175 | posthistory_tree = ET.parse(self.filename)
176 | for posthistory in posthistory_tree.getroot():
177 | posthistory_typeid = posthistory.attrib['PostHistoryTypeId']
178 | postId = posthistory.attrib['PostId']
179 | if posthistory_typeid == '2':
180 | self.posthistories[postId].initial_post = get_tokens(posthistory.attrib['Text'])
181 | self.posthistories[postId].initial_post_sents = get_sent_tokens(posthistory.attrib['Text'])
182 | elif posthistory_typeid == '5':
183 | self.posthistories[postId].edited_posts.append(get_tokens(posthistory.attrib['Text']))
184 | self.posthistories[postId].edit_comments.append(get_tokens(posthistory.attrib['Comment']))
185 | self.posthistories[postId].edit_dates.append(datetime.datetime.strptime(posthistory.attrib['CreationDate'].split('.')[0], "%Y-%m-%dT%H:%M:%S"))
186 | #format of date e.g.:"2008-09-06T08:07:10.730" We don't want .730
187 | for postId in self.posthistories.keys():
188 | if not self.posthistories[postId].edited_posts:
189 | del self.posthistories[postId]
190 |
191 | def get_posthistories(self):
192 | return self.posthistories
193 |
194 |
--------------------------------------------------------------------------------
/src/data_generation/post_ques_ans_generator.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from helper import *
3 | from collections import defaultdict
4 | from difflib import SequenceMatcher
5 | import pdb
6 |
7 | class PostQuesAns:
8 |
9 | def __init__(self, post_title, post, post_sents, question_comment, answer):
10 | self.post_title = post_title
11 | self.post = post
12 | self.post_sents = post_sents
13 | self.question_comment = question_comment
14 | self.answer = answer
15 |
16 | class PostQuesAnsGenerator:
17 |
18 | def __init__(self):
19 | self.post_ques_ans_dict = defaultdict(PostQuesAns)
20 |
21 | def get_diff(self, initial, final):
22 | s = SequenceMatcher(None, initial, final)
23 | diff = None
24 | for tag, i1, i2, j1, j2 in s.get_opcodes():
25 | if tag == 'insert':
26 | diff = final[j1:j2]
27 | if not diff:
28 | return None
29 | return diff
30 |
31 | def find_first_question(self, answer, question_comment_candidates, vocab, word_embeddings):
32 | first_question = None
33 | first_date = None
34 | for question_comment in question_comment_candidates:
35 | if first_question == None:
36 | first_question = question_comment
37 | first_date = question_comment.creation_date
38 | else:
39 | if question_comment.creation_date < first_date:
40 | first_question = question_comment
41 | first_date = question_comment.creation_date
42 | return first_question
43 |
44 | def find_answer_comment(self, all_comments, question_comment, post_userId):
45 | answer_comment, answer_comment_date = None, None
46 | for comment in all_comments:
47 | if comment.userId and comment.userId == post_userId:
48 | if comment.creation_date > question_comment.creation_date:
49 | if not answer_comment or (comment.creation_date < answer_comment_date):
50 | answer_comment = comment
51 | answer_comment_date = comment.creation_date
52 | return answer_comment
53 |
54 | def generate_using_comments(self, posts, question_comments, all_comments, vocab, word_embeddings):
55 | for postId in posts.keys():
56 | if postId in self.post_ques_ans_dict.keys():
57 | continue
58 | if posts[postId].typeId != 1: # is not a main post
59 | continue
60 | first_question = None
61 | first_date = None
62 | for question_comment in question_comments[postId]:
63 | if question_comment.userId and question_comment.userId == posts[postId].owner_userId:
64 | continue #Ignore comments by the original author of the post
65 | if first_question == None:
66 | first_question = question_comment
67 | first_date = question_comment.creation_date
68 | else:
69 | if question_comment.creation_date < first_date:
70 | first_question = question_comment
71 | first_date = question_comment.creation_date
72 | question = first_question
73 | if not question:
74 | continue
75 | answer_comment = self.find_answer_comment(all_comments[postId], question, posts[postId].owner_userId)
76 | if not answer_comment:
77 | continue
78 | answer = answer_comment
79 | self.post_ques_ans_dict[postId] = PostQuesAns(posts[postId].title, posts[postId].body, \
80 | posts[postId].sents, question.text, answer.text)
81 |
82 | def generate(self, posts, question_comments, all_comments, posthistories, vocab, word_embeddings):
83 | for postId, posthistory in posthistories.iteritems():
84 | if not posthistory.edited_posts:
85 | continue
86 | if posts[postId].typeId != 1: # is not a main post
87 | continue
88 | if not posthistory.initial_post:
89 | continue
90 | first_edit_date, first_question, first_answer = None, None, None
91 | for i in range(len(posthistory.edited_posts)):
92 | answer = self.get_diff(posthistory.initial_post, posthistory.edited_posts[i])
93 | if not answer:
94 | continue
95 | else:
96 | answer = remove_urls(' '.join(answer))
97 | answer = answer.split()
98 | if is_too_short_or_long(answer):
99 | continue
100 | question_comment_candidates = []
101 | for comment in question_comments[postId]:
102 | if comment.userId and comment.userId == posts[postId].owner_userId:
103 | continue #Ignore comments by the original author of the post
104 | if comment.creation_date > posthistory.edit_dates[i]:
105 | continue #Ignore comments added after the edit
106 | else:
107 | question_comment_candidates.append(comment)
108 | question = self.find_first_question(answer, question_comment_candidates, vocab, word_embeddings)
109 | if not question:
110 | continue
111 | answer_comment = self.find_answer_comment(all_comments[postId], question, posts[postId].owner_userId)
112 | if answer_comment and 'edit' not in answer_comment.text: #prefer edit if comment points to the edit
113 | question_indices = get_indices(question.text, vocab)
114 | answer_indices = get_indices(answer, vocab)
115 | answer_comment_indices = get_indices(answer_comment.text, vocab)
116 | if get_similarity(question_indices, answer_comment_indices, word_embeddings) > get_similarity(question_indices, answer_indices, word_embeddings):
117 | answer = answer_comment.text
118 | if first_edit_date == None or posthistory.edit_dates[i] < first_edit_date:
119 | first_question, first_answer, first_edit_date = question, answer, posthistory.edit_dates[i]
120 |
121 | if not first_question:
122 | continue
123 | self.post_ques_ans_dict[postId] = PostQuesAns(posts[postId].title, posthistory.initial_post, \
124 | posthistory.initial_post_sents, first_question.text, first_answer)
125 |
126 | self.generate_using_comments(posts, question_comments, all_comments, vocab, word_embeddings)
127 | return self.post_ques_ans_dict
128 |
--------------------------------------------------------------------------------
/src/data_generation/run_data_generator.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATADUMP_DIR=stackexchange #Directory containing xml files
4 | EMB_DIR=embeddings
5 | DATA_DIR=data
6 | SITE_NAME=askubuntu.com
7 | #SITE_NAME=unix.stackexchange.com
8 | #SITE_NAME=superuser.com
9 |
10 | SCRIPTS_DIR=ranking_clarification_questions/src/data_generation
11 | LUCENE_DIR=ranking_clarification_questions/lucene
12 |
13 | mkdir -p $DATA_DIR/$SITE_NAME
14 |
15 | rm -r $DATA_DIR/$SITE_NAME/post_docs
16 | rm -r $DATA_DIR/$SITE_NAME/post_doc_indices
17 | mkdir -p $DATA_DIR/$SITE_NAME/post_docs
18 |
19 | rm -r $DATA_DIR/$SITE_NAME/ques_docs
20 | rm -r $DATA_DIR/$SITE_NAME/ques_doc_indices
21 | mkdir -p $DATA_DIR/$SITE_NAME/ques_docs
22 |
23 | python $SCRIPTS_DIR/data_generator.py --posts_xml $DATADUMP_DIR/$SITE_NAME/Posts.xml \
24 | --comments_xml $DATADUMP_DIR/$SITE_NAME/Comments.xml \
25 | --posthistory_xml $DATADUMP_DIR/$SITE_NAME/PostHistory.xml \
26 | --word_embeddings $EMB_DIR/word_embeddings.p \
27 | --vocab $EMB_DIR/vocab.p \
28 | --lucene_dir $LUCENE_DIR \
29 | --lucene_docs_dir $DATA_DIR/$SITE_NAME/post_docs \
30 | --lucene_similar_posts $DATA_DIR/$SITE_NAME/lucene_similar_posts.txt \
31 | --site_name $SITE_NAME \
32 | --post_data_tsv $DATA_DIR/$SITE_NAME/post_data.tsv
33 | --qa_data_tsv $DATA_DIR/$SITE_NAME/qa_data.tsv
34 |
35 |
--------------------------------------------------------------------------------
/src/embedding_generation/README.md:
--------------------------------------------------------------------------------
1 | # Prerequisite
2 |
3 | * Download and compile GLoVE code from: https://nlp.stanford.edu/projects/glove/
4 |
5 | # NOTE
6 |
7 | * Word embeddings pretrained on stackexchange datadump (version year 2017) can be found here: https://go.umd.edu/stackexchange_embeddings
8 |
9 | # Steps to retrain word embeddings
10 |
11 | 1. Download all domains of stackexchange from: https://archive.org/download/stackexchange
12 | 2. Extract text from all Posts.xml, Comments.xml and PostHistory.xml
13 | 3. Save the combined data under: stackexchange/stackexchange_datadump.txt
14 | 4. cd ranking_clarification_questions; sh src/embedding_generation/run_glove.sh
15 | 5. cd ranking_clarification_questions; sh src/embedding_generation/run_create_we_vocab.sh
16 |
--------------------------------------------------------------------------------
/src/embedding_generation/create_we_vocab.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import cPickle as p
3 |
4 | if __name__ == "__main__":
5 | if len(sys.argv) < 4:
6 | print "usage: python create_we_vocab.py "
7 | sys.exit(0)
8 | word_vectors_file = open(sys.argv[1], 'r')
9 | word_embeddings = []
10 | vocab = {}
11 | i = 0
12 | for line in word_vectors_file.readlines():
13 | vals = line.rstrip().split(' ')
14 | vocab[vals[0]] = i
15 | word_embeddings.append(map(float, vals[1:]))
16 | i += 1
17 | p.dump(word_embeddings, open(sys.argv[2], 'wb'))
18 | p.dump(vocab, open(sys.argv[3], 'wb'))
19 |
20 |
--------------------------------------------------------------------------------
/src/embedding_generation/run_create_we_vocab.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SCRIPTS_DIR=ranking_clarification_questions/src/embedding_generation
4 | EMB_DIR=ranking_clarification_questions/embeddings/
5 |
6 | python $SCRIPTS_DIR/create_we_vocab.py $EMB_DIR/vectors.txt $EMB_DIR/word_embeddings.p $EMB_DIR/vocab.p
7 |
8 |
--------------------------------------------------------------------------------
/src/embedding_generation/run_glove.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Makes programs, downloads sample data, trains a GloVe model, and then evaluates it.
4 | # One optional argument can specify the language used for eval script: matlab, octave or [default] python
5 |
6 | DATADIR=embeddings
7 |
8 | CORPUS=stackexchange/stackexchange_datadump.txt
9 | VOCAB_FILE=$DATADIR/vocab.txt
10 | COOCCURRENCE_FILE=$DATADIR/cooccurrence.bin
11 | COOCCURRENCE_SHUF_FILE=$DATADIR/cooccurrence.shuf.bin
12 | BUILDDIR=GloVe-1.2/build
13 | SAVE_FILE=$DATADIR/vectors
14 | VERBOSE=2
15 | MEMORY=4.0
16 | VOCAB_MIN_COUNT=100
17 | VECTOR_SIZE=200
18 | MAX_ITER=30
19 | WINDOW_SIZE=15
20 | BINARY=2
21 | NUM_THREADS=4
22 | X_MAX=10
23 |
24 | $BUILDDIR/vocab_count -min-count $VOCAB_MIN_COUNT -verbose $VERBOSE < $CORPUS > $VOCAB_FILE
25 | $BUILDDIR/cooccur -memory $MEMORY -vocab-file $VOCAB_FILE -verbose $VERBOSE -window-size $WINDOW_SIZE < $CORPUS > $COOCCURRENCE_FILE
26 | $BUILDDIR/shuffle -memory $MEMORY -verbose $VERBOSE < $COOCCURRENCE_FILE > $COOCCURRENCE_SHUF_FILE
27 | $BUILDDIR/glove -save-file $SAVE_FILE -eta 0.05 -threads $NUM_THREADS -input-file $COOCCURRENCE_SHUF_FILE -x-max $X_MAX -iter $MAX_ITER -vector-size $VECTOR_SIZE -binary $BINARY -vocab-file $VOCAB_FILE -verbose $VERBOSE
28 |
--------------------------------------------------------------------------------
/src/evaluation/evaluate_model_with_human_annotations.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import pdb
4 | import random
5 | from sklearn.metrics import average_precision_score
6 | import sys
7 |
8 | BAD_QUESTIONS="unix_56867 unix_136954 unix_160510 unix_138507".split() + \
9 | "askubuntu_791945 askubuntu_91332 askubuntu_704807 askubuntu_628216 askubuntu_688172 askubuntu_727993".split() + \
10 | "askubuntu_279488 askubuntu_624918 askubuntu_527314 askubuntu_182249 askubuntu_610081 askubuntu_613851 askubuntu_777774 askubuntu_624498".split() + \
11 | "superuser_356658 superuser_121201 superuser_455589 superuser_38460 superuser_739955 superuser_931151".split() + \
12 | "superuser_291105 superuser_627439 superuser_584013 superuser_399182 superuser_632675 superuser_706347".split() + \
13 | "superuser_670748 superuser_369878 superuser_830279 superuser_927242 superuser_850786".split()
14 |
15 | def get_annotations(line):
16 | set_info, post_id, best, valids, confidence = line.split(',')
17 | annotator_name = set_info.split('_')[0]
18 | sitename = set_info.split('_')[1]
19 | best = int(best)
20 | valids = [int(v) for v in valids.split()]
21 | confidence = int(confidence)
22 | return post_id, annotator_name, sitename, best, valids, confidence
23 |
24 | def calculate_precision(model_pred_indices, best, valids):
25 | bp1, bp3, bp5 = 0., 0., 0.
26 | vp1, vp3, vp5 = 0., 0., 0.
27 | bp1 = len(set(model_pred_indices[:1]).intersection(set(best)))*1.0
28 | bp3 = len(set(model_pred_indices[:3]).intersection(set(best)))*1.0/3
29 | bp5 = len(set(model_pred_indices[:5]).intersection(set(best)))*1.0/5
30 |
31 | vp1 = len(set(model_pred_indices[:1]).intersection(set(valids)))*1.0
32 | vp3 = len(set(model_pred_indices[:3]).intersection(set(valids)))*1.0/3
33 | vp5 = len(set(model_pred_indices[:5]).intersection(set(valids)))*1.0/5
34 | return bp1, bp3, bp5, vp1, vp3, vp5
35 |
36 | def calculate_avg_precision(model_probs, best, valids):
37 | best_bool = [0]*10
38 | valids_bool = [0]*10
39 | for i in range(10):
40 | if i in best:
41 | best_bool[i] = 1
42 | if i in valids:
43 | valids_bool[i] = 1
44 | bap = average_precision_score(best_bool, model_probs)
45 | if 1 in valids_bool:
46 | vap = average_precision_score(valids_bool, model_probs)
47 | else:
48 | vap = 0.
49 | if 1 in best_bool[1:]:
50 | bap_on9 = average_precision_score(best_bool[1:], model_probs[1:])
51 | else:
52 | bap_on9 = 0.
53 | if 1 in valids_bool[1:]:
54 | vap_on9 = average_precision_score(valids_bool[1:], model_probs[1:])
55 | else:
56 | vap_on9 = 0.
57 | return bap, vap, bap_on9, vap_on9
58 |
59 | def get_pred_indices(model_predictions, asc=False):
60 | preds = np.array(model_predictions)
61 | pred_indices = np.argsort(preds)
62 | if not asc:
63 | pred_indices = pred_indices[::-1] #since ascending sort and we want descending
64 | return pred_indices
65 |
66 | def convert_to_probalitites(model_predictions):
67 | tot = sum(model_predictions)
68 | model_probs = [0]*10
69 | for i,v in enumerate(model_predictions):
70 | model_probs[i] = v*1./tot
71 | return model_probs
72 |
73 | def evaluate_model_on_org(model_predictions, asc=False):
74 | br1_tot, br3_tot, br5_tot = 0., 0., 0.
75 | for post_id in model_predictions:
76 | model_probs = convert_to_probalitites(model_predictions[post_id])
77 | model_pred_indices = get_pred_indices(model_probs, asc)
78 | br1, br3, br5, vr1, vr3, vr5 = calculate_precision(model_pred_indices, [0], [0])
79 | br1_tot += br1
80 | br3_tot += br3
81 | br5_tot += br5
82 | N = len(model_predictions)
83 | return br1_tot*100.0/N, br3_tot*100.0/N, br5_tot*100.0/N
84 |
85 | def read_human_annotations(human_annotations_filename):
86 | human_annotations_file = open(human_annotations_filename, 'r')
87 | annotations = {}
88 | for line in human_annotations_file.readlines():
89 | line = line.strip('\n')
90 | splits = line.split('\t')
91 | post_id1, annotator_name1, sitename1, best1, valids1, confidence1 = get_annotations(splits[0])
92 | post_id2, annotator_name2, sitename2, best2, valids2, confidence2 = get_annotations(splits[1])
93 | assert(sitename1 == sitename2)
94 | assert(post_id1 == post_id2)
95 | post_id = sitename1+'_'+post_id1
96 | best_union = list(set([best1, best2]))
97 | valids_inter = list(set(valids1).intersection(set(valids2)))
98 | annotations[post_id] = (best_union, valids_inter)
99 | return annotations
100 |
101 | def evaluate_model(human_annotations_filename, model_predictions, asc=False):
102 | human_annotations_file = open(human_annotations_filename, 'r')
103 | br1_tot, br3_tot, br5_tot = 0., 0., 0.
104 | vr1_tot, vr3_tot, vr5_tot = 0., 0., 0.
105 | br1_on9_tot, br3_on9_tot, br5_on9_tot = 0., 0., 0.
106 | vr1_on9_tot, vr3_on9_tot, vr5_on9_tot = 0., 0., 0.
107 | bap_tot, vap_tot = 0., 0.
108 | bap_on9_tot, vap_on9_tot = 0., 0.
109 | N = 0
110 | for line in human_annotations_file.readlines():
111 | line = line.strip('\n')
112 | splits = line.split('\t')
113 | post_id1, annotator_name1, sitename1, best1, valids1, confidence1 = get_annotations(splits[0])
114 | post_id2, annotator_name2, sitename2, best2, valids2, confidence2 = get_annotations(splits[1])
115 | assert(sitename1 == sitename2)
116 | assert(post_id1 == post_id2)
117 | post_id = sitename1+'_'+post_id1
118 | if post_id in BAD_QUESTIONS:
119 | continue
120 | best_union = list(set([best1, best2]))
121 | valids_inter = list(set(valids1).intersection(set(valids2)))
122 | valids_union = list(set(valids1+valids2))
123 |
124 | model_probs = convert_to_probalitites(model_predictions[post_id])
125 | model_pred_indices = get_pred_indices(model_probs, asc)
126 |
127 | br1, br3, br5, vr1, vr3, vr5 = calculate_precision(model_pred_indices, best_union, valids_inter)
128 | br1_tot += br1
129 | br3_tot += br3
130 | br5_tot += br5
131 | vr1_tot += vr1
132 | vr3_tot += vr3
133 | vr5_tot += vr5
134 | bap, vap, bap_on9, vap_on9 = calculate_avg_precision(model_probs, best_union, valids_inter)
135 | bap_tot += bap
136 | vap_tot += vap
137 | bap_on9_tot += bap_on9
138 | vap_on9_tot += vap_on9
139 |
140 | model_pred_indices = np.delete(model_pred_indices, 0)
141 |
142 | br1_on9, br3_on9, br5_on9, vr1_on9, vr3_on9, vr5_on9 = calculate_precision(model_pred_indices, best_union, valids_inter)
143 |
144 | br1_on9_tot += br1_on9
145 | br3_on9_tot += br3_on9
146 | br5_on9_tot += br5_on9
147 | vr1_on9_tot += vr1_on9
148 | vr3_on9_tot += vr3_on9
149 | vr5_on9_tot += vr5_on9
150 |
151 | N += 1
152 |
153 | human_annotations_file.close()
154 | return br1_tot*100.0/N, br3_tot*100.0/N, br5_tot*100.0/N, vr1_tot*100.0/N, vr3_tot*100.0/N, vr5_tot*100.0/N, \
155 | br1_on9_tot*100.0/N, br3_on9_tot*100.0/N, br5_on9_tot*100.0/N, vr1_on9_tot*100.0/N, vr3_on9_tot*100.0/N, vr5_on9_tot*100.0/N, \
156 | bap_tot*100./N, vap_tot*100./N, bap_on9_tot*100./N, vap_on9_tot*100./N
157 |
158 | def read_model_predictions(model_predictions_file):
159 | model_predictions = {}
160 | for line in model_predictions_file.readlines():
161 | splits = line.strip('\n').split()
162 | post_id = splits[0][1:-2]
163 | predictions = [float(val) for val in splits[1:]]
164 | model_predictions[post_id] = predictions
165 | return model_predictions
166 |
167 | def print_numbers(br1, br3, br5, vr1, vr3, vr5, br1_on9, br3_on9, br5_on9, vr1_on9, vr3_on9, vr5_on9, bmap, vmap, bmap_on9, vmap_on9, br1_org, br3_org, br5_org):
168 | print 'Best'
169 | print 'p@1 %.1f' % (br1)
170 | print 'p@3 %.1f' % (br3)
171 | print 'p@5 %.1f' % (br5)
172 | print 'map %.1f' % (bmap)
173 | print
174 | print 'Valid'
175 | print 'p@1 %.1f' % (vr1)
176 | print 'p@3 %.1f' % (vr3)
177 | print 'p@5 %.1f' % (vr5)
178 | print 'map %.1f' % (vmap)
179 | print
180 | print 'Best on 9'
181 | print 'p@1 %.1f' % (br1_on9)
182 | print 'p@3 %.1f' % (br3_on9)
183 | print 'p@5 %.1f' % (br5_on9)
184 | print 'map %.1f' % (bmap_on9)
185 | print
186 | print 'Valid on 9'
187 | print 'p@1 %.1f' % (vr1_on9)
188 | print 'p@3 %.1f' % (vr3_on9)
189 | print 'p@5 %.1f' % (vr5_on9)
190 | print 'map %.1f' % (vmap_on9)
191 | print
192 | print 'Original'
193 | print 'p@1 %.1f' % (br1_org)
194 | #print 'p@3 %.1f' % (br3_org)
195 | #print 'p@5 %.1f' % (br5_org)
196 |
197 | def main(args):
198 | model_predictions_file = open(args.model_predictions_filename, 'r')
199 | asc=False
200 | model_predictions = read_model_predictions(model_predictions_file)
201 | br1, br3, br5, vr1, vr3, vr5, \
202 | br1_on9, br3_on9, br5_on9, \
203 | vr1_on9, vr3_on9, vr5_on9, \
204 | bmap, vmap, bmap_on9, vmap_on9 = evaluate_model(args.human_annotations_filename, model_predictions, asc)
205 | br1_org, br3_org, br5_org = evaluate_model_on_org(model_predictions)
206 |
207 | print_numbers(br1, br3, br5, vr1, vr3, vr5, \
208 | br1_on9, br3_on9, br5_on9, \
209 | vr1_on9, vr3_on9, vr5_on9, \
210 | bmap, vmap, bmap_on9, vmap_on9, \
211 | br1_org, br3_org, br5_org)
212 |
213 | if __name__ == '__main__':
214 | argparser = argparse.ArgumentParser(sys.argv[0])
215 | argparser.add_argument("--human_annotations_filename", type = str)
216 | argparser.add_argument("--model_predictions_filename", type = str)
217 | args = argparser.parse_args()
218 | print args
219 | print ""
220 | main(args)
221 |
222 |
--------------------------------------------------------------------------------
/src/evaluation/run_evaluation.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATA_DIR=data
4 | #SITE_NAME=askubuntu.com
5 | #SITE_NAME=unix.stackexchange.com
6 | #SITE_NAME=superuser.com
7 | SITE_NAME=askubuntu_unix_superuser
8 |
9 | SCRIPTS_DIR=src/evaluation
10 | #MODEL=baseline_pq
11 | #MODEL=baseline_pa
12 | #MODEL=baseline_pqa
13 | MODEL=evpi
14 |
15 | python $SCRIPTS_DIR/evaluate_model_with_human_annotations.py \
16 | --human_annotations_filename $DATA_DIR/$SITE_NAME/human_annotations \
17 | --model_predictions_filename $DATA_DIR/$SITE_NAME/test_predictions_${MODEL}.out.epoch13 \
18 |
--------------------------------------------------------------------------------
/src/models/README.md:
--------------------------------------------------------------------------------
1 | # Prerequisites
2 |
3 | * Install lasagne: http://lasagne.readthedocs.io/en/latest/user/installation.html
4 | * Install numpy, scipy
5 | * Version information:
6 |
7 | Python 2.7.5
8 |
9 | Theano 0.9.0dev5
10 |
11 | Lasagne 0.2.dev1
12 |
13 | Cuda 8.0.44
14 |
15 | Cudnn 5.1
16 |
17 | # Loading data
18 |
19 | Load data from askubuntu.com
20 |
21 | * Set "SITE_NAME=askubuntu.com" in ranking_clarification_questions/src/models/run_load_data.sh
22 | * cd ranking_clarification_questions; sh src/models/run_load_data.sh
23 |
24 | Load data from unix.stackexchange.com
25 |
26 | * Set "SITE_NAME=unix.stackexchange.com" in ranking_clarification_questions/src/models/run_load_data.sh
27 | * cd ranking_clarification_questions; sh src/models/run_load_data.sh
28 |
29 | Load data from superuser.com
30 |
31 | * Set "SITE_NAME=superuser.com" in ranking_clarification_questions/src/models/run_load_data.sh
32 | * cd ranking_clarification_questions; sh src/models/run_load_data.sh
33 |
34 | Combine data from three domains
35 |
36 | * cd ranking_clarification_questions; sh src/models/run_combine_domains.sh
37 | * cat data/askubuntu.com/human_annotations data/unix.stackexchange.com/human_annotations data/superuser.com/human_annotations > askubuntu_unix_superuser/human_annotations
38 |
39 | # Running neural baselines on the combined data
40 |
41 | Neural(p,q)
42 |
43 | * Set "MODEL=baseline_pq" in ranking_clarification_questions/src/models/run_main.sh
44 | * cd ranking_clarification_questions; sh src/models/run_main.sh
45 |
46 | Neural(p,a)
47 |
48 | * Set "MODEL=baseline_pa" in ranking_clarification_questions/src/models/run_main.sh
49 | * cd ranking_clarification_questions; sh src/models/run_main.sh
50 |
51 | Neural(p,q,a)
52 |
53 | * Set "MODEL=baseline_pqa" in ranking_clarification_questions/src/models/run_main.sh
54 | * cd ranking_clarification_questions; sh src/models/run_main.sh
55 |
56 | # Runing EVPI model on the combined data
57 |
58 | * Set "MODEL=evpi" in ranking_clarification_questions/src/models/run_main.sh
59 | * cd ranking_clarification_questions; sh src/models/run_main.sh
60 |
61 | # Runing evaluation
62 |
63 | * cd ranking_clarification_questions; sh src/evaluation/run_evaluation.sh
64 |
--------------------------------------------------------------------------------
/src/models/baseline_pa.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | import theano, lasagne
4 | import numpy as np
5 | import cPickle as p
6 | import theano.tensor as T
7 | from collections import Counter
8 | import pdb
9 | import time
10 | import random, math
11 | DEPTH = 5
12 | from lstm_helper import *
13 | from model_helper import *
14 |
15 | def build(word_embeddings, len_voc, word_emb_dim, args, freeze=False):
16 |
17 | # input theano vars
18 | posts = T.imatrix()
19 | post_masks = T.fmatrix()
20 | ans_list = T.itensor3()
21 | ans_masks_list = T.ftensor3()
22 | labels = T.imatrix()
23 | N = args.no_of_candidates
24 |
25 | post_out, post_lstm_params = build_lstm(posts, post_masks, args.post_max_len, \
26 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size)
27 | ans_out, ans_emb_out, ans_lstm_params = build_list_lstm(ans_list, ans_masks_list, N, args.ans_max_len, \
28 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size)
29 |
30 | pa_preds = [None]*N
31 | post_ans = T.concatenate([post_out, ans_out[0]], axis=1)
32 | l_post_ans_in = lasagne.layers.InputLayer(shape=(args.batch_size, 2*args.hidden_dim), input_var=post_ans)
33 | l_post_ans_denses = [None]*DEPTH
34 | for k in range(DEPTH):
35 | if k == 0:
36 | l_post_ans_denses[k] = lasagne.layers.DenseLayer(l_post_ans_in, num_units=args.hidden_dim,\
37 | nonlinearity=lasagne.nonlinearities.rectify)
38 | else:
39 | l_post_ans_denses[k] = lasagne.layers.DenseLayer(l_post_ans_denses[k-1], num_units=args.hidden_dim,\
40 | nonlinearity=lasagne.nonlinearities.rectify)
41 | l_post_ans_dense = lasagne.layers.DenseLayer(l_post_ans_denses[-1], num_units=1,\
42 | nonlinearity=lasagne.nonlinearities.sigmoid)
43 | pa_preds[0] = lasagne.layers.get_output(l_post_ans_dense)
44 | loss = T.sum(lasagne.objectives.binary_crossentropy(pa_preds[0], labels[:,0]))
45 | for i in range(1, N):
46 | post_ans = T.concatenate([post_out, ans_out[i]], axis=1)
47 | l_post_ans_in_ = lasagne.layers.InputLayer(shape=(args.batch_size, 2*args.hidden_dim), input_var=post_ans)
48 | for k in range(DEPTH):
49 | if k == 0:
50 | l_post_ans_dense_ = lasagne.layers.DenseLayer(l_post_ans_in_, num_units=args.hidden_dim,\
51 | nonlinearity=lasagne.nonlinearities.rectify,\
52 | W=l_post_ans_denses[k].W,\
53 | b=l_post_ans_denses[k].b)
54 | else:
55 | l_post_ans_dense_ = lasagne.layers.DenseLayer(l_post_ans_dense_, num_units=args.hidden_dim,\
56 | nonlinearity=lasagne.nonlinearities.rectify,\
57 | W=l_post_ans_denses[k].W,\
58 | b=l_post_ans_denses[k].b)
59 | l_post_ans_dense_ = lasagne.layers.DenseLayer(l_post_ans_dense_, num_units=1,\
60 | nonlinearity=lasagne.nonlinearities.sigmoid)
61 | pa_preds[i] = lasagne.layers.get_output(l_post_ans_dense_)
62 | loss += T.sum(lasagne.objectives.binary_crossentropy(pa_preds[i], labels[:,i]))
63 |
64 | post_ans_dense_params = lasagne.layers.get_all_params(l_post_ans_dense, trainable=True)
65 |
66 | all_params = post_lstm_params + ans_lstm_params + post_ans_dense_params
67 | #print 'Params in concat ', lasagne.layers.count_params(l_post_ans_dense)
68 | loss += args.rho * sum(T.sum(l ** 2) for l in all_params)
69 |
70 | updates = lasagne.updates.adam(loss, all_params, learning_rate=args.learning_rate)
71 |
72 | train_fn = theano.function([posts, post_masks, ans_list, ans_masks_list, labels], \
73 | [loss] + pa_preds, updates=updates)
74 | test_fn = theano.function([posts, post_masks, ans_list, ans_masks_list, labels], \
75 | [loss] + pa_preds,)
76 | return train_fn, test_fn
77 |
78 | def validate(val_fn, fold_name, epoch, fold, args, out_file=None):
79 | start = time.time()
80 | num_batches = 0
81 | cost = 0
82 | corr = 0
83 | mrr = 0
84 | total = 0
85 | _lambda = 0.5
86 | N = args.no_of_candidates
87 | recall = [0]*N
88 |
89 | if out_file:
90 | out_file_o = open(out_file+'.epoch%d' % epoch, 'w')
91 | out_file_o.close()
92 | posts, post_masks, ques_list, ques_masks_list, ans_list, ans_masks_list, post_ids = fold
93 | labels = np.zeros((len(post_ids), N), dtype=np.int32)
94 | ranks = np.zeros((len(post_ids), N), dtype=np.int32)
95 | labels[:,0] = 1
96 | for j in range(N):
97 | ranks[:,j] = j
98 | ques_list, ques_masks_list, ans_list, ans_masks_list, labels, ranks = shuffle(ques_list, ques_masks_list, \
99 | ans_list, ans_masks_list, labels, ranks)
100 | for p, pm, q, qm, a, am, l, r, ids in iterate_minibatches(posts, post_masks, ques_list, ques_masks_list, \
101 | ans_list, ans_masks_list, labels, ranks, \
102 | post_ids, args.batch_size, shuffle=False):
103 | a = np.transpose(a, (1, 0, 2))
104 | am = np.transpose(am, (1, 0, 2))
105 |
106 | pa_out = val_fn(p, pm, a, am, l)
107 | loss = pa_out[0]
108 | pa_preds = pa_out[1:]
109 | pa_preds = np.transpose(pa_preds, (1, 0, 2))
110 | pa_preds = pa_preds[:,:,0]
111 | cost += loss
112 | for j in range(args.batch_size):
113 | preds = [0.0]*N
114 | for k in range(N):
115 | preds[k] = pa_preds[j][k]
116 | rank = get_rank(preds, l[j])
117 | if rank == 1:
118 | corr += 1
119 | mrr += 1.0/rank
120 | for index in range(N):
121 | if rank <= index+1:
122 | recall[index] += 1
123 | total += 1
124 | if out_file:
125 | write_test_predictions(out_file, ids[j], preds, r[j], epoch)
126 | num_batches += 1
127 |
128 | lstring = '%s: epoch:%d, cost:%f, acc:%f, mrr:%f,time:%d' % \
129 | (fold_name, epoch, cost*1.0/num_batches, corr*1.0/total, mrr*1.0/total, time.time()-start)
130 |
131 | recall = [round(curr_r*1.0/total, 3) for curr_r in recall]
132 | recall_str = '['
133 | for r in recall:
134 | recall_str += '%.3f ' % r
135 | recall_str += ']\n'
136 |
137 | print lstring
138 | print recall
139 |
140 | def baseline_pa(word_embeddings, vocab_size, word_emb_dim, freeze, args, train, test):
141 | start = time.time()
142 | print 'compiling pq graph...'
143 | train_fn, test_fn, = build(word_embeddings, vocab_size, word_emb_dim, args, freeze=freeze)
144 | print 'done! Time taken: ', time.time()-start
145 |
146 | # train network
147 | for epoch in range(args.no_of_epochs):
148 | validate(train_fn, 'TRAIN', epoch, train, args)
149 | validate(test_fn, '\t TEST', epoch, test, args, args.test_predictions_output)
150 | print "\n"
151 |
--------------------------------------------------------------------------------
/src/models/baseline_pq.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | import theano, lasagne
4 | import numpy as np
5 | import cPickle as p
6 | import theano.tensor as T
7 | from collections import Counter
8 | import pdb
9 | import time
10 | import random, math
11 | DEPTH = 5
12 | from lstm_helper import *
13 | from model_helper import *
14 |
15 | def build(word_embeddings, len_voc, word_emb_dim, args, freeze=False):
16 |
17 | # input theano vars
18 | posts = T.imatrix()
19 | post_masks = T.fmatrix()
20 | ques_list = T.itensor3()
21 | ques_masks_list = T.ftensor3()
22 | labels = T.imatrix()
23 | N = args.no_of_candidates
24 |
25 | post_out, post_lstm_params = build_lstm(posts, post_masks, args.post_max_len, \
26 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size)
27 | ques_out, ques_emb_out, ques_lstm_params = build_list_lstm(ques_list, ques_masks_list, N, args.ques_max_len, \
28 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size)
29 |
30 | pq_preds = [None]*N
31 | post_ques = T.concatenate([post_out, ques_out[0]], axis=1)
32 | l_post_ques_in = lasagne.layers.InputLayer(shape=(args.batch_size, 2*args.hidden_dim), input_var=post_ques)
33 | l_post_ques_denses = [None]*DEPTH
34 | for k in range(DEPTH):
35 | if k == 0:
36 | l_post_ques_denses[k] = lasagne.layers.DenseLayer(l_post_ques_in, num_units=args.hidden_dim,\
37 | nonlinearity=lasagne.nonlinearities.rectify)
38 | else:
39 | l_post_ques_denses[k] = lasagne.layers.DenseLayer(l_post_ques_denses[k-1], num_units=args.hidden_dim,\
40 | nonlinearity=lasagne.nonlinearities.rectify)
41 | l_post_ques_dense = lasagne.layers.DenseLayer(l_post_ques_denses[-1], num_units=1,\
42 | nonlinearity=lasagne.nonlinearities.sigmoid)
43 | pq_preds[0] = lasagne.layers.get_output(l_post_ques_dense)
44 | loss = T.sum(lasagne.objectives.binary_crossentropy(pq_preds[0], labels[:,0]))
45 | for i in range(1, N):
46 | post_ques = T.concatenate([post_out, ques_out[i]], axis=1)
47 | l_post_ques_in_ = lasagne.layers.InputLayer(shape=(args.batch_size, 2*args.hidden_dim), input_var=post_ques)
48 | for k in range(DEPTH):
49 | if k == 0:
50 | l_post_ques_dense_ = lasagne.layers.DenseLayer(l_post_ques_in_, num_units=args.hidden_dim,\
51 | nonlinearity=lasagne.nonlinearities.rectify,\
52 | W=l_post_ques_denses[k].W,\
53 | b=l_post_ques_denses[k].b)
54 | else:
55 | l_post_ques_dense_ = lasagne.layers.DenseLayer(l_post_ques_dense_, num_units=args.hidden_dim,\
56 | nonlinearity=lasagne.nonlinearities.rectify,\
57 | W=l_post_ques_denses[k].W,\
58 | b=l_post_ques_denses[k].b)
59 | l_post_ques_dense_ = lasagne.layers.DenseLayer(l_post_ques_dense_, num_units=1,\
60 | nonlinearity=lasagne.nonlinearities.sigmoid)
61 | pq_preds[i] = lasagne.layers.get_output(l_post_ques_dense_)
62 | loss += T.sum(lasagne.objectives.binary_crossentropy(pq_preds[i], labels[:,i]))
63 |
64 | post_ques_dense_params = lasagne.layers.get_all_params(l_post_ques_dense, trainable=True)
65 |
66 | all_params = post_lstm_params + ques_lstm_params + post_ques_dense_params
67 | #print 'Params in concat ', lasagne.layers.count_params(l_post_ques_dense)
68 | loss += args.rho * sum(T.sum(l ** 2) for l in all_params)
69 |
70 | updates = lasagne.updates.adam(loss, all_params, learning_rate=args.learning_rate)
71 |
72 | train_fn = theano.function([posts, post_masks, ques_list, ques_masks_list, labels], \
73 | [loss] + pq_preds, updates=updates)
74 | test_fn = theano.function([posts, post_masks, ques_list, ques_masks_list, labels], \
75 | [loss] + pq_preds,)
76 | return train_fn, test_fn
77 |
78 | def validate(val_fn, fold_name, epoch, fold, args, out_file=None):
79 | start = time.time()
80 | num_batches = 0
81 | cost = 0
82 | corr = 0
83 | mrr = 0
84 | total = 0
85 | _lambda = 0.5
86 | N = args.no_of_candidates
87 | recall = [0]*N
88 |
89 | if out_file:
90 | out_file_o = open(out_file+'.epoch%d' % epoch, 'w')
91 | out_file_o.close()
92 | posts, post_masks, ques_list, ques_masks_list, ans_list, ans_masks_list, post_ids = fold
93 | labels = np.zeros((len(post_ids), N), dtype=np.int32)
94 | ranks = np.zeros((len(post_ids), N), dtype=np.int32)
95 | labels[:,0] = 1
96 | for j in range(N):
97 | ranks[:,j] = j
98 | ques_list, ques_masks_list, ans_list, ans_masks_list, labels, ranks = shuffle(ques_list, ques_masks_list, \
99 | ans_list, ans_masks_list, labels, ranks)
100 | for p, pm, q, qm, a, am, l, r, ids in iterate_minibatches(posts, post_masks, ques_list, ques_masks_list, \
101 | ans_list, ans_masks_list, labels, ranks, \
102 | post_ids, args.batch_size, shuffle=False):
103 | q = np.transpose(q, (1, 0, 2))
104 | qm = np.transpose(qm, (1, 0, 2))
105 |
106 | pq_out = val_fn(p, pm, q, qm, l)
107 | loss = pq_out[0]
108 | pq_preds = pq_out[1:]
109 | pq_preds = np.transpose(pq_preds, (1, 0, 2))
110 | pq_preds = pq_preds[:,:,0]
111 | cost += loss
112 | for j in range(args.batch_size):
113 | preds = [0.0]*N
114 | for k in range(N):
115 | preds[k] = pq_preds[j][k]
116 | rank = get_rank(preds, l[j])
117 | if rank == 1:
118 | corr += 1
119 | mrr += 1.0/rank
120 | for index in range(N):
121 | if rank <= index+1:
122 | recall[index] += 1
123 | total += 1
124 | if out_file:
125 | write_test_predictions(out_file, ids[j], preds, r[j], epoch)
126 | num_batches += 1
127 |
128 | lstring = '%s: epoch:%d, cost:%f, acc:%f, mrr:%f,time:%d' % \
129 | (fold_name, epoch, cost*1.0/num_batches, corr*1.0/total, mrr*1.0/total, time.time()-start)
130 |
131 | recall = [round(curr_r*1.0/total, 3) for curr_r in recall]
132 | recall_str = '['
133 | for r in recall:
134 | recall_str += '%.3f ' % r
135 | recall_str += ']\n'
136 |
137 | print lstring
138 | print recall
139 |
140 | def baseline_pq(word_embeddings, vocab_size, word_emb_dim, freeze, args, train, test):
141 | start = time.time()
142 | print 'compiling pq graph...'
143 | train_fn, test_fn, = build(word_embeddings, vocab_size, word_emb_dim, args, freeze=freeze)
144 | print 'done! Time taken: ', time.time()-start
145 |
146 | # train network
147 | for epoch in range(args.no_of_epochs):
148 | validate(train_fn, 'TRAIN', epoch, train, args)
149 | validate(test_fn, '\t TEST', epoch, test, args, args.test_predictions_output)
150 | print "\n"
151 |
--------------------------------------------------------------------------------
/src/models/baseline_pqa.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import theano, lasagne
3 | import numpy as np
4 | import cPickle as p
5 | import theano.tensor as T
6 | import pdb
7 | import time
8 | DEPTH = 5
9 | from lstm_helper import *
10 | from model_helper import *
11 |
12 | def build(word_embeddings, len_voc, word_emb_dim, args, freeze=False):
13 |
14 | # input theano vars
15 | posts = T.imatrix()
16 | post_masks = T.fmatrix()
17 | ques_list = T.itensor3()
18 | ques_masks_list = T.ftensor3()
19 | ans_list = T.itensor3()
20 | ans_masks_list = T.ftensor3()
21 | labels = T.imatrix()
22 | N = args.no_of_candidates
23 |
24 | post_out, post_lstm_params = build_lstm(posts, post_masks, args.post_max_len, \
25 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size)
26 | ques_out, ques_emb_out, ques_lstm_params = build_list_lstm(ques_list, ques_masks_list, N, args.ques_max_len, \
27 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size)
28 | ans_out, ans_emb_out, ans_lstm_params = build_list_lstm(ans_list, ans_masks_list, N, args.ans_max_len, \
29 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size)
30 |
31 | pqa_preds = [None]*(N*N)
32 | post_ques_ans = T.concatenate([post_out, ques_out[0], ans_out[0]], axis=1)
33 | l_post_ques_ans_in = lasagne.layers.InputLayer(shape=(args.batch_size, 3*args.hidden_dim), input_var=post_ques_ans)
34 | l_post_ques_ans_denses = [None]*DEPTH
35 | for k in range(DEPTH):
36 | if k == 0:
37 | l_post_ques_ans_denses[k] = lasagne.layers.DenseLayer(l_post_ques_ans_in, num_units=args.hidden_dim,\
38 | nonlinearity=lasagne.nonlinearities.rectify)
39 | else:
40 | l_post_ques_ans_denses[k] = lasagne.layers.DenseLayer(l_post_ques_ans_denses[k-1], num_units=args.hidden_dim,\
41 | nonlinearity=lasagne.nonlinearities.rectify)
42 | l_post_ques_ans_dense = lasagne.layers.DenseLayer(l_post_ques_ans_denses[-1], num_units=1,\
43 | nonlinearity=lasagne.nonlinearities.sigmoid)
44 | pqa_preds[0] = lasagne.layers.get_output(l_post_ques_ans_dense)
45 | loss = 0.0
46 | for i in range(N):
47 | for j in range(N):
48 | if i == 0 and j == 0:
49 | continue
50 | post_ques_ans = T.concatenate([post_out, ques_out[i], ans_out[j]], axis=1)
51 | l_post_ques_ans_in_ = lasagne.layers.InputLayer(shape=(args.batch_size, 3*args.hidden_dim), input_var=post_ques_ans)
52 | for k in range(DEPTH):
53 | if k == 0:
54 | l_post_ques_ans_dense_ = lasagne.layers.DenseLayer(l_post_ques_ans_in_, num_units=args.hidden_dim,\
55 | nonlinearity=lasagne.nonlinearities.rectify,\
56 | W=l_post_ques_ans_denses[k].W,\
57 | b=l_post_ques_ans_denses[k].b)
58 | else:
59 | l_post_ques_ans_dense_ = lasagne.layers.DenseLayer(l_post_ques_ans_dense_, num_units=args.hidden_dim,\
60 | nonlinearity=lasagne.nonlinearities.rectify,\
61 | W=l_post_ques_ans_denses[k].W,\
62 | b=l_post_ques_ans_denses[k].b)
63 | l_post_ques_ans_dense_ = lasagne.layers.DenseLayer(l_post_ques_ans_dense_, num_units=1,\
64 | nonlinearity=lasagne.nonlinearities.sigmoid)
65 | pqa_preds[i*N+j] = lasagne.layers.get_output(l_post_ques_ans_dense_)
66 | loss += T.mean(lasagne.objectives.binary_crossentropy(pqa_preds[i*N+i], labels[:,i]))
67 |
68 | squared_errors = [None]*(N*N)
69 | for i in range(N):
70 | for j in range(N):
71 | squared_errors[i*N+j] = lasagne.objectives.squared_error(ans_out[i], ans_out[j])
72 | post_ques_ans_dense_params = lasagne.layers.get_all_params(l_post_ques_ans_dense, trainable=True)
73 |
74 | all_params = post_lstm_params + ques_lstm_params + ans_lstm_params + post_ques_ans_dense_params
75 | #print 'Params in concat ', lasagne.layers.count_params(l_post_ques_ans_dense)
76 | loss += args.rho * sum(T.sum(l ** 2) for l in all_params)
77 |
78 | updates = lasagne.updates.adam(loss, all_params, learning_rate=args.learning_rate)
79 |
80 | train_fn = theano.function([posts, post_masks, ques_list, ques_masks_list, ans_list, ans_masks_list, labels], \
81 | [loss] + pqa_preds + squared_errors, updates=updates)
82 | test_fn = theano.function([posts, post_masks, ques_list, ques_masks_list, ans_list, ans_masks_list, labels], \
83 | [loss] + pqa_preds + squared_errors,)
84 | return train_fn, test_fn
85 |
86 | def validate(val_fn, fold_name, epoch, fold, args, out_file=None):
87 | start = time.time()
88 | num_batches = 0
89 | cost = 0
90 | corr = 0
91 | mrr = 0
92 | total = 0
93 | _lambda = 0.5
94 | N = args.no_of_candidates
95 | recall = [0]*N
96 | batch_size = args.batch_size
97 |
98 | if out_file:
99 | out_file_o = open(out_file+'.epoch%d' % epoch, 'w')
100 | out_file_o.close()
101 | posts, post_masks, ques_list, ques_masks_list, ans_list, ans_masks_list, post_ids = fold
102 | labels = np.zeros((len(post_ids), N), dtype=np.int32)
103 | ranks = np.zeros((len(post_ids), N), dtype=np.int32)
104 | labels[:,0] = 1
105 | for j in range(N):
106 | ranks[:,j] = j
107 | ques_list, ques_masks_list, ans_list, ans_masks_list, labels, ranks = shuffle(ques_list, ques_masks_list, \
108 | ans_list, ans_masks_list, labels, ranks)
109 | for p, pm, q, qm, a, am, l, r, ids in iterate_minibatches(posts, post_masks, ques_list, ques_masks_list, \
110 | ans_list, ans_masks_list, labels, ranks, \
111 | post_ids, args.batch_size, shuffle=False):
112 | q = np.transpose(q, (1, 0, 2))
113 | qm = np.transpose(qm, (1, 0, 2))
114 | a = np.transpose(a, (1, 0, 2))
115 | am = np.transpose(am, (1, 0, 2))
116 |
117 | out = val_fn(p, pm, q, qm, a, am, l)
118 | loss = out[0]
119 | probs = out[1:1+N*N]
120 | errors = out[1+N*N:]
121 | probs = np.transpose(probs, (1, 0, 2))
122 | probs = probs[:,:,0]
123 | errors = np.transpose(errors, (1, 0, 2))
124 | errors = errors[:,:,0]
125 | cost += loss
126 | for j in range(batch_size):
127 | preds = [0.0]*N
128 | for k in range(N):
129 | preds[k] = probs[j][k*N+k]
130 | rank = get_rank(preds, l[j])
131 | if rank == 1:
132 | corr += 1
133 | mrr += 1.0/rank
134 | for index in range(N):
135 | if rank <= index+1:
136 | recall[index] += 1
137 | total += 1
138 | if out_file:
139 | write_test_predictions(out_file, ids[j], preds, r[j], epoch)
140 | num_batches += 1
141 |
142 | lstring = '%s: epoch:%d, cost:%f, acc:%f, mrr:%f,time:%d' % \
143 | (fold_name, epoch, cost*1.0/num_batches, corr*1.0/total, mrr*1.0/total, time.time()-start)
144 |
145 | recall = [round(curr_r*1.0/total, 3) for curr_r in recall]
146 | recall_str = '['
147 | for r in recall:
148 | recall_str += '%.3f ' % r
149 | recall_str += ']\n'
150 |
151 | print lstring
152 | print recall
153 |
154 | def baseline_pqa(word_embeddings, vocab_size, word_emb_dim, freeze, args, train, test):
155 | start = time.time()
156 | print 'compiling pqa graph...'
157 | train_fn, test_fn, = build(word_embeddings, vocab_size, word_emb_dim, args, freeze=freeze)
158 | print 'done! Time taken: ', time.time()-start
159 |
160 | # train network
161 | for epoch in range(args.no_of_epochs):
162 | validate(train_fn, 'TRAIN', epoch, train, args)
163 | validate(test_fn, '\t TEST', epoch, test, args, args.test_predictions_output)
164 | print "\n"
165 |
--------------------------------------------------------------------------------
/src/models/combine_pickle.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import cPickle as p
3 |
4 | if __name__ == "__main__":
5 | askubuntu = p.load(open(sys.argv[1], 'rb'))
6 | unix = p.load(open(sys.argv[2], 'rb'))
7 | superuser = p.load(open(sys.argv[3], 'rb'))
8 | combined = unix + superuser + askubuntu
9 | p.dump(combined, open(sys.argv[4], 'wb'))
10 |
--------------------------------------------------------------------------------
/src/models/evpi.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 | import argparse
3 | import theano, lasagne
4 | import numpy as np
5 | import cPickle as p
6 | import theano.tensor as T
7 | from collections import Counter
8 | import pdb
9 | import time
10 | import random, math
11 | DEPTH = 3
12 | DEPTH_A = 2
13 | from lstm_helper import *
14 | from model_helper import *
15 |
16 | def cos_sim_fn(v1, v2):
17 | numerator = T.sum(v1*v2, axis=1)
18 | denominator = T.sqrt(T.sum(v1**2, axis=1) * T.sum(v2**2, axis=1))
19 | val = numerator/denominator
20 | return T.gt(val,0) * val + T.le(val,0) * 0.001
21 |
22 | def custom_sim_fn(v1, v2):
23 | val = cos_sim_fn(v1, v2)
24 | val = val - 0.95
25 | return T.gt(val,0) * T.exp(val)
26 |
27 | def answer_model(post_out, ques_out, ques_emb_out, ans_out, ans_emb_out, labels, args):
28 | # Pr(a|p,q)
29 | N = args.no_of_candidates
30 | post_ques = T.concatenate([post_out, ques_out[0]], axis=1)
31 | hidden_dim = 200
32 | l_post_ques_in = lasagne.layers.InputLayer(shape=(args.batch_size, 2*args.hidden_dim), input_var=post_ques)
33 | l_post_ques_denses = [None]*DEPTH_A
34 | for k in range(DEPTH_A):
35 | if k == 0:
36 | l_post_ques_denses[k] = lasagne.layers.DenseLayer(l_post_ques_in, num_units=hidden_dim,\
37 | nonlinearity=lasagne.nonlinearities.rectify)
38 | else:
39 | l_post_ques_denses[k] = lasagne.layers.DenseLayer(l_post_ques_denses[k-1], num_units=hidden_dim,\
40 | nonlinearity=lasagne.nonlinearities.rectify)
41 |
42 | l_post_ques_dense = lasagne.layers.DenseLayer(l_post_ques_denses[-1], num_units=1,\
43 | nonlinearity=lasagne.nonlinearities.sigmoid)
44 | post_ques_dense_params = lasagne.layers.get_all_params(l_post_ques_denses, trainable=True)
45 | #print 'Params in post_ques ', lasagne.layers.count_params(l_post_ques_denses)
46 |
47 | for i in range(1, N):
48 | post_ques = T.concatenate([post_out, ques_out[i]], axis=1)
49 | l_post_ques_in_ = lasagne.layers.InputLayer(shape=(args.batch_size, 2*args.hidden_dim), input_var=post_ques)
50 | for k in range(DEPTH_A):
51 | if k == 0:
52 | l_post_ques_dense_ = lasagne.layers.DenseLayer(l_post_ques_in_, num_units=hidden_dim,\
53 | nonlinearity=lasagne.nonlinearities.rectify,\
54 | W=l_post_ques_denses[k].W,\
55 | b=l_post_ques_denses[k].b)
56 | else:
57 | l_post_ques_dense_ = lasagne.layers.DenseLayer(l_post_ques_dense_, num_units=hidden_dim,\
58 | nonlinearity=lasagne.nonlinearities.rectify,\
59 | W=l_post_ques_denses[k].W,\
60 | b=l_post_ques_denses[k].b)
61 | l_post_ques_dense_ = lasagne.layers.DenseLayer(l_post_ques_dense_, num_units=1,\
62 | nonlinearity=lasagne.nonlinearities.sigmoid)
63 |
64 | ques_sim = [None]*(N*N)
65 | pq_a_squared_errors = [None]*(N*N)
66 | for i in range(N):
67 | for j in range(N):
68 | ques_sim[i*N+j] = custom_sim_fn(ques_emb_out[i], ques_emb_out[j])
69 | pq_a_squared_errors[i*N+j] = 1-cos_sim_fn(ques_emb_out[i], ans_emb_out[j])
70 |
71 | pq_a_loss = 0.0
72 | for i in range(N):
73 | pq_a_loss += T.mean(labels[:,i] * pq_a_squared_errors[i*N+i])
74 | for j in range(N):
75 | if i == j:
76 | continue
77 | pq_a_loss += T.mean(labels[:,i] * pq_a_squared_errors[i*N+j] * ques_sim[i*N+j])
78 |
79 | return ques_sim, pq_a_squared_errors, pq_a_loss, post_ques_dense_params
80 |
81 | def utility_calculator(post_out, ques_out, ques_emb_out, ans_out, ques_sim, pq_a_squared_errors, labels, args):
82 | # U(p+a)
83 | N = args.no_of_candidates
84 | pqa_loss = 0.0
85 | pqa_preds = [None]*(N*N)
86 | post_ques_ans = T.concatenate([post_out, ques_out[0], ans_out[0]], axis=1)
87 | l_post_ques_ans_in = lasagne.layers.InputLayer(shape=(args.batch_size, 3*args.hidden_dim), input_var=post_ques_ans)
88 | l_post_ques_ans_denses = [None]*DEPTH
89 | for k in range(DEPTH):
90 | if k == 0:
91 | l_post_ques_ans_denses[k] = lasagne.layers.DenseLayer(l_post_ques_ans_in, num_units=args.hidden_dim,\
92 | nonlinearity=lasagne.nonlinearities.rectify)
93 | else:
94 | l_post_ques_ans_denses[k] = lasagne.layers.DenseLayer(l_post_ques_ans_denses[k-1], num_units=args.hidden_dim,\
95 | nonlinearity=lasagne.nonlinearities.rectify)
96 | l_post_ques_ans_dense = lasagne.layers.DenseLayer(l_post_ques_ans_denses[-1], num_units=1,\
97 | nonlinearity=lasagne.nonlinearities.sigmoid)
98 | pqa_preds[0] = lasagne.layers.get_output(l_post_ques_ans_dense)
99 |
100 | pqa_loss += T.mean(lasagne.objectives.binary_crossentropy(pqa_preds[0], labels[:,0]))
101 |
102 | for i in range(N):
103 | for j in range(N):
104 | if i == 0 and j == 0:
105 | continue
106 | post_ques_ans = T.concatenate([post_out, ques_out[i], ans_out[j]], axis=1)
107 | l_post_ques_ans_in_ = lasagne.layers.InputLayer(shape=(args.batch_size, 3*args.hidden_dim), input_var=post_ques_ans)
108 | for k in range(DEPTH):
109 | if k == 0:
110 | l_post_ques_ans_dense_ = lasagne.layers.DenseLayer(l_post_ques_ans_in_, num_units=args.hidden_dim,\
111 | nonlinearity=lasagne.nonlinearities.rectify,\
112 | W=l_post_ques_ans_denses[k].W,\
113 | b=l_post_ques_ans_denses[k].b)
114 | else:
115 | l_post_ques_ans_dense_ = lasagne.layers.DenseLayer(l_post_ques_ans_dense_, num_units=args.hidden_dim,\
116 | nonlinearity=lasagne.nonlinearities.rectify,\
117 | W=l_post_ques_ans_denses[k].W,\
118 | b=l_post_ques_ans_denses[k].b)
119 | l_post_ques_ans_dense_ = lasagne.layers.DenseLayer(l_post_ques_ans_dense_, num_units=1,\
120 | nonlinearity=lasagne.nonlinearities.sigmoid)
121 | pqa_preds[i*N+j] = lasagne.layers.get_output(l_post_ques_ans_dense_)
122 | pqa_loss += T.mean(lasagne.objectives.binary_crossentropy(pqa_preds[i*N+i], labels[:,i]))
123 | post_ques_ans_dense_params = lasagne.layers.get_all_params(l_post_ques_ans_dense, trainable=True)
124 | #print 'Params in post_ques_ans ', lasagne.layers.count_params(l_post_ques_ans_dense)
125 |
126 | return pqa_loss, post_ques_ans_dense_params, pqa_preds
127 |
128 | def build(word_embeddings, len_voc, word_emb_dim, args, freeze=False):
129 | # input theano vars
130 | posts = T.imatrix()
131 | post_masks = T.fmatrix()
132 | ques_list = T.itensor3()
133 | ques_masks_list = T.ftensor3()
134 | ans_list = T.itensor3()
135 | ans_masks_list = T.ftensor3()
136 | labels = T.imatrix()
137 | N = args.no_of_candidates
138 |
139 | post_out, post_lstm_params = build_lstm(posts, post_masks, args.post_max_len, \
140 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size)
141 | ques_out, ques_emb_out, ques_lstm_params = build_list_lstm(ques_list, ques_masks_list, N, args.ques_max_len, \
142 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size)
143 | ans_out, ans_emb_out, ans_lstm_params = build_list_lstm(ans_list, ans_masks_list, N, args.ans_max_len, \
144 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size)
145 |
146 | ques_sim, pq_a_squared_errors, pq_a_loss, post_ques_dense_params \
147 | = answer_model(post_out, ques_out, ques_emb_out, ans_out, ans_emb_out, labels, args)
148 |
149 | all_params = post_lstm_params + ques_lstm_params + post_ques_dense_params
150 |
151 | post_out, post_lstm_params = build_lstm(posts, post_masks, args.post_max_len, \
152 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size)
153 | ques_out, ques_emb_out, ques_lstm_params = build_list_lstm(ques_list, ques_masks_list, N, args.ques_max_len, \
154 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size)
155 | ans_out, ans_emb_out, ans_lstm_params = build_list_lstm(ans_list, ans_masks_list, N, args.ans_max_len, \
156 | word_embeddings, word_emb_dim, args.hidden_dim, len_voc, args.batch_size)
157 |
158 | pqa_loss, post_ques_ans_dense_params, pqa_preds = utility_calculator(post_out, ques_out, ques_emb_out, ans_out, \
159 | ques_sim, pq_a_squared_errors, labels, args)
160 |
161 | all_params += post_lstm_params + ques_lstm_params + ans_lstm_params
162 | all_params += post_ques_ans_dense_params
163 |
164 | loss = pq_a_loss + pqa_loss
165 | loss += args.rho * sum(T.sum(l ** 2) for l in all_params)
166 |
167 | updates = lasagne.updates.adam(loss, all_params, learning_rate=args.learning_rate)
168 |
169 | train_fn = theano.function([posts, post_masks, ques_list, ques_masks_list, ans_list, ans_masks_list, labels], \
170 | [loss, pq_a_loss, pqa_loss] + pq_a_squared_errors + ques_sim + pqa_preds, updates=updates)
171 | test_fn = theano.function([posts, post_masks, ques_list, ques_masks_list, ans_list, ans_masks_list, labels], \
172 | [loss, pq_a_loss, pqa_loss] + pq_a_squared_errors + ques_sim + pqa_preds,)
173 | return train_fn, test_fn
174 |
175 | def validate(val_fn, fold_name, epoch, fold, args, out_file=None):
176 | start = time.time()
177 | num_batches = 0
178 | cost = 0
179 | pq_a_cost = 0
180 | utility_cost = 0
181 | corr = 0
182 | mrr = 0
183 | total = 0
184 | N = args.no_of_candidates
185 | recall = [0]*N
186 |
187 | if out_file:
188 | out_file_o = open(out_file+'.epoch%d' % epoch, 'w')
189 | out_file_o.close()
190 | posts, post_masks, ques_list, ques_masks_list, ans_list, ans_masks_list, post_ids = fold
191 | labels = np.zeros((len(post_ids), N), dtype=np.int32)
192 | ranks = np.zeros((len(post_ids), N), dtype=np.int32)
193 | labels[:,0] = 1
194 | for j in range(N):
195 | ranks[:,j] = j
196 | ques_list, ques_masks_list, ans_list, ans_masks_list, labels, ranks = shuffle(ques_list, ques_masks_list, \
197 | ans_list, ans_masks_list, labels, ranks)
198 | for p, pm, q, qm, a, am, l, r, ids in iterate_minibatches(posts, post_masks, ques_list, ques_masks_list, \
199 | ans_list, ans_masks_list, labels, ranks, \
200 | post_ids, args.batch_size, shuffle=False):
201 | q = np.transpose(q, (1, 0, 2))
202 | qm = np.transpose(qm, (1, 0, 2))
203 | a = np.transpose(a, (1, 0, 2))
204 | am = np.transpose(am, (1, 0, 2))
205 |
206 | out = val_fn(p, pm, q, qm, a, am, l)
207 | loss = out[0]
208 | pq_a_loss = out[1]
209 | pqa_loss = out[2]
210 |
211 | pq_a_errors = out[3: 3+(N*N)]
212 | pq_a_errors = np.transpose(pq_a_errors)
213 |
214 | ques_sim = out[3+(N*N): 3+(N*N)+(N*N)]
215 | ques_sim = np.transpose(ques_sim)
216 |
217 | pqa_preds = out[3+(N*N)+(N*N):]
218 | pqa_preds = np.array(pqa_preds)[:,:,0]
219 | pqa_preds = np.transpose(pqa_preds)
220 |
221 | cost += loss
222 | pq_a_cost += pq_a_loss
223 | utility_cost += pqa_loss
224 |
225 | for j in range(args.batch_size):
226 | preds = [0.0]*N
227 | for k in range(N):
228 | preds[k] = pqa_preds[j][k*N+k]
229 | for m in range(N):
230 | if m == k:
231 | continue
232 | preds[k] += 0.1*math.exp(pq_a_errors[j][k*N+m])*ques_sim[j][k*N+m] * pqa_preds[j][k*N+m]
233 | rank = get_rank(preds, l[j])
234 | if rank == 1:
235 | corr += 1
236 | mrr += 1.0/rank
237 | for index in range(N):
238 | if rank <= index+1:
239 | recall[index] += 1
240 | total += 1
241 | if out_file:
242 | write_test_predictions(out_file, ids[j], preds, r[j], epoch)
243 | num_batches += 1
244 |
245 | lstring = '%s: epoch:%d, cost:%f, pq_a_cost:%f, utility_cost:%f, acc:%f, mrr:%f,time:%d\n' % \
246 | (fold_name, epoch, cost*1.0/num_batches, pq_a_cost*1.0/num_batches, utility_cost*1.0/num_batches, \
247 | corr*1.0/total, mrr*1.0/total, time.time()-start)
248 | recall = [round(curr_r*1.0/total, 3) for curr_r in recall]
249 | recall_str = '['
250 | for r in recall:
251 | recall_str += '%.3f ' % r
252 | recall_str += ']\n'
253 |
254 | outfile = open(args.stdout_file, 'a')
255 | outfile.write(lstring+'\n')
256 | outfile.write(recall_str+'\n')
257 | outfile.close()
258 |
259 | print lstring
260 | print recall
261 |
262 | def evpi(word_embeddings, vocab_size, word_emb_dim, freeze, args, train, test):
263 | outfile = open(args.stdout_file, 'w')
264 | outfile.close()
265 |
266 | print 'Compiling graph...'
267 | start = time.time()
268 | train_fn, test_fn = build(word_embeddings, vocab_size, word_emb_dim, args, freeze=freeze)
269 | print 'done! Time taken: ', time.time() - start
270 |
271 | # train network
272 | for epoch in range(args.no_of_epochs):
273 | validate(train_fn, 'TRAIN', epoch, train, args)
274 | validate(test_fn, '\t TEST', epoch, test, args, args.test_predictions_output)
275 | print "\n"
276 |
277 |
--------------------------------------------------------------------------------
/src/models/load_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import csv
3 | import os
4 | import sys
5 | import cPickle as p
6 | import numpy as np
7 |
8 | def get_indices(tokens, vocab):
9 | indices = np.zeros([len(tokens)], dtype=np.int32)
10 | UNK = ""
11 | for i, w in enumerate(tokens):
12 | try:
13 | indices[i] = vocab[w]
14 | except:
15 | indices[i] = vocab[UNK]
16 | return indices
17 |
18 | def read_data(post_data_tsv, qa_data_tsv):
19 | posts = {}
20 | titles = {}
21 | ques_lists = {}
22 | ans_lists = {}
23 | with open(post_data_tsv, 'rb') as tsvfile:
24 | tsv_reader = csv.DictReader(tsvfile, delimiter='\t')
25 | for row in tsv_reader:
26 | post_id = row['postid']
27 | titles[post_id] = row['title']
28 | posts[post_id] = row['post']
29 | with open(qa_data_tsv, 'rb') as tsvfile:
30 | tsv_reader = csv.DictReader(tsvfile, delimiter='\t')
31 | for row in tsv_reader:
32 | post_id = row['postid']
33 | ques_lists[post_id] = [row['q1'], row['q2'], row['q3'], row['q4'], row['q5'], row['q6'], row['q7'], row['q8'], row['q9'], row['q10']]
34 | ans_lists[post_id] = [row['a1'], row['a2'], row['a3'], row['a4'], row['a5'], row['a6'], row['a7'], row['a8'], row['a9'], row['a10']]
35 | return posts, titles, ques_lists, ans_lists
36 |
37 | def read_ids(ids_file):
38 | ids = [curr_id.strip('\n') for curr_id in open(ids_file, 'r').readlines()]
39 | return ids
40 |
41 | def generate_neural_vectors(posts, titles, ques_lists, ans_lists, post_ids, vocab, N, split):
42 | post_vectors = []
43 | ques_list_vectors = []
44 | ans_list_vectors = []
45 | for post_id in post_ids:
46 | post_vectors.append(get_indices(titles[post_id] + ' ' + posts[post_id], vocab))
47 | ques_list_vector = [None]*N
48 | ans_list_vector = [None]*N
49 | for k in range(N):
50 | ques_list_vector[k] = get_indices(ques_lists[post_id][k], vocab)
51 | ans_list_vector[k] = get_indices(ans_lists[post_id][k], vocab)
52 | ques_list_vectors.append(ques_list_vector)
53 | ans_list_vectors.append(ans_list_vector)
54 | dirname = os.path.dirname(args.train_ids)
55 | p.dump(post_ids, open(os.path.join(dirname, 'post_ids_'+split+'.p'), 'wb'))
56 | p.dump(post_vectors, open(os.path.join(dirname, 'post_vectors_'+split+'.p'), 'wb'))
57 | p.dump(ques_list_vectors, open(os.path.join(dirname, 'ques_list_vectors_'+split+'.p'), 'wb'))
58 | p.dump(ans_list_vectors, open(os.path.join(dirname, 'ans_list_vectors_'+split+'.p'), 'wb'))
59 |
60 | def main(args):
61 | vocab = p.load(open(args.vocab, 'rb'))
62 | train_ids = read_ids(args.train_ids)
63 | tune_ids = read_ids(args.tune_ids)
64 | test_ids = read_ids(args.test_ids)
65 | N = args.no_of_candidates
66 | posts, titles, ques_lists, ans_lists = read_data(args.post_data_tsv, args.qa_data_tsv)
67 | generate_neural_vectors(posts, titles, ques_lists, ans_lists, train_ids, vocab, N, 'train')
68 | generate_neural_vectors(posts, titles, ques_lists, ans_lists, tune_ids, vocab, N, 'tune')
69 | generate_neural_vectors(posts, titles, ques_lists, ans_lists, test_ids, vocab, N, 'test')
70 |
71 | if __name__ == "__main__":
72 | argparser = argparse.ArgumentParser(sys.argv[0])
73 | argparser.add_argument("--post_data_tsv", type = str)
74 | argparser.add_argument("--qa_data_tsv", type = str)
75 | argparser.add_argument("--train_ids", type = str)
76 | argparser.add_argument("--tune_ids", type = str)
77 | argparser.add_argument("--test_ids", type = str)
78 | argparser.add_argument("--vocab", type = str)
79 | argparser.add_argument("--no_of_candidates", type = int, default = 10)
80 | args = argparser.parse_args()
81 | print args
82 | print ""
83 | main(args)
84 |
85 |
--------------------------------------------------------------------------------
/src/models/lstm_helper.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import theano, lasagne
3 | import theano.tensor as T
4 |
5 | def build_list_lstm(content_list, content_masks_list, N, max_len, word_embeddings, word_emb_dim, hidden_dim, len_voc, batch_size):
6 | out = [None]*N
7 | emb_out = [None]*N
8 | l_in = lasagne.layers.InputLayer(shape=(batch_size, max_len), input_var=content_list[0])
9 | l_mask = lasagne.layers.InputLayer(shape=(batch_size, max_len), input_var=content_masks_list[0])
10 | l_emb = lasagne.layers.EmbeddingLayer(l_in, len_voc, word_emb_dim, W=word_embeddings)
11 | l_lstm = lasagne.layers.LSTMLayer(l_emb, hidden_dim, mask_input=l_mask, )
12 | out[0] = lasagne.layers.get_output(l_lstm)
13 | out[0] = T.mean(out[0] * content_masks_list[0][:,:,None], axis=1)
14 | emb_out[0] = lasagne.layers.get_output(l_emb)
15 | emb_out[0] = T.mean(emb_out[0] * content_masks_list[0][:,:,None], axis=1)
16 | for i in range(1, N):
17 | l_in_ = lasagne.layers.InputLayer(shape=(batch_size, max_len), input_var=content_list[i])
18 | l_mask_ = lasagne.layers.InputLayer(shape=(batch_size, max_len), input_var=content_masks_list[i])
19 | l_emb_ = lasagne.layers.EmbeddingLayer(l_in_, len_voc, word_emb_dim, W=l_emb.W)
20 | l_lstm_ = lasagne.layers.LSTMLayer(l_emb_, hidden_dim, mask_input=l_mask_,\
21 | ingate=lasagne.layers.Gate(W_in=l_lstm.W_in_to_ingate,\
22 | W_hid=l_lstm.W_hid_to_ingate,\
23 | b=l_lstm.b_ingate,\
24 | nonlinearity=l_lstm.nonlinearity_ingate),\
25 | outgate=lasagne.layers.Gate(W_in=l_lstm.W_in_to_outgate,\
26 | W_hid=l_lstm.W_hid_to_outgate,\
27 | b=l_lstm.b_outgate,\
28 | nonlinearity=l_lstm.nonlinearity_outgate),\
29 | forgetgate=lasagne.layers.Gate(W_in=l_lstm.W_in_to_forgetgate,\
30 | W_hid=l_lstm.W_hid_to_forgetgate,\
31 | b=l_lstm.b_forgetgate,\
32 | nonlinearity=l_lstm.nonlinearity_forgetgate),\
33 | cell=lasagne.layers.Gate(W_in=l_lstm.W_in_to_cell,\
34 | W_hid=l_lstm.W_hid_to_cell,\
35 | b=l_lstm.b_cell,\
36 | nonlinearity=l_lstm.nonlinearity_cell),\
37 | peepholes=False,\
38 | )
39 | out[i] = lasagne.layers.get_output(l_lstm_)
40 | out[i] = T.mean(out[i] * content_masks_list[i][:,:,None], axis=1)
41 | emb_out[i] = lasagne.layers.get_output(l_emb_)
42 | emb_out[i] = T.mean(emb_out[i] * content_masks_list[i][:,:,None], axis=1)
43 | l_emb.params[l_emb.W].remove('trainable')
44 | params = lasagne.layers.get_all_params(l_lstm, trainable=True)
45 | #print 'Params in lstm: ', (lasagne.layers.count_params(l_lstm)-lasagne.layers.count_params(l_emb))
46 | return out, emb_out, params
47 |
48 | def build_lstm(posts, post_masks, max_len, word_embeddings, word_emb_dim, hidden_dim, len_voc, batch_size):
49 |
50 | l_in = lasagne.layers.InputLayer(shape=(batch_size, max_len), input_var=posts)
51 | l_mask = lasagne.layers.InputLayer(shape=(batch_size, max_len), input_var=post_masks)
52 | l_emb = lasagne.layers.EmbeddingLayer(l_in, len_voc, word_emb_dim, W=word_embeddings)
53 | l_lstm = lasagne.layers.LSTMLayer(l_emb, hidden_dim, mask_input=l_mask, )
54 | out = lasagne.layers.get_output(l_lstm)
55 | out = T.mean(out * post_masks[:,:,None], axis=1)
56 | l_emb.params[l_emb.W].remove('trainable')
57 | params = lasagne.layers.get_all_params(l_lstm, trainable=True)
58 | #print 'Params in post_lstm: ', (lasagne.layers.count_params(l_lstm)-lasagne.layers.count_params(l_emb))
59 | return out, params
60 |
61 |
--------------------------------------------------------------------------------
/src/models/main.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | import theano, lasagne
4 | import numpy as np
5 | import cPickle as p
6 | import theano.tensor as T
7 | from collections import Counter
8 | import pdb
9 | import time
10 | import random, math
11 | from baseline_pq import baseline_pq
12 | from baseline_pa import baseline_pa
13 | from baseline_pqa import baseline_pqa
14 | from evpi import evpi
15 | from model_helper import *
16 |
17 | def main(args):
18 | post_ids_train = p.load(open(args.post_ids_train, 'rb'))
19 | post_ids_train = np.array(post_ids_train)
20 | post_vectors_train = p.load(open(args.post_vectors_train, 'rb'))
21 | ques_list_vectors_train = p.load(open(args.ques_list_vectors_train, 'rb'))
22 | ans_list_vectors_train = p.load(open(args.ans_list_vectors_train, 'rb'))
23 |
24 | post_ids_test = p.load(open(args.post_ids_test, 'rb'))
25 | post_ids_test = np.array(post_ids_test)
26 | post_vectors_test = p.load(open(args.post_vectors_test, 'rb'))
27 | ques_list_vectors_test = p.load(open(args.ques_list_vectors_test, 'rb'))
28 | ans_list_vectors_test = p.load(open(args.ans_list_vectors_test, 'rb'))
29 |
30 | out_file = open(args.test_predictions_output, 'w')
31 | out_file.close()
32 |
33 | word_embeddings = p.load(open(args.word_embeddings, 'rb'))
34 | word_embeddings = np.asarray(word_embeddings, dtype=np.float32)
35 | vocab_size = len(word_embeddings)
36 | word_emb_dim = len(word_embeddings[0])
37 | print 'word emb dim: ', word_emb_dim
38 | freeze = False
39 | N = args.no_of_candidates
40 |
41 | print 'vocab_size ', vocab_size, ', post_max_len ', args.post_max_len, ' ques_max_len ', args.ques_max_len, ' ans_max_len ', args.ans_max_len
42 |
43 | start = time.time()
44 | print 'generating data'
45 | train = generate_data(post_vectors_train, ques_list_vectors_train, ans_list_vectors_train, args)
46 | test = generate_data(post_vectors_test, ques_list_vectors_test, ans_list_vectors_test, args)
47 | train.append(post_ids_train)
48 | test.append(post_ids_test)
49 |
50 | print 'done! Time taken: ', time.time() - start
51 |
52 | print 'Size of training data: ', len(post_ids_train)
53 | print 'Size of test data: ', len(post_ids_test)
54 |
55 | if args.model == 'baseline_pq':
56 | baseline_pq(word_embeddings, vocab_size, word_emb_dim, freeze, args, train, test)
57 | elif args.model == 'baseline_pa':
58 | baseline_pa(word_embeddings, vocab_size, word_emb_dim, freeze, args, train, test)
59 | elif args.model == 'baseline_pqa':
60 | baseline_pqa(word_embeddings, vocab_size, word_emb_dim, freeze, args, train, test)
61 | elif args.model == 'evpi':
62 | evpi(word_embeddings, vocab_size, word_emb_dim, freeze, args, train, test)
63 |
64 | if __name__ == '__main__':
65 | argparser = argparse.ArgumentParser(sys.argv[0])
66 | argparser.add_argument("--post_ids_train", type = str)
67 | argparser.add_argument("--post_vectors_train", type = str)
68 | argparser.add_argument("--ques_list_vectors_train", type = str)
69 | argparser.add_argument("--ans_list_vectors_train", type = str)
70 | argparser.add_argument("--post_ids_test", type = str)
71 | argparser.add_argument("--post_vectors_test", type = str)
72 | argparser.add_argument("--ques_list_vectors_test", type = str)
73 | argparser.add_argument("--ans_list_vectors_test", type = str)
74 | argparser.add_argument("--word_embeddings", type = str)
75 | argparser.add_argument("--batch_size", type = int, default = 256)
76 | argparser.add_argument("--no_of_epochs", type = int, default = 20)
77 | argparser.add_argument("--hidden_dim", type = int, default = 100)
78 | argparser.add_argument("--no_of_candidates", type = int, default = 10)
79 | argparser.add_argument("--learning_rate", type = float, default = 0.001)
80 | argparser.add_argument("--rho", type = float, default = 1e-5)
81 | argparser.add_argument("--post_max_len", type = int, default = 300)
82 | argparser.add_argument("--ques_max_len", type = int, default = 40)
83 | argparser.add_argument("--ans_max_len", type = int, default = 40)
84 | argparser.add_argument("--test_predictions_output", type = str)
85 | argparser.add_argument("--stdout_file", type = str)
86 | argparser.add_argument("--model", type = str)
87 | args = argparser.parse_args()
88 | print args
89 | print ""
90 | main(args)
91 |
--------------------------------------------------------------------------------
/src/models/model_helper.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | import theano, lasagne
4 | import numpy as np
5 | import cPickle as p
6 | import theano.tensor as T
7 | from collections import Counter
8 | import pdb
9 | import time
10 | import random, math
11 | DEPTH = 5
12 |
13 | def get_data_masks(content, max_len):
14 | if len(content) > max_len:
15 | data = content[:max_len]
16 | data_mask = np.ones(max_len)
17 | else:
18 | data = np.concatenate((content, np.zeros(max_len-len(content))), axis=0)
19 | data_mask = np.concatenate((np.ones(len(content)), np.zeros(max_len-len(content))), axis=0)
20 | return data, data_mask
21 |
22 | def generate_data(posts, ques_list, ans_list, args):
23 | data_size = len(posts)
24 | data_posts = np.zeros((data_size, args.post_max_len), dtype=np.int32)
25 | data_post_masks = np.zeros((data_size, args.post_max_len), dtype=np.float32)
26 |
27 | N = args.no_of_candidates
28 | data_ques_list = np.zeros((data_size, N, args.ques_max_len), dtype=np.int32)
29 | data_ques_masks_list = np.zeros((data_size, N, args.ques_max_len), dtype=np.float32)
30 |
31 | data_ans_list = np.zeros((data_size, N, args.ans_max_len), dtype=np.int32)
32 | data_ans_masks_list = np.zeros((data_size, N, args.ans_max_len), dtype=np.float32)
33 |
34 | for i in range(data_size):
35 | data_posts[i], data_post_masks[i] = get_data_masks(posts[i], args.post_max_len)
36 | for j in range(N):
37 | data_ques_list[i][j], data_ques_masks_list[i][j] = get_data_masks(ques_list[i][j], args.ques_max_len)
38 | data_ans_list[i][j], data_ans_masks_list[i][j] = get_data_masks(ans_list[i][j], args.ans_max_len)
39 |
40 | return [data_posts, data_post_masks, data_ques_list, data_ques_masks_list, data_ans_list, data_ans_masks_list]
41 |
42 | def iterate_minibatches(posts, post_masks, ques_list, ques_masks_list, ans_list, ans_masks_list, \
43 | labels, ranks, post_ids, batch_size, shuffle=False):
44 | if shuffle:
45 | indices = np.arange(posts.shape[0])
46 | np.random.shuffle(indices)
47 | for start_idx in range(0, posts.shape[0] - batch_size + 1, batch_size):
48 | if shuffle:
49 | excerpt = indices[start_idx:start_idx + batch_size]
50 | else:
51 | excerpt = slice(start_idx, start_idx + batch_size)
52 | yield posts[excerpt], post_masks[excerpt], ques_list[excerpt], ques_masks_list[excerpt], \
53 | ans_list[excerpt], ans_masks_list[excerpt], labels[excerpt], ranks[excerpt], post_ids[excerpt]
54 |
55 | def get_rank(preds, labels):
56 | preds = np.array(preds)
57 | correct = np.where(labels==1)[0][0]
58 | sort_index_preds = np.argsort(preds)
59 | desc_sort_index_preds = sort_index_preds[::-1] #since ascending sort and we want descending
60 | rank = np.where(desc_sort_index_preds==correct)[0][0]
61 | return rank+1
62 |
63 | def shuffle(q, qm, a, am, l, r):
64 | shuffled_q = np.zeros(q.shape, dtype=np.int32)
65 | shuffled_qm = np.zeros(qm.shape, dtype=np.float32)
66 | shuffled_a = np.zeros(a.shape, dtype=np.int32)
67 | shuffled_am = np.zeros(am.shape, dtype=np.float32)
68 | shuffled_l = np.zeros(l.shape, dtype=np.int32)
69 | shuffled_r = np.zeros(r.shape, dtype=np.int32)
70 |
71 | for i in range(len(q)):
72 | indexes = range(len(q[i]))
73 | random.shuffle(indexes)
74 | for j, index in enumerate(indexes):
75 | shuffled_q[i][j] = q[i][index]
76 | shuffled_qm[i][j] = qm[i][index]
77 | shuffled_a[i][j] = a[i][index]
78 | shuffled_am[i][j] = am[i][index]
79 | shuffled_l[i][j] = l[i][index]
80 | shuffled_r[i][j] = r[i][index]
81 |
82 | return shuffled_q, shuffled_qm, shuffled_a, shuffled_am, shuffled_l, shuffled_r
83 |
84 |
85 | def write_test_predictions(out_file, postId, utilities, ranks, epoch):
86 | lstring = "[%s]: " % (postId)
87 | N = len(utilities)
88 | scores = [0]*N
89 | for i in range(N):
90 | scores[ranks[i]] = utilities[i]
91 | for i in range(N):
92 | lstring += "%f " % (scores[i])
93 | out_file_o = open(out_file+'.epoch%d' % epoch, 'a')
94 | out_file_o.write(lstring + '\n')
95 | out_file_o.close()
96 |
97 | def get_annotations(line):
98 | set_info, post_id, best, valids, confidence = line.split(',')
99 | sitename = set_info.split('_')[1]
100 | best = [int(best)]
101 | valids = [int(v) for v in valids.split()]
102 | confidence = int(confidence)
103 | return post_id, sitename, best, valids, confidence
104 |
105 | def evaluate_using_human_annotations(args, preds):
106 | human_annotations_file = open(args.test_human_annotations, 'r')
107 | best_acc_on10 = 0
108 | best_acc_on9 = 0
109 |
110 | valid_acc_on10 = 0
111 | valid_acc_on9 = 0
112 |
113 | total = 0
114 | total_best_on9 = 0
115 |
116 | for line in human_annotations_file.readlines():
117 | line = line.strip('\n')
118 | splits = line.split('\t')
119 | post_id, sitename, best, valids, confidence = get_annotations(splits[0])
120 | if len(splits) > 1:
121 | post_id2, sitename2, best2, valids2, confidence2 = get_annotations(splits[1])
122 | assert(sitename == sitename2)
123 | assert(post_id == post_id2)
124 | best += [best2]
125 | valids += valids2
126 |
127 | if best != 0:
128 | total_best_on9 += 1
129 |
130 | post_id = sitename+'_'+post_id
131 | pred = preds[post_id].index(max(preds[post_id]))
132 |
133 | if pred in best:
134 | best_acc_on10 += 1
135 | if best != 0:
136 | best_acc_on9 +=1
137 |
138 | if pred in valids:
139 | valid_acc_on10 += 1
140 | if pred != 0:
141 | valid_acc_on9 += 1
142 |
143 | total += 1
144 |
145 | print
146 | print '\t\tBest acc on 10: %.2f Valid acc on 10: %.2f' % (best_acc_on10*100.0/total, valid_acc_on10*100.0/total)
147 | print '\t\tBest acc on 9: %.2f Valid acc on 9: %.2f' % (best_acc_on9*100.0/total_best_on9, valid_acc_on9*100.0/total)
148 |
149 |
--------------------------------------------------------------------------------
/src/models/run_combine_domains.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATA_DIR=data
4 | UBUNTU=askubuntu.com
5 | UNIX=unix.stackexchange.com
6 | SUPERUSER=superuser.com
7 | SCRIPTS_DIR=src/models
8 | SITE_NAME=askubuntu_unix_superuser
9 |
10 | mkdir -p $DATA_DIR/$SITE_NAME
11 |
12 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/post_vectors_train.p \
13 | $DATA_DIR/$UNIX/post_vectors_train.p \
14 | $DATA_DIR/$SUPERUSER/post_vectors_train.p \
15 | $DATA_DIR/$SITE_NAME/post_vectors_train.p
16 |
17 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/ques_list_vectors_train.p \
18 | $DATA_DIR/$UNIX/ques_list_vectors_train.p \
19 | $DATA_DIR/$SUPERUSER/ques_list_vectors_train.p \
20 | $DATA_DIR/$SITE_NAME/ques_list_vectors_train.p
21 |
22 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/ans_list_vectors_train.p \
23 | $DATA_DIR/$UNIX/ans_list_vectors_train.p \
24 | $DATA_DIR/$SUPERUSER/ans_list_vectors_train.p \
25 | $DATA_DIR/$SITE_NAME/ans_list_vectors_train.p
26 |
27 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/post_ids_train.p \
28 | $DATA_DIR/$UNIX/post_ids_train.p \
29 | $DATA_DIR/$SUPERUSER/post_ids_train.p \
30 | $DATA_DIR/$SITE_NAME/post_ids_train.p
31 |
32 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/post_vectors_tune.p \
33 | $DATA_DIR/$UNIX/post_vectors_tune.p \
34 | $DATA_DIR/$SUPERUSER/post_vectors_tune.p \
35 | $DATA_DIR/$SITE_NAME/post_vectors_tune.p
36 |
37 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/ques_list_vectors_tune.p \
38 | $DATA_DIR/$UNIX/ques_list_vectors_tune.p \
39 | $DATA_DIR/$SUPERUSER/ques_list_vectors_tune.p \
40 | $DATA_DIR/$SITE_NAME/ques_list_vectors_tune.p
41 |
42 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/ans_list_vectors_tune.p \
43 | $DATA_DIR/$UNIX/ans_list_vectors_tune.p \
44 | $DATA_DIR/$SUPERUSER/ans_list_vectors_tune.p \
45 | $DATA_DIR/$SITE_NAME/ans_list_vectors_tune.p
46 |
47 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/post_ids_tune.p \
48 | $DATA_DIR/$UNIX/post_ids_tune.p \
49 | $DATA_DIR/$SUPERUSER/post_ids_tune.p \
50 | $DATA_DIR/$SITE_NAME/post_ids_tune.p
51 |
52 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/post_vectors_test.p \
53 | $DATA_DIR/$UNIX/post_vectors_test.p \
54 | $DATA_DIR/$SUPERUSER/post_vectors_test.p \
55 | $DATA_DIR/$SITE_NAME/post_vectors_test.p
56 |
57 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/ques_list_vectors_test.p \
58 | $DATA_DIR/$UNIX/ques_list_vectors_test.p \
59 | $DATA_DIR/$SUPERUSER/ques_list_vectors_test.p \
60 | $DATA_DIR/$SITE_NAME/ques_list_vectors_test.p
61 |
62 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/ans_list_vectors_test.p \
63 | $DATA_DIR/$UNIX/ans_list_vectors_test.p \
64 | $DATA_DIR/$SUPERUSER/ans_list_vectors_test.p \
65 | $DATA_DIR/$SITE_NAME/ans_list_vectors_test.p
66 |
67 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/post_ids_test.p \
68 | $DATA_DIR/$UNIX/post_ids_test.p \
69 | $DATA_DIR/$SUPERUSER/post_ids_test.p \
70 | $DATA_DIR/$SITE_NAME/post_ids_test.p
71 |
72 | cat $DATA_DIR/$UBUNTU/human_annotations $DATA_DIR/$UNIX/human_annotations $DATA_DIR/$SUPERUSER/human_annotations > $DATA_DIR/$SITE_NAME/human_annotations
73 |
--------------------------------------------------------------------------------
/src/models/run_load_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATA_DIR=data
4 | EMB_DIR=embeddings
5 | #SITE_NAME=askubuntu.com
6 | #SITE_NAME=unix.stackexchange.com
7 | SITE_NAME=superuser.com
8 |
9 | SCRIPTS_DIR=src/models
10 |
11 | python $SCRIPTS_DIR/load_data.py --post_data_tsv $DATA_DIR/$SITE_NAME/post_data.tsv \
12 | --qa_data_tsv $DATA_DIR/$SITE_NAME/qa_data.tsv \
13 | --train_ids $DATA_DIR/$SITE_NAME/train_ids \
14 | --tune_ids $DATA_DIR/$SITE_NAME/tune_ids \
15 | --test_ids $DATA_DIR/$SITE_NAME/test_ids \
16 | --vocab $EMB_DIR/vocab.p
17 |
--------------------------------------------------------------------------------
/src/models/run_main.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATA_DIR=data
4 | EMB_DIR=embeddings
5 | #SITE_NAME=askubuntu.com
6 | #SITE_NAME=unix.stackexchange.com
7 | #SITE_NAME=superuser.com
8 | SITE_NAME=askubuntu_unix_superuser
9 |
10 | OUTPUT_DIR=output
11 | SCRIPTS_DIR=src/models
12 | #MODEL=baseline_pq
13 | #MODEL=baseline_pa
14 | #MODEL=baseline_pqa
15 | MODEL=evpi
16 |
17 | mkdir -p $OUTPUT_DIR
18 |
19 | source /fs/clip-amr/gpu_virtualenv/bin/activate
20 | module add cuda/8.0.44
21 | module add cudnn/v5.1
22 |
23 | THEANO_FLAGS=floatX=float32,device=gpu python $SCRIPTS_DIR/main.py \
24 | --post_ids_train $DATA_DIR/$SITE_NAME/post_ids_train.p \
25 | --post_vectors_train $DATA_DIR/$SITE_NAME/post_vectors_train.p \
26 | --ques_list_vectors_train $DATA_DIR/$SITE_NAME/ques_list_vectors_train.p \
27 | --ans_list_vectors_train $DATA_DIR/$SITE_NAME/ans_list_vectors_train.p \
28 | --post_ids_test $DATA_DIR/$SITE_NAME/post_ids_test.p \
29 | --post_vectors_test $DATA_DIR/$SITE_NAME/post_vectors_test.p \
30 | --ques_list_vectors_test $DATA_DIR/$SITE_NAME/ques_list_vectors_test.p \
31 | --ans_list_vectors_test $DATA_DIR/$SITE_NAME/ans_list_vectors_test.p \
32 | --word_embeddings $EMB_DIR/word_embeddings.p \
33 | --batch_size 128 --no_of_epochs 20 --no_of_candidates 10 \
34 | --test_predictions_output $DATA_DIR/$SITE_NAME/test_predictions_${MODEL}.out \
35 | --stdout_file $OUTPUT_DIR/${SITE_NAME}.${MODEL}.out \
36 | --model $MODEL \
37 |
--------------------------------------------------------------------------------