├── APSegmentor.cpp ├── APSegmentor.h ├── CMakeLists.txt ├── IntegratedLSTMSegmentor.cpp ├── IntegratedLSTMSegmentor.h ├── LSTMSegmentor.cpp ├── LSTMSegmentor.h ├── LinearSegmentor.cpp ├── LinearSegmentor.h ├── Options.h ├── README.md ├── basic ├── Action.h ├── Argument_helper.h ├── Instance.h ├── InstanceReader.h ├── InstanceWriter.h ├── Pipe.h ├── Reader.h ├── SegLookupTable.h ├── Utf.h └── Writer.h ├── cleanall.sh ├── feature ├── DenseFeature.h ├── DenseFeatureChar.h ├── DenseFeatureExtraction.h ├── DenseFeatureForward.h ├── Feature.h └── FeatureExtraction.h ├── model ├── IntegratedLSTMBeamSearcher.h ├── LSTMBeamSearcher.h └── LinearBeamSearcher.h ├── other-implementations ├── cpps │ ├── APSegmentor.cpp │ ├── APSegmentor.h │ ├── GRNNSegmentor.cpp │ ├── GRNNSegmentor.h │ ├── LSTMNASegmentor.cpp │ ├── LSTMNASegmentor.h │ ├── LSTMNBCSegmentor.cpp │ ├── LSTMNBCSegmentor.h │ ├── LSTMNCSegmentor.cpp │ ├── LSTMNCSegmentor.h │ ├── LSTMNUCSegmentor.cpp │ ├── LSTMNUCSegmentor.h │ ├── LSTMNWSegmentor.cpp │ ├── LSTMNWSegmentor.h │ ├── RNNSegmentor.cpp │ ├── RNNSegmentor.h │ ├── TNNSegmentor.cpp │ └── TNNSegmentor.h └── models │ ├── APBeamSearcher.h │ ├── GRNNBeamSearcher.h │ ├── LSTMNABeamSearcher.h │ ├── LSTMNBCBeamSearcher.h │ ├── LSTMNCBeamSearcher.h │ ├── LSTMNUCBeamSearcher.h │ ├── LSTMNWBeamSearcher.h │ ├── RNNBeamSearcher.h │ └── TNNBeamSearcher.h └── state ├── NeuralState.h └── State.h /APSegmentor.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Segmentor.cpp 3 | * 4 | * Created on: Oct 23, 2015 5 | * Author: mszhang 6 | */ 7 | 8 | #include "APSegmentor.h" 9 | 10 | #include "Argument_helper.h" 11 | 12 | Segmentor::Segmentor() { 13 | // TODO Auto-generated constructor stub 14 | nullkey = "-null-"; 15 | unknownkey = "-unknown-"; 16 | paddingtag = "-padding-"; 17 | seperateKey = "#"; 18 | } 19 | 20 | Segmentor::~Segmentor() { 21 | // TODO Auto-generated destructor stub 22 | } 23 | 24 | // all linear features are extracted from positive examples 25 | int Segmentor::createAlphabet(const vector& vecInsts) { 26 | cout << "Creating Alphabet..." << endl; 27 | 28 | int numInstance = vecInsts.size(); 29 | 30 | hash_map word_stat; 31 | hash_map char_stat; 32 | hash_map bichar_stat; 33 | hash_map action_stat; 34 | hash_map feat_stat; 35 | 36 | assert(numInstance > 0); 37 | 38 | static Metric eval; 39 | static CStateItem state[m_classifier.MAX_SENTENCE_SIZE]; 40 | static Feature feat; 41 | static vector output; 42 | static CAction answer; 43 | static int actionNum; 44 | m_classifier.initAlphabet(); 45 | eval.reset(); 46 | for (numInstance = 0; numInstance < vecInsts.size(); numInstance++) { 47 | const Instance &instance = vecInsts[numInstance]; 48 | for (int idx = 0; idx < instance.wordsize(); idx++) { 49 | word_stat[normalize_to_lowerwithdigit(instance.words[idx])]++; 50 | } 51 | for (int idx = 0; idx < instance.charsize(); idx++) { 52 | char_stat[instance.chars[idx]]++; 53 | } 54 | for (int idx = 0; idx < instance.charsize() - 1; idx++) { 55 | bichar_stat[instance.chars[idx] + instance.chars[idx + 1]]++; 56 | } 57 | bichar_stat[instance.chars[instance.charsize() - 1] + m_classifier.fe.nullkey]++; 58 | bichar_stat[m_classifier.fe.nullkey + instance.chars[0]]++; 59 | actionNum = 0; 60 | state[actionNum].initSentence(&instance.chars); 61 | state[actionNum].clear(); 62 | 63 | while (!state[actionNum].IsTerminated()) { 64 | state[actionNum].getGoldAction(instance.words, answer); 65 | action_stat[answer.str()]++; 66 | 67 | m_classifier.extractFeature(state+actionNum, answer, feat); 68 | for (int idx = 0; idx < feat._strSparseFeat.size(); idx++) { 69 | feat_stat[feat._strSparseFeat[idx]]++; 70 | } 71 | state[actionNum].move(state+actionNum+1, answer); 72 | actionNum++; 73 | } 74 | 75 | if(actionNum-1 != instance.charsize()) { 76 | std::cout << "action number is not correct, please check" << std::endl; 77 | } 78 | state[actionNum].getSegResults(output); 79 | 80 | instance.evaluate(output, eval); 81 | 82 | if (!eval.bIdentical()) { 83 | std::cout << "error state conversion!" << std::endl; 84 | exit(0); 85 | } 86 | 87 | if ((numInstance + 1) % m_options.verboseIter == 0) { 88 | cout << numInstance + 1 << " "; 89 | if ((numInstance + 1) % (40 * m_options.verboseIter) == 0) 90 | cout << std::endl; 91 | cout.flush(); 92 | } 93 | if (m_options.maxInstance > 0 && numInstance == m_options.maxInstance) 94 | break; 95 | } 96 | 97 | m_classifier.addToActionAlphabet(action_stat); 98 | m_classifier.addToWordAlphabet(word_stat, m_options.wordEmbFineTune ? m_options.wordCutOff : 0); 99 | m_classifier.addToCharAlphabet(char_stat, m_options.charEmbFineTune ? m_options.charCutOff : 0); 100 | m_classifier.addToBiCharAlphabet(bichar_stat, m_options.bicharEmbFineTune ? m_options.bicharCutOff : 0); 101 | m_classifier.addToFeatureAlphabet(feat_stat, m_options.featCutOff); 102 | 103 | cout << numInstance << " " << endl; 104 | cout << "Action num: " << m_classifier.fe._actionAlphabet.size() << endl; 105 | cout << "Total word num: " << word_stat.size() << endl; 106 | cout << "Total char num: " << char_stat.size() << endl; 107 | cout << "Total bichar num: " << bichar_stat.size() << endl; 108 | cout << "Total feat num: " << feat_stat.size() << endl; 109 | 110 | cout << "Remain word num: " << m_classifier.fe._wordAlphabet.size() << endl; 111 | cout << "Remain char num: " << m_classifier.fe._charAlphabet.size() << endl; 112 | cout << "Remain bichar num: " << m_classifier.fe._bicharAlphabet.size() << endl; 113 | cout << "Remain feat num: " << m_classifier.fe._featAlphabet.size() << endl; 114 | 115 | //m_classifier.setFeatureCollectionState(false); 116 | 117 | return 0; 118 | } 119 | 120 | int Segmentor::addTestWordAlpha(const vector& vecInsts) { 121 | cout << "Add test Alphabet..." << endl; 122 | 123 | hash_map word_stat; 124 | hash_map char_stat; 125 | hash_map bichar_stat; 126 | int numInstance; 127 | 128 | for (numInstance = 0; numInstance < vecInsts.size(); numInstance++) { 129 | const Instance &instance = vecInsts[numInstance]; 130 | 131 | for (int idx = 0; idx < instance.wordsize(); idx++) { 132 | word_stat[normalize_to_lowerwithdigit(instance.words[idx])]++; 133 | } 134 | for (int idx = 0; idx < instance.charsize(); idx++) { 135 | char_stat[instance.chars[idx]]++; 136 | } 137 | for (int idx = 0; idx < instance.charsize() - 1; idx++) { 138 | bichar_stat[instance.chars[idx] + instance.chars[idx + 1]]++; 139 | } 140 | bichar_stat[instance.chars[instance.charsize() - 1] + m_classifier.fe.nullkey]++; 141 | bichar_stat[m_classifier.fe.nullkey + instance.chars[0]]++; 142 | 143 | if ((numInstance + 1) % m_options.verboseIter == 0) { 144 | cout << numInstance + 1 << " "; 145 | if ((numInstance + 1) % (40 * m_options.verboseIter) == 0) 146 | cout << std::endl; 147 | cout.flush(); 148 | } 149 | if (m_options.maxInstance > 0 && numInstance == m_options.maxInstance) 150 | break; 151 | } 152 | 153 | if (!m_options.wordEmbFineTune) 154 | m_classifier.addToWordAlphabet(word_stat, 0); 155 | if (!m_options.charEmbFineTune) 156 | m_classifier.addToCharAlphabet(char_stat, 0); 157 | if (!m_options.bicharEmbFineTune) 158 | m_classifier.addToBiCharAlphabet(bichar_stat, 0); 159 | 160 | cout << "Remain word num: " << m_classifier.fe._wordAlphabet.size() << endl; 161 | cout << "Remain char num: " << m_classifier.fe._charAlphabet.size() << endl; 162 | cout << "Remain bichar num: " << m_classifier.fe._bicharAlphabet.size() << endl; 163 | 164 | return 0; 165 | } 166 | 167 | 168 | void Segmentor::getGoldActions(const vector& vecInsts, vector >& vecActions){ 169 | vecActions.clear(); 170 | 171 | static Metric eval; 172 | static CStateItem state[m_classifier.MAX_SENTENCE_SIZE]; 173 | static vector output; 174 | static CAction answer; 175 | eval.reset(); 176 | static int numInstance, actionNum; 177 | vecActions.resize(vecInsts.size()); 178 | for (numInstance = 0; numInstance < vecInsts.size(); numInstance++) { 179 | const Instance &instance = vecInsts[numInstance]; 180 | 181 | actionNum = 0; 182 | state[actionNum].initSentence(&instance.chars); 183 | state[actionNum].clear(); 184 | 185 | while (!state[actionNum].IsTerminated()) { 186 | state[actionNum].getGoldAction(instance.words, answer); 187 | vecActions[numInstance].push_back(answer); 188 | state[actionNum].move(state+actionNum+1, answer); 189 | actionNum++; 190 | } 191 | 192 | if(actionNum-1 != instance.charsize()) { 193 | std::cout << "action number is not correct, please check" << std::endl; 194 | } 195 | state[actionNum].getSegResults(output); 196 | 197 | instance.evaluate(output, eval); 198 | 199 | if (!eval.bIdentical()) { 200 | std::cout << "error state conversion!" << std::endl; 201 | exit(0); 202 | } 203 | 204 | if ((numInstance + 1) % m_options.verboseIter == 0) { 205 | cout << numInstance + 1 << " "; 206 | if ((numInstance + 1) % (40 * m_options.verboseIter) == 0) 207 | cout << std::endl; 208 | cout.flush(); 209 | } 210 | if (m_options.maxInstance > 0 && numInstance == m_options.maxInstance) 211 | break; 212 | } 213 | } 214 | 215 | void Segmentor::train(const string& trainFile, const string& devFile, const string& testFile, const string& modelFile, const string& optionFile, 216 | const string& wordEmbFile) { 217 | if (optionFile != "") 218 | m_options.load(optionFile); 219 | 220 | m_options.showOptions(); 221 | vector trainInsts, devInsts, testInsts; 222 | m_pipe.readInstances(trainFile, trainInsts, m_classifier.MAX_SENTENCE_SIZE, m_options.maxInstance); 223 | if (devFile != "") 224 | m_pipe.readInstances(devFile, devInsts, m_classifier.MAX_SENTENCE_SIZE, m_options.maxInstance); 225 | if (testFile != "") 226 | m_pipe.readInstances(testFile, testInsts, m_classifier.MAX_SENTENCE_SIZE, m_options.maxInstance); 227 | 228 | vector > otherInsts(m_options.testFiles.size()); 229 | for (int idx = 0; idx < m_options.testFiles.size(); idx++) { 230 | m_pipe.readInstances(m_options.testFiles[idx], otherInsts[idx], m_classifier.MAX_SENTENCE_SIZE, m_options.maxInstance); 231 | } 232 | 233 | createAlphabet(trainInsts); 234 | 235 | addTestWordAlpha(devInsts); 236 | addTestWordAlpha(testInsts); 237 | for (int idx = 0; idx < otherInsts.size(); idx++) { 238 | addTestWordAlpha(otherInsts[idx]); 239 | } 240 | 241 | 242 | m_classifier.init(); 243 | m_classifier.setDropValue(m_options.dropProb); 244 | 245 | vector > trainInstGoldactions; 246 | getGoldActions(trainInsts, trainInstGoldactions); 247 | double bestFmeasure = 0; 248 | 249 | int inputSize = trainInsts.size(); 250 | 251 | std::vector indexes; 252 | for (int i = 0; i < inputSize; ++i) 253 | indexes.push_back(i); 254 | 255 | static Metric eval, metric_dev, metric_test; 256 | 257 | int maxIter = m_options.maxIter * (inputSize / m_options.batchSize + 1); 258 | int oneIterMaxRound = (inputSize + m_options.batchSize -1) / m_options.batchSize; 259 | std::cout << "maxIter = " << maxIter << std::endl; 260 | int devNum = devInsts.size(), testNum = testInsts.size(); 261 | 262 | static vector > decodeInstResults; 263 | static vector curDecodeInst; 264 | static bool bCurIterBetter; 265 | static vector > subInstances; 266 | static vector > subInstGoldActions; 267 | 268 | //m_classifier.setAlphaIncreasing(true); 269 | for (int iter = 0; iter < maxIter; ++iter) { 270 | std::cout << "##### Iteration " << iter << std::endl; 271 | srand(iter); 272 | //random_shuffle(indexes.begin(), indexes.end()); 273 | std::cout << "random: " << indexes[0] << ", " << indexes[indexes.size() - 1] << std::endl; 274 | eval.reset(); 275 | for (int updateIter = 0; updateIter < oneIterMaxRound; updateIter++) { 276 | int start_pos = updateIter * m_options.batchSize; 277 | int end_pos = (updateIter + 1) * m_options.batchSize; 278 | if (end_pos > inputSize) 279 | end_pos = inputSize; 280 | subInstances.clear(); 281 | subInstGoldActions.clear(); 282 | for (int idy = start_pos; idy < end_pos; idy++) { 283 | subInstances.push_back(trainInsts[indexes[idy]].chars); 284 | subInstGoldActions.push_back(trainInstGoldactions[indexes[idy]]); 285 | } 286 | 287 | double cost = m_classifier.train(subInstances, subInstGoldActions); 288 | 289 | eval.overall_label_count += m_classifier._eval.overall_label_count; 290 | eval.correct_label_count += m_classifier._eval.correct_label_count; 291 | 292 | //if ((updateIter + 1) % (m_options.verboseIter*10) == 0) { 293 | //std::cout << "current: " << updateIter + 1 << ", Cost = " << cost << ", Correct(%) = " << eval.getAccuracy() << std::endl; 294 | //} 295 | m_classifier.updateParams(m_options.regParameter, m_options.adaAlpha, m_options.adaEps); 296 | } 297 | 298 | std::cout << "current: " << iter + 1 << ", Correct(%) = " << eval.getAccuracy() << std::endl; 299 | 300 | if (devNum > 0) { 301 | bCurIterBetter = false; 302 | if (!m_options.outBest.empty()) 303 | decodeInstResults.clear(); 304 | metric_dev.reset(); 305 | for (int idx = 0; idx < devInsts.size(); idx++) { 306 | predict(devInsts[idx], curDecodeInst); 307 | devInsts[idx].evaluate(curDecodeInst, metric_dev); 308 | if (!m_options.outBest.empty()) { 309 | decodeInstResults.push_back(curDecodeInst); 310 | } 311 | } 312 | std::cout << "dev:" << std::endl; 313 | metric_dev.print(); 314 | 315 | if (!m_options.outBest.empty() && metric_dev.getAccuracy() > bestFmeasure) { 316 | m_pipe.outputAllInstances(devFile + m_options.outBest, decodeInstResults); 317 | bCurIterBetter = true; 318 | } 319 | 320 | if (testNum > 0) { 321 | if (!m_options.outBest.empty()) 322 | decodeInstResults.clear(); 323 | metric_test.reset(); 324 | for (int idx = 0; idx < testInsts.size(); idx++) { 325 | predict(testInsts[idx], curDecodeInst); 326 | testInsts[idx].evaluate(curDecodeInst, metric_test); 327 | if (bCurIterBetter && !m_options.outBest.empty()) { 328 | decodeInstResults.push_back(curDecodeInst); 329 | } 330 | } 331 | std::cout << "test:" << std::endl; 332 | metric_test.print(); 333 | 334 | if (!m_options.outBest.empty() && bCurIterBetter) { 335 | m_pipe.outputAllInstances(testFile + m_options.outBest, decodeInstResults); 336 | } 337 | } 338 | 339 | for (int idx = 0; idx < otherInsts.size(); idx++) { 340 | std::cout << "processing " << m_options.testFiles[idx] << std::endl; 341 | if (!m_options.outBest.empty()) 342 | decodeInstResults.clear(); 343 | metric_test.reset(); 344 | for (int idy = 0; idy < otherInsts[idx].size(); idy++) { 345 | predict(otherInsts[idx][idy], curDecodeInst); 346 | otherInsts[idx][idy].evaluate(curDecodeInst, metric_test); 347 | if (bCurIterBetter && !m_options.outBest.empty()) { 348 | decodeInstResults.push_back(curDecodeInst); 349 | } 350 | } 351 | std::cout << "test:" << std::endl; 352 | metric_test.print(); 353 | 354 | if (!m_options.outBest.empty() && bCurIterBetter) { 355 | m_pipe.outputAllInstances(m_options.testFiles[idx] + m_options.outBest, decodeInstResults); 356 | } 357 | } 358 | 359 | 360 | if (m_options.saveIntermediate && metric_dev.getAccuracy() > bestFmeasure) { 361 | std::cout << "Exceeds best previous DIS of " << bestFmeasure << ". Saving model file.." << std::endl; 362 | bestFmeasure = metric_dev.getAccuracy(); 363 | writeModelFile(modelFile); 364 | } 365 | } 366 | } 367 | } 368 | 369 | void Segmentor::predict(const Instance& input, vector& output) { 370 | m_classifier.decode(input.chars, output); 371 | } 372 | 373 | void Segmentor::test(const string& testFile, const string& outputFile, const string& modelFile) { 374 | loadModelFile(modelFile); 375 | vector testInsts; 376 | m_pipe.readInstances(testFile, testInsts, m_options.maxInstance); 377 | 378 | vector > testInstResults(testInsts.size()); 379 | Metric metric_test; 380 | metric_test.reset(); 381 | for (int idx = 0; idx < testInsts.size(); idx++) { 382 | vector result_labels; 383 | predict(testInsts[idx], testInstResults[idx]); 384 | testInsts[idx].evaluate(testInstResults[idx], metric_test); 385 | } 386 | std::cout << "test:" << std::endl; 387 | metric_test.print(); 388 | 389 | std::ofstream os(outputFile.c_str()); 390 | 391 | for (int idx = 0; idx < testInsts.size(); idx++) { 392 | for(int idy = 0; idy < testInstResults[idx].size(); idy++){ 393 | os << testInstResults[idx][idy] << " "; 394 | } 395 | os << std::endl; 396 | } 397 | os.close(); 398 | } 399 | 400 | /* 401 | void Segmentor::readWordEmbeddings(const string& inFile, NRMat& wordEmb) { 402 | static ifstream inf; 403 | if (inf.is_open()) { 404 | inf.close(); 405 | inf.clear(); 406 | } 407 | inf.open(inFile.c_str()); 408 | 409 | static string strLine, curWord; 410 | static int wordId; 411 | 412 | //find the first line, decide the wordDim; 413 | while (1) { 414 | if (!my_getline(inf, strLine)) { 415 | break; 416 | } 417 | if (!strLine.empty()) 418 | break; 419 | } 420 | 421 | int unknownId = m_wordAlphabet.from_string(unknownkey); 422 | 423 | static vector vecInfo; 424 | split_bychar(strLine, vecInfo, ' '); 425 | int wordDim = vecInfo.size() - 1; 426 | 427 | std::cout << "word embedding dim is " << wordDim << std::endl; 428 | m_options.wordEmbSize = wordDim; 429 | 430 | wordEmb.resize(m_wordAlphabet.size(), wordDim); 431 | wordEmb = 0.0; 432 | curWord = normalize_to_lowerwithdigit(vecInfo[0]); 433 | wordId = m_wordAlphabet.from_string(curWord); 434 | hash_set indexers; 435 | dtype sum[wordDim]; 436 | int count = 0; 437 | bool bHasUnknown = false; 438 | if (wordId >= 0) { 439 | count++; 440 | if (unknownId == wordId) 441 | bHasUnknown = true; 442 | indexers.insert(wordId); 443 | for (int idx = 0; idx < wordDim; idx++) { 444 | dtype curValue = atof(vecInfo[idx + 1].c_str()); 445 | sum[idx] = curValue; 446 | wordEmb[wordId][idx] = curValue; 447 | } 448 | 449 | } else { 450 | for (int idx = 0; idx < wordDim; idx++) { 451 | sum[idx] = 0.0; 452 | } 453 | } 454 | 455 | while (1) { 456 | if (!my_getline(inf, strLine)) { 457 | break; 458 | } 459 | if (strLine.empty()) 460 | continue; 461 | split_bychar(strLine, vecInfo, ' '); 462 | if (vecInfo.size() != wordDim + 1) { 463 | std::cout << "error embedding file" << std::endl; 464 | } 465 | curWord = normalize_to_lowerwithdigit(vecInfo[0]); 466 | wordId = m_wordAlphabet.from_string(curWord); 467 | if (wordId >= 0) { 468 | count++; 469 | if (unknownId == wordId) 470 | bHasUnknown = true; 471 | indexers.insert(wordId); 472 | 473 | for (int idx = 0; idx < wordDim; idx++) { 474 | dtype curValue = atof(vecInfo[idx + 1].c_str()); 475 | sum[idx] = curValue; 476 | wordEmb[wordId][idx] += curValue; 477 | } 478 | } 479 | 480 | } 481 | 482 | if (!bHasUnknown) { 483 | for (int idx = 0; idx < wordDim; idx++) { 484 | wordEmb[unknownId][idx] = sum[idx] / count; 485 | } 486 | count++; 487 | std::cout << unknownkey << " not found, using averaged value to initialize." << std::endl; 488 | } 489 | 490 | int oovWords = 0; 491 | int totalWords = 0; 492 | for (int id = 0; id < m_wordAlphabet.size(); id++) { 493 | if (indexers.find(id) == indexers.end()) { 494 | oovWords++; 495 | for (int idx = 0; idx < wordDim; idx++) { 496 | wordEmb[id][idx] = wordEmb[unknownId][idx]; 497 | } 498 | } 499 | totalWords++; 500 | } 501 | 502 | std::cout << "OOV num is " << oovWords << ", total num is " << m_wordAlphabet.size() << ", embedding oov ratio is " << oovWords * 1.0 / m_wordAlphabet.size() 503 | << std::endl; 504 | 505 | } 506 | 507 | void Segmentor::readWordClusters(const string& inFile) { 508 | static ifstream inf; 509 | if (inf.is_open()) { 510 | inf.close(); 511 | inf.clear(); 512 | } 513 | inf.open(inFile.c_str()); 514 | 515 | static string strLine, curWord; 516 | static int wordId; 517 | 518 | //find the first line, decide the wordDim; 519 | while (1) { 520 | if (!my_getline(inf, strLine)) { 521 | break; 522 | } 523 | if (!strLine.empty()) 524 | break; 525 | } 526 | 527 | int unknownId = m_wordAlphabet.from_string(unknownkey); 528 | 529 | static vector vecInfo; 530 | split_bychar(strLine, vecInfo, ' '); 531 | int wordClusterNum = vecInfo.size() - 1; 532 | 533 | std::cout << "word cluster number is " << wordClusterNum << std::endl; 534 | 535 | m_wordClusters.resize(m_wordAlphabet.size(), wordClusterNum); 536 | m_wordClusters = "0"; 537 | curWord = normalize_to_lowerwithdigit(vecInfo[0]); 538 | wordId = m_wordAlphabet.from_string(curWord); 539 | hash_set indexers; 540 | int count = 0; 541 | bool bHasUnknown = false; 542 | if (wordId >= 0) { 543 | count++; 544 | if (unknownId == wordId) 545 | bHasUnknown = true; 546 | indexers.insert(wordId); 547 | for (int idx = 0; idx < wordClusterNum; idx++) { 548 | m_wordClusters[wordId][idx] = vecInfo[idx + 1]; 549 | } 550 | } 551 | 552 | while (1) { 553 | if (!my_getline(inf, strLine)) { 554 | break; 555 | } 556 | if (strLine.empty()) 557 | continue; 558 | split_bychar(strLine, vecInfo, ' '); 559 | if (vecInfo.size() != wordClusterNum + 1) { 560 | std::cout << "error embedding file" << std::endl; 561 | } 562 | curWord = normalize_to_lowerwithdigit(vecInfo[0]); 563 | wordId = m_wordAlphabet.from_string(curWord); 564 | if (wordId >= 0) { 565 | count++; 566 | if (unknownId == wordId) 567 | bHasUnknown = true; 568 | indexers.insert(wordId); 569 | for (int idx = 0; idx < wordClusterNum; idx++) { 570 | m_wordClusters[wordId][idx] = vecInfo[idx + 1]; 571 | } 572 | } 573 | 574 | } 575 | 576 | int oovWords = 0; 577 | int totalWords = 0; 578 | for (int id = 0; id < m_wordAlphabet.size(); id++) { 579 | if (indexers.find(id) == indexers.end()) { 580 | oovWords++; 581 | for (int idx = 0; idx < wordClusterNum; idx++) { 582 | m_wordClusters[id][idx] = m_wordClusters[unknownId][idx]; 583 | } 584 | } 585 | totalWords++; 586 | } 587 | 588 | std::cout << "OOV num is " << oovWords << ", total num is " << m_wordAlphabet.size() << ", cluster oov ratio is " << oovWords * 1.0 / m_wordAlphabet.size() 589 | << std::endl; 590 | 591 | } 592 | */ 593 | 594 | void Segmentor::loadModelFile(const string& inputModelFile) { 595 | 596 | } 597 | 598 | void Segmentor::writeModelFile(const string& outputModelFile) { 599 | 600 | } 601 | 602 | int main(int argc, char* argv[]) { 603 | std::string trainFile = "", devFile = "", testFile = "", modelFile = ""; 604 | std::string wordEmbFile = "", optionFile = ""; 605 | std::string outputFile = ""; 606 | bool bTrain = false; 607 | dsr::Argument_helper ah; 608 | 609 | ah.new_flag("l", "learn", "train or test", bTrain); 610 | ah.new_named_string("train", "trainCorpus", "named_string", "training corpus to train a model, must when training", trainFile); 611 | ah.new_named_string("dev", "devCorpus", "named_string", "development corpus to train a model, optional when training", devFile); 612 | ah.new_named_string("test", "testCorpus", "named_string", 613 | "testing corpus to train a model or input file to test a model, optional when training and must when testing", testFile); 614 | ah.new_named_string("model", "modelFile", "named_string", "model file, must when training and testing", modelFile); 615 | ah.new_named_string("word", "wordEmbFile", "named_string", "pretrained word embedding file to train a model, optional when training", wordEmbFile); 616 | ah.new_named_string("option", "optionFile", "named_string", "option file to train a model, optional when training", optionFile); 617 | ah.new_named_string("output", "outputFile", "named_string", "output file to test, must when testing", outputFile); 618 | 619 | ah.process(argc, argv); 620 | 621 | Segmentor segmentor; 622 | if (bTrain) { 623 | segmentor.train(trainFile, devFile, testFile, modelFile, optionFile, wordEmbFile); 624 | } else { 625 | segmentor.test(testFile, outputFile, modelFile); 626 | } 627 | 628 | //test(argv); 629 | //ah.write_values(std::cout); 630 | 631 | } 632 | -------------------------------------------------------------------------------- /APSegmentor.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Segmentor.h 3 | * 4 | * Created on: Mar 25, 2015 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef SRC_PARSER_H_ 9 | #define SRC_PARSER_H_ 10 | 11 | #include "N3L.h" 12 | 13 | #include "APBeamSearcher.h" 14 | #include "Options.h" 15 | #include "Pipe.h" 16 | #include "Utf.h" 17 | 18 | using namespace nr; 19 | using namespace std; 20 | 21 | 22 | class Segmentor { 23 | public: 24 | std::string nullkey; 25 | std::string rootdepkey; 26 | std::string unknownkey; 27 | std::string paddingtag; 28 | std::string seperateKey; 29 | 30 | public: 31 | Segmentor(); 32 | virtual ~Segmentor(); 33 | 34 | public: 35 | 36 | #if USE_CUDA==1 37 | APBeamSearcher m_classifier; 38 | #else 39 | APBeamSearcher m_classifier; 40 | #endif 41 | 42 | Options m_options; 43 | 44 | Pipe m_pipe; 45 | 46 | public: 47 | void readWordEmbeddings(const string& inFile, NRMat& wordEmb); 48 | 49 | void readWordClusters(const string& inFile); 50 | 51 | int createAlphabet(const vector& vecInsts); 52 | 53 | int addTestWordAlpha(const vector& vecInsts); 54 | 55 | public: 56 | void train(const string& trainFile, const string& devFile, const string& testFile, const string& modelFile, const string& optionFile, 57 | const string& wordEmbFile); 58 | void predict(const Instance& input, vector& output); 59 | void test(const string& testFile, const string& outputFile, const string& modelFile); 60 | 61 | // static training 62 | void getGoldActions(const vector& vecInsts, vector >& vecActions); 63 | 64 | 65 | public: 66 | 67 | void proceedOneStepForDecode(const Instance& inputTree, CStateItem& state, int& outlab); //may be merged with train in the future 68 | 69 | void writeModelFile(const string& outputModelFile); 70 | void loadModelFile(const string& inputModelFile); 71 | 72 | public: 73 | inline void getCandidateActions(const CStateItem &item, vector& actions) { 74 | 75 | } 76 | 77 | 78 | 79 | }; 80 | 81 | #endif /* SRC_PARSER_H_ */ 82 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | project(NNTransitionSegmentor) 2 | 3 | include_directories( 4 | basic/ 5 | feature/ 6 | model/ 7 | state/ 8 | # please specify where you put mshadow. 9 | # for example, /users/mszhang/mshadow/ 10 | /opt/mshadow/ 11 | ../LibN3L/ 12 | ) 13 | 14 | add_definitions(-DUSE_CUDA=0) 15 | 16 | IF(CMAKE_BUILD_TYPE MATCHES Debug) 17 | SET( CMAKE_CXX_FLAGS "-w -msse3 -funroll-loops -O0" ) 18 | ELSE() 19 | SET( CMAKE_CXX_FLAGS "-w -msse3 -funroll-loops -O3" ) 20 | ENDIF() 21 | ####for openblas 22 | add_definitions(-DMSHADOW_USE_CUDA=0) 23 | add_definitions(-DMSHADOW_USE_CBLAS=1) 24 | add_definitions(-DMSHADOW_USE_MKL=0) 25 | 26 | SET( CMAKE_SHARED_LINKER_FLAGS "-lm -lopenblas") 27 | ####endfor openblas 28 | 29 | ####for cuda 30 | #add_definitions(-DMSHADOW_USE_CUDA=1) 31 | #add_definitions(-DMSHADOW_USE_CBLAS=1) 32 | #add_definitions(-DMSHADOW_USE_MKL=0) 33 | 34 | #SET( CMAKE_SHARED_LINKER_FLAGS "-lm -lcudart -lcublas -lcurand" ) 35 | #include_directories( 36 | # $(USE_CUDA_PATH)/include 37 | #) 38 | #LINK_DIRECTORIES($(USE_CUDA_PATH)/lib64) 39 | ####endfor cuda 40 | 41 | #add_subdirectory(basic) 42 | 43 | #aux_source_directory(. DIR_SRCS) 44 | 45 | 46 | add_executable(LSTMSegmentor LSTMSegmentor.cpp) 47 | add_executable(IntegratedLSTMSegmentor IntegratedLSTMSegmentor.cpp) 48 | add_executable(LinearSegmentor LinearSegmentor.cpp) 49 | 50 | 51 | 52 | target_link_libraries(LSTMSegmentor openblas) 53 | target_link_libraries(IntegratedLSTMSegmentor openblas) 54 | target_link_libraries(LinearSegmentor openblas) 55 | 56 | 57 | 58 | 59 | #add_executable(TNNSegmentor TNNSegmentor.cpp) 60 | #add_executable(GRNNSegmentor GRNNSegmentor.cpp) 61 | #add_executable(RNNSegmentor RNNSegmentor.cpp) 62 | #add_executable(LSTMNWSegmentor LSTMNWSegmentor.cpp) 63 | #add_executable(LSTMNASegmentor LSTMNASegmentor.cpp) 64 | #add_executable(LSTMNCSegmentor LSTMNCSegmentor.cpp) 65 | #add_executable(LSTMNUCSegmentor LSTMNUCSegmentor.cpp) 66 | #add_executable(LSTMNBCSegmentor LSTMNBCSegmentor.cpp) 67 | 68 | 69 | #target_link_libraries(TNNSegmentor openblas) 70 | #target_link_libraries(GRNNSegmentor openblas) 71 | #target_link_libraries(RNNSegmentor openblas) 72 | #target_link_libraries(LSTMNWSegmentor openblas) 73 | #target_link_libraries(LSTMNASegmentor openblas) 74 | #target_link_libraries(LSTMNCSegmentor openblas) 75 | #target_link_libraries(LSTMNUCSegmentor openblas) 76 | #target_link_libraries(LSTMNBCSegmentor openblas) 77 | 78 | 79 | -------------------------------------------------------------------------------- /IntegratedLSTMSegmentor.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Segmentor.h 3 | * 4 | * Created on: Jan 25, 2016 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef SRC_PARSER_H_ 9 | #define SRC_PARSER_H_ 10 | 11 | #include "N3L.h" 12 | 13 | #include "model/IntegratedLSTMBeamSearcher.h" 14 | #include "Options.h" 15 | #include "Pipe.h" 16 | #include "Utf.h" 17 | 18 | using namespace nr; 19 | using namespace std; 20 | 21 | class Segmentor { 22 | public: 23 | std::string nullkey; 24 | std::string rootdepkey; 25 | std::string unknownkey; 26 | std::string paddingtag; 27 | std::string seperateKey; 28 | 29 | public: 30 | Segmentor(); 31 | virtual ~Segmentor(); 32 | 33 | public: 34 | 35 | #if USE_CUDA==1 36 | IntegratedLSTMBeamSearcher m_classifier; 37 | #else 38 | IntegratedLSTMBeamSearcher m_classifier; 39 | #endif 40 | 41 | Options m_options; 42 | 43 | Pipe m_pipe; 44 | 45 | hash_map m_word_stat; 46 | 47 | public: 48 | void readWordEmbeddings(const string& inFile, NRMat& wordEmb); 49 | 50 | void readWordClusters(const string& inFile); 51 | 52 | int createAlphabet(const vector& vecInsts); 53 | 54 | int addTestWordAlpha(const vector& vecInsts); 55 | 56 | int allWordAlphaEmb(const string& inFile, NRMat& emb); 57 | 58 | public: 59 | void train(const string& trainFile, const string& devFile, const string& testFile, const string& modelFile, const string& optionFile, 60 | const string& wordEmbFile, const string& charEmbFile, const string& bicharEmbFile); 61 | void predict(const Instance& input, vector& output); 62 | void test(const string& testFile, const string& outputFile, const string& modelFile); 63 | 64 | // static training 65 | void getGoldActions(const vector& vecInsts, vector >& vecActions); 66 | 67 | public: 68 | void readEmbeddings(Alphabet &alpha, const string& inFile, NRMat& emb); 69 | 70 | void writeModelFile(const string& outputModelFile); 71 | void loadModelFile(const string& inputModelFile); 72 | 73 | public: 74 | 75 | 76 | }; 77 | 78 | #endif /* SRC_PARSER_H_ */ 79 | -------------------------------------------------------------------------------- /LSTMSegmentor.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Segmentor.h 3 | * 4 | * Created on: Jan 25, 2016 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef SRC_PARSER_H_ 9 | #define SRC_PARSER_H_ 10 | 11 | #include "N3L.h" 12 | 13 | #include "model/LSTMBeamSearcher.h" 14 | #include "Options.h" 15 | #include "Pipe.h" 16 | #include "Utf.h" 17 | 18 | using namespace nr; 19 | using namespace std; 20 | 21 | class Segmentor { 22 | public: 23 | std::string nullkey; 24 | std::string rootdepkey; 25 | std::string unknownkey; 26 | std::string paddingtag; 27 | std::string seperateKey; 28 | 29 | public: 30 | Segmentor(); 31 | virtual ~Segmentor(); 32 | 33 | public: 34 | 35 | #if USE_CUDA==1 36 | LSTMBeamSearcher m_classifier; 37 | #else 38 | LSTMBeamSearcher m_classifier; 39 | #endif 40 | 41 | Options m_options; 42 | 43 | Pipe m_pipe; 44 | 45 | hash_map m_word_stat; 46 | 47 | public: 48 | void readWordEmbeddings(const string& inFile, NRMat& wordEmb); 49 | 50 | void readWordClusters(const string& inFile); 51 | 52 | int createAlphabet(const vector& vecInsts); 53 | 54 | int addTestWordAlpha(const vector& vecInsts); 55 | 56 | int allWordAlphaEmb(const string& inFile, NRMat& emb); 57 | 58 | public: 59 | void train(const string& trainFile, const string& devFile, const string& testFile, const string& modelFile, const string& optionFile, 60 | const string& wordEmbFile, const string& charEmbFile, const string& bicharEmbFile); 61 | void predict(const Instance& input, vector& output); 62 | void test(const string& testFile, const string& outputFile, const string& modelFile); 63 | 64 | // static training 65 | void getGoldActions(const vector& vecInsts, vector >& vecActions); 66 | 67 | public: 68 | void readEmbeddings(Alphabet &alpha, const string& inFile, NRMat& emb); 69 | 70 | void writeModelFile(const string& outputModelFile); 71 | void loadModelFile(const string& inputModelFile); 72 | 73 | public: 74 | 75 | 76 | }; 77 | 78 | #endif /* SRC_PARSER_H_ */ 79 | -------------------------------------------------------------------------------- /LinearSegmentor.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Segmentor.h 3 | * 4 | * Created on: Jan 25, 2016 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef SRC_PARSER_H_ 9 | #define SRC_PARSER_H_ 10 | 11 | #include "N3L.h" 12 | 13 | #include "LinearBeamSearcher.h" 14 | #include "Options.h" 15 | #include "Pipe.h" 16 | #include "Utf.h" 17 | 18 | using namespace nr; 19 | using namespace std; 20 | 21 | 22 | class Segmentor { 23 | public: 24 | std::string nullkey; 25 | std::string rootdepkey; 26 | std::string unknownkey; 27 | std::string paddingtag; 28 | std::string seperateKey; 29 | 30 | public: 31 | Segmentor(); 32 | virtual ~Segmentor(); 33 | 34 | public: 35 | 36 | #if USE_CUDA==1 37 | LinearBeamSearcher m_classifier; 38 | #else 39 | LinearBeamSearcher m_classifier; 40 | #endif 41 | 42 | Options m_options; 43 | 44 | Pipe m_pipe; 45 | 46 | public: 47 | void readWordEmbeddings(const string& inFile, NRMat& wordEmb); 48 | 49 | void readWordClusters(const string& inFile); 50 | 51 | int createAlphabet(const vector& vecInsts); 52 | 53 | int addTestWordAlpha(const vector& vecInsts); 54 | 55 | public: 56 | void train(const string& trainFile, const string& devFile, const string& testFile, const string& modelFile, const string& optionFile, 57 | const string& wordEmbFile); 58 | void predict(const Instance& input, vector& output); 59 | void test(const string& testFile, const string& outputFile, const string& modelFile); 60 | 61 | // static training 62 | void getGoldActions(const vector& vecInsts, vector >& vecActions); 63 | 64 | 65 | public: 66 | 67 | void proceedOneStepForDecode(const Instance& inputTree, CStateItem& state, int& outlab); //may be merged with train in the future 68 | 69 | void writeModelFile(const string& outputModelFile); 70 | void loadModelFile(const string& inputModelFile); 71 | 72 | public: 73 | inline void getCandidateActions(const CStateItem &item, vector& actions) { 74 | 75 | } 76 | 77 | 78 | 79 | }; 80 | 81 | #endif /* SRC_PARSER_H_ */ 82 | -------------------------------------------------------------------------------- /Options.h: -------------------------------------------------------------------------------- 1 | #ifndef _PARSER_OPTIONS_ 2 | #define _PARSER_OPTIONS_ 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "N3L.h" 11 | 12 | using namespace std; 13 | 14 | class Options { 15 | public: 16 | 17 | int wordCutOff; 18 | int featCutOff; 19 | int charCutOff; 20 | int bicharCutOff; 21 | dtype initRange; 22 | int maxIter; 23 | int batchSize; 24 | dtype adaEps; 25 | dtype adaAlpha; 26 | dtype regParameter; 27 | dtype dropProb; 28 | dtype delta; 29 | dtype clip; 30 | dtype oovRatio; 31 | 32 | int sepHiddenSize; 33 | int appHiddenSize; 34 | 35 | int wordEmbSize; 36 | int lengthEmbSize; 37 | int wordNgram; 38 | int wordHiddenSize; 39 | int wordRNNHiddenSize; 40 | bool wordEmbFineTune; 41 | 42 | int charEmbSize; 43 | int bicharEmbSize; 44 | int charcontext; 45 | int charHiddenSize; 46 | int charRNNHiddenSize; 47 | bool charEmbFineTune; 48 | bool bicharEmbFineTune; 49 | 50 | int actionEmbSize; 51 | int actionNgram; 52 | int actionHiddenSize; 53 | int actionRNNHiddenSize; 54 | 55 | int verboseIter; 56 | bool saveIntermediate; 57 | bool train; 58 | int maxInstance; 59 | vector testFiles; 60 | string outBest; 61 | 62 | Options() { 63 | wordCutOff = 4; 64 | featCutOff = 0; 65 | charCutOff = 0; 66 | bicharCutOff = 0; 67 | initRange = 0.01; 68 | maxIter = 1000; 69 | batchSize = 1; 70 | adaEps = 1e-6; 71 | adaAlpha = 0.01; 72 | regParameter = 1e-8; 73 | dropProb = 0.0; 74 | delta = 0.2; 75 | clip = -1.0; 76 | oovRatio = 0.2; 77 | 78 | sepHiddenSize = 100; 79 | appHiddenSize = 80; 80 | 81 | wordEmbSize = 0; 82 | lengthEmbSize = 20; 83 | wordNgram = 2; 84 | wordHiddenSize = 150; 85 | wordRNNHiddenSize = 100; 86 | wordEmbFineTune = true; 87 | 88 | charEmbSize = 50; 89 | bicharEmbSize = 50; 90 | charcontext = 2; 91 | charHiddenSize = 150; 92 | charRNNHiddenSize = 100; 93 | charEmbFineTune = false; 94 | bicharEmbFineTune = false; 95 | 96 | actionEmbSize = 20; 97 | actionNgram = 2; 98 | actionHiddenSize = 30; 99 | actionRNNHiddenSize = 20; 100 | 101 | verboseIter = 100; 102 | saveIntermediate = true; 103 | train = false; 104 | maxInstance = -1; 105 | testFiles.clear(); 106 | outBest = ""; 107 | 108 | } 109 | 110 | virtual ~Options() { 111 | 112 | } 113 | 114 | void setOptions(const vector &vecOption) { 115 | int i = 0; 116 | for (; i < vecOption.size(); ++i) { 117 | pair pr; 118 | string2pair(vecOption[i], pr, '='); 119 | if (pr.first == "wordCutOff") 120 | wordCutOff = atoi(pr.second.c_str()); 121 | if (pr.first == "featCutOff") 122 | featCutOff = atoi(pr.second.c_str()); 123 | if (pr.first == "charCutOff") 124 | charCutOff = atoi(pr.second.c_str()); 125 | if (pr.first == "bicharCutOff") 126 | bicharCutOff = atoi(pr.second.c_str()); 127 | if (pr.first == "initRange") 128 | initRange = atof(pr.second.c_str()); 129 | if (pr.first == "maxIter") 130 | maxIter = atoi(pr.second.c_str()); 131 | if (pr.first == "batchSize") 132 | batchSize = atoi(pr.second.c_str()); 133 | if (pr.first == "adaEps") 134 | adaEps = atof(pr.second.c_str()); 135 | if (pr.first == "adaAlpha") 136 | adaAlpha = atof(pr.second.c_str()); 137 | if (pr.first == "regParameter") 138 | regParameter = atof(pr.second.c_str()); 139 | if (pr.first == "dropProb") 140 | dropProb = atof(pr.second.c_str()); 141 | if (pr.first == "delta") 142 | delta = atof(pr.second.c_str()); 143 | if (pr.first == "clip") 144 | clip = atof(pr.second.c_str()); 145 | if (pr.first == "oovRatio") 146 | oovRatio = atof(pr.second.c_str()); 147 | 148 | if (pr.first == "sepHiddenSize") 149 | sepHiddenSize = atoi(pr.second.c_str()); 150 | if (pr.first == "appHiddenSize") 151 | appHiddenSize = atoi(pr.second.c_str()); 152 | 153 | if (pr.first == "wordEmbSize") 154 | wordEmbSize = atoi(pr.second.c_str()); 155 | if (pr.first == "lengthEmbSize") 156 | lengthEmbSize = atoi(pr.second.c_str()); 157 | if (pr.first == "wordNgram") 158 | wordNgram = atoi(pr.second.c_str()); 159 | if (pr.first == "wordHiddenSize") 160 | wordHiddenSize = atoi(pr.second.c_str()); 161 | if (pr.first == "wordRNNHiddenSize") 162 | wordRNNHiddenSize = atoi(pr.second.c_str()); 163 | if (pr.first == "wordEmbFineTune") 164 | wordEmbFineTune = (pr.second == "true") ? true : false; 165 | 166 | if (pr.first == "charEmbSize") 167 | charEmbSize = atoi(pr.second.c_str()); 168 | if (pr.first == "bicharEmbSize") 169 | bicharEmbSize = atoi(pr.second.c_str()); 170 | if (pr.first == "charcontext") 171 | charcontext = atoi(pr.second.c_str()); 172 | if (pr.first == "charHiddenSize") 173 | charHiddenSize = atoi(pr.second.c_str()); 174 | if (pr.first == "charRNNHiddenSize") 175 | charRNNHiddenSize = atoi(pr.second.c_str()); 176 | if (pr.first == "charEmbFineTune") 177 | charEmbFineTune = (pr.second == "true") ? true : false; 178 | if (pr.first == "bicharEmbFineTune") 179 | bicharEmbFineTune = (pr.second == "true") ? true : false; 180 | 181 | if (pr.first == "actionEmbSize") 182 | actionEmbSize = atoi(pr.second.c_str()); 183 | if (pr.first == "actionNgram") 184 | actionNgram = atoi(pr.second.c_str()); 185 | if (pr.first == "actionHiddenSize") 186 | actionHiddenSize = atoi(pr.second.c_str()); 187 | if (pr.first == "actionRNNHiddenSize") 188 | actionRNNHiddenSize = atoi(pr.second.c_str()); 189 | 190 | if (pr.first == "verboseIter") 191 | verboseIter = atoi(pr.second.c_str()); 192 | if (pr.first == "train") 193 | train = (pr.second == "true") ? true : false; 194 | if (pr.first == "saveIntermediate") 195 | saveIntermediate = (pr.second == "true") ? true : false; 196 | if (pr.first == "maxInstance") 197 | maxInstance = atoi(pr.second.c_str()); 198 | if (pr.first == "testFile") 199 | testFiles.push_back(pr.second); 200 | if (pr.first == "outBest") 201 | outBest = pr.second; 202 | 203 | } 204 | } 205 | 206 | void showOptions() { 207 | std::cout << "wordCutOff = " << wordCutOff << std::endl; 208 | std::cout << "featCutOff = " << featCutOff << std::endl; 209 | std::cout << "charCutOff = " << charCutOff << std::endl; 210 | std::cout << "bicharCutOff = " << bicharCutOff << std::endl; 211 | std::cout << "initRange = " << initRange << std::endl; 212 | std::cout << "maxIter = " << maxIter << std::endl; 213 | std::cout << "batchSize = " << batchSize << std::endl; 214 | std::cout << "adaEps = " << adaEps << std::endl; 215 | std::cout << "adaAlpha = " << adaAlpha << std::endl; 216 | std::cout << "regParameter = " << regParameter << std::endl; 217 | std::cout << "dropProb = " << dropProb << std::endl; 218 | std::cout << "delta = " << delta << std::endl; 219 | std::cout << "clip = " << clip << std::endl; 220 | std::cout << "oovRatio = " << oovRatio << std::endl; 221 | 222 | std::cout << "sepHiddenSize = " << sepHiddenSize << std::endl; 223 | std::cout << "appHiddenSize = " << appHiddenSize << std::endl; 224 | 225 | std::cout << "wordEmbSize = " << wordEmbSize << std::endl; 226 | std::cout << "lengthEmbSize = " << lengthEmbSize << std::endl; 227 | std::cout << "wordNgram = " << wordNgram << std::endl; 228 | std::cout << "wordHiddenSize = " << wordHiddenSize << std::endl; 229 | std::cout << "wordRNNHiddenSize = " << wordRNNHiddenSize << std::endl; 230 | std::cout << "wordEmbFineTune = " << wordEmbFineTune << std::endl; 231 | 232 | std::cout << "charEmbSize = " << charEmbSize << std::endl; 233 | std::cout << "bicharEmbSize = " << bicharEmbSize << std::endl; 234 | std::cout << "charcontext = " << charcontext << std::endl; 235 | std::cout << "charHiddenSize = " << charHiddenSize << std::endl; 236 | std::cout << "charRNNHiddenSize = " << charRNNHiddenSize << std::endl; 237 | std::cout << "charEmbFineTune = " << charEmbFineTune << std::endl; 238 | std::cout << "bicharEmbFineTune = " << bicharEmbFineTune << std::endl; 239 | 240 | std::cout << "actionEmbSize = " << actionEmbSize << std::endl; 241 | std::cout << "actionNgram = " << actionNgram << std::endl; 242 | std::cout << "actionHiddenSize = " << actionHiddenSize << std::endl; 243 | std::cout << "actionRNNHiddenSize = " << actionRNNHiddenSize << std::endl; 244 | 245 | std::cout << "verboseIter = " << verboseIter << std::endl; 246 | std::cout << "saveItermediate = " << saveIntermediate << std::endl; 247 | std::cout << "train = " << train << std::endl; 248 | std::cout << "maxInstance = " << maxInstance << std::endl; 249 | for (int idx = 0; idx < testFiles.size(); idx++) { 250 | std::cout << "testFile = " << testFiles[idx] << std::endl; 251 | } 252 | std::cout << "outBest = " << outBest << std::endl; 253 | } 254 | 255 | void load(const std::string& infile) { 256 | ifstream inf; 257 | inf.open(infile.c_str()); 258 | vector vecLine; 259 | while (1) { 260 | string strLine; 261 | if (!my_getline(inf, strLine)) { 262 | break; 263 | } 264 | if (strLine.empty()) 265 | continue; 266 | vecLine.push_back(strLine); 267 | } 268 | inf.close(); 269 | setOptions(vecLine); 270 | } 271 | }; 272 | 273 | #endif 274 | 275 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | NNTransitionSegmentor 2 | ====== 3 | NNTransitionSegmentor is a package for Word Segmentation using neural networks based on package [LibN3L](https://github.com/SUTDNLP/LibN3L). 4 | The current version is a re-implementation of segmentor in ZPar. 5 | What is the transition-based framework with beam-search decoding? Please see our ACL2014 tutorial: [Syntactic Processing Using Global Discriminative Learning and Beam-Search Decoding](http://ir.hit.edu.cn/~mszhang/yue&meishan&ting[T8].pdf) 6 | 7 | Performance 8 | ====== 9 | Take averaged perceptron as an example (CTB6.0, please refer to [LibN3L: A lightweight Package for Neural NLP](https://github.com/SUTDNLP/LibN3L/blob/master/description\(expect%20for%20lrec2016\).pdf) for details): 10 | Both ZPar and this package obtain performance about 95.08%; 11 | The normal sparse model, with max-margin training, reaches a F-measure of 95.24%. 12 | 13 | Compile 14 | ====== 15 | * Download [LibN3L](https://github.com/SUTDNLP/LibN3L) library and compile it. 16 | * Open [CMakeLists.txt](CMakeLists.txt) and change "../LibN3L/" into the directory of your [LibN3L](https://github.com/SUTDNLP/LibN3L) package. 17 | 18 | `cmake .` 19 | `make` 20 | 21 | Input data format 22 | ====== 23 | one line one sentence, with words seperated by spaces 24 | 25 | Notice 26 | ====== 27 | * one can remove the length and keyChar embeddings in my implementation to reproduce the results of my ACL paper, 28 | because the two kinds of embeddings induce little influences in the final performances, but later experiments I found that this can be more stable. 29 | * I will make the code more readable in the future. However if someone is interested in this framework, please concat me without hesitatation. 30 | 31 | -------------------------------------------------------------------------------- /basic/Action.h: -------------------------------------------------------------------------------- 1 | /* 2 | * CAction.h 3 | * 4 | * Created on: Jan 25, 2016 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef BASIC_CAction_H_ 9 | #define BASIC_CAction_H_ 10 | 11 | 12 | 13 | /*=============================================================== 14 | * 15 | * scored actions 16 | * 17 | *==============================================================*/ 18 | // for segmentation, there are only threee valid operations 19 | class CAction { 20 | public: 21 | enum CODE {NO_ACTION=0, SEP=1, APP=2, FIN=3, IDLE=4}; 22 | unsigned long _code; 23 | 24 | public: 25 | CAction() : _code(0){ 26 | } 27 | 28 | CAction(int code) : _code(code){ 29 | } 30 | 31 | CAction(const CAction &ac) : _code(ac._code){ 32 | } 33 | 34 | public: 35 | inline void clear() { _code=0; } 36 | 37 | inline void set(int code){ 38 | _code = code; 39 | } 40 | 41 | inline bool isNone() const { return _code==NO_ACTION; } 42 | inline bool isSeparate() const { return _code==SEP; } 43 | inline bool isAppend() const { return _code==APP; } 44 | inline bool isFinish() const { return _code==FIN; } 45 | inline bool isIdle() const { return _code>=IDLE; } 46 | 47 | public: 48 | inline std::string str() const { 49 | if (isNone()) { return "NONE"; } 50 | if (isIdle()) { return "IDLE"; } 51 | if (isSeparate()) { return "SEP"; } 52 | if (isAppend()) { return "APP"; } 53 | if (isFinish()) { return "FIN"; } 54 | return "IDLE"; 55 | } 56 | 57 | public: 58 | const unsigned long &code() const {return _code;} 59 | const unsigned long &hash() const {return _code;} 60 | bool operator == (const CAction &a1) const { return _code == a1._code; } 61 | bool operator != (const CAction &a1) const { return _code != a1._code; } 62 | bool operator < (const CAction &a1) const { return _code < a1._code; } 63 | bool operator > (const CAction &a1) const { return _code > a1._code; } 64 | 65 | }; 66 | 67 | 68 | inline std::istream & operator >> (std::istream &is, CAction &action) { 69 | std::string tmp; 70 | is >> tmp; 71 | 72 | if (tmp=="NONE") { 73 | action.clear(); 74 | } 75 | else if(tmp=="IDLE"){ 76 | action._code = CAction::IDLE; 77 | } 78 | else if(tmp=="SEP"){ 79 | action._code = CAction::SEP; 80 | } 81 | else if(tmp=="APP"){ 82 | action._code = CAction::APP; 83 | } 84 | else if(tmp=="FIN"){ 85 | action._code = CAction::FIN; 86 | } 87 | 88 | return is; 89 | } 90 | 91 | 92 | inline std::ostream & operator << (std::ostream &os, const CAction &action) { 93 | os << action.str(); 94 | return os; 95 | } 96 | 97 | 98 | #endif /* BASIC_CAction_H_ */ 99 | -------------------------------------------------------------------------------- /basic/Instance.h: -------------------------------------------------------------------------------- 1 | #ifndef _JST_INSTANCE_ 2 | #define _JST_INSTANCE_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "N3L.h" 9 | #include "Metric.h" 10 | 11 | using namespace std; 12 | 13 | class Instance { 14 | public: 15 | Instance() { 16 | } 17 | ~Instance() { 18 | } 19 | 20 | int wordsize() const { 21 | return words.size(); 22 | } 23 | 24 | int charsize() const { 25 | return chars.size(); 26 | } 27 | 28 | void clear() { 29 | words.clear(); 30 | chars.clear(); 31 | } 32 | 33 | void allocate(int length, int charLength) { 34 | clear(); 35 | words.resize(length); 36 | chars.resize(charLength); 37 | } 38 | 39 | void copyValuesFrom(const Instance& anInstance) { 40 | allocate(anInstance.wordsize(), anInstance.charsize()); 41 | for (int i = 0; i < anInstance.wordsize(); i++) { 42 | words[i] = anInstance.words[i]; 43 | } 44 | for (int i = 0; i < anInstance.charsize(); i++) { 45 | chars[i] = anInstance.chars[i]; 46 | } 47 | } 48 | 49 | 50 | void evaluate(const vector& resulted_segs, Metric& eval) const { 51 | hash_set golds; 52 | getSegIndexes(words, golds); 53 | 54 | hash_set preds; 55 | getSegIndexes(resulted_segs, preds); 56 | 57 | hash_set::iterator iter; 58 | eval.overall_label_count += golds.size(); 59 | eval.predicated_label_count += preds.size(); 60 | for (iter = preds.begin(); iter != preds.end(); iter++) { 61 | if (golds.find(*iter) != golds.end()) { 62 | eval.correct_label_count++; 63 | } 64 | } 65 | 66 | } 67 | 68 | void getSegIndexes(const vector& segs, hash_set& segIndexes) const{ 69 | segIndexes.clear(); 70 | int idx = 0, idy = 0; 71 | string curWord = ""; 72 | int beginId = 0; 73 | while(idx < chars.size() && idy < segs.size()){ 74 | curWord = curWord + chars[idx]; 75 | if(curWord.length() == segs[idy].length()){ 76 | stringstream ss; 77 | ss << "[" << beginId << "," << idx << "]"; 78 | segIndexes.insert(ss.str()); 79 | idy++; 80 | beginId = idx+1; 81 | curWord = ""; 82 | } 83 | idx++; 84 | } 85 | 86 | if(idx != chars.size() || idy != segs.size()){ 87 | std::cout << "error segs, please check" << std::endl; 88 | } 89 | } 90 | 91 | public: 92 | vector words; 93 | vector chars; 94 | }; 95 | 96 | #endif 97 | 98 | -------------------------------------------------------------------------------- /basic/InstanceReader.h: -------------------------------------------------------------------------------- 1 | #ifndef _CONLL_READER_ 2 | #define _CONLL_READER_ 3 | 4 | #include "Reader.h" 5 | #include "N3L.h" 6 | #include "Utf.h" 7 | #include 8 | 9 | using namespace std; 10 | /* 11 | this class reads conll-format data (10 columns, no srl-info) 12 | */ 13 | class InstanceReader: public Reader { 14 | public: 15 | InstanceReader() { 16 | } 17 | ~InstanceReader() { 18 | } 19 | 20 | Instance *getNext() { 21 | m_instance.clear(); 22 | string strLine; 23 | while (1) { 24 | if (!my_getline(m_inf, strLine)) { 25 | break; 26 | } 27 | if (!strLine.empty()) 28 | break; 29 | } 30 | 31 | vector wordInfo; 32 | split_bychar(strLine, wordInfo, ' '); 33 | 34 | string sentence = ""; 35 | for (int i = 0; i < wordInfo.size(); ++i) { 36 | sentence = sentence + wordInfo[i]; 37 | } 38 | 39 | vector charInfo; 40 | getCharactersFromUTF8String(sentence, charInfo); 41 | 42 | m_instance.allocate(wordInfo.size(), charInfo.size()); 43 | for (int i = 0; i < wordInfo.size(); ++i) { 44 | m_instance.words[i] = wordInfo[i]; 45 | } 46 | for (int i = 0; i < charInfo.size(); ++i) { 47 | m_instance.chars[i] = charInfo[i]; 48 | } 49 | 50 | return &m_instance; 51 | } 52 | }; 53 | 54 | #endif 55 | 56 | -------------------------------------------------------------------------------- /basic/InstanceWriter.h: -------------------------------------------------------------------------------- 1 | #ifndef _CONLL_WRITER_ 2 | #define _CONLL_WRITER_ 3 | 4 | #include "Writer.h" 5 | #include 6 | 7 | using namespace std; 8 | 9 | class InstanceWriter : public Writer 10 | { 11 | public: 12 | InstanceWriter(){} 13 | ~InstanceWriter(){} 14 | int write(const Instance *pInstance) 15 | { 16 | if (!m_outf.is_open()) return -1; 17 | 18 | for (int i = 0; i < pInstance->wordsize(); ++i) { 19 | m_outf << pInstance->words[i] << " "; 20 | } 21 | m_outf << endl; 22 | return 0; 23 | } 24 | 25 | int write(const vector &curWords) 26 | { 27 | if (!m_outf.is_open()) return -1; 28 | for (int i = 0; i < curWords.size(); ++i) { 29 | m_outf << curWords[i] << " "; 30 | } 31 | m_outf << endl; 32 | return 0; 33 | } 34 | }; 35 | 36 | #endif 37 | 38 | -------------------------------------------------------------------------------- /basic/Pipe.h: -------------------------------------------------------------------------------- 1 | #ifndef _JST_PIPE_ 2 | #define _JST_PIPE_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "Instance.h" 11 | #include "InstanceReader.h" 12 | #include "InstanceWriter.h" 13 | #include 14 | 15 | using namespace std; 16 | 17 | //#define MAX_BUFFER_SIZE 256 18 | 19 | class Pipe { 20 | public: 21 | Pipe() { 22 | m_jstReader = new InstanceReader(); 23 | m_jstWriter = new InstanceWriter(); 24 | } 25 | 26 | ~Pipe(void) { 27 | if (m_jstReader) 28 | delete m_jstReader; 29 | if (m_jstWriter) 30 | delete m_jstWriter; 31 | } 32 | 33 | int initInputFile(const char *filename) { 34 | if (0 != m_jstReader->startReading(filename)) 35 | return -1; 36 | return 0; 37 | } 38 | 39 | void uninitInputFile() { 40 | if (m_jstWriter) 41 | m_jstReader->finishReading(); 42 | } 43 | 44 | int initOutputFile(const char *filename) { 45 | if (0 != m_jstWriter->startWriting(filename)) 46 | return -1; 47 | return 0; 48 | } 49 | 50 | void uninitOutputFile() { 51 | if (m_jstWriter) 52 | m_jstWriter->finishWriting(); 53 | } 54 | 55 | int outputAllInstances(const string& m_strOutFile, const vector >& vecInstances) { 56 | 57 | initOutputFile(m_strOutFile.c_str()); 58 | static int instNum; 59 | instNum = vecInstances.size(); 60 | for (int idx = 0; idx < instNum; idx++) { 61 | if (0 != m_jstWriter->write(vecInstances[idx])) 62 | return -1; 63 | } 64 | 65 | uninitOutputFile(); 66 | return 0; 67 | } 68 | 69 | int outputSingleInstance(const Instance& inst) { 70 | 71 | if (0 != m_jstWriter->write(&inst)) 72 | return -1; 73 | return 0; 74 | } 75 | 76 | Instance* nextInstance() { 77 | Instance *pInstance = m_jstReader->getNext(); 78 | if (!pInstance || pInstance->words.empty()) 79 | return 0; 80 | 81 | return pInstance; 82 | } 83 | 84 | void readInstances(const string& m_strInFile, vector& vecInstances, int max_sentence_size, int maxInstance = -1) { 85 | vecInstances.clear(); 86 | initInputFile(m_strInFile.c_str()); 87 | 88 | Instance *pInstance = nextInstance(); 89 | int numInstance = 0; 90 | 91 | while (pInstance) { 92 | 93 | if (pInstance->charsize() < max_sentence_size) { 94 | Instance trainInstance; 95 | trainInstance.copyValuesFrom(*pInstance); 96 | vecInstances.push_back(trainInstance); 97 | numInstance++; 98 | 99 | if (numInstance == maxInstance) { 100 | break; 101 | } 102 | } 103 | 104 | pInstance = nextInstance(); 105 | 106 | } 107 | 108 | uninitInputFile(); 109 | 110 | cout << endl; 111 | cout << "instance num: " << numInstance << endl; 112 | } 113 | 114 | protected: 115 | Reader *m_jstReader; 116 | Writer *m_jstWriter; 117 | 118 | }; 119 | 120 | #endif 121 | -------------------------------------------------------------------------------- /basic/Reader.h: -------------------------------------------------------------------------------- 1 | #ifndef _JST_READER_ 2 | #define _JST_READER_ 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | using namespace std; 9 | 10 | #include "Instance.h" 11 | 12 | class Reader 13 | { 14 | public: 15 | Reader() 16 | { 17 | } 18 | 19 | virtual ~Reader() 20 | { 21 | if (m_inf.is_open()) m_inf.close(); 22 | } 23 | int startReading(const char *filename) { 24 | if (m_inf.is_open()) { 25 | m_inf.close(); 26 | m_inf.clear(); 27 | } 28 | m_inf.open(filename); 29 | 30 | if (!m_inf.is_open()) { 31 | cout << "Reader::startReading() open file err: " << filename << endl; 32 | return -1; 33 | } 34 | 35 | return 0; 36 | } 37 | 38 | void finishReading() { 39 | if (m_inf.is_open()) { 40 | m_inf.close(); 41 | m_inf.clear(); 42 | } 43 | } 44 | virtual Instance *getNext() = 0; 45 | protected: 46 | ifstream m_inf; 47 | 48 | int m_numInstance; 49 | 50 | Instance m_instance; 51 | }; 52 | 53 | #endif 54 | 55 | -------------------------------------------------------------------------------- /basic/SegLookupTable.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SegLookupTable.h 3 | * 4 | * Created on: Jan 25, 2016 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef SRC_SegLookupTable_H_ 9 | #define SRC_SegLookupTable_H_ 10 | #include "tensor.h" 11 | #include "Utiltensor.h" 12 | #include "MyLib.h" 13 | 14 | using namespace mshadow; 15 | using namespace mshadow::expr; 16 | using namespace mshadow::utils; 17 | 18 | // Weight updating process implemented without theory support, 19 | // but recently find an EMNLP 2015 paper "An Empirical Analysis of Optimization for Max-Margin NLP" 20 | // In all my papers that use adagrad for sparse features, I use it for parameter updating. 21 | 22 | template 23 | class SegLookupTable { 24 | 25 | public: 26 | 27 | hash_set _indexers; 28 | 29 | Tensor _E; 30 | Tensor _gradE; 31 | Tensor _eg2E; 32 | 33 | Tensor _ftE; 34 | 35 | bool _bFineTune; 36 | int _nDim; 37 | int _nVSize; 38 | 39 | int _max_update; 40 | NRVec _last_update; 41 | 42 | NRVec _freq; 43 | 44 | 45 | public: 46 | 47 | SegLookupTable() { 48 | _indexers.clear(); 49 | } 50 | 51 | 52 | inline void initial(const NRMat& wordEmb) { 53 | _nVSize = wordEmb.nrows(); 54 | _nDim = wordEmb.ncols(); 55 | 56 | _E = NewTensor(Shape2(_nVSize, _nDim), d_zero); 57 | _gradE = NewTensor(Shape2(_nVSize, _nDim), d_zero); 58 | _eg2E = NewTensor(Shape2(_nVSize, _nDim), d_zero); 59 | _ftE = NewTensor(Shape2(_nVSize, _nDim), d_one); 60 | assign(_E, wordEmb); 61 | for (int idx = 0; idx < _nVSize; idx++) { 62 | norm2one(_E, idx); 63 | } 64 | 65 | _bFineTune = true; 66 | 67 | _max_update = 0; 68 | _last_update.resize(_nVSize); 69 | _last_update = 0; 70 | 71 | _freq.resize(_nVSize); 72 | _freq = -1; 73 | } 74 | 75 | inline void setEmbFineTune(bool bFineTune) { 76 | _bFineTune = bFineTune; 77 | } 78 | 79 | inline void setFrequency(hash_map wordfreq) { 80 | static hash_map::iterator action_iter; 81 | for (action_iter = wordfreq.begin(); action_iter != wordfreq.end(); action_iter++) { 82 | _freq[action_iter->first] = action_iter->second; 83 | } 84 | } 85 | 86 | inline int getFrequency(int id){ 87 | return _freq[id]; 88 | } 89 | 90 | inline void release() { 91 | FreeSpace(&_E); 92 | FreeSpace(&_gradE); 93 | FreeSpace(&_eg2E); 94 | FreeSpace(&_ftE); 95 | _indexers.clear(); 96 | } 97 | 98 | virtual ~SegLookupTable() { 99 | // TODO Auto-generated destructor stub 100 | } 101 | 102 | inline dtype squarenormAll() { 103 | dtype result = 0; 104 | static hash_set::iterator it; 105 | for (int idx = 0; idx < _nDim; idx++) { 106 | for (it = _indexers.begin(); it != _indexers.end(); ++it) { 107 | result += _gradE[*it][idx] * _gradE[*it][idx]; 108 | } 109 | } 110 | 111 | 112 | return result; 113 | } 114 | 115 | inline void scaleGrad(dtype scale) { 116 | static hash_set::iterator it; 117 | for (int idx = 0; idx < _nDim; idx++) { 118 | for (it = _indexers.begin(); it != _indexers.end(); ++it) { 119 | _gradE[*it][idx] = _gradE[*it][idx] * scale; 120 | } 121 | } 122 | } 123 | 124 | inline bool bEmbFineTune() 125 | { 126 | return _bFineTune; 127 | } 128 | 129 | public: 130 | /* (1) unk is -1 when training 131 | * (2) if unk >= 0, then must be in test phase, if last_update equals zero, 132 | * denoting that never be trained, thus we regards it as unknown. 133 | */ 134 | void GetEmb(int id, Tensor y, int unk = -1) { 135 | updateSparseWeight(id); 136 | assert(y.size(0) == 1); 137 | y = 0.0; 138 | if(unk < 0 || _last_update[id] > 0){ 139 | y[0] += _E[id]; 140 | } 141 | else{ 142 | y[0] += _E[unk]; 143 | } 144 | } 145 | 146 | // loss is stopped at this layer, since the input is one-hold alike 147 | void EmbLoss(int id, Tensor ly) { 148 | if(!_bFineTune) return; 149 | //_gradE 150 | assert(ly.size(0) == 1); 151 | _gradE[id] += ly[0]; 152 | _indexers.insert(id); 153 | 154 | } 155 | 156 | 157 | void randomprint(int num) { 158 | static int _nVSize, _nDim; 159 | _nVSize = _E.size(0); 160 | _nDim = _E.size(1); 161 | int count = 0; 162 | while (count < num) { 163 | int idx = rand() % _nVSize; 164 | int idy = rand() % _nDim; 165 | 166 | std::cout << "_E[" << idx << "," << idy << "]=" << _E[idx][idy] << " "; 167 | 168 | count++; 169 | } 170 | 171 | std::cout << std::endl; 172 | } 173 | 174 | void updateAdaGrad(dtype regularizationWeight, dtype adaAlpha, dtype adaEps) { 175 | 176 | if(!_bFineTune) return; 177 | static hash_set::iterator it; 178 | _max_update++; 179 | 180 | Tensor sqrt_eg2E = NewTensor(Shape1(_E.size(1)), d_zero); 181 | 182 | for (it = _indexers.begin(); it != _indexers.end(); ++it) { 183 | int index = *it; 184 | _eg2E[index] = _eg2E[index] + _gradE[index] * _gradE[index]; 185 | sqrt_eg2E = F(_eg2E[index] + adaEps); 186 | _E[index] = (_E[index] * sqrt_eg2E - _gradE[index] * adaAlpha) / (adaAlpha * regularizationWeight + sqrt_eg2E); 187 | _ftE[index] = sqrt_eg2E / (adaAlpha * regularizationWeight + sqrt_eg2E); 188 | } 189 | 190 | FreeSpace(&sqrt_eg2E); 191 | 192 | clearGrad(); 193 | } 194 | 195 | void clearGrad() { 196 | static hash_set::iterator it; 197 | 198 | for (it = _indexers.begin(); it != _indexers.end(); ++it) { 199 | int index = *it; 200 | _gradE[index] = 0.0; 201 | } 202 | 203 | _indexers.clear(); 204 | 205 | } 206 | 207 | void updateSparseWeight(int wordId) { 208 | if(!_bFineTune) return; 209 | if (_last_update[wordId] < _max_update) { 210 | int times = _max_update - _last_update[wordId]; 211 | _E[wordId] = _E[wordId] * F(times * F(_ftE[wordId])); 212 | _last_update[wordId] = _max_update; 213 | } 214 | } 215 | 216 | void writeModel(LStream &outf) { 217 | SaveBinary(outf, _E); 218 | SaveBinary(outf, _gradE); 219 | SaveBinary(outf, _eg2E); 220 | SaveBinary(outf, _ftE); 221 | 222 | WriteBinary(outf, _bFineTune); 223 | WriteBinary(outf, _nDim); 224 | WriteBinary(outf, _nVSize); 225 | WriteBinary(outf, _max_update); 226 | WriteVector(outf, _last_update); 227 | } 228 | void loadModel(LStream &inf) { 229 | LoadBinary(inf, &_E, false); 230 | LoadBinary(inf, &_gradE, false); 231 | LoadBinary(inf, &_eg2E, false); 232 | LoadBinary(inf, &_ftE, false); 233 | 234 | ReadBinary(inf, _bFineTune); 235 | ReadBinary(inf, _nDim); 236 | ReadBinary(inf, _nVSize); 237 | ReadBinary(inf, _max_update); 238 | 239 | ReadVector(inf, _last_update); 240 | } 241 | 242 | }; 243 | 244 | #endif /* SRC_SegLookupTable_H_ */ 245 | -------------------------------------------------------------------------------- /basic/Utf.h: -------------------------------------------------------------------------------- 1 | // Copyright (C) University of Oxford 2010 2 | /**************************************************************** 3 | * * 4 | * utf.h - the utilities for unicode characters. * 5 | * * 6 | * Author: Yue Zhang * 7 | * * 8 | * Computing Laboratory, Oxford. 2007.6 * 9 | * * 10 | ****************************************************************/ 11 | 12 | #ifndef _UTILITY_UTF_H 13 | #define _UTILITY_UTF_H 14 | 15 | #include 16 | #include 17 | 18 | /*=============================================================== 19 | * 20 | * Unicode std::string and character utils 21 | * 22 | *==============================================================*/ 23 | 24 | /*--------------------------------------------------------------- 25 | * 26 | * getUTF8StringLength - get how many characters are in a UTF8 std::string 27 | * 28 | *--------------------------------------------------------------*/ 29 | 30 | inline 31 | unsigned long int getUTF8StringLength(const std::string &s) { 32 | unsigned long int retval = 0; 33 | unsigned long int idx = 0; 34 | while (idx < s.length()) { 35 | if ((s[idx] & 0x80) == 0) { 36 | ++idx; 37 | ++retval; 38 | } else if ((s[idx] & 0xE0) == 0xC0) { 39 | idx += 2; 40 | ++retval; 41 | } else if ((s[idx] & 0xF0) == 0xE0) { 42 | idx += 3; 43 | ++retval; 44 | } else { 45 | std::cerr << "Warning: " << "in utf.h getUTF8StringLength: std::string '" << s << "' not encoded in unicode utf-8" << std::endl; 46 | return retval; 47 | } 48 | } 49 | if (idx != s.length()) { 50 | std::cerr << "Warning: " << "in utf.h getUTF8StringLength: std::string '" << s << "' not encoded in unicode utf-8" << std::endl; 51 | return retval; 52 | } 53 | return retval; 54 | } 55 | 56 | /*---------------------------------------------------------------- 57 | * 58 | * getCharactersFromUTF8String - get the characters from 59 | * utf std::string. The characters from 60 | * this std::string are appended 61 | * to a given sentence. 62 | * 63 | *----------------------------------------------------------------*/ 64 | 65 | inline int getCharactersFromUTF8String(const std::string &s, std::vector& sentence) { 66 | sentence.clear(); 67 | unsigned long int idx = 0; 68 | unsigned long int len = 0; 69 | while (idx < s.length()) { 70 | if ((s[idx] & 0x80) == 0) { 71 | sentence.push_back(s.substr(idx, 1)); 72 | ++len; 73 | ++idx; 74 | } else if ((s[idx] & 0xE0) == 0xC0) { 75 | sentence.push_back(s.substr(idx, 2)); 76 | ++len; 77 | idx += 2; 78 | } else if ((s[idx] & 0xF0) == 0xE0) { 79 | sentence.push_back(s.substr(idx, 3)); 80 | ++len; 81 | idx += 3; 82 | } else { 83 | std::cerr << "Warning: " << "in utf.h getCharactersFromUTF8String: std::string '" << s << "' not encoded in unicode utf-8" << std::endl; 84 | sentence.push_back("?"); 85 | ++len; 86 | ++idx; 87 | } 88 | } 89 | if (idx != s.length()) { 90 | std::cerr << "Warning: " << "in utf.h getCharactersFromUTF8String: std::string '" << s << "' not encoded in utf-8" << std::endl; 91 | return len; 92 | } 93 | 94 | return len; 95 | } 96 | 97 | /*---------------------------------------------------------------- 98 | * 99 | * getFirstCharFromUTF8String - get the first character from 100 | * utf std::string. 101 | * 102 | *----------------------------------------------------------------*/ 103 | 104 | inline std::string getFirstCharFromUTF8String(const std::string &s) { 105 | if (s == "") 106 | return ""; 107 | if ((s[0] & 0x80) == 0) { 108 | return s.substr(0, 1); 109 | } else if ((s[0] & 0xE0) == 0xC0) { 110 | assert(s.length() >= 2); 111 | return s.substr(0, 2); 112 | } else if ((s[0] & 0xF0) == 0xE0) { 113 | assert(s.length() >= 3); 114 | return s.substr(0, 3); 115 | } else { 116 | std::cerr << "Warning: " << "in utf.h getFirstCharFromUTF8String: std::string '" << s << "' not encoded in unicode utf-8" << std::endl; 117 | return "?"; 118 | } 119 | } 120 | 121 | /*---------------------------------------------------------------- 122 | * 123 | * getLastCharFromUTF8String - get the last character from 124 | * utf std::string. 125 | * 126 | *----------------------------------------------------------------*/ 127 | 128 | inline std::string getLastCharFromUTF8String(const std::string &s) { 129 | if (s == "") 130 | return ""; 131 | unsigned long int idx = 0; 132 | std::string retval; 133 | while (idx < s.length()) { 134 | if ((s[idx] & 0x80) == 0) { 135 | retval = s.substr(idx, 1); 136 | ++idx; 137 | } else if ((s[idx] & 0xE0) == 0xC0) { 138 | retval = s.substr(idx, 2); 139 | idx += 2; 140 | } else if ((s[idx] & 0xF0) == 0xE0) { 141 | retval = s.substr(idx, 3); 142 | idx += 3; 143 | } else { 144 | std::cerr << "Warning: " << "in utf.h getLastCharFromUTF8String: std::string '" << s << "' not encoded in unicode utf-8" << std::endl; 145 | return "?"; 146 | } 147 | } 148 | if (idx != s.length()) { 149 | std::cerr << "Warning: " << "in utf.h getLastCharFromUTF8String: std::string '" << s << "' not encoded in unicode utf-8" << std::endl; 150 | return "?"; 151 | } 152 | return retval; 153 | } 154 | 155 | /*---------------------------------------------------------------- 156 | * 157 | * isOneUTF8Character - whether a std::string is one utf8 character 158 | * 159 | *----------------------------------------------------------------*/ 160 | 161 | inline bool isOneUTF8Character(const std::string &s) { 162 | if (s == "") 163 | return false; // is no utf character 164 | if (s.size() > 3) 165 | return false; // is more than one utf character 166 | if ((s[0] & 0x80) == 0) { 167 | return s.size() == 1; 168 | } else if ((s[0] & 0xE0) == 0xC0) { 169 | return s.size() == 2; 170 | } else if ((s[0] & 0xF0) == 0xE0) { 171 | return s.size() == 3; 172 | } 173 | } 174 | 175 | inline std::string getUTF8CharType(const std::string &s) { 176 | std::string digit = "0123456789"; 177 | std::string eng = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; 178 | std::string defaultType = "default"; 179 | if (s.length() <= 2) { 180 | if (digit.find(s) != -1) { 181 | defaultType = "digit"; 182 | } else if (eng.find(s) != -1) { 183 | defaultType = "eng"; 184 | } 185 | } else { 186 | defaultType = "unitype"; 187 | } 188 | 189 | return defaultType; 190 | } 191 | 192 | inline std::string wordtype(const std::string &s) { 193 | std::vector chars; 194 | getCharactersFromUTF8String(s, chars); 195 | std::string type = ""; 196 | for (int i = 0; i < chars.size(); i++) { 197 | if (chars[i].length() > 1) { 198 | type = type + "u"; 199 | } else if (isdigit(chars[i][0])) { 200 | type = type + "d"; 201 | } else if (isalpha(chars[i][0])) { 202 | if (islower(chars[i][0])) 203 | type = type + "e"; 204 | else 205 | type = type + "E"; 206 | } 207 | } 208 | return type; 209 | } 210 | 211 | inline std::string normalize_to_lowerwithdigit(const std::string& s) 212 | { 213 | std::vector chars; 214 | getCharactersFromUTF8String(s, chars); 215 | std::string lowcase = ""; 216 | for (int i = 0; i < chars.size(); i++) { 217 | if (chars[i].length() > 1) { 218 | lowcase = lowcase + chars[i]; 219 | } else if (isdigit(chars[i][0])) { 220 | lowcase = lowcase + "0"; 221 | } else if (isalpha(chars[i][0])) { 222 | if (islower(chars[i][0])) 223 | { 224 | lowcase = lowcase + chars[i][0]; 225 | } 226 | else 227 | { 228 | char temp = chars[i][0] + 'a'-'A'; 229 | lowcase = lowcase + temp; 230 | } 231 | } 232 | else 233 | { 234 | lowcase = lowcase + chars[i]; 235 | } 236 | } 237 | return lowcase; 238 | } 239 | 240 | #endif 241 | -------------------------------------------------------------------------------- /basic/Writer.h: -------------------------------------------------------------------------------- 1 | #ifndef _JST_WRITER_ 2 | #define _JST_WRITER_ 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | using namespace std; 9 | 10 | #include "Instance.h" 11 | 12 | class Writer 13 | { 14 | public: 15 | Writer() 16 | { 17 | } 18 | virtual ~Writer() 19 | { 20 | if (m_outf.is_open()) m_outf.close(); 21 | } 22 | 23 | inline int startWriting(const char *filename) { 24 | m_outf.open(filename); 25 | if (!m_outf) { 26 | cout << "Writerr::startWriting() open file err: " << filename << endl; 27 | return -1; 28 | } 29 | return 0; 30 | } 31 | 32 | inline void finishWriting() { 33 | m_outf.close(); 34 | } 35 | 36 | virtual int write(const Instance *pInstance) = 0; 37 | virtual int write(const vector &curWords) = 0; 38 | protected: 39 | ofstream m_outf; 40 | }; 41 | 42 | #endif 43 | 44 | -------------------------------------------------------------------------------- /cleanall.sh: -------------------------------------------------------------------------------- 1 | make clean 2 | rm -rf CMakeFiles cmake_install.cmake CMakeCache.txt Makefile 3 | rm -rf */CMakeFiles */cmake_install.cmake */Makefile 4 | -------------------------------------------------------------------------------- /feature/DenseFeature.h: -------------------------------------------------------------------------------- 1 | /* 2 | * DenseFeature.h 3 | * 4 | * Created on: Jan 25, 2016 5 | * Author: mason 6 | */ 7 | 8 | #ifndef FEATURE_DENSEFEATURE_H_ 9 | #define FEATURE_DENSEFEATURE_H_ 10 | 11 | #include "N3L.h" 12 | 13 | template 14 | class DenseFeature { 15 | public: 16 | //all state inter dependent features 17 | vector > _wordIds, _keyCharIds, _lengthIds, _actionIds; 18 | vector > _wordPrime, _wordPrimeLoss; 19 | vector > _allwordPrime; 20 | vector > _keyCharPrime, _keyCharPrimeLoss; 21 | vector > _lengthPrime, _lengthPrimeLoss; 22 | vector > _wordRep, _wordRepLoss; 23 | vector > _allwordRep, _allwordRepLoss; 24 | vector > _keyCharRep, _keyCharRepLoss; 25 | vector > _lengthRep, _lengthRepLoss; 26 | vector > _wordUnitRep, _wordUnitRepLoss, _wordUnitRepMask; 27 | vector > _wordHidden, _wordHiddenLoss; 28 | 29 | vector > > _wordRNNHiddenBuf; 30 | vector > _wordRNNHidden, _wordRNNHiddenLoss; //lstm 31 | 32 | 33 | vector > _actionPrime, _actionPrimeLoss; 34 | vector > _actionRep, _actionRepLoss, _actionRepMask; 35 | vector > _actionHidden, _actionHiddenLoss; 36 | 37 | vector > > _actionRNNHiddenBuf; 38 | vector > _actionRNNHidden, _actionRNNHiddenLoss; //lstm 39 | 40 | int _steps; 41 | int _wordnum; 42 | int _buffer; 43 | 44 | public: 45 | DenseFeature() { 46 | _steps = 0; 47 | _wordnum = 0; 48 | _buffer = 0; 49 | } 50 | 51 | ~DenseFeature() { 52 | clear(); 53 | } 54 | 55 | public: 56 | inline void init(int wordnum, int steps, int wordDim, int allwordDim, int charDim, int lengthDim, int wordNgram, int wordUnitDim, int wordHiddenDim, int wordRNNDim, 57 | int actionDim, int actionNgram, int actionHiddenDim, int actionRNNDim, int buffer = 0) { 58 | 59 | clear(); 60 | _steps = steps; 61 | _wordnum = wordnum; 62 | _buffer = buffer; 63 | 64 | if (wordnum > 0) { 65 | _wordIds.resize(wordnum); 66 | _keyCharIds.resize(wordnum); 67 | _lengthIds.resize(wordnum); 68 | 69 | _wordPrime.resize(wordnum); 70 | _allwordPrime.resize(wordnum); 71 | _wordRep.resize(wordnum); 72 | _allwordRep.resize(wordnum); 73 | _keyCharPrime.resize(wordnum); 74 | _keyCharRep.resize(wordnum); 75 | _lengthPrime.resize(wordnum); 76 | _lengthRep.resize(wordnum); 77 | _wordUnitRep.resize(wordnum); 78 | _wordHidden.resize(wordnum); 79 | if(_buffer > 0){ 80 | _wordRNNHiddenBuf.resize(_buffer); 81 | for(int idk = 0; idk < _buffer; idk++){ 82 | _wordRNNHiddenBuf[idk].resize(wordnum); 83 | } 84 | } 85 | _wordRNNHidden.resize(wordnum); 86 | 87 | _wordPrimeLoss.resize(wordnum); 88 | _wordRepLoss.resize(wordnum); 89 | _allwordRepLoss.resize(wordnum); 90 | _keyCharPrimeLoss.resize(wordnum); 91 | _keyCharRepLoss.resize(wordnum); 92 | _lengthPrimeLoss.resize(wordnum); 93 | _lengthRepLoss.resize(wordnum); 94 | _wordUnitRepLoss.resize(wordnum); 95 | _wordUnitRepMask.resize(wordnum); 96 | _wordHiddenLoss.resize(wordnum); 97 | _wordRNNHiddenLoss.resize(wordnum); 98 | 99 | for (int idx = 0; idx < wordnum; idx++) { 100 | _wordIds[idx].resize(wordNgram); 101 | _keyCharIds[idx].resize(2 * wordNgram + 1); 102 | _lengthIds[idx].resize(wordNgram); 103 | _wordPrime[idx] = NewTensor(Shape3(wordNgram, 1, wordDim), d_zero); 104 | _allwordPrime[idx] = NewTensor(Shape3(wordNgram, 1, allwordDim), d_zero); 105 | _wordRep[idx] = NewTensor(Shape2(1, wordNgram * wordDim), d_zero); 106 | _allwordRep[idx] = NewTensor(Shape2(1, wordNgram * allwordDim), d_zero); 107 | _keyCharPrime[idx] = NewTensor(Shape3(2 * wordNgram + 1, 1, charDim), d_zero); 108 | _keyCharRep[idx] = NewTensor(Shape2(1, (2 * wordNgram + 1) * charDim), d_zero); 109 | _lengthPrime[idx] = NewTensor(Shape3(wordNgram, 1, lengthDim), d_zero); 110 | _lengthRep[idx] = NewTensor(Shape2(1, wordNgram * lengthDim), d_zero); 111 | _wordUnitRep[idx] = NewTensor(Shape2(1, wordUnitDim), d_zero); 112 | _wordHidden[idx] = NewTensor(Shape2(1, wordHiddenDim), d_zero); 113 | for(int idk = 0; idk < _buffer; idk++){ 114 | _wordRNNHiddenBuf[idk][idx] = NewTensor(Shape2(1, wordRNNDim), d_zero); 115 | } 116 | _wordRNNHidden[idx] = NewTensor(Shape2(1, wordRNNDim), d_zero); 117 | 118 | _wordPrimeLoss[idx] = NewTensor(Shape3(wordNgram, 1, wordDim), d_zero); 119 | _wordRepLoss[idx] = NewTensor(Shape2(1, wordNgram * wordDim), d_zero); 120 | _allwordRepLoss[idx] = NewTensor(Shape2(1, wordNgram * allwordDim), d_zero); 121 | _keyCharPrimeLoss[idx] = NewTensor(Shape3(2 * wordNgram + 1, 1, charDim), d_zero); 122 | _keyCharRepLoss[idx] = NewTensor(Shape2(1, (2 * wordNgram + 1) * charDim), d_zero); 123 | _lengthPrimeLoss[idx] = NewTensor(Shape3(wordNgram, 1, lengthDim), d_zero); 124 | _lengthRepLoss[idx] = NewTensor(Shape2(1, wordNgram * lengthDim), d_zero); 125 | _wordUnitRepLoss[idx] = NewTensor(Shape2(1, wordUnitDim), d_zero); 126 | _wordUnitRepMask[idx] = NewTensor(Shape2(1, wordUnitDim), d_zero); 127 | _wordHiddenLoss[idx] = NewTensor(Shape2(1, wordHiddenDim), d_zero); 128 | _wordRNNHiddenLoss[idx] = NewTensor(Shape2(1, wordRNNDim), d_zero); 129 | } 130 | } 131 | 132 | if (steps > 0) { 133 | _actionIds.resize(steps); 134 | _actionPrime.resize(steps); 135 | _actionRep.resize(steps); 136 | _actionHidden.resize(steps); 137 | if(_buffer > 0){ 138 | _actionRNNHiddenBuf.resize(_buffer); 139 | for(int idk = 0; idk < _buffer; idk++){ 140 | _actionRNNHiddenBuf[idk].resize(steps); 141 | } 142 | } 143 | _actionRNNHidden.resize(steps); 144 | 145 | _actionPrimeLoss.resize(steps); 146 | _actionRepLoss.resize(steps); 147 | _actionRepMask.resize(steps); 148 | _actionHiddenLoss.resize(steps); 149 | _actionRNNHiddenLoss.resize(steps); 150 | for (int idx = 0; idx < steps; idx++) { 151 | _actionIds[idx].resize(actionNgram); 152 | _actionPrime[idx] = NewTensor(Shape3(actionNgram, 1, actionDim), d_zero); 153 | _actionRep[idx] = NewTensor(Shape2(1, actionNgram * actionDim), d_zero); 154 | _actionHidden[idx] = NewTensor(Shape2(1, actionHiddenDim), d_zero); 155 | for(int idk = 0; idk < _buffer; idk++){ 156 | _actionRNNHiddenBuf[idk][idx] = NewTensor(Shape2(1, actionRNNDim), d_zero); 157 | } 158 | _actionRNNHidden[idx] = NewTensor(Shape2(1, actionRNNDim), d_zero); 159 | 160 | _actionPrimeLoss[idx] = NewTensor(Shape3(actionNgram, 1, actionDim), d_zero); 161 | _actionRepLoss[idx] = NewTensor(Shape2(1, actionNgram * actionDim), d_zero); 162 | _actionRepMask[idx] = NewTensor(Shape2(1, actionNgram * actionDim), d_zero); 163 | _actionHiddenLoss[idx] = NewTensor(Shape2(1, actionHiddenDim), d_zero); 164 | _actionRNNHiddenLoss[idx] = NewTensor(Shape2(1, actionRNNDim), d_zero); 165 | } 166 | } 167 | 168 | } 169 | 170 | inline void clear() { 171 | for (int idx = 0; idx < _wordnum; idx++) { 172 | _wordIds[idx].clear(); 173 | _keyCharIds[idx].clear(); 174 | _lengthIds[idx].clear(); 175 | FreeSpace(&(_wordPrime[idx])); 176 | FreeSpace(&(_allwordPrime[idx])); 177 | FreeSpace(&(_wordRep[idx])); 178 | FreeSpace(&(_allwordRep[idx])); 179 | FreeSpace(&(_keyCharPrime[idx])); 180 | FreeSpace(&(_keyCharRep[idx])); 181 | FreeSpace(&(_lengthPrime[idx])); 182 | FreeSpace(&(_lengthRep[idx])); 183 | FreeSpace(&(_wordUnitRep[idx])); 184 | FreeSpace(&(_wordHidden[idx])); 185 | for(int idk = 0; idk < _buffer; idk++){ 186 | FreeSpace(&(_wordRNNHiddenBuf[idk][idx])); 187 | } 188 | FreeSpace(&(_wordRNNHidden[idx])); 189 | 190 | FreeSpace(&(_wordPrimeLoss[idx])); 191 | FreeSpace(&(_wordRepLoss[idx])); 192 | FreeSpace(&(_allwordRepLoss[idx])); 193 | FreeSpace(&(_keyCharPrimeLoss[idx])); 194 | FreeSpace(&(_keyCharRepLoss[idx])); 195 | FreeSpace(&(_lengthPrimeLoss[idx])); 196 | FreeSpace(&(_lengthRepLoss[idx])); 197 | FreeSpace(&(_wordUnitRepLoss[idx])); 198 | FreeSpace(&(_wordUnitRepMask[idx])); 199 | FreeSpace(&(_wordHiddenLoss[idx])); 200 | FreeSpace(&(_wordRNNHiddenLoss[idx])); 201 | } 202 | _wordIds.clear(); 203 | _keyCharIds.clear(); 204 | _lengthIds.clear(); 205 | _wordPrime.clear(); 206 | _allwordPrime.clear(); 207 | _wordRep.clear(); 208 | _allwordRep.clear(); 209 | _keyCharPrime.clear(); 210 | _keyCharRep.clear(); 211 | _lengthPrime.clear(); 212 | _lengthRep.clear(); 213 | _wordUnitRep.clear(); 214 | _wordHidden.clear(); 215 | for(int idk = 0; idk < _buffer; idk++){ 216 | _wordRNNHiddenBuf[idk].clear(); 217 | } 218 | _wordRNNHiddenBuf.clear(); 219 | _wordRNNHidden.clear(); 220 | 221 | _wordPrimeLoss.clear(); 222 | _wordRepLoss.clear(); 223 | _allwordRepLoss.clear(); 224 | _keyCharPrimeLoss.clear(); 225 | _keyCharRepLoss.clear(); 226 | _lengthPrimeLoss.clear(); 227 | _lengthRepLoss.clear(); 228 | _wordUnitRepLoss.clear(); 229 | _wordUnitRepMask.clear(); 230 | _wordHiddenLoss.clear(); 231 | _wordRNNHiddenLoss.clear(); 232 | 233 | for (int idx = 0; idx < _steps; idx++) { 234 | _actionIds[idx].clear(); 235 | FreeSpace(&(_actionPrime[idx])); 236 | FreeSpace(&(_actionRep[idx])); 237 | FreeSpace(&(_actionHidden[idx])); 238 | for(int idk = 0; idk < _buffer; idk++){ 239 | FreeSpace(&(_actionRNNHiddenBuf[idk][idx])); 240 | } 241 | FreeSpace(&(_actionRNNHidden[idx])); 242 | 243 | FreeSpace(&(_actionPrimeLoss[idx])); 244 | FreeSpace(&(_actionRepLoss[idx])); 245 | FreeSpace(&(_actionRepMask[idx])); 246 | FreeSpace(&(_actionHiddenLoss[idx])); 247 | FreeSpace(&(_actionRNNHiddenLoss[idx])); 248 | } 249 | _actionIds.clear(); 250 | _actionPrime.clear(); 251 | _actionRep.clear(); 252 | _actionHidden.clear(); 253 | for(int idk = 0; idk < _buffer; idk++){ 254 | _actionRNNHiddenBuf[idk].clear(); 255 | } 256 | _actionRNNHiddenBuf.clear(); 257 | _actionRNNHidden.clear(); 258 | 259 | _actionPrimeLoss.clear(); 260 | _actionRepLoss.clear(); 261 | _actionRepMask.clear(); 262 | _actionHiddenLoss.clear(); 263 | _actionRNNHiddenLoss.clear(); 264 | 265 | _wordnum = 0; 266 | _steps = 0; 267 | _buffer = 0; 268 | } 269 | 270 | }; 271 | 272 | #endif /* FEATURE_DENSEFEATURE_H_ */ 273 | -------------------------------------------------------------------------------- /feature/DenseFeatureChar.h: -------------------------------------------------------------------------------- 1 | /* 2 | * DenseFeatureChar.h 3 | * 4 | * Created on: Jan 25, 2016 5 | * Author: mason 6 | */ 7 | 8 | #ifndef FEATURE_DENSEFEATURECHAR_H_ 9 | #define FEATURE_DENSEFEATURECHAR_H_ 10 | 11 | #include "N3L.h" 12 | 13 | template 14 | class DenseFeatureChar { 15 | public: 16 | //all state inter dependent features 17 | vector _charIds, _bicharIds; 18 | Tensor _charprime, _bicharprime; 19 | Tensor _charpre, _charpreMask; 20 | Tensor _charInput, _charHidden; 21 | vector > _charLeftRNNHiddenBuf, _charRightRNNHiddenBuf; 22 | Tensor _charLeftRNNHidden, _charRightRNNHidden; 23 | Tensor _charRNNHiddenDummy; 24 | 25 | Tensor _charprime_Loss, _bicharprime_Loss, _charpre_Loss; 26 | Tensor _charInput_Loss, _charHidden_Loss; 27 | Tensor _charLeftRNNHidden_Loss, _charRightRNNHidden_Loss; 28 | Tensor _charRNNHiddenDummy_Loss; 29 | 30 | bool _bTrain; 31 | int _charnum; 32 | int _buffer; 33 | 34 | public: 35 | DenseFeatureChar() { 36 | _bTrain = false; 37 | _charnum = 0; 38 | _buffer = 0; 39 | } 40 | 41 | ~DenseFeatureChar() { 42 | clear(); 43 | } 44 | 45 | public: 46 | inline void init(int charnum, int charDim, int bicharDim, int charcontext, int charHiddenDim, int charRNNHiddenDim, int buffer = 0, bool bTrain = false) { 47 | clear(); 48 | 49 | _charnum = charnum; 50 | _bTrain = bTrain; 51 | _buffer = buffer; 52 | 53 | if (_charnum > 0) { 54 | int charwindow = 2 * charcontext + 1; 55 | int charRepresentDim = (charDim + bicharDim) * charwindow; 56 | 57 | _charIds.resize(charnum); 58 | _bicharIds.resize(charnum); 59 | _charprime = NewTensor(Shape3(_charnum, 1, charDim), d_zero); 60 | _bicharprime = NewTensor(Shape3(_charnum, 1, bicharDim), d_zero); 61 | _charpre = NewTensor(Shape3(_charnum, 1, charDim + bicharDim), d_zero); 62 | _charInput = NewTensor(Shape3(_charnum, 1, charRepresentDim), d_zero); 63 | _charHidden = NewTensor(Shape3(_charnum, 1, charHiddenDim), d_zero); 64 | if (_buffer > 0) { 65 | _charLeftRNNHiddenBuf.resize(_buffer); 66 | _charRightRNNHiddenBuf.resize(_buffer); 67 | for (int idk = 0; idk < _buffer; idk++) { 68 | _charLeftRNNHiddenBuf[idk] = NewTensor(Shape3(_charnum, 1, charRNNHiddenDim), d_zero); 69 | _charRightRNNHiddenBuf[idk] = NewTensor(Shape3(_charnum, 1, charRNNHiddenDim), d_zero); 70 | } 71 | } 72 | _charLeftRNNHidden = NewTensor(Shape3(_charnum, 1, charRNNHiddenDim), d_zero); 73 | _charRightRNNHidden = NewTensor(Shape3(_charnum, 1, charRNNHiddenDim), d_zero); 74 | _charRNNHiddenDummy = NewTensor(Shape2(1, charRNNHiddenDim), d_zero); 75 | 76 | if (_bTrain) { 77 | _charpreMask = NewTensor(Shape3(_charnum, 1, charDim + bicharDim), d_zero); 78 | _charprime_Loss = NewTensor(Shape3(_charnum, 1, charDim), d_zero); 79 | _bicharprime_Loss = NewTensor(Shape3(_charnum, 1, bicharDim), d_zero); 80 | _charpre_Loss = NewTensor(Shape3(_charnum, 1, charDim + bicharDim), d_zero); 81 | _charInput_Loss = NewTensor(Shape3(_charnum, 1, charRepresentDim), d_zero); 82 | _charHidden_Loss = NewTensor(Shape3(_charnum, 1, charHiddenDim), d_zero); 83 | _charLeftRNNHidden_Loss = NewTensor(Shape3(_charnum, 1, charRNNHiddenDim), d_zero); 84 | _charRightRNNHidden_Loss = NewTensor(Shape3(_charnum, 1, charRNNHiddenDim), d_zero); 85 | _charRNNHiddenDummy_Loss = NewTensor(Shape2(1, charRNNHiddenDim), d_zero); 86 | } 87 | } 88 | 89 | } 90 | 91 | inline void clear() { 92 | if (_charnum > 0) { 93 | _charIds.clear(); 94 | _bicharIds.clear(); 95 | 96 | FreeSpace(&_charprime); 97 | FreeSpace(&_bicharprime); 98 | FreeSpace(&_charpre); 99 | FreeSpace(&_charInput); 100 | FreeSpace(&_charHidden); 101 | if (_buffer > 0) { 102 | for (int idk = 0; idk < _buffer; idk++) { 103 | FreeSpace(&(_charLeftRNNHiddenBuf[idk])); 104 | FreeSpace(&(_charRightRNNHiddenBuf[idk])); 105 | } 106 | _charLeftRNNHiddenBuf.clear(); 107 | _charRightRNNHiddenBuf.clear(); 108 | } 109 | FreeSpace(&_charLeftRNNHidden); 110 | FreeSpace(&_charRightRNNHidden); 111 | FreeSpace(&_charRNNHiddenDummy); 112 | 113 | if (_bTrain) { 114 | FreeSpace(&_charprime_Loss); 115 | FreeSpace(&_bicharprime_Loss); 116 | FreeSpace(&_charpreMask); 117 | FreeSpace(&_charpre_Loss); 118 | FreeSpace(&_charInput_Loss); 119 | FreeSpace(&_charHidden_Loss); 120 | FreeSpace(&_charLeftRNNHidden_Loss); 121 | FreeSpace(&_charRightRNNHidden_Loss); 122 | FreeSpace(&_charRNNHiddenDummy_Loss); 123 | } 124 | 125 | } 126 | 127 | _bTrain = false; 128 | _charnum = 0; 129 | _buffer = 0; 130 | } 131 | 132 | }; 133 | 134 | #endif /* FEATURE_DENSEFEATURECHAR_H_ */ 135 | -------------------------------------------------------------------------------- /feature/DenseFeatureExtraction.h: -------------------------------------------------------------------------------- 1 | /* 2 | * FeatureExtraction.h 3 | * 4 | * Created on: Jan 25, 2016 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef BASIC_FEATUREEXTRACTION_H_ 9 | #define BASIC_FEATUREEXTRACTION_H_ 10 | #include "N3L.h" 11 | #include "Action.h" 12 | #include "NeuralState.h" 13 | #include "Utf.h" 14 | 15 | template 16 | class FeatureExtraction { 17 | public: 18 | std::string nullkey; 19 | std::string rootdepkey; 20 | std::string unknownkey; 21 | std::string paddingtag; 22 | std::string seperateKey; 23 | 24 | public: 25 | Alphabet _featAlphabet; 26 | Alphabet _wordAlphabet; 27 | Alphabet _allwordAlphabet; 28 | Alphabet _charAlphabet; 29 | Alphabet _bicharAlphabet; 30 | Alphabet _actionAlphabet; 31 | 32 | public: 33 | bool _bStringFeat; // string-formated features or digit-formated features 34 | 35 | public: 36 | FeatureExtraction() { 37 | nullkey = "-null-"; 38 | unknownkey = "-unknown-"; 39 | paddingtag = "-padding-"; 40 | seperateKey = "#"; 41 | 42 | _bStringFeat = true; 43 | } 44 | 45 | FeatureExtraction(bool bCollecting) { 46 | nullkey = "-null-"; 47 | unknownkey = "-unknown-"; 48 | paddingtag = "-padding-"; 49 | seperateKey = "#"; 50 | 51 | _bStringFeat = bCollecting; 52 | } 53 | 54 | public: 55 | 56 | inline void setFeatureFormat(bool bStringFeat) { 57 | _bStringFeat = bStringFeat; 58 | } 59 | 60 | inline void setAlphaIncreasing(bool alphaIncreasing) { 61 | if (alphaIncreasing) { 62 | _featAlphabet.set_fixed_flag(false); 63 | _wordAlphabet.set_fixed_flag(false); 64 | _allwordAlphabet.set_fixed_flag(false); 65 | _charAlphabet.set_fixed_flag(false); 66 | _bicharAlphabet.set_fixed_flag(false); 67 | _actionAlphabet.set_fixed_flag(false); 68 | } else { 69 | _featAlphabet.set_fixed_flag(true); 70 | _wordAlphabet.set_fixed_flag(true); 71 | _allwordAlphabet.set_fixed_flag(true); 72 | _charAlphabet.set_fixed_flag(true); 73 | _bicharAlphabet.set_fixed_flag(true); 74 | _actionAlphabet.set_fixed_flag(true); 75 | } 76 | } 77 | 78 | inline void setFeatAlphaIncreasing(bool alphaIncreasing) { 79 | if (alphaIncreasing) { 80 | _featAlphabet.set_fixed_flag(false); 81 | } else { 82 | _featAlphabet.set_fixed_flag(true); 83 | } 84 | } 85 | 86 | inline int getCharAlphaId(const std::string & oneChar) { 87 | return _charAlphabet[oneChar]; 88 | } 89 | 90 | inline int getBiCharAlphaId(const std::string & twoChar) { 91 | return _bicharAlphabet[twoChar]; 92 | } 93 | 94 | void extractFeature(const CStateItem * curState, const CAction& nextAC, Feature& feat, int wordNgram = 0, int actionNgram = 0) { 95 | feat.clear(); 96 | feat.setFeatureFormat(_bStringFeat); 97 | if (nextAC._code == CAction::APP) { 98 | extractFeatureApp(curState, feat, wordNgram, actionNgram); 99 | } else if (nextAC._code == CAction::SEP) { 100 | extractFeatureSep(curState, feat, wordNgram, actionNgram); 101 | } else if (nextAC._code == CAction::FIN) { 102 | extractFeatureFinish(curState, feat, wordNgram, actionNgram); 103 | } else { 104 | 105 | } 106 | 107 | static int featId, unknownID; 108 | if (!_bStringFeat) { 109 | for (int idx = 0; idx < feat._strSparseFeat.size(); idx++) { 110 | featId = _featAlphabet[feat._strSparseFeat[idx]]; 111 | if (featId >= 0) 112 | feat._nSparseFeat.push_back(featId); 113 | } 114 | feat._strSparseFeat.clear(); 115 | 116 | if (wordNgram > 0) { 117 | feat._nWordFeat.resize(wordNgram); 118 | unknownID = _wordAlphabet[unknownkey]; 119 | for (int idx = 0; idx < wordNgram; idx++) { 120 | featId = idx < feat._strWordFeat.size() ? _wordAlphabet[normalize_to_lowerwithdigit(feat._strWordFeat[idx])] : _wordAlphabet[nullkey]; 121 | if (featId < 0) 122 | featId = unknownID; 123 | feat._nWordFeat[idx] = featId; 124 | } 125 | 126 | feat._nAllWordFeat.resize(wordNgram); 127 | unknownID = _allwordAlphabet[unknownkey]; 128 | for (int idx = 0; idx < wordNgram; idx++) { 129 | featId = idx < feat._strWordFeat.size() ? _allwordAlphabet[feat._strWordFeat[idx]] : _allwordAlphabet[nullkey]; 130 | if (featId < 0) 131 | featId = unknownID; 132 | feat._nAllWordFeat[idx] = featId; 133 | } 134 | 135 | feat._strWordFeat.clear(); 136 | 137 | for(int idx = feat._nWordLengths.size(); idx < wordNgram; idx++){ 138 | feat._nWordLengths.push_back(0); 139 | } 140 | 141 | feat._nKeyChars.resize(2 * wordNgram + 1); 142 | for (int idx = 0; idx < 2 * wordNgram + 1; idx++) { 143 | featId = idx < feat._strKeyChars.size() ? _charAlphabet[feat._strKeyChars[idx]] : _charAlphabet[nullkey]; 144 | if (featId < 0) 145 | featId = unknownID; 146 | feat._nKeyChars[idx] = featId; 147 | } 148 | 149 | } 150 | 151 | if (actionNgram > 0) { 152 | feat._nActionFeat.resize(actionNgram); 153 | for (int idx = 0; idx < actionNgram; idx++) { 154 | featId = idx < feat._strActionFeat.size() ? _actionAlphabet[feat._strActionFeat[idx]] : _actionAlphabet[nullkey]; 155 | if(featId == -1) featId = 0; //noAction 156 | feat._nActionFeat[idx] = featId; 157 | } 158 | feat._strActionFeat.clear(); 159 | } 160 | } 161 | } 162 | 163 | void addToFeatureAlphabet(hash_map feat_stat, int featCutOff = 0) { 164 | _featAlphabet.set_fixed_flag(false); 165 | hash_map::iterator feat_iter; 166 | for (feat_iter = feat_stat.begin(); feat_iter != feat_stat.end(); feat_iter++) { 167 | if (feat_iter->second > featCutOff) { 168 | _featAlphabet.from_string(feat_iter->first); 169 | } 170 | } 171 | _featAlphabet.set_fixed_flag(true); 172 | } 173 | 174 | void addToWordAlphabet(hash_map word_stat, int wordCutOff = 0) { 175 | _wordAlphabet.set_fixed_flag(false); 176 | hash_map::iterator word_iter; 177 | for (word_iter = word_stat.begin(); word_iter != word_stat.end(); word_iter++) { 178 | if (word_iter->second > wordCutOff) { 179 | _wordAlphabet.from_string(word_iter->first); 180 | } 181 | } 182 | _wordAlphabet.set_fixed_flag(true); 183 | } 184 | 185 | void addToAllWordAlphabet(hash_map allword_stat, int allwordCutOff = 0) { 186 | _allwordAlphabet.set_fixed_flag(false); 187 | hash_map::iterator allword_iter; 188 | for (allword_iter = allword_stat.begin(); allword_iter != allword_stat.end(); allword_iter++) { 189 | if (allword_iter->second > allwordCutOff) { 190 | _allwordAlphabet.from_string(allword_iter->first); 191 | } 192 | } 193 | _allwordAlphabet.set_fixed_flag(true); 194 | } 195 | 196 | void addToCharAlphabet(hash_map char_stat, int charCutOff = 0) { 197 | _charAlphabet.set_fixed_flag(false); 198 | hash_map::iterator char_iter; 199 | for (char_iter = char_stat.begin(); char_iter != char_stat.end(); char_iter++) { 200 | if (char_iter->second > charCutOff) { 201 | _charAlphabet.from_string(char_iter->first); 202 | } 203 | } 204 | _charAlphabet.set_fixed_flag(true); 205 | } 206 | 207 | void addToBiCharAlphabet(hash_map bichar_stat, int bicharCutOff = 0) { 208 | _bicharAlphabet.set_fixed_flag(false); 209 | hash_map::iterator bichar_iter; 210 | for (bichar_iter = bichar_stat.begin(); bichar_iter != bichar_stat.end(); bichar_iter++) { 211 | if (bichar_iter->second > bicharCutOff) { 212 | _bicharAlphabet.from_string(bichar_iter->first); 213 | } 214 | } 215 | _bicharAlphabet.set_fixed_flag(true); 216 | } 217 | 218 | void addToActionAlphabet(hash_map action_stat) { 219 | _actionAlphabet.set_fixed_flag(false); 220 | hash_map::iterator action_iter; 221 | for (action_iter = action_stat.begin(); action_iter != action_stat.end(); action_iter++) { 222 | _actionAlphabet.from_string(action_iter->first); 223 | } 224 | _actionAlphabet.set_fixed_flag(true); 225 | } 226 | 227 | void initAlphabet() { 228 | //alphabet initialization 229 | _featAlphabet.clear(); 230 | _featAlphabet.set_fixed_flag(true); 231 | 232 | _wordAlphabet.clear(); 233 | _wordAlphabet.from_string(nullkey); 234 | _wordAlphabet.from_string(unknownkey); 235 | _wordAlphabet.set_fixed_flag(true); 236 | 237 | _allwordAlphabet.clear(); 238 | _allwordAlphabet.from_string(nullkey); 239 | _allwordAlphabet.from_string(unknownkey); 240 | _allwordAlphabet.set_fixed_flag(true); 241 | 242 | _charAlphabet.clear(); 243 | _charAlphabet.from_string(nullkey); 244 | _charAlphabet.from_string(unknownkey); 245 | _charAlphabet.set_fixed_flag(true); 246 | 247 | _bicharAlphabet.clear(); 248 | _bicharAlphabet.from_string(nullkey); 249 | _bicharAlphabet.from_string(unknownkey); 250 | _bicharAlphabet.set_fixed_flag(true); 251 | 252 | _actionAlphabet.clear(); 253 | _actionAlphabet.from_string(nullkey); 254 | _actionAlphabet.set_fixed_flag(true); 255 | 256 | _bStringFeat = true; 257 | } 258 | 259 | void loadAlphabet() { 260 | _bStringFeat = false; 261 | } 262 | 263 | protected: 264 | void extractFeatureApp(const CStateItem * curState, Feature& feat, int wordNgram = 0, int actionNgram = 0) { 265 | string curWord = curState->_lastWordEnd == -1 ? nullkey : curState->_strlastWord; 266 | const std::vector * pCharacters = curState->_pCharacters; 267 | string nextChar = pCharacters->at(curState->_nextPosition); 268 | string curWordLastChar = curState->_lastWordEnd == -1 ? nullkey : pCharacters->at(curState->_lastWordEnd); 269 | string curWordLast2Char = curState->_lastWordEnd < 1 ? nullkey : pCharacters->at(curState->_lastWordEnd - 1); 270 | string curWordFirstChar = curState->_lastWordEnd == -1 ? nullkey : pCharacters->at(curState->_lastWordStart); 271 | 272 | string curWordLastCharType = curState->_lastWordEnd == -1 ? nullkey : wordtype(curWordLastChar); 273 | string curWordLast2CharType = curState->_lastWordEnd < 1 ? nullkey : wordtype(curWordLast2Char); 274 | string nextCharType = wordtype(nextChar); 275 | 276 | if (curState->_nextPosition != curState->_lastWordEnd + 1) { 277 | std::cout << "position error" << std::endl; 278 | } 279 | 280 | string strFeat = ""; 281 | strFeat = "F01" + seperateKey + curWordLastChar + seperateKey + nextChar; 282 | feat._strSparseFeat.push_back(strFeat); 283 | 284 | strFeat = "F02" + seperateKey + curWordFirstChar + seperateKey + nextChar; 285 | feat._strSparseFeat.push_back(strFeat); 286 | 287 | /* 288 | strFeat = "F03" + seperateKey + nextCharType; 289 | feat._strSparseFeat.push_back(strFeat); 290 | 291 | strFeat = "F04" + seperateKey + curWordLastCharType + nextCharType; 292 | feat._strSparseFeat.push_back(strFeat); 293 | 294 | */ 295 | strFeat = "F05" + seperateKey + curWordLast2CharType + curWordLastCharType + nextCharType; 296 | feat._strSparseFeat.push_back(strFeat); 297 | 298 | int ngram = 0; 299 | 300 | if (ngram < actionNgram) { 301 | feat._strActionFeat.push_back(CAction(CAction::APP).str()); 302 | ngram++; 303 | const CStateItem* theState = curState; 304 | while (theState != NULL && ngram < actionNgram) { 305 | feat._strActionFeat.push_back(theState->_lastAction.str()); 306 | theState = theState->_prevState; 307 | ngram++; 308 | } 309 | } 310 | 311 | } 312 | 313 | void extractFeatureSep(const CStateItem * curState, Feature& feat, int wordNgram = 0, int actionNgram = 0) { 314 | string curWord = curState->_lastWordEnd == -1 ? nullkey : curState->_strlastWord; 315 | const std::vector * pCharacters = curState->_pCharacters; 316 | string nextChar = pCharacters->at(curState->_nextPosition); 317 | string curWordLastChar = curState->_lastWordEnd == -1 ? nullkey : pCharacters->at(curState->_lastWordEnd); 318 | string curWordLast2Char = curState->_lastWordEnd < 1 ? nullkey : pCharacters->at(curState->_lastWordEnd - 1); 319 | string curWordFirstChar = curState->_lastWordEnd == -1 ? nullkey : pCharacters->at(curState->_lastWordStart); 320 | 321 | string curWordLastCharType = curState->_lastWordEnd == -1 ? nullkey : wordtype(curWordLastChar); 322 | string curWordLast2CharType = curState->_lastWordEnd < 1 ? nullkey : wordtype(curWordLast2Char); 323 | string nextCharType = wordtype(nextChar); 324 | 325 | int length = curState->_lastWordEnd - curState->_lastWordStart + 1; 326 | if (length > 5) 327 | length = 5; 328 | stringstream curss; 329 | curss << length; 330 | string strCurWordLen = curss.str(); 331 | 332 | if (curState->_nextPosition != curState->_lastWordEnd + 1) { 333 | std::cout << "position error" << std::endl; 334 | } 335 | 336 | const CStateItem * preStackState = curState->_prevStackState; 337 | string pre1Word = preStackState == 0 || preStackState->_lastWordEnd == -1 ? nullkey : preStackState->_strlastWord; 338 | string pre1WordLastChar = preStackState == 0 || preStackState->_lastWordEnd == -1 ? nullkey : pCharacters->at(preStackState->_lastWordEnd); 339 | string pre1WordFirstChar = preStackState == 0 || preStackState->_lastWordEnd == -1 ? nullkey : pCharacters->at(preStackState->_lastWordStart); 340 | 341 | length = preStackState == 0 ? 0 : preStackState->_lastWordEnd - preStackState->_lastWordStart + 1; 342 | if (length > 5) 343 | length = 5; 344 | stringstream press; 345 | press << length; 346 | string strPreWordLen = press.str(); 347 | 348 | string strFeat = ""; 349 | strFeat = "F11" + seperateKey + curWordLastChar + seperateKey + nextChar; 350 | feat._strSparseFeat.push_back(strFeat); 351 | 352 | strFeat = "F12" + seperateKey + curWordFirstChar + seperateKey + nextChar; 353 | feat._strSparseFeat.push_back(strFeat); 354 | 355 | /* 356 | strFeat = "F13" + seperateKey + nextCharType; 357 | feat._strSparseFeat.push_back(strFeat); 358 | 359 | strFeat = "F14" + seperateKey + curWordLastCharType + nextCharType; 360 | feat._strSparseFeat.push_back(strFeat); 361 | */ 362 | strFeat = "F15" + seperateKey + curWordLast2CharType + curWordLastCharType + nextCharType; 363 | feat._strSparseFeat.push_back(strFeat); 364 | 365 | strFeat = "F16" + seperateKey + curWord + seperateKey + nextChar; 366 | feat._strSparseFeat.push_back(strFeat); 367 | 368 | strFeat = "F17" + seperateKey + curWord; 369 | feat._strSparseFeat.push_back(strFeat); 370 | 371 | strFeat = "F18" + seperateKey + curWord + seperateKey + pre1Word; 372 | feat._strSparseFeat.push_back(strFeat); 373 | 374 | strFeat = "F19" + seperateKey + curWord + seperateKey + pre1WordLastChar; 375 | feat._strSparseFeat.push_back(strFeat); 376 | 377 | strFeat = "F20" + seperateKey + curWord + seperateKey + pre1WordFirstChar; 378 | feat._strSparseFeat.push_back(strFeat); 379 | 380 | strFeat = "F21" + seperateKey + curWordLastChar + seperateKey + pre1WordLastChar; 381 | feat._strSparseFeat.push_back(strFeat); 382 | 383 | strFeat = "F22" + seperateKey + curWordFirstChar + seperateKey + curWordLastChar; 384 | feat._strSparseFeat.push_back(strFeat); 385 | 386 | strFeat = "F23" + seperateKey + curWord + seperateKey + strPreWordLen; 387 | feat._strSparseFeat.push_back(strFeat); 388 | 389 | strFeat = "F24" + seperateKey + pre1Word + seperateKey + strCurWordLen; 390 | feat._strSparseFeat.push_back(strFeat); 391 | 392 | strFeat = "F25" + seperateKey + pre1Word + seperateKey + curWordLastChar; 393 | feat._strSparseFeat.push_back(strFeat); 394 | 395 | if (curState->_lastWordStart != -1) { 396 | for (int idx = curState->_lastWordStart; idx < curState->_lastWordEnd; idx++) { 397 | strFeat = "F26" + seperateKey + pCharacters->at(idx) + seperateKey + curWordLastChar; 398 | feat._strSparseFeat.push_back(strFeat); 399 | } 400 | } 401 | 402 | if (curState->_lastWordEnd == curState->_lastWordStart && curState->_lastWordStart != -1) { 403 | strFeat = "F27" + seperateKey + curWord; 404 | feat._strSparseFeat.push_back(strFeat); 405 | } 406 | 407 | strFeat = "F28" + seperateKey + curWordFirstChar + seperateKey + strCurWordLen; 408 | feat._strSparseFeat.push_back(strFeat); 409 | 410 | strFeat = "F29" + seperateKey + curWordLastChar + seperateKey + strCurWordLen; 411 | feat._strSparseFeat.push_back(strFeat); 412 | 413 | int ngram = 0; 414 | 415 | if (ngram < actionNgram) { 416 | feat._strActionFeat.push_back(CAction(CAction::SEP).str()); 417 | ngram++; 418 | const CStateItem* theState = curState; 419 | while (theState != NULL && ngram < actionNgram) { 420 | feat._strActionFeat.push_back(theState->_lastAction.str()); 421 | theState = theState->_prevState; 422 | ngram++; 423 | } 424 | } 425 | 426 | ngram = 0; 427 | 428 | feat._strKeyChars.push_back(nextChar); 429 | if (ngram < wordNgram) { 430 | feat._strWordFeat.push_back(curWord); 431 | feat._strKeyChars.push_back(curWordLastChar); 432 | feat._strKeyChars.push_back(curWordFirstChar); 433 | length = curState->_lastWordEnd == -1 ? 0 : curState->_lastWordEnd - curState->_lastWordStart + 1; 434 | if(length > 5) length = 5; 435 | feat._nWordLengths.push_back(length); 436 | ngram++; 437 | const CStateItem* theState = preStackState; 438 | while (theState != NULL && theState->_lastWordEnd >= 0 && ngram < actionNgram) { 439 | feat._strWordFeat.push_back(theState->_strlastWord); 440 | feat._strKeyChars.push_back(pCharacters->at(theState->_lastWordEnd)); 441 | feat._strKeyChars.push_back(pCharacters->at(theState->_lastWordEnd)); 442 | length = theState->_lastWordEnd - theState->_lastWordStart + 1; 443 | if(length > 5) length = 5; 444 | feat._nWordLengths.push_back(length); 445 | theState = theState->_prevStackState; 446 | ngram++; 447 | } 448 | } 449 | 450 | } 451 | 452 | void extractFeatureFinish(const CStateItem * curState, Feature& feat, int wordNgram = 0, int actionNgram = 0) { 453 | string curWord = curState->_lastWordEnd == -1 ? nullkey : curState->_strlastWord; 454 | const std::vector * pCharacters = curState->_pCharacters; 455 | //string nextChar = pCharacters->at(curState->_nextPosition); 456 | string curWordLastChar = curState->_lastWordEnd == -1 ? nullkey : pCharacters->at(curState->_lastWordEnd); 457 | string curWordLast2Char = curState->_lastWordEnd < 1 ? nullkey : pCharacters->at(curState->_lastWordEnd - 1); 458 | string curWordFirstChar = curState->_lastWordEnd == -1 ? nullkey : pCharacters->at(curState->_lastWordStart); 459 | 460 | int length = curState->_lastWordEnd - curState->_lastWordStart + 1; 461 | if (length > 5) 462 | length = 5; 463 | stringstream curss; 464 | curss << length; 465 | string strCurWordLen = curss.str(); 466 | 467 | if (curState->_nextPosition != curState->_lastWordEnd + 1) { 468 | std::cout << "position error" << std::endl; 469 | } 470 | 471 | const CStateItem * preStackState = curState->_prevStackState; 472 | string pre1Word = preStackState == 0 || preStackState->_lastWordEnd == -1 ? nullkey : preStackState->_strlastWord; 473 | string pre1WordLastChar = preStackState == 0 || preStackState->_lastWordEnd == -1 ? nullkey : pCharacters->at(preStackState->_lastWordEnd); 474 | string pre1WordFirstChar = preStackState == 0 || preStackState->_lastWordEnd == -1 ? nullkey : pCharacters->at(preStackState->_lastWordStart); 475 | 476 | length = preStackState == 0 ? 0 : preStackState->_lastWordEnd - preStackState->_lastWordStart + 1; 477 | if (length > 5) 478 | length = 5; 479 | stringstream press; 480 | press << length; 481 | string strPreWordLen = press.str(); 482 | 483 | string strFeat = ""; 484 | /*strFeat = "F11" + seperateKey + curWordLastChar + seperateKey + nullkey; 485 | feat._strSparseFeat.push_back(strFeat); 486 | 487 | strFeat = "F12" + seperateKey + curWordFirstChar + seperateKey + nullkey; 488 | feat._strSparseFeat.push_back(strFeat); 489 | 490 | strFeat = "F13" + seperateKey + nextCharType; 491 | feat._strSparseFeat.push_back(strFeat); 492 | 493 | strFeat = "F14" + seperateKey + curWordLastCharType + nextCharType; 494 | feat._strSparseFeat.push_back(strFeat); 495 | 496 | strFeat = "F15" + seperateKey + curWordLast2CharType + curWordLastCharType + nextCharType; 497 | feat._strSparseFeat.push_back(strFeat); 498 | */ 499 | 500 | strFeat = "F16" + seperateKey + curWord + seperateKey + nullkey; 501 | feat._strSparseFeat.push_back(strFeat); 502 | 503 | strFeat = "F17" + seperateKey + curWord; 504 | feat._strSparseFeat.push_back(strFeat); 505 | 506 | strFeat = "F18" + seperateKey + curWord + seperateKey + pre1Word; 507 | feat._strSparseFeat.push_back(strFeat); 508 | 509 | strFeat = "F19" + seperateKey + curWord + seperateKey + pre1WordLastChar; 510 | feat._strSparseFeat.push_back(strFeat); 511 | 512 | strFeat = "F20" + seperateKey + curWord + seperateKey + pre1WordFirstChar; 513 | feat._strSparseFeat.push_back(strFeat); 514 | 515 | strFeat = "F21" + seperateKey + curWordLastChar + seperateKey + pre1WordLastChar; 516 | feat._strSparseFeat.push_back(strFeat); 517 | 518 | strFeat = "F22" + seperateKey + curWordFirstChar + seperateKey + curWordLastChar; 519 | feat._strSparseFeat.push_back(strFeat); 520 | 521 | strFeat = "F23" + seperateKey + curWord + seperateKey + strPreWordLen; 522 | feat._strSparseFeat.push_back(strFeat); 523 | 524 | strFeat = "F24" + seperateKey + pre1Word + seperateKey + strCurWordLen; 525 | feat._strSparseFeat.push_back(strFeat); 526 | 527 | strFeat = "F25" + seperateKey + pre1Word + seperateKey + curWordLastChar; 528 | feat._strSparseFeat.push_back(strFeat); 529 | 530 | if (curState->_lastWordStart != -1) { 531 | for (int idx = curState->_lastWordStart; idx < curState->_lastWordEnd; idx++) { 532 | strFeat = "F26" + seperateKey + pCharacters->at(idx) + seperateKey + curWordLastChar; 533 | feat._strSparseFeat.push_back(strFeat); 534 | } 535 | } 536 | 537 | if (curState->_lastWordEnd == curState->_lastWordStart && curState->_lastWordStart != -1) { 538 | strFeat = "F27" + seperateKey + curWord; 539 | feat._strSparseFeat.push_back(strFeat); 540 | } 541 | 542 | strFeat = "F28" + seperateKey + curWordFirstChar + seperateKey + strCurWordLen; 543 | feat._strSparseFeat.push_back(strFeat); 544 | 545 | strFeat = "F29" + seperateKey + curWordLastChar + seperateKey + strCurWordLen; 546 | feat._strSparseFeat.push_back(strFeat); 547 | 548 | int ngram = 0; 549 | 550 | if (ngram < actionNgram) { 551 | feat._strActionFeat.push_back(CAction(CAction::FIN).str()); 552 | ngram++; 553 | const CStateItem* theState = curState; 554 | while (theState != NULL && ngram < actionNgram) { 555 | feat._strActionFeat.push_back(theState->_lastAction.str()); 556 | theState = theState->_prevState; 557 | ngram++; 558 | } 559 | } 560 | 561 | ngram = 0; 562 | 563 | feat._strKeyChars.push_back(nullkey); 564 | if (ngram < wordNgram) { 565 | feat._strWordFeat.push_back(curWord); 566 | feat._strKeyChars.push_back(curWordLastChar); 567 | feat._strKeyChars.push_back(curWordFirstChar); 568 | length = curState->_lastWordEnd == -1 ? 0 : curState->_lastWordEnd - curState->_lastWordStart + 1; 569 | if(length > 5) length = 5; 570 | feat._nWordLengths.push_back(length); 571 | ngram++; 572 | const CStateItem* theState = preStackState; 573 | while (theState != NULL && theState->_lastWordEnd >= 0 && ngram < actionNgram) { 574 | feat._strWordFeat.push_back(theState->_strlastWord); 575 | feat._strKeyChars.push_back(pCharacters->at(theState->_lastWordEnd)); 576 | feat._strKeyChars.push_back(pCharacters->at(theState->_lastWordEnd)); 577 | length = theState->_lastWordEnd - theState->_lastWordStart + 1; 578 | if(length > 5) length = 5; 579 | feat._nWordLengths.push_back(length); 580 | theState = theState->_prevStackState; 581 | ngram++; 582 | } 583 | } 584 | 585 | } 586 | 587 | }; 588 | 589 | #endif /* BASIC_FEATUREEXTRACTION_H_ */ 590 | -------------------------------------------------------------------------------- /feature/DenseFeatureForward.h: -------------------------------------------------------------------------------- 1 | /* 2 | * DenseFeatureForward.h 3 | * 4 | * Created on: Jan 25, 2016 5 | * Author: mason 6 | */ 7 | 8 | #ifndef FEATURE_DENSEFEATUREFORWARD_H_ 9 | #define FEATURE_DENSEFEATUREFORWARD_H_ 10 | 11 | #include "N3L.h" 12 | template 13 | class DenseFeatureForward { 14 | public: 15 | //state inter dependent features 16 | //word 17 | Tensor _wordPrime; 18 | Tensor _allwordPrime; 19 | Tensor _keyCharPrime; 20 | Tensor _lengthPrime; 21 | Tensor _wordRep, _allwordRep, _keyCharRep, _lengthRep; 22 | Tensor _wordUnitRep, _wordUnitRepMask; 23 | Tensor _wordHidden; 24 | vector > _wordRNNHiddenBuf; 25 | Tensor _wordRNNHidden; //lstm 26 | 27 | //action 28 | Tensor _actionPrime; 29 | Tensor _actionRep, _actionRepMask; 30 | Tensor _actionHidden; 31 | vector > _actionRNNHiddenBuf; 32 | Tensor _actionRNNHidden; //lstm 33 | 34 | //state inter independent features 35 | Tensor _sepInHidden, _sepInHiddenLoss; //sep in 36 | Tensor _sepOutHidden, _sepOutHiddenLoss; //sep out 37 | Tensor _appInHidden, _appInHiddenLoss; //app in 38 | Tensor _appOutHidden, _appOutHiddenLoss; //app out 39 | 40 | bool _bAllocated; 41 | bool _bTrain; 42 | int _buffer; 43 | 44 | int _wordDim, _allwordDim, _lengthDim, _charDim, _wordNgram, _wordHiddenDim, _wordRNNDim, _wordUnitDim; 45 | int _actionDim, _actionNgram, _actionPreDim, _actionHiddenDim, _actionRNNDim; 46 | int _sepInHiddenDim, _appInHiddenDim, _sepOutHiddenDim, _appOutHiddenDim; 47 | 48 | public: 49 | DenseFeatureForward() { 50 | _bAllocated = false; 51 | _bTrain = false; 52 | _buffer = 0; 53 | 54 | _wordDim = 0; 55 | _allwordDim = 0; 56 | _charDim = 0; 57 | _lengthDim = 0; 58 | _wordNgram = 0; 59 | _wordUnitDim = 0; 60 | _wordHiddenDim = 0; 61 | _wordRNNDim = 0; 62 | 63 | 64 | _actionDim = 0; 65 | _actionNgram = 0; 66 | _actionPreDim = 0; 67 | _actionHiddenDim = 0; 68 | _actionRNNDim = 0; 69 | 70 | _sepInHiddenDim = 0; 71 | _appInHiddenDim = 0; 72 | _sepOutHiddenDim = 0; 73 | _appOutHiddenDim = 0; 74 | } 75 | 76 | ~DenseFeatureForward() { 77 | clear(); 78 | } 79 | 80 | public: 81 | inline void init(int wordDim, int allwordDim, int charDim, int lengthDim, int wordNgram, int wordHiddenDim, int wordRNNDim, int actionDim, int actionNgram, 82 | int actionHiddenDim, int actionRNNDim, int sepInHiddenDim, int appInHiddenDim, int sepOutHiddenDim, int appOutHiddenDim, int buffer = 0, bool bTrain = false) { 83 | clear(); 84 | _buffer = buffer; 85 | 86 | _wordDim = wordDim; 87 | _allwordDim = allwordDim; 88 | _charDim = charDim; 89 | _lengthDim = lengthDim; 90 | _wordNgram = wordNgram; 91 | _wordUnitDim = wordNgram * wordDim + wordNgram * allwordDim + (2 * wordNgram + 1) * charDim + wordNgram * lengthDim; 92 | _wordHiddenDim = wordHiddenDim; 93 | _wordRNNDim = wordRNNDim; 94 | 95 | _wordPrime = NewTensor(Shape3(_wordNgram, 1, _wordDim), d_zero); 96 | _allwordPrime = NewTensor(Shape3(_wordNgram, 1, _allwordDim), d_zero); 97 | _wordRep = NewTensor(Shape2(1, _wordNgram * _wordDim), d_zero); 98 | _allwordRep = NewTensor(Shape2(1, _wordNgram * _allwordDim), d_zero); 99 | _keyCharPrime = NewTensor(Shape3(2 * _wordNgram + 1, 1, _charDim), d_zero); 100 | _keyCharRep = NewTensor(Shape2(1, (2 * _wordNgram + 1) * _charDim), d_zero); 101 | _lengthPrime = NewTensor(Shape3(_wordNgram, 1, _lengthDim), d_zero); 102 | _lengthRep = NewTensor(Shape2(1, _wordNgram * _lengthDim), d_zero); 103 | _wordUnitRep = NewTensor(Shape2(1, _wordUnitDim), d_zero); 104 | _wordHidden = NewTensor(Shape2(1, _wordHiddenDim), d_zero); 105 | if (_buffer > 0) { 106 | _wordRNNHiddenBuf.resize(_buffer); 107 | for (int idk = 0; idk < _buffer; idk++) { 108 | _wordRNNHiddenBuf[idk] = NewTensor(Shape2(1, _wordRNNDim), d_zero); 109 | } 110 | } 111 | _wordRNNHidden = NewTensor(Shape2(1, _wordRNNDim), d_zero); 112 | 113 | _actionDim = actionDim; 114 | _actionNgram = actionNgram; 115 | _actionPreDim = actionNgram * actionDim; 116 | _actionHiddenDim = actionHiddenDim; 117 | _actionRNNDim = actionRNNDim; 118 | _actionPrime = NewTensor(Shape3(_actionNgram, 1, _actionDim), d_zero); 119 | _actionRep = NewTensor(Shape2(1, _actionPreDim), d_zero); 120 | _actionHidden = NewTensor(Shape2(1, _actionHiddenDim), d_zero); 121 | if (_buffer > 0) { 122 | _actionRNNHiddenBuf.resize(_buffer); 123 | for (int idk = 0; idk < _buffer; idk++) { 124 | _actionRNNHiddenBuf[idk] = NewTensor(Shape2(1, _actionRNNDim), d_zero); 125 | } 126 | } 127 | _actionRNNHidden = NewTensor(Shape2(1, _actionRNNDim), d_zero); 128 | 129 | _sepInHiddenDim = sepInHiddenDim; 130 | _appInHiddenDim = appInHiddenDim; 131 | _sepOutHiddenDim = sepOutHiddenDim; 132 | _appOutHiddenDim = appOutHiddenDim; 133 | _sepInHidden = NewTensor(Shape2(1, _sepInHiddenDim), d_zero); 134 | _appInHidden = NewTensor(Shape2(1, _appInHiddenDim), d_zero); 135 | _sepOutHidden = NewTensor(Shape2(1, _sepOutHiddenDim), d_zero); 136 | _appOutHidden = NewTensor(Shape2(1, _appOutHiddenDim), d_zero); 137 | 138 | if (bTrain) { 139 | _bTrain = bTrain; 140 | 141 | _wordUnitRepMask = NewTensor(Shape2(1, _wordUnitDim), d_zero); 142 | _actionRepMask = NewTensor(Shape2(1, _actionPreDim), d_zero); 143 | 144 | _sepInHiddenLoss = NewTensor(Shape2(1, _sepInHiddenDim), d_zero); 145 | _appInHiddenLoss = NewTensor(Shape2(1, _appInHiddenDim), d_zero); 146 | _sepOutHiddenLoss = NewTensor(Shape2(1, _sepOutHiddenDim), d_zero); 147 | _appOutHiddenLoss = NewTensor(Shape2(1, _appOutHiddenDim), d_zero); 148 | 149 | } 150 | 151 | _bAllocated = true; 152 | } 153 | 154 | inline void clear() { 155 | if (_bAllocated) { 156 | _wordDim = 0; 157 | _allwordDim = 0; 158 | _charDim = 0; 159 | _lengthDim = 0; 160 | _wordNgram = 0; 161 | _wordUnitDim = 0; 162 | _wordHiddenDim = 0; 163 | _wordRNNDim = 0; 164 | 165 | FreeSpace(&_wordPrime); 166 | FreeSpace(&_allwordPrime); 167 | FreeSpace(&_wordRep); 168 | FreeSpace(&_allwordRep); 169 | FreeSpace(&_keyCharPrime); 170 | FreeSpace(&_keyCharRep); 171 | FreeSpace(&_lengthPrime); 172 | FreeSpace(&_lengthRep); 173 | FreeSpace(&_wordUnitRep); 174 | FreeSpace(&_wordHidden); 175 | FreeSpace(&_wordRNNHidden); 176 | if (_buffer > 0) { 177 | for (int idk = 0; idk < _buffer; idk++) { 178 | FreeSpace(&(_wordRNNHiddenBuf[idk])); 179 | } 180 | _wordRNNHiddenBuf.clear(); 181 | } 182 | 183 | _actionDim = 0; 184 | _actionNgram = 0; 185 | _actionPreDim = 0; 186 | _actionHiddenDim = 0; 187 | _actionRNNDim = 0; 188 | FreeSpace(&_actionPrime); 189 | FreeSpace(&_actionRep); 190 | FreeSpace(&_actionHidden); 191 | FreeSpace(&_actionRNNHidden); 192 | if (_buffer > 0) { 193 | for (int idk = 0; idk < _buffer; idk++) { 194 | FreeSpace(&(_actionRNNHiddenBuf[idk])); 195 | } 196 | _actionRNNHiddenBuf.clear(); 197 | } 198 | 199 | _sepInHiddenDim = 0; 200 | _appInHiddenDim = 0; 201 | _sepOutHiddenDim = 0; 202 | _appOutHiddenDim = 0; 203 | FreeSpace(&_sepInHidden); 204 | FreeSpace(&_appInHidden); 205 | FreeSpace(&_sepOutHidden); 206 | FreeSpace(&_appOutHidden); 207 | 208 | if (_bTrain) { 209 | FreeSpace(&_wordUnitRepMask); 210 | FreeSpace(&_actionRepMask); 211 | 212 | FreeSpace(&_sepInHiddenLoss); 213 | FreeSpace(&_appInHiddenLoss); 214 | FreeSpace(&_sepOutHiddenLoss); 215 | FreeSpace(&_appOutHiddenLoss); 216 | } 217 | 218 | _bAllocated = false; 219 | _bTrain = false; 220 | _buffer = 0; 221 | } 222 | } 223 | 224 | inline void copy(const DenseFeatureForward& other) { 225 | if (other._bAllocated) { 226 | if (_bAllocated) { 227 | if (_wordDim != other._wordDim || _allwordDim != other._allwordDim || _charDim != other._charDim || _lengthDim != other._lengthDim 228 | || _wordNgram != other._wordNgram || _wordHiddenDim != other._wordHiddenDim || _wordRNNDim != other._wordRNNDim 229 | || _actionDim != other._actionDim || _actionNgram != other._actionNgram || _actionHiddenDim != other._actionHiddenDim 230 | || _actionRNNDim != other._actionRNNDim || _sepInHiddenDim != other._sepInHiddenDim || _appInHiddenDim != other._appInHiddenDim 231 | || _sepOutHiddenDim != other._sepOutHiddenDim || _appOutHiddenDim != other._appOutHiddenDim) { 232 | std::cout << "please check, error allocatation somewhere" << std::endl; 233 | return; 234 | } 235 | } else { 236 | init(other._wordDim, other._allwordDim, other._charDim, other._lengthDim, other._wordNgram, other._wordHiddenDim, other._wordRNNDim, other._actionDim, other._actionNgram, other._actionHiddenDim, 237 | other._actionRNNDim, other._sepInHiddenDim, other._appInHiddenDim, other._sepOutHiddenDim, other._appOutHiddenDim, other._buffer, other._bTrain); 238 | } 239 | Copy(_wordPrime, other._wordPrime); 240 | Copy(_allwordPrime, other._allwordPrime); 241 | Copy(_wordRep, other._wordRep); 242 | Copy(_allwordRep, other._allwordRep); 243 | Copy(_keyCharPrime, other._keyCharPrime); 244 | Copy(_keyCharRep, other._keyCharRep); 245 | Copy(_lengthPrime, other._lengthPrime); 246 | Copy(_lengthRep, other._lengthRep); 247 | Copy(_wordUnitRep, other._wordUnitRep); 248 | Copy(_wordHidden, other._wordHidden); 249 | for(int idk = 0; idk < _buffer; idk++){ 250 | Copy(_wordRNNHiddenBuf[idk], other._wordRNNHiddenBuf[idk]); 251 | } 252 | Copy(_wordRNNHidden, other._wordRNNHidden); 253 | 254 | Copy(_actionPrime, other._actionPrime); 255 | Copy(_actionRep, other._actionRep); 256 | Copy(_actionHidden, other._actionHidden); 257 | for(int idk = 0; idk < _buffer; idk++){ 258 | Copy(_actionRNNHiddenBuf[idk], other._actionRNNHiddenBuf[idk]); 259 | } 260 | Copy(_actionRNNHidden, other._actionRNNHidden); 261 | 262 | Copy(_sepInHidden, other._sepInHidden); 263 | Copy(_appInHidden, other._appInHidden); 264 | Copy(_sepOutHidden, other._sepOutHidden); 265 | Copy(_appOutHidden, other._appOutHidden); 266 | 267 | if (other._bTrain) { 268 | Copy(_wordUnitRepMask, other._wordUnitRepMask); 269 | Copy(_actionRepMask, other._actionRepMask); 270 | 271 | Copy(_sepInHiddenLoss, other._sepInHiddenLoss); 272 | Copy(_appInHiddenLoss, other._appInHiddenLoss); 273 | Copy(_sepOutHiddenLoss, other._sepOutHiddenLoss); 274 | Copy(_appOutHiddenLoss, other._appOutHiddenLoss); 275 | } 276 | } else { 277 | clear(); 278 | } 279 | 280 | _bAllocated = other._bAllocated; 281 | _bTrain = other._bTrain; 282 | _buffer = other._buffer; 283 | } 284 | 285 | /* 286 | inline DenseFeatureForward& operator=(const DenseFeatureForward &rhs) { 287 | // Check for self-assignment! 288 | if (this == &rhs) // Same object? 289 | return *this; // Yes, so skip assignment, and just return *this. 290 | copy(rhs); 291 | return *this; 292 | } 293 | */ 294 | }; 295 | 296 | #endif /* FEATURE_DENSEFEATUREFORWARD_H_ */ 297 | -------------------------------------------------------------------------------- /feature/Feature.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Feature.h 3 | * 4 | * Created on: Jan 25, 2016 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef SRC_FEATURE_H_ 9 | #define SRC_FEATURE_H_ 10 | 11 | #include 12 | 13 | using namespace std; 14 | 15 | class Feature { 16 | 17 | public: 18 | vector _nSparseFeat; 19 | vector _nWordFeat; 20 | vector _nActionFeat; 21 | vector _nKeyChars; 22 | 23 | 24 | vector _nWordLengths; 25 | vector _nAllWordFeat; 26 | 27 | vector _strSparseFeat; 28 | vector _strWordFeat; 29 | vector _strActionFeat; 30 | vector _strKeyChars; 31 | 32 | bool _bStringFeat; 33 | 34 | public: 35 | Feature() { 36 | _bStringFeat = false; 37 | clear(); 38 | } 39 | 40 | Feature(bool bCollecting) { 41 | _bStringFeat = bCollecting; 42 | clear(); 43 | } 44 | 45 | ~Feature() { 46 | clear(); 47 | } 48 | 49 | void setFeatureFormat(bool bStringFeat) { 50 | _bStringFeat = bStringFeat; 51 | } 52 | 53 | void copy(const Feature& other) { 54 | clear(); 55 | if (other._bStringFeat) { 56 | for (int idx = 0; idx < other._strSparseFeat.size(); idx++) { 57 | _strSparseFeat.push_back(other._strSparseFeat[idx]); 58 | } 59 | for (int idx = 0; idx < other._strWordFeat.size(); idx++) { 60 | _strWordFeat.push_back(other._strWordFeat[idx]); 61 | } 62 | for (int idx = 0; idx < other._strActionFeat.size(); idx++) { 63 | _strActionFeat.push_back(other._strActionFeat[idx]); 64 | } 65 | for (int idx = 0; idx < other._strKeyChars.size(); idx++) { 66 | _strKeyChars.push_back(other._strKeyChars[idx]); 67 | } 68 | for (int idx = 0; idx < other._nWordLengths.size(); idx++) { 69 | _nWordLengths.push_back(other._nWordLengths[idx]); 70 | } 71 | for (int idx = 0; idx < other._nAllWordFeat.size(); idx++) { 72 | _nAllWordFeat.push_back(other._nAllWordFeat[idx]); 73 | } 74 | } else { 75 | for (int idx = 0; idx < other._nSparseFeat.size(); idx++) { 76 | _nSparseFeat.push_back(other._nSparseFeat[idx]); 77 | } 78 | for (int idx = 0; idx < other._nWordFeat.size(); idx++) { 79 | _nWordFeat.push_back(other._nWordFeat[idx]); 80 | } 81 | for (int idx = 0; idx < other._nActionFeat.size(); idx++) { 82 | _nActionFeat.push_back(other._nActionFeat[idx]); 83 | } 84 | for (int idx = 0; idx < other._nKeyChars.size(); idx++) { 85 | _nKeyChars.push_back(other._nKeyChars[idx]); 86 | } 87 | for (int idx = 0; idx < other._nWordLengths.size(); idx++) { 88 | _nWordLengths.push_back(other._nWordLengths[idx]); 89 | } 90 | for (int idx = 0; idx < other._nAllWordFeat.size(); idx++) { 91 | _nAllWordFeat.push_back(other._nAllWordFeat[idx]); 92 | } 93 | } 94 | } 95 | 96 | /* 97 | inline Feature& operator=(const Feature &rhs) { 98 | // Check for self-assignment! 99 | if (this == &rhs) // Same object? 100 | return *this; // Yes, so skip assignment, and just return *this. 101 | copy(rhs); 102 | return *this; 103 | } 104 | */ 105 | 106 | void clear() { 107 | _nSparseFeat.clear(); 108 | _nWordFeat.clear(); 109 | _nActionFeat.clear(); 110 | _nKeyChars.clear(); 111 | _nWordLengths.clear(); 112 | _nAllWordFeat.clear(); 113 | 114 | _strSparseFeat.clear(); 115 | _strWordFeat.clear(); 116 | _strActionFeat.clear(); 117 | _strKeyChars.clear(); 118 | } 119 | }; 120 | 121 | #endif /* SRC_FEATURE_H_ */ 122 | -------------------------------------------------------------------------------- /feature/FeatureExtraction.h: -------------------------------------------------------------------------------- 1 | /* 2 | * FeatureExtraction.h 3 | * 4 | * Created on: Jan 25, 2016 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef BASIC_FEATUREEXTRACTION_H_ 9 | #define BASIC_FEATUREEXTRACTION_H_ 10 | #include "N3L.h" 11 | #include "Action.h" 12 | #include "State.h" 13 | #include "Utf.h" 14 | 15 | class FeatureExtraction { 16 | public: 17 | std::string nullkey; 18 | std::string rootdepkey; 19 | std::string unknownkey; 20 | std::string paddingtag; 21 | std::string seperateKey; 22 | 23 | public: 24 | Alphabet _featAlphabet; 25 | Alphabet _wordAlphabet; 26 | Alphabet _allwordAlphabet; 27 | Alphabet _charAlphabet; 28 | Alphabet _bicharAlphabet; 29 | Alphabet _actionAlphabet; 30 | 31 | public: 32 | bool _bStringFeat; // string-formated features or digit-formated features 33 | 34 | public: 35 | FeatureExtraction() { 36 | nullkey = "-null-"; 37 | unknownkey = "-unknown-"; 38 | paddingtag = "-padding-"; 39 | seperateKey = "#"; 40 | 41 | _bStringFeat = true; 42 | } 43 | 44 | FeatureExtraction(bool bCollecting) { 45 | nullkey = "-null-"; 46 | unknownkey = "-unknown-"; 47 | paddingtag = "-padding-"; 48 | seperateKey = "#"; 49 | 50 | _bStringFeat = bCollecting; 51 | } 52 | 53 | public: 54 | 55 | inline void setFeatureFormat(bool bStringFeat) { 56 | _bStringFeat = bStringFeat; 57 | } 58 | 59 | inline void setAlphaIncreasing(bool alphaIncreasing) { 60 | if (alphaIncreasing) { 61 | _featAlphabet.set_fixed_flag(false); 62 | _wordAlphabet.set_fixed_flag(false); 63 | _allwordAlphabet.set_fixed_flag(false); 64 | _charAlphabet.set_fixed_flag(false); 65 | _bicharAlphabet.set_fixed_flag(false); 66 | _actionAlphabet.set_fixed_flag(false); 67 | } else { 68 | _featAlphabet.set_fixed_flag(true); 69 | _wordAlphabet.set_fixed_flag(true); 70 | _allwordAlphabet.set_fixed_flag(true); 71 | _charAlphabet.set_fixed_flag(true); 72 | _bicharAlphabet.set_fixed_flag(true); 73 | _actionAlphabet.set_fixed_flag(true); 74 | } 75 | } 76 | 77 | inline void setFeatAlphaIncreasing(bool alphaIncreasing) { 78 | if (alphaIncreasing) { 79 | _featAlphabet.set_fixed_flag(false); 80 | } else { 81 | _featAlphabet.set_fixed_flag(true); 82 | } 83 | } 84 | 85 | inline int getCharAlphaId(const std::string & oneChar) { 86 | return _charAlphabet[oneChar]; 87 | } 88 | 89 | inline int getBiCharAlphaId(const std::string & twoChar) { 90 | return _bicharAlphabet[twoChar]; 91 | } 92 | 93 | void extractFeature(const CStateItem * curState, const CAction& nextAC, Feature& feat, int wordNgram = 0, int actionNgram = 0) { 94 | feat.clear(); 95 | feat.setFeatureFormat(_bStringFeat); 96 | if (nextAC._code == CAction::APP) { 97 | extractFeatureApp(curState, feat, wordNgram, actionNgram); 98 | } else if (nextAC._code == CAction::SEP) { 99 | extractFeatureSep(curState, feat, wordNgram, actionNgram); 100 | } else if (nextAC._code == CAction::FIN) { 101 | extractFeatureFinish(curState, feat, wordNgram, actionNgram); 102 | } else { 103 | 104 | } 105 | 106 | static int featId, unknownID; 107 | if (!_bStringFeat) { 108 | for (int idx = 0; idx < feat._strSparseFeat.size(); idx++) { 109 | featId = _featAlphabet[feat._strSparseFeat[idx]]; 110 | if (featId >= 0) 111 | feat._nSparseFeat.push_back(featId); 112 | } 113 | feat._strSparseFeat.clear(); 114 | 115 | if (wordNgram > 0) { 116 | feat._nWordFeat.resize(wordNgram); 117 | unknownID = _wordAlphabet[unknownkey]; 118 | for (int idx = 0; idx < wordNgram; idx++) { 119 | featId = idx < feat._strWordFeat.size() ? _wordAlphabet[feat._strWordFeat[idx]] : _wordAlphabet[nullkey]; 120 | if (featId < 0) 121 | featId = unknownID; 122 | feat._nWordFeat[idx] = featId; 123 | } 124 | 125 | feat._nAllWordFeat.resize(wordNgram); 126 | unknownID = _allwordAlphabet[unknownkey]; 127 | for (int idx = 0; idx < wordNgram; idx++) { 128 | featId = idx < feat._strWordFeat.size() ? _allwordAlphabet[feat._strWordFeat[idx]] : _allwordAlphabet[nullkey]; 129 | if (featId < 0) 130 | featId = unknownID; 131 | feat._nAllWordFeat[idx] = featId; 132 | } 133 | 134 | feat._strWordFeat.clear(); 135 | } 136 | 137 | if (actionNgram > 0) { 138 | feat._nActionFeat.resize(actionNgram); 139 | for (int idx = 0; idx < actionNgram; idx++) { 140 | featId = idx < feat._strActionFeat.size() ? _actionAlphabet[feat._strActionFeat[idx]] : _actionAlphabet[nullkey]; 141 | feat._nActionFeat[idx] = featId; 142 | } 143 | feat._strActionFeat.clear(); 144 | } 145 | } 146 | } 147 | 148 | void addToFeatureAlphabet(hash_map feat_stat, int featCutOff = 0) { 149 | _featAlphabet.set_fixed_flag(false); 150 | hash_map::iterator feat_iter; 151 | for (feat_iter = feat_stat.begin(); feat_iter != feat_stat.end(); feat_iter++) { 152 | if (feat_iter->second > featCutOff) { 153 | _featAlphabet.from_string(feat_iter->first); 154 | } 155 | } 156 | _featAlphabet.set_fixed_flag(true); 157 | } 158 | 159 | void addToWordAlphabet(hash_map word_stat, int wordCutOff = 0) { 160 | _wordAlphabet.set_fixed_flag(false); 161 | hash_map::iterator word_iter; 162 | for (word_iter = word_stat.begin(); word_iter != word_stat.end(); word_iter++) { 163 | if (word_iter->second > wordCutOff) { 164 | _wordAlphabet.from_string(word_iter->first); 165 | } 166 | } 167 | _wordAlphabet.set_fixed_flag(true); 168 | } 169 | 170 | 171 | void addToAllWordAlphabet(hash_map allword_stat, int allwordCutOff = 0) { 172 | _allwordAlphabet.set_fixed_flag(false); 173 | hash_map::iterator allword_iter; 174 | for (allword_iter = allword_stat.begin(); allword_iter != allword_stat.end(); allword_iter++) { 175 | if (allword_iter->second > allwordCutOff) { 176 | _allwordAlphabet.from_string(allword_iter->first); 177 | } 178 | } 179 | _allwordAlphabet.set_fixed_flag(true); 180 | } 181 | 182 | void addToCharAlphabet(hash_map char_stat, int charCutOff = 0) { 183 | _charAlphabet.set_fixed_flag(false); 184 | hash_map::iterator char_iter; 185 | for (char_iter = char_stat.begin(); char_iter != char_stat.end(); char_iter++) { 186 | if (char_iter->second > charCutOff) { 187 | _charAlphabet.from_string(char_iter->first); 188 | } 189 | } 190 | _charAlphabet.set_fixed_flag(true); 191 | } 192 | 193 | void addToBiCharAlphabet(hash_map bichar_stat, int bicharCutOff = 0) { 194 | _bicharAlphabet.set_fixed_flag(false); 195 | hash_map::iterator bichar_iter; 196 | for (bichar_iter = bichar_stat.begin(); bichar_iter != bichar_stat.end(); bichar_iter++) { 197 | if (bichar_iter->second > bicharCutOff) { 198 | _bicharAlphabet.from_string(bichar_iter->first); 199 | } 200 | } 201 | _bicharAlphabet.set_fixed_flag(true); 202 | } 203 | 204 | void addToActionAlphabet(hash_map action_stat) { 205 | _actionAlphabet.set_fixed_flag(false); 206 | hash_map::iterator action_iter; 207 | for (action_iter = action_stat.begin(); action_iter != action_stat.end(); action_iter++) { 208 | _actionAlphabet.from_string(action_iter->first); 209 | } 210 | _actionAlphabet.set_fixed_flag(true); 211 | } 212 | 213 | void initAlphabet() { 214 | //alphabet initialization 215 | _featAlphabet.clear(); 216 | _featAlphabet.set_fixed_flag(true); 217 | 218 | _wordAlphabet.clear(); 219 | _wordAlphabet.from_string(nullkey); 220 | _wordAlphabet.from_string(unknownkey); 221 | _wordAlphabet.set_fixed_flag(true); 222 | 223 | _allwordAlphabet.clear(); 224 | _allwordAlphabet.from_string(nullkey); 225 | _allwordAlphabet.from_string(unknownkey); 226 | _allwordAlphabet.set_fixed_flag(true); 227 | 228 | _charAlphabet.clear(); 229 | _charAlphabet.from_string(nullkey); 230 | _charAlphabet.from_string(unknownkey); 231 | _charAlphabet.set_fixed_flag(true); 232 | 233 | _bicharAlphabet.clear(); 234 | _bicharAlphabet.from_string(nullkey); 235 | _bicharAlphabet.from_string(unknownkey); 236 | _bicharAlphabet.set_fixed_flag(true); 237 | 238 | _actionAlphabet.clear(); 239 | _actionAlphabet.from_string(nullkey); 240 | _actionAlphabet.set_fixed_flag(true); 241 | 242 | _bStringFeat = true; 243 | } 244 | 245 | void loadAlphabet() { 246 | _bStringFeat = false; 247 | } 248 | 249 | protected: 250 | void extractFeatureApp(const CStateItem * curState, Feature& feat, int wordNgram = 0, int actionNgram = 0) { 251 | string curWord = curState->_lastWordEnd == -1 ? nullkey : curState->_strlastWord; 252 | const std::vector * pCharacters = curState->_pCharacters; 253 | string nextChar = pCharacters->at(curState->_nextPosition); 254 | string curWordLastChar = curState->_lastWordEnd == -1 ? nullkey : pCharacters->at(curState->_lastWordEnd); 255 | string curWordLast2Char = curState->_lastWordEnd < 1 ? nullkey : pCharacters->at(curState->_lastWordEnd - 1); 256 | string curWordFirstChar = curState->_lastWordEnd == -1 ? nullkey : pCharacters->at(curState->_lastWordStart); 257 | 258 | string curWordLastCharType = curState->_lastWordEnd == -1 ? nullkey : wordtype(curWordLastChar); 259 | string curWordLast2CharType = curState->_lastWordEnd < 1 ? nullkey : wordtype(curWordLast2Char); 260 | string nextCharType = wordtype(nextChar); 261 | 262 | if (curState->_nextPosition != curState->_lastWordEnd + 1) { 263 | std::cout << "position error" << std::endl; 264 | } 265 | 266 | string strFeat = ""; 267 | strFeat = "F01" + seperateKey + curWordLastChar + seperateKey + nextChar; 268 | feat._strSparseFeat.push_back(strFeat); 269 | 270 | strFeat = "F02" + seperateKey + curWordFirstChar + seperateKey + nextChar; 271 | feat._strSparseFeat.push_back(strFeat); 272 | 273 | /* 274 | strFeat = "F03" + seperateKey + nextCharType; 275 | feat._strSparseFeat.push_back(strFeat); 276 | 277 | strFeat = "F04" + seperateKey + curWordLastCharType + nextCharType; 278 | feat._strSparseFeat.push_back(strFeat); 279 | 280 | */ 281 | strFeat = "F05" + seperateKey + curWordLast2CharType + curWordLastCharType + nextCharType; 282 | feat._strSparseFeat.push_back(strFeat); 283 | 284 | int ngram = 0; 285 | 286 | if (ngram < actionNgram) { 287 | feat._strActionFeat.push_back(CAction(CAction::APP).str()); 288 | ngram++; 289 | const CStateItem* theState = curState; 290 | while (theState != NULL && ngram < actionNgram) { 291 | feat._strActionFeat.push_back(theState->_lastAction.str()); 292 | theState = theState->_prevState; 293 | ngram++; 294 | } 295 | } 296 | 297 | } 298 | 299 | void extractFeatureSep(const CStateItem * curState, Feature& feat, int wordNgram = 0, int actionNgram = 0) { 300 | string curWord = curState->_lastWordEnd == -1 ? nullkey : curState->_strlastWord; 301 | const std::vector * pCharacters = curState->_pCharacters; 302 | string nextChar = pCharacters->at(curState->_nextPosition); 303 | string curWordLastChar = curState->_lastWordEnd == -1 ? nullkey : pCharacters->at(curState->_lastWordEnd); 304 | string curWordLast2Char = curState->_lastWordEnd < 1 ? nullkey : pCharacters->at(curState->_lastWordEnd - 1); 305 | string curWordFirstChar = curState->_lastWordEnd == -1 ? nullkey : pCharacters->at(curState->_lastWordStart); 306 | 307 | string curWordLastCharType = curState->_lastWordEnd == -1 ? nullkey : wordtype(curWordLastChar); 308 | string curWordLast2CharType = curState->_lastWordEnd < 1 ? nullkey : wordtype(curWordLast2Char); 309 | string nextCharType = wordtype(nextChar); 310 | 311 | int length = curState->_lastWordEnd - curState->_lastWordStart + 1; 312 | if (length > 5) 313 | length = 5; 314 | stringstream curss; 315 | curss << length; 316 | string strCurWordLen = curss.str(); 317 | 318 | if (curState->_nextPosition != curState->_lastWordEnd + 1) { 319 | std::cout << "position error" << std::endl; 320 | } 321 | 322 | const CStateItem * preStackState = curState->_prevStackState; 323 | string pre1Word = preStackState == 0 || preStackState->_lastWordEnd == -1 ? nullkey : preStackState->_strlastWord; 324 | string pre1WordLastChar = preStackState == 0 || preStackState->_lastWordEnd == -1 ? nullkey : pCharacters->at(preStackState->_lastWordEnd); 325 | string pre1WordFirstChar = preStackState == 0 || preStackState->_lastWordEnd == -1 ? nullkey : pCharacters->at(preStackState->_lastWordStart); 326 | 327 | length = preStackState == 0 ? 0 : preStackState->_lastWordEnd - preStackState->_lastWordStart + 1; 328 | if (length > 5) 329 | length = 5; 330 | stringstream press; 331 | press << length; 332 | string strPreWordLen = press.str(); 333 | 334 | string strFeat = ""; 335 | strFeat = "F11" + seperateKey + curWordLastChar + seperateKey + nextChar; 336 | feat._strSparseFeat.push_back(strFeat); 337 | 338 | strFeat = "F12" + seperateKey + curWordFirstChar + seperateKey + nextChar; 339 | feat._strSparseFeat.push_back(strFeat); 340 | 341 | /* 342 | strFeat = "F13" + seperateKey + nextCharType; 343 | feat._strSparseFeat.push_back(strFeat); 344 | 345 | strFeat = "F14" + seperateKey + curWordLastCharType + nextCharType; 346 | feat._strSparseFeat.push_back(strFeat); 347 | */ 348 | strFeat = "F15" + seperateKey + curWordLast2CharType + curWordLastCharType + nextCharType; 349 | feat._strSparseFeat.push_back(strFeat); 350 | 351 | strFeat = "F16" + seperateKey + curWord + seperateKey + nextChar; 352 | feat._strSparseFeat.push_back(strFeat); 353 | 354 | strFeat = "F17" + seperateKey + curWord; 355 | feat._strSparseFeat.push_back(strFeat); 356 | 357 | strFeat = "F18" + seperateKey + curWord + seperateKey + pre1Word; 358 | feat._strSparseFeat.push_back(strFeat); 359 | 360 | strFeat = "F19" + seperateKey + curWord + seperateKey + pre1WordLastChar; 361 | feat._strSparseFeat.push_back(strFeat); 362 | 363 | strFeat = "F20" + seperateKey + curWord + seperateKey + pre1WordFirstChar; 364 | feat._strSparseFeat.push_back(strFeat); 365 | 366 | strFeat = "F21" + seperateKey + curWordLastChar + seperateKey + pre1WordLastChar; 367 | feat._strSparseFeat.push_back(strFeat); 368 | 369 | strFeat = "F22" + seperateKey + curWordFirstChar + seperateKey + curWordLastChar; 370 | feat._strSparseFeat.push_back(strFeat); 371 | 372 | strFeat = "F23" + seperateKey + curWord + seperateKey + strPreWordLen; 373 | feat._strSparseFeat.push_back(strFeat); 374 | 375 | strFeat = "F24" + seperateKey + pre1Word + seperateKey + strCurWordLen; 376 | feat._strSparseFeat.push_back(strFeat); 377 | 378 | strFeat = "F25" + seperateKey + pre1Word + seperateKey + curWordLastChar; 379 | feat._strSparseFeat.push_back(strFeat); 380 | 381 | if (curState->_lastWordStart != -1) { 382 | for (int idx = curState->_lastWordStart; idx < curState->_lastWordEnd; idx++) { 383 | strFeat = "F26" + seperateKey + pCharacters->at(idx) + seperateKey + curWordLastChar; 384 | feat._strSparseFeat.push_back(strFeat); 385 | } 386 | } 387 | 388 | if (curState->_lastWordEnd == curState->_lastWordStart && curState->_lastWordStart != -1) { 389 | strFeat = "F27" + seperateKey + curWord; 390 | feat._strSparseFeat.push_back(strFeat); 391 | } 392 | 393 | strFeat = "F28" + seperateKey + curWordFirstChar + seperateKey + strCurWordLen; 394 | feat._strSparseFeat.push_back(strFeat); 395 | 396 | strFeat = "F29" + seperateKey + curWordLastChar + seperateKey + strCurWordLen; 397 | feat._strSparseFeat.push_back(strFeat); 398 | 399 | int ngram = 0; 400 | 401 | if (ngram < actionNgram) { 402 | feat._strActionFeat.push_back(CAction(CAction::SEP).str()); 403 | ngram++; 404 | const CStateItem* theState = curState; 405 | while (theState != NULL && ngram < actionNgram) { 406 | feat._strActionFeat.push_back(theState->_lastAction.str()); 407 | theState = theState->_prevState; 408 | ngram++; 409 | } 410 | } 411 | 412 | ngram = 0; 413 | 414 | if (ngram < wordNgram) { 415 | feat._strWordFeat.push_back(curWord); 416 | ngram++; 417 | const CStateItem* theState = preStackState; 418 | while (theState != NULL && theState->_lastWordEnd >= 0 && ngram < actionNgram) { 419 | feat._strWordFeat.push_back(theState->_strlastWord); 420 | theState = theState->_prevStackState; 421 | ngram++; 422 | } 423 | } 424 | 425 | } 426 | 427 | void extractFeatureFinish(const CStateItem * curState, Feature& feat, int wordNgram = 0, int actionNgram = 0) { 428 | string curWord = curState->_lastWordEnd == -1 ? nullkey : curState->_strlastWord; 429 | const std::vector * pCharacters = curState->_pCharacters; 430 | //string nextChar = pCharacters->at(curState->_nextPosition); 431 | string curWordLastChar = curState->_lastWordEnd == -1 ? nullkey : pCharacters->at(curState->_lastWordEnd); 432 | string curWordLast2Char = curState->_lastWordEnd < 1 ? nullkey : pCharacters->at(curState->_lastWordEnd - 1); 433 | string curWordFirstChar = curState->_lastWordEnd == -1 ? nullkey : pCharacters->at(curState->_lastWordStart); 434 | 435 | int length = curState->_lastWordEnd - curState->_lastWordStart + 1; 436 | if (length > 5) 437 | length = 5; 438 | stringstream curss; 439 | curss << length; 440 | string strCurWordLen = curss.str(); 441 | 442 | if (curState->_nextPosition != curState->_lastWordEnd + 1) { 443 | std::cout << "position error" << std::endl; 444 | } 445 | 446 | const CStateItem * preStackState = curState->_prevStackState; 447 | string pre1Word = preStackState == 0 || preStackState->_lastWordEnd == -1 ? nullkey : preStackState->_strlastWord; 448 | string pre1WordLastChar = preStackState == 0 || preStackState->_lastWordEnd == -1 ? nullkey : pCharacters->at(preStackState->_lastWordEnd); 449 | string pre1WordFirstChar = preStackState == 0 || preStackState->_lastWordEnd == -1 ? nullkey : pCharacters->at(preStackState->_lastWordStart); 450 | 451 | length = preStackState == 0 ? 0 : preStackState->_lastWordEnd - preStackState->_lastWordStart + 1; 452 | if (length > 5) 453 | length = 5; 454 | stringstream press; 455 | press << length; 456 | string strPreWordLen = press.str(); 457 | 458 | string strFeat = ""; 459 | /*strFeat = "F11" + seperateKey + curWordLastChar + seperateKey + nullkey; 460 | feat._strSparseFeat.push_back(strFeat); 461 | 462 | strFeat = "F12" + seperateKey + curWordFirstChar + seperateKey + nullkey; 463 | feat._strSparseFeat.push_back(strFeat); 464 | 465 | strFeat = "F13" + seperateKey + nextCharType; 466 | feat._strSparseFeat.push_back(strFeat); 467 | 468 | strFeat = "F14" + seperateKey + curWordLastCharType + nextCharType; 469 | feat._strSparseFeat.push_back(strFeat); 470 | 471 | strFeat = "F15" + seperateKey + curWordLast2CharType + curWordLastCharType + nextCharType; 472 | feat._strSparseFeat.push_back(strFeat); 473 | */ 474 | 475 | strFeat = "F16" + seperateKey + curWord + seperateKey + nullkey; 476 | feat._strSparseFeat.push_back(strFeat); 477 | 478 | strFeat = "F17" + seperateKey + curWord; 479 | feat._strSparseFeat.push_back(strFeat); 480 | 481 | strFeat = "F18" + seperateKey + curWord + seperateKey + pre1Word; 482 | feat._strSparseFeat.push_back(strFeat); 483 | 484 | strFeat = "F19" + seperateKey + curWord + seperateKey + pre1WordLastChar; 485 | feat._strSparseFeat.push_back(strFeat); 486 | 487 | strFeat = "F20" + seperateKey + curWord + seperateKey + pre1WordFirstChar; 488 | feat._strSparseFeat.push_back(strFeat); 489 | 490 | strFeat = "F21" + seperateKey + curWordLastChar + seperateKey + pre1WordLastChar; 491 | feat._strSparseFeat.push_back(strFeat); 492 | 493 | strFeat = "F22" + seperateKey + curWordFirstChar + seperateKey + curWordLastChar; 494 | feat._strSparseFeat.push_back(strFeat); 495 | 496 | strFeat = "F23" + seperateKey + curWord + seperateKey + strPreWordLen; 497 | feat._strSparseFeat.push_back(strFeat); 498 | 499 | strFeat = "F24" + seperateKey + pre1Word + seperateKey + strCurWordLen; 500 | feat._strSparseFeat.push_back(strFeat); 501 | 502 | strFeat = "F25" + seperateKey + pre1Word + seperateKey + curWordLastChar; 503 | feat._strSparseFeat.push_back(strFeat); 504 | 505 | if (curState->_lastWordStart != -1) { 506 | for (int idx = curState->_lastWordStart; idx < curState->_lastWordEnd; idx++) { 507 | strFeat = "F26" + seperateKey + pCharacters->at(idx) + seperateKey + curWordLastChar; 508 | feat._strSparseFeat.push_back(strFeat); 509 | } 510 | } 511 | 512 | if (curState->_lastWordEnd == curState->_lastWordStart && curState->_lastWordStart != -1) { 513 | strFeat = "F27" + seperateKey + curWord; 514 | feat._strSparseFeat.push_back(strFeat); 515 | } 516 | 517 | strFeat = "F28" + seperateKey + curWordFirstChar + seperateKey + strCurWordLen; 518 | feat._strSparseFeat.push_back(strFeat); 519 | 520 | strFeat = "F29" + seperateKey + curWordLastChar + seperateKey + strCurWordLen; 521 | feat._strSparseFeat.push_back(strFeat); 522 | 523 | int ngram = 0; 524 | 525 | if (ngram < actionNgram) { 526 | feat._strActionFeat.push_back(CAction(CAction::FIN).str()); 527 | ngram++; 528 | const CStateItem* theState = curState; 529 | while (theState != NULL && ngram < actionNgram) { 530 | feat._strActionFeat.push_back(theState->_lastAction.str()); 531 | theState = theState->_prevState; 532 | ngram++; 533 | } 534 | } 535 | 536 | ngram = 0; 537 | 538 | if (ngram < wordNgram) { 539 | feat._strWordFeat.push_back(curWord); 540 | ngram++; 541 | const CStateItem* theState = preStackState; 542 | while (theState != NULL && theState->_lastWordEnd >= 0 && ngram < actionNgram) { 543 | feat._strWordFeat.push_back(theState->_strlastWord); 544 | theState = theState->_prevStackState; 545 | ngram++; 546 | } 547 | } 548 | 549 | } 550 | 551 | }; 552 | 553 | #endif /* BASIC_FEATUREEXTRACTION_H_ */ 554 | -------------------------------------------------------------------------------- /model/LinearBeamSearcher.h: -------------------------------------------------------------------------------- 1 | /* 2 | * LinearBeamSearcher.h 3 | * 4 | * Created on: Jan 25, 2016 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef SRC_LinearBeamSearcher_H_ 9 | #define SRC_LinearBeamSearcher_H_ 10 | 11 | #include 12 | #include 13 | 14 | #include 15 | #include "Feature.h" 16 | #include "FeatureExtraction.h" 17 | #include "N3L.h" 18 | #include "State.h" 19 | #include "Action.h" 20 | 21 | using namespace nr; 22 | using namespace std; 23 | using namespace mshadow; 24 | using namespace mshadow::expr; 25 | using namespace mshadow::utils; 26 | 27 | //re-implementation of Yue and Clark ACL (2007) 28 | template 29 | class LinearBeamSearcher { 30 | public: 31 | LinearBeamSearcher() { 32 | _dropOut = 0.5; 33 | _delta = 0.2; 34 | } 35 | ~LinearBeamSearcher() { 36 | } 37 | 38 | public: 39 | SparseUniLayer1O _splayer_output; 40 | 41 | FeatureExtraction fe; 42 | 43 | int _linearfeatSize; 44 | 45 | Metric _eval; 46 | 47 | dtype _dropOut; 48 | 49 | dtype _delta; 50 | 51 | enum { 52 | BEAM_SIZE = 16, MAX_SENTENCE_SIZE = 512 53 | }; 54 | 55 | public: 56 | 57 | inline void addToFeatureAlphabet(hash_map feat_stat, int featCutOff = 0) { 58 | fe.addToFeatureAlphabet(feat_stat, featCutOff); 59 | } 60 | 61 | inline void addToWordAlphabet(hash_map word_stat, int wordCutOff = 0) { 62 | fe.addToWordAlphabet(word_stat, wordCutOff); 63 | } 64 | 65 | inline void addToCharAlphabet(hash_map char_stat, int charCutOff = 0) { 66 | fe.addToCharAlphabet(char_stat, charCutOff); 67 | } 68 | 69 | inline void addToBiCharAlphabet(hash_map bichar_stat, int bicharCutOff = 0) { 70 | fe.addToBiCharAlphabet(bichar_stat, bicharCutOff); 71 | } 72 | 73 | inline void addToActionAlphabet(hash_map action_stat) { 74 | fe.addToActionAlphabet(action_stat); 75 | } 76 | 77 | inline void setAlphaIncreasing(bool bAlphaIncreasing) { 78 | fe.setAlphaIncreasing(bAlphaIncreasing); 79 | } 80 | 81 | inline void initAlphabet() { 82 | fe.initAlphabet(); 83 | } 84 | 85 | inline void loadAlphabet() { 86 | fe.loadAlphabet(); 87 | } 88 | 89 | inline void extractFeature(const CStateItem * curState, const CAction& nextAC, Feature& feat) { 90 | fe.extractFeature(curState, nextAC, feat); 91 | } 92 | 93 | public: 94 | 95 | inline void init() { 96 | _linearfeatSize = 3*fe._featAlphabet.size(); 97 | 98 | _splayer_output.initial(_linearfeatSize, 10); 99 | } 100 | 101 | inline void release() { 102 | _splayer_output.release(); 103 | } 104 | 105 | dtype train(const std::vector >& sentences, const vector >& goldACs) { 106 | fe.setFeatureFormat(false); 107 | setAlphaIncreasing(true); 108 | _eval.reset(); 109 | dtype cost = 0.0; 110 | for (int idx = 0; idx < sentences.size(); idx++) { 111 | cost += trainOneExample(sentences[idx], goldACs[idx], sentences.size()); 112 | } 113 | 114 | return cost; 115 | } 116 | 117 | // scores do not accumulate together...., big bug, refine it tomorrow or at thursday. 118 | dtype trainOneExample(const std::vector& sentence, const vector& goldAC, int num) { 119 | if (sentence.size() >= MAX_SENTENCE_SIZE) 120 | return 0.0; 121 | static CStateItem lattice[(MAX_SENTENCE_SIZE + 1) * (BEAM_SIZE + 1)]; 122 | static CStateItem * lattice_index[MAX_SENTENCE_SIZE + 1]; 123 | 124 | int length = sentence.size(); 125 | dtype cost = 0.0; 126 | dtype score = 0.0; 127 | 128 | const static CStateItem *pGenerator; 129 | const static CStateItem *pBestGen; 130 | static CStateItem *correctState; 131 | 132 | bool bCorrect; // used in learning for early update 133 | int index, tmp_i, tmp_j; 134 | CAction correct_action; 135 | bool correct_action_scored; 136 | std::vector actions; // actions to apply for a candidate 137 | static NRHeap beam(BEAM_SIZE); 138 | static CScoredStateAction scored_action; // used rank actions 139 | static CScoredStateAction scored_correct_action; 140 | 141 | lattice_index[0] = lattice; 142 | lattice_index[1] = lattice + 1; 143 | lattice_index[0]->clear(); 144 | lattice_index[0]->initSentence(&sentence); 145 | 146 | index = 0; 147 | 148 | correctState = lattice_index[0]; 149 | 150 | while (true) { 151 | ++index; 152 | lattice_index[index + 1] = lattice_index[index]; 153 | beam.clear(); 154 | pBestGen = 0; 155 | correct_action = goldAC[index - 1]; 156 | bCorrect = false; 157 | correct_action_scored = false; 158 | 159 | //std::cout << "check beam start" << std::endl; 160 | for (pGenerator = lattice_index[index - 1]; pGenerator != lattice_index[index]; ++pGenerator) { 161 | //std::cout << "new" << std::endl; 162 | //std::cout << pGenerator->str() << std::endl; 163 | pGenerator->getCandidateActions(actions); 164 | for (tmp_j = 0; tmp_j < actions.size(); ++tmp_j) { 165 | scored_action.action = actions[tmp_j]; 166 | scored_action.item = pGenerator; 167 | fe.extractFeature(pGenerator, actions[tmp_j], scored_action.feat); 168 | _splayer_output.ComputeForwardScore(scored_action.feat._nSparseFeat, scored_action.score); 169 | //std::cout << "add start, action = " << actions[tmp_j] << ", cur ac score = " << scored_action.score << ", orgin score = " << pGenerator->_score << std::endl;; 170 | scored_action.score += pGenerator->_score; 171 | if(actions[tmp_j] != correct_action){ 172 | scored_action.score += _delta; 173 | } 174 | 175 | beam.add_elem(scored_action); 176 | 177 | //std::cout << "new scored_action : " << scored_action.score << ", action = " << scored_action.action << ", state = " << scored_action.item->str() << std::endl; 178 | //for (int tmp_k = 0; tmp_k < beam.elemsize(); ++tmp_k) { 179 | // std::cout << tmp_k << ": " << beam[tmp_k].score << ", action = " << beam[tmp_k].action << ", state = " << beam[tmp_k].item->str() << std::endl; 180 | //} 181 | 182 | if (pGenerator == correctState && actions[tmp_j] == correct_action) { 183 | scored_correct_action = scored_action; 184 | correct_action_scored = true; 185 | //std::cout << "add gold finish" << std::endl; 186 | } else { 187 | //std::cout << "add finish" << std::endl; 188 | } 189 | 190 | } 191 | } 192 | 193 | //std::cout << "check beam finish" << std::endl; 194 | 195 | if (beam.elemsize() == 0) { 196 | std::cout << "error" << std::endl; 197 | for (int idx = 0; idx < sentence.size(); idx++) { 198 | std::cout << sentence[idx] << std::endl; 199 | } 200 | std::cout << "" << std::endl; 201 | return -1; 202 | } 203 | 204 | //std::cout << "check beam start" << std::endl; 205 | for (tmp_j = 0; tmp_j < beam.elemsize(); ++tmp_j) { // insert from 206 | pGenerator = beam[tmp_j].item; 207 | pGenerator->move(lattice_index[index + 1], beam[tmp_j].action); 208 | lattice_index[index + 1]->_score = beam[tmp_j].score; 209 | lattice_index[index + 1]->_curFeat.copy(beam[tmp_j].feat); 210 | 211 | //std::cout << tmp_j << ": " << beam[tmp_j].score << std::endl; 212 | 213 | if (pBestGen == 0 || lattice_index[index + 1]->_score > pBestGen->_score) { 214 | pBestGen = lattice_index[index + 1]; 215 | } 216 | if (pGenerator == correctState && beam[tmp_j].action == correct_action) { 217 | correctState = lattice_index[index + 1]; 218 | bCorrect = true; 219 | } 220 | 221 | ++lattice_index[index + 1]; 222 | } 223 | //std::cout << "check beam finish" << std::endl; 224 | 225 | if (pBestGen->IsTerminated()) 226 | break; // while 227 | 228 | // update items if correct item jump out of the agenda 229 | 230 | if (!bCorrect) { 231 | // note that if bCorrect == true then the correct state has 232 | // already been updated, and the new value is one of the new states 233 | // among the newly produced from lattice[index+1]. 234 | correctState->move(lattice_index[index + 1], correct_action); 235 | correctState = lattice_index[index + 1]; 236 | lattice_index[index + 1]->_score = scored_correct_action.score; 237 | lattice_index[index + 1]->_curFeat.copy(scored_correct_action.feat); 238 | 239 | ++lattice_index[index + 1]; 240 | assert(correct_action_scored); // scored_correct_act valid 241 | //TRACE(index << " updated"); 242 | //std::cout << index << " updated" << std::endl; 243 | 244 | cost = backPropagationStates(pBestGen, correctState, 1.0/num, -1.0/num); 245 | if (cost < 0) { 246 | std::cout << "strange ..." << std::endl; 247 | } 248 | _eval.correct_label_count += index; 249 | _eval.overall_label_count += length + 1; 250 | return cost; 251 | } 252 | 253 | } 254 | 255 | // make sure that the correct item is stack top finally 256 | if (pBestGen != correctState) { 257 | if (!bCorrect) { 258 | correctState->move(lattice_index[index + 1], correct_action); 259 | correctState = lattice_index[index + 1]; 260 | lattice_index[index + 1]->_score = scored_correct_action.score; 261 | lattice_index[index + 1]->_curFeat.copy(scored_correct_action.feat); 262 | assert(correct_action_scored); // scored_correct_act valid 263 | } 264 | 265 | //std::cout << "best:" << pBestGen->str() << std::endl; 266 | //std::cout << "gold:" << correctState->str() << std::endl; 267 | 268 | cost = backPropagationStates(pBestGen, correctState, 1.0/num, -1.0/num); 269 | if (cost < 0) { 270 | std::cout << "strange ..." << std::endl; 271 | } 272 | _eval.correct_label_count += length; 273 | _eval.overall_label_count += length + 1; 274 | } else { 275 | _eval.correct_label_count += length + 1; 276 | _eval.overall_label_count += length + 1; 277 | } 278 | 279 | return cost; 280 | } 281 | 282 | dtype backPropagationStates(const CStateItem *pPredState, const CStateItem *pGoldState, dtype predLoss, dtype goldLoss) { 283 | if (pPredState == pGoldState) 284 | return 0.0; 285 | 286 | if(pPredState->_nextPosition != pGoldState->_nextPosition){ 287 | std::cout << "state align error" << std::endl; 288 | } 289 | dtype delta = 0.0; 290 | dtype predscore, goldscore; 291 | _splayer_output.ComputeForwardScore(pPredState->_curFeat._nSparseFeat, predscore); 292 | _splayer_output.ComputeForwardScore(pGoldState->_curFeat._nSparseFeat, goldscore); 293 | 294 | delta = predscore - goldscore; 295 | if(pPredState->_lastAction != pGoldState->_lastAction){ 296 | delta += _delta; 297 | } 298 | 299 | _splayer_output.ComputeBackwardLoss(pPredState->_curFeat._nSparseFeat, predLoss); 300 | _splayer_output.ComputeBackwardLoss(pGoldState->_curFeat._nSparseFeat, goldLoss); 301 | 302 | //currently we use a uniform loss 303 | delta += backPropagationStates(pPredState->_prevState, pGoldState->_prevState, predLoss, goldLoss); 304 | 305 | dtype compare_delta = pPredState->_score - pGoldState->_score; 306 | if (abs(delta - compare_delta) > 0.01) { 307 | std::cout << "delta=" << delta << "\t, compare_delta=" << compare_delta << std::endl; 308 | } 309 | 310 | return delta; 311 | } 312 | 313 | bool decode(const std::vector& sentence, std::vector& words) { 314 | setAlphaIncreasing(false); 315 | if (sentence.size() >= MAX_SENTENCE_SIZE) 316 | return false; 317 | static CStateItem lattice[(MAX_SENTENCE_SIZE + 1) * (BEAM_SIZE + 1)]; 318 | static CStateItem *lattice_index[MAX_SENTENCE_SIZE + 1]; 319 | 320 | int length = sentence.size(); 321 | dtype cost = 0.0; 322 | dtype score = 0.0; 323 | 324 | const static CStateItem *pGenerator; 325 | const static CStateItem *pBestGen; 326 | 327 | int index, tmp_i, tmp_j; 328 | std::vector actions; // actions to apply for a candidate 329 | static NRHeap beam(BEAM_SIZE); 330 | static CScoredStateAction scored_action; // used rank actions 331 | static Feature feat; 332 | 333 | lattice_index[0] = lattice; 334 | lattice_index[1] = lattice + 1; 335 | lattice_index[0]->clear(); 336 | lattice_index[0]->initSentence(&sentence); 337 | 338 | index = 0; 339 | 340 | while (true) { 341 | ++index; 342 | lattice_index[index + 1] = lattice_index[index]; 343 | beam.clear(); 344 | pBestGen = 0; 345 | 346 | //std::cout << index << std::endl; 347 | for (pGenerator = lattice_index[index - 1]; pGenerator != lattice_index[index]; ++pGenerator) { 348 | pGenerator->getCandidateActions(actions); 349 | for (tmp_j = 0; tmp_j < actions.size(); ++tmp_j) { 350 | scored_action.action = actions[tmp_j]; 351 | scored_action.item = pGenerator; 352 | fe.extractFeature(pGenerator, actions[tmp_j], scored_action.feat); 353 | _splayer_output.ComputeForwardScore(scored_action.feat._nSparseFeat, scored_action.score); 354 | scored_action.score += pGenerator->_score; 355 | beam.add_elem(scored_action); 356 | } 357 | 358 | } 359 | 360 | if (beam.elemsize() == 0) { 361 | std::cout << "error" << std::endl; 362 | for (int idx = 0; idx < sentence.size(); idx++) { 363 | std::cout << sentence[idx] << std::endl; 364 | } 365 | std::cout << "" << std::endl; 366 | return false; 367 | } 368 | 369 | for (tmp_j = 0; tmp_j < beam.elemsize(); ++tmp_j) { // insert from 370 | pGenerator = beam[tmp_j].item; 371 | pGenerator->move(lattice_index[index + 1], beam[tmp_j].action); 372 | lattice_index[index + 1]->_score = beam[tmp_j].score; 373 | 374 | if (pBestGen == 0 || lattice_index[index + 1]->_score > pBestGen->_score) { 375 | pBestGen = lattice_index[index + 1]; 376 | } 377 | 378 | ++lattice_index[index + 1]; 379 | } 380 | 381 | if (pBestGen->IsTerminated()) 382 | break; // while 383 | 384 | } 385 | pBestGen->getSegResults(words); 386 | 387 | return true; 388 | } 389 | 390 | void updateParams(dtype nnRegular, dtype adaAlpha, dtype adaEps) { 391 | _splayer_output.updateAdaGrad(nnRegular, adaAlpha, adaEps); 392 | } 393 | 394 | void writeModel(); 395 | 396 | void loadModel(); 397 | 398 | public: 399 | 400 | inline void resetEval() { 401 | _eval.reset(); 402 | } 403 | 404 | inline void setDropValue(dtype dropOut) { 405 | _dropOut = dropOut; 406 | } 407 | 408 | }; 409 | 410 | #endif /* SRC_LinearBeamSearcher_H_ */ 411 | -------------------------------------------------------------------------------- /other-implementations/cpps/APSegmentor.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Segmentor.h 3 | * 4 | * Created on: Jan 25, 2016 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef SRC_PARSER_H_ 9 | #define SRC_PARSER_H_ 10 | 11 | #include "N3L.h" 12 | 13 | #include "APBeamSearcher.h" 14 | #include "Options.h" 15 | #include "Pipe.h" 16 | #include "Utf.h" 17 | 18 | using namespace nr; 19 | using namespace std; 20 | 21 | 22 | class Segmentor { 23 | public: 24 | std::string nullkey; 25 | std::string rootdepkey; 26 | std::string unknownkey; 27 | std::string paddingtag; 28 | std::string seperateKey; 29 | 30 | public: 31 | Segmentor(); 32 | virtual ~Segmentor(); 33 | 34 | public: 35 | 36 | #if USE_CUDA==1 37 | APBeamSearcher m_classifier; 38 | #else 39 | APBeamSearcher m_classifier; 40 | #endif 41 | 42 | Options m_options; 43 | 44 | Pipe m_pipe; 45 | 46 | public: 47 | void readWordEmbeddings(const string& inFile, NRMat& wordEmb); 48 | 49 | void readWordClusters(const string& inFile); 50 | 51 | int createAlphabet(const vector& vecInsts); 52 | 53 | int addTestWordAlpha(const vector& vecInsts); 54 | 55 | public: 56 | void train(const string& trainFile, const string& devFile, const string& testFile, const string& modelFile, const string& optionFile, 57 | const string& wordEmbFile); 58 | void predict(const Instance& input, vector& output); 59 | void test(const string& testFile, const string& outputFile, const string& modelFile); 60 | 61 | // static training 62 | void getGoldActions(const vector& vecInsts, vector >& vecActions); 63 | 64 | 65 | public: 66 | 67 | void proceedOneStepForDecode(const Instance& inputTree, CStateItem& state, int& outlab); //may be merged with train in the future 68 | 69 | void writeModelFile(const string& outputModelFile); 70 | void loadModelFile(const string& inputModelFile); 71 | 72 | public: 73 | inline void getCandidateActions(const CStateItem &item, vector& actions) { 74 | 75 | } 76 | 77 | 78 | 79 | }; 80 | 81 | #endif /* SRC_PARSER_H_ */ 82 | -------------------------------------------------------------------------------- /other-implementations/cpps/GRNNSegmentor.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Segmentor.h 3 | * 4 | * Created on: Jan 25, 2016 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef SRC_PARSER_H_ 9 | #define SRC_PARSER_H_ 10 | 11 | #include "N3L.h" 12 | 13 | #include "model/GRNNBeamSearcher.h" 14 | #include "Options.h" 15 | #include "Pipe.h" 16 | #include "Utf.h" 17 | 18 | using namespace nr; 19 | using namespace std; 20 | 21 | class Segmentor { 22 | public: 23 | std::string nullkey; 24 | std::string rootdepkey; 25 | std::string unknownkey; 26 | std::string paddingtag; 27 | std::string seperateKey; 28 | 29 | public: 30 | Segmentor(); 31 | virtual ~Segmentor(); 32 | 33 | public: 34 | 35 | #if USE_CUDA==1 36 | GRNNBeamSearcher m_classifier; 37 | #else 38 | GRNNBeamSearcher m_classifier; 39 | #endif 40 | 41 | Options m_options; 42 | 43 | Pipe m_pipe; 44 | 45 | hash_map m_word_stat; 46 | 47 | public: 48 | void readWordEmbeddings(const string& inFile, NRMat& wordEmb); 49 | 50 | void readWordClusters(const string& inFile); 51 | 52 | int createAlphabet(const vector& vecInsts); 53 | 54 | int addTestWordAlpha(const vector& vecInsts); 55 | 56 | int allWordAlphaEmb(const string& inFile, NRMat& emb); 57 | 58 | public: 59 | void train(const string& trainFile, const string& devFile, const string& testFile, const string& modelFile, const string& optionFile, 60 | const string& wordEmbFile, const string& charEmbFile, const string& bicharEmbFile); 61 | void predict(const Instance& input, vector& output); 62 | void test(const string& testFile, const string& outputFile, const string& modelFile); 63 | 64 | // static training 65 | void getGoldActions(const vector& vecInsts, vector >& vecActions); 66 | 67 | public: 68 | void readEmbeddings(Alphabet &alpha, const string& inFile, NRMat& emb); 69 | 70 | void writeModelFile(const string& outputModelFile); 71 | void loadModelFile(const string& inputModelFile); 72 | 73 | public: 74 | 75 | 76 | }; 77 | 78 | #endif /* SRC_PARSER_H_ */ 79 | -------------------------------------------------------------------------------- /other-implementations/cpps/LSTMNASegmentor.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Segmentor.h 3 | * 4 | * Created on: Jan 25, 2016 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef SRC_PARSER_H_ 9 | #define SRC_PARSER_H_ 10 | 11 | #include "N3L.h" 12 | 13 | #include "model/LSTMNABeamSearcher.h" 14 | #include "Options.h" 15 | #include "Pipe.h" 16 | #include "Utf.h" 17 | 18 | using namespace nr; 19 | using namespace std; 20 | 21 | class Segmentor { 22 | public: 23 | std::string nullkey; 24 | std::string rootdepkey; 25 | std::string unknownkey; 26 | std::string paddingtag; 27 | std::string seperateKey; 28 | 29 | public: 30 | Segmentor(); 31 | virtual ~Segmentor(); 32 | 33 | public: 34 | 35 | #if USE_CUDA==1 36 | LSTMNABeamSearcher m_classifier; 37 | #else 38 | LSTMNABeamSearcher m_classifier; 39 | #endif 40 | 41 | Options m_options; 42 | 43 | Pipe m_pipe; 44 | 45 | hash_map m_word_stat; 46 | 47 | public: 48 | void readWordEmbeddings(const string& inFile, NRMat& wordEmb); 49 | 50 | void readWordClusters(const string& inFile); 51 | 52 | int createAlphabet(const vector& vecInsts); 53 | 54 | int addTestWordAlpha(const vector& vecInsts); 55 | 56 | int allWordAlphaEmb(const string& inFile, NRMat& emb); 57 | 58 | public: 59 | void train(const string& trainFile, const string& devFile, const string& testFile, const string& modelFile, const string& optionFile, 60 | const string& wordEmbFile, const string& charEmbFile, const string& bicharEmbFile); 61 | void predict(const Instance& input, vector& output); 62 | void test(const string& testFile, const string& outputFile, const string& modelFile); 63 | 64 | // static training 65 | void getGoldActions(const vector& vecInsts, vector >& vecActions); 66 | 67 | public: 68 | void readEmbeddings(Alphabet &alpha, const string& inFile, NRMat& emb); 69 | 70 | void writeModelFile(const string& outputModelFile); 71 | void loadModelFile(const string& inputModelFile); 72 | 73 | public: 74 | 75 | 76 | }; 77 | 78 | #endif /* SRC_PARSER_H_ */ 79 | -------------------------------------------------------------------------------- /other-implementations/cpps/LSTMNBCSegmentor.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Segmentor.h 3 | * 4 | * Created on: Jan 25, 2016 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef SRC_PARSER_H_ 9 | #define SRC_PARSER_H_ 10 | 11 | #include "N3L.h" 12 | 13 | #include "model/LSTMNBCBeamSearcher.h" 14 | #include "Options.h" 15 | #include "Pipe.h" 16 | #include "Utf.h" 17 | 18 | using namespace nr; 19 | using namespace std; 20 | 21 | class Segmentor { 22 | public: 23 | std::string nullkey; 24 | std::string rootdepkey; 25 | std::string unknownkey; 26 | std::string paddingtag; 27 | std::string seperateKey; 28 | 29 | public: 30 | Segmentor(); 31 | virtual ~Segmentor(); 32 | 33 | public: 34 | 35 | #if USE_CUDA==1 36 | LSTMNBCBeamSearcher m_classifier; 37 | #else 38 | LSTMNBCBeamSearcher m_classifier; 39 | #endif 40 | 41 | Options m_options; 42 | 43 | Pipe m_pipe; 44 | 45 | hash_map m_word_stat; 46 | 47 | public: 48 | void readWordEmbeddings(const string& inFile, NRMat& wordEmb); 49 | 50 | void readWordClusters(const string& inFile); 51 | 52 | int createAlphabet(const vector& vecInsts); 53 | 54 | int addTestWordAlpha(const vector& vecInsts); 55 | 56 | int allWordAlphaEmb(const string& inFile, NRMat& emb); 57 | 58 | public: 59 | void train(const string& trainFile, const string& devFile, const string& testFile, const string& modelFile, const string& optionFile, 60 | const string& wordEmbFile, const string& charEmbFile, const string& bicharEmbFile); 61 | void predict(const Instance& input, vector& output); 62 | void test(const string& testFile, const string& outputFile, const string& modelFile); 63 | 64 | // static training 65 | void getGoldActions(const vector& vecInsts, vector >& vecActions); 66 | 67 | public: 68 | void readEmbeddings(Alphabet &alpha, const string& inFile, NRMat& emb); 69 | 70 | void writeModelFile(const string& outputModelFile); 71 | void loadModelFile(const string& inputModelFile); 72 | 73 | public: 74 | 75 | 76 | }; 77 | 78 | #endif /* SRC_PARSER_H_ */ 79 | -------------------------------------------------------------------------------- /other-implementations/cpps/LSTMNCSegmentor.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Segmentor.h 3 | * 4 | * Created on: Jan 25, 2016 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef SRC_PARSER_H_ 9 | #define SRC_PARSER_H_ 10 | 11 | #include "N3L.h" 12 | 13 | #include "model/LSTMNCBeamSearcher.h" 14 | #include "Options.h" 15 | #include "Pipe.h" 16 | #include "Utf.h" 17 | 18 | using namespace nr; 19 | using namespace std; 20 | 21 | class Segmentor { 22 | public: 23 | std::string nullkey; 24 | std::string rootdepkey; 25 | std::string unknownkey; 26 | std::string paddingtag; 27 | std::string seperateKey; 28 | 29 | public: 30 | Segmentor(); 31 | virtual ~Segmentor(); 32 | 33 | public: 34 | 35 | #if USE_CUDA==1 36 | LSTMNCBeamSearcher m_classifier; 37 | #else 38 | LSTMNCBeamSearcher m_classifier; 39 | #endif 40 | 41 | Options m_options; 42 | 43 | Pipe m_pipe; 44 | 45 | hash_map m_word_stat; 46 | 47 | public: 48 | void readWordEmbeddings(const string& inFile, NRMat& wordEmb); 49 | 50 | void readWordClusters(const string& inFile); 51 | 52 | int createAlphabet(const vector& vecInsts); 53 | 54 | int addTestWordAlpha(const vector& vecInsts); 55 | 56 | int allWordAlphaEmb(const string& inFile, NRMat& emb); 57 | 58 | public: 59 | void train(const string& trainFile, const string& devFile, const string& testFile, const string& modelFile, const string& optionFile, 60 | const string& wordEmbFile, const string& charEmbFile, const string& bicharEmbFile); 61 | void predict(const Instance& input, vector& output); 62 | void test(const string& testFile, const string& outputFile, const string& modelFile); 63 | 64 | // static training 65 | void getGoldActions(const vector& vecInsts, vector >& vecActions); 66 | 67 | public: 68 | void readEmbeddings(Alphabet &alpha, const string& inFile, NRMat& emb); 69 | 70 | void writeModelFile(const string& outputModelFile); 71 | void loadModelFile(const string& inputModelFile); 72 | 73 | public: 74 | 75 | 76 | }; 77 | 78 | #endif /* SRC_PARSER_H_ */ 79 | -------------------------------------------------------------------------------- /other-implementations/cpps/LSTMNUCSegmentor.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Segmentor.h 3 | * 4 | * Created on: Jan 25, 2016 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef SRC_PARSER_H_ 9 | #define SRC_PARSER_H_ 10 | 11 | #include "N3L.h" 12 | 13 | #include "model/LSTMNUCBeamSearcher.h" 14 | #include "Options.h" 15 | #include "Pipe.h" 16 | #include "Utf.h" 17 | 18 | using namespace nr; 19 | using namespace std; 20 | 21 | class Segmentor { 22 | public: 23 | std::string nullkey; 24 | std::string rootdepkey; 25 | std::string unknownkey; 26 | std::string paddingtag; 27 | std::string seperateKey; 28 | 29 | public: 30 | Segmentor(); 31 | virtual ~Segmentor(); 32 | 33 | public: 34 | 35 | #if USE_CUDA==1 36 | LSTMNUCBeamSearcher m_classifier; 37 | #else 38 | LSTMNUCBeamSearcher m_classifier; 39 | #endif 40 | 41 | Options m_options; 42 | 43 | Pipe m_pipe; 44 | 45 | hash_map m_word_stat; 46 | 47 | public: 48 | void readWordEmbeddings(const string& inFile, NRMat& wordEmb); 49 | 50 | void readWordClusters(const string& inFile); 51 | 52 | int createAlphabet(const vector& vecInsts); 53 | 54 | int addTestWordAlpha(const vector& vecInsts); 55 | 56 | int allWordAlphaEmb(const string& inFile, NRMat& emb); 57 | 58 | public: 59 | void train(const string& trainFile, const string& devFile, const string& testFile, const string& modelFile, const string& optionFile, 60 | const string& wordEmbFile, const string& charEmbFile, const string& bicharEmbFile); 61 | void predict(const Instance& input, vector& output); 62 | void test(const string& testFile, const string& outputFile, const string& modelFile); 63 | 64 | // static training 65 | void getGoldActions(const vector& vecInsts, vector >& vecActions); 66 | 67 | public: 68 | void readEmbeddings(Alphabet &alpha, const string& inFile, NRMat& emb); 69 | 70 | void writeModelFile(const string& outputModelFile); 71 | void loadModelFile(const string& inputModelFile); 72 | 73 | public: 74 | 75 | 76 | }; 77 | 78 | #endif /* SRC_PARSER_H_ */ 79 | -------------------------------------------------------------------------------- /other-implementations/cpps/LSTMNWSegmentor.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Segmentor.h 3 | * 4 | * Created on: Jan 25, 2016 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef SRC_PARSER_H_ 9 | #define SRC_PARSER_H_ 10 | 11 | #include "N3L.h" 12 | 13 | #include "model/LSTMNWBeamSearcher.h" 14 | #include "Options.h" 15 | #include "Pipe.h" 16 | #include "Utf.h" 17 | 18 | using namespace nr; 19 | using namespace std; 20 | 21 | class Segmentor { 22 | public: 23 | std::string nullkey; 24 | std::string rootdepkey; 25 | std::string unknownkey; 26 | std::string paddingtag; 27 | std::string seperateKey; 28 | 29 | public: 30 | Segmentor(); 31 | virtual ~Segmentor(); 32 | 33 | public: 34 | 35 | #if USE_CUDA==1 36 | LSTMNWBeamSearcher m_classifier; 37 | #else 38 | LSTMNWBeamSearcher m_classifier; 39 | #endif 40 | 41 | Options m_options; 42 | 43 | Pipe m_pipe; 44 | 45 | hash_map m_word_stat; 46 | 47 | public: 48 | void readWordEmbeddings(const string& inFile, NRMat& wordEmb); 49 | 50 | void readWordClusters(const string& inFile); 51 | 52 | int createAlphabet(const vector& vecInsts); 53 | 54 | int addTestWordAlpha(const vector& vecInsts); 55 | 56 | int allWordAlphaEmb(const string& inFile, NRMat& emb); 57 | 58 | public: 59 | void train(const string& trainFile, const string& devFile, const string& testFile, const string& modelFile, const string& optionFile, 60 | const string& wordEmbFile, const string& charEmbFile, const string& bicharEmbFile); 61 | void predict(const Instance& input, vector& output); 62 | void test(const string& testFile, const string& outputFile, const string& modelFile); 63 | 64 | // static training 65 | void getGoldActions(const vector& vecInsts, vector >& vecActions); 66 | 67 | public: 68 | void readEmbeddings(Alphabet &alpha, const string& inFile, NRMat& emb); 69 | 70 | void writeModelFile(const string& outputModelFile); 71 | void loadModelFile(const string& inputModelFile); 72 | 73 | public: 74 | 75 | 76 | }; 77 | 78 | #endif /* SRC_PARSER_H_ */ 79 | -------------------------------------------------------------------------------- /other-implementations/cpps/RNNSegmentor.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Segmentor.h 3 | * 4 | * Created on: Jan 25, 2016 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef SRC_PARSER_H_ 9 | #define SRC_PARSER_H_ 10 | 11 | #include "N3L.h" 12 | 13 | #include "model/RNNBeamSearcher.h" 14 | #include "Options.h" 15 | #include "Pipe.h" 16 | #include "Utf.h" 17 | 18 | using namespace nr; 19 | using namespace std; 20 | 21 | class Segmentor { 22 | public: 23 | std::string nullkey; 24 | std::string rootdepkey; 25 | std::string unknownkey; 26 | std::string paddingtag; 27 | std::string seperateKey; 28 | 29 | public: 30 | Segmentor(); 31 | virtual ~Segmentor(); 32 | 33 | public: 34 | 35 | #if USE_CUDA==1 36 | RNNBeamSearcher m_classifier; 37 | #else 38 | RNNBeamSearcher m_classifier; 39 | #endif 40 | 41 | Options m_options; 42 | 43 | Pipe m_pipe; 44 | 45 | hash_map m_word_stat; 46 | 47 | public: 48 | void readWordEmbeddings(const string& inFile, NRMat& wordEmb); 49 | 50 | void readWordClusters(const string& inFile); 51 | 52 | int createAlphabet(const vector& vecInsts); 53 | 54 | int addTestWordAlpha(const vector& vecInsts); 55 | 56 | int allWordAlphaEmb(const string& inFile, NRMat& emb); 57 | 58 | public: 59 | void train(const string& trainFile, const string& devFile, const string& testFile, const string& modelFile, const string& optionFile, 60 | const string& wordEmbFile, const string& charEmbFile, const string& bicharEmbFile); 61 | void predict(const Instance& input, vector& output); 62 | void test(const string& testFile, const string& outputFile, const string& modelFile); 63 | 64 | // static training 65 | void getGoldActions(const vector& vecInsts, vector >& vecActions); 66 | 67 | public: 68 | void readEmbeddings(Alphabet &alpha, const string& inFile, NRMat& emb); 69 | 70 | void writeModelFile(const string& outputModelFile); 71 | void loadModelFile(const string& inputModelFile); 72 | 73 | public: 74 | 75 | 76 | }; 77 | 78 | #endif /* SRC_PARSER_H_ */ 79 | -------------------------------------------------------------------------------- /other-implementations/cpps/TNNSegmentor.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Segmentor.h 3 | * 4 | * Created on: Jan 25, 2016 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef SRC_PARSER_H_ 9 | #define SRC_PARSER_H_ 10 | 11 | #include "N3L.h" 12 | 13 | #include "model/TNNBeamSearcher.h" 14 | #include "Options.h" 15 | #include "Pipe.h" 16 | #include "Utf.h" 17 | 18 | using namespace nr; 19 | using namespace std; 20 | 21 | class Segmentor { 22 | public: 23 | std::string nullkey; 24 | std::string rootdepkey; 25 | std::string unknownkey; 26 | std::string paddingtag; 27 | std::string seperateKey; 28 | 29 | public: 30 | Segmentor(); 31 | virtual ~Segmentor(); 32 | 33 | public: 34 | 35 | #if USE_CUDA==1 36 | TNNBeamSearcher m_classifier; 37 | #else 38 | TNNBeamSearcher m_classifier; 39 | #endif 40 | 41 | Options m_options; 42 | 43 | Pipe m_pipe; 44 | 45 | hash_map m_word_stat; 46 | 47 | public: 48 | void readWordEmbeddings(const string& inFile, NRMat& wordEmb); 49 | 50 | void readWordClusters(const string& inFile); 51 | 52 | int createAlphabet(const vector& vecInsts); 53 | 54 | int addTestWordAlpha(const vector& vecInsts); 55 | 56 | int allWordAlphaEmb(const string& inFile, NRMat& emb); 57 | 58 | public: 59 | void train(const string& trainFile, const string& devFile, const string& testFile, const string& modelFile, const string& optionFile, 60 | const string& wordEmbFile, const string& charEmbFile, const string& bicharEmbFile); 61 | void predict(const Instance& input, vector& output); 62 | void test(const string& testFile, const string& outputFile, const string& modelFile); 63 | 64 | // static training 65 | void getGoldActions(const vector& vecInsts, vector >& vecActions); 66 | 67 | public: 68 | void readEmbeddings(Alphabet &alpha, const string& inFile, NRMat& emb); 69 | 70 | void writeModelFile(const string& outputModelFile); 71 | void loadModelFile(const string& inputModelFile); 72 | 73 | public: 74 | 75 | 76 | }; 77 | 78 | #endif /* SRC_PARSER_H_ */ 79 | -------------------------------------------------------------------------------- /other-implementations/models/APBeamSearcher.h: -------------------------------------------------------------------------------- 1 | /* 2 | * APBeamSearcher.h 3 | * 4 | * Created on: Jan 25, 2016 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef SRC_APBeamSearcher_H_ 9 | #define SRC_APBeamSearcher_H_ 10 | 11 | #include 12 | #include 13 | 14 | #include 15 | #include "Feature.h" 16 | #include "FeatureExtraction.h" 17 | #include "N3L.h" 18 | #include "State.h" 19 | #include "Action.h" 20 | 21 | using namespace nr; 22 | using namespace std; 23 | using namespace mshadow; 24 | using namespace mshadow::expr; 25 | using namespace mshadow::utils; 26 | 27 | //re-implementation of Yue and Clark ACL (2007) 28 | template 29 | class APBeamSearcher { 30 | 31 | public: 32 | APBeamSearcher() { 33 | _dropOut = 0.5; 34 | } 35 | ~APBeamSearcher() { 36 | } 37 | 38 | public: 39 | AvgPerceptron1O _splayer_output; 40 | 41 | FeatureExtraction fe; 42 | 43 | int _linearfeatSize; 44 | 45 | Metric _eval; 46 | 47 | dtype _dropOut; 48 | 49 | enum { 50 | BEAM_SIZE = 16, MAX_SENTENCE_SIZE = 512 51 | }; 52 | 53 | public: 54 | 55 | inline void addToFeatureAlphabet(hash_map feat_stat, int featCutOff = 0) { 56 | fe.addToFeatureAlphabet(feat_stat, featCutOff); 57 | } 58 | 59 | inline void addToWordAlphabet(hash_map word_stat, int wordCutOff = 0) { 60 | fe.addToWordAlphabet(word_stat, wordCutOff); 61 | } 62 | 63 | inline void addToCharAlphabet(hash_map char_stat, int charCutOff = 0) { 64 | fe.addToCharAlphabet(char_stat, charCutOff); 65 | } 66 | 67 | inline void addToBiCharAlphabet(hash_map bichar_stat, int bicharCutOff = 0) { 68 | fe.addToBiCharAlphabet(bichar_stat, bicharCutOff); 69 | } 70 | 71 | inline void addToActionAlphabet(hash_map action_stat) { 72 | fe.addToActionAlphabet(action_stat); 73 | } 74 | 75 | inline void setAlphaIncreasing(bool bAlphaIncreasing) { 76 | fe.setAlphaIncreasing(bAlphaIncreasing); 77 | } 78 | 79 | inline void initAlphabet() { 80 | fe.initAlphabet(); 81 | } 82 | 83 | inline void loadAlphabet() { 84 | fe.loadAlphabet(); 85 | } 86 | 87 | inline void extractFeature(const CStateItem * curState, const CAction& nextAC, Feature& feat) { 88 | fe.extractFeature(curState, nextAC, feat); 89 | } 90 | 91 | public: 92 | 93 | inline void init() { 94 | _linearfeatSize = 3 * fe._featAlphabet.size(); 95 | 96 | _splayer_output.initial(_linearfeatSize, 10); 97 | } 98 | 99 | inline void release() { 100 | _splayer_output.release(); 101 | } 102 | 103 | dtype train(const std::vector >& sentences, const vector >& goldACs) { 104 | fe.setFeatureFormat(false); 105 | setAlphaIncreasing(true); 106 | _eval.reset(); 107 | dtype cost = 0.0; 108 | for (int idx = 0; idx < sentences.size(); idx++) { 109 | cost += trainOneExample(sentences[idx], goldACs[idx], sentences.size()); 110 | } 111 | 112 | return cost; 113 | } 114 | 115 | // scores do not accumulate together...., big bug, refine it tomorrow or at thursday. 116 | dtype trainOneExample(const std::vector& sentence, const vector& goldAC, int num) { 117 | if (sentence.size() >= MAX_SENTENCE_SIZE) 118 | return 0.0; 119 | static CStateItem lattice[(MAX_SENTENCE_SIZE + 1) * (BEAM_SIZE + 1)]; 120 | static CStateItem * lattice_index[MAX_SENTENCE_SIZE + 1]; 121 | 122 | int length = sentence.size(); 123 | dtype cost = 0.0; 124 | dtype score = 0.0; 125 | 126 | const static CStateItem *pGenerator; 127 | const static CStateItem *pBestGen; 128 | static CStateItem *correctState; 129 | 130 | bool bCorrect; // used in learning for early update 131 | int index, tmp_i, tmp_j; 132 | CAction correct_action; 133 | bool correct_action_scored; 134 | std::vector actions; // actions to apply for a candidate 135 | static NRHeap beam(BEAM_SIZE); 136 | static CScoredStateAction scored_action; // used rank actions 137 | static CScoredStateAction scored_correct_action; 138 | 139 | lattice_index[0] = lattice; 140 | lattice_index[1] = lattice + 1; 141 | lattice_index[0]->clear(); 142 | lattice_index[0]->initSentence(&sentence); 143 | 144 | index = 0; 145 | 146 | correctState = lattice_index[0]; 147 | 148 | while (true) { 149 | ++index; 150 | lattice_index[index + 1] = lattice_index[index]; 151 | beam.clear(); 152 | pBestGen = 0; 153 | correct_action = goldAC[index - 1]; 154 | bCorrect = false; 155 | correct_action_scored = false; 156 | 157 | //std::cout << "check beam start" << std::endl; 158 | for (pGenerator = lattice_index[index - 1]; pGenerator != lattice_index[index]; ++pGenerator) { 159 | //std::cout << "new" << std::endl; 160 | //std::cout << pGenerator->str() << std::endl; 161 | pGenerator->getCandidateActions(actions); 162 | for (tmp_j = 0; tmp_j < actions.size(); ++tmp_j) { 163 | scored_action.action = actions[tmp_j]; 164 | scored_action.item = pGenerator; 165 | fe.extractFeature(pGenerator, actions[tmp_j], scored_action.feat); 166 | _splayer_output.ComputeForwardScore(scored_action.feat._nSparseFeat, scored_action.score, true); 167 | //std::cout << "add start, action = " << actions[tmp_j] << ", cur ac score = " << scored_action.score << ", orgin score = " << pGenerator->_score << std::endl;; 168 | scored_action.score += pGenerator->_score; 169 | beam.add_elem(scored_action); 170 | 171 | //std::cout << "new scored_action : " << scored_action.score << ", action = " << scored_action.action << ", state = " << scored_action.item->str() << std::endl; 172 | //for (int tmp_k = 0; tmp_k < beam.elemsize(); ++tmp_k) { 173 | // std::cout << tmp_k << ": " << beam[tmp_k].score << ", action = " << beam[tmp_k].action << ", state = " << beam[tmp_k].item->str() << std::endl; 174 | //} 175 | 176 | if (pGenerator == correctState && actions[tmp_j] == correct_action) { 177 | scored_correct_action = scored_action; 178 | correct_action_scored = true; 179 | //std::cout << "add gold finish" << std::endl; 180 | } else { 181 | //std::cout << "add finish" << std::endl; 182 | } 183 | 184 | } 185 | } 186 | 187 | //std::cout << "check beam finish" << std::endl; 188 | 189 | if (beam.elemsize() == 0) { 190 | std::cout << "error" << std::endl; 191 | for (int idx = 0; idx < sentence.size(); idx++) { 192 | std::cout << sentence[idx] << std::endl; 193 | } 194 | std::cout << "" << std::endl; 195 | return -1; 196 | } 197 | 198 | //std::cout << "check beam start" << std::endl; 199 | for (tmp_j = 0; tmp_j < beam.elemsize(); ++tmp_j) { // insert from 200 | pGenerator = beam[tmp_j].item; 201 | pGenerator->move(lattice_index[index + 1], beam[tmp_j].action); 202 | lattice_index[index + 1]->_score = beam[tmp_j].score; 203 | lattice_index[index + 1]->_curFeat.copy(beam[tmp_j].feat); 204 | 205 | //std::cout << tmp_j << ": " << beam[tmp_j].score << std::endl; 206 | 207 | if (pBestGen == 0 || lattice_index[index + 1]->_score > pBestGen->_score) { 208 | pBestGen = lattice_index[index + 1]; 209 | } 210 | if (pGenerator == correctState && beam[tmp_j].action == correct_action) { 211 | correctState = lattice_index[index + 1]; 212 | bCorrect = true; 213 | } 214 | 215 | ++lattice_index[index + 1]; 216 | } 217 | //std::cout << "check beam finish" << std::endl; 218 | 219 | if (pBestGen->IsTerminated()) 220 | break; // while 221 | 222 | // update items if correct item jump out of the agenda 223 | 224 | if (!bCorrect) { 225 | // note that if bCorrect == true then the correct state has 226 | // already been updated, and the new value is one of the new states 227 | // among the newly produced from lattice[index+1]. 228 | correctState->move(lattice_index[index + 1], correct_action); 229 | correctState = lattice_index[index + 1]; 230 | lattice_index[index + 1]->_score = scored_correct_action.score; 231 | lattice_index[index + 1]->_curFeat.copy(scored_correct_action.feat); 232 | 233 | ++lattice_index[index + 1]; 234 | assert(correct_action_scored); // scored_correct_act valid 235 | //TRACE(index << " updated"); 236 | //std::cout << index << " updated" << std::endl; 237 | 238 | cost = backPropagationStates(pBestGen, correctState, 1.0/num, -1.0/num); 239 | if (cost < 0) { 240 | std::cout << "strange ..." << std::endl; 241 | } 242 | _eval.correct_label_count += index; 243 | _eval.overall_label_count += length + 1; 244 | return cost; 245 | } 246 | 247 | } 248 | 249 | // make sure that the correct item is stack top finally 250 | if (pBestGen != correctState) { 251 | if (!bCorrect) { 252 | correctState->move(lattice_index[index + 1], correct_action); 253 | correctState = lattice_index[index + 1]; 254 | lattice_index[index + 1]->_score = scored_correct_action.score; 255 | lattice_index[index + 1]->_curFeat.copy(scored_correct_action.feat); 256 | assert(correct_action_scored); // scored_correct_act valid 257 | } 258 | 259 | //std::cout << "best:" << pBestGen->str() << std::endl; 260 | //std::cout << "gold:" << correctState->str() << std::endl; 261 | 262 | cost = backPropagationStates(pBestGen, correctState, 1.0/num, -1.0/num); 263 | if (cost < 0) { 264 | std::cout << "strange ..." << std::endl; 265 | } 266 | _eval.correct_label_count += length; 267 | _eval.overall_label_count += length + 1; 268 | } else { 269 | _eval.correct_label_count += length + 1; 270 | _eval.overall_label_count += length + 1; 271 | } 272 | 273 | return cost; 274 | } 275 | 276 | dtype backPropagationStates(const CStateItem *pPredState, const CStateItem *pGoldState, dtype predLoss, dtype goldLoss) { 277 | if (pPredState == pGoldState) 278 | return 0.0; 279 | 280 | if(pPredState->_nextPosition != pGoldState->_nextPosition){ 281 | std::cout << "state align error" << std::endl; 282 | } 283 | dtype delta = 0.0; 284 | dtype predscore, goldscore; 285 | _splayer_output.ComputeForwardScore(pPredState->_curFeat._nSparseFeat, predscore, true); 286 | _splayer_output.ComputeForwardScore(pGoldState->_curFeat._nSparseFeat, goldscore, true); 287 | 288 | delta = predscore - goldscore; 289 | 290 | _splayer_output.ComputeBackwardLoss(pPredState->_curFeat._nSparseFeat, predLoss); 291 | _splayer_output.ComputeBackwardLoss(pGoldState->_curFeat._nSparseFeat, goldLoss); 292 | 293 | //currently we use a uniform loss 294 | delta += backPropagationStates(pPredState->_prevState, pGoldState->_prevState, predLoss, goldLoss); 295 | 296 | dtype compare_delta = pPredState->_score - pGoldState->_score; 297 | if (abs(delta - compare_delta) > 0.01) { 298 | std::cout << "delta=" << delta << "\t, compare_delta=" << compare_delta << std::endl; 299 | } 300 | 301 | return delta; 302 | } 303 | 304 | bool decode(const std::vector& sentence, std::vector& words) { 305 | setAlphaIncreasing(false); 306 | if (sentence.size() >= MAX_SENTENCE_SIZE) 307 | return false; 308 | static CStateItem lattice[(MAX_SENTENCE_SIZE + 1) * (BEAM_SIZE + 1)]; 309 | static CStateItem *lattice_index[MAX_SENTENCE_SIZE + 1]; 310 | 311 | int length = sentence.size(); 312 | dtype cost = 0.0; 313 | dtype score = 0.0; 314 | 315 | const static CStateItem *pGenerator; 316 | const static CStateItem *pBestGen; 317 | 318 | int index, tmp_i, tmp_j; 319 | std::vector actions; // actions to apply for a candidate 320 | static NRHeap beam(BEAM_SIZE); 321 | static CScoredStateAction scored_action; // used rank actions 322 | static Feature feat; 323 | 324 | lattice_index[0] = lattice; 325 | lattice_index[1] = lattice + 1; 326 | lattice_index[0]->clear(); 327 | lattice_index[0]->initSentence(&sentence); 328 | 329 | index = 0; 330 | 331 | while (true) { 332 | ++index; 333 | lattice_index[index + 1] = lattice_index[index]; 334 | beam.clear(); 335 | pBestGen = 0; 336 | 337 | //std::cout << index << std::endl; 338 | for (pGenerator = lattice_index[index - 1]; pGenerator != lattice_index[index]; ++pGenerator) { 339 | pGenerator->getCandidateActions(actions); 340 | for (tmp_j = 0; tmp_j < actions.size(); ++tmp_j) { 341 | scored_action.action = actions[tmp_j]; 342 | scored_action.item = pGenerator; 343 | fe.extractFeature(pGenerator, actions[tmp_j], scored_action.feat); 344 | _splayer_output.ComputeForwardScore(scored_action.feat._nSparseFeat, scored_action.score); 345 | scored_action.score += pGenerator->_score; 346 | beam.add_elem(scored_action); 347 | } 348 | 349 | } 350 | 351 | if (beam.elemsize() == 0) { 352 | std::cout << "error" << std::endl; 353 | for (int idx = 0; idx < sentence.size(); idx++) { 354 | std::cout << sentence[idx] << std::endl; 355 | } 356 | std::cout << "" << std::endl; 357 | return false; 358 | } 359 | 360 | for (tmp_j = 0; tmp_j < beam.elemsize(); ++tmp_j) { // insert from 361 | pGenerator = beam[tmp_j].item; 362 | pGenerator->move(lattice_index[index + 1], beam[tmp_j].action); 363 | lattice_index[index + 1]->_score = beam[tmp_j].score; 364 | 365 | if (pBestGen == 0 || lattice_index[index + 1]->_score > pBestGen->_score) { 366 | pBestGen = lattice_index[index + 1]; 367 | } 368 | 369 | ++lattice_index[index + 1]; 370 | } 371 | 372 | if (pBestGen->IsTerminated()) 373 | break; // while 374 | 375 | } 376 | pBestGen->getSegResults(words); 377 | 378 | return true; 379 | } 380 | 381 | void updateParams(dtype nnRegular, dtype adaAlpha, dtype adaEps) { 382 | _splayer_output.updateAdaGrad(nnRegular, adaAlpha, adaEps); 383 | } 384 | 385 | void writeModel(); 386 | 387 | void loadModel(); 388 | 389 | public: 390 | 391 | inline void resetEval() { 392 | _eval.reset(); 393 | } 394 | 395 | inline void setDropValue(dtype dropOut) { 396 | _dropOut = dropOut; 397 | } 398 | 399 | }; 400 | 401 | #endif /* SRC_APBeamSearcher_H_ */ 402 | -------------------------------------------------------------------------------- /state/NeuralState.h: -------------------------------------------------------------------------------- 1 | /* 2 | * State.h 3 | * 4 | * Created on: Oct 1, 2015 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef SEG_NEURALSTATE_H_ 9 | #define SEG_NEURALSTATE_H_ 10 | 11 | #include "DenseFeatureForward.h" 12 | #include "Feature.h" 13 | #include "Action.h" 14 | 15 | template 16 | class CStateItem { 17 | public: 18 | std::string _strlastWord; 19 | int _lastWordStart; 20 | int _lastWordEnd; 21 | const CStateItem *_prevStackState; 22 | const CStateItem *_prevSepState; 23 | const CStateItem *_prevState; 24 | int _nextPosition; 25 | 26 | const std::vector *_pCharacters; 27 | int _characterSize; 28 | 29 | CAction _lastAction; 30 | Feature _curFeat; 31 | DenseFeatureForward _nnfeat; 32 | dtype _score; 33 | int _wordnum; 34 | 35 | public: 36 | CStateItem() { 37 | _strlastWord = ""; 38 | _lastWordStart = -1; 39 | _lastWordEnd = -1; 40 | _prevStackState = 0; 41 | _prevSepState = 0; 42 | _prevState = 0; 43 | _nextPosition = 0; 44 | _pCharacters = 0; 45 | _characterSize = 0; 46 | _lastAction.clear(); 47 | _curFeat.clear(); 48 | _nnfeat.clear(); 49 | _score = 0.0; 50 | _wordnum = 0; 51 | } 52 | 53 | CStateItem(const std::vector* pCharacters) { 54 | _strlastWord = ""; 55 | _lastWordStart = -1; 56 | _lastWordEnd = -1; 57 | _prevStackState = 0; 58 | _prevSepState = 0; 59 | _prevState = 0; 60 | _nextPosition = 0; 61 | _pCharacters = pCharacters; 62 | _characterSize = pCharacters->size(); 63 | _lastAction.clear(); 64 | _curFeat.clear(); 65 | _nnfeat.clear(); 66 | _score = 0.0; 67 | _wordnum = 0; 68 | } 69 | 70 | virtual ~CStateItem(){ 71 | clear(); 72 | } 73 | 74 | void initSentence(const std::vector* pCharacters) { 75 | _pCharacters = pCharacters; 76 | _characterSize = pCharacters->size(); 77 | } 78 | 79 | void clear() { 80 | _strlastWord = ""; 81 | _lastWordStart = -1; 82 | _lastWordEnd = -1; 83 | _prevStackState = 0; 84 | _prevSepState = 0; 85 | _prevState = 0; 86 | _nextPosition = 0; 87 | _lastAction.clear(); 88 | _curFeat.clear(); 89 | _nnfeat.clear(); 90 | _score = 0.0; 91 | _wordnum = 0; 92 | } 93 | 94 | void copyState(const CStateItem* from) { 95 | _strlastWord = from->_strlastWord; 96 | _lastWordStart = from->_lastWordStart; 97 | _lastWordEnd = from->_lastWordEnd; 98 | _prevStackState = from->_prevStackState; 99 | _prevSepState = from->_prevSepState; 100 | _prevState = from->_prevState; 101 | _nextPosition = from->_nextPosition; 102 | _pCharacters = from->_pCharacters; 103 | _characterSize = from->_characterSize; 104 | _lastAction = from->_lastAction; 105 | _curFeat.copy(from->_curFeat); 106 | _nnfeat.copy(from->_nnfeat); 107 | _score = from->_score; 108 | _wordnum = from->_wordnum; 109 | } 110 | 111 | const CStateItem* getPrevStackState() const{ 112 | return _prevStackState; 113 | } 114 | 115 | const CStateItem* getPrevSepState() const{ 116 | return _prevSepState; 117 | } 118 | 119 | const CStateItem* getPrevState() const{ 120 | return _prevState; 121 | } 122 | 123 | std::string getLastWord() { 124 | return _strlastWord; 125 | } 126 | 127 | public: 128 | //only assign context 129 | void separate(CStateItem* next) const{ 130 | if (_nextPosition >= _characterSize) { 131 | std::cout << "separate error" << std::endl; 132 | return; 133 | } 134 | next->_strlastWord = (*_pCharacters)[_nextPosition]; 135 | next->_lastWordStart = _nextPosition; 136 | next->_lastWordEnd = _nextPosition; 137 | next->_prevStackState = this; 138 | next->_prevSepState = next; 139 | next->_prevState = this; 140 | next->_nextPosition = _nextPosition + 1; 141 | next->_pCharacters = _pCharacters; 142 | next->_characterSize = _characterSize; 143 | next->_wordnum = _wordnum + 1; 144 | next->_lastAction.set(CAction::SEP); 145 | } 146 | 147 | //only assign context 148 | void finish(CStateItem* next) const{ 149 | if (_nextPosition != _characterSize) { 150 | std::cout << "finish error" << std::endl; 151 | return; 152 | } 153 | next->_strlastWord = _strlastWord; 154 | next->_lastWordStart = _lastWordStart; 155 | next->_lastWordEnd = _lastWordEnd; 156 | next->_prevStackState = _prevStackState; 157 | next->_prevSepState = next; 158 | next->_prevState = this; 159 | next->_nextPosition = _nextPosition + 1; 160 | next->_pCharacters = _pCharacters; 161 | next->_characterSize = _characterSize; 162 | next->_wordnum = _wordnum + 1; 163 | next->_lastAction.set(CAction::FIN); 164 | } 165 | 166 | //only assign context 167 | void append(CStateItem* next) const{ 168 | if (_nextPosition >= _characterSize) { 169 | std::cout << "append error" << std::endl; 170 | return; 171 | } 172 | next->_strlastWord = _strlastWord + (*_pCharacters)[_nextPosition]; 173 | next->_lastWordStart = _lastWordStart; 174 | next->_lastWordEnd = _nextPosition; 175 | next->_prevStackState = _prevStackState; 176 | next->_prevSepState = _prevSepState; 177 | next->_prevState = this; 178 | next->_nextPosition = _nextPosition + 1; 179 | next->_pCharacters = _pCharacters; 180 | next->_characterSize = _characterSize; 181 | next->_wordnum = _wordnum; 182 | next->_lastAction.set(CAction::APP); 183 | } 184 | 185 | void move(CStateItem* next, const CAction& ac) const{ 186 | if (ac.isAppend()) { 187 | append(next); 188 | } else if (ac.isSeparate()) { 189 | separate(next); 190 | } else if (ac.isFinish()) { 191 | finish(next); 192 | } else { 193 | std::cout << "error action" << std::endl; 194 | } 195 | } 196 | 197 | bool IsTerminated() const { 198 | if (_lastAction.isFinish()) 199 | return true; 200 | return false; 201 | } 202 | 203 | //partial results 204 | void getSegResults(std::vector& words) const { 205 | words.clear(); 206 | words.insert(words.begin(), _strlastWord); 207 | const CStateItem *prevStackState = _prevStackState; 208 | while (prevStackState != 0 && prevStackState->_wordnum > 0) { 209 | words.insert(words.begin(), prevStackState->_strlastWord); 210 | prevStackState = prevStackState->_prevStackState; 211 | } 212 | } 213 | 214 | 215 | void getGoldAction(const std::vector& segments, CAction& ac) const { 216 | if (_nextPosition == _characterSize) { 217 | ac.set(CAction::FIN); 218 | return; 219 | } 220 | if (_nextPosition == 0) { 221 | ac.set(CAction::SEP); 222 | return; 223 | } 224 | 225 | if (_nextPosition > 0 && _nextPosition < _characterSize) { 226 | // should have a check here to see whether the words are match, but I did not do it here 227 | if (_strlastWord.length() == segments[_wordnum-1].length()) { 228 | ac.set(CAction::SEP); 229 | return; 230 | } else { 231 | ac.set(CAction::APP); 232 | return; 233 | } 234 | } 235 | 236 | ac.set(CAction::NO_ACTION); 237 | return; 238 | } 239 | 240 | // we did not judge whether history actions are match with current state. 241 | void getGoldAction(const CStateItem* goldState, CAction& ac) const{ 242 | if (_nextPosition > goldState->_nextPosition || _nextPosition < 0) { 243 | ac.set(CAction::NO_ACTION); 244 | return; 245 | } 246 | const CStateItem *prevState = goldState->_prevState; 247 | CAction curAction = goldState->_lastAction; 248 | while (_nextPosition < prevState->_nextPosition) { 249 | curAction = prevState->_lastAction; 250 | prevState = prevState->_prevState; 251 | } 252 | return ac.set(curAction._code); 253 | } 254 | 255 | void getCandidateActions(vector & actions) const{ 256 | actions.clear(); 257 | static CAction ac; 258 | if(_nextPosition == 0){ 259 | ac.set(CAction::SEP); 260 | actions.push_back(ac); 261 | } 262 | else if(_nextPosition == _characterSize){ 263 | ac.set(CAction::FIN); 264 | actions.push_back(ac); 265 | } 266 | else if(_nextPosition > 0 && _nextPosition < _characterSize){ 267 | ac.set(CAction::SEP); 268 | actions.push_back(ac); 269 | ac.set(CAction::APP); 270 | actions.push_back(ac); 271 | } 272 | else{ 273 | 274 | } 275 | 276 | } 277 | 278 | inline std::string str() const{ 279 | stringstream curoutstr; 280 | 281 | curoutstr << "score: " << _score << " "; 282 | curoutstr << "seg:"; 283 | std::vector words; 284 | getSegResults(words); 285 | for(int idx = 0; idx < words.size(); idx++){ 286 | curoutstr << " " << words[idx]; 287 | } 288 | 289 | return curoutstr.str(); 290 | } 291 | 292 | }; 293 | 294 | template 295 | class CScoredStateAction { 296 | public: 297 | CAction action; 298 | const CStateItem *item; 299 | dtype score; 300 | Feature feat; 301 | DenseFeatureForward nnfeat; 302 | 303 | public: 304 | CScoredStateAction() : 305 | item(0), action(0), score(0) { 306 | feat.setFeatureFormat(false); 307 | feat.clear(); 308 | nnfeat.clear(); 309 | } 310 | 311 | /* 312 | ~CScoredStateAction(){ 313 | clear(); 314 | } 315 | */ 316 | 317 | public: 318 | /* 319 | inline void clear(){ 320 | item = 0; 321 | action = 0; 322 | score = 0; 323 | feat.setFeatureFormat(false); 324 | feat.clear(); 325 | nnfeat.clear(); 326 | } 327 | */ 328 | 329 | inline CScoredStateAction& operator=(const CScoredStateAction &rhs) { 330 | // Check for self-assignment! 331 | if (this == &rhs) // Same object? 332 | return *this; // Yes, so skip assignment, and just return *this. 333 | 334 | item = rhs.item; 335 | action.set(rhs.action._code); 336 | score = rhs.score; 337 | feat.copy(rhs.feat); 338 | nnfeat.copy(rhs.nnfeat); 339 | 340 | return *this; 341 | } 342 | 343 | 344 | public: 345 | bool operator <(const CScoredStateAction &a1) const { 346 | return score < a1.score; 347 | } 348 | bool operator >(const CScoredStateAction &a1) const { 349 | return score > a1.score; 350 | } 351 | bool operator <=(const CScoredStateAction &a1) const { 352 | return score <= a1.score; 353 | } 354 | bool operator >=(const CScoredStateAction &a1) const { 355 | return score >= a1.score; 356 | } 357 | 358 | 359 | }; 360 | 361 | template 362 | class CScoredStateAction_Compare { 363 | public: 364 | int operator()(const CScoredStateAction &o1, const CScoredStateAction &o2) const { 365 | 366 | if (o1.score < o2.score) 367 | return -1; 368 | else if (o1.score > o2.score) 369 | return 1; 370 | else 371 | return 0; 372 | } 373 | }; 374 | 375 | 376 | #endif /* SEG_NEURALSTATE_H_ */ 377 | -------------------------------------------------------------------------------- /state/State.h: -------------------------------------------------------------------------------- 1 | /* 2 | * State.h 3 | * 4 | * Created on: Oct 1, 2015 5 | * Author: mszhang 6 | */ 7 | 8 | #ifndef SEG_STATE_H_ 9 | #define SEG_STATE_H_ 10 | 11 | #include "Feature.h" 12 | #include "Action.h" 13 | 14 | class CStateItem { 15 | public: 16 | std::string _strlastWord; 17 | int _lastWordStart; 18 | int _lastWordEnd; 19 | const CStateItem *_prevStackState; 20 | const CStateItem *_prevState; 21 | int _nextPosition; 22 | 23 | const std::vector *_pCharacters; 24 | int _characterSize; 25 | 26 | CAction _lastAction; 27 | Feature _curFeat; 28 | dtype _score; 29 | int _wordnum; 30 | 31 | public: 32 | CStateItem() { 33 | _strlastWord = ""; 34 | _lastWordStart = -1; 35 | _lastWordEnd = -1; 36 | _prevStackState = 0; 37 | _prevState = 0; 38 | _nextPosition = 0; 39 | _pCharacters = 0; 40 | _characterSize = 0; 41 | _lastAction.clear(); 42 | _curFeat.clear(); 43 | _score = 0.0; 44 | _wordnum = 0; 45 | } 46 | 47 | CStateItem(const std::vector* pCharacters) { 48 | _strlastWord = ""; 49 | _lastWordStart = -1; 50 | _lastWordEnd = -1; 51 | _prevStackState = 0; 52 | _prevState = 0; 53 | _nextPosition = 0; 54 | _pCharacters = pCharacters; 55 | _characterSize = pCharacters->size(); 56 | _lastAction.clear(); 57 | _curFeat.clear(); 58 | _score = 0.0; 59 | _wordnum = 0; 60 | } 61 | 62 | virtual ~CStateItem(){ 63 | clear(); 64 | } 65 | 66 | void initSentence(const std::vector* pCharacters) { 67 | _pCharacters = pCharacters; 68 | _characterSize = pCharacters->size(); 69 | } 70 | 71 | void clear() { 72 | _strlastWord = ""; 73 | _lastWordStart = -1; 74 | _lastWordEnd = -1; 75 | _prevStackState = 0; 76 | _prevState = 0; 77 | _nextPosition = 0; 78 | _lastAction.clear(); 79 | _curFeat.clear(); 80 | _score = 0.0; 81 | _wordnum = 0; 82 | } 83 | 84 | void copyState(const CStateItem* from) { 85 | _strlastWord = from->_strlastWord; 86 | _lastWordStart = from->_lastWordStart; 87 | _lastWordEnd = from->_lastWordEnd; 88 | _prevStackState = from->_prevStackState; 89 | _prevState = from->_prevState; 90 | _nextPosition = from->_nextPosition; 91 | _pCharacters = from->_pCharacters; 92 | _characterSize = from->_characterSize; 93 | _lastAction = from->_lastAction; 94 | _curFeat.copy(from->_curFeat); 95 | _score = from->_score; 96 | _wordnum = from->_wordnum; 97 | } 98 | 99 | const CStateItem* getPrevStackState() const{ 100 | return _prevStackState; 101 | } 102 | 103 | const CStateItem* getPrevState() const{ 104 | return _prevState; 105 | } 106 | 107 | std::string getLastWord() { 108 | return _strlastWord; 109 | } 110 | 111 | public: 112 | //only assign context 113 | void separate(CStateItem* next) const{ 114 | if (_nextPosition >= _characterSize) { 115 | std::cout << "separate error" << std::endl; 116 | return; 117 | } 118 | next->_strlastWord = (*_pCharacters)[_nextPosition]; 119 | next->_lastWordStart = _nextPosition; 120 | next->_lastWordEnd = _nextPosition; 121 | next->_prevStackState = this; 122 | next->_prevState = this; 123 | next->_nextPosition = _nextPosition + 1; 124 | next->_pCharacters = _pCharacters; 125 | next->_characterSize = _characterSize; 126 | next->_wordnum = _wordnum + 1; 127 | next->_lastAction.set(CAction::SEP); 128 | } 129 | 130 | //only assign context 131 | void finish(CStateItem* next) const{ 132 | if (_nextPosition != _characterSize) { 133 | std::cout << "finish error" << std::endl; 134 | return; 135 | } 136 | next->_strlastWord = _strlastWord; 137 | next->_lastWordStart = _lastWordStart; 138 | next->_lastWordEnd = _lastWordEnd; 139 | next->_prevStackState = _prevStackState; 140 | next->_prevState = this; 141 | next->_nextPosition = _nextPosition + 1; 142 | next->_pCharacters = _pCharacters; 143 | next->_characterSize = _characterSize; 144 | next->_wordnum = _wordnum + 1; 145 | next->_lastAction.set(CAction::FIN); 146 | } 147 | 148 | //only assign context 149 | void append(CStateItem* next) const{ 150 | if (_nextPosition >= _characterSize) { 151 | std::cout << "append error" << std::endl; 152 | return; 153 | } 154 | next->_strlastWord = _strlastWord + (*_pCharacters)[_nextPosition]; 155 | next->_lastWordStart = _lastWordStart; 156 | next->_lastWordEnd = _nextPosition; 157 | next->_prevStackState = _prevStackState; 158 | next->_prevState = this; 159 | next->_nextPosition = _nextPosition + 1; 160 | next->_pCharacters = _pCharacters; 161 | next->_characterSize = _characterSize; 162 | next->_wordnum = _wordnum; 163 | next->_lastAction.set(CAction::APP); 164 | } 165 | 166 | void move(CStateItem* next, const CAction& ac) const{ 167 | if (ac.isAppend()) { 168 | append(next); 169 | } else if (ac.isSeparate()) { 170 | separate(next); 171 | } else if (ac.isFinish()) { 172 | finish(next); 173 | } else { 174 | std::cout << "error action" << std::endl; 175 | } 176 | } 177 | 178 | bool IsTerminated() const { 179 | if (_lastAction.isFinish()) 180 | return true; 181 | return false; 182 | } 183 | 184 | //partial results 185 | void getSegResults(std::vector& words) const { 186 | words.clear(); 187 | words.insert(words.begin(), _strlastWord); 188 | const CStateItem *prevStackState = _prevStackState; 189 | while (prevStackState != 0 && prevStackState->_wordnum > 0) { 190 | words.insert(words.begin(), prevStackState->_strlastWord); 191 | prevStackState = prevStackState->_prevStackState; 192 | } 193 | } 194 | 195 | 196 | void getGoldAction(const std::vector& segments, CAction& ac) const { 197 | if (_nextPosition == _characterSize) { 198 | ac.set(CAction::FIN); 199 | return; 200 | } 201 | if (_nextPosition == 0) { 202 | ac.set(CAction::SEP); 203 | return; 204 | } 205 | 206 | if (_nextPosition > 0 && _nextPosition < _characterSize) { 207 | // should have a check here to see whether the words are match, but I did not do it here 208 | if (_strlastWord.length() == segments[_wordnum-1].length()) { 209 | ac.set(CAction::SEP); 210 | return; 211 | } else { 212 | ac.set(CAction::APP); 213 | return; 214 | } 215 | } 216 | 217 | ac.set(CAction::NO_ACTION); 218 | return; 219 | } 220 | 221 | // we did not judge whether history actions are match with current state. 222 | void getGoldAction(const CStateItem* goldState, CAction& ac) const{ 223 | if (_nextPosition > goldState->_nextPosition || _nextPosition < 0) { 224 | ac.set(CAction::NO_ACTION); 225 | return; 226 | } 227 | const CStateItem *prevState = goldState->_prevState; 228 | CAction curAction = goldState->_lastAction; 229 | while (_nextPosition < prevState->_nextPosition) { 230 | curAction = prevState->_lastAction; 231 | prevState = prevState->_prevState; 232 | } 233 | return ac.set(curAction._code); 234 | } 235 | 236 | void getCandidateActions(vector & actions) const{ 237 | actions.clear(); 238 | static CAction ac; 239 | if(_nextPosition == 0){ 240 | ac.set(CAction::SEP); 241 | actions.push_back(ac); 242 | } 243 | else if(_nextPosition == _characterSize){ 244 | ac.set(CAction::FIN); 245 | actions.push_back(ac); 246 | } 247 | else if(_nextPosition > 0 && _nextPosition < _characterSize){ 248 | ac.set(CAction::SEP); 249 | actions.push_back(ac); 250 | ac.set(CAction::APP); 251 | actions.push_back(ac); 252 | } 253 | else{ 254 | 255 | } 256 | 257 | } 258 | 259 | inline std::string str() const{ 260 | stringstream curoutstr; 261 | 262 | curoutstr << "score: " << _score << " "; 263 | curoutstr << "seg:"; 264 | std::vector words; 265 | getSegResults(words); 266 | for(int idx = 0; idx < words.size(); idx++){ 267 | curoutstr << " " << words[idx]; 268 | } 269 | 270 | return curoutstr.str(); 271 | } 272 | 273 | }; 274 | 275 | 276 | class CScoredStateAction { 277 | public: 278 | CAction action; 279 | const CStateItem *item; 280 | dtype score; 281 | Feature feat; 282 | 283 | public: 284 | CScoredStateAction() : 285 | item(0), action(-1), score(0) { 286 | feat.setFeatureFormat(false); 287 | feat.clear(); 288 | } 289 | 290 | 291 | public: 292 | bool operator <(const CScoredStateAction &a1) const { 293 | return score < a1.score; 294 | } 295 | bool operator >(const CScoredStateAction &a1) const { 296 | return score > a1.score; 297 | } 298 | bool operator <=(const CScoredStateAction &a1) const { 299 | return score <= a1.score; 300 | } 301 | bool operator >=(const CScoredStateAction &a1) const { 302 | return score >= a1.score; 303 | } 304 | 305 | 306 | }; 307 | 308 | class CScoredStateAction_Compare { 309 | public: 310 | int operator()(const CScoredStateAction &o1, const CScoredStateAction &o2) const { 311 | 312 | if (o1.score < o2.score) 313 | return -1; 314 | else if (o1.score > o2.score) 315 | return 1; 316 | else 317 | return 0; 318 | } 319 | }; 320 | 321 | 322 | #endif /* SEG_STATE_H_ */ 323 | --------------------------------------------------------------------------------