├── README.md └── src ├── main └── Main.java ├── model ├── Model.java └── Task.java └── myUtils ├── Dataset.java ├── Dictionary.java ├── MyMath.java └── Text.java /README.md: -------------------------------------------------------------------------------- 1 | # Neural-BoN 2 | This code implements the Neural-BoN model proposed in AAAI-17 paper: [Bofang Li, Tao Liu, Zhe Zhao, Puwei Wang and Xiaoyong Du - **Neural Bag-of-Ngrams**]. 3 | 4 | ## Code 5 | This code has been tested on Windows, but it should work on Linux, OSX or any other operation system without any changes (Thanks to Java). 6 | 7 | All parameters your may need to change are in the top lines of src/main/Main.java. 8 | 9 | To train Neural-BoN on an unlabeled corpus, you can just specify the corpus in the Main function of src/main/Main.java and run it. This will generate the learned n-gram representations in the results folder. 10 | 11 | To train Neural-BoN on a labeled corpus, you can specify the corpus and implement an getXXXDataset function in src/myUtils/Dataset.java or change the getIMDBDataset function in src/myUtils/Dataset.java for the specific format. 12 | 13 | ## More Information 14 | 15 | This code is implemented upon the [DV-ngram code] (https://github.com/libofang/DV-ngram) of our ICLR 2016 workshop paper [Bofang Li, Tao Liu, Xiaoyong Du, Deyuan Zhang and Zhe Zhao - **Learning Document Embedding by Predicting N-grams for Sentiment Classification of Long Movie Reviews**] (http://arxiv.org/abs/1512.08183). We also recommend you to try that model for simpler generation of document vectors. 16 | 17 | Again, we thank Grégoire Mesnil et al. for their implementation of Paragraph Vector. Both their code and [iclr 2015 workshop paper](http://arxiv.org/abs/1412.5335) have influenced us a lot. 18 | 19 | -------------------------------------------------------------------------------- /src/main/Main.java: -------------------------------------------------------------------------------- 1 | package main; 2 | 3 | import java.util.*; 4 | 5 | import model.Model; 6 | import myUtils.*; 7 | import myUtils.Dictionary; 8 | 9 | public class Main { 10 | // all the parameters are here 11 | public static int ngram = 1; 12 | public static float lr = 0.025f; // learning rate //0.1 89.93420.0 13 | public static float rr = 0.000f; // regulize rate //not used 14 | public static int negSize = 5; // negative sampling size 15 | public static int iter = 10; // iteration 16 | public static int batchSize = 100; 17 | public static int ws = 5; // CGNR window size 18 | public static int n = 500; // word embeddings dimension 19 | public static String info = "l"; // w, d or l // w stands for CGNR. d stands 20 | // for TGNR. l stands for LGNR. 21 | public static boolean useUniqueWordList = false; 22 | public static String negType = "i"; // i or o. negative sampling type 23 | public static boolean preLogistic = false; 24 | public static boolean saveWordVectors = true; 25 | public static double subSampleRate = 0; 26 | public static double dropRate = 0; 27 | public static String addDataType = ""; // news4 28 | public static double subRate = 1; 29 | public static double pow = 1; 30 | public static int testId = new Random().nextInt(10000); 31 | 32 | public static class DatasetTask { 33 | public String folderName; 34 | public String type; 35 | 36 | public DatasetTask(String folderName, String type) { 37 | this.folderName = folderName; 38 | this.type = type; 39 | } 40 | } 41 | 42 | public static void train(List taskList) { 43 | 44 | for (int index_task = 0; index_task < taskList.size(); index_task++) { 45 | DatasetTask task = taskList.get(index_task); 46 | 47 | // load dataset 48 | Dataset dataset = null; 49 | dataset = new Dataset(task.folderName, task.type, ngram, addDataType); 50 | 51 | // initialize model 52 | Model model = new Model(task.folderName, dataset, lr, rr, negSize, iter, batchSize, ws, n, info); 53 | 54 | System.out.println(index_task + "||" + index_task + "||" + index_task + "||" + index_task + "||"); 55 | // train model 56 | model.train(); 57 | 58 | { 59 | long startTime = System.currentTimeMillis(); 60 | 61 | // save model 62 | String baseTmpSaveFolder = "./results/" + task.folderName + ngram + "/" + info + "_" + addDataType 63 | + "_" + n + "_r" + testId + "/"; 64 | String tmpSaveFolder = baseTmpSaveFolder + "fold" + index_task + "/"; 65 | if (saveWordVectors) 66 | model.saveWordVectors(tmpSaveFolder); 67 | 68 | System.out.println("||" + "|" + "time:" + (System.currentTimeMillis() - startTime)); 69 | } 70 | } 71 | System.out.println(); 72 | System.out.println(info); 73 | System.out.println(); 74 | } 75 | 76 | public static void main(String args[]) { 77 | 78 | { 79 | List taskList = new ArrayList(); 80 | /** 81 | * choose one of following line to add a task. 82 | * 83 | * "unlabel" indicate unlabeled corpus and suitable only for CGNR 84 | * and TGNR. 85 | * 86 | * "imdb" indicate imdb format. You can implement an getXXXDataset function 87 | * in src/myUtils/Dataset.java our change the getIMDBDataset 88 | * function in src/myUtils/Dataset.java for the specific format. 89 | */ 90 | // taskList.add(new DatasetTask("books.txt", "unlabel")); 91 | // taskList.add(new DatasetTask("imdb.txt", "imdb")); 92 | 93 | train(taskList); 94 | } 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /src/model/Model.java: -------------------------------------------------------------------------------- 1 | package model; 2 | 3 | import java.awt.event.TextListener; 4 | import java.io.File; 5 | import java.io.FileWriter; 6 | import java.io.IOException; 7 | import java.util.*; 8 | import java.util.Map.Entry; 9 | 10 | import main.Main; 11 | import myUtils.Dataset; 12 | import myUtils.Dictionary; 13 | import myUtils.MyMath; 14 | import myUtils.Text; 15 | 16 | public class Model { 17 | public String folderName; 18 | private Dataset dataset; 19 | 20 | private float lr; // learning rate 21 | private float original_lr; 22 | private float rr; // regulize rate 23 | private int negSize; 24 | private int iter; 25 | private int batchSize; 26 | private int ws; 27 | private int n; 28 | private String info; // infomation used 29 | 30 | private TrainThread ttList[]; 31 | 32 | private float WV[][][]; 33 | private float TV[][]; 34 | private float LV[][]; 35 | private double bestAcDev = 0; 36 | private double ac = 0; 37 | 38 | private static Random random = new Random(); 39 | 40 | public Model(String folderName, Dataset dataset, float lr, float rr, int negSize, int iter, int batchSize, int ws, 41 | int n, String info) { 42 | this.folderName = folderName; 43 | this.dataset = dataset; 44 | this.lr = lr; 45 | this.original_lr = lr; 46 | this.rr = rr; 47 | this.negSize = negSize; 48 | this.iter = iter; 49 | this.batchSize = batchSize; 50 | this.ws = ws; 51 | this.n = n; 52 | this.info = info; 53 | ttList = new TrainThread[30]; 54 | 55 | // initialize vectors 56 | System.out.println("initializing model:vocabSize=" + dataset.dic.uniqueWordSize() + " textSize=" 57 | + dataset.textList.size()); 58 | if (info.contains("w")) 59 | WV = new float[dataset.dic.uniqueWordSize()][2][]; 60 | else 61 | WV = new float[dataset.dic.uniqueWordSize()][1][]; 62 | for (int i = 0; i < WV.length; i++) { 63 | if (dataset.dic.wordIdCountList.get(i) <= 1) 64 | continue; 65 | 66 | if (i % 100000 == 0) 67 | System.out.println("init vectors :" + (i * 100 / dataset.dic.uniqueWordSize()) + "%"); 68 | for (int bi = 0; bi < WV[i].length; bi++) { 69 | 70 | WV[i][bi] = new float[n]; 71 | 72 | for (int j = 0; j < WV[i][bi].length; j++) 73 | WV[i][bi][j] = (random.nextFloat() - 0.5f) / n; 74 | // for (int j = 0; j < WV[i][bi].length; j++) 75 | // WV[i][bi][j] = 0; 76 | // WV[i][bi][random.nextInt(n)] = 1; 77 | // WV[i][bi][random.nextInt(n)] = 1; 78 | 79 | } 80 | } 81 | 82 | // one-hot representation 83 | // for (int i = 0; i < WV.length; i++) { 84 | // if (i % 100000 == 0) 85 | // System.out.println(i * 100 / dataset.dic.size()); 86 | // if (dataset.dic.wordIdCountList.get(i) >= 2) { 87 | // for (int bi = 0; bi < WV[i].length; bi++) { 88 | // WV[i][bi] = new float[n]; 89 | // for (int j = 0; j < WV[i][bi].length; j++) 90 | // WV[i][bi][j] = 0; 91 | // WV[i][bi][i % n] = 1; 92 | // } 93 | // } 94 | // } 95 | 96 | TV = new float[dataset.textList.size()][n]; 97 | for (int i = 0; i < TV.length; i++) { 98 | for (int j = 0; j < TV[i].length; j++) 99 | TV[i][j] = (random.nextFloat() - 0.5f) / n; 100 | } 101 | LV = new float[20][n]; 102 | for (int i = 0; i < LV.length; i++) { 103 | for (int j = 0; j < LV[i].length; j++) 104 | LV[i][j] = (random.nextFloat() - 0.5f) / n; 105 | } 106 | System.out.println("initializing finished"); 107 | } 108 | 109 | public boolean isSub(int id) { 110 | double subsample = Main.subRate * dataset.dic.totalWordSize; 111 | if (dataset.dic.wordIdCountList.get(id) >= subsample 112 | && random.nextDouble() < (1 - Math.sqrt(subsample / dataset.dic.wordIdCountList.get(id)))) 113 | return true; 114 | return false; 115 | } 116 | 117 | public static void train(float a[], float b[], boolean mask[], int target, double lr, double rr) { 118 | if (a == null || b == null) 119 | return; 120 | 121 | float aa[] = null; 122 | float bb[] = null; 123 | if (Main.preLogistic) { 124 | aa = new float[a.length]; 125 | bb = new float[b.length]; 126 | } 127 | 128 | float y = 0; 129 | for (int i = 0; i < a.length; i++) 130 | if (mask == null || mask[i]) { 131 | if (Main.preLogistic) { 132 | aa[i] = MyMath.tanh(a[i]); 133 | bb[i] = MyMath.tanh(b[i]); 134 | y += aa[i] * bb[i]; 135 | } else { 136 | y += a[i] * b[i]; 137 | } 138 | } 139 | y = MyMath.logistic(y); 140 | for (int i = 0; i < a.length; i++) { 141 | if (mask == null || mask[i]) { 142 | if (Main.preLogistic) { 143 | a[i] += -(y - target) * bb[i] * (aa[i] * (1 - aa[i] * aa[i])) * lr - rr * a[i] * lr; 144 | b[i] += -(y - target) * aa[i] * (bb[i] * (1 - bb[i] * bb[i])) * lr - rr * b[i] * lr; 145 | } else { 146 | float wv = a[i]; 147 | a[i] += -(y - target) * b[i] * lr - rr * a[i] * lr; 148 | b[i] += -(y - target) * wv * lr - rr * b[i] * lr; 149 | } 150 | } 151 | } 152 | } 153 | 154 | public class TrainThread extends Thread { 155 | 156 | public List taskSubList; 157 | public float lr; 158 | public float rr; 159 | 160 | public TrainThread(List taskSubList, float lr, float rr) { 161 | this.taskSubList = taskSubList; 162 | this.lr = lr; 163 | this.rr = rr; 164 | } 165 | 166 | private void runTask(Task task, boolean[] mask) { 167 | if (task.type == 0) { 168 | 169 | if (isSub(task.b)) 170 | return; 171 | 172 | train(TV[task.a], WV[task.b][0], mask, 1, lr, rr); 173 | for (int neg = 0; neg < negSize; neg++) { 174 | if (Main.negType.contains("i")) { 175 | int c = dataset.dic.getRandomWord(); 176 | train(TV[task.a], WV[c][0], mask, 0, lr, rr); 177 | } 178 | if (Main.negType.contains("o")) { 179 | while (true) { 180 | int c = random.nextInt(TV.length); 181 | if (c != task.a) { 182 | train(TV[c], WV[task.b][0], mask, 0, lr, rr); 183 | break; 184 | } 185 | } 186 | } 187 | } 188 | } 189 | if (task.type == 1) { 190 | 191 | if (isSub(task.b)) 192 | return; 193 | 194 | train(WV[task.a][1], WV[task.b][0], mask, 1, lr, rr); 195 | for (int neg = 0; neg < negSize; neg++) { 196 | if (Main.negType.contains("i")) { 197 | train(WV[task.a][1], WV[dataset.dic.getRandomWord()][0], mask, 0, lr, rr); 198 | } 199 | if (Main.negType.contains("o")) { 200 | while (true) { 201 | int c = random.nextInt(WV.length); 202 | if (c != task.a) { 203 | train(WV[c][1], WV[task.b][0], mask, 0, lr, rr); 204 | break; 205 | } 206 | } 207 | } 208 | } 209 | } 210 | if (task.type == 2) { 211 | 212 | if (isSub(task.b)) 213 | return; 214 | 215 | train(LV[task.a], WV[task.b][0], mask, 1, lr, rr); 216 | for (int neg = 0; neg < negSize; neg++) { 217 | if (Main.negType.contains("i")) { 218 | train(LV[task.a], WV[dataset.dic.getRandomWord()][0], mask, 0, lr, rr); 219 | } 220 | } 221 | if (Main.negType.contains("o")) { 222 | while (true) { 223 | int c = dataset.labelSet.toArray(new Integer[dataset.labelSet.size()])[random 224 | .nextInt(dataset.labelSet.size())]; 225 | // System.out.println(c); 226 | if (c != task.a) { 227 | train(LV[c], WV[task.b][0], mask, 0, lr, rr); 228 | break; 229 | } 230 | } 231 | } 232 | } 233 | } 234 | 235 | private void runSmallTask(Task task, boolean[] mask) { 236 | Text text = dataset.textList.get(task.a); 237 | if (task.type == 0) { 238 | for (int b : text.getIds(Main.useUniqueWordList)) { 239 | task.b = b; 240 | train(TV[task.a], WV[task.b][0], mask, 1, lr, rr); 241 | for (int neg = 0; neg < negSize; neg++) { 242 | while (true) { 243 | int c = random.nextInt(TV.length); 244 | if (c != task.a) { 245 | train(TV[c], WV[task.b][0], mask, 0, lr, rr); 246 | break; 247 | } 248 | } 249 | } 250 | } 251 | } 252 | if (task.type == 1) { 253 | for (Text.Pair p : text.getIdPairList(ws, dataset.dic)) { 254 | task.a = p.a; 255 | task.b = p.b; 256 | train(WV[task.a][1], WV[task.b][0], mask, 1, lr, rr); 257 | for (int neg = 0; neg < negSize; neg++) { 258 | while (true) { 259 | int c = dataset.dic.getRandomWord(); 260 | if (c != task.a) { 261 | train(WV[c][1], WV[task.b][0], mask, 0, lr, rr); 262 | break; 263 | } 264 | } 265 | } 266 | } 267 | } 268 | if (task.type == 2) { 269 | for (int b : text.getIds(Main.useUniqueWordList)) { 270 | task.a = text.label; 271 | task.b = b; 272 | 273 | train(LV[task.a], WV[task.b][0], mask, 1, lr, rr); 274 | for (int neg = 0; neg < negSize; neg++) { 275 | while (true) { 276 | int c = dataset.textList.get(random.nextInt(dataset.textList.size())).label; 277 | if (c != task.a) { 278 | train(LV[c], WV[task.b][0], mask, 0, lr, rr); 279 | break; 280 | } 281 | } 282 | } 283 | } 284 | } 285 | } 286 | 287 | public void run() { 288 | Random random = new Random(); 289 | 290 | boolean mask[] = null; 291 | mask = new boolean[n]; 292 | for (int i = 0; i < mask.length; i++) { 293 | if (i < mask.length * Main.dropRate) 294 | mask[i] = false; 295 | else 296 | mask[i] = true; 297 | } 298 | for (Task task : taskSubList) { 299 | 300 | if (task.b == -1) { 301 | runSmallTask(task.copy(), mask); 302 | } else { 303 | runTask(task, mask); 304 | } 305 | } 306 | } 307 | } 308 | 309 | public List getTaskList(int portion, int N) { 310 | System.out.print("assign task " + portion + "/" + N); 311 | List taskList = new ArrayList(); 312 | double avgWordCount = 1.0 * dataset.dic.totalWordSize / dataset.dic.uniqueWordSize(); 313 | for (int i = 0; i < dataset.textList.size(); i++) { 314 | if (i % N != portion) 315 | continue; 316 | if (i % 5000 == 0) 317 | System.out.print("|" + (i * 100 / dataset.textList.size()) + "%"); 318 | Text text = dataset.textList.get(i); 319 | // d 320 | if (info.contains("d")) 321 | for (int j : text.getIds(Main.useUniqueWordList)) { 322 | if (WV[j][0] == null) 323 | continue; 324 | for (int t = 0; t < avgWordCount / Math.pow(avgWordCount, Main.pow); t++) 325 | if (random.nextDouble() < Math.pow(dataset.dic.wordIdCountList.get(j), Main.pow) 326 | / dataset.dic.wordIdCountList.get(j)) 327 | taskList.add(new Task(i, j, (short) 0)); 328 | } 329 | // w 330 | if (info.contains("w")) 331 | for (Text.Pair p : text.getIdPairList(ws, dataset.dic)) { 332 | if (WV[p.a][0] == null) 333 | continue; 334 | if (WV[p.b][0] == null) 335 | continue; 336 | if (random.nextDouble() < 1.0) 337 | taskList.add(new Task(p.a, p.b, (short) 1)); 338 | } 339 | // l 340 | if (info.contains("l")) 341 | if (!text.type.equals("test") && text.label != -1) 342 | for (int j : text.getIds(Main.useUniqueWordList)) { 343 | if (WV[j][0] == null) 344 | continue; 345 | for (int t = 0; t < avgWordCount / Math.pow(avgWordCount, Main.pow); t++) 346 | if (random.nextDouble() < Math.pow(dataset.dic.wordIdCountList.get(j), Main.pow) 347 | / dataset.dic.wordIdCountList.get(j)) 348 | taskList.add(new Task(text.label, j, (short) 2)); 349 | } 350 | } 351 | System.out.print(" total " + taskList.size() + " tasks"); 352 | System.out.print(" over"); 353 | return taskList; 354 | } 355 | 356 | public List getSmallTaskList() { 357 | List taskList = new ArrayList(); 358 | for (int i = 0; i < dataset.textList.size(); i++) { 359 | Text text = dataset.textList.get(i); 360 | // d 361 | if (info.contains("d")) 362 | taskList.add(new Task(i, -1, (short) 0)); 363 | // w 364 | if (info.contains("w")) 365 | taskList.add(new Task(i, -1, (short) 1)); 366 | // l 367 | if (info.contains("l")) 368 | if (text.type.equals("train") && text.label != -1) 369 | taskList.add(new Task(i, -1, (short) 2)); 370 | } 371 | return taskList; 372 | } 373 | 374 | public void runTaskList(List taskList) { 375 | Collections.shuffle(taskList); 376 | int p = 0; 377 | while (true) { 378 | boolean over = false; 379 | for (int i = 0; i < ttList.length; i++) { 380 | if (ttList[i] == null || !ttList[i].isAlive()) { 381 | if (p < taskList.size()) { 382 | int s = p; 383 | int e = p + batchSize; 384 | if (taskList.size() < e) 385 | e = taskList.size(); 386 | ttList[i] = new TrainThread(taskList.subList(s, e), lr, rr); 387 | ttList[i].start(); 388 | p += batchSize; 389 | } else { 390 | over = true; 391 | break; 392 | } 393 | } else { 394 | } 395 | } 396 | if (over) 397 | break; 398 | } 399 | } 400 | 401 | public void waitTrainThreadOver() { 402 | while (true) { 403 | boolean over = true; 404 | for (int i = 0; i < ttList.length; i++) 405 | if (ttList[i] != null && ttList[i].isAlive()) 406 | over = false; 407 | if (over) 408 | break; 409 | } 410 | } 411 | 412 | public void saveWordVectors(String vectorFolder) { 413 | new File(vectorFolder).mkdirs(); 414 | int wordBi = 0; 415 | try { 416 | FileWriter fw = new FileWriter(vectorFolder + "WV" + n + ".txt"); 417 | for (Entry entry : dataset.dic.wordIdMap.entrySet()) { 418 | if (WV[entry.getValue()][wordBi] == null) 419 | continue; 420 | float v[] = WV[entry.getValue()][wordBi]; 421 | fw.write(entry.getKey() + " "); 422 | for (int j = 0; j < v.length; j++) 423 | if (v[j] != 0) 424 | fw.write(v[j] + " "); 425 | fw.write("\n"); 426 | } 427 | fw.close(); 428 | } catch (Exception e) { 429 | e.printStackTrace(); 430 | } 431 | } 432 | 433 | public void train() { 434 | { 435 | 436 | for (int epoch = 0; epoch < iter; epoch++) { 437 | 438 | lr = original_lr * (1 - epoch / iter); 439 | System.out.print("traning " + epoch + "/" + iter); 440 | long startTime = System.currentTimeMillis(); 441 | 442 | int N = 1; 443 | for (int portion = 0; portion < N; portion++) { 444 | List taskList = getTaskList(portion, N); 445 | runTaskList(taskList); 446 | waitTrainThreadOver(); 447 | } 448 | 449 | System.out.print("train over||"); 450 | System.out.println("||" + "|" + "time:" + (System.currentTimeMillis() - startTime)); 451 | } 452 | waitTrainThreadOver(); 453 | } 454 | } 455 | 456 | private float predictLabel(int i, int label) { 457 | Text text = dataset.textList.get(i); 458 | float p = 0; 459 | for (int id : text.getIds(Main.useUniqueWordList)) { 460 | p += MyMath.logistic(WV[id][0], LV[label]); 461 | } 462 | p /= text.getIds(Main.useUniqueWordList).size(); 463 | return p; 464 | } 465 | 466 | private int predictLabel(int i) { 467 | Text text = dataset.textList.get(i); 468 | float max = 0; 469 | int label = -1; 470 | for (int l : dataset.labelSet) { 471 | float p = predictLabel(i, l); 472 | if (p > max) { 473 | max = p; 474 | label = l; 475 | } 476 | } 477 | return label; 478 | } 479 | 480 | public double labelClassify() { 481 | double correct = 0; 482 | double total = 0; 483 | for (int i = 0; i < dataset.textList.size(); i++) { 484 | Text text = dataset.textList.get(i); 485 | if (text.type.equals("test")) { 486 | int label = predictLabel(i); 487 | if (label == text.label) 488 | correct++; 489 | total++; 490 | } 491 | } 492 | return correct / total * 100; 493 | } 494 | 495 | public void resume(Model bak) { // word vector only 496 | for (Entry entry : dataset.dic.wordIdMap.entrySet()) { 497 | WV[entry.getValue()] = bak.WV[bak.dataset.dic.wordIdMap.get(entry.getKey())]; 498 | } 499 | } 500 | } 501 | -------------------------------------------------------------------------------- /src/model/Task.java: -------------------------------------------------------------------------------- 1 | package model; 2 | 3 | public class Task { 4 | int a; 5 | int b; 6 | short type; 7 | 8 | public Task(int a, int b, short type) { 9 | this.a = a; 10 | this.b = b; 11 | this.type = type; 12 | } 13 | 14 | public Task copy() { 15 | return new Task(this.a, this.b, this.type); 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/myUtils/Dataset.java: -------------------------------------------------------------------------------- 1 | package myUtils; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.File; 5 | import java.io.FileReader; 6 | import java.util.*; 7 | 8 | import main.Main; 9 | import myUtils.*; 10 | 11 | public class Dataset { 12 | public List textList; 13 | public Dictionary dic = new Dictionary(); 14 | public Set labelSet = new HashSet(); 15 | 16 | // dataset for training 17 | private List getTrainDataset(String folder, int ngram, int portion) { 18 | List d = new ArrayList(); 19 | try { 20 | File file = new File(folder); 21 | BufferedReader reader = new BufferedReader(new FileReader(file)); 22 | String line = null; 23 | { 24 | int index = 0; 25 | while ((line = reader.readLine()) != null) { 26 | if (index++ % 10000 == 0) { 27 | System.out.print("."); 28 | } 29 | if (index % portion != 0) 30 | continue; 31 | String content = line; 32 | Text text = new Text(content, "dev", -1, ngram, dic); 33 | d.add(text); 34 | } 35 | } 36 | reader.close(); 37 | } catch (Exception e) { 38 | e.printStackTrace(); 39 | } 40 | return d; 41 | } 42 | 43 | private List getImdbDataset(String filePath, int ngram, boolean addData) { 44 | List d = new ArrayList(); 45 | try { 46 | File file = new File(filePath); 47 | BufferedReader reader = new BufferedReader(new FileReader(file)); 48 | String line = null; 49 | int readDataCount = 0; 50 | while ((line = reader.readLine()) != null) { 51 | String key = "train"; 52 | int index = Integer.parseInt(line.split(" ")[0].substring(2)); 53 | if (25000 <= index && index < 50000) 54 | key = "test"; 55 | int label = -1; 56 | if (index < 12500) 57 | label = 1; 58 | if (12500 <= index && index < 25000) 59 | label = 0; 60 | if (25000 <= index && index < 25000 + 12500) 61 | label = 1; 62 | if (25000 + 12500 <= index && index < 25000 + 25000) 63 | label = 0; 64 | if (addData) { 65 | if (label != -1) 66 | continue; 67 | } else { 68 | if (label == -1) 69 | continue; 70 | } 71 | String content = line.substring(line.split(" ")[0].length()).trim(); 72 | // System.out.println(content); 73 | Text text = new Text(content, key, label, ngram, dic); 74 | d.add(text); 75 | if (readDataCount++ % 1000 == 0) { 76 | System.out.print("."); 77 | } 78 | } 79 | reader.close(); 80 | } catch (Exception e) { 81 | e.printStackTrace(); 82 | } 83 | return d; 84 | } 85 | 86 | public Dataset(String fileName, String type, int ngram, String addDataType) { 87 | System.out.print("reading dataset:" + fileName); 88 | textList = new ArrayList(); 89 | if (type.equals("unlabel")) 90 | textList = getTrainDataset("./data/" + fileName, ngram, 1); 91 | if (type.equals("imdb")) 92 | textList = getImdbDataset("./data/" + fileName, ngram, false); 93 | 94 | // some additional unlabeled data 95 | if (addDataType.contains("news")) { 96 | String folder = "./data/unlabeled/news.txt"; 97 | List adtl = getTrainDataset(folder, ngram, Integer.parseInt(addDataType.split("news")[1])); 98 | for (Text t : adtl) { 99 | t.type = "dev"; 100 | } 101 | textList.addAll(adtl); 102 | } 103 | if (addDataType.contains("wb")) { 104 | String folder = "./data/unlabeled/wb.txt"; 105 | List adtl = getTrainDataset(folder, ngram, Integer.parseInt(addDataType.split("wb")[1])); 106 | for (Text t : adtl) { 107 | t.type = "dev"; 108 | } 109 | textList.addAll(adtl); 110 | } 111 | if (addDataType.contains("books")) { 112 | { 113 | String folder = "./data/unlabeled/books_large_p1.txt"; 114 | List adtl = getTrainDataset(folder, ngram, Integer.parseInt(addDataType.split("books")[1])); 115 | for (Text t : adtl) { 116 | t.type = "dev"; 117 | } 118 | textList.addAll(adtl); 119 | } 120 | } 121 | if (addDataType.contains("sick")) { 122 | String folder = "./data/unlabeled/sick.txt"; 123 | List adtl = getTrainDataset(folder, ngram, Integer.parseInt(addDataType.split("sick")[1])); 124 | for (Text t : adtl) { 125 | t.type = "dev"; 126 | } 127 | textList.addAll(adtl); 128 | } 129 | if (addDataType.contains("imdbd")) { 130 | String path = "./data/" + "imdb.txt"; 131 | List adtl = getImdbDataset(path, ngram, true); 132 | for (Text t : adtl) { 133 | t.type = "dev"; 134 | } 135 | textList.addAll(adtl); 136 | } 137 | 138 | // set dictionary random word rate 139 | dic.setRandomFactor(Main.pow); 140 | 141 | System.out.println(); 142 | showDetail(); 143 | System.out.println("reading finished"); 144 | } 145 | 146 | public List getTextList() { 147 | return textList; 148 | } 149 | 150 | public void showDetail() { 151 | System.out.println("text size = " + textList.size()); 152 | System.out.println("vocab size = " + dic.uniqueWordSize()); 153 | System.out.println("total vocab size = " + dic.totalWordSize); 154 | double length = 0; 155 | for (Text text : textList) { 156 | length += text.getIds(false).size(); 157 | } 158 | length /= textList.size(); 159 | System.out.println("avg length = " + length); 160 | double l1 = 0; 161 | double l2 = 0; 162 | for (Text text : textList) { 163 | double t = text.getIds(false).size() - length; 164 | l1 += Math.abs(t); 165 | l2 += t * t; 166 | } 167 | l1 /= textList.size(); 168 | l2 /= textList.size(); 169 | l2 = Math.sqrt(l2); 170 | System.out.println("l1 = " + l1); 171 | System.out.println("l2 = " + l2); 172 | } 173 | 174 | } 175 | -------------------------------------------------------------------------------- /src/myUtils/Dictionary.java: -------------------------------------------------------------------------------- 1 | package myUtils; 2 | 3 | import java.util.*; 4 | 5 | public class Dictionary { 6 | public static Random random = new Random(); 7 | 8 | public Map wordIdMap; 9 | public List wordIdCountList; 10 | public Map idWordMap; 11 | public double wordIdSumList[]; 12 | public int totalWordSize = 0; // how many unique word 13 | private int uniqueWordSize = 0; 14 | 15 | public Dictionary() { 16 | wordIdMap = new HashMap(); 17 | idWordMap = new HashMap(); 18 | wordIdCountList = new ArrayList(); 19 | } 20 | 21 | public int uniqueWordSize() { 22 | return uniqueWordSize; 23 | } 24 | 25 | public int addWord(String w) { 26 | totalWordSize++; 27 | if (!wordIdMap.containsKey(w)) { 28 | wordIdMap.put(w, wordIdCountList.size()); 29 | idWordMap.put(wordIdCountList.size(), w); 30 | wordIdCountList.add(0); 31 | } 32 | int index = wordIdMap.get(w); 33 | wordIdCountList.set(index, wordIdCountList.get(index) + 1); 34 | return index; 35 | } 36 | 37 | public void setRandomFactor(double factor) { 38 | wordIdSumList = new double[wordIdCountList.size()]; 39 | wordIdSumList[0] = (Math.pow(wordIdCountList.get(0), factor)); 40 | for (int t = 1; t < wordIdCountList.size(); t++) { 41 | wordIdSumList[t] = (Math.pow(wordIdCountList.get(t), factor) + wordIdSumList[t - 1]); 42 | } 43 | uniqueWordSize = wordIdMap.size(); 44 | // wordIdCountList.clear(); 45 | // wordIdCountList = null; 46 | // wordIdMap.clear(); 47 | // wordIdMap = null; 48 | } 49 | 50 | public int getRandomWord() { 51 | 52 | double i = random.nextDouble() * (wordIdSumList[wordIdSumList.length - 1]); 53 | int l = 0, r = wordIdSumList.length - 1; 54 | while (l != r) { 55 | int m = (l + r) / 2; 56 | 57 | if ((m == 0 || wordIdSumList[m - 1] < i) && i <= wordIdSumList[m]) 58 | return m; 59 | if (i <= wordIdSumList[m]) 60 | r = m; 61 | else 62 | l = m + 1; 63 | } 64 | // System.out.println(wordIdSumList[l - 1] + "|" + i + "|" + 65 | // wordIdSumList[l]); 66 | return l; 67 | // return allWord.get(random.nextInt(allWord.size())); 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /src/myUtils/MyMath.java: -------------------------------------------------------------------------------- 1 | package myUtils; 2 | 3 | import main.Main; 4 | 5 | public class MyMath { 6 | public static void zero(float a[]) { 7 | for (int i = 0; i < a.length; i++) 8 | a[i] = 0; 9 | } 10 | 11 | public static void addSelf(float a[], float b[]) { 12 | if (a == null) 13 | return; 14 | for (int i = 0; i < a.length; i++) 15 | a[i] += b[i]; 16 | } 17 | public static float logistic(float a[], float b[]){ 18 | float y = 0; 19 | for (int i = 0; i < a.length; i++) 20 | y += a[i] * b[i]; 21 | return logistic(y); 22 | } 23 | public static float logistic(float a){ 24 | return (float) (1.0 / (1 + Math.exp(-a))); 25 | } 26 | public static float tanh(float a){ 27 | return (float) ((Math.exp(a) - Math.exp(-a)) / (Math.exp(a) + Math.exp(-a))); 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/myUtils/Text.java: -------------------------------------------------------------------------------- 1 | package myUtils; 2 | 3 | import java.util.*; 4 | 5 | import main.Main; 6 | 7 | public class Text { 8 | private int wordIds[][]; 9 | public String type; // train or test 10 | public int label; 11 | int ngram; 12 | private static Random random = new Random(); 13 | public String s; 14 | 15 | public Text(String s, String type, int label, int ngram, Dictionary dic) { 16 | s = s.trim(); 17 | s = s.replaceAll(" +", " "); 18 | this.s = s; 19 | this.type = type; 20 | this.label = label; 21 | this.ngram = ngram; 22 | String[] tokens = s.split(" "); 23 | 24 | wordIds = new int[ngram][]; 25 | for (int gram = 0; gram < ngram; gram++) { 26 | int tl = tokens.length - gram; 27 | if (tl < 0) 28 | tl = 0; 29 | wordIds[gram] = new int[tl]; 30 | for (int i = 0; i < tokens.length - gram; i++) { 31 | String w = ""; 32 | for (int j = 0; j <= gram; j++) 33 | w += tokens[i + j] + "_"; 34 | // System.out.println(w); 35 | 36 | int index = -1; 37 | index = dic.addWord(w); 38 | wordIds[gram][i] = index; 39 | } 40 | } 41 | } 42 | 43 | public List getIds(boolean unique) { 44 | 45 | List ids = new ArrayList(); 46 | for (int gram = 0; gram < ngram; gram++) { 47 | for (int i = 0; i < wordIds[gram].length; i++) { 48 | ids.add(wordIds[gram][i]); 49 | } 50 | } 51 | if (unique) { 52 | Set idSet = new HashSet(); 53 | idSet.addAll(ids); 54 | ids.clear(); 55 | ids.addAll(idSet); 56 | } 57 | return ids; 58 | } 59 | 60 | public List getIdPairList(int ws, Dictionary dic) { 61 | List pairList = new ArrayList(); 62 | for (int gram = 0; gram < ngram; gram++) { 63 | for (int i = 0; i < wordIds[gram].length; i++) { 64 | if (Main.subSampleRate != 0) { 65 | int cn = dic.wordIdCountList.get(wordIds[gram][i]); 66 | if (Math.sqrt(Main.subSampleRate * dic.totalWordSize / cn) < random.nextDouble()) 67 | continue; 68 | } 69 | int l = i - ws; 70 | int r = i + ws; 71 | if (gram == 1) {// bigram 72 | l = i - ws; 73 | r = i + ws + 1; 74 | } 75 | if (gram == 2) { // trigram 76 | l = i - ws; 77 | r = i + ws + 2; 78 | } 79 | if (l < 0) 80 | l = 0; 81 | if (r >= wordIds[0].length) 82 | r = wordIds[0].length - 1; 83 | for (int g = 0; g < ngram; g++) 84 | for (int t = l; t <= r - g; t++) 85 | pairList.add(new Pair(wordIds[gram][i], wordIds[g][t])); 86 | } 87 | } 88 | return pairList; 89 | } 90 | 91 | public static class Pair { 92 | public int a; 93 | public int b; 94 | 95 | public Pair(int a, int b) { 96 | this.a = a; 97 | this.b = b; 98 | } 99 | } 100 | } 101 | --------------------------------------------------------------------------------