├── README.md ├── analyse.py ├── envs ├── 10x10.json ├── 4x4.json └── 5x5.json ├── model.py ├── parameters.py ├── plot.py ├── run.py ├── run_lstm.py ├── test.py ├── utils.py └── world.py /README.md: -------------------------------------------------------------------------------- 1 | 9 | 10 | # Torch TEM 11 | 12 | 13 | ## Table of Contents 14 | 15 | * [About the Project](#about-the-project) 16 | * [Getting Started](#getting-started) 17 | * [Installation](#installation) 18 | * [Model Training](#model-training) 19 | * [Model Analysis](#model-analysis) 20 | * [Contact](#contact) 21 | * [Acknowledgements](#acknowledgements) 22 | 23 | 24 | 25 | ## About The Project 26 | 27 | This is an implementation of the Tolman-Eichenbaum Machine in pytorch, written from scratch by following the Supplementary Material of [the original paper](https://www.biorxiv.org/content/10.1101/770495v2.full). It is extensively annotated and tries to follow the notation and terminology from the publication as closely as possible. 28 | 29 | 30 | 31 | ## Getting Started 32 | 33 | You need to install [python >= 3.6.0](https://www.python.org/downloads/) and [pytorch >= 1.6.0](https://pytorch.org/). 34 | 35 | 36 | ### Installation 37 | 38 | Clone the repo 39 | ```sh 40 | git clone https://github.com/jbakermans/torch_tem.git 41 | ``` 42 | 43 | ### Model Training 44 | 45 | With the repo as working directory, train a model by running 46 | ```sh 47 | python run.py 48 | ``` 49 | Model parameters are specified in ```parameters.py```. 50 | 51 | ### Model Analysis 52 | 53 | After training a model, analyse a model and plot analysis results by running 54 | ```sh 55 | python test.py 56 | ``` 57 | You will need to specify the correct model run in ```test.py```. 58 | 59 | 60 | ## Contact 61 | 62 | [Jacob Bakermans](http://users.ox.ac.uk/~phys1358/) - jacob.bakermans [at] gmail.com 63 | 64 | Project Link: [https://github.com/jbakermans/torch_tem](https://github.com/jbakermans/torch_tem) 65 | 66 | 67 | 68 | ## Acknowledgements 69 | 70 | Many thanks to James Whittington for advice and assistance throughout. -------------------------------------------------------------------------------- /analyse.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu May 28 15:01:34 2020 5 | 6 | @author: jacobb 7 | """ 8 | import numpy as np 9 | import torch 10 | import pdb 11 | import copy 12 | 13 | # Track prediction accuracy over walk, and calculate fraction of locations visited and actions taken to assess performance 14 | def performance(forward, model, environments): 15 | # Keep track of whether model prediction were correct, as well as the fraction of nodes/edges visited, across environments 16 | all_correct, all_location_frac, all_action_frac = [], [], [] 17 | # Run through environments and monitor performance in each 18 | for env_i, env in enumerate(environments): 19 | # Keep track for each location whether it has been visited 20 | location_visited = np.full(env.n_locations, False) 21 | # And for each action in each location whether it has been taken 22 | action_taken = np.full((env.n_locations,model.hyper['n_actions']), False) 23 | # Not all actions are available at every location (e.g. edges of grid world). Find how many actions can be taken 24 | action_available = np.full((env.n_locations,model.hyper['n_actions']), False) 25 | for currLocation in env.locations: 26 | for currAction in currLocation['actions']: 27 | if np.sum(currAction['transition']) > 0: 28 | if model.hyper['has_static_action']: 29 | if currAction['id'] > 0: 30 | action_available[currLocation['id'], currAction['id'] - 1] = True 31 | else: 32 | action_available[currLocation['id'], currAction['id']] = True 33 | # Make array to list whether the observation was predicted correctly or not 34 | correct = [] 35 | # Make array that stores for each step the fraction of locations visited 36 | location_frac = [] 37 | # And an array that stores for each step the fraction of actions taken 38 | action_frac = [] 39 | # Run through iterations of forward pass to check when an action is taken for the first time 40 | for step in forward: 41 | # Update the states that have now been visited 42 | location_visited[step.g[env_i]['id']] = True 43 | # ... And the actions that now have been taken 44 | if model.hyper['has_static_action']: 45 | if step.a[env_i] > 0: 46 | action_taken[step.g[env_i]['id'], step.a[env_i] - 1] = True 47 | else: 48 | action_taken[step.g[env_i]['id'], step.a[env_i]] = True 49 | # Mark the location of the previous iteration as visited 50 | correct.append((torch.argmax(step.x_gen[2][env_i]) == torch.argmax(step.x[env_i])).numpy()) 51 | # Add the fraction of locations visited for this step 52 | location_frac.append(np.sum(location_visited) / location_visited.size) 53 | # ... And also add the fraction of actions taken for this step 54 | action_frac.append(np.sum(action_taken) / np.sum(action_available)) 55 | # Add performance and visitation fractions of this environment to performance list across environments 56 | all_correct.append(correct) 57 | all_location_frac.append(location_frac) 58 | all_action_frac.append(action_frac) 59 | # Return 60 | return all_correct, all_location_frac, all_action_frac 61 | 62 | # Track prediction accuracy per location, after a transition towards the location 63 | def location_accuracy(forward, model, environments): 64 | # Keep track of whether model prediction were correct for each environment, separated by arrival and departure location 65 | accuracy_from, accuracy_to = [], [] 66 | # Run through environments and monitor performance in each 67 | for env_i, env in enumerate(environments): 68 | # Make array to list whether the observation was predicted correctly or not 69 | correct_from = [[] for _ in range(env.n_locations)] 70 | correct_to = [[] for _ in range(env.n_locations)] 71 | # Run through iterations of forward pass to check when an action is taken for the first time 72 | for step_i, step in enumerate(forward[1:]): 73 | # Prediction on arrival: sensory prediction when arriving at given node 74 | correct_to[step.g[env_i]['id']].append((torch.argmax(step.x_gen[2][env_i]) == torch.argmax(step.x[env_i])).numpy().tolist()) 75 | # Prediction on depature: sensory prediction after leaving given node - i.e. store whether the current prediction is correct for the previous location 76 | correct_from[forward[step_i].g[env_i]['id']].append((torch.argmax(step.x_gen[2][env_i]) == torch.argmax(step.x[env_i])).numpy().tolist()) 77 | # Add performance and visitation fractions of this environment to performance list across environments 78 | accuracy_from.append([sum(correct_from_location) / (len(correct_from_location) if len(correct_from_location) > 0 else 1) for correct_from_location in correct_from]) 79 | accuracy_to.append([sum(correct_to_location) / (len(correct_to_location) if len(correct_to_location) > 0 else 1) for correct_to_location in correct_to]) 80 | # Return 81 | return accuracy_from, accuracy_to 82 | 83 | # Track occupation per location 84 | def location_occupation(forward, model, environments): 85 | # Keep track of how many times each location was visited 86 | occupation = [] 87 | # Run through environments and monitor performance in each 88 | for env_i, env in enumerate(environments): 89 | # Make array to list whether the observation was predicted correctly or not 90 | visits = [0 for _ in range(env.n_locations)] 91 | # Run through iterations of forward pass to check when an action is taken for the first time 92 | for step in forward: 93 | # Prediction on arrival: sensory prediction when arriving at given node 94 | visits[step.g[env_i]['id']] += 1 95 | # Add performance and visitation fractions of this environment to performance list across environments 96 | occupation.append(visits) 97 | # Return occupation of states during walk across environments 98 | return occupation 99 | 100 | # Measure zero-shot inference for this model: see if it can predict an observation following a new action to a know location 101 | def zero_shot(forward, model, environments, include_stay_still=True): 102 | # Get the number of actions in this model 103 | n_actions = model.hyper['n_actions'] + model.hyper['has_static_action'] 104 | # Track for all opportunities for zero-shot inference if the predictions were correct across environments 105 | all_correct_zero_shot = [] 106 | # Run through environments and check for zero-shot inference in each of them 107 | for env_i, env in enumerate(environments): 108 | # Keep track for each location whether it has been visited 109 | location_visited = np.full(env.n_locations, False) 110 | # And for each action in each location whether it has been taken 111 | action_taken = np.full((env.n_locations, n_actions), False) 112 | # Get the very first iteration 113 | prev_iter = forward[0] 114 | # Make list that for all opportunities for zero-shot inference tracks if the predictions were correct 115 | correct_zero_shot = [] 116 | # Run through iterations of forward pass to check when an action is taken for the first time 117 | for step in forward[1:]: 118 | # Get the previous action and previous location location 119 | prev_a, prev_g = prev_iter.a[env_i], prev_iter.g[env_i]['id'] 120 | # If the previous action was standing still: only count as valid transition standing still actions are included as zero-shot inference 121 | if model.hyper['has_static_action'] and prev_a == 0 and not include_stay_still: 122 | prev_a = None 123 | # Mark the location of the previous iteration as visited 124 | location_visited[prev_g] = True 125 | # Zero shot inference occurs when the current location was visited, but the previous action wasn't taken before 126 | if location_visited[step.g[env_i]['id']] and prev_a is not None and not action_taken[prev_g, prev_a]: 127 | # Find whether the prediction was correct 128 | correct_zero_shot.append((torch.argmax(step.x_gen[2][env_i]) == torch.argmax(step.x[env_i])).numpy()) 129 | # Update the previous action as taken 130 | if prev_a is not None: 131 | action_taken[prev_g, prev_a] = True 132 | # And update the previous iteration to the current iteration 133 | prev_iter = step 134 | # Having gone through the full forward pass for one environment, add the zero-shot performance to the list of all 135 | all_correct_zero_shot.append(correct_zero_shot) 136 | # Return lists of success of zero-shot inference for all environments 137 | return all_correct_zero_shot 138 | 139 | # Compare TEM performance to a 'node' and an 'edge' agent, that remember previous observations and guess others 140 | def compare_to_agents(forward, model, environments, include_stay_still=True): 141 | # Get the number of actions in this model 142 | n_actions = model.hyper['n_actions'] + model.hyper['has_static_action'] 143 | # Store for each environment for each step whether is was predicted correctly by the model, and by a perfect node and perfect edge agent 144 | all_correct_model, all_correct_node, all_correct_edge = [], [], [] 145 | # Run through environments and check for correct or incorrect prediction 146 | for env_i, env in enumerate(environments): 147 | # Keep track for each location whether it has been visited 148 | location_visited = np.full(env.n_locations, False) 149 | # And for each action in each location whether it has been taken 150 | action_taken = np.full((env.n_locations, n_actions), False) 151 | # Make array to list whether the observation was predicted correctly or not for the model 152 | correct_model = [] 153 | # And the same for a node agent, that picks a random observation on first encounter of a node, and the correct one every next time 154 | correct_node = [] 155 | # And the same for an edge agent, that picks a random observation on first encounter of an edge, and the correct one every next time 156 | correct_edge = [] 157 | # Get the very first iteration 158 | prev_iter = forward[0] 159 | # Run through iterations of forward pass to check when an action is taken for the first time 160 | for step in forward[1:]: 161 | # Get the previous action and previous location location 162 | prev_a, prev_g = prev_iter.a[env_i], prev_iter.g[env_i]['id'] 163 | # If the previous action was standing still: only count as valid transition standing still actions are included as zero-shot inference 164 | if model.hyper['has_static_action'] and prev_a == 0 and not include_stay_still: 165 | prev_a = None 166 | # Mark the location of the previous iteration as visited 167 | location_visited[prev_g] = True 168 | # Update model prediction for this step 169 | correct_model.append((torch.argmax(step.x_gen[2][env_i]) == torch.argmax(step.x[env_i])).numpy()) 170 | # Update node agent prediction for this step: correct when this state was visited beofre, otherwise chance 171 | correct_node.append(True if location_visited[step.g[env_i]['id']] else np.random.randint(model.hyper['n_x']) == torch.argmax(step.x[env_i]).numpy()) 172 | # Update edge agent prediction for this step: always correct if no action taken, correct when action leading to this state was taken before, otherwise chance 173 | correct_edge.append(True if prev_a is None else True if action_taken[prev_g, prev_a] else np.random.randint(model.hyper['n_x']) == torch.argmax(step.x[env_i]).numpy()) 174 | # Update the previous action as taken 175 | if prev_a is not None: 176 | action_taken[prev_g, prev_a] = True 177 | # And update the previous iteration to the current iteration 178 | prev_iter = step 179 | # Add the performance of model, node agent, and edge agent for this environment to list across environments 180 | all_correct_model.append(correct_model) 181 | all_correct_node.append(correct_node) 182 | all_correct_edge.append(correct_edge) 183 | # Return list of prediction success for all three agents across environments 184 | return all_correct_model, all_correct_node, all_correct_edge 185 | 186 | # Calculate rate maps for this model: what is the firing pattern for each cell at all locations? 187 | def rate_map(forward, model, environments): 188 | # Store location x cell firing rate matrix for abstract and grounded location representation across environments 189 | all_g, all_p = [], [] 190 | # Go through environments and collect firing rates in each 191 | for env_i, env in enumerate(environments): 192 | # Collect grounded location/hippocampal/place cell representation during walk: separate into frequency modules, then locations 193 | p = [[[] for loc in range(env.n_locations)] for f in range(model.hyper['n_f'])] 194 | # Collect abstract location/entorhinal/grid cell representation during walk: separate into frequency modules, then locations 195 | g = [[[] for loc in range(env.n_locations)] for f in range(model.hyper['n_f'])] 196 | # In each step, concatenate the representations to the appropriate list 197 | for step in forward: 198 | # Run through frequency modules and append the firing rates to the correct location list 199 | for f in range(model.hyper['n_f']): 200 | g[f][step.g[env_i]['id']].append(step.g_inf[f][env_i].numpy()) 201 | p[f][step.g[env_i]['id']].append(step.p_inf[f][env_i].numpy()) 202 | # Now average across location visits to get a single represenation vector for each location for each frequency 203 | for cells, n_cells in zip([p, g], [model.hyper['n_p'], model.hyper['n_g']]): 204 | for f, frequency in enumerate(cells): 205 | # Average across visits of the each location, but only the second half of the visits so model roughly know the environment 206 | for l, location in enumerate(frequency): 207 | frequency[l] = sum(location[int(len(location)/2):]) / len(location[int(len(location)/2):]) if len(location[int(len(location)/2):]) > 0 else np.zeros(n_cells[f]) 208 | # Then concatenate the locations to get a [locations x cells for this frequency] matrix 209 | cells[f] = np.stack(frequency, axis=0) 210 | # Append the final average representations of this environment to the list of representations across environments 211 | all_g.append(g) 212 | all_p.append(p) 213 | # Return list of locations x cells matrix of firing rates for each frequency module for each environment 214 | return all_g, all_p 215 | 216 | # Helper function to generate input for the model 217 | def generate_input(environment, walk): 218 | # If no walk was provided: use the environment to generate one 219 | if walk is None: 220 | # Generate a single walk from environment with length depending on number of locations (so you're likely to visit each location) 221 | walk = environment.generate_walks(environment.graph['n_locations']*100, 1)[0] 222 | # Now this walk needs to be adjusted so that it looks like a batch with batch size 1 223 | for step in walk: 224 | # Make single location into list 225 | step[0] = [step[0]] 226 | # Make single 1D observation vector into 2D row vector 227 | step[1] = step[1].unsqueeze(dim=0) 228 | # Make single action into list 229 | step[2] = [step[2]] 230 | return walk 231 | 232 | # Smoothing function (originally written by James) 233 | def smooth(a, wsz): 234 | # a: NumPy 1-D array containing the data to be smoothed 235 | # WSZ: smoothing window size needs, which must be odd number, 236 | out0 = np.convolve(a, np.ones(wsz, dtype=int), 'valid') / wsz 237 | r = np.arange(1, wsz - 1, 2) 238 | start = np.cumsum(a[:wsz - 1])[::2] / r 239 | stop = (np.cumsum(a[:-wsz:-1])[::2] / r)[::-1] 240 | return np.concatenate((start, out0, stop)) -------------------------------------------------------------------------------- /envs/4x4.json: -------------------------------------------------------------------------------- 1 | {"n_locations":16,"n_observations":45,"n_actions":5,"adjacency":[[1,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0],[1,1,1,0,0,1,0,0,0,0,0,0,0,0,0,0],[0,1,1,1,0,0,1,0,0,0,0,0,0,0,0,0],[0,0,1,1,0,0,0,1,0,0,0,0,0,0,0,0],[1,0,0,0,1,1,0,0,1,0,0,0,0,0,0,0],[0,1,0,0,1,1,1,0,0,1,0,0,0,0,0,0],[0,0,1,0,0,1,1,1,0,0,1,0,0,0,0,0],[0,0,0,1,0,0,1,1,0,0,0,1,0,0,0,0],[0,0,0,0,1,0,0,0,1,1,0,0,1,0,0,0],[0,0,0,0,0,1,0,0,1,1,1,0,0,1,0,0],[0,0,0,0,0,0,1,0,0,1,1,1,0,0,1,0],[0,0,0,0,0,0,0,1,0,0,1,1,0,0,0,1],[0,0,0,0,0,0,0,0,1,0,0,0,1,1,0,0],[0,0,0,0,0,0,0,0,0,1,0,0,1,1,1,0],[0,0,0,0,0,0,0,0,0,0,1,0,0,1,1,1],[0,0,0,0,0,0,0,0,0,0,0,1,0,0,1,1]],"locations":[{"id":0,"observation":9,"x":0.125,"y":0.125,"in_locations":[0,1,4],"in_degree":3,"out_locations":[0,1,4],"out_degree":3,"actions":[{"id":0,"transition":[1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.33333333333333331},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":2,"transition":[0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.33333333333333331},{"id":3,"transition":[0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0],"probability":0.33333333333333331},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0}]},{"id":1,"observation":22,"x":0.375,"y":0.125,"in_locations":[0,1,2,5],"in_degree":4,"out_locations":[0,1,2,5],"out_degree":4,"actions":[{"id":0,"transition":[0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":2,"transition":[0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":3,"transition":[0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":4,"transition":[1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25}]},{"id":2,"observation":23,"x":0.625,"y":0.125,"in_locations":[1,2,3,6],"in_degree":4,"out_locations":[1,2,3,6],"out_degree":4,"actions":[{"id":0,"transition":[0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":2,"transition":[0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":3,"transition":[0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":4,"transition":[0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25}]},{"id":3,"observation":8,"x":0.875,"y":0.125,"in_locations":[2,3,7],"in_degree":3,"out_locations":[2,3,7],"out_degree":3,"actions":[{"id":0,"transition":[0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.33333333333333331},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":3,"transition":[0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0],"probability":0.33333333333333331},{"id":4,"transition":[0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.33333333333333331}]},{"id":4,"observation":31,"x":0.125,"y":0.375,"in_locations":[0,4,5,8],"in_degree":4,"out_locations":[0,4,5,8],"out_degree":4,"actions":[{"id":0,"transition":[0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":1,"transition":[1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":2,"transition":[0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":3,"transition":[0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0],"probability":0.25},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0}]},{"id":5,"observation":14,"x":0.375,"y":0.375,"in_locations":[1,4,5,6,9],"in_degree":5,"out_locations":[1,4,5,6,9],"out_degree":5,"actions":[{"id":0,"transition":[0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":1,"transition":[0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":2,"transition":[0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0],"probability":0.2},{"id":4,"transition":[0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2}]},{"id":6,"observation":28,"x":0.625,"y":0.375,"in_locations":[2,5,6,7,10],"in_degree":5,"out_locations":[2,5,6,7,10],"out_degree":5,"actions":[{"id":0,"transition":[0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":1,"transition":[0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":2,"transition":[0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0],"probability":0.2},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0],"probability":0.2},{"id":4,"transition":[0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0],"probability":0.2}]},{"id":7,"observation":32,"x":0.875,"y":0.375,"in_locations":[3,6,7,11],"in_degree":4,"out_locations":[3,6,7,11],"out_degree":4,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0],"probability":0.25},{"id":1,"transition":[0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0],"probability":0.25},{"id":4,"transition":[0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0],"probability":0.25}]},{"id":8,"observation":26,"x":0.125,"y":0.625,"in_locations":[4,8,9,12],"in_degree":4,"out_locations":[4,8,9,12],"out_degree":4,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0],"probability":0.25},{"id":1,"transition":[0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0],"probability":0.25},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0],"probability":0.25},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0}]},{"id":9,"observation":16,"x":0.375,"y":0.625,"in_locations":[5,8,9,10,13],"in_degree":5,"out_locations":[5,8,9,10,13],"out_degree":5,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0],"probability":0.2},{"id":1,"transition":[0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0],"probability":0.2},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0],"probability":0.2},{"id":4,"transition":[0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0],"probability":0.2}]},{"id":10,"observation":24,"x":0.625,"y":0.625,"in_locations":[6,9,10,11,14],"in_degree":5,"out_locations":[6,9,10,11,14],"out_degree":5,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0],"probability":0.2},{"id":1,"transition":[0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0],"probability":0.2},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0],"probability":0.2},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0],"probability":0.2}]},{"id":11,"observation":5,"x":0.875,"y":0.625,"in_locations":[7,10,11,15],"in_degree":4,"out_locations":[7,10,11,15],"out_degree":4,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0],"probability":0.25},{"id":1,"transition":[0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0],"probability":0.25},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1],"probability":0.25},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0],"probability":0.25}]},{"id":12,"observation":40,"x":0.125,"y":0.875,"in_locations":[8,12,13],"in_degree":3,"out_locations":[8,12,13],"out_degree":3,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0],"probability":0.33333333333333331},{"id":1,"transition":[0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0],"probability":0.33333333333333331},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0],"probability":0.33333333333333331},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0}]},{"id":13,"observation":12,"x":0.375,"y":0.875,"in_locations":[9,12,13,14],"in_degree":4,"out_locations":[9,12,13,14],"out_degree":4,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0],"probability":0.25},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0],"probability":0.25},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0],"probability":0.25},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0],"probability":0.25}]},{"id":14,"observation":30,"x":0.625,"y":0.875,"in_locations":[10,13,14,15],"in_degree":4,"out_locations":[10,13,14,15],"out_degree":4,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0],"probability":0.25},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0],"probability":0.25},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1],"probability":0.25},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0],"probability":0.25}]},{"id":15,"observation":35,"x":0.875,"y":0.875,"in_locations":[11,14,15],"in_degree":3,"out_locations":[11,14,15],"out_degree":3,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1],"probability":0.33333333333333331},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0],"probability":0.33333333333333331},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0],"probability":0.33333333333333331}]}]} -------------------------------------------------------------------------------- /envs/5x5.json: -------------------------------------------------------------------------------- 1 | {"n_locations":25,"n_observations":45,"n_actions":5,"adjacency":[[1,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[1,1,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,1,1,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,1,1,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,1,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[1,0,0,0,0,1,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,1,0,0,0,1,1,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,1,0,0,0,1,1,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,1,0,0,0,1,1,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,1,0,0,0,1,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,1,0,0,0,0,1,1,0,0,0,1,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,1,0,0,0,1,1,1,0,0,0,1,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,1,0,0,0,1,1,1,0,0,0,1,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,1,0,0,0,1,1,1,0,0,0,1,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,1,0,0,0,1,1,0,0,0,0,1,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,1,0,0,0,1,0,0,0,0],[0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,1,1,0,0,0,1,0,0,0],[0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,1,1,0,0,0,1,0,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,1,1,0,0,0,1,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,1,0,0,0,0,1],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,1,0,0,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,1,1,0,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,1,1,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,1,1],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,1]],"locations":[{"id":0,"observation":31,"x":0.1,"y":0.1,"in_locations":[0,1,5],"in_degree":3,"out_locations":[0,1,5],"out_degree":3,"actions":[{"id":0,"transition":[1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.33333333333333331},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":2,"transition":[0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.33333333333333331},{"id":3,"transition":[0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.33333333333333331},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0}]},{"id":1,"observation":39,"x":0.3,"y":0.1,"in_locations":[0,1,2,6],"in_degree":4,"out_locations":[0,1,2,6],"out_degree":4,"actions":[{"id":0,"transition":[0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":2,"transition":[0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":3,"transition":[0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":4,"transition":[1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25}]},{"id":2,"observation":21,"x":0.5,"y":0.1,"in_locations":[1,2,3,7],"in_degree":4,"out_locations":[1,2,3,7],"out_degree":4,"actions":[{"id":0,"transition":[0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":2,"transition":[0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":3,"transition":[0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":4,"transition":[0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25}]},{"id":3,"observation":33,"x":0.7,"y":0.1,"in_locations":[2,3,4,8],"in_degree":4,"out_locations":[2,3,4,8],"out_degree":4,"actions":[{"id":0,"transition":[0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":2,"transition":[0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":3,"transition":[0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":4,"transition":[0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25}]},{"id":4,"observation":34,"x":0.9,"y":0.1,"in_locations":[3,4,9],"in_degree":3,"out_locations":[3,4,9],"out_degree":3,"actions":[{"id":0,"transition":[0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.33333333333333331},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.33333333333333331},{"id":4,"transition":[0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.33333333333333331}]},{"id":5,"observation":5,"x":0.1,"y":0.3,"in_locations":[0,5,6,10],"in_degree":4,"out_locations":[0,5,6,10],"out_degree":4,"actions":[{"id":0,"transition":[0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":1,"transition":[1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":2,"transition":[0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0}]},{"id":6,"observation":2,"x":0.3,"y":0.3,"in_locations":[1,5,6,7,11],"in_degree":5,"out_locations":[1,5,6,7,11],"out_degree":5,"actions":[{"id":0,"transition":[0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":1,"transition":[0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":2,"transition":[0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":4,"transition":[0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2}]},{"id":7,"observation":15,"x":0.5,"y":0.3,"in_locations":[2,6,7,8,12],"in_degree":5,"out_locations":[2,6,7,8,12],"out_degree":5,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":1,"transition":[0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":2,"transition":[0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":4,"transition":[0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2}]},{"id":8,"observation":10,"x":0.7,"y":0.3,"in_locations":[3,7,8,9,13],"in_degree":5,"out_locations":[3,7,8,9,13],"out_degree":5,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":1,"transition":[0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":4,"transition":[0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2}]},{"id":9,"observation":29,"x":0.9,"y":0.3,"in_locations":[4,8,9,14],"in_degree":4,"out_locations":[4,8,9,14],"out_degree":4,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":1,"transition":[0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":4,"transition":[0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25}]},{"id":10,"observation":44,"x":0.1,"y":0.5,"in_locations":[5,10,11,15],"in_degree":4,"out_locations":[5,10,11,15],"out_degree":4,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":1,"transition":[0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0}]},{"id":11,"observation":32,"x":0.3,"y":0.5,"in_locations":[6,10,11,12,16],"in_degree":5,"out_locations":[6,10,11,12,16],"out_degree":5,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":1,"transition":[0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0],"probability":0.2},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2}]},{"id":12,"observation":6,"x":0.5,"y":0.5,"in_locations":[7,11,12,13,17],"in_degree":5,"out_locations":[7,11,12,13,17],"out_degree":5,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":1,"transition":[0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0],"probability":0.2},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2}]},{"id":13,"observation":37,"x":0.7,"y":0.5,"in_locations":[8,12,13,14,18],"in_degree":5,"out_locations":[8,12,13,14,18],"out_degree":5,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":1,"transition":[0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0],"probability":0.2},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2}]},{"id":14,"observation":41,"x":0.9,"y":0.5,"in_locations":[9,13,14,19],"in_degree":4,"out_locations":[9,13,14,19],"out_degree":4,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0],"probability":0.25},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25}]},{"id":15,"observation":27,"x":0.1,"y":0.7,"in_locations":[10,15,16,20],"in_degree":4,"out_locations":[10,15,16,20],"out_degree":4,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0],"probability":0.25},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0],"probability":0.25},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0}]},{"id":16,"observation":16,"x":0.3,"y":0.7,"in_locations":[11,15,16,17,21],"in_degree":5,"out_locations":[11,15,16,17,21],"out_degree":5,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0],"probability":0.2},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0],"probability":0.2},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0],"probability":0.2},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0],"probability":0.2}]},{"id":17,"observation":40,"x":0.5,"y":0.7,"in_locations":[12,16,17,18,22],"in_degree":5,"out_locations":[12,16,17,18,22],"out_degree":5,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0],"probability":0.2},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0],"probability":0.2},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0],"probability":0.2},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0],"probability":0.2}]},{"id":18,"observation":13,"x":0.7,"y":0.7,"in_locations":[13,17,18,19,23],"in_degree":5,"out_locations":[13,17,18,19,23],"out_degree":5,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0],"probability":0.2},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0],"probability":0.2},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0],"probability":0.2},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0],"probability":0.2},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0],"probability":0.2}]},{"id":19,"observation":7,"x":0.9,"y":0.7,"in_locations":[14,18,19,24],"in_degree":4,"out_locations":[14,18,19,24],"out_degree":4,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0],"probability":0.25},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0],"probability":0.25},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1],"probability":0.25},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0],"probability":0.25}]},{"id":20,"observation":4,"x":0.1,"y":0.9,"in_locations":[15,20,21],"in_degree":3,"out_locations":[15,20,21],"out_degree":3,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0],"probability":0.33333333333333331},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0],"probability":0.33333333333333331},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0],"probability":0.33333333333333331},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0}]},{"id":21,"observation":28,"x":0.3,"y":0.9,"in_locations":[16,20,21,22],"in_degree":4,"out_locations":[16,20,21,22],"out_degree":4,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0],"probability":0.25},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0],"probability":0.25},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0],"probability":0.25},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0],"probability":0.25}]},{"id":22,"observation":20,"x":0.5,"y":0.9,"in_locations":[17,21,22,23],"in_degree":4,"out_locations":[17,21,22,23],"out_degree":4,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0],"probability":0.25},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0],"probability":0.25},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0],"probability":0.25},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0],"probability":0.25}]},{"id":23,"observation":24,"x":0.7,"y":0.9,"in_locations":[18,22,23,24],"in_degree":4,"out_locations":[18,22,23,24],"out_degree":4,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0],"probability":0.25},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0],"probability":0.25},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1],"probability":0.25},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0],"probability":0.25}]},{"id":24,"observation":36,"x":0.9,"y":0.9,"in_locations":[19,23,24],"in_degree":3,"out_locations":[19,23,24],"out_degree":3,"actions":[{"id":0,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1],"probability":0.33333333333333331},{"id":1,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0],"probability":0.33333333333333331},{"id":2,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":3,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"probability":0},{"id":4,"transition":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0],"probability":0.33333333333333331}]}]} -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Feb 11 14:26:32 2020 5 | 6 | This is a pytorch implementation of the Tolman-Eichenbaum Machine, 7 | written by Jacob Bakermans after the original by James Whittington. 8 | The referenced paper is the bioRxiv publication at https://www.biorxiv.org/content/10.1101/770495v2 9 | 10 | Release v1.0.0: Fully functional pytorch model, without any extensions 11 | 12 | @author: jacobb 13 | """ 14 | # Standard modules 15 | import numpy as np 16 | import torch 17 | import pdb 18 | import copy 19 | from scipy.stats import truncnorm 20 | # Custom modules 21 | import utils 22 | 23 | class Model(torch.nn.Module): 24 | def __init__(self, params): 25 | # First call super class init function to set up torch.nn.Module style model and inherit it's functionality 26 | super(Model, self).__init__() 27 | # Copy hyperparameters (e.g. network sizes) from parameter dict, usually generated from parameters() in parameters.py 28 | self.hyper = copy.deepcopy(params) 29 | # Create trainable parameters 30 | self.init_trainable() 31 | 32 | def forward(self, walk, prev_iter = None, prev_M = None): 33 | # The previous iteration may contain walks without action. These are new walks, for which some parameters need to be reset. 34 | steps = self.init_walks(prev_iter) 35 | # Forward pass: perform a TEM iteration for each set of [place, observation, action], and produce inferred and generated variables for each step. 36 | for g, x, a in walk: 37 | # If there is no previous iteration at all: all walks are new, initialise a whole new iteration object 38 | if steps is None: 39 | # Use an Iteration object to set initial values before any real iterations, initialising M, x_inf as zero. Set actions to None blank to indicate there was no previous action 40 | steps = [self.init_iteration(g, x, [None for _ in range(len(a))], prev_M)] 41 | # Perform TEM iteration using transition from previous iteration 42 | L, M, g_gen, p_gen, x_gen, x_logits, x_inf, g_inf, p_inf = self.iteration(x, g, steps[-1].a, steps[-1].M, steps[-1].x_inf, steps[-1].g_inf) 43 | # Store this iteration in iteration object in steps list 44 | steps.append(Iteration(g, x, a, L, M, g_gen, p_gen, x_gen, x_logits, x_inf, g_inf, p_inf)) 45 | # The first step is either a step from a previous walk or initialisiation rubbish, so remove it 46 | steps = steps[1:] 47 | # Return steps, which is a list of Iteration objects 48 | return steps 49 | 50 | def iteration(self, x, locations, a_prev, M_prev, x_prev, g_prev): 51 | # First, do the transition step, as it will be necessary for both the inference and generative part of the model 52 | gt_gen, gt_inf = self.gen_g(a_prev, g_prev, locations) 53 | # Run inference model: infer grounded location p_inf (hippocampus), abstract location g_inf (entorhinal). Also keep filtered sensory observation (x_inf), and retrieved grounded location p_inf_x 54 | x_inf, g_inf, p_inf_x, p_inf = self.inference(x, locations, M_prev, x_prev, gt_inf) 55 | # Run generative model: since generative model is only used for training purposes, it will generate from *inferred* variables instead of *generated* variables (as it would when used for generation) 56 | x_gen, x_logits, p_gen = self.generative(M_prev, p_inf, g_inf, gt_gen) 57 | # Update generative memory with generated and inferred grounded location. 58 | M = [self.hebbian(M_prev[0], torch.cat(p_inf,dim=1), torch.cat(p_gen,dim=1))] 59 | # If using memory for grounded location inference: append inference memory 60 | if self.hyper['use_p_inf']: 61 | # Inference memory is identical to generative memory if using common memory, and updated separatedly if not 62 | M.append(M[0] if self.hyper['common_memory'] else self.hebbian(M_prev[1], torch.cat(p_inf,dim=1), torch.cat(p_inf_x,dim=1), do_hierarchical_connections=False)) 63 | # Calculate loss of this step 64 | L = self.loss(gt_gen, p_gen, x_logits, x, g_inf, p_inf, p_inf_x, M_prev) 65 | # Return all iteration values 66 | return L, M, gt_gen, p_gen, x_gen, x_logits, x_inf, g_inf, p_inf 67 | 68 | def inference(self, x, locations, M_prev, x_prev, g_gen): 69 | # Compress sensory observation from one-hot to two-hot (or alternatively, whatever an MLP makes of it) 70 | x_c = self.f_c(x) 71 | # Temporally filter sensory observation by mixing it with previous experience 72 | x_f = self.x_prev2x(x_prev, x_c) 73 | # Prepare sensory experience for input to memory by normalisation and weighting 74 | x_ = self.x2x_(x_f) 75 | # Retrieve grounded location from memory by doing pattern completion on current sensory experience 76 | p_x = self.attractor(x_, M_prev[1], retrieve_it_mask=self.hyper['p_retrieve_mask_inf']) if self.hyper['use_p_inf'] else None 77 | # Infer abstract location by combining previous abstract location and grounded location retrieved from memory by current sensory experience 78 | g = self.inf_g(p_x, g_gen, x, locations) 79 | # Prepare abstract location for input to memory by downsampling and weighting 80 | g_ = self.g2g_(g) 81 | # Infer grounded location from sensory experience and inferred abstract location 82 | p = self.inf_p(x_, g_) 83 | # Return variables in order that they were created 84 | return x_f, g, p_x, p 85 | 86 | def generative(self, M_prev, p_inf, g_inf, g_gen): 87 | # Generate observation from inferred grounded location, using only the highest frequency. Also keep non-softmaxed logits which are used in the loss later 88 | x_p, x_p_logits = self.gen_x(p_inf[0]) 89 | # Retrieve grounded location from memory by pattern completion on inferred abstract location 90 | p_g_inf = self.gen_p(g_inf, M_prev[0]) # was p_mem_gen 91 | # And generate observation from the grounded location retrieved from inferred abstract location 92 | x_g, x_g_logits = self.gen_x(p_g_inf[0]) 93 | # Retreive grounded location from memory by pattern completion on abstract location by transitioning 94 | p_g_gen = self.gen_p(g_gen, M_prev[0]) 95 | # Generate observation from sampled grounded location 96 | x_gt, x_gt_logits = self.gen_x(p_g_gen[0]) 97 | # Return all generated observations and their corresponding logits 98 | return (x_p, x_g, x_gt), (x_p_logits, x_g_logits, x_gt_logits), p_g_inf 99 | 100 | def loss(self, g_gen, p_gen, x_logits, x, g_inf, p_inf, p_inf_x, M_prev): 101 | # Calculate loss function, separately for each component because you might want to reweight contributions later 102 | # L_p_gen is squared error loss between inferred grounded location and grounded location retrieved from inferred abstract location 103 | L_p_g = torch.sum(torch.stack(utils.squared_error(p_inf, p_gen), dim=0), dim=0) 104 | # L_p_inf is squared error loss between inferred grounded location and grounded location retrieved from sensory experience 105 | L_p_x = torch.sum(torch.stack(utils.squared_error(p_inf, p_inf_x), dim=0), dim=0) if self.hyper['use_p_inf'] else torch.zeros_like(L_p_g) 106 | # L_g is squared error loss between generated abstract location and inferred abstract location 107 | L_g = torch.sum(torch.stack(utils.squared_error(g_inf, g_gen), dim=0), dim=0) 108 | # L_x is a cross-entropy loss between sensory experience and different model predictions. First get true labels from sensory experience 109 | labels = torch.argmax(x, 1) 110 | # L_x_gen: losses generated by generative model from g_prev -> g -> p -> x 111 | L_x_gen = utils.cross_entropy(x_logits[2], labels) 112 | # L_x_g: Losses generated by generative model from g_inf -> p -> x 113 | L_x_g = utils.cross_entropy(x_logits[1], labels) 114 | # L_x_p: Losses generated by generative model from p_inf -> x 115 | L_x_p = utils.cross_entropy(x_logits[0], labels) 116 | # L_reg are regularisation losses, L_reg_g on L2 norm of g 117 | L_reg_g = torch.sum(torch.stack([torch.sum(g ** 2, dim=1) for g in g_inf], dim=0), dim=0) 118 | # And L_reg_p regularisation on L1 norm of p 119 | L_reg_p = torch.sum(torch.stack([torch.sum(torch.abs(p), dim=1) for p in p_inf], dim=0), dim=0) 120 | # Return total loss as list of losses, so you can possibly reweight them 121 | L = [L_p_g, L_p_x, L_x_gen, L_x_g, L_x_p, L_g, L_reg_g, L_reg_p] 122 | return L 123 | 124 | def init_trainable(self): 125 | # Scale factor in Laplacian transform for each frequency module. High frequency comes first, low frequency comes last. Learn inverse sigmoid instead of scale factor directly, so domain of alpha is -inf, inf 126 | self.alpha = torch.nn.ParameterList([torch.nn.Parameter(torch.tensor(np.log(self.hyper['f_initial'][f] / (1 - self.hyper['f_initial'][f])), dtype=torch.float)) for f in range(self.hyper['n_f'])]) 127 | # Entorhinal preference weights 128 | self.w_x = torch.nn.Parameter(torch.tensor(1.0)) 129 | # Entorhinal preference bias 130 | self.b_x = torch.nn.Parameter(torch.zeros(self.hyper['n_x_c'])) 131 | # Frequency module specific scaling of sensory experience before input to hippocampus 132 | self.w_p = torch.nn.ParameterList([torch.nn.Parameter(torch.tensor(1.0)) for f in range(self.hyper['n_f'])]) 133 | # Initial activity of abstract location cells when entering a new environment, like a prior on g. Initialise with truncated normal 134 | self.g_init = torch.nn.ParameterList([torch.nn.Parameter(torch.tensor(truncnorm.rvs(-2, 2, size=self.hyper['n_g'][f], loc=0, scale=self.hyper['g_init_std']), dtype=torch.float)) for f in range(self.hyper['n_f'])]) 135 | # Log of standard deviation of abstract location cells when entering a new environment; standard deviation of the prior on g. Initialise with truncated normal 136 | self.logsig_g_init = torch.nn.ParameterList([torch.nn.Parameter(torch.tensor(truncnorm.rvs(-2, 2, size=self.hyper['n_g'][f], loc=0, scale=self.hyper['g_init_std']), dtype=torch.float)) for f in range(self.hyper['n_f'])]) 137 | # MLP for transition weights (not in paper, but recommended by James so you can learn about similarities between actions). Size is given by grid connections 138 | self.MLP_D_a = MLP([self.hyper['n_actions'] for _ in range(self.hyper['n_f'])], 139 | [sum([self.hyper['n_g'][f_from] for f_from in range(self.hyper['n_f']) if self.hyper['g_connections'][f_to][f_from]])*self.hyper['n_g'][f_to] for f_to in range(self.hyper['n_f'])], 140 | activation=[torch.tanh, None], 141 | hidden_dim=[self.hyper['d_hidden_dim'] for _ in range(self.hyper['n_f'])], 142 | bias=[True, False]) 143 | # Initialise the hidden to output weights as zero, so initially you simply keep the current abstract location to predict the next abstract location 144 | self.MLP_D_a.set_weights(1, 0.0) 145 | # Transition weights without specifying an action for use in generative model with shiny objects 146 | self.D_no_a = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros(sum([self.hyper['n_g'][f_from] for f_from in range(self.hyper['n_f']) if self.hyper['g_connections'][f_to][f_from]])*self.hyper['n_g'][f_to])) for f_to in range(self.hyper['n_f'])]) 147 | # MLP for standard deviation of transition sample 148 | self.MLP_sigma_g_path = MLP(self.hyper['n_g'], self.hyper['n_g'], activation=[torch.tanh, torch.exp], hidden_dim=[2 * g for g in self.hyper['n_g']]) 149 | # MLP for standard devation of grounded location from retrieved memory sample 150 | self.MLP_sigma_p = MLP(self.hyper['n_p'], self.hyper['n_p'], activation=[torch.tanh, torch.exp]) 151 | # MLP to generate mean of abstract location from downsampled abstract location, obtained by summing grounded location over sensory preferences in inference model 152 | self.MLP_mu_g_mem = MLP(self.hyper['n_g_subsampled'], self.hyper['n_g'], hidden_dim=[2 * g for g in self.hyper['n_g']]) 153 | # Initialise weights in last layer of MLP_mu_g_mem as truncated normal for each frequency module 154 | self.MLP_mu_g_mem.set_weights(-1, [torch.tensor(truncnorm.rvs(-2, 2, size=list(self.MLP_mu_g_mem.w[f][-1].weight.shape), loc=0, scale=self.hyper['g_mem_std']), dtype=torch.float) for f in range(self.hyper['n_f'])]) 155 | # MLP to generate standard deviation of abstract location from two measures (generated observation error and inferred abstract location vector norm) of memory quality 156 | self.MLP_sigma_g_mem = MLP([2 for _ in self.hyper['n_g_subsampled']], self.hyper['n_g'], activation=[torch.tanh, torch.exp], hidden_dim=[2 * g for g in self.hyper['n_g']]) 157 | # MLP to generate mean of abstract location directly from shiny object presence. Outputs to object vector cell modules if they're separated, else to all abstract location modules 158 | self.MLP_mu_g_shiny = MLP([1 for _ in range(self.hyper['n_f_ovc'] if self.hyper['separate_ovc'] else self.hyper['n_f'])], 159 | [n_g for n_g in self.hyper['n_g'][(self.hyper['n_f_g'] if self.hyper['separate_ovc'] else 0):]], 160 | hidden_dim=[2*n_g for n_g in self.hyper['n_g'][(self.hyper['n_f_g'] if self.hyper['separate_ovc'] else 0):]]) 161 | # MLP to generate standard deviation of abstract location directly from shiny object presence. Outputs to object vector cell modules if they're separated, else to all abstract location modules 162 | self.MLP_sigma_g_shiny = MLP([1 for _ in range(self.hyper['n_f_ovc'] if self.hyper['separate_ovc'] else self.hyper['n_f'])], 163 | [n_g for n_g in self.hyper['n_g'][(self.hyper['n_f_g'] if self.hyper['separate_ovc'] else 0):]], 164 | hidden_dim=[2*n_g for n_g in self.hyper['n_g'][(self.hyper['n_f_g'] if self.hyper['separate_ovc'] else 0):]], activation=[torch.tanh, torch.exp]) 165 | # MLP for decompressing highest frequency sensory experience to sensory observation 166 | self.MLP_c_star = MLP(self.hyper['n_x_f'][0], self.hyper['n_x'], hidden_dim=20 * self.hyper['n_x_c']) 167 | 168 | def init_iteration(self, g, x, a, M): 169 | # On the very first iteration, update the batch size based on the data. This is useful when doing analysis on the network with different batch sizes compared to training 170 | self.hyper['batch_size'] = x.shape[0] 171 | # Initalise hebbian memory connectivity matrix [M_gen, M_inf] if it wasn't initialised yet 172 | if M is None: 173 | # Create new empty memory dict for generative network: zero connectivity matrix M_0, then empty list of the memory vectors a and b for each iteration for efficient hebbian memory computation 174 | M = [torch.zeros((self.hyper['batch_size'],sum(self.hyper['n_p']),sum(self.hyper['n_p'])), dtype=torch.float)] 175 | # Append inference memory only if memory is used in grounded location inference 176 | if self.hyper['use_p_inf']: 177 | # If inference and generative network share common memory: reuse same connectivity, and same memory vectors. Else, create a new empty memory list for inference network 178 | M.append(M[0] if self.hyper['common_memory'] else torch.zeros((self.hyper['batch_size'],sum(self.hyper['n_p']),sum(self.hyper['n_p'])), dtype=torch.float)) 179 | # Initialise previous abstract location by stacking abstract location prior 180 | g_inf = [torch.stack([self.g_init[f] for _ in range(self.hyper['batch_size'])]) for f in range(self.hyper['n_f'])] 181 | # Initialise previous sensory experience with zeros, as there is no data yet for temporal smoothing 182 | x_inf = [torch.zeros((self.hyper['batch_size'], self.hyper['n_x_f'][f])) for f in range(self.hyper['n_f'])] 183 | # And construct new iteration for that g, x, a, and M 184 | return Iteration(g=g, x=x, a=a, M=M, x_inf=x_inf, g_inf=g_inf) 185 | 186 | def init_walks(self, prev_iter): 187 | # Only reset parameters for previous iteration if a previous iteration was actually provided - if it wasn't, all parameters will be reset when creating a fresh Iteration object in init_iteration 188 | if prev_iter is not None: 189 | # The supplied previous iteration might have new walks starting, with empty actions. For these walks some parameters need to be reset 190 | for a_i, a in enumerate(prev_iter[0].a): 191 | # A new walk is indicated by having a None action in the previous iteration 192 | if a is None: 193 | # Reset the initial connectivity matrix for this walk 194 | for M in prev_iter[0].M: 195 | M[a_i,:,:] = 0 196 | # Reset the abstract location for this walk 197 | for f, g_inf in enumerate(prev_iter[0].g_inf): 198 | g_inf[a_i,:] = self.g_init[f] 199 | # Reset the sensory experience for this walk 200 | for f, x_inf in enumerate(prev_iter[0].x_inf): 201 | x_inf[a_i,:] = torch.zeros(self.hyper['n_x_f'][f]) 202 | # Return the iteration with reset parameters (or simply the empty array if prev_iter was empty) 203 | return prev_iter 204 | 205 | def gen_g(self, a_prev, g_prev, locations): 206 | # Transition from previous abstract location to new abstract location using weights specific to action taken for each frequency module 207 | mu_g = self.f_mu_g_path(a_prev, g_prev) 208 | sigma_g = self.f_sigma_g_path(a_prev, g_prev) 209 | # Either sample new abstract location g or simply take the mean of distribution in noiseless case. 210 | g = [mu_g[f] + sigma_g[f] * np.random.randn() if self.hyper['do_sample'] else mu_g[f] for f in range(self.hyper['n_f'])] 211 | # But for environments with shiny objects, the transition to the new abstract location shouldn't have access to the action direction in the generative model 212 | shiny_envs = [location['shiny'] is not None for location in locations] 213 | # If there are any shiny environments, the abstract locations for the generative model will need to be re-calculated without providing actions for those 214 | g_gen = self.f_mu_g_path(a_prev, g_prev, no_direc=shiny_envs) if any(shiny_envs) else g 215 | # Return generated abstract location after transition 216 | return g_gen, (g, sigma_g) 217 | 218 | def gen_p(self, g, M_prev): 219 | # We want to use g as an index for memory retrieval, but it doesn't have the right dimensions (these are grid cells, we need place cells). We need g_ instead 220 | g_ = self.g2g_(g) 221 | # Retreive memory: do pattern completion on abstract location to get grounded location 222 | mu_p = self.attractor(g_, M_prev, retrieve_it_mask=self.hyper['p_retrieve_mask_gen']) 223 | sigma_p = self.f_sigma_p(mu_p) 224 | # Either sample new grounded location p or simply take the mean of distribution in noiseless case 225 | p = [mu_p[f] + sigma_p[f] * np.random.randn() if self.hyper['do_sample'] else mu_p[f] for f in range(self.hyper['n_f'])] 226 | # Return pattern-completed grounded location p after memory retrieval 227 | return p 228 | 229 | def gen_x(self, p): 230 | # Get categorical distribution over observations from grounded location 231 | # If you actually want to sample observation, you need a reparaterisation trick for categorical distributions 232 | # Sampling would be the correct way to do this, since observations are discrete, and it's also what the TEM paper says 233 | # However, it looks like you could also get away with using categorical distribution directly as an approximation of the one-hot observations 234 | if self.hyper['do_sample']: 235 | x, logits = self.f_x(p) # This is a placeholder! Should be done using reparameterisation trick (like https://blog.evjang.com/2016/11/tutorial-categorical-variational.html) 236 | else: 237 | x, logits = self.f_x(p) 238 | # Return one-hot (or almost one-hot...) observation obtained from grounded location, and also the non-softmaxed logits 239 | return x, logits 240 | 241 | def inf_g(self, p_x, g_gen, x, locations): 242 | # Infer abstract location from the combination of [grounded location retrieved from memory by sensory experience] ... 243 | if self.hyper['use_p_inf']: 244 | # Not in paper, but makes sense from symmetry with f_x: first get g from p by "summing over sensory preferences" g = p * W_repeat^T 245 | g_downsampled = [torch.matmul(p_x[f], torch.t(self.hyper['W_repeat'][f])) for f in range(self.hyper['n_f'])] 246 | # Then use abstract location after summing over sensory preferences as input to MLP to obtain the inferred abstract location from memory 247 | mu_g_mem = self.f_mu_g_mem(g_downsampled) 248 | # Not in paper, but this greatly improves zero-shot inference: provide the uncertainty function of the inferred abstract location with measures of memory quality 249 | with torch.no_grad(): 250 | # For the first measure, use the grounded location inferred from memory to generate an observation 251 | x_hat, x_hat_logits = self.gen_x(p_x[0]) 252 | # Then calculate the error between the generated observation and the actual observation: if the memory is working well, this error should be small 253 | err = utils.squared_error(x, x_hat) 254 | # The second measure is the vector norm of the inferred abstract location; good memories should have similar vector norms. Concatenate the two measures as input for the abstract location uncertainty function 255 | sigma_g_input = [torch.cat((torch.sum(g ** 2, dim=1, keepdim=True), torch.unsqueeze(err, dim=1)), dim=1) for g in mu_g_mem] 256 | # Not in paper, but recommended by James for stability: get final mean of inferred abstract location by clamping activations between -1 and 1 257 | mu_g_mem = self.f_g_clamp(mu_g_mem) 258 | # And get standard deviation/uncertainty of inferred abstract location by providing uncertainty function with memory quality measures 259 | sigma_g_mem = self.f_sigma_g_mem(sigma_g_input) 260 | # ... and [previous abstract location and action (path integration)] 261 | mu_g_path = g_gen[0] 262 | sigma_g_path = g_gen[1] 263 | # Infer abstract location by combining previous abstract location and grounded location retrieved from memory by current sensory experience 264 | mu_g, sigma_g = [], [] 265 | for f in range(self.hyper['n_f']): 266 | if self.hyper['use_p_inf']: 267 | # Then get full gaussian distribution of inferred abstract location by calculating precision weighted mean 268 | mu, sigma = utils.inv_var_weight([mu_g_path[f], mu_g_mem[f]],[sigma_g_path[f], sigma_g_mem[f]]) 269 | else: 270 | # Or simply completely ignore the inference memory here, to test if things are working 271 | mu, sigma = mu_g_path[f], sigma_g_path[f] 272 | # Append mu and sigma to list for all frequency modules 273 | mu_g.append(mu) 274 | sigma_g.append(sigma) 275 | # Finally (though not in paper), also add object vector cell information to inferred abstract location for environments with shiny objects 276 | shiny_envs = [location['shiny'] is not None for location in locations] 277 | if any(shiny_envs): 278 | # Find for which environments the current location has a shiny object 279 | shiny_locations = torch.unsqueeze(torch.stack([torch.tensor(location['shiny'], dtype=torch.float) for location in locations if location['shiny'] is not None]), dim=-1) 280 | # Get abstract location for environments with shiny objects and feed to each of the object vector cell modules 281 | mu_g_shiny = self.f_mu_g_shiny([shiny_locations for _ in range(self.hyper['n_f_g'] if self.hyper['separate_ovc'] else self.hyper['n_f'])]) 282 | sigma_g_shiny = self.f_sigma_g_shiny([shiny_locations for _ in range(self.hyper['n_f_g'] if self.hyper['separate_ovc'] else self.hyper['n_f'])]) 283 | # Update only object vector modules with shiny-inferred abstract location: start from offset if object vector modules are separate 284 | module_start = self.hyper['n_f_g'] if self.hyper['separate_ovc'] else 0 285 | # Inverse variance weighting is associative, so I can just do additional inverse variance weighting to the previously obtained mu and sigma - but only for object vector cell modules! 286 | for f in range(module_start, self.hyper['n_f']): 287 | # Add inferred abstract location from shiny objects to previously obtained position, only for environments with shiny objects 288 | mu, sigma = utils.inv_var_weight([mu_g[f][shiny_envs,:], mu_g_shiny[f - module_start]], [sigma_g[f][shiny_envs,:], sigma_g_shiny[f - module_start]]) 289 | # In order to update only the environments with shiny objects, without in-place value assignment, construct a mask of shiny environments 290 | mask = torch.zeros_like(mu_g[f], dtype=torch.bool) 291 | mask[shiny_envs,:] = True 292 | # Use mask to update the shiny environment entries in inferred abstract locations 293 | mu_g[f] = mu_g[f].masked_scatter(mask,mu) 294 | sigma_g[f] = sigma_g[f].masked_scatter(mask,sigma) 295 | # Either sample inferred abstract location from combined (precision weighted) distribution or just take mean 296 | g = [mu_g[f] + sigma_g[f] * np.random.randn() if self.hyper['do_sample'] else mu_g[f] for f in range(self.hyper['n_f'])] 297 | # Return abstract location inferred from grounded location from memory and previous abstract location 298 | return g 299 | 300 | def inf_p(self, x_, g_): 301 | # Infer grounded location from sensory experience and inferred abstract location for each module 302 | p = [] 303 | # Use the same transformation for each frequency module: leaky relu for sparsity 304 | for f in range(self.hyper['n_f']): 305 | mu_p = self.f_p(g_[f] * x_[f]) # This is element-wise multiplication 306 | sigma_p = 0 # Unclear from paper (typo?). Some undefined function f that takes two arguments: f(f_n(x),g) 307 | # Either sample inferred grounded location or just take mean 308 | if self.hyper['do_sample']: 309 | p.append(mu_p + sigma_p * np.random.randn()) 310 | else: 311 | p.append(mu_p) 312 | # Return new memory constructed from sensory experience and inferred abstract location 313 | return p 314 | 315 | def x_prev2x(self, x_prev, x_c): 316 | # Calculate factor for filtering from sigmoid of learned parameter 317 | alpha = [torch.nn.Sigmoid()(self.alpha[f]) for f in range(self.hyper['n_f'])] 318 | # Do exponential temporal filtering for each frequency modulemod 319 | x = [(1 - alpha[f]) * x_prev[f] + alpha[f] * x_c for f in range(self.hyper['n_f'])] 320 | return x 321 | 322 | def x2x_(self, x): 323 | # Prepare sensory input for input to memory by weighting and normalisation for each frequency module 324 | # Get normalised sensory input for each frequency module 325 | normalised = self.f_n(x) 326 | # Then reshape and reweight (use sigmoid to keep weight between 0 and 1) each frequency module separately: matrix multiplication by W_tile prepares x for outer product with g by element-wise multiplication 327 | x_ = [torch.nn.Sigmoid()(self.w_p[f]) * torch.matmul(normalised[f],self.hyper['W_tile'][f]) for f in range(self.hyper['n_f'])] 328 | return x_ 329 | 330 | def g2g_(self, g): 331 | # Prepares abstract location for input to memory by reshaping and down-sampling for each frequency module 332 | # Get downsampled abstract location for each frequency module 333 | downsampled = self.f_g(g) 334 | # Then reshape and reweight each frequency module separately 335 | g_ = [torch.matmul(downsampled[f], self.hyper['W_repeat'][f]) for f in range(self.hyper['n_f'])] 336 | return g_ 337 | 338 | def f_mu_g_path(self, a_prev, g_prev, no_direc=None): 339 | # If there are no environments where the transition direction needs to be omitted (e.g. no shiny objects, or in inference model: set to all false 340 | no_direc = [False for _ in a_prev] if no_direc is None else no_direc 341 | # Remove all Nones from a_prev: these are walks where there was no previous action, so no step needs to be calculated for those 342 | a_prev_step = [a if a is not None else 0 for a in a_prev] 343 | # And also keep track of which walks these valid step actions are for 344 | a_do_step = [a != None for a in a_prev] 345 | # Transform list of actions into batch of one-hot row vectors. 346 | if self.hyper['has_static_action']: 347 | # If this world has static actions: whenever action 0 (standing still) appears, the action vector should be all zeros. All other actions should have a 1 in the label-1 entry 348 | a = torch.zeros((len(a_prev_step),self.hyper['n_actions'])).scatter_(1, torch.clamp(torch.tensor(a_prev_step).unsqueeze(1)-1,min=0), 1.0*(torch.tensor(a_prev_step).unsqueeze(1)>0)) 349 | else: 350 | # Without static actions: each action label should become a one-hot vector for that label 351 | a = torch.zeros((len(a_prev_step),self.hyper['n_actions'])).scatter_(1, torch.tensor(a_prev_step).unsqueeze(1), 1.0) 352 | # Get vector of transition weights by feeding actions into MLP 353 | D_a = self.MLP_D_a([a for _ in range(self.hyper['n_f'])]) 354 | # Replace transition weights by non-directional transition weights in environments where transition direction needs to be omitted (can set only if any no_direc) 355 | for f in range(self.hyper['n_f']): 356 | D_a[f][no_direc,:] = self.D_no_a[f] 357 | # Reshape transition weight vector into transition matrix. The number of rows in the transition matrix is given by the incoming abstract location connections for each frequency module 358 | D_a = [torch.reshape(D_a[f_to],(-1, sum([self.hyper['n_g'][f_from] for f_from in range(self.hyper['n_f']) if self.hyper['g_connections'][f_to][f_from]]), self.hyper['n_g'][f_to])) for f_to in range(self.hyper['n_f'])] 359 | # Select the frequency modules of the previous abstract location that are connected to each frequency module, to 360 | g_in = [torch.unsqueeze(torch.cat([g_prev[f_from] for f_from in range(self.hyper['n_f']) if self.hyper['g_connections'][f_to][f_from]], dim=1),1) for f_to in range(self.hyper['n_f'])] 361 | # Reshape transition weight vector into transition matrix. The number of rows in the transition matrix is given by the incoming abstract location connections for each frequency module 362 | delta = [torch.squeeze(torch.matmul(g, T)) for g, T in zip(g_in, D_a)] 363 | # Not in the paper, but recommended by James for stability: use inferred code as *difference* in abstract location. Calculate new abstract location from previous abstract location and difference 364 | g_step = [g + d if g.dim() > 1 else torch.unsqueeze(g + d, 0) for g, d in zip(g_prev, delta)] 365 | # Not in paper, but recommended by James for stability: clamp activations between -1 and 1 366 | g_step = self.f_g_clamp(g_step) 367 | # Build new abstract location from result of transition if there was one, or from prior on abstract location if there wasn't 368 | return [torch.stack([g_step[f][batch_i, :] if do_step else self.g_init[f] for batch_i, do_step in enumerate(a_do_step)]) for f in range(self.hyper['n_f'])] 369 | 370 | def f_sigma_g_path(self, a_prev, g_prev): 371 | # Keep track of which walks these valid step actions are for 372 | a_do_step = [a != None for a in a_prev] 373 | # Multi layer perceptron to generate standard deviation from all previous abstract locations, including those that were just initialised and not real previous locations 374 | from_g = self.MLP_sigma_g_path(g_prev) 375 | # And take exponent to get prior sigma for the walks that didn't have a previous location 376 | from_prior = [torch.exp(logsig) for logsig in self.logsig_g_init] 377 | # Now select the standard deviation generated from the previous abstract location if there was one, and the prior standard deviation on abstract location otherwise 378 | return [torch.stack([from_g[f][batch_i, :] if do_step else from_prior[f] for batch_i, do_step in enumerate(a_do_step)]) for f in range(self.hyper['n_f'])] 379 | 380 | def f_mu_g_mem(self, g_downsampled): 381 | # Multi layer perceptron to generate mean of abstract location from down-sampled abstract location, obtained by summing over sensory dimension of grounded location 382 | return self.MLP_mu_g_mem(g_downsampled) 383 | 384 | def f_sigma_g_mem(self, g_downsampled): 385 | # Multi layer perceptron to generate standard deviation of abstract location from down-sampled abstract location, obtained by summing over sensory dimension of grounded location 386 | sigma = self.MLP_sigma_g_mem(g_downsampled) 387 | # Not in paper, but also offset this sigma over training, so you can reduce influence of inferred p early on 388 | return [sigma[f] + self.hyper['p2g_scale_offset'] * self.hyper['p2g_sig_val'] for f in range(self.hyper['n_f'])] 389 | 390 | def f_mu_g_shiny(self, shiny): 391 | # Multi layer perceptron to generate mean of abstract location from boolean location shiny-ness 392 | mu_g = self.MLP_mu_g_shiny(shiny) 393 | # Take absolute because James wants object vector cells to be positive 394 | mu_g = [torch.abs(mu) for mu in mu_g] 395 | # Then apply clamp and leaky relu to get object vector module activations, like it's done for ground location activations 396 | g = self.f_p(mu_g) 397 | return g 398 | 399 | def f_sigma_g_shiny(self, shiny): 400 | # Multi layer perceptron to generate standard deviation of abstract location from boolean location shiny-ness 401 | return self.MLP_sigma_g_shiny(shiny) 402 | 403 | def f_sigma_p(self, p): 404 | # Multi layer perceptron to generate standard deviation of grounded location retrieval 405 | return self.MLP_sigma_p(p) 406 | 407 | def f_x(self, p): 408 | # Calculate categorical probability distribution over observations for a given ground location 409 | # p has dimensions n_p[0]. We'll need to transform those to temporally filtered sensory experience, before we can decompress 410 | # p is the flattened (by concatenating rows - like reading sentences) outer product of g and x (p = g^T * x). 411 | # Therefore to get the sensory experience x for a grounded location p, sum over all abstract locations g for each component of x 412 | # That's what the paper means when it says "sum over entorhinal preferences". It can be done with the transpose of W_tile 413 | x = self.w_x * torch.matmul(p, torch.t(self.hyper['W_tile'][0])) + self.b_x 414 | # Then we need to decompress the temporally filtered sensory experience into a single current experience prediction 415 | logits = self.f_c_star(x) 416 | # We'll keep both the logits (domain -inf, inf) and probabilities (domain 0, 1) because both are needed later on 417 | probability = utils.softmax(logits) 418 | return probability, logits 419 | 420 | def f_c_star(self, compressed): 421 | # Multi layer perceptron to decompress sensory experience at highest frequency 422 | return self.MLP_c_star(compressed) 423 | 424 | def f_c(self, decompressed): 425 | # Compress sensory observation from one-hot provided by world to two-hot for ease of computation 426 | return torch.stack([self.hyper['two_hot_table'][i] for i in torch.argmax(decompressed, dim=1)], dim=0) 427 | 428 | def f_n(self, x): 429 | # Normalise sensory observation for each frequency module 430 | normalised = [utils.normalise(utils.relu(x[f] - torch.mean(x[f]))) for f in range(self.hyper['n_f'])] 431 | return normalised 432 | 433 | def f_g(self, g): 434 | # Downsample abstract location for each frequency module 435 | downsampled = [torch.matmul(g[f], self.hyper['g_downsample'][f]) for f in range(self.hyper['n_f'])] 436 | return downsampled 437 | 438 | def f_g_clamp(self, g): 439 | # Calculate activation for abstract location, thresholding between -1 and 1 440 | activation = [torch.clamp(g_f, min=-1, max=1) for g_f in g] 441 | return activation 442 | 443 | def f_p(self, p): 444 | # Calculate activation for inferred grounded location, using a leaky relu for sparsity. Either apply to full multi-frequency grounded location or single frequency module 445 | activation = [utils.leaky_relu(torch.clamp(p_f, min=-1, max=1)) for p_f in p] if type(p) is list else utils.leaky_relu(torch.clamp(p, min=-1, max=1)) 446 | return activation 447 | 448 | def attractor(self, p_query, M, retrieve_it_mask=None): 449 | # Retreive grounded location from attractor network memory with weights M by pattern-completing query 450 | # For example, initial attractor input can come from abstract location (g_) or sensory experience (x_) 451 | # Start by flattening query grounded locations across frequency modules 452 | h_t = torch.cat(p_query, dim=1) 453 | # Apply activation function to initial memory index 454 | h_t = self.f_p(h_t) 455 | # Hierarchical retrieval (not in paper) is implemented by early stopping retrieval for low frequencies, using a mask. If not specified: initialise mask as all 1s 456 | retrieve_it_mask = [torch.ones(sum(self.hyper['n_p'])) for _ in range(self.hyper['n_p'])] if retrieve_it_mask is None else retrieve_it_mask 457 | # Iterate attractor dynamics to do pattern completion 458 | for tau in range(self.hyper['i_attractor']): 459 | # Apply one iteration of attractor dynamics, but only where there is a 1 in the mask. NB retrieve_it_mask entries have only one row, but are broadcasted to batch_size 460 | h_t = (1-retrieve_it_mask[tau])*h_t + retrieve_it_mask[tau]*(self.f_p(self.hyper['kappa'] * h_t + torch.squeeze(torch.matmul(torch.unsqueeze(h_t,1), M)))) 461 | # Make helper list of cumulative neurons per frequency module for grounded locations 462 | n_p = np.cumsum(np.concatenate(([0],self.hyper['n_p']))) 463 | # Now re-cast the grounded location into different frequency modules, since memory retrieval turned it into one long vector 464 | p = [h_t[:,n_p[f]:n_p[f+1]] for f in range(self.hyper['n_f'])] 465 | return p; 466 | 467 | def hebbian(self, M_prev, p_inferred, p_generated, do_hierarchical_connections=True): 468 | # Create new ground memory for attractor network by setting weights to outer product of learned vectors 469 | # p_inferred corresponds to p in the paper, and p_generated corresponds to p^. 470 | # The order of p + p^ and p - p^ is reversed since these are row vectors, instead of column vectors in the paper. 471 | M_new = torch.squeeze(torch.matmul(torch.unsqueeze(p_inferred + p_generated, 2),torch.unsqueeze(p_inferred - p_generated,1))) 472 | # Multiply by connection vector, e.g. only keeping weights from low to high frequencies for hierarchical retrieval 473 | if do_hierarchical_connections: 474 | M_new = M_new * self.hyper['p_update_mask'] 475 | # Store grounded location in attractor network memory with weights M by Hebbian learning of pattern 476 | M = torch.clamp(self.hyper['lambda'] * M_prev + self.hyper['eta'] * M_new, min=-1, max=1) 477 | return M; 478 | 479 | class MLP(torch.nn.Module): 480 | def __init__(self, in_dim, out_dim, activation=(torch.nn.functional.elu, None), hidden_dim=None, bias=(True, True)): 481 | # First call super class init function to set up torch.nn.Module style model and inherit it's functionality 482 | super(MLP, self).__init__() 483 | # Check if this network consists of module: are input and output dimensions lists? If not, make them (but remember it wasn't) 484 | if type(in_dim) is list: 485 | self.is_list = True 486 | else: 487 | in_dim = [in_dim] 488 | out_dim = [out_dim] 489 | self.is_list = False 490 | # Find number of modules 491 | self.N = len(in_dim) 492 | # Create weights (input->hidden, hidden->output) for each module 493 | self.w = torch.nn.ModuleList([]) 494 | for n in range(self.N): 495 | # If number of hidden dimensions is not specified: mean of input and output 496 | if hidden_dim is None: 497 | hidden = int(np.mean([in_dim[n],out_dim[n]])) 498 | else: 499 | hidden = hidden_dim[n] if self.is_list else hidden_dim 500 | # Each module has two sets of weights: input->hidden and hidden->output 501 | self.w.append(torch.nn.ModuleList([torch.nn.Linear(in_dim[n], hidden, bias=bias[0]), torch.nn.Linear(hidden, out_dim[n], bias=bias[1])])) 502 | # Copy activation function for hidden layer and output layer 503 | self.activation = activation 504 | # Initialise all weights 505 | with torch.no_grad(): 506 | for from_layer in range(2): 507 | for n in range(self.N): 508 | # Set weights to xavier initalisation 509 | torch.nn.init.xavier_normal_(self.w[n][from_layer].weight) 510 | # Set biases to 0 511 | if bias[from_layer]: 512 | self.w[n][from_layer].bias.fill_(0.0) 513 | 514 | def set_weights(self, from_layer, value): 515 | # If single value is provided: copy it for each module 516 | if type(value) is not list: 517 | input_value = [value for n in range(self.N)] 518 | else: 519 | input_value = value 520 | # Run through all modules and set weights starting from requested layer to the specified value 521 | with torch.no_grad(): 522 | # MLP is setup as follows: w[module][layer] is Linear object, w[module][layer].weight is Parameter object for linear weights, w[module][layer].weight.data is tensor of weight values 523 | for n in range(self.N): 524 | # If a tensor is provided: copy the tensor to the weights 525 | if type(input_value[n]) is torch.Tensor: 526 | self.w[n][from_layer].weight.copy_(input_value[n]) 527 | # If only a single value is provided: set that value everywhere 528 | else: 529 | self.w[n][from_layer].weight.fill_(input_value[n]) 530 | 531 | def forward(self, data): 532 | # Make input data into list, if this network doesn't consist of modules 533 | if self.is_list: 534 | input_data = data 535 | else: 536 | input_data = [data] 537 | # Run input through network for each module 538 | output = [] 539 | for n in range(self.N): 540 | # Pass through first weights from input to hidden layer 541 | module_output = self.w[n][0](input_data[n]) 542 | # Apply hidden layer activation 543 | if self.activation[0] is not None: 544 | module_output = self.activation[0](module_output) 545 | # Pass through second weights from hidden to output layer 546 | module_output = self.w[n][1](module_output) 547 | # Apply output layer activation 548 | if self.activation[1] is not None: 549 | module_output = self.activation[1](module_output) 550 | # Transpose output again to go back to column vectors instead of row vectors 551 | output.append(module_output) 552 | # If this network doesn't consist of modules: select output from first module to return 553 | if not self.is_list: 554 | output = output[0] 555 | # And return output 556 | return output 557 | 558 | class LSTM(torch.nn.Module): 559 | def __init__(self, in_dim, hidden_dim, out_dim, n_layers = 1, n_a = 4): 560 | # First call super class init function to set up torch.nn.Module style model and inherit it's functionality 561 | super(LSTM, self).__init__() 562 | # LSTM layer 563 | self.lstm = torch.nn.LSTM(in_dim, hidden_dim, n_layers, batch_first=True) 564 | # Hidden to output 565 | self.lin = torch.nn.Linear(hidden_dim, out_dim) 566 | # Copy number of actions, will be needed for input data vector 567 | self.n_a = n_a 568 | 569 | def forward(self, data, prev_hidden = None): 570 | # If previous hidden and cell state are not provided: initialise them randomly 571 | if prev_hidden is None: 572 | hidden_state = torch.randn(self.lstm.num_layers, data.shape[0], self.lstm.hidden_size) 573 | cell_state = torch.randn(self.lstm.num_layers, data.shape[0], self.lstm.hidden_size) 574 | prev_hidden = (hidden_state, cell_state) 575 | # Run input through lstm 576 | lstm_out, lstm_hidden = self.lstm(data, prev_hidden) 577 | # Apply linear network to lstm output to get output: prediction at each timestep 578 | lin_out = self.lin(lstm_out) 579 | # And since we want a one-hot prediciton: do softmax on top 580 | out = utils.softmax(lin_out) 581 | # Return output and hidden state 582 | return out, lstm_hidden 583 | 584 | def prepare_data(self, data_in): 585 | # Transform list of actions of each step into batch of one-hot row vectors 586 | actions = [torch.zeros((len(step[2]),self.n_a)).scatter_(1, torch.tensor(step[2]).unsqueeze(1), 1.0) for step in data_in] 587 | # Concatenate observation and action together along column direction in each step 588 | vectors = [torch.cat((step[1], action), dim=1) for step, action in zip(data_in, actions)] 589 | # Then stack all these together along the second dimension, which is sequence length 590 | data = torch.stack(vectors, dim=1) 591 | # Return data in [batch_size, seq_len, input_dim] dimension as expected by lstm 592 | return data 593 | 594 | class Iteration: 595 | def __init__(self, g = None, x = None, a = None, L = None, M = None, g_gen = None, p_gen = None, x_gen = None, x_logits = None, x_inf = None, g_inf = None, p_inf = None): 596 | # Copy all inputs 597 | self.g = g 598 | self.x = x 599 | self.a = a 600 | self.L = L 601 | self.M = M 602 | self.g_gen = g_gen 603 | self.p_gen = p_gen 604 | self.x_gen = x_gen 605 | self.x_logits = x_logits 606 | self.x_inf = x_inf 607 | self.g_inf = g_inf 608 | self.p_inf = p_inf 609 | 610 | def correct(self): 611 | # Detach observation and all predictions 612 | observation = self.x.detach().numpy() 613 | predictions = [tensor.detach().numpy() for tensor in self.x_gen] 614 | # Did the model predict the right observation in this iteration? 615 | accuracy = [np.argmax(prediction, axis=-1) == np.argmax(observation, axis=-1) for prediction in predictions] 616 | return accuracy 617 | 618 | def detach(self): 619 | # Detach all tensors contained in this iteration 620 | self.L = [tensor.detach() for tensor in self.L] 621 | self.M = [tensor.detach() for tensor in self.M] 622 | self.g_gen = [tensor.detach() for tensor in self.g_gen] 623 | self.p_gen = [tensor.detach() for tensor in self.p_gen] 624 | self.x_gen = [tensor.detach() for tensor in self.x_gen] 625 | self.x_inf = [tensor.detach() for tensor in self.x_inf] 626 | self.g_inf = [tensor.detach() for tensor in self.g_inf] 627 | self.p_inf = [tensor.detach() for tensor in self.p_inf] 628 | # Return self after detaching everything 629 | return self 630 | -------------------------------------------------------------------------------- /parameters.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Feb 12 09:42:32 2020 5 | 6 | @author: jacobb 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | from scipy.special import comb 12 | 13 | # This contains one single function that generates a dictionary of parameters, which is provided to the model on initialisation 14 | def parameters(): 15 | params = {} 16 | # -- World parameters 17 | # Does this world include the standing still action? 18 | params['has_static_action'] = True 19 | # Number of available actions, excluding the stand still action (since standing still has an action vector full of zeros, it won't add to the action vector dimension) 20 | params['n_actions'] = 4 21 | # Bias for explorative behaviour to pick the same action again, to encourage straight walks 22 | params['explore_bias'] = 2 23 | # Rate at which environments with shiny objects occur between training environments. Set to 0 for no shiny environments at all 24 | params['shiny_rate'] = 0 25 | # Discount factor in calculating Q-values to generate shiny object oriented behaviour 26 | params['shiny_gamma'] = 0.7 27 | # Inverse temperature for shiny object behaviour to pick actions based on Q-values 28 | params['shiny_beta'] = 1.5 29 | # Number of shiny objects in the arena 30 | params['shiny_n'] = 2 31 | # Number of times to return to a shiny object after finding it 32 | params['shiny_returns'] = 15 33 | # Group all shiny parameters together to pass them to the world object 34 | params['shiny'] = {'gamma' : params['shiny_gamma'], 'beta' : params['shiny_beta'], 'n' : params['shiny_n'], 'returns' : params['shiny_returns']} 35 | 36 | # -- Traning parameters 37 | # Number of walks to generate 38 | params['train_it'] = 20000 39 | # Number of steps to roll out before backpropagation through time 40 | params['n_rollout'] = 20 41 | # Batch size: number of walks for training simultaneously 42 | params['batch_size'] = 16 43 | # Minimum length of a walk on one environment. Walk lengths are sampled uniformly from a window that shifts down until its lower limit is walk_it_min at the end of training 44 | params['walk_it_min'] = 25 45 | # Maximum length of a walk on one environment. Walk lengths are sampled uniformly from a window that starts with its upper limit at walk_it_max in the beginning of training, then shifts down 46 | params['walk_it_max'] = 300 47 | # Width of window from which walk lengths are sampled: at any moment, new walk lengths are sampled window_center +/- 0.5 * walk_it_window where window_center shifts down 48 | params['walk_it_window'] = 0.2 * (params['walk_it_max'] - params['walk_it_min']) 49 | # Weights of prediction losses 50 | params['loss_weights_x'] = 1 51 | # Weights of grounded location losses 52 | params['loss_weights_p'] = 1 53 | # Weights of abstract location losses 54 | params['loss_weights_g'] = 1 55 | # Weights of regularisation losses 56 | params['loss_weights_reg_g'] = 0.01 57 | params['loss_weights_reg_p'] = 0.02 58 | # Weights of losses: re-balance contributions of L_p_g, L_p_x, L_x_gen, L_x_g, L_x_p, L_g, L_reg_g, L_reg_p 59 | params['loss_weights'] = torch.tensor([params['loss_weights_p'], params['loss_weights_p'], params['loss_weights_x'], params['loss_weights_x'], params['loss_weights_x'], params['loss_weights_g'], params['loss_weights_reg_g'], params['loss_weights_reg_p']], dtype=torch.float) 60 | # Number of backprop iters until latent parameter losses (L_p_g, L_p_x, L_g) are all fully weighted 61 | params['loss_weights_p_g_it'] = 2000 62 | # Number of backptrop iters until regularisation losses are fully weighted 63 | params['loss_weights_reg_p_it'] = 4000 64 | params['loss_weights_reg_g_it'] = 40000000 65 | # Number of backprop iters until eta is (rate of remembering) completely 'on' 66 | params['eta_it'] = 16000 67 | # Number of backprop iters until lambda (rate of forgetting) is completely 'on' 68 | params['lambda_it'] = 200 69 | # Determine how much to use an offset for the standard deviation of the inferred grounded location to reduce its influence 70 | params['p2g_scale_offset'] = 0 71 | # Additional value to offset standard deviation of inferred grounded location when inferring new abstract location, to reduce influence in precision weighted mean 72 | params['p2g_sig_val'] = 10000 73 | # Set number of iterations where offset scaling should be 0.5 74 | params['p2g_sig_half_it'] = 400 75 | # Set how fast offset scaling should decrease - after p2g_sig_half_it + p2g_sig_scale_it the offset scaling is down to ~0.25 (1/(1+e) to be exact) 76 | params['p2g_sig_scale_it'] = 200 77 | # Maximum learning rate 78 | params['lr_max'] = 9.4e-4 79 | # Minimum learning rate 80 | params['lr_min'] = 8e-5 81 | # Rate of learning rate decay 82 | params['lr_decay_rate'] = 0.5 83 | # Steps of learning rate decay 84 | params['lr_decay_steps'] = 4000 85 | 86 | # -- Model parameters 87 | # Decide whether to sample, or assume no noise and simply take mean of all distributions 88 | params['do_sample'] = False 89 | # Decide whether to use inferred ground location while inferring new abstract location, instead of only previous grounded location (James's infer_g_type) 90 | params['use_p_inf'] = True 91 | # Decide whether to use seperate grid modules that recieve shiny information for object vector cells. To disable OVC, set this False, and set n_ovc to [0 for _ in range(len(params['n_g_subsampled']))] 92 | params['separate_ovc'] = False 93 | # Standard deviation for initial initial g (which will then be learned) 94 | params['g_init_std'] = 0.5 95 | # Standard deviation to initialise hidden to output layer of MLP for inferring new abstract location from memory of grounded location 96 | params['g_mem_std'] = 0.1 97 | # Hidden layer size of MLP for abstract location transitions 98 | params['d_hidden_dim'] = 20 99 | 100 | # ---- Neuron and module parameters 101 | # Neurons for subsampled entorhinal abstract location f_g(g) for each frequency module 102 | params['n_g_subsampled'] = [10, 10, 8, 6, 6] 103 | # Neurons for object vector cells. Neurons will get new modules if object vector cell modules are separated; otherwise, they are added to existing abstract location modules. 104 | # a) No additional modules, no additional object vector neurons (e.g. when not using shiny environments): [0 for _ in range(len(params['n_g_subsampled']))], and separate_ovc set to False 105 | # b) No additional modules, but n additional object vector neurons in each grid module: [n for _ in range(len(params['n_g_subsampled']))], and separate_ovc set to False 106 | # c) Additional separate object vector modules, with n, m neurons: [n, m], and separate_ovc set to True 107 | params['n_ovc'] = [0 for _ in range(len(params['n_g_subsampled']))] 108 | # Add neurons for object vector cells. Add new modules if object vector cells get separate modules, or else add neurons to existing modules 109 | params['n_g_subsampled'] = params['n_g_subsampled'] + params['n_ovc'] if params['separate_ovc'] else [grid + ovc for grid, ovc in zip(params['n_g_subsampled'],params['n_ovc'])] 110 | # Number of hierarchical frequency modules for object vector cells 111 | params['n_f_ovc'] = len(params['n_ovc']) if params['separate_ovc'] else 0 112 | # Number of hierarchical frequency modules for grid cells 113 | params['n_f_g'] = len(params['n_g_subsampled']) - params['n_f_ovc'] 114 | # Total number of modules 115 | params['n_f'] = len(params['n_g_subsampled']) 116 | # Number of neurons of entorhinal abstract location g for each frequency 117 | params['n_g'] = [3 * g for g in params['n_g_subsampled']] 118 | # Neurons for sensory observation x 119 | params['n_x'] = 45 120 | # Neurons for compressed sensory experience x_c 121 | params['n_x_c'] = 10 122 | # Neurons for temporally filtered sensory experience x for each frequency 123 | params['n_x_f'] = [params['n_x_c'] for _ in range(params['n_f'])] 124 | # Neurons for hippocampal grounded location p for each frequency 125 | params['n_p'] = [g * x for g, x in zip(params['n_g_subsampled'], params['n_x_f'])] 126 | # Initial frequencies of each module. For ease of interpretation (higher number = higher frequency) this is 1 - the frequency as James uses it 127 | params['f_initial'] = [0.99, 0.3, 0.09, 0.03, 0.01] 128 | # Add frequencies of object vector cell modules, if object vector cells get separate modules 129 | params['f_initial'] = params['f_initial'] + params['f_initial'][0:params['n_f_ovc']] 130 | 131 | # ---- Memory parameters 132 | # Use common memory for generative and inference network 133 | params['common_memory'] = False 134 | # Hebbian rate of forgetting 135 | params['lambda'] = 0.9999 136 | # Hebbian rate of remembering 137 | params['eta'] = 0.5 138 | # Hebbian retrieval decay term 139 | params['kappa'] = 0.8 140 | # Number of iterations of attractor dynamics for memory retrieval 141 | params['i_attractor'] = params['n_f_g'] 142 | # Maximum iterations of attractor dynamics per frequency in inference model, so you can early stop low-frequency modules. Set to None for no early stopping 143 | params['i_attractor_max_freq_inf'] = [params['i_attractor'] for _ in range(params['n_f'])] 144 | # Maximum iterations of attractor dynamics per frequency in generative model, so you can early stop low-frequency modules. Don't early stop for object vector cell modules. 145 | params['i_attractor_max_freq_gen'] = [params['i_attractor'] - freq_nr for freq_nr in range(params['n_f_g'])] + [params['i_attractor'] for _ in range(params['n_f_ovc'])] 146 | 147 | # --- Connectivity matrices 148 | # Set connections when forming Hebbian memory of grounded locations: from low frequency modules to high. High frequency modules come first (different from James!) 149 | params['p_update_mask'] = torch.zeros((np.sum(params['n_p']),np.sum(params['n_p'])), dtype=torch.float) 150 | n_p = np.cumsum(np.concatenate(([0],params['n_p']))) 151 | # Entry M_ij (row i, col j) is the connection FROM cell i TO cell j. Memory is retrieved by h_t+1 = h_t * M, i.e. h_t+1_j = sum_i {connection from i to j * h_t_i} 152 | for f_from in range(params['n_f']): 153 | for f_to in range(params['n_f']): 154 | # For connections that involve separate object vector modules: these are connected to all normal modules, but hierarchically between object vector modules 155 | if f_from > params['n_f_g'] or f_to > params['n_f_g']: 156 | # If this is a connection between object vector modules: only allow for connection from low to high frequency 157 | if (f_from > params['n_f_g'] and f_to > params['n_f_g']): 158 | if params['f_initial'][f_from] <= params['f_initial'][f_to]: 159 | params['p_update_mask'][n_p[f_from]:n_p[f_from+1],n_p[f_to]:n_p[f_to+1]] = 1.0 160 | # If this is a connection to between object vector and normal modules: allow any connections, in both directions 161 | else: 162 | params['p_update_mask'][n_p[f_from]:n_p[f_from+1],n_p[f_to]:n_p[f_to+1]] = 1.0 163 | # Else: this is a connection between abstract location frequency modules; only allow for connections if it goes from low to high frequency 164 | else: 165 | if params['f_initial'][f_from] <= params['f_initial'][f_to]: 166 | params['p_update_mask'][n_p[f_from]:n_p[f_from+1],n_p[f_to]:n_p[f_to+1]] = 1.0 167 | # During memory retrieval, hierarchical memory retrieval of grounded location is implemented by early-stopping low-frequency memory updates, using a mask for updates at every retrieval iteration 168 | params['p_retrieve_mask_inf'] = [torch.zeros(sum(params['n_p'])) for _ in range(params['i_attractor'])] 169 | params['p_retrieve_mask_gen'] = [torch.zeros(sum(params['n_p'])) for _ in range(params['i_attractor'])] 170 | # Build masks for each retrieval iteration 171 | for mask, max_iters in zip([params['p_retrieve_mask_inf'], params['p_retrieve_mask_gen']], [params['i_attractor_max_freq_inf'], params['i_attractor_max_freq_gen']]): 172 | # For each frequency, we get the number of update iterations, and insert ones in the mask for those iterations 173 | for f, max_i in enumerate(max_iters): 174 | # Update masks up to maximum iteration 175 | for i in range(max_i): 176 | mask[i][n_p[f]:n_p[f+1]] = 1.0 177 | # In path integration, abstract location frequency modules can influence the transition of other modules hierarchically (low to high). Set for each frequency module from which other frequencies input is received 178 | params['g_connections'] = [[params['f_initial'][f_from] <= params['f_initial'][f_to] for f_from in range(params['n_f_g'])] + [False for _ in range(params['n_f_ovc'])] for f_to in range(params['n_f_g'])] 179 | # Add connections for separate object vector cell module: only between object vector cell modules - and make those hierarchical too 180 | params['g_connections'] = params['g_connections'] + [[False for _ in range(params['n_f_g'])] + [params['f_initial'][f_from] <= params['f_initial'][f_to] for f_from in range(params['n_f_g'], params['n_f'])] for f_to in range(params['n_f_g'], params['n_f'])] 181 | 182 | # ---- Static matrices 183 | # Matrix for repeating abstract location g to do outer product with sensory information x with elementwise product. Also see (*) note at bottom 184 | params['W_repeat'] = [torch.tensor(np.kron(np.eye(params['n_g_subsampled'][f]),np.ones((1,params['n_x_f'][f]))), dtype=torch.float) for f in range(params['n_f'])] 185 | # Matrix for tiling sensory observation x to do outer product with abstract with elementwise product. Also see (*) note at bottom 186 | params['W_tile'] = [torch.tensor(np.kron(np.ones((1,params['n_g_subsampled'][f])),np.eye(params['n_x_f'][f])), dtype=torch.float) for f in range(params['n_f'])] 187 | # Table for converting one-hot to two-hot compressed representation 188 | params['two_hot_table'] = [[0]*(params['n_x_c']-2) + [1]*2] 189 | # We need a compressed code for each possible observation, but it's impossible to have more compressed codes than "n_x_c choose 2" 190 | for i in range(1, min(int(comb(params['n_x_c'],2)), params['n_x'])): 191 | # Copy previous code 192 | code = params['two_hot_table'][-1].copy() 193 | # Find latest occurrence of [0 1] in that code 194 | swap = [index for index in range(len(code)-1,-1,-1) if code[index:index+2] == [0,1]][0] 195 | # Swap those to get new code 196 | code[swap:swap+2] = [1,0] 197 | # If the first one was swapped: value after swapped pair is 1 198 | if swap+2 < len(code) and code[swap+2] == 1: 199 | # In that case: move the second 1 all the way back - reverse everything after the swapped pair 200 | code[swap+2:] = code[:swap+1:-1] 201 | # And append new code to array 202 | params['two_hot_table'].append(code) 203 | # Convert each code to column vector pytorch tensor 204 | params['two_hot_table'] = [torch.tensor(code) for code in params['two_hot_table']] 205 | # Downsampling matrix to go from grid cells to compressed grid cells for indexing memories by simply taking only the first n_g_subsampled grid cells 206 | params['g_downsample'] = [torch.cat([torch.eye(dim_out, dtype=torch.float),torch.zeros((dim_in-dim_out,dim_out), dtype=torch.float)]) for dim_in, dim_out in zip(params['n_g'],params['n_g_subsampled'])] 207 | return params 208 | 209 | # This specifies how parameters are updated at every backpropagation iteration/gradient update 210 | def parameter_iteration(iteration, params): 211 | # Calculate eta (rate of remembering) and lambda (rate of forgetting) for Hebbian memory updates 212 | eta = min((iteration+1)/params['eta_it'], 1) * params['eta'] 213 | lamb = min((iteration+1)/params['lambda_it'], 1) * params['lambda'] 214 | # Calculate current scaling of variance offset for ground location inference 215 | p2g_scale_offset = 1/(1+np.exp((iteration - params['p2g_sig_half_it'])/params['p2g_sig_scale_it'])) 216 | # Calculate current learning rate 217 | lr = max(params['lr_min'] + (params['lr_max'] - params['lr_min']) * (params['lr_decay_rate'] ** (iteration / params['lr_decay_steps'])), params['lr_min']) 218 | # Calculate center of walk length window, within which the walk lenghts of new walks are uniformly sampled 219 | walk_length_center = params['walk_it_max'] - params['walk_it_window'] * 0.5 - min((iteration+1)/params['train_it'],1) * (params['walk_it_max'] - params['walk_it_min'] - params['walk_it_window']) 220 | # Calculate current loss weights 221 | L_p_g = min((iteration+1)/params['loss_weights_p_g_it'], 1) * params['loss_weights_p'] 222 | L_p_x = min((iteration+1)/params['loss_weights_p_g_it'], 1) * params['loss_weights_p'] * (1 - p2g_scale_offset) 223 | L_x_gen = params['loss_weights_x'] 224 | L_x_g = params['loss_weights_x'] 225 | L_x_p = params['loss_weights_x'] 226 | L_g = min((iteration+1)/params['loss_weights_p_g_it'], 1) * params['loss_weights_g'] 227 | L_reg_g = (1 - min((iteration+1) / params['loss_weights_reg_g_it'], 1)) * params['loss_weights_reg_g'] 228 | L_reg_p = (1 - min((iteration+1) / params['loss_weights_reg_p_it'], 1)) * params['loss_weights_reg_p'] 229 | # And concatenate them in the order expected by the model 230 | loss_weights = torch.tensor([L_p_g, L_p_x, L_x_gen, L_x_g, L_x_p, L_g, L_reg_g, L_reg_p]) 231 | # Return all updated parameters 232 | return eta, lamb, p2g_scale_offset, lr, walk_length_center, loss_weights 233 | 234 | 235 | ''' 236 | (*) Note on W_tile and W_repeat: 237 | W_tile and W_repeat are for calculating outer products then vector flattening by matrix multiplication then elementwise product: 238 | g = np.random.rand(4,1) 239 | x = np.random.rand(3,1) 240 | out1 = np.matmul(g,np.transpose(x)).reshape((4*3,1)) 241 | W_repeat = np.kron(np.eye(4),np.ones((3,1))) 242 | W_tile = np.kron(np.ones((4,1)),np.eye(3)) 243 | out2 = np.matmul(W_repeat,g) * np.matmul(W_tile,x) 244 | Or in the case of row vectors, which is what you'd do for batch calculation: 245 | g = g.T 246 | x = x.T 247 | out3 = np.matmul(np.transpose(g), x).reshape((1,4*3)) # Notice how this is not batch-proof! 248 | W_repeat = np.kron(np.eye(4), np.ones((1,3))) 249 | W_tile = np.kron(np.ones((1,4)),np.eye(3)) 250 | out4 = np.matmul(g, W_repeat) * np.matmul(x,W_tile) # This is batch-proof 251 | ''' -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Mar 30 14:35:30 2020 5 | 6 | @author: jacobb 7 | """ 8 | 9 | # Functions for plotting training and results of TEM 10 | 11 | # Standard library imports 12 | import matplotlib.pyplot as plt 13 | from matplotlib import cm 14 | import numpy as np 15 | 16 | def plot_weights(models, params = None, steps = None, do_save = False): 17 | # If no parameter names specified: just take all of the trained ones from the model 18 | if params is None: 19 | params = [item[0] for item in models[0].named_parameters().items() if item[1].requires_grad] 20 | # If no steps specified: just make them increase by 1 for each model 21 | if steps is None: 22 | steps = [i for i in range(len(models))] 23 | # Collect this parameter in each model as provided 24 | model_dicts = [{model_params[0] : model_params[1] for model_params in model.named_parameters()} for model in models] 25 | # Plot each parameter separately 26 | for param in params: 27 | # Create figure and subplots 28 | fig, axs = plt.subplots(2, len(steps)) 29 | # Set it's size to something that is stretched horizontally so you can read titles 30 | fig.set_size_inches(10, 4) 31 | # Figure overall title is the parameter name 32 | fig.suptitle(param) 33 | values = [model_params[param].detach().numpy() for model_params in model_dicts] 34 | # On the first line of this figure: plot params at each step 35 | for i, step in enumerate(steps): 36 | # Plot variable values in subplot 37 | axs[0, i].imshow(values[i]) 38 | axs[0, i].set_title('Step ' + str(step)) 39 | # On the second line of this figure: plot change in params between steps 40 | for i in range(len(steps)-1): 41 | # Plot the change in variables 42 | axs[1, i].imshow(values[i+1] - values[i]) 43 | axs[1, i].set_title(str(steps[i]) + ' to ' + str(steps[i+1]) + ', ' + '{:.2E}'.format(np.mean(np.abs(values[i+1] - values[i]))/(steps[i+1]-steps[i]))) 44 | # On the very last axis: plot the total difference between the first and the last 45 | axs[1, -1].imshow(values[-1] - values[0]) 46 | axs[1, -1].set_title(str(steps[0]) + ' to ' + str(steps[-1]) + ', ' + '{:.2E}'.format(np.mean(np.abs(values[-1] - values[0])))) 47 | # If you want to save this figure: do so 48 | if do_save: 49 | fig.savefig('./figs/plot_weights_' + param + '.png') 50 | 51 | def plot_memory(iters, steps = None, do_save = False): 52 | # If no steps specified: just make them increase by 1 for each model 53 | if steps is None: 54 | steps = [i for i in range(len(iters))] 55 | # Set names of memory: inference and generative 56 | names = ['Generative','Inference'] 57 | # Plot each parameter separately 58 | for mem in range(len(iters[0].M)): 59 | # Get current memory name 60 | name = names[mem] 61 | # Create figure and subplots 62 | fig, axs = plt.subplots(len(iters[0].M[0]), len(steps)) 63 | # Set it's size to something that is stretched horizontally so you can read titles 64 | fig.set_size_inches(len(steps)*2, len(iters[0].M[0])) 65 | # Figure overall title is the parameter name 66 | fig.suptitle(name + ' memory') 67 | # Load the memory matrices - first on in each batch 68 | batches = [iteration.M[mem] for iteration in iters] 69 | # On the first line of this figure: plot params at each step 70 | for col, step in enumerate(steps): 71 | for row, batch in enumerate(batches[col]): 72 | if len(steps) == 1: 73 | # Plot variable values in subplot 74 | axs[row].imshow(batch.numpy()) 75 | axs[row].set_title('Step ' + str(step) + ', batch ' + str(row)) 76 | else: 77 | # Plot variable values in subplot 78 | axs[row, col].imshow(batch.numpy()) 79 | axs[row, col].set_title('Step ' + str(step) + ', batch ' + str(row)) 80 | # If you want to save this figure: do so 81 | if do_save: 82 | fig.savefig('./figs/plot_mem_' + name + '.png') 83 | 84 | def plot_map(environment, values, ax=None, min_val=None, max_val=None, num_cols=100, location_cm='viridis', action_cm='Pastel1', do_plot_actions=False, shape='circle', radius=None): 85 | # If min_val and max_val are not specified: take the minimum and maximum of the supplied values 86 | min_val = np.min(values) if min_val is None else min_val 87 | max_val = np.max(values) if max_val is None else max_val 88 | # Create color map for locations: colour given by value input 89 | location_cm = cm.get_cmap(location_cm, num_cols) 90 | # Create color map for actions: colour given by action index 91 | action_cm = cm.get_cmap(action_cm, environment.n_actions) 92 | # Calculate colour corresponding to each value 93 | plotvals = np.floor((values - min_val) / (max_val - min_val) * num_cols) if max_val != min_val else np.ones(values.shape) 94 | # Calculate radius of location circles based on how many nodes there are 95 | radius = 2*(0.01 + 1/(10*np.sqrt(environment.n_locations))) if radius is None else radius 96 | # Initialise empty axis 97 | ax = initialise_axes(ax) 98 | # Create empty list of location patches and action patches 99 | location_patches, action_patches = [], [] 100 | # Now start drawing locations and actions 101 | for i, location in enumerate(environment.locations): 102 | # Create patch for location 103 | location_patches.append(plt.Rectangle((location['x']-radius/2, location['y']-radius/2), radius, radius, color=location_cm(int(plotvals[i]))) if shape == 'square' 104 | else plt.Circle((location['x'], location['y']), radius, color=location_cm(int(plotvals[i])))) 105 | # And create action patches, if action plotting is switched on 106 | if do_plot_actions: 107 | for a, action in enumerate(location['actions']): 108 | # Only draw patch if action probability is larger than 0 109 | if action['probability'] > 0: 110 | # Find where this action takes you 111 | locations_to = [environment.locations[loc_to] for loc_to in np.where(np.array(action['transition'])>0)[0]] 112 | # Create an action patch for each possible transition for this action 113 | for loc_to in locations_to: 114 | action_patches.append(action_patch(location, loc_to, radius, action_cm(action['id']))) 115 | # After drawing all locations, add shiny patches 116 | for location in environment.locations: 117 | # For shiny locations, add big red patch to indicate shiny 118 | if location['shiny']: 119 | # Create square patch for location 120 | location_patches.append(plt.Rectangle((location['x']-radius/2, location['y']-radius/2), radius, radius, linewidth=1, facecolor='none', edgecolor=[1,0,0]) if shape == 'square' 121 | else plt.Circle((location['x'], location['y']), radius, linewidth=1, facecolor='none', edgecolor=[1,0,0])) 122 | # Add patches to axes 123 | for patch in location_patches + action_patches: 124 | ax.add_patch(patch) 125 | # Return axes for further use 126 | return ax 127 | 128 | def plot_actions(environment, field='probability', ax=None, min_val=None, max_val=None, num_cols=100, action_cm='viridis'): 129 | # If min_val and max_val are not specified: take the minimum and maximum of the supplied values 130 | min_val = min([action[field] for location in environment.locations for action in location['actions']]) if min_val is None else min_val 131 | max_val = max([action[field] for location in environment.locations for action in location['actions']]) if max_val is None else max_val 132 | # Create color map for locations: colour given by value input 133 | action_cm = cm.get_cmap(action_cm, num_cols) 134 | # Calculate radius of location circles based on how many nodes there are 135 | radius = 2*(0.01 + 1/(10*np.sqrt(environment.n_locations))) 136 | # Initialise empty axis 137 | ax = initialise_axes(ax) 138 | # Create empty list of location patches and action patches 139 | location_patches, action_patches = [], [] 140 | # Now start drawing locations and actions 141 | for i, location in enumerate(environment.locations): 142 | # Create circle patch for location 143 | location_patches.append(plt.Circle((location['x'], location['y']), radius, color=[0, 0, 0])) 144 | # And create action patches 145 | for a, action in enumerate(location['actions']): 146 | # Only draw patch if action probability is larger than 0 147 | if action['probability'] > 0: 148 | # Calculate colour for this action from colour map 149 | action_colour = action_cm(int(np.floor((action[field] - min_val) / (max_val - min_val) * num_cols))) 150 | # Find where this action takes you 151 | locations_to = [environment.locations[loc_to] for loc_to in np.where(np.array(action['transition'])>0)[0]] 152 | # Create an action patch for each possible transition for this action 153 | for loc_to in locations_to: 154 | action_patches.append(action_patch(location, loc_to, radius, action_colour)) 155 | # Add patches to axes 156 | for patch in (location_patches + action_patches): 157 | ax.add_patch(patch) 158 | # Return axes for further use 159 | return ax 160 | 161 | def plot_walk(environment, walk, max_steps=None, n_steps=1, ax=None): 162 | # Set maximum number of steps if not provided 163 | max_steps = len(walk) if max_steps is None else min(max_steps, len(walk)) 164 | # Initialise empty axis if axis wasn't provided 165 | if ax is None: 166 | ax = initialise_axes(ax) 167 | # Find all circle patches on current axis 168 | location_patches = [patch_i for patch_i, patch in enumerate(ax.patches) if type(patch) is plt.Circle or type(patch) is plt.Rectangle] 169 | # Get radius of location circles on this map 170 | radius = (ax.patches[location_patches[-1]].get_radius() if type(ax.patches[location_patches[-1]]) is plt.Circle 171 | else ax.patches[location_patches[-1]].get_width()) if len(location_patches) > 0 else 0.02 172 | # Initialise previous location: location of first location 173 | prev_loc = np.array([environment.locations[walk[0][0]['id']]['x'], environment.locations[walk[0][0]['id']]['y']]) 174 | # Run through walk, creating lines 175 | for step_i in range(1, max_steps, n_steps): 176 | # Get location of current location, with some jitter so lines don't overlap 177 | new_loc = np.array([environment.locations[walk[step_i][0]['id']]['x'], environment.locations[walk[step_i][0]['id']]['y']]) 178 | # Add jitter (need to unpack shape for rand - annoyingly np.random.rand takes dimensions separately) 179 | new_loc = new_loc + 0.8*(-radius + 2*radius*np.random.rand(*new_loc.shape)) 180 | # Plot line from previous location to current location 181 | plt.plot([prev_loc[0], new_loc[0]], [prev_loc[1], new_loc[1]], color=[step_i/max_steps for _ in range(3)]) 182 | # Update new location to previous location 183 | prev_loc = new_loc 184 | # Return axes that this was plotted on 185 | return ax 186 | 187 | def plot_cells(p, g, environment, n_f_ovc=0, columns=10): 188 | # Run through all hippocampal and entorhinal rate maps, big nested arrays arranged as [frequency][location][cell] 189 | for cells, names in zip([p, g],['Hippocampal','Entorhinal']): 190 | # Calculate the number of rows that each frequency module requires 191 | n_rows_f = np.cumsum([0] + [np.ceil(len(c[0]) * 1.0 / columns) for c in cells]).astype(int) 192 | # Create subplots for cells across frequencies 193 | fig, ax = plt.subplots(nrows=n_rows_f[-1], ncols=columns) 194 | # Switch all axes off 195 | for row in ax: 196 | for col in row: 197 | col.axis('off') 198 | # And run through all frequencies to plot cells for that frequency 199 | for f, loc_rates in enumerate(cells): 200 | # Set title for current axis 201 | ax[n_rows_f[f], int(columns/2)].set_title(names + ('' if f < len(cells) - n_f_ovc else ' object vector ') + ' cells, frequency ' 202 | + str(f if f < len(cells) - n_f_ovc else f - (len(cells) - n_f_ovc))) 203 | # Plot map for each cell 204 | for c in range(len(loc_rates[0])): 205 | # Get current row and column 206 | row = int(n_rows_f[f] + np.floor(c / columns)) 207 | col = int(c % columns) 208 | # Plot rate map for this cell by collection firing rate at each location 209 | plot_map(environment, np.array([loc_rates[l][c] for l in range(len(loc_rates))]), ax[row, col], shape='square', radius=1/np.sqrt(len(loc_rates))) 210 | 211 | def initialise_axes(ax=None): 212 | # If no axes specified: create new figure with new empty axes 213 | if ax is None: 214 | plt.figure() 215 | ax = plt.axes() 216 | # Set axes limits to 0, 1 as this is how the positions in the environment are setup 217 | ax.set_xlim([0, 1]) 218 | ax.set_ylim([0, 1]) 219 | # Force axes to be square to keep proper aspect ratio 220 | ax.set_aspect(1) 221 | # Revert y-axes so y position increases downwards (as it usually does in graphics/pixels) 222 | ax.invert_yaxis() 223 | # And don't show any axes 224 | ax.axis('off') 225 | # Return axes object 226 | return ax 227 | 228 | def action_patch(location_from, location_to, radius, colour): 229 | # Set patch coordinates 230 | if location_to['id'] == location_from['id']: 231 | # If this is a transition to self: action will point down (y-axis is reversed so pi/2 degrees is up) 232 | a_dir = np.pi/2; 233 | # Set the patch coordinates to point from this location to transition location (but shifted upward for self transition) 234 | xdat = location_from['x'] + radius * np.array([2*np.cos((a_dir-np.pi/6)), 2*np.cos((a_dir+np.pi/6)), 3*np.cos((a_dir))]) 235 | ydat = location_from['y'] - radius * 3 + radius * np.array([2*np.sin((a_dir-np.pi/6)), 2*np.sin((a_dir+np.pi/6)), 3*np.sin((a_dir))]) 236 | else: 237 | # This is not a transition to self. Find out the direction between current location and transitioned location 238 | xvec = location_to['x']-location_from['x'] 239 | yvec = location_from['y']-location_to['y'] 240 | a_dir = np.arctan2(xvec*0-yvec*1,xvec*1+yvec*0); 241 | # Set the patch coordinates to point from this location to transition location 242 | xdat = location_from['x'] + radius * np.array([2*np.cos((a_dir-np.pi/6)), 2*np.cos((a_dir+np.pi/6)), 3*np.cos((a_dir))]) 243 | ydat = location_from['y'] + radius * np.array([2*np.sin((a_dir-np.pi/6)), 2*np.sin((a_dir+np.pi/6)), 3*np.sin((a_dir))]) 244 | # Return action patch for provided data 245 | return plt.Polygon(np.stack([xdat, ydat], axis=1), color=colour) 246 | 247 | 248 | ## Just for convenience: all parameters in TEM 249 | #for name, param in model.named_parameters(): 250 | # if param.requires_grad: 251 | # print(name, param.data) 252 | ''' 253 | w_x 254 | b_x 255 | w_p.0 256 | w_p.1 257 | w_p.2 258 | w_p.3 259 | w_p.4 260 | MLP_D_a.w.0.0.weight 261 | MLP_D_a.w.0.0.bias 262 | MLP_D_a.w.0.1.weight 263 | MLP_D_a.w.0.1.bias 264 | MLP_D_a.w.1.0.weight 265 | MLP_D_a.w.1.0.bias 266 | MLP_D_a.w.1.1.weight 267 | MLP_D_a.w.1.1.bias 268 | MLP_D_a.w.2.0.weight 269 | MLP_D_a.w.2.0.bias 270 | MLP_D_a.w.2.1.weight 271 | MLP_D_a.w.2.1.bias 272 | MLP_D_a.w.3.0.weight 273 | MLP_D_a.w.3.0.bias 274 | MLP_D_a.w.3.1.weight 275 | MLP_D_a.w.3.1.bias 276 | MLP_D_a.w.4.0.weight 277 | MLP_D_a.w.4.0.bias 278 | MLP_D_a.w.4.1.weight 279 | MLP_D_a.w.4.1.bias 280 | MLP_sigma_g_path.w.0.0.weight 281 | MLP_sigma_g_path.w.0.0.bias 282 | MLP_sigma_g_path.w.0.1.weight 283 | MLP_sigma_g_path.w.0.1.bias 284 | MLP_sigma_g_path.w.1.0.weight 285 | MLP_sigma_g_path.w.1.0.bias 286 | MLP_sigma_g_path.w.1.1.weight 287 | MLP_sigma_g_path.w.1.1.bias 288 | MLP_sigma_g_path.w.2.0.weight 289 | MLP_sigma_g_path.w.2.0.bias 290 | MLP_sigma_g_path.w.2.1.weight 291 | MLP_sigma_g_path.w.2.1.bias 292 | MLP_sigma_g_path.w.3.0.weight 293 | MLP_sigma_g_path.w.3.0.bias 294 | MLP_sigma_g_path.w.3.1.weight 295 | MLP_sigma_g_path.w.3.1.bias 296 | MLP_sigma_g_path.w.4.0.weight 297 | MLP_sigma_g_path.w.4.0.bias 298 | MLP_sigma_g_path.w.4.1.weight 299 | MLP_sigma_g_path.w.4.1.bias 300 | MLP_sigma_p.w.0.0.weight 301 | MLP_sigma_p.w.0.0.bias 302 | MLP_sigma_p.w.0.1.weight 303 | MLP_sigma_p.w.0.1.bias 304 | MLP_sigma_p.w.1.0.weight 305 | MLP_sigma_p.w.1.0.bias 306 | MLP_sigma_p.w.1.1.weight 307 | MLP_sigma_p.w.1.1.bias 308 | MLP_sigma_p.w.2.0.weight 309 | MLP_sigma_p.w.2.0.bias 310 | MLP_sigma_p.w.2.1.weight 311 | MLP_sigma_p.w.2.1.bias 312 | MLP_sigma_p.w.3.0.weight 313 | MLP_sigma_p.w.3.0.bias 314 | MLP_sigma_p.w.3.1.weight 315 | MLP_sigma_p.w.3.1.bias 316 | MLP_sigma_p.w.4.0.weight 317 | MLP_sigma_p.w.4.0.bias 318 | MLP_sigma_p.w.4.1.weight 319 | MLP_sigma_p.w.4.1.bias 320 | MLP_mu_g_mem.w.0.0.weight 321 | MLP_mu_g_mem.w.0.0.bias 322 | MLP_mu_g_mem.w.0.1.weight 323 | MLP_mu_g_mem.w.0.1.bias 324 | MLP_mu_g_mem.w.1.0.weight 325 | MLP_mu_g_mem.w.1.0.bias 326 | MLP_mu_g_mem.w.1.1.weight 327 | MLP_mu_g_mem.w.1.1.bias 328 | MLP_mu_g_mem.w.2.0.weight 329 | MLP_mu_g_mem.w.2.0.bias 330 | MLP_mu_g_mem.w.2.1.weight 331 | MLP_mu_g_mem.w.2.1.bias 332 | MLP_mu_g_mem.w.3.0.weight 333 | MLP_mu_g_mem.w.3.0.bias 334 | MLP_mu_g_mem.w.3.1.weight 335 | MLP_mu_g_mem.w.3.1.bias 336 | MLP_mu_g_mem.w.4.0.weight 337 | MLP_mu_g_mem.w.4.0.bias 338 | MLP_mu_g_mem.w.4.1.weight 339 | MLP_mu_g_mem.w.4.1.bias 340 | MLP_sigma_g_mem.w.0.0.weight 341 | MLP_sigma_g_mem.w.0.0.bias 342 | MLP_sigma_g_mem.w.0.1.weight 343 | MLP_sigma_g_mem.w.0.1.bias 344 | MLP_sigma_g_mem.w.1.0.weight 345 | MLP_sigma_g_mem.w.1.0.bias 346 | MLP_sigma_g_mem.w.1.1.weight 347 | MLP_sigma_g_mem.w.1.1.bias 348 | MLP_sigma_g_mem.w.2.0.weight 349 | MLP_sigma_g_mem.w.2.0.bias 350 | MLP_sigma_g_mem.w.2.1.weight 351 | MLP_sigma_g_mem.w.2.1.bias 352 | MLP_sigma_g_mem.w.3.0.weight 353 | MLP_sigma_g_mem.w.3.0.bias 354 | MLP_sigma_g_mem.w.3.1.weight 355 | MLP_sigma_g_mem.w.3.1.bias 356 | MLP_sigma_g_mem.w.4.0.weight 357 | MLP_sigma_g_mem.w.4.0.bias 358 | MLP_sigma_g_mem.w.4.1.weight 359 | MLP_sigma_g_mem.w.4.1.bias 360 | MLP_c_star.w.0.0.weight 361 | MLP_c_star.w.0.0.bias 362 | MLP_c_star.w.0.1.weight 363 | MLP_c_star.w.0.1.bias 364 | ''' -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu Feb 20 14:57:45 2020 5 | 6 | @author: jacobb 7 | """ 8 | 9 | # Standard library imports 10 | import numpy as np 11 | import torch 12 | from torch.utils.tensorboard import SummaryWriter 13 | import time 14 | import glob, os, shutil 15 | import importlib.util 16 | # Own module imports 17 | import world 18 | import utils 19 | import parameters 20 | import model as model 21 | 22 | # Set random seeds for reproducibility 23 | np.random.seed(0) 24 | torch.manual_seed(0) 25 | 26 | # Either load a trained model and continue training, or start afresh 27 | load_existing_model = False; 28 | if load_existing_model: 29 | # Choose which trained model to load 30 | date = '2020-10-06' # 2020-07-05 run 0 for successful node agent 31 | run = '2' 32 | i_start = 40 33 | 34 | # Set all paths from existing run 35 | run_path, train_path, model_path, save_path, script_path, envs_path = utils.set_directories(date, run) 36 | 37 | # Load the model: use import library to import module from specified path 38 | model_spec = importlib.util.spec_from_file_location("model", script_path + '/model.py') 39 | model = importlib.util.module_from_spec(model_spec) 40 | model_spec.loader.exec_module(model) 41 | 42 | # Load the parameters of the model 43 | params = torch.load(model_path + '/params_' + str(i_start) + '.pt') 44 | # But certain parameters (like total nr of training iterations) may need to be copied from the current set of parameters 45 | new_params = {'train_it':40000} 46 | # Update those in params 47 | for key in new_params: 48 | params[key] = new_params[key] 49 | 50 | # Create a new tem model with the loaded parameters 51 | tem = model.Model(params) 52 | # Load the model weights after training 53 | model_weights = torch.load(model_path + '/tem_' + str(i_start) + '.pt') 54 | # Set the model weights to the loaded trained model weights 55 | tem.load_state_dict(model_weights) 56 | 57 | # Make list of all the environments that this model was trained on 58 | envs = list(glob.iglob(envs_path + '/*')) 59 | 60 | # And increase starting iteration by 1, since the loaded model already carried out the current starting iteration 61 | i_start = i_start + 1 62 | else: 63 | # Start training from step 0 64 | i_start = 0 65 | 66 | # Create directories for storing all information about the current run 67 | run_path, train_path, model_path, save_path, script_path, envs_path = utils.make_directories() 68 | # Save all python files in current directory to script directory 69 | files = glob.iglob(os.path.join('.', '*.py')) 70 | for file in files: 71 | if os.path.isfile(file): 72 | shutil.copy2(file, os.path.join(script_path, file)) 73 | 74 | # Initalise hyperparameters for model 75 | params = parameters.parameters() 76 | # Save parameters 77 | np.save(os.path.join(save_path, 'params'), params) 78 | 79 | # And create instance of TEM with those parameters 80 | tem = model.Model(params) 81 | 82 | # Create list of environments that we will sample from during training to provide TEM with trajectory input 83 | envs = ['./envs/5x5.json'] 84 | # Save all environment files that are being used in training in the script directory 85 | for file in set(envs): 86 | shutil.copy2(file, os.path.join(envs_path, os.path.basename(file))) 87 | 88 | # Create a tensor board to stay updated on training progress. Start tensorboard with tensorboard --logdir=runs 89 | writer = SummaryWriter(train_path) 90 | # Create a logger to write log output to file 91 | logger = utils.make_logger(run_path) 92 | 93 | # Make an ADAM optimizer for TEM 94 | adam = torch.optim.Adam(tem.parameters(), lr = params['lr_max']) 95 | 96 | # Make set of environments: one for each batch, randomly choosing to use shiny objects or not 97 | environments = [world.World(graph, randomise_observations=True, shiny=(params['shiny'] if np.random.rand() < params['shiny_rate'] else None)) for graph in np.random.choice(envs,params['batch_size'])] 98 | # Initialise whether a state has been visited for each world 99 | visited = [[False for _ in range(env.n_locations)] for env in environments] 100 | # And make a single walk for each environment, where walk lengths can be any between the min and max length to de-sychronise world switches 101 | walks = [env.generate_walks(params['n_rollout']*np.random.randint(params['walk_it_min'], params['walk_it_max']), 1)[0] for env in environments] 102 | # Initialise the previous iteration as None: we start from the beginning of the walk, so there is no previous iteration yet 103 | prev_iter = None 104 | 105 | # Train TEM on walks in different environment 106 | for i in range(i_start, params['train_it']): 107 | 108 | # Get start time for function timing 109 | start_time = time.time() 110 | # Get updated parameters for this backprop iteration 111 | eta_new, lambda_new, p2g_scale_offset, lr, walk_length_center, loss_weights = parameters.parameter_iteration(i, params) 112 | # Update eta and lambda 113 | tem.hyper['eta'] = eta_new 114 | tem.hyper['lambda'] = lambda_new 115 | # Update scaling of offset for variance of inferred grounded position 116 | tem.hyper['p2g_scale_offset'] = p2g_scale_offset 117 | # Update learning rate (the neater torch-way of doing this would be a scheduler, but this is quick and easy) 118 | for param_group in adam.param_groups: 119 | param_group['lr'] = lr 120 | 121 | # Make an empty chunk that will be fed to TEM in this backprop iteration 122 | chunk = [] 123 | # For each environment: fill chunk by popping the first batch_size steps of the walk 124 | for env_i, walk in enumerate(walks): 125 | # Make sure this walk has enough steps in it for a whole backprop iteration 126 | if len(walk) < params['n_rollout']: 127 | # If it doesn't: create a new environment 128 | environments[env_i] = world.World(envs[np.random.randint(len(envs))], randomise_observations=True, shiny=(params['shiny'] if np.random.rand() < params['shiny_rate'] else None)) 129 | # Initialise whether a state has been visited for each world 130 | visited[env_i] = [False for _ in range(environments[env_i].n_locations)] 131 | # Generate a new walk on that environment 132 | walk = environments[env_i].generate_walks(params['n_rollout']*np.random.randint(walk_length_center - params['walk_it_window'] * 0.5, walk_length_center + params['walk_it_window'] * 0.5), 1)[0] 133 | # And store it in walks array 134 | walks[env_i] = walk 135 | # Finally, set the action of the previous iteration for this environment to zero, to indicate that this is a new walk 136 | prev_iter[0].a[env_i] = None 137 | # Log progress 138 | logger.info('Iteration {:d}: new walk of length {:d} for batch entry {:d}'.format(i, len(walk), env_i)) 139 | # Now pop the first n_rollout steps from this walk and append them to the chunk 140 | for step in range(params['n_rollout']): 141 | # For the first environment: simply copy the components (g, x, a) of each step 142 | if len(chunk) < params['n_rollout']: 143 | chunk.append([[comp] for comp in walk.pop(0)]) 144 | # For all next environments: add the components to the existing list of components for each step 145 | else: 146 | for comp_i, comp in enumerate(walk.pop(0)): 147 | chunk[step][comp_i].append(comp) 148 | # Stack all observations (x, component 1) into tensors along the first dimension for batch processing 149 | for i_step, step in enumerate(chunk): 150 | chunk[i_step][1] = torch.stack(step[1], dim=0) 151 | 152 | # Forward-pass this walk through the network 153 | forward = tem(chunk, prev_iter) 154 | 155 | # Accumulate loss from forward pass 156 | loss = torch.tensor(0.0) 157 | # Make vector for plotting losses 158 | plot_loss = 0 159 | # Collect all losses 160 | for step in forward: 161 | # Make list of losses included in this step 162 | step_loss = [] 163 | # Only include loss for locations that have been visited before 164 | for env_i, env_visited in enumerate(visited): 165 | if env_visited[step.g[env_i]['id']]: 166 | step_loss.append(loss_weights*torch.stack([l[env_i] for l in step.L])) 167 | else: 168 | env_visited[step.g[env_i]['id']] = True 169 | # Stack losses in this step along first dimension, then average across that dimension to get mean loss for this step 170 | step_loss = torch.tensor(0) if not step_loss else torch.mean(torch.stack(step_loss, dim=0), dim=0) 171 | # Save all separate components of loss for monitoring 172 | plot_loss = plot_loss + step_loss.detach().numpy() 173 | # And sum all components, then add them to total loss of this step 174 | loss = loss + torch.sum(step_loss) 175 | 176 | # Reset gradients 177 | adam.zero_grad() 178 | # Do backward pass to calculate gradients with respect to total loss of this chunk 179 | loss.backward(retain_graph=True) 180 | # Then do optimiser step to update parameters of model 181 | adam.step() 182 | # Update the previous iteration for the next chunk with the final step of this chunk, removing all operation history 183 | prev_iter = [forward[-1].detach()] 184 | 185 | # Compute model accuracies 186 | acc_p, acc_g, acc_gt = np.mean([[np.mean(a) for a in step.correct()] for step in forward], axis=0) 187 | acc_p, acc_g, acc_gt = [a * 100 for a in (acc_p, acc_g, acc_gt)] 188 | # Log progress 189 | if i % 10 == 0: 190 | # Write series of messages to logger from this backprop iteration 191 | logger.info('Finished backprop iter {:d} in {:.2f} seconds.'.format(i,time.time()-start_time)) 192 | logger.info('Loss: {:.2f}. {:.2f} {:.2f} {:.2f} {:.2f} {:.2f} {:.2f} {:.2f} {:.2f}'.format(loss.detach().numpy(), *plot_loss)) 193 | logger.info('Accuracy:

