├── .gitignore ├── README.md ├── SpikingNeuroController.py ├── Thesis.pdf ├── actor.py ├── agent.py ├── connectome.py ├── critic.py ├── draw.py ├── exp.py ├── experiments ├── amplificationtest.py ├── fullynet.py ├── lf_absolutenet.py ├── lf_placecells.py ├── lf_relative_prewired.py ├── lf_simplified.py ├── polebalancing.py └── polebalancing_simplified.py ├── fremauxfilter.py ├── globalvalues.py ├── gridsearch.sh ├── gridsearch ├── gridsearch.json └── vq.json ├── lineFollowingEnvironment.py ├── lineFollowingEnvironment2.py ├── models ├── __init__.py └── trainingrun.py ├── nestdockerrun.sh ├── requirements.txt ├── rstdp_imp_test.py ├── symbolicactor.py └── utilityTest.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | experimentdata/ 131 | .idea -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SNNexperiments 2 | Reinforcement learning framework for spiking neural networks actors with R-STDP for the master thesis "Training Spiking Neural Networks with Reinforcement Learning". The thesis is included in the repository. 3 | 4 | This is an Actor-Critic Reinforcement Learning Framework where the actor/controller is a neurocontroller using Spiking Neural Networks. The critic uses traditional episodic calcualted value iteration. 5 | 6 | ## Requirements 7 | Python 3.7 or 3.8 8 | 9 | ## Installation 10 | First, nest 3 needs to be installed. 11 | 12 | ### Compile NEST 3 on macOS: 13 | 14 | From nest dir, replace : 15 | ``` 16 | cmake -DCMAKE_INSTALL_PREFIX:PATH=~/Studium/Masterarbeit/nest \ 17 | -DCMAKE_C_COMPILER=/usr/local/bin/gcc-10\ 18 | -DCMAKE_CXX_COMPILER=/usr/local/bin/g++-10 \ 19 | ./ 20 | ``` 21 | 22 | then `make install` and afterwards `source ~/.bashrc` 23 | 24 | ### Install NEST via Docker 25 | Alternatively install via the docker image by running nestdockerrun.sh or usin this line: 26 | ```docker run --rm -it -e LOCAL_USER_ID=`id -u $USER` -v $(pwd):/opt/data -p 8080:8080 nestsim/nest:latest /bin/bash``` 27 | 28 | then attach with 29 | 30 | ```docker ps``` 31 | 32 | ```docker attach``` 33 | 34 | ## Install Python Dependencies 35 | run 36 | 37 | ```pip install -r requirements.txt``` 38 | 39 | 40 | ## Running Experiments 41 | Default experiment 42 | 43 | `python exp.py` 44 | 45 | Some predefined experiments are included in the experiments directory 46 | 47 | Grid search on global values can be performed by adding the argument 48 | 49 | ```-g ``` 50 | 51 | to any experiment. 52 | 53 | Example files are in the `gridsearch` directory. The parameter names must match the parameters in the globalvalues.py. 54 | 55 | ## Obtaining Results 56 | Plots are written to ./experimentdata/ 57 | 58 | If a mongoDB connection string is specified in the exp.py, writing to the db can be enabled. Then training progress and data dumps are written to the db. 59 | 60 | ## Config 61 | `render` enables live rendering 62 | 63 | `headless` should be enabled when running on a server without a display driver 64 | -------------------------------------------------------------------------------- /SpikingNeuroController.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, Tuple 2 | import numpy as np 3 | from actor import Actor, PlaceCellAnalog2ActivationLayer, Weightstorage 4 | from globalvalues import gv 5 | try: 6 | import nest 7 | except ImportError: 8 | print("Neural simulator Nest not found (import nest). Only able to run the simplified architecture.") 9 | 10 | 11 | class SpikingNeurocontroller(Actor): 12 | def __init__(self, placecell_range: Union[np.ndarray, None], num_neurons_per_dim: Union[np.ndarray, int], env, 13 | neuron_labels: List[str] = None): 14 | """ 15 | 16 | :param placecell_range: if not using place cell encoding pass None 17 | :param num_neurons_per_dim: for each dimension number of place fields 18 | :param env: 19 | :param neuron_labels: 20 | """ 21 | super().__init__() 22 | self.env = env 23 | self.obsfactor = 1 # factors to scale observations, can be a numpy array of size of number of place cells 24 | # place cells 25 | num_place_cells: int = np.multiply.reduce(num_neurons_per_dim) 26 | self.placell_nneurons_per_dim = num_neurons_per_dim 27 | self.placecell_range: np.ndarray = placecell_range # axis 0: dimension, axis 1: [from, to] 28 | if placecell_range is not None: 29 | labels = [] 30 | for i in range(num_place_cells): 31 | pc_center = placecell_range[0, 0] + i * abs((placecell_range[0, 1] - placecell_range[0, 0])) / ( 32 | num_place_cells - 1) 33 | labels.append(f"pc_{pc_center}") 34 | self.placeCellA2SLayer: PlaceCellAnalog2ActivationLayer = PlaceCellAnalog2ActivationLayer( 35 | placecellrange=self.placecell_range, 36 | num_cells_per_dim=self.placell_nneurons_per_dim) 37 | 38 | from connectome import Connectome 39 | self.connectome = Connectome(actor=self, num_input=num_place_cells, neuron_labels=neuron_labels) 40 | # a matrix storing for each synapse (time, from, to) the history, +1 because 0 is initial 41 | self.weightlog: List[Weightstorage] = [] 42 | self.only_positive_input = False # if input is only positive 43 | 44 | # log 45 | self.end_episode(-1) 46 | 47 | def _set_input(self, observation: np.ndarray): 48 | """translate observation into rate code (Analog to Spike Rate)""" 49 | rate: np.ndarray 50 | # is using place cells? 51 | if self.placecell_range is not None: 52 | rate = self.placeCellA2SLayer.activation( 53 | observation=observation) 54 | else: 55 | # encode directly via rate 56 | if self.only_positive_input: 57 | rate: np.ndarray = observation 58 | else: 59 | # if it can also be negative split into two neurons 60 | rate = np.empty((observation.shape[0] * 2)) 61 | for i, obs in enumerate(observation): 62 | rate[i * 2] = -np.minimum(obs, 0) 63 | rate[i * 2 + 1] = np.maximum(obs, 0) 64 | 65 | # logging 66 | # translate to scaled numpy array 67 | rate = rate * self.obsfactor 68 | rate = np.clip(rate, 0, gv.max_poisson_freq) 69 | self.connectome.set_inputactivation(rate) 70 | 71 | def cycle(self, time: float, observation_in: np.ndarray) -> Tuple[Union[int, List[float]], List]: 72 | """has side effects, call only once per cycle""" 73 | # feed observations into brain 74 | self._set_input(observation_in) 75 | nest.Simulate(gv.cycle_length) # run does not work because between cycles parameters are set 76 | return self.get_action(time=time) 77 | 78 | def release_neurotransmitter(self, amount: float): 79 | super().release_neurotransmitter(amount) 80 | if len(self.connectome.conns_nest_in) >= 1: 81 | self.connectome.conns_nest_in.set({"n": -amount}) 82 | # because of some strange bug where the first is a proxy node 83 | self.connectome.conns_nest_ex.set({"n": amount}) 84 | 85 | def end_cycle(self, cycle_num): 86 | 87 | self.connectome.end_cycle(cycle_num) 88 | 89 | def end_episode(self, episode): 90 | super().end_episode(episode) 91 | weights = self.connectome.get_weights() 92 | self.weightlog.append(weights) 93 | -------------------------------------------------------------------------------- /Thesis.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BSVogler/SNN-RL/67df1115fd604706ab08200ba17eab11d5f844e2/Thesis.pdf -------------------------------------------------------------------------------- /actor.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | from lineFollowingEnvironment import LineFollowingEnv 4 | from lineFollowingEnvironment2 import LineFollowingEnv2 5 | from globalvalues import gv 6 | import fremauxfilter 7 | import numpy as np 8 | import sys 9 | 10 | class PlaceCellAnalog2ActivationLayer: 11 | positions: np.ndarray = None # list of center positions 12 | distance_pc: np.ndarray # can be cached without vector quantization 13 | 14 | def __init__(self, placecellrange: np.ndarray, num_cells_per_dim: np.ndarray): 15 | """ 16 | 17 | :param placecellrange: min max for each dimension 18 | :param num_cells_per_dim: 19 | """ 20 | # distance between centers covered by the pc per dimension 21 | self.distance_pc = np.abs((placecellrange[:, 1] - placecellrange[:, 0])) / (num_cells_per_dim - 1) 22 | numdims = placecellrange.shape[0] 23 | # for each dimension create a grid, then transform to have all coordinates 24 | pos = np.mgrid[tuple(slice(0., num_cells_per_dim[dim]) for dim in range(numdims))].T.reshape(-1, 25 | numdims) # shape: (num_cells_per_dim^numdims, num_cells_per_dim) 26 | self.positions = pos * self.distance_pc 27 | self.positions += placecellrange[:, 0] # add offset to each element 28 | 29 | if sys.maxsize > 2**32:#check if 64bit 30 | self.vq_decay = np.float128(1.0) # of vector quantization 31 | else: 32 | self.vq_decay = np.float64(1.0) # of vector quantization 33 | 34 | self.sigmafactor = 1 35 | if gv.workerdata is not None and "receptivefieldsize" in gv.workerdata: 36 | self.sigmafactor = gv.workerdata["receptivefieldsize"] 37 | 38 | def activation(self, observation: np.ndarray) -> np.ndarray: 39 | """ 40 | calculate activation for each neuron with exponential kernel like in fremaux2013 (eq. 22) 41 | When vector quantization is activated this function call has side-effects as it moves the positions. 42 | :param observation: 43 | :return: activation levels 44 | """ 45 | 46 | if gv.vq_learning_scale > 0: 47 | # changes every time, so cannot be cached 48 | # rezsigma = (len(dists) / np.sum(dists) 49 | k_averagedistancing = False 50 | if k_averagedistancing: 51 | rezsigma = np.empty_like(self.positions) 52 | k = 4 53 | for i in range(len(rezsigma)): 54 | # distance to other place cells 55 | dist = np.linalg.norm(self.positions[i] - self.positions, axis=1) 56 | rezsigma[i] = np.average(dist[np.argpartition(dist, 1 + k)[1:1 + k]]) # ignore self 57 | else: 58 | rezsigma = self.sigmafactor / self.distance_pc 59 | dists: np.ndarray = np.linalg.norm((self.positions - observation) * rezsigma, axis=1) 60 | dists2: np.ndarray = dists ** 2 61 | input_activation = dists2 # use average distance 62 | self.vector_quantization(observation, dists2) 63 | else: 64 | # use lp2 norm, weighted by dimensionality density 65 | scaleddistance = (self.positions - observation) / self.distance_pc # calculation per dimension 66 | input_activation = np.linalg.norm(scaleddistance, axis=1) ** 2 67 | vec = np.exp(-input_activation) 68 | return vec / np.sum(vec) # normalize 69 | 70 | def vector_quantization(self, observation: np.ndarray, dist2: np.ndarray): 71 | if sys.maxsize > 2 ** 32: # check if 64bit 72 | self.vq_decay *= np.float128(1 - gv.vq_decay) # exponentially decrease strength of vq 73 | else: 74 | self.vq_decay *= np.float64(1 - gv.vq_decay) # exponentially decrease strength of vq 75 | changeamount = gv.vq_learning_scale * np.exp(-dist2 / self.vq_decay) 76 | self.positions += (observation - self.positions) * changeamount[:, np.newaxis] 77 | 78 | 79 | Weightstorage = np.array 80 | 81 | 82 | class Actor: 83 | """Part of the agent that manages behaviour. Includes analog-2-spike and spike-2-analog.""" 84 | 85 | def __init__(self): 86 | self.log_m: List[float] = [] # log every cycle 87 | 88 | def read_output(self, time: float) -> List[float]: 89 | """returns a vector containing the sampled output signal per neuron 90 | time: Time when the output is read. Is only needed for filtered. (Spike to Analog)""" 91 | 92 | # temporal average by filtering 93 | # for each neuron there is a signal 94 | spike_signals = self.connectome.get_outspikes() 95 | if gv.filter_outsignals: 96 | # plot spikes 97 | return fremauxfilter.filter(time, spike_signals) 98 | else: 99 | # just count spikes in cycle 100 | return list(map(lambda x: len(x), spike_signals)) 101 | 102 | def cycle(self, time: float, observation_in: np.ndarray) -> Tuple[Union[int, List[float]], List]: 103 | pass 104 | 105 | def get_action(self, time: float) -> Tuple[Union[int, List[float]], List]: 106 | """ 107 | Get the action this actor decided to do 108 | :param time: the time at which the output is read 109 | :return: Tuple[activation of action for gym, neural output activations] 110 | """ 111 | outputs = self.read_output(time=time) 112 | left = outputs[0] 113 | right = outputs[1] 114 | 115 | if isinstance(self.env, LineFollowingEnv) or isinstance(self.env, LineFollowingEnv2): 116 | sensitivitiy = 1 / 400.0 117 | if gv.filter_outsignals: 118 | sensitivitiy *= 51.081 # empirical value from one episoded with least squares 119 | 120 | # in line following the action 0.5 means to go straight, 121 | action = [np.clip(0.5 - (left - right) * sensitivitiy, 0., 1.)] 122 | else: 123 | action = 0 if left >= right else 1 124 | return action, outputs 125 | 126 | def release_neurotransmitter(self, amount: float): 127 | """insert reward into nest synapses""" 128 | self.log_m.append(amount) 129 | 130 | def end_cycle(self, cycle_num): 131 | pass 132 | 133 | def end_episode(self, episode): 134 | """ 135 | store weights for loading in next episode and logging 136 | :return: 137 | """ 138 | self.log_m.clear() 139 | -------------------------------------------------------------------------------- /agent.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Tuple 2 | try: 3 | import nest 4 | except ImportError: 5 | print("Neural simulator Nest not found (import nest). Only able to run the simplified architecture.") 6 | 7 | from actor import Actor, Weightstorage 8 | from critic import AbstractCritic 9 | from globalvalues import gv 10 | 11 | 12 | class Agent: 13 | def __init__(self, environment, actor: Actor, critic: AbstractCritic): 14 | self.actor = actor 15 | self.critic: AbstractCritic = critic 16 | self.environment = environment 17 | 18 | def get_action(self, time) -> List[float]: 19 | return self.actor.get_action(time) 20 | 21 | def end_cycle(self, cycle_num): 22 | self.actor.end_cycle(cycle_num) 23 | 24 | def end_episode(self, episode): 25 | self.critic.end_episode() 26 | self.actor.end_episode(episode) 27 | 28 | def prepare_episode(self): 29 | try: 30 | self.actor.connectome.rebuild() 31 | except AttributeError: 32 | pass 33 | 34 | def post_episode(self): 35 | try: 36 | #only simulate if there is a connectome 37 | self.actor.connectome 38 | nest.Simulate(gv.cycle_length) 39 | except AttributeError: 40 | pass 41 | 42 | 43 | def get_weights(self) -> Weightstorage: 44 | try: 45 | return self.actor.connectome.get_weights() 46 | except: 47 | return self.actor.placecellaction.copy() -------------------------------------------------------------------------------- /connectome.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import pickle 3 | from typing import List, Dict, Set, Tuple 4 | 5 | try: 6 | import nest 7 | except ImportError: 8 | print("Neural simulator Nest not found (import nest). Only able to run the simplified architecture.") 9 | 10 | import numpy as np 11 | from nest.lib.hl_api_types import NodeCollection, SynapseCollection 12 | 13 | import draw 14 | from actor import Weightstorage 15 | from globalvalues import gv 16 | 17 | 18 | class Connectome: 19 | def __init__(self, actor, num_input=8, num_output=2, neuron_labels: List[str] = []): 20 | """ 21 | Create and connect a nest network 22 | :param initial: 23 | :param num_input: 24 | :param num_output: 25 | :return: 26 | """ 27 | self.actor = actor 28 | # note that the order does not reflect the internal nest order 29 | self.neur_ids_parrot = [] 30 | self.neur_ids_out: List[int] = [] 31 | self.neur_ids_core: np.ndarray # the neurons which are used for wiring (except spike generators) 32 | self.neur_ids_ex = [] # includes input 33 | self.neur_ids_hidden_in = [] # ids of the free (not in population) in hidden layer 34 | self.neur_ids_hidden_ex = [] # ids of the free (not in population) in hidden layer 35 | 36 | # store connected pairs (source, target) for reconnecting, order different from nest 37 | self.conns: np.ndarray 38 | self.conns_in: np.ndarray 39 | self.conns_ex: np.ndarray 40 | 41 | # redundance in nest format/indexing 42 | self.conns_nest: SynapseCollection = [] # contains all stdp synapses 43 | self.conns_nest_ex: SynapseCollection = [] 44 | self.conns_nest_in: SynapseCollection = [] 45 | self.populations_nest: List[NodeCollection] = [] 46 | self.neurons_nest_ex: NodeCollection = [] 47 | self.neurons_nest_in: NodeCollection = [] 48 | self.neurons_nest: NodeCollection = [] 49 | self.neurons_input: NodeCollection = [] 50 | 51 | self.cycles_to_reconnect = 2 # counts the cycles till a reconnect (structural plasticity) happens 52 | self.num_input = num_input 53 | self.num_output = num_output 54 | self.last_num_spikes = [] # record the number of spikes in this trial to only get new spikes for each out population 55 | self.multimeter = None 56 | self.synapsecontingent = 1 57 | 58 | self.spike_detector = None 59 | self.spike_detectors_populations = [] 60 | # log 61 | self.lastweights: Weightstorage = [] # only stores last entries for restoring in next episode, order like nest connections 62 | 63 | self.neuron_labels: List[str] = neuron_labels 64 | 65 | # after initalizing fields construct 66 | nest.set_verbosity("M_WARNING") 67 | self.rebuild(True) 68 | print("Training " + str(len(self.conns_nest)) + " weights") 69 | 70 | def create_input_layer(self, initial): 71 | """input layer in nest""" 72 | #todo use populations of size > 1 73 | self.neurons_input = nest.Create("poisson_generator", self.num_input) 74 | # introduce parrot neuron to fix limitation of devices with STDP synapses 75 | self.parrots = nest.Create("parrot_neuron", self.num_input) 76 | self.neur_ids_parrot = self.parrots.tolist() 77 | if initial: 78 | self.neur_ids_ex.extend(self.neur_ids_parrot) 79 | # add each parrot as a population 80 | for i in range(self.num_input): 81 | self.populations_nest.append(self.parrots[i]) 82 | # connect without adding to front-end 83 | nest.Connect(self.neurons_input, self.parrots, "one_to_one") 84 | 85 | self.spike_detector = nest.Create('spike_detector') 86 | nest.Connect(self.parrots, self.spike_detector, 'all_to_all') 87 | 88 | def init_labels(self): 89 | """ 90 | set labels for matching neuron id to label 91 | :return: 92 | """ 93 | if self.neuron_labels is not None and self.num_input < 3000: 94 | neuronlist = self.neurons_input.tolist() 95 | parrotlist = self.parrots.tolist() 96 | for i in range(self.num_input): 97 | gv.neuronLabelMap[neuronlist[i]] = self.neuron_labels[i] 98 | # copy labels 99 | gv.neuronLabelMap[parrotlist[i]] = gv.neuronLabelMap[neuronlist[i]] 100 | 101 | gv.neuronLabelMap[self.neur_ids_out[0]] = "Left" 102 | gv.neuronLabelMap[self.neur_ids_out[1]] = "Right" 103 | 104 | def connect_front_end(self): 105 | """only front-end related connections""" 106 | #temporary for faster creation 107 | self.conns_ex = [] 108 | self.conns_in = [] 109 | self.conns = [] 110 | if gv.manualwiring is not None: 111 | for conn in gv.manualwiring: 112 | for from_neuron in self.populations_nest[conn[0]].tolist(): 113 | for to_neuron in self.populations_nest[conn[1]].tolist(): 114 | self.conns.append((from_neuron, to_neuron)) 115 | if conn[2]: 116 | self.conns_ex.append((from_neuron, to_neuron)) 117 | else: 118 | self.conns_in.append((from_neuron, to_neuron)) 119 | else: 120 | # connect input to everything 121 | self.add_connection(self.neur_ids_parrot, self.neur_ids_out, 'excitatory') 122 | self.add_connection(self.neur_ids_parrot, self.neur_ids_hidden_ex, 'excitatory') 123 | # connect hidden with out 124 | self.add_connection(self.neur_ids_hidden_ex, self.neur_ids_out, 'excitatory') 125 | 126 | self.conns_ex = np.array(self.conns_ex) 127 | self.conns_in = np.array(self.conns_in) 128 | self.conns = np.array(self.conns) 129 | 130 | def rebuild(self, initial=False): 131 | """in intial phase create connections in front-end (neur_ids), always creates nest connections (back-end). Keeps the weights.""" 132 | nest.set_verbosity("M_ERROR") 133 | nest.ResetKernel() 134 | self.populations_nest.clear() 135 | self.create_input_layer(initial) 136 | 137 | if gv.manualwiring is not None: 138 | num_inhibitory = gv.manual_num_inhibitory 139 | else: 140 | if gv.num_hidden_neurons > self.num_output + 1: 141 | num_inhibitory = int(np.ceil((self.num_output + gv.num_hidden_neurons) * gv.fraction_hidden_inhibitory)) 142 | else: 143 | num_inhibitory = 0 144 | # create populations 145 | # self.neurons_nest_ex = nest.NodeCollection([]) 146 | # hidden population 147 | if gv.num_hidden_neurons - num_inhibitory < 0: 148 | raise AssertionError("Invalid number of excitatory and inhibitory neurons") 149 | # reset last detected spikes 150 | self.spike_detectors_populations = [] 151 | num_hidden_ex = gv.num_hidden_neurons - num_inhibitory 152 | self.neurons_nest_ex = None 153 | if num_hidden_ex > 0: 154 | self.neur_ids_hidden_ex = self.create_population(num_hidden_ex, 155 | initial=initial, 156 | recurrent=True) # todo add fraction of inhibitory 157 | # create output populations 158 | self.neur_ids_out = [] 159 | self.last_num_spikes = [] 160 | outpopulations = [] 161 | for i in range(gv.num_output_populations): 162 | popnest, poplist = self.create_population(gv.population_size, initial=initial, recurrent=False) 163 | self.neur_ids_out.extend(poplist) 164 | outpopulations.append(popnest) 165 | 166 | # self.neur_ids_in.extend(self.neur_ids_layer_in) 167 | self.neur_ids_ex.extend(self.neur_ids_out) 168 | 169 | # connect out populations lateral inhibition, not in front-end because it is not a STDP synapse 170 | if gv.lateral_inhibition > 0: 171 | #todo replace with autapses': False, all to all 172 | for i in outpopulations: 173 | for j in outpopulations: 174 | if i!=j: 175 | nest.Connect(i, j, syn_spec={'weight': -gv.lateral_inhibition})#static synapses 176 | 177 | 178 | if initial: 179 | self.connect_front_end() 180 | self.init_labels() 181 | 182 | # create neuromodulator synapse 183 | vt = nest.Create('volume_transmitter') 184 | gv.define_stdp(vt) 185 | # using a distribution is not possible with stdp_dopamine_synapse 186 | # {"distribution": "uniform", "low": -weight, "high": weight} 187 | 188 | # setup measurement devices 189 | nest.Connect(self.neurons_nest_ex, self.spike_detector) 190 | self.multimeter = nest.Create("multimeter") 191 | self.multimeter.set({"record_from": ["V_m"]}) # removed in nest3: "withtime": True, 192 | nest.Connect(self.multimeter, self.neurons_nest_ex) 193 | gv.voltageRecordings = set(self.neurons_nest_ex.tolist()) 194 | 195 | # connection in front-end already established: now connect neurons in nest backend 196 | nest.Connect(self.conns_ex[:,0], 197 | self.conns_ex[:,1], 198 | syn_spec={'synapse_model': 'stdp_dopamine_synapse_ex'}, 199 | conn_spec="one_to_one") 200 | if len(self.conns_in) > 0: 201 | nest.Connect(self.conns_in[:,0], 202 | self.conns_in[:,1], 203 | syn_spec={'synapse_model': 'stdp_dopamine_synapse_in'}, 204 | conn_spec="one_to_one") 205 | 206 | # update references to nest 207 | self.update_connections_nest() 208 | 209 | # random initial weights 210 | if initial: 211 | # randomize weights for inhibitory 212 | if gv.num_hidden_neurons > 1: # can only have inhibitory if there are more than one excitatory 213 | rand_weight = gv.pyrngs[0].uniform(gv.w0_min, gv.w0_max, size=len(self.conns_nest_in)) 214 | self.conns_nest_in.set({"weight": -rand_weight}) 215 | 216 | rand_weight = gv.pyrngs[0].uniform(gv.w0_min, gv.w0_max, size=len(self.conns_nest_ex)) 217 | self.conns_nest_ex.set({"weight": rand_weight}) 218 | 219 | if initial: 220 | self.neur_ids_core = np.unique( 221 | [x[0] for x in self.conns] + [x[1] for x in self.conns]) # the connected neurons of interest 222 | else: 223 | # use weights from last episode, offset one because of initial weights in 0 224 | self.restorelastweights() 225 | if gv.structural_plasticity and gv.random_reconnect: 226 | self.random_reconnect() 227 | 228 | nest.set_verbosity("M_WARNING") 229 | 230 | def set_inputactivation(self, rate: np.ndarray): 231 | """Set the activation levels of the input neurons.""" 232 | self.neurons_input.set({"rate": rate}) 233 | time = nest.GetKernelStatus("time") 234 | self.neurons_input.set({"origin": time}) 235 | self.neurons_input.set({"stop": gv.cycle_length}) 236 | 237 | def get_weights(self) -> Weightstorage: 238 | """ 239 | 240 | :return: a dict where a tuple of source target returns the weight 241 | """ 242 | data = self.conns_nest.get(keys={"weight", "source", "target"}) 243 | # save in history 244 | self.lastweights = np.array((data["source"], data["target"], data["weight"])).T 245 | return self.lastweights 246 | 247 | def restorelastweights(self): 248 | """Sets the last set weights to the nest back-end""" 249 | self.set_weights(self.lastweights) 250 | 251 | def load(self, mongoid: str): 252 | from models import trainingrun 253 | weights = trainingrun.Episode.objects(id=mongoid).first().weights 254 | #unpickle 255 | wlist = pickle.loads(weights) 256 | self.set_weights(wlist) 257 | 258 | def set_weights(self, weights: Weightstorage, pick=False): 259 | """ 260 | Assume that the back-end contains the connections where this was extracted. 261 | :param weights: source, target, weights 262 | :param pick: if true will parse by using source and target information. slower 263 | :return: 264 | """ 265 | if pick: 266 | for (source, target, weight) in weights: 267 | nest.GetConnections(source=NodeCollection([int(source)]), 268 | target=NodeCollection([int(target)]), 269 | synapse_model='stdp_dopamine_synapse_ex').set({"weight": weight}) 270 | else: 271 | self.conns_nest.set({"weight": weights[:, 2]}) 272 | 273 | def create_population(self, num_neurons: int, initial: bool, recurrent=False) -> Tuple[NodeCollection, List[int]]: 274 | """Creates nest neurons and automatically connects""" 275 | 276 | # back-end 277 | neurons_nest = nest.Create(gv.neuronmodel, num_neurons) 278 | if self.neurons_nest_ex is None: 279 | self.neurons_nest_ex = neurons_nest 280 | else: 281 | self.neurons_nest_ex += neurons_nest 282 | self.neurons_nest = neurons_nest 283 | neurons_list = neurons_nest.tolist() 284 | self.populations_nest.append(neurons_nest) 285 | #add a detector 286 | detector = nest.Create('spike_detector') 287 | self.spike_detectors_populations.append(detector) 288 | self.last_num_spikes.append(0) 289 | nest.Connect(neurons_nest, detector) 290 | 291 | if initial and recurrent: 292 | for neur_a in neurons_list: 293 | for neur_b in neurons_list: 294 | if neur_a != neur_b: # no autapses 295 | self.add_connection(neur_a, neur_b, True) 296 | 297 | return neurons_nest, neurons_list 298 | 299 | def add_connection(self, from_nid, to_nid, synapse_type: bool): 300 | """ 301 | Connects lists or single neurons in an all to all fashion to the front-end. 302 | :param from_nid: global neuron id 303 | :param to_nid: global neuron id 304 | :param synapse_type: 305 | :return: 306 | """ 307 | # resolve tuples 308 | if isinstance(from_nid, collections.Sequence): 309 | for f in from_nid: 310 | if isinstance(to_nid, collections.Sequence): 311 | for t in to_nid: 312 | if f != t: 313 | self.add_connection(f, t, synapse_type) 314 | else: 315 | if f != to_nid: 316 | self.add_connection(f, to_nid, synapse_type) 317 | return 318 | elif isinstance(to_nid, collections.Sequence): 319 | for t in to_nid: 320 | if from_nid != t: 321 | self.add_connection(from_nid, t, synapse_type) 322 | return 323 | 324 | self.conns.append((from_nid, to_nid)) 325 | if synapse_type: 326 | self.conns_ex.append((from_nid, to_nid)) 327 | else: 328 | self.conns_in.append((from_nid, to_nid)) 329 | 330 | def update_connections_nest(self): 331 | """ 332 | Before calling this method, connect eveything in the front and back-end. Updates the references to the nest back-end based on the front-end connectome 333 | """ 334 | #todo should include inhibtiory, todo skip first becaue of some bug when concating synapse collection 335 | self.conns_nest: SynapseCollection = nest.GetConnections(synapse_model='stdp_dopamine_synapse_ex') 336 | self.conns_nest_ex: SynapseCollection = nest.GetConnections(synapse_model='stdp_dopamine_synapse_ex') 337 | if len(self.neurons_nest_in) > 0: 338 | self.conns_nest_in: SynapseCollection = nest.GetConnections(source=self.neurons_nest_in, 339 | target=self.neurons_nest_in, 340 | synapse_model='stdp_dopamine_synapse_in') 341 | 342 | def remove_weak_conns(self, connlist, model): 343 | """ 344 | Checks every synapse and removes weak ones in connlis. 345 | :param connlist: 346 | :param model: 347 | :return: 348 | """ 349 | 350 | def checkIfMatch(a, b): 351 | return a[0] == b[0] and a[1] == b[1] 352 | 353 | if len(connlist) == 0: 354 | return False 355 | removed = False 356 | tobe_removed = [] 357 | for i, conn in enumerate(connlist): 358 | # get directly form backend because we are editing the back-end, which will outdate teh connection 359 | nestconn = nest.GetConnections(source=nest.NodeCollection([conn[0]]), target=nest.NodeCollection([conn[1]])) 360 | w = nestconn.get({"weight"})["weight"] 361 | if abs(w) < gv.strp_min: 362 | print(f"Disconnecting {conn}") 363 | nest.Disconnect(nest.NodeCollection([conn[0]]), 364 | nest.NodeCollection([conn[1]]), 365 | syn_spec={'synapse_model': model}) 366 | tobe_removed.append(conn) 367 | removed = True 368 | 369 | # delete from conns 370 | connpair = (conn[0], conn[1]) 371 | if model == "stdp_dopamine_synapse_ex": 372 | for i, n in enumerate(self.conns_ex): 373 | if checkIfMatch(n, connpair): 374 | del self.conns_ex[i] 375 | break 376 | else: 377 | for i, n in enumerate(self.conns_in): 378 | if checkIfMatch(n, connpair): 379 | del self.conns_in[i] 380 | break 381 | for i, n in enumerate(self.conns): 382 | if checkIfMatch(n, connpair): 383 | del self.conns[i] 384 | break 385 | if removed: 386 | self.synapsecontingent += len(tobe_removed) 387 | for delete in tobe_removed: 388 | if delete in connlist: 389 | # might be already deleted when using conns_ex or conns_in 390 | connlist.remove(delete) 391 | self.update_connections_nest() 392 | 393 | return removed 394 | 395 | def update_structural_plasticity(self): 396 | """Cyclic update to check if connectiosn should be removed.""" 397 | if self.cycles_to_reconnect <= 0: 398 | removed = False 399 | removed |= self.remove_weak_conns(self.conns_ex, "stdp_dopamine_synapse_ex") 400 | removed |= self.remove_weak_conns(self.conns_in, "stdp_dopamine_synapse_in") 401 | if removed: 402 | self.update_connections_nest() 403 | self.cycles_to_reconnect = 2 404 | else: 405 | self.cycles_to_reconnect -= 1 406 | 407 | def get_outspikes(self, recent=True) -> List[List[float]]: 408 | """Get the aggregated spike times for a population 409 | :recent: if false returns signal over whole trial""" 410 | spikes = [[]] * len(self.spike_detectors_populations) 411 | for pop_i, detector in enumerate(self.spike_detectors_populations): 412 | if recent: 413 | # filter spikes since last cycle 414 | spikes[pop_i] = detector.get({"events"})["events"]["times"][self.last_num_spikes[pop_i]:] 415 | else: 416 | spikes[pop_i] = detector.get({"events"})["events"]["times"] 417 | return spikes 418 | 419 | def end_cycle(self, cyclenum): 420 | for i, pop in enumerate(self.spike_detectors_populations): 421 | self.last_num_spikes[i] = len(pop.get({"events"})["events"]["times"]) 422 | 423 | def random_reconnect(self): 424 | """Adds connection from neurons which recently fired to a random target. Can result in no change.""" 425 | # todo use self.getrecentfiring() 426 | 427 | # calculate neurons which fired in the last cycle 428 | spikesenders: List[float] = self.spike_detector.get({"events"})["events"]["senders"] 429 | # filter spikes since last cycle 430 | spikesenders: List[float] = spikesenders[self.last_num_spikes:] 431 | neurons_fired_cycle = np.unique(spikesenders) 432 | 433 | source = np.random.choice(neurons_fired_cycle, 1)[0] 434 | type = "excitatory" if source in self.neur_ids_ex else "inhibitory" 435 | # get synapes where there is zero weight 436 | noconn_from_source = set(np.where(self.actor.lastweightsmatrix[source, :] == 0)[0]) 437 | candidates = set(self.neur_ids_core) & noconn_from_source 438 | if len(candidates) > 1 and self.synapsecontingent > 0: 439 | self.synapsecontingent -= 1 440 | # no self connection 441 | candidates.remove(source) 442 | target = np.random.choice(list(candidates), 1)[0] 443 | print(f"random connect of {source}->{target}") 444 | # add to front-end 445 | self.add_connection(source, target, type) 446 | nest.set_verbosity("M_ERROR") 447 | nest.Connect(nest.NodeCollection([source]), nest.NodeCollection([target]), 448 | syn_spec={ 449 | 'synapse_model': 'stdp_dopamine_synapse_in' if type == "inhibitory" else 'stdp_dopamine_synapse_ex'}) 450 | nest.set_verbosity("M_WARNING") 451 | synapse = nest.GetConnections(source=nest.NodeCollection([source]), target=nest.NodeCollection([target])) 452 | if type == "inhibitory": 453 | synapse.set({"weight": -gv.w0_min}) 454 | else: 455 | synapse.set({"weight": gv.w0_min}) 456 | # indices have changes so update everything 457 | self.update_connections_nest() 458 | 459 | def drawspikes(self): 460 | draw.spikes(nest.GetStatus(self.spike_detector)[0]["events"], 461 | outsignal=self.get_outspikes(recent=False), 462 | output_ids=self.neur_ids_out) -------------------------------------------------------------------------------- /critic.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABCMeta, abstractmethod 3 | 4 | import numpy as np 5 | from typing import List, Dict, NamedTuple, Any, Type, Tuple, Union 6 | import matplotlib.pyplot as plt 7 | from gym.spaces import Box 8 | from sklearn.neighbors import KNeighborsRegressor 9 | from dataclasses import dataclass 10 | 11 | from globalvalues import gv 12 | 13 | State = Tuple[float, ...] 14 | BucketIndex = Tuple[int, ...] 15 | Rewards = Union[float, Tuple[float, ...]] 16 | 17 | 18 | def hash_bucketed(bucket: BucketIndex) -> int: 19 | """get the number of the bucket""" 20 | return hash(bucket) 21 | 22 | 23 | @dataclass 24 | class RewardEntry: 25 | """Used for recording teh values to a percept.""" 26 | state_index: BucketIndex 27 | rewards: List[Rewards] # recordings of previous rewards 28 | 29 | 30 | class AbstractCritic(metaclass=ABCMeta): 31 | """ 32 | Contains all the methods needed to implement a critic. 33 | """ 34 | 35 | def __init__(self, obsranges: np.ndarray, recordingtype: Type = RewardEntry): 36 | self.tickentries: List[BaselineEntry.TicksEpisode] = [] # stores rewards in this trial 37 | self.entrytype = recordingtype 38 | self.recordings: Dict[int, recordingtype] = dict() # maps hash to rewards and utility 39 | self.bucketsperdim = gv.criticresolution # good default value found by simulations on simple net 40 | if gv.workerdata is not None and "rescritic" in gv.workerdata: 41 | self.bucketsperdim = gv.workerdata["rescritic"] 42 | self.completedATrial = False 43 | self.state_limits = obsranges # the ranges where the states are expected 44 | self.bucketing_steps: np.ndarray = (self.state_limits[:, 1] - self.state_limits[:, 0]) / self.bucketsperdim 45 | self.draw_limits: np.ndarray = self.state_limits.T # the state limts which are drawn, first row is minimum, second is maximum 46 | self.knn_util = KNeighborsRegressor(n_neighbors=2, weights="distance") 47 | self.displayrange = None 48 | self.maxm: float = 6 # best default value found by simulations on simple net 49 | 50 | # @final #final is not supported in py37 51 | def tick(self, state: State, new_rewards: Rewards) -> (float, float): 52 | """ 53 | input the states and rewards for a frame and get the utility back. Returns a single reward value based on teh new rewards. Later should return utility. 54 | :param state: 55 | :param new_rewards: the rewards for this state 56 | :returns reward signal and absolute utility (rating of this state) 57 | """ 58 | errsignal, util = self._tick(state, new_rewards) 59 | 60 | if errsignal < 0: 61 | errsignal *= gv.factor_negative_util 62 | return errsignal, util 63 | 64 | @abstractmethod 65 | def _tick(self, state: State, new_rewards: Rewards) -> (float, float): 66 | """ 67 | Intenral tick function. Needs to be overwritten by implementation. 68 | :param new_rewards: 69 | :return: 70 | """ 71 | pass 72 | 73 | def bucketandhash(self, state: State) -> int: 74 | """get the hash numebr of the bucket""" 75 | return hash_bucketed(self.bucket_index(state)) 76 | 77 | def bucket_index(self, state: State) -> BucketIndex: 78 | """get the index when put it into a bucket. Can be easily used to get the bucketed state.""" 79 | # floordiv 80 | return tuple((v - self.state_limits[dim, 0]) // self.bucketing_steps[dim] for dim, v in enumerate(state)) 81 | 82 | def bucket_index_floating(self, state: State) -> Tuple[float, ...]: 83 | """get the index when put it into a bucket. Can be easily used to get the bucketed state.""" 84 | # floordiv 85 | return tuple((v - self.state_limits[dim, 0]) / self.bucketing_steps[dim] for dim, v in enumerate(state)) 86 | 87 | def query_bucket(self, state: State, bucket: int = 0, bucketed=None) -> Any: 88 | """return the bucket content for a already bucketed state""" 89 | if bucketed is None: 90 | bucketed = self.bucket_index(state) 91 | if bucket == 0: 92 | bucket = hash_bucketed(bucketed) 93 | 94 | if self.completedATrial < self.knn_util.n_neighbors: 95 | return self.entrytype(bucketed, [], .0, None) 96 | elif bucket in self.recordings: 97 | return self.recordings[bucket] 98 | else: # inference 99 | return self.entrytype(bucketed, [], self.knn_util.predict([self.bucket_index_floating(state)])[0], None) 100 | 101 | def query(self, state: State) -> Any: 102 | """Get the utility for any unbucketed state""" 103 | if not self.completedATrial: 104 | raise AssertionError("Trial not ended. Call end_episode after collecting data.") 105 | 106 | return self.query_bucket(state=state, bucket=self.bucketandhash(state)) 107 | 108 | # @final 109 | def end_episode(self): 110 | """Called at the end of a trial""" 111 | self._end_episode() 112 | 113 | def valueiterate(self, trajectory): 114 | entry = self.recordings[trajectory[-1].bucket] 115 | if entry.utility is not None: 116 | reward = np.average(entry.rewards) 117 | entry.utility += gv.util_learn_rate * (reward - entry.utility) # exp. moving average 118 | prev_util = entry.utility # the very first will be None 119 | if prev_util is None: 120 | entry.utility = np.average(entry.rewards) 121 | prev_util = entry.utility 122 | #propagate backwards but skip the first 123 | for (bucket, rewards) in reversed(trajectory[:-1]): 124 | entry = self.recordings[bucket] 125 | entry_util = 0 if entry.utility is None else entry.utility 126 | td_error = gv.util_learn_rate * (np.average(rewards) + gv.util_discount_factor * prev_util - entry_util) 127 | entry.utility = entry_util + td_error 128 | prev_util = entry.utility 129 | 130 | trajectory.clear() 131 | 132 | def _end_episode(self): 133 | """Overwrite to specifiy what should be performed when a trial ended.""" 134 | # implements value iteration but assumes that current policy is optimal policy 135 | self.valueiterate(self.tickentries) 136 | 137 | # fit utility predictor 138 | self.completedATrial = len(self.recordings.values()) 139 | states_index = [] 140 | utility = [] 141 | for rec in self.recordings.values(): 142 | states_index.append(rec.state_index) 143 | utility.append(rec.utility) 144 | self.knn_util.fit(states_index, utility) 145 | 146 | 147 | @dataclass 148 | class BaselineEntry(RewardEntry): 149 | utility: float 150 | baseline: float 151 | 152 | 153 | class DynamicBaseline(AbstractCritic): 154 | """Implements critic as in wunderlich et al. 2019 """ 155 | TicksEpisode = NamedTuple("trialentry", [("bucket", int), ("rewards", Rewards)]) 156 | 157 | def __init__(self, obsranges: np.ndarray, state_labels=None): 158 | super().__init__(obsranges, recordingtype=BaselineEntry) 159 | if state_labels is None: 160 | # todo should be generic "observation 1" 161 | state_labels = ["Cart Position", "Cart Velocity", "Pole Angle", "Pole Velocity At Tip"] 162 | self.state_labels = state_labels 163 | self.bias = 0.0 # bias value is added each tick 164 | self.last_r = None 165 | self.last_util = 0 166 | self.numbaselines = 0 167 | self.knn_baseline = KNeighborsRegressor(n_neighbors=2, weights="distance") 168 | 169 | learning_rate = 0.9 # for baseline 170 | 171 | def _end_episode(self): 172 | super()._end_episode() 173 | self.last_r = None 174 | self.last_util = None 175 | 176 | if gv.dynamic_baseline: 177 | #fit baseline predictor 178 | states = [] 179 | baselines = [] 180 | self.numbaselines = 0 181 | for rec in self.recordings.values(): 182 | if rec.baseline is not None: 183 | states.append(rec.state_index) 184 | baselines.append(rec.baseline) 185 | self.numbaselines += 1 186 | self.knn_baseline.fit(states, baselines) 187 | 188 | def _tick(self, state: State, new_rewards: Rewards) -> (float, float): 189 | """ 190 | 191 | :param state: 192 | :param new_rewards: 193 | :return: 194 | """ 195 | bucketed_state = self.bucket_index(state) 196 | bucket = hash_bucketed(bucketed_state) 197 | new_reward_reduced = np.average(new_rewards) # reduce reward to a float 198 | self.tickentries.append(DynamicBaseline.TicksEpisode(bucket=bucket, rewards=new_rewards)) 199 | 200 | # check if can retrieve utility 201 | entry = self.query_bucket(state, bucket, bucketed=bucketed_state) 202 | # save if new 203 | self.recordings[bucket] = entry 204 | # add new reward to recordings 205 | entry.rewards.append(new_rewards) 206 | 207 | # the utility of this state is the current reward plus future rewards 208 | util: float = new_reward_reduced + entry.utility 209 | # if it can compute delta_r/u 210 | if self.last_r is not None: 211 | # try returning the delta in utility for R-STDP, in the beginning return 0 212 | errsignal = util - self.last_util 213 | if gv.dynamic_baseline: 214 | if entry.baseline is None: 215 | # init with r0 216 | if self.numbaselines >= self.knn_baseline.n_neighbors: 217 | entry.baseline = self.knn_baseline.predict([self.bucket_index_floating(state)])[0] 218 | else: 219 | entry.baseline = errsignal # δu because we don't want to introduce a bias 220 | else: 221 | # exp weighted update 222 | entry.baseline += DynamicBaseline.learning_rate * errsignal 223 | # subtract dynamic baseline 224 | errsignal -= entry.baseline 225 | errsignal = np.clip(errsignal, -self.maxm, self.maxm) + self.bias 226 | else: 227 | errsignal = 0 228 | self.last_r = new_reward_reduced # to compute delta r in next step 229 | self.last_util = util 230 | return errsignal, entry.utility 231 | 232 | def draw_rewards(self, xaxis=3, yaxis=2, show=True): 233 | """ 234 | Draw recorded reward map 235 | """ 236 | if not self.completedATrial: 237 | raise AssertionError("Trial not ended. Call end_trial after collecting data.") 238 | field = np.full((self.bucketsperdim, self.bucketsperdim), np.nan) 239 | # get reward values for each bucket 240 | for entry in self.recordings.values(): 241 | # normalize then map to bucketsperdim 242 | coords = np.array(entry.state_index, dtype=np.int) 243 | # ignore recorded states outside observed range 244 | if 0 < coords[xaxis] < self.bucketsperdim and 0 < coords[yaxis] < self.bucketsperdim: 245 | field[coords[xaxis], coords[yaxis]] = np.average(entry.rewards) 246 | current_cmap = plt.cm.get_cmap("inferno") 247 | current_cmap.set_bad(color='green') 248 | heatmap = plt.imshow(field, cmap=current_cmap, interpolation='nearest') 249 | plt.title("Reward for States") 250 | steps = 8 251 | labels_x = [f"{x:.1f}" for x in 252 | np.linspace(self.state_limits.T[0, xaxis], self.state_limits.T[1, xaxis], num=steps)] 253 | labels_y = [f"{y:.1f}" for y in 254 | np.linspace(self.state_limits.T[0, yaxis], self.state_limits.T[1, yaxis], num=steps)] 255 | plt.xticks(np.linspace(0, self.bucketsperdim, num=steps), labels_x) 256 | plt.yticks(np.linspace(0, self.bucketsperdim, num=steps), labels_y) 257 | plt.xlabel(self.state_labels[xaxis]) 258 | plt.ylabel(self.state_labels[yaxis]) 259 | if show: 260 | plt.show() 261 | return heatmap 262 | 263 | def draw_utility(self, xaxis=3, yaxis=2, show=True): 264 | """ 265 | Draw recorded reward map 266 | """ 267 | if not self.completedATrial: 268 | raise AssertionError("Trial not ended. Call end_trial after collecting data.") 269 | field = np.full((self.bucketsperdim, self.bucketsperdim), np.nan) 270 | # get reward values for each bucket 271 | for entry in self.recordings.values(): 272 | # normalize then map to bucketsperdim 273 | coords = np.array(entry.state_index, dtype=np.int) 274 | # ignore recorded states outside observed range 275 | if 0 < coords[xaxis] < self.bucketsperdim and 0 < coords[yaxis] < self.bucketsperdim: 276 | field[coords[xaxis], coords[yaxis]] = entry.utility 277 | current_cmap = plt.cm.get_cmap("inferno") 278 | current_cmap.set_bad(color='green') 279 | heatmap = plt.imshow(field, cmap=current_cmap, interpolation='nearest') 280 | plt.title("Value for Visited States") 281 | 282 | steps = 8 283 | labels_x = [f"{x:.1f}" for x in 284 | np.linspace(self.state_limits.T[0, xaxis], self.state_limits.T[1, xaxis], num=steps)] 285 | labels_y = [f"{y:.1f}" for y in 286 | np.linspace(self.state_limits.T[0, yaxis], self.state_limits.T[1, yaxis], num=steps)] 287 | plt.xticks(np.linspace(0, self.bucketsperdim, num=steps), labels_x) 288 | plt.yticks(np.linspace(0, self.bucketsperdim, num=steps), labels_y) 289 | plt.xlabel(self.state_labels[xaxis]) 290 | plt.ylabel(self.state_labels[yaxis]) 291 | if show: 292 | plt.show() 293 | return heatmap 294 | 295 | def draw_utility_inferred(self, xaxis=3, yaxis=2, show=True, legend=None): 296 | """ 297 | Draw recorded reward map 298 | """ 299 | # interpolate with querying when available, else interpolate image 300 | if self.state_limits.shape[0] == 2: 301 | # get reward values for each bucket, but limit resolution 302 | res = np.minimum(self.bucketsperdim, 100) 303 | field = np.empty((res, res)) 304 | range_x = np.linspace(self.state_limits.T[0, xaxis], self.state_limits.T[1, xaxis], res) 305 | range_y = np.linspace(self.state_limits.T[0, yaxis], self.state_limits.T[1, yaxis], res) 306 | # loop over the states in the graph 307 | # only two dimension can be iterated and visualized. Therefore, 308 | # use a state to pick a prototype which can be visualized by modifieng only the otehr states 309 | bestrecording = next(iter(self.recordings.values())) 310 | # find best state in O(n) 311 | for state in self.recordings.values(): 312 | if state.utility > bestrecording.utility: 313 | bestrecording = state 314 | prototypestate = list(bestrecording.state_index) # create a mutable copy 315 | for x, stateX in enumerate(range_x): 316 | for y, stateY in enumerate(range_y): 317 | prototypestate[xaxis] = stateX 318 | prototypestate[yaxis] = stateY 319 | field[x, y] = self.query_bucket(state=tuple(prototypestate)).utility 320 | else: 321 | res = self.bucketsperdim 322 | interpollist = [] 323 | for entry in self.recordings.values(): 324 | # normalize then map to bucketsperdim 325 | coords = np.array(entry.state_index, dtype=np.int) 326 | # ignore recorded states outside observed range 327 | if 0 < coords[xaxis] < self.bucketsperdim and 0 < coords[yaxis] < self.bucketsperdim: 328 | interpollist.append((coords[xaxis], coords[yaxis], entry.utility)) 329 | interpollist = np.array(interpollist) 330 | from scipy.interpolate import griddata 331 | xi = np.arange(0, self.bucketsperdim) 332 | yi = np.arange(0, self.bucketsperdim) 333 | field = griddata(interpollist[:, 0:2], interpollist[:, 2], (xi[None, :], yi[:, None]), method='nearest').T 334 | current_cmap = plt.cm.get_cmap("inferno") 335 | current_cmap.set_bad(color='green') 336 | plt.imshow(field, cmap=current_cmap, interpolation='nearest') 337 | if legend is not None: 338 | plt.clim(legend.norm.vmin, legend.norm.vmax) 339 | 340 | plt.title("Interpolated State-Values for every State") 341 | 342 | steps = 8 343 | labels_x = [f"{x:.1f}" for x in 344 | np.linspace(self.state_limits.T[0, xaxis], self.state_limits.T[1, xaxis], num=steps)] 345 | labels_y = [f"{y:.1f}" for y in 346 | np.linspace(self.state_limits.T[0, yaxis], self.state_limits.T[1, yaxis], num=steps)] 347 | plt.xticks(np.linspace(0, res, num=steps), labels_x) 348 | plt.yticks(np.linspace(0, res, num=steps), labels_y) 349 | plt.xlabel(self.state_labels[xaxis]) 350 | plt.ylabel(self.state_labels[yaxis]) 351 | if show: 352 | plt.show() 353 | 354 | def draw(self, xaxis=3, yaxis=2): 355 | """Draws plots showing different state value function mappings..""" 356 | if len(list(self.recordings.values())[0].state_index) < 2: 357 | print("Cannot draw state maps with only one input dimension.") 358 | return 359 | fig, axis = plt.subplots(1, 4, figsize=(18, 4)) 360 | fig.tight_layout() 361 | # fig.suptitle("State Maps") 362 | plt.subplot(141) 363 | self.draw_rewards(xaxis=xaxis, yaxis=yaxis, show=False) 364 | plt.subplot(142) 365 | legend = self.draw_utility(xaxis=xaxis, yaxis=yaxis, show=False) 366 | plt.subplot(143) 367 | self.draw_utility_inferred(xaxis, yaxis, show=False, legend=legend) 368 | ax3 = plt.subplot(144) 369 | ax3.axis('off') 370 | # might be improved with this: https://stackoverflow.com/a/38940369/2768715 371 | plt.colorbar(legend) 372 | # find free filename 373 | counter = 0 374 | filename = f"utility{counter}.pdf" 375 | while os.path.isfile(filename): 376 | counter += 1 377 | filename = f"utility{counter}.pdf" 378 | plt.savefig(filename) 379 | 380 | 381 | AbstractCritic.register(DynamicBaseline) 382 | -------------------------------------------------------------------------------- /draw.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import Dict, List, Union 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | import fremauxfilter 9 | from actor import Weightstorage 10 | from globalvalues import gv 11 | import datetime 12 | from matplotlib.colors import LightSource 13 | 14 | 15 | def voltage(measurements, persp="3d"): 16 | """ 17 | 18 | :param measurements: nest.GetStatus(multimeter)[0]["events"] 19 | :param persp: 20 | :return: 21 | """ 22 | # draw voltage curves 23 | fig = plt.figure(num=0, figsize=(6.5, 3.5), dpi=150) 24 | if persp == "3d": 25 | ax = fig.add_subplot(111, projection='3d') 26 | ax.set_proj_type('ortho') 27 | for n in range(gv.voltageRecordings): 28 | voltage = measurements["V_m"][n::len(gv.voltageRecordings)] # skip 6, start with offset 29 | times = measurements["times"][n::len(gv.voltageRecordings)] 30 | ax.plot(xs=times % 100, ys=times / 100, zs=voltage, label=f"$V_m$ {list(gv.neuronLabelMap.values())[n]}") 31 | plt.legend() 32 | plt.show() 33 | else: 34 | ax = fig.add_subplot(111) 35 | for neuron in gv.voltageRecordings: 36 | voltage = measurements["V_m"][ 37 | neuron::len(gv.voltageRecordings)] # pick every numMeasurements, start with offset n 38 | times = measurements["times"][neuron::len(gv.voltageRecordings)] 39 | ax.plot(times, voltage, label=f"$V_m$ {neuron}: {gv.neuronLabelMap[neuron]}") 40 | plt.legend() 41 | plt.ylabel("voltage [mV]") 42 | plt.xlabel("time [ms]") 43 | plt.title("Voltage over time") 44 | plt.show() 45 | 46 | 47 | def spikes(spikes_nest: Dict, outsignal: List[List[float]], output_ids: List[int] = []): 48 | """ 49 | Draws only recorded spikes. 50 | :param outsignal: 51 | :param spikes_nest: 52 | :param output_ids: pass ids of output channel for different color and further analysis 53 | :return: 54 | """ 55 | last_spiketime: int = spikes_nest["times"][-1] 56 | lastvalid = np.arange(0, last_spiketime, gv.cycle_length)[-1] #lazy hack 57 | # create list of signals from nest format 58 | spiketimes = {} # dicts are only filled for non-outspikes 59 | colors = [] 60 | colorcounter = 0 61 | onlyout = False 62 | import matplotlib.colors as mcolors 63 | for i, spiketime in enumerate(spikes_nest["times"]): 64 | if spiketime > lastvalid: 65 | continue 66 | neurid = spikes_nest['senders'][i] 67 | # if the first time this neuron appeared create new list 68 | if neurid not in spiketimes: 69 | if not onlyout or neurid in output_ids: 70 | spiketimes[neurid] = [] 71 | # sort into categories 72 | if neurid in output_ids: 73 | colors.append(list(mcolors.TABLEAU_COLORS)[colorcounter%10]) 74 | colorcounter += 1 75 | else: 76 | colors.append("black") 77 | if not onlyout or neurid in output_ids: 78 | spiketimes[neurid].append(spiketime) 79 | 80 | # put labels on it 81 | labels = [] 82 | if len(gv.neuronLabelMap) > 0: 83 | for key in list(spiketimes.keys()): 84 | label = gv.neuronLabelMap[key] if key in gv.neuronLabelMap else "" 85 | labels.append(label) 86 | 87 | # now draw 88 | height = max(2 + len(spiketimes) / 3, 15) 89 | fig = plt.figure(figsize=(14, height)) 90 | 91 | plt.subplot(411) 92 | plt.eventplot(spiketimes.values(), linewidths=0.8, colors=colors) 93 | plt.yticks(range(len(spiketimes.keys())), labels) 94 | # include upper limit 95 | xticks = np.arange(0, last_spiketime, gv.cycle_length) 96 | if len(xticks) < 500: 97 | # only draw xticks if not noo much 98 | plt.xticks(xticks) 99 | plt.grid() 100 | plt.margins(x=0.03)#kinda misaligned because it does not start spiking at 0 and last cycle time 101 | plt.ylabel("Neuron") 102 | plt.xlabel("time [ms]") 103 | plt.title('Spike Events per Neuron') 104 | 105 | if len(output_ids) > 0: 106 | # plot filtered activity 107 | plt.subplot(412) 108 | outsignal_dict: Dict[str, List[float]] = dict() 109 | for i, population in enumerate(outsignal): 110 | outsignal_dict[f"population{i}"] = population 111 | filtersig = filtered_signal(last_spiketime, outsignal_dict) 112 | if len(xticks) < 500: 113 | plt.xticks(xticks) 114 | plt.margins(x=0.02) 115 | plt.xlabel("time [ms]") 116 | plt.ylabel("Activity") 117 | 118 | # plot sampled activity 119 | plt.subplot(413) 120 | read_out_activity(last_spiketime, outsignal_dict, filtered_signal=filtersig) 121 | if len(xticks) < 500: 122 | plt.xticks(xticks) 123 | plt.margins(x=0.02) 124 | plt.xlabel("time [ms]") 125 | plt.ylabel("Activity") 126 | 127 | # plot ?? 128 | plt.subplot(414) 129 | insignal: Dict[str, List[float]] = dict() 130 | insignals = list(spiketimes.values()) 131 | for i, population in enumerate(insignals): 132 | insignal[f"population{i}"] = population 133 | read_out_activity(last_spiketime, insignal) 134 | plt.margins(x=0.02) 135 | if len(xticks) < 500: 136 | plt.xticks(xticks) 137 | plt.xlabel("time [ms]") 138 | plt.show() 139 | 140 | 141 | def filtered_signal(last_spiketime: int, spikes_per_id: Dict[str, List[float]]): 142 | """Draw a plot with fremaux fitlering""" 143 | signal = np.empty((len(spikes_per_id), int(last_spiketime))) # for every ms 144 | 145 | outspikes_list = list(spikes_per_id.values()) 146 | for t in range(0, int(last_spiketime)): 147 | signal[:, t] = fremauxfilter.filter(t, outspikes_list) 148 | 149 | # draw 150 | plt.title('Continuous, Filtered Out-Neurons Activity') 151 | plt.plot(signal.T) 152 | plt.ylabel("Activity") 153 | ax = plt.gca() 154 | ax.xaxis.grid(True) 155 | plt.legend(spikes_per_id.keys()) 156 | return signal 157 | 158 | 159 | def read_out_activity(last_spiketime: int, spikes_per_id: Dict[str, List[float]], filtered_signal=None): 160 | """draws a plot showing the activity of the out neurons when read.""" 161 | # data for activity diagram 162 | activity_in_cycle = np.zeros((len(spikes_per_id),#yaxis 163 | int(last_spiketime // gv.cycle_length) + 1))#time axis 164 | nidx = 0 165 | # for each neuron 166 | for spiketimes in spikes_per_id.values(): 167 | if len(spiketimes) == 0: 168 | continue 169 | for spike in spiketimes: 170 | cycle = int(spike // gv.cycle_length) + 1 # offset of one bc 0-50 is cycle 1 171 | # might be in no cycle 172 | if cycle < len(activity_in_cycle[nidx]): 173 | activity_in_cycle[nidx][cycle] += 1 174 | nidx += 1 175 | # xcoordinates 176 | sample_times = range(0, int(activity_in_cycle.shape[1] * gv.cycle_length), int(gv.cycle_length)) 177 | 178 | # filtered signal only at sample time 179 | if filtered_signal is not None: 180 | filtered_signal_sampled = np.empty_like(activity_in_cycle) 181 | for i, t in enumerate(range(0, filtered_signal.shape[1], int(gv.cycle_length))): 182 | filtered_signal_sampled[:, i] = filtered_signal[:, t] 183 | 184 | # use least squares to find a scaling that almost fits everywhere 185 | tominimize = lambda scaling: activity_in_cycle.flatten() - scaling * filtered_signal_sampled.flatten() 186 | import scipy 187 | scaling = scipy.optimize.leastsq(tominimize, x0=10)[0][0] 188 | # normalize 189 | # scaling = np.max(activity_in_cycle) / np.max(filtered_signal_sampled) 190 | filtered_signal_sampled *= scaling 191 | plt.plot(sample_times, filtered_signal_sampled.T, drawstyle='steps-post', 192 | label="Filtered Signal Sampled (normalized)") 193 | 194 | plt.title(f"Sampled Activity Out-Neurons (Scaling ({scaling:.2f})") 195 | else: 196 | plt.title(f"Sampled Activity Out-Neurons") 197 | #todo use neuron output id 198 | labels = None if len(spikes_per_id) > 5 else "Number of Spikes per Cycle" 199 | # draw 200 | plt.plot(sample_times, activity_in_cycle.T, drawstyle='steps-post', label=labels) 201 | plt.ylabel("Activity") 202 | ax = plt.gca() 203 | ax.xaxis.grid(True) 204 | plt.legend() 205 | 206 | 207 | def return_graph(rewards, show=True, ax=None): 208 | """A plot showing the average reward over time.""" 209 | plt.plot(rewards, label="Returns") 210 | 211 | def running_mean(x, N: int): 212 | cumsum = np.cumsum(np.insert(x, 0, 0)) 213 | return (cumsum[N:] - cumsum[:-N]) / float(N) 214 | 215 | try: 216 | plt.plot(running_mean(rewards, min(int(len(rewards) / 4), 40)), label="moving average") 217 | except ValueError: 218 | pass 219 | print(f"final avg. reward of last 100 {np.average(rewards[-100:])}/195") 220 | plt.legend() 221 | plt.ylabel("Return") 222 | plt.xlabel("Trial Number") 223 | plt.title('Return per trial') 224 | if ax is not None: 225 | ax.axhline(y=np.average(rewards), xmin=0.0, xmax=1.0, linestyle='--', dashes=(0.5, 5.)) 226 | if show: 227 | plt.show() 228 | 229 | 230 | def weight_changes(syn_w: np.ndarray, connections: np.ndarray = None, show=True): 231 | """ 232 | Draw changes in weights 233 | :param syn_w: time dimension, then list of weight 234 | :param connections: should match connetions 235 | :param show: 236 | :return: 237 | """ 238 | plt.title('Weight with reward factor ' + str(gv.errsig_factor)) 239 | # syn_w = syn_w.reshape(syn_w.shape[0], syn_w.shape[1]*syn_w.shape[2]) 240 | 241 | weights = syn_w[..., 2] # all times, all weights, only weights 242 | 243 | plt.plot(range(len(weights)), weights) 244 | if connections is not None and len(connections) <= 11: 245 | labels = [f"{x[0]}->{x[1]}" for x in connections] 246 | plt.gca().legend(labels) 247 | plt.ylabel("weight") 248 | plt.xlabel("episode") 249 | if show: 250 | plt.show() 251 | 252 | 253 | def error_signal(utility, fig=None, persp="heat"): 254 | plt.title('Used error signals') 255 | max_cycles = np.max(np.argmin(utility, axis=1)) + 1 256 | valuesToShow = utility[:, :max_cycles].T # switch cycles and episodes 257 | # plt.plot(valuesToShow) 258 | 259 | if persp == "3d": 260 | cycle_axis = np.linspace(0, valuesToShow.shape[0], valuesToShow.shape[0]) 261 | episode_axis = np.linspace(0, valuesToShow.shape[1], valuesToShow.shape[1]).T 262 | 263 | sx = cycle_axis.size 264 | sy = episode_axis.size 265 | 266 | cycle_axis = np.tile(cycle_axis, (sy, 1)) 267 | episode_axis = np.tile(episode_axis, (sx, 1)).T 268 | if fig is None: 269 | newfig = plt.figure() 270 | ax = newfig.add_subplot(1, 1, 1, projection='3d') 271 | else: 272 | ax = fig.add_subplot(3, 1, 1, projection='3d') 273 | light = LightSource(315, 45) 274 | cm = plt.cm.get_cmap("inferno") 275 | azimuth = 45 276 | altitude = 60 277 | ax.view_init(altitude, azimuth) 278 | if gv.num_episodes > 1: 279 | illuminated_surface = light.shade(valuesToShow, cmap=cm) 280 | ax.plot_surface(cycle_axis, episode_axis, valuesToShow, rstride=1, cstride=1, linewidth=0, 281 | antialiased=False, facecolors=illuminated_surface, label="utilities") 282 | else: 283 | ax.plot_surface(cycle_axis, episode_axis, valuesToShow, rstride=1, cstride=1, linewidth=0, 284 | antialiased=False, label="utilities") 285 | else: 286 | current_cmap = plt.cm.get_cmap("inferno") 287 | current_cmap.set_bad(color='green') 288 | legend = plt.imshow(valuesToShow, cmap=current_cmap, interpolation='nearest') 289 | plt.colorbar(legend) 290 | plt.ylabel("cycle") 291 | plt.xlabel("episode") 292 | # plt.title("Utility per cycle") #title is covered in condensed view 293 | # plt.legend() 294 | if fig is None: 295 | plt.show() 296 | 297 | 298 | def utilities_over_time(utils): 299 | plt.plot(utils.T) 300 | plt.title("Utilities over time") 301 | plt.xlabel("cycles") 302 | plt.ylabel("Utility") 303 | plt.show() 304 | 305 | 306 | def report(utility, weights: Union[List[Weightstorage], np.ndarray], returnpereps, connections: np.ndarray, filename=None, 307 | env=None): 308 | """ 309 | Draw a report consisting of many parts 310 | :param env: 311 | :param filename: 312 | :param utility: 313 | :param weights: 314 | :param returnpereps: 315 | :param connections: 316 | :return: 317 | """ 318 | fig = plt.figure(figsize=(16, 14)) 319 | fig.suptitle("Report " + str(datetime.datetime.now())) 320 | # plt.subplot(321) 321 | # connectome(connections) 322 | plt.subplot(322) 323 | try: 324 | error_signal(utility, fig=fig) 325 | except: 326 | print("Rendering of error signal history failed.") 327 | 328 | if env is not None and not gv.headless: 329 | plt.subplot(323) 330 | try: 331 | plt.imshow(env.render(mode='rgb_array')) 332 | except: 333 | print("Rendering of env failed.") 334 | 335 | plt.subplot(324) 336 | try: 337 | if not isinstance(weights, np.ndarray): 338 | weights = np.array(weights) 339 | weight_changes(weights, connections, show=False) 340 | except: 341 | print("Rendering of weight changes failed.") 342 | 343 | plt.subplot(325) 344 | plt.text(0.0, 0.5, str(gv.workerdata), fontsize=12, wrap=True) 345 | 346 | ax = plt.subplot(326) 347 | returnpereps = np.array(returnpereps) 348 | try: 349 | nr = np.isnan(returnpereps) 350 | returnpereps[nr] = 0 351 | return_graph(returnpereps, show=False, ax=ax) 352 | except: 353 | e = sys.exc_info()[0] 354 | print("Rendering of return history failed:" + str(e)) 355 | 356 | if filename is None: 357 | # find free filename 358 | counter = 0 359 | filename = f"report{counter}.pdf" 360 | while os.path.isfile(filename): 361 | counter += 1 362 | filename = f"report{counter}.pdf" 363 | plt.savefig(filename) 364 | -------------------------------------------------------------------------------- /exp.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import multiprocessing 3 | import os 4 | import argparse 5 | import pickle 6 | import sys 7 | from typing import Tuple, List, Dict, Callable, Optional, Union 8 | 9 | import mongoengine 10 | import numpy as np 11 | import atexit 12 | import datetime 13 | 14 | from gym.envs.classic_control import CartPoleEnv 15 | from globalvalues import gv 16 | import draw 17 | import models.trainingrun as tg 18 | from agent import Agent 19 | from lineFollowingEnvironment import LineFollowingEnv 20 | from lineFollowingEnvironment2 import LineFollowingEnv2 21 | from mongoengine import Document, FileField, ListField, StringField, BinaryField, IntField, DateTimeField, FloatField, \ 22 | ReferenceField, disconnect, DictField 23 | import json 24 | 25 | 26 | class Experiment(Document): 27 | """ 28 | Creates and manages a traiing environment with a SNN RL algorithm. 29 | """ 30 | training = ReferenceField(tg.Trainingrun) 31 | parameterdump = StringField() 32 | time_start = DateTimeField() 33 | time_end = DateTimeField() 34 | time_elapsed = FloatField() # in s 35 | diagrams = ListField(FileField()) 36 | cycle_i = IntField(default=0) 37 | totalCycleCounter = IntField(default=0) 38 | episode = IntField(default=0) 39 | return_per_episode_sum = ListField(FloatField()) # per episode 40 | log_Δw = ListField(FloatField()) # per episode 41 | log_m = ListField(FloatField()) # per episode 42 | epslength = ListField(IntField()) 43 | episodedata = ListField(ReferenceField(tg.Episode))#stores references to all episodes 44 | workerdata = DictField() 45 | 46 | def __init__(self, *args, **values): 47 | super().__init__(*args, **values) 48 | gv.init() 49 | self.printedbias = False 50 | self.env = None 51 | self.penalty = -8 # penalty for ending 52 | self.errsig_contingent = [0] 53 | 54 | self.return_per_episode_sum = [] 55 | self.totalCycleCounter = -1 # will be increased at the beginning of the cycle 56 | 57 | self.log_Δw = [] 58 | self.log_m = [] 59 | self.rewards: List = [] # reward of last episode 60 | self.errsigs = None 61 | self.utils = None 62 | self.agent: Agent = None 63 | 64 | self.lastweights: np.array = 0 # initialized with 0 so that it can be used in computation 65 | 66 | self.epslength = [] # stores the number of cycles for each episode 67 | 68 | def cycle(self, observation_in: np.array) -> np.array: 69 | """Calculates brain one frame, applies action and simulates environment for a frame 70 | : observation_in: the last observation 71 | :return float values 72 | """ 73 | if gv.render: 74 | self.env.render() 75 | self.totalCycleCounter += 1 76 | # feed observations into brain 77 | action, neural_activity = self.agent.actor.cycle(time=gv.cycle_length * self.cycle_i, 78 | observation_in=observation_in) 79 | # simulate environment 80 | observation, reward, done, info = self.env.step(action) 81 | reward_internal = reward 82 | # distance from ideal position 83 | # if isinstance(self.env.env, CartPoleEnv): 84 | # reward_internal = 50 * np.math.cos(observation[2]) 85 | if not self.printedbias: 86 | print("Bias: " + str(reward + self.penalty)) 87 | self.printedbias = True 88 | 89 | try: # try because of env.env 90 | if done and not (isinstance(self.env.env, CartPoleEnv) and self.cycle_i >= 200): 91 | #add a penalty for cartpole when failed 92 | reward_internal += self.penalty 93 | except: 94 | pass 95 | 96 | err_signal, util = self.agent.critic.tick(state=observation, new_rewards=[reward_internal]) 97 | # store unedited 98 | if not gv.demo: 99 | self.errsigs[self.episode, self.cycle_i] = err_signal 100 | # self.utils[self.episode, self.totalCycleCounter] = util 101 | self.rewards.append(reward) 102 | 103 | # clamp utility 104 | if gv.max_util_integral != float("inf"): 105 | if abs(self.errsig_contingent[-1] + err_signal) >= gv.max_util_integral: 106 | err_signal = 0 107 | self.errsig_contingent.append(self.errsig_contingent[-1] + err_signal) 108 | 109 | # gv.outactivity["utility"].append(utility) 110 | 111 | if gv.structural_plasticity: 112 | self.agent.actor.connectome.update_structural_plasticity() 113 | 114 | # Set reward signal for left and right network 115 | self.agent.actor.release_neurotransmitter(err_signal * gv.errsig_factor) 116 | 117 | self.agent.end_cycle(self.cycle_i) 118 | return done, observation 119 | 120 | def simulate_episode(self) -> bool: 121 | """Simulate one episode 122 | :return: True if everything went okay. False if training needs to be canceled 123 | """ 124 | if self.episode > 0: 125 | self.agent.prepare_episode() 126 | 127 | observation = self.env.reset() 128 | self.rewards.clear() 129 | for self.cycle_i in range(gv.max_cycles): 130 | # if failed, break early 131 | done, observation = self.cycle(observation_in=observation) 132 | if done: 133 | break 134 | # extra simulation time to apply changes in last cycle before resetting 135 | self.agent.post_episode() 136 | self.epslength.append(self.cycle_i) 137 | 138 | return self.post_episode() 139 | 140 | def post_episode(self) -> bool: 141 | """ 142 | :return: True if everything went okay. False if training needs to be canceled 143 | """ 144 | eps: tg.Episode = tg.Episode() 145 | eps.rewards = self.rewards 146 | if gv.save_to_db: 147 | eps.episode = self.episode 148 | if len(self.agent.actor.log_m)>0: 149 | eps.neuromodulator = self.agent.actor.log_m 150 | self.log_m.append(np.average(eps.neuromodulator)) 151 | 152 | # extract the last weights 153 | try: 154 | weights = np.array(list(self.agent.get_weights().values())) 155 | except: 156 | weights = self.agent.get_weights() 157 | # check if no weight changed -> Early termination 158 | Δw: float = np.sum(weights - self.lastweights) 159 | self.log_Δw.append(Δw) 160 | if gv.allow_early_termination and self.episode > 50 and -0.00001 < Δw < 0.00001: 161 | self.early_termination(eps, weights) 162 | return False 163 | self.lastweights = weights 164 | 165 | self.return_per_episode_sum.append(np.sum(self.rewards)) 166 | if gv.save_to_db: 167 | # save at the end of the training 168 | if self.episode > 0 and self.episode % (gv.num_episodes-1) == 0: 169 | self.save_episode(eps, weights) 170 | self.save() 171 | if not gv.demo: 172 | self.agent.end_episode(self.episode) 173 | return True 174 | 175 | def early_termination(self, eps, weights): 176 | print("\nEarly termination because Δw=0.") 177 | # todo log a message in the db 178 | if gv.save_to_db: 179 | #eps.activation = list(np.average(np.array(self.agent.actor.log_activation), axis=0)) 180 | eps.neuromodulator = self.agent.actor.log_m 181 | self.save_episode(eps, weights) 182 | try: 183 | self.agent.actor.connectome.drawspikes() 184 | except AttributeError: 185 | pass 186 | self.save() 187 | 188 | def save_episode(self, eps, weights): 189 | eps.weights_human = weights.tolist() 190 | eps.weights = pickle.dumps(weights) 191 | eps.save() 192 | self.episodedata.append(eps.id) 193 | 194 | def train(self): 195 | """Trains the agent for given numbers""" 196 | # extend on existing recordings 197 | self.errsigs = np.full((self.episode + gv.num_episodes, gv.max_cycles), np.nan) 198 | for episode_training in range(gv.num_episodes): 199 | # episode_training=0 200 | # while self.totalCycleCounter < gv.max_cycles: 201 | episode_training += 1 202 | 203 | # simulate 204 | if not self.simulate_episode(): 205 | break 206 | # "CartPole-v0 defines solving as getting average return of 195.0 over 100 consecutive trials." 207 | last100return = np.average(self.return_per_episode_sum[self.episode-100:self.episode+1]) 208 | 209 | # time/performance evaluation 210 | tpe = (datetime.datetime.utcnow() - self.time_start) / episode_training 211 | # tpc = (datetime.datetime.utcnow() - self.time_start) / self.totalCycleCounter 212 | # eta = tpc * (gv.max_cycles - self.totalCycleCounter) 213 | eta = tpe * (gv.num_episodes - episode_training) 214 | overwrite = "\r" if self.episode > 0 else "" 215 | sys.stdout.write( 216 | f"{overwrite}{self.episode * 100 / gv.num_episodes:3.3f}% (Episode: {self.episode}, Cycle:{self.totalCycleCounter}) ETA {eta}. Avg. return: {last100return:.1f}") 217 | sys.stdout.flush() 218 | 219 | # plots 220 | if gv.num_plots > 0 and gv.num_episodes > gv.num_plots and self.episode % ( 221 | gv.num_episodes // gv.num_plots) == 0: 222 | # draw.voltage(self.agent.actor.connectome.multimeter, persp="2d") 223 | try: 224 | self.agent.actor.connectome.drawspikes() 225 | except AttributeError: 226 | pass 227 | self.episode += 1 228 | print(f"Cycles: {self.totalCycleCounter}") 229 | 230 | def drawreport(self): 231 | # self.agent.critic.draw(xaxis=0, yaxis=1) 232 | filename = f"{self.id}.png" if self.id is not None else None 233 | try: 234 | connectome = self.agent.actor.connectome.conns 235 | except: 236 | connectome = None 237 | draw.report(utility=self.errsigs, 238 | weights=np.array(self.agent.actor.weightlog), 239 | returnpereps=self.return_per_episode_sum, 240 | connections=connectome, 241 | filename=filename, 242 | env=self.env) 243 | 244 | def presetup(self): 245 | print("Process w/ worker id " + str(multiprocessing.current_process())) 246 | dbconnect() 247 | 248 | self.time_start = datetime.datetime.utcnow() 249 | if gv.save_to_db: 250 | self.save() # safe first to get id 251 | 252 | # pre-training 253 | def dump(obj): 254 | f = "" 255 | for attr in dir(obj): 256 | if attr != "__dict__": 257 | f += "obj.%s = %r" % (attr, getattr(obj, attr)) + "\n" 258 | return f 259 | 260 | self.parameterdump = dump(gv) 261 | # dump(f, self) 262 | # dump(f, self.agent.critic) 263 | 264 | # register instance 265 | self.training.instances.append(str(self.id)) 266 | if gv.save_to_db: 267 | self.training.save() 268 | 269 | def posttrain(self): 270 | # stats 271 | self.time_end = datetime.datetime.utcnow() 272 | self.time_elapsed = (self.time_end - self.time_start).total_seconds() 273 | if gv.save_to_db: 274 | self.save() 275 | if isinstance(self.env, LineFollowingEnv) or isinstance(self.env, LineFollowingEnv2): 276 | self.drawreport() 277 | 278 | self.env.close() 279 | # if not gv.render: 280 | # self.show() 281 | 282 | def run(self, workerdata: Dict = None) -> List[float]: 283 | """ 284 | Create and trains the network. 285 | :param configurator: 286 | :param workerdata: 287 | :return: the results of the training 288 | """ 289 | self.training = workerdata.pop("training") 290 | self.presetup() 291 | 292 | self.workerdata = workerdata 293 | gv.workerdata = workerdata # not nice to add it as a global variable 294 | 295 | # create experiment 296 | configurator: Callable 297 | if "configurator" in workerdata and workerdata["configurator"] is not None: 298 | configurator = workerdata.pop("configurator") 299 | else: 300 | from experiments import lf_placecells 301 | configurator = lf_placecells.configure_training 302 | configurator(self) 303 | 304 | # parse some gridsearch parameters to overwrite configurator 305 | if workerdata: 306 | for (key, value) in self.workerdata.items(): 307 | if hasattr(gv, key): 308 | setattr(gv, key, value) 309 | elif key == "vq_lr_int": 310 | gv.vq_learning_scale = list([0, 10 ** -4, 10 ** -3, 10 ** -2])[int(value)] 311 | elif key == "vq_decay_int": 312 | gv.vq_decay = list([0, 10 ** -4, 10 ** -3, 10 ** -2])[int(value)] 313 | else: 314 | print("unknown gridsearch hyperparameter " + key) 315 | 316 | # training for pole 317 | self.train() 318 | self.posttrain() 319 | 320 | return self.return_per_episode_sum 321 | 322 | def show(self): 323 | global gv 324 | gv_old = copy.deepcopy(gv) 325 | gv.errsig_factor = 0. 326 | gv.structural_plasticity = False 327 | gv.render = True 328 | gv.demo = True 329 | self.agent.prepare_episode() 330 | self.simulate_episode() 331 | gv = gv_old 332 | 333 | 334 | def runworker(dataperworker: Optional[Dict]) -> List[float]: 335 | """ 336 | Set up a worker (process) and run an experiment. 337 | :param dataperworker: 338 | :return: 339 | """ 340 | # redundant copy of method because the gridsearch returns validation errors 341 | # there was a crash when db was disabled with a gridsearch pool 342 | # this cannot be a local function bedause it will cause a crash" 343 | return Experiment().run(dataperworker) 344 | 345 | 346 | def gridsearch(num_processes: int, training, configurator: Callable) -> List: 347 | """perform a gridsearcg on the giving trainingdata """ 348 | pool = multiprocessing.Pool(num_processes) 349 | withoutgivenvalues = filter(lambda v: "from" in v, training.gridsearch.values()) 350 | #todo insert ranges in gridsearch 351 | withgivenvalues = filter(lambda v: "range" in v, training.gridsearch.values()) 352 | parameters: List[slice] = [slice(rangedetails["from"], rangedetails["to"], complex(rangedetails["steps"])) for 353 | rangedetails in withoutgivenvalues] 354 | rangesgridsearch: np.array = np.mgrid[parameters].reshape(len(parameters), -1).T 355 | # put in an array containg the parameters per worker 356 | workload: List[Dict] = [] 357 | paramnameslist = list(training.gridsearch.keys()) 358 | for workerdata in rangesgridsearch: 359 | # each gets a training reference 360 | obj = {"training": training} 361 | if configurator is not None: 362 | obj["configurator"] = configurator 363 | for paramidx, param in enumerate(workerdata): 364 | obj[paramnameslist[paramidx]] = param 365 | workload.append(obj) 366 | result = pool.map(func=runworker, iterable=workload) 367 | pool.close() 368 | pool.join() 369 | 370 | if len(parameters) == 2: 371 | numcolums = list(training.gridsearch.values())[0]["steps"] 372 | resultnp = np.array(result).reshape((-1, numcolums)) 373 | table = "\\begin{center}\\begin{tabular}{ | l | l | l | l | l |}\\hline\n" 374 | table += "num.~cells & $\\lambda =0$ (no vq) & $\\lambda =0.0001$ & $\\lambda =0.001$ & $\\lambda =0.01$" 375 | for i, resultitem in enumerate(result): 376 | if i % numcolums == 0: 377 | table += "\\\\ \\hline\n" 378 | table += str(int(i / numcolums)) # row name 379 | 380 | table += f" & {np.average(resultitem):.0f}" 381 | table += "\\\\ \\hline\n" 382 | table += "\\end{tabular}\\end{center}" 383 | print(table) 384 | return result 385 | 386 | 387 | dirname = "" 388 | 389 | 390 | def exit_handler(): 391 | os.chdir("../../") 392 | global dirname 393 | if len(os.listdir(dirname)) == 0: 394 | os.rmdir(dirname) 395 | 396 | 397 | def createexpdatadir(): 398 | """create new directory for test results and switches to it""" 399 | counter = 0 400 | dirbase = "experimentdata/gsrewardsignal" 401 | global dirname 402 | dirname = dirbase + str(counter) 403 | while os.path.isdir(dirname): 404 | counter += 1 405 | dirname = dirbase + str(counter) 406 | os.makedirs(dirname) 407 | os.chdir(dirname) 408 | print(f"saving to {dirname}\n") 409 | 410 | atexit.register(exit_handler) 411 | 412 | 413 | def dbconnect(): 414 | mongoengine.connect( 415 | db='snntrainings', 416 | username='', 417 | password='', 418 | port=45920, 419 | authentication_source='admin', 420 | host='' 421 | ) 422 | 423 | 424 | def trainingrun(configurator: Callable = None, num_processes: int = 1, gridsearchpath: str = None) -> Tuple[Union[None, Experiment], List]: 425 | """ 426 | Creates an experiemnt and runs it. 427 | :param configurator: 428 | :param num_processes: 429 | :param gridsearchpath: 430 | :return: if a single experiment returns this. None if gridsearch. 431 | """ 432 | training = tg.Trainingrun() 433 | 434 | training.time_start = datetime.datetime.utcnow() 435 | if gv.save_to_db: 436 | dbconnect() 437 | training.save() 438 | print(f"💾DB Trainingrun: ObjectId(\"{training.id}\")") 439 | disconnect() 440 | # if gridsearch 441 | singleexp = None 442 | if gridsearchpath is not None: 443 | with open(gridsearchpath, "r") as file: 444 | training.gridsearch = json.loads(file.read()) 445 | createexpdatadir() 446 | result = gridsearch(num_processes, training, configurator) 447 | else: 448 | # if not a gridsearch 449 | createexpdatadir() 450 | datasingleworker = {"training": training} if configurator is None else {"training": training, 451 | "configurator": configurator} 452 | singleexp = Experiment() 453 | result = singleexp.run(datasingleworker) 454 | 455 | training.time_end = datetime.datetime.utcnow() 456 | training.time_elapsed = (training.time_end - training.time_start).total_seconds() 457 | if gv.save_to_db: 458 | dbconnect() 459 | training.save() 460 | disconnect() 461 | print(f"{training.time_elapsed / 60:10.1f} min") 462 | 463 | return singleexp, result 464 | 465 | import matplotlib.pyplot as plt 466 | 467 | # plt.plot(gv.outactivity["out1"], label="ouput 0") 468 | # plt.plot(gv.outactivity["out2"], label="ouput 1") 469 | # #plt.plot(np.array(gv.outactivity["action"])*30, label="action") 470 | # plt.plot(gv.outactivity["in1"], label="input 1") 471 | # plt.plot(gv.outactivity["in2"], label="input 2") 472 | # # plt.plot(np.array(exp.outactivity[2])*80, label="utility") 473 | # # plt.plot(exp.utilitycontingent, label="used utility") 474 | # for xc in exp.epslength: 475 | # plt.axvline(x=xc, color='k') 476 | # plt.title("Experiment") 477 | # plt.xlabel("cycle") 478 | # plt.legend() 479 | # plt.show() 480 | 481 | # for exp in exps: 482 | # exp.join() 483 | 484 | 485 | def parseargs(): 486 | parser = argparse.ArgumentParser(description='Process some integers.') 487 | parser.add_argument('--processes', type=int, default=multiprocessing.cpu_count(), 488 | help='The number of cores. Currently only supporting multi-cores in grid search.') 489 | parser.add_argument('-g', '--gridsearch', type=str, default=None, help='json specifing grid search parameter') 490 | parser.add_argument('--headless', action='store_true', help='Do not render.') 491 | args = parser.parse_args() 492 | gv.headless = args.headless 493 | if gv.headless: 494 | gv.render = False 495 | return args 496 | 497 | 498 | if __name__ == "__main__": 499 | args = parseargs() 500 | trainingrun(num_processes=args.processes, gridsearchpath=args.gridsearch) 501 | -------------------------------------------------------------------------------- /experiments/amplificationtest.py: -------------------------------------------------------------------------------- 1 | #%% 2 | # import nest 3 | import numpy as np 4 | import nest.raster_plot 5 | 6 | import matplotlib.pylab as plt 7 | 8 | 9 | # this experiments shows the transfer function of a rate coded 10 | if __name__ == "__main__": 11 | # %% 12 | inoutmap = {} 13 | #nest.SetKernelStatus({'resolution': 0.01}) 14 | nest.set_verbosity("M_ERROR") 15 | nsteps = 460 16 | for i in range(nsteps): 17 | nest.ResetKernel() 18 | #alternatively use a poisson generator 19 | #inp = nest.Create("poisson_generator", {"origin":0.0, "start":0.0,"stop":100.0, "rate": float(i*6)}}) 20 | inp = nest.Create("spike_generator", {"allow_offgrid_times": True,"spike_times": np.linspace(0.1,100.1,int(i*0.4))}) 21 | outp = nest.Create("iaf_psc_alpha") #linear with aeif_psc_alpha 22 | sd = nest.Create('spike_detector') 23 | sd2 = nest.Create('spike_detector') 24 | nest.Connect(inp, outp, syn_spec={"weight": 1000.0}) 25 | nest.Connect(inp, sd) 26 | nest.Connect(outp, sd2) 27 | spike_det = nest.Create("spike_detector") 28 | nest.Connect(inp, spike_det) 29 | nest.Connect(outp, spike_det) 30 | 31 | nest.Simulate(100.0) 32 | 33 | count_in = len(nest.GetStatus(sd)[0]['events']['times']) 34 | count_out = len(nest.GetStatus(sd2)[0]['events']['times']) 35 | inoutmap[count_in] = count_out 36 | #%% 37 | # scount = spike_det.get("n_events") 38 | # if scount>0: 39 | # nest.raster_plot.from_device(spike_det, hist=False) 40 | # plt.show() 41 | #%% 42 | x,y = zip(*inoutmap.items()) 43 | plt.plot(x,y,marker='.', label="rate code transfer function") 44 | plt.title("Rate Code Transfer Function") 45 | plt.xlabel("Number of Presynaptic Spikes") 46 | plt.ylabel("Number of Elicited Spikes") 47 | plt.grid() 48 | plt.show() 49 | #draw.spikes(nest.GetStatus(sd)[0]["events"], outsignal=nest.GetStatus(sd2)[0]["events"], output_ids=["in", "out"]) -------------------------------------------------------------------------------- /experiments/fullynet.py: -------------------------------------------------------------------------------- 1 | #this fixes exp not beeing able to import because it is not in the pythonpath 2 | import os,sys,inspect 3 | 4 | 5 | current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 6 | parent_dir = os.path.dirname(current_dir) 7 | sys.path.insert(0, parent_dir) 8 | 9 | import exp 10 | from SpikingNeuroController import SpikingNeurocontroller 11 | from globalvalues import gv 12 | import numpy as np 13 | 14 | def configure_training(expenv): 15 | # fully connected net with relative input encoding 16 | gv.num_hidden_neurons = 2 17 | gv.manual_num_inhibitory = 0 18 | gv.w0_min = 200. # Minimum initial random value 19 | gv.w0_max = 800. # Maximum initial random value 20 | gv.num_episodes = 10 21 | gv.errsig_factor = 0.03 22 | gv.render = False 23 | gv.structural_plasticity = False 24 | gv.random_reconnect = True 25 | expenv.createLineFollowing(SpikingNeurocontroller(num_inputneurons=2, neuron_labels=["error pos", "error neg"])) 26 | expenv.agent.actor.obsfactor = np.array([800 for x in range(2)]) # clamped anyway 27 | 28 | if __name__ == "__main__": 29 | args = exp.parseargs() 30 | exp.trainingrun(configure_training, num_processes=args.processes, gridsearchpath=args.gridsearch) 31 | -------------------------------------------------------------------------------- /experiments/lf_absolutenet.py: -------------------------------------------------------------------------------- 1 | #this fixes exp not beeing able to import because it is not in the pythonpath 2 | import os,sys,inspect 3 | 4 | 5 | current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 6 | parent_dir = os.path.dirname(current_dir) 7 | sys.path.insert(0, parent_dir) 8 | 9 | from SpikingNeuroController import SpikingNeurocontroller 10 | from critic import DynamicBaseline 11 | from lineFollowingEnvironment2 import LineFollowingEnv2 12 | import exp 13 | from globalvalues import gv 14 | import numpy as np 15 | from agent import Agent 16 | 17 | def configure_training(expenv): 18 | # manual minimal net with absolute input encoding as depictured in figure c) 19 | gv.num_hidden_neurons = 4 20 | gv.manual_num_inhibitory = 2 21 | gv.manualwiring = [ 22 | (3, 6, True), 23 | (3, 9, True), 24 | (4, 7, True), 25 | (4, 8, True), 26 | (7, 9, False), 27 | (6, 8, False) 28 | ] 29 | gv.w0_min = 800. # Minimum initial random value 30 | gv.w0_max = 800. # Maximum initial random value 31 | gv.num_episodes = 200 32 | gv.structural_plasticity = False 33 | lfenv = LineFollowingEnv2(absolute_observation=True) 34 | lfenv.seed(gv.seed) 35 | placecell_range = np.array([[-lfenv.track_width * 1 / 3.0, lfenv.track_width * 1 / 3.0], 36 | [-lfenv.track_width * 1 / 3.0, lfenv.track_width * 1 / 3.0]]) 37 | critic = DynamicBaseline(obsranges=placecell_range) 38 | expenv.agent = Agent(expenv.env, 39 | actor=SpikingNeurocontroller(placecell_range=None, num_neurons_per_dim=1, neuron_labels=["current Position", "next Position"], env=expenv.env), 40 | critic=critic) 41 | expenv.penalty = -2. 42 | expenv.agent.actor.only_positive_input = True 43 | expenv.agent.actor.obsfactor = np.array([800 for x in range(2)]) # clamped anyway 44 | 45 | 46 | if __name__ == "__main__": 47 | args = exp.parseargs() 48 | exp.trainingrun(configure_training, num_processes=args.processes, gridsearchpath=args.gridsearch) 49 | -------------------------------------------------------------------------------- /experiments/lf_placecells.py: -------------------------------------------------------------------------------- 1 | #this fixes exp not beeing able to import because it is not in the pythonpath 2 | import os,sys,inspect 3 | 4 | 5 | current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 6 | parent_dir = os.path.dirname(current_dir) 7 | sys.path.insert(0, parent_dir) 8 | import exp 9 | from SpikingNeuroController import SpikingNeurocontroller 10 | from agent import Agent 11 | from exp import Experiment 12 | from globalvalues import gv 13 | import numpy as np 14 | 15 | from critic import DynamicBaseline 16 | from lineFollowingEnvironment2 import LineFollowingEnv2 17 | 18 | 19 | def configure_training(expenv: Experiment): 20 | """error encoded in place cells""" 21 | print("Preparing placenet experiment.") 22 | gv.manual_num_inhibitory = 0 23 | gv.structural_plasticity = False # might be swiched to true later 24 | gv.random_reconnect = False 25 | gv.population_size = 1 26 | gv.num_episodes = 300 27 | gv.criticresolution = 40 28 | gv.dynamic_baseline = False 29 | gv.vq_learning_scale = 0 30 | #the reward factor is closely connected to the output size of the ciritic 31 | # values are so small, so we scale it to boost learning 32 | gv.errsig_factor = 0.002 #0.03 unfiltered, 0.008 for utility 33 | gv.w0_min = 500. # Minimum initial random value 34 | gv.w0_max = 800. # Maximum initial random value 35 | #relative input with place cell encoding currently does not work 36 | relative_input = False #externalize computations of difference 37 | lfenv = LineFollowingEnv2(absolute_observation=not relative_input) 38 | if relative_input: 39 | placecell_range = np.array([[-lfenv.track_width * 1 / 3.0, lfenv.track_width * 1 / 3.0]]) 40 | else: 41 | placecell_range = np.array([[-lfenv.track_width * 1 / 3.0, lfenv.track_width * 1 / 3.0], 42 | [-lfenv.track_width * 1 / 3.0, lfenv.track_width * 1 / 3.0]]) 43 | num_neurons_per_dim = np.full(len(placecell_range), fill_value=2) 44 | 45 | expenv.env = lfenv 46 | expenv.env.seed(gv.seed) 47 | critic = DynamicBaseline(obsranges=placecell_range) 48 | expenv.agent = Agent(expenv.env, 49 | actor=SpikingNeurocontroller(placecell_range=placecell_range, num_neurons_per_dim=num_neurons_per_dim, env=expenv.env), 50 | critic=critic) 51 | expenv.agent.placell_nneurons_per_dim = num_neurons_per_dim 52 | expenv.penalty = 0. 53 | expenv.agent.actor.obsfactor = 400 # clamped if too big 54 | 55 | if __name__ == "__main__": 56 | args = exp.parseargs() 57 | single, res = exp.trainingrun(configure_training, num_processes=args.processes, gridsearchpath=args.gridsearch) 58 | if single is not None: 59 | single.drawspikes() 60 | -------------------------------------------------------------------------------- /experiments/lf_relative_prewired.py: -------------------------------------------------------------------------------- 1 | #this fixes exp not beeing able to import because it is not in the pythonpath 2 | import os,sys,inspect 3 | 4 | 5 | current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 6 | parent_dir = os.path.dirname(current_dir) 7 | sys.path.insert(0, parent_dir) 8 | 9 | from SpikingNeuroController import SpikingNeurocontroller 10 | import exp 11 | from agent import Agent 12 | from globalvalues import gv 13 | import numpy as np 14 | 15 | from critic import DynamicBaseline 16 | from lineFollowingEnvironment2 import LineFollowingEnv2 17 | 18 | 19 | def configure_training(expenv): 20 | # manual minimal net with relative input encoding as depictured in figure d) 21 | gv.num_hidden_neurons = 0 22 | gv.manual_num_inhibitory = 0 23 | gv.manualwiring = [ 24 | (0, 3, True), 25 | (1, 2, True), 26 | ] 27 | gv.structural_plasticity = False 28 | gv.w0_min = 500. # Minimum initial random value 29 | gv.w0_max = 800. # Maximum initial random value 30 | gv.num_episodes = 40 31 | gv.errsig_factor = 0.003 32 | lfenv = LineFollowingEnv2(absolute_observation=False) 33 | 34 | expenv.env = lfenv 35 | expenv.env.seed(gv.seed) 36 | # the error can not get bigger than a third of the width 37 | env_range = np.array([[-lfenv.track_width * 1 / 3.0, lfenv.track_width * 1 / 3.0]]) 38 | critic = DynamicBaseline(obsranges=env_range) 39 | expenv.agent = Agent(expenv.env, 40 | actor=SpikingNeurocontroller(placecell_range=None, num_neurons_per_dim=2, env=expenv.env, neuron_labels=["neg", "pos"]), 41 | critic=critic) 42 | expenv.agent.actor.only_positive_input = False 43 | expenv.agent.actor.obsfactor = np.array([400 for x in range(2)]) # clamped anyway 44 | 45 | 46 | if __name__ == "__main__": 47 | args = exp.parseargs() 48 | exp.trainingrun(configure_training, num_processes=args.processes, gridsearchpath=args.gridsearch) 49 | -------------------------------------------------------------------------------- /experiments/lf_simplified.py: -------------------------------------------------------------------------------- 1 | 2 | #this fixes exp not beeing able to import because it is not in the pythonpath 3 | import os,sys,inspect 4 | current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 5 | parent_dir = os.path.dirname(current_dir) 6 | sys.path.insert(0, parent_dir) 7 | 8 | import exp 9 | from agent import Agent 10 | from exp import Experiment, trainingrun 11 | from globalvalues import gv 12 | import numpy as np 13 | 14 | from critic import DynamicBaseline 15 | from lineFollowingEnvironment2 import LineFollowingEnv2 16 | from symbolicactor import SymbolicActor 17 | 18 | 19 | def configure_training(expenv: Experiment): 20 | """error encoded in place cells""" 21 | print("Preparing line following with symbolic computation experiment.") 22 | gv.manual_num_inhibitory = 0 23 | gv.structural_plasticity = False # might be switched to true later 24 | gv.random_reconnect = False 25 | gv.population_size = 1 26 | gv.errsig_factor = 0.001 # 0.03 unfiltered, 0.008 for utility 27 | gv.num_episodes = 1500 28 | gv.criticresolution = 40 29 | gv.dynamic_baseline = False 30 | #relative input with palce cells currently does not work 31 | relative_input = False #externalize computations of difference 32 | lfenv = LineFollowingEnv2(absolute_observation=not relative_input) 33 | if relative_input: 34 | placecell_range = np.array([[-lfenv.track_width * 1 / 3.0, lfenv.track_width * 1 / 3.0]]) 35 | state_labels = ["Error from optimum"] 36 | else: 37 | placecell_range = np.array([[-lfenv.track_width * 1 / 3.0, lfenv.track_width * 1 / 3.0], 38 | [-lfenv.track_width * 1 / 3.0, lfenv.track_width * 1 / 3.0]]) 39 | state_labels = ["Current Position", "Target Position"] 40 | num_neurons_per_dim = np.full(len(placecell_range), fill_value=2) 41 | 42 | expenv.env = lfenv 43 | expenv.env.seed(gv.seed) 44 | gv.w_max = 2 45 | 46 | expenv.env.seed(gv.seed) 47 | critic = DynamicBaseline(obsranges=placecell_range, state_labels=state_labels) 48 | expenv.agent = Agent(expenv.env, 49 | actor=SymbolicActor(placecell_range=placecell_range, num_neurons_per_dim=num_neurons_per_dim, 50 | env=expenv.env), 51 | critic=critic) 52 | expenv.agent.placell_nneurons_per_dim = num_neurons_per_dim 53 | expenv.penalty = 0. 54 | expenv.agent.actor.obsfactor = 400 # clamped if too big 55 | 56 | 57 | if __name__ == "__main__": 58 | args = exp.parseargs() 59 | single, res = exp.trainingrun(configure_training, num_processes=args.processes, gridsearchpath=args.gridsearch) 60 | if single is not None: 61 | single.agent.critic.draw(xaxis=1, yaxis=0) 62 | -------------------------------------------------------------------------------- /experiments/polebalancing.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import gym 4 | 5 | #this fixes exp not beeing able to import because it is not in the pythonpath 6 | import os,sys,inspect 7 | 8 | 9 | current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 10 | parent_dir = os.path.dirname(current_dir) 11 | sys.path.insert(0, parent_dir) 12 | 13 | import exp 14 | from SpikingNeuroController import SpikingNeurocontroller 15 | from agent import Agent 16 | from globalvalues import gv 17 | import numpy as np 18 | 19 | from critic import DynamicBaseline 20 | 21 | 22 | def configure_training(expenv: exp.Experiment): 23 | """error encoded in place cells""" 24 | print("Preparing polebalancing experiment.") 25 | gv.manual_num_inhibitory = 0 26 | gv.structural_plasticity = False # might be swiched to true later 27 | gv.random_reconnect = False 28 | gv.population_size = 1 29 | #the reward factor is closely connected to the output size of the ciritic 30 | # values are so small, so we scale it to boost learning 31 | gv.errsig_factor = 0.003 #0.03 unfiltered, 0.008 for utility 32 | gv.num_episodes = 1500 33 | env = gym.make('CartPole-v1') 34 | env.seed(gv.seed) 35 | gv.w0_min = 300. # Minimum initial random value 36 | gv.w0_max = 800. # Maximum initial random value 37 | 38 | placecell_range = [[]] * 4 39 | # for cartpole 40 | # actual ranges areonly have the size 41 | placecell_range[0] = [-2.4, 2.4] # Cart pos 42 | placecell_range[1] = [-3, 3] # Cart Velocity #usually upper limit is aroudn 1.9 43 | theta_threshold_radians = 12 * 2 * math.pi / 360 44 | placecell_range[2] = [-theta_threshold_radians, theta_threshold_radians] # is in rad 45 | placecell_range[3] = [-4, 4] # angular velocity 46 | 47 | placecell_range = np.array(placecell_range) 48 | 49 | if "num_cells" in gv.workerdata: 50 | num_neurons_per_dim = [np.array([3, 3, 5, 5]), 51 | np.array([5, 5, 7, 7]), 52 | np.array([7, 7, 15, 15])][int(gv.workerdata["num_cells"])] 53 | else: 54 | num_neurons_per_dim = np.array([5, 5, 7, 7]) 55 | neuron_labels = ["Cart Pos. +", "Cart Pos. -", "Cart Vel. +", "Cart Vel. -", "Pole Angle +", 56 | "Pole Angle -", "Pole Vel. +", 57 | "Pole Vel. -"] 58 | expenv.env = env 59 | expenv.env.seed(gv.seed) 60 | critic = DynamicBaseline(obsranges=placecell_range) 61 | expenv.agent = Agent(expenv.env, 62 | actor=SpikingNeurocontroller(placecell_range=placecell_range, num_neurons_per_dim=num_neurons_per_dim, env=expenv.env), 63 | critic=critic) 64 | expenv.agent.placell_nneurons_per_dim = num_neurons_per_dim 65 | expenv.penalty = -50. 66 | expenv.agent.actor.obsfactor = 400 # clamped if too big 67 | #expenv.agent.actor.connectome.load("5ebd31ed74bd5dfd5c40804e") 68 | 69 | if __name__ == "__main__": 70 | args = exp.parseargs() 71 | exp.trainingrun(configure_training, num_processes=args.processes, gridsearchpath=args.gridsearch) 72 | -------------------------------------------------------------------------------- /experiments/polebalancing_simplified.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import gym 4 | 5 | #this fixes exp not beeing able to import because it is not in the pythonpath 6 | import os,sys,inspect 7 | current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 8 | parent_dir = os.path.dirname(current_dir) 9 | sys.path.insert(0, parent_dir) 10 | 11 | import exp 12 | from agent import Agent 13 | from globalvalues import gv 14 | import numpy as np 15 | from symbolicactor import SymbolicActor 16 | from critic import DynamicBaseline 17 | 18 | 19 | def configure_training(expenv: exp.Experiment): 20 | """error encoded in place cells""" 21 | print("Preparing polebalancing with symbolic computation experiment.") 22 | gv.manual_num_inhibitory = 0 23 | gv.structural_plasticity = False # might be swiched to true later 24 | gv.random_reconnect = False 25 | gv.population_size = 1 26 | gv.num_episodes = 300 27 | gv.vq_learning_scale = 1e-3 28 | 29 | # the reward factor is closely connected to the output size of the ciritic 30 | # values are so small, so we scale it to boost learning 31 | gv.errsig_factor = 0.06 # 0.03 unfiltered, 0.008 for utility 32 | 33 | env = gym.make('CartPole-v1') 34 | env.seed(gv.seed) 35 | gv.w_max = 2 36 | 37 | placecell_range = [[]] * 4 38 | # for cartpole 39 | # actual ranges areonly have the size 40 | placecell_range[0] = [-2.4, 2.4] # Cart pos 41 | placecell_range[1] = [-3, 3] # Cart Velocity #usually upper limit is aroudn 1.9 42 | theta_threshold_radians = 12 * 2 * math.pi / 360 43 | placecell_range[2] = [-theta_threshold_radians, theta_threshold_radians] # is in rad 44 | placecell_range[3] = [-4, 4] # angular velocity 45 | 46 | placecell_range = np.array(placecell_range) 47 | 48 | if "num_cells" in gv.workerdata: 49 | num_neurons_per_dim = [np.array([3, 3, 5, 5]), 50 | np.array([5, 5, 7, 7]), 51 | np.array([7, 7, 15, 15])][int(gv.workerdata["num_cells"])] 52 | else: 53 | num_neurons_per_dim = np.array([5, 5, 7, 7]) 54 | neuron_labels = ["Cart Pos. +", "Cart Pos. -", "Cart Vel. +", "Cart Vel. -", "Pole Angle +", 55 | "Pole Angle -", "Pole Vel. +", 56 | "Pole Vel. -"] 57 | expenv.env = env 58 | expenv.env.seed(gv.seed) 59 | critic = DynamicBaseline(obsranges=placecell_range) 60 | expenv.agent = Agent(expenv.env, 61 | actor=SymbolicActor(placecell_range=placecell_range, num_neurons_per_dim=num_neurons_per_dim, 62 | env=expenv.env), 63 | critic=critic) 64 | expenv.agent.placell_nneurons_per_dim = num_neurons_per_dim 65 | expenv.penalty = -50. 66 | expenv.agent.actor.obsfactor = 400 # clamped if too big 67 | expenv.agent.actor.placecell_range = placecell_range # in line environment.width 68 | 69 | 70 | if __name__ == "__main__": 71 | args = exp.parseargs() 72 | single, res = exp.trainingrun(configure_training, num_processes=args.processes, gridsearchpath=args.gridsearch) 73 | if single is not None: 74 | single.agent.critic.draw(xaxis=2, yaxis=3) 75 | -------------------------------------------------------------------------------- /fremauxfilter.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import numpy as np 3 | from globalvalues import gv 4 | 5 | 6 | def kernel(t: np.ndarray) -> np.ndarray: 7 | """ 8 | kernel as defined in fremaux2013 9 | :param t: first axis: neuron number, second: list of spike times 10 | :return: 11 | """ 12 | tau = 40 13 | ypsilon = 10 14 | for i, spiketime in enumerate(t): 15 | #invalidmask = (0 > spiketime) | (spiketime > gv.cycle_length) 16 | #clip negative spike times 17 | spiketime = np.clip(spiketime, a_min=0, a_max=None) 18 | #spiketime[invalidmask] = 0 19 | t[i] = (np.exp(-spiketime / tau) - np.exp(-spiketime / ypsilon)) / (tau - ypsilon) 20 | return t 21 | 22 | 23 | def filter(time, spikesignals: List[List[float]]) -> np.ndarray: 24 | """Converts vector of list of spike times to vector of list of float at time t.""" 25 | spikesignals_np = np.array([np.array(xi) for xi in spikesignals], dtype=object) 26 | #per neuron thre can only be one spike at each timestep therefore use 1 for f(tau) 27 | integrands = kernel(time - spikesignals_np) 28 | # calc. discrete integral by summing the integrands for each neuron 29 | ouputs: np.ndarray = np.zeros(len(spikesignals)) 30 | for nidx, integrand in enumerate(integrands): 31 | ouputs[nidx] = np.sum(integrand) 32 | return ouputs 33 | 34 | #todo allow filter for every t, e.g. when t is None 35 | -------------------------------------------------------------------------------- /globalvalues.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import typing 3 | 4 | 5 | class Settings: 6 | workerdata = None # per training instance 7 | demo = False # when set to true will only display and not learn or train. there are more setting to disabel training 8 | neuronLabelMap: typing.Dict[int, str] = {} # map neuron ids to strings 9 | seed = 42 10 | cycle_length = 40.0 # ms 11 | voltageRecordings = 0 # a set of ids recording the voltages 12 | delay = 1.0 # delay between EPSP/IPSP and action potential 13 | filter_outsignals = False # filter signal as in fremaux 2013 14 | neuronmodel = "iaf_psc_alpha" # wunderlich use iaf_psc_exp 15 | num_hidden_neurons = 0 # only free neurons 16 | num_output_populations = 2 # number of ouput populations 17 | population_size = 5 18 | lateral_inhibition = 300 #if zero it is disabled, weight 19 | criticresolution = 300 20 | dynamic_baseline = True 21 | 22 | fraction_hidden_inhibitory = 0.2 # can be overwritten by using a manual wiring 23 | manualwiring: typing.List[typing.Tuple[int, int, bool]] = None # which population is connected to which, last param is excitatory or inhibitory 24 | 25 | manual_num_inhibitory = 0 # overwrites automatic creation of inhibitory 26 | errsig_factor = 0.052 # https://github.com/clamesc/Training-Neural-Networks-for-Event-Based-End-to-End-Robot-Control/blob/56dc686cbc660e8c462c2c2b6f1310f83ee70ea9/Controller/R-STDP/parameters.py#L29 27 | factor_negative_util = 1 # factor for negative utilities, Dabney2020 28 | tau_n = 200. # Time constant of reward signal 29 | tau_c = 1000. # Time constant of eligibility trace 30 | w0_min = 100. # Minimum initial random value 31 | w0_max = 500. # Maximum initial random value 32 | w_max = 3000.0 # https://github.com/clamesc/Training-Neural-Networks-for-Event-Based-End-to-End-Robot-Control/blob/56dc686cbc660e8c462c2c2b6f1310f83ee70ea9/Controller/R-STDP/parameters.py#L24 33 | 34 | max_poisson_freq = 500. 35 | max_util_integral = float("inf") # 0.7 36 | util_discount_factor = 0.8 37 | util_learn_rate = 1.0 # usually 0.9, to disable td learning set it to 1 38 | 39 | vq_learning_scale = 1e-3 # learn rate of vector quantization (scaling), 0 disables it 40 | vq_decay = 1e-4 # (decaying speed) 41 | 42 | num_episodes = 6000 43 | max_cycles = 600 # 300 ~ episode in lf2 44 | num_plots = 0 # number of spiking plots 45 | render = False # render the environment 46 | headless = False # prevents every rendering, even in pyplot, prevents crash when importing from gym.envs.classic_control import rendering 47 | save_to_db = False 48 | allow_early_termination = False 49 | 50 | structural_plasticity = True # allow removal or adding of synapses 51 | random_reconnect = False # randomly adds new synapses 52 | strp_min = w0_min / 4 # minimum value where a synapse is removed 53 | @staticmethod 54 | def define_stdp(vt): 55 | try: 56 | import nest 57 | # the time constant of the depressing window of STDP is a parameter of the post-synaptic neuron. 58 | rstdp_syn_spec_exitory = {'Wmax': gv.w_max, 59 | 'Wmin': 0.0, 60 | 'delay': gv.delay, 61 | "weight": gv.w0_min, # will be overwritten with randomized value later 62 | "A_plus": 1.0, 63 | "A_minus": 1.0, 64 | "tau_n": gv.tau_n, 65 | "tau_c": gv.tau_c, 66 | 'vt': vt.tolist()[0]} 67 | 68 | rstdp_syn_spec_inhibitory = rstdp_syn_spec_exitory.copy() 69 | rstdp_syn_spec_inhibitory["Wmin"] = -gv.w_max 70 | rstdp_syn_spec_inhibitory['Wmax'] = 0.0 71 | 72 | # alternative to set defaults is to create a new model from a copy 73 | nest.CopyModel('stdp_dopamine_synapse', 'stdp_dopamine_synapse_ex', rstdp_syn_spec_exitory) 74 | # nest.SetDefaults("stdp_dopamine_synapse_ex", rstdp_syn_spec_exitory) 75 | nest.CopyModel('stdp_dopamine_synapse', 'stdp_dopamine_synapse_in', rstdp_syn_spec_inhibitory) 76 | # nest.SetDefaults("stdp_dopamine_synapse_in", rstdp_syn_spec_inhibitory) 77 | except ImportError: 78 | print("Neural simulator Nest not found (import nest). Only able to run the simplified architecture.") 79 | 80 | @staticmethod 81 | def init(): 82 | numproc = 1 83 | try: 84 | import nest 85 | nest.ResetKernel() 86 | # run for every process 87 | # same results every run 88 | nest.SetKernelStatus({"grng_seed": gv.seed}) 89 | # numpy seed 90 | nest.EnableStructuralPlasticity() 91 | numproc = nest.GetKernelStatus(['total_num_virtual_procs'])[0] 92 | except ImportError: 93 | print("Neural simulator Nest not found (import nest). Only able to run the simplified architecture.") 94 | 95 | np.random.seed(gv.seed) 96 | gv.pyrngs = [np.random.RandomState(s) for s in range(gv.seed, gv.seed + numproc)] 97 | 98 | 99 | gv = Settings() 100 | -------------------------------------------------------------------------------- /gridsearch.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | rsync -rtuv /Users/benediktvogler/Studium/Masterarbeit/SNNexperiments/* ubuntu:~/Dokumente/SNNexperiments/ 4 | ssh ubuntu << EOF 5 | cd ~/Dokumente/SNNexperiments/ 6 | source /home/bsvogler/Dokumente/nest/bin/nest_vars.sh 7 | screen -S "Gridsearch" -d -m 8 | screen -r "Gridsearch" -X stuff $'python3 ./experiments/polebalancing.py --processes=6 -g gridsearch.json --headless;rsync -rtuv ./experimentdata benediktvogler@192.168.2.3:~/Studium/Masterarbeit/SNNexperiments/;\n' 9 | EOF -------------------------------------------------------------------------------- /gridsearch/gridsearch.json: -------------------------------------------------------------------------------- 1 | { 2 | 3 | "errsig_factor": { 4 | "from": 0.001, 5 | "to": 0.02, 6 | "steps": 5 7 | }, 8 | "criticresolution": { 9 | "from": 50, 10 | "to": 400, 11 | "steps": 3 12 | }, 13 | "lateral_inhibition": { 14 | "from": 100, 15 | "to": 400, 16 | "steps": 3 17 | } 18 | } -------------------------------------------------------------------------------- /gridsearch/vq.json: -------------------------------------------------------------------------------- 1 | { 2 | "vq_lr_int": { 3 | "from": 1, 4 | "to": 3, 5 | "steps": 3 6 | }, 7 | "vq_decay_int": { 8 | "from": 1, 9 | "to": 3, 10 | "steps": 3 11 | } 12 | } -------------------------------------------------------------------------------- /lineFollowingEnvironment.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any 2 | 3 | import gym 4 | from gym import spaces, logger 5 | from gym.utils import seeding 6 | import numpy as np 7 | 8 | from globalvalues import gv 9 | 10 | 11 | class LineFollowingEnv(gym.Env): 12 | """ 13 | Description: 14 | A one dimensional lien following task. 15 | Source: 16 | This environment corresponds to the version of the cart-pole problem described by Barto, Sutton, and Anderson 17 | Observation: 18 | Type: Box(2) 19 | Num Observation Min Max 20 | 0 Current Position 0 width 21 | 1 Next step line position 0 width 22 | 23 | Actions: 24 | Type: Discrete(2) 25 | Num Action 26 | 0 Push cart to the left 27 | 1 Push cart to the right 28 | 29 | Note: The amount the velocity that is reduced or increased is not fixed; it depends on the angle the pole is pointing. This is because the center of gravity of the pole increases the amount of energy needed to move the cart underneath it 30 | Reward: 31 | Reward is 1 for every step taken, including the termination step 32 | Starting State: 33 | All observations are assigned a uniform random value in [-0.05..0.05] 34 | Episode Termination: 35 | Pole Angle is more than 12 degrees 36 | Cart Position is more than 2.4 (center of the cart reaches the edge of the display) 37 | Episode length is greater than 200 38 | Solved Requirements 39 | Considered solved when the average reward is greater than or equal to 195.0 over 100 consecutive trials. 40 | """ 41 | 42 | metadata = { 43 | 'render.modes': ['human', 'rgb_array'], 44 | 'video.frames_per_second': 50 45 | } 46 | 47 | def __init__(self, absolute_observation=False): 48 | self.absolute_observation = absolute_observation 49 | self.tracklength = 100 50 | self.width = 3.0 51 | self.track = np.sin(np.linspace(0, 2 * np.pi, self.tracklength)) + self.width / 2 52 | self.trackPiece: List[Any] = [None] * self.tracklength 53 | self.bound = [] 54 | self.botpos = [0,0] 55 | 56 | self.action_space = spaces.Box(np.array([0.0]), np.array([1])) 57 | self.observation_space = spaces.Box(np.array([0.0, 0.0]), np.array([self.width, self.width]), dtype=np.float32) 58 | 59 | self.seed() 60 | self.viewer = None 61 | self.state = None 62 | self.previouspath = [] 63 | self.steps_beyond_done = None 64 | self.reset() 65 | 66 | def update_state(self): 67 | self.state = (self.botpos[0], self.track[(self.botpos[1] + 1) % self.tracklength]) 68 | return self.state 69 | 70 | def seed(self, seed=None): 71 | self.np_random, seed = seeding.np_random(seed) 72 | return [seed] 73 | 74 | def step(self, action): 75 | assert self.action_space.contains(action), "%r (%s) invalid" % (action, type(action)) 76 | self.previouspath[-1].append(self.botpos[0]) 77 | self.botpos[0] += action[0] - 0.5 78 | self.update_state() 79 | reward = self.width / 4 - abs(self.botpos[0] - self.track[self.botpos[1]]) # surviving is positive 80 | self.botpos[1] += 1 81 | self.botpos[1] %= self.tracklength 82 | done = reward <= 0 83 | 84 | if done: 85 | #clamp so that every failing state is equal regarding reward 86 | reward = 0 87 | if self.steps_beyond_done == 0: 88 | logger.warn( 89 | "You are calling 'step()' even though this environment has already returned done = True. You should always call 'reset()' once you receive 'done = True' -- any further steps are undefined behavior.") 90 | if self.steps_beyond_done is not None: 91 | self.steps_beyond_done += 1 92 | if self.steps_beyond_done is None: 93 | self.steps_beyond_done = 0 94 | 95 | # relative 96 | if self.absolute_observation: 97 | return np.array(self.state), reward, done, {} 98 | else: 99 | return np.array([self.state[0] - self.state[1]]), reward, done, {} 100 | 101 | def reset(self): 102 | """Reset the enviroment. Returns initial reward""" 103 | # self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,)) 104 | self.previouspath.append([]) 105 | self.steps_beyond_done = None 106 | self.botpos = [self.width * gv.pyrngs[0].uniform(0.7, 0.74), 0] # x float, y index # middle is self.width / 2 107 | self.update_state() 108 | initalreward = self.width / 4 - abs(self.botpos[0] - self.track[self.botpos[1]]) # surviving is positive 109 | if self.absolute_observation: 110 | return np.array(self.state), initalreward 111 | else: 112 | return np.array([self.state[0] - self.state[1]]), initalreward 113 | 114 | def render(self, mode='human'): 115 | screen_width = 600 116 | screen_height = 400 117 | stepHeight = screen_height / float(self.tracklength) 118 | unitLength = screen_width / self.width / 2 119 | offsety = 10 120 | from gym.envs.classic_control import rendering 121 | if self.viewer is None: 122 | self.viewer = rendering.Viewer(screen_width, screen_height) 123 | cartwidth = 4.0 124 | cartheight = 4.0 125 | l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2 126 | cart = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)]) 127 | self.carttrans = rendering.Transform() 128 | cart.add_attr(self.carttrans) 129 | self.viewer.add_geom(cart) 130 | 131 | # track 132 | self.trackTrans = rendering.Transform() 133 | for i in range(0, self.tracklength): 134 | self.trackPiece[i] = rendering.Line((self.track[i] * unitLength, i * stepHeight), 135 | (self.track[(i + 1) % self.tracklength] * unitLength, 136 | (i + 1) * stepHeight)) 137 | # self.trackPiece[i] = rendering.Line((0, 30), (10, 40)) 138 | self.trackPiece[i].add_attr(self.trackTrans) 139 | self.trackPiece[i].set_color(.5, .5, .8) 140 | self.viewer.add_geom(self.trackPiece[i]) 141 | # bound left 142 | for i in range(0, self.tracklength): 143 | newPiece = rendering.Line(((self.track[i] + self.width / 2) * unitLength, i * stepHeight), 144 | ((self.track[(i + 1) % self.tracklength] + self.width / 4) * unitLength, 145 | (i + 1) * stepHeight)) 146 | newPiece.add_attr(self.trackTrans) 147 | newPiece.set_color(.5, .0, .0) 148 | self.bound.append(newPiece) 149 | self.viewer.add_geom(newPiece) 150 | # bound right 151 | for i in range(0, self.tracklength): 152 | newPiece = rendering.Line(((self.track[i] - self.width / 2) * unitLength, i * stepHeight), 153 | ((self.track[(i + 1) % self.tracklength] - self.width / 4) * unitLength, 154 | (i + 1) * stepHeight)) 155 | newPiece.add_attr(self.trackTrans) 156 | newPiece.set_color(.5, .0, .0) 157 | self.bound.append(newPiece) 158 | self.viewer.add_geom(newPiece) 159 | 160 | if self.state is None: return None 161 | 162 | for i, trial in enumerate(self.previouspath[-20:]): 163 | grayvalue = 1 - float(i) / len(self.previouspath[-20:]) 164 | color = (grayvalue, grayvalue, grayvalue) 165 | for i in range(1, len(trial)): 166 | self.viewer.draw_line((trial[i - 1] * unitLength, (i - 1) * stepHeight + offsety), 167 | (trial[i] * unitLength, i * stepHeight + offsety), 168 | color=color) 169 | 170 | cartx = self.botpos[0] * unitLength # MIDDLE OF CART 171 | self.carttrans.set_translation(cartx, self.botpos[1] * stepHeight + offsety) 172 | self.trackTrans.set_translation(0, 0 + offsety) 173 | # self.poletrans.set_rotation(-x[2]) 174 | 175 | return self.viewer.render(return_rgb_array=mode == 'rgb_array') 176 | 177 | def close(self): 178 | if self.viewer: 179 | self.viewer.close() 180 | self.viewer = None 181 | -------------------------------------------------------------------------------- /lineFollowingEnvironment2.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Any 3 | 4 | import gym 5 | from gym import spaces, logger 6 | from gym.utils import seeding 7 | import numpy as np 8 | 9 | from globalvalues import gv 10 | 11 | 12 | class LineFollowingEnv2(gym.Env): 13 | """ 14 | Description: 15 | A one dimensional lien following task. 16 | Source: 17 | This environment corresponds to the version of the cart-pole problem described by Barto, Sutton, and Anderson 18 | Observation: 19 | Type: Box(2) 20 | Num Observation Min Max 21 | 0 Current Position 0 width 22 | 1 Next step line position 0 width 23 | 24 | Actions: 25 | Type: Discrete(2) 26 | Num Action 27 | 0 Push cart to the left 28 | 1 Push cart to the right 29 | 30 | Note: The amount the velocity that is reduced or increased is not fixed; it depends on the angle the pole is pointing. This is because the center of gravity of the pole increases the amount of energy needed to move the cart underneath it 31 | Reward: 32 | Reward is 1 for every step taken, including the termination step 33 | Starting State: 34 | All observations are assigned a uniform random value in [-0.05..0.05] 35 | Episode Termination: 36 | Pole Angle is more than 12 degrees 37 | Cart Position is more than 2.4 (center of the cart reaches the edge of the display) 38 | Episode length is greater than 200 39 | Solved Requirements 40 | Considered solved when the average reward is greater than or equal to 195.0 over 100 consecutive trials. 41 | """ 42 | 43 | metadata = { 44 | 'render.modes': ['human', 'rgb_array'], 45 | 'video.frames_per_second': int(1000/gv.cycle_length) 46 | } 47 | 48 | segments = 12 49 | def __init__(self, absolute_observation=False): 50 | self.absolute_observation = absolute_observation 51 | self.tracklength = 600 52 | self.track_width = 3.0 53 | 54 | self.track: List[float] = [0.0] * self.tracklength 55 | self.trackPiece: List[Any] = [None] * self.tracklength 56 | self.bound = [] 57 | self.botpos = [0, 0.0] #x, y 58 | 59 | self.action_space = spaces.Box(np.array([0.0]), np.array([1])) 60 | self.observation_space = spaces.Box(np.array([0.0, 0.0]), np.array([self.track_width, self.track_width]), dtype=np.float32) 61 | 62 | self.seed() 63 | self.viewer = None 64 | self.state = None 65 | self.previouspath = [] 66 | self.steps_beyond_done = None 67 | self.grid_from = math.floor(-self.track_width / 2) 68 | self.grid_to = math.ceil(self.track_width / 2) 69 | self.grid_res = 0 70 | self.grid = [] 71 | self.reset() 72 | 73 | def enable_grid(self): 74 | self.grid_res = math.ceil(abs(self.grid_from-self.grid_to)) #count of lines 75 | 76 | def update_state(self): 77 | #current pos, next track pos 78 | self.state = (self.botpos[1], self.track[(self.botpos[0] + 1) % self.tracklength]) 79 | return self.state 80 | 81 | def seed(self, seed=None): 82 | self.np_random, seed = seeding.np_random(seed) 83 | return [seed] 84 | 85 | def step(self, action): 86 | assert self.action_space.contains(action), "%r (%s) invalid" % (action, type(action)) 87 | self.previouspath[-1].append(self.botpos[1]) 88 | self.botpos[1] += action[0] - 0.5 89 | self.update_state() 90 | reward = self.track_width - abs(self.botpos[1] - self.track[self.botpos[0]]) # surviving is positive 91 | self.botpos[0] += 1 92 | self.botpos[0] %= self.tracklength 93 | done = reward <= 0 94 | 95 | #stop if not visible any more 96 | if abs(self.botpos[1]) >= self.track_width/2: 97 | done = True 98 | 99 | if done: 100 | #clamp so that every failing state is equal regarding reward 101 | reward = 0 102 | if self.steps_beyond_done == 0: 103 | logger.warn( 104 | "You are calling 'step()' even though this environment has already returned done = True. You should always call 'reset()' once you receive 'done = True' -- any further steps are undefined behavior.") 105 | if self.steps_beyond_done is not None: 106 | self.steps_beyond_done += 1 107 | if self.steps_beyond_done is None: 108 | self.steps_beyond_done = 0 109 | 110 | # relative 111 | if self.absolute_observation: 112 | return np.array(self.state), reward, done, {} 113 | else: 114 | return np.array([self.state[0] - self.state[1]]), reward, done, {} 115 | 116 | def reset(self): 117 | """Reset the enviroment. Returns initial reward""" 118 | # self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,)) 119 | #create new path 120 | # track is between 0 and width 121 | seglength = int(self.tracklength // LineFollowingEnv2.segments) 122 | randomvar = self.track_width * 2 / 3.0 123 | for seg in range(LineFollowingEnv2.segments): 124 | random = np.random.random_sample() * randomvar - randomvar/2 #centered at 0. 125 | 126 | for i in range(seg*seglength,(seg+1)*seglength): 127 | self.track[i] = random 128 | 129 | self.previouspath.append([]) 130 | self.steps_beyond_done = None 131 | self.botpos = [0, gv.pyrngs[0].uniform(0, self.track_width / 2) - self.track_width / 4] # x index, y float # middle is self.width / 2 132 | self.update_state() 133 | if self.absolute_observation: 134 | return np.array(self.state) 135 | else: 136 | return np.array([self.state[0] - self.state[1]]) 137 | 138 | def render(self, mode='human'): 139 | screen_width = 600 140 | screen_height = 400 141 | unit_height = screen_height / self.track_width 142 | unit_length = screen_width / float(self.tracklength) 143 | centerY = screen_height / 2 144 | from gym.envs.classic_control import rendering 145 | if self.viewer is None: 146 | self.viewer = rendering.Viewer(screen_width, screen_height) 147 | 148 | #grid 149 | if self.grid_res>0: 150 | gridstep = (self.grid_to-self.grid_from)/self.grid_res*unit_height 151 | self.grid.append(rendering.Line((0, centerY), 152 | (screen_width, centerY))) 153 | self.grid[0].set_color(.0, .9, .0) 154 | self.viewer.add_geom(self.grid[0]) 155 | for y in range(self.grid_res): 156 | ypos = y*gridstep + self.grid_from * unit_height 157 | self.grid.append(rendering.Line((0, ypos+centerY), 158 | (screen_width, ypos+centerY))) 159 | self.grid[y+1].set_color(.8, .2, .2) 160 | self.viewer.add_geom(self.grid[y+1]) 161 | 162 | #add cart 163 | cartwidth = 4.0 164 | cartheight = 4.0 165 | l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2 166 | cart = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)]) 167 | self.carttrans = rendering.Transform() 168 | cart.add_attr(self.carttrans) 169 | self.viewer.add_geom(cart) 170 | 171 | # bound left 172 | # for i in range(0, self.tracklength): 173 | # newPiece = rendering.Line(((self.track[i] + self.width / 2) * unitLength, i * unit_height), 174 | # ((self.track[(i + 1) % self.tracklength] + self.width / 4) * unitLength, 175 | # (i + 1) * unit_height)) 176 | # newPiece.add_attr(self.trackTrans) 177 | # newPiece.set_color(.5, .0, .0) 178 | # self.bound.append(newPiece) 179 | # self.viewer.add_geom(newPiece) 180 | # # bound right 181 | # for i in range(0, self.tracklength): 182 | # newPiece = rendering.Line(((self.track[i] - self.width / 2) * unitLength, i * unit_height), 183 | # ((self.track[(i + 1) % self.tracklength] - self.width / 4) * unitLength, 184 | # (i + 1) * unit_height)) 185 | # newPiece.add_attr(self.trackTrans) 186 | # newPiece.set_color(.5, .0, .0) 187 | # self.bound.append(newPiece) 188 | # self.viewer.add_geom(newPiece) 189 | 190 | if self.state is None: 191 | return None 192 | 193 | #track 194 | color = (.5, .5, .8) 195 | for i in range(0, self.tracklength): 196 | self.viewer.draw_line((i * unit_length, 197 | +self.track[i] * unit_height+centerY), 198 | ((i + 1) * unit_length, 199 | +self.track[(i + 1) % self.tracklength] * unit_height+centerY), 200 | color=color) 201 | 202 | #trail 203 | for i, trial in enumerate(self.previouspath[-20:]): 204 | grayvalue = 1 - float(i) / len(self.previouspath[-20:]) 205 | color = (grayvalue, grayvalue, grayvalue) 206 | for i in range(1, len(trial)): 207 | self.viewer.draw_line(((i - 1) * unit_length, trial[i - 1] * unit_height+centerY), 208 | ( i * unit_length, trial[i] * unit_height+centerY), 209 | color=color) 210 | 211 | cartx = self.botpos[0] * unit_length # MIDDLE OF CART 212 | self.carttrans.set_translation(cartx, self.botpos[1] * unit_height+centerY) 213 | # self.poletrans.set_rotation(-x[2]) 214 | 215 | return self.viewer.render(return_rgb_array=mode == 'rgb_array') 216 | 217 | def close(self): 218 | if self.viewer: 219 | self.viewer.close() 220 | self.viewer = None 221 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BSVogler/SNN-RL/67df1115fd604706ab08200ba17eab11d5f844e2/models/__init__.py -------------------------------------------------------------------------------- /models/trainingrun.py: -------------------------------------------------------------------------------- 1 | from mongoengine import Document, DateTimeField, ListField, ReferenceField, DictField, FloatField, IntField, \ 2 | BinaryField, StringField 3 | 4 | 5 | class Episode(Document): 6 | """everything per cycle""" 7 | episode = IntField() 8 | rewards = ListField(FloatField()) 9 | activation = ListField(FloatField())#stores the average activation of the input layer 10 | weights = BinaryField() 11 | weights_human = ListField(ListField(FloatField())) 12 | neuromodulator = ListField(FloatField()) 13 | 14 | class Trainingrun(Document): 15 | time_start = DateTimeField() 16 | time_end = DateTimeField() 17 | time_elapsed = FloatField() # in s 18 | #from exp import Experiment 19 | #Reference field causes circular import 20 | instances = ListField(StringField()) 21 | gridsearch = DictField() 22 | -------------------------------------------------------------------------------- /nestdockerrun.sh: -------------------------------------------------------------------------------- 1 | docker run --rm -it -e LOCAL_USER_ID=`id -u $USER` -v $(pwd):/opt/data -p 8080:8080 nestsim/nest:latest /bin/bash -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gym>=0.17.1. 2 | numpy==1.18. 3 | matplotlib==3.2.2 4 | networkx 5 | scikit-learn==0.22. 6 | mongoengine>=0.20.0 -------------------------------------------------------------------------------- /rstdp_imp_test.py: -------------------------------------------------------------------------------- 1 | import nest 2 | import nest.raster_plot 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import draw 6 | import globalvalues as gv 7 | 8 | gv.configure_nest() 9 | nest.ResetKernel() 10 | type = "iaf_psc_alpha" 11 | inputs = nest.Create("poisson_generator", 1) 12 | # introduce parrot neuron to fix limitation of devices with STDP synapses 13 | parrots = nest.Create("parrot_neuron", 1) 14 | nest.Connect(inputs, parrots) 15 | spike_detector = nest.Create('spike_detector') 16 | nest.Connect(parrots, spike_detector, 'all_to_all') 17 | out = nest.Create(type,1) 18 | 19 | vt = nest.Create('volume_transmitter') 20 | gv.define_stdp(vt) 21 | nest.Connect(parrots, out, syn_spec={'model': 'stdp_dopamine_synapse_ex'}) 22 | 23 | nest.SetStatus(inputs, {"start":0., 24 | "rate":1000., 25 | "stop":gv.cycle_length}) 26 | nest.Simulate(gv.cycle_length) 27 | #now stop spiking and give reward 28 | conn = nest.GetConnections(source=parrots, target=out) 29 | weights = [] 30 | weights.append(np.array(nest.GetStatus(conn, keys="weight"))) 31 | nest.SetStatus(conn, {"n": 100.}) 32 | weights.append(np.array(nest.GetStatus(conn, keys="weight"))) 33 | nest.Simulate(1.) 34 | weights.append(np.array(nest.GetStatus(conn, keys="weight"))) 35 | nest.Simulate(1.) 36 | weights.append(np.array(nest.GetStatus(conn, keys="weight"))) 37 | nest.Simulate(1.) 38 | weights.append(np.array(nest.GetStatus(conn, keys="weight"))) 39 | 40 | spikes = nest.GetStatus(spike_detector)[0]["events"] 41 | draw.spikes(spikes) 42 | 43 | #nest.raster_plot.from_device(spike_detector, hist=True, hist_binwidth=40., 44 | # title='Repeated stimulation by Poisson generator') 45 | #nest.raster_plot.show() 46 | 47 | plt.plot(weights) 48 | plt.show() -------------------------------------------------------------------------------- /symbolicactor.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | from actor import Actor, PlaceCellAnalog2ActivationLayer 4 | from globalvalues import gv 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | from lineFollowingEnvironment import LineFollowingEnv 9 | from lineFollowingEnvironment2 import LineFollowingEnv2 10 | 11 | 12 | class SymbolicActor(Actor): 13 | """Part of the agent that manages behaviour. Includes analog-2-spike and spike-2-analog.""" 14 | 15 | def __init__(self, placecell_range: np.ndarray, num_neurons_per_dim: np.ndarray, env, 16 | neuron_labels: List[str] = None): 17 | """ 18 | 19 | :param num_neurons_per_dim: for each dimension numebr of place fields 20 | :param env: 21 | :param neuron_labels: 22 | """ 23 | super().__init__() 24 | self.env = env 25 | self.obsfactor = 1 # factors to scale observations, can be a numpy array of size of number of place cells 26 | # place cells 27 | num_place_cells: int = np.multiply.reduce(num_neurons_per_dim) 28 | self.placell_nneurons_per_dim = num_neurons_per_dim 29 | self.placecell_range: np.ndarray = placecell_range # axis 0: dimension, axis 1: [from, to] 30 | if placecell_range is not None: 31 | labels = [] 32 | for i in range(num_place_cells): 33 | pc_center = placecell_range[0, 0] + i * abs((placecell_range[0, 1] - placecell_range[0, 0])) / ( 34 | num_place_cells - 1) 35 | labels.append(f"pc_{pc_center}") 36 | self.placeCellA2SLayer: PlaceCellAnalog2ActivationLayer = PlaceCellAnalog2ActivationLayer( 37 | placecellrange=self.placecell_range, 38 | num_cells_per_dim=self.placell_nneurons_per_dim) 39 | 40 | self.lastpc: int # index 41 | self.placecellaction = np.random.uniform(low=-0.5, high=.5, 42 | size=self.placeCellA2SLayer.positions.shape[0]) # default action is random 43 | # log 44 | self.lastactivation: np.array = np.array([]) # array, one hot encoded 45 | self.weightlog: List[Actor.Weightstorage] = [] 46 | self.end_episode(-1) 47 | self.positive_input = False # if input is only positive 48 | 49 | def read_output(self, time: float) -> List[float]: 50 | output = self.placecellaction * self.lastactivation 51 | return output 52 | 53 | def cycle(self, time: float, observation_in: np.ndarray) -> Tuple[Union[int, List[float]], List]: 54 | """has side effects, call only once per cycle""" 55 | # get nearest place cell 56 | dists2: np.ndarray = np.linalg.norm(self.placeCellA2SLayer.positions - observation_in, axis=1) ** 2 57 | activations = self.placeCellA2SLayer.activation(observation_in) 58 | # lateral inhibition causes one hot encoding 59 | self.lastactivation = np.zeros_like(dists2) 60 | lastpc = np.argmax(activations) 61 | self.lastactivation[lastpc] = 1 62 | 63 | # vector quantiazion 64 | # plt.scatter(self.placecellpos[:, 0], self.placecellpos[:, 1], c=self.placecellpos[:, 3]) 65 | # plt.show() 66 | 67 | #calculate output layer 68 | output = self.read_output(time) 69 | if isinstance(self.env, LineFollowingEnv) or isinstance(self.env, LineFollowingEnv2): 70 | action = [np.clip(0.5 - np.sum(output), 0., 1.)] 71 | else: 72 | # rate should only be a scalar value 73 | action: int = int(np.sign(np.sum(output)) == 1) # 0 or 1 74 | return action, None 75 | 76 | def release_neurotransmitter(self, amount: float): 77 | """update last pc""" 78 | super().release_neurotransmitter(amount) 79 | def g(weight: float) -> float: 80 | #eligibility trace after foderaro et al. 2010, not very beneficial 81 | return 0.2+np.abs(weight*5*np.exp(-np.abs(weight)/gv.w_max)) 82 | 83 | self.placecellaction += np.sign(self.placecellaction*self.lastactivation) * amount #* g(self.placecellaction) proved not very beneficial 84 | self.placecellaction = np.clip(self.placecellaction, -gv.w_max, gv.w_max) 85 | 86 | def end_cycle(self, cycle_num): 87 | pass 88 | 89 | def end_episode(self, episode): 90 | """ 91 | store weights for loading in next episode and logging 92 | :return: 93 | """ 94 | self.weightlog.append(self.placecellaction.copy()) 95 | -------------------------------------------------------------------------------- /utilityTest.py: -------------------------------------------------------------------------------- 1 | import math 2 | import unittest 3 | import numpy as np 4 | from gym.spaces import Box 5 | 6 | from critic import Utility 7 | 8 | 9 | class MyTestCase(unittest.TestCase): 10 | 11 | def test_tick(self): 12 | # test two dimensional states 13 | util = Utility(obsranges=np.array([[0.0, 0.0], [3.0, 3.0]])) 14 | errsig, utilValue = util.tick((0.0, 3.0), (1.0, 1.0)) 15 | self.assertAlmostEqual(utilValue, 1.0) 16 | errsig, a = util.tick((0.0, 2.0), (0.0, 0.0)) 17 | errsig, b = util.tick((0.0, 2.0), (3.0, 3.0)) 18 | self.assertAlmostEqual(a, 0) 19 | self.assertAlmostEqual(b, 3) # should return last value 20 | 21 | def test_drawrewards(self): 22 | util = Utility(obsranges=np.array([[0.0, 3.4], [0.0, 5.0], [2., 3.]])) 23 | util.tick((0.0, 3.0, 2.), (1.0, 4.0)) 24 | util.tick((2.0, 0.3, 3.), (2.0, 3.0)) 25 | util.tick((1.0, 2.0, 2.), (1.0, 2.)) 26 | util.tick((-4.0, 2.0, 2.), (3.4, 1.0)) 27 | util.tick((1.3, 1.2, 2.), (-2.0, 2.)) 28 | util.tick((-4.0, 2.0, 2.), (3.4, 1.0)) 29 | util.tick((3.32, 1.2, 2.), (-2.0, 2.)) 30 | util.tick((-4.0, 1.52, 3.),(2.24, 2.)) 31 | util.tick((2.34, 1.4, 2.), (-1.4, 2.)) 32 | util.end_episode() 33 | util.draw_rewards(xaxis=0, yaxis=1) 34 | 35 | def test_draw(self): 36 | util = Utility(obsranges=np.array([[0.0, 2.0], [0.0, 1.0]])) 37 | for x in range(0, 40): 38 | util.tick((x / 30.0, math.sin(x / 20) * 1.0), (2 * x)) 39 | util.end_episode() 40 | util.draw(xaxis=0, yaxis=1) 41 | 42 | 43 | if __name__ == '__main__': 44 | unittest.main() 45 | --------------------------------------------------------------------------------