├── .gitignore ├── LICENSE ├── README.md ├── cpp ├── CMake │ ├── FindEigen.cmake │ └── FindHalide.cmake ├── CMakeLists.txt ├── chrono.h ├── cimg │ └── CImg.h ├── histogramIO.h ├── inverseWasserstein.h ├── kernels.h ├── lbfgs │ ├── arithmetic_ansi.h │ ├── arithmetic_sse_double.h │ ├── arithmetic_sse_float.h │ ├── lbfgs.c │ └── lbfgs.h ├── loss.h ├── main_dictionary_learning.cpp ├── signArray.h └── sse_helpers.h └── data ├── imgheart2 ├── heart2_seq-00.png ├── heart2_seq-01.png ├── heart2_seq-02.png ├── heart2_seq-03.png ├── heart2_seq-04.png ├── heart2_seq-05.png ├── heart2_seq-06.png ├── heart2_seq-07.png ├── heart2_seq-08.png ├── heart2_seq-09.png ├── heart2_seq-10.png ├── heart2_seq-11.png └── heart2_seq-12.png └── mug_001_expr2 ├── img_0021.png ├── img_0023.png ├── img_0026.png ├── img_0028.png ├── img_0029.png ├── img_0032.png ├── img_0034.png ├── img_0036.png ├── img_0038.png ├── img_0039.png ├── img_0040.png ├── img_0041.png ├── img_0043.png ├── img_0044.png ├── img_0045.png ├── img_0046.png ├── img_0047.png ├── img_0050.png ├── img_0110.png └── img_0117.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files 2 | *.slo 3 | *.lo 4 | *.o 5 | *.obj 6 | 7 | # Precompiled Headers 8 | *.gch 9 | *.pch 10 | 11 | # Compiled Dynamic libraries 12 | *.so 13 | *.dylib 14 | # *.dll 15 | 16 | # Fortran module files 17 | *.mod 18 | *.smod 19 | 20 | # Compiled Static libraries 21 | *.lai 22 | *.la 23 | *.a 24 | # *.lib 25 | 26 | # Executables 27 | *.exe 28 | *.out 29 | *.app 30 | 31 | # IDE files 32 | .tags* 33 | CMakeLists.txt.user 34 | 35 | # Platform-specific files 36 | *.DS_Store 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | 9 | This version of the GNU Lesser General Public License incorporates 10 | the terms and conditions of version 3 of the GNU General Public 11 | License, supplemented by the additional permissions listed below. 12 | 13 | 0. Additional Definitions. 14 | 15 | As used herein, "this License" refers to version 3 of the GNU Lesser 16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU 17 | General Public License. 18 | 19 | "The Library" refers to a covered work governed by this License, 20 | other than an Application or a Combined Work as defined below. 21 | 22 | An "Application" is any work that makes use of an interface provided 23 | by the Library, but which is not otherwise based on the Library. 24 | Defining a subclass of a class defined by the Library is deemed a mode 25 | of using an interface provided by the Library. 26 | 27 | A "Combined Work" is a work produced by combining or linking an 28 | Application with the Library. The particular version of the Library 29 | with which the Combined Work was made is also called the "Linked 30 | Version". 31 | 32 | The "Minimal Corresponding Source" for a Combined Work means the 33 | Corresponding Source for the Combined Work, excluding any source code 34 | for portions of the Combined Work that, considered in isolation, are 35 | based on the Application, and not on the Linked Version. 36 | 37 | The "Corresponding Application Code" for a Combined Work means the 38 | object code and/or source code for the Application, including any data 39 | and utility programs needed for reproducing the Combined Work from the 40 | Application, but excluding the System Libraries of the Combined Work. 41 | 42 | 1. Exception to Section 3 of the GNU GPL. 43 | 44 | You may convey a covered work under sections 3 and 4 of this License 45 | without being bound by section 3 of the GNU GPL. 46 | 47 | 2. Conveying Modified Versions. 48 | 49 | If you modify a copy of the Library, and, in your modifications, a 50 | facility refers to a function or data to be supplied by an Application 51 | that uses the facility (other than as an argument passed when the 52 | facility is invoked), then you may convey a copy of the modified 53 | version: 54 | 55 | a) under this License, provided that you make a good faith effort to 56 | ensure that, in the event an Application does not supply the 57 | function or data, the facility still operates, and performs 58 | whatever part of its purpose remains meaningful, or 59 | 60 | b) under the GNU GPL, with none of the additional permissions of 61 | this License applicable to that copy. 62 | 63 | 3. Object Code Incorporating Material from Library Header Files. 64 | 65 | The object code form of an Application may incorporate material from 66 | a header file that is part of the Library. You may convey such object 67 | code under terms of your choice, provided that, if the incorporated 68 | material is not limited to numerical parameters, data structure 69 | layouts and accessors, or small macros, inline functions and templates 70 | (ten or fewer lines in length), you do both of the following: 71 | 72 | a) Give prominent notice with each copy of the object code that the 73 | Library is used in it and that the Library and its use are 74 | covered by this License. 75 | 76 | b) Accompany the object code with a copy of the GNU GPL and this license 77 | document. 78 | 79 | 4. Combined Works. 80 | 81 | You may convey a Combined Work under terms of your choice that, 82 | taken together, effectively do not restrict modification of the 83 | portions of the Library contained in the Combined Work and reverse 84 | engineering for debugging such modifications, if you also do each of 85 | the following: 86 | 87 | a) Give prominent notice with each copy of the Combined Work that 88 | the Library is used in it and that the Library and its use are 89 | covered by this License. 90 | 91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license 92 | document. 93 | 94 | c) For a Combined Work that displays copyright notices during 95 | execution, include the copyright notice for the Library among 96 | these notices, as well as a reference directing the user to the 97 | copies of the GNU GPL and this license document. 98 | 99 | d) Do one of the following: 100 | 101 | 0) Convey the Minimal Corresponding Source under the terms of this 102 | License, and the Corresponding Application Code in a form 103 | suitable for, and under terms that permit, the user to 104 | recombine or relink the Application with a modified version of 105 | the Linked Version to produce a modified Combined Work, in the 106 | manner specified by section 6 of the GNU GPL for conveying 107 | Corresponding Source. 108 | 109 | 1) Use a suitable shared library mechanism for linking with the 110 | Library. A suitable mechanism is one that (a) uses at run time 111 | a copy of the Library already present on the user's computer 112 | system, and (b) will operate properly with a modified version 113 | of the Library that is interface-compatible with the Linked 114 | Version. 115 | 116 | e) Provide Installation Information, but only if you would otherwise 117 | be required to provide such information under section 6 of the 118 | GNU GPL, and only to the extent that such information is 119 | necessary to install and execute a modified version of the 120 | Combined Work produced by recombining or relinking the 121 | Application with a modified version of the Linked Version. (If 122 | you use option 4d0, the Installation Information must accompany 123 | the Minimal Corresponding Source and Corresponding Application 124 | Code. If you use option 4d1, you must provide the Installation 125 | Information in the manner specified by section 6 of the GNU GPL 126 | for conveying Corresponding Source.) 127 | 128 | 5. Combined Libraries. 129 | 130 | You may place library facilities that are a work based on the 131 | Library side by side in a single library together with other library 132 | facilities that are not Applications and are not covered by this 133 | License, and convey such a combined library under terms of your 134 | choice, if you do both of the following: 135 | 136 | a) Accompany the combined library with a copy of the same work based 137 | on the Library, uncombined with any other library facilities, 138 | conveyed under the terms of this License. 139 | 140 | b) Give prominent notice with the combined library that part of it 141 | is a work based on the Library, and explaining where to find the 142 | accompanying uncombined form of the same work. 143 | 144 | 6. Revised Versions of the GNU Lesser General Public License. 145 | 146 | The Free Software Foundation may publish revised and/or new versions 147 | of the GNU Lesser General Public License from time to time. Such new 148 | versions will be similar in spirit to the present version, but may 149 | differ in detail to address new problems or concerns. 150 | 151 | Each version is given a distinguishing version number. If the 152 | Library as you received it specifies that a certain numbered version 153 | of the GNU Lesser General Public License "or any later version" 154 | applies to it, you have the option of following the terms and 155 | conditions either of that published version or of any later version 156 | published by the Free Software Foundation. If the Library as you 157 | received it does not specify a version number of the GNU Lesser 158 | General Public License, you may choose any version of the GNU Lesser 159 | General Public License ever published by the Free Software Foundation. 160 | 161 | If the Library as you received it specifies that a proxy can decide 162 | whether future versions of the GNU Lesser General Public License shall 163 | apply, that proxy's public statement of acceptance of any version is 164 | permanent authorization for you to choose that version for the 165 | Library. 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Wasserstein Dictionary Learning [Schmitz et al. 2018] 2 | 3 | This repository contains the code for the following publication. Please credit this reference if you use it. 4 | 5 | @article{schmitz_wasserstein_2018, 6 | title = {Wasserstein {Dictionary} {Learning}: {Optimal} {Transport}-based unsupervised non-linear dictionary learning}, 7 | shorttitle = {Wasserstein {Dictionary} {Learning}}, 8 | url = {https://hal.archives-ouvertes.fr/hal-01717943}, 9 | journal = {SIAM Journal on Imaging Sciences}, 10 | author = {Schmitz, Morgan A and Heitz, Matthieu and Bonneel, Nicolas and Ngolè Mboula, Fred Maurice and Coeurjolly, David and Cuturi, Marco and Peyré, Gabriel and Starck, Jean-Luc}, 11 | year = {2018}, 12 | keywords = {Dictionary Learning, Optimal Transport, Wasserstein barycenter}, 13 | } 14 | 15 | The full text is available on [HAL](https://hal.archives-ouvertes.fr/hal-01717943) and [arXiv](https://arxiv.org/abs/1708.01955). 16 | 17 | 18 | ### Configure, build and run 19 | 20 | There is a CMakeLists.txt for the project, so you can just create a build directory *outside the source*. 21 | 22 | $ mkdir build 23 | $ cd build 24 | $ ccmake ../inverseWasserstein/cpp/ 25 | 26 | In Cmake, you can activate the different options and targets, then configure and generate a project file for your system (Makefile, VS project, etc.). 27 | 28 | Target : 29 | - `BUILD_APP_DICTIONARY_LEARNING` : Target for Dictionary Learning on images (2D histograms) 30 | 31 | The different options allow to link different libraries in order to use different kernels (see below) 32 | - `WITH_HALIDE` 33 | - `WITH_EIGEN` 34 | - `WITH_AVX_SUPPORT` 35 | - `WITH_OPENMP` 36 | 37 | ### Kernels 38 | 39 | The kernel is the part of the algorithm that does the convolution in the Sinkhorn barycenter algorithm. We provide multiple kernels in `kernels.h`. 40 | Some of them use external libraries that are activated with the options mentionned above. 41 | You can change the kernel used, by changing the `typedef XXX KernelType` in `main_dictionary_learning.cpp`. 42 | 43 | ### Histogram IO 44 | 45 | The files histogramIO.h and histogramIO.cpp contain helper functions to load many kinds of histograms (1D,2D,3D,nD). They also contain a bunch of classes to export histograms (in different formats), that are used during the optimization. 46 | The example given only uses the functions for 2D histograms, but you can easily adapt those examples to other types of histograms thanks to these helper functions. 47 | 48 | ### Halide 49 | 50 | [Halide](http://halide-lang.org/) is a language for image processing that uses JIT compilation, and that allows fast convolutions (~3x faster than with SSE). It is optionnal, but we recommend its use. You will need to set the CMake variable HALIDE_ROOT_DIR to the Halide folder you downloaded, in order to use it. 51 | 52 | ### Files 53 | 54 | The 4 important files of the algorithm are : 55 | `inverseWasserstein.h`: Core of the algorithm 56 | `kernels.h`: Classes that compute convolutions in different ways for the Sinkhorn barycenter algorithm. 57 | `loss.h`: Loss functions and their gradients 58 | `histogramIO.h`: Classes and functions for reading and writing histograms of different dimensions, in different format. 59 | 60 | The file `main_dictionary_learning.cpp` is an example of how to use the class `inverseWasserstein.h::WassersteinRegression` for your application. 61 | 62 | The other files are : 63 | `chrono.h`: Measure execution time 64 | `signArray.h`: Needed if you are using log-domain kernels (see Extensions). 65 | `sse_helpers`: Needed if you are using SSE-based kernels. 66 | 67 | 68 | ### Examples 69 | 70 | For the `app_dictionary_learning` program, here is a simple example that runs the algorithm on some toy images. 71 | 72 | - Cardiac cycle : (~3 hours on a 16 core) 73 | `./app_dictionary_learning -i ../data/imgheart2 -o outputFiles -k 4 -l 2 -n 25 -s 100 -g 2 -x 200` 74 | 75 | - Wasserstein faces : ( ~15 hours on a 16 core) 76 | `./app_dictionary_learning -i ../data/mug_001_expr2 -o outputFiles -k 5 -l 4 -n 100 -s 100 -g 1 -a 3 -x 500 -m 100 --imComplement` 77 | 78 | - Wasserstein faces with warm restart : (~2.5 hour on a 16 core) 79 | `./app_dictionary_learning -i ../data/mug_001_expr2 -o outputFiles -k 5 -l 4 -n 5 -s 100 -g 1 -a 3 -x 500 -m 100 --imComplement --warmRestart` 80 | 81 | 82 | ### CLI parameters and options 83 | 84 | CLI parameters for `app_dictionary_learning` : 85 | 86 | Parameter / Option | Explanation | Default | Typical 87 | --------------------------|--------------------------------------------------------|---------|--------------- 88 | `-i ` | Input directory where to find the input histograms | - | - 89 | `-o ` | Output directory where to write results | - | - 90 | `[-k ]` | Number of atoms to find | 4 | [2-10] 91 | `[-l ]` | Loss function | 2 | [1,4] 92 | `[-n ]` | Number of Sinkhorn iterations | 25 | [20-500],[2,20]1 93 | `[-s ]` | Factor for scaling between weights and atom values | 100 | [10,1000] 94 | `[-g ]` | Entropic regularization parameter | 2 | [0.5,50] 95 | `[-a ]` | Value for gamma correction of input images | 1 | [1-5] 96 | `[-x ]` | Max number of optimization iterations | 200 | [50,1000] 97 | `[-m ]` | Saving frequency for intermediate results saving | 0 | [0,maxOptimIter] 98 | `[--deterministic]` | Generate random initial weights in a deterministic way | OFF | - 99 | `[--imComplement]` | Invert the values of input images | OFF | - 100 | `[--allowNegWeights]` | Do not force the barycentric weights to be positive. | OFF | - 101 | `[--warmRestart]` | Activate the warm restart technique | OFF | - 102 | 103 | 1 Range when using the warm restart 104 | 105 | ##### Additional information on parameters 106 | 107 | - `[-l ]` : 1 (Total Variation), 2 (Quadratic Loss), 3 (Wasserstein Loss, 4 (Kullback-Leibler Loss) 108 | - `[-s ]` : Should be tuned : a too high value will update the weights too much, and a too low value will update them too little. The values we used were between 10 and 1000, depending on the problem size. 109 | - `[-a ]` : This is useful when working with 2D images. If there is a close-to-zero background level, 110 | Setting the image to the power $\alpha$ will stretch the range so that values close to zero get even closer to it. This prevents the phenomenon that gathers residual background mass, resulting in a significant amount, and gives the impression that mass is created instead of transported. 111 | - `[--imComplement]` : This is useful to consider either black or white as the presence of mass. Without this option, white is considered as presence of mass. 112 | - `[-m ]`: Set to 0 to save only the final results. 113 | - `[-x ]` : Always represent the total number of iterations, even in the warm restart mode. 114 | 115 | 116 | ### Extensions 117 | 118 | ##### Log-domain stabilization 119 | 120 | The log-domain stabilization can be used to use arbitrarily small values of the regularization parameter $\gamma$. 121 | In order to use it, you need to uncomment the line `#define COMPUTE_BARYCENTER_LOG` at the top of the file `main_dictionary_learning.cpp`. 122 | 123 | ##### Warm restart 124 | 125 | The idea is that instead of a single L-BFGS run of 500 iterations, you restart a fresh L-BFGS every 10 iterations, and initialize the scaling vectors as the ones obtained at the end of the previous run. 126 | As explained in the paper, this technique accumulates the Sinkhorn iterations as we accumulate L-BFGS runs, so it allows to compute less Sinkhorn iterations for equivalent or better results, which leads to significant speed-ups. 127 | Be aware that accumulating too much Sinkhorn iterations can lead to numerical instabilities. If that happens, you can use the log-domain stabilization, which is slower, but compensated by the speed-up of the warm restart. 128 | For more details, please refer to our paper. 129 | The value of 10 optimization iterations per L-BFGS run is arbitrary and can be changed in the code (in `regress_both()` of `inverseWasserstein.h`), but it has shown good results for our experiments. 130 | 131 | 132 | 133 | ### License 134 | 135 | This program is free software: you can redistribute it and/or modify 136 | it under the terms of the GNU Lesser General Public License as published by 137 | the Free Software Foundation, either version 3 of the License, or 138 | (at your option) any later version. 139 | 140 | This program is distributed in the hope that it will be useful, 141 | but WITHOUT ANY WARRANTY; without even the implied warranty of 142 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 143 | GNU Lesser General Public License for more details. 144 | 145 | You should have received a copy of the GNU Lesser General Public License 146 | along with this program. If not, see . 147 | 148 | 149 | ### Contact 150 | 151 | matthieu.heitz@univ-lyon1.fr 152 | -------------------------------------------------------------------------------- /cpp/CMake/FindEigen.cmake: -------------------------------------------------------------------------------- 1 | # Redistribution and use in source and binary forms, with or without 2 | # modification, are permitted provided that the following conditions are met: 3 | # 4 | # * Redistributions of source code must retain the above copyright notice, 5 | # this list of conditions and the following disclaimer. 6 | # * Redistributions in binary form must reproduce the above copyright notice, 7 | # this list of conditions and the following disclaimer in the documentation 8 | # and/or other materials provided with the distribution. 9 | # * Neither the name of Google Inc. nor the names of its contributors may be 10 | # used to endorse or promote products derived from this software without 11 | # specific prior written permission. 12 | # 13 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 14 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 15 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 16 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 17 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 18 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 19 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 20 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 21 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 22 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 23 | # POSSIBILITY OF SUCH DAMAGE. 24 | # 25 | # Author: alexs.mac@gmail.com (Alex Stewart) 26 | # 27 | # FindEigen.cmake - Find Eigen library, version >= 3. 28 | # 29 | # This module defines the following variables: 30 | # 31 | # EIGEN_FOUND: TRUE iff Eigen is found. 32 | # EIGEN_INCLUDE_DIRS: Include directories for Eigen. 33 | # 34 | # EIGEN_VERSION: Extracted from Eigen/src/Core/util/Macros.h 35 | # EIGEN_WORLD_VERSION: Equal to 3 if EIGEN_VERSION = 3.2.0 36 | # EIGEN_MAJOR_VERSION: Equal to 2 if EIGEN_VERSION = 3.2.0 37 | # EIGEN_MINOR_VERSION: Equal to 0 if EIGEN_VERSION = 3.2.0 38 | # 39 | # The following variables control the behaviour of this module: 40 | # 41 | # EIGEN_INCLUDE_DIR_HINTS: List of additional directories in which to 42 | # search for eigen includes, e.g: /timbuktu/eigen3. 43 | # 44 | # The following variables are also defined by this module, but in line with 45 | # CMake recommended FindPackage() module style should NOT be referenced directly 46 | # by callers (use the plural variables detailed above instead). These variables 47 | # do however affect the behaviour of the module via FIND_[PATH/LIBRARY]() which 48 | # are NOT re-called (i.e. search for library is not repeated) if these variables 49 | # are set with valid values _in the CMake cache_. This means that if these 50 | # variables are set directly in the cache, either by the user in the CMake GUI, 51 | # or by the user passing -DVAR=VALUE directives to CMake when called (which 52 | # explicitly defines a cache variable), then they will be used verbatim, 53 | # bypassing the HINTS variables and other hard-coded search locations. 54 | # 55 | # EIGEN_INCLUDE_DIR: Include directory for CXSparse, not including the 56 | # include directory of any dependencies. 57 | # Called if we failed to find Eigen or any of it's required dependencies, 58 | # unsets all public (designed to be used externally) variables and reports 59 | # error message at priority depending upon [REQUIRED/QUIET/] argument. 60 | macro(EIGEN_REPORT_NOT_FOUND REASON_MSG) 61 | unset(EIGEN_FOUND) 62 | unset(EIGEN_INCLUDE_DIRS) 63 | # Make results of search visible in the CMake GUI if Eigen has not 64 | # been found so that user does not have to toggle to advanced view. 65 | mark_as_advanced(CLEAR EIGEN_INCLUDE_DIR) 66 | # Note _FIND_[REQUIRED/QUIETLY] variables defined by FindPackage() 67 | # use the camelcase library name, not uppercase. 68 | if (Eigen_FIND_QUIETLY) 69 | message(STATUS "Failed to find Eigen - " ${REASON_MSG} ${ARGN}) 70 | elseif (Eigen_FIND_REQUIRED) 71 | message(FATAL_ERROR "Failed to find Eigen - " ${REASON_MSG} ${ARGN}) 72 | else() 73 | # Neither QUIETLY nor REQUIRED, use no priority which emits a message 74 | # but continues configuration and allows generation. 75 | message("-- Failed to find Eigen - " ${REASON_MSG} ${ARGN}) 76 | endif () 77 | return() 78 | endmacro(EIGEN_REPORT_NOT_FOUND) 79 | # Search user-installed locations first, so that we prefer user installs 80 | # to system installs where both exist. 81 | # 82 | # TODO: Add standard Windows search locations for Eigen. 83 | list(APPEND EIGEN_CHECK_INCLUDE_DIRS 84 | /usr/local/include 85 | /usr/local/homebrew/include # Mac OS X 86 | /opt/local/var/macports/software # Mac OS X. 87 | /opt/local/include 88 | /usr/include) 89 | # Additional suffixes to try appending to each search path. 90 | list(APPEND EIGEN_CHECK_PATH_SUFFIXES 91 | eigen3 # Default root directory for Eigen. 92 | Eigen/include/eigen3 ) # Windows (for C:/Program Files prefix). 93 | # Search supplied hint directories first if supplied. 94 | find_path(EIGEN_INCLUDE_DIR 95 | NAMES Eigen/Core 96 | PATHS ${EIGEN_INCLUDE_DIR_HINTS} 97 | ${EIGEN_CHECK_INCLUDE_DIRS} 98 | PATH_SUFFIXES ${EIGEN_CHECK_PATH_SUFFIXES}) 99 | if (NOT EIGEN_INCLUDE_DIR OR 100 | NOT EXISTS ${EIGEN_INCLUDE_DIR}) 101 | eigen_report_not_found( 102 | "Could not find eigen3 include directory, set EIGEN_INCLUDE_DIR to " 103 | "path to eigen3 include directory, e.g. /usr/local/include/eigen3.") 104 | endif (NOT EIGEN_INCLUDE_DIR OR 105 | NOT EXISTS ${EIGEN_INCLUDE_DIR}) 106 | # Mark internally as found, then verify. EIGEN_REPORT_NOT_FOUND() unsets 107 | # if called. 108 | set(EIGEN_FOUND TRUE) 109 | # Extract Eigen version from Eigen/src/Core/util/Macros.h 110 | if (EIGEN_INCLUDE_DIR) 111 | set(EIGEN_VERSION_FILE ${EIGEN_INCLUDE_DIR}/Eigen/src/Core/util/Macros.h) 112 | if (NOT EXISTS ${EIGEN_VERSION_FILE}) 113 | eigen_report_not_found( 114 | "Could not find file: ${EIGEN_VERSION_FILE} " 115 | "containing version information in Eigen install located at: " 116 | "${EIGEN_INCLUDE_DIR}.") 117 | else (NOT EXISTS ${EIGEN_VERSION_FILE}) 118 | file(READ ${EIGEN_VERSION_FILE} EIGEN_VERSION_FILE_CONTENTS) 119 | string(REGEX MATCH "#define EIGEN_WORLD_VERSION [0-9]+" 120 | EIGEN_WORLD_VERSION "${EIGEN_VERSION_FILE_CONTENTS}") 121 | string(REGEX REPLACE "#define EIGEN_WORLD_VERSION ([0-9]+)" "\\1" 122 | EIGEN_WORLD_VERSION "${EIGEN_WORLD_VERSION}") 123 | string(REGEX MATCH "#define EIGEN_MAJOR_VERSION [0-9]+" 124 | EIGEN_MAJOR_VERSION "${EIGEN_VERSION_FILE_CONTENTS}") 125 | string(REGEX REPLACE "#define EIGEN_MAJOR_VERSION ([0-9]+)" "\\1" 126 | EIGEN_MAJOR_VERSION "${EIGEN_MAJOR_VERSION}") 127 | string(REGEX MATCH "#define EIGEN_MINOR_VERSION [0-9]+" 128 | EIGEN_MINOR_VERSION "${EIGEN_VERSION_FILE_CONTENTS}") 129 | string(REGEX REPLACE "#define EIGEN_MINOR_VERSION ([0-9]+)" "\\1" 130 | EIGEN_MINOR_VERSION "${EIGEN_MINOR_VERSION}") 131 | # This is on a single line s/t CMake does not interpret it as a list of 132 | # elements and insert ';' separators which would result in 3.;2.;0 nonsense. 133 | set(EIGEN_VERSION "${EIGEN_WORLD_VERSION}.${EIGEN_MAJOR_VERSION}.${EIGEN_MINOR_VERSION}") 134 | endif (NOT EXISTS ${EIGEN_VERSION_FILE}) 135 | endif (EIGEN_INCLUDE_DIR) 136 | # Set standard CMake FindPackage variables if found. 137 | if (EIGEN_FOUND) 138 | set(EIGEN_INCLUDE_DIRS ${EIGEN_INCLUDE_DIR}) 139 | endif (EIGEN_FOUND) 140 | # Handle REQUIRED / QUIET optional arguments and version. 141 | include(FindPackageHandleStandardArgs) 142 | find_package_handle_standard_args(Eigen 143 | REQUIRED_VARS EIGEN_INCLUDE_DIRS 144 | VERSION_VAR EIGEN_VERSION) 145 | # Only mark internal variables as advanced if we found Eigen, otherwise 146 | # leave it visible in the standard GUI for the user to set manually. 147 | if (EIGEN_FOUND) 148 | mark_as_advanced(FORCE EIGEN_INCLUDE_DIR) 149 | endif (EIGEN_FOUND) 150 | -------------------------------------------------------------------------------- /cpp/CMake/FindHalide.cmake: -------------------------------------------------------------------------------- 1 | # FindHalide.cmake 2 | # ... shamelessly based on FindJeMalloc.cmake 3 | 4 | find_path(HALIDE_ROOT_DIR 5 | NAMES include/Halide.h include/HalideRuntime.h 6 | ) 7 | 8 | find_library(HALIDE_LIBRARIES 9 | NAMES Halide 10 | HINTS ${HALIDE_ROOT_DIR}/lib 11 | ) 12 | 13 | find_path(HALIDE_INCLUDES 14 | NAMES Halide.h HalideRuntime.h 15 | HINTS ${HALIDE_ROOT_DIR}/include 16 | ) 17 | 18 | include(FindPackageHandleStandardArgs) 19 | find_package_handle_standard_args(Halide DEFAULT_MSG 20 | HALIDE_LIBRARIES 21 | HALIDE_INCLUDES 22 | ) 23 | 24 | set(HALIDE_LIBRARY_DIR ${HALIDE_LIBRARIES}) 25 | set(HALIDE_INCLUDE_DIR ${HALIDE_INCLUDES}) 26 | 27 | mark_as_advanced( 28 | HALIDE_ROOT_DIR 29 | HALIDE_LIBRARIES 30 | HALIDE_LIBRARY_DIR 31 | HALIDE_INCLUDES 32 | HALIDE_INCLUDE_DIR 33 | ) 34 | -------------------------------------------------------------------------------- /cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Author : Matthieu Heitz 2 | # Date : 03/11/2016 3 | 4 | cmake_minimum_required(VERSION 2.8) 5 | project(inverseWasserstein) 6 | 7 | 8 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") 9 | set(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/CMake") 10 | 11 | ####################### 12 | ### Project options ### 13 | ####################### 14 | 15 | option(WITH_HALIDE "Use Halide to compute convolutions" OFF) 16 | option(WITH_EIGEN "Use Eigen to compute convolutions" OFF) 17 | option(WITH_AVX_SUPPORT "Use AVX to vectorize convolutions" ON) 18 | option(WITH_OPENMP "Use OpenMP" OFF) 19 | 20 | option(BUILD_APP_DICTIONARY_LEARNING "Build the Dictionary Learning application" ON) 21 | 22 | 23 | ############################ 24 | ### General dependencies ### 25 | ############################ 26 | 27 | IF(WITH_OPENMP) 28 | FIND_PACKAGE(OpenMP REQUIRED) 29 | IF(OPENMP_FOUND) 30 | SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") 31 | SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") 32 | ADD_DEFINITIONS("-DHAS_OPENMP ") 33 | message(STATUS "OpenMP found.") 34 | ELSE(OPENMP_FOUND) 35 | message(FATAL_ERROR "OpenMP support not available.") 36 | ENDIF(OPENMP_FOUND) 37 | ENDIF(WITH_OPENMP) 38 | 39 | 40 | set(INCLUDE_DIRS 41 | lbfgs/ 42 | ) 43 | 44 | set(ADDITIONAL_SOURCE_FILES 45 | ${ADDITIONAL_SOURCE_FILES} 46 | lbfgs/lbfgs.c) 47 | 48 | if(WITH_HALIDE) 49 | message(STATUS "WITH_HALIDE") 50 | find_package(Halide REQUIRED) 51 | add_definitions(-DHAS_HALIDE) 52 | include_directories(${HALIDE_INCLUDE_DIR}) 53 | set(ADDITIONAL_LINKER_FLAG ${ADDITIONAL_LINKER_FLAG} ${HALIDE_LIBRARY_DIR} -ldl -lncurses -lz) 54 | endif(WITH_HALIDE) 55 | 56 | if(WITH_EIGEN) 57 | message(STATUS "WITH_EIGEN") 58 | set(EIGEN_INCLUDE_DIR_HINTS "NO_HINT" CACHE STRING "Hints to Eigen include directory") 59 | find_package(Eigen REQUIRED) 60 | add_definitions(-DHAS_EIGEN) 61 | include_directories(${EIGEN_INCLUDE_DIRS}) 62 | endif(WITH_EIGEN) 63 | 64 | if(WITH_AVX_SUPPORT) 65 | add_definitions(-DAVX_SUPPORT) 66 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx -mtune=native") 67 | endif(WITH_AVX_SUPPORT) 68 | 69 | 70 | ######################################### 71 | ### Application-specific dependencies ### 72 | ######################################### 73 | 74 | # Applications that need CImg 75 | if(BUILD_APP_DICTIONARY_LEARNING) 76 | 77 | set(INCLUDE_DIRS 78 | ${INCLUDE_DIRS} 79 | cimg/ 80 | ) 81 | set(ADDITIONAL_SOURCE_FILES 82 | ${ADDITIONAL_SOURCE_FILES} 83 | 84 | ) 85 | set(ADDITIONAL_LINKER_FLAG ${ADDITIONAL_LINKER_FLAG} -lpthread) 86 | # Set display library type : 0: no display, 1: X11 based, 2: Windows-GDI 87 | add_definitions(-Dcimg_display=0) 88 | endif(BUILD_APP_DICTIONARY_LEARNING) 89 | 90 | 91 | include_directories(${INCLUDE_DIRS}) 92 | 93 | ##################### 94 | ### Build targets ### 95 | ##################### 96 | 97 | if(BUILD_APP_DICTIONARY_LEARNING) 98 | set(SOURCE_FILES 99 | main_dictionary_learning.cpp 100 | ) 101 | add_executable(app_dictionary_learning ${SOURCE_FILES} ${ADDITIONAL_SOURCE_FILES}) 102 | target_link_libraries(app_dictionary_learning ${ADDITIONAL_LINKER_FLAG}) 103 | endif(BUILD_APP_DICTIONARY_LEARNING) 104 | -------------------------------------------------------------------------------- /cpp/chrono.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef _MSC_VER 4 | #include 5 | 6 | class PerfChrono 7 | { 8 | __int64 freq, t0; 9 | 10 | public: 11 | PerfChrono() { 12 | QueryPerformanceFrequency((LARGE_INTEGER*)&freq); 13 | } 14 | 15 | void Start() { 16 | QueryPerformanceCounter((LARGE_INTEGER*)&t0); 17 | } 18 | 19 | DWORD GetDiffMs(){ 20 | __int64 t1; 21 | QueryPerformanceCounter((LARGE_INTEGER*)&t1); 22 | return (DWORD)(((t1 - t0) * 1000) / freq); 23 | } 24 | 25 | DWORD GetDiffUs() { //micro sec 26 | __int64 t1; 27 | QueryPerformanceCounter((LARGE_INTEGER*)&t1); 28 | return (DWORD)(((t1 - t0) * 1000000) / freq); 29 | } 30 | 31 | DWORD GetDiffNs(){ 32 | __int64 t1; 33 | QueryPerformanceCounter((LARGE_INTEGER*)&t1); 34 | return (DWORD)(((t1 - t0) * 1000000000) / freq); 35 | } 36 | 37 | DWORD GetDiff(UINT unit){ 38 | __int64 t1; 39 | QueryPerformanceCounter((LARGE_INTEGER*)&t1); 40 | return (DWORD)(((t1 - t0) * unit) / freq); 41 | } 42 | 43 | DWORD GetFreq(){ 44 | return (DWORD)freq; 45 | } 46 | }; 47 | 48 | #else 49 | 50 | #include 51 | 52 | class PerfChrono 53 | { 54 | 55 | public: 56 | double GetDiffMs(){return GetDiffAbsoluteMs();} 57 | 58 | double GetDiffAbsoluteMs(){ 59 | gettimeofday(&t2, NULL); 60 | double elapsedTime = (t2.tv_sec - t1.tv_sec) * 1000.0; // sec to ms 61 | elapsedTime += (t2.tv_usec - t1.tv_usec) / 1000.0; // us to ms 62 | return elapsedTime; 63 | } 64 | double GetDiffRelativeMs(){ 65 | gettimeofday(&t2, NULL); 66 | double elapsedTime = (t2.tv_sec - t3.tv_sec) * 1000.0; // sec to ms 67 | elapsedTime += (t2.tv_usec - t3.tv_usec) / 1000.0; // us to ms 68 | t3 = t2; 69 | return elapsedTime; 70 | } 71 | 72 | void Start() { 73 | gettimeofday(&t1, NULL); 74 | t3 = t1; 75 | } 76 | 77 | struct timeval t1, t2, t3; 78 | }; 79 | 80 | #endif 81 | -------------------------------------------------------------------------------- /cpp/histogramIO.h: -------------------------------------------------------------------------------- 1 | // File : histogramIO.h 2 | // 3 | // Description : 4 | // Functions and structures to manage the input/output of n-dimensional histograms 5 | // Histograms are also called pdfs for Probability Density Functions 6 | // 7 | 8 | // Functions specific to the dimension of the histogram. 9 | 10 | #pragma once 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include "cimg/CImg.h" 19 | 20 | #if defined(_WIN32) // Windows 21 | 22 | #elif defined(__linux__) || defined(__APPLE__) // Linux, OS X 23 | #include 24 | #endif 25 | 26 | 27 | double load_img_to_pdf(const char* filename, std::vector &result, int &W, int &H); 28 | double load_2D_img_to_pdf(const char* filename, std::vector &result, int &W, int &H); 29 | double load_rgb_img_to_pdf_grayscale_luma(const char* filename, std::vector &result, int &W, int &H); 30 | double load_img_to_pdf_gamma_correction(const char* filename, std::vector &result, int &W, int &H, double alpha, bool complement); 31 | double load_csv_to_pdf(const char* filename, std::vector &result); 32 | void save_pdf(const char* filename, int W, int H, const std::vector &val, double scaling); 33 | void save_pdf_gamma_correction(const char* filename, int W, int H, const std::vector &val, double alpha, bool im_complement, double scaling); 34 | 35 | // Helper functions 36 | double lerp(double v0, double v1, double t); 37 | std::vector quantile(const std::vector& inData, const std::vector& probs); 38 | std::vector getFileListByPattern(const std::string& pat); 39 | std::vector get_all_files_names_within_folder(std::wstring folder); 40 | 41 | // Abstract export class 42 | class ExportHistogramBase{ 43 | public: 44 | ExportHistogramBase(std::string outputFolderPath) : mOutputFolderPath(outputFolderPath) {} 45 | 46 | // Must return 0 if export is successful, other values otherwise 47 | virtual int exportHistogram(const std::vector &results, std::string filenameNoExtension) const = 0 ; 48 | 49 | // Must return 0 if export is successful, other values otherwise 50 | virtual int exportHistogramsBatch(const std::vector> &results, std::vector filenamesNoExtension) const = 0 ; 51 | 52 | //protected: 53 | std::string mOutputFolderPath; 54 | }; 55 | 56 | 57 | // Export a nD histogram to a CSV file 58 | class ExportHistogramToCSV : public ExportHistogramBase { 59 | 60 | public: 61 | ExportHistogramToCSV(std::string outputFolderPath = "outputFiles") : ExportHistogramBase(outputFolderPath) 62 | {} 63 | 64 | int exportHistogram(const std::vector &results, std::string filenameNoExtension) const 65 | { 66 | std::ofstream ofs; 67 | std::string filename = mOutputFolderPath + "/" + filenameNoExtension + ".csv"; 68 | ofs.open(filename); 69 | if(!ofs.is_open()) 70 | { 71 | std::cerr<<"Unable to write file "<> &results, std::vector filenamesNoExtension) const 83 | { 84 | if(results.size() != filenamesNoExtension.size()) 85 | { 86 | std::cerr<<"Histogram vector and filename vector are of different sizes"<exportHistogram(results[i],filenamesNoExtension[i]); 94 | if(ret != 0) success = false; 95 | } 96 | return (success ? 0 : -1); 97 | } 98 | }; 99 | 100 | // Export a 2D histogram to a graylevel image 101 | class ExportHistogramToPNG : public ExportHistogramBase { 102 | 103 | public: 104 | int mWidth, mHeight; 105 | 106 | ExportHistogramToPNG(int width, int height, std::string outputFolderPath = "outputFiles") : 107 | mWidth(width), mHeight(height), ExportHistogramBase(outputFolderPath) {} 108 | 109 | int exportHistogram(const std::vector &results, std::string filenameNoExtension) const 110 | { 111 | std::string filename = mOutputFolderPath + "/" + filenameNoExtension + ".png"; 112 | assert(results.size() == mWidth*mHeight); 113 | 114 | std::vector probs = {0.95}; 115 | std::vector quantiles = quantile(results,probs); 116 | save_pdf(filename.c_str(), mWidth, mHeight, results, 255.0*probs[0]/quantiles[0]); 117 | return 0; 118 | } 119 | 120 | int exportHistogramsBatch(const std::vector> &results, std::vector filenamesNoExtension) const 121 | { 122 | if(results.size() != filenamesNoExtension.size()) 123 | { 124 | std::cerr<<"Histogram vector and filename vector are of different sizes"<exportHistogram(results[i],filenamesNoExtension[i]); 132 | if(ret != 0) success = false; 133 | } 134 | return (success ? 0 : -1); 135 | } 136 | }; 137 | 138 | // Export a 2D histogram to a graylevel image 139 | // Holds a vector of scaling values to apply to histogram when exporting them. 140 | class ExportHistogramToPNGWithScaling : public ExportHistogramBase { 141 | 142 | public: 143 | int mWidth, mHeight; 144 | std::vector mScalings; 145 | float mAlpha; // Value for the gamma-correction 146 | bool mComplement; // Take the image complement 147 | 148 | ExportHistogramToPNGWithScaling(int width, int height, std::vector& scalings, std::string outputFolderPath = "outputFiles", float alpha = 1, bool imComplement = false) : 149 | mWidth(width), mHeight(height), mScalings(scalings), mAlpha(alpha), mComplement(imComplement), ExportHistogramBase(outputFolderPath) {} 150 | 151 | int exportHistogram(const std::vector &results, std::string filenameNoExtension) const 152 | { 153 | std::string filename = mOutputFolderPath + "/" + filenameNoExtension + ".png"; 154 | assert(results.size() == mWidth*mHeight); 155 | 156 | std::vector probs = {0.95}; 157 | std::vector quantiles = quantile(results,probs); 158 | save_pdf_gamma_correction(filename.c_str(), mWidth, mHeight, results, mAlpha, mComplement, probs[0]/quantiles[0]); 159 | return 0; 160 | } 161 | 162 | int exportHistogramWithScaling(const std::vector &results, std::string filenameNoExtension, int i) const 163 | { 164 | std::string filename = mOutputFolderPath + "/" + filenameNoExtension + ".png"; 165 | assert(results.size() == mWidth*mHeight); 166 | if(i >= mScalings.size()) 167 | { 168 | std::cerr<<"Index "<> &results, std::vector filenamesNoExtension) const 181 | { 182 | if(results.size() != filenamesNoExtension.size()) 183 | { 184 | std::cerr<<"Histogram vector and filename vector are of different sizes"<exportHistogramWithScaling(results[i],filenamesNoExtension[i],i); 193 | if(ret != 0) success = false; 194 | } 195 | return (success ? 0 : -1); 196 | } 197 | }; 198 | 199 | 200 | // Export a nD histogram to a raw binary file 201 | class ExportHistogramToRAW : public ExportHistogramBase { 202 | 203 | public: 204 | ExportHistogramToRAW(std::string outputFolderPath = "outputFiles") : ExportHistogramBase(outputFolderPath) 205 | {} 206 | 207 | int exportHistogram(const std::vector &results, std::string filenameNoExtension) const 208 | { 209 | std::ofstream ofs; 210 | std::string filename = mOutputFolderPath + "/" + filenameNoExtension + ".raw"; 211 | ofs.open(filename, std::ios::out | std::ios::binary); 212 | if(!ofs.is_open()) 213 | { 214 | std::cerr<<"Unable to write file "<(&results[0]), results.size()*sizeof(double)); 218 | ofs.close(); 219 | return 0; 220 | } 221 | 222 | int exportHistogramsBatch(const std::vector> &results, std::vector filenamesNoExtension) const 223 | { 224 | if(results.size() != filenamesNoExtension.size()) 225 | { 226 | std::cerr<<"Histogram vector and filename vector are of different sizes"<exportHistogram(results[i],filenamesNoExtension[i]); 234 | if(ret != 0) success = false; 235 | } 236 | return (success ? 0 : -1); 237 | } 238 | }; 239 | 240 | 241 | 242 | 243 | double load_img_to_pdf(const char* filename, std::vector &result, int &W, int &H) { 244 | 245 | cimg_library::CImg cimg(filename); 246 | std::cout<<"loading "< &result, int &W, int &H) { 262 | 263 | std::string f(filename); 264 | double r = load_img_to_pdf(filename,result,W,H); 265 | if(W == 0 || H == 0) return -1; 266 | return r; 267 | } 268 | 269 | double load_rgb_img_to_pdf_grayscale_luma(const char* filename, std::vector &result, int &W, int &H) { 270 | 271 | cimg_library::CImg cimg(filename); 272 | printf("loading %s\n",filename); 273 | W = cimg.width(); 274 | H = cimg.height(); 275 | result.resize(W*H); 276 | double sum = 0; 277 | for (int i=0; i &result, int &W, int &H, double alpha, bool complement) { 288 | 289 | cimg_library::CImg cimg(filename); 290 | std::cout<<"loading "< &result) { 315 | 316 | std::ifstream ifs; 317 | ifs.open(filename); 318 | if(!ifs.is_open()) 319 | { 320 | std::cerr<<"Unable to read file "<>v; 327 | sum +=v; 328 | result.push_back(v); 329 | } 330 | ifs.close(); 331 | for(int i=0; i &val, double scaling) { 340 | 341 | std::vector deinterleaved(W*H * 3); 342 | for (int i = 0; i < W*H; i++) { 343 | deinterleaved[i] = std::min(255., std::max(0., val[i] * scaling)); 344 | deinterleaved[i + W*H] = deinterleaved[i]; 345 | deinterleaved[i + 2 * W*H] = deinterleaved[i]; 346 | } 347 | 348 | cimg_library::CImg cimg(&deinterleaved[0], W, H, 1, 3); 349 | cimg.save(filename); 350 | 351 | } 352 | 353 | void save_pdf_gamma_correction(const char* filename, int W, int H, const std::vector &val, double alpha, bool im_complement, double scaling) { 354 | 355 | std::vector deinterleaved(W*H * 3); 356 | if(im_complement) { 357 | for (int i = 0; i < W*H; i++) { 358 | deinterleaved[i] = std::min(255., std::max(0., 255. - std::pow(val[i]*scaling, 1.0/alpha) * 255.)); 359 | deinterleaved[i + W*H] = deinterleaved[i]; 360 | deinterleaved[i + 2 * W*H] = deinterleaved[i]; 361 | } 362 | } else { 363 | for (int i = 0; i < W*H; i++) { 364 | deinterleaved[i] = std::min(255., std::max(0., std::pow(val[i]*scaling, 1.0/alpha) * 255.)); 365 | deinterleaved[i + W*H] = deinterleaved[i]; 366 | deinterleaved[i + 2 * W*H] = deinterleaved[i]; 367 | } 368 | } 369 | cimg_library::CImg cimg(&deinterleaved[0], W, H, 1, 3); 370 | cimg.save(filename); 371 | } 372 | 373 | double lerp(double v0, double v1, double t) 374 | { 375 | return (1 - t)*v0 + t*v1; 376 | } 377 | 378 | // Taken from Yury : http://stackoverflow.com/a/37708864/4195725 379 | std::vector quantile(const std::vector& inData, const std::vector& probs) 380 | { 381 | if (inData.size() <= 2 || probs.empty()) 382 | { 383 | throw std::runtime_error("Invalid input"); 384 | } 385 | 386 | std::vector data = inData; 387 | std::sort(data.begin(), data.end()); 388 | std::vector quantiles; 389 | 390 | for (size_t i = 0; i < probs.size(); ++i) 391 | { 392 | double center = lerp(-0.5, data.size() - 0.5, probs[i]); 393 | 394 | size_t left = std::max(int64_t(std::floor(center)), int64_t(0)); 395 | size_t right = std::min(int64_t(std::ceil(center)), int64_t(data.size() - 1)); 396 | 397 | double dataLeft = data.at(left); 398 | double dataRight = data.at(right); 399 | 400 | double quantile = lerp(dataLeft, dataRight, center - left); 401 | 402 | quantiles.push_back(quantile); 403 | } 404 | 405 | return quantiles; 406 | } 407 | 408 | 409 | #if defined(_WIN32) 410 | 411 | static std::vector findFiles(const std::string& pat){ 412 | std::string filename; 413 | std::vector listOfFileNames; 414 | 415 | WIN32_FIND_DATAA findFileData; 416 | HANDLE myHandle = FindFirstFileA(pat.c_str(),&findFileData); 417 | 418 | if( myHandle != INVALID_HANDLE_VALUE) 419 | { 420 | do 421 | { 422 | filename = findFileData.cFileName; 423 | listOfFileNames.push_back( filename); 424 | } while (FindNextFileA(myHandle, &findFileData) != 0); 425 | } 426 | FindClose(myHandle); 427 | std::sort(listOfFileNames.begin(), listOfFileNames.end()); 428 | return listOfFileNames; 429 | } 430 | 431 | #elif defined(__linux__) || defined(__APPLE__) 432 | 433 | static std::vector glob(const std::string& pat){ 434 | glob_t glob_result; 435 | glob(pat.c_str(),GLOB_TILDE | GLOB_BRACE,NULL,&glob_result); 436 | std::vector ret; 437 | for(unsigned int i=0;i getFileListByPattern(const std::string& pat) 447 | { 448 | #if defined(_WIN32) 449 | return findFiles(pat); 450 | #elif defined(__linux__) || defined(__APPLE__) 451 | return glob(pat); 452 | #endif 453 | } 454 | 455 | 456 | std::vector get_all_files_names_within_folder(std::wstring folder) { 457 | #if defined(_WIN32) 458 | std::vector names; 459 | TCHAR search_path[200]; 460 | std::string strpath(folder.begin(), folder.end()); 461 | sprintf(search_path, "%s/*.*", strpath.c_str()); 462 | WIN32_FIND_DATA fd; 463 | HANDLE hFind = ::FindFirstFile(search_path, &fd); 464 | if (hFind != INVALID_HANDLE_VALUE) { 465 | do { 466 | // read all (real) files in current folder 467 | // , delete '!' read other 2 default folder . and .. 468 | if (!(fd.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY)) { 469 | std::string str(fd.cFileName); 470 | names.push_back(str); 471 | } 472 | } while (::FindNextFile(hFind, &fd)); 473 | ::FindClose(hFind); 474 | } 475 | return names; 476 | #elif defined(__unix__) || defined(__APPLE__) // Linux and OSX 477 | std::vector names; 478 | std::string folderPath(folder.begin(), folder.end()); 479 | DIR *dir; 480 | struct dirent *ent; 481 | std::cout<<"Reading in folder "<d_name,".") != 0 && std::strcmp(ent->d_name,"..") != 0) 486 | { 487 | names.push_back(ent->d_name); 488 | } 489 | } 490 | closedir (dir); 491 | } else { 492 | /* could not open directory */ 493 | perror (""); 494 | return names; 495 | } 496 | std::sort(names.begin(), names.end()); 497 | return names; 498 | #endif 499 | } 500 | -------------------------------------------------------------------------------- /cpp/inverseWasserstein.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include "lbfgs.h" 4 | #include "kernels.h" 5 | #include "loss.h" 6 | #include "histogramIO.h" 7 | #include "signArray.h" 8 | #include "chrono.h" 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | typedef std::complex ComplexDataType; 15 | 16 | enum gradient_type {GRADIENT_IMPLICIT, GRADIENT_SINKHORN, GRADIENT_NUMERIC}; 17 | 18 | template class Problem; 19 | template class WassersteinRegression; 20 | template class WassersteinBarycenter; 21 | 22 | template T dot(const T* u, const T* v, int N) { 23 | double res = 0.; 24 | for (int i=0; i 30 | class Problem { 31 | public: 32 | Problem(std::vector > &histograms, const std::vector &observed_histogram, const KernelType &pkernel, double* weights = NULL, int npdfs = 1) { 33 | pdfs = histograms; 34 | K = pdfs.size(); 35 | N = pkernel.N; 36 | observed_pdf = observed_histogram; 37 | kernel = pkernel; 38 | lambdas = new double[K]; 39 | num_pdfs = npdfs; 40 | if (weights) { set_weights(weights); }; 41 | 42 | } 43 | Problem(const Problem& p) { // Kernel is passed by pointer ; lambdas are copied 44 | pdfs = p.pdfs; 45 | K = p.K; 46 | N = p.N; 47 | observed_pdf = p.observed_pdf; 48 | kernel = p.kernel; 49 | num_pdfs = p.num_pdfs; 50 | lambdas = new double[K]; 51 | set_weights(p.lambdas); 52 | } 53 | void normalize_values() { 54 | 55 | int P = num_pdfs; 56 | 57 | double s = 0; 58 | for (int i=0; i > pdfs; 90 | std::vector observed_pdf; 91 | KernelType kernel; 92 | double* lambdas; 93 | size_t N, K; 94 | int num_pdfs; 95 | }; 96 | 97 | 98 | 99 | template 100 | class WassersteinBarycenter { 101 | template friend class WassersteinRegression; 102 | 103 | public: 104 | WassersteinBarycenter(Problem* p, int n_bregman_iter) { 105 | K = p->pdfs.size(); 106 | N = p->kernel.N; 107 | problem = p; 108 | Niters = n_bregman_iter; 109 | } 110 | 111 | double get_plan(int id, int i, int j) { 112 | return a[id*N+i]*b[id*N+j]*problem->kernel(i, j); 113 | } 114 | 115 | void compute_barycenter(int id_basis = 0) { 116 | b.resize(K*N); 117 | std::fill(b.begin(), b.begin()+K*N, 1.0); 118 | compute_barycenter_no_scaling_init(id_basis); 119 | } 120 | 121 | // Bregman projections 122 | void compute_barycenter_no_scaling_init(int id_basis = 0) { 123 | 124 | barycenter.resize(N); 125 | a.resize(K*N); 126 | 127 | std::vector > convolution(K*N); 128 | 129 | problem->normalize_values(); 130 | // Bregman Projections 131 | for (int iter=0; iterkernel.convolve(&b[0], &convolution[0], K); 135 | for (int i=0; ipdfs[i][id_basis*N + j] / convolution[i*N+j]; 138 | } 139 | } 140 | problem->kernel.convolveAdjoint(&a[0], &convolution[0], K); 141 | 142 | geomMean(&convolution[0], &barycenter[0]); 143 | 144 | //#pragma omp parallel for 145 | for (int i=0; i > convolution(K*N); 164 | 165 | problem->normalize_values(); 166 | // Bregman Projections 167 | for (int iter=0; iterkernel.convolve(&b[0], &convolution[0], K); 171 | problem->kernel.log_convolve(&b[0], &convolution[0], K); 172 | for (int i=0; ipdfs[i][id_basis*N + j] / convolution[i*N+j]; 175 | a[i*N+j] = log(problem->pdfs[i][id_basis*N + j]) - convolution[i*N+j]; 176 | } 177 | } 178 | // problem->kernel.convolveAdjoint(&a[0], &convolution[0], K); 179 | problem->kernel.log_convolveAdjoint(&a[0], &convolution[0], K); 180 | 181 | memset(&barycenter[0], 0, N*sizeof(barycenter[0])); 182 | for (int i=0; ilambdas[i]; 184 | for (int j=0; jpdfs[i][0])); 208 | } 209 | 210 | double curEntropy = entropy(&barycenter[0]); 211 | double sumBary = 0; for (int i=0; i maxEntropy + 1) { 213 | 214 | 215 | std::vector baryBeta(N); 216 | std::vector entrop(40); 217 | #pragma omp parallel for firstprivate(baryBeta) 218 | for (int i=0; i &a, const std::vector &b) const { 277 | double sum = 0; 278 | for (int j=0; jkernelMatrix[j*N+k]*a[k]*b[j]*(safe_log(a[k])+safe_log(b[j])-1.); 281 | } 282 | } 283 | return sum; 284 | } 285 | 286 | // Wasserstein cost of the transport plan given by the current barycenter 287 | double cost_W_bary() const { 288 | double sum = 0; 289 | for (int i=0; ilambdas[i]*cost_W(a[i], b[i]); 291 | } 292 | return sum; 293 | } 294 | 295 | // geometric mean used for Bregman projections 296 | void geomMean(double* convolved_a, double* result) const { 297 | 298 | memset(&result[0], 0, N*sizeof(result[0])); 299 | double sumLambdas = 0; 300 | for (int i=0; ilambdas[i]; 302 | } 303 | 304 | for (int i=0; ilambdas[i]; 306 | for (int j=0; j > a, b; 315 | std::vector barycenter; 316 | 317 | Problem *problem; 318 | size_t K, N; 319 | int Niters; 320 | }; 321 | 322 | 323 | template 324 | class WassersteinRegression { 325 | public: 326 | WassersteinRegression(Problem* p, 327 | int n_bregman_iterations, 328 | gradient_type gradient_method, 329 | const BaseLoss &lossFunctor, 330 | const ExportHistogramBase &exportFunctor, 331 | double scale_dictionary = 0.0, // Only for dual regression (regress_both). 332 | bool export_atoms = true, // Only for dual regression (regress_both) 333 | bool export_fittings = true, // Only for dual regression (regress_both) 334 | bool export_only_final_solution = true, 335 | bool warm_restart = false) : 336 | exporter(exportFunctor), loss(lossFunctor) 337 | { 338 | K = p->pdfs.size(); 339 | N = p->kernel.N; 340 | iteration = 0; 341 | bary_computation = new WassersteinBarycenter(p, n_bregman_iterations); 342 | problem = p; 343 | firstCall = true; 344 | this->gradient_method = gradient_method; 345 | // Optimization is done in the log-domain (to keep positive values) 346 | exp_weight = true; 347 | // Exportation of pdfs 348 | scaleDictionary = scale_dictionary; 349 | exportAtoms = export_atoms; 350 | exportFittings = export_fittings; 351 | exportOnlyFinalSolution = export_only_final_solution; 352 | warmRestart = warm_restart; 353 | exportEveryMIter = 1; 354 | lbfgs_parameter_init(&lbfgs_param); 355 | } 356 | 357 | ~WassersteinRegression() { 358 | delete bary_computation; 359 | } 360 | 361 | 362 | void regress_both(double* solution) { 363 | double residual; 364 | firstCall = true; 365 | lbfgs_param.linesearch = LBFGS_LINESEARCH_BACKTRACKING_STRONG_WOLFE; // LBFGS_LINESEARCH_BACKTRACKING_STRONG_WOLFE; // LBFGS_LINESEARCH_BACKTRACKING_WOLFE; // LBFGS_LINESEARCH_BACKTRACKING_ARMIJO;// LBFGS_LINESEARCH_MORETHUENTE; 366 | lbfgs_param.epsilon = 1E-50; // Convergence test on accuracy 367 | lbfgs_param.max_linesearch = 20; // Max number of trials for the line search 368 | lbfgs_param.delta = 1E-50; // Convergence test on minimum rate of decrease 369 | if(warmRestart) 370 | lbfgs_param.max_iterations = 10; // Number of iterations when doing multiple small runs of LBFGS 371 | else 372 | lbfgs_param.max_iterations = wrTotalIteration; 373 | 374 | iteration = 0; 375 | 376 | int P = problem->num_pdfs; 377 | int n = K*P + K*N; 378 | double scaleDic = scaleDictionary; // Scale between dictionary and weight sizes 379 | 380 | if(warmRestart) 381 | { 382 | b_storage.resize(K*N*P); // Initialize the storage vector for warm restart 383 | std::fill(b_storage.begin(), b_storage.begin()+K*N*P, 1.0); 384 | b_temp.resize(K*N*P); 385 | b.resize(K*N*(this->bary_computation->Niters+1)); // b initialization 386 | std::fill(b.begin(), b.begin()+K*N, 1.0); 387 | } 388 | 389 | // Transfer the data to the log domain 390 | if (exp_weight) { 391 | for (int i=0; iwarmRestart) 419 | { 420 | // Store the last computed scalings 421 | std::copy(this->b_temp.begin(), this->b_temp.end(),this->b_storage.begin()); 422 | } 423 | } 424 | std::vector solution2(solution,solution+P*K+K*N); 425 | 426 | // Transfer the data back to the original domain 427 | if (exp_weight) { 428 | for (int i=0; i base_result(N); 467 | for (int i=0; i gradient(K*P+N*K); // dummy 487 | // Force exporting 488 | int M = exportEveryMIter; 489 | bool final = exportOnlyFinalSolution; 490 | exportOnlyFinalSolution = false; 491 | exportEveryMIter = 1; 492 | // Compute barycenter (and gradients but not used) 493 | evaluate_both(this,&solution2[0],&gradient[0],K*P+N*K,0); 494 | // Revert changes made to force exporting 495 | exportOnlyFinalSolution = final; 496 | exportEveryMIter = M; 497 | std::string dir = exporter.mOutputFolderPath + "/"; 498 | std::ostringstream ss; 499 | ss << "i-fitting_" << std::setfill('0') << std::setw(3) << iteration; 500 | std::string oldname(ss.str()); 501 | std::vector files = getFileListByPattern(dir + oldname + "*"); 502 | for (int i=0; iNiters), std::min(iteration*1, bary_computation->Niters)); 546 | const int n_iter_gradients = bary_computation->Niters; 547 | 548 | std::vector > a(K*N), u(K*N), convu(K*N); 549 | //std::vector > barycenter(N); 550 | 551 | if(!warmRestart) 552 | { 553 | b.resize(K*N*(n_iter_gradients+1)); 554 | std::fill(b.begin(), b.begin()+K*N, 1.0); 555 | } 556 | 557 | conv_b.resize(K*N*n_iter_gradients); 558 | phi.resize(K*N*(n_iter_gradients+1)); 559 | g.resize(N); 560 | r.resize(K*N); memset(&r[0], 0, K*N*sizeof(r[0])); 561 | 562 | problem->normalize_values(); 563 | 564 | // Bregman Projections 565 | for (int iter=1; iter<=n_iter_gradients; iter++) { 566 | 567 | //#pragma omp parallel for firstprivate(a) 568 | problem->kernel.convolveAdjoint(&b[(iter-1)*K*N], &conv_b[(size_t)((iter-1)*K*N)], K); 569 | for (int i=0; ipdfs[i][0*N + j] / conv_b[offset+j]; 573 | } 574 | } 575 | problem->kernel.convolve(&a[0], &phi[iter*K*N], K); 576 | 577 | memset(&barycenter[0], 0, N*sizeof(barycenter[0])); 578 | for (int i=0; ilambdas[i]; 580 | for (int j=0; j > n(N), v(K*N,0.0), tmp(K*N), c(K*N), sumv(N); 597 | loss.gradient(&barycenter[0], &problem->observed_pdf[id*N], N, &g[0]); 598 | 599 | // gradient w.r.t dictionary 600 | memset(resultDic, 0, K*N*sizeof(resultDic[0])); 601 | 602 | /*if (iteration%2==0)*/ { 603 | memcpy(&n[0], &g[0], N*sizeof(double)); 604 | for (int sub_iter = n_iter_gradients; sub_iter>=1; sub_iter--) { 605 | memset(&sumv[0], 0, N*sizeof(sumv[0])); 606 | for (int i=0; ilambdas[i]*n[j]-v[i*N+j]) * b[sub_iter*K*N + i*N+j]; 609 | } 610 | } 611 | problem->kernel.convolve(&tmp[0], &c[0], K); 612 | for (int i=0; ipdfs[i][0*N + j]*c[i*N+j]/sqr(conv_b[(sub_iter-1)*K*N+i*N+j]); 616 | } 617 | } 618 | if(sub_iter==1) break; 619 | problem->kernel.convolveAdjoint(&tmp[0], &v[0], K); 620 | for (int i=0; i=1; sub_iter--) { 644 | 645 | //#pragma omp parallel for firstprivate(u, convu) 646 | for (int i=0; ilambdas[i]*g[j] - r[i*N+j])/phi[sub_iter*K*N + i*N+j]; 651 | } 652 | resultLa[i] += dotp; 653 | } 654 | if (sub_iter!=1) { 655 | problem->kernel.convolve(&u[0], &convu[0], K); 656 | for (int i=0; ipdfs[i][0*N +j] / sqr(conv_b[(sub_iter-1)*K*N + i*N + j]); 659 | } 660 | } 661 | problem->kernel.convolveAdjoint(&convu[0], &r[0], K); 662 | for (int i=0; iNiters; 681 | 682 | std::vector > a(K*N), u(K*N), convu(K*N); 683 | 684 | if(!warmRestart) 685 | { 686 | b.resize(K*N*(n_iter_gradients+1)); 687 | std::fill(b.begin(), b.begin()+K*N, 0.0); 688 | } 689 | 690 | conv_b.resize(K*N*n_iter_gradients); 691 | phi.resize(K*N*(n_iter_gradients+1)); 692 | g.resize(N); 693 | r.resize(K*N); memset(&r[0], 0, K*N*sizeof(r[0])); 694 | 695 | problem->normalize_values(); 696 | 697 | // Forward log, Backward log + sign array 698 | 699 | for (int iter=1; iter<=n_iter_gradients; iter++) { 700 | 701 | //#pragma omp parallel for firstprivate(a) 702 | problem->kernel.log_convolveAdjoint(&b[(iter-1)*K*N], &conv_b[(size_t)((iter-1)*K*N)], K); 703 | for (int i=0; ipdfs[i][0*N + j]) - conv_b[offset+j]; 707 | } 708 | } 709 | problem->kernel.log_convolve(&a[0], &phi[iter*K*N], K); 710 | 711 | memset(&barycenter[0], 0, N*sizeof(barycenter[0])); 712 | for (int i=0; ilambdas[i]; 714 | for (int j=0; j > n(N), v(K*N,0.0), tmp(K*N), c(K*N), sumv(N); 731 | 732 | loss.gradient(&barycenter[0], &problem->observed_pdf[id*N], N, &g[0]); 733 | 734 | // gradient w.r.t dictionary 735 | 736 | memset(resultDic, 0, K*N*sizeof(resultDic[0])); 737 | memcpy(&n[0], &g[0], N*sizeof(double)); 738 | 739 | unsigned char * signArray = new unsigned char[(K*N+7)/8]; 740 | 741 | 742 | for (int sub_iter = n_iter_gradients; sub_iter>=1; sub_iter--) { 743 | memset(&sumv[0], 0, N*sizeof(sumv[0])); 744 | for (int i=0; ilambdas[i]*n[j] + v[i*N+j], signArray, i*N+j) + b[sub_iter*K*N + i*N+j]; 747 | } 748 | } 749 | problem->kernel.log_convolve_signArray(&tmp[0], signArray, &c[0], K); 750 | for (int i=0; ipdfs[i][0*N + j]) + c[i*N+j] - 2.0*(conv_b[(sub_iter-1)*K*N+i*N+j]); 756 | } 757 | } 758 | if(sub_iter==1) break; 759 | problem->kernel.log_convolve_signArrayAdjoint(&tmp[0], signArray, &v[0], K); 760 | for (int i=0; i=1; sub_iter--) { 781 | 782 | //#pragma omp parallel for firstprivate(u, convu) 783 | for (int i=0; ilambdas[i]*g[j] + r[i*N+j], signArray, i*N+j) - phi[sub_iter*K*N + i*N+j]; 788 | } 789 | resultLa[i] += dotp; 790 | } 791 | 792 | if (sub_iter!=1) { 793 | problem->kernel.log_convolve_signArray(&u[0], signArray, &convu[0], K); 794 | for (int i=0; ipdfs[i][0*N +j]) - 2.0*(conv_b[(sub_iter-1)*K*N + i*N + j]); 798 | } 799 | } 800 | problem->kernel.log_convolve_signArrayAdjoint(&convu[0], signArray, &r[0], K); 801 | for (int i=0; iNiters; 820 | 821 | std::vector > c_b, c_conv_b, c_phi, c_g, c_r, c_bary; 822 | // std::vector > c_b, c_conv_b, c_phi, c_g, c_r; 823 | 824 | std::vector > a(K*N), u(K*N), convu(K*N); 825 | // std::fill(b.begin(), b.begin()+K*N, 0.0); 826 | 827 | g.resize(N); 828 | c_b.resize(K*N*(bary_computation->Niters+1)); 829 | std::fill(c_b.begin(), c_b.begin()+K*N, 0.0); 830 | c_conv_b.resize(K*N*bary_computation->Niters); 831 | c_phi.resize(K*N*(bary_computation->Niters+1)); 832 | c_g.resize(N); 833 | c_r.resize(K*N); memset(&c_r[0], 0, K*N*sizeof(c_r[0])); 834 | c_bary.resize(N); 835 | float logepsilon = (float)log(EPSILON); 836 | 837 | problem->normalize_values(); 838 | 839 | // Bregman Projections 840 | for (int iter=1; iter<=n_iter_gradients; iter++) { 841 | 842 | //#pragma omp parallel for firstprivate(a) 843 | problem->kernel.log_convolveAdjoint(&c_b[(iter-1)*K*N], &c_conv_b[(size_t)((iter-1)*K*N)], K); 844 | for (int i=0; ipdfs[i][0*N + j]) - c_conv_b[offset+j]; 848 | // a[i*N + j] = problem->pdfs[i][0*N + j] / conv_b[offset+j]; 849 | } 850 | } 851 | problem->kernel.log_convolve(&a[0], &c_phi[iter*K*N], K); 852 | // problem->kernel.convolve(&a[0], &phi[iter*K*N], K); 853 | 854 | memset(&c_bary[0], 0, N*sizeof(c_bary[0])); 855 | for (int i=0; ilambdas[i]; 857 | for (int j=0; j > n(N), v(K*N,0.0), tmp(K*N), c(K*N), sumv(N); 882 | // std::vector > n(N); 883 | 884 | loss.gradient(&barycenter[0], &problem->observed_pdf[id*N], N, &g[0]); 885 | 886 | for (int j=0; j=1; sub_iter--) { 897 | memset(&sumv[0], 0, N*sizeof(sumv[0])); 898 | for (int i=0; ilambdas[i]*n[j] - v[i*N+j]) + c_b[sub_iter*K*N + i*N+j]; 901 | // tmp[i*N+j] = (problem->lambdas[i]*n[j]-v[i*N+j]) * b[sub_iter*K*N + i*N+j]; 902 | } 903 | } 904 | 905 | problem->kernel.log_convolve(&tmp[0], &c[0], K); 906 | // problem->kernel.convolve(&tmp[0], &c[0], K); 907 | for (int i=0; ipdfs[i][0*N + j])+c[i*N+j] - 2.0f*(c_conv_b[(sub_iter-1)*K*N+i*N+j]); 912 | // tmp[i*N+j] = -problem->pdfs[i][0*N + j]*c[i*N+j]/sqr(conv_b[(sub_iter-1)*K*N+i*N+j]); 913 | } 914 | } 915 | if(sub_iter==1) break; 916 | problem->kernel.log_convolveAdjoint(&tmp[0], &v[0], K); 917 | // problem->kernel.convolveAdjoint(&tmp[0], &v[0], K); 918 | 919 | ComplexDataType minus = std::log((ComplexDataType)(-1.0f)); 920 | for (int i=0; i=1; sub_iter--) { 942 | 943 | //#pragma omp parallel for firstprivate(u, convu) 944 | for (int i=0; ilambdas[i]*c_g[j] - c_r[i*N+j]) - c_phi[sub_iter*K*N + i*N+j]; 950 | // u[i*N+j] = (problem->lambdas[i]*g[j] - r[i*N+j])/phi[sub_iter*K*N + i*N+j]; 951 | } 952 | resultLa[i] += std::real(dotp); 953 | } 954 | 955 | if (sub_iter!=1) { 956 | problem->kernel.log_convolve(&u[0], &convu[0], K); 957 | // problem->kernel.convolve(&u[0], &convu[0], K); 958 | for (int i=0; ipdfs[i][0*N +j]) - 2.0f*(c_conv_b[(sub_iter-1)*K*N + i*N + j]); 961 | // convu[i*N+j] *= -problem->pdfs[i][0*N +j] / sqr(conv_b[(sub_iter-1)*K*N + i*N + j]); 962 | } 963 | } 964 | problem->kernel.log_convolveAdjoint(&convu[0], &c_r[0], K); 965 | // problem->kernel.convolveAdjoint(&convu[0], &r[0], K); 966 | 967 | ComplexDataType minus = std::log((ComplexDataType)(-1.0f)); 968 | for (int i=0; i* regression = (WassersteinRegression*)(instance); 989 | Problem* prob = regression->problem; 990 | WassersteinBarycenter* bary = regression->bary_computation; 991 | int N = prob->N; 992 | int K = prob->K; 993 | int P = prob->num_pdfs; 994 | double scaleDic = regression->scaleDictionary; 995 | assert(n==P*K+K*N); 996 | 997 | // Transfer the data back to the original domain 998 | double* variables2 = new double[n]; 999 | memcpy(variables2, variables, n*sizeof(double)); 1000 | if (regression->exp_weight) { 1001 | for (int i=0; i s(P, 0); 1014 | for (int i=0; i s2(K, 0.); 1024 | for (int i=0; ipdfs[i][j] = variables2[P*K+i*N+j]; 1037 | } 1038 | } 1039 | 1040 | std::vector barycenter(N); // To store one barycenter 1041 | std::vector curGradientDic(K*N); // To store the gradient on dictionary 1042 | memset(gradient, 0, (K*P+N*K) * sizeof(gradient[0])); // Set all values of gradient to 0 1043 | lbfgsfloatval_t lossVal = 0; 1044 | 1045 | // Containers for histogram and filename batches 1046 | std::vector > barycentersBatch(P); 1047 | std::vector filenamesBatch(P); 1048 | // Check condition for exporting the histograms 1049 | bool exportHists = regression->exportFittings && !regression->exportOnlyFinalSolution && (regression->iteration % regression->exportEveryMIter == 0); 1050 | 1051 | // For each histogram in the input dataset 1052 | for (int id =0; idset_weights(&variables2[id*K]); // Store weights 1054 | //prob->normalize_values(); 1055 | std::cout<<"|"<warmRestart) 1058 | { 1059 | // Use last scaling for initialization (warm restart) 1060 | std::copy(regression->b_storage.begin()+K*N*id, regression->b_storage.begin()+K*N*(id+1),regression->b.begin()); 1061 | } 1062 | 1063 | // Compute the barycenter and both gradients 1064 | #ifndef COMPUTE_BARYCENTER_LOG 1065 | regression->compute_gradient_both(&gradient[id*K], &curGradientDic[0], &barycenter[0], id); 1066 | #else 1067 | regression->compute_gradient_both_log_signArray(&gradient[id*K], &curGradientDic[0], &barycenter[0], id); 1068 | #endif 1069 | 1070 | // Save data (barycenter and filenames) for exporting in batch later 1071 | if(exportHists) 1072 | { 1073 | barycentersBatch[id] = barycenter; 1074 | std::ostringstream ss; 1075 | ss << "i-fitting_" << std::setfill('0'); 1076 | ss << std::setw(3) << regression->iteration << "_"; 1077 | ss << std::setw(3) << id; 1078 | filenamesBatch[id] = ss.str(); 1079 | } 1080 | 1081 | if(regression->warmRestart) 1082 | { 1083 | // Save last scaling b (for warm restart) 1084 | std::copy(regression->b.begin()+K*N*bary->Niters, regression->b.end(),regression->b_temp.begin()+K*N*id); 1085 | } 1086 | 1087 | // Compute the loss between the computed barycenter and the input histogram 1088 | // Total loss is the sum over all the input histograms 1089 | double currentLoss = regression->loss.loss(&barycenter[0], &prob->observed_pdf[id*N], prob->N); 1090 | // std::cout<<" "<exporter.exportHistogramsBatch(barycentersBatch,filenamesBatch); 1119 | } 1120 | 1121 | std::cout<exp_weight) { 1126 | for (int i=0; icompute_gradient_numeric(&test[0]); 1140 | std::cout<<"new iter-----------------"<* regression = (WassersteinRegression*)instance; 1151 | regression->iteration++; 1152 | if(regression->warmRestart) 1153 | printf("LBFGS Iteration %d, total iterations %d :\n", k,regression->iteration); 1154 | else 1155 | printf("Iteration %d:\n", k); 1156 | printf("time elapsed: %f (s)\n", regression->chrono.GetDiffMs()*0.001); 1157 | int K = regression->K; 1158 | int P = regression->problem->num_pdfs; 1159 | int N = regression->N; 1160 | 1161 | double scaleDic = regression->scaleDictionary; 1162 | 1163 | // Display fitting variables 1164 | std::cout<<"weights + 10 first values of first atom:"<exp_weight) { 1166 | for (int i=0; iexportAtoms && !(regression->exportOnlyFinalSolution) && (k % regression->exportEveryMIter == 0)) 1188 | { 1189 | std::vector base_result(N); 1190 | for (int i=0; iiteration << "_"; 1202 | ss << std::setw(3) << i; 1203 | std::string filename(ss.str()); 1204 | regression->exporter.exportHistogram(base_result,filename); 1205 | } 1206 | } 1207 | return 0; 1208 | } 1209 | 1210 | // Storage for warm restart 1211 | std::vector b_storage; 1212 | std::vector b_temp; 1213 | // Storage for intermediate results 1214 | std::vector > b, conv_b, phi, g, r; 1215 | 1216 | WassersteinBarycenter* bary_computation; 1217 | Problem* problem; 1218 | size_t K, N; 1219 | std::vector prev_solution; 1220 | gradient_type gradient_method; 1221 | PerfChrono chrono; 1222 | const BaseLoss &loss; 1223 | lbfgs_parameter_t lbfgs_param; 1224 | bool firstCall; 1225 | bool exp_weight; // Holds whether the weights are in log-domain 1226 | int iteration; 1227 | int exportEveryMIter; 1228 | // For histogram exportation 1229 | const ExportHistogramBase &exporter; 1230 | // For dual regression (regress_both) 1231 | double scaleDictionary; 1232 | bool exportAtoms; 1233 | bool exportFittings; 1234 | bool exportOnlyFinalSolution; // works for atoms and fittings 1235 | // For warm restart 1236 | bool warmRestart; 1237 | int wrTotalIteration; 1238 | }; 1239 | -------------------------------------------------------------------------------- /cpp/lbfgs/arithmetic_ansi.h: -------------------------------------------------------------------------------- 1 | /* 2 | * ANSI C implementation of vector operations. 3 | * 4 | * Copyright (c) 2007-2010 Naoaki Okazaki 5 | * All rights reserved. 6 | * 7 | * Permission is hereby granted, free of charge, to any person obtaining a copy 8 | * of this software and associated documentation files (the "Software"), to deal 9 | * in the Software without restriction, including without limitation the rights 10 | * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | * copies of the Software, and to permit persons to whom the Software is 12 | * furnished to do so, subject to the following conditions: 13 | * 14 | * The above copyright notice and this permission notice shall be included in 15 | * all copies or substantial portions of the Software. 16 | * 17 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 | * THE SOFTWARE. 24 | */ 25 | 26 | /* $Id$ */ 27 | 28 | #include 29 | #include 30 | 31 | #if LBFGS_FLOAT == 32 && LBFGS_IEEE_FLOAT 32 | #define fsigndiff(x, y) (((*(uint32_t*)(x)) ^ (*(uint32_t*)(y))) & 0x80000000U) 33 | #else 34 | #define fsigndiff(x, y) (*(x) * (*(y) / fabs(*(y))) < 0.) 35 | #endif/*LBFGS_IEEE_FLOAT*/ 36 | 37 | inline static void* vecalloc(size_t size) 38 | { 39 | void *memblock = malloc(size); 40 | if (memblock) { 41 | memset(memblock, 0, size); 42 | } 43 | return memblock; 44 | } 45 | 46 | inline static void vecfree(void *memblock) 47 | { 48 | free(memblock); 49 | } 50 | 51 | inline static void vecset(lbfgsfloatval_t *x, const lbfgsfloatval_t c, const int n) 52 | { 53 | int i; 54 | 55 | for (i = 0;i < n;++i) { 56 | x[i] = c; 57 | } 58 | } 59 | 60 | inline static void veccpy(lbfgsfloatval_t *y, const lbfgsfloatval_t *x, const int n) 61 | { 62 | int i; 63 | 64 | for (i = 0;i < n;++i) { 65 | y[i] = x[i]; 66 | } 67 | } 68 | 69 | inline static void vecncpy(lbfgsfloatval_t *y, const lbfgsfloatval_t *x, const int n) 70 | { 71 | int i; 72 | 73 | for (i = 0;i < n;++i) { 74 | y[i] = -x[i]; 75 | } 76 | } 77 | 78 | inline static void vecadd(lbfgsfloatval_t *y, const lbfgsfloatval_t *x, const lbfgsfloatval_t c, const int n) 79 | { 80 | int i; 81 | 82 | for (i = 0;i < n;++i) { 83 | y[i] += c * x[i]; 84 | } 85 | } 86 | 87 | inline static void vecdiff(lbfgsfloatval_t *z, const lbfgsfloatval_t *x, const lbfgsfloatval_t *y, const int n) 88 | { 89 | int i; 90 | 91 | for (i = 0;i < n;++i) { 92 | z[i] = x[i] - y[i]; 93 | } 94 | } 95 | 96 | inline static void vecscale(lbfgsfloatval_t *y, const lbfgsfloatval_t c, const int n) 97 | { 98 | int i; 99 | 100 | for (i = 0;i < n;++i) { 101 | y[i] *= c; 102 | } 103 | } 104 | 105 | inline static void vecmul(lbfgsfloatval_t *y, const lbfgsfloatval_t *x, const int n) 106 | { 107 | int i; 108 | 109 | for (i = 0;i < n;++i) { 110 | y[i] *= x[i]; 111 | } 112 | } 113 | 114 | inline static void vecdot(lbfgsfloatval_t* s, const lbfgsfloatval_t *x, const lbfgsfloatval_t *y, const int n) 115 | { 116 | int i; 117 | *s = 0.; 118 | for (i = 0;i < n;++i) { 119 | *s += x[i] * y[i]; 120 | } 121 | } 122 | 123 | inline static void vec2norm(lbfgsfloatval_t* s, const lbfgsfloatval_t *x, const int n) 124 | { 125 | vecdot(s, x, x, n); 126 | *s = (lbfgsfloatval_t)sqrt(*s); 127 | } 128 | 129 | inline static void vec2norminv(lbfgsfloatval_t* s, const lbfgsfloatval_t *x, const int n) 130 | { 131 | vec2norm(s, x, n); 132 | *s = (lbfgsfloatval_t)(1.0 / *s); 133 | } 134 | -------------------------------------------------------------------------------- /cpp/lbfgs/arithmetic_sse_double.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SSE2 implementation of vector oprations (64bit double). 3 | * 4 | * Copyright (c) 2007-2010 Naoaki Okazaki 5 | * All rights reserved. 6 | * 7 | * Permission is hereby granted, free of charge, to any person obtaining a copy 8 | * of this software and associated documentation files (the "Software"), to deal 9 | * in the Software without restriction, including without limitation the rights 10 | * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | * copies of the Software, and to permit persons to whom the Software is 12 | * furnished to do so, subject to the following conditions: 13 | * 14 | * The above copyright notice and this permission notice shall be included in 15 | * all copies or substantial portions of the Software. 16 | * 17 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 | * THE SOFTWARE. 24 | */ 25 | 26 | /* $Id$ */ 27 | 28 | #include 29 | #ifndef __APPLE__ 30 | #include 31 | #endif 32 | #include 33 | 34 | #if 1400 <= _MSC_VER 35 | #include 36 | #endif/*1400 <= _MSC_VER*/ 37 | 38 | #if HAVE_EMMINTRIN_H 39 | #include 40 | #endif/*HAVE_EMMINTRIN_H*/ 41 | 42 | inline static void* vecalloc(size_t size) 43 | { 44 | #if defined(_MSC_VER) 45 | void *memblock = _aligned_malloc(size, 16); 46 | #elif defined(__APPLE__) /* OS X always aligns on 16-byte boundaries */ 47 | void *memblock = malloc(size); 48 | #else 49 | void *memblock = NULL, *p = NULL; 50 | if (posix_memalign(&p, 16, size) == 0) { 51 | memblock = p; 52 | } 53 | #endif 54 | if (memblock != NULL) { 55 | memset(memblock, 0, size); 56 | } 57 | return memblock; 58 | } 59 | 60 | inline static void vecfree(void *memblock) 61 | { 62 | #ifdef _MSC_VER 63 | _aligned_free(memblock); 64 | #else 65 | free(memblock); 66 | #endif 67 | } 68 | 69 | #define fsigndiff(x, y) \ 70 | ((_mm_movemask_pd(_mm_set_pd(*(x), *(y))) + 1) & 0x002) 71 | 72 | #define vecset(x, c, n) \ 73 | { \ 74 | int i; \ 75 | __m128d XMM0 = _mm_set1_pd(c); \ 76 | for (i = 0;i < (n);i += 8) { \ 77 | _mm_store_pd((x)+i , XMM0); \ 78 | _mm_store_pd((x)+i+2, XMM0); \ 79 | _mm_store_pd((x)+i+4, XMM0); \ 80 | _mm_store_pd((x)+i+6, XMM0); \ 81 | } \ 82 | } 83 | 84 | #define veccpy(y, x, n) \ 85 | { \ 86 | int i; \ 87 | for (i = 0;i < (n);i += 8) { \ 88 | __m128d XMM0 = _mm_load_pd((x)+i ); \ 89 | __m128d XMM1 = _mm_load_pd((x)+i+2); \ 90 | __m128d XMM2 = _mm_load_pd((x)+i+4); \ 91 | __m128d XMM3 = _mm_load_pd((x)+i+6); \ 92 | _mm_store_pd((y)+i , XMM0); \ 93 | _mm_store_pd((y)+i+2, XMM1); \ 94 | _mm_store_pd((y)+i+4, XMM2); \ 95 | _mm_store_pd((y)+i+6, XMM3); \ 96 | } \ 97 | } 98 | 99 | #define vecncpy(y, x, n) \ 100 | { \ 101 | int i; \ 102 | for (i = 0;i < (n);i += 8) { \ 103 | __m128d XMM0 = _mm_setzero_pd(); \ 104 | __m128d XMM1 = _mm_setzero_pd(); \ 105 | __m128d XMM2 = _mm_setzero_pd(); \ 106 | __m128d XMM3 = _mm_setzero_pd(); \ 107 | __m128d XMM4 = _mm_load_pd((x)+i ); \ 108 | __m128d XMM5 = _mm_load_pd((x)+i+2); \ 109 | __m128d XMM6 = _mm_load_pd((x)+i+4); \ 110 | __m128d XMM7 = _mm_load_pd((x)+i+6); \ 111 | XMM0 = _mm_sub_pd(XMM0, XMM4); \ 112 | XMM1 = _mm_sub_pd(XMM1, XMM5); \ 113 | XMM2 = _mm_sub_pd(XMM2, XMM6); \ 114 | XMM3 = _mm_sub_pd(XMM3, XMM7); \ 115 | _mm_store_pd((y)+i , XMM0); \ 116 | _mm_store_pd((y)+i+2, XMM1); \ 117 | _mm_store_pd((y)+i+4, XMM2); \ 118 | _mm_store_pd((y)+i+6, XMM3); \ 119 | } \ 120 | } 121 | 122 | #define vecadd(y, x, c, n) \ 123 | { \ 124 | int i; \ 125 | __m128d XMM7 = _mm_set1_pd(c); \ 126 | for (i = 0;i < (n);i += 4) { \ 127 | __m128d XMM0 = _mm_load_pd((x)+i ); \ 128 | __m128d XMM1 = _mm_load_pd((x)+i+2); \ 129 | __m128d XMM2 = _mm_load_pd((y)+i ); \ 130 | __m128d XMM3 = _mm_load_pd((y)+i+2); \ 131 | XMM0 = _mm_mul_pd(XMM0, XMM7); \ 132 | XMM1 = _mm_mul_pd(XMM1, XMM7); \ 133 | XMM2 = _mm_add_pd(XMM2, XMM0); \ 134 | XMM3 = _mm_add_pd(XMM3, XMM1); \ 135 | _mm_store_pd((y)+i , XMM2); \ 136 | _mm_store_pd((y)+i+2, XMM3); \ 137 | } \ 138 | } 139 | 140 | #define vecdiff(z, x, y, n) \ 141 | { \ 142 | int i; \ 143 | for (i = 0;i < (n);i += 8) { \ 144 | __m128d XMM0 = _mm_load_pd((x)+i ); \ 145 | __m128d XMM1 = _mm_load_pd((x)+i+2); \ 146 | __m128d XMM2 = _mm_load_pd((x)+i+4); \ 147 | __m128d XMM3 = _mm_load_pd((x)+i+6); \ 148 | __m128d XMM4 = _mm_load_pd((y)+i ); \ 149 | __m128d XMM5 = _mm_load_pd((y)+i+2); \ 150 | __m128d XMM6 = _mm_load_pd((y)+i+4); \ 151 | __m128d XMM7 = _mm_load_pd((y)+i+6); \ 152 | XMM0 = _mm_sub_pd(XMM0, XMM4); \ 153 | XMM1 = _mm_sub_pd(XMM1, XMM5); \ 154 | XMM2 = _mm_sub_pd(XMM2, XMM6); \ 155 | XMM3 = _mm_sub_pd(XMM3, XMM7); \ 156 | _mm_store_pd((z)+i , XMM0); \ 157 | _mm_store_pd((z)+i+2, XMM1); \ 158 | _mm_store_pd((z)+i+4, XMM2); \ 159 | _mm_store_pd((z)+i+6, XMM3); \ 160 | } \ 161 | } 162 | 163 | #define vecscale(y, c, n) \ 164 | { \ 165 | int i; \ 166 | __m128d XMM7 = _mm_set1_pd(c); \ 167 | for (i = 0;i < (n);i += 4) { \ 168 | __m128d XMM0 = _mm_load_pd((y)+i ); \ 169 | __m128d XMM1 = _mm_load_pd((y)+i+2); \ 170 | XMM0 = _mm_mul_pd(XMM0, XMM7); \ 171 | XMM1 = _mm_mul_pd(XMM1, XMM7); \ 172 | _mm_store_pd((y)+i , XMM0); \ 173 | _mm_store_pd((y)+i+2, XMM1); \ 174 | } \ 175 | } 176 | 177 | #define vecmul(y, x, n) \ 178 | { \ 179 | int i; \ 180 | for (i = 0;i < (n);i += 8) { \ 181 | __m128d XMM0 = _mm_load_pd((x)+i ); \ 182 | __m128d XMM1 = _mm_load_pd((x)+i+2); \ 183 | __m128d XMM2 = _mm_load_pd((x)+i+4); \ 184 | __m128d XMM3 = _mm_load_pd((x)+i+6); \ 185 | __m128d XMM4 = _mm_load_pd((y)+i ); \ 186 | __m128d XMM5 = _mm_load_pd((y)+i+2); \ 187 | __m128d XMM6 = _mm_load_pd((y)+i+4); \ 188 | __m128d XMM7 = _mm_load_pd((y)+i+6); \ 189 | XMM4 = _mm_mul_pd(XMM4, XMM0); \ 190 | XMM5 = _mm_mul_pd(XMM5, XMM1); \ 191 | XMM6 = _mm_mul_pd(XMM6, XMM2); \ 192 | XMM7 = _mm_mul_pd(XMM7, XMM3); \ 193 | _mm_store_pd((y)+i , XMM4); \ 194 | _mm_store_pd((y)+i+2, XMM5); \ 195 | _mm_store_pd((y)+i+4, XMM6); \ 196 | _mm_store_pd((y)+i+6, XMM7); \ 197 | } \ 198 | } 199 | 200 | 201 | 202 | #if 3 <= __SSE__ || defined(__SSE3__) 203 | /* 204 | Horizontal add with haddps SSE3 instruction. The work register (rw) 205 | is unused. 206 | */ 207 | #define __horizontal_sum(r, rw) \ 208 | r = _mm_hadd_ps(r, r); \ 209 | r = _mm_hadd_ps(r, r); 210 | 211 | #else 212 | /* 213 | Horizontal add with SSE instruction. The work register (rw) is used. 214 | */ 215 | #define __horizontal_sum(r, rw) \ 216 | rw = r; \ 217 | r = _mm_shuffle_ps(r, rw, _MM_SHUFFLE(1, 0, 3, 2)); \ 218 | r = _mm_add_ps(r, rw); \ 219 | rw = r; \ 220 | r = _mm_shuffle_ps(r, rw, _MM_SHUFFLE(2, 3, 0, 1)); \ 221 | r = _mm_add_ps(r, rw); 222 | 223 | #endif 224 | 225 | #define vecdot(s, x, y, n) \ 226 | { \ 227 | int i; \ 228 | __m128d XMM0 = _mm_setzero_pd(); \ 229 | __m128d XMM1 = _mm_setzero_pd(); \ 230 | __m128d XMM2, XMM3, XMM4, XMM5; \ 231 | for (i = 0;i < (n);i += 4) { \ 232 | XMM2 = _mm_load_pd((x)+i ); \ 233 | XMM3 = _mm_load_pd((x)+i+2); \ 234 | XMM4 = _mm_load_pd((y)+i ); \ 235 | XMM5 = _mm_load_pd((y)+i+2); \ 236 | XMM2 = _mm_mul_pd(XMM2, XMM4); \ 237 | XMM3 = _mm_mul_pd(XMM3, XMM5); \ 238 | XMM0 = _mm_add_pd(XMM0, XMM2); \ 239 | XMM1 = _mm_add_pd(XMM1, XMM3); \ 240 | } \ 241 | XMM0 = _mm_add_pd(XMM0, XMM1); \ 242 | XMM1 = _mm_shuffle_pd(XMM0, XMM0, _MM_SHUFFLE2(1, 1)); \ 243 | XMM0 = _mm_add_pd(XMM0, XMM1); \ 244 | _mm_store_sd((s), XMM0); \ 245 | } 246 | 247 | #define vec2norm(s, x, n) \ 248 | { \ 249 | int i; \ 250 | __m128d XMM0 = _mm_setzero_pd(); \ 251 | __m128d XMM1 = _mm_setzero_pd(); \ 252 | __m128d XMM2, XMM3, XMM4, XMM5; \ 253 | for (i = 0;i < (n);i += 4) { \ 254 | XMM2 = _mm_load_pd((x)+i ); \ 255 | XMM3 = _mm_load_pd((x)+i+2); \ 256 | XMM4 = XMM2; \ 257 | XMM5 = XMM3; \ 258 | XMM2 = _mm_mul_pd(XMM2, XMM4); \ 259 | XMM3 = _mm_mul_pd(XMM3, XMM5); \ 260 | XMM0 = _mm_add_pd(XMM0, XMM2); \ 261 | XMM1 = _mm_add_pd(XMM1, XMM3); \ 262 | } \ 263 | XMM0 = _mm_add_pd(XMM0, XMM1); \ 264 | XMM1 = _mm_shuffle_pd(XMM0, XMM0, _MM_SHUFFLE2(1, 1)); \ 265 | XMM0 = _mm_add_pd(XMM0, XMM1); \ 266 | XMM0 = _mm_sqrt_pd(XMM0); \ 267 | _mm_store_sd((s), XMM0); \ 268 | } 269 | 270 | 271 | #define vec2norminv(s, x, n) \ 272 | { \ 273 | int i; \ 274 | __m128d XMM0 = _mm_setzero_pd(); \ 275 | __m128d XMM1 = _mm_setzero_pd(); \ 276 | __m128d XMM2, XMM3, XMM4, XMM5; \ 277 | for (i = 0;i < (n);i += 4) { \ 278 | XMM2 = _mm_load_pd((x)+i ); \ 279 | XMM3 = _mm_load_pd((x)+i+2); \ 280 | XMM4 = XMM2; \ 281 | XMM5 = XMM3; \ 282 | XMM2 = _mm_mul_pd(XMM2, XMM4); \ 283 | XMM3 = _mm_mul_pd(XMM3, XMM5); \ 284 | XMM0 = _mm_add_pd(XMM0, XMM2); \ 285 | XMM1 = _mm_add_pd(XMM1, XMM3); \ 286 | } \ 287 | XMM2 = _mm_set1_pd(1.0); \ 288 | XMM0 = _mm_add_pd(XMM0, XMM1); \ 289 | XMM1 = _mm_shuffle_pd(XMM0, XMM0, _MM_SHUFFLE2(1, 1)); \ 290 | XMM0 = _mm_add_pd(XMM0, XMM1); \ 291 | XMM0 = _mm_sqrt_pd(XMM0); \ 292 | XMM2 = _mm_div_pd(XMM2, XMM0); \ 293 | _mm_store_sd((s), XMM2); \ 294 | } 295 | -------------------------------------------------------------------------------- /cpp/lbfgs/arithmetic_sse_float.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SSE/SSE3 implementation of vector oprations (32bit float). 3 | * 4 | * Copyright (c) 2007-2010 Naoaki Okazaki 5 | * All rights reserved. 6 | * 7 | * Permission is hereby granted, free of charge, to any person obtaining a copy 8 | * of this software and associated documentation files (the "Software"), to deal 9 | * in the Software without restriction, including without limitation the rights 10 | * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | * copies of the Software, and to permit persons to whom the Software is 12 | * furnished to do so, subject to the following conditions: 13 | * 14 | * The above copyright notice and this permission notice shall be included in 15 | * all copies or substantial portions of the Software. 16 | * 17 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 | * THE SOFTWARE. 24 | */ 25 | 26 | /* $Id$ */ 27 | 28 | #include 29 | #ifndef __APPLE__ 30 | #include 31 | #endif 32 | #include 33 | 34 | #if 1400 <= _MSC_VER 35 | #include 36 | #endif/*_MSC_VER*/ 37 | 38 | #if HAVE_XMMINTRIN_H 39 | #include 40 | #endif/*HAVE_XMMINTRIN_H*/ 41 | 42 | #if LBFGS_FLOAT == 32 && LBFGS_IEEE_FLOAT 43 | #define fsigndiff(x, y) (((*(uint32_t*)(x)) ^ (*(uint32_t*)(y))) & 0x80000000U) 44 | #else 45 | #define fsigndiff(x, y) (*(x) * (*(y) / fabs(*(y))) < 0.) 46 | #endif/*LBFGS_IEEE_FLOAT*/ 47 | 48 | inline static void* vecalloc(size_t size) 49 | { 50 | #if defined(_MSC_VER) 51 | void *memblock = _aligned_malloc(size, 16); 52 | #elif defined(__APPLE__) /* OS X always aligns on 16-byte boundaries */ 53 | void *memblock = malloc(size); 54 | #else 55 | void *memblock = NULL, *p = NULL; 56 | if (posix_memalign(&p, 16, size) == 0) { 57 | memblock = p; 58 | } 59 | #endif 60 | if (memblock != NULL) { 61 | memset(memblock, 0, size); 62 | } 63 | return memblock; 64 | } 65 | 66 | inline static void vecfree(void *memblock) 67 | { 68 | _aligned_free(memblock); 69 | } 70 | 71 | #define vecset(x, c, n) \ 72 | { \ 73 | int i; \ 74 | __m128 XMM0 = _mm_set_ps1(c); \ 75 | for (i = 0;i < (n);i += 16) { \ 76 | _mm_store_ps((x)+i , XMM0); \ 77 | _mm_store_ps((x)+i+ 4, XMM0); \ 78 | _mm_store_ps((x)+i+ 8, XMM0); \ 79 | _mm_store_ps((x)+i+12, XMM0); \ 80 | } \ 81 | } 82 | 83 | #define veccpy(y, x, n) \ 84 | { \ 85 | int i; \ 86 | for (i = 0;i < (n);i += 16) { \ 87 | __m128 XMM0 = _mm_load_ps((x)+i ); \ 88 | __m128 XMM1 = _mm_load_ps((x)+i+ 4); \ 89 | __m128 XMM2 = _mm_load_ps((x)+i+ 8); \ 90 | __m128 XMM3 = _mm_load_ps((x)+i+12); \ 91 | _mm_store_ps((y)+i , XMM0); \ 92 | _mm_store_ps((y)+i+ 4, XMM1); \ 93 | _mm_store_ps((y)+i+ 8, XMM2); \ 94 | _mm_store_ps((y)+i+12, XMM3); \ 95 | } \ 96 | } 97 | 98 | #define vecncpy(y, x, n) \ 99 | { \ 100 | int i; \ 101 | const uint32_t mask = 0x80000000; \ 102 | __m128 XMM4 = _mm_load_ps1((float*)&mask); \ 103 | for (i = 0;i < (n);i += 16) { \ 104 | __m128 XMM0 = _mm_load_ps((x)+i ); \ 105 | __m128 XMM1 = _mm_load_ps((x)+i+ 4); \ 106 | __m128 XMM2 = _mm_load_ps((x)+i+ 8); \ 107 | __m128 XMM3 = _mm_load_ps((x)+i+12); \ 108 | XMM0 = _mm_xor_ps(XMM0, XMM4); \ 109 | XMM1 = _mm_xor_ps(XMM1, XMM4); \ 110 | XMM2 = _mm_xor_ps(XMM2, XMM4); \ 111 | XMM3 = _mm_xor_ps(XMM3, XMM4); \ 112 | _mm_store_ps((y)+i , XMM0); \ 113 | _mm_store_ps((y)+i+ 4, XMM1); \ 114 | _mm_store_ps((y)+i+ 8, XMM2); \ 115 | _mm_store_ps((y)+i+12, XMM3); \ 116 | } \ 117 | } 118 | 119 | #define vecadd(y, x, c, n) \ 120 | { \ 121 | int i; \ 122 | __m128 XMM7 = _mm_set_ps1(c); \ 123 | for (i = 0;i < (n);i += 8) { \ 124 | __m128 XMM0 = _mm_load_ps((x)+i ); \ 125 | __m128 XMM1 = _mm_load_ps((x)+i+4); \ 126 | __m128 XMM2 = _mm_load_ps((y)+i ); \ 127 | __m128 XMM3 = _mm_load_ps((y)+i+4); \ 128 | XMM0 = _mm_mul_ps(XMM0, XMM7); \ 129 | XMM1 = _mm_mul_ps(XMM1, XMM7); \ 130 | XMM2 = _mm_add_ps(XMM2, XMM0); \ 131 | XMM3 = _mm_add_ps(XMM3, XMM1); \ 132 | _mm_store_ps((y)+i , XMM2); \ 133 | _mm_store_ps((y)+i+4, XMM3); \ 134 | } \ 135 | } 136 | 137 | #define vecdiff(z, x, y, n) \ 138 | { \ 139 | int i; \ 140 | for (i = 0;i < (n);i += 16) { \ 141 | __m128 XMM0 = _mm_load_ps((x)+i ); \ 142 | __m128 XMM1 = _mm_load_ps((x)+i+ 4); \ 143 | __m128 XMM2 = _mm_load_ps((x)+i+ 8); \ 144 | __m128 XMM3 = _mm_load_ps((x)+i+12); \ 145 | __m128 XMM4 = _mm_load_ps((y)+i ); \ 146 | __m128 XMM5 = _mm_load_ps((y)+i+ 4); \ 147 | __m128 XMM6 = _mm_load_ps((y)+i+ 8); \ 148 | __m128 XMM7 = _mm_load_ps((y)+i+12); \ 149 | XMM0 = _mm_sub_ps(XMM0, XMM4); \ 150 | XMM1 = _mm_sub_ps(XMM1, XMM5); \ 151 | XMM2 = _mm_sub_ps(XMM2, XMM6); \ 152 | XMM3 = _mm_sub_ps(XMM3, XMM7); \ 153 | _mm_store_ps((z)+i , XMM0); \ 154 | _mm_store_ps((z)+i+ 4, XMM1); \ 155 | _mm_store_ps((z)+i+ 8, XMM2); \ 156 | _mm_store_ps((z)+i+12, XMM3); \ 157 | } \ 158 | } 159 | 160 | #define vecscale(y, c, n) \ 161 | { \ 162 | int i; \ 163 | __m128 XMM7 = _mm_set_ps1(c); \ 164 | for (i = 0;i < (n);i += 8) { \ 165 | __m128 XMM0 = _mm_load_ps((y)+i ); \ 166 | __m128 XMM1 = _mm_load_ps((y)+i+4); \ 167 | XMM0 = _mm_mul_ps(XMM0, XMM7); \ 168 | XMM1 = _mm_mul_ps(XMM1, XMM7); \ 169 | _mm_store_ps((y)+i , XMM0); \ 170 | _mm_store_ps((y)+i+4, XMM1); \ 171 | } \ 172 | } 173 | 174 | #define vecmul(y, x, n) \ 175 | { \ 176 | int i; \ 177 | for (i = 0;i < (n);i += 16) { \ 178 | __m128 XMM0 = _mm_load_ps((x)+i ); \ 179 | __m128 XMM1 = _mm_load_ps((x)+i+ 4); \ 180 | __m128 XMM2 = _mm_load_ps((x)+i+ 8); \ 181 | __m128 XMM3 = _mm_load_ps((x)+i+12); \ 182 | __m128 XMM4 = _mm_load_ps((y)+i ); \ 183 | __m128 XMM5 = _mm_load_ps((y)+i+ 4); \ 184 | __m128 XMM6 = _mm_load_ps((y)+i+ 8); \ 185 | __m128 XMM7 = _mm_load_ps((y)+i+12); \ 186 | XMM4 = _mm_mul_ps(XMM4, XMM0); \ 187 | XMM5 = _mm_mul_ps(XMM5, XMM1); \ 188 | XMM6 = _mm_mul_ps(XMM6, XMM2); \ 189 | XMM7 = _mm_mul_ps(XMM7, XMM3); \ 190 | _mm_store_ps((y)+i , XMM4); \ 191 | _mm_store_ps((y)+i+ 4, XMM5); \ 192 | _mm_store_ps((y)+i+ 8, XMM6); \ 193 | _mm_store_ps((y)+i+12, XMM7); \ 194 | } \ 195 | } 196 | 197 | 198 | 199 | #if 3 <= __SSE__ || defined(__SSE3__) 200 | /* 201 | Horizontal add with haddps SSE3 instruction. The work register (rw) 202 | is unused. 203 | */ 204 | #define __horizontal_sum(r, rw) \ 205 | r = _mm_hadd_ps(r, r); \ 206 | r = _mm_hadd_ps(r, r); 207 | 208 | #else 209 | /* 210 | Horizontal add with SSE instruction. The work register (rw) is used. 211 | */ 212 | #define __horizontal_sum(r, rw) \ 213 | rw = r; \ 214 | r = _mm_shuffle_ps(r, rw, _MM_SHUFFLE(1, 0, 3, 2)); \ 215 | r = _mm_add_ps(r, rw); \ 216 | rw = r; \ 217 | r = _mm_shuffle_ps(r, rw, _MM_SHUFFLE(2, 3, 0, 1)); \ 218 | r = _mm_add_ps(r, rw); 219 | 220 | #endif 221 | 222 | #define vecdot(s, x, y, n) \ 223 | { \ 224 | int i; \ 225 | __m128 XMM0 = _mm_setzero_ps(); \ 226 | __m128 XMM1 = _mm_setzero_ps(); \ 227 | __m128 XMM2, XMM3, XMM4, XMM5; \ 228 | for (i = 0;i < (n);i += 8) { \ 229 | XMM2 = _mm_load_ps((x)+i ); \ 230 | XMM3 = _mm_load_ps((x)+i+4); \ 231 | XMM4 = _mm_load_ps((y)+i ); \ 232 | XMM5 = _mm_load_ps((y)+i+4); \ 233 | XMM2 = _mm_mul_ps(XMM2, XMM4); \ 234 | XMM3 = _mm_mul_ps(XMM3, XMM5); \ 235 | XMM0 = _mm_add_ps(XMM0, XMM2); \ 236 | XMM1 = _mm_add_ps(XMM1, XMM3); \ 237 | } \ 238 | XMM0 = _mm_add_ps(XMM0, XMM1); \ 239 | __horizontal_sum(XMM0, XMM1); \ 240 | _mm_store_ss((s), XMM0); \ 241 | } 242 | 243 | #define vec2norm(s, x, n) \ 244 | { \ 245 | int i; \ 246 | __m128 XMM0 = _mm_setzero_ps(); \ 247 | __m128 XMM1 = _mm_setzero_ps(); \ 248 | __m128 XMM2, XMM3; \ 249 | for (i = 0;i < (n);i += 8) { \ 250 | XMM2 = _mm_load_ps((x)+i ); \ 251 | XMM3 = _mm_load_ps((x)+i+4); \ 252 | XMM2 = _mm_mul_ps(XMM2, XMM2); \ 253 | XMM3 = _mm_mul_ps(XMM3, XMM3); \ 254 | XMM0 = _mm_add_ps(XMM0, XMM2); \ 255 | XMM1 = _mm_add_ps(XMM1, XMM3); \ 256 | } \ 257 | XMM0 = _mm_add_ps(XMM0, XMM1); \ 258 | __horizontal_sum(XMM0, XMM1); \ 259 | XMM2 = XMM0; \ 260 | XMM1 = _mm_rsqrt_ss(XMM0); \ 261 | XMM3 = XMM1; \ 262 | XMM1 = _mm_mul_ss(XMM1, XMM1); \ 263 | XMM1 = _mm_mul_ss(XMM1, XMM3); \ 264 | XMM1 = _mm_mul_ss(XMM1, XMM0); \ 265 | XMM1 = _mm_mul_ss(XMM1, _mm_set_ss(-0.5f)); \ 266 | XMM3 = _mm_mul_ss(XMM3, _mm_set_ss(1.5f)); \ 267 | XMM3 = _mm_add_ss(XMM3, XMM1); \ 268 | XMM3 = _mm_mul_ss(XMM3, XMM2); \ 269 | _mm_store_ss((s), XMM3); \ 270 | } 271 | 272 | #define vec2norminv(s, x, n) \ 273 | { \ 274 | int i; \ 275 | __m128 XMM0 = _mm_setzero_ps(); \ 276 | __m128 XMM1 = _mm_setzero_ps(); \ 277 | __m128 XMM2, XMM3; \ 278 | for (i = 0;i < (n);i += 16) { \ 279 | XMM2 = _mm_load_ps((x)+i ); \ 280 | XMM3 = _mm_load_ps((x)+i+4); \ 281 | XMM2 = _mm_mul_ps(XMM2, XMM2); \ 282 | XMM3 = _mm_mul_ps(XMM3, XMM3); \ 283 | XMM0 = _mm_add_ps(XMM0, XMM2); \ 284 | XMM1 = _mm_add_ps(XMM1, XMM3); \ 285 | } \ 286 | XMM0 = _mm_add_ps(XMM0, XMM1); \ 287 | __horizontal_sum(XMM0, XMM1); \ 288 | XMM2 = XMM0; \ 289 | XMM1 = _mm_rsqrt_ss(XMM0); \ 290 | XMM3 = XMM1; \ 291 | XMM1 = _mm_mul_ss(XMM1, XMM1); \ 292 | XMM1 = _mm_mul_ss(XMM1, XMM3); \ 293 | XMM1 = _mm_mul_ss(XMM1, XMM0); \ 294 | XMM1 = _mm_mul_ss(XMM1, _mm_set_ss(-0.5f)); \ 295 | XMM3 = _mm_mul_ss(XMM3, _mm_set_ss(1.5f)); \ 296 | XMM3 = _mm_add_ss(XMM3, XMM1); \ 297 | _mm_store_ss((s), XMM3); \ 298 | } 299 | -------------------------------------------------------------------------------- /cpp/lbfgs/lbfgs.h: -------------------------------------------------------------------------------- 1 | /* 2 | * C library of Limited memory BFGS (L-BFGS). 3 | * 4 | * Copyright (c) 1990, Jorge Nocedal 5 | * Copyright (c) 2007-2010 Naoaki Okazaki 6 | * All rights reserved. 7 | * 8 | * Permission is hereby granted, free of charge, to any person obtaining a copy 9 | * of this software and associated documentation files (the "Software"), to deal 10 | * in the Software without restriction, including without limitation the rights 11 | * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | * copies of the Software, and to permit persons to whom the Software is 13 | * furnished to do so, subject to the following conditions: 14 | * 15 | * The above copyright notice and this permission notice shall be included in 16 | * all copies or substantial portions of the Software. 17 | * 18 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 24 | * THE SOFTWARE. 25 | */ 26 | 27 | /* $Id$ */ 28 | 29 | #ifndef __LBFGS_H__ 30 | #define __LBFGS_H__ 31 | 32 | #ifdef __cplusplus 33 | extern "C" { 34 | #endif/*__cplusplus*/ 35 | 36 | /* 37 | * The default precision of floating point values is 64bit (double). 38 | */ 39 | #ifndef LBFGS_FLOAT 40 | #define LBFGS_FLOAT 64 41 | #endif/*LBFGS_FLOAT*/ 42 | 43 | /* 44 | * Activate optimization routines for IEEE754 floating point values. 45 | */ 46 | #ifndef LBFGS_IEEE_FLOAT 47 | #define LBFGS_IEEE_FLOAT 1 48 | #endif/*LBFGS_IEEE_FLOAT*/ 49 | 50 | #if LBFGS_FLOAT == 32 51 | typedef float lbfgsfloatval_t; 52 | 53 | #elif LBFGS_FLOAT == 64 54 | typedef double lbfgsfloatval_t; 55 | 56 | #else 57 | #error "libLBFGS supports single (float; LBFGS_FLOAT = 32) or double (double; LBFGS_FLOAT=64) precision only." 58 | 59 | #endif 60 | 61 | 62 | /** 63 | * \addtogroup liblbfgs_api libLBFGS API 64 | * @{ 65 | * 66 | * The libLBFGS API. 67 | */ 68 | 69 | /** 70 | * Return values of lbfgs(). 71 | * 72 | * Roughly speaking, a negative value indicates an error. 73 | */ 74 | enum { 75 | /** L-BFGS reaches convergence. */ 76 | LBFGS_SUCCESS = 0, 77 | LBFGS_CONVERGENCE = 0, 78 | LBFGS_STOP, 79 | /** The initial variables already minimize the objective function. */ 80 | LBFGS_ALREADY_MINIMIZED, 81 | 82 | /** Unknown error. */ 83 | LBFGSERR_UNKNOWNERROR = -1024, 84 | /** Logic error. */ 85 | LBFGSERR_LOGICERROR, 86 | /** Insufficient memory. */ 87 | LBFGSERR_OUTOFMEMORY, 88 | /** The minimization process has been canceled. */ 89 | LBFGSERR_CANCELED, 90 | /** Invalid number of variables specified. */ 91 | LBFGSERR_INVALID_N, 92 | /** Invalid number of variables (for SSE) specified. */ 93 | LBFGSERR_INVALID_N_SSE, 94 | /** The array x must be aligned to 16 (for SSE). */ 95 | LBFGSERR_INVALID_X_SSE, 96 | /** Invalid parameter lbfgs_parameter_t::epsilon specified. */ 97 | LBFGSERR_INVALID_EPSILON, 98 | /** Invalid parameter lbfgs_parameter_t::past specified. */ 99 | LBFGSERR_INVALID_TESTPERIOD, 100 | /** Invalid parameter lbfgs_parameter_t::delta specified. */ 101 | LBFGSERR_INVALID_DELTA, 102 | /** Invalid parameter lbfgs_parameter_t::linesearch specified. */ 103 | LBFGSERR_INVALID_LINESEARCH, 104 | /** Invalid parameter lbfgs_parameter_t::max_step specified. */ 105 | LBFGSERR_INVALID_MINSTEP, 106 | /** Invalid parameter lbfgs_parameter_t::max_step specified. */ 107 | LBFGSERR_INVALID_MAXSTEP, 108 | /** Invalid parameter lbfgs_parameter_t::ftol specified. */ 109 | LBFGSERR_INVALID_FTOL, 110 | /** Invalid parameter lbfgs_parameter_t::wolfe specified. */ 111 | LBFGSERR_INVALID_WOLFE, 112 | /** Invalid parameter lbfgs_parameter_t::gtol specified. */ 113 | LBFGSERR_INVALID_GTOL, 114 | /** Invalid parameter lbfgs_parameter_t::xtol specified. */ 115 | LBFGSERR_INVALID_XTOL, 116 | /** Invalid parameter lbfgs_parameter_t::max_linesearch specified. */ 117 | LBFGSERR_INVALID_MAXLINESEARCH, 118 | /** Invalid parameter lbfgs_parameter_t::orthantwise_c specified. */ 119 | LBFGSERR_INVALID_ORTHANTWISE, 120 | /** Invalid parameter lbfgs_parameter_t::orthantwise_start specified. */ 121 | LBFGSERR_INVALID_ORTHANTWISE_START, 122 | /** Invalid parameter lbfgs_parameter_t::orthantwise_end specified. */ 123 | LBFGSERR_INVALID_ORTHANTWISE_END, 124 | /** The line-search step went out of the interval of uncertainty. */ 125 | LBFGSERR_OUTOFINTERVAL, 126 | /** A logic error occurred; alternatively, the interval of uncertainty 127 | became too small. */ 128 | LBFGSERR_INCORRECT_TMINMAX, 129 | /** A rounding error occurred; alternatively, no line-search step 130 | satisfies the sufficient decrease and curvature conditions. */ 131 | LBFGSERR_ROUNDING_ERROR, 132 | /** The line-search step became smaller than lbfgs_parameter_t::min_step. */ 133 | LBFGSERR_MINIMUMSTEP, 134 | /** The line-search step became larger than lbfgs_parameter_t::max_step. */ 135 | LBFGSERR_MAXIMUMSTEP, 136 | /** The line-search routine reaches the maximum number of evaluations. */ 137 | LBFGSERR_MAXIMUMLINESEARCH, 138 | /** The algorithm routine reaches the maximum number of iterations. */ 139 | LBFGSERR_MAXIMUMITERATION, 140 | /** Relative width of the interval of uncertainty is at most 141 | lbfgs_parameter_t::xtol. */ 142 | LBFGSERR_WIDTHTOOSMALL, 143 | /** A logic error (negative line-search step) occurred. */ 144 | LBFGSERR_INVALIDPARAMETERS, 145 | /** The current search direction increases the objective function value. */ 146 | LBFGSERR_INCREASEGRADIENT, 147 | }; 148 | 149 | /** 150 | * Line search algorithms. 151 | */ 152 | enum { 153 | /** The default algorithm (MoreThuente method). */ 154 | LBFGS_LINESEARCH_DEFAULT = 0, 155 | /** MoreThuente method proposd by More and Thuente. */ 156 | LBFGS_LINESEARCH_MORETHUENTE = 0, 157 | /** 158 | * Backtracking method with the Armijo condition. 159 | * The backtracking method finds the step length such that it satisfies 160 | * the sufficient decrease (Armijo) condition, 161 | * - f(x + a * d) <= f(x) + lbfgs_parameter_t::ftol * a * g(x)^T d, 162 | * 163 | * where x is the current point, d is the current search direction, and 164 | * a is the step length. 165 | */ 166 | LBFGS_LINESEARCH_BACKTRACKING_ARMIJO = 1, 167 | /** The backtracking method with the defualt (regular Wolfe) condition. */ 168 | LBFGS_LINESEARCH_BACKTRACKING = 2, 169 | /** 170 | * Backtracking method with regular Wolfe condition. 171 | * The backtracking method finds the step length such that it satisfies 172 | * both the Armijo condition (LBFGS_LINESEARCH_BACKTRACKING_ARMIJO) 173 | * and the curvature condition, 174 | * - g(x + a * d)^T d >= lbfgs_parameter_t::wolfe * g(x)^T d, 175 | * 176 | * where x is the current point, d is the current search direction, and 177 | * a is the step length. 178 | */ 179 | LBFGS_LINESEARCH_BACKTRACKING_WOLFE = 2, 180 | /** 181 | * Backtracking method with strong Wolfe condition. 182 | * The backtracking method finds the step length such that it satisfies 183 | * both the Armijo condition (LBFGS_LINESEARCH_BACKTRACKING_ARMIJO) 184 | * and the following condition, 185 | * - |g(x + a * d)^T d| <= lbfgs_parameter_t::wolfe * |g(x)^T d|, 186 | * 187 | * where x is the current point, d is the current search direction, and 188 | * a is the step length. 189 | */ 190 | LBFGS_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 3, 191 | }; 192 | 193 | /** 194 | * L-BFGS optimization parameters. 195 | * Call lbfgs_parameter_init() function to initialize parameters to the 196 | * default values. 197 | */ 198 | typedef struct { 199 | /** 200 | * The number of corrections to approximate the inverse hessian matrix. 201 | * The L-BFGS routine stores the computation results of previous \ref m 202 | * iterations to approximate the inverse hessian matrix of the current 203 | * iteration. This parameter controls the size of the limited memories 204 | * (corrections). The default value is \c 6. Values less than \c 3 are 205 | * not recommended. Large values will result in excessive computing time. 206 | */ 207 | int m; 208 | 209 | /** 210 | * Epsilon for convergence test. 211 | * This parameter determines the accuracy with which the solution is to 212 | * be found. A minimization terminates when 213 | * ||g|| < \ref epsilon * max(1, ||x||), 214 | * where ||.|| denotes the Euclidean (L2) norm. The default value is 215 | * \c 1e-5. 216 | */ 217 | lbfgsfloatval_t epsilon; 218 | 219 | /** 220 | * Distance for delta-based convergence test. 221 | * This parameter determines the distance, in iterations, to compute 222 | * the rate of decrease of the objective function. If the value of this 223 | * parameter is zero, the library does not perform the delta-based 224 | * convergence test. The default value is \c 0. 225 | */ 226 | int past; 227 | 228 | /** 229 | * Delta for convergence test. 230 | * This parameter determines the minimum rate of decrease of the 231 | * objective function. The library stops iterations when the 232 | * following condition is met: 233 | * (f' - f) / f < \ref delta, 234 | * where f' is the objective value of \ref past iterations ago, and f is 235 | * the objective value of the current iteration. 236 | * The default value is \c 0. 237 | */ 238 | lbfgsfloatval_t delta; 239 | 240 | /** 241 | * The maximum number of iterations. 242 | * The lbfgs() function terminates an optimization process with 243 | * ::LBFGSERR_MAXIMUMITERATION status code when the iteration count 244 | * exceedes this parameter. Setting this parameter to zero continues an 245 | * optimization process until a convergence or error. The default value 246 | * is \c 0. 247 | */ 248 | int max_iterations; 249 | 250 | /** 251 | * The line search algorithm. 252 | * This parameter specifies a line search algorithm to be used by the 253 | * L-BFGS routine. 254 | */ 255 | int linesearch; 256 | 257 | /** 258 | * The maximum number of trials for the line search. 259 | * This parameter controls the number of function and gradients evaluations 260 | * per iteration for the line search routine. The default value is \c 20. 261 | */ 262 | int max_linesearch; 263 | 264 | /** 265 | * The minimum step of the line search routine. 266 | * The default value is \c 1e-20. This value need not be modified unless 267 | * the exponents are too large for the machine being used, or unless the 268 | * problem is extremely badly scaled (in which case the exponents should 269 | * be increased). 270 | */ 271 | lbfgsfloatval_t min_step; 272 | 273 | /** 274 | * The maximum step of the line search. 275 | * The default value is \c 1e+20. This value need not be modified unless 276 | * the exponents are too large for the machine being used, or unless the 277 | * problem is extremely badly scaled (in which case the exponents should 278 | * be increased). 279 | */ 280 | lbfgsfloatval_t max_step; 281 | 282 | /** 283 | * A parameter to control the accuracy of the line search routine. 284 | * The default value is \c 1e-4. This parameter should be greater 285 | * than zero and smaller than \c 0.5. 286 | */ 287 | lbfgsfloatval_t ftol; 288 | 289 | /** 290 | * A coefficient for the Wolfe condition. 291 | * This parameter is valid only when the backtracking line-search 292 | * algorithm is used with the Wolfe condition, 293 | * ::LBFGS_LINESEARCH_BACKTRACKING_STRONG_WOLFE or 294 | * ::LBFGS_LINESEARCH_BACKTRACKING_WOLFE . 295 | * The default value is \c 0.9. This parameter should be greater 296 | * the \ref ftol parameter and smaller than \c 1.0. 297 | */ 298 | lbfgsfloatval_t wolfe; 299 | 300 | /** 301 | * A parameter to control the accuracy of the line search routine. 302 | * The default value is \c 0.9. If the function and gradient 303 | * evaluations are inexpensive with respect to the cost of the 304 | * iteration (which is sometimes the case when solving very large 305 | * problems) it may be advantageous to set this parameter to a small 306 | * value. A typical small value is \c 0.1. This parameter shuold be 307 | * greater than the \ref ftol parameter (\c 1e-4) and smaller than 308 | * \c 1.0. 309 | */ 310 | lbfgsfloatval_t gtol; 311 | 312 | /** 313 | * The machine precision for floating-point values. 314 | * This parameter must be a positive value set by a client program to 315 | * estimate the machine precision. The line search routine will terminate 316 | * with the status code (::LBFGSERR_ROUNDING_ERROR) if the relative width 317 | * of the interval of uncertainty is less than this parameter. 318 | */ 319 | lbfgsfloatval_t xtol; 320 | 321 | /** 322 | * Coeefficient for the L1 norm of variables. 323 | * This parameter should be set to zero for standard minimization 324 | * problems. Setting this parameter to a positive value activates 325 | * Orthant-Wise Limited-memory Quasi-Newton (OWL-QN) method, which 326 | * minimizes the objective function F(x) combined with the L1 norm |x| 327 | * of the variables, {F(x) + C |x|}. This parameter is the coeefficient 328 | * for the |x|, i.e., C. As the L1 norm |x| is not differentiable at 329 | * zero, the library modifies function and gradient evaluations from 330 | * a client program suitably; a client program thus have only to return 331 | * the function value F(x) and gradients G(x) as usual. The default value 332 | * is zero. 333 | */ 334 | lbfgsfloatval_t orthantwise_c; 335 | 336 | /** 337 | * Start index for computing L1 norm of the variables. 338 | * This parameter is valid only for OWL-QN method 339 | * (i.e., \ref orthantwise_c != 0). This parameter b (0 <= b < N) 340 | * specifies the index number from which the library computes the 341 | * L1 norm of the variables x, 342 | * |x| := |x_{b}| + |x_{b+1}| + ... + |x_{N}| . 343 | * In other words, variables x_1, ..., x_{b-1} are not used for 344 | * computing the L1 norm. Setting b (0 < b < N), one can protect 345 | * variables, x_1, ..., x_{b-1} (e.g., a bias term of logistic 346 | * regression) from being regularized. The default value is zero. 347 | */ 348 | int orthantwise_start; 349 | 350 | /** 351 | * End index for computing L1 norm of the variables. 352 | * This parameter is valid only for OWL-QN method 353 | * (i.e., \ref orthantwise_c != 0). This parameter e (0 < e <= N) 354 | * specifies the index number at which the library stops computing the 355 | * L1 norm of the variables x, 356 | */ 357 | int orthantwise_end; 358 | } lbfgs_parameter_t; 359 | 360 | 361 | /** 362 | * Callback interface to provide objective function and gradient evaluations. 363 | * 364 | * The lbfgs() function call this function to obtain the values of objective 365 | * function and its gradients when needed. A client program must implement 366 | * this function to evaluate the values of the objective function and its 367 | * gradients, given current values of variables. 368 | * 369 | * @param instance The user data sent for lbfgs() function by the client. 370 | * @param x The current values of variables. 371 | * @param g The gradient vector. The callback function must compute 372 | * the gradient values for the current variables. 373 | * @param n The number of variables. 374 | * @param step The current step of the line search routine. 375 | * @retval lbfgsfloatval_t The value of the objective function for the current 376 | * variables. 377 | */ 378 | typedef lbfgsfloatval_t (*lbfgs_evaluate_t)( 379 | void *instance, 380 | const lbfgsfloatval_t *x, 381 | lbfgsfloatval_t *g, 382 | const int n, 383 | const lbfgsfloatval_t step 384 | ); 385 | 386 | /** 387 | * Callback interface to receive the progress of the optimization process. 388 | * 389 | * The lbfgs() function call this function for each iteration. Implementing 390 | * this function, a client program can store or display the current progress 391 | * of the optimization process. 392 | * 393 | * @param instance The user data sent for lbfgs() function by the client. 394 | * @param x The current values of variables. 395 | * @param g The current gradient values of variables. 396 | * @param fx The current value of the objective function. 397 | * @param xnorm The Euclidean norm of the variables. 398 | * @param gnorm The Euclidean norm of the gradients. 399 | * @param step The line-search step used for this iteration. 400 | * @param n The number of variables. 401 | * @param k The iteration count. 402 | * @param ls The number of evaluations called for this iteration. 403 | * @retval int Zero to continue the optimization process. Returning a 404 | * non-zero value will cancel the optimization process. 405 | */ 406 | typedef int (*lbfgs_progress_t)( 407 | void *instance, 408 | const lbfgsfloatval_t *x, 409 | const lbfgsfloatval_t *g, 410 | const lbfgsfloatval_t fx, 411 | const lbfgsfloatval_t xnorm, 412 | const lbfgsfloatval_t gnorm, 413 | const lbfgsfloatval_t step, 414 | int n, 415 | int k, 416 | int ls 417 | ); 418 | 419 | /* 420 | A user must implement a function compatible with ::lbfgs_evaluate_t (evaluation 421 | callback) and pass the pointer to the callback function to lbfgs() arguments. 422 | Similarly, a user can implement a function compatible with ::lbfgs_progress_t 423 | (progress callback) to obtain the current progress (e.g., variables, function 424 | value, ||G||, etc) and to cancel the iteration process if necessary. 425 | Implementation of a progress callback is optional: a user can pass \c NULL if 426 | progress notification is not necessary. 427 | 428 | In addition, a user must preserve two requirements: 429 | - The number of variables must be multiples of 16 (this is not 4). 430 | - The memory block of variable array ::x must be aligned to 16. 431 | 432 | This algorithm terminates an optimization 433 | when: 434 | 435 | ||G|| < \epsilon \cdot \max(1, ||x||) . 436 | 437 | In this formula, ||.|| denotes the Euclidean norm. 438 | */ 439 | 440 | /** 441 | * Start a L-BFGS optimization. 442 | * 443 | * @param n The number of variables. 444 | * @param x The array of variables. A client program can set 445 | * default values for the optimization and receive the 446 | * optimization result through this array. This array 447 | * must be allocated by ::lbfgs_malloc function 448 | * for libLBFGS built with SSE/SSE2 optimization routine 449 | * enabled. The library built without SSE/SSE2 450 | * optimization does not have such a requirement. 451 | * @param ptr_fx The pointer to the variable that receives the final 452 | * value of the objective function for the variables. 453 | * This argument can be set to \c NULL if the final 454 | * value of the objective function is unnecessary. 455 | * @param proc_evaluate The callback function to provide function and 456 | * gradient evaluations given a current values of 457 | * variables. A client program must implement a 458 | * callback function compatible with \ref 459 | * lbfgs_evaluate_t and pass the pointer to the 460 | * callback function. 461 | * @param proc_progress The callback function to receive the progress 462 | * (the number of iterations, the current value of 463 | * the objective function) of the minimization 464 | * process. This argument can be set to \c NULL if 465 | * a progress report is unnecessary. 466 | * @param instance A user data for the client program. The callback 467 | * functions will receive the value of this argument. 468 | * @param param The pointer to a structure representing parameters for 469 | * L-BFGS optimization. A client program can set this 470 | * parameter to \c NULL to use the default parameters. 471 | * Call lbfgs_parameter_init() function to fill a 472 | * structure with the default values. 473 | * @retval int The status code. This function returns zero if the 474 | * minimization process terminates without an error. A 475 | * non-zero value indicates an error. 476 | */ 477 | int lbfgs( 478 | int n, 479 | lbfgsfloatval_t *x, 480 | lbfgsfloatval_t *ptr_fx, 481 | lbfgs_evaluate_t proc_evaluate, 482 | lbfgs_progress_t proc_progress, 483 | void *instance, 484 | lbfgs_parameter_t *param 485 | ); 486 | 487 | /** 488 | * Initialize L-BFGS parameters to the default values. 489 | * 490 | * Call this function to fill a parameter structure with the default values 491 | * and overwrite parameter values if necessary. 492 | * 493 | * @param param The pointer to the parameter structure. 494 | */ 495 | void lbfgs_parameter_init(lbfgs_parameter_t *param); 496 | 497 | /** 498 | * Allocate an array for variables. 499 | * 500 | * This function allocates an array of variables for the convenience of 501 | * ::lbfgs function; the function has a requreiemt for a variable array 502 | * when libLBFGS is built with SSE/SSE2 optimization routines. A user does 503 | * not have to use this function for libLBFGS built without SSE/SSE2 504 | * optimization. 505 | * 506 | * @param n The number of variables. 507 | */ 508 | lbfgsfloatval_t* lbfgs_malloc(int n); 509 | 510 | /** 511 | * Free an array of variables. 512 | * 513 | * @param x The array of variables allocated by ::lbfgs_malloc 514 | * function. 515 | */ 516 | void lbfgs_free(lbfgsfloatval_t *x); 517 | 518 | 519 | /** 520 | * Get string description of an lbfgs() return code. 521 | * 522 | * @param err A value returned by lbfgs(). 523 | */ 524 | const char* lbfgs_strerror(int err); 525 | 526 | /** @} */ 527 | 528 | #ifdef __cplusplus 529 | } 530 | #endif/*__cplusplus*/ 531 | 532 | 533 | 534 | /** 535 | @mainpage libLBFGS: a library of Limited-memory Broyden-Fletcher-Goldfarb-Shanno (L-BFGS) 536 | 537 | @section intro Introduction 538 | 539 | This library is a C port of the implementation of Limited-memory 540 | Broyden-Fletcher-Goldfarb-Shanno (L-BFGS) method written by Jorge Nocedal. 541 | The original FORTRAN source code is available at: 542 | http://www.ece.northwestern.edu/~nocedal/lbfgs.html 543 | 544 | The L-BFGS method solves the unconstrainted minimization problem, 545 | 546 |
547 |     minimize F(x), x = (x1, x2, ..., xN),
548 | 
549 | 550 | only if the objective function F(x) and its gradient G(x) are computable. The 551 | well-known Newton's method requires computation of the inverse of the hessian 552 | matrix of the objective function. However, the computational cost for the 553 | inverse hessian matrix is expensive especially when the objective function 554 | takes a large number of variables. The L-BFGS method iteratively finds a 555 | minimizer by approximating the inverse hessian matrix by information from last 556 | m iterations. This innovation saves the memory storage and computational time 557 | drastically for large-scaled problems. 558 | 559 | Among the various ports of L-BFGS, this library provides several features: 560 | - Optimization with L1-norm (Orthant-Wise Limited-memory Quasi-Newton 561 | (OWL-QN) method): 562 | In addition to standard minimization problems, the library can minimize 563 | a function F(x) combined with L1-norm |x| of the variables, 564 | {F(x) + C |x|}, where C is a constant scalar parameter. This feature is 565 | useful for estimating parameters of sparse log-linear models (e.g., 566 | logistic regression and maximum entropy) with L1-regularization (or 567 | Laplacian prior). 568 | - Clean C code: 569 | Unlike C codes generated automatically by f2c (Fortran 77 into C converter), 570 | this port includes changes based on my interpretations, improvements, 571 | optimizations, and clean-ups so that the ported code would be well-suited 572 | for a C code. In addition to comments inherited from the original code, 573 | a number of comments were added through my interpretations. 574 | - Callback interface: 575 | The library receives function and gradient values via a callback interface. 576 | The library also notifies the progress of the optimization by invoking a 577 | callback function. In the original implementation, a user had to set 578 | function and gradient values every time the function returns for obtaining 579 | updated values. 580 | - Thread safe: 581 | The library is thread-safe, which is the secondary gain from the callback 582 | interface. 583 | - Cross platform. The source code can be compiled on Microsoft Visual 584 | Studio 2010, GNU C Compiler (gcc), etc. 585 | - Configurable precision: A user can choose single-precision (float) 586 | or double-precision (double) accuracy by changing ::LBFGS_FLOAT macro. 587 | - SSE/SSE2 optimization: 588 | This library includes SSE/SSE2 optimization (written in compiler intrinsics) 589 | for vector arithmetic operations on Intel/AMD processors. The library uses 590 | SSE for float values and SSE2 for double values. The SSE/SSE2 optimization 591 | routine is disabled by default. 592 | 593 | This library is used by: 594 | - CRFsuite: A fast implementation of Conditional Random Fields (CRFs) 595 | - Classias: A collection of machine-learning algorithms for classification 596 | - mlegp: an R package for maximum likelihood estimates for Gaussian processes 597 | - imaging2: the imaging2 class library 598 | - Algorithm::LBFGS - Perl extension for L-BFGS 599 | - YAP-LBFGS (an interface to call libLBFGS from YAP Prolog) 600 | 601 | @section download Download 602 | 603 | - Source code 604 | - GitHub repository 605 | 606 | libLBFGS is distributed under the term of the 607 | MIT license. 608 | 609 | @section changelog History 610 | - Version 1.10 (2010-12-22): 611 | - Fixed compiling errors on Mac OS X; this patch was kindly submitted by 612 | Nic Schraudolph. 613 | - Reduced compiling warnings on Mac OS X; this patch was kindly submitted 614 | by Tamas Nepusz. 615 | - Replaced memalign() with posix_memalign(). 616 | - Updated solution and project files for Microsoft Visual Studio 2010. 617 | - Version 1.9 (2010-01-29): 618 | - Fixed a mistake in checking the validity of the parameters "ftol" and 619 | "wolfe"; this was discovered by Kevin S. Van Horn. 620 | - Version 1.8 (2009-07-13): 621 | - Accepted the patch submitted by Takashi Imamichi; 622 | the backtracking method now has three criteria for choosing the step 623 | length: 624 | - ::LBFGS_LINESEARCH_BACKTRACKING_ARMIJO: sufficient decrease (Armijo) 625 | condition only 626 | - ::LBFGS_LINESEARCH_BACKTRACKING_WOLFE: regular Wolfe condition 627 | (sufficient decrease condition + curvature condition) 628 | - ::LBFGS_LINESEARCH_BACKTRACKING_STRONG_WOLFE: strong Wolfe condition 629 | - Updated the documentation to explain the above three criteria. 630 | - Version 1.7 (2009-02-28): 631 | - Improved OWL-QN routines for stability. 632 | - Removed the support of OWL-QN method in MoreThuente algorithm because 633 | it accidentally fails in early stages of iterations for some objectives. 634 | Because of this change, the OW-LQN method must be used with the 635 | backtracking algorithm (::LBFGS_LINESEARCH_BACKTRACKING), or the 636 | library returns ::LBFGSERR_INVALID_LINESEARCH. 637 | - Renamed line search algorithms as follows: 638 | - ::LBFGS_LINESEARCH_BACKTRACKING: regular Wolfe condition. 639 | - ::LBFGS_LINESEARCH_BACKTRACKING_LOOSE: regular Wolfe condition. 640 | - ::LBFGS_LINESEARCH_BACKTRACKING_STRONG: strong Wolfe condition. 641 | - Source code clean-up. 642 | - Version 1.6 (2008-11-02): 643 | - Improved line-search algorithm with strong Wolfe condition, which was 644 | contributed by Takashi Imamichi. This routine is now default for 645 | ::LBFGS_LINESEARCH_BACKTRACKING. The previous line search algorithm 646 | with regular Wolfe condition is still available as 647 | ::LBFGS_LINESEARCH_BACKTRACKING_LOOSE. 648 | - Configurable stop index for L1-norm computation. A member variable 649 | ::lbfgs_parameter_t::orthantwise_end was added to specify the index 650 | number at which the library stops computing the L1 norm of the 651 | variables. This is useful to prevent some variables from being 652 | regularized by the OW-LQN method. 653 | - A sample program written in C++ (sample/sample.cpp). 654 | - Version 1.5 (2008-07-10): 655 | - Configurable starting index for L1-norm computation. A member variable 656 | ::lbfgs_parameter_t::orthantwise_start was added to specify the index 657 | number from which the library computes the L1 norm of the variables. 658 | This is useful to prevent some variables from being regularized by the 659 | OWL-QN method. 660 | - Fixed a zero-division error when the initial variables have already 661 | been a minimizer (reported by Takashi Imamichi). In this case, the 662 | library returns ::LBFGS_ALREADY_MINIMIZED status code. 663 | - Defined ::LBFGS_SUCCESS status code as zero; removed unused constants, 664 | LBFGSFALSE and LBFGSTRUE. 665 | - Fixed a compile error in an implicit down-cast. 666 | - Version 1.4 (2008-04-25): 667 | - Configurable line search algorithms. A member variable 668 | ::lbfgs_parameter_t::linesearch was added to choose either MoreThuente 669 | method (::LBFGS_LINESEARCH_MORETHUENTE) or backtracking algorithm 670 | (::LBFGS_LINESEARCH_BACKTRACKING). 671 | - Fixed a bug: the previous version did not compute psuedo-gradients 672 | properly in the line search routines for OWL-QN. This bug might quit 673 | an iteration process too early when the OWL-QN routine was activated 674 | (0 < ::lbfgs_parameter_t::orthantwise_c). 675 | - Configure script for POSIX environments. 676 | - SSE/SSE2 optimizations with GCC. 677 | - New functions ::lbfgs_malloc and ::lbfgs_free to use SSE/SSE2 routines 678 | transparently. It is uncessary to use these functions for libLBFGS built 679 | without SSE/SSE2 routines; you can still use any memory allocators if 680 | SSE/SSE2 routines are disabled in libLBFGS. 681 | - Version 1.3 (2007-12-16): 682 | - An API change. An argument was added to lbfgs() function to receive the 683 | final value of the objective function. This argument can be set to 684 | \c NULL if the final value is unnecessary. 685 | - Fixed a null-pointer bug in the sample code (reported by Takashi Imamichi). 686 | - Added build scripts for Microsoft Visual Studio 2005 and GCC. 687 | - Added README file. 688 | - Version 1.2 (2007-12-13): 689 | - Fixed a serious bug in orthant-wise L-BFGS. 690 | An important variable was used without initialization. 691 | - Version 1.1 (2007-12-01): 692 | - Implemented orthant-wise L-BFGS. 693 | - Implemented lbfgs_parameter_init() function. 694 | - Fixed several bugs. 695 | - API documentation. 696 | - Version 1.0 (2007-09-20): 697 | - Initial release. 698 | 699 | @section api Documentation 700 | 701 | - @ref liblbfgs_api "libLBFGS API" 702 | 703 | @section sample Sample code 704 | 705 | @include sample.c 706 | 707 | @section ack Acknowledgements 708 | 709 | The L-BFGS algorithm is described in: 710 | - Jorge Nocedal. 711 | Updating Quasi-Newton Matrices with Limited Storage. 712 | Mathematics of Computation, Vol. 35, No. 151, pp. 773--782, 1980. 713 | - Dong C. Liu and Jorge Nocedal. 714 | On the limited memory BFGS method for large scale optimization. 715 | Mathematical Programming B, Vol. 45, No. 3, pp. 503-528, 1989. 716 | 717 | The line search algorithms used in this implementation are described in: 718 | - John E. Dennis and Robert B. Schnabel. 719 | Numerical Methods for Unconstrained Optimization and Nonlinear 720 | Equations, Englewood Cliffs, 1983. 721 | - Jorge J. More and David J. Thuente. 722 | Line search algorithm with guaranteed sufficient decrease. 723 | ACM Transactions on Mathematical Software (TOMS), Vol. 20, No. 3, 724 | pp. 286-307, 1994. 725 | 726 | This library also implements Orthant-Wise Limited-memory Quasi-Newton (OWL-QN) 727 | method presented in: 728 | - Galen Andrew and Jianfeng Gao. 729 | Scalable training of L1-regularized log-linear models. 730 | In Proceedings of the 24th International Conference on Machine 731 | Learning (ICML 2007), pp. 33-40, 2007. 732 | 733 | Special thanks go to: 734 | - Yoshimasa Tsuruoka and Daisuke Okanohara for technical information about 735 | OWL-QN 736 | - Takashi Imamichi for the useful enhancements of the backtracking method 737 | - Kevin S. Van Horn, Nic Schraudolph, and Tamas Nepusz for bug fixes 738 | 739 | Finally I would like to thank the original author, Jorge Nocedal, who has been 740 | distributing the effieicnt and explanatory implementation in an open source 741 | licence. 742 | 743 | @section reference Reference 744 | 745 | - L-BFGS by Jorge Nocedal. 746 | - Orthant-Wise Limited-memory Quasi-Newton Optimizer for L1-regularized Objectives by Galen Andrew. 747 | - C port (via f2c) by Taku Kudo. 748 | - C#/C++/Delphi/VisualBasic6 port in ALGLIB. 749 | - Computational Crystallography Toolbox includes 750 | scitbx::lbfgs. 751 | */ 752 | 753 | #endif/*__LBFGS_H__*/ 754 | -------------------------------------------------------------------------------- /cpp/loss.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | template T safe_log(const T x) { 4 | return log(std::max(EPSILON, x)); 5 | } 6 | 7 | // Parent class for all losses 8 | template 9 | class BaseLoss { 10 | public: 11 | BaseLoss() {} 12 | BaseLoss(Kernel &ker) {} 13 | 14 | virtual double loss(const double* v1, const double* v2, int N) const = 0; 15 | virtual void gradient(const double* v1, const double* v2, int N, double* res) const = 0; 16 | }; 17 | 18 | 19 | template 20 | class QuadraticLoss : public BaseLoss { 21 | public: 22 | QuadraticLoss() {} 23 | QuadraticLoss(Kernel &ker) {} 24 | 25 | double loss(const double* v1, const double* v2, int N) const { 26 | double r = 0; 27 | for (int i=0; i 44 | class TVLoss : public BaseLoss { 45 | public: 46 | TVLoss() {} 47 | TVLoss(Kernel &ker) {} 48 | 49 | double loss(const double* v1, const double* v2, int N) const { 50 | double r = 0; 51 | for (int i=0; i0?0.5:-0.5); 63 | } 64 | 65 | } 66 | }; 67 | 68 | template 69 | class KLLoss : public BaseLoss { 70 | public: 71 | KLLoss() {} 72 | KLLoss(Kernel &ker) {} 73 | 74 | double loss(const double* v1, const double* v2, int N) const { 75 | double r = 0; 76 | for (int i=0; i 93 | class WassersteinLoss : public BaseLoss { 94 | public: 95 | WassersteinLoss(Kernel &ker, int n_iter = 50) : kernel(&ker), num_iter(n_iter) {} 96 | 97 | Kernel * kernel; 98 | int num_iter; 99 | 100 | double loss(const double* v1, const double* v2, int N) const {return myWloss(v1,v2,N,num_iter);} 101 | 102 | double myWloss(const double* v1, const double* v2, int N, int n_iter) const { 103 | 104 | int Niters = n_iter; 105 | std::vector a(N, 1.), b(N, 1.), convolution(N); 106 | 107 | // Bregman Projections 108 | for (int iter=0; iterconvolveAdjoint(&b[0], &convolution[0], 1); 111 | for (int j=0; jconvolve(&a[0], &convolution[0], 1); 115 | for (int j=0; jconvolveAdjoint(&b[0], &convolution[0], 1); 122 | for (int j=0; jgamma*l; 127 | } 128 | 129 | void gradient(const double* v1, const double* v2, int N, double* res) const {myWgradient(v1,v2,N,res,num_iter);} 130 | 131 | void myWgradient(const double* v1, const double* v2, int N, double* res, int n_iter) const { 132 | 133 | int Niters = n_iter; 134 | std::vector a(N, 1.), b(N, 1.), convolution(N); 135 | 136 | // Bregman Projections 137 | for (int iter=0; iterconvolveAdjoint(&b[0], &convolution[0]); 140 | for (int j=0; jconvolve(&a[0], &convolution[0]); 144 | for (int j=0; jgamma*safe_log(a[i]); 151 | } 152 | } 153 | 154 | }; 155 | -------------------------------------------------------------------------------- /cpp/main_dictionary_learning.cpp: -------------------------------------------------------------------------------- 1 | // Uncomment for log-domain stabilization 2 | //#define COMPUTE_BARYCENTER_LOG 3 | 4 | #include "inverseWasserstein.h" 5 | #include "histogramIO.h" 6 | #include "cimg/CImg.h" 7 | #include "chrono.h" 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | 14 | void dict_learning_2Dimages(int argc, const char* argv[]) { 15 | 16 | std::string usage = "parameters are : " 17 | "-i " 18 | "-o " 19 | "[-k ] " 20 | "[-l ] " 21 | "[-n ] " 22 | "[-s ] " 23 | "[-g ] " 24 | "[-a ] " 25 | "[-x ] " 26 | "[-m ] " 27 | "[--deterministic] " 28 | "[--imComplement] " 29 | "[--allowNegWeights] " 30 | "[--warmRestart>]"; 31 | 32 | if (argc < 4 || std::string(argv[1]) != "-i" || std::string(argv[3]) != "-o") 33 | { 34 | std::cout << usage < files = get_all_files_names_within_folder(wfolder); 109 | int P = files.size(); 110 | 111 | std::vector normScalings; 112 | std::vector pdf; 113 | std::vector observedpdf; 114 | std::vector > bases(K); 115 | int W, H; 116 | for (int i=0; i=K) continue; 123 | bases[i].resize(W*H); 124 | double s=0; 125 | // Atom initialization 126 | for (int j=0; j variables(P*K+N*K); 143 | // Fill the weights 144 | for (int i=0; i KernelType; 173 | #endif 174 | #else 175 | #ifndef COMPUTE_BARYCENTER_LOG 176 | typedef Gaussian2DKernel KernelType; 177 | #else 178 | typedef LogSignArrayGaussian2DKernel KernelType; 179 | #endif 180 | #endif 181 | 182 | ExportHistogramToPNGWithScaling exporter(W,H,normScalings,outputFolder,alpha,imComplement); 183 | KernelType mkernel(gamma, W, H); 184 | 185 | // Utility class that stores all the data 186 | Problem problem(bases, observedpdf, mkernel, &variables[0], P); 187 | 188 | // Choose the loss 189 | BaseLoss * loss; 190 | switch(myLoss) 191 | { 192 | case 1: loss = new TVLoss(problem.kernel); break; 193 | case 2: loss = new QuadraticLoss(problem.kernel); break; 194 | case 3: loss = new WassersteinLoss(problem.kernel); break; 195 | case 4: loss = new KLLoss(problem.kernel); break; 196 | default : std::cerr<<"Loss index not recognized. Must be in {1,..,4}"< regression(&problem, Niter, GRADIENT_SINKHORN, *loss, exporter, 200 | scaleDictionary, exportAtoms, exportFittings, exportOnlyFinalSolution, warmRestart); 201 | regression.exportEveryMIter = exportEveryMIter; 202 | regression.exp_weight = !allowNegWeights; 203 | regression.wrTotalIteration = maxOptimIter; 204 | 205 | // Start regression 206 | regression.regress_both(&variables[0]); 207 | 208 | 209 | // Display final solution 210 | std::cout<<"solution (just weights+20): "< 4 | #include 5 | #include 6 | 7 | 8 | template 9 | void setSign(unsigned char *tab, const unsigned int& i, const T& val) 10 | { 11 | if(val >= 0) 12 | #pragma omp atomic 13 | tab[i/8] |= (0x01 << (i%8)); 14 | else 15 | #pragma omp atomic 16 | tab[i/8] &= ~(0x01 << (i%8)); 17 | } 18 | 19 | float getSign(unsigned char *tab, const unsigned int& i) 20 | { 21 | return (((tab[i / 8] >> (i % 8)) & 1)<<1) - 1; 22 | } 23 | 24 | template 25 | T myAbsLog(T val, unsigned char *tab, const unsigned int &i) 26 | { 27 | setSign(tab,i,val); 28 | return log(fabs(val)); 29 | } 30 | 31 | template 32 | T myAbsExp(T log_val, unsigned char *tab, const unsigned int &i) 33 | { 34 | return getSign(tab,i) * exp(log_val); 35 | } 36 | 37 | void displaySignArray(unsigned char *tab, unsigned int N) 38 | { 39 | for(unsigned int i=0; i(a[0])<(a[0])<(a[0])<(a[0])<(a[0])< 7 | #include 8 | #include 9 | #include 10 | #if defined(__linux__) && !defined(_ISOC11_SOURCE) 11 | #include 12 | #endif 13 | 14 | #include 15 | 16 | #ifndef _MSC_VER 17 | #include 18 | #include 19 | #endif 20 | 21 | #ifdef AVX_SUPPORT 22 | #define SSE_SUPPORT 23 | #endif 24 | 25 | static inline void * malloc_simd(const size_t size, const size_t alignment) 26 | { 27 | #if defined WIN32 // WIN32 28 | return _aligned_malloc(size, alignment); 29 | #elif defined __linux__ // Linux 30 | #if defined _ISOC11_SOURCE 31 | return aligned_alloc(alignment, size); 32 | #else 33 | return memalign(alignment, size); 34 | #endif 35 | #elif defined __MACH__ // Mac OS X 36 | return malloc(size); 37 | #else // other (use valloc for page-aligned memory) 38 | return valloc(size); 39 | #endif 40 | } 41 | 42 | static inline void free_simd(void* mem) 43 | { 44 | #if defined WIN32 // WIN32 45 | return _aligned_free(mem); 46 | #elif defined __linux__ // Linux 47 | free(mem); 48 | #elif defined __MACH__ // Mac OS X 49 | free(mem); 50 | #else // other (use valloc for page-aligned memory) 51 | free(mem); 52 | #endif 53 | } 54 | 55 | 56 | 57 | 58 | #if _MSC_VER 59 | ///grrr... I don't have it ! if your compiler has it and complains, just remove this approximation 60 | static const __m256d ct_acos_1 = _mm256_set1_pd(1.43); 61 | static const __m256d ct_acos_2 = _mm256_set1_pd(0.59); 62 | static const __m256d ct_acos_3 = _mm256_set1_pd(1.65); 63 | static const __m256d ct_acos_4 = _mm256_set1_pd(-1.41); 64 | static const __m256d ct_acos_5 = _mm256_set1_pd(0.88); 65 | static const __m256d ct_acos_6 = _mm256_set1_pd(-0.77); 66 | static const __m256d ct_acos_7 = _mm256_set1_pd(8./3.); 67 | static const __m256d ct_acos_8 = _mm256_set1_pd(-1./3.); 68 | static const __m256d ct_acos_two = _mm256_set1_pd(2.); 69 | static const __m256d ct_acos_mtwo = _mm256_set1_pd(-2.); 70 | static const __m256d ct_acos_half = _mm256_set1_pd(0.5); 71 | static const __m256d ct_acos_invsix = _mm256_set1_pd(1./6.); 72 | static const __m256d ct_acos_eight = _mm256_set1_pd(8.); 73 | 74 | static inline __m256d _mm256_acos_pd(__m256d x) { 75 | return _mm256_set_pd(acos(x.m256d_f64[0]), acos(x.m256d_f64[1]), acos(x.m256d_f64[2]), acos(x.m256d_f64[3])); 76 | // approximation below not precise enough 77 | 78 | __m256d a = _mm256_add_pd(ct_acos_1, _mm256_mul_pd(ct_acos_2, x)); 79 | a = _mm256_mul_pd(_mm256_add_pd(a, _mm256_div_pd(_mm256_add_pd(ct_acos_two, _mm256_mul_pd(ct_acos_two, x)), a)), ct_acos_half); 80 | __m256d b = _mm256_add_pd(ct_acos_3, _mm256_mul_pd(ct_acos_4, x)); 81 | b = _mm256_mul_pd(_mm256_add_pd(b, _mm256_div_pd(_mm256_add_pd(ct_acos_two, _mm256_mul_pd(ct_acos_mtwo, x)), b)), ct_acos_half); 82 | __m256d c = _mm256_add_pd(ct_acos_5, _mm256_mul_pd(ct_acos_6, x)); 83 | c = _mm256_mul_pd(_mm256_add_pd(c, _mm256_div_pd(_mm256_sub_pd(ct_acos_two, a), c)), ct_acos_half); 84 | return _mm256_mul_pd(_mm256_sub_pd(_mm256_mul_pd(ct_acos_eight, _mm256_add_pd(c, _mm256_div_pd(_mm256_sub_pd(ct_acos_two, a), c))), _mm256_add_pd(b, _mm256_div_pd(_mm256_add_pd(ct_acos_two, _mm256_mul_pd(ct_acos_mtwo, x)), b))), ct_acos_invsix); 85 | } 86 | 87 | static const __m256d _ps256_exp_hi = _mm256_set1_pd(200.3762626647949f); 88 | static const __m256d _ps256_exp_lo = _mm256_set1_pd(-200.3762626647949f); 89 | 90 | static const __m256d _ps256_cephes_LOG2EF = _mm256_set1_pd(1.44269504088896341); 91 | static const __m256d _ps256_cephes_exp_C1 = _mm256_set1_pd(0.693359375); 92 | static const __m256d _ps256_cephes_exp_C2 = _mm256_set1_pd(-2.12194440e-4); 93 | 94 | static const __m256d _ps256_cephes_exp_p0 = _mm256_set1_pd(1.9875691500E-4); 95 | static const __m256d _ps256_cephes_exp_p1 = _mm256_set1_pd(1.3981999507E-3); 96 | static const __m256d _ps256_cephes_exp_p2 = _mm256_set1_pd(8.3334519073E-3); 97 | static const __m256d _ps256_cephes_exp_p3 = _mm256_set1_pd(4.1665795894E-2); 98 | static const __m256d _ps256_cephes_exp_p4 = _mm256_set1_pd(1.6666665459E-1); 99 | static const __m256d _ps256_cephes_exp_p5 = _mm256_set1_pd(5.0000001201E-1); 100 | static const __m256d _ps256_exp_0p5 = _mm256_set1_pd(0.5); 101 | 102 | static const __m128i _pi32_256_123 = _mm_set1_epi32(1023); 103 | 104 | 105 | static inline __m256d _mm256_exp_pd(__m256d x) { 106 | 107 | return _mm256_set_pd(exp(x.m256d_f64[0]), exp(x.m256d_f64[1]), exp(x.m256d_f64[2]), exp(x.m256d_f64[3])); 108 | // approximation below not precise enough 109 | 110 | __m256d tmp = _mm256_setzero_pd(), fx; 111 | __m128i imm0; 112 | __m256d one = _mm256_set1_pd(1.); 113 | 114 | x = _mm256_min_pd(x, _ps256_exp_hi); 115 | x = _mm256_max_pd(x, _ps256_exp_lo); 116 | 117 | /* express exp(x) as exp(g + n*log(2)) */ 118 | fx = _mm256_mul_pd(x, _ps256_cephes_LOG2EF); 119 | fx = _mm256_add_pd(fx, _ps256_exp_0p5); 120 | 121 | /* how to perform a floorf with SSE: just below */ 122 | //imm0 = _mm256_cvttps_epi32(fx); 123 | //tmp = _mm256_cvtepi32_ps(imm0); 124 | 125 | tmp = _mm256_floor_pd(fx); 126 | 127 | /* if greater, substract 1 */ 128 | //v8sf mask = _mm256_cmpgt_ps(tmp, fx); 129 | __m256d mask = _mm256_cmp_pd(tmp, fx, _CMP_GT_OS); 130 | mask = _mm256_and_pd(mask, one); 131 | fx = _mm256_sub_pd(tmp, mask); 132 | 133 | tmp = _mm256_mul_pd(fx, _ps256_cephes_exp_C1); 134 | __m256d z = _mm256_mul_pd(fx, _ps256_cephes_exp_C2); 135 | x = _mm256_sub_pd(x, tmp); 136 | x = _mm256_sub_pd(x, z); 137 | 138 | z = _mm256_mul_pd(x, x); 139 | 140 | __m256d y = _ps256_cephes_exp_p0; 141 | y = _mm256_mul_pd(y, x); 142 | y = _mm256_add_pd(y, _ps256_cephes_exp_p1); 143 | y = _mm256_mul_pd(y, x); 144 | y = _mm256_add_pd(y, _ps256_cephes_exp_p2); 145 | y = _mm256_mul_pd(y, x); 146 | y = _mm256_add_pd(y, _ps256_cephes_exp_p3); 147 | y = _mm256_mul_pd(y, x); 148 | y = _mm256_add_pd(y, _ps256_cephes_exp_p4); 149 | y = _mm256_mul_pd(y, x); 150 | y = _mm256_add_pd(y, _ps256_cephes_exp_p5); 151 | y = _mm256_mul_pd(y, z); 152 | y = _mm256_add_pd(y, x); 153 | y = _mm256_add_pd(y, one); 154 | 155 | /* build 2^n */ 156 | //__m128 d1 = _mm128_cvttps_epi32(fx); 157 | imm0 = _mm256_cvttpd_epi32(fx); 158 | // another two AVX2 instructions 159 | imm0 = _mm_add_epi32(imm0, _pi32_256_123); 160 | 161 | /*__m256i ik; 162 | ik.m256i_i64[0] = imm0.m128i_i32[0]; 163 | ik.m256i_i64[1] = imm0.m128i_i32[1]; 164 | ik.m256i_i64[2] = imm0.m128i_i32[2]; 165 | ik.m256i_i64[3] = imm0.m128i_i32[3]; 166 | 167 | ik = _mm256_slli_epi64(ik, 52); 168 | 169 | __m256d pow2n = _mm256_castsi256_pd(ik); 170 | y = _mm256_mul_pd(y, pow2n);*/ 171 | 172 | __m256d pow2d; // my proc doesn't like _mm256_slli_epi64, my compiler doesn't have _mm_add_epi64, etc. etc. etc. ! grrrr 173 | unsigned long long ul1 = ((unsigned long long) imm0.m128i_u32[0]) << 52l; 174 | unsigned long long ul2 = ((unsigned long long) imm0.m128i_u32[1]) << 52l; 175 | unsigned long long ul3 = ((unsigned long long) imm0.m128i_u32[2]) << 52l; 176 | unsigned long long ul4 = ((unsigned long long) imm0.m128i_u32[3]) << 52l; 177 | pow2d.m256d_f64[0] = *((double*)&ul1); 178 | pow2d.m256d_f64[1] = *((double*)&ul2); 179 | pow2d.m256d_f64[2] = *((double*)&ul3); 180 | pow2d.m256d_f64[3] = *((double*)&ul4); 181 | 182 | /*long long int t = 1025ll<<52ll; 183 | double d = *((double*)&t);*/ 184 | 185 | return _mm256_mul_pd(y, pow2d); 186 | 187 | } 188 | 189 | 190 | #else 191 | 192 | // emulates Visual Studio's m256d_f64 fields, and missing SSE/AVX instructions 193 | union m256d { 194 | struct { double m256d_f64[4]; }; 195 | __m256d variable; 196 | }; 197 | 198 | static inline __m256d _mm256_acos_pd(__m256d y) { 199 | m256d x; x.variable = y; 200 | return _mm256_set_pd(acos(x.m256d_f64[0]), acos(x.m256d_f64[1]), acos(x.m256d_f64[2]), acos(x.m256d_f64[3])); 201 | // tried approximate acos, but not precise enough 202 | } 203 | static inline __m256d _mm256_exp_pd(__m256d y) { 204 | m256d x; x.variable = y; 205 | return _mm256_set_pd(exp(x.m256d_f64[0]), exp(x.m256d_f64[1]), exp(x.m256d_f64[2]), exp(x.m256d_f64[3])); 206 | // tried approximate exp, but not precise enough 207 | } 208 | 209 | #endif 210 | 211 | #ifdef SSE_SUPPORT 212 | 213 | #ifdef AVX_SUPPORT 214 | #define ALIGN 32 215 | #define VECSIZEDOUBLE 4 216 | #define VECSIZEFLOAT 8 217 | //#define simd_dotp(x,y) _mm256_dp_pd(x,y) 218 | #define simd_add(x,y) _mm256_add_pd(x,y) 219 | #define simd_sub(x,y) _mm256_sub_pd(x,y) 220 | #define simd_mul(x,y) _mm256_mul_pd(x,y) 221 | #define simd_div(x,y) _mm256_div_pd(x,y) 222 | #define simd_max(x,y) _mm256_max_pd(x,y) 223 | #define simd_load(x) _mm256_load_pd(x) 224 | #define simd_store(x,y) _mm256_store_pd(x,y) 225 | #define simd_set1(x) _mm256_set1_pd(x) 226 | #define simd_or(x,y) _mm256_or_pd(x,y) 227 | #define simd_gt(x,y) _mm256_cmp_pd(x,y,_CMP_GT_OS) 228 | #define simd_and(x,y) _mm256_and_pd(x,y) 229 | #define simd_andnot(x,y) _mm256_andnot_pd(x,y) 230 | 231 | #define simd_dotp_f(x,y) _mm256_dp_ps(x,y) 232 | #define simd_add_f(x,y) _mm256_add_ps(x,y) 233 | #define simd_sub_f(x,y) _mm256_sub_ps(x,y) 234 | #define simd_mul_f(x,y) _mm256_mul_ps(x,y) 235 | #define simd_div_f(x,y) _mm256_div_ps(x,y) 236 | #define simd_max_f(x,y) _mm256_max_ps(x,y) 237 | #define simd_load_f(x) _mm256_load_ps(x) 238 | #define simd_store_f(x,y) _mm256_store_ps(x,y) 239 | #define simd_set1_f(x) _mm256_set1_ps(x) 240 | #define simd_or_f(x,y) _mm256_or_ps(x,y) 241 | #define simd_gt_f(x,y) _mm256_cmp_ps(x,y,_CMP_GT_OS) 242 | #define simd_and_f(x,y) _mm256_and_ps(x,y) 243 | #define simd_andnot_f(x,y) _mm256_andnot_ps(x,y) 244 | 245 | typedef __m256d simd_double; 246 | typedef __m256 simd_float; 247 | 248 | 249 | 250 | static inline double dotp16(const double* u, const double* v) { 251 | __m256d xy0 = _mm256_mul_pd(simd_load(u), simd_load(v)); 252 | __m256d xy1 = _mm256_mul_pd(simd_load(u+4), simd_load(v+4)); 253 | __m256d xy2 = _mm256_mul_pd(simd_load(u+8), simd_load(v+8)); 254 | __m256d xy3 = _mm256_mul_pd(simd_load(u+12), simd_load(v+12)); 255 | 256 | __m256d dotproduct = _mm256_add_pd(_mm256_add_pd(xy0, xy1), _mm256_add_pd(xy2, xy3)); 257 | #ifdef _MSC_VER 258 | __m256d s = _mm256_hadd_pd(dotproduct, dotproduct); 259 | return s.m256d_f64[0]+s.m256d_f64[2]; 260 | #else 261 | m256d s; 262 | s.variable = _mm256_hadd_pd(dotproduct, dotproduct); 263 | return s.m256d_f64[0]+s.m256d_f64[2]; 264 | #endif 265 | } 266 | 267 | static inline double dotp32(const double* u, const double* v) { 268 | __m256d xy0 = _mm256_mul_pd(simd_load(u), simd_load(v)); 269 | __m256d xy1 = _mm256_mul_pd(simd_load(u+4), simd_load(v+4)); 270 | __m256d xy2 = _mm256_mul_pd(simd_load(u+8), simd_load(v+8)); 271 | __m256d xy3 = _mm256_mul_pd(simd_load(u+12), simd_load(v+12)); 272 | 273 | __m256d xy4 = _mm256_mul_pd(simd_load(u+16), simd_load(v+16)); 274 | __m256d xy5 = _mm256_mul_pd(simd_load(u+20), simd_load(v+20)); 275 | __m256d xy6 = _mm256_mul_pd(simd_load(u+24), simd_load(v+24)); 276 | __m256d xy7 = _mm256_mul_pd(simd_load(u+28), simd_load(v+28)); 277 | 278 | __m256d dotproduct1 = _mm256_add_pd(_mm256_add_pd(xy0, xy1), _mm256_add_pd(xy2, xy3)); 279 | __m256d dotproduct2 = _mm256_add_pd(_mm256_add_pd(xy4, xy5), _mm256_add_pd(xy6, xy7)); 280 | __m256d dotproduct = _mm256_add_pd(dotproduct1, dotproduct2); 281 | #ifdef _MSC_VER 282 | __m256d s = _mm256_hadd_pd(dotproduct, dotproduct); 283 | return s.m256d_f64[0]+s.m256d_f64[2]; 284 | #else 285 | m256d s; 286 | s.variable = _mm256_hadd_pd(dotproduct, dotproduct); 287 | return s.m256d_f64[0]+s.m256d_f64[2]; 288 | #endif 289 | } 290 | 291 | static inline double dotp64(const double* u, const double* v) { 292 | __m256d xy0 = _mm256_mul_pd(simd_load(u), simd_load(v)); 293 | __m256d xy1 = _mm256_mul_pd(simd_load(u+4), simd_load(v+4)); 294 | __m256d xy2 = _mm256_mul_pd(simd_load(u+8), simd_load(v+8)); 295 | __m256d xy3 = _mm256_mul_pd(simd_load(u+12), simd_load(v+12)); 296 | 297 | __m256d xy4 = _mm256_mul_pd(simd_load(u+16), simd_load(v+16)); 298 | __m256d xy5 = _mm256_mul_pd(simd_load(u+20), simd_load(v+20)); 299 | __m256d xy6 = _mm256_mul_pd(simd_load(u+24), simd_load(v+24)); 300 | __m256d xy7 = _mm256_mul_pd(simd_load(u+28), simd_load(v+28)); 301 | 302 | __m256d xy8 = _mm256_mul_pd(simd_load(u+32), simd_load(v+32)); 303 | __m256d xy9 = _mm256_mul_pd(simd_load(u+36), simd_load(v+36)); 304 | __m256d xy10 = _mm256_mul_pd(simd_load(u+40), simd_load(v+40)); 305 | __m256d xy11 = _mm256_mul_pd(simd_load(u+44), simd_load(v+44)); 306 | 307 | __m256d xy12 = _mm256_mul_pd(simd_load(u+48), simd_load(v+48)); 308 | __m256d xy13 = _mm256_mul_pd(simd_load(u+52), simd_load(v+52)); 309 | __m256d xy14 = _mm256_mul_pd(simd_load(u+56), simd_load(v+56)); 310 | __m256d xy15 = _mm256_mul_pd(simd_load(u+60), simd_load(v+60)); 311 | 312 | __m256d dotproduct1 = _mm256_add_pd(_mm256_add_pd(xy0, xy1), _mm256_add_pd(xy2, xy3)); 313 | __m256d dotproduct2 = _mm256_add_pd(_mm256_add_pd(xy4, xy5), _mm256_add_pd(xy6, xy7)); 314 | __m256d dotproduct3 = _mm256_add_pd(_mm256_add_pd(xy8, xy9), _mm256_add_pd(xy10, xy11)); 315 | __m256d dotproduct4 = _mm256_add_pd(_mm256_add_pd(xy12, xy13), _mm256_add_pd(xy14, xy15)); 316 | 317 | __m256d dotproductA = _mm256_add_pd(dotproduct1, dotproduct2); 318 | __m256d dotproductB = _mm256_add_pd(dotproduct3, dotproduct4); 319 | 320 | __m256d dotproduct = _mm256_add_pd(dotproductA, dotproductB); 321 | #ifdef _MSC_VER 322 | __m256d s = _mm256_hadd_pd(dotproduct, dotproduct); 323 | return s.m256d_f64[0]+s.m256d_f64[2]; 324 | #else 325 | m256d s; 326 | s.variable = _mm256_hadd_pd(dotproduct, dotproduct); 327 | return s.m256d_f64[0]+s.m256d_f64[2]; 328 | #endif 329 | } 330 | 331 | static inline double dotp128(const double* u, const double* v) { 332 | __m256d xy0 = _mm256_mul_pd(simd_load(u), simd_load(v)); 333 | __m256d xy1 = _mm256_mul_pd(simd_load(u+4), simd_load(v+4)); 334 | __m256d xy2 = _mm256_mul_pd(simd_load(u+8), simd_load(v+8)); 335 | __m256d xy3 = _mm256_mul_pd(simd_load(u+12), simd_load(v+12)); 336 | 337 | __m256d xy4 = _mm256_mul_pd(simd_load(u+16), simd_load(v+16)); 338 | __m256d xy5 = _mm256_mul_pd(simd_load(u+20), simd_load(v+20)); 339 | __m256d xy6 = _mm256_mul_pd(simd_load(u+24), simd_load(v+24)); 340 | __m256d xy7 = _mm256_mul_pd(simd_load(u+28), simd_load(v+28)); 341 | 342 | __m256d xy8 = _mm256_mul_pd(simd_load(u+32), simd_load(v+32)); 343 | __m256d xy9 = _mm256_mul_pd(simd_load(u+36), simd_load(v+36)); 344 | __m256d xy10 = _mm256_mul_pd(simd_load(u+40), simd_load(v+40)); 345 | __m256d xy11 = _mm256_mul_pd(simd_load(u+44), simd_load(v+44)); 346 | 347 | __m256d xy12 = _mm256_mul_pd(simd_load(u+48), simd_load(v+48)); 348 | __m256d xy13 = _mm256_mul_pd(simd_load(u+52), simd_load(v+52)); 349 | __m256d xy14 = _mm256_mul_pd(simd_load(u+56), simd_load(v+56)); 350 | __m256d xy15 = _mm256_mul_pd(simd_load(u+60), simd_load(v+60)); 351 | 352 | __m256d xy16 = _mm256_mul_pd(simd_load(u+64), simd_load(v+64)); 353 | __m256d xy17 = _mm256_mul_pd(simd_load(u+68), simd_load(v+68)); 354 | __m256d xy18 = _mm256_mul_pd(simd_load(u+72), simd_load(v+72)); 355 | __m256d xy19 = _mm256_mul_pd(simd_load(u+76), simd_load(v+76)); 356 | 357 | __m256d xy20 = _mm256_mul_pd(simd_load(u+80), simd_load(v+80)); 358 | __m256d xy21 = _mm256_mul_pd(simd_load(u+84), simd_load(v+84)); 359 | __m256d xy22 = _mm256_mul_pd(simd_load(u+88), simd_load(v+88)); 360 | __m256d xy23 = _mm256_mul_pd(simd_load(u+92), simd_load(v+92)); 361 | 362 | __m256d xy24 = _mm256_mul_pd(simd_load(u+96), simd_load(v+96)); 363 | __m256d xy25 = _mm256_mul_pd(simd_load(u+100), simd_load(v+100)); 364 | __m256d xy26 = _mm256_mul_pd(simd_load(u+104), simd_load(v+104)); 365 | __m256d xy27 = _mm256_mul_pd(simd_load(u+108), simd_load(v+108)); 366 | 367 | __m256d xy28 = _mm256_mul_pd(simd_load(u+112), simd_load(v+112)); 368 | __m256d xy29 = _mm256_mul_pd(simd_load(u+116), simd_load(v+116)); 369 | __m256d xy30 = _mm256_mul_pd(simd_load(u+120), simd_load(v+120)); 370 | __m256d xy31 = _mm256_mul_pd(simd_load(u+124), simd_load(v+124)); 371 | 372 | __m256d dotproduct1a = _mm256_add_pd(_mm256_add_pd(xy0, xy1), _mm256_add_pd(xy2, xy3)); 373 | __m256d dotproduct2a = _mm256_add_pd(_mm256_add_pd(xy4, xy5), _mm256_add_pd(xy6, xy7)); 374 | __m256d dotproduct3a = _mm256_add_pd(_mm256_add_pd(xy8, xy9), _mm256_add_pd(xy10, xy11)); 375 | __m256d dotproduct4a = _mm256_add_pd(_mm256_add_pd(xy12, xy13), _mm256_add_pd(xy14, xy15)); 376 | 377 | __m256d dotproduct1b = _mm256_add_pd(_mm256_add_pd(xy16, xy17), _mm256_add_pd(xy18, xy19)); 378 | __m256d dotproduct2b = _mm256_add_pd(_mm256_add_pd(xy20, xy21), _mm256_add_pd(xy22, xy23)); 379 | __m256d dotproduct3b = _mm256_add_pd(_mm256_add_pd(xy24, xy25), _mm256_add_pd(xy26, xy27)); 380 | __m256d dotproduct4b = _mm256_add_pd(_mm256_add_pd(xy28, xy29), _mm256_add_pd(xy30, xy31)); 381 | 382 | 383 | __m256d dotproductA = _mm256_add_pd(dotproduct1a, dotproduct2a); 384 | __m256d dotproductB = _mm256_add_pd(dotproduct3a, dotproduct4a); 385 | __m256d dotproductC = _mm256_add_pd(dotproduct1b, dotproduct2b); 386 | __m256d dotproductD = _mm256_add_pd(dotproduct3b, dotproduct4b); 387 | 388 | 389 | __m256d dotproduct1 = _mm256_add_pd(dotproductA, dotproductB); 390 | __m256d dotproduct2 = _mm256_add_pd(dotproductC, dotproductD); 391 | 392 | __m256d dotproduct = _mm256_add_pd(dotproduct1, dotproduct2); 393 | #ifdef _MSC_VER 394 | __m256d s = _mm256_hadd_pd(dotproduct, dotproduct); 395 | return s.m256d_f64[0]+s.m256d_f64[2]; 396 | #else 397 | m256d s; 398 | s.variable = _mm256_hadd_pd(dotproduct, dotproduct); 399 | return s.m256d_f64[0]+s.m256d_f64[2]; 400 | #endif 401 | } 402 | 403 | static inline double dotp_full(const double* u, const double* v, int n) { 404 | 405 | int k=0; 406 | const int max128 = n-128+1; 407 | const int max64 = n-64+1; 408 | const int max16 = n-16+1; 409 | double conv = 0; 410 | for (; k 468 | class aligned_allocator { 469 | public: 470 | 471 | typedef size_t size_type; 472 | typedef ptrdiff_t difference_type; 473 | typedef T* pointer; 474 | typedef const T* const_pointer; 475 | typedef T& reference; 476 | typedef const T& const_reference; 477 | typedef T value_type; 478 | 479 | 480 | template 481 | struct rebind { 482 | typedef aligned_allocator other; 483 | }; 484 | 485 | 486 | pointer address(reference value) const { 487 | return &value; 488 | }; 489 | 490 | const_pointer address(const_reference value) const { 491 | return &value; 492 | }; 493 | 494 | 495 | aligned_allocator() throw() { 496 | }; 497 | 498 | aligned_allocator(const aligned_allocator&) throw() { 499 | }; 500 | 501 | template 502 | aligned_allocator(const aligned_allocator&) throw() { 503 | }; 504 | 505 | ~aligned_allocator() throw() { 506 | }; 507 | 508 | //max capacity 509 | size_type max_size() const throw() { 510 | return std::numeric_limits::max(); 511 | }; 512 | 513 | 514 | pointer allocate(size_type num, const_pointer *hint = 0) { 515 | 516 | return (pointer)malloc_simd(num*sizeof(T), Alignment); 517 | }; 518 | 519 | 520 | void construct(pointer p, const T& value) { 521 | 522 | // memcpy( p, &value, sizeof T ); 523 | *p=value; 524 | // new ( (void *) p ) T ( value ); 525 | }; 526 | 527 | 528 | void destroy(pointer p) { 529 | 530 | p->~T(); 531 | }; 532 | 533 | 534 | void deallocate(pointer p, size_type num) { 535 | 536 | free_simd(p); 537 | }; 538 | }; 539 | -------------------------------------------------------------------------------- /data/imgheart2/heart2_seq-00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/imgheart2/heart2_seq-00.png -------------------------------------------------------------------------------- /data/imgheart2/heart2_seq-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/imgheart2/heart2_seq-01.png -------------------------------------------------------------------------------- /data/imgheart2/heart2_seq-02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/imgheart2/heart2_seq-02.png -------------------------------------------------------------------------------- /data/imgheart2/heart2_seq-03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/imgheart2/heart2_seq-03.png -------------------------------------------------------------------------------- /data/imgheart2/heart2_seq-04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/imgheart2/heart2_seq-04.png -------------------------------------------------------------------------------- /data/imgheart2/heart2_seq-05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/imgheart2/heart2_seq-05.png -------------------------------------------------------------------------------- /data/imgheart2/heart2_seq-06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/imgheart2/heart2_seq-06.png -------------------------------------------------------------------------------- /data/imgheart2/heart2_seq-07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/imgheart2/heart2_seq-07.png -------------------------------------------------------------------------------- /data/imgheart2/heart2_seq-08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/imgheart2/heart2_seq-08.png -------------------------------------------------------------------------------- /data/imgheart2/heart2_seq-09.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/imgheart2/heart2_seq-09.png -------------------------------------------------------------------------------- /data/imgheart2/heart2_seq-10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/imgheart2/heart2_seq-10.png -------------------------------------------------------------------------------- /data/imgheart2/heart2_seq-11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/imgheart2/heart2_seq-11.png -------------------------------------------------------------------------------- /data/imgheart2/heart2_seq-12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/imgheart2/heart2_seq-12.png -------------------------------------------------------------------------------- /data/mug_001_expr2/img_0021.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/mug_001_expr2/img_0021.png -------------------------------------------------------------------------------- /data/mug_001_expr2/img_0023.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/mug_001_expr2/img_0023.png -------------------------------------------------------------------------------- /data/mug_001_expr2/img_0026.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/mug_001_expr2/img_0026.png -------------------------------------------------------------------------------- /data/mug_001_expr2/img_0028.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/mug_001_expr2/img_0028.png -------------------------------------------------------------------------------- /data/mug_001_expr2/img_0029.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/mug_001_expr2/img_0029.png -------------------------------------------------------------------------------- /data/mug_001_expr2/img_0032.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/mug_001_expr2/img_0032.png -------------------------------------------------------------------------------- /data/mug_001_expr2/img_0034.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/mug_001_expr2/img_0034.png -------------------------------------------------------------------------------- /data/mug_001_expr2/img_0036.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/mug_001_expr2/img_0036.png -------------------------------------------------------------------------------- /data/mug_001_expr2/img_0038.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/mug_001_expr2/img_0038.png -------------------------------------------------------------------------------- /data/mug_001_expr2/img_0039.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/mug_001_expr2/img_0039.png -------------------------------------------------------------------------------- /data/mug_001_expr2/img_0040.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/mug_001_expr2/img_0040.png -------------------------------------------------------------------------------- /data/mug_001_expr2/img_0041.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/mug_001_expr2/img_0041.png -------------------------------------------------------------------------------- /data/mug_001_expr2/img_0043.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/mug_001_expr2/img_0043.png -------------------------------------------------------------------------------- /data/mug_001_expr2/img_0044.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/mug_001_expr2/img_0044.png -------------------------------------------------------------------------------- /data/mug_001_expr2/img_0045.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/mug_001_expr2/img_0045.png -------------------------------------------------------------------------------- /data/mug_001_expr2/img_0046.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/mug_001_expr2/img_0046.png -------------------------------------------------------------------------------- /data/mug_001_expr2/img_0047.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/mug_001_expr2/img_0047.png -------------------------------------------------------------------------------- /data/mug_001_expr2/img_0050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/mug_001_expr2/img_0050.png -------------------------------------------------------------------------------- /data/mug_001_expr2/img_0110.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/mug_001_expr2/img_0110.png -------------------------------------------------------------------------------- /data/mug_001_expr2/img_0117.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthieuheitz/WassersteinDictionaryLearning/5cc59a486ec1f122403d324e8882bec810440624/data/mug_001_expr2/img_0117.png --------------------------------------------------------------------------------