{:.2f}% {:.2f}% {:.2f}%'.format(acc_p, acc_g, acc_gt)) 194 | logger.info('Parameters: {:.2f} {:.2f} {:.2f} {:.2f}'.format(np.max(np.abs(prev_iter[0].M[0].numpy())), tem.hyper['eta'], tem.hyper['lambda'], tem.hyper['p2g_scale_offset'])) 195 | logger.info('Weights:' + str([w for w in loss_weights.numpy()])) 196 | logger.info(' ') 197 | # Also write progress to tensorboard, and all loss components. Order: [L_p_g, L_p_x, L_x_gen, L_x_g, L_x_p, L_g, L_reg_g, L_reg_p] 198 | writer.add_scalar('Losses/Total', loss.detach().numpy(), i) 199 | writer.add_scalar('Losses/p_g', plot_loss[0], i) 200 | writer.add_scalar('Losses/p_x', plot_loss[1], i) 201 | writer.add_scalar('Losses/x_gen', plot_loss[2], i) 202 | writer.add_scalar('Losses/x_g', plot_loss[3], i) 203 | writer.add_scalar('Losses/x_p', plot_loss[4], i) 204 | writer.add_scalar('Losses/g', plot_loss[5], i) 205 | writer.add_scalar('Losses/reg_g', plot_loss[6], i) 206 | writer.add_scalar('Losses/reg_p', plot_loss[7], i) 207 | writer.add_scalar('Accuracies/p', acc_p, i) 208 | writer.add_scalar('Accuracies/g', acc_g, i) 209 | writer.add_scalar('Accuracies/gt', acc_gt, i) 210 | # Also store the internal state (all learnable parameters) and the hyperparameters periodically 211 | if i % 1000 == 0: 212 | torch.save(tem.state_dict(), model_path + '/tem_' + str(i) + '.pt') 213 | torch.save(tem.hyper, model_path + '/params_' + str(i) + '.pt') 214 | 215 | # Save the final state of the model after training has finished 216 | torch.save(tem.state_dict(), model_path + '/tem_' + str(i) + '.pt') 217 | torch.save(tem.hyper, model_path + '/params_' + str(i) + '.pt') -------------------------------------------------------------------------------- /run_lstm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu Feb 20 14:57:45 2020 5 | 6 | @author: jacobb 7 | """ 8 | 9 | # Standard library imports 10 | import numpy as np 11 | import torch 12 | from torch.utils.tensorboard import SummaryWriter 13 | import time 14 | # Own module imports 15 | import world 16 | import parameters 17 | import model 18 | import plot 19 | 20 | # Set random seeds for reproducibility 21 | np.random.seed(0) 22 | torch.manual_seed(0) 23 | 24 | # Create world: 4x4 grid with actions [North, East, South, West] with random policy, with 15 sensory experiences 25 | grid = world.World('./graphs/5x5.json', 45) 26 | 27 | # Initalise hyperparameters for model 28 | params = parameters.parameters(grid) 29 | 30 | # Create lstm, to see if that learns well 31 | lstm = model.LSTM(params['n_x'] + params['n_actions'], 100, params['n_x'], n_a = params['n_actions']) 32 | 33 | # Create set of training worlds, as many as there are batches 34 | environments = [world.World('./graphs/5x5.json', 45) for batch in range(params['n_batches'])] 35 | 36 | # Create walks on each world 37 | walks = [env.generate_walks(params['walk_length'], params['n_walks']) for env in environments] 38 | 39 | # Create batched walks: instead of having walks separated by environment, collect them by environment 40 | batches = [[[[],[],[]] for l in range(params['walk_length'])] for w in range(params['n_walks'])] 41 | for env in walks: 42 | for i_walk, walk in enumerate(env): 43 | for i_step, step in enumerate(walk): 44 | for i_comp, component in enumerate(step): 45 | # Append state, observation, action across environments 46 | batches[i_walk][i_step][i_comp].append(component) 47 | # Stack all observations into tensors along the first dimension for batch processing 48 | for i_walk, walk in enumerate(batches): 49 | for i_step, step in enumerate(walk): 50 | batches[i_walk][i_step][1] = torch.stack(step[1], dim=0) 51 | 52 | # Create a tensor board to stay updated on training progress. Start tensorboard with tensorboard --logdir=runs 53 | writer = SummaryWriter() 54 | 55 | # Make an ADAM optimizer for the LSTM 56 | adam = torch.optim.Adam(lstm.parameters(), lr = 0.1) 57 | 58 | # Create learning rate scheduler that reduces learning rate over training 59 | lr_factor = lambda epoch: 0.75 60 | scheduler = torch.optim.lr_scheduler.MultiplicativeLR(adam,lr_factor) 61 | 62 | # Train LSTM 63 | for i, walk in enumerate(batches): 64 | # Don't feed walk all at once; instead, feed limited number of forward rollouts, then backprop through time 65 | chunks = [[i, min(i + params['n_rollout'],len(walk))] for i in range(0, len(walk), params['n_rollout'])] 66 | # Initialise the previous hidden state as none: at the beginning of a walk, there is no hidden state yet 67 | prev_hidden = None 68 | # Run through all chunks that we are going to backprop for 69 | for j, [start, stop] in enumerate(chunks): 70 | # Get start time for function timing 71 | start_time = time.time() 72 | # Prepare data for feeding into lstm 73 | data = lstm.prepare_data(walk[start:stop]) 74 | # Forward-pass this data through the network 75 | predictions, prev_hidden = lstm(data, prev_hidden) 76 | # Calculate loss from forward pass: difference between predicted and real observation at each step 77 | loss = torch.nn.BCELoss()(predictions[:,:-1,:], data[:,1:,:params['n_x']]) 78 | # Reset gradients 79 | adam.zero_grad() 80 | # Do backward pass to calculate gradients with respect to total loss of this chunk 81 | loss.backward(retain_graph=True) 82 | # Then do optimiser step to update parameters of model 83 | adam.step() 84 | # And detach previous hidden state to prevent gradients going back forever 85 | prev_hidden = tuple([hidden.detach() for hidden in prev_hidden]) 86 | # Calculate accuracy: how often was the best guess from the predictions correct? 87 | accuracy = torch.mean((torch.argmax(data[:,1:,:params['n_x']], dim=-1) == torch.argmax(predictions[:,:-1,:], dim=-1)).type(torch.float)).numpy() 88 | # Show progress 89 | if j % 10 == 0: 90 | print('Finished walk {:d}, chunk {:d} in {:.2f} seconds.\n'.format(i,j,time.time()-start_time) + 91 | 'Loss: {:.2f}, accuracy: {:.2f} %'.format(loss.detach().numpy(), accuracy * 100.0)) 92 | # Also write progress to tensorboard 93 | writer.add_scalar('Walk ' + str(i + 1) + '/Loss', loss.detach().numpy(), j) 94 | writer.add_scalar('Walk ' + str(i + 1) + '/Accuracy', accuracy * 100, j) 95 | # Also step the learning rate down after each walk 96 | scheduler.step() 97 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu Jul 2 09:35:57 2020 5 | 6 | @author: jacobb 7 | """ 8 | 9 | # Standard library imports 10 | import numpy as np 11 | import torch 12 | import glob 13 | import matplotlib.pyplot as plt 14 | import importlib.util 15 | # Own module imports. Note how model module is not imported, since we'll used the model from the training run 16 | import world 17 | import analyse 18 | import plot 19 | 20 | # Set random seeds for reproducibility 21 | np.random.seed(0) 22 | torch.manual_seed(0) 23 | 24 | # Choose which trained model to load 25 | date = '2020-10-19' # 2020-10-13 run 0 for successful node agent 26 | run = '0' 27 | index = '32000' 28 | 29 | # Load the model: use import library to import module from specified path 30 | model_spec = importlib.util.spec_from_file_location("model", '../Summaries/' + date + '/run' + run + '/script/model.py') 31 | model = importlib.util.module_from_spec(model_spec) 32 | model_spec.loader.exec_module(model) 33 | 34 | # Load the parameters of the model 35 | params = torch.load('../Summaries/' + date + '/run' + run + '/model/params_' + index + '.pt') 36 | # Create a new tem model with the loaded parameters 37 | tem = model.Model(params) 38 | # Load the model weights after training 39 | model_weights = torch.load('../Summaries/' + date + '/run' + run + '/model/tem_' + index + '.pt') 40 | # Set the model weights to the loaded trained model weights 41 | tem.load_state_dict(model_weights) 42 | # Make sure model is in evaluate mode (not crucial because it doesn't currently use dropout or batchnorm layers) 43 | tem.eval() 44 | 45 | # Make list of all the environments that this model was trained on 46 | envs = list(glob.iglob('../Summaries/' + date + '/run' + run + '/script/envs/*')) 47 | # Set which environments will include shiny objects 48 | shiny_envs = [False, False, True, True] 49 | # Set the number of walks to execute in parallel (batch size) 50 | n_walks = len(shiny_envs) 51 | # Select environments from the environments included in training 52 | environments = [world.World(graph, randomise_observations=True, shiny=(params['shiny'] if shiny_envs[env_i] else None)) 53 | for env_i, graph in enumerate(np.random.choice(envs, n_walks))] 54 | # Determine the length of each walk 55 | walk_len = np.median([env.n_locations * 50 for env in environments]).astype(int) 56 | # And generate walks for each environment 57 | walks = [env.generate_walks(walk_len, 1)[0] for env in environments] 58 | 59 | # Generate model input from specified walk and environment: group steps from all environments together to feed to model in parallel 60 | model_input = [[[[walks[i][j][k]][0] for i in range(len(walks))] for k in range(3)] for j in range(walk_len)] 61 | for i_step, step in enumerate(model_input): 62 | model_input[i_step][1] = torch.stack(step[1], dim=0) 63 | 64 | # Run a forward pass through the model using this data, without accumulating gradients 65 | with torch.no_grad(): 66 | forward = tem(model_input, prev_iter=None) 67 | 68 | # Decide whether to include stay-still actions as valid occasions for inference 69 | include_stay_still = True 70 | 71 | # Compare trained model performance to a node agent and an edge agent 72 | correct_model, correct_node, correct_edge = analyse.compare_to_agents(forward, tem, environments, include_stay_still=include_stay_still) 73 | 74 | # Analyse occurrences of zero-shot inference: predict the right observation arriving from a visited node with a new action 75 | zero_shot = analyse.zero_shot(forward, tem, environments, include_stay_still=include_stay_still) 76 | 77 | # Generate occupancy maps: how much time TEM spends at every location 78 | occupation = analyse.location_occupation(forward, tem, environments) 79 | 80 | # Generate rate maps 81 | g, p = analyse.rate_map(forward, tem, environments) 82 | 83 | # Calculate accuracy leaving from and arriving to each location 84 | from_acc, to_acc = analyse.location_accuracy(forward, tem, environments) 85 | 86 | # Choose which environment to plot 87 | env_to_plot = 0 88 | # And when averaging environments, e.g. for calculating average accuracy, decide which environments to include 89 | envs_to_avg = shiny_envs if shiny_envs[env_to_plot] else [not shiny_env for shiny_env in shiny_envs] 90 | 91 | # Plot results of agent comparison and zero-shot inference analysis 92 | filt_size = 41 93 | plt.figure() 94 | plt.plot(analyse.smooth(np.mean(np.array([env for env_i, env in enumerate(correct_model) if envs_to_avg[env_i]]),0)[1:], filt_size), label='tem') 95 | plt.plot(analyse.smooth(np.mean(np.array([env for env_i, env in enumerate(correct_node) if envs_to_avg[env_i]]),0)[1:], filt_size), label='node') 96 | plt.plot(analyse.smooth(np.mean(np.array([env for env_i, env in enumerate(correct_edge) if envs_to_avg[env_i]]),0)[1:], filt_size), label='edge') 97 | plt.ylim(0, 1) 98 | plt.legend() 99 | plt.title('Zero-shot inference: ' + str(np.mean([np.mean(env) for env_i, env in enumerate(zero_shot) if envs_to_avg[env_i]]) * 100) + '%') 100 | plt.show() 101 | 102 | # Plot rate maps for all cells 103 | plot.plot_cells(p[env_to_plot], g[env_to_plot], environments[env_to_plot], n_f_ovc=(params['n_f_ovc'] if 'n_f_ovc' in params else 0), columns = 25) 104 | 105 | # Plot accuracy separated by location 106 | plt.figure() 107 | ax = plt.subplot(1,2,1) 108 | plot.plot_map(environments[env_to_plot], np.array(to_acc[env_to_plot]), ax) 109 | ax.set_title('Accuracy to location') 110 | ax = plt.subplot(1,2,2) 111 | plot.plot_map(environments[env_to_plot], np.array(from_acc[env_to_plot]), ax) 112 | ax.set_title('Accuracy from location') 113 | 114 | # Plot occupation per location, then add walks on top 115 | ax = plot.plot_map(environments[env_to_plot], np.array(occupation[env_to_plot])/sum(occupation[env_to_plot])*environments[env_to_plot].n_locations, 116 | min_val=0, max_val=2, ax=None, shape='square', radius=1/np.sqrt(environments[env_to_plot].n_locations)) 117 | ax = plot.plot_walk(environments[env_to_plot], walks[env_to_plot], ax=ax, n_steps=max(1, int(len(walks[env_to_plot])/500))) 118 | plt.title('Walk and average occupation') 119 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Feb 12 15:03:58 2020 5 | 6 | @author: jacobb 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | import os 12 | import datetime 13 | import logging 14 | 15 | def inv_var_weight(mus, sigmas): 16 | ''' 17 | Accepts lists batches of row vectors of means and standard deviations, with batches along dim 0 18 | Return tensors of inverse-variance weighted averages and tensors of inverse-variance weighted standard deviations 19 | ''' 20 | # Stack vectors together along first dimension 21 | mus = torch.stack(mus, dim = 0) 22 | sigmas = torch.stack(sigmas, dim = 0) 23 | # Calculate inverse variance weighted variance from sum over reciprocal of squared sigmas 24 | inv_var_var = 1.0 / torch.sum(1.0 / (sigmas**2), dim = 0) 25 | # Calculate inverse variance weighted average 26 | inv_var_avg = torch.sum(mus / (sigmas**2), dim = 0) * inv_var_var 27 | # Convert weigthed variance to sigma 28 | inv_var_sigma = torch.sqrt(inv_var_var) 29 | # And return results 30 | return inv_var_avg, inv_var_sigma 31 | 32 | def softmax(x): 33 | ''' 34 | Applies softmax to tensors of inputs, using torch softmax funcion 35 | Assumes x is a 1D vector, or batches of row vectors with the batches along dim 0 36 | ''' 37 | # Return torch softmax 38 | return torch.nn.Softmax(dim=-1)(x) 39 | 40 | def normalise(x): 41 | ''' 42 | Normalises vector of input to unit norm, using torch normalise funcion 43 | Assumes x is a 1D vector, or batches of row vectors with the batches along dim 0 44 | ''' 45 | # Return torch normalise with p=2 for L2 norm 46 | return torch.nn.functional.normalize(x, p=2, dim=-1) 47 | 48 | def relu(x): 49 | ''' 50 | Applies rectified linear activation unit to tensors of inputs, using torch relu funcion 51 | ''' 52 | # Return torch relu 53 | return torch.nn.functional.relu(x) 54 | 55 | def leaky_relu(x): 56 | ''' 57 | Applies leaky (meaning small negative slope instead of zeros) rectified linear activation unit to tensors of inputs, using torch leaky relu funcion 58 | ''' 59 | # Return torch leaky relu [torch.nn.functional.leaky_relu(val) for val in x] if type(x) is list else 60 | return torch.nn.functional.leaky_relu(x) 61 | 62 | def squared_error(value, target): 63 | ''' 64 | Calculates mean squared error (L2 norm) between (list of) tensors value and target by using torch MSE loss 65 | Include a factor 0.5 to squared error by convention 66 | Set reduction to none, then get mean over last dimension to keep losses of different batches separate 67 | ''' 68 | # Return torch MSE loss 69 | if type(value) is list: 70 | loss = [0.5 * torch.sum(torch.nn.MSELoss(reduction='none')(value[i], target[i]),dim=-1) for i in range(len(value))] 71 | else: 72 | loss = 0.5 * torch.sum(torch.nn.MSELoss(reduction='none')(value, target),dim=-1) 73 | return loss 74 | 75 | def cross_entropy(value, target): 76 | ''' 77 | Calculates binary cross entropy between tensors value and target by using torch cross entropy loss 78 | Set reduction to none, then get mean over last dimension to keep losses of different batches separate 79 | ''' 80 | # Return torch BCE loss 81 | if type(value) is list: 82 | loss = [torch.nn.CrossEntropyLoss(reduction='none')(val, targ) for val, targ in zip(value, target)] 83 | else: 84 | loss = torch.nn.CrossEntropyLoss(reduction='none')(value, target) 85 | return loss 86 | 87 | def downsample(value, target_dim): 88 | ''' 89 | Does downsampling by taking the an input vector, then averaging chunks to make it of requested dimension 90 | Assumes x is a 1D vector, or batches of row vectors with the batches along dim 0 91 | ''' 92 | # Get input dimension 93 | value_dim = value.size()[-1] 94 | # Set places to break up input vector into chunks 95 | edges = np.append(np.round(np.arange(0, value_dim, float(value_dim) / target_dim)),value_dim).astype(int) 96 | # Create downsampling matrix 97 | downsample = torch.zeros((value_dim,target_dim), dtype = torch.float) 98 | # Fill downsampling matrix with chunks 99 | for curr_entry in range(target_dim): 100 | downsample[edges[curr_entry]:edges[curr_entry+1],curr_entry] = torch.tensor(1.0/(edges[curr_entry+1]-edges[curr_entry]), dtype=torch.float) 101 | # Do downsampling by matrix multiplication 102 | return torch.matmul(value,downsample) 103 | 104 | def make_directories(): 105 | ''' 106 | Creates directories for storing data during a model training run 107 | ''' 108 | # Get current date for saving folder 109 | date = datetime.datetime.today().strftime('%Y-%m-%d') 110 | # Initialise the run and dir_check to create a new run folder within the current date 111 | run = 0 112 | dir_check = True 113 | # Initialise all pahts 114 | train_path, model_path, save_path, script_path, run_path = None, None, None, None, None 115 | # Find the current run: the first run that doesn't exist yet 116 | while dir_check: 117 | # Construct new paths 118 | run_path = '../Summaries/' + date + '/run' + str(run) + '/' 119 | train_path = run_path + 'train' 120 | model_path = run_path + 'model' 121 | save_path = run_path + 'save' 122 | script_path = run_path + 'script' 123 | envs_path = script_path + '/envs' 124 | run += 1 125 | # And once a path doesn't exist yet: create new folders 126 | if not os.path.exists(train_path) and not os.path.exists(model_path) and not os.path.exists(save_path): 127 | os.makedirs(train_path) 128 | os.makedirs(model_path) 129 | os.makedirs(save_path) 130 | os.makedirs(script_path) 131 | os.makedirs(envs_path) 132 | dir_check = False 133 | # Return folders to new path 134 | return run_path, train_path, model_path, save_path, script_path, envs_path 135 | 136 | def set_directories(date, run): 137 | ''' 138 | Returns directories for storing data during a model training run from a given previous training run 139 | ''' 140 | # Initialise all pahts 141 | train_path, model_path, save_path, script_path, run_path = None, None, None, None, None 142 | # Find the current run: the first run that doesn't exist yet 143 | run_path = '../Summaries/' + date + '/run' + str(run) + '/' 144 | train_path = run_path + 'train' 145 | model_path = run_path + 'model' 146 | save_path = run_path + 'save' 147 | script_path = run_path + 'script' 148 | envs_path = script_path + '/envs' 149 | # Return folders to new path 150 | return run_path, train_path, model_path, save_path, script_path, envs_path 151 | 152 | def make_logger(run_path): 153 | ''' 154 | Creates logger so output during training can be stored to file in a consistent way 155 | ''' 156 | # Create new logger 157 | logger = logging.getLogger(__name__) 158 | logger.setLevel(logging.INFO) 159 | # Remove anly existing handlers so you don't output to old files, or to new files twice 160 | logger.handlers = [] 161 | # Create a file handler, but only if the handler does 162 | handler = logging.FileHandler(run_path + 'report.log') 163 | handler.setLevel(logging.INFO) 164 | # Create a logging format 165 | formatter = logging.Formatter('%(asctime)s: %(message)s') 166 | handler.setFormatter(formatter) 167 | # Add the handlers to the logger 168 | logger.addHandler(handler) 169 | # Return the logger object 170 | return logger 171 | -------------------------------------------------------------------------------- /world.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Feb 11 14:33:06 2020 5 | 6 | @author: jacobb 7 | """ 8 | import json 9 | import numpy as np 10 | import torch 11 | import copy 12 | from scipy.sparse.csgraph import shortest_path 13 | 14 | # Functions for generating data that TEM trains on: sequences of [state,observation,action] tuples 15 | 16 | class World: 17 | def __init__(self, env, randomise_observations=False, randomise_policy=False, shiny=None): 18 | # If the environment is provided as a filename: load the corresponding file. If it's no filename, it's assumed to be an environment dictionary 19 | if type(env) == str or type(env) == np.str_: 20 | # Filename provided, load graph from json file 21 | file = open(env, 'r') 22 | json_text = file.read() 23 | env = json.loads(json_text) 24 | file.close() 25 | 26 | # Now env holds a dictionary that describes this world 27 | try: 28 | # Copy expected fiels to object attributes 29 | self.adjacency = env['adjacency'] 30 | self.locations = env['locations'] 31 | self.n_actions = env['n_actions'] 32 | self.n_locations = env['n_locations'] 33 | self.n_observations = env['n_observations'] 34 | except (KeyError, TypeError) as e: 35 | # If any of the expected fields is missing: treat this as an invalid environment 36 | print('Invalid environment: bad dictionary\n', e) 37 | # Initialise all environment fields for an empty environment 38 | self.adjacency = [] 39 | self.locations = [] 40 | self.n_actions = 0 41 | self.n_locations = 0 42 | self.n_observations = 0 43 | 44 | # If requested: shuffle observations from original assignments 45 | if randomise_observations: 46 | self.observations_randomise() 47 | 48 | # If requested: randomise policy by setting equal probability for each action 49 | if randomise_policy: 50 | self.policy_random() 51 | 52 | # Copy the shiny input 53 | self.shiny = copy.deepcopy(shiny) 54 | # If there's no shiny data provided: initialise this world as a non-shiny environement 55 | if self.shiny is None: 56 | # TEM needs to know that this is a non-shiny environment (e.g. for providing actions to generative model), so set shiny to None for each location 57 | for location in self.locations: 58 | location['shiny'] = None 59 | # If shiny data is provided: initialise shiny properties 60 | else: 61 | # Initially make all locations non-shiny 62 | for location in self.locations: 63 | location['shiny'] = False 64 | # Calculate all graph distances, since shiny objects aren't allowed to be too close together 65 | dist_matrix = shortest_path(csgraph=np.array(self.adjacency), directed=False) 66 | # Initialise the list of shiny locations as empty 67 | self.shiny['locations'] = [] 68 | # Then select shiny locations by adding them one-by-one, with the constraint that they can't be too close to each other 69 | while len(self.shiny['locations']) < self.shiny['n']: 70 | new = np.random.randint(self.n_locations) 71 | too_close = [dist_matrix[new,existing] < np.max(dist_matrix) / self.shiny['n'] for existing in self.shiny['locations']] 72 | if not any(too_close): 73 | self.shiny['locations'].append(new) 74 | # Set those locaitons to be shiny 75 | for shiny_location in self.shiny['locations']: 76 | self.locations[shiny_location]['shiny'] = True 77 | # Get objects at shiny locations 78 | self.shiny['objects'] = [self.locations[location]['observation'] for location in self.shiny['locations']] 79 | # Make list of objects that are not shiny 80 | not_shiny = [observation for observation in range(self.n_observations) if observation not in self.shiny['objects'] ] 81 | # Update observations so there is no non-shiny occurence of the shiny objects 82 | for location in self.locations: 83 | # Update a non-shiny location if it has a shiny object observation 84 | if location['id'] not in self.shiny['locations'] and location['observation'] in self.shiny['objects']: 85 | # Pick new observation from non-shiny objects 86 | location['observation'] = np.random.choice(not_shiny) 87 | # Generate a policy towards each of the shiny objects 88 | self.shiny['policies'] = [self.policy_distance(shiny_location) for shiny_location in self.shiny['locations']] 89 | 90 | def observations_randomise(self): 91 | # Run through every abstract location 92 | for location in self.locations: 93 | # Pick random observation from any of the observations 94 | location['observation'] = np.random.randint(self.n_observations) 95 | return self 96 | 97 | def policy_random(self): 98 | # Run through every abstract location 99 | for location in self.locations: 100 | # Count the number of actions that can transition anywhere for this location 101 | count = sum([sum(action['transition']) > 0 for action in location['actions']]) 102 | # Run through all actions at this location to update their probability 103 | for action in location['actions']: 104 | # If this action transitions anywhere: it is an avaiable action, so set its probability to 1/count 105 | action['probability'] = 1.0/count if sum(action['transition']) > 0 else 0 106 | return self 107 | 108 | def policy_learned(self, reward_locations): 109 | # This generates a Q-learned policy towards reward locations. 110 | # Prepare new set of locations to hold policies towards reward locations 111 | new_locations, reward_locations = self.get_reward(reward_locations) 112 | # Initialise state-action values Q at 0 113 | for location in new_locations: 114 | for action in location['actions']: 115 | action['Q'] = 0 116 | # Do value iteration in order to find a policy toward a given location 117 | iters = 10*self.n_locations 118 | # Run value iterations by looping through all actions iteratively 119 | for i in range(iters): 120 | # Deepcopy the current Q-values so they are the same for all updates (don't update values that you later need) 121 | prev_locations = copy.deepcopy(new_locations) 122 | for location in new_locations: 123 | for action in location['actions']: 124 | # Q-value update from value iteration of Bellman equation: Q(s,a) <- sum_across_s'(p(s,a,s') * (r(s') + gamma * max_across_a'(Q(s', a')))) 125 | action['Q'] = sum([probability * ((new_location in reward_locations) + self.shiny['gamma'] * max([new_action['Q'] for new_action in prev_locations[new_location]['actions']])) for new_location, probability in enumerate(action['transition'])]) 126 | # Calculate policy from softmax over Q-values for every state 127 | for location in new_locations: 128 | exp = np.exp(self.shiny['beta'] * np.array([action['Q'] if action['probability']>0 else -np.inf for action in location['actions']])) 129 | for action, probability in zip(location['actions'], exp/sum(exp)): 130 | # Policy from softmax: p(a) = exp(beta*a)/sum_over_as(exp(beta*a_s)) 131 | action['probability'] = probability 132 | # Return new locations with updated policy for given reward locations 133 | return new_locations 134 | 135 | def policy_distance(self, reward_locations): 136 | # This generates a distance-based policy towards reward locations, which is much faster than Q-learning but ignores policy and transition probabilities 137 | # Prepare new set of locations to hold policies towards reward locations 138 | new_locations, reward_locations = self.get_reward(reward_locations) 139 | # Create boolean vector of reward locations for matrix indexing 140 | is_reward_location = np.zeros(self.n_locations, dtype=bool) 141 | is_reward_location[reward_locations] = True 142 | # Calculate distances between all locations based on adjacency matrix - this doesn't take transition probabilities into account! 143 | dist_matrix = shortest_path(csgraph=np.array(self.adjacency), directed=True) 144 | # Fill out minumum distance to any reward state for each action 145 | for location in new_locations: 146 | for action in location['actions']: 147 | action['d'] = np.min(dist_matrix[is_reward_location, np.array(action['transition']) > 0]) if any(action['transition']) else np.inf 148 | # Calculate policy from softmax over negative distances for every action 149 | for location in new_locations: 150 | exp = np.exp(self.shiny['beta'] * np.array([-action['d'] if action['probability']>0 else -np.inf for action in location['actions']])) 151 | for action, probability in zip(location['actions'], exp/sum(exp)): 152 | # Policy from softmax: p(a) = exp(beta*a)/sum_over_as(exp(beta*a_s)) 153 | action['probability'] = probability 154 | # Return new locations with updated policy for given reward locations 155 | return new_locations 156 | 157 | def generate_walks(self, walk_length=10, n_walk=100, repeat_bias_factor=2): 158 | # Generate walk by sampling actions accoring to policy, then next state according to graph 159 | walks = [] # This is going to contain a list of (state, observation, action) tuples 160 | for currWalk in range(n_walk): 161 | new_walk = [] 162 | # If shiny hasn't been specified: there are no shiny objects, generate default policy 163 | if self.shiny is None: 164 | new_walk = self.walk_default(new_walk, walk_length, repeat_bias_factor) 165 | # If shiny was specified: use policy that uses shiny policy to approach shiny objects sequentially 166 | else: 167 | new_walk = self.walk_shiny(new_walk, walk_length, repeat_bias_factor) 168 | # Clean up walk a bit by only keep essential location dictionary entries 169 | for step in new_walk[:-1]: 170 | step[0] = {'id': step[0]['id'], 'shiny': step[0]['shiny']} 171 | # Append new walk to list of walks 172 | walks.append(new_walk) 173 | return walks 174 | 175 | def walk_default(self, walk, walk_length, repeat_bias_factor=2): 176 | # Finish the provided walk until it contains walk_length steps 177 | for curr_step in range(walk_length - len(walk)): 178 | # Get new location based on previous action and location 179 | new_location = self.get_location(walk) 180 | # Get new observation at new location 181 | new_observation = self.get_observation(new_location) 182 | # Get new action based on policy at new location 183 | new_action = self.get_action(new_location, walk) 184 | # Append location, observation, and action to the walk 185 | walk.append([new_location, new_observation, new_action]) 186 | # Return the final walk 187 | return walk 188 | 189 | def walk_shiny(self, walk, walk_length, repeat_bias_factor=2): 190 | # Pick current shiny object to approach 191 | shiny_current = np.random.randint(self.shiny['n']) 192 | # Reset number of iterations to hang around an object once found 193 | shiny_returns = self.shiny['returns'] 194 | # Finish the provided walk until it contains walk_length steps 195 | for curr_step in range(walk_length - len(walk)): 196 | # Get new location based on previous action and location 197 | new_location = self.get_location(walk) 198 | # Check if the shiny object was found in this step 199 | if new_location['id'] == self.shiny['locations'][shiny_current]: 200 | # After shiny object is found, start counting down for hanging around 201 | shiny_returns -= 1 202 | # Check if it's time to select new object to approach 203 | if shiny_returns < 0: 204 | # Pick new current shiny object to approach 205 | shiny_current = np.random.randint(self.shiny['n']) 206 | # Reset number of iterations to hang around an object once found 207 | shiny_returns = self.shiny['returns'] 208 | # Get new observation at new location 209 | new_observation = self.get_observation(new_location) 210 | # Get new action based on policy of new location towards shiny object 211 | new_action = self.get_action(self.shiny['policies'][shiny_current][new_location['id']], walk) 212 | # Append location, observation, and action to the walk 213 | walk.append([new_location, new_observation, new_action]) 214 | # Return the final walk 215 | return walk 216 | 217 | def get_location(self, walk): 218 | # First step: start at random location 219 | if len(walk) == 0: 220 | new_location = np.random.randint(self.n_locations) 221 | # Any other step: get new location from previous location and action 222 | else: 223 | new_location = int(np.flatnonzero(np.cumsum(walk[-1][0]['actions'][walk[-1][2]]['transition'])>np.random.rand())[0]) 224 | # Return the location dictionary of the new location 225 | return self.locations[new_location] 226 | 227 | def get_observation(self, new_location): 228 | # Find sensory observation for new state, and store it as one-hot vector 229 | new_observation = np.eye(self.n_observations)[new_location['observation']] 230 | # Create a new observation by converting the new observation to a torch tensor 231 | new_observation = torch.tensor(new_observation, dtype=torch.float).view((new_observation.shape[0])) 232 | # Return the new observation 233 | return new_observation 234 | 235 | def get_action(self, new_location, walk, repeat_bias_factor=2): 236 | # Build policy from action probability of each action of provided location dictionary 237 | policy = np.array([action['probability'] for action in new_location['actions']]) 238 | # Add a bias for repeating previous action to walk in straight lines, only if (this is not the first step) and (the previous action was a move) 239 | policy[[] if len(walk) == 0 or new_location['id'] == walk[-1][0]['id'] else walk[-1][2]] *= repeat_bias_factor 240 | # And renormalise policy (note that for unavailable actions, the policy was 0 and remains 0, so in that case no renormalisation needed) 241 | policy = policy / sum(policy) if sum(policy) > 0 else policy 242 | # Select action in new state 243 | new_action = int(np.flatnonzero(np.cumsum(policy)>np.random.rand())[0]) 244 | # Return the new action 245 | return new_action 246 | 247 | def get_reward(self, reward_locations): 248 | # Stick reward location into a list if there is only one reward location. Use multiple reward locations simultaneously for e.g. wall attraction 249 | reward_locations = [reward_locations] if type(reward_locations) is not list else reward_locations 250 | # Copy locations for updated policy towards goal 251 | new_locations = copy.deepcopy(self.locations) 252 | # Disable self-actions at reward locations because they will be very attractive 253 | for reward_location in reward_locations: 254 | # Check for each action if it's a self-action 255 | for action in new_locations[reward_location]['actions']: 256 | if action['transition'][reward_location] == 1: 257 | action['probability'] = 0 258 | # Count total action probability to renormalise after disabling self-action 259 | total_probability = sum([action['probability'] for action in new_locations[reward_location]['actions']]) 260 | # Renormalise action probabilities 261 | for action in new_locations[reward_location]['actions']: 262 | action['probability'] = action['probability'] / total_probability if total_probability > 0 else action['probability'] 263 | return new_locations, reward_locations 264 | 265 | 266 | 267 | 268 | 269 | --------------------------------------------------------------------------------