├── README.md
├── cma
├── LICENSE
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-39.pyc
│ ├── bbobbenchmarks.cpython-39.pyc
│ ├── constraints_handler.cpython-39.pyc
│ ├── evolution_strategy.cpython-39.pyc
│ ├── fitness_functions.cpython-39.pyc
│ ├── fitness_transformations.cpython-39.pyc
│ ├── interfaces.cpython-39.pyc
│ ├── logger.cpython-39.pyc
│ ├── optimization_tools.cpython-39.pyc
│ ├── purecma.cpython-39.pyc
│ ├── recombination_weights.cpython-39.pyc
│ ├── restricted_gaussian_sampler.cpython-39.pyc
│ ├── s.cpython-39.pyc
│ ├── sampler.cpython-39.pyc
│ ├── sigma_adaptation.cpython-39.pyc
│ └── transformations.cpython-39.pyc
├── bbobbenchmarks.py
├── constraints_handler.py
├── evolution_strategy.py
├── fitness_functions.py
├── fitness_models.py
├── fitness_transformations.py
├── interfaces.py
├── logger.py
├── optimization_tools.py
├── purecma.py
├── recombination_weights.py
├── restricted_gaussian_sampler.py
├── s.py
├── sampler.py
├── sigma_adaptation.py
├── test.py
├── transformations.py
├── utilities
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-39.pyc
│ │ ├── math.cpython-39.pyc
│ │ ├── python3for2.cpython-39.pyc
│ │ └── utils.cpython-39.pyc
│ ├── math.py
│ ├── python3for2.py
│ └── utils.py
└── wrapper.py
├── embeding_distribution.py
├── figures
├── case.png
├── cma.gif
└── framework.png
├── infer_inversion.py
├── initialize_inversion.py
├── train_inversion.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Gradient-Free Textual Inversion
2 |
3 | Gradient-free textual inversion for personalized text-to-image generation.
4 | We introduce to use evolution strategy from [OpenAI](https://openai.com/blog/evolution-strategies/) without gradient to optimize the pesudo-word embeddings.
5 | Our implementation is totally compatible with [diffusers](https://github.com/huggingface/diffusers) and stable diffusion model.
6 |
7 |
8 |
9 |
10 |
11 | Evolution process for textual embeddings.
12 |
13 |
14 |
15 |
16 | ## What does this repo do?
17 |
18 | Current personalized text-to-image approaches, which learn to bind a unique identifier with specific subjects or styles in a few given images, usually incorporate a special word and tune its embedding parameters through gradient descent.
19 | It is natural to question whether we can optimize the textual inversions by only accessing the inference of models? As only requiring the forward computation to determine the textual inversion retains the benefits of efficient computation and safe deployment.
20 |
21 | Hereto, we introduce a gradient-free framework to optimize the continuous textual inversion in personalized text-to-image generation.
22 | Specifically, we first initialize the textual inversion with non-parameter cross-attention to ensure the latent embedding space.
23 | Then, instead of optimizing in the original high-dimensional embedding space, which is intractable for derivative-free optimization, we perform optimization in a decomposition subspace with (i) PCA and (ii) prior normalization through *iterative* evolutionary strategy.
24 |
25 |
26 |
27 |
28 |
29 | Overview of the proposed gradient-free textual inversion framework.
30 |
31 |
32 |
33 |
34 |
35 | ## Cases
36 |
37 | Some cases generated by standard textual inversion and gradient-free inversion based on stable diffusion model.
38 |
39 |
40 |
41 |
42 |
43 |
44 | Cases for the personalized text-to-image generation.
45 |
46 |
47 |
48 |
49 | ## Process
50 |
51 |
52 | To intialize the textual inversion with cross-attention automatically, run:
53 | ```
54 | python initialize_inversion.py
55 | ```
56 |
57 | Then, iterative optimize the textual inversion with gradient-free evolution strategy, run:
58 | ```
59 | python train_inversion.py
60 | ```
61 |
62 | Finally, with the trained textual inversion, you can generated personalized image with ```infer_inversion.py``` script.
63 |
64 |
65 |
66 | ## Acknowledge
67 |
68 | This repository is based on [diffusers](https://github.com/huggingface/diffusers) and [textual inversion](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) script. Thanks for their clear code.
69 |
70 |
71 |
72 |
73 |
--------------------------------------------------------------------------------
/cma/LICENSE:
--------------------------------------------------------------------------------
1 | The BSD 3-Clause License
2 | Copyright (c) 2014 Inria
3 | Author: Nikolaus Hansen, 2008-
4 | Author: Petr Baudis, 2014
5 | Author: Youhei Akimoto, 2016-
6 |
7 | Redistribution and use in source and binary forms, with or without
8 | modification, are permitted provided that the following conditions
9 | are met:
10 |
11 | 1. Redistributions of source code must retain the above copyright and
12 | authors notice, this list of conditions and the following disclaimer.
13 |
14 | 2. Redistributions in binary form must reproduce the above copyright
15 | and authors notice, this list of conditions and the following
16 | disclaimer in the documentation and/or other materials provided with
17 | the distribution.
18 |
19 | 3. Neither the name of the copyright holder nor the names of its
20 | contributors nor the authors names may be used to endorse or promote
21 | products derived from this software without specific prior written
22 | permission.
23 |
24 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27 | AUTHORS OR CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES
28 | OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
29 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
30 | DEALINGS IN THE SOFTWARE.
31 |
--------------------------------------------------------------------------------
/cma/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Package `cma` implements the CMA-ES (Covariance Matrix Adaptation
3 | Evolution Strategy).
4 |
5 | CMA-ES is a stochastic optimizer for robust non-linear non-convex
6 | derivative- and function-value-free numerical optimization.
7 |
8 | This implementation can be used with Python versions >= 2.7, namely,
9 | it was tested with 2.7, 3.5, 3.6, 3.7, 3.8.
10 |
11 | CMA-ES searches for a minimizer (a solution x in :math:`R^n`) of an
12 | objective function f (cost function), such that f(x) is minimal.
13 | Regarding f, only a passably reliable ranking of the candidate
14 | solutions in each iteration is necessary. Neither the function values
15 | itself, nor the gradient of f need to be available or do matter (like
16 | in the downhill simplex Nelder-Mead algorithm). Some termination
17 | criteria however depend on actual f-values.
18 |
19 | The `cma` module provides two independent implementations of the
20 | CMA-ES algorithm in the classes `cma.CMAEvolutionStrategy` and
21 | `cma.purecma.CMAES`.
22 |
23 | In each implementation two interfaces are provided:
24 |
25 | - functions `fmin2` and `purecma.fmin`:
26 | run a complete minimization of the passed objective function with
27 | CMA-ES. `fmin` also provides optional restarts and noise handling.
28 |
29 | - class `CMAEvolutionStrategy` and `purecma.CMAES`:
30 | allow for minimization such that the control of the iteration
31 | loop remains with the user.
32 |
33 | The `cma` package root provides shortcuts to these and other classes and
34 | functions.
35 |
36 | Used external packages are `numpy` (only `purecma` does not depend on
37 | `numpy`) and `matplotlib.pyplot` (for `plot` etc., optional but highly
38 | recommended).
39 |
40 | Install
41 | =======
42 | To use the module, the folder ``cma`` only needs to be visible in the
43 | python path, e.g. in the current working directory.
44 |
45 | To install the module from pipy, type::
46 |
47 | pip install cma
48 |
49 | from the command line.
50 |
51 | To install the module from a ``cma`` folder::
52 |
53 | pip install -e cma
54 |
55 | To upgrade the currently installed version use additionally the ``-U``
56 | option.
57 |
58 | Testing
59 | =======
60 | From the system shell::
61 |
62 | python -m cma.test -h
63 | python -m cma.test
64 | python -c "import cma.test; cma.test.main()" # the same
65 |
66 | or from any (i)python shell::
67 |
68 | import cma.test
69 | cma.test.main()
70 |
71 | should run without complaints in about between 20 and 100 seconds.
72 |
73 | Example
74 | =======
75 | From a python shell::
76 |
77 | import cma
78 | help(cma) # "this" help message, use cma? in ipython
79 | help(cma.fmin)
80 | help(cma.CMAEvolutionStrategy)
81 | help(cma.CMAOptions)
82 | cma.CMAOptions('tol') # display 'tolerance' termination options
83 | cma.CMAOptions('verb') # display verbosity options
84 | res = cma.fmin(cma.ff.tablet, 15 * [1], 1)
85 | es = cma.CMAEvolutionStrategy(15 * [1], 1).optimize(cma.ff.tablet)
86 | help(es.result)
87 | res[0], es.result[0] # best evaluated solution
88 | res[5], es.result[5] # mean solution, presumably better with noise
89 |
90 | :See also: `fmin` (), `CMAOptions`, `CMAEvolutionStrategy`
91 |
92 | :Author: Nikolaus Hansen, 2008-
93 | :Author: Petr Baudis, 2014
94 | :Author: Youhei Akimoto, 2017-
95 |
96 | :License: BSD 3-Clause, see LICENSE file.
97 |
98 | """
99 |
100 | # How to create a html documentation file:
101 | # pydoctor --docformat=restructuredtext --make-html cma
102 | # old:
103 | # pydoc -w cma # edit the header (remove local pointers)
104 | # epydoc cma.py # comes close to javadoc but does not find the
105 | # # links of function references etc
106 | # doxygen needs @package cma as first line in the module docstring
107 | # some things like class attributes are not interpreted correctly
108 | # sphinx: doc style of doc.python.org, could not make it work (yet)
109 | # __docformat__ = "reStructuredText" # this hides some comments entirely?
110 |
111 | from __future__ import absolute_import # now local imports must use .
112 | # big difference between PY2 and PY3:
113 | from __future__ import division
114 | from __future__ import print_function
115 | # only necessary for python 2.5 (not supported) and not in heavy use
116 | from __future__ import with_statement
117 | ___author__ = "Nikolaus Hansen and Petr Baudis and Youhei Akimoto"
118 | __license__ = "BSD 3-clause"
119 |
120 | import warnings as _warnings
121 |
122 | # __package__ = 'cma'
123 | from . import purecma
124 | try:
125 | import numpy
126 | del numpy
127 | except ImportError:
128 | _warnings.warn('Only `cma.purecma` has been imported. Install `numpy` ("pip'
129 | ' install numpy") if you want to import the entire `cma`'
130 | ' package.')
131 | else:
132 | from . import (constraints_handler, evolution_strategy, fitness_functions,
133 | fitness_transformations, interfaces, optimization_tools,
134 | sampler, sigma_adaptation, transformations, utilities,
135 | )
136 | # from . import test # gives a warning with python -m cma.test (since Python 3.5.3?)
137 | test = 'type "import cma.test" to access the `test` module of `cma`'
138 | from . import s
139 | from .fitness_functions import ff
140 | from .fitness_transformations import GlueArguments, ScaleCoordinates
141 | from .evolution_strategy import fmin, fmin2, fmin_con, fmin_con2, CMAEvolutionStrategy, CMAOptions
142 | from .logger import disp, plot, CMADataLogger
143 | from .optimization_tools import NoiseHandler
144 | from .constraints_handler import BoundPenalty, BoundTransform, ConstrainedFitnessAL
145 | from .evolution_strategy import cma_default_options_
146 |
147 | del division, print_function, absolute_import, with_statement #, unicode_literals
148 |
149 | # fcts = ff # historical reasons only, replace cma.fcts with cma.ff first
150 |
151 | __version__ = "3.2.2"
152 | # $Source$ # according to PEP 8 style guides, but what is it good for?
153 | # $Id: __init__.py 4432 2020-05-28 18:39:09Z hansen $
154 | # bash $: svn propset svn:keywords 'Date Revision Id' __init__.py
155 |
--------------------------------------------------------------------------------
/cma/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/cma/__pycache__/bbobbenchmarks.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/bbobbenchmarks.cpython-39.pyc
--------------------------------------------------------------------------------
/cma/__pycache__/constraints_handler.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/constraints_handler.cpython-39.pyc
--------------------------------------------------------------------------------
/cma/__pycache__/evolution_strategy.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/evolution_strategy.cpython-39.pyc
--------------------------------------------------------------------------------
/cma/__pycache__/fitness_functions.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/fitness_functions.cpython-39.pyc
--------------------------------------------------------------------------------
/cma/__pycache__/fitness_transformations.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/fitness_transformations.cpython-39.pyc
--------------------------------------------------------------------------------
/cma/__pycache__/interfaces.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/interfaces.cpython-39.pyc
--------------------------------------------------------------------------------
/cma/__pycache__/logger.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/logger.cpython-39.pyc
--------------------------------------------------------------------------------
/cma/__pycache__/optimization_tools.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/optimization_tools.cpython-39.pyc
--------------------------------------------------------------------------------
/cma/__pycache__/purecma.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/purecma.cpython-39.pyc
--------------------------------------------------------------------------------
/cma/__pycache__/recombination_weights.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/recombination_weights.cpython-39.pyc
--------------------------------------------------------------------------------
/cma/__pycache__/restricted_gaussian_sampler.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/restricted_gaussian_sampler.cpython-39.pyc
--------------------------------------------------------------------------------
/cma/__pycache__/s.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/s.cpython-39.pyc
--------------------------------------------------------------------------------
/cma/__pycache__/sampler.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/sampler.cpython-39.pyc
--------------------------------------------------------------------------------
/cma/__pycache__/sigma_adaptation.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/sigma_adaptation.cpython-39.pyc
--------------------------------------------------------------------------------
/cma/__pycache__/transformations.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/__pycache__/transformations.cpython-39.pyc
--------------------------------------------------------------------------------
/cma/fitness_functions.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """versatile container for test objective functions.
3 |
4 | For the time being this is probably best used like::
5 |
6 | from cma.fitness_functions import ff
7 |
8 | """
9 |
10 | from __future__ import (absolute_import, division, print_function,
11 | ) # unicode_literals, with_statement)
12 | # from __future__ import collections.MutableMapping
13 | # does not exist in future, otherwise Python 2.5 would work, since 0.91.01
14 | from .utilities.python3for2 import range
15 |
16 | import os, sys
17 |
18 | import numpy as np
19 | # arange, cos, size, eye, inf, dot, floor, outer, zeros, linalg.eigh,
20 | # sort, argsort, random, ones,...
21 | from numpy import array, dot, isscalar, sum # sum is not needed
22 | # from numpy import inf, exp, log, isfinite
23 | # to access the built-in sum fct: ``__builtins__.sum`` or ``del sum``
24 | # removes the imported sum and recovers the shadowed build-in
25 | try: np.median([1,2,3,2]) # fails currently in pypy, also sigma_vec.scaling
26 | except AttributeError:
27 | def _median(x):
28 | x = sorted(x)
29 | if len(x) % 2:
30 | return x[len(x) // 2]
31 | return (x[len(x) // 2 - 1] + x[len(x) // 2]) / 2
32 | np.median = _median
33 | from .transformations import Rotation
34 | from .utilities import utils
35 | from .utilities.utils import rglen
36 | del (division, print_function, absolute_import,
37 | ) # unicode_literals, with_statement)
38 |
39 | # $Source$ # according to PEP 8 style guides, but what is it good for?
40 | # $Id: fitness_functions.py 4150 2015-03-20 13:53:36Z hansen $
41 | # bash $: svn propset svn:keywords 'Date Revision Id' fitness_functions.py
42 |
43 | try:
44 | from . import bbobbenchmarks
45 | BBOB = bbobbenchmarks # for backwards compatibility
46 | except ImportError:
47 | BBOB = """Call::
48 | cma.ff.fetch_bbob_fcts()
49 | to download and extract `bbobbenchmarks.py` and thereby setting
50 | cma.ff.BBOB to these benchmarks; then, e.g., `F12 = cma.ff.BBOB.F12()`
51 | returns an instance of F12 Bent Cigar.
52 |
53 | CAVEAT: in the downloaded `bbobbenchmarks.py` file in L987
54 | ``np.negative(idx)`` needs to be replaced by ``~idx``.
55 | """
56 | from .fitness_transformations import rotate #, ComposedFunction, Function
57 |
58 | def elli(x, cond=1e6):
59 | """unbound test function, needed to test multiprocessor, as long
60 | as the other test functions are defined within a class and
61 | only accessable via the class instance"""
62 | return sum(cond**(np.arange(len(x)) / (len(x) - 1 + 1e-9)) * np.asarray(x)**2)
63 | def sphere(x):
64 | return sum(np.asarray(x)**2)
65 |
66 | def _iqr(x):
67 | x = sorted(x)
68 | i1 = int(len(x) / 4)
69 | i3 = int(3*len(x) / 4)
70 | return x[i3] - x[i1]
71 |
72 | class FitnessFunctions(object): # TODO: this class is not necessary anymore? But some effort is needed to change it
73 | """collection of objective functions.
74 |
75 | """
76 | evaluations = 0 # number of calls or any other practical use
77 | def __init__(self):
78 | """"""
79 | @property # avoid pickle and multiprocessing error
80 | def BBOB(self):
81 | return bbobbenchmarks
82 | try: BBOB.__doc__ = bbobbenchmarks.__doc__
83 | except: pass # in Python 2 __doc__ is readonly
84 |
85 | def rot(self, x, fun, rot=1, args=()):
86 | """returns ``fun(rotation(x), *args)``, ie. `fun` applied to a rotated argument"""
87 | if len(np.shape(array(x))) > 1: # parallelized
88 | res = []
89 | for x in x:
90 | res.append(self.rot(x, fun, rot, args))
91 | return res
92 |
93 | if rot:
94 | return fun(rotate(x, *args))
95 | else:
96 | return fun(x)
97 | def somenan(self, x, fun, p=0.1):
98 | """returns sometimes np.NaN, otherwise fun(x)"""
99 | if np.random.rand(1) < p:
100 | return np.NaN
101 | else:
102 | return fun(x)
103 |
104 | def epslow(self, fun, eps=1e-7, Neff=lambda x: int(len(x)**0.5)):
105 | return lambda x: fun(x[:Neff(x)]) + eps * np.mean(x**2)
106 |
107 | def rand(self, x):
108 | """Random test objective function"""
109 | return np.random.random(1)[0]
110 | def linear(self, x):
111 | return -x[0]
112 | def lineard(self, x):
113 | if 1 < 3 and any(array(x) < 0):
114 | return np.nan
115 | if 1 < 3 and sum([(10 + i) * x[i] for i in rglen(x)]) > 50e3:
116 | return np.nan
117 | return -sum(x)
118 | def sphere(self, x):
119 | """Sphere (squared norm) test objective function"""
120 | # return np.random.rand(1)[0]**0 * sum(x**2) + 1 * np.random.rand(1)[0]
121 | return sum((x + 0)**2)
122 | def subspace_sphere(self, x, visible_ratio=1/2):
123 | """
124 | """
125 | # here we could use an init function, that is this would
126 | # preferably be a class
127 | m = int(visible_ratio * len(x) + 1)
128 | x = np.asarray(x)[np.random.permutation(len(x))[:m]]
129 | return sum(x**2)
130 | def pnorm(self, x, p=0.5):
131 | return sum(np.abs(x)**p)**(1./p)
132 | def grad_sphere(self, x, *args):
133 | return 2*array(x, copy=False)
134 | def grad_to_one(self, x, *args):
135 | return array(x, copy=False) - 1
136 | def sphere_pos(self, x):
137 | """Sphere (squared norm) test objective function"""
138 | # return np.random.rand(1)[0]**0 * sum(x**2) + 1 * np.random.rand(1)[0]
139 | c = 0.0
140 | if x[0] < c:
141 | return np.nan
142 | return -c**2 + sum((x + 0)**2)
143 | def spherewithoneconstraint(self, x):
144 | return sum((x + 0)**2) if x[0] > 1 else np.nan
145 | def elliwithoneconstraint(self, x, idx=[-1]):
146 | return self.ellirot(x) if all(array(x)[idx] > 1) else np.nan
147 |
148 | def spherewithnconstraints(self, x):
149 | return sum((x + 0)**2) if all(array(x) > 1) else np.nan
150 | # zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz
151 | def noisysphere(self, x, noise=2.10e-9, cond=1.0, noise_offset=0.10):
152 | """noise=10 does not work with default popsize, ``cma.NoiseHandler(dimension, 1e7)`` helps"""
153 | return self.elli(x, cond=cond) * np.exp(0 + noise * np.random.randn() / len(x)) + noise_offset * np.random.rand()
154 | def spherew(self, x):
155 | """Sphere (squared norm) with sum x_i = 1 test objective function"""
156 | # return np.random.rand(1)[0]**0 * sum(x**2) + 1 * np.random.rand(1)[0]
157 | # s = sum(abs(x))
158 | # return sum((x/s+0)**2) - 1/len(x)
159 | # return sum((x/s)**2) - 1/len(x)
160 | return -0.01 * x[0] + abs(x[0])**-2 * sum(x[1:]**2)
161 | def epslowsphere(self, x, eps=1e-7, Neff=lambda x: int(len(x)**0.5)):
162 | """TODO: define as wrapper"""
163 | return np.mean(x[:Neff(x)]**2) + eps * np.mean(x**2)
164 | def partsphere(self, x):
165 | """Sphere (squared norm) test objective function"""
166 | self.evaluations += 1
167 | # return np.random.rand(1)[0]**0 * sum(x**2) + 1 * np.random.rand(1)[0]
168 | dim = len(x)
169 | x = array([x[i % dim] for i in range(2 * dim)])
170 | N = 8
171 | i = self.evaluations % dim
172 | # f = sum(x[i:i + N]**2)
173 | f = sum(x[np.random.randint(dim, size=N)]**2)
174 | return f
175 | def sectorsphere(self, x):
176 | """asymmetric Sphere (squared norm) test objective function"""
177 | return sum(x**2) + (1e6 - 1) * sum(x[x < 0]**2)
178 | def cornersphere(self, x):
179 | """Sphere (squared norm) test objective function constraint to the corner"""
180 | nconstr = len(x) - 0
181 | if any(x[:nconstr] < 1):
182 | return np.NaN
183 | return sum(x**2) - nconstr
184 | def cornerelli(self, x):
185 | """ """
186 | if any(x < 1):
187 | return np.NaN
188 | return self.elli(x) - self.elli(np.ones(len(x)))
189 | def cornerellirot(self, x):
190 | """ """
191 | if any(x < 1):
192 | return np.NaN
193 | return self.ellirot(x)
194 | def normalSkew(self, f):
195 | N = np.random.randn(1)[0]**2
196 | if N < 1:
197 | N = f * N # diminish blow up lower part
198 | return N
199 | def noiseC(self, x, func=sphere, fac=10, expon=0.8):
200 | f = func(self, x)
201 | N = np.random.randn(1)[0] / np.random.randn(1)[0]
202 | return max(1e-19, f + (float(fac) / len(x)) * f**expon * N)
203 | def noise(self, x, func=sphere, fac=10, expon=1):
204 | f = func(self, x)
205 | # R = np.random.randn(1)[0]
206 | R = np.log10(f) + expon * abs(10 - np.log10(f)) * np.random.rand(1)[0]
207 | # sig = float(fac)/float(len(x))
208 | # R = log(f) + 0.5*log(f) * random.randn(1)[0]
209 | # return max(1e-19, f + sig * (f**np.log10(f)) * np.exp(R))
210 | # return max(1e-19, f * np.exp(sig * N / f**expon))
211 | # return max(1e-19, f * normalSkew(f**expon)**sig)
212 | return f + 10**R # == f + f**(1+0.5*RN)
213 | def cigar(self, x, rot=0, cond=1e6, noise=0):
214 | """Cigar test objective function"""
215 | if rot:
216 | x = rotate(x)
217 | x = [x] if isscalar(x[0]) else x # scalar into list
218 | f = [(x[0]**2 + cond * sum(x[1:]**2)) * np.exp(noise * np.random.randn(1)[0] / len(x)) for x in x]
219 | return f if len(f) > 1 else f[0] # 1-element-list into scalar
220 | def grad_cigar(self, x, *args):
221 | grad = 2 * 1e6 * np.array(x)
222 | grad[0] /= 1e6
223 | return grad
224 | def diagonal_cigar(self, x, cond=1e6):
225 | axis = np.ones(len(x)) / len(x)**0.5
226 | proj = dot(axis, x) * axis
227 | s = sum(proj**2)
228 | s += cond * sum((x - proj)**2)
229 | return s
230 | def tablet(self, x, cond=1e6, rot=0):
231 | """Tablet test objective function"""
232 | x = np.asarray(x)
233 | if rot and rot is not ff.tablet:
234 | x = rotate(x)
235 | x = [x] if isscalar(x[0]) else x # scalar into list
236 | f = [cond * x[0]**2 + sum(x[1:]**2) for x in x]
237 | return f if len(f) > 1 else f[0] # 1-element-list into scalar
238 | def grad_tablet(self, x, *args):
239 | grad = 2 * np.array(x)
240 | grad[0] *= 1e6
241 | return grad
242 | def cigtab(self, y):
243 | """Cigtab test objective function"""
244 | X = [y] if isscalar(y[0]) else y
245 | f = [1e-4 * x[0]**2 + 1e4 * x[1]**2 + sum(x[2:]**2) for x in X]
246 | return f if len(f) > 1 else f[0]
247 | def cigtab2(self, x, condition=1e8, n_axes=None):
248 | """cigtab with 1 + 5% long and short axes.
249 |
250 | `n_axes: int`, if > 0, sets the number of long as well as short
251 | axes to `n_axes`, respectively.
252 | """
253 | m = n_axes or 1 + len(x) // 20
254 | x = np.asarray(x)
255 | f = sum(x[m:-m]**2)
256 | f += condition**0.5 * sum(x[:m]**2)
257 | f += condition**-0.5 * sum(x[-m:]**2)
258 | return f
259 | def twoaxes(self, y):
260 | """Cigtab test objective function"""
261 | X = [y] if isscalar(y[0]) else y
262 | N2 = len(X[0]) // 2
263 | f = [1e6 * sum(x[0:N2]**2) + sum(x[N2:]**2) for x in X]
264 | return f if len(f) > 1 else f[0]
265 | def ellirot(self, x):
266 | return ff.elli(array(x), 1)
267 | def hyperelli(self, x):
268 | N = len(x)
269 | return sum((np.arange(1, N + 1) * x)**2)
270 | def halfelli(self, x):
271 | l = len(x) // 2
272 | felli = self.elli(x[:l])
273 | return felli + 1e-8 * sum(x[l:]**2)
274 | def elli(self, x, rot=0, xoffset=0, cond=1e6, actuator_noise=0.0, both=False):
275 | """Ellipsoid test objective function"""
276 | x = np.asarray(x)
277 | if not isscalar(x[0]): # parallel evaluation
278 | return [self.elli(xi, rot) for xi in x] # could save 20% overall
279 | if rot:
280 | x = rotate(x)
281 | N = len(x)
282 | if actuator_noise:
283 | x = x + actuator_noise * np.random.randn(N)
284 |
285 | ftrue = sum(cond**(np.arange(N) / (N - 1.)) * (x + xoffset)**2) \
286 | if N > 1 else (x + xoffset)**2
287 |
288 | alpha = 0.49 + 1. / N
289 | beta = 1
290 | felli = np.random.rand(1)[0]**beta * ftrue * \
291 | max(1, (10.**9 / (ftrue + 1e-99))**(alpha * np.random.rand(1)[0]))
292 | # felli = ftrue + 1*np.random.randn(1)[0] / (1e-30 +
293 | # np.abs(np.random.randn(1)[0]))**0
294 | if both:
295 | return (felli, ftrue)
296 | else:
297 | # return felli # possibly noisy value
298 | return ftrue # + np.random.randn()
299 | def grad_elli(self, x, *args):
300 | cond = 1e6
301 | N = len(x)
302 | return 2 * cond**(np.arange(N) / (N - 1.)) * array(x, copy=False)
303 | def fun_as_arg(self, x, *args):
304 | """``fun_as_arg(x, fun, *more_args)`` calls ``fun(x, *more_args)``.
305 |
306 | Use case::
307 |
308 | fmin(cma.fun_as_arg, args=(fun,), gradf=grad_numerical)
309 |
310 | calls fun_as_args(x, args) and grad_numerical(x, fun, args=args)
311 |
312 | """
313 | fun = args[0]
314 | more_args = args[1:] if len(args) > 1 else ()
315 | return fun(x, *more_args)
316 | def grad_numerical(self, x, func, epsilon=None):
317 | """symmetric gradient"""
318 | eps = 1e-8 * (1 + abs(x)) if epsilon is None else epsilon
319 | grad = np.zeros(len(x))
320 | ei = np.zeros(len(x)) # float is 1.6 times faster than int
321 | for i in rglen(x):
322 | ei[i] = eps[i]
323 | grad[i] = (func(x + ei) - func(x - ei)) / (2*eps[i])
324 | ei[i] = 0
325 | return grad
326 | def elliconstraint(self, x, cfac=1e8, tough=True, cond=1e6):
327 | """ellipsoid test objective function with "constraints" """
328 | N = len(x)
329 | f = sum(cond**(np.arange(N)[-1::-1] / (N - 1)) * x**2)
330 | cvals = (x[0] + 1,
331 | x[0] + 1 + 100 * x[1],
332 | x[0] + 1 - 100 * x[1])
333 | if tough:
334 | f += cfac * sum(max(0, c) for c in cvals)
335 | else:
336 | f += cfac * sum(max(0, c + 1e-3)**2 for c in cvals)
337 | return f
338 | def rosen(self, x, alpha=1e2):
339 | """Rosenbrock test objective function"""
340 | x = [x] if isscalar(x[0]) else x # scalar into list
341 | x = np.asarray(x)
342 | f = [sum(alpha * (x[:-1]**2 - x[1:])**2 + (1. - x[:-1])**2) for x in x]
343 | return f if len(f) > 1 else f[0] # 1-element-list into scalar
344 | def grad_rosen(self, x, *args):
345 | N = len(x)
346 | grad = np.zeros(N)
347 | grad[0] = 2 * (x[0] - 1) + 200 * (x[1] - x[0]**2) * -2 * x[0]
348 | i = np.arange(1, N - 1)
349 | grad[i] = 2 * (x[i] - 1) - 400 * (x[i+1] - x[i]**2) * x[i] + 200 * (x[i] - x[i-1]**2)
350 | grad[N-1] = 200 * (x[N-1] - x[N-2]**2)
351 | return grad
352 | def rosen_chained(self, x, alpha=1e2):
353 | x = [x] if isscalar(x[0]) else x # scalar into list
354 | f = [(1. - x[0])**2 + sum(alpha * (x[:-1]**2 - x[1:])**2) for x in x]
355 | return f if len(f) > 1 else f[0] # 1-element-list into scalar
356 |
357 | def diffpow(self, x, rot=0):
358 | """Diffpow test objective function"""
359 | N = len(x)
360 | if rot:
361 | x = rotate(x)
362 | return sum(np.abs(x)**(2. + 4.*np.arange(N) / (N - 1.)))**0.5
363 | def rosenelli(self, x):
364 | N = len(x)
365 | Nhalf = int((N + 1) / 2)
366 | return self.rosen(x[:Nhalf]) + self.elli(x[Nhalf:], cond=1)
367 | def ridge(self, x, expo=2):
368 | x = [x] if isscalar(x[0]) else x # scalar into list
369 | f = [x[0] + 100 * np.sum(x[1:]**2)**(expo / 2.) for x in x]
370 | return f if len(f) > 1 else f[0] # 1-element-list into scalar
371 | def ridgecircle(self, x, expo=0.5):
372 | """a difficult sharp ridge type function.
373 |
374 | A modified implementation of HG Beyers `happycat`.
375 | """
376 | a = len(x)
377 | s = sum(x**2)
378 | return ((s - a)**2)**(expo / 2) + s / a + sum(x) / a
379 | def happycat(self, x, alpha=1. / 8):
380 | """a difficult sharp ridge type function.
381 |
382 | Proposed by HG Beyer.
383 | """
384 | s = sum(x**2)
385 | return ((s - len(x))**2)**alpha + (s / 2 + sum(x)) / len(x) + 0.5
386 | def flat(self, x):
387 | return 1
388 | return 1 if np.random.rand(1) < 0.9 else 1.1
389 | return np.random.randint(1, 30)
390 | def branin(self, x):
391 | # in [0,15]**2
392 | y = x[1]
393 | x = x[0] + 5
394 | return (y - 5.1 * x**2 / 4 / np.pi**2 + 5 * x / np.pi - 6)**2 + 10 * (1 - 1 / 8 / np.pi) * np.cos(x) + 10 - 0.397887357729738160000
395 | def goldsteinprice(self, x):
396 | x1 = x[0]
397 | x2 = x[1]
398 | return (1 + (x1 + x2 + 1)**2 * (19 - 14 * x1 + 3 * x1**2 - 14 * x2 + 6 * x1 * x2 + 3 * x2**2)) * (
399 | 30 + (2 * x1 - 3 * x2)**2 * (18 - 32 * x1 + 12 * x1**2 + 48 * x2 - 36 * x1 * x2 + 27 * x2**2)) - 3
400 | def griewank(self, x):
401 | # was in [-600 600]
402 | x = (600. / 5) * x
403 | return 1 - np.prod(np.cos(x / np.sqrt(1. + np.arange(len(x))))) + sum(x**2) / 4e3
404 | def levy(self, x):
405 | """a rather benign multimodal function.
406 |
407 | xopt == ones, fopt == 0.0
408 | """
409 | w = 1 + (np.asarray(x) - 1) / 4
410 | del x
411 | f = np.sin(np.pi * w[0])**2
412 | f += (w[-1] - 1)**2 * (1 + np.sin(2 * np.pi * w[-1])**2)
413 | w = w[1:-1]
414 | return f + sum((w - 1)**2 * (1 + 10 * np.sin(np.pi * w + 1)**2))
415 | def rastrigin(self, x):
416 | """Rastrigin test objective function"""
417 | if not isscalar(x[0]):
418 | N = len(x[0])
419 | return [10 * N + sum(xi**2 - 10 * np.cos(2 * np.pi * xi)) for xi in x]
420 | # return 10*N + sum(x**2 - 10*np.cos(2*np.pi*x), axis=1)
421 | N = len(x)
422 | return 10 * N + sum(x**2 - 10 * np.cos(2 * np.pi * x))
423 | def schaffer(self, x):
424 | """ Schaffer function x0 in [-100..100]"""
425 | N = len(x)
426 | s = x[0:N - 1]**2 + x[1:N]**2
427 | return sum(s**0.25 * (np.sin(50 * s**0.1)**2 + 1))
428 |
429 | def schwefelelli(self, x):
430 | s = 0
431 | f = 0
432 | for i in rglen(x):
433 | s += x[i]
434 | f += s**2
435 | return f
436 | def schwefelmult(self, x, pen_fac=1e4):
437 | """multimodal Schwefel function with domain -500..500"""
438 | y = [x] if isscalar(x[0]) else x
439 | N = len(y[0])
440 | f = array([418.9829 * N - 1.27275661e-5 * N - sum(x * np.sin(np.abs(x)**0.5))
441 | + pen_fac * sum((abs(x) > 500) * (abs(x) - 500)**2) for x in y])
442 | return f if len(f) > 1 else f[0]
443 | def schwefel2_22(self, x):
444 | """Schwefel 2.22 function"""
445 | return sum(np.abs(x)) + np.prod(np.abs(x))
446 | def optprob(self, x):
447 | n = np.arange(len(x)) + 1
448 | f = n * x * (1 - x)**(n - 1)
449 | return sum(1 - f)
450 | def lincon(self, x, theta=0.01):
451 | """ridge like linear function with one linear constraint"""
452 | if x[0] < 0:
453 | return np.NaN
454 | return theta * x[1] + x[0]
455 | def rosen_nesterov(self, x, rho=100):
456 | """needs exponential number of steps in a non-increasing
457 | f-sequence.
458 |
459 | x_0 = (-1,1,...,1)
460 | See Jarre (2011) "On Nesterov's Smooth Chebyshev-Rosenbrock
461 | Function"
462 |
463 | """
464 | f = 0.25 * (x[0] - 1)**2
465 | f += rho * sum((x[1:] - 2 * x[:-1]**2 + 1)**2)
466 | return f
467 | def powel_singular(self, x):
468 | # ((8 * np.sin(7 * (x[i] - 0.9)**2)**2 ) + (6 * np.sin()))
469 | res = np.sum((x[i - 1] + 10 * x[i])**2 + 5 * (x[i + 1] - x[i + 2])**2 +
470 | (x[i] - 2 * x[i + 1])**4 + 10 * (x[i - 1] - x[i + 2])**4
471 | for i in range(1, len(x) - 2))
472 | return 1 + res
473 | def styblinski_tang(self, x):
474 | """in [-5, 5]
475 | found also in Lazar and Jarre 2016, optimum in f(-2.903534...)=0
476 | """
477 | # x_opt = N * [-2.90353402], seems to have essentially
478 | # (only) 2**N local optima
479 | return (39.1661657037714171054273576010019 * len(x))**1 + \
480 | sum(x**4 - 16 * x**2 + 5 * x) / 2
481 |
482 | def trid(self, x):
483 | return sum((x-1)**2) - sum(x[:-1] * x[1:])
484 |
485 | def bukin(self, x):
486 | """Bukin function from Wikipedia, generalized simplistically from 2-D.
487 |
488 | http://en.wikipedia.org/wiki/Test_functions_for_optimization"""
489 | s = 0
490 | for k in range((1+len(x)) // 2):
491 | z = x[2 * k]
492 | y = x[min((2*k + 1, len(x)-1))]
493 | s += 100 * np.abs(y - 0.01 * z**2)**0.5 + 0.01 * np.abs(z + 10)
494 | return s
495 |
496 | def xinsheyang2(self, x, termination_friendly=True):
497 | """a multimodal function which is rather unsolvable in larger dimension.
498 |
499 | >>> import functools
500 | >>> import numpy as np
501 | >>> import cma
502 | >>> f = functools.partial(cma.ff.xinsheyang2, termination_friendly=False)
503 | >>> X = [(i * [0] + (4 - i) * [1.24]) for i in range(5)]
504 | >>> for x in X: print(x)
505 | [1.24, 1.24, 1.24, 1.24]
506 | [0, 1.24, 1.24, 1.24]
507 | [0, 0, 1.24, 1.24]
508 | [0, 0, 0, 1.24]
509 | [0, 0, 0, 0]
510 | >>> ' '.join(['{:.3}'.format(f(x)) for x in X]) # [np.round(f(x), 3) for x in X]
511 | '0.091 0.186 0.336 0.456 0.0'
512 |
513 | One needs to solve a trinary deceptive function where f-value (to
514 | be minimized) is monotonuously decreasing with increasing distance
515 | to the global optimum >= 1. That is, the global optimum is
516 | surrounded by 3^n - 1 local optima that have the better values the
517 | further they are away from the global optimum.
518 |
519 | Conclusion: it is a rather suspicious sign if an algorithm finds the global
520 | optimum of this function in larger dimension.
521 |
522 | See also http://benchmarkfcns.xyz/benchmarkfcns/xinsheyangn2fcn.html
523 | """
524 | x = np.asarray(x)
525 | val = np.sum(np.abs(x)) * np.exp(-np.sum(np.sin(np.square(x)))) # np.mean under the exponential makes the function much easier
526 | if termination_friendly and val < 1:
527 | val **= 1. / len(x)
528 | return val
529 |
530 | def _fetch_bbob_fcts(self):
531 | """Fetch GECCO BBOB 2009 functions from WWW and set as `self.BBOB`.
532 |
533 | Side effects in the current folder: two files are added and folder
534 | "._tmp_" is removed.
535 | """
536 | # fetch http://coco.lri.fr/downloads/download15.02/bbobpproc.tar.gz
537 | bbob_version, fname = '15.03', 'bbobpproc.tar.gz'
538 | url = 'http://coco.lri.fr/downloads/download'+bbob_version+'/'+fname # 3MB
539 | print('downloading %s ...' % url, end=''); sys.stdout.flush()
540 | utils.download_file(url)
541 | print(' done downloading')
542 | print('extracting bbobbenchmarks.py'); sys.stdout.flush()
543 | utils.extract_targz(url.split(os.path.sep)[-1],
544 | os.path.join('bbob.v' + bbob_version,
545 | 'python', 'bbobbenchmarks.py'))
546 | print('importing bbobbenchmarks.py and setting as BBOB attribute')
547 | import bbobbenchmarks
548 | self.BBOB = bbobbenchmarks
549 | print("BBOB set and ready to go. Example: `f11 = cma.FF.BBOB.F11()`")
550 |
551 | def fetch_bbob_fcts(self):
552 | """Fetch GECCO BBOB 2009 functions from WWW and set as `self.BBOB`.
553 | """
554 | url = "http://coco.gforge.inria.fr/python/bbobbenchmarks.py"
555 |
556 | ff = FitnessFunctions()
557 |
--------------------------------------------------------------------------------
/cma/fitness_transformations.py:
--------------------------------------------------------------------------------
1 | """Wrapper for objective functions like noise, rotation, gluing args
2 | """
3 | from __future__ import absolute_import, division, print_function #, unicode_literals, with_statement
4 | import warnings
5 | from functools import partial
6 | import numpy as np
7 | import time
8 | from .utilities import utils
9 | from .transformations import ConstRandnShift, Rotation
10 | from .constraints_handler import BoundTransform
11 | from .optimization_tools import EvalParallel2 # for backwards compatibility
12 | from .utilities.python3for2 import range
13 | del absolute_import, division, print_function #, unicode_literals, with_statement
14 |
15 | rotate = Rotation()
16 |
17 | class Function(object):
18 | """a declarative base class, indicating that a derived class instance
19 | "is" a (fitness/objective) function.
20 |
21 | A `callable` passed to `__init__` is called as the fitness
22 | `Function`, otherwise the `_eval` method is called, if defined in the
23 | derived class, when the `Function` instance is called. If the input
24 | argument is a matrix or a list of vectors, the method is called for
25 | each vector individually like
26 | ``_eval(numpy.asarray(vector)) for vector in matrix``.
27 |
28 | >>> import cma
29 | >>> from cma.fitness_transformations import Function
30 | >>> f = Function(cma.ff.rosen)
31 | >>> assert f.evaluations == 0
32 | >>> assert f([2, 3]) == cma.ff.rosen([2, 3])
33 | >>> assert f.evaluations == 1
34 | >>> assert f([[1], [2]]) == [cma.ff.rosen([1]), cma.ff.rosen([2])]
35 | >>> assert f.evaluations == 3
36 | >>> class Fsphere(Function):
37 | ... def _eval(self, x):
38 | ... return sum(x**2)
39 | >>> fsphere = Fsphere()
40 | >>> assert fsphere.evaluations == 0
41 | >>> assert fsphere([2, 3]) == 4 + 9 and fsphere([[2], [3]]) == [4, 9]
42 | >>> assert fsphere.evaluations == 3
43 | >>> Fsphere.__init__ = lambda self: None # overwrites Function.__init__
44 | >>> assert Fsphere()([2]) == 4 # which is perfectly fine to do
45 |
46 | Details:
47 |
48 | - When called, a class instance calls either the function passed to
49 | `__init__` or, if none was given, tries to call any of the
50 | `function_names_to_evaluate_first_found`, first come first serve.
51 | By default, ``function_names_to_evaluate_first_found == ["_eval"]``.
52 |
53 | - This class cannot move to module `fitness_functions`, because the
54 | latter uses `fitness_transformations.rotate`.
55 |
56 | """
57 | _function_names_to_evaluate_first_found = ["_eval"]
58 | @property
59 | def function_names_to_evaluate_first_found(self):
60 | """attributes which are searched for to be called if no function
61 | was given to `__init__`.
62 |
63 | The first present match is used.
64 | """ # % str(Function._function_names_to_evaluate_first_found)
65 | return Function._function_names_to_evaluate_first_found
66 |
67 | def __init__(self, fitness_function=None):
68 | """allows to define the fitness_function to be called, doesn't
69 | need to be ever called
70 | """
71 | Function.initialize(self, fitness_function)
72 | def initialize(self, fitness_function):
73 | """initialization of `Function`
74 | """
75 | self.__callable = fitness_function # this naming prevents interference with a derived class variable
76 | self.evaluations = 0
77 | self.ftarget = -np.inf
78 | self.target_hit_at = 0 # evaluation counter when target was first hit
79 | self.__initialized = True
80 |
81 | def __call__(self, *args, **kwargs):
82 | # late initialization if necessary
83 | try:
84 | if not self.__initialized:
85 | raise AttributeError
86 | except AttributeError:
87 | Function.initialize(self, None)
88 | # find the "right" callable
89 | callable_ = self.__callable
90 | if callable_ is None:
91 | for name in self.function_names_to_evaluate_first_found:
92 | try:
93 | callable_ = getattr(self, name)
94 | break
95 | except AttributeError:
96 | pass
97 | # call with each vector
98 | if callable_ is not None:
99 | X, list_revert = utils.as_vector_list(args[0])
100 | self.evaluations += len(X)
101 | F = [callable_(np.asarray(x), *args[1:], **kwargs) for x in X]
102 | if not self.target_hit_at and any(np.asarray(F) <= self.ftarget):
103 | self.target_hit_at = self.evaluations - len(X) + 1 + list(np.asarray(F) <= self.ftarget).index(True)
104 | return list_revert(F)
105 | else:
106 | self.evaluations += 1 # somewhat bound to fail
107 |
108 | class ComposedFunction(Function, list):
109 | """compose an arbitrary number of functions.
110 |
111 | A class instance is a list of functions. Calling the instance executes
112 | the composition of these functions (evaluating from right to left as
113 | in math notation). Functions can be added to or removed from the list
114 | at any time with the obvious effect. To remain consistent (if needed),
115 | the ``list_of_inverses`` attribute must be updated respectively.
116 |
117 | >>> import numpy as np
118 | >>> from cma.fitness_transformations import ComposedFunction
119 | >>> f1, f2, f3, f4 = lambda x: 2*x, lambda x: x**2, lambda x: 3*x, lambda x: x**3
120 | >>> f = ComposedFunction([f1, f2, f3, f4])
121 | >>> assert isinstance(f, list) and isinstance(f, ComposedFunction)
122 | >>> assert f[0] == f1 # how I love Python indexing
123 | >>> assert all(f(x) == f1(f2(f3(f4(x)))) for x in np.random.rand(10))
124 | >>> assert f4 == f.pop()
125 | >>> assert len(f) == 3
126 | >>> f.insert(1, f4)
127 | >>> f.append(f4)
128 | >>> assert all(f(x) == f1(f4(f2(f3(f4(x))))) for x in range(5))
129 |
130 | A more specific example:
131 |
132 | >>> from cma.fitness_transformations import ComposedFunction
133 | >>> from cma.constraints_handler import BoundTransform
134 | >>> from cma import ff
135 | >>> f = ComposedFunction([ff.elli,
136 | ... BoundTransform([[0], [1]]).transform])
137 | >>> assert max(f([2, 3]), f([1, 1])) <= ff.elli([1, 1])
138 |
139 | Details:
140 |
141 | - This class can serve as basis for a more transparent
142 | alternative to a ``scaling_of_variables`` CMA option or for any
143 | necessary transformation of the fitness/objective function
144 | (genotype-phenotype transformation).
145 |
146 | - The parallelizing call with a list of solutions of the `Function`
147 | class is not inherited. The inheritence from `Function` is rather
148 | declarative than funtional and could be omitted.
149 |
150 | """
151 | def __init__(self, list_of_functions, list_of_inverses=None):
152 | """Caveat: to remain consistent, the ``list_of_inverses`` must be
153 | updated explicitly, if the list of function was updated after
154 | initialization.
155 | """
156 | list.__init__(self, list_of_functions)
157 | Function.__init__(self)
158 | self.list_of_inverses = list_of_inverses
159 |
160 | def __call__(self, x, *args, **kwargs):
161 | Function.__call__(self, x, *args, **kwargs) # for the possible side effects only
162 | for i in range(-1, -len(self) - 1, -1):
163 | x = self[i](x, *args, **kwargs)
164 | return x
165 |
166 | def inverse(self, x, *args, **kwargs):
167 | """evaluate the composition of inverses on ``x``.
168 |
169 | Return `None`, if no list was provided.
170 | """
171 | if self.list_of_inverses is None:
172 | utils.print_warning("inverses were not given")
173 | return
174 | for i in range(len(self.list_of_inverses)):
175 | x = self.list_of_inverses[i](x, *args, **kwargs)
176 | return x
177 |
178 | class StackFunction(Function):
179 | """a function that returns ``f1(x[:n1]) + f2(x[n1:])``.
180 |
181 | >>> import functools
182 | >>> import numpy as np
183 | >>> import cma
184 | >>> def elli48(x):
185 | ... return 1e-4 * functools.partial(cma.ff.elli, cond=1e8)(x)
186 | >>> fcigtab = cma.fitness_transformations.StackFunction(
187 | ... elli48, cma.ff.sphere, 2)
188 | >>> x = [1, 2, 3, 4]
189 | >>> assert np.isclose(fcigtab(x), cma.ff.cigtab(np.asarray(x)))
190 |
191 | """
192 | def __init__(self, f1, f2, n1):
193 | self.f1 = f1
194 | self.f2 = f2
195 | self.n1 = n1
196 | def _eval(self, x, *args, **kwargs):
197 | return self.f1(x[:self.n1], *args, **kwargs) + self.f2(x[self.n1:], *args, **kwargs)
198 |
199 | class GlueArguments(Function):
200 | """deprecated, use `functools.partial` or
201 | `cma.fitness_transformations.partial` instead, which has the same
202 | functionality and interface.
203 |
204 | from a `callable` return a `callable` with arguments attached.
205 |
206 |
207 | An ellipsoid function with condition number ``1e4`` is created by
208 | ``felli1e4 = cma.s.ft.GlueArguments(cma.ff.elli, cond=1e4)``.
209 |
210 | >>> import cma
211 | >>> f = cma.fitness_transformations.GlueArguments(cma.ff.elli,
212 | ... cond=1e1)
213 | >>> assert f([1, 2]) == 1**2 + 1e1 * 2**2
214 |
215 | """
216 | def __init__(self, fitness_function, *args, **kwargs):
217 | """define function, ``args``, and ``kwargs``.
218 |
219 | ``args`` are appended to arguments passed in the call, ``kwargs``
220 | are updated with keyword arguments passed in the call.
221 | """
222 | Function.__init__(self, fitness_function)
223 | self.fitness_function = fitness_function # never used
224 | self.args = args
225 | self.kwargs = kwargs
226 | def __call__(self, x, *args, **kwargs):
227 | """call function with at least one additional argument and
228 | attached args and kwargs.
229 | """
230 | joined_kwargs = dict(self.kwargs)
231 | joined_kwargs.update(kwargs)
232 | x = np.asarray(x)
233 | return Function.__call__(self, x, *(args + self.args),
234 | **joined_kwargs)
235 |
236 | class FBoundTransform(ComposedFunction):
237 | """shortcut for ``ComposedFunction([f, BoundTransform(bounds).transform])``,
238 | see also below.
239 |
240 | Maps the argument into bounded or half-bounded (feasible) domain
241 | before evaluating ``f``.
242 |
243 | Example with lower bound at 0, which becomes the image of -0.05 in
244 | `BoundTransform.transform`:
245 |
246 | >>> import cma, numpy as np
247 | >>> f = cma.fitness_transformations.FBoundTransform(cma.ff.elli,
248 | ... [[0], None])
249 | >>> assert all(f[1](np.random.randn(200)) >= 0)
250 | >>> assert all(f[1]([-0.05, -0.05]) == 0)
251 | >>> assert f([-0.05, -0.05]) == 0
252 |
253 | A slightly more verbose version to implement the lower bound at zero
254 | in the very same way:
255 |
256 | >>> import cma
257 | >>> felli_in_bound = cma.s.ft.ComposedFunction(
258 | ... [cma.ff.elli, cma.BoundTransform([[0], None]).transform])
259 |
260 | """
261 | def __init__(self, fitness_function, bounds):
262 | """`bounds[0]` are lower bounds, `bounds[1]` are upper bounds
263 | """
264 | self.bound_tf = BoundTransform(bounds) # not strictly necessary
265 | ComposedFunction.__init__(self,
266 | [fitness_function, self.bound_tf.transform])
267 |
268 | class Rotated(ComposedFunction):
269 | """return a rotated version of a function for testing purpose.
270 |
271 | This class is a convenience shortcut for the litte more verbose
272 | composition of a function with a rotation:
273 |
274 | >>> import cma
275 | >>> from cma import fitness_transformations as ft
276 | >>> f1 = ft.Rotated(cma.ff.elli)
277 | >>> f2 = ft.ComposedFunction([cma.ff.elli, ft.Rotation()])
278 | >>> assert f1([2]) == f2([2]) # same rotation only in 1-D
279 | >>> assert f1([1, 2]) != f2([1, 2])
280 |
281 | """
282 | def __init__(self, f, rotate=None, seed=None):
283 | """optional argument ``rotate(x)`` must return a (stable) rotation
284 | of ``x``.
285 | """
286 | if rotate is None:
287 | rotate = Rotation(seed=seed)
288 | ComposedFunction.__init__(self, [f, rotate])
289 |
290 | class Shifted(ComposedFunction):
291 | """compose a function with a shift in x-space.
292 |
293 | >>> import cma
294 | >>> f = cma.s.ft.Shifted(cma.ff.elli)
295 |
296 | Details: this class solely provides as default second argument to
297 | `ComposedFunction`, namely a random shift in search space.
298 | ``shift=lambda x: x`` would provide "no shift", ``None``
299 | expands to ``cma.transformations.ConstRandnShift()``.
300 | """
301 | def __init__(self, f, shift=None):
302 | """``shift(x)`` must return a (stable) shift of x"""
303 | if shift is None:
304 | shift = ConstRandnShift()
305 | ComposedFunction.__init__(self, [f, shift])
306 |
307 | class ScaleCoordinates(ComposedFunction):
308 | """compose a (fitness) function with a scaling for each variable
309 | (more concisely, a coordinate-wise affine transformation).
310 |
311 | After ``fun2 = cma.ScaleCoordinates(fun, multipliers, zero)``, we have
312 | ``fun2(x) == fun(multipliers * (x - zero))``, where the size of
313 | `multipliers` and `zero` is adapated to the size of `x`, in case by
314 | recycling their last entry.
315 |
316 | >>> import numpy as np
317 | >>> import cma
318 | >>> f = cma.ScaleCoordinates(cma.ff.sphere, [100, 1])
319 | >>> assert f[0] == cma.ff.sphere # first element of f-composition
320 | >>> assert f(range(1, 6)) == 100**2 + sum([x**2 for x in range(2, 6)])
321 | >>> assert f([2.1]) == 210**2 == f(2.1)
322 | >>> assert f(20 * [1]) == 100**2 + 19
323 | >>> assert np.all(f.inverse(f.scale_and_offset([1, 2, 3, 4])) ==
324 | ... np.asarray([1, 2, 3, 4]))
325 | >>> f = cma.ScaleCoordinates(f, [-2, 7], [2, 3, 4]) # last is recycled
326 | >>> f([5, 6]) == sum(x**2 for x in [100 * -2 * (5 - 2), 7 * (6 - 3)])
327 | True
328 |
329 | """
330 | def __init__(self, fitness_function, multipliers=None, zero=None):
331 | """
332 | :param fitness_function: a `callable` object
333 | :param multipliers: coordinate-wise multipliers.
334 | :param zero: defines a new zero in preimage space, that is,
335 | calling the `ScaleCoordinates` instance returns
336 | ``fitness_function(multipliers * (x - zero))``.
337 |
338 | For both arguments, ``multipliers`` and ``zero``, to fit
339 | the length of the given input, superfluous trailing
340 | elements are ignored and the last element is recycled
341 | if needed.
342 | """
343 | ComposedFunction.__init__(self,
344 | [fitness_function, self.scale_and_offset])
345 | self.multiplier = multipliers
346 | if self.multiplier is not None:
347 | self.multiplier = np.asarray(self.multiplier, dtype=float)
348 | self.zero = zero
349 | if zero is not None:
350 | self.zero = np.asarray(zero, dtype=float)
351 |
352 | def scale_and_offset(self, x):
353 | x = np.asarray(x)
354 | r = lambda vec: utils.recycled(vec, as_=x)
355 | if self.zero is not None and self.multiplier is not None:
356 | x = r(self.multiplier) * (x - r(self.zero))
357 | elif self.zero is not None:
358 | x = x - r(self.zero)
359 | elif self.multiplier is not None:
360 | x = r(self.multiplier) * x
361 | return x
362 |
363 | def inverse(self, x):
364 | """inverse of coordinate-wise affine transformation
365 | ``y / multipliers + zero``
366 | """
367 | x = np.asarray(x)
368 | r = lambda vec: utils.recycled(vec, as_=x)
369 | if self.zero is not None and self.multiplier is not None:
370 | x = x / r(self.multiplier) + r(self.zero)
371 | elif self.zero is not None:
372 | x = x + r(self.zero)
373 | elif self.multiplier is not None:
374 | x = x / r(self.multiplier)
375 | return x
376 |
377 | class FixVariables(ComposedFunction):
378 | """Insert variables with given values, thereby reducing the
379 | dimensionality of the resulting composed function.
380 |
381 | The constructor takes ``index_value_pairs``, a `dict` or `list` of
382 | pairs, as input and returns a function with smaller preimage space
383 | than input function ``f``.
384 |
385 | Fixing variable 3 and 5 works like
386 |
387 | >>> from cma.fitness_transformations import FixVariables
388 | >>> index_value_pairs = [[2, 0.2], [4, 0.4]]
389 | >>> fun = FixVariables(cma.ff.elli, index_value_pairs)
390 | >>> fun[1](4 * [1]) == [ 1., 1., 0.2, 1., 0.4, 1.]
391 | True
392 |
393 | Or starting from a given current solution in the larger space from
394 | which we pick the fixed values:
395 |
396 | >>> from cma.fitness_transformations import FixVariables
397 | >>> current_solution = [0.1 * i for i in range(5)]
398 | >>> fixed_indices = [2, 4]
399 | >>> index_value_pairs = [[i, current_solution[i]] # fix these
400 | ... for i in fixed_indices]
401 | >>> fun = FixVariables(cma.ff.elli, index_value_pairs)
402 | >>> fun[1](4 * [1]) == [ 1., 1., 0.2, 1., 0.4, 1.]
403 | True
404 | >>> assert (current_solution == # list with same values
405 | ... fun.transform(fun.insert_variables(current_solution)))
406 | >>> assert (current_solution == # list with same values
407 | ... fun.insert_variables(fun.transform(current_solution)))
408 |
409 | Details: this might replace the ``fixed_variables`` option in
410 | `CMAOptions` in future, but hasn't been thoroughly tested yet.
411 |
412 | Supersedes `ExpandSolution`.
413 |
414 | """
415 | def __init__(self, f, index_value_pairs):
416 | """return `f` with reduced dimensionality.
417 |
418 | ``index_value_pairs``:
419 | variables
420 | """
421 | # super(FixVariables, self).__init__(
422 | ComposedFunction.__init__(self, [f, self.insert_variables])
423 | self.index_value_pairs = dict(index_value_pairs)
424 | def transform(self, x):
425 | """transform `x` such that it could be used as argument to `self`.
426 |
427 | Return a list or array, usually dismissing some elements of
428 | `x`. ``fun.transform`` is the inverse of
429 | ``fun.insert_variables == fun[1]``, that is
430 | ``np.all(x == fun.transform(fun.insert_variables(x))) is True``.
431 | """
432 | res = [x[i] for i in range(len(x))
433 | if i not in self.index_value_pairs]
434 | return res if isinstance(x, list) else np.asarray(res)
435 | def insert_variables(self, x):
436 | """return `x` with inserted fixed values"""
437 | if len(self.index_value_pairs) == 0:
438 | return x
439 | y = list(x)
440 | for i in sorted(self.index_value_pairs):
441 | y.insert(i, self.index_value_pairs[i])
442 | if not isinstance(x, list):
443 | y = np.asarray(y) # doubles the necessary time
444 | return y
445 |
446 | class Expensify(Function):
447 | """Add waiting time to each evaluation, to simulate "expensive"
448 | behavior"""
449 | def __init__(self, callable_, time=1):
450 | """add time in seconds"""
451 | Function.__init__(self) # callable_ could go here
452 | self.time = time
453 | self.callable = callable_
454 | def __call__(self, *args, **kwargs):
455 | time.sleep(self.time)
456 | Function.__call__(self, *args, **kwargs)
457 | return self.callable(*args, **kwargs)
458 |
459 | class SomeNaNFitness(Function):
460 | """transform ``fitness_function`` to return sometimes ``NaN``"""
461 | def __init__(self, fitness_function, probability_of_nan=0.1):
462 | Function.__init__(self)
463 | self.fitness_function = fitness_function
464 | self.p = probability_of_nan
465 | def __call__(self, x, *args):
466 | Function.__call__(self, x, *args)
467 | if np.random.rand(1) <= self.p:
468 | return np.NaN
469 | else:
470 | return self.fitness_function(x, *args)
471 |
472 | class NoisyFitness(Function):
473 | """apply noise via ``f += rel_noise(dim) * f + abs_noise(dim)``"""
474 | def __init__(self, fitness_function,
475 | rel_noise=lambda dim: 1.1 * np.random.randn() / dim,
476 | abs_noise=lambda dim: 1.1 * np.random.randn()):
477 | """attach relative and absolution noise to ``fitness_function``.
478 |
479 | Relative noise is by default computed using the length of the
480 | input argument to ``fitness_function``. Both noise functions take
481 | ``dimension`` as input.
482 |
483 | >>> import cma
484 | >>> from cma.fitness_transformations import NoisyFitness
485 | >>> fn = NoisyFitness(cma.ff.elli)
486 | >>> assert fn([1, 2]) != cma.ff.elli([1, 2])
487 | >>> assert fn.evaluations == 1
488 |
489 | """
490 | Function.__init__(self, fitness_function)
491 | self.rel_noise = rel_noise
492 | self.abs_noise = abs_noise
493 |
494 | def __call__(self, x, *args):
495 | f = Function.__call__(self, x, *args)
496 | if self.rel_noise:
497 | f += f * self.rel_noise(len(x))
498 | assert np.isscalar(f)
499 | if self.abs_noise:
500 | f += self.abs_noise(len(x))
501 | return f
502 |
503 | class IntegerMixedFunction(ComposedFunction):
504 | """compose fitness function with some integer variables.
505 |
506 | >>> import cma
507 | >>> f = cma.s.ft.IntegerMixedFunction(cma.ff.elli, [0, 3, 6])
508 | >>> assert f([0.2, 2]) == f([0.4, 2]) != f([1.2, 2])
509 |
510 | It is advisable to set minstd of integer variables to
511 | ``1 / (2 * len(integer_variable_indices) + 1)``, in which case in
512 | an independent model at least 33% (1 integer variable) -> 39% (many
513 | integer variables) of the solutions should have an integer mutation
514 | on average. Option ``integer_variables`` of `cma.CMAOptions`
515 | implements this simple measure.
516 | """
517 | def __init__(self, function, integer_variable_indices, copy_arg=True):
518 | ComposedFunction.__init__(self, [function, self._flatten])
519 | self.integer_variable_indices = integer_variable_indices
520 | self.copy_arg = copy_arg
521 | def _flatten(self, x):
522 | x = np.array(x, copy=self.copy_arg)
523 | for i in sorted(self.integer_variable_indices):
524 | if i < -len(x):
525 | continue
526 | if i >= len(x):
527 | break
528 | x[i] = np.floor(x[i])
529 | return x
530 |
--------------------------------------------------------------------------------
/cma/interfaces.py:
--------------------------------------------------------------------------------
1 | """Very few interface defining base class definitions"""
2 | from __future__ import absolute_import, division, print_function #, unicode_literals
3 | import warnings
4 | try: from .optimization_tools import EvalParallel2
5 | except: EvalParallel2 = None
6 | del absolute_import, division, print_function #, unicode_literals
7 |
8 | class EvalParallel:
9 | """allow construct ``with EvalParallel(fun) as eval_all:``"""
10 | def __init__(self, fun, *args, **kwargs):
11 | self.fun = fun
12 | def __call__(self, X, args=()):
13 | return [self.fun(x, *args) for x in X]
14 | def __enter__(self): return self
15 | def __exit__(self, *args, **kwargs): pass
16 |
17 | class OOOptimizer(object):
18 | """abstract base class for an Object Oriented Optimizer interface.
19 |
20 | Relevant methods are `__init__`, `ask`, `tell`, `optimize` and `stop`,
21 | and property `result`. Only `optimize` is fully implemented in this
22 | base class.
23 |
24 | Examples
25 | --------
26 | All examples minimize the function `elli`, the output is not shown.
27 | (A preferred environment to execute all examples is ``ipython``.)
28 |
29 | First we need::
30 |
31 | # CMAEvolutionStrategy derives from the OOOptimizer class
32 | from cma import CMAEvolutionStrategy
33 | from cma.fitness_functions import elli
34 |
35 | The shortest example uses the inherited method
36 | `OOOptimizer.optimize`::
37 |
38 | es = CMAEvolutionStrategy(8 * [0.1], 0.5).optimize(elli)
39 |
40 | The input parameters to `CMAEvolutionStrategy` are specific to this
41 | inherited class. The remaining functionality is based on interface
42 | defined by `OOOptimizer`. We might have a look at the result::
43 |
44 | print(es.result[0]) # best solution and
45 | print(es.result[1]) # its function value
46 |
47 | Virtually the same example can be written with an explicit loop
48 | instead of using `optimize`. This gives the necessary insight into
49 | the `OOOptimizer` class interface and entire control over the
50 | iteration loop::
51 |
52 | # a new CMAEvolutionStrategy instance
53 | optim = CMAEvolutionStrategy(9 * [0.5], 0.3)
54 |
55 | # this loop resembles optimize()
56 | while not optim.stop(): # iterate
57 | X = optim.ask() # get candidate solutions
58 | f = [elli(x) for x in X] # evaluate solutions
59 | # in case do something else that needs to be done
60 | optim.tell(X, f) # do all the real "update" work
61 | optim.disp(20) # display info every 20th iteration
62 | optim.logger.add() # log another "data line", non-standard
63 |
64 | # final output
65 | print('termination by', optim.stop())
66 | print('best f-value =', optim.result[1])
67 | print('best solution =', optim.result[0])
68 | optim.logger.plot() # if matplotlib is available
69 |
70 | Details
71 | -------
72 | Most of the work is done in the methods `tell` or `ask`. The property
73 | `result` provides more useful output.
74 |
75 | """
76 | def __init__(self, xstart, *more_mandatory_args, **optional_kwargs):
77 | """``xstart`` is a mandatory argument"""
78 | self.xstart = xstart
79 | self.more_mandatory_args = more_mandatory_args
80 | self.optional_kwargs = optional_kwargs
81 | self.initialize()
82 | def initialize(self):
83 | """(re-)set to the initial state"""
84 | raise NotImplementedError('method initialize() must be implemented in derived class')
85 | self.countiter = 0
86 | self.xcurrent = [xi for xi in self.xstart]
87 | def ask(self, **optional_kwargs):
88 | """abstract method, AKA "get" or "sample_distribution", deliver
89 | new candidate solution(s), a list of "vectors"
90 | """
91 | raise NotImplementedError('method ask() must be implemented in derived class')
92 | def tell(self, solutions, function_values):
93 | """abstract method, AKA "update", pass f-values and prepare for
94 | next iteration
95 | """
96 | self.countiter += 1
97 | raise NotImplementedError('method tell() must be implemented in derived class')
98 | def stop(self):
99 | """abstract method, return satisfied termination conditions in a
100 | dictionary like ``{'termination reason': value, ...}`` or ``{}``.
101 |
102 | For example ``{'tolfun': 1e-12}``, or the empty dictionary ``{}``.
103 |
104 | TODO: this should rather be a property!? Unfortunately, a change
105 | would break backwards compatibility.
106 | """
107 | raise NotImplementedError('method stop() is not implemented')
108 | def disp(self, modulo=None):
109 | """abstract method, display some iteration info when
110 | ``self.iteration_counter % modulo < 1``, using a reasonable
111 | default for `modulo` if ``modulo is None``.
112 | """
113 | @property
114 | def result(self):
115 | """abstract property, contain ``(x, f(x), ...)``, that is, the
116 | minimizer, its function value, ...
117 | """
118 | raise NotImplementedError('result property is not implemented')
119 | return [self.xcurrent]
120 |
121 | def optimize(self, objective_fct,
122 | maxfun=None, iterations=None, min_iterations=1,
123 | args=(),
124 | verb_disp=None,
125 | callback=None,
126 | n_jobs=0,
127 | **kwargs):
128 | """find minimizer of ``objective_fct``.
129 |
130 | CAVEAT: the return value for `optimize` has changed to ``self``,
131 | allowing for a call like::
132 |
133 | solver = OOOptimizer(x0).optimize(f)
134 |
135 | and investigate the state of the solver.
136 |
137 | Arguments
138 | ---------
139 |
140 | ``objective_fct``: f(x: array_like) -> float
141 | function be to minimized
142 | ``maxfun``: number
143 | maximal number of function evaluations
144 | ``iterations``: number
145 | number of (maximal) iterations, while ``not self.stop()``,
146 | it can be useful to conduct only one iteration at a time.
147 | ``min_iterations``: number
148 | minimal number of iterations, even if ``not self.stop()``
149 | ``args``: sequence_like
150 | arguments passed to ``objective_fct``
151 | ``verb_disp``: number
152 | print to screen every ``verb_disp`` iteration, if `None`
153 | the value from ``self.logger`` is "inherited", if
154 | available.
155 | ``callback``: callable or list of callables
156 | callback function called like ``callback(self)`` or
157 | a list of call back functions called in the same way. If
158 | available, ``self.logger.add`` is added to this list.
159 | TODO: currently there is no way to prevent this other than
160 | changing the code of `_prepare_callback_list`.
161 | ``n_jobs=0``: number of processes to be acquired for
162 | multiprocessing to parallelize calls to `objective_fct`.
163 | Must be >1 to expect any speed-up or `None` or `-1`, which
164 | both default to the number of available CPUs. The default
165 | ``n_jobs=0`` avoids the use of multiprocessing altogether.
166 |
167 | ``return self``, that is, the `OOOptimizer` instance.
168 |
169 | Example
170 | -------
171 | >>> import cma
172 | >>> es = cma.CMAEvolutionStrategy(7 * [0.1], 0.1
173 | ... ).optimize(cma.ff.rosen, verb_disp=100)
174 | ... #doctest: +ELLIPSIS
175 | (4_w,9)-aCMA-ES (mu_w=2.8,w_1=49%) in dimension 7 (seed=...)
176 | Iterat #Fevals function value axis ratio sigma ...
177 | 1 9 ...
178 | 2 18 ...
179 | 3 27 ...
180 | 100 900 ...
181 | >>> cma.s.Mh.vequals_approximately(es.result[0], 7 * [1], 1e-5)
182 | True
183 |
184 | """
185 | if kwargs:
186 | message = "ignoring unkown argument%s %s in OOOptimizer.optimize" % (
187 | 's' if len(kwargs) > 1 else '', str(kwargs))
188 | warnings.warn(
189 | message) # warnings.simplefilter('ignore', lineno=186) suppresses this warning
190 |
191 | if iterations is not None and min_iterations > iterations:
192 | warnings.warn("doing min_iterations = %d > %d = iterations"
193 | % (min_iterations, iterations))
194 | iterations = min_iterations
195 | callback = self._prepare_callback_list(callback)
196 |
197 | citer, cevals = 0, 0
198 | with (EvalParallel2 or EvalParallel)(objective_fct,
199 | None if n_jobs == -1 else n_jobs) as eval_all:
200 | while not self.stop() or citer < min_iterations:
201 | if (maxfun and cevals >= maxfun) or (
202 | iterations and citer >= iterations):
203 | return self
204 | citer += 1
205 |
206 | X = self.ask() # deliver candidate solutions
207 | # fitvals = [objective_fct(x, *args) for x in X]
208 | fitvals = eval_all(X, args=args)
209 | cevals += len(fitvals)
210 | self.tell(X, fitvals) # all the work is done here
211 | for f in callback:
212 | f(self)
213 | self.disp(verb_disp) # disp does nothing if not overwritten
214 |
215 | # final output
216 | self._force_final_logging()
217 |
218 | if verb_disp: # do not print by default to allow silent verbosity
219 | self.disp(1)
220 | print('termination by', self.stop())
221 | print('best f-value =', self.result[1])
222 | print('solution =', self.result[0])
223 |
224 | return self
225 |
226 | def _prepare_callback_list(self, callback): # helper function
227 | """return a list of callbacks including ``self.logger.add``.
228 |
229 | ``callback`` can be a `callable` or a `list` (or iterable) of
230 | callables. Otherwise a `ValueError` exception is raised.
231 | """
232 | if callback is None:
233 | callback = []
234 | if callable(callback):
235 | callback = [callback]
236 | try:
237 | callback = list(callback) + [self.logger.add]
238 | except AttributeError:
239 | pass
240 | try:
241 | for c in callback:
242 | if not callable(c):
243 | raise ValueError("""callback argument %s is not
244 | callable""" % str(c))
245 | except TypeError:
246 | raise ValueError("""callback argument must be a `callable` or
247 | an iterable (e.g. a list) of callables, after some
248 | processing it was %s""" % str(callback))
249 | return callback
250 |
251 | def _force_final_logging(self): # helper function
252 | """try force the logger to log NOW"""
253 | try:
254 | if not self.logger:
255 | return
256 | except AttributeError:
257 | return
258 | # the idea: modulo == 0 means never log, 1 or True means log now
259 | try:
260 | modulo = bool(self.logger.modulo)
261 | except AttributeError:
262 | modulo = True # could also be named force
263 | try:
264 | self.logger.add(self, modulo=modulo)
265 | except AttributeError:
266 | pass
267 | except TypeError:
268 | try:
269 | self.logger.add(self)
270 | except Exception as e:
271 | print(' The final call of the logger in'
272 | ' OOOptimizer._force_final_logging from'
273 | ' OOOptimizer.optimize did not succeed: %s'
274 | % str(e))
275 |
276 | class StatisticalModelSamplerWithZeroMeanBaseClass(object):
277 | """yet versatile base class to replace a sampler namely in
278 | `CMAEvolutionStrategy`
279 | """
280 | def __init__(self, std_vec, **kwargs):
281 | """pass the vector of initial standard deviations or dimension of
282 | the underlying sample space.
283 |
284 | Ideally catch the case when `std_vec` is a scalar and then
285 | interpreted as dimension.
286 | """
287 | try:
288 | dimension = len(std_vec)
289 | except TypeError: # std_vec has no len
290 | dimension = std_vec
291 | std_vec = dimension * [1]
292 | raise NotImplementedError
293 |
294 | def sample(self, number, update=None):
295 | """return list of i.i.d. samples.
296 |
297 | :param number: is the number of samples.
298 | :param update: controls a possibly lazy update of the sampler.
299 | """
300 | raise NotImplementedError
301 |
302 | def update(self, vectors, weights):
303 | """``vectors`` is a list of samples, ``weights`` a corrsponding
304 | list of learning rates
305 | """
306 | raise NotImplementedError
307 |
308 | def parameters(self, mueff=None, lam=None):
309 | """return `dict` with (default) parameters, e.g., `c1` and `cmu`.
310 |
311 | :See also: `RecombinationWeights`"""
312 | if (hasattr(self, '_mueff') and hasattr(self, '_lam') and
313 | (mueff == self._mueff or mueff is None) and
314 | (lam == self._lam or lam is None)):
315 | return self._parameters
316 | self._mueff = mueff
317 | lower_lam = 6 # for setting c1
318 | if lam is None:
319 | lam = lower_lam
320 | self._lam = lam
321 | # todo: put here rather generic formula with degrees of freedom
322 | # todo: replace these base class computations with the appropriate
323 | c1 = min((1, lam / lower_lam)) * 2 / ((self.dimension + 1.3)**2.0 + mueff)
324 | alpha = 2
325 | self._parameters = dict(
326 | c1=c1,
327 | cmu=min((1 - c1,
328 | # or alpha * (mueff - 0.9) with relative min and
329 | # max value of about 1: 0.4, 1.75: 1.5
330 | alpha * (0.25 + mueff - 2 + 1 / mueff) /
331 | ((self.dimension + 2)**2 + alpha * mueff / 2)))
332 | )
333 | return self._parameters
334 |
335 | def norm(self, x):
336 | """return Mahalanobis norm of `x` w.r.t. the statistical model"""
337 | return sum(self.transform_inverse(x)**2)**0.5
338 | @property
339 | def condition_number(self):
340 | raise NotImplementedError
341 | @property
342 | def covariance_matrix(self):
343 | raise NotImplementedError
344 | @property
345 | def variances(self):
346 | """vector of coordinate-wise (marginal) variances"""
347 | raise NotImplementedError
348 |
349 | def transform(self, x):
350 | """transform ``x`` as implied from the distribution parameters"""
351 | raise NotImplementedError
352 |
353 | def transform_inverse(self, x):
354 | raise NotImplementedError
355 |
356 | def to_linear_transformation_inverse(self, reset=False):
357 | """return inverse of associated linear transformation"""
358 | raise NotImplementedError
359 |
360 | def to_linear_transformation(self, reset=False):
361 | """return associated linear transformation"""
362 | raise NotImplementedError
363 |
364 | def inverse_hessian_scalar_correction(self, mean, X, f):
365 | """return scalar correction ``alpha`` such that ``X`` and ``f``
366 | fit to ``f(x) = (x-mean) (alpha * C)**-1 (x-mean)``
367 | """
368 | raise NotImplementedError
369 |
370 | def __imul__(self, factor):
371 | raise NotImplementedError
372 |
373 | class BaseDataLogger(object):
374 | """abstract base class for a data logger that can be used with an
375 | `OOOptimizer`.
376 |
377 | Details: attribute `modulo` is used in `OOOptimizer.optimize`.
378 | """
379 |
380 | def __init__(self):
381 | self.optim = None
382 | """object instance to be logging data from"""
383 | self._data = None
384 | """`dict` of logged data"""
385 | self.filename = "_BaseDataLogger_datadict.py"
386 | """file to save to or load from unless specified otherwise"""
387 |
388 | def register(self, optim, *args, **kwargs):
389 | """register an optimizer ``optim``, only needed if method `add` is
390 | called without passing the ``optim`` argument
391 | """
392 | self.optim = optim
393 | return self
394 |
395 | def add(self, optim=None, more_data=None, **kwargs):
396 | """abstract method, add a "data point" from the state of ``optim``
397 | into the logger.
398 |
399 | The argument ``optim`` can be omitted if ``optim`` was
400 | ``register`` ()-ed before, acts like an event handler
401 | """
402 | raise NotImplementedError
403 |
404 | def disp(self, *args, **kwargs):
405 | """abstract method, display some data trace"""
406 | print('method BaseDataLogger.disp() not implemented, to be done in subclass ' + str(type(self)))
407 |
408 | def plot(self, *args, **kwargs):
409 | """abstract method, plot data"""
410 | print('method BaseDataLogger.plot() is not implemented, to be done in subclass ' + str(type(self)))
411 |
412 | def save(self, name=None):
413 | """save data to file `name` or `self.filename`"""
414 | with open(name or self.filename, 'w') as f:
415 | f.write(repr(self._data))
416 |
417 | def load(self, name=None):
418 | """load data from file `name` or `self.filename`"""
419 | from ast import literal_eval
420 | with open(name or self.filename, 'r') as f:
421 | self._data = literal_eval(f.read())
422 | return self
423 | @property
424 | def data(self):
425 | """logged data in a dictionary"""
426 | return self._data
427 |
--------------------------------------------------------------------------------
/cma/recombination_weights.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """`RecombinationWeights` is a list of recombination weights for the CMA-ES.
3 |
4 | The most delicate part is the correct setting of negative weights depending
5 | on learning rates to prevent negative definite matrices when using the
6 | weights in the covariance matrix update.
7 |
8 | The dependency chain is
9 |
10 | lambda -> weights -> mueff -> c1, cmu -> negative weights
11 |
12 | """
13 | # https://gist.github.com/nikohansen/3eb4ef0790ff49276a7be3cdb46d84e9
14 | from __future__ import division, print_function
15 | import math
16 |
17 | class RecombinationWeights(list):
18 | """a list of decreasing (recombination) weight values.
19 |
20 | To be used in the update of the covariance matrix C in CMA-ES as
21 | ``w_i``::
22 |
23 | C <- (1 - c1 - cmu * sum w_i) C + c1 ... + cmu sum w_i y_i y_i^T
24 |
25 | After calling `finalize_negative_weights`, the weights
26 | ``w_i`` let ``1 - c1 - cmu * sum w_i = 1`` and guaranty positive
27 | definiteness of C if ``y_i^T C^-1 y_i <= dimension`` for all
28 | ``w_i < 0``.
29 |
30 | Class attributes/properties:
31 |
32 | - ``lambda_``: number of weights, alias for ``len(self)``
33 | - ``mu``: number of strictly positive weights, i.e.
34 | ``sum([wi > 0 for wi in self])``
35 | - ``mueff``: variance effective number of positive weights, i.e.
36 | ``1 / sum([self[i]**2 for i in range(self.mu)])`` where
37 | ``1 == sum([self[i] for i in range(self.mu)])**2``
38 | - `mueffminus`: variance effective number of negative weights
39 | - `positive_weights`: `np.array` of the strictly positive weights
40 | - ``finalized``: `True` if class instance is ready to use
41 |
42 | Class methods not inherited from `list`:
43 |
44 | - `finalize_negative_weights`: main method
45 | - `zero_negative_weights`: set negative weights to zero, leads to
46 | ``finalized`` to be `True`.
47 | - `set_attributes_from_weights`: useful when weight values are
48 | "manually" changed, removed or inserted
49 | - `asarray`: alias for ``np.asarray(self)``
50 | - `do_asserts`: check consistency of weight values, passes also when
51 | not yet ``finalized``
52 |
53 | Usage:
54 |
55 | >>> # from recombination_weights import RecombinationWeights
56 | >>> from cma.recombination_weights import RecombinationWeights
57 | >>> dimension, popsize = 5, 7
58 | >>> weights = RecombinationWeights(popsize)
59 | >>> c1 = 2. / (dimension + 1)**2 # caveat: __future___ division
60 | >>> cmu = weights.mueff / (weights.mueff + dimension**2)
61 | >>> weights.finalize_negative_weights(dimension, c1, cmu)
62 | >>> print('weights = [%s]' % ', '.join("%.2f" % w for w in weights))
63 | weights = [0.59, 0.29, 0.12, 0.00, -0.31, -0.57, -0.79]
64 | >>> print("sum=%.2f, c1+cmu*sum=%.2f" % (sum(weights),
65 | ... c1 + cmu * sum(weights)))
66 | sum=-0.67, c1+cmu*sum=0.00
67 | >>> print('mueff=%.1f, mueffminus=%.1f, mueffall=%.1f' % (
68 | ... weights.mueff,
69 | ... weights.mueffminus,
70 | ... sum(abs(w) for w in weights)**2 /
71 | ... sum(w**2 for w in weights)))
72 | mueff=2.3, mueffminus=2.7, mueffall=4.8
73 | >>> weights = RecombinationWeights(popsize)
74 | >>> print("sum=%.2f, mu=%d, sumpos=%.2f, sumneg=%.2f" % (
75 | ... sum(weights),
76 | ... weights.mu,
77 | ... sum(weights[:weights.mu]),
78 | ... sum(weights[weights.mu:])))
79 | sum=0.00, mu=3, sumpos=1.00, sumneg=-1.00
80 | >>> print('weights = [%s]' % ', '.join("%.2f" % w for w in weights))
81 | weights = [0.59, 0.29, 0.12, 0.00, -0.19, -0.34, -0.47]
82 | >>> weights = RecombinationWeights(21)
83 | >>> weights.finalize_negative_weights(3, 0.081, 0.28)
84 | >>> weights.insert(weights.mu, 0) # add zero weight in the middle
85 | >>> weights = weights.set_attributes_from_weights() # change lambda_
86 | >>> assert weights.lambda_ == 22
87 | >>> print("sum=%.2f, mu=%d, sumpos=%.2f" %
88 | ... (sum(weights), weights.mu, sum(weights[:weights.mu])))
89 | sum=0.24, mu=10, sumpos=1.00
90 | >>> print('weights = [%s]%%' % ', '.join(["%.1f" % (100*weights[i])
91 | ... for i in range(0, 22, 5)]))
92 | weights = [27.0, 6.8, 0.0, -6.1, -11.7]%
93 | >>> weights.zero_negative_weights() # doctest:+ELLIPSIS
94 | [0.270...
95 | >>> "%.2f, %.2f" % (sum(weights), sum(weights[weights.mu:]))
96 | '1.00, 0.00'
97 | >>> mu = int(weights.mu / 2)
98 | >>> for i in range(len(weights)):
99 | ... weights[i] = 1. / mu if i < mu else 0
100 | >>> weights = weights.set_attributes_from_weights()
101 | >>> 5 * "%.1f " % (sum(w for w in weights if w > 0),
102 | ... sum(w for w in weights if w < 0),
103 | ... weights.mu,
104 | ... weights.mueff,
105 | ... weights.mueffminus)
106 | '1.0 0.0 5.0 5.0 0.0 '
107 |
108 | The optimal weights on the sphere and other functions are closer
109 | to exponent 0.75:
110 |
111 | >>> for expo, w in [(expo, RecombinationWeights(5, exponent=expo))
112 | ... for expo in [1, 0.9, 0.8, 0.7, 0.6, 0.5]]:
113 | ... print(7 * "%.2f " % tuple([expo, w.mueff] + w))
114 | 1.00 1.65 0.73 0.27 0.00 -0.36 -0.64
115 | 0.90 1.70 0.71 0.29 0.00 -0.37 -0.63
116 | 0.80 1.75 0.69 0.31 0.00 -0.39 -0.61
117 | 0.70 1.80 0.67 0.33 0.00 -0.40 -0.60
118 | 0.60 1.84 0.65 0.35 0.00 -0.41 -0.59
119 | 0.50 1.89 0.62 0.38 0.00 -0.43 -0.57
120 |
121 | >>> for lam in [8, 8**2, 8**3, 8**4]:
122 | ... if lam == 8:
123 | ... print(" lam expo mueff w[i] / w[i](1)")
124 | ... print(" /mu(1) 1 2 3 4 5 6 7 8")
125 | ... w1 = RecombinationWeights(lam, exponent=1)
126 | ... for expo, w in [(expo, RecombinationWeights(lam, exponent=expo))
127 | ... for expo in [1, 0.8, 0.6]]:
128 | ... print('%4d ' % lam + 10 * "%.2f " % tuple([expo, w.mueff / w1.mueff] + [w[i] / w1[i] for i in range(8)]))
129 | lam expo mueff w[i] / w[i](1)
130 | /mu(1) 1 2 3 4 5 6 7 8
131 | 8 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00
132 | 8 0.80 1.11 0.90 1.02 1.17 1.50 1.30 1.07 0.98 0.93
133 | 8 0.60 1.24 0.80 1.02 1.35 2.21 1.68 1.13 0.95 0.85
134 | 64 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00
135 | 64 0.80 1.17 0.82 0.86 0.88 0.91 0.93 0.95 0.97 0.98
136 | 64 0.60 1.36 0.65 0.72 0.76 0.80 0.84 0.87 0.91 0.94
137 | 512 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00
138 | 512 0.80 1.20 0.76 0.78 0.79 0.80 0.81 0.82 0.83 0.83
139 | 512 0.60 1.42 0.56 0.59 0.61 0.63 0.64 0.65 0.67 0.68
140 | 4096 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00
141 | 4096 0.80 1.21 0.71 0.73 0.74 0.74 0.75 0.75 0.76 0.76
142 | 4096 0.60 1.44 0.50 0.52 0.53 0.54 0.55 0.55 0.56 0.56
143 |
144 | Reference: Hansen 2016, arXiv:1604.00772.
145 | """
146 | def __init__(self, len_, exponent=1):
147 | """return recombination weights `list`, post condition is
148 | ``sum(self) == 0 and sum(self.positive_weights) == 1``.
149 |
150 | Positive and negative weights sum to 1 and -1, respectively.
151 | The number of positive weights, ``self.mu``, is about
152 | ``len_/2``. Weights are strictly decreasing.
153 |
154 | `finalize_negative_weights` (...) or `zero_negative_weights` ()
155 | should be called to finalize the negative weights.
156 |
157 | :param `len_`: AKA ``lambda`` is the number of weights, see
158 | attribute `lambda_` which is an alias for ``len(self)``.
159 | Alternatively, a list of "raw" weights can be provided.
160 |
161 | """
162 | weights = len_
163 | self.exponent = exponent # for the record
164 | if exponent is None:
165 | self.exponent = 1 # shall become 4/5 or 3/4?
166 | try:
167 | len_ = len(weights)
168 | except TypeError:
169 | try: # iterator without len
170 | len_ = len(list(weights))
171 | except TypeError: # create from scratch
172 | def signed_power(x, expo):
173 | if expo == 1: return x
174 | s = (x != 0) * (-1 if x < 0 else 1)
175 | return s * math.fabs(x)**expo
176 | weights = [signed_power(math.log((len_ + 1) / 2.) - math.log(i), self.exponent)
177 | for i in range(1, len_ + 1)] # raw shape
178 | if len_ < 2:
179 | raise ValueError("number of weights must be >=2, was %d"
180 | % (len_))
181 | self.debug = False
182 |
183 | # self[:] = weights # should do, or
184 | # super(RecombinationWeights, self).__init__(weights)
185 | list.__init__(self, weights)
186 |
187 | self.set_attributes_from_weights(do_asserts=False)
188 | sum_neg = sum(self[self.mu:])
189 | if sum_neg != 0:
190 | for i in range(self.mu, len(self)):
191 | self[i] /= -sum_neg
192 | self.do_asserts()
193 | self.finalized = False
194 |
195 | def set_attributes_from_weights(self, weights=None, do_asserts=True):
196 | """make the class attribute values consistent with weights, in
197 | case after (re-)setting the weights from input parameter ``weights``,
198 | post condition is also ``sum(self.postive_weights) == 1``.
199 |
200 | This method allows to set or change the weight list manually,
201 | e.g. like ``weights[:] = new_list`` or using the `pop`,
202 | `insert` etc. generic `list` methods to change the list.
203 | Currently, weights must be non-increasing and the first weight
204 | must be strictly positive and the last weight not larger than
205 | zero. Then all ``weights`` are normalized such that the
206 | positive weights sum to one.
207 | """
208 | if weights is not None:
209 | if not weights[0] > 0:
210 | raise ValueError(
211 | "the first weight must be >0 but was %f" % weights[0])
212 | if weights[-1] > 0:
213 | raise ValueError(
214 | "the last weight must be <=0 but was %f" %
215 | weights[-1])
216 | self[:] = weights
217 | weights = self
218 | assert all(weights[i] >= weights[i+1]
219 | for i in range(len(weights) - 1))
220 | self.mu = sum(w > 0 for w in weights)
221 | spos = sum(weights[:self.mu])
222 | assert spos > 0
223 | for i in range(len(self)):
224 | self[i] /= spos
225 | # variance-effectiveness of sum^mu w_i x_i
226 | self.mueff = 1**2 / sum(w**2 for w in
227 | weights[:self.mu])
228 | sneg = sum(weights[self.mu:])
229 | assert (sneg - sum(w for w in weights if w < 0))**2 < 1e-11
230 | not do_asserts or self.do_asserts()
231 | return self
232 |
233 | def finalize_negative_weights(self, dimension, c1, cmu, pos_def=True):
234 | """finalize negative weights using ``dimension`` and learning
235 | rates ``c1`` and ``cmu``.
236 |
237 | This is a rather intricate method which makes this class
238 | useful. The negative weights are scaled to achieve
239 | in this order:
240 |
241 | 1. zero decay, i.e. ``c1 + cmu * sum w == 0``,
242 | 2. a learning rate respecting mueff, i.e. ``sum |w|^- / sum |w|^+
243 | <= 1 + 2 * self.mueffminus / (self.mueff + 2)``,
244 | 3. if `pos_def` guaranty positive definiteness when sum w^+ = 1
245 | and all negative input vectors used later have at most their
246 | dimension as squared Mahalanobis norm. This is accomplished by
247 | guarantying ``(dimension-1) * cmu * sum |w|^- < 1 - c1 - cmu``
248 | via setting ``sum |w|^- <= (1 - c1 -cmu) / dimension / cmu``.
249 |
250 | The latter two conditions do not change the weights with default
251 | population size.
252 |
253 | Details:
254 |
255 | - To guaranty 3., the input vectors associated to negative
256 | weights must obey ||.||^2 <= dimension in Mahalanobis norm.
257 | - The third argument, ``cmu``, usually depends on the
258 | (raw) weights, in particular it depends on ``self.mueff``.
259 | For this reason the calling syntax
260 | ``weights = RecombinationWeights(...).finalize_negative_weights(...)``
261 | is not supported.
262 |
263 | """
264 | if dimension <= 0:
265 | raise ValueError("dimension must be larger than zero, was " +
266 | str(dimension))
267 | self._c1 = c1 # for the record
268 | self._cmu = cmu
269 |
270 | if self[-1] < 0:
271 | if cmu > 0:
272 | if c1 > 10 * cmu:
273 | print("""WARNING: c1/cmu = %f/%f seems to assume a
274 | too large value for negative weights setting"""
275 | % (c1, cmu))
276 | self._negative_weights_set_sum(1 + c1 / cmu)
277 | if pos_def:
278 | self._negative_weights_limit_sum((1 - c1 - cmu) / cmu
279 | / dimension)
280 | self._negative_weights_limit_sum(1 + 2 * self.mueffminus
281 | / (self.mueff + 2))
282 | self.do_asserts()
283 | self.finalized = True
284 |
285 | if self.debug:
286 | print("sum w = %.2f (final)" % sum(self))
287 |
288 | def zero_negative_weights(self):
289 | """finalize by setting all negative weights to zero"""
290 | for k in range(len(self)):
291 | self[k] *= 0 if self[k] < 0 else 1
292 | self.finalized = True
293 | return self
294 |
295 | def _negative_weights_set_sum(self, value):
296 | """set sum of negative weights to ``-abs(value)``
297 |
298 | Precondition: the last weight must no be greater than zero.
299 |
300 | Details: if no negative weight exists, all zero weights with index
301 | lambda / 2 or greater become uniformely negative.
302 | """
303 | weights = self # simpler to change to data attribute and nicer to read
304 | value = abs(value) # simplify code, prevent erroneous assertion error
305 | assert weights[self.mu] <= 0
306 | if not weights[-1] < 0:
307 | # breaks if mu == lambda
308 | # we could also just return here
309 | # return
310 | istart = max((self.mu, int(self.lambda_ / 2)))
311 | for i in range(istart, self.lambda_):
312 | weights[i] = -value / (self.lambda_ - istart)
313 | factor = abs(value / sum(weights[self.mu:]))
314 | for i in range(self.mu, self.lambda_):
315 | weights[i] *= factor
316 | assert 1 - value - 1e-5 < sum(weights) < 1 - value + 1e-5
317 | if self.debug:
318 | print("sum w = %.2f, sum w^- = %.2f" %
319 | (sum(weights), -sum(weights[self.mu:])))
320 |
321 | def _negative_weights_limit_sum(self, value):
322 | """lower bound the sum of negative weights to ``-abs(value)``.
323 | """
324 | weights = self # simpler to change to data attribute and nicer to read
325 | value = abs(value) # simplify code, prevent erroneous assertion error
326 | if sum(weights[self.mu:]) >= -value: # nothing to limit
327 | return # needed when sum is zero
328 | assert weights[-1] < 0 and weights[self.mu] <= 0
329 | factor = abs(value / sum(weights[self.mu:]))
330 | if factor < 1:
331 | for i in range(self.mu, self.lambda_):
332 | weights[i] *= factor
333 | if self.debug:
334 | print("sum w = %.2f (with correction %.2f)" %
335 | (sum(weights), value))
336 | assert sum(weights) + 1e-5 >= 1 - value
337 |
338 | def do_asserts(self):
339 | """assert consistency.
340 |
341 | Assert:
342 |
343 | - attribute values of ``lambda_, mu, mueff, mueffminus``
344 | - value of first and last weight
345 | - monotonicity of weights
346 | - sum of positive weights to be one
347 |
348 | """
349 | weights = self
350 | assert 1 >= weights[0] > 0
351 | assert weights[-1] <= 0
352 | assert len(weights) == self.lambda_
353 | assert all(weights[i] >= weights[i+1]
354 | for i in range(len(weights) - 1)) # monotony
355 | assert self.mu > 0 # needed for next assert
356 | assert weights[self.mu-1] > 0 >= weights[self.mu]
357 | assert 0.999 < sum(w for w in weights[:self.mu]) < 1.001
358 | assert (self.mueff / 1.001 <
359 | sum(weights[:self.mu])**2 / sum(w**2 for w in weights[:self.mu]) <
360 | 1.001 * self.mueff)
361 | assert (self.mueffminus == 0 == sum(weights[self.mu:]) or
362 | self.mueffminus / 1.001 <
363 | sum(weights[self.mu:])**2 / sum(w**2 for w in weights[self.mu:]) <
364 | 1.001 * self.mueffminus)
365 |
366 | @property
367 | def lambda_(self):
368 | """alias for ``len(self)``"""
369 | return len(self)
370 | @property
371 | def mueffminus(self):
372 | weights = self
373 | sneg = sum(weights[self.mu:])
374 | assert (sneg - sum(w for w in weights if w < 0))**2 < 1e-11
375 | return (0 if sneg == 0 else
376 | sneg**2 / sum(w**2 for w in weights[self.mu:]))
377 | @property
378 | def positive_weights(self):
379 | """all (strictly) positive weights as ``np.array``.
380 |
381 | Useful to implement recombination for the new mean vector.
382 | """
383 | try:
384 | from numpy import asarray
385 | return asarray(self[:self.mu])
386 | except:
387 | return self[:self.mu]
388 | @property
389 | def asarray(self):
390 | """return weights as numpy array"""
391 | from numpy import asarray
392 | return asarray(self)
393 |
--------------------------------------------------------------------------------
/cma/s.py:
--------------------------------------------------------------------------------
1 | """versatile shortcuts for quick typing in an (i)python shell or even
2 | ``from cma.s import *`` in interactive sessions.
3 |
4 | Provides various aliases from within the `cma` package, to be reached like
5 | ``cma.s....``
6 |
7 | Don't use for stable code.
8 | """
9 | import warnings as _warnings
10 | try: from matplotlib import pyplot as _pyplot # like this it doesn't show up in the interface
11 | except ImportError:
12 | _pyplot = None
13 | _warnings.warn('Could not import matplotlib.pyplot, therefore'
14 | ' ``cma.plot()`` etc. is not available')
15 | # from . import fitness_functions as ff
16 | from . import evolution_strategy as es
17 | from . import fitness_transformations as ft
18 | from . import transformations as tf
19 | from . import constraints_handler as ch
20 | from .utilities import utils
21 | from .evolution_strategy import CMAEvolutionStrategy as CMAES
22 | from .utilities.utils import pprint
23 | from .utilities.math import Mh
24 | # from .fitness_functions import elli as felli
25 |
26 | if _pyplot:
27 | def figshow():
28 | """`pyplot.show` to make a plotted figure show up"""
29 | # is_interactive = matplotlib.is_interactive()
30 |
31 | _pyplot.ion()
32 | _pyplot.show()
33 | # if we call now matplotlib.interactive(True), the console is
34 | # blocked
35 | figsave = _pyplot.savefig
36 |
37 | if 11 < 3:
38 | try:
39 | from matplotlib.pyplot import savefig as figsave, close as figclose, ion as figion
40 | except:
41 | figsave, figclose, figshow = 3 * ['not available']
42 | _warnings.warn('Could not import matplotlib.pyplot, therefore ``cma.plot()``'
43 | ' etc. is not available')
44 |
--------------------------------------------------------------------------------
/cma/sigma_adaptation.py:
--------------------------------------------------------------------------------
1 | """step-size adaptation classes, currently tightly linked to CMA,
2 | because `hsig` is computed in the base class
3 | """
4 | from __future__ import absolute_import, division, print_function #, unicode_literals, with_statement
5 | import numpy as np
6 | from numpy import square as _square, sqrt as _sqrt
7 | from .utilities import utils
8 | from .utilities.math import Mh
9 | def _norm(x): return np.sqrt(np.sum(np.square(x)))
10 | del absolute_import, division, print_function #, unicode_literals, with_statement
11 |
12 | class CMAAdaptSigmaBase(object):
13 | """step-size adaptation base class, implement `hsig` (for stalling
14 | distribution update) functionality via an isotropic evolution path.
15 |
16 | Details: `hsig` or `_update_ps` must be called before the sampling
17 | distribution is changed. `_update_ps` depends heavily on
18 | `cma.CMAEvolutionStrategy`.
19 | """
20 | def __init__(self, *args, **kwargs):
21 | self.is_initialized_base = False
22 | self._ps_updated_iteration = -1
23 | self.delta = 1
24 | "cumulated effect of adaptation"
25 | def initialize_base(self, es):
26 | """set parameters and state variable based on dimension,
27 | mueff and possibly further options.
28 |
29 | """
30 | ## meta_parameters.cs_exponent == 1.0
31 | b = 1.0
32 | ## meta_parameters.cs_multiplier == 1.0
33 | self.cs = 1.0 * (es.sp.weights.mueff + 2)**b / (es.N**b + (es.sp.weights.mueff + 3)**b)
34 | self.ps = np.zeros(es.N)
35 | self.is_initialized_base = True
36 | return self
37 | def _update_ps(self, es):
38 | """update the isotropic evolution path.
39 |
40 | Using ``es`` attributes ``mean``, ``mean_old``, ``sigma``,
41 | ``sigma_vec``, ``sp.weights.mueff``, ``cp.cmean`` and
42 | ``sm.transform_inverse``.
43 |
44 | :type es: CMAEvolutionStrategy
45 | """
46 | if not self.is_initialized_base:
47 | self.initialize_base(es)
48 | if self._ps_updated_iteration == es.countiter:
49 | return
50 | try:
51 | if es.countiter <= es.sm.itereigenupdated:
52 | # es.B and es.D must/should be those from the last iteration
53 | utils.print_warning('distribution transformation (B and D) have been updated before ps could be computed',
54 | '_update_ps', 'CMAAdaptSigmaBase', verbose=es.opts['verbose'])
55 | except AttributeError:
56 | pass
57 | z = es.sm.transform_inverse((es.mean - es.mean_old) / es.sigma_vec.scaling)
58 | # assert Mh.vequals_approximately(z, np.dot(es.B, (1. / es.D) *
59 | # np.dot(es.B.T, (es.mean - es.mean_old) / es.sigma_vec.scaling)))
60 | z *= es.sp.weights.mueff**0.5 / es.sigma / es.sp.cmean
61 | self.ps = (1 - self.cs) * self.ps + (self.cs * (2 - self.cs))**0.5 * z
62 | self._ps_updated_iteration = es.countiter
63 | def hsig(self, es):
64 | """return "OK-signal" for rank-one update, `True` (OK) or `False`
65 | (stall rank-one update), based on the length of an evolution path
66 |
67 | """
68 | self._update_ps(es)
69 | if self.ps is None:
70 | return True
71 | squared_sum = np.sum(self.ps**2) / (1 - (1 - self.cs)**(2 * es.countiter))
72 | # correction with self.countiter seems not necessary,
73 | # as pc also starts with zero
74 | return squared_sum / es.N - 1 < 1 + 4. / (es.N + 1)
75 |
76 | def update2(self, es, **kwargs):
77 | """return sigma change factor and update self.delta.
78 |
79 | ``self.delta == sigma/sigma0`` accumulates all past changes
80 | starting from `1.0`.
81 |
82 | Unlike `update`, `update2` is not supposed to change attributes
83 | in `es`, specifically it should not change `es.sigma`.
84 | """
85 | self._update_ps(es)
86 | raise NotImplementedError('must be implemented in a derived class')
87 |
88 | def update(self, es, **kwargs):
89 | """update ``es.sigma``
90 |
91 | :param es: `CMAEvolutionStrategy` class instance
92 | :param kwargs: whatever else is needed to update ``es.sigma``,
93 | which should be none.
94 | """
95 | self._update_ps(es)
96 | raise NotImplementedError('must be implemented in a derived class')
97 | def check_consistency(self, es):
98 | """make consistency checks with a `CMAEvolutionStrategy` instance
99 | as input
100 | """
101 | class CMAAdaptSigmaNone(CMAAdaptSigmaBase):
102 | """constant step-size sigma"""
103 | def update(self, es, **kwargs):
104 | """no update, ``es.sigma`` remains constant.
105 | """
106 | pass
107 | class CMAAdaptSigmaDistanceProportional(CMAAdaptSigmaBase):
108 | """artificial setting of ``sigma`` for test purposes, e.g.
109 | to simulate optimal progress rates.
110 |
111 | """
112 | def __init__(self, coefficient=1.2, **kwargs):
113 | """pass multiplier for normalized step-size"""
114 | super(CMAAdaptSigmaDistanceProportional, self).__init__() # base class provides method hsig()
115 | self.coefficient = coefficient
116 | self.is_initialized = True
117 | def update(self, es, **kwargs):
118 | """need attributes ``N``, ``sp.weights.mueff``, ``mean``,
119 | ``sp.cmean`` of input parameter ``es``
120 | """
121 | es.sigma = self.coefficient * es.sp.weights.mueff * _norm(es.mean) / es.N / es.sp.cmean
122 | class CMAAdaptSigmaCSA(CMAAdaptSigmaBase):
123 | """CSA cumulative step-size adaptation AKA path length control.
124 |
125 | As of 2017, CSA is considered as the default step-size control method
126 | within CMA-ES.
127 | """
128 | def __init__(self, **kwargs):
129 | """postpone initialization to a method call where dimension and mueff should be known.
130 |
131 | """
132 | self.is_initialized = False
133 | self.delta = 1
134 | def initialize(self, es):
135 | """set parameters and state variable based on dimension,
136 | mueff and possibly further options.
137 |
138 | """
139 | self.disregard_length_setting = True if es.opts['CSA_disregard_length'] else False
140 | if es.opts['CSA_clip_length_value'] is not None:
141 | try:
142 | if len(es.opts['CSA_clip_length_value']) == 0:
143 | es.opts['CSA_clip_length_value'] = [-np.Inf, np.Inf]
144 | elif len(es.opts['CSA_clip_length_value']) == 1:
145 | es.opts['CSA_clip_length_value'] = [-np.Inf, es.opts['CSA_clip_length_value'][0]]
146 | elif len(es.opts['CSA_clip_length_value']) == 2:
147 | es.opts['CSA_clip_length_value'] = np.sort(es.opts['CSA_clip_length_value'])
148 | else:
149 | raise ValueError('option CSA_clip_length_value should be a number of len(.) in [1,2]')
150 | except TypeError: # len(...) failed
151 | es.opts['CSA_clip_length_value'] = [-np.Inf, es.opts['CSA_clip_length_value']]
152 | es.opts['CSA_clip_length_value'] = list(np.sort(es.opts['CSA_clip_length_value']))
153 | if es.opts['CSA_clip_length_value'][0] > 0 or es.opts['CSA_clip_length_value'][1] < 0:
154 | raise ValueError('option CSA_clip_length_value must be a single positive or a negative and a positive number')
155 | ## meta_parameters.cs_exponent == 1.0
156 | b = 1.0
157 | ## meta_parameters.cs_multiplier == 1.0
158 | self.cs = 1.0 * (es.sp.weights.mueff + 2)**b / (es.N**b + (es.sp.weights.mueff + 3)**b)
159 |
160 | self.damps = es.opts['CSA_dampfac'] * (0.5 +
161 | 0.5 * min([1, (es.sp.lam_mirr / (0.159 * es.sp.popsize) - 1)**2])**1 +
162 | 2 * max([0, ((es.sp.weights.mueff - 1) / (es.N + 1))**es.opts['CSA_damp_mueff_exponent'] - 1]) +
163 | self.cs
164 | )
165 | self.max_delta_log_sigma = 1 # in symmetric use (strict lower bound is -cs/damps anyway)
166 |
167 | if self.disregard_length_setting:
168 | es.opts['CSA_clip_length_value'] = [0, 0]
169 | ## meta_parameters.cs_exponent == 1.0
170 | b = 1.0 * 0.5
171 | ## meta_parameters.cs_multiplier == 1.0
172 | self.cs = 1.0 * (es.sp.weights.mueff + 1)**b / (es.N**b + 2 * es.sp.weights.mueff**b)
173 | self.damps = es.opts['CSA_dampfac'] * 1 # * (1.1 - 1/(es.N+1)**0.5)
174 | if es.opts['verbose'] > 1:
175 | print('CMAAdaptSigmaCSA Parameters: ')
176 | for k, v in self.__dict__.items():
177 | print(' ', k, ':', v)
178 | self.ps = np.zeros(es.N)
179 | self._ps_updated_iteration = -1
180 | self.is_initialized = True
181 | def _update_ps(self, es):
182 | """update path with isotropic delta mean, possibly clipped.
183 |
184 | From input argument `es`, the attributes isotropic_mean_shift,
185 | opts['CSA_clip_length_value'], and N are used.
186 | opts['CSA_clip_length_value'] can be a single value, the upper
187 | bound parameter, such that::
188 |
189 | max_len = sqrt(N) + opts['CSA_clip_length_value'] * N / (N+2)
190 |
191 | or a list with lower and upper bound parameters.
192 | """
193 | if not self.is_initialized:
194 | self.initialize(es)
195 | if self._ps_updated_iteration == es.countiter:
196 | return
197 | z = es.isotropic_mean_shift
198 | if es.opts['CSA_clip_length_value'] is not None:
199 | vals = es.opts['CSA_clip_length_value']
200 | try: len(vals)
201 | except TypeError: vals = [-np.inf, vals]
202 | if vals[0] > 0 or vals[1] < 0:
203 | raise ValueError(
204 | """value(s) for option 'CSA_clip_length_value' = %s
205 | not allowed""" % str(es.opts['CSA_clip_length_value']))
206 | min_len = es.N**0.5 + vals[0] * es.N / (es.N + 2)
207 | max_len = es.N**0.5 + vals[1] * es.N / (es.N + 2)
208 | act_len = _norm(z)
209 | new_len = Mh.minmax(act_len, min_len, max_len)
210 | if new_len != act_len:
211 | z *= new_len / act_len
212 | # z *= (es.N / sum(z**2))**0.5 # ==> sum(z**2) == es.N
213 | # z *= es.const.chiN / sum(z**2)**0.5
214 | self.ps = (1 - self.cs) * self.ps + _sqrt(self.cs * (2 - self.cs)) * z
215 | self._ps_updated_iteration = es.countiter
216 | def update2(self, es, **kwargs):
217 | """call ``self._update_ps(es)`` and update self.delta.
218 |
219 | Return change factor of self.delta.
220 |
221 | From input `es`, either attribute N or const.chiN is used.
222 | """
223 | self._update_ps(es) # caveat: if es.B or es.D are already updated and ps is not, this goes wrong!
224 | p = self.ps
225 | if 'pc for ps' in es.opts['vv']:
226 | # was: es.D**-1 * np.dot(es.B.T, es.pc)
227 | p = es.sm.transform_inverse(es.pc)
228 | if es.opts['CSA_squared']:
229 | s = (sum(_square(p)) / es.N - 1) / 2
230 | # sum(self.ps**2) / es.N has mean 1 and std sqrt(2/N) and is skewed
231 | # divided by 2 to have the derivative d/dx (x**2 / N - 1) for x**2=N equal to 1
232 | else:
233 | s = _norm(p) / es.const.chiN - 1
234 | s *= self.cs / self.damps
235 | s_clipped = Mh.minmax(s, -self.max_delta_log_sigma, self.max_delta_log_sigma)
236 | # "error" handling
237 | if s_clipped != s:
238 | utils.print_warning('sigma change np.exp(' + str(s) + ') = ' + str(np.exp(s)) +
239 | ' clipped to np.exp(+-' + str(self.max_delta_log_sigma) + ')',
240 | 'update',
241 | 'CMAAdaptSigmaCSA',
242 | es.countiter, es.opts['verbose'])
243 | self.delta *= np.exp(s_clipped)
244 | return np.exp(s_clipped)
245 | def update(self, es, **kwargs):
246 | """call ``self._update_ps(es)`` and update ``es.sigma``.
247 |
248 | Legacy method replaced by `update2`.
249 | """
250 | es.sigma *= self.update2(es, **kwargs)
251 | if 11 < 3:
252 | # derandomized MSR = natural gradient descent using mean(z**2) instead of mu*mean(z)**2
253 | fit = kwargs['fit'] # == es.fit
254 | slengths = np.array([sum(z**2) for z in es.arz[fit.idx[:es.sp.weights.mu]]])
255 | # print lengths[0::int(es.sp.weights.mu/5)]
256 | es.sigma *= np.exp(np.dot(es.sp.weights, slengths / es.N - 1))**(2 / (es.N + 1))
257 | if 11 < 3:
258 | es.more_to_write.append(10**((sum(self.ps**2) / es.N / 2 - 1 / 2 if es.opts['CSA_squared'] else _norm(self.ps) / es.const.chiN - 1)))
259 | es.more_to_write.append(10**(-3.5 + sum(self.ps**2) / es.N / 2 - _norm(self.ps) / es.const.chiN))
260 | # es.more_to_write.append(10**(-3 + sum(es.arz[es.fit.idx[0]]**2) / es.N))
261 |
262 | class CMAAdaptSigmaMedianImprovement(CMAAdaptSigmaBase):
263 | """Compares median fitness to the 27%tile fitness of the
264 | previous iteration, see Ait ElHara et al, GECCO 2013.
265 |
266 | >>> import cma
267 | >>> es = cma.CMAEvolutionStrategy(3 * [1], 1,
268 | ... {'AdaptSigma':cma.sigma_adaptation.CMAAdaptSigmaMedianImprovement,
269 | ... 'verbose': -9})
270 | >>> assert es.optimize(cma.ff.elli).result[1] < 1e-9
271 | >>> assert es.result[2] < 2000
272 |
273 | """
274 | def __init__(self, **kwargs):
275 | CMAAdaptSigmaBase.__init__(self) # base class provides method hsig()
276 | # super(CMAAdaptSigmaMedianImprovement, self).__init__()
277 | def initialize(self, es):
278 | """late initialization using attributes ``N`` and ``popsize``"""
279 | r = es.sp.weights.mueff / es.popsize
280 | self.index_to_compare = 0.5 * (r**0.5 + 2.0 * (1 - r**0.5) / np.log(es.N + 9)**2) * (es.popsize) # TODO
281 | self.index_to_compare = 0.30 * es.popsize # TODO
282 | self.damp = 2 - 2 / es.N # sign-rule: 2
283 | self.c = 0.3 # sign-rule needs <= 0.3
284 | self.s = 0 # averaged statistics, usually between -1 and +1
285 | def update(self, es, **kwargs):
286 | if es.countiter < 2:
287 | self.initialize(es)
288 | self.fit = es.fit.fit
289 | else:
290 | ft1, ft2 = self.fit[int(self.index_to_compare)], self.fit[int(np.ceil(self.index_to_compare))]
291 | ftt1, ftt2 = es.fit.fit[(es.popsize - 1) // 2], es.fit.fit[int(np.ceil((es.popsize - 1) / 2))]
292 | pt2 = self.index_to_compare - int(self.index_to_compare)
293 | # ptt2 = (es.popsize - 1) / 2 - (es.popsize - 1) // 2 # not in use
294 | s = 0
295 | if 1 < 3:
296 | s += pt2 * sum(es.fit.fit <= self.fit[int(np.ceil(self.index_to_compare))])
297 | s += (1 - pt2) * sum(es.fit.fit < self.fit[int(self.index_to_compare)])
298 | s -= es.popsize / 2.
299 | s *= 2. / es.popsize # the range was popsize, is 2
300 | elif 11 < 3: # compare ft with median of ftt
301 | s += self.index_to_compare - sum(self.fit <= es.fit.fit[es.popsize // 2])
302 | s *= 2 / es.popsize # the range was popsize, is 2
303 | else: # compare ftt j-index of ft
304 | s += (1 - pt2) * np.sign(ft1 - ftt1)
305 | s += pt2 * np.sign(ft2 - ftt1)
306 | self.s = (1 - self.c) * self.s + self.c * s
307 | es.sigma *= np.exp(self.s / self.damp)
308 | # es.more_to_write.append(10**(self.s))
309 |
310 | #es.more_to_write.append(10**((2 / es.popsize) * (sum(es.fit.fit < self.fit[int(self.index_to_compare)]) - (es.popsize + 1) / 2)))
311 | # # es.more_to_write.append(10**(self.index_to_compare - sum(self.fit <= es.fit.fit[es.popsize // 2])))
312 | # # es.more_to_write.append(10**(np.sign(self.fit[int(self.index_to_compare)] - es.fit.fit[es.popsize // 2])))
313 | if 11 < 3:
314 | import scipy.stats.stats as stats
315 | zkendall = stats.kendalltau(list(es.fit.fit) + list(self.fit),
316 | len(es.fit.fit) * [0] + len(self.fit) * [1])[0]
317 | es.more_to_write.append(10**zkendall)
318 | self.fit = es.fit.fit
319 | class CMAAdaptSigmaTPA(CMAAdaptSigmaBase):
320 | """two point adaptation for step-size sigma.
321 |
322 | Relies on a specific sampling of the first two offspring, whose
323 | objective function value ranks are used to decide on the step-size
324 | change, see `update` for the specifics.
325 |
326 | Example
327 | =======
328 |
329 | >>> import cma
330 | >>> cma.CMAOptions('adapt').pprint() # doctest: +ELLIPSIS
331 | AdaptSigma='True...
332 | >>> es = cma.CMAEvolutionStrategy(10 * [0.2], 0.1,
333 | ... {'AdaptSigma': cma.sigma_adaptation.CMAAdaptSigmaTPA,
334 | ... 'ftarget': 1e-8}) # doctest: +ELLIPSIS
335 | (5_w,10)-aCMA-ES (mu_w=3.2,w_1=45%) in dimension 10 (seed=...
336 | >>> es.optimize(cma.ff.rosen) # doctest: +ELLIPSIS
337 | Iter...
338 | >>> assert 'ftarget' in es.stop()
339 | >>> assert es.result[1] <= 1e-8 # should coincide with the above
340 | >>> assert es.result[2] < 6500 # typically < 5500
341 |
342 | References: loosely based on Hansen 2008, CMA-ES with Two-Point
343 | Step-Size Adaptation, more tightly based on Hansen et al. 2014,
344 | How to Assess Step-Size Adaptation Mechanisms in Randomized Search.
345 |
346 | """
347 | def __init__(self, dimension=None, opts=None, **kwargs):
348 | super(CMAAdaptSigmaTPA, self).__init__() # base class provides method hsig()
349 | # CMAAdaptSigmaBase.__init__(self)
350 | self.initialized = False
351 | self.dimension = dimension
352 | self.opts = opts
353 | def initialize(self, N=None, opts=None):
354 | """late initialization.
355 |
356 | :param N: is used for the (minor) dependency on dimension,
357 | :param opts: is used for hacking
358 | """
359 | if self.initialized is True:
360 | return self
361 | self.initialized = False
362 | if N is None:
363 | N = self.dimension
364 | if opts is None:
365 | opts = self.opts
366 | try:
367 | damp_fac = opts['CSA_dampfac'] # should be renamed to sigma_adapt_dampfac or something
368 | except (TypeError, KeyError):
369 | damp_fac = 1
370 |
371 | self.sp = utils.BlancClass() # just a container to have sp.name instead of sp['name'] to access parameters
372 | try:
373 | self.sp.damp = damp_fac * eval('N')**0.5 # (1) why do we need 10 <-> np.exp(1/10) == 1.1? 2 should be fine!?
374 | self.sp.damp = damp_fac * (4 - 3.6/eval('N')**0.5) # (2) should become new default!?
375 | self.sp.damp = damp_fac * eval('N')**0.25
376 | self.sp.damp = 0.7 + np.log(eval('N')) # between 2 and 9 very close to N**1/2, for N=7 equal to (1) and (2)
377 | # self.sp.damp = 100
378 | except:
379 | self.sp.damp = 4 # or 1 + np.log(10)
380 | self.initialized = 1/2
381 | try:
382 | self.sp.damp = opts['vv']['TPA_damp']
383 | print('damp set to %d' % self.sp.damp)
384 | except (KeyError, TypeError):
385 | pass
386 |
387 | self.sp.dampup = 0.5**0.0 * 1.0 * self.sp.damp # 0.5 fails to converge on the Rastrigin function
388 | self.sp.dampdown = 2.0**0.0 * self.sp.damp
389 | if self.sp.dampup != self.sp.dampdown:
390 | print('TPA damping is asymmetric')
391 | self.sp.c = 0.3 # rank difference is asymetric and therefore the switch from increase to decrease takes too long
392 | self.sp.z_exponent = 0.5 # sign(z) * abs(z)**z_exponent, 0.5 seems better with larger popsize, 1 was default
393 | self.sp.sigma_fac = 1.0 # (obsolete) 0.5 feels better, but no evidence whether it is
394 | self.sp.relative_to_delta_mean = True # (obsolete)
395 | self.s = 0 # the state/summation variable
396 | self.last = None
397 | if not self.initialized:
398 | self.initialized = True
399 | return self
400 | def update(self, es, function_values, **kwargs):
401 | """the first and second value in ``function_values``
402 | must reflect two mirrored solutions.
403 |
404 | Mirrored solutions must have been sampled
405 | in direction / in opposite direction of
406 | the previous mean shift, respectively.
407 | """
408 | # On the linear function, the two mirrored samples lead
409 | # to a sharp increase of the condition of the covariance matrix,
410 | # unless we have negative weights (which we have now by default).
411 | # Otherwise they should not be used to update the covariance
412 | # matrix, if the step-size inreases quickly.
413 | if self.initialized is not True: # try again
414 | self.initialize(es.N, es.opts)
415 | if self.initialized is not True:
416 | utils.print_warning("dimension not known, damping set to 4",
417 | 'update', 'CMAAdaptSigmaTPA')
418 | self.initialized = True
419 | if 1 < 3:
420 | f_vals = np.asarray(function_values)
421 | z = np.sum(f_vals < f_vals[1]) - np.sum(f_vals < f_vals[0])
422 | z /= len(f_vals) - 1 # z in [-1, 1]
423 | elif 1 < 3:
424 | # use the ranking difference of the mirrors for adaptation
425 | # damp = 5 should be fine
426 | z = np.nonzero(es.fit.idx == 1)[0][0] - np.nonzero(es.fit.idx == 0)[0][0]
427 | z /= es.popsize - 1 # z in [-1, 1]
428 | self.s = (1 - self.sp.c) * self.s + self.sp.c * np.sign(z) * np.abs(z)**self.sp.z_exponent
429 | if self.s > 0:
430 | es.sigma *= np.exp(self.s / self.sp.dampup)
431 | else:
432 | es.sigma *= np.exp(self.s / self.sp.dampdown)
433 | #es.more_to_write.append(10**z)
434 |
435 | def check_consistency(self, es):
436 | assert isinstance(es.adapt_sigma, CMAAdaptSigmaTPA)
437 | if es.countiter > 3:
438 | dm = es.mean[0] - es.mean_old[0]
439 | dx0 = es.pop[0][0] - es.mean_old[0]
440 | dx1 = es.pop[1][0] - es.mean_old[0]
441 | for i in np.random.randint(1, es.N, 1):
442 | if dx0 * dx1 * (es.pop[0][i] - es.mean_old[i]) * (
443 | es.pop[1][i] - es.mean_old[i]):
444 | dmi_div_dx0i = (es.mean[i] - es.mean_old[i]) \
445 | / (es.pop[0][i] - es.mean_old[i])
446 | dmi_div_dx1i = (es.mean[i] - es.mean_old[i]) \
447 | / (es.pop[1][i] - es.mean_old[i])
448 | if not Mh.equals_approximately(
449 | dmi_div_dx0i, dm / dx0, 1e-4) or \
450 | not Mh.equals_approximately(
451 | dmi_div_dx1i, dm / dx1, 1e-4):
452 | utils.print_warning(
453 | 'TPA: apparent inconsistency with mirrored'
454 | ' samples, where dmi_div_dx0i, dm/dx0=%f, %f'
455 | ' and dmi_div_dx1i, dm/dx1=%f, %f' % (
456 | dmi_div_dx0i, dm/dx0, dmi_div_dx1i, dm/dx1),
457 | 'check_consistency',
458 | 'CMAAdaptSigmaTPA', es.countiter)
459 | else:
460 | utils.print_warning('zero delta encountered in TPA which' +
461 | ' \nshould be very rare and might be a bug' +
462 | ' (sigma=%f)' % es.sigma,
463 | 'check_consistency', 'CMAAdaptSigmaTPA',
464 | es.countiter)
465 |
466 |
--------------------------------------------------------------------------------
/cma/test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """test module of `cma` package.
3 |
4 | Usage::
5 |
6 | python -m cma.test -h # print this docstring
7 | python -m cma.test # doctest all (listed) files
8 | python -m cma.test list # list files to be doctested
9 | python -m cma.test interfaces.py [file2 [file3 [...]]] # doctest only these
10 |
11 | or possibly by executing this file as a script::
12 |
13 | python cma/test.py # same options as above work
14 | cma/test.py # the same
15 |
16 | or equivalently by passing Python code::
17 |
18 | python -c "import cma.test; cma.test.main()" # doctest all (listed) files
19 | python -c "import cma.test; cma.test.main('list')" # show files in doctest list
20 | python -c "import cma.test; cma.test.main('interfaces.py [file2 [file3 [...]]]')"
21 | python -c "import cma.test; help(cma.test)" # print this docstring
22 |
23 | File(name)s are interpreted within the package. Without a filename
24 | argument, all files from attribute `files_for_doc_test` are tested.
25 | """
26 |
27 | # (note to self) for testing:
28 | # pyflakes cma.py # finds bugs by static analysis
29 | # pychecker --limit 60 cma.py # also executes, all 60 warnings checked
30 | # or python ~/Downloads/pychecker-0.8.19/pychecker/checker.py cma.py
31 | # python -3 -m cma 2> out2to3warnings.txt # produces no warnings from here
32 |
33 | from __future__ import (absolute_import, division, print_function,
34 | ) # unicode_literals)
35 | import os, sys
36 | import doctest
37 | del absolute_import, division, print_function #, unicode_literals
38 |
39 | files_for_doctest = ['bbobbenchmarks.py',
40 | 'constraints_handler.py',
41 | 'evolution_strategy.py',
42 | 'fitness_functions.py',
43 | 'fitness_models.py',
44 | 'fitness_transformations.py',
45 | 'interfaces.py',
46 | 'logger.py',
47 | 'optimization_tools.py',
48 | 'purecma.py',
49 | 'recombination_weights.py',
50 | 'restricted_gaussian_sampler.py',
51 | 'sampler.py',
52 | 'sigma_adaptation.py',
53 | 'test.py',
54 | 'transformations.py',
55 | os.path.join('utilities', 'math.py'),
56 | os.path.join('utilities', 'utils.py'),
57 | ]
58 | _files_written = ['_saved-cma-object.pkl',
59 | 'outcmaesaxlen.dat',
60 | 'outcmaesaxlencorr.dat',
61 | 'outcmaesfit.dat',
62 | 'outcmaesstddev.dat',
63 | 'outcmaesxmean.dat',
64 | 'outcmaesxrecentbest.dat',
65 | ]
66 | """files written by the doc tests and hence, in case, to be deleted"""
67 |
68 | PY2 = sys.version_info[0] == 2
69 | def _clean_up(folder, start_matches, protected):
70 | """(permanently) remove entries in ``folder`` which begin with any of
71 | ``start_matches``, where ``""`` matches any string, and which are not
72 | in ``protected``.
73 |
74 | CAVEAT: use with care, as with ``"", ""`` as second and third
75 | arguments this could delete all files in ``folder``.
76 | """
77 | if not os.path.isdir(folder):
78 | return
79 | if not protected and "" in start_matches:
80 | raise ValueError(
81 | '''_clean_up(folder, [..., "", ...], []) is not permitted as it
82 | resembles "rm *"''')
83 | protected = protected + ["/"]
84 | for file_ in os.listdir(folder):
85 | if any(file_.startswith(s) for s in start_matches) \
86 | and not any(file_.startswith(p) for p in protected):
87 | os.remove(os.path.join(folder, file_))
88 | def is_str(var): # copy from utils to avoid relative import
89 | """`bytes` (in Python 3) also fit the bill"""
90 | if PY2:
91 | types_ = (str, unicode)
92 | else:
93 | types_ = (str, bytes)
94 | return any(isinstance(var, type_) for type_ in types_)
95 |
96 | def various_doctests():
97 | """various doc tests.
98 |
99 | This function describes test cases and might in future become
100 | helpful as an experimental tutorial as well. The main testing feature
101 | at the moment is by doctest with ``cma.test.main()`` in a Python shell
102 | or by ``python -m cma.test`` in a system shell.
103 |
104 | A simple first overall test:
105 |
106 | >>> import cma
107 | >>> res = cma.fmin(cma.ff.elli, 3*[1], 1,
108 | ... {'CMA_diagonal':2, 'seed':1, 'verbose':-9})
109 | >>> assert res[1] < 1e-6
110 | >>> assert res[2] < 2000
111 |
112 | Testing `args` argument:
113 |
114 | >>> def maxcorr(m):
115 | ... val = 0
116 | ... for i in range(len(m)):
117 | ... for j in range(i + 1, len(m)):
118 | ... val = max((val, abs(m[i, j])))
119 | ... return val
120 | >>> x, es = cma.fmin2(cma.ff.elli, [1, 0, 0], 0.5, {'verbose':-9}, args=[True]) # rotated
121 | >>> assert maxcorr(es.sm.correlation_matrix) > 0.9, es.sm.correlation_matrix
122 | >>> es = cma.CMAEvolutionStrategy([1, 0, 0], 0.5,
123 | ... {'verbose':-9}).optimize(cma.ff.elli, args=[1])
124 | >>> assert maxcorr(es.sm.correlation_matrix) > 0.9, es.sm.correlation_matrix
125 |
126 | Testing output file consistency with diagonal option:
127 |
128 | >>> import cma
129 | >>> for val in (0, True, 2, 3):
130 | ... _ = cma.fmin(cma.ff.sphere, 3 * [1], 1,
131 | ... {'verb_disp':0, 'CMA_diagonal':val, 'maxiter':5})
132 | ... _ = cma.CMADataLogger().load()
133 |
134 | Test on the Rosenbrock function with 3 restarts. The first trial only
135 | finds the local optimum, which happens in about 20% of the cases.
136 |
137 | >>> import cma
138 | >>> res = cma.fmin(cma.ff.rosen, 4 * [-1], 0.01,
139 | ... options={'ftarget':1e-6,
140 | ... 'verb_time':0, 'verb_disp':500,
141 | ... 'seed':3},
142 | ... restarts=3)
143 | ... # doctest: +ELLIPSIS
144 | (4_w,8)-aCMA-ES (mu_w=2.6,w_1=52%) in dimension 4 (seed=3,...)
145 | Iterat #Fevals ...
146 | >>> assert res[1] <= 1e-6
147 |
148 | Notice the different termination conditions. Termination on the target
149 | function value ftarget prevents further restarts.
150 |
151 | Test of scaling_of_variables option
152 |
153 | >>> import cma
154 | >>> opts = cma.CMAOptions()
155 | >>> opts['seed'] = 4567
156 | >>> opts['verb_disp'] = 0
157 | >>> opts['CMA_const_trace'] = True
158 | >>> # rescaling of third variable: for searching in roughly
159 | >>> # x0 plus/minus 1e3*sigma0 (instead of plus/minus sigma0)
160 | >>> opts['scaling_of_variables'] = [1, 1, 1e3, 1]
161 | >>> res = cma.fmin(cma.ff.rosen, 4 * [0.1], 0.1, opts)
162 | >>> assert res[1] < 1e-9
163 | >>> es = res[-2]
164 | >>> es.result_pretty() # doctest: +ELLIPSIS
165 | termination on tolfun=1e-11
166 | final/bestever f-value = ...
167 |
168 | The printed std deviations reflect the actual value in the
169 | parameters of the function (not the one in the internal
170 | representation which can be different).
171 |
172 | Test of CMA_stds scaling option.
173 |
174 | >>> import cma
175 | >>> opts = cma.CMAOptions()
176 | >>> s = 5 * [1]
177 | >>> s[0] = 1e3
178 | >>> opts.set('CMA_stds', s) #doctest: +ELLIPSIS
179 | {'...
180 | >>> opts.set('verb_disp', 0) #doctest: +ELLIPSIS
181 | {'...
182 | >>> res = cma.fmin(cma.ff.cigar, 5 * [0.1], 0.1, opts)
183 | >>> assert res[1] < 1800
184 |
185 | Testing combination of ``fixed_variables`` and ``CMA_stds`` options.
186 |
187 | >>> import cma
188 | >>> options = {
189 | ... 'fixed_variables':{1:2.345},
190 | ... 'CMA_stds': 4 * [1],
191 | ... 'minstd': 3 * [1]}
192 | >>> es = cma.CMAEvolutionStrategy(4 * [1], 1, options) #doctest: +ELLIPSIS
193 | (3_w,7)-aCMA-ES (mu_w=2.3,w_1=58%) in dimension 3 (seed=...
194 |
195 | Test of elitism:
196 |
197 | >>> import cma
198 | >>> res = cma.fmin(cma.ff.rastrigin, 10 * [0.1], 2,
199 | ... {'CMA_elitist':'initial', 'ftarget':1e-3, 'verbose':-9})
200 | >>> assert 'ftarget' in res[7]
201 |
202 | Test CMA_on option and similar:
203 |
204 | >>> import cma
205 | >>> res = cma.fmin(cma.ff.sphere, 4 * [1], 2,
206 | ... {'CMA_on':False, 'ftarget':1e-8, 'verbose':-9})
207 | >>> assert 'ftarget' in res[7] and res[2] < 1e3
208 | >>> res = cma.fmin(cma.ff.sphere, 3 * [1], 2,
209 | ... {'CMA_rankone':0, 'CMA_rankmu':0, 'ftarget':1e-8,
210 | ... 'verbose':-9})
211 | >>> assert 'ftarget' in res[7] and res[2] < 1e3
212 | >>> res = cma.fmin(cma.ff.sphere, 2 * [1], 2,
213 | ... {'CMA_rankone':0, 'ftarget':1e-8, 'verbose':-9})
214 | >>> assert 'ftarget' in res[7] and res[2] < 1e3
215 | >>> res = cma.fmin(cma.ff.sphere, 2 * [1], 2,
216 | ... {'CMA_rankmu':0, 'ftarget':1e-8, 'verbose':-9})
217 | >>> assert 'ftarget' in res[7] and res[2] < 1e3
218 |
219 | Check rotational invariance:
220 |
221 | >>> import cma
222 | >>> felli = cma.s.ft.Shifted(cma.ff.elli)
223 | >>> frot = cma.s.ft.Rotated(felli)
224 | >>> res_elli = cma.CMAEvolutionStrategy(3 * [1], 1,
225 | ... {'ftarget': 1e-8}).optimize(felli).result
226 | ... #doctest: +ELLIPSIS
227 | (3_w,7)-...
228 | >>> res_rot = cma.CMAEvolutionStrategy(3 * [1], 1,
229 | ... {'ftarget': 1e-8}).optimize(frot).result
230 | ... #doctest: +ELLIPSIS
231 | (3_w,7)-...
232 | >>> assert res_rot[3] < 2 * res_elli[3]
233 |
234 | Both condition alleviation transformations are applied during this
235 | test, first in iteration 62, second in iteration 257:
236 |
237 | >>> import cma
238 | >>> ftabletrot = cma.fitness_transformations.Rotated(cma.ff.tablet, seed=10)
239 | >>> es = cma.CMAEvolutionStrategy(4 * [1], 1, {
240 | ... 'tolconditioncov':False,
241 | ... 'seed': 8,
242 | ... 'CMA_mirrors': 0,
243 | ... 'ftarget': 1e-9,
244 | ... }) # doctest:+ELLIPSIS
245 | (4_w...
246 | >>> while not es.stop() and es.countiter < 82:
247 | ... X = es.ask()
248 | ... es.tell(X, [cma.ff.elli(x, cond=1e22) for x in X]) # doctest:+ELLIPSIS
249 | NOTE ...iteration=81...
250 | >>> while not es.stop():
251 | ... X = es.ask()
252 | ... es.tell(X, [ftabletrot(x) for x in X]) # doctest:+ELLIPSIS
253 | >>> assert es.countiter <= 344 and 'ftarget' in es.stop(), (
254 | ... "transformation bug in alleviate_condition?",
255 | ... es.countiter, es.stop())
256 |
257 | Integer handling:
258 |
259 | >>> import warnings
260 | >>> idx = [0, 1, 5, -1]
261 | >>> f = cma.s.ft.IntegerMixedFunction(cma.ff.elli, idx)
262 | >>> with warnings.catch_warnings(record=True) as warns:
263 | ... es = cma.CMAEvolutionStrategy(4 * [5], 10, dict(
264 | ... ftarget=1e-9, seed=5,
265 | ... integer_variables=idx
266 | ... )) # doctest:+ELLIPSIS
267 | (4_w,8)-...
268 | >>> warns[0].message # doctest:+ELLIPSIS
269 | UserWarning('integer index 5 not in range of dimension 4 ()'...
270 | >>> es.optimize(f) # doctest:+ELLIPSIS
271 | Iterat #Fevals function value ...
272 | >>> assert 'ftarget' in es.stop() and es.result[3] < 1800
273 |
274 | Parallel objective:
275 |
276 | >>> def parallel_sphere(X): return [cma.ff.sphere(x) for x in X]
277 | >>> x, es = cma.fmin2(cma.ff.sphere, 3 * [0], 0.1, {
278 | ... 'verbose': -9, 'eval_final_mean': True, 'CMA_elitist': 'initial'},
279 | ... parallel_objective=parallel_sphere)
280 | >>> assert es.result[1] < 1e-9
281 | >>> x, es = cma.fmin2(None, 3 * [0], 0.1, {
282 | ... 'verbose': -9, 'eval_final_mean': True, 'CMA_elitist': 'initial'},
283 | ... parallel_objective=parallel_sphere)
284 | >>> assert es.result[1] < 1e-9
285 |
286 | Some sort of interactive control via an options file:
287 |
288 | >>> es = cma.CMAEvolutionStrategy(4 * [2], 1, dict(
289 | ... signals_filename='cma_signals.in',
290 | ... verbose=-9))
291 | >>> s = es.stop()
292 | >>> es = es.optimize(cma.ff.sphere)
293 |
294 | Test of huge lambda:
295 |
296 | >>> es = cma.CMAEvolutionStrategy(3 * [0.91], 1, {
297 | ... 'verbose': -9,
298 | ... 'popsize': 200,
299 | ... 'ftarget': 1e-8 })
300 | >>> es = es.optimize(cma.ff.tablet)
301 | >>> if es.result.evaluations > 5000: print(es.result.evalutions, es.result)
302 |
303 | For VD- and VkD-CMA, see `cma.restricted_gaussian_sampler`.
304 |
305 | >>> import sys
306 | >>> import cma
307 | >>> assert cma.interfaces.EvalParallel2 is not None
308 | >>> try:
309 | ... with warnings.catch_warnings(record=True) as warn:
310 | ... with cma.optimization_tools.EvalParallel2(cma.ff.elli) as eval_all:
311 | ... res = eval_all([[1,2], [3,4]])
312 | ... except:
313 | ... assert sys.version[0] == '2'
314 |
315 | """
316 |
317 | def doctest_files(file_list=files_for_doctest, **kwargs):
318 | """doctest all (listed) files of the `cma` package.
319 |
320 | Details: accepts ``verbose`` and all other keyword arguments that
321 | `doctest.testfile` would accept, while negative ``verbose`` values
322 | are passed as 0.
323 | """
324 | # print("__name__ is", __name__, sys.modules[__name__])
325 | # print(__package__)
326 | if not isinstance(file_list, list) and is_str(file_list):
327 | file_list = [file_list]
328 | verbosity_here = kwargs.get('verbose', 0)
329 | if verbosity_here < 0:
330 | kwargs['verbose'] = 0
331 | failures = 0
332 | for file_ in file_list:
333 | file_ = file_.strip().strip(os.path.sep)
334 | if file_.startswith('cma' + os.path.sep):
335 | file_ = file_[4:]
336 | if verbosity_here >= 0:
337 | print('doctesting %s ...' % file_,
338 | ' ' * (max(len(_file) for _file in file_list) -
339 | len(file_)),
340 | end="") # does not work in Python 2.5
341 | sys.stdout.flush()
342 | protected_files = os.listdir('.')
343 | report = doctest.testfile(file_,
344 | package=__package__, # 'cma', # sys.modules[__name__],
345 | **kwargs)
346 | _clean_up('.', _files_written, protected_files)
347 | failures += report[0]
348 | if verbosity_here >= 0:
349 | print(report)
350 | return failures
351 |
352 | def get_version():
353 | try:
354 | with open(__file__[:-7] + '__init__.py', 'r') as f:
355 | for line in f.readlines():
356 | if line.startswith('__version__'):
357 | return line[15:].split()[0]
358 | except:
359 | return ""
360 | print(__file__)
361 | raise
362 |
363 | def main(*args, **kwargs):
364 | """test the `cma` package.
365 |
366 | The first argument can be '-h' or '--help' or 'list' to list all
367 | files to be tested. Otherwise, arguments can be file(name)s to be
368 | tested, where names are interpreted relative to the package root
369 | and a leading 'cma' + path separator is ignored.
370 |
371 | By default all files are tested.
372 |
373 | :See also: ``python -c "import cma.test; help(cma.test)"``
374 | """
375 | if len(args) > 0:
376 | if args[0].startswith(('-h', '--h')):
377 | print(__doc__)
378 | exit(0)
379 | elif args[0].startswith('list'):
380 | for file_ in files_for_doctest:
381 | print(file_)
382 | exit(0)
383 | else:
384 | v = get_version()
385 | print("doctesting `cma` package%s by calling `doctest_files`:"
386 | % ((" (v%s)" % v) if v else ""))
387 | return doctest_files(args if args else files_for_doctest, **kwargs)
388 |
389 | if __name__ == "__main__":
390 | exit(main(*sys.argv[1:]) > 0) # 0 if failures == 0 else 1
--------------------------------------------------------------------------------
/cma/utilities/__init__.py:
--------------------------------------------------------------------------------
1 | """various unspecific utilities"""
--------------------------------------------------------------------------------
/cma/utilities/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/utilities/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/cma/utilities/__pycache__/math.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/utilities/__pycache__/math.cpython-39.pyc
--------------------------------------------------------------------------------
/cma/utilities/__pycache__/python3for2.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/utilities/__pycache__/python3for2.cpython-39.pyc
--------------------------------------------------------------------------------
/cma/utilities/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/cma/utilities/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/cma/utilities/math.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """ various math utilities, notably `eig` and a collection of simple
3 | functions in `Mh`
4 | """
5 | from __future__ import absolute_import, division, print_function #, unicode_literals
6 | # from future.builtins.disabled import * # don't use any function which could lead to different results in Python 2 vs 3
7 | import warnings as _warnings
8 | import numpy as np
9 | from .python3for2 import range
10 | del absolute_import, division, print_function #, unicode_literals
11 |
12 | def _sqrt_len(x): # makes randhss option pickable
13 | return len(x)**0.5
14 |
15 | def randhss(n, dim, norm_=_sqrt_len, randn=np.random.randn):
16 | """`n` iid `dim`-dimensional vectors with length ``norm_(vector)``.
17 |
18 | The vectors are uniformly distributed on a hypersphere surface.
19 |
20 | CMA-ES diverges with popsize 100 in 15-D without option
21 | 'CSA_clip_length_value': [0,0].
22 |
23 | >>> from cma.utilities.math import randhss
24 | >>> dim = 3
25 | >>> assert dim - 1e-7 < sum(randhss(1, dim)[0]**2) < dim + 1e-7
26 |
27 | """
28 | arv = randn(n, dim)
29 | for v in arv:
30 | v *= norm_(v) / np.sum(v**2)**0.5
31 | return arv
32 |
33 | def randhss_mixin(n, dim, norm_=_sqrt_len,
34 | c=lambda d: 1. / d, randn=np.random.randn):
35 | """`n` iid vectors uniformly distributed on the hypersphere surface with
36 | mixing in of normal distribution, which can be beneficial in smaller
37 | dimension.
38 | """
39 | arv = randhss(n, dim, norm_, randn)
40 | c = min((1, c(dim)))
41 | if c > 0:
42 | if c > 1: # can never happen
43 | raise ValueError("c(dim)=%f should be <=1" % c)
44 | for v in arv:
45 | v *= (1 - c**2)**0.5 # has 2 / c longer time horizon than 1 - c
46 | v += c * randn(1, dim)[0] # c is sqrt(2/c) times smaller than sqrt(c * (2 - c))
47 | return arv
48 |
49 | def to_correlation_matrix(c):
50 | """change C in place into a correlation matrix, AKA whitening"""
51 | for i in range(c.shape[0]):
52 | fac = c[i, i]**0.5
53 | c[:, i] /= fac
54 | c[i, :] /= fac
55 | c = (c + c.T) / 2.0
56 | assert np.allclose(np.diag(c), 1)
57 | return c
58 |
59 | _warnings.filterwarnings( # one message for each value (which is given in the message)
60 | 'once', message="using exponential smoothing with .* rolling average")
61 | def moving_average(x, w=7):
62 | """rolling average without biasing boundary effects.
63 |
64 | The first entries give the average over all first
65 | values (until the window width is reached).
66 |
67 | If `w` is not an integer, expontential smoothing with weights
68 | proportionate to ``(1 - 1/w)**i`` summing to one is executed, thereby
69 | putting about 1 - exp(-1) ≈ 0.63 of the weight sum on the last `w`
70 | entries.
71 |
72 | Details: the average is mainly based on `np.convolve`, whereas
73 | exponential smoothing is for the time being numerically inefficient and
74 | scales quadratically with the length of `x`.
75 | """
76 | if w == 1:
77 | return x
78 | elif isinstance(w, int): # interpret as window width
79 | w = min((w, len(x))) # window width
80 | return np.hstack([[np.mean(x[:i]) for i in range(1, w)],
81 | np.convolve(x, w * [1 / w], mode='valid')])
82 | else: # exponential smoothing
83 | if w == int(w):
84 | _warnings.warn("using exponential smoothing with time"
85 | " horizon {}. \nUse `int` type to get the"
86 | " rolling average.".format(w))
87 | v = 1 - 1 / w
88 | return np.asarray([sum([v**j * x[i-j] for j in range(i + 1)])
89 | / sum([v**j for j in range(i + 1)])
90 | for i in range(len(x))])
91 |
92 | def Hessian(f, x0, eps=1e-6):
93 | """Hessian estimate for `f` at `x0`"""
94 | if eps is None:
95 | eps = 1e-6 # restore default
96 | x0 = np.asarray(x0)
97 | e = np.eye(len(x0))
98 | H = 0 * e
99 | for i in range(len(x0)):
100 | ei = eps * e[i]
101 | for j in range(i+1):
102 | ej = eps * e[j]
103 | H[i,j] = (f(x0 + ei + ej) - f(x0 + ei) - f(x0 + ej) + f(x0)) / eps**2
104 | H[j, i] = H[i, j]
105 | return H
106 |
107 | def geometric_sd(vals, **kwargs):
108 | """return geometric standard deviation of `vals`.
109 |
110 | The gsd is invariant under linear scaling and independent
111 | of the choice of the log-exp base.
112 |
113 | ``kwargs`` are passed to `np.std`, in particular `ddof`.
114 | """
115 | return np.exp(np.std(np.log(vals), **kwargs))
116 |
117 | # ____________________________________________________________
118 | # ____________________________________________________________
119 | #
120 | # C and B are arrays rather than matrices, because they are
121 | # addressed via B[i][j], matrices can only be addressed via B[i,j]
122 |
123 | # tred2(N, B, diagD, offdiag);
124 | # tql2(N, diagD, offdiag, B);
125 |
126 |
127 | # Symmetric Householder reduction to tridiagonal form, translated from JAMA package.
128 | def eig(C):
129 | """eigendecomposition of a symmetric matrix, much slower than
130 | `numpy.linalg.eigh`, return ``(EVals, Basis)``, the eigenvalues
131 | and an orthonormal basis of the corresponding eigenvectors, where
132 |
133 | ``Basis[i]``
134 | the i-th row of ``Basis``
135 | columns of ``Basis``, ``[Basis[j][i] for j in range(len(Basis))]``
136 | the i-th eigenvector with eigenvalue ``EVals[i]``
137 |
138 | """
139 |
140 | # class eig(object):
141 | # def __call__(self, C):
142 |
143 | # Householder transformation of a symmetric matrix V into tridiagonal form.
144 | # -> n : dimension
145 | # -> V : symmetric nxn-matrix
146 | # <- V : orthogonal transformation matrix:
147 | # tridiag matrix == V * V_in * V^t
148 | # <- d : diagonal
149 | # <- e[0..n-1] : off diagonal (elements 1..n-1)
150 |
151 | # Symmetric tridiagonal QL algorithm, iterative
152 | # Computes the eigensystem from a tridiagonal matrix in roughtly 3N^3 operations
153 | # -> n : Dimension.
154 | # -> d : Diagonale of tridiagonal matrix.
155 | # -> e[1..n-1] : off-diagonal, output from Householder
156 | # -> V : matrix output von Householder
157 | # <- d : eigenvalues
158 | # <- e : garbage?
159 | # <- V : basis of eigenvectors, according to d
160 |
161 |
162 | # tred2(N, B, diagD, offdiag); B=C on input
163 | # tql2(N, diagD, offdiag, B);
164 |
165 | # private void tred2 (int n, double V[][], double d[], double e[]) {
166 | def tred2 (n, V, d, e):
167 | # This is derived from the Algol procedures tred2 by
168 | # Bowdler, Martin, Reinsch, and Wilkinson, Handbook for
169 | # Auto. Comp., Vol.ii-Linear Algebra, and the corresponding
170 | # Fortran subroutine in EISPACK.
171 |
172 | num_opt = False # factor 1.5 in 30-D
173 |
174 | for j in range(n):
175 | d[j] = V[n - 1][j] # d is output argument
176 |
177 | # Householder reduction to tridiagonal form.
178 |
179 | for i in range(n - 1, 0, -1):
180 | # Scale to avoid under/overflow.
181 | h = 0.0
182 | if not num_opt:
183 | scale = 0.0
184 | for k in range(i):
185 | scale = scale + abs(d[k])
186 | else:
187 | scale = sum(abs(d[0:i]))
188 |
189 | if scale == 0.0:
190 | e[i] = d[i - 1]
191 | for j in range(i):
192 | d[j] = V[i - 1][j]
193 | V[i][j] = 0.0
194 | V[j][i] = 0.0
195 | else:
196 |
197 | # Generate Householder vector.
198 | if not num_opt:
199 | for k in range(i):
200 | d[k] /= scale
201 | h += d[k] * d[k]
202 | else:
203 | d[:i] /= scale
204 | h = np.dot(d[:i], d[:i])
205 |
206 | f = d[i - 1]
207 | g = h**0.5
208 |
209 | if f > 0:
210 | g = -g
211 |
212 | e[i] = scale * g
213 | h = h - f * g
214 | d[i - 1] = f - g
215 | if not num_opt:
216 | for j in range(i):
217 | e[j] = 0.0
218 | else:
219 | e[:i] = 0.0
220 |
221 | # Apply similarity transformation to remaining columns.
222 |
223 | for j in range(i):
224 | f = d[j]
225 | V[j][i] = f
226 | g = e[j] + V[j][j] * f
227 | if not num_opt:
228 | for k in range(j + 1, i):
229 | g += V[k][j] * d[k]
230 | e[k] += V[k][j] * f
231 | e[j] = g
232 | else:
233 | e[j + 1:i] += V.T[j][j + 1:i] * f
234 | e[j] = g + np.dot(V.T[j][j + 1:i], d[j + 1:i])
235 |
236 | f = 0.0
237 | if not num_opt:
238 | for j in range(i):
239 | e[j] /= h
240 | f += e[j] * d[j]
241 | else:
242 | e[:i] /= h
243 | f += np.dot(e[:i], d[:i])
244 |
245 | hh = f / (h + h)
246 | if not num_opt:
247 | for j in range(i):
248 | e[j] -= hh * d[j]
249 | else:
250 | e[:i] -= hh * d[:i]
251 |
252 | for j in range(i):
253 | f = d[j]
254 | g = e[j]
255 | if not num_opt:
256 | for k in range(j, i):
257 | V[k][j] -= (f * e[k] + g * d[k])
258 | else:
259 | V.T[j][j:i] -= (f * e[j:i] + g * d[j:i])
260 |
261 | d[j] = V[i - 1][j]
262 | V[i][j] = 0.0
263 |
264 | d[i] = h
265 | # end for i--
266 |
267 | # Accumulate transformations.
268 |
269 | for i in range(n - 1):
270 | V[n - 1][i] = V[i][i]
271 | V[i][i] = 1.0
272 | h = d[i + 1]
273 | if h != 0.0:
274 | if not num_opt:
275 | for k in range(i + 1):
276 | d[k] = V[k][i + 1] / h
277 | else:
278 | d[:i + 1] = V.T[i + 1][:i + 1] / h
279 |
280 | for j in range(i + 1):
281 | if not num_opt:
282 | g = 0.0
283 | for k in range(i + 1):
284 | g += V[k][i + 1] * V[k][j]
285 | for k in range(i + 1):
286 | V[k][j] -= g * d[k]
287 | else:
288 | g = np.dot(V.T[i + 1][0:i + 1], V.T[j][0:i + 1])
289 | V.T[j][:i + 1] -= g * d[:i + 1]
290 |
291 | if not num_opt:
292 | for k in range(i + 1):
293 | V[k][i + 1] = 0.0
294 | else:
295 | V.T[i + 1][:i + 1] = 0.0
296 |
297 |
298 | if not num_opt:
299 | for j in range(n):
300 | d[j] = V[n - 1][j]
301 | V[n - 1][j] = 0.0
302 | else:
303 | d[:n] = V[n - 1][:n]
304 | V[n - 1][:n] = 0.0
305 |
306 | V[n - 1][n - 1] = 1.0
307 | e[0] = 0.0
308 |
309 |
310 | # Symmetric tridiagonal QL algorithm, taken from JAMA package.
311 | # private void tql2 (int n, double d[], double e[], double V[][]) {
312 | # needs roughly 3N^3 operations
313 | def tql2 (n, d, e, V):
314 |
315 | # This is derived from the Algol procedures tql2, by
316 | # Bowdler, Martin, Reinsch, and Wilkinson, Handbook for
317 | # Auto. Comp., Vol.ii-Linear Algebra, and the corresponding
318 | # Fortran subroutine in EISPACK.
319 |
320 | num_opt = False # using vectors from numpy makes it faster
321 |
322 | if not num_opt:
323 | for i in range(1, n): # (int i = 1; i < n; i++):
324 | e[i - 1] = e[i]
325 | else:
326 | e[0:n - 1] = e[1:n]
327 | e[n - 1] = 0.0
328 |
329 | f = 0.0
330 | tst1 = 0.0
331 | eps = 2.0**-52.0
332 | for l in range(n): # (int l = 0; l < n; l++) {
333 |
334 | # Find small subdiagonal element
335 |
336 | tst1 = max(tst1, abs(d[l]) + abs(e[l]))
337 | m = l
338 | while m < n:
339 | if abs(e[m]) <= eps * tst1:
340 | break
341 | m += 1
342 |
343 | # If m == l, d[l] is an eigenvalue,
344 | # otherwise, iterate.
345 |
346 | if m > l:
347 | iiter = 0
348 | while 1: # do {
349 | iiter += 1 # (Could check iteration count here.)
350 |
351 | # Compute implicit shift
352 |
353 | g = d[l]
354 | p = (d[l + 1] - g) / (2.0 * e[l])
355 | r = (p**2 + 1)**0.5 # hypot(p,1.0)
356 | if p < 0:
357 | r = -r
358 |
359 | d[l] = e[l] / (p + r)
360 | d[l + 1] = e[l] * (p + r)
361 | dl1 = d[l + 1]
362 | h = g - d[l]
363 | if not num_opt:
364 | for i in range(l + 2, n):
365 | d[i] -= h
366 | else:
367 | d[l + 2:n] -= h
368 |
369 | f = f + h
370 |
371 | # Implicit QL transformation.
372 |
373 | p = d[m]
374 | c = 1.0
375 | c2 = c
376 | c3 = c
377 | el1 = e[l + 1]
378 | s = 0.0
379 | s2 = 0.0
380 |
381 | # hh = V.T[0].copy() # only with num_opt
382 | for i in range(m - 1, l - 1, -1): # (int i = m-1; i >= l; i--) {
383 | c3 = c2
384 | c2 = c
385 | s2 = s
386 | g = c * e[i]
387 | h = c * p
388 | r = (p**2 + e[i]**2)**0.5 # hypot(p,e[i])
389 | e[i + 1] = s * r
390 | s = e[i] / r
391 | c = p / r
392 | p = c * d[i] - s * g
393 | d[i + 1] = h + s * (c * g + s * d[i])
394 |
395 | # Accumulate transformation.
396 |
397 | if not num_opt: # overall factor 3 in 30-D
398 | for k in range(n): # (int k = 0; k < n; k++) {
399 | h = V[k][i + 1]
400 | V[k][i + 1] = s * V[k][i] + c * h
401 | V[k][i] = c * V[k][i] - s * h
402 | else: # about 20% faster in 10-D
403 | hh = V.T[i + 1].copy()
404 | # hh[:] = V.T[i+1][:]
405 | V.T[i + 1] = s * V.T[i] + c * hh
406 | V.T[i] = c * V.T[i] - s * hh
407 | # V.T[i] *= c
408 | # V.T[i] -= s * hh
409 |
410 | p = -s * s2 * c3 * el1 * e[l] / dl1
411 | e[l] = s * p
412 | d[l] = c * p
413 |
414 | # Check for convergence.
415 | if abs(e[l]) <= eps * tst1:
416 | break
417 | # } while (Math.abs(e[l]) > eps*tst1);
418 |
419 | d[l] = d[l] + f
420 | e[l] = 0.0
421 |
422 |
423 | # Sort eigenvalues and corresponding vectors.
424 | if 11 < 3:
425 | for i in range(n - 1): # (int i = 0; i < n-1; i++) {
426 | k = i
427 | p = d[i]
428 | for j in range(i + 1, n): # (int j = i+1; j < n; j++) {
429 | if d[j] < p: # NH find smallest k>i
430 | k = j
431 | p = d[j]
432 |
433 | if k != i:
434 | d[k] = d[i] # swap k and i
435 | d[i] = p
436 | for j in range(n): # (int j = 0; j < n; j++) {
437 | p = V[j][i]
438 | V[j][i] = V[j][k]
439 | V[j][k] = p
440 | # tql2
441 |
442 | N = len(C[0])
443 | if 11 < 3:
444 | V = np.array([x[:] for x in C]) # copy each "row"
445 | N = V[0].size
446 | d = np.zeros(N)
447 | e = np.zeros(N)
448 | else:
449 | V = [[x[i] for i in range(N)] for x in C] # copy each "row"
450 | d = N * [0.]
451 | e = N * [0.]
452 |
453 | tred2(N, V, d, e)
454 | tql2(N, d, e, V)
455 | return np.array(d), np.array(V)
456 |
457 | class MathHelperFunctions(object):
458 | """static convenience math helper functions, if the function name
459 | is preceded with an "a", a numpy array is returned
460 |
461 | TODO: there is probably no good reason why this should be a class and not a
462 | module.
463 |
464 | """
465 | @staticmethod
466 | def aclamp(x, upper):
467 | return -MathHelperFunctions.apos(-x, -upper)
468 | @staticmethod
469 | def equals_approximately(a, b, eps=1e-12):
470 | if a < 0:
471 | a, b = -1 * a, -1 * b
472 | return (a - eps < b < a + eps) or ((1 - eps) * a < b < (1 + eps) * a)
473 | @staticmethod
474 | def vequals_approximately(a, b, eps=1e-12):
475 | a, b = np.array(a), np.array(b)
476 | idx = np.nonzero(a < 0)[0] # find
477 | if len(idx):
478 | a[idx], b[idx] = -1 * a[idx], -1 * b[idx]
479 | return (np.all(a - eps < b) and np.all(b < a + eps)
480 | ) or (np.all((1 - eps) * a < b) and np.all(b < (1 + eps) * a))
481 | @staticmethod
482 | def expms(A, eig=np.linalg.eigh):
483 | """matrix exponential for a symmetric matrix"""
484 | # TODO: check that this works reliably for low rank matrices
485 | # first: symmetrize A
486 | D, B = eig(A)
487 | return np.dot(B, (np.exp(D) * B).T)
488 | @staticmethod
489 | def amax(vec, vec_or_scalar):
490 | return np.array(MathHelperFunctions.max(vec, vec_or_scalar))
491 | @staticmethod
492 | def max(vec, vec_or_scalar):
493 | b = vec_or_scalar
494 | if np.isscalar(b):
495 | m = [max(x, b) for x in vec]
496 | else:
497 | m = [max(vec[i], b[i]) for i in range(len((vec)))]
498 | return m
499 | @staticmethod
500 | def minmax(val, min_val, max_val):
501 | assert min_val <= max_val
502 | return min((max_val, max((val, min_val))))
503 | @staticmethod
504 | def aminmax(val, min_val, max_val):
505 | return np.array([min((max_val, max((v, min_val)))) for v in val])
506 | @staticmethod
507 | def amin(vec_or_scalar, vec_or_scalar2):
508 | return np.array(MathHelperFunctions.min(vec_or_scalar, vec_or_scalar2))
509 | @staticmethod
510 | def min(a, b):
511 | iss = np.isscalar
512 | if iss(a) and iss(b):
513 | return min(a, b)
514 | if iss(a):
515 | a, b = b, a
516 | # now only b can be still a scalar
517 | if iss(b):
518 | return [min(x, b) for x in a]
519 | else: # two non-scalars must have the same length
520 | return [min(a[i], b[i]) for i in range(len((a)))]
521 | @staticmethod
522 | def norm(vec, expo=2):
523 | return sum(vec**expo)**(1 / expo)
524 | @staticmethod
525 | def apos(x, lower=0):
526 | """clips argument (scalar or array) from below at lower"""
527 | if lower == 0:
528 | return (x > 0) * x
529 | else:
530 | return lower + (x > lower) * (x - lower)
531 |
532 | @staticmethod
533 | def apenalty_quadlin(x, lower=0, upper=None):
534 | """Huber-like smooth penality which starts at lower.
535 |
536 | The penalty is zero below lower and affine linear above upper.
537 |
538 | Return::
539 |
540 | 0, if x <= lower
541 | quadratic in x, if lower <= x <= upper
542 | affine linear in x with slope upper - lower, if x >= upper
543 |
544 | `upper` defaults to ``lower + 1``.
545 |
546 | """
547 | if upper is None:
548 | upper = np.asarray(lower) + 1
549 | z = np.asarray(x) - lower
550 | del x # assert that x is not used anymore accidentally
551 | u = np.asarray(upper) - lower
552 | return (z > 0) * ((z <= u) * (z ** 2 / 2) + (z > u) * u * (z - u / 2))
553 |
554 | @staticmethod
555 | def prctile(data, p_vals=[0, 25, 50, 75, 100], sorted_=False):
556 | """``prctile(data, 50)`` returns the median, but p_vals can
557 | also be a sequence.
558 |
559 | Provides for small samples or extremes IMHO better values than
560 | matplotlib.mlab.prctile or np.percentile, however also slower.
561 |
562 | """
563 | ps = [p_vals] if np.isscalar(p_vals) else p_vals
564 |
565 | if not sorted_:
566 | data = sorted(data)
567 | n = len(data)
568 | d = []
569 | for p in ps:
570 | fi = p * n / 100 - 0.5
571 | if fi <= 0: # maybe extrapolate?
572 | d.append(data[0])
573 | elif fi >= n - 1:
574 | d.append(data[-1])
575 | else:
576 | i = int(fi)
577 | d.append((i + 1 - fi) * data[i] + (fi - i) * data[i + 1])
578 | return d[0] if np.isscalar(p_vals) else d
579 | @staticmethod
580 | def iqr(data, percentile_function=np.percentile): # MathHelperFunctions.prctile
581 | """interquartile range"""
582 | q25, q75 = percentile_function(data, [25, 75])
583 | return np.asarray(q75) - np.asarray(q25)
584 | @staticmethod
585 | def interdecilerange(data, percentile_function=np.percentile):
586 | """return 10% to 90% range width"""
587 | q10, q90 = percentile_function(data, [10, 90])
588 | return np.asarray(q90) - np.asarray(q10)
589 | @staticmethod
590 | def logit10(x, lower=0, upper=1):
591 | """map [lower, upper] -> R such that
592 |
593 | ::
594 |
595 | upper - 10^-x -> x, and
596 | lower + 10^-x -> -x
597 |
598 | for large enough x. By default, simplifies close to `log10(x / (1 - x))`.
599 |
600 | >>> from cma.utilities.math import Mh
601 | >>> l, u = -1, 2
602 | >>> print(Mh.logit10([l+0.01, 0.5, u-0.01], l, u))
603 | [-1.9949189 0. 1.9949189]
604 |
605 | """
606 | x = np.asarray(x)
607 | z = (x - lower) / (upper - lower) # between 0 and 1
608 | return np.log10((x - lower)**(1-z) / (upper - x)**z)
609 | return (1 - z) * np.log10(x - lower) - z * np.log10(upper - x)
610 | @staticmethod
611 | def sround(nb): # TODO: to be vectorized
612 | """return stochastic round: int(nb) + (rand() 1000:
618 | n = np.random.randn() / np.random.randn()
619 | return n / 25
620 | @staticmethod
621 | def standard_finite_cauchy(size=1):
622 | try:
623 | l = len(size)
624 | except TypeError:
625 | l = 0
626 |
627 | if l == 0:
628 | return np.array([MathHelperFunctions.cauchy_with_variance_one() for _i in range(size)])
629 | elif l == 1:
630 | return np.array([MathHelperFunctions.cauchy_with_variance_one() for _i in range(size[0])])
631 | elif l == 2:
632 | return np.array([[MathHelperFunctions.cauchy_with_variance_one() for _i in range(size[1])]
633 | for _j in range(size[0])])
634 | else:
635 | raise ValueError('len(size) cannot be larger than two')
636 |
637 | Mh = MathHelperFunctions
638 |
--------------------------------------------------------------------------------
/cma/utilities/python3for2.py:
--------------------------------------------------------------------------------
1 | """to execute Python 3 code in Python 2.
2 |
3 | redefines builtin `range` and `input` functions and `abc` either via `collections`
4 | or `collections.abc` if available.
5 | """
6 | import sys
7 | import collections as _collections
8 |
9 | range = range # to allow (trivial) explicit import also in Python 3
10 | input = input
11 |
12 | if sys.version[0] == '2': # in python 2
13 | range = xrange # clean way: from builtins import range
14 | input = raw_input # in py2, input(x) == eval(raw_input(x))
15 | abc = _collections # never used
16 |
17 | # only for testing, because `future` may not be installed
18 | # from future.builtins import *
19 | if 11 < 3: # newint does produce an error on some installations
20 | try:
21 | from future.builtins.disabled import * # rather not necessary if tested also in Python 3
22 | from future.builtins import (
23 | bytes, dict, int, list, object, range,
24 | str, ascii, chr, hex, input, next, oct, open,
25 | pow, round, super, filter, map, zip
26 | )
27 | from builtins import ( # not list and object, by default builtins don't exist in Python 2
28 | ascii, bytes, chr, dict, filter, hex, input,
29 | int, map, next, oct, open, pow, range, round,
30 | str, super, zip)
31 | except ImportError:
32 | pass
33 | else:
34 | try:
35 | abc = _collections.abc
36 | except AttributeError:
37 | abc = _collections
38 |
--------------------------------------------------------------------------------
/cma/wrapper.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | '''Interface wrappers for the `cma` module.
3 |
4 | The `SkoptCMAoptimizer` wrapper interfaces an optimizer aligned with
5 | `skopt.optimizer`.
6 | '''
7 | # built-in
8 | import pdb
9 | import copy
10 | import inspect
11 | import tempfile
12 | import os
13 | import warnings
14 |
15 | # external
16 | import numpy as np
17 | import cma # caveat: does not import necessarily the code of this root folder?
18 |
19 | try: import skopt
20 | except ImportError: warnings.warn('install `skopt` ("pip install scikit-optimize") '
21 | 'to use `SkoptCMAoptimizer`')
22 | else:
23 | def SkoptCMAoptimizer(
24 | func, dimensions, n_calls, verbose=False, callback=(), x0=None, n_jobs=1,
25 | sigma0=.5, normalize=True,
26 | ):
27 | '''
28 | Optmizer based on CMA-ES algorithm.
29 | This is essentially a wrapper fuction for the cma library function
30 | to align the interface with skopt library.
31 |
32 | Args:
33 | func (callable): function to optimize
34 | dimensions: list of tuples like ``4 * [(-1., 1.)]`` for defining the domain.
35 | n_calls: the number of samples.
36 | verbose: if this func should be verbose
37 | callback: the list of callback functions.
38 | n_jobs: number of cores to run different calls to `func` in parallel.
39 | x0: inital values
40 | if None, random point will be sampled
41 | sigma0: initial standard deviation relative to domain width
42 | normalize: whether optimization domain should be normalized
43 |
44 | Returns:
45 | `res` skopt.OptimizeResult object
46 | The optimization result returned as a dict object.
47 | Important attributes are:
48 | - `x` [list]: location of the minimum.
49 | - `fun` [float]: function value at the minimum.
50 | - `x_iters` [list of lists]: location of function evaluation for each
51 | iteration.
52 | - `func_vals` [array]: function value for each iteration.
53 | - `space` [skopt.space.Space]: the optimization space.
54 |
55 | Example::
56 |
57 | import cma.wrapper
58 | res = cma.wrapper.SkoptCMAoptimizer(lambda x: sum([xi**2 for xi in x]),
59 | 2 * [(-1.,1.)], 55)
60 | res['cma_es'].logger.plot()
61 |
62 | '''
63 | specs = {
64 | 'args': copy.copy(inspect.currentframe().f_locals),
65 | 'function': inspect.currentframe().f_code.co_name,
66 | }
67 |
68 | if normalize: dimensions = list(map(lambda x: skopt.space.check_dimension(x, 'normalize'), dimensions))
69 | space = skopt.space.Space(dimensions)
70 | if x0 is None: x0 = space.transform(space.rvs())[0]
71 | else: x0 = space.transform([x0])[0]
72 |
73 | tempdir = tempfile.mkdtemp()
74 | xi, yi = [], []
75 | options = {
76 | 'bounds': np.array(space.transformed_bounds).transpose().tolist(),
77 | 'verb_filenameprefix': tempdir,
78 | }
79 |
80 | def delete_tempdir(self, *args, **kargs):
81 | os.removedirs(tempdir)
82 | return
83 |
84 | model = cma.CMAEvolutionStrategy(x0, sigma0, options)
85 | model.logger.__del__ = delete_tempdir
86 | switch = { -1: None, # use number of available CPUs
87 | 1: 0, # avoid using multiprocessor for just one CPU
88 | }
89 | with cma.optimization_tools.EvalParallel2(func,
90 | number_of_processes=switch.get(n_jobs, n_jobs)) as parallel_func:
91 | for _i in range(n_calls):
92 | if model.stop(): break
93 | new_xi = model.ask()
94 | new_xi_denorm = space.inverse_transform(np.array(new_xi))
95 | # new_yi = [func(x) for x in new_xi_denorm]
96 | new_yi = parallel_func(new_xi_denorm)
97 |
98 | model.tell(new_xi, new_yi)
99 | model.logger.add()
100 | if verbose: model.disp()
101 |
102 | xi += new_xi_denorm
103 | yi += new_yi
104 | results = skopt.utils.create_result(xi, yi)
105 | for f in callback: f(results)
106 |
107 | results = skopt.utils.create_result(xi, yi, space)
108 | model.logger.load()
109 | results.cma_es = model
110 | results.cma_logger = model.logger
111 | results.specs = specs
112 | return results
113 |
--------------------------------------------------------------------------------
/embeding_distribution.py:
--------------------------------------------------------------------------------
1 | """
2 | Plot the textual embedding and projection w_p Q distribution in stable diffusion.
3 | """
4 |
5 | import os
6 | import argparse
7 |
8 | from transformers import CLIPTextModel, CLIPTokenizer
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | from sklearn.decomposition import PCA
12 |
13 | def main():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument("--model_path", default='./ckpt', type=str)
16 | parser.add_argument("--model_dim", default=768, type=int)
17 | parser.add_argument("--lamda", default=5, type=int)
18 | args = parser.parse_args()
19 |
20 |
21 | text_encoder = CLIPTextModel.from_pretrained(
22 | os.path.join(args.model_path, "text_encoder")
23 | )
24 | embedding = text_encoder.get_input_embeddings().weight.clone().cpu()
25 | print(embedding.size())
26 | embedding = embedding.detach().cpu().numpy()
27 | mu_hat = np.mean(embedding.reshape(-1))
28 | std_hat = np.std(embedding.reshape(-1))
29 | print(mu_hat, std_hat)
30 | number = embedding.reshape(-1).shape[0]
31 | normal = np.random.normal(loc=0, scale=1 / args.model_dim * args.lamda, size = number)
32 | sampling = np.random.normal(loc=0, scale=std_hat * args.lamda, size = number)
33 |
34 |
35 |
36 | pca = PCA(n_components=args.model_dim)
37 | pca.fit(embedding)
38 | pca = pca.components_.reshape(-1)
39 |
40 | # initialize the Q with norm(0, 0.5)
41 | cma = np.random.normal(loc=0, scale=0.5, size = number)
42 |
43 | # projection distribution with W_p Q
44 | normal = cma * normal
45 | sampling = cma * sampling
46 |
47 | cma_pca = np.random.normal(loc=0, scale=0.5, size = pca.shape[0])
48 | pca = cma_pca * pca
49 |
50 |
51 | kwargs = dict(alpha=0.5, bins=100, density=True, stacked=True)
52 | embedding = embedding.reshape(-1)
53 |
54 | plt.hist(embedding, **kwargs, color='g', label='Textual Embedding')
55 | plt.hist(normal, **kwargs, color='r', label='Random Norm')
56 | plt.hist(pca, **kwargs, color='black', label='PCA')
57 | plt.hist(sampling, **kwargs, color='b', label='Prior Norm')
58 |
59 | plt.gca().set(ylabel='Frequency')
60 | plt.xlim(-0.1,0.1)
61 | plt.legend()
62 | plt.show()
63 |
64 |
65 | if __name__ == '__main__':
66 | main()
--------------------------------------------------------------------------------
/figures/case.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/figures/case.png
--------------------------------------------------------------------------------
/figures/cma.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/figures/cma.gif
--------------------------------------------------------------------------------
/figures/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/feizc/Gradient-Free-Textual-Inversion/80ac300e52c47009e5ec467fe27cc08a88ba369c/figures/framework.png
--------------------------------------------------------------------------------
/infer_inversion.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | from torch import autocast
5 |
6 | import PIL
7 | from PIL import Image
8 |
9 | from diffusers import StableDiffusionPipeline
10 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
11 |
12 |
13 | def image_grid(imgs, rows, cols):
14 | assert len(imgs) == rows*cols
15 |
16 | w, h = imgs[0].size
17 | grid = Image.new('RGB', size=(cols*w, rows*h))
18 | grid_w, grid_h = grid.size
19 |
20 | for i, img in enumerate(imgs):
21 | grid.paste(img, box=(i%cols*w, i//cols*h))
22 | return grid
23 |
24 |
25 | def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
26 | loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
27 |
28 | # separate token and the embeds
29 | trained_token = list(loaded_learned_embeds.keys())[0]
30 | embeds = loaded_learned_embeds[trained_token]
31 |
32 | # cast to dtype of text_encoder
33 | dtype = text_encoder.get_input_embeddings().weight.dtype
34 | embeds.to(dtype)
35 |
36 | # add the token in tokenizer
37 | token = token if token is not None else trained_token
38 | num_added_tokens = tokenizer.add_tokens(token)
39 | if num_added_tokens == 0:
40 | raise ValueError(f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer.")
41 |
42 | # resize the token embeddings
43 | text_encoder.resize_token_embeddings(len(tokenizer))
44 |
45 | # get the id for the token and assign the embeds
46 | token_id = tokenizer.convert_tokens_to_ids(token)
47 | text_encoder.get_input_embeddings().weight.data[token_id] = embeds
48 |
49 |
50 |
51 | def main():
52 | parser = argparse.ArgumentParser()
53 | parser.add_argument("--model_path", default='./ckpt', type=str)
54 | parser.add_argument("--inversion_path", default='./save/learned_embeds.bin' , type=str)
55 | parser.add_argument("--prompt", default='city under the sun, painting, in a style of ' , type=str)
56 | args = parser.parse_args()
57 |
58 | model_path = args.model_path
59 | learned_embeds_path = args.inversion_path
60 | device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
61 |
62 | tokenizer = CLIPTokenizer.from_pretrained(
63 | os.path.join(model_path, 'tokenizer')
64 | )
65 | text_encoder = CLIPTextModel.from_pretrained(
66 | os.path.join(model_path, 'text_encoder')
67 | )
68 | load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer)
69 |
70 | pipe = StableDiffusionPipeline.from_pretrained(
71 | model_path,
72 | text_encoder=text_encoder,
73 | tokenizer=tokenizer,
74 | ).to(device)
75 |
76 | prompt = args.prompt
77 |
78 | num_samples = 2 #@param {type:"number"}
79 | num_rows = 2 #@param {type:"number"}
80 |
81 | all_images = []
82 | for _ in range(num_rows):
83 | with autocast("cuda"):
84 | images = pipe(prompt, num_images_per_prompt=num_samples, num_inference_steps=50, guidance_scale=7.5).images
85 | all_images.extend(images)
86 |
87 | grid = image_grid(all_images, num_samples, num_rows)
88 | grid.save('./1.png')
89 |
90 |
91 |
92 | if __name__ == '__main__':
93 | main()
94 |
--------------------------------------------------------------------------------
/initialize_inversion.py:
--------------------------------------------------------------------------------
1 | """
2 | automatically initialize the textual inversion with CLIP and no-parameter cross-attention
3 | """
4 |
5 | import torch
6 | import os
7 | import argparse
8 |
9 | from PIL import Image
10 | import torch.nn.functional as F
11 | from transformers import CLIPModel, CLIPTokenizer, CLIPProcessor, CLIPTextModel
12 | from utils import imagenet_template, automatic_subjective_classnames
13 |
14 |
15 | def embedding_generate(model, tokenizer, text_encoder, classnames, templates, device):
16 | """
17 | pre-caculate the template sentence, token embeddings
18 | """
19 | with torch.no_grad():
20 | sentence_weights = []
21 | token_weights = []
22 | token_embedding_table = text_encoder.get_input_embeddings().weight.data
23 | for classname in classnames:
24 | texts = [template(classname) for template in templates] # format with class
25 | texts = tokenizer(texts, padding="max_length", max_length=77, truncation=True, return_tensors="pt") # tokenize
26 | texts = texts['input_ids'].to(device)
27 | class_embeddings = model.get_text_features(texts)
28 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
29 | class_embedding /= class_embedding.norm()
30 | sentence_weights.append(class_embedding)
31 |
32 | token_ids = tokenizer.encode(classname,add_special_tokens=False)
33 | token_embedding_list = []
34 | for token_id in token_ids:
35 | token_embedding_list.append(token_embedding_table[token_id])
36 | token_weights.append(torch.mean(torch.stack(token_embedding_list), dim=0))
37 |
38 | sentence_weights = torch.stack(sentence_weights, dim=1).to(device)
39 | token_weights = torch.stack(token_weights, dim=0).to(device)
40 | return sentence_weights, token_weights
41 |
42 |
43 |
44 | def image_condition_embed_initialize(image_feature_list, sentence_embeddings, token_embeddings):
45 | """
46 | no-parameter cross-attention: query: image, key: sentence, value: token
47 | """
48 | inversion_emb_list = []
49 | for image_features in image_feature_list:
50 | cross_attention = image_features @ sentence_embeddings
51 | attention_probs = F.softmax(cross_attention, dim=-1)
52 | inversion_emb = torch.matmul(attention_probs, token_embeddings)
53 | inversion_emb_list.append(inversion_emb)
54 |
55 | final_inversion = torch.mean(torch.stack(inversion_emb_list), dim=0)
56 | final_inversion = final_inversion / final_inversion.norm()
57 | return final_inversion
58 |
59 |
60 |
61 | def main():
62 | parser = argparse.ArgumentParser()
63 | parser.add_argument("--save_path", default='./save', type=str)
64 | parser.add_argument("--data_path", default='./cat', type=str)
65 | args = parser.parse_args()
66 |
67 | save_path = args.save_path
68 | data_path = args.data_path
69 | device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
70 | tokenizer = CLIPTokenizer.from_pretrained('./clip')
71 | model = CLIPModel.from_pretrained('./clip')
72 | text_encoder = CLIPTextModel.from_pretrained('./clip')
73 | processor = CLIPProcessor.from_pretrained('./clip')
74 |
75 | sentence_embeddings, token_embeddings = embedding_generate(model,
76 | tokenizer,
77 | text_encoder,
78 | automatic_subjective_classnames,
79 | imagenet_template,
80 | device)
81 | print('sentence embedding size: ', sentence_embeddings.size(), ' token embedding size: ', token_embeddings.size())
82 |
83 | image_feature_list = []
84 | name_list = os.listdir(data_path)
85 | for name in name_list:
86 | image_path = os.path.join(data_path, name)
87 | image = Image.open(image_path)
88 | inputs = processor(images=image, return_tensors="pt")
89 | image_features = model.get_image_features(**inputs)
90 | image_features = F.normalize(image_features, dim=-1)
91 | image_feature_list.append(image_features)
92 | print('image size: ', len(image_feature_list))
93 |
94 | inversion_emb = image_condition_embed_initialize(image_feature_list, sentence_embeddings, token_embeddings)
95 |
96 | inversion_emb_dict = {"initialize": inversion_emb.detach().cpu()}
97 | torch.save(inversion_emb_dict, os.path.join(save_path, 'initialize_emb.bin'))
98 |
99 |
100 |
101 | if __name__ == "__main__":
102 | main()
103 |
--------------------------------------------------------------------------------
/train_inversion.py:
--------------------------------------------------------------------------------
1 | import cma
2 | import argparse
3 | import torch
4 | import os
5 | import numpy as np
6 | import copy
7 | from sklearn.decomposition import PCA
8 |
9 | from diffusers import StableDiffusionPipeline, DDPMScheduler
10 | from transformers import CLIPTextModel, CLIPTokenizer
11 |
12 | import torch.nn.functional as F
13 | from utils import TextualInversionDataset
14 | from tqdm import tqdm
15 |
16 |
17 | class GradientFreePipeline:
18 | def __init__(self, model_path, args, init_text_inversion=None, ):
19 | self.tokenizer = CLIPTokenizer.from_pretrained(
20 | os.path.join(model_path, 'tokenizer')
21 | )
22 | self.text_encoder = CLIPTextModel.from_pretrained(
23 | os.path.join(model_path, 'text_encoder')
24 | )
25 | self.pipe = StableDiffusionPipeline.from_pretrained(
26 | model_path,
27 | text_encoder=self.text_encoder,
28 | tokenizer=self.tokenizer,
29 | ).to(args.device)
30 |
31 | if args.projection_modeling == 'prior_normal':
32 | self.linear = torch.nn.Linear(args.intrinsic_dim, args.model_dim, bias=False).to(args.device)
33 | embedding = self.text_encoder.get_input_embeddings().weight.clone().cpu()
34 | mu_hat = np.mean(embedding.reshape(-1).detach().cpu().numpy())
35 | std_hat = np.std(embedding.reshape(-1).detach().cpu().numpy())
36 | mu = 0.0
37 | std = args.alpha * std_hat / (np.sqrt(args.intrinsic_dim) * args.sigma)
38 |
39 | # incorporate temperature factor
40 | # temp = intrinsic_dim - std_hat * std_hat
41 | # mu = mu_hat / temp
42 | # std = std_hat / np.sqrt(temp)
43 | print('[Embedding] mu: {} | std: {} [RandProj] mu: {} | std: {}'.format(mu_hat, std_hat, mu, std))
44 | for p in self.linear.parameters():
45 | torch.nn.init.normal_(p, mu, std)
46 |
47 | elif args.projection_modeling == 'pca':
48 | embedding = self.text_encoder.get_input_embeddings().weight.clone().cpu()
49 | embedding = embedding.detach().cpu().numpy() # (49408, 768)
50 |
51 | self.pca_model = PCA(n_components=args.intrinsic_dim)
52 | self.pca_model.fit(embedding)
53 |
54 |
55 | # Add the placeholder token in tokenizer
56 | num_added_tokens = self.tokenizer.add_tokens(args.placeholder_token)
57 | if num_added_tokens == 0:
58 | raise ValueError(
59 | f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
60 | " `placeholder_token` that is not already in the tokenizer."
61 | )
62 | # Convert the initializer_token, placeholder_token to ids
63 | token_ids = self.tokenizer.encode(args.initializer_token, add_special_tokens=False)
64 |
65 | initializer_token_id = token_ids[0]
66 | placeholder_token_id = self.tokenizer.convert_tokens_to_ids(args.placeholder_token)
67 | # Resize the token embeddings as we are adding new special tokens to the tokenizer
68 | self.text_encoder.resize_token_embeddings(len(self.tokenizer))
69 | # Initialise the newly added placeholder token with the embeddings of the initializer token
70 | token_embeds = self.text_encoder.get_input_embeddings().weight.data
71 | token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
72 |
73 | print('convert text inversion: ', args.placeholder_token, 'in id: ', str(placeholder_token_id))
74 | self.placeholder_token_id = placeholder_token_id
75 | self.placeholder_token = args.placeholder_token
76 | self.num_call = 0
77 |
78 | train_dataset = TextualInversionDataset(
79 | data_root=args.train_data_dir,
80 | tokenizer=self.tokenizer,
81 | size=args.resolution,
82 | placeholder_token=args.placeholder_token,
83 | repeats=args.repeats,
84 | learnable_property=args.learnable_property,
85 | center_crop=args.center_crop,
86 | set="train",
87 | )
88 | self.dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.repeats, shuffle=True)
89 | self.batch_size = args.repeats
90 | self.device = args.device
91 | print('load data length: ', len(self.dataloader))
92 |
93 | # optimize incremental elements or original inversion
94 | if init_text_inversion is not None:
95 | self.init_text_inversion = init_text_inversion.to(args.device)
96 | else:
97 | self.init_text_inversion = token_embeds[initializer_token_id].to(args.device)
98 |
99 | self.args = args
100 | self.best_inversion = None
101 |
102 | def eval(self, inversion_embedding):
103 | self.num_call += 1
104 | pe_list = []
105 | if isinstance(inversion_embedding, list): # multiple queries
106 | for pe in inversion_embedding:
107 | if self.args.projection_modeling == 'prior_normal':
108 | z = torch.tensor(pe).type(torch.float32).to(self.device) # z
109 | with torch.no_grad():
110 | z = self.linear(z) # W_p Q
111 | if self.init_text_inversion is not None:
112 | z = z + self.init_text_inversion # W_p Q + p_0
113 | elif self.args.projection_modeling == 'pca':
114 | z = self.pca_model.inverse_transform(pe) # project the original text embedding space
115 | z = torch.tensor(z).type(torch.float32).to(self.device)
116 | if self.init_text_inversion is not None:
117 | z = z + self.init_text_inversion
118 | pe_list.append(z)
119 |
120 | elif isinstance(inversion_embedding, np.ndarray): # single query or None
121 | if self.args.projection_modeling == 'prior_normal':
122 | inversion_embedding = torch.tensor(inversion_embedding).type(torch.float32).to(self.device) # z
123 | with torch.no_grad():
124 | inversion_embedding = self.linear(inversion_embedding) # W_p Q
125 | elif self.args.projection_modeling == 'pca':
126 | inversion_embedding = self.pca_model.inverse_transform(inversion_embedding)
127 | inversion_embedding = torch.tensor(inversion_embedding).type(torch.float32).to(self.device)
128 | if self.init_text_inversion is not None:
129 | inversion_embedding = inversion_embedding + self.init_text_inversion # W_p Q + p_0
130 | pe_list.append(inversion_embedding)
131 | else:
132 | raise ValueError(
133 | f'[Inversion Embedding] Only support [list, numpy.ndarray], got `{type(inversion_embedding)}` instead.'
134 | )
135 |
136 | loss_list = []
137 | print('begin to calculate loss')
138 |
139 | # fixed time step for fair evaluation
140 | noise_scheduler = DDPMScheduler.from_config('./ckpt/scheduler')
141 | timesteps = torch.randint(
142 | 0, noise_scheduler.config.num_train_timesteps, (self.batch_size,), device=self.device
143 | ).long()
144 |
145 | best_loss = 1000
146 | best_inversion = None
147 |
148 | for pe in tqdm(pe_list):
149 | token_embeds = self.text_encoder.get_input_embeddings().weight.data
150 | pe.to(self.text_encoder.get_input_embeddings().weight.dtype)
151 | token_embeds[self.placeholder_token_id] = pe
152 | loss = calculate_mse_loss(self.pipe, self.dataloader, self.device, noise_scheduler, timesteps)
153 | if loss < best_loss:
154 | best_loss = loss
155 | best_inversion = pe
156 | loss_list.append(loss)
157 |
158 | # update total point
159 | self.best_inversion = best_inversion
160 |
161 | return loss_list
162 |
163 |
164 | def save(self, output_path):
165 | learned_embeds_dict = {self.placeholder_token: self.best_inversion.detach().cpu()}
166 | torch.save(learned_embeds_dict, os.path.join(output_path, "learned_embeds.bin"))
167 |
168 |
169 |
170 | def calculate_mse_loss(image_generator, dataloader, device, noise_scheduler, timesteps):
171 | # print(image_generator.text_encoder.get_input_embeddings().weight.data[49408])
172 |
173 | loss_cum = .0
174 | with torch.no_grad():
175 | for batch in dataloader:
176 | # Convert images to latent space
177 | latents = image_generator.vae.encode(batch["pixel_values"].to(device)).latent_dist.sample().detach()
178 | latents = latents * 0.18215
179 |
180 | # Sample noise that we'll add to the latents
181 | noise = torch.randn(latents.shape).to(latents.device)
182 | # Sample a random timestep for each image
183 |
184 | # Add noise to the latents according to the noise magnitude at each timestep
185 | # (this is the forward diffusion process)
186 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
187 |
188 | # Get the text embedding for conditioning
189 | encoder_hidden_states = image_generator.text_encoder(batch["input_ids"].to(device))[0]
190 |
191 | # Predict the noise residual
192 | noise_pred = image_generator.unet(noisy_latents, timesteps, encoder_hidden_states).sample
193 |
194 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
195 | loss_cum += loss.item()
196 |
197 | return loss_cum / len(dataloader)
198 |
199 |
200 |
201 |
202 |
203 | def main():
204 | parser = argparse.ArgumentParser()
205 |
206 | parser.add_argument("--intrinsic_dim", default=256, type=int)
207 | parser.add_argument("--k_shot", default=16, type=int)
208 | parser.add_argument("--batch_size", default=32, type=int)
209 | parser.add_argument("--budget", default=5000, type=int) # number of iterations
210 | parser.add_argument("--popsize", default=20, type=int) # number of candidates
211 | parser.add_argument("--bound", default=0, type=int)
212 | parser.add_argument("--sigma", default=1, type=float)
213 | parser.add_argument("--alpha", default=1, type=float)
214 | parser.add_argument("--print_every", default=50, type=int)
215 | parser.add_argument("--eval_every", default=100, type=int)
216 | parser.add_argument("--alg", default='CMA', type=str) # support other advanced evelution strategy
217 | parser.add_argument("--projection_modeling", default='pca', type=str) # decomposition method {'pca', 'prior_norm'}
218 | parser.add_argument("--model_dim", default=768, type=int) # dim of textual inversion
219 | parser.add_argument("--inversion_initialize", default='./save/initialize_emb.bin', type=str) # dim of textual inversion
220 | parser.add_argument("--seed", default=2023, type=int)
221 | parser.add_argument("--loss_type", default='noise', type=str)
222 | parser.add_argument("--cat_or_add", default='add', type=str)
223 | parser.add_argument("--device", default= torch.device("cuda:2" if torch.cuda.is_available() else "cpu"))
224 | parser.add_argument("--parallel", default=False, type=bool, help='Whether to allow parallel evaluation')
225 |
226 | parser.add_argument(
227 | "--placeholder_token",
228 | type=str,
229 | default='',
230 | help="A token to use as a placeholder for the concept.",
231 | )
232 | parser.add_argument(
233 | "--initializer_token",
234 | type=str,
235 | default='painting',
236 | help="A token to use as initializer word."
237 | )
238 | parser.add_argument(
239 | "--inference_framework",
240 | default='pt',
241 | type=str,
242 | help='''Which inference framework to use.
243 | Currently supports `pt` and `ort`, standing for pytorch and Microsoft onnxruntime respectively'''
244 | )
245 | parser.add_argument(
246 | "--onnx_model_path",
247 | default=None,
248 | type=str,
249 | help='Path to your onnx model.'
250 | )
251 | parser.add_argument(
252 | "--train_data_dir",
253 | type=str,
254 | default='./data',
255 | help="A folder containing the training data of instance images.",
256 | )
257 | parser.add_argument(
258 | "--learnable_property",
259 | type=str,
260 | default="style",
261 | help="Choose between 'object' and 'style'"
262 | )
263 | parser.add_argument(
264 | "--resolution",
265 | type=int,
266 | default=512,
267 | help=(
268 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
269 | " resolution"
270 | ),
271 | )
272 | parser.add_argument(
273 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
274 | )
275 | parser.add_argument("--repeats", type=int, default=5, help="How many times to repeat the training data.")
276 |
277 | args = parser.parse_args()
278 |
279 | cma_opts = {
280 | 'seed': args.seed,
281 | 'popsize': args.popsize,
282 | 'maxiter': args.budget if args.parallel else args.budget // args.popsize,
283 | 'verbose': -1,
284 | }
285 |
286 | if args.bound > 0:
287 | cma_opts['bounds'] = [-1 * args.bound, 1 * args.bound]
288 |
289 | if args.inversion_initialize is not None:
290 | print('initialize textual inversion')
291 | init_text_inversion = torch.load(args.inversion_initialize, map_location="cpu")["initialize"]
292 | else:
293 | init_text_inversion = None
294 |
295 | pipeline = GradientFreePipeline(model_path='./ckpt', args=args, init_text_inversion=init_text_inversion)
296 |
297 | es = cma.CMAEvolutionStrategy(args.intrinsic_dim * [0], args.sigma, inopts=cma_opts)
298 |
299 | while not es.stop():
300 | solutions = es.ask() # (popsize, intrinsic_dim)
301 | fitnesses = pipeline.eval(solutions)
302 | print(fitnesses) # loss for each point
303 | es.tell(solutions, fitnesses)
304 | pipeline.save('./save')
305 |
306 |
307 | if __name__ == "__main__":
308 | main()
309 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import os
3 | import numpy as np
4 | from torchvision import transforms
5 | from PIL import Image
6 | import random
7 | import PIL
8 | import torch
9 |
10 |
11 | imagenet_templates_small = [
12 | "a photo of a {}",
13 | "a rendering of a {}",
14 | "a cropped photo of the {}",
15 | "the photo of a {}",
16 | "a photo of a clean {}",
17 | "a photo of a dirty {}",
18 | "a dark photo of the {}",
19 | "a photo of my {}",
20 | "a photo of the cool {}",
21 | "a close-up photo of a {}",
22 | "a bright photo of the {}",
23 | "a cropped photo of a {}",
24 | "a photo of the {}",
25 | "a good photo of the {}",
26 | "a photo of one {}",
27 | "a close-up photo of the {}",
28 | "a rendition of the {}",
29 | "a photo of the clean {}",
30 | "a rendition of a {}",
31 | "a photo of a nice {}",
32 | "a good photo of a {}",
33 | "a photo of the nice {}",
34 | "a photo of the small {}",
35 | "a photo of the weird {}",
36 | "a photo of the large {}",
37 | "a photo of a cool {}",
38 | "a photo of a small {}",
39 | ]
40 |
41 |
42 | imagenet_style_templates_small = [
43 | "a painting in the style of {}",
44 | "a rendering in the style of {}",
45 | "a cropped painting in the style of {}",
46 | "the painting in the style of {}",
47 | "a clean painting in the style of {}",
48 | "a dirty painting in the style of {}",
49 | "a dark painting in the style of {}",
50 | "a picture in the style of {}",
51 | "a cool painting in the style of {}",
52 | "a close-up painting in the style of {}",
53 | "a bright painting in the style of {}",
54 | "a cropped painting in the style of {}",
55 | "a good painting in the style of {}",
56 | "a close-up painting in the style of {}",
57 | "a rendition in the style of {}",
58 | "a nice painting in the style of {}",
59 | "a small painting in the style of {}",
60 | "a weird painting in the style of {}",
61 | "a large painting in the style of {}",
62 | ]
63 |
64 |
65 |
66 | automatic_subjective_classnames = [
67 | "tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray",
68 | "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco",
69 | "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper",
70 | "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander",
71 | "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog",
72 | "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin",
73 | "box turtle", "banded gecko", "green iguana", "Carolina anole",
74 | "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard",
75 | "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile",
76 | "American alligator", "triceratops", "worm snake", "ring-necked snake",
77 | "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake",
78 | "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra",
79 | "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake",
80 | "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider",
81 | "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider",
82 | "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl",
83 | "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet",
84 | "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck",
85 | "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby",
86 | "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch",
87 | "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab",
88 | "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab",
89 | "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron",
90 | "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot",
91 | "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher",
92 | "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion",
93 | "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel",
94 | "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle",
95 | "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound",
96 | "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound",
97 | "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound",
98 | "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier",
99 | "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier",
100 | "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier",
101 | "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier",
102 | "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer",
103 | "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier",
104 | "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier",
105 | "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever",
106 | "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla",
107 | "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel",
108 | "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel",
109 | "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard",
110 | "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie",
111 | "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann",
112 | "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog",
113 | "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff",
114 | "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky",
115 | "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog",
116 | "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon",
117 | "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle",
118 | "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf",
119 | "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox",
120 | "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat",
121 | "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger",
122 | "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose",
123 | "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle",
124 | "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper",
125 | "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper",
126 | "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly",
127 | "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly",
128 | "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit",
129 | "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse",
130 | "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison",
131 | "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)",
132 | "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat",
133 | "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan",
134 | "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque",
135 | "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin",
136 | "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey",
137 | "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda",
138 | "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish",
139 | "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown",
140 | "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance",
141 | "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle",
142 | "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo",
143 | "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel",
144 | "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel",
145 | "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)",
146 | "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini",
147 | "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet",
148 | "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra",
149 | "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest",
150 | "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe",
151 | "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton",
152 | "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran",
153 | "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw",
154 | "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking",
155 | "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker",
156 | "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard",
157 | "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot",
158 | "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed",
159 | "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer",
160 | "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table",
161 | "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig",
162 | "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar",
163 | "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder",
164 | "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute",
165 | "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed",
166 | "freight car", "French horn", "frying pan", "fur coat", "garbage truck",
167 | "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola",
168 | "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine",
169 | "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer",
170 | "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet",
171 | "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar",
172 | "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep",
173 | "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat",
174 | "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library",
175 | "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion",
176 | "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag",
177 | "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask",
178 | "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone",
179 | "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile",
180 | "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor",
181 | "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa",
182 | "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail",
183 | "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina",
184 | "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart",
185 | "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush",
186 | "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench",
187 | "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case",
188 | "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube",
189 | "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball",
190 | "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag",
191 | "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho",
192 | "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug",
193 | "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill",
194 | "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel",
195 | "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator",
196 | "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser",
197 | "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal",
198 | "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard",
199 | "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store",
200 | "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap",
201 | "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door",
202 | "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
203 | "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater",
204 | "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight",
205 | "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf",
206 | "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa",
207 | "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge",
208 | "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe",
209 | "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball",
210 | "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof",
211 | "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store",
212 | "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod",
213 | "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard",
214 | "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling",
215 | "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball",
216 | "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink",
217 | "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle",
218 | "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing",
219 | "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website",
220 | "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu",
221 | "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette",
222 | "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli",
223 | "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber",
224 | "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange",
225 | "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate",
226 | "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito",
227 | "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef",
228 | "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player",
229 | "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn",
230 | "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom",
231 | "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper",
232 | ]
233 |
234 |
235 | # refer to: https://github.com/samiramunir/Classifying-Painting-Art-Style-With-Deep-Learning
236 | automatic_style_classnames = [
237 | "realism",
238 | "photorealism",
239 | "expressionism",
240 | "impressionism",
241 | "abstract",
242 | "surrealism",
243 | "pop art",
244 | "oil",
245 | "watercolor",
246 | "acrylic",
247 | "gouache",
248 | "pastel",
249 | "encaustic",
250 | "fresco",
251 | "spray paint",
252 | "digital",
253 | "history",
254 | "portrait",
255 | "genre",
256 | "landscape"
257 | ]
258 |
259 |
260 | imagenet_template = [
261 | lambda c: f'a bad photo of a {c}.',
262 | lambda c: f'a photo of many {c}.',
263 | lambda c: f'a sculpture of a {c}.',
264 | lambda c: f'a photo of the hard to see {c}.',
265 | lambda c: f'a low resolution photo of the {c}.',
266 | lambda c: f'a rendering of a {c}.',
267 | lambda c: f'graffiti of a {c}.',
268 | lambda c: f'a bad photo of the {c}.',
269 | lambda c: f'a cropped photo of the {c}.',
270 | lambda c: f'a tattoo of a {c}.',
271 | lambda c: f'the embroidered {c}.',
272 | lambda c: f'a photo of a hard to see {c}.',
273 | lambda c: f'a bright photo of a {c}.',
274 | lambda c: f'a photo of a clean {c}.',
275 | lambda c: f'a photo of a dirty {c}.',
276 | lambda c: f'a dark photo of the {c}.',
277 | lambda c: f'a drawing of a {c}.',
278 | lambda c: f'a photo of my {c}.',
279 | lambda c: f'the plastic {c}.',
280 | lambda c: f'a photo of the cool {c}.',
281 | lambda c: f'a close-up photo of a {c}.',
282 | lambda c: f'a black and white photo of the {c}.',
283 | lambda c: f'a painting of the {c}.',
284 | lambda c: f'a painting of a {c}.',
285 | lambda c: f'a pixelated photo of the {c}.',
286 | lambda c: f'a sculpture of the {c}.',
287 | lambda c: f'a bright photo of the {c}.',
288 | lambda c: f'a cropped photo of a {c}.',
289 | lambda c: f'a plastic {c}.',
290 | lambda c: f'a photo of the dirty {c}.',
291 | lambda c: f'a jpeg corrupted photo of a {c}.',
292 | lambda c: f'a blurry photo of the {c}.',
293 | lambda c: f'a photo of the {c}.',
294 | lambda c: f'a good photo of the {c}.',
295 | lambda c: f'a rendering of the {c}.',
296 | lambda c: f'a {c} in a video game.',
297 | lambda c: f'a photo of one {c}.',
298 | lambda c: f'a doodle of a {c}.',
299 | lambda c: f'a close-up photo of the {c}.',
300 | lambda c: f'a photo of a {c}.',
301 | lambda c: f'the origami {c}.',
302 | lambda c: f'the {c} in a video game.',
303 | lambda c: f'a sketch of a {c}.',
304 | lambda c: f'a doodle of the {c}.',
305 | lambda c: f'a origami {c}.',
306 | lambda c: f'a low resolution photo of a {c}.',
307 | lambda c: f'the toy {c}.',
308 | lambda c: f'a rendition of the {c}.',
309 | lambda c: f'a photo of the clean {c}.',
310 | lambda c: f'a photo of a large {c}.',
311 | lambda c: f'a rendition of a {c}.',
312 | lambda c: f'a photo of a nice {c}.',
313 | lambda c: f'a photo of a weird {c}.',
314 | lambda c: f'a blurry photo of a {c}.',
315 | lambda c: f'a cartoon {c}.',
316 | lambda c: f'art of a {c}.',
317 | lambda c: f'a sketch of the {c}.',
318 | lambda c: f'a embroidered {c}.',
319 | lambda c: f'a pixelated photo of a {c}.',
320 | lambda c: f'itap of the {c}.',
321 | lambda c: f'a jpeg corrupted photo of the {c}.',
322 | lambda c: f'a good photo of a {c}.',
323 | lambda c: f'a plushie {c}.',
324 | lambda c: f'a photo of the nice {c}.',
325 | lambda c: f'a photo of the small {c}.',
326 | lambda c: f'a photo of the weird {c}.',
327 | lambda c: f'the cartoon {c}.',
328 | lambda c: f'art of the {c}.',
329 | lambda c: f'a drawing of the {c}.',
330 | lambda c: f'a photo of the large {c}.',
331 | lambda c: f'a black and white photo of a {c}.',
332 | lambda c: f'the plushie {c}.',
333 | lambda c: f'a dark photo of a {c}.',
334 | lambda c: f'itap of a {c}.',
335 | lambda c: f'graffiti of the {c}.',
336 | lambda c: f'a toy {c}.',
337 | lambda c: f'itap of my {c}.',
338 | lambda c: f'a photo of a cool {c}.',
339 | lambda c: f'a photo of a small {c}.',
340 | lambda c: f'a tattoo of the {c}.',
341 | ]
342 |
343 |
344 |
345 | class TextualInversionDataset(Dataset):
346 | def __init__(
347 | self,
348 | data_root,
349 | tokenizer,
350 | learnable_property="style", # [object, style]
351 | size=512,
352 | repeats=1,
353 | interpolation="bicubic",
354 | flip_p=0.5,
355 | set="train",
356 | placeholder_token="*",
357 | center_crop=False,
358 | ):
359 | self.data_root = data_root
360 | self.tokenizer = tokenizer
361 | self.learnable_property = learnable_property
362 | self.size = size
363 | self.placeholder_token = placeholder_token
364 | self.center_crop = center_crop
365 | self.flip_p = flip_p
366 |
367 | self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
368 |
369 | self.num_images = len(self.image_paths)
370 | self._length = self.num_images
371 |
372 | if set == "train":
373 | self._length = self.num_images * repeats
374 |
375 | self.interpolation = {
376 | "linear": PIL.Image.LINEAR,
377 | "bilinear": PIL.Image.BILINEAR,
378 | "bicubic": PIL.Image.BICUBIC,
379 | "lanczos": PIL.Image.LANCZOS,
380 | }[interpolation]
381 |
382 | self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
383 | self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
384 |
385 | def __len__(self):
386 | return self._length
387 |
388 | def __getitem__(self, i):
389 | example = {}
390 | try:
391 | image = Image.open(self.image_paths[i % self.num_images])
392 | except:
393 | image = Image.open(self.image_paths[(i+1) % self.num_images])
394 | if not image.mode == "RGB":
395 | image = image.convert("RGB")
396 |
397 | placeholder_string = self.placeholder_token
398 | text = random.choice(self.templates).format(placeholder_string)
399 |
400 | example["input_ids"] = self.tokenizer(
401 | text,
402 | padding="max_length",
403 | truncation=True,
404 | max_length=self.tokenizer.model_max_length,
405 | return_tensors="pt",
406 | ).input_ids[0]
407 |
408 | # default to score-sde preprocessing
409 | img = np.array(image).astype(np.uint8)
410 |
411 | if self.center_crop:
412 | crop = min(img.shape[0], img.shape[1])
413 | h, w, = (
414 | img.shape[0],
415 | img.shape[1],
416 | )
417 | img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
418 |
419 | image = Image.fromarray(img)
420 | image = image.resize((self.size, self.size), resample=self.interpolation)
421 |
422 | image = self.flip_transform(image)
423 | image = np.array(image).astype(np.uint8)
424 | image = (image / 127.5 - 1.0).astype(np.float32)
425 |
426 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
427 | return example
428 |
429 |
--------------------------------------------------------------------------------