├── tests ├── __init__.py ├── data │ ├── toy_data_1.txt │ ├── toy_data_2.txt │ ├── iris.data │ └── BinaryIrisData.txt ├── test_harden_layer.py ├── utils.py ├── test_hard_count.py ├── test_symbolic_generation.py ├── test_harden.py ├── test_hard_masks.py ├── test_hard_majority.py ├── test_toy_problem.py ├── test_mnist.py ├── test_noisy_xor.py └── test_symbolic_primitives.py ├── neurallogic ├── __init__.py ├── symbolic_representation.py ├── hard_vmap.py ├── neural_logic_net.py ├── harden_layer.py ├── hard_concatenate.py ├── map_at_elements.py ├── harden.py ├── initialization.py ├── hard_majority.py ├── hard_xor.py ├── symbolic_operator.py ├── hard_not.py ├── real_encoder.py ├── hard_or.py ├── hard_count.py ├── hard_dropout.py ├── hard_and.py ├── hard_masks.py ├── symbolic_generation.py └── symbolic_primitives.py ├── docs ├── db.pdf ├── db-net.png ├── logic-gates.png ├── majority-gates.png ├── margin-trick.png ├── iclr2021_conference.bst ├── mnist-architecture.png ├── binary-iris-architecture.png ├── noisy-xor-architecture.png ├── toy-example-architecture.png ├── ICML_workshop │ ├── icml2023 │ │ ├── example_paper.pdf │ │ ├── icml_numpapers.pdf │ │ ├── example_paper.bib │ │ ├── algorithm.sty │ │ └── algorithmic.sty │ └── icml2023-diffxyz │ │ └── algorithm.sty ├── iclr2021_conference.sty └── math_commands.tex ├── .gitattributes ├── .devcontainer ├── requirements.txt ├── Dockerfile └── devcontainer.json ├── setup.py ├── .github └── workflows │ └── python.yaml ├── README.md └── scratchpad.ipynb /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /neurallogic/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/db.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Z80coder/db-nets/HEAD/docs/db.pdf -------------------------------------------------------------------------------- /docs/db-net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Z80coder/db-nets/HEAD/docs/db-net.png -------------------------------------------------------------------------------- /docs/logic-gates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Z80coder/db-nets/HEAD/docs/logic-gates.png -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Notebooks are binary 2 | prototype/notebooks/* binary 3 | demos/* binary 4 | -------------------------------------------------------------------------------- /docs/majority-gates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Z80coder/db-nets/HEAD/docs/majority-gates.png -------------------------------------------------------------------------------- /docs/margin-trick.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Z80coder/db-nets/HEAD/docs/margin-trick.png -------------------------------------------------------------------------------- /docs/iclr2021_conference.bst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Z80coder/db-nets/HEAD/docs/iclr2021_conference.bst -------------------------------------------------------------------------------- /docs/mnist-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Z80coder/db-nets/HEAD/docs/mnist-architecture.png -------------------------------------------------------------------------------- /docs/binary-iris-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Z80coder/db-nets/HEAD/docs/binary-iris-architecture.png -------------------------------------------------------------------------------- /docs/noisy-xor-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Z80coder/db-nets/HEAD/docs/noisy-xor-architecture.png -------------------------------------------------------------------------------- /docs/toy-example-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Z80coder/db-nets/HEAD/docs/toy-example-architecture.png -------------------------------------------------------------------------------- /docs/ICML_workshop/icml2023/example_paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Z80coder/db-nets/HEAD/docs/ICML_workshop/icml2023/example_paper.pdf -------------------------------------------------------------------------------- /docs/ICML_workshop/icml2023/icml_numpapers.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Z80coder/db-nets/HEAD/docs/ICML_workshop/icml2023/icml_numpapers.pdf -------------------------------------------------------------------------------- /tests/data/toy_data_1.txt: -------------------------------------------------------------------------------- 1 | 0 0 2 | 0 0 3 | 0 0 4 | 0 0 5 | 0 0 6 | 0 0 7 | 0 0 8 | 0 0 9 | 0 0 10 | 0 0 11 | 0 0 12 | 1 1 13 | 1 1 14 | 1 1 15 | 1 1 16 | 1 1 17 | 1 1 18 | 1 1 19 | 1 1 20 | 1 1 21 | -------------------------------------------------------------------------------- /.devcontainer/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | pytest 4 | jupyter 5 | plum-dispatch 6 | 7 | tensorflow 8 | tensorflow_datasets 9 | ml_collections 10 | 11 | --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 12 | jaxlib 13 | jax[cuda] 14 | flax 15 | clu -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | setup( 3 | name='neurallogic', 4 | packages=find_packages(include=['neurallogic']), 5 | version='0.1.0', 6 | description='A Neural Logic Library', 7 | author='@z80coder', 8 | install_requires=[], 9 | setup_requires=['pytest-runner'], 10 | test_suite='tests', 11 | ) -------------------------------------------------------------------------------- /neurallogic/symbolic_representation.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from plum import dispatch 3 | 4 | # TODO: need a more general solution to unquoting symbolic strings 5 | 6 | @dispatch 7 | def symbolic_representation(x: numpy.ndarray): 8 | return repr(x).replace('array', 'numpy.array').replace('\n', '').replace('float32', 'numpy.float32').replace('\'', '') 9 | 10 | 11 | @dispatch 12 | def symbolic_representation(x: str): 13 | return x.replace('\'', '') 14 | -------------------------------------------------------------------------------- /neurallogic/hard_vmap.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import numpy 3 | 4 | from neurallogic import neural_logic_net 5 | 6 | 7 | def soft_vmap(f): 8 | return jax.vmap(f) 9 | 10 | 11 | def hard_vmap(f): 12 | return soft_vmap(f) 13 | 14 | 15 | def symbolic_vmap(f): 16 | return numpy.vectorize(f, otypes=[object]) 17 | 18 | vmap = neural_logic_net.select( 19 | lambda f: soft_vmap(f[0]), 20 | lambda f: hard_vmap(f[1]), 21 | lambda f: symbolic_vmap(f[2]) 22 | ) 23 | 24 | # TODO: add tests 25 | -------------------------------------------------------------------------------- /tests/test_harden_layer.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | from neurallogic import harden, harden_layer 4 | 5 | 6 | def test_harden_layer(): 7 | test_data = [ 8 | [[0.8, 0.1], [1.0, 0.0]], 9 | [[1.0, 0.52], [1.0, 1.0]], 10 | [[0.3, 0.51], [0.0, 1.0]], 11 | [[0.49, 0.32], [0.0, 0.0]] 12 | ] 13 | for input, expected in test_data: 14 | input = jnp.array(input) 15 | expected = jnp.array(expected) 16 | assert jnp.array_equal(harden_layer.soft_harden_layer(input), expected) 17 | assert jnp.array_equal(harden_layer.hard_harden_layer( 18 | harden.harden(input)), harden.harden(expected)) 19 | symbolic_output = harden_layer.symbolic_harden_layer( 20 | harden.harden(input.tolist())) 21 | assert jnp.array_equal(symbolic_output, harden.harden(expected)) 22 | # TODO: test symbolic harden layer 23 | -------------------------------------------------------------------------------- /neurallogic/neural_logic_net.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from flax import linen as nn 3 | 4 | NetType = Enum("NetType", ["Soft", "Hard", "Symbolic"]) 5 | 6 | 7 | def select(soft, hard, symbolic): 8 | def selector(type: NetType): 9 | return {NetType.Soft: soft, NetType.Hard: hard, NetType.Symbolic: symbolic}[ 10 | type 11 | ] 12 | 13 | return selector 14 | 15 | 16 | def net(f): 17 | class SoftNet(nn.Module): 18 | @nn.compact 19 | def __call__(self, x, **kwargs): 20 | return f(NetType.Soft, x, **kwargs) 21 | 22 | class HardNet(nn.Module): 23 | @nn.compact 24 | def __call__(self, x, **kwargs): 25 | return f(NetType.Hard, x, **kwargs) 26 | 27 | class SymbolicNet(nn.Module): 28 | @nn.compact 29 | def __call__(self, x, **kwargs): 30 | return f(NetType.Symbolic, x, **kwargs) 31 | 32 | return SoftNet(), HardNet(), SymbolicNet() 33 | -------------------------------------------------------------------------------- /tests/data/toy_data_2.txt: -------------------------------------------------------------------------------- 1 | 1 0 0 0 0 1 2 | 0 1 0 0 0 1 3 | 0 0 1 0 0 0 4 | 0 0 0 1 0 0 5 | 1 0 0 0 0 1 6 | 0 1 0 0 0 1 7 | 0 0 1 0 0 0 8 | 0 0 0 1 0 0 9 | 1 0 0 0 1 1 10 | 0 1 0 0 1 1 11 | 0 0 1 0 1 0 12 | 0 0 0 1 1 1 13 | 1 0 0 0 1 1 14 | 0 1 0 0 1 1 15 | 0 0 1 0 1 0 16 | 0 0 0 1 1 1 17 | 0 1 0 0 0 1 18 | 0 0 0 1 1 1 19 | 1 0 0 0 1 1 20 | 0 0 1 0 1 0 21 | 1 0 0 0 0 1 22 | 0 1 0 0 0 1 23 | 0 0 1 0 0 0 24 | 0 0 0 1 0 0 25 | 1 0 0 0 0 1 26 | 0 1 0 0 0 1 27 | 0 0 1 0 0 0 28 | 0 0 0 1 0 0 29 | 1 0 0 0 1 1 30 | 0 1 0 0 1 1 31 | 0 0 1 0 1 0 32 | 0 0 0 1 1 1 33 | 1 0 0 0 1 1 34 | 0 1 0 0 1 1 35 | 0 0 1 0 1 0 36 | 0 0 0 1 1 1 37 | 0 1 0 0 0 1 38 | 0 0 0 1 1 1 39 | 1 0 0 0 1 1 40 | 0 0 1 0 1 0 41 | 1 0 0 0 0 1 42 | 0 1 0 0 0 1 43 | 0 0 1 0 0 0 44 | 0 0 0 1 0 0 45 | 1 0 0 0 0 1 46 | 0 1 0 0 0 1 47 | 0 0 1 0 0 0 48 | 0 0 0 1 0 0 49 | 1 0 0 0 1 1 50 | 0 1 0 0 1 1 51 | 0 0 1 0 1 0 52 | 0 0 0 1 1 1 53 | 1 0 0 0 1 1 54 | 0 1 0 0 1 1 55 | 0 0 1 0 1 0 56 | 0 0 0 1 1 1 57 | 0 1 0 0 0 1 58 | 0 0 0 1 1 1 59 | 1 0 0 0 1 1 60 | 0 0 1 0 1 0 -------------------------------------------------------------------------------- /neurallogic/harden_layer.py: -------------------------------------------------------------------------------- 1 | import jax 2 | 3 | from neurallogic import neural_logic_net 4 | 5 | 6 | def logistic_clip(x): 7 | return jax.scipy.special.expit(3 * (2 * x - 1)) 8 | 9 | 10 | def harden(x): 11 | # non-differentiable 12 | return jax.lax.cond(x > 0.5, lambda _: 1.0, lambda _: 0.0, None) 13 | 14 | 15 | def straight_through_harden(x): 16 | # The harden operation is non-differentiable. Therefore we need to 17 | # approximate with the straight-through estimator. 18 | 19 | # Create an exactly-zero expression with Sterbenz lemma that has 20 | # an exactly-one gradient. 21 | zero = x - jax.lax.stop_gradient(x) 22 | grad_of_one = zero + jax.lax.stop_gradient(harden(x)) 23 | return grad_of_one 24 | 25 | 26 | soft_harden_layer = jax.vmap(straight_through_harden) 27 | 28 | 29 | def hard_harden_layer(x): 30 | return x 31 | 32 | 33 | # TODO: can we harden arbitrary tensors? 34 | # TODO: is this correct? 35 | def symbolic_harden_layer(x): 36 | return x 37 | 38 | 39 | harden_layer = neural_logic_net.select( 40 | soft_harden_layer, hard_harden_layer, symbolic_harden_layer 41 | ) 42 | -------------------------------------------------------------------------------- /.github/workflows/python.yaml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: 4 | pull_request: 5 | paths: 6 | - "neurallogic/**" 7 | - "tests/**" 8 | - ".github/workflows/**" 9 | 10 | jobs: 11 | build: 12 | 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | python-version: ["3.8", "3.9", "3.10"] 17 | 18 | steps: 19 | - uses: actions/checkout@v3 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v4 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install flake8 pytest 28 | pip install -r .devcontainer/requirements.txt 29 | - name: Lint with flake8 30 | run: | 31 | # stop the build if there are Python syntax errors or undefined names 32 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 33 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 34 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 35 | - name: Test with pytest 36 | run: | 37 | pytest -------------------------------------------------------------------------------- /neurallogic/hard_concatenate.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from flax import linen as nn 3 | 4 | from neurallogic import neural_logic_net, symbolic_generation 5 | 6 | 7 | def soft_concatenate(x, axis): 8 | return jax.numpy.concatenate(x, axis) 9 | 10 | 11 | def hard_concatenate(x, axis): 12 | return soft_concatenate(x, axis) 13 | 14 | 15 | class SoftConcatenate(nn.Module): 16 | axis: int 17 | @nn.compact 18 | def __call__(self, x): 19 | return soft_concatenate(x, self.axis) 20 | 21 | 22 | class HardConcatenate(nn.Module): 23 | axis: int 24 | @nn.compact 25 | def __call__(self, x): 26 | return hard_concatenate(x, self.axis) 27 | 28 | 29 | class SymbolicConcatenate: 30 | def __init__(self, axis): 31 | self.hard_concatenate = HardConcatenate(axis) 32 | 33 | def __call__(self, x): 34 | jaxpr = symbolic_generation.make_symbolic_flax_jaxpr( 35 | self.hard_concatenate, x 36 | ) 37 | return symbolic_generation.symbolic_expression(jaxpr, x) 38 | 39 | 40 | concatenate = neural_logic_net.select( 41 | lambda x, axis: SoftConcatenate(axis)(x), 42 | lambda x, axis: HardConcatenate(axis)(x), 43 | lambda x, axis: SymbolicConcatenate(axis)(x), 44 | ) 45 | 46 | # TODO: add tests -------------------------------------------------------------------------------- /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | # See here for image contents: https://hub.docker.com/r/jupyter/datascience-notebook/ 2 | 3 | FROM jupyter/datascience-notebook 4 | 5 | # We want to run common-debian.sh from here: 6 | # https://github.com/microsoft/vscode-dev-containers/tree/main/script-library#development-container-scripts 7 | # But that script assumes that the main non-root user (in this case jovyan) 8 | # is in a group with the same name (in this case jovyan). So we must first make that so. 9 | COPY library-scripts/common-debian.sh /tmp/library-scripts/ 10 | USER root 11 | RUN apt-get update \ 12 | && groupadd jovyan \ 13 | && usermod -g jovyan -a -G users jovyan \ 14 | && bash /tmp/library-scripts/common-debian.sh \ 15 | && apt-get clean -y && rm -rf /var/lib/apt/lists/* /tmp/library-scripts 16 | 17 | # [Optional] If your pip requirements rarely change, uncomment this section to add them to the image. 18 | COPY requirements.txt /tmp/pip-tmp/ 19 | RUN pip3 --disable-pip-version-check --no-cache-dir install -r /tmp/pip-tmp/requirements.txt \ 20 | && rm -rf /tmp/pip-tmp 21 | 22 | # [Optional] Uncomment this section to install additional OS packages. 23 | RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ 24 | && apt-get -y install --no-install-recommends nvidia-cuda-toolkit 25 | 26 | USER jovyan -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Jupyter Data Science Notebooks (Community)", 3 | "build": { 4 | "dockerfile": "Dockerfile" 5 | }, 6 | "overrideCommand": false, 7 | 8 | // Forward Jupyter port locally, mark required 9 | "forwardPorts": [8888], 10 | "portsAttributes": { 11 | "8888": { 12 | "label": "Jupyter", 13 | "requireLocalPort": true, 14 | "onAutoForward": "ignore" 15 | } 16 | }, 17 | 18 | // Configure tool-specific properties. 19 | "customizations": { 20 | // Configure properties specific to VS Code. 21 | "vscode": { 22 | // Set *default* container specific settings.json values on container create. 23 | "settings": { 24 | "python.defaultInterpreterPath": "/opt/conda/bin/python" 25 | }, 26 | 27 | // Add the IDs of extensions you want installed when the container is created. 28 | "extensions": [ 29 | "ms-python.python" 30 | ] 31 | } 32 | }, 33 | 34 | // Use 'postCreateCommand' to run commands after the container is created. 35 | // "postCreateCommand": "pip3 install --user -r requirements.txt", 36 | 37 | // Set `remoteUser` to `root` to connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root. 38 | "remoteUser": "jovyan", 39 | 40 | // Install NVIDIA cuda and deep neural network support 41 | "features": { 42 | "ghcr.io/devcontainers/features/nvidia-cuda:1": { 43 | "installCudnn": true, 44 | "installNvtx": true 45 | } 46 | } 47 | 48 | } 49 | -------------------------------------------------------------------------------- /neurallogic/map_at_elements.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import jax 4 | import numpy 5 | from plum import dispatch 6 | 7 | 8 | @dispatch 9 | def map_at_elements(x: str, func: typing.Callable): 10 | return func(x) 11 | 12 | 13 | @dispatch 14 | def map_at_elements(x: bool, func: typing.Callable): 15 | return func(x) 16 | 17 | 18 | @dispatch 19 | def map_at_elements(x: numpy.bool_, func: typing.Callable): 20 | return func(x) 21 | 22 | 23 | @dispatch 24 | def map_at_elements(x: float, func: typing.Callable): 25 | return func(x) 26 | 27 | 28 | @dispatch 29 | def map_at_elements(x: numpy.float32, func: typing.Callable): 30 | return func(x) 31 | 32 | 33 | @dispatch 34 | def map_at_elements(x: numpy.int32, func: typing.Callable): 35 | return func(x) 36 | 37 | 38 | @dispatch 39 | def map_at_elements(x: list, func: typing.Callable): 40 | return [map_at_elements(item, func) for item in x] 41 | 42 | 43 | @dispatch 44 | def map_at_elements(x: numpy.ndarray, func: typing.Callable): 45 | return numpy.array([map_at_elements(item, func) for item in x], dtype=object) 46 | 47 | 48 | @dispatch 49 | def map_at_elements(x: jax.numpy.ndarray, func: typing.Callable): 50 | if x.ndim == 0: 51 | return func(x.item()) 52 | return jax.numpy.array([map_at_elements(item, func) for item in x]) 53 | 54 | 55 | @dispatch 56 | def map_at_elements(x: dict, func: typing.Callable): 57 | return {k: map_at_elements(v, func) for k, v in x.items()} 58 | 59 | 60 | @dispatch 61 | def map_at_elements(x: tuple, func: typing.Callable): 62 | return tuple(map_at_elements(list(x), func)) 63 | -------------------------------------------------------------------------------- /neurallogic/harden.py: -------------------------------------------------------------------------------- 1 | import flax 2 | import jax 3 | import numpy 4 | from plum import dispatch 5 | 6 | from neurallogic import map_at_elements 7 | 8 | 9 | def harden_float(x: float) -> bool: 10 | return x > 0.5 11 | 12 | 13 | harden_array = jax.vmap(harden_float, 0, 0) 14 | 15 | 16 | @dispatch 17 | def harden(x: float): 18 | if numpy.isnan(x): 19 | return x 20 | return harden_float(x) 21 | 22 | 23 | @dispatch 24 | def harden(x: bool): 25 | return x 26 | 27 | 28 | @dispatch 29 | def harden(x: list): 30 | return map_at_elements.map_at_elements(x, harden_float) 31 | 32 | 33 | @dispatch 34 | def harden(x: numpy.ndarray): 35 | if x.ndim == 0: 36 | return harden(x.item()) 37 | return harden_array(x) 38 | 39 | 40 | @dispatch 41 | def harden(x: jax.numpy.ndarray): 42 | if x.ndim == 0: 43 | return harden(x.item()) 44 | return harden_array(x) 45 | 46 | 47 | @dispatch 48 | def harden(x: dict): 49 | # Only harden parameters that explicitly represent bits 50 | def conditional_harden(k, v): 51 | if k.startswith("bit_"): 52 | return map_at_elements.map_at_elements(v, harden) 53 | elif isinstance(v, dict) or isinstance(v, flax.core.FrozenDict) or isinstance(v, list): 54 | return harden(v) 55 | return v 56 | 57 | return {k: conditional_harden(k, v) for k, v in x.items()} 58 | 59 | 60 | @dispatch 61 | def harden(x: flax.core.FrozenDict): 62 | return harden(x.unfreeze()) 63 | 64 | 65 | @dispatch 66 | def map_keys_nested(f, d: dict) -> dict: 67 | return { 68 | f(k): map_keys_nested(f, v) if isinstance(v, dict) else v for k, v in d.items() 69 | } 70 | 71 | 72 | def hard_weights(weights): 73 | return flax.core.FrozenDict( 74 | map_keys_nested( 75 | lambda str: str.replace("Soft", "Hard"), harden(weights.unfreeze()) 76 | ) 77 | ) 78 | -------------------------------------------------------------------------------- /neurallogic/initialization.py: -------------------------------------------------------------------------------- 1 | import jax 2 | 3 | 4 | def initialize_uniform_range(lower=0.0, upper=1.0): 5 | def init(key, shape, dtype): 6 | dtype = jax.dtypes.canonicalize_dtype(dtype) 7 | x = jax.random.uniform(key, shape, dtype, lower, upper) 8 | return x 9 | 10 | return init 11 | 12 | 13 | def initialize_near_to_zero(mean=-1, std=0.5): 14 | def init(key, shape, dtype): 15 | dtype = jax.dtypes.canonicalize_dtype(dtype) 16 | # Sample from standard normal distribution (zero mean, unit variance) 17 | x = jax.random.normal(key, shape, dtype) 18 | # Transform to a normal distribution with mean -1 and standard deviation 0.5 19 | x = std * x + mean 20 | x = jax.numpy.clip(x, 0.001, 0.999) 21 | return x 22 | 23 | return init 24 | 25 | 26 | def initialize_near_to_one(): 27 | def init(key, shape, dtype): 28 | dtype = jax.dtypes.canonicalize_dtype(dtype) 29 | # Sample from standard normal distribution (zero mean, unit variance) 30 | x = jax.random.normal(key, shape, dtype) 31 | # Transform to a normal distribution with mean 1 and standard deviation 0.5 32 | x = 0.5 * x + 1 33 | x = jax.numpy.clip(x, 0.001, 0.999) 34 | return x 35 | 36 | return init 37 | 38 | 39 | # TODO: get rid of symmetry 40 | def initialize_bernoulli(p=0.5, low=0.001, high=0.999): 41 | def init(key, shape, dtype): 42 | x = jax.random.bernoulli(key, p, shape) 43 | x = jax.numpy.where(x, high, low) 44 | x = jax.numpy.asarray(x, dtype) 45 | return x 46 | 47 | return init 48 | 49 | def initialize_bernoulli_uniform(p=0.5, low=0.001, high=0.999): 50 | def init(key, shape, dtype): 51 | x = jax.random.bernoulli(key, p, shape) 52 | h = jax.random.uniform(key, shape, dtype, 0.5, high) 53 | l = jax.random.uniform(key, shape, dtype, low, 0.5) 54 | x = jax.numpy.where(x, h, l) 55 | x = jax.numpy.asarray(x, dtype) 56 | return x 57 | 58 | return init 59 | -------------------------------------------------------------------------------- /docs/ICML_workshop/icml2023/example_paper.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{langley00, 2 | author = {P. Langley}, 3 | title = {Crafting Papers on Machine Learning}, 4 | year = {2000}, 5 | pages = {1207--1216}, 6 | editor = {Pat Langley}, 7 | booktitle = {Proceedings of the 17th International Conference 8 | on Machine Learning (ICML 2000)}, 9 | address = {Stanford, CA}, 10 | publisher = {Morgan Kaufmann} 11 | } 12 | 13 | @TechReport{mitchell80, 14 | author = "T. M. Mitchell", 15 | title = "The Need for Biases in Learning Generalizations", 16 | institution = "Computer Science Department, Rutgers University", 17 | year = "1980", 18 | address = "New Brunswick, MA", 19 | } 20 | 21 | @phdthesis{kearns89, 22 | author = {M. J. Kearns}, 23 | title = {Computational Complexity of Machine Learning}, 24 | school = {Department of Computer Science, Harvard University}, 25 | year = {1989} 26 | } 27 | 28 | @Book{MachineLearningI, 29 | editor = "R. S. Michalski and J. G. Carbonell and T. 30 | M. Mitchell", 31 | title = "Machine Learning: An Artificial Intelligence 32 | Approach, Vol. I", 33 | publisher = "Tioga", 34 | year = "1983", 35 | address = "Palo Alto, CA" 36 | } 37 | 38 | @Book{DudaHart2nd, 39 | author = "R. O. Duda and P. E. Hart and D. G. Stork", 40 | title = "Pattern Classification", 41 | publisher = "John Wiley and Sons", 42 | edition = "2nd", 43 | year = "2000" 44 | } 45 | 46 | @misc{anonymous, 47 | title= {Suppressed for Anonymity}, 48 | author= {Author, N. N.}, 49 | year= {2021} 50 | } 51 | 52 | @InCollection{Newell81, 53 | author = "A. Newell and P. S. Rosenbloom", 54 | title = "Mechanisms of Skill Acquisition and the Law of 55 | Practice", 56 | booktitle = "Cognitive Skills and Their Acquisition", 57 | pages = "1--51", 58 | publisher = "Lawrence Erlbaum Associates, Inc.", 59 | year = "1981", 60 | editor = "J. R. Anderson", 61 | chapter = "1", 62 | address = "Hillsdale, NJ" 63 | } 64 | 65 | 66 | @Article{Samuel59, 67 | author = "A. L. Samuel", 68 | title = "Some Studies in Machine Learning Using the Game of 69 | Checkers", 70 | journal = "IBM Journal of Research and Development", 71 | year = "1959", 72 | volume = "3", 73 | number = "3", 74 | pages = "211--229" 75 | } 76 | -------------------------------------------------------------------------------- /neurallogic/hard_majority.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from flax import linen as nn 3 | 4 | from neurallogic import neural_logic_net, symbolic_generation 5 | 6 | 7 | def majority_index(input_size: int) -> int: 8 | return (input_size - 1) // 2 9 | 10 | # TODO: properly factor with/without margin versions 11 | 12 | def majority_bit(x: jax.numpy.array) -> float: 13 | index = majority_index(x.shape[-1]) 14 | sorted_x = jax.numpy.sort(x, axis=-1) 15 | return jax.numpy.take(sorted_x, index, axis=-1) 16 | 17 | 18 | def soft_majority(x: jax.numpy.array) -> float: 19 | m_bit = majority_bit(x) 20 | margin = jax.numpy.abs(m_bit - 0.5) 21 | mean = jax.numpy.mean(x, axis=-1) 22 | margin_delta = mean * margin 23 | representative_bit = jax.numpy.where( 24 | m_bit > 0.5, 25 | 0.5 + margin_delta, 26 | m_bit + margin_delta, 27 | ) 28 | return representative_bit 29 | 30 | 31 | def hard_majority(x: jax.numpy.array) -> bool: 32 | threshold = x.shape[-1] - majority_index(x.shape[-1]) 33 | return jax.numpy.sum(x, axis=-1) >= threshold 34 | 35 | 36 | soft_majority_layer = jax.vmap(soft_majority, in_axes=0) 37 | 38 | hard_majority_layer = jax.vmap(hard_majority, in_axes=0) 39 | 40 | 41 | class SoftMajorityLayer(nn.Module): 42 | """ 43 | A soft-bit MAJORITY layer than transforms its inputs along the last dimension. 44 | 45 | Attributes: 46 | layer_size: The number of neurons in the layer. 47 | weights_init: The initializer function for the weight matrix. 48 | """ 49 | 50 | @nn.compact 51 | def __call__(self, x): 52 | return soft_majority_layer(x) 53 | 54 | 55 | class HardMajorityLayer(nn.Module): 56 | @nn.compact 57 | def __call__(self, x): 58 | return hard_majority_layer(x) 59 | 60 | 61 | class SymbolicMajorityLayer: 62 | def __init__(self): 63 | self.hard_majority_layer = HardMajorityLayer() 64 | 65 | def __call__(self, x): 66 | jaxpr = symbolic_generation.make_symbolic_flax_jaxpr( 67 | self.hard_majority_layer, x 68 | ) 69 | return symbolic_generation.symbolic_expression(jaxpr, x) 70 | 71 | 72 | majority_layer = neural_logic_net.select( 73 | lambda: SoftMajorityLayer(), 74 | lambda: HardMajorityLayer(), 75 | lambda: SymbolicMajorityLayer(), 76 | ) 77 | 78 | # TODO: construct a majority-k generalisation of the above 79 | # where k is the number of high-soft bits required for a majority 80 | # and where k is a soft-bit parameter. Requires constructing 81 | # a piecewise-continuous function (as per notebook). 82 | -------------------------------------------------------------------------------- /docs/ICML_workshop/icml2023/algorithm.sty: -------------------------------------------------------------------------------- 1 | % ALGORITHM STYLE -- Released 8 April 1996 2 | % for LaTeX-2e 3 | % Copyright -- 1994 Peter Williams 4 | % E-mail Peter.Williams@dsto.defence.gov.au 5 | \NeedsTeXFormat{LaTeX2e} 6 | \ProvidesPackage{algorithm} 7 | \typeout{Document Style `algorithm' - floating environment} 8 | 9 | \RequirePackage{float} 10 | \RequirePackage{ifthen} 11 | \newcommand{\ALG@within}{nothing} 12 | \newboolean{ALG@within} 13 | \setboolean{ALG@within}{false} 14 | \newcommand{\ALG@floatstyle}{ruled} 15 | \newcommand{\ALG@name}{Algorithm} 16 | \newcommand{\listalgorithmname}{List of \ALG@name s} 17 | 18 | % Declare Options 19 | % first appearance 20 | \DeclareOption{plain}{ 21 | \renewcommand{\ALG@floatstyle}{plain} 22 | } 23 | \DeclareOption{ruled}{ 24 | \renewcommand{\ALG@floatstyle}{ruled} 25 | } 26 | \DeclareOption{boxed}{ 27 | \renewcommand{\ALG@floatstyle}{boxed} 28 | } 29 | % then numbering convention 30 | \DeclareOption{part}{ 31 | \renewcommand{\ALG@within}{part} 32 | \setboolean{ALG@within}{true} 33 | } 34 | \DeclareOption{chapter}{ 35 | \renewcommand{\ALG@within}{chapter} 36 | \setboolean{ALG@within}{true} 37 | } 38 | \DeclareOption{section}{ 39 | \renewcommand{\ALG@within}{section} 40 | \setboolean{ALG@within}{true} 41 | } 42 | \DeclareOption{subsection}{ 43 | \renewcommand{\ALG@within}{subsection} 44 | \setboolean{ALG@within}{true} 45 | } 46 | \DeclareOption{subsubsection}{ 47 | \renewcommand{\ALG@within}{subsubsection} 48 | \setboolean{ALG@within}{true} 49 | } 50 | \DeclareOption{nothing}{ 51 | \renewcommand{\ALG@within}{nothing} 52 | \setboolean{ALG@within}{true} 53 | } 54 | \DeclareOption*{\edef\ALG@name{\CurrentOption}} 55 | 56 | % ALGORITHM 57 | % 58 | \ProcessOptions 59 | \floatstyle{\ALG@floatstyle} 60 | \ifthenelse{\boolean{ALG@within}}{ 61 | \ifthenelse{\equal{\ALG@within}{part}} 62 | {\newfloat{algorithm}{htbp}{loa}[part]}{} 63 | \ifthenelse{\equal{\ALG@within}{chapter}} 64 | {\newfloat{algorithm}{htbp}{loa}[chapter]}{} 65 | \ifthenelse{\equal{\ALG@within}{section}} 66 | {\newfloat{algorithm}{htbp}{loa}[section]}{} 67 | \ifthenelse{\equal{\ALG@within}{subsection}} 68 | {\newfloat{algorithm}{htbp}{loa}[subsection]}{} 69 | \ifthenelse{\equal{\ALG@within}{subsubsection}} 70 | {\newfloat{algorithm}{htbp}{loa}[subsubsection]}{} 71 | \ifthenelse{\equal{\ALG@within}{nothing}} 72 | {\newfloat{algorithm}{htbp}{loa}}{} 73 | }{ 74 | \newfloat{algorithm}{htbp}{loa} 75 | } 76 | \floatname{algorithm}{\ALG@name} 77 | 78 | \newcommand{\listofalgorithms}{\listof{algorithm}{\listalgorithmname}} 79 | 80 | -------------------------------------------------------------------------------- /docs/ICML_workshop/icml2023-diffxyz/algorithm.sty: -------------------------------------------------------------------------------- 1 | % ALGORITHM STYLE -- Released 8 April 1996 2 | % for LaTeX-2e 3 | % Copyright -- 1994 Peter Williams 4 | % E-mail Peter.Williams@dsto.defence.gov.au 5 | \NeedsTeXFormat{LaTeX2e} 6 | \ProvidesPackage{algorithm} 7 | \typeout{Document Style `algorithm' - floating environment} 8 | 9 | \RequirePackage{float} 10 | \RequirePackage{ifthen} 11 | \newcommand{\ALG@within}{nothing} 12 | \newboolean{ALG@within} 13 | \setboolean{ALG@within}{false} 14 | \newcommand{\ALG@floatstyle}{ruled} 15 | \newcommand{\ALG@name}{Algorithm} 16 | \newcommand{\listalgorithmname}{List of \ALG@name s} 17 | 18 | % Declare Options 19 | % first appearance 20 | \DeclareOption{plain}{ 21 | \renewcommand{\ALG@floatstyle}{plain} 22 | } 23 | \DeclareOption{ruled}{ 24 | \renewcommand{\ALG@floatstyle}{ruled} 25 | } 26 | \DeclareOption{boxed}{ 27 | \renewcommand{\ALG@floatstyle}{boxed} 28 | } 29 | % then numbering convention 30 | \DeclareOption{part}{ 31 | \renewcommand{\ALG@within}{part} 32 | \setboolean{ALG@within}{true} 33 | } 34 | \DeclareOption{chapter}{ 35 | \renewcommand{\ALG@within}{chapter} 36 | \setboolean{ALG@within}{true} 37 | } 38 | \DeclareOption{section}{ 39 | \renewcommand{\ALG@within}{section} 40 | \setboolean{ALG@within}{true} 41 | } 42 | \DeclareOption{subsection}{ 43 | \renewcommand{\ALG@within}{subsection} 44 | \setboolean{ALG@within}{true} 45 | } 46 | \DeclareOption{subsubsection}{ 47 | \renewcommand{\ALG@within}{subsubsection} 48 | \setboolean{ALG@within}{true} 49 | } 50 | \DeclareOption{nothing}{ 51 | \renewcommand{\ALG@within}{nothing} 52 | \setboolean{ALG@within}{true} 53 | } 54 | \DeclareOption*{\edef\ALG@name{\CurrentOption}} 55 | 56 | % ALGORITHM 57 | % 58 | \ProcessOptions 59 | \floatstyle{\ALG@floatstyle} 60 | \ifthenelse{\boolean{ALG@within}}{ 61 | \ifthenelse{\equal{\ALG@within}{part}} 62 | {\newfloat{algorithm}{htbp}{loa}[part]}{} 63 | \ifthenelse{\equal{\ALG@within}{chapter}} 64 | {\newfloat{algorithm}{htbp}{loa}[chapter]}{} 65 | \ifthenelse{\equal{\ALG@within}{section}} 66 | {\newfloat{algorithm}{htbp}{loa}[section]}{} 67 | \ifthenelse{\equal{\ALG@within}{subsection}} 68 | {\newfloat{algorithm}{htbp}{loa}[subsection]}{} 69 | \ifthenelse{\equal{\ALG@within}{subsubsection}} 70 | {\newfloat{algorithm}{htbp}{loa}[subsubsection]}{} 71 | \ifthenelse{\equal{\ALG@within}{nothing}} 72 | {\newfloat{algorithm}{htbp}{loa}}{} 73 | }{ 74 | \newfloat{algorithm}{htbp}{loa} 75 | } 76 | \floatname{algorithm}{\ALG@name} 77 | 78 | \newcommand{\listofalgorithms}{\listof{algorithm}{\listalgorithmname}} 79 | 80 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import flax 4 | import jax 5 | import numpy 6 | from plum import dispatch 7 | 8 | from neurallogic import harden, symbolic_generation, map_at_elements 9 | 10 | 11 | def to_string(x): 12 | return str(x) 13 | 14 | 15 | @dispatch 16 | def make_symbolic(x: dict): 17 | return map_at_elements.map_at_elements( 18 | x, to_string 19 | ) 20 | 21 | 22 | @dispatch 23 | def make_symbolic(x: list): 24 | return map_at_elements.map_at_elements( 25 | x, to_string 26 | ) 27 | 28 | 29 | @dispatch 30 | def make_symbolic(x: numpy.ndarray): 31 | return map_at_elements.map_at_elements( 32 | x, to_string 33 | ) 34 | 35 | 36 | @dispatch 37 | def make_symbolic(x: jax.numpy.ndarray): 38 | return map_at_elements.map_at_elements( 39 | convert_jax_to_numpy_arrays(x), to_string 40 | ) 41 | 42 | 43 | @dispatch 44 | def make_symbolic(x: bool): 45 | return to_string(x) 46 | 47 | 48 | @dispatch 49 | def make_symbolic(x: str): 50 | return to_string(x) 51 | 52 | 53 | @dispatch 54 | def convert_jax_to_numpy_arrays(x: jax.numpy.ndarray): 55 | return numpy.asarray(x) 56 | 57 | 58 | @dispatch 59 | def convert_jax_to_numpy_arrays(x: dict): 60 | return {k: convert_jax_to_numpy_arrays(v) for k, v in x.items()} 61 | 62 | 63 | @dispatch 64 | def make_symbolic(x: flax.core.FrozenDict): 65 | x = convert_jax_to_numpy_arrays(x.unfreeze()) 66 | return flax.core.FrozenDict(make_symbolic(x)) 67 | 68 | 69 | @dispatch 70 | def make_symbolic(*args): 71 | return tuple([make_symbolic(arg) for arg in args]) 72 | 73 | 74 | def check_consistency(soft: Callable, hard: Callable, expected, *args): 75 | print(f'\nchecking consistency for {soft.__name__}') 76 | # Check that the soft function performs as expected 77 | soft_output = soft(*args) 78 | print(f'Expected: {expected}, Actual soft_output: {repr(soft_output)}') 79 | assert numpy.allclose(soft_output, expected, equal_nan=True) 80 | 81 | # Check that the hard function performs as expected 82 | hard_args = tuple([harden.harden(arg) for arg in args]) 83 | hard_expected = harden.harden(expected) 84 | hard_output = hard(*hard_args) 85 | print(f'Expected: {hard_expected}, Actual hard_output: {repr(hard_output)}') 86 | assert numpy.allclose(hard_output, hard_expected, equal_nan=True) 87 | 88 | # Check that the jaxpr performs as expected 89 | symbolic_f = symbolic_generation.make_symbolic_jaxpr(hard, *hard_args) 90 | symbolic_output = symbolic_generation.eval_symbolic(symbolic_f, *hard_args) 91 | print(f'Expected: {hard_expected}, Actual symbolic_output: {repr(symbolic_output)}') 92 | assert numpy.allclose(symbolic_output, hard_expected, equal_nan=True) 93 | -------------------------------------------------------------------------------- /neurallogic/hard_xor.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import jax 4 | from flax import linen as nn 5 | 6 | from neurallogic import neural_logic_net, symbolic_generation, hard_masks 7 | 8 | 9 | def differentiable_xor(x, y): 10 | return jax.numpy.minimum(jax.numpy.maximum(x, y), 1.0 - jax.numpy.minimum(x, y)) 11 | 12 | 13 | # TODO: seperate out the mask from the xor operation 14 | def soft_xor_neuron(w, x): 15 | # Conditionally include input bits, according to weights 16 | x = jax.vmap(hard_masks.soft_mask_to_false, 0, 0)(w, x) 17 | x = jax.lax.reduce(x, jax.numpy.array(0, dtype=x.dtype), differentiable_xor, (0,)) 18 | return x 19 | 20 | 21 | def hard_xor_neuron(w, x): 22 | x = jax.vmap(hard_masks.hard_mask_to_false, 0, 0)(w, x) 23 | return jax.lax.reduce(x, False, jax.lax.bitwise_xor, [0]) 24 | 25 | 26 | soft_xor_layer = jax.vmap(soft_xor_neuron, (0, None), 0) 27 | 28 | 29 | hard_xor_layer = jax.vmap(hard_xor_neuron, (0, None), 0) 30 | 31 | 32 | class SoftXorLayer(nn.Module): 33 | layer_size: int 34 | weights_init: Callable = ( 35 | nn.initializers.uniform(1.0) 36 | ) 37 | dtype: jax.numpy.dtype = jax.numpy.float32 38 | 39 | @nn.compact 40 | def __call__(self, x): 41 | weights_shape = (self.layer_size, jax.numpy.shape(x)[-1]) 42 | weights = self.param( 43 | "bit_weights", self.weights_init, weights_shape, self.dtype 44 | ) 45 | x = jax.numpy.asarray(x, self.dtype) 46 | return soft_xor_layer(weights, x) 47 | 48 | 49 | class HardXorLayer(nn.Module): 50 | layer_size: int 51 | 52 | @nn.compact 53 | def __call__(self, x): 54 | weights_shape = (self.layer_size, jax.numpy.shape(x)[-1]) 55 | weights = self.param( 56 | "bit_weights", nn.initializers.constant(True), weights_shape 57 | ) 58 | return hard_xor_layer(weights, x) 59 | 60 | 61 | class SymbolicXorLayer: 62 | def __init__(self, layer_size): 63 | self.layer_size = layer_size 64 | self.hard_xor_layer = HardXorLayer(self.layer_size) 65 | 66 | def __call__(self, x): 67 | jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(self.hard_xor_layer, x) 68 | return symbolic_generation.symbolic_expression(jaxpr, x) 69 | 70 | 71 | xor_layer = neural_logic_net.select( 72 | lambda layer_size, weights_init=nn.initializers.uniform( 73 | 1.0 74 | ), dtype=jax.numpy.float32: SoftXorLayer(layer_size, weights_init, dtype), 75 | lambda layer_size, weights_init=nn.initializers.constant( 76 | True 77 | ), dtype=jax.numpy.float32: HardXorLayer(layer_size), 78 | lambda layer_size, weights_init=nn.initializers.constant( 79 | True 80 | ), dtype=jax.numpy.float32: SymbolicXorLayer(layer_size), 81 | ) 82 | -------------------------------------------------------------------------------- /tests/test_hard_count.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import jax 3 | 4 | from neurallogic import hard_count 5 | 6 | def test_soft_count(): 7 | # 2 bits are high in a 7-bit input array, x 8 | x = numpy.array([1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]) 9 | y = hard_count.soft_count(x) 10 | # We expect a 8-bit output array, y, where y[5] is the only high soft-bit (indicating that 5 soft-bits are low in the input) 11 | expected_output = numpy.array([0.25, 0.25, 0.25, 0.25, 0.25, 1., 0.25, 0.25]) 12 | print("soft_count", y) 13 | assert numpy.allclose(y, expected_output) 14 | 15 | # Same example as above, except instead of 0s and 1s, we have soft-bits 16 | x = numpy.array([0.9, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1]) 17 | y = hard_count.soft_count(x) 18 | # We expect an 8-bit output array, y, where y[5] is the only high soft-bit (indicating that 5 soft-bits are low in the input) 19 | expected_output = numpy.array([0.32000002, 0.3, 0.3, 0.3, 0.3, 0.85999995, 0.3, 0.32000002]) 20 | print("soft_count", y) 21 | assert numpy.allclose(y, expected_output) 22 | 23 | x = numpy.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) 24 | y = hard_count.soft_count(x) 25 | # We expect an 8-bit output array, y, where no y[0] is high (indicating that 0 soft-bits are low in the input) 26 | expected_output = numpy.array([1., 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25]) 27 | print("soft_count", y) 28 | assert numpy.allclose(y, expected_output) 29 | 30 | # Same example as above, except instead of 0s and 1s, we have soft-bits 31 | x = numpy.array([0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9]) 32 | y = hard_count.soft_count(x) 33 | # We expect an 8-bit output array, y, where y[0] is high (indicating that 0 soft-bits are low in the input) 34 | expected_output = numpy.array([0.88, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.32000002]) 35 | print("soft_count", y) 36 | assert numpy.allclose(y, expected_output) 37 | 38 | x = numpy.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) 39 | y = hard_count.soft_count(x) 40 | # We expect an 8-bit output array, y, where y[7] is the only high soft-bit (indicating that 7 soft-bits are low in the input) 41 | expected_output = numpy.array([0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 1.]) 42 | print("soft_count", y) 43 | assert numpy.allclose(y, expected_output) 44 | 45 | # Same example as above, except instead of 0s and 1s, we have soft-bits 46 | x = numpy.array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]) 47 | y = hard_count.soft_count(x) 48 | # We expect an 7-bit output array, y, where y[7] is the only high soft-bit (indicating that 7 soft-bits are low in the input) 49 | expected_output = numpy.array([0.32000002, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.88]) 50 | print("soft_count", y) 51 | assert numpy.allclose(y, expected_output) 52 | 53 | # TODO: test soft_count == hard_count -------------------------------------------------------------------------------- /neurallogic/symbolic_operator.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import numpy 3 | from plum import dispatch 4 | 5 | 6 | @dispatch 7 | def symbolic_operator(operator: str, x: str) -> str: 8 | return f'{operator}({x})'.replace('\'', '') 9 | 10 | 11 | @dispatch 12 | def symbolic_operator(operator: str, x: str, y: str): 13 | return f'{operator}({x}, {y})'.replace('\'', '') 14 | 15 | 16 | @dispatch 17 | def symbolic_operator(operator: str, x: float, y: str): 18 | return symbolic_operator(operator, str(x), y) 19 | 20 | 21 | @dispatch 22 | def symbolic_operator(operator: str, x: int, y: str): 23 | return symbolic_operator(operator, str(x), y) 24 | 25 | 26 | @dispatch 27 | def symbolic_operator(operator: str, x: str, y: float): 28 | return symbolic_operator(operator, x, str(y)) 29 | 30 | 31 | @dispatch 32 | def symbolic_operator(operator: str, x: float, y: numpy.ndarray): 33 | return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y) 34 | 35 | 36 | @dispatch 37 | def symbolic_operator(operator: str, x: numpy.ndarray, y: numpy.ndarray): 38 | return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y) 39 | 40 | 41 | @dispatch 42 | def symbolic_operator(operator: str, x: numpy.ndarray, y: float): 43 | return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y) 44 | 45 | 46 | @dispatch 47 | def symbolic_operator(operator: str, x: list, y: float): 48 | return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y) 49 | 50 | 51 | @dispatch 52 | def symbolic_operator(operator: str, x: list, y: list): 53 | return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y) 54 | 55 | 56 | @dispatch 57 | def symbolic_operator(operator: str, x: bool, y: str): 58 | return symbolic_operator(operator, str(x), y) 59 | 60 | 61 | @dispatch 62 | def symbolic_operator(operator: str, x: str, y: numpy.ndarray): 63 | return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y) 64 | 65 | 66 | @dispatch 67 | def symbolic_operator(operator: str, x: str, y: int): 68 | return symbolic_operator(operator, x, str(y)) 69 | 70 | 71 | @dispatch 72 | def symbolic_operator(operator: str, x: tuple): 73 | return symbolic_operator(operator, str(x)) 74 | 75 | 76 | @dispatch 77 | def symbolic_operator(operator: str, x: list, y: numpy.ndarray): 78 | return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y) 79 | 80 | 81 | @dispatch 82 | def symbolic_operator(operator: str, x: numpy.ndarray, y: jax.numpy.ndarray): 83 | return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x, y) 84 | 85 | 86 | @dispatch 87 | def symbolic_operator(operator: str, x: numpy.ndarray): 88 | return numpy.vectorize(symbolic_operator, otypes=[object])(operator, x) 89 | 90 | 91 | @dispatch 92 | def symbolic_operator(operator: str, x: list): 93 | return symbolic_operator(operator, numpy.array(x)) 94 | -------------------------------------------------------------------------------- /tests/test_symbolic_generation.py: -------------------------------------------------------------------------------- 1 | import flax 2 | import jax 3 | import jax.numpy as jnp 4 | import numpy 5 | 6 | from neurallogic import (hard_and, hard_majority, hard_not, hard_or, hard_xor, 7 | harden, harden_layer, neural_logic_net, real_encoder, 8 | symbolic_generation, hard_concatenate, hard_vmap, symbolic_primitives) 9 | from tests import utils 10 | 11 | 12 | def nln(type, x, width): 13 | # Can't symbolically support this layer yet since the symbolic output is an unevaluated string that 14 | # lacks the correct tensor structure 15 | # x = real_encoder.real_encoder_layer(type)(2)(x) 16 | # x = x.ravel() 17 | y = hard_vmap.vmap(type)((lambda x: 1 - x, lambda x: 1 - x, lambda x: symbolic_primitives.symbolic_not(x)))(x) 18 | x = hard_concatenate.concatenate(type)([x, y], 0) 19 | x = hard_or.or_layer(type)(width)(x) 20 | x = hard_and.and_layer(type)(width)(x) 21 | x = hard_xor.xor_layer(type)(width)(x) 22 | x = hard_not.not_layer(type)(2)(x) 23 | x = hard_majority.majority_layer(type)()(x) 24 | x = harden_layer.harden_layer(type)(x) 25 | x = x.reshape([2, 1]) 26 | x = x.sum(-1) 27 | return x 28 | 29 | 30 | def test_symbolic_generation(): 31 | # Define width of network 32 | width = 2 33 | # Define the neural logic net 34 | soft, hard, symbolic = neural_logic_net.net( 35 | lambda type, x: nln(type, x, width)) 36 | # Initialize a random number generator 37 | rng = jax.random.PRNGKey(0) 38 | #rng, init_rng = jax.random.split(rng) 39 | mock_input = harden.harden(jnp.ones([2 * 2])) 40 | # Initialize the weights of the neural logic net 41 | soft_weights = soft.init(rng, mock_input) 42 | hard_weights = harden.hard_weights(soft_weights) 43 | # Apply the neural logic net to the hard input 44 | hard_output = hard.apply(hard_weights, mock_input) 45 | 46 | # Check the standard evaluation of the network equals the non-standard evaluation 47 | symbolic_weights = harden.hard_weights(soft_weights) 48 | symbolic_output = symbolic.apply(symbolic_weights, mock_input) 49 | assert numpy.array_equal(symbolic_output, hard_output) 50 | 51 | # Check the standard evaluation of the network equals the non-standard symbolic evaluation 52 | symbolic_mock_input = utils.make_symbolic(mock_input) 53 | symbolic_output = symbolic.apply(symbolic_weights, symbolic_mock_input) 54 | assert numpy.array_equal(hard_output.shape, symbolic_output.shape) 55 | 56 | # Compute the symbolic expression, i.e. perform the actual operations in the symbolic expression 57 | eval_symbolic_output = symbolic_generation.eval_symbolic_expression(symbolic_output) 58 | # If this assertion succeeds then the non-standard symbolic evaluation of the jaxpr is is identical to the standard evaluation of network 59 | assert numpy.array_equal(hard_output, eval_symbolic_output) 60 | -------------------------------------------------------------------------------- /neurallogic/hard_not.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import jax 4 | from flax import linen as nn 5 | 6 | from neurallogic import neural_logic_net, symbolic_generation, hard_and, hard_or, initialization 7 | 8 | 9 | def soft_not(w, x): 10 | """ 11 | w > 0.5 implies the not operation is inactive, else active 12 | 13 | Assumes x is in [0, 1] 14 | 15 | Corresponding hard logic: ! (x XOR w) 16 | """ 17 | w = jax.numpy.clip(w, 0.0, 1.0) 18 | return 1.0 - w + x * (2.0 * w - 1.0) 19 | 20 | # TODO: split out function of parameter, and not operation, in order to simplify 21 | def soft_not_deprecated(w: float, x: float) -> float: 22 | w = jax.numpy.clip(w, 0.0, 1.0) 23 | # (w && x) || (! w && ! x) 24 | return hard_or.soft_or(hard_and.soft_and(w, x), hard_and.soft_and(1.0 - w, 1.0 - x)) 25 | 26 | 27 | def hard_not(w: bool, x: bool): 28 | return jax.numpy.logical_not(jax.numpy.logical_xor(x, w)) 29 | 30 | 31 | soft_not_neuron = jax.vmap(soft_not, 0, 0) 32 | 33 | hard_not_neuron = jax.vmap(hard_not, 0, 0) 34 | 35 | 36 | soft_not_layer = jax.vmap(soft_not_neuron, (0, None), 0) 37 | 38 | hard_not_layer = jax.vmap(hard_not_neuron, (0, None), 0) 39 | 40 | 41 | class SoftNotLayer(nn.Module): 42 | layer_size: int 43 | weights_init: Callable = initialization.initialize_uniform_range(0.49, 0.51) 44 | dtype: jax.numpy.dtype = jax.numpy.float32 45 | 46 | @nn.compact 47 | def __call__(self, x): 48 | weights_shape = (self.layer_size, jax.numpy.shape(x)[-1]) 49 | weights = self.param( 50 | "bit_weights", self.weights_init, weights_shape, self.dtype 51 | ) 52 | x = jax.numpy.asarray(x, self.dtype) 53 | return soft_not_layer(weights, x) 54 | 55 | 56 | class HardNotLayer(nn.Module): 57 | layer_size: int 58 | weights_init: Callable = nn.initializers.constant(True) 59 | 60 | @nn.compact 61 | def __call__(self, x): 62 | weights_shape = (self.layer_size, jax.numpy.shape(x)[-1]) 63 | weights = self.param("bit_weights", self.weights_init, weights_shape) 64 | return hard_not_layer(weights, x) 65 | 66 | 67 | class SymbolicNotLayer: 68 | def __init__(self, layer_size): 69 | self.layer_size = layer_size 70 | self.hard_not_layer = HardNotLayer(self.layer_size) 71 | 72 | def __call__(self, x): 73 | jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(self.hard_not_layer, x) 74 | return symbolic_generation.symbolic_expression(jaxpr, x) 75 | 76 | 77 | not_layer = neural_logic_net.select( 78 | lambda layer_size, weights_init=initialization.initialize_uniform_range(0.49, 0.51), dtype=jax.numpy.float32: SoftNotLayer(layer_size, weights_init, dtype), 79 | lambda layer_size, weights_init=initialization.initialize_uniform_range(0.49, 0.51), dtype=jax.numpy.float32: HardNotLayer(layer_size), 80 | lambda layer_size, weights_init=initialization.initialize_uniform_range(0.49, 0.51), dtype=jax.numpy.float32: SymbolicNotLayer(layer_size), 81 | ) 82 | -------------------------------------------------------------------------------- /neurallogic/real_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import jax 4 | from flax import linen as nn 5 | 6 | from neurallogic import neural_logic_net, symbolic_generation 7 | 8 | # TODO: implement a soft_real_decoder that can perhaps replace the port count approach 9 | 10 | 11 | def soft_real_encoder(t: float, x: float): 12 | eps = 0.0000001 13 | # x should be in [0, 1] 14 | t = jax.numpy.clip(t, 0, 1) 15 | return jax.numpy.where( 16 | jax.numpy.isclose(t, x), 17 | 0.5, 18 | # t != x 19 | jax.numpy.where( 20 | x < t, 21 | (x / (2 * t + eps)), 22 | # x > t 23 | (x + 1 - 2 * t) / (2 * (1 - t) + eps) 24 | ) 25 | ) 26 | 27 | 28 | def hard_real_encoder(t, x): 29 | # t and x must be floats 30 | return jax.numpy.where(soft_real_encoder(t, x) > 0.5, True, False) 31 | 32 | 33 | soft_real_encoder_neuron = jax.vmap(soft_real_encoder, in_axes=(0, None)) 34 | 35 | hard_real_encoder_neuron = jax.vmap(hard_real_encoder, in_axes=(0, None)) 36 | 37 | soft_real_encoder_layer = jax.vmap(soft_real_encoder_neuron, (0, 0), 0) 38 | 39 | hard_real_encoder_layer = jax.vmap(hard_real_encoder_neuron, (0, 0), 0) 40 | 41 | 42 | class SoftRealEncoderLayer(nn.Module): 43 | bits_per_real: int 44 | thresholds_init: Callable = nn.initializers.uniform(1.0) 45 | dtype: jax.numpy.dtype = jax.numpy.float32 46 | 47 | @nn.compact 48 | def __call__(self, x): 49 | thresholds_shape = (jax.numpy.shape(x)[-1], self.bits_per_real) 50 | thresholds = self.param( 51 | "thresholds", self.thresholds_init, thresholds_shape, self.dtype) 52 | x = jax.numpy.asarray(x, self.dtype) 53 | return soft_real_encoder_layer(thresholds, x) 54 | 55 | 56 | class HardRealEncoderLayer(nn.Module): 57 | bits_per_real: int 58 | 59 | @nn.compact 60 | def __call__(self, x): 61 | thresholds_shape = (jax.numpy.shape(x)[-1], self.bits_per_real) 62 | thresholds = self.param( 63 | "thresholds", nn.initializers.constant(0.0), thresholds_shape) 64 | return hard_real_encoder_layer(thresholds, x) 65 | 66 | 67 | class SymbolicRealEncoderLayer: 68 | def __init__(self, bits_per_real): 69 | self.bits_per_real = bits_per_real 70 | self.hard_real_encoder_layer = HardRealEncoderLayer(self.bits_per_real) 71 | 72 | def __call__(self, x): 73 | jaxpr = symbolic_generation.make_symbolic_flax_jaxpr( 74 | self.hard_real_encoder_layer, x 75 | ) 76 | return symbolic_generation.symbolic_expression(jaxpr, x) 77 | 78 | 79 | real_encoder_layer = neural_logic_net.select( 80 | lambda bits_per_real, weights_init=nn.initializers.uniform( 81 | 1.0 82 | ), dtype=jax.numpy.float32: SoftRealEncoderLayer( 83 | bits_per_real, weights_init, dtype 84 | ), 85 | lambda bits_per_real, weights_init=nn.initializers.uniform( 86 | 1.0 87 | ), dtype=jax.numpy.float32: HardRealEncoderLayer(bits_per_real), 88 | lambda bits_per_real, weights_init=nn.initializers.uniform( 89 | 1.0 90 | ), dtype=jax.numpy.float32: SymbolicRealEncoderLayer(bits_per_real), 91 | ) 92 | -------------------------------------------------------------------------------- /neurallogic/hard_or.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import jax 4 | from flax import linen as nn 5 | 6 | from neurallogic import ( 7 | neural_logic_net, 8 | symbolic_generation, 9 | hard_masks, 10 | initialization, 11 | ) 12 | 13 | 14 | def soft_or(x, y): 15 | m = jax.numpy.maximum(x, y) 16 | return jax.numpy.where( 17 | 2 * m > 1, 18 | 0.5 + 0.5 * (x + y) * (m - 0.5), 19 | m + 0.5 * (x + y) * (0.5 - m), 20 | ) 21 | 22 | def soft_or_vec(x): 23 | m = jax.numpy.max(x) 24 | mean = jax.numpy.mean(x) 25 | delta = jax.numpy.abs(mean - 0.5) 26 | return jax.numpy.where( 27 | 2 * m > 1, 28 | 0.5 + delta, 29 | m + delta 30 | ) 31 | 32 | 33 | # TODO: seperate out the or operation from the mask operation 34 | def soft_or_neuron(w, x): 35 | x = jax.vmap(hard_masks.soft_mask_to_false_margin, 0, 0)(w, x) 36 | return jax.numpy.max(x) 37 | 38 | # TODO: doesn't seem to work as well 39 | def soft_or_neuron_deprecated(w, x): 40 | x = jax.vmap(hard_masks.soft_mask_to_false_margin, 0, 0)(w, x) 41 | return soft_or_vec(x) 42 | 43 | 44 | def hard_or_neuron(w, x): 45 | x = jax.vmap(hard_masks.hard_mask_to_false, 0, 0)(w, x) 46 | return jax.lax.reduce(x, False, jax.lax.bitwise_or, [0]) 47 | 48 | 49 | soft_or_layer = jax.vmap(soft_or_neuron, (0, None), 0) 50 | 51 | hard_or_layer = jax.vmap(hard_or_neuron, (0, None), 0) 52 | 53 | 54 | class SoftOrLayer(nn.Module): 55 | layer_size: int 56 | weights_init: Callable = initialization.initialize_near_to_one() 57 | dtype: jax.numpy.dtype = jax.numpy.float32 58 | 59 | @nn.compact 60 | def __call__(self, x): 61 | weights_shape = (self.layer_size, jax.numpy.shape(x)[-1]) 62 | weights = self.param( 63 | "bit_weights", self.weights_init, weights_shape, self.dtype 64 | ) 65 | x = jax.numpy.asarray(x, self.dtype) 66 | return soft_or_layer(weights, x) 67 | 68 | 69 | class HardOrLayer(nn.Module): 70 | layer_size: int 71 | 72 | @nn.compact 73 | def __call__(self, x): 74 | weights_shape = (self.layer_size, jax.numpy.shape(x)[-1]) 75 | weights = self.param( 76 | "bit_weights", nn.initializers.constant(True), weights_shape 77 | ) 78 | return hard_or_layer(weights, x) 79 | 80 | 81 | class SymbolicOrLayer: 82 | def __init__(self, layer_size): 83 | self.layer_size = layer_size 84 | self.hard_or_layer = HardOrLayer(self.layer_size) 85 | 86 | def __call__(self, x): 87 | jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(self.hard_or_layer, x) 88 | return symbolic_generation.symbolic_expression(jaxpr, x) 89 | 90 | 91 | or_layer = neural_logic_net.select( 92 | lambda layer_size, weights_init=initialization.initialize_near_to_one(), dtype=jax.numpy.float32: SoftOrLayer( 93 | layer_size, weights_init, dtype 94 | ), 95 | lambda layer_size, weights_init=nn.initializers.constant( 96 | True 97 | ), dtype=jax.numpy.float32: HardOrLayer(layer_size), 98 | lambda layer_size, weights_init=nn.initializers.constant( 99 | True 100 | ), dtype=jax.numpy.float32: SymbolicOrLayer(layer_size), 101 | ) 102 | -------------------------------------------------------------------------------- /neurallogic/hard_count.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from flax import linen as nn 3 | 4 | from neurallogic import neural_logic_net, symbolic_generation, hard_and 5 | 6 | 7 | def low_to_high(x, y): 8 | return hard_and.soft_and(1 - x, y) 9 | 10 | def soft_count(x: jax.numpy.array): 11 | """ 12 | Returns an array of soft-bits, of length |x|+1, and where only 1 soft-bit is high. 13 | The index of the high soft-bit indicates the total quantity of low and high bits in the input array. 14 | i.e. if index i is high, then there are i low bits 15 | 16 | E.g. if x = [0.1, 0.9, 0.2, 0.6, 0.4], then the output is y=[low, low, low, high, low, low] 17 | y[3] is high, which indicates that 18 | - 3 bits are low 19 | - 2 bits are high 20 | 21 | E.g. if x = [0.0, 0.2, 0.3, 0.1, 0.4], then the output is y=[low, low, low, low, low, high] 22 | y[5] is high, which indicates that 23 | - 5 bits are low 24 | - 0 bits are high 25 | 26 | E.g. if x = [0.9, 0.8, 0.7, 0.6, 0.5], then the output is y=[high, low, low, low, low, low] 27 | y[0] is high, which indicates that 28 | - 0 bits are low 29 | - 5 bits are high 30 | """ 31 | sorted_x = jax.numpy.sort(x, axis=-1) 32 | low = jax.numpy.array([0.0]) 33 | high = jax.numpy.array([1.0]) 34 | sorted_x = jax.numpy.concatenate([low, sorted_x, high]) 35 | return jax.vmap(low_to_high)(sorted_x[:-1], sorted_x[1:]) 36 | 37 | def augmented_bit(mean, representative_bit) -> float: 38 | margin = jax.numpy.abs(representative_bit - 0.5) 39 | margin_delta = mean * margin 40 | representative_bit = jax.numpy.where( 41 | representative_bit > 0.5, 42 | 0.5 + margin_delta, 43 | representative_bit + margin_delta, 44 | ) 45 | return representative_bit 46 | 47 | # TODO: investigate 48 | def soft_count_packed(x: jax.numpy.array): 49 | mean = jax.numpy.mean(x, axis=-1) 50 | sorted_x = jax.numpy.sort(x, axis=-1) 51 | low = jax.numpy.array([0.0]) 52 | high = jax.numpy.array([1.0]) 53 | sorted_x = jax.numpy.concatenate([low, sorted_x, high]) 54 | sorted_x = jax.vmap(low_to_high)(sorted_x[:-1], sorted_x[1:]) 55 | return jax.vmap(lambda x: augmented_bit(mean, x))(sorted_x) 56 | 57 | def hard_count(x: jax.numpy.array): 58 | # We simply count the number of low bits 59 | num_low_bits = jax.numpy.sum(x <= 0.5, axis=-1) 60 | return jax.nn.one_hot(num_low_bits, num_classes=x.shape[-1] + 1) 61 | 62 | 63 | soft_count_layer = jax.vmap(soft_count, in_axes=0) 64 | 65 | hard_count_layer = jax.vmap(hard_count, in_axes=0) 66 | 67 | 68 | class SoftCountLayer(nn.Module): 69 | @nn.compact 70 | def __call__(self, x): 71 | return soft_count_layer(x) 72 | 73 | 74 | class HardCountLayer(nn.Module): 75 | @nn.compact 76 | def __call__(self, x): 77 | return hard_count_layer(x) 78 | 79 | 80 | class SymbolicCountLayer: 81 | def __init__(self): 82 | self.hard_count_layer = HardCountLayer() 83 | 84 | def __call__(self, x): 85 | jaxpr = symbolic_generation.make_symbolic_flax_jaxpr( 86 | self.hard_count_layer, x 87 | ) 88 | return symbolic_generation.symbolic_expression(jaxpr, x) 89 | 90 | 91 | count_layer = neural_logic_net.select( 92 | lambda: SoftCountLayer(), 93 | lambda: HardCountLayer(), 94 | lambda: SymbolicCountLayer(), 95 | ) 96 | 97 | -------------------------------------------------------------------------------- /neurallogic/hard_dropout.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Callable 2 | 3 | import jax 4 | from flax import linen as nn 5 | from jax import lax, random 6 | 7 | from neurallogic import neural_logic_net 8 | 9 | 10 | class SoftHardDropout(nn.Module): 11 | """Create a dropout layer suitable for dropping soft-bit values. 12 | Adapted from flax/stochastic.py 13 | 14 | 15 | Note: When using :meth:`Module.apply() `, make sure 16 | to include an RNG seed named `'dropout'`. For example:: 17 | 18 | model.apply({'params': params}, inputs=inputs, train=True, rngs={'dropout': dropout_rng})` 19 | 20 | Attributes: 21 | rate: the dropout probability. (_not_ the keep rate!) 22 | broadcast_dims: dimensions that will share the same dropout mask 23 | deterministic: if false the inputs are scaled by `1 / (1 - rate)` and 24 | masked, whereas if true, no mask is applied and the inputs are returned 25 | as is. 26 | rng_collection: the rng collection name to use when requesting an rng key. 27 | """ 28 | 29 | rate: float 30 | broadcast_dims: Sequence[int] = () 31 | deterministic: Optional[bool] = None 32 | rng_collection: str = "dropout" 33 | dtype: jax.numpy.dtype = jax.numpy.float32 34 | dropout_function: Callable = lambda x: jax.numpy.full_like(x, 0.0) 35 | 36 | @nn.compact 37 | def __call__(self, inputs, deterministic: Optional[bool] = None): 38 | """Applies a random dropout mask to the input. 39 | 40 | Args: 41 | inputs: the inputs that should be randomly masked. 42 | Masking means setting the input bits to 0.5. 43 | deterministic: if false the inputs are masked, 44 | whereas if true, no mask is applied and the inputs are returned 45 | as is. 46 | 47 | Returns: 48 | The masked inputs 49 | """ 50 | deterministic = nn.merge_param( 51 | "deterministic", self.deterministic, deterministic 52 | ) 53 | 54 | if (self.rate == 0.0) or deterministic: 55 | return inputs 56 | 57 | # Prevent gradient NaNs in 1.0 edge-case. 58 | if self.rate == 1.0: 59 | return jax.numpy.zeros_like(inputs) 60 | 61 | keep_prob = 1.0 - self.rate 62 | rng = self.make_rng(self.rng_collection) 63 | broadcast_shape = list(inputs.shape) 64 | for dim in self.broadcast_dims: 65 | broadcast_shape[dim] = 1 66 | mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) 67 | mask = jax.numpy.broadcast_to(mask, inputs.shape) 68 | """ 69 | masked_values = jax.numpy.full_like( 70 | inputs, self.dropout_value, dtype=self.dtype 71 | ) 72 | """ 73 | masked_values = jax.vmap(self.dropout_function)(inputs) 74 | return lax.select(mask, inputs, masked_values) 75 | 76 | 77 | class HardHardDropout(nn.Module): 78 | @nn.compact 79 | def __call__(self, inputs, deterministic: Optional[bool] = None): 80 | return inputs 81 | 82 | 83 | class SymbolicHardDropout(nn.Module): 84 | @nn.compact 85 | def __call__(self, inputs, deterministic: Optional[bool] = None): 86 | return inputs 87 | 88 | 89 | hard_dropout = neural_logic_net.select( 90 | lambda **kwargs: SoftHardDropout(**kwargs), 91 | lambda **kwargs: HardHardDropout(**kwargs), 92 | lambda **kwargs: SymbolicHardDropout(**kwargs), 93 | ) 94 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ∂B nets 2 | 3 | [![Python package](https://github.com/Z80coder/discrete-differentiable-networks/actions/workflows/python.yaml/badge.svg)](https://github.com/Z80coder/discrete-differentiable-networks/actions/workflows/python.yaml) 4 | 5 | A neural network library for learning boolean-valued, discrete functions on GPUs with gradient descent. 6 | 7 | The library is implemented in Python using the [Flax](https://github.com/google/flax) and [JAX](https://github.com/google/jax) frameworks. 8 | 9 | Questions? Ask @Z80coder 10 | 11 | ## Papers 12 | 13 | [Lossless hardening with ∂𝔹 nets](https://differentiable.xyz/papers/paper_21.pdf). I. Wright. In ["Differentiable Almost Everything: Differentiable Relaxations, Algorithms, Operators, and Simulators"](https://differentiable.xyz/papers), ICML 2023 Workshop, Honolulu, 2023. 14 | 15 | Draft paper: ["∂B nets: learning discrete functions by gradient descent"](https://arxiv.org/abs/2305.07315) (April 2023). 16 | 17 | ## Demos 18 | 19 | [Neural network research with the Wolfram language](https://youtu.be/FeIwI49AEgM?si=HCOAFOTLHmFmwOXn) (30 mins). 20 | 21 | [∂B nets quick overview](https://drive.google.com/file/d/1wi0uCpCgdSSlyXlb1XmzeE_u7adugAEa/view?usp=sharing) (30 mins). 22 | 23 | [∂B nets overview](https://drive.google.com/file/d/1UUhv6loBrFnZ7jwiHBofnp06at_8bm_F/view?usp=share_link) (1 hour). 24 | 25 | ## Prototype 26 | 27 | The working prototype was implemented in Wolfram. The demos below were snapshots of work-in-progress. 28 | 29 | ### Prototype demos 30 | 31 | - [Neural logic nets](https://drive.google.com/file/d/1_IECuI0f58o_aIIdaQhRo6qPH517YaMa/view?usp=share_link) (15m) 32 | 33 | ### Prototype development snapshots 34 | 35 | - [The Soft-NOT operator](https://drive.google.com/file/d/1z2WFpz4eWLb9xauRnIl6mSXhkbU-XR6X/view?usp=share_link) (10m) 36 | - [The Soft-AND operator](https://drive.google.com/file/d/1l9Y2cWJYYdYSsgqwfH-Dfo2Nxmiewia-/view?usp=share_link) (10m) 37 | - [The differentiable Hard-AND operator](https://drive.google.com/file/d/1Bg1KjKF8KZaBP6jYFhQ5oARrcZYx2O8S/view?usp=share_link) (17m) 38 | - [The differentiable Hard-OR operator](https://drive.google.com/file/d/1WUmJHToU0hQo0YgHlhJb12qECDKzmE8f/view?usp=share_link) (5m) 39 | - [The differentiable Hard-MAJORITY operator](https://drive.google.com/file/d/18oQWhNvbkJGZ49OcQEqGAxkskGZV0e09/view?usp=share_link) (13m) 40 | - [The hardening layer](https://drive.google.com/file/d/1c5K77n9dftsyciq32T7SBBa0PBhIgEq7/view?usp=share_link) (11m) 41 | - [The hardening operation](https://drive.google.com/file/d/1JWA9P9BbfEHWiDfNKVjaH_ssP6CA19Nf/view?usp=share_link) (19m) 42 | - [A classifier architecture](https://drive.google.com/file/d/1KZp8-7hbc_5tHESgmcyBDdBbZDu9UEO9/view?usp=share_link) (20m) 43 | - [Neural logic nets](https://drive.google.com/file/d/1_IECuI0f58o_aIIdaQhRo6qPH517YaMa/view?usp=share_link) (15m) 44 | - [Learning XOR (parity)](https://drive.google.com/file/d/1I2H3iQjM7tNrG83DJFFngQZB_T8jM6uw/view?usp=share_link) (10m) 45 | - [Numerical regression](https://drive.google.com/file/d/1Qx9hBR2nZVymJr3Yoi1CGdg9y8VBxn8P/view?usp=share_link) (23m) 46 | - [If-Then-Else neuron](https://drive.google.com/file/d/1siMqbLr9VYCOwBqNUAnQse9IQSGUjlqo/view?usp=share_link) (23m) 47 | - [Neural conditions and actions](https://drive.google.com/file/d/1WH319bwV55858TYQ9G3C4RPxzdTiA0Ru/view?usp=share_link) (24m) 48 | - [Neural decision lists](https://drive.google.com/file/d/1H0tJtiHz3yXZ7E2xeauaNRd4rnBTUf2v/view?usp=share_link) (15m) 49 | - [Boolean logic nets and MNIST](https://drive.google.com/file/d/12Rwx8H76UTNRdBK4WAwe_QeTWiGrbP-_/view?usp=share_link) (18m) 50 | - [Neural logic nets for differentiable QL](https://drive.google.com/file/d/15rAagCh7LxEN0CHVNkTY6iPWSxrAG0pW/view?usp=share_link) (30m) 51 | 52 | 53 | -------------------------------------------------------------------------------- /neurallogic/hard_and.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import jax 4 | from flax import linen as nn 5 | 6 | from neurallogic import hard_masks, neural_logic_net, symbolic_generation, initialization 7 | 8 | 9 | def soft_and(x, y): 10 | m = jax.numpy.minimum(x, y) 11 | return jax.numpy.where( 12 | 2 * m > 1, 13 | 0.5 + 0.5 * (x + y) * (m - 0.5), 14 | m + 0.5 * (x + y) * (0.5 - m), 15 | ) 16 | 17 | def soft_and_vec(x): 18 | m = jax.numpy.min(x) 19 | mean = jax.numpy.mean(x) 20 | delta = jax.numpy.abs(mean - 0.5) 21 | return jax.numpy.where( 22 | 2 * m > 1, 23 | 0.5 + delta, 24 | m + delta 25 | ) 26 | 27 | # TODO: seperate and operation from mask operation 28 | def soft_and_neuron(w, x): 29 | x = jax.vmap(hard_masks.soft_mask_to_true_margin, 0, 0)(w, x) 30 | #x = jax.vmap(hard_masks.soft_mask_to_true, 0, 0)(w, x) 31 | return jax.numpy.min(x) 32 | 33 | # TODO: doesn't seem to work as well 34 | def soft_and_neuron_deprecated(w, x): 35 | x = jax.vmap(hard_masks.soft_mask_to_true_margin, 0, 0)(w, x) 36 | return soft_and_vec(x) 37 | 38 | def hard_and_neuron(w, x): 39 | x = jax.vmap(hard_masks.hard_mask_to_true, 0, 0)(w, x) 40 | return jax.lax.reduce(x, True, jax.lax.bitwise_and, [0]) 41 | 42 | 43 | soft_and_layer = jax.vmap(soft_and_neuron, (0, None), 0) 44 | 45 | hard_and_layer = jax.vmap(hard_and_neuron, (0, None), 0) 46 | 47 | 48 | 49 | class SoftAndLayer(nn.Module): 50 | """ 51 | A soft-bit AND layer than transforms its inputs along the last dimension. 52 | 53 | Attributes: 54 | layer_size: The number of neurons in the layer. 55 | weights_init: The initializer function for the weight matrix. 56 | """ 57 | 58 | layer_size: int 59 | weights_init: Callable = initialization.initialize_near_to_zero() 60 | dtype: jax.numpy.dtype = jax.numpy.float32 61 | 62 | @nn.compact 63 | def __call__(self, x): 64 | weights_shape = (self.layer_size, jax.numpy.shape(x)[-1]) 65 | weights = self.param( 66 | "bit_weights", self.weights_init, weights_shape, self.dtype 67 | ) 68 | x = jax.numpy.asarray(x, self.dtype) 69 | return soft_and_layer(weights, x) 70 | 71 | 72 | class HardAndLayer(nn.Module): 73 | """ 74 | A hard-bit And layer that shadows the SoftAndLayer. 75 | This is a convenience class to make it easier to switch between soft and hard logic. 76 | 77 | Attributes: 78 | layer_size: The number of neurons in the layer. 79 | """ 80 | 81 | layer_size: int 82 | 83 | @nn.compact 84 | def __call__(self, x): 85 | weights_shape = (self.layer_size, jax.numpy.shape(x)[-1]) 86 | weights = self.param( 87 | "bit_weights", nn.initializers.constant(True), weights_shape 88 | ) 89 | return hard_and_layer(weights, x) 90 | 91 | 92 | class SymbolicAndLayer: 93 | def __init__(self, layer_size): 94 | self.layer_size = layer_size 95 | self.hard_and_layer = HardAndLayer(self.layer_size) 96 | 97 | def __call__(self, x): 98 | jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(self.hard_and_layer, x) 99 | return symbolic_generation.symbolic_expression(jaxpr, x) 100 | 101 | 102 | and_layer = neural_logic_net.select( 103 | lambda layer_size, weights_init=initialization.initialize_near_to_zero(), dtype=jax.numpy.float32: SoftAndLayer( 104 | layer_size, weights_init, dtype 105 | ), 106 | lambda layer_size, weights_init=nn.initializers.constant( 107 | True 108 | ), dtype=jax.numpy.float32: HardAndLayer(layer_size), 109 | lambda layer_size, weights_init=nn.initializers.constant( 110 | True 111 | ), dtype=jax.numpy.float32: SymbolicAndLayer(layer_size), 112 | ) 113 | -------------------------------------------------------------------------------- /tests/test_harden.py: -------------------------------------------------------------------------------- 1 | import flax 2 | import jax.numpy as jnp 3 | 4 | from neurallogic import harden 5 | 6 | 7 | def test_harden_float(): 8 | assert harden.harden_float(0.5) == False 9 | assert harden.harden_float(0.6) == True 10 | assert harden.harden_float(0.4) == False 11 | assert harden.harden_float(0.0) == False 12 | assert harden.harden_float(1.0) == True 13 | 14 | 15 | def test_harden_list(): 16 | assert harden.harden([0.5, 0.6, 0.4, 0.0, 1.0]) == [ 17 | False, True, False, False, True] 18 | 19 | 20 | def test_harden_array(): 21 | assert jnp.array_equal(harden.harden( 22 | jnp.array([0.5, 0.6, 0.4, 0.0, 1.0])), [False, True, False, False, True]) 23 | 24 | 25 | def test_harden_dict(): 26 | dict = {'bit_a': 0.5, 'bit_b': 0.6, 'c': 0.4, 'bit_d': 0.0, 'bit_e': 1.0} 27 | expected_dict = {'bit_a': False, 'bit_b': True, 'c': 0.4, 'bit_d': False, 'bit_e': True} 28 | assert harden.harden(dict) == expected_dict 29 | 30 | 31 | def test_harden_frozen_dict(): 32 | dict = flax.core.frozen_dict.FrozenDict( 33 | {'a': 0.5, 'bit_b': 0.6, 'bit_c': 0.4, 'bit_d': 0.0, 'e': 1.0}) 34 | expected_dict = {'a': 0.5, 'bit_b': True, 'bit_c': False, 'bit_d': False, 'e': 1.0} 35 | assert harden.harden(dict) == expected_dict 36 | 37 | 38 | def test_harden(): 39 | assert harden.harden(0.5) == False 40 | assert harden.harden(0.6) == True 41 | assert harden.harden(0.4) == False 42 | assert harden.harden(0.0) == False 43 | assert harden.harden(1.0) == True 44 | assert harden.harden([0.5, 0.6, 0.4, 0.0, 1.0]) == [ 45 | False, True, False, False, True] 46 | assert jnp.array_equal(harden.harden(jnp.array([0.5, 0.6, 0.4, 0.0, 1.0])), [ 47 | False, True, False, False, True]) 48 | dict = {'bit_a': 0.5, 'bit_b': 0.6, 'bit_c': 0.4, 'bit_d': 0.0, 'e': 1.0} 49 | expected_dict = {'bit_a': False, 'bit_b': True, 'bit_c': False, 'bit_d': False, 'e': 1.0} 50 | assert harden.harden(dict) == expected_dict 51 | dict = flax.core.frozen_dict.FrozenDict(dict) 52 | assert harden.harden(dict) == expected_dict 53 | 54 | 55 | def test_harden_compound_dict(): 56 | dict = {'bit_a': 0.5, 'bit_b': 0.6, 'bit_c': 0.4, 'bit_d': 0.0, 'e': 1.0, 57 | 'f': {'bit_a': 0.5, 'bit_b': 0.6, 'bit_c': 0.4, 'bit_d': 0.0, 'e': 1.0}} 58 | expected_dict = {'bit_a': False, 'bit_b': True, 'bit_c': False, 'bit_d': False, 'e': 1.0, 'f': { 59 | 'bit_a': False, 'bit_b': True, 'bit_c': False, 'bit_d': False, 'e': 1.0}} 60 | assert harden.harden(dict) == expected_dict 61 | dict = flax.core.frozen_dict.FrozenDict(dict) 62 | assert harden.harden(dict) == expected_dict 63 | 64 | 65 | def test_harden_complex_compound_dict(): 66 | dict = {'bit_a': 0.5, 'bit_b': 0.6, 'bit_c': 0.4, 'bit_d': 0.0, 'e': 1.0, 'f': { 67 | 'bit_a': 0.5, 'bit_b': 0.6, 'bit_c': 0.4, 'bit_d': 0.0, 'e': 1.0, 'g': [0.5, 0.6, 0.4, 0.0, 1.0]}} 68 | expected_dict = {'bit_a': False, 'bit_b': True, 'bit_c': False, 'bit_d': False, 'e': 1.0, 'f': { 69 | 'bit_a': False, 'bit_b': True, 'bit_c': False, 'bit_d': False, 'e': 1.0, 'g': [False, True, False, False, True]}} 70 | assert harden.harden(dict) == expected_dict 71 | dict = flax.core.frozen_dict.FrozenDict(dict) 72 | assert harden.harden(dict) == expected_dict 73 | 74 | 75 | def test_dict_with_array(): 76 | dict = {'a': 0.5, 'b': 0.6, 'c': 0.4, 'd': 0.0, 77 | 'e': 1.0, 'f': jnp.array([0.5, 0.6, 0.4, 0.0, 1.0])} 78 | expected_dict = {'a': False, 'b': True, 'c': False, 'd': False, 79 | 'e': True, 'f': jnp.array([False, True, False, False, True])} 80 | str(harden.harden(dict)) == str(expected_dict) 81 | 82 | 83 | def test_harden_compound_list(): 84 | list = [0.5, 0.6, 0.4, 0.0, 1.0, [0.5, 0.6, 0.4, 0.0, 1.0]] 85 | expected_list = [False, True, False, False, 86 | True, [False, True, False, False, True]] 87 | assert harden.harden(list) == expected_list 88 | 89 | 90 | def test_hard_weights(): 91 | weights = flax.core.FrozenDict( 92 | {'Soft_params': {'bit_a': 0.5, 'bit_b': 0.6, 'bit_c': 0.4, 'Soft_d': 0.0, 'e': 1.0}}) 93 | expected_weights = flax.core.FrozenDict( 94 | {'Hard_params': {'bit_a': False, 'bit_b': True, 'bit_c': False, 'Hard_d': 0.0, 'e': 1.0}}) 95 | hard_weights = harden.hard_weights(weights).unfreeze() 96 | assert str(hard_weights) == str(expected_weights.unfreeze()) 97 | -------------------------------------------------------------------------------- /tests/data/iris.data: -------------------------------------------------------------------------------- 1 | 5.1,3.5,1.4,0.2,Iris-setosa 2 | 4.9,3.0,1.4,0.2,Iris-setosa 3 | 4.7,3.2,1.3,0.2,Iris-setosa 4 | 4.6,3.1,1.5,0.2,Iris-setosa 5 | 5.0,3.6,1.4,0.2,Iris-setosa 6 | 5.4,3.9,1.7,0.4,Iris-setosa 7 | 4.6,3.4,1.4,0.3,Iris-setosa 8 | 5.0,3.4,1.5,0.2,Iris-setosa 9 | 4.4,2.9,1.4,0.2,Iris-setosa 10 | 4.9,3.1,1.5,0.1,Iris-setosa 11 | 5.4,3.7,1.5,0.2,Iris-setosa 12 | 4.8,3.4,1.6,0.2,Iris-setosa 13 | 4.8,3.0,1.4,0.1,Iris-setosa 14 | 4.3,3.0,1.1,0.1,Iris-setosa 15 | 5.8,4.0,1.2,0.2,Iris-setosa 16 | 5.7,4.4,1.5,0.4,Iris-setosa 17 | 5.4,3.9,1.3,0.4,Iris-setosa 18 | 5.1,3.5,1.4,0.3,Iris-setosa 19 | 5.7,3.8,1.7,0.3,Iris-setosa 20 | 5.1,3.8,1.5,0.3,Iris-setosa 21 | 5.4,3.4,1.7,0.2,Iris-setosa 22 | 5.1,3.7,1.5,0.4,Iris-setosa 23 | 4.6,3.6,1.0,0.2,Iris-setosa 24 | 5.1,3.3,1.7,0.5,Iris-setosa 25 | 4.8,3.4,1.9,0.2,Iris-setosa 26 | 5.0,3.0,1.6,0.2,Iris-setosa 27 | 5.0,3.4,1.6,0.4,Iris-setosa 28 | 5.2,3.5,1.5,0.2,Iris-setosa 29 | 5.2,3.4,1.4,0.2,Iris-setosa 30 | 4.7,3.2,1.6,0.2,Iris-setosa 31 | 4.8,3.1,1.6,0.2,Iris-setosa 32 | 5.4,3.4,1.5,0.4,Iris-setosa 33 | 5.2,4.1,1.5,0.1,Iris-setosa 34 | 5.5,4.2,1.4,0.2,Iris-setosa 35 | 4.9,3.1,1.5,0.1,Iris-setosa 36 | 5.0,3.2,1.2,0.2,Iris-setosa 37 | 5.5,3.5,1.3,0.2,Iris-setosa 38 | 4.9,3.1,1.5,0.1,Iris-setosa 39 | 4.4,3.0,1.3,0.2,Iris-setosa 40 | 5.1,3.4,1.5,0.2,Iris-setosa 41 | 5.0,3.5,1.3,0.3,Iris-setosa 42 | 4.5,2.3,1.3,0.3,Iris-setosa 43 | 4.4,3.2,1.3,0.2,Iris-setosa 44 | 5.0,3.5,1.6,0.6,Iris-setosa 45 | 5.1,3.8,1.9,0.4,Iris-setosa 46 | 4.8,3.0,1.4,0.3,Iris-setosa 47 | 5.1,3.8,1.6,0.2,Iris-setosa 48 | 4.6,3.2,1.4,0.2,Iris-setosa 49 | 5.3,3.7,1.5,0.2,Iris-setosa 50 | 5.0,3.3,1.4,0.2,Iris-setosa 51 | 7.0,3.2,4.7,1.4,Iris-versicolor 52 | 6.4,3.2,4.5,1.5,Iris-versicolor 53 | 6.9,3.1,4.9,1.5,Iris-versicolor 54 | 5.5,2.3,4.0,1.3,Iris-versicolor 55 | 6.5,2.8,4.6,1.5,Iris-versicolor 56 | 5.7,2.8,4.5,1.3,Iris-versicolor 57 | 6.3,3.3,4.7,1.6,Iris-versicolor 58 | 4.9,2.4,3.3,1.0,Iris-versicolor 59 | 6.6,2.9,4.6,1.3,Iris-versicolor 60 | 5.2,2.7,3.9,1.4,Iris-versicolor 61 | 5.0,2.0,3.5,1.0,Iris-versicolor 62 | 5.9,3.0,4.2,1.5,Iris-versicolor 63 | 6.0,2.2,4.0,1.0,Iris-versicolor 64 | 6.1,2.9,4.7,1.4,Iris-versicolor 65 | 5.6,2.9,3.6,1.3,Iris-versicolor 66 | 6.7,3.1,4.4,1.4,Iris-versicolor 67 | 5.6,3.0,4.5,1.5,Iris-versicolor 68 | 5.8,2.7,4.1,1.0,Iris-versicolor 69 | 6.2,2.2,4.5,1.5,Iris-versicolor 70 | 5.6,2.5,3.9,1.1,Iris-versicolor 71 | 5.9,3.2,4.8,1.8,Iris-versicolor 72 | 6.1,2.8,4.0,1.3,Iris-versicolor 73 | 6.3,2.5,4.9,1.5,Iris-versicolor 74 | 6.1,2.8,4.7,1.2,Iris-versicolor 75 | 6.4,2.9,4.3,1.3,Iris-versicolor 76 | 6.6,3.0,4.4,1.4,Iris-versicolor 77 | 6.8,2.8,4.8,1.4,Iris-versicolor 78 | 6.7,3.0,5.0,1.7,Iris-versicolor 79 | 6.0,2.9,4.5,1.5,Iris-versicolor 80 | 5.7,2.6,3.5,1.0,Iris-versicolor 81 | 5.5,2.4,3.8,1.1,Iris-versicolor 82 | 5.5,2.4,3.7,1.0,Iris-versicolor 83 | 5.8,2.7,3.9,1.2,Iris-versicolor 84 | 6.0,2.7,5.1,1.6,Iris-versicolor 85 | 5.4,3.0,4.5,1.5,Iris-versicolor 86 | 6.0,3.4,4.5,1.6,Iris-versicolor 87 | 6.7,3.1,4.7,1.5,Iris-versicolor 88 | 6.3,2.3,4.4,1.3,Iris-versicolor 89 | 5.6,3.0,4.1,1.3,Iris-versicolor 90 | 5.5,2.5,4.0,1.3,Iris-versicolor 91 | 5.5,2.6,4.4,1.2,Iris-versicolor 92 | 6.1,3.0,4.6,1.4,Iris-versicolor 93 | 5.8,2.6,4.0,1.2,Iris-versicolor 94 | 5.0,2.3,3.3,1.0,Iris-versicolor 95 | 5.6,2.7,4.2,1.3,Iris-versicolor 96 | 5.7,3.0,4.2,1.2,Iris-versicolor 97 | 5.7,2.9,4.2,1.3,Iris-versicolor 98 | 6.2,2.9,4.3,1.3,Iris-versicolor 99 | 5.1,2.5,3.0,1.1,Iris-versicolor 100 | 5.7,2.8,4.1,1.3,Iris-versicolor 101 | 6.3,3.3,6.0,2.5,Iris-virginica 102 | 5.8,2.7,5.1,1.9,Iris-virginica 103 | 7.1,3.0,5.9,2.1,Iris-virginica 104 | 6.3,2.9,5.6,1.8,Iris-virginica 105 | 6.5,3.0,5.8,2.2,Iris-virginica 106 | 7.6,3.0,6.6,2.1,Iris-virginica 107 | 4.9,2.5,4.5,1.7,Iris-virginica 108 | 7.3,2.9,6.3,1.8,Iris-virginica 109 | 6.7,2.5,5.8,1.8,Iris-virginica 110 | 7.2,3.6,6.1,2.5,Iris-virginica 111 | 6.5,3.2,5.1,2.0,Iris-virginica 112 | 6.4,2.7,5.3,1.9,Iris-virginica 113 | 6.8,3.0,5.5,2.1,Iris-virginica 114 | 5.7,2.5,5.0,2.0,Iris-virginica 115 | 5.8,2.8,5.1,2.4,Iris-virginica 116 | 6.4,3.2,5.3,2.3,Iris-virginica 117 | 6.5,3.0,5.5,1.8,Iris-virginica 118 | 7.7,3.8,6.7,2.2,Iris-virginica 119 | 7.7,2.6,6.9,2.3,Iris-virginica 120 | 6.0,2.2,5.0,1.5,Iris-virginica 121 | 6.9,3.2,5.7,2.3,Iris-virginica 122 | 5.6,2.8,4.9,2.0,Iris-virginica 123 | 7.7,2.8,6.7,2.0,Iris-virginica 124 | 6.3,2.7,4.9,1.8,Iris-virginica 125 | 6.7,3.3,5.7,2.1,Iris-virginica 126 | 7.2,3.2,6.0,1.8,Iris-virginica 127 | 6.2,2.8,4.8,1.8,Iris-virginica 128 | 6.1,3.0,4.9,1.8,Iris-virginica 129 | 6.4,2.8,5.6,2.1,Iris-virginica 130 | 7.2,3.0,5.8,1.6,Iris-virginica 131 | 7.4,2.8,6.1,1.9,Iris-virginica 132 | 7.9,3.8,6.4,2.0,Iris-virginica 133 | 6.4,2.8,5.6,2.2,Iris-virginica 134 | 6.3,2.8,5.1,1.5,Iris-virginica 135 | 6.1,2.6,5.6,1.4,Iris-virginica 136 | 7.7,3.0,6.1,2.3,Iris-virginica 137 | 6.3,3.4,5.6,2.4,Iris-virginica 138 | 6.4,3.1,5.5,1.8,Iris-virginica 139 | 6.0,3.0,4.8,1.8,Iris-virginica 140 | 6.9,3.1,5.4,2.1,Iris-virginica 141 | 6.7,3.1,5.6,2.4,Iris-virginica 142 | 6.9,3.1,5.1,2.3,Iris-virginica 143 | 5.8,2.7,5.1,1.9,Iris-virginica 144 | 6.8,3.2,5.9,2.3,Iris-virginica 145 | 6.7,3.3,5.7,2.5,Iris-virginica 146 | 6.7,3.0,5.2,2.3,Iris-virginica 147 | 6.3,2.5,5.0,1.9,Iris-virginica 148 | 6.5,3.0,5.2,2.0,Iris-virginica 149 | 6.2,3.4,5.4,2.3,Iris-virginica 150 | 5.9,3.0,5.1,1.8,Iris-virginica 151 | -------------------------------------------------------------------------------- /tests/data/BinaryIrisData.txt: -------------------------------------------------------------------------------- 1 | 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 2 | 0 0 1 1 0 0 0 1 0 0 0 0 0 0 0 0 0 3 | 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 4 | 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 5 | 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 6 | 0 0 1 1 0 0 1 0 0 0 0 1 0 0 0 0 0 7 | 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 8 | 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 9 | 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 10 | 0 0 1 1 0 0 0 1 0 0 0 0 0 0 0 0 0 11 | 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 12 | 0 0 1 1 0 0 1 0 0 0 0 1 0 0 0 0 0 13 | 0 0 1 1 0 0 0 1 0 0 0 0 0 0 0 0 0 14 | 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 15 | 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 16 | 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 17 | 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 18 | 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 19 | 0 0 1 1 0 0 1 0 0 0 0 1 0 0 0 0 0 20 | 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 21 | 0 0 1 1 0 0 1 0 0 0 0 1 0 0 0 0 0 22 | 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 23 | 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 24 | 0 0 1 1 0 0 1 0 0 0 0 1 0 0 0 0 0 25 | 0 0 1 1 0 0 1 0 0 0 0 1 0 0 0 0 0 26 | 0 0 1 1 0 0 0 1 0 0 0 1 0 0 0 0 0 27 | 0 0 1 1 0 0 1 0 0 0 0 1 0 0 0 0 0 28 | 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 29 | 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 30 | 0 0 1 0 0 0 1 0 0 0 0 1 0 0 0 0 0 31 | 0 0 1 1 0 0 0 1 0 0 0 1 0 0 0 0 0 32 | 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 33 | 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 34 | 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 35 | 0 0 1 1 0 0 0 1 0 0 0 0 0 0 0 0 0 36 | 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 37 | 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 38 | 0 0 1 1 0 0 0 1 0 0 0 0 0 0 0 0 0 39 | 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 40 | 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 41 | 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 42 | 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 43 | 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 44 | 0 0 1 1 0 0 1 0 0 0 0 1 0 0 0 0 0 45 | 0 0 1 1 0 0 1 0 0 0 0 1 0 0 0 0 0 46 | 0 0 1 1 0 0 0 1 0 0 0 0 0 0 0 0 0 47 | 0 0 1 1 0 0 1 0 0 0 0 1 0 0 0 0 0 48 | 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 49 | 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 50 | 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 51 | 0 1 0 0 0 0 1 0 0 0 1 0 0 0 0 0 1 52 | 0 1 0 0 0 0 1 0 0 0 1 0 0 0 0 0 1 53 | 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 0 1 54 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 55 | 0 1 0 0 0 0 0 1 0 0 1 0 0 0 0 0 1 56 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 57 | 0 0 1 1 0 0 1 0 0 0 1 0 0 0 0 1 1 58 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 59 | 0 1 0 0 0 0 0 1 0 0 1 0 0 0 0 0 1 60 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 61 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 62 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 63 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 64 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 65 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 66 | 0 1 0 0 0 0 0 1 0 0 1 0 0 0 0 0 1 67 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 68 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 69 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 70 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 71 | 0 0 1 1 0 0 1 0 0 0 1 1 0 0 0 1 1 72 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 73 | 0 0 1 1 0 0 0 1 0 0 1 1 0 0 0 0 1 74 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 75 | 0 1 0 0 0 0 0 1 0 0 1 0 0 0 0 0 1 76 | 0 1 0 0 0 0 0 1 0 0 1 0 0 0 0 0 1 77 | 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 0 1 78 | 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 1 1 79 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 80 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 81 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 82 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 83 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 84 | 0 0 1 1 0 0 0 1 0 0 1 1 0 0 0 1 1 85 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 86 | 0 0 1 1 0 0 1 0 0 0 1 0 0 0 0 1 1 87 | 0 1 0 0 0 0 0 1 0 0 1 0 0 0 0 0 1 88 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 89 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 90 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 91 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 92 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 93 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 94 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 95 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 96 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 97 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 98 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 99 | 0 0 1 1 0 0 0 1 0 0 0 1 0 0 0 0 1 100 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 1 101 | 0 0 1 1 0 0 1 0 0 0 1 1 0 0 0 1 2 102 | 0 0 1 1 0 0 0 1 0 0 1 1 0 0 0 1 2 103 | 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 1 2 104 | 0 0 1 1 0 0 0 1 0 0 1 1 0 0 0 1 2 105 | 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 1 2 106 | 0 1 0 0 0 0 0 1 0 1 0 0 0 0 0 1 2 107 | 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 1 2 108 | 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 1 2 109 | 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 1 2 110 | 0 1 0 0 0 0 1 0 0 0 1 1 0 0 0 1 2 111 | 0 1 0 0 0 0 1 0 0 0 1 1 0 0 0 1 2 112 | 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 1 2 113 | 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 1 2 114 | 0 0 1 1 0 0 0 1 0 0 1 1 0 0 0 1 2 115 | 0 0 1 1 0 0 0 1 0 0 1 1 0 0 0 1 2 116 | 0 1 0 0 0 0 1 0 0 0 1 1 0 0 0 1 2 117 | 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 1 2 118 | 0 1 0 0 0 0 1 0 0 1 0 0 0 0 0 1 2 119 | 0 1 0 0 0 0 0 1 0 1 0 0 0 0 0 1 2 120 | 0 0 1 1 0 0 0 1 0 0 1 1 0 0 0 0 2 121 | 0 1 0 0 0 0 1 0 0 0 1 1 0 0 0 1 2 122 | 0 0 1 1 0 0 0 1 0 0 1 1 0 0 0 1 2 123 | 0 1 0 0 0 0 0 1 0 1 0 0 0 0 0 1 2 124 | 0 0 1 1 0 0 0 1 0 0 1 1 0 0 0 1 2 125 | 0 1 0 0 0 0 1 0 0 0 1 1 0 0 0 1 2 126 | 0 1 0 0 0 0 1 0 0 0 1 1 0 0 0 1 2 127 | 0 0 1 1 0 0 0 1 0 0 1 1 0 0 0 1 2 128 | 0 0 1 1 0 0 0 1 0 0 1 1 0 0 0 1 2 129 | 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 1 2 130 | 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 1 2 131 | 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 1 2 132 | 0 1 0 0 0 0 1 0 0 1 0 0 0 0 0 1 2 133 | 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 1 2 134 | 0 0 1 1 0 0 0 1 0 0 1 1 0 0 0 0 2 135 | 0 0 1 1 0 0 0 1 0 0 1 1 0 0 0 0 2 136 | 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 1 2 137 | 0 0 1 1 0 0 1 0 0 0 1 1 0 0 0 1 2 138 | 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 1 2 139 | 0 0 1 1 0 0 0 1 0 0 1 1 0 0 0 1 2 140 | 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 1 2 141 | 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 1 2 142 | 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 1 2 143 | 0 0 1 1 0 0 0 1 0 0 1 1 0 0 0 1 2 144 | 0 1 0 0 0 0 1 0 0 0 1 1 0 0 0 1 2 145 | 0 1 0 0 0 0 1 0 0 0 1 1 0 0 0 1 2 146 | 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 1 2 147 | 0 0 1 1 0 0 0 1 0 0 1 1 0 0 0 1 2 148 | 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 1 2 149 | 0 0 1 1 0 0 1 0 0 0 1 1 0 0 0 1 2 150 | 0 0 1 1 0 0 0 1 0 0 1 1 0 0 0 1 2 -------------------------------------------------------------------------------- /neurallogic/hard_masks.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import jax 4 | from flax import linen as nn 5 | 6 | from neurallogic import neural_logic_net, symbolic_generation, hard_and, hard_or, initialization 7 | 8 | # TODO: properly factor with/without margin versions 9 | 10 | 11 | def soft_mask_to_true(w: float, x: float): 12 | """ 13 | w > 0.5 implies the mask operation is inactive, else active 14 | 15 | Assumes x is in [0, 1] 16 | 17 | Corresponding hard logic: x OR ! w 18 | """ 19 | w = jax.numpy.clip(w, 0.0, 1.0) 20 | return jax.numpy.maximum(x, 1.0 - w) 21 | 22 | # Superior on noisy XOR 23 | def soft_mask_to_true_margin(w: float, x: float) -> float: 24 | w = jax.numpy.clip(w, 0.0, 1.0) 25 | return hard_or.soft_or(x, 1.0 - w) 26 | 27 | 28 | 29 | def hard_mask_to_true(w, x): 30 | return jax.numpy.logical_or(x, jax.numpy.logical_not(w)) 31 | 32 | 33 | soft_mask_to_true_neuron = jax.vmap(soft_mask_to_true, 0, 0) 34 | soft_mask_to_true_margin_neuron = jax.vmap(soft_mask_to_true_margin, 0, 0) 35 | 36 | hard_mask_to_true_neuron = jax.vmap(hard_mask_to_true, 0, 0) 37 | 38 | 39 | soft_mask_to_true_layer = jax.vmap(soft_mask_to_true_neuron, (0, None), 0) 40 | soft_mask_to_true_margin_layer = jax.vmap(soft_mask_to_true_margin_neuron, (0, None), 0) 41 | 42 | hard_mask_to_true_layer = jax.vmap(hard_mask_to_true_neuron, (0, None), 0) 43 | 44 | 45 | def soft_mask_to_false(w: float, x: float): 46 | """ 47 | w > 0.5 implies the mask is inactive, else active 48 | 49 | Assumes x is in [0, 1] 50 | 51 | Corresponding hard logic: b AND w 52 | """ 53 | w = jax.numpy.clip(w, 0.0, 1.0) 54 | # TODO: what is this madness? 55 | return 1.0 - jax.numpy.maximum(1.0 - x, 1.0 - w) 56 | 57 | # Superior on noisy XOR 58 | def soft_mask_to_false_margin(w: float, x: float) -> float: 59 | w = jax.numpy.clip(w, 0.0, 1.0) 60 | return hard_and.soft_and(x, w) 61 | 62 | 63 | def hard_mask_to_false(w, x): 64 | return jax.numpy.logical_and(x, w) 65 | 66 | 67 | soft_mask_to_false_neuron = jax.vmap(soft_mask_to_false, 0, 0) 68 | soft_mask_to_false_margin_neuron = jax.vmap(soft_mask_to_false_margin, 0, 0) 69 | 70 | hard_mask_to_false_neuron = jax.vmap(hard_mask_to_false, 0, 0) 71 | 72 | 73 | soft_mask_to_false_layer = jax.vmap(soft_mask_to_false_neuron, (0, None), 0) 74 | soft_mask_to_false_margin_layer = jax.vmap(soft_mask_to_false_margin_neuron, (0, None), 0) 75 | 76 | hard_mask_to_false_layer = jax.vmap(hard_mask_to_false_neuron, (0, None), 0) 77 | 78 | 79 | class SoftMaskLayer(nn.Module): 80 | mask_layer_operation: Callable 81 | layer_size: int 82 | weights_init: Callable = nn.initializers.uniform(1.0) 83 | dtype: jax.numpy.dtype = jax.numpy.float32 84 | 85 | @nn.compact 86 | def __call__(self, x): 87 | weights_shape = (self.layer_size, jax.numpy.shape(x)[-1]) 88 | weights = self.param( 89 | "bit_weights", self.weights_init, weights_shape, self.dtype 90 | ) 91 | x = jax.numpy.asarray(x, self.dtype) 92 | return self.mask_layer_operation(weights, x) 93 | 94 | 95 | class HardMaskLayer(nn.Module): 96 | mask_layer_operation: Callable 97 | layer_size: int 98 | weights_init: Callable = nn.initializers.constant(True) 99 | 100 | @nn.compact 101 | def __call__(self, x): 102 | weights_shape = (self.layer_size, jax.numpy.shape(x)[-1]) 103 | weights = self.param("bit_weights", self.weights_init, weights_shape) 104 | return self.mask_layer_operation(weights, x) 105 | 106 | 107 | class SymbolicMaskLayer: 108 | def __init__(self, mask_layer): 109 | self.hard_mask_layer = mask_layer 110 | 111 | def __call__(self, x): 112 | jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(self.hard_mask_layer, x) 113 | return symbolic_generation.symbolic_expression(jaxpr, x) 114 | 115 | 116 | mask_to_true_layer = neural_logic_net.select( 117 | lambda layer_size, weights_init=nn.initializers.uniform( 118 | 1.0 119 | ), dtype=jax.numpy.float32: SoftMaskLayer( 120 | soft_mask_to_true_layer, layer_size, weights_init, dtype 121 | ), 122 | lambda layer_size, weights_init=nn.initializers.uniform( 123 | 1.0 124 | ), dtype=jax.numpy.float32: HardMaskLayer(hard_mask_to_true_layer, layer_size), 125 | lambda layer_size, weights_init=nn.initializers.uniform( 126 | 1.0 127 | ), dtype=jax.numpy.float32: SymbolicMaskLayer( 128 | HardMaskLayer(hard_mask_to_true_layer, layer_size) 129 | ), 130 | ) 131 | 132 | 133 | mask_to_true_margin_layer = neural_logic_net.select( 134 | lambda layer_size, weights_init=nn.initializers.uniform( 135 | 1.0 136 | ), dtype=jax.numpy.float32: SoftMaskLayer( 137 | soft_mask_to_true_margin_layer, layer_size, weights_init, dtype 138 | ), 139 | lambda layer_size, weights_init=nn.initializers.uniform( 140 | 1.0 141 | ), dtype=jax.numpy.float32: HardMaskLayer(hard_mask_to_true_layer, layer_size), 142 | lambda layer_size, weights_init=nn.initializers.uniform( 143 | 1.0 144 | ), dtype=jax.numpy.float32: SymbolicMaskLayer( 145 | HardMaskLayer(hard_mask_to_true_layer, layer_size) 146 | ), 147 | ) 148 | 149 | mask_to_false_layer = neural_logic_net.select( 150 | lambda layer_size, weights_init=nn.initializers.uniform( 151 | 1.0 152 | ), dtype=jax.numpy.float32: SoftMaskLayer( 153 | soft_mask_to_false_layer, layer_size, weights_init, dtype 154 | ), 155 | lambda layer_size, weights_init=nn.initializers.uniform( 156 | 1.0 157 | ), dtype=jax.numpy.float32: HardMaskLayer(hard_mask_to_false_layer, layer_size), 158 | lambda layer_size, weights_init=nn.initializers.uniform( 159 | 1.0 160 | ), dtype=jax.numpy.float32: SymbolicMaskLayer( 161 | HardMaskLayer(hard_mask_to_false_layer, layer_size) 162 | ), 163 | ) 164 | 165 | # TODO: mask to false margin layer -------------------------------------------------------------------------------- /scratchpad.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Imports" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stderr", 17 | "output_type": "stream", 18 | "text": [ 19 | "2023-01-17 15:09:13.384599: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n", 20 | "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", 21 | "2023-01-17 15:09:14.414425: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", 22 | "2023-01-17 15:09:14.414541: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", 23 | "2023-01-17 15:09:14.414556: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" 24 | ] 25 | }, 26 | { 27 | "ename": "ImportError", 28 | "evalue": "cannot import name 'primitives' from 'neurallogic' (/workspaces/neural-logic/neurallogic/__init__.py)", 29 | "output_type": "error", 30 | "traceback": [ 31 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 32 | "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", 33 | "Cell \u001b[0;32mIn[1], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mjax\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mnumpy\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mjnp\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mflax\u001b[39;00m \u001b[39mimport\u001b[39;00m linen \u001b[39mas\u001b[39;00m nn\n\u001b[0;32m----> 5\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mneurallogic\u001b[39;00m \u001b[39mimport\u001b[39;00m neural_logic_net, harden, harden_layer, hard_or, hard_and, hard_not, primitives, symbolic_primitives\n\u001b[1;32m 6\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mtests\u001b[39;00m \u001b[39mimport\u001b[39;00m test_mnist\n\u001b[1;32m 7\u001b[0m tf\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mexperimental\u001b[39m.\u001b[39mset_visible_devices([], \u001b[39m\"\u001b[39m\u001b[39mGPU\u001b[39m\u001b[39m\"\u001b[39m)\n", 34 | "\u001b[0;31mImportError\u001b[0m: cannot import name 'primitives' from 'neurallogic' (/workspaces/neural-logic/neurallogic/__init__.py)" 35 | ] 36 | } 37 | ], 38 | "source": [ 39 | "import tensorflow as tf\n", 40 | "import jax\n", 41 | "import jax.numpy as jnp\n", 42 | "from flax import linen as nn\n", 43 | "from neurallogic import neural_logic_net, harden, harden_layer, hard_or, hard_and, hard_not, symbolic_primitives\n", 44 | "from tests import test_mnist\n", 45 | "tf.config.experimental.set_visible_devices([], \"GPU\")\n", 46 | "import numpy" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 10, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "# clear the GPU memory\n", 56 | "from numba import cuda\n", 57 | "cuda.select_device(0)\n", 58 | "cuda.close()" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "# Sandpit" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 2, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "data": { 75 | "text/plain": [ 76 | "3" 77 | ] 78 | }, 79 | "execution_count": 2, 80 | "metadata": {}, 81 | "output_type": "execute_result" 82 | } 83 | ], 84 | "source": [ 85 | "eval('1+2')" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 3, 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "ename": "TypeError", 95 | "evalue": "eval() arg 1 must be a string, bytes or code object", 96 | "output_type": "error", 97 | "traceback": [ 98 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 99 | "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", 100 | "Cell \u001b[0;32mIn[3], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[39meval\u001b[39;49m([\u001b[39m'\u001b[39;49m\u001b[39m1+2\u001b[39;49m\u001b[39m'\u001b[39;49m])\n", 101 | "\u001b[0;31mTypeError\u001b[0m: eval() arg 1 must be a string, bytes or code object" 102 | ] 103 | } 104 | ], 105 | "source": [ 106 | "eval(['1+2'])" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 4, 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "data": { 116 | "text/plain": [ 117 | "['1+2']" 118 | ] 119 | }, 120 | "execution_count": 4, 121 | "metadata": {}, 122 | "output_type": "execute_result" 123 | } 124 | ], 125 | "source": [ 126 | "eval(\"['1+2']\")" 127 | ] 128 | } 129 | ], 130 | "metadata": { 131 | "kernelspec": { 132 | "display_name": "Python 3.10.6 ('base')", 133 | "language": "python", 134 | "name": "python3" 135 | }, 136 | "language_info": { 137 | "codemirror_mode": { 138 | "name": "ipython", 139 | "version": 3 140 | }, 141 | "file_extension": ".py", 142 | "mimetype": "text/x-python", 143 | "name": "python", 144 | "nbconvert_exporter": "python", 145 | "pygments_lexer": "ipython3", 146 | "version": "3.10.8" 147 | }, 148 | "orig_nbformat": 4, 149 | "vscode": { 150 | "interpreter": { 151 | "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe" 152 | } 153 | } 154 | }, 155 | "nbformat": 4, 156 | "nbformat_minor": 2 157 | } 158 | -------------------------------------------------------------------------------- /docs/ICML_workshop/icml2023/algorithmic.sty: -------------------------------------------------------------------------------- 1 | % ALGORITHMIC STYLE -- Released 8 APRIL 1996 2 | % for LaTeX version 2e 3 | % Copyright -- 1994 Peter Williams 4 | % E-mail PeterWilliams@dsto.defence.gov.au 5 | % 6 | % Modified by Alex Smola (08/2000) 7 | % E-mail Alex.Smola@anu.edu.au 8 | % 9 | \NeedsTeXFormat{LaTeX2e} 10 | \ProvidesPackage{algorithmic} 11 | \typeout{Document Style `algorithmic' - environment} 12 | % 13 | \RequirePackage{ifthen} 14 | \RequirePackage{calc} 15 | \newboolean{ALC@noend} 16 | \setboolean{ALC@noend}{false} 17 | \newcounter{ALC@line} 18 | \newcounter{ALC@rem} 19 | \newlength{\ALC@tlm} 20 | % 21 | \DeclareOption{noend}{\setboolean{ALC@noend}{true}} 22 | % 23 | \ProcessOptions 24 | % 25 | % ALGORITHMIC 26 | \newcommand{\algorithmicrequire}{\textbf{Require:}} 27 | \newcommand{\algorithmicensure}{\textbf{Ensure:}} 28 | \newcommand{\algorithmiccomment}[1]{\{#1\}} 29 | \newcommand{\algorithmicend}{\textbf{end}} 30 | \newcommand{\algorithmicif}{\textbf{if}} 31 | \newcommand{\algorithmicthen}{\textbf{then}} 32 | \newcommand{\algorithmicelse}{\textbf{else}} 33 | \newcommand{\algorithmicelsif}{\algorithmicelse\ \algorithmicif} 34 | \newcommand{\algorithmicendif}{\algorithmicend\ \algorithmicif} 35 | \newcommand{\algorithmicfor}{\textbf{for}} 36 | \newcommand{\algorithmicforall}{\textbf{for all}} 37 | \newcommand{\algorithmicdo}{\textbf{do}} 38 | \newcommand{\algorithmicendfor}{\algorithmicend\ \algorithmicfor} 39 | \newcommand{\algorithmicwhile}{\textbf{while}} 40 | \newcommand{\algorithmicendwhile}{\algorithmicend\ \algorithmicwhile} 41 | \newcommand{\algorithmicloop}{\textbf{loop}} 42 | \newcommand{\algorithmicendloop}{\algorithmicend\ \algorithmicloop} 43 | \newcommand{\algorithmicrepeat}{\textbf{repeat}} 44 | \newcommand{\algorithmicuntil}{\textbf{until}} 45 | 46 | %changed by alex smola 47 | \newcommand{\algorithmicinput}{\textbf{input}} 48 | \newcommand{\algorithmicoutput}{\textbf{output}} 49 | \newcommand{\algorithmicset}{\textbf{set}} 50 | \newcommand{\algorithmictrue}{\textbf{true}} 51 | \newcommand{\algorithmicfalse}{\textbf{false}} 52 | \newcommand{\algorithmicand}{\textbf{and\ }} 53 | \newcommand{\algorithmicor}{\textbf{or\ }} 54 | \newcommand{\algorithmicfunction}{\textbf{function}} 55 | \newcommand{\algorithmicendfunction}{\algorithmicend\ \algorithmicfunction} 56 | \newcommand{\algorithmicmain}{\textbf{main}} 57 | \newcommand{\algorithmicendmain}{\algorithmicend\ \algorithmicmain} 58 | %end changed by alex smola 59 | 60 | \def\ALC@item[#1]{% 61 | \if@noparitem \@donoparitem 62 | \else \if@inlabel \indent \par \fi 63 | \ifhmode \unskip\unskip \par \fi 64 | \if@newlist \if@nobreak \@nbitem \else 65 | \addpenalty\@beginparpenalty 66 | \addvspace\@topsep \addvspace{-\parskip}\fi 67 | \else \addpenalty\@itempenalty \addvspace\itemsep 68 | \fi 69 | \global\@inlabeltrue 70 | \fi 71 | \everypar{\global\@minipagefalse\global\@newlistfalse 72 | \if@inlabel\global\@inlabelfalse \hskip -\parindent \box\@labels 73 | \penalty\z@ \fi 74 | \everypar{}}\global\@nobreakfalse 75 | \if@noitemarg \@noitemargfalse \if@nmbrlist \refstepcounter{\@listctr}\fi \fi 76 | \sbox\@tempboxa{\makelabel{#1}}% 77 | \global\setbox\@labels 78 | \hbox{\unhbox\@labels \hskip \itemindent 79 | \hskip -\labelwidth \hskip -\ALC@tlm 80 | \ifdim \wd\@tempboxa >\labelwidth 81 | \box\@tempboxa 82 | \else \hbox to\labelwidth {\unhbox\@tempboxa}\fi 83 | \hskip \ALC@tlm}\ignorespaces} 84 | % 85 | \newenvironment{algorithmic}[1][0]{ 86 | \let\@item\ALC@item 87 | \newcommand{\ALC@lno}{% 88 | \ifthenelse{\equal{\arabic{ALC@rem}}{0}} 89 | {{\footnotesize \arabic{ALC@line}:}}{}% 90 | } 91 | \let\@listii\@listi 92 | \let\@listiii\@listi 93 | \let\@listiv\@listi 94 | \let\@listv\@listi 95 | \let\@listvi\@listi 96 | \let\@listvii\@listi 97 | \newenvironment{ALC@g}{ 98 | \begin{list}{\ALC@lno}{ \itemsep\z@ \itemindent\z@ 99 | \listparindent\z@ \rightmargin\z@ 100 | \topsep\z@ \partopsep\z@ \parskip\z@\parsep\z@ 101 | \leftmargin 1em 102 | \addtolength{\ALC@tlm}{\leftmargin} 103 | } 104 | } 105 | {\end{list}} 106 | \newcommand{\ALC@it}{\addtocounter{ALC@line}{1}\addtocounter{ALC@rem}{1}\ifthenelse{\equal{\arabic{ALC@rem}}{#1}}{\setcounter{ALC@rem}{0}}{}\item} 107 | \newcommand{\ALC@com}[1]{\ifthenelse{\equal{##1}{default}}% 108 | {}{\ \algorithmiccomment{##1}}} 109 | \newcommand{\REQUIRE}{\item[\algorithmicrequire]} 110 | \newcommand{\ENSURE}{\item[\algorithmicensure]} 111 | \newcommand{\STATE}{\ALC@it} 112 | \newcommand{\COMMENT}[1]{\algorithmiccomment{##1}} 113 | %changes by alex smola 114 | \newcommand{\INPUT}{\item[\algorithmicinput]} 115 | \newcommand{\OUTPUT}{\item[\algorithmicoutput]} 116 | \newcommand{\SET}{\item[\algorithmicset]} 117 | % \newcommand{\TRUE}{\algorithmictrue} 118 | % \newcommand{\FALSE}{\algorithmicfalse} 119 | \newcommand{\AND}{\algorithmicand} 120 | \newcommand{\OR}{\algorithmicor} 121 | \newenvironment{ALC@func}{\begin{ALC@g}}{\end{ALC@g}} 122 | \newenvironment{ALC@main}{\begin{ALC@g}}{\end{ALC@g}} 123 | %end changes by alex smola 124 | \newenvironment{ALC@if}{\begin{ALC@g}}{\end{ALC@g}} 125 | \newenvironment{ALC@for}{\begin{ALC@g}}{\end{ALC@g}} 126 | \newenvironment{ALC@whl}{\begin{ALC@g}}{\end{ALC@g}} 127 | \newenvironment{ALC@loop}{\begin{ALC@g}}{\end{ALC@g}} 128 | \newenvironment{ALC@rpt}{\begin{ALC@g}}{\end{ALC@g}} 129 | \renewcommand{\\}{\@centercr} 130 | \newcommand{\IF}[2][default]{\ALC@it\algorithmicif\ ##2\ \algorithmicthen% 131 | \ALC@com{##1}\begin{ALC@if}} 132 | \newcommand{\SHORTIF}[2]{\ALC@it\algorithmicif\ ##1\ 133 | \algorithmicthen\ {##2}} 134 | \newcommand{\ELSE}[1][default]{\end{ALC@if}\ALC@it\algorithmicelse% 135 | \ALC@com{##1}\begin{ALC@if}} 136 | \newcommand{\ELSIF}[2][default]% 137 | {\end{ALC@if}\ALC@it\algorithmicelsif\ ##2\ \algorithmicthen% 138 | \ALC@com{##1}\begin{ALC@if}} 139 | \newcommand{\FOR}[2][default]{\ALC@it\algorithmicfor\ ##2\ \algorithmicdo% 140 | \ALC@com{##1}\begin{ALC@for}} 141 | \newcommand{\FORALL}[2][default]{\ALC@it\algorithmicforall\ ##2\ % 142 | \algorithmicdo% 143 | \ALC@com{##1}\begin{ALC@for}} 144 | \newcommand{\SHORTFORALL}[2]{\ALC@it\algorithmicforall\ ##1\ % 145 | \algorithmicdo\ {##2}} 146 | \newcommand{\WHILE}[2][default]{\ALC@it\algorithmicwhile\ ##2\ % 147 | \algorithmicdo% 148 | \ALC@com{##1}\begin{ALC@whl}} 149 | \newcommand{\LOOP}[1][default]{\ALC@it\algorithmicloop% 150 | \ALC@com{##1}\begin{ALC@loop}} 151 | %changed by alex smola 152 | \newcommand{\FUNCTION}[2][default]{\ALC@it\algorithmicfunction\ ##2\ % 153 | \ALC@com{##1}\begin{ALC@func}} 154 | \newcommand{\MAIN}[2][default]{\ALC@it\algorithmicmain\ ##2\ % 155 | \ALC@com{##1}\begin{ALC@main}} 156 | %end changed by alex smola 157 | \newcommand{\REPEAT}[1][default]{\ALC@it\algorithmicrepeat% 158 | \ALC@com{##1}\begin{ALC@rpt}} 159 | \newcommand{\UNTIL}[1]{\end{ALC@rpt}\ALC@it\algorithmicuntil\ ##1} 160 | \ifthenelse{\boolean{ALC@noend}}{ 161 | \newcommand{\ENDIF}{\end{ALC@if}} 162 | \newcommand{\ENDFOR}{\end{ALC@for}} 163 | \newcommand{\ENDWHILE}{\end{ALC@whl}} 164 | \newcommand{\ENDLOOP}{\end{ALC@loop}} 165 | \newcommand{\ENDFUNCTION}{\end{ALC@func}} 166 | \newcommand{\ENDMAIN}{\end{ALC@main}} 167 | }{ 168 | \newcommand{\ENDIF}{\end{ALC@if}\ALC@it\algorithmicendif} 169 | \newcommand{\ENDFOR}{\end{ALC@for}\ALC@it\algorithmicendfor} 170 | \newcommand{\ENDWHILE}{\end{ALC@whl}\ALC@it\algorithmicendwhile} 171 | \newcommand{\ENDLOOP}{\end{ALC@loop}\ALC@it\algorithmicendloop} 172 | \newcommand{\ENDFUNCTION}{\end{ALC@func}\ALC@it\algorithmicendfunction} 173 | \newcommand{\ENDMAIN}{\end{ALC@main}\ALC@it\algorithmicendmain} 174 | } 175 | \renewcommand{\@toodeep}{} 176 | \begin{list}{\ALC@lno}{\setcounter{ALC@line}{0}\setcounter{ALC@rem}{0}% 177 | \itemsep\z@ \itemindent\z@ \listparindent\z@% 178 | \partopsep\z@ \parskip\z@ \parsep\z@% 179 | \labelsep 0.5em \topsep 0.2em% 180 | \ifthenelse{\equal{#1}{0}} 181 | {\labelwidth 0.5em } 182 | {\labelwidth 1.2em } 183 | \leftmargin\labelwidth \addtolength{\leftmargin}{\labelsep} 184 | \ALC@tlm\labelsep 185 | } 186 | } 187 | {\end{list}} 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | -------------------------------------------------------------------------------- /tests/test_hard_masks.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import numpy 3 | from jax import random 4 | 5 | from neurallogic import hard_masks, harden, neural_logic_net 6 | from tests import utils 7 | 8 | 9 | def test_mask_to_true(): 10 | test_data = [ 11 | [[1.0, 1.0], 1.0], 12 | [[1.0, 0.0], 0.0], 13 | [[0.0, 0.0], 1.0], 14 | [[0.0, 1.0], 1.0], 15 | [[1.1, 1.0], 1.0], 16 | [[1.1, 0.0], 0.0], 17 | [[-0.1, 0.0], 1.0], 18 | [[-0.1, 1.0], 1.0], 19 | ] 20 | for input, expected in test_data: 21 | utils.check_consistency( 22 | hard_masks.soft_mask_to_true, 23 | hard_masks.hard_mask_to_true, 24 | expected, 25 | input[0], 26 | input[1], 27 | ) 28 | 29 | 30 | def test_mask_to_true_neuron(): 31 | test_data = [ 32 | [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], 33 | [[0.0, 0.0], [0.0, 0.0], [1.0, 1.0]], 34 | [[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]], 35 | [[0.0, 1.0], [1.0, 0.0], [0.0, 1.0]], 36 | [[0.0, 1.0], [0.0, 0.0], [1.0, 1.0]], 37 | [[0.0, 1.0], [1.0, 1.0], [0.0, 1.0]], 38 | ] 39 | for input, weights, expected in test_data: 40 | 41 | def soft(weights, input): 42 | return hard_masks.soft_mask_to_true_neuron(weights, input) 43 | 44 | def hard(weights, input): 45 | return hard_masks.hard_mask_to_true_neuron(weights, input) 46 | 47 | utils.check_consistency( 48 | soft, hard, expected, jax.numpy.array(weights), jax.numpy.array(input) 49 | ) 50 | 51 | 52 | def test_mask_to_true_layer(): 53 | test_data = [ 54 | [ 55 | [1.0, 0.0], 56 | [[1.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 0.2]], 57 | [[1.0, 0.0], [1.0, 0.0], [1.0, 1.0], [1.0, 0.8]], 58 | ], 59 | [ 60 | [1.0, 0.4], 61 | [[1.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 0.0]], 62 | [[1.0, 0.4], [1.0, 0.4], [1.0, 1.0], [1.0, 1.0]], 63 | ], 64 | [ 65 | [0.0, 1.0], 66 | [[1.0, 1.0], [0.0, 0.8], [1.0, 0.0], [0.0, 0.0]], 67 | [[0.0, 1.0], [1.0, 1.0], [0.0, 1.0], [1.0, 1.0]], 68 | ], 69 | [ 70 | [0.0, 0.0], 71 | [[1.0, 0.01], [0.0, 1.0], [1.0, 0.0], [0.0, 0.0]], 72 | [[0.0, 0.99], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], 73 | ], 74 | ] 75 | for input, weights, expected in test_data: 76 | 77 | def soft(weights, input): 78 | return hard_masks.soft_mask_to_true_layer(weights, input) 79 | 80 | def hard(weights, input): 81 | return hard_masks.hard_mask_to_true_layer(weights, input) 82 | 83 | utils.check_consistency( 84 | soft, 85 | hard, 86 | jax.numpy.array(expected), 87 | jax.numpy.array(weights), 88 | jax.numpy.array(input), 89 | ) 90 | 91 | 92 | def test_mask_to_true_net(): 93 | def test_net(type, x): 94 | x = hard_masks.mask_to_true_layer(type)(4)(x) 95 | x = x.ravel() 96 | return x 97 | 98 | soft, hard, symbolic = neural_logic_net.net(test_net) 99 | weights = soft.init(random.PRNGKey(0), [0.0, 0.0]) 100 | hard_weights = harden.hard_weights(weights) 101 | 102 | test_data = [ 103 | [ 104 | [1.0, 1.0], 105 | [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 106 | ], 107 | [ 108 | [1.0, 0.0], 109 | [1.0, 0.17739451, 1.0, 0.77752244, 1.0, 0.11280203, 1.0, 0.43465567], 110 | ], 111 | [ 112 | [0.0, 1.0], 113 | [0.6201445, 1.0, 0.7178699, 1.0, 0.29197645, 1.0, 0.41213453, 1.0], 114 | ], 115 | [ 116 | [0.0, 0.0], 117 | [ 118 | 0.6201445, 119 | 0.17739451, 120 | 0.7178699, 121 | 0.77752244, 122 | 0.29197645, 123 | 0.11280203, 124 | 0.41213453, 125 | 0.43465567, 126 | ], 127 | ], 128 | ] 129 | for input, expected in test_data: 130 | # Check that the soft function performs as expected 131 | soft_output = soft.apply(weights, jax.numpy.array(input)) 132 | expected_output = jax.numpy.array(expected) 133 | assert jax.numpy.allclose(soft_output, expected_output) 134 | 135 | # Check that the hard function performs as expected 136 | hard_input = harden.harden(jax.numpy.array(input)) 137 | hard_expected = harden.harden(jax.numpy.array(expected)) 138 | hard_output = hard.apply(hard_weights, hard_input) 139 | assert jax.numpy.allclose(hard_output, hard_expected) 140 | 141 | # Check that the symbolic function performs as expected 142 | symbolic_output = symbolic.apply(hard_weights, hard_input) 143 | assert numpy.allclose(symbolic_output, hard_expected) 144 | 145 | 146 | def test_mask_to_false(): 147 | test_data = [ 148 | [[1.0, 1.0], 1.0], 149 | [[1.0, 0.0], 0.0], 150 | [[0.0, 0.0], 0.0], 151 | [[0.0, 1.0], 0.0], 152 | [[1.1, 1.0], 1.0], 153 | [[1.1, 0.0], 0.0], 154 | [[-0.1, 0.0], 0.0], 155 | [[-0.1, 1.0], 0.0], 156 | ] 157 | for input, expected in test_data: 158 | utils.check_consistency( 159 | hard_masks.soft_mask_to_false, 160 | hard_masks.hard_mask_to_false, 161 | expected, 162 | input[0], 163 | input[1], 164 | ) 165 | 166 | 167 | def test_mask_to_false_neuron(): 168 | test_data = [ 169 | [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], 170 | [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], 171 | [[1.0, 0.0], [0.0, 1.0], [0.0, 0.0]], 172 | [[0.0, 1.0], [1.0, 0.0], [0.0, 0.0]], 173 | [[0.0, 1.0], [0.0, 0.0], [0.0, 0.0]], 174 | [[0.0, 1.0], [1.0, 1.0], [0.0, 1.0]], 175 | ] 176 | for input, weights, expected in test_data: 177 | 178 | def soft(weights, input): 179 | return hard_masks.soft_mask_to_false_neuron(weights, input) 180 | 181 | def hard(weights, input): 182 | return hard_masks.hard_mask_to_false_neuron(weights, input) 183 | 184 | utils.check_consistency( 185 | soft, hard, expected, jax.numpy.array(weights), jax.numpy.array(input) 186 | ) 187 | 188 | 189 | def test_mask_to_false_layer(): 190 | test_data = [ 191 | [ 192 | [1.0, 0.0], 193 | [[1.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 0.2]], 194 | [[1.0, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 0.0]], 195 | ], 196 | [ 197 | [1.0, 0.4], 198 | [[1.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 0.0]], 199 | [[1.0, 0.39999998], [0.0, 0.39999998], [1.0, 0.0], [0.0, 0.0]], 200 | ], 201 | [ 202 | [0.0, 1.0], 203 | [[1.0, 1.0], [0.0, 0.8], [1.0, 0.0], [0.0, 0.0]], 204 | [[0.0, 1.0], [0.0, 0.8], [0.0, 0.0], [0.0, 0.0]], 205 | ], 206 | [ 207 | [0.0, 0.0], 208 | [[1.0, 0.01], [0.0, 1.0], [1.0, 0.0], [0.0, 0.0]], 209 | [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], 210 | ], 211 | ] 212 | for input, weights, expected in test_data: 213 | 214 | def soft(weights, input): 215 | return hard_masks.soft_mask_to_false_layer(weights, input) 216 | 217 | def hard(weights, input): 218 | return hard_masks.hard_mask_to_false_layer(weights, input) 219 | 220 | utils.check_consistency( 221 | soft, 222 | hard, 223 | jax.numpy.array(expected), 224 | jax.numpy.array(weights), 225 | jax.numpy.array(input), 226 | ) 227 | 228 | 229 | def test_mask_to_false_net(): 230 | def test_net(type, x): 231 | x = hard_masks.mask_to_false_layer(type)(4)(x) 232 | x = x.ravel() 233 | return x 234 | 235 | soft, hard, symbolic = neural_logic_net.net(test_net) 236 | weights = soft.init(random.PRNGKey(0), [0.0, 0.0]) 237 | hard_weights = harden.hard_weights(weights) 238 | 239 | test_data = [ 240 | [ 241 | [1.0, 1.0], 242 | [ 243 | 0.3798555, 244 | 0.8226055, 245 | 0.28213012, 246 | 0.22247756, 247 | 0.70802355, 248 | 0.887198, 249 | 0.5878655, 250 | 0.56534433, 251 | ], 252 | ], 253 | [ 254 | [1.0, 0.0], 255 | [0.3798555, 0.0, 0.28213012, 0.0, 0.70802355, 0.0, 0.5878655, 0.0], 256 | ], 257 | [ 258 | [0.0, 1.0], 259 | [0.0, 0.8226055, 0.0, 0.22247756, 0.0, 0.887198, 0.0, 0.56534433] 260 | ], 261 | [ 262 | [0.0, 0.0], 263 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 264 | ], 265 | ] 266 | for input, expected in test_data: 267 | # Check that the soft function performs as expected 268 | soft_output = soft.apply(weights, jax.numpy.array(input)) 269 | expected_output = jax.numpy.array(expected) 270 | assert jax.numpy.allclose(soft_output, expected_output) 271 | 272 | # Check that the hard function performs as expected 273 | hard_input = harden.harden(jax.numpy.array(input)) 274 | hard_expected = harden.harden(jax.numpy.array(expected)) 275 | hard_output = hard.apply(hard_weights, hard_input) 276 | assert jax.numpy.allclose(hard_output, hard_expected) 277 | 278 | # Check that the symbolic function performs as expected 279 | symbolic_output = symbolic.apply(hard_weights, hard_input) 280 | assert numpy.allclose(symbolic_output, hard_expected) 281 | -------------------------------------------------------------------------------- /docs/iclr2021_conference.sty: -------------------------------------------------------------------------------- 1 | %%%% ICLR Macros (LaTex) 2 | %%%% Adapted by Hugo Larochelle from the NIPS stylefile Macros 3 | %%%% Style File 4 | %%%% Dec 12, 1990 Rev Aug 14, 1991; Sept, 1995; April, 1997; April, 1999; October 2014 5 | 6 | % This file can be used with Latex2e whether running in main mode, or 7 | % 2.09 compatibility mode. 8 | % 9 | % If using main mode, you need to include the commands 10 | % \documentclass{article} 11 | % \usepackage{iclr14submit_e,times} 12 | % 13 | 14 | % Change the overall width of the page. If these parameters are 15 | % changed, they will require corresponding changes in the 16 | % maketitle section. 17 | % 18 | \usepackage{eso-pic} % used by \AddToShipoutPicture 19 | \RequirePackage{fancyhdr} 20 | \RequirePackage{natbib} 21 | 22 | % modification to natbib citations 23 | \setcitestyle{authoryear,round,citesep={;},aysep={,},yysep={;}} 24 | 25 | \renewcommand{\topfraction}{0.95} % let figure take up nearly whole page 26 | \renewcommand{\textfraction}{0.05} % let figure take up nearly whole page 27 | 28 | % Define iclrfinal, set to true if iclrfinalcopy is defined 29 | \newif\ificlrfinal 30 | \iclrfinalfalse 31 | \def\iclrfinalcopy{\iclrfinaltrue} 32 | \font\iclrtenhv = phvb at 8pt 33 | 34 | % Specify the dimensions of each page 35 | 36 | \setlength{\paperheight}{11in} 37 | \setlength{\paperwidth}{8.5in} 38 | 39 | 40 | \oddsidemargin .5in % Note \oddsidemargin = \evensidemargin 41 | \evensidemargin .5in 42 | \marginparwidth 0.07 true in 43 | %\marginparwidth 0.75 true in 44 | %\topmargin 0 true pt % Nominal distance from top of page to top of 45 | %\topmargin 0.125in 46 | \topmargin -0.625in 47 | \addtolength{\headsep}{0.25in} 48 | \textheight 9.0 true in % Height of text (including footnotes & figures) 49 | \textwidth 5.5 true in % Width of text line. 50 | \widowpenalty=10000 51 | \clubpenalty=10000 52 | 53 | % \thispagestyle{empty} \pagestyle{empty} 54 | \flushbottom \sloppy 55 | 56 | % We're never going to need a table of contents, so just flush it to 57 | % save space --- suggested by drstrip@sandia-2 58 | \def\addcontentsline#1#2#3{} 59 | 60 | % Title stuff, taken from deproc. 61 | \def\maketitle{\par 62 | \begingroup 63 | \def\thefootnote{\fnsymbol{footnote}} 64 | \def\@makefnmark{\hbox to 0pt{$^{\@thefnmark}$\hss}} % for perfect author 65 | % name centering 66 | % The footnote-mark was overlapping the footnote-text, 67 | % added the following to fix this problem (MK) 68 | \long\def\@makefntext##1{\parindent 1em\noindent 69 | \hbox to1.8em{\hss $\m@th ^{\@thefnmark}$}##1} 70 | \@maketitle \@thanks 71 | \endgroup 72 | \setcounter{footnote}{0} 73 | \let\maketitle\relax \let\@maketitle\relax 74 | \gdef\@thanks{}\gdef\@author{}\gdef\@title{}\let\thanks\relax} 75 | 76 | % The toptitlebar has been raised to top-justify the first page 77 | 78 | \usepackage{fancyhdr} 79 | \pagestyle{fancy} 80 | \fancyhead{} 81 | 82 | % Title (includes both anonimized and non-anonimized versions) 83 | \def\@maketitle{\vbox{\hsize\textwidth 84 | %\linewidth\hsize \vskip 0.1in \toptitlebar \centering 85 | {\LARGE\sc \@title\par} 86 | %\bottomtitlebar % \vskip 0.1in % minus 87 | \ificlrfinal 88 | %\lhead{Published as a conference paper at ICLR 2021} 89 | \lhead{April 2023. DRAFT 1.1.} 90 | \def\And{\end{tabular}\hfil\linebreak[0]\hfil 91 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 92 | \def\AND{\end{tabular}\hfil\linebreak[4]\hfil 93 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 94 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\@author\end{tabular}% 95 | \else 96 | \lhead{Under review as a conference paper at ICLR 2021} 97 | \def\And{\end{tabular}\hfil\linebreak[0]\hfil 98 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 99 | \def\AND{\end{tabular}\hfil\linebreak[4]\hfil 100 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}\ignorespaces}% 101 | \begin{tabular}[t]{l}\bf\rule{\z@}{24pt}Anonymous authors\\Paper under double-blind review\end{tabular}% 102 | \fi 103 | \vskip 0.3in minus 0.1in}} 104 | 105 | \renewenvironment{abstract}{\vskip.075in\centerline{\large\sc 106 | Abstract}\vspace{0.5ex}\begin{quote}}{\par\end{quote}\vskip 1ex} 107 | 108 | % sections with less space 109 | \def\section{\@startsection {section}{1}{\z@}{-2.0ex plus 110 | -0.5ex minus -.2ex}{1.5ex plus 0.3ex 111 | minus0.2ex}{\large\sc\raggedright}} 112 | 113 | \def\subsection{\@startsection{subsection}{2}{\z@}{-1.8ex plus 114 | -0.5ex minus -.2ex}{0.8ex plus .2ex}{\normalsize\sc\raggedright}} 115 | \def\subsubsection{\@startsection{subsubsection}{3}{\z@}{-1.5ex 116 | plus -0.5ex minus -.2ex}{0.5ex plus 117 | .2ex}{\normalsize\sc\raggedright}} 118 | \def\paragraph{\@startsection{paragraph}{4}{\z@}{1.5ex plus 119 | 0.5ex minus .2ex}{-1em}{\normalsize\bf}} 120 | \def\subparagraph{\@startsection{subparagraph}{5}{\z@}{1.5ex plus 121 | 0.5ex minus .2ex}{-1em}{\normalsize\sc}} 122 | \def\subsubsubsection{\vskip 123 | 5pt{\noindent\normalsize\rm\raggedright}} 124 | 125 | 126 | % Footnotes 127 | \footnotesep 6.65pt % 128 | \skip\footins 9pt plus 4pt minus 2pt 129 | \def\footnoterule{\kern-3pt \hrule width 12pc \kern 2.6pt } 130 | \setcounter{footnote}{0} 131 | 132 | % Lists and paragraphs 133 | \parindent 0pt 134 | \topsep 4pt plus 1pt minus 2pt 135 | \partopsep 1pt plus 0.5pt minus 0.5pt 136 | \itemsep 2pt plus 1pt minus 0.5pt 137 | \parsep 2pt plus 1pt minus 0.5pt 138 | \parskip .5pc 139 | 140 | 141 | %\leftmargin2em 142 | \leftmargin3pc 143 | \leftmargini\leftmargin \leftmarginii 2em 144 | \leftmarginiii 1.5em \leftmarginiv 1.0em \leftmarginv .5em 145 | 146 | %\labelsep \labelsep 5pt 147 | 148 | \def\@listi{\leftmargin\leftmargini} 149 | \def\@listii{\leftmargin\leftmarginii 150 | \labelwidth\leftmarginii\advance\labelwidth-\labelsep 151 | \topsep 2pt plus 1pt minus 0.5pt 152 | \parsep 1pt plus 0.5pt minus 0.5pt 153 | \itemsep \parsep} 154 | \def\@listiii{\leftmargin\leftmarginiii 155 | \labelwidth\leftmarginiii\advance\labelwidth-\labelsep 156 | \topsep 1pt plus 0.5pt minus 0.5pt 157 | \parsep \z@ \partopsep 0.5pt plus 0pt minus 0.5pt 158 | \itemsep \topsep} 159 | \def\@listiv{\leftmargin\leftmarginiv 160 | \labelwidth\leftmarginiv\advance\labelwidth-\labelsep} 161 | \def\@listv{\leftmargin\leftmarginv 162 | \labelwidth\leftmarginv\advance\labelwidth-\labelsep} 163 | \def\@listvi{\leftmargin\leftmarginvi 164 | \labelwidth\leftmarginvi\advance\labelwidth-\labelsep} 165 | 166 | \abovedisplayskip 7pt plus2pt minus5pt% 167 | \belowdisplayskip \abovedisplayskip 168 | \abovedisplayshortskip 0pt plus3pt% 169 | \belowdisplayshortskip 4pt plus3pt minus3pt% 170 | 171 | % Less leading in most fonts (due to the narrow columns) 172 | % The choices were between 1-pt and 1.5-pt leading 173 | %\def\@normalsize{\@setsize\normalsize{11pt}\xpt\@xpt} % got rid of @ (MK) 174 | \def\normalsize{\@setsize\normalsize{11pt}\xpt\@xpt} 175 | \def\small{\@setsize\small{10pt}\ixpt\@ixpt} 176 | \def\footnotesize{\@setsize\footnotesize{10pt}\ixpt\@ixpt} 177 | \def\scriptsize{\@setsize\scriptsize{8pt}\viipt\@viipt} 178 | \def\tiny{\@setsize\tiny{7pt}\vipt\@vipt} 179 | \def\large{\@setsize\large{14pt}\xiipt\@xiipt} 180 | \def\Large{\@setsize\Large{16pt}\xivpt\@xivpt} 181 | \def\LARGE{\@setsize\LARGE{20pt}\xviipt\@xviipt} 182 | \def\huge{\@setsize\huge{23pt}\xxpt\@xxpt} 183 | \def\Huge{\@setsize\Huge{28pt}\xxvpt\@xxvpt} 184 | 185 | \def\toptitlebar{\hrule height4pt\vskip .25in\vskip-\parskip} 186 | 187 | \def\bottomtitlebar{\vskip .29in\vskip-\parskip\hrule height1pt\vskip 188 | .09in} % 189 | %Reduced second vskip to compensate for adding the strut in \@author 190 | 191 | 192 | %% % Vertical Ruler 193 | %% % This code is, largely, from the CVPR 2010 conference style file 194 | %% % ----- define vruler 195 | %% \makeatletter 196 | %% \newbox\iclrrulerbox 197 | %% \newcount\iclrrulercount 198 | %% \newdimen\iclrruleroffset 199 | %% \newdimen\cv@lineheight 200 | %% \newdimen\cv@boxheight 201 | %% \newbox\cv@tmpbox 202 | %% \newcount\cv@refno 203 | %% \newcount\cv@tot 204 | %% % NUMBER with left flushed zeros \fillzeros[] 205 | %% \newcount\cv@tmpc@ \newcount\cv@tmpc 206 | %% \def\fillzeros[#1]#2{\cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi 207 | %% \cv@tmpc=1 % 208 | %% \loop\ifnum\cv@tmpc@<10 \else \divide\cv@tmpc@ by 10 \advance\cv@tmpc by 1 \fi 209 | %% \ifnum\cv@tmpc@=10\relax\cv@tmpc@=11\relax\fi \ifnum\cv@tmpc@>10 \repeat 210 | %% \ifnum#2<0\advance\cv@tmpc1\relax-\fi 211 | %% \loop\ifnum\cv@tmpc<#1\relax0\advance\cv@tmpc1\relax\fi \ifnum\cv@tmpc<#1 \repeat 212 | %% \cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi \relax\the\cv@tmpc@}% 213 | %% % \makevruler[][][][][] 214 | %% \def\makevruler[#1][#2][#3][#4][#5]{\begingroup\offinterlineskip 215 | %% \textheight=#5\vbadness=10000\vfuzz=120ex\overfullrule=0pt% 216 | %% \global\setbox\iclrrulerbox=\vbox to \textheight{% 217 | %% {\parskip=0pt\hfuzz=150em\cv@boxheight=\textheight 218 | %% \cv@lineheight=#1\global\iclrrulercount=#2% 219 | %% \cv@tot\cv@boxheight\divide\cv@tot\cv@lineheight\advance\cv@tot2% 220 | %% \cv@refno1\vskip-\cv@lineheight\vskip1ex% 221 | %% \loop\setbox\cv@tmpbox=\hbox to0cm{{\iclrtenhv\hfil\fillzeros[#4]\iclrrulercount}}% 222 | %% \ht\cv@tmpbox\cv@lineheight\dp\cv@tmpbox0pt\box\cv@tmpbox\break 223 | %% \advance\cv@refno1\global\advance\iclrrulercount#3\relax 224 | %% \ifnum\cv@refno<\cv@tot\repeat}}\endgroup}% 225 | %% \makeatother 226 | %% % ----- end of vruler 227 | 228 | %% % \makevruler[][][][][] 229 | %% \def\iclrruler#1{\makevruler[12pt][#1][1][3][0.993\textheight]\usebox{\iclrrulerbox}} 230 | %% \AddToShipoutPicture{% 231 | %% \ificlrfinal\else 232 | %% \iclrruleroffset=\textheight 233 | %% \advance\iclrruleroffset by -3.7pt 234 | %% \color[rgb]{.7,.7,.7} 235 | %% \AtTextUpperLeft{% 236 | %% \put(\LenToUnit{-35pt},\LenToUnit{-\iclrruleroffset}){%left ruler 237 | %% \iclrruler{\iclrrulercount}} 238 | %% } 239 | %% \fi 240 | %% } 241 | %%% To add a vertical bar on the side 242 | %\AddToShipoutPicture{ 243 | %\AtTextLowerLeft{ 244 | %\hspace*{-1.8cm} 245 | %\colorbox[rgb]{0.7,0.7,0.7}{\small \parbox[b][\textheight]{0.1cm}{}}} 246 | %} 247 | -------------------------------------------------------------------------------- /neurallogic/symbolic_generation.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from typing import Any, Mapping 3 | 4 | import jax 5 | import numpy 6 | from jax import core 7 | from jax._src.util import safe_map 8 | from plum import dispatch 9 | 10 | from neurallogic import symbolic_primitives, map_at_elements 11 | 12 | # Imports required for evaluating symbolic expressions with eval() 13 | import jax._src.lax_reference as lax_reference 14 | 15 | 16 | def symbolic_bind(prim, *args, **params): 17 | #print('\nprimitive: ', prim.name) 18 | #print('\targs:\n\t\t', args) 19 | #print('\tparams\n\t\t: ', params) 20 | symbolic_outvals = { 21 | 'broadcast_in_dim': symbolic_primitives.symbolic_broadcast_in_dim, 22 | 'reshape': symbolic_primitives.symbolic_reshape, 23 | 'transpose': symbolic_primitives.symbolic_transpose, 24 | 'convert_element_type': symbolic_primitives.symbolic_convert_element_type, 25 | 'eq': symbolic_primitives.symbolic_eq, 26 | 'ne': symbolic_primitives.symbolic_ne, 27 | 'le': symbolic_primitives.symbolic_le, 28 | 'lt': symbolic_primitives.symbolic_lt, 29 | 'ge': symbolic_primitives.symbolic_ge, 30 | 'gt': symbolic_primitives.symbolic_gt, 31 | 'add': symbolic_primitives.symbolic_add, 32 | 'sub': symbolic_primitives.symbolic_sub, 33 | 'mul': symbolic_primitives.symbolic_mul, 34 | 'div': symbolic_primitives.symbolic_div, 35 | 'tan': symbolic_primitives.symbolic_tan, 36 | 'max': symbolic_primitives.symbolic_max, 37 | 'min': symbolic_primitives.symbolic_min, 38 | 'abs': symbolic_primitives.symbolic_abs, 39 | 'round': symbolic_primitives.symbolic_round, 40 | 'floor': symbolic_primitives.symbolic_floor, 41 | 'ceil': symbolic_primitives.symbolic_ceil, 42 | 'and': symbolic_primitives.symbolic_and, 43 | 'or': symbolic_primitives.symbolic_or, 44 | 'xor': symbolic_primitives.symbolic_xor, 45 | 'not': symbolic_primitives.symbolic_not, 46 | 'reduce_and': symbolic_primitives.symbolic_reduce_and, 47 | 'reduce_or': symbolic_primitives.symbolic_reduce_or, 48 | 'reduce_xor': symbolic_primitives.symbolic_reduce_xor, 49 | 'reduce_sum': symbolic_primitives.symbolic_reduce_sum, 50 | 'select_n': symbolic_primitives.symbolic_select_n, 51 | }[prim.name](*args, **params) 52 | # print('\tresult:\n\t\t', symbolic_outvals) 53 | return symbolic_outvals 54 | 55 | 56 | def scope_put_variable(self, col: str, name: str, value: Any): 57 | variables = self._collection(col) 58 | 59 | def put(target, key, val): 60 | if key in target and isinstance(target[key], dict) and isinstance(val, Mapping): 61 | for k, v in val.items(): 62 | put(target[key], k, v) 63 | else: 64 | target[key] = val 65 | 66 | put(variables, name, value) 67 | 68 | 69 | def put_variable(self, col: str, name: str, value: Any): 70 | self.scope._variables = self.scope.variables().unfreeze() 71 | scope_put_variable(self.scope, col, name, value) 72 | 73 | 74 | # TODO: make this robust and general over multiple types of param names 75 | 76 | 77 | def convert_to_numeric_params(flax_layer, param_names: str): 78 | actual_weights = flax_layer.get_variable('params', param_names) 79 | # Convert actual weights to dummy numeric weights (if needed) 80 | if isinstance(actual_weights, list) or ( 81 | isinstance(actual_weights, numpy.ndarray) and actual_weights.dtype == object 82 | ): 83 | numeric_weights = map_at_elements.map_at_elements( 84 | actual_weights, lambda x: 0 85 | ) 86 | numeric_weights = numpy.asarray(numeric_weights, dtype=numpy.int32) 87 | put_variable(flax_layer, 'params', param_names, numeric_weights) 88 | return flax_layer, actual_weights 89 | 90 | 91 | def make_symbolic_flax_jaxpr(flax_layer, x): 92 | flax_layer, bit_weights = convert_to_numeric_params(flax_layer, 'bit_weights') 93 | flax_layer, thresholds = convert_to_numeric_params(flax_layer, 'thresholds') 94 | # Convert input to dummy numeric input (if needed) 95 | if isinstance(x, list) or (isinstance(x, numpy.ndarray) and x.dtype == object): 96 | x = map_at_elements.map_at_elements(x, lambda x: 0) 97 | x = numpy.asarray(x, dtype=numpy.int32) 98 | # Make the jaxpr that corresponds to the flax layer 99 | jaxpr = make_symbolic_jaxpr(flax_layer, x) 100 | if hasattr(jaxpr, '_consts'): 101 | # Make a list of bit_weights and thresholds but only include each if they are not None 102 | bit_weights_and_thresholds = [x for x in [bit_weights, thresholds] if x is not None] 103 | # Replace the dummy numeric weights with the actual weights in the jaxpr 104 | jaxpr.__setattr__('_consts', bit_weights_and_thresholds) 105 | return jaxpr 106 | 107 | 108 | 109 | def eval_jaxpr(symbolic, jaxpr, consts, *args): 110 | '''Evaluates a jaxpr by interpreting it as Python code. 111 | 112 | Parameters 113 | ---------- 114 | symbolic : bool 115 | Whether to return symbolic values or concrete values. If symbolic is 116 | True, returns symbolic values, and if symbolic is False, returns 117 | concrete values. 118 | jaxpr : Jaxpr 119 | The jaxpr to interpret. 120 | consts : tuple 121 | Constant values for the jaxpr. 122 | args : tuple 123 | Arguments for the jaxpr. 124 | 125 | Returns 126 | ------- 127 | out : tuple 128 | The result of evaluating the jaxpr. 129 | ''' 130 | 131 | # Mapping from variable -> value 132 | env = {} 133 | symbolic_env = {} 134 | 135 | # TODO: unify read and symbolic_read 136 | 137 | def read(var): 138 | # Literals are values baked into the Jaxpr 139 | if type(var) is core.Literal: 140 | return var.val 141 | return env[var] 142 | 143 | def symbolic_read(var): 144 | # Literals are values baked into the Jaxpr 145 | if type(var) is core.Literal: 146 | return var.val 147 | return symbolic_env[var] 148 | 149 | def write(var, val): 150 | env[var] = val 151 | 152 | def symbolic_write(var, val): 153 | symbolic_env[var] = val 154 | 155 | # Bind args and consts to environment 156 | if not symbolic: 157 | safe_map(write, jaxpr.invars, args) 158 | safe_map(write, jaxpr.constvars, consts) 159 | safe_map(symbolic_write, jaxpr.invars, args) 160 | safe_map(symbolic_write, jaxpr.constvars, consts) 161 | 162 | def eval_jaxpr_impl(jaxpr): 163 | # Loop through equations and evaluate primitives using `bind` 164 | for eqn in jaxpr.eqns: 165 | # Read inputs to equation from environment 166 | if not symbolic: 167 | invals = safe_map(read, eqn.invars) 168 | symbolic_invals = safe_map(symbolic_read, eqn.invars) 169 | prim = eqn.primitive 170 | if type(prim) is jax.core.CallPrimitive: 171 | call_jaxpr = eqn.params['call_jaxpr'] 172 | if not symbolic: 173 | safe_map(write, call_jaxpr.invars, map(read, eqn.invars)) 174 | try: 175 | safe_map( 176 | symbolic_write, 177 | call_jaxpr.invars, 178 | map(symbolic_read, eqn.invars), 179 | ) 180 | except: 181 | pass 182 | eval_jaxpr_impl(call_jaxpr) 183 | if not symbolic: 184 | safe_map(write, eqn.outvars, map(read, call_jaxpr.outvars)) 185 | safe_map( 186 | symbolic_write, eqn.outvars, map(symbolic_read, call_jaxpr.outvars) 187 | ) 188 | else: 189 | if not symbolic: 190 | outvals = prim.bind(*invals, **eqn.params) 191 | symbolic_outvals = symbolic_bind(prim, *symbolic_invals, **eqn.params) 192 | # Primitives may return multiple outputs or not 193 | if not prim.multiple_results: 194 | if not symbolic: 195 | outvals = [outvals] 196 | symbolic_outvals = [symbolic_outvals] 197 | if not symbolic: 198 | # Always check that the symbolic binding generates the same values as the 199 | # standard jax binding in order to detect bugs early. 200 | # print(f'outvals: {outvals} and symbolic_outvals: {symbolic_outvals}') 201 | assert numpy.allclose( 202 | numpy.array(outvals), symbolic_outvals, equal_nan=True 203 | ) 204 | # Write the results of the primitive into the environment 205 | if not symbolic: 206 | safe_map(write, eqn.outvars, outvals) 207 | safe_map(symbolic_write, eqn.outvars, symbolic_outvals) 208 | 209 | # Read the final result of the Jaxpr from the environment 210 | eval_jaxpr_impl(jaxpr) 211 | if not symbolic: 212 | return safe_map(read, jaxpr.outvars)[0] 213 | else: 214 | return safe_map(symbolic_read, jaxpr.outvars)[0] 215 | 216 | 217 | def make_symbolic_jaxpr(func: typing.Callable, *args): 218 | return jax.make_jaxpr(lambda *args: func(*args))(*args) 219 | 220 | # TODO: better name 221 | def eval_symbolic(jaxpr, *args): 222 | if hasattr(jaxpr, 'literals'): 223 | return eval_jaxpr( 224 | False, jaxpr.jaxpr, jaxpr.literals, *args 225 | ) 226 | return eval_jaxpr(False, jaxpr.jaxpr, [], *args) 227 | 228 | # TODO: better name 229 | def symbolic_expression(jaxpr, *args): 230 | if hasattr(jaxpr, 'literals'): 231 | sym_expr = eval_jaxpr(True, jaxpr.jaxpr, jaxpr.literals, *args) 232 | else: 233 | sym_expr = eval_jaxpr(True, jaxpr.jaxpr, [], *args) 234 | return sym_expr 235 | 236 | 237 | @dispatch 238 | def eval_symbolic_expression(x: str): 239 | # TODO: distinguish python code-gen from other possible code-gen 240 | eval_str = x.replace('inf', 'numpy.inf') 241 | return eval(eval_str) 242 | 243 | 244 | @dispatch 245 | def eval_symbolic_expression(x: numpy.ndarray): 246 | return numpy.vectorize(eval_symbolic_expression)(x) 247 | 248 | 249 | @dispatch 250 | def eval_symbolic_expression(x: list): 251 | return numpy.vectorize(eval)(x) 252 | 253 | -------------------------------------------------------------------------------- /neurallogic/symbolic_primitives.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import jax 4 | import jax._src.lax_reference as lax_reference 5 | import jax._src.lax.lax as lax 6 | import numpy 7 | 8 | from neurallogic import symbolic_operator, symbolic_representation 9 | 10 | 11 | def all_concrete_values(data): 12 | if isinstance(data, str): 13 | return False 14 | if isinstance(data, (list, tuple)): 15 | return all(all_concrete_values(x) for x in data) 16 | if isinstance(data, dict): 17 | return all(all_concrete_values(v) for v in data.values()) 18 | if isinstance(data, numpy.ndarray): 19 | return all_concrete_values(data.tolist()) 20 | if isinstance(data, jax.numpy.ndarray): 21 | return all_concrete_values(data.tolist()) 22 | return True 23 | 24 | 25 | def symbolic(concrete_function: Callable, symbolic_function: str, *args, **kwargs): 26 | if all_concrete_values([*args]): 27 | # We can directly evaluate the function 28 | return concrete_function(*args, **kwargs) 29 | else: 30 | # We need to return a symbolic representation 31 | return symbolic_operator.symbolic_operator(symbolic_function, *args, **kwargs) 32 | 33 | 34 | def symbolic_not(*args, **kwargs): 35 | return symbolic(numpy.logical_not, 'numpy.logical_not', *args, **kwargs) 36 | 37 | 38 | def symbolic_eq(*args, **kwargs): 39 | return symbolic(lax_reference.eq, 'lax_reference.eq', *args, **kwargs) 40 | 41 | 42 | def symbolic_ne(*args, **kwargs): 43 | return symbolic(lax_reference.ne, 'lax_reference.ne', *args, **kwargs) 44 | 45 | 46 | def symbolic_le(*args, **kwargs): 47 | return symbolic(lax_reference.le, 'lax_reference.le', *args, **kwargs) 48 | 49 | 50 | def symbolic_lt(*args, **kwargs): 51 | return symbolic(lax_reference.lt, 'lax_reference.lt', *args, **kwargs) 52 | 53 | 54 | def symbolic_ge(*args, **kwargs): 55 | return symbolic(lax_reference.ge, 'lax_reference.ge', *args, **kwargs) 56 | 57 | 58 | def symbolic_gt(*args, **kwargs): 59 | return symbolic(lax_reference.gt, 'lax_reference.gt', *args, **kwargs) 60 | 61 | 62 | def symbolic_abs(*args, **kwargs): 63 | return symbolic(lax_reference.abs, 'lax_reference.abs', *args, **kwargs) 64 | 65 | 66 | def symbolic_floor(*args, **kwargs): 67 | return symbolic(lax_reference.floor, 'lax_reference.floor', *args, **kwargs) 68 | 69 | 70 | def symbolic_ceil(*args, **kwargs): 71 | return symbolic(lax_reference.ceil, 'lax_reference.ceil', *args, **kwargs) 72 | 73 | 74 | def symbolic_round(*args, **kwargs): 75 | # The reference implementation only supports away from zero 76 | if kwargs['rounding_method'] == lax.RoundingMethod.AWAY_FROM_ZERO: 77 | return symbolic(lax_reference.round, 'lax_reference.round', *args) 78 | elif kwargs['rounding_method'] == lax.RoundingMethod.TO_NEAREST_EVEN: 79 | return symbolic(numpy.around, 'numpy.around', *args) 80 | else: 81 | raise NotImplementedError( 82 | f'rounding_method {str(kwargs["rounding_method"])} not implemented') 83 | 84 | 85 | def symbolic_add(*args, **kwargs): 86 | return symbolic(lax_reference.add, 'lax_reference.add', *args, **kwargs) 87 | 88 | 89 | def symbolic_sub(*args, **kwargs): 90 | return symbolic(lax_reference.sub, 'lax_reference.sub', *args, **kwargs) 91 | 92 | 93 | def symbolic_mul(*args, **kwargs): 94 | return symbolic(lax_reference.mul, 'lax_reference.mul', *args, **kwargs) 95 | 96 | 97 | def symbolic_div(*args, **kwargs): 98 | return symbolic(lax_reference.div, 'lax_reference.div', *args, **kwargs) 99 | 100 | 101 | def symbolic_tan(*args, **kwargs): 102 | return symbolic(lax_reference.tan, 'lax_reference.tan', *args, **kwargs) 103 | 104 | 105 | def symbolic_max(*args, **kwargs): 106 | return symbolic(lax_reference.max, 'lax_reference.max', *args, **kwargs) 107 | 108 | 109 | def symbolic_min(*args, **kwargs): 110 | return symbolic(lax_reference.min, 'lax_reference.min', *args, **kwargs) 111 | 112 | 113 | def symbolic_and(*args, **kwargs): 114 | return symbolic(numpy.logical_and, 'numpy.logical_and', *args, **kwargs) 115 | 116 | 117 | def symbolic_or(*args, **kwargs): 118 | return symbolic(numpy.logical_or, 'numpy.logical_or', *args, **kwargs) 119 | 120 | 121 | def symbolic_xor(*args, **kwargs): 122 | return symbolic(numpy.logical_xor, 'numpy.logical_xor', *args, **kwargs) 123 | 124 | 125 | def symbolic_sum(*args, **kwargs): 126 | # N.B. We pass the tuple directly because we're summing over all args 127 | return symbolic(lax_reference.sum, 'lax_reference.sum', args, **kwargs) 128 | 129 | 130 | def symbolic_broadcast_in_dim(*args, **kwargs): 131 | # broadcast_in_dim requires numpy arrays not lists 132 | args = tuple([numpy.array(arg) if isinstance( 133 | arg, list) else arg for arg in args]) 134 | return lax_reference.broadcast_in_dim(*args, **kwargs) 135 | 136 | 137 | def symbolic_reshape(*args, **kwargs): 138 | return lax_reference.reshape(*args, **kwargs) 139 | 140 | 141 | def symbolic_transpose(*args, **kwargs): 142 | return lax_reference.transpose(*args, axes=kwargs['permutation']) 143 | 144 | 145 | def symbolic_convert_element_type(*args, **kwargs): 146 | # Check if all the boolean arguments are True or False 147 | if all_concrete_values([*args]): 148 | # If so, we can use the lax reference implementation 149 | return lax_reference.convert_element_type(*args, dtype=kwargs['new_dtype']) 150 | else: 151 | # Otherwise, we nop 152 | def convert_element_type(x, dtype): 153 | return x 154 | return convert_element_type(*args, dtype=kwargs['new_dtype']) 155 | 156 | 157 | def symbolic_select_n(*args, **kwargs): 158 | ''' 159 | Important comment from lax.py 160 | # Caution! The select_n_p primitive has the *opposite* order of arguments to 161 | # select(). This is because it implements `select_n`. 162 | ''' 163 | pred = args[0] 164 | on_true = args[1] 165 | on_false = args[2] 166 | if all_concrete_values([*args]): 167 | # swap order of on_true and on_false 168 | return lax_reference.select(pred, on_false, on_true) 169 | else: 170 | # TODO: to retain tensor structure we need to push down the select to the 171 | # lowest level of the symbolic expression tree. This is not currently 172 | # implemented. 173 | print('WARNING: symbolic_select_n is not fully implemented. This may not work as expected.') 174 | # swap order of on_true and on_false 175 | evaluable_pred = symbolic_representation.symbolic_representation(pred) 176 | evaluable_on_true = symbolic_representation.symbolic_representation( 177 | on_true) 178 | evaluable_on_false = symbolic_representation.symbolic_representation( 179 | on_false) 180 | return f'lax_reference.select({evaluable_pred}, {evaluable_on_false}, {evaluable_on_true})' 181 | 182 | 183 | def make_symbolic_reducer(py_binop, init_val): 184 | # This function is a hack to get around the fact that JAX doesn't 185 | # support symbolic reduction operations. It takes a symbolic reduction 186 | # operation and a symbolic initial value and returns a function that 187 | # performs the reduction operation on a numpy array. 188 | def reducer(operand, axis): 189 | # axis=None means we are reducing over all axes of the operand. 190 | axis = range(numpy.ndim(operand)) if axis is None else axis 191 | 192 | # We create a new array with the same shape as the operand, but with the 193 | # dimensions corresponding to the axis argument removed. The values in this 194 | # array will be the result of the reduction. 195 | result = numpy.full( 196 | numpy.delete(numpy.shape(operand), axis), 197 | init_val, 198 | dtype=numpy.asarray(operand).dtype, 199 | ) 200 | 201 | # We iterate over all elements of the operand, computing the reduction. 202 | for idx, _ in numpy.ndenumerate(operand): 203 | # We need to index into the result array with the same indices that we used 204 | # to index into the operand, but with the axis dimensions removed. 205 | out_idx = tuple(numpy.delete(idx, axis)) 206 | result[out_idx] = py_binop(result[out_idx], operand[idx]) 207 | return result 208 | 209 | return reducer 210 | 211 | 212 | def symbolic_reduce(operand, init_value, computation, dimensions): 213 | reducer = make_symbolic_reducer(computation, init_value) 214 | return reducer(operand, tuple(dimensions)).astype(operand.dtype) 215 | 216 | 217 | def symbolic_reduce_and(*args, **kwargs): 218 | if all_concrete_values([*args]): 219 | return lax_reference.reduce( 220 | *args, 221 | init_value=True, 222 | computation=numpy.logical_and, 223 | dimensions=kwargs['axes'], 224 | ) 225 | else: 226 | return symbolic_reduce( 227 | *args, 228 | init_value='True', 229 | computation=symbolic_and, 230 | dimensions=kwargs['axes'], 231 | ) 232 | 233 | 234 | def symbolic_reduce_or(*args, **kwargs): 235 | if all_concrete_values([*args]): 236 | return lax_reference.reduce( 237 | *args, 238 | init_value=False, 239 | computation=numpy.logical_or, 240 | dimensions=kwargs['axes'], 241 | ) 242 | else: 243 | return symbolic_reduce( 244 | *args, 245 | init_value='False', 246 | computation=symbolic_or, 247 | dimensions=kwargs['axes'], 248 | ) 249 | 250 | 251 | def symbolic_reduce_xor(*args, **kwargs): 252 | if all_concrete_values([*args]): 253 | return lax_reference.reduce( 254 | *args, 255 | init_value=False, 256 | computation=numpy.logical_xor, 257 | dimensions=kwargs['axes'], 258 | ) 259 | else: 260 | return symbolic_reduce( 261 | *args, 262 | init_value='False', 263 | computation=symbolic_xor, 264 | dimensions=kwargs['axes'], 265 | ) 266 | 267 | 268 | def symbolic_reduce_sum(*args, **kwargs): 269 | if all_concrete_values([*args]): 270 | return lax_reference.reduce( 271 | *args, init_value=0, computation=numpy.add, dimensions=kwargs['axes'] 272 | ) 273 | else: 274 | return symbolic_reduce( 275 | *args, init_value='0', computation=symbolic_sum, dimensions=kwargs['axes'] 276 | ) 277 | -------------------------------------------------------------------------------- /tests/test_hard_majority.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import jax 3 | 4 | from neurallogic import hard_majority, harden, symbolic_generation 5 | from tests import utils 6 | 7 | 8 | def test_majority_index(): 9 | assert hard_majority.majority_index(1) == 0 10 | assert hard_majority.majority_index(2) == 0 11 | assert hard_majority.majority_index(3) == 1 12 | assert hard_majority.majority_index(4) == 1 13 | assert hard_majority.majority_index(5) == 2 14 | assert hard_majority.majority_index(6) == 2 15 | assert hard_majority.majority_index(7) == 3 16 | assert hard_majority.majority_index(8) == 3 17 | assert hard_majority.majority_index(9) == 4 18 | assert hard_majority.majority_index(10) == 4 19 | assert hard_majority.majority_index(11) == 5 20 | assert hard_majority.majority_index(12) == 5 21 | 22 | 23 | def test_majority_bit(): 24 | assert hard_majority.majority_bit(numpy.array([1.0])) == 1.0 25 | assert hard_majority.majority_bit(numpy.array([2.0, 1.0])) == 1.0 26 | assert hard_majority.majority_bit(numpy.array([1.0, 3.0, 2.0])) == 2.0 27 | assert hard_majority.majority_bit(numpy.array([2.0, 1.0, 4.0, 3.0])) == 2.0 28 | assert hard_majority.majority_bit(numpy.array([1.0, 2.0, 3.0, 4.0, 5.0])) == 3.0 29 | assert ( 30 | hard_majority.majority_bit(numpy.array([6.0, 3.0, 2.0, 4.0, 5.0, 1.0])) == 3.0 31 | ) 32 | assert ( 33 | hard_majority.majority_bit(numpy.array([7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0])) 34 | == 4.0 35 | ) 36 | assert ( 37 | hard_majority.majority_bit( 38 | numpy.array([2.0, 1.0, 4.0, 3.0, 6.0, 5.0, 8.0, 7.0]) 39 | ) 40 | == 4.0 41 | ) 42 | assert ( 43 | hard_majority.majority_bit( 44 | numpy.array([1.0, 2.0, 3.0, 5.0, 4.0, 6.0, 7.0, 9.0, 8.0]) 45 | ) 46 | == 5.0 47 | ) 48 | assert ( 49 | hard_majority.majority_bit( 50 | numpy.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]) 51 | ) 52 | == 5.0 53 | ) 54 | assert ( 55 | hard_majority.majority_bit( 56 | numpy.array([11.0, 10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0]) 57 | ) 58 | == 6.0 59 | ) 60 | assert ( 61 | hard_majority.majority_bit( 62 | numpy.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]) 63 | ) 64 | == 6.0 65 | ) 66 | 67 | 68 | def test_hard_majority(): 69 | assert hard_majority.hard_majority(numpy.array([True])) == True 70 | assert hard_majority.hard_majority(numpy.array([False])) == False 71 | assert hard_majority.hard_majority(numpy.array([True, False])) == False 72 | assert hard_majority.hard_majority(numpy.array([False, True, False])) == False 73 | assert hard_majority.hard_majority(numpy.array([True, False, True, False])) == False 74 | assert ( 75 | hard_majority.hard_majority(numpy.array([False, True, False, True, False])) 76 | == False 77 | ) 78 | assert ( 79 | hard_majority.hard_majority(numpy.array([True, True, True, False, True, False])) 80 | == True 81 | ) 82 | assert ( 83 | hard_majority.hard_majority( 84 | numpy.array([True, False, False, True, True, True, False]) 85 | ) 86 | == True 87 | ) 88 | assert ( 89 | hard_majority.hard_majority( 90 | numpy.array([False, True, False, True, False, True, False, True]) 91 | ) 92 | == False 93 | ) 94 | assert ( 95 | hard_majority.hard_majority( 96 | numpy.array([True, True, True, True, True, False, True, True, True]) 97 | ) 98 | == True 99 | ) 100 | assert ( 101 | hard_majority.hard_majority( 102 | numpy.array( 103 | [True, False, False, False, False, False, True, True, True, True] 104 | ) 105 | ) 106 | == False 107 | ) 108 | 109 | 110 | def test_soft_and_hard_majority_equivalence(): 111 | soft_maj = jax.jit(hard_majority.soft_majority) 112 | hard_maj = jax.jit(hard_majority.hard_majority) 113 | for i in range(1, 50): 114 | input = numpy.random.rand(i) 115 | soft_output = soft_maj(input) 116 | hard_output = hard_maj(harden.harden(input)) 117 | assert harden.harden(soft_output) == hard_output 118 | 119 | 120 | def test_soft_majority_layer(): 121 | assert numpy.allclose( 122 | hard_majority.soft_majority_layer(numpy.array([[0.0, 1.0], [1.0, 0.0]])), 123 | numpy.array([0.25, 0.25]), 124 | ) 125 | assert numpy.allclose( 126 | hard_majority.soft_majority_layer( 127 | numpy.array([[0.0, 1.0, 1.0], [1.0, 0.0, 0.0]]) 128 | ), 129 | numpy.array([0.8333334, 0.16666667]), 130 | ) 131 | assert numpy.allclose( 132 | hard_majority.soft_majority_layer( 133 | numpy.array([[1.0, 0.0, 1.0, 0.0], [1.0, 0.0, 1.0, 1.0]]) 134 | ), 135 | numpy.array([0.25, 0.875]), 136 | ) 137 | assert numpy.allclose( 138 | hard_majority.soft_majority_layer( 139 | numpy.array([[0.0, 1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 1.0, 1.0]]) 140 | ), 141 | numpy.array([0.2, 0.9]), 142 | ) 143 | assert numpy.allclose( 144 | hard_majority.soft_majority_layer( 145 | numpy.array( 146 | [[0.0, 1.0, 0.0, 1.0, 0.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 0.0]] 147 | ) 148 | ), 149 | numpy.array([0.25, 0.9166667]), 150 | ) 151 | assert numpy.allclose( 152 | hard_majority.soft_majority_layer( 153 | numpy.array( 154 | [ 155 | [1.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.4], 156 | [1.0, 0.9, 0.8, 0.45, 0.48, 0.51, 0.52], 157 | ] 158 | ) 159 | ), 160 | numpy.array([0.15714286, 0.51331425]), 161 | ) 162 | 163 | 164 | def test_hard_majority_layer(): 165 | assert numpy.all( 166 | hard_majority.hard_majority_layer(numpy.array([[True, False], [False, True]])) 167 | == numpy.array([False, False]) 168 | ) 169 | assert numpy.all( 170 | hard_majority.hard_majority_layer( 171 | numpy.array([[True, False, True], [True, False, False]]) 172 | ) 173 | == numpy.array([True, False]) 174 | ) 175 | assert numpy.all( 176 | hard_majority.hard_majority_layer( 177 | numpy.array([[True, False, True, False], [False, True, False, True]]) 178 | ) 179 | == numpy.array([False, False]) 180 | ) 181 | assert numpy.all( 182 | hard_majority.hard_majority_layer( 183 | numpy.array( 184 | [[True, False, True, False, True], [True, False, True, False, True]] 185 | ) 186 | ) 187 | == numpy.array([True, True]) 188 | ) 189 | assert numpy.all( 190 | hard_majority.hard_majority_layer( 191 | numpy.array( 192 | [ 193 | [True, False, True, False, True, False], 194 | [False, True, False, True, False, True], 195 | ] 196 | ) 197 | ) 198 | == numpy.array([False, False]) 199 | ) 200 | assert numpy.all( 201 | hard_majority.hard_majority_layer( 202 | numpy.array( 203 | [ 204 | [True, False, True, False, True, False, True], 205 | [True, False, True, False, True, False, False], 206 | ] 207 | ) 208 | ) 209 | == numpy.array([True, False]) 210 | ) 211 | 212 | assert numpy.all( 213 | hard_majority.hard_majority_layer( 214 | numpy.array([[True, False], [False, True], [False, True]]) 215 | ) 216 | == numpy.array([False, False, False]) 217 | ) 218 | assert numpy.all( 219 | hard_majority.hard_majority_layer( 220 | numpy.array([[True, False, True], [True, False, True], [True, False, True]]) 221 | ) 222 | == numpy.array([True, True, True]) 223 | ) 224 | assert numpy.all( 225 | hard_majority.hard_majority_layer( 226 | numpy.array( 227 | [ 228 | [True, False, True, False], 229 | [False, True, False, True], 230 | [False, True, False, True], 231 | ] 232 | ) 233 | ) 234 | == numpy.array([False, False, False]) 235 | ) 236 | assert numpy.all( 237 | hard_majority.hard_majority_layer( 238 | numpy.array( 239 | [ 240 | [True, False, True, False, True], 241 | [True, False, True, False, True], 242 | [True, False, True, False, True], 243 | ] 244 | ) 245 | ) 246 | == numpy.array([True, True, True]) 247 | ) 248 | assert numpy.all( 249 | hard_majority.hard_majority_layer( 250 | numpy.array( 251 | [ 252 | [True, False, True, False, True, False], 253 | [False, True, False, True, False, True], 254 | [False, True, False, True, False, True], 255 | ] 256 | ) 257 | ) 258 | == numpy.array([False, False, False]) 259 | ) 260 | 261 | 262 | def test_layer(): 263 | test_data = [ 264 | [[[0.8, 0.1, 0.4], [1.0, 0.0, 0.3]], [0.44333333, 0.3866667]], 265 | [ 266 | [[0.8, 0.1, 0.4], [1.0, 0.0, 0.3], [0.0, 0.0, 0.0]], 267 | [0.44333333, 0.3866667, 0.0], 268 | ], 269 | [ 270 | [[0.8, 0.1, 0.4], [1.0, 0.0, 0.3], [0.8, 0.9, 0.1], [0.2, 0.01, 0.45]], 271 | [0.44333333, 0.3866667, 0.68, 0.266], 272 | ], 273 | [ 274 | [ 275 | [0.8, 0.1, 0.4], 276 | [1.0, 0.0, 0.3], 277 | [0.8, 0.9, 0.1], 278 | [0.2, 0.01, 0.45], 279 | [0.0, 0.0, 0.0], 280 | ], 281 | [0.44333333, 0.3866667, 0.68, 0.266, 0.0], 282 | ], 283 | [ 284 | [ 285 | [0.3, 0.93, 0.01, 0.5], 286 | [0.2, 0.01, 0.45, 0.1], 287 | [0.8, 0.9, 0.1, 0.2], 288 | [0.8, 0.1, 0.4, 0.3], 289 | [0.0, 0.0, 0.0, 0.0], 290 | ], 291 | [0.38700002, 0.176, 0.35000002, 0.38, 0.0], 292 | ], 293 | ] 294 | 295 | for input, expected in test_data: 296 | 297 | def soft(input): 298 | return hard_majority.soft_majority_layer(input) 299 | 300 | def hard(input): 301 | return hard_majority.hard_majority_layer(input) 302 | 303 | utils.check_consistency( 304 | soft, hard, jax.numpy.array(expected), jax.numpy.array(input) 305 | ) 306 | 307 | 308 | # TODO: test training the hard majority layer 309 | -------------------------------------------------------------------------------- /tests/test_toy_problem.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import jax 4 | import ml_collections 5 | import numpy 6 | import pytest 7 | import optax 8 | import scipy 9 | from flax.training import train_state 10 | from jax.config import config 11 | from tqdm import tqdm 12 | 13 | from neurallogic import ( 14 | hard_and, 15 | hard_dropout, 16 | hard_majority, 17 | hard_masks, 18 | hard_not, 19 | hard_or, 20 | hard_xor, 21 | hard_count, 22 | harden, 23 | harden_layer, 24 | neural_logic_net, 25 | real_encoder, 26 | initialization, 27 | hard_vmap, 28 | hard_concatenate, 29 | symbolic_primitives 30 | ) 31 | from tests import utils 32 | 33 | config.update("jax_enable_x64", True) 34 | 35 | """ 36 | Temperature: 4 booleans, 1-hot vector 37 | 0 high = very cold 38 | 1 high = cold 39 | 2 high = warm 40 | 3 high = very warm 41 | Outside?: 1 boolean 42 | 0 = no 43 | 1 = yes 44 | Labels: 45 | 0 = wear t-shirt 46 | 1 = wear coat 47 | """ 48 | 49 | toy_data = 2 50 | num_classes = 2 51 | if toy_data == 1: 52 | num_features = 1 53 | else: 54 | num_features = 5 55 | 56 | 57 | def check_symbolic(nets, data, trained_state, dropout_rng): 58 | x_training, y_training, x_test, y_test = data 59 | _, hard, symbolic = nets 60 | _, test_loss, test_accuracy = apply_model_with_grad(trained_state, x_test, y_test, dropout_rng) 61 | print( 62 | "soft_net: final test_loss: %.4f, final test_accuracy: %.2f" 63 | % (test_loss, test_accuracy * 100) 64 | ) 65 | hard_weights = harden.hard_weights(trained_state.params) 66 | hard_trained_state = TrainState.create( 67 | apply_fn=hard.apply, 68 | params=hard_weights, 69 | tx=optax.sgd(1.0, 1.0), 70 | dropout_rng=dropout_rng, 71 | ) 72 | hard_input = harden.harden(x_test) 73 | hard_test_accuracy = apply_hard_model_to_data(hard_trained_state, hard_input, y_test) 74 | print("hard_net: final test_accuracy: %.2f" % (hard_test_accuracy * 100)) 75 | assert numpy.isclose(test_accuracy, hard_test_accuracy, atol=0.0001) 76 | 77 | # CPU and GPU give different results, so we can't easily regress on a static symbolic expression 78 | if toy_data == 1: 79 | symbolic_input = ["outside"] 80 | else: 81 | symbolic_input = ["very-cold", "cold", "warm", "very-warm", "outside"] 82 | # This simply checks that the symbolic output can be generated 83 | symbolic_output = symbolic.apply({"params": hard_weights}, symbolic_input, training=False) 84 | print("symbolic_output: class 1", symbolic_output[0][:10000]) 85 | print("symbolic_output: class 2", symbolic_output[1][:10000]) 86 | 87 | 88 | def nln_1(type, x, training: bool): 89 | dtype = jax.numpy.float32 90 | layer_size = 2 91 | x = hard_not.not_layer(type)(layer_size)(x) 92 | x = x.ravel() 93 | x = x.reshape((num_classes, int(x.shape[0] / num_classes))) 94 | x = hard_majority.majority_layer(type)()(x) 95 | ######################################################## 96 | x = harden_layer.harden_layer(type)(x) 97 | x = x.reshape((num_classes, int(x.shape[0] / num_classes))) 98 | x = x.sum(-1) 99 | return x 100 | 101 | """ 102 | Class 1: 103 | lax_reference.ge(lax_reference.sum((0, numpy.logical_not(numpy.logical_xor(lax_reference.ne(x0, 0), False)))), 1) 104 | 105 | is equivalent to: 106 | 107 | sum( 108 | ( 109 | 0, 110 | ! xor(x0 != 0, False) 111 | ) 112 | ) >= 1 113 | 114 | is equivalent to: 115 | 116 | ! xor(x0 != 0, False) >= 1 117 | 118 | is equivalent to: 119 | 120 | ! x0 121 | 122 | Class 2: 123 | class 2 lax_reference.ge(lax_reference.sum((0, numpy.logical_not(numpy.logical_xor(lax_reference.ne(x0, 0), True)))), 1) 124 | 125 | is equivalent to: 126 | 127 | ! xor(x0 != 0, True) >= 1 128 | 129 | is equivalent to: 130 | 131 | x0 132 | 133 | Therefore learned class prediction is [!x, x] 134 | """ 135 | 136 | def nln_2(type, x, training: bool): 137 | dtype = jax.numpy.float32 138 | x = hard_not.not_layer(type)(8)(x) 139 | x = x.ravel() 140 | x = x.reshape((num_classes, int(x.shape[0] / num_classes))) 141 | x = hard_majority.majority_layer(type)()(x) 142 | ######################################################## 143 | x = harden_layer.harden_layer(type)(x) 144 | x = x.reshape((num_classes, int(x.shape[0] / num_classes))) 145 | x = x.sum(-1) 146 | return x 147 | 148 | def nln(type, x, training: bool): 149 | if toy_data == 1: 150 | return nln_1(type, x, training) 151 | else: 152 | return nln_2(type, x, training) 153 | 154 | def batch_nln(type, x, training: bool): 155 | return jax.vmap(lambda x: nln(type, x, training))(x) 156 | 157 | 158 | class TrainState(train_state.TrainState): 159 | dropout_rng: jax.random.KeyArray 160 | 161 | 162 | def create_train_state(net, rng, dropout_rng, config): 163 | mock_input = jax.numpy.ones([1, num_features]) 164 | soft_weights = net.init(rng, mock_input, training=False)["params"] 165 | tx = optax.radam(learning_rate=config.learning_rate) 166 | return TrainState.create( 167 | apply_fn=net.apply, params=soft_weights, tx=tx, dropout_rng=dropout_rng 168 | ) 169 | 170 | 171 | @jax.jit 172 | def update_model(state, grads): 173 | return state.apply_gradients(grads=grads) 174 | 175 | def apply_model_with_grad_impl(state, features, labels, dropout_rng, training: bool): 176 | dropout_train_rng = jax.random.fold_in(key=dropout_rng, data=state.step) 177 | 178 | def loss_fn(params): 179 | logits = state.apply_fn( 180 | {"params": params}, 181 | features, 182 | training=training, 183 | rngs={"dropout": dropout_train_rng}, 184 | ) 185 | one_hot = jax.nn.one_hot(labels, num_classes) 186 | loss = jax.numpy.mean( 187 | optax.softmax_cross_entropy(logits=logits, labels=one_hot) 188 | ) 189 | return loss, logits 190 | 191 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 192 | (loss, logits), grads = grad_fn(state.params) 193 | accuracy = jax.numpy.mean(jax.numpy.argmax(logits, -1) == labels) 194 | return grads, loss, accuracy 195 | 196 | 197 | @jax.jit 198 | def apply_model_with_grad_and_training(state, features, labels, dropout_rng): 199 | return apply_model_with_grad_impl( 200 | state, features, labels, dropout_rng, training=True 201 | ) 202 | 203 | 204 | @jax.jit 205 | def apply_model_with_grad(state, features, labels, dropout_rng): 206 | return apply_model_with_grad_impl( 207 | state, features, labels, dropout_rng, training=False 208 | ) 209 | 210 | 211 | def train_epoch(state, features, labels, batch_size, rng, dropout_rng): 212 | train_ds_size = len(features) 213 | steps_per_epoch = train_ds_size // batch_size 214 | 215 | perms = jax.random.permutation(rng, len(features)) 216 | perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch 217 | perms = perms.reshape((steps_per_epoch, batch_size)) 218 | 219 | epoch_loss = [] 220 | epoch_accuracy = [] 221 | 222 | for perm in perms: 223 | batch_features = features[perm, ...] 224 | batch_labels = labels[perm, ...] 225 | grads, loss, accuracy = apply_model_with_grad_and_training( 226 | state, batch_features, batch_labels, dropout_rng 227 | ) 228 | state = update_model(state, grads) 229 | epoch_loss.append(loss) 230 | epoch_accuracy.append(accuracy) 231 | train_loss = numpy.mean(epoch_loss) 232 | train_accuracy = numpy.mean(epoch_accuracy) 233 | return state, train_loss, train_accuracy 234 | 235 | 236 | def train_and_evaluate( 237 | init_rng, dropout_rng, net, data, config: ml_collections.ConfigDict 238 | ): 239 | state = create_train_state(net, init_rng, dropout_rng, config) 240 | x_training, y_training, x_test, y_test = data 241 | best_train_accuracy = 0.0 242 | best_test_accuracy = 0.0 243 | for epoch in range(1, config.num_epochs + 1): 244 | init_rng, input_rng = jax.random.split(init_rng) 245 | state, train_loss, train_accuracy = train_epoch( 246 | state, x_training, y_training, config.batch_size, input_rng, dropout_rng 247 | ) 248 | _, test_loss, test_accuracy = apply_model_with_grad( 249 | state, x_test, y_test, dropout_rng 250 | ) 251 | if train_accuracy > best_train_accuracy: 252 | best_train_accuracy = train_accuracy 253 | if test_accuracy >= best_test_accuracy: 254 | best_test_accuracy = test_accuracy 255 | print( 256 | "epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f" 257 | % (epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100) 258 | ) 259 | if train_accuracy == 1.0 and test_accuracy == 1.0: 260 | break 261 | 262 | # return trained state and final test_accuracy 263 | return state, test_accuracy 264 | 265 | 266 | def apply_hard_model(state, features, label): 267 | def logits_fn(params): 268 | return state.apply_fn({"params": params}, features, training=False) 269 | 270 | logits = logits_fn(state.params) 271 | if isinstance(logits, list): 272 | logits = jax.numpy.array(logits) 273 | logits *= 1.0 274 | accuracy = jax.numpy.mean(jax.numpy.argmax(logits, -1) == label) 275 | return accuracy 276 | 277 | 278 | def apply_hard_model_to_data(state, features, labels): 279 | accuracy = 0 280 | for image, label in tqdm(zip(features, labels), total=len(features)): 281 | accuracy += apply_hard_model(state, image, label) 282 | return accuracy / len(features) 283 | 284 | 285 | def get_config(): 286 | config = ml_collections.ConfigDict() 287 | config.learning_rate = 0.01 288 | config.momentum = 0.9 289 | if toy_data == 1: 290 | config.batch_size = 16 291 | else: 292 | config.batch_size = 48 293 | config.num_epochs = 1000 294 | return config 295 | 296 | def get_toy_data(): 297 | data_dir = Path(__file__).parent.parent / "tests" / "data" 298 | if toy_data == 1: 299 | data = numpy.loadtxt(data_dir / "toy_data_1.txt").astype(dtype=numpy.int32) 300 | else: 301 | data = numpy.loadtxt(data_dir / "toy_data_2.txt").astype(dtype=numpy.int32) 302 | features = data[:, 0:num_features] # Input features 303 | labels = data[:, num_features] # Target value 304 | return features, labels 305 | 306 | 307 | def train_test_split(features, labels, rng, test_size=0.2): 308 | rng, split_rng = jax.random.split(rng) 309 | train_size = int(len(features) * (1 - test_size)) 310 | train_idx = jax.random.permutation(split_rng, len(features))[:train_size] 311 | test_idx = jax.random.permutation(split_rng, len(features))[train_size:] 312 | return ( 313 | features[train_idx], 314 | features[test_idx], 315 | labels[train_idx], 316 | labels[test_idx], 317 | ) 318 | 319 | @pytest.mark.skip(reason="temporarily off") 320 | def test_toy(): 321 | # Train net 322 | features, labels = get_toy_data() 323 | soft, hard, symbolic = neural_logic_net.net( 324 | lambda type, x, training: batch_nln(type, x, training) 325 | ) 326 | 327 | rng = jax.random.PRNGKey(0) 328 | print(soft.tabulate(rng, features[0:1], training=False)) 329 | 330 | num_experiments = 1 331 | final_test_accuracies = [] 332 | for i in range(num_experiments): 333 | # Split features and labels into 80% training and 20% test 334 | rng, int_rng, dropout_rng = jax.random.split(rng, 3) 335 | x_training, x_test, y_training, y_test = train_test_split( 336 | features, labels, rng, test_size=0.2 337 | ) 338 | trained_state, final_test_accuracy = train_and_evaluate( 339 | int_rng, 340 | dropout_rng, 341 | soft, 342 | (x_training, y_training, x_test, y_test), 343 | get_config(), 344 | ) 345 | final_test_accuracies.append(final_test_accuracy) 346 | print(f"{i}: final test accuracy: {final_test_accuracy * 100:.2f}") 347 | # print mean, standard error of the mean, min, max, lowest 5%, highest 5% of final test accuracies 348 | print( 349 | f"mean: {numpy.mean(final_test_accuracies) * 100:.2f}, " 350 | f"sem: {scipy.stats.sem(final_test_accuracies) * 100:.2f}, " 351 | f"min: {numpy.min(final_test_accuracies) * 100:.2f}, " 352 | f"max: {numpy.max(final_test_accuracies) * 100:.2f}, " 353 | f"5%: {numpy.percentile(final_test_accuracies, 5) * 100:.2f}, " 354 | f"95%: {numpy.percentile(final_test_accuracies, 95) * 100:.2f}" 355 | ) 356 | 357 | # Check symbolic net 358 | _, hard, symbolic = neural_logic_net.net(lambda type, x, training: nln(type, x, training)) 359 | check_symbolic((soft, hard, symbolic), (x_training, y_training, x_test, y_test), trained_state, dropout_rng) 360 | -------------------------------------------------------------------------------- /tests/test_mnist.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import ml_collections 4 | import numpy as np 5 | import optax 6 | import pytest 7 | import tensorflow as tf 8 | import tensorflow_datasets as tfds 9 | from flax import linen as nn 10 | from flax.training import train_state 11 | from jax.config import config 12 | from tqdm import tqdm 13 | 14 | from neurallogic import ( 15 | hard_and, 16 | hard_majority, 17 | hard_not, 18 | hard_or, 19 | hard_xor, 20 | hard_masks, 21 | harden, 22 | harden_layer, 23 | neural_logic_net, 24 | real_encoder, 25 | hard_dropout, 26 | initialization, 27 | hard_count, 28 | hard_vmap, 29 | symbolic_primitives, 30 | hard_concatenate 31 | ) 32 | 33 | # Uncomment to debug NaNs 34 | # config.update("jax_debug_nans", True) 35 | 36 | 37 | class CNN(nn.Module): 38 | """A simple CNN model.""" 39 | 40 | @nn.compact 41 | def __call__(self, x): 42 | x = nn.Conv(features=32, kernel_size=(3, 3))(x) 43 | x = nn.relu(x) 44 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) 45 | x = nn.Conv(features=64, kernel_size=(3, 3))(x) 46 | x = nn.relu(x) 47 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) 48 | x = x.reshape((x.shape[0], -1)) # flatten 49 | x = nn.Dense(features=256)(x) 50 | x = nn.relu(x) 51 | x = nn.Dense(features=10)(x) 52 | return x 53 | 54 | 55 | def check_symbolic(nets, datasets, trained_state, dropout_rng): 56 | _, test_ds = datasets 57 | _, hard, symbolic = nets 58 | _, test_loss, test_accuracy = apply_model_with_grad( 59 | trained_state, test_ds["image"], test_ds["label"], dropout_rng 60 | ) 61 | print( 62 | "soft_net: final test_loss: %.4f, final test_accuracy: %.2f" 63 | % (test_loss, test_accuracy * 100) 64 | ) 65 | hard_weights = harden.hard_weights(trained_state.params) 66 | hard_trained_state = train_state.TrainState.create( 67 | apply_fn=hard.apply, params=hard_weights, tx=optax.sgd(1.0, 1.0) 68 | ) 69 | hard_input = harden.harden(test_ds["image"]) 70 | hard_test_accuracy = apply_hard_model_to_images( 71 | hard_trained_state, hard_input, test_ds["label"] 72 | ) 73 | print("hard_net: final test_accuracy: %.2f" % (hard_test_accuracy * 100)) 74 | assert np.isclose(test_accuracy, hard_test_accuracy, atol=0.0001) 75 | # TODO: activate these checks 76 | if False: 77 | # It takes too long to compute this 78 | symbolic_weights = harden.symbolic_weights(trained_state.params) 79 | symbolic_trained_state = train_state.TrainState.create( 80 | apply_fn=symbolic.apply, params=symbolic_weights, tx=optax.sgd(1.0, 1.0) 81 | ) 82 | symbolic_input = hard_input.tolist() 83 | symbolic_test_accuracy = apply_hard_model_to_images( 84 | symbolic_trained_state, symbolic_input, test_ds["label"] 85 | ) 86 | print( 87 | "symbolic_net: final test_accuracy: %.2f" % (symbolic_test_accuracy * 100) 88 | ) 89 | assert np.isclose(test_accuracy, symbolic_test_accuracy, atol=0.0001) 90 | if False: 91 | # CPU and GPU give different results, so we can't easily regress on a static symbolic expression 92 | symbolic_input = [f"x{i}" for i in range(len(hard_input[0].tolist()))] 93 | symbolic_output = symbolic.apply({"params": symbolic_weights}, symbolic_input) 94 | print("symbolic_output", symbolic_output[0][:10000]) 95 | 96 | 97 | # about 95% training, 93-4% test 98 | # batch size 6000 99 | def nln_1(type, x, training: bool): 100 | input_size = 784 101 | mask_layer_size = 60 102 | dtype = jax.numpy.float32 103 | x = hard_masks.mask_to_true_layer(type)(mask_layer_size, dtype=dtype, 104 | weights_init=initialization.initialize_bernoulli(0.01, 0.3, 0.501))(x) 105 | x = x.reshape((2940, 16)) 106 | x = hard_majority.majority_layer(type)()(x) 107 | x = hard_not.not_layer(type)(20, weights_init=nn.initializers.uniform(1.0), dtype=dtype)(x) 108 | x = x.ravel() 109 | ############################## 110 | x = harden_layer.harden_layer(type)(x) 111 | num_classes = 10 112 | x = x.reshape((num_classes, int(x.shape[0] / num_classes))) 113 | x = x.sum(-1) 114 | return x 115 | 116 | def nln(type, x, training: bool): 117 | input_size = 784 118 | mask_layer_size = 200 119 | dtype = jax.numpy.float32 120 | x = hard_masks.mask_to_true_layer(type)(mask_layer_size, dtype=dtype, 121 | weights_init=initialization.initialize_bernoulli(0.01, 0.3, 0.501))(x) 122 | x = x.reshape((9800, 16)) 123 | x = hard_majority.majority_layer(type)()(x) 124 | x = hard_not.not_layer(type)(20, weights_init=nn.initializers.uniform(1.0), dtype=dtype)(x) 125 | x = x.ravel() 126 | ############################## 127 | x = harden_layer.harden_layer(type)(x) 128 | num_classes = 10 129 | x = x.reshape((num_classes, int(x.shape[0] / num_classes))) 130 | x = x.sum(-1) 131 | return x 132 | 133 | def batch_nln(type, x, training: bool): 134 | return jax.vmap(lambda x: nln(type, x, training))(x) 135 | 136 | 137 | def apply_model_with_grad_impl(state, images, labels, dropout_rng, training: bool): 138 | dropout_train_rng = jax.random.fold_in(key=dropout_rng, data=state.step) 139 | 140 | def loss_fn(params): 141 | logits = state.apply_fn( 142 | {"params": params}, 143 | images, 144 | training=training, 145 | rngs={"dropout": dropout_train_rng}, 146 | ) 147 | one_hot = jax.nn.one_hot(labels, 10) 148 | loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) 149 | return loss, logits 150 | 151 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 152 | (loss, logits), grads = grad_fn(state.params) 153 | accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) 154 | return grads, loss, accuracy 155 | 156 | 157 | @jax.jit 158 | def apply_model_with_grad_and_training(state, images, labels, dropout_rng): 159 | return apply_model_with_grad_impl(state, images, labels, dropout_rng, training=True) 160 | 161 | 162 | @jax.jit 163 | def apply_model_with_grad(state, images, labels, dropout_rng): 164 | return apply_model_with_grad_impl( 165 | state, images, labels, dropout_rng, training=False 166 | ) 167 | 168 | 169 | @jax.jit 170 | def update_model(state, grads): 171 | return state.apply_gradients(grads=grads) 172 | 173 | 174 | def train_epoch(state, train_ds, batch_size, rng, dropout_rng): 175 | """Train for a single epoch.""" 176 | train_ds_size = len(train_ds["image"]) 177 | steps_per_epoch = train_ds_size // batch_size 178 | 179 | perms = jax.random.permutation(rng, len(train_ds["image"])) 180 | perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch 181 | perms = perms.reshape((steps_per_epoch, batch_size)) 182 | 183 | epoch_loss = [] 184 | epoch_accuracy = [] 185 | 186 | for perm in perms: 187 | batch_images = train_ds["image"][perm, ...] 188 | batch_labels = train_ds["label"][perm, ...] 189 | grads, loss, accuracy = apply_model_with_grad_and_training( 190 | state, batch_images, batch_labels, dropout_rng 191 | ) 192 | state = update_model(state, grads) 193 | epoch_loss.append(loss) 194 | epoch_accuracy.append(accuracy) 195 | train_loss = np.mean(epoch_loss) 196 | train_accuracy = np.mean(epoch_accuracy) 197 | return state, train_loss, train_accuracy 198 | 199 | 200 | def get_datasets(): 201 | ds_builder = tfds.builder("mnist") 202 | ds_builder.download_and_prepare() 203 | train_ds = tfds.as_numpy(ds_builder.as_dataset(split="train", batch_size=-1)) 204 | test_ds = tfds.as_numpy(ds_builder.as_dataset(split="test", batch_size=-1)) 205 | train_ds["image"] = jnp.float32(train_ds["image"]) / 255.0 206 | test_ds["image"] = jnp.float32(test_ds["image"]) / 255.0 207 | # Convert the floating point values in [0,1] to binary values in {0,1} 208 | # If the float value is > 0.3 then we convert to 1, otherwise 0 209 | train_ds["image"] = jnp.where(train_ds["image"] > 0.3, 1.0, 0.0) 210 | test_ds["image"] = jnp.where(test_ds["image"] > 0.3, 1.0, 0.0) 211 | #train_ds["image"] = jnp.round(train_ds["image"]) 212 | #test_ds["image"] = jnp.round(test_ds["image"]) 213 | return train_ds, test_ds 214 | 215 | 216 | def show_img(img, ax=None, title=None): 217 | """Shows a single image.""" 218 | """ 219 | if ax is None: 220 | ax = plt.gca() 221 | ax.imshow(img.reshape(28, 28), cmap="gray") 222 | ax.set_xticks([]) 223 | ax.set_yticks([]) 224 | if title: 225 | ax.set_title(title) 226 | """ 227 | 228 | def show_img_grid(imgs, titles): 229 | """Shows a grid of images.""" 230 | """ 231 | n = int(np.ceil(len(imgs) ** 0.5)) 232 | _, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n)) 233 | for i, (img, title) in enumerate(zip(imgs, titles)): 234 | show_img(img, axs[i // n][i % n], title) 235 | """ 236 | 237 | class TrainState(train_state.TrainState): 238 | dropout_rng: jax.random.KeyArray 239 | 240 | 241 | def create_train_state(net, rng, dropout_rng, config): 242 | # for CNN: mock_input = jnp.ones([1, 28, 28, 1]) 243 | mock_input = jnp.ones([1, 28 * 28]) 244 | soft_weights = net.init(rng, mock_input, training=False)["params"] 245 | # tx = optax.yogi(config.learning_rate) # for nln_2 246 | tx = optax.radam(config.learning_rate) 247 | return TrainState.create( 248 | apply_fn=net.apply, params=soft_weights, tx=tx, dropout_rng=dropout_rng 249 | ) 250 | 251 | 252 | def train_and_evaluate( 253 | init_rng, 254 | dropout_rng, 255 | net, 256 | datasets, 257 | config: ml_collections.ConfigDict, 258 | workdir: str, 259 | ): 260 | state = create_train_state(net, init_rng, dropout_rng, config) 261 | train_dataset, test_dataset = datasets 262 | 263 | for epoch in range(1, config.num_epochs + 1): 264 | init_rng, input_rng = jax.random.split(init_rng) 265 | state, train_loss, train_accuracy = train_epoch( 266 | state, train_dataset, config.batch_size, input_rng, dropout_rng 267 | ) 268 | _, test_loss, test_accuracy = apply_model_with_grad( 269 | state, test_dataset["image"], test_dataset["label"], dropout_rng 270 | ) 271 | 272 | print( 273 | "epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f" 274 | % (epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100) 275 | ) 276 | 277 | return state 278 | 279 | 280 | def apply_hard_model(state, image, label): 281 | def logits_fn(params): 282 | return state.apply_fn({"params": params}, image) 283 | 284 | logits = logits_fn(state.params) 285 | if isinstance(logits, list): 286 | logits = jnp.array(logits) 287 | logits *= 1.0 288 | accuracy = jnp.mean(jnp.argmax(logits, -1) == label) 289 | return accuracy 290 | 291 | 292 | def apply_hard_model_to_images(state, images, labels): 293 | accuracy = 0 294 | for (image, label) in tqdm(zip(images, labels), total=len(images)): 295 | accuracy += apply_hard_model(state, image, label) 296 | return accuracy / len(images) 297 | 298 | 299 | def get_config(): 300 | config = ml_collections.ConfigDict() 301 | # config for CNN: config.learning_rate = 0.01 302 | config.learning_rate = 0.01 303 | config.momentum = 0.9 304 | config.batch_size = 3000 # 6000 # 128 305 | config.num_epochs = 5000 306 | return config 307 | 308 | 309 | @pytest.mark.skip(reason="temporarily off") 310 | def test_mnist(): 311 | # Make sure tf does not allocate gpu memory. 312 | tf.config.experimental.set_visible_devices([], "GPU") 313 | 314 | rng = jax.random.PRNGKey(0) 315 | rng, int_rng, dropout_rng = jax.random.split(rng, 3) 316 | 317 | # soft = CNN() 318 | soft, _, _ = neural_logic_net.net( 319 | lambda type, x, training: batch_nln(type, x, training) 320 | ) 321 | 322 | train_ds, test_ds = get_datasets() 323 | # If we're using a NLN then flatten the images 324 | train_ds["image"] = jnp.reshape(train_ds["image"], (train_ds["image"].shape[0], -1)) 325 | test_ds["image"] = jnp.reshape(test_ds["image"], (test_ds["image"].shape[0], -1)) 326 | 327 | print(soft.tabulate(rng, train_ds["image"][0:1], training=False)) 328 | 329 | # TODO: 50 experiments 330 | 331 | # Train and evaluate the model. 332 | trained_state = train_and_evaluate( 333 | int_rng, 334 | dropout_rng, 335 | soft, 336 | (train_ds, test_ds), 337 | config=get_config(), 338 | workdir="./mnist_metrics", 339 | ) 340 | 341 | # Check symbolic net 342 | #_, hard, symbolic = neural_logic_net.net(lambda type, x: nln(type, x, False)) 343 | #check_symbolic( 344 | # (soft, hard, symbolic), (train_ds, test_ds), trained_state, dropout_rng 345 | #) 346 | 347 | -------------------------------------------------------------------------------- /tests/test_noisy_xor.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | import jax 5 | import ml_collections 6 | import numpy 7 | import optax 8 | import scipy 9 | from flax import linen as nn 10 | from flax.training import train_state 11 | from jax.config import config 12 | from tqdm import tqdm 13 | 14 | from neurallogic import ( 15 | hard_and, 16 | hard_majority, 17 | hard_not, 18 | hard_or, 19 | harden, 20 | harden_layer, 21 | neural_logic_net, 22 | initialization, 23 | symbolic_primitives, 24 | hard_vmap, 25 | hard_concatenate 26 | ) 27 | from tests import utils 28 | 29 | config.update("jax_enable_x64", True) 30 | 31 | def check_symbolic(nets, data, trained_state, dropout_rng): 32 | x_training, y_training, x_test, y_test = data 33 | _, hard, symbolic = nets 34 | _, test_loss, test_accuracy = apply_model_with_grad( 35 | trained_state, x_test, y_test, dropout_rng 36 | ) 37 | print( 38 | "soft_net: final test_loss: %.4f, final test_accuracy: %.2f" 39 | % (test_loss, test_accuracy * 100) 40 | ) 41 | hard_weights = harden.hard_weights(trained_state.params) 42 | hard_trained_state = TrainState.create( 43 | apply_fn=hard.apply, 44 | params=hard_weights, 45 | tx=optax.sgd(1.0, 1.0), 46 | dropout_rng=dropout_rng, 47 | ) 48 | hard_input = harden.harden(x_test) 49 | hard_test_accuracy = apply_hard_model_to_data( 50 | hard_trained_state, hard_input, y_test 51 | ) 52 | print("hard_net: final test_accuracy: %.2f" % (hard_test_accuracy * 100)) 53 | assert numpy.isclose(test_accuracy, hard_test_accuracy, atol=0.0001) 54 | 55 | if False: 56 | symbolic_weights = hard_weights # utils.make_symbolic(hard_weights) 57 | symbolic_trained_state = train_state.TrainState.create( 58 | apply_fn=symbolic.apply, 59 | params=symbolic_weights, 60 | tx=optax.sgd(1.0, 1.0), 61 | dropout_rng=dropout_rng, 62 | ) 63 | symbolic_input = hard_input.tolist() 64 | symbolic_test_accuracy = apply_hard_model( 65 | symbolic_trained_state, symbolic_input, y_test 66 | ) 67 | print( 68 | "symbolic_net: final test_accuracy: %.2f" % (symbolic_test_accuracy * 100) 69 | ) 70 | assert numpy.isclose(test_accuracy, symbolic_test_accuracy, atol=0.0001) 71 | if True: 72 | # CPU and GPU give different results, so we can't easily regress on a static symbolic expression 73 | symbolic_input = [f"x{i}" for i in range(len(hard_input[0].tolist()))] 74 | # This simply checks that the symbolic output can be generated 75 | symbolic_output = symbolic.apply({"params": hard_weights}, symbolic_input, training=False) 76 | 77 | num_features = 12 78 | num_classes = 2 79 | 80 | 81 | def get_data(): 82 | # Create a path to the data directory 83 | data_dir = Path(__file__).parent.parent / "tests" / "data" 84 | # Load the training data 85 | training_data = numpy.loadtxt(data_dir / "NoisyXORTrainingData.txt").astype( 86 | dtype=numpy.float32 87 | ) 88 | # Load the test data 89 | test_data = numpy.loadtxt(data_dir / "NoisyXORTestData.txt").astype( 90 | dtype=numpy.float32 91 | ) 92 | return training_data, test_data 93 | 94 | 95 | """ 96 | | Technique/Accuracy | Mean | 5 %ile | 95 %ile | Min | Max | 97 | | ------------------- | -------------- | ------- | ------- | ------ | ------ | 98 | | Tsetlin | 99.3 +/- 0.3 | 95.9 | 100.0 | 91.6 | 100.0 | 99 | | dB | 97.9 +/- 0.2 | 95.4 | 100.0 | 93.6 | 100.0 | 100 | | Neural network | 95.4 +/- 0.5 | 90.1 | 98.6 | 88.2 | 99.9 | 101 | | SVM | 58.0 +/- 0.3 | 56.4 | 59.2 | 55.4 | 66.5 | 102 | | Naive Bayes | 49.8 +/- 0.2 | 48.3 | 51.0 | 41.3 | 52.7 | 103 | | Logistic regression | 49.8 +/- 0.3 | 47.8 | 51.1 | 41.1 | 53.1 | 104 | 105 | Source: https://arxiv.org/pdf/1804.01508.pdf 106 | """ 107 | # N.B. We use marginal versions of and/or layers for this performance 108 | # mean: 97.89, sem: 0.15, min: 93.58, max: 100.00, 5%: 95.40, 95%: 100.00 109 | def nln(type, x, training: bool): 110 | y = hard_vmap.vmap(type)((lambda x: 1 - x, lambda x: 1 - x, lambda x: symbolic_primitives.symbolic_not(x)))(x) 111 | x = hard_concatenate.concatenate(type)([x, y], 0) 112 | 113 | layer_size = 32 114 | dtype = jax.numpy.float64 115 | x = hard_and.and_layer(type)( 116 | layer_size, 117 | dtype=dtype, 118 | weights_init=initialization.initialize_bernoulli(0.01, 0.3, 0.501), 119 | )(x) 120 | x = hard_or.or_layer(type)( 121 | layer_size, 122 | dtype=dtype, 123 | weights_init=initialization.initialize_bernoulli(0.99, 0.499, 0.7), 124 | )(x) 125 | not_layer_size = 16 126 | x = hard_not.not_layer(type)( 127 | not_layer_size, 128 | dtype=dtype, 129 | weights_init=initialization.initialize_uniform_range(0.499, 0.501), 130 | )(x) 131 | 132 | x = x.reshape((1, layer_size * not_layer_size)) 133 | x = hard_majority.majority_layer(type)()(x) 134 | 135 | z = hard_vmap.vmap(type)((lambda x: 1 - x, lambda x: 1 - x, lambda x: symbolic_primitives.symbolic_not(x)))(x) 136 | x = hard_concatenate.concatenate(type)([x, z], 0) 137 | 138 | ######################################################## 139 | 140 | x = x.reshape((num_classes, int(x.shape[0] / num_classes))) 141 | x = x.sum(-1) 142 | return x 143 | 144 | 145 | def batch_nln(type, x, training: bool): 146 | return jax.vmap(lambda x: nln(type, x, training))(x) 147 | 148 | 149 | class TrainState(train_state.TrainState): 150 | dropout_rng: jax.random.KeyArray 151 | 152 | 153 | def create_train_state(net, rng, dropout_rng, config): 154 | mock_input = jax.numpy.ones([1, num_features]) 155 | soft_weights = net.init(rng, mock_input, training=False)["params"] 156 | tx = optax.radam(learning_rate=config.learning_rate) 157 | return TrainState.create( 158 | apply_fn=net.apply, params=soft_weights, tx=tx, dropout_rng=dropout_rng 159 | ) 160 | 161 | 162 | @jax.jit 163 | def update_model(state, grads): 164 | return state.apply_gradients(grads=grads) 165 | 166 | 167 | def apply_model_with_grad_impl(state, features, labels, dropout_rng, training: bool): 168 | dropout_train_rng = jax.random.fold_in(key=dropout_rng, data=state.step) 169 | 170 | def loss_fn(params): 171 | logits = state.apply_fn( 172 | {"params": params}, 173 | features, 174 | training=training, 175 | rngs={"dropout": dropout_train_rng}, 176 | ) 177 | one_hot = jax.nn.one_hot(labels, num_classes, dtype=jax.numpy.int32) 178 | loss = jax.numpy.mean( 179 | optax.softmax_cross_entropy(logits=logits, labels=one_hot) 180 | ) 181 | return loss, logits 182 | 183 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 184 | (loss, logits), grads = grad_fn(state.params) 185 | accuracy = jax.numpy.mean(jax.numpy.argmax(logits, -1) == labels) 186 | return grads, loss, accuracy 187 | 188 | 189 | @jax.jit 190 | def apply_model_with_grad_and_training(state, features, labels, dropout_rng): 191 | return apply_model_with_grad_impl( 192 | state, features, labels, dropout_rng, training=True 193 | ) 194 | 195 | 196 | @jax.jit 197 | def apply_model_with_grad(state, features, labels, dropout_rng): 198 | return apply_model_with_grad_impl( 199 | state, features, labels, dropout_rng, training=False 200 | ) 201 | 202 | 203 | def train_epoch(state, features, labels, batch_size, rng, dropout_rng): 204 | train_ds_size = len(features) 205 | steps_per_epoch = train_ds_size // batch_size 206 | 207 | perms = jax.random.permutation(rng, len(features)) 208 | perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch 209 | perms = perms.reshape((steps_per_epoch, batch_size)) 210 | 211 | epoch_loss = [] 212 | epoch_accuracy = [] 213 | 214 | for perm in perms: 215 | batch_features = features[perm, ...] 216 | batch_labels = labels[perm, ...] 217 | grads, loss, accuracy = apply_model_with_grad_and_training( 218 | state, batch_features, batch_labels, dropout_rng 219 | ) 220 | state = update_model(state, grads) 221 | epoch_loss.append(loss) 222 | epoch_accuracy.append(accuracy) 223 | train_loss = numpy.mean(epoch_loss) 224 | train_accuracy = numpy.mean(epoch_accuracy) 225 | return state, train_loss, train_accuracy 226 | 227 | 228 | def get_train_and_test_data(data): 229 | training_data, test_data = data 230 | x_training = training_data[:, 0:num_features] # Input features 231 | y_training = training_data[:, num_features] # Target value 232 | x_test = test_data[:, 0:num_features] # Input features 233 | y_test = test_data[:, num_features] # Target value 234 | return x_training, y_training, x_test, y_test 235 | 236 | 237 | def train_and_evaluate( 238 | init_rng, dropout_rng, net, data, config: ml_collections.ConfigDict 239 | ): 240 | state = create_train_state(net, init_rng, dropout_rng, config) 241 | x_training, y_training, x_test, y_test = data 242 | for epoch in range(1, config.num_epochs + 1): 243 | init_rng, input_rng = jax.random.split(init_rng) 244 | state, train_loss, train_accuracy = train_epoch( 245 | state, x_training, y_training, config.batch_size, input_rng, dropout_rng 246 | ) 247 | _, test_loss, test_accuracy = apply_model_with_grad( 248 | state, x_test, y_test, dropout_rng 249 | ) 250 | 251 | print( 252 | "epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f" 253 | % (epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100) 254 | ) 255 | 256 | return state, test_accuracy 257 | 258 | 259 | def apply_hard_model(state, features, label): 260 | def logits_fn(params): 261 | return state.apply_fn({"params": params}, features, training=False) 262 | 263 | logits = logits_fn(state.params) 264 | if isinstance(logits, list): 265 | logits = jax.numpy.array(logits) 266 | logits *= 1.0 267 | accuracy = jax.numpy.mean(jax.numpy.argmax(logits, -1) == label) 268 | return accuracy 269 | 270 | 271 | def apply_hard_model_to_data(state, features, labels): 272 | accuracy = 0 273 | for image, label in tqdm(zip(features, labels), total=len(features)): 274 | accuracy += apply_hard_model(state, image, label) 275 | return accuracy / len(features) 276 | 277 | 278 | def get_config(): 279 | config = ml_collections.ConfigDict() 280 | config.learning_rate = 0.01 281 | config.batch_size = 5000 282 | config.num_epochs = 2000 # 2000 for paper 283 | return config 284 | 285 | 286 | def test_noisy_xor(): 287 | # Train net 288 | soft, _, _ = neural_logic_net.net( 289 | lambda type, x, training: batch_nln(type, x, training) 290 | ) 291 | 292 | x_training, y_training, x_test, y_test = get_train_and_test_data(get_data()) 293 | 294 | rng = jax.random.PRNGKey(0) 295 | print(soft.tabulate(rng, x_training[0:1], training=False)) 296 | 297 | num_experiments = 1 # 100 for paper 298 | final_test_accuracies = [] 299 | for i in range(num_experiments): 300 | rng, int_rng, dropout_rng = jax.random.split(rng, 3) 301 | trained_state, final_test_accuracy = train_and_evaluate( 302 | int_rng, 303 | dropout_rng, 304 | soft, 305 | (x_training, y_training, x_test, y_test), 306 | get_config(), 307 | ) 308 | final_test_accuracies.append(final_test_accuracy) 309 | print(f"{i}: final test accuracy: {final_test_accuracy * 100:.2f}") 310 | # print mean, standard error of the mean, min, max, lowest 5%, highest 5% of final test accuracies 311 | print( 312 | f"mean: {numpy.mean(final_test_accuracies) * 100:.2f}, " 313 | f"sem: {scipy.stats.sem(final_test_accuracies) * 100:.2f}, " 314 | f"min: {numpy.min(final_test_accuracies) * 100:.2f}, " 315 | f"max: {numpy.max(final_test_accuracies) * 100:.2f}, " 316 | f"5%: {numpy.percentile(final_test_accuracies, 5) * 100:.2f}, " 317 | f"95%: {numpy.percentile(final_test_accuracies, 95) * 100:.2f}" 318 | ) 319 | # numpy.set_printoptions(threshold=sys.maxsize) 320 | # print(f"trained soft weights: {repr(trained_state.params)}") 321 | # hard_weights = harden.hard_weights(trained_state.params) 322 | # print(f"trained hard weights: {repr(hard_weights)}") 323 | 324 | # Check symbolic net 325 | _, hard, symbolic = neural_logic_net.net( 326 | lambda type, x, training: nln(type, x, training) 327 | ) 328 | check_symbolic( 329 | (soft, hard, symbolic), 330 | (x_training, y_training, x_test, y_test), 331 | trained_state, 332 | dropout_rng, 333 | ) 334 | -------------------------------------------------------------------------------- /tests/test_symbolic_primitives.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import numpy 3 | 4 | from neurallogic import symbolic_generation, symbolic_primitives, symbolic_operator 5 | from tests import utils 6 | 7 | 8 | def test_symbolic_expression(): 9 | output = symbolic_operator.symbolic_operator("not", "True") 10 | expected = "not(True)" 11 | assert output == expected 12 | eval_output = symbolic_generation.eval_symbolic_expression(output) 13 | eval_expected = symbolic_generation.eval_symbolic_expression(expected) 14 | assert eval_output == eval_expected 15 | 16 | 17 | def test_symbolic_expression_vector(): 18 | x = numpy.array(["True", "False"]) 19 | output = symbolic_operator.symbolic_operator("not", x) 20 | expected = numpy.array(["not(True)", "not(False)"]) 21 | assert numpy.array_equal(output, expected) 22 | eval_output = symbolic_generation.eval_symbolic_expression(output) 23 | eval_expected = symbolic_generation.eval_symbolic_expression(expected) 24 | assert numpy.array_equal(eval_output, eval_expected) 25 | 26 | 27 | def test_symbolic_expression_matrix(): 28 | x = numpy.array([["True", "False"], ["False", "True"]]) 29 | output = symbolic_operator.symbolic_operator("not", x) 30 | expected = numpy.array( 31 | [["not(True)", "not(False)"], ["not(False)", "not(True)"]]) 32 | assert numpy.array_equal(output, expected) 33 | eval_output = symbolic_generation.eval_symbolic_expression(output) 34 | eval_expected = symbolic_generation.eval_symbolic_expression(expected) 35 | assert numpy.array_equal(eval_output, eval_expected) 36 | 37 | 38 | 39 | def test_symbolic_eval(): 40 | output = symbolic_generation.eval_symbolic_expression("1 + 2") 41 | expected = 3 42 | assert output == expected 43 | output = symbolic_generation.eval_symbolic_expression("[1, 2, 3]") 44 | expected = [1, 2, 3] 45 | assert numpy.array_equal(output, expected) 46 | output = symbolic_generation.eval_symbolic_expression( 47 | "[1, 2, 3] + [4, 5, 6]") 48 | expected = [1, 2, 3, 4, 5, 6] 49 | assert numpy.array_equal(output, expected) 50 | output = symbolic_generation.eval_symbolic_expression(['1', '2', '3']) 51 | expected = [1, 2, 3] 52 | assert numpy.array_equal(output, expected) 53 | output = symbolic_generation.eval_symbolic_expression( 54 | ['1', '2', '3'] + ['4', '5', '6']) 55 | expected = [1, 2, 3, 4, 5, 6] 56 | assert numpy.array_equal(output, expected) 57 | output = symbolic_generation.eval_symbolic_expression( 58 | ['not(False)', 'not(True)']) 59 | expected = [True, False] 60 | assert numpy.array_equal(output, expected) 61 | output = symbolic_generation.eval_symbolic_expression( 62 | [['not(False)', 'not(True)'] + ['not(False)', 'not(True)']]) 63 | expected = [[True, False, True, False]] 64 | assert numpy.array_equal(output, expected) 65 | output = symbolic_generation.eval_symbolic_expression(numpy.array( 66 | [['not(False)', 'not(True)'] + ['not(False)', 'not(True)']])) 67 | expected = [[True, False, True, False]] 68 | assert numpy.array_equal(output, expected) 69 | output = symbolic_generation.eval_symbolic_expression(numpy.array( 70 | [['not(False)', False], ['not(False)', 'not(True)']])) 71 | expected = [[True, False], [True, False]] 72 | assert numpy.array_equal(output, expected) 73 | 74 | 75 | def test_symbolic_not(): 76 | x1 = numpy.array([True, False]) 77 | output = symbolic_primitives.symbolic_not(x1) 78 | expected = numpy.array([False, True]) 79 | assert numpy.array_equal(output, expected) 80 | x1 = utils.make_symbolic(x1) 81 | output = symbolic_primitives.symbolic_not(x1) 82 | expected = numpy.array( 83 | ['numpy.logical_not(True)', 'numpy.logical_not(False)']) 84 | assert numpy.array_equal(output, expected) 85 | 86 | 87 | def test_symbolic_and(): 88 | x1 = numpy.array([True, False]) 89 | x2 = numpy.array([True, True]) 90 | output = symbolic_primitives.symbolic_and(x1, x2) 91 | expected = numpy.array([True, False]) 92 | assert numpy.array_equal(output, expected) 93 | x1 = utils.make_symbolic(x1) 94 | x2 = utils.make_symbolic(x2) 95 | output = symbolic_primitives.symbolic_and(x1, x2) 96 | expected = numpy.array( 97 | ['numpy.logical_and(True, True)', 'numpy.logical_and(False, True)']) 98 | assert numpy.array_equal(output, expected) 99 | 100 | 101 | def test_symbolic_xor(): 102 | x1 = numpy.array([True, False]) 103 | x2 = numpy.array([True, True]) 104 | output = symbolic_primitives.symbolic_xor(x1, x2) 105 | expected = numpy.array([False, True]) 106 | assert numpy.array_equal(output, expected) 107 | x1 = utils.make_symbolic(x1) 108 | x2 = utils.make_symbolic(x2) 109 | output = symbolic_primitives.symbolic_xor(x1, x2) 110 | expected = numpy.array( 111 | ['numpy.logical_xor(True, True)', 'numpy.logical_xor(False, True)']) 112 | assert numpy.array_equal(output, expected) 113 | 114 | 115 | def test_symbolic_broadcast_in_dim(): 116 | # Test 1D 117 | input = jax.numpy.array([1, 1]) 118 | output = symbolic_primitives.symbolic_broadcast_in_dim(input, (2, 2), (0,)) 119 | expected = jax.numpy.array([[1, 1], [1, 1]]) 120 | assert numpy.array_equal(output, expected) 121 | # Test 2D 122 | input = jax.numpy.array([[1, 1], [1, 1]]) 123 | output = symbolic_primitives.symbolic_broadcast_in_dim( 124 | input, (2, 2, 2), (0, 1)) 125 | expected = jax.numpy.array([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]) 126 | assert numpy.array_equal(output, expected) 127 | # Test 3D 128 | input = jax.numpy.array([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]) 129 | output = symbolic_primitives.symbolic_broadcast_in_dim( 130 | input, (2, 2, 2, 2), (0, 1, 2)) 131 | expected = jax.numpy.array([[[[1, 1], [1, 1]], [[1, 1], [1, 1]]], [ 132 | [[1, 1], [1, 1]], [[1, 1], [1, 1]]]]) 133 | assert numpy.array_equal(output, expected) 134 | 135 | 136 | def symbolic_reduce_or_impl(input, expected, symbolic_expected, axes): 137 | # symbolic_reduce_or uses the lax reference implementation if its input consists of boolean values, 138 | # otherwise it evaluates symbolically. Therefore we first test the reference implementation and then 139 | # the symbolic implementation, and then compare them. 140 | # Test reference implementation 141 | input = numpy.array(input) 142 | output = symbolic_primitives.symbolic_reduce_or(input, axes=axes) 143 | expected = numpy.array(expected) 144 | assert numpy.array_equal(output, expected) 145 | # Test symbolic implementation 146 | input = utils.make_symbolic(input) 147 | output = symbolic_primitives.symbolic_reduce_or(input, axes=axes) 148 | symbolic_expected = numpy.array(symbolic_expected) 149 | assert numpy.array_equal(output, symbolic_expected) 150 | # Compare the reference and symbolic evaluation 151 | symbolic_expected = symbolic_generation.eval_symbolic_expression( 152 | symbolic_expected) 153 | assert numpy.array_equal(expected, symbolic_expected) 154 | 155 | 156 | def test_symbolic_reduce_or(): 157 | # Test 1: 2D matrix with different axes inputs 158 | symbolic_reduce_or_impl(input=[[True, False], [True, False]], expected=[ 159 | True, True], symbolic_expected=['numpy.logical_or(numpy.logical_or(False, True), False)', 'numpy.logical_or(numpy.logical_or(False, True), False)'], axes=(1,)) 160 | symbolic_reduce_or_impl(input=[[True, False], [True, False]], expected=[ 161 | True, False], symbolic_expected=['numpy.logical_or(numpy.logical_or(False, True), True)', 'numpy.logical_or(numpy.logical_or(False, False), False)'], axes=(0,)) 162 | symbolic_reduce_or_impl(input=[[True, False], [True, False]], expected=True, 163 | symbolic_expected='numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(False, True), False), True), False)', axes=(0, 1)) 164 | # Test 2: 3D matrix with different axes inputs 165 | symbolic_reduce_or_impl(input=[[[True, False], [True, False]], [[True, False], [True, False]]], expected=[True, True], symbolic_expected=['numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(False, True), False), True), False)', 166 | 'numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(False, True), False), True), False)'], axes=(1, 2)) 167 | symbolic_reduce_or_impl(input=[[[True, False], [True, False]], [[True, False], [True, False]]], expected=[True, True], symbolic_expected=['numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(False, True), False), True), False)', 168 | 'numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(False, True), False), True), False)'], axes=(0, 2)) 169 | symbolic_reduce_or_impl(input=[[[True, False], [True, False]], [[True, False], [True, False]]], expected=[True, False], symbolic_expected=['numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(False, True), True), True), True)', 170 | 'numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(False, False), False), False), False)'], axes=(0, 1)) 171 | symbolic_reduce_or_impl(input=[[[True, False], [True, False]], [[True, False], [True, False]]], expected=True, 172 | symbolic_expected='numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(False, True), False), True), False), True), False), True), False)', axes=(0, 1, 2)) 173 | # Test 3: 4D matrix with different axes inputs 174 | symbolic_reduce_or_impl(input=[[[[True, False], [True, False]], [[True, False], [True, False]]], [[[True, False], [True, False]], [[True, False], [True, False]]]], expected=[True, True], symbolic_expected=['numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(False, True), False), True), False), True), False), True), False)', 175 | 'numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(False, True), False), True), False), True), False), True), False)'], axes=(1, 2, 3)) 176 | symbolic_reduce_or_impl(input=[[[[True, False], [True, False]], [[True, False], [True, False]]], [[[True, False], [True, False]], [[True, False], [True, False]]]], expected=[True, True], symbolic_expected=['numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(False, True), False), True), False), True), False), True), False)', 177 | 'numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(False, True), False), True), False), True), False), True), False)'], axes=(0, 2, 3)) 178 | symbolic_reduce_or_impl(input=[[[[True, False], [True, False]], [[True, False], [True, False]]], [[[True, False], [True, False]], [[True, False], [True, False]]]], expected=[True, True], symbolic_expected=['numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(False, True), False), True), False), True), False), True), False)', 179 | 'numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(False, True), False), True), False), True), False), True), False)'], axes=(0, 1, 3)) 180 | symbolic_reduce_or_impl(input=[[[[True, False], [True, False]], [[True, False], [True, False]]], [[[True, False], [True, False]], [[True, False], [True, False]]]], expected=[True, False], symbolic_expected=['numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(False, True), True), True), True), True), True), True), True)', 181 | 'numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(False, False), False), False), False), False), False), False), False)'], axes=(0, 1, 2)) 182 | symbolic_reduce_or_impl(input=[[[[True, False], [True, False]], [[True, False], [True, False]]], [[[True, False], [True, False]], [[True, False], [True, False]]]], expected=True, 183 | symbolic_expected='numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(numpy.logical_or(False, True), False), True), False), True), False), True), False), True), False), True), False), True), False), True), False)', axes=(0, 1, 2, 3)) 184 | -------------------------------------------------------------------------------- /docs/math_commands.tex: -------------------------------------------------------------------------------- 1 | %%%%% NEW MATH DEFINITIONS %%%%% 2 | 3 | \usepackage{amsmath,amsfonts,bm} 4 | 5 | % Mark sections of captions for referring to divisions of figures 6 | \newcommand{\figleft}{{\em (Left)}} 7 | \newcommand{\figcenter}{{\em (Center)}} 8 | \newcommand{\figright}{{\em (Right)}} 9 | \newcommand{\figtop}{{\em (Top)}} 10 | \newcommand{\figbottom}{{\em (Bottom)}} 11 | \newcommand{\captiona}{{\em (a)}} 12 | \newcommand{\captionb}{{\em (b)}} 13 | \newcommand{\captionc}{{\em (c)}} 14 | \newcommand{\captiond}{{\em (d)}} 15 | 16 | % Highlight a newly defined term 17 | \newcommand{\newterm}[1]{{\bf #1}} 18 | 19 | 20 | % Figure reference, lower-case. 21 | \def\figref#1{figure~\ref{#1}} 22 | % Figure reference, capital. For start of sentence 23 | \def\Figref#1{Figure~\ref{#1}} 24 | \def\twofigref#1#2{figures \ref{#1} and \ref{#2}} 25 | \def\quadfigref#1#2#3#4{figures \ref{#1}, \ref{#2}, \ref{#3} and \ref{#4}} 26 | % Section reference, lower-case. 27 | \def\secref#1{section~\ref{#1}} 28 | % Section reference, capital. 29 | \def\Secref#1{Section~\ref{#1}} 30 | % Reference to two sections. 31 | \def\twosecrefs#1#2{sections \ref{#1} and \ref{#2}} 32 | % Reference to three sections. 33 | \def\secrefs#1#2#3{sections \ref{#1}, \ref{#2} and \ref{#3}} 34 | % Reference to an equation, lower-case. 35 | \def\eqref#1{equation~\ref{#1}} 36 | % Reference to an equation, upper case 37 | \def\Eqref#1{Equation~\ref{#1}} 38 | % A raw reference to an equation---avoid using if possible 39 | \def\plaineqref#1{\ref{#1}} 40 | % Reference to a chapter, lower-case. 41 | \def\chapref#1{chapter~\ref{#1}} 42 | % Reference to an equation, upper case. 43 | \def\Chapref#1{Chapter~\ref{#1}} 44 | % Reference to a range of chapters 45 | \def\rangechapref#1#2{chapters\ref{#1}--\ref{#2}} 46 | % Reference to an algorithm, lower-case. 47 | \def\algref#1{algorithm~\ref{#1}} 48 | % Reference to an algorithm, upper case. 49 | \def\Algref#1{Algorithm~\ref{#1}} 50 | \def\twoalgref#1#2{algorithms \ref{#1} and \ref{#2}} 51 | \def\Twoalgref#1#2{Algorithms \ref{#1} and \ref{#2}} 52 | % Reference to a part, lower case 53 | \def\partref#1{part~\ref{#1}} 54 | % Reference to a part, upper case 55 | \def\Partref#1{Part~\ref{#1}} 56 | \def\twopartref#1#2{parts \ref{#1} and \ref{#2}} 57 | 58 | \def\ceil#1{\lceil #1 \rceil} 59 | \def\floor#1{\lfloor #1 \rfloor} 60 | \def\1{\bm{1}} 61 | \newcommand{\train}{\mathcal{D}} 62 | \newcommand{\valid}{\mathcal{D_{\mathrm{valid}}}} 63 | \newcommand{\test}{\mathcal{D_{\mathrm{test}}}} 64 | 65 | \def\eps{{\epsilon}} 66 | 67 | 68 | % Random variables 69 | \def\reta{{\textnormal{$\eta$}}} 70 | \def\ra{{\textnormal{a}}} 71 | \def\rb{{\textnormal{b}}} 72 | \def\rc{{\textnormal{c}}} 73 | \def\rd{{\textnormal{d}}} 74 | \def\re{{\textnormal{e}}} 75 | \def\rf{{\textnormal{f}}} 76 | \def\rg{{\textnormal{g}}} 77 | \def\rh{{\textnormal{h}}} 78 | \def\ri{{\textnormal{i}}} 79 | \def\rj{{\textnormal{j}}} 80 | \def\rk{{\textnormal{k}}} 81 | \def\rl{{\textnormal{l}}} 82 | % rm is already a command, just don't name any random variables m 83 | \def\rn{{\textnormal{n}}} 84 | \def\ro{{\textnormal{o}}} 85 | \def\rp{{\textnormal{p}}} 86 | \def\rq{{\textnormal{q}}} 87 | \def\rr{{\textnormal{r}}} 88 | \def\rs{{\textnormal{s}}} 89 | \def\rt{{\textnormal{t}}} 90 | \def\ru{{\textnormal{u}}} 91 | \def\rv{{\textnormal{v}}} 92 | \def\rw{{\textnormal{w}}} 93 | \def\rx{{\textnormal{x}}} 94 | \def\ry{{\textnormal{y}}} 95 | \def\rz{{\textnormal{z}}} 96 | 97 | % Random vectors 98 | \def\rvepsilon{{\mathbf{\epsilon}}} 99 | \def\rvtheta{{\mathbf{\theta}}} 100 | \def\rva{{\mathbf{a}}} 101 | \def\rvb{{\mathbf{b}}} 102 | \def\rvc{{\mathbf{c}}} 103 | \def\rvd{{\mathbf{d}}} 104 | \def\rve{{\mathbf{e}}} 105 | \def\rvf{{\mathbf{f}}} 106 | \def\rvg{{\mathbf{g}}} 107 | \def\rvh{{\mathbf{h}}} 108 | \def\rvu{{\mathbf{i}}} 109 | \def\rvj{{\mathbf{j}}} 110 | \def\rvk{{\mathbf{k}}} 111 | \def\rvl{{\mathbf{l}}} 112 | \def\rvm{{\mathbf{m}}} 113 | \def\rvn{{\mathbf{n}}} 114 | \def\rvo{{\mathbf{o}}} 115 | \def\rvp{{\mathbf{p}}} 116 | \def\rvq{{\mathbf{q}}} 117 | \def\rvr{{\mathbf{r}}} 118 | \def\rvs{{\mathbf{s}}} 119 | \def\rvt{{\mathbf{t}}} 120 | \def\rvu{{\mathbf{u}}} 121 | \def\rvv{{\mathbf{v}}} 122 | \def\rvw{{\mathbf{w}}} 123 | \def\rvx{{\mathbf{x}}} 124 | \def\rvy{{\mathbf{y}}} 125 | \def\rvz{{\mathbf{z}}} 126 | 127 | % Elements of random vectors 128 | \def\erva{{\textnormal{a}}} 129 | \def\ervb{{\textnormal{b}}} 130 | \def\ervc{{\textnormal{c}}} 131 | \def\ervd{{\textnormal{d}}} 132 | \def\erve{{\textnormal{e}}} 133 | \def\ervf{{\textnormal{f}}} 134 | \def\ervg{{\textnormal{g}}} 135 | \def\ervh{{\textnormal{h}}} 136 | \def\ervi{{\textnormal{i}}} 137 | \def\ervj{{\textnormal{j}}} 138 | \def\ervk{{\textnormal{k}}} 139 | \def\ervl{{\textnormal{l}}} 140 | \def\ervm{{\textnormal{m}}} 141 | \def\ervn{{\textnormal{n}}} 142 | \def\ervo{{\textnormal{o}}} 143 | \def\ervp{{\textnormal{p}}} 144 | \def\ervq{{\textnormal{q}}} 145 | \def\ervr{{\textnormal{r}}} 146 | \def\ervs{{\textnormal{s}}} 147 | \def\ervt{{\textnormal{t}}} 148 | \def\ervu{{\textnormal{u}}} 149 | \def\ervv{{\textnormal{v}}} 150 | \def\ervw{{\textnormal{w}}} 151 | \def\ervx{{\textnormal{x}}} 152 | \def\ervy{{\textnormal{y}}} 153 | \def\ervz{{\textnormal{z}}} 154 | 155 | % Random matrices 156 | \def\rmA{{\mathbf{A}}} 157 | \def\rmB{{\mathbf{B}}} 158 | \def\rmC{{\mathbf{C}}} 159 | \def\rmD{{\mathbf{D}}} 160 | \def\rmE{{\mathbf{E}}} 161 | \def\rmF{{\mathbf{F}}} 162 | \def\rmG{{\mathbf{G}}} 163 | \def\rmH{{\mathbf{H}}} 164 | \def\rmI{{\mathbf{I}}} 165 | \def\rmJ{{\mathbf{J}}} 166 | \def\rmK{{\mathbf{K}}} 167 | \def\rmL{{\mathbf{L}}} 168 | \def\rmM{{\mathbf{M}}} 169 | \def\rmN{{\mathbf{N}}} 170 | \def\rmO{{\mathbf{O}}} 171 | \def\rmP{{\mathbf{P}}} 172 | \def\rmQ{{\mathbf{Q}}} 173 | \def\rmR{{\mathbf{R}}} 174 | \def\rmS{{\mathbf{S}}} 175 | \def\rmT{{\mathbf{T}}} 176 | \def\rmU{{\mathbf{U}}} 177 | \def\rmV{{\mathbf{V}}} 178 | \def\rmW{{\mathbf{W}}} 179 | \def\rmX{{\mathbf{X}}} 180 | \def\rmY{{\mathbf{Y}}} 181 | \def\rmZ{{\mathbf{Z}}} 182 | 183 | % Elements of random matrices 184 | \def\ermA{{\textnormal{A}}} 185 | \def\ermB{{\textnormal{B}}} 186 | \def\ermC{{\textnormal{C}}} 187 | \def\ermD{{\textnormal{D}}} 188 | \def\ermE{{\textnormal{E}}} 189 | \def\ermF{{\textnormal{F}}} 190 | \def\ermG{{\textnormal{G}}} 191 | \def\ermH{{\textnormal{H}}} 192 | \def\ermI{{\textnormal{I}}} 193 | \def\ermJ{{\textnormal{J}}} 194 | \def\ermK{{\textnormal{K}}} 195 | \def\ermL{{\textnormal{L}}} 196 | \def\ermM{{\textnormal{M}}} 197 | \def\ermN{{\textnormal{N}}} 198 | \def\ermO{{\textnormal{O}}} 199 | \def\ermP{{\textnormal{P}}} 200 | \def\ermQ{{\textnormal{Q}}} 201 | \def\ermR{{\textnormal{R}}} 202 | \def\ermS{{\textnormal{S}}} 203 | \def\ermT{{\textnormal{T}}} 204 | \def\ermU{{\textnormal{U}}} 205 | \def\ermV{{\textnormal{V}}} 206 | \def\ermW{{\textnormal{W}}} 207 | \def\ermX{{\textnormal{X}}} 208 | \def\ermY{{\textnormal{Y}}} 209 | \def\ermZ{{\textnormal{Z}}} 210 | 211 | % Vectors 212 | \def\vzero{{\bm{0}}} 213 | \def\vone{{\bm{1}}} 214 | \def\vmu{{\bm{\mu}}} 215 | \def\vtheta{{\bm{\theta}}} 216 | \def\va{{\bm{a}}} 217 | \def\vb{{\bm{b}}} 218 | \def\vc{{\bm{c}}} 219 | \def\vd{{\bm{d}}} 220 | \def\ve{{\bm{e}}} 221 | \def\vf{{\bm{f}}} 222 | \def\vg{{\bm{g}}} 223 | \def\vh{{\bm{h}}} 224 | \def\vi{{\bm{i}}} 225 | \def\vj{{\bm{j}}} 226 | \def\vk{{\bm{k}}} 227 | \def\vl{{\bm{l}}} 228 | \def\vm{{\bm{m}}} 229 | \def\vn{{\bm{n}}} 230 | \def\vo{{\bm{o}}} 231 | \def\vp{{\bm{p}}} 232 | \def\vq{{\bm{q}}} 233 | \def\vr{{\bm{r}}} 234 | \def\vs{{\bm{s}}} 235 | \def\vt{{\bm{t}}} 236 | \def\vu{{\bm{u}}} 237 | \def\vv{{\bm{v}}} 238 | \def\vw{{\bm{w}}} 239 | \def\vx{{\bm{x}}} 240 | \def\vy{{\bm{y}}} 241 | \def\vz{{\bm{z}}} 242 | 243 | % Elements of vectors 244 | \def\evalpha{{\alpha}} 245 | \def\evbeta{{\beta}} 246 | \def\evepsilon{{\epsilon}} 247 | \def\evlambda{{\lambda}} 248 | \def\evomega{{\omega}} 249 | \def\evmu{{\mu}} 250 | \def\evpsi{{\psi}} 251 | \def\evsigma{{\sigma}} 252 | \def\evtheta{{\theta}} 253 | \def\eva{{a}} 254 | \def\evb{{b}} 255 | \def\evc{{c}} 256 | \def\evd{{d}} 257 | \def\eve{{e}} 258 | \def\evf{{f}} 259 | \def\evg{{g}} 260 | \def\evh{{h}} 261 | \def\evi{{i}} 262 | \def\evj{{j}} 263 | \def\evk{{k}} 264 | \def\evl{{l}} 265 | \def\evm{{m}} 266 | \def\evn{{n}} 267 | \def\evo{{o}} 268 | \def\evp{{p}} 269 | \def\evq{{q}} 270 | \def\evr{{r}} 271 | \def\evs{{s}} 272 | \def\evt{{t}} 273 | \def\evu{{u}} 274 | \def\evv{{v}} 275 | \def\evw{{w}} 276 | \def\evx{{x}} 277 | \def\evy{{y}} 278 | \def\evz{{z}} 279 | 280 | % Matrix 281 | \def\mA{{\bm{A}}} 282 | \def\mB{{\bm{B}}} 283 | \def\mC{{\bm{C}}} 284 | \def\mD{{\bm{D}}} 285 | \def\mE{{\bm{E}}} 286 | \def\mF{{\bm{F}}} 287 | \def\mG{{\bm{G}}} 288 | \def\mH{{\bm{H}}} 289 | \def\mI{{\bm{I}}} 290 | \def\mJ{{\bm{J}}} 291 | \def\mK{{\bm{K}}} 292 | \def\mL{{\bm{L}}} 293 | \def\mM{{\bm{M}}} 294 | \def\mN{{\bm{N}}} 295 | \def\mO{{\bm{O}}} 296 | \def\mP{{\bm{P}}} 297 | \def\mQ{{\bm{Q}}} 298 | \def\mR{{\bm{R}}} 299 | \def\mS{{\bm{S}}} 300 | \def\mT{{\bm{T}}} 301 | \def\mU{{\bm{U}}} 302 | \def\mV{{\bm{V}}} 303 | \def\mW{{\bm{W}}} 304 | \def\mX{{\bm{X}}} 305 | \def\mY{{\bm{Y}}} 306 | \def\mZ{{\bm{Z}}} 307 | \def\mBeta{{\bm{\beta}}} 308 | \def\mPhi{{\bm{\Phi}}} 309 | \def\mLambda{{\bm{\Lambda}}} 310 | \def\mSigma{{\bm{\Sigma}}} 311 | 312 | % Tensor 313 | \DeclareMathAlphabet{\mathsfit}{\encodingdefault}{\sfdefault}{m}{sl} 314 | \SetMathAlphabet{\mathsfit}{bold}{\encodingdefault}{\sfdefault}{bx}{n} 315 | \newcommand{\tens}[1]{\bm{\mathsfit{#1}}} 316 | \def\tA{{\tens{A}}} 317 | \def\tB{{\tens{B}}} 318 | \def\tC{{\tens{C}}} 319 | \def\tD{{\tens{D}}} 320 | \def\tE{{\tens{E}}} 321 | \def\tF{{\tens{F}}} 322 | \def\tG{{\tens{G}}} 323 | \def\tH{{\tens{H}}} 324 | \def\tI{{\tens{I}}} 325 | \def\tJ{{\tens{J}}} 326 | \def\tK{{\tens{K}}} 327 | \def\tL{{\tens{L}}} 328 | \def\tM{{\tens{M}}} 329 | \def\tN{{\tens{N}}} 330 | \def\tO{{\tens{O}}} 331 | \def\tP{{\tens{P}}} 332 | \def\tQ{{\tens{Q}}} 333 | \def\tR{{\tens{R}}} 334 | \def\tS{{\tens{S}}} 335 | \def\tT{{\tens{T}}} 336 | \def\tU{{\tens{U}}} 337 | \def\tV{{\tens{V}}} 338 | \def\tW{{\tens{W}}} 339 | \def\tX{{\tens{X}}} 340 | \def\tY{{\tens{Y}}} 341 | \def\tZ{{\tens{Z}}} 342 | 343 | 344 | % Graph 345 | \def\gA{{\mathcal{A}}} 346 | \def\gB{{\mathcal{B}}} 347 | \def\gC{{\mathcal{C}}} 348 | \def\gD{{\mathcal{D}}} 349 | \def\gE{{\mathcal{E}}} 350 | \def\gF{{\mathcal{F}}} 351 | \def\gG{{\mathcal{G}}} 352 | \def\gH{{\mathcal{H}}} 353 | \def\gI{{\mathcal{I}}} 354 | \def\gJ{{\mathcal{J}}} 355 | \def\gK{{\mathcal{K}}} 356 | \def\gL{{\mathcal{L}}} 357 | \def\gM{{\mathcal{M}}} 358 | \def\gN{{\mathcal{N}}} 359 | \def\gO{{\mathcal{O}}} 360 | \def\gP{{\mathcal{P}}} 361 | \def\gQ{{\mathcal{Q}}} 362 | \def\gR{{\mathcal{R}}} 363 | \def\gS{{\mathcal{S}}} 364 | \def\gT{{\mathcal{T}}} 365 | \def\gU{{\mathcal{U}}} 366 | \def\gV{{\mathcal{V}}} 367 | \def\gW{{\mathcal{W}}} 368 | \def\gX{{\mathcal{X}}} 369 | \def\gY{{\mathcal{Y}}} 370 | \def\gZ{{\mathcal{Z}}} 371 | 372 | % Sets 373 | \def\sA{{\mathbb{A}}} 374 | \def\sB{{\mathbb{B}}} 375 | \def\sC{{\mathbb{C}}} 376 | \def\sD{{\mathbb{D}}} 377 | % Don't use a set called E, because this would be the same as our symbol 378 | % for expectation. 379 | \def\sF{{\mathbb{F}}} 380 | \def\sG{{\mathbb{G}}} 381 | \def\sH{{\mathbb{H}}} 382 | \def\sI{{\mathbb{I}}} 383 | \def\sJ{{\mathbb{J}}} 384 | \def\sK{{\mathbb{K}}} 385 | \def\sL{{\mathbb{L}}} 386 | \def\sM{{\mathbb{M}}} 387 | \def\sN{{\mathbb{N}}} 388 | \def\sO{{\mathbb{O}}} 389 | \def\sP{{\mathbb{P}}} 390 | \def\sQ{{\mathbb{Q}}} 391 | \def\sR{{\mathbb{R}}} 392 | \def\sS{{\mathbb{S}}} 393 | \def\sT{{\mathbb{T}}} 394 | \def\sU{{\mathbb{U}}} 395 | \def\sV{{\mathbb{V}}} 396 | \def\sW{{\mathbb{W}}} 397 | \def\sX{{\mathbb{X}}} 398 | \def\sY{{\mathbb{Y}}} 399 | \def\sZ{{\mathbb{Z}}} 400 | 401 | % Entries of a matrix 402 | \def\emLambda{{\Lambda}} 403 | \def\emA{{A}} 404 | \def\emB{{B}} 405 | \def\emC{{C}} 406 | \def\emD{{D}} 407 | \def\emE{{E}} 408 | \def\emF{{F}} 409 | \def\emG{{G}} 410 | \def\emH{{H}} 411 | \def\emI{{I}} 412 | \def\emJ{{J}} 413 | \def\emK{{K}} 414 | \def\emL{{L}} 415 | \def\emM{{M}} 416 | \def\emN{{N}} 417 | \def\emO{{O}} 418 | \def\emP{{P}} 419 | \def\emQ{{Q}} 420 | \def\emR{{R}} 421 | \def\emS{{S}} 422 | \def\emT{{T}} 423 | \def\emU{{U}} 424 | \def\emV{{V}} 425 | \def\emW{{W}} 426 | \def\emX{{X}} 427 | \def\emY{{Y}} 428 | \def\emZ{{Z}} 429 | \def\emSigma{{\Sigma}} 430 | 431 | % entries of a tensor 432 | % Same font as tensor, without \bm wrapper 433 | \newcommand{\etens}[1]{\mathsfit{#1}} 434 | \def\etLambda{{\etens{\Lambda}}} 435 | \def\etA{{\etens{A}}} 436 | \def\etB{{\etens{B}}} 437 | \def\etC{{\etens{C}}} 438 | \def\etD{{\etens{D}}} 439 | \def\etE{{\etens{E}}} 440 | \def\etF{{\etens{F}}} 441 | \def\etG{{\etens{G}}} 442 | \def\etH{{\etens{H}}} 443 | \def\etI{{\etens{I}}} 444 | \def\etJ{{\etens{J}}} 445 | \def\etK{{\etens{K}}} 446 | \def\etL{{\etens{L}}} 447 | \def\etM{{\etens{M}}} 448 | \def\etN{{\etens{N}}} 449 | \def\etO{{\etens{O}}} 450 | \def\etP{{\etens{P}}} 451 | \def\etQ{{\etens{Q}}} 452 | \def\etR{{\etens{R}}} 453 | \def\etS{{\etens{S}}} 454 | \def\etT{{\etens{T}}} 455 | \def\etU{{\etens{U}}} 456 | \def\etV{{\etens{V}}} 457 | \def\etW{{\etens{W}}} 458 | \def\etX{{\etens{X}}} 459 | \def\etY{{\etens{Y}}} 460 | \def\etZ{{\etens{Z}}} 461 | 462 | % The true underlying data generating distribution 463 | \newcommand{\pdata}{p_{\rm{data}}} 464 | % The empirical distribution defined by the training set 465 | \newcommand{\ptrain}{\hat{p}_{\rm{data}}} 466 | \newcommand{\Ptrain}{\hat{P}_{\rm{data}}} 467 | % The model distribution 468 | \newcommand{\pmodel}{p_{\rm{model}}} 469 | \newcommand{\Pmodel}{P_{\rm{model}}} 470 | \newcommand{\ptildemodel}{\tilde{p}_{\rm{model}}} 471 | % Stochastic autoencoder distributions 472 | \newcommand{\pencode}{p_{\rm{encoder}}} 473 | \newcommand{\pdecode}{p_{\rm{decoder}}} 474 | \newcommand{\precons}{p_{\rm{reconstruct}}} 475 | 476 | \newcommand{\laplace}{\mathrm{Laplace}} % Laplace distribution 477 | 478 | \newcommand{\E}{\mathbb{E}} 479 | \newcommand{\Ls}{\mathcal{L}} 480 | \newcommand{\R}{\mathbb{R}} 481 | \newcommand{\emp}{\tilde{p}} 482 | \newcommand{\lr}{\alpha} 483 | \newcommand{\reg}{\lambda} 484 | \newcommand{\rect}{\mathrm{rectifier}} 485 | \newcommand{\softmax}{\mathrm{softmax}} 486 | \newcommand{\sigmoid}{\sigma} 487 | \newcommand{\softplus}{\zeta} 488 | \newcommand{\KL}{D_{\mathrm{KL}}} 489 | \newcommand{\Var}{\mathrm{Var}} 490 | \newcommand{\standarderror}{\mathrm{SE}} 491 | \newcommand{\Cov}{\mathrm{Cov}} 492 | % Wolfram Mathworld says $L^2$ is for function spaces and $\ell^2$ is for vectors 493 | % But then they seem to use $L^2$ for vectors throughout the site, and so does 494 | % wikipedia. 495 | \newcommand{\normlzero}{L^0} 496 | \newcommand{\normlone}{L^1} 497 | \newcommand{\normltwo}{L^2} 498 | \newcommand{\normlp}{L^p} 499 | \newcommand{\normmax}{L^\infty} 500 | 501 | \newcommand{\parents}{Pa} % See usage in notation.tex. Chosen to match Daphne's book. 502 | 503 | \DeclareMathOperator*{\argmax}{arg\,max} 504 | \DeclareMathOperator*{\argmin}{arg\,min} 505 | 506 | \DeclareMathOperator{\sign}{sign} 507 | \DeclareMathOperator{\Tr}{Tr} 508 | \let\ab\allowbreak 509 | --------------------------------------------------------------------------------