├── .gitignore ├── Makefile.am ├── data ├── sketches_svg.zip.sha1 └── map_id_label.txt ├── autogen.sh ├── src ├── svg.cpp ├── Makefile.am ├── util.cpp ├── types.h ├── util.h ├── conv.h ├── vocab.cpp ├── svg.h ├── classify.cpp ├── kmeans.h ├── cats.cpp ├── cross.cpp ├── features.h ├── io.h ├── svm.h └── gui.cpp ├── util ├── get-data ├── data-fold ├── run-vocab ├── run-classify ├── run-cats └── run-cross ├── LICENSE ├── configure.ac ├── m4 └── ax_cxx_compile_stdcxx_11.m4 └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | -------------------------------------------------------------------------------- /Makefile.am: -------------------------------------------------------------------------------- 1 | SUBDIRS = src 2 | -------------------------------------------------------------------------------- /data/sketches_svg.zip.sha1: -------------------------------------------------------------------------------- 1 | 0f11a412b2be919109fe125103a04a010703341a sketches_svg.zip 2 | -------------------------------------------------------------------------------- /autogen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | aclocal -I m4 \ 4 | && automake --add-missing \ 5 | && autoconf 6 | -------------------------------------------------------------------------------- /src/svg.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // A static class for automatically initializing GLib 4 | struct glib_initializer { 5 | glib_initializer() { 6 | g_type_init(); 7 | } 8 | }; 9 | 10 | const glib_initializer glib_init; 11 | 12 | -------------------------------------------------------------------------------- /util/get-data: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | 5 | ZIP='sketches_svg.zip' 6 | URL="http://cybertron.cg.tu-berlin.de/eitz/projects/classifysketch/$ZIP" 7 | 8 | SHA1="$ZIP.sha1" 9 | 10 | pushd data 11 | wget "$URL" 12 | sha1sum -c "$SHA1" 13 | unzip -q "$ZIP" 14 | rm "$ZIP" 15 | popd 16 | -------------------------------------------------------------------------------- /util/data-fold: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | 5 | NFOLDS=8 6 | 7 | op='==' 8 | n="$1" 9 | 10 | # Handle folds negated with `~'. 11 | [[ "$n" == ~* ]] && op='!=' 12 | n="${n##\~}" 13 | 14 | if [ "$n" -ge "$NFOLDS" ] 15 | then 16 | echo "${0##*/}: Fold out of range: $n" 1>&2 17 | exit 1 18 | fi 19 | 20 | find data/svg/ -type f -name '*.svg' \ 21 | | sort -n -t '/' -k 4 \ 22 | | awk "(NR - 1) % $NFOLDS $op $n { print; }" 23 | -------------------------------------------------------------------------------- /src/Makefile.am: -------------------------------------------------------------------------------- 1 | noinst_PROGRAMS = cats classify cross gui vocab 2 | 3 | AM_CXXFLAGS = $(CAIRO_CFLAGS) $(FFTW_CFLAGS) $(GLIB_CFLAGS) $(GTKMM_CFLAGS) $(LIBRSVG_CFLAGS) $(OPENMP_CXXFLAGS) 4 | AM_LDFLAGS = $(CAIRO_LIBS) $(FFTW_LIBS) $(GLIB_LIBS) $(GTKMM_LIBS) $(LIBRSVG_LIBS) 5 | 6 | cats_SOURCES = cats.cpp svg.cpp util.cpp 7 | classify_SOURCES = classify.cpp svg.cpp util.cpp 8 | cross_SOURCES = cross.cpp svg.cpp util.cpp 9 | gui_SOURCES = gui.cpp util.cpp 10 | vocab_SOURCES = vocab.cpp svg.cpp util.cpp 11 | -------------------------------------------------------------------------------- /util/run-vocab: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | 5 | # Default arguments 6 | n=1000000 7 | vocab='data/vocab.out' 8 | 9 | # Process the command-line arguments. 10 | while [ $# -gt 0 ] 11 | do 12 | case "$1" in 13 | -n) 14 | n="$2" 15 | shift 16 | ;; 17 | -*) 18 | echo "${0##*/}: Unrecognized option: \`$1'" 1>&2 19 | exit 1 20 | ;; 21 | *) 22 | break 23 | ;; 24 | esac 25 | done 26 | 27 | [ $# -gt 0 ] && vocab="$1" 28 | 29 | find data/svg/ -type f -name '*.svg' | 30 | build/src/vocab \ 31 | -n "$n" \ 32 | "$vocab" 33 | -------------------------------------------------------------------------------- /src/util.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "util.h" 4 | 5 | dlib::matrix< float, 3, 3 > sobel_x_init() { 6 | dlib::matrix< float, 3, 3 > m; 7 | m = 8 | -1.f, 0.f, 1.f, 9 | -2.f, 0.f, 2.f, 10 | -1.f, 0.f, 1.f; 11 | return m; 12 | } 13 | 14 | const dlib::matrix< float, 3, 3 > sobel_x = sobel_x_init(); 15 | 16 | dlib::matrix< float, 3, 3 > sobel_y_init() { 17 | dlib::matrix< float, 3, 3 > m; 18 | m = 19 | -1.f, -2.f, -1.f, 20 | 0.f, 0.f, 0.f, 21 | 1.f, 2.f, 1.f; 22 | return m; 23 | } 24 | 25 | const dlib::matrix< float, 3, 3 > sobel_y = sobel_y_init(); 26 | 27 | -------------------------------------------------------------------------------- /util/run-classify: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | 5 | # Default arguments 6 | vocab='data/vocab.out' 7 | map='data/map_id_label.txt' 8 | classifier='ova' 9 | cats='data/cats.out' 10 | fold='0' 11 | 12 | # Process the command-line arguments. 13 | while [ $# -gt 0 ] 14 | do 15 | case "$1" in 16 | -v) 17 | vocab="$2" 18 | shift 19 | ;; 20 | -m) 21 | map="$2" 22 | shift 23 | ;; 24 | -c) 25 | classifier="$2" 26 | shift 27 | ;; 28 | --fold) 29 | fold="$2" 30 | shift 31 | ;; 32 | -*) 33 | echo "${0##*/}: Unrecognized option: \`$1'" 1>&2 34 | exit 1 35 | ;; 36 | *) 37 | break 38 | ;; 39 | esac 40 | shift 41 | done 42 | 43 | [ $# -gt 0 ] && cats="$1" 44 | 45 | util/data-fold "$fold" | 46 | build/src/classify \ 47 | -v "$vocab" \ 48 | -m "$map" \ 49 | -c "$classifier" \ 50 | "$cats" 51 | -------------------------------------------------------------------------------- /util/run-cats: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | 5 | # Default arguments 6 | vocab='data/vocab.out' 7 | map='data/map_id_label.txt' 8 | classifier='ova' 9 | gamma='17.8' 10 | C='3.2' 11 | cats='data/cats.out' 12 | fold='~0' 13 | 14 | # Process the command-line arguments. 15 | while [ $# -gt 0 ] 16 | do 17 | case "$1" in 18 | -v) 19 | vocab="$2" 20 | shift 21 | ;; 22 | -m) 23 | map="$2" 24 | shift 25 | ;; 26 | -c) 27 | classifier="$2" 28 | shift 29 | ;; 30 | -g) 31 | gamma="$2" 32 | shift 33 | ;; 34 | -C) 35 | C="$2" 36 | shift 37 | ;; 38 | --fold) 39 | fold="$2" 40 | shift 41 | ;; 42 | -*) 43 | echo "${0##*/}: Unrecognized option: \`$1'" 1>&2 44 | exit 1 45 | ;; 46 | *) 47 | break 48 | ;; 49 | esac 50 | shift 51 | done 52 | 53 | [ $# -gt 0 ] && cats="$1" 54 | 55 | util/data-fold "$fold" | 56 | build/src/cats \ 57 | -v "$vocab" \ 58 | -m "$map" \ 59 | -c "$classifier" \ 60 | -g "$gamma" \ 61 | -C "$C" \ 62 | "$cats" 63 | -------------------------------------------------------------------------------- /util/run-cross: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | 5 | # Default arguments 6 | folds=8 7 | vocab='data/vocab.out' 8 | map='data/map_id_label.txt' 9 | classifier='ova' 10 | gamma='17.8' 11 | C='3.2' 12 | conf='data/conf.out' 13 | 14 | # Process the command-line arguments. 15 | while [ $# -gt 0 ] 16 | do 17 | case "$1" in 18 | -f) 19 | folds="$2" 20 | shift 21 | ;; 22 | -v) 23 | vocab="$2" 24 | shift 25 | ;; 26 | -m) 27 | map="$2" 28 | shift 29 | ;; 30 | -c) 31 | classifier="$2" 32 | shift 33 | ;; 34 | -g) 35 | gamma="$2" 36 | shift 37 | ;; 38 | -C) 39 | C="$2" 40 | shift 41 | ;; 42 | -*) 43 | echo "${0##*/}: Unrecognized option: \`$1'" 1>&2 44 | exit 1 45 | ;; 46 | *) 47 | break 48 | ;; 49 | esac 50 | shift 51 | done 52 | 53 | [ $# -gt 0 ] && conf="$1" 54 | 55 | find data/svg/ -type f -name '*.svg' | 56 | build/src/cross \ 57 | -f "$folds" \ 58 | -v "$vocab" \ 59 | -m "$map" \ 60 | -c "$classifier" \ 61 | -g "$gamma" \ 62 | -C "$C" \ 63 | "$conf" 64 | -------------------------------------------------------------------------------- /src/types.h: -------------------------------------------------------------------------------- 1 | #ifndef TYPES_H 2 | #define TYPES_H 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "features.h" 11 | #include "svm.h" 12 | 13 | // Preprocessing 14 | typedef feature_desc_extractor< float, 256 > feature_desc_extractor_type; 15 | typedef feature_desc_extractor_type::image_type image_type; 16 | typedef feature_desc_extractor_type::desc_type feature_desc_type; 17 | typedef std::vector< feature_desc_type > vocab_type; 18 | typedef dlib::matrix< float, 500, 1 > feature_hist_type; 19 | 20 | // Classification 21 | typedef dlib::radial_basis_kernel< feature_hist_type > kernel_type; 22 | typedef dlib::svm_c_trainer< kernel_type > trainer_type; 23 | 24 | typedef one_vs_all_trainer2< dlib::any_trainer< feature_hist_type, float >, 25 | int, true > ova_trainer_type; 26 | typedef dlib::one_vs_all_decision_function< ova_trainer_type, 27 | dlib::decision_function< kernel_type > > ova_df_type; 28 | 29 | typedef one_vs_one_trainer2< dlib::any_trainer< feature_hist_type, float >, 30 | int, true > ovo_trainer_type; 31 | typedef dlib::one_vs_one_decision_function< ovo_trainer_type, 32 | dlib::decision_function< kernel_type > > ovo_df_type; 33 | 34 | typedef dlib::type_safe_union< ova_df_type, ovo_df_type > df_type; 35 | 36 | #endif 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2013 Anthony DeRossi 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions 7 | are met: 8 | 9 | * Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above 13 | copyright notice, this list of conditions and the following 14 | disclaimer in the documentation and/or other materials provided 15 | with the distribution. 16 | 17 | * Neither the name of the author nor the names of other 18 | contributors may be used to endorse or promote products derived 19 | from this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 25 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | -------------------------------------------------------------------------------- /configure.ac: -------------------------------------------------------------------------------- 1 | AC_INIT([sketchrec], 0.1) 2 | AC_PREREQ(2.69) 3 | 4 | AC_LANG([C++]) 5 | 6 | AM_INIT_AUTOMAKE([1.12 foreign no-define nostdinc]) 7 | 8 | : ${CXXFLAGS="-pedantic -Wall -g -O3 -funroll-loops"} 9 | 10 | # Check for programs. 11 | AC_PROG_CXX 12 | AC_PROG_CXXCPP 13 | PKG_PROG_PKG_CONFIG 14 | 15 | # Check for libraries. 16 | PKG_CHECK_MODULES(CAIRO, [cairo >= 1.10]) 17 | PKG_CHECK_MODULES(FFTW, [fftw3f >= 3.0]) 18 | PKG_CHECK_MODULES(GLIB, [glib-2.0 >= 2.0]) 19 | PKG_CHECK_MODULES(GTKMM, [gtkmm-2.4 >= 2.24]) 20 | PKG_CHECK_MODULES(LIBRSVG, [librsvg-2.0 >= 2.0]) 21 | 22 | # AX_LIB_DLIB([MIN-VERSION],[ACTION-IF-SUCCESS],[ACTION-IF-FAILURE]) 23 | # ------------------------------------------------------------------ 24 | # Check for dlib with at least major version MIN-VERSION. On success, set 25 | # HAVE_DLIB and execute ACTION-IF-SUCCESS, otherwise execute 26 | # ACTION-IF-FAILURE. 27 | AC_DEFUN([AX_LIB_DLIB], [ 28 | AC_CACHE_CHECK([for dlib], [ax_cv_have_dlib], [ 29 | ax_cv_have_dlib=no 30 | AC_LANG_PUSH([C++]) 31 | AC_PREPROC_IFELSE([ 32 | AC_LANG_PROGRAM([ 33 | #include 34 | #include 35 | #include 36 | #include 37 | #include 38 | #if !defined(DLIB_MAJOR_VERSION) || DLIB_MAJOR_VERSION < $1 39 | #error 40 | #endif 41 | ]) 42 | ], [ax_cv_have_dlib=yes], []) 43 | AC_LANG_POP 44 | ]) 45 | if test "$ax_cv_have_dlib" = yes; then 46 | AC_DEFINE([HAVE_DLIB], [1], [Define if dlib >= $1 is present.]) 47 | $2 48 | else 49 | $3 50 | fi 51 | ]) 52 | 53 | AX_LIB_DLIB([18], , [AC_MSG_FAILURE([dlib 18 or newer is required])]) 54 | 55 | # Check for compiler characteristics. 56 | AC_HEADER_ASSERT 57 | AC_OPENMP 58 | AX_CXX_COMPILE_STDCXX_11([noext], [mandatory]) 59 | 60 | AC_CONFIG_FILES([Makefile src/Makefile]) 61 | 62 | AC_OUTPUT 63 | -------------------------------------------------------------------------------- /src/util.h: -------------------------------------------------------------------------------- 1 | #ifndef UTIL_H 2 | #define UTIL_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | // 3x3 Sobel filter kernels 11 | extern const dlib::matrix< float, 3, 3 > sobel_x, sobel_y; 12 | 13 | // Normalize a vector using the L1-norm. 14 | template< class Exp > 15 | const dlib::matrix_op< dlib::op_normalize< Exp > > l1_normalize( 16 | const dlib::matrix_exp< Exp > &m) { 17 | typedef dlib::op_normalize< Exp > op; 18 | 19 | typename Exp::type s = sum(m); 20 | if (s != 0.) 21 | s = 1. / s; 22 | 23 | return dlib::matrix_op< op >(op(m.ref(), s)); 24 | } 25 | 26 | // Convert cartesian x- and y-magnitude images to radial magnitude and 27 | // orientation images. 28 | template< class T, long NR1, long NC1, long NR2, long NC2, long NR3, long NC3, 29 | long NR4, long NC4> 30 | void cart2polar(const dlib::matrix< T, NR1, NC1 > &x, 31 | const dlib::matrix< T, NR2, NC2 > &y, dlib::matrix< T, NR3, NC3 > &r, 32 | dlib::matrix< T, NR4, NC4 > &theta) { 33 | const long rows = x.nr(); 34 | const long cols = x.nc(); 35 | assert(y.nr() == rows); 36 | assert(y.nc() == cols); 37 | r.set_size(rows, cols); 38 | theta.set_size(rows, cols); 39 | 40 | for (long j = 0; j < rows; ++j) { 41 | for (long i = 0; i < cols; ++i) { 42 | r(j, i) = std::hypot(x(j, i), y(j, i)); 43 | theta(j, i) = std::atan2(y(j, i), x(j, i)); 44 | } 45 | } 46 | } 47 | 48 | // A stream sampling algorithm for choosing n elements from a stream uniformly 49 | // at random 50 | template< class T > 51 | struct stream_sample { 52 | typedef typename std::vector< T >::size_type size_type; 53 | 54 | stream_sample(size_type n_) : n(n_), i(0) { 55 | samples.reserve(n); 56 | } 57 | 58 | template< class Generator > 59 | void push_back(Generator &g, const T& x) { 60 | if (i < n) { 61 | // Select the first n elements. 62 | samples.push_back(x); 63 | } 64 | else { 65 | // Select the remaining elements with probability n / (i + 1). 66 | std::uniform_int_distribution< typename std::vector< T >::size_type > 67 | uniform_i(0, i); 68 | if (uniform_i(g) < n) { 69 | std::uniform_int_distribution< typename std::vector< T >::size_type > 70 | uniform_n(0, n - 1); 71 | samples[uniform_n(g)] = x; 72 | } 73 | } 74 | ++i; 75 | } 76 | 77 | const std::vector< T > &get() const { 78 | return samples; 79 | } 80 | 81 | private: 82 | const size_type n; 83 | size_type i; 84 | std::vector< T > samples; 85 | }; 86 | 87 | #endif 88 | -------------------------------------------------------------------------------- /src/conv.h: -------------------------------------------------------------------------------- 1 | #ifndef CONV_H 2 | #define CONV_H 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | template< class T > 11 | struct fftw_helper; 12 | 13 | template<> 14 | struct fftw_helper< float > { 15 | typedef fftwf_plan plan_type; 16 | 17 | static inline plan_type plan_dft_r2c_2d(int n0, int n1, float *in, 18 | fftwf_complex *out, unsigned flags) { 19 | return fftwf_plan_dft_r2c_2d(n0, n1, in, out, flags); 20 | } 21 | 22 | static inline plan_type plan_dft_c2r_2d(int n0, int n1, fftwf_complex *in, 23 | float *out, unsigned flags) { 24 | return fftwf_plan_dft_c2r_2d(n0, n1, in, out, flags); 25 | } 26 | 27 | static inline void execute_dft_r2c(const plan_type p, float *in, 28 | fftwf_complex *out) { 29 | fftwf_execute_dft_r2c(p, in, out); 30 | } 31 | 32 | static inline void execute_dft_c2r(const plan_type p, fftwf_complex *in, 33 | float *out) { 34 | fftwf_execute_dft_c2r(p, in, out); 35 | } 36 | }; 37 | 38 | // A class for FFT-based convolution with a fixed kernel 39 | template< class T, long NR, long NC, bool Verbose = false > 40 | struct conv_fft { 41 | conv_fft(const dlib::matrix< T, NR, NC > &h) { 42 | dlib::matrix< T, NR, NC > x; 43 | dlib::matrix< std::complex< T >, NR, NC_Comp > xf; 44 | 45 | plan = fftw_helper< T >::plan_dft_r2c_2d(NR, NC, &x(0, 0), 46 | reinterpret_cast< T (*)[2] >(&xf(0, 0)), FFTW_PATIENT); 47 | inv_plan = fftw_helper< T >::plan_dft_c2r_2d(NR, NC, 48 | reinterpret_cast< T (*)[2] >(&xf(0, 0)), &x(0, 0), FFTW_PATIENT); 49 | 50 | if (Verbose) { 51 | std::cout << "FFT plan:\n"; 52 | fftwf_print_plan(plan); 53 | std::cout << '\n'; 54 | 55 | std::cout << "Inverse FFT plan:\n"; 56 | fftwf_print_plan(inv_plan); 57 | std::cout << '\n'; 58 | } 59 | 60 | // Store the transformed kernel. 61 | fftw_helper< T >::execute_dft_r2c(plan, const_cast< T * >(&h(0, 0)), 62 | reinterpret_cast< T (*)[2] >(&hf(0, 0))); 63 | hf /= NR * NC; 64 | } 65 | 66 | ~conv_fft() { 67 | fftwf_destroy_plan(plan); 68 | fftwf_destroy_plan(inv_plan); 69 | } 70 | 71 | void operator()(dlib::matrix< T, NR, NC > &x) const { 72 | dlib::matrix< std::complex< T >, NR, NC_Comp > xf; 73 | 74 | fftw_helper< T >::execute_dft_r2c(plan, const_cast< T * >(&x(0, 0)), 75 | reinterpret_cast< T (*)[2] >(&xf(0, 0))); 76 | xf = pointwise_multiply(xf, hf); 77 | 78 | fftw_helper< T >::execute_dft_c2r(inv_plan, 79 | reinterpret_cast< T (*)[2] >(&xf(0, 0)), &x(0, 0)); 80 | } 81 | 82 | private: 83 | static const long NC_Comp = NC / 2 + 1; 84 | 85 | typename fftw_helper< T >::plan_type plan, inv_plan; 86 | 87 | dlib::matrix< std::complex< T >, NR, NC_Comp > hf; 88 | }; 89 | 90 | #endif 91 | -------------------------------------------------------------------------------- /src/vocab.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | #include "features.h" 11 | #include "io.h" 12 | #include "kmeans.h" 13 | #include "svg.h" 14 | #include "types.h" 15 | #include "util.h" 16 | 17 | int main(int argc, char *argv[]) { 18 | typedef stream_sample< feature_desc_type > stream_sample_type; 19 | 20 | // Process the command-line arguments. 21 | typename stream_sample_type::size_type n = 1000000; 22 | const char *vocab_path = "vocab.out"; 23 | 24 | { 25 | int i; 26 | for (i = 1; i < argc; ++i) { 27 | if (!strcmp(argv[i], "-h")) { 28 | goto usage; 29 | } 30 | else if (!strcmp(argv[i], "-n")) { 31 | std::istringstream ss(argv[++i]); 32 | if (!(ss >> n)) 33 | goto usage; 34 | } 35 | else { 36 | break; 37 | } 38 | } 39 | 40 | if (i < argc) 41 | vocab_path = argv[i++]; 42 | 43 | if (i != argc) 44 | goto usage; 45 | } 46 | 47 | { 48 | // Extract features for all input files. 49 | std::vector< std::string > paths; 50 | std::string path; 51 | while (std::getline(std::cin, path)) 52 | paths.push_back(path); 53 | 54 | // Select a fixed number of random descriptors. 55 | std::random_device rd; 56 | std::mt19937 gen(rd()); 57 | stream_sample_type samples(n); 58 | 59 | #pragma omp parallel for schedule(dynamic) 60 | for (typename std::vector< std::string >::size_type i = 0; 61 | i < paths.size(); ++i) { 62 | const std::string &path = paths[i]; 63 | 64 | #pragma omp critical 65 | { 66 | std::cout << "Extracting features for " << path << " (" << i + 1 67 | << '/' << paths.size() << ")...\n"; 68 | } 69 | 70 | image_type image; 71 | load_svg(path.c_str(), image); 72 | image = 1. - image; 73 | 74 | std::vector< feature_desc_type > descs; 75 | extract_descriptors(image, descs); 76 | 77 | #pragma omp critical 78 | { 79 | for (const auto &desc : descs) 80 | samples.push_back(gen, desc); 81 | } 82 | } 83 | 84 | std::cout << "Got " << samples.get().size() << " descriptors\n"; 85 | 86 | // Generate a vocabulary for this data set. 87 | std::cout << "Clustering...\n"; 88 | 89 | static const long center_count = feature_hist_type::NR; 90 | std::cout << "Picking " << center_count << " initial centers...\n"; 91 | vocab_type vocab; 92 | kmeanspp< float >(gen, samples.get(), center_count, vocab); 93 | 94 | kmeans< float, feature_desc_type, true >(samples.get(), vocab); 95 | 96 | // Save the vocabulary. 97 | std::cout << "Saving vocabulary...\n"; 98 | { 99 | std::ofstream fs(vocab_path, std::ios::binary); 100 | serialize2(vocab, fs); 101 | } 102 | } 103 | 104 | return 0; 105 | 106 | usage: 107 | std::cerr << "Usage: " << argv[0] << " [-n sample-count] [vocab-file]\n"; 108 | return 1; 109 | } 110 | 111 | -------------------------------------------------------------------------------- /src/svg.h: -------------------------------------------------------------------------------- 1 | #ifndef SVG_H 2 | #define SVG_H 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | // An error while loading an image 12 | struct image_error : std::exception { 13 | virtual ~image_error() noexcept { 14 | } 15 | 16 | virtual const char *what() const noexcept { 17 | return "image error"; 18 | } 19 | }; 20 | 21 | // Destruction policies for Cairo objects 22 | // Ideally these would be wrapped in classes, but they aren't used enough. 23 | template< class T > struct cairo_delete {}; 24 | 25 | template<> struct cairo_delete< cairo_t > { 26 | void operator()(cairo_t *p) { 27 | cairo_destroy(p); 28 | } 29 | }; 30 | 31 | template<> struct cairo_delete< cairo_surface_t > { 32 | void operator()(cairo_surface_t *p) { 33 | cairo_surface_destroy(p); 34 | } 35 | }; 36 | 37 | // A destruction policy for GLib objects 38 | template< class T > struct glib_delete { 39 | void operator()(gpointer p) { 40 | g_object_unref(p); 41 | } 42 | }; 43 | 44 | // Load an SVG file, storing the rasterized image in a square matrix. 45 | template< class T, long N > 46 | void load_svg(const char *file, dlib::matrix< T, N, N > &image) { 47 | std::unique_ptr< cairo_surface_t, cairo_delete< cairo_surface_t > > 48 | surface(cairo_image_surface_create(CAIRO_FORMAT_RGB24, N, N)); 49 | if (!surface) 50 | throw image_error(); 51 | 52 | std::unique_ptr< cairo_t, cairo_delete< cairo_t > > 53 | cr(cairo_create(surface.get())); 54 | if (cairo_status(cr.get()) != CAIRO_STATUS_SUCCESS) 55 | throw image_error(); 56 | 57 | // Clear the buffer to white. 58 | cairo_set_source_rgb(cr.get(), 1., 1., 1.); 59 | cairo_paint(cr.get()); 60 | 61 | { 62 | GError *error; 63 | std::unique_ptr< RsvgHandle, glib_delete< RsvgHandle > > 64 | svg(rsvg_handle_new_from_file(file, &error)); 65 | if (!svg) 66 | throw image_error(); 67 | 68 | RsvgDimensionData dims; 69 | rsvg_handle_get_dimensions(svg.get(), &dims); 70 | 71 | // Loaded images must be square. 72 | if (dims.width != dims.height) 73 | throw image_error(); 74 | 75 | const double scale = static_cast< double >(N) / dims.width; 76 | cairo_scale(cr.get(), scale, scale); 77 | 78 | gboolean res; 79 | #pragma omp critical 80 | { 81 | res = rsvg_handle_render_cairo(svg.get(), cr.get()); 82 | } 83 | if (!res) 84 | throw image_error(); 85 | } 86 | 87 | cairo_surface_flush(surface.get()); 88 | 89 | // Store the image in the matrix. 90 | { 91 | const unsigned char *p = cairo_image_surface_get_data(surface.get()); 92 | const int stride = cairo_image_surface_get_stride(surface.get()); 93 | 94 | const unsigned char *q = p; 95 | for (long j = 0; j < N; ++j, q = p += stride) { 96 | for (long i = 0; i < N; ++i, q += 4) { 97 | // Convert the image to grayscale. 98 | image(j, i) = (.299 * q[0] + .587 * q[1] + .114 * q[2]) / 255.; 99 | } 100 | } 101 | } 102 | } 103 | 104 | #endif 105 | -------------------------------------------------------------------------------- /src/classify.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "features.h" 11 | #include "io.h" 12 | #include "svg.h" 13 | #include "svm.h" 14 | #include "types.h" 15 | 16 | int main(int argc, char *argv[]) { 17 | // Process the command-line arguments. 18 | const char *vocab_path = "vocab.out"; 19 | const char *map_path = "map_id_label.txt"; 20 | const char *cats_path = "cats.out"; 21 | bool ova = true; 22 | 23 | { 24 | int i; 25 | for (i = 1; i < argc; ++i) { 26 | if (!strcmp(argv[i], "-h")) { 27 | goto usage; 28 | } 29 | else if (!strcmp(argv[i], "-v")) { 30 | vocab_path = argv[++i]; 31 | } 32 | else if (!strcmp(argv[i], "-m")) { 33 | map_path = argv[++i]; 34 | } 35 | else if (!strcmp(argv[i], "-c")) { 36 | ++i; 37 | if (!strcmp(argv[i], "ova")) { 38 | ova = true; 39 | } 40 | else if (!strcmp(argv[i], "ovo")) { 41 | ova = false; 42 | } 43 | else { 44 | std::cerr << argv[0] << ": Unsupported classifier: `" << argv[i] 45 | << "'\n"; 46 | goto err; 47 | } 48 | } 49 | else { 50 | break; 51 | } 52 | } 53 | 54 | if (i < argc) 55 | cats_path = argv[i++]; 56 | 57 | if (i != argc) 58 | goto usage; 59 | 60 | if (!vocab_path || !map_path || !cats_path) 61 | goto usage; 62 | } 63 | 64 | { 65 | // Load the vocabulary. 66 | std::cout << "Loading vocabulary...\n"; 67 | vocab_type vocab; 68 | { 69 | std::ifstream fs(vocab_path, std::ios::binary); 70 | deserialize2(vocab, fs); 71 | } 72 | 73 | // Load the category map. 74 | std::cout << "Loading category map...\n"; 75 | std::map< int, std::string > cat_map; 76 | { 77 | std::ifstream fs(map_path); 78 | for (std::string line; std::getline(fs, line);) { 79 | std::istringstream ss(line); 80 | int i; 81 | std::string label; 82 | ss >> i; 83 | ss.get(); // ',' 84 | std::getline(ss, label); 85 | cat_map[i] = label; 86 | } 87 | } 88 | 89 | // Load the category classifier. 90 | std::cout << "Loading classifier...\n"; 91 | df_type df; 92 | { 93 | std::ifstream fs(cats_path, std::ios::binary); 94 | if (ova) 95 | deserialize2(df.get< ova_df_type >(), fs); 96 | else 97 | deserialize2(df.get< ovo_df_type >(), fs); 98 | } 99 | 100 | // Extract features for all input files. 101 | std::vector< std::string > paths; 102 | std::string path; 103 | while (std::getline(std::cin, path)) 104 | paths.push_back(path); 105 | 106 | #pragma omp parallel for schedule(dynamic) 107 | for (typename std::vector< std::string >::size_type i = 0; 108 | i < paths.size(); ++i) { 109 | const std::string &path = paths[i]; 110 | 111 | // Extract the features. 112 | image_type image; 113 | load_svg(path.c_str(), image); 114 | image = 1. - image; 115 | 116 | std::vector< feature_desc_type > descs; 117 | extract_descriptors(image, descs); 118 | 119 | feature_hist_type hist; 120 | feature_hist(descs, vocab, hist); 121 | 122 | const int cat = ova ? df.get< ova_df_type >()(hist) : 123 | df.get< ovo_df_type >()(hist); 124 | 125 | assert(cat); 126 | 127 | #pragma omp critical 128 | { 129 | std::cout << path << ' ' << cat_map[cat] << '\n'; 130 | } 131 | } 132 | } 133 | 134 | return 0; 135 | 136 | usage: 137 | std::cerr << "Usage: " << argv[0] 138 | << " [-v vocab-file] [-m map-file] [-c classifier] [cats-file]\n"; 139 | err: 140 | return 1; 141 | } 142 | 143 | -------------------------------------------------------------------------------- /data/map_id_label.txt: -------------------------------------------------------------------------------- 1 | 1,airplane 2 | 2,alarm clock 3 | 3,angel 4 | 4,ant 5 | 5,apple 6 | 6,arm 7 | 7,armchair 8 | 8,ashtray 9 | 9,axe 10 | 10,backpack 11 | 11,banana 12 | 12,barn 13 | 13,baseball bat 14 | 14,basket 15 | 15,bathtub 16 | 16,bear (animal) 17 | 17,bed 18 | 18,bee 19 | 19,beer-mug 20 | 20,bell 21 | 21,bench 22 | 22,bicycle 23 | 23,binoculars 24 | 24,blimp 25 | 25,book 26 | 26,bookshelf 27 | 27,boomerang 28 | 28,bottle opener 29 | 29,bowl 30 | 30,brain 31 | 31,bread 32 | 32,bridge 33 | 33,bulldozer 34 | 34,bus 35 | 35,bush 36 | 36,butterfly 37 | 37,cabinet 38 | 38,cactus 39 | 39,cake 40 | 40,calculator 41 | 41,camel 42 | 42,camera 43 | 43,candle 44 | 44,cannon 45 | 45,canoe 46 | 46,car (sedan) 47 | 47,carrot 48 | 48,castle 49 | 49,cat 50 | 50,cell phone 51 | 51,chair 52 | 52,chandelier 53 | 53,church 54 | 54,cigarette 55 | 55,cloud 56 | 56,comb 57 | 57,computer monitor 58 | 58,computer-mouse 59 | 59,couch 60 | 60,cow 61 | 61,crab 62 | 62,crane (machine) 63 | 63,crocodile 64 | 64,crown 65 | 65,cup 66 | 66,diamond 67 | 67,dog 68 | 68,dolphin 69 | 69,donut 70 | 70,door 71 | 71,door handle 72 | 72,dragon 73 | 73,duck 74 | 74,ear 75 | 75,elephant 76 | 76,envelope 77 | 77,eye 78 | 78,eyeglasses 79 | 79,face 80 | 80,fan 81 | 81,feather 82 | 82,fire hydrant 83 | 83,fish 84 | 84,flashlight 85 | 85,floor lamp 86 | 86,flower with stem 87 | 87,flying bird 88 | 88,flying saucer 89 | 89,foot 90 | 90,fork 91 | 91,frog 92 | 92,frying-pan 93 | 93,giraffe 94 | 94,grapes 95 | 95,grenade 96 | 96,guitar 97 | 97,hamburger 98 | 98,hammer 99 | 99,hand 100 | 100,harp 101 | 101,hat 102 | 102,head 103 | 103,head-phones 104 | 104,hedgehog 105 | 105,helicopter 106 | 106,helmet 107 | 107,horse 108 | 108,hot air balloon 109 | 109,hot-dog 110 | 110,hourglass 111 | 111,house 112 | 112,human-skeleton 113 | 113,ice-cream-cone 114 | 114,ipod 115 | 115,kangaroo 116 | 116,key 117 | 117,keyboard 118 | 118,knife 119 | 119,ladder 120 | 120,laptop 121 | 121,leaf 122 | 122,lightbulb 123 | 123,lighter 124 | 124,lion 125 | 125,lobster 126 | 126,loudspeaker 127 | 127,mailbox 128 | 128,megaphone 129 | 129,mermaid 130 | 130,microphone 131 | 131,microscope 132 | 132,monkey 133 | 133,moon 134 | 134,mosquito 135 | 135,motorbike 136 | 136,mouse (animal) 137 | 137,mouth 138 | 138,mug 139 | 139,mushroom 140 | 140,nose 141 | 141,octopus 142 | 142,owl 143 | 143,palm tree 144 | 144,panda 145 | 145,paper clip 146 | 146,parachute 147 | 147,parking meter 148 | 148,parrot 149 | 149,pear 150 | 150,pen 151 | 151,penguin 152 | 152,person sitting 153 | 153,person walking 154 | 154,piano 155 | 155,pickup truck 156 | 156,pig 157 | 157,pigeon 158 | 158,pineapple 159 | 159,pipe (for smoking) 160 | 160,pizza 161 | 161,potted plant 162 | 162,power outlet 163 | 163,present 164 | 164,pretzel 165 | 165,pumpkin 166 | 166,purse 167 | 167,rabbit 168 | 168,race car 169 | 169,radio 170 | 170,rainbow 171 | 171,revolver 172 | 172,rifle 173 | 173,rollerblades 174 | 174,rooster 175 | 175,sailboat 176 | 176,santa claus 177 | 177,satellite 178 | 178,satellite dish 179 | 179,saxophone 180 | 180,scissors 181 | 181,scorpion 182 | 182,screwdriver 183 | 183,sea turtle 184 | 184,seagull 185 | 185,shark 186 | 186,sheep 187 | 187,ship 188 | 188,shoe 189 | 189,shovel 190 | 190,skateboard 191 | 191,skull 192 | 192,skyscraper 193 | 193,snail 194 | 194,snake 195 | 195,snowboard 196 | 196,snowman 197 | 197,socks 198 | 198,space shuttle 199 | 199,speed-boat 200 | 200,spider 201 | 201,sponge bob 202 | 202,spoon 203 | 203,squirrel 204 | 204,standing bird 205 | 205,stapler 206 | 206,strawberry 207 | 207,streetlight 208 | 208,submarine 209 | 209,suitcase 210 | 210,sun 211 | 211,suv 212 | 212,swan 213 | 213,sword 214 | 214,syringe 215 | 215,t-shirt 216 | 216,table 217 | 217,tablelamp 218 | 218,teacup 219 | 219,teapot 220 | 220,teddy-bear 221 | 221,telephone 222 | 222,tennis-racket 223 | 223,tent 224 | 224,tiger 225 | 225,tire 226 | 226,toilet 227 | 227,tomato 228 | 228,tooth 229 | 229,toothbrush 230 | 230,tractor 231 | 231,traffic light 232 | 232,train 233 | 233,tree 234 | 234,trombone 235 | 235,trousers 236 | 236,truck 237 | 237,trumpet 238 | 238,tv 239 | 239,umbrella 240 | 240,van 241 | 241,vase 242 | 242,violin 243 | 243,walkie talkie 244 | 244,wheel 245 | 245,wheelbarrow 246 | 246,windmill 247 | 247,wine-bottle 248 | 248,wineglass 249 | 249,wrist-watch 250 | 250,zebra -------------------------------------------------------------------------------- /src/kmeans.h: -------------------------------------------------------------------------------- 1 | #ifndef KMEANS_H 2 | #define KMEANS_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | 12 | // An implementation of the k-means++ cluster center initialization algorithm 13 | // by Arthur and Vassilvitskii 14 | template< class DistT, class T, class Generator > 15 | void kmeanspp(Generator &g, const std::vector< T > &samples, 16 | typename std::vector< T >::size_type k, std::vector< T > ¢ers) { 17 | typedef std::vector< T > vector_type; 18 | 19 | assert(samples.size() > 0 && k > 0); 20 | 21 | // Initialize storage for the centers and the minimum distance to a center 22 | // from each sample point. 23 | std::vector< DistT > min_distances(samples.size(), 24 | std::numeric_limits< DistT >::max()); 25 | centers.clear(); 26 | centers.reserve(k); 27 | 28 | // Pick the first center uniformly at random. 29 | { 30 | std::uniform_int_distribution< typename vector_type::size_type > 31 | uniform_sample_index(0, samples.size() - 1); 32 | centers.push_back(samples[uniform_sample_index(g)]); 33 | } 34 | 35 | // Pick the remaining centers. 36 | for (typename vector_type::size_type i = 0; i < k - 1; ++i) { 37 | // Update the minimum distance from each sample to a center, taking into 38 | // account them most recently added center. 39 | #pragma omp parallel for 40 | for (typename vector_type::size_type j = 0; j < samples.size(); ++j) { 41 | const DistT dist = dlib::length_squared(centers[i] - samples[j]); 42 | if (dist < min_distances[j]) 43 | min_distances[j] = dist; 44 | } 45 | 46 | // Pick the next center at random using a probability distribution 47 | // weighted by distance squared. 48 | std::discrete_distribution< typename vector_type::size_type > 49 | weighted_sample_index(min_distances.begin(), min_distances.end()); 50 | centers.push_back(samples[weighted_sample_index(g)]); 51 | } 52 | } 53 | 54 | // An implementation of k-means clustering 55 | template< class DistT, class T, bool Verbose = false > 56 | void kmeans(const std::vector< T > &samples, std::vector< T > ¢ers, 57 | unsigned int max_iter = 1000) { 58 | typedef std::vector< T > vector_type; 59 | 60 | assert(samples.size() > 0 && centers.size() > 0); 61 | 62 | // A zero sample for calculating the centroid 63 | T zero = dlib::zeros_matrix(centers[0]); 64 | 65 | // Initialize storage for the number of samples for each center and the 66 | // center associated with each sample. 67 | std::vector< typename vector_type::size_type > assignments(samples.size()); 68 | std::vector< typename vector_type::size_type > center_element_count; 69 | 70 | if (Verbose) { 71 | std::cout << "Running k-means..."; 72 | std::cout.flush(); 73 | } 74 | 75 | unsigned int iter = 0; 76 | bool centers_changed = true; 77 | while (centers_changed && iter < max_iter) { 78 | ++iter; 79 | centers_changed = false; 80 | 81 | if (Verbose) { 82 | std::cout << ' ' << iter << "..."; 83 | std::cout.flush(); 84 | } 85 | 86 | // Determine which center each sample is closest to. 87 | #pragma omp parallel for 88 | for (typename vector_type::size_type i = 0; i < samples.size(); ++i) { 89 | DistT min_dist = std::numeric_limits< DistT >::max(); 90 | typename vector_type::size_type min_center = 0; 91 | 92 | for (typename vector_type::size_type j = 0; j < centers.size(); ++j) { 93 | const DistT dist = dlib::length_squared(centers[j] - samples[i]); 94 | if (dist < min_dist) { 95 | min_dist = dist; 96 | min_center = j; 97 | } 98 | } 99 | 100 | if (assignments[i] != min_center) { 101 | centers_changed = true; 102 | assignments[i] = min_center; 103 | } 104 | } 105 | 106 | // Update the cluster centers. 107 | centers.assign(centers.size(), zero); 108 | center_element_count.assign(centers.size(), 0); 109 | 110 | for (typename vector_type::size_type i = 0; i < samples.size(); ++i) { 111 | const typename vector_type::size_type &assignment = assignments[i]; 112 | centers[assignment] += samples[i]; 113 | ++center_element_count[assignment]; 114 | } 115 | 116 | for (typename vector_type::size_type i = 0; i < centers.size(); ++i) { 117 | if (center_element_count[i]) 118 | centers[i] /= center_element_count[i]; 119 | } 120 | } 121 | 122 | if (Verbose) 123 | std::cout << " done\n"; 124 | } 125 | 126 | #endif 127 | -------------------------------------------------------------------------------- /m4/ax_cxx_compile_stdcxx_11.m4: -------------------------------------------------------------------------------- 1 | # ============================================================================ 2 | # http://www.gnu.org/software/autoconf-archive/ax_cxx_compile_stdcxx_11.html 3 | # ============================================================================ 4 | # 5 | # SYNOPSIS 6 | # 7 | # AX_CXX_COMPILE_STDCXX_11([ext|noext],[mandatory|optional]) 8 | # 9 | # DESCRIPTION 10 | # 11 | # Check for baseline language coverage in the compiler for the C++11 12 | # standard; if necessary, add switches to CXXFLAGS to enable support. 13 | # 14 | # The first argument, if specified, indicates whether you insist on an 15 | # extended mode (e.g. -std=gnu++11) or a strict conformance mode (e.g. 16 | # -std=c++11). If neither is specified, you get whatever works, with 17 | # preference for an extended mode. 18 | # 19 | # The second argument, if specified 'mandatory' or if left unspecified, 20 | # indicates that baseline C++11 support is required and that the macro 21 | # should error out if no mode with that support is found. If specified 22 | # 'optional', then configuration proceeds regardless, after defining 23 | # HAVE_CXX11 if and only if a supporting mode is found. 24 | # 25 | # LICENSE 26 | # 27 | # Copyright (c) 2008 Benjamin Kosnik 28 | # Copyright (c) 2012 Zack Weinberg 29 | # Copyright (c) 2013 Roy Stogner 30 | # 31 | # Copying and distribution of this file, with or without modification, are 32 | # permitted in any medium without royalty provided the copyright notice 33 | # and this notice are preserved. This file is offered as-is, without any 34 | # warranty. 35 | 36 | #serial 3 37 | 38 | m4_define([_AX_CXX_COMPILE_STDCXX_11_testbody], [ 39 | template 40 | struct check 41 | { 42 | static_assert(sizeof(int) <= sizeof(T), "not big enough"); 43 | }; 44 | 45 | typedef check> right_angle_brackets; 46 | 47 | int a; 48 | decltype(a) b; 49 | 50 | typedef check check_type; 51 | check_type c; 52 | check_type&& cr = static_cast(c); 53 | 54 | auto d = a; 55 | ]) 56 | 57 | AC_DEFUN([AX_CXX_COMPILE_STDCXX_11], [dnl 58 | m4_if([$1], [], [], 59 | [$1], [ext], [], 60 | [$1], [noext], [], 61 | [m4_fatal([invalid argument `$1' to AX_CXX_COMPILE_STDCXX_11])])dnl 62 | m4_if([$2], [], [ax_cxx_compile_cxx11_required=true], 63 | [$2], [mandatory], [ax_cxx_compile_cxx11_required=true], 64 | [$2], [optional], [ax_cxx_compile_cxx11_required=false], 65 | [m4_fatal([invalid second argument `$2' to AX_CXX_COMPILE_STDCXX_11])])dnl 66 | AC_LANG_PUSH([C++])dnl 67 | ac_success=no 68 | AC_CACHE_CHECK(whether $CXX supports C++11 features by default, 69 | ax_cv_cxx_compile_cxx11, 70 | [AC_COMPILE_IFELSE([AC_LANG_SOURCE([_AX_CXX_COMPILE_STDCXX_11_testbody])], 71 | [ax_cv_cxx_compile_cxx11=yes], 72 | [ax_cv_cxx_compile_cxx11=no])]) 73 | if test x$ax_cv_cxx_compile_cxx11 = xyes; then 74 | ac_success=yes 75 | fi 76 | 77 | m4_if([$1], [noext], [], [dnl 78 | if test x$ac_success = xno; then 79 | for switch in -std=gnu++11 -std=gnu++0x; do 80 | cachevar=AS_TR_SH([ax_cv_cxx_compile_cxx11_$switch]) 81 | AC_CACHE_CHECK(whether $CXX supports C++11 features with $switch, 82 | $cachevar, 83 | [ac_save_CXXFLAGS="$CXXFLAGS" 84 | CXXFLAGS="$CXXFLAGS $switch" 85 | AC_COMPILE_IFELSE([AC_LANG_SOURCE([_AX_CXX_COMPILE_STDCXX_11_testbody])], 86 | [eval $cachevar=yes], 87 | [eval $cachevar=no]) 88 | CXXFLAGS="$ac_save_CXXFLAGS"]) 89 | if eval test x\$$cachevar = xyes; then 90 | CXXFLAGS="$CXXFLAGS $switch" 91 | ac_success=yes 92 | break 93 | fi 94 | done 95 | fi]) 96 | 97 | m4_if([$1], [ext], [], [dnl 98 | if test x$ac_success = xno; then 99 | for switch in -std=c++11 -std=c++0x; do 100 | cachevar=AS_TR_SH([ax_cv_cxx_compile_cxx11_$switch]) 101 | AC_CACHE_CHECK(whether $CXX supports C++11 features with $switch, 102 | $cachevar, 103 | [ac_save_CXXFLAGS="$CXXFLAGS" 104 | CXXFLAGS="$CXXFLAGS $switch" 105 | AC_COMPILE_IFELSE([AC_LANG_SOURCE([_AX_CXX_COMPILE_STDCXX_11_testbody])], 106 | [eval $cachevar=yes], 107 | [eval $cachevar=no]) 108 | CXXFLAGS="$ac_save_CXXFLAGS"]) 109 | if eval test x\$$cachevar = xyes; then 110 | CXXFLAGS="$CXXFLAGS $switch" 111 | ac_success=yes 112 | break 113 | fi 114 | done 115 | fi]) 116 | AC_LANG_POP([C++]) 117 | if test x$ax_cxx_compile_cxx11_required = xtrue; then 118 | if test x$ac_success = xno; then 119 | AC_MSG_ERROR([*** A compiler with support for C++11 language features is required.]) 120 | fi 121 | else 122 | if test x$ac_success = xno; then 123 | HAVE_CXX11=0 124 | AC_MSG_NOTICE([No compiler with C++11 support was found]) 125 | else 126 | HAVE_CXX11=1 127 | AC_DEFINE(HAVE_CXX11,1, 128 | [define if the compiler supports basic C++11 syntax]) 129 | fi 130 | 131 | AC_SUBST(HAVE_CXX11) 132 | fi 133 | ]) 134 | -------------------------------------------------------------------------------- /src/cats.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include "features.h" 12 | #include "io.h" 13 | #include "svm.h" 14 | #include "svg.h" 15 | #include "types.h" 16 | 17 | int main(int argc, char *argv[]) { 18 | // Process the command-line arguments. 19 | const char *vocab_path = "vocab.out"; 20 | const char *map_path = "map_id_label.txt"; 21 | const char *cats_path = "cats.out"; 22 | bool ova = true; 23 | typename kernel_type::scalar_type gamma = 17.8; 24 | typename kernel_type::scalar_type c = 3.2; 25 | 26 | { 27 | int i; 28 | for (i = 1; i < argc; ++i) { 29 | if (!strcmp(argv[i], "-h")) { 30 | goto usage; 31 | } 32 | else if (!strcmp(argv[i], "-v")) { 33 | vocab_path = argv[++i]; 34 | } 35 | else if (!strcmp(argv[i], "-m")) { 36 | map_path = argv[++i]; 37 | } 38 | else if (!strcmp(argv[i], "-c")) { 39 | ++i; 40 | if (!strcmp(argv[i], "ova")) { 41 | ova = true; 42 | } 43 | else if (!strcmp(argv[i], "ovo")) { 44 | ova = false; 45 | } 46 | else { 47 | std::cerr << argv[0] << ": Unsupported classifier: `" << argv[i] 48 | << "'\n"; 49 | goto err; 50 | } 51 | } 52 | else if (!strcmp(argv[i], "-g")) { 53 | std::istringstream ss(argv[++i]); 54 | if (!(ss >> gamma)) 55 | goto usage; 56 | } 57 | else if (!strcmp(argv[i], "-C")) { 58 | std::istringstream ss(argv[++i]); 59 | if (!(ss >> c)) 60 | goto usage; 61 | } 62 | else { 63 | break; 64 | } 65 | } 66 | 67 | if (i < argc) 68 | cats_path = argv[i++]; 69 | 70 | if (i != argc) 71 | goto usage; 72 | 73 | if (!vocab_path || !map_path) 74 | goto usage; 75 | } 76 | 77 | { 78 | // Load the vocabulary. 79 | std::cout << "Loading vocabulary...\n"; 80 | vocab_type vocab; 81 | { 82 | std::ifstream fs(vocab_path, std::ios::binary); 83 | deserialize2(vocab, fs); 84 | } 85 | 86 | // Load the category map. 87 | std::cout << "Loading category map...\n"; 88 | std::map< std::string, int > cat_map; 89 | { 90 | std::ifstream fs(map_path); 91 | for (std::string line; std::getline(fs, line);) { 92 | std::istringstream ss(line); 93 | int i; 94 | std::string label; 95 | ss >> i; 96 | ss.get(); // ',' 97 | std::getline(ss, label); 98 | cat_map[label] = i; 99 | } 100 | fs.close(); 101 | } 102 | 103 | // Extract features for all input files. 104 | std::vector< feature_hist_type > samples; 105 | std::vector< int > labels; 106 | 107 | std::vector< std::string > paths; 108 | std::string path; 109 | while (std::getline(std::cin, path)) 110 | paths.push_back(path); 111 | 112 | #pragma omp parallel for schedule(dynamic) 113 | for (typename std::vector< std::string >::size_type i = 0; 114 | i < paths.size(); ++i) { 115 | const std::string &path = paths[i]; 116 | 117 | #pragma omp critical 118 | { 119 | std::cout << "Extracting features for " << path << " (" << i + 1 120 | << '/' << paths.size() << ")...\n"; 121 | } 122 | 123 | // Get the category from the directory name. 124 | const std::size_t dir_end = path.rfind('/'); 125 | std::size_t dir_begin = path.rfind('/', dir_end - 1); 126 | if (dir_begin == std::string::npos) 127 | dir_begin = 0; 128 | else 129 | ++dir_begin; 130 | 131 | const std::string dir = path.substr(dir_begin, dir_end - dir_begin); 132 | const int cat = cat_map[dir]; 133 | 134 | assert(cat); 135 | 136 | // Extract the features. 137 | image_type image; 138 | load_svg(path.c_str(), image); 139 | image = 1. - image; 140 | 141 | std::vector< feature_desc_type > descs; 142 | extract_descriptors(image, descs); 143 | 144 | feature_hist_type hist; 145 | feature_hist(descs, vocab, hist); 146 | 147 | // Store the category label and feature histogram. 148 | #pragma omp critical 149 | { 150 | samples.push_back(hist); 151 | labels.push_back(cat); 152 | } 153 | } 154 | 155 | // Train a multi-class classifier. 156 | trainer_type rbf_trainer; 157 | rbf_trainer.set_kernel(kernel_type(gamma)); 158 | rbf_trainer.set_c(c); 159 | 160 | df_type df; 161 | if (ova) { 162 | std::cout << "Training one-vs-all classifier...\n"; 163 | df.get< ova_df_type >() = 164 | ova_trainer_type(rbf_trainer).train(samples, labels); 165 | } 166 | else { 167 | std::cout << "Training one-vs-one classifier...\n"; 168 | df.get< ovo_df_type >() = 169 | ovo_trainer_type(rbf_trainer).train(samples, labels); 170 | } 171 | 172 | // Save the classifier. 173 | std::cout << "Saving classifier...\n"; 174 | { 175 | std::ofstream fs(cats_path, std::ios::binary); 176 | if (ova) 177 | serialize2(df.get< ova_df_type >(), fs); 178 | else 179 | serialize2(df.get< ovo_df_type >(), fs); 180 | } 181 | } 182 | 183 | return 0; 184 | 185 | usage: 186 | std::cerr << "Usage: " << argv[0] << " [-v vocab-file] [-m map-file]" 187 | " [-c classifier] [-g gamma] [-C C] [cats-file]\n"; 188 | err: 189 | return 1; 190 | } 191 | 192 | -------------------------------------------------------------------------------- /src/cross.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include "features.h" 12 | #include "io.h" 13 | #include "svm.h" 14 | #include "svg.h" 15 | #include "types.h" 16 | 17 | int main(int argc, char *argv[]) { 18 | // Process the command-line arguments. 19 | long folds = 8; 20 | const char *vocab_path = "vocab.out"; 21 | const char *map_path = "map_id_label.txt"; 22 | const char *conf_path = "conf.out"; 23 | bool ova = true; 24 | typename kernel_type::scalar_type gamma = 17.8; 25 | typename kernel_type::scalar_type c = 3.2; 26 | 27 | { 28 | int i; 29 | for (i = 1; i < argc; ++i) { 30 | if (!strcmp(argv[i], "-h")) { 31 | goto usage; 32 | } 33 | else if (!strcmp(argv[i], "-f")) { 34 | std::istringstream ss(argv[++i]); 35 | if (!(ss >> folds)) 36 | goto usage; 37 | } 38 | else if (!strcmp(argv[i], "-v")) { 39 | vocab_path = argv[++i]; 40 | } 41 | else if (!strcmp(argv[i], "-m")) { 42 | map_path = argv[++i]; 43 | } 44 | else if (!strcmp(argv[i], "-c")) { 45 | ++i; 46 | if (!strcmp(argv[i], "ova")) { 47 | ova = true; 48 | } 49 | else if (!strcmp(argv[i], "ovo")) { 50 | ova = false; 51 | } 52 | else { 53 | std::cerr << argv[0] << ": Unsupported classifier: `" << argv[i] 54 | << "'\n"; 55 | goto err; 56 | } 57 | } 58 | else if (!strcmp(argv[i], "-g")) { 59 | std::istringstream ss(argv[++i]); 60 | if (!(ss >> gamma)) 61 | goto usage; 62 | } 63 | else if (!strcmp(argv[i], "-C")) { 64 | std::istringstream ss(argv[++i]); 65 | if (!(ss >> c)) 66 | goto usage; 67 | } 68 | else { 69 | break; 70 | } 71 | } 72 | 73 | if (i < argc) 74 | conf_path = argv[i++]; 75 | 76 | if (i != argc) 77 | goto usage; 78 | 79 | if (!vocab_path || !map_path) 80 | goto usage; 81 | } 82 | 83 | { 84 | // Load the vocabulary. 85 | std::cout << "Loading vocabulary...\n"; 86 | vocab_type vocab; 87 | { 88 | std::ifstream fs(vocab_path, std::ios::binary); 89 | deserialize2(vocab, fs); 90 | } 91 | 92 | // Load the category map. 93 | std::cout << "Loading category map...\n"; 94 | std::map< std::string, int > cat_map; 95 | { 96 | std::ifstream fs(map_path); 97 | for (std::string line; std::getline(fs, line);) { 98 | std::istringstream ss(line); 99 | int i; 100 | std::string label; 101 | ss >> i; 102 | ss.get(); // ',' 103 | std::getline(ss, label); 104 | cat_map[label] = i; 105 | } 106 | fs.close(); 107 | } 108 | 109 | // Extract features for all input files. 110 | std::vector< feature_hist_type > samples; 111 | std::vector< int > labels; 112 | 113 | std::vector< std::string > paths; 114 | std::string path; 115 | while (std::getline(std::cin, path)) 116 | paths.push_back(path); 117 | 118 | #pragma omp parallel for schedule(dynamic) 119 | for (typename std::vector< std::string >::size_type i = 0; 120 | i < paths.size(); ++i) { 121 | const std::string &path = paths[i]; 122 | 123 | #pragma omp critical 124 | { 125 | std::cout << "Extracting features for " << path << " (" << i + 1 126 | << '/' << paths.size() << ")...\n"; 127 | } 128 | 129 | // Get the category from the directory name. 130 | const std::size_t dir_end = path.rfind('/'); 131 | std::size_t dir_begin = path.rfind('/', dir_end - 1); 132 | if (dir_begin == std::string::npos) 133 | dir_begin = 0; 134 | else 135 | ++dir_begin; 136 | 137 | const std::string dir = path.substr(dir_begin, dir_end - dir_begin); 138 | const int cat = cat_map[dir]; 139 | 140 | assert(cat); 141 | 142 | // Extract the features. 143 | image_type image; 144 | load_svg(path.c_str(), image); 145 | image = 1. - image; 146 | 147 | std::vector< feature_desc_type > descs; 148 | extract_descriptors(image, descs); 149 | 150 | feature_hist_type hist; 151 | feature_hist(descs, vocab, hist); 152 | 153 | // Store the category label and feature histogram. 154 | #pragma omp critical 155 | { 156 | samples.push_back(hist); 157 | labels.push_back(cat); 158 | } 159 | } 160 | 161 | // Train a multi-class classifier. 162 | trainer_type rbf_trainer; 163 | rbf_trainer.set_kernel(kernel_type(gamma)); 164 | rbf_trainer.set_c(c); 165 | 166 | dlib::matrix< double > conf; 167 | if (ova) { 168 | std::cout << "Cross-validating one-vs-all classifier using " << folds 169 | << " folds...\n"; 170 | conf = cross_validate_multiclass_trainer2< ova_trainer_type, 171 | feature_hist_type, int, true >(ova_trainer_type(rbf_trainer), samples, 172 | labels, folds); 173 | } 174 | else { 175 | std::cout << "Cross-validating one-vs-one classifier using " << folds 176 | << " folds...\n"; 177 | conf = cross_validate_multiclass_trainer2< ovo_trainer_type, 178 | feature_hist_type, int, true >(ovo_trainer_type(rbf_trainer), samples, 179 | labels, folds); 180 | } 181 | 182 | const std::vector< int > distinct_labels = 183 | dlib::select_all_distinct_labels(labels); 184 | 185 | // Save the confusion matrix. 186 | std::cout << "Saving confusion matrix...\n"; 187 | { 188 | std::ofstream fs(conf_path, std::ios::binary); 189 | serialize2(distinct_labels, fs); 190 | serialize2(conf, fs); 191 | } 192 | } 193 | 194 | return 0; 195 | 196 | usage: 197 | std::cerr << "Usage: " << argv[0] << " [-f folds] [-v vocab-file]" 198 | " [-m map-file] [-c classifier] [-g gamma] [-C C] [conf-file]\n"; 199 | err: 200 | return 1; 201 | } 202 | 203 | -------------------------------------------------------------------------------- /src/features.h: -------------------------------------------------------------------------------- 1 | #ifndef FEATURES_H 2 | #define FEATURES_H 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | #include "conv.h" 10 | #include "util.h" 11 | 12 | // Bin the gradient magnitudes by orientation into orientational response 13 | // images. 14 | template< class T, long NR, long NC > 15 | void orient_responses(const dlib::matrix< T, NR, NC > &g, 16 | const dlib::matrix< T, NR, NC > &o, unsigned int bin_count, 17 | std::vector< dlib::matrix< T, NR, NC > > &os) { 18 | const T w = M_PI / bin_count; // The width of each bin 19 | os.assign(bin_count, zeros_matrix(g)); 20 | 21 | for (unsigned int k = 0; k < bin_count; ++k) { 22 | const T o0 = (k - .5) * w; // The lower bound for this bin 23 | const T oc = k * w; // The center of this bin 24 | const T o1 = (k + .5) * w; // The upper bound for this bin 25 | 26 | T oji; 27 | for (long j = 0; j < NR; ++j) { 28 | for (long i = 0; i < NC; ++i) { 29 | oji = o(j, i); 30 | if (oji >= o0 && oji < o1) { 31 | const T t = std::abs(oc - oji) / w; 32 | os[k](j, i) += g(j, i) * (1. - t); 33 | os[((oji < oc) ? k + bin_count - 1 : k + 1) % bin_count](j, i) += 34 | g(j, i) * t; 35 | } 36 | } 37 | } 38 | } 39 | } 40 | 41 | // A class for extracting feature descriptors from a grayscale image 42 | template< class T, long N > 43 | struct feature_desc_extractor { 44 | static const unsigned int orient_bin_count = 4; 45 | static const unsigned int spatial_bin_count = 4; // In each dimension 46 | static const unsigned int spatial_bin_size = 47 | (N * 0.35355339) / spatial_bin_count; // 12.5% area 48 | static const unsigned int feature_grid_size = 28; 49 | 50 | typedef dlib::matrix< T, N, N > image_type; 51 | typedef dlib::matrix< T, orient_bin_count * 52 | spatial_bin_count * spatial_bin_count, 1 > desc_type; 53 | 54 | static void extract(const image_type &image, 55 | std::vector< desc_type > &descs) { 56 | descs.clear(); 57 | descs.reserve(feature_grid_size * feature_grid_size); 58 | 59 | // Compute the gradient. 60 | const image_type gx = conv_same(image, sobel_x); 61 | const image_type gy = conv_same(image, sobel_y); 62 | 63 | // Compute the magnitude and orientation of the gradient. 64 | image_type g, o; 65 | cart2polar(gx, gy, g, o); 66 | 67 | // Limit the orientation range to [0, pi). 68 | for (long j = 0; j < N; ++j) { 69 | for (long i = 0; i < N; ++i) { 70 | if (o(j, i) < 0) 71 | o(j, i) += M_PI; 72 | else if (o(j, i) >= M_PI) 73 | o(j, i) -= M_PI; 74 | } 75 | } 76 | 77 | // Generate orientational response images. 78 | std::vector< image_type > os; 79 | orient_responses(g, o, orient_bin_count, os); 80 | 81 | // Convolve each orientational response image with a 2D tent function to 82 | // accelerate interpolation. 83 | for (unsigned int i = 0; i < orient_bin_count; ++i) { 84 | conv_tent(os[i]); 85 | // Account for slightly negative responses introduced by the FFT. 86 | os[i] = abs(os[i]); 87 | } 88 | 89 | // Extract feature descriptors on a regular grid. Orientational response 90 | // values are binned into a spatial grid centered at each grid point. 91 | static const unsigned int dg = N / feature_grid_size; 92 | 93 | for (unsigned int v = dg / 2; v < N; v += dg) { 94 | for (unsigned int u = dg / 2; u < N; u += dg) { 95 | desc_type d; 96 | 97 | for (unsigned int i = 0; i < orient_bin_count; ++i) { 98 | for (unsigned int t = 0; t < spatial_bin_count; ++t) { 99 | // The vertical bin center 100 | const int ct = spatial_bin_size / 2 + spatial_bin_size * 101 | (t - spatial_bin_count / 2); 102 | 103 | for (unsigned int s = 0; s < spatial_bin_count; ++s) { 104 | // The horizontal bin center 105 | const int cs = spatial_bin_size / 2 + spatial_bin_size * 106 | (s - spatial_bin_count / 2); 107 | 108 | const int y = v + ct; 109 | const int x = u + cs; 110 | d((i * spatial_bin_count + t) * spatial_bin_count + s) = 111 | (0 <= y && y < N && 0 <= x && x < N) ? os[i](y, x) : 0; 112 | } 113 | } 114 | } 115 | 116 | // Normalize the feature descriptor before adding it to the array. 117 | descs.push_back(normalize(d)); 118 | } 119 | } 120 | } 121 | 122 | private: 123 | // A 2D tent function kernel for bilinear interpolation 124 | static image_type tent_kernel_init() { 125 | const unsigned int tent_size = 2 * spatial_bin_size + 1; 126 | 127 | image_type m; 128 | m = 0; 129 | for (unsigned int j = 0; j < tent_size; ++j) { 130 | const unsigned int xj = 131 | spatial_bin_size - std::abs(j - spatial_bin_size); 132 | for (unsigned int i = 0; i < tent_size; ++i) 133 | m(j, i) = xj * (spatial_bin_size - std::abs(i - spatial_bin_size)); 134 | } 135 | return m; 136 | } 137 | 138 | static const conv_fft< T, N, N > conv_tent; 139 | }; 140 | 141 | template< class T, long N > 142 | const conv_fft< T, N, N > feature_desc_extractor< T, N >::conv_tent( 143 | feature_desc_extractor::tent_kernel_init()); 144 | 145 | // A helper function for inferring the image type for feature extraction 146 | template< class T, long N > 147 | void extract_descriptors(const dlib::matrix< T, N, N > &S, 148 | std::vector< typename feature_desc_extractor< T, N >::desc_type > &D) { 149 | feature_desc_extractor< T, N >::extract(S, D); 150 | } 151 | 152 | // Quantize a feature descriptor for a vocabulary using the Gaussian distance 153 | // to each word. 154 | template< class T, long N, long V > 155 | void quantize_desc(const dlib::matrix< T, N, 1 > &desc, 156 | const std::vector< dlib::matrix< T, N, 1 > > &vocab, 157 | dlib::matrix< T, V, 1 > &q) { 158 | static const T sigma = .1; // Sigma for Gaussian distance 159 | typedef std::vector< dlib::matrix< T, N, 1 > > vocab_type; 160 | 161 | assert(vocab.size() == V); 162 | 163 | for (typename vocab_type::size_type i = 0; i < vocab.size(); ++i) { 164 | const dlib::matrix< T, N, 1 > diff = desc - vocab[i]; 165 | q(i) = std::exp(-dot(diff, diff) / (2 * sigma * sigma)); 166 | } 167 | } 168 | 169 | // Generate a feature histogram for a set of feature descriptors and a 170 | // vocabulary. 171 | template< class T, long N, long V > 172 | void feature_hist(const std::vector< dlib::matrix< T, N, 1 > > &descs, 173 | const std::vector< dlib::matrix< T, N, 1 > > &vocab, 174 | dlib::matrix< T, V, 1 > &hist) { 175 | assert(vocab.size() == V && V > 0); 176 | 177 | hist = 0; 178 | 179 | for (const auto &desc : descs) { 180 | dlib::matrix< T, V, 1 > q; 181 | quantize_desc(desc, vocab, q); 182 | 183 | // Normalize the feature distance before accumulating. 184 | hist += l1_normalize(q); 185 | } 186 | 187 | hist /= V; 188 | } 189 | 190 | #endif 191 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Human Sketch Recognition 2 | 3 | This project implements a human sketch recognition algorithm based on the 4 | paper: 5 | 6 | Mathias Eitz, James Hays, and Marc Alexa. *[How Do Humans Sketch Objects?] 7 | [1]* ACM Trans. Graph. (Proc. SIGGRAPH), 31(4):44:1-10, July 2012. 8 | 9 | [1]: http://cybertron.cg.tu-berlin.de/eitz/projects/classifysketch/ 10 | 11 | The sketch dataset is not stored in this repository and must be downloaded 12 | separately. See the setup instructions below for details. 13 | 14 | ## Setup 15 | 16 | ### Dependencies 17 | 18 | GCC 4.6.3 or newer with support for [OpenMP] [2], autoconf 2.69, and automake 19 | 1.12 are required to compile the code. 20 | 21 | [2]: http://openmp.org/ 22 | 23 | The following libraries must also be installed before compiling: 24 | 25 | * [cairo] [3] (>= 1.10) 26 | 27 | * [dlib] [4] (>= 18.0) 28 | 29 | * [fftw] [5] (>= 3.0, configured with single precision support) 30 | 31 | * [gtkmm] [6] (>= 2.24) 32 | 33 | * [librsvg] [7] (>= 2.0) 34 | 35 | [3]: http://cairographics.org/ 36 | [4]: http://dlib.net/ 37 | [5]: http://fftw.org/ 38 | [6]: http://gtkmm.org/ 39 | [7]: https://live.gnome.org/LibRsvg 40 | 41 | ### Compiling 42 | 43 | In the root directory of the project, run `autogen.sh` to set up the 44 | configuration scripts. 45 | 46 | $ ./autogen.sh 47 | 48 | Create the build directory and switch to it, then run `configure` and `make` 49 | to compile. 50 | 51 | $ mkdir build 52 | $ cd build 53 | $ ../configure 54 | $ make 55 | 56 | To compile without optimizations (for debugging), configure with: 57 | 58 | $ ../configure CXXFLAGS='-O0' 59 | 60 | The compiled programs in the build directory mirror the source directory 61 | structure. 62 | 63 | ### Data 64 | 65 | Run `util/get-data` from the root directory to automatically download the 66 | sketch dataset into the data directory. The dataset can also be downloaded 67 | manually from the following link: 68 | 69 | [Sketch dataset (SVG)] [8] (zip, ~50 MB) 70 | 71 | [8]: http://cybertron.cg.tu-berlin.de/eitz/projects/classifysketch/sketches_svg.zip 72 | 73 | ## Running 74 | 75 | The easiest way to train a classifier and classify sketch data is to run the 76 | provided utility scripts from the project root directory. These scripts make 77 | some assumptions about the project layout: 78 | 79 | * The build directory is in the project root and is called `build` 80 | 81 | * The SVG dataset is in `data/svg` and has one subdirectory for each 82 | category 83 | 84 | * Each category subdirectory contains only sketch images and has no nested 85 | subdirectories 86 | 87 | Each script accepts the same command-line arguments as the associated program 88 | and runs with the same default arguments (with one exception: default paths 89 | are prefixed with `data/`). For detailed information about these arguments, 90 | see **Programs**. 91 | 92 | To generate the visual vocabulary for the entire sketch dataset, run: 93 | 94 | $ util/run-vocab [-n sample-count] 95 | 96 | By default, this script runs with 1,000,000 features selected at random from 97 | the dataset. This usually takes between 60 and 90 minutes (400-600 iterations 98 | of k-means clustering) to complete on a 2.2 GHz Core i7 with 8 threads. 99 | 100 | The dataset is organized into 8 folds to aid in selecting subsets of the data. 101 | Each fold is assigned an index (0-7), and folds can be negated by prepending 102 | a ~. The identifier ~4, for example, refers to the contents of folds 0-3 and 103 | 5-7. 104 | 105 | To train a classifier on a subset of the data, run: 106 | 107 | $ util/run-cats [--fold fold-id] [-c classifier] [cats-file] 108 | 109 | By default, this script operates on folds 1-7 (~0) and trains a one-vs-all 110 | classifier. Training a one-vs-one classifier on this dataset requires nearly 111 | 8 GB of memory. 112 | 113 | To classify a subset of the data, run: 114 | 115 | $ util/run-classify [--fold fold-id] [-c classifier] [cats-file] 116 | 117 | By default, this script operates on fold 0 and expects a one-vs-all 118 | classifier. 119 | 120 | ### Demo 121 | 122 | The default settings will create a working classifier trained on 7/8 (87.5%) 123 | of the dataset. To train, after compiling and obtaining the dataset, run: 124 | 125 | $ util/run-vocab 126 | $ util/run-cats 127 | 128 | This will create the files `data/vocab.out`, the visual vocabulary, and 129 | `data/cats.out`, the one-vs-all classifier. 130 | 131 | To classify the un-trained portion of the dataset, run: 132 | 133 | $ util/run-classify 134 | 135 | ### Programs 136 | 137 | These programs can be found under the build directory after running `make`. 138 | Arguments in brackets are optional and will assume default values when 139 | omitted. 140 | 141 | * `vocab [-n sample-count] [vocab-file]` 142 | 143 | Generate a visual vocabulary for the images specified on standard input, 144 | one path per line. Feature descriptors are extracted from each file. 145 | `sample-count` (default: 1,000,000) random descriptors are selected from 146 | this dataset and clustered into 500 visual words. The resulting 147 | vocabulary is written to `vocab-file` (default: `vocab.out`). 148 | 149 | * `cats [-v vocab-file] [-m map-file] [-c classifier] [-g gamma] [-C C] [cats-file]` 150 | 151 | Train a classifier with the images specified on standard input, one path 152 | per line. The name of the subdirectory containing each image is used as 153 | the category label. Feature histograms are generated from each image for 154 | training using `vocab-file` (default: `vocab.out`). The mapping between 155 | category labels and numeric identifiers is read from `map-file` (default: 156 | `map_id_label.txt`). Two types of classifiers are currently supported, 157 | one-vs-all (`ova`) and one-vs-one (`ovo`). `classifier` (default: `ova`) 158 | must be one of these two values. `gamma` and `C` are the SVM parameters 159 | (default: 17.8 and 3.2 respectively). The resulting classifier is written 160 | to `cats-file` (default: `cats.out`). 161 | 162 | * `classify [-v vocab-file] [-m map-file] [-c classifier] [cats-file]` 163 | 164 | Run a classifier on each image specified on standard input, one path per 165 | line. Each path and its predicted category is written to standard output. 166 | The default values for each argument are the same as above. The same 167 | classifier type must be selected for both training and classification, 168 | since this information is currently not stored with the classifier. 169 | 170 | * `cross [-f folds] [-v vocab-file] [-m map-file] [-c classifier] [-g gamma] [-C C] [conf-file]` 171 | 172 | Run cross-validation using the given number of folds, writing the 173 | confusion matrix to `conf-file` (default: `conf.out`). 174 | 175 | * `gui [-v vocab-file] [-m map-file] [-c classifier] [cats-file]` 176 | 177 | Run a GUI that classifies user sketches in real time. The command-line 178 | arguments this program accepts are the same as above. 179 | 180 | ## License 181 | 182 | The files in this project are released under the BSD-3 license unless stated 183 | otherwise. See the file `LICENSE` for details. 184 | -------------------------------------------------------------------------------- /src/io.h: -------------------------------------------------------------------------------- 1 | #ifndef IO_H 2 | #define IO_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | #include 16 | 17 | // An error while serializing or deserializing 18 | struct serialization_error : std::exception { 19 | virtual ~serialization_error() noexcept { 20 | } 21 | 22 | virtual const char *what() const noexcept { 23 | return "serialization error"; 24 | } 25 | }; 26 | 27 | // Check whether a type is a multi-class decision function. 28 | template< class T > 29 | struct is_multiclass_df : std::false_type {}; 30 | 31 | template< class T, class DF1 > 32 | struct is_multiclass_df< dlib::one_vs_all_decision_function< T, DF1 > > 33 | : std::true_type {}; 34 | 35 | template< class T, class DF1 > 36 | struct is_multiclass_df< dlib::one_vs_one_decision_function< T, DF1 > > 37 | : std::true_type {}; 38 | 39 | // Template prototypes 40 | 41 | template< class T1, class T2 > 42 | void serialize2(const std::pair< T1, T2 > &x, std::ostream &s); 43 | 44 | template< class T1, class T2 > 45 | void deserialize2(std::pair< T1, T2 > &x, std::istream &s); 46 | 47 | template< class CharT > 48 | void serialize2(const std::basic_string< CharT > &x, std::ostream &s); 49 | 50 | template< class CharT > 51 | void deserialize2(std::basic_string< CharT > &x, std::istream &s); 52 | 53 | template< class T > 54 | void serialize2(const std::vector< T > &xs, std::ostream &s); 55 | 56 | template< class T > 57 | void deserialize2(std::vector< T > &xs, std::istream &s); 58 | 59 | template< class Key, class T > 60 | void serialize2(const std::multimap< Key, T > &xs, std::ostream &s); 61 | 62 | template< class Key, class T > 63 | void deserialize2(std::multimap< Key, T > &xs, std::istream &s); 64 | 65 | template< class T, long NR, long NC > 66 | void serialize2(const dlib::matrix< T, NR, NC > &x, std::ostream &s); 67 | 68 | template< class T, long NR, long NC > 69 | void deserialize2(dlib::matrix< T, NR, NC > &x, std::istream &s); 70 | 71 | template< class T > 72 | void serialize2(const dlib::unordered_pair< T > &xs, std::ostream &s); 73 | 74 | template< class T > 75 | void deserialize2(dlib::unordered_pair< T > &xs, std::istream &s); 76 | 77 | template< class T > 78 | void serialize2(const dlib::radial_basis_kernel< T > &x, std::ostream &s); 79 | 80 | template< class T > 81 | void deserialize2(dlib::radial_basis_kernel< T > &x, std::istream &s); 82 | 83 | template< class K > 84 | void serialize2(const dlib::decision_function< K > &x, std::ostream &s); 85 | 86 | template< class K > 87 | void deserialize2(dlib::decision_function< K > &x, std::istream &s); 88 | 89 | template< template< class... > class DF, class T, class DF1, class... DFS > 90 | typename std::enable_if< is_multiclass_df< DF< T, DF1, DFS... > >::value >::type 91 | serialize2(const DF< T, DF1, DFS... > &x, std::ostream &s); 92 | 93 | template< template< class... > class DF, class T, class DF1, class... DFS > 94 | typename std::enable_if< is_multiclass_df< DF< T, DF1, DFS... > >::value >::type 95 | deserialize2(DF< T, DF1, DFS... > &x, std::istream &s); 96 | 97 | // Arithmetic types 98 | 99 | template< class T > 100 | typename std::enable_if< std::is_arithmetic< T >::value >::type 101 | serialize2(const T &x, std::ostream &s) { 102 | assert(!std::is_floating_point< T >::value || !std::isnan(x)); 103 | if (!s.write(reinterpret_cast< const char * >(&x), sizeof(T))) 104 | throw serialization_error(); 105 | } 106 | 107 | template< class T > 108 | typename std::enable_if< std::is_arithmetic< T >::value >::type 109 | deserialize2(T &x, std::istream &s) { 110 | if (!s.read(reinterpret_cast< char * >(&x), sizeof(T))) 111 | throw serialization_error(); 112 | assert(!std::is_floating_point< T >::value || !std::isnan(x)); 113 | } 114 | 115 | // std::pair 116 | 117 | template< class T1, class T2 > 118 | void serialize2(const std::pair< T1, T2 > &x, std::ostream &s) { 119 | serialize2(x.first, s); 120 | serialize2(x.second, s); 121 | } 122 | 123 | template< class T1, class T2 > 124 | void deserialize2(std::pair< T1, T2 > &x, std::istream &s) { 125 | deserialize2(x.first, s); 126 | deserialize2(x.second, s); 127 | } 128 | 129 | // std::basic_string 130 | 131 | template< class CharT > 132 | void serialize2(const std::basic_string< CharT > &x, std::ostream &s) { 133 | const auto size = x.size(); 134 | serialize2(size, s); 135 | if (!s.write(&x[0], sizeof(CharT) * size)) 136 | throw serialization_error(); 137 | } 138 | 139 | template< class CharT > 140 | void deserialize2(std::basic_string< CharT > &x, std::istream &s) { 141 | typename std::basic_string< CharT >::size_type size; 142 | std::vector< CharT > data; 143 | deserialize2(size, s); 144 | data.resize(size); 145 | if (!s.read(&data[0], sizeof(CharT) * size)) 146 | throw serialization_error(); 147 | x = std::basic_string< CharT >(data.begin(), data.end()); 148 | } 149 | 150 | // std::vector 151 | 152 | template< class T > 153 | void serialize2(const std::vector< T > &xs, std::ostream &s) { 154 | const auto size = xs.size(); 155 | serialize2(size, s); 156 | for (const auto &x : xs) 157 | serialize2(x, s); 158 | } 159 | 160 | template< class T > 161 | void deserialize2(std::vector< T > &xs, std::istream &s) { 162 | typename std::vector< T >::size_type size; 163 | deserialize2(size, s); 164 | xs.resize(size); 165 | for (auto &x : xs) 166 | deserialize2(x, s); 167 | } 168 | 169 | // std::multimap 170 | 171 | template< class Key, class T > 172 | void serialize2(const std::multimap< Key, T > &xs, std::ostream &s) { 173 | const auto size = xs.size(); 174 | serialize2(size, s); 175 | for (auto &pair : xs) 176 | serialize2(pair, s); 177 | } 178 | 179 | template< class Key, class T > 180 | void deserialize2(std::multimap< Key, T > &xs, std::istream &s) { 181 | typename std::multimap< Key, T >::size_type size; 182 | deserialize2(size, s); 183 | for (typename std::multimap< Key, T >::size_type i = 0; i < size; ++i) { 184 | std::pair< Key, T > pair; 185 | deserialize2(pair, s); 186 | xs.insert(pair); 187 | } 188 | } 189 | 190 | // dlib::matrix 191 | 192 | template< class T, long NR, long NC > 193 | void serialize2(const dlib::matrix< T, NR, NC > &x, std::ostream &s) { 194 | const long rows = x.nr(); 195 | const long cols = x.nc(); 196 | serialize2(rows, s); 197 | serialize2(cols, s); 198 | for (long j = 0; j < rows; ++j) { 199 | for (long i = 0; i < cols; ++i) 200 | serialize2(x(j, i), s); 201 | } 202 | } 203 | 204 | template< class T, long NR, long NC > 205 | void deserialize2(dlib::matrix< T, NR, NC > &x, std::istream &s) { 206 | long rows, cols; 207 | deserialize2(rows, s); 208 | deserialize2(cols, s); 209 | x.set_size(rows, cols); 210 | for (long j = 0; j < rows; ++j) { 211 | for (long i = 0; i < cols; ++i) 212 | deserialize2(x(j, i), s); 213 | } 214 | } 215 | 216 | // dlib::unordered_pair 217 | 218 | template< class T > 219 | void serialize2(const dlib::unordered_pair< T > &xs, std::ostream &s) { 220 | serialize2(xs.first, s); 221 | serialize2(xs.second, s); 222 | } 223 | 224 | template< class T > 225 | void deserialize2(dlib::unordered_pair< T > &xs, std::istream &s) { 226 | deserialize2(const_cast< T & >(xs.first), s); 227 | deserialize2(const_cast< T & >(xs.second), s); 228 | } 229 | 230 | // dlib::radial_basis_kernel 231 | 232 | template< class T > 233 | void serialize2(const dlib::radial_basis_kernel< T > &x, std::ostream &s) { 234 | serialize2(x.gamma, s); 235 | } 236 | 237 | template< class T > 238 | void deserialize2(dlib::radial_basis_kernel< T > &x, std::istream &s) { 239 | typedef dlib::radial_basis_kernel< T > kernel_type; 240 | deserialize2(const_cast< typename kernel_type::scalar_type & >(x.gamma), s); 241 | } 242 | 243 | // dlib::decision_function 244 | 245 | template< class K > 246 | void serialize2(const dlib::decision_function< K > &x, std::ostream &s) { 247 | serialize2(x.alpha, s); 248 | serialize2(x.b, s); 249 | serialize2(x.kernel_function, s); 250 | serialize2(x.basis_vectors, s); 251 | } 252 | 253 | template< class K > 254 | void deserialize2(dlib::decision_function< K > &x, std::istream &s) { 255 | deserialize2(x.alpha, s); 256 | deserialize2(x.b, s); 257 | deserialize2(x.kernel_function, s); 258 | deserialize2(x.basis_vectors, s); 259 | } 260 | 261 | // dlib::one_vs_all_decision_function 262 | // dlib::one_vs_one_decision_function 263 | 264 | template< template< class... > class DF, class T, class DF1, class... DFS > 265 | typename std::enable_if< is_multiclass_df< DF< T, DF1, DFS... > >::value >::type 266 | serialize2(const DF< T, DF1, DFS... > &x, std::ostream &s) { 267 | const auto &dfs = x.get_binary_decision_functions(); 268 | const auto size = dfs.size(); 269 | serialize2(size, s); 270 | for (const auto &pair : dfs) { 271 | serialize2(std::make_pair(pair.first, 272 | dlib::any_cast< DF1 >(pair.second)), s); 273 | } 274 | } 275 | 276 | template< template< class... > class DF, class T, class DF1, class... DFS > 277 | typename std::enable_if< is_multiclass_df< DF< T, DF1, DFS... > >::value >::type 278 | deserialize2(DF< T, DF1, DFS... > &x, std::istream &s) { 279 | typedef DF< T, DF1, DFS... > df_type; 280 | typedef typename df_type::binary_function_table binary_function_table; 281 | typename binary_function_table::size_type size; 282 | binary_function_table dfs; 283 | deserialize2(size, s); 284 | for (typename binary_function_table::size_type i = 0; i < size; ++i) { 285 | std::pair< typename binary_function_table::key_type, DF1 > pair; 286 | deserialize2(pair, s); 287 | dfs.insert(pair); 288 | } 289 | x = df_type(dfs); 290 | } 291 | 292 | #endif 293 | -------------------------------------------------------------------------------- /src/svm.h: -------------------------------------------------------------------------------- 1 | #ifndef SVM_H 2 | #define SVM_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | // A trainer for one-vs-all multi-class classifiers 14 | template< class AnyTrainer, class LabelT, bool Verbose = false > 15 | struct one_vs_all_trainer2 { 16 | typedef LabelT label_type; 17 | typedef typename AnyTrainer::sample_type sample_type; 18 | typedef typename AnyTrainer::scalar_type scalar_type; 19 | typedef typename AnyTrainer::mem_manager_type mem_manager_type; 20 | typedef dlib::one_vs_all_decision_function< one_vs_all_trainer2 > 21 | trained_function_type; 22 | 23 | one_vs_all_trainer2(const AnyTrainer &trainer_) : trainer(trainer_) { 24 | } 25 | 26 | trained_function_type train(const std::vector< sample_type > &samples, 27 | const std::vector< label_type > &labels) const { 28 | assert(dlib::is_learning_problem(samples, labels)); 29 | 30 | const std::vector< label_type > distinct_labels = 31 | dlib::select_all_distinct_labels(labels); 32 | typename trained_function_type::binary_function_table dfs; 33 | 34 | #pragma omp parallel 35 | { 36 | std::vector< scalar_type > set_labels; 37 | 38 | #pragma omp for schedule(dynamic) 39 | for (typename std::vector< label_type >::size_type i = 0; 40 | i < distinct_labels.size(); ++i) { 41 | const label_type &label = distinct_labels[i]; 42 | set_labels.clear(); 43 | 44 | // Set up the one-vs-all training set. 45 | for (typename std::vector< sample_type >::size_type k = 0; 46 | k < samples.size(); ++k) 47 | set_labels.push_back((labels[k] == label) ? 1 : -1); 48 | 49 | if (Verbose) { 50 | #pragma omp critical 51 | { 52 | std::cout << "Training classifier " << i + 1 << '/' 53 | << distinct_labels.size() << "...\n"; 54 | } 55 | } 56 | 57 | // Train the classifier. 58 | typename trained_function_type::binary_function_table::mapped_type 59 | df = trainer.train(samples, set_labels); 60 | 61 | #pragma omp critical 62 | { 63 | dfs[label] = df; 64 | } 65 | } 66 | } 67 | 68 | return trained_function_type(dfs); 69 | } 70 | 71 | private: 72 | AnyTrainer trainer; 73 | }; 74 | 75 | // A trainer for one-vs-one multi-class classifiers 76 | template< class AnyTrainer, class LabelT, bool Verbose = false > 77 | struct one_vs_one_trainer2 { 78 | typedef LabelT label_type; 79 | typedef typename AnyTrainer::sample_type sample_type; 80 | typedef typename AnyTrainer::scalar_type scalar_type; 81 | typedef typename AnyTrainer::mem_manager_type mem_manager_type; 82 | typedef dlib::one_vs_one_decision_function< one_vs_one_trainer2 > 83 | trained_function_type; 84 | 85 | one_vs_one_trainer2(const AnyTrainer &trainer_) : trainer(trainer_) { 86 | } 87 | 88 | trained_function_type train(const std::vector< sample_type > &samples, 89 | const std::vector< label_type > &labels) const { 90 | assert(dlib::is_learning_problem(samples, labels)); 91 | 92 | const std::vector< label_type > distinct_labels = 93 | dlib::select_all_distinct_labels(labels); 94 | typename trained_function_type::binary_function_table dfs; 95 | 96 | typename std::vector< label_type >::size_type n = 0; 97 | #pragma omp parallel 98 | { 99 | std::vector< sample_type > set_samples; 100 | std::vector< scalar_type > set_labels; 101 | 102 | #pragma omp for schedule(dynamic) 103 | for (typename std::vector< label_type >::size_type i = 0; 104 | i < distinct_labels.size(); ++i) { 105 | const label_type &label1 = distinct_labels[i]; 106 | for (typename std::vector< label_type >::size_type j = i + 1; 107 | j < distinct_labels.size(); ++j) { 108 | const label_type &label2 = distinct_labels[j]; 109 | const dlib::unordered_pair< label_type > pair(label1, label2); 110 | 111 | // Set up the one-vs-one training set. 112 | set_samples.clear(); 113 | set_labels.clear(); 114 | for (typename std::vector< sample_type >::size_type k = 0; 115 | k < samples.size(); ++k) { 116 | if (labels[k] == pair.first) { 117 | set_samples.push_back(samples[k]); 118 | set_labels.push_back(1); 119 | } 120 | else if (labels[k] == pair.second) { 121 | set_samples.push_back(samples[k]); 122 | set_labels.push_back(-1); 123 | } 124 | } 125 | 126 | if (Verbose) { 127 | #pragma omp critical 128 | { 129 | std::cout << "Training classifier " << n + 1 << '/' 130 | << distinct_labels.size() * (distinct_labels.size() - 1) / 2 131 | << "...\n"; 132 | ++n; 133 | } 134 | } 135 | 136 | // Train the classifier. 137 | typename trained_function_type::binary_function_table::mapped_type 138 | df = trainer.train(set_samples, set_labels); 139 | 140 | #pragma omp critical 141 | { 142 | dfs[pair] = df; 143 | } 144 | } 145 | } 146 | } 147 | 148 | return trained_function_type(dfs); 149 | } 150 | 151 | private: 152 | AnyTrainer trainer; 153 | }; 154 | 155 | // Run a multi-class decision function on a test set, returning the confusion 156 | // matrix. 157 | template< class DF, class SampleT, class LabelT, bool Verbose = false > 158 | const dlib::matrix< double > test_multiclass_decision_function2( 159 | const DF &df, const std::vector< SampleT > &test_samples, 160 | const std::vector< LabelT > &test_labels) { 161 | typedef std::map< LabelT, typename std::vector< LabelT >::size_type > 162 | label_count_map_type; 163 | typedef typename DF::mem_manager_type mem_manager_type; 164 | 165 | assert(is_learning_problem(test_samples, test_labels)); 166 | 167 | const std::vector< LabelT > &labels = df.get_labels(); 168 | 169 | label_count_map_type label_offsets; 170 | for (typename std::vector< LabelT >::size_type i = 0; 171 | i < labels.size(); ++i) 172 | label_offsets[labels[i]] = i; 173 | 174 | dlib::matrix< double, 0, 0, mem_manager_type > conf(labels.size(), 175 | labels.size()); 176 | conf = 0; 177 | 178 | typename std::vector< LabelT >::size_type n = 0; 179 | #pragma omp parallel for 180 | for (typename std::vector< SampleT >::size_type i = 0; 181 | i < test_samples.size(); ++i) { 182 | const auto it = label_offsets.find(test_labels[i]); 183 | assert(it != label_offsets.end()); 184 | 185 | if (Verbose) { 186 | #pragma omp critical 187 | { 188 | std::cout << "Classifying sample " << n + 1 << '/' 189 | << test_samples.size() << "...\n"; 190 | ++n; 191 | } 192 | } 193 | 194 | const auto pred_offset = label_offsets.find(df(test_samples[i]))->second; 195 | 196 | #pragma omp critical 197 | { 198 | ++conf(it->second, pred_offset); 199 | } 200 | } 201 | 202 | return conf; 203 | } 204 | 205 | // Cross-validation for multi-class classifiers 206 | template< class Trainer, class SampleT, class LabelT, bool Verbose = false > 207 | const dlib::matrix< double > cross_validate_multiclass_trainer2( 208 | const Trainer &trainer, const std::vector< SampleT > &samples, 209 | const std::vector< LabelT > &labels, const unsigned long folds) { 210 | typedef std::map< LabelT, typename std::vector< LabelT >::size_type > 211 | label_count_map_type; 212 | typedef typename Trainer::mem_manager_type mem_manager_type; 213 | 214 | assert(is_learning_problem(samples, labels) && 1 < folds && 215 | folds <= samples.size()); 216 | 217 | const std::vector< LabelT > distinct_labels = 218 | dlib::select_all_distinct_labels(labels); 219 | 220 | // Count the occurrences of each label. 221 | label_count_map_type label_counts; 222 | for (const auto &label : labels) 223 | ++label_counts[label]; 224 | 225 | // Determine the sizes of the test and the training sets. 226 | label_count_map_type test_sizes, train_sizes; 227 | for (const auto &pair : label_counts) { 228 | const typename label_count_map_type::mapped_type test_size = 229 | pair.second / folds; 230 | if (!test_size) { 231 | std::ostringstream ss; 232 | ss << "In cross_validate_multiclass_trainer2(), the number of folds" 233 | " was larger than the number of elements in one of the training" 234 | " classes.\n folds: " << folds << "\n size of class: " 235 | << pair.second << '\n'; 236 | throw dlib::cross_validation_error(ss.str()); 237 | } 238 | 239 | test_sizes[pair.first] = test_size; 240 | train_sizes[pair.first] = pair.second - test_size; 241 | } 242 | 243 | dlib::matrix< double, 0, 0, mem_manager_type > conf(labels.size(), 244 | labels.size()); 245 | conf = 0; 246 | 247 | label_count_map_type next_offsets; 248 | 249 | std::vector< SampleT > test_samples, train_samples; 250 | std::vector< LabelT > test_labels, train_labels; 251 | 252 | // Train and test with each fold configuration. 253 | for (unsigned long i = 0; i < folds; ++i) { 254 | test_samples.clear(); 255 | train_samples.clear(); 256 | test_labels.clear(); 257 | train_labels.clear(); 258 | 259 | // Load the test samples. 260 | for (const auto &label : distinct_labels) { 261 | const auto test_size = test_sizes[label]; 262 | 263 | unsigned long &next_offset = next_offsets[label]; 264 | unsigned long size = 0; 265 | while (size < test_size) { 266 | if (labels[next_offset] == label) { 267 | test_samples.push_back(samples[next_offset]); 268 | test_labels.push_back(label); 269 | ++size; 270 | } 271 | 272 | next_offset = (next_offset + 1) % samples.size(); 273 | } 274 | } 275 | 276 | // Load the training samples. 277 | for (const auto &label : distinct_labels) { 278 | const auto train_size = train_sizes[label]; 279 | 280 | unsigned long &next_offset = next_offsets[label]; 281 | unsigned long size = 0; 282 | while (size < train_size) { 283 | if (labels[next_offset] == label) { 284 | train_samples.push_back(samples[next_offset]); 285 | train_labels.push_back(label); 286 | ++size; 287 | } 288 | 289 | next_offset = (next_offset + 1) % samples.size(); 290 | } 291 | } 292 | 293 | if (Verbose) { 294 | std::cout << "Running cross-validation on fold " << i + 1 << '/' 295 | << folds << "...\n"; 296 | } 297 | 298 | conf += test_multiclass_decision_function2< 299 | typename Trainer::trained_function_type, SampleT, LabelT, Verbose >( 300 | trainer.train(train_samples, train_labels), test_samples, test_labels); 301 | } 302 | 303 | return conf; 304 | } 305 | 306 | #endif 307 | -------------------------------------------------------------------------------- /src/gui.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #include "features.h" 23 | #include "io.h" 24 | #include "svm.h" 25 | #include "types.h" 26 | 27 | struct point { 28 | double x, y; 29 | }; 30 | 31 | typedef std::vector< point > path_type; 32 | 33 | // The width of a line as a fraction of the size of the image. 34 | const double line_width = 0.00375; 35 | 36 | const int sketch_timeout = 500; // ms 37 | const int sketch_min_size = 256; // px 38 | 39 | class SketchArea : public Gtk::DrawingArea 40 | { 41 | public: 42 | SketchArea() { 43 | set_size_request(sketch_min_size, sketch_min_size); 44 | add_events(Gdk::BUTTON_PRESS_MASK | Gdk::EXPOSURE_MASK | 45 | Gdk::POINTER_MOTION_MASK | Gdk::POINTER_MOTION_HINT_MASK); 46 | } 47 | 48 | virtual ~SketchArea() { 49 | } 50 | 51 | // Draw the current sketch to a matrix, converting it to grayscale. 52 | template< class T, long NR, long NC > 53 | void draw(dlib::matrix< T, NR, NC > &image) const { 54 | Cairo::RefPtr< Cairo::ImageSurface > surface = 55 | Cairo::ImageSurface::create(Cairo::FORMAT_RGB24, NC, NR); 56 | Cairo::RefPtr< Cairo::Context > cr = Cairo::Context::create(surface); 57 | 58 | cr->set_source_rgb(1.0, 1.0, 1.0); 59 | cr->paint(); 60 | 61 | cr->scale(NC, NR); 62 | draw(cr); 63 | 64 | cr->show_page(); 65 | 66 | { 67 | const unsigned char *p = surface->get_data(); 68 | const int stride = surface->get_stride(); 69 | 70 | const unsigned char *q = p; 71 | for (long j = 0; j < NR; ++j, q = p += stride) { 72 | for (long i = 0; i < NC; ++i, q += 4) { 73 | // Convert the image to grayscale. 74 | image(j, i) = (.299 * q[0] + .587 * q[1] + .114 * q[2]) / 255.; 75 | } 76 | } 77 | } 78 | } 79 | 80 | // Scale and center the image to fit the canvas. 81 | void scale() { 82 | if (paths.empty()) 83 | return; 84 | 85 | double x0, x1, y0, y1; 86 | x0 = y0 = 1.0; 87 | x1 = y1 = 0.0; 88 | for (const auto &path : paths) { 89 | for (const point &p : path) { 90 | x0 = std::min(x0, p.x); 91 | x1 = std::max(x1, p.x); 92 | y0 = std::min(y0, p.y); 93 | y1 = std::max(y1, p.y); 94 | } 95 | } 96 | 97 | const double dx = x1 - x0; 98 | const double dy = y1 - y0; 99 | const double dmax = std::max(dx, dy); 100 | 101 | for (auto &path : paths) { 102 | for (point &p : path) { 103 | p.x = 0.5 + (p.x - x0 - dx / 2.0) * 0.8 / dmax; 104 | p.y = 0.5 + (p.y - y0 - dy / 2.0) * 0.8 / dmax; 105 | } 106 | } 107 | 108 | signal_update.emit(); 109 | invalidate(); 110 | } 111 | 112 | // Clear the sketch, removing all paths. 113 | void clear() { 114 | paths.clear(); 115 | signal_update.emit(); 116 | invalidate(); 117 | } 118 | 119 | sigc::signal< void > signal_update; 120 | 121 | protected: 122 | // Invalidate the entire drawing area. 123 | void invalidate() { 124 | Gtk::Allocation allocation = get_allocation(); 125 | const int width = allocation.get_width(); 126 | const int height = allocation.get_height(); 127 | Gdk::Rectangle rect(0, 0, width, height); 128 | get_window()->invalidate_rect(rect, false); 129 | } 130 | 131 | // Draw all paths to a Cairo context. 132 | void draw(const Cairo::RefPtr< Cairo::Context > &cr) const { 133 | cr->set_source_rgb(0.0, 0.0, 0.0); 134 | cr->set_line_width(line_width); 135 | cr->set_line_cap(Cairo::LINE_CAP_ROUND); 136 | cr->set_line_join(Cairo::LINE_JOIN_ROUND); 137 | 138 | for (const auto &path : paths) { 139 | if (path.empty()) 140 | continue; 141 | 142 | path_type::const_iterator it = path.begin(); 143 | cr->move_to(it->x, it->y); 144 | for (; it < path.end(); ++it) 145 | cr->line_to(it->x, it->y); 146 | cr->stroke(); 147 | } 148 | } 149 | 150 | virtual bool on_button_press_event(GdkEventButton *event) { 151 | if (event->button == 1) { 152 | // Create a new path starting at the current cursor position. 153 | if (paths.empty() || !get_path().empty()) { 154 | paths.push_back(path_type()); 155 | add_point(event->x, event->y); 156 | } 157 | } 158 | 159 | return true; 160 | } 161 | 162 | virtual bool on_expose_event(GdkEventExpose *) { 163 | Cairo::RefPtr< Cairo::Context > cr = get_window()->create_cairo_context(); 164 | 165 | Gtk::Allocation allocation = get_allocation(); 166 | const int width = allocation.get_width(); 167 | const int height = allocation.get_height(); 168 | 169 | cr->set_source_rgb(1.0, 1.0, 1.0); 170 | cr->paint(); 171 | 172 | cr->scale(width, height); 173 | draw(cr); 174 | 175 | cr->show_page(); 176 | 177 | return true; 178 | } 179 | 180 | virtual bool on_motion_notify_event(GdkEventMotion *event) { 181 | int x, y; 182 | Gdk::ModifierType state; 183 | 184 | if (event->is_hint) { 185 | get_window()->get_pointer(x, y, state); 186 | } 187 | else { 188 | x = static_cast< int >(event->x); 189 | y = static_cast< int >(event->y); 190 | state = static_cast< Gdk::ModifierType >(event->state); 191 | } 192 | 193 | if (!paths.empty() && state & Gdk::BUTTON1_MASK) { 194 | // Add a point to the current path if the first button is pressed. 195 | add_point(x, y); 196 | } 197 | 198 | return true; 199 | } 200 | 201 | // Return the current path. 202 | path_type &get_path() { 203 | assert(!paths.empty()); 204 | return paths.back(); 205 | } 206 | 207 | // Add a point (in widget coordinates) to the current path. 208 | void add_point(int x, int y) { 209 | Gtk::Allocation allocation = get_allocation(); 210 | const int width = allocation.get_width(); 211 | const int height = allocation.get_height(); 212 | 213 | point p; 214 | double x0, x1, y0, y1; 215 | p.x = x0 = x1 = static_cast< double >(x) / width; 216 | p.y = y0 = y1 = static_cast< double >(y) / height; 217 | 218 | path_type &path = get_path(); 219 | if (!path.empty()) { 220 | const point &pp = path.back(); 221 | x0 = std::min(x0, pp.x); 222 | x1 = std::max(x1, pp.x); 223 | y0 = std::min(y0, pp.y); 224 | y1 = std::max(y1, pp.y); 225 | } 226 | 227 | path.push_back(p); 228 | 229 | Gdk::Rectangle rect(std::floor((x0 - line_width) * width), 230 | std::floor((y0 - line_width) * height), 231 | std::ceil((x1 - x0 + 2 * line_width) * width), 232 | std::ceil((y1 - y0 + 2 * line_width) * height)); 233 | 234 | signal_update.emit(); 235 | get_window()->invalidate_rect(rect, false); 236 | } 237 | 238 | std::vector< path_type > paths; 239 | }; 240 | 241 | class MainWindow : public Gtk::Window 242 | { 243 | public: 244 | MainWindow(const vocab_type *vocab_, const std::map< int, std::string > 245 | *cat_map_, bool ova_, df_type *df_) : hbox(true, 10), 246 | vocab(vocab_), cat_map(cat_map_), ova(ova_), df(df_) { 247 | // Set up the window. 248 | set_title("Sketch recognition"); 249 | set_size_request(800, 400); 250 | 251 | add(vbox); 252 | 253 | Glib::RefPtr< Gtk::ActionGroup > action_group = 254 | Gtk::ActionGroup::create(); 255 | 256 | action_group->add(Gtk::Action::create("New", Gtk::Stock::NEW, "_New", 257 | "Create a new sketch"), sigc::mem_fun(*this, &MainWindow::on_new)); 258 | action_group->add(Gtk::Action::create("ScaleToFit", Gtk::Stock::ZOOM_FIT, 259 | "_Scale to Fit", "Scale the current sketch to fit the canvas"), 260 | sigc::mem_fun(*this, &MainWindow::on_scale)); 261 | action_group->add(Gtk::Action::create("Quit", Gtk::Stock::QUIT, "_Quit"), 262 | sigc::mem_fun(*this, &MainWindow::on_quit)); 263 | 264 | Glib::RefPtr< Gtk::UIManager > ui_manager = Gtk::UIManager::create(); 265 | ui_manager->insert_action_group(action_group); 266 | add_accel_group(ui_manager->get_accel_group()); 267 | 268 | Glib::ustring ui_info = 269 | "" 270 | " " 271 | " " 272 | " " 273 | " " 274 | " " 275 | ""; 276 | ui_manager->add_ui_from_string(ui_info); 277 | 278 | vbox.pack_start(*ui_manager->get_widget("/ToolBar"), Gtk::PACK_SHRINK); 279 | 280 | hbox.set_border_width(10); 281 | vbox.pack_start(hbox); 282 | 283 | sketch_frame.set_shadow_type(Gtk::SHADOW_NONE); 284 | hbox.pack_start(sketch_frame); 285 | 286 | sketch.signal_update.connect(sigc::mem_fun(*this, 287 | &MainWindow::on_sketch_update)); 288 | sketch_frame.add(sketch); 289 | 290 | cat_label.set_text("Draw in the box to begin."); 291 | cat_label.set_justify(Gtk::JUSTIFY_CENTER); 292 | cat_label.set_line_wrap(); 293 | hbox.pack_start(cat_label); 294 | 295 | show_all(); 296 | } 297 | 298 | virtual ~MainWindow() { 299 | } 300 | 301 | protected: 302 | virtual void on_new() { 303 | sketch.clear(); 304 | } 305 | 306 | virtual void on_scale() { 307 | sketch.scale(); 308 | } 309 | 310 | virtual void on_quit() { 311 | hide(); 312 | } 313 | 314 | virtual void on_sketch_update() { 315 | if (sketch_timer_conn.connected()) 316 | sketch_timer_conn.disconnect(); 317 | 318 | sketch_timer_conn = Glib::signal_timeout().connect(sigc::mem_fun(*this, 319 | &MainWindow::on_sketch_timeout), sketch_timeout); 320 | } 321 | 322 | virtual bool on_sketch_timeout() { 323 | image_type image; 324 | sketch.draw(image); 325 | image = 1. - image; 326 | 327 | std::vector< feature_desc_type > descs; 328 | extract_descriptors(image, descs); 329 | 330 | feature_hist_type hist; 331 | feature_hist(descs, *vocab, hist); 332 | 333 | const int cat = ova ? df->get< ova_df_type >()(hist) : 334 | df->get< ovo_df_type >()(hist); 335 | 336 | const auto it = cat_map->find(cat); 337 | assert(it != cat_map->end()); 338 | 339 | std::ostringstream ss; 340 | ss << "" << it->second << ""; 341 | cat_label.set_markup(ss.str()); 342 | 343 | return false; 344 | } 345 | 346 | sigc::connection sketch_timer_conn; 347 | 348 | Gtk::VBox vbox; 349 | Gtk::HBox hbox; 350 | Gtk::AspectFrame sketch_frame; 351 | SketchArea sketch; 352 | Gtk::Label cat_label; 353 | 354 | const vocab_type *vocab; 355 | const std::map< int, std::string > *cat_map; 356 | bool ova; 357 | df_type *df; 358 | }; 359 | 360 | int main(int argc, char* argv[]) 361 | { 362 | Gtk::Main app(argc, argv); 363 | 364 | // Process the command-line arguments. 365 | const char *vocab_path = "vocab.out"; 366 | const char *map_path = "map_id_label.txt"; 367 | const char *cats_path = "cats.out"; 368 | bool ova = true; 369 | 370 | { 371 | int i; 372 | for (i = 1; i < argc; ++i) { 373 | if (!strcmp(argv[i], "-h")) { 374 | goto usage; 375 | } 376 | else if (!strcmp(argv[i], "-v")) { 377 | vocab_path = argv[++i]; 378 | } 379 | else if (!strcmp(argv[i], "-m")) { 380 | map_path = argv[++i]; 381 | } 382 | else if (!strcmp(argv[i], "-c")) { 383 | ++i; 384 | if (!strcmp(argv[i], "ova")) { 385 | ova = true; 386 | } 387 | else if (!strcmp(argv[i], "ovo")) { 388 | ova = false; 389 | } 390 | else { 391 | std::cerr << argv[0] << ": Unsupported classifier: `" << argv[i] 392 | << "'\n"; 393 | goto err; 394 | } 395 | } 396 | else { 397 | break; 398 | } 399 | } 400 | 401 | if (i < argc) 402 | cats_path = argv[i++]; 403 | 404 | if (i != argc) 405 | goto usage; 406 | 407 | if (!vocab_path || !map_path || !cats_path) 408 | goto usage; 409 | } 410 | 411 | { 412 | // Load the vocabulary. 413 | vocab_type vocab; 414 | { 415 | std::ifstream fs(vocab_path, std::ios::binary); 416 | deserialize2(vocab, fs); 417 | } 418 | 419 | // Load the category map. 420 | std::map< int, std::string > cat_map; 421 | { 422 | std::ifstream fs(map_path); 423 | for (std::string line; std::getline(fs, line);) { 424 | std::istringstream ss(line); 425 | int i; 426 | std::string label; 427 | ss >> i; 428 | ss.get(); // ',' 429 | std::getline(ss, label); 430 | cat_map[i] = label; 431 | } 432 | } 433 | 434 | // Load the category classifier. 435 | df_type df; 436 | { 437 | std::ifstream fs(cats_path, std::ios::binary); 438 | if (ova) 439 | deserialize2(df.get< ova_df_type >(), fs); 440 | else 441 | deserialize2(df.get< ovo_df_type >(), fs); 442 | } 443 | 444 | MainWindow win(&vocab, &cat_map, ova, &df); 445 | app.run(win); 446 | } 447 | 448 | return 0; 449 | 450 | usage: 451 | std::cerr << "Usage: " << argv[0] 452 | << " [-v vocab-file] [-m map-file] [-c classifier] [cats-file]\n"; 453 | 454 | err: 455 | return 1; 456 | } 457 | --------------------------------------------------------------------------------