├── tests ├── __init__.py ├── utils.py ├── test_cupy_utils.py ├── test_examples.py ├── test_nodes.py ├── test_model_check.py └── test_utils.py ├── images ├── nmf_cost_function.png └── nmf_tree.svg ├── .gitignore ├── wonterfact ├── __init__.py ├── examples │ ├── conv_nmf.py │ ├── snmf.py │ └── nmf.py ├── buds.py ├── glob_var_manager.py ├── observers.py ├── cupy_utils.py ├── graphviz.py └── operators.py ├── pyproject.toml └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/nmf_cost_function.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SmartImpulse/Wonterfact/HEAD/images/nmf_cost_function.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *.pyc 3 | *.egg-info 4 | build/* 5 | dist/* 6 | .coverage 7 | htmlcov 8 | jupyter/* 9 | .vscode/* 10 | wonterfact/untracked/* 11 | !jupyter/skellam-snmf-experiments.ipynb 12 | -------------------------------------------------------------------------------- /wonterfact/__init__.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright 2020 Smart Impulse SAS, Benoit Fuentes 3 | # 4 | # This file is part of Wonterfact. 5 | # 6 | # Wonterfact is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # any later version. 10 | # 11 | # Wonterfact is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Wonterfact. If not, see . 18 | # ---------------------------------------------------------------------------- 19 | 20 | """Initialization file for the wonterfact package""" 21 | 22 | # Python Future imports 23 | from __future__ import division, unicode_literals, print_function, absolute_import 24 | 25 | # Python System imports 26 | 27 | # Third-party imports 28 | 29 | # Smart impulse common modules 30 | 31 | # Relative imports 32 | from . import utils 33 | from .utils import create_filiation 34 | from .root import Root 35 | from .leaves import LeafGamma, LeafDirichlet, LeafGammaNorm 36 | from .operators import Multiplier, Multiplexer, Integrator, Adder, Proxy 37 | from .observers import PosObserver, RealObserver, BlindObs 38 | from .glob_var_manager import glob 39 | 40 | # Django imports only if possible. 41 | 42 | __all__ = [ 43 | "LeafGamma", 44 | "LeafDirichlet", 45 | "Multiplier", 46 | "Multiplexer", 47 | "PosObserver", 48 | "RealObserver", 49 | "Root", 50 | "Integrator", 51 | "Adder", 52 | "Proxy", 53 | "BlindObs", 54 | "LeafGammaNorm", 55 | "utils", 56 | "glob", 57 | "create_filiation", 58 | ] 59 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright 2020 Benoit Fuentes 3 | # 4 | # This file is part of Wonterfact. 5 | # 6 | # Wonterfact is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # any later version. 10 | # 11 | # Wonterfact is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Wonterfact. If not, see . 18 | # ---------------------------------------------------------------------------- 19 | 20 | """Useful methods for testing a wonterfact model""" 21 | 22 | # Python standard library 23 | from pathlib import Path 24 | import tempfile 25 | 26 | # Third-party imports 27 | import numpy as np 28 | 29 | # wonterfact imports 30 | 31 | 32 | def _assert_graphviz_ok(tree, legend_dict=None): 33 | with tempfile.TemporaryDirectory() as tmpdir: 34 | filemane = Path(tmpdir) / "test.pdf" 35 | tree.draw_tree( 36 | filename=filemane, legend_dict=legend_dict, prior_nodes=True, view=False 37 | ) 38 | return filemane.exists() 39 | 40 | 41 | def _assert_cost_decrease(tree): 42 | return (np.diff(tree.cost_record) <= 0).all() 43 | 44 | 45 | def _setup_tree_for_decreasing_cost( 46 | tree, inference_mode="EM", update_type="parabolic", limit_skellam=True 47 | ): 48 | tree.cost_computation_iter = 1 49 | tree.acceleration_start_iter = 10 50 | for parent in tree.list_of_parents: 51 | parent.drawings_step = parent.drawings_max 52 | parent.drawings = parent.drawings_max 53 | parent.limit_skellam_update = limit_skellam 54 | tree.inference_mode = inference_mode 55 | tree.update_type = update_type 56 | tree.stop_estim_threshold = 1e-10 57 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "wonterfact" 3 | version = "3.0.0" 4 | description = "A powerful tool to design any tensor factorization model and estimate the corresponding parameters" 5 | authors = ["Smart-Impulse ", "Benoit Fuentes "] 6 | license = "GPL-3.0-or-later" 7 | readme = "README.md" 8 | repository = "https://github.com/smartimpulse/wonterfact" 9 | keywords = ["tensor", "factorization", "IA", "signal", "bayes"] 10 | classifiers = [ 11 | "Intended Audience :: Developers", 12 | "Intended Audience :: Education", 13 | "Intended Audience :: Information Technology", 14 | "Intended Audience :: Science/Research", 15 | "License :: OSI Approved :: GNU Lesser General Public License v3 or later (LGPLv3+)", 16 | "Operating System :: OS Independent", 17 | "Programming Language :: Python :: 3", 18 | "Programming Language :: Python :: 3.8", 19 | "Programming Language :: Python :: 3.9", 20 | "Programming Language :: Python :: 3 :: Only", 21 | "Topic :: Scientific/Engineering", 22 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 23 | "Topic :: Scientific/Engineering :: Information Analysis", 24 | "Topic :: Scientific/Engineering :: Mathematics", 25 | "Topic :: Software Development", 26 | "Topic :: Software Development :: Libraries", 27 | "Topic :: Software Development :: Libraries :: Python Modules", 28 | "Topic :: Utilities", 29 | ] 30 | 31 | [tool.poetry.dependencies] 32 | python = ">=3.8, <3.10" 33 | numpy = ">=1.20.2" 34 | numba = ">=0.53.1" 35 | scipy = ">=1.6.2" 36 | opt_einsum = ">=3.3.0" 37 | custom_inherit = ">=2.3.1" 38 | python-baseconv = ">=1.2.2" 39 | graphviz = ">=0.16" 40 | methodtools = ">=0.4.2" 41 | 42 | [tool.poetry.dev-dependencies] 43 | pytest = ">=5.2" 44 | pylint = ">=2.5.3" 45 | black = {version = ">=19.10b0", allow-prereleases = true} 46 | rope = ">=0.18.0" 47 | line_profiler = "^3.2.6" 48 | 49 | [tool.pytest.ini_options] 50 | markers = [ 51 | "gpu: marks tests as running on gpu (deselect with '-m \"not gpu\"')" 52 | ] 53 | 54 | [build-system] 55 | requires = ["poetry>=0.12"] 56 | build-backend = "poetry.masonry.api" 57 | -------------------------------------------------------------------------------- /wonterfact/examples/conv_nmf.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright 2020 Benoit Fuentes 3 | # 4 | # This file is part of Wonterfact. 5 | # 6 | # Wonterfact is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # any later version. 10 | # 11 | # Wonterfact is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Wonterfact. If not, see . 18 | # ---------------------------------------------------------------------------- 19 | 20 | """Examples of wonterfact models involving convolution""" 21 | 22 | # Python standard library 23 | 24 | # Third-party imports 25 | import numpy as np 26 | import numpy.random as npr 27 | from scipy.signal import convolve2d 28 | 29 | # wonterfact imports 30 | import wonterfact as wtf 31 | import wonterfact.utils as wtfu 32 | 33 | 34 | def make_deconv_tree(): 35 | dim_x0, dim_y0, dim_x1, dim_y1 = 20, 20, 3, 3 36 | dim_x = dim_x0 - dim_x1 + 1 37 | dim_y = dim_y0 - dim_y1 + 1 38 | dim_k = 2 39 | 40 | kernel_kyx = np.zeros((dim_k, dim_y1, dim_x1)) 41 | kernel_kyx[0, [1, 0, 1, 2, 1], [0, 1, 1, 1, 2]] = 1 42 | kernel_kyx[1, [0, 0, 1, 1, 2, 2, 2], [0, 2, 0, 2, 0, 1, 2]] = 1 43 | kernel_kyx /= kernel_kyx.sum((1, 2), keepdims=True) 44 | 45 | impulse_kyx = npr.gamma(shape=0.08, scale=200, size=(dim_k, dim_y0, dim_x0)) 46 | impulse_kyx[impulse_kyx < 200] = 0 47 | impulse_kyx[impulse_kyx >= 200] = 200 48 | image_yx = np.zeros((dim_y, dim_x)) 49 | for kk in range(dim_k): 50 | image_yx += convolve2d(impulse_kyx[kk], kernel_kyx[kk], mode="valid") 51 | 52 | leaf_kernel = wtf.LeafDirichlet( 53 | name="kernel", 54 | index_id=("k", "j", "i"), 55 | norm_axis=(1, 2), 56 | tensor=np.ones((dim_k, dim_y1, dim_x1)), 57 | prior_shape=100 + 1e-4 * npr.rand(dim_k, dim_y1, dim_x1), 58 | init_type="prior", 59 | ) 60 | leaf_impulse = wtf.LeafGamma( 61 | name="impulse", 62 | index_id=("k", "y", "x"), 63 | tensor=np.ones((dim_k, dim_y0, dim_x0)), 64 | prior_shape=1, 65 | prior_rate=0.0001, 66 | ) 67 | mul_image = wtf.Multiplier( 68 | name="reconstruction", index_id="yx", conv_idx_ids=("y", "x") 69 | ) 70 | leaf_kernel.new_child(mul_image, index_id_for_child=("k", "y", "x")) 71 | leaf_impulse.new_child(mul_image, index_id_for_child=("k", "y", "x")) 72 | obs = wtf.PosObserver( 73 | name="image", 74 | index_id="yx", 75 | tensor=image_yx, 76 | drawings_max=200 * dim_x * dim_y, 77 | drawings_step=10, 78 | ) 79 | mul_image.new_child(obs) 80 | root = wtf.Root(name="root") 81 | obs.new_child(root) 82 | return root 83 | -------------------------------------------------------------------------------- /tests/test_cupy_utils.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright 2020 Benoit Fuentes 3 | # 4 | # This file is part of Wonterfact. 5 | # 6 | # Wonterfact is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # any later version. 10 | # 11 | # Wonterfact is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Wonterfact. If not, see . 18 | # ---------------------------------------------------------------------------- 19 | 20 | """Tests for all examples in the wonterfact/examples directory""" 21 | 22 | # Python standard library 23 | 24 | # Third-party imports 25 | import pytest 26 | import numpy as np 27 | 28 | try: 29 | import cupy as cp # pylint: disable=import-error 30 | except: 31 | cp = None 32 | 33 | # wonterfact and relative imports 34 | import wonterfact.cupy_utils as c_utils 35 | 36 | pytestmark = pytest.mark.gpu # marks all file methods as "gpu" 37 | 38 | 39 | @pytest.mark.parametrize("dim_cumsum", [10, 100, 1000]) 40 | def test_cupy_cumsum_2d(dim_cumsum): 41 | arr_in = cp.ones((2, dim_cumsum), dtype=float) 42 | arr_out = cp.empty_like(arr_in) 43 | c_utils.cupy_cumsum_2d(arr_in, arr_out) 44 | assert cp.allclose(arr_out, cp.cumsum(arr_in, axis=-1)) 45 | 46 | 47 | def test_find_cumsum_max_threads(): 48 | max_threads = c_utils.find_cumsum_max_threads(10, 100) 49 | assert isinstance(max_threads, int) 50 | 51 | 52 | def test_min_clip(): 53 | arr = cp.array([3.0, 4.0]) 54 | assert (c_utils.min_clip(arr, 3.5) == cp.array([3.5, 4.0])).all() 55 | 56 | 57 | def test_max_clip(): 58 | arr = cp.array([3.0, 4.0]) 59 | assert (c_utils.max_clip(arr, 3.5) == cp.array([3.0, 3.5])).all() 60 | 61 | 62 | def test_normalize_l1_l2_tensor_numba_core(): 63 | dim0, dim1 = 10, 10 64 | tensor_init = 100 * cp.random.rand(dim0, dim1) 65 | tensor_out = tensor_init.copy() 66 | # pylint: disable=unsubscriptable-object 67 | c_utils.normalize_l1_l2_tensor_numba_core[dim0, (dim1 + 1) // 2]( 68 | tensor_out, 20, 0.01 69 | ) 70 | # pylint: enable=unsubscriptable-object 71 | from wonterfact import LeafGammaNorm 72 | 73 | tensor_out_np = cp.asnumpy(tensor_init) 74 | LeafGammaNorm._normalize_l1_l2_tensor(tensor_out_np, (1,), 20, 0.01) 75 | 76 | assert np.allclose(tensor_out_np, cp.asnumpy(tensor_out)) 77 | 78 | 79 | def test_set_bezier_point(): 80 | arr1 = cp.random.rand(10) 81 | arr2 = cp.random.rand(10) 82 | arr3 = cp.random.rand(10) 83 | arr_out = cp.empty_like(arr1) 84 | param = 0.5 85 | 86 | c_utils._set_bezier_point( 87 | arr1, arr2, arr3, param, arr_out, 88 | ) 89 | arr_out2 = ( 90 | (1 - param) ** 2 * arr1 + 2 * (1 - param) * param * arr2 + (param ** 2) * arr3 91 | ) 92 | assert cp.allclose(arr_out, arr_out2) 93 | 94 | 95 | def test_multiply_and_sum(): 96 | arr1 = cp.ones(4, dtype=float) 97 | arr2 = cp.ones(4, dtype=float) 98 | arr_out = c_utils.multiply_and_sum(arr1, arr2) 99 | assert arr_out == 4.0 100 | 101 | 102 | def test_hyp0f1ln(): 103 | # already tested in test_utils 104 | pass 105 | 106 | 107 | def test_bessel_ratio(): 108 | # already tested in test_utils 109 | pass 110 | 111 | 112 | def test_inclusive_scan_2d(): 113 | # tested through cupy_cumsum_2d 114 | pass 115 | 116 | 117 | def test_sum_inclusive_scan_2d(): 118 | # tested through cupy_cumsum_2d 119 | pass 120 | -------------------------------------------------------------------------------- /tests/test_examples.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright 2020 Benoit Fuentes 3 | # 4 | # This file is part of Wonterfact. 5 | # 6 | # Wonterfact is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # any later version. 10 | # 11 | # Wonterfact is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Wonterfact. If not, see . 18 | # ---------------------------------------------------------------------------- 19 | 20 | """Tests for all examples in the wonterfact/examples directory""" 21 | 22 | # Python standard library 23 | from itertools import product 24 | 25 | # Third-party imports 26 | import pytest 27 | import numpy as np 28 | import numpy.random as npr 29 | 30 | # wonterfact and relative imports 31 | from wonterfact import glob, LeafGammaNorm 32 | from wonterfact import utils as wtfu 33 | from wonterfact.examples import snmf 34 | from wonterfact.examples import nmf 35 | from wonterfact.examples import conv_nmf 36 | 37 | # relative imports 38 | from . import utils as t_utils 39 | 40 | # args for make_snmf 41 | data = 100 * npr.randn(10, 5) 42 | atoms_nonneg_init = npr.rand(3, 5, 2) 43 | activations_init = np.ones((10, 3)) 44 | 45 | 46 | list_of_tree_makers_tuple = [ 47 | (nmf.make_nmf, (), {}), 48 | (nmf.make_smooth_activation_nmf, (), {}), 49 | (nmf.make_smooth_activation_nmf2, (), {}), 50 | (nmf.make_sparse_nmf, (), {}), 51 | (nmf.make_sparse_nmf2, (), {}), 52 | (nmf.make_sparse_nmf3, (), {}), 53 | (snmf.make_snmf, (data, atoms_nonneg_init, activations_init), {"fix_atoms": True}), 54 | (snmf.make_snmf, (data, atoms_nonneg_init, activations_init), {"fix_atoms": False}), 55 | (snmf.make_cluster_snmf, (), {}), 56 | (snmf.make_cluster_snmf2, (), {}), 57 | (conv_nmf.make_deconv_tree, (), {}), 58 | ] 59 | # let us add unique_id to the tuples 60 | list_of_tree_makers_tuple = [ 61 | elem + (num,) for num, elem in enumerate(list_of_tree_makers_tuple) 62 | ] 63 | 64 | 65 | @pytest.fixture(scope="module") 66 | def cost_record_results(): 67 | return {} 68 | 69 | 70 | @pytest.mark.parametrize("tree_maker_tuple", list_of_tree_makers_tuple) 71 | @pytest.mark.parametrize("inference_mode", ["EM", "VBEM"]) 72 | @pytest.mark.parametrize("update_type", ["regular", "parabolic"]) 73 | @pytest.mark.parametrize("limit_skellam", [True, False]) 74 | @pytest.mark.parametrize("backend", ["cpu", pytest.param("gpu", marks=pytest.mark.gpu)]) 75 | def test_example( 76 | tree_maker_tuple, 77 | inference_mode, 78 | update_type, 79 | limit_skellam, 80 | backend, 81 | cost_record_results, 82 | ): 83 | glob.set_backend_processor(backend, force=True) 84 | np.random.seed(0) 85 | tree_maker, args, kwargs, unique_id = tree_maker_tuple 86 | tree = tree_maker(*args, **kwargs) 87 | t_utils._setup_tree_for_decreasing_cost( 88 | tree, 89 | inference_mode=inference_mode, 90 | update_type=update_type, 91 | limit_skellam=limit_skellam, 92 | ) 93 | if inference_mode == "VBEM": 94 | if any(type(node) == LeafGammaNorm for node in tree.census()): 95 | with pytest.raises(NotImplementedError): 96 | tree.estimate_param(n_iter=100) 97 | return 98 | for ii in range(10): 99 | tree.estimate_param(n_iter=10) 100 | tree.estimate_hyperparam(n_iter=ii) 101 | else: 102 | tree.estimate_param(n_iter=100) 103 | assert t_utils._assert_cost_decrease(tree) 104 | assert t_utils._assert_graphviz_ok(tree) 105 | base_key = ( 106 | unique_id, 107 | inference_mode, 108 | update_type, 109 | limit_skellam, 110 | ) 111 | 112 | cost_record_results[base_key + (backend,)] = np.array(tree.cost_record) 113 | if ( 114 | base_key + ("cpu",) in cost_record_results 115 | and base_key + ("gpu",) in cost_record_results 116 | ): 117 | assert np.allclose( 118 | cost_record_results[base_key + ("cpu",)], 119 | cost_record_results[base_key + ("gpu",)], 120 | ) 121 | -------------------------------------------------------------------------------- /wonterfact/buds.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright 2021 Benoit Fuentes 3 | # 4 | # This file is part of Wonterfact. 5 | # 6 | # Wonterfact is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # any later version. 10 | # 11 | # Wonterfact is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Wonterfact. If not, see . 18 | # ---------------------------------------------------------------------------- 19 | 20 | """Module for all buds (i.e. hyperparameter nodes) classes""" 21 | 22 | # Python System imports 23 | from functools import cached_property 24 | 25 | # Third-party imports 26 | import numpy as np 27 | from methodtools import lru_cache 28 | 29 | # Relative imports 30 | from . import utils, core_nodes 31 | from .glob_var_manager import glob 32 | 33 | 34 | class _Bud(core_nodes._DynNodeData0): 35 | """ 36 | Mother class for the hyperparameters buds of a graphical model 37 | """ 38 | 39 | def __init__(self, **kwargs): 40 | self.new_born = True 41 | super().__init__(**kwargs) 42 | 43 | @property 44 | def level(self): 45 | """ 46 | Returns 0 , which is the default level for of a bud. 47 | 48 | Returns 49 | ------ 50 | int 51 | """ 52 | return 0 53 | 54 | def _check_model_validity(self): 55 | super()._check_model_validity() 56 | if not self.are_all_tensor_coefs_linked_to_at_least_one_child: 57 | raise ValueError( 58 | "Model invalid at {} level. Please make sure that each" 59 | "hyperparameters is linked to at least one child.".format(self) 60 | ) 61 | 62 | def _set_inference_mode(self, mode="EM"): 63 | super()._set_inference_mode(mode=mode) 64 | if mode == "EM": 65 | self.update_period = 0 66 | 67 | def _first_iteration(self): 68 | # this one is for e_d or m_e (cf technical report) 69 | self.tensor_update = glob.xp.empty_like(self.tensor) 70 | # this one is for hyperparameter optimization algorithm 71 | self.tensor_update_bis = glob.xp.empty_like(self.tensor) 72 | # at the first run, self.tensor_update needs to be initialized 73 | self._compute_tensor_update_aux2( 74 | tensor_to_fill=self.tensor_update, 75 | method_to_call="_give_update_first_iteration", 76 | ) 77 | 78 | @cached_property 79 | def number_of_users(self): 80 | """ 81 | Gives the number of parameters that share a same hyperparameter for each 82 | hyperparameter (corresponds to $|\\phi^{-1}(d)|$ in tech report) 83 | """ 84 | number = glob.xp.zeros_like(self.tensor) 85 | self._compute_tensor_update_aux2( 86 | tensor_to_fill=number, 87 | method_to_call="_give_number_of_users", 88 | ) 89 | return number 90 | 91 | def get_update_bis(self, tensor_to_fill): 92 | self._compute_tensor_update_aux2( 93 | tensor_to_fill=tensor_to_fill, 94 | method_to_call="_give_update_bis", 95 | ) 96 | 97 | def compute_tensor_update_online(self, learning_rate=1.0): 98 | if self.new_born: 99 | self._first_iteration() 100 | self.new_born = False 101 | past_tensor_update = self.tensor_update * (1.0 - learning_rate) 102 | self.compute_tensor_update() # new values in self.tensor_update 103 | self.tensor_update *= learning_rate 104 | self.tensor_update += past_tensor_update 105 | 106 | 107 | class BudShape(_Bud): 108 | @property 109 | def tensor_has_energy(self): 110 | return False 111 | 112 | def update_tensor(self): 113 | self.get_update_bis(tensor_to_fill=self.tensor_update_bis) 114 | utils.inverse_digamma( 115 | (self.tensor_update_bis + self.tensor_update) / self.number_of_users, 116 | out=self.tensor, 117 | ) 118 | 119 | 120 | class BudRate(_Bud): 121 | def init(self, prior_rate): 122 | prior_rate = glob.xp.array(prior_rate, dtype=glob.float) 123 | self.prior_rate = prior_rate 124 | if self.prior_rate.size > 1 and not (self.prior_rate > 0).all(): 125 | raise ValueError("prior_rate, if not None, must be > 0") 126 | 127 | @property 128 | def tensor_has_energy(self): 129 | return True 130 | 131 | def update_tensor(self): 132 | self.get_update_bis(tensor_to_fill=self.tensor_update_bis) 133 | self.tensor[...] = self.tensor_update_bis / self.tensor_update 134 | -------------------------------------------------------------------------------- /images/nmf_tree.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | root 11 | 12 | 13 | 14 | 140312933457056 15 | 16 | (×) 17 |   18 | 19 | ft 20 | multiplier 21 | 22 | 23 | 24 | 140312892089248 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | + 33 |   34 | 35 | ft 36 | observer 37 | 38 | 39 | 40 | 140312933457056->140312892089248 41 | 42 | 43 | 44 | 45 | 140312892089632 46 | 47 | 48 | kt 49 | activations 50 | 51 | 52 | 53 | 140312892089632->140312933457056 54 | 55 | 56 | 57 | 58 | 140312892089152 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | root 69 | 70 | 71 | 72 | 140312892089200 73 | 74 | f 75 | | 76 | k 77 | atoms 78 | 79 | 80 | 81 | 140312892089200->140312933457056 82 | 83 | 84 | 85 | 86 | 140312892089248->140312892089152 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /wonterfact/glob_var_manager.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright 2019 Smart Impulse SAS 3 | # 4 | # This file is part of Wonterfact. 5 | # 6 | # Wonterfact is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # any later version. 10 | # 11 | # Wonterfact is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Wonterfact. If not, see . 18 | # ---------------------------------------------------------------------------- 19 | 20 | """Module that deals with global variables of the program""" 21 | 22 | # Python Future imports 23 | from __future__ import division, unicode_literals, print_function, absolute_import 24 | 25 | # Python System imports 26 | 27 | # Third party import 28 | from functools import cached_property 29 | 30 | # Relative imports 31 | from . import utils 32 | 33 | 34 | class GlobalVarManager: 35 | CPU = "cpu" 36 | GPU = "gpu" 37 | FLOAT64 = "float64" 38 | FLOAT32 = "float32" 39 | CUPY = "cupy" 40 | NUMPY = "numpy" 41 | 42 | def __init__(self, processor=None, float_precision=None): 43 | self._processor = processor or self.CPU 44 | self._float = float_precision or self.FLOAT64 45 | self._can_change_backend = True 46 | self._can_change_float_precision = True 47 | 48 | @property 49 | def processor(self): 50 | """ 51 | Returns the current backend processor (gpu or cpu) 52 | """ 53 | return self._processor 54 | 55 | @property 56 | def backend(self): 57 | """ 58 | Returns the current backend python package for array manipulation and 59 | operation (numpy or cupy) 60 | """ 61 | if self.processor == self.GPU: 62 | backend = self.CUPY 63 | elif self.processor == self.CPU: 64 | backend = self.NUMPY 65 | return backend 66 | 67 | @cached_property 68 | def float(self): 69 | """ 70 | Returns the current float precision (32 or 64). Cautious once this 71 | property is called, float precision cannot no longer be changed. 72 | """ 73 | self._forbid_float_precision_change() 74 | return self._float 75 | 76 | def _forbid_backend_change(self): 77 | self._can_change_backend = False 78 | 79 | def _forbid_float_precision_change(self): 80 | self._can_change_float_precision = False 81 | 82 | def set_float_precision(self, float_precision): 83 | """ 84 | Set the float precision for all tensors in wonterfact's nodes. By 85 | default the float precision is `float64`. It can be changed to `float32` 86 | only once and should be just after importing wonterfact package. 87 | """ 88 | if float_precision != self.FLOAT64: 89 | raise NotImplementedError 90 | if float_precision not in [self.FLOAT64, self.FLOAT32]: 91 | raise ValueError( 92 | "float_precision can be '{}' or '{}'".format(self.FLOAT64, self.FLOAT32) 93 | ) 94 | if float_precision != self._float: 95 | if self._can_change_float_precision: 96 | self._float = float_precision 97 | else: 98 | raise ValueError( 99 | """ 100 | This call to set_float_precision has no effect because the precision has already been set. 101 | You can change float precision only just after importing wonterfact 102 | """ 103 | ) 104 | 105 | def _force_backend_reinit(self): 106 | for name_attr in ["xp", "sps", "as_strided"]: 107 | try: 108 | self.__delattr__(name_attr) 109 | except AttributeError: 110 | pass 111 | self._can_change_backend = True 112 | 113 | def set_backend_processor(self, processor, force=False): 114 | """ 115 | Set the backend processor. It should be set just after importing 116 | wonterfact package. By default, wonterfact uses cpu. 117 | 118 | Parameters 119 | ---------- 120 | processor: 'cpu' or 'gpu' 121 | force: bool, optional, default False 122 | In order to prevent unexpected errors, the backend can be changed 123 | only once, right after wonterfact package import. If True, user can 124 | force the backend setup any time, but at its own risk. 125 | """ 126 | if force: 127 | self._force_backend_reinit() 128 | if processor not in ["cpu", "gpu"]: 129 | raise ValueError("processor can be 'gpu' or 'cpu'") 130 | if processor != self.processor: 131 | if self._can_change_backend: 132 | self._processor = processor 133 | else: 134 | raise ValueError( 135 | """ 136 | This call to set_backend_processor has no effect because the processor has already been chosen. 137 | You can change processor engine only just after importing wonterfact 138 | """ 139 | ) 140 | 141 | @cached_property 142 | def xp(self): 143 | """ 144 | Return either cupy or numpy depending on the backend processor used. 145 | """ 146 | self._forbid_backend_change() 147 | if self.processor == "cpu": 148 | import numpy 149 | 150 | return numpy 151 | elif self.processor == "gpu": 152 | import cupy # pylint: disable=E0401 153 | 154 | return cupy 155 | 156 | @cached_property 157 | def sps(self): 158 | """ 159 | Returns either scipy.special or cupyx.scipy.special depending on the 160 | backend processor used. 161 | """ 162 | self._forbid_backend_change() 163 | if self.processor == "cpu": 164 | import scipy.special 165 | 166 | return scipy.special 167 | elif self.processor == "gpu": 168 | import cupyx.scipy.special # pylint: disable=E0401 169 | 170 | return cupyx.scipy.special 171 | 172 | @cached_property 173 | def as_strided(self): 174 | """ 175 | Returns either numpy.lib.stride_tricks.as_strided or 176 | cupy.lib.stride_tricks.as_strided depending on the backend processor 177 | used. 178 | """ 179 | self._forbid_backend_change() 180 | if self.processor == "cpu": 181 | from numpy.lib.stride_tricks import as_strided 182 | 183 | return as_strided 184 | elif self.processor == "gpu": 185 | from cupy.lib.stride_tricks import as_strided # pylint: disable=E0401 186 | 187 | return as_strided 188 | 189 | 190 | glob = GlobalVarManager() 191 | -------------------------------------------------------------------------------- /wonterfact/examples/snmf.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright 2020 Benoit Fuentes 3 | # 4 | # This file is part of Wonterfact. 5 | # 6 | # Wonterfact is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # any later version. 10 | # 11 | # Wonterfact is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Wonterfact. If not, see . 18 | # ---------------------------------------------------------------------------- 19 | 20 | """Examples of semi-nonnegative models for real-valued observed tensor""" 21 | 22 | 23 | # Python standard library 24 | 25 | # Third-party imports 26 | import numpy as np 27 | import numpy.random as npr 28 | 29 | # wonterfact imports 30 | import wonterfact as wtf 31 | import wonterfact.utils as wtfu 32 | 33 | 34 | def make_snmf( 35 | data, 36 | atoms_nonneg_init, 37 | activations_init, 38 | fix_atoms=False, 39 | atoms_shape_prior=1, 40 | activations_shape_prior=1, 41 | activations_rate_prior=0.001, 42 | inference_mode="EM", 43 | integer_data=False, 44 | ): 45 | """ 46 | This return a wonterfact tree corresponding to the Skellam-SNMF algorithm. 47 | 48 | Parameters 49 | ---------- 50 | data: array_like of shape [J x I] 51 | Real-valued input array to factorize 52 | atoms_nonneg_init: array_like of shape [K x I x 2] 53 | Initialization for nonnegative atoms tensor 54 | activations_init: array_like of shape [J x K] 55 | Initialization for activations matrix 56 | fix_atoms: bool, default False 57 | Whether atoms should be updated or left to initial value 58 | atoms_shape_prior: array_like or float 59 | Shape hyperparameters for atoms (`atoms_shape_prior + atoms_nonneg_init`) 60 | should raise no Error 61 | activations_shape_prior: array_like or float 62 | Shape hyperparameters for activations (`activations_shape_prior + activations_init`) 63 | should raise no Error 64 | activations_rate_prior: array_like or float 65 | Shape hyperparameters for atoms (`activations_rate_prior + activations_init`) 66 | should raise no Error. 67 | inference_mode: 'EM' or 'VBEM', default 'EM' 68 | Algorithm that should be used to infer parameters 69 | integer_data: bool, default False 70 | Whether data are integers or real numbers. If True, Skellam-SNMF is 71 | performed with :math: `M=1`, otherwise with :math: `M=\\infty` (see [1]) 72 | 73 | Returns 74 | ------- 75 | wonterfact.Root 76 | root of the wonterfact tree 77 | 78 | Notes 79 | ------ 80 | Allows to solve the following problem 81 | .. math:: 82 | X_{ji} \\approx \\sum_{k} \\lambda_{jk} * W_{ki} \\textrm{with}\\ 83 | W_{ki} = \\theta_{ki, s=0} - \\theta_{ki, s=1} 84 | Beware that axis are reversed compared to the model in [1]. This due to the 85 | terms of use of wonterfact. 86 | 87 | References 88 | ---------- 89 | ..[1] B.Fuentes et. al., Probabilistic semi-nonnegative matrix factorization: 90 | a Skellam-based framework, 2021 91 | 92 | """ 93 | ### creation of atoms leaf 94 | # to be sure atoms are well normalized 95 | atoms_nonneg_init /= atoms_nonneg_init.sum(axis=(1, 2), keepdims=True) 96 | leaf_kis = wtf.LeafDirichlet( 97 | name="atoms", # name of the node 98 | index_id="kis", # name of the axis 99 | norm_axis=(1, 2), # normalization axis 100 | tensor=atoms_nonneg_init, # instantiation of leaf's tensor 101 | init_type="custom", # to be sure value of `tensor` is kept as initialization 102 | update_period=0 if fix_atoms else 1, # whether atoms should be updated or not 103 | prior_shape=atoms_shape_prior, # shape hyperparameters 104 | ) 105 | ### creation of activations leaf 106 | leaf_jk = wtf.LeafGamma( 107 | name="activations", # name of the node 108 | index_id="jk", # name of the axis 109 | tensor=activations_init, # instantiation of leaf's tensor 110 | init_type="custom", # to be sure value of `tensor` is kept as initialization 111 | prior_rate=activations_rate_prior, # rate hyperparameters 112 | prior_shape=activations_shape_prior, # shape hyperparameters 113 | ) 114 | ### resulting skellam parameters for observed data data 115 | mul_jis = wtf.Multiplier(name="multiplier", index_id="jis") 116 | mul_jis.new_parents(leaf_kis, leaf_jk) # creation of filiations 117 | 118 | ### observed real-valued data 119 | obs_ji = wtf.RealObserver( 120 | name="observer", 121 | index_id="ji", 122 | tensor=data, 123 | limit_skellam_update=not integer_data, # whether data are considered as integers 124 | ) 125 | mul_jis.new_child(obs_ji) # filiation between data and model parameters 126 | 127 | ### creation or the root 128 | root = wtf.Root( 129 | inference_mode=inference_mode, 130 | stop_estim_threshold=1e-7, 131 | cost_computation_iter=50, 132 | ) 133 | obs_ji.new_child(root) 134 | return root 135 | 136 | 137 | def make_cluster_snmf(): 138 | """ 139 | Convex S-NMF for automatic clustering, inspired by Ding2010_IEEE 140 | Data to cluster are considered scale invariant, meaning we cluster directions 141 | rather than points in some n-dimensional space. 142 | """ 143 | dim_d, dim_f = 400, 10 144 | tensor_df = np.zeros((dim_d, dim_f)) 145 | for ii in range(4): 146 | shape = npr.dirichlet(np.ones(dim_f)) ** 2 * 100 147 | tensor_df[ii * dim_d // 4 : (ii + 1) * dim_d // 4, :] = npr.dirichlet( 148 | shape, size=dim_d // 4 149 | ) * npr.choice([-1, 1], size=(dim_f)) 150 | 151 | tensor_df = tensor_df[npr.permutation(dim_d)] 152 | tensor_df *= 200 153 | # tensor_df[tensor_df<1e-10] = 0 154 | 155 | dim_q = 10 156 | 157 | tensor_dfs = wtfu.normalize(wtfu.real_to_2D_nonnegative(tensor_df), (1, 2)) 158 | leaf_dfs = wtf.LeafDirichlet( 159 | name="samples", 160 | index_id="dfs", 161 | norm_axis=(1, 2), 162 | tensor=tensor_dfs, 163 | update_period=0, 164 | ) 165 | 166 | tensor_qd = np.ones((dim_q, dim_d)) / dim_d 167 | leaf_qd = wtf.LeafDirichlet( 168 | name="samples_by_class", 169 | index_id="qd", 170 | norm_axis=(1,), 171 | tensor=tensor_qd, 172 | prior_shape=2, 173 | ) 174 | mul_qfs = wtf.Multiplier(name="barycenters", index_id="qfs") 175 | mul_qfs.new_parents(leaf_qd, leaf_dfs) 176 | 177 | tensor_q = 1 + 1e-4 * (npr.rand(dim_q)) 178 | leaf_q = wtf.LeafGamma( 179 | name="class_energy", 180 | index_id="q", 181 | norm_axis=(0,), 182 | tensor=tensor_q, 183 | # total_max_energy=2 * np.abs(tensor_df).sum(), 184 | prior_shape=1.1 * tensor_q, 185 | prior_rate=1e-4, 186 | ) 187 | mul_qd = wtf.Multiplier(name="sample_energy", index_id="qd") 188 | mul_qd.new_parents(leaf_q, leaf_qd) 189 | 190 | mul_dfs = wtf.Multiplier(name="reconstruction", index_id="dfs") 191 | mul_dfs.new_parents(mul_qfs, mul_qd) 192 | 193 | # observations 194 | obs_df = wtf.RealObserver( 195 | name="observations", 196 | index_id="df", 197 | tensor=tensor_df, 198 | # drawings_max=drawing_bin_max * tensor_df.size, 199 | # drawings_step = drawings_bin_step * tensor_df.size, 200 | ) 201 | mul_dfs.new_child(obs_df) 202 | root = wtf.Root( 203 | name="root", 204 | verbose_iter=200, 205 | cost_computation_iter=10, 206 | # update_type='regular' 207 | ) 208 | obs_df.new_child(root) 209 | return root 210 | 211 | 212 | def make_cluster_snmf2(nb_cluster=4, prior_rate=0.001): 213 | """ 214 | Convex S-NMF for automatic clustering, inspired by Ding2010_IEEE 215 | Data to cluster are points in some n-dimensional space. 216 | """ 217 | dim_d, dim_f = nb_cluster * 100, 2 218 | tensor_df = np.zeros((dim_d, dim_f)) 219 | for ii in range(nb_cluster): 220 | mean = npr.uniform(low=-10, high=10, size=dim_f) 221 | cov = npr.uniform(low=0.0, high=1.0, size=(dim_f, dim_f)) 222 | cov = np.dot(cov, cov.T) + np.diag(npr.uniform(low=0.0, high=0.5, size=dim_f)) 223 | tensor_df[ 224 | ii * dim_d // nb_cluster : (ii + 1) * dim_d // nb_cluster, : 225 | ] = npr.multivariate_normal(mean, cov, size=dim_d // nb_cluster) 226 | 227 | tensor_df = tensor_df[npr.permutation(dim_d)].copy() 228 | 229 | dim_q = nb_cluster 230 | 231 | tensor_dfs = wtfu.normalize(wtfu.real_to_2D_nonnegative(tensor_df), None) 232 | leaf_dfs = wtf.LeafDirichlet( 233 | name="samples", 234 | index_id="dfs", 235 | norm_axis=(0, 1, 2), 236 | tensor=tensor_dfs, 237 | update_period=0, 238 | ) 239 | 240 | tensor_dq = np.ones((dim_d, dim_q)) 241 | leaf_dq = wtf.LeafDirichlet( 242 | name="class_by_sample", 243 | index_id="dq", 244 | norm_axis=(1,), 245 | tensor=wtfu.normalize(npr.rand(dim_d, dim_q), (1,)), 246 | prior_shape=1, 247 | ) 248 | mul_pfs = wtf.Multiplier(name="barycenters", index_id="pfs") 249 | leaf_dq.new_child(mul_pfs, index_id_for_child="dp") 250 | leaf_dfs.new_child(mul_pfs) 251 | 252 | tensor_pqm = np.zeros((dim_q, dim_q, 2)) 253 | tensor_pqm[:, :, 0] = np.eye(dim_q) 254 | tensor_pqm[:, :, 1] = 1 - np.eye(dim_q) 255 | leaf_pqm = wtf.LeafDirichlet( 256 | name="energy dispatcher", 257 | index_id="pqm", 258 | tensor=tensor_pqm, 259 | norm_axis=(2,), 260 | update_period=0, 261 | ) 262 | 263 | mul_qfsm = wtf.Multiplier(index_id="qfsm") 264 | mul_qfsm.new_parents(mul_pfs, leaf_pqm) 265 | 266 | leaf_d = wtf.LeafGamma( 267 | name="sample_energy", 268 | index_id="d", 269 | norm_axis=(0,), 270 | tensor=np.ones(dim_d), 271 | prior_shape=1, 272 | prior_rate=prior_rate, 273 | ) 274 | 275 | mul_dq = wtf.Multiplier(name="sample_class_energy", index_id="dq") 276 | mul_dq.new_parents(leaf_d, leaf_dq) 277 | 278 | mul_dfsm = wtf.Multiplier(name="reconstruction", index_id="dfsm") 279 | mul_dfsm.new_parents(mul_qfsm, mul_dq) 280 | 281 | # observations 282 | obs_df = wtf.RealObserver( 283 | name="observations", 284 | index_id="df", 285 | tensor=tensor_df, 286 | ) 287 | mul_dfsm.new_child(obs_df, index_id_for_child="dfs", slice_for_child=(Ellipsis, 0)) 288 | root = wtf.Root( 289 | name="root", 290 | verbose_iter=50, 291 | cost_computation_iter=10, 292 | # update_type='regular' 293 | ) 294 | obs_df.new_child(root) 295 | 296 | # # null obs 297 | # obs_df2 = wtf.PosObserver( 298 | # name="null_observations", index_id="dfs", tensor=np.zeros((dim_d, dim_f, 2)) 299 | # ) 300 | # mul_dfsm.new_child(obs_df2, index_id_for_child="dfs", slice_for_child=(Ellipsis, 1)) 301 | 302 | # obs_df2.new_child(root) 303 | 304 | return root 305 | -------------------------------------------------------------------------------- /tests/test_nodes.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright 2020 Benoit Fuentes 3 | # 4 | # This file is part of Wonterfact. 5 | # 6 | # Wonterfact is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # any later version. 10 | # 11 | # Wonterfact is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Wonterfact. If not, see . 18 | # ---------------------------------------------------------------------------- 19 | 20 | """Tests for basic methods of nodes""" 21 | 22 | # Python standard library 23 | from pathlib import Path 24 | import tempfile 25 | 26 | # Third-party imports 27 | import numpy as np 28 | import numpy.random as npr 29 | import pytest 30 | 31 | # wonterfact imports 32 | import wonterfact as wtf 33 | from wonterfact import buds 34 | 35 | # relative imports 36 | from . import utils as t_utils 37 | 38 | 39 | def normalize(arr, axis): 40 | return arr / arr.sum(axis, keepdims=True) 41 | 42 | 43 | # let us create a wonterfact tree once for all tests 44 | def make_test_tree(): 45 | """ 46 | A toy wonterfact tree is designed in order to test the maximum 47 | possibilities. 48 | """ 49 | 50 | dim_k, dim_s, dim_t, dim_f, dim_c = 3, 2, 10, 4, 2 51 | 52 | leaf_energy = wtf.LeafGamma( 53 | name="leaf_energy", 54 | index_id="", 55 | tensor=np.array(1), 56 | prior_shape=1.1, 57 | prior_rate=0.01, 58 | ) 59 | leaf_k = wtf.LeafDirichlet( 60 | name="leaf_k", 61 | index_id="k", 62 | norm_axis=(0,), 63 | tensor=normalize(npr.rand(dim_k), 0), 64 | prior_shape=1, 65 | ) 66 | mul_k = wtf.Multiplier(name="mul_k", index_id=("k",)) 67 | mul_k.new_parents(leaf_energy, leaf_k) 68 | 69 | leaf_kts_0 = wtf.LeafDirichlet( 70 | name="leaf_kts_0", 71 | index_id="kts", 72 | norm_axis=(1, 2), 73 | tensor=normalize(npr.rand(dim_k, dim_t, dim_s), (1, 2)), 74 | prior_shape=1, 75 | ) 76 | mul_kts = wtf.Multiplier(name="mul_kts", index_id="kts") 77 | mul_kts.new_parents(leaf_kts_0, mul_k) 78 | 79 | leaf_kts_1 = wtf.LeafGamma( 80 | name="leaf_kts_1", 81 | index_id="kts", 82 | tensor=npr.rand(dim_k, dim_t, dim_s), 83 | prior_shape=1, 84 | prior_rate=0.001, 85 | ) 86 | mult_ktsc = wtf.Multiplexer(name="mult_ktsc", index_id="ktsc") 87 | mult_ktsc.new_parents(leaf_kts_1, mul_kts) 88 | 89 | # no update for this leaf 90 | # it acts like an Adder (no normalization) 91 | leaf_c = wtf.LeafDirichlet( 92 | name="leaf_c", 93 | norm_axis=(), 94 | index_id="c", 95 | tensor=np.ones(dim_c), 96 | update_period=0, 97 | ) 98 | mul_kts_1 = wtf.Multiplier(name="mul_kts_1", index_id="kts") 99 | mul_kts_1.new_parents(mult_ktsc, leaf_c) 100 | 101 | # two updates every 3 iterations 102 | leaf_kf = wtf.LeafDirichlet( 103 | name="leaf_kf", 104 | index_id="kf", 105 | norm_axis=(1,), 106 | tensor=normalize(npr.rand(dim_k, dim_f), 1), 107 | prior_shape=1, 108 | update_period=3, 109 | update_succ=2, 110 | update_offset=0, 111 | ) 112 | mul_tfs = wtf.Multiplier(name="mul_tfs", index_id="tfs") 113 | mul_tfs.new_parents(mul_kts_1, leaf_kf) 114 | 115 | obs_tf = wtf.RealObserver( 116 | name="obs_tf", index_id="tf", tensor=100 * npr.randn(dim_t, dim_f) 117 | ) 118 | mul_tfs.new_child(obs_tf) 119 | 120 | root = wtf.Root( 121 | name="root", 122 | cost_computation_iter=1, 123 | stop_estim_threshold=0, 124 | update_type="regular", 125 | inference_mode="EM", 126 | verbose_iter=10, 127 | ) 128 | obs_tf.new_child(root) 129 | 130 | leaf_k_1 = wtf.LeafGamma( 131 | name="leaf_k_1", 132 | index_id="k", 133 | tensor=npr.rand(dim_k), 134 | prior_shape=1, 135 | prior_rate=0.001, 136 | ) 137 | # one update every 3 iterations 138 | leaf_kt = wtf.LeafDirichlet( 139 | name="leaf_kt", 140 | index_id="kt", 141 | norm_axis=(1,), 142 | tensor=normalize(np.ones((dim_k, dim_t)), 1), 143 | prior_shape=1, 144 | update_period=3, 145 | update_succ=1, 146 | update_offset=2, 147 | ) 148 | mul_kt = wtf.Multiplier(name="mul_kt", index_id="kt") 149 | mul_kt.new_parents(leaf_k_1, leaf_kt) 150 | mul_tf = wtf.Multiplier(name="mul_tf", index_id="tf") 151 | mul_tf.new_parents(leaf_kf, mul_kt) 152 | 153 | obs_tf_2 = wtf.PosObserver( 154 | name="obs_tf_2", index_id="tf", tensor=100 * npr.rand(dim_t, dim_f) 155 | ) 156 | mul_tf.new_child(obs_tf_2) 157 | obs_tf_2.new_child(root) 158 | 159 | root.dim_k, root.dim_s, root.dim_t, root.dim_f, root.dim_c = ( 160 | dim_k, 161 | dim_s, 162 | dim_t, 163 | dim_f, 164 | dim_c, 165 | ) 166 | 167 | return root 168 | 169 | 170 | @pytest.fixture( 171 | scope="module", params=["cpu", pytest.param("gpu", marks=pytest.mark.gpu)] 172 | ) 173 | def tree(request): 174 | backend = request.param 175 | wtf.glob.set_backend_processor(backend, force=True) 176 | return make_test_tree() 177 | 178 | 179 | def test_filiation(tree): 180 | with pytest.raises(ValueError, match=r".* cannot be linked several times.*"): 181 | tree.leaf_kf.new_child(tree.mul_tf) 182 | with pytest.raises(ValueError, match=r".* nodes cannot have more than .*"): 183 | tree.leaf_energy.new_child(tree.mul_tf) 184 | assert set(tree.leaf_kf.list_of_children) == set([tree.mul_tfs, tree.mul_tf]) 185 | assert set(tree.mul_tf.list_of_parents) == set([tree.leaf_kf, tree.mul_kt]) 186 | assert tree.leaf_energy.first_child == tree.mul_k 187 | assert tree.mul_k.first_parent == tree.leaf_energy 188 | assert tree.leaf_k_1.has_a_single_child 189 | assert not tree.leaf_kf.has_a_single_child 190 | 191 | 192 | def test_level(tree): 193 | assert tree.leaf_energy_rate.level == 0 194 | assert tree.leaf_energy_shape.level == 0 195 | assert tree.leaf_energy.level == 1 196 | assert tree.leaf_energy.level == 1 197 | assert tree.mult_ktsc.level == 4 198 | assert tree.obs_tf_2.level == 4 199 | assert tree.obs_tf.level == 7 200 | assert tree.root.level == 8 201 | 202 | 203 | def test_census(tree): 204 | all_nodes = set( 205 | [ 206 | tree.leaf_energy, 207 | tree.leaf_energy_rate, 208 | tree.leaf_energy_shape, 209 | tree.leaf_k, 210 | tree.leaf_k_shape, 211 | tree.leaf_k_1, 212 | tree.leaf_k_1_rate, 213 | tree.leaf_k_1_shape, 214 | tree.leaf_kf, 215 | tree.leaf_kf_shape, 216 | tree.leaf_kt, 217 | tree.leaf_kt_shape, 218 | tree.leaf_kts_0, 219 | tree.leaf_kts_0_shape, 220 | tree.leaf_kts_1, 221 | tree.leaf_kts_1_rate, 222 | tree.leaf_kts_1_shape, 223 | tree.leaf_c, 224 | tree.mul_k, 225 | tree.mul_kts, 226 | tree.mul_kts_1, 227 | tree.mul_tf, 228 | tree.mul_tfs, 229 | tree.mul_kt, 230 | tree.mult_ktsc, 231 | tree.obs_tf, 232 | tree.obs_tf_2, 233 | tree.root, 234 | ] 235 | ) 236 | assert all_nodes == tree.root.census() 237 | 238 | 239 | def test_get_tensor(tree): 240 | tree.leaf_c._set_inference_mode(mode="EM") 241 | tree.leaf_c._initialization() 242 | tensor1 = tree.leaf_c.get_tensor(force_numpy=True) 243 | assert np.allclose(tensor1, 1.0) 244 | tensor2 = tree.leaf_c.get_tensor_for_children(tree.mul_kts_1, force_numpy=True) 245 | assert np.allclose(tensor2, 1.0) 246 | 247 | 248 | def test_message_passing(tree): 249 | def set_foo(iteration_number=None, mode="top-down"): 250 | tree.root.tree_traversal( 251 | "__setattr__", 252 | mode=mode, 253 | method_input=(("foo", iteration_number), {}), 254 | iteration_number=iteration_number, 255 | ) 256 | 257 | all_nodes = tree.root.census() 258 | all_nodes_that_always_update = all_nodes.difference( 259 | set( 260 | [ 261 | tree.leaf_c, 262 | tree.leaf_kf, 263 | tree.leaf_kf_shape, 264 | tree.leaf_kt, 265 | tree.leaf_kt_shape, 266 | ] 267 | ) 268 | ) 269 | set_foo(None) 270 | assert all(node.foo == None for node in all_nodes) 271 | # pylint: disable=no-member 272 | set_foo(0) 273 | assert all(node.foo == 0 for node in all_nodes_that_always_update) 274 | assert tree.leaf_c.foo is None 275 | assert tree.leaf_kf.foo == 0 276 | assert tree.leaf_kf_shape.foo == 0 277 | assert tree.leaf_kt.foo is None 278 | assert tree.leaf_kt_shape.foo is None 279 | set_foo(1) 280 | assert all(node.foo == 1 for node in all_nodes_that_always_update) 281 | assert tree.leaf_c.foo is None 282 | assert tree.leaf_kf.foo == 1 283 | assert tree.leaf_kf_shape.foo == 1 284 | assert tree.leaf_kt.foo is None 285 | assert tree.leaf_kt_shape.foo is None 286 | 287 | set_foo(2) 288 | assert all(node.foo == 2 for node in all_nodes_that_always_update) 289 | assert tree.leaf_c.foo is None 290 | assert tree.leaf_kf.foo == 1 291 | assert tree.leaf_kf_shape.foo == 1 292 | assert tree.leaf_kt.foo == 2 293 | assert tree.leaf_kt_shape.foo == 2 294 | 295 | set_foo(3) 296 | assert all(node.foo == 3 for node in all_nodes_that_always_update) 297 | assert tree.leaf_c.foo is None 298 | assert tree.leaf_kf.foo == 3 299 | assert tree.leaf_kf_shape.foo == 3 300 | assert tree.leaf_kt.foo == 2 301 | assert tree.leaf_kt_shape.foo == 2 302 | # pylint: enable=no-member 303 | 304 | 305 | def test_initialization(tree): 306 | tree.root.tree_traversal( 307 | "_set_inference_mode", 308 | mode="top-down", 309 | method_input=((), dict(mode=tree.root.inference_mode)), 310 | ) 311 | tree.root.tree_traversal( 312 | "_initialization", 313 | mode="top-down", 314 | ) 315 | norm_tensor_k = tree.leaf_kf.get_tensor(force_numpy=True).sum(1) 316 | assert np.allclose(norm_tensor_k, 1.0) 317 | norm_tensor_k = tree.leaf_kts_0.get_tensor(force_numpy=True).sum((1, 2)) 318 | assert np.allclose(norm_tensor_k, 1.0) 319 | assert np.isclose(tree.leaf_k.get_tensor(force_numpy=True).sum(), 1.0) 320 | 321 | assert tree.mul_tfs.get_tensor().shape == (tree.dim_t, tree.dim_f, tree.dim_s) 322 | assert tree.mul_tf.get_tensor().shape == (tree.dim_t, tree.dim_f) 323 | assert tree.mul_kt.get_tensor().shape == (tree.dim_k, tree.dim_t) 324 | assert tree.mult_ktsc.get_tensor().shape == ( 325 | tree.dim_k, 326 | tree.dim_t, 327 | tree.dim_s, 328 | tree.dim_c, 329 | ) 330 | 331 | all_nodes = tree.root.census() 332 | assert all( 333 | node.tensor_update.shape == node.tensor.shape 334 | for node in all_nodes 335 | if getattr(node, "tensor_update", None) is not None 336 | ) 337 | assert tree.leaf_c.tensor_update == None 338 | 339 | 340 | def test_compute_tensor_update(tree): 341 | tree.root.tree_traversal( 342 | "_set_inference_mode", 343 | mode="top-down", 344 | method_input=((), dict(mode=tree.root.inference_mode)), 345 | ) 346 | tree.root.tree_traversal( 347 | "_initialization", 348 | mode="top-down", 349 | ) 350 | tree.root.tree_traversal( 351 | "compute_tensor_update", 352 | mode="bottom-up", 353 | iteration_number=0, 354 | type_filter_list=[ 355 | buds._Bud, 356 | ], 357 | ) 358 | 359 | assert wtf.glob.xp.allclose( 360 | tree.mul_tf.tensor_update, tree.obs_tf_2.tensor / tree.mul_tf.tensor 361 | ) 362 | assert wtf.glob.xp.allclose( 363 | tree.leaf_k_1.tensor_update, 364 | wtf.glob.xp.einsum("kt,kt->k", tree.mul_kt.tensor_update, tree.leaf_kt.tensor), 365 | ) 366 | assert wtf.glob.xp.allclose( 367 | tree.leaf_energy.tensor_update, 368 | (tree.mul_k.tensor_update * tree.leaf_k.tensor).sum(), 369 | ) 370 | assert wtf.glob.xp.allclose( 371 | tree.leaf_kf.tensor_update, 372 | ( 373 | wtf.glob.xp.einsum( 374 | "tfs,kts->kf", tree.mul_tfs.tensor_update, tree.mul_kts_1.tensor 375 | ) 376 | + wtf.glob.xp.einsum( 377 | "tf,kt->kf", tree.mul_tf.tensor_update, tree.mul_kt.tensor 378 | ) 379 | ), 380 | ) 381 | 382 | 383 | def test_param_estimation(tree): 384 | tree.root.estimate_param(n_iter=100) 385 | cost_func = np.array(tree.root.cost_record) 386 | assert all(cost_func[:-1] >= cost_func[1:]) 387 | 388 | 389 | def test_graphviz(tree): 390 | legend_dict = { 391 | "k": {"description": "atom", "letter": "k"}, 392 | "f": {"description": "frequency"}, 393 | "t": {"description": "time"}, 394 | "s": {"description": "sign"}, 395 | "c": {"description": "complex part"}, 396 | } 397 | assert t_utils._assert_graphviz_ok(tree, legend_dict) 398 | -------------------------------------------------------------------------------- /tests/test_model_check.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright 2020 Benoit Fuentes 3 | # 4 | # This file is part of Wonterfact. 5 | # 6 | # Wonterfact is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # any later version. 10 | # 11 | # Wonterfact is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Wonterfact. If not, see . 18 | # ---------------------------------------------------------------------------- 19 | 20 | """Tests for the model validation feature of wonterfact""" 21 | 22 | # Python standard library 23 | 24 | # Third-party imports 25 | import numpy as np 26 | import numpy.random as npr 27 | import pytest 28 | 29 | # wonterfact imports 30 | import wonterfact as wtf 31 | import wonterfact.core_nodes as wtfc 32 | import wonterfact.operators as wtfo 33 | 34 | 35 | def test_get_norm_for_children(): 36 | parent = wtfc._DynNodeData() 37 | parent.tensor_has_energy = False 38 | child = wtfo._Operator() 39 | sl = slice(None) 40 | # list of (parent_shape, parent_norm_axis, slice_for_child, shape_for_child, norm_axis_for_child) 41 | input_output_list = [ 42 | ((2,) * 7, (1, 3, 4), (0, sl, 1, sl, sl, 1, sl), (2, 2, 2, 2), (0, 1, 2)), 43 | ((2,) * 7, (1, 3, 4), (0, sl, 1, sl, sl, 1, slice(1)), (2, 2, 2, 1), (0, 1, 2)), 44 | ((2,) * 7, (1, 3, 4), (0, sl, 1, sl, sl, 1, 0), (2, 2, 2), (0, 1, 2)), 45 | ((2,) * 7, (1, 3, 4), (0, sl, 1, sl, sl, 1, sl), (2, 4, 2), (0, 1)), 46 | ((2,) * 7, (1, 3, 4), (0, sl, 1, sl, sl, sl, sl), (2, 4, 4), (0, 1)), 47 | ((2,) * 7, (4,), (0, sl, 1, sl, sl, 1, sl), (4, 2, 2), (1,)), 48 | ((2, 3, 4, 2), (1, 2), (0, Ellipsis), (2, 6, 2), (0, 1)), 49 | ((2, 3, 4, 1), (1, 2), (0, Ellipsis, 0), (2, 6), (0, 1)), 50 | ((2,), (), Ellipsis, None, ()), 51 | ((2, 2, 2), (1,), (slice(1),), (1, 2, 2), (1,)), 52 | ] 53 | 54 | def setup_nodes(parent_shape, parent_norm_axis, slice_for_child, shape_for_child): 55 | parent.remove_child(child) 56 | parent.tensor = npr.rand(*parent_shape) 57 | parent.norm_axis = parent_norm_axis 58 | parent.new_child( 59 | child, slice_for_child=slice_for_child, shape_for_child=shape_for_child 60 | ) 61 | 62 | for elem in input_output_list: 63 | setup_nodes(*elem[:-1]) 64 | assert parent.get_norm_axis_for_children(child) == elem[-1] 65 | 66 | # expected errors 67 | # list of (parent_shape, parent_norm_axis, slice_for_child, shape_for_child, expected error) 68 | input_output_list = [ 69 | ((2, 2, 2), (2,), Ellipsis, (2, 4), "Invalid shape_for_child .*"), 70 | ((2, 2, 2), (1,), Ellipsis, (4, 2), "Invalid shape_for_child .*"), 71 | ( 72 | (2, 2, 2), 73 | (1,), 74 | (slice(1),), 75 | (2, 2), 76 | ".* 1-size dimension of the sliced tensor.*", 77 | ), 78 | ( 79 | (2, 2, 2), 80 | (2,), 81 | Ellipsis, 82 | (1, 4, 2), 83 | ".* 1-size dimensions in `shape_for_child`.*", 84 | ), 85 | ] 86 | for elem in input_output_list: 87 | setup_nodes(*elem[:-1]) 88 | with pytest.raises(ValueError, match=elem[-1]): 89 | parent.get_norm_axis_for_children(child) 90 | 91 | 92 | def test_check_model_validity_dynnodedata(): 93 | parent = wtfc._DynNodeData(index_id="d", tensor=npr.rand(2)) 94 | parent.tensor_has_energy = False 95 | child1 = wtfo._Operator(index_id="d", tensor=npr.rand(1)) 96 | parent.new_child(child1, slice_for_child=slice(0, 1)) 97 | with pytest.raises( 98 | ValueError, 99 | match=".*Each coefficient of inner tensor should be seen by at least", 100 | ): 101 | parent._check_model_validity() 102 | 103 | parent = wtfc._DynNodeData(index_id="td", tensor=npr.rand(2, 2)) 104 | parent.tensor_has_energy = False 105 | parent.norm_axis = (1,) 106 | child1 = wtfo._Operator(index_id="td", tensor=npr.rand(2, 1)) 107 | parent.new_child(child1, slice_for_child=(..., slice(0, 1))) 108 | child2 = wtfo._Operator(index_id="td", tensor=npr.rand(2, 1)) 109 | parent.new_child(child2, slice_for_child=(..., slice(1, 2))) 110 | with pytest.raises( 111 | ValueError, match=".*its children should not see an incomplete piece of" 112 | ): 113 | parent._check_model_validity() 114 | 115 | parent = wtfc._DynNodeData(index_id="td", tensor=npr.rand(2, 2)) 116 | parent.tensor_has_energy = False 117 | parent.norm_axis = (1,) 118 | child1 = wtfo._Operator(index_id="t", tensor=npr.rand(1)) 119 | parent.new_child(child1, slice_for_child=(..., 0)) 120 | child2 = wtfo._Operator(index_id="td", tensor=npr.rand(1)) 121 | parent.new_child(child2, slice_for_child=(..., [1,])) 122 | with pytest.raises( 123 | ValueError, match=".*its children should not see an incomplete piece of" 124 | ): 125 | parent._check_model_validity() 126 | 127 | parent = wtfc._DynNodeData(index_id="d", tensor=npr.rand(3)) 128 | parent.tensor_has_energy = True 129 | child1 = wtfo._Operator(index_id="d", tensor=npr.rand(2)) 130 | parent.new_child(child1, slice_for_child=slice(0, 2)) 131 | child2 = wtfo._Operator(index_id="d", tensor=npr.rand(1)) 132 | parent.new_child(child2, slice_for_child=slice(0, 1)) 133 | with pytest.raises( 134 | ValueError, 135 | match=".*Each coefficient of inner tensor should be seen by at most", 136 | ): 137 | parent._check_model_validity() 138 | 139 | 140 | def test_norm_axis_and_check_model_validity_multiplier(): 141 | tensor1 = npr.rand(2, 2) 142 | tensor2 = npr.rand(2, 2) 143 | 144 | # list of (has_energy1, has_energy2, index_id1, index_id2, index_id_child, norm_axis1, norm_axis2, norm_axis_child) 145 | input_output_list = [ 146 | (True, False, "td", "df", "tf", None, (1,), None), 147 | (False, False, "td", "df", "tf", (0, 1), (1,), (0, 1)), 148 | (False, False, "td", "df", "tf", (1,), (1,), (1,)), 149 | (False, False, "td", "df", "tf", (0,), (1, 1), (0, 1)), 150 | ] 151 | 152 | def setup( 153 | has_energy1, 154 | has_energy2, 155 | index_id1, 156 | index_id2, 157 | index_id_child, 158 | norm_axis1, 159 | norm_axis2, 160 | ): 161 | parent1 = wtfc._DynNodeData(tensor=tensor1, index_id=index_id1) 162 | parent2 = wtfc._DynNodeData(tensor=tensor2, index_id=index_id2) 163 | child = wtfo.Multiplier( 164 | index_id=index_id_child, tensor=npr.rand(*(2,) * len(index_id_child)) 165 | ) 166 | obs = wtf.PosObserver(tensor=child.tensor) 167 | parent1.tensor_has_energy = has_energy1 168 | parent2.tensor_has_energy = has_energy2 169 | parent1.norm_axis = norm_axis1 170 | parent2.norm_axis = norm_axis2 171 | child.new_parents(parent1, parent2) 172 | child.new_child(obs) 173 | return child 174 | 175 | for elem in input_output_list: 176 | child = setup(*elem[:-1]) 177 | assert child.norm_axis == elem[-1] 178 | 179 | # list of (has_energy1, has_energy2, index_id1, index_id2, index_id_child, norm_axis1, norm_axis2, expected_error) 180 | input_output_list = [ 181 | (True, True, "td", "df", "tf", None, None, ".*At most one parent can have.*"), 182 | ( 183 | False, 184 | False, 185 | "td", 186 | "df", 187 | "tf", 188 | (1,), 189 | (0,), 190 | ".*An index_id cannot be normalized twice.*", 191 | ), 192 | ( 193 | True, 194 | False, 195 | "t", 196 | "df", 197 | "tdf", 198 | None, 199 | (0,), 200 | ".*before multiplication with a tensor that has energy.*", 201 | ), 202 | ( 203 | False, 204 | False, 205 | "td", 206 | "df", 207 | "tf", 208 | (0,), 209 | (1,), 210 | ".*should be normalized before marginalization.*", 211 | ), 212 | ] 213 | for elem in input_output_list: 214 | child = setup(*elem[:-1]) 215 | with pytest.raises(ValueError, match=elem[-1]): 216 | child._check_model_validity() 217 | 218 | # The followings should work 219 | child = setup(True, False, "td", "df", "tf", None, (1,)) 220 | assert child._check_model_validity() is None 221 | child = setup(False, False, "td", "df", "tf", (0, 1), (1,)) 222 | assert child._check_model_validity() is None 223 | 224 | 225 | def test_check_model_validity_convolver(): 226 | tensor1 = npr.rand(4, 4, 2) 227 | tensor2 = npr.rand(2, 2, 2) 228 | tensor_child = npr.rand(3, 3, 2) 229 | 230 | def setup( 231 | has_energy1, 232 | has_energy2, 233 | index_id1, 234 | index_id2, 235 | index_id_child, 236 | norm_axis1, 237 | norm_axis2, 238 | conv_idx_ids, 239 | ): 240 | parent1 = wtfc._DynNodeData(tensor=tensor1, index_id=index_id1) 241 | parent1.tensor_has_energy = has_energy1 242 | parent1.norm_axis = norm_axis1 243 | parent2 = wtfc._DynNodeData(tensor=tensor2, index_id=index_id2) 244 | parent2.tensor_has_energy = has_energy2 245 | parent2.norm_axis = norm_axis2 246 | child = wtfo.Multiplier( 247 | index_id=index_id_child, tensor=tensor_child, conv_idx_ids=conv_idx_ids 248 | ) 249 | child.new_parents(parent1, parent2) 250 | obs = wtf.PosObserver(tensor=child.tensor) 251 | obs.new_parent(child) 252 | return child 253 | 254 | # list of (has_energy1, has_energy2, index_id1, index_id2, index_id_child,... 255 | # ... norm_axis1, norm_axis2, conv_idx_ids, expected_error) 256 | input_output_list = [ 257 | ( 258 | False, 259 | False, 260 | "ftd", 261 | "ftd", 262 | "ftd", 263 | (0, 1, 2), 264 | (0, 1), 265 | ("f", "t"), 266 | ".*inner tensor of a Multiplier must have energy.*", 267 | ), 268 | ( 269 | True, 270 | False, 271 | "fid", 272 | "ftd", 273 | "ftd", 274 | None, 275 | (0, 1), 276 | ("f", "t"), 277 | ".*This index should belong to two parents.*", 278 | ), 279 | ( 280 | True, 281 | False, 282 | "ftd", 283 | "ftd", 284 | "ftd", 285 | None, 286 | (0,), 287 | ("f", "t"), 288 | ".*Axis to be convolved should be normalized.*", 289 | ), 290 | ] 291 | for elem in input_output_list: 292 | child = setup(*elem[:-1]) 293 | with pytest.raises(ValueError, match=elem[-1]): 294 | child._check_model_validity() 295 | 296 | # should not raise any error 297 | child = setup(True, False, "ftd", "ftd", "ftd", None, (0, 1), ("f", "t")) 298 | assert child._check_model_validity() is None 299 | 300 | 301 | def test_check_model_validity_multiplexer(): 302 | def setup(has_energy1, has_energy2, norm_axis1, norm_axis2, concatenate): 303 | parent1 = wtfc._DynNodeData(index_id="f", tensor=npr.rand(2)) 304 | parent1.tensor_has_energy = has_energy1 305 | parent1.norm_axis = norm_axis1 306 | parent2 = wtfc._DynNodeData(index_id="f", tensor=npr.rand(2)) 307 | parent2.tensor_has_energy = has_energy2 308 | parent2.norm_axis = norm_axis2 309 | index_id_child = "f" if concatenate else "fd" 310 | multiplexer_idx = "f" if concatenate else None 311 | tensor_child = npr.rand(4) if concatenate else npr.rand(2, 2) 312 | multiplexer = wtf.Multiplexer( 313 | index_id=index_id_child, 314 | tensor=tensor_child, 315 | multiplexer_idx=multiplexer_idx, 316 | ) 317 | multiplexer.new_parents(parent1, parent2) 318 | obs = wtf.PosObserver(index_id=index_id_child, tensor=tensor_child) 319 | obs.new_parent(multiplexer) 320 | return multiplexer 321 | 322 | with pytest.raises(ValueError, match=".*Either all the parents' tensor.*"): 323 | child = setup(False, True, (0,), None, True) 324 | child._check_model_validity() 325 | with pytest.raises(ValueError, match=".*When `multiplexer_idx` is provided.*"): 326 | child = setup(False, False, (0,), (0,), True) 327 | child._check_model_validity() 328 | 329 | child = setup(False, False, (0,), (0,), False) 330 | assert child._check_model_validity() is None 331 | child = setup(True, True, (0,), (0,), False) 332 | assert child._check_model_validity() is None 333 | child = setup(True, True, None, None, True) 334 | assert child._check_model_validity() is None 335 | 336 | 337 | def test_check_model_validity_integrator(): 338 | parent = wtfc._DynNodeData(index_id="dt", tensor=npr.rand(2, 4)) 339 | parent.tensor_has_energy = False 340 | parent.norm_axis = (0,) 341 | child = wtf.Integrator(index_id="dt", tensor=npr.rand(2, 4)) 342 | obs = wtf.PosObserver(index_id="dt", tensor=npr.rand(2, 4)) 343 | parent.new_child(child) 344 | child.new_child(obs) 345 | with pytest.raises(ValueError, match=".*or its last axis must be normalized.*"): 346 | child._check_model_validity() 347 | 348 | 349 | def test_check_model_validity_adder(): 350 | parent1 = wtfc._DynNodeData(index_id="ft", tensor=npr.rand(2, 2)) 351 | parent1.tensor_has_energy = False 352 | parent1.norm_axis = (0, 1) 353 | parent2 = wtfc._DynNodeData(index_id="ft", tensor=npr.rand(2, 2)) 354 | parent2.tensor_has_energy = True 355 | child = wtf.Adder(index_id="ft", tensor=npr.rand(2, 2)) 356 | child.new_parents(parent1, parent2) 357 | obs = wtf.PosObserver(index_id="ft", tensor=npr.rand(2, 2)) 358 | child.new_child(obs) 359 | with pytest.raises(ValueError, match=".*should all have energy.*"): 360 | child._check_model_validity() 361 | 362 | -------------------------------------------------------------------------------- /wonterfact/examples/nmf.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright 2020 Benoit Fuentes 3 | # 4 | # This file is part of Wonterfact. 5 | # 6 | # Wonterfact is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # any later version. 10 | # 11 | # Wonterfact is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Wonterfact. If not, see . 18 | # ---------------------------------------------------------------------------- 19 | 20 | """Examples of nonnegative models for nonnegative observed tensor""" 21 | 22 | 23 | # Python standard library 24 | 25 | # Third-party imports 26 | import numpy as np 27 | import numpy.random as npr 28 | 29 | # wonterfact imports 30 | import wonterfact as wtf 31 | import wonterfact.utils as wtfu 32 | 33 | 34 | def make_nmf(fix_atoms=False): 35 | dim_k, dim_f, dim_t = 5, 20, 100 36 | 37 | atoms_kf = npr.dirichlet(np.ones(dim_f) * 0.9, size=dim_k) 38 | 39 | activations_tk = npr.gamma(shape=0.6, scale=200, size=(dim_t, dim_k)) 40 | 41 | observations_tf = np.einsum("tk,kf->tf", activations_tk, atoms_kf) 42 | observations_tf += npr.randn(dim_t, dim_f) * 1e-4 43 | 44 | leaf_kf = wtf.LeafDirichlet( 45 | name="atoms", 46 | index_id="kf", 47 | norm_axis=(1,), 48 | tensor=atoms_kf if fix_atoms else wtfu.normalize(npr.rand(dim_k, dim_f), (1,)), 49 | update_period=0 if fix_atoms else 1, 50 | prior_shape=1, 51 | ) 52 | leaf_tk = wtf.LeafGamma( 53 | name="activations", 54 | index_id="tk", 55 | tensor=np.ones_like(activations_tk), 56 | prior_rate=1e-5, 57 | prior_shape=1, 58 | ) 59 | mul_tf = wtf.Multiplier(name="multiplier", index_id="tf") 60 | mul_tf.new_parents(leaf_kf, leaf_tk) 61 | obs_tf = wtf.PosObserver(name="observer", index_id="tf", tensor=observations_tf) 62 | mul_tf.new_child(obs_tf) 63 | root = wtf.Root(name="nmf") 64 | obs_tf.new_child(root) 65 | 66 | return root 67 | 68 | 69 | def _aux_smooth_activation_nmf(): 70 | dim_k, dim_f, dim_t = 5, 20, 100 71 | dim_w = 9 72 | 73 | atoms_kf = npr.dirichlet(np.ones(dim_f) * 0.9, size=dim_k) 74 | 75 | activations_tk = npr.gamma(shape=0.6, scale=200, size=(dim_t + dim_w - 1, dim_k)) 76 | spread_tkw = npr.dirichlet(3 * np.ones(dim_w), size=(dim_t + dim_w - 1, dim_k)) 77 | act_tkw = np.einsum("tk,tkw->tkw", activations_tk, spread_tkw) 78 | activations_tk = np.zeros((dim_t, dim_k)) 79 | for ww in range(dim_w): 80 | activations_tk += act_tkw[ww : dim_t + ww, :, ww] 81 | 82 | observations_tf = np.einsum("tk,kf->tf", activations_tk, atoms_kf) 83 | observations_tf += npr.randn(dim_t, dim_f) * 1e-4 84 | 85 | leaf_kf = wtf.LeafDirichlet( 86 | name="atoms", 87 | index_id="kf", 88 | norm_axis=(1,), 89 | tensor=wtfu.normalize(npr.rand(dim_k, dim_f), (1,)), 90 | prior_shape=1, 91 | ) 92 | 93 | leaf_kt = wtf.LeafGamma( 94 | name="impulse", 95 | index_id="kt", 96 | tensor=np.ones((dim_k, dim_t + dim_w - 1)), 97 | prior_shape=1, 98 | prior_rate=1e-5, 99 | ) 100 | leaf_ktw = wtf.LeafDirichlet( 101 | name="spreading", 102 | index_id="ktw", 103 | norm_axis=(2,), 104 | tensor=np.ones((dim_k, dim_t + dim_w - 1, dim_w)) / dim_w, 105 | prior_shape=10, 106 | ) 107 | mul_kwt = wtf.Multiplier(name="mul_kwt", index_id="kwt") 108 | mul_kwt.new_parents(leaf_kt, leaf_ktw) 109 | 110 | mul_tf = wtf.Multiplier(name="reconstruction", index_id="tf") 111 | mul_tf.new_parent(leaf_kf) 112 | obs_tf = wtf.PosObserver(name="observer", index_id="tf", tensor=observations_tf) 113 | mul_tf.new_child(obs_tf) 114 | 115 | return leaf_kt, leaf_ktw, mul_kwt, mul_tf, obs_tf 116 | 117 | 118 | def make_smooth_activation_nmf(): 119 | """ 120 | NMF with overlapping activation (each time coefficient spreads out on 121 | both sides). We uses this model in order to test strides_for_child 122 | """ 123 | _, leaf_ktw, mul_kwt, mul_tf, obs_tf = _aux_smooth_activation_nmf() 124 | 125 | dim_k, _, dim_w = leaf_ktw.tensor.shape 126 | dim_t = obs_tf.tensor.shape[0] 127 | 128 | leaf_w = wtf.LeafDirichlet( 129 | name="adder", index_id="w", norm_axis=(), tensor=np.ones(dim_w), update_period=0 130 | ) 131 | mul_tk = wtf.Multiplier(name="activations", index_id="tk") 132 | shape = (dim_k, dim_w, dim_t + dim_w - 1) 133 | strides_for_child = ( 134 | [ 135 | 8, 136 | ] 137 | + list(np.cumprod(shape[:0:-1]) * 8) 138 | )[::-1] 139 | strides_for_child[-2] += 8 140 | mul_kwt.new_child( 141 | mul_tk, 142 | shape_for_child=(dim_k, dim_w, dim_t), 143 | strides_for_child=strides_for_child, 144 | ) 145 | leaf_w.new_child(mul_tk) 146 | 147 | mul_tf.new_parent(mul_tk) 148 | root = wtf.Root(name="smooth_activation_nmf") 149 | obs_tf.new_child(root) 150 | 151 | return root 152 | 153 | 154 | def make_smooth_activation_nmf2(): 155 | """ 156 | Same model as in the `make_smooth_activation_nmf` method, but with Proxys 157 | and Adder instead of strides 158 | """ 159 | _, leaf_ktw, mul_kwt, mul_tf, obs_tf = _aux_smooth_activation_nmf() 160 | 161 | dim_w = leaf_ktw.tensor.shape[2] 162 | dim_t = obs_tf.tensor.shape[0] 163 | 164 | add_kt = wtf.Adder(name="activations", index_id="kt") 165 | for ww in range(dim_w): 166 | proxy = wtf.Proxy(index_id="kt") 167 | proxy.new_parent( 168 | mul_kwt, 169 | slice_for_child=(slice(None), ww, slice(ww, ww + dim_t)), 170 | index_id_for_child="kt", 171 | ) 172 | proxy.new_child(add_kt) 173 | mul_tf.new_parent(add_kt) 174 | root = wtf.Root(name="smooth_activation_nmf2") 175 | obs_tf.new_child(root) 176 | 177 | return root 178 | 179 | 180 | def make_sparse_nmf(prior_rate=0.001, obs=None, atoms=None): 181 | """ 182 | NMF with minimization of \\sum_{k != k'} P(k|t)P(k'|t)E(t) where P(k|t) are 183 | the activations and E(t) total energy at time t 184 | """ 185 | dim_f, dim_t, dim_k = 2, 100, 2 186 | 187 | # gt_kf = npr.dirichlet(np.ones(dim_f), size=dim_k) 188 | gt_kf = np.array([[4.0, 1.0], [4.0, 3.0]]) 189 | gt_kf /= gt_kf.sum(1, keepdims=True) 190 | gt_tk = npr.gamma(shape=0.3, scale=100, size=(dim_t, dim_k)) 191 | gt_tf = np.dot(gt_tk, gt_kf) 192 | gt_tf += npr.rand(dim_t, dim_f) 193 | if obs is not None: 194 | gt_tf = obs 195 | 196 | leaf_t = wtf.LeafGamma( 197 | name="time_energy", 198 | index_id="t", 199 | tensor=np.ones(dim_t), 200 | prior_rate=prior_rate, 201 | prior_shape=1, 202 | ) 203 | 204 | leaf_tk = wtf.LeafDirichlet( 205 | name="activations", 206 | index_id="tk", 207 | norm_axis=(1,), 208 | tensor=np.ones((dim_t, dim_k)) / dim_k, 209 | prior_shape=1, 210 | ) 211 | 212 | mul_tk = wtf.Multiplier(index_id="tk") 213 | mul_tk.new_parents(leaf_t, leaf_tk) 214 | 215 | mul_tkl = wtf.Multiplier(name="activations_square", index_id="tkl") 216 | leaf_tk.new_child(mul_tkl, index_id_for_child="tl") 217 | mul_tk.new_child(mul_tkl) 218 | if atoms is None: 219 | atoms = np.ones((dim_k, dim_f)) / dim_f 220 | update_period = 1 221 | else: 222 | update_period = 0 223 | leaf_kf = wtf.LeafDirichlet( 224 | name="atoms", 225 | index_id="kf", 226 | norm_axis=(1,), 227 | tensor=atoms, 228 | prior_shape=1, 229 | update_period=update_period, 230 | ) 231 | mul_tf = wtf.Multiplier(name="reconstruction", index_id="tf") 232 | leaf_kf.new_child(mul_tf) 233 | 234 | test_arr = npr.rand(2, dim_k, dim_k) 235 | strides = (test_arr.strides[0],) + np.diag(test_arr[0]).strides 236 | mul_tkl.new_child( 237 | mul_tf, 238 | shape_for_child=(dim_t, dim_k), 239 | strides_for_child=strides, 240 | index_id_for_child="tk", 241 | ) 242 | 243 | obs_tf = wtf.PosObserver(name="observations", index_id="tf", tensor=gt_tf) 244 | mul_tf.new_child(obs_tf) 245 | root = wtf.Root(name="root", verbose_iter=50, cost_computation_iter=10) 246 | obs_tf.new_child(root) 247 | return root 248 | 249 | 250 | def make_sparse_nmf2(prior_rate=0.001, obs=None): 251 | """ 252 | NMF with l1/l2 sparse constraint on atoms 253 | """ 254 | dim_f, dim_t, dim_k = 2, 100, 2 255 | 256 | # gt_kf = npr.dirichlet(np.ones(dim_f), size=dim_k) 257 | gt_kf = np.array([[4.0, 1.0], [4.0, 3.0]]) 258 | gt_kf /= gt_kf.sum(1, keepdims=True) 259 | gt_tk = npr.gamma(shape=0.3, scale=100, size=(dim_t, dim_k)) 260 | gt_tf = np.dot(gt_tk, gt_kf) 261 | gt_tf += npr.rand(dim_t, dim_f) 262 | if obs is not None: 263 | gt_tf = obs 264 | 265 | leaf_kf = wtf.LeafGammaNorm( 266 | name="atoms", 267 | index_id="kf", 268 | tensor=np.ones((dim_k, dim_f)), 269 | l2_norm_axis=(1,), 270 | prior_rate=prior_rate, 271 | prior_shape=1 + 1e-4 * npr.rand(dim_k, dim_k), 272 | ) 273 | 274 | leaf_kt = wtf.LeafDirichlet( 275 | name="activations", 276 | index_id="kt", 277 | norm_axis=(1,), 278 | tensor=np.ones((dim_k, dim_t)) / dim_t, 279 | prior_shape=1, 280 | ) 281 | 282 | mul_tf = wtf.Multiplier(name="reconstruction", index_id="tf") 283 | mul_tf.new_parents(leaf_kt, leaf_kf) 284 | 285 | obs_tf = wtf.PosObserver(name="observations", index_id="tf", tensor=gt_tf) 286 | mul_tf.new_child(obs_tf) 287 | root = wtf.Root(name="root", verbose_iter=50, cost_computation_iter=10) 288 | obs_tf.new_child(root) 289 | return root 290 | 291 | 292 | def make_sparse_nmf3(prior_rate=0.001, obs=None): 293 | """ 294 | NMF with approximation of l2 norm for atoms 295 | """ 296 | dim_f, dim_t, dim_k, dim_a = 2, 100, 2, 2 297 | 298 | gt_kf = npr.dirichlet(np.ones(dim_f), size=dim_k) 299 | # gt_kf = np.array([[4.0, 1.0], [4.0, 3.0]]) 300 | gt_kf /= gt_kf.sum(1, keepdims=True) 301 | gt_tk = npr.gamma(shape=0.3, scale=100, size=(dim_t, dim_k)) 302 | gt_tf = np.dot(gt_tk, gt_kf) 303 | gt_tf += npr.rand(dim_t, dim_f) 304 | if obs is not None: 305 | gt_tf = obs 306 | 307 | leaf_k = wtf.LeafGamma( 308 | name="atoms_energy", 309 | index_id="k", 310 | tensor=np.ones((dim_k)), 311 | prior_shape=1, 312 | prior_rate=prior_rate, 313 | ) 314 | 315 | leaf_kf = wtf.LeafDirichlet( 316 | name="atoms_init", 317 | index_id="kf", 318 | norm_axis=(1,), 319 | tensor=wtfu.normalize(npr.rand(dim_k, dim_f), (1,)), 320 | prior_shape=1, 321 | ) 322 | mul_kf = wtf.Multiplier(index_id="kf") 323 | mul_kf.new_parents(leaf_kf, leaf_k) 324 | mul_kfg = wtf.Multiplier(index_id="kfg") 325 | mul_kf.new_child(mul_kfg) 326 | leaf_kf.new_child(mul_kfg, index_id_for_child="kg") 327 | 328 | leaf_c = wtf.LeafDirichlet( 329 | index_id="c", norm_axis=(0,), tensor=np.array([0.5, 0.5]), update_period=0 330 | ) 331 | mul_ckf = wtf.Multiplier(index_id="ckf") 332 | leaf_c.new_child(mul_ckf) 333 | test_arr = npr.rand(2, dim_f, dim_f) 334 | strides = (test_arr.strides[0],) + np.diag(test_arr[0]).strides 335 | mul_kfg.new_child( 336 | mul_ckf, 337 | index_id_for_child="kf", 338 | shape_for_child=(dim_k, dim_f), 339 | strides_for_child=strides, 340 | ) 341 | 342 | leaf_g = wtf.LeafDirichlet( 343 | index_id="g", norm_axis=(), tensor=np.ones(dim_f - 1), update_period=0 344 | ) 345 | mul_kf2 = wtf.Multiplier(index_id="kf") 346 | leaf_g.new_child(mul_kf2) 347 | mask = np.logical_not(np.eye(dim_f, dtype=bool)) 348 | mul_kfg.new_child( 349 | mul_kf2, 350 | slice_for_child=[slice(None), mask], 351 | shape_for_child=(dim_k, dim_f, dim_f - 1), 352 | ) 353 | 354 | add_kf = wtf.Adder(name="atoms", index_id="kf") 355 | mul_kf2.new_child(add_kf) 356 | mul_ckf.new_child(add_kf, index_id_for_child="kf", slice_for_child=(0, Ellipsis)) 357 | 358 | # leaf_ka = wtf.LeafDirichlet( 359 | # name="angle", 360 | # index_id="ka", 361 | # norm_axis=(1,), 362 | # tensor=np.ones((dim_k, dim_a)), 363 | # l2_norm_axis=(1,), 364 | # prior_shape=1 + 1e-4 * npr.rand(dim_k, dim_a), 365 | # ) 366 | 367 | # mul_ka = wtf.Multiplier(index_id="ka") 368 | # mul_ka.new_parents(leaf_k, leaf_ka) 369 | 370 | # mul_kab = wtf.Multiplier(index_id="kab") 371 | # mul_ka.new_child(mul_kab) 372 | # leaf_ka.new_child(mul_kab, index_id_for_child="kb") 373 | 374 | # leaf_abm = wtf.LeafDirichlet( 375 | # index_id="abm", 376 | # norm_axis=(2,), 377 | # tensor=np.array([[[1, 0, 0], [0, 1, 0]], [[0, 1, 0], [0, 0, 1]]]), 378 | # update_period=0, 379 | # ) 380 | 381 | # mul_km = wtf.Multiplier(index_id="km") 382 | # mul_km.new_parents(leaf_abm, mul_kab) 383 | 384 | # leaf_mnf = wtf.LeafDirichlet( 385 | # name="basis", 386 | # index_id="mnf", 387 | # norm_axis=(1, 2), 388 | # tensor=np.array( 389 | # [[[0, 0.5], [0.5, 0]], [[0.5, 0.5], [0, 0]], [[0.5, 0], [0, 0.5]]] 390 | # ), 391 | # update_period=0, 392 | # ) 393 | 394 | # mul_nkf = wtf.Multiplier(index_id="nkf", name="atoms") 395 | # mul_nkf.new_parents(leaf_mnf, mul_km) 396 | 397 | leaf_kt = wtf.LeafDirichlet( 398 | name="activations", 399 | index_id="kt", 400 | norm_axis=(1,), 401 | tensor=np.ones((dim_k, dim_t)) / dim_t, 402 | prior_shape=1, 403 | ) 404 | 405 | mul_tf = wtf.Multiplier(name="reconstruction", index_id="tf") 406 | leaf_kt.new_child(mul_tf) 407 | # mul_nkf.new_child(mul_tf, index_id_for_child="kf", slice_for_child=[0, Ellipsis]) 408 | add_kf.new_child(mul_tf) 409 | 410 | obs_tf = wtf.PosObserver(name="observations", index_id="tf", tensor=gt_tf) 411 | mul_tf.new_child(obs_tf) 412 | root = wtf.Root(name="root", verbose_iter=50, cost_computation_iter=10) 413 | obs_tf.new_child(root) 414 | return root 415 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WONTERFACT 2 | 3 | ## Overview 4 | Wonterfact, or WONderful TEnsoR FACTorization, is a python package which provides a powerful tool to design any tensor factorization model and estimate the corresponding parameters. 5 | 6 | This project has been initiated from 2015 to 2019 by Benoit Fuentes as a researcher in [Smart Impulse](https://www.smart-impulse.com/) R&D team. 7 | 8 | ## Features 9 | - It is a generalization of coupled tensor factorization, allowing to design an infinite number of models. 10 | - Many kind of operations can be used to design models: tensor multiplication, addition, integration, convolution, concatenation, etc. 11 | - Tensors to be factorized (i.e. observed tensors) can be either nonnegative or real-valued. 12 | - Each factor tensor of the model can be either nonnegative or real-valued, allowing to design semi-nonnegative models. 13 | - It is based on solid probabilistic framework: Poisson-based for nonnegative observed tensors or Skellam-based for real valued observed tensors. 14 | - Smooth constraints can be added to factors via prior distributions. 15 | - Hard linear equality or inequality constraints can be applied on factor tensors. 16 | - Two inference algorithms have been implemented: the Expectation-Maximization (EM) algorithm to find the posterior maximum of the parameters or the Variational Bayes EM (VBEM) algorithm to find an approximation of the posterior distributions of the parameters. 17 | - In VBEM mode, hyperparameters of the prior distributions can also be inferred. 18 | - In VBEM mode, wonterfact is a generalization of Latent Dirichlet Allocation (LDA). 19 | - Possibility to use cpu or gpu backend. 20 | 21 | 22 | ## References 23 | For the moment, no article has been published yet to introduce this package and describe theoretical background on which it relies. Such an article should be published during the year 2021. Meanwhile, you can check on the [technical report](https://github.com/SmartImpulse/Wonterfact/releases/download/v3.0.0/Wonterfact_technical_report.pdf) (beware this is work in progress) if you're interested in the theory. 24 | 25 | 26 | ## Installation 27 | 28 | ### Requirements 29 | Wonterfact only runs with python 3.8+. 30 | 31 | [Graphviz](https://graphviz.org/) need to be installed to be able to render graphs of designed model (not necessary). 32 | 33 | To be able to use wonterfact with gpu (optional), [CUDA](https://developer.nvidia.com/cuda-toolkit-archive) and `cupy` need to be installed (see [here](https://docs.cupy.dev/en/master/install.html) for cupy installation). 34 | Wonterfact has only been tested with CUDA 10.2 and the corresponding `cupy` version: 35 | ```bash 36 | pip install cupy-cuda102 37 | ``` 38 | It might work with other versions of CUDA (not tested). 39 | 40 | For conda users who would prefer to "conda install" instead of "pip install" the main dependencies before installing wonterfact package (like numpy for instance), you can find the list of requirements in the "[tool.poetry.dependencies]" section of the `pyproject.toml` file. 41 | 42 | ### Installation of wonterfact 43 | With `pip`: 44 | ```bash 45 | pip install wonterfact 46 | ``` 47 | 48 | ## Getting started: Nonnegative Matrix Factorization (NMF) 49 | Wonterfact offers many possibilities, and mastering it requires time and effort (a full documentation as well as many tutorials should be released soon). However, once mastered, implementing and testing any tensor factorization model is quite simple and quick. We introduce here, by mean of a very simple example, the main principles of the use of wonterfact. 50 | 51 | (If you read these lines on github, you can install an appropriate extension such as [this one for Firefox](https://addons.mozilla.org/fr/firefox/addon/latexmathifygithub/) or [this one for Chrome](https://chrome.google.com/webstore/detail/github-math-display/cgolaobglebjonjiblcjagnpmdmlgmda) in order to correctly render mathematical formulas.) 52 | 53 | ### Formulation of the model. 54 | The goal of NMF is to find a good approximation of a nonnegative matrix *X*, that can be decomposed as the product of two nonnegative matrices $W$ and $H$: $X\approx WH,W\geq 0,H\geq 0$. In some application of the NMF problem, $W$ is called *atoms* and $H$ is called *activations*, terms that will be kept in the following to refer these two factors. Moreover, the more generic term *tensor* can be used instead of *matrix*. 55 | 56 | The first stage, and it is the main difficulty, in order to implement such a problem with wonterfact is to reformulate it with normalization constraints. It is not explained why here (there are some theoretical reasons) but factors need to be normalized in a specific way. We give here a step by step simple recipe to provide a valid normalization. 57 | 58 | First, give names to each tensor index: $X_{ft}\approx\sum_{k}W_{fk}H_{kt},W\geq 0,H\geq 0$. In wonterfact, tensors are identified with their indexes' names. 59 | 60 | Then, add a scalar factor $\lambda$ to the model: $X_{ft}\approx\lambda\sum_{k}W_{fk}H_{kt},W\geq 0,H\geq 0$ This scalar can be called *overall energy*. 61 | 62 | Find normalization constraints on each factor so that the sum on all indexes of the left hand side expression, from which overall energy has been withdrawn, is equal to 1. Here, we want then $\sum_{ftk}W_{fk}H_{kt}=1$. It can be verified that a valid normalization can be: $\sum_{kt}H_{kt}=1,\forall k,\sum_{f}W_{fk}=1$. 63 | 64 | A convenient convention to designate nonnegative tensors subject to normalization constraints is to use the same letter $\theta$ for them, as well as the sign "|" for partial normalization: any tensor written for instance as $\theta_{wx\mid zy}$ should verify $\forall y,z, \sum_{wx}\theta_{wx\mid zy}=1$ and $\forall w,x,y,z, \theta_{wx\mid zy}\geq 0$. In this way, the model can be reformulated as: $X_{ft}\approx\lambda\sum_{k}\theta_{f \mid k}\theta_{kt}$. All normalization constraints are implicitly express through the suggested convention and the name of the indexes are sufficient to identify the former $W$ and $H$ matrices. Generally, there is several ways to find a valid normalization for the factors, but they all respect the following rule: each index must be once and only once on the left side of the sign "|" (all indexes of a tensor subject to normalization constraint and not having a "|" sign are considered to be on the left side). Feel free to switch order of indexes in a given tensor if needed. 65 | 66 | Eventually, the overall energy can be reintegrated in a full normalized tensor, i.e. not having the "|" sign, getting rid of this specific normalization constraint. Non-normalized tensors can be expressed with the letter $\lambda$ and $\lambda\theta_{kt}$ becomes then $\lambda_{kt}$. 67 | 68 | 69 | Finally, all intermediate operations should be expressed, so that multiplications of tensors have only two operands. In our case, it gives: $X_{ft}\approx\lambda_{ft}\text{ with }\lambda_{ft}=\sum_{k}\theta_{f\mid k}\lambda_{kt}, \lambda_{kt}\geq 0, \theta_{f\mid k}\geq 0, \sum_{f}\theta_{f\mid k}=1$. $\lambda_{kt}$ are the activations and $\theta_{f\mid k}$ the atoms. 70 | 71 | ### A graphical way to represent the model. 72 | 73 | Before implementing the model with wonterfact, it is recommended to draw the tree of tensor operations. In this tree, *leaves* (i.e. nodes with no parents) correspond to the factors of the model, level 1 nodes, called *observers* are the tensors to be decomposed, the single *root* (level 0) is used to identify the whole tree and all other inner nodes, called *operators* correspond to tensors resulting from a specific operation of their parents nodes. The graphical representation of our current NMF model is as follow. 74 | 75 | 76 | 77 | 78 | Nodes label correspond to the indexes of the tensors. Indexes are underlined to represent tensor not subject to normalization constraints (like $\lambda_{kt}$), and not underlined if they corresponds to normalized tensors such as $\theta_{f\mid k}$). 79 | 80 | ### Implementation with wonterfact 81 | 82 | Implementing a tensor factorization model with wonterfact consists in performing some kind of literal translation of its graphical representation (like the one presented in previous section) in wonterfact language. There exists a specific class for each kind of node (leaves, operators, observers, etc.) and a users just need to know their name and attributes. We describe here how to make implement our NMF model. We will proceed bottom-up, starting from the root and ending with the leaves. 83 | 84 | First off all, let us import needed packages, and define arbitrary dimensions for our problem: 85 | ```python 86 | import wonterfact as wtf 87 | import numpy as np 88 | import numpy.random as npr 89 | dim_f, dim_k, dim_t = 10, 3, 30 90 | ``` 91 | 92 | Then let us create the root. We recall that the root is used to represent the whole tree. Root object has many attributes, but all of them have default values and therefore are optionals. We will nevertheless specify a name for this node (it is not necessary but it can be quite useful for debugging when errors are raised), and the type of inference algorithm (here EM algorithm, which is the default value anyway). We will also specify that one want to compute the value of the cost function to minimize at each iteration (it is not necessary, but we will show you that this cost decreases over the iterations). 93 | ```python 94 | my_tree = wtf.Root(name='root', inference_mode='EM', cost_computation_iter=1) 95 | ``` 96 | 97 | Now is time to create the observer. Since the tensor to be decomposed is nonnegative, the class `PosObserver` is used. To create a node that represents a tensor (i.e. all nodes but the root), it is necessary to provide the name of indexes in the write order via `index_id` attribute. The only convention you need to keep in mind, is that **the order of indexes must be reversed compared to the order on the graph representation** (an explanation will be given when the normalized leave representing atoms is instantiated). For observers object, the tensor to decompose is passed through the `tensor` attribute. We define here some random tensor. 98 | 99 | ```python 100 | observer = wtf.PosObserver( 101 | name='observer', 102 | index_id='tf', # could also define a tuple of strings like ('t', 'f') 103 | tensor=100 * npr.rand(dim_t, dim_f) 104 | ) 105 | ``` 106 | 107 | We can create the edge between the two nodes: 108 | 109 | ```python 110 | observer.new_child(my_tree) 111 | # my_tree.new_parent(observer) # is equivalent 112 | # wtf.create_filiation(observer, root) # is also equivalent 113 | ``` 114 | 115 | Let us go further and define the node above the observer that represents the approximation $\lambda_{ft}$. It is defined as the product of its two parents, therefore the class `Multiplier` is used. Only `index_id` attribute is necessary (do not forget to define indexes backwards). 116 | 117 | ```python 118 | multiplier = wtf.Multiplier(name='multiplier', index_id='tf') 119 | multiplier.new_child(observer) # edge with its child 120 | ``` 121 | 122 | For the leaves, let us start with the one representing activations $\lambda_{kt}$. This tensor has no normalization constraint, therefore the right class to use to represent it is `LeafGamma`. The reason for such a class name is that, in the probabilistic framework on which wonterfact relies, each coefficient of nonnegative tensor factors not subject to normalization constraints are considered as gamma random variables. There are two hyperparameters for gamma random variables: shape and rate. Those two hyperparameters defines then the prior distribution of the activations and can be specified during creation of the leaf with the two attributes `prior_shape` and `prior_rate`. Leaving these two attributes to their default value is equivalent to not consider any prior. `tensor` attribute is compulsory and defines the initial values for the activations. 123 | 124 | ```python 125 | leaf_tk = wtf.LeafGamma( 126 | name='activations', 127 | index_id='tk', 128 | tensor=np.ones((dim_t, dim_k)), # initialization with uniform activations 129 | prior_shape=1, # default value, meaning "uniform prior" over R 130 | prior_scale=0 # default value, meaning "uniform prior" over R 131 | ) 132 | ``` 133 | 134 | The last node to create is the leaf representing atoms $\theta_{f\mid k}$. Since this tensor is subject to normalization constraint, the right class to use is `LeafDirichlet`, referring to the Dirichlet prior distribution. This distribution can be defined with the shape hyperparameters. Besides `index_id`, `tensor`, and `prior_shape` attributes, `LeafDirichlet` class needs a `norm_axis` attribute that specifies the axes on which the tensor has the normalization constraint. For a reason internal to wonterfact (related to the way numpy manages automatic broadcasting arrays), **it is necessary that normalization axes are the last axes of the tensor**, hence the convention to define indexes and tensors' shape backwards. We decide to randomly initialize the atoms. Beware that initialization must respect the normalization constraint. 135 | 136 | ```python 137 | leaf_kf = wtf.LeafDirichlet( 138 | name='atoms', 139 | index_id='kf', 140 | norm_axis=(1, ), # self.tensor.sum(axis=self.norm_axis) must be equal to 1 141 | tensor=npr.dirichlet(np.ones(dim_f), size=dim_k), # random initialization 142 | prior_shape=1, # default value, meaning "uniform prior" over the simplex 143 | ) 144 | ``` 145 | 146 | Finally, we create the last needed edges: 147 | ```python 148 | leaf_tk.new_child(multiplier) 149 | leaf_kf.new_child(multiplier) 150 | # multiplier.new_parents(leaf_tk, leaf_kf) # is equivalent 151 | ``` 152 | 153 | This is it, the NMF model is now instantiated. You can check if the tree of operations corresponds to the one you draw yourself (this is how we actually generated the figure in previous section): 154 | ```python 155 | my_tree.draw_tree(show_node_names=True, filename='nmf_tree.svg') 156 | ``` 157 | 158 | You can use name you gave to each node (provided that each name is unique) to access to a particular node: 159 | ```python 160 | leaf_kf == my_tree.atoms == my_tree.nodes_by_id['atoms'] # returns True 161 | ``` 162 | 163 | In order to estimate the parameters of the model (here atoms and activations), use the method `estimate_param` of your tree and decide on the number of iterations: 164 | ```python 165 | my_tree.estimate_param(n_iter=30) 166 | ``` 167 | If you think you need more iterations, just call again the same method (stop and go mode). Algorithm might stop before the requested number of iterations if convergence has been reached. 168 | ```python 169 | my_tree.estimate_param(n_iter=70) # up to a total of 100 iterations 170 | ``` 171 | 172 | Plot cost function and observed its decrease: 173 | ```python 174 | import matplotlib.pyplot as plt 175 | plt.plot(my_tree.cost_record) 176 | plt.xlabel('iteration') 177 | plt.ylabel('cost function') 178 | ``` 179 | 180 | 181 | Get atoms and activations values: 182 | ```python 183 | print(my_tree.atoms.tensor) 184 | print(my_tree.activations.tensor) 185 | ``` 186 | -------------------------------------------------------------------------------- /wonterfact/observers.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright 2020 Smart Impulse SAS, Benoit Fuentes 3 | # 4 | # This file is part of Wonterfact. 5 | # 6 | # Wonterfact is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # any later version. 10 | # 11 | # Wonterfact is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Wonterfact. If not, see . 18 | # ---------------------------------------------------------------------------- 19 | 20 | """Module for all observer classes""" 21 | 22 | 23 | # Python System imports 24 | 25 | # Third-party imports 26 | import numpy as np 27 | import numpy.random as npr 28 | from functools import cached_property 29 | 30 | 31 | # Relative imports 32 | from . import utils, core_nodes 33 | from .glob_var_manager import glob 34 | 35 | 36 | class _Observer( 37 | core_nodes._NodeData, core_nodes._ChildNode 38 | ): # TODO: optimize mask_data 39 | """ 40 | Base class for all Observer nodes, i.e. nodes that carry observed data. 41 | """ 42 | 43 | def __init__(self, **kwargs): 44 | """ 45 | Parameters 46 | ---------- 47 | mask_data: array_like of booleans or None, default None 48 | Boolean mask to apply to observed 2 to specify which 49 | coefficients are masked and which are not. If masked, a coefficient 50 | plays no role in the optimization process. 51 | drawings_max: float or None, optional, default None 52 | During optimization algorithm, inner tensor is normalized with a 53 | coefficient which can increase along iterations up to a limit. 54 | value. If not None, this limit value is computed as `drawings_max / 55 | abs(self.tensor).sum()`. If None, the limit value is `1`. 56 | drawings_update_iter: int, optional, default 1 57 | The normalization coefficient (see `drawings_max` section) is 58 | updated every `drawings_update_iter` iterations. 59 | drawings_step: float or None, optional, default None 60 | Normalization coefficient is initialized as `drawings_step / 61 | abs(self.tensor).sum()` and when it has to be updated (see 62 | `drawings_update_iter`), the same amount is added to the current 63 | normalization coefficient until it reaches its limit (see 64 | `drawings_max` section). If None, `drawings_step` is set to 65 | `drawings_max` value so that the normalization coefficient remains 66 | fixed during the algorithm. 67 | 68 | Notes 69 | ----- 70 | The dynamic normalization feature (see `drawings_max`, 71 | `drawings_update_iter` and `drawings_step` sections) aims at giving more 72 | weight to the priors in the early stage of the algorithm. If you do not 73 | want to use this feature, just leave default values for those arguments. 74 | """ 75 | self.mask_data = kwargs.pop("mask_data", None) 76 | self.drawings_max = kwargs.pop("drawings_max", None) 77 | self.drawings_step = kwargs.pop("drawings_step", None) 78 | self.drawings_update_iter = kwargs.pop("drawings_update_iter", 1) 79 | super().__init__(**kwargs) 80 | 81 | self.drawings_max = self.drawings_max or self.sum_tensor 82 | self.drawings_step = self.drawings_step or self.sum_tensor 83 | self.drawings = self.drawings_step 84 | self.drawings_update_counter = 0 85 | 86 | @cached_property 87 | def is_tensor_null(self): 88 | return (self.tensor == 0).astype(self.tensor.dtype) 89 | 90 | @cached_property 91 | def is_tensor_pos(self): 92 | return self.tensor > 0 93 | 94 | @cached_property 95 | def sum_tensor(self): 96 | return glob.xp.abs(self.tensor).sum() 97 | 98 | def there_is_a_mult_factor(self): 99 | return self.drawings != self.sum_tensor 100 | 101 | def get_current_mult_factor(self): 102 | return self.drawings / self.sum_tensor if self.sum_tensor else 1 103 | 104 | def update_drawings(self): 105 | self.drawings_update_counter += 1 106 | if self.drawings_update_counter % self.drawings_update_iter == 0: 107 | self.drawings = min(self.drawings_max, self.drawings + self.drawings_step) 108 | 109 | def apply_mask_to_tensor_update(self, tensor_update): 110 | if self.mask_data is not None: 111 | tensor_update[self.mask_data] = 1 112 | 113 | def get_current_reconstruction(self): 114 | """ 115 | Returns the current approximation of the observed tensor. 116 | """ 117 | raise NotImplementedError 118 | 119 | def tensor_has_energy(self): 120 | return True 121 | 122 | 123 | class PosObserver(_Observer): 124 | """ 125 | Class for nonnegative observations. 126 | """ 127 | 128 | def _initialization(self): 129 | pass 130 | 131 | def _give_update(self, parent, out=None): 132 | parent_tensor = parent.get_tensor_for_children(self) 133 | 134 | if out is None: 135 | tensor_update = glob.xp.empty_like(self.tensor) 136 | else: 137 | tensor_update = out 138 | 139 | denominator = parent_tensor + self.is_tensor_null 140 | if self._inference_mode == "VBEM": 141 | # IN VBEM mode, underflow might happen, leading to null parent_tensor 142 | denominator += parent_tensor == 0 143 | 144 | tensor_update[...] = self.get_current_mult_factor() * self.tensor / denominator 145 | 146 | self.update_drawings() 147 | self.apply_mask_to_tensor_update(tensor_update) 148 | 149 | return tensor_update 150 | 151 | def get_current_reconstruction(self, parent, force_numpy=False): 152 | tensor_to_give = ( 153 | parent.get_tensor_for_children(self) / self.get_current_mult_factor() 154 | ) 155 | if force_numpy and utils.infer_backend(tensor_to_give) == glob.CUPY: 156 | return glob.xp.asnumpy(tensor_to_give) 157 | return tensor_to_give 158 | 159 | def get_kl_divergence(self): 160 | """ 161 | Returns the kullback-Leibler divergence between observer tensor and the 162 | current reconstruction. 163 | """ 164 | kl_div = 0 165 | for parent in self.list_of_parents: 166 | reconstruction = self.get_current_reconstruction(parent) 167 | my_tensor = self.tensor 168 | kl_div -= utils.xlogy(my_tensor, reconstruction).sum() 169 | kl_div += reconstruction.sum() 170 | kl_div -= my_tensor.sum() - utils.xlogy(my_tensor, my_tensor).sum() 171 | return kl_div.item() 172 | 173 | def _get_data_fitting(self): 174 | """ 175 | Returns minus log-likelihood of Poisson distribution 176 | """ 177 | lh = glob.xp.zeros_like(self.tensor) 178 | for parent in self.list_of_parents: 179 | my_tensor = self.get_current_mult_factor() * self.tensor 180 | parent_tensor = parent.get_tensor_for_children(self) 181 | if self._inference_mode == "EM": 182 | lh -= parent_tensor 183 | elif self._inference_mode == "VBEM": 184 | pass 185 | # this part is canceled with the gamma leaves prior values 186 | # and therefore, method "_get_mean_tensor_for_VBEM" has been 187 | # removed (check v2.1.2 to get it back) 188 | # lh -= parent._get_mean_tensor_for_VBEM(self) 189 | lh += utils.xlogy(my_tensor, parent_tensor) 190 | lh -= glob.sps.gammaln(my_tensor + 1) 191 | if self.mask_data is not None: 192 | if self._inference_mode == "EM": 193 | lh[self.mask_data] = 0 194 | elif self._inference_mode == "VBEM": 195 | lh[self.mask_data] = parent_tensor[self.mask_data] 196 | return -lh.sum().item() 197 | 198 | 199 | class RealObserver(_Observer): 200 | """ 201 | Class for real observations 202 | """ 203 | 204 | def __init__(self, **kwargs): 205 | """ 206 | Parameters 207 | ---------- 208 | limit_skellam_update: bool, default True 209 | Set to True if data are real, False if data are integer 210 | """ 211 | self.limit_skellam_update = kwargs.pop("limit_skellam_update", True) 212 | super().__init__(**kwargs) 213 | 214 | @cached_property 215 | def abs_tensor(self): 216 | return glob.xp.abs(self.tensor) 217 | 218 | @cached_property 219 | def abs_tensor_plus_1(self): 220 | return self.abs_tensor + 1 221 | 222 | @cached_property 223 | def abs_tensor_power2(self): 224 | return self.abs_tensor ** 2 225 | 226 | @cached_property 227 | def nonneg_tensor(self): 228 | return utils.real_to_2D_nonnegative(self.tensor) 229 | 230 | def _initialization(self): 231 | pass 232 | 233 | def get_current_reconstruction(self, parent, force_numpy=False): 234 | parent_tensor = parent.get_tensor_for_children(self) 235 | real_parent_tensor = parent_tensor[..., 0] - parent_tensor[..., 1] 236 | tensor_to_give = real_parent_tensor / self.get_current_mult_factor() 237 | if force_numpy and utils.infer_backend(tensor_to_give) == glob.CUPY: 238 | return glob.xp.asnumpy(tensor_to_give) 239 | return tensor_to_give 240 | 241 | def _get_data_fitting(self): 242 | """ 243 | Returns minus log-likelihood of either Skellam distribution or extended 244 | real KL divergence. 245 | """ 246 | lh = glob.xp.zeros_like(self.tensor) 247 | for parent in self.list_of_parents: 248 | mult_fact = self.get_current_mult_factor() 249 | parent_tensor = parent.get_tensor_for_children(self) 250 | if self._inference_mode == "EM": 251 | lh -= parent_tensor.sum(-1) 252 | elif self._inference_mode == "VBEM": 253 | pass 254 | # this part is canceled with the gamma leaves prior values 255 | # and therefore, method "_get_mean_tensor_for_VBEM" has been 256 | # removed (check v2.1.2 to get it back) 257 | # lh -= parent._get_mean_tensor_for_VBEM(self).sum(-1) 258 | abs_tensor = mult_fact * self.abs_tensor 259 | inside_log = ( 260 | parent_tensor[..., 0] * self.is_tensor_pos 261 | + parent_tensor[..., 1] * glob.xp.logical_not(self.is_tensor_pos) 262 | + self.is_tensor_null 263 | ) 264 | lh += utils.xlogy(abs_tensor, inside_log) 265 | if not self.limit_skellam_update: 266 | abs_tensor += 1 267 | lh += utils.hyp0f1ln( 268 | abs_tensor, parent_tensor[..., 0] * parent_tensor[..., 1] 269 | ) 270 | lh -= glob.sps.gammaln(abs_tensor) 271 | else: 272 | temp_calculus = self.temp_calculus(parent) 273 | lh += temp_calculus 274 | lh -= utils.xlogy( 275 | abs_tensor, (abs_tensor + temp_calculus + self.is_tensor_null) / 2 276 | ) 277 | if self.mask_data is not None: 278 | if self._inference_mode == "EM": 279 | lh[self.mask_data] = 0 280 | elif self._inference_mode == "VBEM": 281 | lh[self.mask_data] = parent_tensor[self.mask_data, :].sum(-1) 282 | return -lh.sum().item() 283 | 284 | def temp_calculus(self, parent): 285 | parent_tensor = parent.get_tensor_for_children(self) 286 | model_param_prod = parent_tensor[..., 0] * parent_tensor[..., 1] 287 | if self.there_is_a_mult_factor(): 288 | abs_tensor_power2 = ( 289 | self.get_current_mult_factor() ** 2 * self.abs_tensor_power2 290 | ) 291 | else: 292 | abs_tensor_power2 = self.abs_tensor_power2 293 | temp = glob.xp.sqrt(4 * model_param_prod + abs_tensor_power2) 294 | return temp 295 | 296 | def _give_update(self, parent, out=None): 297 | 298 | # model 299 | parent_tensor = parent.get_tensor_for_children(self) 300 | if self.there_is_a_mult_factor(): 301 | mult_fact = self.get_current_mult_factor() 302 | abs_tensor = mult_fact * self.abs_tensor 303 | nonneg_tensor = mult_fact * self.nonneg_tensor 304 | else: 305 | abs_tensor = self.abs_tensor 306 | nonneg_tensor = self.nonneg_tensor 307 | 308 | if out is None: 309 | tensor_update = glob.xp.empty_like(parent_tensor) 310 | else: 311 | tensor_update = out 312 | 313 | # compute tensor_update 314 | if not self.limit_skellam_update: 315 | model_param_prod = parent_tensor[..., 0] * parent_tensor[..., 1] 316 | for ii in range(2): 317 | tensor_update[..., ii] = (2 * parent_tensor[..., 1 - ii]) / ( 318 | 2 * (1 + abs_tensor) 319 | + utils.bessel_ratio( 320 | abs_tensor + 1, 321 | 2 * (model_param_prod ** 0.5), 322 | 1e-16, 323 | ) 324 | ) 325 | else: 326 | temp_calculus = self.temp_calculus(parent) 327 | for ii in range(2): 328 | tensor_update[..., ii] = (2 * parent_tensor[..., 1 - ii]) / ( 329 | abs_tensor 330 | + temp_calculus 331 | # to avoid x/0, in which case value of tensor_update is not important 332 | ) 333 | if glob.xp == np: 334 | with np.errstate(invalid="ignore"): 335 | tensor_update += nonneg_tensor / parent_tensor 336 | else: 337 | tensor_update += nonneg_tensor / parent_tensor 338 | tensor_update[parent_tensor == 0] = 0 339 | # glob.xp.nan_to_num(tensor_update, copy=False) 340 | 341 | self.update_drawings() 342 | self.apply_mask_to_tensor_update(tensor_update) 343 | 344 | return tensor_update 345 | 346 | 347 | class BlindObs(core_nodes._ChildNode, core_nodes._ParentNode): 348 | def __init__(self, **kwargs): 349 | super().__init__(**kwargs) 350 | 351 | def _give_update(self, parent, out=None): 352 | if out is None: 353 | return glob.xp.ones_like(parent.get_tensor_for_children(self)) 354 | out[...] = 1.0 355 | 356 | return out 357 | 358 | def _get_data_fitting(self): 359 | if self._inference_mode == "EM": 360 | return 0 361 | elif self._inference_mode == "VBEM": 362 | total_energy = 0 363 | for parent in self.list_of_parents: 364 | total_energy += parent.get_tensor_for_children(self).sum() 365 | return -total_energy.item() 366 | -------------------------------------------------------------------------------- /wonterfact/cupy_utils.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright 2019 Smart Impulse SAS 3 | # 4 | # This file is part of Wonterfact. 5 | # 6 | # Wonterfact is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # any later version. 10 | # 11 | # Wonterfact is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Wonterfact. If not, see . 18 | # ---------------------------------------------------------------------------- 19 | 20 | """Methods written for cupy or cuda""" 21 | 22 | # Python System imports 23 | import sys 24 | import time 25 | from functools import lru_cache # pylint: disable=E0611 26 | 27 | # Third-party imports 28 | from numba import cuda, float32, float64 29 | 30 | # Relative imports 31 | from .glob_var_manager import glob 32 | from . import utils 33 | 34 | if glob.float == glob.FLOAT32: 35 | numba_float = float32 36 | elif glob.float == glob.FLOAT64: 37 | numba_float = float64 38 | 39 | 40 | def cupy_cumsum_2d(arr_in, arr_out, max_threads=256): 41 | cp = utils.xp_utils.back("cupy") 42 | if arr_in.shape[-1] == 1: 43 | arr_out[...] = arr_in 44 | return 45 | batch_dim_y = utils.next_pow_of_two(min(arr_in.shape[1], 2 * max_threads)) 46 | block_dim_y = batch_dim_y // 2 47 | block_dim_x = min(arr_in.shape[0], max_threads // block_dim_y) 48 | blocks_number_y = (arr_in.shape[1] + batch_dim_y - 1) // batch_dim_y 49 | blocks_number_x = (arr_in.shape[0] + block_dim_x - 1) // block_dim_x 50 | print(blocks_number_y) 51 | 52 | if blocks_number_y > 1: 53 | store = True 54 | aux = cp.zeros((blocks_number_x, blocks_number_y)) 55 | else: 56 | store = False 57 | aux = cp.zeros(()) 58 | 59 | inclusive_scan_2d( 60 | ( 61 | blocks_number_x, 62 | blocks_number_y, 63 | ), 64 | ( 65 | block_dim_x, 66 | block_dim_y, 67 | ), 68 | ( 69 | arr_in, 70 | arr_out, 71 | arr_in.shape[0], 72 | arr_in.shape[1], 73 | batch_dim_y, 74 | aux, 75 | store, 76 | arr_in.strides[0] // 8, 77 | arr_in.strides[1] // 8, 78 | arr_out.strides[0] // 8, 79 | arr_out.strides[1] // 8, 80 | ), 81 | shared_mem=(block_dim_x * (batch_dim_y + 1)) * 8, 82 | ) 83 | if blocks_number_y > 1: 84 | incr = cp.zeros((blocks_number_x, blocks_number_y)) 85 | cupy_cumsum_2d(aux, incr, max_threads=max_threads) 86 | sum_inclusive_scan_2d( 87 | ( 88 | blocks_number_x, 89 | blocks_number_y, 90 | ), 91 | ( 92 | block_dim_x, 93 | batch_dim_y, 94 | ), 95 | ( 96 | incr, 97 | arr_out, 98 | arr_in.shape[0], 99 | arr_in.shape[1], 100 | arr_out.strides[0] // 8, 101 | arr_out.strides[1] // 8, 102 | ), 103 | ) 104 | 105 | 106 | @lru_cache(maxsize=1024) 107 | def find_cumsum_max_threads(size_arr, size_cumsum): 108 | cp = utils.xp_utils.back("cupy") 109 | arr = cp.random.rand(size_arr, size_cumsum) 110 | out = arr.copy() 111 | tac_list = [] 112 | max_threads_list = [2 ** nn for nn in range(5, 10)] 113 | for max_threads in max_threads_list: 114 | tic = time.time() 115 | for __ in range(100): 116 | cupy_cumsum_2d(arr, out, max_threads=max_threads) 117 | device = cp.cuda.Device() 118 | device.synchronize() 119 | tac_list.append(time.time() - tic) 120 | __, max_threads = min(zip(tac_list, max_threads_list), key=lambda x: x[0]) 121 | return max_threads 122 | 123 | 124 | min_clip = utils.xp_utils.back("cupy").ElementwiseKernel( 125 | "float64 x, float64 min_val", # input params 126 | "float64 y", # output params 127 | "y = ((x < min_val) ? min_val:x)", 128 | "min_clip", 129 | ) 130 | 131 | 132 | max_clip = utils.xp_utils.back("cupy").ElementwiseKernel( 133 | "float64 x, float64 max_val", # input params 134 | "float64 y", # output params 135 | "y = ((x > max_val) ? max_val:x)", 136 | "max_clip", 137 | ) 138 | 139 | _xlogy = utils.xp_utils.back("cupy").ElementwiseKernel( 140 | "float64 x, float64 y", # input params 141 | "float64 z", # output params 142 | "z = ((x != 0.0) ? x*log(y):0.0)", 143 | "_xlogy", 144 | ) 145 | 146 | 147 | def xlogy(arr1, arr2, out=None): 148 | if out is None: 149 | return _xlogy(arr1, arr2) 150 | return _xlogy(arr1, arr2, out) 151 | 152 | 153 | @cuda.jit() 154 | def normalize_l1_l2_tensor_numba_core(tensor, n_iter, prior_rate): 155 | """ 156 | Core function for l1/l2 normalization of LeafGammaNorm' tensor attribute 157 | transforms tensor inplace 158 | This code is inspired by [1] for the reduction part 159 | 160 | input: 161 | - tensor: input tensor to normalize (normalization is performed inplace) 162 | - n_iter: number of iteration for the fix-point algorithm 163 | - prior_beta_arr: a cupy array with values : [2 * prior_rate, 4 * prior_rate] 164 | 165 | [1] https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf 166 | """ 167 | 168 | # shared array def (for output data computation and reduction) 169 | # shared_memo = cuda.shared.array(shape=100, dtype=numba_float[glob.float]) # pylint disable=E1102 170 | shared_memo = cuda.shared.array( 171 | shape=100, dtype=numba_float 172 | ) # pylint disable=E1102 173 | 174 | # get block and thread numbers 175 | dim_0, dim_1 = cuda.blockIdx.x, cuda.threadIdx.x 176 | block_dim = cuda.blockDim.x 177 | 178 | # leave if thread is out of bound 179 | if dim_0 >= tensor.shape[0] and dim_1 >= tensor.shape[1]: 180 | return 181 | 182 | # import input data 183 | out_val_1 = tensor[dim_0, dim_1] 184 | dim_1_bis = dim_1 + block_dim 185 | if dim_1_bis < tensor.shape[1]: 186 | out_val_2 = tensor[ 187 | dim_0, dim_1_bis 188 | ] # each thread deals with 2 values for speed purpose (see [1]: scenario 4) 189 | else: 190 | out_val_2 = 0.0 191 | in_val_1 = out_val_1 192 | in_val_2 = out_val_2 193 | 194 | for __ in range(n_iter): 195 | # First we load data 196 | shared_memo[dim_1] = out_val_1 ** 2 + out_val_2 ** 2 197 | 198 | # A synchronize loop to performe the sum in O(log(N)) 199 | # at the end, the sum is the first component of the shared memory 200 | max_idx = (block_dim) >> 1 # bin swaping is equivalent to // 2 201 | jump_idx = ( 202 | block_dim + 1 203 | ) >> 1 # we need some tricks to deals with odd dimensions 204 | while max_idx > 0: 205 | cuda.syncthreads() 206 | if dim_1 < max_idx: 207 | shared_memo[dim_1] += shared_memo[dim_1 + jump_idx] 208 | max_idx = jump_idx >> 1 209 | jump_idx = (jump_idx + 1) >> 1 210 | 211 | if dim_1 == 0: 212 | shared_memo[1] = shared_memo[0] ** 0.5 213 | shared_memo[2] = 4 * prior_rate * shared_memo[1] 214 | # here, norm22 is in shared_mem0[0], norm2 in shared_memo[1] and shared_memo[2] contains some temp computation 215 | 216 | cuda.syncthreads() 217 | # norm22 = shared_memo[0] 218 | # norm2 = shared_memo[1] 219 | 220 | delta1 = shared_memo[0] + shared_memo[2] * in_val_1 221 | delta2 = shared_memo[0] + shared_memo[2] * in_val_2 222 | out_val_1 = (-shared_memo[1] + delta1 ** 0.5) / (2 * prior_rate) 223 | out_val_2 = (-shared_memo[1] + delta2 ** 0.5) / (2 * prior_rate) 224 | cuda.syncthreads() 225 | 226 | tensor[dim_0, dim_1] = out_val_1 227 | if dim_1_bis < tensor.shape[1]: 228 | tensor[ 229 | dim_0, dim_1_bis 230 | ] = out_val_2 # each thread deals with 2 values for speed purpose (see [1]: scenario 4) 231 | 232 | 233 | _set_bezier_point = utils.xp_utils.back("cupy").ElementwiseKernel( 234 | "float64 val1, float64 val2, float64 val3, float64 p", 235 | "float64 z", 236 | """ 237 | double omp = 1 - p; 238 | z = omp * omp * val1 + 2 * omp * p * val2 + p * p * val3; 239 | """, 240 | "_set_bezier_point", 241 | ) 242 | 243 | 244 | multiply_and_sum = utils.xp_utils.back("cupy").ReductionKernel( 245 | "float64 x, float64 y", # input params 246 | "float64 z", # output params 247 | "x * y", # map 248 | "a + b", # reduce 249 | "z = a", # post-reduction map 250 | "0", # identity value 251 | "multiply_and_sum", # kernel name 252 | ) 253 | 254 | 255 | _exp_digamma_c_code = """ 256 | double coef_list[] = { 257 | 0.041666666666666664, 258 | -0.006423611111111111, 259 | 0.003552482914462081, 260 | -0.0039535574489730305, 261 | }; 262 | double tmp2, temp; 263 | 264 | double input_plus_n = input_val; 265 | if (input_plus_n == 0){ 266 | output_val = 0.0; 267 | } 268 | else { 269 | tmp2 = 0.0; 270 | while (input_plus_n < 10.0){ 271 | tmp2 -= 1.0 / input_plus_n; 272 | input_plus_n += 1.0; 273 | } 274 | input_plus_n -= 0.5; 275 | output_val = input_plus_n; 276 | temp = input_plus_n; 277 | input_plus_n *= input_plus_n; 278 | for (int idx = 0; idx < 4; ++idx){ 279 | temp /= input_plus_n; 280 | output_val += coef_list[idx] * temp; 281 | } 282 | if (tmp2 != 0.0){ 283 | output_val *= exp(tmp2); 284 | } 285 | } 286 | """ 287 | 288 | 289 | # This method corresponds to utils._exp_digamma method 290 | exp_digamma = utils.xp_utils.back("cupy").ElementwiseKernel( 291 | "float64 input_val", # input params 292 | "float64 output_val", # output params 293 | _exp_digamma_c_code, 294 | "exp_digamma", 295 | ) 296 | 297 | 298 | _hyp0f1ln_c_code = """ 299 | double temp_r, temp_a, output_temp_val, temp_v_val; 300 | bool keep_going; 301 | 302 | temp_r = 1.0; 303 | temp_a = 1.0; 304 | temp_v_val = v_val; 305 | keep_going = true; 306 | output_temp_val = 1.0; 307 | output_val = 0.0; 308 | 309 | while (keep_going){ 310 | if (output_temp_val > 1e300){ 311 | output_val += log(output_temp_val); 312 | temp_a /= output_temp_val; 313 | output_temp_val = 1.0; 314 | } 315 | temp_a *= z_val / (temp_r * temp_v_val); 316 | output_temp_val += temp_a; 317 | temp_r += 1.0; 318 | temp_v_val += 1.0; 319 | keep_going = temp_a > (tol * output_temp_val); 320 | } 321 | output_val += log(output_temp_val); 322 | """ 323 | 324 | _hyp0f1ln = utils.xp_utils.back("cupy").ElementwiseKernel( 325 | "float64 v_val, float64 z_val, float64 tol", # input params 326 | "float64 output_val", # output params 327 | _hyp0f1ln_c_code, 328 | "_hyp0f1ln", 329 | ) 330 | 331 | 332 | # This method corresponds to utils.hyp0f1ln method 333 | def hyp0f1ln(v_arr, z_arr, tol=1e-16): 334 | return _hyp0f1ln(v_arr, z_arr, tol) 335 | 336 | 337 | _bessel_ratio_c_code = """ 338 | double z_arr2, temp_pr, temp_v0, temp_v, temp_u, temp_w, temp_p, tol2, temp_t; 339 | bool keep_going; 340 | 341 | z_arr2 = z_arr / 2.0; 342 | output_val = 1.0; 343 | temp_pr = 1.0; 344 | temp_v0 = v_arr + z_arr2 + 1.0; 345 | temp_v = v_arr + z_arr + 1.5; 346 | temp_u = (v_arr + 1 + z_arr) * temp_v; 347 | temp_w = z_arr2 * (v_arr + 1.5); 348 | temp_p = temp_w / (temp_v0 * temp_v - temp_w); 349 | tol2 = tol * (1 + temp_p); 350 | temp_pr *= temp_p; 351 | output_val += temp_pr; 352 | keep_going = true; 353 | while (keep_going){ 354 | temp_u += temp_v; 355 | temp_v += 0.5; 356 | temp_w += z_arr2; 357 | temp_t = temp_w * (1 + temp_p); 358 | temp_p = temp_t / (temp_u - temp_t); 359 | temp_pr *= temp_p; 360 | output_val += temp_pr; 361 | keep_going = temp_pr > tol2; 362 | } 363 | output_val *= z_arr * z_arr / (z_arr+ 2 * v_arr + 2); 364 | """ 365 | 366 | _bessel_ratio = utils.xp_utils.back("cupy").ElementwiseKernel( 367 | "float64 v_arr, float64 z_arr, float64 tol", # input params 368 | "float64 output_val", # output params 369 | _bessel_ratio_c_code, 370 | "_bessel_ratio", 371 | ) 372 | 373 | 374 | # This method corresponds to utils.bessel_ratio method 375 | def bessel_ratio(v_arr, z_arr, tol=1e-16): 376 | return _bessel_ratio(v_arr, z_arr, tol) 377 | 378 | 379 | inclusive_scan_2d = utils.xp_utils.back("cupy").RawKernel( 380 | r""" 381 | // Inclusive scan on CUDA. 382 | extern "C" __global__ 383 | void inclusive_scan_2d( 384 | double *d_array, 385 | double *d_result, 386 | int dimRow, 387 | int dimCol, 388 | int batchDim, // must be power of two 389 | double *d_aux, 390 | bool store, 391 | int strideRowIn, 392 | int strideColIn, 393 | int strideRowOut, 394 | int strideColOut 395 | ) { 396 | extern __shared__ double temp[]; // dim is blockDim.x * (batchDim + 1) 397 | 398 | // index of input and result arrays 399 | int realIndexRow = blockDim.x * blockIdx.x + threadIdx.x; 400 | int realIndexCol = 2 * (blockDim.y * blockIdx.y + threadIdx.y); // 2 * blockDim.y must be == batchDim 401 | // int realIndexFlat = realIndexRow * dimCol + realIndexCol; 402 | int realIndexFlatIn = realIndexRow * strideRowIn + realIndexCol * strideColIn; 403 | int realIndexFlatOut = realIndexRow * strideRowOut + realIndexCol * strideColOut; 404 | 405 | int threadIndexRow = threadIdx.x; 406 | int threadIndexCol = threadIdx.y; 407 | 408 | // index of temp arr 409 | int indexRow = threadIndexRow; 410 | int indexCol = 2 * threadIndexCol; 411 | int indexStartFlat = indexRow * (batchDim + 1); 412 | int indexFlat = indexStartFlat + indexCol; 413 | 414 | int offset = 1; 415 | 416 | // Copy from the array to shared memory. 417 | if (realIndexCol < (dimCol - 1)){ 418 | temp[indexFlat] = d_array[realIndexFlatIn]; 419 | temp[indexFlat + 1] = d_array[realIndexFlatIn + strideColIn]; 420 | } 421 | else if (realIndexCol == (dimCol - 1)){ 422 | temp[indexFlat] = d_array[realIndexFlatIn]; 423 | temp[indexFlat + 1] = 0.; 424 | } 425 | else{ 426 | temp[indexFlat] = 0.; 427 | temp[indexFlat + 1] = 0.; 428 | } 429 | // Reduce by storing the intermediate values. The last element will be 430 | // the sum of n-1 elements. 431 | for (int d = blockDim.y; d > 0; d = d/2) { 432 | __syncthreads(); 433 | 434 | // Regulates the amount of threads operating. 435 | if (threadIndexCol < d) { 436 | // Swap the numbers 437 | int current = offset * (indexCol + 1) - 1; 438 | int next = offset * (indexCol + 2) - 1; 439 | if (next < batchDim){ 440 | temp[indexStartFlat + next] += temp[indexStartFlat + current]; 441 | } 442 | /* 443 | temp[indexStartFlat + next] += temp[indexStartFlat + current]; 444 | */ 445 | } 446 | 447 | // Increase the offset by multiple of 2. 448 | offset *= 2; 449 | } 450 | 451 | // Only one thread performs this. 452 | if (threadIndexCol == 0) { 453 | // Store the sum on the last index of temp 454 | temp[indexStartFlat + batchDim] = temp[indexStartFlat + batchDim - 1]; 455 | // Store the sum to the auxiliary array. 456 | if (store) { 457 | d_aux[blockIdx.x * gridDim.y + blockIdx.y] = temp[indexStartFlat + batchDim]; 458 | } 459 | // Reset the last element with identity. Only the first thread will do 460 | // the job. 461 | temp[indexStartFlat + batchDim - 1] = 0; 462 | } 463 | // Down sweep to build scan. 464 | for (int d = 1; d < blockDim.y*2; d *= 2) { 465 | 466 | // Reduce the offset by division of 2. 467 | offset = offset / 2; 468 | 469 | __syncthreads(); 470 | 471 | if (threadIndexCol < d) 472 | { 473 | int current = offset * (indexCol + 1) - 1; 474 | int next = offset * (indexCol + 2) - 1; 475 | 476 | // Swap 477 | if (next < batchDim){ 478 | double tempCurrent = temp[indexStartFlat + current]; 479 | temp[indexStartFlat + current] = temp[indexStartFlat + next]; 480 | temp[indexStartFlat + next] += tempCurrent; 481 | } else if (current < batchDim){ // peut-être pas nécessaire... 482 | temp[indexStartFlat + current] = 0; 483 | } 484 | /* 485 | double tempCurrent = temp[indexStartFlat + current]; 486 | temp[indexStartFlat + current] = temp[indexStartFlat + next]; 487 | temp[indexStartFlat + next] += tempCurrent; 488 | */ 489 | } 490 | } 491 | __syncthreads(); 492 | 493 | if (realIndexCol >= dimCol) {return;} 494 | d_result[realIndexFlatOut] = temp[indexFlat + 1]; // write results to device memory 495 | if (realIndexCol < dimCol - 1){ 496 | d_result[realIndexFlatOut + strideColOut] = temp[indexFlat + 2]; 497 | } 498 | } 499 | """, 500 | "inclusive_scan_2d", 501 | ) 502 | 503 | 504 | sum_inclusive_scan_2d = utils.xp_utils.back("cupy").RawKernel( 505 | r""" 506 | extern "C" __global__ 507 | void sum_inclusive_scan_2d( 508 | double *d_incr, 509 | double *d_result, 510 | int dimRow, 511 | int dimCol, 512 | int strideOutRow, 513 | int strideOutCol 514 | ) { 515 | if (blockIdx.y > 0){ 516 | double addThis = d_incr[blockIdx.x * gridDim.y + blockIdx.y - 1]; 517 | int tid_row = threadIdx.x + blockDim.x * blockIdx.x; 518 | int tidCol = threadIdx.y + blockDim.y * blockIdx.y; 519 | if ((tidCol < dimCol) && (tid_row < dimRow)){ 520 | d_result[tid_row * strideOutRow + tidCol * strideOutCol] += addThis; 521 | } 522 | } 523 | } 524 | """, 525 | "sum_inclusive_scan_2d", 526 | ) 527 | -------------------------------------------------------------------------------- /wonterfact/graphviz.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright 2020 Benoit Fuentes 3 | # 4 | # This file is part of Wonterfact. 5 | # 6 | # Wonterfact is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # any later version. 10 | # 11 | # Wonterfact is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Wonterfact. If not, see . 18 | # ---------------------------------------------------------------------------- 19 | 20 | """Module to draw operation tree of wonterfact models""" 21 | 22 | # Python System imports 23 | import re 24 | from pathlib import Path 25 | 26 | # Relative imports 27 | from . import core_nodes, root, observers, operators, utils, buds 28 | 29 | # Third-party imports 30 | import numpy as np 31 | import graphviz 32 | import string 33 | 34 | 35 | def _get_node_shape(node): 36 | if isinstance(node, observers._Observer): 37 | return "doublecircle" 38 | if isinstance(node, root.Root): 39 | return "none" 40 | return "ellipse" 41 | 42 | 43 | def _get_node_prefix(node, legend_dict, **extra_param): 44 | if isinstance(node, operators.Multiplier): 45 | if node.conv_idx_ids: 46 | conv_idx = "".join( 47 | [ 48 | legend_dict[idx]["letter"] 49 | for idx in node.index_id[::-1] 50 | if idx in node.conv_idx_ids 51 | ] 52 | ) 53 | # return "⊛{}".format(conv_idx) 54 | return "(∗){}".format(conv_idx) 55 | # return "×" 56 | # return "Π" 57 | return "(×)" 58 | 59 | if isinstance(node, operators.Multiplexer): 60 | if node.multiplexer_idx is not None: 61 | return "(∥{})".format( 62 | legend_dict[node.multiplexer_idx]["letter"] 63 | ) 64 | return "(∥)" 65 | 66 | if isinstance(node, operators.Adder): 67 | # return "+" 68 | # return "Σ" 69 | return "(+)" 70 | 71 | if isinstance(node, observers.PosObserver): 72 | integer_observations = extra_param.get("integer_observations", False) 73 | if integer_observations: 74 | return "ℕ" 75 | return "ℝ+" 76 | 77 | if isinstance(node, observers.RealObserver): 78 | integer_observations = extra_param.get("integer_observations", False) 79 | if integer_observations: 80 | return "ℤ" 81 | return "ℝ" 82 | 83 | if isinstance(node, operators.Integrator): 84 | return "∫{}".format( 85 | legend_dict[node.index_id[-1]]["letter"] 86 | ) 87 | 88 | if isinstance(node, operators.Proxy): 89 | return "(=)" 90 | 91 | return "" 92 | 93 | 94 | def _get_edge_label(node, child, legend_dict): 95 | if not isinstance(node, core_nodes._DynNodeData): 96 | return None 97 | label = "" 98 | slice_for_child = node.slicing_for_children_dict[child] 99 | if slice_for_child != Ellipsis: 100 | explicit_slice_raw = utils.explicit_slice(slice_for_child, node.tensor.ndim) 101 | explicit_slice = [] 102 | masked_idx = [] 103 | num_idx = 0 104 | for elem in explicit_slice_raw: 105 | elem_as_np = utils._is_bool_masking(elem) 106 | if elem_as_np is not None: 107 | masked_idx += node.index_id[num_idx : num_idx + elem_as_np.ndim] 108 | num_idx += elem_as_np.ndim 109 | explicit_slice += [slice(None),] * elem_as_np.ndim 110 | else: 111 | explicit_slice.append(elem) 112 | num_idx += 1 113 | 114 | for idx, sl, dim_axis in zip( 115 | node.index_id[::-1], explicit_slice[::-1], node.tensor.shape[::-1] 116 | ): 117 | if sl != slice(None): 118 | letter = legend_dict[idx]["letter"] 119 | if isinstance(sl, int): 120 | label += "{}={};".format(letter, sl) 121 | elif isinstance(sl, slice): 122 | start = "" if sl.start in [0, None] else sl.start 123 | step = "" if sl.step in [1, None] else sl.step 124 | stop = "" if sl.stop in [dim_axis, None] else sl.stop 125 | label += "{}={}:{}:{};".format(letter, start, stop, step) 126 | else: 127 | label += "{}=[" 128 | label += ",".join(str(ax) for ax in sl) 129 | label += "]" 130 | if masked_idx: 131 | masked_letters = "".join( 132 | [legend_dict[idx]["letter"] for idx in masked_idx[::-1]] 133 | ) 134 | label += "mask: " + masked_letters 135 | if node.strides_for_children_dict[child]: 136 | label += "strides: {}".format(node.strides_for_children_dict[child]) 137 | index_id_for_child = node.get_index_id_for_children(child) 138 | if index_id_for_child != node.index_id: 139 | label_new_idx = _insert_given_symbol( 140 | index_id_for_child, 141 | node.get_norm_axis_for_children(child), 142 | node.get_tensor_for_children(child).ndim, 143 | legend_dict, 144 | ) 145 | # label_new_idx = '→' + label_new_idx 146 | if label: 147 | label += "
" + label_new_idx 148 | else: 149 | label = label_new_idx 150 | return label or None 151 | 152 | 153 | def _draw_tree( 154 | tree, 155 | fileformat=None, 156 | filename=None, 157 | legend_dict=None, 158 | prior_nodes=False, 159 | view=True, 160 | show_node_names=False, 161 | integer_observations=False, 162 | ): 163 | if tree.current_iter == 0: 164 | tree._first_iteration(check_model_validity=False) 165 | fileformat = fileformat or "pdf" 166 | filename = filename or tree.name 167 | filename = Path(filename) 168 | suffix = filename.suffix 169 | if suffix: 170 | fileformat = suffix[1:] 171 | filename = filename.with_suffix("") 172 | legend_dict = legend_dict or {} 173 | all_index_id = set.union( 174 | *(set(node.index_id) for node in tree.census() if hasattr(node, "index_id")) 175 | ) 176 | used_letters = set() 177 | for elem in legend_dict.values(): 178 | letter = elem.get("letter", None) 179 | if letter is not None: 180 | used_letters.add(letter[0]) 181 | for index_id in all_index_id: 182 | if not index_id in legend_dict or "letter" not in legend_dict[index_id]: 183 | idx_id = index_id if isinstance(index_id, str) else "" 184 | idx_id2 = re.sub("[^a-z]+", "", idx_id) 185 | letter = next( 186 | ( 187 | let 188 | for let in idx_id2 + string.ascii_lowercase 189 | if let not in used_letters 190 | ), 191 | None, 192 | ) 193 | if letter is None: 194 | raise ValueError( 195 | "Not enough letters in the alphabet. Please provide" 196 | "`legend_dict` with your own letters with subscripts" 197 | "to represent each index_id" 198 | ) 199 | if not index_id in legend_dict: 200 | legend_dict[index_id] = {} 201 | legend_dict[index_id]["letter"] = letter 202 | used_letters.add(letter[0]) 203 | if "description" not in legend_dict[index_id] and letter != index_id: 204 | legend_dict[index_id]["description"] = index_id 205 | 206 | graph = graphviz.Digraph( 207 | name=tree.name, format=fileformat, filename=filename, engine="dot" 208 | ) 209 | graph.attr("node", color="#0b51c3f2", fontname="Times-Roman", height="0") 210 | graph.attr("edge", color="#0b51c3f2", arrowhead="none", fontname="Times-Roman") 211 | for node in tree.census(): 212 | xlabel = _html(_small_font(node.name)) if show_node_names else None 213 | 214 | # special form for root 215 | if node is tree: 216 | # node_label = """</ / / / / />""" 217 | # node_label = '<___
__
_
>' 218 | node_label = ( 219 | """<""" 220 | """""" 221 | """""" 222 | """""" 223 | """""" 224 | """""" 225 | """""" 226 | """""" 227 | """""" 228 | """""" 229 | """""" 230 | """""" 231 | """""" 232 | """""" 233 | """""" 234 | """
>""" 235 | ) 236 | # node_label = "" 237 | graph.node( 238 | str(id(node)), 239 | label=node_label, 240 | shape="plain", 241 | peripheries="0", 242 | xlabel=xlabel, 243 | forcelabels="true", 244 | # image="/home/fuentes/Projects/wonterfact/wonterfact/images/ground.svg", 245 | # shape="epsf", 246 | # shapefile="/home/fuentes/Projects/wonterfact/wonterfact/images/ground.ps", 247 | ) 248 | elif isinstance(node, observers.BlindObs): 249 | node_label = _make_node_label("", "?", True) 250 | with graph.subgraph(name="observers") as subg: 251 | subg.attr(rank="same") 252 | subg.node( 253 | str(id(node)), 254 | label=node_label, 255 | shape="ellipse", 256 | # shape="doublecircle", 257 | peripheries="2", 258 | style="diagonals", 259 | xlabel=xlabel, 260 | forcelabels="true", 261 | ) 262 | # graph.node( 263 | # str(id(node)), 264 | # label=node_label, 265 | # shape="ellipse", 266 | # style="diagonals", 267 | # peripheries="2", 268 | # xlabel=xlabel, 269 | # ) 270 | # all nodes except root 271 | else: 272 | # let us compute node_label 273 | if not node.index_id: 274 | node_label = _make_node_label( 275 | "", "·", underline=node.tensor_has_energy 276 | ) 277 | else: 278 | if node.tensor_has_energy or node.level == 0: 279 | index_label = "".join( 280 | [legend_dict[idx]["letter"] for idx in node.index_id[::-1]] 281 | ) 282 | index_label = _italic(index_label) 283 | underline = node.tensor_has_energy 284 | else: 285 | index_label = _insert_given_symbol( 286 | node.index_id, node.norm_axis, node.tensor.ndim, legend_dict 287 | ) 288 | underline = False 289 | node_prefix = _get_node_prefix( 290 | node, legend_dict, integer_observations=integer_observations 291 | ) 292 | node_prefix = _small_font(node_prefix) 293 | node_label = _make_node_label(node_prefix, index_label, underline) 294 | 295 | # special shape for observers 296 | if isinstance(node, observers._Observer): 297 | with graph.subgraph(name="observers") as subg: 298 | subg.attr(rank="same") 299 | subg.node( 300 | str(id(node)), 301 | label=node_label, 302 | shape="ellipse", 303 | # shape="doublecircle", 304 | peripheries="2", 305 | style="diagonals", 306 | xlabel=xlabel, 307 | forcelabels="true", 308 | ) 309 | # special shape for hyperparameter buds 310 | elif isinstance(node, buds._Bud): 311 | if hasattr(node, "update_period") and node.update_period == 0: 312 | peripheries = "2" 313 | else: 314 | peripheries = "1" 315 | if prior_nodes: 316 | graph.node( 317 | str(id(node)), 318 | label=node_label, 319 | shape="box", 320 | peripheries=peripheries, 321 | # style="diagonals", 322 | xlabel=xlabel, 323 | forcelabels="true", 324 | ) 325 | # leaves and operators 326 | else: 327 | if hasattr(node, "update_period") and node.update_period == 0: 328 | peripheries = "2" 329 | else: 330 | peripheries = "1" 331 | graph.node( 332 | str(id(node)), 333 | label=node_label, 334 | shape="ellipse", 335 | peripheries=peripheries, 336 | xlabel=xlabel, 337 | forcelabels="true", 338 | ) 339 | for node in tree.census(): 340 | if node != tree: 341 | if node.level != 0 or prior_nodes: 342 | for child in node.list_of_children: 343 | edge_label = _get_edge_label(node, child, legend_dict) 344 | if edge_label: 345 | edge_label = _html(_small_font(edge_label)) 346 | graph.edge(str(id(node)), str(id(child)), taillabel=edge_label) 347 | 348 | if any("description" in idx_dict for idx_dict in legend_dict.values()): 349 | label_legend = ( 350 | '' 351 | '" 353 | ) 354 | for idx_dict in legend_dict.values(): 355 | if "description" in idx_dict: 356 | label_legend += ( 357 | '' 358 | '' 359 | "".format(idx_dict["letter"], idx_dict["description"]) 360 | ) 361 | label_legend = label_legend + "
' 352 | "Indexes
{}:{}
" 362 | with graph.subgraph(name="observers") as subgraph: 363 | subgraph.attr(rank="same") 364 | subgraph.node( 365 | "legend", 366 | label=_html(label_legend), 367 | shape="box", 368 | # fontsize='8', 369 | style="dotted", 370 | ) 371 | 372 | if view: 373 | graph.view(cleanup=True) 374 | else: 375 | graph.render(cleanup=True) 376 | return graph 377 | 378 | 379 | def _html(input_str): 380 | return "<" + _clean_html(_detect_sub(input_str)) + ">" 381 | 382 | 383 | def _clean_html(input_str): 384 | input_str = "".join(input_str.split("")) 385 | input_str = "".join(input_str.split("")) 386 | return input_str 387 | 388 | 389 | def _italic(input_str): 390 | output_str = "" + input_str + "" 391 | return output_str 392 | 393 | 394 | def _underline(input_str): 395 | return "" + _italic(input_str) + "" 396 | 397 | 398 | def _small_font(input_str, fontsize=8): 399 | if not input_str: 400 | return input_str 401 | return ''.format(fontsize) + input_str + "" 402 | 403 | 404 | def _detect_sub(input_str): 405 | def change_sub(input_str): 406 | return "" + input_str.group()[2:-1] + "" 407 | 408 | return re.sub("_{[^}]+}", change_sub, input_str) 409 | 410 | 411 | def _insert_given_symbol(index_id, norm_axis, ndim, legend_dict): 412 | norm_axis = range(ndim) if norm_axis is None else norm_axis 413 | len_norm = len(norm_axis) 414 | if set(norm_axis) == set(range(ndim - len_norm, ndim)): 415 | index_id_list = index_id 416 | else: 417 | index_id_list = [ 418 | idx for num_idx, idx in enumerate(index_id) if num_idx not in norm_axis 419 | ] + [idx for num_idx, idx in enumerate(index_id) if num_idx in norm_axis] 420 | index_label = "".join([legend_dict[idx]["letter"] for idx in index_id_list[::-1]]) 421 | if len_norm < len(index_id_list): 422 | index_label = ( 423 | _italic(index_label[:len_norm]) + " |" + _italic(index_label[len_norm:]) 424 | ) 425 | else: 426 | index_label = _italic(index_label) 427 | if len_norm == 0: 428 | index_label = "·" + index_label 429 | return index_label 430 | 431 | 432 | def _make_node_label(prefix, suffix, underline=False): 433 | if underline: 434 | underline_str = ' sides="b"' 435 | else: 436 | underline_str = ' border="0"' 437 | if prefix: 438 | first_col = '{} '.format(prefix) 439 | else: 440 | first_col = "" 441 | node_label = ( 442 | '' 443 | '' 444 | "{}{}" 445 | "" 446 | "
".format(first_col, underline_str, suffix) 447 | ) 448 | return _html(node_label) 449 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright 2020 Benoit Fuentes 3 | # 4 | # This file is part of Wonterfact. 5 | # 6 | # Wonterfact is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # any later version. 10 | # 11 | # Wonterfact is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Wonterfact. If not, see . 18 | # ---------------------------------------------------------------------------- 19 | 20 | """Tests for all methods in utils module""" 21 | 22 | # Python standard library 23 | 24 | # Third-party imports 25 | import numpy as np 26 | import numpy.random as npr 27 | import pytest 28 | 29 | try: 30 | import cupy as cp # pylint: disable=import-error 31 | except: 32 | cp = None 33 | 34 | # wonterfact imports 35 | from wonterfact import utils 36 | from wonterfact import glob 37 | 38 | 39 | def prepare_for_numpy_einsum(args): 40 | letter_to_int = {} 41 | for num_letter, letter in enumerate(set(args[1] + args[3] + args[5])): 42 | letter_to_int[letter] = num_letter 43 | sub0 = [letter_to_int[letter] for letter in args[1]] 44 | sub1 = [letter_to_int[letter] for letter in args[3]] 45 | subout = [letter_to_int[letter] for letter in args[5]] 46 | return [args[0], sub0, args[2], sub1, subout] 47 | 48 | 49 | @pytest.fixture( 50 | scope="module", params=["cpu", pytest.param("gpu", marks=pytest.mark.gpu)] 51 | ) 52 | def backend_and_operations_as_list(request): 53 | """ 54 | Returns a list of (op0, sub0, op1, sub1, out, subout) 55 | """ 56 | xp = cp if request.param == "gpu" else np 57 | arr0 = xp.random.rand(4, 4, 4) 58 | arr1 = xp.random.rand(4, 4, 4) 59 | # Please update 'expected_func' dictionary in 60 | # 'test_parse_einsum_two_operands_input' if you update the following list. 61 | input_list = [ 62 | ["ijk", "ijk", "ijk"], 63 | ["ijk", "ijk", "ij"], 64 | ["ijk", "ijk", "i"], 65 | ["ijk", "ijk", ""], 66 | ["ijk", "ijk", "kji"], 67 | ["ijk", "jik", "jk"], 68 | ["ijk", "klm", "ijlm"], 69 | ["ijk", "klm", "im"], 70 | ["ijk", "mkj", "im"], 71 | ["ijk", "imj", "im"], 72 | ["ijk", "ikl", "ijl"], 73 | ["ijk", "ikl", "jl"], 74 | ["ijk", "kli", "ji"], 75 | ] 76 | operation_list = [ 77 | ( 78 | arr0, 79 | elem[0], 80 | arr1, 81 | elem[1], 82 | xp.zeros((4,) * len(elem[2]), dtype=float), 83 | elem[2], 84 | ) 85 | for elem in input_list 86 | ] 87 | return xp, operation_list 88 | 89 | 90 | @pytest.fixture(scope="module") 91 | def operations_as_list(backend_and_operations_as_list): 92 | return backend_and_operations_as_list[1] 93 | 94 | 95 | @pytest.fixture(scope="module") 96 | def operations_as_string(operations_as_list): 97 | """ 98 | Returns a list of ('sub0,sub1->subout', op0, op1, out) 99 | """ 100 | operation_list = [ 101 | ("->".join([",".join(elem[1:4:2]), elem[5]]), elem[0], elem[2], elem[4]) 102 | for elem in operations_as_list 103 | ] 104 | return operation_list 105 | 106 | 107 | @pytest.fixture(scope="module") 108 | def xp(backend_and_operations_as_list): 109 | return backend_and_operations_as_list[0] 110 | 111 | 112 | def test_has_a_shared_sublist(): 113 | assert utils._has_a_shared_sublist([], [1, 2, 3]) 114 | assert utils._has_a_shared_sublist([1, 2, 3, 4], [2, 5, 3]) 115 | assert not utils._has_a_shared_sublist([1, 2, 3, 4], [0, 3, 2]) 116 | assert not utils._has_a_shared_sublist( 117 | [1, 2, 3, 4], [1, 2, 5, 6, 7, 8, 5, 76, 4, 3] 118 | ) 119 | 120 | 121 | def test_get_transpose_and_slice(): 122 | sub1, sub2, sub_out = "ij", "jkl", "kilj" 123 | transpose1, slice1 = utils.get_transpose_and_slice(sub1, sub_out) 124 | transpose2, slice2 = utils.get_transpose_and_slice(sub2, sub_out) 125 | transpose3, slice3 = utils.get_transpose_and_slice("", "") 126 | 127 | assert transpose1 == (0, 1) 128 | assert transpose2 == (1, 2, 0) 129 | assert slice1 == (None, slice(None), None, slice(None)) 130 | assert slice2 == (slice(None), None, slice(None), slice(None)) 131 | assert transpose3 == () 132 | assert slice3 == Ellipsis 133 | 134 | 135 | def test_supscript_summation(xp): 136 | sub1, sub2, sub_out = "ij", "jkl", "kilj" 137 | op1 = xp.arange(3 * 5).reshape((3, 5)) 138 | op2 = xp.arange(5 * 2 * 4).reshape((5, 2, 4)) 139 | res = utils.supscript_summation(op1, sub1, op2, sub2, sub_out) 140 | assert xp.allclose(res, op1[:, None] + op2.transpose((1, 2, 0))[:, None]) 141 | 142 | 143 | def test_parse_einsum_args(): 144 | args1 = ("fd,dt,d->ft", 1, 2, 3) 145 | args2 = (1, "fd", 2, "dt", 3, "d", "ft") 146 | out1 = utils._parse_einsum_args(*args1) 147 | out2 = utils._parse_einsum_args(*args2) 148 | assert tuple(out1[:-1:2]) == args2[:-1:2] 149 | assert tuple(out2[:-1:2]) == args2[:-1:2] 150 | 151 | 152 | def test_einsum(operations_as_string, xp): 153 | for elem in operations_as_string: 154 | assert xp.allclose(utils.einsum(*elem[:3]), xp.einsum(*elem[:3])) 155 | for elem in operations_as_string: 156 | elem[3][...] = 0.0 157 | utils.einsum(*elem[:3], out=elem[3]) 158 | assert xp.allclose(elem[3], xp.einsum(*elem[:3])) 159 | 160 | arr1 = xp.random.rand(4, 4, 4) 161 | arr2 = xp.random.rand(4, 4, 4) 162 | arr3 = xp.random.rand(4, 4) 163 | assert xp.allclose( 164 | utils.einsum("ijk,ikl,il->ijl", arr1, arr2, arr3), 165 | xp.einsum("ijk,ikl,il->ijl", arr1, arr2, arr3), 166 | ) 167 | 168 | 169 | def test_einsum_two_operands(operations_as_list, xp): 170 | backend = xp.__name__ 171 | for elem in operations_as_list: 172 | if backend == "cupy": 173 | pass 174 | assert xp.allclose( 175 | utils._einsum_two_operands(*elem[:4], elem[5], backend), 176 | xp.einsum(*prepare_for_numpy_einsum(elem)), 177 | ) 178 | for elem in operations_as_list: 179 | elem[4][...] = 0.0 180 | utils._einsum_two_operands(*elem[:4], elem[5], backend, out=elem[4]), 181 | assert xp.allclose(elem[4], xp.einsum(*prepare_for_numpy_einsum(elem))) 182 | 183 | 184 | def test_element_wise_mult(xp): 185 | arr1 = xp.random.rand(4, 4, 4) 186 | arr2 = xp.random.rand(4, 4) 187 | out = xp.ones((4, 4, 4)) 188 | sub1 = sub2 = sub_out = None 189 | backend = xp.__name__ 190 | transpose1 = (2, 0, 1) 191 | transpose2 = (1, 0) 192 | slice1 = Ellipsis 193 | slice2 = (slice(None), None) 194 | utils._element_wise_mult( 195 | arr1, 196 | sub1, 197 | arr2, 198 | sub2, 199 | out, 200 | sub_out, 201 | backend, 202 | transpose1, 203 | transpose2, 204 | slice1, 205 | slice2, 206 | ) 207 | assert xp.allclose(out, xp.einsum("ijk,jk->kij", arr1, arr2)) 208 | 209 | 210 | def test_regular_einsum(operations_as_list, xp): 211 | backend = xp.__name__ 212 | for elem in operations_as_list: 213 | elem[4][...] = 0.0 214 | utils._regular_einsum(*elem, backend) 215 | assert xp.allclose(elem[4], xp.einsum(*prepare_for_numpy_einsum(elem))) 216 | 217 | 218 | def test_sequential_tensor_dot(xp): 219 | op1 = xp.random.rand(4, 4, 4) 220 | sub1 = "ijk" 221 | op2 = xp.random.rand(4, 4, 4) 222 | sub2 = "kil" 223 | out = xp.random.rand(4, 4, 4) 224 | sub_out = "lij" 225 | backend = xp.__name__ 226 | list_of_slice_out = [(slice(None), dim) for dim in range(4)] 227 | list_of_slice_1 = [(dim,) for dim in range(4)] 228 | list_of_slice_2 = [(slice(None), dim) for dim in range(4)] 229 | new_sub1 = ["j", "k"] 230 | new_sub2 = ["k", "l"] 231 | new_sub_out = ["l", "j"] 232 | out_shape = (4, 4, 4) 233 | utils._sequential_tensor_dot( 234 | op1, 235 | sub1, 236 | op2, 237 | sub2, 238 | out, 239 | sub_out, 240 | backend, 241 | list_of_slice_out, 242 | list_of_slice_1, 243 | list_of_slice_2, 244 | new_sub1, 245 | new_sub2, 246 | new_sub_out, 247 | out_shape, 248 | ) 249 | assert xp.allclose( 250 | out, xp.einsum("{},{}->{}".format(sub1, sub2, sub_out), op1, op2) 251 | ) 252 | 253 | 254 | def test_parse_einsum_two_operands_input(operations_as_list, xp): 255 | # this refers operations_as_list 256 | expected_func = { 257 | # cupy 258 | cp: [ 259 | utils._element_wise_mult, 260 | utils._element_wise_mult_and_sum, 261 | utils._element_wise_mult_and_sum, 262 | utils._element_wise_mult_and_sum, 263 | utils._element_wise_mult, 264 | utils._element_wise_mult_and_sum, 265 | utils._sequential_tensor_dot, 266 | utils._regular_einsum, 267 | utils._sequential_tensor_dot, 268 | utils._regular_einsum, 269 | utils._sequential_tensor_dot, 270 | utils._sequential_tensor_dot, 271 | utils._regular_einsum, 272 | ], 273 | # numpy 274 | np: [ 275 | utils._element_wise_mult, 276 | utils._regular_einsum, 277 | utils._regular_einsum, 278 | utils._regular_einsum, 279 | utils._element_wise_mult, 280 | utils._regular_einsum, 281 | utils._sequential_tensor_dot, 282 | utils._regular_einsum, 283 | utils._sequential_tensor_dot, 284 | utils._regular_einsum, 285 | utils._sequential_tensor_dot, 286 | utils._sequential_tensor_dot, 287 | utils._regular_einsum, 288 | ], 289 | } 290 | backend = xp.__name__ 291 | for elem, func in zip(operations_as_list, expected_func[xp]): 292 | assert ( 293 | utils._parse_einsum_two_operands_input( 294 | elem[0].shape, elem[1], elem[2].shape, elem[3], elem[5], backend 295 | )[0] 296 | == func 297 | ) 298 | 299 | 300 | def test_element_wise_mult_and_sum(xp): 301 | if xp == cp: 302 | op1 = xp.random.rand(4, 4, 4) 303 | sub1 = "ijk" 304 | op2 = xp.random.rand(4, 4, 4) 305 | sub2 = "jik" 306 | out = xp.random.rand(4, 4) 307 | sub_out = "jk" 308 | backend = xp.__name__ 309 | switch_operand = False 310 | transpose1 = (1, 2, 0) 311 | transpose2 = (0, 2, 1) 312 | slice1 = Ellipsis 313 | sum_axis = 2 314 | utils._element_wise_mult_and_sum( 315 | op1, 316 | sub1, 317 | op2, 318 | sub2, 319 | out, 320 | sub_out, 321 | backend, 322 | switch_operand, 323 | transpose1, 324 | transpose2, 325 | slice1, 326 | sum_axis, 327 | ) 328 | assert xp.allclose( 329 | out, xp.einsum("{},{}->{}".format(sub1, sub2, sub_out), op1, op2) 330 | ) 331 | else: 332 | assert True 333 | 334 | 335 | def test_einsum_as_dot(xp): 336 | op1 = xp.random.rand(4, 4, 4) 337 | sub1 = "ijk" 338 | op2 = xp.random.rand(4, 4, 4) 339 | sub2 = "kil" 340 | sub_out = "lj" 341 | backend = xp.__name__ 342 | assert xp.allclose( 343 | utils._einsum_as_dot(op1, sub1, op2, sub2, sub_out, backend), 344 | xp.einsum("{},{}->{}".format(sub1, sub2, sub_out), op1, op2), 345 | ) 346 | 347 | 348 | def test_einconv(operations_as_list, xp): 349 | # first, we test with no convolution 350 | for elem in operations_as_list: 351 | elem[4][...] = 0.0 352 | utils.einconv(*elem[:4], elem[5], out=elem[4]) 353 | assert xp.allclose(elem[4], xp.einsum(*prepare_for_numpy_einsum(elem))) 354 | 355 | # get convolution method for numpy or scipy 356 | if xp == np: 357 | from scipy.ndimage import convolve 358 | elif xp == cp: 359 | from cupyx.scipy.ndimage import convolve # pylint: disable=import-error 360 | 361 | # 1D convolution 362 | arr1 = xp.random.rand(10, 3, 10) 363 | arr2 = xp.random.rand(2, 3) 364 | sub1 = "ijk" 365 | sub2 = "kj" 366 | subout = "jki" 367 | conv_idx_list = [ 368 | "k", 369 | ] 370 | args = (arr1, sub1, arr2, sub2, subout) 371 | out = utils.einconv(*args, conv_idx_list=conv_idx_list) 372 | out2 = xp.empty((3, 9, 10)) 373 | for ii in range(10): 374 | for jj in range(3): 375 | out2[jj, :, ii] = convolve(arr1[ii, jj, :], arr2[:, jj])[:-1] 376 | assert xp.allclose(out, out2) 377 | 378 | # independent with respect to operands order ? 379 | assert xp.allclose( 380 | out, utils.einconv(*args[2:4], *args[:2], args[-1], conv_idx_list=conv_idx_list) 381 | ) 382 | 383 | # in correlation mode 384 | out = utils.einconv(*args, conv_idx_list=conv_idx_list, compute_correlation=True) 385 | for ii in range(10): 386 | for jj in range(3): 387 | out2[jj, :, ii] = convolve(arr1[ii, jj, :], arr2[::-1, jj])[:-1] 388 | assert xp.allclose(out, out2) 389 | 390 | # 2D convolution 391 | conv_idx_list = ["k", "j"] 392 | out = utils.einconv(*args, conv_idx_list=conv_idx_list) 393 | out2 = xp.empty((1, 9, 10)) 394 | for ii in range(10): 395 | out2[:, :, ii] = convolve(arr1[ii, :, :], arr2.T)[1:-1, :-1] 396 | assert xp.allclose(out, out2) 397 | 398 | # independent with respect to operands order ? 399 | assert xp.allclose( 400 | out, utils.einconv(*args[2:4], *args[:2], args[-1], conv_idx_list=conv_idx_list) 401 | ) 402 | # in correlation mode 403 | out = utils.einconv(*args, conv_idx_list=conv_idx_list, compute_correlation=True) 404 | for ii in range(10): 405 | for jj in range(3): 406 | out2[:, :, ii] = convolve(arr1[ii, :, :], arr2.T[::-1, ::-1])[1:-1, :-1] 407 | assert xp.allclose(out, out2) 408 | 409 | # should raise an error if one array is not larger than the other in every 410 | # dimension where convolution is performed 411 | with pytest.raises(ValueError, match=r".*must be at least as large*"): 412 | utils.einconv( 413 | xp.ones((2, 1)), 414 | "ij", 415 | xp.ones((1, 2)), 416 | "ij", 417 | "ij", 418 | conv_idx_list=["i", "j"], 419 | ) 420 | with pytest.raises(ValueError, match=r".*must be at least as large*"): 421 | utils.einconv( 422 | xp.ones((1, 2)), 423 | "ij", 424 | xp.ones((2, 1)), 425 | "ij", 426 | "ij", 427 | conv_idx_list=["i", "j"], 428 | ) 429 | 430 | 431 | def test_make_unique_hashable(): 432 | forbiden_set = set([0, 1, 2, 3, "a", "b", "c"]) 433 | assert utils._make_unique_hashable(forbiden_set) not in forbiden_set 434 | 435 | 436 | def test_find_equality_root(xp): 437 | backend = xp.__name__ 438 | max_iter = 10 439 | 440 | # inequality constraint should not be binding 441 | arr = 100 * xp.ones((2, 2, 2)) 442 | const_coef = xp.array([1, -1]) 443 | arr[:, :, :1] = 200 444 | norm_arr = arr.sum(2, keepdims=True) 445 | sigma = utils._find_equality_root( 446 | arr, norm_arr, const_coef, max_iter, type="inequality", backend=backend 447 | ) 448 | assert xp.allclose(sigma, 0) 449 | assert xp.allclose((arr / norm_arr).sum(2), 1.0) 450 | 451 | # inequality constraint should be binding 452 | arr[:, :, :1] = 50 453 | norm_arr = arr.sum(2, keepdims=True) 454 | sigma = utils._find_equality_root( 455 | arr, norm_arr, const_coef, max_iter, type="inequality", backend=backend 456 | ) 457 | assert xp.allclose(((arr / (norm_arr - sigma * const_coef)) * const_coef).sum(2), 0) 458 | assert xp.allclose((arr / (norm_arr - sigma * const_coef)).sum(2), 1.0) 459 | 460 | # force equality constraint 461 | arr[:, :, :1] = 200 462 | norm_arr = arr.sum(2, keepdims=True) 463 | sigma = utils._find_equality_root( 464 | arr, norm_arr, const_coef, max_iter, type="equality", backend=backend 465 | ) 466 | assert np.allclose(((arr / (norm_arr - sigma * const_coef)) * const_coef).sum(2), 0) 467 | assert np.allclose((arr / (norm_arr - sigma * const_coef)).sum(2), 1.0) 468 | 469 | # inequality constraint should not be binding 470 | const_coef = xp.ones((2, 2)) 471 | const_coef[:, 1] = -1 472 | arr[:, :, 0] = 200 473 | norm_arr = arr.sum((1, 2), keepdims=True) 474 | sigma = utils._find_equality_root( 475 | arr, norm_arr, const_coef, max_iter, type="inequality", backend=backend 476 | ) 477 | assert xp.allclose(sigma, 0) 478 | 479 | # inequality constraint should be binding 480 | arr[:, :, 0] = 50 481 | norm_arr = arr.sum((1, 2), keepdims=True) 482 | sigma = utils._find_equality_root( 483 | arr, norm_arr, const_coef, max_iter, type="inequality", backend=backend 484 | ) 485 | assert xp.allclose( 486 | ((arr / (norm_arr - sigma * const_coef)) * const_coef).sum((1, 2)), 0 487 | ) 488 | assert xp.allclose((arr / (norm_arr - sigma * const_coef)).sum((1, 2)), 1.0) 489 | 490 | # when no normalization 491 | norm_arr = xp.array(1 + 0.001) 492 | const_coef = xp.ones((2, 2, 2)) 493 | const_coef[:, :, 1] = -1 494 | arr[:, :, 0] = 200 495 | sigma = utils._find_equality_root( 496 | arr, norm_arr, const_coef, max_iter, type="inequality", backend=backend 497 | ) 498 | assert xp.allclose(sigma, 0) 499 | 500 | arr[:, :, 0] = 50 501 | sigma = utils._find_equality_root( 502 | arr, norm_arr, const_coef, max_iter, type="inequality", backend=backend 503 | ) 504 | assert xp.allclose(((arr / (norm_arr - sigma * const_coef)) * const_coef).sum(), 0) 505 | 506 | 507 | def test_xlogy(xp): 508 | arr1 = xp.ones(1, dtype=float) 509 | arr2 = xp.ones(1, dtype=float) 510 | assert xp.allclose(utils.xlogy(arr1, arr2), 0.0) 511 | 512 | arr1[...] = 0.0 513 | arr2[...] = 0.0 514 | assert xp.allclose(utils.xlogy(arr1, arr2), 0.0) 515 | 516 | arr1[...] = 2.0 517 | arr2[...] = 2.0 518 | assert xp.allclose(utils.xlogy(arr1, arr2), arr1 * xp.log(arr2)) 519 | 520 | out = xp.ones_like(arr1) 521 | utils.xlogy(arr1, arr2, out=out) 522 | assert xp.allclose(utils.xlogy(arr1, arr2), out) 523 | 524 | 525 | def test_cumsum_last_axis(xp): 526 | arr = xp.random.rand(4, 4, 10) 527 | out = xp.ones_like(arr) 528 | utils.cumsum_last_axis(arr, out) 529 | assert xp.allclose(out, xp.cumsum(arr, axis=-1)) 530 | 531 | 532 | def test_next_pow_of_two(): 533 | assert utils.next_pow_of_two(0) == 1 534 | assert utils.next_pow_of_two(0.5) == 1 535 | assert utils.next_pow_of_two(1) == 1 536 | assert utils.next_pow_of_two(1.1) == 2 537 | assert utils.next_pow_of_two(2) == 2 538 | assert utils.next_pow_of_two(8.5) == 16 539 | 540 | 541 | def test_exp_digamma(xp): 542 | if xp == np: 543 | from scipy.special import digamma 544 | elif xp == cp: 545 | from cupyx.scipy.special import digamma # pylint: disable=import-error 546 | 547 | arr = xp.array(1.0) 548 | assert xp.allclose(utils.exp_digamma(arr), xp.exp(digamma(arr))) 549 | arr[...] = 0.0 550 | assert xp.allclose(utils.exp_digamma(arr), xp.exp(digamma(arr))) 551 | arr[...] = 10.0 552 | assert xp.allclose(utils.exp_digamma(arr), xp.exp(digamma(arr))) 553 | 554 | 555 | def test_hyp0f1ln(xp): 556 | from scipy.special import hyp0f1 557 | 558 | if xp == cp: 559 | hyp0f1_np = hyp0f1 560 | 561 | def hyp0f1(arr1, arr2): # pylint: disable=function-redefined 562 | arr1 = xp.asnumpy(arr1) 563 | arr2 = xp.asnumpy(arr2) 564 | return xp.array(hyp0f1_np(arr1, arr2)) 565 | 566 | arr1 = xp.array([2.0, 1e-8, 1e5]) 567 | arr2 = xp.array([3.0, 2e-8, 2e5]) 568 | assert xp.allclose(utils.hyp0f1ln(arr1, arr2), xp.log(hyp0f1(arr1, arr2))) 569 | arr2[...] = 0.0 570 | assert xp.allclose(utils.hyp0f1ln(arr1, arr2), 0) 571 | 572 | 573 | def test_bessel_ratio(xp): 574 | from scipy.special import iv 575 | 576 | if xp == cp: 577 | iv_np = iv 578 | 579 | def iv(arr1, arr2): # pylint: disable=function-redefined 580 | arr1 = xp.asnumpy(arr1) 581 | arr2 = xp.asnumpy(arr2) 582 | return xp.array(iv_np(arr1, arr2)) 583 | 584 | arr1 = xp.array([2.0, 1e-8, 1e2]) 585 | arr2 = xp.array([3.0, 2e-8, 2e2]) 586 | assert xp.allclose( 587 | utils.bessel_ratio(arr1, arr2), arr2 * iv(arr1 + 1, arr2) / iv(arr1, arr2) 588 | ) 589 | arr2[...] = 0.0 590 | assert xp.allclose(utils.bessel_ratio(arr1, arr2), 0) 591 | 592 | 593 | def test_forced_iter(): 594 | elem = [3, 4] 595 | output = tuple(ii for ii in utils.forced_iter(elem)) 596 | assert output == tuple(elem) 597 | elem = 3 598 | output = tuple(ii for ii in utils.forced_iter(elem)) 599 | assert output == (elem,) 600 | elem = () 601 | output = tuple(ii for ii in utils.forced_iter(elem)) 602 | assert output == () 603 | 604 | 605 | def test_explicit_slice(): 606 | ndim = 3 607 | list_of_input = [ 608 | Ellipsis, 609 | (Ellipsis,), 610 | (slice(None), Ellipsis), 611 | (slice(None), Ellipsis, slice(None)), 612 | (Ellipsis, slice(None)), 613 | (slice(None),), 614 | slice(None), 615 | ] 616 | for sl in list_of_input: 617 | assert utils.explicit_slice(sl, ndim) == (slice(None),) * ndim 618 | 619 | list_of_input = [ 620 | 1, 621 | (1,), 622 | (1, Ellipsis), 623 | (1, Ellipsis, slice(None)), 624 | (1, slice(None)), 625 | (1, slice(None), Ellipsis), 626 | ] 627 | for sl in list_of_input: 628 | assert utils.explicit_slice(sl, ndim) == (1,) + (slice(None),) * (ndim - 1) 629 | 630 | list_of_input = [(slice(1), 2), (slice(1), 2, Ellipsis)] 631 | for sl in list_of_input: 632 | assert utils.explicit_slice(sl, ndim) == (slice(1), 2, slice(None)) 633 | 634 | assert utils.explicit_slice(Ellipsis, 0) == Ellipsis 635 | assert utils.explicit_slice(slice(None), 1) == (slice(None),) 636 | 637 | mask = np.array([[True, False], [False, True]]) 638 | list_of_input = [ 639 | mask, 640 | (mask, slice(None)), 641 | (mask, Ellipsis), 642 | ] 643 | for sl in list_of_input: 644 | explicit_sl = utils.explicit_slice(sl, ndim) 645 | assert len(explicit_sl) == ndim - 1 646 | assert np.alltrue(explicit_sl[0] == mask) 647 | assert explicit_sl[1:] == (slice(None),) * (ndim - 2) 648 | 649 | 650 | def test_real_to_2D_nonnegative(): 651 | arr = np.array([[1, -1, 2], [-3, 3.2, 0]]) 652 | arr_nonneg = utils.real_to_2D_nonnegative(arr) 653 | assert arr_nonneg.shape == arr.shape + (2,) 654 | assert np.allclose(arr_nonneg[..., 0], arr.clip(min=0)) 655 | assert np.allclose(arr_nonneg[..., 1], (-arr).clip(min=0)) 656 | 657 | 658 | def test_complex_to_2D_real(): 659 | arr = np.array([[1, -1, 2], [-3, 3.2, 0]]) + 1j * np.array( 660 | [[3, 1, -2], [3, 0, -1.1]] 661 | ) 662 | arr_real = utils.complex_to_2D_real(arr) 663 | assert arr_real.shape == arr.shape + (2,) 664 | assert np.allclose(arr_real[..., 0], arr.real) 665 | assert np.allclose(arr_real[..., 1], arr.imag) 666 | 667 | 668 | def test_complex_to_4D_nonnegative(): 669 | arr = np.array([[1, -1, 2], [-3, 3.2, 0]]) + 1j * np.array( 670 | [[3, 1, -2], [3, 0, -1.1]] 671 | ) 672 | arr_real = utils.complex_to_4D_nonnegative(arr) 673 | assert arr_real.shape == arr.shape + (2, 2) 674 | assert np.allclose(arr_real[..., 0, 0], arr.real.clip(min=0)) 675 | assert np.allclose(arr_real[..., 0, 1], (-arr.real).clip(min=0)) 676 | assert np.allclose(arr_real[..., 1, 0], arr.imag.clip(min=0)) 677 | assert np.allclose(arr_real[..., 1, 1], (-arr.imag).clip(min=0)) 678 | 679 | 680 | def test_clip_inplace(xp): 681 | arr_init = xp.array([3.0, 4.0]) 682 | 683 | arr_out = arr_init.copy() 684 | utils.clip_inplace(arr_out, a_min=3.5) 685 | assert (arr_out == xp.array([3.5, 4.0])).all() 686 | 687 | arr_out = arr_init.copy() 688 | utils.clip_inplace(arr_out, a_max=3.5) 689 | assert (arr_out == xp.array([3.0, 3.5])).all() 690 | 691 | arr_out = arr_init.copy() 692 | utils.clip_inplace(arr_out, a_min=3.5, a_max=3.8) 693 | assert (arr_out == xp.array([3.5, 3.8])).all() 694 | 695 | 696 | def test_inverse_gamma(xp): 697 | if xp == np: 698 | from scipy.special import digamma 699 | elif xp == cp: 700 | from cupyx.scipy.special import digamma # pylint: disable=import-error 701 | input_arr_pos = xp.array([1.0, 2.0, 3.443]) 702 | assert xp.allclose(utils.inverse_digamma(digamma(input_arr_pos)), input_arr_pos) 703 | assert xp.allclose(digamma(utils.inverse_digamma(input_arr_pos)), input_arr_pos) 704 | input_arr_neg = xp.array([-1.0, -2.0, -3.443]) 705 | assert xp.allclose(digamma(utils.inverse_digamma(input_arr_neg)), input_arr_neg) 706 | 707 | output_arr = xp.zeros_like(input_arr_pos) 708 | utils.inverse_digamma(input_arr_pos, out=output_arr) 709 | assert xp.allclose(utils.inverse_digamma(input_arr_pos), output_arr) 710 | 711 | 712 | def test_normalize(xp): 713 | tensor = xp.random.rand(3, 4, 2) 714 | axis = (1, 2) 715 | assert xp.allclose(utils.normalize(tensor, axis).sum(axis), 1) 716 | -------------------------------------------------------------------------------- /wonterfact/operators.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # Copyright 2020 Smart Impulse SAS, Benoit Fuentes 3 | # 4 | # This file is part of Wonterfact. 5 | # 6 | # Wonterfact is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # any later version. 10 | # 11 | # Wonterfact is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with Wonterfact. If not, see . 18 | # ---------------------------------------------------------------------------- 19 | 20 | """Module for all operator classes""" 21 | 22 | # Python System imports 23 | from functools import cached_property 24 | 25 | # Third-party imports 26 | import numpy as np 27 | from methodtools import lru_cache 28 | 29 | # Relative imports 30 | from . import utils 31 | from .core_nodes import _ChildNode, _DynNodeData 32 | from .glob_var_manager import glob 33 | 34 | 35 | class _Operator(_ChildNode, _DynNodeData): 36 | """ 37 | Mother class of all operators 38 | """ 39 | 40 | @cached_property 41 | def tensor_has_energy(self): 42 | return any(parent.tensor_has_energy for parent in self.list_of_parents) 43 | 44 | @cached_property 45 | def norm_axis(self): 46 | if self.tensor_has_energy: 47 | return None 48 | norm_axis = [] 49 | for num_idx, idx in enumerate(self.index_id): 50 | is_idx_in_norm_axis = next( 51 | ( 52 | bool(parent) 53 | for parent in self.list_of_parents 54 | if idx in parent.get_index_id_for_children(self) 55 | and parent.get_index_id_for_children(self).index(idx) 56 | in parent.get_norm_axis_for_children(self) 57 | ), 58 | False, 59 | ) 60 | if is_idx_in_norm_axis: 61 | norm_axis.append(num_idx) 62 | return tuple(norm_axis) 63 | 64 | 65 | class Proxy(_Operator): 66 | """ 67 | A proxy class to any node. 68 | 69 | Can be useful when one wants a parent and a child being linked several 70 | times (which is not possible straightforward). A proxy can have only one 71 | parent. 72 | """ 73 | 74 | # TODO: compatible with Mutliplexer as child 75 | # TODO: if no brother, could optimize tensor_update which is copied by its parent 76 | 77 | max_parents = 1 78 | 79 | def _give_update(self, parent, out=None): 80 | if out is None: 81 | return self.tensor_update 82 | out[...] = self.tensor_update 83 | 84 | def _initialization(self): 85 | self.tensor = self.first_parent.get_tensor_for_children(self) 86 | if not glob.xp.may_share_memory(self.tensor, self.first_parent.tensor): 87 | raise ValueError( 88 | """A Proxy's tensor should be a view to its parent's: 89 | I guess you cannot use this 'slice_for_child'""" 90 | ) 91 | self.tensor_update = glob.xp.zeros_like(self.tensor) 92 | 93 | def _check_filiation_ok(self, child=None, parent=None, **kwargs): 94 | if child is not None: 95 | if kwargs.get("slice_for_child", Ellipsis) != Ellipsis: 96 | raise ValueError( 97 | "`slice_for_child` argument cannot be specified when parent is a {} object".format( 98 | type(self) 99 | ) 100 | ) 101 | if kwargs.get("strides_for_child", None) != None: 102 | raise ValueError( 103 | "`strides_for_child` argument cannot be specified when parent is a {} object".format( 104 | type(self) 105 | ) 106 | ) 107 | if isinstance(child, Multiplexer): 108 | raise ValueError( 109 | "Proxy object cannot be a parent of Multiplexer object." 110 | ) 111 | super()._check_filiation_ok(child=child, parent=parent, **kwargs) 112 | 113 | 114 | class Multiplier(_Operator): 115 | """ 116 | Class of multiplier operator. 117 | 118 | Must have two and only two parents. Inner tensor is the multiplication of 119 | its parent's tensor. It can also be used to convolve its parent's tensor 120 | when the attribute 'conv_idx_ids' is provided (cf __init__ doctstring). The 121 | way the two parent's tensor are multiplied or convolved is automatically 122 | computed according the index IDs of all involved tensors. Only 'valid' mode 123 | for convolution is available (cf docstring of numpy.convolve for the 124 | definition of this mode). 125 | """ 126 | 127 | max_parents = 2 128 | 129 | def __init__(self, conv_idx_ids=None, **kwargs): 130 | """ 131 | Returns a Multiplier object which multiplies or convolves the parent's 132 | tensors. 133 | 134 | Parameters 135 | ---------- 136 | conv_idx_ids: sequence hashable, optional, default None 137 | When provided and different from None, the operator acts as a 138 | convolver. Must be of the form [idx1, idx2, ...] 139 | where each idx is the name of parents' and self's axes to be 140 | convolved (convolution can be performed on several axes). If 141 | parent's names of axes to be convolved are different, you should 142 | use option `index_id_for_child` in the filiation creation process 143 | (i.e. when calling `new_child` or `new_parent` method) to change 144 | them. 145 | """ 146 | self.conv_idx_ids = conv_idx_ids or [] 147 | super().__init__(**kwargs) 148 | 149 | @cached_property 150 | def conv_idx_list_for_einconv(self): 151 | return [(idx,) * 3 for idx in self.conv_idx_ids] 152 | 153 | def _update_tensor(self, **kwargs): 154 | input_einconv = [ 155 | elem 156 | for node in self.list_of_parents 157 | for elem in [ 158 | node.get_tensor_for_children(self), 159 | node.get_index_id_for_children(self), 160 | ] 161 | ] + [self.index_id] 162 | 163 | utils.einconv( 164 | *input_einconv, out=self.tensor[...], conv_idx_list=self.conv_idx_ids 165 | ) 166 | 167 | def _bump(self, **kwargs): 168 | self._update_tensor() 169 | 170 | def _give_update(self, parent, out=None): 171 | input_einconv = [ 172 | elem 173 | for node in self.list_of_parents 174 | if node != parent 175 | for elem in [ 176 | node.get_tensor_for_children(self), 177 | node.get_index_id_for_children(self), 178 | ] 179 | ] + [ 180 | self.full_tensor_update, 181 | self.index_id, 182 | parent.get_index_id_for_children(self), 183 | ] 184 | update_tensor = utils.einconv( 185 | *input_einconv, 186 | compute_correlation=True, 187 | conv_idx_list=self.conv_idx_ids, 188 | out=out 189 | ) 190 | if out is None: 191 | return update_tensor 192 | 193 | def _initialization(self): 194 | # instantiation of tensor 195 | shape_dict = {} 196 | # first find all indexes that are involved in convolution 197 | for parent in self.list_of_parents: 198 | parent_tensor_shape = parent.get_tensor_for_children(self).shape 199 | parent_idx = parent.get_index_id_for_children(self) 200 | for idx, size in zip(parent_idx, parent_tensor_shape): 201 | if idx not in self.conv_idx_ids: 202 | shape_dict.update({idx: size}) 203 | # then we compute the resulting dimension for convolution indexes 204 | shape_dict_conv_diff = {} 205 | shape_dict_conv_sum = {} 206 | shape_dict_conv_min = {} 207 | for idx in self.conv_idx_ids: 208 | dim_iter = iter( 209 | parent.idx_to_shape_for_child(idx, self) 210 | for parent in self.list_of_parents 211 | if idx in parent.get_index_id_for_children(self) 212 | ) 213 | dim1, dim2 = next(dim_iter), next(dim_iter) 214 | dim1, dim2 = (dim2, dim1) if dim1 >= dim2 else (dim1, dim2) 215 | shape_dict_conv_diff.update({idx: dim2 - dim1 + 1}) 216 | shape_dict_conv_sum.update({idx: dim2 + dim1 - 1}) 217 | shape_dict_conv_min.update({idx: dim1}) 218 | 219 | shape_dict.update(shape_dict_conv_diff) # for inner tensor 220 | self.tensor = glob.xp.empty( 221 | tuple(shape_dict[idx] for idx in self.index_id), dtype=glob.float 222 | ) 223 | if self.update_period != 0: 224 | shape_dict.update(shape_dict_conv_sum) # for inner full_tensor_update 225 | full_shape = tuple(shape_dict[idx] for idx in self.index_id) 226 | self.full_tensor_update = glob.xp.ones(full_shape, dtype=glob.float) 227 | tensor_update_slice = tuple( 228 | slice(None) 229 | if idx not in self.conv_idx_ids 230 | else slice(shape_dict_conv_min[idx] - 1, -shape_dict_conv_min[idx] + 1) 231 | for idx in self.index_id 232 | ) 233 | self.tensor_update = self.full_tensor_update[tensor_update_slice] 234 | 235 | self._update_tensor() 236 | 237 | def _check_model_validity(self): 238 | super()._check_model_validity() 239 | parent_with_energy_list = [ 240 | parent for parent in self.list_of_parents if parent.tensor_has_energy 241 | ] 242 | if len(parent_with_energy_list) > 1: 243 | raise ValueError( 244 | "Problem with {}'s parents. At most one parent can have a " 245 | "tensor having energy i.e. no subject to normalization " 246 | "constraint.".format(self) 247 | ) 248 | if not self.tensor_has_energy and self.conv_idx_ids: 249 | raise ValueError( 250 | "Problem with {}. In convolution mode, inner tensor of a " 251 | "Multiplier must have energy. Please put a LeafGamma (or any " 252 | "Leaf carrying energy) upstream.".format(self) 253 | ) 254 | set_of_idx = set.union( 255 | *( 256 | set(parent.get_index_id_for_children(self)) 257 | for parent in self.list_of_parents 258 | ) 259 | ) 260 | set_of_idx.update(self.index_id) 261 | dict_idx = {idx: {} for idx in set_of_idx} 262 | for parent in self.list_of_parents: 263 | parent_norm_axis = parent.get_norm_axis_for_children(self) 264 | parent_index_id = parent.get_index_id_for_children(self) 265 | for idx in parent_index_id: 266 | dict_idx[idx].update( 267 | { 268 | parent: { 269 | "has_energy": parent.tensor_has_energy, 270 | "is_normalized": ( 271 | True 272 | if parent.tensor_has_energy 273 | else parent_index_id.index(idx) in parent_norm_axis 274 | ), # it is simpler to consider normalized when tensor_has_energy 275 | "is_marginalized": not idx in self.index_id, 276 | "is_convolved": idx in self.conv_idx_ids, 277 | } 278 | } 279 | ) 280 | for idx, parents_dict in dict_idx.items(): 281 | # if idx is convolved, it must appear in two parents and it must 282 | # be normalized in both parents. 283 | if idx in self.conv_idx_ids: 284 | if len(parents_dict) != 2: 285 | raise ValueError( 286 | "Cannot convolve '{}' in {}. This index should belong " 287 | "to two parents. You might make use of `index_id_for_child`" 288 | " argument during filiation creation.".format(idx, self) 289 | ) 290 | if not all(info["is_normalized"] for info in parents_dict.values()): 291 | raise ValueError( 292 | "Model is wrong at {} level. Axis to be convolved " 293 | "should be normalized or belong to a tensor that has energy." 294 | ) 295 | # idx can only be normalized once unless it is convolved 296 | normalized_parent_no_conv_list = [ 297 | parent 298 | for parent, info in parents_dict.items() 299 | if info["is_normalized"] and not info["is_convolved"] 300 | ] 301 | normalized_parent_list = [ 302 | parent for parent, info in parents_dict.items() if info["is_normalized"] 303 | ] 304 | if len(normalized_parent_no_conv_list) > 1: 305 | raise ValueError( 306 | "Model is wrong at {} level. An index_id cannot be " 307 | "normalized twice".format(self) 308 | ) 309 | # if there is a parent with energy, idx must be normalized once 310 | if parent_with_energy_list and not normalized_parent_list: 311 | raise ValueError( 312 | "Model is wrong at {} level. An index_id should be " 313 | "normalized once before multiplication with a tensor that " 314 | "has energy.".format(self) 315 | ) 316 | # if idx is not normalized, it cannot be marginalized 317 | if ( 318 | not normalized_parent_no_conv_list 319 | and next(iter(parents_dict.values()))["is_marginalized"] 320 | ): 321 | raise ValueError( 322 | "Model is wrong at {} level. An index_id should be " 323 | "normalized before marginalization".format(self) 324 | ) 325 | 326 | def _total_energy_leak(self): 327 | if not self.conv_idx_ids: 328 | return super()._total_energy_leak() 329 | # to compute energy leak for convolver operator, ones compute the total 330 | # energy we would have if "full" convolution were performed minus 331 | # total energy of the "valid" convolution result 332 | einsum_arg = [] 333 | for parent in self.list_of_parents: 334 | index_id = tuple(parent.get_index_id_for_children(self)) 335 | index_id_final = tuple( 336 | idx for idx in index_id if idx not in self.conv_idx_ids 337 | ) 338 | tensor = parent.get_tensor_for_children(self) 339 | einsum_arg.append(utils.einsum(tensor, index_id, index_id_final)) 340 | einsum_arg.append(index_id_final) 341 | total_parents_energy = utils.einsum(*einsum_arg, ()) 342 | 343 | energy_diff = (total_parents_energy - self.tensor.sum()).item() 344 | return energy_diff + super()._total_energy_leak() 345 | 346 | 347 | class Multiplexer(_Operator): 348 | """ 349 | Class of Multiplexer operator. 350 | 351 | It aims at concatenate inner tensor of the parent nodes. Concatenation can 352 | be performed on a new axis (stacking) wich is automatically detected 353 | according to parents and self index IDs. It can also be performed on an 354 | existing axis when attributes `multiplexer_idx` is provided (see `__ini__` 355 | docstring). 356 | """ 357 | 358 | # TODO: if no brother, could optimize tensor_update which is copied by its parent 359 | 360 | def __init__(self, multiplexer_idx=None, **kwargs): 361 | """ 362 | Returns a Multiplexer object which concatenate or stack the parent's 363 | tensors. 364 | 365 | Parameters 366 | ---------- 367 | multiplexer_idx: hashable or None, optional, default None 368 | If None, multiplexer will stack all parents' tensor along a new axis 369 | which is automatically detected by comparing self's `index_id` 370 | argument with its parents' ones. If not None, should refer to the 371 | axis ID along which parents' tensors are concatenated. 372 | """ 373 | self.multiplexer_idx = multiplexer_idx 374 | super().__init__(**kwargs) 375 | self.parent_slicing_dict = {} 376 | 377 | def _check_filiation_ok(self, child=None, parent=None, **kwargs): 378 | if parent is not None: 379 | if kwargs.get("slice_for_child", None) != None: 380 | raise ValueError( 381 | "`slice_for_child` argument cannot be specified when child is a {} object".format( 382 | type(self) 383 | ) 384 | ) 385 | if kwargs.get("strides_for_child", None) != None: 386 | raise ValueError( 387 | "`strides_for_child` argument cannot be specified when child is a {} object".format( 388 | type(self) 389 | ) 390 | ) 391 | if isinstance(parent, Proxy): 392 | raise ValueError( 393 | "Proxy object cannot be a parent of Multiplexer object." 394 | ) 395 | super()._check_filiation_ok(child=child, parent=parent, **kwargs) 396 | 397 | def _give_update(self, parent, out=None): 398 | update = self.tensor_update[self.parent_slicing_dict[parent]] 399 | if out is None: 400 | return update 401 | out[...] = update 402 | 403 | def _initialization(self): 404 | if ( 405 | len( 406 | set( 407 | parent.get_index_id_for_children(self) 408 | for parent in self.list_of_parents 409 | ) 410 | ) 411 | > 1 412 | ): 413 | raise ValueError( 414 | "All parents of a multiplexer object must have the same index_id" 415 | ) 416 | 417 | multiplexer_idx_set = set(self.index_id) - set( 418 | self.list_of_parents[0].get_index_id_for_children(self) 419 | ) 420 | # if concatenation is performed along a new axis 421 | if len(multiplexer_idx_set) == 1: 422 | multiplexer_idx = multiplexer_idx_set.pop() 423 | multiplexer_idx_number = self.index_id.index(multiplexer_idx) 424 | 425 | # tensor definition (concatenation of parents' tensors) 426 | self.tensor = glob.xp.concatenate( 427 | tuple( 428 | parent.get_tensor_for_children(self)[ 429 | (slice(None),) * multiplexer_idx_number + (None,) 430 | ] 431 | for parent in self.list_of_parents 432 | ), 433 | axis=multiplexer_idx_number, 434 | ) 435 | for num_parent, parent in enumerate(self.list_of_parents): 436 | self.parent_slicing_dict.update( 437 | {parent: (slice(None),) * multiplexer_idx_number + (num_parent,)} 438 | ) 439 | 440 | # if concatenation is performed along an existing axis 441 | elif not multiplexer_idx_set: 442 | if self.multiplexer_idx is None: 443 | raise ValueError("Defining multiplexer_idx is mandatory in that case") 444 | multiplexer_idx_number = self.index_id.index(self.multiplexer_idx) 445 | 446 | # tensor definition (concatenation of parents' tensors) 447 | self.tensor = glob.xp.concatenate( 448 | tuple( 449 | parent.get_tensor_for_children(self) 450 | for parent in self.list_of_parents 451 | ), 452 | axis=multiplexer_idx_number, 453 | ) 454 | index_init = 0 455 | for num_parent, parent in enumerate(self.list_of_parents): 456 | index_end = ( 457 | index_init 458 | + parent.get_tensor_for_children(self).shape[multiplexer_idx_number] 459 | ) 460 | self.parent_slicing_dict.update( 461 | { 462 | parent: (slice(None),) * multiplexer_idx_number 463 | + (slice(index_init, index_end),) 464 | } 465 | ) 466 | index_init = index_end 467 | else: 468 | raise ValueError( 469 | "index_id problem between multiplexer object and its parents" 470 | ) 471 | 472 | if self.tensor_update is None and self.update_period != 0: 473 | self.tensor_update = glob.xp.ones_like(self.tensor) 474 | 475 | self._redefine_parent_tensor() 476 | 477 | def _redefine_parent_tensor(self): 478 | for parent in self.list_of_parents: 479 | parent.tensor = self.tensor[self.parent_slicing_dict[parent]].reshape( 480 | parent.tensor.shape 481 | ) 482 | # must be recursive in case consecutive multiplexers 483 | if isinstance(parent, Multiplexer): 484 | parent._redefine_parent_tensor() 485 | 486 | @cached_property 487 | def multiplexer_idx_number(self): 488 | if self.multiplexer_idx: 489 | return self.index_id.index(self.multiplexer_idx) 490 | return 0 491 | 492 | def _check_model_validity(self): 493 | super()._check_model_validity() 494 | if not all(parent.tensor_has_energy for parent in self.list_of_parents) and any( 495 | parent.tensor_has_energy for parent in self.list_of_parents 496 | ): 497 | raise ValueError( 498 | "Model is wrong at {} level. Either all the parents' tensor or " 499 | "none of them should have energy.".format(self) 500 | ) 501 | if self.multiplexer_idx is not None and not self.tensor_has_energy: 502 | raise ValueError( 503 | "Problem at {} level. When `multiplexer_idx` is provided, i.e. " 504 | "when a Multiplexer concatenates its parent's tensor along an " 505 | "existing axis, all parents' tensor should have energy." 506 | ) 507 | 508 | 509 | class Integrator(_Operator): 510 | """ 511 | Class for Integrator operation. 512 | 513 | Its goal is to compute integration (i.e. cumulative sum) of its single 514 | parent's tensor along the last axis. Actually, since theory force operators 515 | to respect some kind of energy conservation, the real operation here is a 516 | weighted cumulative sum where weights are defined as [1/T, 1/(T-1), ... 1/2, 517 | 1] and where T is the size of the last axis of parent's tensor. This 518 | weighted cumulative sum can also be performed backwards (see docstring of 519 | `__init__` method). 520 | """ 521 | 522 | max_parents = 1 523 | 524 | def __init__(self, **kwargs): 525 | """ 526 | Returns an Integrator object which computes a weighted cumulative sum 527 | along last axis of a parent's tensor. 528 | 529 | Parameters 530 | ---------- 531 | backward_integration: bool, optional, default False 532 | If True, integration is performed backwards. 533 | """ 534 | self.backward_integration = kwargs.pop("backward_integration", False) 535 | super().__init__(**kwargs) 536 | 537 | def _update_tensor(self, **kwargs): 538 | direction = -1 if self.backward_integration else 1 539 | parent_tensor = ( 540 | self.first_parent.get_tensor_for_children(self) * self._normalization_coef 541 | ) 542 | utils.cumsum_last_axis( 543 | parent_tensor[..., ::direction], self.tensor[..., ::direction] 544 | ) 545 | 546 | def _give_update(self, parent, out=None): 547 | direction = 1 if self.backward_integration else -1 548 | utils.cumsum_last_axis( 549 | self.tensor_update[..., ::direction], self.update_to_give[..., ::direction] 550 | ) 551 | self.update_to_give *= self._normalization_coef 552 | if out is None: 553 | return self.update_to_give 554 | out[...] = self.update_to_give 555 | 556 | def _initialization(self): 557 | self.tensor = glob.xp.empty_like( 558 | self.first_parent.get_tensor_for_children(self) 559 | ) 560 | self._update_tensor() 561 | if self.update_period != 0: 562 | self.full_tensor_update = glob.xp.zeros_like(self.tensor) 563 | self.tensor_update = self.full_tensor_update 564 | self.update_to_give = glob.xp.empty_like(self.tensor) 565 | self.integration_dim = self.tensor.shape[-1] 566 | 567 | @cached_property 568 | def _normalization_coef(self): 569 | parent_tensor = self.first_parent.get_tensor_for_children(self) 570 | if self.backward_integration: 571 | norm_coef = 1.0 / glob.xp.arange( 572 | 1, parent_tensor.shape[-1] + 1, dtype=glob.float 573 | ) 574 | else: 575 | norm_coef = 1.0 / glob.xp.arange( 576 | parent_tensor.shape[-1], 0, -1, dtype=glob.float 577 | ) 578 | return norm_coef 579 | 580 | def _check_model_validity(self): 581 | super()._check_model_validity() 582 | if not self.tensor_has_energy and self.tensor.ndim - 1 not in self.norm_axis: 583 | raise ValueError( 584 | "Model is wrong at {} level. Parent's tensor must have energy " 585 | "or its last axis must be normalized." 586 | ) 587 | 588 | 589 | class Adder(_Operator): 590 | """ 591 | Class for Adder operator. 592 | 593 | It aims at summing tensors of all parents into a single tensor. Allows 594 | self's tensor to be manipulated with reshapes and slices before actual 595 | summation (see `wonterfact.DynNodeData.new_parent` method's docstring). 596 | """ 597 | 598 | def __init__(self, **kwargs): 599 | self.pre_slice_dict = {} 600 | self.shape_dict = {} 601 | self.post_slice_dict = {} 602 | super().__init__(**kwargs) 603 | 604 | def _update_tensor(self, **kwargs): 605 | self.tensor[...] = 0 606 | for parent in self.list_of_parents: 607 | tensor = self.apply_slice_and_shape(self.tensor, parent) 608 | if not glob.xp.may_share_memory(tensor, self.tensor): 609 | raise ValueError("something is wrong in the reslicing and reshaping") 610 | tensor += parent.get_tensor_for_children(self) 611 | 612 | def _give_update(self, parent, out=None): 613 | update = self.apply_slice_and_shape(self.tensor_update, parent) 614 | if out is None: 615 | return update 616 | out[...] = update 617 | 618 | def _initialization(self): 619 | if self.tensor is None: 620 | if self.parent_full_shape is None: 621 | raise ValueError( 622 | "Please manually instantiate a tensor for this Adder (cannot infer the proper shape)" 623 | ) 624 | self.tensor = glob.xp.zeros_like( 625 | self.parent_full_shape.get_tensor_for_children(self) 626 | ) 627 | self._update_tensor() 628 | if self.update_period != 0: 629 | self.tensor_update = glob.xp.zeros_like(self.tensor) 630 | 631 | def _parse_kwargs_for_filiation(self, parent, **kwargs): 632 | pre_slice_for_adder = kwargs.pop("pre_slice_for_adder", Ellipsis) 633 | shape_for_adder = kwargs.pop("shape_for_adder", None) 634 | post_slice_for_adder = kwargs.pop("post_slice_for_adder", Ellipsis) 635 | try: 636 | pre_slice_for_adder = tuple(pre_slice_for_adder) 637 | except TypeError: 638 | pass 639 | try: 640 | post_slice_for_adder = tuple(post_slice_for_adder) 641 | except TypeError: 642 | pass 643 | self.pre_slice_dict[parent] = pre_slice_for_adder 644 | self.shape_dict[parent] = shape_for_adder 645 | self.post_slice_dict[parent] = post_slice_for_adder 646 | return kwargs 647 | 648 | def apply_slice_and_shape(self, tensor, parent): 649 | tensor1 = tensor[self.pre_slice_dict[parent]] 650 | if self.shape_dict[parent]: 651 | tensor1 = tensor1.reshape(self.shape_dict[parent]) 652 | if self.post_slice_dict[parent]: 653 | tensor1 = tensor1[self.post_slice_dict[parent]] 654 | return tensor1 655 | 656 | @cached_property 657 | def parent_full_shape(self): 658 | parent_full_shape = next( 659 | ( 660 | parent 661 | for parent in self.list_of_parents 662 | if self.pre_slice_dict[parent] == Ellipsis 663 | and self.post_slice_dict[parent] == Ellipsis 664 | and self.shape_dict[parent] is None 665 | ), 666 | None, 667 | ) 668 | return parent_full_shape 669 | 670 | def _check_model_validity(self): 671 | super()._check_model_validity() 672 | if not all(parent.tensor_has_energy for parent in self.list_of_parents): 673 | raise ValueError( 674 | "Model is wrong at {} level. Parents' tensors of an Adder " 675 | "should all have energy." 676 | ) 677 | 678 | 679 | # class RealMultiplier(DynNodeData): 680 | # def __init__(self, **kwargs): 681 | # self.conv_idx_ids = kwargs.pop('conv_idx_ids', []) 682 | # self.sign_id = kwargs.pop('sign_id', None) 683 | # self.parent_sign_id_dict = {} 684 | # super(RealMultiplier, self).__init__(**kwargs) 685 | 686 | # def parent_sign_axis(self, parent): 687 | # return parent.index_id.index(self.parent_sign_id_dict[parent]) 688 | 689 | # def _new_parent(self, parent, **kwargs): 690 | # parent_sign_id = kwargs.pop('parent_sign_id', None) 691 | # if parent_sign_id is None: 692 | # ValueError("Please provide 'parent_sign_id' during call of create_filiation") 693 | # self.parent_sign_id_dict[parent] = parent_sign_id 694 | 695 | # def get_parent_slice(self, parent, sign_num): 696 | # return [slice(None), ] * self.parent_sign_axis(parent) + [sign_num, ] 697 | 698 | # def _update_tensor(self, **kwargs): 699 | # pass 700 | --------------------------------------------------------------------------------