├── .gitattributes ├── LICENSE ├── README.md ├── data ├── __init__.py ├── global.p ├── loading.py └── local.p ├── download_data.sh ├── environment ├── DefaultGenerator.py ├── Generator.py ├── GlobalGenerator.py ├── MDP.py ├── NonTextGenerator.py ├── NonUniqueGenerator.py ├── SpriteFigure.py ├── ValueIteration.py ├── __init__.py ├── figure_library.py ├── library.py ├── reference_instructions.py └── visualization.py ├── generate_worlds.py ├── logs ├── .gitkeep └── example │ ├── git_sprites.png │ └── predictions.png ├── models ├── Linear_custom.py ├── __init__.py ├── attention_direct.py ├── attention_global.py ├── attention_heatmap.py ├── attention_model.py ├── cnn_lstm.py ├── compositor_model.py ├── constraint_factorizer.py ├── conv_to_vector.py ├── custom.py ├── goal_model.py ├── initializations.py ├── lookup_location.py ├── lookup_model.py ├── map_model.py ├── mlp.py ├── model_factorizer.py ├── multi_global.py ├── multi_model.py ├── multi_nobases.py ├── multi_nocnn.py ├── multi_nonsep.py ├── multi_norbf.py ├── object_model.py ├── object_model_10.py ├── object_model_fc.py ├── object_model_pos.py ├── object_model_rnn.py ├── simple_conv.py ├── state_model.py ├── state_model_10.py ├── text_model.py ├── uvfa_pos.py ├── uvfa_text.py └── values_factorized.py ├── pipeline ├── __init__.py ├── agent.py ├── evaluation.py ├── run_eval.py ├── score_iteration.py └── training.py ├── psiturk-vi ├── tmp.p ├── turk_jun22.p ├── turk_jun23_global.p ├── turk_jun26.p └── turk_local_merged.p ├── reinforcement.py ├── representation.py ├── requirements.txt ├── slurm ├── meta-generate.py └── meta_reinforce.sh ├── utils ├── __init__.py ├── filesystem.py └── rbf.py └── visualization ├── __init__.py ├── place_goal.py ├── run_vis.py └── vis_predictions.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.p linguist-language=Python 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Michael Janner 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spatial Reasoning 2 | Code and data to reproduce the experiments in [Representation Learning for Grounded Spatial Reasoning](https://arxiv.org/abs/1707.03938). 3 | 4 | ## Installation 5 | Get [PyTorch](http://pytorch.org/) and `pip install -r requirements` 6 | 7 | `./download_data.sh` to get the annotated map data and sprites to make new maps. 8 | 9 | Note: There were a few breaking changes in the latest PyTorch release that affected this repo. Make sure you are using v0.2.0+ (`torch.__version__` to check) 10 | 11 | ## Data 12 | We have collected human annotations of 3308 goals in 10x10 maps, specified by referencing one or more of 12 object types placed randomly in the world. To load up to `max_train` train maps and `max_val` val maps with `mode = [ local | global ]` instructions and `annotations = [ human | synthetic ]` descriptions, run: 13 | ``` 14 | >>> import data 15 | >>> train_data, val_data = data.load(mode, annotations, max_train, max_val) 16 | >>> layouts, objects, rewards, terminal, instructions, values, goals = train_data 17 | ``` 18 | where `layouts` and `objects` are arrays with item identifiers, `rewards` and `terminal` are arrays that can be used to construct an MDP, `instructions` are a list of text descriptions of the coordinates in `goals`, and `values` has the ground truth state values from Value Iteration. 19 | 20 | To generate more maps (with synthetic annotations): 21 | ``` 22 | $ python generate_worlds.py --mode [ local | global ] --save_path data/example_worlds/ --vis_path data/example_env/ 23 | ``` 24 | which will save pickle files that can be loaded with `data.load()` in `data/example_worlds/` and visualizations in `data/example_env/`. Visualizations of a few of the maps downloaded to `data/local/` are in `data/local_sprites/`. 25 | 26 |

27 | 28 |

29 |

30 | Visualizations of randomly generated worlds 31 |

32 | 33 | ## Training 34 | 35 | To train the model with reinforcement learning: 36 | ``` 37 | $ python reinforcement.py --annotations [ human | synthetic ] --mode [ local | global ] --save_path logs/trial/ 38 | ``` 39 | 40 | This will save the model, pickle files with the predictions, and visualizations to `logs/trial/`. 41 | 42 | To train the models in a supervised manner for the representation analysis, run `python representation.py` with the same arguments as above. 43 | 44 |

45 | 46 |

47 |

48 | Predicted values for two maps with two instructions each. In the first map, the instructions share no objects but refer to the same location. In the second map, the instructions refer to different locations. 49 | 50 |

51 | 52 | ## Acknowledgments 53 | A huge thank you to Daniel Fried and Jacob Andreas for providing a copy of the human annotations when the original Dropbox link was purged. 54 | 55 | 56 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from loading import * -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | wget https://www.dropbox.com/sh/9at8tlmjp1ocrpg/AAAmhzZHNcJHNf5XXflkIHnha?dl=1 -O data.zip 2 | unzip data.zip -d data 3 | mv data/sprites environment/ 4 | rm data.zip 5 | -------------------------------------------------------------------------------- /environment/DefaultGenerator.py: -------------------------------------------------------------------------------- 1 | import math, random, pdb 2 | import numpy as np 3 | import library 4 | from Generator import Generator 5 | 6 | # print library.objects 7 | 8 | class DefaultGenerator(Generator): 9 | 10 | def __init__(self, objects, directions, shape = (20,20), goal_value = 3, num_steps = 50): 11 | self.objects = objects 12 | self.directions = directions 13 | self.shape = shape 14 | self.goal_value = goal_value 15 | self.num_steps = num_steps 16 | 17 | self.default_world = np.array( 18 | [ [1,1,0,0,0,0,0,0,0,1,1,1,1,1,0,1,1,1,1,1], 19 | [1,1,1,0,0,0,0,0,0,1,0,1,1,1,0,1,0,1,0,0], 20 | [1,1,1,0,0,0,0,0,1,1,0,1,1,1,1,1,0,1,0,1], 21 | [1,1,1,1,1,1,1,0,1,0,0,0,0,0,0,0,0,1,1,1], 22 | [0,0,0,0,0,0,1,1,1,0,0,0,0,0,1,0,1,1,1,1], 23 | [0,0,0,1,1,0,1,0,0,0,0,0,1,0,1,1,1,1,1,1], 24 | [0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,1,1,1,1], 25 | [0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,1,0,0,1], 26 | [0,0,0,1,1,1,0,1,1,1,0,0,0,0,0,0,0,0,0,0], 27 | [0,0,0,0,1,1,1,1,0,1,1,1,0,0,0,0,0,0,0,0], 28 | [0,1,1,1,1,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0], 29 | [1,1,1,1,1,1,1,0,0,0,1,1,0,0,0,0,0,0,0,0], 30 | [1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0], 31 | [1,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0], 32 | [1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], 33 | [1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], 34 | [1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], 35 | [1,0,0,1,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0], 36 | [1,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0], 37 | [1,1,1,1,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0] ] 38 | ) 39 | self.default_world = self.default_world * -1 + 1 40 | self.default_world = self.default_world[:10,:10] 41 | print 'world size: ', self.default_world.shape 42 | 43 | def new(self): 44 | directions = {} 45 | world = self.default_world.copy() 46 | # print world 47 | 48 | states = [(i,j) for i in range(world.shape[0]) \ 49 | for j in range(world.shape[1]) \ 50 | if world[i][j] != self.objects['puddle']['index'] ] 51 | 52 | used_indices = set( np.unique(world).tolist() ) 53 | positions = {} 54 | for name, obj in self.objects.iteritems(): 55 | # print obj['name'] 56 | if not obj['background']: 57 | ind = obj['index'] 58 | assert( ind not in used_indices ) 59 | 60 | 61 | # if name == 'square': 62 | # pos = (2,1) 63 | # else: 64 | pos = random.choice(states) #self.randomPosition() 65 | while pos in positions.values(): 66 | pos = random.choice(states) 67 | 68 | world[pos] = ind 69 | used_indices.add(ind) 70 | 71 | # print name 72 | positions[name] = pos 73 | object_specific_dir = self.generateDirections(world, pos, name) 74 | # print object_specific_dir 75 | directions[name] = object_specific_dir 76 | # print directions 77 | # pdb.set_trace() 78 | 79 | reward_maps, terminal_maps, instructions, goals = \ 80 | self.addRewards(world, positions, directions) 81 | 82 | info = { 83 | 'map': world, 84 | 'rewards': reward_maps, 85 | 'terminal': terminal_maps, 86 | 'instructions': instructions, 87 | 'goals': goals 88 | } 89 | return info 90 | -------------------------------------------------------------------------------- /environment/Generator.py: -------------------------------------------------------------------------------- 1 | import math, random, pdb 2 | import numpy as np 3 | import library 4 | 5 | # print library.objects 6 | 7 | class Generator: 8 | 9 | def __init__(self, objects, directions, shape = (20,20), goal_value = 3, num_steps = 50): 10 | self.objects = objects 11 | self.directions = directions 12 | self.shape = shape 13 | self.goal_value = goal_value 14 | self.num_steps = num_steps 15 | 16 | def new(self): 17 | directions = {} 18 | world = self.__puddles(self.num_steps) 19 | # print world 20 | 21 | states = [(i,j) for i in range(world.shape[0]) \ 22 | for j in range(world.shape[1]) \ 23 | if world[i][j] != self.objects['puddle']['index'] ] 24 | 25 | used_indices = set( np.unique(world).tolist() ) 26 | positions = {} 27 | for name, obj in self.objects.iteritems(): 28 | # print obj['name'] 29 | if not obj['background']: 30 | ind = obj['index'] 31 | assert( ind not in used_indices ) 32 | 33 | 34 | pos = random.choice(states) #self.randomPosition() 35 | while pos in positions.values(): 36 | pos = random.choice(states) 37 | 38 | world[pos] = ind 39 | used_indices.add(ind) 40 | 41 | # print name 42 | positions[name] = pos 43 | object_specific_dir = self.generateDirections(world, pos, name) 44 | # print object_specific_dir 45 | directions[name] = object_specific_dir 46 | # print directions 47 | # pdb.set_trace() 48 | 49 | reward_maps, terminal_maps, instructions, goals = \ 50 | self.addRewards(world, positions, directions) 51 | 52 | info = { 53 | 'map': world, 54 | 'rewards': reward_maps, 55 | 'terminal': terminal_maps, 56 | 'instructions': instructions, 57 | 'goals': goals 58 | } 59 | return info 60 | 61 | def __puddles(self, iters, max_width=3, max_steps=8): 62 | (M,N) = self.shape 63 | turns = ['up', 'down', 'left', 'right'] 64 | world = np.zeros( self.shape ) 65 | world.fill( self.objects['puddle']['index'] ) 66 | # position = np.floor(np.random.uniform(size=2)*self.shape[0]).astype(int) 67 | 68 | position = (np.random.uniform()*M, np.random.uniform()*N) 69 | position = map(int, position) 70 | 71 | for i in range(iters): 72 | direction = random.choice(turns) 73 | width = int(np.random.uniform(low=1, high=max_width)) 74 | steps = int(np.random.uniform(low=1, high=max_steps)) 75 | if direction == 'up': 76 | top = max(position[0] - steps, 0) 77 | bottom = position[0] 78 | left = max(position[1] - int(math.floor(width/2.)), 0) 79 | right = min(position[1] + int(math.ceil(width/2.)), N) 80 | position[0] = top 81 | elif direction == 'down': 82 | top = position[0] 83 | bottom = min(position[0] + steps, M) 84 | left = max(position[1] - int(math.floor(width/2.)), 0) 85 | right = min(position[1] + int(math.ceil(width/2.)), N) 86 | position[0] = bottom 87 | elif direction == 'left': 88 | top = max(position[0] - int(math.floor(width/2.)), 0) 89 | bottom = min(position[0] + int(math.ceil(width/2.)), M) 90 | left = max(position[1] - steps, 0) 91 | right = position[1] 92 | position[1] = left 93 | elif direction == 'right': 94 | top = max(position[0] - int(math.floor(width/2.)), 0) 95 | bottom = min(position[0] + int(math.ceil(width/2.)), M) 96 | left = position[1] 97 | right = min(position[1] + steps, N) 98 | position[1] = right 99 | # print top, bottom, left, right, self.objects['grass']['index'] 100 | # print world.shape 101 | world[top:bottom+1, left:right+1] = self.objects['grass']['index'] 102 | 103 | return world 104 | 105 | def addRewards(self, world, positions, directions): 106 | reward_maps = [] 107 | terminal_maps = [] 108 | instruction_set = [] 109 | goal_positions = [] 110 | 111 | object_values = np.zeros( self.shape ) 112 | ## add non-background values 113 | for name, obj in self.objects.iteritems(): 114 | value = obj['value'] 115 | if not obj['background']: 116 | pos = positions[name] 117 | object_values[pos] = value 118 | else: 119 | mask = np.ma.masked_equal(world, obj['index']).mask 120 | # print name, obj['index'] 121 | # print mask 122 | # print value 123 | object_values[mask] += value 124 | # for st 125 | # value = obj['v'] 126 | # print 'values: ' 127 | # print object_values 128 | 129 | for name, obj in self.objects.iteritems(): 130 | if not obj['background']: 131 | # ind = obj['index'] 132 | # pos = positions[name] 133 | for (phrase, target_pos) in directions[name]: 134 | rewards = object_values.copy() 135 | rewards[target_pos] += self.goal_value 136 | terminal = np.zeros( self.shape ) 137 | terminal[target_pos] = 1 138 | 139 | reward_maps.append(rewards) 140 | terminal_maps.append(terminal) 141 | instruction_set.append(phrase) 142 | goal_positions.append(target_pos) 143 | 144 | # print name, pos, direct 145 | # print reward_maps 146 | # print instruction_set 147 | # print goal_positions 148 | return reward_maps, terminal_maps, instruction_set, goal_positions 149 | 150 | # def randomPosition(self): 151 | # pos = tuple( (np.random.uniform(size=2) * self.dim).astype(int) ) 152 | # return pos 153 | 154 | def generateDirections(self, world, pos, name): 155 | directions = [] 156 | for identifier, offset in self.directions.iteritems(): 157 | phrase = 'reach cell ' + identifier + ' ' + name 158 | absolute_pos = tuple(map(sum, zip(pos, offset))) 159 | # print absolute_pos 160 | (i,j) = absolute_pos 161 | (M,N) = self.shape 162 | out_of_bounds = (i < 0 or i >= M) or (j < 0 or j >= N) 163 | 164 | # absolute_pos[0] < 0 or absolute_pos[0] > [coord < 0 or coord >= self.dim for coord in absolute_pos] 165 | if not out_of_bounds: 166 | in_puddle = world[absolute_pos] == self.objects['puddle']['index'] 167 | if not in_puddle: 168 | directions.append( (phrase, absolute_pos) ) 169 | # print directions 170 | return directions 171 | 172 | # for i in [-1, 0, 1]: 173 | # for j in [-1, 0, 1]: 174 | # if i == 0 and j == 0: 175 | # pass 176 | # else: 177 | 178 | 179 | if __name__ == '__main__': 180 | gen = Generator(library.objects, library.directions) 181 | info = gen.new() 182 | 183 | print info['map'] 184 | print info['instructions'], len(info['instructions']) 185 | print len(info['rewards']), len(info['terminal']) 186 | # print info['terminal'] 187 | 188 | 189 | 190 | -------------------------------------------------------------------------------- /environment/MDP.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class MDP: 4 | 5 | def __init__(self, world, rewards, terminal): 6 | self.world = world 7 | self.reward_map = rewards 8 | self.terminal_map = terminal 9 | self.shape = self.reward_map.shape 10 | 11 | self.M, self.N = self.shape 12 | self.states = [(i,j) for i in range(self.M) for j in range(self.N)] 13 | self.children = self.get_children( self.M, self.N ) 14 | 15 | self.actions = [(-1,0),(1,0),(0,-1),(0,1)] 16 | self.states = [(i,j) for i in range(self.shape[0]) for j in range(self.shape[1])] 17 | 18 | def getActions(self): 19 | return [i for i in range(len(self.actions))] 20 | 21 | def getStates(self): 22 | return self.states 23 | 24 | def transition(self, position, action_ind, fullstate=False): 25 | action = self.actions[action_ind] 26 | # print 'transitioning: ', action, position 27 | candidate = tuple(map(sum, zip(position, action))) 28 | 29 | ## if new location is valid, 30 | ## update the position 31 | if self.valid(candidate): 32 | position = candidate 33 | 34 | if fullstate: 35 | state = self.observe(position) 36 | else: 37 | state = position 38 | 39 | return state 40 | 41 | def valid(self, position): 42 | x, y = position[0], position[1] 43 | if x >= 0 and x < self.shape[0] and y >= 0 and y < self.shape[1]: 44 | return True 45 | else: 46 | return False 47 | 48 | def reward(self, position): 49 | rew = self.reward_map[position] 50 | return rew 51 | 52 | def terminal(self, position): 53 | term = self.terminal_map[position] 54 | return term 55 | 56 | def representValues(self, values): 57 | value_map = np.zeros( self.shape ) 58 | for pos, val in values.iteritems(): 59 | assert(value_map[pos] == 0) 60 | value_map[pos] = val 61 | return value_map 62 | 63 | # ''' 64 | # start_pos is (i,j) 65 | # policy is dict from (i,j) --> (delta_i, delta_j) 66 | # ''' 67 | # def simulate(self, policy, start_pos, num_steps = 100): 68 | # pos = start_pos 69 | # visited = set([pos]) 70 | # for step in range(num_steps): 71 | # # rew = self.reward(pos) 72 | # term = self.terminal(pos) 73 | # if term: 74 | # return 0 75 | # reachable = policy[pos] 76 | # selected = 0 77 | # while selected < len(reachable) and reachable[selected] in visited: 78 | # # print ' visited ', selected, reachable[selected] 79 | # selected += 1 80 | # if selected == len(reachable): 81 | # selected = 0 82 | # pos = policy[pos][selected] 83 | # visited.add(pos) 84 | # print 'position: ', pos 85 | # # print pos, goal 86 | # goal = np.argwhere( self.terminal_map ).flatten().tolist() 87 | # manhattan_dist = abs(pos[0] - goal[0]) + abs(pos[1] - goal[1]) 88 | # return manhattan_dist 89 | 90 | ''' 91 | start_pos is (i,j) 92 | policy is dict from (i,j) --> (delta_i, delta_j) 93 | ''' 94 | def simulate(self, policy, start_pos, num_steps = 100): 95 | pos = start_pos 96 | visited = set([pos]) 97 | for step in range(num_steps): 98 | # rew = self.reward(pos) 99 | term = self.terminal(pos) 100 | if term: 101 | return step 102 | reachable = policy[pos] 103 | selected = 0 104 | while selected < len(reachable) and reachable[selected] in visited: 105 | # print ' visited ', selected, reachable[selected] 106 | selected += 1 107 | if selected == len(reachable): 108 | selected = 0 109 | pos = policy[pos][selected] 110 | visited.add(pos) 111 | # print 'position: ', pos 112 | return step 113 | 114 | def get_children(self, M, N): 115 | children = {} 116 | for i in range(M): 117 | for j in range(N): 118 | pos = (i,j) 119 | children[pos] = [] 120 | for di in range( max(i-1, 0), min(i+1, M-1)+1 ): 121 | for dj in range( max(j-1, 0), min(j+1, N-1)+1 ): 122 | child = (di, dj) 123 | if pos != child and (i == di or j == dj): 124 | children[pos].append( child ) 125 | return children 126 | 127 | 128 | ''' 129 | values is M x N map of predicted values 130 | ''' 131 | def get_policy(self, values): 132 | policy = {} 133 | for state in self.states: 134 | reachable = self.children[state] 135 | selected = sorted(reachable, key = lambda x: values[x], reverse=True) 136 | policy[state] = selected 137 | return policy 138 | 139 | 140 | 141 | if __name__ == '__main__': 142 | import pickle, pdb, numpy as np 143 | info = pickle.load( open('../data/train_10/0.p') ) 144 | print info.keys() 145 | mdp = MDP(info['map'], info['rewards'][0], info['terminal'][0]) 146 | print mdp.world 147 | print mdp.children 148 | 149 | values = np.random.randn(10,10) 150 | policy = mdp.get_policy(values) 151 | 152 | steps = mdp.simulate(policy, (0,0)) 153 | print steps 154 | 155 | pdb.set_trace() 156 | 157 | 158 | -------------------------------------------------------------------------------- /environment/NonTextGenerator.py: -------------------------------------------------------------------------------- 1 | import math, random, pdb 2 | import numpy as np 3 | import library 4 | from Generator import Generator 5 | 6 | # print library.objects 7 | 8 | class NonTextGenerator(Generator): 9 | 10 | def __init__(self, objects, shape = (20,20), goal_value = 3, num_steps = 50): 11 | self.objects = objects 12 | self.shape = shape 13 | self.goal_value = goal_value 14 | self.num_steps = num_steps 15 | 16 | def new(self): 17 | directions = {} 18 | world = self.__puddles(self.num_steps) 19 | 20 | # states = [(i,j) for i in range(world.shape[0]) \ 21 | # for j in range(world.shape[1]) \ 22 | # if world[i][j] != self.objects['puddle']['index'] ] 23 | 24 | ## we are using ALL of the locations as goals, 25 | ## even those in puddles 26 | states = [ (i,j) for i in range(world.shape[0]) \ 27 | for j in range(world.shape[1]) ] 28 | 29 | reward_maps = [] 30 | terminal_maps = [] 31 | goals = [] 32 | for state in states: 33 | reward = np.zeros( world.shape ) 34 | terminal = np.zeros( world.shape ) 35 | reward[state] = self.goal_value 36 | terminal[state] = 1 37 | reward_maps.append(reward) 38 | terminal_maps.append(terminal) 39 | goals.append(state) 40 | 41 | info = { 42 | 'map': world, 43 | 'rewards': reward_maps, 44 | 'terminal': terminal_maps, 45 | 'goals': goals 46 | } 47 | 48 | # pdb.set_trace() 49 | 50 | return info 51 | 52 | def __puddles(self, iters, max_width=3, max_steps=10): 53 | (M,N) = self.shape 54 | turns = ['up', 'down', 'left', 'right'] 55 | world = np.zeros( self.shape ) 56 | world.fill( self.objects['puddle']['index'] ) 57 | # position = np.floor(np.random.uniform(size=2)*self.shape[0]).astype(int) 58 | 59 | position = (np.random.uniform()*M, np.random.uniform()*N) 60 | position = map(int, position) 61 | 62 | for i in range(iters): 63 | direction = random.choice(turns) 64 | width = int(np.random.uniform(low=1, high=max_width)) 65 | steps = int(np.random.uniform(low=1, high=max_steps)) 66 | if direction == 'up': 67 | top = max(position[0] - steps, 0) 68 | bottom = position[0] 69 | left = max(position[1] - int(math.floor(width/2.)), 0) 70 | right = min(position[1] + int(math.ceil(width/2.)), N) 71 | position[0] = top 72 | elif direction == 'down': 73 | top = position[0] 74 | bottom = min(position[0] + steps, M) 75 | left = max(position[1] - int(math.floor(width/2.)), 0) 76 | right = min(position[1] + int(math.ceil(width/2.)), N) 77 | position[0] = bottom 78 | elif direction == 'left': 79 | top = max(position[0] - int(math.floor(width/2.)), 0) 80 | bottom = min(position[0] + int(math.ceil(width/2.)), M) 81 | left = max(position[1] - steps, 0) 82 | right = position[1] 83 | position[1] = left 84 | elif direction == 'right': 85 | top = max(position[0] - int(math.floor(width/2.)), 0) 86 | bottom = min(position[0] + int(math.ceil(width/2.)), M) 87 | left = position[1] 88 | right = min(position[1] + steps, N) 89 | position[1] = right 90 | # print top, bottom, left, right, self.objects['grass']['index'] 91 | # print world.shape 92 | world[top:bottom+1, left:right+1] = self.objects['grass']['index'] 93 | 94 | return world 95 | 96 | # def addRewards(self, world, positions, directions): 97 | # reward_maps = [] 98 | # terminal_maps = [] 99 | # instruction_set = [] 100 | # goal_positions = [] 101 | 102 | # object_values = np.zeros( self.shape ) 103 | # ## add non-background values 104 | # for name, obj in self.objects.iteritems(): 105 | # value = obj['value'] 106 | # if not obj['background']: 107 | # pos = positions[name] 108 | # object_values[pos] = value 109 | # else: 110 | # mask = np.ma.masked_equal(world, obj['index']).mask 111 | # # print name, obj['index'] 112 | # # print mask 113 | # # print value 114 | # object_values[mask] += value 115 | # # for st 116 | # # value = obj['v'] 117 | # # print 'values: ' 118 | # # print object_values 119 | 120 | # for name, obj in self.objects.iteritems(): 121 | # if not obj['background']: 122 | # # ind = obj['index'] 123 | # # pos = positions[name] 124 | # for (phrase, target_pos) in directions[name]: 125 | # rewards = object_values.copy() 126 | # rewards[target_pos] += self.goal_value 127 | # terminal = np.zeros( self.shape ) 128 | # terminal[target_pos] = 1 129 | 130 | # reward_maps.append(rewards) 131 | # terminal_maps.append(terminal) 132 | # instruction_set.append(phrase) 133 | # goal_positions.append(target_pos) 134 | 135 | # # print name, pos, direct 136 | # # print reward_maps 137 | # # print instruction_set 138 | # # print goal_positions 139 | # return reward_maps, terminal_maps, instruction_set, goal_positions 140 | 141 | # # def randomPosition(self): 142 | # # pos = tuple( (np.random.uniform(size=2) * self.dim).astype(int) ) 143 | # # return pos 144 | 145 | # def generateDirections(self, world, pos, name): 146 | # directions = [] 147 | # for identifier, offset in self.directions.iteritems(): 148 | # phrase = 'reach cell ' + identifier + ' ' + name 149 | # absolute_pos = tuple(map(sum, zip(pos, offset))) 150 | # # print absolute_pos 151 | # (i,j) = absolute_pos 152 | # (M,N) = self.shape 153 | # out_of_bounds = (i < 0 or i >= M) or (j < 0 or j >= N) 154 | 155 | # # absolute_pos[0] < 0 or absolute_pos[0] > [coord < 0 or coord >= self.dim for coord in absolute_pos] 156 | # if not out_of_bounds: 157 | # in_puddle = world[absolute_pos] == self.objects['puddle']['index'] 158 | # if not in_puddle: 159 | # directions.append( (phrase, absolute_pos) ) 160 | # # print directions 161 | # return directions 162 | 163 | # # for i in [-1, 0, 1]: 164 | # # for j in [-1, 0, 1]: 165 | # # if i == 0 and j == 0: 166 | # # pass 167 | # # else: 168 | 169 | 170 | if __name__ == '__main__': 171 | gen = Generator(library.objects, library.directions) 172 | info = gen.new() 173 | 174 | print info['map'] 175 | print info['instructions'], len(info['instructions']) 176 | print len(info['rewards']), len(info['terminal']) 177 | # print info['terminal'] 178 | 179 | 180 | 181 | -------------------------------------------------------------------------------- /environment/SpriteFigure.py: -------------------------------------------------------------------------------- 1 | import os, numpy as np, scipy.misc, pdb 2 | 3 | class SpriteFigure: 4 | 5 | def __init__(self, objects, background, dim = 20): 6 | self.objects = objects 7 | self.dim = dim 8 | background = self.loadImg(background) 9 | self.sprites = {0: background} 10 | for name, obj in objects.iteritems(): 11 | ind = obj['index'] 12 | # print name 13 | sprite = self.loadImg( obj['sprite'] ) 14 | if not obj['background']: 15 | # overlay = background.copy() 16 | # masked = np.ma.masked_greater( sprite[:,:,-1], 0 ).mask 17 | # overlay[masked] = sprite[:,:,:-1][masked] 18 | overlay = sprite 19 | else: 20 | overlay = sprite 21 | 22 | self.sprites[ind] = overlay 23 | 24 | 25 | def loadImg(self, path, dim = None): 26 | if dim == None: 27 | dim = self.dim 28 | path = os.path.join('environment', path) 29 | img = scipy.misc.imread(path) 30 | img = scipy.misc.imresize(img, (dim, dim) ) 31 | return img 32 | 33 | def makeGrid(self, world, filename, boundary_width = 4): 34 | shape = world.shape 35 | 36 | grass_ind = self.objects['grass']['index'] 37 | puddle_ind = self.objects['puddle']['index'] 38 | 39 | state = self.loadImg( self.objects['grass']['sprite'], dim = shape[0] * self.dim ) 40 | puddle = self.loadImg( self.objects['puddle']['sprite'], dim = shape[0] * self.dim ) 41 | 42 | for i in range(shape[0]): 43 | for j in range(shape[1]): 44 | 45 | row_low = i*self.dim 46 | row_high = (i+1)*self.dim 47 | col_low = j*self.dim 48 | col_high = (j+1)*self.dim 49 | 50 | ind = int(world[i,j]) 51 | sprite = self.sprites[ind] 52 | 53 | if ind == grass_ind: 54 | continue 55 | elif ind == puddle_ind: 56 | state[row_low:row_high, col_low:col_high, :] = puddle[row_low:row_high, col_low:col_high, :] 57 | ## background 58 | else: 59 | masked = np.ma.masked_greater( sprite[:,:,-1], 0 ).mask 60 | state[row_low:row_high, col_low:col_high, :][masked] = sprite[:,:,:-1][masked] 61 | # overlay[masked] = sprite[:,:,:-1][masked] 62 | # sprite = self.sprites[world[i,j].astype('int')] 63 | # state[i*self.dim:(i+1)*self.dim, j*self.dim:(j+1)*self.dim, :] = sprite 64 | 65 | for i in range(shape[0]): 66 | for j in range(shape[1]): 67 | 68 | row_low = i*self.dim 69 | row_high = (i+1)*self.dim 70 | col_low = j*self.dim 71 | col_high = (j+1)*self.dim 72 | 73 | ind = int(world[i,j]) 74 | 75 | if i < shape[0] - 1: 76 | below = int(world[i+1,j]) 77 | if (ind != puddle_ind and below == puddle_ind) or (ind == puddle_ind and below != puddle_ind): 78 | # print 'BELOW: ', i, j 79 | state[row_high-boundary_width:row_high+boundary_width, col_low:col_high, :] = 0. 80 | 81 | if j < shape[1] - 1: 82 | right = int(world[i,j+1]) 83 | if (ind != puddle_ind and right == puddle_ind) or (ind == puddle_ind and right != puddle_ind): 84 | # print 'BELOW: ', i, j 85 | state[row_low:row_high, col_high-boundary_width:col_high+boundary_width, :] = 0. 86 | 87 | 88 | scipy.misc.imsave(filename + '.png', state) 89 | return state 90 | -------------------------------------------------------------------------------- /environment/ValueIteration.py: -------------------------------------------------------------------------------- 1 | class ValueIteration: 2 | 3 | def __init__(self, mdp): 4 | self.refresh(mdp) 5 | 6 | def refresh(self, mdp): 7 | self.mdp = mdp 8 | self.states = mdp.getStates() 9 | self.actions = mdp.getActions() 10 | self.transition = mdp.transition 11 | self.reward = mdp.reward 12 | self.terminal = mdp.terminal 13 | self.values = {state: 0 for state in self.states} 14 | self.policy = {state: None for state in self.states} 15 | self.discount = 0.9 16 | 17 | 18 | def iterate(self): 19 | for k in range(0, 1000): 20 | for state in self.states: 21 | max_val = -float('inf') 22 | term = self.terminal(state) 23 | if term: 24 | max_val = self.reward(state) 25 | else: 26 | for action in self.actions: 27 | new_state = self.transition(state, action) 28 | new_val = self.reward(state) + self.discount * self.values[new_state] 29 | if new_val > max_val: 30 | max_val = new_val 31 | self.policy[state] = action 32 | self.values[state] = max_val 33 | return self.values, self.policy -------------------------------------------------------------------------------- /environment/__init__.py: -------------------------------------------------------------------------------- 1 | import library, figure_library 2 | from MDP import * 3 | from ValueIteration import ValueIteration 4 | from SpriteFigure import SpriteFigure -------------------------------------------------------------------------------- /environment/figure_library.py: -------------------------------------------------------------------------------- 1 | spritepath = 'sprites/' 2 | 3 | objects = { 4 | 'grass': { 5 | 'index': 0, 6 | 'value': 0, 7 | 'sprite': 'sprites/grass_figure_4.png', # 'sprites/white.png', 8 | 'background': True, 9 | 'unique': False, 10 | }, 11 | 'puddle': { 12 | 'index': 1, 13 | 'value': -1, 14 | 'sprite': 'sprites/water_figure_2.png', 15 | 'background': True, 16 | 'unique': False, 17 | }, 18 | ## unique 19 | 'star': { 20 | 'index': 2, 21 | 'value': 0, 22 | 'sprite': 'sprites/star_figure-01.png', ## white_alpha.png 23 | 'background': False, 24 | 'unique': True, 25 | }, 26 | 'circle': { 27 | 'index': 3, 28 | 'value': 0, 29 | 'sprite': 'sprites/circle_figure-01.png', 30 | 'background': False, 31 | 'unique': True, 32 | }, 33 | 'triangle': { 34 | 'index': 4, 35 | 'value': 0, 36 | 'sprite': 'sprites/triangle_figure-01.png', 37 | 'background': False, 38 | 'unique': True, 39 | }, 40 | 'heart': { 41 | 'index': 5, 42 | 'value': 0, 43 | 'sprite': 'sprites/heart_figure-01.png', 44 | 'background': False, 45 | 'unique': True, 46 | }, 47 | 'spade': { 48 | 'index': 6, 49 | 'value': 0, 50 | 'sprite': 'sprites/spade_figure-01.png', 51 | 'background': False, 52 | 'unique': True, 53 | }, 54 | 'diamond': { 55 | 'index': 7, 56 | 'value': 0, 57 | 'sprite': 'sprites/diamond_figure-01.png', 58 | 'background': False, 59 | 'unique': True, 60 | }, 61 | ## non-unique 62 | 'rock': { 63 | 'index': 8, 64 | 'value': 0, 65 | 'sprite': 'sprites/rock_figure-01.png', 66 | 'background': False, 67 | 'unique': False, 68 | }, 69 | 'tree': { 70 | 'index': 9, 71 | 'value': 0, 72 | 'sprite': 'sprites/tree_figure-01.png', 73 | 'background': False, 74 | 'unique': False, 75 | }, 76 | 'house': { 77 | 'index': 10, 78 | 'value': 0, 79 | 'sprite': 'sprites/house_figure-01.png', 80 | 'background': False, 81 | 'unique': False, 82 | }, 83 | 'horse': { 84 | 'index': 11, 85 | 'value': 0, 86 | 'sprite': 'sprites/horse_figure-01.png', 87 | 'background': False, 88 | 'unique': False, 89 | }, 90 | } 91 | 92 | unique_instructions = { 93 | ## original 94 | 'to top left of': (-1, -1), 95 | 'on top of': (-1, 0), 96 | 'to top right of': (-1, 1), 97 | 'to left of': (0, -1), 98 | 'with': (0, 0), 99 | 'to right of': (0, 1), 100 | 'to bottom left of': (1, -1), 101 | 'on bottom of': (1, 0), 102 | 'to bottom right of': (1, 1), 103 | ## two steps away 104 | 'two to the left and two above': (-2, -2), 105 | 'one to the left and two above': (-2, -1), 106 | 'two above': (-2, 0), 107 | 'one to the right and two above': (-2, 1), 108 | 'two to the right and two above': (-2, 2), 109 | 'two to the right and one above': (-1, 2), 110 | 'two to the right of': (0, 2), 111 | 'two to the right and one below': (1, 2), 112 | 'two to the right and two below': (2, 2), 113 | 'one to the right and two below': (2, 1), 114 | 'two below': (2, 0), 115 | 'one to the left and two below': (2, -1), 116 | 'two to the left and two below': (2, -2), 117 | 'two to the left and one below': (1, -2), 118 | 'two to the left': (0, -2), 119 | 'two to the left and one above': (-1, -2) 120 | 121 | } 122 | 123 | background = 'sprites/grass_figure_4.png' 124 | 125 | # print objects 126 | -------------------------------------------------------------------------------- /environment/library.py: -------------------------------------------------------------------------------- 1 | spritepath = 'sprites/' 2 | 3 | objects = { 4 | 'grass': { 5 | 'index': 0, 6 | 'value': 0, 7 | 'sprite': 'environment/sprites/white.png', # 'sprites/white.png', 8 | 'background': True, 9 | 'unique': False, 10 | }, 11 | 'puddle': { 12 | 'index': 1, 13 | 'value': -1, 14 | 'sprite': 'environment/sprites/white.png', 15 | 'background': True, 16 | 'unique': False, 17 | }, 18 | ## unique 19 | 'star': { 20 | 'index': 2, 21 | 'value': 0, 22 | 'sprite': 'environment/sprites/white_alpha.png', ## white_alpha.png 23 | 'background': False, 24 | 'unique': True, 25 | }, 26 | 'circle': { 27 | 'index': 3, 28 | 'value': 0, 29 | 'sprite': 'environment/sprites/white_alpha.png', 30 | 'background': False, 31 | 'unique': True, 32 | }, 33 | 'triangle': { 34 | 'index': 4, 35 | 'value': 0, 36 | 'sprite': 'environment/sprites/white_alpha.png', 37 | 'background': False, 38 | 'unique': True, 39 | }, 40 | 'heart': { 41 | 'index': 5, 42 | 'value': 0, 43 | 'sprite': 'environment/sprites/white_alpha.png', 44 | 'background': False, 45 | 'unique': True, 46 | }, 47 | 'spade': { 48 | 'index': 6, 49 | 'value': 0, 50 | 'sprite': 'environment/sprites/white_alpha.png', 51 | 'background': False, 52 | 'unique': True, 53 | }, 54 | 'diamond': { 55 | 'index': 7, 56 | 'value': 0, 57 | 'sprite': 'environment/sprites/white_alpha.png', 58 | 'background': False, 59 | 'unique': True, 60 | }, 61 | ## non-unique 62 | 'rock': { 63 | 'index': 8, 64 | 'value': 0, 65 | 'sprite': 'environment/sprites/rock_hires.png', 66 | 'background': False, 67 | 'unique': False, 68 | }, 69 | 'tree': { 70 | 'index': 9, 71 | 'value': 0, 72 | 'sprite': 'environment/sprites/tree_hires.png', 73 | 'background': False, 74 | 'unique': False, 75 | }, 76 | 'house': { 77 | 'index': 10, 78 | 'value': 0, 79 | 'sprite': 'environment/sprites/house_hires.png', 80 | 'background': False, 81 | 'unique': False, 82 | }, 83 | 'horse': { 84 | 'index': 11, 85 | 'value': 0, 86 | 'sprite': 'environment/sprites/horse_hires.png', 87 | 'background': False, 88 | 'unique': False, 89 | }, 90 | } 91 | 92 | unique_instructions = { 93 | ## original 94 | 'to top left of': (-1, -1), 95 | 'on top of': (-1, 0), 96 | 'to top right of': (-1, 1), 97 | 'to left of': (0, -1), 98 | 'with': (0, 0), 99 | 'to right of': (0, 1), 100 | 'to bottom left of': (1, -1), 101 | 'on bottom of': (1, 0), 102 | 'to bottom right of': (1, 1), 103 | ## two steps away 104 | 'two to the left and two above': (-2, -2), 105 | 'one to the left and two above': (-2, -1), 106 | 'two above': (-2, 0), 107 | 'one to the right and two above': (-2, 1), 108 | 'two to the right and two above': (-2, 2), 109 | 'two to the right and one above': (-1, 2), 110 | 'two to the right of': (0, 2), 111 | 'two to the right and one below': (1, 2), 112 | 'two to the right and two below': (2, 2), 113 | 'one to the right and two below': (2, 1), 114 | 'two below': (2, 0), 115 | 'one to the left and two below': (2, -1), 116 | 'two to the left and two below': (2, -2), 117 | 'two to the left and one below': (1, -2), 118 | 'two to the left': (0, -2), 119 | 'two to the left and one above': (-1, -2) 120 | 121 | } 122 | 123 | background = 'environment/sprites/white.png' 124 | 125 | # print objects -------------------------------------------------------------------------------- /environment/reference_instructions.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | from collections import defaultdict 3 | 4 | def num_to_str(num): 5 | str_rep = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine'] 6 | return str_rep[num] 7 | 8 | def create_references(max_dist = 2, verbose = False): 9 | references = defaultdict(lambda: []) 10 | references[( (-2,0), (-1,0) )].append('reach between and ') 11 | references[( (2, 0), (1, 0) )].append('reach between and ') 12 | references[( (0,-2), (0,-1) )].append('reach between and ') 13 | references[( (0, 2), (0, 1) )].append('reach between and ') 14 | 15 | for i in range(-2, 3): 16 | for j in range(-2, 3): 17 | for goal_i in range(-2, 3): 18 | for goal_j in range(-2, 3): 19 | if goal_i != i and goal_j != j and goal_i != 0 and goal_j != 0: 20 | 21 | if abs(goal_j - j) > max_dist or abs(goal_i - i) > max_dist: 22 | continue 23 | ## i with reference to obj 1 24 | ## j with reference to obj 2 25 | row = goal_i 26 | col = goal_j - j 27 | if row > 0: 28 | vertical = '{} below '.format( num_to_str(row) ) 29 | elif row < 0: 30 | vertical = '{} above '.format( num_to_str(abs(row)) ) 31 | else: 32 | raise RuntimeError('goal_i should not be in line with obj 1') 33 | 34 | if col > 0: 35 | horizontal = '{} to the right of '.format( num_to_str(col) ) 36 | elif col < 0: 37 | horizontal = '{} to the left of '.format( num_to_str(abs(col)) ) 38 | else: 39 | raise RuntimeError('goal_j should not be in line with obj 2') 40 | 41 | if verbose: 42 | print 'OBJ2: ', i, j, ' Goal: ', goal_i, goal_j 43 | 44 | instructions = 'reach ' + vertical + ' and ' + horizontal 45 | references[( (i,j), (goal_i,goal_j) )].append(instructions) 46 | 47 | if verbose: 48 | print ' ', instructions 49 | 50 | ## i with reference to obj 2 51 | ## j with reference to obj 1 52 | row = goal_i - i 53 | col = goal_j 54 | if row > 0: 55 | vertical = '{} below '.format( num_to_str(row) ) 56 | elif row < 0: 57 | vertical = '{} above '.format( num_to_str(abs(row)) ) 58 | else: 59 | raise RuntimeError('goal_i should not be in line with obj 2') 60 | 61 | if col > 0: 62 | horizontal = '{} to the right of '.format( num_to_str(col) ) 63 | elif col < 0: 64 | horizontal = '{} to the left of '.format( num_to_str(abs(col)) ) 65 | else: 66 | raise RuntimeError('goal_j should not be in line with obj 1') 67 | 68 | instructions = 'reach ' + vertical + ' and ' + horizontal 69 | 70 | if verbose: 71 | print ' ', instructions 72 | 73 | references[( (i,j), (goal_i,goal_j) )].append(instructions) 74 | 75 | return references 76 | 77 | 78 | if __name__ == '__main__': 79 | references = create_references() 80 | pdb.set_trace() 81 | 82 | 83 | 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /environment/visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | from matplotlib import pyplot as plt 5 | 6 | def visualize_values(mdp, values, policy, filename, title=None): 7 | states = mdp.states 8 | # print states 9 | plt.clf() 10 | m = max(states, key=lambda x: x[0])[0] + 1 11 | n = max(states, key=lambda x: x[1])[1] + 1 12 | data = np.zeros((m,n)) 13 | for i in range(m): 14 | for j in range(n): 15 | state = (i,j) 16 | if type(values) == dict: 17 | data[i][j] = values[state] 18 | else: 19 | # print values[i][j] 20 | data[i][j] = values[i][j] 21 | action = policy[state] 22 | ## if using all_reachable actions, pick the best one 23 | if type(action) == tuple: 24 | action = action[0] 25 | if action != None: 26 | x, y, w, h = arrow(i, j, action) 27 | plt.arrow(x,y,w,h,head_length=0.4,head_width=0.4,fc='k',ec='k') 28 | heatmap = plt.pcolor(data, cmap=plt.get_cmap('jet')) 29 | plt.colorbar() 30 | plt.gca().invert_yaxis() 31 | 32 | if title: 33 | plt.title(title) 34 | plt.savefig(filename + '.png') 35 | # print data 36 | 37 | def arrow(i, j, action): 38 | ## up, down, left, right 39 | ## x, y, w, h 40 | arrows = {0: (.5,.95,0,-.4), 1: (.5,.05,0,.4), 2: (.95,.5,-.4,0), 3: (.05,.5,.4,0)} 41 | arrow = arrows[action] 42 | return j+arrow[0], i+arrow[1], arrow[2], arrow[3] 43 | 44 | -------------------------------------------------------------------------------- /generate_worlds.py: -------------------------------------------------------------------------------- 1 | #!/om/user/janner/anaconda2/envs/pytorch/bin/python 2 | 3 | import sys, os, argparse, pickle, subprocess, pdb 4 | from tqdm import tqdm 5 | import environment, utils 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--lower', type=int, default=0) 9 | parser.add_argument('--num_worlds', type=int, default=10) 10 | parser.add_argument('--vis_path', type=str, default='data/example_vis/') 11 | parser.add_argument('--save_path', type=str, default='data/example_env/') 12 | parser.add_argument('--dim', type=int, default=10) 13 | parser.add_argument('--mode', type=str, default='local', choices=['local', 'global']) 14 | parser.add_argument('--only_global', type=bool, default=False) 15 | parser.add_argument('--sprite_dim', type=int, default=100) 16 | parser.add_argument('--num_steps', type=int, default=10) 17 | args = parser.parse_args() 18 | 19 | print args, '\n' 20 | 21 | utils.mkdir(args.vis_path) 22 | utils.mkdir(args.save_path) 23 | 24 | 25 | if args.mode == 'local': 26 | from environment.NonUniqueGenerator import NonUniqueGenerator 27 | gen = NonUniqueGenerator( environment.figure_library.objects, environment.figure_library.unique_instructions, shape=(args.dim, args.dim), num_steps=args.num_steps, only_global=args.only_global ) 28 | elif args.mode == 'global': 29 | from environment.GlobalGenerator import GlobalGenerator 30 | gen = GlobalGenerator( environment.figure_library.objects, environment.figure_library.unique_instructions, shape=(args.dim, args.dim), num_steps=args.num_steps, only_global=args.only_global ) 31 | 32 | 33 | for outer in range(args.lower, args.lower + args.num_worlds): 34 | info = gen.new() 35 | configurations = len(info['rewards']) 36 | 37 | print 'Generating map', outer, '(', configurations, 'configuations )' 38 | sys.stdout.flush() 39 | 40 | world = info['map'] 41 | rewards = info['rewards'] 42 | terminal = info['terminal'] 43 | values = [] 44 | 45 | sprite = environment.SpriteFigure(environment.figure_library.objects, environment.figure_library.background, dim=args.sprite_dim) 46 | sprite.makeGrid(world, args.vis_path + str(outer) + '_sprites') 47 | 48 | for inner in tqdm(range(configurations)): 49 | reward_map = rewards[inner] 50 | terminal_map = terminal[inner] 51 | 52 | mdp = environment.MDP(world, reward_map, terminal_map) 53 | vi = environment.ValueIteration(mdp) 54 | 55 | values_list, policy = vi.iterate() 56 | value_map = mdp.representValues(values_list) 57 | values.append(value_map) 58 | 59 | 60 | info['values'] = values 61 | filename = os.path.join( args.save_path, str(outer) + '.p' ) 62 | pickle.dump( info, open( filename, 'wb' ) ) 63 | 64 | 65 | -------------------------------------------------------------------------------- /logs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jannerm/spatial-reasoning/e163003a33177e41ca02d5feefee3fdfca5ba154/logs/.gitkeep -------------------------------------------------------------------------------- /logs/example/git_sprites.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jannerm/spatial-reasoning/e163003a33177e41ca02d5feefee3fdfca5ba154/logs/example/git_sprites.png -------------------------------------------------------------------------------- /logs/example/predictions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jannerm/spatial-reasoning/e163003a33177e41ca02d5feefee3fdfca5ba154/logs/example/predictions.png -------------------------------------------------------------------------------- /models/Linear_custom.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | 5 | class Linear_custom(Function): 6 | 7 | def forward(self, input, weight, bias=None): 8 | self.save_for_backward(input, weight, bias) 9 | output = input.new(input.size(0), weight.size(0)) 10 | output.addmm_(0, 1, input, weight.t()) 11 | if bias is not None: 12 | # cuBLAS doesn't support 0 strides in sger, so we can't use expand 13 | self.add_buffer = input.new(input.size(0)).fill_(1) 14 | output.addr_(self.add_buffer, bias) 15 | return output 16 | 17 | def backward(self, grad_output): 18 | input, weight, bias = self.saved_tensors 19 | 20 | grad_input = grad_weight = grad_bias = None 21 | if self.needs_input_grad[0]: 22 | grad_output = grad_output.squeeze() 23 | grad_input = torch.mm(grad_output, weight) 24 | if self.needs_input_grad[1]: 25 | grad_weight = torch.mm(grad_output.t(), input) 26 | if bias is not None and self.needs_input_grad[2]: 27 | grad_bias = torch.mv(grad_output.t(), self.add_buffer) 28 | 29 | if bias is not None: 30 | return grad_input, grad_weight, grad_bias 31 | else: 32 | return grad_input, grad_weight 33 | 34 | 35 | class Bilinear(Function): 36 | 37 | def forward(self, input1, input2, weight, bias=None): 38 | self.save_for_backward(input1, input2, weight, bias) 39 | 40 | output = input1.new(input1.size(0), weight.size(0)) 41 | 42 | buff = input1.new() 43 | 44 | # compute output scores: 45 | for k, w in enumerate(weight): 46 | torch.mm(input1, w, out=buff) 47 | buff.mul_(input2) 48 | torch.sum(buff, 1, out=output.narrow(1, k, 1)) 49 | 50 | if bias is not None: 51 | output.add_(bias.expand_as(output)) 52 | 53 | return output 54 | 55 | def backward(self, grad_output): 56 | input1, input2, weight, bias = self.saved_tensors 57 | grad_input1 = grad_input2 = grad_weight = grad_bias = None 58 | 59 | buff = input1.new() 60 | 61 | if self.needs_input_grad[0] or self.needs_input_grad[1]: 62 | grad_input1 = torch.mm(input2, weight[0].t()) 63 | grad_input1.mul_(grad_output.narrow(1, 0, 1).expand(grad_input1.size())) 64 | grad_input2 = torch.mm(input1, weight[0]) 65 | grad_input2.mul_(grad_output.narrow(1, 0, 1).expand(grad_input2.size())) 66 | 67 | for k in range(1, weight.size(0)): 68 | torch.mm(input2, weight[k].t(), out=buff) 69 | buff.mul_(grad_output.narrow(1, k, 1).expand(grad_input1.size())) 70 | grad_input1.add_(buff) 71 | 72 | torch.mm(input1, weight[k], out=buff) 73 | buff.mul_(grad_output.narrow(1, k, 1).expand(grad_input2.size())) 74 | grad_input2.add_(buff) 75 | 76 | if self.needs_input_grad[2]: 77 | # accumulate parameter gradients: 78 | for k in range(weight.size(0)): 79 | torch.mul(input1, grad_output.narrow(1, k, 1).expand_as(input1), out=buff) 80 | grad_weight = torch.mm(buff.t(), input2) 81 | 82 | if bias is not None and self.needs_input_grad[3]: 83 | grad_bias = grad_output.sum(0) 84 | 85 | return grad_input1, grad_input2, grad_weight, grad_bias 86 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from state_model_10 import * 2 | from object_model_10 import * 3 | from text_model import * 4 | from goal_model import * 5 | from compositor_model import * 6 | from constraint_factorizer import * 7 | from model_factorizer import * 8 | from values_factorized import * 9 | from map_model import * 10 | 11 | ## attention 12 | from lookup_model import * 13 | from lookup_location import * 14 | from attention_model import * 15 | from attention_direct import * 16 | from attention_heatmap import * 17 | 18 | ## global 19 | from attention_global import * 20 | from multi_global import * 21 | from multi_nonsep import * 22 | from multi_nobases import * 23 | from multi_norbf import * 24 | from multi_nocnn import * 25 | from cnn_lstm import * 26 | 27 | ## full map (instead of single value) 28 | from multi_model import * 29 | from simple_conv import * 30 | from conv_to_vector import * 31 | 32 | ## uvfa 33 | from uvfa_pos import * 34 | # from uvfa_3 import * 35 | from uvfa_text import * 36 | from mlp import * 37 | 38 | from initializations import * -------------------------------------------------------------------------------- /models/attention_direct.py: -------------------------------------------------------------------------------- 1 | ## attention model with only a single convolution 2 | ## LSTM kernel isn't really used as attention map, 3 | ## but just as a kernel for convolution 4 | 5 | import torch 6 | import math, torch.nn as nn 7 | import torch.nn.functional as F 8 | import utils 9 | 10 | class AttentionDirect(nn.Module): 11 | # args.lstm_out, args.goal_hid, args.rank, args.obj_embed 12 | def __init__(self, text_model, object_model, args, final_hidden = 20, map_dim = 10): 13 | super(AttentionDirect, self).__init__() 14 | 15 | self.text_model = text_model 16 | self.object_model = object_model 17 | 18 | self.embed_dim = args.obj_embed 19 | self.kernel_out_dim = args.attention_out_dim 20 | self.kernel_size = args.attention_kernel 21 | 22 | self.conv_custom = utils.ConvKernel(self.embed_dim, self.kernel_out_dim, self.kernel_size, bias=False) 23 | 24 | self.reshape_dim = self.kernel_out_dim * (map_dim-self.kernel_size+1)**2 25 | self.fc1 = nn.Linear(self.reshape_dim, args.rank) 26 | 27 | def __conv(self, inp, kernel): 28 | batch_size = inp.size(0) 29 | out = [ self.conv_custom(inp[i].unsqueeze(0), kernel[i]) for i in range(batch_size) ] 30 | out = torch.cat(out, 0) 31 | return out 32 | 33 | def forward(self, inp): 34 | (obj, text) = inp 35 | batch_size = obj.size(0) 36 | text = text.transpose(0,1) 37 | hidden = self.text_model.init_hidden(batch_size) 38 | 39 | embeddings = self.object_model.forward(obj) 40 | 41 | lstm_out = self.text_model.forward(text, hidden) 42 | lstm_out = lstm_out.view(-1, self.kernel_out_dim, self.embed_dim, self.kernel_size, self.kernel_size) 43 | 44 | conv = self.__conv(embeddings, lstm_out) 45 | # print conv.size() 46 | conv = conv.view(-1, self.reshape_dim) 47 | # print conv.size() 48 | 49 | out = F.relu(self.fc1(conv)) 50 | # print out.size() 51 | # out = F.relu(self.fc2(out)) 52 | 53 | return out 54 | 55 | 56 | if __name__ == '__main__': 57 | from text_model import * 58 | from object_model import * 59 | from lookup_model import * 60 | 61 | batch = 2 62 | seq = 10 63 | 64 | text_vocab = 10 65 | lstm_inp = 5 66 | lstm_hid = 3 67 | lstm_layers = 1 68 | 69 | # obj_inp = torch.LongTensor(1,10,10).zero_() 70 | obj_vocab = 3 71 | emb_dim = 3 72 | 73 | concat_dim = 27 74 | hidden_dim = 5 75 | out_dim = 7 76 | 77 | text_model = TextModel(text_vocab, lstm_inp, lstm_hid, lstm_layers, concat_dim) 78 | object_model = LookupModel(obj_vocab, emb_dim, concat_dim) 79 | 80 | psi = AttentionModel(text_model, object_model, concat_dim, hidden_dim, out_dim) 81 | 82 | hidden = text_model.init_hidden(batch) 83 | text_inp = Variable(torch.floor(torch.rand(batch,seq)*text_vocab).long()) 84 | obj_inp = Variable(torch.floor(torch.rand(batch,1,10,10)*obj_vocab).long()) 85 | 86 | to = psi.forward( (obj_inp, text_inp) ) 87 | print to.size() 88 | 89 | # # inp = Variable(torch.LongTensor((1,2,3))) 90 | # print 'INPUT: ', text_inp 91 | # print text_inp.size() 92 | # out = text_model.forward( text_inp, hidden ) 93 | 94 | # print 'OUT: ', out 95 | # print out.size() 96 | # # print 'HID: ', hid 97 | 98 | # obj_inp = Variable(torch.LongTensor(5,1,20,20).zero_()) 99 | # obj_inp = Variable(obj_inp) 100 | # obj_out = object_model.forward(obj_inp) 101 | 102 | # print obj_out.data.size() 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /models/attention_global.py: -------------------------------------------------------------------------------- 1 | ## attention model with only a single convolution 2 | ## LSTM kernel isn't really used as attention map, 3 | ## but just as a kernel for convolution 4 | 5 | import math, pdb 6 | import torch, torch.nn as nn, torch.nn.functional as F 7 | import custom 8 | 9 | class AttentionGlobal(nn.Module): 10 | # args.lstm_out, args.goal_hid, args.rank, args.obj_embed 11 | def __init__(self, text_model, args, map_dim = 10): 12 | super(AttentionGlobal, self).__init__() 13 | 14 | assert args.attention_kernel % 2 == 1 15 | 16 | self.text_model = text_model 17 | # self.object_model = object_model 18 | 19 | self.embed_dim = args.attention_in_dim 20 | self.kernel_out_dim = args.attention_out_dim 21 | self.kernel_size = args.attention_kernel 22 | self.global_coeffs = args.global_coeffs 23 | 24 | padding = int(math.ceil(self.kernel_size/2.)) - 1 25 | self.conv_custom = custom.ConvKernel(self.embed_dim, self.kernel_out_dim, self.kernel_size, bias=False, padding=padding) 26 | 27 | self.reshape_dim = self.kernel_out_dim * (map_dim-self.kernel_size+1)**2 28 | 29 | def __conv(self, inp, kernel): 30 | batch_size = inp.size(0) 31 | out = [ self.conv_custom(inp[i].unsqueeze(0), kernel[i]) for i in range(batch_size) ] 32 | out = torch.cat(out, 0) 33 | return out 34 | 35 | def forward(self, inp): 36 | (embeddings, text) = inp 37 | batch_size = embeddings.size(0) 38 | text = text.transpose(0,1) 39 | hidden = self.text_model.init_hidden(batch_size) 40 | 41 | # embeddings = self.object_model.forward(obj) 42 | # print embeddings.size() 43 | 44 | lstm_out = self.text_model.forward(text, hidden) 45 | lstm_kernel = lstm_out[:,:-self.global_coeffs].contiguous() 46 | lstm_kernel = lstm_kernel.view(-1, self.kernel_out_dim, self.embed_dim, self.kernel_size, self.kernel_size) 47 | # print 'LSTM_OUT: ', lstm_out.size() 48 | # print 'EMBEDDINGS: ', embeddings.size() 49 | # print self.kernel_size/2 - 1, self.kernel_size 50 | local_heatmap = self.__conv(embeddings, lstm_kernel) 51 | 52 | ## sum along attention_out_dim 53 | ## < batch x attention_out_dim x map_dim x map_dim > 54 | ## < batch x 1 x map_dim x map_dim > 55 | local_heatmap = local_heatmap.sum(1, keepdim=True) 56 | 57 | lstm_global = lstm_out[:,-self.global_coeffs:] 58 | # global_heatmap = self._global(lstm_global) 59 | 60 | # out = local_heatmap + global_heatmap 61 | # print conv.size() 62 | # conv = conv.view(-1, self.reshape_dim) 63 | 64 | ## save outputs for kernel visualization 65 | self.output_local = local_heatmap 66 | self.output_global = lstm_global 67 | 68 | return local_heatmap, lstm_global 69 | 70 | 71 | if __name__ == '__main__': 72 | import argparse 73 | from text_model import * 74 | from object_model import * 75 | from lookup_model import * 76 | 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument('--attention_kernel', type=int, default=3) 79 | parser.add_argument('--attention_out_dim', type=int, default=3) 80 | parser.add_argument('--obj_embed', type=int, default=5) 81 | parser.add_argument('--map_dim', type=int, default=10) 82 | args = parser.parse_args() 83 | 84 | batch = 2 85 | seq = 10 86 | 87 | text_vocab = 10 88 | lstm_inp = 5 89 | lstm_hid = 3 90 | lstm_layers = 1 91 | 92 | # obj_inp = torch.LongTensor(1,10,10).zero_() 93 | obj_vocab = 3 94 | # emb_dim = 3 95 | 96 | lstm_out = args.obj_embed * args.attention_out_dim * args.attention_kernel**2 97 | # hidden_dim = 5 98 | # out_dim = 7 99 | 100 | text_model = TextModel(text_vocab, lstm_inp, lstm_hid, lstm_layers, lstm_out) 101 | object_model = LookupModel(obj_vocab, args.obj_embed) 102 | 103 | psi = AttentionHeatmap(text_model, object_model, args, map_dim = args.map_dim) 104 | 105 | hidden = text_model.init_hidden(batch) 106 | text_inp = Variable(torch.floor(torch.rand(batch,seq)*text_vocab).long()) 107 | obj_inp = Variable(torch.floor(torch.rand(batch,1,10,10)*obj_vocab).long()) 108 | 109 | to = psi.forward( (obj_inp, text_inp) ) 110 | print to.size() 111 | 112 | # # inp = Variable(torch.LongTensor((1,2,3))) 113 | # print 'INPUT: ', text_inp 114 | # print text_inp.size() 115 | # out = text_model.forward( text_inp, hidden ) 116 | 117 | # print 'OUT: ', out 118 | # print out.size() 119 | # # print 'HID: ', hid 120 | 121 | # obj_inp = Variable(torch.LongTensor(5,1,20,20).zero_()) 122 | # obj_inp = Variable(obj_inp) 123 | # obj_out = object_model.forward(obj_inp) 124 | 125 | # print obj_out.data.size() 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /models/attention_heatmap.py: -------------------------------------------------------------------------------- 1 | ## attention model with only a single convolution 2 | ## LSTM kernel isn't really used as attention map, 3 | ## but just as a kernel for convolution 4 | 5 | import torch 6 | import math, torch.nn as nn 7 | import torch.nn.functional as F 8 | import custom 9 | 10 | class AttentionHeatmap(nn.Module): 11 | # args.lstm_out, args.goal_hid, args.rank, args.obj_embed 12 | def __init__(self, text_model, args, map_dim = 10): 13 | super(AttentionHeatmap, self).__init__() 14 | 15 | assert args.attention_kernel % 2 == 1 16 | 17 | self.text_model = text_model 18 | 19 | self.embed_dim = args.attention_in_dim 20 | self.kernel_out_dim = args.attention_out_dim 21 | self.kernel_size = args.attention_kernel 22 | 23 | padding = int(math.ceil(self.kernel_size/2.)) - 1 24 | self.conv_custom = custom.ConvKernel(self.embed_dim, self.kernel_out_dim, self.kernel_size, bias=False, padding=padding) 25 | 26 | self.reshape_dim = self.kernel_out_dim * (map_dim-self.kernel_size+1)**2 27 | 28 | def __conv(self, inp, kernel): 29 | batch_size = inp.size(0) 30 | out = [ self.conv_custom(inp[i].unsqueeze(0), kernel[i]) for i in range(batch_size) ] 31 | out = torch.cat(out, 0) 32 | return out 33 | 34 | def forward(self, inp): 35 | (embeddings, text) = inp 36 | batch_size = embeddings.size(0) 37 | text = text.transpose(0,1) 38 | hidden = self.text_model.init_hidden(batch_size) 39 | 40 | lstm_out = self.text_model.forward(text, hidden) 41 | lstm_kernel = lstm_out.view(-1, self.kernel_out_dim, self.embed_dim, self.kernel_size, self.kernel_size) 42 | # print 'LSTM_OUT: ', lstm_out.size() 43 | # print 'EMBEDDINGS: ', embeddings.size() 44 | # print self.kernel_size/2 - 1, self.kernel_size 45 | local_heatmap = self.__conv(embeddings, lstm_kernel) 46 | 47 | ## sum along attention_out_dim 48 | ## < batch x attention_out_dim x map_dim x map_dim > 49 | ## < batch x 1 x map_dim x map_dim > 50 | local_heatmap = local_heatmap.sum(1, keepdim=True) 51 | # print conv.size() 52 | # conv = conv.view(-1, self.reshape_dim) 53 | 54 | return local_heatmap 55 | 56 | 57 | if __name__ == '__main__': 58 | import argparse 59 | from text_model import * 60 | from object_model import * 61 | from lookup_model import * 62 | 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument('--attention_kernel', type=int, default=3) 65 | parser.add_argument('--attention_out_dim', type=int, default=3) 66 | parser.add_argument('--obj_embed', type=int, default=5) 67 | parser.add_argument('--map_dim', type=int, default=10) 68 | args = parser.parse_args() 69 | 70 | batch = 2 71 | seq = 10 72 | 73 | text_vocab = 10 74 | lstm_inp = 5 75 | lstm_hid = 3 76 | lstm_layers = 1 77 | 78 | # obj_inp = torch.LongTensor(1,10,10).zero_() 79 | obj_vocab = 3 80 | # emb_dim = 3 81 | 82 | lstm_out = args.obj_embed * args.attention_out_dim * args.attention_kernel**2 83 | # hidden_dim = 5 84 | # out_dim = 7 85 | 86 | text_model = TextModel(text_vocab, lstm_inp, lstm_hid, lstm_layers, lstm_out) 87 | object_model = LookupModel(obj_vocab, args.obj_embed) 88 | 89 | psi = AttentionHeatmap(text_model, object_model, args, map_dim = args.map_dim) 90 | 91 | hidden = text_model.init_hidden(batch) 92 | text_inp = Variable(torch.floor(torch.rand(batch,seq)*text_vocab).long()) 93 | obj_inp = Variable(torch.floor(torch.rand(batch,1,10,10)*obj_vocab).long()) 94 | 95 | to = psi.forward( (obj_inp, text_inp) ) 96 | print to.size() 97 | 98 | # # inp = Variable(torch.LongTensor((1,2,3))) 99 | # print 'INPUT: ', text_inp 100 | # print text_inp.size() 101 | # out = text_model.forward( text_inp, hidden ) 102 | 103 | # print 'OUT: ', out 104 | # print out.size() 105 | # # print 'HID: ', hid 106 | 107 | # obj_inp = Variable(torch.LongTensor(5,1,20,20).zero_()) 108 | # obj_inp = Variable(obj_inp) 109 | # obj_out = object_model.forward(obj_inp) 110 | 111 | # print obj_out.data.size() 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /models/attention_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math, torch.nn as nn 3 | import torch.nn.functional as F 4 | import utils 5 | 6 | class AttentionModel(nn.Module): 7 | 8 | def __init__(self, text_model, object_model, args, final_hidden = 20, map_dim = 10): 9 | super(AttentionModel, self).__init__() 10 | 11 | assert args.attention_kernel % 2 == 1 12 | 13 | self.text_model = text_model 14 | self.object_model = object_model 15 | 16 | self.embed_dim = args.obj_embed 17 | self.kernel_out_dim = args.attention_out_dim 18 | self.kernel_size = args.attention_kernel 19 | 20 | self.conv_custom = utils.ConvKernel(self.embed_dim, self.kernel_out_dim, self.kernel_size, bias=False, padding=1) 21 | 22 | ## final_hidden and fc1 are hard-coded 23 | ## should be bash arg / inferred 24 | self.conv1_kernel = 3 25 | self.conv1 = nn.Conv2d(self.embed_dim, self.embed_dim, kernel_size=self.conv1_kernel, padding=1) 26 | # self.conv2 = nn.Conv2d(3, 6, kernel_size=5) 27 | # self.conv3 = nn.Conv2d(6,12, kernel_size=5) 28 | self.reshape_dim = self.embed_dim * (map_dim)**2 29 | self.fc1 = nn.Linear(self.reshape_dim, final_hidden) 30 | self.fc2 = nn.Linear(final_hidden, args.rank) 31 | 32 | def __conv(self, inp, kernel): 33 | # print '__conv: ', inp.size(), kernel.size() 34 | batch_size = inp.size(0) 35 | out = [ self.conv_custom(inp[i].unsqueeze(0), kernel[i]) for i in range(batch_size) ] 36 | out = torch.cat(out, 0) 37 | return out 38 | 39 | def forward(self, inp): 40 | (obj, text) = inp 41 | batch_size = obj.size(0) 42 | text = text.transpose(0,1) 43 | hidden = self.text_model.init_hidden(batch_size) 44 | 45 | embeddings = self.object_model.forward(obj) 46 | 47 | lstm_out = self.text_model.forward(text, hidden) 48 | lstm_out = lstm_out.view(-1, 1, self.embed_dim, self.kernel_size, self.kernel_size) 49 | # print 'lstm: ', lstm_out.size() 50 | 51 | ## first convolve the object embeddings 52 | ## to populate the states neighboring objects 53 | ## (otherwise they would be all 0's) 54 | conv = F.relu(self.conv1(embeddings)) 55 | # print 'conv: ', conv.size() 56 | ## get attention map 57 | attention = self.__conv(conv, lstm_out) 58 | # print 'attention: ', attention.size() 59 | # attention = self.__conv(embeddings, lstm_out) 60 | attention = attention.repeat(1,self.embed_dim,1,1) 61 | # print 'conv: ', conv.size() 62 | attended = attention * conv 63 | attended = attended.view(-1, self.reshape_dim) 64 | # print 'attended: ', attended.size() 65 | out = F.relu(self.fc1(attended)) 66 | out = F.relu(self.fc2(out)) 67 | 68 | # out = F.relu(self.conv1(attended)) 69 | # out = F.relu(self.conv2(out)) 70 | # out = F.relu(self.conv3(out)) 71 | # out = out.view(-1, 12*2*2) 72 | # out = F.relu(self.fc1(out)) 73 | # out = self.fc2(out) 74 | 75 | return out 76 | 77 | 78 | if __name__ == '__main__': 79 | from text_model import * 80 | from object_model import * 81 | from lookup_model import * 82 | 83 | batch = 2 84 | seq = 10 85 | 86 | text_vocab = 10 87 | lstm_inp = 5 88 | lstm_hid = 3 89 | lstm_layers = 1 90 | 91 | # obj_inp = torch.LongTensor(1,10,10).zero_() 92 | obj_vocab = 3 93 | emb_dim = 3 94 | 95 | concat_dim = 27 96 | hidden_dim = 5 97 | out_dim = 7 98 | 99 | text_model = TextModel(text_vocab, lstm_inp, lstm_hid, lstm_layers, concat_dim) 100 | object_model = LookupModel(obj_vocab, emb_dim, concat_dim) 101 | 102 | psi = AttentionModel(text_model, object_model, concat_dim, hidden_dim, out_dim) 103 | 104 | hidden = text_model.init_hidden(batch) 105 | text_inp = Variable(torch.floor(torch.rand(batch,seq)*text_vocab).long()) 106 | obj_inp = Variable(torch.floor(torch.rand(batch,1,10,10)*obj_vocab).long()) 107 | 108 | to = psi.forward( (obj_inp, text_inp) ) 109 | print to.size() 110 | 111 | # # inp = Variable(torch.LongTensor((1,2,3))) 112 | # print 'INPUT: ', text_inp 113 | # print text_inp.size() 114 | # out = text_model.forward( text_inp, hidden ) 115 | 116 | # print 'OUT: ', out 117 | # print out.size() 118 | # # print 'HID: ', hid 119 | 120 | # obj_inp = Variable(torch.LongTensor(5,1,20,20).zero_()) 121 | # obj_inp = Variable(obj_inp) 122 | # obj_out = object_model.forward(obj_inp) 123 | 124 | # print obj_out.data.size() 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /models/cnn_lstm.py: -------------------------------------------------------------------------------- 1 | ## attention model with only a single convolution 2 | ## LSTM kernel isn't really used as attention map, 3 | ## but just as a kernel for convolution 4 | 5 | import torch 6 | import math, torch.nn as nn, pdb 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | import models, utils 10 | 11 | class CNN_LSTM(nn.Module): 12 | # args.lstm_out, args.goal_hid, args.rank, args.obj_embed 13 | def __init__(self, state_model, object_model, lstm, args, map_dim = 10, batch_size = 32): 14 | super(CNN_LSTM, self).__init__() 15 | 16 | self.state_model = state_model 17 | self.object_model = object_model 18 | 19 | self.cnn_inp_dim = args.obj_embed + args.state_embed + 1 20 | self.cnn = models.ConvToVector(self.cnn_inp_dim) 21 | self.lstm = lstm 22 | 23 | self.fc1 = nn.Linear(args.cnn_out_dim, 16) 24 | self.fc2 = nn.Linear(16, 1) 25 | 26 | # self.state_vocab = args.state_embed 27 | # self.object_vocab = args.obj_embed 28 | 29 | self.state_dim = args.state_embed 30 | self.object_dim = args.obj_embed 31 | 32 | self.map_dim = map_dim 33 | self.batch_size = batch_size 34 | self.positions = self.__agent_pos_2d() 35 | 36 | # ''' 37 | # returns tensor with one-hot vector encoding 38 | # [1, 2, 3, ..., map_dim] repeated batch_size times 39 | # < batch_size * map_dim, state_vocab > 40 | # ''' 41 | # def __agent_pos(self): 42 | # size = self.map_dim**2 43 | # positions = torch.zeros(self.batch_size*size, 100, 1) 44 | # # print positions.size() 45 | # for ind in range(size): 46 | # # print ind, ind*self.batch_size, (ind+1)*self.batch_size, ind, positions.size() 47 | # # positions[ind*self.batch_size:(ind+1)*self.batch_size, ind] = 1 48 | # positions[ind:self.batch_size*size:size, ind] = 1 49 | # # pdb.set_trace() 50 | # return Variable( positions.cuda() ) 51 | 52 | def __agent_pos_2d(self): 53 | ## < 10 x 10 > 54 | positions = torch.zeros(self.map_dim**2, self.map_dim**2) 55 | for i in range(self.map_dim**2): 56 | positions[i][i] = 1 57 | 58 | ## < 100 x 10 x 10 > 59 | positions = positions.view(self.map_dim**2, self.map_dim, self.map_dim) 60 | ## < 100 x 1 x 10 x 10 > 61 | ## < 100*batch x 1 x 10 x 10 > 62 | positions = positions.unsqueeze(1).repeat(self.batch_size,1,1,1) 63 | return Variable( positions.cuda() ) 64 | 65 | 66 | def __repeat_position(self, x): 67 | # print 'X: ', x.size() 68 | if x.size() == 2: 69 | return x.unsqueeze(1).repeat(1,self.map_dim**2,1) 70 | elif x.size() == 3: 71 | return x.unsqueeze(1).repeat(1,self.map_dim**2,1,1) 72 | else: 73 | return x.repeat(1,self.map_dim**2,1,1) 74 | 75 | # ''' 76 | # < batch_size x N > 77 | # < batch_size*100 x N > 78 | # ''' 79 | # def __construct_inp(self, state, obj, text): 80 | # state = self.__repeat_position(state) 81 | # # pdb.set_trace() 82 | # state = state.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.state_dim) 83 | # ## add agent position 84 | # state = torch.cat( (state, self.positions), -1) 85 | # ## reshape to (batched) vector for input to MLPs 86 | # state = state.view(-1, self.state_dim+1) 87 | 88 | # obj = self.__repeat_position(obj) 89 | # obj = obj.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.object_dim) 90 | # obj = obj.view(-1, self.object_dim) 91 | 92 | # instr_length = text.size(1) 93 | # ## < batch x length > 94 | # ## < batch x 100 x length > 95 | # text = self.__repeat_position(text) 96 | # ## < batch*100 x length > 97 | # text = text.view(self.batch_size*self.map_dim**2,instr_length) 98 | # ## < length x batch*100 > 99 | # text = text.transpose(0,1) 100 | # ## < batch*100 x rank > 101 | 102 | # return state, obj, text 103 | 104 | def forward(self, inp): 105 | (state, obj, text) = inp 106 | batch_size = state.size(0) 107 | 108 | if batch_size != self.batch_size: 109 | self.batch_size = batch_size 110 | self.positions = self.__agent_pos_2d() 111 | 112 | ## < batch x layout x 10 x 10 > 113 | state_out = self.state_model(state) 114 | ## < batch x object x 10 x 10 > 115 | obj_out = self.object_model.forward(obj) 116 | 117 | ## < batch x layout+object x 10 x 10 > 118 | embeddings = torch.cat( (state_out, obj_out), 1 ) 119 | ## < batch x (layout+object)*100 x 10 x 10 > 120 | ## < batch*100 x layout+object x 10 x 10 > 121 | embeddings = self.__repeat_position(embeddings).view(self.batch_size*self.map_dim**2,self.cnn_inp_dim-1,self.map_dim,self.map_dim) 122 | 123 | ## < batch*100, layout+object+agent, 10, 10 > 124 | ## state + object embeddings + agent position 125 | concat = torch.cat( (embeddings, self.positions), 1) 126 | ## < batch*100 x 1 x 4 x 4 > 127 | ## < batch*100 x 16 > 128 | cnn_out = self.cnn( concat ).view(batch_size*self.map_dim**2, -1) 129 | 130 | 131 | instr_length = text.size(1) 132 | ## < batch x length > 133 | ## < batch x 100 x length > 134 | text = self.__repeat_position(text) 135 | ## < batch*100 x length > 136 | text = text.view(self.batch_size*self.map_dim**2,instr_length) 137 | ## < length x batch*100 > 138 | text = text.transpose(0,1) 139 | hidden = self.lstm.init_hidden(self.batch_size*self.map_dim**2) 140 | ## < batch*100 x rank > 141 | lstm_out = self.lstm.forward(text, hidden) 142 | 143 | concat = torch.cat( (cnn_out, lstm_out), 1 ) 144 | 145 | out = F.relu(self.fc1(concat)) 146 | out = self.fc2(out) 147 | 148 | map_pred = out.view(self.batch_size,self.map_dim,self.map_dim) 149 | 150 | return map_pred 151 | 152 | 153 | if __name__ == '__main__': 154 | from text_model import * 155 | from object_model import * 156 | from lookup_model import * 157 | 158 | batch = 2 159 | seq = 10 160 | 161 | text_vocab = 10 162 | lstm_inp = 5 163 | lstm_hid = 3 164 | lstm_layers = 1 165 | 166 | # obj_inp = torch.LongTensor(1,10,10).zero_() 167 | obj_vocab = 3 168 | emb_dim = 3 169 | 170 | concat_dim = 27 171 | hidden_dim = 5 172 | out_dim = 7 173 | 174 | text_model = TextModel(text_vocab, lstm_inp, lstm_hid, lstm_layers, concat_dim) 175 | object_model = LookupModel(obj_vocab, emb_dim, concat_dim) 176 | 177 | psi = AttentionModel(text_model, object_model, concat_dim, hidden_dim, out_dim) 178 | 179 | hidden = text_model.init_hidden(batch) 180 | text_inp = Variable(torch.floor(torch.rand(batch,seq)*text_vocab).long()) 181 | obj_inp = Variable(torch.floor(torch.rand(batch,1,10,10)*obj_vocab).long()) 182 | 183 | to = psi.forward( (obj_inp, text_inp) ) 184 | print to.size() 185 | 186 | # # inp = Variable(torch.LongTensor((1,2,3))) 187 | # print 'INPUT: ', text_inp 188 | # print text_inp.size() 189 | # out = text_model.forward( text_inp, hidden ) 190 | 191 | # print 'OUT: ', out 192 | # print out.size() 193 | # # print 'HID: ', hid 194 | 195 | # obj_inp = Variable(torch.LongTensor(5,1,20,20).zero_()) 196 | # obj_inp = Variable(obj_inp) 197 | # obj_out = object_model.forward(obj_inp) 198 | 199 | # print obj_out.data.size() 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | -------------------------------------------------------------------------------- /models/compositor_model.py: -------------------------------------------------------------------------------- 1 | import torch, torch.nn as nn 2 | 3 | class CompositorModel(nn.Module): 4 | 5 | def __init__(self, state_model, goal_model): 6 | super(CompositorModel, self).__init__() 7 | 8 | self.state_model = state_model 9 | self.goal_model = goal_model 10 | 11 | 12 | def forward(self, state, objects, instructions): 13 | state_embedding = self.state_model.forward(state) 14 | goal_embedding = self.goal_model.forward( (objects, instructions) ) 15 | 16 | ## num_states x rank 17 | num_states = state_embedding.size(0) 18 | ## num_goals x rank 19 | num_goals = goal_embedding.size(0) 20 | 21 | ## num_states x num_goals x rank 22 | state_rep = state_embedding.unsqueeze(1).repeat(1,num_goals,1) 23 | goal_rep = goal_embedding.repeat(num_states,1,1) 24 | 25 | values = state_rep * goal_rep 26 | values = values.sum(2, keepdim=True).squeeze() 27 | values = values.transpose(0,1) 28 | 29 | return values 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /models/constraint_factorizer.py: -------------------------------------------------------------------------------- 1 | import sys, math 2 | import numpy as np 3 | from tqdm import tqdm, trange 4 | import torch, torch.nn as nn, torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.autograd import Variable 7 | from torch.nn.modules.utils import _pair 8 | import pdb, pickle 9 | 10 | class ConstraintFactorizer(nn.Module): 11 | def __init__(self, sparse_value_mat, rank, dissimilarity_lambda, world_lambda, location_lambda): 12 | super(ConstraintFactorizer, self).__init__() 13 | self.M, self.N = sparse_value_mat.shape 14 | self.mat = Variable( torch.Tensor(sparse_value_mat).cuda() ) 15 | self.mask = (self.mat == 0) 16 | self.rank = rank 17 | # print self.M, self.N 18 | self.M_embed = nn.Embedding(self.M, self.rank) 19 | self.N_embed = nn.Embedding(self.N, self.rank) 20 | self.M_inp = Variable( torch.range(0, self.M-1).cuda().long() ) 21 | self.N_inp = Variable( torch.range(0, self.N-1).cuda().long() ) 22 | self.dim = int(math.sqrt(self.N)) 23 | 24 | # self.conv_kernel = torch.ones(1,self.rank,3,3)/8./self.rank 25 | # self.conv_kernel[:,:,1,1] = -1./self.rank 26 | self.conv_kernel = Variable( self.avg_conv(self.rank).cuda() ) 27 | self.conv_bias = None 28 | self.stride = _pair(1) 29 | self.padding = _pair(1) 30 | 31 | ######## similarity loss on states 32 | # self.conv_vector = self.avg_vector(self.rank) 33 | ######## 34 | 35 | # print 'BEFORE PARAMETER:' 36 | # print self.conv_kernel 37 | # pdb.set_trace() 38 | # self.param = self.conv_kernel 39 | # print 'CONV KERNEL:' 40 | # print self.param 41 | # self.conv_kernel = self.conv_kernel.cuda() 42 | # self.conv_bias = self.conv_bias.cuda() 43 | # self.conv = nn.Conv2d(self.rank,1,kernel_size=3,padding=1) 44 | 45 | self.dissimilarity_lambda = dissimilarity_lambda 46 | self.world_lambda = world_lambda 47 | self.location_lambda = location_lambda 48 | 49 | # print self.mat 50 | # print self.mask 51 | 52 | def __lookup(self): 53 | rows = self.M_embed(self.M_inp) 54 | columns = self.N_embed(self.N_inp) 55 | return rows, columns 56 | 57 | # def __reset_conv(self): 58 | # # print self.conv_kernel 59 | # self.conv.weight = nn.Parameter(self.conv_kernel) 60 | # self.conv.bias = nn.Parameter(self.conv_bias) 61 | 62 | ## batch x 5 x 3 63 | def forward(self, x): 64 | # self.__reset_conv() 65 | # print self.conv.weight 66 | # print self.conv.bias 67 | ## batch x 5 x 1 68 | rows, columns = self.__lookup() 69 | out = torch.mm(rows, columns.t()) 70 | # print out 71 | out[self.mask] = 0 72 | 73 | # print out 74 | # print self.mat 75 | diff = torch.pow(out - self.mat, 2) 76 | mse = diff.sum() 77 | 78 | ## 1 x M x N x rank 79 | layout = columns.view(self.dim,self.dim,self.rank).unsqueeze(0) 80 | ## 1 x rank x N x M 81 | layout = layout.transpose(1,3) 82 | ## 1 x rank x M x N 83 | layout = layout.transpose(2,3) 84 | 85 | average = F.conv2d(layout,self.conv_kernel,self.conv_bias,self.stride,self.padding) 86 | divergence_penalty = torch.pow(layout - average, 2).sum() 87 | 88 | # print 'divergenec: ', divergence_penalty.size() 89 | # print 'layout: ', layout.size() 90 | # conv = self.conv(layout) 91 | # divergence_penalty = conv.sum() 92 | # print 'conv: ', conv.size() 93 | # pdb.set_trace() 94 | self.mse = mse.data[0] 95 | self.divergence = divergence_penalty.data[0] 96 | 97 | loss = mse + self.dissimilarity_lambda * divergence_penalty 98 | 99 | ######## similarity loss on states 100 | ## worlds x state_size x rank 101 | states = rows.view(self.M / self.N, self.N, self.rank) 102 | 103 | world_avg = states.mean(1).repeat(1,self.N,1) 104 | location_avg = states.mean(0).repeat(self.M/self.N,1,1) 105 | # pdb.set_trace() 106 | world_mse = torch.pow(states - world_avg, 2).sum() 107 | location_mse = torch.pow(states - location_avg, 2).sum() 108 | self.world_mse = world_mse.data[0] 109 | self.location_mse = location_mse.data[0] 110 | 111 | loss += (self.world_lambda * world_mse) + (self.location_lambda * location_mse) 112 | ## rank x state_size x worlds 113 | # states = states.transpose(0,2) 114 | ## rank x worlds x state_size 115 | # states = states.transpose(1,2) 116 | # F.conv2d(states, self.conv_vector,self.conv_bias,self.stride,self.padding) 117 | ######## 118 | 119 | 120 | # print rows.size(), columns.size(), out.size() 121 | # print out 122 | # print out 123 | # print loss 124 | 125 | return loss 126 | 127 | def embeddings(self): 128 | return self.__lookup() 129 | 130 | def avg_conv(self, out_dim): 131 | kernel = torch.zeros(out_dim,out_dim,3,3) 132 | for i in range(out_dim): 133 | kernel[i][i] = 1./8 134 | kernel[i][i][1][1] = 0 135 | return kernel 136 | 137 | # def avg_vector(self, out_dim): 138 | # kernel = torch.zeros(out_dim,out_dim,1,3) 139 | # for i in range(out_dim): 140 | # kernel[i][i][0] = torch.Tensor((1,0,1))/2. 141 | # return kernel 142 | 143 | def train(self, lr, iters): 144 | optimizer = optim.Adam(self.parameters(), lr=lr) 145 | 146 | t = trange(iters) 147 | for i in t: 148 | optimizer.zero_grad() 149 | loss = self.forward( () ) 150 | # print loss.data[0] 151 | t.set_description( '%.3f | %.3f | %.3f | %.3f' % (self.mse, self.divergence, self.world_mse, self.location_mse) ) 152 | loss.backward() 153 | optimizer.step() 154 | 155 | U, V = self.__lookup() 156 | recon = torch.mm(U, V.t()) 157 | # print U, V, recon 158 | U = U.data.cpu().numpy() 159 | V = V.data.cpu().numpy() 160 | recon = recon.data.cpu().numpy() 161 | return U, V, recon 162 | 163 | def avg_conv(out_dim): 164 | kernel = torch.zeros(out_dim,out_dim,3,3) 165 | for i in range(out_dim): 166 | kernel[i][i] = 1./8 167 | kernel[i][i][1][1] = 0 168 | return kernel 169 | 170 | def avg_vector(out_dim): 171 | kernel = torch.zeros(out_dim,out_dim,1,3) 172 | for i in range(out_dim): 173 | kernel[i][i][0] = torch.Tensor((1,0,1))/2. 174 | return kernel 175 | 176 | if __name__ == '__main__': 177 | # from torch.nn.modules.utils import _pair 178 | # rank = 4 179 | # # kernel = Variable(avg_conv(rank)) 180 | # kernel = Variable( avg_vector(rank) ) 181 | 182 | # inp = Variable(torch.randn(1,rank,20,400)) 183 | # bias = None 184 | # stride = _pair(1) 185 | # padding = _pair(1) 186 | # conv = F.conv2d(inp,kernel,bias,stride,padding)[:,:,1:-1,:] 187 | # print 'inp:', inp.size() 188 | # print 'kernel:', kernel.size() 189 | # print 'conv:', conv.size() 190 | 191 | # pdb.set_trace() 192 | 193 | # from torch.nn.modules.utils import _pair 194 | # inp = Variable(torch.randn(1,5,8,8)) 195 | # kern = Variable(torch.ones(7,5,3,3)) 196 | # bias = None 197 | # stride = _pair(1) 198 | # padding = _pair(1) 199 | # # dilation = _pair(1) 200 | # # print dilation 201 | # # groups = 1 202 | # # conv = F.conv2d(inp,kern,bias,stride,padding,dilation,groups) 203 | # conv = F.conv2d(inp,kern,bias,stride,padding) 204 | # print 'inp:', inp 205 | # print 'conv: ', conv 206 | # print 'sum: ', inp[:,:,:3,:3].sum() 207 | 208 | # print a.b 209 | # print 'conv: ', conv 210 | value_mat = pickle.load( open('../pickle/value_mat20.p') ) 211 | rank = 10 212 | 213 | print 'value_mat: ', value_mat.shape 214 | dissimilarity_lambda = .1 215 | world_lambda = 0 216 | location_lambda = .001 217 | model = ConstraintFactorizer(value_mat, rank, dissimilarity_lambda, world_lambda, location_lambda).cuda() 218 | lr = 0.001 219 | iters = 500000 220 | U, V, recon = model.train(lr, iters) 221 | 222 | # lr = 0.001 223 | # optimizer = optim.Adam(model.parameters(), lr=lr) 224 | 225 | # loss = model.forward( () ) 226 | # print loss 227 | # iters = 50000 228 | # t = trange(iters) 229 | # for i in t: 230 | # optimizer.zero_grad() 231 | # loss = model.forward( () ) 232 | # # print loss.data[0] 233 | # t.set_description( str(model.mse) + ' ' + str(model.divergence) ) 234 | # loss.backward() 235 | # optimizer.step() 236 | 237 | # U, V = model.embeddings() 238 | 239 | print 'recon' 240 | print recon 241 | print 'true' 242 | print torch.Tensor(value_mat) 243 | 244 | pickle.dump(U, open('../pickle/U_lambda_' + str(dissimilarity_lambda) + '.p', 'w') ) 245 | pickle.dump(V, open('../pickle/V_lambda_' + str(dissimilarity_lambda) + '.p', 'w') ) 246 | # pdb.set_trace() 247 | 248 | 249 | # inp = torch.LongTensor(5,3).zero_() 250 | # vocab_size = 10 251 | # emb_dim = 3 252 | # rank = 10 253 | # phi = ObjectModel(vocab_size, emb_dim, inp.size(), rank) 254 | 255 | # batch = 2 256 | # hidden = phi.init_hidden(batch) 257 | 258 | # # enc = nn.Embedding(10,emb_dim,padding_idx=0) 259 | # inp = torch.LongTensor(batch,5,3).zero_() 260 | # inp[0][0]=1 261 | # inp[0][1]=1 262 | # inp[1][0]=8 263 | # inp[1][1]=8 264 | # print inp 265 | # # inp[0][0][0][0]=1 266 | # # inp[0][1][0][0]=1 267 | # # inp[1][0][0][2]=1 268 | # # print inp 269 | # inp = Variable(inp) 270 | 271 | # out = phi.forward(inp, hidden) 272 | # # print out 273 | # # out = out.view(-1,2,3,3,emb_dim) 274 | # out = out.data 275 | # print out.size() 276 | 277 | # # print out[0][0][0] 278 | # # print out[1][0][0] 279 | 280 | -------------------------------------------------------------------------------- /models/conv_to_vector.py: -------------------------------------------------------------------------------- 1 | import sys, math, pdb 2 | import numpy as np 3 | from tqdm import tqdm, trange 4 | import torch, torch.nn as nn, torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | ''' 8 | State observations are two-channel images 9 | with 0: puddle, 1: grass, 2: agent 10 | 11 | ''' 12 | 13 | class ConvToVector(nn.Module): 14 | def __init__(self, in_channels, padding=1): 15 | super(ConvToVector, self).__init__() 16 | 17 | self.in_channels = in_channels 18 | 19 | # self.embed = nn.Embedding(vocab_size, in_channels) 20 | self.conv1 = nn.Conv2d(in_channels, 3, kernel_size=3, padding=padding) 21 | self.conv2 = nn.Conv2d(3, 6, kernel_size=3, padding=padding) 22 | self.conv3 = nn.Conv2d(6, 12, kernel_size=3, padding=padding) 23 | # self.conv4 = nn.Conv2d(12,18, kernel_size=2, padding=padding) 24 | # self.conv5 = nn.Conv2d(18,24, kernel_size=2, padding=padding) 25 | # self.conv6 = nn.Conv2d(24,18, kernel_size=2, padding=padding) 26 | # self.conv7 = nn.Conv2d(18,12, kernel_size=2, padding=padding) 27 | self.conv8 = nn.Conv2d(12, 6, kernel_size=3, padding=0) 28 | self.conv9 = nn.Conv2d(6, 3, kernel_size=3, padding=0) 29 | self.conv10 = nn.Conv2d(3, 1, kernel_size=3, padding=0) 30 | # # self.conv4 = nn.Conv2d(12,12, kernel_size=5) 31 | # self.fc1 = nn.Linear(192, out_dim) 32 | 33 | def forward(self, x): 34 | # reshape = [] 35 | # for dim in x.size(): reshape.append(dim) 36 | # reshape.append(self.in_channels) 37 | 38 | # ## reshape to vector 39 | # x = x.view(-1) 40 | # ## get embeddings 41 | # x = self.embed(x) 42 | # ## reshape to batch x channels x M x N x embed_dim 43 | # x = x.view(*reshape) 44 | # ## sum over channels in input 45 | # x = x.sum(1) 46 | # # pdb.set_trace() 47 | # ## reshape to batch x embed_dim x M x N 48 | # ## (treats embedding dims as channels) 49 | # x = x.transpose(1,-1)[:,:,:,:,0] #.squeeze() # 50 | # print 'SIZE:', x.size() 51 | # pdb.set_trace() 52 | x = F.relu(self.conv1(x)) 53 | x = F.relu(self.conv2(x)) 54 | x = F.relu(self.conv3(x)) 55 | # x = F.relu(self.conv4(x)) 56 | # x = F.relu(self.conv5(x)) 57 | # x = F.relu(self.conv6(x)) 58 | # x = F.relu(self.conv7(x)) 59 | x = F.relu(self.conv8(x)) 60 | x = F.relu(self.conv9(x)) 61 | x = self.conv10(x) 62 | 63 | # x = x.view(-1, 192) 64 | # x = self.fc1(x) 65 | return x 66 | 67 | 68 | if __name__ == '__main__': 69 | from torch.autograd import Variable 70 | # inp = torch.LongTensor(2,10,10).zero_() 71 | inp = Variable( torch.randn(5,1,10,10) ) 72 | 73 | model = ConvToVector(1) 74 | 75 | out = model(inp) 76 | 77 | print inp.size() 78 | print out.size() 79 | 80 | -------------------------------------------------------------------------------- /models/custom.py: -------------------------------------------------------------------------------- 1 | import math, torch, torch.nn.functional as F, pdb 2 | from torch.autograd import Variable 3 | from torch.nn.parameter import Parameter 4 | 5 | from torch.nn.modules.module import Module 6 | from torch.nn.modules.utils import _single, _pair, _triple 7 | from torch.nn.modules.conv import _ConvNd 8 | 9 | class _ConvNdKernel(Module): 10 | 11 | def __init__(self, in_channels, out_channels, kernel_size, stride, 12 | padding, dilation, transposed, output_padding, groups, bias): 13 | super(_ConvNdKernel, self).__init__() 14 | if in_channels % groups != 0: 15 | raise ValueError('in_channels must be divisible by groups') 16 | if out_channels % groups != 0: 17 | raise ValueError('out_channels must be divisible by groups') 18 | self.in_channels = in_channels 19 | self.out_channels = out_channels 20 | self.kernel_size = kernel_size 21 | self.stride = stride 22 | self.padding = padding 23 | self.dilation = dilation 24 | self.transposed = transposed 25 | self.output_padding = output_padding 26 | self.groups = groups 27 | # if transposed: 28 | # self.weight = Parameter(torch.Tensor( 29 | # in_channels, out_channels // groups, *kernel_size)) 30 | # else: 31 | # self.weight = Parameter(torch.Tensor( 32 | # out_channels, in_channels // groups, *kernel_size)) 33 | if bias: 34 | self.bias = Parameter(torch.Tensor(out_channels)) 35 | else: 36 | self.register_parameter('bias', None) 37 | self.reset_parameters() 38 | 39 | def reset_parameters(self): 40 | n = self.in_channels 41 | for k in self.kernel_size: 42 | n *= k 43 | stdv = 1. / math.sqrt(n) 44 | # self.weight.data.uniform_(-stdv, stdv) 45 | if self.bias is not None: 46 | self.bias.data.uniform_(-stdv, stdv) 47 | 48 | def __repr__(self): 49 | s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' 50 | ', stride={stride}') 51 | if self.padding != (0,) * len(self.padding): 52 | s += ', padding={padding}' 53 | if self.dilation != (1,) * len(self.dilation): 54 | s += ', dilation={dilation}' 55 | if self.output_padding != (0,) * len(self.output_padding): 56 | s += ', output_padding={output_padding}' 57 | if self.groups != 1: 58 | s += ', groups={groups}' 59 | if self.bias is None: 60 | s += ', bias=False' 61 | s += ')' 62 | return s.format(name=self.__class__.__name__, **self.__dict__) 63 | 64 | 65 | class ConvKernel(_ConvNdKernel): 66 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 67 | padding=0, dilation=1, groups=1, bias=True): 68 | kernel_size = _pair(kernel_size) 69 | stride = _pair(stride) 70 | padding = _pair(padding) 71 | dilation = _pair(dilation) 72 | super(ConvKernel, self).__init__( 73 | in_channels, out_channels, kernel_size, stride, padding, dilation, 74 | False, _pair(0), groups, bias) 75 | 76 | def forward(self, input, kernel): 77 | self.weight = Parameter(kernel.data) 78 | # print 'weight: ', self.weight.size() 79 | # print 'bias: ', self.bias.size() 80 | # print 'forward:', type(input.data), type(self.weight.data) 81 | # print 'forward: ', input.size(), self.weight.size() 82 | return F.conv2d(input, kernel, self.bias, self.stride, 83 | self.padding, self.dilation, self.groups) 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /models/goal_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Psi(nn.Module): 6 | 7 | def __init__(self, text_model, object_model, concat_dim, hidden_dim, out_dim): 8 | super(Psi, self).__init__() 9 | 10 | self.text_model = text_model 11 | self.object_model = object_model 12 | self.fc1 = nn.Linear(concat_dim * 2, hidden_dim) 13 | self.fc2 = nn.Linear(hidden_dim, hidden_dim) 14 | self.fc3 = nn.Linear(hidden_dim, out_dim) 15 | 16 | 17 | def forward(self, inp): 18 | (obj, text) = inp 19 | batch_size = obj.size(0) 20 | text = text.transpose(0,1) 21 | hidden = self.text_model.init_hidden(batch_size) 22 | # print obj, text, hidden 23 | 24 | obj_out = self.object_model.forward(obj) 25 | # print text 26 | # print 'obj out: ', obj_out.data.size() 27 | # print 'text inp: ', text.data.size() 28 | # print 'hidden: ', hidden.data.size() 29 | text_out = self.text_model.forward(text, hidden) 30 | concat = F.relu( torch.cat((obj_out, text_out), 1) ) 31 | output = F.relu(self.fc1(concat)) 32 | output = F.relu(self.fc2(output)) 33 | output = self.fc3(output) 34 | return output 35 | 36 | 37 | if __name__ == '__main__': 38 | from text_model import * 39 | from object_model import * 40 | 41 | batch = 2 42 | seq = 10 43 | 44 | text_vocab = 10 45 | lstm_inp = 5 46 | lstm_hid = 3 47 | lstm_layers = 2 48 | 49 | obj_inp = torch.LongTensor(1,20,20).zero_() 50 | obj_vocab = 3 51 | emb_dim = 3 52 | 53 | concat_dim = 10 54 | hidden_dim = 5 55 | out_dim = 7 56 | 57 | text_model = TextModel(text_vocab, lstm_inp, lstm_hid, lstm_layers, concat_dim) 58 | object_model = ObjectModel(obj_vocab, emb_dim, obj_inp.size(), concat_dim) 59 | 60 | psi = Psi(text_model, object_model, concat_dim, hidden_dim, out_dim) 61 | 62 | hidden = text_model.init_hidden(batch) 63 | text_inp = Variable(torch.floor(torch.rand(seq,batch)*text_vocab).long()) 64 | obj_inp = Variable(torch.floor(torch.rand(batch,1,20,20)*obj_vocab).long()) 65 | 66 | to = psi.forward(obj_inp, text_inp, hidden) 67 | print to 68 | 69 | # # inp = Variable(torch.LongTensor((1,2,3))) 70 | # print 'INPUT: ', text_inp 71 | # print text_inp.size() 72 | # out = text_model.forward( text_inp, hidden ) 73 | 74 | # print 'OUT: ', out 75 | # print out.size() 76 | # # print 'HID: ', hid 77 | 78 | # obj_inp = Variable(torch.LongTensor(5,1,20,20).zero_()) 79 | # obj_inp = Variable(obj_inp) 80 | # obj_out = object_model.forward(obj_inp) 81 | 82 | # print obj_out.data.size() 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /models/initializations.py: -------------------------------------------------------------------------------- 1 | import models 2 | 3 | ''' 4 | norbf (full), nobases (no gradient), nonsep 5 | uvfa-pos, uvfa-text, cnn+lstm 6 | (nocnn) 7 | 8 | rbf + gradient + cnn: full 9 | no rbf: norbf 10 | no gradient: noglobal 11 | no rbf / gradient: nobases 12 | no cnn: nocnn 13 | ''' 14 | 15 | def init(args, layout_vocab_size, object_vocab_size, text_vocab_size): 16 | if args.model == 'full': ## new 17 | model = init_full(args, layout_vocab_size, object_vocab_size, text_vocab_size) 18 | elif args.model == 'no-gradient': 19 | model = init_nogradient(args, layout_vocab_size, object_vocab_size, text_vocab_size) 20 | elif args.model == 'cnn-lstm': 21 | model = init_cnn_lstm(args, layout_vocab_size, object_vocab_size, text_vocab_size) 22 | elif args.model == 'uvfa-text': 23 | model = init_uvfa_text(args, layout_vocab_size, object_vocab_size, text_vocab_size) 24 | # TODO: clean up UVFA-pos goal loading 25 | elif args.model == 'uvfa-pos': 26 | model = init_uvfa_pos(args, layout_vocab_size, object_vocab_size, text_vocab_size) 27 | train_indices = train_goals 28 | val_indices = val_goals 29 | return model 30 | 31 | 32 | def init_full(args, layout_vocab_size, object_vocab_size, text_vocab_size): 33 | args.global_coeffs = 3 34 | args.attention_in_dim = args.obj_embed 35 | args.lstm_out = args.attention_in_dim * args.attention_out_dim * args.attention_kernel**2 + args.global_coeffs 36 | 37 | state_model = models.LookupModel(layout_vocab_size, args.state_embed).cuda() 38 | object_model = models.LookupModel(object_vocab_size, args.obj_embed) 39 | 40 | text_model = models.TextModel(text_vocab_size, args.lstm_inp, args.lstm_hid, args.lstm_layers, args.lstm_out) 41 | heatmap_model = models.AttentionGlobal(text_model, args, map_dim=args.map_dim).cuda() 42 | 43 | model = models.MultiNoRBF(state_model, object_model, heatmap_model, args, map_dim=args.map_dim).cuda() 44 | return model 45 | 46 | 47 | def init_nogradient(args, layout_vocab_size, object_vocab_size, text_vocab_size): 48 | args.global_coeffs = 0 49 | args.attention_in_dim = args.obj_embed 50 | args.lstm_out = args.attention_in_dim * args.attention_out_dim * args.attention_kernel**2 51 | 52 | state_model = models.LookupModel(layout_vocab_size, args.state_embed).cuda() 53 | object_model = models.LookupModel(object_vocab_size, args.obj_embed) 54 | 55 | text_model = models.TextModel(text_vocab_size, args.lstm_inp, args.lstm_hid, args.lstm_layers, args.lstm_out) 56 | heatmap_model = models.AttentionHeatmap(text_model, args, map_dim=args.map_dim).cuda() 57 | 58 | model = models.MultiNoBases(state_model, object_model, heatmap_model, args, map_dim=args.map_dim).cuda() 59 | return model 60 | 61 | 62 | def init_cnn_lstm(args, layout_vocab_size, object_vocab_size, text_vocab_size): 63 | args.lstm_out = 16 64 | args.cnn_out_dim = 2*args.lstm_out 65 | 66 | state_model = models.LookupModel(layout_vocab_size, args.state_embed) 67 | object_model = models.LookupModel(object_vocab_size, args.obj_embed) 68 | 69 | lstm = models.TextModel(text_vocab_size, args.lstm_inp, args.lstm_hid, args.lstm_layers, args.lstm_out) 70 | 71 | model = models.CNN_LSTM(state_model, object_model, lstm, args).cuda() 72 | return model 73 | 74 | 75 | def init_uvfa_text(args, layout_vocab_size, object_vocab_size, text_vocab_size, rank = 7): 76 | print ' Using UVFA variant, consider using a lower learning rate (eg, 0.0001)' 77 | print ' UVFA rank: {} '.format(rank) 78 | 79 | args.rank = rank 80 | args.lstm_out = rank 81 | 82 | text_model = models.TextModel(text_vocab_size, args.lstm_inp, args.lstm_hid, args.lstm_layers, args.lstm_out) 83 | model = models.UVFA_text(text_model, layout_vocab_size, object_vocab_size, args, map_dim=args.map_dim).cuda() 84 | return model 85 | 86 | 87 | def init_uvfa_pos(args, layout_vocab_size, object_vocab_size, text_vocab_size, rank = 7): 88 | print ' Using UVFA variant, consider using a lower learning rate (eg, 0.0001)' 89 | print ' UVFA rank: {} '.format(rank) 90 | 91 | args.rank = rank 92 | args.lstm_out = rank 93 | 94 | model = models.UVFA_pos(layout_vocab_size, object_vocab_size, args, map_dim=args.map_dim).cuda() 95 | return model 96 | 97 | 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /models/lookup_location.py: -------------------------------------------------------------------------------- 1 | import sys, math, pdb 2 | import numpy as np 3 | from tqdm import tqdm, trange 4 | import torch, torch.nn as nn, torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import torch.optim as optim 7 | 8 | ''' 9 | Object observations are single-channels images 10 | with positive indices denoting objects. 11 | 0's denote no object. 12 | 13 | ''' 14 | 15 | class LookupLocationModel(nn.Module): 16 | def __init__(self, vocab_size, embed_dim, map_dim = 10): 17 | super(LookupLocationModel, self).__init__() 18 | 19 | ## add two for (x,y) location channels 20 | self.embed_dim = embed_dim 21 | 22 | # self.reshape = [-1] 23 | # for dim in inp_size: 24 | # self.reshape.append(dim) 25 | # self.reshape.append(self.embed_dim) 26 | 27 | # self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0) 28 | self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0) 29 | 30 | # self.conv1 = nn.Conv2d(embed_dim, 3, kernel_size=5) 31 | # self.conv2 = nn.Conv2d(3, 6, kernel_size=5) 32 | # self.conv3 = nn.Conv2d(6,12, kernel_size=5) 33 | # self.conv4 = nn.Conv2d(12,12, kernel_size=5) 34 | # self.fc1 = nn.Linear(192, out_dim) 35 | # self.init_weights() 36 | 37 | self.map_dim = map_dim 38 | self.locations = self.__init_locations(self.map_dim) 39 | 40 | def __init_locations(self, map_dim): 41 | row = torch.arange(0,map_dim).unsqueeze(1).repeat(1,map_dim) 42 | col = torch.arange(0,map_dim).repeat(map_dim,1) 43 | locations = torch.stack( (row, col) ) 44 | return Variable(locations.cuda()) 45 | 46 | def init_weights(self): 47 | initrange = 0.1 48 | self.embed.weight.data.uniform_(-initrange, initrange) 49 | # self.decoder.bias.data.fill_(0) 50 | # self.decoder.weight.data.uniform_(-initrange, initrange) 51 | 52 | def forward(self, x): 53 | # print 'LOOKUP: ', x.size(), type(x) 54 | batch_size = x.size(0) 55 | 56 | reshape = [] 57 | for dim in x.size(): reshape.append(dim) 58 | reshape.append(self.embed_dim) 59 | 60 | if x.size(-1) != self.map_dim: 61 | self.map_dim = x.size(-1) 62 | self.locations = self.__init_locations(self.map_dim) 63 | 64 | ## reshape to vector 65 | x = x.view(-1) 66 | ## get embeddings 67 | x = self.embed(x) 68 | ## reshape to batch x channels x M x N x embed_dim 69 | x = x.view(*reshape) 70 | ## sum over channels in input 71 | x = x.sum(1, keepdim=True) 72 | # pdb.set_trace() 73 | ## reshape to batch x embed_dim x M x N 74 | ## (treats embedding dims as channels) 75 | x = x.transpose(1,-1)[:,:,:,:,0] #.squeeze() # 76 | 77 | locations = self.locations.unsqueeze(0).repeat(batch_size,1,1,1) 78 | 79 | # pdb.set_trace() 80 | x = torch.cat( (x, locations), 1 ) 81 | # print self.locations 82 | # print locations 83 | # x = F.relu(self.conv1(x)) 84 | # x = F.relu(self.conv2(x)) 85 | # x = F.relu(self.conv3(x)) 86 | # x = F.relu(self.conv4(x)) 87 | # x = x.view(-1, 192) 88 | # x = self.fc1(x) 89 | return x 90 | 91 | if __name__ == '__main__': 92 | vocab_size = 10 93 | emb_dim = 3 94 | map_dim = 10 95 | phi = LookupLocationModel(vocab_size, emb_dim, map_dim=map_dim) 96 | 97 | # enc = nn.Embedding(10,emb_dim,padding_idx=0) 98 | inp = torch.LongTensor(5,1,map_dim,map_dim).zero_() 99 | inp[0][0][0][0]=1 100 | # inp[0][1][0][0]=1 101 | inp[1][0][0][2]=1 102 | # print inp 103 | inp = Variable(inp) 104 | 105 | out = phi.forward(inp) 106 | # print out 107 | # out = out.view(-1,2,3,3,emb_dim) 108 | out = out.data 109 | print out.size() 110 | 111 | # print out[0][0][0] 112 | # print out[1][0][0] 113 | 114 | -------------------------------------------------------------------------------- /models/lookup_model.py: -------------------------------------------------------------------------------- 1 | import sys, math, pdb 2 | import numpy as np 3 | from tqdm import tqdm, trange 4 | import torch, torch.nn as nn, torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import torch.optim as optim 7 | 8 | ''' 9 | Object observations are single-channels images 10 | with positive indices denoting objects. 11 | 0's denote no object. 12 | 13 | ''' 14 | 15 | class LookupModel(nn.Module): 16 | def __init__(self, vocab_size, embed_dim): 17 | super(LookupModel, self).__init__() 18 | 19 | self.embed_dim = embed_dim 20 | 21 | # self.reshape = [-1] 22 | # for dim in inp_size: 23 | # self.reshape.append(dim) 24 | # self.reshape.append(self.embed_dim) 25 | 26 | # self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0) 27 | self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0) 28 | 29 | # self.conv1 = nn.Conv2d(embed_dim, 3, kernel_size=5) 30 | # self.conv2 = nn.Conv2d(3, 6, kernel_size=5) 31 | # self.conv3 = nn.Conv2d(6,12, kernel_size=5) 32 | # self.conv4 = nn.Conv2d(12,12, kernel_size=5) 33 | # self.fc1 = nn.Linear(192, out_dim) 34 | # self.init_weights() 35 | 36 | # self.map_dim = map_dim 37 | # self.locations = self.__init_locations(self.map_dim) 38 | 39 | def init_weights(self): 40 | initrange = 0.1 41 | self.embed.weight.data.uniform_(-initrange, initrange) 42 | # self.decoder.bias.data.fill_(0) 43 | # self.decoder.weight.data.uniform_(-initrange, initrange) 44 | 45 | def forward(self, x): 46 | # print 'LOOKUP: ', x.size(), type(x) 47 | batch_size = x.size(0) 48 | 49 | reshape = [] 50 | for dim in x.size(): reshape.append(dim) 51 | reshape.append(self.embed_dim) 52 | 53 | ## reshape to vector 54 | x = x.view(-1) 55 | ## get embeddings 56 | x = self.embed(x) 57 | ## reshape to batch x channels x M x N x embed_dim 58 | x = x.view(*reshape) 59 | ## sum over channels in input 60 | x = x.sum(1, keepdim=True) 61 | # pdb.set_trace() 62 | ## reshape to batch x embed_dim x M x N 63 | ## (treats embedding dims as channels) 64 | x = x.transpose(1,-1)[:,:,:,:,0] #.squeeze() # 65 | 66 | # locations = self.locations.unsqueeze(0).repeat(batch_size,1,1,1) 67 | 68 | # pdb.set_trace() 69 | # x = torch.cat( (x, locations), 1 ) 70 | # print self.locations 71 | # print locations 72 | # x = F.relu(self.conv1(x)) 73 | # x = F.relu(self.conv2(x)) 74 | # x = F.relu(self.conv3(x)) 75 | # x = F.relu(self.conv4(x)) 76 | # x = x.view(-1, 192) 77 | # x = self.fc1(x) 78 | return x 79 | 80 | if __name__ == '__main__': 81 | vocab_size = 10 82 | emb_dim = 3 83 | map_dim = 10 84 | phi = LookupModel(vocab_size, emb_dim, map_dim=map_dim) 85 | 86 | # enc = nn.Embedding(10,emb_dim,padding_idx=0) 87 | inp = torch.LongTensor(5,1,map_dim,map_dim).zero_() 88 | inp[0][0][0][0]=1 89 | # inp[0][1][0][0]=1 90 | inp[1][0][0][2]=1 91 | # print inp 92 | inp = Variable(inp) 93 | 94 | out = phi.forward(inp) 95 | # print out 96 | # out = out.view(-1,2,3,3,emb_dim) 97 | out = out.data 98 | print out.size() 99 | 100 | # print out[0][0][0] 101 | # print out[1][0][0] 102 | 103 | -------------------------------------------------------------------------------- /models/map_model.py: -------------------------------------------------------------------------------- 1 | import sys, math 2 | import numpy as np 3 | from tqdm import tqdm, trange 4 | import torch, torch.nn as nn, torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | ''' 8 | State observations are two-channel images 9 | with 0: puddle, 1: grass, 2: agent 10 | 11 | ''' 12 | 13 | class MapModel(nn.Module): 14 | def __init__(self, vocab_size, embed_dim, out_dim): 15 | super(MapModel, self).__init__() 16 | 17 | self.embed_dim = embed_dim 18 | 19 | self.embed = nn.Embedding(vocab_size, embed_dim) 20 | self.conv1 = nn.Conv2d(embed_dim, 3, kernel_size=3) 21 | self.conv2 = nn.Conv2d(3, 6, kernel_size=3) 22 | self.conv3 = nn.Conv2d(6,12, kernel_size=3) 23 | # self.conv4 = nn.Conv2d(12,12, kernel_size=5) 24 | self.fc1 = nn.Linear(192, out_dim) 25 | 26 | def forward(self, x): 27 | reshape = [] 28 | for dim in x.size(): reshape.append(dim) 29 | reshape.append(self.embed_dim) 30 | 31 | ## reshape to vector 32 | x = x.view(-1) 33 | ## get embeddings 34 | x = self.embed(x) 35 | ## reshape to batch x channels x M x N x embed_dim 36 | x = x.view(*reshape) 37 | ## sum over channels in input 38 | x = x.sum(1, keepdim=True) 39 | ## reshape to batch x embed_dim x M x N 40 | ## (treats embedding dims as channels) 41 | x = x.transpose(1,-1).squeeze() 42 | x = F.relu(self.conv1(x)) 43 | x = F.relu(self.conv2(x)) 44 | x = F.relu(self.conv3(x)) 45 | 46 | x = x.view(-1, 192) 47 | x = self.fc1(x) 48 | return x 49 | 50 | 51 | if __name__ == '__main__': 52 | from torch.autograd import Variable 53 | # inp = torch.LongTensor(2,10,10).zero_() 54 | vocab_size = 10 55 | emb_dim = 3 56 | rank = 7 57 | phi = MapModel(vocab_size, emb_dim, rank) 58 | 59 | # enc = nn.Embedding(10,emb_dim,padding_idx=0) 60 | inp = torch.LongTensor(5,2,10,10).zero_() 61 | inp[0][0][0][0]=1 62 | # inp[0][1][0][0]=1 63 | inp[1][0][0][2]=1 64 | print inp.size() 65 | inp = Variable(inp) 66 | 67 | out = phi.forward(inp) 68 | # print out 69 | # out = out.view(-1,2,3,3,emb_dim) 70 | out = out.data 71 | print out.size() 72 | 73 | # print out[0][0][0] 74 | # print out[1][0][0] 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /models/mlp.py: -------------------------------------------------------------------------------- 1 | ## predicts entire value map 2 | ## rather than a single value 3 | 4 | import torch 5 | import math, torch.nn as nn, pdb 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | class MLP(nn.Module): 10 | def __init__(self, sizes): 11 | super(MLP, self).__init__() 12 | 13 | layers = [] 14 | 15 | for ind in range(len(sizes)-1): 16 | layers.append( nn.Linear(sizes[ind], sizes[ind+1]) ) 17 | layers.append( nn.ReLU() ) 18 | layers.pop(-1) 19 | 20 | self.layers = nn.ModuleList(layers) 21 | 22 | def forward(self, x): 23 | for lay in self.layers: 24 | x = lay(x) 25 | return x 26 | 27 | if __name__ == '__main__': 28 | batch_size = 32 29 | sizes = [10, 128, 128, 4] 30 | mlp = MLP(sizes) 31 | inp = Variable( torch.Tensor(batch_size, sizes[0]) ) 32 | print inp.size() 33 | out = mlp(inp) 34 | print out.size() 35 | loss = out.sum() 36 | loss.backward() 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /models/model_factorizer.py: -------------------------------------------------------------------------------- 1 | import sys, math 2 | import numpy as np 3 | from tqdm import tqdm, trange 4 | import torch, torch.nn as nn, torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.autograd import Variable 7 | from torch.nn.modules.utils import _pair 8 | import pdb, pickle 9 | 10 | class ModelFactorizer(nn.Module): 11 | def __init__(self, state_model, goal_model, state_inp, goal_inp, sparse_value_mat): 12 | super(ModelFactorizer, self).__init__() 13 | self.state_model = state_model 14 | self.goal_model = goal_model 15 | self.state_inp = Variable(state_inp.cuda()) 16 | self.goal_inp = [Variable(i).cuda() for i in goal_inp] 17 | self.mat = Variable( sparse_value_mat ) 18 | self.mask = (self.mat == 0) 19 | self.M = state_inp.size(0) 20 | 21 | def forward(self, inds): 22 | state_inp = self.state_inp.index_select(0, inds) 23 | state_out = self.state_model.forward(state_inp) 24 | goal_out = self.goal_model.forward(self.goal_inp) 25 | 26 | recon = torch.mm(state_out, goal_out.t()) 27 | mask_select = self.mask.index_select(0, inds) 28 | true_select = self.mat.index_select(0, inds) 29 | 30 | # pdb.set_trace() 31 | 32 | diff = torch.pow(recon - true_select, 2) 33 | 34 | mse = diff.sum() 35 | 36 | return mse 37 | 38 | def train(self, lr, iters, batch_size = 256): 39 | optimizer = optim.Adam(self.parameters(), lr=lr) 40 | 41 | t = trange(iters) 42 | for i in t: 43 | optimizer.zero_grad() 44 | inds = torch.floor(torch.rand(batch_size) * self.M).long().cuda() 45 | # bug: floor(rand()) sometimes gives 1 46 | inds[inds >= self.M] = self.M - 1 47 | inds = Variable(inds) 48 | 49 | loss = self.forward(inds) 50 | # print loss.data[0] 51 | t.set_description( str(loss.data[0]) ) 52 | loss.backward() 53 | optimizer.step() 54 | 55 | return self.state_model, self.goal_model 56 | 57 | 58 | 59 | if __name__ == '__main__': 60 | from state_model import * 61 | from object_model_10 import * 62 | from text_model import * 63 | from goal_model import * 64 | rank = 7 65 | state_vocab_size = 20 66 | embed_size = 3 67 | state_obs_size = (2,10,10) 68 | goal_obs_size = (1,10,10) 69 | lstm_size = 15 70 | lstm_nlayer = 1 71 | phi = Phi(state_vocab_size, embed_size, state_obs_size, rank).cuda() 72 | text_model = TextModel(state_vocab_size, lstm_size, lstm_size, lstm_nlayer, lstm_size) 73 | object_model = ObjectModel(state_vocab_size, embed_size, goal_obs_size, lstm_size) 74 | psi = Psi(text_model, object_model, lstm_size, lstm_size, rank).cuda() 75 | print phi 76 | 77 | state_obs = Variable((torch.rand(20*100,2,10,10)*10).long().cuda()) 78 | pdb.set_trace() 79 | 80 | out = phi.forward(state_obs) 81 | print out.size() 82 | # from torch.nn.modules.utils import _pair 83 | # rank = 4 84 | # # kernel = Variable(avg_conv(rank)) 85 | # kernel = Variable( avg_vector(rank) ) 86 | 87 | # inp = Variable(torch.randn(1,rank,20,400)) 88 | # bias = None 89 | # stride = _pair(1) 90 | # padding = _pair(1) 91 | # conv = F.conv2d(inp,kernel,bias,stride,padding)[:,:,1:-1,:] 92 | # print 'inp:', inp.size() 93 | # print 'kernel:', kernel.size() 94 | # print 'conv:', conv.size() 95 | 96 | # pdb.set_trace() 97 | 98 | # from torch.nn.modules.utils import _pair 99 | # inp = Variable(torch.randn(1,5,8,8)) 100 | # kern = Variable(torch.ones(7,5,3,3)) 101 | # bias = None 102 | # stride = _pair(1) 103 | # padding = _pair(1) 104 | # # dilation = _pair(1) 105 | # # print dilation 106 | # # groups = 1 107 | # # conv = F.conv2d(inp,kern,bias,stride,padding,dilation,groups) 108 | # conv = F.conv2d(inp,kern,bias,stride,padding) 109 | # print 'inp:', inp 110 | # print 'conv: ', conv 111 | # print 'sum: ', inp[:,:,:3,:3].sum() 112 | 113 | # print a.b 114 | # print 'conv: ', conv 115 | value_mat = pickle.load( open('../pickle/value_mat20.p') ) 116 | rank = 10 117 | 118 | print 'value_mat: ', value_mat.shape 119 | dissimilarity_lambda = .1 120 | world_lambda = 0 121 | location_lambda = .001 122 | model = ConstraintFactorizer(value_mat, rank, dissimilarity_lambda, world_lambda, location_lambda).cuda() 123 | lr = 0.001 124 | iters = 500000 125 | U, V, recon = model.train(lr, iters) 126 | 127 | # lr = 0.001 128 | # optimizer = optim.Adam(model.parameters(), lr=lr) 129 | 130 | # loss = model.forward( () ) 131 | # print loss 132 | # iters = 50000 133 | # t = trange(iters) 134 | # for i in t: 135 | # optimizer.zero_grad() 136 | # loss = model.forward( () ) 137 | # # print loss.data[0] 138 | # t.set_description( str(model.mse) + ' ' + str(model.divergence) ) 139 | # loss.backward() 140 | # optimizer.step() 141 | 142 | # U, V = model.embeddings() 143 | 144 | print 'recon' 145 | print recon 146 | print 'true' 147 | print torch.Tensor(value_mat) 148 | 149 | pickle.dump(U, open('../pickle/U_lambda_' + str(dissimilarity_lambda) + '.p', 'w') ) 150 | pickle.dump(V, open('../pickle/V_lambda_' + str(dissimilarity_lambda) + '.p', 'w') ) 151 | # pdb.set_trace() 152 | 153 | 154 | # inp = torch.LongTensor(5,3).zero_() 155 | # vocab_size = 10 156 | # emb_dim = 3 157 | # rank = 10 158 | # phi = ObjectModel(vocab_size, emb_dim, inp.size(), rank) 159 | 160 | # batch = 2 161 | # hidden = phi.init_hidden(batch) 162 | 163 | # # enc = nn.Embedding(10,emb_dim,padding_idx=0) 164 | # inp = torch.LongTensor(batch,5,3).zero_() 165 | # inp[0][0]=1 166 | # inp[0][1]=1 167 | # inp[1][0]=8 168 | # inp[1][1]=8 169 | # print inp 170 | # # inp[0][0][0][0]=1 171 | # # inp[0][1][0][0]=1 172 | # # inp[1][0][0][2]=1 173 | # # print inp 174 | # inp = Variable(inp) 175 | 176 | # out = phi.forward(inp, hidden) 177 | # # print out 178 | # # out = out.view(-1,2,3,3,emb_dim) 179 | # out = out.data 180 | # print out.size() 181 | 182 | # # print out[0][0][0] 183 | # # print out[1][0][0] 184 | 185 | -------------------------------------------------------------------------------- /models/multi_global.py: -------------------------------------------------------------------------------- 1 | ## predicts entire value map 2 | ## rather than a single value 3 | 4 | import torch 5 | import math, torch.nn as nn, pdb 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | import models, utils 9 | 10 | class MultiGlobal(nn.Module): 11 | def __init__(self, state_model, object_model, heatmap_model, args, map_dim = 10): 12 | super(MultiGlobal, self).__init__() 13 | 14 | self.state_model = state_model 15 | self.object_model = object_model 16 | self.heatmap_model = heatmap_model 17 | self.simple_conv = models.SimpleConv(3).cuda() 18 | self.rbf = Variable( utils.meta_rbf(map_dim).cuda() ) 19 | self.positions = Variable( self.__init_positions(map_dim).cuda() ) 20 | 21 | self.map_dim = map_dim 22 | self.batch_size = args.batch_size 23 | self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1) 24 | self.positions_batch = self.positions.repeat(self.batch_size,1,1,1) 25 | 26 | ''' 27 | global_coeffs are < batch x 3 > 28 | 3: row, col, bias 29 | ''' 30 | def _global(self, global_coeffs): 31 | pos_coeffs = global_coeffs[:,:-1] 32 | bias = global_coeffs[:,-1] 33 | 34 | coeffs_batch = pos_coeffs.unsqueeze(-1).unsqueeze(-1).repeat(1,1,self.map_dim,self.map_dim) 35 | bias_batch = bias.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1,1,self.map_dim,self.map_dim) 36 | 37 | ## sum over row, col and add bias 38 | obj_global = (coeffs_batch * self.positions_batch).sum(1, keepdim=True) + bias_batch 39 | return obj_global 40 | 41 | 42 | 43 | def __init_positions(self, map_dim): 44 | row = torch.arange(0,map_dim).unsqueeze(1).repeat(1,map_dim) 45 | col = torch.arange(0,map_dim).repeat(map_dim,1) 46 | positions = torch.stack( (row, col) ) 47 | return positions 48 | 49 | 50 | def forward(self, inp): 51 | (state, obj, text) = inp 52 | batch_size = state.size(0) 53 | if batch_size != self.batch_size: 54 | self.batch_size = batch_size 55 | self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1) 56 | self.positions_batch = self.positions.repeat(self.batch_size,1,1,1) 57 | 58 | ## get state map 59 | state_out = self.state_model(state) 60 | obj_out = self.object_model.forward(obj) 61 | 62 | # print 'state_out: ', state_out.size() 63 | # print 'obj_out: ', obj_out.size() 64 | 65 | ## get object map 66 | heatmap, global_coeffs = self.heatmap_model((obj_out, text)) 67 | 68 | ## repeat heatmap for multiplication by rbf batch 69 | heatmap_batch = heatmap.view(self.batch_size,self.map_dim**2,1,1).repeat(1,1,self.map_dim,self.map_dim) 70 | 71 | ## multiply object map by pre-computed manhattan rbf 72 | ## < batch x size^2 x size x size > 73 | obj_local = heatmap_batch * self.rbf_batch 74 | ## sum contributions from rbf from every source 75 | ## < batch x 1 x size x size > 76 | obj_local = obj_local.sum(1) 77 | # print 'obj_out:', obj_out.size() 78 | 79 | obj_global = self._global(global_coeffs) 80 | 81 | # pdb.set_trace() 82 | 83 | # obj_out = obj_local + obj_global 84 | 85 | map_pred = torch.cat( (state_out, obj_local, obj_global), 1 ) 86 | # pdb.set_trace() 87 | map_pred = self.simple_conv(map_pred) 88 | # map_pred = state_out + obj_out 89 | 90 | return map_pred 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /models/multi_model.py: -------------------------------------------------------------------------------- 1 | ## predicts entire value map 2 | ## rather than a single value 3 | 4 | import torch 5 | import math, torch.nn as nn, pdb 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | import models, utils 9 | 10 | class MultiModel(nn.Module): 11 | def __init__(self, state_model, object_model, heatmap_model, args, map_dim = 10): 12 | super(MultiModel, self).__init__() 13 | 14 | self.state_model = state_model 15 | self.object_model = object_model 16 | self.heatmap_model = heatmap_model 17 | self.simple_conv = models.SimpleConv(2).cuda() 18 | self.rbf = Variable( utils.meta_rbf(map_dim).cuda() ) 19 | 20 | self.map_dim = map_dim 21 | self.batch_size = args.batch_size 22 | self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1) 23 | 24 | 25 | 26 | def forward(self, inp): 27 | (state, obj, text) = inp 28 | batch_size = state.size(0) 29 | if batch_size != self.batch_size: 30 | self.batch_size = batch_size 31 | self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1) 32 | 33 | ## get state map 34 | state_out = self.state_model(state) 35 | obj_out = self.object_model.forward(obj) 36 | 37 | # print 'state_out: ', state_out.size() 38 | ## get object map 39 | heatmap = self.heatmap_model((obj_out, text)) 40 | ## repeat heatmap for multiplication by rbf batch 41 | heatmap_batch = heatmap.view(self.batch_size,self.map_dim**2,1,1).repeat(1,1,self.map_dim,self.map_dim) 42 | ## multiply object map by pre-computed manhattan rbf 43 | ## < batch x size^2 x size x size > 44 | # pdb.set_trace() 45 | obj_local = heatmap_batch * self.rbf_batch 46 | ## sum contributions from rbf from every source 47 | ## < batch x 1 x size x size > 48 | obj_local = obj_local.sum(1, keepdim=True) 49 | # print 'obj_out:', obj_out.size() 50 | # pdb.set_trace() 51 | map_pred = torch.cat( (state_out, obj_local), 1 ) 52 | # pdb.set_trace() 53 | map_pred = self.simple_conv(map_pred) 54 | # map_pred = state_out + obj_out 55 | 56 | return map_pred 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /models/multi_nobases.py: -------------------------------------------------------------------------------- 1 | ## predicts entire value map 2 | ## rather than a single value 3 | 4 | import torch 5 | import math, torch.nn as nn, pdb 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | import models, utils 9 | 10 | class MultiNoBases(nn.Module): 11 | def __init__(self, state_model, object_model, heatmap_model, args, map_dim = 10): 12 | super(MultiNoBases, self).__init__() 13 | 14 | self.state_model = state_model 15 | self.object_model = object_model 16 | self.heatmap_model = heatmap_model 17 | self.simple_conv = models.SimpleConv(2).cuda() 18 | self.rbf = Variable( utils.meta_rbf(map_dim).cuda() ) 19 | self.positions = Variable( self.__init_positions(map_dim).cuda() ) 20 | 21 | self.map_dim = map_dim 22 | self.batch_size = args.batch_size 23 | self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1) 24 | self.positions_batch = self.positions.repeat(self.batch_size,1,1,1) 25 | 26 | 27 | def __init_positions(self, map_dim): 28 | row = torch.arange(0,map_dim).unsqueeze(1).repeat(1,map_dim) 29 | col = torch.arange(0,map_dim).repeat(map_dim,1) 30 | positions = torch.stack( (row, col) ) 31 | return positions 32 | 33 | 34 | def forward(self, inp): 35 | (state, obj, text) = inp 36 | batch_size = state.size(0) 37 | if batch_size != self.batch_size: 38 | self.batch_size = batch_size 39 | self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1) 40 | self.positions_batch = self.positions.repeat(self.batch_size,1,1,1) 41 | 42 | ## get state map 43 | state_out = self.state_model(state) 44 | obj_out = self.object_model.forward(obj) 45 | 46 | # print 'state_out: ', state_out.size() 47 | # print 'obj_out: ', obj_out.size() 48 | 49 | ## get object map 50 | heatmap = self.heatmap_model((obj_out, text)) 51 | 52 | ## repeat heatmap for multiplication by rbf batch 53 | # heatmap_batch = heatmap.view(self.batch_size,self.map_dim**2,1,1).repeat(1,1,self.map_dim,self.map_dim) 54 | 55 | ## multiply object map by pre-computed manhattan rbf 56 | ## < batch x size^2 x size x size > 57 | # obj_local = heatmap_batch * self.rbf_batch 58 | ## sum contributions from rbf from every source 59 | ## < batch x 1 x size x size > 60 | # obj_local = obj_local.sum(1) 61 | # print 'obj_out:', obj_out.size() 62 | 63 | obj_local = heatmap 64 | 65 | # print obj_local.size() 66 | 67 | # obj_global = self._global(global_coeffs) 68 | 69 | # pdb.set_trace() 70 | 71 | # obj_out = obj_local + obj_global 72 | 73 | map_pred = torch.cat( (state_out, obj_local), 1 ) 74 | # pdb.set_trace() 75 | map_pred = self.simple_conv(map_pred) 76 | # map_pred = state_out + obj_out 77 | 78 | return map_pred 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /models/multi_nocnn.py: -------------------------------------------------------------------------------- 1 | ## predicts entire value map 2 | ## rather than a single value 3 | 4 | import torch 5 | import math, torch.nn as nn, pdb 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | import models, utils 9 | 10 | class MultiNoCNN(nn.Module): 11 | def __init__(self, state_model, object_model, heatmap_model, args, map_dim = 10): 12 | super(MultiNoCNN, self).__init__() 13 | 14 | self.state_model = state_model 15 | self.object_model = object_model 16 | self.heatmap_model = heatmap_model 17 | # self.simple_conv = models.SimpleConv(3).cuda() 18 | self.rbf = Variable( utils.meta_rbf(map_dim).cuda() ) 19 | self.positions = Variable( self.__init_positions(map_dim).cuda() ) 20 | 21 | self.map_dim = map_dim 22 | self.batch_size = args.batch_size 23 | self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1) 24 | self.positions_batch = self.positions.repeat(self.batch_size,1,1,1) 25 | 26 | ''' 27 | global_coeffs are < batch x 3 > 28 | 3: row, col, bias 29 | ''' 30 | def _global(self, global_coeffs): 31 | pos_coeffs = global_coeffs[:,:-1] 32 | bias = global_coeffs[:,-1] 33 | 34 | coeffs_batch = pos_coeffs.unsqueeze(-1).unsqueeze(-1).repeat(1,1,self.map_dim,self.map_dim) 35 | bias_batch = bias.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1,1,self.map_dim,self.map_dim) 36 | 37 | ## sum over row, col and add bias 38 | obj_global = (coeffs_batch * self.positions_batch).sum(1, keepdim=True) + bias_batch 39 | return obj_global 40 | 41 | 42 | 43 | def __init_positions(self, map_dim): 44 | row = torch.arange(0,map_dim).unsqueeze(1).repeat(1,map_dim) 45 | col = torch.arange(0,map_dim).repeat(map_dim,1) 46 | positions = torch.stack( (row, col) ) 47 | return positions 48 | 49 | 50 | def forward(self, inp): 51 | (state, obj, text) = inp 52 | batch_size = state.size(0) 53 | if batch_size != self.batch_size: 54 | self.batch_size = batch_size 55 | self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1) 56 | self.positions_batch = self.positions.repeat(self.batch_size,1,1,1) 57 | 58 | ## get state map 59 | state_out = self.state_model(state) 60 | obj_out = self.object_model.forward(obj) 61 | 62 | # print 'state_out: ', state_out.size() 63 | # print 'obj_out: ', obj_out.size() 64 | 65 | ## get object map 66 | heatmap, global_coeffs = self.heatmap_model((obj_out, text)) 67 | 68 | ## repeat heatmap for multiplication by rbf batch 69 | heatmap_batch = heatmap.view(self.batch_size,self.map_dim**2,1,1).repeat(1,1,self.map_dim,self.map_dim) 70 | 71 | ## multiply object map by pre-computed manhattan rbf 72 | ## < batch x size^2 x size x size > 73 | obj_local = heatmap_batch * self.rbf_batch 74 | ## sum contributions from rbf from every source 75 | ## < batch x 1 x size x size > 76 | obj_local = obj_local.sum(1) 77 | # print 'obj_out:', obj_out.size() 78 | 79 | obj_global = self._global(global_coeffs) 80 | 81 | # pdb.set_trace() 82 | 83 | # obj_out = obj_local + obj_global 84 | 85 | # map_pred = torch.cat( (state_out, obj_local, obj_global), 1 ) 86 | # pdb.set_trace() 87 | # map_pred = self.simple_conv(map_pred) 88 | # map_pred = state_out + obj_out 89 | 90 | map_pred = state_out + obj_local + obj_global 91 | 92 | return map_pred 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /models/multi_nonsep.py: -------------------------------------------------------------------------------- 1 | ## predicts entire value map 2 | ## rather than a single value 3 | 4 | import torch 5 | import math, torch.nn as nn, pdb 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | import models, utils 9 | 10 | class MultiNonSep(nn.Module): 11 | def __init__(self, state_model, object_model, heatmap_model, args, map_dim = 10): 12 | super(MultiNonSep, self).__init__() 13 | 14 | self.state_model = state_model 15 | self.object_model = object_model 16 | self.heatmap_model = heatmap_model 17 | self.simple_conv = models.SimpleConv(2).cuda() 18 | self.rbf = Variable( utils.meta_rbf(map_dim).cuda() ) 19 | self.positions = Variable( self.__init_positions(map_dim).cuda() ) 20 | 21 | self.map_dim = map_dim 22 | self.batch_size = args.batch_size 23 | # self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1) 24 | self.positions_batch = self.positions.repeat(self.batch_size,1,1,1) 25 | 26 | ''' 27 | global_coeffs are < batch x 3 > 28 | 3: row, col, bias 29 | ''' 30 | def _global(self, global_coeffs): 31 | pos_coeffs = global_coeffs[:,:-1] 32 | bias = global_coeffs[:,-1] 33 | 34 | coeffs_batch = pos_coeffs.unsqueeze(-1).unsqueeze(-1).repeat(1,1,self.map_dim,self.map_dim) 35 | bias_batch = bias.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1,1,self.map_dim,self.map_dim) 36 | 37 | ## sum over row, col and add bias 38 | obj_global = (coeffs_batch * self.positions_batch).sum(1, keepdim=True) + bias_batch 39 | return obj_global 40 | 41 | 42 | 43 | def __init_positions(self, map_dim): 44 | row = torch.arange(0,map_dim).unsqueeze(1).repeat(1,map_dim) 45 | col = torch.arange(0,map_dim).repeat(map_dim,1) 46 | positions = torch.stack( (row, col) ) 47 | return positions 48 | 49 | 50 | def forward(self, inp): 51 | (state, obj, text) = inp 52 | batch_size = state.size(0) 53 | if batch_size != self.batch_size: 54 | self.batch_size = batch_size 55 | # self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1) 56 | self.positions_batch = self.positions.repeat(self.batch_size,1,1,1) 57 | 58 | ## get state map 59 | state_out = self.state_model(state) 60 | obj_out = self.object_model.forward(obj) 61 | embeddings = torch.cat( (state_out, obj_out), 1 ) 62 | 63 | # print 'state_out: ', state_out.size() 64 | # print 'obj_out: ', obj_out.size() 65 | 66 | ## get object map 67 | heatmap, global_coeffs = self.heatmap_model((embeddings, text)) 68 | 69 | ## repeat heatmap for multiplication by rbf batch 70 | # heatmap_batch = heatmap.view(self.batch_size,self.map_dim**2,1,1).repeat(1,1,self.map_dim,self.map_dim) 71 | 72 | ## multiply object map by pre-computed manhattan rbf 73 | ## < batch x size^2 x size x size > 74 | # obj_local = heatmap_batch * self.rbf_batch 75 | ## sum contributions from rbf from every source 76 | ## < batch x 1 x size x size > 77 | # obj_local = obj_local.sum(1) 78 | # print 'obj_out:', obj_out.size() 79 | 80 | obj_local = heatmap 81 | obj_global = self._global(global_coeffs) 82 | 83 | # pdb.set_trace() 84 | 85 | # obj_out = obj_local + obj_global 86 | 87 | map_pred = torch.cat( (obj_local, obj_global), 1 ) 88 | # pdb.set_trace() 89 | map_pred = self.simple_conv(map_pred) 90 | # map_pred = state_out + obj_out 91 | 92 | return map_pred 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /models/multi_norbf.py: -------------------------------------------------------------------------------- 1 | ## predicts entire value map 2 | ## rather than a single value 3 | 4 | import torch 5 | import math, torch.nn as nn, pdb 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | import models, utils 9 | 10 | class MultiNoRBF(nn.Module): 11 | def __init__(self, state_model, object_model, heatmap_model, args, map_dim = 10): 12 | super(MultiNoRBF, self).__init__() 13 | 14 | self.state_model = state_model 15 | self.object_model = object_model 16 | self.heatmap_model = heatmap_model 17 | self.simple_conv = models.SimpleConv(3).cuda() 18 | self.rbf = Variable( utils.meta_rbf(map_dim).cuda() ) 19 | self.positions = Variable( self.__init_positions(map_dim).cuda() ) 20 | 21 | self.map_dim = map_dim 22 | self.batch_size = args.batch_size 23 | self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1) 24 | self.positions_batch = self.positions.repeat(self.batch_size,1,1,1) 25 | 26 | ''' 27 | global_coeffs are < batch x 3 > 28 | 3: row, col, bias 29 | ''' 30 | def _global(self, global_coeffs): 31 | pos_coeffs = global_coeffs[:,:-1] 32 | bias = global_coeffs[:,-1] 33 | 34 | coeffs_batch = pos_coeffs.unsqueeze(-1).unsqueeze(-1).repeat(1,1,self.map_dim,self.map_dim) 35 | bias_batch = bias.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1,1,self.map_dim,self.map_dim) 36 | 37 | ## sum over row, col and add bias 38 | obj_global = (coeffs_batch * self.positions_batch).sum(1, keepdim=True) + bias_batch 39 | return obj_global 40 | 41 | 42 | 43 | def __init_positions(self, map_dim): 44 | row = torch.arange(0,map_dim).unsqueeze(1).repeat(1,map_dim) 45 | col = torch.arange(0,map_dim).repeat(map_dim,1) 46 | positions = torch.stack( (row, col) ) 47 | return positions 48 | 49 | 50 | def forward(self, inp): 51 | (state, obj, text) = inp 52 | batch_size = state.size(0) 53 | if batch_size != self.batch_size: 54 | self.batch_size = batch_size 55 | self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1) 56 | self.positions_batch = self.positions.repeat(self.batch_size,1,1,1) 57 | 58 | ## get state map 59 | state_out = self.state_model(state) 60 | obj_out = self.object_model.forward(obj) 61 | 62 | # print 'state_out: ', state_out.size() 63 | # print 'obj_out: ', obj_out.size() 64 | 65 | ## get object map 66 | heatmap, global_coeffs = self.heatmap_model((obj_out, text)) 67 | 68 | ## repeat heatmap for multiplication by rbf batch 69 | # heatmap_batch = heatmap.view(self.batch_size,self.map_dim**2,1,1).repeat(1,1,self.map_dim,self.map_dim) 70 | 71 | ## multiply object map by pre-computed manhattan rbf 72 | ## < batch x size^2 x size x size > 73 | # obj_local = heatmap_batch * self.rbf_batch 74 | ## sum contributions from rbf from every source 75 | ## < batch x 1 x size x size > 76 | # obj_local = obj_local.sum(1) 77 | # print 'obj_out:', obj_out.size() 78 | 79 | obj_local = heatmap 80 | obj_global = self._global(global_coeffs) 81 | 82 | self.output_local = obj_local 83 | self.output_global = obj_global 84 | 85 | # pdb.set_trace() 86 | 87 | # obj_out = obj_local + obj_global 88 | map_pred = torch.cat( (state_out, obj_local, obj_global), 1 ) 89 | # pdb.set_trace() 90 | map_pred = self.simple_conv(map_pred) 91 | # map_pred = state_out + obj_out 92 | 93 | return map_pred 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /models/object_model.py: -------------------------------------------------------------------------------- 1 | import sys, math 2 | import numpy as np 3 | from tqdm import tqdm, trange 4 | import torch, torch.nn as nn, torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | ''' 8 | Object observations are single-channels images 9 | with positive indices denoting objects. 10 | 0's denote no object. 11 | 12 | ''' 13 | 14 | class ObjectModel(nn.Module): 15 | def __init__(self, vocab_size, embed_dim, inp_size, out_dim): 16 | super(ObjectModel, self).__init__() 17 | 18 | self.reshape = [-1] 19 | for dim in inp_size: 20 | self.reshape.append(dim) 21 | self.reshape.append(embed_dim) 22 | 23 | self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0) 24 | self.conv1 = nn.Conv2d(embed_dim, 3, kernel_size=5) 25 | self.conv2 = nn.Conv2d(3, 6, kernel_size=5) 26 | self.conv3 = nn.Conv2d(6,12, kernel_size=5) 27 | self.conv4 = nn.Conv2d(12,12, kernel_size=5) 28 | self.fc1 = nn.Linear(192, out_dim) 29 | 30 | def forward(self, x): 31 | x = x.view(-1) 32 | x = self.embed(x) 33 | x = x.view(*self.reshape) 34 | x = x.transpose(1,-1).squeeze() 35 | x = F.relu(self.conv1(x)) 36 | x = F.relu(self.conv2(x)) 37 | x = F.relu(self.conv3(x)) 38 | x = F.relu(self.conv4(x)) 39 | x = x.view(-1, 192) 40 | x = self.fc1(x) 41 | return x 42 | 43 | if __name__ == '__main__': 44 | inp = torch.LongTensor(1,20,20).zero_() 45 | vocab_size = 10 46 | emb_dim = 3 47 | rank = 10 48 | phi = ObjectModel(vocab_size, emb_dim,inp.size(), rank) 49 | 50 | # enc = nn.Embedding(10,emb_dim,padding_idx=0) 51 | inp = torch.LongTensor(5,1,20,20).zero_() 52 | inp[0][0][0][0]=1 53 | # inp[0][1][0][0]=1 54 | inp[1][0][0][2]=1 55 | print inp 56 | inp = Variable(inp.view(-1)) 57 | 58 | out = phi.forward(inp) 59 | # print out 60 | # out = out.view(-1,2,3,3,emb_dim) 61 | out = out.data 62 | print out.size() 63 | 64 | # print out[0][0][0] 65 | # print out[1][0][0] 66 | 67 | -------------------------------------------------------------------------------- /models/object_model_10.py: -------------------------------------------------------------------------------- 1 | import sys, math 2 | import numpy as np 3 | from tqdm import tqdm, trange 4 | import torch, torch.nn as nn, torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | ''' 8 | Object observations are single-channels images 9 | with positive indices denoting objects. 10 | 0's denote no object. 11 | 12 | ''' 13 | 14 | class ObjectModel(nn.Module): 15 | def __init__(self, vocab_size, embed_dim, inp_size, out_dim): 16 | super(ObjectModel, self).__init__() 17 | 18 | self.reshape = [-1] 19 | for dim in inp_size: 20 | self.reshape.append(dim) 21 | self.reshape.append(embed_dim) 22 | 23 | self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0) 24 | self.conv1 = nn.Conv2d(embed_dim, 3, kernel_size=3) 25 | self.conv2 = nn.Conv2d(3, 6, kernel_size=3) 26 | self.conv3 = nn.Conv2d(6,12, kernel_size=3) 27 | # self.conv4 = nn.Conv2d(12,12, kernel_size=5) 28 | self.fc1 = nn.Linear(192, out_dim) 29 | 30 | def forward(self, x): 31 | x = x.view(-1) 32 | x = self.embed(x) 33 | x = x.view(*self.reshape) 34 | x = x.transpose(1,-1).squeeze() 35 | x = F.relu(self.conv1(x)) 36 | x = F.relu(self.conv2(x)) 37 | x = F.relu(self.conv3(x)) 38 | # x = F.relu(self.conv4(x)) 39 | x = x.view(-1, 192) 40 | x = self.fc1(x) 41 | return x 42 | 43 | if __name__ == '__main__': 44 | from torch.autograd import Variable 45 | inp = torch.LongTensor(1,10,10).zero_() 46 | vocab_size = 10 47 | emb_dim = 3 48 | rank = 10 49 | phi = ObjectModel(vocab_size, emb_dim,inp.size(), rank) 50 | 51 | # enc = nn.Embedding(10,emb_dim,padding_idx=0) 52 | inp = torch.LongTensor(5,1,20,20).zero_() 53 | inp[0][0][0][0]=1 54 | # inp[0][1][0][0]=1 55 | inp[1][0][0][2]=1 56 | # print inp 57 | inp = Variable(inp.view(-1)) 58 | 59 | out = phi.forward(inp) 60 | # print out 61 | # out = out.view(-1,2,3,3,emb_dim) 62 | # out = out.data 63 | print out.size() 64 | loss = out.sum() 65 | loss.backward() 66 | 67 | print phi.parameters 68 | print loss 69 | print loss.grad 70 | # print a 71 | 72 | # print out[0][0][0] 73 | # print out[1][0][0] 74 | 75 | -------------------------------------------------------------------------------- /models/object_model_fc.py: -------------------------------------------------------------------------------- 1 | import sys, math 2 | import numpy as np 3 | from tqdm import tqdm, trange 4 | import torch, torch.nn as nn, torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | ''' 8 | Object observations are single-channels images 9 | with positive indices denoting objects. 10 | 0's denote no object. 11 | 12 | ''' 13 | 14 | class ObjectModel(nn.Module): 15 | def __init__(self, vocab_size, embed_dim, inp_size, out_dim): 16 | super(ObjectModel, self).__init__() 17 | 18 | self.reshape = [-1] 19 | for dim in inp_size: 20 | self.reshape.append(dim) 21 | self.reshape.append(embed_dim) 22 | 23 | self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0) 24 | self.conv1 = nn.Conv2d(embed_dim, 3, kernel_size=3) 25 | self.conv2 = nn.Conv2d(3, 6, kernel_size=3) 26 | self.conv3 = nn.Conv2d(6,12, kernel_size=3) 27 | self.conv4 = nn.Conv2d(12,12, kernel_size=5) 28 | self.fc1 = nn.Linear(192, out_dim) 29 | 30 | def forward(self, x): 31 | ### 1D 32 | # print 'x: ', x.size() 33 | x = x.view(-1) 34 | x = self.embed(x) 35 | # print 'embed: ', x.size() 36 | ### -1, C, M, N, embed 37 | x = x.view(*self.reshape) 38 | # print 'view: ', x.size() 39 | ### -1, embed, M, N, C 40 | x = x.transpose(1,-1).squeeze() 41 | x = F.relu(self.conv1(x)) 42 | x = F.relu(self.conv2(x)) 43 | x = F.relu(self.conv3(x)) 44 | # x = F.relu(self.conv4(x)) 45 | x = x.view(-1, 192) 46 | x = self.fc1(x) 47 | return x 48 | 49 | if __name__ == '__main__': 50 | from torch.autograd import Variable 51 | inp = torch.LongTensor(1,10,10).zero_() 52 | vocab_size = 10 53 | emb_dim = 3 54 | rank = 10 55 | phi = ObjectModel(vocab_size, emb_dim,inp.size(), rank) 56 | 57 | # enc = nn.Embedding(10,emb_dim,padding_idx=0) 58 | inp = torch.LongTensor(5,1,20,20).zero_() 59 | inp[0][0][0][0]=1 60 | # inp[0][1][0][0]=1 61 | inp[1][0][0][2]=1 62 | # print inp 63 | inp = Variable(inp.view(-1)) 64 | 65 | out = phi.forward(inp) 66 | # print out 67 | # out = out.view(-1,2,3,3,emb_dim) 68 | # out = out.data 69 | print out.size() 70 | loss = out.sum() 71 | loss.backward() 72 | 73 | print phi.parameters 74 | print loss 75 | print loss.grad 76 | # print a 77 | 78 | # print out[0][0][0] 79 | # print out[1][0][0] 80 | 81 | -------------------------------------------------------------------------------- /models/object_model_pos.py: -------------------------------------------------------------------------------- 1 | import sys, math 2 | import numpy as np 3 | from tqdm import tqdm, trange 4 | import torch, torch.nn as nn, torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.autograd import Variable 7 | import pdb 8 | 9 | ''' 10 | Object observations are single-channels images 11 | with positive indices denoting objects. 12 | 0's denote no object. 13 | 14 | ''' 15 | 16 | class ObjectModel(nn.Module): 17 | def __init__(self, vocab_size, embed_dim, inp_size, out_dim): 18 | super(ObjectModel, self).__init__() 19 | (self.num_objects, self.embed_dim) = inp_size 20 | self.hidden_dim_1 = 10 21 | self.hidden_dim_2 = 10 22 | self.out_dim = out_dim 23 | 24 | self.reshape = [-1] 25 | for dim in inp_size: 26 | self.reshape.append(dim) 27 | self.reshape.append(embed_dim) 28 | 29 | self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0) 30 | self.fc1 = nn.Linear(embed_dim + 2, self.hidden_dim_1) 31 | self.fc2 = nn.Linear(self.hidden_dim_1, self.hidden_dim_2) 32 | self.fc3 = nn.Linear(self.hidden_dim_2, self.out_dim) 33 | # self.fc1 = nn.Linear(embed_dim+2, 20) 34 | # self.conv1 = nn.Conv2d(embed_dim, 3, kernel_size=5) 35 | # self.conv2 = nn.Conv2d(3, 6, kernel_size=5) 36 | # self.conv3 = nn.Conv2d(6,12, kernel_size=5) 37 | # self.conv4 = nn.Conv2d(12,12, kernel_size=5) 38 | # self.fc1 = nn.Linear(192, out_dim) 39 | self.fc_temp = nn.Linear(self.hidden_dim_1, self.hidden_dim_2) 40 | 41 | ## batch x 5 x 3 42 | def forward(self, x): 43 | ## batch x 5 x 1 44 | indices = x[:,:,0] 45 | 46 | ## batch x 5 x embed 47 | embeddings = self.embed(indices) 48 | 49 | ## batch x 5 x pos 50 | positions = x[:,:,1:].float() 51 | 52 | ## join embeddings and positions 53 | ## batch x 5 x (embed + pos) 54 | x = torch.cat( (embeddings, positions), 2 ) 55 | 56 | ## reshape 57 | ## (batch * 5) x (embed + pos) 58 | x = x.view(-1, self.embed_dim + 2).float() 59 | 60 | ## fc1 61 | ## (batch * 5) x hidden1 62 | x = F.relu( self.fc1(x) ) 63 | x = F.relu( self.fc_temp(x) ) 64 | 65 | ## reshape 66 | ## batch x 5 x hidden1 67 | x = x.view(-1, self.num_objects, self.hidden_dim_1) 68 | 69 | ## get rid of middle object dimension 70 | ## batch x hidden1 71 | x = x.max(1)[0].squeeze() 72 | 73 | ## fc2 74 | ## batch x hidden2 75 | x = F.relu( self.fc2(x) ) 76 | 77 | ## fc3 78 | ## batch x out 79 | x = self.fc3(x) 80 | 81 | return x 82 | 83 | if __name__ == '__main__': 84 | inp = torch.LongTensor(5,3).zero_() 85 | vocab_size = 10 86 | emb_dim = 3 87 | rank = 10 88 | phi = ObjectModel(vocab_size, emb_dim, inp.size(), rank) 89 | 90 | # enc = nn.Embedding(10,emb_dim,padding_idx=0) 91 | inp = torch.LongTensor(2,5,3).zero_() 92 | inp[0][0]=1 93 | inp[0][1]=1 94 | inp[1][0]=8 95 | inp[1][1]=8 96 | print inp 97 | # inp[0][0][0][0]=1 98 | # inp[0][1][0][0]=1 99 | # inp[1][0][0][2]=1 100 | # print inp 101 | inp = Variable(inp) 102 | 103 | out = phi.forward(inp) 104 | # print out 105 | # out = out.view(-1,2,3,3,emb_dim) 106 | out = out.data 107 | print out.size() 108 | 109 | # print out[0][0][0] 110 | # print out[1][0][0] 111 | 112 | -------------------------------------------------------------------------------- /models/object_model_rnn.py: -------------------------------------------------------------------------------- 1 | import sys, math 2 | import numpy as np 3 | from tqdm import tqdm, trange 4 | import torch, torch.nn as nn, torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.autograd import Variable 7 | import pdb 8 | 9 | ''' 10 | Object observations are single-channels images 11 | with positive indices denoting objects. 12 | 0's denote no object. 13 | 14 | ''' 15 | 16 | class ObjectModel(nn.Module): 17 | def __init__(self, vocab_size, embed_dim, inp_size, out_dim): 18 | super(ObjectModel, self).__init__() 19 | (self.num_objects, self.embed_dim) = inp_size 20 | self.hidden_dim_1 = 10 21 | self.hidden_dim_2 = 10 22 | self.out_dim = out_dim 23 | 24 | self.reshape = [-1] 25 | for dim in inp_size: 26 | self.reshape.append(dim) 27 | self.reshape.append(embed_dim) 28 | 29 | self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0) 30 | self.fc1 = nn.Linear(embed_dim + 2, self.hidden_dim_1) 31 | self.fc2 = nn.Linear(self.hidden_dim_1, self.hidden_dim_2) 32 | self.fc3 = nn.Linear(self.hidden_dim_2, self.out_dim) 33 | 34 | self.ninp = self.embed_dim + 2 35 | self.nhid = 10 36 | self.nlayers = 1 37 | self.rnn = nn.LSTM(self.ninp, self.nhid, self.nlayers) 38 | # self.fc1 = nn.Linear(embed_dim+2, 20) 39 | # self.conv1 = nn.Conv2d(embed_dim, 3, kernel_size=5) 40 | # self.conv2 = nn.Conv2d(3, 6, kernel_size=5) 41 | # self.conv3 = nn.Conv2d(6,12, kernel_size=5) 42 | # self.conv4 = nn.Conv2d(12,12, kernel_size=5) 43 | # self.fc1 = nn.Linear(192, out_dim) 44 | self.fc_temp = nn.Linear(self.hidden_dim_1, self.hidden_dim_2) 45 | 46 | ## batch x 5 x 3 47 | def forward(self, x, hidden): 48 | ## batch x 5 x 1 49 | indices = x[:,:,0] 50 | 51 | ## batch x 5 x embed 52 | embeddings = self.embed(indices) 53 | 54 | ## batch x 5 x pos 55 | positions = x[:,:,1:].float() 56 | 57 | ## join embeddings and positions 58 | ## batch x 5 x (embed + pos) 59 | x = torch.cat( (embeddings, positions), 2 ) 60 | print 'cat: ', x.size() 61 | 62 | x = self.rnn(x, hidden) 63 | print 'x; ', x.size() 64 | 65 | # ## reshape 66 | # ## (batch * 5) x (embed + pos) 67 | # x = x.view(-1, self.embed_dim + 2).float() 68 | 69 | # ## fc1 70 | # ## (batch * 5) x hidden1 71 | # x = F.relu( self.fc1(x) ) 72 | # x = F.relu( self.fc_temp(x) ) 73 | 74 | # ## reshape 75 | # ## batch x 5 x hidden1 76 | # x = x.view(-1, self.num_objects, self.hidden_dim_1) 77 | 78 | # ## get rid of middle object dimension 79 | # ## batch x hidden1 80 | # x = x.max(1)[0].squeeze() 81 | 82 | # ## fc2 83 | # ## batch x hidden2 84 | # x = F.relu( self.fc2(x) ) 85 | 86 | # ## fc3 87 | # ## batch x out 88 | # x = self.fc3(x) 89 | 90 | return x 91 | 92 | def init_hidden(self, bsz): 93 | weight = next(self.parameters()).data 94 | return (Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()), 95 | Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())) 96 | 97 | if __name__ == '__main__': 98 | inp = torch.LongTensor(5,3).zero_() 99 | vocab_size = 10 100 | emb_dim = 3 101 | rank = 10 102 | phi = ObjectModel(vocab_size, emb_dim, inp.size(), rank) 103 | 104 | batch = 2 105 | hidden = phi.init_hidden(batch) 106 | 107 | # enc = nn.Embedding(10,emb_dim,padding_idx=0) 108 | inp = torch.LongTensor(batch,5,3).zero_() 109 | inp[0][0]=1 110 | inp[0][1]=1 111 | inp[1][0]=8 112 | inp[1][1]=8 113 | print inp 114 | # inp[0][0][0][0]=1 115 | # inp[0][1][0][0]=1 116 | # inp[1][0][0][2]=1 117 | # print inp 118 | inp = Variable(inp) 119 | 120 | out = phi.forward(inp, hidden) 121 | # print out 122 | # out = out.view(-1,2,3,3,emb_dim) 123 | out = out.data 124 | print out.size() 125 | 126 | # print out[0][0][0] 127 | # print out[1][0][0] 128 | 129 | -------------------------------------------------------------------------------- /models/simple_conv.py: -------------------------------------------------------------------------------- 1 | import sys, math, pdb 2 | import numpy as np 3 | from tqdm import tqdm, trange 4 | import torch, torch.nn as nn, torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | ''' 8 | State observations are two-channel images 9 | with 0: puddle, 1: grass, 2: agent 10 | 11 | ''' 12 | 13 | class SimpleConv(nn.Module): 14 | def __init__(self, in_channels): 15 | super(SimpleConv, self).__init__() 16 | 17 | self.in_channels = in_channels 18 | 19 | # self.embed = nn.Embedding(vocab_size, in_channels) 20 | self.conv1 = nn.Conv2d(in_channels, 3, kernel_size=3, padding=1) 21 | self.conv2 = nn.Conv2d(3, 6, kernel_size=3, padding=1) 22 | self.conv3 = nn.Conv2d(6, 12, kernel_size=3, padding=1) 23 | self.conv4 = nn.Conv2d(12,18, kernel_size=3, padding=1) 24 | self.conv5 = nn.Conv2d(18,24, kernel_size=3, padding=1) 25 | self.conv6 = nn.Conv2d(24,18, kernel_size=3, padding=1) 26 | self.conv7 = nn.Conv2d(18,12, kernel_size=3, padding=1) 27 | self.conv8 = nn.Conv2d(12, 6, kernel_size=3, padding=1) 28 | self.conv9 = nn.Conv2d(6, 3, kernel_size=3, padding=1) 29 | self.conv10 = nn.Conv2d(3, 1, kernel_size=3, padding=1) 30 | # # self.conv4 = nn.Conv2d(12,12, kernel_size=5) 31 | # self.fc1 = nn.Linear(192, out_dim) 32 | 33 | def forward(self, x): 34 | # reshape = [] 35 | # for dim in x.size(): reshape.append(dim) 36 | # reshape.append(self.in_channels) 37 | 38 | # ## reshape to vector 39 | # x = x.view(-1) 40 | # ## get embeddings 41 | # x = self.embed(x) 42 | # ## reshape to batch x channels x M x N x embed_dim 43 | # x = x.view(*reshape) 44 | # ## sum over channels in input 45 | # x = x.sum(1) 46 | # # pdb.set_trace() 47 | # ## reshape to batch x embed_dim x M x N 48 | # ## (treats embedding dims as channels) 49 | # x = x.transpose(1,-1)[:,:,:,:,0] #.squeeze() # 50 | # print 'SIZE:', x.size() 51 | # pdb.set_trace() 52 | x = F.relu(self.conv1(x)) 53 | x = F.relu(self.conv2(x)) 54 | x = F.relu(self.conv3(x)) 55 | # x = F.relu(self.conv4(x)) 56 | # x = F.relu(self.conv5(x)) 57 | # x = F.relu(self.conv6(x)) 58 | # x = F.relu(self.conv7(x)) 59 | x = F.relu(self.conv8(x)) 60 | x = F.relu(self.conv9(x)) 61 | x = self.conv10(x) 62 | 63 | # x = x.view(-1, 192) 64 | # x = self.fc1(x) 65 | return x 66 | 67 | 68 | if __name__ == '__main__': 69 | from torch.autograd import Variable 70 | # inp = torch.LongTensor(2,10,10).zero_() 71 | vocab_size = 10 72 | emb_dim = 3 73 | rank = 7 74 | phi = MapModel(vocab_size, emb_dim, rank) 75 | 76 | # enc = nn.Embedding(10,emb_dim,padding_idx=0) 77 | inp = torch.LongTensor(5,2,10,10).zero_() 78 | inp[0][0][0][0]=1 79 | # inp[0][1][0][0]=1 80 | inp[1][0][0][2]=1 81 | print inp.size() 82 | inp = Variable(inp) 83 | 84 | out = phi.forward(inp) 85 | # print out 86 | # out = out.view(-1,2,3,3,emb_dim) 87 | out = out.data 88 | print out.size() 89 | 90 | # print out[0][0][0] 91 | # print out[1][0][0] 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /models/state_model.py: -------------------------------------------------------------------------------- 1 | import sys, math 2 | import numpy as np 3 | from tqdm import tqdm, trange 4 | import torch, torch.nn as nn, torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | ''' 8 | State observations are two-channel images 9 | with 0: puddle, 1: grass, 2: agent 10 | 11 | ''' 12 | 13 | class Phi(nn.Module): 14 | def __init__(self, vocab_size, embed_dim, inp_size, out_dim): 15 | super(Phi, self).__init__() 16 | 17 | self.reshape = [-1] 18 | for dim in inp_size: 19 | self.reshape.append(dim) 20 | self.reshape.append(embed_dim) 21 | 22 | self.embed = nn.Embedding(vocab_size, embed_dim) 23 | self.conv1 = nn.Conv2d(embed_dim, 3, kernel_size=5) 24 | self.conv2 = nn.Conv2d(3, 6, kernel_size=5) 25 | self.conv3 = nn.Conv2d(6,12, kernel_size=5) 26 | self.conv4 = nn.Conv2d(12,12, kernel_size=5) 27 | self.fc1 = nn.Linear(192, out_dim) 28 | 29 | def forward(self, x): 30 | x = x.view(-1) 31 | x = self.embed(x) 32 | x = x.view(*self.reshape) 33 | x = x.sum(1, keepdim=True) 34 | x = x.transpose(1,-1).squeeze() 35 | x = F.relu(self.conv1(x)) 36 | x = F.relu(self.conv2(x)) 37 | x = F.relu(self.conv3(x)) 38 | x = F.relu(self.conv4(x)) 39 | x = x.view(-1, 192) 40 | x = self.fc1(x) 41 | return x 42 | 43 | 44 | if __name__ == '__main__': 45 | inp = torch.LongTensor(2,20,20).zero_() 46 | vocab_size = 10 47 | emb_dim = 3 48 | rank = 7 49 | phi = Phi(vocab_size, emb_dim,inp.size(), rank) 50 | 51 | # enc = nn.Embedding(10,emb_dim,padding_idx=0) 52 | inp = torch.LongTensor(5,2,20,20).zero_() 53 | inp[0][0][0][0]=1 54 | inp[0][1][0][0]=1 55 | inp[1][0][0][2]=1 56 | print inp 57 | inp = Variable(inp.view(-1)) 58 | 59 | out = phi.forward(inp) 60 | # print out 61 | # out = out.view(-1,2,3,3,emb_dim) 62 | out = out.data 63 | print out.size() 64 | 65 | # print out[0][0][0] 66 | # print out[1][0][0] 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /models/state_model_10.py: -------------------------------------------------------------------------------- 1 | import sys, math 2 | import numpy as np 3 | from tqdm import tqdm, trange 4 | import torch, torch.nn as nn, torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | ''' 8 | State observations are two-channel images 9 | with 0: puddle, 1: grass, 2: agent 10 | 11 | ''' 12 | 13 | class Phi(nn.Module): 14 | def __init__(self, vocab_size, embed_dim, inp_size, out_dim): 15 | super(Phi, self).__init__() 16 | 17 | self.reshape = [-1] 18 | for dim in inp_size: 19 | self.reshape.append(dim) 20 | self.reshape.append(embed_dim) 21 | 22 | self.embed = nn.Embedding(vocab_size, embed_dim) 23 | self.conv1 = nn.Conv2d(embed_dim, 3, kernel_size=3) 24 | self.conv2 = nn.Conv2d(3, 6, kernel_size=3) 25 | self.conv3 = nn.Conv2d(6,12, kernel_size=3) 26 | # self.conv4 = nn.Conv2d(12,12, kernel_size=5) 27 | self.fc1 = nn.Linear(192, out_dim) 28 | 29 | def forward(self, x): 30 | x = x.view(-1) 31 | x = self.embed(x) 32 | x = x.view(*self.reshape) 33 | x = x.sum(1, keepdim=True) 34 | x = x.transpose(1,-1).squeeze() 35 | x = F.relu(self.conv1(x)) 36 | x = F.relu(self.conv2(x)) 37 | x = F.relu(self.conv3(x)) 38 | # x = F.relu(self.conv4(x)) 39 | x = x.view(-1, 192) 40 | x = self.fc1(x) 41 | return x 42 | 43 | 44 | if __name__ == '__main__': 45 | from torch.autograd import Variable 46 | inp = torch.LongTensor(2,10,10).zero_() 47 | vocab_size = 10 48 | emb_dim = 3 49 | rank = 7 50 | phi = Phi(vocab_size, emb_dim,inp.size(), rank) 51 | 52 | # enc = nn.Embedding(10,emb_dim,padding_idx=0) 53 | inp = torch.LongTensor(5,2,20,20).zero_() 54 | inp[0][0][0][0]=1 55 | inp[0][1][0][0]=1 56 | inp[1][0][0][2]=1 57 | print inp 58 | inp = Variable(inp.view(-1)) 59 | 60 | out = phi.forward(inp) 61 | # print out 62 | # out = out.view(-1,2,3,3,emb_dim) 63 | out = out.data 64 | print out.size() 65 | 66 | # print out[0][0][0] 67 | # print out[1][0][0] 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /models/text_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | 6 | ''' 7 | Text inputs are seq x batch 8 | ''' 9 | class TextModel(nn.Module): 10 | 11 | def __init__(self, vocab_size, ninp, nhid, nlayers, out_dim): 12 | super(TextModel, self).__init__() 13 | 14 | self.rnn_type = 'LSTM' 15 | self.nhid = nhid 16 | self.nlayers = nlayers 17 | self.out_dim = out_dim 18 | 19 | self.encoder = nn.Embedding(vocab_size, ninp, padding_idx=0) 20 | self.rnn = nn.LSTM(ninp, nhid, nlayers) 21 | self.decoder = nn.Linear(nhid, out_dim) 22 | self.init_weights() 23 | 24 | def init_weights(self): 25 | initrange = 0.1 26 | # self.encoder.weight.data.uniform_(-initrange, initrange) 27 | self.decoder.bias.data.fill_(0) 28 | self.decoder.weight.data.uniform_(-initrange, initrange) 29 | 30 | def forward(self, inp, hidden): 31 | emb = self.encoder(inp) 32 | # print 'emb: ', emb.size() 33 | output, hidden = self.rnn(emb, hidden) 34 | # print 'TEXT: ', output.size() 35 | final_output = output[-1,:,:] 36 | decoded = self.decoder(final_output) 37 | return decoded 38 | 39 | def init_hidden(self, bsz): 40 | weight = next(self.parameters()).data 41 | return (Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()), 42 | Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())) 43 | 44 | 45 | if __name__ == '__main__': 46 | batch = 12 47 | seq = 4 48 | vocab = 20 49 | ninp = 5 50 | nhid = 3 51 | nlayers = 2 52 | out_dim = 7 53 | rnn = TextModel(vocab, ninp, nhid, nlayers, out_dim) 54 | 55 | hidden = rnn.init_hidden(batch) 56 | ## inds x batch 57 | ## 58 | inp = torch.floor(torch.rand(seq,batch)*vocab).long() 59 | inp[0,:5] = 0 60 | inp[1,:5] = 0 61 | # inp[2,:] = 0 62 | # inp[3,:] = 0 63 | inp = Variable(inp) 64 | # inp = Variable(torch.LongTensor((1,2,3))) 65 | print 'INPUT: ', inp 66 | print inp.size() 67 | out = rnn.forward( inp,hidden ) 68 | 69 | print 'OUT: ', out 70 | print out.size() 71 | # print 'HID: ', hid 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /models/uvfa_pos.py: -------------------------------------------------------------------------------- 1 | ## predicts entire value map 2 | ## rather than a single value 3 | 4 | import torch 5 | import math, torch.nn as nn, pdb 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | import models, utils 9 | 10 | class UVFA_pos(nn.Module): 11 | def __init__(self, state_vocab, object_vocab, args, map_dim = 10, batch_size = 32): 12 | super(UVFA_pos, self).__init__() 13 | 14 | self.state_vocab = state_vocab 15 | self.object_vocab = object_vocab 16 | self.total_vocab = state_vocab + object_vocab 17 | self.pos_size = 2 18 | 19 | self.rank = args.rank 20 | self.map_dim = map_dim 21 | self.batch_size = batch_size 22 | self.positions = self.__agent_pos() 23 | 24 | ## add one for agent position 25 | self.input_dim = (self.total_vocab + 1) * (map_dim**2) 26 | self.world_layers = [self.input_dim, 128, 128, args.rank] 27 | self.world_mlp = models.MLP(self.world_layers) 28 | 29 | # self.object_dim = self.object_vocab * (map_dim**2) 30 | self.pos_layers = [self.pos_size, 128, 128, args.rank] 31 | self.pos_mlp = models.MLP(self.pos_layers) 32 | 33 | ''' 34 | returns tensor with one-hot vector encoding 35 | [1, 2, 3, ..., map_dim] repeated batch_size times 36 | < batch_size * map_dim, state_vocab > 37 | ''' 38 | def __agent_pos(self): 39 | size = self.map_dim**2 40 | positions = torch.zeros(self.batch_size*size, 100, 1) 41 | # print positions.size() 42 | for ind in range(size): 43 | # print ind, ind*self.batch_size, (ind+1)*self.batch_size, ind, positions.size() 44 | # positions[ind*self.batch_size:(ind+1)*self.batch_size, ind] = 1 45 | positions[ind:self.batch_size*size:size, ind] = 1 46 | # pdb.set_trace() 47 | return Variable( positions.cuda() ) 48 | 49 | def __repeat_position(self, x): 50 | if x.size() == 2: 51 | return x.unsqueeze(1).repeat(1,self.map_dim**2,1) 52 | else: 53 | return x.unsqueeze(1).repeat(1,self.map_dim**2,1,1) 54 | 55 | ''' 56 | < batch_size x N > 57 | < batch_size*100 x N > 58 | ''' 59 | def __construct_inp(self, world, pos): 60 | world = self.__repeat_position(world) 61 | world = world.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.total_vocab) 62 | ## add agent position 63 | world = torch.cat( (world, self.positions), -1) 64 | ## reshape to (batched) vector for input to MLPs 65 | world = world.view(-1, self.input_dim) 66 | 67 | # obj = self.__repeat_position(obj) 68 | # obj = obj.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.object_vocab) 69 | # obj = obj.view(-1, self.object_dim) 70 | 71 | pos = self.__repeat_position(pos) 72 | # pos = pos.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.pos_size) 73 | pos = pos.view(-1, self.pos_size) 74 | 75 | return world, pos 76 | 77 | 78 | def forward(self, inp): 79 | (state, obj, pos) = inp 80 | batch_size = state.size(0) 81 | # text = text.transpose(0,1) 82 | # hidden = self.lstm.init_hidden(batch_size * self.map_dim**2) 83 | 84 | if batch_size != self.batch_size: 85 | self.batch_size = batch_size 86 | self.positions = self.__agent_pos() 87 | 88 | ## reshape to (batched) vectors 89 | ## can't scatter Variables 90 | state = state.data.view(-1, self.map_dim**2, 1) 91 | obj = obj.data.view(-1, self.map_dim**2, 1) 92 | 93 | ## make state / object indices into one-hot vectors 94 | state_binary = torch.zeros(batch_size, self.map_dim**2, self.total_vocab).cuda() 95 | object_binary = torch.zeros(batch_size, self.map_dim**2, self.total_vocab).cuda() 96 | state_binary.scatter_(2, state, 1) 97 | object_binary.scatter_(2, obj+self.state_vocab, 1) 98 | 99 | ## < batch x 100 x total_vocab > 100 | ## state_binary will only have non-zero components in the first state_vocab components 101 | ## object_binary will only have non-zero components in state_vocab:total_vocab components 102 | input_binary = state_binary + object_binary 103 | # pdb.set_trace() 104 | 105 | input_binary = Variable( input_binary ) 106 | # object_binary = Variable( object_binary ) 107 | # print input_binary.size(), pos.size() 108 | # pdb.set_trace() 109 | input_binary, pos = self.__construct_inp(input_binary, pos) 110 | 111 | # print state_binary.size(), object_binary.size(), text.size() 112 | # object_binary = self.__repeat_position(object_binary) 113 | # object_binary = object_binary.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.object_vocab) 114 | 115 | ## add in agent position 116 | ## < batch x 100 x 2 > 117 | ## < batch x 100 x 100 x 2 > 118 | # state_binary = self.__repeat_position(state_binary) 119 | ## < batch*100 x 100 x 2 > 120 | # state_binary = state_binary.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.state_vocab) 121 | 122 | ## add agent position 123 | # state_binary = torch.cat( (state_binary, self.positions), -1) 124 | 125 | # pdb.set_trace() 126 | # print state_binary.size(), object_binary.size() 127 | 128 | ## reshape to (batched) vectors for input to MLPs 129 | ## turn back into Variables for backprop 130 | # state_binary = state_binary.view(-1, self.state_dim) 131 | # object_binary = object_binary.view(-1, self.object_dim) 132 | # print input_binary.size() 133 | # pdb.set_trace() 134 | ## < batch*100 x rank > 135 | world_out = self.world_mlp(input_binary) 136 | pos_out = self.pos_mlp(pos) 137 | 138 | # lstm_out = self.lstm.forward(text, hidden) 139 | 140 | 141 | # print lstm_out.size() 142 | 143 | # print world_out.size(), pos_out.size() 144 | 145 | values = world_out * pos_out 146 | map_pred = values.sum(1, keepdim=True).view(self.batch_size, self.map_dim, self.map_dim) 147 | 148 | 149 | return map_pred 150 | 151 | -------------------------------------------------------------------------------- /models/uvfa_text.py: -------------------------------------------------------------------------------- 1 | ## predicts entire value map 2 | ## rather than a single value 3 | 4 | import torch 5 | import math, torch.nn as nn, pdb 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | import models, utils 9 | 10 | class UVFA_text(nn.Module): 11 | def __init__(self, lstm, state_vocab, object_vocab, args, map_dim = 10, batch_size = 32): 12 | super(UVFA_text, self).__init__() 13 | 14 | self.state_vocab = state_vocab 15 | self.object_vocab = object_vocab 16 | self.lstm = lstm 17 | self.rank = args.rank 18 | self.map_dim = map_dim 19 | self.batch_size = batch_size 20 | self.positions = self.__agent_pos() 21 | 22 | ## add one for agent position 23 | self.state_dim = (self.state_vocab+1) * (map_dim**2) 24 | # self.state_dim = self.state_vocab * map_dim**2 25 | self.state_layers = [self.state_dim, 128, 128, args.rank] 26 | self.state_mlp = models.MLP(self.state_layers) 27 | 28 | self.object_dim = self.object_vocab * (map_dim**2) 29 | self.object_layers = [self.object_dim, 128, 128, args.rank] 30 | self.object_mlp = models.MLP(self.object_layers) 31 | 32 | ''' 33 | returns tensor with one-hot vector encoding 34 | [1, 2, 3, ..., map_dim] repeated batch_size times 35 | < batch_size * map_dim, state_vocab > 36 | ''' 37 | def __agent_pos(self): 38 | size = self.map_dim**2 39 | positions = torch.zeros(self.batch_size*size, 100, 1) 40 | # print positions.size() 41 | for ind in range(size): 42 | # print ind, ind*self.batch_size, (ind+1)*self.batch_size, ind, positions.size() 43 | # positions[ind*self.batch_size:(ind+1)*self.batch_size, ind] = 1 44 | positions[ind:self.batch_size*size:size, ind] = 1 45 | # pdb.set_trace() 46 | return Variable( positions.cuda() ) 47 | 48 | def __repeat_position(self, x): 49 | if x.size() == 2: 50 | return x.unsqueeze(1).repeat(1,self.map_dim**2,1) 51 | else: 52 | return x.unsqueeze(1).repeat(1,self.map_dim**2,1,1) 53 | 54 | ''' 55 | < batch_size x N > 56 | < batch_size*100 x N > 57 | ''' 58 | def __construct_inp(self, state, obj, text): 59 | state = self.__repeat_position(state) 60 | state = state.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.state_vocab) 61 | ## add agent position 62 | state = torch.cat( (state, self.positions), -1) 63 | ## reshape to (batched) vector for input to MLPs 64 | state = state.view(-1, self.state_dim) 65 | 66 | obj = self.__repeat_position(obj) 67 | obj = obj.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.object_vocab) 68 | obj = obj.view(-1, self.object_dim) 69 | 70 | instr_length = text.size(1) 71 | ## < batch x length > 72 | ## < batch x 100 x length > 73 | text = self.__repeat_position(text) 74 | ## < batch*100 x length > 75 | text = text.view(self.batch_size*self.map_dim**2,instr_length) 76 | ## < length x batch*100 > 77 | text = text.transpose(0,1) 78 | ## < batch*100 x rank > 79 | 80 | return state, obj, text 81 | 82 | 83 | def forward(self, inp): 84 | (state, obj, text) = inp 85 | batch_size = state.size(0) 86 | # text = text.transpose(0,1) 87 | hidden = self.lstm.init_hidden(batch_size * self.map_dim**2) 88 | 89 | if batch_size != self.batch_size: 90 | self.batch_size = batch_size 91 | self.positions = self.__agent_pos() 92 | 93 | ## reshape to (batched) vectors 94 | ## can't scatter Variables 95 | state = state.data.view(-1, self.map_dim**2, 1) 96 | obj = obj.data.view(-1, self.map_dim**2, 1) 97 | 98 | ## make state / object indices into one-hot vectors 99 | state_binary = torch.zeros(batch_size, self.map_dim**2, self.state_vocab).cuda() 100 | object_binary = torch.zeros(batch_size, self.map_dim**2, self.object_vocab).cuda() 101 | state_binary.scatter_(2, state, 1) 102 | object_binary.scatter_(2, obj, 1) 103 | 104 | state_binary = Variable( state_binary ) 105 | object_binary = Variable( object_binary ) 106 | 107 | state_binary, object_binary, text = self.__construct_inp(state_binary, object_binary, text) 108 | 109 | # print state_binary.size(), object_binary.size(), text.size() 110 | # object_binary = self.__repeat_position(object_binary) 111 | # object_binary = object_binary.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.object_vocab) 112 | 113 | ## add in agent position 114 | ## < batch x 100 x 2 > 115 | ## < batch x 100 x 100 x 2 > 116 | # state_binary = self.__repeat_position(state_binary) 117 | ## < batch*100 x 100 x 2 > 118 | # state_binary = state_binary.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.state_vocab) 119 | 120 | ## add agent position 121 | # state_binary = torch.cat( (state_binary, self.positions), -1) 122 | 123 | # pdb.set_trace() 124 | # print state_binary.size(), object_binary.size() 125 | 126 | ## reshape to (batched) vectors for input to MLPs 127 | ## turn back into Variables for backprop 128 | # state_binary = state_binary.view(-1, self.state_dim) 129 | # object_binary = object_binary.view(-1, self.object_dim) 130 | 131 | ## < batch*100 x rank > 132 | state_out = self.state_mlp(state_binary) 133 | object_out = self.object_mlp(object_binary) 134 | 135 | lstm_out = self.lstm.forward(text, hidden) 136 | 137 | 138 | # print lstm_out.size() 139 | 140 | values = state_out * object_out * lstm_out 141 | map_pred = values.sum(1, keepdim=True).view(self.batch_size, self.map_dim, self.map_dim) 142 | 143 | 144 | return map_pred 145 | 146 | # def forward(self, inp): 147 | # (state, obj, text) = inp 148 | # batch_size = state.size(0) 149 | # text = text.transpose(0,1) 150 | # hidden = self.lstm.init_hidden(batch_size) 151 | 152 | # if batch_size != self.batch_size: 153 | # self.batch_size = batch_size 154 | # # self.positions = self.__agent_pos() 155 | 156 | # ## reshape to (batched) vectors 157 | # ## can't scatter Variables 158 | # state = state.data.view(-1, self.map_dim**2, 1) 159 | # obj = obj.data.view(-1, self.map_dim**2, 1) 160 | 161 | # ## make state / object indices into one-hot vectors 162 | # state_binary = torch.zeros(batch_size, self.map_dim**2, self.state_vocab).cuda() 163 | # object_binary = torch.zeros(batch_size, self.map_dim**2, self.object_vocab).cuda() 164 | # state_binary.scatter_(2, state, 1) 165 | # object_binary.scatter_(2, obj, 1) 166 | 167 | # state_binary = Variable( state_binary ) 168 | # object_binary = Variable( object_binary ) 169 | 170 | # ## add in agent position 171 | # ## < batch x 100 x 2 > 172 | # ## < batch x 100 x 100 x 2 > 173 | # # state_binary = state_binary.unsqueeze(1).repeat(1,self.map_dim**2,1,1) 174 | # # state_binary = self.__repeat_position(state_binary) 175 | # ## < batch*100 x 100 x 2 > 176 | # # state_binary = state_binary.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.state_vocab) 177 | 178 | # ## add agent position 179 | # # state_binary = torch.cat( (state_binary, self.positions), -1) 180 | 181 | # # pdb.set_trace() 182 | # # print state_binary.size(), object_binary.size() 183 | 184 | # ## reshape to (batched) vectors for input to MLPs 185 | # ## turn back into Variables for backprop 186 | # state_binary = state_binary.view(-1, self.state_dim) 187 | # object_binary = object_binary.view(-1, self.object_dim) 188 | # print 'state: ', state_binary.size(), object_binary.size() 189 | # state_out = self.state_mlp(state_binary) 190 | # object_out = self.object_mlp(object_binary) 191 | 192 | # # print state_out.size(), object_out.size() 193 | 194 | # lstm_out = self.lstm.forward(text, hidden) 195 | 196 | # # object_out = self.__repeat_position(object_out).view(-1, self.rank) 197 | # # lstm_out = self.__repeat_position(lstm_out).view(-1, self.rank) 198 | 199 | # # print state_out.size(), object_out.size(), lstm_out.size() 200 | 201 | # values = state_out * object_out * lstm_out 202 | # # map_pred = values.sum(1).view(self.batch_size, self.map_dim**2) 203 | # values = values.sum(1) 204 | # # print values.size() 205 | # map_pred = values.unsqueeze(-1).repeat(1,self.map_dim,self.map_dim) 206 | # print values.size(), map_pred.size() 207 | 208 | # return map_pred 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | -------------------------------------------------------------------------------- /models/values_factorized.py: -------------------------------------------------------------------------------- 1 | import torch, torch.nn as nn 2 | 3 | class ValuesFactorized(nn.Module): 4 | 5 | def __init__(self, state_model, goal_model): 6 | super(ValuesFactorized, self).__init__() 7 | 8 | self.state_model = state_model 9 | self.goal_model = goal_model 10 | 11 | # def forward_large(self, inputs, batch_size = 32): 12 | 13 | 14 | def forward(self, inp): 15 | state, objects, instructions = inp 16 | # print 'in Values: ', state.size(), objects.size(), instructions.size() 17 | state_embedding = self.state_model.forward(state) 18 | goal_embedding = self.goal_model.forward( (objects, instructions) ) 19 | 20 | values = state_embedding * goal_embedding 21 | values = values.sum(1, keepdim=True) 22 | # ## num_states x rank 23 | # num_states = state_embedding.size(0) 24 | # ## num_goals x rank 25 | # num_goals = goal_embedding.size(0) 26 | 27 | # ## num_states x num_goals x rank 28 | # state_rep = state_embedding.unsqueeze(1).repeat(1,num_goals,1) 29 | # goal_rep = goal_embedding.repeat(num_states,1,1) 30 | 31 | # values = state_rep * goal_rep 32 | # values = values.sum(2).squeeze() 33 | # values = values.transpose(0,1) 34 | 35 | return values 36 | 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from training import * 2 | from evaluation import * 3 | from agent import * 4 | from score_iteration import * 5 | -------------------------------------------------------------------------------- /pipeline/evaluation.py: -------------------------------------------------------------------------------- 1 | import os, math, pickle, torch, numpy as np, pdb, pipeline 2 | from torch.autograd import Variable 3 | import matplotlib; matplotlib.use('Agg') 4 | from matplotlib import cm 5 | from matplotlib import pyplot as plt 6 | from tqdm import tqdm 7 | 8 | 9 | ''' 10 | saves predictions of model, targets, and MDP info (rewards / terminal map) as pickle files 11 | inputs is a tuple of (layouts, objects, instruction_indices) 12 | assumes that save path already exists 13 | ''' 14 | def save_predictions(model, inputs, targets, rewards, terminal, text_vocab, save_path, prefix=''): 15 | ## wrap tensors in Variables to pass to model 16 | input_vars = ( Variable(tensor.contiguous()) for tensor in inputs ) 17 | predictions = model(input_vars) 18 | 19 | ## convert to numpy arrays for saving to disk 20 | predictions = predictions.data.cpu().numpy() 21 | targets = targets.cpu().numpy() 22 | 23 | ## save the predicted and target value maps 24 | ## as well as info about the MDP and instruction 25 | pickle.dump(predictions, open(os.path.join(save_path, prefix+'predictions.p'), 'wb') ) 26 | pickle.dump(targets, open(os.path.join(save_path, prefix+'targets.p'), 'wb') ) 27 | pickle.dump(rewards, open(os.path.join(save_path, prefix+'rewards.p'), 'wb') ) 28 | pickle.dump(terminal, open(os.path.join(save_path, prefix+'terminal.p'), 'wb') ) 29 | pickle.dump(text_vocab, open(os.path.join(save_path, prefix+'vocab.p'), 'wb') ) 30 | 31 | 32 | ''' 33 | test set is dict from 34 | world number --> (state_obs, goal_obs, instruct_inds, values) 35 | ''' 36 | 37 | def evaluate(model, test_set, savepath=None): 38 | progress = tqdm(total=len(test_set)) 39 | count = 0 40 | for key, (state_obs, goal_obs, instruct_words, instruct_inds, targets) in test_set.iteritems(): 41 | progress.update(1) 42 | 43 | state = Variable( torch.Tensor(state_obs).long().cuda() ) 44 | objects = Variable( torch.Tensor(goal_obs).long().cuda() ) 45 | instructions = Variable( torch.Tensor(instruct_inds).long().cuda() ) 46 | targets = torch.Tensor(targets) 47 | # print state.size(), objects.size(), instructions.size(), targets.size() 48 | 49 | preds = model.forward( (state, objects, instructions) ).data.cpu() 50 | 51 | state_dim = 1 52 | for dim in state.size()[-2:]: 53 | state_dim *= dim 54 | 55 | if savepath: 56 | num_goals = preds.size(0) / state_dim 57 | for goal_num in range(num_goals): 58 | lower = goal_num * state_dim 59 | upper = (goal_num + 1) * state_dim 60 | fullpath = os.path.join(savepath, \ 61 | str(key) + '_' + str(goal_num) + '.png') 62 | pred = preds[lower:upper].numpy() 63 | targ = targets[lower:upper].numpy() 64 | instr = instruct_words[lower] 65 | 66 | pipeline.visualize_value_map(pred, targ, fullpath, title=instr) 67 | 68 | 69 | def get_children(M, N): 70 | children = {} 71 | for i in range(M): 72 | for j in range(N): 73 | pos = (i,j) 74 | children[pos] = [] 75 | for di in range( max(i-1, 0), min(i+1, M-1)+1 ): 76 | for dj in range( max(j-1, 0), min(j+1, N-1)+1 ): 77 | child = (di, dj) 78 | if pos != child and (i == di or j == dj): 79 | children[pos].append( child ) 80 | return children 81 | 82 | 83 | ''' 84 | values is M x N map of predicted values 85 | ''' 86 | def get_policy(values): 87 | values = values.squeeze() 88 | M, N = values.shape 89 | states = [(i,j) for i in range(M) for j in range(N)] 90 | children = get_children( M, N ) 91 | policy = {} 92 | for state in states: 93 | reachable = children[state] 94 | selected = sorted(reachable, key = lambda x: values[x], reverse=True) 95 | policy[state] = selected[0] 96 | return policy 97 | 98 | def simulate(model, sim_set): 99 | # progress = tqdm(total=len(test_set)) 100 | steps_list = [] 101 | count = 0 102 | for key in tqdm(range(len(sim_set))): 103 | (state_obs, goal_obs, instruct_words, instruct_inds, targets, mdps) = sim_set[key] 104 | # progress.update(1) 105 | # print torch.Tensor(state_obs).long().cuda() 106 | state = Variable( torch.Tensor(state_obs).long().cuda() ) 107 | objects = Variable( torch.Tensor(goal_obs).long().cuda() ) 108 | instructions = Variable( torch.Tensor(instruct_inds).long().cuda() ) 109 | targets = torch.Tensor(targets) 110 | # print state.size(), objects.size(), instructions.size() 111 | 112 | preds = model.forward(state, objects, instructions).data.cpu().numpy() 113 | # print 'sim preds: ', preds.shape 114 | 115 | ## average over all goals 116 | num_goals = preds.shape[0] 117 | for ind in range(num_goals): 118 | # print ind 119 | mdp = mdps[ind] 120 | values = preds[ind,:] 121 | dim = int(math.sqrt(values.size)) 122 | positions = [(i,j) for i in range(dim) for j in range(dim)] 123 | # print 'dim: ', dim 124 | values = preds[ind,:].reshape(dim, dim) 125 | policy = mdp.get_policy(values) 126 | 127 | # plt.clf() 128 | # plt.pcolor(policy) 129 | 130 | 131 | ## average over all start positions 132 | for start_pos in positions: 133 | steps = mdp.simulate(policy, start_pos) 134 | steps_list.append(steps) 135 | # pdb.set_trace() 136 | # print 'simulating: ', start_pos, steps 137 | avg_steps = np.mean(steps_list) 138 | # print 'avg steps: ', avg_steps, len(steps_list), len(sim_set), num_goals 139 | return avg_steps 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /pipeline/run_eval.py: -------------------------------------------------------------------------------- 1 | #!/om/user/janner/anaconda2/envs/pytorch/bin/python 2 | 3 | ## norbf 4 | # python run_eval.py --save_path rl_logs_analysis/qos_1250epochs_reinforce_norbfmodel_localmode_1566turk_0.001lr_0.95gamma_4batch_1rep/ 5 | # python run_eval.py --save_path rl_logs_analysis/qos_reinforce_norbfmodel_globalmode_1566turk_0.001lr_0.95gamma_4batch_1rep/ 6 | 7 | ## uvfa-text 8 | # python run_eval.py --save_path rl_logs_analysis//reinforce_uvfa-textmodel_localmode_1566turk_0.001lr_0.95gamma_32batch_3rep/ 9 | # python run_eval.py --save_path rl_logs_analysis//reinforce_uvfa-textmodel_globalmode_1566turk_0.001lr_0.95gamma_32batch_2rep/ 10 | 11 | ## cnn-lstm 12 | # run_eval.py --save_path rl_logs_analysis/reinforce_cnn-lstmmodel_localmode_1566turk_0.001lr_0.95gamma_16batch_2rep/ 13 | # python run_eval.py --save_path rl_logs_analysis/reinforce_cnn-lstmmodel_globalmode_1566turk_0.001lr_0.95gamma_32batch_3rep/ 14 | 15 | import sys, os, subprocess, argparse, numpy as np, pickle, pdb 16 | sys.path.append('/om/user/janner/mit/urop/direction_decomposition/') 17 | import environment, pipeline, reinforce 18 | 19 | parser = argparse.ArgumentParser() 20 | # parser.add_argument('--save_path', type=str, default='logs/trial_nobases_test') 21 | parser.add_argument('--save_path', type=str, default='curves/local_full_1500_wsynth-0-18,20,21,22,23,24,25,26,27,49,53,20,30,40,50,35,45_4/') 22 | # parser.add_argument('--metric_path', type=str, default='curves/local_full_1500_wsynth-0-18,20,21,22,23,24,25,26,27,49,53,20,30,40,50,35,45_4/') 23 | args = parser.parse_args() 24 | args.metric_path = args.save_path 25 | 26 | print args.save_path 27 | predictions = pickle.load( open(os.path.join(args.save_path, 'test_predictions.p'), 'rb') ).squeeze() 28 | targets = pickle.load( open(os.path.join(args.save_path, 'test_targets.p'), 'rb') ).squeeze() 29 | rewards = pickle.load( open(os.path.join(args.save_path, 'test_rewards.p'), 'rb') ).squeeze() 30 | terminal = pickle.load( open(os.path.join(args.save_path, 'test_terminal.p'), 'rb') ).squeeze() 31 | 32 | rewards = rewards.cpu().numpy() 33 | terminal = terminal.cpu().numpy() 34 | 35 | def get_states(M, N): 36 | states = [(i,j) for i in range(M) for j in range(N)] 37 | return states 38 | 39 | def get_children(M, N): 40 | children = {} 41 | for i in range(M): 42 | for j in range(N): 43 | pos = (i,j) 44 | children[pos] = [] 45 | for di in range( max(i-1, 0), min(i+1, M-1)+1 ): 46 | for dj in range( max(j-1, 0), min(j+1, N-1)+1 ): 47 | child = (di, dj) 48 | if pos != child and (i == di or j == dj): 49 | children[pos].append( child ) 50 | return children 51 | 52 | ''' 53 | for each state (i,j) in states, returns a list 54 | of neighboring states with approximated values 55 | in descending order 56 | ''' 57 | def get_policy(values): 58 | values = values.squeeze() 59 | policy = {} 60 | for state in STATES: 61 | reachable = CHILDREN[state] 62 | selected = sorted(reachable, key = lambda x: values[x], reverse = True) 63 | policy[state] = selected 64 | return policy 65 | 66 | def simulate_single(reward_map, terminal_map, approx_values, start_pos, max_steps = 75): 67 | # world = world.squeeze() 68 | # reward_map = reward_map.data.squeeze() 69 | # terminal_map = terminal_map.data.squeeze() 70 | # approx_values = approx_values 71 | 72 | M, N = reward_map.shape 73 | # print 'M, N: ', M, N 74 | if M != 10 or N != 10: 75 | raise RuntimeError( 'wrong size: {}x{}, expected: {}x{}'.format(M, N, self.M, self.N) ) 76 | # self._refresh_size(M, N) 77 | policy = get_policy(approx_values) 78 | 79 | pos = start_pos 80 | visited = set([pos]) 81 | trajectory = [] 82 | 83 | total_reward = 0 84 | for step in range(max_steps): 85 | val = approx_values[pos] 86 | rew = reward_map[pos] 87 | term = terminal_map[pos] 88 | # print 'VAL: ', val, '\nREW: ', rew, '\nTERM: ', term, approx_values.size(), reward_map.size(), terminal_map.size() 89 | trajectory.append( (pos, val, rew, term) ) 90 | 91 | total_reward += rew * (GAMMA ** step) 92 | if term: 93 | # print 'GOT REWARD: ', rew, step, rew * (self.gamma ** step) 94 | # print '\n\nDONE\n\n\n\n' 95 | break 96 | 97 | reachable = policy[pos] 98 | selected = 0 99 | while selected < len(reachable) and reachable[selected] in visited: 100 | # print ' visited ', selected, reachable[selected] 101 | selected += 1 102 | if selected == len(reachable): 103 | # print '\n\nVISITED ALL', pos, [n in visited for n in reachable], '\n\n\n\n' 104 | selected = 0 105 | # return trajectory 106 | break 107 | 108 | pos = reachable[selected] 109 | visited.add(pos) 110 | # print 'position: ', pos 111 | # print 'traj: ', len(trajectory), 'rew: ', total_reward 112 | return total_reward 113 | 114 | STATES = get_states(10,10) 115 | CHILDREN = get_children(10,10) 116 | GAMMA = 0.95 117 | 118 | # pdb.set_trace() 119 | quality_path = os.path.join(args.metric_path, 'quality') 120 | if not os.path.exists( quality_path ): 121 | subprocess.call(['mkdir', quality_path]) 122 | 123 | num_worlds = targets.shape[0] 124 | print 'Num worlds: {}'.format(num_worlds) 125 | 126 | mse = np.sum(np.power(predictions - targets, 2)) / predictions.size 127 | print 'MSE: {}'.format(mse) 128 | 129 | cumulative_normed = 0 130 | manhattan = 0 131 | cumulative_per_score = 0 132 | cumulative_score = 0 133 | for ind in range(num_worlds): 134 | pred = predictions[ind] 135 | targ = targets[ind] 136 | 137 | pred_max = np.unravel_index(np.argmax(pred), pred.shape) 138 | targ_max = np.unravel_index(np.argmax(targ), targ.shape) 139 | man = abs(pred_max[0] - targ_max[0]) + abs(pred_max[1] - targ_max[1]) 140 | 141 | # pdb.set_trace() 142 | 143 | unif = np.ones( pred.shape ) 144 | rew = rewards[ind] 145 | term = terminal[ind] 146 | 147 | mdp = environment.MDP(None, rew, term) 148 | si = pipeline.ScoreIteration(mdp, pred) 149 | avg_pred, scores_pred = si.iterate() 150 | 151 | mdp = environment.MDP(None, rew, term) 152 | si = pipeline.ScoreIteration(mdp, targ) 153 | avg_targ, scores_targ = si.iterate() 154 | 155 | mdp = environment.MDP(None, rew, term) 156 | si = pipeline.ScoreIteration(mdp, unif) 157 | avg_unif, scores_unif = si.iterate() 158 | 159 | avg_per_score = np.divide(scores_pred-scores_unif, scores_targ-scores_unif) 160 | avg_per_score[avg_per_score != avg_per_score] = 1 161 | avg_per_score = np.mean(avg_per_score) 162 | # pdb.set_trace() 163 | 164 | start_pos = (np.random.randint(10), np.random.randint(10)) 165 | score = 0 #simulate_single(rew, term, pred, start_pos) 166 | 167 | normed = (avg_pred - avg_unif) / (avg_targ - avg_unif) 168 | cumulative_normed += normed 169 | manhattan += man 170 | cumulative_per_score += avg_per_score 171 | cumulative_score += score 172 | 173 | sys.stdout.write( '{:4}:\t{:4}\t{:4}\t{:4}\t{:4}\t{:4}\t{:.4}\t{:.4}\t{:.3}\r'.format(ind, avg_pred, avg_targ, avg_unif, normed, 174 | cumulative_normed / (ind + 1), cumulative_per_score / (ind + 1), float(manhattan) / (ind + 1), float(cumulative_score) / (ind + 1))) 175 | sys.stdout.flush() 176 | 177 | 178 | # fullpath = os.path.join(quality_path, str(ind) + '.png') 179 | # pipeline.visualize_value_map(scores_pred, scores_targ, fullpath) 180 | 181 | 182 | 183 | 184 | 185 | 186 | # pred_policy = pipeline.get_policy(pred) 187 | # targ_policy = pipeline.get_policy(targ) 188 | 189 | 190 | 191 | # corr = 0 192 | # for pos, targ in targ_policy.iteritems(): 193 | # pred = pred_policy[pos] 194 | # if pred == targ: 195 | # # print 'count: ', pred, targ 196 | # corr += 1 197 | # # else: 198 | # # print 'not counting: ', pred, targ 199 | # print corr, ' / ', len(targ_policy) 200 | # cumulative_correct += corr 201 | 202 | avg_normed = float(cumulative_normed) / num_worlds 203 | avg_manhattan = float(manhattan) / num_worlds 204 | avg_score = float(cumulative_score) / num_worlds 205 | print 'Avg normed: {}'.format(avg_normed) 206 | print 'Avg manhattan: {}'.format(avg_manhattan) 207 | print 'Avg score: {}'.format(avg_score) 208 | 209 | if args.metric_path != None: 210 | results = {'mse': mse, 'quality': avg_normed, 'manhattan': avg_manhattan} 211 | pickle.dump(results, open(os.path.join(args.metric_path, 'results.p'), 'wb')) 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | -------------------------------------------------------------------------------- /pipeline/score_iteration.py: -------------------------------------------------------------------------------- 1 | import numpy as np, pdb 2 | 3 | class ScoreIteration: 4 | 5 | def __init__(self, mdp, values): 6 | self.refresh(mdp, values) 7 | 8 | def _softmax(self, x): 9 | return np.exp(x) / np.sum(np.exp(x), axis=0) 10 | 11 | def get_children(M, N): 12 | children = {} 13 | for i in range(M): 14 | for j in range(N): 15 | pos = (i,j) 16 | children[pos] = [] 17 | for di in range( max(i-1, 0), min(i+1, M-1)+1 ): 18 | for dj in range( max(j-1, 0), min(j+1, N-1)+1 ): 19 | child = (di, dj) 20 | if pos != child and (i == di or j == dj): 21 | children[pos].append( child ) 22 | return children 23 | 24 | ''' 25 | values is M x N map of predicted values 26 | ''' 27 | def get_policy(values): 28 | values = values.squeeze() 29 | M, N = values.shape 30 | states = [(i,j) for i in range(M) for j in range(N)] 31 | children = get_children( M, N ) 32 | policy = {} 33 | for state in states: 34 | reachable = children[state] 35 | selected = sorted(reachable, key = lambda x: values[x], reverse=True) 36 | policy[state] = selected[0] 37 | return policy 38 | 39 | def refresh(self, mdp, values): 40 | self.mdp = mdp 41 | self.states = mdp.getStates() 42 | self.actions = mdp.getActions() 43 | self.transition = mdp.transition 44 | self.reward = mdp.reward 45 | self.terminal = mdp.terminal 46 | self.values = values 47 | self.scores = np.zeros( self.mdp.reward_map.shape ) 48 | self.discount = 0.9 49 | 50 | 51 | def iterate(self): 52 | old_scores = self.scores.copy() 53 | old_scores.fill(-float('inf')) 54 | 55 | count = 0 56 | while np.sum(np.abs(self.scores - old_scores)) > 0.01: 57 | old_scores = self.scores.copy() 58 | for state in self.states: 59 | max_val = self.reward(state) 60 | term = self.terminal(state) 61 | 62 | if not term: 63 | neighbors = [self.transition(state, action) for action in self.actions] 64 | values = [self.values[s] for s in neighbors] 65 | transition_probs = self._softmax(values) 66 | 67 | neighbor_contributions = [transition_probs[ind] * self.scores[neighbors[ind]] \ 68 | for ind in range(len(transition_probs))] 69 | 70 | max_val = self.reward(state) + self.discount * sum(neighbor_contributions) 71 | 72 | self.scores[state] = max_val 73 | count += 1 74 | 75 | avg_score = np.mean(self.scores) 76 | return avg_score, self.scores 77 | -------------------------------------------------------------------------------- /pipeline/training.py: -------------------------------------------------------------------------------- 1 | import sys, math 2 | import numpy as np 3 | from tqdm import tqdm, trange 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from torchvision import datasets, transforms 9 | from torch.autograd import Variable 10 | 11 | class Trainer: 12 | def __init__(self, model, lr, batch_size): 13 | self.model = model 14 | self.batch_size = batch_size 15 | self.criterion = nn.MSELoss(size_average=True).cuda() 16 | self.optimizer = optim.Adam(self.model.parameters(), lr=lr) 17 | 18 | def __epoch(self, inputs, targets, repeats = 1): 19 | self.model.train() 20 | if type(inputs) == tuple: 21 | data_size = inputs[0].size(0) 22 | else: 23 | data_size = inputs.size(0) 24 | num_batches = int(math.ceil(data_size / float(self.batch_size)) * repeats) 25 | 26 | err = 0 27 | for i in range(num_batches): 28 | inp, targ = self.__get_batch(inputs, targets) 29 | self.optimizer.zero_grad() 30 | out = self.model.forward(inp) 31 | loss = self.criterion(out, targ) 32 | loss.backward() 33 | self.optimizer.step() 34 | err += loss.data[0] 35 | err = err / float(num_batches) 36 | return err 37 | 38 | def __get_batch(self, inputs, targets): 39 | data_size = targets.size(0) 40 | 41 | inds = torch.floor(torch.rand(self.batch_size) * data_size).long().cuda() 42 | # bug: floor(rand()) sometimes gives 1 43 | inds[inds >= data_size] = data_size - 1 44 | 45 | if type(inputs) == tuple: 46 | inp = tuple([Variable( i.index_select(0, inds).cuda() ) for i in inputs]) 47 | else: 48 | inp = Variable( inputs.index_select(0, inds).cuda() ) 49 | 50 | targ = Variable( targets.index_select(0, inds).cuda() ) 51 | return inp, targ 52 | 53 | def train(self, inputs, targets, val_inputs, val_targets, iters = 10): 54 | t = trange(iters) 55 | for i in t: 56 | err = self.__epoch(inputs, targets) 57 | t.set_description( str(err) ) 58 | return self.model 59 | 60 | -------------------------------------------------------------------------------- /reinforcement.py: -------------------------------------------------------------------------------- 1 | #!/om/user/janner/anaconda2/envs/pytorch/bin/python 2 | 3 | import os, argparse, pickle, torch 4 | import pipeline, models, data, utils, visualization 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--save_path', type=str, default='logs/trial') 8 | parser.add_argument('--max_train', type=int, default=5000) 9 | parser.add_argument('--max_test', type=int, default=500) 10 | 11 | parser.add_argument('--mode', type=str, default='local', choices=['local', 'global']) 12 | parser.add_argument('--annotations', type=str, default='human', choices=['synthetic', 'human']) 13 | parser.add_argument('--model', type=str, default='full', choices=['full', 'no-gradient', 'cnn-lstm', 'uvfa-text']) 14 | 15 | parser.add_argument('--map_dim', type=int, default=10) 16 | parser.add_argument('--state_embed', type=int, default=1) 17 | parser.add_argument('--obj_embed', type=int, default=7) 18 | 19 | parser.add_argument('--lstm_inp', type=int, default=15) 20 | parser.add_argument('--lstm_hid', type=int, default=30) 21 | parser.add_argument('--lstm_layers', type=int, default=1) 22 | parser.add_argument('--attention_kernel', type=int, default=3) 23 | parser.add_argument('--attention_out_dim', type=int, default=1) 24 | 25 | parser.add_argument('--batch_size', type=int, default=4) 26 | parser.add_argument('--replay_size', type=int, default=100000) 27 | parser.add_argument('--learn_start', type=int, default=1000) 28 | parser.add_argument('--gamma', type=float, default=0.95) 29 | parser.add_argument('--lr', type=float, default=0.001) 30 | parser.add_argument('--epochs', type=int, default=1250) 31 | 32 | args = parser.parse_args() 33 | 34 | 35 | ################################# 36 | ############## Data ############# 37 | ################################# 38 | 39 | train_data, test_data = data.load(args.mode, args.annotations, args.max_train, args.max_test) 40 | layout_vocab_size, object_vocab_size, text_vocab_size, text_vocab = data.get_statistics(train_data, test_data) 41 | 42 | print '\n
Converting to tensors' 43 | train_layouts, train_objects, train_rewards, train_terminal, \ 44 | train_instructions, train_indices, train_values, train_goals = data.to_tensor(train_data, text_vocab) 45 | 46 | test_layouts, test_objects, test_rewards, test_terminal, \ 47 | test_instructions, test_indices, test_values, test_goals = data.to_tensor(test_data, text_vocab) 48 | 49 | print '
Training:', train_layouts.size(), 'x', train_objects.size(), 'x', train_indices.size() 50 | print '
Rewards: ', train_rewards.size(), ' Terminal: ', train_terminal.size() 51 | print '
Test :', test_layouts.size(), 'x', test_objects.size(), 'x', test_indices.size() 52 | print '
Rewards: ', test_rewards.size(), ' Terminal: ', test_terminal.size() 53 | 54 | 55 | ################################# 56 | ############ Training ########### 57 | ################################# 58 | 59 | print '\n
Initializing model: {}'.format(args.model) 60 | model = models.init(args, layout_vocab_size, object_vocab_size, text_vocab_size) 61 | target_model = models.init(args, layout_vocab_size, object_vocab_size, text_vocab_size) 62 | 63 | ## initialize agent 64 | agent = pipeline.Agent(model, target_model, map_dim = args.map_dim, instr_len = train_indices.size(1), 65 | batch_size = args.batch_size, learn_start = args.learn_start, 66 | replay_size = args.replay_size, lr = args.lr, gamma = args.gamma) 67 | 68 | train_inputs = (train_layouts, train_objects, train_indices) 69 | test_inputs = (test_layouts, test_objects, test_indices) 70 | 71 | ## train agent 72 | scores = agent.train( train_inputs, train_rewards, train_terminal, 73 | test_inputs, test_rewards, test_terminal, epochs = args.epochs ) 74 | 75 | 76 | ################################# 77 | ######## Save predictions ####### 78 | ################################# 79 | 80 | ## make logging directories 81 | pickle_path = os.path.join(args.save_path, 'pickle') 82 | utils.mkdir(args.save_path) 83 | utils.mkdir(pickle_path) 84 | 85 | print '\n
Saving model and scores to {}'.format(args.save_path) 86 | ## save model 87 | torch.save(model, os.path.join(args.save_path, 'model.pth')) 88 | ## save scores from training 89 | score_path = os.path.join(pickle_path, 'scores.p') 90 | pickle.dump(scores, open(score_path, 'wb') ) 91 | 92 | 93 | print '
Saving predictions to {}'.format(pickle_path) 94 | ## save inputs, outputs, and MDP info (rewards and terminal maps) 95 | pipeline.save_predictions(model, train_inputs, train_values, train_rewards, train_terminal, text_vocab, pickle_path, prefix='train_') 96 | pipeline.save_predictions(model, test_inputs, test_values, test_rewards, test_terminal, text_vocab, pickle_path, prefix='test_') 97 | 98 | 99 | ################################# 100 | ######### Visualization ######### 101 | ################################# 102 | 103 | vis_path = os.path.join(args.save_path, 'visualization') 104 | utils.mkdir(vis_path) 105 | 106 | print '
Saving visualizations to {}'.format(vis_path) 107 | 108 | ## save inputs, outputs, and MDP info (rewards and terminal maps) 109 | visualization.vis_predictions(model, train_inputs, train_values, train_instructions, vis_path, prefix='train_') 110 | visualization.vis_predictions(model, test_inputs, test_values, test_instructions, vis_path, prefix='test_') 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /representation.py: -------------------------------------------------------------------------------- 1 | #!/om/user/janner/anaconda2/envs/pytorch/bin/python 2 | 3 | import os, argparse, numpy as np, torch, pdb 4 | import pipeline, models, data, utils, visualization 5 | 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--save_path', type=str, default='logs/trial/') 9 | parser.add_argument('--max_train', type=int, default=5000) 10 | parser.add_argument('--max_test', type=int, default=500) 11 | 12 | parser.add_argument('--model', type=str, default='full', choices=['full', 'no-gradient', 'cnn-lstm', 'uvfa-text']) 13 | parser.add_argument('--annotations', type=str, default='human', choices=['synthetic', 'human']) 14 | parser.add_argument('--mode', type=str, default='local', choices=['local', 'global']) 15 | 16 | parser.add_argument('--map_dim', type=int, default=10) 17 | parser.add_argument('--state_embed', type=int, default=1) 18 | parser.add_argument('--obj_embed', type=int, default=7) 19 | 20 | parser.add_argument('--lstm_inp', type=int, default=15) 21 | parser.add_argument('--lstm_hid', type=int, default=30) 22 | parser.add_argument('--lstm_layers', type=int, default=1) 23 | parser.add_argument('--attention_kernel', type=int, default=3) 24 | parser.add_argument('--attention_out_dim', type=int, default=1) 25 | 26 | parser.add_argument('--batch_size', type=int, default=32) 27 | parser.add_argument('--lr', type=int, default=0.001) 28 | parser.add_argument('--iters', type=int, default=200) 29 | args = parser.parse_args() 30 | 31 | 32 | ################################# 33 | ############## Data ############# 34 | ################################# 35 | 36 | train_data, test_data = data.load(args.mode, args.annotations, args.max_train, args.max_test) 37 | layout_vocab_size, object_vocab_size, text_vocab_size, text_vocab = data.get_statistics(train_data, test_data) 38 | 39 | print '\n
Converting to tensors' 40 | train_layouts, train_objects, train_rewards, train_terminal, \ 41 | train_instructions, train_indices, train_values, train_goals = data.to_tensor(train_data, text_vocab) 42 | 43 | test_layouts, test_objects, test_rewards, test_terminal, \ 44 | test_instructions, test_indices, test_values, test_goals = data.to_tensor(test_data, text_vocab) 45 | 46 | print '
Training: (', train_layouts.size(), 'x', train_objects.size(), 'x', train_indices.size(), ') -->', train_values.size() 47 | print '
test : (', test_layouts.size(), 'x', test_objects.size(), 'x', test_indices.size(), ') -->', test_values.size() 48 | 49 | 50 | ################################# 51 | ############ Training ########### 52 | ################################# 53 | 54 | print '\n
Initializing model: {}'.format(args.model) 55 | model = models.init(args, layout_vocab_size, object_vocab_size, text_vocab_size) 56 | 57 | train_inputs = (train_layouts, train_objects, train_indices) 58 | test_inputs = (test_layouts, test_objects, test_indices) 59 | 60 | print '
Training model' 61 | trainer = pipeline.Trainer(model, args.lr, args.batch_size) 62 | trainer.train(train_inputs, train_values, test_inputs, test_values, iters=args.iters) 63 | 64 | 65 | ################################# 66 | ######## Save predictions ####### 67 | ################################# 68 | 69 | ## make logging directories 70 | pickle_path = os.path.join(args.save_path, 'pickle') 71 | utils.mkdir(args.save_path) 72 | utils.mkdir(pickle_path) 73 | 74 | print '\n
Saving model to {}'.format(args.save_path) 75 | ## save model 76 | torch.save(model, os.path.join(args.save_path, 'model.pth')) 77 | 78 | print '
Saving predictions to {}'.format(pickle_path) 79 | ## save inputs, outputs, and MDP info (rewards and terminal maps) 80 | pipeline.save_predictions(model, test_inputs, test_values, test_rewards, test_terminal, text_vocab, pickle_path, prefix='test_') 81 | 82 | 83 | ################################# 84 | ######### Visualization ######### 85 | ################################# 86 | 87 | vis_path = os.path.join(args.save_path, 'visualization') 88 | utils.mkdir(vis_path) 89 | 90 | print '
Saving visualizations to {}'.format(vis_path) 91 | ## save images with predicted and target value maps 92 | visualization.vis_predictions(model, test_inputs, test_values, test_instructions, vis_path, prefix='test_') 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.12.1 2 | scipy>=0.19.0 3 | matplotlib==2.0.0 4 | tqdm 5 | -------------------------------------------------------------------------------- /slurm/meta-generate.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | save_path = '../data/mturk_only_global_10dim_8max_2ref/' 4 | vis_path = '../data/mturk_only_global_10dim_8max_2ref_vis/' 5 | start = 200 6 | end = 300 7 | step = 1 8 | dim = 10 9 | mode = 'global' 10 | # only_global = False 11 | 12 | for lower in range(start, end, step): 13 | command = [ 'sbatch', '--qos=tenenbaum', '-c', '2', '--time=1-12:0', '-J', str(lower), 'generate_worlds.py', \ 14 | '--lower', str(lower), '--num_worlds', str(step), '--dim', str(dim), \ 15 | '--save_path', save_path, '--vis_path', vis_path, '--mode', mode] 16 | # print command 17 | subprocess.Popen( command ) -------------------------------------------------------------------------------- /slurm/meta_reinforce.sh: -------------------------------------------------------------------------------- 1 | # lr=0.001 2 | # gamma=0.95 3 | # kernel_out=1 4 | 5 | # for batch in 4 8 16 32 6 | # do 7 | # for worlds in 5 10 20 8 | # do 9 | # save_path=logs/reinforce_${lr}lr_${gamma}gamma_${batch}batch_${kernel_out}attn_${worlds}worlds 10 | # mkdir ${save_path} 11 | # sbatch -c 2 --gres=gpu:titan-x:1 --qos=tenenbaum --time=4-12:0 --mem=40G -J ${worlds}_${batch}_${lr} -o ${save_path}/out.txt run_reinforce.py \ 12 | # --lr ${lr} --attention_out_dim ${kernel_out} --num_worlds ${worlds} --save_path ${save_path} \ 13 | # --batch_size ${batch} --gamma ${gamma} 14 | # done 15 | # done 16 | 17 | lr=0.001 18 | gamma=0.95 19 | epochs=1250 20 | 21 | for rep in 1 22 | do 23 | for batch in 4 24 | do 25 | for model in norbf #uvfa-text cnn-lstm 26 | do 27 | for mode in local global 28 | do 29 | for num_turk in 1566 30 | do 31 | save_path=rl_logs_analysis/qos_${epochs}epochs_reinforce_${model}model_${mode}mode_${num_turk}turk_${lr}lr_${gamma}gamma_${batch}batch_${rep}rep 32 | mkdir ${save_path} 33 | sbatch -c 2 --gres=gpu:titan-x:1 --qos=tenenbaum --time=4-12:0 --mem=40G -J ${rep}${model}${num_turk}_${batch}_${lr} -o ${save_path}/_out.txt turk_reinforce.py \ 34 | --model ${model} --mode ${mode} --num_turk ${num_turk} \ 35 | --lr ${lr} --save_path ${save_path} \ 36 | --batch_size ${batch} --gamma ${gamma} --epochs ${epochs} 37 | done 38 | done 39 | done 40 | done 41 | done 42 | 43 | # -o ${save_path}/_out.txt -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from rbf import * 2 | from filesystem import * -------------------------------------------------------------------------------- /utils/filesystem.py: -------------------------------------------------------------------------------- 1 | import os, subprocess 2 | 3 | def mkdir(path): 4 | if not os.path.exists(path): 5 | subprocess.call(['mkdir', path]) -------------------------------------------------------------------------------- /utils/rbf.py: -------------------------------------------------------------------------------- 1 | import torch, pdb 2 | 3 | def meta_rbf(size): 4 | batch = torch.zeros(size*size, size, size) 5 | count = 0 6 | for i in range(size): 7 | for j in range(size): 8 | batch[count,:,:] = rbf(size, (i,j)).clone() 9 | count += 1 10 | return batch 11 | 12 | def rbf(size, position): 13 | x,y = position 14 | grid = torch.zeros(size, size) 15 | 16 | top_left = manhattan(size, row='increasing', col='increasing') 17 | bottom_left = manhattan(size, row='increasing', col='decreasing') 18 | top_right = manhattan(size, row='decreasing', col='increasing') 19 | bottom_right = manhattan(size, row='decreasing', col='decreasing') 20 | 21 | ## top left 22 | if x > 0 and y > 0: 23 | grid[:x+1, :y+1] = bottom_right[-x-1:, -y-1:] 24 | ## bottom left 25 | if x < size and y > 0: 26 | grid[x:, :y+1] = top_right[:size-x, -y-1:] 27 | ## top right 28 | if x > 0 and y < size: 29 | grid[:x+1, y:] = bottom_left[size-x-1:, :size-y] 30 | ## bottom right 31 | if x < size and y < size: 32 | grid[x:, y:] = top_left[:size-x, :size-y] 33 | 34 | return grid 35 | 36 | def manhattan(size, row='increasing', col='increasing'): 37 | if row == 'increasing': 38 | rows = range_grid(0, size, 1, size) 39 | elif row == 'decreasing': 40 | rows = range_grid(size-1, -1, -1, size) 41 | else: 42 | raise RuntimeError('Unrecognized row in manhattan: ', row) 43 | 44 | if col == 'increasing': 45 | cols = range_grid(0, size, 1, size).t() 46 | elif col == 'decreasing': 47 | cols = range_grid(size-1, -1, -1, size).t() 48 | else: 49 | raise RuntimeError('Unrecognized col in manhattan: ', col) 50 | 51 | distance = rows + cols 52 | return distance 53 | 54 | def range_grid(low, high, step, repeat): 55 | grid = torch.arange(low, high, step).repeat(repeat, 1) 56 | return grid 57 | 58 | 59 | 60 | if __name__ == '__main__': 61 | import sys 62 | sys.path.append('../') 63 | import pipeline 64 | 65 | map_dim = 10 66 | 67 | row = torch.arange(0,map_dim).unsqueeze(1).repeat(1,map_dim).numpy() 68 | col = torch.arange(0,map_dim).repeat(map_dim,1).numpy() 69 | 70 | pipeline.visualize_fig(row, 'figures/grad_row.png') 71 | pipeline.visualize_fig(col, 'figures/grad_col.png') 72 | 73 | # import scipy.misc 74 | # grid = rbf(10, (5,5)) 75 | # print grid 76 | batch = meta_rbf(10).numpy() 77 | 78 | # pdb.set_trace() 79 | 80 | for b in range(batch.shape[0]): 81 | pipeline.visualize_fig(batch[b], 'figures/' + str(b) + '.png') 82 | print batch.size 83 | # print batch[-1] 84 | pdb.set_trace() 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /visualization/__init__.py: -------------------------------------------------------------------------------- 1 | from place_goal import * 2 | from vis_predictions import * -------------------------------------------------------------------------------- /visualization/place_goal.py: -------------------------------------------------------------------------------- 1 | import os, pickle, argparse, subprocess, numpy as np, scipy.misc 2 | 3 | def place_goal(img, goal, save_path, num_cells = 10, width = 4): 4 | img = img.copy() 5 | dim = img.shape[0] 6 | cell_size = float(dim) / num_cells 7 | half_width = width / 2 8 | 9 | ## add black rectangular grid 10 | for i in range(0, num_cells+1): 11 | low = int(i * cell_size) - half_width 12 | high = low + half_width 13 | img[low:high, :, :] = 0 14 | img[:, low:high, :] = 0 15 | 16 | if goal != None: 17 | ## add red square around goal cell 18 | goal_row_low = int(goal[0] * cell_size) 19 | goal_row_high = min( int(goal_row_low + cell_size), img.shape[0] - 1) 20 | goal_col_low = int(goal[1] * cell_size) 21 | goal_col_high = min( int(goal_col_low + cell_size), img.shape[1] - 1) 22 | 23 | img[max(goal_row_low-width, 0):goal_row_low+width, goal_col_low:goal_col_high+1, :] = [255,0,0] 24 | img[goal_row_high-width:goal_row_high+width, goal_col_low:goal_col_high+1, :] = [255,0,0] 25 | img[goal_row_low:goal_row_high+1, max(goal_col_low-width, 0):goal_col_low+width, :] = [255,0,0] 26 | img[goal_row_low:goal_row_high+1, goal_col_high-width:goal_col_high+width, :] = [255,0,0] 27 | 28 | scipy.misc.imsave(save_path, img) 29 | 30 | if __name__ == '__main__': 31 | parser = argparse.ArgumentParser() 32 | # parser.add_argument('--data_path', type=str, default='../data/mturk_nonunique_global_10dim_8max_2ref/') 33 | # parser.add_argument('--sprite_path', type=str, default='../data/mturk_nonunique_global_10dim_8max_2ref_vis/') 34 | parser.add_argument('--data_path', type=str, default='../data/mturk_only_global_10dim_8max_2ref/') 35 | parser.add_argument('--sprite_path', type=str, default='../data/mturk_only_global_10dim_8max_2ref_vis/') 36 | parser.add_argument('--save_path', type=str, default='global_trial/') 37 | parser.add_argument('--start', type=int, default=200) 38 | parser.add_argument('--end', type=int, default=300) 39 | parser.add_argument('--dim', type=int, default=10) 40 | args = parser.parse_args() 41 | 42 | print args, '\n' 43 | 44 | if not os.path.exists(args.save_path): 45 | subprocess.Popen(['mkdir', args.save_path]) 46 | 47 | for i in range(args.start, args.end): 48 | data_path = os.path.join(args.data_path, str(i) + '.p') 49 | data = pickle.load( open(data_path, 'rb') ) 50 | 51 | sprite_path = os.path.join(args.sprite_path, str(i) + '_sprites.png') 52 | sprites = scipy.misc.imread(sprite_path) 53 | 54 | dump_path = os.path.join(args.save_path, str(i) + '.p') 55 | pickle.dump(data, open(dump_path, 'wb')) 56 | 57 | save_path = os.path.join(args.save_path, str(i) + '_clean.png') 58 | place_goal(sprites, None, save_path, num_cells = args.dim) 59 | 60 | for g, goal in enumerate(data['goals']): 61 | print data['instructions'][g], goal 62 | save_path = os.path.join(args.save_path, str(i) + '_' + str(g) + '.png') 63 | print save_path, '\n' 64 | place_goal(sprites, goal, save_path, num_cells = args.dim) 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /visualization/run_vis.py: -------------------------------------------------------------------------------- 1 | #!/om/user/janner/anaconda2/envs/pytorch/bin/python 2 | 3 | import sys, os, subprocess, argparse, numpy as np, pickle, pdb 4 | from tqdm import tqdm 5 | from matplotlib import cm 6 | from vis_predictions import * 7 | sys.path.append('../') 8 | import environment, pipeline 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--save_path', type=str, default='../logs/example/pickle/') 12 | args = parser.parse_args() 13 | 14 | predictions = pickle.load( open(os.path.join(args.save_path, 'predictions.p'), 'rb') ).squeeze() 15 | targets = pickle.load( open(os.path.join(args.save_path, 'targets.p'), 'rb') ).squeeze() 16 | rewards = pickle.load( open(os.path.join(args.save_path, 'rewards.p'), 'rb') ).squeeze() 17 | terminal = pickle.load( open(os.path.join(args.save_path, 'terminal.p'), 'rb') ).squeeze() 18 | 19 | vis_path = os.path.join(args.save_path, 'vis') 20 | if not os.path.exists(vis_path): 21 | subprocess.call(['mkdir', vis_path]) 22 | 23 | num_worlds = targets.shape[0] 24 | 25 | for ind in tqdm(range(num_worlds)): 26 | pred = predictions[ind] 27 | targ = targets[ind] 28 | 29 | vmax = max(pred.max(), targ.max()) 30 | vmin = min(pred.min(), targ.min()) 31 | 32 | pred_path = os.path.join(vis_path, str(ind) + '_pred.png') 33 | targ_path = os.path.join(vis_path, str(ind) + '_targ.png') 34 | 35 | vis_fig(pred, pred_path, vmax=vmax, vmin=vmin) 36 | vis_fig(targ, targ_path, vmax=vmax, vmin=vmin) 37 | 38 | 39 | -------------------------------------------------------------------------------- /visualization/vis_predictions.py: -------------------------------------------------------------------------------- 1 | import os, math, torch, pdb 2 | from tqdm import tqdm 3 | from torch.autograd import Variable 4 | import matplotlib; matplotlib.use('Agg') 5 | from matplotlib import cm 6 | from matplotlib import pyplot as plt 7 | 8 | def vis_value_map(pred, targ, save_path, title='prediction', share=True): 9 | # print 'in vis: ', pred.shape, targ.shape 10 | dim = int(math.sqrt(pred.size)) 11 | if share: 12 | vmin = min(pred.min(), targ.min()) 13 | vmax = max(pred.max(), targ.max()) 14 | else: 15 | vmin = None 16 | vmax = None 17 | 18 | plt.clf() 19 | fig, (ax0,ax1) = plt.subplots(1,2,sharey=True) 20 | heat0 = ax0.pcolor(pred.reshape(dim,dim), vmin=vmin, vmax=vmax, cmap=cm.jet) 21 | ax0.set_title(title, fontsize=5) 22 | if not share: 23 | fig.colorbar(heat0) 24 | heat1 = ax1.pcolor(targ.reshape(dim,dim), vmin=vmin, vmax=vmax, cmap=cm.jet) 25 | ax1.invert_yaxis() 26 | ax1.set_title('target') 27 | 28 | fig.subplots_adjust(right=0.8) 29 | cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) 30 | fig.colorbar(heat1, cax=cbar_ax) 31 | 32 | # print 'saving to: ', fullpath 33 | plt.savefig(save_path, bbox_inches='tight') 34 | plt.close(fig) 35 | 36 | # print pred.shape, targ.shape 37 | 38 | def vis_fig(data, save_path, title=None, vmax=None, vmin=None, cmap=cm.jet): 39 | # print 'in vis: ', pred.shape, targ.shape 40 | dim = int(math.sqrt(data.size)) 41 | 42 | # if share: 43 | # vmin = min(pred.min(), targ.min()) 44 | # vmax = max(pred.max(), targ.max()) 45 | # else: 46 | # vmin = None 47 | # vmax = None 48 | 49 | plt.clf() 50 | # fig, (ax0,ax1) = plt.subplots(1,2,sharey=True) 51 | plt.pcolor(data.reshape(dim,dim), vmin=vmin, vmax=vmax, cmap=cmap) 52 | plt.xticks([]) 53 | plt.yticks([]) 54 | # ax0.set_title(title, fontsize=5) 55 | # if not share: 56 | # fig.colorbar(heat0) 57 | # heat1 = ax1.pcolor(targ.reshape(dim,dim), vmin=vmin, vmax=vmax, cmap=cm.jet) 58 | fig = plt.gcf() 59 | ax = plt.gca() 60 | 61 | if title: 62 | ax.set_title(title) 63 | ax.invert_yaxis() 64 | 65 | fig.set_size_inches(4,4) 66 | 67 | # ax1.invert_yaxis() 68 | # ax1.set_title('target') 69 | 70 | # fig.subplots_adjust(right=0.8) 71 | # cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) 72 | # fig.colorbar(heat1, cax=cbar_ax) 73 | 74 | # print 'saving to: ', fullpath 75 | plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0) 76 | plt.close(fig) 77 | 78 | # print pred.shape, targ.shape 79 | 80 | def vis_predictions(model, inputs, targets, instructions, save_path, prefix=''): 81 | ## wrap tensors in Variables to pass to model 82 | input_vars = ( Variable(tensor.contiguous()) for tensor in inputs ) 83 | predictions = model(input_vars) 84 | 85 | ## convert to numpy arrays for saving to disk 86 | predictions = predictions.data.cpu().numpy() 87 | targets = targets.cpu().numpy() 88 | 89 | num_inputs = inputs[0].size(0) 90 | for ind in tqdm(range(num_inputs)): 91 | pred = predictions[ind] 92 | targ = targets[ind] 93 | instr = instructions[ind] 94 | 95 | full_path = os.path.join(save_path, prefix + str(ind) + '.png') 96 | 97 | vis_value_map(pred, targ, full_path, title=instr, share=False) 98 | 99 | --------------------------------------------------------------------------------