├── .gitattributes ├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── MANIFEST.in ├── README.md ├── cmake ├── FindNumPy.cmake ├── compiler.cmake ├── pario.cmake ├── sol.cmake ├── summary.cmake ├── test.cmake ├── tools.cmake └── util.cmake ├── data ├── a1a ├── a1a.t ├── a1a.tmp ├── mnist.scale.res └── pcmac.res ├── experiments ├── README.md ├── experiment.py ├── fig.py ├── liblinear.py ├── opts │ ├── a1a.py │ ├── mnist.py │ ├── news.py │ ├── rcv1.py │ ├── url.py │ └── webspam.py └── vw.py ├── external ├── cmdline │ └── cmdline.h └── json │ ├── json-forwards.h │ ├── json.h │ └── jsoncpp.cpp ├── format.sh ├── include └── sol │ ├── c_api.h │ ├── loss │ ├── bool_loss.h │ ├── hinge_loss.h │ ├── logistic_loss.h │ ├── loss.h │ └── square_loss.h │ ├── math │ ├── expression.h │ ├── matrix.h │ ├── matrix_expression.h │ ├── matrix_storage.h │ ├── operator.h │ ├── shape.h │ ├── sparse_vector.h │ └── vector.h │ ├── model │ ├── model.h │ ├── olm │ │ ├── ada_fobos.h │ │ ├── ada_rda.h │ │ ├── alma2.h │ │ ├── arow.h │ │ ├── cw.h │ │ ├── eccw.h │ │ ├── fofs.h │ │ ├── ogd.h │ │ ├── pa.h │ │ ├── perceptron.h │ │ ├── rda.h │ │ └── sop.h │ ├── online_linear_model.h │ ├── online_model.h │ └── regularizer.h │ ├── pario │ ├── binary_reader.h │ ├── binary_writer.h │ ├── compress.h │ ├── csr_matrix_reader.h │ ├── csv_reader.h │ ├── csv_writer.h │ ├── data_iter.h │ ├── data_point.h │ ├── data_read_task.h │ ├── data_reader.h │ ├── data_writer.h │ ├── file_reader.h │ ├── file_writer.h │ ├── mini_batch.h │ ├── numeric_parser.h │ ├── numpy_reader.h │ ├── svm_reader.h │ └── svm_writer.h │ ├── sol.h │ ├── tools.h │ └── util │ ├── block_queue.h │ ├── error_code.h │ ├── heap.h │ ├── monitor.h │ ├── mutex.h │ ├── platform_win32.h │ ├── platform_xnix.h │ ├── reflector.h │ ├── str_util.h │ ├── thread.h │ ├── thread_task.h │ ├── types.h │ └── util.h ├── ofs ├── README.md ├── fgm.py ├── fig.py ├── fs.py ├── liblinear.py ├── mrmr.py └── opts │ ├── aut.py │ ├── basehock.py │ ├── ccat.py │ ├── pcmac.py │ ├── rcv1.py │ ├── real-sim.py │ ├── relathe.py │ ├── synthetic_100k.py │ ├── synthetic_1m.py │ ├── synthetic_200k.py │ ├── url.py │ └── voc2007.py ├── python ├── __init__.py ├── cv.py ├── dataset.py ├── pysol.pxd ├── pysol.pyx ├── sol_test.py └── sol_train.py ├── setup.py ├── src └── sol │ ├── c_api.cc │ ├── loss │ ├── bool_loss.cc │ ├── hinge_loss.cc │ ├── logistic_loss.cc │ ├── loss.cc │ └── square_loss.cc │ ├── model │ ├── model.cc │ ├── olm │ │ ├── ada_fobos.cc │ │ ├── ada_rda.cc │ │ ├── alma2.cc │ │ ├── arow.cc │ │ ├── cw.cc │ │ ├── eccw.cc │ │ ├── fofs.cc │ │ ├── ogd.cc │ │ ├── pa.cc │ │ ├── perceptron.cc │ │ ├── rda.cc │ │ └── sop.cc │ ├── online_linear_model.cc │ ├── online_model.cc │ └── regularizer.cc │ ├── pario │ ├── binary_reader.cc │ ├── binary_writer.cc │ ├── csr_matrix_reader.cc │ ├── csv_reader.cc │ ├── csv_writer.cc │ ├── data_iter.cc │ ├── data_point.cc │ ├── data_read_task.cc │ ├── data_reader.cc │ ├── data_writer.cc │ ├── file_reader.cc │ ├── file_writer.cc │ ├── numpy_reader.cc │ ├── svm_reader.cc │ └── svm_writer.cc │ ├── tools.cc │ └── util │ └── reflector.cc ├── test ├── model │ └── test_model.cc ├── pario │ ├── test_binary.cc │ ├── test_compress.cc │ ├── test_csv.cc │ ├── test_data_iter.cc │ ├── test_data_point.cc │ ├── test_file_reader.cc │ ├── test_file_writer.cc │ └── test_svm.cc └── util │ └── test_matrix.cc └── tools ├── analyze.cc ├── concat.cc ├── converter.cc ├── lsol_c.cc ├── shuffle.cc ├── sol_test.cc ├── sol_train.cc └── split.cc /.gitattributes: -------------------------------------------------------------------------------- 1 | *.cc linguist-language=C++ 2 | *.py linguist-language=C++ 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files 2 | *.slo 3 | *.lo 4 | *.o 5 | *.obj 6 | 7 | # Precompiled Headers 8 | *.gch 9 | *.pch 10 | 11 | # Compiled Dynamic libraries 12 | *.so 13 | *.dylib 14 | *.dll 15 | 16 | # Fortran module files 17 | *.mod 18 | 19 | # Compiled Static libraries 20 | *.lai 21 | *.la 22 | *.a 23 | *.lib 24 | 25 | # Executables 26 | *.exe 27 | *.out 28 | *.app 29 | 30 | #python tmp files 31 | *.pyc 32 | 33 | #temp or log file 34 | *.log 35 | *.temp 36 | *.tmp 37 | *.bak 38 | *.pkl 39 | *.sln 40 | *.txt 41 | *.pdf 42 | *.png 43 | *.tar 44 | *.gz 45 | python/*.c 46 | 47 | #folders 48 | .vs/ 49 | build/ 50 | install/ 51 | tmp/ 52 | temp/ 53 | dist/ 54 | cache/ 55 | data/ 56 | *.wiki/ 57 | *.egg-info/ 58 | /build_cyg 59 | /pyenv 60 | /python/pysol.cpp 61 | /log 62 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.12) 2 | 3 | project(SOL) 4 | 5 | set(EXECUTABLE_OUTPUT_PATH ${CMAKE_BINARY_DIR}/bin) 6 | set(LIBRARY_OUTPUT_PATH ${CMAKE_BINARY_DIR}/bin) 7 | set(ARCHIVE_OUTPUT_PATH ${CMAKE_BINARY_DIR}/lib) 8 | set(CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) 9 | 10 | if (NOT PREFIX) 11 | set(CMAKE_INSTALL_PREFIX ${CMAKE_SOURCE_DIR}/dist) 12 | else() 13 | set(CMAKE_INSTALL_PREFIX ${PREFIX}) 14 | endif() 15 | 16 | include(compiler) 17 | 18 | include_directories( 19 | ${PROJECT_SOURCE_DIR}/include 20 | ${PROJECT_SOURCE_DIR}/external 21 | ) 22 | 23 | #include(util) 24 | #include(pario) 25 | include(sol) 26 | 27 | include(tools) 28 | include(test) 29 | 30 | include(summary) 31 | 32 | install(TARGETS ${TARGET_LIBS} 33 | RUNTIME DESTINATION bin 34 | LIBRARY DESTINATION bin 35 | ARCHIVE DESTINATION lib 36 | ) 37 | 38 | install(DIRECTORY include 39 | DESTINATION . 40 | ) 41 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Include license file 2 | include LICENSE 3 | 4 | #include python extension file 5 | include python/pysol.pyx 6 | 7 | #Include include files 8 | recursive-include include * 9 | recursive-include external/json *.h 10 | -------------------------------------------------------------------------------- /cmake/FindNumPy.cmake: -------------------------------------------------------------------------------- 1 | # Find the Python NumPy package 2 | # PYTHON_NUMPY_INCLUDE_DIR 3 | # NUMPY_FOUND 4 | # will be set by this script 5 | 6 | cmake_minimum_required(VERSION 2.6) 7 | 8 | if(NOT PYTHON_EXECUTABLE) 9 | if(NumPy_FIND_QUIETLY) 10 | find_package(PythonInterp QUIET) 11 | else() 12 | find_package(PythonInterp) 13 | set(_numpy_out 1) 14 | endif() 15 | endif() 16 | 17 | if (PYTHON_EXECUTABLE) 18 | #find python 19 | find_package(PythonLibs REQUIRED) 20 | # write a python script that finds the numpy path 21 | file(WRITE ${PROJECT_BINARY_DIR}/FindNumpyPath.py 22 | "try: import numpy; print(numpy.get_include())\nexcept:pass\n") 23 | 24 | # execute the find script 25 | exec_program("${PYTHON_EXECUTABLE}" ${PROJECT_BINARY_DIR} 26 | ARGS "FindNumpyPath.py" 27 | OUTPUT_VARIABLE NUMPY_PATH) 28 | elseif(_numpy_out) 29 | message(STATUS "Python executable not found.") 30 | endif(PYTHON_EXECUTABLE) 31 | 32 | find_path(PYTHON_NUMPY_INCLUDE_DIR numpy/arrayobject.h 33 | HINTS "${NUMPY_PATH}" "${PYTHON_INCLUDE_PATH}") 34 | 35 | if(PYTHON_NUMPY_INCLUDE_DIR) 36 | set(PYTHON_NUMPY_FOUND 1 CACHE INTERNAL "Python numpy found") 37 | endif(PYTHON_NUMPY_INCLUDE_DIR) 38 | 39 | include(FindPackageHandleStandardArgs) 40 | find_package_handle_standard_args(NumPy DEFAULT_MSG PYTHON_NUMPY_INCLUDE_DIR) 41 | -------------------------------------------------------------------------------- /cmake/pario.cmake: -------------------------------------------------------------------------------- 1 | file(GLOB pario_headers 2 | "${PROJECT_SOURCE_DIR}/include/sol/pario/*.h" 3 | "${PROJECT_SOURCE_DIR}/include/sol/pario/*.hpp" 4 | ) 5 | 6 | file(GLOB pario_src 7 | "${PROJECT_SOURCE_DIR}/src/sol/pario/*.cpp" 8 | "${PROJECT_SOURCE_DIR}/src/sol/pario/*.cc" 9 | ) 10 | 11 | source_group("Header Files" FILES ${pario_headers}) 12 | source_group("Source Files" FILES ${pario_src}) 13 | 14 | 15 | add_library(sol_pario SHARED ${pario_headers} ${pario_src}) 16 | target_link_libraries(sol_pario sol_util) 17 | list(APPEND TARGET_LIBS sol_pario) 18 | 19 | execute_process(COMMAND ${CMAKE_COMMAND} -E copy_directory 20 | ${PROJECT_SOURCE_DIR}/data 21 | ${CMAKE_BINARY_DIR}/data) 22 | -------------------------------------------------------------------------------- /cmake/sol.cmake: -------------------------------------------------------------------------------- 1 | set(src_dirs util math pario loss model model/olm) 2 | foreach (src_dir ${src_dirs}) 3 | file(GLOB ${src_dir}_headers 4 | "${PROJECT_SOURCE_DIR}/include/sol/${src_dir}/*.h" 5 | "${PROJECT_SOURCE_DIR}/include/sol/${src_dir}/*.hpp" 6 | ) 7 | 8 | file(GLOB ${src_dir}_src 9 | "${PROJECT_SOURCE_DIR}/src/sol/${src_dir}/*.cpp" 10 | "${PROJECT_SOURCE_DIR}/src/sol/${src_dir}/*.cc" 11 | ) 12 | 13 | STRING(REGEX REPLACE "/" "\\\\" win_src_dir ${src_dir}) 14 | source_group("Header Files\\${win_src_dir}" FILES ${${src_dir}_headers}) 15 | source_group("Source Files\\${win_src_dir}" FILES ${${src_dir}_src}) 16 | list(APPEND sol_list ${${src_dir}_headers} ${${src_dir}_src}) 17 | endforeach() 18 | 19 | file(GLOB json_files 20 | "${PROJECT_SOURCE_DIR}/external/json/*.h" 21 | "${PROJECT_SOURCE_DIR}/external/json/*.cpp" 22 | ) 23 | list(APPEND sol_list ${json_files}) 24 | 25 | add_library(sol SHARED ${sol_list} 26 | ${PROJECT_SOURCE_DIR}/include/sol/sol.h 27 | ${PROJECT_SOURCE_DIR}/include/sol/c_api.h 28 | ${PROJECT_SOURCE_DIR}/include/sol/tools.h 29 | ${PROJECT_SOURCE_DIR}/src/sol/c_api.cc 30 | ${PROJECT_SOURCE_DIR}/src/sol/tools.cc 31 | ) 32 | target_link_libraries(sol ${LINK_LIBS}) 33 | list(APPEND TARGET_LIBS sol) 34 | 35 | execute_process(COMMAND ${CMAKE_COMMAND} -E copy_directory 36 | ${PROJECT_SOURCE_DIR}/data 37 | ${CMAKE_BINARY_DIR}/data) 38 | -------------------------------------------------------------------------------- /cmake/summary.cmake: -------------------------------------------------------------------------------- 1 | foreach(target_lib ${TARGET_LIBS}) 2 | set_target_properties(${target_lib} PROPERTIES COMPILE_DEFINITIONS "SOL_EXPORTS;JSON_DLL_BUILD") 3 | list(APPEND TARGETS ${target_lib}) 4 | endforeach() 5 | 6 | # ========================== build platform ========================== 7 | message(STATUS "") 8 | message(STATUS " Platform:") 9 | if(CMAKE_CROSSCOMPILING) 10 | message(STATUS " Target:" ${CMAKE_SYSTEM_NAME} ${CMAKE_SYSTEM_VERSION} ${CMAKE_SYSTEM_PROCESSOR}) 11 | endif() 12 | message(STATUS " CMake:" ${CMAKE_VERSION}) 13 | message(STATUS " CMake generator:" ${CMAKE_GENERATOR}) 14 | 15 | if(NOT ${CMAKE_GENERATOR} MATCHES "Xcode") 16 | message(STATUS " Configuration:" ${CMAKE_BUILD_TYPE}) 17 | endif() 18 | 19 | message(STATUS "") 20 | message(STATUS " C:") 21 | message(STATUS " C Compiler:" ${CMAKE_C_COMPILER}) 22 | message(STATUS " C flags: " ${CMAKE_C_FLAGS}) 23 | message(STATUS " C flags (Release):" ${CMAKE_C_FLAGS_RELEASE}) 24 | message(STATUS " C flags (Debug):" ${CMAKE_C_FLAGS_DEBUG}) 25 | message(STATUS "") 26 | message(STATUS " C++:") 27 | message(STATUS " C++ Compiler:" ${CMAKE_CXX_COMPILER}) 28 | message(STATUS " C++ flags: " ${CMAKE_CXX_FLAGS}) 29 | message(STATUS " C++ flags (Release):" ${CMAKE_CXX_FLAGS_RELEASE}) 30 | message(STATUS " C++ flags (Debug):" ${CMAKE_CXX_FLAGS_DEBUG}) 31 | if(WIN32) 32 | message(STATUS " Linker flags (Release):" ${CMAKE_SHARED_LINKER_FLAGS_RELEASE}) 33 | message(STATUS " Linker flags (Debug):" ${CMAKE_SHARED_LINKER_FLAGS_DEBUG}) 34 | endif() 35 | 36 | #message(STATUS " Linked Libraries" ${LINKED_LIBS}) 37 | -------------------------------------------------------------------------------- /cmake/test.cmake: -------------------------------------------------------------------------------- 1 | set(src_dirs util pario model) 2 | foreach (src_dir ${src_dirs}) 3 | file(GLOB ${src_dir}_src 4 | "${PROJECT_SOURCE_DIR}/test/${src_dir}/*.cpp" 5 | "${PROJECT_SOURCE_DIR}/test/${src_dir}/*.cc" 6 | ) 7 | 8 | foreach(test_src ${${src_dir}_src}) 9 | get_filename_component(tgt_name ${test_src} NAME_WE) 10 | add_executable(${tgt_name} ${test_src}) 11 | target_link_libraries(${tgt_name} sol ${LINK_LIBS}) 12 | SET_PROPERTY(TARGET ${tgt_name} PROPERTY FOLDER "test/${src_dir}") 13 | list(APPEND test_targets ${tgt_name}) 14 | endforeach() 15 | endforeach() 16 | -------------------------------------------------------------------------------- /cmake/tools.cmake: -------------------------------------------------------------------------------- 1 | set(TOOLS_DIR ${PROJECT_SOURCE_DIR}/tools) 2 | 3 | file(GLOB tool_list 4 | "${TOOLS_DIR}/*.cpp" 5 | "${TOOLS_DIR}/*.cc" 6 | ) 7 | 8 | foreach(tool_src ${tool_list}) 9 | get_filename_component(tgt_name ${tool_src} NAME_WE) 10 | add_executable(${tgt_name} ${tool_src}) 11 | target_link_libraries(${tgt_name} sol ${LINK_LIBS}) 12 | SET_PROPERTY(TARGET ${tgt_name} PROPERTY FOLDER "tools") 13 | list(APPEND tools_targets ${tgt_name}) 14 | endforeach() 15 | 16 | install(TARGETS ${tools_targets} 17 | RUNTIME DESTINATION bin 18 | LIBRARY DESTINATION bin 19 | ARCHIVE DESTINATION lib 20 | ) 21 | -------------------------------------------------------------------------------- /cmake/util.cmake: -------------------------------------------------------------------------------- 1 | file(GLOB util_headers 2 | "${PROJECT_SOURCE_DIR}/include/sol/util/*.h" 3 | "${PROJECT_SOURCE_DIR}/include/sol/util/*.hpp" 4 | "${PROJECT_SOURCE_DIR}/external/json/*.h" 5 | ) 6 | 7 | file(GLOB util_src 8 | "${PROJECT_SOURCE_DIR}/src/sol/util/*.cc" 9 | "${PROJECT_SOURCE_DIR}/src/sol/util/*.cpp" 10 | "${PROJECT_SOURCE_DIR}/external/json/*.cpp" 11 | ) 12 | 13 | source_group("Header Files" FILES ${util_headers}) 14 | source_group("Source Files" FILES ${util_src}) 15 | 16 | file(GLOB math_headers 17 | "${PROJECT_SOURCE_DIR}/include/sol/math/*.h" 18 | "${PROJECT_SOURCE_DIR}/include/sol/math/*.hpp" 19 | ) 20 | source_group("Header Files\\math" FILES ${math_headers}) 21 | 22 | add_library(sol_util SHARED ${util_headers} ${util_src} ${math_headers}) 23 | list(APPEND TARGET_LIBS sol_util) 24 | -------------------------------------------------------------------------------- /data/mnist.scale.res: -------------------------------------------------------------------------------- 1 | Perceptron 0.1449 8696 2 | OGD 0.1076 9812 3 | PA 0.1339 22482 4 | PA1 0.1339 22482 5 | PA2 0.1335 22492 -------------------------------------------------------------------------------- /data/pcmac.res: -------------------------------------------------------------------------------- 1 | perceptron 0.1313 128 2 | alma2 0.0985 226 3 | ogd 0.0769 975 4 | pa 0.0677 700 5 | pa1 0.0687 704 6 | pa2 0.0687 808 7 | eccw 0.0687 733 8 | arow 0.0697 904 9 | ada-fobos 0.0821 975. 10 | ada-rda 0.0831 975 -------------------------------------------------------------------------------- /experiments/README.md: -------------------------------------------------------------------------------- 1 | Experiments for comparison of different online learning algorithm 2 | ================================================================ 3 | 4 | The python scripts in this folder is for comparison of different online 5 | learning algorithms, as well as the comparison with VW and LIBLINEAR. 6 | 7 | 8 | For example, to compare on the a1a dataset, you can simpy run: 9 | 10 | python experiment.py a1a ../data/a1a ../data/a1a.t 11 | 12 | 13 | The algorithms to be compared are defined in "a1a.py". By default, we include 14 | the comparison with VW and LIBLINEAR. You should have both packages properly 15 | installed in your system. If **NOT**, you can simply remove the configuration 16 | for VW and LIBLINEAR from the configuration scripts. 17 | -------------------------------------------------------------------------------- /experiments/fig.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | # AUTHOR: 3 | # FILE: fig.py 4 | # ROLE: TODO (some explanation) 5 | # CREATED: 2015-05-16 23:37:47 6 | # MODIFIED: 2015-05-16 23:37:47 7 | 8 | import matplotlib.pyplot as plt 9 | import logging 10 | 11 | plt.rc('pdf', fonttype=42) 12 | 13 | 14 | def plot(xs, ys, 15 | x_label, 16 | y_label, 17 | legends, 18 | output_path, 19 | line_width=3, 20 | marker_size=12, 21 | xlim=None, 22 | ylim=None, 23 | xtickers=None, 24 | logx=False, 25 | logy=False, 26 | clip_on=False, 27 | fontsize=18, 28 | legend_cols=2, 29 | legend_order=201, 30 | legend_loc='best', 31 | bbox_to_anchor=None, 32 | draw_legend=True): 33 | 34 | color_list = ['r', 'm', 'k', 'b', (0.12, 0.56, 1), (0.58, 0.66, 0.2), (0.48, 0.41, 0.93), 35 | (0, 0.75, 0.75)] 36 | marker_list = ['s', 'h', '*', u'o', 'd', '^', 'v', '<', '>'] 37 | #line_styles=['-','--'] 38 | 39 | c_ind = 0 40 | m_ind = 0 41 | fig = plt.figure() 42 | ax = fig.add_subplot(1, 1, 1) 43 | lines = [] 44 | if logx is True and logy is True: 45 | plot_handler = ax.loglog 46 | elif logx is True and logy is False: 47 | plot_handler = ax.semilogx 48 | elif logx is False and logy is True: 49 | plot_handler = ax.semilogy 50 | else: 51 | plot_handler = ax.plot 52 | 53 | for i in xrange(len(xs)): 54 | zorder = 200 - i 55 | color = color_list[c_ind % len(color_list)] 56 | marker = marker_list[m_ind % len(marker_list)] 57 | c_ind += 1 58 | m_ind += 1 59 | if xlim != None: 60 | x_values = [] 61 | y_values = [] 62 | for k in xrange(len(xs[i])): 63 | if xs[i][k] >= xlim[0] and xs[i][k] <= xlim[1]: 64 | x_values.append(xs[i][k]) 65 | y_values.append(ys[i][k]) 66 | else: 67 | x_values = xs[i] 68 | y_values = ys[i] 69 | line, = plot_handler(x_values, y_values, 70 | color=color, 71 | marker=marker, 72 | linestyle='-', 73 | clip_on=clip_on, 74 | markersize=marker_size, 75 | linewidth=line_width, 76 | fillstyle='full', 77 | zorder=zorder) 78 | lines.append(line) 79 | 80 | if xtickers != None: 81 | ax.set_xticks(xtickers) 82 | if xlim != None: 83 | ax.set_xlim(xlim) 84 | if ylim != None: 85 | ax.set_ylim(ylim) 86 | 87 | ax.grid() 88 | if draw_legend: 89 | l = ax.legend(lines,legends,loc=legend_loc,ncol=legend_cols) 90 | l.set_zorder(legend_order) 91 | if bbox_to_anchor != None: 92 | l.set_bbox_to_anchor(bbox_to_anchor) 93 | 94 | plt.xlabel(x_label,fontsize=fontsize) 95 | plt.ylabel(y_label,fontsize=fontsize) 96 | #plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0)) 97 | 98 | plt.savefig(output_path,bbox_inches='tight') 99 | logging.info('figure saved to %s' %(output_path)) 100 | #plt.show() 101 | -------------------------------------------------------------------------------- /experiments/opts/a1a.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | ################################################################################# 3 | # File Name : a1a.py 4 | # Created By : yuewu 5 | # Creation Date : [2016-11-17 18:26] 6 | # Last Modified : [2017-03-17 19:14] 7 | # Description : 8 | ################################################################################# 9 | 10 | import numpy as np 11 | import collections 12 | 13 | const_eta_search = np.logspace(-5, 5, 11, base=2) 14 | eta_search = np.logspace(-2, 8, 11, base=2) 15 | delta_search = np.logspace(-5, 5,11, base=2) 16 | r_search = np.logspace(-5, 8, 11, base=2) 17 | delta_ofs_search = np.logspace(-5, 5, 11, base=2) / 100.0 18 | 19 | ol_opts = {} 20 | ol_opts['ada-fobos'] = { 21 | 'cv':{'eta':const_eta_search, 'delta':delta_search} 22 | } 23 | ol_opts['ada-rda'] = { 24 | 'cv':{'eta':const_eta_search, 'delta':delta_search} 25 | } 26 | ol_opts['alma2'] = { 27 | 'cv':{'alpha': np.linspace(0.1, 1, 10)} 28 | } 29 | ol_opts['arow'] = { 30 | 'cv':{'r':r_search} 31 | } 32 | ol_opts['cw'] = { 33 | 'cv':{'a': np.logspace(-4, 0, 5, base=2), 'phi':np.linspace(0, 2, 9)} 34 | } 35 | ol_opts['eccw'] = { 36 | 'cv':{'a': np.logspace(-4, 0, 5, base=2), 'phi':np.linspace(0, 2, 9)} 37 | } 38 | ol_opts['ogd'] = { 39 | 'cv':{'eta':eta_search} 40 | } 41 | ol_opts['pa'] = {} 42 | ol_opts['pa1'] = { 43 | 'cv':{'C':np.logspace(-4, 4, 9, base=2)} 44 | } 45 | ol_opts['pa2'] = { 46 | 'cv':{'C':np.logspace(-4, 4, 9, base=2)} 47 | } 48 | ol_opts['perceptron'] = {} 49 | ol_opts['sop'] = { 50 | 'cv':{'a':np.logspace(-4, 4, 9, base=2)} 51 | } 52 | ol_opts['rda'] = {} 53 | ol_opts['erda-l1'] = {} 54 | 55 | for k,v in ol_opts.iteritems(): 56 | if 'params' not in v: 57 | v['params'] = {} 58 | v['params']['step_show'] = 200 59 | 60 | ol_opts['liblinear'] = { 61 | 'cv': {'C': np.logspace(-5,7,13, base=2)} 62 | } 63 | #ol_opts['vw'] = { 64 | # 'cv':{'learning_rate':np.logspace(-4,7,12,base=2)} 65 | #} 66 | 67 | sol_opts = {} 68 | sol_opts['stg'] = { 69 | 'params':{'k':10}, 70 | 'cv':{'eta':eta_search}, 71 | 'lambda': np.logspace(-3,-0.5,5,base=10) 72 | } 73 | sol_opts['fobos-l1'] = { 74 | 'cv':{'eta':eta_search}, 75 | 'lambda': np.logspace(-3,-0.5,5,base=10) 76 | } 77 | sol_opts['rda-l1'] = { 78 | 'lambda': np.logspace(-3,-0.5,5,base=10) 79 | } 80 | 81 | sol_opts['erda-l1'] = { 82 | 'params':{'rou':0.001}, 83 | 'lambda': np.logspace(-3,-1,5,base=10) 84 | } 85 | sol_opts['ada-fobos-l1'] = { 86 | 'cv':{'eta':const_eta_search, 'delta':delta_search}, 87 | 'lambda': np.logspace(-3,-0.5,5,base=10) 88 | } 89 | sol_opts['ada-rda-l1'] = { 90 | 'cv':{'eta':const_eta_search, 'delta':delta_search}, 91 | 'lambda': np.logspace(-3,-1,5,base=10) 92 | } 93 | sol_opts['liblinear'] = { 94 | 'lambda':np.logspace(-4,7,12, base=2) 95 | } 96 | 97 | #sol_opts['vw'] = { 98 | # 'cv':'vw', 99 | # 'lambda':np.logspace(-5,-2,4, base=10) 100 | #} 101 | -------------------------------------------------------------------------------- /experiments/opts/mnist.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | ################################################################################# 3 | # File Name : mnist.py 4 | # Created By : yuewu 5 | # Creation Date : [2016-12-07 18:03] 6 | # Last Modified : [2016-12-07 18:05] 7 | # Description : 8 | ################################################################################# 9 | 10 | import numpy as np 11 | 12 | const_eta_search = np.logspace(-5, 5, 11, base=2) 13 | eta_search = np.logspace(-2, 8, 11, base=2) 14 | delta_search = np.logspace(-5, 5,11, base=2) 15 | r_search = np.logspace(-5, 5, 11, base=2) 16 | delta_ofs_search = np.logspace(-5, 5, 11, base=2) / 100.0 17 | norm_search = ['L2', 'None'] 18 | 19 | ol_opts = {} 20 | ol_opts['ada-fobos'] = { 21 | 'cv':{'eta':const_eta_search, 'delta':delta_search} 22 | } 23 | ol_opts['ada-rda'] = { 24 | 'cv':{'eta':const_eta_search, 'delta':delta_search} 25 | } 26 | ol_opts['alma2'] = { 27 | 'cv':{'alpha': np.linspace(0.1, 1, 10)} 28 | } 29 | ol_opts['arow'] = { 30 | 'cv':{'r':r_search} 31 | } 32 | ol_opts['cw'] = { 33 | 'cv':{'a': np.logspace(-4, 0, 5, base=2), 'phi':np.linspace(0, 2, 9)} 34 | } 35 | ol_opts['eccw'] = { 36 | 'cv':{'a': np.logspace(-4, 0, 5, base=2), 'phi':np.linspace(0, 2, 9)} 37 | } 38 | ol_opts['ogd'] = { 39 | 'cv':{'eta':eta_search} 40 | } 41 | ol_opts['pa'] = {} 42 | ol_opts['pa1'] = { 43 | 'cv':{'C':np.logspace(-4, 4, 9, base=2)} 44 | } 45 | ol_opts['pa2'] = { 46 | 'cv':{'C':np.logspace(-4, 4, 9, base=2)} 47 | } 48 | ol_opts['perceptron'] = {} 49 | ol_opts['sop'] = { 50 | 'cv':{'a':np.logspace(-4, 4, 9, base=2)} 51 | } 52 | ol_opts['rda'] = {} 53 | ol_opts['erda-l1'] = {} 54 | 55 | for k,v in ol_opts.iteritems(): 56 | if 'params' not in v: 57 | v['params'] = {} 58 | v['params']['step_show'] = 5000 59 | 60 | #ol_opts['liblinear'] = { 61 | # 'cv': {'C': np.logspace(-5,7,13, base=2)} 62 | #} 63 | #ol_opts['vw'] = { 64 | # 'cv':{'learning_rate':np.logspace(-4,7,12,base=2)} 65 | #} 66 | 67 | sol_opts = {} 68 | sol_opts['stg'] = { 69 | 'params':{'k':10}, 70 | 'cv':{'eta':eta_search}, 71 | 'lambda': np.logspace(-4,-1,5,base=10) 72 | } 73 | sol_opts['fobos-l1'] = { 74 | 'cv':{'eta':eta_search}, 75 | 'lambda': np.logspace(-4,-1,5,base=10) 76 | } 77 | sol_opts['rda-l1'] = { 78 | 'lambda': np.logspace(-4,-1,5,base=10) 79 | } 80 | 81 | sol_opts['erda-l1'] = { 82 | 'params':{'rou':0.001}, 83 | 'lambda': np.logspace(-4,-1,5,base=10) 84 | } 85 | sol_opts['ada-fobos-l1'] = { 86 | 'cv':{'eta':const_eta_search, 'delta':delta_search}, 87 | 'lambda': np.logspace(-4,-1,5,base=10) 88 | } 89 | sol_opts['ada-rda-l1'] = { 90 | 'cv':{'eta':const_eta_search, 'delta':delta_search}, 91 | 'lambda': np.logspace(-4,-1,5,base=10) 92 | } 93 | #sol_opts['liblinear'] = { 94 | # 'lambda':np.logspace(-4,8,13, base=2) 95 | #} 96 | 97 | #sol_opts['vw'] = { 98 | # 'cv':'vw', 99 | # 'lambda':np.logspace(-6,-2,10, base=10) 100 | #} 101 | -------------------------------------------------------------------------------- /experiments/opts/news.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | ################################################################################# 3 | # File Name : synthetic_100k.py 4 | # Created By : yuewu 5 | # Creation Date : [2016-10-25 11:21] 6 | # Last Modified : [2016-12-06 16:35] 7 | # Description : 8 | ################################################################################# 9 | 10 | import numpy as np 11 | import collections 12 | 13 | const_eta_search = np.logspace(-5, 5, 11, base=2) 14 | eta_search = np.logspace(-2, 8, 11, base=2) 15 | delta_search = np.logspace(-5, 5,11, base=2) 16 | r_search = np.logspace(-5, 5, 11, base=2) 17 | delta_ofs_search = np.logspace(-5, 5, 11, base=2) / 100.0 18 | norm_search = ['L2', 'None'] 19 | 20 | dim = 1355191 21 | fs_num = (np.array([0.005,0.05, 0.1,0.2,0.3,0.4,0.5]) * dim).astype(np.int) 22 | 23 | fs_opts = collections.OrderedDict() 24 | 25 | fs_opts['SOFS'] = { 26 | 'params':{'norm':'L2'}, 27 | 'cv':{'r': r_search}, 28 | 'lambda': fs_num 29 | } 30 | fs_opts['PET'] = { 31 | 'params':{'power_t':'0', 'norm':'L2'}, 32 | 'cv':{'eta':const_eta_search}, 33 | 'lambda': fs_num 34 | } 35 | 36 | fs_opts['liblinear'] = { 37 | 'lambda': [5000,10000,20000,40000,80000,160000] 38 | } 39 | 40 | fs_opts['FGM'] = { 41 | 'lambda': fs_num 42 | } 43 | draw_opts = { 44 | 'accu':{ 45 | }, 46 | 'time': { 47 | 'logy': True, 48 | 'legend_loc':'center', 49 | 'bbox_to_anchor':(0.7,0.65), 50 | #'legend_order':100, 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /experiments/opts/rcv1.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | ################################################################################# 3 | # File Name : synthetic_100k.py 4 | # Created By : yuewu 5 | # Creation Date : [2016-10-25 11:21] 6 | # Last Modified : [2017-03-17 19:01] 7 | # Description : 8 | ################################################################################# 9 | 10 | import numpy as np 11 | import collections 12 | 13 | const_eta_search = np.logspace(-5, 5, 11, base=2) 14 | eta_search = np.logspace(-2, 8, 11, base=2) 15 | delta_search = np.logspace(-5, 5,11, base=2) 16 | r_search = np.logspace(-5, 5, 11, base=2) 17 | delta_ofs_search = np.logspace(-5, 5, 11, base=2) / 100.0 18 | norm_search = ['L2', 'None'] 19 | 20 | ol_opts = {} 21 | ol_opts['ada-fobos'] = { 22 | 'cv':{'eta':const_eta_search, 'delta':delta_search} 23 | } 24 | ol_opts['ada-rda'] = { 25 | 'cv':{'eta':const_eta_search, 'delta':delta_search} 26 | } 27 | ol_opts['alma2'] = { 28 | 'cv':{'alpha': np.linspace(0.1, 1, 10)} 29 | } 30 | ol_opts['arow'] = { 31 | 'cv':{'r':r_search} 32 | } 33 | ol_opts['cw'] = { 34 | 'cv':{'a': np.logspace(-4, 0, 5, base=2), 'phi':np.linspace(0, 2, 9)} 35 | } 36 | ol_opts['eccw'] = { 37 | 'cv':{'a': np.logspace(-4, 0, 5, base=2), 'phi':np.linspace(0, 2, 9)} 38 | } 39 | ol_opts['ogd'] = { 40 | 'cv':{'eta':eta_search} 41 | } 42 | ol_opts['pa'] = {} 43 | ol_opts['pa1'] = { 44 | 'cv':{'C':np.logspace(-4, 4, 9, base=2)} 45 | } 46 | ol_opts['pa2'] = { 47 | 'cv':{'C':np.logspace(-4, 4, 9, base=2)} 48 | } 49 | ol_opts['perceptron'] = {} 50 | ol_opts['sop'] = { 51 | 'cv':{'a':np.logspace(-4, 4, 9, base=2)} 52 | } 53 | ol_opts['rda'] = {} 54 | ol_opts['erda-l1'] = {} 55 | 56 | for k,v in ol_opts.iteritems(): 57 | if 'params' not in v: 58 | v['params'] = {} 59 | v['params']['step_show'] = 50000 60 | 61 | ol_opts['liblinear'] = { 62 | 'cv': {'C': np.logspace(-5,7,13, base=2)} 63 | } 64 | ol_opts['vw'] = { 65 | 'cv':{'learning_rate':np.logspace(-4,7,12,base=2)} 66 | } 67 | 68 | sol_opts = {} 69 | sol_opts['stg'] = { 70 | 'params':{'k':10}, 71 | 'cv':{'eta':eta_search}, 72 | 'lambda': np.logspace(-6,-1,10,base=10) 73 | } 74 | #sol_opts['fobos-l1'] = { 75 | # 'cv':{'eta':eta_search}, 76 | # 'lambda': np.logspace(-6,-1,10,base=10) 77 | #} 78 | #sol_opts['rda-l1'] = { 79 | # 'lambda': np.logspace(-6,-1,10,base=10) 80 | #} 81 | # 82 | #sol_opts['erda-l1'] = { 83 | # 'params':{'rou':0.001}, 84 | # 'lambda': np.logspace(-6,-1,10,base=10) 85 | #} 86 | #sol_opts['ada-fobos-l1'] = { 87 | # 'cv':{'eta':const_eta_search, 'delta':delta_search}, 88 | # 'lambda': np.logspace(-6,-1,10,base=10) 89 | #} 90 | #sol_opts['ada-rda-l1'] = { 91 | # 'cv':{'eta':const_eta_search, 'delta':delta_search}, 92 | # 'lambda': np.logspace(-7,-2,10,base=10) 93 | #} 94 | #sol_opts['liblinear'] = { 95 | # 'lambda':np.logspace(-5,7,13, base=2) 96 | #} 97 | # 98 | #sol_opts['vw'] = { 99 | # 'cv':'vw', 100 | # 'lambda':np.logspace(-6,-2,10, base=10) 101 | #} 102 | -------------------------------------------------------------------------------- /experiments/opts/url.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | ################################################################################# 3 | # File Name : synthetic_100k.py 4 | # Created By : yuewu 5 | # Creation Date : [2016-10-25 11:21] 6 | # Last Modified : [2017-03-17 19:00] 7 | # Description : 8 | ################################################################################# 9 | 10 | import numpy as np 11 | import collections 12 | 13 | const_eta_search = np.logspace(-5, 5, 11, base=2) 14 | eta_search = np.logspace(-2, 8, 11, base=2) 15 | delta_search = np.logspace(-5, 5,11, base=2) 16 | r_search = np.logspace(-5, 5, 11, base=2) 17 | delta_ofs_search = np.logspace(-5, 5, 11, base=2) / 100.0 18 | norm_search = ['L2', 'None'] 19 | 20 | ol_opts = {} 21 | ol_opts['ada-fobos'] = { 22 | 'cv':{'eta':const_eta_search, 'delta':delta_search} 23 | } 24 | ol_opts['ada-rda'] = { 25 | 'cv':{'eta':const_eta_search, 'delta':delta_search} 26 | } 27 | ol_opts['alma2'] = { 28 | 'cv':{'alpha': np.linspace(0.1, 1, 10)} 29 | } 30 | ol_opts['arow'] = { 31 | 'cv':{'r':r_search} 32 | } 33 | ol_opts['cw'] = { 34 | 'cv':{'a': np.logspace(-4, 0, 5, base=2), 'phi':np.linspace(0, 2, 9)} 35 | } 36 | ol_opts['eccw'] = { 37 | 'cv':{'a': np.logspace(-4, 0, 5, base=2), 'phi':np.linspace(0, 2, 9)} 38 | } 39 | ol_opts['ogd'] = { 40 | 'cv':{'eta':eta_search} 41 | } 42 | ol_opts['pa'] = {} 43 | ol_opts['pa1'] = { 44 | 'cv':{'C':np.logspace(-4, 4, 9, base=2)} 45 | } 46 | ol_opts['pa2'] = { 47 | 'cv':{'C':np.logspace(-4, 4, 9, base=2)} 48 | } 49 | ol_opts['perceptron'] = {} 50 | ol_opts['sop'] = { 51 | 'cv':{'a':np.logspace(-4, 4, 9, base=2)} 52 | } 53 | ol_opts['rda'] = {} 54 | ol_opts['erda-l1'] = {} 55 | 56 | for k,v in ol_opts.iteritems(): 57 | if 'params' not in v: 58 | v['params'] = {} 59 | v['params']['step_show'] = 50000 60 | 61 | #ol_opts['liblinear'] = { 62 | # 'cv': {'C': np.logspace(-5,7,13, base=2)} 63 | #} 64 | #ol_opts['vw'] = { 65 | # 'cv':{'learning_rate':np.logspace(-4,7,12,base=2)} 66 | #} 67 | 68 | sol_opts = {} 69 | sol_opts['stg'] = { 70 | 'params':{'k':10}, 71 | 'cv':{'eta':eta_search}, 72 | 'lambda': np.logspace(-6,-1,10,base=10) 73 | } 74 | sol_opts['fobos-l1'] = { 75 | 'cv':{'eta':eta_search}, 76 | 'lambda': np.logspace(-6,-1,10,base=10) 77 | } 78 | sol_opts['rda-l1'] = { 79 | 'lambda': np.logspace(-6,-1,10,base=10) 80 | } 81 | 82 | sol_opts['erda-l1'] = { 83 | 'params':{'rou':0.001}, 84 | 'lambda': np.logspace(-6,-1,10,base=10) 85 | } 86 | sol_opts['ada-fobos-l1'] = { 87 | 'cv':{'eta':const_eta_search, 'delta':delta_search}, 88 | 'lambda': np.logspace(-6,-1,10,base=10) 89 | } 90 | sol_opts['ada-rda-l1'] = { 91 | 'cv':{'eta':const_eta_search, 'delta':delta_search}, 92 | 'lambda': np.logspace(-7,-2,10,base=10) 93 | } 94 | #sol_opts['liblinear'] = { 95 | # 'lambda':np.logspace(-5,7,13, base=2) 96 | #} 97 | # 98 | #sol_opts['vw'] = { 99 | # 'cv':'vw', 100 | # 'lambda':np.logspace(-6,-2,10, base=10) 101 | #} 102 | -------------------------------------------------------------------------------- /experiments/opts/webspam.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | ################################################################################# 3 | # File Name : webspam.py 4 | # Created By : yuewu 5 | # Description : 6 | ################################################################################# 7 | import numpy as np 8 | 9 | ol_opts = {} 10 | ol_opts['ada-fobos'] = { 11 | 'cv':{'eta':const_eta_search, 'delta':delta_search} 12 | } 13 | ol_opts['ada-rda'] = { 14 | 'cv':{'eta':const_eta_search, 'delta':delta_search} 15 | } 16 | ol_opts['alma2'] = { 17 | 'cv':{'alpha': np.linspace(0.1, 1, 10)} 18 | } 19 | ol_opts['arow'] = { 20 | 'cv':{'r':r_search} 21 | } 22 | ol_opts['cw'] = { 23 | 'cv':{'a': np.logspace(-4, 0, 5, base=2), 'phi':np.linspace(0, 2, 9)} 24 | } 25 | ol_opts['eccw'] = { 26 | 'cv':{'a': np.logspace(-4, 0, 5, base=2), 'phi':np.linspace(0, 2, 9)} 27 | } 28 | ol_opts['ogd'] = { 29 | 'cv':{'eta':eta_search} 30 | } 31 | ol_opts['pa'] = {} 32 | ol_opts['pa1'] = { 33 | 'cv':{'C':np.logspace(-4, 4, 9, base=2)} 34 | } 35 | ol_opts['pa2'] = { 36 | 'cv':{'C':np.logspace(-4, 4, 9, base=2)} 37 | } 38 | ol_opts['perceptron'] = {} 39 | ol_opts['sop'] = { 40 | 'cv':{'a':np.logspace(-4, 4, 9, base=2)} 41 | } 42 | ol_opts['rda'] = {} 43 | ol_opts['erda-l1'] = {} 44 | 45 | for k,v in ol_opts.iteritems(): 46 | if 'params' not in v: 47 | v['params'] = {} 48 | v['params']['step_show'] = 20000 49 | 50 | #ol_opts['liblinear'] = { 51 | # 'cv': {'C': np.logspace(-5,7,13, base=2)} 52 | #} 53 | #ol_opts['vw'] = { 54 | # 'cv':{'learning_rate':np.logspace(-4,7,12,base=2)} 55 | #} 56 | 57 | sol_opts = {} 58 | sol_opts['stg'] = { 59 | 'params':{'k':10}, 60 | 'cv':{'eta':eta_search}, 61 | 'lambda': np.logspace(-6,-1,10,base=10) 62 | } 63 | sol_opts['fobos-l1'] = { 64 | 'cv':{'eta':eta_search}, 65 | 'lambda': np.logspace(-6,-1,10,base=10) 66 | } 67 | sol_opts['rda-l1'] = { 68 | 'lambda': np.logspace(-6,-1,10,base=10) 69 | } 70 | 71 | sol_opts['erda-l1'] = { 72 | 'params':{'rou':0.001}, 73 | 'lambda': np.logspace(-6,-1,10,base=10) 74 | } 75 | sol_opts['ada-fobos-l1'] = { 76 | 'cv':{'eta':const_eta_search, 'delta':delta_search}, 77 | 'lambda': np.logspace(-6,-1,10,base=10) 78 | } 79 | sol_opts['ada-rda-l1'] = { 80 | 'cv':{'eta':const_eta_search, 'delta':delta_search}, 81 | 'lambda': np.logspace(-7,-2,10,base=10) 82 | } 83 | #sol_opts['liblinear'] = { 84 | # 'lambda':np.logspace(-5,7,13, base=2) 85 | #} 86 | # 87 | #sol_opts['vw'] = { 88 | # 'cv':'vw', 89 | # 'lambda':np.logspace(-6,-2,10, base=10) 90 | #} 91 | -------------------------------------------------------------------------------- /format.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | for ext in *.h *.cc 3 | do 4 | src_list=$(find . -name $ext) 5 | for src in ${src_list} 6 | do 7 | echo 'process' $src 8 | vim -c ":ClangFormat" -c ":x" $src 9 | done 10 | done 11 | -------------------------------------------------------------------------------- /include/sol/loss/bool_loss.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : bool_loss.h 3 | * Created By : yuewu 4 | * Description : loss with yes or no 5 | **********************************************************************************/ 6 | #ifndef SOL_LOSS_BOOL_LOSS_H__ 7 | #define SOL_LOSS_BOOL_LOSS_H__ 8 | 9 | #include 10 | 11 | namespace sol { 12 | namespace loss { 13 | 14 | class SOL_EXPORTS BoolLoss : public Loss { 15 | public: 16 | BoolLoss() : Loss(Type::BC | Type::BOOL) {} 17 | 18 | public: 19 | virtual float loss(const pario::DataPoint& dp, float* predict, 20 | label_t predict_label, int cls_num); 21 | 22 | virtual float gradient(const pario::DataPoint& dp, float* predict, 23 | label_t predict_label, float* gradient, int cls_num); 24 | 25 | }; // class BoolLoss 26 | 27 | class SOL_EXPORTS MaxScoreBoolLoss : public Loss { 28 | public: 29 | MaxScoreBoolLoss() : Loss(Type::MC | Type::BOOL) {} 30 | 31 | public: 32 | virtual float loss(const pario::DataPoint& dp, float* predict, 33 | label_t predict_label, int cls_num); 34 | 35 | virtual float gradient(const pario::DataPoint& dp, float* predict, 36 | label_t predict_label, float* gradient, int cls_num); 37 | }; 38 | 39 | class SOL_EXPORTS UniformBoolLoss : public Loss { 40 | public: 41 | UniformBoolLoss() : Loss(Type::MC | Type::BOOL) {} 42 | 43 | public: 44 | virtual float loss(const pario::DataPoint& dp, float* predict, 45 | label_t predict_label, int cls_num); 46 | 47 | virtual float gradient(const pario::DataPoint& dp, float* predict, 48 | label_t predict_label, float* gradient, int cls_num); 49 | }; 50 | 51 | } // namespace loss 52 | } // namespace sol 53 | #endif 54 | -------------------------------------------------------------------------------- /include/sol/loss/hinge_loss.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : hinge.h 3 | * Created By : yuewu 4 | * Description : hinge loss 5 | **********************************************************************************/ 6 | #ifndef SOL_LOSS_HINGE_LOSS_H__ 7 | #define SOL_LOSS_HINGE_LOSS_H__ 8 | 9 | #include 10 | #include 11 | 12 | namespace sol { 13 | namespace loss { 14 | 15 | class SOL_EXPORTS HingeBase : public Loss { 16 | public: 17 | HingeBase(int type); 18 | 19 | public: 20 | float margin() { return margin_; } 21 | void set_margin(float val) { this->margin_ = val; } 22 | void set_margin(const std::function& margin_handler) { 24 | this->margin_handler_ = margin_handler; 25 | } 26 | 27 | protected: 28 | float margin_; 29 | std::function 30 | margin_handler_; 31 | }; 32 | 33 | class SOL_EXPORTS HingeLoss : public HingeBase { 34 | public: 35 | HingeLoss() : HingeBase(Type::BC) {} 36 | 37 | public: 38 | virtual float loss(const pario::DataPoint& dp, float* predict, 39 | label_t predict_label, int cls_num); 40 | 41 | virtual float gradient(const pario::DataPoint& dp, float* predict, 42 | label_t predict_label, float* gradient, int cls_num); 43 | 44 | }; // class HingeLoss 45 | 46 | class SOL_EXPORTS MaxScoreHingeLoss : public HingeBase { 47 | public: 48 | MaxScoreHingeLoss() : HingeBase(Type::MC) {} 49 | 50 | public: 51 | virtual float loss(const pario::DataPoint& dp, float* predict, 52 | label_t predict_label, int cls_num); 53 | 54 | virtual float gradient(const pario::DataPoint& dp, float* predict, 55 | label_t predict_label, float* gradient, int cls_num); 56 | }; 57 | 58 | class SOL_EXPORTS UniformHingeLoss : public HingeBase { 59 | public: 60 | UniformHingeLoss() : HingeBase(Type::MC) {} 61 | 62 | public: 63 | virtual float loss(const pario::DataPoint& dp, float* predict, 64 | label_t predict_label, int cls_num); 65 | 66 | virtual float gradient(const pario::DataPoint& dp, float* predict, 67 | label_t predict_label, float* gradient, int cls_num); 68 | }; 69 | 70 | } // namespace loss 71 | } // namespace sol 72 | #endif 73 | -------------------------------------------------------------------------------- /include/sol/loss/logistic_loss.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : hinge.h 3 | * Created By : yuewu 4 | * Description : hinge loss 5 | **********************************************************************************/ 6 | #ifndef SOL_LOSS_LOGISTIC_LOSS_H__ 7 | #define SOL_LOSS_LOGISTIC_LOSS_H__ 8 | 9 | #include 10 | #include 11 | 12 | namespace sol { 13 | namespace loss { 14 | 15 | class SOL_EXPORTS LogisticLoss : public Loss { 16 | public: 17 | LogisticLoss() : Loss(Type::BC) {} 18 | 19 | public: 20 | virtual float loss(const pario::DataPoint& dp, float* predict, 21 | label_t predict_label, int cls_num); 22 | 23 | virtual float gradient(const pario::DataPoint& dp, float* predict, 24 | label_t predict_label, float* gradient, int cls_num); 25 | 26 | }; // class LogisticLoss 27 | 28 | class SOL_EXPORTS MaxScoreLogisticLoss : public Loss { 29 | public: 30 | MaxScoreLogisticLoss() : Loss(Type::MC) {} 31 | 32 | public: 33 | virtual float loss(const pario::DataPoint& dp, float* predict, 34 | label_t predict_label, int cls_num); 35 | 36 | virtual float gradient(const pario::DataPoint& dp, float* predict, 37 | label_t predict_label, float* gradient, int cls_num); 38 | }; 39 | 40 | class SOL_EXPORTS UniformLogisticLoss : public Loss { 41 | public: 42 | UniformLogisticLoss() : Loss(Type::MC) {} 43 | 44 | public: 45 | virtual float loss(const pario::DataPoint& dp, float* predict, 46 | label_t predict_label, int cls_num); 47 | 48 | virtual float gradient(const pario::DataPoint& dp, float* predict, 49 | label_t predict_label, float* gradient, int cls_num); 50 | }; 51 | 52 | } // namespace loss 53 | } // namespace sol 54 | #endif 55 | -------------------------------------------------------------------------------- /include/sol/loss/loss.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : loss.h 3 | * Created By : yuewu 4 | * Creation Date : [2016-02-14 23:10] 5 | * Last Modified : [2016-02-18 23:33] 6 | * Description : base class for loss functions 7 | **********************************************************************************/ 8 | 9 | #ifndef SOL_LOSS_LOSS_H__ 10 | #define SOL_LOSS_LOSS_H__ 11 | 12 | #include 13 | 14 | #include 15 | #include 16 | #include 17 | 18 | namespace sol { 19 | namespace loss { 20 | 21 | class SOL_EXPORTS Loss { 22 | DeclareReflectorBase(Loss); 23 | 24 | public: 25 | enum Type { 26 | // loss function for regression 27 | RG = 1, 28 | // loss function for binary classification 29 | BC = 2, 30 | // loss function for multi-class classification 31 | MC = 4, 32 | // hinge-based loss function 33 | BOOL = 8, 34 | // bool-based loss function 35 | HINGE = 16, 36 | }; 37 | 38 | inline static char Sign(float x) { return x >= 0.f ? 1 : -1; } 39 | 40 | public: 41 | Loss(int type) : type_(type) {} 42 | virtual ~Loss() {} 43 | 44 | int type() const { return this->type_; } 45 | 46 | public: 47 | /// \brief calculate loss according to the label and predictions 48 | /// 49 | /// \param dp data instance 50 | /// \param predict prediction on each class 51 | /// \param predict_label predicted label 52 | /// \param cls_num number of classes 53 | /// 54 | /// \return loss of the prediction 55 | virtual float loss(const pario::DataPoint& dp, float* predict, 56 | label_t predict_label, int cls_num) = 0; 57 | 58 | /// \brief calculate the gradients according to the label and predictions 59 | /// 60 | /// \param dp data instance 61 | /// \param predict prediction on each class 62 | /// \param predict_label predicted label 63 | /// \param gradient resulted gradient on each class (without x) 64 | /// \param cls_num number of classes 65 | /// 66 | /// \return loss of the prediction 67 | virtual float gradient(const pario::DataPoint& dp, float* predict, 68 | label_t predict_label, float* gradient, 69 | int cls_num) = 0; 70 | 71 | public: 72 | const std::string& name() const { return name_; } 73 | void set_name(const std::string& name) { this->name_ = name; } 74 | 75 | protected: 76 | /// \brief indicating it's a binary or multi-class loss 77 | int type_; 78 | std::string name_; 79 | }; 80 | 81 | #define RegisterLoss(type, name, descr) \ 82 | type* type##_##CreateNewInstance() { \ 83 | type* ins = new type(); \ 84 | ins->set_name(name); \ 85 | return ins; \ 86 | } \ 87 | ClassInfo __kClassInfo_##type##__(std::string(name) + "_loss", \ 88 | (void*)(type##_##CreateNewInstance), \ 89 | descr); 90 | 91 | } // namespace loss 92 | } // namespace sol 93 | 94 | #endif 95 | -------------------------------------------------------------------------------- /include/sol/loss/square_loss.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : hinge.h 3 | * Created By : yuewu 4 | * Description : hinge loss 5 | **********************************************************************************/ 6 | #ifndef SOL_LOSS_SQUARE_LOSS_H__ 7 | #define SOL_LOSS_SQUARE_LOSS_H__ 8 | 9 | #include 10 | #include 11 | 12 | namespace sol { 13 | namespace loss { 14 | 15 | class SOL_EXPORTS SquareLoss : public Loss { 16 | public: 17 | SquareLoss() : Loss(Type::RG) {} 18 | 19 | public: 20 | virtual float loss(const pario::DataPoint& dp, float* predict, 21 | label_t predict_label, int cls_num); 22 | 23 | virtual float gradient(const pario::DataPoint& dp, float* predict, 24 | label_t predict_label, float* gradient, int cls_num); 25 | 26 | }; // class SquareLoss 27 | } // namespace loss 28 | } // namespace sol 29 | #endif 30 | -------------------------------------------------------------------------------- /include/sol/math/matrix_storage.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : matrix_storage.h 3 | * Created By : yuewu 4 | * Description : 5 | **********************************************************************************/ 6 | 7 | #ifndef SOL_MATH_MATRIX_STORAGE_H__ 8 | #define SOL_MATH_MATRIX_STORAGE_H__ 9 | 10 | #include 11 | 12 | namespace sol { 13 | namespace math { 14 | 15 | /// \brief storage structure of matrix 16 | /// 17 | /// \tparam DType Data Element Type 18 | template 19 | class MatrixStorage { 20 | public: 21 | MatrixStorage() : begin_(nullptr), size_(0) {} 22 | 23 | ~MatrixStorage() { DeleteArray(this->begin_); } 24 | 25 | public: 26 | /// \brief resize the storage 27 | /// 28 | /// \param new_size Specified size of elements to be allocated 29 | void resize(size_t new_size) { 30 | if (new_size > this->size_) { 31 | DType* new_begin = new DType[new_size]; 32 | memset(new_begin, 0, sizeof(DType) * new_size); 33 | // copy data 34 | std::memcpy(new_begin, this->begin_, sizeof(DType) * this->size()); 35 | DeleteArray(this->begin_); 36 | this->begin_ = new_begin; 37 | this->size_ = new_size; 38 | } 39 | } 40 | 41 | DISABLE_COPY_AND_ASSIGN(MatrixStorage); 42 | 43 | public: 44 | inline const DType* begin() const { return this->begin_; } 45 | inline DType* begin() { return this->begin_; } 46 | 47 | inline const DType* end() const { return this->begin_ + this->size_; } 48 | inline DType* end() { return this->end_ + this->size_; } 49 | 50 | inline size_t size() const { return this->size_; } 51 | 52 | protected: 53 | // point to the first element 54 | DType* begin_; 55 | // capacity of the array 56 | size_t size_; 57 | }; 58 | 59 | } // namespace math 60 | } // namespace sol 61 | 62 | #endif 63 | -------------------------------------------------------------------------------- /include/sol/math/shape.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : matrix_shape.h 3 | * Created By : yuewu 4 | * Description : shape definition of matrices 5 | **********************************************************************************/ 6 | #ifndef SOL_MATH_SHAPE_H__ 7 | #define SOL_MATH_SHAPE_H__ 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | 17 | namespace sol { 18 | namespace math { 19 | 20 | template 21 | class Shape { 22 | public: 23 | /// \brief empty constructor 24 | Shape() { memset(this->shape_, 0, sizeof(size_t) * kDim); } 25 | 26 | Shape(const std::initializer_list& shape) { 27 | if (shape.size() != kDim) { 28 | std::ostringstream oss; 29 | oss << "dimension of input shape (" << shape.size() 30 | << ") is compatible with the expected (" << kDim << ")"; 31 | throw std::invalid_argument(oss.str()); 32 | } 33 | size_t* sp = this->shape_; 34 | for (const size_t& s : shape) *sp++ = s; 35 | } 36 | 37 | /// \brief copy constructor 38 | Shape(const Shape& shape) { 39 | memcpy(this->shape_, shape.shape_, sizeof(size_t) * kDim); 40 | } 41 | 42 | Shape& operator=(const Shape& shape) { 43 | memcpy(this->shape_, shape.shape_, sizeof(size_t) * kDim); 44 | return *this; 45 | } 46 | 47 | public: 48 | inline Shape<2> FlatTo2D() const { 49 | Shape<2> s; 50 | s[1] = this->shape_[kDim - 1]; 51 | s[0] = 1; 52 | for (int i = 0; i < kDim - 1; ++i) { 53 | s[0] *= this->shape_[i]; 54 | } 55 | return s; 56 | } 57 | 58 | public: 59 | inline int dim() const { return kDim; } 60 | 61 | inline size_t& operator[](int idx) { return shape_[idx]; } 62 | inline const size_t& operator[](int idx) const { return shape_[idx]; } 63 | 64 | inline size_t size(int start = 0, int end = kDim) const { 65 | size_t sz = 1; 66 | for (int i = start; i < end; ++i) { 67 | sz *= this->shape_[i]; 68 | } 69 | return sz; 70 | } 71 | 72 | inline size_t offset(int dim) const { 73 | size_t sz = 1; 74 | for (int i = dim + 1; i < kDim; ++i) { 75 | sz *= this->shape_[i]; 76 | } 77 | return sz; 78 | } 79 | 80 | inline bool operator==(const Shape& shape) const { 81 | for (int i = 0; i < kDim; ++i) { 82 | if (shape.shape_[i] != this->shape_[i]) return false; 83 | } 84 | return true; 85 | } 86 | inline bool operator!=(const Shape& shape) const { 87 | return !(*this == shape); 88 | } 89 | 90 | template 91 | friend std::ostream& operator<<(std::ostream& os, const Shape& s); 92 | 93 | inline std::string shape_string() const { 94 | std::ostringstream oss; 95 | oss << *this; 96 | return oss.str(); 97 | } 98 | 99 | protected: 100 | size_t shape_[kDim]; 101 | }; 102 | 103 | template 104 | std::ostream& operator<<(std::ostream& os, const Shape& s) { 105 | os << "shape: "; 106 | os << s.shape_[0]; 107 | for (int i = 1; i < kDim; ++i) { 108 | os << "," << s.shape_[i]; 109 | } 110 | return os; 111 | } 112 | 113 | } // namespace math 114 | } // namespace sol 115 | 116 | #endif 117 | -------------------------------------------------------------------------------- /include/sol/model/olm/ada_fobos.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : ada_fobos.h 3 | * Created By : yuewu 4 | * Description : Adaptive Subgradient FOBOS 5 | **********************************************************************************/ 6 | 7 | #ifndef SOL_MODEL_OLM_ADA_FOBOS_H__ 8 | #define SOL_MODEL_OLM_ADA_FOBOS_H__ 9 | 10 | #include 11 | #include 12 | 13 | namespace sol { 14 | namespace model { 15 | 16 | class AdaFOBOS : public OnlineLinearModel { 17 | public: 18 | AdaFOBOS(int class_num); 19 | virtual ~AdaFOBOS(); 20 | 21 | virtual void SetParameter(const std::string& name, const std::string& value); 22 | 23 | protected: 24 | virtual void Update(const pario::DataPoint& dp, const float* predict, 25 | float loss); 26 | virtual void update_dim(index_t dim); 27 | 28 | virtual void GetModelInfo(Json::Value& root) const; 29 | virtual void GetModelParam(std::ostream& os) const; 30 | virtual int SetModelParam(std::istream& is); 31 | 32 | protected: 33 | float delta_; 34 | math::Vector* H_; 35 | 36 | }; // class AdaFOBOS 37 | 38 | /// \brief AdaFOBOS with l1 regularization 39 | class AdaFOBOS_L1 : public AdaFOBOS { 40 | public: 41 | AdaFOBOS_L1(int class_num); 42 | 43 | virtual void EndTrain(); 44 | 45 | protected: 46 | virtual label_t TrainPredict(const pario::DataPoint& dp, float* predicts); 47 | 48 | protected: 49 | LazyOnlineL1Regularizer l1_; 50 | }; 51 | 52 | class AdaFOBOS_OFS: public AdaFOBOS { 53 | public: 54 | AdaFOBOS_OFS(int class_num); 55 | virtual ~AdaFOBOS_OFS(); 56 | 57 | virtual void SetParameter(const std::string& name, const std::string& value); 58 | virtual void BeginTrain(); 59 | 60 | protected: 61 | virtual void Update(const pario::DataPoint& dp, const float* predict, 62 | float loss); 63 | virtual void update_dim(index_t dim); 64 | 65 | protected: 66 | math::Vector* H_sum_; 67 | OnlineRegularizer l0_; 68 | MinHeap min_heap_; 69 | }; 70 | 71 | } // namespace model 72 | } // namespace sol 73 | #endif 74 | -------------------------------------------------------------------------------- /include/sol/model/olm/ada_rda.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : ada_rda.h 3 | * Created By : yuewu 4 | * Description : Adaptive Subgradient RDA 5 | **********************************************************************************/ 6 | 7 | #ifndef SOL_MODEL_OLM_ADA_RDA_H__ 8 | #define SOL_MODEL_OLM_ADA_RDA_H__ 9 | 10 | #include 11 | #include 12 | 13 | namespace sol { 14 | namespace model { 15 | 16 | class AdaRDA : public OnlineLinearModel { 17 | public: 18 | AdaRDA(int class_num); 19 | virtual ~AdaRDA(); 20 | 21 | virtual void SetParameter(const std::string& name, const std::string& value); 22 | virtual void EndTrain(); 23 | 24 | protected: 25 | virtual void Update(const pario::DataPoint& dp, const float* predict, 26 | float loss); 27 | virtual void update_dim(index_t dim); 28 | 29 | virtual void GetModelInfo(Json::Value& root) const; 30 | virtual void GetModelParam(std::ostream& os) const; 31 | virtual int SetModelParam(std::istream& is); 32 | 33 | protected: 34 | float delta_; 35 | math::Vector* H_; 36 | // sum of gradients 37 | math::Vector* ut_; 38 | 39 | }; // class AdaRDA 40 | 41 | /// \brief AdaRDA with l1 regularization 42 | class AdaRDA_L1 : public AdaRDA { 43 | public: 44 | AdaRDA_L1(int class_num); 45 | 46 | virtual void EndTrain(); 47 | 48 | protected: 49 | virtual label_t TrainPredict(const pario::DataPoint& dp, float* predicts); 50 | 51 | protected: 52 | OnlineL1Regularizer l1_; 53 | }; 54 | 55 | class AdaRDA_OFS: public AdaRDA { 56 | public: 57 | AdaRDA_OFS(int class_num); 58 | virtual ~AdaRDA_OFS(); 59 | 60 | virtual void SetParameter(const std::string& name, const std::string& value); 61 | virtual void BeginTrain(); 62 | virtual void EndTrain(); 63 | 64 | protected: 65 | virtual void Update(const pario::DataPoint& dp, const float* predict, 66 | float loss); 67 | virtual void update_dim(index_t dim); 68 | 69 | protected: 70 | math::Vector* H_sum_; 71 | OnlineRegularizer l0_; 72 | MinHeap min_heap_; 73 | }; 74 | } // namespace model 75 | } // namespace sol 76 | #endif 77 | -------------------------------------------------------------------------------- /include/sol/model/olm/alma2.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : alma2.h 3 | * Created By : yuewu 4 | * Description : Approximate Large Margin Algorithm with norm 2 5 | **********************************************************************************/ 6 | #ifndef SOL_MODEL_OLM_ALMA2_H__ 7 | #define SOL_MODEL_OLM_ALMA2_H__ 8 | 9 | #include 10 | #include 11 | 12 | namespace sol { 13 | namespace model { 14 | 15 | class ALMA2 : public OnlineLinearModel { 16 | public: 17 | ALMA2(int class_num); 18 | 19 | virtual void SetParameter(const std::string& name, const std::string& value); 20 | 21 | public: 22 | virtual void BeginTrain(); 23 | 24 | protected: 25 | virtual void Update(const pario::DataPoint& dp, const float* predict, 26 | float loss); 27 | 28 | virtual void GetModelInfo(Json::Value& root) const; 29 | 30 | protected: 31 | loss::HingeBase* hinge_base_; 32 | int p_; 33 | float alpha_; 34 | float C_; 35 | float B_; 36 | // sqrt(p_ - 1) 37 | float square_p1_; 38 | int k_; 39 | }; // class ALMA2 40 | 41 | } // namespace model 42 | } // namespace sol 43 | #endif 44 | -------------------------------------------------------------------------------- /include/sol/model/olm/arow.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : arow.h 3 | * Created By : yuewu 4 | * Description : Adaptive Regularization of Weight Vectors 5 | **********************************************************************************/ 6 | 7 | #ifndef SOL_MODEL_OLM_AROW_H__ 8 | #define SOL_MODEL_OLM_AROW_H__ 9 | 10 | #include 11 | #include 12 | 13 | namespace sol { 14 | namespace model { 15 | 16 | class AROW : public OnlineLinearModel { 17 | public: 18 | AROW(int class_num); 19 | virtual ~AROW(); 20 | 21 | virtual void SetParameter(const std::string& name, const std::string& value); 22 | 23 | protected: 24 | virtual void Update(const pario::DataPoint& dp, const float* predict, 25 | float loss); 26 | virtual void update_dim(index_t dim); 27 | 28 | virtual void GetModelInfo(Json::Value& root) const; 29 | virtual void GetModelParam(std::ostream& os) const; 30 | virtual int SetModelParam(std::istream& is); 31 | 32 | protected: 33 | inline math::Vector& Sigma(int cls_id) { 34 | return this->Sigmas_[cls_id]; 35 | } 36 | 37 | protected: 38 | float r_; 39 | math::Vector* Sigmas_; 40 | 41 | }; // class AROW 42 | 43 | class SOFS : public AROW { 44 | public: 45 | SOFS(int class_num); 46 | virtual ~SOFS(); 47 | 48 | virtual void SetParameter(const std::string& name, const std::string& value); 49 | virtual void BeginTrain(); 50 | 51 | protected: 52 | virtual void Update(const pario::DataPoint& dp, const float* predict, 53 | float loss); 54 | virtual void update_dim(index_t dim); 55 | 56 | protected: 57 | math::Vector* Sigma_sum_; 58 | OnlineRegularizer l0_; 59 | MaxHeap max_heap_; 60 | }; 61 | 62 | } // namespace model 63 | } // namespace sol 64 | #endif 65 | -------------------------------------------------------------------------------- /include/sol/model/olm/cw.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : cw.h 3 | * Created By : yuewu 4 | * Description : Confidence Weighted Online Learing 5 | **********************************************************************************/ 6 | 7 | #ifndef SOL_MODEL_OLM_CW_H__ 8 | #define SOL_MODEL_OLM_CW_H__ 9 | 10 | #include 11 | #include 12 | 13 | namespace sol { 14 | namespace model { 15 | 16 | class CW : public OnlineLinearModel { 17 | public: 18 | CW(int class_num); 19 | virtual ~CW(); 20 | 21 | virtual void SetParameter(const std::string& name, const std::string& value); 22 | 23 | virtual void BeginTrain(); 24 | 25 | protected: 26 | virtual void Update(const pario::DataPoint& dp, const float* predict, 27 | float loss); 28 | virtual void update_dim(index_t dim); 29 | 30 | virtual void GetModelInfo(Json::Value& root) const; 31 | virtual void GetModelParam(std::ostream& os) const; 32 | virtual int SetModelParam(std::istream& is); 33 | 34 | protected: 35 | loss::HingeBase* hinge_base_; 36 | // initial variance 37 | float a_; 38 | // inverse normal distribution threshold 39 | float phi_; 40 | // x'Sigma x 41 | float Vi_; 42 | math::Vector* Sigmas_; 43 | 44 | }; // class CW 45 | 46 | } // namespace model 47 | } // namespace sol 48 | #endif 49 | -------------------------------------------------------------------------------- /include/sol/model/olm/eccw.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : eccw.h 3 | * Created By : yuewu 4 | * Description : Exact Convex Confidence Weighted Online Learning 5 | **********************************************************************************/ 6 | #ifndef SOL_MODEL_OLM_ECCW_H__ 7 | #define SOL_MODEL_OLM_ECCW_H__ 8 | 9 | #include 10 | #include 11 | 12 | namespace sol { 13 | namespace model { 14 | 15 | class ECCW : public OnlineLinearModel { 16 | public: 17 | ECCW(int class_num); 18 | virtual ~ECCW(); 19 | 20 | virtual void SetParameter(const std::string& name, const std::string& value); 21 | 22 | virtual void BeginTrain(); 23 | 24 | protected: 25 | virtual void Update(const pario::DataPoint& dp, const float* predict, 26 | float loss); 27 | virtual void update_dim(index_t dim); 28 | 29 | protected: 30 | void set_phi(float phi); 31 | 32 | protected: 33 | virtual void GetModelInfo(Json::Value& root) const; 34 | virtual void GetModelParam(std::ostream& os) const; 35 | virtual int SetModelParam(std::istream& is); 36 | 37 | protected: 38 | loss::HingeBase* hinge_base_; 39 | // initial variance 40 | float a_; 41 | // inverse normal distribution threshold 42 | float phi_; 43 | // x'Sigma x 44 | float vi_; 45 | float psi_; 46 | float xi_; 47 | math::Vector* Sigmas_; 48 | 49 | }; // class ECCW 50 | } // namespace model 51 | } // namespace sol 52 | #endif 53 | -------------------------------------------------------------------------------- /include/sol/model/olm/fofs.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : fofs.h 3 | * Created By : yuewu 4 | * Description : First Order Online Feature Selection 5 | **********************************************************************************/ 6 | 7 | #ifndef SOL_MODEL_OLM_FOFS_H__ 8 | #define SOL_MODEL_OLM_FOFS_H__ 9 | 10 | #include 11 | #include 12 | 13 | namespace sol { 14 | namespace model { 15 | 16 | class FOFS : public OnlineLinearModel { 17 | public: 18 | FOFS(int class_num); 19 | virtual ~FOFS(); 20 | 21 | virtual void SetParameter(const std::string& name, const std::string& value); 22 | virtual void BeginTrain(); 23 | 24 | protected: 25 | virtual void Update(const pario::DataPoint& dp, const float* predict, 26 | float loss); 27 | virtual void update_dim(index_t dim); 28 | 29 | virtual void GetModelInfo(Json::Value& root) const; 30 | 31 | protected: 32 | float lambda_; 33 | index_t B_; 34 | math::Vector abs_weights_; 35 | MinHeap min_heap_; 36 | 37 | float norm_coeff_; 38 | float momentum_; 39 | }; 40 | 41 | } // namespace model 42 | } // namespace sol 43 | 44 | #endif 45 | -------------------------------------------------------------------------------- /include/sol/model/olm/ogd.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : sgd.h 3 | * Created By : yuewu 4 | * Creation Date : [2016-02-18 21:33] 5 | * Last Modified : [2016-10-11 22:49] 6 | * Description : Online Gradient Descent 7 | **********************************************************************************/ 8 | 9 | #ifndef SOL_MODEL_OLM_OGD_H__ 10 | #define SOL_MODEL_OLM_OGD_H__ 11 | 12 | #include 13 | #include 14 | 15 | namespace sol { 16 | namespace model { 17 | 18 | class OGD : public OnlineLinearModel { 19 | public: 20 | OGD(int class_num); 21 | 22 | virtual void SetParameter(const std::string& name, const std::string& value); 23 | 24 | protected: 25 | virtual void Update(const pario::DataPoint& dp, const float* predict, 26 | float loss); 27 | virtual void GetModelInfo(Json::Value& root) const; 28 | 29 | protected: 30 | void set_power_t(float power_t); 31 | 32 | protected: 33 | // power_t of the decreasing coefficient of learning rate 34 | float power_t_; 35 | // initial learning rate 36 | float eta0_; 37 | 38 | float (*pow_)(int iter, float power_t); 39 | }; // class OGD 40 | 41 | /// \brief Sparse online learning via Truncated Gradient 42 | class STG : public OGD { 43 | public: 44 | STG(int class_num); 45 | 46 | virtual void SetParameter(const std::string& name, const std::string& value); 47 | virtual void BeginTrain(); 48 | virtual void EndTrain(); 49 | 50 | protected: 51 | virtual label_t TrainPredict(const pario::DataPoint& dp, float* predicts); 52 | void update_dim(index_t dim); 53 | 54 | virtual void GetModelInfo(Json::Value& root) const; 55 | 56 | protected: 57 | // truncate every k steps 58 | int k_; 59 | OnlineL1Regularizer l1_; 60 | math::Vector last_trunc_time_; 61 | }; // class STG 62 | 63 | /// \brief Forward Backward Splitting 64 | class FOBOS_L1 : public OGD { 65 | public: 66 | FOBOS_L1(int class_num); 67 | 68 | virtual void EndTrain(); 69 | 70 | protected: 71 | virtual label_t TrainPredict(const pario::DataPoint& dp, float* predicts); 72 | 73 | protected: 74 | LazyOnlineL1Regularizer l1_; 75 | }; // class STG 76 | 77 | /// \brief Perceptron with Truncation 78 | class PET : public OGD { 79 | public: 80 | PET(int class_num); 81 | virtual ~PET(); 82 | 83 | virtual void SetParameter(const std::string& name, const std::string& value); 84 | virtual void BeginTrain(); 85 | 86 | protected: 87 | virtual void Update(const pario::DataPoint& dp, const float* predict, 88 | float loss); 89 | virtual void update_dim(index_t dim); 90 | 91 | protected: 92 | math::Vector abs_weights_; 93 | OnlineRegularizer l0_; 94 | MinHeap min_heap_; 95 | }; 96 | 97 | } // namespace model 98 | } // namespace sol 99 | #endif 100 | -------------------------------------------------------------------------------- /include/sol/model/olm/pa.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : pa.h 3 | * Created By : yuewu 4 | * Description : Online Passive Aggressive Algorithms 5 | **********************************************************************************/ 6 | #ifndef SOL_MODEL_OLM_PA_H__ 7 | #define SOL_MODEL_OLM_PA_H__ 8 | 9 | #include 10 | 11 | namespace sol { 12 | namespace model { 13 | 14 | class PA : public OnlineLinearModel { 15 | public: 16 | PA(int class_num); 17 | 18 | protected: 19 | virtual void Update(const pario::DataPoint& dp, const float* predict, 20 | float loss); 21 | 22 | protected: 23 | // the coeffient difference between binary and multiclass classification 24 | float eta_coeff_; 25 | }; // class PA 26 | 27 | class PAI : public PA { 28 | public: 29 | PAI(int class_num) : PA(class_num), C_(1.f) {} 30 | 31 | virtual void SetParameter(const std::string& name, const std::string& value); 32 | 33 | protected: 34 | virtual void Update(const pario::DataPoint& dp, const float* predict, 35 | float loss); 36 | virtual void GetModelInfo(Json::Value& root) const; 37 | 38 | protected: 39 | float C_; 40 | }; // class PAI 41 | 42 | class PAII : public PA { 43 | public: 44 | PAII(int class_num) : PA(class_num), C_(1.f) {} 45 | 46 | virtual void SetParameter(const std::string& name, const std::string& value); 47 | 48 | protected: 49 | virtual void Update(const pario::DataPoint& dp, const float* predict, 50 | float loss); 51 | virtual void GetModelInfo(Json::Value& root) const; 52 | 53 | protected: 54 | float C_; 55 | }; // class PAII 56 | 57 | } // namespace model 58 | } // namespace sol 59 | #endif 60 | -------------------------------------------------------------------------------- /include/sol/model/olm/perceptron.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : perceptron.h 3 | * Created By : yuewu 4 | * Description : Perceptron algorithm 5 | **********************************************************************************/ 6 | #ifndef SOL_MODEL_OLM_PERCEPTRON_H__ 7 | #define SOL_MODEL_OLM_PERCEPTRON_H__ 8 | 9 | #include 10 | 11 | namespace sol { 12 | namespace model { 13 | 14 | class Perceptron : public OnlineLinearModel { 15 | public: 16 | Perceptron(int class_num); 17 | 18 | virtual void SetParameter(const std::string& name, const std::string& value); 19 | 20 | protected: 21 | virtual void Update(const pario::DataPoint& dp, const float* predict, 22 | float loss); 23 | }; // class Perceptron 24 | 25 | } // namespace model 26 | } // namespace sol 27 | #endif 28 | -------------------------------------------------------------------------------- /include/sol/model/olm/rda.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : rda.h 3 | * Created By : yuewu 4 | * Description : Regularized Dual Averaging 5 | **********************************************************************************/ 6 | 7 | #ifndef SOL_MODEL_OLM_RDA_H__ 8 | #define SOL_MODEL_OLM_RDA_H__ 9 | 10 | #include 11 | 12 | namespace sol { 13 | namespace model { 14 | 15 | /// \brief RDA with `l2^2` as the proximal function 16 | class RDA : public OnlineLinearModel { 17 | public: 18 | RDA(int class_num); 19 | virtual ~RDA(); 20 | 21 | virtual void SetParameter(const std::string& name, const std::string& value); 22 | virtual void EndTrain(); 23 | 24 | protected: 25 | virtual label_t TrainPredict(const pario::DataPoint& dp, float* predicts); 26 | virtual void Update(const pario::DataPoint& dp, const float* predict, 27 | float loss); 28 | virtual void update_dim(index_t dim); 29 | 30 | virtual void GetModelInfo(Json::Value& root) const; 31 | virtual void GetModelParam(std::ostream& os) const; 32 | virtual int SetModelParam(std::istream& is); 33 | 34 | protected: 35 | float sigma_; 36 | // sum of gradients 37 | math::Vector* ut_; 38 | }; // class RDA 39 | 40 | /// \brief RDA with l1 regularization 41 | class RDA_L1 : public RDA { 42 | public: 43 | RDA_L1(int class_num); 44 | 45 | virtual void EndTrain(); 46 | 47 | protected: 48 | virtual label_t TrainPredict(const pario::DataPoint& dp, float* predicts); 49 | 50 | protected: 51 | OnlineL1Regularizer l1_; 52 | }; 53 | 54 | /// \brief Enhanced RDA with l1 regularization 55 | class ERDA_L1 : public RDA { 56 | public: 57 | ERDA_L1(int class_num); 58 | 59 | virtual void SetParameter(const std::string& name, const std::string& value); 60 | virtual void EndTrain(); 61 | 62 | protected: 63 | virtual label_t TrainPredict(const pario::DataPoint& dp, float* predicts); 64 | 65 | virtual void GetModelInfo(Json::Value& root) const; 66 | 67 | protected: 68 | float rou_; 69 | OnlineL1Regularizer l1_; 70 | }; 71 | 72 | } // namespace model 73 | } // namespace sol 74 | #endif 75 | -------------------------------------------------------------------------------- /include/sol/model/olm/sop.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : sop.h 3 | * Created By : yuewu 4 | * Description : second order perceptron 5 | **********************************************************************************/ 6 | 7 | #ifndef SOL_MODEL_OLM_SOP_H__ 8 | #define SOL_MODEL_OLM_SOP_H__ 9 | 10 | #include 11 | 12 | namespace sol { 13 | namespace model { 14 | 15 | class SOP : public OnlineLinearModel { 16 | public: 17 | SOP(int class_num); 18 | virtual ~SOP(); 19 | 20 | virtual void SetParameter(const std::string& name, const std::string& value); 21 | virtual void EndTrain(); 22 | 23 | protected: 24 | virtual label_t TrainPredict(const pario::DataPoint& dp, float* predicts); 25 | 26 | protected: 27 | virtual void Update(const pario::DataPoint& dp, const float* predict, 28 | float loss); 29 | virtual void update_dim(index_t dim); 30 | 31 | virtual void GetModelInfo(Json::Value& root) const; 32 | virtual void GetModelParam(std::ostream& os) const; 33 | virtual int SetModelParam(std::istream& is); 34 | 35 | protected: 36 | math::Vector& v(int cls_id) { return this->v_[cls_id]; } 37 | const math::Vector& v(int cls_id) const { return this->v_[cls_id]; } 38 | 39 | protected: 40 | float a_; 41 | math::Vector* v_; 42 | math::Vector X_; 43 | 44 | }; // class SOP 45 | 46 | } // namespace model 47 | } // namespace sol 48 | #endif 49 | -------------------------------------------------------------------------------- /include/sol/model/online_linear_model.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : online_linear_model.h 3 | * Created By : yuewu 4 | * Creation Date : [2016-02-18 16:38] 5 | * Last Modified : [2016-11-03 16:27] 6 | * Description : online linear model 7 | **********************************************************************************/ 8 | 9 | #ifndef SOL_MODEL_ONLINE_LINEAR_MODEL_H__ 10 | #define SOL_MODEL_ONLINE_LINEAR_MODEL_H__ 11 | 12 | #include 13 | #include 14 | 15 | namespace sol { 16 | namespace model { 17 | 18 | class OnlineLinearModel : public OnlineModel { 19 | public: 20 | OnlineLinearModel(int class_num); 21 | virtual ~OnlineLinearModel(); 22 | 23 | public: 24 | virtual void BeginTrain() { OnlineModel::BeginTrain(); } 25 | 26 | virtual void EndTrain() { 27 | if (this->regularizer_ != nullptr) { 28 | for (int c = 0; c < this->clf_num_; ++c) { 29 | this->regularizer_->FinalizeRegularization(w(c)); 30 | } 31 | } 32 | OnlineModel::EndTrain(); 33 | } 34 | virtual label_t Predict(const pario::DataPoint& dp, float* predicts); 35 | 36 | virtual label_t Iterate(const pario::DataPoint& dp, float* predicts); 37 | 38 | protected: 39 | /// \brief update model 40 | /// 41 | /// \param dp training instance 42 | /// \param predict predicted values 43 | /// \param loss prediction loss 44 | virtual void Update(const pario::DataPoint& dp, const float* predict, 45 | float loss) = 0; 46 | virtual void update_dim(index_t dim); 47 | 48 | virtual label_t TrainPredict(const pario::DataPoint& dp, float* predicts); 49 | 50 | public: 51 | virtual float model_sparsity(); 52 | 53 | protected: 54 | virtual void GetModelParam(std::ostream& os) const; 55 | 56 | virtual int SetModelParam(std::istream& is); 57 | 58 | public: 59 | const math::Vector& w(int cls_id) const { 60 | return this->weights_[cls_id]; 61 | } 62 | math::Vector& w(int cls_id) { return this->weights_[cls_id]; } 63 | 64 | inline real_t g(int cls_id) const { return this->gradients_[cls_id]; } 65 | inline real_t& g(int cls_id) { return this->gradients_[cls_id]; } 66 | 67 | private: 68 | // the first element is zero 69 | math::Vector* weights_; 70 | // gradients for each class 71 | real_t* gradients_; 72 | }; // class OnlineLinearModel 73 | } // namespace model 74 | } // namespace sol 75 | 76 | #endif 77 | -------------------------------------------------------------------------------- /include/sol/model/regularizer.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : sparse_model.h 3 | * Created By : yuewu 4 | * Description : base class for sparse models 5 | **********************************************************************************/ 6 | 7 | #ifndef SOL_MODEL_SPARSE_MODEL_H__ 8 | #define SOL_MODEL_SPARSE_MODEL_H__ 9 | 10 | #include 11 | #include 12 | 13 | namespace sol { 14 | namespace model { 15 | 16 | class SOL_EXPORTS Regularizer { 17 | public: 18 | Regularizer() : lambda_(0) {}; 19 | /// \brief set model parameters 20 | /// 21 | /// \param name name of the parameter 22 | /// \param value value of the parameter in string 23 | /// 24 | /// \return status code, 0 if successfully 25 | virtual int SetParameter(const std::string &name, const std::string &value); 26 | 27 | /// \brief finalize the regularization on weights 28 | /// 29 | /// \param w weight vector 30 | 31 | virtual void FinalizeRegularization(math::Vector &w) {} 32 | 33 | /// \brief Get Regularizer Information 34 | /// 35 | /// \param root root node of saver 36 | /// info 37 | virtual void GetRegularizerInfo(Json::Value &root) const; 38 | 39 | public: 40 | inline real_t lambda() const { return this->lambda_; } 41 | 42 | protected: 43 | // regularization weight 44 | real_t lambda_; 45 | }; 46 | 47 | class SOL_EXPORTS OnlineRegularizer : public Regularizer { 48 | public: 49 | virtual void BeginIterate(const pario::DataPoint &dp) {} 50 | virtual void EndIterate(const pario::DataPoint &dp, int cur_iter_num) {} 51 | }; 52 | 53 | class SOL_EXPORTS OnlineL1Regularizer : public OnlineRegularizer { 54 | public: 55 | OnlineL1Regularizer(); 56 | 57 | virtual int SetParameter(const std::string &name, const std::string &value); 58 | 59 | virtual void FinalizeRegularization(math::Vector &w); 60 | 61 | protected: 62 | real_t sparse_thresh_; 63 | }; 64 | 65 | class SOL_EXPORTS LazyOnlineL1Regularizer : public OnlineL1Regularizer { 66 | public: 67 | LazyOnlineL1Regularizer(); 68 | 69 | virtual int SetParameter(const std::string &name, const std::string &value); 70 | 71 | virtual void BeginIterate(const pario::DataPoint &dp); 72 | virtual void EndIterate(const pario::DataPoint &dp, int cur_iter_num); 73 | 74 | public: 75 | inline const math::Vector &last_update_time() const { 76 | return this->last_update_time_; 77 | }; 78 | void set_initial_t(real_t t0) { this->initial_t_ = t0; } 79 | 80 | protected: 81 | real_t initial_t_; 82 | // record the last update time of each dimension 83 | math::Vector last_update_time_; 84 | }; 85 | 86 | } // namespace model 87 | } // namespace sol 88 | 89 | #endif 90 | -------------------------------------------------------------------------------- /include/sol/pario/binary_reader.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : binary_reader.h 3 | * Created By : yuewu 4 | * Creation Date : [2015-11-13 20:28] 5 | * Last Modified : [2015-11-13 20:52] 6 | * Description : binary format data reader 7 | **********************************************************************************/ 8 | 9 | #ifndef SOL_PARIO_BINARY_READER_H__ 10 | #define SOL_PARIO_BINARY_READER_H__ 11 | 12 | #include 13 | #include 14 | 15 | namespace sol { 16 | namespace pario { 17 | 18 | class SOL_EXPORTS BinaryReader : public DataFileReader { 19 | public: 20 | /// \brief Open a new file 21 | /// 22 | /// \param path Path to the file, '-' when if use stdin 23 | /// \param mode open mode, "r" or "rb" 24 | /// 25 | /// \return Status code, Status_OK if succeed 26 | virtual int Open(const std::string& path, const char* mode = "rb"); 27 | 28 | public: 29 | /// \brief Read next data point 30 | /// 31 | /// \param dst_data Destination data point 32 | /// 33 | /// \return Status code, Status_OK if everything ok, Status_EndOfFile if 34 | /// read to file end 35 | virtual int Next(DataPoint& dst_data); 36 | 37 | private: 38 | // compressed codes of indexes 39 | math::Vector comp_codes_; 40 | }; // class BinaryReader 41 | 42 | } // namespace pario 43 | } // namespace sol 44 | 45 | #endif 46 | -------------------------------------------------------------------------------- /include/sol/pario/binary_writer.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : binary_writer.h 3 | * Created By : yuewu 4 | * Creation Date : [2015-11-14 15:32] 5 | * Last Modified : [2015-11-14 15:33] 6 | * Description : binary format data writer 7 | **********************************************************************************/ 8 | 9 | #ifndef SOL_PARIO_BINARY_READER_H__ 10 | #define SOL_PARIO_BINARY_READER_H__ 11 | 12 | #include 13 | #include 14 | 15 | namespace sol { 16 | namespace pario { 17 | 18 | class SOL_EXPORTS BinaryWriter : public DataWriter { 19 | public: 20 | /// \brief Open a new file 21 | /// 22 | /// \param path Path to the file, '-' when if use stdin 23 | /// \param mode open mode, "wb" 24 | /// 25 | /// \return Status code, Status_OK if succeed 26 | virtual int Open(const std::string& path, const char* mode = "wb"); 27 | 28 | public: 29 | /// \brief Write a new data into the file 30 | /// 31 | /// \param data Data to be saved 32 | /// 33 | /// \return Status code, Status_OK if succeed 34 | virtual int Write(const DataPoint& data); 35 | 36 | private: 37 | // compressed codes of indexes 38 | math::Vector comp_codes_; 39 | }; // class BinaryWriter 40 | 41 | } // namespace pario 42 | } // namespace sol 43 | 44 | #endif 45 | -------------------------------------------------------------------------------- /include/sol/pario/compress.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : compress.h 3 | * Created By : yuewu 4 | * Creation Date : [2015-11-13 21:14] 5 | * Last Modified : [2015-11-14 17:39] 6 | * Description : compressor for binary data format 7 | **********************************************************************************/ 8 | 9 | #ifndef SOL_PARIO_COMPRESS_H__ 10 | #define SOL_PARIO_COMPRESS_H__ 11 | 12 | #if defined(_MSC_VER) && defined(_DEBUG) 13 | #include 14 | #endif 15 | 16 | #include 17 | #include 18 | 19 | namespace sol { 20 | namespace pario { 21 | 22 | // encode an unsigned int with run length encoding 23 | // if encode signed int, first map it to unsigned with ZigZag Encoding 24 | inline void run_len_encode(math::Vector& codes, uint64_t i) { 25 | // store an int 7 bits at a time. 26 | while (i >= 128) { 27 | codes.push_back((i & 127) | 128); 28 | i = i >> 7; 29 | } 30 | codes.push_back((i & 127)); 31 | } 32 | 33 | inline const char* run_len_decode( 34 | const char* p, uint64_t& i) { // read an int 7 bits at a time. 35 | size_t count = 0; 36 | while (*p & 128) i = i | ((*(p++) & 127) << 7 * count++); 37 | i = i | (*(p++) << 7 * count); 38 | return p; 39 | } 40 | 41 | /** 42 | * comp : compress the index list, note that the indexes must be sorted from 43 | * small to big 44 | * Note: the function will not erase codes by iteself 45 | * 46 | * @Param indexes: indexes to be encoded 47 | * @Param codes: ouput codes 48 | */ 49 | template ::type* = nullptr> 50 | inline void comp_index(const math::Vector& indexes, 51 | math::Vector& codes) { 52 | T last = 0; 53 | size_t feat_num = indexes.size(); 54 | for (size_t i = 0; i < feat_num; i++) { 55 | run_len_encode(codes, indexes[i] - last); 56 | last = indexes[i]; 57 | } 58 | } 59 | 60 | /** 61 | * decomp_index : de-compress the codes to indexes 62 | * 63 | * @Param codes: input codes 64 | * @Param indexes: output indexes 65 | */ 66 | template ::type* = nullptr> 67 | inline void decomp_index(const math::Vector& codes, 68 | math::Vector& indexes) { 69 | // size_t sz = indexes.size(); 70 | indexes.clear(); 71 | uint64_t last = 0; 72 | uint64_t index = 0; 73 | 74 | const char* p = codes.begin(); 75 | while (p < codes.end()) { 76 | index = 0; 77 | p = run_len_decode(p, index); 78 | index += last; 79 | last = index; 80 | indexes.push_back(T(index)); 81 | } 82 | #if defined(_MSC_VER) && defined(_DEBUG) 83 | assert(p == codes.end()); 84 | #endif 85 | } 86 | 87 | } // namespace pario 88 | } // namespace sol 89 | #endif 90 | -------------------------------------------------------------------------------- /include/sol/pario/csr_matrix_reader.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : csr_matrix_reader.h 3 | * Created By : yuewu 4 | * Description : reader for csr_matrix 5 | **********************************************************************************/ 6 | 7 | #ifdef HAS_NUMPY_DEV 8 | #ifndef SOL_PARIO_CSR_MATRIX_READER_H__ 9 | #define SOL_PARIO_CSR_MATRIX_READER_H__ 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | namespace sol { 16 | namespace pario { 17 | 18 | class SOL_EXPORTS CsrMatrixReader : public DataReader { 19 | public: 20 | CsrMatrixReader() 21 | : is_good_(true), 22 | indices_(nullptr), 23 | indptr_(nullptr), 24 | features_(nullptr), 25 | Y_(nullptr), 26 | n_samples_(0), 27 | x_idx_(0) {} 28 | virtual ~CsrMatrixReader() {} 29 | 30 | public: 31 | virtual int Open(const std::string& path, const char* mode = "r"); 32 | virtual void Close() {} 33 | 34 | virtual bool Good() { return is_good_; } 35 | 36 | virtual void Rewind(); 37 | 38 | public: 39 | virtual int Next(DataPoint& dst_data); 40 | 41 | public: 42 | static std::string GeneratePath(int* indices, int* indptr, double* features, 43 | double* y, int n_samples); 44 | static int ParsePath(const std::string& path, int*& indices, int*& indptr, 45 | double*& features, double*& y, int& n_samples); 46 | 47 | protected: 48 | bool is_good_; 49 | 50 | int* indices_; 51 | int* indptr_; 52 | double* features_; 53 | double* Y_; 54 | int n_samples_; 55 | int x_idx_; 56 | }; 57 | 58 | } // namespace pario 59 | } // namespace sol 60 | 61 | #endif 62 | #endif 63 | -------------------------------------------------------------------------------- /include/sol/pario/csv_reader.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : csv_reader.h 3 | * Created By : yuewu 4 | * Creation Date : [2015-11-13 19:35] 5 | * Last Modified : [2015-11-13 20:49] 6 | * Description : reader for csv data format 7 | **********************************************************************************/ 8 | #ifndef SOL_PARIO_CSV_READER_H__ 9 | #define SOL_PARIO_CSV_READER_H__ 10 | 11 | #include 12 | 13 | namespace sol { 14 | namespace pario { 15 | 16 | class SOL_EXPORTS CSVReader : public DataFileReader { 17 | public: 18 | CSVReader(); 19 | 20 | public: 21 | /// \brief Open a new file 22 | /// 23 | /// \param path Path to the file, '-' when if use stdin 24 | /// \param mode open mode, "r" or "rb" 25 | /// 26 | /// \return Status code, Status_OK if succeed 27 | virtual int Open(const std::string& path, const char* mode = "r"); 28 | 29 | /// \brief Rewind the dataset to the beginning of the file 30 | virtual void Rewind(); 31 | 32 | public: 33 | /// \brief Read next data point 34 | /// 35 | /// \param dst_data Destination data point 36 | /// 37 | /// \return Status code, Status_OK if everything ok, Status_EndOfFile if 38 | /// read to file end 39 | virtual int Next(DataPoint& dst_data); 40 | 41 | private: 42 | /// \brief load the head info of csv data 43 | /// 44 | /// \return Status code, Status_OK if everything ok 45 | int LoadFeatDim(); 46 | 47 | protected: 48 | // dimension of data 49 | index_t feat_dim_; 50 | }; // class CSVReader 51 | 52 | } // namespace pario 53 | } // namespace sol 54 | 55 | #endif 56 | -------------------------------------------------------------------------------- /include/sol/pario/csv_writer.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : csv_writer.h 3 | * Created By : yuewu 4 | * Creation Date : [2015-11-14 15:09] 5 | * Last Modified : [2015-11-14 15:46] 6 | * Description : writer of csv format data 7 | **********************************************************************************/ 8 | 9 | #ifndef SOL_PARIO_CSV_WRITER_H__ 10 | #define SOL_PARIO_CSV_WRITER_H__ 11 | 12 | #include 13 | 14 | namespace sol { 15 | namespace pario { 16 | 17 | class SOL_EXPORTS CSVWriter : public DataWriter { 18 | public: 19 | /// \brief Write a new data into the file 20 | /// 21 | /// \param data Data to be saved 22 | /// 23 | /// \return Status code, Status_OK if succeed 24 | virtual int Write(const DataPoint& data); 25 | 26 | /// \brief Set extra information for the output format, for example header 27 | /// of csv 28 | /// 29 | /// \param extra_info extra info 30 | /// 31 | /// \return Status code, Status_OK if succeed 32 | virtual int SetExtraInfo(const char* extra_info); 33 | 34 | protected: 35 | // dimension of data 36 | index_t feat_dim_; 37 | 38 | }; // class CSVWriter 39 | 40 | } // namespace pario 41 | } // namespace sol 42 | 43 | #endif 44 | -------------------------------------------------------------------------------- /include/sol/pario/data_iter.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : data_iter.h 3 | * Created By : yuewu 4 | * Creation Date : [2015-12-03 14:51] 5 | * Last Modified : [2016-02-12 18:10] 6 | * Description : Data Iterator 7 | **********************************************************************************/ 8 | 9 | #ifndef SOL_PARIO_DATA_ITER_H__ 10 | #define SOL_PARIO_DATA_ITER_H__ 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | namespace sol { 23 | namespace pario { 24 | 25 | class SOL_EXPORTS DataIter { 26 | public: 27 | /// \brief Create a new Data Iterator 28 | /// 29 | /// \param batch_size size of minibatch 30 | /// \param batch_num number of mini-batches in buffer 31 | DataIter(int batch_size = 256, int batch_num = 2); 32 | virtual ~DataIter(); 33 | 34 | public: 35 | /// \brief Load a new data 36 | /// 37 | /// \param path data file path 38 | /// \param dtype data type (svm, bin, csv, etc.) 39 | /// 40 | /// \return 41 | int AddReader(const std::string& path, const std::string& dtype, 42 | int pass_num = 1); 43 | 44 | /// \brief get the next mini-batch 45 | /// 46 | /// \param prev_batch previously used mini-batch for recycle 47 | /// 48 | /// \return 49 | virtual MiniBatch* Next(MiniBatch* prev_batch = nullptr); 50 | 51 | protected: 52 | // mini-batch size 53 | int batch_size_; 54 | // factory to store not used mini batches 55 | BlockQueue mini_batch_factory_; 56 | // mini-batch number in buffer 57 | BlockQueue mini_batch_buf_; 58 | // data reader threads 59 | std::vector> readers_; 60 | // index of running reader 61 | int running_reader_idx_; 62 | }; // class DataIter 63 | } // namespace pario 64 | } // namespace sol 65 | 66 | #endif 67 | -------------------------------------------------------------------------------- /include/sol/pario/data_point.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : data_point.h 3 | * Created By : yuewu 4 | * Creation Date : [2015-10-29 11:16] 5 | * Last Modified : [2016-02-18 23:55] 6 | * Description : Data Point Structure 7 | **********************************************************************************/ 8 | 9 | #ifndef SOL_PARIO_DATA_POINT_H__ 10 | #define SOL_PARIO_DATA_POINT_H__ 11 | 12 | #include 13 | #include 14 | 15 | namespace sol { 16 | namespace pario { 17 | 18 | /** 19 | * \brief one label, one feature index vector, and one feature value vector 20 | */ 21 | class SOL_EXPORTS DataPoint { 22 | public: 23 | /// \brief Create an empty data point 24 | DataPoint(); 25 | 26 | ~DataPoint() {} 27 | 28 | /// \brief Clone the current point to destination point 29 | /// 30 | /// \param dst_pt destination point 31 | void Clone(DataPoint& dst_pt) const; 32 | 33 | /// \brief Clone a new point from this 34 | /// 35 | /// \return new data point 36 | DataPoint Clone() const; 37 | 38 | public: 39 | /// \brief add new feature into the data point, mostly used when loading 40 | // data 41 | /// 42 | /// \param index index of the feature 43 | /// \param feat value of the feature 44 | void AddNewFeat(index_t index, real_t feat); 45 | 46 | inline void Reserve(size_t new_size) { this->data_.reserve(new_size); } 47 | inline void Resize(size_t new_size) { this->data_.resize(new_size); } 48 | /// \brief clear the label, indexes, and features 49 | void Clear(); 50 | 51 | /// \brief Check if the indexes are sorted from small to large 52 | /// 53 | /// \return true of sorted, false otherwise 54 | bool IsSorted() const; 55 | 56 | /// \brief Sort the features so that indexes are from small to large 57 | void Sort(); 58 | 59 | public: 60 | inline const math::SVector& data() const { return data_; } 61 | inline math::SVector& data() { return data_; } 62 | 63 | inline const math::Vector& indexes() const { 64 | return this->data_.indexes(); 65 | } 66 | inline math::Vector& indexes() { return this->data_.indexes(); } 67 | 68 | inline const math::Vector& features() const { 69 | return this->data_.values(); 70 | } 71 | inline math::Vector& features() { return this->data_.values(); } 72 | 73 | inline index_t index(size_t index) const { return this->data_.index(index); } 74 | inline index_t& index(size_t index) { return this->data_.index(index); } 75 | 76 | inline real_t feature(size_t index) const { return this->data_.value(index); } 77 | inline real_t& feature(size_t index) { return this->data_.value(index); } 78 | 79 | inline label_t label() const { return this->label_; } 80 | inline void set_label(label_t label) { this->label_ = label; } 81 | index_t dim() const { return this->data_.dim(); } 82 | inline size_t size() const { return this->data_.size(); } 83 | 84 | protected: 85 | math::SVector data_; 86 | label_t label_; 87 | }; // class DataPoint 88 | 89 | } // namespace pario 90 | } // namespace sol 91 | 92 | #endif 93 | -------------------------------------------------------------------------------- /include/sol/pario/data_read_task.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : data_read_task.h 3 | * Created By : yuewu 4 | * Creation Date : [2016-02-12 16:15] 5 | * Last Modified : [2016-02-12 18:07] 6 | * Description : 7 | **********************************************************************************/ 8 | #ifndef SOL_PARIO_DATA_READ_TASK_H__ 9 | #define SOL_PARIO_DATA_READ_TASK_H__ 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | namespace sol { 18 | namespace pario { 19 | 20 | /// \brief Thread task to read data 21 | class DataReadTask : public ThreadTask { 22 | public: 23 | /// \brief Initialize the Data Read Task 24 | /// 25 | /// \param path data file path 26 | /// \param dtype data type (svm, bin, csv, etc.) 27 | /// \param mini_batch_factory factory of empty mini batch 28 | /// \param mini_batch_buf place to store the loaded mini batched 29 | /// \param pass_num number of passes to read the data 30 | DataReadTask(const std::string& path, const std::string& dtype, 31 | BlockQueue& mini_batch_factory, 32 | BlockQueue& mini_batch_buf, int pass_num); 33 | 34 | public: 35 | inline bool Good() { return this->reader_ != nullptr; } 36 | 37 | protected: 38 | virtual void run(); 39 | 40 | private: 41 | std::unique_ptr reader_; 42 | BlockQueue& mini_batch_factory_; 43 | BlockQueue& mini_batch_buf_; 44 | int pass_num_; 45 | }; 46 | 47 | } // namespace pario 48 | } // namespace sol 49 | #endif 50 | -------------------------------------------------------------------------------- /include/sol/pario/data_writer.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : data_writer.h 3 | * Created By : yuewu 4 | * Creation Date : [2015-11-14 14:46] 5 | * Last Modified : [2015-11-14 15:14] 6 | * Description : Interface for data writer (svm,binary, etc.) 7 | **********************************************************************************/ 8 | 9 | #ifndef SOL_PARIO_DATA_WRITER_H__ 10 | #define SOL_PARIO_DATA_WRITER_H__ 11 | 12 | #include 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | namespace sol { 20 | namespace pario { 21 | 22 | class SOL_EXPORTS DataWriter { 23 | DeclareReflectorBase(DataWriter); 24 | 25 | public: 26 | DataWriter(); 27 | virtual ~DataWriter(); 28 | 29 | public: 30 | /// \brief Open a new file 31 | /// 32 | /// \param path Path to the file, '-' when if use stdout 33 | /// \param mode open mode, "w" or "wb" 34 | /// 35 | /// \return Status code, Status_OK if succeed 36 | virtual int Open(const std::string& path, const char* mode = "w"); 37 | 38 | /// \brief Close the reader 39 | virtual void Close() { this->file_writer_.Close(); } 40 | 41 | /// \brief Check the status of the data handler 42 | /// 43 | /// \return True if everything is ok 44 | virtual bool Good() { return this->is_good_ && this->file_writer_.Good(); } 45 | 46 | public: 47 | /// \brief Write a new data into the file 48 | /// 49 | /// \param data Data to be saved 50 | /// 51 | /// \return Status code, Status_OK if succeed 52 | virtual int Write(const DataPoint& data) = 0; 53 | 54 | /// \brief Set extra information for the output format, for example header 55 | /// of csv 56 | /// 57 | /// \param extra_info extra info 58 | /// 59 | /// \return Status code, Status_OK if succeed 60 | virtual int SetExtraInfo(const char* extra_info) { 61 | return Status_OK; 62 | }; 63 | 64 | protected: 65 | FileWriter file_writer_; 66 | /// \brief flag to denote whether any parse error occurs 67 | bool is_good_; 68 | /// \brief path to the opened file 69 | std::string file_path_; 70 | 71 | public: 72 | const std::string& file_path() const { return file_path_; } 73 | }; 74 | 75 | #define RegisterDataWriter(type, name, descr) \ 76 | type* type##_##CreateNewInstance() { return new type(); } \ 77 | ClassInfo __kClassInfo_##type##__(std::string(name) + "_writer", \ 78 | (void*)(type##_##CreateNewInstance), \ 79 | descr); 80 | } // namespace pario 81 | } // namespace sol 82 | 83 | #endif 84 | -------------------------------------------------------------------------------- /include/sol/pario/file_reader.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : file_reader.h 3 | * Created By : yuewu 4 | * Creation Date : [2015-10-16 23:31] 5 | * Last Modified : [2015-11-12 18:01] 6 | * Description : basic file reader 7 | **********************************************************************************/ 8 | 9 | #ifndef SOL_PARIO_FILE_READER_H__ 10 | #define SOL_PARIO_FILE_READER_H__ 11 | 12 | #include 13 | #include 14 | 15 | namespace sol { 16 | namespace pario { 17 | 18 | class SOL_EXPORTS FileReader { 19 | enum ReadMode { 20 | kUnknown = 0, 21 | kText = 1, 22 | kBinary = 2 23 | }; 24 | 25 | public: 26 | FileReader(); 27 | FileReader(const char* path, const char* mode); 28 | ~FileReader(); 29 | 30 | public: 31 | /** 32 | * \brief open a file to read 33 | * 34 | * \param path Path to the file, set to '-' if read from stdin 35 | * \param mode 'r' or 'rb' 36 | * 37 | * \return Status code, Status_OK if succeed 38 | */ 39 | int Open(const char* path, const char* mode); 40 | 41 | /** 42 | * \brief Close the file 43 | */ 44 | void Close(); 45 | 46 | /** 47 | * \brief Rewind the file reader to the beginning of the file 48 | */ 49 | void Rewind(); 50 | 51 | /** 52 | * Good : Test if the file reader is good 53 | * 54 | * \return: true of good 55 | */ 56 | bool Good(); 57 | 58 | public: 59 | /** 60 | * \brief Read the data from file with specified length 61 | * 62 | * \param length Length of data in size of char to be read 63 | * \param dst Destination buffer to store the data 64 | * 65 | * \return Status code, Status_OK if succeed 66 | */ 67 | int Read(char* dst, size_t length); 68 | 69 | /** 70 | * \brief Read a line from 71 | * 72 | * \param dst Destination buffer to store the data 73 | * \param dst_len length of the destination buffer, note `dst` may be 74 | * reallocated to store one line 75 | * 76 | * \return Status code, Status_OK if succeed 77 | */ 78 | int ReadLine(char*& dst, int& dst_len); 79 | 80 | private: 81 | FILE* file_; 82 | ReadMode mode_; 83 | }; // class FileReader 84 | 85 | } // namespace pario 86 | } // namespace sol 87 | #endif 88 | -------------------------------------------------------------------------------- /include/sol/pario/file_writer.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : file_writer.h 3 | * Created By : yuewu 4 | * Creation Date : [2015-10-17 10:51] 5 | * Last Modified : [2015-11-14 15:25] 6 | * Description : basic file writer 7 | **********************************************************************************/ 8 | 9 | #ifndef SOL_PARIO_FILE_WRITER_H__ 10 | #define SOL_PARIO_FILE_WRITER_H__ 11 | 12 | #include 13 | 14 | #include 15 | 16 | namespace sol { 17 | namespace pario { 18 | 19 | class SOL_EXPORTS FileWriter { 20 | public: 21 | FileWriter(); 22 | FileWriter(const char* path, const char* mode); 23 | ~FileWriter(); 24 | 25 | public: 26 | /** 27 | * \brief open a file to write 28 | * 29 | * \param path Path to the file, set to '-' if write to stdout 30 | * \param mode 'w', 'wb', 'w+', 'w+b', 'a', 'ab,', 'a+', or 'a+b' 31 | * 32 | * \return Status code, Status_OK if succeed 33 | */ 34 | int Open(const char* path, const char* mode); 35 | 36 | /** 37 | * \brief Close the file 38 | */ 39 | void Close(); 40 | 41 | /** 42 | * Good : Test if the file writer is good 43 | * 44 | * \return: true of good 45 | */ 46 | bool Good(); 47 | 48 | public: 49 | /** 50 | * \brief Write the data with specified length to file 51 | * 52 | * \param src_buf source buffer to store the data 53 | * \param length Length of data in size of char to be written 54 | * 55 | * \return Status code, Status_OK if succeed 56 | */ 57 | int Write(char* src_buf, size_t length); 58 | 59 | /// \brief Wrapper for fprintf 60 | /// 61 | /// \param format format string 62 | /// \param ... Formated data 63 | /// 64 | /// \return Status code, Status_OK if succeed 65 | int Printf(const char* format, ...); 66 | 67 | private: 68 | FILE* file_; 69 | }; // class FileWriter 70 | 71 | } // namespace pario 72 | } // namespace sol 73 | 74 | #endif 75 | -------------------------------------------------------------------------------- /include/sol/pario/mini_batch.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : mini_batch.h 3 | * Created By : yuewu 4 | * Creation Date : [2015-12-03 16:59] 5 | * Last Modified : [2016-02-12 18:19] 6 | * Description : min batch 7 | **********************************************************************************/ 8 | 9 | #ifndef SOL_PARIO_MINI_BATCH_H__ 10 | #define SOL_PARIO_MINI_BATCH_H__ 11 | 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | namespace sol { 18 | namespace pario { 19 | 20 | class SOL_EXPORTS MiniBatch { 21 | public: 22 | MiniBatch(int batch_size = 0) 23 | : data_num(0), points_(nullptr), capacity_(batch_size) { 24 | this->points_ = new DataPoint[this->capacity_]; 25 | } 26 | ~MiniBatch() { 27 | if (this->points_ != nullptr) { 28 | delete[] this->points_; 29 | } 30 | } 31 | 32 | public: 33 | inline const DataPoint* points() const { return this->points_; } 34 | inline int size() const { return this->data_num; } 35 | inline int capacity() const { return this->capacity_; } 36 | inline const DataPoint& operator[](size_t index) const { 37 | return this->points_[index]; 38 | } 39 | inline DataPoint& operator[](size_t index) { return this->points_[index]; } 40 | 41 | int data_num; 42 | 43 | private: 44 | DataPoint* points_; 45 | int capacity_; 46 | }; 47 | 48 | } // namespace pario 49 | } // namespace sol 50 | #endif 51 | -------------------------------------------------------------------------------- /include/sol/pario/numpy_reader.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : numpy_reader.h 3 | * Created By : yuewu 4 | * Description : reader for numpy array 5 | **********************************************************************************/ 6 | 7 | #ifdef HAS_NUMPY_DEV 8 | #ifndef SOL_PARIO_NUMPY_READER_H__ 9 | #define SOL_PARIO_NUMPY_READER_H__ 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | namespace sol { 16 | namespace pario { 17 | 18 | class SOL_EXPORTS NumpyReader : public DataReader { 19 | public: 20 | NumpyReader() 21 | : is_good_(true), 22 | X_(nullptr), 23 | Y_(nullptr), 24 | n_samples_(0), 25 | n_features_(0), 26 | stride_(0), 27 | x_idx_(0) {} 28 | virtual ~NumpyReader() {} 29 | 30 | public: 31 | virtual int Open(const std::string& path, const char* mode = "r"); 32 | virtual void Close() {} 33 | 34 | virtual bool Good() { return is_good_; } 35 | 36 | virtual void Rewind(); 37 | 38 | public: 39 | virtual int Next(DataPoint& dst_data); 40 | 41 | public: 42 | static std::string GeneratePath(double* x, double* y, int rows, int cols, 43 | int stride); 44 | static int ParsePath(const std::string& path, double*& x, double*& y, 45 | int& rows, int& cols, int& stride); 46 | 47 | protected: 48 | bool is_good_; 49 | double* X_; 50 | double* Y_; 51 | int n_samples_; 52 | int n_features_; 53 | int stride_; 54 | int x_idx_; 55 | }; 56 | 57 | } // namespace pario 58 | } // namespace sol 59 | 60 | #endif 61 | #endif 62 | -------------------------------------------------------------------------------- /include/sol/pario/svm_reader.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : svm_reader.h 3 | * Created By : yuewu 4 | * Creation Date : [2015-11-11 22:15] 5 | * Last Modified : [2015-11-13 20:43] 6 | * Description : reader of lilbsvm format data 7 | **********************************************************************************/ 8 | 9 | #ifndef SOL_PARIO_SVM_READER_H__ 10 | #define SOL_PARIO_SVM_READER_H__ 11 | 12 | #include 13 | 14 | namespace sol { 15 | namespace pario { 16 | 17 | class SOL_EXPORTS SVMReader : public DataFileReader { 18 | public: 19 | /// \brief Read next data point 20 | /// 21 | /// \param dst_data Destination data point 22 | /// 23 | /// \return Status code, Status_OK if everything ok, Status_EndOfFile if 24 | /// read to file end 25 | virtual int Next(DataPoint& dst_data); 26 | }; // class SVMReader 27 | 28 | } // namespace pario 29 | } // namespace sol 30 | 31 | #endif 32 | -------------------------------------------------------------------------------- /include/sol/pario/svm_writer.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : svm_writer.h 3 | * Created By : yuewu 4 | * Creation Date : [2015-11-14 15:03] 5 | * Last Modified : [2015-11-14 15:04] 6 | * Description : writer of lilbsvm format data 7 | **********************************************************************************/ 8 | 9 | #ifndef SOL_PARIO_SVM_WRITER_H__ 10 | #define SOL_PARIO_SVM_WRITER_H__ 11 | 12 | #include 13 | 14 | namespace sol { 15 | namespace pario { 16 | 17 | class SOL_EXPORTS SVMWriter : public DataWriter { 18 | public: 19 | /// \brief Write a new data into the file 20 | /// 21 | /// \param data Data to be saved 22 | /// 23 | /// \return Status code, Status_OK if succeed 24 | virtual int Write(const DataPoint& data); 25 | 26 | }; // class SVMWriter 27 | 28 | } // namespace pario 29 | } // namespace sol 30 | 31 | #endif 32 | -------------------------------------------------------------------------------- /include/sol/sol.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : include/sol/sol.h 3 | * Created By : yuewu 4 | * Description : 5 | **********************************************************************************/ 6 | 7 | #ifndef SOL_SOL_H__ 8 | #define SOL_SOL_H__ 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #include 20 | 21 | #endif 22 | -------------------------------------------------------------------------------- /include/sol/tools.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : tools.h 3 | * Created By : yuewu 4 | * Description : tools for sol 5 | **********************************************************************************/ 6 | #ifndef SOL_TOOLS_H__ 7 | #define SOL_TOOLS_H__ 8 | 9 | #include 10 | #include "sol/util/types.h" 11 | 12 | namespace sol { 13 | SOL_EXPORTS int analyze(const std::string& src_path, 14 | const std::string& src_type, 15 | const std::string& output_path); 16 | 17 | SOL_EXPORTS int convert(const std::string& src_path, 18 | const std::string& src_type, 19 | const std::string& dst_path, 20 | const std::string& dst_type, 21 | bool binaryize=false, 22 | float binaryize_thresh=0); 23 | 24 | SOL_EXPORTS int shuffle(const std::string& src_path, 25 | const std::string& src_type, 26 | const std::string& output_path, 27 | const std::string& output_type); 28 | 29 | SOL_EXPORTS int split(const std::string& src_path, const std::string& src_type, 30 | int fold_num, const std::string& output_prefix, 31 | const std::string& dst_type, bool shuffle); 32 | } 33 | #endif 34 | -------------------------------------------------------------------------------- /include/sol/util/error_code.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : ../util/error_code.h 3 | * Created By : yuewu 4 | * Creation Date : [2015-10-16 23:44] 5 | * Last Modified : [2015-12-03 14:49] 6 | * Description : error code of lsos 7 | **********************************************************************************/ 8 | #ifndef SOL_UTIL_ERROR_CODE_H__ 9 | #define SOL_UTIL_ERROR_CODE_H__ 10 | 11 | namespace sol { 12 | 13 | static const int Status_OK = 0; 14 | static const int Status_Error = 1; 15 | static const int Status_IO_Error = 2; 16 | static const int Status_EndOfFile = 3; 17 | static const int Status_Invalid_Argument = 4; 18 | static const int Status_Invalid_Format = 5; 19 | 20 | } // namespace sol 21 | 22 | #endif // SOL_UTIL_ERROR_CODE_H__ 23 | -------------------------------------------------------------------------------- /include/sol/util/monitor.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : monitor.h 3 | * Created By : yuewu 4 | * Creation Date : [2016-05-14 15:18] 5 | * Last Modified : [2016-05-17 02:30] 6 | * Description : 7 | **********************************************************************************/ 8 | 9 | #ifndef CXX_SELF_CUSTOMIZED_MONITOR_H__ 10 | #define CXX_SELF_CUSTOMIZED_MONITOR_H__ 11 | 12 | #include 13 | 14 | #if USE_STD_THREAD 15 | #include 16 | #endif 17 | 18 | #include 19 | 20 | namespace sol { 21 | 22 | #if USE_STD_THREAD 23 | class Monitor { 24 | public: 25 | Monitor() : owned_mutex_(new Mutex()) { 26 | this->mutex_ = this->owned_mutex_.get(); 27 | } 28 | Monitor(Mutex& mutex) { this->mutex_ = &mutex; } 29 | 30 | void lock() { this->mutex_->lock(); } 31 | void unlock() { this->mutex_->unlock(); } 32 | // assume the thread already obtains the lock 33 | void wait() { 34 | std::unique_lock lock(this->mutex_->mutex(), std::adopt_lock); 35 | this->cv_.wait(lock); 36 | lock.release(); 37 | } 38 | void notify() { this->cv_.notify_one(); } 39 | void notify_all() { this->cv_.notify_all(); } 40 | 41 | protected: 42 | std::unique_ptr owned_mutex_; 43 | std::condition_variable cv_; 44 | Mutex* mutex_; 45 | }; 46 | 47 | #elif USE_WIN_THREAD 48 | class Monitor { 49 | public: 50 | Monitor() : owned_mutex_(new Mutex()) { 51 | this->mutex_ = this->owned_mutex_.get(); 52 | InitializeConditionVariable(&cv_); 53 | } 54 | Monitor(Mutex& mutex) { 55 | this->mutex_ = &mutex; 56 | 57 | InitializeConditionVariable(&cv_); 58 | } 59 | 60 | void lock() { this->mutex_->lock(); } 61 | void unlock() { this->mutex_->unlock(); } 62 | // assume the thread already obtains the lock 63 | void wait() { 64 | SleepConditionVariableCS(&cv_, &(this->mutex_->mutex()), INFINITE); 65 | } 66 | void notify() { WakeConditionVariable(&cv_); } 67 | void notify_all() { WakeAllConditionVariable(&cv_); } 68 | 69 | protected: 70 | std::unique_ptr owned_mutex_; 71 | CONDITION_VARIABLE cv_; 72 | Mutex* mutex_; 73 | }; 74 | 75 | #elif USE_PTHREAD 76 | 77 | class Monitor { 78 | public: 79 | Monitor() : owned_mutex_(new Mutex()) { 80 | this->mutex_ = this->owned_mutex_.get(); 81 | pthread_cond_init(&cv_, nullptr); 82 | } 83 | Monitor(Mutex& mutex) { 84 | this->mutex_ = &mutex; 85 | pthread_cond_init(&cv_, nullptr); 86 | } 87 | 88 | void lock() { this->mutex_->lock(); } 89 | void unlock() { this->mutex_->unlock(); } 90 | // assume the thread already obtains the lock 91 | void wait() { pthread_cond_wait(&cv_, &(this->mutex_->mutex())); } 92 | void notify() { pthread_cond_signal(&cv_); } 93 | void notify_all() { pthread_cond_broadcast(&cv_); } 94 | 95 | protected: 96 | std::unique_ptr owned_mutex_; 97 | pthread_cond_t cv_; 98 | Mutex* mutex_; 99 | }; 100 | 101 | #endif 102 | 103 | } // namespace sol 104 | 105 | #endif 106 | -------------------------------------------------------------------------------- /include/sol/util/mutex.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : mutex.h 3 | * Created By : yuewu 4 | * Creation Date : [2016-05-14 15:09] 5 | * Last Modified : [2016-05-14 23:29] 6 | * Description : 7 | **********************************************************************************/ 8 | 9 | #ifndef CXX_SELF_CUSTOMIZED_MUTEX_H__ 10 | #define CXX_SELF_CUSTOMIZED_MUTEX_H__ 11 | 12 | #if USE_WIN_THREAD 13 | #include 14 | #elif USE_PTHREAD 15 | #include 16 | #else 17 | #include 18 | #ifndef USE_STD_THREAD 19 | #define USE_STD_THREAD 1 20 | #endif 21 | #endif 22 | 23 | namespace sol { 24 | 25 | #if USE_STD_THREAD 26 | class Mutex { 27 | public: 28 | Mutex() {} 29 | 30 | void lock() { this->mutex_.lock(); } 31 | void unlock() { this->mutex_.unlock(); } 32 | 33 | std::mutex& mutex() { return mutex_; } 34 | 35 | protected: 36 | std::mutex mutex_; 37 | }; 38 | 39 | #elif USE_WIN_THREAD 40 | class Mutex { 41 | public: 42 | Mutex() { InitializeCriticalSection(&mutex_); } 43 | ~Mutex() { DeleteCriticalSection(&mutex_); } 44 | 45 | void lock() { EnterCriticalSection(&mutex_); } 46 | void unlock() { LeaveCriticalSection(&mutex_); } 47 | 48 | CRITICAL_SECTION& mutex() { return mutex_; } 49 | 50 | protected: 51 | CRITICAL_SECTION mutex_; 52 | }; 53 | 54 | #elif USE_PTHREAD 55 | 56 | class Mutex { 57 | public: 58 | Mutex() { pthread_mutex_init(&mutex_, nullptr); } 59 | 60 | void lock() { pthread_mutex_lock(&mutex_); } 61 | void unlock() { pthread_mutex_unlock(&mutex_); } 62 | 63 | pthread_mutex_t& mutex() { return mutex_; } 64 | 65 | protected: 66 | pthread_mutex_t mutex_; 67 | }; 68 | 69 | #endif 70 | 71 | } // namespace sol 72 | 73 | #endif 74 | -------------------------------------------------------------------------------- /include/sol/util/platform_win32.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : platform_win32.h 3 | * Created By : yuewu 4 | * Creation Date : [2015-10-23 13:33] 5 | * Last Modified : [2015-12-03 14:49] 6 | * Description : platform specific functions for windows 7 | **********************************************************************************/ 8 | 9 | #ifndef SOL_UTIL_PLATFORM_WIN32_H__ 10 | #define SOL_UTIL_PLATFORM_WIN32_H__ 11 | 12 | #include 13 | 14 | namespace sol { 15 | 16 | inline FILE* open_file(const char* path, const char* mode) { 17 | FILE* file; 18 | errno_t ret = fopen_s(&file, path, mode); 19 | if (ret != 0) { 20 | return nullptr; 21 | } 22 | return file; 23 | } 24 | 25 | } // namespace sol 26 | 27 | #endif // SOL_UTIL_PLATFORM_WIN32_H__ 28 | -------------------------------------------------------------------------------- /include/sol/util/platform_xnix.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : platform_xnix.h 3 | * Created By : yuewu 4 | * Creation Date : [2015-10-23 13:37] 5 | * Last Modified : [2015-12-03 14:49] 6 | * Description : platform specific functions for linix/unix 7 | **********************************************************************************/ 8 | 9 | #ifndef SOL_UTIL_PLATFORM_XNIX_H__ 10 | #define SOL_UTIL_PLATFORM_XNIX_H__ 11 | 12 | namespace sol { 13 | 14 | inline FILE* open_file(const char* path, const char* mode) { 15 | return fopen(path, mode); 16 | } 17 | 18 | } // namespace sol 19 | 20 | #endif 21 | -------------------------------------------------------------------------------- /include/sol/util/reflector.h: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | > File Name: reflector.h 3 | > Copyright (C) 2014 Yue Wu 4 | > Created Time: 2014/5/12 Monday 16:21:00 5 | > Functions: C++ reflector 6 | ************************************************************************/ 7 | #ifndef CXX_SELF_CUSTOMIZED_RELFECTOR_H__ 8 | #define CXX_SELF_CUSTOMIZED_RELFECTOR_H__ 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | namespace sol { 18 | 19 | /// \brief Information of a class, its name, creator, and description 20 | class SOL_EXPORTS ClassInfo { 21 | public: 22 | ClassInfo(const std::string& name, void* func, const std::string& descr = ""); 23 | 24 | public: 25 | const std::string& name() const { return this->name_; } 26 | void* create_func() const { return this->create_func_; } 27 | const std::string& descr() const { return this->descr_; } 28 | 29 | private: 30 | std::string name_; 31 | void* create_func_; 32 | std::string descr_; 33 | }; 34 | 35 | /// \brief Factory of classes 36 | class SOL_EXPORTS ClassFactory { 37 | public: 38 | typedef std::map ClsInfoMapType; 39 | 40 | /// \brief Called outside of Registry to register a new class to 41 | /// dictionary 42 | /// 43 | /// \param class_info Information about the class, including: 44 | // 1. Name of the class 45 | // 2. CreateFunction poniter 46 | // 3. Description of the class(optional) 47 | static void Register(ClassInfo* class_info); 48 | 49 | static ClsInfoMapType& ClassInfoMap(); 50 | }; 51 | 52 | /// \brief Create a new class according to the name of the class 53 | /// 54 | /// \tparam ClsType Type of the class 55 | /// \param name Name of the class 56 | /// \param params Parameter required to create the class 57 | /// 58 | /// \return Pointer to the created class instance 59 | template 61 | ReturnType CreateObject(const std::string& cls_name) { 62 | auto cls_info_map = ClassFactory::ClassInfoMap(); 63 | const std::string& cls_name2 = lower(cls_name); 64 | auto iter = cls_info_map.find(cls_name2); 65 | if (iter != cls_info_map.end()) { 66 | return (ReturnType((iter->second)->create_func())); 67 | } 68 | fprintf(stderr, "no class named %s\n", cls_name.c_str()); 69 | return ReturnType(nullptr); 70 | } 71 | 72 | #define UniqueClassName(name, suffix) name##suffix 73 | 74 | #define DeclareReflectorBase(type, ...) \ 75 | public: \ 76 | typedef type* (*CreateFunction)(__VA_ARGS__); \ 77 | \ 78 | static type* Create(const std::string& cls_name, ##__VA_ARGS__); \ 79 | \ 80 | private: 81 | #define RegisterClassReflector(type, name, descr) \ 82 | type* type##_##CreateNewInstance() { return new type(); } \ 83 | ClassInfo __kClassInfo_##type##__(name, (void*)(type##_##CreateNewInstance), \ 84 | descr); 85 | } // namespace sol 86 | 87 | #endif // CXX_SELF_CUSTOMIZED_RELFECTOR_H__ 88 | -------------------------------------------------------------------------------- /include/sol/util/str_util.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : str_util.h 3 | * Created By : yuewu 4 | * Creation Date : [2016-02-18 15:55] 5 | * Last Modified : [2016-02-18 23:12] 6 | * Description : string related operations 7 | **********************************************************************************/ 8 | #ifndef SOL_UTIL_STR_UTIL_H__ 9 | #define SOL_UTIL_STR_UTIL_H__ 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | namespace sol { 18 | 19 | inline std::vector split(const std::string& str, 20 | char delim = '\t') { 21 | std::vector res; 22 | auto e = str.end(); 23 | auto i = str.begin(); 24 | while (i != e) { 25 | i = find_if_not(i, e, [delim](char c) { return c == delim; }); 26 | if (i == e) break; 27 | auto j = find_if(i, e, [delim](char c) { return c == delim; }); 28 | res.push_back(std::string(i, j)); 29 | i = j; 30 | } 31 | return res; 32 | } 33 | 34 | inline std::string strip(const std::string& str) { 35 | auto i = std::find_if_not(str.begin(), str.end(), 36 | [](char c) { return c == ' ' || c == '\t'; }); 37 | auto j = 38 | std::find_if(i, str.end(), [](char c) { return c == ' ' || c == '\t'; }); 39 | return std::string(i, j); 40 | } 41 | 42 | inline std::string lower(const std::string& str) { 43 | std::string res = str; 44 | for (char& c : res) { 45 | c = tolower(c); 46 | } 47 | return res; 48 | } 49 | 50 | } // namespace std 51 | 52 | #endif 53 | -------------------------------------------------------------------------------- /include/sol/util/thread.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : ../util/thread.h 3 | * Created By : yuewu 4 | * Creation Date : [2016-05-14 00:22] 5 | * Last Modified : [2016-05-14 23:34] 6 | * Description : 7 | **********************************************************************************/ 8 | 9 | #ifndef CXX_SELF_CUSTOMIZED_THREAD_H__ 10 | #define CXX_SELF_CUSTOMIZED_THREAD_H__ 11 | 12 | #include 13 | 14 | #if USE_STD_THREAD 15 | #include 16 | #endif 17 | 18 | namespace sol { 19 | 20 | typedef void (*ThreadFuncType)(void*); 21 | 22 | #if USE_STD_THREAD 23 | class Thread { 24 | public: 25 | Thread(ThreadFuncType start_routine, void* param) { 26 | this->thread_.reset(new std::thread(start_routine, param)); 27 | } 28 | 29 | void join() { this->thread_->join(); } 30 | 31 | protected: 32 | std::unique_ptr thread_; 33 | }; 34 | 35 | #elif USE_WIN_THREAD 36 | 37 | class Thread { 38 | public: 39 | Thread(ThreadFuncType start_routine, void* param) 40 | : thread_func_(start_routine), thread_func_param_(param) { 41 | this->thread_ = 42 | CreateThread(nullptr, 0, Thread::ThreadProxy, this, 0, nullptr); 43 | } 44 | ~Thread() { 45 | TerminateThread(this->thread_, 0); 46 | CloseHandle(this->thread_); 47 | } 48 | 49 | void join() { WaitForSingleObject(this->thread_, INFINITE); } 50 | 51 | public: 52 | static DWORD WINAPI ThreadProxy(LPVOID param) { 53 | Thread* instance = (Thread*)(param); 54 | instance->thread_func_(instance->thread_func_param_); 55 | return 0; 56 | } 57 | 58 | protected: 59 | HANDLE thread_; 60 | ThreadFuncType thread_func_; 61 | void* thread_func_param_; 62 | }; 63 | 64 | #elif USE_PTHREAD 65 | 66 | class Thread { 67 | public: 68 | Thread(ThreadFuncType start_routine, void* param) 69 | : thread_func_(start_routine), thread_func_param_(param) { 70 | pthread_create(&thread_, nullptr, Thread::ThreadProxy, this); 71 | } 72 | 73 | void join() { pthread_join(this->thread_, nullptr); } 74 | 75 | static void* ThreadProxy(void* param) { 76 | Thread* instance = (Thread*)(param); 77 | instance->thread_func_(instance->thread_func_param_); 78 | return nullptr; 79 | } 80 | 81 | protected: 82 | pthread_t thread_; 83 | 84 | ThreadFuncType thread_func_; 85 | void* thread_func_param_; 86 | }; 87 | 88 | #endif 89 | 90 | } // namespace sol 91 | #endif 92 | -------------------------------------------------------------------------------- /include/sol/util/thread_task.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : thread_task.h 3 | * Created By : yuewu 4 | * Creation Date : [2015-12-03 14:15] 5 | * Last Modified : [2016-05-14 00:21] 6 | * Description : Task processed in a separate task 7 | **********************************************************************************/ 8 | 9 | #ifndef SHENTU_UTIL_THREAD_TASK_H__ 10 | #define SHENTU_UTIL_THREAD_TASK_H__ 11 | 12 | #include 13 | #include 14 | 15 | namespace sol { 16 | 17 | /// \brief task processed in a separate task 18 | class ThreadTask { 19 | public: 20 | void Start() { 21 | if (this->thread_ == nullptr) { 22 | this->thread_.reset(new Thread(ThreadTask::InternalEntry, this)); 23 | } 24 | } 25 | void Join() { 26 | if (this->thread_) { 27 | this->thread_->join(); 28 | } 29 | } 30 | 31 | protected: 32 | virtual void run() = 0; 33 | static void InternalEntry(void* task) { ((ThreadTask*)task)->run(); } 34 | 35 | protected: 36 | std::unique_ptr thread_; 37 | }; // class ThreadTask 38 | 39 | } // namespace sol 40 | #endif 41 | -------------------------------------------------------------------------------- /include/sol/util/types.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : global_config.h 3 | * Created By : yuewu 4 | * Creation Date : [2015-10-29 11:23] 5 | * Last Modified : [2015-12-03 14:49] 6 | * Description : Global Configurations 7 | **********************************************************************************/ 8 | 9 | #ifndef SOL_UTIL_GLOBAL_CONFIG_H__ 10 | #define SOL_UTIL_GLOBAL_CONFIG_H__ 11 | 12 | #include 13 | #include 14 | 15 | namespace sol { 16 | 17 | #ifndef SOL_EMBED_PACKAGE 18 | #if (defined WIN32 || defined _WIN32 || defined WINCE) 19 | #ifdef SOL_EXPORTS 20 | #undef SOL_EXPORTS 21 | #define SOL_EXPORTS __declspec(dllexport) 22 | #else 23 | #define SOL_EXPORTS __declspec(dllimport) 24 | #endif 25 | #else 26 | #undef SOL_EXPORTS 27 | #define SOL_EXPORTS 28 | #endif 29 | #else 30 | #undef SOL_EXPORTS 31 | #define SOL_EXPORTS 32 | #endif 33 | 34 | #ifndef FeatType 35 | #define FeatType float 36 | #endif 37 | 38 | #ifndef IndexType 39 | #define IndexType uint32_t 40 | #endif 41 | 42 | #ifndef LabelType 43 | #define LabelType int32_t 44 | #endif 45 | 46 | /// \brief only float or double type are allowed for features 47 | template 48 | struct feat_type_traits { 49 | typedef typename std::enable_if::value || 50 | std::is_same::value, 51 | T>::type type; 52 | }; 53 | 54 | typedef feat_type_traits::type real_t; 55 | 56 | /// \brief only uint16_t, uint32_t or uint64_t type are allowed for features 57 | template 58 | struct index_type_traits { 59 | typedef typename std::enable_if::value || 60 | std::is_same::value || 61 | std::is_same::value, 62 | T>::type type; 63 | }; 64 | 65 | typedef index_type_traits::type index_t; 66 | static const index_t invalid_index = static_cast(-1); 67 | 68 | /// \brief only char, short, int32, int64 are allowed for features 69 | template 70 | struct label_type_traits { 71 | typedef typename std::enable_if< 72 | std::is_same::value || std::is_same::value || 73 | std::is_same::value || std::is_same::value, 74 | T>::type type; 75 | }; 76 | 77 | typedef label_type_traits::type label_t; 78 | 79 | } // namespace sol 80 | 81 | #endif 82 | -------------------------------------------------------------------------------- /include/sol/util/util.h: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : platform.h 3 | * Created By : yuewu 4 | * Creation Date : [2015-10-17 11:06] 5 | * Last Modified : [2016-02-12 17:51] 6 | * Description : utilized or platform specific functions 7 | **********************************************************************************/ 8 | 9 | #ifndef SOL_UTIL_PLATFORM_H__ 10 | #define SOL_UTIL_PLATFORM_H__ 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | /// \brief declaration of functions 19 | namespace sol { 20 | 21 | #define DeletePointer(p) \ 22 | if ((p) != nullptr) { \ 23 | delete (p); \ 24 | (p) = nullptr; \ 25 | } 26 | 27 | #define DeleteArray(p) \ 28 | if ((p) != nullptr) { \ 29 | delete[](p); \ 30 | (p) = nullptr; \ 31 | } 32 | 33 | #define DISABLE_COPY_AND_ASSIGN(classname) \ 34 | private: \ 35 | classname(const classname&); \ 36 | classname& operator=(const classname&); 37 | 38 | // check if the argument is valid and throw exception otherwise 39 | #define Check(condition) \ 40 | if ((condition) == false) { \ 41 | std::ostringstream oss; \ 42 | oss << "Check " << #condition << " failed at line " << __LINE__ << " of " \ 43 | << __FILE__; \ 44 | throw std::invalid_argument(oss.str()); \ 45 | } 46 | 47 | /// \brief Open file wrapper, windows use fopen_s for safety 48 | /// 49 | /// \param path Path to file 50 | /// \param mode Open mode 51 | /// 52 | /// \return FILE pointer 53 | inline FILE* open_file(const char* path, const char* mode); 54 | 55 | /// \brief delete a file 56 | /// 57 | /// \param path File path 58 | /// \param is_force Whether prompt when file not exist 59 | inline void delete_file(const char* path, bool is_force = false) { 60 | if (remove(path) != 0 && is_force == false) { 61 | fprintf(stderr, "warnning, remove file %s failed!\n", path); 62 | } 63 | } 64 | 65 | /// \brief get current time, in seconds 66 | /// 67 | /// \return seconds 68 | inline double get_current_time() { 69 | return std::chrono::duration_cast( 70 | std::chrono::system_clock::now().time_since_epoch()).count() * 71 | 0.001; 72 | } 73 | 74 | } // namespace sol 75 | 76 | #if _WIN32 77 | #include "sol/util/platform_win32.h" 78 | #else 79 | #include "sol/util/platform_xnix.h" 80 | #endif 81 | 82 | #endif 83 | -------------------------------------------------------------------------------- /ofs/README.md: -------------------------------------------------------------------------------- 1 | Experimental Scripts for Large-scale Online Feature Selection for Ultra-high Dimensional Sparse Data 2 | ================================================================ 3 | 4 | The python scripts in this folder are for the following paper: 5 | 6 | Yue Wu, Steven C.H. Hoi, Tao Mei, and Nenghai Yu. 2017. Large-scale Online Feature Selection 7 | for Ultra-high Dimensional Sparse Data. ACM Transactions on Knowledge Discovery from Data. 8 | 9 | 10 | # Installation 11 | 12 | 1. Install the SOL python scripts. Refer to the SOL documentation for details. 13 | 14 | 2. Some external packages you may need to install for full set of experiments: 15 | 16 | + [fast-mRMR](https://github.com/sramirez/fast-mRMR) 17 | 18 | 1. Comile the cpu version for general mRMR algorithm. 19 | 20 | 2. If you have a Nvidia GPU card, compile the gpu version. 21 | 22 | 3. Compile the data-reader in utils. **Change the arguments of the main 23 | function as follows, so that the output file is specified manually.** 24 | 25 | line 77: ofstream outputFile(argv[2], ios::out | ios::binary); 26 | 27 | 4. Add the **fast-mrmr**, **gpu-mrmr**, and **mrmr-reader** to the 28 | system path. 29 | 30 | + [FGM](http://www.tanmingkui.com/fgm.html) 31 | 32 | Compile the code and add the executable **FGM** to the system path. 33 | 34 | # Experiments 35 | 36 | The configuarions for the datasets are in the **"opts"** folder. 37 | For example, to compare performance on the 'aut' dataset, you can simpy run: 38 | 39 | python fs.py aut /data/sol/aut/aut_train /data/sol/aut/aut_test 40 | 41 | The results and figures will be saved to the folder **"cache/aut"**. 42 | -------------------------------------------------------------------------------- /ofs/fig.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | # AUTHOR: 3 | # FILE: fig.py 4 | # ROLE: TODO (some explanation) 5 | # CREATED: 2015-05-16 23:37:47 6 | # MODIFIED: 2015-05-16 23:37:47 7 | 8 | import matplotlib.pyplot as plt 9 | import logging 10 | 11 | plt.rc('pdf', fonttype=42) 12 | 13 | 14 | def plot(xs, ys, 15 | x_label, 16 | y_label, 17 | legends, 18 | output_path, 19 | line_width=3, 20 | marker_size=12, 21 | xlim=None, 22 | ylim=None, 23 | xtickers=None, 24 | logx=False, 25 | logy=False, 26 | clip_on=False, 27 | fontsize=18, 28 | legend_cols=2, 29 | legend_order=201, 30 | legend_loc='best', 31 | bbox_to_anchor=None, 32 | draw_legend=True): 33 | 34 | color_list = ['r', 'm', 'k', 'b', (0.12, 0.56, 1), (0.58, 0.66, 0.2), (0.48, 0.41, 0.93), 35 | (0, 0.75, 0.75)] 36 | marker_list = ['s', 'h', '*', u'o', 'd', '^', 'v', '<', '>'] 37 | #line_styles=['-','--'] 38 | 39 | c_ind = 0 40 | m_ind = 0 41 | fig = plt.figure() 42 | ax = fig.add_subplot(1, 1, 1) 43 | lines = [] 44 | if logx is True and logy is True: 45 | plot_handler = ax.loglog 46 | elif logx is True and logy is False: 47 | plot_handler = ax.semilogx 48 | elif logx is False and logy is True: 49 | plot_handler = ax.semilogy 50 | else: 51 | plot_handler = ax.plot 52 | 53 | for i in xrange(len(xs)): 54 | zorder = 200 - i 55 | color = color_list[c_ind % len(color_list)] 56 | marker = marker_list[m_ind % len(marker_list)] 57 | c_ind += 1 58 | m_ind += 1 59 | if xlim != None: 60 | x_values = [] 61 | y_values = [] 62 | for k in xrange(len(xs[i])): 63 | if xs[i][k] >= xlim[0] and xs[i][k] <= xlim[1]: 64 | x_values.append(xs[i][k]) 65 | y_values.append(ys[i][k]) 66 | else: 67 | x_values = xs[i] 68 | y_values = ys[i] 69 | line, = plot_handler(x_values, y_values, 70 | color=color, 71 | marker=marker, 72 | linestyle='-', 73 | clip_on=clip_on, 74 | markersize=marker_size, 75 | linewidth=line_width, 76 | fillstyle='full', 77 | zorder=zorder) 78 | lines.append(line) 79 | 80 | if xtickers != None: 81 | ax.set_xticks(xtickers) 82 | if xlim != None: 83 | ax.set_xlim(xlim) 84 | if ylim != None: 85 | ax.set_ylim(ylim) 86 | 87 | ax.grid() 88 | if draw_legend: 89 | l = ax.legend(lines,legends,loc=legend_loc,ncol=legend_cols) 90 | l.set_zorder(legend_order) 91 | if bbox_to_anchor != None: 92 | l.set_bbox_to_anchor(bbox_to_anchor) 93 | 94 | plt.xlabel(x_label,fontsize=fontsize) 95 | plt.ylabel(y_label,fontsize=fontsize) 96 | #plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0)) 97 | 98 | plt.savefig(output_path,bbox_inches='tight') 99 | logging.info('figure saved to %s' %(output_path)) 100 | #plt.show() 101 | -------------------------------------------------------------------------------- /ofs/opts/aut.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | ################################################################################# 3 | # File Name : aut.py 4 | # Created By : yuewu 5 | # Creation Date : [2016-10-25 11:21] 6 | # Last Modified : [2017-03-14 10:31] 7 | # Description : 8 | ################################################################################# 9 | 10 | import numpy as np 11 | import collections 12 | 13 | const_eta_search = np.logspace(-5, 5, 11, base=2) 14 | eta_search = np.logspace(-2, 8, 11, base=2) 15 | delta_search = np.logspace(-5, 5,11, base=2) 16 | r_search = np.logspace(-5, 5, 11, base=2) 17 | delta_ofs_search = np.logspace(-5, 5, 11, base=2) / 100.0 18 | norm_search = ['L2', 'None'] 19 | 20 | dim = 20072 21 | fs_num = (np.array([0.025,0.05, 0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.95]) * dim).astype(np.int) 22 | 23 | fs_opts = collections.OrderedDict() 24 | 25 | fs_opts['SOFS'] = { 26 | 'cv':{'r': r_search, 'norm':norm_search}, 27 | 'lambda': fs_num 28 | } 29 | fs_opts['PET'] = { 30 | 'params':{'power_t':'0'}, 31 | 'cv':{'eta':eta_search, 'norm':norm_search}, 32 | 'lambda': fs_num 33 | } 34 | fs_opts['FOFS'] = { 35 | 'cv':{'eta': const_eta_search, 'lambda': delta_ofs_search, 'norm':norm_search}, 36 | 'lambda': fs_num 37 | } 38 | fs_opts['FGM'] = { 39 | 'lambda': fs_num 40 | } 41 | fs_opts['liblinear'] = { 42 | 'lambda': [0.015625,0.03125,0.0625,0.125,0.25,0.5,1,2,512,1024,2048,4096,8192,16384] 43 | } 44 | fs_opts['mRMR'] = { 45 | 'params':{'binary_thresh':0.5}, 46 | 'lambda': fs_num 47 | } 48 | fs_opts['GPU-mRMR'] = { 49 | 'params':{'binary_thresh':0.5}, 50 | 'lambda': fs_num 51 | } 52 | 53 | draw_opts = { 54 | 'accu':{ 55 | 'clip_on':True, 56 | 'ylim':[0.93,0.985], 57 | 'legend_loc':'lower right', 58 | 'bbox_to_anchor':(1,0.15), 59 | }, 60 | 'time': { 61 | 'logy': True, 62 | #'xlim':[0,25000], 63 | 'legend_loc':'center', 64 | 'bbox_to_anchor':(0.7,0.65), 65 | 'legend_order':100, 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /ofs/opts/basehock.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | ################################################################################# 3 | # File Name : synthetic_100k.py 4 | # Created By : yuewu 5 | # Creation Date : [2016-10-25 11:21] 6 | # Last Modified : [2016-12-05 19:42] 7 | # Description : 8 | ################################################################################# 9 | 10 | import numpy as np 11 | import collections 12 | 13 | const_eta_search = np.logspace(-5, 5, 11, base=2) 14 | eta_search = np.logspace(-2, 8, 11, base=2) 15 | delta_search = np.logspace(-5, 5,11, base=2) 16 | r_search = np.logspace(-5, 5, 11, base=2) 17 | delta_ofs_search = np.logspace(-5, 5, 11, base=2) / 100.0 18 | norm_search = ['L2', 'None'] 19 | 20 | dim = 4862 21 | fs_num = (np.array([0.05,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.95]) * dim).astype(np.int) 22 | 23 | fs_opts = collections.OrderedDict() 24 | 25 | fs_opts['SOFS'] = { 26 | 'cv':{'r': r_search, 'norm':norm_search}, 27 | 'lambda': fs_num 28 | } 29 | fs_opts['PET'] = { 30 | 'params':{'power_t':'0'}, 31 | 'cv':{'eta':eta_search, 'norm':norm_search}, 32 | 'lambda': fs_num 33 | } 34 | fs_opts['FOFS'] = { 35 | 'cv':{'eta': const_eta_search, 'lambda': delta_ofs_search, 'norm':norm_search}, 36 | 'lambda': fs_num 37 | } 38 | fs_opts['FGM'] = { 39 | 'lambda': fs_num 40 | } 41 | fs_opts['liblinear'] = { 42 | 'lambda': [0.015625,0.0625,0.25,128,512,1024,2048,4096,9182] 43 | } 44 | fs_opts['mRMR'] = { 45 | 'lambda': fs_num 46 | } 47 | fs_opts['GPU-mRMR'] = { 48 | 'lambda': fs_num 49 | } 50 | 51 | draw_opts = { 52 | 'accu':{}, 53 | 'time': { 54 | 'logy': True, 55 | 'bbox_to_anchor':(0.4,0.3), 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /ofs/opts/ccat.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | ################################################################################# 3 | # File Name : synthetic_100k.py 4 | # Created By : yuewu 5 | # Creation Date : [2016-10-25 11:21] 6 | # Last Modified : [2016-12-05 19:53] 7 | # Description : 8 | ################################################################################# 9 | 10 | import numpy as np 11 | import collections 12 | 13 | const_eta_search = np.logspace(-5, 5, 11, base=2) 14 | eta_search = np.logspace(-2, 8, 11, base=2) 15 | delta_search = np.logspace(-5, 5,11, base=2) 16 | r_search = np.logspace(-5, 5, 11, base=2) 17 | delta_ofs_search = np.logspace(-5, 5, 11, base=2) / 100.0 18 | norm_search = ['L2', 'None'] 19 | 20 | dim = 47236 21 | fs_num = (np.array([0.01, 0.025,0.05,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.95]) * dim).astype(np.int) 22 | 23 | fs_opts = collections.OrderedDict() 24 | 25 | fs_opts['SOFS'] = { 26 | 'cv':{'r': r_search, 'norm':norm_search}, 27 | 'lambda': fs_num 28 | } 29 | fs_opts['PET'] = { 30 | 'params':{'power_t':'0'}, 31 | 'cv':{'eta':eta_search, 'norm':norm_search}, 32 | 'lambda': fs_num 33 | } 34 | fs_opts['FOFS'] = { 35 | 'cv':{'eta': const_eta_search, 'lambda': delta_ofs_search, 'norm':norm_search}, 36 | 'lambda': fs_num 37 | } 38 | fs_opts['FGM'] = { 39 | 'lambda': fs_num 40 | } 41 | fs_opts['liblinear'] = { 42 | 'lambda': [ 0.015625,0.03125,0.0625,0.125, 0.5,2,128,512,2048,4096,16384,131072,262144,524288] 43 | } 44 | fs_opts['mRMR'] = { 45 | 'params':{'binary_thresh':0.5}, 46 | 'lambda': fs_num 47 | } 48 | fs_opts['GPU-mRMR'] = { 49 | 'params':{'binary_thresh':0.5}, 50 | 'lambda': fs_num 51 | } 52 | 53 | draw_opts = { 54 | 'accu':{ 55 | 'ylim':[0.82, 0.94], 56 | 'xlim':[0, 25000], 57 | 'clip_on':True, 58 | 'legend_loc':'lower right', 59 | 'bbox_to_anchor':(1,0), 60 | }, 61 | 'time': { 62 | 'logy': True, 63 | 'xlim':[0,25000], 64 | 'clip_on':True, 65 | 'legend_loc':'center', 66 | 'bbox_to_anchor':(0.7,0.68), 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /ofs/opts/pcmac.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | ################################################################################# 3 | # File Name : synthetic_100k.py 4 | # Created By : yuewu 5 | # Creation Date : [2016-10-25 11:21] 6 | # Last Modified : [2017-03-14 10:31] 7 | # Description : 8 | ################################################################################# 9 | 10 | import numpy as np 11 | import collections 12 | 13 | const_eta_search = np.logspace(-5, 5, 11, base=2) 14 | eta_search = np.logspace(-2, 8, 11, base=2) 15 | delta_search = np.logspace(-5, 5,11, base=2) 16 | r_search = np.logspace(-5, 5, 11, base=2) 17 | delta_ofs_search = np.logspace(-5, 5, 11, base=2) / 100.0 18 | norm_search = ['L2', 'None'] 19 | 20 | dim = 7510 21 | fs_num = (np.array([0.025,0.05,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.95]) * dim).astype(np.int) 22 | 23 | fs_opts = collections.OrderedDict() 24 | 25 | fs_opts['SOFS'] = { 26 | 'cv':{'r': r_search, 'norm':norm_search}, 27 | 'lambda': fs_num 28 | } 29 | fs_opts['PET'] = { 30 | 'params':{'power_t':'0'}, 31 | 'cv':{'eta':eta_search, 'norm':norm_search}, 32 | 'lambda': fs_num 33 | } 34 | fs_opts['FOFS'] = { 35 | 'cv':{'eta': const_eta_search, 'lambda': delta_ofs_search, 'norm':norm_search}, 36 | 'lambda': fs_num 37 | } 38 | fs_opts['FGM'] = { 39 | 'lambda': fs_num 40 | } 41 | fs_opts['liblinear'] = { 42 | 'lambda': [0.0625,0.25,128,2048,4096,8192,16384] 43 | } 44 | fs_opts['mRMR'] = { 45 | 'params':{'binary_thresh':0.5}, 46 | 'lambda': fs_num 47 | } 48 | fs_opts['GPU-mRMR'] = { 49 | 'params':{'binary_thresh':0.5}, 50 | 'lambda': fs_num 51 | } 52 | draw_opts = { 53 | 'accu':{ 54 | 'ylim':[0.92,0.98], 55 | 'clip_on':True, 56 | 'legend_loc':'lower right', 57 | }, 58 | 'time': { 59 | 'logy': True, 60 | 'legend_order':100, 61 | 'ylim':[0.01, 20000] 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /ofs/opts/real-sim.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | ################################################################################# 3 | # File Name : synthetic_100k.py 4 | # Created By : yuewu 5 | # Creation Date : [2016-10-25 11:21] 6 | # Last Modified : [2016-12-05 19:52] 7 | # Description : 8 | ################################################################################# 9 | 10 | import numpy as np 11 | import collections 12 | 13 | const_eta_search = np.logspace(-5, 5, 11, base=2) 14 | eta_search = np.logspace(-2, 8, 11, base=2) 15 | delta_search = np.logspace(-5, 5,11, base=2) 16 | r_search = np.logspace(-5, 5, 11, base=2) 17 | delta_ofs_search = np.logspace(-5, 5, 11, base=2) / 100.0 18 | norm_search = ['L2', 'None'] 19 | 20 | dim = 20958 21 | fs_num = (np.array([0.025,0.05, 0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.95]) * dim).astype(np.int) 22 | 23 | fs_opts = collections.OrderedDict() 24 | 25 | fs_opts['SOFS'] = { 26 | 'cv':{'r': r_search, 'norm':norm_search}, 27 | 'lambda': fs_num 28 | } 29 | fs_opts['PET'] = { 30 | 'params':{'power_t':'0'}, 31 | 'cv':{'eta':eta_search, 'norm':norm_search}, 32 | 'lambda': fs_num 33 | } 34 | fs_opts['FOFS'] = { 35 | 'cv':{'eta': const_eta_search, 'lambda': delta_ofs_search, 'norm':norm_search}, 36 | 'lambda': fs_num 37 | } 38 | fs_opts['FGM'] = { 39 | 'lambda': fs_num 40 | } 41 | fs_opts['liblinear'] = { 42 | 'lambda': [0.015625,0.03125,0.0625,0.125,0.25,0.5,1,2, 43 | 64,512,1024,2048,8192,16384,131072] 44 | } 45 | fs_opts['mRMR'] = { 46 | 'params':{'binary_thresh':0.5}, 47 | 'lambda': fs_num 48 | } 49 | fs_opts['GPU-mRMR'] = { 50 | 'params':{'binary_thresh':0.5, 'device_id':0}, 51 | 'lambda': fs_num 52 | } 53 | 54 | draw_opts = { 55 | 'accu':{ 56 | 'ylim':[0.9, 0.98], 57 | 'clip_on':True, 58 | 'legend_loc':'lower right', 59 | 'bbox_to_anchor':(1,0.1), 60 | }, 61 | 'time': { 62 | 'logy': True, 63 | 'legend_loc':'center', 64 | 'bbox_to_anchor':(0.7,0.65), 65 | 'legend_order':100, 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /ofs/opts/relathe.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | ################################################################################# 3 | # File Name : synthetic_100k.py 4 | # Created By : yuewu 5 | # Creation Date : [2016-10-25 11:21] 6 | # Last Modified : [2016-12-05 19:45] 7 | # Description : 8 | ################################################################################# 9 | 10 | import numpy as np 11 | import collections 12 | 13 | const_eta_search = np.logspace(-5, 5, 11, base=2) 14 | eta_search = np.logspace(-2, 8, 11, base=2) 15 | delta_search = np.logspace(-5, 5,11, base=2) 16 | r_search = np.logspace(-5, 5, 11, base=2) 17 | delta_ofs_search = np.logspace(-5, 5, 11, base=2) / 100.0 18 | norm_search = ['L2', 'None'] 19 | 20 | dim = 4322 21 | fs_num = (np.array([0.05, 0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.95,0.99]) * dim).astype(np.int) 22 | 23 | fs_opts = collections.OrderedDict() 24 | 25 | fs_opts['SOFS'] = { 26 | 'cv':{'r': r_search, 'norm':norm_search}, 27 | 'lambda': fs_num 28 | } 29 | fs_opts['PET'] = { 30 | 'params':{'power_t':'0'}, 31 | 'cv':{'eta':eta_search, 'norm':norm_search}, 32 | 'lambda': fs_num 33 | } 34 | fs_opts['FOFS'] = { 35 | 'cv':{'eta': const_eta_search, 'lambda': delta_ofs_search, 'norm':norm_search}, 36 | 'lambda': fs_num 37 | } 38 | fs_opts['FGM'] = { 39 | 'lambda': fs_num 40 | } 41 | fs_opts['liblinear'] = { 42 | 'lambda': [0.015625,0.03125,0.0625,0.125,64,128,512,1024,2048,4096,9182] 43 | } 44 | fs_opts['mRMR'] = { 45 | 'lambda': fs_num 46 | } 47 | fs_opts['GPU-mRMR'] = { 48 | 'lambda': fs_num 49 | } 50 | 51 | draw_opts = { 52 | 'accu':{ 53 | 'ylim':[0.73,0.9], 54 | }, 55 | 'time': { 56 | 'logy': True, 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /ofs/opts/synthetic_100k.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | ################################################################################# 3 | # File Name : synthetic_100k.py 4 | # Created By : yuewu 5 | # Creation Date : [2016-10-25 11:21] 6 | # Last Modified : [2016-12-06 21:34] 7 | # Description : 8 | ################################################################################# 9 | 10 | import numpy as np 11 | import collections 12 | 13 | const_eta_search = np.logspace(-5, 5, 11, base=2) 14 | eta_search = np.logspace(-2, 8, 11, base=2) 15 | delta_search = np.logspace(-5, 5,11, base=2) 16 | r_search = np.logspace(-5, 5, 11, base=2) 17 | delta_ofs_search = np.logspace(-5, 5, 11, base=2) / 100.0 18 | 19 | fs_num = [50,60,70,80,90,100, 120,140,160,180,200] 20 | 21 | fs_opts = collections.OrderedDict() 22 | 23 | fs_opts['SOFS'] = { 24 | 'params':{'norm':'L2'}, 25 | 'cv':{'r': r_search}, 26 | 'lambda': fs_num 27 | } 28 | fs_opts['PET'] = { 29 | 'params':{'norm':'L2', 'power_t':'0'}, 30 | 'cv':{'eta':eta_search}, 31 | 'lambda': fs_num 32 | } 33 | fs_opts['FOFS'] = { 34 | 'params':{'norm':'L2'}, 35 | 'cv':{'eta': const_eta_search, 'lambda': delta_ofs_search}, 36 | 'lambda': fs_num 37 | } 38 | fs_opts['FGM'] = { 39 | 'lambda': fs_num 40 | } 41 | 42 | fs_opts['liblinear'] = { 43 | 'lambda': [0.0001,0.00015, 0.0002, 0.00025,0.0005,0.001,0.01, 0.018, 0.02, 0.022, 0.023, 0.024, 0.025] 44 | } 45 | 46 | fs_opts['mRMR'] = { 47 | 'params':{'binary_thresh':0}, 48 | 'lambda': fs_num 49 | } 50 | fs_opts['GPU-mRMR'] = { 51 | 'params':{'binary_thresh':0}, 52 | 'lambda': fs_num 53 | } 54 | #fs_opts['AROW'] = { 55 | # 'params':{'norm':'L2'}, 56 | # 'cv':{'r': r_search}, 57 | # 'lambda': [-1] 58 | #} 59 | #fs_opts['OGD'] = { 60 | # 'params':{'norm':'L2'}, 61 | # 'cv':{'eta':eta_search}, 62 | # 'lambda': [-1] 63 | #} 64 | 65 | draw_opts = { 66 | 'accu':{ 67 | 'xlim':[49, 200], 68 | 'ylim':[0.7, 1] 69 | }, 70 | 'time': { 71 | 'logy': True, 72 | 'xlim':[49, 200], 73 | 'ylim':[2, 1000], 74 | 'legend_order':100, 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /ofs/opts/synthetic_1m.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | ################################################################################# 3 | # File Name : synthetic_100k.py 4 | # Created By : yuewu 5 | # Creation Date : [2016-10-25 11:21] 6 | # Last Modified : [2016-12-04 20:45] 7 | # Description : 8 | ################################################################################# 9 | import numpy as np 10 | import collections 11 | 12 | const_eta_search = np.logspace(-5, 5, 11, base=2) 13 | eta_search = np.logspace(-2, 8, 11, base=2) 14 | delta_search = np.logspace(-5, 5,11, base=2) 15 | r_search = np.logspace(-5, 5, 11, base=2) 16 | delta_ofs_search = np.logspace(-5, 5, 11, base=2) / 100.0 17 | 18 | fs_num = [500] 19 | 20 | fs_opts = collections.OrderedDict() 21 | 22 | fs_opts['SOFS'] = { 23 | 'params':{'norm':'L2'}, 24 | 'cv':{'r': r_search}, 25 | 'lambda': fs_num 26 | } 27 | fs_opts['AROW'] = { 28 | 'params':{'norm':'L2'}, 29 | 'cv':{'r': r_search}, 30 | 'lambda': [-1] 31 | } 32 | fs_opts['OGD'] = { 33 | 'params':{'norm':'L2'}, 34 | 'cv':{'eta':eta_search}, 35 | 'lambda': [-1] 36 | } 37 | -------------------------------------------------------------------------------- /ofs/opts/synthetic_200k.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | ################################################################################# 3 | # File Name : synthetic_100k.py 4 | # Created By : yuewu 5 | # Creation Date : [2016-10-25 11:21] 6 | # Last Modified : [2016-12-07 10:41] 7 | # Description : 8 | ################################################################################# 9 | 10 | import numpy as np 11 | import collections 12 | 13 | const_eta_search = np.logspace(-5, 5, 11, base=2) 14 | eta_search = np.logspace(-2, 8, 11, base=2) 15 | delta_search = np.logspace(-5, 5,11, base=2) 16 | r_search = np.logspace(-5, 5, 11, base=2) 17 | delta_ofs_search = np.logspace(-5, 5, 11, base=2) / 100.0 18 | 19 | fs_num = [150,160,170,180,190,200, 220,240,260,280,300] 20 | 21 | fs_opts = collections.OrderedDict() 22 | 23 | fs_opts['SOFS'] = { 24 | 'params':{'norm':'L2'}, 25 | 'cv':{'r': r_search}, 26 | 'lambda': fs_num 27 | } 28 | fs_opts['PET'] = { 29 | 'params':{'norm':'L2', 'power_t':'0'}, 30 | 'cv':{'eta':eta_search}, 31 | 'lambda': fs_num 32 | } 33 | fs_opts['FOFS'] = { 34 | 'params':{'norm':'L2'}, 35 | 'cv':{'eta': const_eta_search, 'lambda': delta_ofs_search}, 36 | 'lambda': fs_num 37 | } 38 | fs_opts['FGM'] = { 39 | 'lambda': fs_num 40 | } 41 | fs_opts['liblinear'] = { 42 | 'lambda': [0.0002,0.0003,0.0004,0.0008,0.01,0.015,0.016,0.017,0.018,0.019,0.02] 43 | } 44 | fs_opts['mRMR'] = { 45 | 'params':{'binary_thresh':0}, 46 | 'lambda': fs_num 47 | } 48 | fs_opts['GPU-mRMR'] = { 49 | 'params':{'binary_thresh':0}, 50 | 'lambda': fs_num 51 | } 52 | 53 | #fs_opts['AROW'] = { 54 | # 'params':{'norm':'L2'}, 55 | # 'cv':{'r': r_search}, 56 | # 'lambda': [-1] 57 | #} 58 | #fs_opts['OGD'] = { 59 | # 'params':{'norm':'L2'}, 60 | # 'cv':{'eta':eta_search}, 61 | # 'lambda': [-1] 62 | #} 63 | 64 | draw_opts = { 65 | 'accu':{ 66 | 'xlim':[149, 300], 67 | #'ylim':[0.73, 1] 68 | }, 69 | 'time': { 70 | 'logy': True, 71 | 'xlim':[149, 300], 72 | 'ylim':[4, 3000] 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /ofs/opts/voc2007.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | ################################################################################# 3 | # File Name : voc2007.py 4 | # Created By : yuewu 5 | # Creation Date : [2016-11-29 12:35] 6 | # Last Modified : [2016-12-06 22:07] 7 | # Description : 8 | ################################################################################# 9 | 10 | import numpy as np 11 | import collections 12 | 13 | const_eta_search = np.logspace(-5, 5, 11, base=2) 14 | eta_search = np.logspace(-2, 8, 11, base=2) 15 | delta_search = np.logspace(-5, 5,11, base=2) 16 | r_search = np.logspace(-5, 5, 11, base=2) 17 | delta_ofs_search = np.logspace(-5, 5, 11, base=2) / 100.0 18 | norm_search = ['L2', 'None'] 19 | 20 | dim = 8192 21 | passes = 1 22 | fs_num = (np.array([0.01, 0.05,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]) * dim).astype(np.int) 23 | #fs_num = (np.array([0.1,0.2,0.3,0.4,0.5,0.6]) * dim).astype(np.int) 24 | 25 | fs_opts = collections.OrderedDict() 26 | 27 | fs_opts['SOFS'] = { 28 | 'params':{'norm':'L2'}, 29 | 'cv':{'r': r_search}, 30 | 'lambda': fs_num 31 | } 32 | fs_opts['PET'] = { 33 | 'params':{'power_t':0, 'norm':'L2'}, 34 | 'cv':{'eta':const_eta_search}, 35 | 'lambda': fs_num 36 | } 37 | fs_opts['FOFS'] = { 38 | 'params':{'norm':'L2'}, 39 | 'cv':{'eta': const_eta_search, 'lambda': delta_ofs_search}, 40 | 'lambda': fs_num 41 | } 42 | fs_opts['mRMR'] = { 43 | 'params': { 44 | 'binary_thresh':0.5, 45 | 'ol_model_params': {'power_t':0, 'norm':'L2'}, 46 | }, 47 | 'lambda': fs_num 48 | } 49 | fs_opts['GPU-mRMR'] = { 50 | 'params': { 51 | 'binary_thresh':0.5, 52 | 'ol_model_params': {'power_t':0, 'norm':'L2'}, 53 | }, 54 | 'lambda': fs_num 55 | } 56 | # 57 | # 'liblinear': { 58 | # 'lambda': [0.015625,0.03125,0.0625,0.125, 0.5,2,128,512,2048,4096,16384,131072,262144,524288] 59 | 60 | draw_opts = { 61 | 'accu':{ 62 | 'ylim':[0.88,0.91], 63 | 'xlim':[500,8000], 64 | 'clip_on':True, 65 | 'legend_loc':'lower right', 66 | }, 67 | 'time': { 68 | 'logy': True, 69 | 'ylim': [3,2000], 70 | 'xlim':[500,8000], 71 | 'legend_order':100, 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /python/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | ################################################################################# 3 | # File Name : ../python/__init__.py 4 | # Created By : yuewu 5 | # Description : 6 | ################################################################################# 7 | 8 | from pysol import SOL 9 | -------------------------------------------------------------------------------- /python/pysol.pxd: -------------------------------------------------------------------------------- 1 | cimport numpy as np 2 | 3 | cdef extern from "sol/c_api.h": 4 | void* sol_CreateDataIter(int batch_size, int buf_size) 5 | void sol_ReleaseDataIter(void** data_iter) 6 | int sol_LoadData(void* data_iter, const char* path, const char* format, int pass_num) 7 | void* sol_CreateModel(const char* name, int class_num) 8 | void* sol_RestoreModel(const char* model_path) 9 | int sol_SaveModel(void* model, const char* model_path) 10 | void sol_ReleaseModel(void** model) 11 | int sol_SetModelParameter(void* model, const char* param_name, const char* param_val) 12 | ctypedef void (*get_parameter_callback)(void* user_context, const char* param_name, const char* param_val) 13 | int sol_GetModelParameters(void* model, get_parameter_callback callback, void* user_context) 14 | float sol_Train(void* model, void* data_iter) 15 | float sol_Test(void* model, void* data_iter, const char* output_path) 16 | ctypedef void (*sol_predict_callback)(void* user_context, double label, double predict, int cls_num, float* scores) 17 | int sol_Predict(void* model, void* data_iter, sol_predict_callback callback, void* user_context) 18 | float sol_model_sparsity(void* model) 19 | ctypedef void (*inspect_iterate_callback)(void* user_context, long long data_num, long long iter_num, 20 | long long update_num, double err_rate) 21 | void sol_InspectOnlineIteration(void* model, inspect_iterate_callback callback, void* user_context) 22 | int sol_loadArray(void* data_iter, char* X, char* Y, np.npy_intp* dims, np.npy_intp* strides, int pass_num) 23 | int sol_loadCsrMatrix(void* data_iter, char* indices, char* indptr, char* features, char* Y, int n_samples, int pass_num) 24 | int sol_analyze_data(const char* data_path, const char* data_type, const char* output_path) 25 | int sol_convert_data(const char* src_path, const char* src_type, const char* dst_path, const char* dst_type, bint binarize, float binarize_thresh) 26 | int sol_shuffle_data(const char* src_path, const char* src_type, const char* dst_path, const char* dst_type) 27 | int sol_split_data(const char* src_path, const char* src_type, int fold, const char* output_prefix, const char* dst_type, bint shuffle) 28 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | ################################################################################# 3 | # File Name : setup.py 4 | # Created By : yuewu 5 | # Description : 6 | ################################################################################# 7 | 8 | try: 9 | from setuptools import setup, Extension 10 | except ImportError: 11 | from distutils.core import setup, Extension 12 | 13 | from Cython.Build import cythonize 14 | 15 | import sys 16 | import os 17 | 18 | try: 19 | from pypandoc import convert 20 | 21 | def read_md(fpath): 22 | return convert(fpath, 'rst') 23 | 24 | except ImportError: 25 | print("warning: pypandoc module not found, DONOT convert Markdown to RST") 26 | 27 | def read_md(fpath): 28 | with open(fpath, 'r') as fp: 29 | return fp.read() 30 | 31 | sys.path.append("python") 32 | 33 | 34 | def get_source_files(root_dir): 35 | src_files = [] 36 | for pathname in os.listdir(root_dir): 37 | path = os.path.join(root_dir, pathname) 38 | if os.path.isfile(path): 39 | ext = os.path.splitext(path)[1] 40 | if ext in ['.cc', '.cpp', '.c']: 41 | src_files.append(path) 42 | elif os.path.isdir(path): 43 | src_files = src_files + get_source_files(path) 44 | return src_files 45 | def get_include_dirs(): 46 | import numpy as np 47 | return [np.get_include(), "include", "external"] 48 | 49 | if os.name == 'nt': 50 | extra_flags = ['/wd4251','/wd4275', '/EHsc','-DSOL_EMBED_PACKAGE'] 51 | dependencies = [] 52 | else: 53 | extra_flags = ['-std=c++11','-pthread'] 54 | dependencies = [ 55 | "numpy >= 1.7.0", 56 | "scipy >= 0.13.0", 57 | "scikit-learn >= 0.18.1", 58 | "matplotlib >= 1.5.1" 59 | ] 60 | 61 | 62 | ext_modules = [ 63 | Extension( 64 | "pysol", 65 | sources=["python/pysol.pyx"] + get_source_files('src/sol') + 66 | get_source_files('external/json'), 67 | language='c++', 68 | include_dirs=get_include_dirs(), 69 | extra_compile_args=['-DHAS_NUMPY_DEV', '-DUSE_STD_THREAD'] + extra_flags) 70 | ] 71 | 72 | 73 | setup( 74 | name='sol', 75 | version='1.1.0', 76 | description='Library for Scalable Online Learning', 77 | #long_description=read_md('README.md'), 78 | author='Yue Wu, Steven C.H. Hoi', 79 | author_email='yuewu@outlook.com', 80 | maintainer='Yue Wu', 81 | maintainer_email='yuewu@outlook.com', 82 | url='http://sol.stevenhoi.org', 83 | license='Apache 2.0', 84 | keywords='Scalable Online Learning', 85 | packages=['sol'], 86 | package_dir={'sol': 'python'}, 87 | entry_points = { 88 | 'console_scripts':[ 89 | 'sol_train=sol.sol_train:main', 90 | 'sol_test=sol.sol_test:main', 91 | ], 92 | }, 93 | ext_modules=cythonize(ext_modules), 94 | install_requires=dependencies 95 | ) 96 | -------------------------------------------------------------------------------- /src/sol/loss/bool_loss.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : bool_loss.cc 3 | * Created By : yuewu 4 | * Description : bool loss with yes or no 5 | **********************************************************************************/ 6 | 7 | #include "sol/loss/bool_loss.h" 8 | 9 | #include 10 | 11 | namespace sol { 12 | namespace loss { 13 | 14 | float BoolLoss::loss(const pario::DataPoint& dp, float* predict, 15 | label_t predict_label, int cls_num) { 16 | return predict_label == dp.label() ? 0.f : 1.f; 17 | } 18 | 19 | float BoolLoss::gradient(const pario::DataPoint& dp, float* predict, 20 | label_t predict_label, float* gradient, int cls_num) { 21 | label_t label = dp.label(); 22 | float loss = predict_label == label ? 0.f : 1.f; 23 | if (loss > 0) { 24 | *gradient = (float)(-label); 25 | } else { 26 | *gradient = 0; 27 | } 28 | return loss; 29 | } 30 | 31 | RegisterLoss(BoolLoss, "bool", "Bool Loss"); 32 | 33 | float MaxScoreBoolLoss::loss(const pario::DataPoint& dp, float* predict, 34 | label_t predict_label, int cls_num) { 35 | return predict_label == dp.label() ? 0.f : 1.f; 36 | } 37 | 38 | float MaxScoreBoolLoss::gradient(const pario::DataPoint& dp, float* predict, 39 | label_t predict_label, float* gradient, 40 | int cls_num) { 41 | label_t label = dp.label(); 42 | float loss = predict_label == label ? 0.f : 1.f; 43 | 44 | if (loss > 0) { 45 | for (int i = 0; i < cls_num; ++i) { 46 | gradient[i] = 0; 47 | } 48 | gradient[predict_label] = 1; 49 | gradient[label] = -1; 50 | } 51 | return loss; 52 | } 53 | 54 | RegisterLoss(MaxScoreBoolLoss, "maxscore-bool", "Max-Score Bool Loss"); 55 | 56 | float UniformBoolLoss::loss(const pario::DataPoint& dp, float* predict, 57 | label_t predict_label, int cls_num) { 58 | return predict_label == dp.label() ? 0.f : 1.f; 59 | } 60 | 61 | float UniformBoolLoss::gradient(const pario::DataPoint& dp, float* predict, 62 | label_t predict_label, float* gradient, 63 | int cls_num) { 64 | label_t label = dp.label(); 65 | float loss = predict_label == label ? 0.f : 1.f; 66 | 67 | if (loss > 0) { 68 | size_t false_num = std::count_if( 69 | predict, predict + cls_num, 70 | [&predict, &label](float val) { return val >= predict[label]; }); 71 | false_num -= 1; 72 | float alpha = 1.f / false_num; 73 | for (int i = 0; i < cls_num; ++i) { 74 | if (predict[i] >= predict[label]) 75 | gradient[i] = alpha; 76 | else 77 | gradient[i] = 0; 78 | } 79 | gradient[label] = -1; 80 | } 81 | return loss; 82 | } 83 | 84 | RegisterLoss(UniformBoolLoss, "uniform-bool", "Uniform Bool Loss"); 85 | 86 | } // namespace loss 87 | } // namespace sol 88 | -------------------------------------------------------------------------------- /src/sol/loss/loss.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : loss.cc 3 | * Created By : yuewu 4 | * Creation Date : [2016-02-14 23:23] 5 | * Last Modified : [2016-02-14 23:24] 6 | * Description : base class for loss functions 7 | **********************************************************************************/ 8 | 9 | #include "sol/loss/loss.h" 10 | 11 | namespace sol { 12 | namespace loss { 13 | 14 | Loss* Loss::Create(const std::string& type) { 15 | auto create_func = CreateObject(std::string(type) + "_loss"); 16 | return create_func == nullptr ? nullptr : create_func(); 17 | } 18 | 19 | } // namespace loss 20 | } // namespace sol 21 | -------------------------------------------------------------------------------- /src/sol/loss/square_loss.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : square_loss.cc 3 | * Created By : yuewu 4 | * Creation Date : [2016-02-14 23:44] 5 | * Last Modified : [2016-02-15 00:01] 6 | * Description : square loss 7 | **********************************************************************************/ 8 | 9 | #include "sol/loss/square_loss.h" 10 | 11 | #include 12 | #include 13 | 14 | namespace sol { 15 | namespace loss { 16 | 17 | float SquareLoss::loss(const pario::DataPoint& dp, float* predict, 18 | label_t predict_label, int cls_num) { 19 | return (*predict - dp.label()) * (*predict - dp.label()) * 0.5f; 20 | } 21 | 22 | float SquareLoss::gradient(const pario::DataPoint& dp, float* predict, 23 | label_t predict_label, float* gradient, 24 | int cls_num) { 25 | *gradient = *predict - float(dp.label()); 26 | return *gradient * *gradient * 0.5f; 27 | } 28 | 29 | RegisterLoss(SquareLoss, "square", "Square Loss"); 30 | 31 | } // namespace loss 32 | } // namespace sol 33 | -------------------------------------------------------------------------------- /src/sol/model/olm/pa.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : pa.cc 3 | * Created By : yuewu 4 | * Description : 5 | **********************************************************************************/ 6 | 7 | #include "sol/model/olm/pa.h" 8 | #include 9 | 10 | using namespace std; 11 | using namespace sol::math::expr; 12 | 13 | namespace sol { 14 | 15 | namespace model { 16 | 17 | PA::PA(int class_num) : OnlineLinearModel(class_num) { 18 | this->eta_coeff_ = class_num == 2 ? 1.f : 2.f; 19 | } 20 | void PA::Update(const pario::DataPoint& dp, const float*, float loss) { 21 | const auto& x = dp.data(); 22 | this->eta_ = loss / (this->eta_coeff_ * Norm2(x)); 23 | 24 | for (int c = 0; c < this->clf_num_; ++c) { 25 | if (g(c) == 0) continue; 26 | w(c) -= eta_ * g(c) * x; 27 | // update bias 28 | w(c)[0] -= bias_eta() * g(c); 29 | } 30 | } 31 | 32 | RegisterModel(PA, "pa", "Online Passive Aggressive"); 33 | 34 | void PAI::SetParameter(const std::string& name, const std::string& value) { 35 | if (name == "C") { 36 | this->C_ = stof(value); 37 | } else { 38 | PA::SetParameter(name, value); 39 | } 40 | } 41 | 42 | void PAI::Update(const pario::DataPoint& dp, const float*, float loss) { 43 | const auto& x = dp.data(); 44 | this->eta_ = (std::min)(this->C_, loss / (eta_coeff_ * Norm2(x))); 45 | 46 | for (int c = 0; c < this->clf_num_; ++c) { 47 | if (g(c) == 0) continue; 48 | w(c) -= eta_ * g(c) * x; 49 | // update bias 50 | w(c)[0] -= bias_eta() * g(c); 51 | } 52 | } 53 | void PAI::GetModelInfo(Json::Value& root) const { 54 | PA::GetModelInfo(root); 55 | root["online"]["C"] = this->C_; 56 | } 57 | 58 | RegisterModel(PAI, "pa1", "Online Passive Aggressive-1"); 59 | 60 | void PAII::SetParameter(const std::string& name, const std::string& value) { 61 | if (name == "C") { 62 | this->C_ = stof(value); 63 | } else { 64 | PA::SetParameter(name, value); 65 | } 66 | } 67 | 68 | void PAII::Update(const pario::DataPoint& dp, const float*, float loss) { 69 | const auto& x = dp.data(); 70 | this->eta_ = loss / (eta_coeff_ * Norm2(x) + 0.5f / C_); 71 | for (int c = 0; c < this->clf_num_; ++c) { 72 | if (g(c) == 0) continue; 73 | w(c) -= eta_ * g(c) * x; 74 | // update bias 75 | w(c)[0] -= bias_eta() * g(c); 76 | } 77 | } 78 | 79 | void PAII::GetModelInfo(Json::Value& root) const { 80 | PA::GetModelInfo(root); 81 | root["online"]["C"] = this->C_; 82 | } 83 | RegisterModel(PAII, "pa2", "Online Passive Aggressive-2"); 84 | 85 | } // namespace model 86 | } // namespace sol 87 | -------------------------------------------------------------------------------- /src/sol/model/olm/perceptron.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : perceptron.cc 3 | * Created By : yuewu 4 | * Description : Perceptron Algorithm 5 | **********************************************************************************/ 6 | 7 | #include "sol/model/olm/perceptron.h" 8 | #include "sol/loss/bool_loss.h" 9 | 10 | using namespace std; 11 | using namespace sol; 12 | 13 | namespace sol { 14 | 15 | namespace model { 16 | Perceptron::Perceptron(int class_num) : OnlineLinearModel(class_num) { 17 | // loss 18 | if (class_num == 2) { 19 | this->SetParameter("loss", "bool"); 20 | } else { 21 | this->SetParameter("loss", "maxscore-bool"); 22 | } 23 | } 24 | 25 | void Perceptron::SetParameter(const std::string& name, 26 | const std::string& value) { 27 | if (name == "loss") { 28 | OnlineLinearModel::SetParameter(name, value); 29 | if ((this->loss_->type() & loss::Loss::Type::BOOL) == 0) { 30 | throw invalid_argument("only bool-based loss functions are allowed"); 31 | } 32 | } else { 33 | OnlineLinearModel::SetParameter(name, value); 34 | } 35 | } 36 | 37 | void Perceptron::Update(const pario::DataPoint& dp, const float*, float) { 38 | const auto& x = dp.data(); 39 | this->eta_ = 1.f; 40 | 41 | for (int c = 0; c < this->clf_num_; ++c) { 42 | if (g(c) == 0) continue; 43 | w(c) -= g(c) * x; 44 | // update bias 45 | w(c)[0] -= bias_eta() * g(c); 46 | } 47 | } 48 | 49 | RegisterModel(Perceptron, "perceptron", "perceptron algorithm"); 50 | 51 | } // namespace model 52 | } // namespace sol 53 | -------------------------------------------------------------------------------- /src/sol/model/regularizer.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : regularizer.cc 3 | * Created By : yuewu 4 | * Description : model regularizers 5 | **********************************************************************************/ 6 | 7 | #include "sol/model/regularizer.h" 8 | #include "sol/util/error_code.h" 9 | 10 | #include 11 | 12 | using namespace std; 13 | using namespace sol::math::expr; 14 | 15 | namespace sol { 16 | namespace model { 17 | 18 | int Regularizer::SetParameter(const std::string &name, 19 | const std::string &value) { 20 | int ret = Status_OK; 21 | if (name == "lambda") { 22 | this->lambda_ = stof(value); 23 | } else { 24 | ret = Status_Invalid_Argument; 25 | } 26 | return ret; 27 | } 28 | void Regularizer::GetRegularizerInfo(Json::Value &root) const { 29 | root["regularizer"]["lambda"] = this->lambda_; 30 | } 31 | 32 | OnlineL1Regularizer::OnlineL1Regularizer() : sparse_thresh_(1e-6f) {} 33 | 34 | int OnlineL1Regularizer::SetParameter(const std::string &name, 35 | const std::string &value) { 36 | if (name == "sparse_thresh") { 37 | this->sparse_thresh_ = stof(value); 38 | } else { 39 | return OnlineRegularizer::SetParameter(name, value); 40 | } 41 | return Status_OK; 42 | } 43 | 44 | void OnlineL1Regularizer::FinalizeRegularization(math::Vector &w) { 45 | float thresh = this->sparse_thresh_; 46 | w.slice_op([thresh](real_t &val) { 47 | if (val < thresh && val > -thresh) val = 0; 48 | }); 49 | } 50 | 51 | LazyOnlineL1Regularizer::LazyOnlineL1Regularizer() : initial_t_(0) { 52 | this->last_update_time_.resize(1); 53 | this->last_update_time_ = 0; 54 | } 55 | 56 | int LazyOnlineL1Regularizer::SetParameter(const std::string &name, 57 | const std::string &value) { 58 | if (name == "t0") { 59 | this->initial_t_ = stof(value); 60 | this->last_update_time_ = this->initial_t_; 61 | } else { 62 | return OnlineL1Regularizer::SetParameter(name, value); 63 | } 64 | return Status_OK; 65 | } 66 | 67 | void LazyOnlineL1Regularizer::BeginIterate(const pario::DataPoint &dp) { 68 | // update dim 69 | size_t d = this->last_update_time_.dim(); 70 | if (dp.dim() > d) { 71 | this->last_update_time_.resize(dp.dim()); 72 | real_t t0 = this->initial_t_; 73 | this->last_update_time_.slice_op([t0](float &val) { val = t0; }, d); 74 | } 75 | } 76 | 77 | void LazyOnlineL1Regularizer::EndIterate(const pario::DataPoint &dp, 78 | int cur_iter_num) { 79 | // update last update time 80 | const auto &x = dp.data(); 81 | auto &last_update_time = this->last_update_time_; 82 | real_t time_stamp = real_t(cur_iter_num - 1); 83 | x.indexes().slice_op([&last_update_time, time_stamp](const index_t &idx) { 84 | last_update_time[idx] = time_stamp; 85 | }); 86 | this->last_update_time_[0] = time_stamp; 87 | } 88 | 89 | } // namespace model 90 | } // namespace sol 91 | -------------------------------------------------------------------------------- /src/sol/pario/binary_reader.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : binary_reader.cc 3 | * Created By : yuewu 4 | * Creation Date : [2015-11-13 20:31] 5 | * Last Modified : [2015-11-14 17:16] 6 | * Description : binary format data reader 7 | **********************************************************************************/ 8 | 9 | #include "sol/pario/binary_reader.h" 10 | 11 | #include 12 | 13 | #include "sol/pario/compress.h" 14 | #include "sol/util/error_code.h" 15 | 16 | namespace sol { 17 | namespace pario { 18 | 19 | int BinaryReader::Open(const std::string& path, const char* mode) { 20 | return DataFileReader::Open(path, "rb"); 21 | } 22 | 23 | int BinaryReader::Next(DataPoint& dst_data) { 24 | dst_data.Clear(); 25 | label_t label; 26 | int ret = this->file_reader_.Read((char*)&label, sizeof(label_t)); 27 | if (ret != Status_OK) return ret; 28 | dst_data.set_label(label); 29 | 30 | size_t feat_num; 31 | ret = this->file_reader_.Read((char*)&feat_num, sizeof(feat_num)); 32 | if (ret != Status_OK) { 33 | fprintf(stderr, "load feature number failed!\n"); 34 | this->is_good_ = false; 35 | return false; 36 | } 37 | if (feat_num > 0) { 38 | size_t code_len = 0; 39 | ret = this->file_reader_.Read((char*)&code_len, sizeof(size_t)); 40 | if (ret != Status_OK) { 41 | fprintf(stderr, "read coded index length failed!\n"); 42 | return Status_Invalid_Format; 43 | } 44 | this->comp_codes_.resize(code_len); 45 | ret = this->file_reader_.Read(this->comp_codes_.begin(), code_len); 46 | if (ret != Status_OK) { 47 | fprintf(stderr, "read coded index failed!\n"); 48 | return Status_Invalid_Format; 49 | } 50 | dst_data.Resize(feat_num); 51 | decomp_index(this->comp_codes_, dst_data.indexes()); 52 | if (dst_data.indexes().size() != feat_num) { 53 | fprintf(stderr, "decoded index number is not correct!\n"); 54 | return Status_Invalid_Format; 55 | } 56 | 57 | ret = this->file_reader_.Read((char*)dst_data.features().begin(), 58 | sizeof(real_t) * feat_num); 59 | if (ret != Status_OK) { 60 | fprintf(stderr, "load features failed!\n"); 61 | return Status_Invalid_Format; 62 | } 63 | } 64 | return ret; 65 | } 66 | 67 | RegisterDataReader(BinaryReader, "bin", "binary format data reader"); 68 | 69 | } // namespace pario 70 | } // namespace sol 71 | -------------------------------------------------------------------------------- /src/sol/pario/binary_writer.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : binary_writer.cc 3 | * Created By : yuewu 4 | * Creation Date : [2015-11-14 15:34] 5 | * Last Modified : [2015-11-14 15:49] 6 | * Description : binary format data writer 7 | **********************************************************************************/ 8 | #include "sol/pario/binary_writer.h" 9 | 10 | #include 11 | 12 | #include "sol/pario/compress.h" 13 | #include "sol/util/error_code.h" 14 | 15 | namespace sol { 16 | namespace pario { 17 | 18 | int BinaryWriter::Open(const std::string& path, const char* mode) { 19 | return DataWriter::Open(path, "wb"); 20 | } 21 | 22 | int BinaryWriter::Write(const DataPoint& data) { 23 | label_t label = data.label(); 24 | this->file_writer_.Write((char*)&label, sizeof(label)); 25 | size_t feat_num = data.indexes().size(); 26 | 27 | this->file_writer_.Write((char*)&feat_num, sizeof(feat_num)); 28 | if (feat_num > 0) { 29 | this->comp_codes_.clear(); 30 | comp_index(data.indexes(), this->comp_codes_); 31 | size_t code_len = this->comp_codes_.size(); 32 | this->file_writer_.Write((char*)&(code_len), sizeof(code_len)); 33 | this->file_writer_.Write(this->comp_codes_.begin(), code_len); 34 | this->file_writer_.Write((char*)(data.features().begin()), 35 | sizeof(real_t) * feat_num); 36 | } 37 | return Status_OK; 38 | } 39 | 40 | RegisterDataWriter(BinaryWriter, "bin", "binary format data writer"); 41 | 42 | } // namespace pario 43 | } // namespace sol 44 | -------------------------------------------------------------------------------- /src/sol/pario/csv_reader.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : /home/yuewu/work/sol/src/sol/pario/csv_reader.cc 3 | * Created By : yuewu 4 | * Creation Date : [2015-11-13 19:39] 5 | * Last Modified : [2016-02-12 21:23] 6 | * Description : 7 | **********************************************************************************/ 8 | #include "sol/pario/csv_reader.h" 9 | 10 | #include 11 | 12 | #include "sol/pario/numeric_parser.h" 13 | 14 | namespace sol { 15 | namespace pario { 16 | 17 | CSVReader::CSVReader() : DataFileReader() { this->feat_dim_ = 0; } 18 | 19 | int CSVReader::Open(const std::string& path, const char* mode) { 20 | int ret = DataFileReader::Open(path); 21 | if (ret == Status_OK) { 22 | ret = this->LoadFeatDim(); 23 | } 24 | this->is_good_ = ret == Status_OK ? true : false; 25 | return ret; 26 | } 27 | 28 | void CSVReader::Rewind() { 29 | DataFileReader::Rewind(); 30 | // read the first line for csv 31 | this->file_reader_.ReadLine(this->read_buf_, this->read_buf_size_); 32 | } 33 | 34 | int CSVReader::Next(DataPoint& dst_data) { 35 | int ret = this->file_reader_.ReadLine(this->read_buf_, this->read_buf_size_); 36 | if (ret != Status_OK) return ret; 37 | 38 | char* iter = this->read_buf_, *endptr = nullptr; 39 | if (*iter == '\0') { 40 | fprintf(stderr, "incorrect line\n"); 41 | return Status_Invalid_Format; 42 | } 43 | 44 | dst_data.Clear(); 45 | // 1. parse label 46 | dst_data.set_label(label_t(NumericParser::ParseInt(iter, endptr))); 47 | if (endptr == iter) { 48 | fprintf(stderr, "parse label failed.\n"); 49 | this->is_good_ = false; 50 | return Status_Invalid_Format; 51 | } 52 | iter = endptr; 53 | 54 | // 2. parse features 55 | dst_data.Reserve(this->feat_dim_); 56 | index_t index = 1; 57 | while (*iter != '\0') { 58 | if (*iter != ',') { 59 | fprintf(stderr, "incorrect input file (%s)!\n", iter); 60 | this->is_good_ = false; 61 | return Status_Invalid_Format; 62 | } 63 | ++iter; 64 | 65 | real_t feat = NumericParser::ParseFloat(iter, endptr); 66 | if (endptr == iter) { 67 | fprintf(stderr, "parse feature value (%s) failed!\n", iter); 68 | this->is_good_ = false; 69 | return Status_Invalid_Format; 70 | } 71 | iter = endptr; 72 | 73 | if (feat != 0) { 74 | dst_data.AddNewFeat(index, feat); 75 | } 76 | ++index; 77 | } 78 | 79 | return ret; 80 | } 81 | 82 | int CSVReader::LoadFeatDim() { 83 | int ret = this->file_reader_.ReadLine(this->read_buf_, this->read_buf_size_); 84 | if (ret != Status_OK) return ret; 85 | char* p = this->read_buf_; 86 | this->feat_dim_ = 0; 87 | while (*p != '\0') { 88 | if (*p++ == ',') ++this->feat_dim_; 89 | } 90 | ++this->feat_dim_; 91 | return ret; 92 | } 93 | 94 | RegisterDataReader(CSVReader, "csv", "csv format data reader"); 95 | 96 | } // namespace pario 97 | } // namespace sol 98 | -------------------------------------------------------------------------------- /src/sol/pario/csv_writer.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : csv_writer.cc 3 | * Created By : yuewu 4 | * Creation Date : [2015-11-14 15:10] 5 | * Last Modified : [2015-11-14 15:46] 6 | * Description : writer of csv format data 7 | **********************************************************************************/ 8 | 9 | #include "sol/pario/csv_writer.h" 10 | 11 | #include 12 | 13 | namespace sol { 14 | namespace pario { 15 | 16 | int CSVWriter::Write(const DataPoint& data) { 17 | size_t feat_num = data.indexes().size(); 18 | this->file_writer_.Printf("%d", data.label()); 19 | 20 | size_t i = 0; 21 | index_t j = 1; 22 | for (; i < feat_num && j < this->feat_dim_; ++j) { 23 | if (data.index(i) == j) { 24 | this->file_writer_.Printf(",%g", data.feature(i++)); 25 | } else { 26 | this->file_writer_.Printf(",0"); 27 | } 28 | } 29 | for (; j < this->feat_dim_; ++j) this->file_writer_.Printf(",0"); 30 | 31 | this->file_writer_.Printf("\n"); 32 | return Status_OK; 33 | } 34 | 35 | int CSVWriter::SetExtraInfo(const char* extra_info) { 36 | if (this->Good() == false) { 37 | return Status_IO_Error; 38 | } 39 | this->feat_dim_ = *((index_t*)(extra_info)); 40 | std::ostringstream oss; 41 | oss << "class"; 42 | // the index starts from 1 43 | for (index_t i = 1; i < this->feat_dim_; ++i) { 44 | oss << ",v" << i; 45 | } 46 | this->file_writer_.Printf("%s\n", oss.str().c_str()); 47 | return Status_OK; 48 | } 49 | 50 | RegisterDataWriter(CSVWriter, "csv", "csv format data writer"); 51 | 52 | } // namespace pario 53 | } // namespace sol 54 | -------------------------------------------------------------------------------- /src/sol/pario/data_point.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : data_point.cc 3 | * Created By : yuewu 4 | * Creation Date : [2015-10-29 13:12] 5 | * Last Modified : [2015-11-13 21:01] 6 | * Description : Data Point Structure 7 | **********************************************************************************/ 8 | 9 | #include "sol/pario/data_point.h" 10 | 11 | #include 12 | 13 | namespace sol { 14 | namespace pario { 15 | 16 | DataPoint::DataPoint() : label_(0) {} 17 | 18 | void DataPoint::Clone(DataPoint &dst_pt) const { 19 | dst_pt.label_ = this->label_; 20 | dst_pt.Reserve(this->size()); 21 | dst_pt.Resize(this->size()); 22 | memcpy(dst_pt.indexes().begin(), this->indexes().begin(), 23 | this->size() * sizeof(index_t)); 24 | memcpy(dst_pt.features().begin(), this->features().begin(), 25 | this->size() * sizeof(real_t)); 26 | } 27 | 28 | DataPoint DataPoint::Clone() const { 29 | DataPoint dst_pt; 30 | this->Clone(dst_pt); 31 | return dst_pt; 32 | } 33 | 34 | void DataPoint::AddNewFeat(index_t index, real_t feat) { 35 | this->data_.push_back(index, feat); 36 | } 37 | 38 | void DataPoint::Clear() { 39 | this->data_.clear(); 40 | this->label_ = 0; 41 | } 42 | 43 | bool DataPoint::IsSorted() const { 44 | for (auto iter = this->indexes().begin() + 1; iter < this->indexes().end(); 45 | ++iter) { 46 | if (*iter <= *(iter - 1)) return false; 47 | } 48 | return true; 49 | } 50 | 51 | template 52 | void QuickSort(T1 *a, T2 *b, size_t low, size_t high) { // from small to great 53 | size_t i = low; 54 | size_t j = high; 55 | T1 temp = a[low]; // select the first element as the indicator 56 | T2 temp_ind = b[low]; 57 | 58 | while (i < j) { 59 | while ((i < j) && (temp < a[j])) { // scan right side 60 | j--; 61 | } 62 | if (i < j) { 63 | a[i] = a[j]; 64 | b[i] = b[j]; 65 | i++; 66 | } 67 | 68 | while (i < j && (a[i] < temp)) { // scan left side 69 | i++; 70 | } 71 | if (i < j) { 72 | a[j] = a[i]; 73 | b[j] = b[i]; 74 | j--; 75 | } 76 | } 77 | a[i] = temp; 78 | b[i] = temp_ind; 79 | 80 | if (low < i) { 81 | QuickSort(a, b, low, i - 1); // sort left subset 82 | } 83 | if (i < high) { 84 | QuickSort(a, b, j + 1, high); // sort right subset 85 | } 86 | } 87 | 88 | void DataPoint::Sort() { 89 | if (this->IsSorted() == false) { 90 | QuickSort(this->indexes().begin(), this->features().begin(), 0, 91 | this->indexes().size() - 1); 92 | } 93 | } 94 | 95 | } // namespace pario 96 | } // namespace sol 97 | -------------------------------------------------------------------------------- /src/sol/pario/data_read_task.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : data_read_task.cc 3 | * Created By : yuewu 4 | * Creation Date : [2016-02-12 15:18] 5 | * Last Modified : [2016-03-09 19:24] 6 | * Description : 7 | **********************************************************************************/ 8 | 9 | #include "sol/pario/data_read_task.h" 10 | #include "sol/util/error_code.h" 11 | 12 | namespace sol { 13 | namespace pario { 14 | DataReadTask::DataReadTask(const std::string& path, const std::string& dtype, 15 | BlockQueue& mini_batch_factory, 16 | BlockQueue& mini_batch_buf, int pass_num) 17 | : mini_batch_factory_(mini_batch_factory), 18 | mini_batch_buf_(mini_batch_buf), 19 | pass_num_(pass_num) { 20 | DataReader* reader = DataReader::Create(dtype); 21 | if (reader != nullptr) { 22 | if (reader->Open(path) != Status_OK) { 23 | delete reader; 24 | reader = nullptr; 25 | } 26 | } 27 | this->reader_.reset(reader); 28 | } 29 | 30 | void DataReadTask::run() { 31 | int status = Status_OK; 32 | DataReader* reader = this->reader_.get(); 33 | while (status == Status_OK && this->pass_num_ > 0) { 34 | MiniBatch* mini_batch = this->mini_batch_factory_.Dequeue(); 35 | if (mini_batch == nullptr) { // exit signal 36 | this->mini_batch_factory_.Enqueue(nullptr); 37 | break; 38 | } 39 | mini_batch->data_num = 0; 40 | while (mini_batch->data_num < mini_batch->capacity() && 41 | status == Status_OK) { 42 | status = reader->Next((*mini_batch)[mini_batch->data_num]); 43 | if (status == Status_OK) { 44 | ++mini_batch->data_num; 45 | continue; 46 | } else if (status == Status_EndOfFile) { 47 | --this->pass_num_; 48 | reader->Rewind(); 49 | status = Status_OK; 50 | break; 51 | } else 52 | break; 53 | } 54 | this->mini_batch_buf_.Enqueue(mini_batch); 55 | } 56 | reader->Close(); 57 | this->mini_batch_buf_.Enqueue(nullptr); 58 | } 59 | 60 | } // namespace pario 61 | } // namespace sol 62 | -------------------------------------------------------------------------------- /src/sol/pario/data_reader.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : data_reader.cc 3 | * Created By : yuewu 4 | * Creation Date : [2015-11-11 21:38] 5 | * Last Modified : [2015-11-13 21:32] 6 | * Description : Interface for data reader (svm,binary, etc.) 7 | **********************************************************************************/ 8 | 9 | #include "sol/pario/data_reader.h" 10 | 11 | #include 12 | 13 | #include "sol/util/error_code.h" 14 | 15 | using namespace std; 16 | 17 | namespace sol { 18 | namespace pario { 19 | 20 | DataReader* DataReader::Create(const std::string& type) { 21 | auto create_func = CreateObject(std::string(type) + "_reader"); 22 | return create_func == nullptr ? nullptr : create_func(); 23 | } 24 | 25 | DataReader::DataReader() {} 26 | DataReader::~DataReader() {} 27 | 28 | DataFileReader::DataFileReader() { 29 | this->read_buf_size_ = 4096; 30 | this->read_buf_ = (char*)malloc(this->read_buf_size_ * sizeof(char)); 31 | this->is_good_ = true; 32 | } 33 | 34 | DataFileReader::~DataFileReader() { 35 | this->Close(); 36 | if (this->read_buf_ != nullptr) { 37 | free(this->read_buf_); 38 | } 39 | } 40 | 41 | int DataFileReader::Open(const string& path, const char* mode) { 42 | this->Close(); 43 | this->file_path_ = path; 44 | int ret = this->file_reader_.Open(path.c_str(), mode); 45 | this->is_good_ = ret == Status_OK ? true : false; 46 | 47 | return ret; 48 | } 49 | 50 | } // namespace pario 51 | } // namespace sol 52 | -------------------------------------------------------------------------------- /src/sol/pario/data_writer.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : data_writer.cc 3 | * Created By : yuewu 4 | * Creation Date : [2015-11-14 15:01] 5 | * Last Modified : [2015-11-14 15:03] 6 | * Description : Interface for data writer (svm,binary, etc.) 7 | **********************************************************************************/ 8 | 9 | #include "sol/pario/data_writer.h" 10 | 11 | #include 12 | 13 | #include "sol/util/error_code.h" 14 | 15 | using namespace std; 16 | 17 | namespace sol { 18 | namespace pario { 19 | 20 | DataWriter* DataWriter::Create(const std::string& type) { 21 | auto create_func = CreateObject(std::string(type) + "_writer"); 22 | return create_func == nullptr ? nullptr : create_func(); 23 | } 24 | 25 | DataWriter::DataWriter() { this->is_good_ = true; } 26 | 27 | DataWriter::~DataWriter() { this->Close(); } 28 | 29 | int DataWriter::Open(const string& path, const char* mode) { 30 | this->Close(); 31 | this->file_path_ = path; 32 | int ret = this->file_writer_.Open(path.c_str(), mode); 33 | this->is_good_ = ret == Status_OK ? true : false; 34 | 35 | return ret; 36 | } 37 | 38 | } // namespace pario 39 | } // namespace sol 40 | -------------------------------------------------------------------------------- /src/sol/pario/file_writer.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : file_writer.cc 3 | * Created By : yuewu 4 | * Creation Date : [2015-10-17 11:02] 5 | * Last Modified : [2015-11-14 15:49] 6 | * Description : basic file writer 7 | **********************************************************************************/ 8 | 9 | #include "sol/pario/file_writer.h" 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include "sol/util/util.h" 16 | #include "sol/util/error_code.h" 17 | 18 | using namespace std; 19 | 20 | namespace sol { 21 | namespace pario { 22 | 23 | FileWriter::FileWriter() : file_(nullptr) {} 24 | FileWriter::FileWriter(const char* path, const char* mode) : file_(nullptr) { 25 | this->Open(path, mode); 26 | } 27 | 28 | FileWriter::~FileWriter() { this->Close(); } 29 | 30 | int FileWriter::Open(const char* path, const char* mode) { 31 | this->Close(); 32 | // open file 33 | if (strcmp(path, "-") == 0) { 34 | this->file_ = stdout; 35 | path = "stdout"; 36 | } else { 37 | this->file_ = open_file(path, mode); 38 | } 39 | 40 | if (this->file_ == nullptr || this->Good() == false) { 41 | this->Close(); 42 | fprintf(stderr, "Error: open file (%s) failed.\n", path); 43 | return Status_IO_Error; 44 | } 45 | 46 | return Status_OK; 47 | } 48 | 49 | void FileWriter::Close() { 50 | if (this->file_ != nullptr && this->file_ != stdout) { 51 | fclose(this->file_); 52 | } 53 | this->file_ = nullptr; 54 | } 55 | 56 | bool FileWriter::Good() { 57 | // we do not need to handle eof here, when eof is set, ferror still returns 58 | // 0 59 | return this->file_ != nullptr && ferror(this->file_) == 0; 60 | } 61 | 62 | int FileWriter::Write(char* src_buf, size_t length) { 63 | size_t write_len = fwrite(src_buf, 1, length, this->file_); 64 | if (write_len == length) { 65 | return Status_OK; 66 | } else { 67 | cerr << "Error " << Status_IO_Error << ": only " << write_len 68 | << " bytes are written while " << length << " bytes are specified.\n"; 69 | return Status_IO_Error; 70 | } 71 | } 72 | 73 | int FileWriter::Printf(const char* format, ...) { 74 | va_list argptr; 75 | va_start(argptr, format); 76 | 77 | int ret = Status_OK; 78 | if ((ret = vfprintf(this->file_, format, argptr)) < 0) { 79 | fprintf(stderr, "vfprintf failed in %s:%d\n", __FILE__, __LINE__); 80 | ret = Status_IO_Error; 81 | } 82 | va_end(argptr); 83 | return ret; 84 | } 85 | 86 | } // namespace pario 87 | } // namespace sol 88 | -------------------------------------------------------------------------------- /src/sol/pario/numpy_reader.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : numpy_reader.cc 3 | * Created By : yuewu 4 | * Description : reader for numpy array 5 | **********************************************************************************/ 6 | #ifdef HAS_NUMPY_DEV 7 | 8 | #include "sol/pario/numpy_reader.h" 9 | 10 | #include 11 | #include 12 | 13 | #include "sol/util/str_util.h" 14 | 15 | using namespace std; 16 | 17 | namespace sol { 18 | namespace pario { 19 | 20 | int NumpyReader::Open(const std::string& path, const char* mode) { 21 | int ret = this->ParsePath(path, this->X_, this->Y_, this->n_samples_, 22 | this->n_features_, this->stride_); 23 | this->is_good_ = ret == Status_OK ? true : false; 24 | this->x_idx_ = 0; 25 | return ret; 26 | } 27 | 28 | void NumpyReader::Rewind() { this->x_idx_ = 0; } 29 | 30 | std::string NumpyReader::GeneratePath(double* x, double* y, int rows, int cols, 31 | int stride) { 32 | ostringstream path; 33 | path << (long long)(x) << ";" << (long long)(y) << ";" << rows << ";" << cols 34 | << ";" << stride; 35 | return path.str(); 36 | } 37 | 38 | template 39 | int ParseAddr(T& dst, const std::string& src) { 40 | string::size_type sz = 0; 41 | try { 42 | string tmp = strip(src); 43 | dst = (T)(stoll(tmp, &sz)); 44 | if (tmp[sz] != '\0') { 45 | return Status_IO_Error; 46 | } 47 | } 48 | catch (invalid_argument&) { 49 | return Status_IO_Error; 50 | } 51 | return Status_OK; 52 | } 53 | 54 | int NumpyReader::ParsePath(const std::string& path, double*& x, double*& y, 55 | int& rows, int& cols, int& stride) { 56 | const vector& parts = split(path, ';'); 57 | if (parts.size() != 5) { 58 | fprintf(stderr, "invalid address (%s) for numpy reader\n", path.c_str()); 59 | return Status_IO_Error; 60 | } 61 | int ret = Status_OK; 62 | ret = ParseAddr(x, parts[0]); 63 | if (ret != Status_OK) return ret; 64 | ret = ParseAddr(y, parts[1]); 65 | if (ret != Status_OK) return ret; 66 | ret = ParseAddr(rows, parts[2]); 67 | if (ret != Status_OK) return ret; 68 | ret = ParseAddr(cols, parts[3]); 69 | if (ret != Status_OK) return ret; 70 | ret = ParseAddr(stride, parts[4]); 71 | return ret; 72 | } 73 | 74 | int NumpyReader::Next(DataPoint& dst_data) { 75 | if (this->x_idx_ == this->n_samples_) 76 | return this->n_samples_ > 0 ? Status_EndOfFile : Status_IO_Error; 77 | 78 | dst_data.Clear(); 79 | // 1. parse label 80 | if (this->Y_ != nullptr) { 81 | dst_data.set_label(label_t(this->Y_[this->x_idx_])); 82 | } 83 | 84 | // 2. parse features 85 | double* ptr = (double*)((char*)this->X_ + this->x_idx_ * this->stride_); 86 | for (int j = 0; j < this->n_features_; ++j, ++ptr) { 87 | if (*ptr != 0) { 88 | dst_data.AddNewFeat(j + 1, static_cast(*ptr)); 89 | } 90 | } 91 | ++this->x_idx_; 92 | return Status_OK; 93 | } 94 | 95 | RegisterDataReader(NumpyReader, "numpy", "numpy array data reader"); 96 | 97 | } // namespace pario 98 | } // namespace sol 99 | 100 | #endif 101 | -------------------------------------------------------------------------------- /src/sol/pario/svm_reader.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : /home/yuewu/work/sol/src/sol/pario/svm_reader.cc 3 | * Created By : yuewu 4 | * Creation Date : [2015-11-11 22:26] 5 | * Last Modified : [2015-11-14 17:16] 6 | * Description : reader of lilbsvm format data 7 | **********************************************************************************/ 8 | 9 | #include "sol/pario/svm_reader.h" 10 | 11 | #include 12 | 13 | #include "sol/pario/numeric_parser.h" 14 | 15 | namespace sol { 16 | namespace pario { 17 | 18 | int SVMReader::Next(DataPoint &dst_data) { 19 | int ret = this->file_reader_.ReadLine(this->read_buf_, this->read_buf_size_); 20 | if (ret != Status_OK) return ret; 21 | 22 | char *iter = this->read_buf_, *endptr = nullptr; 23 | if (*iter == '\0') { 24 | fprintf(stderr, "incorrect line\n"); 25 | return Status_Invalid_Format; 26 | } 27 | 28 | dst_data.Clear(); 29 | // 1. parse label 30 | dst_data.set_label(label_t(NumericParser::ParseInt(iter, endptr))); 31 | if (endptr == iter) { 32 | fprintf(stderr, "parse label failed.\n"); 33 | this->is_good_ = false; 34 | return Status_Invalid_Format; 35 | } 36 | iter = endptr; 37 | 38 | // 2. parse features 39 | while (*iter != '\0') { 40 | index_t index = (index_t)(NumericParser::ParseUint(iter, endptr)); 41 | if (endptr == iter) { 42 | // parse index failed 43 | fprintf(stderr, "parse index value (%s) failed!\n", iter); 44 | this->is_good_ = false; 45 | return Status_Invalid_Format; 46 | } 47 | iter = endptr; 48 | if (*iter != ':') { 49 | fprintf(stderr, "incorrect input file (%s)!\n", iter); 50 | this->is_good_ = false; 51 | return Status_Invalid_Format; 52 | } 53 | ++iter; 54 | 55 | real_t feat = NumericParser::ParseFloat(iter, endptr); 56 | if (endptr == iter) { 57 | fprintf(stderr, "parse feature value (%s) failed!\n", iter); 58 | this->is_good_ = false; 59 | return Status_Invalid_Format; 60 | } 61 | iter = endptr; 62 | 63 | dst_data.AddNewFeat(index, feat); 64 | } 65 | dst_data.Sort(); 66 | 67 | return ret; 68 | } 69 | 70 | RegisterDataReader(SVMReader, "svm", "libsvm format data reader"); 71 | 72 | } // namespace pario 73 | } // namespace sol 74 | -------------------------------------------------------------------------------- /src/sol/pario/svm_writer.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : svm_writer.cc 3 | * Created By : yuewu 4 | * Creation Date : [2015-11-14 15:04] 5 | * Last Modified : [2015-11-14 16:17] 6 | * Description : writer of lilbsvm format data 7 | **********************************************************************************/ 8 | 9 | #include "sol/pario/svm_writer.h" 10 | 11 | #include 12 | 13 | namespace sol { 14 | namespace pario { 15 | 16 | int SVMWriter::Write(const DataPoint &data) { 17 | size_t feat_num = data.indexes().size(); 18 | this->file_writer_.Printf("%d", data.label()); 19 | for (size_t i = 0; i < feat_num; ++i) { 20 | this->file_writer_.Printf(" %d:%g", data.index(i), data.feature(i)); 21 | } 22 | this->file_writer_.Printf("\n"); 23 | return Status_OK; 24 | } 25 | 26 | RegisterDataWriter(SVMWriter, "svm", "libsvm format data writer"); 27 | 28 | } // namespace pario 29 | } // namespace sol 30 | -------------------------------------------------------------------------------- /src/sol/util/reflector.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : reflector.cc 3 | * Created By : yuewu 4 | * Creation Date : [2015-10-23 15:09] 5 | * Last Modified : [2016-05-15 01:42] 6 | * Description : 7 | **********************************************************************************/ 8 | 9 | #include "sol/util/reflector.h" 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | using namespace std; 18 | 19 | namespace sol { 20 | 21 | ClassInfo::ClassInfo(const std::string& name, void* func, 22 | const std::string& descr) 23 | : name_(name), create_func_(func), descr_(descr) { 24 | ClassFactory::Register(this); 25 | } 26 | 27 | void ClassFactory::Register(ClassInfo* class_info) { 28 | ClsInfoMapType& cls_info_map = ClassInfoMap(); 29 | const string& cls_name = lower(class_info->name()); 30 | if (cls_info_map.find(cls_name) == cls_info_map.end()) { 31 | cls_info_map[cls_name] = class_info; 32 | } else { 33 | fprintf(stderr, "%s already exists!\n", class_info->name().c_str()); 34 | exit(1); 35 | } 36 | } 37 | 38 | ClassFactory::ClsInfoMapType& ClassFactory::ClassInfoMap() { 39 | static ClsInfoMapType class_info_map; 40 | return class_info_map; 41 | } 42 | 43 | } // namespace shentu 44 | -------------------------------------------------------------------------------- /test/pario/test_compress.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : test_compress.cc 3 | * Created By : yuewu 4 | * Creation Date : [2015-11-14 17:21] 5 | * Last Modified : [2015-11-14 17:48] 6 | * Description : test compression 7 | **********************************************************************************/ 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include "sol/pario/compress.h" 15 | 16 | using namespace sol::pario; 17 | using namespace sol::math; 18 | using namespace std; 19 | 20 | template 21 | int test() { 22 | srand(static_cast(time(0))); 23 | Vector arr; 24 | int N = 1000; 25 | for (int i = 0; i < N; ++i) { 26 | arr.push_back(T(rand())); 27 | } 28 | std::sort(arr.begin(), arr.end()); 29 | cout << "example array: " << endl; 30 | for (int i = 0; i < (N < 10 ? N : 10); ++i) { 31 | cout << arr[i] << " "; 32 | } 33 | cout << endl; 34 | 35 | Vector codes; 36 | comp_index(arr, codes); 37 | 38 | cout << "size of original array :" << sizeof(T) * N << " bytes" << endl; 39 | cout << "size of codes array :" << sizeof(char) * codes.size() << " bytes" 40 | << endl; 41 | 42 | cout << "check decompress..." << endl; 43 | 44 | Vector arr2; 45 | decomp_index(codes, arr2); 46 | if (arr.size() != arr2.size()) { 47 | cerr << "size of decompressed array " << arr2.size() << " not equal to " 48 | << N << endl; 49 | return -1; 50 | } 51 | for (int i = 0; i < N; ++i) { 52 | if (arr[i] != arr2[i]) { 53 | cerr << i << "-th element not the same (" << arr[i] << " vs " << arr2[i] 54 | << ")" << endl; 55 | return -1; 56 | } 57 | } 58 | cout << "check decompress succeed" << endl; 59 | return 0; 60 | } 61 | 62 | int main() { 63 | // check memory leak in VC++ 64 | #if defined(_MSC_VER) && defined(_DEBUG) 65 | int tmpFlag = _CrtSetDbgFlag(_CRTDBG_REPORT_FLAG); 66 | tmpFlag |= _CRTDBG_LEAK_CHECK_DF; 67 | _CrtSetDbgFlag(tmpFlag); 68 | //_CrtSetBreakAlloc(368); 69 | #endif 70 | 71 | cout << "check uint16_t" << endl; 72 | int ret = test(); 73 | if (ret != 0) return ret; 74 | 75 | cout << "check uint32_t" << endl; 76 | ret = test(); 77 | if (ret != 0) return ret; 78 | 79 | cout << "check uint64_t" << endl; 80 | ret = test(); 81 | if (ret != 0) return ret; 82 | return 0; 83 | } 84 | -------------------------------------------------------------------------------- /test/pario/test_data_iter.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : test_data_iter.cc 3 | * Created By : yuewu 4 | * Creation Date : [2016-02-12 18:11] 5 | * Last Modified : [2016-02-12 23:49] 6 | * Description : 7 | **********************************************************************************/ 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "sol/pario/data_iter.h" 14 | 15 | using namespace sol; 16 | using namespace sol::pario; 17 | using namespace std; 18 | 19 | int main(int argc, char** argv) { 20 | // check memory leak in VC++ 21 | #if defined(_MSC_VER) && defined(_DEBUG) 22 | int tmpFlag = _CrtSetDbgFlag(_CRTDBG_REPORT_FLAG); 23 | tmpFlag |= _CRTDBG_LEAK_CHECK_DF; 24 | _CrtSetDbgFlag(tmpFlag); 25 | //_CrtSetBreakAlloc(368); 26 | #endif 27 | 28 | string path = "data/a1a"; 29 | string dtype = "svm"; 30 | if (argc == 3) { 31 | path = argv[1]; 32 | dtype = argv[2]; 33 | } 34 | 35 | DataIter iter; 36 | iter.AddReader(path, dtype); 37 | 38 | MiniBatch* mb = nullptr; 39 | while (true) { 40 | mb = iter.Next(mb); 41 | if (mb == nullptr) break; 42 | 43 | fprintf(stdout, "mini-batch size: %d\n", mb->size()); 44 | for (int i = 0; i < mb->size(); ++i) { 45 | DataPoint& dp = (*mb)[i]; 46 | fprintf(stdout, "%d", dp.label()); 47 | for (size_t d = 0; d < dp.size(); ++d) { 48 | fprintf(stdout, " %d:%f", dp.index(d), dp.feature(d)); 49 | } 50 | fprintf(stdout, "\n"); 51 | } 52 | } 53 | return 0; 54 | } 55 | -------------------------------------------------------------------------------- /test/pario/test_data_point.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : test_data_point.cc 3 | * Created By : yuewu 4 | * Creation Date : [2015-10-29 13:52] 5 | * Last Modified : [2015-11-12 18:11] 6 | * Description : 7 | **********************************************************************************/ 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | using namespace sol; 16 | using namespace sol::pario; 17 | using namespace std; 18 | 19 | int main() { 20 | // check memory leak in VC++ 21 | #if defined(_MSC_VER) && defined(_DEBUG) 22 | int tmpFlag = _CrtSetDbgFlag(_CRTDBG_REPORT_FLAG); 23 | tmpFlag |= _CRTDBG_LEAK_CHECK_DF; 24 | _CrtSetDbgFlag(tmpFlag); 25 | //_CrtSetBreakAlloc(368); 26 | #endif 27 | 28 | size_t N = 4; 29 | 30 | DataPoint pt; 31 | 32 | for (size_t i = 0; i < N; ++i) { 33 | index_t idx = rand() % N + 1; 34 | real_t feat = (real_t)(rand() % N); 35 | cout << "add new feat: " << idx << ": " << feat << endl; 36 | pt.AddNewFeat(idx, feat); 37 | } 38 | cout << "original data point" << endl; 39 | for (size_t i = 0; i < pt.size(); ++i) { 40 | cout << pt.indexes()[i] << ": " << pt.features()[i] << endl; 41 | } 42 | 43 | cout << "sort" << endl; 44 | pt.Sort(); 45 | for (size_t i = 0; i < pt.size(); ++i) { 46 | cout << pt.indexes()[i] << ": " << pt.features()[i] << endl; 47 | } 48 | 49 | return 0; 50 | } 51 | -------------------------------------------------------------------------------- /test/pario/test_file_reader.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : test_file_reader.cc 3 | * Created By : yuewu 4 | * Creation Date : [2015-10-17 00:26] 5 | * Last Modified : [2015-10-21 11:43] 6 | * Description : test file reader 7 | **********************************************************************************/ 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include "sol/pario/file_reader.h" 14 | #include "sol/util/error_code.h" 15 | 16 | using namespace sol; 17 | using namespace sol::pario; 18 | using namespace std; 19 | 20 | int main(int argc, char** args) { 21 | // check memory leak in VC++ 22 | #if defined(_MSC_VER) && defined(_DEBUG) 23 | int tmpFlag = _CrtSetDbgFlag(_CRTDBG_REPORT_FLAG); 24 | tmpFlag |= _CRTDBG_LEAK_CHECK_DF; 25 | _CrtSetDbgFlag(tmpFlag); 26 | //_CrtSetBreakAlloc(368); 27 | #endif 28 | 29 | FileReader reader; 30 | string path = "data/a1a"; 31 | if (argc > 1) { 32 | path = args[1]; 33 | } 34 | 35 | reader.Open(path.c_str(), "r"); 36 | if (reader.Good() == false) { 37 | cerr << "open file (" << path << ") failed\n"; 38 | return -1; 39 | } 40 | cout << "test readline\n"; 41 | int buf_len = 1024; 42 | char* buf = new char[buf_len]; 43 | size_t file_len = 0; 44 | for (int i = 0; i < 10; ++i) { 45 | cerr << "\tread round " << i << "\t"; 46 | file_len = 0; 47 | while (reader.ReadLine(buf, buf_len) == Status_OK) { 48 | file_len += strlen(buf); 49 | } 50 | if (reader.Good()) { 51 | cout << file_len << " bytes read\n"; 52 | reader.Rewind(); 53 | } 54 | } 55 | int status = 0; 56 | if (reader.Good() == false) { 57 | status = -1; 58 | } else { 59 | reader.Rewind(); 60 | } 61 | if (status == 0) { 62 | cerr << "test read, file length: " << file_len << "\n"; 63 | buf = (char*)realloc(buf, file_len); 64 | for (int i = 0; i < 10; ++i) { 65 | cerr << "\tread round " << i << "\t"; 66 | while (reader.Read(buf, file_len / 2) == Status_OK) { 67 | } 68 | if (reader.Good()) { 69 | cerr << file_len << " bytes read\n"; 70 | reader.Rewind(); 71 | } 72 | } 73 | if (reader.Good() == false) { 74 | status = -1; 75 | } else { 76 | reader.Rewind(); 77 | } 78 | } 79 | 80 | delete[] buf; 81 | cerr << "program exited with code " << status << "\n"; 82 | return status; 83 | } 84 | -------------------------------------------------------------------------------- /test/pario/test_file_writer.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : ../../../test/test_file_writer.cc 3 | * Created By : yuewu 4 | * Creation Date : [2015-10-17 11:13] 5 | * Last Modified : [2015-10-17 11:25] 6 | * Description : test file writer 7 | **********************************************************************************/ 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include "sol/pario/file_writer.h" 15 | #include "sol/pario/file_reader.h" 16 | #include "sol/util/error_code.h" 17 | 18 | using namespace sol; 19 | using namespace sol::pario; 20 | using namespace std; 21 | 22 | int main(int argc, char** args) { 23 | // check memory leak in VC++ 24 | #if defined(_MSC_VER) && defined(_DEBUG) 25 | int tmpFlag = _CrtSetDbgFlag(_CRTDBG_REPORT_FLAG); 26 | tmpFlag |= _CRTDBG_LEAK_CHECK_DF; 27 | _CrtSetDbgFlag(tmpFlag); 28 | //_CrtSetBreakAlloc(368); 29 | #endif 30 | 31 | FileWriter writer; 32 | string out_path = "data/a1a.out"; 33 | if (argc > 1) { 34 | out_path = args[1]; 35 | } 36 | 37 | string in_path = "data/a1a"; 38 | if (argc > 2) { 39 | in_path = args[2]; 40 | } 41 | 42 | FileReader reader(in_path.c_str(), "r"); 43 | if (reader.Good() == false) { 44 | fprintf(stderr, "open file (%s) failed\n", in_path.c_str()); 45 | return -1; 46 | } 47 | 48 | writer.Open(out_path.c_str(), "w"); 49 | if (writer.Good() == false) { 50 | fprintf(stderr, "open file (%s) failed\n", out_path.c_str()); 51 | return -1; 52 | } 53 | int buf_len = 1024; 54 | char* buf = new char[buf_len]; 55 | size_t file_len = 0; 56 | while (reader.ReadLine(buf, buf_len) == Status_OK) { 57 | file_len += strlen(buf); 58 | writer.Write(buf, strlen(buf)); 59 | } 60 | int status = 0; 61 | if (writer.Good() == false || reader.Good() == false) { 62 | status = -1; 63 | } else { 64 | cerr << file_len << "bytes read and write\n"; 65 | } 66 | 67 | delete[] buf; 68 | fprintf(stderr, "program exited with code %d\n", status); 69 | return status; 70 | } 71 | -------------------------------------------------------------------------------- /tools/analyze.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : analyze.cc 3 | * Created By : yuewu 4 | * Description : analyze data 5 | **********************************************************************************/ 6 | #include 7 | #include 8 | 9 | using namespace sol; 10 | using namespace std; 11 | 12 | int main(int argc, char** argv) { 13 | // check memory leak in VC++ 14 | #if defined(_MSC_VER) && defined(_DEBUG) 15 | int tmpFlag = _CrtSetDbgFlag(_CRTDBG_REPORT_FLAG); 16 | tmpFlag |= _CRTDBG_LEAK_CHECK_DF; 17 | _CrtSetDbgFlag(tmpFlag); 18 | //_CrtSetBreakAlloc(231); 19 | #endif 20 | 21 | cmdline::parser parser; 22 | parser.add("input", 'i', "input data path", true); 23 | parser.add("input_type", 's', "input data type", true); 24 | parser.add("output", 'o', "output data path", false, "", "-"); 25 | 26 | parser.parse_check(argc, argv); 27 | 28 | return analyze(parser.get("input"), parser.get("input_type"), 29 | parser.get("output")); 30 | } 31 | -------------------------------------------------------------------------------- /tools/concat.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : concat.cc 3 | * Created By : yuewu 4 | * Description : concatenate datasets 5 | **********************************************************************************/ 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | #include 16 | 17 | using namespace sol; 18 | using namespace sol::pario; 19 | using namespace std; 20 | 21 | int main(int argc, char** argv) { 22 | // check memory leak in VC++ 23 | #if defined(_MSC_VER) && defined(_DEBUG) 24 | int tmpFlag = _CrtSetDbgFlag(_CRTDBG_REPORT_FLAG); 25 | tmpFlag |= _CRTDBG_LEAK_CHECK_DF; 26 | _CrtSetDbgFlag(tmpFlag); 27 | //_CrtSetBreakAlloc(231); 28 | #endif 29 | 30 | cmdline::parser parser; 31 | parser.add("input", 'i', "input data paths, separated by ';'", true); 32 | parser.add("input_type", 's', "input data type", true); 33 | parser.add("output", 'o', "output data path", true); 34 | parser.add("output_type", 'd', "output data type"); 35 | 36 | parser.parse_check(argc, argv); 37 | 38 | string src_path = parser.get("input"); 39 | string src_type = parser.get("input_type"); 40 | string dst_path = parser.get("output"); 41 | string dst_type = parser.get("output_type"); 42 | 43 | DataWriter* writer = DataWriter::Create(dst_type); 44 | if (writer == nullptr) { 45 | return Status_Invalid_Argument; 46 | } 47 | int ret = writer->Open(dst_path); 48 | if (ret != Status_OK) { 49 | delete writer; 50 | return ret; 51 | } 52 | 53 | DataIter iter; 54 | MiniBatch* mb = nullptr; 55 | const vector& input_list = split(src_path, ';'); 56 | 57 | if (dst_type == "csv") { 58 | cout << "analyzing feature dimension\n"; 59 | index_t feat_dim = 0; 60 | for (const string& input_path : input_list) { 61 | ret = iter.AddReader(input_path, src_type); 62 | if (ret != Status_OK) return ret; 63 | } 64 | while (true) { 65 | mb = iter.Next(mb); 66 | if (mb == nullptr) break; 67 | for (int i = 0; i < mb->size(); ++i) { 68 | DataPoint& dp = (*mb)[i]; 69 | if (feat_dim < dp.dim()) feat_dim = dp.dim(); 70 | } 71 | } 72 | cout << "total dimension: " << feat_dim << "\n"; 73 | writer->SetExtraInfo((char*)(&feat_dim)); 74 | } 75 | 76 | size_t data_num = 0; 77 | size_t print_thresh = 10000; 78 | for (const string& input_path : input_list) { 79 | ret = iter.AddReader(input_path, src_type); 80 | if (ret != Status_OK) return ret; 81 | } 82 | while (true) { 83 | mb = iter.Next(mb); 84 | if (mb == nullptr) break; 85 | data_num += mb->size(); 86 | for (int i = 0; i < mb->size(); ++i) { 87 | writer->Write((*mb)[i]); 88 | } 89 | 90 | if (data_num > print_thresh) { 91 | cout << data_num << " examples concatenated\r"; 92 | print_thresh += 10000; 93 | } 94 | } 95 | 96 | writer->Close(); 97 | delete writer; 98 | cout << data_num << " examples concatenated to " << dst_path << "\n"; 99 | return ret; 100 | } 101 | -------------------------------------------------------------------------------- /tools/converter.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : converter.cc 3 | * Created By : yuewu 4 | * Creation Date : [2016-02-12 18:29] 5 | * Last Modified : [2016-11-19 18:23] 6 | * Description : covert data formats 7 | **********************************************************************************/ 8 | 9 | #include 10 | #include 11 | 12 | using namespace sol; 13 | using namespace std; 14 | 15 | int main(int argc, char** argv) { 16 | // check memory leak in VC++ 17 | #if defined(_MSC_VER) && defined(_DEBUG) 18 | int tmpFlag = _CrtSetDbgFlag(_CRTDBG_REPORT_FLAG); 19 | tmpFlag |= _CRTDBG_LEAK_CHECK_DF; 20 | _CrtSetDbgFlag(tmpFlag); 21 | //_CrtSetBreakAlloc(231); 22 | #endif 23 | 24 | cmdline::parser parser; 25 | parser.add("input", 'i', "input data path", true); 26 | parser.add("input_type", 's', "input data type", true); 27 | parser.add("output", 'o', "output data path", true); 28 | parser.add("output_type", 'd', "output data type", true); 29 | parser.add("binary_thresh", 'b', "threshoold to binarize the values", false); 30 | 31 | parser.parse_check(argc, argv); 32 | 33 | bool binarize = false; 34 | float binary_thrshold = 0; 35 | if (true == parser.exist("binary_thresh")) { 36 | binarize = true; 37 | binary_thrshold = parser.get("binary_thresh"); 38 | } 39 | return convert(parser.get("input"), parser.get("input_type"), 40 | parser.get("output"), 41 | parser.get("output_type"), 42 | binarize, binary_thrshold); 43 | } 44 | -------------------------------------------------------------------------------- /tools/shuffle.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : shuffle.cc 3 | * Created By : yuewu 4 | * Description : shuffle file 5 | **********************************************************************************/ 6 | 7 | #include 8 | #include 9 | 10 | using namespace sol; 11 | using namespace std; 12 | 13 | int main(int argc, char** argv) { 14 | // check memory leak in VC++ 15 | #if defined(_MSC_VER) && defined(_DEBUG) 16 | int tmpFlag = _CrtSetDbgFlag(_CRTDBG_REPORT_FLAG); 17 | tmpFlag |= _CRTDBG_LEAK_CHECK_DF; 18 | _CrtSetDbgFlag(tmpFlag); 19 | //_CrtSetBreakAlloc(231); 20 | #endif 21 | 22 | cmdline::parser parser; 23 | parser.add("input", 'i', "input data path", true); 24 | parser.add("input_type", 's', "input data type", true); 25 | parser.add("output", 'o', "output data path", false, "", "-"); 26 | parser.add("output_type", 'd', "output data type", false, "", ""); 27 | 28 | parser.parse_check(argc, argv); 29 | 30 | return shuffle(parser.get("input"), parser.get("input_type"), 31 | parser.get("output"), 32 | parser.get("output_type")); 33 | } 34 | -------------------------------------------------------------------------------- /tools/split.cc: -------------------------------------------------------------------------------- 1 | /********************************************************************************* 2 | * File Name : shuffle.cc 3 | * Created By : yuewu 4 | * Description : split file into folds 5 | **********************************************************************************/ 6 | 7 | #include 8 | #include 9 | 10 | using namespace sol; 11 | using namespace std; 12 | 13 | int main(int argc, char** argv) { 14 | // check memory leak in VC++ 15 | #if defined(_MSC_VER) && defined(_DEBUG) 16 | int tmpFlag = _CrtSetDbgFlag(_CRTDBG_REPORT_FLAG); 17 | tmpFlag |= _CRTDBG_LEAK_CHECK_DF; 18 | _CrtSetDbgFlag(tmpFlag); 19 | //_CrtSetBreakAlloc(231); 20 | #endif 21 | 22 | cmdline::parser parser; 23 | parser.add("input", 'i', "input data path", true); 24 | parser.add("input_type", 's', "input data type", true); 25 | parser.add("fold", 'n', "split number", true); 26 | parser.add("output_prefix", 'o', "output prefix", true); 27 | parser.add("output_type", 'd', "output data type"); 28 | parser.add("shuffle", 'r', "shuffle the input file"); 29 | 30 | parser.parse_check(argc, argv); 31 | 32 | return split(parser.get("input"), parser.get("input_type"), 33 | parser.get("fold"), parser.get("output_prefix"), 34 | parser.get("output_type"), 35 | parser.exist("shuffle") ? true : false); 36 | } 37 | --------------------------------------------------------------------------------