├── Makefile ├── README.md ├── dependency ├── jieba.dict.utf8 ├── new_dictionary.sqlite ├── symbelTerms.txt └── trainingSet.txt.model ├── include ├── TC_process.h ├── combinition.h ├── cppMSGmodule.hpp ├── libClassifier.h ├── sqlite3.h └── titlebook.pb.h ├── pyscript ├── Leader_asynchronous.py ├── Leader_synchronous.py ├── Secretary_asynchronous.py └── pyMSGmodule.py ├── src ├── TC_process.cpp ├── combinition.cpp ├── cppMSGmodule.cpp ├── libClassifier.cpp ├── sqlite3.c └── titlebook.proto └── tool ├── Worker.cpp └── nlp_service_main.cpp /Makefile: -------------------------------------------------------------------------------- 1 | C=gcc 2 | CXX=g++ -std=c++11 3 | CFLAGS=-I./include 4 | LDFLAGS=-Wl,--no-as-needed -ldl -pthread 5 | LINK=g++ 6 | 7 | SRC=./src 8 | INCLUDE=./include 9 | BUILD_DIR=./build 10 | BIN_DIR=./build/bin 11 | TOOL_SRC=./tool 12 | PYSCRIPT=./pyscript 13 | ZMQ_INCLUDE=-I/usr/local/include 14 | ZMQ_LIB=-L/usr/local/lib -lzmq 15 | PROTOBUF_LIB=-L/usr/local/lib -lprotobuf -pthread 16 | 17 | MKDIR_P=mkdir -p 18 | .PHONY: directories 19 | 20 | all: directories $(BIN_DIR)/nlp_service_main Worker Leader Leader_async 21 | 22 | MSGmodule: $(BUILD_DIR)/titlebook.pb.o $(BUILD_DIR)/cppMSGmodule.o 23 | 24 | Worker: $(BIN_DIR)/Worker 25 | 26 | Leader: $(BIN_DIR)/Leader_synchronous.py $(BUILD_DIR)/pyMSGmodule.py $(BUILD_DIR)/titlebook_pb2.py 27 | 28 | Leader_async: $(BIN_DIR)/Leader_asynchronous.py $(BIN_DIR)/Secretary_asynchronous.py 29 | 30 | directories: $(BUILD_DIR) 31 | 32 | ${BUILD_DIR}: 33 | $(MKDIR_P) $(BUILD_DIR) 34 | $(MKDIR_P) $(BIN_DIR) 35 | 36 | $(BIN_DIR)/nlp_service_main: $(BUILD_DIR)/nlp_service_main.o $(BUILD_DIR)/combinition.o $(BUILD_DIR)/libClassifier.o $(BUILD_DIR)/sqlite3.o $(BUILD_DIR)/TC_process.o 37 | $(LINK) $(LDFLAGS) $^ -o $@ 38 | 39 | $(BUILD_DIR)/combinition.o: $(SRC)/combinition.cpp 40 | $(CXX) $^ $(CFLAGS) -c -o $@ 41 | 42 | $(BUILD_DIR)/libClassifier.o: $(SRC)/libClassifier.cpp 43 | $(CXX) $^ $(CFLAGS) -c -o $@ 44 | 45 | $(BUILD_DIR)/sqlite3.o: $(SRC)/sqlite3.c 46 | $(C) $^ $(CFLAGS) -c -o $@ 47 | 48 | $(BUILD_DIR)/TC_process.o: $(SRC)/TC_process.cpp 49 | $(CXX) $^ $(CFLAGS) -c -o $@ 50 | 51 | $(BUILD_DIR)/nlp_service_main.o: $(TOOL_SRC)/nlp_service_main.cpp 52 | $(CXX) $^ $(CFLAGS) -c -o $@ 53 | 54 | $(BUILD_DIR)/titlebook.pb.o: $(SRC)/titlebook.pb.cc 55 | $(CXX) $^ $(CFLAGS) -c -o $@ 56 | 57 | $(BUILD_DIR)/cppMSGmodule.o: $(SRC)/cppMSGmodule.cpp 58 | $(CXX) $^ $(CFLAGS) -c -o $@ 59 | 60 | $(BIN_DIR)/Worker: $(TOOL_SRC)/Worker.cpp $(BUILD_DIR)/titlebook.pb.o $(BUILD_DIR)/cppMSGmodule.o $(BUILD_DIR)/combinition.o $(BUILD_DIR)/libClassifier.o $(BUILD_DIR)/sqlite3.o $(BUILD_DIR)/TC_process.o 61 | $(CXX) $^ $(CFLAGS) $(ZMQ_INCLUDE) $(ZMQ_LIB) $(PROTOBUF_LIB) $(LDFLAGS) -o $@ 62 | 63 | $(BIN_DIR)/Leader_synchronous.py: $(PYSCRIPT)/Leader_synchronous.py 64 | cp $^ $@ 65 | 66 | $(BUILD_DIR)/pyMSGmodule.py: $(PYSCRIPT)/pyMSGmodule.py 67 | cp $^ $@ 68 | 69 | $(BIN_DIR)/Leader_asynchronous.py: $(PYSCRIPT)/Leader_asynchronous.py 70 | cp $^ $@ 71 | 72 | $(BIN_DIR)/Secretary_asynchronous.py: $(PYSCRIPT)/Secretary_asynchronous.py 73 | cp $^ $@ 74 | 75 | $(BUILD_DIR)/titlebook_pb2.py: $(PYSCRIPT)/titlebook_pb2.py 76 | cp $^ $@ 77 | 78 | $(SRC)/titlebook.pb.cc: $(SRC)/titlebook.proto 79 | cp $(SRC)/titlebook.proto ./ 80 | protoc titlebook.proto --cpp_out=./$(SRC) 81 | mv $(SRC)/titlebook.pb.h $(INCLUDE) 82 | protoc titlebook.proto --python_out=./ 83 | mv titlebook_pb2.py $(PYSCRIPT) 84 | rm -f titlebook.proto 85 | 86 | clean: 87 | rm -rf build 88 | rm -f $(SRC)/titlebook.pb.h 89 | rm -f $(SRC)/titlebook.pb.cc 90 | rm -f $(PYSCRIPT)/titlebook_pb2.py 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # a NLP service powered by ZEROMQ 2 | 3 | NLP programs have to load tons of math models. It's a good choice to separate NLP program from web-service-infrastructure. 4 | 5 | We compare the C++ NLP module as the **Worker** to the Python control module as the **Leader**. The **Leader** arranges for the **Worker** to do its job and collects its achievement. 6 | 7 | While the process loop between **Leader** and **Worker** is synchronous, you can use Secretary assistant-python-script for making *asynchronous* calls to Worker. 8 | 9 | 10 | ----------- 11 | 12 | ## dependency 13 | 14 | * protobuf and python-protobuf 15 | 16 | * Zeromq and pyzmq 17 | 18 | ----------- 19 | 20 | ## deploy this service 21 | 22 | All shells should be run at $(zeromq_nlp_service) directory. 23 | 24 | ### 1. start Worker service 25 | 26 | ``` 27 | ./build/bin/Worker 28 | ``` 29 | 30 | ### 2. run synchronized Leader script 31 | 32 | ``` 33 | python ./build/bin/Leader_synchronous.py 34 | ``` 35 | 36 | ### 3. run asynchronized Leader script 37 | 38 | #### 3.1 Secretary service should be deployed at first. 39 | 40 | ``` 41 | python build/bin/Secretary_asynchronous.py 42 | ``` 43 | 44 | #### 3.2 run asynchronized Leader script 45 | 46 | ``` 47 | python build/bin/Leader_asynchronous.py 48 | ``` 49 | -------------------------------------------------------------------------------- /dependency/new_dictionary.sqlite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Serbipunk/zeromq_nlp_service/e1ab9ff5aa08dbfd6bf0c4fed3c43b9182b648e6/dependency/new_dictionary.sqlite -------------------------------------------------------------------------------- /dependency/symbelTerms.txt: -------------------------------------------------------------------------------- 1 |  | 2 | ‖‖ 3 | - 4 | / 5 | ?? 6 | - 7 | _ 8 | 】_ 9 | | 10 | ‖‖ 11 | ! 12 | ! 13 | ' 14 | ' 15 | ’ 16 | " 17 | " 18 | “ 19 | ” 20 | , 21 | , 22 | , 23 | 。 24 | . 25 | 、 26 | ; 27 | ; 28 | : 29 | : 30 | ? 31 | ? 32 | ( 33 | ) 34 | & 35 | % 36 | [ 37 | ] 38 | 【 39 | 】 40 | / 41 | \ 42 | | 43 | { 44 | } 45 | ( 46 | ) 47 | [ 48 | ] 49 | { 50 | } 51 | 52 | 53 | - 54 | — 55 | -------------------------------------------------------------------------------- /include/TC_process.h: -------------------------------------------------------------------------------- 1 | // #include "stdafx.h" 2 | 3 | #include 4 | #include 5 | //#include "../cppjieba/headers.h" 6 | #include "combinition.h" 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | #include "sqlite3.h" 13 | 14 | #include 15 | #include 16 | 17 | #include "libClassifier.h" 18 | 19 | 20 | #define N_WORD 10 21 | #define N_DIMEN 5 22 | 23 | #pragma once 24 | 25 | 26 | extern std::vector v_query; 27 | 28 | 29 | //// 30 | void cut(const CppJieba::SegmentInterface * seg, std::list & l_lines, std::vector *> & v_lines_words, int n_word); 31 | 32 | void cut(const CppJieba::SegmentInterface * seg, const char * const filePath, char * outfile, int n_word); 33 | 34 | void query_word(sqlite3 * conn, char * word); 35 | 36 | int sqlite3_exec_callback(void *data, int n_columns, char **col_values, char **col_names); 37 | 38 | 39 | //�ļ��������� 40 | class termFilter 41 | { 42 | public: 43 | std::vector cantPassediterms; 44 | 45 | public: 46 | termFilter( const char * loadPath ); 47 | termFilter(); 48 | ~termFilter(); 49 | 50 | termFilter & operator=(const termFilter & rhs); 51 | 52 | void appendFilter( char * loadPath ); 53 | 54 | bool termIsPass(char *); 55 | }; 56 | 57 | 58 | 59 | void cut(const CppJieba::SegmentInterface * seg, char ** p_text, int n_text, std::vector *> & v_lines_words, int n_word, termFilter & filter); 60 | 61 | void textCategorization_new(char ** p_text, int n_text, int * p_labels, char * outputPath=NULL); 62 | 63 | class CateTeller { 64 | private: 65 | CppJieba::MPSegment seg; 66 | sqlite3 * conn; 67 | struct svm_model * svmModel; 68 | termFilter filter; 69 | 70 | public: 71 | CateTeller(); 72 | ~CateTeller(); 73 | 74 | void tell(char ** p_text, int n_text, int * p_labels); 75 | }; 76 | -------------------------------------------------------------------------------- /include/combinition.h: -------------------------------------------------------------------------------- 1 | //#include "stdafx.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | //#include 27 | //#include 28 | #include 29 | //#include 30 | #include 31 | 32 | 33 | typedef unsigned int uint; 34 | 35 | #ifndef CPPJIEBA_HEADERS_H 36 | #define CPPJIEBA_HEADERS_H 37 | 38 | 39 | #endif 40 | 41 | 42 | #ifndef CPPCOMMON_HEADERS_H 43 | #define CPPCOMMON_HEADERS_H 44 | 45 | 46 | 47 | #endif 48 | 49 | 50 | //////argv_functs.h 51 | #ifndef CPPCOMMON_ARGV_FUNCTS_H 52 | #define CPPCOMMON_ARGV_FUNCTS_H 53 | 54 | namespace CPPCOMMON 55 | { 56 | using namespace std; 57 | bool getArgvMap(int argc, const char* const* argv, map& mpss); 58 | class ArgvContext 59 | { 60 | public : 61 | ArgvContext(int argc, const char* const * argv); 62 | ~ArgvContext(); 63 | public: 64 | string toString(); 65 | string operator [](uint i); 66 | string operator [](const string& key); 67 | public: 68 | bool isKeyExist(const string& key); 69 | private: 70 | vector _args; 71 | map _mpss; 72 | set _sset; 73 | }; 74 | 75 | using namespace std; 76 | template 77 | string vecToString(const vector& vec) 78 | { 79 | if(vec.empty()) 80 | { 81 | return "[]"; 82 | } 83 | stringstream ss; 84 | ss<<"["< 93 | bool isInVec(const vector& vec, const T& item) 94 | { 95 | typename vector::const_iterator it = find(vec.begin(), vec.end(), item); 96 | return it != vec.end(); 97 | } 98 | template 99 | void splitVec(const vector& vecSrc, vector< pair > >& outVec, const vector& patterns) 100 | { 101 | vector tmp; 102 | T pattern; 103 | size_t patternSize = patterns.size(); 104 | for(size_t i = 0; i < vecSrc.size(); i++) 105 | { 106 | size_t patternPos = patternSize; 107 | for(size_t j = 0; j < patternSize; j++) 108 | { 109 | if(patterns[j] == vecSrc[i]) 110 | { 111 | patternPos = j; 112 | break; 113 | } 114 | } 115 | if(patternPos != patternSize) 116 | { 117 | if(!tmp.empty()) 118 | { 119 | outVec.push_back(make_pair >(pattern, tmp)); 120 | tmp.clear(); 121 | } 122 | pattern = patterns[patternPos]; 123 | } 124 | else 125 | { 126 | tmp.push_back(vecSrc[i]); 127 | } 128 | } 129 | if(!tmp.empty()) 130 | { 131 | outVec.push_back(make_pair >(pattern, tmp)); 132 | } 133 | } 134 | 135 | template 136 | void splitVec(const vector& vecSrc, vector< vector >& outVec, const vector& patternVec) 137 | { 138 | vector tmp; 139 | for(size_t i = 0; i < vecSrc.size(); i++) 140 | { 141 | bool flag = false; 142 | for(size_t j = 0; j < patternVec.size(); j++) 143 | { 144 | if(patternVec[j] == vecSrc[i]) 145 | { 146 | flag = true; 147 | break; 148 | } 149 | } 150 | if(flag) 151 | { 152 | if(!tmp.empty()) 153 | { 154 | outVec.push_back(tmp); 155 | tmp.clear(); 156 | } 157 | } 158 | else 159 | { 160 | tmp.push_back(vecSrc[i]); 161 | } 162 | } 163 | if(!tmp.empty()) 164 | { 165 | outVec.push_back(tmp); 166 | } 167 | } 168 | 169 | } 170 | 171 | 172 | 173 | #endif 174 | 175 | //////config.h 176 | #ifndef CPPCOMMON_CONFIG_H 177 | #define CPPCOMMON_CONFIG_H 178 | 179 | namespace CPPCOMMON 180 | { 181 | using std::map; 182 | using std::string; 183 | using std::cout; 184 | using std::endl; 185 | using std::ifstream; 186 | class Config 187 | { 188 | public: 189 | Config(); 190 | ~Config(); 191 | bool init(const string& configFile); 192 | void display(); 193 | string getByKey(const string& key); 194 | private: 195 | string _stripComment(const string& line); 196 | map _map; 197 | bool _isInit; 198 | 199 | }; 200 | } 201 | 202 | namespace CPPCOMMON 203 | { 204 | extern Config gConfig; 205 | } 206 | 207 | #endif 208 | 209 | //////encodeing.h 210 | #ifndef CPPCOMMON_ENCODING_H 211 | #define CPPCOMMON_ENCODING_H 212 | namespace CPPCOMMON 213 | { 214 | using namespace std; 215 | 216 | //const char* const UTF8ENC = "utf-8"; 217 | //const char* const GBKENC = "gbk"; 218 | 219 | //class UnicodeEncoding 220 | //{ 221 | // private: 222 | // string _encoding; 223 | // vector _encVec; 224 | // public: 225 | // UnicodeEncoding(const string& enc); 226 | // ~UnicodeEncoding(); 227 | // public: 228 | // bool setEncoding(const string& enc); 229 | // string encode(const Unicode& unicode); 230 | // string encode(UnicodeConstIterator begin, UnicodeConstIterator end); 231 | // bool decode(const string& str, Unicode& unicode); 232 | // public: 233 | // size_t getWordLength(const string& str); 234 | //}; 235 | } 236 | 237 | #endif 238 | 239 | //////file_functs.h 240 | #ifndef CPPCOMMON_FILE_FUNCTS_H 241 | #define CPPCOMMON_FILE_FUNCTS_H 242 | namespace CPPCOMMON 243 | { 244 | using namespace std; 245 | bool checkFileExist(const string& filePath); 246 | bool createDir(const string& dirPath, bool p = true); 247 | bool checkDirExist(const string& dirPath); 248 | 249 | } 250 | 251 | #endif 252 | 253 | //////io_fincts.h 254 | #ifndef CPPCOMMON_IO_FUNCTS_H 255 | #define CPPCOMMON_IO_FUNCTS_H 256 | namespace CPPCOMMON 257 | { 258 | using namespace std; 259 | string loadFile2Str(const char * const filepath); 260 | } 261 | #endif 262 | 263 | //////logger.h 264 | #ifndef CPPCOMMON_LOGGER_H 265 | #define CPPCOMMON_LOGGER_H 266 | 267 | #define LL_DEBUG 0 268 | #define LL_INFO 1 269 | #define LL_WARN 2 270 | #define LL_ERROR 3 271 | #define LL_FATAL 4 272 | #define LEVEL_ARRAY_SIZE 5 273 | #define CSTR_BUFFER_SIZE 1024 274 | 275 | typedef unsigned int uint; 276 | 277 | //#define LogDebug(msg) Logger::Logging(LL_DEBUG, msg, __FILE__, __LINE__) 278 | //#define LogInfo(msg) Logger::Logging(LL_INFO, msg, __FILE__, __LINE__) 279 | //#define LogWarn(msg) Logger::Logging(LL_WARN, msg, __FILE__, __LINE__) 280 | //#define LogError(msg) Logger::Logging(LL_ERROR, msg, __FILE__, __LINE__) 281 | //#define LogFatal(msg) Logger::Logging(LL_FATAL, msg, __FILE__, __LINE__) 282 | 283 | #define LogDebug(fmt, ...) Logger::LoggingF(LL_DEBUG, __FILE__, __LINE__, fmt, ## __VA_ARGS__) 284 | #define LogInfo(fmt, ...) Logger::LoggingF(LL_INFO, __FILE__, __LINE__, fmt, ## __VA_ARGS__) 285 | #define LogWarn(fmt, ...) Logger::LoggingF(LL_WARN, __FILE__, __LINE__, fmt, ## __VA_ARGS__) 286 | #define LogError(fmt, ...) Logger::LoggingF(LL_ERROR, __FILE__, __LINE__, fmt, ## __VA_ARGS__) 287 | #define LogFatal(fmt, ...) Logger::LoggingF(LL_FATAL, __FILE__, __LINE__, fmt, ## __VA_ARGS__) 288 | 289 | 290 | namespace CPPCOMMON 291 | { 292 | using namespace std; 293 | class Logger 294 | { 295 | public: 296 | Logger(); 297 | ~Logger(); 298 | public: 299 | static bool Logging(uint level, const string& msg, const char* fileName, int lineNo); 300 | static bool Logging(uint level, const char * msg, const char* fileName, int lineNo); 301 | static bool LoggingF(uint level, const char* fileName, int lineNo, const string& fmt, ...); 302 | private: 303 | static char _cStrBuf[CSTR_BUFFER_SIZE]; 304 | static const char * _logLevel[LEVEL_ARRAY_SIZE]; 305 | static const char * _logFormat; 306 | static const char * _timeFormat; 307 | static time_t _timeNow; 308 | }; 309 | } 310 | 311 | #endif 312 | 313 | //////map_functs.h 314 | #ifndef CPPCOMMON_MAP_FUNCTS_H 315 | #define CPPCOMMON_MAP_FUNCTS_H 316 | 317 | namespace CPPCOMMON 318 | { 319 | using namespace std; 320 | 321 | template 322 | string setToString(const set& st) 323 | { 324 | if(st.empty()) 325 | { 326 | return "{}"; 327 | } 328 | stringstream ss; 329 | ss<<'{'; 330 | typename set::const_iterator it = st.begin(); 331 | ss<<*it; 332 | it++; 333 | while(it != st.end()) 334 | { 335 | ss<<", "<<*it; 336 | it++; 337 | } 338 | ss<<'}'; 339 | return ss.str(); 340 | } 341 | 342 | template 343 | string mapToString(const map& mp) 344 | { 345 | if(mp.empty()) 346 | { 347 | return "{}"; 348 | } 349 | stringstream ss; 350 | ss<<'{'; 351 | typename map::const_iterator it = mp.begin(); 352 | ss<first<<": "<second; 353 | it++; 354 | while(it != mp.end()) 355 | { 356 | ss<<", "<first<<": "<second; 357 | it++; 358 | } 359 | ss<<'}'; 360 | return ss.str(); 361 | } 362 | 363 | template 364 | string pairToString(const pair& p) 365 | { 366 | stringstream ss; 367 | ss< 372 | void printMap(const map& mp) 373 | { 374 | for(typename map::const_iterator it = mp.begin(); it != mp.end(); it++) 375 | { 376 | cout<first<<' '<second< 381 | vT getMap(const map& mp, const kT & key, const vT & defaultVal) 382 | { 383 | typename map::const_iterator it; 384 | it = mp.find(key); 385 | if(mp.end() == it) 386 | { 387 | return defaultVal; 388 | } 389 | return it->second; 390 | } 391 | 392 | } 393 | 394 | #endif 395 | 396 | //////typedef.h 397 | #ifndef CPPCOMMON_TYPEDEFS_H 398 | #define CPPCOMMON_TYPEDEFS_H 399 | 400 | namespace CPPCOMMON 401 | { 402 | typedef std::vector Unicode; 403 | typedef std::vector::const_iterator UnicodeConstIterator; 404 | } 405 | 406 | #endif 407 | 408 | 409 | //////vec_functs.h 410 | #ifndef CPPCOMMON_VEC_FUNCTS_H 411 | #define CPPCOMMON_VEC_FUNCTS_H 412 | 413 | #define FOR_VECTOR(vec, i) for(size_t i = 0; i < vec.size(); i++) 414 | 415 | #define PRINT_VECTOR(vec) FOR_VECTOR(vec, i)\ 416 | {\ 417 | cout<& source, const string& connector); 447 | vector splitStr(const string& source, const string& pattern = " \t\n"); 448 | bool splitStr(const string& source, vector& res, const string& pattern = " \t\n"); 449 | bool splitStrMultiPatterns( 450 | const string& strSrc, 451 | vector& outVec, 452 | const vector& patterns 453 | ); 454 | string upperStr(const string& str); 455 | string lowerStr(const string& str); 456 | string replaceStr(const string& strSrc, const string& oldStr, const string& newStr, int count = -1); 457 | string stripStr(const string& str, const string& patternstr = " \n\t"); 458 | std::string <rim(std::string &s) ; 459 | std::string &rtrim(std::string &s) ; 460 | std::string &trim(std::string &s) ; 461 | unsigned int countStrDistance(const string& A, const string& B); 462 | unsigned int countStrSimilarity(const string& A, const string& B); 463 | 464 | 465 | bool uniStrToVec(const string& str, Unicode& vec); 466 | string uniVecToStr(const Unicode& vec); 467 | 468 | inline uint16_t twocharToUint16(char high, char low); 469 | 470 | inline pair uint16ToChar2(uint16_t in); 471 | 472 | inline void printUnicode(const Unicode& unicode); 473 | 474 | inline bool strStartsWith(const string& str, const string& prefix); 475 | 476 | inline bool strEndsWith(const string& str, const string& suffix); 477 | 478 | } 479 | #endif 480 | 481 | //////globals.h 482 | #ifndef CPPJIEBA_GLOBALS_H 483 | #define CPPJIEBA_GLOBALS_H 484 | 485 | 486 | 487 | namespace CppJieba 488 | { 489 | 490 | using namespace std; 491 | //using std::tr1::unordered_map; 492 | using std::unordered_map; 493 | //using __gnu_cxx::hash_map; 494 | //using namespace stdext; 495 | //typedefs 496 | typedef std::vector::iterator VSI; 497 | typedef std::vector Unicode; 498 | typedef Unicode::const_iterator UniConIter; 499 | typedef unordered_map TrieNodeMap; 500 | typedef unordered_map EmitProbMap; 501 | 502 | const double MIN_DOUBLE = -3.14e+100; 503 | const double MAX_DOUBLE = 3.14e+100; 504 | enum CHAR_TYPE { CHWORD = 0, DIGIT_OR_LETTER = 1, OTHERS = 2}; 505 | } 506 | 507 | #endif 508 | 509 | 510 | //////ChineseFilter.h 511 | #ifndef CPPJIEBA_CHINESEFILTER_H 512 | #define CPPJIEBA_CHINESEFILTER_H 513 | 514 | 515 | namespace CppJieba 516 | { 517 | class ChFilterIterator; 518 | class ChineseFilter 519 | { 520 | public: 521 | typedef ChFilterIterator iterator; 522 | public: 523 | ChineseFilter(); 524 | ~ChineseFilter(); 525 | public: 526 | bool feed(const std::string& str); 527 | iterator begin(); 528 | iterator end(); 529 | private: 530 | Unicode _unico; 531 | private: 532 | //friend class ChFilterIterator; 533 | }; 534 | 535 | class ChFilterIterator 536 | { 537 | public: 538 | const Unicode * ptUnico; 539 | UniConIter begin; 540 | UniConIter end; 541 | CHAR_TYPE charType; 542 | ChFilterIterator& operator++(); 543 | ChFilterIterator operator++(int); 544 | bool operator==(const ChFilterIterator& iter); 545 | bool operator!=(const ChFilterIterator& iter); 546 | ChFilterIterator& operator=(const ChFilterIterator& iter); 547 | 548 | public: 549 | ChFilterIterator(const Unicode * ptu, UniConIter be, UniConIter en, CHAR_TYPE is):ptUnico(ptu), begin(be), end(en), charType(is){}; 550 | ChFilterIterator(const Unicode * ptu):ptUnico(ptu){*this = _get(ptUnico->begin());}; 551 | private: 552 | ChFilterIterator(); 553 | private: 554 | CHAR_TYPE _charType(uint16_t x)const; 555 | ChFilterIterator _get(UniConIter iter); 556 | 557 | }; 558 | } 559 | 560 | 561 | #endif 562 | 563 | 564 | //////SegmentInterface.h 565 | #ifndef CPPJIEBA_SEGMENTINTERFACE_H 566 | #define CPPJIEBA_SEGMENTINTERFACE_H 567 | 568 | namespace CppJieba 569 | { 570 | class SegmentInterface 571 | { 572 | //public: 573 | // virtual ~SegmentInterface(){}; 574 | public: 575 | virtual bool cut(Unicode::const_iterator begin , Unicode::const_iterator end, vector& res) const = 0; 576 | virtual bool cut(const string& str, vector& res) const = 0; 577 | }; 578 | } 579 | 580 | #endif 581 | 582 | //////SegmentBase.h 583 | #ifndef CPPJIEBA_SEGMENTBASE_H 584 | #define CPPJIEBA_SEGMENTBASE_H 585 | 586 | namespace CppJieba 587 | { 588 | using namespace CPPCOMMON; 589 | class SegmentBase: public SegmentInterface 590 | { 591 | public: 592 | SegmentBase(){_setInitFlag(false);}; 593 | virtual ~SegmentBase(){}; 594 | private: 595 | bool _isInited; 596 | protected: 597 | bool _getInitFlag()const{return _isInited;}; 598 | bool _setInitFlag(bool flag){return _isInited = flag;}; 599 | bool cut(const string& str, vector& res)const; 600 | bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res)const = 0; 601 | 602 | }; 603 | } 604 | 605 | #endif 606 | 607 | 608 | 609 | //////HMMSegment.h 610 | #ifndef CPPJIBEA_HMMSEGMENT_H 611 | #define CPPJIBEA_HMMSEGMENT_H 612 | 613 | 614 | 615 | namespace CppJieba 616 | { 617 | using namespace CPPCOMMON; 618 | class HMMSegment: public SegmentBase 619 | { 620 | public: 621 | /* 622 | * STATUS: 623 | * 0:B, 1:E, 2:M, 3:S 624 | * */ 625 | enum {B = 0, E = 1, M = 2, S = 3, STATUS_SUM = 4}; 626 | private: 627 | char _statMap[STATUS_SUM]; 628 | double _startProb[STATUS_SUM]; 629 | double _transProb[STATUS_SUM][STATUS_SUM]; 630 | EmitProbMap _emitProbB; 631 | EmitProbMap _emitProbE; 632 | EmitProbMap _emitProbM; 633 | EmitProbMap _emitProbS; 634 | vector _emitProbVec; 635 | 636 | public: 637 | HMMSegment(); 638 | virtual ~HMMSegment(); 639 | public: 640 | bool init(const char* const modelPath); 641 | bool dispose(); 642 | public: 643 | bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res)const ; 644 | bool cut(const string& str, vector& res)const; 645 | bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res)const; 646 | //virtual bool cut(const string& str, vector& res)const; 647 | 648 | private: 649 | bool _viterbi(Unicode::const_iterator begin, Unicode::const_iterator end, vector& status)const; 650 | bool _loadModel(const char* const filePath); 651 | bool _getLine(ifstream& ifile, string& line); 652 | bool _loadEmitProb(const string& line, EmitProbMap& mp); 653 | bool _decodeOne(const string& str, uint16_t& res); 654 | double _getEmitProb(const EmitProbMap* ptMp, uint16_t key, double defVal)const ; 655 | 656 | 657 | }; 658 | } 659 | 660 | #endif 661 | 662 | 663 | 664 | 665 | 666 | 667 | 668 | 669 | 670 | //////TransCode.h 671 | #ifndef CPPJIEBA_TRANSCODE_H 672 | #define CPPJIEBA_TRANSCODE_H 673 | 674 | typedef unsigned int uint; 675 | 676 | namespace CppJieba 677 | { 678 | 679 | using namespace CPPCOMMON; 680 | namespace TransCode 681 | { 682 | inline bool decode(const string& str, vector& vec) 683 | { 684 | char ch1, ch2; 685 | if(str.empty()) 686 | { 687 | return false; 688 | } 689 | vec.clear(); 690 | size_t siz = str.size(); 691 | for(uint i = 0;i < siz;) 692 | { 693 | if(!(str[i] & 0x80)) // 0xxxxxxx 694 | { 695 | vec.push_back(str[i]); 696 | i++; 697 | } 698 | else if ((unsigned char)str[i] <= 0xdf && i + 1 < siz) // 110xxxxxx 699 | { 700 | ch1 = (str[i] >> 2) & 0x07; 701 | ch2 = (str[i+1] & 0x3f) | ((str[i] & 0x03) << 6 ); 702 | vec.push_back(twocharToUint16(ch1, ch2)); 703 | i += 2; 704 | } 705 | else if((unsigned char)str[i] <= 0xef && i + 2 < siz) 706 | { 707 | ch1 = (str[i] << 4) | ((str[i+1] >> 2) & 0x0f ); 708 | ch2 = ((str[i+1]<<6) & 0xc0) | (str[i+2] & 0x3f); 709 | vec.push_back(twocharToUint16(ch1, ch2)); 710 | i += 3; 711 | } 712 | else 713 | { 714 | return false; 715 | } 716 | } 717 | return true; 718 | } 719 | 720 | 721 | inline bool encode(vector::const_iterator begin, vector::const_iterator end, string& res) 722 | { 723 | if(begin >= end) 724 | { 725 | return false; 726 | } 727 | res.clear(); 728 | uint16_t ui; 729 | while(begin != end) 730 | { 731 | ui = *begin; 732 | if(ui <= 0x7f) 733 | { 734 | res += char(ui); 735 | } 736 | else if(ui <= 0x7ff) 737 | { 738 | res += char(((ui>>6) & 0x1f) | 0xc0); 739 | res += char((ui & 0x3f) | 0x80); 740 | } 741 | else 742 | { 743 | res += char(((ui >> 12) & 0x0f )| 0xe0); 744 | res += char(((ui>>6) & 0x3f )| 0x80 ); 745 | res += char((ui & 0x3f) | 0x80); 746 | } 747 | begin ++; 748 | } 749 | return true; 750 | } 751 | inline bool encode(const vector& sentence, string& res) 752 | { 753 | return encode(sentence.begin(), sentence.end(), res); 754 | } 755 | } 756 | } 757 | 758 | #endif 759 | 760 | 761 | 762 | 763 | //////structs.h 764 | #ifndef CPPJIEBA_STRUCTS_H 765 | #define CPPJIEBA_STRUCTS_H 766 | 767 | 768 | namespace CppJieba 769 | { 770 | 771 | struct TrieNodeInfo 772 | { 773 | //string word; 774 | //size_t wLen;// the word's len , not string.length(), 775 | Unicode word; 776 | size_t freq; 777 | string tag; 778 | double logFreq; //logFreq = log(freq/sum(freq)); 779 | TrieNodeInfo():freq(0),logFreq(0.0) 780 | { 781 | } 782 | TrieNodeInfo(const TrieNodeInfo& nodeInfo):word(nodeInfo.word), freq(nodeInfo.freq), tag(nodeInfo.tag), logFreq(nodeInfo.logFreq) 783 | { 784 | } 785 | TrieNodeInfo(const Unicode& _word):word(_word),freq(0),logFreq(MIN_DOUBLE) 786 | { 787 | } 788 | string toString()const 789 | { 790 | string tmp; 791 | TransCode::encode(word, tmp); 792 | return string_format("{word:%s,freq:%d, logFreq:%lf}", tmp.c_str(), freq, logFreq); 793 | } 794 | }; 795 | 796 | typedef unordered_map DagType; 797 | struct SegmentChar 798 | { 799 | uint16_t uniCh; 800 | DagType dag; 801 | const TrieNodeInfo * pInfo; 802 | double weight; 803 | 804 | SegmentChar(uint16_t uni):uniCh(uni), pInfo(NULL), weight(0.0) 805 | { 806 | } 807 | 808 | /*const TrieNodeInfo* pInfo; 809 | double weight; 810 | SegmentChar(uint16_t unich, const TrieNodeInfo* p, double w):uniCh(unich), pInfo(p), weight(w) 811 | { 812 | }*/ 813 | }; 814 | /* 815 | struct SegmentContext 816 | { 817 | vector context; 818 | bool getDA 819 | };*/ 820 | typedef vector SegmentContext; 821 | 822 | 823 | struct KeyWordInfo: public TrieNodeInfo 824 | { 825 | double idf; 826 | double weight;// log(wLen+1)*logFreq; 827 | KeyWordInfo():idf(0.0),weight(0.0) 828 | { 829 | } 830 | KeyWordInfo(const Unicode& _word):TrieNodeInfo(_word),idf(0.0),weight(0.0) 831 | { 832 | } 833 | KeyWordInfo(const TrieNodeInfo& trieNodeInfo):TrieNodeInfo(trieNodeInfo) 834 | { 835 | } 836 | inline string toString() const 837 | { 838 | string tmp; 839 | TransCode::encode(word, tmp); 840 | return string_format("{word:%s,weight:%lf, idf:%lf}", tmp.c_str(), weight, idf); 841 | } 842 | KeyWordInfo& operator = (const TrieNodeInfo& trieNodeInfo) 843 | { 844 | word = trieNodeInfo.word; 845 | freq = trieNodeInfo.freq; 846 | tag = trieNodeInfo.tag; 847 | logFreq = trieNodeInfo.logFreq; 848 | return *this; 849 | } 850 | }; 851 | 852 | inline string joinWordInfos(const vector& vec) 853 | { 854 | vector tmp; 855 | for(uint i = 0; i < vec.size(); i++) 856 | { 857 | tmp.push_back(vec[i].toString()); 858 | } 859 | return joinStr(tmp, ","); 860 | } 861 | } 862 | 863 | #endif 864 | 865 | 866 | //////Trie.h 867 | #ifndef CPPJIEBA_TRIE_H 868 | #define CPPJIEBA_TRIE_H 869 | 870 | namespace CppJieba 871 | { 872 | using namespace CPPCOMMON; 873 | struct TrieNode 874 | { 875 | TrieNodeMap hmap; 876 | bool isLeaf; 877 | uint nodeInfoVecPos; 878 | TrieNode() 879 | { 880 | isLeaf = false; 881 | nodeInfoVecPos = 0; 882 | } 883 | }; 884 | 885 | class Trie 886 | { 887 | 888 | private: 889 | TrieNode* _root; 890 | vector _nodeInfoVec; 891 | 892 | bool _initFlag; 893 | int64_t _freqSum; 894 | double _minLogFreq; 895 | 896 | public: 897 | Trie(); 898 | ~Trie(); 899 | bool init(); 900 | bool loadDict(const char * const filePath); 901 | bool dispose(); 902 | 903 | private: 904 | void _setInitFlag(bool on){_initFlag = on;}; 905 | bool _getInitFlag()const{return _initFlag;}; 906 | 907 | public: 908 | const TrieNodeInfo* find(const string& str)const; 909 | const TrieNodeInfo* find(const Unicode& uintVec)const; 910 | const TrieNodeInfo* find(Unicode::const_iterator begin, Unicode::const_iterator end)const; 911 | bool find(const Unicode& unico, vector >& res)const; 912 | 913 | const TrieNodeInfo* findPrefix(const string& str)const; 914 | 915 | public: 916 | //double getWeight(const string& str); 917 | //double getWeight(const Unicode& uintVec); 918 | //double getWeight(Unicode::const_iterator begin, Unicode::const_iterator end); 919 | double getMinLogFreq()const{return _minLogFreq;}; 920 | 921 | //int64_t getTotalCount(){return _freqSum;}; 922 | 923 | bool insert(const TrieNodeInfo& nodeInfo); 924 | 925 | private: 926 | bool _trieInsert(const char * const filePath); 927 | bool _countWeight(); 928 | bool _deleteNode(TrieNode* node); 929 | 930 | }; 931 | } 932 | 933 | #endif 934 | 935 | //////MPSegment.h 936 | #ifndef CPPJIEBA_MPSEGMENT_H 937 | #define CPPJIEBA_MPSEGMENT_H 938 | 939 | 940 | 941 | namespace CppJieba 942 | { 943 | 944 | typedef vector SegmentContext; 945 | 946 | class MPSegment: public SegmentBase 947 | { 948 | private: 949 | Trie _trie; 950 | 951 | public: 952 | MPSegment(); 953 | virtual ~MPSegment(); 954 | public: 955 | bool init(const char* const filePath); 956 | bool dispose(); 957 | public: 958 | //bool cut(const string& str, vector& segWordInfos)const; 959 | bool cut(const string& str, vector& res)const; 960 | bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res)const; 961 | bool cut(const string& str, vector& segWordInfos)const; 962 | bool cut(Unicode::const_iterator begin , Unicode::const_iterator end, vector& segWordInfos)const; 963 | //virtual bool cut(const string& str, vector& res)const; 964 | 965 | private: 966 | bool _calcDAG(SegmentContext& segContext)const; 967 | bool _calcDP(SegmentContext& segContext)const; 968 | bool _cut(SegmentContext& segContext, vector& res)const; 969 | 970 | 971 | }; 972 | } 973 | 974 | #endif 975 | 976 | 977 | //////MixSegment.h 978 | #ifndef CPPJIEBA_MIXSEGMENT_H 979 | #define CPPJIEBA_MIXSEGMENT_H 980 | 981 | namespace CppJieba 982 | { 983 | class MixSegment: public SegmentBase 984 | { 985 | private: 986 | MPSegment _mpSeg; 987 | HMMSegment _hmmSeg; 988 | public: 989 | MixSegment(); 990 | virtual ~MixSegment(); 991 | public: 992 | bool init(const char* const _mpSegDict, const char* const _hmmSegDict); 993 | bool dispose(); 994 | public: 995 | //virtual bool cut(const string& str, vector& res) const; 996 | bool cut(const string& str, vector& res)const; 997 | bool cut(Unicode::const_iterator begin, Unicode::const_iterator end, vector& res)const; 998 | }; 999 | } 1000 | 1001 | #endif 1002 | 1003 | //////KeyWordExt.h 1004 | #ifndef CPPJIEBA_KEYWORDEXT_H 1005 | #define CPPJIEBA_KEYWORDEXT_H 1006 | 1007 | namespace CppJieba 1008 | { 1009 | 1010 | class KeyWordExt 1011 | { 1012 | private: 1013 | MPSegment _segment; 1014 | //vector _priorSubWords; 1015 | set _stopWords; 1016 | public: 1017 | KeyWordExt(); 1018 | ~KeyWordExt(); 1019 | bool init(const char* const segDictFile); 1020 | bool dispose(); 1021 | bool loadStopWords(const char * const filePath); 1022 | private: 1023 | //bool _loadPriorSubWords(const char * const filePath); 1024 | 1025 | 1026 | public: 1027 | bool extract(const string& title, vector& keyWordInfos, uint topN); 1028 | bool extract(const vector& words, vector& keyWordInfos, uint topN); 1029 | private: 1030 | static bool _wordInfoCompare(const KeyWordInfo& a, const KeyWordInfo& b); 1031 | private: 1032 | bool _extract(vector& keyWordInfos, uint topN); 1033 | bool _extTopN(vector& wordInfos, uint topN); 1034 | private: 1035 | //sort by word len - idf 1036 | bool _sortWLIDF(vector& wordInfos); 1037 | private: 1038 | bool _filter(vector& ); 1039 | bool _filterDuplicate(vector& ); 1040 | bool _filterSingleWord(vector& ); 1041 | bool _filterSubstr(vector& ); 1042 | bool _filterStopWords(vector& ); 1043 | private: 1044 | inline bool _isSubIn(const vector& words, const Unicode& word)const 1045 | { 1046 | 1047 | for(uint j = 0; j < words.size(); j++) 1048 | { 1049 | if(word != words[j] && words[j].end() != search(words[j].begin(), words[j].end(), word.begin(), word.end())) 1050 | { 1051 | return true; 1052 | } 1053 | } 1054 | return false; 1055 | } 1056 | //bool _prioritizeSubWords(vector& wordInfos); 1057 | //bool _isContainSubWords(const string& word); 1058 | 1059 | }; 1060 | 1061 | } 1062 | 1063 | #endif 1064 | 1065 | 1066 | 1067 | 1068 | 1069 | -------------------------------------------------------------------------------- /include/cppMSGmodule.hpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | using namespace std; 5 | 6 | class CppMsgModule { 7 | public: 8 | static bool msgStrToPcharArray(char**& pchar_array, int& count_array, const string& msg_str); 9 | 10 | static bool idArrayToMsgStr(string& msg_str, const int* p_id, const int count_id); 11 | }; 12 | -------------------------------------------------------------------------------- /include/libClassifier.h: -------------------------------------------------------------------------------- 1 | // #include "stdafx.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | 14 | //libsvm code 15 | #ifndef _LIBSVM_H 16 | #define _LIBSVM_H 17 | 18 | #define LIBSVM_VERSION 312 19 | 20 | #ifdef __cplusplus 21 | extern "C" { 22 | #endif 23 | 24 | extern int libsvm_version; 25 | 26 | struct svm_node 27 | { 28 | int index; 29 | double value; 30 | }; 31 | 32 | struct svm_problem 33 | { 34 | int l; 35 | double *y; 36 | struct svm_node **x; 37 | }; 38 | 39 | enum { C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR }; /* svm_type */ 40 | enum { LINEAR, POLY, RBF, SIGMOID, PRECOMPUTED }; /* kernel_type */ 41 | 42 | struct svm_parameter 43 | { 44 | int svm_type; 45 | int kernel_type; 46 | int degree; /* for poly */ 47 | double gamma; /* for poly/rbf/sigmoid */ 48 | double coef0; /* for poly/sigmoid */ 49 | 50 | /* these are for training only */ 51 | double cache_size; /* in MB */ 52 | double eps; /* stopping criteria */ 53 | double C; /* for C_SVC, EPSILON_SVR and NU_SVR */ 54 | int nr_weight; /* for C_SVC */ 55 | int *weight_label; /* for C_SVC */ 56 | double* weight; /* for C_SVC */ 57 | double nu; /* for NU_SVC, ONE_CLASS, and NU_SVR */ 58 | double p; /* for EPSILON_SVR */ 59 | int shrinking; /* use the shrinking heuristics */ 60 | int probability; /* do probability estimates */ 61 | }; 62 | 63 | // 64 | // svm_model 65 | // 66 | struct svm_model 67 | { 68 | struct svm_parameter param; /* parameter */ 69 | int nr_class; /* number of classes, = 2 in regression/one class svm */ 70 | int l; /* total #SV */ 71 | struct svm_node **SV; /* SVs (SV[l]) */ 72 | double **sv_coef; /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */ 73 | double *rho; /* constants in decision functions (rho[k*(k-1)/2]) */ 74 | double *probA; /* pariwise probability information */ 75 | double *probB; 76 | 77 | /* for classification only */ 78 | 79 | int *label; /* label of each class (label[k]) */ 80 | int *nSV; /* number of SVs for each class (nSV[k]) */ 81 | /* nSV[0] + nSV[1] + ... + nSV[k-1] = l */ 82 | /* XXX */ 83 | int free_sv; /* 1 if svm_model is created by svm_load_model*/ 84 | /* 0 if svm_model is created by svm_train */ 85 | }; 86 | 87 | struct svm_model *svm_train(const struct svm_problem *prob, const struct svm_parameter *param); 88 | void svm_cross_validation(const struct svm_problem *prob, const struct svm_parameter *param, int nr_fold, double *target); 89 | 90 | int svm_save_model(const char *model_file_name, const struct svm_model *model); 91 | struct svm_model *svm_load_model(const char *model_file_name); 92 | 93 | int svm_get_svm_type(const struct svm_model *model); 94 | int svm_get_nr_class(const struct svm_model *model); 95 | void svm_get_labels(const struct svm_model *model, int *label); 96 | double svm_get_svr_probability(const struct svm_model *model); 97 | 98 | double svm_predict_values(const struct svm_model *model, const struct svm_node *x, double* dec_values); 99 | double svm_predict(const struct svm_model *model, const struct svm_node *x); 100 | double svm_predict_probability(const struct svm_model *model, const struct svm_node *x, double* prob_estimates); 101 | 102 | void svm_free_model_content(struct svm_model *model_ptr); 103 | void svm_free_and_destroy_model(struct svm_model **model_ptr_ptr); 104 | void svm_destroy_param(struct svm_parameter *param); 105 | 106 | const char *svm_check_parameter(const struct svm_problem *prob, const struct svm_parameter *param); 107 | int svm_check_probability_model(const struct svm_model *model); 108 | 109 | void svm_set_print_string_function(void (*print_func)(const char *)); 110 | 111 | #ifdef __cplusplus 112 | } 113 | #endif 114 | 115 | #endif /* _LIBSVM_H */ 116 | -------------------------------------------------------------------------------- /include/titlebook.pb.h: -------------------------------------------------------------------------------- 1 | // Generated by the protocol buffer compiler. DO NOT EDIT! 2 | // source: titlebook.proto 3 | 4 | #ifndef PROTOBUF_titlebook_2eproto__INCLUDED 5 | #define PROTOBUF_titlebook_2eproto__INCLUDED 6 | 7 | #include 8 | 9 | #include 10 | 11 | #if GOOGLE_PROTOBUF_VERSION < 2005000 12 | #error This file was generated by a newer version of protoc which is 13 | #error incompatible with your Protocol Buffer headers. Please update 14 | #error your headers. 15 | #endif 16 | #if 2005000 < GOOGLE_PROTOBUF_MIN_PROTOC_VERSION 17 | #error This file was generated by an older version of protoc which is 18 | #error incompatible with your Protocol Buffer headers. Please 19 | #error regenerate this file with a newer version of protoc. 20 | #endif 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | // @@protoc_insertion_point(includes) 28 | 29 | namespace tutorial { 30 | 31 | // Internal implementation detail -- do not call these. 32 | void protobuf_AddDesc_titlebook_2eproto(); 33 | void protobuf_AssignDesc_titlebook_2eproto(); 34 | void protobuf_ShutdownFile_titlebook_2eproto(); 35 | 36 | class TitleList; 37 | class IdList; 38 | 39 | // =================================================================== 40 | 41 | class TitleList : public ::google::protobuf::Message { 42 | public: 43 | TitleList(); 44 | virtual ~TitleList(); 45 | 46 | TitleList(const TitleList& from); 47 | 48 | inline TitleList& operator=(const TitleList& from) { 49 | CopyFrom(from); 50 | return *this; 51 | } 52 | 53 | inline const ::google::protobuf::UnknownFieldSet& unknown_fields() const { 54 | return _unknown_fields_; 55 | } 56 | 57 | inline ::google::protobuf::UnknownFieldSet* mutable_unknown_fields() { 58 | return &_unknown_fields_; 59 | } 60 | 61 | static const ::google::protobuf::Descriptor* descriptor(); 62 | static const TitleList& default_instance(); 63 | 64 | void Swap(TitleList* other); 65 | 66 | // implements Message ---------------------------------------------- 67 | 68 | TitleList* New() const; 69 | void CopyFrom(const ::google::protobuf::Message& from); 70 | void MergeFrom(const ::google::protobuf::Message& from); 71 | void CopyFrom(const TitleList& from); 72 | void MergeFrom(const TitleList& from); 73 | void Clear(); 74 | bool IsInitialized() const; 75 | 76 | int ByteSize() const; 77 | bool MergePartialFromCodedStream( 78 | ::google::protobuf::io::CodedInputStream* input); 79 | void SerializeWithCachedSizes( 80 | ::google::protobuf::io::CodedOutputStream* output) const; 81 | ::google::protobuf::uint8* SerializeWithCachedSizesToArray(::google::protobuf::uint8* output) const; 82 | int GetCachedSize() const { return _cached_size_; } 83 | private: 84 | void SharedCtor(); 85 | void SharedDtor(); 86 | void SetCachedSize(int size) const; 87 | public: 88 | 89 | ::google::protobuf::Metadata GetMetadata() const; 90 | 91 | // nested types ---------------------------------------------------- 92 | 93 | // accessors ------------------------------------------------------- 94 | 95 | // repeated string title = 1; 96 | inline int title_size() const; 97 | inline void clear_title(); 98 | static const int kTitleFieldNumber = 1; 99 | inline const ::std::string& title(int index) const; 100 | inline ::std::string* mutable_title(int index); 101 | inline void set_title(int index, const ::std::string& value); 102 | inline void set_title(int index, const char* value); 103 | inline void set_title(int index, const char* value, size_t size); 104 | inline ::std::string* add_title(); 105 | inline void add_title(const ::std::string& value); 106 | inline void add_title(const char* value); 107 | inline void add_title(const char* value, size_t size); 108 | inline const ::google::protobuf::RepeatedPtrField< ::std::string>& title() const; 109 | inline ::google::protobuf::RepeatedPtrField< ::std::string>* mutable_title(); 110 | 111 | // @@protoc_insertion_point(class_scope:tutorial.TitleList) 112 | private: 113 | 114 | ::google::protobuf::UnknownFieldSet _unknown_fields_; 115 | 116 | ::google::protobuf::RepeatedPtrField< ::std::string> title_; 117 | 118 | mutable int _cached_size_; 119 | ::google::protobuf::uint32 _has_bits_[(1 + 31) / 32]; 120 | 121 | friend void protobuf_AddDesc_titlebook_2eproto(); 122 | friend void protobuf_AssignDesc_titlebook_2eproto(); 123 | friend void protobuf_ShutdownFile_titlebook_2eproto(); 124 | 125 | void InitAsDefaultInstance(); 126 | static TitleList* default_instance_; 127 | }; 128 | // ------------------------------------------------------------------- 129 | 130 | class IdList : public ::google::protobuf::Message { 131 | public: 132 | IdList(); 133 | virtual ~IdList(); 134 | 135 | IdList(const IdList& from); 136 | 137 | inline IdList& operator=(const IdList& from) { 138 | CopyFrom(from); 139 | return *this; 140 | } 141 | 142 | inline const ::google::protobuf::UnknownFieldSet& unknown_fields() const { 143 | return _unknown_fields_; 144 | } 145 | 146 | inline ::google::protobuf::UnknownFieldSet* mutable_unknown_fields() { 147 | return &_unknown_fields_; 148 | } 149 | 150 | static const ::google::protobuf::Descriptor* descriptor(); 151 | static const IdList& default_instance(); 152 | 153 | void Swap(IdList* other); 154 | 155 | // implements Message ---------------------------------------------- 156 | 157 | IdList* New() const; 158 | void CopyFrom(const ::google::protobuf::Message& from); 159 | void MergeFrom(const ::google::protobuf::Message& from); 160 | void CopyFrom(const IdList& from); 161 | void MergeFrom(const IdList& from); 162 | void Clear(); 163 | bool IsInitialized() const; 164 | 165 | int ByteSize() const; 166 | bool MergePartialFromCodedStream( 167 | ::google::protobuf::io::CodedInputStream* input); 168 | void SerializeWithCachedSizes( 169 | ::google::protobuf::io::CodedOutputStream* output) const; 170 | ::google::protobuf::uint8* SerializeWithCachedSizesToArray(::google::protobuf::uint8* output) const; 171 | int GetCachedSize() const { return _cached_size_; } 172 | private: 173 | void SharedCtor(); 174 | void SharedDtor(); 175 | void SetCachedSize(int size) const; 176 | public: 177 | 178 | ::google::protobuf::Metadata GetMetadata() const; 179 | 180 | // nested types ---------------------------------------------------- 181 | 182 | // accessors ------------------------------------------------------- 183 | 184 | // repeated int32 id = 1; 185 | inline int id_size() const; 186 | inline void clear_id(); 187 | static const int kIdFieldNumber = 1; 188 | inline ::google::protobuf::int32 id(int index) const; 189 | inline void set_id(int index, ::google::protobuf::int32 value); 190 | inline void add_id(::google::protobuf::int32 value); 191 | inline const ::google::protobuf::RepeatedField< ::google::protobuf::int32 >& 192 | id() const; 193 | inline ::google::protobuf::RepeatedField< ::google::protobuf::int32 >* 194 | mutable_id(); 195 | 196 | // @@protoc_insertion_point(class_scope:tutorial.IdList) 197 | private: 198 | 199 | ::google::protobuf::UnknownFieldSet _unknown_fields_; 200 | 201 | ::google::protobuf::RepeatedField< ::google::protobuf::int32 > id_; 202 | 203 | mutable int _cached_size_; 204 | ::google::protobuf::uint32 _has_bits_[(1 + 31) / 32]; 205 | 206 | friend void protobuf_AddDesc_titlebook_2eproto(); 207 | friend void protobuf_AssignDesc_titlebook_2eproto(); 208 | friend void protobuf_ShutdownFile_titlebook_2eproto(); 209 | 210 | void InitAsDefaultInstance(); 211 | static IdList* default_instance_; 212 | }; 213 | // =================================================================== 214 | 215 | 216 | // =================================================================== 217 | 218 | // TitleList 219 | 220 | // repeated string title = 1; 221 | inline int TitleList::title_size() const { 222 | return title_.size(); 223 | } 224 | inline void TitleList::clear_title() { 225 | title_.Clear(); 226 | } 227 | inline const ::std::string& TitleList::title(int index) const { 228 | return title_.Get(index); 229 | } 230 | inline ::std::string* TitleList::mutable_title(int index) { 231 | return title_.Mutable(index); 232 | } 233 | inline void TitleList::set_title(int index, const ::std::string& value) { 234 | title_.Mutable(index)->assign(value); 235 | } 236 | inline void TitleList::set_title(int index, const char* value) { 237 | title_.Mutable(index)->assign(value); 238 | } 239 | inline void TitleList::set_title(int index, const char* value, size_t size) { 240 | title_.Mutable(index)->assign( 241 | reinterpret_cast(value), size); 242 | } 243 | inline ::std::string* TitleList::add_title() { 244 | return title_.Add(); 245 | } 246 | inline void TitleList::add_title(const ::std::string& value) { 247 | title_.Add()->assign(value); 248 | } 249 | inline void TitleList::add_title(const char* value) { 250 | title_.Add()->assign(value); 251 | } 252 | inline void TitleList::add_title(const char* value, size_t size) { 253 | title_.Add()->assign(reinterpret_cast(value), size); 254 | } 255 | inline const ::google::protobuf::RepeatedPtrField< ::std::string>& 256 | TitleList::title() const { 257 | return title_; 258 | } 259 | inline ::google::protobuf::RepeatedPtrField< ::std::string>* 260 | TitleList::mutable_title() { 261 | return &title_; 262 | } 263 | 264 | // ------------------------------------------------------------------- 265 | 266 | // IdList 267 | 268 | // repeated int32 id = 1; 269 | inline int IdList::id_size() const { 270 | return id_.size(); 271 | } 272 | inline void IdList::clear_id() { 273 | id_.Clear(); 274 | } 275 | inline ::google::protobuf::int32 IdList::id(int index) const { 276 | return id_.Get(index); 277 | } 278 | inline void IdList::set_id(int index, ::google::protobuf::int32 value) { 279 | id_.Set(index, value); 280 | } 281 | inline void IdList::add_id(::google::protobuf::int32 value) { 282 | id_.Add(value); 283 | } 284 | inline const ::google::protobuf::RepeatedField< ::google::protobuf::int32 >& 285 | IdList::id() const { 286 | return id_; 287 | } 288 | inline ::google::protobuf::RepeatedField< ::google::protobuf::int32 >* 289 | IdList::mutable_id() { 290 | return &id_; 291 | } 292 | 293 | 294 | // @@protoc_insertion_point(namespace_scope) 295 | 296 | } // namespace tutorial 297 | 298 | #ifndef SWIG 299 | namespace google { 300 | namespace protobuf { 301 | 302 | 303 | } // namespace google 304 | } // namespace protobuf 305 | #endif // SWIG 306 | 307 | // @@protoc_insertion_point(global_scope) 308 | 309 | #endif // PROTOBUF_titlebook_2eproto__INCLUDED 310 | -------------------------------------------------------------------------------- /pyscript/Leader_asynchronous.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #coding:utf-8 3 | 4 | import zmq 5 | import sys 6 | sys.path.insert(0, 'build') 7 | import pyMSGmodule as MSG 8 | 9 | 10 | # Worker_port = 5555 11 | Secretary_port = 8964 12 | 13 | 14 | if __name__=='__main__': 15 | 16 | # Client is created with a socket type “zmq.REQ” 17 | Leader_Req_context = zmq.Context() 18 | print "Connecting to server..." 19 | Leader_Req_socket = Leader_Req_context.socket(zmq.REQ) 20 | Leader_Req_socket.connect("tcp://localhost:%d" % Secretary_port) 21 | 22 | str_list=["教育部考试中心托福网考网上报名", 23 | "皇马6-4马竞登顶欧冠", 24 | "evernote 安装最新版本后,个别笔记本无法同步?", 25 | "ios私有api 能修改运营商名称吗?", 26 | "提前博弈A股纳入MSCI"]*10 27 | 28 | # serialize str_list into msg_str 29 | msg_str='' 30 | try: 31 | # msg_str=strListToMsgStr(str_list+[1]) 32 | msg_str=MSG.strListToMsgStr(str_list) 33 | print msg_str 34 | except Exception as error: 35 | print error 36 | exit() 37 | 38 | # #send request to CPP Server 39 | uni_id = '12345' 40 | Leader_Req_socket.send(uni_id+' '+msg_str) 41 | 42 | # #get the reply from CPP Server 43 | # function is blocked here until having reply 44 | secretary_msg = Leader_Req_socket.recv() 45 | print '\n\n', secretary_msg, '\n' 46 | -------------------------------------------------------------------------------- /pyscript/Leader_synchronous.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #coding:utf-8 3 | 4 | import zmq 5 | import sys 6 | sys.path.insert(0, 'build') 7 | import pyMSGmodule as MSG 8 | 9 | 10 | Worker_port = 5555 11 | 12 | 13 | if __name__=='__main__': 14 | 15 | # Client is created with a socket type “zmq.REQ” 16 | Leader_Req_context = zmq.Context() 17 | print "Connecting to server..." 18 | Leader_Req_socket = Leader_Req_context.socket(zmq.REQ) 19 | Leader_Req_socket.connect("tcp://localhost:%d" % Worker_port) 20 | 21 | str_list=["教育部考试中心托福网考网上报名", 22 | "皇马6-4马竞登顶欧冠", 23 | "evernote 安装最新版本后,个别笔记本无法同步?", 24 | "ios私有api 能修改运营商名称吗?", 25 | "提前博弈A股纳入MSCI"]*10 26 | 27 | # serialize str_list into msg_str 28 | msg_str='' 29 | try: 30 | # msg_str=strListToMsgStr(str_list+[1]) 31 | msg_str=MSG.strListToMsgStr(str_list) 32 | print msg_str 33 | except Exception as error: 34 | print error 35 | exit() 36 | 37 | # #send request to CPP Server 38 | Leader_Req_socket.send (msg_str) 39 | 40 | # #get the reply from CPP Server 41 | # function is blocked here until having reply 42 | message = Leader_Req_socket.recv() 43 | print '\n\n', message, '\n' 44 | id_list = MSG.msgStrToIdList(message) 45 | print id_list 46 | -------------------------------------------------------------------------------- /pyscript/Secretary_asynchronous.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #coding:utf-8 3 | 4 | import zmq 5 | import time 6 | import sys 7 | from Queue import Queue 8 | from threading import Thread 9 | 10 | import sys 11 | sys.path.insert(0, 'build') 12 | import pyMSGmodule as MSG 13 | 14 | 15 | Leader_port = 8964 16 | Worker_port = 5555 17 | 18 | 19 | msg_queue = Queue() 20 | 21 | # ( A **Secretary** will say 'OK' immediately, when **Leader** assigns a task. ) 22 | # An important task of **Secretary** is his scheduling of **Worker** 23 | # when the **Worker** is available && there is some task to do (in message_queue), 24 | # the **Worker** will get his task by **Secretary** 25 | # 26 | # secretaryThread() will implement the preceding described works. 27 | # 28 | # But waiting for the **Worker** finish this 29 | def secretaryThread(m_queue, flag): 30 | while True: 31 | if m_queue.qsize() > 0: 32 | uni_id, real_msg = m_queue.get() 33 | 34 | # give **Worker** his task 35 | # thread program will be blocked here 36 | # print 'dododododdo' 37 | # time.sleep(10) 38 | Worker_Req_context = zmq.Context() 39 | Worker_Req_socket = Worker_Req_context.socket(zmq.REQ) 40 | Worker_Req_socket.connect("tcp://localhost:%d" % Worker_port) 41 | 42 | leader_msg_str = real_msg[:] 43 | Worker_Req_socket.send(leader_msg_str) 44 | worker_msg = Worker_Req_socket.recv() 45 | # print worker_msg 46 | id_list = MSG.msgStrToIdList(worker_msg) 47 | print uni_id, '--', id_list 48 | 49 | else: 50 | time.sleep(1) 51 | 52 | # set Secretary as Deamon Thread 53 | Secretary = Thread(target=secretaryThread, args=(msg_queue, 1)) 54 | Secretary.setDaemon(True) 55 | Secretary.start() 56 | 57 | Leader_rep_context = zmq.Context() 58 | Leader_rep_socket = Leader_rep_context.socket(zmq.REP) 59 | Leader_rep_socket.bind("tcp://*:%d" % Leader_port) 60 | 61 | # wait for the Leader's command && say OK. 62 | while True: 63 | message = Leader_rep_socket.recv() 64 | uni_id = message[:message.find(' ')] 65 | real_msg = message[message.find(' ')+1:] 66 | 67 | msg_queue.put([uni_id, real_msg]) 68 | 69 | print 'Received requiest', message 70 | Leader_rep_socket.send('OK, I will do it later.') 71 | -------------------------------------------------------------------------------- /pyscript/pyMSGmodule.py: -------------------------------------------------------------------------------- 1 | # ! /usr/bin/python 2 | # coding:utf-8 3 | 4 | import titlebook_pb2 5 | 6 | def strListToMsgStr(str_list): 7 | if not isinstance(str_list, list): 8 | raise ValueError(str(str_list)+' is not list') 9 | 10 | title_book=titlebook_pb2.TitleList() 11 | for str_item in str_list: 12 | if not isinstance(str_item, str): 13 | raise ValueError(str(str_item)+' is not str') 14 | title_book.title.append(str_item) 15 | 16 | return title_book.SerializeToString() 17 | 18 | 19 | def msgStrToIdList(msg_str): 20 | id_book=titlebook_pb2.IdList() 21 | id_book.ParseFromString(msg_str) 22 | id_list=[] 23 | 24 | for id_num in id_book.id: 25 | id_list.append(id_num) 26 | 27 | return id_list 28 | 29 | 30 | if __name__=='__main__': 31 | 32 | str_list=["教育部考试中心托福网考网上报名", 33 | "皇马6-4马竞登顶欧冠", 34 | "evernote 安装最新版本后,个别笔记本无法同步?", 35 | "ios私有api 能修改运营商名称吗?", 36 | "提前博弈A股纳入MSCI"] 37 | 38 | try: 39 | # msg_str=strListToMsgStr(str_list+[1]) 40 | msg_str=strListToMsgStr(str_list) 41 | print msg_str 42 | except Exception as error: 43 | print error 44 | 45 | print '*'*10 46 | 47 | ifile=open('idbook.protomodel', 'r') 48 | id_str=ifile.read() 49 | ifile.close() 50 | 51 | try: 52 | id_list=msgStrToIdList(id_str) 53 | print id_list 54 | except Exception as error: 55 | print error 56 | -------------------------------------------------------------------------------- /src/TC_process.cpp: -------------------------------------------------------------------------------- 1 | // #include "stdafx.h" 2 | 3 | #include "TC_process.h" 4 | 5 | #include 6 | 7 | 8 | std::vector v_query; 9 | 10 | 11 | int sqlite3_exec_callback(void *data, int n_columns, char **col_values, char **col_names) 12 | { 13 | 14 | for (int i = 1; i < n_columns; i++) 15 | { 16 | //printf("%s\t", col_values[i]); 17 | v_query.push_back( atoi(col_values[i]) ); 18 | } 19 | //printf("\n"); 20 | 21 | return 0; 22 | } 23 | 24 | 25 | void query_word(sqlite3 * conn, char * word) 26 | { 27 | v_query.clear(); //��ռ������ 28 | 29 | char * err_msg = NULL; //������ʾ 30 | char sql[1000]; 31 | 32 | sprintf(sql, "select * from [keyword] where word=\'%s\';", word); 33 | 34 | sqlite3_exec(conn, sql, &sqlite3_exec_callback, 0, &err_msg); 35 | } 36 | 37 | //// 38 | void cut(const CppJieba::SegmentInterface * seg, const char * const filePath, char * outfile, int n_word) 39 | { 40 | std::ofstream ofile(outfile, std::ios::out); 41 | ofile << n_word << "\n"; 42 | 43 | std::ifstream ifile(filePath); 44 | 45 | std::vector res; 46 | std::string line; 47 | while(getline(ifile, line)) 48 | { 49 | if(!line.empty()) //�������س��� ����һ�������У�һ����Ҫ��Dz��� 50 | { 51 | if(line.size() < 2) //���س��У����ߵ����У����ţ� 52 | { 53 | continue; 54 | } 55 | res.clear(); 56 | seg->cut(line, res); 57 | 58 | //���д��� 59 | if(res.size() < n_word) //������NULL 60 | { 61 | int n_NULL = n_word - res.size(); 62 | for(int i = 0; i < n_NULL; i++) 63 | { 64 | ofile << "NULL" << "\n"; 65 | } 66 | for(int i = 0; i < res.size(); i++) 67 | { 68 | ofile << res.at(i).c_str(); 69 | ofile << "\n"; 70 | } 71 | } 72 | else 73 | { 74 | for(int i = res.size()-n_word; i < res.size(); i++) 75 | { 76 | ofile << res.at(i).c_str(); 77 | ofile << "\n"; 78 | } 79 | } 80 | 81 | 82 | } 83 | } 84 | 85 | res.clear(); 86 | 87 | ifile.close(); 88 | ofile.close(); 89 | } 90 | 91 | 92 | void cut(const CppJieba::SegmentInterface * seg, std::list & l_lines, std::vector *> & v_lines_words, int n_word) 93 | { 94 | std::list::iterator i_l_lines = l_lines.begin(); 95 | while( i_l_lines != l_lines.end() ) 96 | { 97 | std::string line( (char *)(*i_l_lines) ); 98 | std::vector res0; 99 | 100 | std::vector res; 101 | 102 | 103 | //���д��� 104 | seg->cut(line, res0); 105 | 106 | //�ִʺ���й��˴��� 107 | for(int i = 0; i < res0.size(); i++) 108 | { 109 | bool isINVALID = false; 110 | if( strcmp( res0.at(i).c_str(), "\t" ) == 0 ) 111 | { 112 | isINVALID = true; 113 | } 114 | //else if( strcmp( res0.at(i).c_str(), " " ) == 0 ) 115 | //{ 116 | // isINVALID = true; 117 | //} 118 | 119 | if(isINVALID == false) 120 | { 121 | res.push_back( res0.at(i) ); 122 | } 123 | } 124 | 125 | 126 | std::vector * p_procRes = new std::vector; 127 | v_lines_words.push_back( p_procRes ); 128 | 129 | 130 | if(res.size() < n_word) //��������NULL 131 | { 132 | int n_NULL = n_word - res.size(); 133 | for(int i = 0; i < n_NULL; i++) 134 | { 135 | //ofile << "NULL" << "\n"; 136 | char * newWord = new char [5]; 137 | strcpy(newWord, "NULL"); 138 | p_procRes->push_back( newWord ); 139 | } 140 | for(int i = 0; i < res.size(); i++) 141 | { 142 | //ofile << res.at(i).c_str(); 143 | //ofile << "\n"; 144 | int len = strlen( res.at(i).c_str() ); 145 | char * newWord = new char [len+1]; 146 | strcpy(newWord, res.at(i).c_str()); 147 | p_procRes->push_back( newWord ); 148 | } 149 | } 150 | else//�����㹻 151 | { 152 | for(int i = res.size()-n_word; i < res.size(); i++) 153 | { 154 | //ofile << res.at(i).c_str(); 155 | //ofile << "\n"; 156 | int len = strlen( res.at(i).c_str() ); 157 | char * newWord = new char [len+1]; 158 | strcpy(newWord, res.at(i).c_str()); 159 | p_procRes->push_back( newWord ); 160 | } 161 | } 162 | 163 | 164 | //�ͷ� 165 | res0.clear(); 166 | res.clear(); 167 | 168 | i_l_lines++; 169 | } 170 | 171 | } 172 | 173 | void cut(const CppJieba::SegmentInterface * seg, char ** p_text, int n_text, std::vector *> & v_lines_words, int n_word) 174 | { 175 | //std::list::iterator i_l_lines = l_lines.begin(); 176 | //while( i_l_lines != l_lines.end() ) 177 | for(int i = 0; i < n_text; i++) 178 | { 179 | //std::string line( (char *)(*i_l_lines) ); 180 | std::string line( (char *)( p_text[i] ) ); 181 | std::vector res0; 182 | 183 | std::vector res; 184 | 185 | 186 | //���д��� 187 | seg->cut(line, res0); 188 | 189 | //�ִʺ���й��˴��� 190 | for(int i = 0; i < res0.size(); i++) 191 | { 192 | bool isINVALID = false; 193 | if( strcmp( res0.at(i).c_str(), "\t" ) == 0 ) 194 | { 195 | isINVALID = true; 196 | } 197 | //else if( strcmp( res0.at(i).c_str(), " " ) == 0 ) 198 | //{ 199 | // isINVALID = true; 200 | //} 201 | 202 | if(isINVALID == false) 203 | { 204 | res.push_back( res0.at(i) ); 205 | } 206 | } 207 | 208 | 209 | std::vector * p_procRes = new std::vector; 210 | v_lines_words.push_back( p_procRes ); 211 | 212 | 213 | if(res.size() < n_word) //��������NULL 214 | { 215 | int n_NULL = n_word - res.size(); 216 | for(int i = 0; i < n_NULL; i++) 217 | { 218 | //ofile << "NULL" << "\n"; 219 | char * newWord = new char [5]; 220 | strcpy(newWord, "NULL"); 221 | p_procRes->push_back( newWord ); 222 | } 223 | for(int i = 0; i < res.size(); i++) 224 | { 225 | //ofile << res.at(i).c_str(); 226 | //ofile << "\n"; 227 | int len = strlen( res.at(i).c_str() ); 228 | char * newWord = new char [len+1]; 229 | strcpy(newWord, res.at(i).c_str()); 230 | p_procRes->push_back( newWord ); 231 | } 232 | } 233 | else//�����㹻 234 | { 235 | for(int i = res.size()-n_word; i < res.size(); i++) 236 | { 237 | //ofile << res.at(i).c_str(); 238 | //ofile << "\n"; 239 | int len = strlen( res.at(i).c_str() ); 240 | char * newWord = new char [len+1]; 241 | strcpy(newWord, res.at(i).c_str()); 242 | p_procRes->push_back( newWord ); 243 | } 244 | } 245 | 246 | 247 | //�ͷ� 248 | res0.clear(); 249 | res.clear(); 250 | 251 | //i_l_lines++; 252 | } 253 | 254 | } 255 | 256 | 257 | 258 | //�ļ����������ʵ�� 259 | termFilter::termFilter( const char * loadPath ) 260 | { 261 | std::ifstream ifile; 262 | ifile.open( loadPath, std::ios::in ); 263 | 264 | char line[100]; 265 | while( !ifile.eof() ) 266 | { 267 | ifile.getline( line, 100 ); 268 | 269 | if( strlen(line) > 0 ) 270 | { 271 | char * term = new char [strlen( line ) + 1]; 272 | strcpy(term, line); 273 | 274 | cantPassediterms.push_back(term); 275 | } 276 | 277 | //std::cout << line; 278 | } 279 | } 280 | 281 | void termFilter::appendFilter( char * loadPath ) 282 | { 283 | std::ifstream ifile; 284 | ifile.open( loadPath, std::ios::in ); 285 | 286 | char line[100]; 287 | while( !ifile.eof() ) 288 | { 289 | ifile.getline( line, 100 ); 290 | 291 | if( strlen(line) > 0 ) 292 | { 293 | char * term = new char [strlen( line ) + 1]; 294 | strcpy(term, line); 295 | 296 | cantPassediterms.push_back(term); 297 | } 298 | 299 | //std::cout << line; 300 | } 301 | } 302 | 303 | termFilter::~termFilter() 304 | { 305 | for(int i = 0; i *> & v_lines_words, int n_word, termFilter & filter) 345 | { 346 | //std::list::iterator i_l_lines = l_lines.begin(); 347 | //while( i_l_lines != l_lines.end() ) 348 | for(int i = 0; i < n_text; i++) 349 | { 350 | //std::string line( (char *)(*i_l_lines) ); 351 | std::string line( (char *)( p_text[i] ) ); 352 | std::vector res0; 353 | 354 | std::vector res; 355 | 356 | 357 | //���д��� 358 | seg->cut(line, res0); 359 | 360 | //�ִʺ���й��˴��� 361 | for(int i = 0; i < res0.size(); i++) 362 | { 363 | /* 364 | bool isINVALID = false; 365 | if( strcmp( res0.at(i).c_str(), "\t" ) == 0 ) 366 | { 367 | isINVALID = true; 368 | } 369 | //else if( strcmp( res0.at(i).c_str(), " " ) == 0 ) 370 | //{ 371 | // isINVALID = true; 372 | //} 373 | 374 | if(isINVALID == false) 375 | { 376 | res.push_back( res0.at(i) ); 377 | }*/ 378 | if( filter.termIsPass( (char*)(res0.at(i).c_str()) ) ) 379 | { 380 | res.push_back( res0.at(i) ); 381 | } 382 | } 383 | 384 | 385 | std::vector * p_procRes = new std::vector; 386 | v_lines_words.push_back( p_procRes ); 387 | 388 | 389 | if(res.size() < n_word) //��������NULL 390 | { 391 | int n_NULL = n_word - res.size(); 392 | for(int i = 0; i < n_NULL; i++) 393 | { 394 | //ofile << "NULL" << "\n"; 395 | char * newWord = new char [5]; 396 | strcpy(newWord, "NULL"); 397 | p_procRes->push_back( newWord ); 398 | } 399 | for(int i = 0; i < res.size(); i++) 400 | { 401 | //ofile << res.at(i).c_str(); 402 | //ofile << "\n"; 403 | int len = strlen( res.at(i).c_str() ); 404 | char * newWord = new char [len+1]; 405 | strcpy(newWord, res.at(i).c_str()); 406 | p_procRes->push_back( newWord ); 407 | } 408 | } 409 | else//�����㹻 410 | { 411 | for(int i = res.size()-n_word; i < res.size(); i++) 412 | { 413 | //ofile << res.at(i).c_str(); 414 | //ofile << "\n"; 415 | int len = strlen( res.at(i).c_str() ); 416 | char * newWord = new char [len+1]; 417 | strcpy(newWord, res.at(i).c_str()); 418 | p_procRes->push_back( newWord ); 419 | } 420 | } 421 | 422 | 423 | //�ͷ� 424 | res0.clear(); 425 | res.clear(); 426 | 427 | //i_l_lines++; 428 | } 429 | 430 | } 431 | 432 | 433 | 434 | //void textCategorization(std::list& l_text, std::list& l_labels) 435 | //����˵����p_text Ϊn_text���ַ�ָ�룻 p_labelsΪn_text����ǩ��int��ָ�� 436 | void textCategorization_new(char ** p_text, int n_text, int * p_labels, char * outputPath) 437 | { 438 | int n_word = N_WORD; 439 | 440 | std::vector< std::vector * > v_lines_words; //�����еĴ�����б� 441 | std::vector< std::vector * > v_class_tf; //ÿ������ÿ���г��ִ�Ƶ���б� 442 | std::vector< std::vector * > v_featureVector; //�������� 443 | 444 | ////���õĹ��� 445 | //�ָ���seg 446 | CppJieba::MPSegment seg; 447 | //��ʼ�� 448 | //bool init_res = seg.init("C:\\languageData_new\\jieba.dict.utf8"); 449 | bool init_res = seg.init("dependency/jieba.dict.utf8"); 450 | ////��ʼ����ѯ����sqlite3�� 451 | sqlite3 * conn = NULL; //������ݿ� 452 | // char * err_msg = NULL; //����ʧ�ܵ�ԭ�� 453 | //����ݿ⣬�������� 454 | //if( SQLITE_OK != sqlite3_open("C:\\languageData_new\\new_dictionary.sqlite", &conn) ) 455 | if( SQLITE_OK != sqlite3_open("dependency/new_dictionary.sqlite", &conn) ) 456 | { 457 | printf("can't open the database."); 458 | exit(-1); 459 | } 460 | //������ 461 | //struct svm_model * svmModel = svm_load_model("C:\\languageData_new\\trainingSet.txt.model"); 462 | struct svm_model * svmModel = svm_load_model("dependency/trainingSet.txt.model"); 463 | //��������� 464 | //termFilter filter("C:\\languageData_new\\symbelTerms.txt"); 465 | termFilter filter("dependency/symbelTerms.txt"); 466 | 467 | 468 | //cut( �ָ����� �����ļ��� ����ļ��� ���ֵĴ� ) 469 | //cut(&seg, "title_utf8.txt", "title_res_utf8.txt", n_word); //�ֵ�ִ� 470 | 471 | //cut(&seg, l_text, v_lines_words, n_word); 472 | cut(&seg, p_text, n_text, v_lines_words, n_word, filter); 473 | 474 | 475 | /* 476 | //����ִʽ�� 477 | std::ofstream ofile1; 478 | ofile1.open("split_res.txt", std::ios::out); 479 | //std::vector< std::vector * > v_lines_words; 480 | for(int iSen = 0; iSen < v_lines_words.size(); iSen++) 481 | { 482 | std::vector * pSen = v_lines_words.at(iSen); 483 | for(int iWord = 0; iWord < pSen->size(); iWord++) 484 | { 485 | ofile1 << pSen->at(iWord) << "\n"; 486 | } 487 | } 488 | ofile1.close(); 489 | */ 490 | 491 | 492 | 493 | 494 | //��ÿ���ʣ��ڴʵ���Ѱ����Ӧ��Ƶ�� 495 | for(int iSen = 0; iSen < v_lines_words.size(); iSen++) 496 | { 497 | std::vector * pSen_fp = new std::vector; 498 | v_class_tf.push_back(pSen_fp); 499 | 500 | //std::cout << "querying no." << iSen+1 << "\n"; 501 | char message_t[50]; 502 | //sprintf(message_t, "\rquerying no.%d", iSen+1); 503 | sprintf(message_t, "\r�����%d��", iSen+1); 504 | std::cout << message_t; 505 | 506 | for( int iWord = 0; iWord < (*v_lines_words.at(iSen)).size(); iWord++ ) //ÿ����10���� 507 | { 508 | char * pWord = (*v_lines_words.at(iSen)).at(iWord); 509 | 510 | //����ݿ������Ƶ���ѯ 511 | query_word(conn, pWord); 512 | 513 | if( v_query.size() != N_DIMEN ) //������������ 514 | { 515 | for(int i = 0; i < N_DIMEN; i++) 516 | { 517 | pSen_fp->push_back( 0 ); 518 | } 519 | } 520 | else //����û�г������� 521 | { 522 | for(int i = 0; i < N_DIMEN; i++) 523 | { 524 | pSen_fp->push_back( v_query.at(i) ); 525 | } 526 | } 527 | } 528 | } 529 | 530 | 531 | /* 532 | //���������� 533 | //std::vector< std::vector * > v_class_tf; //ÿ������ÿ���г��ִ�Ƶ���б� 534 | std::ofstream ofile; 535 | ofile.open( "frequency.txt", std::ios::out ); 536 | for(int iSen = 0; iSen < v_class_tf.size(); iSen++) 537 | { 538 | std::vector * pSen_fp = v_class_tf.at( iSen ); 539 | for( int iTF = 0; iTF < pSen_fp->size(); iTF++ ) 540 | { 541 | ofile << pSen_fp->at(iTF) << "\t"; 542 | } 543 | ofile << "\n"; 544 | } 545 | ofile.close(); 546 | */ 547 | 548 | 549 | //������������������� 550 | double class_n[N_DIMEN] = {13186.0, 133915.0, 29844.0, 14694.0, 235245}; //�ʵ���ÿ��ĸ��� 551 | 552 | //std::vector< std::vector * > v_class_tf; //ÿ������ÿ���г��ִ�Ƶ���б� 553 | for(int iSen = 0; iSen < v_class_tf.size(); iSen++) 554 | { 555 | std::vector * fp_thisSen = v_class_tf.at(iSen); //����������� 556 | 557 | std::vector * pFeatureV = new std::vector; //�������������м���������� 558 | v_featureVector.push_back(pFeatureV); 559 | 560 | for(int iWord = 0; iWord < 10; iWord++) //ÿ�乲10���� 561 | { 562 | double norm_fp[N_DIMEN]; //�����һ����TF 563 | double max_nfp = 0.0; //���TF 564 | double sum_nfp = 0.0; //TF֮�� 565 | 566 | for(int i = 0; i < N_DIMEN; i++) 567 | { 568 | norm_fp[i] = 10000.0 * (double)(fp_thisSen->at( N_DIMEN*iWord + i ))/class_n[i]; 569 | sum_nfp += norm_fp[i]; 570 | if(max_nfp < norm_fp[i]) 571 | { 572 | max_nfp = norm_fp[i]; 573 | } 574 | } 575 | 576 | max_nfp /= 10.0; //�dz���10�󾭹�sigmoid���� 577 | double f1 = 2.0/( 1.0 + exp(-0.10986*max_nfp) ) - 1.0; 578 | 579 | if(sum_nfp != 0) 580 | { 581 | pFeatureV->push_back( f1 ); 582 | for(int i = 0; i < N_DIMEN; i++) 583 | { 584 | pFeatureV->push_back( norm_fp[i]/sum_nfp ); 585 | } 586 | } 587 | else 588 | { 589 | pFeatureV->push_back( f1 ); 590 | for(int i = 0; i < N_DIMEN; i++) 591 | { 592 | pFeatureV->push_back( 0.0 ); 593 | } 594 | } 595 | } 596 | } 597 | 598 | 599 | /* 600 | //����������� 601 | //std::vector< std::vector * > v_featureVector; //�������� 602 | std::ofstream ofile2; 603 | ofile2.open("featureVector.txt", std::ios::app); 604 | for(int iVec = 0; iVec < v_featureVector.size(); iVec++) 605 | { 606 | std::vector * v_Vector = v_featureVector.at(iVec); 607 | for(int iCell = 0; iCell < v_Vector->size(); iCell++) 608 | { 609 | ofile2 << iCell+1 << ":" << v_Vector->at(iCell) << " "; 610 | } 611 | ofile2 << "\n"; 612 | } 613 | ofile2.close(); 614 | */ 615 | 616 | 617 | ////svm������~ 618 | 619 | for(int iFV = 0; iFV < v_featureVector.size(); iFV++) //����ÿ����¼ 620 | { 621 | std::vector * pFV = v_featureVector.at(iFV); 622 | struct svm_node * svmData = (struct svm_node *)malloc( (50+1)*sizeof(struct svm_node) ); 623 | for(int i = 0; i < 50; i++) 624 | { 625 | svmData[i].index = i+1; 626 | svmData[i].value = pFV->at(i); 627 | } 628 | svmData[50].index = -1; 629 | 630 | int label = svm_predict(svmModel, svmData); 631 | 632 | //l_labels.push_back(label); 633 | p_labels[iFV] = label; 634 | 635 | //std::cout << iFV+1 << " : " << label << "\n"; 636 | 637 | free(svmData); 638 | } 639 | 640 | 641 | //���Ԥ����� 642 | if( outputPath != NULL ) 643 | { 644 | std::ofstream ofile3; 645 | ofile3.open( outputPath, std::ios::out ); 646 | for(int i = 0; i < n_text; i++) 647 | { 648 | ofile3 << p_labels[i] << "\n"; 649 | } 650 | ofile3.close(); 651 | } 652 | 653 | 654 | //���� 655 | 656 | 657 | /////�ͷŷִ��� 658 | seg.dispose(); 659 | //�ر�sqlite3���� 660 | if( SQLITE_OK != sqlite3_close(conn) ) 661 | { 662 | printf("can't close the database: %s/n", sqlite3_errmsg(conn)); 663 | exit(-1); 664 | } 665 | //�ͷŷ����� 666 | free(svmModel); 667 | //�ͷ�ȡ���б�v_lines_words 668 | for( int iSen = 0; iSen < v_lines_words.size(); iSen++ ) 669 | { 670 | std::vector * p_vSen = v_lines_words.at(iSen); 671 | for(int iWord = 0; iWord < p_vSen->size(); iWord++) 672 | { 673 | delete [] (char*)(p_vSen->at(iWord)); 674 | } 675 | p_vSen->clear(); 676 | } 677 | v_lines_words.clear(); 678 | //�ͷŲ�ѯ��Ƶ�б� 679 | for( int iSen = 0; iSen < v_class_tf.size(); iSen++ ) 680 | { 681 | v_class_tf.at(iSen)->clear(); 682 | } 683 | v_class_tf.clear(); 684 | //�ͷ������������� 685 | for( int iVec = 0; iVec < v_featureVector.size(); iVec++ ) 686 | { 687 | v_featureVector.at(iVec)->clear(); 688 | } 689 | v_featureVector.clear(); 690 | 691 | v_query.clear(); 692 | 693 | //return p_labels; 694 | } 695 | 696 | CateTeller::CateTeller() { 697 | // load Chinese word segment tools 698 | bool init_res = seg.init("dependency/jieba.dict.utf8"); 699 | 700 | // word vector data 701 | conn = NULL; 702 | if( SQLITE_OK != sqlite3_open("dependency/new_dictionary.sqlite", &conn) ) { 703 | printf("can't open the database."); 704 | exit(-1); 705 | } 706 | 707 | // svm model 708 | svmModel = svm_load_model("dependency/trainingSet.txt.model"); 709 | 710 | // filter 711 | filter = termFilter("dependency/symbelTerms.txt"); 712 | } 713 | 714 | CateTeller::~CateTeller() { 715 | /////�ͷŷִ��� 716 | seg.dispose(); 717 | //�ر�sqlite3���� 718 | if( SQLITE_OK != sqlite3_close(conn) ) { 719 | printf("can't close the database: %s/n", sqlite3_errmsg(conn)); 720 | exit(-1); 721 | } 722 | //�ͷŷ����� 723 | free(svmModel); 724 | } 725 | 726 | void CateTeller::tell(char ** p_text, int n_text, int * p_labels) { 727 | int n_word = N_WORD; 728 | 729 | std::vector< std::vector * > v_lines_words; //�����еĴ�����б� 730 | std::vector< std::vector * > v_class_tf; //ÿ������ÿ���г��ִ�Ƶ���б� 731 | std::vector< std::vector * > v_featureVector; //�������� 732 | 733 | cut(&seg, p_text, n_text, v_lines_words, n_word, filter); 734 | 735 | //��ÿ���ʣ��ڴʵ���Ѱ����Ӧ��Ƶ�� 736 | for(int iSen = 0; iSen < v_lines_words.size(); iSen++) 737 | { 738 | std::vector * pSen_fp = new std::vector; 739 | v_class_tf.push_back(pSen_fp); 740 | 741 | //std::cout << "querying no." << iSen+1 << "\n"; 742 | char message_t[50]; 743 | //sprintf(message_t, "\rquerying no.%d", iSen+1); 744 | sprintf(message_t, "\r�����%d��", iSen+1); 745 | // std::cout << message_t; 746 | 747 | for( int iWord = 0; iWord < (*v_lines_words.at(iSen)).size(); iWord++ ) //ÿ����10���� 748 | { 749 | char * pWord = (*v_lines_words.at(iSen)).at(iWord); 750 | 751 | //����ݿ������Ƶ���ѯ 752 | query_word(conn, pWord); 753 | 754 | if( v_query.size() != N_DIMEN ) //������������ 755 | { 756 | for(int i = 0; i < N_DIMEN; i++) 757 | { 758 | pSen_fp->push_back( 0 ); 759 | } 760 | } 761 | else //����û�г������� 762 | { 763 | for(int i = 0; i < N_DIMEN; i++) 764 | { 765 | pSen_fp->push_back( v_query.at(i) ); 766 | } 767 | } 768 | } 769 | } 770 | 771 | //������������������� 772 | double class_n[N_DIMEN] = {13186.0, 133915.0, 29844.0, 14694.0, 235245}; //�ʵ���ÿ��ĸ��� 773 | 774 | //std::vector< std::vector * > v_class_tf; //ÿ������ÿ���г��ִ�Ƶ���б� 775 | for(int iSen = 0; iSen < v_class_tf.size(); iSen++) 776 | { 777 | std::vector * fp_thisSen = v_class_tf.at(iSen); //����������� 778 | 779 | std::vector * pFeatureV = new std::vector; //�������������м���������� 780 | v_featureVector.push_back(pFeatureV); 781 | 782 | for(int iWord = 0; iWord < 10; iWord++) //ÿ�乲10���� 783 | { 784 | double norm_fp[N_DIMEN]; //�����һ����TF 785 | double max_nfp = 0.0; //���TF 786 | double sum_nfp = 0.0; //TF֮�� 787 | 788 | for(int i = 0; i < N_DIMEN; i++) 789 | { 790 | norm_fp[i] = 10000.0 * (double)(fp_thisSen->at( N_DIMEN*iWord + i ))/class_n[i]; 791 | sum_nfp += norm_fp[i]; 792 | if(max_nfp < norm_fp[i]) 793 | { 794 | max_nfp = norm_fp[i]; 795 | } 796 | } 797 | 798 | max_nfp /= 10.0; //�dz���10�󾭹�sigmoid���� 799 | double f1 = 2.0/( 1.0 + exp(-0.10986*max_nfp) ) - 1.0; 800 | 801 | if(sum_nfp != 0) 802 | { 803 | pFeatureV->push_back( f1 ); 804 | for(int i = 0; i < N_DIMEN; i++) 805 | { 806 | pFeatureV->push_back( norm_fp[i]/sum_nfp ); 807 | } 808 | } 809 | else 810 | { 811 | pFeatureV->push_back( f1 ); 812 | for(int i = 0; i < N_DIMEN; i++) 813 | { 814 | pFeatureV->push_back( 0.0 ); 815 | } 816 | } 817 | } 818 | } 819 | 820 | 821 | ////svm������~ 822 | 823 | for(int iFV = 0; iFV < v_featureVector.size(); iFV++) //����ÿ����¼ 824 | { 825 | std::vector * pFV = v_featureVector.at(iFV); 826 | struct svm_node * svmData = (struct svm_node *)malloc( (50+1)*sizeof(struct svm_node) ); 827 | for(int i = 0; i < 50; i++) 828 | { 829 | svmData[i].index = i+1; 830 | svmData[i].value = pFV->at(i); 831 | } 832 | svmData[50].index = -1; 833 | 834 | int label = svm_predict(svmModel, svmData); 835 | 836 | //l_labels.push_back(label); 837 | p_labels[iFV] = label; 838 | 839 | //std::cout << iFV+1 << " : " << label << "\n"; 840 | 841 | free(svmData); 842 | } 843 | 844 | //// do some cleanup 845 | 846 | //�ͷ�ȡ���б�v_lines_words 847 | for( int iSen = 0; iSen < v_lines_words.size(); iSen++ ) 848 | { 849 | std::vector * p_vSen = v_lines_words.at(iSen); 850 | for(int iWord = 0; iWord < p_vSen->size(); iWord++) 851 | { 852 | delete [] (char*)(p_vSen->at(iWord)); 853 | } 854 | p_vSen->clear(); 855 | } 856 | v_lines_words.clear(); 857 | //�ͷŲ�ѯ��Ƶ�б� 858 | for( int iSen = 0; iSen < v_class_tf.size(); iSen++ ) 859 | { 860 | v_class_tf.at(iSen)->clear(); 861 | } 862 | v_class_tf.clear(); 863 | //�ͷ������������� 864 | for( int iVec = 0; iVec < v_featureVector.size(); iVec++ ) 865 | { 866 | v_featureVector.at(iVec)->clear(); 867 | } 868 | v_featureVector.clear(); 869 | 870 | v_query.clear(); 871 | 872 | } 873 | -------------------------------------------------------------------------------- /src/combinition.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Serbipunk/zeromq_nlp_service/e1ab9ff5aa08dbfd6bf0c4fed3c43b9182b648e6/src/combinition.cpp -------------------------------------------------------------------------------- /src/cppMSGmodule.cpp: -------------------------------------------------------------------------------- 1 | #include "cppMSGmodule.hpp" 2 | 3 | #include 4 | #include 5 | #include "titlebook.pb.h" 6 | #include 7 | 8 | using namespace std; 9 | 10 | bool CppMsgModule::msgStrToPcharArray(char**& pchar_array, int& count_array, const string& msg_str) { 11 | 12 | // Verify that the version of the library that we linked against is 13 | // compatible with the version of the headers we compiled against. 14 | GOOGLE_PROTOBUF_VERIFY_VERSION; 15 | 16 | tutorial::TitleList title_list; 17 | // fstream input("titlebook.protomodel", ios::in); 18 | // if (!title_list.ParseFromIstream(&input)) { 19 | if (!title_list.ParseFromString(msg_str)) { 20 | cerr << "Failed to parse title book." << endl; 21 | return false; 22 | } 23 | 24 | pchar_array = new char* [title_list.title_size()]; 25 | count_array = title_list.title_size(); 26 | 27 | for(int i=0; i static inline T min(T x,T y) { return (x static inline T max(T x,T y) { return (x>y)?x:y; } 13 | #endif 14 | template static inline void swap(T& x, T& y) { T t=x; x=y; y=t; } 15 | template static inline void clone(T*& dst, S* src, int n) 16 | { 17 | dst = new T[n]; 18 | memcpy((void *)dst,(void *)src,sizeof(T)*n); 19 | } 20 | static inline double powi(double base, int times) 21 | { 22 | double tmp = base, ret = 1.0; 23 | 24 | for(int t=times; t>0; t/=2) 25 | { 26 | if(t%2==1) ret*=tmp; 27 | tmp = tmp * tmp; 28 | } 29 | return ret; 30 | } 31 | #define INF HUGE_VAL 32 | #define TAU 1e-12 33 | #define Malloc(type,n) (type *)malloc((n)*sizeof(type)) 34 | 35 | static void print_string_stdout(const char *s) 36 | { 37 | fputs(s,stdout); 38 | fflush(stdout); 39 | } 40 | static void (*svm_print_string) (const char *) = &print_string_stdout; 41 | #if 1 42 | static void info(const char *fmt,...) 43 | { 44 | char buf[BUFSIZ]; 45 | va_list ap; 46 | va_start(ap,fmt); 47 | vsprintf(buf,fmt,ap); 48 | va_end(ap); 49 | (*svm_print_string)(buf); 50 | } 51 | #else 52 | static void info(const char *fmt,...) {} 53 | #endif 54 | 55 | // 56 | // Kernel Cache 57 | // 58 | // l is the number of total data items 59 | // size is the cache size limit in bytes 60 | // 61 | class Cache 62 | { 63 | public: 64 | Cache(int l,long int size); 65 | ~Cache(); 66 | 67 | // request data [0,len) 68 | // return some position p where [p,len) need to be filled 69 | // (p >= len if nothing needs to be filled) 70 | int get_data(const int index, Qfloat **data, int len); 71 | void swap_index(int i, int j); 72 | private: 73 | int l; 74 | long int size; 75 | struct head_t 76 | { 77 | head_t *prev, *next; // a circular list 78 | Qfloat *data; 79 | int len; // data[0,len) is cached in this entry 80 | }; 81 | 82 | head_t *head; 83 | head_t lru_head; 84 | void lru_delete(head_t *h); 85 | void lru_insert(head_t *h); 86 | }; 87 | 88 | Cache::Cache(int l_,long int size_):l(l_),size(size_) 89 | { 90 | head = (head_t *)calloc(l,sizeof(head_t)); // initialized to 0 91 | size /= sizeof(Qfloat); 92 | size -= l * sizeof(head_t) / sizeof(Qfloat); 93 | size = max(size, 2 * (long int) l); // cache must be large enough for two columns 94 | lru_head.next = lru_head.prev = &lru_head; 95 | } 96 | 97 | Cache::~Cache() 98 | { 99 | for(head_t *h = lru_head.next; h != &lru_head; h=h->next) 100 | free(h->data); 101 | free(head); 102 | } 103 | 104 | void Cache::lru_delete(head_t *h) 105 | { 106 | // delete from current location 107 | h->prev->next = h->next; 108 | h->next->prev = h->prev; 109 | } 110 | 111 | void Cache::lru_insert(head_t *h) 112 | { 113 | // insert to last position 114 | h->next = &lru_head; 115 | h->prev = lru_head.prev; 116 | h->prev->next = h; 117 | h->next->prev = h; 118 | } 119 | 120 | int Cache::get_data(const int index, Qfloat **data, int len) 121 | { 122 | head_t *h = &head[index]; 123 | if(h->len) lru_delete(h); 124 | int more = len - h->len; 125 | 126 | if(more > 0) 127 | { 128 | // free old space 129 | while(size < more) 130 | { 131 | head_t *old = lru_head.next; 132 | lru_delete(old); 133 | free(old->data); 134 | size += old->len; 135 | old->data = 0; 136 | old->len = 0; 137 | } 138 | 139 | // allocate new space 140 | h->data = (Qfloat *)realloc(h->data,sizeof(Qfloat)*len); 141 | size -= more; 142 | swap(h->len,len); 143 | } 144 | 145 | lru_insert(h); 146 | *data = h->data; 147 | return len; 148 | } 149 | 150 | void Cache::swap_index(int i, int j) 151 | { 152 | if(i==j) return; 153 | 154 | if(head[i].len) lru_delete(&head[i]); 155 | if(head[j].len) lru_delete(&head[j]); 156 | swap(head[i].data,head[j].data); 157 | swap(head[i].len,head[j].len); 158 | if(head[i].len) lru_insert(&head[i]); 159 | if(head[j].len) lru_insert(&head[j]); 160 | 161 | if(i>j) swap(i,j); 162 | for(head_t *h = lru_head.next; h!=&lru_head; h=h->next) 163 | { 164 | if(h->len > i) 165 | { 166 | if(h->len > j) 167 | swap(h->data[i],h->data[j]); 168 | else 169 | { 170 | // give up 171 | lru_delete(h); 172 | free(h->data); 173 | size += h->len; 174 | h->data = 0; 175 | h->len = 0; 176 | } 177 | } 178 | } 179 | } 180 | 181 | // 182 | // Kernel evaluation 183 | // 184 | // the static method k_function is for doing single kernel evaluation 185 | // the constructor of Kernel prepares to calculate the l*l kernel matrix 186 | // the member function get_Q is for getting one column from the Q Matrix 187 | // 188 | class QMatrix { 189 | public: 190 | virtual Qfloat *get_Q(int column, int len) const = 0; 191 | virtual double *get_QD() const = 0; 192 | virtual void swap_index(int i, int j) const = 0; 193 | virtual ~QMatrix() {} 194 | }; 195 | 196 | class Kernel: public QMatrix { 197 | public: 198 | Kernel(int l, svm_node * const * x, const svm_parameter& param); 199 | virtual ~Kernel(); 200 | 201 | static double k_function(const svm_node *x, const svm_node *y, 202 | const svm_parameter& param); 203 | virtual Qfloat *get_Q(int column, int len) const = 0; 204 | virtual double *get_QD() const = 0; 205 | virtual void swap_index(int i, int j) const // no so const... 206 | { 207 | swap(x[i],x[j]); 208 | if(x_square) swap(x_square[i],x_square[j]); 209 | } 210 | protected: 211 | 212 | double (Kernel::*kernel_function)(int i, int j) const; 213 | 214 | private: 215 | const svm_node **x; 216 | double *x_square; 217 | 218 | // svm_parameter 219 | const int kernel_type; 220 | const int degree; 221 | const double gamma; 222 | const double coef0; 223 | 224 | static double dot(const svm_node *px, const svm_node *py); 225 | double kernel_linear(int i, int j) const 226 | { 227 | return dot(x[i],x[j]); 228 | } 229 | double kernel_poly(int i, int j) const 230 | { 231 | return powi(gamma*dot(x[i],x[j])+coef0,degree); 232 | } 233 | double kernel_rbf(int i, int j) const 234 | { 235 | return exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j]))); 236 | } 237 | double kernel_sigmoid(int i, int j) const 238 | { 239 | return tanh(gamma*dot(x[i],x[j])+coef0); 240 | } 241 | double kernel_precomputed(int i, int j) const 242 | { 243 | return x[i][(int)(x[j][0].value)].value; 244 | } 245 | }; 246 | 247 | Kernel::Kernel(int l, svm_node * const * x_, const svm_parameter& param) 248 | :kernel_type(param.kernel_type), degree(param.degree), 249 | gamma(param.gamma), coef0(param.coef0) 250 | { 251 | switch(kernel_type) 252 | { 253 | case LINEAR: 254 | kernel_function = &Kernel::kernel_linear; 255 | break; 256 | case POLY: 257 | kernel_function = &Kernel::kernel_poly; 258 | break; 259 | case RBF: 260 | kernel_function = &Kernel::kernel_rbf; 261 | break; 262 | case SIGMOID: 263 | kernel_function = &Kernel::kernel_sigmoid; 264 | break; 265 | case PRECOMPUTED: 266 | kernel_function = &Kernel::kernel_precomputed; 267 | break; 268 | } 269 | 270 | clone(x,x_,l); 271 | 272 | if(kernel_type == RBF) 273 | { 274 | x_square = new double[l]; 275 | for(int i=0;iindex != -1 && py->index != -1) 292 | { 293 | if(px->index == py->index) 294 | { 295 | sum += px->value * py->value; 296 | ++px; 297 | ++py; 298 | } 299 | else 300 | { 301 | if(px->index > py->index) 302 | ++py; 303 | else 304 | ++px; 305 | } 306 | } 307 | return sum; 308 | } 309 | 310 | double Kernel::k_function(const svm_node *x, const svm_node *y, 311 | const svm_parameter& param) 312 | { 313 | switch(param.kernel_type) 314 | { 315 | case LINEAR: 316 | return dot(x,y); 317 | case POLY: 318 | return powi(param.gamma*dot(x,y)+param.coef0,param.degree); 319 | case RBF: 320 | { 321 | double sum = 0; 322 | while(x->index != -1 && y->index !=-1) 323 | { 324 | if(x->index == y->index) 325 | { 326 | double d = x->value - y->value; 327 | sum += d*d; 328 | ++x; 329 | ++y; 330 | } 331 | else 332 | { 333 | if(x->index > y->index) 334 | { 335 | sum += y->value * y->value; 336 | ++y; 337 | } 338 | else 339 | { 340 | sum += x->value * x->value; 341 | ++x; 342 | } 343 | } 344 | } 345 | 346 | while(x->index != -1) 347 | { 348 | sum += x->value * x->value; 349 | ++x; 350 | } 351 | 352 | while(y->index != -1) 353 | { 354 | sum += y->value * y->value; 355 | ++y; 356 | } 357 | 358 | return exp(-param.gamma*sum); 359 | } 360 | case SIGMOID: 361 | return tanh(param.gamma*dot(x,y)+param.coef0); 362 | case PRECOMPUTED: //x: test (validation), y: SV 363 | return x[(int)(y->value)].value; 364 | default: 365 | return 0; // Unreachable 366 | } 367 | } 368 | 369 | // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918 370 | // Solves: 371 | // 372 | // min 0.5(\alpha^T Q \alpha) + p^T \alpha 373 | // 374 | // y^T \alpha = \delta 375 | // y_i = +1 or -1 376 | // 0 <= alpha_i <= Cp for y_i = 1 377 | // 0 <= alpha_i <= Cn for y_i = -1 378 | // 379 | // Given: 380 | // 381 | // Q, p, y, Cp, Cn, and an initial feasible point \alpha 382 | // l is the size of vectors and matrices 383 | // eps is the stopping tolerance 384 | // 385 | // solution will be put in \alpha, objective value will be put in obj 386 | // 387 | class Solver { 388 | public: 389 | Solver() {}; 390 | virtual ~Solver() {}; 391 | 392 | struct SolutionInfo { 393 | double obj; 394 | double rho; 395 | double upper_bound_p; 396 | double upper_bound_n; 397 | double r; // for Solver_NU 398 | }; 399 | 400 | void Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, 401 | double *alpha_, double Cp, double Cn, double eps, 402 | SolutionInfo* si, int shrinking); 403 | protected: 404 | int active_size; 405 | schar *y; 406 | double *G; // gradient of objective function 407 | enum { LOWER_BOUND, UPPER_BOUND, FREE }; 408 | char *alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE 409 | double *alpha; 410 | const QMatrix *Q; 411 | const double *QD; 412 | double eps; 413 | double Cp,Cn; 414 | double *p; 415 | int *active_set; 416 | double *G_bar; // gradient, if we treat free variables as 0 417 | int l; 418 | bool unshrink; // XXX 419 | 420 | double get_C(int i) 421 | { 422 | return (y[i] > 0)? Cp : Cn; 423 | } 424 | void update_alpha_status(int i) 425 | { 426 | if(alpha[i] >= get_C(i)) 427 | alpha_status[i] = UPPER_BOUND; 428 | else if(alpha[i] <= 0) 429 | alpha_status[i] = LOWER_BOUND; 430 | else alpha_status[i] = FREE; 431 | } 432 | bool is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; } 433 | bool is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; } 434 | bool is_free(int i) { return alpha_status[i] == FREE; } 435 | void swap_index(int i, int j); 436 | void reconstruct_gradient(); 437 | virtual int select_working_set(int &i, int &j); 438 | virtual double calculate_rho(); 439 | virtual void do_shrinking(); 440 | private: 441 | bool be_shrunk(int i, double Gmax1, double Gmax2); 442 | }; 443 | 444 | void Solver::swap_index(int i, int j) 445 | { 446 | Q->swap_index(i,j); 447 | swap(y[i],y[j]); 448 | swap(G[i],G[j]); 449 | swap(alpha_status[i],alpha_status[j]); 450 | swap(alpha[i],alpha[j]); 451 | swap(p[i],p[j]); 452 | swap(active_set[i],active_set[j]); 453 | swap(G_bar[i],G_bar[j]); 454 | } 455 | 456 | void Solver::reconstruct_gradient() 457 | { 458 | // reconstruct inactive elements of G from G_bar and free variables 459 | 460 | if(active_size == l) return; 461 | 462 | int i,j; 463 | int nr_free = 0; 464 | 465 | for(j=active_size;j 2*active_size*(l-active_size)) 476 | { 477 | for(i=active_size;iget_Q(i,active_size); 480 | for(j=0;jget_Q(i,l); 491 | double alpha_i = alpha[i]; 492 | for(j=active_size;jl = l; 503 | this->Q = &Q; 504 | QD=Q.get_QD(); 505 | clone(p, p_,l); 506 | clone(y, y_,l); 507 | clone(alpha,alpha_,l); 508 | this->Cp = Cp; 509 | this->Cn = Cn; 510 | this->eps = eps; 511 | unshrink = false; 512 | 513 | // initialize alpha_status 514 | { 515 | alpha_status = new char[l]; 516 | for(int i=0;iINT_MAX/100 ? INT_MAX : 100*l); 556 | int counter = min(l,1000)+1; 557 | 558 | while(iter < max_iter) 559 | { 560 | // show progress and do shrinking 561 | 562 | if(--counter == 0) 563 | { 564 | counter = min(l,1000); 565 | if(shrinking) do_shrinking(); 566 | info("."); 567 | } 568 | 569 | int i,j; 570 | if(select_working_set(i,j)!=0) 571 | { 572 | // reconstruct the whole gradient 573 | reconstruct_gradient(); 574 | // reset active set size and check 575 | active_size = l; 576 | info("*"); 577 | if(select_working_set(i,j)!=0) 578 | break; 579 | else 580 | counter = 1; // do shrinking next iteration 581 | } 582 | 583 | ++iter; 584 | 585 | // update alpha[i] and alpha[j], handle bounds carefully 586 | 587 | const Qfloat *Q_i = Q.get_Q(i,active_size); 588 | const Qfloat *Q_j = Q.get_Q(j,active_size); 589 | 590 | double C_i = get_C(i); 591 | double C_j = get_C(j); 592 | 593 | double old_alpha_i = alpha[i]; 594 | double old_alpha_j = alpha[j]; 595 | 596 | if(y[i]!=y[j]) 597 | { 598 | double quad_coef = QD[i]+QD[j]+2*Q_i[j]; 599 | if (quad_coef <= 0) 600 | quad_coef = TAU; 601 | double delta = (-G[i]-G[j])/quad_coef; 602 | double diff = alpha[i] - alpha[j]; 603 | alpha[i] += delta; 604 | alpha[j] += delta; 605 | 606 | if(diff > 0) 607 | { 608 | if(alpha[j] < 0) 609 | { 610 | alpha[j] = 0; 611 | alpha[i] = diff; 612 | } 613 | } 614 | else 615 | { 616 | if(alpha[i] < 0) 617 | { 618 | alpha[i] = 0; 619 | alpha[j] = -diff; 620 | } 621 | } 622 | if(diff > C_i - C_j) 623 | { 624 | if(alpha[i] > C_i) 625 | { 626 | alpha[i] = C_i; 627 | alpha[j] = C_i - diff; 628 | } 629 | } 630 | else 631 | { 632 | if(alpha[j] > C_j) 633 | { 634 | alpha[j] = C_j; 635 | alpha[i] = C_j + diff; 636 | } 637 | } 638 | } 639 | else 640 | { 641 | double quad_coef = QD[i]+QD[j]-2*Q_i[j]; 642 | if (quad_coef <= 0) 643 | quad_coef = TAU; 644 | double delta = (G[i]-G[j])/quad_coef; 645 | double sum = alpha[i] + alpha[j]; 646 | alpha[i] -= delta; 647 | alpha[j] += delta; 648 | 649 | if(sum > C_i) 650 | { 651 | if(alpha[i] > C_i) 652 | { 653 | alpha[i] = C_i; 654 | alpha[j] = sum - C_i; 655 | } 656 | } 657 | else 658 | { 659 | if(alpha[j] < 0) 660 | { 661 | alpha[j] = 0; 662 | alpha[i] = sum; 663 | } 664 | } 665 | if(sum > C_j) 666 | { 667 | if(alpha[j] > C_j) 668 | { 669 | alpha[j] = C_j; 670 | alpha[i] = sum - C_j; 671 | } 672 | } 673 | else 674 | { 675 | if(alpha[i] < 0) 676 | { 677 | alpha[i] = 0; 678 | alpha[j] = sum; 679 | } 680 | } 681 | } 682 | 683 | // update G 684 | 685 | double delta_alpha_i = alpha[i] - old_alpha_i; 686 | double delta_alpha_j = alpha[j] - old_alpha_j; 687 | 688 | for(int k=0;k= max_iter) 726 | { 727 | if(active_size < l) 728 | { 729 | // reconstruct the whole gradient to calculate objective value 730 | reconstruct_gradient(); 731 | active_size = l; 732 | info("*"); 733 | } 734 | info("\nWARNING: reaching max number of iterations"); 735 | } 736 | 737 | // calculate rho 738 | 739 | si->rho = calculate_rho(); 740 | 741 | // calculate objective value 742 | { 743 | double v = 0; 744 | int i; 745 | for(i=0;iobj = v/2; 749 | } 750 | 751 | // put back the solution 752 | { 753 | for(int i=0;iupper_bound_p = Cp; 766 | si->upper_bound_n = Cn; 767 | 768 | info("\noptimization finished, #iter = %d\n",iter); 769 | 770 | delete[] p; 771 | delete[] y; 772 | delete[] alpha; 773 | delete[] alpha_status; 774 | delete[] active_set; 775 | delete[] G; 776 | delete[] G_bar; 777 | } 778 | 779 | // return 1 if already optimal, return 0 otherwise 780 | int Solver::select_working_set(int &out_i, int &out_j) 781 | { 782 | // return i,j such that 783 | // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha) 784 | // j: minimizes the decrease of obj value 785 | // (if quadratic coefficeint <= 0, replace it with tau) 786 | // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha) 787 | 788 | double Gmax = -INF; 789 | double Gmax2 = -INF; 790 | int Gmax_idx = -1; 791 | int Gmin_idx = -1; 792 | double obj_diff_min = INF; 793 | 794 | for(int t=0;t= Gmax) 799 | { 800 | Gmax = -G[t]; 801 | Gmax_idx = t; 802 | } 803 | } 804 | else 805 | { 806 | if(!is_lower_bound(t)) 807 | if(G[t] >= Gmax) 808 | { 809 | Gmax = G[t]; 810 | Gmax_idx = t; 811 | } 812 | } 813 | 814 | int i = Gmax_idx; 815 | const Qfloat *Q_i = NULL; 816 | if(i != -1) // NULL Q_i not accessed: Gmax=-INF if i=-1 817 | Q_i = Q->get_Q(i,active_size); 818 | 819 | for(int j=0;j= Gmax2) 827 | Gmax2 = G[j]; 828 | if (grad_diff > 0) 829 | { 830 | double obj_diff; 831 | double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j]; 832 | if (quad_coef > 0) 833 | obj_diff = -(grad_diff*grad_diff)/quad_coef; 834 | else 835 | obj_diff = -(grad_diff*grad_diff)/TAU; 836 | 837 | if (obj_diff <= obj_diff_min) 838 | { 839 | Gmin_idx=j; 840 | obj_diff_min = obj_diff; 841 | } 842 | } 843 | } 844 | } 845 | else 846 | { 847 | if (!is_upper_bound(j)) 848 | { 849 | double grad_diff= Gmax-G[j]; 850 | if (-G[j] >= Gmax2) 851 | Gmax2 = -G[j]; 852 | if (grad_diff > 0) 853 | { 854 | double obj_diff; 855 | double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j]; 856 | if (quad_coef > 0) 857 | obj_diff = -(grad_diff*grad_diff)/quad_coef; 858 | else 859 | obj_diff = -(grad_diff*grad_diff)/TAU; 860 | 861 | if (obj_diff <= obj_diff_min) 862 | { 863 | Gmin_idx=j; 864 | obj_diff_min = obj_diff; 865 | } 866 | } 867 | } 868 | } 869 | } 870 | 871 | if(Gmax+Gmax2 < eps) 872 | return 1; 873 | 874 | out_i = Gmax_idx; 875 | out_j = Gmin_idx; 876 | return 0; 877 | } 878 | 879 | bool Solver::be_shrunk(int i, double Gmax1, double Gmax2) 880 | { 881 | if(is_upper_bound(i)) 882 | { 883 | if(y[i]==+1) 884 | return(-G[i] > Gmax1); 885 | else 886 | return(-G[i] > Gmax2); 887 | } 888 | else if(is_lower_bound(i)) 889 | { 890 | if(y[i]==+1) 891 | return(G[i] > Gmax2); 892 | else 893 | return(G[i] > Gmax1); 894 | } 895 | else 896 | return(false); 897 | } 898 | 899 | void Solver::do_shrinking() 900 | { 901 | int i; 902 | double Gmax1 = -INF; // max { -y_i * grad(f)_i | i in I_up(\alpha) } 903 | double Gmax2 = -INF; // max { y_i * grad(f)_i | i in I_low(\alpha) } 904 | 905 | // find maximal violating pair first 906 | for(i=0;i= Gmax1) 913 | Gmax1 = -G[i]; 914 | } 915 | if(!is_lower_bound(i)) 916 | { 917 | if(G[i] >= Gmax2) 918 | Gmax2 = G[i]; 919 | } 920 | } 921 | else 922 | { 923 | if(!is_upper_bound(i)) 924 | { 925 | if(-G[i] >= Gmax2) 926 | Gmax2 = -G[i]; 927 | } 928 | if(!is_lower_bound(i)) 929 | { 930 | if(G[i] >= Gmax1) 931 | Gmax1 = G[i]; 932 | } 933 | } 934 | } 935 | 936 | if(unshrink == false && Gmax1 + Gmax2 <= eps*10) 937 | { 938 | unshrink = true; 939 | reconstruct_gradient(); 940 | active_size = l; 941 | info("*"); 942 | } 943 | 944 | for(i=0;i i) 949 | { 950 | if (!be_shrunk(active_size, Gmax1, Gmax2)) 951 | { 952 | swap_index(i,active_size); 953 | break; 954 | } 955 | active_size--; 956 | } 957 | } 958 | } 959 | 960 | double Solver::calculate_rho() 961 | { 962 | double r; 963 | int nr_free = 0; 964 | double ub = INF, lb = -INF, sum_free = 0; 965 | for(int i=0;i0) 991 | r = sum_free/nr_free; 992 | else 993 | r = (ub+lb)/2; 994 | 995 | return r; 996 | } 997 | 998 | // 999 | // Solver for nu-svm classification and regression 1000 | // 1001 | // additional constraint: e^T \alpha = constant 1002 | // 1003 | class Solver_NU : public Solver 1004 | { 1005 | public: 1006 | Solver_NU() {} 1007 | void Solve(int l, const QMatrix& Q, const double *p, const schar *y, 1008 | double *alpha, double Cp, double Cn, double eps, 1009 | SolutionInfo* si, int shrinking) 1010 | { 1011 | this->si = si; 1012 | Solver::Solve(l,Q,p,y,alpha,Cp,Cn,eps,si,shrinking); 1013 | } 1014 | private: 1015 | SolutionInfo *si; 1016 | int select_working_set(int &i, int &j); 1017 | double calculate_rho(); 1018 | bool be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4); 1019 | void do_shrinking(); 1020 | }; 1021 | 1022 | // return 1 if already optimal, return 0 otherwise 1023 | int Solver_NU::select_working_set(int &out_i, int &out_j) 1024 | { 1025 | // return i,j such that y_i = y_j and 1026 | // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha) 1027 | // j: minimizes the decrease of obj value 1028 | // (if quadratic coefficeint <= 0, replace it with tau) 1029 | // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha) 1030 | 1031 | double Gmaxp = -INF; 1032 | double Gmaxp2 = -INF; 1033 | int Gmaxp_idx = -1; 1034 | 1035 | double Gmaxn = -INF; 1036 | double Gmaxn2 = -INF; 1037 | int Gmaxn_idx = -1; 1038 | 1039 | int Gmin_idx = -1; 1040 | double obj_diff_min = INF; 1041 | 1042 | for(int t=0;t= Gmaxp) 1047 | { 1048 | Gmaxp = -G[t]; 1049 | Gmaxp_idx = t; 1050 | } 1051 | } 1052 | else 1053 | { 1054 | if(!is_lower_bound(t)) 1055 | if(G[t] >= Gmaxn) 1056 | { 1057 | Gmaxn = G[t]; 1058 | Gmaxn_idx = t; 1059 | } 1060 | } 1061 | 1062 | int ip = Gmaxp_idx; 1063 | int in = Gmaxn_idx; 1064 | const Qfloat *Q_ip = NULL; 1065 | const Qfloat *Q_in = NULL; 1066 | if(ip != -1) // NULL Q_ip not accessed: Gmaxp=-INF if ip=-1 1067 | Q_ip = Q->get_Q(ip,active_size); 1068 | if(in != -1) 1069 | Q_in = Q->get_Q(in,active_size); 1070 | 1071 | for(int j=0;j= Gmaxp2) 1079 | Gmaxp2 = G[j]; 1080 | if (grad_diff > 0) 1081 | { 1082 | double obj_diff; 1083 | double quad_coef = QD[ip]+QD[j]-2*Q_ip[j]; 1084 | if (quad_coef > 0) 1085 | obj_diff = -(grad_diff*grad_diff)/quad_coef; 1086 | else 1087 | obj_diff = -(grad_diff*grad_diff)/TAU; 1088 | 1089 | if (obj_diff <= obj_diff_min) 1090 | { 1091 | Gmin_idx=j; 1092 | obj_diff_min = obj_diff; 1093 | } 1094 | } 1095 | } 1096 | } 1097 | else 1098 | { 1099 | if (!is_upper_bound(j)) 1100 | { 1101 | double grad_diff=Gmaxn-G[j]; 1102 | if (-G[j] >= Gmaxn2) 1103 | Gmaxn2 = -G[j]; 1104 | if (grad_diff > 0) 1105 | { 1106 | double obj_diff; 1107 | double quad_coef = QD[in]+QD[j]-2*Q_in[j]; 1108 | if (quad_coef > 0) 1109 | obj_diff = -(grad_diff*grad_diff)/quad_coef; 1110 | else 1111 | obj_diff = -(grad_diff*grad_diff)/TAU; 1112 | 1113 | if (obj_diff <= obj_diff_min) 1114 | { 1115 | Gmin_idx=j; 1116 | obj_diff_min = obj_diff; 1117 | } 1118 | } 1119 | } 1120 | } 1121 | } 1122 | 1123 | if(max(Gmaxp+Gmaxp2,Gmaxn+Gmaxn2) < eps) 1124 | return 1; 1125 | 1126 | if (y[Gmin_idx] == +1) 1127 | out_i = Gmaxp_idx; 1128 | else 1129 | out_i = Gmaxn_idx; 1130 | out_j = Gmin_idx; 1131 | 1132 | return 0; 1133 | } 1134 | 1135 | bool Solver_NU::be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4) 1136 | { 1137 | if(is_upper_bound(i)) 1138 | { 1139 | if(y[i]==+1) 1140 | return(-G[i] > Gmax1); 1141 | else 1142 | return(-G[i] > Gmax4); 1143 | } 1144 | else if(is_lower_bound(i)) 1145 | { 1146 | if(y[i]==+1) 1147 | return(G[i] > Gmax2); 1148 | else 1149 | return(G[i] > Gmax3); 1150 | } 1151 | else 1152 | return(false); 1153 | } 1154 | 1155 | void Solver_NU::do_shrinking() 1156 | { 1157 | double Gmax1 = -INF; // max { -y_i * grad(f)_i | y_i = +1, i in I_up(\alpha) } 1158 | double Gmax2 = -INF; // max { y_i * grad(f)_i | y_i = +1, i in I_low(\alpha) } 1159 | double Gmax3 = -INF; // max { -y_i * grad(f)_i | y_i = -1, i in I_up(\alpha) } 1160 | double Gmax4 = -INF; // max { y_i * grad(f)_i | y_i = -1, i in I_low(\alpha) } 1161 | 1162 | // find maximal violating pair first 1163 | int i; 1164 | for(i=0;i Gmax1) Gmax1 = -G[i]; 1171 | } 1172 | else if(-G[i] > Gmax4) Gmax4 = -G[i]; 1173 | } 1174 | if(!is_lower_bound(i)) 1175 | { 1176 | if(y[i]==+1) 1177 | { 1178 | if(G[i] > Gmax2) Gmax2 = G[i]; 1179 | } 1180 | else if(G[i] > Gmax3) Gmax3 = G[i]; 1181 | } 1182 | } 1183 | 1184 | if(unshrink == false && max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10) 1185 | { 1186 | unshrink = true; 1187 | reconstruct_gradient(); 1188 | active_size = l; 1189 | } 1190 | 1191 | for(i=0;i i) 1196 | { 1197 | if (!be_shrunk(active_size, Gmax1, Gmax2, Gmax3, Gmax4)) 1198 | { 1199 | swap_index(i,active_size); 1200 | break; 1201 | } 1202 | active_size--; 1203 | } 1204 | } 1205 | } 1206 | 1207 | double Solver_NU::calculate_rho() 1208 | { 1209 | int nr_free1 = 0,nr_free2 = 0; 1210 | double ub1 = INF, ub2 = INF; 1211 | double lb1 = -INF, lb2 = -INF; 1212 | double sum_free1 = 0, sum_free2 = 0; 1213 | 1214 | for(int i=0;i 0) 1244 | r1 = sum_free1/nr_free1; 1245 | else 1246 | r1 = (ub1+lb1)/2; 1247 | 1248 | if(nr_free2 > 0) 1249 | r2 = sum_free2/nr_free2; 1250 | else 1251 | r2 = (ub2+lb2)/2; 1252 | 1253 | si->r = (r1+r2)/2; 1254 | return (r1-r2)/2; 1255 | } 1256 | 1257 | // 1258 | // Q matrices for various formulations 1259 | // 1260 | class SVC_Q: public Kernel 1261 | { 1262 | public: 1263 | SVC_Q(const svm_problem& prob, const svm_parameter& param, const schar *y_) 1264 | :Kernel(prob.l, prob.x, param) 1265 | { 1266 | clone(y,y_,prob.l); 1267 | cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20))); 1268 | QD = new double[prob.l]; 1269 | for(int i=0;i*kernel_function)(i,i); 1271 | } 1272 | 1273 | Qfloat *get_Q(int i, int len) const 1274 | { 1275 | Qfloat *data; 1276 | int start, j; 1277 | if((start = cache->get_data(i,&data,len)) < len) 1278 | { 1279 | for(j=start;j*kernel_function)(i,j)); 1281 | } 1282 | return data; 1283 | } 1284 | 1285 | double *get_QD() const 1286 | { 1287 | return QD; 1288 | } 1289 | 1290 | void swap_index(int i, int j) const 1291 | { 1292 | cache->swap_index(i,j); 1293 | Kernel::swap_index(i,j); 1294 | swap(y[i],y[j]); 1295 | swap(QD[i],QD[j]); 1296 | } 1297 | 1298 | ~SVC_Q() 1299 | { 1300 | delete[] y; 1301 | delete cache; 1302 | delete[] QD; 1303 | } 1304 | private: 1305 | schar *y; 1306 | Cache *cache; 1307 | double *QD; 1308 | }; 1309 | 1310 | class ONE_CLASS_Q: public Kernel 1311 | { 1312 | public: 1313 | ONE_CLASS_Q(const svm_problem& prob, const svm_parameter& param) 1314 | :Kernel(prob.l, prob.x, param) 1315 | { 1316 | cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20))); 1317 | QD = new double[prob.l]; 1318 | for(int i=0;i*kernel_function)(i,i); 1320 | } 1321 | 1322 | Qfloat *get_Q(int i, int len) const 1323 | { 1324 | Qfloat *data; 1325 | int start, j; 1326 | if((start = cache->get_data(i,&data,len)) < len) 1327 | { 1328 | for(j=start;j*kernel_function)(i,j); 1330 | } 1331 | return data; 1332 | } 1333 | 1334 | double *get_QD() const 1335 | { 1336 | return QD; 1337 | } 1338 | 1339 | void swap_index(int i, int j) const 1340 | { 1341 | cache->swap_index(i,j); 1342 | Kernel::swap_index(i,j); 1343 | swap(QD[i],QD[j]); 1344 | } 1345 | 1346 | ~ONE_CLASS_Q() 1347 | { 1348 | delete cache; 1349 | delete[] QD; 1350 | } 1351 | private: 1352 | Cache *cache; 1353 | double *QD; 1354 | }; 1355 | 1356 | class SVR_Q: public Kernel 1357 | { 1358 | public: 1359 | SVR_Q(const svm_problem& prob, const svm_parameter& param) 1360 | :Kernel(prob.l, prob.x, param) 1361 | { 1362 | l = prob.l; 1363 | cache = new Cache(l,(long int)(param.cache_size*(1<<20))); 1364 | QD = new double[2*l]; 1365 | sign = new schar[2*l]; 1366 | index = new int[2*l]; 1367 | for(int k=0;k*kernel_function)(k,k); 1374 | QD[k+l] = QD[k]; 1375 | } 1376 | buffer[0] = new Qfloat[2*l]; 1377 | buffer[1] = new Qfloat[2*l]; 1378 | next_buffer = 0; 1379 | } 1380 | 1381 | void swap_index(int i, int j) const 1382 | { 1383 | swap(sign[i],sign[j]); 1384 | swap(index[i],index[j]); 1385 | swap(QD[i],QD[j]); 1386 | } 1387 | 1388 | Qfloat *get_Q(int i, int len) const 1389 | { 1390 | Qfloat *data; 1391 | int j, real_i = index[i]; 1392 | if(cache->get_data(real_i,&data,l) < l) 1393 | { 1394 | for(j=0;j*kernel_function)(real_i,j); 1396 | } 1397 | 1398 | // reorder and copy 1399 | Qfloat *buf = buffer[next_buffer]; 1400 | next_buffer = 1 - next_buffer; 1401 | schar si = sign[i]; 1402 | for(j=0;jl; 1439 | double *minus_ones = new double[l]; 1440 | schar *y = new schar[l]; 1441 | 1442 | int i; 1443 | 1444 | for(i=0;iy[i] > 0) y[i] = +1; else y[i] = -1; 1449 | } 1450 | 1451 | Solver s; 1452 | s.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y, 1453 | alpha, Cp, Cn, param->eps, si, param->shrinking); 1454 | 1455 | double sum_alpha=0; 1456 | for(i=0;il)); 1461 | 1462 | for(i=0;il; 1475 | double nu = param->nu; 1476 | 1477 | schar *y = new schar[l]; 1478 | 1479 | for(i=0;iy[i]>0) 1481 | y[i] = +1; 1482 | else 1483 | y[i] = -1; 1484 | 1485 | double sum_pos = nu*l/2; 1486 | double sum_neg = nu*l/2; 1487 | 1488 | for(i=0;ieps, si, param->shrinking); 1508 | double r = si->r; 1509 | 1510 | info("C = %f\n",1/r); 1511 | 1512 | for(i=0;irho /= r; 1516 | si->obj /= (r*r); 1517 | si->upper_bound_p = 1/r; 1518 | si->upper_bound_n = 1/r; 1519 | 1520 | delete[] y; 1521 | delete[] zeros; 1522 | } 1523 | 1524 | static void solve_one_class( 1525 | const svm_problem *prob, const svm_parameter *param, 1526 | double *alpha, Solver::SolutionInfo* si) 1527 | { 1528 | int l = prob->l; 1529 | double *zeros = new double[l]; 1530 | schar *ones = new schar[l]; 1531 | int i; 1532 | 1533 | int n = (int)(param->nu*prob->l); // # of alpha's at upper bound 1534 | 1535 | for(i=0;il) 1538 | alpha[n] = param->nu * prob->l - n; 1539 | for(i=n+1;ieps, si, param->shrinking); 1551 | 1552 | delete[] zeros; 1553 | delete[] ones; 1554 | } 1555 | 1556 | static void solve_epsilon_svr( 1557 | const svm_problem *prob, const svm_parameter *param, 1558 | double *alpha, Solver::SolutionInfo* si) 1559 | { 1560 | int l = prob->l; 1561 | double *alpha2 = new double[2*l]; 1562 | double *linear_term = new double[2*l]; 1563 | schar *y = new schar[2*l]; 1564 | int i; 1565 | 1566 | for(i=0;ip - prob->y[i]; 1570 | y[i] = 1; 1571 | 1572 | alpha2[i+l] = 0; 1573 | linear_term[i+l] = param->p + prob->y[i]; 1574 | y[i+l] = -1; 1575 | } 1576 | 1577 | Solver s; 1578 | s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y, 1579 | alpha2, param->C, param->C, param->eps, si, param->shrinking); 1580 | 1581 | double sum_alpha = 0; 1582 | for(i=0;iC*l)); 1588 | 1589 | delete[] alpha2; 1590 | delete[] linear_term; 1591 | delete[] y; 1592 | } 1593 | 1594 | static void solve_nu_svr( 1595 | const svm_problem *prob, const svm_parameter *param, 1596 | double *alpha, Solver::SolutionInfo* si) 1597 | { 1598 | int l = prob->l; 1599 | double C = param->C; 1600 | double *alpha2 = new double[2*l]; 1601 | double *linear_term = new double[2*l]; 1602 | schar *y = new schar[2*l]; 1603 | int i; 1604 | 1605 | double sum = C * param->nu * l / 2; 1606 | for(i=0;iy[i]; 1612 | y[i] = 1; 1613 | 1614 | linear_term[i+l] = prob->y[i]; 1615 | y[i+l] = -1; 1616 | } 1617 | 1618 | Solver_NU s; 1619 | s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y, 1620 | alpha2, C, C, param->eps, si, param->shrinking); 1621 | 1622 | info("epsilon = %f\n",-si->r); 1623 | 1624 | for(i=0;il); 1646 | Solver::SolutionInfo si; 1647 | switch(param->svm_type) 1648 | { 1649 | case C_SVC: 1650 | solve_c_svc(prob,param,alpha,&si,Cp,Cn); 1651 | break; 1652 | case NU_SVC: 1653 | solve_nu_svc(prob,param,alpha,&si); 1654 | break; 1655 | case ONE_CLASS: 1656 | solve_one_class(prob,param,alpha,&si); 1657 | break; 1658 | case EPSILON_SVR: 1659 | solve_epsilon_svr(prob,param,alpha,&si); 1660 | break; 1661 | case NU_SVR: 1662 | solve_nu_svr(prob,param,alpha,&si); 1663 | break; 1664 | } 1665 | 1666 | info("obj = %f, rho = %f\n",si.obj,si.rho); 1667 | 1668 | // output SVs 1669 | 1670 | int nSV = 0; 1671 | int nBSV = 0; 1672 | for(int i=0;il;i++) 1673 | { 1674 | if(fabs(alpha[i]) > 0) 1675 | { 1676 | ++nSV; 1677 | if(prob->y[i] > 0) 1678 | { 1679 | if(fabs(alpha[i]) >= si.upper_bound_p) 1680 | ++nBSV; 1681 | } 1682 | else 1683 | { 1684 | if(fabs(alpha[i]) >= si.upper_bound_n) 1685 | ++nBSV; 1686 | } 1687 | } 1688 | } 1689 | 1690 | info("nSV = %d, nBSV = %d\n",nSV,nBSV); 1691 | 1692 | decision_function f; 1693 | f.alpha = alpha; 1694 | f.rho = si.rho; 1695 | return f; 1696 | } 1697 | 1698 | // Platt's binary SVM Probablistic Output: an improvement from Lin et al. 1699 | static void sigmoid_train( 1700 | int l, const double *dec_values, const double *labels, 1701 | double& A, double& B) 1702 | { 1703 | double prior1=0, prior0 = 0; 1704 | int i; 1705 | 1706 | for (i=0;i 0) prior1+=1; 1708 | else prior0+=1; 1709 | 1710 | int max_iter=100; // Maximal number of iterations 1711 | double min_step=1e-10; // Minimal step taken in line search 1712 | double sigma=1e-12; // For numerically strict PD of Hessian 1713 | double eps=1e-5; 1714 | double hiTarget=(prior1+1.0)/(prior1+2.0); 1715 | double loTarget=1/(prior0+2.0); 1716 | double *t=Malloc(double,l); 1717 | double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize; 1718 | double newA,newB,newf,d1,d2; 1719 | int iter; 1720 | 1721 | // Initial Point and Initial Fun Value 1722 | A=0.0; B=log((prior0+1.0)/(prior1+1.0)); 1723 | double fval = 0.0; 1724 | 1725 | for (i=0;i0) t[i]=hiTarget; 1728 | else t[i]=loTarget; 1729 | fApB = dec_values[i]*A+B; 1730 | if (fApB>=0) 1731 | fval += t[i]*fApB + log(1+exp(-fApB)); 1732 | else 1733 | fval += (t[i] - 1)*fApB +log(1+exp(fApB)); 1734 | } 1735 | for (iter=0;iter= 0) 1745 | { 1746 | p=exp(-fApB)/(1.0+exp(-fApB)); 1747 | q=1.0/(1.0+exp(-fApB)); 1748 | } 1749 | else 1750 | { 1751 | p=1.0/(1.0+exp(fApB)); 1752 | q=exp(fApB)/(1.0+exp(fApB)); 1753 | } 1754 | d2=p*q; 1755 | h11+=dec_values[i]*dec_values[i]*d2; 1756 | h22+=d2; 1757 | h21+=dec_values[i]*d2; 1758 | d1=t[i]-p; 1759 | g1+=dec_values[i]*d1; 1760 | g2+=d1; 1761 | } 1762 | 1763 | // Stopping Criteria 1764 | if (fabs(g1)= min_step) 1776 | { 1777 | newA = A + stepsize * dA; 1778 | newB = B + stepsize * dB; 1779 | 1780 | // New function value 1781 | newf = 0.0; 1782 | for (i=0;i= 0) 1786 | newf += t[i]*fApB + log(1+exp(-fApB)); 1787 | else 1788 | newf += (t[i] - 1)*fApB +log(1+exp(fApB)); 1789 | } 1790 | // Check sufficient decrease 1791 | if (newf=max_iter) 1808 | info("Reaching maximal iterations in two-class probability estimates\n"); 1809 | free(t); 1810 | } 1811 | 1812 | static double sigmoid_predict(double decision_value, double A, double B) 1813 | { 1814 | double fApB = decision_value*A+B; 1815 | // 1-p used later; avoid catastrophic cancellation 1816 | if (fApB >= 0) 1817 | return exp(-fApB)/(1.0+exp(-fApB)); 1818 | else 1819 | return 1.0/(1+exp(fApB)) ; 1820 | } 1821 | 1822 | // Method 2 from the multiclass_prob paper by Wu, Lin, and Weng 1823 | static void multiclass_probability(int k, double **r, double *p) 1824 | { 1825 | int t,j; 1826 | int iter = 0, max_iter=max(100,k); 1827 | double **Q=Malloc(double *,k); 1828 | double *Qp=Malloc(double,k); 1829 | double pQp, eps=0.005/k; 1830 | 1831 | for (t=0;tmax_error) 1863 | max_error=error; 1864 | } 1865 | if (max_error=max_iter) 1880 | info("Exceeds max_iter in multiclass_prob\n"); 1881 | for(t=0;tl); 1894 | double *dec_values = Malloc(double,prob->l); 1895 | 1896 | // random shuffle 1897 | for(i=0;il;i++) perm[i]=i; 1898 | for(i=0;il;i++) 1899 | { 1900 | int j = i+rand()%(prob->l-i); 1901 | swap(perm[i],perm[j]); 1902 | } 1903 | for(i=0;il/nr_fold; 1906 | int end = (i+1)*prob->l/nr_fold; 1907 | int j,k; 1908 | struct svm_problem subprob; 1909 | 1910 | subprob.l = prob->l-(end-begin); 1911 | subprob.x = Malloc(struct svm_node*,subprob.l); 1912 | subprob.y = Malloc(double,subprob.l); 1913 | 1914 | k=0; 1915 | for(j=0;jx[perm[j]]; 1918 | subprob.y[k] = prob->y[perm[j]]; 1919 | ++k; 1920 | } 1921 | for(j=end;jl;j++) 1922 | { 1923 | subprob.x[k] = prob->x[perm[j]]; 1924 | subprob.y[k] = prob->y[perm[j]]; 1925 | ++k; 1926 | } 1927 | int p_count=0,n_count=0; 1928 | for(j=0;j0) 1930 | p_count++; 1931 | else 1932 | n_count++; 1933 | 1934 | if(p_count==0 && n_count==0) 1935 | for(j=begin;j 0 && n_count == 0) 1938 | for(j=begin;j 0) 1941 | for(j=begin;jx[perm[j]],&(dec_values[perm[j]])); 1959 | // ensure +1 -1 order; reason not using CV subroutine 1960 | dec_values[perm[j]] *= submodel->label[0]; 1961 | } 1962 | svm_free_and_destroy_model(&submodel); 1963 | svm_destroy_param(&subparam); 1964 | } 1965 | free(subprob.x); 1966 | free(subprob.y); 1967 | } 1968 | sigmoid_train(prob->l,dec_values,prob->y,probA,probB); 1969 | free(dec_values); 1970 | free(perm); 1971 | } 1972 | 1973 | // Return parameter of a Laplace distribution 1974 | static double svm_svr_probability( 1975 | const svm_problem *prob, const svm_parameter *param) 1976 | { 1977 | int i; 1978 | int nr_fold = 5; 1979 | double *ymv = Malloc(double,prob->l); 1980 | double mae = 0; 1981 | 1982 | svm_parameter newparam = *param; 1983 | newparam.probability = 0; 1984 | svm_cross_validation(prob,&newparam,nr_fold,ymv); 1985 | for(i=0;il;i++) 1986 | { 1987 | ymv[i]=prob->y[i]-ymv[i]; 1988 | mae += fabs(ymv[i]); 1989 | } 1990 | mae /= prob->l; 1991 | double std=sqrt(2*mae*mae); 1992 | int count=0; 1993 | mae=0; 1994 | for(i=0;il;i++) 1995 | if (fabs(ymv[i]) > 5*std) 1996 | count=count+1; 1997 | else 1998 | mae+=fabs(ymv[i]); 1999 | mae /= (prob->l-count); 2000 | info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= %g\n",mae); 2001 | free(ymv); 2002 | return mae; 2003 | } 2004 | 2005 | 2006 | // label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data 2007 | // perm, length l, must be allocated before calling this subroutine 2008 | static void svm_group_classes(const svm_problem *prob, int *nr_class_ret, int **label_ret, int **start_ret, int **count_ret, int *perm) 2009 | { 2010 | int l = prob->l; 2011 | int max_nr_class = 16; 2012 | int nr_class = 0; 2013 | int *label = Malloc(int,max_nr_class); 2014 | int *count = Malloc(int,max_nr_class); 2015 | int *data_label = Malloc(int,l); 2016 | int i; 2017 | 2018 | for(i=0;iy[i]; 2021 | int j; 2022 | for(j=0;jparam = *param; 2072 | model->free_sv = 0; // XXX 2073 | 2074 | if(param->svm_type == ONE_CLASS || 2075 | param->svm_type == EPSILON_SVR || 2076 | param->svm_type == NU_SVR) 2077 | { 2078 | // regression or one-class-svm 2079 | model->nr_class = 2; 2080 | model->label = NULL; 2081 | model->nSV = NULL; 2082 | model->probA = NULL; model->probB = NULL; 2083 | model->sv_coef = Malloc(double *,1); 2084 | 2085 | if(param->probability && 2086 | (param->svm_type == EPSILON_SVR || 2087 | param->svm_type == NU_SVR)) 2088 | { 2089 | model->probA = Malloc(double,1); 2090 | model->probA[0] = svm_svr_probability(prob,param); 2091 | } 2092 | 2093 | decision_function f = svm_train_one(prob,param,0,0); 2094 | model->rho = Malloc(double,1); 2095 | model->rho[0] = f.rho; 2096 | 2097 | int nSV = 0; 2098 | int i; 2099 | for(i=0;il;i++) 2100 | if(fabs(f.alpha[i]) > 0) ++nSV; 2101 | model->l = nSV; 2102 | model->SV = Malloc(svm_node *,nSV); 2103 | model->sv_coef[0] = Malloc(double,nSV); 2104 | int j = 0; 2105 | for(i=0;il;i++) 2106 | if(fabs(f.alpha[i]) > 0) 2107 | { 2108 | model->SV[j] = prob->x[i]; 2109 | model->sv_coef[0][j] = f.alpha[i]; 2110 | ++j; 2111 | } 2112 | 2113 | free(f.alpha); 2114 | } 2115 | else 2116 | { 2117 | // classification 2118 | int l = prob->l; 2119 | int nr_class; 2120 | int *label = NULL; 2121 | int *start = NULL; 2122 | int *count = NULL; 2123 | int *perm = Malloc(int,l); 2124 | 2125 | // group training data of the same class 2126 | svm_group_classes(prob,&nr_class,&label,&start,&count,perm); 2127 | if(nr_class == 1) 2128 | info("WARNING: training data in only one class. See README for details.\n"); 2129 | 2130 | svm_node **x = Malloc(svm_node *,l); 2131 | int i; 2132 | for(i=0;ix[perm[i]]; 2134 | 2135 | // calculate weighted C 2136 | 2137 | double *weighted_C = Malloc(double, nr_class); 2138 | for(i=0;iC; 2140 | for(i=0;inr_weight;i++) 2141 | { 2142 | int j; 2143 | for(j=0;jweight_label[i] == label[j]) 2145 | break; 2146 | if(j == nr_class) 2147 | fprintf(stderr,"WARNING: class label %d specified in weight is not found\n", param->weight_label[i]); 2148 | else 2149 | weighted_C[j] *= param->weight[i]; 2150 | } 2151 | 2152 | // train k*(k-1)/2 models 2153 | 2154 | bool *nonzero = Malloc(bool,l); 2155 | for(i=0;iprobability) 2161 | { 2162 | probA=Malloc(double,nr_class*(nr_class-1)/2); 2163 | probB=Malloc(double,nr_class*(nr_class-1)/2); 2164 | } 2165 | 2166 | int p = 0; 2167 | for(i=0;iprobability) 2189 | svm_binary_svc_probability(&sub_prob,param,weighted_C[i],weighted_C[j],probA[p],probB[p]); 2190 | 2191 | f[p] = svm_train_one(&sub_prob,param,weighted_C[i],weighted_C[j]); 2192 | for(k=0;k 0) 2194 | nonzero[si+k] = true; 2195 | for(k=0;k 0) 2197 | nonzero[sj+k] = true; 2198 | free(sub_prob.x); 2199 | free(sub_prob.y); 2200 | ++p; 2201 | } 2202 | 2203 | // build output 2204 | 2205 | model->nr_class = nr_class; 2206 | 2207 | model->label = Malloc(int,nr_class); 2208 | for(i=0;ilabel[i] = label[i]; 2210 | 2211 | model->rho = Malloc(double,nr_class*(nr_class-1)/2); 2212 | for(i=0;irho[i] = f[i].rho; 2214 | 2215 | if(param->probability) 2216 | { 2217 | model->probA = Malloc(double,nr_class*(nr_class-1)/2); 2218 | model->probB = Malloc(double,nr_class*(nr_class-1)/2); 2219 | for(i=0;iprobA[i] = probA[i]; 2222 | model->probB[i] = probB[i]; 2223 | } 2224 | } 2225 | else 2226 | { 2227 | model->probA=NULL; 2228 | model->probB=NULL; 2229 | } 2230 | 2231 | int total_sv = 0; 2232 | int *nz_count = Malloc(int,nr_class); 2233 | model->nSV = Malloc(int,nr_class); 2234 | for(i=0;inSV[i] = nSV; 2244 | nz_count[i] = nSV; 2245 | } 2246 | 2247 | info("Total nSV = %d\n",total_sv); 2248 | 2249 | model->l = total_sv; 2250 | model->SV = Malloc(svm_node *,total_sv); 2251 | p = 0; 2252 | for(i=0;iSV[p++] = x[i]; 2254 | 2255 | int *nz_start = Malloc(int,nr_class); 2256 | nz_start[0] = 0; 2257 | for(i=1;isv_coef = Malloc(double *,nr_class-1); 2261 | for(i=0;isv_coef[i] = Malloc(double,total_sv); 2263 | 2264 | p = 0; 2265 | for(i=0;isv_coef[j-1][q++] = f[p].alpha[k]; 2282 | q = nz_start[j]; 2283 | for(k=0;ksv_coef[i][q++] = f[p].alpha[ci+k]; 2286 | ++p; 2287 | } 2288 | 2289 | free(label); 2290 | free(probA); 2291 | free(probB); 2292 | free(count); 2293 | free(perm); 2294 | free(start); 2295 | free(x); 2296 | free(weighted_C); 2297 | free(nonzero); 2298 | for(i=0;il; 2313 | int *perm = Malloc(int,l); 2314 | int nr_class; 2315 | 2316 | // stratified cv may not give leave-one-out rate 2317 | // Each class to l folds -> some folds may have zero elements 2318 | if((param->svm_type == C_SVC || 2319 | param->svm_type == NU_SVC) && nr_fold < l) 2320 | { 2321 | int *start = NULL; 2322 | int *label = NULL; 2323 | int *count = NULL; 2324 | svm_group_classes(prob,&nr_class,&label,&start,&count,perm); 2325 | 2326 | // random shuffle and then data grouped by fold using the array perm 2327 | int *fold_count = Malloc(int,nr_fold); 2328 | int c; 2329 | int *index = Malloc(int,l); 2330 | for(i=0;ix[perm[j]]; 2394 | subprob.y[k] = prob->y[perm[j]]; 2395 | ++k; 2396 | } 2397 | for(j=end;jx[perm[j]]; 2400 | subprob.y[k] = prob->y[perm[j]]; 2401 | ++k; 2402 | } 2403 | struct svm_model *submodel = svm_train(&subprob,param); 2404 | if(param->probability && 2405 | (param->svm_type == C_SVC || param->svm_type == NU_SVC)) 2406 | { 2407 | double *prob_estimates=Malloc(double,svm_get_nr_class(submodel)); 2408 | for(j=begin;jx[perm[j]],prob_estimates); 2410 | free(prob_estimates); 2411 | } 2412 | else 2413 | for(j=begin;jx[perm[j]]); 2415 | svm_free_and_destroy_model(&submodel); 2416 | free(subprob.x); 2417 | free(subprob.y); 2418 | } 2419 | free(fold_start); 2420 | free(perm); 2421 | } 2422 | 2423 | 2424 | int svm_get_svm_type(const svm_model *model) 2425 | { 2426 | return model->param.svm_type; 2427 | } 2428 | 2429 | int svm_get_nr_class(const svm_model *model) 2430 | { 2431 | return model->nr_class; 2432 | } 2433 | 2434 | void svm_get_labels(const svm_model *model, int* label) 2435 | { 2436 | if (model->label != NULL) 2437 | for(int i=0;inr_class;i++) 2438 | label[i] = model->label[i]; 2439 | } 2440 | 2441 | double svm_get_svr_probability(const svm_model *model) 2442 | { 2443 | if ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) && 2444 | model->probA!=NULL) 2445 | return model->probA[0]; 2446 | else 2447 | { 2448 | fprintf(stderr,"Model doesn't contain information for SVR probability inference\n"); 2449 | return 0; 2450 | } 2451 | } 2452 | 2453 | double svm_predict_values(const svm_model *model, const svm_node *x, double* dec_values) 2454 | { 2455 | int i; 2456 | if(model->param.svm_type == ONE_CLASS || 2457 | model->param.svm_type == EPSILON_SVR || 2458 | model->param.svm_type == NU_SVR) 2459 | { 2460 | double *sv_coef = model->sv_coef[0]; 2461 | double sum = 0; 2462 | for(i=0;il;i++) 2463 | sum += sv_coef[i] * Kernel::k_function(x,model->SV[i],model->param); 2464 | sum -= model->rho[0]; 2465 | *dec_values = sum; 2466 | 2467 | if(model->param.svm_type == ONE_CLASS) 2468 | return (sum>0)?1:-1; 2469 | else 2470 | return sum; 2471 | } 2472 | else 2473 | { 2474 | int nr_class = model->nr_class; 2475 | int l = model->l; 2476 | 2477 | double *kvalue = Malloc(double,l); 2478 | for(i=0;iSV[i],model->param); 2480 | 2481 | int *start = Malloc(int,nr_class); 2482 | start[0] = 0; 2483 | for(i=1;inSV[i-1]; 2485 | 2486 | int *vote = Malloc(int,nr_class); 2487 | for(i=0;inSV[i]; 2498 | int cj = model->nSV[j]; 2499 | 2500 | int k; 2501 | double *coef1 = model->sv_coef[j-1]; 2502 | double *coef2 = model->sv_coef[i]; 2503 | for(k=0;krho[p]; 2508 | dec_values[p] = sum; 2509 | 2510 | if(dec_values[p] > 0) 2511 | ++vote[i]; 2512 | else 2513 | ++vote[j]; 2514 | p++; 2515 | } 2516 | 2517 | int vote_max_idx = 0; 2518 | for(i=1;i vote[vote_max_idx]) 2520 | vote_max_idx = i; 2521 | 2522 | free(kvalue); 2523 | free(start); 2524 | free(vote); 2525 | return model->label[vote_max_idx]; 2526 | } 2527 | } 2528 | 2529 | double svm_predict(const svm_model *model, const svm_node *x) 2530 | { 2531 | int nr_class = model->nr_class; 2532 | double *dec_values; 2533 | if(model->param.svm_type == ONE_CLASS || 2534 | model->param.svm_type == EPSILON_SVR || 2535 | model->param.svm_type == NU_SVR) 2536 | dec_values = Malloc(double, 1); 2537 | else 2538 | dec_values = Malloc(double, nr_class*(nr_class-1)/2); 2539 | double pred_result = svm_predict_values(model, x, dec_values); 2540 | free(dec_values); 2541 | return pred_result; 2542 | } 2543 | 2544 | double svm_predict_probability( 2545 | const svm_model *model, const svm_node *x, double *prob_estimates) 2546 | { 2547 | if ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) && 2548 | model->probA!=NULL && model->probB!=NULL) 2549 | { 2550 | int i; 2551 | int nr_class = model->nr_class; 2552 | double *dec_values = Malloc(double, nr_class*(nr_class-1)/2); 2553 | svm_predict_values(model, x, dec_values); 2554 | 2555 | double min_prob=1e-7; 2556 | double **pairwise_prob=Malloc(double *,nr_class); 2557 | for(i=0;iprobA[k],model->probB[k]),min_prob),1-min_prob); 2564 | pairwise_prob[j][i]=1-pairwise_prob[i][j]; 2565 | k++; 2566 | } 2567 | multiclass_probability(nr_class,pairwise_prob,prob_estimates); 2568 | 2569 | int prob_max_idx = 0; 2570 | for(i=1;i prob_estimates[prob_max_idx]) 2572 | prob_max_idx = i; 2573 | for(i=0;ilabel[prob_max_idx]; 2578 | } 2579 | else 2580 | return svm_predict(model, x); 2581 | } 2582 | 2583 | static const char *svm_type_table[] = 2584 | { 2585 | "c_svc","nu_svc","one_class","epsilon_svr","nu_svr",NULL 2586 | }; 2587 | 2588 | static const char *kernel_type_table[]= 2589 | { 2590 | "linear","polynomial","rbf","sigmoid","precomputed",NULL 2591 | }; 2592 | 2593 | int svm_save_model(const char *model_file_name, const svm_model *model) 2594 | { 2595 | FILE *fp = fopen(model_file_name,"w"); 2596 | if(fp==NULL) return -1; 2597 | 2598 | //char *old_locale = _strdup(setlocale(LC_ALL, NULL)); 2599 | char *old_locale = strdup(setlocale(LC_ALL, NULL)); 2600 | setlocale(LC_ALL, "C"); 2601 | 2602 | const svm_parameter& param = model->param; 2603 | 2604 | fprintf(fp,"svm_type %s\n", svm_type_table[param.svm_type]); 2605 | fprintf(fp,"kernel_type %s\n", kernel_type_table[param.kernel_type]); 2606 | 2607 | if(param.kernel_type == POLY) 2608 | fprintf(fp,"degree %d\n", param.degree); 2609 | 2610 | if(param.kernel_type == POLY || param.kernel_type == RBF || param.kernel_type == SIGMOID) 2611 | fprintf(fp,"gamma %g\n", param.gamma); 2612 | 2613 | if(param.kernel_type == POLY || param.kernel_type == SIGMOID) 2614 | fprintf(fp,"coef0 %g\n", param.coef0); 2615 | 2616 | int nr_class = model->nr_class; 2617 | int l = model->l; 2618 | fprintf(fp, "nr_class %d\n", nr_class); 2619 | fprintf(fp, "total_sv %d\n",l); 2620 | 2621 | { 2622 | fprintf(fp, "rho"); 2623 | for(int i=0;irho[i]); 2625 | fprintf(fp, "\n"); 2626 | } 2627 | 2628 | if(model->label) 2629 | { 2630 | fprintf(fp, "label"); 2631 | for(int i=0;ilabel[i]); 2633 | fprintf(fp, "\n"); 2634 | } 2635 | 2636 | if(model->probA) // regression has probA only 2637 | { 2638 | fprintf(fp, "probA"); 2639 | for(int i=0;iprobA[i]); 2641 | fprintf(fp, "\n"); 2642 | } 2643 | if(model->probB) 2644 | { 2645 | fprintf(fp, "probB"); 2646 | for(int i=0;iprobB[i]); 2648 | fprintf(fp, "\n"); 2649 | } 2650 | 2651 | if(model->nSV) 2652 | { 2653 | fprintf(fp, "nr_sv"); 2654 | for(int i=0;inSV[i]); 2656 | fprintf(fp, "\n"); 2657 | } 2658 | 2659 | fprintf(fp, "SV\n"); 2660 | const double * const *sv_coef = model->sv_coef; 2661 | const svm_node * const *SV = model->SV; 2662 | 2663 | for(int i=0;ivalue)); 2672 | else 2673 | while(p->index != -1) 2674 | { 2675 | fprintf(fp,"%d:%.8g ",p->index,p->value); 2676 | p++; 2677 | } 2678 | fprintf(fp, "\n"); 2679 | } 2680 | 2681 | setlocale(LC_ALL, old_locale); 2682 | free(old_locale); 2683 | 2684 | 2685 | if (ferror(fp) != 0 || fclose(fp) != 0) return -1; 2686 | else return 0; 2687 | } 2688 | 2689 | static char *line = NULL; 2690 | static int max_line_len; 2691 | 2692 | static char* readline(FILE *input) 2693 | { 2694 | int len; 2695 | 2696 | if(fgets(line,max_line_len,input) == NULL) 2697 | return NULL; 2698 | 2699 | while(strrchr(line,'\n') == NULL) 2700 | { 2701 | max_line_len *= 2; 2702 | line = (char *) realloc(line,max_line_len); 2703 | len = (int) strlen(line); 2704 | if(fgets(line+len,max_line_len-len,input) == NULL) 2705 | break; 2706 | } 2707 | return line; 2708 | } 2709 | 2710 | svm_model *svm_load_model(const char *model_file_name) 2711 | { 2712 | FILE *fp = fopen(model_file_name,"rb"); 2713 | if(fp==NULL) return NULL; 2714 | 2715 | //char *old_locale = _strdup(setlocale(LC_ALL, NULL)); 2716 | char *old_locale = strdup(setlocale(LC_ALL, NULL)); 2717 | setlocale(LC_ALL, "C"); 2718 | 2719 | // read parameters 2720 | 2721 | svm_model *model = Malloc(svm_model,1); 2722 | svm_parameter& param = model->param; 2723 | model->rho = NULL; 2724 | model->probA = NULL; 2725 | model->probB = NULL; 2726 | model->label = NULL; 2727 | model->nSV = NULL; 2728 | 2729 | char cmd[81]; 2730 | while(1) 2731 | { 2732 | fscanf(fp,"%80s",cmd); 2733 | 2734 | if(strcmp(cmd,"svm_type")==0) 2735 | { 2736 | fscanf(fp,"%80s",cmd); 2737 | int i; 2738 | for(i=0;svm_type_table[i];i++) 2739 | { 2740 | if(strcmp(svm_type_table[i],cmd)==0) 2741 | { 2742 | param.svm_type=i; 2743 | break; 2744 | } 2745 | } 2746 | if(svm_type_table[i] == NULL) 2747 | { 2748 | fprintf(stderr,"unknown svm type.\n"); 2749 | 2750 | setlocale(LC_ALL, old_locale); 2751 | free(old_locale); 2752 | free(model->rho); 2753 | free(model->label); 2754 | free(model->nSV); 2755 | free(model); 2756 | return NULL; 2757 | } 2758 | } 2759 | else if(strcmp(cmd,"kernel_type")==0) 2760 | { 2761 | fscanf(fp,"%80s",cmd); 2762 | int i; 2763 | for(i=0;kernel_type_table[i];i++) 2764 | { 2765 | if(strcmp(kernel_type_table[i],cmd)==0) 2766 | { 2767 | param.kernel_type=i; 2768 | break; 2769 | } 2770 | } 2771 | if(kernel_type_table[i] == NULL) 2772 | { 2773 | fprintf(stderr,"unknown kernel function.\n"); 2774 | 2775 | setlocale(LC_ALL, old_locale); 2776 | free(old_locale); 2777 | free(model->rho); 2778 | free(model->label); 2779 | free(model->nSV); 2780 | free(model); 2781 | return NULL; 2782 | } 2783 | } 2784 | else if(strcmp(cmd,"degree")==0) 2785 | fscanf(fp,"%d",¶m.degree); 2786 | else if(strcmp(cmd,"gamma")==0) 2787 | fscanf(fp,"%lf",¶m.gamma); 2788 | else if(strcmp(cmd,"coef0")==0) 2789 | fscanf(fp,"%lf",¶m.coef0); 2790 | else if(strcmp(cmd,"nr_class")==0) 2791 | fscanf(fp,"%d",&model->nr_class); 2792 | else if(strcmp(cmd,"total_sv")==0) 2793 | fscanf(fp,"%d",&model->l); 2794 | else if(strcmp(cmd,"rho")==0) 2795 | { 2796 | int n = model->nr_class * (model->nr_class-1)/2; 2797 | model->rho = Malloc(double,n); 2798 | for(int i=0;irho[i]); 2800 | } 2801 | else if(strcmp(cmd,"label")==0) 2802 | { 2803 | int n = model->nr_class; 2804 | model->label = Malloc(int,n); 2805 | for(int i=0;ilabel[i]); 2807 | } 2808 | else if(strcmp(cmd,"probA")==0) 2809 | { 2810 | int n = model->nr_class * (model->nr_class-1)/2; 2811 | model->probA = Malloc(double,n); 2812 | for(int i=0;iprobA[i]); 2814 | } 2815 | else if(strcmp(cmd,"probB")==0) 2816 | { 2817 | int n = model->nr_class * (model->nr_class-1)/2; 2818 | model->probB = Malloc(double,n); 2819 | for(int i=0;iprobB[i]); 2821 | } 2822 | else if(strcmp(cmd,"nr_sv")==0) 2823 | { 2824 | int n = model->nr_class; 2825 | model->nSV = Malloc(int,n); 2826 | for(int i=0;inSV[i]); 2828 | } 2829 | else if(strcmp(cmd,"SV")==0) 2830 | { 2831 | while(1) 2832 | { 2833 | int c = getc(fp); 2834 | if(c==EOF || c=='\n') break; 2835 | } 2836 | break; 2837 | } 2838 | else 2839 | { 2840 | fprintf(stderr,"unknown text in model file: [%s]\n",cmd); 2841 | 2842 | setlocale(LC_ALL, old_locale); 2843 | free(old_locale); 2844 | free(model->rho); 2845 | free(model->label); 2846 | free(model->nSV); 2847 | free(model); 2848 | return NULL; 2849 | } 2850 | } 2851 | 2852 | // read sv_coef and SV 2853 | 2854 | int elements = 0; 2855 | long pos = ftell(fp); 2856 | 2857 | max_line_len = 1024; 2858 | line = Malloc(char,max_line_len); 2859 | char *p,*endptr,*idx,*val; 2860 | 2861 | while(readline(fp)!=NULL) 2862 | { 2863 | p = strtok(line,":"); 2864 | while(1) 2865 | { 2866 | p = strtok(NULL,":"); 2867 | if(p == NULL) 2868 | break; 2869 | ++elements; 2870 | } 2871 | } 2872 | elements += model->l; 2873 | 2874 | fseek(fp,pos,SEEK_SET); 2875 | 2876 | int m = model->nr_class - 1; 2877 | int l = model->l; 2878 | model->sv_coef = Malloc(double *,m); 2879 | int i; 2880 | for(i=0;isv_coef[i] = Malloc(double,l); 2882 | model->SV = Malloc(svm_node*,l); 2883 | svm_node *x_space = NULL; 2884 | if(l>0) x_space = Malloc(svm_node,elements); 2885 | 2886 | int j=0; 2887 | for(i=0;iSV[i] = &x_space[j]; 2891 | 2892 | p = strtok(line, " \t"); 2893 | model->sv_coef[0][i] = strtod(p,&endptr); 2894 | for(int k=1;ksv_coef[k][i] = strtod(p,&endptr); 2898 | } 2899 | 2900 | while(1) 2901 | { 2902 | idx = strtok(NULL, ":"); 2903 | val = strtok(NULL, " \t"); 2904 | 2905 | if(val == NULL) 2906 | break; 2907 | x_space[j].index = (int) strtol(idx,&endptr,10); 2908 | x_space[j].value = strtod(val,&endptr); 2909 | 2910 | ++j; 2911 | } 2912 | x_space[j++].index = -1; 2913 | } 2914 | free(line); 2915 | 2916 | setlocale(LC_ALL, old_locale); 2917 | free(old_locale); 2918 | 2919 | if (ferror(fp) != 0 || fclose(fp) != 0) 2920 | return NULL; 2921 | 2922 | model->free_sv = 1; // XXX 2923 | return model; 2924 | } 2925 | 2926 | void svm_free_model_content(svm_model* model_ptr) 2927 | { 2928 | if(model_ptr->free_sv && model_ptr->l > 0 && model_ptr->SV != NULL) 2929 | free((void *)(model_ptr->SV[0])); 2930 | if(model_ptr->sv_coef) 2931 | { 2932 | for(int i=0;inr_class-1;i++) 2933 | free(model_ptr->sv_coef[i]); 2934 | } 2935 | 2936 | free(model_ptr->SV); 2937 | model_ptr->SV = NULL; 2938 | 2939 | free(model_ptr->sv_coef); 2940 | model_ptr->sv_coef = NULL; 2941 | 2942 | free(model_ptr->rho); 2943 | model_ptr->rho = NULL; 2944 | 2945 | free(model_ptr->label); 2946 | model_ptr->label= NULL; 2947 | 2948 | free(model_ptr->probA); 2949 | model_ptr->probA = NULL; 2950 | 2951 | free(model_ptr->probB); 2952 | model_ptr->probB= NULL; 2953 | 2954 | free(model_ptr->nSV); 2955 | model_ptr->nSV = NULL; 2956 | } 2957 | 2958 | void svm_free_and_destroy_model(svm_model** model_ptr_ptr) 2959 | { 2960 | if(model_ptr_ptr != NULL && *model_ptr_ptr != NULL) 2961 | { 2962 | svm_free_model_content(*model_ptr_ptr); 2963 | free(*model_ptr_ptr); 2964 | *model_ptr_ptr = NULL; 2965 | } 2966 | } 2967 | 2968 | void svm_destroy_param(svm_parameter* param) 2969 | { 2970 | free(param->weight_label); 2971 | free(param->weight); 2972 | } 2973 | 2974 | const char *svm_check_parameter(const svm_problem *prob, const svm_parameter *param) 2975 | { 2976 | // svm_type 2977 | 2978 | int svm_type = param->svm_type; 2979 | if(svm_type != C_SVC && 2980 | svm_type != NU_SVC && 2981 | svm_type != ONE_CLASS && 2982 | svm_type != EPSILON_SVR && 2983 | svm_type != NU_SVR) 2984 | return "unknown svm type"; 2985 | 2986 | // kernel_type, degree 2987 | 2988 | int kernel_type = param->kernel_type; 2989 | if(kernel_type != LINEAR && 2990 | kernel_type != POLY && 2991 | kernel_type != RBF && 2992 | kernel_type != SIGMOID && 2993 | kernel_type != PRECOMPUTED) 2994 | return "unknown kernel type"; 2995 | 2996 | if(param->gamma < 0) 2997 | return "gamma < 0"; 2998 | 2999 | if(param->degree < 0) 3000 | return "degree of polynomial kernel < 0"; 3001 | 3002 | // cache_size,eps,C,nu,p,shrinking 3003 | 3004 | if(param->cache_size <= 0) 3005 | return "cache_size <= 0"; 3006 | 3007 | if(param->eps <= 0) 3008 | return "eps <= 0"; 3009 | 3010 | if(svm_type == C_SVC || 3011 | svm_type == EPSILON_SVR || 3012 | svm_type == NU_SVR) 3013 | if(param->C <= 0) 3014 | return "C <= 0"; 3015 | 3016 | if(svm_type == NU_SVC || 3017 | svm_type == ONE_CLASS || 3018 | svm_type == NU_SVR) 3019 | if(param->nu <= 0 || param->nu > 1) 3020 | return "nu <= 0 or nu > 1"; 3021 | 3022 | if(svm_type == EPSILON_SVR) 3023 | if(param->p < 0) 3024 | return "p < 0"; 3025 | 3026 | if(param->shrinking != 0 && 3027 | param->shrinking != 1) 3028 | return "shrinking != 0 and shrinking != 1"; 3029 | 3030 | if(param->probability != 0 && 3031 | param->probability != 1) 3032 | return "probability != 0 and probability != 1"; 3033 | 3034 | if(param->probability == 1 && 3035 | svm_type == ONE_CLASS) 3036 | return "one-class SVM probability output not supported yet"; 3037 | 3038 | 3039 | // check whether nu-svc is feasible 3040 | 3041 | if(svm_type == NU_SVC) 3042 | { 3043 | int l = prob->l; 3044 | int max_nr_class = 16; 3045 | int nr_class = 0; 3046 | int *label = Malloc(int,max_nr_class); 3047 | int *count = Malloc(int,max_nr_class); 3048 | 3049 | int i; 3050 | for(i=0;iy[i]; 3053 | int j; 3054 | for(j=0;jnu*(n1+n2)/2 > min(n1,n2)) 3082 | { 3083 | free(label); 3084 | free(count); 3085 | return "specified nu is infeasible"; 3086 | } 3087 | } 3088 | } 3089 | free(label); 3090 | free(count); 3091 | } 3092 | 3093 | return NULL; 3094 | } 3095 | 3096 | int svm_check_probability_model(const svm_model *model) 3097 | { 3098 | return ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) && 3099 | model->probA!=NULL && model->probB!=NULL) || 3100 | ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) && 3101 | model->probA!=NULL); 3102 | } 3103 | 3104 | void svm_set_print_string_function(void (*print_func)(const char *)) 3105 | { 3106 | if(print_func == NULL) 3107 | svm_print_string = &print_string_stdout; 3108 | else 3109 | svm_print_string = print_func; 3110 | } 3111 | 3112 | 3113 | //finished libsvm code 3114 | 3115 | //svm surrounding 3116 | int max_nr_attr = 64; 3117 | 3118 | int predict_probability=0; 3119 | -------------------------------------------------------------------------------- /src/titlebook.proto: -------------------------------------------------------------------------------- 1 | package tutorial; 2 | 3 | message TitleList { 4 | repeated string title = 1; 5 | } 6 | 7 | message IdList { 8 | repeated int32 id = 1; 9 | } -------------------------------------------------------------------------------- /tool/Worker.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Hello World server in C++ 3 | // Binds REP socket to tcp://*:5555 4 | // Expects "Hello" from client, replies with "World" 5 | // 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include "cppMSGmodule.hpp" 12 | 13 | #include "TC_process.h" 14 | 15 | #include 16 | #include "titlebook.pb.h" 17 | 18 | using namespace std; 19 | 20 | int main () { 21 | // Prepare our context and socket 22 | zmq::context_t context (1); 23 | zmq::socket_t socket (context, ZMQ_REP); 24 | socket.bind ("tcp://*:5555"); 25 | 26 | CateTeller ct; 27 | 28 | while(true) { // that's the reason of why this server program is synchronous 29 | 30 | //// get Request from Python Client 31 | zmq::message_t request; 32 | // Wait for next request from client 33 | socket.recv (&request); 34 | // std::cout << "Received Request" << std::endl; 35 | std::string req_str = std::string (static_cast(request.data()), request.size()); 36 | std::cout << "Received:\n " << req_str << std::endl; 37 | 38 | // parse string to char* array 39 | char** pca; 40 | int count; 41 | if( CppMsgModule::msgStrToPcharArray(pca, count, req_str) ) { 42 | for(int i=0; i 2 | #include 3 | #include "TC_process.h" 4 | 5 | int main(int argc, char** argv) { 6 | // char * text = "教育部考试中心托福网考网上报名"; 7 | // char* text = "皇马6-4马竞登顶欧冠"; 8 | char* text = "evernote 安装最新版本后,个别笔记本无法同步?"; 9 | // char* text = "ios私有api 能修改运营商名称吗?"; 10 | // char* text = "提前博弈A股纳入MSCI"; 11 | char** p_texts = &text; 12 | int label = 0; 13 | 14 | CateTeller ct; 15 | 16 | for(int i=0; i<1000000; ++i) { 17 | 18 | ct.tell(p_texts, 1, &label); 19 | std::cout << std::endl << label << std::endl; 20 | 21 | } 22 | 23 | return 0; 24 | } 25 | --------------------------------------------------------------------------------