├── .gitattributes
├── LICENSE
├── README.md
├── data
├── __init__.py
├── global.p
├── loading.py
└── local.p
├── download_data.sh
├── environment
├── DefaultGenerator.py
├── Generator.py
├── GlobalGenerator.py
├── MDP.py
├── NonTextGenerator.py
├── NonUniqueGenerator.py
├── SpriteFigure.py
├── ValueIteration.py
├── __init__.py
├── figure_library.py
├── library.py
├── reference_instructions.py
└── visualization.py
├── generate_worlds.py
├── logs
├── .gitkeep
└── example
│ ├── git_sprites.png
│ └── predictions.png
├── models
├── Linear_custom.py
├── __init__.py
├── attention_direct.py
├── attention_global.py
├── attention_heatmap.py
├── attention_model.py
├── cnn_lstm.py
├── compositor_model.py
├── constraint_factorizer.py
├── conv_to_vector.py
├── custom.py
├── goal_model.py
├── initializations.py
├── lookup_location.py
├── lookup_model.py
├── map_model.py
├── mlp.py
├── model_factorizer.py
├── multi_global.py
├── multi_model.py
├── multi_nobases.py
├── multi_nocnn.py
├── multi_nonsep.py
├── multi_norbf.py
├── object_model.py
├── object_model_10.py
├── object_model_fc.py
├── object_model_pos.py
├── object_model_rnn.py
├── simple_conv.py
├── state_model.py
├── state_model_10.py
├── text_model.py
├── uvfa_pos.py
├── uvfa_text.py
└── values_factorized.py
├── pipeline
├── __init__.py
├── agent.py
├── evaluation.py
├── run_eval.py
├── score_iteration.py
└── training.py
├── psiturk-vi
├── tmp.p
├── turk_jun22.p
├── turk_jun23_global.p
├── turk_jun26.p
└── turk_local_merged.p
├── reinforcement.py
├── representation.py
├── requirements.txt
├── slurm
├── meta-generate.py
└── meta_reinforce.sh
├── utils
├── __init__.py
├── filesystem.py
└── rbf.py
└── visualization
├── __init__.py
├── place_goal.py
├── run_vis.py
└── vis_predictions.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.p linguist-language=Python
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Michael Janner
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Spatial Reasoning
2 | Code and data to reproduce the experiments in [Representation Learning for Grounded Spatial Reasoning](https://arxiv.org/abs/1707.03938).
3 |
4 | ## Installation
5 | Get [PyTorch](http://pytorch.org/) and `pip install -r requirements`
6 |
7 | `./download_data.sh` to get the annotated map data and sprites to make new maps.
8 |
9 | Note: There were a few breaking changes in the latest PyTorch release that affected this repo. Make sure you are using v0.2.0+ (`torch.__version__` to check)
10 |
11 | ## Data
12 | We have collected human annotations of 3308 goals in 10x10 maps, specified by referencing one or more of 12 object types placed randomly in the world. To load up to `max_train` train maps and `max_val` val maps with `mode = [ local | global ]` instructions and `annotations = [ human | synthetic ]` descriptions, run:
13 | ```
14 | >>> import data
15 | >>> train_data, val_data = data.load(mode, annotations, max_train, max_val)
16 | >>> layouts, objects, rewards, terminal, instructions, values, goals = train_data
17 | ```
18 | where `layouts` and `objects` are arrays with item identifiers, `rewards` and `terminal` are arrays that can be used to construct an MDP, `instructions` are a list of text descriptions of the coordinates in `goals`, and `values` has the ground truth state values from Value Iteration.
19 |
20 | To generate more maps (with synthetic annotations):
21 | ```
22 | $ python generate_worlds.py --mode [ local | global ] --save_path data/example_worlds/ --vis_path data/example_env/
23 | ```
24 | which will save pickle files that can be loaded with `data.load()` in `data/example_worlds/` and visualizations in `data/example_env/`. Visualizations of a few of the maps downloaded to `data/local/` are in `data/local_sprites/`.
25 |
26 |
27 |
28 |
29 |
30 | Visualizations of randomly generated worlds
31 |
32 |
33 | ## Training
34 |
35 | To train the model with reinforcement learning:
36 | ```
37 | $ python reinforcement.py --annotations [ human | synthetic ] --mode [ local | global ] --save_path logs/trial/
38 | ```
39 |
40 | This will save the model, pickle files with the predictions, and visualizations to `logs/trial/`.
41 |
42 | To train the models in a supervised manner for the representation analysis, run `python representation.py` with the same arguments as above.
43 |
44 |
45 |
46 |
47 |
48 | Predicted values for two maps with two instructions each. In the first map, the instructions share no objects but refer to the same location. In the second map, the instructions refer to different locations.
49 |
50 |
51 |
52 | ## Acknowledgments
53 | A huge thank you to Daniel Fried and Jacob Andreas for providing a copy of the human annotations when the original Dropbox link was purged.
54 |
55 |
56 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from loading import *
--------------------------------------------------------------------------------
/download_data.sh:
--------------------------------------------------------------------------------
1 | wget https://www.dropbox.com/sh/9at8tlmjp1ocrpg/AAAmhzZHNcJHNf5XXflkIHnha?dl=1 -O data.zip
2 | unzip data.zip -d data
3 | mv data/sprites environment/
4 | rm data.zip
5 |
--------------------------------------------------------------------------------
/environment/DefaultGenerator.py:
--------------------------------------------------------------------------------
1 | import math, random, pdb
2 | import numpy as np
3 | import library
4 | from Generator import Generator
5 |
6 | # print library.objects
7 |
8 | class DefaultGenerator(Generator):
9 |
10 | def __init__(self, objects, directions, shape = (20,20), goal_value = 3, num_steps = 50):
11 | self.objects = objects
12 | self.directions = directions
13 | self.shape = shape
14 | self.goal_value = goal_value
15 | self.num_steps = num_steps
16 |
17 | self.default_world = np.array(
18 | [ [1,1,0,0,0,0,0,0,0,1,1,1,1,1,0,1,1,1,1,1],
19 | [1,1,1,0,0,0,0,0,0,1,0,1,1,1,0,1,0,1,0,0],
20 | [1,1,1,0,0,0,0,0,1,1,0,1,1,1,1,1,0,1,0,1],
21 | [1,1,1,1,1,1,1,0,1,0,0,0,0,0,0,0,0,1,1,1],
22 | [0,0,0,0,0,0,1,1,1,0,0,0,0,0,1,0,1,1,1,1],
23 | [0,0,0,1,1,0,1,0,0,0,0,0,1,0,1,1,1,1,1,1],
24 | [0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,1,1,1,1],
25 | [0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,1,0,0,1],
26 | [0,0,0,1,1,1,0,1,1,1,0,0,0,0,0,0,0,0,0,0],
27 | [0,0,0,0,1,1,1,1,0,1,1,1,0,0,0,0,0,0,0,0],
28 | [0,1,1,1,1,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0],
29 | [1,1,1,1,1,1,1,0,0,0,1,1,0,0,0,0,0,0,0,0],
30 | [1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
31 | [1,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
32 | [1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
33 | [1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
34 | [1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
35 | [1,0,0,1,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0],
36 | [1,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0],
37 | [1,1,1,1,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0] ]
38 | )
39 | self.default_world = self.default_world * -1 + 1
40 | self.default_world = self.default_world[:10,:10]
41 | print 'world size: ', self.default_world.shape
42 |
43 | def new(self):
44 | directions = {}
45 | world = self.default_world.copy()
46 | # print world
47 |
48 | states = [(i,j) for i in range(world.shape[0]) \
49 | for j in range(world.shape[1]) \
50 | if world[i][j] != self.objects['puddle']['index'] ]
51 |
52 | used_indices = set( np.unique(world).tolist() )
53 | positions = {}
54 | for name, obj in self.objects.iteritems():
55 | # print obj['name']
56 | if not obj['background']:
57 | ind = obj['index']
58 | assert( ind not in used_indices )
59 |
60 |
61 | # if name == 'square':
62 | # pos = (2,1)
63 | # else:
64 | pos = random.choice(states) #self.randomPosition()
65 | while pos in positions.values():
66 | pos = random.choice(states)
67 |
68 | world[pos] = ind
69 | used_indices.add(ind)
70 |
71 | # print name
72 | positions[name] = pos
73 | object_specific_dir = self.generateDirections(world, pos, name)
74 | # print object_specific_dir
75 | directions[name] = object_specific_dir
76 | # print directions
77 | # pdb.set_trace()
78 |
79 | reward_maps, terminal_maps, instructions, goals = \
80 | self.addRewards(world, positions, directions)
81 |
82 | info = {
83 | 'map': world,
84 | 'rewards': reward_maps,
85 | 'terminal': terminal_maps,
86 | 'instructions': instructions,
87 | 'goals': goals
88 | }
89 | return info
90 |
--------------------------------------------------------------------------------
/environment/Generator.py:
--------------------------------------------------------------------------------
1 | import math, random, pdb
2 | import numpy as np
3 | import library
4 |
5 | # print library.objects
6 |
7 | class Generator:
8 |
9 | def __init__(self, objects, directions, shape = (20,20), goal_value = 3, num_steps = 50):
10 | self.objects = objects
11 | self.directions = directions
12 | self.shape = shape
13 | self.goal_value = goal_value
14 | self.num_steps = num_steps
15 |
16 | def new(self):
17 | directions = {}
18 | world = self.__puddles(self.num_steps)
19 | # print world
20 |
21 | states = [(i,j) for i in range(world.shape[0]) \
22 | for j in range(world.shape[1]) \
23 | if world[i][j] != self.objects['puddle']['index'] ]
24 |
25 | used_indices = set( np.unique(world).tolist() )
26 | positions = {}
27 | for name, obj in self.objects.iteritems():
28 | # print obj['name']
29 | if not obj['background']:
30 | ind = obj['index']
31 | assert( ind not in used_indices )
32 |
33 |
34 | pos = random.choice(states) #self.randomPosition()
35 | while pos in positions.values():
36 | pos = random.choice(states)
37 |
38 | world[pos] = ind
39 | used_indices.add(ind)
40 |
41 | # print name
42 | positions[name] = pos
43 | object_specific_dir = self.generateDirections(world, pos, name)
44 | # print object_specific_dir
45 | directions[name] = object_specific_dir
46 | # print directions
47 | # pdb.set_trace()
48 |
49 | reward_maps, terminal_maps, instructions, goals = \
50 | self.addRewards(world, positions, directions)
51 |
52 | info = {
53 | 'map': world,
54 | 'rewards': reward_maps,
55 | 'terminal': terminal_maps,
56 | 'instructions': instructions,
57 | 'goals': goals
58 | }
59 | return info
60 |
61 | def __puddles(self, iters, max_width=3, max_steps=8):
62 | (M,N) = self.shape
63 | turns = ['up', 'down', 'left', 'right']
64 | world = np.zeros( self.shape )
65 | world.fill( self.objects['puddle']['index'] )
66 | # position = np.floor(np.random.uniform(size=2)*self.shape[0]).astype(int)
67 |
68 | position = (np.random.uniform()*M, np.random.uniform()*N)
69 | position = map(int, position)
70 |
71 | for i in range(iters):
72 | direction = random.choice(turns)
73 | width = int(np.random.uniform(low=1, high=max_width))
74 | steps = int(np.random.uniform(low=1, high=max_steps))
75 | if direction == 'up':
76 | top = max(position[0] - steps, 0)
77 | bottom = position[0]
78 | left = max(position[1] - int(math.floor(width/2.)), 0)
79 | right = min(position[1] + int(math.ceil(width/2.)), N)
80 | position[0] = top
81 | elif direction == 'down':
82 | top = position[0]
83 | bottom = min(position[0] + steps, M)
84 | left = max(position[1] - int(math.floor(width/2.)), 0)
85 | right = min(position[1] + int(math.ceil(width/2.)), N)
86 | position[0] = bottom
87 | elif direction == 'left':
88 | top = max(position[0] - int(math.floor(width/2.)), 0)
89 | bottom = min(position[0] + int(math.ceil(width/2.)), M)
90 | left = max(position[1] - steps, 0)
91 | right = position[1]
92 | position[1] = left
93 | elif direction == 'right':
94 | top = max(position[0] - int(math.floor(width/2.)), 0)
95 | bottom = min(position[0] + int(math.ceil(width/2.)), M)
96 | left = position[1]
97 | right = min(position[1] + steps, N)
98 | position[1] = right
99 | # print top, bottom, left, right, self.objects['grass']['index']
100 | # print world.shape
101 | world[top:bottom+1, left:right+1] = self.objects['grass']['index']
102 |
103 | return world
104 |
105 | def addRewards(self, world, positions, directions):
106 | reward_maps = []
107 | terminal_maps = []
108 | instruction_set = []
109 | goal_positions = []
110 |
111 | object_values = np.zeros( self.shape )
112 | ## add non-background values
113 | for name, obj in self.objects.iteritems():
114 | value = obj['value']
115 | if not obj['background']:
116 | pos = positions[name]
117 | object_values[pos] = value
118 | else:
119 | mask = np.ma.masked_equal(world, obj['index']).mask
120 | # print name, obj['index']
121 | # print mask
122 | # print value
123 | object_values[mask] += value
124 | # for st
125 | # value = obj['v']
126 | # print 'values: '
127 | # print object_values
128 |
129 | for name, obj in self.objects.iteritems():
130 | if not obj['background']:
131 | # ind = obj['index']
132 | # pos = positions[name]
133 | for (phrase, target_pos) in directions[name]:
134 | rewards = object_values.copy()
135 | rewards[target_pos] += self.goal_value
136 | terminal = np.zeros( self.shape )
137 | terminal[target_pos] = 1
138 |
139 | reward_maps.append(rewards)
140 | terminal_maps.append(terminal)
141 | instruction_set.append(phrase)
142 | goal_positions.append(target_pos)
143 |
144 | # print name, pos, direct
145 | # print reward_maps
146 | # print instruction_set
147 | # print goal_positions
148 | return reward_maps, terminal_maps, instruction_set, goal_positions
149 |
150 | # def randomPosition(self):
151 | # pos = tuple( (np.random.uniform(size=2) * self.dim).astype(int) )
152 | # return pos
153 |
154 | def generateDirections(self, world, pos, name):
155 | directions = []
156 | for identifier, offset in self.directions.iteritems():
157 | phrase = 'reach cell ' + identifier + ' ' + name
158 | absolute_pos = tuple(map(sum, zip(pos, offset)))
159 | # print absolute_pos
160 | (i,j) = absolute_pos
161 | (M,N) = self.shape
162 | out_of_bounds = (i < 0 or i >= M) or (j < 0 or j >= N)
163 |
164 | # absolute_pos[0] < 0 or absolute_pos[0] > [coord < 0 or coord >= self.dim for coord in absolute_pos]
165 | if not out_of_bounds:
166 | in_puddle = world[absolute_pos] == self.objects['puddle']['index']
167 | if not in_puddle:
168 | directions.append( (phrase, absolute_pos) )
169 | # print directions
170 | return directions
171 |
172 | # for i in [-1, 0, 1]:
173 | # for j in [-1, 0, 1]:
174 | # if i == 0 and j == 0:
175 | # pass
176 | # else:
177 |
178 |
179 | if __name__ == '__main__':
180 | gen = Generator(library.objects, library.directions)
181 | info = gen.new()
182 |
183 | print info['map']
184 | print info['instructions'], len(info['instructions'])
185 | print len(info['rewards']), len(info['terminal'])
186 | # print info['terminal']
187 |
188 |
189 |
190 |
--------------------------------------------------------------------------------
/environment/MDP.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class MDP:
4 |
5 | def __init__(self, world, rewards, terminal):
6 | self.world = world
7 | self.reward_map = rewards
8 | self.terminal_map = terminal
9 | self.shape = self.reward_map.shape
10 |
11 | self.M, self.N = self.shape
12 | self.states = [(i,j) for i in range(self.M) for j in range(self.N)]
13 | self.children = self.get_children( self.M, self.N )
14 |
15 | self.actions = [(-1,0),(1,0),(0,-1),(0,1)]
16 | self.states = [(i,j) for i in range(self.shape[0]) for j in range(self.shape[1])]
17 |
18 | def getActions(self):
19 | return [i for i in range(len(self.actions))]
20 |
21 | def getStates(self):
22 | return self.states
23 |
24 | def transition(self, position, action_ind, fullstate=False):
25 | action = self.actions[action_ind]
26 | # print 'transitioning: ', action, position
27 | candidate = tuple(map(sum, zip(position, action)))
28 |
29 | ## if new location is valid,
30 | ## update the position
31 | if self.valid(candidate):
32 | position = candidate
33 |
34 | if fullstate:
35 | state = self.observe(position)
36 | else:
37 | state = position
38 |
39 | return state
40 |
41 | def valid(self, position):
42 | x, y = position[0], position[1]
43 | if x >= 0 and x < self.shape[0] and y >= 0 and y < self.shape[1]:
44 | return True
45 | else:
46 | return False
47 |
48 | def reward(self, position):
49 | rew = self.reward_map[position]
50 | return rew
51 |
52 | def terminal(self, position):
53 | term = self.terminal_map[position]
54 | return term
55 |
56 | def representValues(self, values):
57 | value_map = np.zeros( self.shape )
58 | for pos, val in values.iteritems():
59 | assert(value_map[pos] == 0)
60 | value_map[pos] = val
61 | return value_map
62 |
63 | # '''
64 | # start_pos is (i,j)
65 | # policy is dict from (i,j) --> (delta_i, delta_j)
66 | # '''
67 | # def simulate(self, policy, start_pos, num_steps = 100):
68 | # pos = start_pos
69 | # visited = set([pos])
70 | # for step in range(num_steps):
71 | # # rew = self.reward(pos)
72 | # term = self.terminal(pos)
73 | # if term:
74 | # return 0
75 | # reachable = policy[pos]
76 | # selected = 0
77 | # while selected < len(reachable) and reachable[selected] in visited:
78 | # # print ' visited ', selected, reachable[selected]
79 | # selected += 1
80 | # if selected == len(reachable):
81 | # selected = 0
82 | # pos = policy[pos][selected]
83 | # visited.add(pos)
84 | # print 'position: ', pos
85 | # # print pos, goal
86 | # goal = np.argwhere( self.terminal_map ).flatten().tolist()
87 | # manhattan_dist = abs(pos[0] - goal[0]) + abs(pos[1] - goal[1])
88 | # return manhattan_dist
89 |
90 | '''
91 | start_pos is (i,j)
92 | policy is dict from (i,j) --> (delta_i, delta_j)
93 | '''
94 | def simulate(self, policy, start_pos, num_steps = 100):
95 | pos = start_pos
96 | visited = set([pos])
97 | for step in range(num_steps):
98 | # rew = self.reward(pos)
99 | term = self.terminal(pos)
100 | if term:
101 | return step
102 | reachable = policy[pos]
103 | selected = 0
104 | while selected < len(reachable) and reachable[selected] in visited:
105 | # print ' visited ', selected, reachable[selected]
106 | selected += 1
107 | if selected == len(reachable):
108 | selected = 0
109 | pos = policy[pos][selected]
110 | visited.add(pos)
111 | # print 'position: ', pos
112 | return step
113 |
114 | def get_children(self, M, N):
115 | children = {}
116 | for i in range(M):
117 | for j in range(N):
118 | pos = (i,j)
119 | children[pos] = []
120 | for di in range( max(i-1, 0), min(i+1, M-1)+1 ):
121 | for dj in range( max(j-1, 0), min(j+1, N-1)+1 ):
122 | child = (di, dj)
123 | if pos != child and (i == di or j == dj):
124 | children[pos].append( child )
125 | return children
126 |
127 |
128 | '''
129 | values is M x N map of predicted values
130 | '''
131 | def get_policy(self, values):
132 | policy = {}
133 | for state in self.states:
134 | reachable = self.children[state]
135 | selected = sorted(reachable, key = lambda x: values[x], reverse=True)
136 | policy[state] = selected
137 | return policy
138 |
139 |
140 |
141 | if __name__ == '__main__':
142 | import pickle, pdb, numpy as np
143 | info = pickle.load( open('../data/train_10/0.p') )
144 | print info.keys()
145 | mdp = MDP(info['map'], info['rewards'][0], info['terminal'][0])
146 | print mdp.world
147 | print mdp.children
148 |
149 | values = np.random.randn(10,10)
150 | policy = mdp.get_policy(values)
151 |
152 | steps = mdp.simulate(policy, (0,0))
153 | print steps
154 |
155 | pdb.set_trace()
156 |
157 |
158 |
--------------------------------------------------------------------------------
/environment/NonTextGenerator.py:
--------------------------------------------------------------------------------
1 | import math, random, pdb
2 | import numpy as np
3 | import library
4 | from Generator import Generator
5 |
6 | # print library.objects
7 |
8 | class NonTextGenerator(Generator):
9 |
10 | def __init__(self, objects, shape = (20,20), goal_value = 3, num_steps = 50):
11 | self.objects = objects
12 | self.shape = shape
13 | self.goal_value = goal_value
14 | self.num_steps = num_steps
15 |
16 | def new(self):
17 | directions = {}
18 | world = self.__puddles(self.num_steps)
19 |
20 | # states = [(i,j) for i in range(world.shape[0]) \
21 | # for j in range(world.shape[1]) \
22 | # if world[i][j] != self.objects['puddle']['index'] ]
23 |
24 | ## we are using ALL of the locations as goals,
25 | ## even those in puddles
26 | states = [ (i,j) for i in range(world.shape[0]) \
27 | for j in range(world.shape[1]) ]
28 |
29 | reward_maps = []
30 | terminal_maps = []
31 | goals = []
32 | for state in states:
33 | reward = np.zeros( world.shape )
34 | terminal = np.zeros( world.shape )
35 | reward[state] = self.goal_value
36 | terminal[state] = 1
37 | reward_maps.append(reward)
38 | terminal_maps.append(terminal)
39 | goals.append(state)
40 |
41 | info = {
42 | 'map': world,
43 | 'rewards': reward_maps,
44 | 'terminal': terminal_maps,
45 | 'goals': goals
46 | }
47 |
48 | # pdb.set_trace()
49 |
50 | return info
51 |
52 | def __puddles(self, iters, max_width=3, max_steps=10):
53 | (M,N) = self.shape
54 | turns = ['up', 'down', 'left', 'right']
55 | world = np.zeros( self.shape )
56 | world.fill( self.objects['puddle']['index'] )
57 | # position = np.floor(np.random.uniform(size=2)*self.shape[0]).astype(int)
58 |
59 | position = (np.random.uniform()*M, np.random.uniform()*N)
60 | position = map(int, position)
61 |
62 | for i in range(iters):
63 | direction = random.choice(turns)
64 | width = int(np.random.uniform(low=1, high=max_width))
65 | steps = int(np.random.uniform(low=1, high=max_steps))
66 | if direction == 'up':
67 | top = max(position[0] - steps, 0)
68 | bottom = position[0]
69 | left = max(position[1] - int(math.floor(width/2.)), 0)
70 | right = min(position[1] + int(math.ceil(width/2.)), N)
71 | position[0] = top
72 | elif direction == 'down':
73 | top = position[0]
74 | bottom = min(position[0] + steps, M)
75 | left = max(position[1] - int(math.floor(width/2.)), 0)
76 | right = min(position[1] + int(math.ceil(width/2.)), N)
77 | position[0] = bottom
78 | elif direction == 'left':
79 | top = max(position[0] - int(math.floor(width/2.)), 0)
80 | bottom = min(position[0] + int(math.ceil(width/2.)), M)
81 | left = max(position[1] - steps, 0)
82 | right = position[1]
83 | position[1] = left
84 | elif direction == 'right':
85 | top = max(position[0] - int(math.floor(width/2.)), 0)
86 | bottom = min(position[0] + int(math.ceil(width/2.)), M)
87 | left = position[1]
88 | right = min(position[1] + steps, N)
89 | position[1] = right
90 | # print top, bottom, left, right, self.objects['grass']['index']
91 | # print world.shape
92 | world[top:bottom+1, left:right+1] = self.objects['grass']['index']
93 |
94 | return world
95 |
96 | # def addRewards(self, world, positions, directions):
97 | # reward_maps = []
98 | # terminal_maps = []
99 | # instruction_set = []
100 | # goal_positions = []
101 |
102 | # object_values = np.zeros( self.shape )
103 | # ## add non-background values
104 | # for name, obj in self.objects.iteritems():
105 | # value = obj['value']
106 | # if not obj['background']:
107 | # pos = positions[name]
108 | # object_values[pos] = value
109 | # else:
110 | # mask = np.ma.masked_equal(world, obj['index']).mask
111 | # # print name, obj['index']
112 | # # print mask
113 | # # print value
114 | # object_values[mask] += value
115 | # # for st
116 | # # value = obj['v']
117 | # # print 'values: '
118 | # # print object_values
119 |
120 | # for name, obj in self.objects.iteritems():
121 | # if not obj['background']:
122 | # # ind = obj['index']
123 | # # pos = positions[name]
124 | # for (phrase, target_pos) in directions[name]:
125 | # rewards = object_values.copy()
126 | # rewards[target_pos] += self.goal_value
127 | # terminal = np.zeros( self.shape )
128 | # terminal[target_pos] = 1
129 |
130 | # reward_maps.append(rewards)
131 | # terminal_maps.append(terminal)
132 | # instruction_set.append(phrase)
133 | # goal_positions.append(target_pos)
134 |
135 | # # print name, pos, direct
136 | # # print reward_maps
137 | # # print instruction_set
138 | # # print goal_positions
139 | # return reward_maps, terminal_maps, instruction_set, goal_positions
140 |
141 | # # def randomPosition(self):
142 | # # pos = tuple( (np.random.uniform(size=2) * self.dim).astype(int) )
143 | # # return pos
144 |
145 | # def generateDirections(self, world, pos, name):
146 | # directions = []
147 | # for identifier, offset in self.directions.iteritems():
148 | # phrase = 'reach cell ' + identifier + ' ' + name
149 | # absolute_pos = tuple(map(sum, zip(pos, offset)))
150 | # # print absolute_pos
151 | # (i,j) = absolute_pos
152 | # (M,N) = self.shape
153 | # out_of_bounds = (i < 0 or i >= M) or (j < 0 or j >= N)
154 |
155 | # # absolute_pos[0] < 0 or absolute_pos[0] > [coord < 0 or coord >= self.dim for coord in absolute_pos]
156 | # if not out_of_bounds:
157 | # in_puddle = world[absolute_pos] == self.objects['puddle']['index']
158 | # if not in_puddle:
159 | # directions.append( (phrase, absolute_pos) )
160 | # # print directions
161 | # return directions
162 |
163 | # # for i in [-1, 0, 1]:
164 | # # for j in [-1, 0, 1]:
165 | # # if i == 0 and j == 0:
166 | # # pass
167 | # # else:
168 |
169 |
170 | if __name__ == '__main__':
171 | gen = Generator(library.objects, library.directions)
172 | info = gen.new()
173 |
174 | print info['map']
175 | print info['instructions'], len(info['instructions'])
176 | print len(info['rewards']), len(info['terminal'])
177 | # print info['terminal']
178 |
179 |
180 |
181 |
--------------------------------------------------------------------------------
/environment/SpriteFigure.py:
--------------------------------------------------------------------------------
1 | import os, numpy as np, scipy.misc, pdb
2 |
3 | class SpriteFigure:
4 |
5 | def __init__(self, objects, background, dim = 20):
6 | self.objects = objects
7 | self.dim = dim
8 | background = self.loadImg(background)
9 | self.sprites = {0: background}
10 | for name, obj in objects.iteritems():
11 | ind = obj['index']
12 | # print name
13 | sprite = self.loadImg( obj['sprite'] )
14 | if not obj['background']:
15 | # overlay = background.copy()
16 | # masked = np.ma.masked_greater( sprite[:,:,-1], 0 ).mask
17 | # overlay[masked] = sprite[:,:,:-1][masked]
18 | overlay = sprite
19 | else:
20 | overlay = sprite
21 |
22 | self.sprites[ind] = overlay
23 |
24 |
25 | def loadImg(self, path, dim = None):
26 | if dim == None:
27 | dim = self.dim
28 | path = os.path.join('environment', path)
29 | img = scipy.misc.imread(path)
30 | img = scipy.misc.imresize(img, (dim, dim) )
31 | return img
32 |
33 | def makeGrid(self, world, filename, boundary_width = 4):
34 | shape = world.shape
35 |
36 | grass_ind = self.objects['grass']['index']
37 | puddle_ind = self.objects['puddle']['index']
38 |
39 | state = self.loadImg( self.objects['grass']['sprite'], dim = shape[0] * self.dim )
40 | puddle = self.loadImg( self.objects['puddle']['sprite'], dim = shape[0] * self.dim )
41 |
42 | for i in range(shape[0]):
43 | for j in range(shape[1]):
44 |
45 | row_low = i*self.dim
46 | row_high = (i+1)*self.dim
47 | col_low = j*self.dim
48 | col_high = (j+1)*self.dim
49 |
50 | ind = int(world[i,j])
51 | sprite = self.sprites[ind]
52 |
53 | if ind == grass_ind:
54 | continue
55 | elif ind == puddle_ind:
56 | state[row_low:row_high, col_low:col_high, :] = puddle[row_low:row_high, col_low:col_high, :]
57 | ## background
58 | else:
59 | masked = np.ma.masked_greater( sprite[:,:,-1], 0 ).mask
60 | state[row_low:row_high, col_low:col_high, :][masked] = sprite[:,:,:-1][masked]
61 | # overlay[masked] = sprite[:,:,:-1][masked]
62 | # sprite = self.sprites[world[i,j].astype('int')]
63 | # state[i*self.dim:(i+1)*self.dim, j*self.dim:(j+1)*self.dim, :] = sprite
64 |
65 | for i in range(shape[0]):
66 | for j in range(shape[1]):
67 |
68 | row_low = i*self.dim
69 | row_high = (i+1)*self.dim
70 | col_low = j*self.dim
71 | col_high = (j+1)*self.dim
72 |
73 | ind = int(world[i,j])
74 |
75 | if i < shape[0] - 1:
76 | below = int(world[i+1,j])
77 | if (ind != puddle_ind and below == puddle_ind) or (ind == puddle_ind and below != puddle_ind):
78 | # print 'BELOW: ', i, j
79 | state[row_high-boundary_width:row_high+boundary_width, col_low:col_high, :] = 0.
80 |
81 | if j < shape[1] - 1:
82 | right = int(world[i,j+1])
83 | if (ind != puddle_ind and right == puddle_ind) or (ind == puddle_ind and right != puddle_ind):
84 | # print 'BELOW: ', i, j
85 | state[row_low:row_high, col_high-boundary_width:col_high+boundary_width, :] = 0.
86 |
87 |
88 | scipy.misc.imsave(filename + '.png', state)
89 | return state
90 |
--------------------------------------------------------------------------------
/environment/ValueIteration.py:
--------------------------------------------------------------------------------
1 | class ValueIteration:
2 |
3 | def __init__(self, mdp):
4 | self.refresh(mdp)
5 |
6 | def refresh(self, mdp):
7 | self.mdp = mdp
8 | self.states = mdp.getStates()
9 | self.actions = mdp.getActions()
10 | self.transition = mdp.transition
11 | self.reward = mdp.reward
12 | self.terminal = mdp.terminal
13 | self.values = {state: 0 for state in self.states}
14 | self.policy = {state: None for state in self.states}
15 | self.discount = 0.9
16 |
17 |
18 | def iterate(self):
19 | for k in range(0, 1000):
20 | for state in self.states:
21 | max_val = -float('inf')
22 | term = self.terminal(state)
23 | if term:
24 | max_val = self.reward(state)
25 | else:
26 | for action in self.actions:
27 | new_state = self.transition(state, action)
28 | new_val = self.reward(state) + self.discount * self.values[new_state]
29 | if new_val > max_val:
30 | max_val = new_val
31 | self.policy[state] = action
32 | self.values[state] = max_val
33 | return self.values, self.policy
--------------------------------------------------------------------------------
/environment/__init__.py:
--------------------------------------------------------------------------------
1 | import library, figure_library
2 | from MDP import *
3 | from ValueIteration import ValueIteration
4 | from SpriteFigure import SpriteFigure
--------------------------------------------------------------------------------
/environment/figure_library.py:
--------------------------------------------------------------------------------
1 | spritepath = 'sprites/'
2 |
3 | objects = {
4 | 'grass': {
5 | 'index': 0,
6 | 'value': 0,
7 | 'sprite': 'sprites/grass_figure_4.png', # 'sprites/white.png',
8 | 'background': True,
9 | 'unique': False,
10 | },
11 | 'puddle': {
12 | 'index': 1,
13 | 'value': -1,
14 | 'sprite': 'sprites/water_figure_2.png',
15 | 'background': True,
16 | 'unique': False,
17 | },
18 | ## unique
19 | 'star': {
20 | 'index': 2,
21 | 'value': 0,
22 | 'sprite': 'sprites/star_figure-01.png', ## white_alpha.png
23 | 'background': False,
24 | 'unique': True,
25 | },
26 | 'circle': {
27 | 'index': 3,
28 | 'value': 0,
29 | 'sprite': 'sprites/circle_figure-01.png',
30 | 'background': False,
31 | 'unique': True,
32 | },
33 | 'triangle': {
34 | 'index': 4,
35 | 'value': 0,
36 | 'sprite': 'sprites/triangle_figure-01.png',
37 | 'background': False,
38 | 'unique': True,
39 | },
40 | 'heart': {
41 | 'index': 5,
42 | 'value': 0,
43 | 'sprite': 'sprites/heart_figure-01.png',
44 | 'background': False,
45 | 'unique': True,
46 | },
47 | 'spade': {
48 | 'index': 6,
49 | 'value': 0,
50 | 'sprite': 'sprites/spade_figure-01.png',
51 | 'background': False,
52 | 'unique': True,
53 | },
54 | 'diamond': {
55 | 'index': 7,
56 | 'value': 0,
57 | 'sprite': 'sprites/diamond_figure-01.png',
58 | 'background': False,
59 | 'unique': True,
60 | },
61 | ## non-unique
62 | 'rock': {
63 | 'index': 8,
64 | 'value': 0,
65 | 'sprite': 'sprites/rock_figure-01.png',
66 | 'background': False,
67 | 'unique': False,
68 | },
69 | 'tree': {
70 | 'index': 9,
71 | 'value': 0,
72 | 'sprite': 'sprites/tree_figure-01.png',
73 | 'background': False,
74 | 'unique': False,
75 | },
76 | 'house': {
77 | 'index': 10,
78 | 'value': 0,
79 | 'sprite': 'sprites/house_figure-01.png',
80 | 'background': False,
81 | 'unique': False,
82 | },
83 | 'horse': {
84 | 'index': 11,
85 | 'value': 0,
86 | 'sprite': 'sprites/horse_figure-01.png',
87 | 'background': False,
88 | 'unique': False,
89 | },
90 | }
91 |
92 | unique_instructions = {
93 | ## original
94 | 'to top left of': (-1, -1),
95 | 'on top of': (-1, 0),
96 | 'to top right of': (-1, 1),
97 | 'to left of': (0, -1),
98 | 'with': (0, 0),
99 | 'to right of': (0, 1),
100 | 'to bottom left of': (1, -1),
101 | 'on bottom of': (1, 0),
102 | 'to bottom right of': (1, 1),
103 | ## two steps away
104 | 'two to the left and two above': (-2, -2),
105 | 'one to the left and two above': (-2, -1),
106 | 'two above': (-2, 0),
107 | 'one to the right and two above': (-2, 1),
108 | 'two to the right and two above': (-2, 2),
109 | 'two to the right and one above': (-1, 2),
110 | 'two to the right of': (0, 2),
111 | 'two to the right and one below': (1, 2),
112 | 'two to the right and two below': (2, 2),
113 | 'one to the right and two below': (2, 1),
114 | 'two below': (2, 0),
115 | 'one to the left and two below': (2, -1),
116 | 'two to the left and two below': (2, -2),
117 | 'two to the left and one below': (1, -2),
118 | 'two to the left': (0, -2),
119 | 'two to the left and one above': (-1, -2)
120 |
121 | }
122 |
123 | background = 'sprites/grass_figure_4.png'
124 |
125 | # print objects
126 |
--------------------------------------------------------------------------------
/environment/library.py:
--------------------------------------------------------------------------------
1 | spritepath = 'sprites/'
2 |
3 | objects = {
4 | 'grass': {
5 | 'index': 0,
6 | 'value': 0,
7 | 'sprite': 'environment/sprites/white.png', # 'sprites/white.png',
8 | 'background': True,
9 | 'unique': False,
10 | },
11 | 'puddle': {
12 | 'index': 1,
13 | 'value': -1,
14 | 'sprite': 'environment/sprites/white.png',
15 | 'background': True,
16 | 'unique': False,
17 | },
18 | ## unique
19 | 'star': {
20 | 'index': 2,
21 | 'value': 0,
22 | 'sprite': 'environment/sprites/white_alpha.png', ## white_alpha.png
23 | 'background': False,
24 | 'unique': True,
25 | },
26 | 'circle': {
27 | 'index': 3,
28 | 'value': 0,
29 | 'sprite': 'environment/sprites/white_alpha.png',
30 | 'background': False,
31 | 'unique': True,
32 | },
33 | 'triangle': {
34 | 'index': 4,
35 | 'value': 0,
36 | 'sprite': 'environment/sprites/white_alpha.png',
37 | 'background': False,
38 | 'unique': True,
39 | },
40 | 'heart': {
41 | 'index': 5,
42 | 'value': 0,
43 | 'sprite': 'environment/sprites/white_alpha.png',
44 | 'background': False,
45 | 'unique': True,
46 | },
47 | 'spade': {
48 | 'index': 6,
49 | 'value': 0,
50 | 'sprite': 'environment/sprites/white_alpha.png',
51 | 'background': False,
52 | 'unique': True,
53 | },
54 | 'diamond': {
55 | 'index': 7,
56 | 'value': 0,
57 | 'sprite': 'environment/sprites/white_alpha.png',
58 | 'background': False,
59 | 'unique': True,
60 | },
61 | ## non-unique
62 | 'rock': {
63 | 'index': 8,
64 | 'value': 0,
65 | 'sprite': 'environment/sprites/rock_hires.png',
66 | 'background': False,
67 | 'unique': False,
68 | },
69 | 'tree': {
70 | 'index': 9,
71 | 'value': 0,
72 | 'sprite': 'environment/sprites/tree_hires.png',
73 | 'background': False,
74 | 'unique': False,
75 | },
76 | 'house': {
77 | 'index': 10,
78 | 'value': 0,
79 | 'sprite': 'environment/sprites/house_hires.png',
80 | 'background': False,
81 | 'unique': False,
82 | },
83 | 'horse': {
84 | 'index': 11,
85 | 'value': 0,
86 | 'sprite': 'environment/sprites/horse_hires.png',
87 | 'background': False,
88 | 'unique': False,
89 | },
90 | }
91 |
92 | unique_instructions = {
93 | ## original
94 | 'to top left of': (-1, -1),
95 | 'on top of': (-1, 0),
96 | 'to top right of': (-1, 1),
97 | 'to left of': (0, -1),
98 | 'with': (0, 0),
99 | 'to right of': (0, 1),
100 | 'to bottom left of': (1, -1),
101 | 'on bottom of': (1, 0),
102 | 'to bottom right of': (1, 1),
103 | ## two steps away
104 | 'two to the left and two above': (-2, -2),
105 | 'one to the left and two above': (-2, -1),
106 | 'two above': (-2, 0),
107 | 'one to the right and two above': (-2, 1),
108 | 'two to the right and two above': (-2, 2),
109 | 'two to the right and one above': (-1, 2),
110 | 'two to the right of': (0, 2),
111 | 'two to the right and one below': (1, 2),
112 | 'two to the right and two below': (2, 2),
113 | 'one to the right and two below': (2, 1),
114 | 'two below': (2, 0),
115 | 'one to the left and two below': (2, -1),
116 | 'two to the left and two below': (2, -2),
117 | 'two to the left and one below': (1, -2),
118 | 'two to the left': (0, -2),
119 | 'two to the left and one above': (-1, -2)
120 |
121 | }
122 |
123 | background = 'environment/sprites/white.png'
124 |
125 | # print objects
--------------------------------------------------------------------------------
/environment/reference_instructions.py:
--------------------------------------------------------------------------------
1 | import pdb
2 | from collections import defaultdict
3 |
4 | def num_to_str(num):
5 | str_rep = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine']
6 | return str_rep[num]
7 |
8 | def create_references(max_dist = 2, verbose = False):
9 | references = defaultdict(lambda: [])
10 | references[( (-2,0), (-1,0) )].append('reach between and ')
11 | references[( (2, 0), (1, 0) )].append('reach between and ')
12 | references[( (0,-2), (0,-1) )].append('reach between and ')
13 | references[( (0, 2), (0, 1) )].append('reach between and ')
14 |
15 | for i in range(-2, 3):
16 | for j in range(-2, 3):
17 | for goal_i in range(-2, 3):
18 | for goal_j in range(-2, 3):
19 | if goal_i != i and goal_j != j and goal_i != 0 and goal_j != 0:
20 |
21 | if abs(goal_j - j) > max_dist or abs(goal_i - i) > max_dist:
22 | continue
23 | ## i with reference to obj 1
24 | ## j with reference to obj 2
25 | row = goal_i
26 | col = goal_j - j
27 | if row > 0:
28 | vertical = '{} below '.format( num_to_str(row) )
29 | elif row < 0:
30 | vertical = '{} above '.format( num_to_str(abs(row)) )
31 | else:
32 | raise RuntimeError('goal_i should not be in line with obj 1')
33 |
34 | if col > 0:
35 | horizontal = '{} to the right of '.format( num_to_str(col) )
36 | elif col < 0:
37 | horizontal = '{} to the left of '.format( num_to_str(abs(col)) )
38 | else:
39 | raise RuntimeError('goal_j should not be in line with obj 2')
40 |
41 | if verbose:
42 | print 'OBJ2: ', i, j, ' Goal: ', goal_i, goal_j
43 |
44 | instructions = 'reach ' + vertical + ' and ' + horizontal
45 | references[( (i,j), (goal_i,goal_j) )].append(instructions)
46 |
47 | if verbose:
48 | print ' ', instructions
49 |
50 | ## i with reference to obj 2
51 | ## j with reference to obj 1
52 | row = goal_i - i
53 | col = goal_j
54 | if row > 0:
55 | vertical = '{} below '.format( num_to_str(row) )
56 | elif row < 0:
57 | vertical = '{} above '.format( num_to_str(abs(row)) )
58 | else:
59 | raise RuntimeError('goal_i should not be in line with obj 2')
60 |
61 | if col > 0:
62 | horizontal = '{} to the right of '.format( num_to_str(col) )
63 | elif col < 0:
64 | horizontal = '{} to the left of '.format( num_to_str(abs(col)) )
65 | else:
66 | raise RuntimeError('goal_j should not be in line with obj 1')
67 |
68 | instructions = 'reach ' + vertical + ' and ' + horizontal
69 |
70 | if verbose:
71 | print ' ', instructions
72 |
73 | references[( (i,j), (goal_i,goal_j) )].append(instructions)
74 |
75 | return references
76 |
77 |
78 | if __name__ == '__main__':
79 | references = create_references()
80 | pdb.set_trace()
81 |
82 |
83 |
84 |
85 |
86 |
87 |
--------------------------------------------------------------------------------
/environment/visualization.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib
3 | matplotlib.use('Agg')
4 | from matplotlib import pyplot as plt
5 |
6 | def visualize_values(mdp, values, policy, filename, title=None):
7 | states = mdp.states
8 | # print states
9 | plt.clf()
10 | m = max(states, key=lambda x: x[0])[0] + 1
11 | n = max(states, key=lambda x: x[1])[1] + 1
12 | data = np.zeros((m,n))
13 | for i in range(m):
14 | for j in range(n):
15 | state = (i,j)
16 | if type(values) == dict:
17 | data[i][j] = values[state]
18 | else:
19 | # print values[i][j]
20 | data[i][j] = values[i][j]
21 | action = policy[state]
22 | ## if using all_reachable actions, pick the best one
23 | if type(action) == tuple:
24 | action = action[0]
25 | if action != None:
26 | x, y, w, h = arrow(i, j, action)
27 | plt.arrow(x,y,w,h,head_length=0.4,head_width=0.4,fc='k',ec='k')
28 | heatmap = plt.pcolor(data, cmap=plt.get_cmap('jet'))
29 | plt.colorbar()
30 | plt.gca().invert_yaxis()
31 |
32 | if title:
33 | plt.title(title)
34 | plt.savefig(filename + '.png')
35 | # print data
36 |
37 | def arrow(i, j, action):
38 | ## up, down, left, right
39 | ## x, y, w, h
40 | arrows = {0: (.5,.95,0,-.4), 1: (.5,.05,0,.4), 2: (.95,.5,-.4,0), 3: (.05,.5,.4,0)}
41 | arrow = arrows[action]
42 | return j+arrow[0], i+arrow[1], arrow[2], arrow[3]
43 |
44 |
--------------------------------------------------------------------------------
/generate_worlds.py:
--------------------------------------------------------------------------------
1 | #!/om/user/janner/anaconda2/envs/pytorch/bin/python
2 |
3 | import sys, os, argparse, pickle, subprocess, pdb
4 | from tqdm import tqdm
5 | import environment, utils
6 |
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--lower', type=int, default=0)
9 | parser.add_argument('--num_worlds', type=int, default=10)
10 | parser.add_argument('--vis_path', type=str, default='data/example_vis/')
11 | parser.add_argument('--save_path', type=str, default='data/example_env/')
12 | parser.add_argument('--dim', type=int, default=10)
13 | parser.add_argument('--mode', type=str, default='local', choices=['local', 'global'])
14 | parser.add_argument('--only_global', type=bool, default=False)
15 | parser.add_argument('--sprite_dim', type=int, default=100)
16 | parser.add_argument('--num_steps', type=int, default=10)
17 | args = parser.parse_args()
18 |
19 | print args, '\n'
20 |
21 | utils.mkdir(args.vis_path)
22 | utils.mkdir(args.save_path)
23 |
24 |
25 | if args.mode == 'local':
26 | from environment.NonUniqueGenerator import NonUniqueGenerator
27 | gen = NonUniqueGenerator( environment.figure_library.objects, environment.figure_library.unique_instructions, shape=(args.dim, args.dim), num_steps=args.num_steps, only_global=args.only_global )
28 | elif args.mode == 'global':
29 | from environment.GlobalGenerator import GlobalGenerator
30 | gen = GlobalGenerator( environment.figure_library.objects, environment.figure_library.unique_instructions, shape=(args.dim, args.dim), num_steps=args.num_steps, only_global=args.only_global )
31 |
32 |
33 | for outer in range(args.lower, args.lower + args.num_worlds):
34 | info = gen.new()
35 | configurations = len(info['rewards'])
36 |
37 | print 'Generating map', outer, '(', configurations, 'configuations )'
38 | sys.stdout.flush()
39 |
40 | world = info['map']
41 | rewards = info['rewards']
42 | terminal = info['terminal']
43 | values = []
44 |
45 | sprite = environment.SpriteFigure(environment.figure_library.objects, environment.figure_library.background, dim=args.sprite_dim)
46 | sprite.makeGrid(world, args.vis_path + str(outer) + '_sprites')
47 |
48 | for inner in tqdm(range(configurations)):
49 | reward_map = rewards[inner]
50 | terminal_map = terminal[inner]
51 |
52 | mdp = environment.MDP(world, reward_map, terminal_map)
53 | vi = environment.ValueIteration(mdp)
54 |
55 | values_list, policy = vi.iterate()
56 | value_map = mdp.representValues(values_list)
57 | values.append(value_map)
58 |
59 |
60 | info['values'] = values
61 | filename = os.path.join( args.save_path, str(outer) + '.p' )
62 | pickle.dump( info, open( filename, 'wb' ) )
63 |
64 |
65 |
--------------------------------------------------------------------------------
/logs/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jannerm/spatial-reasoning/e163003a33177e41ca02d5feefee3fdfca5ba154/logs/.gitkeep
--------------------------------------------------------------------------------
/logs/example/git_sprites.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jannerm/spatial-reasoning/e163003a33177e41ca02d5feefee3fdfca5ba154/logs/example/git_sprites.png
--------------------------------------------------------------------------------
/logs/example/predictions.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jannerm/spatial-reasoning/e163003a33177e41ca02d5feefee3fdfca5ba154/logs/example/predictions.png
--------------------------------------------------------------------------------
/models/Linear_custom.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 |
4 |
5 | class Linear_custom(Function):
6 |
7 | def forward(self, input, weight, bias=None):
8 | self.save_for_backward(input, weight, bias)
9 | output = input.new(input.size(0), weight.size(0))
10 | output.addmm_(0, 1, input, weight.t())
11 | if bias is not None:
12 | # cuBLAS doesn't support 0 strides in sger, so we can't use expand
13 | self.add_buffer = input.new(input.size(0)).fill_(1)
14 | output.addr_(self.add_buffer, bias)
15 | return output
16 |
17 | def backward(self, grad_output):
18 | input, weight, bias = self.saved_tensors
19 |
20 | grad_input = grad_weight = grad_bias = None
21 | if self.needs_input_grad[0]:
22 | grad_output = grad_output.squeeze()
23 | grad_input = torch.mm(grad_output, weight)
24 | if self.needs_input_grad[1]:
25 | grad_weight = torch.mm(grad_output.t(), input)
26 | if bias is not None and self.needs_input_grad[2]:
27 | grad_bias = torch.mv(grad_output.t(), self.add_buffer)
28 |
29 | if bias is not None:
30 | return grad_input, grad_weight, grad_bias
31 | else:
32 | return grad_input, grad_weight
33 |
34 |
35 | class Bilinear(Function):
36 |
37 | def forward(self, input1, input2, weight, bias=None):
38 | self.save_for_backward(input1, input2, weight, bias)
39 |
40 | output = input1.new(input1.size(0), weight.size(0))
41 |
42 | buff = input1.new()
43 |
44 | # compute output scores:
45 | for k, w in enumerate(weight):
46 | torch.mm(input1, w, out=buff)
47 | buff.mul_(input2)
48 | torch.sum(buff, 1, out=output.narrow(1, k, 1))
49 |
50 | if bias is not None:
51 | output.add_(bias.expand_as(output))
52 |
53 | return output
54 |
55 | def backward(self, grad_output):
56 | input1, input2, weight, bias = self.saved_tensors
57 | grad_input1 = grad_input2 = grad_weight = grad_bias = None
58 |
59 | buff = input1.new()
60 |
61 | if self.needs_input_grad[0] or self.needs_input_grad[1]:
62 | grad_input1 = torch.mm(input2, weight[0].t())
63 | grad_input1.mul_(grad_output.narrow(1, 0, 1).expand(grad_input1.size()))
64 | grad_input2 = torch.mm(input1, weight[0])
65 | grad_input2.mul_(grad_output.narrow(1, 0, 1).expand(grad_input2.size()))
66 |
67 | for k in range(1, weight.size(0)):
68 | torch.mm(input2, weight[k].t(), out=buff)
69 | buff.mul_(grad_output.narrow(1, k, 1).expand(grad_input1.size()))
70 | grad_input1.add_(buff)
71 |
72 | torch.mm(input1, weight[k], out=buff)
73 | buff.mul_(grad_output.narrow(1, k, 1).expand(grad_input2.size()))
74 | grad_input2.add_(buff)
75 |
76 | if self.needs_input_grad[2]:
77 | # accumulate parameter gradients:
78 | for k in range(weight.size(0)):
79 | torch.mul(input1, grad_output.narrow(1, k, 1).expand_as(input1), out=buff)
80 | grad_weight = torch.mm(buff.t(), input2)
81 |
82 | if bias is not None and self.needs_input_grad[3]:
83 | grad_bias = grad_output.sum(0)
84 |
85 | return grad_input1, grad_input2, grad_weight, grad_bias
86 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from state_model_10 import *
2 | from object_model_10 import *
3 | from text_model import *
4 | from goal_model import *
5 | from compositor_model import *
6 | from constraint_factorizer import *
7 | from model_factorizer import *
8 | from values_factorized import *
9 | from map_model import *
10 |
11 | ## attention
12 | from lookup_model import *
13 | from lookup_location import *
14 | from attention_model import *
15 | from attention_direct import *
16 | from attention_heatmap import *
17 |
18 | ## global
19 | from attention_global import *
20 | from multi_global import *
21 | from multi_nonsep import *
22 | from multi_nobases import *
23 | from multi_norbf import *
24 | from multi_nocnn import *
25 | from cnn_lstm import *
26 |
27 | ## full map (instead of single value)
28 | from multi_model import *
29 | from simple_conv import *
30 | from conv_to_vector import *
31 |
32 | ## uvfa
33 | from uvfa_pos import *
34 | # from uvfa_3 import *
35 | from uvfa_text import *
36 | from mlp import *
37 |
38 | from initializations import *
--------------------------------------------------------------------------------
/models/attention_direct.py:
--------------------------------------------------------------------------------
1 | ## attention model with only a single convolution
2 | ## LSTM kernel isn't really used as attention map,
3 | ## but just as a kernel for convolution
4 |
5 | import torch
6 | import math, torch.nn as nn
7 | import torch.nn.functional as F
8 | import utils
9 |
10 | class AttentionDirect(nn.Module):
11 | # args.lstm_out, args.goal_hid, args.rank, args.obj_embed
12 | def __init__(self, text_model, object_model, args, final_hidden = 20, map_dim = 10):
13 | super(AttentionDirect, self).__init__()
14 |
15 | self.text_model = text_model
16 | self.object_model = object_model
17 |
18 | self.embed_dim = args.obj_embed
19 | self.kernel_out_dim = args.attention_out_dim
20 | self.kernel_size = args.attention_kernel
21 |
22 | self.conv_custom = utils.ConvKernel(self.embed_dim, self.kernel_out_dim, self.kernel_size, bias=False)
23 |
24 | self.reshape_dim = self.kernel_out_dim * (map_dim-self.kernel_size+1)**2
25 | self.fc1 = nn.Linear(self.reshape_dim, args.rank)
26 |
27 | def __conv(self, inp, kernel):
28 | batch_size = inp.size(0)
29 | out = [ self.conv_custom(inp[i].unsqueeze(0), kernel[i]) for i in range(batch_size) ]
30 | out = torch.cat(out, 0)
31 | return out
32 |
33 | def forward(self, inp):
34 | (obj, text) = inp
35 | batch_size = obj.size(0)
36 | text = text.transpose(0,1)
37 | hidden = self.text_model.init_hidden(batch_size)
38 |
39 | embeddings = self.object_model.forward(obj)
40 |
41 | lstm_out = self.text_model.forward(text, hidden)
42 | lstm_out = lstm_out.view(-1, self.kernel_out_dim, self.embed_dim, self.kernel_size, self.kernel_size)
43 |
44 | conv = self.__conv(embeddings, lstm_out)
45 | # print conv.size()
46 | conv = conv.view(-1, self.reshape_dim)
47 | # print conv.size()
48 |
49 | out = F.relu(self.fc1(conv))
50 | # print out.size()
51 | # out = F.relu(self.fc2(out))
52 |
53 | return out
54 |
55 |
56 | if __name__ == '__main__':
57 | from text_model import *
58 | from object_model import *
59 | from lookup_model import *
60 |
61 | batch = 2
62 | seq = 10
63 |
64 | text_vocab = 10
65 | lstm_inp = 5
66 | lstm_hid = 3
67 | lstm_layers = 1
68 |
69 | # obj_inp = torch.LongTensor(1,10,10).zero_()
70 | obj_vocab = 3
71 | emb_dim = 3
72 |
73 | concat_dim = 27
74 | hidden_dim = 5
75 | out_dim = 7
76 |
77 | text_model = TextModel(text_vocab, lstm_inp, lstm_hid, lstm_layers, concat_dim)
78 | object_model = LookupModel(obj_vocab, emb_dim, concat_dim)
79 |
80 | psi = AttentionModel(text_model, object_model, concat_dim, hidden_dim, out_dim)
81 |
82 | hidden = text_model.init_hidden(batch)
83 | text_inp = Variable(torch.floor(torch.rand(batch,seq)*text_vocab).long())
84 | obj_inp = Variable(torch.floor(torch.rand(batch,1,10,10)*obj_vocab).long())
85 |
86 | to = psi.forward( (obj_inp, text_inp) )
87 | print to.size()
88 |
89 | # # inp = Variable(torch.LongTensor((1,2,3)))
90 | # print 'INPUT: ', text_inp
91 | # print text_inp.size()
92 | # out = text_model.forward( text_inp, hidden )
93 |
94 | # print 'OUT: ', out
95 | # print out.size()
96 | # # print 'HID: ', hid
97 |
98 | # obj_inp = Variable(torch.LongTensor(5,1,20,20).zero_())
99 | # obj_inp = Variable(obj_inp)
100 | # obj_out = object_model.forward(obj_inp)
101 |
102 | # print obj_out.data.size()
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
--------------------------------------------------------------------------------
/models/attention_global.py:
--------------------------------------------------------------------------------
1 | ## attention model with only a single convolution
2 | ## LSTM kernel isn't really used as attention map,
3 | ## but just as a kernel for convolution
4 |
5 | import math, pdb
6 | import torch, torch.nn as nn, torch.nn.functional as F
7 | import custom
8 |
9 | class AttentionGlobal(nn.Module):
10 | # args.lstm_out, args.goal_hid, args.rank, args.obj_embed
11 | def __init__(self, text_model, args, map_dim = 10):
12 | super(AttentionGlobal, self).__init__()
13 |
14 | assert args.attention_kernel % 2 == 1
15 |
16 | self.text_model = text_model
17 | # self.object_model = object_model
18 |
19 | self.embed_dim = args.attention_in_dim
20 | self.kernel_out_dim = args.attention_out_dim
21 | self.kernel_size = args.attention_kernel
22 | self.global_coeffs = args.global_coeffs
23 |
24 | padding = int(math.ceil(self.kernel_size/2.)) - 1
25 | self.conv_custom = custom.ConvKernel(self.embed_dim, self.kernel_out_dim, self.kernel_size, bias=False, padding=padding)
26 |
27 | self.reshape_dim = self.kernel_out_dim * (map_dim-self.kernel_size+1)**2
28 |
29 | def __conv(self, inp, kernel):
30 | batch_size = inp.size(0)
31 | out = [ self.conv_custom(inp[i].unsqueeze(0), kernel[i]) for i in range(batch_size) ]
32 | out = torch.cat(out, 0)
33 | return out
34 |
35 | def forward(self, inp):
36 | (embeddings, text) = inp
37 | batch_size = embeddings.size(0)
38 | text = text.transpose(0,1)
39 | hidden = self.text_model.init_hidden(batch_size)
40 |
41 | # embeddings = self.object_model.forward(obj)
42 | # print embeddings.size()
43 |
44 | lstm_out = self.text_model.forward(text, hidden)
45 | lstm_kernel = lstm_out[:,:-self.global_coeffs].contiguous()
46 | lstm_kernel = lstm_kernel.view(-1, self.kernel_out_dim, self.embed_dim, self.kernel_size, self.kernel_size)
47 | # print 'LSTM_OUT: ', lstm_out.size()
48 | # print 'EMBEDDINGS: ', embeddings.size()
49 | # print self.kernel_size/2 - 1, self.kernel_size
50 | local_heatmap = self.__conv(embeddings, lstm_kernel)
51 |
52 | ## sum along attention_out_dim
53 | ## < batch x attention_out_dim x map_dim x map_dim >
54 | ## < batch x 1 x map_dim x map_dim >
55 | local_heatmap = local_heatmap.sum(1, keepdim=True)
56 |
57 | lstm_global = lstm_out[:,-self.global_coeffs:]
58 | # global_heatmap = self._global(lstm_global)
59 |
60 | # out = local_heatmap + global_heatmap
61 | # print conv.size()
62 | # conv = conv.view(-1, self.reshape_dim)
63 |
64 | ## save outputs for kernel visualization
65 | self.output_local = local_heatmap
66 | self.output_global = lstm_global
67 |
68 | return local_heatmap, lstm_global
69 |
70 |
71 | if __name__ == '__main__':
72 | import argparse
73 | from text_model import *
74 | from object_model import *
75 | from lookup_model import *
76 |
77 | parser = argparse.ArgumentParser()
78 | parser.add_argument('--attention_kernel', type=int, default=3)
79 | parser.add_argument('--attention_out_dim', type=int, default=3)
80 | parser.add_argument('--obj_embed', type=int, default=5)
81 | parser.add_argument('--map_dim', type=int, default=10)
82 | args = parser.parse_args()
83 |
84 | batch = 2
85 | seq = 10
86 |
87 | text_vocab = 10
88 | lstm_inp = 5
89 | lstm_hid = 3
90 | lstm_layers = 1
91 |
92 | # obj_inp = torch.LongTensor(1,10,10).zero_()
93 | obj_vocab = 3
94 | # emb_dim = 3
95 |
96 | lstm_out = args.obj_embed * args.attention_out_dim * args.attention_kernel**2
97 | # hidden_dim = 5
98 | # out_dim = 7
99 |
100 | text_model = TextModel(text_vocab, lstm_inp, lstm_hid, lstm_layers, lstm_out)
101 | object_model = LookupModel(obj_vocab, args.obj_embed)
102 |
103 | psi = AttentionHeatmap(text_model, object_model, args, map_dim = args.map_dim)
104 |
105 | hidden = text_model.init_hidden(batch)
106 | text_inp = Variable(torch.floor(torch.rand(batch,seq)*text_vocab).long())
107 | obj_inp = Variable(torch.floor(torch.rand(batch,1,10,10)*obj_vocab).long())
108 |
109 | to = psi.forward( (obj_inp, text_inp) )
110 | print to.size()
111 |
112 | # # inp = Variable(torch.LongTensor((1,2,3)))
113 | # print 'INPUT: ', text_inp
114 | # print text_inp.size()
115 | # out = text_model.forward( text_inp, hidden )
116 |
117 | # print 'OUT: ', out
118 | # print out.size()
119 | # # print 'HID: ', hid
120 |
121 | # obj_inp = Variable(torch.LongTensor(5,1,20,20).zero_())
122 | # obj_inp = Variable(obj_inp)
123 | # obj_out = object_model.forward(obj_inp)
124 |
125 | # print obj_out.data.size()
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
--------------------------------------------------------------------------------
/models/attention_heatmap.py:
--------------------------------------------------------------------------------
1 | ## attention model with only a single convolution
2 | ## LSTM kernel isn't really used as attention map,
3 | ## but just as a kernel for convolution
4 |
5 | import torch
6 | import math, torch.nn as nn
7 | import torch.nn.functional as F
8 | import custom
9 |
10 | class AttentionHeatmap(nn.Module):
11 | # args.lstm_out, args.goal_hid, args.rank, args.obj_embed
12 | def __init__(self, text_model, args, map_dim = 10):
13 | super(AttentionHeatmap, self).__init__()
14 |
15 | assert args.attention_kernel % 2 == 1
16 |
17 | self.text_model = text_model
18 |
19 | self.embed_dim = args.attention_in_dim
20 | self.kernel_out_dim = args.attention_out_dim
21 | self.kernel_size = args.attention_kernel
22 |
23 | padding = int(math.ceil(self.kernel_size/2.)) - 1
24 | self.conv_custom = custom.ConvKernel(self.embed_dim, self.kernel_out_dim, self.kernel_size, bias=False, padding=padding)
25 |
26 | self.reshape_dim = self.kernel_out_dim * (map_dim-self.kernel_size+1)**2
27 |
28 | def __conv(self, inp, kernel):
29 | batch_size = inp.size(0)
30 | out = [ self.conv_custom(inp[i].unsqueeze(0), kernel[i]) for i in range(batch_size) ]
31 | out = torch.cat(out, 0)
32 | return out
33 |
34 | def forward(self, inp):
35 | (embeddings, text) = inp
36 | batch_size = embeddings.size(0)
37 | text = text.transpose(0,1)
38 | hidden = self.text_model.init_hidden(batch_size)
39 |
40 | lstm_out = self.text_model.forward(text, hidden)
41 | lstm_kernel = lstm_out.view(-1, self.kernel_out_dim, self.embed_dim, self.kernel_size, self.kernel_size)
42 | # print 'LSTM_OUT: ', lstm_out.size()
43 | # print 'EMBEDDINGS: ', embeddings.size()
44 | # print self.kernel_size/2 - 1, self.kernel_size
45 | local_heatmap = self.__conv(embeddings, lstm_kernel)
46 |
47 | ## sum along attention_out_dim
48 | ## < batch x attention_out_dim x map_dim x map_dim >
49 | ## < batch x 1 x map_dim x map_dim >
50 | local_heatmap = local_heatmap.sum(1, keepdim=True)
51 | # print conv.size()
52 | # conv = conv.view(-1, self.reshape_dim)
53 |
54 | return local_heatmap
55 |
56 |
57 | if __name__ == '__main__':
58 | import argparse
59 | from text_model import *
60 | from object_model import *
61 | from lookup_model import *
62 |
63 | parser = argparse.ArgumentParser()
64 | parser.add_argument('--attention_kernel', type=int, default=3)
65 | parser.add_argument('--attention_out_dim', type=int, default=3)
66 | parser.add_argument('--obj_embed', type=int, default=5)
67 | parser.add_argument('--map_dim', type=int, default=10)
68 | args = parser.parse_args()
69 |
70 | batch = 2
71 | seq = 10
72 |
73 | text_vocab = 10
74 | lstm_inp = 5
75 | lstm_hid = 3
76 | lstm_layers = 1
77 |
78 | # obj_inp = torch.LongTensor(1,10,10).zero_()
79 | obj_vocab = 3
80 | # emb_dim = 3
81 |
82 | lstm_out = args.obj_embed * args.attention_out_dim * args.attention_kernel**2
83 | # hidden_dim = 5
84 | # out_dim = 7
85 |
86 | text_model = TextModel(text_vocab, lstm_inp, lstm_hid, lstm_layers, lstm_out)
87 | object_model = LookupModel(obj_vocab, args.obj_embed)
88 |
89 | psi = AttentionHeatmap(text_model, object_model, args, map_dim = args.map_dim)
90 |
91 | hidden = text_model.init_hidden(batch)
92 | text_inp = Variable(torch.floor(torch.rand(batch,seq)*text_vocab).long())
93 | obj_inp = Variable(torch.floor(torch.rand(batch,1,10,10)*obj_vocab).long())
94 |
95 | to = psi.forward( (obj_inp, text_inp) )
96 | print to.size()
97 |
98 | # # inp = Variable(torch.LongTensor((1,2,3)))
99 | # print 'INPUT: ', text_inp
100 | # print text_inp.size()
101 | # out = text_model.forward( text_inp, hidden )
102 |
103 | # print 'OUT: ', out
104 | # print out.size()
105 | # # print 'HID: ', hid
106 |
107 | # obj_inp = Variable(torch.LongTensor(5,1,20,20).zero_())
108 | # obj_inp = Variable(obj_inp)
109 | # obj_out = object_model.forward(obj_inp)
110 |
111 | # print obj_out.data.size()
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
--------------------------------------------------------------------------------
/models/attention_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math, torch.nn as nn
3 | import torch.nn.functional as F
4 | import utils
5 |
6 | class AttentionModel(nn.Module):
7 |
8 | def __init__(self, text_model, object_model, args, final_hidden = 20, map_dim = 10):
9 | super(AttentionModel, self).__init__()
10 |
11 | assert args.attention_kernel % 2 == 1
12 |
13 | self.text_model = text_model
14 | self.object_model = object_model
15 |
16 | self.embed_dim = args.obj_embed
17 | self.kernel_out_dim = args.attention_out_dim
18 | self.kernel_size = args.attention_kernel
19 |
20 | self.conv_custom = utils.ConvKernel(self.embed_dim, self.kernel_out_dim, self.kernel_size, bias=False, padding=1)
21 |
22 | ## final_hidden and fc1 are hard-coded
23 | ## should be bash arg / inferred
24 | self.conv1_kernel = 3
25 | self.conv1 = nn.Conv2d(self.embed_dim, self.embed_dim, kernel_size=self.conv1_kernel, padding=1)
26 | # self.conv2 = nn.Conv2d(3, 6, kernel_size=5)
27 | # self.conv3 = nn.Conv2d(6,12, kernel_size=5)
28 | self.reshape_dim = self.embed_dim * (map_dim)**2
29 | self.fc1 = nn.Linear(self.reshape_dim, final_hidden)
30 | self.fc2 = nn.Linear(final_hidden, args.rank)
31 |
32 | def __conv(self, inp, kernel):
33 | # print '__conv: ', inp.size(), kernel.size()
34 | batch_size = inp.size(0)
35 | out = [ self.conv_custom(inp[i].unsqueeze(0), kernel[i]) for i in range(batch_size) ]
36 | out = torch.cat(out, 0)
37 | return out
38 |
39 | def forward(self, inp):
40 | (obj, text) = inp
41 | batch_size = obj.size(0)
42 | text = text.transpose(0,1)
43 | hidden = self.text_model.init_hidden(batch_size)
44 |
45 | embeddings = self.object_model.forward(obj)
46 |
47 | lstm_out = self.text_model.forward(text, hidden)
48 | lstm_out = lstm_out.view(-1, 1, self.embed_dim, self.kernel_size, self.kernel_size)
49 | # print 'lstm: ', lstm_out.size()
50 |
51 | ## first convolve the object embeddings
52 | ## to populate the states neighboring objects
53 | ## (otherwise they would be all 0's)
54 | conv = F.relu(self.conv1(embeddings))
55 | # print 'conv: ', conv.size()
56 | ## get attention map
57 | attention = self.__conv(conv, lstm_out)
58 | # print 'attention: ', attention.size()
59 | # attention = self.__conv(embeddings, lstm_out)
60 | attention = attention.repeat(1,self.embed_dim,1,1)
61 | # print 'conv: ', conv.size()
62 | attended = attention * conv
63 | attended = attended.view(-1, self.reshape_dim)
64 | # print 'attended: ', attended.size()
65 | out = F.relu(self.fc1(attended))
66 | out = F.relu(self.fc2(out))
67 |
68 | # out = F.relu(self.conv1(attended))
69 | # out = F.relu(self.conv2(out))
70 | # out = F.relu(self.conv3(out))
71 | # out = out.view(-1, 12*2*2)
72 | # out = F.relu(self.fc1(out))
73 | # out = self.fc2(out)
74 |
75 | return out
76 |
77 |
78 | if __name__ == '__main__':
79 | from text_model import *
80 | from object_model import *
81 | from lookup_model import *
82 |
83 | batch = 2
84 | seq = 10
85 |
86 | text_vocab = 10
87 | lstm_inp = 5
88 | lstm_hid = 3
89 | lstm_layers = 1
90 |
91 | # obj_inp = torch.LongTensor(1,10,10).zero_()
92 | obj_vocab = 3
93 | emb_dim = 3
94 |
95 | concat_dim = 27
96 | hidden_dim = 5
97 | out_dim = 7
98 |
99 | text_model = TextModel(text_vocab, lstm_inp, lstm_hid, lstm_layers, concat_dim)
100 | object_model = LookupModel(obj_vocab, emb_dim, concat_dim)
101 |
102 | psi = AttentionModel(text_model, object_model, concat_dim, hidden_dim, out_dim)
103 |
104 | hidden = text_model.init_hidden(batch)
105 | text_inp = Variable(torch.floor(torch.rand(batch,seq)*text_vocab).long())
106 | obj_inp = Variable(torch.floor(torch.rand(batch,1,10,10)*obj_vocab).long())
107 |
108 | to = psi.forward( (obj_inp, text_inp) )
109 | print to.size()
110 |
111 | # # inp = Variable(torch.LongTensor((1,2,3)))
112 | # print 'INPUT: ', text_inp
113 | # print text_inp.size()
114 | # out = text_model.forward( text_inp, hidden )
115 |
116 | # print 'OUT: ', out
117 | # print out.size()
118 | # # print 'HID: ', hid
119 |
120 | # obj_inp = Variable(torch.LongTensor(5,1,20,20).zero_())
121 | # obj_inp = Variable(obj_inp)
122 | # obj_out = object_model.forward(obj_inp)
123 |
124 | # print obj_out.data.size()
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
--------------------------------------------------------------------------------
/models/cnn_lstm.py:
--------------------------------------------------------------------------------
1 | ## attention model with only a single convolution
2 | ## LSTM kernel isn't really used as attention map,
3 | ## but just as a kernel for convolution
4 |
5 | import torch
6 | import math, torch.nn as nn, pdb
7 | import torch.nn.functional as F
8 | from torch.autograd import Variable
9 | import models, utils
10 |
11 | class CNN_LSTM(nn.Module):
12 | # args.lstm_out, args.goal_hid, args.rank, args.obj_embed
13 | def __init__(self, state_model, object_model, lstm, args, map_dim = 10, batch_size = 32):
14 | super(CNN_LSTM, self).__init__()
15 |
16 | self.state_model = state_model
17 | self.object_model = object_model
18 |
19 | self.cnn_inp_dim = args.obj_embed + args.state_embed + 1
20 | self.cnn = models.ConvToVector(self.cnn_inp_dim)
21 | self.lstm = lstm
22 |
23 | self.fc1 = nn.Linear(args.cnn_out_dim, 16)
24 | self.fc2 = nn.Linear(16, 1)
25 |
26 | # self.state_vocab = args.state_embed
27 | # self.object_vocab = args.obj_embed
28 |
29 | self.state_dim = args.state_embed
30 | self.object_dim = args.obj_embed
31 |
32 | self.map_dim = map_dim
33 | self.batch_size = batch_size
34 | self.positions = self.__agent_pos_2d()
35 |
36 | # '''
37 | # returns tensor with one-hot vector encoding
38 | # [1, 2, 3, ..., map_dim] repeated batch_size times
39 | # < batch_size * map_dim, state_vocab >
40 | # '''
41 | # def __agent_pos(self):
42 | # size = self.map_dim**2
43 | # positions = torch.zeros(self.batch_size*size, 100, 1)
44 | # # print positions.size()
45 | # for ind in range(size):
46 | # # print ind, ind*self.batch_size, (ind+1)*self.batch_size, ind, positions.size()
47 | # # positions[ind*self.batch_size:(ind+1)*self.batch_size, ind] = 1
48 | # positions[ind:self.batch_size*size:size, ind] = 1
49 | # # pdb.set_trace()
50 | # return Variable( positions.cuda() )
51 |
52 | def __agent_pos_2d(self):
53 | ## < 10 x 10 >
54 | positions = torch.zeros(self.map_dim**2, self.map_dim**2)
55 | for i in range(self.map_dim**2):
56 | positions[i][i] = 1
57 |
58 | ## < 100 x 10 x 10 >
59 | positions = positions.view(self.map_dim**2, self.map_dim, self.map_dim)
60 | ## < 100 x 1 x 10 x 10 >
61 | ## < 100*batch x 1 x 10 x 10 >
62 | positions = positions.unsqueeze(1).repeat(self.batch_size,1,1,1)
63 | return Variable( positions.cuda() )
64 |
65 |
66 | def __repeat_position(self, x):
67 | # print 'X: ', x.size()
68 | if x.size() == 2:
69 | return x.unsqueeze(1).repeat(1,self.map_dim**2,1)
70 | elif x.size() == 3:
71 | return x.unsqueeze(1).repeat(1,self.map_dim**2,1,1)
72 | else:
73 | return x.repeat(1,self.map_dim**2,1,1)
74 |
75 | # '''
76 | # < batch_size x N >
77 | # < batch_size*100 x N >
78 | # '''
79 | # def __construct_inp(self, state, obj, text):
80 | # state = self.__repeat_position(state)
81 | # # pdb.set_trace()
82 | # state = state.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.state_dim)
83 | # ## add agent position
84 | # state = torch.cat( (state, self.positions), -1)
85 | # ## reshape to (batched) vector for input to MLPs
86 | # state = state.view(-1, self.state_dim+1)
87 |
88 | # obj = self.__repeat_position(obj)
89 | # obj = obj.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.object_dim)
90 | # obj = obj.view(-1, self.object_dim)
91 |
92 | # instr_length = text.size(1)
93 | # ## < batch x length >
94 | # ## < batch x 100 x length >
95 | # text = self.__repeat_position(text)
96 | # ## < batch*100 x length >
97 | # text = text.view(self.batch_size*self.map_dim**2,instr_length)
98 | # ## < length x batch*100 >
99 | # text = text.transpose(0,1)
100 | # ## < batch*100 x rank >
101 |
102 | # return state, obj, text
103 |
104 | def forward(self, inp):
105 | (state, obj, text) = inp
106 | batch_size = state.size(0)
107 |
108 | if batch_size != self.batch_size:
109 | self.batch_size = batch_size
110 | self.positions = self.__agent_pos_2d()
111 |
112 | ## < batch x layout x 10 x 10 >
113 | state_out = self.state_model(state)
114 | ## < batch x object x 10 x 10 >
115 | obj_out = self.object_model.forward(obj)
116 |
117 | ## < batch x layout+object x 10 x 10 >
118 | embeddings = torch.cat( (state_out, obj_out), 1 )
119 | ## < batch x (layout+object)*100 x 10 x 10 >
120 | ## < batch*100 x layout+object x 10 x 10 >
121 | embeddings = self.__repeat_position(embeddings).view(self.batch_size*self.map_dim**2,self.cnn_inp_dim-1,self.map_dim,self.map_dim)
122 |
123 | ## < batch*100, layout+object+agent, 10, 10 >
124 | ## state + object embeddings + agent position
125 | concat = torch.cat( (embeddings, self.positions), 1)
126 | ## < batch*100 x 1 x 4 x 4 >
127 | ## < batch*100 x 16 >
128 | cnn_out = self.cnn( concat ).view(batch_size*self.map_dim**2, -1)
129 |
130 |
131 | instr_length = text.size(1)
132 | ## < batch x length >
133 | ## < batch x 100 x length >
134 | text = self.__repeat_position(text)
135 | ## < batch*100 x length >
136 | text = text.view(self.batch_size*self.map_dim**2,instr_length)
137 | ## < length x batch*100 >
138 | text = text.transpose(0,1)
139 | hidden = self.lstm.init_hidden(self.batch_size*self.map_dim**2)
140 | ## < batch*100 x rank >
141 | lstm_out = self.lstm.forward(text, hidden)
142 |
143 | concat = torch.cat( (cnn_out, lstm_out), 1 )
144 |
145 | out = F.relu(self.fc1(concat))
146 | out = self.fc2(out)
147 |
148 | map_pred = out.view(self.batch_size,self.map_dim,self.map_dim)
149 |
150 | return map_pred
151 |
152 |
153 | if __name__ == '__main__':
154 | from text_model import *
155 | from object_model import *
156 | from lookup_model import *
157 |
158 | batch = 2
159 | seq = 10
160 |
161 | text_vocab = 10
162 | lstm_inp = 5
163 | lstm_hid = 3
164 | lstm_layers = 1
165 |
166 | # obj_inp = torch.LongTensor(1,10,10).zero_()
167 | obj_vocab = 3
168 | emb_dim = 3
169 |
170 | concat_dim = 27
171 | hidden_dim = 5
172 | out_dim = 7
173 |
174 | text_model = TextModel(text_vocab, lstm_inp, lstm_hid, lstm_layers, concat_dim)
175 | object_model = LookupModel(obj_vocab, emb_dim, concat_dim)
176 |
177 | psi = AttentionModel(text_model, object_model, concat_dim, hidden_dim, out_dim)
178 |
179 | hidden = text_model.init_hidden(batch)
180 | text_inp = Variable(torch.floor(torch.rand(batch,seq)*text_vocab).long())
181 | obj_inp = Variable(torch.floor(torch.rand(batch,1,10,10)*obj_vocab).long())
182 |
183 | to = psi.forward( (obj_inp, text_inp) )
184 | print to.size()
185 |
186 | # # inp = Variable(torch.LongTensor((1,2,3)))
187 | # print 'INPUT: ', text_inp
188 | # print text_inp.size()
189 | # out = text_model.forward( text_inp, hidden )
190 |
191 | # print 'OUT: ', out
192 | # print out.size()
193 | # # print 'HID: ', hid
194 |
195 | # obj_inp = Variable(torch.LongTensor(5,1,20,20).zero_())
196 | # obj_inp = Variable(obj_inp)
197 | # obj_out = object_model.forward(obj_inp)
198 |
199 | # print obj_out.data.size()
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
--------------------------------------------------------------------------------
/models/compositor_model.py:
--------------------------------------------------------------------------------
1 | import torch, torch.nn as nn
2 |
3 | class CompositorModel(nn.Module):
4 |
5 | def __init__(self, state_model, goal_model):
6 | super(CompositorModel, self).__init__()
7 |
8 | self.state_model = state_model
9 | self.goal_model = goal_model
10 |
11 |
12 | def forward(self, state, objects, instructions):
13 | state_embedding = self.state_model.forward(state)
14 | goal_embedding = self.goal_model.forward( (objects, instructions) )
15 |
16 | ## num_states x rank
17 | num_states = state_embedding.size(0)
18 | ## num_goals x rank
19 | num_goals = goal_embedding.size(0)
20 |
21 | ## num_states x num_goals x rank
22 | state_rep = state_embedding.unsqueeze(1).repeat(1,num_goals,1)
23 | goal_rep = goal_embedding.repeat(num_states,1,1)
24 |
25 | values = state_rep * goal_rep
26 | values = values.sum(2, keepdim=True).squeeze()
27 | values = values.transpose(0,1)
28 |
29 | return values
30 |
31 |
32 |
33 |
34 |
35 |
--------------------------------------------------------------------------------
/models/constraint_factorizer.py:
--------------------------------------------------------------------------------
1 | import sys, math
2 | import numpy as np
3 | from tqdm import tqdm, trange
4 | import torch, torch.nn as nn, torch.nn.functional as F
5 | import torch.optim as optim
6 | from torch.autograd import Variable
7 | from torch.nn.modules.utils import _pair
8 | import pdb, pickle
9 |
10 | class ConstraintFactorizer(nn.Module):
11 | def __init__(self, sparse_value_mat, rank, dissimilarity_lambda, world_lambda, location_lambda):
12 | super(ConstraintFactorizer, self).__init__()
13 | self.M, self.N = sparse_value_mat.shape
14 | self.mat = Variable( torch.Tensor(sparse_value_mat).cuda() )
15 | self.mask = (self.mat == 0)
16 | self.rank = rank
17 | # print self.M, self.N
18 | self.M_embed = nn.Embedding(self.M, self.rank)
19 | self.N_embed = nn.Embedding(self.N, self.rank)
20 | self.M_inp = Variable( torch.range(0, self.M-1).cuda().long() )
21 | self.N_inp = Variable( torch.range(0, self.N-1).cuda().long() )
22 | self.dim = int(math.sqrt(self.N))
23 |
24 | # self.conv_kernel = torch.ones(1,self.rank,3,3)/8./self.rank
25 | # self.conv_kernel[:,:,1,1] = -1./self.rank
26 | self.conv_kernel = Variable( self.avg_conv(self.rank).cuda() )
27 | self.conv_bias = None
28 | self.stride = _pair(1)
29 | self.padding = _pair(1)
30 |
31 | ######## similarity loss on states
32 | # self.conv_vector = self.avg_vector(self.rank)
33 | ########
34 |
35 | # print 'BEFORE PARAMETER:'
36 | # print self.conv_kernel
37 | # pdb.set_trace()
38 | # self.param = self.conv_kernel
39 | # print 'CONV KERNEL:'
40 | # print self.param
41 | # self.conv_kernel = self.conv_kernel.cuda()
42 | # self.conv_bias = self.conv_bias.cuda()
43 | # self.conv = nn.Conv2d(self.rank,1,kernel_size=3,padding=1)
44 |
45 | self.dissimilarity_lambda = dissimilarity_lambda
46 | self.world_lambda = world_lambda
47 | self.location_lambda = location_lambda
48 |
49 | # print self.mat
50 | # print self.mask
51 |
52 | def __lookup(self):
53 | rows = self.M_embed(self.M_inp)
54 | columns = self.N_embed(self.N_inp)
55 | return rows, columns
56 |
57 | # def __reset_conv(self):
58 | # # print self.conv_kernel
59 | # self.conv.weight = nn.Parameter(self.conv_kernel)
60 | # self.conv.bias = nn.Parameter(self.conv_bias)
61 |
62 | ## batch x 5 x 3
63 | def forward(self, x):
64 | # self.__reset_conv()
65 | # print self.conv.weight
66 | # print self.conv.bias
67 | ## batch x 5 x 1
68 | rows, columns = self.__lookup()
69 | out = torch.mm(rows, columns.t())
70 | # print out
71 | out[self.mask] = 0
72 |
73 | # print out
74 | # print self.mat
75 | diff = torch.pow(out - self.mat, 2)
76 | mse = diff.sum()
77 |
78 | ## 1 x M x N x rank
79 | layout = columns.view(self.dim,self.dim,self.rank).unsqueeze(0)
80 | ## 1 x rank x N x M
81 | layout = layout.transpose(1,3)
82 | ## 1 x rank x M x N
83 | layout = layout.transpose(2,3)
84 |
85 | average = F.conv2d(layout,self.conv_kernel,self.conv_bias,self.stride,self.padding)
86 | divergence_penalty = torch.pow(layout - average, 2).sum()
87 |
88 | # print 'divergenec: ', divergence_penalty.size()
89 | # print 'layout: ', layout.size()
90 | # conv = self.conv(layout)
91 | # divergence_penalty = conv.sum()
92 | # print 'conv: ', conv.size()
93 | # pdb.set_trace()
94 | self.mse = mse.data[0]
95 | self.divergence = divergence_penalty.data[0]
96 |
97 | loss = mse + self.dissimilarity_lambda * divergence_penalty
98 |
99 | ######## similarity loss on states
100 | ## worlds x state_size x rank
101 | states = rows.view(self.M / self.N, self.N, self.rank)
102 |
103 | world_avg = states.mean(1).repeat(1,self.N,1)
104 | location_avg = states.mean(0).repeat(self.M/self.N,1,1)
105 | # pdb.set_trace()
106 | world_mse = torch.pow(states - world_avg, 2).sum()
107 | location_mse = torch.pow(states - location_avg, 2).sum()
108 | self.world_mse = world_mse.data[0]
109 | self.location_mse = location_mse.data[0]
110 |
111 | loss += (self.world_lambda * world_mse) + (self.location_lambda * location_mse)
112 | ## rank x state_size x worlds
113 | # states = states.transpose(0,2)
114 | ## rank x worlds x state_size
115 | # states = states.transpose(1,2)
116 | # F.conv2d(states, self.conv_vector,self.conv_bias,self.stride,self.padding)
117 | ########
118 |
119 |
120 | # print rows.size(), columns.size(), out.size()
121 | # print out
122 | # print out
123 | # print loss
124 |
125 | return loss
126 |
127 | def embeddings(self):
128 | return self.__lookup()
129 |
130 | def avg_conv(self, out_dim):
131 | kernel = torch.zeros(out_dim,out_dim,3,3)
132 | for i in range(out_dim):
133 | kernel[i][i] = 1./8
134 | kernel[i][i][1][1] = 0
135 | return kernel
136 |
137 | # def avg_vector(self, out_dim):
138 | # kernel = torch.zeros(out_dim,out_dim,1,3)
139 | # for i in range(out_dim):
140 | # kernel[i][i][0] = torch.Tensor((1,0,1))/2.
141 | # return kernel
142 |
143 | def train(self, lr, iters):
144 | optimizer = optim.Adam(self.parameters(), lr=lr)
145 |
146 | t = trange(iters)
147 | for i in t:
148 | optimizer.zero_grad()
149 | loss = self.forward( () )
150 | # print loss.data[0]
151 | t.set_description( '%.3f | %.3f | %.3f | %.3f' % (self.mse, self.divergence, self.world_mse, self.location_mse) )
152 | loss.backward()
153 | optimizer.step()
154 |
155 | U, V = self.__lookup()
156 | recon = torch.mm(U, V.t())
157 | # print U, V, recon
158 | U = U.data.cpu().numpy()
159 | V = V.data.cpu().numpy()
160 | recon = recon.data.cpu().numpy()
161 | return U, V, recon
162 |
163 | def avg_conv(out_dim):
164 | kernel = torch.zeros(out_dim,out_dim,3,3)
165 | for i in range(out_dim):
166 | kernel[i][i] = 1./8
167 | kernel[i][i][1][1] = 0
168 | return kernel
169 |
170 | def avg_vector(out_dim):
171 | kernel = torch.zeros(out_dim,out_dim,1,3)
172 | for i in range(out_dim):
173 | kernel[i][i][0] = torch.Tensor((1,0,1))/2.
174 | return kernel
175 |
176 | if __name__ == '__main__':
177 | # from torch.nn.modules.utils import _pair
178 | # rank = 4
179 | # # kernel = Variable(avg_conv(rank))
180 | # kernel = Variable( avg_vector(rank) )
181 |
182 | # inp = Variable(torch.randn(1,rank,20,400))
183 | # bias = None
184 | # stride = _pair(1)
185 | # padding = _pair(1)
186 | # conv = F.conv2d(inp,kernel,bias,stride,padding)[:,:,1:-1,:]
187 | # print 'inp:', inp.size()
188 | # print 'kernel:', kernel.size()
189 | # print 'conv:', conv.size()
190 |
191 | # pdb.set_trace()
192 |
193 | # from torch.nn.modules.utils import _pair
194 | # inp = Variable(torch.randn(1,5,8,8))
195 | # kern = Variable(torch.ones(7,5,3,3))
196 | # bias = None
197 | # stride = _pair(1)
198 | # padding = _pair(1)
199 | # # dilation = _pair(1)
200 | # # print dilation
201 | # # groups = 1
202 | # # conv = F.conv2d(inp,kern,bias,stride,padding,dilation,groups)
203 | # conv = F.conv2d(inp,kern,bias,stride,padding)
204 | # print 'inp:', inp
205 | # print 'conv: ', conv
206 | # print 'sum: ', inp[:,:,:3,:3].sum()
207 |
208 | # print a.b
209 | # print 'conv: ', conv
210 | value_mat = pickle.load( open('../pickle/value_mat20.p') )
211 | rank = 10
212 |
213 | print 'value_mat: ', value_mat.shape
214 | dissimilarity_lambda = .1
215 | world_lambda = 0
216 | location_lambda = .001
217 | model = ConstraintFactorizer(value_mat, rank, dissimilarity_lambda, world_lambda, location_lambda).cuda()
218 | lr = 0.001
219 | iters = 500000
220 | U, V, recon = model.train(lr, iters)
221 |
222 | # lr = 0.001
223 | # optimizer = optim.Adam(model.parameters(), lr=lr)
224 |
225 | # loss = model.forward( () )
226 | # print loss
227 | # iters = 50000
228 | # t = trange(iters)
229 | # for i in t:
230 | # optimizer.zero_grad()
231 | # loss = model.forward( () )
232 | # # print loss.data[0]
233 | # t.set_description( str(model.mse) + ' ' + str(model.divergence) )
234 | # loss.backward()
235 | # optimizer.step()
236 |
237 | # U, V = model.embeddings()
238 |
239 | print 'recon'
240 | print recon
241 | print 'true'
242 | print torch.Tensor(value_mat)
243 |
244 | pickle.dump(U, open('../pickle/U_lambda_' + str(dissimilarity_lambda) + '.p', 'w') )
245 | pickle.dump(V, open('../pickle/V_lambda_' + str(dissimilarity_lambda) + '.p', 'w') )
246 | # pdb.set_trace()
247 |
248 |
249 | # inp = torch.LongTensor(5,3).zero_()
250 | # vocab_size = 10
251 | # emb_dim = 3
252 | # rank = 10
253 | # phi = ObjectModel(vocab_size, emb_dim, inp.size(), rank)
254 |
255 | # batch = 2
256 | # hidden = phi.init_hidden(batch)
257 |
258 | # # enc = nn.Embedding(10,emb_dim,padding_idx=0)
259 | # inp = torch.LongTensor(batch,5,3).zero_()
260 | # inp[0][0]=1
261 | # inp[0][1]=1
262 | # inp[1][0]=8
263 | # inp[1][1]=8
264 | # print inp
265 | # # inp[0][0][0][0]=1
266 | # # inp[0][1][0][0]=1
267 | # # inp[1][0][0][2]=1
268 | # # print inp
269 | # inp = Variable(inp)
270 |
271 | # out = phi.forward(inp, hidden)
272 | # # print out
273 | # # out = out.view(-1,2,3,3,emb_dim)
274 | # out = out.data
275 | # print out.size()
276 |
277 | # # print out[0][0][0]
278 | # # print out[1][0][0]
279 |
280 |
--------------------------------------------------------------------------------
/models/conv_to_vector.py:
--------------------------------------------------------------------------------
1 | import sys, math, pdb
2 | import numpy as np
3 | from tqdm import tqdm, trange
4 | import torch, torch.nn as nn, torch.nn.functional as F
5 | import torch.optim as optim
6 |
7 | '''
8 | State observations are two-channel images
9 | with 0: puddle, 1: grass, 2: agent
10 |
11 | '''
12 |
13 | class ConvToVector(nn.Module):
14 | def __init__(self, in_channels, padding=1):
15 | super(ConvToVector, self).__init__()
16 |
17 | self.in_channels = in_channels
18 |
19 | # self.embed = nn.Embedding(vocab_size, in_channels)
20 | self.conv1 = nn.Conv2d(in_channels, 3, kernel_size=3, padding=padding)
21 | self.conv2 = nn.Conv2d(3, 6, kernel_size=3, padding=padding)
22 | self.conv3 = nn.Conv2d(6, 12, kernel_size=3, padding=padding)
23 | # self.conv4 = nn.Conv2d(12,18, kernel_size=2, padding=padding)
24 | # self.conv5 = nn.Conv2d(18,24, kernel_size=2, padding=padding)
25 | # self.conv6 = nn.Conv2d(24,18, kernel_size=2, padding=padding)
26 | # self.conv7 = nn.Conv2d(18,12, kernel_size=2, padding=padding)
27 | self.conv8 = nn.Conv2d(12, 6, kernel_size=3, padding=0)
28 | self.conv9 = nn.Conv2d(6, 3, kernel_size=3, padding=0)
29 | self.conv10 = nn.Conv2d(3, 1, kernel_size=3, padding=0)
30 | # # self.conv4 = nn.Conv2d(12,12, kernel_size=5)
31 | # self.fc1 = nn.Linear(192, out_dim)
32 |
33 | def forward(self, x):
34 | # reshape = []
35 | # for dim in x.size(): reshape.append(dim)
36 | # reshape.append(self.in_channels)
37 |
38 | # ## reshape to vector
39 | # x = x.view(-1)
40 | # ## get embeddings
41 | # x = self.embed(x)
42 | # ## reshape to batch x channels x M x N x embed_dim
43 | # x = x.view(*reshape)
44 | # ## sum over channels in input
45 | # x = x.sum(1)
46 | # # pdb.set_trace()
47 | # ## reshape to batch x embed_dim x M x N
48 | # ## (treats embedding dims as channels)
49 | # x = x.transpose(1,-1)[:,:,:,:,0] #.squeeze() #
50 | # print 'SIZE:', x.size()
51 | # pdb.set_trace()
52 | x = F.relu(self.conv1(x))
53 | x = F.relu(self.conv2(x))
54 | x = F.relu(self.conv3(x))
55 | # x = F.relu(self.conv4(x))
56 | # x = F.relu(self.conv5(x))
57 | # x = F.relu(self.conv6(x))
58 | # x = F.relu(self.conv7(x))
59 | x = F.relu(self.conv8(x))
60 | x = F.relu(self.conv9(x))
61 | x = self.conv10(x)
62 |
63 | # x = x.view(-1, 192)
64 | # x = self.fc1(x)
65 | return x
66 |
67 |
68 | if __name__ == '__main__':
69 | from torch.autograd import Variable
70 | # inp = torch.LongTensor(2,10,10).zero_()
71 | inp = Variable( torch.randn(5,1,10,10) )
72 |
73 | model = ConvToVector(1)
74 |
75 | out = model(inp)
76 |
77 | print inp.size()
78 | print out.size()
79 |
80 |
--------------------------------------------------------------------------------
/models/custom.py:
--------------------------------------------------------------------------------
1 | import math, torch, torch.nn.functional as F, pdb
2 | from torch.autograd import Variable
3 | from torch.nn.parameter import Parameter
4 |
5 | from torch.nn.modules.module import Module
6 | from torch.nn.modules.utils import _single, _pair, _triple
7 | from torch.nn.modules.conv import _ConvNd
8 |
9 | class _ConvNdKernel(Module):
10 |
11 | def __init__(self, in_channels, out_channels, kernel_size, stride,
12 | padding, dilation, transposed, output_padding, groups, bias):
13 | super(_ConvNdKernel, self).__init__()
14 | if in_channels % groups != 0:
15 | raise ValueError('in_channels must be divisible by groups')
16 | if out_channels % groups != 0:
17 | raise ValueError('out_channels must be divisible by groups')
18 | self.in_channels = in_channels
19 | self.out_channels = out_channels
20 | self.kernel_size = kernel_size
21 | self.stride = stride
22 | self.padding = padding
23 | self.dilation = dilation
24 | self.transposed = transposed
25 | self.output_padding = output_padding
26 | self.groups = groups
27 | # if transposed:
28 | # self.weight = Parameter(torch.Tensor(
29 | # in_channels, out_channels // groups, *kernel_size))
30 | # else:
31 | # self.weight = Parameter(torch.Tensor(
32 | # out_channels, in_channels // groups, *kernel_size))
33 | if bias:
34 | self.bias = Parameter(torch.Tensor(out_channels))
35 | else:
36 | self.register_parameter('bias', None)
37 | self.reset_parameters()
38 |
39 | def reset_parameters(self):
40 | n = self.in_channels
41 | for k in self.kernel_size:
42 | n *= k
43 | stdv = 1. / math.sqrt(n)
44 | # self.weight.data.uniform_(-stdv, stdv)
45 | if self.bias is not None:
46 | self.bias.data.uniform_(-stdv, stdv)
47 |
48 | def __repr__(self):
49 | s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
50 | ', stride={stride}')
51 | if self.padding != (0,) * len(self.padding):
52 | s += ', padding={padding}'
53 | if self.dilation != (1,) * len(self.dilation):
54 | s += ', dilation={dilation}'
55 | if self.output_padding != (0,) * len(self.output_padding):
56 | s += ', output_padding={output_padding}'
57 | if self.groups != 1:
58 | s += ', groups={groups}'
59 | if self.bias is None:
60 | s += ', bias=False'
61 | s += ')'
62 | return s.format(name=self.__class__.__name__, **self.__dict__)
63 |
64 |
65 | class ConvKernel(_ConvNdKernel):
66 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
67 | padding=0, dilation=1, groups=1, bias=True):
68 | kernel_size = _pair(kernel_size)
69 | stride = _pair(stride)
70 | padding = _pair(padding)
71 | dilation = _pair(dilation)
72 | super(ConvKernel, self).__init__(
73 | in_channels, out_channels, kernel_size, stride, padding, dilation,
74 | False, _pair(0), groups, bias)
75 |
76 | def forward(self, input, kernel):
77 | self.weight = Parameter(kernel.data)
78 | # print 'weight: ', self.weight.size()
79 | # print 'bias: ', self.bias.size()
80 | # print 'forward:', type(input.data), type(self.weight.data)
81 | # print 'forward: ', input.size(), self.weight.size()
82 | return F.conv2d(input, kernel, self.bias, self.stride,
83 | self.padding, self.dilation, self.groups)
84 |
85 |
86 |
87 |
--------------------------------------------------------------------------------
/models/goal_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class Psi(nn.Module):
6 |
7 | def __init__(self, text_model, object_model, concat_dim, hidden_dim, out_dim):
8 | super(Psi, self).__init__()
9 |
10 | self.text_model = text_model
11 | self.object_model = object_model
12 | self.fc1 = nn.Linear(concat_dim * 2, hidden_dim)
13 | self.fc2 = nn.Linear(hidden_dim, hidden_dim)
14 | self.fc3 = nn.Linear(hidden_dim, out_dim)
15 |
16 |
17 | def forward(self, inp):
18 | (obj, text) = inp
19 | batch_size = obj.size(0)
20 | text = text.transpose(0,1)
21 | hidden = self.text_model.init_hidden(batch_size)
22 | # print obj, text, hidden
23 |
24 | obj_out = self.object_model.forward(obj)
25 | # print text
26 | # print 'obj out: ', obj_out.data.size()
27 | # print 'text inp: ', text.data.size()
28 | # print 'hidden: ', hidden.data.size()
29 | text_out = self.text_model.forward(text, hidden)
30 | concat = F.relu( torch.cat((obj_out, text_out), 1) )
31 | output = F.relu(self.fc1(concat))
32 | output = F.relu(self.fc2(output))
33 | output = self.fc3(output)
34 | return output
35 |
36 |
37 | if __name__ == '__main__':
38 | from text_model import *
39 | from object_model import *
40 |
41 | batch = 2
42 | seq = 10
43 |
44 | text_vocab = 10
45 | lstm_inp = 5
46 | lstm_hid = 3
47 | lstm_layers = 2
48 |
49 | obj_inp = torch.LongTensor(1,20,20).zero_()
50 | obj_vocab = 3
51 | emb_dim = 3
52 |
53 | concat_dim = 10
54 | hidden_dim = 5
55 | out_dim = 7
56 |
57 | text_model = TextModel(text_vocab, lstm_inp, lstm_hid, lstm_layers, concat_dim)
58 | object_model = ObjectModel(obj_vocab, emb_dim, obj_inp.size(), concat_dim)
59 |
60 | psi = Psi(text_model, object_model, concat_dim, hidden_dim, out_dim)
61 |
62 | hidden = text_model.init_hidden(batch)
63 | text_inp = Variable(torch.floor(torch.rand(seq,batch)*text_vocab).long())
64 | obj_inp = Variable(torch.floor(torch.rand(batch,1,20,20)*obj_vocab).long())
65 |
66 | to = psi.forward(obj_inp, text_inp, hidden)
67 | print to
68 |
69 | # # inp = Variable(torch.LongTensor((1,2,3)))
70 | # print 'INPUT: ', text_inp
71 | # print text_inp.size()
72 | # out = text_model.forward( text_inp, hidden )
73 |
74 | # print 'OUT: ', out
75 | # print out.size()
76 | # # print 'HID: ', hid
77 |
78 | # obj_inp = Variable(torch.LongTensor(5,1,20,20).zero_())
79 | # obj_inp = Variable(obj_inp)
80 | # obj_out = object_model.forward(obj_inp)
81 |
82 | # print obj_out.data.size()
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
--------------------------------------------------------------------------------
/models/initializations.py:
--------------------------------------------------------------------------------
1 | import models
2 |
3 | '''
4 | norbf (full), nobases (no gradient), nonsep
5 | uvfa-pos, uvfa-text, cnn+lstm
6 | (nocnn)
7 |
8 | rbf + gradient + cnn: full
9 | no rbf: norbf
10 | no gradient: noglobal
11 | no rbf / gradient: nobases
12 | no cnn: nocnn
13 | '''
14 |
15 | def init(args, layout_vocab_size, object_vocab_size, text_vocab_size):
16 | if args.model == 'full': ## new
17 | model = init_full(args, layout_vocab_size, object_vocab_size, text_vocab_size)
18 | elif args.model == 'no-gradient':
19 | model = init_nogradient(args, layout_vocab_size, object_vocab_size, text_vocab_size)
20 | elif args.model == 'cnn-lstm':
21 | model = init_cnn_lstm(args, layout_vocab_size, object_vocab_size, text_vocab_size)
22 | elif args.model == 'uvfa-text':
23 | model = init_uvfa_text(args, layout_vocab_size, object_vocab_size, text_vocab_size)
24 | # TODO: clean up UVFA-pos goal loading
25 | elif args.model == 'uvfa-pos':
26 | model = init_uvfa_pos(args, layout_vocab_size, object_vocab_size, text_vocab_size)
27 | train_indices = train_goals
28 | val_indices = val_goals
29 | return model
30 |
31 |
32 | def init_full(args, layout_vocab_size, object_vocab_size, text_vocab_size):
33 | args.global_coeffs = 3
34 | args.attention_in_dim = args.obj_embed
35 | args.lstm_out = args.attention_in_dim * args.attention_out_dim * args.attention_kernel**2 + args.global_coeffs
36 |
37 | state_model = models.LookupModel(layout_vocab_size, args.state_embed).cuda()
38 | object_model = models.LookupModel(object_vocab_size, args.obj_embed)
39 |
40 | text_model = models.TextModel(text_vocab_size, args.lstm_inp, args.lstm_hid, args.lstm_layers, args.lstm_out)
41 | heatmap_model = models.AttentionGlobal(text_model, args, map_dim=args.map_dim).cuda()
42 |
43 | model = models.MultiNoRBF(state_model, object_model, heatmap_model, args, map_dim=args.map_dim).cuda()
44 | return model
45 |
46 |
47 | def init_nogradient(args, layout_vocab_size, object_vocab_size, text_vocab_size):
48 | args.global_coeffs = 0
49 | args.attention_in_dim = args.obj_embed
50 | args.lstm_out = args.attention_in_dim * args.attention_out_dim * args.attention_kernel**2
51 |
52 | state_model = models.LookupModel(layout_vocab_size, args.state_embed).cuda()
53 | object_model = models.LookupModel(object_vocab_size, args.obj_embed)
54 |
55 | text_model = models.TextModel(text_vocab_size, args.lstm_inp, args.lstm_hid, args.lstm_layers, args.lstm_out)
56 | heatmap_model = models.AttentionHeatmap(text_model, args, map_dim=args.map_dim).cuda()
57 |
58 | model = models.MultiNoBases(state_model, object_model, heatmap_model, args, map_dim=args.map_dim).cuda()
59 | return model
60 |
61 |
62 | def init_cnn_lstm(args, layout_vocab_size, object_vocab_size, text_vocab_size):
63 | args.lstm_out = 16
64 | args.cnn_out_dim = 2*args.lstm_out
65 |
66 | state_model = models.LookupModel(layout_vocab_size, args.state_embed)
67 | object_model = models.LookupModel(object_vocab_size, args.obj_embed)
68 |
69 | lstm = models.TextModel(text_vocab_size, args.lstm_inp, args.lstm_hid, args.lstm_layers, args.lstm_out)
70 |
71 | model = models.CNN_LSTM(state_model, object_model, lstm, args).cuda()
72 | return model
73 |
74 |
75 | def init_uvfa_text(args, layout_vocab_size, object_vocab_size, text_vocab_size, rank = 7):
76 | print ' Using UVFA variant, consider using a lower learning rate (eg, 0.0001)'
77 | print ' UVFA rank: {} '.format(rank)
78 |
79 | args.rank = rank
80 | args.lstm_out = rank
81 |
82 | text_model = models.TextModel(text_vocab_size, args.lstm_inp, args.lstm_hid, args.lstm_layers, args.lstm_out)
83 | model = models.UVFA_text(text_model, layout_vocab_size, object_vocab_size, args, map_dim=args.map_dim).cuda()
84 | return model
85 |
86 |
87 | def init_uvfa_pos(args, layout_vocab_size, object_vocab_size, text_vocab_size, rank = 7):
88 | print ' Using UVFA variant, consider using a lower learning rate (eg, 0.0001)'
89 | print ' UVFA rank: {} '.format(rank)
90 |
91 | args.rank = rank
92 | args.lstm_out = rank
93 |
94 | model = models.UVFA_pos(layout_vocab_size, object_vocab_size, args, map_dim=args.map_dim).cuda()
95 | return model
96 |
97 |
98 |
99 |
100 |
101 |
--------------------------------------------------------------------------------
/models/lookup_location.py:
--------------------------------------------------------------------------------
1 | import sys, math, pdb
2 | import numpy as np
3 | from tqdm import tqdm, trange
4 | import torch, torch.nn as nn, torch.nn.functional as F
5 | from torch.autograd import Variable
6 | import torch.optim as optim
7 |
8 | '''
9 | Object observations are single-channels images
10 | with positive indices denoting objects.
11 | 0's denote no object.
12 |
13 | '''
14 |
15 | class LookupLocationModel(nn.Module):
16 | def __init__(self, vocab_size, embed_dim, map_dim = 10):
17 | super(LookupLocationModel, self).__init__()
18 |
19 | ## add two for (x,y) location channels
20 | self.embed_dim = embed_dim
21 |
22 | # self.reshape = [-1]
23 | # for dim in inp_size:
24 | # self.reshape.append(dim)
25 | # self.reshape.append(self.embed_dim)
26 |
27 | # self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
28 | self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
29 |
30 | # self.conv1 = nn.Conv2d(embed_dim, 3, kernel_size=5)
31 | # self.conv2 = nn.Conv2d(3, 6, kernel_size=5)
32 | # self.conv3 = nn.Conv2d(6,12, kernel_size=5)
33 | # self.conv4 = nn.Conv2d(12,12, kernel_size=5)
34 | # self.fc1 = nn.Linear(192, out_dim)
35 | # self.init_weights()
36 |
37 | self.map_dim = map_dim
38 | self.locations = self.__init_locations(self.map_dim)
39 |
40 | def __init_locations(self, map_dim):
41 | row = torch.arange(0,map_dim).unsqueeze(1).repeat(1,map_dim)
42 | col = torch.arange(0,map_dim).repeat(map_dim,1)
43 | locations = torch.stack( (row, col) )
44 | return Variable(locations.cuda())
45 |
46 | def init_weights(self):
47 | initrange = 0.1
48 | self.embed.weight.data.uniform_(-initrange, initrange)
49 | # self.decoder.bias.data.fill_(0)
50 | # self.decoder.weight.data.uniform_(-initrange, initrange)
51 |
52 | def forward(self, x):
53 | # print 'LOOKUP: ', x.size(), type(x)
54 | batch_size = x.size(0)
55 |
56 | reshape = []
57 | for dim in x.size(): reshape.append(dim)
58 | reshape.append(self.embed_dim)
59 |
60 | if x.size(-1) != self.map_dim:
61 | self.map_dim = x.size(-1)
62 | self.locations = self.__init_locations(self.map_dim)
63 |
64 | ## reshape to vector
65 | x = x.view(-1)
66 | ## get embeddings
67 | x = self.embed(x)
68 | ## reshape to batch x channels x M x N x embed_dim
69 | x = x.view(*reshape)
70 | ## sum over channels in input
71 | x = x.sum(1, keepdim=True)
72 | # pdb.set_trace()
73 | ## reshape to batch x embed_dim x M x N
74 | ## (treats embedding dims as channels)
75 | x = x.transpose(1,-1)[:,:,:,:,0] #.squeeze() #
76 |
77 | locations = self.locations.unsqueeze(0).repeat(batch_size,1,1,1)
78 |
79 | # pdb.set_trace()
80 | x = torch.cat( (x, locations), 1 )
81 | # print self.locations
82 | # print locations
83 | # x = F.relu(self.conv1(x))
84 | # x = F.relu(self.conv2(x))
85 | # x = F.relu(self.conv3(x))
86 | # x = F.relu(self.conv4(x))
87 | # x = x.view(-1, 192)
88 | # x = self.fc1(x)
89 | return x
90 |
91 | if __name__ == '__main__':
92 | vocab_size = 10
93 | emb_dim = 3
94 | map_dim = 10
95 | phi = LookupLocationModel(vocab_size, emb_dim, map_dim=map_dim)
96 |
97 | # enc = nn.Embedding(10,emb_dim,padding_idx=0)
98 | inp = torch.LongTensor(5,1,map_dim,map_dim).zero_()
99 | inp[0][0][0][0]=1
100 | # inp[0][1][0][0]=1
101 | inp[1][0][0][2]=1
102 | # print inp
103 | inp = Variable(inp)
104 |
105 | out = phi.forward(inp)
106 | # print out
107 | # out = out.view(-1,2,3,3,emb_dim)
108 | out = out.data
109 | print out.size()
110 |
111 | # print out[0][0][0]
112 | # print out[1][0][0]
113 |
114 |
--------------------------------------------------------------------------------
/models/lookup_model.py:
--------------------------------------------------------------------------------
1 | import sys, math, pdb
2 | import numpy as np
3 | from tqdm import tqdm, trange
4 | import torch, torch.nn as nn, torch.nn.functional as F
5 | from torch.autograd import Variable
6 | import torch.optim as optim
7 |
8 | '''
9 | Object observations are single-channels images
10 | with positive indices denoting objects.
11 | 0's denote no object.
12 |
13 | '''
14 |
15 | class LookupModel(nn.Module):
16 | def __init__(self, vocab_size, embed_dim):
17 | super(LookupModel, self).__init__()
18 |
19 | self.embed_dim = embed_dim
20 |
21 | # self.reshape = [-1]
22 | # for dim in inp_size:
23 | # self.reshape.append(dim)
24 | # self.reshape.append(self.embed_dim)
25 |
26 | # self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
27 | self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
28 |
29 | # self.conv1 = nn.Conv2d(embed_dim, 3, kernel_size=5)
30 | # self.conv2 = nn.Conv2d(3, 6, kernel_size=5)
31 | # self.conv3 = nn.Conv2d(6,12, kernel_size=5)
32 | # self.conv4 = nn.Conv2d(12,12, kernel_size=5)
33 | # self.fc1 = nn.Linear(192, out_dim)
34 | # self.init_weights()
35 |
36 | # self.map_dim = map_dim
37 | # self.locations = self.__init_locations(self.map_dim)
38 |
39 | def init_weights(self):
40 | initrange = 0.1
41 | self.embed.weight.data.uniform_(-initrange, initrange)
42 | # self.decoder.bias.data.fill_(0)
43 | # self.decoder.weight.data.uniform_(-initrange, initrange)
44 |
45 | def forward(self, x):
46 | # print 'LOOKUP: ', x.size(), type(x)
47 | batch_size = x.size(0)
48 |
49 | reshape = []
50 | for dim in x.size(): reshape.append(dim)
51 | reshape.append(self.embed_dim)
52 |
53 | ## reshape to vector
54 | x = x.view(-1)
55 | ## get embeddings
56 | x = self.embed(x)
57 | ## reshape to batch x channels x M x N x embed_dim
58 | x = x.view(*reshape)
59 | ## sum over channels in input
60 | x = x.sum(1, keepdim=True)
61 | # pdb.set_trace()
62 | ## reshape to batch x embed_dim x M x N
63 | ## (treats embedding dims as channels)
64 | x = x.transpose(1,-1)[:,:,:,:,0] #.squeeze() #
65 |
66 | # locations = self.locations.unsqueeze(0).repeat(batch_size,1,1,1)
67 |
68 | # pdb.set_trace()
69 | # x = torch.cat( (x, locations), 1 )
70 | # print self.locations
71 | # print locations
72 | # x = F.relu(self.conv1(x))
73 | # x = F.relu(self.conv2(x))
74 | # x = F.relu(self.conv3(x))
75 | # x = F.relu(self.conv4(x))
76 | # x = x.view(-1, 192)
77 | # x = self.fc1(x)
78 | return x
79 |
80 | if __name__ == '__main__':
81 | vocab_size = 10
82 | emb_dim = 3
83 | map_dim = 10
84 | phi = LookupModel(vocab_size, emb_dim, map_dim=map_dim)
85 |
86 | # enc = nn.Embedding(10,emb_dim,padding_idx=0)
87 | inp = torch.LongTensor(5,1,map_dim,map_dim).zero_()
88 | inp[0][0][0][0]=1
89 | # inp[0][1][0][0]=1
90 | inp[1][0][0][2]=1
91 | # print inp
92 | inp = Variable(inp)
93 |
94 | out = phi.forward(inp)
95 | # print out
96 | # out = out.view(-1,2,3,3,emb_dim)
97 | out = out.data
98 | print out.size()
99 |
100 | # print out[0][0][0]
101 | # print out[1][0][0]
102 |
103 |
--------------------------------------------------------------------------------
/models/map_model.py:
--------------------------------------------------------------------------------
1 | import sys, math
2 | import numpy as np
3 | from tqdm import tqdm, trange
4 | import torch, torch.nn as nn, torch.nn.functional as F
5 | import torch.optim as optim
6 |
7 | '''
8 | State observations are two-channel images
9 | with 0: puddle, 1: grass, 2: agent
10 |
11 | '''
12 |
13 | class MapModel(nn.Module):
14 | def __init__(self, vocab_size, embed_dim, out_dim):
15 | super(MapModel, self).__init__()
16 |
17 | self.embed_dim = embed_dim
18 |
19 | self.embed = nn.Embedding(vocab_size, embed_dim)
20 | self.conv1 = nn.Conv2d(embed_dim, 3, kernel_size=3)
21 | self.conv2 = nn.Conv2d(3, 6, kernel_size=3)
22 | self.conv3 = nn.Conv2d(6,12, kernel_size=3)
23 | # self.conv4 = nn.Conv2d(12,12, kernel_size=5)
24 | self.fc1 = nn.Linear(192, out_dim)
25 |
26 | def forward(self, x):
27 | reshape = []
28 | for dim in x.size(): reshape.append(dim)
29 | reshape.append(self.embed_dim)
30 |
31 | ## reshape to vector
32 | x = x.view(-1)
33 | ## get embeddings
34 | x = self.embed(x)
35 | ## reshape to batch x channels x M x N x embed_dim
36 | x = x.view(*reshape)
37 | ## sum over channels in input
38 | x = x.sum(1, keepdim=True)
39 | ## reshape to batch x embed_dim x M x N
40 | ## (treats embedding dims as channels)
41 | x = x.transpose(1,-1).squeeze()
42 | x = F.relu(self.conv1(x))
43 | x = F.relu(self.conv2(x))
44 | x = F.relu(self.conv3(x))
45 |
46 | x = x.view(-1, 192)
47 | x = self.fc1(x)
48 | return x
49 |
50 |
51 | if __name__ == '__main__':
52 | from torch.autograd import Variable
53 | # inp = torch.LongTensor(2,10,10).zero_()
54 | vocab_size = 10
55 | emb_dim = 3
56 | rank = 7
57 | phi = MapModel(vocab_size, emb_dim, rank)
58 |
59 | # enc = nn.Embedding(10,emb_dim,padding_idx=0)
60 | inp = torch.LongTensor(5,2,10,10).zero_()
61 | inp[0][0][0][0]=1
62 | # inp[0][1][0][0]=1
63 | inp[1][0][0][2]=1
64 | print inp.size()
65 | inp = Variable(inp)
66 |
67 | out = phi.forward(inp)
68 | # print out
69 | # out = out.view(-1,2,3,3,emb_dim)
70 | out = out.data
71 | print out.size()
72 |
73 | # print out[0][0][0]
74 | # print out[1][0][0]
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
--------------------------------------------------------------------------------
/models/mlp.py:
--------------------------------------------------------------------------------
1 | ## predicts entire value map
2 | ## rather than a single value
3 |
4 | import torch
5 | import math, torch.nn as nn, pdb
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 |
9 | class MLP(nn.Module):
10 | def __init__(self, sizes):
11 | super(MLP, self).__init__()
12 |
13 | layers = []
14 |
15 | for ind in range(len(sizes)-1):
16 | layers.append( nn.Linear(sizes[ind], sizes[ind+1]) )
17 | layers.append( nn.ReLU() )
18 | layers.pop(-1)
19 |
20 | self.layers = nn.ModuleList(layers)
21 |
22 | def forward(self, x):
23 | for lay in self.layers:
24 | x = lay(x)
25 | return x
26 |
27 | if __name__ == '__main__':
28 | batch_size = 32
29 | sizes = [10, 128, 128, 4]
30 | mlp = MLP(sizes)
31 | inp = Variable( torch.Tensor(batch_size, sizes[0]) )
32 | print inp.size()
33 | out = mlp(inp)
34 | print out.size()
35 | loss = out.sum()
36 | loss.backward()
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/models/model_factorizer.py:
--------------------------------------------------------------------------------
1 | import sys, math
2 | import numpy as np
3 | from tqdm import tqdm, trange
4 | import torch, torch.nn as nn, torch.nn.functional as F
5 | import torch.optim as optim
6 | from torch.autograd import Variable
7 | from torch.nn.modules.utils import _pair
8 | import pdb, pickle
9 |
10 | class ModelFactorizer(nn.Module):
11 | def __init__(self, state_model, goal_model, state_inp, goal_inp, sparse_value_mat):
12 | super(ModelFactorizer, self).__init__()
13 | self.state_model = state_model
14 | self.goal_model = goal_model
15 | self.state_inp = Variable(state_inp.cuda())
16 | self.goal_inp = [Variable(i).cuda() for i in goal_inp]
17 | self.mat = Variable( sparse_value_mat )
18 | self.mask = (self.mat == 0)
19 | self.M = state_inp.size(0)
20 |
21 | def forward(self, inds):
22 | state_inp = self.state_inp.index_select(0, inds)
23 | state_out = self.state_model.forward(state_inp)
24 | goal_out = self.goal_model.forward(self.goal_inp)
25 |
26 | recon = torch.mm(state_out, goal_out.t())
27 | mask_select = self.mask.index_select(0, inds)
28 | true_select = self.mat.index_select(0, inds)
29 |
30 | # pdb.set_trace()
31 |
32 | diff = torch.pow(recon - true_select, 2)
33 |
34 | mse = diff.sum()
35 |
36 | return mse
37 |
38 | def train(self, lr, iters, batch_size = 256):
39 | optimizer = optim.Adam(self.parameters(), lr=lr)
40 |
41 | t = trange(iters)
42 | for i in t:
43 | optimizer.zero_grad()
44 | inds = torch.floor(torch.rand(batch_size) * self.M).long().cuda()
45 | # bug: floor(rand()) sometimes gives 1
46 | inds[inds >= self.M] = self.M - 1
47 | inds = Variable(inds)
48 |
49 | loss = self.forward(inds)
50 | # print loss.data[0]
51 | t.set_description( str(loss.data[0]) )
52 | loss.backward()
53 | optimizer.step()
54 |
55 | return self.state_model, self.goal_model
56 |
57 |
58 |
59 | if __name__ == '__main__':
60 | from state_model import *
61 | from object_model_10 import *
62 | from text_model import *
63 | from goal_model import *
64 | rank = 7
65 | state_vocab_size = 20
66 | embed_size = 3
67 | state_obs_size = (2,10,10)
68 | goal_obs_size = (1,10,10)
69 | lstm_size = 15
70 | lstm_nlayer = 1
71 | phi = Phi(state_vocab_size, embed_size, state_obs_size, rank).cuda()
72 | text_model = TextModel(state_vocab_size, lstm_size, lstm_size, lstm_nlayer, lstm_size)
73 | object_model = ObjectModel(state_vocab_size, embed_size, goal_obs_size, lstm_size)
74 | psi = Psi(text_model, object_model, lstm_size, lstm_size, rank).cuda()
75 | print phi
76 |
77 | state_obs = Variable((torch.rand(20*100,2,10,10)*10).long().cuda())
78 | pdb.set_trace()
79 |
80 | out = phi.forward(state_obs)
81 | print out.size()
82 | # from torch.nn.modules.utils import _pair
83 | # rank = 4
84 | # # kernel = Variable(avg_conv(rank))
85 | # kernel = Variable( avg_vector(rank) )
86 |
87 | # inp = Variable(torch.randn(1,rank,20,400))
88 | # bias = None
89 | # stride = _pair(1)
90 | # padding = _pair(1)
91 | # conv = F.conv2d(inp,kernel,bias,stride,padding)[:,:,1:-1,:]
92 | # print 'inp:', inp.size()
93 | # print 'kernel:', kernel.size()
94 | # print 'conv:', conv.size()
95 |
96 | # pdb.set_trace()
97 |
98 | # from torch.nn.modules.utils import _pair
99 | # inp = Variable(torch.randn(1,5,8,8))
100 | # kern = Variable(torch.ones(7,5,3,3))
101 | # bias = None
102 | # stride = _pair(1)
103 | # padding = _pair(1)
104 | # # dilation = _pair(1)
105 | # # print dilation
106 | # # groups = 1
107 | # # conv = F.conv2d(inp,kern,bias,stride,padding,dilation,groups)
108 | # conv = F.conv2d(inp,kern,bias,stride,padding)
109 | # print 'inp:', inp
110 | # print 'conv: ', conv
111 | # print 'sum: ', inp[:,:,:3,:3].sum()
112 |
113 | # print a.b
114 | # print 'conv: ', conv
115 | value_mat = pickle.load( open('../pickle/value_mat20.p') )
116 | rank = 10
117 |
118 | print 'value_mat: ', value_mat.shape
119 | dissimilarity_lambda = .1
120 | world_lambda = 0
121 | location_lambda = .001
122 | model = ConstraintFactorizer(value_mat, rank, dissimilarity_lambda, world_lambda, location_lambda).cuda()
123 | lr = 0.001
124 | iters = 500000
125 | U, V, recon = model.train(lr, iters)
126 |
127 | # lr = 0.001
128 | # optimizer = optim.Adam(model.parameters(), lr=lr)
129 |
130 | # loss = model.forward( () )
131 | # print loss
132 | # iters = 50000
133 | # t = trange(iters)
134 | # for i in t:
135 | # optimizer.zero_grad()
136 | # loss = model.forward( () )
137 | # # print loss.data[0]
138 | # t.set_description( str(model.mse) + ' ' + str(model.divergence) )
139 | # loss.backward()
140 | # optimizer.step()
141 |
142 | # U, V = model.embeddings()
143 |
144 | print 'recon'
145 | print recon
146 | print 'true'
147 | print torch.Tensor(value_mat)
148 |
149 | pickle.dump(U, open('../pickle/U_lambda_' + str(dissimilarity_lambda) + '.p', 'w') )
150 | pickle.dump(V, open('../pickle/V_lambda_' + str(dissimilarity_lambda) + '.p', 'w') )
151 | # pdb.set_trace()
152 |
153 |
154 | # inp = torch.LongTensor(5,3).zero_()
155 | # vocab_size = 10
156 | # emb_dim = 3
157 | # rank = 10
158 | # phi = ObjectModel(vocab_size, emb_dim, inp.size(), rank)
159 |
160 | # batch = 2
161 | # hidden = phi.init_hidden(batch)
162 |
163 | # # enc = nn.Embedding(10,emb_dim,padding_idx=0)
164 | # inp = torch.LongTensor(batch,5,3).zero_()
165 | # inp[0][0]=1
166 | # inp[0][1]=1
167 | # inp[1][0]=8
168 | # inp[1][1]=8
169 | # print inp
170 | # # inp[0][0][0][0]=1
171 | # # inp[0][1][0][0]=1
172 | # # inp[1][0][0][2]=1
173 | # # print inp
174 | # inp = Variable(inp)
175 |
176 | # out = phi.forward(inp, hidden)
177 | # # print out
178 | # # out = out.view(-1,2,3,3,emb_dim)
179 | # out = out.data
180 | # print out.size()
181 |
182 | # # print out[0][0][0]
183 | # # print out[1][0][0]
184 |
185 |
--------------------------------------------------------------------------------
/models/multi_global.py:
--------------------------------------------------------------------------------
1 | ## predicts entire value map
2 | ## rather than a single value
3 |
4 | import torch
5 | import math, torch.nn as nn, pdb
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 | import models, utils
9 |
10 | class MultiGlobal(nn.Module):
11 | def __init__(self, state_model, object_model, heatmap_model, args, map_dim = 10):
12 | super(MultiGlobal, self).__init__()
13 |
14 | self.state_model = state_model
15 | self.object_model = object_model
16 | self.heatmap_model = heatmap_model
17 | self.simple_conv = models.SimpleConv(3).cuda()
18 | self.rbf = Variable( utils.meta_rbf(map_dim).cuda() )
19 | self.positions = Variable( self.__init_positions(map_dim).cuda() )
20 |
21 | self.map_dim = map_dim
22 | self.batch_size = args.batch_size
23 | self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1)
24 | self.positions_batch = self.positions.repeat(self.batch_size,1,1,1)
25 |
26 | '''
27 | global_coeffs are < batch x 3 >
28 | 3: row, col, bias
29 | '''
30 | def _global(self, global_coeffs):
31 | pos_coeffs = global_coeffs[:,:-1]
32 | bias = global_coeffs[:,-1]
33 |
34 | coeffs_batch = pos_coeffs.unsqueeze(-1).unsqueeze(-1).repeat(1,1,self.map_dim,self.map_dim)
35 | bias_batch = bias.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1,1,self.map_dim,self.map_dim)
36 |
37 | ## sum over row, col and add bias
38 | obj_global = (coeffs_batch * self.positions_batch).sum(1, keepdim=True) + bias_batch
39 | return obj_global
40 |
41 |
42 |
43 | def __init_positions(self, map_dim):
44 | row = torch.arange(0,map_dim).unsqueeze(1).repeat(1,map_dim)
45 | col = torch.arange(0,map_dim).repeat(map_dim,1)
46 | positions = torch.stack( (row, col) )
47 | return positions
48 |
49 |
50 | def forward(self, inp):
51 | (state, obj, text) = inp
52 | batch_size = state.size(0)
53 | if batch_size != self.batch_size:
54 | self.batch_size = batch_size
55 | self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1)
56 | self.positions_batch = self.positions.repeat(self.batch_size,1,1,1)
57 |
58 | ## get state map
59 | state_out = self.state_model(state)
60 | obj_out = self.object_model.forward(obj)
61 |
62 | # print 'state_out: ', state_out.size()
63 | # print 'obj_out: ', obj_out.size()
64 |
65 | ## get object map
66 | heatmap, global_coeffs = self.heatmap_model((obj_out, text))
67 |
68 | ## repeat heatmap for multiplication by rbf batch
69 | heatmap_batch = heatmap.view(self.batch_size,self.map_dim**2,1,1).repeat(1,1,self.map_dim,self.map_dim)
70 |
71 | ## multiply object map by pre-computed manhattan rbf
72 | ## < batch x size^2 x size x size >
73 | obj_local = heatmap_batch * self.rbf_batch
74 | ## sum contributions from rbf from every source
75 | ## < batch x 1 x size x size >
76 | obj_local = obj_local.sum(1)
77 | # print 'obj_out:', obj_out.size()
78 |
79 | obj_global = self._global(global_coeffs)
80 |
81 | # pdb.set_trace()
82 |
83 | # obj_out = obj_local + obj_global
84 |
85 | map_pred = torch.cat( (state_out, obj_local, obj_global), 1 )
86 | # pdb.set_trace()
87 | map_pred = self.simple_conv(map_pred)
88 | # map_pred = state_out + obj_out
89 |
90 | return map_pred
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
--------------------------------------------------------------------------------
/models/multi_model.py:
--------------------------------------------------------------------------------
1 | ## predicts entire value map
2 | ## rather than a single value
3 |
4 | import torch
5 | import math, torch.nn as nn, pdb
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 | import models, utils
9 |
10 | class MultiModel(nn.Module):
11 | def __init__(self, state_model, object_model, heatmap_model, args, map_dim = 10):
12 | super(MultiModel, self).__init__()
13 |
14 | self.state_model = state_model
15 | self.object_model = object_model
16 | self.heatmap_model = heatmap_model
17 | self.simple_conv = models.SimpleConv(2).cuda()
18 | self.rbf = Variable( utils.meta_rbf(map_dim).cuda() )
19 |
20 | self.map_dim = map_dim
21 | self.batch_size = args.batch_size
22 | self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1)
23 |
24 |
25 |
26 | def forward(self, inp):
27 | (state, obj, text) = inp
28 | batch_size = state.size(0)
29 | if batch_size != self.batch_size:
30 | self.batch_size = batch_size
31 | self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1)
32 |
33 | ## get state map
34 | state_out = self.state_model(state)
35 | obj_out = self.object_model.forward(obj)
36 |
37 | # print 'state_out: ', state_out.size()
38 | ## get object map
39 | heatmap = self.heatmap_model((obj_out, text))
40 | ## repeat heatmap for multiplication by rbf batch
41 | heatmap_batch = heatmap.view(self.batch_size,self.map_dim**2,1,1).repeat(1,1,self.map_dim,self.map_dim)
42 | ## multiply object map by pre-computed manhattan rbf
43 | ## < batch x size^2 x size x size >
44 | # pdb.set_trace()
45 | obj_local = heatmap_batch * self.rbf_batch
46 | ## sum contributions from rbf from every source
47 | ## < batch x 1 x size x size >
48 | obj_local = obj_local.sum(1, keepdim=True)
49 | # print 'obj_out:', obj_out.size()
50 | # pdb.set_trace()
51 | map_pred = torch.cat( (state_out, obj_local), 1 )
52 | # pdb.set_trace()
53 | map_pred = self.simple_conv(map_pred)
54 | # map_pred = state_out + obj_out
55 |
56 | return map_pred
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/models/multi_nobases.py:
--------------------------------------------------------------------------------
1 | ## predicts entire value map
2 | ## rather than a single value
3 |
4 | import torch
5 | import math, torch.nn as nn, pdb
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 | import models, utils
9 |
10 | class MultiNoBases(nn.Module):
11 | def __init__(self, state_model, object_model, heatmap_model, args, map_dim = 10):
12 | super(MultiNoBases, self).__init__()
13 |
14 | self.state_model = state_model
15 | self.object_model = object_model
16 | self.heatmap_model = heatmap_model
17 | self.simple_conv = models.SimpleConv(2).cuda()
18 | self.rbf = Variable( utils.meta_rbf(map_dim).cuda() )
19 | self.positions = Variable( self.__init_positions(map_dim).cuda() )
20 |
21 | self.map_dim = map_dim
22 | self.batch_size = args.batch_size
23 | self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1)
24 | self.positions_batch = self.positions.repeat(self.batch_size,1,1,1)
25 |
26 |
27 | def __init_positions(self, map_dim):
28 | row = torch.arange(0,map_dim).unsqueeze(1).repeat(1,map_dim)
29 | col = torch.arange(0,map_dim).repeat(map_dim,1)
30 | positions = torch.stack( (row, col) )
31 | return positions
32 |
33 |
34 | def forward(self, inp):
35 | (state, obj, text) = inp
36 | batch_size = state.size(0)
37 | if batch_size != self.batch_size:
38 | self.batch_size = batch_size
39 | self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1)
40 | self.positions_batch = self.positions.repeat(self.batch_size,1,1,1)
41 |
42 | ## get state map
43 | state_out = self.state_model(state)
44 | obj_out = self.object_model.forward(obj)
45 |
46 | # print 'state_out: ', state_out.size()
47 | # print 'obj_out: ', obj_out.size()
48 |
49 | ## get object map
50 | heatmap = self.heatmap_model((obj_out, text))
51 |
52 | ## repeat heatmap for multiplication by rbf batch
53 | # heatmap_batch = heatmap.view(self.batch_size,self.map_dim**2,1,1).repeat(1,1,self.map_dim,self.map_dim)
54 |
55 | ## multiply object map by pre-computed manhattan rbf
56 | ## < batch x size^2 x size x size >
57 | # obj_local = heatmap_batch * self.rbf_batch
58 | ## sum contributions from rbf from every source
59 | ## < batch x 1 x size x size >
60 | # obj_local = obj_local.sum(1)
61 | # print 'obj_out:', obj_out.size()
62 |
63 | obj_local = heatmap
64 |
65 | # print obj_local.size()
66 |
67 | # obj_global = self._global(global_coeffs)
68 |
69 | # pdb.set_trace()
70 |
71 | # obj_out = obj_local + obj_global
72 |
73 | map_pred = torch.cat( (state_out, obj_local), 1 )
74 | # pdb.set_trace()
75 | map_pred = self.simple_conv(map_pred)
76 | # map_pred = state_out + obj_out
77 |
78 | return map_pred
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
--------------------------------------------------------------------------------
/models/multi_nocnn.py:
--------------------------------------------------------------------------------
1 | ## predicts entire value map
2 | ## rather than a single value
3 |
4 | import torch
5 | import math, torch.nn as nn, pdb
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 | import models, utils
9 |
10 | class MultiNoCNN(nn.Module):
11 | def __init__(self, state_model, object_model, heatmap_model, args, map_dim = 10):
12 | super(MultiNoCNN, self).__init__()
13 |
14 | self.state_model = state_model
15 | self.object_model = object_model
16 | self.heatmap_model = heatmap_model
17 | # self.simple_conv = models.SimpleConv(3).cuda()
18 | self.rbf = Variable( utils.meta_rbf(map_dim).cuda() )
19 | self.positions = Variable( self.__init_positions(map_dim).cuda() )
20 |
21 | self.map_dim = map_dim
22 | self.batch_size = args.batch_size
23 | self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1)
24 | self.positions_batch = self.positions.repeat(self.batch_size,1,1,1)
25 |
26 | '''
27 | global_coeffs are < batch x 3 >
28 | 3: row, col, bias
29 | '''
30 | def _global(self, global_coeffs):
31 | pos_coeffs = global_coeffs[:,:-1]
32 | bias = global_coeffs[:,-1]
33 |
34 | coeffs_batch = pos_coeffs.unsqueeze(-1).unsqueeze(-1).repeat(1,1,self.map_dim,self.map_dim)
35 | bias_batch = bias.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1,1,self.map_dim,self.map_dim)
36 |
37 | ## sum over row, col and add bias
38 | obj_global = (coeffs_batch * self.positions_batch).sum(1, keepdim=True) + bias_batch
39 | return obj_global
40 |
41 |
42 |
43 | def __init_positions(self, map_dim):
44 | row = torch.arange(0,map_dim).unsqueeze(1).repeat(1,map_dim)
45 | col = torch.arange(0,map_dim).repeat(map_dim,1)
46 | positions = torch.stack( (row, col) )
47 | return positions
48 |
49 |
50 | def forward(self, inp):
51 | (state, obj, text) = inp
52 | batch_size = state.size(0)
53 | if batch_size != self.batch_size:
54 | self.batch_size = batch_size
55 | self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1)
56 | self.positions_batch = self.positions.repeat(self.batch_size,1,1,1)
57 |
58 | ## get state map
59 | state_out = self.state_model(state)
60 | obj_out = self.object_model.forward(obj)
61 |
62 | # print 'state_out: ', state_out.size()
63 | # print 'obj_out: ', obj_out.size()
64 |
65 | ## get object map
66 | heatmap, global_coeffs = self.heatmap_model((obj_out, text))
67 |
68 | ## repeat heatmap for multiplication by rbf batch
69 | heatmap_batch = heatmap.view(self.batch_size,self.map_dim**2,1,1).repeat(1,1,self.map_dim,self.map_dim)
70 |
71 | ## multiply object map by pre-computed manhattan rbf
72 | ## < batch x size^2 x size x size >
73 | obj_local = heatmap_batch * self.rbf_batch
74 | ## sum contributions from rbf from every source
75 | ## < batch x 1 x size x size >
76 | obj_local = obj_local.sum(1)
77 | # print 'obj_out:', obj_out.size()
78 |
79 | obj_global = self._global(global_coeffs)
80 |
81 | # pdb.set_trace()
82 |
83 | # obj_out = obj_local + obj_global
84 |
85 | # map_pred = torch.cat( (state_out, obj_local, obj_global), 1 )
86 | # pdb.set_trace()
87 | # map_pred = self.simple_conv(map_pred)
88 | # map_pred = state_out + obj_out
89 |
90 | map_pred = state_out + obj_local + obj_global
91 |
92 | return map_pred
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
--------------------------------------------------------------------------------
/models/multi_nonsep.py:
--------------------------------------------------------------------------------
1 | ## predicts entire value map
2 | ## rather than a single value
3 |
4 | import torch
5 | import math, torch.nn as nn, pdb
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 | import models, utils
9 |
10 | class MultiNonSep(nn.Module):
11 | def __init__(self, state_model, object_model, heatmap_model, args, map_dim = 10):
12 | super(MultiNonSep, self).__init__()
13 |
14 | self.state_model = state_model
15 | self.object_model = object_model
16 | self.heatmap_model = heatmap_model
17 | self.simple_conv = models.SimpleConv(2).cuda()
18 | self.rbf = Variable( utils.meta_rbf(map_dim).cuda() )
19 | self.positions = Variable( self.__init_positions(map_dim).cuda() )
20 |
21 | self.map_dim = map_dim
22 | self.batch_size = args.batch_size
23 | # self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1)
24 | self.positions_batch = self.positions.repeat(self.batch_size,1,1,1)
25 |
26 | '''
27 | global_coeffs are < batch x 3 >
28 | 3: row, col, bias
29 | '''
30 | def _global(self, global_coeffs):
31 | pos_coeffs = global_coeffs[:,:-1]
32 | bias = global_coeffs[:,-1]
33 |
34 | coeffs_batch = pos_coeffs.unsqueeze(-1).unsqueeze(-1).repeat(1,1,self.map_dim,self.map_dim)
35 | bias_batch = bias.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1,1,self.map_dim,self.map_dim)
36 |
37 | ## sum over row, col and add bias
38 | obj_global = (coeffs_batch * self.positions_batch).sum(1, keepdim=True) + bias_batch
39 | return obj_global
40 |
41 |
42 |
43 | def __init_positions(self, map_dim):
44 | row = torch.arange(0,map_dim).unsqueeze(1).repeat(1,map_dim)
45 | col = torch.arange(0,map_dim).repeat(map_dim,1)
46 | positions = torch.stack( (row, col) )
47 | return positions
48 |
49 |
50 | def forward(self, inp):
51 | (state, obj, text) = inp
52 | batch_size = state.size(0)
53 | if batch_size != self.batch_size:
54 | self.batch_size = batch_size
55 | # self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1)
56 | self.positions_batch = self.positions.repeat(self.batch_size,1,1,1)
57 |
58 | ## get state map
59 | state_out = self.state_model(state)
60 | obj_out = self.object_model.forward(obj)
61 | embeddings = torch.cat( (state_out, obj_out), 1 )
62 |
63 | # print 'state_out: ', state_out.size()
64 | # print 'obj_out: ', obj_out.size()
65 |
66 | ## get object map
67 | heatmap, global_coeffs = self.heatmap_model((embeddings, text))
68 |
69 | ## repeat heatmap for multiplication by rbf batch
70 | # heatmap_batch = heatmap.view(self.batch_size,self.map_dim**2,1,1).repeat(1,1,self.map_dim,self.map_dim)
71 |
72 | ## multiply object map by pre-computed manhattan rbf
73 | ## < batch x size^2 x size x size >
74 | # obj_local = heatmap_batch * self.rbf_batch
75 | ## sum contributions from rbf from every source
76 | ## < batch x 1 x size x size >
77 | # obj_local = obj_local.sum(1)
78 | # print 'obj_out:', obj_out.size()
79 |
80 | obj_local = heatmap
81 | obj_global = self._global(global_coeffs)
82 |
83 | # pdb.set_trace()
84 |
85 | # obj_out = obj_local + obj_global
86 |
87 | map_pred = torch.cat( (obj_local, obj_global), 1 )
88 | # pdb.set_trace()
89 | map_pred = self.simple_conv(map_pred)
90 | # map_pred = state_out + obj_out
91 |
92 | return map_pred
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
--------------------------------------------------------------------------------
/models/multi_norbf.py:
--------------------------------------------------------------------------------
1 | ## predicts entire value map
2 | ## rather than a single value
3 |
4 | import torch
5 | import math, torch.nn as nn, pdb
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 | import models, utils
9 |
10 | class MultiNoRBF(nn.Module):
11 | def __init__(self, state_model, object_model, heatmap_model, args, map_dim = 10):
12 | super(MultiNoRBF, self).__init__()
13 |
14 | self.state_model = state_model
15 | self.object_model = object_model
16 | self.heatmap_model = heatmap_model
17 | self.simple_conv = models.SimpleConv(3).cuda()
18 | self.rbf = Variable( utils.meta_rbf(map_dim).cuda() )
19 | self.positions = Variable( self.__init_positions(map_dim).cuda() )
20 |
21 | self.map_dim = map_dim
22 | self.batch_size = args.batch_size
23 | self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1)
24 | self.positions_batch = self.positions.repeat(self.batch_size,1,1,1)
25 |
26 | '''
27 | global_coeffs are < batch x 3 >
28 | 3: row, col, bias
29 | '''
30 | def _global(self, global_coeffs):
31 | pos_coeffs = global_coeffs[:,:-1]
32 | bias = global_coeffs[:,-1]
33 |
34 | coeffs_batch = pos_coeffs.unsqueeze(-1).unsqueeze(-1).repeat(1,1,self.map_dim,self.map_dim)
35 | bias_batch = bias.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1,1,self.map_dim,self.map_dim)
36 |
37 | ## sum over row, col and add bias
38 | obj_global = (coeffs_batch * self.positions_batch).sum(1, keepdim=True) + bias_batch
39 | return obj_global
40 |
41 |
42 |
43 | def __init_positions(self, map_dim):
44 | row = torch.arange(0,map_dim).unsqueeze(1).repeat(1,map_dim)
45 | col = torch.arange(0,map_dim).repeat(map_dim,1)
46 | positions = torch.stack( (row, col) )
47 | return positions
48 |
49 |
50 | def forward(self, inp):
51 | (state, obj, text) = inp
52 | batch_size = state.size(0)
53 | if batch_size != self.batch_size:
54 | self.batch_size = batch_size
55 | self.rbf_batch = self.rbf.repeat(self.batch_size,1,1,1)
56 | self.positions_batch = self.positions.repeat(self.batch_size,1,1,1)
57 |
58 | ## get state map
59 | state_out = self.state_model(state)
60 | obj_out = self.object_model.forward(obj)
61 |
62 | # print 'state_out: ', state_out.size()
63 | # print 'obj_out: ', obj_out.size()
64 |
65 | ## get object map
66 | heatmap, global_coeffs = self.heatmap_model((obj_out, text))
67 |
68 | ## repeat heatmap for multiplication by rbf batch
69 | # heatmap_batch = heatmap.view(self.batch_size,self.map_dim**2,1,1).repeat(1,1,self.map_dim,self.map_dim)
70 |
71 | ## multiply object map by pre-computed manhattan rbf
72 | ## < batch x size^2 x size x size >
73 | # obj_local = heatmap_batch * self.rbf_batch
74 | ## sum contributions from rbf from every source
75 | ## < batch x 1 x size x size >
76 | # obj_local = obj_local.sum(1)
77 | # print 'obj_out:', obj_out.size()
78 |
79 | obj_local = heatmap
80 | obj_global = self._global(global_coeffs)
81 |
82 | self.output_local = obj_local
83 | self.output_global = obj_global
84 |
85 | # pdb.set_trace()
86 |
87 | # obj_out = obj_local + obj_global
88 | map_pred = torch.cat( (state_out, obj_local, obj_global), 1 )
89 | # pdb.set_trace()
90 | map_pred = self.simple_conv(map_pred)
91 | # map_pred = state_out + obj_out
92 |
93 | return map_pred
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
--------------------------------------------------------------------------------
/models/object_model.py:
--------------------------------------------------------------------------------
1 | import sys, math
2 | import numpy as np
3 | from tqdm import tqdm, trange
4 | import torch, torch.nn as nn, torch.nn.functional as F
5 | import torch.optim as optim
6 |
7 | '''
8 | Object observations are single-channels images
9 | with positive indices denoting objects.
10 | 0's denote no object.
11 |
12 | '''
13 |
14 | class ObjectModel(nn.Module):
15 | def __init__(self, vocab_size, embed_dim, inp_size, out_dim):
16 | super(ObjectModel, self).__init__()
17 |
18 | self.reshape = [-1]
19 | for dim in inp_size:
20 | self.reshape.append(dim)
21 | self.reshape.append(embed_dim)
22 |
23 | self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
24 | self.conv1 = nn.Conv2d(embed_dim, 3, kernel_size=5)
25 | self.conv2 = nn.Conv2d(3, 6, kernel_size=5)
26 | self.conv3 = nn.Conv2d(6,12, kernel_size=5)
27 | self.conv4 = nn.Conv2d(12,12, kernel_size=5)
28 | self.fc1 = nn.Linear(192, out_dim)
29 |
30 | def forward(self, x):
31 | x = x.view(-1)
32 | x = self.embed(x)
33 | x = x.view(*self.reshape)
34 | x = x.transpose(1,-1).squeeze()
35 | x = F.relu(self.conv1(x))
36 | x = F.relu(self.conv2(x))
37 | x = F.relu(self.conv3(x))
38 | x = F.relu(self.conv4(x))
39 | x = x.view(-1, 192)
40 | x = self.fc1(x)
41 | return x
42 |
43 | if __name__ == '__main__':
44 | inp = torch.LongTensor(1,20,20).zero_()
45 | vocab_size = 10
46 | emb_dim = 3
47 | rank = 10
48 | phi = ObjectModel(vocab_size, emb_dim,inp.size(), rank)
49 |
50 | # enc = nn.Embedding(10,emb_dim,padding_idx=0)
51 | inp = torch.LongTensor(5,1,20,20).zero_()
52 | inp[0][0][0][0]=1
53 | # inp[0][1][0][0]=1
54 | inp[1][0][0][2]=1
55 | print inp
56 | inp = Variable(inp.view(-1))
57 |
58 | out = phi.forward(inp)
59 | # print out
60 | # out = out.view(-1,2,3,3,emb_dim)
61 | out = out.data
62 | print out.size()
63 |
64 | # print out[0][0][0]
65 | # print out[1][0][0]
66 |
67 |
--------------------------------------------------------------------------------
/models/object_model_10.py:
--------------------------------------------------------------------------------
1 | import sys, math
2 | import numpy as np
3 | from tqdm import tqdm, trange
4 | import torch, torch.nn as nn, torch.nn.functional as F
5 | import torch.optim as optim
6 |
7 | '''
8 | Object observations are single-channels images
9 | with positive indices denoting objects.
10 | 0's denote no object.
11 |
12 | '''
13 |
14 | class ObjectModel(nn.Module):
15 | def __init__(self, vocab_size, embed_dim, inp_size, out_dim):
16 | super(ObjectModel, self).__init__()
17 |
18 | self.reshape = [-1]
19 | for dim in inp_size:
20 | self.reshape.append(dim)
21 | self.reshape.append(embed_dim)
22 |
23 | self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
24 | self.conv1 = nn.Conv2d(embed_dim, 3, kernel_size=3)
25 | self.conv2 = nn.Conv2d(3, 6, kernel_size=3)
26 | self.conv3 = nn.Conv2d(6,12, kernel_size=3)
27 | # self.conv4 = nn.Conv2d(12,12, kernel_size=5)
28 | self.fc1 = nn.Linear(192, out_dim)
29 |
30 | def forward(self, x):
31 | x = x.view(-1)
32 | x = self.embed(x)
33 | x = x.view(*self.reshape)
34 | x = x.transpose(1,-1).squeeze()
35 | x = F.relu(self.conv1(x))
36 | x = F.relu(self.conv2(x))
37 | x = F.relu(self.conv3(x))
38 | # x = F.relu(self.conv4(x))
39 | x = x.view(-1, 192)
40 | x = self.fc1(x)
41 | return x
42 |
43 | if __name__ == '__main__':
44 | from torch.autograd import Variable
45 | inp = torch.LongTensor(1,10,10).zero_()
46 | vocab_size = 10
47 | emb_dim = 3
48 | rank = 10
49 | phi = ObjectModel(vocab_size, emb_dim,inp.size(), rank)
50 |
51 | # enc = nn.Embedding(10,emb_dim,padding_idx=0)
52 | inp = torch.LongTensor(5,1,20,20).zero_()
53 | inp[0][0][0][0]=1
54 | # inp[0][1][0][0]=1
55 | inp[1][0][0][2]=1
56 | # print inp
57 | inp = Variable(inp.view(-1))
58 |
59 | out = phi.forward(inp)
60 | # print out
61 | # out = out.view(-1,2,3,3,emb_dim)
62 | # out = out.data
63 | print out.size()
64 | loss = out.sum()
65 | loss.backward()
66 |
67 | print phi.parameters
68 | print loss
69 | print loss.grad
70 | # print a
71 |
72 | # print out[0][0][0]
73 | # print out[1][0][0]
74 |
75 |
--------------------------------------------------------------------------------
/models/object_model_fc.py:
--------------------------------------------------------------------------------
1 | import sys, math
2 | import numpy as np
3 | from tqdm import tqdm, trange
4 | import torch, torch.nn as nn, torch.nn.functional as F
5 | import torch.optim as optim
6 |
7 | '''
8 | Object observations are single-channels images
9 | with positive indices denoting objects.
10 | 0's denote no object.
11 |
12 | '''
13 |
14 | class ObjectModel(nn.Module):
15 | def __init__(self, vocab_size, embed_dim, inp_size, out_dim):
16 | super(ObjectModel, self).__init__()
17 |
18 | self.reshape = [-1]
19 | for dim in inp_size:
20 | self.reshape.append(dim)
21 | self.reshape.append(embed_dim)
22 |
23 | self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
24 | self.conv1 = nn.Conv2d(embed_dim, 3, kernel_size=3)
25 | self.conv2 = nn.Conv2d(3, 6, kernel_size=3)
26 | self.conv3 = nn.Conv2d(6,12, kernel_size=3)
27 | self.conv4 = nn.Conv2d(12,12, kernel_size=5)
28 | self.fc1 = nn.Linear(192, out_dim)
29 |
30 | def forward(self, x):
31 | ### 1D
32 | # print 'x: ', x.size()
33 | x = x.view(-1)
34 | x = self.embed(x)
35 | # print 'embed: ', x.size()
36 | ### -1, C, M, N, embed
37 | x = x.view(*self.reshape)
38 | # print 'view: ', x.size()
39 | ### -1, embed, M, N, C
40 | x = x.transpose(1,-1).squeeze()
41 | x = F.relu(self.conv1(x))
42 | x = F.relu(self.conv2(x))
43 | x = F.relu(self.conv3(x))
44 | # x = F.relu(self.conv4(x))
45 | x = x.view(-1, 192)
46 | x = self.fc1(x)
47 | return x
48 |
49 | if __name__ == '__main__':
50 | from torch.autograd import Variable
51 | inp = torch.LongTensor(1,10,10).zero_()
52 | vocab_size = 10
53 | emb_dim = 3
54 | rank = 10
55 | phi = ObjectModel(vocab_size, emb_dim,inp.size(), rank)
56 |
57 | # enc = nn.Embedding(10,emb_dim,padding_idx=0)
58 | inp = torch.LongTensor(5,1,20,20).zero_()
59 | inp[0][0][0][0]=1
60 | # inp[0][1][0][0]=1
61 | inp[1][0][0][2]=1
62 | # print inp
63 | inp = Variable(inp.view(-1))
64 |
65 | out = phi.forward(inp)
66 | # print out
67 | # out = out.view(-1,2,3,3,emb_dim)
68 | # out = out.data
69 | print out.size()
70 | loss = out.sum()
71 | loss.backward()
72 |
73 | print phi.parameters
74 | print loss
75 | print loss.grad
76 | # print a
77 |
78 | # print out[0][0][0]
79 | # print out[1][0][0]
80 |
81 |
--------------------------------------------------------------------------------
/models/object_model_pos.py:
--------------------------------------------------------------------------------
1 | import sys, math
2 | import numpy as np
3 | from tqdm import tqdm, trange
4 | import torch, torch.nn as nn, torch.nn.functional as F
5 | import torch.optim as optim
6 | from torch.autograd import Variable
7 | import pdb
8 |
9 | '''
10 | Object observations are single-channels images
11 | with positive indices denoting objects.
12 | 0's denote no object.
13 |
14 | '''
15 |
16 | class ObjectModel(nn.Module):
17 | def __init__(self, vocab_size, embed_dim, inp_size, out_dim):
18 | super(ObjectModel, self).__init__()
19 | (self.num_objects, self.embed_dim) = inp_size
20 | self.hidden_dim_1 = 10
21 | self.hidden_dim_2 = 10
22 | self.out_dim = out_dim
23 |
24 | self.reshape = [-1]
25 | for dim in inp_size:
26 | self.reshape.append(dim)
27 | self.reshape.append(embed_dim)
28 |
29 | self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
30 | self.fc1 = nn.Linear(embed_dim + 2, self.hidden_dim_1)
31 | self.fc2 = nn.Linear(self.hidden_dim_1, self.hidden_dim_2)
32 | self.fc3 = nn.Linear(self.hidden_dim_2, self.out_dim)
33 | # self.fc1 = nn.Linear(embed_dim+2, 20)
34 | # self.conv1 = nn.Conv2d(embed_dim, 3, kernel_size=5)
35 | # self.conv2 = nn.Conv2d(3, 6, kernel_size=5)
36 | # self.conv3 = nn.Conv2d(6,12, kernel_size=5)
37 | # self.conv4 = nn.Conv2d(12,12, kernel_size=5)
38 | # self.fc1 = nn.Linear(192, out_dim)
39 | self.fc_temp = nn.Linear(self.hidden_dim_1, self.hidden_dim_2)
40 |
41 | ## batch x 5 x 3
42 | def forward(self, x):
43 | ## batch x 5 x 1
44 | indices = x[:,:,0]
45 |
46 | ## batch x 5 x embed
47 | embeddings = self.embed(indices)
48 |
49 | ## batch x 5 x pos
50 | positions = x[:,:,1:].float()
51 |
52 | ## join embeddings and positions
53 | ## batch x 5 x (embed + pos)
54 | x = torch.cat( (embeddings, positions), 2 )
55 |
56 | ## reshape
57 | ## (batch * 5) x (embed + pos)
58 | x = x.view(-1, self.embed_dim + 2).float()
59 |
60 | ## fc1
61 | ## (batch * 5) x hidden1
62 | x = F.relu( self.fc1(x) )
63 | x = F.relu( self.fc_temp(x) )
64 |
65 | ## reshape
66 | ## batch x 5 x hidden1
67 | x = x.view(-1, self.num_objects, self.hidden_dim_1)
68 |
69 | ## get rid of middle object dimension
70 | ## batch x hidden1
71 | x = x.max(1)[0].squeeze()
72 |
73 | ## fc2
74 | ## batch x hidden2
75 | x = F.relu( self.fc2(x) )
76 |
77 | ## fc3
78 | ## batch x out
79 | x = self.fc3(x)
80 |
81 | return x
82 |
83 | if __name__ == '__main__':
84 | inp = torch.LongTensor(5,3).zero_()
85 | vocab_size = 10
86 | emb_dim = 3
87 | rank = 10
88 | phi = ObjectModel(vocab_size, emb_dim, inp.size(), rank)
89 |
90 | # enc = nn.Embedding(10,emb_dim,padding_idx=0)
91 | inp = torch.LongTensor(2,5,3).zero_()
92 | inp[0][0]=1
93 | inp[0][1]=1
94 | inp[1][0]=8
95 | inp[1][1]=8
96 | print inp
97 | # inp[0][0][0][0]=1
98 | # inp[0][1][0][0]=1
99 | # inp[1][0][0][2]=1
100 | # print inp
101 | inp = Variable(inp)
102 |
103 | out = phi.forward(inp)
104 | # print out
105 | # out = out.view(-1,2,3,3,emb_dim)
106 | out = out.data
107 | print out.size()
108 |
109 | # print out[0][0][0]
110 | # print out[1][0][0]
111 |
112 |
--------------------------------------------------------------------------------
/models/object_model_rnn.py:
--------------------------------------------------------------------------------
1 | import sys, math
2 | import numpy as np
3 | from tqdm import tqdm, trange
4 | import torch, torch.nn as nn, torch.nn.functional as F
5 | import torch.optim as optim
6 | from torch.autograd import Variable
7 | import pdb
8 |
9 | '''
10 | Object observations are single-channels images
11 | with positive indices denoting objects.
12 | 0's denote no object.
13 |
14 | '''
15 |
16 | class ObjectModel(nn.Module):
17 | def __init__(self, vocab_size, embed_dim, inp_size, out_dim):
18 | super(ObjectModel, self).__init__()
19 | (self.num_objects, self.embed_dim) = inp_size
20 | self.hidden_dim_1 = 10
21 | self.hidden_dim_2 = 10
22 | self.out_dim = out_dim
23 |
24 | self.reshape = [-1]
25 | for dim in inp_size:
26 | self.reshape.append(dim)
27 | self.reshape.append(embed_dim)
28 |
29 | self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
30 | self.fc1 = nn.Linear(embed_dim + 2, self.hidden_dim_1)
31 | self.fc2 = nn.Linear(self.hidden_dim_1, self.hidden_dim_2)
32 | self.fc3 = nn.Linear(self.hidden_dim_2, self.out_dim)
33 |
34 | self.ninp = self.embed_dim + 2
35 | self.nhid = 10
36 | self.nlayers = 1
37 | self.rnn = nn.LSTM(self.ninp, self.nhid, self.nlayers)
38 | # self.fc1 = nn.Linear(embed_dim+2, 20)
39 | # self.conv1 = nn.Conv2d(embed_dim, 3, kernel_size=5)
40 | # self.conv2 = nn.Conv2d(3, 6, kernel_size=5)
41 | # self.conv3 = nn.Conv2d(6,12, kernel_size=5)
42 | # self.conv4 = nn.Conv2d(12,12, kernel_size=5)
43 | # self.fc1 = nn.Linear(192, out_dim)
44 | self.fc_temp = nn.Linear(self.hidden_dim_1, self.hidden_dim_2)
45 |
46 | ## batch x 5 x 3
47 | def forward(self, x, hidden):
48 | ## batch x 5 x 1
49 | indices = x[:,:,0]
50 |
51 | ## batch x 5 x embed
52 | embeddings = self.embed(indices)
53 |
54 | ## batch x 5 x pos
55 | positions = x[:,:,1:].float()
56 |
57 | ## join embeddings and positions
58 | ## batch x 5 x (embed + pos)
59 | x = torch.cat( (embeddings, positions), 2 )
60 | print 'cat: ', x.size()
61 |
62 | x = self.rnn(x, hidden)
63 | print 'x; ', x.size()
64 |
65 | # ## reshape
66 | # ## (batch * 5) x (embed + pos)
67 | # x = x.view(-1, self.embed_dim + 2).float()
68 |
69 | # ## fc1
70 | # ## (batch * 5) x hidden1
71 | # x = F.relu( self.fc1(x) )
72 | # x = F.relu( self.fc_temp(x) )
73 |
74 | # ## reshape
75 | # ## batch x 5 x hidden1
76 | # x = x.view(-1, self.num_objects, self.hidden_dim_1)
77 |
78 | # ## get rid of middle object dimension
79 | # ## batch x hidden1
80 | # x = x.max(1)[0].squeeze()
81 |
82 | # ## fc2
83 | # ## batch x hidden2
84 | # x = F.relu( self.fc2(x) )
85 |
86 | # ## fc3
87 | # ## batch x out
88 | # x = self.fc3(x)
89 |
90 | return x
91 |
92 | def init_hidden(self, bsz):
93 | weight = next(self.parameters()).data
94 | return (Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()),
95 | Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()))
96 |
97 | if __name__ == '__main__':
98 | inp = torch.LongTensor(5,3).zero_()
99 | vocab_size = 10
100 | emb_dim = 3
101 | rank = 10
102 | phi = ObjectModel(vocab_size, emb_dim, inp.size(), rank)
103 |
104 | batch = 2
105 | hidden = phi.init_hidden(batch)
106 |
107 | # enc = nn.Embedding(10,emb_dim,padding_idx=0)
108 | inp = torch.LongTensor(batch,5,3).zero_()
109 | inp[0][0]=1
110 | inp[0][1]=1
111 | inp[1][0]=8
112 | inp[1][1]=8
113 | print inp
114 | # inp[0][0][0][0]=1
115 | # inp[0][1][0][0]=1
116 | # inp[1][0][0][2]=1
117 | # print inp
118 | inp = Variable(inp)
119 |
120 | out = phi.forward(inp, hidden)
121 | # print out
122 | # out = out.view(-1,2,3,3,emb_dim)
123 | out = out.data
124 | print out.size()
125 |
126 | # print out[0][0][0]
127 | # print out[1][0][0]
128 |
129 |
--------------------------------------------------------------------------------
/models/simple_conv.py:
--------------------------------------------------------------------------------
1 | import sys, math, pdb
2 | import numpy as np
3 | from tqdm import tqdm, trange
4 | import torch, torch.nn as nn, torch.nn.functional as F
5 | import torch.optim as optim
6 |
7 | '''
8 | State observations are two-channel images
9 | with 0: puddle, 1: grass, 2: agent
10 |
11 | '''
12 |
13 | class SimpleConv(nn.Module):
14 | def __init__(self, in_channels):
15 | super(SimpleConv, self).__init__()
16 |
17 | self.in_channels = in_channels
18 |
19 | # self.embed = nn.Embedding(vocab_size, in_channels)
20 | self.conv1 = nn.Conv2d(in_channels, 3, kernel_size=3, padding=1)
21 | self.conv2 = nn.Conv2d(3, 6, kernel_size=3, padding=1)
22 | self.conv3 = nn.Conv2d(6, 12, kernel_size=3, padding=1)
23 | self.conv4 = nn.Conv2d(12,18, kernel_size=3, padding=1)
24 | self.conv5 = nn.Conv2d(18,24, kernel_size=3, padding=1)
25 | self.conv6 = nn.Conv2d(24,18, kernel_size=3, padding=1)
26 | self.conv7 = nn.Conv2d(18,12, kernel_size=3, padding=1)
27 | self.conv8 = nn.Conv2d(12, 6, kernel_size=3, padding=1)
28 | self.conv9 = nn.Conv2d(6, 3, kernel_size=3, padding=1)
29 | self.conv10 = nn.Conv2d(3, 1, kernel_size=3, padding=1)
30 | # # self.conv4 = nn.Conv2d(12,12, kernel_size=5)
31 | # self.fc1 = nn.Linear(192, out_dim)
32 |
33 | def forward(self, x):
34 | # reshape = []
35 | # for dim in x.size(): reshape.append(dim)
36 | # reshape.append(self.in_channels)
37 |
38 | # ## reshape to vector
39 | # x = x.view(-1)
40 | # ## get embeddings
41 | # x = self.embed(x)
42 | # ## reshape to batch x channels x M x N x embed_dim
43 | # x = x.view(*reshape)
44 | # ## sum over channels in input
45 | # x = x.sum(1)
46 | # # pdb.set_trace()
47 | # ## reshape to batch x embed_dim x M x N
48 | # ## (treats embedding dims as channels)
49 | # x = x.transpose(1,-1)[:,:,:,:,0] #.squeeze() #
50 | # print 'SIZE:', x.size()
51 | # pdb.set_trace()
52 | x = F.relu(self.conv1(x))
53 | x = F.relu(self.conv2(x))
54 | x = F.relu(self.conv3(x))
55 | # x = F.relu(self.conv4(x))
56 | # x = F.relu(self.conv5(x))
57 | # x = F.relu(self.conv6(x))
58 | # x = F.relu(self.conv7(x))
59 | x = F.relu(self.conv8(x))
60 | x = F.relu(self.conv9(x))
61 | x = self.conv10(x)
62 |
63 | # x = x.view(-1, 192)
64 | # x = self.fc1(x)
65 | return x
66 |
67 |
68 | if __name__ == '__main__':
69 | from torch.autograd import Variable
70 | # inp = torch.LongTensor(2,10,10).zero_()
71 | vocab_size = 10
72 | emb_dim = 3
73 | rank = 7
74 | phi = MapModel(vocab_size, emb_dim, rank)
75 |
76 | # enc = nn.Embedding(10,emb_dim,padding_idx=0)
77 | inp = torch.LongTensor(5,2,10,10).zero_()
78 | inp[0][0][0][0]=1
79 | # inp[0][1][0][0]=1
80 | inp[1][0][0][2]=1
81 | print inp.size()
82 | inp = Variable(inp)
83 |
84 | out = phi.forward(inp)
85 | # print out
86 | # out = out.view(-1,2,3,3,emb_dim)
87 | out = out.data
88 | print out.size()
89 |
90 | # print out[0][0][0]
91 | # print out[1][0][0]
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
--------------------------------------------------------------------------------
/models/state_model.py:
--------------------------------------------------------------------------------
1 | import sys, math
2 | import numpy as np
3 | from tqdm import tqdm, trange
4 | import torch, torch.nn as nn, torch.nn.functional as F
5 | import torch.optim as optim
6 |
7 | '''
8 | State observations are two-channel images
9 | with 0: puddle, 1: grass, 2: agent
10 |
11 | '''
12 |
13 | class Phi(nn.Module):
14 | def __init__(self, vocab_size, embed_dim, inp_size, out_dim):
15 | super(Phi, self).__init__()
16 |
17 | self.reshape = [-1]
18 | for dim in inp_size:
19 | self.reshape.append(dim)
20 | self.reshape.append(embed_dim)
21 |
22 | self.embed = nn.Embedding(vocab_size, embed_dim)
23 | self.conv1 = nn.Conv2d(embed_dim, 3, kernel_size=5)
24 | self.conv2 = nn.Conv2d(3, 6, kernel_size=5)
25 | self.conv3 = nn.Conv2d(6,12, kernel_size=5)
26 | self.conv4 = nn.Conv2d(12,12, kernel_size=5)
27 | self.fc1 = nn.Linear(192, out_dim)
28 |
29 | def forward(self, x):
30 | x = x.view(-1)
31 | x = self.embed(x)
32 | x = x.view(*self.reshape)
33 | x = x.sum(1, keepdim=True)
34 | x = x.transpose(1,-1).squeeze()
35 | x = F.relu(self.conv1(x))
36 | x = F.relu(self.conv2(x))
37 | x = F.relu(self.conv3(x))
38 | x = F.relu(self.conv4(x))
39 | x = x.view(-1, 192)
40 | x = self.fc1(x)
41 | return x
42 |
43 |
44 | if __name__ == '__main__':
45 | inp = torch.LongTensor(2,20,20).zero_()
46 | vocab_size = 10
47 | emb_dim = 3
48 | rank = 7
49 | phi = Phi(vocab_size, emb_dim,inp.size(), rank)
50 |
51 | # enc = nn.Embedding(10,emb_dim,padding_idx=0)
52 | inp = torch.LongTensor(5,2,20,20).zero_()
53 | inp[0][0][0][0]=1
54 | inp[0][1][0][0]=1
55 | inp[1][0][0][2]=1
56 | print inp
57 | inp = Variable(inp.view(-1))
58 |
59 | out = phi.forward(inp)
60 | # print out
61 | # out = out.view(-1,2,3,3,emb_dim)
62 | out = out.data
63 | print out.size()
64 |
65 | # print out[0][0][0]
66 | # print out[1][0][0]
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/models/state_model_10.py:
--------------------------------------------------------------------------------
1 | import sys, math
2 | import numpy as np
3 | from tqdm import tqdm, trange
4 | import torch, torch.nn as nn, torch.nn.functional as F
5 | import torch.optim as optim
6 |
7 | '''
8 | State observations are two-channel images
9 | with 0: puddle, 1: grass, 2: agent
10 |
11 | '''
12 |
13 | class Phi(nn.Module):
14 | def __init__(self, vocab_size, embed_dim, inp_size, out_dim):
15 | super(Phi, self).__init__()
16 |
17 | self.reshape = [-1]
18 | for dim in inp_size:
19 | self.reshape.append(dim)
20 | self.reshape.append(embed_dim)
21 |
22 | self.embed = nn.Embedding(vocab_size, embed_dim)
23 | self.conv1 = nn.Conv2d(embed_dim, 3, kernel_size=3)
24 | self.conv2 = nn.Conv2d(3, 6, kernel_size=3)
25 | self.conv3 = nn.Conv2d(6,12, kernel_size=3)
26 | # self.conv4 = nn.Conv2d(12,12, kernel_size=5)
27 | self.fc1 = nn.Linear(192, out_dim)
28 |
29 | def forward(self, x):
30 | x = x.view(-1)
31 | x = self.embed(x)
32 | x = x.view(*self.reshape)
33 | x = x.sum(1, keepdim=True)
34 | x = x.transpose(1,-1).squeeze()
35 | x = F.relu(self.conv1(x))
36 | x = F.relu(self.conv2(x))
37 | x = F.relu(self.conv3(x))
38 | # x = F.relu(self.conv4(x))
39 | x = x.view(-1, 192)
40 | x = self.fc1(x)
41 | return x
42 |
43 |
44 | if __name__ == '__main__':
45 | from torch.autograd import Variable
46 | inp = torch.LongTensor(2,10,10).zero_()
47 | vocab_size = 10
48 | emb_dim = 3
49 | rank = 7
50 | phi = Phi(vocab_size, emb_dim,inp.size(), rank)
51 |
52 | # enc = nn.Embedding(10,emb_dim,padding_idx=0)
53 | inp = torch.LongTensor(5,2,20,20).zero_()
54 | inp[0][0][0][0]=1
55 | inp[0][1][0][0]=1
56 | inp[1][0][0][2]=1
57 | print inp
58 | inp = Variable(inp.view(-1))
59 |
60 | out = phi.forward(inp)
61 | # print out
62 | # out = out.view(-1,2,3,3,emb_dim)
63 | out = out.data
64 | print out.size()
65 |
66 | # print out[0][0][0]
67 | # print out[1][0][0]
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
--------------------------------------------------------------------------------
/models/text_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | import torch.nn.functional as F
5 |
6 | '''
7 | Text inputs are seq x batch
8 | '''
9 | class TextModel(nn.Module):
10 |
11 | def __init__(self, vocab_size, ninp, nhid, nlayers, out_dim):
12 | super(TextModel, self).__init__()
13 |
14 | self.rnn_type = 'LSTM'
15 | self.nhid = nhid
16 | self.nlayers = nlayers
17 | self.out_dim = out_dim
18 |
19 | self.encoder = nn.Embedding(vocab_size, ninp, padding_idx=0)
20 | self.rnn = nn.LSTM(ninp, nhid, nlayers)
21 | self.decoder = nn.Linear(nhid, out_dim)
22 | self.init_weights()
23 |
24 | def init_weights(self):
25 | initrange = 0.1
26 | # self.encoder.weight.data.uniform_(-initrange, initrange)
27 | self.decoder.bias.data.fill_(0)
28 | self.decoder.weight.data.uniform_(-initrange, initrange)
29 |
30 | def forward(self, inp, hidden):
31 | emb = self.encoder(inp)
32 | # print 'emb: ', emb.size()
33 | output, hidden = self.rnn(emb, hidden)
34 | # print 'TEXT: ', output.size()
35 | final_output = output[-1,:,:]
36 | decoded = self.decoder(final_output)
37 | return decoded
38 |
39 | def init_hidden(self, bsz):
40 | weight = next(self.parameters()).data
41 | return (Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()),
42 | Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()))
43 |
44 |
45 | if __name__ == '__main__':
46 | batch = 12
47 | seq = 4
48 | vocab = 20
49 | ninp = 5
50 | nhid = 3
51 | nlayers = 2
52 | out_dim = 7
53 | rnn = TextModel(vocab, ninp, nhid, nlayers, out_dim)
54 |
55 | hidden = rnn.init_hidden(batch)
56 | ## inds x batch
57 | ##
58 | inp = torch.floor(torch.rand(seq,batch)*vocab).long()
59 | inp[0,:5] = 0
60 | inp[1,:5] = 0
61 | # inp[2,:] = 0
62 | # inp[3,:] = 0
63 | inp = Variable(inp)
64 | # inp = Variable(torch.LongTensor((1,2,3)))
65 | print 'INPUT: ', inp
66 | print inp.size()
67 | out = rnn.forward( inp,hidden )
68 |
69 | print 'OUT: ', out
70 | print out.size()
71 | # print 'HID: ', hid
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
--------------------------------------------------------------------------------
/models/uvfa_pos.py:
--------------------------------------------------------------------------------
1 | ## predicts entire value map
2 | ## rather than a single value
3 |
4 | import torch
5 | import math, torch.nn as nn, pdb
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 | import models, utils
9 |
10 | class UVFA_pos(nn.Module):
11 | def __init__(self, state_vocab, object_vocab, args, map_dim = 10, batch_size = 32):
12 | super(UVFA_pos, self).__init__()
13 |
14 | self.state_vocab = state_vocab
15 | self.object_vocab = object_vocab
16 | self.total_vocab = state_vocab + object_vocab
17 | self.pos_size = 2
18 |
19 | self.rank = args.rank
20 | self.map_dim = map_dim
21 | self.batch_size = batch_size
22 | self.positions = self.__agent_pos()
23 |
24 | ## add one for agent position
25 | self.input_dim = (self.total_vocab + 1) * (map_dim**2)
26 | self.world_layers = [self.input_dim, 128, 128, args.rank]
27 | self.world_mlp = models.MLP(self.world_layers)
28 |
29 | # self.object_dim = self.object_vocab * (map_dim**2)
30 | self.pos_layers = [self.pos_size, 128, 128, args.rank]
31 | self.pos_mlp = models.MLP(self.pos_layers)
32 |
33 | '''
34 | returns tensor with one-hot vector encoding
35 | [1, 2, 3, ..., map_dim] repeated batch_size times
36 | < batch_size * map_dim, state_vocab >
37 | '''
38 | def __agent_pos(self):
39 | size = self.map_dim**2
40 | positions = torch.zeros(self.batch_size*size, 100, 1)
41 | # print positions.size()
42 | for ind in range(size):
43 | # print ind, ind*self.batch_size, (ind+1)*self.batch_size, ind, positions.size()
44 | # positions[ind*self.batch_size:(ind+1)*self.batch_size, ind] = 1
45 | positions[ind:self.batch_size*size:size, ind] = 1
46 | # pdb.set_trace()
47 | return Variable( positions.cuda() )
48 |
49 | def __repeat_position(self, x):
50 | if x.size() == 2:
51 | return x.unsqueeze(1).repeat(1,self.map_dim**2,1)
52 | else:
53 | return x.unsqueeze(1).repeat(1,self.map_dim**2,1,1)
54 |
55 | '''
56 | < batch_size x N >
57 | < batch_size*100 x N >
58 | '''
59 | def __construct_inp(self, world, pos):
60 | world = self.__repeat_position(world)
61 | world = world.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.total_vocab)
62 | ## add agent position
63 | world = torch.cat( (world, self.positions), -1)
64 | ## reshape to (batched) vector for input to MLPs
65 | world = world.view(-1, self.input_dim)
66 |
67 | # obj = self.__repeat_position(obj)
68 | # obj = obj.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.object_vocab)
69 | # obj = obj.view(-1, self.object_dim)
70 |
71 | pos = self.__repeat_position(pos)
72 | # pos = pos.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.pos_size)
73 | pos = pos.view(-1, self.pos_size)
74 |
75 | return world, pos
76 |
77 |
78 | def forward(self, inp):
79 | (state, obj, pos) = inp
80 | batch_size = state.size(0)
81 | # text = text.transpose(0,1)
82 | # hidden = self.lstm.init_hidden(batch_size * self.map_dim**2)
83 |
84 | if batch_size != self.batch_size:
85 | self.batch_size = batch_size
86 | self.positions = self.__agent_pos()
87 |
88 | ## reshape to (batched) vectors
89 | ## can't scatter Variables
90 | state = state.data.view(-1, self.map_dim**2, 1)
91 | obj = obj.data.view(-1, self.map_dim**2, 1)
92 |
93 | ## make state / object indices into one-hot vectors
94 | state_binary = torch.zeros(batch_size, self.map_dim**2, self.total_vocab).cuda()
95 | object_binary = torch.zeros(batch_size, self.map_dim**2, self.total_vocab).cuda()
96 | state_binary.scatter_(2, state, 1)
97 | object_binary.scatter_(2, obj+self.state_vocab, 1)
98 |
99 | ## < batch x 100 x total_vocab >
100 | ## state_binary will only have non-zero components in the first state_vocab components
101 | ## object_binary will only have non-zero components in state_vocab:total_vocab components
102 | input_binary = state_binary + object_binary
103 | # pdb.set_trace()
104 |
105 | input_binary = Variable( input_binary )
106 | # object_binary = Variable( object_binary )
107 | # print input_binary.size(), pos.size()
108 | # pdb.set_trace()
109 | input_binary, pos = self.__construct_inp(input_binary, pos)
110 |
111 | # print state_binary.size(), object_binary.size(), text.size()
112 | # object_binary = self.__repeat_position(object_binary)
113 | # object_binary = object_binary.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.object_vocab)
114 |
115 | ## add in agent position
116 | ## < batch x 100 x 2 >
117 | ## < batch x 100 x 100 x 2 >
118 | # state_binary = self.__repeat_position(state_binary)
119 | ## < batch*100 x 100 x 2 >
120 | # state_binary = state_binary.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.state_vocab)
121 |
122 | ## add agent position
123 | # state_binary = torch.cat( (state_binary, self.positions), -1)
124 |
125 | # pdb.set_trace()
126 | # print state_binary.size(), object_binary.size()
127 |
128 | ## reshape to (batched) vectors for input to MLPs
129 | ## turn back into Variables for backprop
130 | # state_binary = state_binary.view(-1, self.state_dim)
131 | # object_binary = object_binary.view(-1, self.object_dim)
132 | # print input_binary.size()
133 | # pdb.set_trace()
134 | ## < batch*100 x rank >
135 | world_out = self.world_mlp(input_binary)
136 | pos_out = self.pos_mlp(pos)
137 |
138 | # lstm_out = self.lstm.forward(text, hidden)
139 |
140 |
141 | # print lstm_out.size()
142 |
143 | # print world_out.size(), pos_out.size()
144 |
145 | values = world_out * pos_out
146 | map_pred = values.sum(1, keepdim=True).view(self.batch_size, self.map_dim, self.map_dim)
147 |
148 |
149 | return map_pred
150 |
151 |
--------------------------------------------------------------------------------
/models/uvfa_text.py:
--------------------------------------------------------------------------------
1 | ## predicts entire value map
2 | ## rather than a single value
3 |
4 | import torch
5 | import math, torch.nn as nn, pdb
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 | import models, utils
9 |
10 | class UVFA_text(nn.Module):
11 | def __init__(self, lstm, state_vocab, object_vocab, args, map_dim = 10, batch_size = 32):
12 | super(UVFA_text, self).__init__()
13 |
14 | self.state_vocab = state_vocab
15 | self.object_vocab = object_vocab
16 | self.lstm = lstm
17 | self.rank = args.rank
18 | self.map_dim = map_dim
19 | self.batch_size = batch_size
20 | self.positions = self.__agent_pos()
21 |
22 | ## add one for agent position
23 | self.state_dim = (self.state_vocab+1) * (map_dim**2)
24 | # self.state_dim = self.state_vocab * map_dim**2
25 | self.state_layers = [self.state_dim, 128, 128, args.rank]
26 | self.state_mlp = models.MLP(self.state_layers)
27 |
28 | self.object_dim = self.object_vocab * (map_dim**2)
29 | self.object_layers = [self.object_dim, 128, 128, args.rank]
30 | self.object_mlp = models.MLP(self.object_layers)
31 |
32 | '''
33 | returns tensor with one-hot vector encoding
34 | [1, 2, 3, ..., map_dim] repeated batch_size times
35 | < batch_size * map_dim, state_vocab >
36 | '''
37 | def __agent_pos(self):
38 | size = self.map_dim**2
39 | positions = torch.zeros(self.batch_size*size, 100, 1)
40 | # print positions.size()
41 | for ind in range(size):
42 | # print ind, ind*self.batch_size, (ind+1)*self.batch_size, ind, positions.size()
43 | # positions[ind*self.batch_size:(ind+1)*self.batch_size, ind] = 1
44 | positions[ind:self.batch_size*size:size, ind] = 1
45 | # pdb.set_trace()
46 | return Variable( positions.cuda() )
47 |
48 | def __repeat_position(self, x):
49 | if x.size() == 2:
50 | return x.unsqueeze(1).repeat(1,self.map_dim**2,1)
51 | else:
52 | return x.unsqueeze(1).repeat(1,self.map_dim**2,1,1)
53 |
54 | '''
55 | < batch_size x N >
56 | < batch_size*100 x N >
57 | '''
58 | def __construct_inp(self, state, obj, text):
59 | state = self.__repeat_position(state)
60 | state = state.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.state_vocab)
61 | ## add agent position
62 | state = torch.cat( (state, self.positions), -1)
63 | ## reshape to (batched) vector for input to MLPs
64 | state = state.view(-1, self.state_dim)
65 |
66 | obj = self.__repeat_position(obj)
67 | obj = obj.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.object_vocab)
68 | obj = obj.view(-1, self.object_dim)
69 |
70 | instr_length = text.size(1)
71 | ## < batch x length >
72 | ## < batch x 100 x length >
73 | text = self.__repeat_position(text)
74 | ## < batch*100 x length >
75 | text = text.view(self.batch_size*self.map_dim**2,instr_length)
76 | ## < length x batch*100 >
77 | text = text.transpose(0,1)
78 | ## < batch*100 x rank >
79 |
80 | return state, obj, text
81 |
82 |
83 | def forward(self, inp):
84 | (state, obj, text) = inp
85 | batch_size = state.size(0)
86 | # text = text.transpose(0,1)
87 | hidden = self.lstm.init_hidden(batch_size * self.map_dim**2)
88 |
89 | if batch_size != self.batch_size:
90 | self.batch_size = batch_size
91 | self.positions = self.__agent_pos()
92 |
93 | ## reshape to (batched) vectors
94 | ## can't scatter Variables
95 | state = state.data.view(-1, self.map_dim**2, 1)
96 | obj = obj.data.view(-1, self.map_dim**2, 1)
97 |
98 | ## make state / object indices into one-hot vectors
99 | state_binary = torch.zeros(batch_size, self.map_dim**2, self.state_vocab).cuda()
100 | object_binary = torch.zeros(batch_size, self.map_dim**2, self.object_vocab).cuda()
101 | state_binary.scatter_(2, state, 1)
102 | object_binary.scatter_(2, obj, 1)
103 |
104 | state_binary = Variable( state_binary )
105 | object_binary = Variable( object_binary )
106 |
107 | state_binary, object_binary, text = self.__construct_inp(state_binary, object_binary, text)
108 |
109 | # print state_binary.size(), object_binary.size(), text.size()
110 | # object_binary = self.__repeat_position(object_binary)
111 | # object_binary = object_binary.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.object_vocab)
112 |
113 | ## add in agent position
114 | ## < batch x 100 x 2 >
115 | ## < batch x 100 x 100 x 2 >
116 | # state_binary = self.__repeat_position(state_binary)
117 | ## < batch*100 x 100 x 2 >
118 | # state_binary = state_binary.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.state_vocab)
119 |
120 | ## add agent position
121 | # state_binary = torch.cat( (state_binary, self.positions), -1)
122 |
123 | # pdb.set_trace()
124 | # print state_binary.size(), object_binary.size()
125 |
126 | ## reshape to (batched) vectors for input to MLPs
127 | ## turn back into Variables for backprop
128 | # state_binary = state_binary.view(-1, self.state_dim)
129 | # object_binary = object_binary.view(-1, self.object_dim)
130 |
131 | ## < batch*100 x rank >
132 | state_out = self.state_mlp(state_binary)
133 | object_out = self.object_mlp(object_binary)
134 |
135 | lstm_out = self.lstm.forward(text, hidden)
136 |
137 |
138 | # print lstm_out.size()
139 |
140 | values = state_out * object_out * lstm_out
141 | map_pred = values.sum(1, keepdim=True).view(self.batch_size, self.map_dim, self.map_dim)
142 |
143 |
144 | return map_pred
145 |
146 | # def forward(self, inp):
147 | # (state, obj, text) = inp
148 | # batch_size = state.size(0)
149 | # text = text.transpose(0,1)
150 | # hidden = self.lstm.init_hidden(batch_size)
151 |
152 | # if batch_size != self.batch_size:
153 | # self.batch_size = batch_size
154 | # # self.positions = self.__agent_pos()
155 |
156 | # ## reshape to (batched) vectors
157 | # ## can't scatter Variables
158 | # state = state.data.view(-1, self.map_dim**2, 1)
159 | # obj = obj.data.view(-1, self.map_dim**2, 1)
160 |
161 | # ## make state / object indices into one-hot vectors
162 | # state_binary = torch.zeros(batch_size, self.map_dim**2, self.state_vocab).cuda()
163 | # object_binary = torch.zeros(batch_size, self.map_dim**2, self.object_vocab).cuda()
164 | # state_binary.scatter_(2, state, 1)
165 | # object_binary.scatter_(2, obj, 1)
166 |
167 | # state_binary = Variable( state_binary )
168 | # object_binary = Variable( object_binary )
169 |
170 | # ## add in agent position
171 | # ## < batch x 100 x 2 >
172 | # ## < batch x 100 x 100 x 2 >
173 | # # state_binary = state_binary.unsqueeze(1).repeat(1,self.map_dim**2,1,1)
174 | # # state_binary = self.__repeat_position(state_binary)
175 | # ## < batch*100 x 100 x 2 >
176 | # # state_binary = state_binary.view(self.batch_size*self.map_dim**2,self.map_dim**2,self.state_vocab)
177 |
178 | # ## add agent position
179 | # # state_binary = torch.cat( (state_binary, self.positions), -1)
180 |
181 | # # pdb.set_trace()
182 | # # print state_binary.size(), object_binary.size()
183 |
184 | # ## reshape to (batched) vectors for input to MLPs
185 | # ## turn back into Variables for backprop
186 | # state_binary = state_binary.view(-1, self.state_dim)
187 | # object_binary = object_binary.view(-1, self.object_dim)
188 | # print 'state: ', state_binary.size(), object_binary.size()
189 | # state_out = self.state_mlp(state_binary)
190 | # object_out = self.object_mlp(object_binary)
191 |
192 | # # print state_out.size(), object_out.size()
193 |
194 | # lstm_out = self.lstm.forward(text, hidden)
195 |
196 | # # object_out = self.__repeat_position(object_out).view(-1, self.rank)
197 | # # lstm_out = self.__repeat_position(lstm_out).view(-1, self.rank)
198 |
199 | # # print state_out.size(), object_out.size(), lstm_out.size()
200 |
201 | # values = state_out * object_out * lstm_out
202 | # # map_pred = values.sum(1).view(self.batch_size, self.map_dim**2)
203 | # values = values.sum(1)
204 | # # print values.size()
205 | # map_pred = values.unsqueeze(-1).repeat(1,self.map_dim,self.map_dim)
206 | # print values.size(), map_pred.size()
207 |
208 | # return map_pred
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
--------------------------------------------------------------------------------
/models/values_factorized.py:
--------------------------------------------------------------------------------
1 | import torch, torch.nn as nn
2 |
3 | class ValuesFactorized(nn.Module):
4 |
5 | def __init__(self, state_model, goal_model):
6 | super(ValuesFactorized, self).__init__()
7 |
8 | self.state_model = state_model
9 | self.goal_model = goal_model
10 |
11 | # def forward_large(self, inputs, batch_size = 32):
12 |
13 |
14 | def forward(self, inp):
15 | state, objects, instructions = inp
16 | # print 'in Values: ', state.size(), objects.size(), instructions.size()
17 | state_embedding = self.state_model.forward(state)
18 | goal_embedding = self.goal_model.forward( (objects, instructions) )
19 |
20 | values = state_embedding * goal_embedding
21 | values = values.sum(1, keepdim=True)
22 | # ## num_states x rank
23 | # num_states = state_embedding.size(0)
24 | # ## num_goals x rank
25 | # num_goals = goal_embedding.size(0)
26 |
27 | # ## num_states x num_goals x rank
28 | # state_rep = state_embedding.unsqueeze(1).repeat(1,num_goals,1)
29 | # goal_rep = goal_embedding.repeat(num_states,1,1)
30 |
31 | # values = state_rep * goal_rep
32 | # values = values.sum(2).squeeze()
33 | # values = values.transpose(0,1)
34 |
35 | return values
36 |
37 |
38 |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/pipeline/__init__.py:
--------------------------------------------------------------------------------
1 | from training import *
2 | from evaluation import *
3 | from agent import *
4 | from score_iteration import *
5 |
--------------------------------------------------------------------------------
/pipeline/evaluation.py:
--------------------------------------------------------------------------------
1 | import os, math, pickle, torch, numpy as np, pdb, pipeline
2 | from torch.autograd import Variable
3 | import matplotlib; matplotlib.use('Agg')
4 | from matplotlib import cm
5 | from matplotlib import pyplot as plt
6 | from tqdm import tqdm
7 |
8 |
9 | '''
10 | saves predictions of model, targets, and MDP info (rewards / terminal map) as pickle files
11 | inputs is a tuple of (layouts, objects, instruction_indices)
12 | assumes that save path already exists
13 | '''
14 | def save_predictions(model, inputs, targets, rewards, terminal, text_vocab, save_path, prefix=''):
15 | ## wrap tensors in Variables to pass to model
16 | input_vars = ( Variable(tensor.contiguous()) for tensor in inputs )
17 | predictions = model(input_vars)
18 |
19 | ## convert to numpy arrays for saving to disk
20 | predictions = predictions.data.cpu().numpy()
21 | targets = targets.cpu().numpy()
22 |
23 | ## save the predicted and target value maps
24 | ## as well as info about the MDP and instruction
25 | pickle.dump(predictions, open(os.path.join(save_path, prefix+'predictions.p'), 'wb') )
26 | pickle.dump(targets, open(os.path.join(save_path, prefix+'targets.p'), 'wb') )
27 | pickle.dump(rewards, open(os.path.join(save_path, prefix+'rewards.p'), 'wb') )
28 | pickle.dump(terminal, open(os.path.join(save_path, prefix+'terminal.p'), 'wb') )
29 | pickle.dump(text_vocab, open(os.path.join(save_path, prefix+'vocab.p'), 'wb') )
30 |
31 |
32 | '''
33 | test set is dict from
34 | world number --> (state_obs, goal_obs, instruct_inds, values)
35 | '''
36 |
37 | def evaluate(model, test_set, savepath=None):
38 | progress = tqdm(total=len(test_set))
39 | count = 0
40 | for key, (state_obs, goal_obs, instruct_words, instruct_inds, targets) in test_set.iteritems():
41 | progress.update(1)
42 |
43 | state = Variable( torch.Tensor(state_obs).long().cuda() )
44 | objects = Variable( torch.Tensor(goal_obs).long().cuda() )
45 | instructions = Variable( torch.Tensor(instruct_inds).long().cuda() )
46 | targets = torch.Tensor(targets)
47 | # print state.size(), objects.size(), instructions.size(), targets.size()
48 |
49 | preds = model.forward( (state, objects, instructions) ).data.cpu()
50 |
51 | state_dim = 1
52 | for dim in state.size()[-2:]:
53 | state_dim *= dim
54 |
55 | if savepath:
56 | num_goals = preds.size(0) / state_dim
57 | for goal_num in range(num_goals):
58 | lower = goal_num * state_dim
59 | upper = (goal_num + 1) * state_dim
60 | fullpath = os.path.join(savepath, \
61 | str(key) + '_' + str(goal_num) + '.png')
62 | pred = preds[lower:upper].numpy()
63 | targ = targets[lower:upper].numpy()
64 | instr = instruct_words[lower]
65 |
66 | pipeline.visualize_value_map(pred, targ, fullpath, title=instr)
67 |
68 |
69 | def get_children(M, N):
70 | children = {}
71 | for i in range(M):
72 | for j in range(N):
73 | pos = (i,j)
74 | children[pos] = []
75 | for di in range( max(i-1, 0), min(i+1, M-1)+1 ):
76 | for dj in range( max(j-1, 0), min(j+1, N-1)+1 ):
77 | child = (di, dj)
78 | if pos != child and (i == di or j == dj):
79 | children[pos].append( child )
80 | return children
81 |
82 |
83 | '''
84 | values is M x N map of predicted values
85 | '''
86 | def get_policy(values):
87 | values = values.squeeze()
88 | M, N = values.shape
89 | states = [(i,j) for i in range(M) for j in range(N)]
90 | children = get_children( M, N )
91 | policy = {}
92 | for state in states:
93 | reachable = children[state]
94 | selected = sorted(reachable, key = lambda x: values[x], reverse=True)
95 | policy[state] = selected[0]
96 | return policy
97 |
98 | def simulate(model, sim_set):
99 | # progress = tqdm(total=len(test_set))
100 | steps_list = []
101 | count = 0
102 | for key in tqdm(range(len(sim_set))):
103 | (state_obs, goal_obs, instruct_words, instruct_inds, targets, mdps) = sim_set[key]
104 | # progress.update(1)
105 | # print torch.Tensor(state_obs).long().cuda()
106 | state = Variable( torch.Tensor(state_obs).long().cuda() )
107 | objects = Variable( torch.Tensor(goal_obs).long().cuda() )
108 | instructions = Variable( torch.Tensor(instruct_inds).long().cuda() )
109 | targets = torch.Tensor(targets)
110 | # print state.size(), objects.size(), instructions.size()
111 |
112 | preds = model.forward(state, objects, instructions).data.cpu().numpy()
113 | # print 'sim preds: ', preds.shape
114 |
115 | ## average over all goals
116 | num_goals = preds.shape[0]
117 | for ind in range(num_goals):
118 | # print ind
119 | mdp = mdps[ind]
120 | values = preds[ind,:]
121 | dim = int(math.sqrt(values.size))
122 | positions = [(i,j) for i in range(dim) for j in range(dim)]
123 | # print 'dim: ', dim
124 | values = preds[ind,:].reshape(dim, dim)
125 | policy = mdp.get_policy(values)
126 |
127 | # plt.clf()
128 | # plt.pcolor(policy)
129 |
130 |
131 | ## average over all start positions
132 | for start_pos in positions:
133 | steps = mdp.simulate(policy, start_pos)
134 | steps_list.append(steps)
135 | # pdb.set_trace()
136 | # print 'simulating: ', start_pos, steps
137 | avg_steps = np.mean(steps_list)
138 | # print 'avg steps: ', avg_steps, len(steps_list), len(sim_set), num_goals
139 | return avg_steps
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
--------------------------------------------------------------------------------
/pipeline/run_eval.py:
--------------------------------------------------------------------------------
1 | #!/om/user/janner/anaconda2/envs/pytorch/bin/python
2 |
3 | ## norbf
4 | # python run_eval.py --save_path rl_logs_analysis/qos_1250epochs_reinforce_norbfmodel_localmode_1566turk_0.001lr_0.95gamma_4batch_1rep/
5 | # python run_eval.py --save_path rl_logs_analysis/qos_reinforce_norbfmodel_globalmode_1566turk_0.001lr_0.95gamma_4batch_1rep/
6 |
7 | ## uvfa-text
8 | # python run_eval.py --save_path rl_logs_analysis//reinforce_uvfa-textmodel_localmode_1566turk_0.001lr_0.95gamma_32batch_3rep/
9 | # python run_eval.py --save_path rl_logs_analysis//reinforce_uvfa-textmodel_globalmode_1566turk_0.001lr_0.95gamma_32batch_2rep/
10 |
11 | ## cnn-lstm
12 | # run_eval.py --save_path rl_logs_analysis/reinforce_cnn-lstmmodel_localmode_1566turk_0.001lr_0.95gamma_16batch_2rep/
13 | # python run_eval.py --save_path rl_logs_analysis/reinforce_cnn-lstmmodel_globalmode_1566turk_0.001lr_0.95gamma_32batch_3rep/
14 |
15 | import sys, os, subprocess, argparse, numpy as np, pickle, pdb
16 | sys.path.append('/om/user/janner/mit/urop/direction_decomposition/')
17 | import environment, pipeline, reinforce
18 |
19 | parser = argparse.ArgumentParser()
20 | # parser.add_argument('--save_path', type=str, default='logs/trial_nobases_test')
21 | parser.add_argument('--save_path', type=str, default='curves/local_full_1500_wsynth-0-18,20,21,22,23,24,25,26,27,49,53,20,30,40,50,35,45_4/')
22 | # parser.add_argument('--metric_path', type=str, default='curves/local_full_1500_wsynth-0-18,20,21,22,23,24,25,26,27,49,53,20,30,40,50,35,45_4/')
23 | args = parser.parse_args()
24 | args.metric_path = args.save_path
25 |
26 | print args.save_path
27 | predictions = pickle.load( open(os.path.join(args.save_path, 'test_predictions.p'), 'rb') ).squeeze()
28 | targets = pickle.load( open(os.path.join(args.save_path, 'test_targets.p'), 'rb') ).squeeze()
29 | rewards = pickle.load( open(os.path.join(args.save_path, 'test_rewards.p'), 'rb') ).squeeze()
30 | terminal = pickle.load( open(os.path.join(args.save_path, 'test_terminal.p'), 'rb') ).squeeze()
31 |
32 | rewards = rewards.cpu().numpy()
33 | terminal = terminal.cpu().numpy()
34 |
35 | def get_states(M, N):
36 | states = [(i,j) for i in range(M) for j in range(N)]
37 | return states
38 |
39 | def get_children(M, N):
40 | children = {}
41 | for i in range(M):
42 | for j in range(N):
43 | pos = (i,j)
44 | children[pos] = []
45 | for di in range( max(i-1, 0), min(i+1, M-1)+1 ):
46 | for dj in range( max(j-1, 0), min(j+1, N-1)+1 ):
47 | child = (di, dj)
48 | if pos != child and (i == di or j == dj):
49 | children[pos].append( child )
50 | return children
51 |
52 | '''
53 | for each state (i,j) in states, returns a list
54 | of neighboring states with approximated values
55 | in descending order
56 | '''
57 | def get_policy(values):
58 | values = values.squeeze()
59 | policy = {}
60 | for state in STATES:
61 | reachable = CHILDREN[state]
62 | selected = sorted(reachable, key = lambda x: values[x], reverse = True)
63 | policy[state] = selected
64 | return policy
65 |
66 | def simulate_single(reward_map, terminal_map, approx_values, start_pos, max_steps = 75):
67 | # world = world.squeeze()
68 | # reward_map = reward_map.data.squeeze()
69 | # terminal_map = terminal_map.data.squeeze()
70 | # approx_values = approx_values
71 |
72 | M, N = reward_map.shape
73 | # print 'M, N: ', M, N
74 | if M != 10 or N != 10:
75 | raise RuntimeError( 'wrong size: {}x{}, expected: {}x{}'.format(M, N, self.M, self.N) )
76 | # self._refresh_size(M, N)
77 | policy = get_policy(approx_values)
78 |
79 | pos = start_pos
80 | visited = set([pos])
81 | trajectory = []
82 |
83 | total_reward = 0
84 | for step in range(max_steps):
85 | val = approx_values[pos]
86 | rew = reward_map[pos]
87 | term = terminal_map[pos]
88 | # print 'VAL: ', val, '\nREW: ', rew, '\nTERM: ', term, approx_values.size(), reward_map.size(), terminal_map.size()
89 | trajectory.append( (pos, val, rew, term) )
90 |
91 | total_reward += rew * (GAMMA ** step)
92 | if term:
93 | # print 'GOT REWARD: ', rew, step, rew * (self.gamma ** step)
94 | # print '\n\nDONE\n\n\n\n'
95 | break
96 |
97 | reachable = policy[pos]
98 | selected = 0
99 | while selected < len(reachable) and reachable[selected] in visited:
100 | # print ' visited ', selected, reachable[selected]
101 | selected += 1
102 | if selected == len(reachable):
103 | # print '\n\nVISITED ALL', pos, [n in visited for n in reachable], '\n\n\n\n'
104 | selected = 0
105 | # return trajectory
106 | break
107 |
108 | pos = reachable[selected]
109 | visited.add(pos)
110 | # print 'position: ', pos
111 | # print 'traj: ', len(trajectory), 'rew: ', total_reward
112 | return total_reward
113 |
114 | STATES = get_states(10,10)
115 | CHILDREN = get_children(10,10)
116 | GAMMA = 0.95
117 |
118 | # pdb.set_trace()
119 | quality_path = os.path.join(args.metric_path, 'quality')
120 | if not os.path.exists( quality_path ):
121 | subprocess.call(['mkdir', quality_path])
122 |
123 | num_worlds = targets.shape[0]
124 | print 'Num worlds: {}'.format(num_worlds)
125 |
126 | mse = np.sum(np.power(predictions - targets, 2)) / predictions.size
127 | print 'MSE: {}'.format(mse)
128 |
129 | cumulative_normed = 0
130 | manhattan = 0
131 | cumulative_per_score = 0
132 | cumulative_score = 0
133 | for ind in range(num_worlds):
134 | pred = predictions[ind]
135 | targ = targets[ind]
136 |
137 | pred_max = np.unravel_index(np.argmax(pred), pred.shape)
138 | targ_max = np.unravel_index(np.argmax(targ), targ.shape)
139 | man = abs(pred_max[0] - targ_max[0]) + abs(pred_max[1] - targ_max[1])
140 |
141 | # pdb.set_trace()
142 |
143 | unif = np.ones( pred.shape )
144 | rew = rewards[ind]
145 | term = terminal[ind]
146 |
147 | mdp = environment.MDP(None, rew, term)
148 | si = pipeline.ScoreIteration(mdp, pred)
149 | avg_pred, scores_pred = si.iterate()
150 |
151 | mdp = environment.MDP(None, rew, term)
152 | si = pipeline.ScoreIteration(mdp, targ)
153 | avg_targ, scores_targ = si.iterate()
154 |
155 | mdp = environment.MDP(None, rew, term)
156 | si = pipeline.ScoreIteration(mdp, unif)
157 | avg_unif, scores_unif = si.iterate()
158 |
159 | avg_per_score = np.divide(scores_pred-scores_unif, scores_targ-scores_unif)
160 | avg_per_score[avg_per_score != avg_per_score] = 1
161 | avg_per_score = np.mean(avg_per_score)
162 | # pdb.set_trace()
163 |
164 | start_pos = (np.random.randint(10), np.random.randint(10))
165 | score = 0 #simulate_single(rew, term, pred, start_pos)
166 |
167 | normed = (avg_pred - avg_unif) / (avg_targ - avg_unif)
168 | cumulative_normed += normed
169 | manhattan += man
170 | cumulative_per_score += avg_per_score
171 | cumulative_score += score
172 |
173 | sys.stdout.write( '{:4}:\t{:4}\t{:4}\t{:4}\t{:4}\t{:4}\t{:.4}\t{:.4}\t{:.3}\r'.format(ind, avg_pred, avg_targ, avg_unif, normed,
174 | cumulative_normed / (ind + 1), cumulative_per_score / (ind + 1), float(manhattan) / (ind + 1), float(cumulative_score) / (ind + 1)))
175 | sys.stdout.flush()
176 |
177 |
178 | # fullpath = os.path.join(quality_path, str(ind) + '.png')
179 | # pipeline.visualize_value_map(scores_pred, scores_targ, fullpath)
180 |
181 |
182 |
183 |
184 |
185 |
186 | # pred_policy = pipeline.get_policy(pred)
187 | # targ_policy = pipeline.get_policy(targ)
188 |
189 |
190 |
191 | # corr = 0
192 | # for pos, targ in targ_policy.iteritems():
193 | # pred = pred_policy[pos]
194 | # if pred == targ:
195 | # # print 'count: ', pred, targ
196 | # corr += 1
197 | # # else:
198 | # # print 'not counting: ', pred, targ
199 | # print corr, ' / ', len(targ_policy)
200 | # cumulative_correct += corr
201 |
202 | avg_normed = float(cumulative_normed) / num_worlds
203 | avg_manhattan = float(manhattan) / num_worlds
204 | avg_score = float(cumulative_score) / num_worlds
205 | print 'Avg normed: {}'.format(avg_normed)
206 | print 'Avg manhattan: {}'.format(avg_manhattan)
207 | print 'Avg score: {}'.format(avg_score)
208 |
209 | if args.metric_path != None:
210 | results = {'mse': mse, 'quality': avg_normed, 'manhattan': avg_manhattan}
211 | pickle.dump(results, open(os.path.join(args.metric_path, 'results.p'), 'wb'))
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
--------------------------------------------------------------------------------
/pipeline/score_iteration.py:
--------------------------------------------------------------------------------
1 | import numpy as np, pdb
2 |
3 | class ScoreIteration:
4 |
5 | def __init__(self, mdp, values):
6 | self.refresh(mdp, values)
7 |
8 | def _softmax(self, x):
9 | return np.exp(x) / np.sum(np.exp(x), axis=0)
10 |
11 | def get_children(M, N):
12 | children = {}
13 | for i in range(M):
14 | for j in range(N):
15 | pos = (i,j)
16 | children[pos] = []
17 | for di in range( max(i-1, 0), min(i+1, M-1)+1 ):
18 | for dj in range( max(j-1, 0), min(j+1, N-1)+1 ):
19 | child = (di, dj)
20 | if pos != child and (i == di or j == dj):
21 | children[pos].append( child )
22 | return children
23 |
24 | '''
25 | values is M x N map of predicted values
26 | '''
27 | def get_policy(values):
28 | values = values.squeeze()
29 | M, N = values.shape
30 | states = [(i,j) for i in range(M) for j in range(N)]
31 | children = get_children( M, N )
32 | policy = {}
33 | for state in states:
34 | reachable = children[state]
35 | selected = sorted(reachable, key = lambda x: values[x], reverse=True)
36 | policy[state] = selected[0]
37 | return policy
38 |
39 | def refresh(self, mdp, values):
40 | self.mdp = mdp
41 | self.states = mdp.getStates()
42 | self.actions = mdp.getActions()
43 | self.transition = mdp.transition
44 | self.reward = mdp.reward
45 | self.terminal = mdp.terminal
46 | self.values = values
47 | self.scores = np.zeros( self.mdp.reward_map.shape )
48 | self.discount = 0.9
49 |
50 |
51 | def iterate(self):
52 | old_scores = self.scores.copy()
53 | old_scores.fill(-float('inf'))
54 |
55 | count = 0
56 | while np.sum(np.abs(self.scores - old_scores)) > 0.01:
57 | old_scores = self.scores.copy()
58 | for state in self.states:
59 | max_val = self.reward(state)
60 | term = self.terminal(state)
61 |
62 | if not term:
63 | neighbors = [self.transition(state, action) for action in self.actions]
64 | values = [self.values[s] for s in neighbors]
65 | transition_probs = self._softmax(values)
66 |
67 | neighbor_contributions = [transition_probs[ind] * self.scores[neighbors[ind]] \
68 | for ind in range(len(transition_probs))]
69 |
70 | max_val = self.reward(state) + self.discount * sum(neighbor_contributions)
71 |
72 | self.scores[state] = max_val
73 | count += 1
74 |
75 | avg_score = np.mean(self.scores)
76 | return avg_score, self.scores
77 |
--------------------------------------------------------------------------------
/pipeline/training.py:
--------------------------------------------------------------------------------
1 | import sys, math
2 | import numpy as np
3 | from tqdm import tqdm, trange
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.optim as optim
8 | from torchvision import datasets, transforms
9 | from torch.autograd import Variable
10 |
11 | class Trainer:
12 | def __init__(self, model, lr, batch_size):
13 | self.model = model
14 | self.batch_size = batch_size
15 | self.criterion = nn.MSELoss(size_average=True).cuda()
16 | self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
17 |
18 | def __epoch(self, inputs, targets, repeats = 1):
19 | self.model.train()
20 | if type(inputs) == tuple:
21 | data_size = inputs[0].size(0)
22 | else:
23 | data_size = inputs.size(0)
24 | num_batches = int(math.ceil(data_size / float(self.batch_size)) * repeats)
25 |
26 | err = 0
27 | for i in range(num_batches):
28 | inp, targ = self.__get_batch(inputs, targets)
29 | self.optimizer.zero_grad()
30 | out = self.model.forward(inp)
31 | loss = self.criterion(out, targ)
32 | loss.backward()
33 | self.optimizer.step()
34 | err += loss.data[0]
35 | err = err / float(num_batches)
36 | return err
37 |
38 | def __get_batch(self, inputs, targets):
39 | data_size = targets.size(0)
40 |
41 | inds = torch.floor(torch.rand(self.batch_size) * data_size).long().cuda()
42 | # bug: floor(rand()) sometimes gives 1
43 | inds[inds >= data_size] = data_size - 1
44 |
45 | if type(inputs) == tuple:
46 | inp = tuple([Variable( i.index_select(0, inds).cuda() ) for i in inputs])
47 | else:
48 | inp = Variable( inputs.index_select(0, inds).cuda() )
49 |
50 | targ = Variable( targets.index_select(0, inds).cuda() )
51 | return inp, targ
52 |
53 | def train(self, inputs, targets, val_inputs, val_targets, iters = 10):
54 | t = trange(iters)
55 | for i in t:
56 | err = self.__epoch(inputs, targets)
57 | t.set_description( str(err) )
58 | return self.model
59 |
60 |
--------------------------------------------------------------------------------
/reinforcement.py:
--------------------------------------------------------------------------------
1 | #!/om/user/janner/anaconda2/envs/pytorch/bin/python
2 |
3 | import os, argparse, pickle, torch
4 | import pipeline, models, data, utils, visualization
5 |
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument('--save_path', type=str, default='logs/trial')
8 | parser.add_argument('--max_train', type=int, default=5000)
9 | parser.add_argument('--max_test', type=int, default=500)
10 |
11 | parser.add_argument('--mode', type=str, default='local', choices=['local', 'global'])
12 | parser.add_argument('--annotations', type=str, default='human', choices=['synthetic', 'human'])
13 | parser.add_argument('--model', type=str, default='full', choices=['full', 'no-gradient', 'cnn-lstm', 'uvfa-text'])
14 |
15 | parser.add_argument('--map_dim', type=int, default=10)
16 | parser.add_argument('--state_embed', type=int, default=1)
17 | parser.add_argument('--obj_embed', type=int, default=7)
18 |
19 | parser.add_argument('--lstm_inp', type=int, default=15)
20 | parser.add_argument('--lstm_hid', type=int, default=30)
21 | parser.add_argument('--lstm_layers', type=int, default=1)
22 | parser.add_argument('--attention_kernel', type=int, default=3)
23 | parser.add_argument('--attention_out_dim', type=int, default=1)
24 |
25 | parser.add_argument('--batch_size', type=int, default=4)
26 | parser.add_argument('--replay_size', type=int, default=100000)
27 | parser.add_argument('--learn_start', type=int, default=1000)
28 | parser.add_argument('--gamma', type=float, default=0.95)
29 | parser.add_argument('--lr', type=float, default=0.001)
30 | parser.add_argument('--epochs', type=int, default=1250)
31 |
32 | args = parser.parse_args()
33 |
34 |
35 | #################################
36 | ############## Data #############
37 | #################################
38 |
39 | train_data, test_data = data.load(args.mode, args.annotations, args.max_train, args.max_test)
40 | layout_vocab_size, object_vocab_size, text_vocab_size, text_vocab = data.get_statistics(train_data, test_data)
41 |
42 | print '\n Converting to tensors'
43 | train_layouts, train_objects, train_rewards, train_terminal, \
44 | train_instructions, train_indices, train_values, train_goals = data.to_tensor(train_data, text_vocab)
45 |
46 | test_layouts, test_objects, test_rewards, test_terminal, \
47 | test_instructions, test_indices, test_values, test_goals = data.to_tensor(test_data, text_vocab)
48 |
49 | print ' Training:', train_layouts.size(), 'x', train_objects.size(), 'x', train_indices.size()
50 | print ' Rewards: ', train_rewards.size(), ' Terminal: ', train_terminal.size()
51 | print ' Test :', test_layouts.size(), 'x', test_objects.size(), 'x', test_indices.size()
52 | print ' Rewards: ', test_rewards.size(), ' Terminal: ', test_terminal.size()
53 |
54 |
55 | #################################
56 | ############ Training ###########
57 | #################################
58 |
59 | print '\n Initializing model: {}'.format(args.model)
60 | model = models.init(args, layout_vocab_size, object_vocab_size, text_vocab_size)
61 | target_model = models.init(args, layout_vocab_size, object_vocab_size, text_vocab_size)
62 |
63 | ## initialize agent
64 | agent = pipeline.Agent(model, target_model, map_dim = args.map_dim, instr_len = train_indices.size(1),
65 | batch_size = args.batch_size, learn_start = args.learn_start,
66 | replay_size = args.replay_size, lr = args.lr, gamma = args.gamma)
67 |
68 | train_inputs = (train_layouts, train_objects, train_indices)
69 | test_inputs = (test_layouts, test_objects, test_indices)
70 |
71 | ## train agent
72 | scores = agent.train( train_inputs, train_rewards, train_terminal,
73 | test_inputs, test_rewards, test_terminal, epochs = args.epochs )
74 |
75 |
76 | #################################
77 | ######## Save predictions #######
78 | #################################
79 |
80 | ## make logging directories
81 | pickle_path = os.path.join(args.save_path, 'pickle')
82 | utils.mkdir(args.save_path)
83 | utils.mkdir(pickle_path)
84 |
85 | print '\n Saving model and scores to {}'.format(args.save_path)
86 | ## save model
87 | torch.save(model, os.path.join(args.save_path, 'model.pth'))
88 | ## save scores from training
89 | score_path = os.path.join(pickle_path, 'scores.p')
90 | pickle.dump(scores, open(score_path, 'wb') )
91 |
92 |
93 | print ' Saving predictions to {}'.format(pickle_path)
94 | ## save inputs, outputs, and MDP info (rewards and terminal maps)
95 | pipeline.save_predictions(model, train_inputs, train_values, train_rewards, train_terminal, text_vocab, pickle_path, prefix='train_')
96 | pipeline.save_predictions(model, test_inputs, test_values, test_rewards, test_terminal, text_vocab, pickle_path, prefix='test_')
97 |
98 |
99 | #################################
100 | ######### Visualization #########
101 | #################################
102 |
103 | vis_path = os.path.join(args.save_path, 'visualization')
104 | utils.mkdir(vis_path)
105 |
106 | print ' Saving visualizations to {}'.format(vis_path)
107 |
108 | ## save inputs, outputs, and MDP info (rewards and terminal maps)
109 | visualization.vis_predictions(model, train_inputs, train_values, train_instructions, vis_path, prefix='train_')
110 | visualization.vis_predictions(model, test_inputs, test_values, test_instructions, vis_path, prefix='test_')
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
--------------------------------------------------------------------------------
/representation.py:
--------------------------------------------------------------------------------
1 | #!/om/user/janner/anaconda2/envs/pytorch/bin/python
2 |
3 | import os, argparse, numpy as np, torch, pdb
4 | import pipeline, models, data, utils, visualization
5 |
6 |
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--save_path', type=str, default='logs/trial/')
9 | parser.add_argument('--max_train', type=int, default=5000)
10 | parser.add_argument('--max_test', type=int, default=500)
11 |
12 | parser.add_argument('--model', type=str, default='full', choices=['full', 'no-gradient', 'cnn-lstm', 'uvfa-text'])
13 | parser.add_argument('--annotations', type=str, default='human', choices=['synthetic', 'human'])
14 | parser.add_argument('--mode', type=str, default='local', choices=['local', 'global'])
15 |
16 | parser.add_argument('--map_dim', type=int, default=10)
17 | parser.add_argument('--state_embed', type=int, default=1)
18 | parser.add_argument('--obj_embed', type=int, default=7)
19 |
20 | parser.add_argument('--lstm_inp', type=int, default=15)
21 | parser.add_argument('--lstm_hid', type=int, default=30)
22 | parser.add_argument('--lstm_layers', type=int, default=1)
23 | parser.add_argument('--attention_kernel', type=int, default=3)
24 | parser.add_argument('--attention_out_dim', type=int, default=1)
25 |
26 | parser.add_argument('--batch_size', type=int, default=32)
27 | parser.add_argument('--lr', type=int, default=0.001)
28 | parser.add_argument('--iters', type=int, default=200)
29 | args = parser.parse_args()
30 |
31 |
32 | #################################
33 | ############## Data #############
34 | #################################
35 |
36 | train_data, test_data = data.load(args.mode, args.annotations, args.max_train, args.max_test)
37 | layout_vocab_size, object_vocab_size, text_vocab_size, text_vocab = data.get_statistics(train_data, test_data)
38 |
39 | print '\n Converting to tensors'
40 | train_layouts, train_objects, train_rewards, train_terminal, \
41 | train_instructions, train_indices, train_values, train_goals = data.to_tensor(train_data, text_vocab)
42 |
43 | test_layouts, test_objects, test_rewards, test_terminal, \
44 | test_instructions, test_indices, test_values, test_goals = data.to_tensor(test_data, text_vocab)
45 |
46 | print ' Training: (', train_layouts.size(), 'x', train_objects.size(), 'x', train_indices.size(), ') -->', train_values.size()
47 | print ' test : (', test_layouts.size(), 'x', test_objects.size(), 'x', test_indices.size(), ') -->', test_values.size()
48 |
49 |
50 | #################################
51 | ############ Training ###########
52 | #################################
53 |
54 | print '\n Initializing model: {}'.format(args.model)
55 | model = models.init(args, layout_vocab_size, object_vocab_size, text_vocab_size)
56 |
57 | train_inputs = (train_layouts, train_objects, train_indices)
58 | test_inputs = (test_layouts, test_objects, test_indices)
59 |
60 | print ' Training model'
61 | trainer = pipeline.Trainer(model, args.lr, args.batch_size)
62 | trainer.train(train_inputs, train_values, test_inputs, test_values, iters=args.iters)
63 |
64 |
65 | #################################
66 | ######## Save predictions #######
67 | #################################
68 |
69 | ## make logging directories
70 | pickle_path = os.path.join(args.save_path, 'pickle')
71 | utils.mkdir(args.save_path)
72 | utils.mkdir(pickle_path)
73 |
74 | print '\n Saving model to {}'.format(args.save_path)
75 | ## save model
76 | torch.save(model, os.path.join(args.save_path, 'model.pth'))
77 |
78 | print ' Saving predictions to {}'.format(pickle_path)
79 | ## save inputs, outputs, and MDP info (rewards and terminal maps)
80 | pipeline.save_predictions(model, test_inputs, test_values, test_rewards, test_terminal, text_vocab, pickle_path, prefix='test_')
81 |
82 |
83 | #################################
84 | ######### Visualization #########
85 | #################################
86 |
87 | vis_path = os.path.join(args.save_path, 'visualization')
88 | utils.mkdir(vis_path)
89 |
90 | print ' Saving visualizations to {}'.format(vis_path)
91 | ## save images with predicted and target value maps
92 | visualization.vis_predictions(model, test_inputs, test_values, test_instructions, vis_path, prefix='test_')
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.12.1
2 | scipy>=0.19.0
3 | matplotlib==2.0.0
4 | tqdm
5 |
--------------------------------------------------------------------------------
/slurm/meta-generate.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 | save_path = '../data/mturk_only_global_10dim_8max_2ref/'
4 | vis_path = '../data/mturk_only_global_10dim_8max_2ref_vis/'
5 | start = 200
6 | end = 300
7 | step = 1
8 | dim = 10
9 | mode = 'global'
10 | # only_global = False
11 |
12 | for lower in range(start, end, step):
13 | command = [ 'sbatch', '--qos=tenenbaum', '-c', '2', '--time=1-12:0', '-J', str(lower), 'generate_worlds.py', \
14 | '--lower', str(lower), '--num_worlds', str(step), '--dim', str(dim), \
15 | '--save_path', save_path, '--vis_path', vis_path, '--mode', mode]
16 | # print command
17 | subprocess.Popen( command )
--------------------------------------------------------------------------------
/slurm/meta_reinforce.sh:
--------------------------------------------------------------------------------
1 | # lr=0.001
2 | # gamma=0.95
3 | # kernel_out=1
4 |
5 | # for batch in 4 8 16 32
6 | # do
7 | # for worlds in 5 10 20
8 | # do
9 | # save_path=logs/reinforce_${lr}lr_${gamma}gamma_${batch}batch_${kernel_out}attn_${worlds}worlds
10 | # mkdir ${save_path}
11 | # sbatch -c 2 --gres=gpu:titan-x:1 --qos=tenenbaum --time=4-12:0 --mem=40G -J ${worlds}_${batch}_${lr} -o ${save_path}/out.txt run_reinforce.py \
12 | # --lr ${lr} --attention_out_dim ${kernel_out} --num_worlds ${worlds} --save_path ${save_path} \
13 | # --batch_size ${batch} --gamma ${gamma}
14 | # done
15 | # done
16 |
17 | lr=0.001
18 | gamma=0.95
19 | epochs=1250
20 |
21 | for rep in 1
22 | do
23 | for batch in 4
24 | do
25 | for model in norbf #uvfa-text cnn-lstm
26 | do
27 | for mode in local global
28 | do
29 | for num_turk in 1566
30 | do
31 | save_path=rl_logs_analysis/qos_${epochs}epochs_reinforce_${model}model_${mode}mode_${num_turk}turk_${lr}lr_${gamma}gamma_${batch}batch_${rep}rep
32 | mkdir ${save_path}
33 | sbatch -c 2 --gres=gpu:titan-x:1 --qos=tenenbaum --time=4-12:0 --mem=40G -J ${rep}${model}${num_turk}_${batch}_${lr} -o ${save_path}/_out.txt turk_reinforce.py \
34 | --model ${model} --mode ${mode} --num_turk ${num_turk} \
35 | --lr ${lr} --save_path ${save_path} \
36 | --batch_size ${batch} --gamma ${gamma} --epochs ${epochs}
37 | done
38 | done
39 | done
40 | done
41 | done
42 |
43 | # -o ${save_path}/_out.txt
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from rbf import *
2 | from filesystem import *
--------------------------------------------------------------------------------
/utils/filesystem.py:
--------------------------------------------------------------------------------
1 | import os, subprocess
2 |
3 | def mkdir(path):
4 | if not os.path.exists(path):
5 | subprocess.call(['mkdir', path])
--------------------------------------------------------------------------------
/utils/rbf.py:
--------------------------------------------------------------------------------
1 | import torch, pdb
2 |
3 | def meta_rbf(size):
4 | batch = torch.zeros(size*size, size, size)
5 | count = 0
6 | for i in range(size):
7 | for j in range(size):
8 | batch[count,:,:] = rbf(size, (i,j)).clone()
9 | count += 1
10 | return batch
11 |
12 | def rbf(size, position):
13 | x,y = position
14 | grid = torch.zeros(size, size)
15 |
16 | top_left = manhattan(size, row='increasing', col='increasing')
17 | bottom_left = manhattan(size, row='increasing', col='decreasing')
18 | top_right = manhattan(size, row='decreasing', col='increasing')
19 | bottom_right = manhattan(size, row='decreasing', col='decreasing')
20 |
21 | ## top left
22 | if x > 0 and y > 0:
23 | grid[:x+1, :y+1] = bottom_right[-x-1:, -y-1:]
24 | ## bottom left
25 | if x < size and y > 0:
26 | grid[x:, :y+1] = top_right[:size-x, -y-1:]
27 | ## top right
28 | if x > 0 and y < size:
29 | grid[:x+1, y:] = bottom_left[size-x-1:, :size-y]
30 | ## bottom right
31 | if x < size and y < size:
32 | grid[x:, y:] = top_left[:size-x, :size-y]
33 |
34 | return grid
35 |
36 | def manhattan(size, row='increasing', col='increasing'):
37 | if row == 'increasing':
38 | rows = range_grid(0, size, 1, size)
39 | elif row == 'decreasing':
40 | rows = range_grid(size-1, -1, -1, size)
41 | else:
42 | raise RuntimeError('Unrecognized row in manhattan: ', row)
43 |
44 | if col == 'increasing':
45 | cols = range_grid(0, size, 1, size).t()
46 | elif col == 'decreasing':
47 | cols = range_grid(size-1, -1, -1, size).t()
48 | else:
49 | raise RuntimeError('Unrecognized col in manhattan: ', col)
50 |
51 | distance = rows + cols
52 | return distance
53 |
54 | def range_grid(low, high, step, repeat):
55 | grid = torch.arange(low, high, step).repeat(repeat, 1)
56 | return grid
57 |
58 |
59 |
60 | if __name__ == '__main__':
61 | import sys
62 | sys.path.append('../')
63 | import pipeline
64 |
65 | map_dim = 10
66 |
67 | row = torch.arange(0,map_dim).unsqueeze(1).repeat(1,map_dim).numpy()
68 | col = torch.arange(0,map_dim).repeat(map_dim,1).numpy()
69 |
70 | pipeline.visualize_fig(row, 'figures/grad_row.png')
71 | pipeline.visualize_fig(col, 'figures/grad_col.png')
72 |
73 | # import scipy.misc
74 | # grid = rbf(10, (5,5))
75 | # print grid
76 | batch = meta_rbf(10).numpy()
77 |
78 | # pdb.set_trace()
79 |
80 | for b in range(batch.shape[0]):
81 | pipeline.visualize_fig(batch[b], 'figures/' + str(b) + '.png')
82 | print batch.size
83 | # print batch[-1]
84 | pdb.set_trace()
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
--------------------------------------------------------------------------------
/visualization/__init__.py:
--------------------------------------------------------------------------------
1 | from place_goal import *
2 | from vis_predictions import *
--------------------------------------------------------------------------------
/visualization/place_goal.py:
--------------------------------------------------------------------------------
1 | import os, pickle, argparse, subprocess, numpy as np, scipy.misc
2 |
3 | def place_goal(img, goal, save_path, num_cells = 10, width = 4):
4 | img = img.copy()
5 | dim = img.shape[0]
6 | cell_size = float(dim) / num_cells
7 | half_width = width / 2
8 |
9 | ## add black rectangular grid
10 | for i in range(0, num_cells+1):
11 | low = int(i * cell_size) - half_width
12 | high = low + half_width
13 | img[low:high, :, :] = 0
14 | img[:, low:high, :] = 0
15 |
16 | if goal != None:
17 | ## add red square around goal cell
18 | goal_row_low = int(goal[0] * cell_size)
19 | goal_row_high = min( int(goal_row_low + cell_size), img.shape[0] - 1)
20 | goal_col_low = int(goal[1] * cell_size)
21 | goal_col_high = min( int(goal_col_low + cell_size), img.shape[1] - 1)
22 |
23 | img[max(goal_row_low-width, 0):goal_row_low+width, goal_col_low:goal_col_high+1, :] = [255,0,0]
24 | img[goal_row_high-width:goal_row_high+width, goal_col_low:goal_col_high+1, :] = [255,0,0]
25 | img[goal_row_low:goal_row_high+1, max(goal_col_low-width, 0):goal_col_low+width, :] = [255,0,0]
26 | img[goal_row_low:goal_row_high+1, goal_col_high-width:goal_col_high+width, :] = [255,0,0]
27 |
28 | scipy.misc.imsave(save_path, img)
29 |
30 | if __name__ == '__main__':
31 | parser = argparse.ArgumentParser()
32 | # parser.add_argument('--data_path', type=str, default='../data/mturk_nonunique_global_10dim_8max_2ref/')
33 | # parser.add_argument('--sprite_path', type=str, default='../data/mturk_nonunique_global_10dim_8max_2ref_vis/')
34 | parser.add_argument('--data_path', type=str, default='../data/mturk_only_global_10dim_8max_2ref/')
35 | parser.add_argument('--sprite_path', type=str, default='../data/mturk_only_global_10dim_8max_2ref_vis/')
36 | parser.add_argument('--save_path', type=str, default='global_trial/')
37 | parser.add_argument('--start', type=int, default=200)
38 | parser.add_argument('--end', type=int, default=300)
39 | parser.add_argument('--dim', type=int, default=10)
40 | args = parser.parse_args()
41 |
42 | print args, '\n'
43 |
44 | if not os.path.exists(args.save_path):
45 | subprocess.Popen(['mkdir', args.save_path])
46 |
47 | for i in range(args.start, args.end):
48 | data_path = os.path.join(args.data_path, str(i) + '.p')
49 | data = pickle.load( open(data_path, 'rb') )
50 |
51 | sprite_path = os.path.join(args.sprite_path, str(i) + '_sprites.png')
52 | sprites = scipy.misc.imread(sprite_path)
53 |
54 | dump_path = os.path.join(args.save_path, str(i) + '.p')
55 | pickle.dump(data, open(dump_path, 'wb'))
56 |
57 | save_path = os.path.join(args.save_path, str(i) + '_clean.png')
58 | place_goal(sprites, None, save_path, num_cells = args.dim)
59 |
60 | for g, goal in enumerate(data['goals']):
61 | print data['instructions'][g], goal
62 | save_path = os.path.join(args.save_path, str(i) + '_' + str(g) + '.png')
63 | print save_path, '\n'
64 | place_goal(sprites, goal, save_path, num_cells = args.dim)
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
--------------------------------------------------------------------------------
/visualization/run_vis.py:
--------------------------------------------------------------------------------
1 | #!/om/user/janner/anaconda2/envs/pytorch/bin/python
2 |
3 | import sys, os, subprocess, argparse, numpy as np, pickle, pdb
4 | from tqdm import tqdm
5 | from matplotlib import cm
6 | from vis_predictions import *
7 | sys.path.append('../')
8 | import environment, pipeline
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--save_path', type=str, default='../logs/example/pickle/')
12 | args = parser.parse_args()
13 |
14 | predictions = pickle.load( open(os.path.join(args.save_path, 'predictions.p'), 'rb') ).squeeze()
15 | targets = pickle.load( open(os.path.join(args.save_path, 'targets.p'), 'rb') ).squeeze()
16 | rewards = pickle.load( open(os.path.join(args.save_path, 'rewards.p'), 'rb') ).squeeze()
17 | terminal = pickle.load( open(os.path.join(args.save_path, 'terminal.p'), 'rb') ).squeeze()
18 |
19 | vis_path = os.path.join(args.save_path, 'vis')
20 | if not os.path.exists(vis_path):
21 | subprocess.call(['mkdir', vis_path])
22 |
23 | num_worlds = targets.shape[0]
24 |
25 | for ind in tqdm(range(num_worlds)):
26 | pred = predictions[ind]
27 | targ = targets[ind]
28 |
29 | vmax = max(pred.max(), targ.max())
30 | vmin = min(pred.min(), targ.min())
31 |
32 | pred_path = os.path.join(vis_path, str(ind) + '_pred.png')
33 | targ_path = os.path.join(vis_path, str(ind) + '_targ.png')
34 |
35 | vis_fig(pred, pred_path, vmax=vmax, vmin=vmin)
36 | vis_fig(targ, targ_path, vmax=vmax, vmin=vmin)
37 |
38 |
39 |
--------------------------------------------------------------------------------
/visualization/vis_predictions.py:
--------------------------------------------------------------------------------
1 | import os, math, torch, pdb
2 | from tqdm import tqdm
3 | from torch.autograd import Variable
4 | import matplotlib; matplotlib.use('Agg')
5 | from matplotlib import cm
6 | from matplotlib import pyplot as plt
7 |
8 | def vis_value_map(pred, targ, save_path, title='prediction', share=True):
9 | # print 'in vis: ', pred.shape, targ.shape
10 | dim = int(math.sqrt(pred.size))
11 | if share:
12 | vmin = min(pred.min(), targ.min())
13 | vmax = max(pred.max(), targ.max())
14 | else:
15 | vmin = None
16 | vmax = None
17 |
18 | plt.clf()
19 | fig, (ax0,ax1) = plt.subplots(1,2,sharey=True)
20 | heat0 = ax0.pcolor(pred.reshape(dim,dim), vmin=vmin, vmax=vmax, cmap=cm.jet)
21 | ax0.set_title(title, fontsize=5)
22 | if not share:
23 | fig.colorbar(heat0)
24 | heat1 = ax1.pcolor(targ.reshape(dim,dim), vmin=vmin, vmax=vmax, cmap=cm.jet)
25 | ax1.invert_yaxis()
26 | ax1.set_title('target')
27 |
28 | fig.subplots_adjust(right=0.8)
29 | cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
30 | fig.colorbar(heat1, cax=cbar_ax)
31 |
32 | # print 'saving to: ', fullpath
33 | plt.savefig(save_path, bbox_inches='tight')
34 | plt.close(fig)
35 |
36 | # print pred.shape, targ.shape
37 |
38 | def vis_fig(data, save_path, title=None, vmax=None, vmin=None, cmap=cm.jet):
39 | # print 'in vis: ', pred.shape, targ.shape
40 | dim = int(math.sqrt(data.size))
41 |
42 | # if share:
43 | # vmin = min(pred.min(), targ.min())
44 | # vmax = max(pred.max(), targ.max())
45 | # else:
46 | # vmin = None
47 | # vmax = None
48 |
49 | plt.clf()
50 | # fig, (ax0,ax1) = plt.subplots(1,2,sharey=True)
51 | plt.pcolor(data.reshape(dim,dim), vmin=vmin, vmax=vmax, cmap=cmap)
52 | plt.xticks([])
53 | plt.yticks([])
54 | # ax0.set_title(title, fontsize=5)
55 | # if not share:
56 | # fig.colorbar(heat0)
57 | # heat1 = ax1.pcolor(targ.reshape(dim,dim), vmin=vmin, vmax=vmax, cmap=cm.jet)
58 | fig = plt.gcf()
59 | ax = plt.gca()
60 |
61 | if title:
62 | ax.set_title(title)
63 | ax.invert_yaxis()
64 |
65 | fig.set_size_inches(4,4)
66 |
67 | # ax1.invert_yaxis()
68 | # ax1.set_title('target')
69 |
70 | # fig.subplots_adjust(right=0.8)
71 | # cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
72 | # fig.colorbar(heat1, cax=cbar_ax)
73 |
74 | # print 'saving to: ', fullpath
75 | plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0)
76 | plt.close(fig)
77 |
78 | # print pred.shape, targ.shape
79 |
80 | def vis_predictions(model, inputs, targets, instructions, save_path, prefix=''):
81 | ## wrap tensors in Variables to pass to model
82 | input_vars = ( Variable(tensor.contiguous()) for tensor in inputs )
83 | predictions = model(input_vars)
84 |
85 | ## convert to numpy arrays for saving to disk
86 | predictions = predictions.data.cpu().numpy()
87 | targets = targets.cpu().numpy()
88 |
89 | num_inputs = inputs[0].size(0)
90 | for ind in tqdm(range(num_inputs)):
91 | pred = predictions[ind]
92 | targ = targets[ind]
93 | instr = instructions[ind]
94 |
95 | full_path = os.path.join(save_path, prefix + str(ind) + '.png')
96 |
97 | vis_value_map(pred, targ, full_path, title=instr, share=False)
98 |
99 |
--------------------------------------------------------------------------------