├── data └── .gitkeep ├── src ├── pydybm │ ├── __init__.py │ ├── base │ │ ├── __init__.py │ │ └── metrics.py │ ├── tools │ │ └── __init__.py │ ├── reinforce │ │ ├── __init__.py │ │ └── agent.py │ ├── time_series │ │ └── __init__.py │ ├── docs │ │ ├── pydybm.base.sgd.rst │ │ ├── pydybm.base.metrics.rst │ │ ├── pydybm.rst │ │ ├── pydybm.base.generator.rst │ │ ├── pydybm.reinforce.agent.rst │ │ ├── pydybm.time_series.esn.rst │ │ ├── pydybm.base.rst │ │ ├── pydybm.reinforce.bandit.rst │ │ ├── pydybm.time_series.dybm.rst │ │ ├── pydybm.reinforce.dysarsa.rst │ │ ├── pydybm.arraymath.dycupy.fifo.rst │ │ ├── pydybm.arraymath.dycupy.magma.rst │ │ ├── pydybm.arraymath.dynumpy.fifo.rst │ │ ├── pydybm.time_series.batch_dybm.rst │ │ ├── pydybm.arraymath.dycupy.random.rst │ │ ├── pydybm.reinforce.discrete_agent.rst │ │ ├── pydybm.arraymath.dycupy.data_queue.rst │ │ ├── pydybm.arraymath.dycupy.operations.rst │ │ ├── pydybm.time_series.functional_dybm.rst │ │ ├── pydybm.arraymath.dynumpy.data_queue.rst │ │ ├── pydybm.arraymath.dynumpy.operations.rst │ │ ├── pydybm.time_series.rnn_gaussian_dybm.rst │ │ ├── pydybm.time_series.time_series_model.rst │ │ ├── pydybm.time_series.vector_regression.rst │ │ ├── pydybm.reinforce.rst │ │ ├── pydybm.time_series.batch_gaussian_dybm.rst │ │ ├── pydybm.arraymath.rst │ │ ├── pydybm.arraymath.dynumpy.rst │ │ ├── pydybm.time_series.rst │ │ ├── pydybm.arraymath.dycupy.rst │ │ ├── index.rst │ │ └── conf.py │ ├── arraymath │ │ ├── dynumpy │ │ │ ├── fifo.py │ │ │ ├── data_queue.py │ │ │ ├── __init__.py │ │ │ └── operations.py │ │ ├── dycupy │ │ │ ├── fifo.py │ │ │ ├── random.py │ │ │ ├── data_queue.py │ │ │ ├── magma.py │ │ │ └── operations.py │ │ └── __init__.py │ └── Readme.md ├── tests │ ├── __init__.py │ ├── sgd_test.py │ ├── arraymath.py │ ├── bandit.py │ ├── tools_test.py │ ├── esn_test.py │ ├── complex_dybm_test.py │ ├── gaussian_dybm_test.py │ ├── dybm_esn_test.py │ ├── rnn_gaussian_dybm_test.py │ ├── dysarsa_test.py │ ├── generator_test.py │ └── vector_regression_with_hidden_test.py ├── jdybm │ ├── python │ │ ├── fig │ │ │ └── .gitkeep │ │ ├── MakeMirrorStepsFigure.py │ │ ├── MakeEvolutionFigure.py │ │ ├── MakeSingleFigure.py │ │ ├── MakeMirrorFigure.py │ │ └── MakeMusicalNote.py │ ├── img │ │ ├── dynamic_science.gif │ │ └── dynamic_evolution.gif │ ├── src │ │ └── com │ │ │ └── ibm │ │ │ └── stdp │ │ │ ├── FIFO.java │ │ │ ├── science │ │ │ ├── WriteNetwork.java │ │ │ ├── AnalyzeEvolution.java │ │ │ ├── AnalyzeMusician.java │ │ │ ├── AnalyzeSingleSCIENCE.java │ │ │ ├── AnalyzeMirrorSCIENCE.java │ │ │ ├── EvolutionExperiment.java │ │ │ ├── SingleExperiment3.java │ │ │ └── MusicianExperiment.java │ │ │ ├── BitMapFont.java │ │ │ ├── BitMapMusic.java │ │ │ ├── Parameter.java │ │ │ ├── BitMapImage.java │ │ │ ├── IntFIFO.java │ │ │ ├── BinaryPattern.java │ │ │ └── BinaryFIFO.java │ ├── run.sh │ ├── run.bat │ ├── build.xml │ └── Readme.md └── tasks │ ├── dydpp │ ├── data │ │ └── .gitkeep │ ├── requirements.txt │ ├── README.md │ └── plot.py │ └── icml17 │ ├── Readme.md │ ├── convert.py │ ├── plot_icml17_fig4.py │ └── plot_icml17_fig3.py ├── setup.py ├── examples ├── Readme.md ├── reinforce │ └── DySARSA_discreteAgent_Demo.py └── time_series │ └── GPUDemo.ipynb ├── .gitignore └── Readme.md /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/pydybm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/pydybm/base/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/pydybm/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/jdybm/python/fig/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/pydybm/reinforce/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/tasks/dydpp/data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/pydybm/time_series/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/jdybm/img/dynamic_science.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ibm-research-tokyo/dybm/HEAD/src/jdybm/img/dynamic_science.gif -------------------------------------------------------------------------------- /src/jdybm/img/dynamic_evolution.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ibm-research-tokyo/dybm/HEAD/src/jdybm/img/dynamic_evolution.gif -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.base.sgd.rst: -------------------------------------------------------------------------------- 1 | pydybm\.base\.sgd module 2 | ======================== 3 | 4 | .. automodule:: pydybm.base.sgd 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.base.metrics.rst: -------------------------------------------------------------------------------- 1 | pydybm\.base\.metrics module 2 | ============================ 3 | 4 | .. automodule:: pydybm.base.metrics 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.rst: -------------------------------------------------------------------------------- 1 | pydybm package 2 | ============== 3 | 4 | ---------- 5 | 6 | .. toctree:: 7 | 8 | pydybm.arraymath 9 | pydybm.base 10 | pydybm.reinforce 11 | pydybm.time_series 12 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.base.generator.rst: -------------------------------------------------------------------------------- 1 | pydybm\.base\.generator module 2 | ============================== 3 | 4 | .. automodule:: pydybm.base.generator 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.reinforce.agent.rst: -------------------------------------------------------------------------------- 1 | pydybm\.reinforce\.agent module 2 | =============================== 3 | 4 | .. automodule:: pydybm.reinforce.agent 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.time_series.esn.rst: -------------------------------------------------------------------------------- 1 | pydybm\.time\_series\.esn module 2 | ================================ 3 | 4 | .. automodule:: pydybm.time_series.esn 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.base.rst: -------------------------------------------------------------------------------- 1 | pydybm\.base package 2 | ==================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | 9 | pydybm.base.generator 10 | pydybm.base.metrics 11 | pydybm.base.sgd 12 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.reinforce.bandit.rst: -------------------------------------------------------------------------------- 1 | pydybm\.reinforce\.bandit module 2 | ================================ 3 | 4 | .. automodule:: pydybm.reinforce.bandit 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.time_series.dybm.rst: -------------------------------------------------------------------------------- 1 | pydybm\.time\_series\.dybm module 2 | ================================= 3 | 4 | .. automodule:: pydybm.time_series.dybm 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.reinforce.dysarsa.rst: -------------------------------------------------------------------------------- 1 | pydybm\.reinforce\.dysarsa module 2 | ================================= 3 | 4 | .. automodule:: pydybm.reinforce.dysarsa 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.arraymath.dycupy.fifo.rst: -------------------------------------------------------------------------------- 1 | pydybm\.arraymath\.dycupy\.fifo module 2 | ====================================== 3 | 4 | .. automodule:: pydybm.arraymath.dycupy.fifo 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.arraymath.dycupy.magma.rst: -------------------------------------------------------------------------------- 1 | pydybm\.arraymath\.dycupy\.magma module 2 | ======================================= 3 | 4 | .. automodule:: pydybm.arraymath.dycupy.magma 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.arraymath.dynumpy.fifo.rst: -------------------------------------------------------------------------------- 1 | pydybm\.arraymath\.dynumpy\.fifo module 2 | ======================================= 3 | 4 | .. automodule:: pydybm.arraymath.dynumpy.fifo 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.time_series.batch_dybm.rst: -------------------------------------------------------------------------------- 1 | pydybm\.time\_series\.batch\_dybm module 2 | ======================================== 3 | 4 | .. automodule:: pydybm.time_series.batch_dybm 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.arraymath.dycupy.random.rst: -------------------------------------------------------------------------------- 1 | pydybm\.arraymath\.dycupy\.random module 2 | ======================================== 3 | 4 | .. automodule:: pydybm.arraymath.dycupy.random 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.reinforce.discrete_agent.rst: -------------------------------------------------------------------------------- 1 | pydybm\.reinforce\.discrete\_agent module 2 | ========================================= 3 | 4 | .. automodule:: pydybm.reinforce.discrete_agent 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.arraymath.dycupy.data_queue.rst: -------------------------------------------------------------------------------- 1 | pydybm\.arraymath\.dycupy\.data\_queue module 2 | ============================================= 3 | 4 | .. automodule:: pydybm.arraymath.dycupy.data_queue 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.arraymath.dycupy.operations.rst: -------------------------------------------------------------------------------- 1 | pydybm\.arraymath\.dycupy\.operations module 2 | ============================================ 3 | 4 | .. automodule:: pydybm.arraymath.dycupy.operations 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.time_series.functional_dybm.rst: -------------------------------------------------------------------------------- 1 | pydybm\.time\_series\.functional\_dybm module 2 | ============================================= 3 | 4 | .. automodule:: pydybm.time_series.functional_dybm 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.arraymath.dynumpy.data_queue.rst: -------------------------------------------------------------------------------- 1 | pydybm\.arraymath\.dynumpy\.data\_queue module 2 | ============================================== 3 | 4 | .. automodule:: pydybm.arraymath.dynumpy.data_queue 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.arraymath.dynumpy.operations.rst: -------------------------------------------------------------------------------- 1 | pydybm\.arraymath\.dynumpy\.operations module 2 | ============================================= 3 | 4 | .. automodule:: pydybm.arraymath.dynumpy.operations 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.time_series.rnn_gaussian_dybm.rst: -------------------------------------------------------------------------------- 1 | pydybm\.time\_series\.rnn\_gaussian\_dybm module 2 | ================================================ 3 | 4 | .. automodule:: pydybm.time_series.rnn_gaussian_dybm 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.time_series.time_series_model.rst: -------------------------------------------------------------------------------- 1 | pydybm\.time\_series\.time\_series\_model module 2 | ================================================ 3 | 4 | .. automodule:: pydybm.time_series.time_series_model 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.time_series.vector_regression.rst: -------------------------------------------------------------------------------- 1 | pydybm\.time\_series\.vector\_regression module 2 | =============================================== 3 | 4 | .. automodule:: pydybm.time_series.vector_regression 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.reinforce.rst: -------------------------------------------------------------------------------- 1 | pydybm\.reinforce package 2 | ========================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | 9 | pydybm.reinforce.agent 10 | pydybm.reinforce.bandit 11 | pydybm.reinforce.discrete_agent 12 | pydybm.reinforce.dysarsa 13 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.time_series.batch_gaussian_dybm.rst: -------------------------------------------------------------------------------- 1 | pydybm\.time\_series\.batch\_gaussian\_dybm module 2 | ================================================== 3 | 4 | .. automodule:: pydybm.time_series.batch_gaussian_dybm 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /src/tasks/dydpp/requirements.txt: -------------------------------------------------------------------------------- 1 | cycler==0.10.0 2 | joblib==1.2.0 3 | kiwisolver==1.1.0 4 | matplotlib==3.1.1 5 | numpy==1.22.0 6 | -e git+https://github.com/ibm-research-tokyo/dybm.git@e209f68016f717d792756d3aea94cac0cf03ac26#egg=pydybm 7 | pyparsing==2.4.2 8 | python-dateutil==2.8.0 9 | scikit-learn==0.21.3 10 | scipy==1.10.0 11 | six==1.12.0 12 | sklearn==0.0 13 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.arraymath.rst: -------------------------------------------------------------------------------- 1 | pydybm\.arraymath package 2 | ========================= 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | pydybm.arraymath.dycupy 10 | pydybm.arraymath.dynumpy 11 | 12 | Module contents 13 | --------------- 14 | 15 | .. automodule:: pydybm.arraymath 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.arraymath.dynumpy.rst: -------------------------------------------------------------------------------- 1 | pydybm\.arraymath\.dynumpy package 2 | ================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | 9 | pydybm.arraymath.dynumpy.data_queue 10 | pydybm.arraymath.dynumpy.fifo 11 | pydybm.arraymath.dynumpy.operations 12 | 13 | Module contents 14 | --------------- 15 | 16 | .. automodule:: pydybm.arraymath.dynumpy 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.time_series.rst: -------------------------------------------------------------------------------- 1 | pydybm\.time\_series package 2 | ============================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | 9 | pydybm.time_series.batch_dybm 10 | pydybm.time_series.batch_gaussian_dybm 11 | pydybm.time_series.dybm 12 | pydybm.time_series.esn 13 | pydybm.time_series.functional_dybm 14 | pydybm.time_series.rnn_gaussian_dybm 15 | pydybm.time_series.time_series_model 16 | pydybm.time_series.vector_regression 17 | -------------------------------------------------------------------------------- /src/pydybm/docs/pydybm.arraymath.dycupy.rst: -------------------------------------------------------------------------------- 1 | pydybm\.arraymath\.dycupy package 2 | ================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | 9 | pydybm.arraymath.dycupy.data_queue 10 | pydybm.arraymath.dycupy.fifo 11 | pydybm.arraymath.dycupy.magma 12 | pydybm.arraymath.dycupy.operations 13 | pydybm.arraymath.dycupy.random 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: pydybm.arraymath.dycupy 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ setup file for DyBM. """ 4 | 5 | __author__ = "Hiroshi Kajino" 6 | __version__ = "3.2" 7 | __date__ = "December 22, 2016" 8 | __copyright__ = "(C) Copyright IBM Corp. 2016" 9 | 10 | from setuptools import setup, find_packages 11 | import sys 12 | import os 13 | 14 | setup( 15 | name = "pydybm", 16 | version = "3.2.1", 17 | author = "DyBM developers at IBM Research - Tokyo", 18 | package_dir = {"": "src"}, 19 | packages = find_packages(), 20 | test_suite = "tests", 21 | ) 22 | -------------------------------------------------------------------------------- /src/pydybm/arraymath/dynumpy/fifo.py: -------------------------------------------------------------------------------- 1 | """``numpy``-based implementation of FIFO queues 2 | """ 3 | 4 | __author__ = "Taro Sekiyama" 5 | __copyright__ = "(C) Copyright IBM Corp. 2016" 6 | 7 | 8 | import numpy 9 | import collections 10 | 11 | 12 | class FIFO: 13 | def __init__(self, shape): 14 | self._fifo = collections.deque(numpy.zeros(shape)) 15 | self._arr = numpy.array(self._fifo) 16 | 17 | def __len__(self): 18 | return len(self._fifo) 19 | 20 | def push(self, a): 21 | b = self._fifo.pop() 22 | self._fifo.appendleft(a) 23 | self._arr = numpy.array(self._fifo) 24 | return b 25 | 26 | def to_array(self): 27 | return self._arr 28 | -------------------------------------------------------------------------------- /examples/Readme.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | Here we provide examples of time-series learning and reinfrocement learning with DyBMs. 4 | 5 | ## Prerequisites 6 | 7 | Some of the esamples rely on `__gym__` and `__matplotlib__`. 8 | 9 | ``` 10 | pip install gym gym[atari] matplotlib 11 | ``` 12 | 13 | ## Time-series learning 14 | 15 | Examples are provided under `time-series` 16 | 17 | ## Reinforcement learning 18 | 19 | Examples are provided under `reinforce` 20 | 21 | `DySARSA_discreteAgent_Demo.py` demonstrates an example for using a model based on `pyDyBM.reinforce.DySARSA` for learning to play Atari games directly from screen pixels, using the Arcade Learning Environment. This example uses the atari games environment as provided by OpenAI Gym (https://gym.openai.com/). 22 | 23 | -------------------------------------------------------------------------------- /src/pydybm/arraymath/dynumpy/data_queue.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright IBM Corp. 2016 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 4 | # not use this file except in compliance with the License. You may obtain 5 | # a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | """``numpy``-based implementation of data queues 16 | """ 17 | 18 | __author__ = "Taro Sekiyama" 19 | 20 | 21 | class DataQueue: 22 | def __init__(self, data): 23 | self._data = data 24 | 25 | def __iter__(self): 26 | return enumerate(self._data) 27 | -------------------------------------------------------------------------------- /src/pydybm/arraymath/dycupy/fifo.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright IBM Corp. 2016 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 4 | # not use this file except in compliance with the License. You may obtain 5 | # a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | """``cupy``-based implementation of FIFO queues 16 | """ 17 | 18 | __author__ = "Taro Sekiyama" 19 | 20 | 21 | import cupy 22 | 23 | 24 | class FIFO: 25 | def __init__(self, shape): 26 | self._fifo = cupy.zeros(shape) 27 | 28 | def __len__(self): 29 | return len(self._fifo) 30 | 31 | def push(self, a): 32 | b = self._fifo[-1, :] 33 | self._fifo = cupy.roll(self._fifo, 1, axis=0) 34 | self._fifo[0, :] = a 35 | return b 36 | 37 | def to_array(self): 38 | return self._fifo 39 | -------------------------------------------------------------------------------- /src/tests/sgd_test.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright IBM Corp. 2017 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 4 | # not use this file except in compliance with the License. You may obtain 5 | # a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | import unittest 16 | 17 | 18 | class TestCaseSGD(unittest.TestCase): 19 | 20 | def test_vSGD(self): 21 | from collections import defaultdict 22 | from pydybm.base.sgd import vSGD 23 | from pydybm.time_series.dybm import LinearDyBM 24 | from pydybm.base.generator import Uniform 25 | 26 | gen = Uniform(length=1000, low=0, high=1, dim=1) 27 | 28 | dybm = LinearDyBM(in_dim=1, delay=1, decay_rates=[], SGD=vSGD(hessian=defaultdict(lambda: 1))) 29 | dybm.learn(gen) 30 | 31 | self.assertEqual(float(dybm.variables['b']), 0.4932564368799297) 32 | 33 | 34 | if __name__ == '__main__': 35 | unittest.main() 36 | -------------------------------------------------------------------------------- /src/tasks/dydpp/README.md: -------------------------------------------------------------------------------- 1 | (C) Copyright IBM Corp. 2019 2 | 3 | # Dynamic determinantal point processes 4 | 5 | DyDPP.py contains core methods of learning and inference with dynamic determinantal point processes. This directory also contains scripts that can be used to reproduce the experimental results and figures reported in the following paper: 6 | 7 | T. Osogami, R. Raymond, A. Goel, T. Shirai, T. Maehara, "Dynamic Determinantal Point Processes," AAAI-18 8 | 9 | # Preparation 10 | 11 | Place the pickle file of data as follows: 12 | ``` 13 | data/XXXX.pickle 14 | ``` 15 | We assume that the pickle file is formatted in the same way as the datasets published at http://www-etud.iro.umontreal.ca/~boulanni/icml2012. Note that these datasets are provided by a third party under the terms and conditions specified by the third party. 16 | 17 | The experiments run with Python 3. Install the dependencies by 18 | ``` 19 | pip install requirements.txt 20 | ``` 21 | 22 | # Run experiments 23 | 24 | ``` 25 | python experiment.py data/XXXX.pickle 26 | ``` 27 | 28 | # Plot figures 29 | 30 | ``` 31 | python plot.py XXX 32 | ``` 33 | where "XXX" is the first three letters of the pickle file (XXXX.pickle). 34 | 35 | Note that the generated figures are similar but not completely identical to those in the paper. We have modified the code to run with Python 3, while the original experiments were run with Python 2. 36 | -------------------------------------------------------------------------------- /src/tests/arraymath.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright IBM Corp. 2017 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 4 | # not use this file except in compliance with the License. You may obtain 5 | # a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | import pydybm.arraymath as amath 16 | import pydybm.arraymath.dynumpy as dynumpy 17 | 18 | 19 | class NumpyTestMixin(object): 20 | def setUp(self): 21 | amath.setup(dynumpy) 22 | print('\nnumpy test') 23 | setup = getattr(super(NumpyTestMixin, self), 'setUp', None) 24 | if setup is not None: 25 | setup() 26 | 27 | 28 | class CupyTestMixin(object): 29 | def setUp(self): 30 | try: 31 | import pydybm.arraymath.dycupy as dycupy 32 | amath.setup(dycupy) 33 | print('\ncupy test') 34 | setup = getattr(super(CupyTestMixin, self), 'setUp', None) 35 | if setup is not None: 36 | setup() 37 | except ImportError: 38 | print('cupy test skipped') 39 | self.skipTest( 40 | 'cupy is not installed and tests with cupy are passed') 41 | -------------------------------------------------------------------------------- /src/jdybm/src/com/ibm/stdp/FIFO.java: -------------------------------------------------------------------------------- 1 | // (C) Copyright IBM Corp. 2015 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may 4 | // not use this file except in compliance with the License. You may obtain 5 | // a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations 13 | // under the License. 14 | 15 | package com.ibm.stdp; 16 | 17 | /* 18 | * FIFO queue 19 | * @author osogami 20 | */ 21 | public interface FIFO { 22 | 23 | public boolean CONNECT_ZERO = false; 24 | 25 | /** 26 | * Initialize the values of the FIFO queue 27 | */ 28 | public void init(); 29 | 30 | /** 31 | * Add a value, x, to the tail of the FIFO queue, and 32 | * remove the value from the head of the FIFO queue 33 | * @param x 34 | * @return the value removed from the head 35 | */ 36 | public boolean offer(boolean x); 37 | 38 | /** 39 | * Get the value of the k-th beta eligibility trace of the FIFO queue 40 | * @param k 41 | * @return 42 | */ 43 | public double getBeta(int k); 44 | 45 | /** 46 | * Store the values of the FIFO queue 47 | */ 48 | public void store(); 49 | 50 | /** 51 | * Restore the values of the FIFO queue 52 | */ 53 | public void restore(); 54 | } 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | .static_storage/ 56 | .media/ 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | _build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | -------------------------------------------------------------------------------- /src/pydybm/arraymath/dycupy/random.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright IBM Corp. 2016 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 4 | # not use this file except in compliance with the License. You may obtain 5 | # a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | """``cupy``-based implementation of the random module 16 | """ 17 | 18 | __author__ = "Taro Sekiyama" 19 | 20 | 21 | import numpy.random as r 22 | import cupy as cp 23 | 24 | 25 | def _to_gpu(a): 26 | arr = cp.empty_like(a) 27 | arr.set(a) 28 | return arr 29 | 30 | 31 | class RandomState: 32 | def __init__(self, seed): 33 | self._random = r.RandomState(seed) 34 | 35 | def uniform(self, low=0.0, high=1.0, size=None): 36 | return _to_gpu(self._random.uniform(low=low, high=high, size=size)) 37 | 38 | def normal(self, loc=0.0, scale=1.0, size=None): 39 | return _to_gpu(self._random.normal(loc=loc, scale=scale, size=size)) 40 | 41 | def get_state(self): 42 | return self._random.get_state() 43 | 44 | def set_state(self, *args): 45 | return self._random.set_state(*args) 46 | 47 | def rand(self, *args): 48 | return _to_gpu(self._random.rand(*args)) 49 | 50 | 51 | seed = r.seed 52 | 53 | 54 | def normal(loc=0.0, scale=1.0, size=None): 55 | return _to_gpu(r.normal(loc=loc, scale=scale, size=size)) 56 | 57 | 58 | def uniform(low=0.0, high=1.0, size=None): 59 | return _to_gpu(r.uniform(low=low, high=high, size=size)) 60 | 61 | 62 | def rand(*args): 63 | return _to_gpu(r.rand(*args)) 64 | 65 | 66 | def randn(*args): 67 | return _to_gpu(r.randn(*args)) 68 | 69 | 70 | def random(size=None): 71 | return _to_gpu(r.random(size=size)) 72 | -------------------------------------------------------------------------------- /src/jdybm/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # (C) Copyright IBM Corp. 2017 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 5 | # not use this file except in compliance with the License. You may obtain 6 | # a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13 | # License for the specific language governing permissions and limitations 14 | # under the License. 15 | 16 | # Run experiments 17 | if [ $1 = "1" ]; then 18 | java -classpath lib/commons-math.jar:bin com.ibm.stdp.science.SingleExperiment3 SCIENCE 3 3 9 19 | elif [ $1 = "2" ]; then 20 | java -classpath lib/commons-math.jar:bin com.ibm.stdp.science.MirrorExperiment3 SCIENCE 5 3 3 9 21 | elif [ $1 = "3" ]; then 22 | java -classpath lib/commons-math.jar:bin com.ibm.stdp.science.EvolutionExperiment 3 3 9 23 | elif [ $1 = "4" ]; then 24 | java -classpath lib/commons-math.jar:bin com.ibm.stdp.science.MusicianExperiment 3 3 9 25 | fi 26 | 27 | # Analyze results 28 | if [ $1 = "1" ]; then 29 | java -classpath bin com.ibm.stdp.science.AnalyzeSingleSCIENCE > python/single.py 30 | java -classpath bin com.ibm.stdp.science.WriteNetwork 31 | elif [ $1 = "2" ]; then 32 | java -classpath bin com.ibm.stdp.science.AnalyzeMirrorSCIENCE > python/mirror.py 33 | elif [ $1 = "3" ]; then 34 | java -classpath bin com.ibm.stdp.science.AnalyzeEvolution > python/evolution.py 35 | elif [ $1 = "4" ]; then 36 | java -classpath bin com.ibm.stdp.science.AnalyzeMusician > python/musician.py 37 | fi 38 | 39 | # Draw figures 40 | cd python 41 | if [ $1 = "1" ]; then 42 | python MakeSingleFigure.py 43 | python MakeSingleNetwork.py 44 | elif [ $1 = "2" ]; then 45 | python MakeMirrorFigure.py 46 | python MakeMirrorStepsFigure.py 47 | elif [ $1 = "3" ]; then 48 | python MakeEvolutionFigure.py 49 | elif [ $1 = "4" ]; then 50 | python MakeMusicalNote.py 51 | fi 52 | cd .. 53 | -------------------------------------------------------------------------------- /src/jdybm/run.bat: -------------------------------------------------------------------------------- 1 | rem (C) Copyright IBM Corp. 2015 2 | rem 3 | rem Licensed under the Apache License, Version 2.0 (the "License"); you may 4 | rem not use this file except in compliance with the License. You may obtain 5 | rem a copy of the License at 6 | rem 7 | rem http://www.apache.org/licenses/LICENSE-2.0 8 | rem 9 | rem Unless required by applicable law or agreed to in writing, software 10 | rem distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | rem WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | rem License for the specific language governing permissions and limitations 13 | rem under the License. 14 | 15 | @echo off 16 | pushd %0\.. 17 | 18 | rem 19 | rem Run experiments 20 | rem 21 | if %1==1 ( 22 | echo Running 1 23 | java -classpath lib\commons-math.jar;bin com.ibm.stdp.science.SingleExperiment3 SCIENCE 3 3 9 24 | ) else if %1==2 ( 25 | java -classpath lib\commons-math.jar;bin com.ibm.stdp.science.MirrorExperiment3 SCIENCE 5 3 3 9 26 | ) else if %1==3 ( 27 | java -classpath lib\commons-math.jar;bin com.ibm.stdp.science.EvolutionExperiment 3 3 9 28 | ) else if %1==4 ( 29 | java -classpath lib\commons-math.jar;bin com.ibm.stdp.science.MusicianExperiment 3 3 9 30 | ) 31 | 32 | rem 33 | rem Analyze results 34 | rem 35 | 36 | if %1==1 ( 37 | java -classpath bin com.ibm.stdp.science.AnalyzeSingleSCIENCE > python\single.py 38 | java -classpath bin com.ibm.stdp.science.WriteNetwork 39 | ) else if %1==2 ( 40 | java -classpath bin com.ibm.stdp.science.AnalyzeMirrorSCIENCE > python\mirror.py 41 | ) else if %1==3 ( 42 | java -classpath bin com.ibm.stdp.science.AnalyzeEvolution > python\evolution.py 43 | ) else if %1==4 ( 44 | java -classpath bin com.ibm.stdp.science.AnalyzeMusician > python\musician.py 45 | ) 46 | 47 | rem 48 | rem Draw figures 49 | rem 50 | 51 | pushd python 52 | if %1==1 ( 53 | python MakeSingleFigure.py 54 | python MakeSingleNetwork.py 55 | ) else if %1==2 ( 56 | python MakeMirrorFigure.py 57 | python MakeMirrorStepsFigure.py 58 | ) else if %1==3 ( 59 | python MakeEvolutionFigure.py 60 | ) else if %1==4 ( 61 | python MakeMusicalNote.py 62 | ) 63 | cd .. 64 | 65 | pause 66 | -------------------------------------------------------------------------------- /src/jdybm/src/com/ibm/stdp/science/WriteNetwork.java: -------------------------------------------------------------------------------- 1 | // (C) Copyright IBM Corp. 2015 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may 4 | // not use this file except in compliance with the License. You may obtain 5 | // a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations 13 | // under the License. 14 | 15 | package com.ibm.stdp.science; 16 | 17 | import java.io.BufferedWriter; 18 | import java.io.File; 19 | import java.io.FileInputStream; 20 | import java.io.FileWriter; 21 | import java.io.IOException; 22 | import java.io.ObjectInputStream; 23 | import java.io.PrintWriter; 24 | 25 | import com.ibm.stdp.Network; 26 | 27 | /* 28 | * Writing out the DyBM after training 29 | * @author osogami 30 | */ 31 | public class WriteNetwork { 32 | public static void main(String[] args) { 33 | String directory = "Results/SCIENCE/Single3b/"; 34 | File dir = new File(directory); 35 | File[] files = dir.listFiles(); 36 | 37 | for(File file : files){ 38 | String filename = file.getName(); 39 | System.out.println(filename); 40 | if(filename.startsWith("NN") && filename.endsWith(".bin")){ 41 | Network ann; 42 | try { 43 | FileInputStream fis = new FileInputStream(file); 44 | ObjectInputStream is = new ObjectInputStream(fis); 45 | ann = (Network) is.readObject(); 46 | is.close(); 47 | } catch (IOException | ClassNotFoundException e) { 48 | e.printStackTrace(); 49 | return; 50 | } 51 | String[] split = filename.split("\\."); 52 | String outname = directory + split[0] + ".csv"; 53 | System.out.println(outname); 54 | try { 55 | File f = new File(outname); 56 | PrintWriter pw = new PrintWriter(new BufferedWriter(new FileWriter(f))); 57 | pw.println(ann.toCSV()); 58 | pw.close(); 59 | } catch (IOException e) { 60 | e.printStackTrace(); 61 | } 62 | } 63 | } 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /src/pydybm/base/metrics.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright IBM Corp. 2016 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 4 | # not use this file except in compliance with the License. You may obtain 5 | # a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | 16 | __author__ = "Takayuki Osogami" 17 | 18 | 19 | from .. import arraymath as amath 20 | 21 | 22 | def MSE(y_true, y_pred): 23 | """ 24 | Mean squared error of a sequence of predicted vectors 25 | 26 | y_true : array, shape(L, N) 27 | y_pred : array, shape(L, N) 28 | 29 | mean of (dy_1^2 + ... + dy_N^2 ) over L pairs of vectors 30 | (y_true[i], y_pred[i]) 31 | """ 32 | MSE_each_coordinate = amath.mean_squared_error(y_true, y_pred, 33 | multioutput="raw_values") 34 | return amath.sum(MSE_each_coordinate) 35 | 36 | 37 | def RMSE(y_true, y_pred): 38 | """ 39 | Root mean squared error of a sequence of predicted vectors 40 | 41 | y_true: array, shape(L, N) 42 | y_pred: array, shape(L, N) 43 | 44 | squared root of the mean of (dy_1^2 + ... + dy_N^2 ) over L pairs of 45 | vectors (y_true[i], y_pred[i]) 46 | """ 47 | return amath.sqrt(MSE(y_true, y_pred)) 48 | 49 | 50 | def baseline_RMSE(init_pred, sequence): 51 | """ 52 | Baseline RMSE where predictions are made using the previous observation. 53 | 54 | Parameters 55 | ---------- 56 | init_pred : float or array, length n_dim 57 | prediction made at time step 0. 58 | sequence : list or generator 59 | time series used for performance evaluation. 60 | 61 | Returns 62 | ------- 63 | float 64 | RMSE of the baseline method, forecasting using the previous observation. 65 | """ 66 | last_pattern = [init_pred] + list(sequence[:-1]) 67 | last_pattern = amath.array(last_pattern) 68 | baseline = RMSE(sequence, last_pattern) 69 | return baseline 70 | -------------------------------------------------------------------------------- /src/tests/bandit.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright IBM Corp. 2017 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 4 | # not use this file except in compliance with the License. You may obtain 5 | # a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | import numpy as np 16 | import gym 17 | from gym import spaces 18 | from gym.utils import seeding 19 | 20 | 21 | class BanditEnv(gym.Env): 22 | """ 23 | Bandit environment base to allow agents to interact with the class n-armed bandit 24 | 25 | p_dist: 26 | A list of probabilities of the likelihood that a particular bandit will pay out 27 | r_dist: 28 | A list of either rewards (if number) or means and standard deviations (if list) 29 | of the payout that bandit has 30 | """ 31 | def __init__(self, p_dist, r_dist): 32 | 33 | self.p_dist = p_dist 34 | self.r_dist = r_dist 35 | 36 | self.n_bandits = len(p_dist) 37 | self.action_space = spaces.Discrete(self.n_bandits) 38 | self.observation_space = spaces.Discrete(1) 39 | 40 | self.seed() 41 | 42 | def seed(self, seed=None): 43 | self.np_random, seed = seeding.np_random(seed) 44 | 45 | return [seed] 46 | 47 | def step(self, action): 48 | assert self.action_space.contains(action) 49 | 50 | done = False 51 | 52 | if np.random.randn(1) > self.p_dist[action]: 53 | 54 | reward = 1 # self.r_dist[0] 55 | 56 | else: 57 | reward = -1 # self.r_dist[1] 58 | 59 | return 0.0, reward, done, {} 60 | 61 | def reset(self): 62 | return 0 63 | 64 | def render(self, mode='human', close=False): 65 | pass 66 | 67 | 68 | class FourArmedBandit(BanditEnv): 69 | """Stochastic version of four-armed bandit where bandit four pays out with highest reward""" 70 | def __init__(self): 71 | BanditEnv.__init__(self, p_dist=[0.2, 0.0, -0.2, -5], r_dist=[1, -1]) 72 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | (C) Copyright IBM Corp. 2016 2 | 3 | This library contains multiple implementations of dynamic Boltzmann machines (DyBMs) and relevant tools. The core of this library is __pydybm__, a Python implementation for learning time-series with DyBMs (see [src/pydybm/Readme.md](src/pydybm/Readme.md)), and __jdybm__, a Java implementation used in the first publication of the DyBM in [www.nature.com/articles/srep14149](http://www.nature.com/articles/srep14149) (see [src/jdybm/Readme.md](src/jdybm/Readme.md)). 4 | 5 | ## What is DyBM? 6 | 7 | The DyBM is an IBM’s artificial neural network, proposed in [www.nature.com/articles/srep14149](http://www.nature.com/articles/srep14149), that is trained via biologically plausible spike-timing dependent plasticity (STDP) in an online and distributed manner for prediction, anomaly detection, classification, reinforcement learning, and other tasks with time-series. DyBM’s learning time per step is independent of the length of (the dependency in) the time-series under consideration (i.e., local in time), whereas existing recurrent neural networks including long short term memory (LSTM) perform, at each step, backpropagation through time whose computational complexity grows linearly with respect to that length. DyBM’s computation for learning, prediction, sampling, and other operations can all be performed in a distributed manner, and its computational complexity of the operation at each unit is independent of the size of the network (i.e., local in space). 8 | 9 | DyBM stands for Dynamic Boltzmann Machine. It is abbreviated as DyBM instead of DBM, because DBM is reserved for Deep Boltzmann Machine in the community. 10 | 11 | ## Directory structure 12 | 13 | Here we provide descriptions of some of the important directories in this library. 14 | 15 | - `src/`: You find source codes here. 16 | - `src/pydybm/`: You find __pydybm__ here. See [src/pydybm/Readme.md](src/pydybm/Readme.md). 17 | - `src/jdybm/`: You find a Java implementation, which is used for the experiments in [www.nature.com/articles/srep14149](http://www.nature.com/articles/srep14149). See [Readme.md for jdybm](src/jdybm/Readme.md). 18 | - `examples/`: You find examples of using __pydybm__ here. Run `jupyter notebook` at this directory to see the examples. See also [src/pydybm/Readme.md](src/pydybm/Readme.md). 19 | - `data/`: You will store datasets here. 20 | 21 | -------------------------------------------------------------------------------- /src/jdybm/src/com/ibm/stdp/science/AnalyzeEvolution.java: -------------------------------------------------------------------------------- 1 | // (C) Copyright IBM Corp. 2015 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may 4 | // not use this file except in compliance with the License. You may obtain 5 | // a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations 13 | // under the License. 14 | 15 | 16 | package com.ibm.stdp.science; 17 | 18 | import java.io.File; 19 | import java.io.FileInputStream; 20 | import java.io.IOException; 21 | import java.io.ObjectInputStream; 22 | import java.util.Arrays; 23 | 24 | import com.ibm.stdp.BitMapImage; 25 | import com.ibm.stdp.Network; 26 | import com.ibm.stdp.TimeSeries; 27 | 28 | /* 29 | * Analysis of the results of the Evolution experiment 30 | * @author osogami 31 | */ 32 | public class AnalyzeEvolution { 33 | 34 | public static void main(String[] args) { 35 | String directory = "Results/Evolution/"; 36 | 37 | BitMapImage image = new BitMapImage(); 38 | final TimeSeries trainingData = image.getEvolution(); 39 | 40 | System.out.println("x = dict()"); 41 | 42 | File dir = new File(directory); 43 | File[] files = dir.listFiles(); 44 | for(File file : files){ 45 | String filename = file.getName(); 46 | if(!filename.endsWith(".bin")){ 47 | continue; 48 | } 49 | String number = filename.split("param")[0].split("NN")[1]; 50 | int n = Integer.parseInt(number); 51 | String result = directory+filename; 52 | 53 | Network ann; 54 | 55 | try { 56 | FileInputStream fis = new FileInputStream(result); 57 | ObjectInputStream is = new ObjectInputStream(fis); 58 | ann = (Network) is.readObject(); 59 | is.close(); 60 | } catch (IOException | ClassNotFoundException e) { 61 | e.printStackTrace(); 62 | return; 63 | } 64 | 65 | // prediction 66 | ann.store(); 67 | TimeSeries prediction = ann.freeRun(trainingData.size()*2,0); 68 | System.out.println("x["+n+"] = "+prediction.toPython()); 69 | ann.restore(); 70 | } 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /src/jdybm/src/com/ibm/stdp/BitMapFont.java: -------------------------------------------------------------------------------- 1 | // (C) Copyright IBM Corp. 2015 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may 4 | // not use this file except in compliance with the License. You may obtain 5 | // a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations 13 | // under the License. 14 | 15 | package com.ibm.stdp; 16 | 17 | import java.util.HashMap; 18 | 19 | /* 20 | * Bitmap patterns of alphabets 21 | * @author osogami 22 | */ 23 | public class BitMapFont { 24 | private final static int[] A = {0,63,72,72,63}; 25 | private final static int[] B = {0,127,73,73,54}; 26 | private final static int[] C = {0,62,65,65,34}; 27 | private final static int[] E = {0,127,73,73,65}; 28 | private final static int[] I = {0,65,127,65}; 29 | private final static int[] L = {0,127,1,1,1}; 30 | private final static int[] M = {0,127,16,8,16,127}; 31 | private final static int[] N = {0,127,16,8,4,127}; 32 | private final static int[] P = {0,127,72,72,48}; 33 | private final static int[] R = {0,127,76,74,49}; 34 | private final static int[] S = {0,49,73,73,70}; 35 | private final static int[] T = {0,64,64,127,64,64}; 36 | private final static int[] U = {0,126,1,1,126}; 37 | 38 | private final static int[] i = {0,7175,7175,7175,8191,8191,8191,8191,7175,7175,7175}; 39 | private final static int[] b = {0,7175,7175,7175,8191,8191,8191,8191,7399,7399,7399,7399,8191,8191,4030,4030,1820}; 40 | private final static int[] m = {0,7175,7175,7175,8191,8191,8191,8191,8064,4064,1016,254,127,254,1016,4064,8064,8191,8191,8191,8191,7175,7175,7175}; 41 | 42 | public final static HashMap map = new HashMap(); 43 | static{ 44 | map.put('A',A); 45 | map.put('B',B); 46 | map.put('C',C); 47 | map.put('E',E); 48 | map.put('I',I); 49 | map.put('L',L); 50 | map.put('M',M); 51 | map.put('N',N); 52 | map.put('P',P); 53 | map.put('S',S); 54 | map.put('T',T); 55 | map.put('U',U); 56 | map.put('R',R); 57 | 58 | map.put('i',i); 59 | map.put('b',b); 60 | map.put('m',m); 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /src/jdybm/src/com/ibm/stdp/science/AnalyzeMusician.java: -------------------------------------------------------------------------------- 1 | // (C) Copyright IBM Corp. 2015 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may 4 | // not use this file except in compliance with the License. You may obtain 5 | // a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations 13 | // under the License. 14 | 15 | package com.ibm.stdp.science; 16 | 17 | import java.io.File; 18 | import java.io.FileInputStream; 19 | import java.io.IOException; 20 | import java.io.ObjectInputStream; 21 | import java.util.Arrays; 22 | 23 | import com.ibm.stdp.BitMapMusic; 24 | import com.ibm.stdp.Network; 25 | import com.ibm.stdp.TimeSeries; 26 | 27 | /* 28 | * Analysis of the results of the Musician experiment 29 | * @author osogami 30 | */ 31 | public class AnalyzeMusician { 32 | 33 | public static void main(String[] args) { 34 | String directory = "Results/Musician/"; 35 | 36 | BitMapMusic music = new BitMapMusic(); 37 | final TimeSeries trainingData = music.getMusician(); 38 | 39 | System.out.println("y = "+trainingData.toPython()); 40 | System.out.println("x = dict()"); 41 | 42 | File dir = new File(directory); 43 | File[] files = dir.listFiles(); 44 | for(File file : files){ 45 | String filename = file.getName(); 46 | if(!filename.endsWith(".bin")){ 47 | continue; 48 | } 49 | String number = filename.split("param")[0].split("NN")[1]; 50 | int n = Integer.parseInt(number); 51 | String result = directory+filename; 52 | 53 | Network ann; 54 | 55 | try { 56 | FileInputStream fis = new FileInputStream(result); 57 | ObjectInputStream is = new ObjectInputStream(fis); 58 | ann = (Network) is.readObject(); 59 | is.close(); 60 | } catch (IOException | ClassNotFoundException e) { 61 | e.printStackTrace(); 62 | return; 63 | } 64 | 65 | // prediction 66 | ann.store(); 67 | TimeSeries prediction = ann.freeRun(56,0); 68 | System.out.println("x["+n+"] = "+prediction.toPython()); 69 | ann.restore(); 70 | } 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /src/pydybm/reinforce/agent.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright IBM Corp. 2016 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 4 | # not use this file except in compliance with the License. You may obtain 5 | # a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | """ 16 | Base class for a reinforcement learning agent 17 | 18 | __Author__: Sakyasingha Dasgupta 19 | """ 20 | 21 | 22 | # Agent abstractions 23 | class Agent(object): 24 | """The Reinforcement learning agent BaseClass. 25 | 26 | The main methods that users of this class need to know are: 27 | fit 28 | fit_episode 29 | predict 30 | 31 | When implementing an environment, override the following methods 32 | in your subclass: 33 | _fit 34 | _fit_episode 35 | _predict 36 | 37 | """ 38 | 39 | def __new__(cls, *args, **kwargs): 40 | # We use __new__ since we want the env author to be able to 41 | # override __init__ without remembering to call super. 42 | env = super(Agent, cls).__new__(cls) 43 | 44 | # Will be automatically set when creating an environment via 'make' 45 | return env 46 | 47 | # Override in ALL subclasses 48 | def _fit(self, test_every, test_num_eps, break_reward, render): 49 | raise NotImplementedError 50 | 51 | def _fit_episode(self, episode, test_every, test_num_eps, break_reward): 52 | raise NotImplementedError 53 | 54 | def _predict(self, render): 55 | raise NotImplementedError 56 | 57 | def fit(self, test_every, test_num_eps, break_reward, render): 58 | """ 59 | Run the training algorithm 60 | """ 61 | return self._fit(test_every, test_num_eps, break_reward, render) 62 | 63 | def fit_episode(self): 64 | """ 65 | Run the training algorithm for 1 step 66 | """ 67 | return self._fit_episode() 68 | 69 | def predict(self, render=True): 70 | """ 71 | Test the training algo 72 | """ 73 | return self._predict(render) 74 | -------------------------------------------------------------------------------- /src/tests/tools_test.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright IBM Corp. 2016 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 4 | # not use this file except in compliance with the License. You may obtain 5 | # a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | # License for the specific language governing permissions and limitations 13 | # under the License. 14 | 15 | 16 | __author__ = "Takayuki Osogami" 17 | 18 | 19 | import unittest 20 | import numpy as np 21 | from tests.arraymath import NumpyTestMixin, CupyTestMixin 22 | from pydybm.base.metrics import MSE, RMSE, baseline_RMSE 23 | 24 | 25 | class metricsTestCase(object): 26 | """ 27 | unit test for metrics 28 | """ 29 | 30 | def setUp(self): 31 | self.L = 4 32 | self.N = 3 33 | 34 | def tearDown(self): 35 | pass 36 | 37 | def testMSE(self): 38 | np.random.seed(0) 39 | y = np.random.random((self.L, self.N)) 40 | z = y + 1.0 41 | err = MSE(y, y) 42 | self.assertAlmostEqual(err, 0) 43 | err = MSE(y, z) 44 | self.assertAlmostEqual(err, self.N) 45 | 46 | def testRMSE(self): 47 | np.random.seed(0) 48 | y = np.random.random((self.L, self.N)) 49 | z = y + 1.0 50 | err = RMSE(y, y) 51 | self.assertAlmostEqual(err, 0) 52 | err = RMSE(y, z) 53 | self.assertAlmostEqual(err, np.sqrt(self.N)) 54 | 55 | def testBaseline_RMSE(self): 56 | np.random.seed(0) 57 | pattern = np.random.random(self.N) 58 | y = [pattern] * self.L 59 | err = baseline_RMSE(pattern, y) 60 | self.assertAlmostEqual(err,0) 61 | y = range(self.L) 62 | init_pred = -1 63 | err = baseline_RMSE(init_pred, y) 64 | self.assertAlmostEqual(err, 1) 65 | 66 | 67 | class metricsTestCaseNumpy(NumpyTestMixin, 68 | metricsTestCase, 69 | unittest.TestCase): 70 | pass 71 | 72 | 73 | class metricsTestCaseCupy(CupyTestMixin, 74 | metricsTestCase, 75 | unittest.TestCase): 76 | pass 77 | 78 | 79 | if __name__ == "__main__": 80 | 81 | unittest.main() 82 | 83 | -------------------------------------------------------------------------------- /src/jdybm/src/com/ibm/stdp/BitMapMusic.java: -------------------------------------------------------------------------------- 1 | // (C) Copyright IBM Corp. 2015 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may 4 | // not use this file except in compliance with the License. You may obtain 5 | // a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations 13 | // under the License. 14 | 15 | package com.ibm.stdp; 16 | 17 | /* 18 | * Bitmap pattern of Ich bin ein Musikante 19 | * @author osogami 20 | */ 21 | public class BitMapMusic { 22 | static int[][] musician = { 23 | {0}, 24 | 25 | {3,-4}, 26 | {3,0}, 27 | {4,-1}, 28 | {4,0}, 29 | 30 | {5,-4}, 31 | {0}, 32 | {7,-2}, 33 | {6,0}, 34 | 35 | {5,0,-4}, 36 | {5,0,-4}, 37 | {4,0,-1}, 38 | {4,0,-1}, 39 | 40 | {3,0,-2}, 41 | {0,-2}, 42 | {0,-2}, 43 | {0}, 44 | 45 | {3,-4}, 46 | {0}, 47 | {4,-1}, 48 | {4,0}, 49 | 50 | {5,-4}, 51 | {0}, 52 | {7,-2}, 53 | {6,0}, 54 | 55 | {5,0,-4}, 56 | {5,0,-4}, 57 | {4,0,-1}, 58 | {4,0,-1}, 59 | 60 | {3,0,-2}, 61 | {0,-2}, 62 | {0,-2}, 63 | {7}, 64 | 65 | {3,5,-4}, 66 | {3,5,-2}, 67 | {3,5,0}, 68 | {6}, 69 | 70 | {2,4,-3}, 71 | {2,4,-1}, 72 | {2,4,0}, 73 | {7}, 74 | 75 | {3,5,-4}, 76 | {3,5,-2}, 77 | {3,5,0}, 78 | {6}, 79 | 80 | {2,4,-3}, 81 | {2,4,-1}, 82 | {2,4,0}, 83 | {7}, 84 | 85 | {5,0}, 86 | {3}, 87 | {4,-1}, 88 | {4}, 89 | 90 | {3,0,-2}, 91 | {0,-2}, 92 | {0,-2} 93 | }; 94 | 95 | public TimeSeries getMusician(){ 96 | int min = Integer.MAX_VALUE; 97 | int max = Integer.MIN_VALUE; 98 | for(int i=0;imax){ 105 | max = n; 106 | } 107 | } 108 | } 109 | 110 | boolean[][] data = new boolean[musician.length][max-min+1]; 111 | for(int i=0;i 0: 47 | self._next = (cupy.cuda.Event(block=True), 48 | cupy.empty_like(list(self._queue[0])[1])) 49 | 50 | def next(self): 51 | if len(self._queue) == 0: 52 | raise StopIteration 53 | 54 | e, a = self._queue.popleft() 55 | 56 | if self._idx < len(self._data): 57 | next_e, next_a = self._next 58 | next_a.set(np.array(self._current_data, copy=False), self._stream) 59 | next_e.record(self._stream) 60 | self._queue.append(self._next) 61 | self._next = (e, a) 62 | 63 | self._idx += 1 64 | e.synchronize() 65 | return (self._idx - self._prefetch - 1, a) 66 | 67 | @property 68 | def _current_data(self): 69 | return self._data[self._idx] 70 | 71 | 72 | class DataQueue: 73 | def __init__(self, data, prefetch=5): 74 | self._data = data 75 | self._prefetch = prefetch 76 | 77 | def __iter__(self): 78 | return Iterator(self._data, self._prefetch) 79 | -------------------------------------------------------------------------------- /src/jdybm/src/com/ibm/stdp/Parameter.java: -------------------------------------------------------------------------------- 1 | // (C) Copyright IBM Corp. 2015 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may 4 | // not use this file except in compliance with the License. You may obtain 5 | // a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations 13 | // under the License. 14 | 15 | package com.ibm.stdp; 16 | 17 | import java.io.Serializable; 18 | import java.util.Arrays; 19 | 20 | /** 21 | * Global parameters of a DyBM 22 | * @author osogami 23 | * 24 | */ 25 | public class Parameter implements Serializable { 26 | 27 | private static final long serialVersionUID = 1L; 28 | 29 | public int nRatesPositive;// = ratePositive.length; 30 | public int nRatesNegative;// = rateNegative.length; 31 | public double[] ratePositive; 32 | public double[] rateNegative; 33 | 34 | public int minDelay; 35 | public int maxDelay; 36 | 37 | public Parameter(int n, int m){ 38 | nRatesNegative = n; 39 | rateNegative = new double[n]; 40 | for(int i=0;i