├── README.md ├── cma ├── LICENSE ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── bbobbenchmarks.cpython-39.pyc │ ├── constraints_handler.cpython-39.pyc │ ├── evolution_strategy.cpython-39.pyc │ ├── fitness_functions.cpython-39.pyc │ ├── fitness_transformations.cpython-39.pyc │ ├── interfaces.cpython-39.pyc │ ├── logger.cpython-39.pyc │ ├── optimization_tools.cpython-39.pyc │ ├── purecma.cpython-39.pyc │ ├── recombination_weights.cpython-39.pyc │ ├── restricted_gaussian_sampler.cpython-39.pyc │ ├── s.cpython-39.pyc │ ├── sampler.cpython-39.pyc │ ├── sigma_adaptation.cpython-39.pyc │ └── transformations.cpython-39.pyc ├── bbobbenchmarks.py ├── constraints_handler.py ├── evolution_strategy.py ├── fitness_functions.py ├── fitness_models.py ├── fitness_transformations.py ├── interfaces.py ├── logger.py ├── optimization_tools.py ├── purecma.py ├── recombination_weights.py ├── restricted_gaussian_sampler.py ├── s.py ├── sampler.py ├── sigma_adaptation.py ├── test.py ├── transformations.py ├── utilities │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── math.cpython-39.pyc │ │ ├── python3for2.cpython-39.pyc │ │ └── utils.cpython-39.pyc │ ├── math.py │ ├── python3for2.py │ └── utils.py └── wrapper.py ├── embeding_distribution.py ├── figures ├── case.png ├── cma.gif └── framework.png ├── infer_inversion.py ├── initialize_inversion.py ├── train_inversion.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Gradient-Free Textual Inversion 2 | 3 | Gradient-free textual inversion for personalized text-to-image generation. 4 | We introduce to use evolution strategy from [OpenAI](https://openai.com/blog/evolution-strategies/) without gradient to optimize the pesudo-word embeddings. 5 | Our implementation is totally compatible with [diffusers](https://github.com/huggingface/diffusers) and stable diffusion model. 6 | 7 |

8 | cma process 9 |
10 | 11 | Evolution process for textual embeddings. 12 | 13 |

14 | 15 | 16 | ## What does this repo do? 17 | 18 | Current personalized text-to-image approaches, which learn to bind a unique identifier with specific subjects or styles in a few given images, usually incorporate a special word and tune its embedding parameters through gradient descent. 19 | It is natural to question whether we can optimize the textual inversions by only accessing the inference of models? As only requiring the forward computation to determine the textual inversion retains the benefits of efficient computation and safe deployment. 20 | 21 | Hereto, we introduce a gradient-free framework to optimize the continuous textual inversion in personalized text-to-image generation. 22 | Specifically, we first initialize the textual inversion with non-parameter cross-attention to ensure the latent embedding space. 23 | Then, instead of optimizing in the original high-dimensional embedding space, which is intractable for derivative-free optimization, we perform optimization in a decomposition subspace with (i) PCA and (ii) prior normalization through *iterative* evolutionary strategy. 24 | 25 |

26 | gradient-free textual inversion framework 27 |
28 | 29 | Overview of the proposed gradient-free textual inversion framework. 30 | 31 |

32 | 33 | 34 | 35 | ## Cases 36 | 37 | Some cases generated by standard textual inversion and gradient-free inversion based on stable diffusion model. 38 | 39 | 40 |

41 | figures/cases 42 |
43 | 44 | Cases for the personalized text-to-image generation. 45 | 46 |

47 | 48 | 49 | ## Process 50 | 51 | 52 | To intialize the textual inversion with cross-attention automatically, run: 53 | ``` 54 | python initialize_inversion.py 55 | ``` 56 | 57 | Then, iterative optimize the textual inversion with gradient-free evolution strategy, run: 58 | ``` 59 | python train_inversion.py 60 | ``` 61 | 62 | Finally, with the trained textual inversion, you can generated personalized image with ```infer_inversion.py``` script. 63 | 64 | 65 | 66 | ## Acknowledge 67 | 68 | This repository is based on [diffusers](https://github.com/huggingface/diffusers) and [textual inversion](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) script. Thanks for their clear code. 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /cma/LICENSE: -------------------------------------------------------------------------------- 1 | The BSD 3-Clause License 2 | Copyright (c) 2014 Inria 3 | Author: Nikolaus Hansen, 2008- 4 | Author: Petr Baudis, 2014 5 | Author: Youhei Akimoto, 2016- 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions 9 | are met: 10 | 11 | 1. Redistributions of source code must retain the above copyright and 12 | authors notice, this list of conditions and the following disclaimer. 13 | 14 | 2. Redistributions in binary form must reproduce the above copyright 15 | and authors notice, this list of conditions and the following 16 | disclaimer in the documentation and/or other materials provided with 17 | the distribution. 18 | 19 | 3. Neither the name of the copyright holder nor the names of its 20 | contributors nor the authors names may be used to endorse or promote 21 | products derived from this software without specific prior written 22 | permission. 23 | 24 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 25 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 26 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 27 | AUTHORS OR CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES 28 | OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 29 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 30 | DEALINGS IN THE SOFTWARE. 31 | -------------------------------------------------------------------------------- /cma/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Package `cma` implements the CMA-ES (Covariance Matrix Adaptation 3 | Evolution Strategy). 4 | 5 | CMA-ES is a stochastic optimizer for robust non-linear non-convex 6 | derivative- and function-value-free numerical optimization. 7 | 8 | This implementation can be used with Python versions >= 2.7, namely, 9 | it was tested with 2.7, 3.5, 3.6, 3.7, 3.8. 10 | 11 | CMA-ES searches for a minimizer (a solution x in :math:`R^n`) of an 12 | objective function f (cost function), such that f(x) is minimal. 13 | Regarding f, only a passably reliable ranking of the candidate 14 | solutions in each iteration is necessary. Neither the function values 15 | itself, nor the gradient of f need to be available or do matter (like 16 | in the downhill simplex Nelder-Mead algorithm). Some termination 17 | criteria however depend on actual f-values. 18 | 19 | The `cma` module provides two independent implementations of the 20 | CMA-ES algorithm in the classes `cma.CMAEvolutionStrategy` and 21 | `cma.purecma.CMAES`. 22 | 23 | In each implementation two interfaces are provided: 24 | 25 | - functions `fmin2` and `purecma.fmin`: 26 | run a complete minimization of the passed objective function with 27 | CMA-ES. `fmin` also provides optional restarts and noise handling. 28 | 29 | - class `CMAEvolutionStrategy` and `purecma.CMAES`: 30 | allow for minimization such that the control of the iteration 31 | loop remains with the user. 32 | 33 | The `cma` package root provides shortcuts to these and other classes and 34 | functions. 35 | 36 | Used external packages are `numpy` (only `purecma` does not depend on 37 | `numpy`) and `matplotlib.pyplot` (for `plot` etc., optional but highly 38 | recommended). 39 | 40 | Install 41 | ======= 42 | To use the module, the folder ``cma`` only needs to be visible in the 43 | python path, e.g. in the current working directory. 44 | 45 | To install the module from pipy, type:: 46 | 47 | pip install cma 48 | 49 | from the command line. 50 | 51 | To install the module from a ``cma`` folder:: 52 | 53 | pip install -e cma 54 | 55 | To upgrade the currently installed version use additionally the ``-U`` 56 | option. 57 | 58 | Testing 59 | ======= 60 | From the system shell:: 61 | 62 | python -m cma.test -h 63 | python -m cma.test 64 | python -c "import cma.test; cma.test.main()" # the same 65 | 66 | or from any (i)python shell:: 67 | 68 | import cma.test 69 | cma.test.main() 70 | 71 | should run without complaints in about between 20 and 100 seconds. 72 | 73 | Example 74 | ======= 75 | From a python shell:: 76 | 77 | import cma 78 | help(cma) # "this" help message, use cma? in ipython 79 | help(cma.fmin) 80 | help(cma.CMAEvolutionStrategy) 81 | help(cma.CMAOptions) 82 | cma.CMAOptions('tol') # display 'tolerance' termination options 83 | cma.CMAOptions('verb') # display verbosity options 84 | res = cma.fmin(cma.ff.tablet, 15 * [1], 1) 85 | es = cma.CMAEvolutionStrategy(15 * [1], 1).optimize(cma.ff.tablet) 86 | help(es.result) 87 | res[0], es.result[0] # best evaluated solution 88 | res[5], es.result[5] # mean solution, presumably better with noise 89 | 90 | :See also: `fmin` (), `CMAOptions`, `CMAEvolutionStrategy` 91 | 92 | :Author: Nikolaus Hansen, 2008- 93 | :Author: Petr Baudis, 2014 94 | :Author: Youhei Akimoto, 2017- 95 | 96 | :License: BSD 3-Clause, see LICENSE file. 97 | 98 | """ 99 | 100 | # How to create a html documentation file: 101 | # pydoctor --docformat=restructuredtext --make-html cma 102 | # old: 103 | # pydoc -w cma # edit the header (remove local pointers) 104 | # epydoc cma.py # comes close to javadoc but does not find the 105 | # # links of function references etc 106 | # doxygen needs @package cma as first line in the module docstring 107 | # some things like class attributes are not interpreted correctly 108 | # sphinx: doc style of doc.python.org, could not make it work (yet) 109 | # __docformat__ = "reStructuredText" # this hides some comments entirely? 110 | 111 | from __future__ import absolute_import # now local imports must use . 112 | # big difference between PY2 and PY3: 113 | from __future__ import division 114 | from __future__ import print_function 115 | # only necessary for python 2.5 (not supported) and not in heavy use 116 | from __future__ import with_statement 117 | ___author__ = "Nikolaus Hansen and Petr Baudis and Youhei Akimoto" 118 | __license__ = "BSD 3-clause" 119 | 120 | import warnings as _warnings 121 | 122 | # __package__ = 'cma' 123 | from . import purecma 124 | try: 125 | import numpy 126 | del numpy 127 | except ImportError: 128 | _warnings.warn('Only `cma.purecma` has been imported. Install `numpy` ("pip' 129 | ' install numpy") if you want to import the entire `cma`' 130 | ' package.') 131 | else: 132 | from . import (constraints_handler, evolution_strategy, fitness_functions, 133 | fitness_transformations, interfaces, optimization_tools, 134 | sampler, sigma_adaptation, transformations, utilities, 135 | ) 136 | # from . import test # gives a warning with python -m cma.test (since Python 3.5.3?) 137 | test = 'type "import cma.test" to access the `test` module of `cma`' 138 | from . import s 139 | from .fitness_functions import ff 140 | from .fitness_transformations import GlueArguments, ScaleCoordinates 141 | from .evolution_strategy import fmin, fmin2, fmin_con, fmin_con2, CMAEvolutionStrategy, CMAOptions 142 | from .logger import disp, plot, CMADataLogger 143 | from .optimization_tools import NoiseHandler 144 | from .constraints_handler import BoundPenalty, BoundTransform, ConstrainedFitnessAL 145 | from .evolution_strategy import cma_default_options_ 146 | 147 | del division, print_function, absolute_import, with_statement #, unicode_literals 148 | 149 | # fcts = ff # historical reasons only, replace cma.fcts with cma.ff first 150 | 151 | __version__ = "3.2.2" 152 | # $Source$ # according to PEP 8 style guides, but what is it good for? 153 | # $Id: __init__.py 4432 2020-05-28 18:39:09Z hansen $ 154 | # bash $: svn propset svn:keywords 'Date Revision Id' __init__.py 155 | -------------------------------------------------------------------------------- /cma/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /cma/__pycache__/bbobbenchmarks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/bbobbenchmarks.cpython-39.pyc -------------------------------------------------------------------------------- /cma/__pycache__/constraints_handler.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/constraints_handler.cpython-39.pyc -------------------------------------------------------------------------------- /cma/__pycache__/evolution_strategy.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/evolution_strategy.cpython-39.pyc -------------------------------------------------------------------------------- /cma/__pycache__/fitness_functions.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/fitness_functions.cpython-39.pyc -------------------------------------------------------------------------------- /cma/__pycache__/fitness_transformations.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/fitness_transformations.cpython-39.pyc -------------------------------------------------------------------------------- /cma/__pycache__/interfaces.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/interfaces.cpython-39.pyc -------------------------------------------------------------------------------- /cma/__pycache__/logger.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/logger.cpython-39.pyc -------------------------------------------------------------------------------- /cma/__pycache__/optimization_tools.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/optimization_tools.cpython-39.pyc -------------------------------------------------------------------------------- /cma/__pycache__/purecma.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/purecma.cpython-39.pyc -------------------------------------------------------------------------------- /cma/__pycache__/recombination_weights.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/recombination_weights.cpython-39.pyc -------------------------------------------------------------------------------- /cma/__pycache__/restricted_gaussian_sampler.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/restricted_gaussian_sampler.cpython-39.pyc -------------------------------------------------------------------------------- /cma/__pycache__/s.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/s.cpython-39.pyc -------------------------------------------------------------------------------- /cma/__pycache__/sampler.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/sampler.cpython-39.pyc -------------------------------------------------------------------------------- /cma/__pycache__/sigma_adaptation.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/sigma_adaptation.cpython-39.pyc -------------------------------------------------------------------------------- /cma/__pycache__/transformations.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/transformations.cpython-39.pyc -------------------------------------------------------------------------------- /cma/fitness_functions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """versatile container for test objective functions. 3 | 4 | For the time being this is probably best used like:: 5 | 6 | from cma.fitness_functions import ff 7 | 8 | """ 9 | 10 | from __future__ import (absolute_import, division, print_function, 11 | ) # unicode_literals, with_statement) 12 | # from __future__ import collections.MutableMapping 13 | # does not exist in future, otherwise Python 2.5 would work, since 0.91.01 14 | from .utilities.python3for2 import range 15 | 16 | import os, sys 17 | 18 | import numpy as np 19 | # arange, cos, size, eye, inf, dot, floor, outer, zeros, linalg.eigh, 20 | # sort, argsort, random, ones,... 21 | from numpy import array, dot, isscalar, sum # sum is not needed 22 | # from numpy import inf, exp, log, isfinite 23 | # to access the built-in sum fct: ``__builtins__.sum`` or ``del sum`` 24 | # removes the imported sum and recovers the shadowed build-in 25 | try: np.median([1,2,3,2]) # fails currently in pypy, also sigma_vec.scaling 26 | except AttributeError: 27 | def _median(x): 28 | x = sorted(x) 29 | if len(x) % 2: 30 | return x[len(x) // 2] 31 | return (x[len(x) // 2 - 1] + x[len(x) // 2]) / 2 32 | np.median = _median 33 | from .transformations import Rotation 34 | from .utilities import utils 35 | from .utilities.utils import rglen 36 | del (division, print_function, absolute_import, 37 | ) # unicode_literals, with_statement) 38 | 39 | # $Source$ # according to PEP 8 style guides, but what is it good for? 40 | # $Id: fitness_functions.py 4150 2015-03-20 13:53:36Z hansen $ 41 | # bash $: svn propset svn:keywords 'Date Revision Id' fitness_functions.py 42 | 43 | try: 44 | from . import bbobbenchmarks 45 | BBOB = bbobbenchmarks # for backwards compatibility 46 | except ImportError: 47 | BBOB = """Call:: 48 | cma.ff.fetch_bbob_fcts() 49 | to download and extract `bbobbenchmarks.py` and thereby setting 50 | cma.ff.BBOB to these benchmarks; then, e.g., `F12 = cma.ff.BBOB.F12()` 51 | returns an instance of F12 Bent Cigar. 52 | 53 | CAVEAT: in the downloaded `bbobbenchmarks.py` file in L987 54 | ``np.negative(idx)`` needs to be replaced by ``~idx``. 55 | """ 56 | from .fitness_transformations import rotate #, ComposedFunction, Function 57 | 58 | def elli(x, cond=1e6): 59 | """unbound test function, needed to test multiprocessor, as long 60 | as the other test functions are defined within a class and 61 | only accessable via the class instance""" 62 | return sum(cond**(np.arange(len(x)) / (len(x) - 1 + 1e-9)) * np.asarray(x)**2) 63 | def sphere(x): 64 | return sum(np.asarray(x)**2) 65 | 66 | def _iqr(x): 67 | x = sorted(x) 68 | i1 = int(len(x) / 4) 69 | i3 = int(3*len(x) / 4) 70 | return x[i3] - x[i1] 71 | 72 | class FitnessFunctions(object): # TODO: this class is not necessary anymore? But some effort is needed to change it 73 | """collection of objective functions. 74 | 75 | """ 76 | evaluations = 0 # number of calls or any other practical use 77 | def __init__(self): 78 | """""" 79 | @property # avoid pickle and multiprocessing error 80 | def BBOB(self): 81 | return bbobbenchmarks 82 | try: BBOB.__doc__ = bbobbenchmarks.__doc__ 83 | except: pass # in Python 2 __doc__ is readonly 84 | 85 | def rot(self, x, fun, rot=1, args=()): 86 | """returns ``fun(rotation(x), *args)``, ie. `fun` applied to a rotated argument""" 87 | if len(np.shape(array(x))) > 1: # parallelized 88 | res = [] 89 | for x in x: 90 | res.append(self.rot(x, fun, rot, args)) 91 | return res 92 | 93 | if rot: 94 | return fun(rotate(x, *args)) 95 | else: 96 | return fun(x) 97 | def somenan(self, x, fun, p=0.1): 98 | """returns sometimes np.NaN, otherwise fun(x)""" 99 | if np.random.rand(1) < p: 100 | return np.NaN 101 | else: 102 | return fun(x) 103 | 104 | def epslow(self, fun, eps=1e-7, Neff=lambda x: int(len(x)**0.5)): 105 | return lambda x: fun(x[:Neff(x)]) + eps * np.mean(x**2) 106 | 107 | def rand(self, x): 108 | """Random test objective function""" 109 | return np.random.random(1)[0] 110 | def linear(self, x): 111 | return -x[0] 112 | def lineard(self, x): 113 | if 1 < 3 and any(array(x) < 0): 114 | return np.nan 115 | if 1 < 3 and sum([(10 + i) * x[i] for i in rglen(x)]) > 50e3: 116 | return np.nan 117 | return -sum(x) 118 | def sphere(self, x): 119 | """Sphere (squared norm) test objective function""" 120 | # return np.random.rand(1)[0]**0 * sum(x**2) + 1 * np.random.rand(1)[0] 121 | return sum((x + 0)**2) 122 | def subspace_sphere(self, x, visible_ratio=1/2): 123 | """ 124 | """ 125 | # here we could use an init function, that is this would 126 | # preferably be a class 127 | m = int(visible_ratio * len(x) + 1) 128 | x = np.asarray(x)[np.random.permutation(len(x))[:m]] 129 | return sum(x**2) 130 | def pnorm(self, x, p=0.5): 131 | return sum(np.abs(x)**p)**(1./p) 132 | def grad_sphere(self, x, *args): 133 | return 2*array(x, copy=False) 134 | def grad_to_one(self, x, *args): 135 | return array(x, copy=False) - 1 136 | def sphere_pos(self, x): 137 | """Sphere (squared norm) test objective function""" 138 | # return np.random.rand(1)[0]**0 * sum(x**2) + 1 * np.random.rand(1)[0] 139 | c = 0.0 140 | if x[0] < c: 141 | return np.nan 142 | return -c**2 + sum((x + 0)**2) 143 | def spherewithoneconstraint(self, x): 144 | return sum((x + 0)**2) if x[0] > 1 else np.nan 145 | def elliwithoneconstraint(self, x, idx=[-1]): 146 | return self.ellirot(x) if all(array(x)[idx] > 1) else np.nan 147 | 148 | def spherewithnconstraints(self, x): 149 | return sum((x + 0)**2) if all(array(x) > 1) else np.nan 150 | # zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz 151 | def noisysphere(self, x, noise=2.10e-9, cond=1.0, noise_offset=0.10): 152 | """noise=10 does not work with default popsize, ``cma.NoiseHandler(dimension, 1e7)`` helps""" 153 | return self.elli(x, cond=cond) * np.exp(0 + noise * np.random.randn() / len(x)) + noise_offset * np.random.rand() 154 | def spherew(self, x): 155 | """Sphere (squared norm) with sum x_i = 1 test objective function""" 156 | # return np.random.rand(1)[0]**0 * sum(x**2) + 1 * np.random.rand(1)[0] 157 | # s = sum(abs(x)) 158 | # return sum((x/s+0)**2) - 1/len(x) 159 | # return sum((x/s)**2) - 1/len(x) 160 | return -0.01 * x[0] + abs(x[0])**-2 * sum(x[1:]**2) 161 | def epslowsphere(self, x, eps=1e-7, Neff=lambda x: int(len(x)**0.5)): 162 | """TODO: define as wrapper""" 163 | return np.mean(x[:Neff(x)]**2) + eps * np.mean(x**2) 164 | def partsphere(self, x): 165 | """Sphere (squared norm) test objective function""" 166 | self.evaluations += 1 167 | # return np.random.rand(1)[0]**0 * sum(x**2) + 1 * np.random.rand(1)[0] 168 | dim = len(x) 169 | x = array([x[i % dim] for i in range(2 * dim)]) 170 | N = 8 171 | i = self.evaluations % dim 172 | # f = sum(x[i:i + N]**2) 173 | f = sum(x[np.random.randint(dim, size=N)]**2) 174 | return f 175 | def sectorsphere(self, x): 176 | """asymmetric Sphere (squared norm) test objective function""" 177 | return sum(x**2) + (1e6 - 1) * sum(x[x < 0]**2) 178 | def cornersphere(self, x): 179 | """Sphere (squared norm) test objective function constraint to the corner""" 180 | nconstr = len(x) - 0 181 | if any(x[:nconstr] < 1): 182 | return np.NaN 183 | return sum(x**2) - nconstr 184 | def cornerelli(self, x): 185 | """ """ 186 | if any(x < 1): 187 | return np.NaN 188 | return self.elli(x) - self.elli(np.ones(len(x))) 189 | def cornerellirot(self, x): 190 | """ """ 191 | if any(x < 1): 192 | return np.NaN 193 | return self.ellirot(x) 194 | def normalSkew(self, f): 195 | N = np.random.randn(1)[0]**2 196 | if N < 1: 197 | N = f * N # diminish blow up lower part 198 | return N 199 | def noiseC(self, x, func=sphere, fac=10, expon=0.8): 200 | f = func(self, x) 201 | N = np.random.randn(1)[0] / np.random.randn(1)[0] 202 | return max(1e-19, f + (float(fac) / len(x)) * f**expon * N) 203 | def noise(self, x, func=sphere, fac=10, expon=1): 204 | f = func(self, x) 205 | # R = np.random.randn(1)[0] 206 | R = np.log10(f) + expon * abs(10 - np.log10(f)) * np.random.rand(1)[0] 207 | # sig = float(fac)/float(len(x)) 208 | # R = log(f) + 0.5*log(f) * random.randn(1)[0] 209 | # return max(1e-19, f + sig * (f**np.log10(f)) * np.exp(R)) 210 | # return max(1e-19, f * np.exp(sig * N / f**expon)) 211 | # return max(1e-19, f * normalSkew(f**expon)**sig) 212 | return f + 10**R # == f + f**(1+0.5*RN) 213 | def cigar(self, x, rot=0, cond=1e6, noise=0): 214 | """Cigar test objective function""" 215 | if rot: 216 | x = rotate(x) 217 | x = [x] if isscalar(x[0]) else x # scalar into list 218 | f = [(x[0]**2 + cond * sum(x[1:]**2)) * np.exp(noise * np.random.randn(1)[0] / len(x)) for x in x] 219 | return f if len(f) > 1 else f[0] # 1-element-list into scalar 220 | def grad_cigar(self, x, *args): 221 | grad = 2 * 1e6 * np.array(x) 222 | grad[0] /= 1e6 223 | return grad 224 | def diagonal_cigar(self, x, cond=1e6): 225 | axis = np.ones(len(x)) / len(x)**0.5 226 | proj = dot(axis, x) * axis 227 | s = sum(proj**2) 228 | s += cond * sum((x - proj)**2) 229 | return s 230 | def tablet(self, x, cond=1e6, rot=0): 231 | """Tablet test objective function""" 232 | x = np.asarray(x) 233 | if rot and rot is not ff.tablet: 234 | x = rotate(x) 235 | x = [x] if isscalar(x[0]) else x # scalar into list 236 | f = [cond * x[0]**2 + sum(x[1:]**2) for x in x] 237 | return f if len(f) > 1 else f[0] # 1-element-list into scalar 238 | def grad_tablet(self, x, *args): 239 | grad = 2 * np.array(x) 240 | grad[0] *= 1e6 241 | return grad 242 | def cigtab(self, y): 243 | """Cigtab test objective function""" 244 | X = [y] if isscalar(y[0]) else y 245 | f = [1e-4 * x[0]**2 + 1e4 * x[1]**2 + sum(x[2:]**2) for x in X] 246 | return f if len(f) > 1 else f[0] 247 | def cigtab2(self, x, condition=1e8, n_axes=None): 248 | """cigtab with 1 + 5% long and short axes. 249 | 250 | `n_axes: int`, if > 0, sets the number of long as well as short 251 | axes to `n_axes`, respectively. 252 | """ 253 | m = n_axes or 1 + len(x) // 20 254 | x = np.asarray(x) 255 | f = sum(x[m:-m]**2) 256 | f += condition**0.5 * sum(x[:m]**2) 257 | f += condition**-0.5 * sum(x[-m:]**2) 258 | return f 259 | def twoaxes(self, y): 260 | """Cigtab test objective function""" 261 | X = [y] if isscalar(y[0]) else y 262 | N2 = len(X[0]) // 2 263 | f = [1e6 * sum(x[0:N2]**2) + sum(x[N2:]**2) for x in X] 264 | return f if len(f) > 1 else f[0] 265 | def ellirot(self, x): 266 | return ff.elli(array(x), 1) 267 | def hyperelli(self, x): 268 | N = len(x) 269 | return sum((np.arange(1, N + 1) * x)**2) 270 | def halfelli(self, x): 271 | l = len(x) // 2 272 | felli = self.elli(x[:l]) 273 | return felli + 1e-8 * sum(x[l:]**2) 274 | def elli(self, x, rot=0, xoffset=0, cond=1e6, actuator_noise=0.0, both=False): 275 | """Ellipsoid test objective function""" 276 | x = np.asarray(x) 277 | if not isscalar(x[0]): # parallel evaluation 278 | return [self.elli(xi, rot) for xi in x] # could save 20% overall 279 | if rot: 280 | x = rotate(x) 281 | N = len(x) 282 | if actuator_noise: 283 | x = x + actuator_noise * np.random.randn(N) 284 | 285 | ftrue = sum(cond**(np.arange(N) / (N - 1.)) * (x + xoffset)**2) \ 286 | if N > 1 else (x + xoffset)**2 287 | 288 | alpha = 0.49 + 1. / N 289 | beta = 1 290 | felli = np.random.rand(1)[0]**beta * ftrue * \ 291 | max(1, (10.**9 / (ftrue + 1e-99))**(alpha * np.random.rand(1)[0])) 292 | # felli = ftrue + 1*np.random.randn(1)[0] / (1e-30 + 293 | # np.abs(np.random.randn(1)[0]))**0 294 | if both: 295 | return (felli, ftrue) 296 | else: 297 | # return felli # possibly noisy value 298 | return ftrue # + np.random.randn() 299 | def grad_elli(self, x, *args): 300 | cond = 1e6 301 | N = len(x) 302 | return 2 * cond**(np.arange(N) / (N - 1.)) * array(x, copy=False) 303 | def fun_as_arg(self, x, *args): 304 | """``fun_as_arg(x, fun, *more_args)`` calls ``fun(x, *more_args)``. 305 | 306 | Use case:: 307 | 308 | fmin(cma.fun_as_arg, args=(fun,), gradf=grad_numerical) 309 | 310 | calls fun_as_args(x, args) and grad_numerical(x, fun, args=args) 311 | 312 | """ 313 | fun = args[0] 314 | more_args = args[1:] if len(args) > 1 else () 315 | return fun(x, *more_args) 316 | def grad_numerical(self, x, func, epsilon=None): 317 | """symmetric gradient""" 318 | eps = 1e-8 * (1 + abs(x)) if epsilon is None else epsilon 319 | grad = np.zeros(len(x)) 320 | ei = np.zeros(len(x)) # float is 1.6 times faster than int 321 | for i in rglen(x): 322 | ei[i] = eps[i] 323 | grad[i] = (func(x + ei) - func(x - ei)) / (2*eps[i]) 324 | ei[i] = 0 325 | return grad 326 | def elliconstraint(self, x, cfac=1e8, tough=True, cond=1e6): 327 | """ellipsoid test objective function with "constraints" """ 328 | N = len(x) 329 | f = sum(cond**(np.arange(N)[-1::-1] / (N - 1)) * x**2) 330 | cvals = (x[0] + 1, 331 | x[0] + 1 + 100 * x[1], 332 | x[0] + 1 - 100 * x[1]) 333 | if tough: 334 | f += cfac * sum(max(0, c) for c in cvals) 335 | else: 336 | f += cfac * sum(max(0, c + 1e-3)**2 for c in cvals) 337 | return f 338 | def rosen(self, x, alpha=1e2): 339 | """Rosenbrock test objective function""" 340 | x = [x] if isscalar(x[0]) else x # scalar into list 341 | x = np.asarray(x) 342 | f = [sum(alpha * (x[:-1]**2 - x[1:])**2 + (1. - x[:-1])**2) for x in x] 343 | return f if len(f) > 1 else f[0] # 1-element-list into scalar 344 | def grad_rosen(self, x, *args): 345 | N = len(x) 346 | grad = np.zeros(N) 347 | grad[0] = 2 * (x[0] - 1) + 200 * (x[1] - x[0]**2) * -2 * x[0] 348 | i = np.arange(1, N - 1) 349 | grad[i] = 2 * (x[i] - 1) - 400 * (x[i+1] - x[i]**2) * x[i] + 200 * (x[i] - x[i-1]**2) 350 | grad[N-1] = 200 * (x[N-1] - x[N-2]**2) 351 | return grad 352 | def rosen_chained(self, x, alpha=1e2): 353 | x = [x] if isscalar(x[0]) else x # scalar into list 354 | f = [(1. - x[0])**2 + sum(alpha * (x[:-1]**2 - x[1:])**2) for x in x] 355 | return f if len(f) > 1 else f[0] # 1-element-list into scalar 356 | 357 | def diffpow(self, x, rot=0): 358 | """Diffpow test objective function""" 359 | N = len(x) 360 | if rot: 361 | x = rotate(x) 362 | return sum(np.abs(x)**(2. + 4.*np.arange(N) / (N - 1.)))**0.5 363 | def rosenelli(self, x): 364 | N = len(x) 365 | Nhalf = int((N + 1) / 2) 366 | return self.rosen(x[:Nhalf]) + self.elli(x[Nhalf:], cond=1) 367 | def ridge(self, x, expo=2): 368 | x = [x] if isscalar(x[0]) else x # scalar into list 369 | f = [x[0] + 100 * np.sum(x[1:]**2)**(expo / 2.) for x in x] 370 | return f if len(f) > 1 else f[0] # 1-element-list into scalar 371 | def ridgecircle(self, x, expo=0.5): 372 | """a difficult sharp ridge type function. 373 | 374 | A modified implementation of HG Beyers `happycat`. 375 | """ 376 | a = len(x) 377 | s = sum(x**2) 378 | return ((s - a)**2)**(expo / 2) + s / a + sum(x) / a 379 | def happycat(self, x, alpha=1. / 8): 380 | """a difficult sharp ridge type function. 381 | 382 | Proposed by HG Beyer. 383 | """ 384 | s = sum(x**2) 385 | return ((s - len(x))**2)**alpha + (s / 2 + sum(x)) / len(x) + 0.5 386 | def flat(self, x): 387 | return 1 388 | return 1 if np.random.rand(1) < 0.9 else 1.1 389 | return np.random.randint(1, 30) 390 | def branin(self, x): 391 | # in [0,15]**2 392 | y = x[1] 393 | x = x[0] + 5 394 | return (y - 5.1 * x**2 / 4 / np.pi**2 + 5 * x / np.pi - 6)**2 + 10 * (1 - 1 / 8 / np.pi) * np.cos(x) + 10 - 0.397887357729738160000 395 | def goldsteinprice(self, x): 396 | x1 = x[0] 397 | x2 = x[1] 398 | return (1 + (x1 + x2 + 1)**2 * (19 - 14 * x1 + 3 * x1**2 - 14 * x2 + 6 * x1 * x2 + 3 * x2**2)) * ( 399 | 30 + (2 * x1 - 3 * x2)**2 * (18 - 32 * x1 + 12 * x1**2 + 48 * x2 - 36 * x1 * x2 + 27 * x2**2)) - 3 400 | def griewank(self, x): 401 | # was in [-600 600] 402 | x = (600. / 5) * x 403 | return 1 - np.prod(np.cos(x / np.sqrt(1. + np.arange(len(x))))) + sum(x**2) / 4e3 404 | def levy(self, x): 405 | """a rather benign multimodal function. 406 | 407 | xopt == ones, fopt == 0.0 408 | """ 409 | w = 1 + (np.asarray(x) - 1) / 4 410 | del x 411 | f = np.sin(np.pi * w[0])**2 412 | f += (w[-1] - 1)**2 * (1 + np.sin(2 * np.pi * w[-1])**2) 413 | w = w[1:-1] 414 | return f + sum((w - 1)**2 * (1 + 10 * np.sin(np.pi * w + 1)**2)) 415 | def rastrigin(self, x): 416 | """Rastrigin test objective function""" 417 | if not isscalar(x[0]): 418 | N = len(x[0]) 419 | return [10 * N + sum(xi**2 - 10 * np.cos(2 * np.pi * xi)) for xi in x] 420 | # return 10*N + sum(x**2 - 10*np.cos(2*np.pi*x), axis=1) 421 | N = len(x) 422 | return 10 * N + sum(x**2 - 10 * np.cos(2 * np.pi * x)) 423 | def schaffer(self, x): 424 | """ Schaffer function x0 in [-100..100]""" 425 | N = len(x) 426 | s = x[0:N - 1]**2 + x[1:N]**2 427 | return sum(s**0.25 * (np.sin(50 * s**0.1)**2 + 1)) 428 | 429 | def schwefelelli(self, x): 430 | s = 0 431 | f = 0 432 | for i in rglen(x): 433 | s += x[i] 434 | f += s**2 435 | return f 436 | def schwefelmult(self, x, pen_fac=1e4): 437 | """multimodal Schwefel function with domain -500..500""" 438 | y = [x] if isscalar(x[0]) else x 439 | N = len(y[0]) 440 | f = array([418.9829 * N - 1.27275661e-5 * N - sum(x * np.sin(np.abs(x)**0.5)) 441 | + pen_fac * sum((abs(x) > 500) * (abs(x) - 500)**2) for x in y]) 442 | return f if len(f) > 1 else f[0] 443 | def schwefel2_22(self, x): 444 | """Schwefel 2.22 function""" 445 | return sum(np.abs(x)) + np.prod(np.abs(x)) 446 | def optprob(self, x): 447 | n = np.arange(len(x)) + 1 448 | f = n * x * (1 - x)**(n - 1) 449 | return sum(1 - f) 450 | def lincon(self, x, theta=0.01): 451 | """ridge like linear function with one linear constraint""" 452 | if x[0] < 0: 453 | return np.NaN 454 | return theta * x[1] + x[0] 455 | def rosen_nesterov(self, x, rho=100): 456 | """needs exponential number of steps in a non-increasing 457 | f-sequence. 458 | 459 | x_0 = (-1,1,...,1) 460 | See Jarre (2011) "On Nesterov's Smooth Chebyshev-Rosenbrock 461 | Function" 462 | 463 | """ 464 | f = 0.25 * (x[0] - 1)**2 465 | f += rho * sum((x[1:] - 2 * x[:-1]**2 + 1)**2) 466 | return f 467 | def powel_singular(self, x): 468 | # ((8 * np.sin(7 * (x[i] - 0.9)**2)**2 ) + (6 * np.sin())) 469 | res = np.sum((x[i - 1] + 10 * x[i])**2 + 5 * (x[i + 1] - x[i + 2])**2 + 470 | (x[i] - 2 * x[i + 1])**4 + 10 * (x[i - 1] - x[i + 2])**4 471 | for i in range(1, len(x) - 2)) 472 | return 1 + res 473 | def styblinski_tang(self, x): 474 | """in [-5, 5] 475 | found also in Lazar and Jarre 2016, optimum in f(-2.903534...)=0 476 | """ 477 | # x_opt = N * [-2.90353402], seems to have essentially 478 | # (only) 2**N local optima 479 | return (39.1661657037714171054273576010019 * len(x))**1 + \ 480 | sum(x**4 - 16 * x**2 + 5 * x) / 2 481 | 482 | def trid(self, x): 483 | return sum((x-1)**2) - sum(x[:-1] * x[1:]) 484 | 485 | def bukin(self, x): 486 | """Bukin function from Wikipedia, generalized simplistically from 2-D. 487 | 488 | http://en.wikipedia.org/wiki/Test_functions_for_optimization""" 489 | s = 0 490 | for k in range((1+len(x)) // 2): 491 | z = x[2 * k] 492 | y = x[min((2*k + 1, len(x)-1))] 493 | s += 100 * np.abs(y - 0.01 * z**2)**0.5 + 0.01 * np.abs(z + 10) 494 | return s 495 | 496 | def xinsheyang2(self, x, termination_friendly=True): 497 | """a multimodal function which is rather unsolvable in larger dimension. 498 | 499 | >>> import functools 500 | >>> import numpy as np 501 | >>> import cma 502 | >>> f = functools.partial(cma.ff.xinsheyang2, termination_friendly=False) 503 | >>> X = [(i * [0] + (4 - i) * [1.24]) for i in range(5)] 504 | >>> for x in X: print(x) 505 | [1.24, 1.24, 1.24, 1.24] 506 | [0, 1.24, 1.24, 1.24] 507 | [0, 0, 1.24, 1.24] 508 | [0, 0, 0, 1.24] 509 | [0, 0, 0, 0] 510 | >>> ' '.join(['{:.3}'.format(f(x)) for x in X]) # [np.round(f(x), 3) for x in X] 511 | '0.091 0.186 0.336 0.456 0.0' 512 | 513 | One needs to solve a trinary deceptive function where f-value (to 514 | be minimized) is monotonuously decreasing with increasing distance 515 | to the global optimum >= 1. That is, the global optimum is 516 | surrounded by 3^n - 1 local optima that have the better values the 517 | further they are away from the global optimum. 518 | 519 | Conclusion: it is a rather suspicious sign if an algorithm finds the global 520 | optimum of this function in larger dimension. 521 | 522 | See also http://benchmarkfcns.xyz/benchmarkfcns/xinsheyangn2fcn.html 523 | """ 524 | x = np.asarray(x) 525 | val = np.sum(np.abs(x)) * np.exp(-np.sum(np.sin(np.square(x)))) # np.mean under the exponential makes the function much easier 526 | if termination_friendly and val < 1: 527 | val **= 1. / len(x) 528 | return val 529 | 530 | def _fetch_bbob_fcts(self): 531 | """Fetch GECCO BBOB 2009 functions from WWW and set as `self.BBOB`. 532 | 533 | Side effects in the current folder: two files are added and folder 534 | "._tmp_" is removed. 535 | """ 536 | # fetch http://coco.lri.fr/downloads/download15.02/bbobpproc.tar.gz 537 | bbob_version, fname = '15.03', 'bbobpproc.tar.gz' 538 | url = 'http://coco.lri.fr/downloads/download'+bbob_version+'/'+fname # 3MB 539 | print('downloading %s ...' % url, end=''); sys.stdout.flush() 540 | utils.download_file(url) 541 | print(' done downloading') 542 | print('extracting bbobbenchmarks.py'); sys.stdout.flush() 543 | utils.extract_targz(url.split(os.path.sep)[-1], 544 | os.path.join('bbob.v' + bbob_version, 545 | 'python', 'bbobbenchmarks.py')) 546 | print('importing bbobbenchmarks.py and setting as BBOB attribute') 547 | import bbobbenchmarks 548 | self.BBOB = bbobbenchmarks 549 | print("BBOB set and ready to go. Example: `f11 = cma.FF.BBOB.F11()`") 550 | 551 | def fetch_bbob_fcts(self): 552 | """Fetch GECCO BBOB 2009 functions from WWW and set as `self.BBOB`. 553 | """ 554 | url = "http://coco.gforge.inria.fr/python/bbobbenchmarks.py" 555 | 556 | ff = FitnessFunctions() 557 | -------------------------------------------------------------------------------- /cma/fitness_transformations.py: -------------------------------------------------------------------------------- 1 | """Wrapper for objective functions like noise, rotation, gluing args 2 | """ 3 | from __future__ import absolute_import, division, print_function #, unicode_literals, with_statement 4 | import warnings 5 | from functools import partial 6 | import numpy as np 7 | import time 8 | from .utilities import utils 9 | from .transformations import ConstRandnShift, Rotation 10 | from .constraints_handler import BoundTransform 11 | from .optimization_tools import EvalParallel2 # for backwards compatibility 12 | from .utilities.python3for2 import range 13 | del absolute_import, division, print_function #, unicode_literals, with_statement 14 | 15 | rotate = Rotation() 16 | 17 | class Function(object): 18 | """a declarative base class, indicating that a derived class instance 19 | "is" a (fitness/objective) function. 20 | 21 | A `callable` passed to `__init__` is called as the fitness 22 | `Function`, otherwise the `_eval` method is called, if defined in the 23 | derived class, when the `Function` instance is called. If the input 24 | argument is a matrix or a list of vectors, the method is called for 25 | each vector individually like 26 | ``_eval(numpy.asarray(vector)) for vector in matrix``. 27 | 28 | >>> import cma 29 | >>> from cma.fitness_transformations import Function 30 | >>> f = Function(cma.ff.rosen) 31 | >>> assert f.evaluations == 0 32 | >>> assert f([2, 3]) == cma.ff.rosen([2, 3]) 33 | >>> assert f.evaluations == 1 34 | >>> assert f([[1], [2]]) == [cma.ff.rosen([1]), cma.ff.rosen([2])] 35 | >>> assert f.evaluations == 3 36 | >>> class Fsphere(Function): 37 | ... def _eval(self, x): 38 | ... return sum(x**2) 39 | >>> fsphere = Fsphere() 40 | >>> assert fsphere.evaluations == 0 41 | >>> assert fsphere([2, 3]) == 4 + 9 and fsphere([[2], [3]]) == [4, 9] 42 | >>> assert fsphere.evaluations == 3 43 | >>> Fsphere.__init__ = lambda self: None # overwrites Function.__init__ 44 | >>> assert Fsphere()([2]) == 4 # which is perfectly fine to do 45 | 46 | Details: 47 | 48 | - When called, a class instance calls either the function passed to 49 | `__init__` or, if none was given, tries to call any of the 50 | `function_names_to_evaluate_first_found`, first come first serve. 51 | By default, ``function_names_to_evaluate_first_found == ["_eval"]``. 52 | 53 | - This class cannot move to module `fitness_functions`, because the 54 | latter uses `fitness_transformations.rotate`. 55 | 56 | """ 57 | _function_names_to_evaluate_first_found = ["_eval"] 58 | @property 59 | def function_names_to_evaluate_first_found(self): 60 | """attributes which are searched for to be called if no function 61 | was given to `__init__`. 62 | 63 | The first present match is used. 64 | """ # % str(Function._function_names_to_evaluate_first_found) 65 | return Function._function_names_to_evaluate_first_found 66 | 67 | def __init__(self, fitness_function=None): 68 | """allows to define the fitness_function to be called, doesn't 69 | need to be ever called 70 | """ 71 | Function.initialize(self, fitness_function) 72 | def initialize(self, fitness_function): 73 | """initialization of `Function` 74 | """ 75 | self.__callable = fitness_function # this naming prevents interference with a derived class variable 76 | self.evaluations = 0 77 | self.ftarget = -np.inf 78 | self.target_hit_at = 0 # evaluation counter when target was first hit 79 | self.__initialized = True 80 | 81 | def __call__(self, *args, **kwargs): 82 | # late initialization if necessary 83 | try: 84 | if not self.__initialized: 85 | raise AttributeError 86 | except AttributeError: 87 | Function.initialize(self, None) 88 | # find the "right" callable 89 | callable_ = self.__callable 90 | if callable_ is None: 91 | for name in self.function_names_to_evaluate_first_found: 92 | try: 93 | callable_ = getattr(self, name) 94 | break 95 | except AttributeError: 96 | pass 97 | # call with each vector 98 | if callable_ is not None: 99 | X, list_revert = utils.as_vector_list(args[0]) 100 | self.evaluations += len(X) 101 | F = [callable_(np.asarray(x), *args[1:], **kwargs) for x in X] 102 | if not self.target_hit_at and any(np.asarray(F) <= self.ftarget): 103 | self.target_hit_at = self.evaluations - len(X) + 1 + list(np.asarray(F) <= self.ftarget).index(True) 104 | return list_revert(F) 105 | else: 106 | self.evaluations += 1 # somewhat bound to fail 107 | 108 | class ComposedFunction(Function, list): 109 | """compose an arbitrary number of functions. 110 | 111 | A class instance is a list of functions. Calling the instance executes 112 | the composition of these functions (evaluating from right to left as 113 | in math notation). Functions can be added to or removed from the list 114 | at any time with the obvious effect. To remain consistent (if needed), 115 | the ``list_of_inverses`` attribute must be updated respectively. 116 | 117 | >>> import numpy as np 118 | >>> from cma.fitness_transformations import ComposedFunction 119 | >>> f1, f2, f3, f4 = lambda x: 2*x, lambda x: x**2, lambda x: 3*x, lambda x: x**3 120 | >>> f = ComposedFunction([f1, f2, f3, f4]) 121 | >>> assert isinstance(f, list) and isinstance(f, ComposedFunction) 122 | >>> assert f[0] == f1 # how I love Python indexing 123 | >>> assert all(f(x) == f1(f2(f3(f4(x)))) for x in np.random.rand(10)) 124 | >>> assert f4 == f.pop() 125 | >>> assert len(f) == 3 126 | >>> f.insert(1, f4) 127 | >>> f.append(f4) 128 | >>> assert all(f(x) == f1(f4(f2(f3(f4(x))))) for x in range(5)) 129 | 130 | A more specific example: 131 | 132 | >>> from cma.fitness_transformations import ComposedFunction 133 | >>> from cma.constraints_handler import BoundTransform 134 | >>> from cma import ff 135 | >>> f = ComposedFunction([ff.elli, 136 | ... BoundTransform([[0], [1]]).transform]) 137 | >>> assert max(f([2, 3]), f([1, 1])) <= ff.elli([1, 1]) 138 | 139 | Details: 140 | 141 | - This class can serve as basis for a more transparent 142 | alternative to a ``scaling_of_variables`` CMA option or for any 143 | necessary transformation of the fitness/objective function 144 | (genotype-phenotype transformation). 145 | 146 | - The parallelizing call with a list of solutions of the `Function` 147 | class is not inherited. The inheritence from `Function` is rather 148 | declarative than funtional and could be omitted. 149 | 150 | """ 151 | def __init__(self, list_of_functions, list_of_inverses=None): 152 | """Caveat: to remain consistent, the ``list_of_inverses`` must be 153 | updated explicitly, if the list of function was updated after 154 | initialization. 155 | """ 156 | list.__init__(self, list_of_functions) 157 | Function.__init__(self) 158 | self.list_of_inverses = list_of_inverses 159 | 160 | def __call__(self, x, *args, **kwargs): 161 | Function.__call__(self, x, *args, **kwargs) # for the possible side effects only 162 | for i in range(-1, -len(self) - 1, -1): 163 | x = self[i](x, *args, **kwargs) 164 | return x 165 | 166 | def inverse(self, x, *args, **kwargs): 167 | """evaluate the composition of inverses on ``x``. 168 | 169 | Return `None`, if no list was provided. 170 | """ 171 | if self.list_of_inverses is None: 172 | utils.print_warning("inverses were not given") 173 | return 174 | for i in range(len(self.list_of_inverses)): 175 | x = self.list_of_inverses[i](x, *args, **kwargs) 176 | return x 177 | 178 | class StackFunction(Function): 179 | """a function that returns ``f1(x[:n1]) + f2(x[n1:])``. 180 | 181 | >>> import functools 182 | >>> import numpy as np 183 | >>> import cma 184 | >>> def elli48(x): 185 | ... return 1e-4 * functools.partial(cma.ff.elli, cond=1e8)(x) 186 | >>> fcigtab = cma.fitness_transformations.StackFunction( 187 | ... elli48, cma.ff.sphere, 2) 188 | >>> x = [1, 2, 3, 4] 189 | >>> assert np.isclose(fcigtab(x), cma.ff.cigtab(np.asarray(x))) 190 | 191 | """ 192 | def __init__(self, f1, f2, n1): 193 | self.f1 = f1 194 | self.f2 = f2 195 | self.n1 = n1 196 | def _eval(self, x, *args, **kwargs): 197 | return self.f1(x[:self.n1], *args, **kwargs) + self.f2(x[self.n1:], *args, **kwargs) 198 | 199 | class GlueArguments(Function): 200 | """deprecated, use `functools.partial` or 201 | `cma.fitness_transformations.partial` instead, which has the same 202 | functionality and interface. 203 | 204 | from a `callable` return a `callable` with arguments attached. 205 | 206 | 207 | An ellipsoid function with condition number ``1e4`` is created by 208 | ``felli1e4 = cma.s.ft.GlueArguments(cma.ff.elli, cond=1e4)``. 209 | 210 | >>> import cma 211 | >>> f = cma.fitness_transformations.GlueArguments(cma.ff.elli, 212 | ... cond=1e1) 213 | >>> assert f([1, 2]) == 1**2 + 1e1 * 2**2 214 | 215 | """ 216 | def __init__(self, fitness_function, *args, **kwargs): 217 | """define function, ``args``, and ``kwargs``. 218 | 219 | ``args`` are appended to arguments passed in the call, ``kwargs`` 220 | are updated with keyword arguments passed in the call. 221 | """ 222 | Function.__init__(self, fitness_function) 223 | self.fitness_function = fitness_function # never used 224 | self.args = args 225 | self.kwargs = kwargs 226 | def __call__(self, x, *args, **kwargs): 227 | """call function with at least one additional argument and 228 | attached args and kwargs. 229 | """ 230 | joined_kwargs = dict(self.kwargs) 231 | joined_kwargs.update(kwargs) 232 | x = np.asarray(x) 233 | return Function.__call__(self, x, *(args + self.args), 234 | **joined_kwargs) 235 | 236 | class FBoundTransform(ComposedFunction): 237 | """shortcut for ``ComposedFunction([f, BoundTransform(bounds).transform])``, 238 | see also below. 239 | 240 | Maps the argument into bounded or half-bounded (feasible) domain 241 | before evaluating ``f``. 242 | 243 | Example with lower bound at 0, which becomes the image of -0.05 in 244 | `BoundTransform.transform`: 245 | 246 | >>> import cma, numpy as np 247 | >>> f = cma.fitness_transformations.FBoundTransform(cma.ff.elli, 248 | ... [[0], None]) 249 | >>> assert all(f[1](np.random.randn(200)) >= 0) 250 | >>> assert all(f[1]([-0.05, -0.05]) == 0) 251 | >>> assert f([-0.05, -0.05]) == 0 252 | 253 | A slightly more verbose version to implement the lower bound at zero 254 | in the very same way: 255 | 256 | >>> import cma 257 | >>> felli_in_bound = cma.s.ft.ComposedFunction( 258 | ... [cma.ff.elli, cma.BoundTransform([[0], None]).transform]) 259 | 260 | """ 261 | def __init__(self, fitness_function, bounds): 262 | """`bounds[0]` are lower bounds, `bounds[1]` are upper bounds 263 | """ 264 | self.bound_tf = BoundTransform(bounds) # not strictly necessary 265 | ComposedFunction.__init__(self, 266 | [fitness_function, self.bound_tf.transform]) 267 | 268 | class Rotated(ComposedFunction): 269 | """return a rotated version of a function for testing purpose. 270 | 271 | This class is a convenience shortcut for the litte more verbose 272 | composition of a function with a rotation: 273 | 274 | >>> import cma 275 | >>> from cma import fitness_transformations as ft 276 | >>> f1 = ft.Rotated(cma.ff.elli) 277 | >>> f2 = ft.ComposedFunction([cma.ff.elli, ft.Rotation()]) 278 | >>> assert f1([2]) == f2([2]) # same rotation only in 1-D 279 | >>> assert f1([1, 2]) != f2([1, 2]) 280 | 281 | """ 282 | def __init__(self, f, rotate=None, seed=None): 283 | """optional argument ``rotate(x)`` must return a (stable) rotation 284 | of ``x``. 285 | """ 286 | if rotate is None: 287 | rotate = Rotation(seed=seed) 288 | ComposedFunction.__init__(self, [f, rotate]) 289 | 290 | class Shifted(ComposedFunction): 291 | """compose a function with a shift in x-space. 292 | 293 | >>> import cma 294 | >>> f = cma.s.ft.Shifted(cma.ff.elli) 295 | 296 | Details: this class solely provides as default second argument to 297 | `ComposedFunction`, namely a random shift in search space. 298 | ``shift=lambda x: x`` would provide "no shift", ``None`` 299 | expands to ``cma.transformations.ConstRandnShift()``. 300 | """ 301 | def __init__(self, f, shift=None): 302 | """``shift(x)`` must return a (stable) shift of x""" 303 | if shift is None: 304 | shift = ConstRandnShift() 305 | ComposedFunction.__init__(self, [f, shift]) 306 | 307 | class ScaleCoordinates(ComposedFunction): 308 | """compose a (fitness) function with a scaling for each variable 309 | (more concisely, a coordinate-wise affine transformation). 310 | 311 | After ``fun2 = cma.ScaleCoordinates(fun, multipliers, zero)``, we have 312 | ``fun2(x) == fun(multipliers * (x - zero))``, where the size of 313 | `multipliers` and `zero` is adapated to the size of `x`, in case by 314 | recycling their last entry. 315 | 316 | >>> import numpy as np 317 | >>> import cma 318 | >>> f = cma.ScaleCoordinates(cma.ff.sphere, [100, 1]) 319 | >>> assert f[0] == cma.ff.sphere # first element of f-composition 320 | >>> assert f(range(1, 6)) == 100**2 + sum([x**2 for x in range(2, 6)]) 321 | >>> assert f([2.1]) == 210**2 == f(2.1) 322 | >>> assert f(20 * [1]) == 100**2 + 19 323 | >>> assert np.all(f.inverse(f.scale_and_offset([1, 2, 3, 4])) == 324 | ... np.asarray([1, 2, 3, 4])) 325 | >>> f = cma.ScaleCoordinates(f, [-2, 7], [2, 3, 4]) # last is recycled 326 | >>> f([5, 6]) == sum(x**2 for x in [100 * -2 * (5 - 2), 7 * (6 - 3)]) 327 | True 328 | 329 | """ 330 | def __init__(self, fitness_function, multipliers=None, zero=None): 331 | """ 332 | :param fitness_function: a `callable` object 333 | :param multipliers: coordinate-wise multipliers. 334 | :param zero: defines a new zero in preimage space, that is, 335 | calling the `ScaleCoordinates` instance returns 336 | ``fitness_function(multipliers * (x - zero))``. 337 | 338 | For both arguments, ``multipliers`` and ``zero``, to fit 339 | the length of the given input, superfluous trailing 340 | elements are ignored and the last element is recycled 341 | if needed. 342 | """ 343 | ComposedFunction.__init__(self, 344 | [fitness_function, self.scale_and_offset]) 345 | self.multiplier = multipliers 346 | if self.multiplier is not None: 347 | self.multiplier = np.asarray(self.multiplier, dtype=float) 348 | self.zero = zero 349 | if zero is not None: 350 | self.zero = np.asarray(zero, dtype=float) 351 | 352 | def scale_and_offset(self, x): 353 | x = np.asarray(x) 354 | r = lambda vec: utils.recycled(vec, as_=x) 355 | if self.zero is not None and self.multiplier is not None: 356 | x = r(self.multiplier) * (x - r(self.zero)) 357 | elif self.zero is not None: 358 | x = x - r(self.zero) 359 | elif self.multiplier is not None: 360 | x = r(self.multiplier) * x 361 | return x 362 | 363 | def inverse(self, x): 364 | """inverse of coordinate-wise affine transformation 365 | ``y / multipliers + zero`` 366 | """ 367 | x = np.asarray(x) 368 | r = lambda vec: utils.recycled(vec, as_=x) 369 | if self.zero is not None and self.multiplier is not None: 370 | x = x / r(self.multiplier) + r(self.zero) 371 | elif self.zero is not None: 372 | x = x + r(self.zero) 373 | elif self.multiplier is not None: 374 | x = x / r(self.multiplier) 375 | return x 376 | 377 | class FixVariables(ComposedFunction): 378 | """Insert variables with given values, thereby reducing the 379 | dimensionality of the resulting composed function. 380 | 381 | The constructor takes ``index_value_pairs``, a `dict` or `list` of 382 | pairs, as input and returns a function with smaller preimage space 383 | than input function ``f``. 384 | 385 | Fixing variable 3 and 5 works like 386 | 387 | >>> from cma.fitness_transformations import FixVariables 388 | >>> index_value_pairs = [[2, 0.2], [4, 0.4]] 389 | >>> fun = FixVariables(cma.ff.elli, index_value_pairs) 390 | >>> fun[1](4 * [1]) == [ 1., 1., 0.2, 1., 0.4, 1.] 391 | True 392 | 393 | Or starting from a given current solution in the larger space from 394 | which we pick the fixed values: 395 | 396 | >>> from cma.fitness_transformations import FixVariables 397 | >>> current_solution = [0.1 * i for i in range(5)] 398 | >>> fixed_indices = [2, 4] 399 | >>> index_value_pairs = [[i, current_solution[i]] # fix these 400 | ... for i in fixed_indices] 401 | >>> fun = FixVariables(cma.ff.elli, index_value_pairs) 402 | >>> fun[1](4 * [1]) == [ 1., 1., 0.2, 1., 0.4, 1.] 403 | True 404 | >>> assert (current_solution == # list with same values 405 | ... fun.transform(fun.insert_variables(current_solution))) 406 | >>> assert (current_solution == # list with same values 407 | ... fun.insert_variables(fun.transform(current_solution))) 408 | 409 | Details: this might replace the ``fixed_variables`` option in 410 | `CMAOptions` in future, but hasn't been thoroughly tested yet. 411 | 412 | Supersedes `ExpandSolution`. 413 | 414 | """ 415 | def __init__(self, f, index_value_pairs): 416 | """return `f` with reduced dimensionality. 417 | 418 | ``index_value_pairs``: 419 | variables 420 | """ 421 | # super(FixVariables, self).__init__( 422 | ComposedFunction.__init__(self, [f, self.insert_variables]) 423 | self.index_value_pairs = dict(index_value_pairs) 424 | def transform(self, x): 425 | """transform `x` such that it could be used as argument to `self`. 426 | 427 | Return a list or array, usually dismissing some elements of 428 | `x`. ``fun.transform`` is the inverse of 429 | ``fun.insert_variables == fun[1]``, that is 430 | ``np.all(x == fun.transform(fun.insert_variables(x))) is True``. 431 | """ 432 | res = [x[i] for i in range(len(x)) 433 | if i not in self.index_value_pairs] 434 | return res if isinstance(x, list) else np.asarray(res) 435 | def insert_variables(self, x): 436 | """return `x` with inserted fixed values""" 437 | if len(self.index_value_pairs) == 0: 438 | return x 439 | y = list(x) 440 | for i in sorted(self.index_value_pairs): 441 | y.insert(i, self.index_value_pairs[i]) 442 | if not isinstance(x, list): 443 | y = np.asarray(y) # doubles the necessary time 444 | return y 445 | 446 | class Expensify(Function): 447 | """Add waiting time to each evaluation, to simulate "expensive" 448 | behavior""" 449 | def __init__(self, callable_, time=1): 450 | """add time in seconds""" 451 | Function.__init__(self) # callable_ could go here 452 | self.time = time 453 | self.callable = callable_ 454 | def __call__(self, *args, **kwargs): 455 | time.sleep(self.time) 456 | Function.__call__(self, *args, **kwargs) 457 | return self.callable(*args, **kwargs) 458 | 459 | class SomeNaNFitness(Function): 460 | """transform ``fitness_function`` to return sometimes ``NaN``""" 461 | def __init__(self, fitness_function, probability_of_nan=0.1): 462 | Function.__init__(self) 463 | self.fitness_function = fitness_function 464 | self.p = probability_of_nan 465 | def __call__(self, x, *args): 466 | Function.__call__(self, x, *args) 467 | if np.random.rand(1) <= self.p: 468 | return np.NaN 469 | else: 470 | return self.fitness_function(x, *args) 471 | 472 | class NoisyFitness(Function): 473 | """apply noise via ``f += rel_noise(dim) * f + abs_noise(dim)``""" 474 | def __init__(self, fitness_function, 475 | rel_noise=lambda dim: 1.1 * np.random.randn() / dim, 476 | abs_noise=lambda dim: 1.1 * np.random.randn()): 477 | """attach relative and absolution noise to ``fitness_function``. 478 | 479 | Relative noise is by default computed using the length of the 480 | input argument to ``fitness_function``. Both noise functions take 481 | ``dimension`` as input. 482 | 483 | >>> import cma 484 | >>> from cma.fitness_transformations import NoisyFitness 485 | >>> fn = NoisyFitness(cma.ff.elli) 486 | >>> assert fn([1, 2]) != cma.ff.elli([1, 2]) 487 | >>> assert fn.evaluations == 1 488 | 489 | """ 490 | Function.__init__(self, fitness_function) 491 | self.rel_noise = rel_noise 492 | self.abs_noise = abs_noise 493 | 494 | def __call__(self, x, *args): 495 | f = Function.__call__(self, x, *args) 496 | if self.rel_noise: 497 | f += f * self.rel_noise(len(x)) 498 | assert np.isscalar(f) 499 | if self.abs_noise: 500 | f += self.abs_noise(len(x)) 501 | return f 502 | 503 | class IntegerMixedFunction(ComposedFunction): 504 | """compose fitness function with some integer variables. 505 | 506 | >>> import cma 507 | >>> f = cma.s.ft.IntegerMixedFunction(cma.ff.elli, [0, 3, 6]) 508 | >>> assert f([0.2, 2]) == f([0.4, 2]) != f([1.2, 2]) 509 | 510 | It is advisable to set minstd of integer variables to 511 | ``1 / (2 * len(integer_variable_indices) + 1)``, in which case in 512 | an independent model at least 33% (1 integer variable) -> 39% (many 513 | integer variables) of the solutions should have an integer mutation 514 | on average. Option ``integer_variables`` of `cma.CMAOptions` 515 | implements this simple measure. 516 | """ 517 | def __init__(self, function, integer_variable_indices, copy_arg=True): 518 | ComposedFunction.__init__(self, [function, self._flatten]) 519 | self.integer_variable_indices = integer_variable_indices 520 | self.copy_arg = copy_arg 521 | def _flatten(self, x): 522 | x = np.array(x, copy=self.copy_arg) 523 | for i in sorted(self.integer_variable_indices): 524 | if i < -len(x): 525 | continue 526 | if i >= len(x): 527 | break 528 | x[i] = np.floor(x[i]) 529 | return x 530 | -------------------------------------------------------------------------------- /cma/interfaces.py: -------------------------------------------------------------------------------- 1 | """Very few interface defining base class definitions""" 2 | from __future__ import absolute_import, division, print_function #, unicode_literals 3 | import warnings 4 | try: from .optimization_tools import EvalParallel2 5 | except: EvalParallel2 = None 6 | del absolute_import, division, print_function #, unicode_literals 7 | 8 | class EvalParallel: 9 | """allow construct ``with EvalParallel(fun) as eval_all:``""" 10 | def __init__(self, fun, *args, **kwargs): 11 | self.fun = fun 12 | def __call__(self, X, args=()): 13 | return [self.fun(x, *args) for x in X] 14 | def __enter__(self): return self 15 | def __exit__(self, *args, **kwargs): pass 16 | 17 | class OOOptimizer(object): 18 | """abstract base class for an Object Oriented Optimizer interface. 19 | 20 | Relevant methods are `__init__`, `ask`, `tell`, `optimize` and `stop`, 21 | and property `result`. Only `optimize` is fully implemented in this 22 | base class. 23 | 24 | Examples 25 | -------- 26 | All examples minimize the function `elli`, the output is not shown. 27 | (A preferred environment to execute all examples is ``ipython``.) 28 | 29 | First we need:: 30 | 31 | # CMAEvolutionStrategy derives from the OOOptimizer class 32 | from cma import CMAEvolutionStrategy 33 | from cma.fitness_functions import elli 34 | 35 | The shortest example uses the inherited method 36 | `OOOptimizer.optimize`:: 37 | 38 | es = CMAEvolutionStrategy(8 * [0.1], 0.5).optimize(elli) 39 | 40 | The input parameters to `CMAEvolutionStrategy` are specific to this 41 | inherited class. The remaining functionality is based on interface 42 | defined by `OOOptimizer`. We might have a look at the result:: 43 | 44 | print(es.result[0]) # best solution and 45 | print(es.result[1]) # its function value 46 | 47 | Virtually the same example can be written with an explicit loop 48 | instead of using `optimize`. This gives the necessary insight into 49 | the `OOOptimizer` class interface and entire control over the 50 | iteration loop:: 51 | 52 | # a new CMAEvolutionStrategy instance 53 | optim = CMAEvolutionStrategy(9 * [0.5], 0.3) 54 | 55 | # this loop resembles optimize() 56 | while not optim.stop(): # iterate 57 | X = optim.ask() # get candidate solutions 58 | f = [elli(x) for x in X] # evaluate solutions 59 | # in case do something else that needs to be done 60 | optim.tell(X, f) # do all the real "update" work 61 | optim.disp(20) # display info every 20th iteration 62 | optim.logger.add() # log another "data line", non-standard 63 | 64 | # final output 65 | print('termination by', optim.stop()) 66 | print('best f-value =', optim.result[1]) 67 | print('best solution =', optim.result[0]) 68 | optim.logger.plot() # if matplotlib is available 69 | 70 | Details 71 | ------- 72 | Most of the work is done in the methods `tell` or `ask`. The property 73 | `result` provides more useful output. 74 | 75 | """ 76 | def __init__(self, xstart, *more_mandatory_args, **optional_kwargs): 77 | """``xstart`` is a mandatory argument""" 78 | self.xstart = xstart 79 | self.more_mandatory_args = more_mandatory_args 80 | self.optional_kwargs = optional_kwargs 81 | self.initialize() 82 | def initialize(self): 83 | """(re-)set to the initial state""" 84 | raise NotImplementedError('method initialize() must be implemented in derived class') 85 | self.countiter = 0 86 | self.xcurrent = [xi for xi in self.xstart] 87 | def ask(self, **optional_kwargs): 88 | """abstract method, AKA "get" or "sample_distribution", deliver 89 | new candidate solution(s), a list of "vectors" 90 | """ 91 | raise NotImplementedError('method ask() must be implemented in derived class') 92 | def tell(self, solutions, function_values): 93 | """abstract method, AKA "update", pass f-values and prepare for 94 | next iteration 95 | """ 96 | self.countiter += 1 97 | raise NotImplementedError('method tell() must be implemented in derived class') 98 | def stop(self): 99 | """abstract method, return satisfied termination conditions in a 100 | dictionary like ``{'termination reason': value, ...}`` or ``{}``. 101 | 102 | For example ``{'tolfun': 1e-12}``, or the empty dictionary ``{}``. 103 | 104 | TODO: this should rather be a property!? Unfortunately, a change 105 | would break backwards compatibility. 106 | """ 107 | raise NotImplementedError('method stop() is not implemented') 108 | def disp(self, modulo=None): 109 | """abstract method, display some iteration info when 110 | ``self.iteration_counter % modulo < 1``, using a reasonable 111 | default for `modulo` if ``modulo is None``. 112 | """ 113 | @property 114 | def result(self): 115 | """abstract property, contain ``(x, f(x), ...)``, that is, the 116 | minimizer, its function value, ... 117 | """ 118 | raise NotImplementedError('result property is not implemented') 119 | return [self.xcurrent] 120 | 121 | def optimize(self, objective_fct, 122 | maxfun=None, iterations=None, min_iterations=1, 123 | args=(), 124 | verb_disp=None, 125 | callback=None, 126 | n_jobs=0, 127 | **kwargs): 128 | """find minimizer of ``objective_fct``. 129 | 130 | CAVEAT: the return value for `optimize` has changed to ``self``, 131 | allowing for a call like:: 132 | 133 | solver = OOOptimizer(x0).optimize(f) 134 | 135 | and investigate the state of the solver. 136 | 137 | Arguments 138 | --------- 139 | 140 | ``objective_fct``: f(x: array_like) -> float 141 | function be to minimized 142 | ``maxfun``: number 143 | maximal number of function evaluations 144 | ``iterations``: number 145 | number of (maximal) iterations, while ``not self.stop()``, 146 | it can be useful to conduct only one iteration at a time. 147 | ``min_iterations``: number 148 | minimal number of iterations, even if ``not self.stop()`` 149 | ``args``: sequence_like 150 | arguments passed to ``objective_fct`` 151 | ``verb_disp``: number 152 | print to screen every ``verb_disp`` iteration, if `None` 153 | the value from ``self.logger`` is "inherited", if 154 | available. 155 | ``callback``: callable or list of callables 156 | callback function called like ``callback(self)`` or 157 | a list of call back functions called in the same way. If 158 | available, ``self.logger.add`` is added to this list. 159 | TODO: currently there is no way to prevent this other than 160 | changing the code of `_prepare_callback_list`. 161 | ``n_jobs=0``: number of processes to be acquired for 162 | multiprocessing to parallelize calls to `objective_fct`. 163 | Must be >1 to expect any speed-up or `None` or `-1`, which 164 | both default to the number of available CPUs. The default 165 | ``n_jobs=0`` avoids the use of multiprocessing altogether. 166 | 167 | ``return self``, that is, the `OOOptimizer` instance. 168 | 169 | Example 170 | ------- 171 | >>> import cma 172 | >>> es = cma.CMAEvolutionStrategy(7 * [0.1], 0.1 173 | ... ).optimize(cma.ff.rosen, verb_disp=100) 174 | ... #doctest: +ELLIPSIS 175 | (4_w,9)-aCMA-ES (mu_w=2.8,w_1=49%) in dimension 7 (seed=...) 176 | Iterat #Fevals function value axis ratio sigma ... 177 | 1 9 ... 178 | 2 18 ... 179 | 3 27 ... 180 | 100 900 ... 181 | >>> cma.s.Mh.vequals_approximately(es.result[0], 7 * [1], 1e-5) 182 | True 183 | 184 | """ 185 | if kwargs: 186 | message = "ignoring unkown argument%s %s in OOOptimizer.optimize" % ( 187 | 's' if len(kwargs) > 1 else '', str(kwargs)) 188 | warnings.warn( 189 | message) # warnings.simplefilter('ignore', lineno=186) suppresses this warning 190 | 191 | if iterations is not None and min_iterations > iterations: 192 | warnings.warn("doing min_iterations = %d > %d = iterations" 193 | % (min_iterations, iterations)) 194 | iterations = min_iterations 195 | callback = self._prepare_callback_list(callback) 196 | 197 | citer, cevals = 0, 0 198 | with (EvalParallel2 or EvalParallel)(objective_fct, 199 | None if n_jobs == -1 else n_jobs) as eval_all: 200 | while not self.stop() or citer < min_iterations: 201 | if (maxfun and cevals >= maxfun) or ( 202 | iterations and citer >= iterations): 203 | return self 204 | citer += 1 205 | 206 | X = self.ask() # deliver candidate solutions 207 | # fitvals = [objective_fct(x, *args) for x in X] 208 | fitvals = eval_all(X, args=args) 209 | cevals += len(fitvals) 210 | self.tell(X, fitvals) # all the work is done here 211 | for f in callback: 212 | f(self) 213 | self.disp(verb_disp) # disp does nothing if not overwritten 214 | 215 | # final output 216 | self._force_final_logging() 217 | 218 | if verb_disp: # do not print by default to allow silent verbosity 219 | self.disp(1) 220 | print('termination by', self.stop()) 221 | print('best f-value =', self.result[1]) 222 | print('solution =', self.result[0]) 223 | 224 | return self 225 | 226 | def _prepare_callback_list(self, callback): # helper function 227 | """return a list of callbacks including ``self.logger.add``. 228 | 229 | ``callback`` can be a `callable` or a `list` (or iterable) of 230 | callables. Otherwise a `ValueError` exception is raised. 231 | """ 232 | if callback is None: 233 | callback = [] 234 | if callable(callback): 235 | callback = [callback] 236 | try: 237 | callback = list(callback) + [self.logger.add] 238 | except AttributeError: 239 | pass 240 | try: 241 | for c in callback: 242 | if not callable(c): 243 | raise ValueError("""callback argument %s is not 244 | callable""" % str(c)) 245 | except TypeError: 246 | raise ValueError("""callback argument must be a `callable` or 247 | an iterable (e.g. a list) of callables, after some 248 | processing it was %s""" % str(callback)) 249 | return callback 250 | 251 | def _force_final_logging(self): # helper function 252 | """try force the logger to log NOW""" 253 | try: 254 | if not self.logger: 255 | return 256 | except AttributeError: 257 | return 258 | # the idea: modulo == 0 means never log, 1 or True means log now 259 | try: 260 | modulo = bool(self.logger.modulo) 261 | except AttributeError: 262 | modulo = True # could also be named force 263 | try: 264 | self.logger.add(self, modulo=modulo) 265 | except AttributeError: 266 | pass 267 | except TypeError: 268 | try: 269 | self.logger.add(self) 270 | except Exception as e: 271 | print(' The final call of the logger in' 272 | ' OOOptimizer._force_final_logging from' 273 | ' OOOptimizer.optimize did not succeed: %s' 274 | % str(e)) 275 | 276 | class StatisticalModelSamplerWithZeroMeanBaseClass(object): 277 | """yet versatile base class to replace a sampler namely in 278 | `CMAEvolutionStrategy` 279 | """ 280 | def __init__(self, std_vec, **kwargs): 281 | """pass the vector of initial standard deviations or dimension of 282 | the underlying sample space. 283 | 284 | Ideally catch the case when `std_vec` is a scalar and then 285 | interpreted as dimension. 286 | """ 287 | try: 288 | dimension = len(std_vec) 289 | except TypeError: # std_vec has no len 290 | dimension = std_vec 291 | std_vec = dimension * [1] 292 | raise NotImplementedError 293 | 294 | def sample(self, number, update=None): 295 | """return list of i.i.d. samples. 296 | 297 | :param number: is the number of samples. 298 | :param update: controls a possibly lazy update of the sampler. 299 | """ 300 | raise NotImplementedError 301 | 302 | def update(self, vectors, weights): 303 | """``vectors`` is a list of samples, ``weights`` a corrsponding 304 | list of learning rates 305 | """ 306 | raise NotImplementedError 307 | 308 | def parameters(self, mueff=None, lam=None): 309 | """return `dict` with (default) parameters, e.g., `c1` and `cmu`. 310 | 311 | :See also: `RecombinationWeights`""" 312 | if (hasattr(self, '_mueff') and hasattr(self, '_lam') and 313 | (mueff == self._mueff or mueff is None) and 314 | (lam == self._lam or lam is None)): 315 | return self._parameters 316 | self._mueff = mueff 317 | lower_lam = 6 # for setting c1 318 | if lam is None: 319 | lam = lower_lam 320 | self._lam = lam 321 | # todo: put here rather generic formula with degrees of freedom 322 | # todo: replace these base class computations with the appropriate 323 | c1 = min((1, lam / lower_lam)) * 2 / ((self.dimension + 1.3)**2.0 + mueff) 324 | alpha = 2 325 | self._parameters = dict( 326 | c1=c1, 327 | cmu=min((1 - c1, 328 | # or alpha * (mueff - 0.9) with relative min and 329 | # max value of about 1: 0.4, 1.75: 1.5 330 | alpha * (0.25 + mueff - 2 + 1 / mueff) / 331 | ((self.dimension + 2)**2 + alpha * mueff / 2))) 332 | ) 333 | return self._parameters 334 | 335 | def norm(self, x): 336 | """return Mahalanobis norm of `x` w.r.t. the statistical model""" 337 | return sum(self.transform_inverse(x)**2)**0.5 338 | @property 339 | def condition_number(self): 340 | raise NotImplementedError 341 | @property 342 | def covariance_matrix(self): 343 | raise NotImplementedError 344 | @property 345 | def variances(self): 346 | """vector of coordinate-wise (marginal) variances""" 347 | raise NotImplementedError 348 | 349 | def transform(self, x): 350 | """transform ``x`` as implied from the distribution parameters""" 351 | raise NotImplementedError 352 | 353 | def transform_inverse(self, x): 354 | raise NotImplementedError 355 | 356 | def to_linear_transformation_inverse(self, reset=False): 357 | """return inverse of associated linear transformation""" 358 | raise NotImplementedError 359 | 360 | def to_linear_transformation(self, reset=False): 361 | """return associated linear transformation""" 362 | raise NotImplementedError 363 | 364 | def inverse_hessian_scalar_correction(self, mean, X, f): 365 | """return scalar correction ``alpha`` such that ``X`` and ``f`` 366 | fit to ``f(x) = (x-mean) (alpha * C)**-1 (x-mean)`` 367 | """ 368 | raise NotImplementedError 369 | 370 | def __imul__(self, factor): 371 | raise NotImplementedError 372 | 373 | class BaseDataLogger(object): 374 | """abstract base class for a data logger that can be used with an 375 | `OOOptimizer`. 376 | 377 | Details: attribute `modulo` is used in `OOOptimizer.optimize`. 378 | """ 379 | 380 | def __init__(self): 381 | self.optim = None 382 | """object instance to be logging data from""" 383 | self._data = None 384 | """`dict` of logged data""" 385 | self.filename = "_BaseDataLogger_datadict.py" 386 | """file to save to or load from unless specified otherwise""" 387 | 388 | def register(self, optim, *args, **kwargs): 389 | """register an optimizer ``optim``, only needed if method `add` is 390 | called without passing the ``optim`` argument 391 | """ 392 | self.optim = optim 393 | return self 394 | 395 | def add(self, optim=None, more_data=None, **kwargs): 396 | """abstract method, add a "data point" from the state of ``optim`` 397 | into the logger. 398 | 399 | The argument ``optim`` can be omitted if ``optim`` was 400 | ``register`` ()-ed before, acts like an event handler 401 | """ 402 | raise NotImplementedError 403 | 404 | def disp(self, *args, **kwargs): 405 | """abstract method, display some data trace""" 406 | print('method BaseDataLogger.disp() not implemented, to be done in subclass ' + str(type(self))) 407 | 408 | def plot(self, *args, **kwargs): 409 | """abstract method, plot data""" 410 | print('method BaseDataLogger.plot() is not implemented, to be done in subclass ' + str(type(self))) 411 | 412 | def save(self, name=None): 413 | """save data to file `name` or `self.filename`""" 414 | with open(name or self.filename, 'w') as f: 415 | f.write(repr(self._data)) 416 | 417 | def load(self, name=None): 418 | """load data from file `name` or `self.filename`""" 419 | from ast import literal_eval 420 | with open(name or self.filename, 'r') as f: 421 | self._data = literal_eval(f.read()) 422 | return self 423 | @property 424 | def data(self): 425 | """logged data in a dictionary""" 426 | return self._data 427 | -------------------------------------------------------------------------------- /cma/recombination_weights.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """`RecombinationWeights` is a list of recombination weights for the CMA-ES. 3 | 4 | The most delicate part is the correct setting of negative weights depending 5 | on learning rates to prevent negative definite matrices when using the 6 | weights in the covariance matrix update. 7 | 8 | The dependency chain is 9 | 10 | lambda -> weights -> mueff -> c1, cmu -> negative weights 11 | 12 | """ 13 | # https://gist.github.com/nikohansen/3eb4ef0790ff49276a7be3cdb46d84e9 14 | from __future__ import division, print_function 15 | import math 16 | 17 | class RecombinationWeights(list): 18 | """a list of decreasing (recombination) weight values. 19 | 20 | To be used in the update of the covariance matrix C in CMA-ES as 21 | ``w_i``:: 22 | 23 | C <- (1 - c1 - cmu * sum w_i) C + c1 ... + cmu sum w_i y_i y_i^T 24 | 25 | After calling `finalize_negative_weights`, the weights 26 | ``w_i`` let ``1 - c1 - cmu * sum w_i = 1`` and guaranty positive 27 | definiteness of C if ``y_i^T C^-1 y_i <= dimension`` for all 28 | ``w_i < 0``. 29 | 30 | Class attributes/properties: 31 | 32 | - ``lambda_``: number of weights, alias for ``len(self)`` 33 | - ``mu``: number of strictly positive weights, i.e. 34 | ``sum([wi > 0 for wi in self])`` 35 | - ``mueff``: variance effective number of positive weights, i.e. 36 | ``1 / sum([self[i]**2 for i in range(self.mu)])`` where 37 | ``1 == sum([self[i] for i in range(self.mu)])**2`` 38 | - `mueffminus`: variance effective number of negative weights 39 | - `positive_weights`: `np.array` of the strictly positive weights 40 | - ``finalized``: `True` if class instance is ready to use 41 | 42 | Class methods not inherited from `list`: 43 | 44 | - `finalize_negative_weights`: main method 45 | - `zero_negative_weights`: set negative weights to zero, leads to 46 | ``finalized`` to be `True`. 47 | - `set_attributes_from_weights`: useful when weight values are 48 | "manually" changed, removed or inserted 49 | - `asarray`: alias for ``np.asarray(self)`` 50 | - `do_asserts`: check consistency of weight values, passes also when 51 | not yet ``finalized`` 52 | 53 | Usage: 54 | 55 | >>> # from recombination_weights import RecombinationWeights 56 | >>> from cma.recombination_weights import RecombinationWeights 57 | >>> dimension, popsize = 5, 7 58 | >>> weights = RecombinationWeights(popsize) 59 | >>> c1 = 2. / (dimension + 1)**2 # caveat: __future___ division 60 | >>> cmu = weights.mueff / (weights.mueff + dimension**2) 61 | >>> weights.finalize_negative_weights(dimension, c1, cmu) 62 | >>> print('weights = [%s]' % ', '.join("%.2f" % w for w in weights)) 63 | weights = [0.59, 0.29, 0.12, 0.00, -0.31, -0.57, -0.79] 64 | >>> print("sum=%.2f, c1+cmu*sum=%.2f" % (sum(weights), 65 | ... c1 + cmu * sum(weights))) 66 | sum=-0.67, c1+cmu*sum=0.00 67 | >>> print('mueff=%.1f, mueffminus=%.1f, mueffall=%.1f' % ( 68 | ... weights.mueff, 69 | ... weights.mueffminus, 70 | ... sum(abs(w) for w in weights)**2 / 71 | ... sum(w**2 for w in weights))) 72 | mueff=2.3, mueffminus=2.7, mueffall=4.8 73 | >>> weights = RecombinationWeights(popsize) 74 | >>> print("sum=%.2f, mu=%d, sumpos=%.2f, sumneg=%.2f" % ( 75 | ... sum(weights), 76 | ... weights.mu, 77 | ... sum(weights[:weights.mu]), 78 | ... sum(weights[weights.mu:]))) 79 | sum=0.00, mu=3, sumpos=1.00, sumneg=-1.00 80 | >>> print('weights = [%s]' % ', '.join("%.2f" % w for w in weights)) 81 | weights = [0.59, 0.29, 0.12, 0.00, -0.19, -0.34, -0.47] 82 | >>> weights = RecombinationWeights(21) 83 | >>> weights.finalize_negative_weights(3, 0.081, 0.28) 84 | >>> weights.insert(weights.mu, 0) # add zero weight in the middle 85 | >>> weights = weights.set_attributes_from_weights() # change lambda_ 86 | >>> assert weights.lambda_ == 22 87 | >>> print("sum=%.2f, mu=%d, sumpos=%.2f" % 88 | ... (sum(weights), weights.mu, sum(weights[:weights.mu]))) 89 | sum=0.24, mu=10, sumpos=1.00 90 | >>> print('weights = [%s]%%' % ', '.join(["%.1f" % (100*weights[i]) 91 | ... for i in range(0, 22, 5)])) 92 | weights = [27.0, 6.8, 0.0, -6.1, -11.7]% 93 | >>> weights.zero_negative_weights() # doctest:+ELLIPSIS 94 | [0.270... 95 | >>> "%.2f, %.2f" % (sum(weights), sum(weights[weights.mu:])) 96 | '1.00, 0.00' 97 | >>> mu = int(weights.mu / 2) 98 | >>> for i in range(len(weights)): 99 | ... weights[i] = 1. / mu if i < mu else 0 100 | >>> weights = weights.set_attributes_from_weights() 101 | >>> 5 * "%.1f " % (sum(w for w in weights if w > 0), 102 | ... sum(w for w in weights if w < 0), 103 | ... weights.mu, 104 | ... weights.mueff, 105 | ... weights.mueffminus) 106 | '1.0 0.0 5.0 5.0 0.0 ' 107 | 108 | The optimal weights on the sphere and other functions are closer 109 | to exponent 0.75: 110 | 111 | >>> for expo, w in [(expo, RecombinationWeights(5, exponent=expo)) 112 | ... for expo in [1, 0.9, 0.8, 0.7, 0.6, 0.5]]: 113 | ... print(7 * "%.2f " % tuple([expo, w.mueff] + w)) 114 | 1.00 1.65 0.73 0.27 0.00 -0.36 -0.64 115 | 0.90 1.70 0.71 0.29 0.00 -0.37 -0.63 116 | 0.80 1.75 0.69 0.31 0.00 -0.39 -0.61 117 | 0.70 1.80 0.67 0.33 0.00 -0.40 -0.60 118 | 0.60 1.84 0.65 0.35 0.00 -0.41 -0.59 119 | 0.50 1.89 0.62 0.38 0.00 -0.43 -0.57 120 | 121 | >>> for lam in [8, 8**2, 8**3, 8**4]: 122 | ... if lam == 8: 123 | ... print(" lam expo mueff w[i] / w[i](1)") 124 | ... print(" /mu(1) 1 2 3 4 5 6 7 8") 125 | ... w1 = RecombinationWeights(lam, exponent=1) 126 | ... for expo, w in [(expo, RecombinationWeights(lam, exponent=expo)) 127 | ... for expo in [1, 0.8, 0.6]]: 128 | ... print('%4d ' % lam + 10 * "%.2f " % tuple([expo, w.mueff / w1.mueff] + [w[i] / w1[i] for i in range(8)])) 129 | lam expo mueff w[i] / w[i](1) 130 | /mu(1) 1 2 3 4 5 6 7 8 131 | 8 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 132 | 8 0.80 1.11 0.90 1.02 1.17 1.50 1.30 1.07 0.98 0.93 133 | 8 0.60 1.24 0.80 1.02 1.35 2.21 1.68 1.13 0.95 0.85 134 | 64 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 135 | 64 0.80 1.17 0.82 0.86 0.88 0.91 0.93 0.95 0.97 0.98 136 | 64 0.60 1.36 0.65 0.72 0.76 0.80 0.84 0.87 0.91 0.94 137 | 512 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 138 | 512 0.80 1.20 0.76 0.78 0.79 0.80 0.81 0.82 0.83 0.83 139 | 512 0.60 1.42 0.56 0.59 0.61 0.63 0.64 0.65 0.67 0.68 140 | 4096 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 141 | 4096 0.80 1.21 0.71 0.73 0.74 0.74 0.75 0.75 0.76 0.76 142 | 4096 0.60 1.44 0.50 0.52 0.53 0.54 0.55 0.55 0.56 0.56 143 | 144 | Reference: Hansen 2016, arXiv:1604.00772. 145 | """ 146 | def __init__(self, len_, exponent=1): 147 | """return recombination weights `list`, post condition is 148 | ``sum(self) == 0 and sum(self.positive_weights) == 1``. 149 | 150 | Positive and negative weights sum to 1 and -1, respectively. 151 | The number of positive weights, ``self.mu``, is about 152 | ``len_/2``. Weights are strictly decreasing. 153 | 154 | `finalize_negative_weights` (...) or `zero_negative_weights` () 155 | should be called to finalize the negative weights. 156 | 157 | :param `len_`: AKA ``lambda`` is the number of weights, see 158 | attribute `lambda_` which is an alias for ``len(self)``. 159 | Alternatively, a list of "raw" weights can be provided. 160 | 161 | """ 162 | weights = len_ 163 | self.exponent = exponent # for the record 164 | if exponent is None: 165 | self.exponent = 1 # shall become 4/5 or 3/4? 166 | try: 167 | len_ = len(weights) 168 | except TypeError: 169 | try: # iterator without len 170 | len_ = len(list(weights)) 171 | except TypeError: # create from scratch 172 | def signed_power(x, expo): 173 | if expo == 1: return x 174 | s = (x != 0) * (-1 if x < 0 else 1) 175 | return s * math.fabs(x)**expo 176 | weights = [signed_power(math.log((len_ + 1) / 2.) - math.log(i), self.exponent) 177 | for i in range(1, len_ + 1)] # raw shape 178 | if len_ < 2: 179 | raise ValueError("number of weights must be >=2, was %d" 180 | % (len_)) 181 | self.debug = False 182 | 183 | # self[:] = weights # should do, or 184 | # super(RecombinationWeights, self).__init__(weights) 185 | list.__init__(self, weights) 186 | 187 | self.set_attributes_from_weights(do_asserts=False) 188 | sum_neg = sum(self[self.mu:]) 189 | if sum_neg != 0: 190 | for i in range(self.mu, len(self)): 191 | self[i] /= -sum_neg 192 | self.do_asserts() 193 | self.finalized = False 194 | 195 | def set_attributes_from_weights(self, weights=None, do_asserts=True): 196 | """make the class attribute values consistent with weights, in 197 | case after (re-)setting the weights from input parameter ``weights``, 198 | post condition is also ``sum(self.postive_weights) == 1``. 199 | 200 | This method allows to set or change the weight list manually, 201 | e.g. like ``weights[:] = new_list`` or using the `pop`, 202 | `insert` etc. generic `list` methods to change the list. 203 | Currently, weights must be non-increasing and the first weight 204 | must be strictly positive and the last weight not larger than 205 | zero. Then all ``weights`` are normalized such that the 206 | positive weights sum to one. 207 | """ 208 | if weights is not None: 209 | if not weights[0] > 0: 210 | raise ValueError( 211 | "the first weight must be >0 but was %f" % weights[0]) 212 | if weights[-1] > 0: 213 | raise ValueError( 214 | "the last weight must be <=0 but was %f" % 215 | weights[-1]) 216 | self[:] = weights 217 | weights = self 218 | assert all(weights[i] >= weights[i+1] 219 | for i in range(len(weights) - 1)) 220 | self.mu = sum(w > 0 for w in weights) 221 | spos = sum(weights[:self.mu]) 222 | assert spos > 0 223 | for i in range(len(self)): 224 | self[i] /= spos 225 | # variance-effectiveness of sum^mu w_i x_i 226 | self.mueff = 1**2 / sum(w**2 for w in 227 | weights[:self.mu]) 228 | sneg = sum(weights[self.mu:]) 229 | assert (sneg - sum(w for w in weights if w < 0))**2 < 1e-11 230 | not do_asserts or self.do_asserts() 231 | return self 232 | 233 | def finalize_negative_weights(self, dimension, c1, cmu, pos_def=True): 234 | """finalize negative weights using ``dimension`` and learning 235 | rates ``c1`` and ``cmu``. 236 | 237 | This is a rather intricate method which makes this class 238 | useful. The negative weights are scaled to achieve 239 | in this order: 240 | 241 | 1. zero decay, i.e. ``c1 + cmu * sum w == 0``, 242 | 2. a learning rate respecting mueff, i.e. ``sum |w|^- / sum |w|^+ 243 | <= 1 + 2 * self.mueffminus / (self.mueff + 2)``, 244 | 3. if `pos_def` guaranty positive definiteness when sum w^+ = 1 245 | and all negative input vectors used later have at most their 246 | dimension as squared Mahalanobis norm. This is accomplished by 247 | guarantying ``(dimension-1) * cmu * sum |w|^- < 1 - c1 - cmu`` 248 | via setting ``sum |w|^- <= (1 - c1 -cmu) / dimension / cmu``. 249 | 250 | The latter two conditions do not change the weights with default 251 | population size. 252 | 253 | Details: 254 | 255 | - To guaranty 3., the input vectors associated to negative 256 | weights must obey ||.||^2 <= dimension in Mahalanobis norm. 257 | - The third argument, ``cmu``, usually depends on the 258 | (raw) weights, in particular it depends on ``self.mueff``. 259 | For this reason the calling syntax 260 | ``weights = RecombinationWeights(...).finalize_negative_weights(...)`` 261 | is not supported. 262 | 263 | """ 264 | if dimension <= 0: 265 | raise ValueError("dimension must be larger than zero, was " + 266 | str(dimension)) 267 | self._c1 = c1 # for the record 268 | self._cmu = cmu 269 | 270 | if self[-1] < 0: 271 | if cmu > 0: 272 | if c1 > 10 * cmu: 273 | print("""WARNING: c1/cmu = %f/%f seems to assume a 274 | too large value for negative weights setting""" 275 | % (c1, cmu)) 276 | self._negative_weights_set_sum(1 + c1 / cmu) 277 | if pos_def: 278 | self._negative_weights_limit_sum((1 - c1 - cmu) / cmu 279 | / dimension) 280 | self._negative_weights_limit_sum(1 + 2 * self.mueffminus 281 | / (self.mueff + 2)) 282 | self.do_asserts() 283 | self.finalized = True 284 | 285 | if self.debug: 286 | print("sum w = %.2f (final)" % sum(self)) 287 | 288 | def zero_negative_weights(self): 289 | """finalize by setting all negative weights to zero""" 290 | for k in range(len(self)): 291 | self[k] *= 0 if self[k] < 0 else 1 292 | self.finalized = True 293 | return self 294 | 295 | def _negative_weights_set_sum(self, value): 296 | """set sum of negative weights to ``-abs(value)`` 297 | 298 | Precondition: the last weight must no be greater than zero. 299 | 300 | Details: if no negative weight exists, all zero weights with index 301 | lambda / 2 or greater become uniformely negative. 302 | """ 303 | weights = self # simpler to change to data attribute and nicer to read 304 | value = abs(value) # simplify code, prevent erroneous assertion error 305 | assert weights[self.mu] <= 0 306 | if not weights[-1] < 0: 307 | # breaks if mu == lambda 308 | # we could also just return here 309 | # return 310 | istart = max((self.mu, int(self.lambda_ / 2))) 311 | for i in range(istart, self.lambda_): 312 | weights[i] = -value / (self.lambda_ - istart) 313 | factor = abs(value / sum(weights[self.mu:])) 314 | for i in range(self.mu, self.lambda_): 315 | weights[i] *= factor 316 | assert 1 - value - 1e-5 < sum(weights) < 1 - value + 1e-5 317 | if self.debug: 318 | print("sum w = %.2f, sum w^- = %.2f" % 319 | (sum(weights), -sum(weights[self.mu:]))) 320 | 321 | def _negative_weights_limit_sum(self, value): 322 | """lower bound the sum of negative weights to ``-abs(value)``. 323 | """ 324 | weights = self # simpler to change to data attribute and nicer to read 325 | value = abs(value) # simplify code, prevent erroneous assertion error 326 | if sum(weights[self.mu:]) >= -value: # nothing to limit 327 | return # needed when sum is zero 328 | assert weights[-1] < 0 and weights[self.mu] <= 0 329 | factor = abs(value / sum(weights[self.mu:])) 330 | if factor < 1: 331 | for i in range(self.mu, self.lambda_): 332 | weights[i] *= factor 333 | if self.debug: 334 | print("sum w = %.2f (with correction %.2f)" % 335 | (sum(weights), value)) 336 | assert sum(weights) + 1e-5 >= 1 - value 337 | 338 | def do_asserts(self): 339 | """assert consistency. 340 | 341 | Assert: 342 | 343 | - attribute values of ``lambda_, mu, mueff, mueffminus`` 344 | - value of first and last weight 345 | - monotonicity of weights 346 | - sum of positive weights to be one 347 | 348 | """ 349 | weights = self 350 | assert 1 >= weights[0] > 0 351 | assert weights[-1] <= 0 352 | assert len(weights) == self.lambda_ 353 | assert all(weights[i] >= weights[i+1] 354 | for i in range(len(weights) - 1)) # monotony 355 | assert self.mu > 0 # needed for next assert 356 | assert weights[self.mu-1] > 0 >= weights[self.mu] 357 | assert 0.999 < sum(w for w in weights[:self.mu]) < 1.001 358 | assert (self.mueff / 1.001 < 359 | sum(weights[:self.mu])**2 / sum(w**2 for w in weights[:self.mu]) < 360 | 1.001 * self.mueff) 361 | assert (self.mueffminus == 0 == sum(weights[self.mu:]) or 362 | self.mueffminus / 1.001 < 363 | sum(weights[self.mu:])**2 / sum(w**2 for w in weights[self.mu:]) < 364 | 1.001 * self.mueffminus) 365 | 366 | @property 367 | def lambda_(self): 368 | """alias for ``len(self)``""" 369 | return len(self) 370 | @property 371 | def mueffminus(self): 372 | weights = self 373 | sneg = sum(weights[self.mu:]) 374 | assert (sneg - sum(w for w in weights if w < 0))**2 < 1e-11 375 | return (0 if sneg == 0 else 376 | sneg**2 / sum(w**2 for w in weights[self.mu:])) 377 | @property 378 | def positive_weights(self): 379 | """all (strictly) positive weights as ``np.array``. 380 | 381 | Useful to implement recombination for the new mean vector. 382 | """ 383 | try: 384 | from numpy import asarray 385 | return asarray(self[:self.mu]) 386 | except: 387 | return self[:self.mu] 388 | @property 389 | def asarray(self): 390 | """return weights as numpy array""" 391 | from numpy import asarray 392 | return asarray(self) 393 | -------------------------------------------------------------------------------- /cma/s.py: -------------------------------------------------------------------------------- 1 | """versatile shortcuts for quick typing in an (i)python shell or even 2 | ``from cma.s import *`` in interactive sessions. 3 | 4 | Provides various aliases from within the `cma` package, to be reached like 5 | ``cma.s....`` 6 | 7 | Don't use for stable code. 8 | """ 9 | import warnings as _warnings 10 | try: from matplotlib import pyplot as _pyplot # like this it doesn't show up in the interface 11 | except ImportError: 12 | _pyplot = None 13 | _warnings.warn('Could not import matplotlib.pyplot, therefore' 14 | ' ``cma.plot()`` etc. is not available') 15 | # from . import fitness_functions as ff 16 | from . import evolution_strategy as es 17 | from . import fitness_transformations as ft 18 | from . import transformations as tf 19 | from . import constraints_handler as ch 20 | from .utilities import utils 21 | from .evolution_strategy import CMAEvolutionStrategy as CMAES 22 | from .utilities.utils import pprint 23 | from .utilities.math import Mh 24 | # from .fitness_functions import elli as felli 25 | 26 | if _pyplot: 27 | def figshow(): 28 | """`pyplot.show` to make a plotted figure show up""" 29 | # is_interactive = matplotlib.is_interactive() 30 | 31 | _pyplot.ion() 32 | _pyplot.show() 33 | # if we call now matplotlib.interactive(True), the console is 34 | # blocked 35 | figsave = _pyplot.savefig 36 | 37 | if 11 < 3: 38 | try: 39 | from matplotlib.pyplot import savefig as figsave, close as figclose, ion as figion 40 | except: 41 | figsave, figclose, figshow = 3 * ['not available'] 42 | _warnings.warn('Could not import matplotlib.pyplot, therefore ``cma.plot()``' 43 | ' etc. is not available') 44 | -------------------------------------------------------------------------------- /cma/sigma_adaptation.py: -------------------------------------------------------------------------------- 1 | """step-size adaptation classes, currently tightly linked to CMA, 2 | because `hsig` is computed in the base class 3 | """ 4 | from __future__ import absolute_import, division, print_function #, unicode_literals, with_statement 5 | import numpy as np 6 | from numpy import square as _square, sqrt as _sqrt 7 | from .utilities import utils 8 | from .utilities.math import Mh 9 | def _norm(x): return np.sqrt(np.sum(np.square(x))) 10 | del absolute_import, division, print_function #, unicode_literals, with_statement 11 | 12 | class CMAAdaptSigmaBase(object): 13 | """step-size adaptation base class, implement `hsig` (for stalling 14 | distribution update) functionality via an isotropic evolution path. 15 | 16 | Details: `hsig` or `_update_ps` must be called before the sampling 17 | distribution is changed. `_update_ps` depends heavily on 18 | `cma.CMAEvolutionStrategy`. 19 | """ 20 | def __init__(self, *args, **kwargs): 21 | self.is_initialized_base = False 22 | self._ps_updated_iteration = -1 23 | self.delta = 1 24 | "cumulated effect of adaptation" 25 | def initialize_base(self, es): 26 | """set parameters and state variable based on dimension, 27 | mueff and possibly further options. 28 | 29 | """ 30 | ## meta_parameters.cs_exponent == 1.0 31 | b = 1.0 32 | ## meta_parameters.cs_multiplier == 1.0 33 | self.cs = 1.0 * (es.sp.weights.mueff + 2)**b / (es.N**b + (es.sp.weights.mueff + 3)**b) 34 | self.ps = np.zeros(es.N) 35 | self.is_initialized_base = True 36 | return self 37 | def _update_ps(self, es): 38 | """update the isotropic evolution path. 39 | 40 | Using ``es`` attributes ``mean``, ``mean_old``, ``sigma``, 41 | ``sigma_vec``, ``sp.weights.mueff``, ``cp.cmean`` and 42 | ``sm.transform_inverse``. 43 | 44 | :type es: CMAEvolutionStrategy 45 | """ 46 | if not self.is_initialized_base: 47 | self.initialize_base(es) 48 | if self._ps_updated_iteration == es.countiter: 49 | return 50 | try: 51 | if es.countiter <= es.sm.itereigenupdated: 52 | # es.B and es.D must/should be those from the last iteration 53 | utils.print_warning('distribution transformation (B and D) have been updated before ps could be computed', 54 | '_update_ps', 'CMAAdaptSigmaBase', verbose=es.opts['verbose']) 55 | except AttributeError: 56 | pass 57 | z = es.sm.transform_inverse((es.mean - es.mean_old) / es.sigma_vec.scaling) 58 | # assert Mh.vequals_approximately(z, np.dot(es.B, (1. / es.D) * 59 | # np.dot(es.B.T, (es.mean - es.mean_old) / es.sigma_vec.scaling))) 60 | z *= es.sp.weights.mueff**0.5 / es.sigma / es.sp.cmean 61 | self.ps = (1 - self.cs) * self.ps + (self.cs * (2 - self.cs))**0.5 * z 62 | self._ps_updated_iteration = es.countiter 63 | def hsig(self, es): 64 | """return "OK-signal" for rank-one update, `True` (OK) or `False` 65 | (stall rank-one update), based on the length of an evolution path 66 | 67 | """ 68 | self._update_ps(es) 69 | if self.ps is None: 70 | return True 71 | squared_sum = np.sum(self.ps**2) / (1 - (1 - self.cs)**(2 * es.countiter)) 72 | # correction with self.countiter seems not necessary, 73 | # as pc also starts with zero 74 | return squared_sum / es.N - 1 < 1 + 4. / (es.N + 1) 75 | 76 | def update2(self, es, **kwargs): 77 | """return sigma change factor and update self.delta. 78 | 79 | ``self.delta == sigma/sigma0`` accumulates all past changes 80 | starting from `1.0`. 81 | 82 | Unlike `update`, `update2` is not supposed to change attributes 83 | in `es`, specifically it should not change `es.sigma`. 84 | """ 85 | self._update_ps(es) 86 | raise NotImplementedError('must be implemented in a derived class') 87 | 88 | def update(self, es, **kwargs): 89 | """update ``es.sigma`` 90 | 91 | :param es: `CMAEvolutionStrategy` class instance 92 | :param kwargs: whatever else is needed to update ``es.sigma``, 93 | which should be none. 94 | """ 95 | self._update_ps(es) 96 | raise NotImplementedError('must be implemented in a derived class') 97 | def check_consistency(self, es): 98 | """make consistency checks with a `CMAEvolutionStrategy` instance 99 | as input 100 | """ 101 | class CMAAdaptSigmaNone(CMAAdaptSigmaBase): 102 | """constant step-size sigma""" 103 | def update(self, es, **kwargs): 104 | """no update, ``es.sigma`` remains constant. 105 | """ 106 | pass 107 | class CMAAdaptSigmaDistanceProportional(CMAAdaptSigmaBase): 108 | """artificial setting of ``sigma`` for test purposes, e.g. 109 | to simulate optimal progress rates. 110 | 111 | """ 112 | def __init__(self, coefficient=1.2, **kwargs): 113 | """pass multiplier for normalized step-size""" 114 | super(CMAAdaptSigmaDistanceProportional, self).__init__() # base class provides method hsig() 115 | self.coefficient = coefficient 116 | self.is_initialized = True 117 | def update(self, es, **kwargs): 118 | """need attributes ``N``, ``sp.weights.mueff``, ``mean``, 119 | ``sp.cmean`` of input parameter ``es`` 120 | """ 121 | es.sigma = self.coefficient * es.sp.weights.mueff * _norm(es.mean) / es.N / es.sp.cmean 122 | class CMAAdaptSigmaCSA(CMAAdaptSigmaBase): 123 | """CSA cumulative step-size adaptation AKA path length control. 124 | 125 | As of 2017, CSA is considered as the default step-size control method 126 | within CMA-ES. 127 | """ 128 | def __init__(self, **kwargs): 129 | """postpone initialization to a method call where dimension and mueff should be known. 130 | 131 | """ 132 | self.is_initialized = False 133 | self.delta = 1 134 | def initialize(self, es): 135 | """set parameters and state variable based on dimension, 136 | mueff and possibly further options. 137 | 138 | """ 139 | self.disregard_length_setting = True if es.opts['CSA_disregard_length'] else False 140 | if es.opts['CSA_clip_length_value'] is not None: 141 | try: 142 | if len(es.opts['CSA_clip_length_value']) == 0: 143 | es.opts['CSA_clip_length_value'] = [-np.Inf, np.Inf] 144 | elif len(es.opts['CSA_clip_length_value']) == 1: 145 | es.opts['CSA_clip_length_value'] = [-np.Inf, es.opts['CSA_clip_length_value'][0]] 146 | elif len(es.opts['CSA_clip_length_value']) == 2: 147 | es.opts['CSA_clip_length_value'] = np.sort(es.opts['CSA_clip_length_value']) 148 | else: 149 | raise ValueError('option CSA_clip_length_value should be a number of len(.) in [1,2]') 150 | except TypeError: # len(...) failed 151 | es.opts['CSA_clip_length_value'] = [-np.Inf, es.opts['CSA_clip_length_value']] 152 | es.opts['CSA_clip_length_value'] = list(np.sort(es.opts['CSA_clip_length_value'])) 153 | if es.opts['CSA_clip_length_value'][0] > 0 or es.opts['CSA_clip_length_value'][1] < 0: 154 | raise ValueError('option CSA_clip_length_value must be a single positive or a negative and a positive number') 155 | ## meta_parameters.cs_exponent == 1.0 156 | b = 1.0 157 | ## meta_parameters.cs_multiplier == 1.0 158 | self.cs = 1.0 * (es.sp.weights.mueff + 2)**b / (es.N**b + (es.sp.weights.mueff + 3)**b) 159 | 160 | self.damps = es.opts['CSA_dampfac'] * (0.5 + 161 | 0.5 * min([1, (es.sp.lam_mirr / (0.159 * es.sp.popsize) - 1)**2])**1 + 162 | 2 * max([0, ((es.sp.weights.mueff - 1) / (es.N + 1))**es.opts['CSA_damp_mueff_exponent'] - 1]) + 163 | self.cs 164 | ) 165 | self.max_delta_log_sigma = 1 # in symmetric use (strict lower bound is -cs/damps anyway) 166 | 167 | if self.disregard_length_setting: 168 | es.opts['CSA_clip_length_value'] = [0, 0] 169 | ## meta_parameters.cs_exponent == 1.0 170 | b = 1.0 * 0.5 171 | ## meta_parameters.cs_multiplier == 1.0 172 | self.cs = 1.0 * (es.sp.weights.mueff + 1)**b / (es.N**b + 2 * es.sp.weights.mueff**b) 173 | self.damps = es.opts['CSA_dampfac'] * 1 # * (1.1 - 1/(es.N+1)**0.5) 174 | if es.opts['verbose'] > 1: 175 | print('CMAAdaptSigmaCSA Parameters: ') 176 | for k, v in self.__dict__.items(): 177 | print(' ', k, ':', v) 178 | self.ps = np.zeros(es.N) 179 | self._ps_updated_iteration = -1 180 | self.is_initialized = True 181 | def _update_ps(self, es): 182 | """update path with isotropic delta mean, possibly clipped. 183 | 184 | From input argument `es`, the attributes isotropic_mean_shift, 185 | opts['CSA_clip_length_value'], and N are used. 186 | opts['CSA_clip_length_value'] can be a single value, the upper 187 | bound parameter, such that:: 188 | 189 | max_len = sqrt(N) + opts['CSA_clip_length_value'] * N / (N+2) 190 | 191 | or a list with lower and upper bound parameters. 192 | """ 193 | if not self.is_initialized: 194 | self.initialize(es) 195 | if self._ps_updated_iteration == es.countiter: 196 | return 197 | z = es.isotropic_mean_shift 198 | if es.opts['CSA_clip_length_value'] is not None: 199 | vals = es.opts['CSA_clip_length_value'] 200 | try: len(vals) 201 | except TypeError: vals = [-np.inf, vals] 202 | if vals[0] > 0 or vals[1] < 0: 203 | raise ValueError( 204 | """value(s) for option 'CSA_clip_length_value' = %s 205 | not allowed""" % str(es.opts['CSA_clip_length_value'])) 206 | min_len = es.N**0.5 + vals[0] * es.N / (es.N + 2) 207 | max_len = es.N**0.5 + vals[1] * es.N / (es.N + 2) 208 | act_len = _norm(z) 209 | new_len = Mh.minmax(act_len, min_len, max_len) 210 | if new_len != act_len: 211 | z *= new_len / act_len 212 | # z *= (es.N / sum(z**2))**0.5 # ==> sum(z**2) == es.N 213 | # z *= es.const.chiN / sum(z**2)**0.5 214 | self.ps = (1 - self.cs) * self.ps + _sqrt(self.cs * (2 - self.cs)) * z 215 | self._ps_updated_iteration = es.countiter 216 | def update2(self, es, **kwargs): 217 | """call ``self._update_ps(es)`` and update self.delta. 218 | 219 | Return change factor of self.delta. 220 | 221 | From input `es`, either attribute N or const.chiN is used. 222 | """ 223 | self._update_ps(es) # caveat: if es.B or es.D are already updated and ps is not, this goes wrong! 224 | p = self.ps 225 | if 'pc for ps' in es.opts['vv']: 226 | # was: es.D**-1 * np.dot(es.B.T, es.pc) 227 | p = es.sm.transform_inverse(es.pc) 228 | if es.opts['CSA_squared']: 229 | s = (sum(_square(p)) / es.N - 1) / 2 230 | # sum(self.ps**2) / es.N has mean 1 and std sqrt(2/N) and is skewed 231 | # divided by 2 to have the derivative d/dx (x**2 / N - 1) for x**2=N equal to 1 232 | else: 233 | s = _norm(p) / es.const.chiN - 1 234 | s *= self.cs / self.damps 235 | s_clipped = Mh.minmax(s, -self.max_delta_log_sigma, self.max_delta_log_sigma) 236 | # "error" handling 237 | if s_clipped != s: 238 | utils.print_warning('sigma change np.exp(' + str(s) + ') = ' + str(np.exp(s)) + 239 | ' clipped to np.exp(+-' + str(self.max_delta_log_sigma) + ')', 240 | 'update', 241 | 'CMAAdaptSigmaCSA', 242 | es.countiter, es.opts['verbose']) 243 | self.delta *= np.exp(s_clipped) 244 | return np.exp(s_clipped) 245 | def update(self, es, **kwargs): 246 | """call ``self._update_ps(es)`` and update ``es.sigma``. 247 | 248 | Legacy method replaced by `update2`. 249 | """ 250 | es.sigma *= self.update2(es, **kwargs) 251 | if 11 < 3: 252 | # derandomized MSR = natural gradient descent using mean(z**2) instead of mu*mean(z)**2 253 | fit = kwargs['fit'] # == es.fit 254 | slengths = np.array([sum(z**2) for z in es.arz[fit.idx[:es.sp.weights.mu]]]) 255 | # print lengths[0::int(es.sp.weights.mu/5)] 256 | es.sigma *= np.exp(np.dot(es.sp.weights, slengths / es.N - 1))**(2 / (es.N + 1)) 257 | if 11 < 3: 258 | es.more_to_write.append(10**((sum(self.ps**2) / es.N / 2 - 1 / 2 if es.opts['CSA_squared'] else _norm(self.ps) / es.const.chiN - 1))) 259 | es.more_to_write.append(10**(-3.5 + sum(self.ps**2) / es.N / 2 - _norm(self.ps) / es.const.chiN)) 260 | # es.more_to_write.append(10**(-3 + sum(es.arz[es.fit.idx[0]]**2) / es.N)) 261 | 262 | class CMAAdaptSigmaMedianImprovement(CMAAdaptSigmaBase): 263 | """Compares median fitness to the 27%tile fitness of the 264 | previous iteration, see Ait ElHara et al, GECCO 2013. 265 | 266 | >>> import cma 267 | >>> es = cma.CMAEvolutionStrategy(3 * [1], 1, 268 | ... {'AdaptSigma':cma.sigma_adaptation.CMAAdaptSigmaMedianImprovement, 269 | ... 'verbose': -9}) 270 | >>> assert es.optimize(cma.ff.elli).result[1] < 1e-9 271 | >>> assert es.result[2] < 2000 272 | 273 | """ 274 | def __init__(self, **kwargs): 275 | CMAAdaptSigmaBase.__init__(self) # base class provides method hsig() 276 | # super(CMAAdaptSigmaMedianImprovement, self).__init__() 277 | def initialize(self, es): 278 | """late initialization using attributes ``N`` and ``popsize``""" 279 | r = es.sp.weights.mueff / es.popsize 280 | self.index_to_compare = 0.5 * (r**0.5 + 2.0 * (1 - r**0.5) / np.log(es.N + 9)**2) * (es.popsize) # TODO 281 | self.index_to_compare = 0.30 * es.popsize # TODO 282 | self.damp = 2 - 2 / es.N # sign-rule: 2 283 | self.c = 0.3 # sign-rule needs <= 0.3 284 | self.s = 0 # averaged statistics, usually between -1 and +1 285 | def update(self, es, **kwargs): 286 | if es.countiter < 2: 287 | self.initialize(es) 288 | self.fit = es.fit.fit 289 | else: 290 | ft1, ft2 = self.fit[int(self.index_to_compare)], self.fit[int(np.ceil(self.index_to_compare))] 291 | ftt1, ftt2 = es.fit.fit[(es.popsize - 1) // 2], es.fit.fit[int(np.ceil((es.popsize - 1) / 2))] 292 | pt2 = self.index_to_compare - int(self.index_to_compare) 293 | # ptt2 = (es.popsize - 1) / 2 - (es.popsize - 1) // 2 # not in use 294 | s = 0 295 | if 1 < 3: 296 | s += pt2 * sum(es.fit.fit <= self.fit[int(np.ceil(self.index_to_compare))]) 297 | s += (1 - pt2) * sum(es.fit.fit < self.fit[int(self.index_to_compare)]) 298 | s -= es.popsize / 2. 299 | s *= 2. / es.popsize # the range was popsize, is 2 300 | elif 11 < 3: # compare ft with median of ftt 301 | s += self.index_to_compare - sum(self.fit <= es.fit.fit[es.popsize // 2]) 302 | s *= 2 / es.popsize # the range was popsize, is 2 303 | else: # compare ftt j-index of ft 304 | s += (1 - pt2) * np.sign(ft1 - ftt1) 305 | s += pt2 * np.sign(ft2 - ftt1) 306 | self.s = (1 - self.c) * self.s + self.c * s 307 | es.sigma *= np.exp(self.s / self.damp) 308 | # es.more_to_write.append(10**(self.s)) 309 | 310 | #es.more_to_write.append(10**((2 / es.popsize) * (sum(es.fit.fit < self.fit[int(self.index_to_compare)]) - (es.popsize + 1) / 2))) 311 | # # es.more_to_write.append(10**(self.index_to_compare - sum(self.fit <= es.fit.fit[es.popsize // 2]))) 312 | # # es.more_to_write.append(10**(np.sign(self.fit[int(self.index_to_compare)] - es.fit.fit[es.popsize // 2]))) 313 | if 11 < 3: 314 | import scipy.stats.stats as stats 315 | zkendall = stats.kendalltau(list(es.fit.fit) + list(self.fit), 316 | len(es.fit.fit) * [0] + len(self.fit) * [1])[0] 317 | es.more_to_write.append(10**zkendall) 318 | self.fit = es.fit.fit 319 | class CMAAdaptSigmaTPA(CMAAdaptSigmaBase): 320 | """two point adaptation for step-size sigma. 321 | 322 | Relies on a specific sampling of the first two offspring, whose 323 | objective function value ranks are used to decide on the step-size 324 | change, see `update` for the specifics. 325 | 326 | Example 327 | ======= 328 | 329 | >>> import cma 330 | >>> cma.CMAOptions('adapt').pprint() # doctest: +ELLIPSIS 331 | AdaptSigma='True... 332 | >>> es = cma.CMAEvolutionStrategy(10 * [0.2], 0.1, 333 | ... {'AdaptSigma': cma.sigma_adaptation.CMAAdaptSigmaTPA, 334 | ... 'ftarget': 1e-8}) # doctest: +ELLIPSIS 335 | (5_w,10)-aCMA-ES (mu_w=3.2,w_1=45%) in dimension 10 (seed=... 336 | >>> es.optimize(cma.ff.rosen) # doctest: +ELLIPSIS 337 | Iter... 338 | >>> assert 'ftarget' in es.stop() 339 | >>> assert es.result[1] <= 1e-8 # should coincide with the above 340 | >>> assert es.result[2] < 6500 # typically < 5500 341 | 342 | References: loosely based on Hansen 2008, CMA-ES with Two-Point 343 | Step-Size Adaptation, more tightly based on Hansen et al. 2014, 344 | How to Assess Step-Size Adaptation Mechanisms in Randomized Search. 345 | 346 | """ 347 | def __init__(self, dimension=None, opts=None, **kwargs): 348 | super(CMAAdaptSigmaTPA, self).__init__() # base class provides method hsig() 349 | # CMAAdaptSigmaBase.__init__(self) 350 | self.initialized = False 351 | self.dimension = dimension 352 | self.opts = opts 353 | def initialize(self, N=None, opts=None): 354 | """late initialization. 355 | 356 | :param N: is used for the (minor) dependency on dimension, 357 | :param opts: is used for hacking 358 | """ 359 | if self.initialized is True: 360 | return self 361 | self.initialized = False 362 | if N is None: 363 | N = self.dimension 364 | if opts is None: 365 | opts = self.opts 366 | try: 367 | damp_fac = opts['CSA_dampfac'] # should be renamed to sigma_adapt_dampfac or something 368 | except (TypeError, KeyError): 369 | damp_fac = 1 370 | 371 | self.sp = utils.BlancClass() # just a container to have sp.name instead of sp['name'] to access parameters 372 | try: 373 | self.sp.damp = damp_fac * eval('N')**0.5 # (1) why do we need 10 <-> np.exp(1/10) == 1.1? 2 should be fine!? 374 | self.sp.damp = damp_fac * (4 - 3.6/eval('N')**0.5) # (2) should become new default!? 375 | self.sp.damp = damp_fac * eval('N')**0.25 376 | self.sp.damp = 0.7 + np.log(eval('N')) # between 2 and 9 very close to N**1/2, for N=7 equal to (1) and (2) 377 | # self.sp.damp = 100 378 | except: 379 | self.sp.damp = 4 # or 1 + np.log(10) 380 | self.initialized = 1/2 381 | try: 382 | self.sp.damp = opts['vv']['TPA_damp'] 383 | print('damp set to %d' % self.sp.damp) 384 | except (KeyError, TypeError): 385 | pass 386 | 387 | self.sp.dampup = 0.5**0.0 * 1.0 * self.sp.damp # 0.5 fails to converge on the Rastrigin function 388 | self.sp.dampdown = 2.0**0.0 * self.sp.damp 389 | if self.sp.dampup != self.sp.dampdown: 390 | print('TPA damping is asymmetric') 391 | self.sp.c = 0.3 # rank difference is asymetric and therefore the switch from increase to decrease takes too long 392 | self.sp.z_exponent = 0.5 # sign(z) * abs(z)**z_exponent, 0.5 seems better with larger popsize, 1 was default 393 | self.sp.sigma_fac = 1.0 # (obsolete) 0.5 feels better, but no evidence whether it is 394 | self.sp.relative_to_delta_mean = True # (obsolete) 395 | self.s = 0 # the state/summation variable 396 | self.last = None 397 | if not self.initialized: 398 | self.initialized = True 399 | return self 400 | def update(self, es, function_values, **kwargs): 401 | """the first and second value in ``function_values`` 402 | must reflect two mirrored solutions. 403 | 404 | Mirrored solutions must have been sampled 405 | in direction / in opposite direction of 406 | the previous mean shift, respectively. 407 | """ 408 | # On the linear function, the two mirrored samples lead 409 | # to a sharp increase of the condition of the covariance matrix, 410 | # unless we have negative weights (which we have now by default). 411 | # Otherwise they should not be used to update the covariance 412 | # matrix, if the step-size inreases quickly. 413 | if self.initialized is not True: # try again 414 | self.initialize(es.N, es.opts) 415 | if self.initialized is not True: 416 | utils.print_warning("dimension not known, damping set to 4", 417 | 'update', 'CMAAdaptSigmaTPA') 418 | self.initialized = True 419 | if 1 < 3: 420 | f_vals = np.asarray(function_values) 421 | z = np.sum(f_vals < f_vals[1]) - np.sum(f_vals < f_vals[0]) 422 | z /= len(f_vals) - 1 # z in [-1, 1] 423 | elif 1 < 3: 424 | # use the ranking difference of the mirrors for adaptation 425 | # damp = 5 should be fine 426 | z = np.nonzero(es.fit.idx == 1)[0][0] - np.nonzero(es.fit.idx == 0)[0][0] 427 | z /= es.popsize - 1 # z in [-1, 1] 428 | self.s = (1 - self.sp.c) * self.s + self.sp.c * np.sign(z) * np.abs(z)**self.sp.z_exponent 429 | if self.s > 0: 430 | es.sigma *= np.exp(self.s / self.sp.dampup) 431 | else: 432 | es.sigma *= np.exp(self.s / self.sp.dampdown) 433 | #es.more_to_write.append(10**z) 434 | 435 | def check_consistency(self, es): 436 | assert isinstance(es.adapt_sigma, CMAAdaptSigmaTPA) 437 | if es.countiter > 3: 438 | dm = es.mean[0] - es.mean_old[0] 439 | dx0 = es.pop[0][0] - es.mean_old[0] 440 | dx1 = es.pop[1][0] - es.mean_old[0] 441 | for i in np.random.randint(1, es.N, 1): 442 | if dx0 * dx1 * (es.pop[0][i] - es.mean_old[i]) * ( 443 | es.pop[1][i] - es.mean_old[i]): 444 | dmi_div_dx0i = (es.mean[i] - es.mean_old[i]) \ 445 | / (es.pop[0][i] - es.mean_old[i]) 446 | dmi_div_dx1i = (es.mean[i] - es.mean_old[i]) \ 447 | / (es.pop[1][i] - es.mean_old[i]) 448 | if not Mh.equals_approximately( 449 | dmi_div_dx0i, dm / dx0, 1e-4) or \ 450 | not Mh.equals_approximately( 451 | dmi_div_dx1i, dm / dx1, 1e-4): 452 | utils.print_warning( 453 | 'TPA: apparent inconsistency with mirrored' 454 | ' samples, where dmi_div_dx0i, dm/dx0=%f, %f' 455 | ' and dmi_div_dx1i, dm/dx1=%f, %f' % ( 456 | dmi_div_dx0i, dm/dx0, dmi_div_dx1i, dm/dx1), 457 | 'check_consistency', 458 | 'CMAAdaptSigmaTPA', es.countiter) 459 | else: 460 | utils.print_warning('zero delta encountered in TPA which' + 461 | ' \nshould be very rare and might be a bug' + 462 | ' (sigma=%f)' % es.sigma, 463 | 'check_consistency', 'CMAAdaptSigmaTPA', 464 | es.countiter) 465 | 466 | -------------------------------------------------------------------------------- /cma/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """test module of `cma` package. 3 | 4 | Usage:: 5 | 6 | python -m cma.test -h # print this docstring 7 | python -m cma.test # doctest all (listed) files 8 | python -m cma.test list # list files to be doctested 9 | python -m cma.test interfaces.py [file2 [file3 [...]]] # doctest only these 10 | 11 | or possibly by executing this file as a script:: 12 | 13 | python cma/test.py # same options as above work 14 | cma/test.py # the same 15 | 16 | or equivalently by passing Python code:: 17 | 18 | python -c "import cma.test; cma.test.main()" # doctest all (listed) files 19 | python -c "import cma.test; cma.test.main('list')" # show files in doctest list 20 | python -c "import cma.test; cma.test.main('interfaces.py [file2 [file3 [...]]]')" 21 | python -c "import cma.test; help(cma.test)" # print this docstring 22 | 23 | File(name)s are interpreted within the package. Without a filename 24 | argument, all files from attribute `files_for_doc_test` are tested. 25 | """ 26 | 27 | # (note to self) for testing: 28 | # pyflakes cma.py # finds bugs by static analysis 29 | # pychecker --limit 60 cma.py # also executes, all 60 warnings checked 30 | # or python ~/Downloads/pychecker-0.8.19/pychecker/checker.py cma.py 31 | # python -3 -m cma 2> out2to3warnings.txt # produces no warnings from here 32 | 33 | from __future__ import (absolute_import, division, print_function, 34 | ) # unicode_literals) 35 | import os, sys 36 | import doctest 37 | del absolute_import, division, print_function #, unicode_literals 38 | 39 | files_for_doctest = ['bbobbenchmarks.py', 40 | 'constraints_handler.py', 41 | 'evolution_strategy.py', 42 | 'fitness_functions.py', 43 | 'fitness_models.py', 44 | 'fitness_transformations.py', 45 | 'interfaces.py', 46 | 'logger.py', 47 | 'optimization_tools.py', 48 | 'purecma.py', 49 | 'recombination_weights.py', 50 | 'restricted_gaussian_sampler.py', 51 | 'sampler.py', 52 | 'sigma_adaptation.py', 53 | 'test.py', 54 | 'transformations.py', 55 | os.path.join('utilities', 'math.py'), 56 | os.path.join('utilities', 'utils.py'), 57 | ] 58 | _files_written = ['_saved-cma-object.pkl', 59 | 'outcmaesaxlen.dat', 60 | 'outcmaesaxlencorr.dat', 61 | 'outcmaesfit.dat', 62 | 'outcmaesstddev.dat', 63 | 'outcmaesxmean.dat', 64 | 'outcmaesxrecentbest.dat', 65 | ] 66 | """files written by the doc tests and hence, in case, to be deleted""" 67 | 68 | PY2 = sys.version_info[0] == 2 69 | def _clean_up(folder, start_matches, protected): 70 | """(permanently) remove entries in ``folder`` which begin with any of 71 | ``start_matches``, where ``""`` matches any string, and which are not 72 | in ``protected``. 73 | 74 | CAVEAT: use with care, as with ``"", ""`` as second and third 75 | arguments this could delete all files in ``folder``. 76 | """ 77 | if not os.path.isdir(folder): 78 | return 79 | if not protected and "" in start_matches: 80 | raise ValueError( 81 | '''_clean_up(folder, [..., "", ...], []) is not permitted as it 82 | resembles "rm *"''') 83 | protected = protected + ["/"] 84 | for file_ in os.listdir(folder): 85 | if any(file_.startswith(s) for s in start_matches) \ 86 | and not any(file_.startswith(p) for p in protected): 87 | os.remove(os.path.join(folder, file_)) 88 | def is_str(var): # copy from utils to avoid relative import 89 | """`bytes` (in Python 3) also fit the bill""" 90 | if PY2: 91 | types_ = (str, unicode) 92 | else: 93 | types_ = (str, bytes) 94 | return any(isinstance(var, type_) for type_ in types_) 95 | 96 | def various_doctests(): 97 | """various doc tests. 98 | 99 | This function describes test cases and might in future become 100 | helpful as an experimental tutorial as well. The main testing feature 101 | at the moment is by doctest with ``cma.test.main()`` in a Python shell 102 | or by ``python -m cma.test`` in a system shell. 103 | 104 | A simple first overall test: 105 | 106 | >>> import cma 107 | >>> res = cma.fmin(cma.ff.elli, 3*[1], 1, 108 | ... {'CMA_diagonal':2, 'seed':1, 'verbose':-9}) 109 | >>> assert res[1] < 1e-6 110 | >>> assert res[2] < 2000 111 | 112 | Testing `args` argument: 113 | 114 | >>> def maxcorr(m): 115 | ... val = 0 116 | ... for i in range(len(m)): 117 | ... for j in range(i + 1, len(m)): 118 | ... val = max((val, abs(m[i, j]))) 119 | ... return val 120 | >>> x, es = cma.fmin2(cma.ff.elli, [1, 0, 0], 0.5, {'verbose':-9}, args=[True]) # rotated 121 | >>> assert maxcorr(es.sm.correlation_matrix) > 0.9, es.sm.correlation_matrix 122 | >>> es = cma.CMAEvolutionStrategy([1, 0, 0], 0.5, 123 | ... {'verbose':-9}).optimize(cma.ff.elli, args=[1]) 124 | >>> assert maxcorr(es.sm.correlation_matrix) > 0.9, es.sm.correlation_matrix 125 | 126 | Testing output file consistency with diagonal option: 127 | 128 | >>> import cma 129 | >>> for val in (0, True, 2, 3): 130 | ... _ = cma.fmin(cma.ff.sphere, 3 * [1], 1, 131 | ... {'verb_disp':0, 'CMA_diagonal':val, 'maxiter':5}) 132 | ... _ = cma.CMADataLogger().load() 133 | 134 | Test on the Rosenbrock function with 3 restarts. The first trial only 135 | finds the local optimum, which happens in about 20% of the cases. 136 | 137 | >>> import cma 138 | >>> res = cma.fmin(cma.ff.rosen, 4 * [-1], 0.01, 139 | ... options={'ftarget':1e-6, 140 | ... 'verb_time':0, 'verb_disp':500, 141 | ... 'seed':3}, 142 | ... restarts=3) 143 | ... # doctest: +ELLIPSIS 144 | (4_w,8)-aCMA-ES (mu_w=2.6,w_1=52%) in dimension 4 (seed=3,...) 145 | Iterat #Fevals ... 146 | >>> assert res[1] <= 1e-6 147 | 148 | Notice the different termination conditions. Termination on the target 149 | function value ftarget prevents further restarts. 150 | 151 | Test of scaling_of_variables option 152 | 153 | >>> import cma 154 | >>> opts = cma.CMAOptions() 155 | >>> opts['seed'] = 4567 156 | >>> opts['verb_disp'] = 0 157 | >>> opts['CMA_const_trace'] = True 158 | >>> # rescaling of third variable: for searching in roughly 159 | >>> # x0 plus/minus 1e3*sigma0 (instead of plus/minus sigma0) 160 | >>> opts['scaling_of_variables'] = [1, 1, 1e3, 1] 161 | >>> res = cma.fmin(cma.ff.rosen, 4 * [0.1], 0.1, opts) 162 | >>> assert res[1] < 1e-9 163 | >>> es = res[-2] 164 | >>> es.result_pretty() # doctest: +ELLIPSIS 165 | termination on tolfun=1e-11 166 | final/bestever f-value = ... 167 | 168 | The printed std deviations reflect the actual value in the 169 | parameters of the function (not the one in the internal 170 | representation which can be different). 171 | 172 | Test of CMA_stds scaling option. 173 | 174 | >>> import cma 175 | >>> opts = cma.CMAOptions() 176 | >>> s = 5 * [1] 177 | >>> s[0] = 1e3 178 | >>> opts.set('CMA_stds', s) #doctest: +ELLIPSIS 179 | {'... 180 | >>> opts.set('verb_disp', 0) #doctest: +ELLIPSIS 181 | {'... 182 | >>> res = cma.fmin(cma.ff.cigar, 5 * [0.1], 0.1, opts) 183 | >>> assert res[1] < 1800 184 | 185 | Testing combination of ``fixed_variables`` and ``CMA_stds`` options. 186 | 187 | >>> import cma 188 | >>> options = { 189 | ... 'fixed_variables':{1:2.345}, 190 | ... 'CMA_stds': 4 * [1], 191 | ... 'minstd': 3 * [1]} 192 | >>> es = cma.CMAEvolutionStrategy(4 * [1], 1, options) #doctest: +ELLIPSIS 193 | (3_w,7)-aCMA-ES (mu_w=2.3,w_1=58%) in dimension 3 (seed=... 194 | 195 | Test of elitism: 196 | 197 | >>> import cma 198 | >>> res = cma.fmin(cma.ff.rastrigin, 10 * [0.1], 2, 199 | ... {'CMA_elitist':'initial', 'ftarget':1e-3, 'verbose':-9}) 200 | >>> assert 'ftarget' in res[7] 201 | 202 | Test CMA_on option and similar: 203 | 204 | >>> import cma 205 | >>> res = cma.fmin(cma.ff.sphere, 4 * [1], 2, 206 | ... {'CMA_on':False, 'ftarget':1e-8, 'verbose':-9}) 207 | >>> assert 'ftarget' in res[7] and res[2] < 1e3 208 | >>> res = cma.fmin(cma.ff.sphere, 3 * [1], 2, 209 | ... {'CMA_rankone':0, 'CMA_rankmu':0, 'ftarget':1e-8, 210 | ... 'verbose':-9}) 211 | >>> assert 'ftarget' in res[7] and res[2] < 1e3 212 | >>> res = cma.fmin(cma.ff.sphere, 2 * [1], 2, 213 | ... {'CMA_rankone':0, 'ftarget':1e-8, 'verbose':-9}) 214 | >>> assert 'ftarget' in res[7] and res[2] < 1e3 215 | >>> res = cma.fmin(cma.ff.sphere, 2 * [1], 2, 216 | ... {'CMA_rankmu':0, 'ftarget':1e-8, 'verbose':-9}) 217 | >>> assert 'ftarget' in res[7] and res[2] < 1e3 218 | 219 | Check rotational invariance: 220 | 221 | >>> import cma 222 | >>> felli = cma.s.ft.Shifted(cma.ff.elli) 223 | >>> frot = cma.s.ft.Rotated(felli) 224 | >>> res_elli = cma.CMAEvolutionStrategy(3 * [1], 1, 225 | ... {'ftarget': 1e-8}).optimize(felli).result 226 | ... #doctest: +ELLIPSIS 227 | (3_w,7)-... 228 | >>> res_rot = cma.CMAEvolutionStrategy(3 * [1], 1, 229 | ... {'ftarget': 1e-8}).optimize(frot).result 230 | ... #doctest: +ELLIPSIS 231 | (3_w,7)-... 232 | >>> assert res_rot[3] < 2 * res_elli[3] 233 | 234 | Both condition alleviation transformations are applied during this 235 | test, first in iteration 62, second in iteration 257: 236 | 237 | >>> import cma 238 | >>> ftabletrot = cma.fitness_transformations.Rotated(cma.ff.tablet, seed=10) 239 | >>> es = cma.CMAEvolutionStrategy(4 * [1], 1, { 240 | ... 'tolconditioncov':False, 241 | ... 'seed': 8, 242 | ... 'CMA_mirrors': 0, 243 | ... 'ftarget': 1e-9, 244 | ... }) # doctest:+ELLIPSIS 245 | (4_w... 246 | >>> while not es.stop() and es.countiter < 82: 247 | ... X = es.ask() 248 | ... es.tell(X, [cma.ff.elli(x, cond=1e22) for x in X]) # doctest:+ELLIPSIS 249 | NOTE ...iteration=81... 250 | >>> while not es.stop(): 251 | ... X = es.ask() 252 | ... es.tell(X, [ftabletrot(x) for x in X]) # doctest:+ELLIPSIS 253 | >>> assert es.countiter <= 344 and 'ftarget' in es.stop(), ( 254 | ... "transformation bug in alleviate_condition?", 255 | ... es.countiter, es.stop()) 256 | 257 | Integer handling: 258 | 259 | >>> import warnings 260 | >>> idx = [0, 1, 5, -1] 261 | >>> f = cma.s.ft.IntegerMixedFunction(cma.ff.elli, idx) 262 | >>> with warnings.catch_warnings(record=True) as warns: 263 | ... es = cma.CMAEvolutionStrategy(4 * [5], 10, dict( 264 | ... ftarget=1e-9, seed=5, 265 | ... integer_variables=idx 266 | ... )) # doctest:+ELLIPSIS 267 | (4_w,8)-... 268 | >>> warns[0].message # doctest:+ELLIPSIS 269 | UserWarning('integer index 5 not in range of dimension 4 ()'... 270 | >>> es.optimize(f) # doctest:+ELLIPSIS 271 | Iterat #Fevals function value ... 272 | >>> assert 'ftarget' in es.stop() and es.result[3] < 1800 273 | 274 | Parallel objective: 275 | 276 | >>> def parallel_sphere(X): return [cma.ff.sphere(x) for x in X] 277 | >>> x, es = cma.fmin2(cma.ff.sphere, 3 * [0], 0.1, { 278 | ... 'verbose': -9, 'eval_final_mean': True, 'CMA_elitist': 'initial'}, 279 | ... parallel_objective=parallel_sphere) 280 | >>> assert es.result[1] < 1e-9 281 | >>> x, es = cma.fmin2(None, 3 * [0], 0.1, { 282 | ... 'verbose': -9, 'eval_final_mean': True, 'CMA_elitist': 'initial'}, 283 | ... parallel_objective=parallel_sphere) 284 | >>> assert es.result[1] < 1e-9 285 | 286 | Some sort of interactive control via an options file: 287 | 288 | >>> es = cma.CMAEvolutionStrategy(4 * [2], 1, dict( 289 | ... signals_filename='cma_signals.in', 290 | ... verbose=-9)) 291 | >>> s = es.stop() 292 | >>> es = es.optimize(cma.ff.sphere) 293 | 294 | Test of huge lambda: 295 | 296 | >>> es = cma.CMAEvolutionStrategy(3 * [0.91], 1, { 297 | ... 'verbose': -9, 298 | ... 'popsize': 200, 299 | ... 'ftarget': 1e-8 }) 300 | >>> es = es.optimize(cma.ff.tablet) 301 | >>> if es.result.evaluations > 5000: print(es.result.evalutions, es.result) 302 | 303 | For VD- and VkD-CMA, see `cma.restricted_gaussian_sampler`. 304 | 305 | >>> import sys 306 | >>> import cma 307 | >>> assert cma.interfaces.EvalParallel2 is not None 308 | >>> try: 309 | ... with warnings.catch_warnings(record=True) as warn: 310 | ... with cma.optimization_tools.EvalParallel2(cma.ff.elli) as eval_all: 311 | ... res = eval_all([[1,2], [3,4]]) 312 | ... except: 313 | ... assert sys.version[0] == '2' 314 | 315 | """ 316 | 317 | def doctest_files(file_list=files_for_doctest, **kwargs): 318 | """doctest all (listed) files of the `cma` package. 319 | 320 | Details: accepts ``verbose`` and all other keyword arguments that 321 | `doctest.testfile` would accept, while negative ``verbose`` values 322 | are passed as 0. 323 | """ 324 | # print("__name__ is", __name__, sys.modules[__name__]) 325 | # print(__package__) 326 | if not isinstance(file_list, list) and is_str(file_list): 327 | file_list = [file_list] 328 | verbosity_here = kwargs.get('verbose', 0) 329 | if verbosity_here < 0: 330 | kwargs['verbose'] = 0 331 | failures = 0 332 | for file_ in file_list: 333 | file_ = file_.strip().strip(os.path.sep) 334 | if file_.startswith('cma' + os.path.sep): 335 | file_ = file_[4:] 336 | if verbosity_here >= 0: 337 | print('doctesting %s ...' % file_, 338 | ' ' * (max(len(_file) for _file in file_list) - 339 | len(file_)), 340 | end="") # does not work in Python 2.5 341 | sys.stdout.flush() 342 | protected_files = os.listdir('.') 343 | report = doctest.testfile(file_, 344 | package=__package__, # 'cma', # sys.modules[__name__], 345 | **kwargs) 346 | _clean_up('.', _files_written, protected_files) 347 | failures += report[0] 348 | if verbosity_here >= 0: 349 | print(report) 350 | return failures 351 | 352 | def get_version(): 353 | try: 354 | with open(__file__[:-7] + '__init__.py', 'r') as f: 355 | for line in f.readlines(): 356 | if line.startswith('__version__'): 357 | return line[15:].split()[0] 358 | except: 359 | return "" 360 | print(__file__) 361 | raise 362 | 363 | def main(*args, **kwargs): 364 | """test the `cma` package. 365 | 366 | The first argument can be '-h' or '--help' or 'list' to list all 367 | files to be tested. Otherwise, arguments can be file(name)s to be 368 | tested, where names are interpreted relative to the package root 369 | and a leading 'cma' + path separator is ignored. 370 | 371 | By default all files are tested. 372 | 373 | :See also: ``python -c "import cma.test; help(cma.test)"`` 374 | """ 375 | if len(args) > 0: 376 | if args[0].startswith(('-h', '--h')): 377 | print(__doc__) 378 | exit(0) 379 | elif args[0].startswith('list'): 380 | for file_ in files_for_doctest: 381 | print(file_) 382 | exit(0) 383 | else: 384 | v = get_version() 385 | print("doctesting `cma` package%s by calling `doctest_files`:" 386 | % ((" (v%s)" % v) if v else "")) 387 | return doctest_files(args if args else files_for_doctest, **kwargs) 388 | 389 | if __name__ == "__main__": 390 | exit(main(*sys.argv[1:]) > 0) # 0 if failures == 0 else 1 -------------------------------------------------------------------------------- /cma/utilities/__init__.py: -------------------------------------------------------------------------------- 1 | """various unspecific utilities""" -------------------------------------------------------------------------------- /cma/utilities/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/utilities/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /cma/utilities/__pycache__/math.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/utilities/__pycache__/math.cpython-39.pyc -------------------------------------------------------------------------------- /cma/utilities/__pycache__/python3for2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/utilities/__pycache__/python3for2.cpython-39.pyc -------------------------------------------------------------------------------- /cma/utilities/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/utilities/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /cma/utilities/math.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ various math utilities, notably `eig` and a collection of simple 3 | functions in `Mh` 4 | """ 5 | from __future__ import absolute_import, division, print_function #, unicode_literals 6 | # from future.builtins.disabled import * # don't use any function which could lead to different results in Python 2 vs 3 7 | import warnings as _warnings 8 | import numpy as np 9 | from .python3for2 import range 10 | del absolute_import, division, print_function #, unicode_literals 11 | 12 | def _sqrt_len(x): # makes randhss option pickable 13 | return len(x)**0.5 14 | 15 | def randhss(n, dim, norm_=_sqrt_len, randn=np.random.randn): 16 | """`n` iid `dim`-dimensional vectors with length ``norm_(vector)``. 17 | 18 | The vectors are uniformly distributed on a hypersphere surface. 19 | 20 | CMA-ES diverges with popsize 100 in 15-D without option 21 | 'CSA_clip_length_value': [0,0]. 22 | 23 | >>> from cma.utilities.math import randhss 24 | >>> dim = 3 25 | >>> assert dim - 1e-7 < sum(randhss(1, dim)[0]**2) < dim + 1e-7 26 | 27 | """ 28 | arv = randn(n, dim) 29 | for v in arv: 30 | v *= norm_(v) / np.sum(v**2)**0.5 31 | return arv 32 | 33 | def randhss_mixin(n, dim, norm_=_sqrt_len, 34 | c=lambda d: 1. / d, randn=np.random.randn): 35 | """`n` iid vectors uniformly distributed on the hypersphere surface with 36 | mixing in of normal distribution, which can be beneficial in smaller 37 | dimension. 38 | """ 39 | arv = randhss(n, dim, norm_, randn) 40 | c = min((1, c(dim))) 41 | if c > 0: 42 | if c > 1: # can never happen 43 | raise ValueError("c(dim)=%f should be <=1" % c) 44 | for v in arv: 45 | v *= (1 - c**2)**0.5 # has 2 / c longer time horizon than 1 - c 46 | v += c * randn(1, dim)[0] # c is sqrt(2/c) times smaller than sqrt(c * (2 - c)) 47 | return arv 48 | 49 | def to_correlation_matrix(c): 50 | """change C in place into a correlation matrix, AKA whitening""" 51 | for i in range(c.shape[0]): 52 | fac = c[i, i]**0.5 53 | c[:, i] /= fac 54 | c[i, :] /= fac 55 | c = (c + c.T) / 2.0 56 | assert np.allclose(np.diag(c), 1) 57 | return c 58 | 59 | _warnings.filterwarnings( # one message for each value (which is given in the message) 60 | 'once', message="using exponential smoothing with .* rolling average") 61 | def moving_average(x, w=7): 62 | """rolling average without biasing boundary effects. 63 | 64 | The first entries give the average over all first 65 | values (until the window width is reached). 66 | 67 | If `w` is not an integer, expontential smoothing with weights 68 | proportionate to ``(1 - 1/w)**i`` summing to one is executed, thereby 69 | putting about 1 - exp(-1) ≈ 0.63 of the weight sum on the last `w` 70 | entries. 71 | 72 | Details: the average is mainly based on `np.convolve`, whereas 73 | exponential smoothing is for the time being numerically inefficient and 74 | scales quadratically with the length of `x`. 75 | """ 76 | if w == 1: 77 | return x 78 | elif isinstance(w, int): # interpret as window width 79 | w = min((w, len(x))) # window width 80 | return np.hstack([[np.mean(x[:i]) for i in range(1, w)], 81 | np.convolve(x, w * [1 / w], mode='valid')]) 82 | else: # exponential smoothing 83 | if w == int(w): 84 | _warnings.warn("using exponential smoothing with time" 85 | " horizon {}. \nUse `int` type to get the" 86 | " rolling average.".format(w)) 87 | v = 1 - 1 / w 88 | return np.asarray([sum([v**j * x[i-j] for j in range(i + 1)]) 89 | / sum([v**j for j in range(i + 1)]) 90 | for i in range(len(x))]) 91 | 92 | def Hessian(f, x0, eps=1e-6): 93 | """Hessian estimate for `f` at `x0`""" 94 | if eps is None: 95 | eps = 1e-6 # restore default 96 | x0 = np.asarray(x0) 97 | e = np.eye(len(x0)) 98 | H = 0 * e 99 | for i in range(len(x0)): 100 | ei = eps * e[i] 101 | for j in range(i+1): 102 | ej = eps * e[j] 103 | H[i,j] = (f(x0 + ei + ej) - f(x0 + ei) - f(x0 + ej) + f(x0)) / eps**2 104 | H[j, i] = H[i, j] 105 | return H 106 | 107 | def geometric_sd(vals, **kwargs): 108 | """return geometric standard deviation of `vals`. 109 | 110 | The gsd is invariant under linear scaling and independent 111 | of the choice of the log-exp base. 112 | 113 | ``kwargs`` are passed to `np.std`, in particular `ddof`. 114 | """ 115 | return np.exp(np.std(np.log(vals), **kwargs)) 116 | 117 | # ____________________________________________________________ 118 | # ____________________________________________________________ 119 | # 120 | # C and B are arrays rather than matrices, because they are 121 | # addressed via B[i][j], matrices can only be addressed via B[i,j] 122 | 123 | # tred2(N, B, diagD, offdiag); 124 | # tql2(N, diagD, offdiag, B); 125 | 126 | 127 | # Symmetric Householder reduction to tridiagonal form, translated from JAMA package. 128 | def eig(C): 129 | """eigendecomposition of a symmetric matrix, much slower than 130 | `numpy.linalg.eigh`, return ``(EVals, Basis)``, the eigenvalues 131 | and an orthonormal basis of the corresponding eigenvectors, where 132 | 133 | ``Basis[i]`` 134 | the i-th row of ``Basis`` 135 | columns of ``Basis``, ``[Basis[j][i] for j in range(len(Basis))]`` 136 | the i-th eigenvector with eigenvalue ``EVals[i]`` 137 | 138 | """ 139 | 140 | # class eig(object): 141 | # def __call__(self, C): 142 | 143 | # Householder transformation of a symmetric matrix V into tridiagonal form. 144 | # -> n : dimension 145 | # -> V : symmetric nxn-matrix 146 | # <- V : orthogonal transformation matrix: 147 | # tridiag matrix == V * V_in * V^t 148 | # <- d : diagonal 149 | # <- e[0..n-1] : off diagonal (elements 1..n-1) 150 | 151 | # Symmetric tridiagonal QL algorithm, iterative 152 | # Computes the eigensystem from a tridiagonal matrix in roughtly 3N^3 operations 153 | # -> n : Dimension. 154 | # -> d : Diagonale of tridiagonal matrix. 155 | # -> e[1..n-1] : off-diagonal, output from Householder 156 | # -> V : matrix output von Householder 157 | # <- d : eigenvalues 158 | # <- e : garbage? 159 | # <- V : basis of eigenvectors, according to d 160 | 161 | 162 | # tred2(N, B, diagD, offdiag); B=C on input 163 | # tql2(N, diagD, offdiag, B); 164 | 165 | # private void tred2 (int n, double V[][], double d[], double e[]) { 166 | def tred2 (n, V, d, e): 167 | # This is derived from the Algol procedures tred2 by 168 | # Bowdler, Martin, Reinsch, and Wilkinson, Handbook for 169 | # Auto. Comp., Vol.ii-Linear Algebra, and the corresponding 170 | # Fortran subroutine in EISPACK. 171 | 172 | num_opt = False # factor 1.5 in 30-D 173 | 174 | for j in range(n): 175 | d[j] = V[n - 1][j] # d is output argument 176 | 177 | # Householder reduction to tridiagonal form. 178 | 179 | for i in range(n - 1, 0, -1): 180 | # Scale to avoid under/overflow. 181 | h = 0.0 182 | if not num_opt: 183 | scale = 0.0 184 | for k in range(i): 185 | scale = scale + abs(d[k]) 186 | else: 187 | scale = sum(abs(d[0:i])) 188 | 189 | if scale == 0.0: 190 | e[i] = d[i - 1] 191 | for j in range(i): 192 | d[j] = V[i - 1][j] 193 | V[i][j] = 0.0 194 | V[j][i] = 0.0 195 | else: 196 | 197 | # Generate Householder vector. 198 | if not num_opt: 199 | for k in range(i): 200 | d[k] /= scale 201 | h += d[k] * d[k] 202 | else: 203 | d[:i] /= scale 204 | h = np.dot(d[:i], d[:i]) 205 | 206 | f = d[i - 1] 207 | g = h**0.5 208 | 209 | if f > 0: 210 | g = -g 211 | 212 | e[i] = scale * g 213 | h = h - f * g 214 | d[i - 1] = f - g 215 | if not num_opt: 216 | for j in range(i): 217 | e[j] = 0.0 218 | else: 219 | e[:i] = 0.0 220 | 221 | # Apply similarity transformation to remaining columns. 222 | 223 | for j in range(i): 224 | f = d[j] 225 | V[j][i] = f 226 | g = e[j] + V[j][j] * f 227 | if not num_opt: 228 | for k in range(j + 1, i): 229 | g += V[k][j] * d[k] 230 | e[k] += V[k][j] * f 231 | e[j] = g 232 | else: 233 | e[j + 1:i] += V.T[j][j + 1:i] * f 234 | e[j] = g + np.dot(V.T[j][j + 1:i], d[j + 1:i]) 235 | 236 | f = 0.0 237 | if not num_opt: 238 | for j in range(i): 239 | e[j] /= h 240 | f += e[j] * d[j] 241 | else: 242 | e[:i] /= h 243 | f += np.dot(e[:i], d[:i]) 244 | 245 | hh = f / (h + h) 246 | if not num_opt: 247 | for j in range(i): 248 | e[j] -= hh * d[j] 249 | else: 250 | e[:i] -= hh * d[:i] 251 | 252 | for j in range(i): 253 | f = d[j] 254 | g = e[j] 255 | if not num_opt: 256 | for k in range(j, i): 257 | V[k][j] -= (f * e[k] + g * d[k]) 258 | else: 259 | V.T[j][j:i] -= (f * e[j:i] + g * d[j:i]) 260 | 261 | d[j] = V[i - 1][j] 262 | V[i][j] = 0.0 263 | 264 | d[i] = h 265 | # end for i-- 266 | 267 | # Accumulate transformations. 268 | 269 | for i in range(n - 1): 270 | V[n - 1][i] = V[i][i] 271 | V[i][i] = 1.0 272 | h = d[i + 1] 273 | if h != 0.0: 274 | if not num_opt: 275 | for k in range(i + 1): 276 | d[k] = V[k][i + 1] / h 277 | else: 278 | d[:i + 1] = V.T[i + 1][:i + 1] / h 279 | 280 | for j in range(i + 1): 281 | if not num_opt: 282 | g = 0.0 283 | for k in range(i + 1): 284 | g += V[k][i + 1] * V[k][j] 285 | for k in range(i + 1): 286 | V[k][j] -= g * d[k] 287 | else: 288 | g = np.dot(V.T[i + 1][0:i + 1], V.T[j][0:i + 1]) 289 | V.T[j][:i + 1] -= g * d[:i + 1] 290 | 291 | if not num_opt: 292 | for k in range(i + 1): 293 | V[k][i + 1] = 0.0 294 | else: 295 | V.T[i + 1][:i + 1] = 0.0 296 | 297 | 298 | if not num_opt: 299 | for j in range(n): 300 | d[j] = V[n - 1][j] 301 | V[n - 1][j] = 0.0 302 | else: 303 | d[:n] = V[n - 1][:n] 304 | V[n - 1][:n] = 0.0 305 | 306 | V[n - 1][n - 1] = 1.0 307 | e[0] = 0.0 308 | 309 | 310 | # Symmetric tridiagonal QL algorithm, taken from JAMA package. 311 | # private void tql2 (int n, double d[], double e[], double V[][]) { 312 | # needs roughly 3N^3 operations 313 | def tql2 (n, d, e, V): 314 | 315 | # This is derived from the Algol procedures tql2, by 316 | # Bowdler, Martin, Reinsch, and Wilkinson, Handbook for 317 | # Auto. Comp., Vol.ii-Linear Algebra, and the corresponding 318 | # Fortran subroutine in EISPACK. 319 | 320 | num_opt = False # using vectors from numpy makes it faster 321 | 322 | if not num_opt: 323 | for i in range(1, n): # (int i = 1; i < n; i++): 324 | e[i - 1] = e[i] 325 | else: 326 | e[0:n - 1] = e[1:n] 327 | e[n - 1] = 0.0 328 | 329 | f = 0.0 330 | tst1 = 0.0 331 | eps = 2.0**-52.0 332 | for l in range(n): # (int l = 0; l < n; l++) { 333 | 334 | # Find small subdiagonal element 335 | 336 | tst1 = max(tst1, abs(d[l]) + abs(e[l])) 337 | m = l 338 | while m < n: 339 | if abs(e[m]) <= eps * tst1: 340 | break 341 | m += 1 342 | 343 | # If m == l, d[l] is an eigenvalue, 344 | # otherwise, iterate. 345 | 346 | if m > l: 347 | iiter = 0 348 | while 1: # do { 349 | iiter += 1 # (Could check iteration count here.) 350 | 351 | # Compute implicit shift 352 | 353 | g = d[l] 354 | p = (d[l + 1] - g) / (2.0 * e[l]) 355 | r = (p**2 + 1)**0.5 # hypot(p,1.0) 356 | if p < 0: 357 | r = -r 358 | 359 | d[l] = e[l] / (p + r) 360 | d[l + 1] = e[l] * (p + r) 361 | dl1 = d[l + 1] 362 | h = g - d[l] 363 | if not num_opt: 364 | for i in range(l + 2, n): 365 | d[i] -= h 366 | else: 367 | d[l + 2:n] -= h 368 | 369 | f = f + h 370 | 371 | # Implicit QL transformation. 372 | 373 | p = d[m] 374 | c = 1.0 375 | c2 = c 376 | c3 = c 377 | el1 = e[l + 1] 378 | s = 0.0 379 | s2 = 0.0 380 | 381 | # hh = V.T[0].copy() # only with num_opt 382 | for i in range(m - 1, l - 1, -1): # (int i = m-1; i >= l; i--) { 383 | c3 = c2 384 | c2 = c 385 | s2 = s 386 | g = c * e[i] 387 | h = c * p 388 | r = (p**2 + e[i]**2)**0.5 # hypot(p,e[i]) 389 | e[i + 1] = s * r 390 | s = e[i] / r 391 | c = p / r 392 | p = c * d[i] - s * g 393 | d[i + 1] = h + s * (c * g + s * d[i]) 394 | 395 | # Accumulate transformation. 396 | 397 | if not num_opt: # overall factor 3 in 30-D 398 | for k in range(n): # (int k = 0; k < n; k++) { 399 | h = V[k][i + 1] 400 | V[k][i + 1] = s * V[k][i] + c * h 401 | V[k][i] = c * V[k][i] - s * h 402 | else: # about 20% faster in 10-D 403 | hh = V.T[i + 1].copy() 404 | # hh[:] = V.T[i+1][:] 405 | V.T[i + 1] = s * V.T[i] + c * hh 406 | V.T[i] = c * V.T[i] - s * hh 407 | # V.T[i] *= c 408 | # V.T[i] -= s * hh 409 | 410 | p = -s * s2 * c3 * el1 * e[l] / dl1 411 | e[l] = s * p 412 | d[l] = c * p 413 | 414 | # Check for convergence. 415 | if abs(e[l]) <= eps * tst1: 416 | break 417 | # } while (Math.abs(e[l]) > eps*tst1); 418 | 419 | d[l] = d[l] + f 420 | e[l] = 0.0 421 | 422 | 423 | # Sort eigenvalues and corresponding vectors. 424 | if 11 < 3: 425 | for i in range(n - 1): # (int i = 0; i < n-1; i++) { 426 | k = i 427 | p = d[i] 428 | for j in range(i + 1, n): # (int j = i+1; j < n; j++) { 429 | if d[j] < p: # NH find smallest k>i 430 | k = j 431 | p = d[j] 432 | 433 | if k != i: 434 | d[k] = d[i] # swap k and i 435 | d[i] = p 436 | for j in range(n): # (int j = 0; j < n; j++) { 437 | p = V[j][i] 438 | V[j][i] = V[j][k] 439 | V[j][k] = p 440 | # tql2 441 | 442 | N = len(C[0]) 443 | if 11 < 3: 444 | V = np.array([x[:] for x in C]) # copy each "row" 445 | N = V[0].size 446 | d = np.zeros(N) 447 | e = np.zeros(N) 448 | else: 449 | V = [[x[i] for i in range(N)] for x in C] # copy each "row" 450 | d = N * [0.] 451 | e = N * [0.] 452 | 453 | tred2(N, V, d, e) 454 | tql2(N, d, e, V) 455 | return np.array(d), np.array(V) 456 | 457 | class MathHelperFunctions(object): 458 | """static convenience math helper functions, if the function name 459 | is preceded with an "a", a numpy array is returned 460 | 461 | TODO: there is probably no good reason why this should be a class and not a 462 | module. 463 | 464 | """ 465 | @staticmethod 466 | def aclamp(x, upper): 467 | return -MathHelperFunctions.apos(-x, -upper) 468 | @staticmethod 469 | def equals_approximately(a, b, eps=1e-12): 470 | if a < 0: 471 | a, b = -1 * a, -1 * b 472 | return (a - eps < b < a + eps) or ((1 - eps) * a < b < (1 + eps) * a) 473 | @staticmethod 474 | def vequals_approximately(a, b, eps=1e-12): 475 | a, b = np.array(a), np.array(b) 476 | idx = np.nonzero(a < 0)[0] # find 477 | if len(idx): 478 | a[idx], b[idx] = -1 * a[idx], -1 * b[idx] 479 | return (np.all(a - eps < b) and np.all(b < a + eps) 480 | ) or (np.all((1 - eps) * a < b) and np.all(b < (1 + eps) * a)) 481 | @staticmethod 482 | def expms(A, eig=np.linalg.eigh): 483 | """matrix exponential for a symmetric matrix""" 484 | # TODO: check that this works reliably for low rank matrices 485 | # first: symmetrize A 486 | D, B = eig(A) 487 | return np.dot(B, (np.exp(D) * B).T) 488 | @staticmethod 489 | def amax(vec, vec_or_scalar): 490 | return np.array(MathHelperFunctions.max(vec, vec_or_scalar)) 491 | @staticmethod 492 | def max(vec, vec_or_scalar): 493 | b = vec_or_scalar 494 | if np.isscalar(b): 495 | m = [max(x, b) for x in vec] 496 | else: 497 | m = [max(vec[i], b[i]) for i in range(len((vec)))] 498 | return m 499 | @staticmethod 500 | def minmax(val, min_val, max_val): 501 | assert min_val <= max_val 502 | return min((max_val, max((val, min_val)))) 503 | @staticmethod 504 | def aminmax(val, min_val, max_val): 505 | return np.array([min((max_val, max((v, min_val)))) for v in val]) 506 | @staticmethod 507 | def amin(vec_or_scalar, vec_or_scalar2): 508 | return np.array(MathHelperFunctions.min(vec_or_scalar, vec_or_scalar2)) 509 | @staticmethod 510 | def min(a, b): 511 | iss = np.isscalar 512 | if iss(a) and iss(b): 513 | return min(a, b) 514 | if iss(a): 515 | a, b = b, a 516 | # now only b can be still a scalar 517 | if iss(b): 518 | return [min(x, b) for x in a] 519 | else: # two non-scalars must have the same length 520 | return [min(a[i], b[i]) for i in range(len((a)))] 521 | @staticmethod 522 | def norm(vec, expo=2): 523 | return sum(vec**expo)**(1 / expo) 524 | @staticmethod 525 | def apos(x, lower=0): 526 | """clips argument (scalar or array) from below at lower""" 527 | if lower == 0: 528 | return (x > 0) * x 529 | else: 530 | return lower + (x > lower) * (x - lower) 531 | 532 | @staticmethod 533 | def apenalty_quadlin(x, lower=0, upper=None): 534 | """Huber-like smooth penality which starts at lower. 535 | 536 | The penalty is zero below lower and affine linear above upper. 537 | 538 | Return:: 539 | 540 | 0, if x <= lower 541 | quadratic in x, if lower <= x <= upper 542 | affine linear in x with slope upper - lower, if x >= upper 543 | 544 | `upper` defaults to ``lower + 1``. 545 | 546 | """ 547 | if upper is None: 548 | upper = np.asarray(lower) + 1 549 | z = np.asarray(x) - lower 550 | del x # assert that x is not used anymore accidentally 551 | u = np.asarray(upper) - lower 552 | return (z > 0) * ((z <= u) * (z ** 2 / 2) + (z > u) * u * (z - u / 2)) 553 | 554 | @staticmethod 555 | def prctile(data, p_vals=[0, 25, 50, 75, 100], sorted_=False): 556 | """``prctile(data, 50)`` returns the median, but p_vals can 557 | also be a sequence. 558 | 559 | Provides for small samples or extremes IMHO better values than 560 | matplotlib.mlab.prctile or np.percentile, however also slower. 561 | 562 | """ 563 | ps = [p_vals] if np.isscalar(p_vals) else p_vals 564 | 565 | if not sorted_: 566 | data = sorted(data) 567 | n = len(data) 568 | d = [] 569 | for p in ps: 570 | fi = p * n / 100 - 0.5 571 | if fi <= 0: # maybe extrapolate? 572 | d.append(data[0]) 573 | elif fi >= n - 1: 574 | d.append(data[-1]) 575 | else: 576 | i = int(fi) 577 | d.append((i + 1 - fi) * data[i] + (fi - i) * data[i + 1]) 578 | return d[0] if np.isscalar(p_vals) else d 579 | @staticmethod 580 | def iqr(data, percentile_function=np.percentile): # MathHelperFunctions.prctile 581 | """interquartile range""" 582 | q25, q75 = percentile_function(data, [25, 75]) 583 | return np.asarray(q75) - np.asarray(q25) 584 | @staticmethod 585 | def interdecilerange(data, percentile_function=np.percentile): 586 | """return 10% to 90% range width""" 587 | q10, q90 = percentile_function(data, [10, 90]) 588 | return np.asarray(q90) - np.asarray(q10) 589 | @staticmethod 590 | def logit10(x, lower=0, upper=1): 591 | """map [lower, upper] -> R such that 592 | 593 | :: 594 | 595 | upper - 10^-x -> x, and 596 | lower + 10^-x -> -x 597 | 598 | for large enough x. By default, simplifies close to `log10(x / (1 - x))`. 599 | 600 | >>> from cma.utilities.math import Mh 601 | >>> l, u = -1, 2 602 | >>> print(Mh.logit10([l+0.01, 0.5, u-0.01], l, u)) 603 | [-1.9949189 0. 1.9949189] 604 | 605 | """ 606 | x = np.asarray(x) 607 | z = (x - lower) / (upper - lower) # between 0 and 1 608 | return np.log10((x - lower)**(1-z) / (upper - x)**z) 609 | return (1 - z) * np.log10(x - lower) - z * np.log10(upper - x) 610 | @staticmethod 611 | def sround(nb): # TODO: to be vectorized 612 | """return stochastic round: int(nb) + (rand() 1000: 618 | n = np.random.randn() / np.random.randn() 619 | return n / 25 620 | @staticmethod 621 | def standard_finite_cauchy(size=1): 622 | try: 623 | l = len(size) 624 | except TypeError: 625 | l = 0 626 | 627 | if l == 0: 628 | return np.array([MathHelperFunctions.cauchy_with_variance_one() for _i in range(size)]) 629 | elif l == 1: 630 | return np.array([MathHelperFunctions.cauchy_with_variance_one() for _i in range(size[0])]) 631 | elif l == 2: 632 | return np.array([[MathHelperFunctions.cauchy_with_variance_one() for _i in range(size[1])] 633 | for _j in range(size[0])]) 634 | else: 635 | raise ValueError('len(size) cannot be larger than two') 636 | 637 | Mh = MathHelperFunctions 638 | -------------------------------------------------------------------------------- /cma/utilities/python3for2.py: -------------------------------------------------------------------------------- 1 | """to execute Python 3 code in Python 2. 2 | 3 | redefines builtin `range` and `input` functions and `abc` either via `collections` 4 | or `collections.abc` if available. 5 | """ 6 | import sys 7 | import collections as _collections 8 | 9 | range = range # to allow (trivial) explicit import also in Python 3 10 | input = input 11 | 12 | if sys.version[0] == '2': # in python 2 13 | range = xrange # clean way: from builtins import range 14 | input = raw_input # in py2, input(x) == eval(raw_input(x)) 15 | abc = _collections # never used 16 | 17 | # only for testing, because `future` may not be installed 18 | # from future.builtins import * 19 | if 11 < 3: # newint does produce an error on some installations 20 | try: 21 | from future.builtins.disabled import * # rather not necessary if tested also in Python 3 22 | from future.builtins import ( 23 | bytes, dict, int, list, object, range, 24 | str, ascii, chr, hex, input, next, oct, open, 25 | pow, round, super, filter, map, zip 26 | ) 27 | from builtins import ( # not list and object, by default builtins don't exist in Python 2 28 | ascii, bytes, chr, dict, filter, hex, input, 29 | int, map, next, oct, open, pow, range, round, 30 | str, super, zip) 31 | except ImportError: 32 | pass 33 | else: 34 | try: 35 | abc = _collections.abc 36 | except AttributeError: 37 | abc = _collections 38 | -------------------------------------------------------------------------------- /cma/wrapper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | '''Interface wrappers for the `cma` module. 3 | 4 | The `SkoptCMAoptimizer` wrapper interfaces an optimizer aligned with 5 | `skopt.optimizer`. 6 | ''' 7 | # built-in 8 | import pdb 9 | import copy 10 | import inspect 11 | import tempfile 12 | import os 13 | import warnings 14 | 15 | # external 16 | import numpy as np 17 | import cma # caveat: does not import necessarily the code of this root folder? 18 | 19 | try: import skopt 20 | except ImportError: warnings.warn('install `skopt` ("pip install scikit-optimize") ' 21 | 'to use `SkoptCMAoptimizer`') 22 | else: 23 | def SkoptCMAoptimizer( 24 | func, dimensions, n_calls, verbose=False, callback=(), x0=None, n_jobs=1, 25 | sigma0=.5, normalize=True, 26 | ): 27 | ''' 28 | Optmizer based on CMA-ES algorithm. 29 | This is essentially a wrapper fuction for the cma library function 30 | to align the interface with skopt library. 31 | 32 | Args: 33 | func (callable): function to optimize 34 | dimensions: list of tuples like ``4 * [(-1., 1.)]`` for defining the domain. 35 | n_calls: the number of samples. 36 | verbose: if this func should be verbose 37 | callback: the list of callback functions. 38 | n_jobs: number of cores to run different calls to `func` in parallel. 39 | x0: inital values 40 | if None, random point will be sampled 41 | sigma0: initial standard deviation relative to domain width 42 | normalize: whether optimization domain should be normalized 43 | 44 | Returns: 45 | `res` skopt.OptimizeResult object 46 | The optimization result returned as a dict object. 47 | Important attributes are: 48 | - `x` [list]: location of the minimum. 49 | - `fun` [float]: function value at the minimum. 50 | - `x_iters` [list of lists]: location of function evaluation for each 51 | iteration. 52 | - `func_vals` [array]: function value for each iteration. 53 | - `space` [skopt.space.Space]: the optimization space. 54 | 55 | Example:: 56 | 57 | import cma.wrapper 58 | res = cma.wrapper.SkoptCMAoptimizer(lambda x: sum([xi**2 for xi in x]), 59 | 2 * [(-1.,1.)], 55) 60 | res['cma_es'].logger.plot() 61 | 62 | ''' 63 | specs = { 64 | 'args': copy.copy(inspect.currentframe().f_locals), 65 | 'function': inspect.currentframe().f_code.co_name, 66 | } 67 | 68 | if normalize: dimensions = list(map(lambda x: skopt.space.check_dimension(x, 'normalize'), dimensions)) 69 | space = skopt.space.Space(dimensions) 70 | if x0 is None: x0 = space.transform(space.rvs())[0] 71 | else: x0 = space.transform([x0])[0] 72 | 73 | tempdir = tempfile.mkdtemp() 74 | xi, yi = [], [] 75 | options = { 76 | 'bounds': np.array(space.transformed_bounds).transpose().tolist(), 77 | 'verb_filenameprefix': tempdir, 78 | } 79 | 80 | def delete_tempdir(self, *args, **kargs): 81 | os.removedirs(tempdir) 82 | return 83 | 84 | model = cma.CMAEvolutionStrategy(x0, sigma0, options) 85 | model.logger.__del__ = delete_tempdir 86 | switch = { -1: None, # use number of available CPUs 87 | 1: 0, # avoid using multiprocessor for just one CPU 88 | } 89 | with cma.optimization_tools.EvalParallel2(func, 90 | number_of_processes=switch.get(n_jobs, n_jobs)) as parallel_func: 91 | for _i in range(n_calls): 92 | if model.stop(): break 93 | new_xi = model.ask() 94 | new_xi_denorm = space.inverse_transform(np.array(new_xi)) 95 | # new_yi = [func(x) for x in new_xi_denorm] 96 | new_yi = parallel_func(new_xi_denorm) 97 | 98 | model.tell(new_xi, new_yi) 99 | model.logger.add() 100 | if verbose: model.disp() 101 | 102 | xi += new_xi_denorm 103 | yi += new_yi 104 | results = skopt.utils.create_result(xi, yi) 105 | for f in callback: f(results) 106 | 107 | results = skopt.utils.create_result(xi, yi, space) 108 | model.logger.load() 109 | results.cma_es = model 110 | results.cma_logger = model.logger 111 | results.specs = specs 112 | return results 113 | -------------------------------------------------------------------------------- /embeding_distribution.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plot the textual embedding and projection w_p Q distribution in stable diffusion. 3 | """ 4 | 5 | import os 6 | import argparse 7 | 8 | from transformers import CLIPTextModel, CLIPTokenizer 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from sklearn.decomposition import PCA 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--model_path", default='./ckpt', type=str) 16 | parser.add_argument("--model_dim", default=768, type=int) 17 | parser.add_argument("--lamda", default=5, type=int) 18 | args = parser.parse_args() 19 | 20 | 21 | text_encoder = CLIPTextModel.from_pretrained( 22 | os.path.join(args.model_path, "text_encoder") 23 | ) 24 | embedding = text_encoder.get_input_embeddings().weight.clone().cpu() 25 | print(embedding.size()) 26 | embedding = embedding.detach().cpu().numpy() 27 | mu_hat = np.mean(embedding.reshape(-1)) 28 | std_hat = np.std(embedding.reshape(-1)) 29 | print(mu_hat, std_hat) 30 | number = embedding.reshape(-1).shape[0] 31 | normal = np.random.normal(loc=0, scale=1 / args.model_dim * args.lamda, size = number) 32 | sampling = np.random.normal(loc=0, scale=std_hat * args.lamda, size = number) 33 | 34 | 35 | 36 | pca = PCA(n_components=args.model_dim) 37 | pca.fit(embedding) 38 | pca = pca.components_.reshape(-1) 39 | 40 | # initialize the Q with norm(0, 0.5) 41 | cma = np.random.normal(loc=0, scale=0.5, size = number) 42 | 43 | # projection distribution with W_p Q 44 | normal = cma * normal 45 | sampling = cma * sampling 46 | 47 | cma_pca = np.random.normal(loc=0, scale=0.5, size = pca.shape[0]) 48 | pca = cma_pca * pca 49 | 50 | 51 | kwargs = dict(alpha=0.5, bins=100, density=True, stacked=True) 52 | embedding = embedding.reshape(-1) 53 | 54 | plt.hist(embedding, **kwargs, color='g', label='Textual Embedding') 55 | plt.hist(normal, **kwargs, color='r', label='Random Norm') 56 | plt.hist(pca, **kwargs, color='black', label='PCA') 57 | plt.hist(sampling, **kwargs, color='b', label='Prior Norm') 58 | 59 | plt.gca().set(ylabel='Frequency') 60 | plt.xlim(-0.1,0.1) 61 | plt.legend() 62 | plt.show() 63 | 64 | 65 | if __name__ == '__main__': 66 | main() -------------------------------------------------------------------------------- /figures/case.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/figures/case.png -------------------------------------------------------------------------------- /figures/cma.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/figures/cma.gif -------------------------------------------------------------------------------- /figures/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/figures/framework.png -------------------------------------------------------------------------------- /infer_inversion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from torch import autocast 5 | 6 | import PIL 7 | from PIL import Image 8 | 9 | from diffusers import StableDiffusionPipeline 10 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 11 | 12 | 13 | def image_grid(imgs, rows, cols): 14 | assert len(imgs) == rows*cols 15 | 16 | w, h = imgs[0].size 17 | grid = Image.new('RGB', size=(cols*w, rows*h)) 18 | grid_w, grid_h = grid.size 19 | 20 | for i, img in enumerate(imgs): 21 | grid.paste(img, box=(i%cols*w, i//cols*h)) 22 | return grid 23 | 24 | 25 | def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None): 26 | loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") 27 | 28 | # separate token and the embeds 29 | trained_token = list(loaded_learned_embeds.keys())[0] 30 | embeds = loaded_learned_embeds[trained_token] 31 | 32 | # cast to dtype of text_encoder 33 | dtype = text_encoder.get_input_embeddings().weight.dtype 34 | embeds.to(dtype) 35 | 36 | # add the token in tokenizer 37 | token = token if token is not None else trained_token 38 | num_added_tokens = tokenizer.add_tokens(token) 39 | if num_added_tokens == 0: 40 | raise ValueError(f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer.") 41 | 42 | # resize the token embeddings 43 | text_encoder.resize_token_embeddings(len(tokenizer)) 44 | 45 | # get the id for the token and assign the embeds 46 | token_id = tokenizer.convert_tokens_to_ids(token) 47 | text_encoder.get_input_embeddings().weight.data[token_id] = embeds 48 | 49 | 50 | 51 | def main(): 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument("--model_path", default='./ckpt', type=str) 54 | parser.add_argument("--inversion_path", default='./save/learned_embeds.bin' , type=str) 55 | parser.add_argument("--prompt", default='city under the sun, painting, in a style of ' , type=str) 56 | args = parser.parse_args() 57 | 58 | model_path = args.model_path 59 | learned_embeds_path = args.inversion_path 60 | device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") 61 | 62 | tokenizer = CLIPTokenizer.from_pretrained( 63 | os.path.join(model_path, 'tokenizer') 64 | ) 65 | text_encoder = CLIPTextModel.from_pretrained( 66 | os.path.join(model_path, 'text_encoder') 67 | ) 68 | load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer) 69 | 70 | pipe = StableDiffusionPipeline.from_pretrained( 71 | model_path, 72 | text_encoder=text_encoder, 73 | tokenizer=tokenizer, 74 | ).to(device) 75 | 76 | prompt = args.prompt 77 | 78 | num_samples = 2 #@param {type:"number"} 79 | num_rows = 2 #@param {type:"number"} 80 | 81 | all_images = [] 82 | for _ in range(num_rows): 83 | with autocast("cuda"): 84 | images = pipe(prompt, num_images_per_prompt=num_samples, num_inference_steps=50, guidance_scale=7.5).images 85 | all_images.extend(images) 86 | 87 | grid = image_grid(all_images, num_samples, num_rows) 88 | grid.save('./1.png') 89 | 90 | 91 | 92 | if __name__ == '__main__': 93 | main() 94 | -------------------------------------------------------------------------------- /initialize_inversion.py: -------------------------------------------------------------------------------- 1 | """ 2 | automatically initialize the textual inversion with CLIP and no-parameter cross-attention 3 | """ 4 | 5 | import torch 6 | import os 7 | import argparse 8 | 9 | from PIL import Image 10 | import torch.nn.functional as F 11 | from transformers import CLIPModel, CLIPTokenizer, CLIPProcessor, CLIPTextModel 12 | from utils import imagenet_template, automatic_subjective_classnames 13 | 14 | 15 | def embedding_generate(model, tokenizer, text_encoder, classnames, templates, device): 16 | """ 17 | pre-caculate the template sentence, token embeddings 18 | """ 19 | with torch.no_grad(): 20 | sentence_weights = [] 21 | token_weights = [] 22 | token_embedding_table = text_encoder.get_input_embeddings().weight.data 23 | for classname in classnames: 24 | texts = [template(classname) for template in templates] # format with class 25 | texts = tokenizer(texts, padding="max_length", max_length=77, truncation=True, return_tensors="pt") # tokenize 26 | texts = texts['input_ids'].to(device) 27 | class_embeddings = model.get_text_features(texts) 28 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) 29 | class_embedding /= class_embedding.norm() 30 | sentence_weights.append(class_embedding) 31 | 32 | token_ids = tokenizer.encode(classname,add_special_tokens=False) 33 | token_embedding_list = [] 34 | for token_id in token_ids: 35 | token_embedding_list.append(token_embedding_table[token_id]) 36 | token_weights.append(torch.mean(torch.stack(token_embedding_list), dim=0)) 37 | 38 | sentence_weights = torch.stack(sentence_weights, dim=1).to(device) 39 | token_weights = torch.stack(token_weights, dim=0).to(device) 40 | return sentence_weights, token_weights 41 | 42 | 43 | 44 | def image_condition_embed_initialize(image_feature_list, sentence_embeddings, token_embeddings): 45 | """ 46 | no-parameter cross-attention: query: image, key: sentence, value: token 47 | """ 48 | inversion_emb_list = [] 49 | for image_features in image_feature_list: 50 | cross_attention = image_features @ sentence_embeddings 51 | attention_probs = F.softmax(cross_attention, dim=-1) 52 | inversion_emb = torch.matmul(attention_probs, token_embeddings) 53 | inversion_emb_list.append(inversion_emb) 54 | 55 | final_inversion = torch.mean(torch.stack(inversion_emb_list), dim=0) 56 | final_inversion = final_inversion / final_inversion.norm() 57 | return final_inversion 58 | 59 | 60 | 61 | def main(): 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument("--save_path", default='./save', type=str) 64 | parser.add_argument("--data_path", default='./cat', type=str) 65 | args = parser.parse_args() 66 | 67 | save_path = args.save_path 68 | data_path = args.data_path 69 | device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") 70 | tokenizer = CLIPTokenizer.from_pretrained('./clip') 71 | model = CLIPModel.from_pretrained('./clip') 72 | text_encoder = CLIPTextModel.from_pretrained('./clip') 73 | processor = CLIPProcessor.from_pretrained('./clip') 74 | 75 | sentence_embeddings, token_embeddings = embedding_generate(model, 76 | tokenizer, 77 | text_encoder, 78 | automatic_subjective_classnames, 79 | imagenet_template, 80 | device) 81 | print('sentence embedding size: ', sentence_embeddings.size(), ' token embedding size: ', token_embeddings.size()) 82 | 83 | image_feature_list = [] 84 | name_list = os.listdir(data_path) 85 | for name in name_list: 86 | image_path = os.path.join(data_path, name) 87 | image = Image.open(image_path) 88 | inputs = processor(images=image, return_tensors="pt") 89 | image_features = model.get_image_features(**inputs) 90 | image_features = F.normalize(image_features, dim=-1) 91 | image_feature_list.append(image_features) 92 | print('image size: ', len(image_feature_list)) 93 | 94 | inversion_emb = image_condition_embed_initialize(image_feature_list, sentence_embeddings, token_embeddings) 95 | 96 | inversion_emb_dict = {"initialize": inversion_emb.detach().cpu()} 97 | torch.save(inversion_emb_dict, os.path.join(save_path, 'initialize_emb.bin')) 98 | 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /train_inversion.py: -------------------------------------------------------------------------------- 1 | import cma 2 | import argparse 3 | import torch 4 | import os 5 | import numpy as np 6 | import copy 7 | from sklearn.decomposition import PCA 8 | 9 | from diffusers import StableDiffusionPipeline, DDPMScheduler 10 | from transformers import CLIPTextModel, CLIPTokenizer 11 | 12 | import torch.nn.functional as F 13 | from utils import TextualInversionDataset 14 | from tqdm import tqdm 15 | 16 | 17 | class GradientFreePipeline: 18 | def __init__(self, model_path, args, init_text_inversion=None, ): 19 | self.tokenizer = CLIPTokenizer.from_pretrained( 20 | os.path.join(model_path, 'tokenizer') 21 | ) 22 | self.text_encoder = CLIPTextModel.from_pretrained( 23 | os.path.join(model_path, 'text_encoder') 24 | ) 25 | self.pipe = StableDiffusionPipeline.from_pretrained( 26 | model_path, 27 | text_encoder=self.text_encoder, 28 | tokenizer=self.tokenizer, 29 | ).to(args.device) 30 | 31 | if args.projection_modeling == 'prior_normal': 32 | self.linear = torch.nn.Linear(args.intrinsic_dim, args.model_dim, bias=False).to(args.device) 33 | embedding = self.text_encoder.get_input_embeddings().weight.clone().cpu() 34 | mu_hat = np.mean(embedding.reshape(-1).detach().cpu().numpy()) 35 | std_hat = np.std(embedding.reshape(-1).detach().cpu().numpy()) 36 | mu = 0.0 37 | std = args.alpha * std_hat / (np.sqrt(args.intrinsic_dim) * args.sigma) 38 | 39 | # incorporate temperature factor 40 | # temp = intrinsic_dim - std_hat * std_hat 41 | # mu = mu_hat / temp 42 | # std = std_hat / np.sqrt(temp) 43 | print('[Embedding] mu: {} | std: {} [RandProj] mu: {} | std: {}'.format(mu_hat, std_hat, mu, std)) 44 | for p in self.linear.parameters(): 45 | torch.nn.init.normal_(p, mu, std) 46 | 47 | elif args.projection_modeling == 'pca': 48 | embedding = self.text_encoder.get_input_embeddings().weight.clone().cpu() 49 | embedding = embedding.detach().cpu().numpy() # (49408, 768) 50 | 51 | self.pca_model = PCA(n_components=args.intrinsic_dim) 52 | self.pca_model.fit(embedding) 53 | 54 | 55 | # Add the placeholder token in tokenizer 56 | num_added_tokens = self.tokenizer.add_tokens(args.placeholder_token) 57 | if num_added_tokens == 0: 58 | raise ValueError( 59 | f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" 60 | " `placeholder_token` that is not already in the tokenizer." 61 | ) 62 | # Convert the initializer_token, placeholder_token to ids 63 | token_ids = self.tokenizer.encode(args.initializer_token, add_special_tokens=False) 64 | 65 | initializer_token_id = token_ids[0] 66 | placeholder_token_id = self.tokenizer.convert_tokens_to_ids(args.placeholder_token) 67 | # Resize the token embeddings as we are adding new special tokens to the tokenizer 68 | self.text_encoder.resize_token_embeddings(len(self.tokenizer)) 69 | # Initialise the newly added placeholder token with the embeddings of the initializer token 70 | token_embeds = self.text_encoder.get_input_embeddings().weight.data 71 | token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] 72 | 73 | print('convert text inversion: ', args.placeholder_token, 'in id: ', str(placeholder_token_id)) 74 | self.placeholder_token_id = placeholder_token_id 75 | self.placeholder_token = args.placeholder_token 76 | self.num_call = 0 77 | 78 | train_dataset = TextualInversionDataset( 79 | data_root=args.train_data_dir, 80 | tokenizer=self.tokenizer, 81 | size=args.resolution, 82 | placeholder_token=args.placeholder_token, 83 | repeats=args.repeats, 84 | learnable_property=args.learnable_property, 85 | center_crop=args.center_crop, 86 | set="train", 87 | ) 88 | self.dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.repeats, shuffle=True) 89 | self.batch_size = args.repeats 90 | self.device = args.device 91 | print('load data length: ', len(self.dataloader)) 92 | 93 | # optimize incremental elements or original inversion 94 | if init_text_inversion is not None: 95 | self.init_text_inversion = init_text_inversion.to(args.device) 96 | else: 97 | self.init_text_inversion = token_embeds[initializer_token_id].to(args.device) 98 | 99 | self.args = args 100 | self.best_inversion = None 101 | 102 | def eval(self, inversion_embedding): 103 | self.num_call += 1 104 | pe_list = [] 105 | if isinstance(inversion_embedding, list): # multiple queries 106 | for pe in inversion_embedding: 107 | if self.args.projection_modeling == 'prior_normal': 108 | z = torch.tensor(pe).type(torch.float32).to(self.device) # z 109 | with torch.no_grad(): 110 | z = self.linear(z) # W_p Q 111 | if self.init_text_inversion is not None: 112 | z = z + self.init_text_inversion # W_p Q + p_0 113 | elif self.args.projection_modeling == 'pca': 114 | z = self.pca_model.inverse_transform(pe) # project the original text embedding space 115 | z = torch.tensor(z).type(torch.float32).to(self.device) 116 | if self.init_text_inversion is not None: 117 | z = z + self.init_text_inversion 118 | pe_list.append(z) 119 | 120 | elif isinstance(inversion_embedding, np.ndarray): # single query or None 121 | if self.args.projection_modeling == 'prior_normal': 122 | inversion_embedding = torch.tensor(inversion_embedding).type(torch.float32).to(self.device) # z 123 | with torch.no_grad(): 124 | inversion_embedding = self.linear(inversion_embedding) # W_p Q 125 | elif self.args.projection_modeling == 'pca': 126 | inversion_embedding = self.pca_model.inverse_transform(inversion_embedding) 127 | inversion_embedding = torch.tensor(inversion_embedding).type(torch.float32).to(self.device) 128 | if self.init_text_inversion is not None: 129 | inversion_embedding = inversion_embedding + self.init_text_inversion # W_p Q + p_0 130 | pe_list.append(inversion_embedding) 131 | else: 132 | raise ValueError( 133 | f'[Inversion Embedding] Only support [list, numpy.ndarray], got `{type(inversion_embedding)}` instead.' 134 | ) 135 | 136 | loss_list = [] 137 | print('begin to calculate loss') 138 | 139 | # fixed time step for fair evaluation 140 | noise_scheduler = DDPMScheduler.from_config('./ckpt/scheduler') 141 | timesteps = torch.randint( 142 | 0, noise_scheduler.config.num_train_timesteps, (self.batch_size,), device=self.device 143 | ).long() 144 | 145 | best_loss = 1000 146 | best_inversion = None 147 | 148 | for pe in tqdm(pe_list): 149 | token_embeds = self.text_encoder.get_input_embeddings().weight.data 150 | pe.to(self.text_encoder.get_input_embeddings().weight.dtype) 151 | token_embeds[self.placeholder_token_id] = pe 152 | loss = calculate_mse_loss(self.pipe, self.dataloader, self.device, noise_scheduler, timesteps) 153 | if loss < best_loss: 154 | best_loss = loss 155 | best_inversion = pe 156 | loss_list.append(loss) 157 | 158 | # update total point 159 | self.best_inversion = best_inversion 160 | 161 | return loss_list 162 | 163 | 164 | def save(self, output_path): 165 | learned_embeds_dict = {self.placeholder_token: self.best_inversion.detach().cpu()} 166 | torch.save(learned_embeds_dict, os.path.join(output_path, "learned_embeds.bin")) 167 | 168 | 169 | 170 | def calculate_mse_loss(image_generator, dataloader, device, noise_scheduler, timesteps): 171 | # print(image_generator.text_encoder.get_input_embeddings().weight.data[49408]) 172 | 173 | loss_cum = .0 174 | with torch.no_grad(): 175 | for batch in dataloader: 176 | # Convert images to latent space 177 | latents = image_generator.vae.encode(batch["pixel_values"].to(device)).latent_dist.sample().detach() 178 | latents = latents * 0.18215 179 | 180 | # Sample noise that we'll add to the latents 181 | noise = torch.randn(latents.shape).to(latents.device) 182 | # Sample a random timestep for each image 183 | 184 | # Add noise to the latents according to the noise magnitude at each timestep 185 | # (this is the forward diffusion process) 186 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 187 | 188 | # Get the text embedding for conditioning 189 | encoder_hidden_states = image_generator.text_encoder(batch["input_ids"].to(device))[0] 190 | 191 | # Predict the noise residual 192 | noise_pred = image_generator.unet(noisy_latents, timesteps, encoder_hidden_states).sample 193 | 194 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 195 | loss_cum += loss.item() 196 | 197 | return loss_cum / len(dataloader) 198 | 199 | 200 | 201 | 202 | 203 | def main(): 204 | parser = argparse.ArgumentParser() 205 | 206 | parser.add_argument("--intrinsic_dim", default=256, type=int) 207 | parser.add_argument("--k_shot", default=16, type=int) 208 | parser.add_argument("--batch_size", default=32, type=int) 209 | parser.add_argument("--budget", default=5000, type=int) # number of iterations 210 | parser.add_argument("--popsize", default=20, type=int) # number of candidates 211 | parser.add_argument("--bound", default=0, type=int) 212 | parser.add_argument("--sigma", default=1, type=float) 213 | parser.add_argument("--alpha", default=1, type=float) 214 | parser.add_argument("--print_every", default=50, type=int) 215 | parser.add_argument("--eval_every", default=100, type=int) 216 | parser.add_argument("--alg", default='CMA', type=str) # support other advanced evelution strategy 217 | parser.add_argument("--projection_modeling", default='pca', type=str) # decomposition method {'pca', 'prior_norm'} 218 | parser.add_argument("--model_dim", default=768, type=int) # dim of textual inversion 219 | parser.add_argument("--inversion_initialize", default='./save/initialize_emb.bin', type=str) # dim of textual inversion 220 | parser.add_argument("--seed", default=2023, type=int) 221 | parser.add_argument("--loss_type", default='noise', type=str) 222 | parser.add_argument("--cat_or_add", default='add', type=str) 223 | parser.add_argument("--device", default= torch.device("cuda:2" if torch.cuda.is_available() else "cpu")) 224 | parser.add_argument("--parallel", default=False, type=bool, help='Whether to allow parallel evaluation') 225 | 226 | parser.add_argument( 227 | "--placeholder_token", 228 | type=str, 229 | default='', 230 | help="A token to use as a placeholder for the concept.", 231 | ) 232 | parser.add_argument( 233 | "--initializer_token", 234 | type=str, 235 | default='painting', 236 | help="A token to use as initializer word." 237 | ) 238 | parser.add_argument( 239 | "--inference_framework", 240 | default='pt', 241 | type=str, 242 | help='''Which inference framework to use. 243 | Currently supports `pt` and `ort`, standing for pytorch and Microsoft onnxruntime respectively''' 244 | ) 245 | parser.add_argument( 246 | "--onnx_model_path", 247 | default=None, 248 | type=str, 249 | help='Path to your onnx model.' 250 | ) 251 | parser.add_argument( 252 | "--train_data_dir", 253 | type=str, 254 | default='./data', 255 | help="A folder containing the training data of instance images.", 256 | ) 257 | parser.add_argument( 258 | "--learnable_property", 259 | type=str, 260 | default="style", 261 | help="Choose between 'object' and 'style'" 262 | ) 263 | parser.add_argument( 264 | "--resolution", 265 | type=int, 266 | default=512, 267 | help=( 268 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 269 | " resolution" 270 | ), 271 | ) 272 | parser.add_argument( 273 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" 274 | ) 275 | parser.add_argument("--repeats", type=int, default=5, help="How many times to repeat the training data.") 276 | 277 | args = parser.parse_args() 278 | 279 | cma_opts = { 280 | 'seed': args.seed, 281 | 'popsize': args.popsize, 282 | 'maxiter': args.budget if args.parallel else args.budget // args.popsize, 283 | 'verbose': -1, 284 | } 285 | 286 | if args.bound > 0: 287 | cma_opts['bounds'] = [-1 * args.bound, 1 * args.bound] 288 | 289 | if args.inversion_initialize is not None: 290 | print('initialize textual inversion') 291 | init_text_inversion = torch.load(args.inversion_initialize, map_location="cpu")["initialize"] 292 | else: 293 | init_text_inversion = None 294 | 295 | pipeline = GradientFreePipeline(model_path='./ckpt', args=args, init_text_inversion=init_text_inversion) 296 | 297 | es = cma.CMAEvolutionStrategy(args.intrinsic_dim * [0], args.sigma, inopts=cma_opts) 298 | 299 | while not es.stop(): 300 | solutions = es.ask() # (popsize, intrinsic_dim) 301 | fitnesses = pipeline.eval(solutions) 302 | print(fitnesses) # loss for each point 303 | es.tell(solutions, fitnesses) 304 | pipeline.save('./save') 305 | 306 | 307 | if __name__ == "__main__": 308 | main() 309 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | import numpy as np 4 | from torchvision import transforms 5 | from PIL import Image 6 | import random 7 | import PIL 8 | import torch 9 | 10 | 11 | imagenet_templates_small = [ 12 | "a photo of a {}", 13 | "a rendering of a {}", 14 | "a cropped photo of the {}", 15 | "the photo of a {}", 16 | "a photo of a clean {}", 17 | "a photo of a dirty {}", 18 | "a dark photo of the {}", 19 | "a photo of my {}", 20 | "a photo of the cool {}", 21 | "a close-up photo of a {}", 22 | "a bright photo of the {}", 23 | "a cropped photo of a {}", 24 | "a photo of the {}", 25 | "a good photo of the {}", 26 | "a photo of one {}", 27 | "a close-up photo of the {}", 28 | "a rendition of the {}", 29 | "a photo of the clean {}", 30 | "a rendition of a {}", 31 | "a photo of a nice {}", 32 | "a good photo of a {}", 33 | "a photo of the nice {}", 34 | "a photo of the small {}", 35 | "a photo of the weird {}", 36 | "a photo of the large {}", 37 | "a photo of a cool {}", 38 | "a photo of a small {}", 39 | ] 40 | 41 | 42 | imagenet_style_templates_small = [ 43 | "a painting in the style of {}", 44 | "a rendering in the style of {}", 45 | "a cropped painting in the style of {}", 46 | "the painting in the style of {}", 47 | "a clean painting in the style of {}", 48 | "a dirty painting in the style of {}", 49 | "a dark painting in the style of {}", 50 | "a picture in the style of {}", 51 | "a cool painting in the style of {}", 52 | "a close-up painting in the style of {}", 53 | "a bright painting in the style of {}", 54 | "a cropped painting in the style of {}", 55 | "a good painting in the style of {}", 56 | "a close-up painting in the style of {}", 57 | "a rendition in the style of {}", 58 | "a nice painting in the style of {}", 59 | "a small painting in the style of {}", 60 | "a weird painting in the style of {}", 61 | "a large painting in the style of {}", 62 | ] 63 | 64 | 65 | 66 | automatic_subjective_classnames = [ 67 | "tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", 68 | "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", 69 | "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", 70 | "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", 71 | "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", 72 | "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", 73 | "box turtle", "banded gecko", "green iguana", "Carolina anole", 74 | "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", 75 | "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", 76 | "American alligator", "triceratops", "worm snake", "ring-necked snake", 77 | "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", 78 | "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", 79 | "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", 80 | "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", 81 | "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", 82 | "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", 83 | "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", 84 | "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", 85 | "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", 86 | "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", 87 | "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", 88 | "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", 89 | "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", 90 | "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", 91 | "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", 92 | "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", 93 | "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", 94 | "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", 95 | "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", 96 | "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", 97 | "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", 98 | "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", 99 | "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", 100 | "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", 101 | "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", 102 | "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", 103 | "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", 104 | "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", 105 | "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", 106 | "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", 107 | "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", 108 | "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", 109 | "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", 110 | "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", 111 | "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", 112 | "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", 113 | "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", 114 | "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", 115 | "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", 116 | "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", 117 | "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", 118 | "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", 119 | "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", 120 | "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", 121 | "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", 122 | "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", 123 | "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", 124 | "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", 125 | "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", 126 | "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", 127 | "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", 128 | "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", 129 | "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", 130 | "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", 131 | "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", 132 | "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", 133 | "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", 134 | "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", 135 | "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", 136 | "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", 137 | "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", 138 | "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", 139 | "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", 140 | "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", 141 | "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", 142 | "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", 143 | "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", 144 | "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", 145 | "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", 146 | "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", 147 | "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", 148 | "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", 149 | "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", 150 | "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", 151 | "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", 152 | "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", 153 | "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", 154 | "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", 155 | "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", 156 | "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", 157 | "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", 158 | "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", 159 | "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", 160 | "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", 161 | "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", 162 | "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", 163 | "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", 164 | "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", 165 | "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", 166 | "freight car", "French horn", "frying pan", "fur coat", "garbage truck", 167 | "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", 168 | "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", 169 | "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", 170 | "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", 171 | "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", 172 | "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", 173 | "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", 174 | "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", 175 | "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", 176 | "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", 177 | "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", 178 | "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", 179 | "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", 180 | "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", 181 | "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", 182 | "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", 183 | "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", 184 | "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", 185 | "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", 186 | "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", 187 | "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", 188 | "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", 189 | "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", 190 | "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", 191 | "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", 192 | "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", 193 | "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", 194 | "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", 195 | "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", 196 | "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", 197 | "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", 198 | "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", 199 | "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", 200 | "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", 201 | "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", 202 | "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", 203 | "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", 204 | "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", 205 | "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", 206 | "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", 207 | "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", 208 | "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", 209 | "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", 210 | "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", 211 | "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", 212 | "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", 213 | "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", 214 | "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", 215 | "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", 216 | "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", 217 | "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", 218 | "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", 219 | "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", 220 | "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", 221 | "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", 222 | "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", 223 | "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", 224 | "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", 225 | "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", 226 | "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", 227 | "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", 228 | "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", 229 | "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", 230 | "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", 231 | "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper", 232 | ] 233 | 234 | 235 | # refer to: https://github.com/samiramunir/Classifying-Painting-Art-Style-With-Deep-Learning 236 | automatic_style_classnames = [ 237 | "realism", 238 | "photorealism", 239 | "expressionism", 240 | "impressionism", 241 | "abstract", 242 | "surrealism", 243 | "pop art", 244 | "oil", 245 | "watercolor", 246 | "acrylic", 247 | "gouache", 248 | "pastel", 249 | "encaustic", 250 | "fresco", 251 | "spray paint", 252 | "digital", 253 | "history", 254 | "portrait", 255 | "genre", 256 | "landscape" 257 | ] 258 | 259 | 260 | imagenet_template = [ 261 | lambda c: f'a bad photo of a {c}.', 262 | lambda c: f'a photo of many {c}.', 263 | lambda c: f'a sculpture of a {c}.', 264 | lambda c: f'a photo of the hard to see {c}.', 265 | lambda c: f'a low resolution photo of the {c}.', 266 | lambda c: f'a rendering of a {c}.', 267 | lambda c: f'graffiti of a {c}.', 268 | lambda c: f'a bad photo of the {c}.', 269 | lambda c: f'a cropped photo of the {c}.', 270 | lambda c: f'a tattoo of a {c}.', 271 | lambda c: f'the embroidered {c}.', 272 | lambda c: f'a photo of a hard to see {c}.', 273 | lambda c: f'a bright photo of a {c}.', 274 | lambda c: f'a photo of a clean {c}.', 275 | lambda c: f'a photo of a dirty {c}.', 276 | lambda c: f'a dark photo of the {c}.', 277 | lambda c: f'a drawing of a {c}.', 278 | lambda c: f'a photo of my {c}.', 279 | lambda c: f'the plastic {c}.', 280 | lambda c: f'a photo of the cool {c}.', 281 | lambda c: f'a close-up photo of a {c}.', 282 | lambda c: f'a black and white photo of the {c}.', 283 | lambda c: f'a painting of the {c}.', 284 | lambda c: f'a painting of a {c}.', 285 | lambda c: f'a pixelated photo of the {c}.', 286 | lambda c: f'a sculpture of the {c}.', 287 | lambda c: f'a bright photo of the {c}.', 288 | lambda c: f'a cropped photo of a {c}.', 289 | lambda c: f'a plastic {c}.', 290 | lambda c: f'a photo of the dirty {c}.', 291 | lambda c: f'a jpeg corrupted photo of a {c}.', 292 | lambda c: f'a blurry photo of the {c}.', 293 | lambda c: f'a photo of the {c}.', 294 | lambda c: f'a good photo of the {c}.', 295 | lambda c: f'a rendering of the {c}.', 296 | lambda c: f'a {c} in a video game.', 297 | lambda c: f'a photo of one {c}.', 298 | lambda c: f'a doodle of a {c}.', 299 | lambda c: f'a close-up photo of the {c}.', 300 | lambda c: f'a photo of a {c}.', 301 | lambda c: f'the origami {c}.', 302 | lambda c: f'the {c} in a video game.', 303 | lambda c: f'a sketch of a {c}.', 304 | lambda c: f'a doodle of the {c}.', 305 | lambda c: f'a origami {c}.', 306 | lambda c: f'a low resolution photo of a {c}.', 307 | lambda c: f'the toy {c}.', 308 | lambda c: f'a rendition of the {c}.', 309 | lambda c: f'a photo of the clean {c}.', 310 | lambda c: f'a photo of a large {c}.', 311 | lambda c: f'a rendition of a {c}.', 312 | lambda c: f'a photo of a nice {c}.', 313 | lambda c: f'a photo of a weird {c}.', 314 | lambda c: f'a blurry photo of a {c}.', 315 | lambda c: f'a cartoon {c}.', 316 | lambda c: f'art of a {c}.', 317 | lambda c: f'a sketch of the {c}.', 318 | lambda c: f'a embroidered {c}.', 319 | lambda c: f'a pixelated photo of a {c}.', 320 | lambda c: f'itap of the {c}.', 321 | lambda c: f'a jpeg corrupted photo of the {c}.', 322 | lambda c: f'a good photo of a {c}.', 323 | lambda c: f'a plushie {c}.', 324 | lambda c: f'a photo of the nice {c}.', 325 | lambda c: f'a photo of the small {c}.', 326 | lambda c: f'a photo of the weird {c}.', 327 | lambda c: f'the cartoon {c}.', 328 | lambda c: f'art of the {c}.', 329 | lambda c: f'a drawing of the {c}.', 330 | lambda c: f'a photo of the large {c}.', 331 | lambda c: f'a black and white photo of a {c}.', 332 | lambda c: f'the plushie {c}.', 333 | lambda c: f'a dark photo of a {c}.', 334 | lambda c: f'itap of a {c}.', 335 | lambda c: f'graffiti of the {c}.', 336 | lambda c: f'a toy {c}.', 337 | lambda c: f'itap of my {c}.', 338 | lambda c: f'a photo of a cool {c}.', 339 | lambda c: f'a photo of a small {c}.', 340 | lambda c: f'a tattoo of the {c}.', 341 | ] 342 | 343 | 344 | 345 | class TextualInversionDataset(Dataset): 346 | def __init__( 347 | self, 348 | data_root, 349 | tokenizer, 350 | learnable_property="style", # [object, style] 351 | size=512, 352 | repeats=1, 353 | interpolation="bicubic", 354 | flip_p=0.5, 355 | set="train", 356 | placeholder_token="*", 357 | center_crop=False, 358 | ): 359 | self.data_root = data_root 360 | self.tokenizer = tokenizer 361 | self.learnable_property = learnable_property 362 | self.size = size 363 | self.placeholder_token = placeholder_token 364 | self.center_crop = center_crop 365 | self.flip_p = flip_p 366 | 367 | self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] 368 | 369 | self.num_images = len(self.image_paths) 370 | self._length = self.num_images 371 | 372 | if set == "train": 373 | self._length = self.num_images * repeats 374 | 375 | self.interpolation = { 376 | "linear": PIL.Image.LINEAR, 377 | "bilinear": PIL.Image.BILINEAR, 378 | "bicubic": PIL.Image.BICUBIC, 379 | "lanczos": PIL.Image.LANCZOS, 380 | }[interpolation] 381 | 382 | self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small 383 | self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) 384 | 385 | def __len__(self): 386 | return self._length 387 | 388 | def __getitem__(self, i): 389 | example = {} 390 | try: 391 | image = Image.open(self.image_paths[i % self.num_images]) 392 | except: 393 | image = Image.open(self.image_paths[(i+1) % self.num_images]) 394 | if not image.mode == "RGB": 395 | image = image.convert("RGB") 396 | 397 | placeholder_string = self.placeholder_token 398 | text = random.choice(self.templates).format(placeholder_string) 399 | 400 | example["input_ids"] = self.tokenizer( 401 | text, 402 | padding="max_length", 403 | truncation=True, 404 | max_length=self.tokenizer.model_max_length, 405 | return_tensors="pt", 406 | ).input_ids[0] 407 | 408 | # default to score-sde preprocessing 409 | img = np.array(image).astype(np.uint8) 410 | 411 | if self.center_crop: 412 | crop = min(img.shape[0], img.shape[1]) 413 | h, w, = ( 414 | img.shape[0], 415 | img.shape[1], 416 | ) 417 | img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] 418 | 419 | image = Image.fromarray(img) 420 | image = image.resize((self.size, self.size), resample=self.interpolation) 421 | 422 | image = self.flip_transform(image) 423 | image = np.array(image).astype(np.uint8) 424 | image = (image / 127.5 - 1.0).astype(np.float32) 425 | 426 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) 427 | return example 428 | 429 | --------------------------------------------------------------------------------