├── DQN-HDRLN
├── NeuralQLearner.lua
├── Rectifier.lua
├── Scale.lua
├── TransitionTable.lua
├── convnet.lua
├── convnet_atari3.lua
├── convnet_atari_main.lua
├── initenv.lua
├── net_downsample_2x_full_y.lua
├── nnutils.lua
├── run_gpu_dist
├── test_agent.lua
├── test_dist_agent.lua
├── test_gpu_dist
└── train_agent.lua
├── Distillation process
├── NeuralQLearner.lua
├── README.md
├── TransitionTable.lua
├── convnet.lua
├── convnet_atari3.lua
├── convnet_atari_main.lua
├── distill_agent.lua
├── distill_gpu
├── initenv.lua
└── test_distill
├── LICENSE
├── README.md
├── Utilities
├── README.md
├── activations_hdf5_to_tsne.lua
├── clean_data.py
├── cluster.py
└── learn_weights.py
└── graying_the_box
├── .gitignore
├── .idea
├── .name
├── encodings.xml
├── graying_the_box.iml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── LUA
├── dqn
│ ├── NeuralQLearner.lua
│ ├── Rectifier.lua
│ ├── Scale.lua
│ ├── TransitionTable.lua
│ ├── convnet.lua
│ ├── convnet_atari3.lua
│ ├── initenv.lua
│ ├── net_downsample_2x_full_y.lua
│ ├── nnutils.lua
│ ├── test_agent.lua
│ ├── train_agent.lua
│ └── train_agent_tmp.lua
├── logs
│ └── Results
│ │ ├── Plot.lua
│ │ ├── Plot2.lua
│ │ ├── Plot3.lua
│ │ ├── Process.sh
│ │ ├── Process2.sh
│ │ ├── Process3.sh
│ │ └── tmp
│ │ ├── TD_Error.png
│ │ ├── conv1.png
│ │ ├── conv1_max.png
│ │ ├── conv2.png
│ │ ├── conv2_max.png
│ │ ├── conv3.png
│ │ ├── conv3_max.png
│ │ ├── lin1.png
│ │ ├── lin1_max.png
│ │ ├── lin2.png
│ │ ├── lin2_max.png
│ │ ├── reward.png
│ │ └── vavg.png
├── roms
│ └── README
└── run_gpu
├── bhtsne
├── Makefile.win
├── parse_lua_tensor.py
├── sptree.cpp
├── sptree.h
├── tsne.cpp
├── tsne.h
├── vptree.h
└── write_movie.py
├── clustering.py
├── common.py
├── digraph.py
├── emhc.py
├── hand_crafted_features
├── add_breakout_buttons.py
├── add_global_features.py
├── add_global_features.py~
├── add_pacman_buttons.py
├── add_seaquest_buttons.py
├── control_buttons.py
├── label_states_breakout.py
├── label_states_packman.py
└── label_states_seaquest.py
├── main.py
├── others
├── SaliencyScore.py
├── Score_trans.py
├── Seaquest_SaliencyDiver.py
└── TermFig.py
├── prepare_data.py
├── prepare_global_features.py
├── smdp.py
└── vis_tool.py
/DQN-HDRLN/Rectifier.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2014 Google Inc.
3 |
4 | See LICENSE file for full terms of limited license.
5 | ]]
6 |
7 | --[[ Rectified Linear Unit.
8 |
9 | The output is max(0, input).
10 | --]]
11 |
12 | local Rectifier, parent = torch.class('nn.Rectifier', 'nn.Module')
13 |
14 | -- This module accepts minibatches
15 | function Rectifier:updateOutput(input)
16 | return self.output:resizeAs(input):copy(input):abs():add(input):div(2)
17 | end
18 |
19 | function Rectifier:updateGradInput(input, gradOutput)
20 | self.gradInput:resizeAs(self.output)
21 | return self.gradInput:sign(self.output):cmul(gradOutput)
22 | end
--------------------------------------------------------------------------------
/DQN-HDRLN/Scale.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2014 Google Inc.
3 |
4 | See LICENSE file for full terms of limited license.
5 | ]]
6 |
7 | require "nn"
8 | require "image"
9 |
10 | local scale = torch.class('nn.Scale', 'nn.Module')
11 |
12 |
13 | function scale:__init(height, width)
14 | self.height = height
15 | self.width = width
16 | end
17 |
18 | function scale:forward(x)
19 | local x = x
20 | if x:dim() > 3 then
21 | x = x[1]
22 | end
23 |
24 | x = image.rgb2y(x)
25 | x = image.scale(x, self.width, self.height, 'bilinear')
26 | return x
27 | end
28 |
29 | function scale:updateOutput(input)
30 | return self:forward(input)
31 | end
32 |
33 | function scale:float()
34 | end
35 |
--------------------------------------------------------------------------------
/DQN-HDRLN/convnet.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2014 Google Inc.
3 |
4 | See LICENSE file for full terms of limited license.
5 | ]]
6 |
7 | require "initenv"
8 |
9 | function create_network(args)
10 |
11 | local net = nn.Sequential()
12 | net:add(nn.Reshape(unpack(args.input_dims)))
13 |
14 | --- first convolutional layer
15 | local convLayer = nn.SpatialConvolution
16 |
17 | net:add(convLayer(args.hist_len*args.ncols, args.n_units[1],
18 | args.filter_size[1], args.filter_size[1],
19 | args.filter_stride[1], args.filter_stride[1],1))
20 | net:add(args.nl())
21 |
22 | -- Add convolutional layers
23 | for i=1,(#args.n_units-1) do
24 | -- second convolutional layer
25 | net:add(convLayer(args.n_units[i], args.n_units[i+1],
26 | args.filter_size[i+1], args.filter_size[i+1],
27 | args.filter_stride[i+1], args.filter_stride[i+1]))
28 | net:add(args.nl())
29 | end
30 |
31 | local nel
32 | if args.gpu >= 0 then
33 | nel = net:cuda():forward(torch.zeros(1,unpack(args.input_dims))
34 | :cuda()):nElement()
35 | else
36 | nel = net:forward(torch.zeros(1,unpack(args.input_dims))):nElement()
37 | end
38 |
39 | -- reshape all feature planes into a vector per example
40 | net:add(nn.Reshape(nel))
41 |
42 | -- fully connected layer
43 | net:add(nn.Linear(nel, args.n_hid[1]))
44 | net:add(args.nl())
45 | local last_layer_size = args.n_hid[1]
46 |
47 | for i=1,(#args.n_hid-1) do
48 | -- add Linear layer
49 | last_layer_size = args.n_hid[i+1]
50 | net:add(nn.Linear(args.n_hid[i], last_layer_size))
51 | net:add(args.nl())
52 | end
53 |
54 | if args.distilled_network == true and false then
55 | -- add the last fully connected layer (to actions)
56 | concat = nn.Concat(2)
57 |
58 |
59 | local MCgameActions = {1,3,4,5,0,6,7,8,9} --,5,6,7,8} -- this is our game actions table
60 | local MCgameActions_primitive = {1,3,4,0,5} --,5} -- this is our game actions table
61 | local optionsActions = {6,7,8,9} -- these actions are correlated to an OPTION, 20 = solve room (make this struct with max iterations per option and socket port and ip)
62 |
63 | local navigateActions = {1,3,4}
64 | local pickupActions = {1,3,4}
65 | local breakActions = {1,3,4,5}
66 | local placeActions = {1,3,4,0}
67 |
68 |
69 | --args.skills = {navigateActions, pickupActions, breakActions, placeActions}
70 |
71 | for i=1,#args.skills do
72 | print('Added new skill layer '..i..' with '..(args.skills[i])..' actions')
73 | skill = nn.Sequential()
74 | skill:add(nn.Linear(last_layer_size, (args.skills[i])))
75 | concat:add(skill)
76 | end
77 |
78 | net:add(concat)
79 | else
80 | -- add the last fully connected layer (to actions)
81 | net:add(nn.Linear(last_layer_size, args.n_actions))
82 | end
83 |
84 | if args.gpu >=0 then
85 | net:cuda()
86 | end
87 | if args.verbose >= 2 then
88 | print(net)
89 | print('Convolutional layers flattened output size:', nel)
90 | end
91 | return net
92 | end
93 |
--------------------------------------------------------------------------------
/DQN-HDRLN/convnet_atari3.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2014 Google Inc.
3 |
4 | See LICENSE file for full terms of limited license.
5 | ]]
6 |
7 | require 'convnet'
8 |
9 | return function(args)
10 | args.n_units = {32, 64, 64}
11 | args.filter_size = {8, 4, 3}
12 | args.filter_stride = {4, 2, 1}
13 | args.n_hid = {512}
14 | args.nl = nn.Rectifier
15 | args.distilled_network = true
16 | args.skills = {3,3,4,4}
17 | return create_network(args)
18 | end
19 |
20 |
--------------------------------------------------------------------------------
/DQN-HDRLN/convnet_atari_main.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2014 Google Inc.
3 |
4 | See LICENSE file for full terms of limited license.
5 | ]]
6 |
7 | require 'convnet'
8 |
9 | return function(args)
10 | args.n_units = {32, 64, 64}
11 | args.filter_size = {8, 4, 3}
12 | args.filter_stride = {4, 2, 1}
13 | args.n_hid = {512}
14 | args.nl = nn.Rectifier
15 | return create_network(args)
16 | end
17 |
--------------------------------------------------------------------------------
/DQN-HDRLN/initenv.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2014 Google Inc.
3 |
4 | See LICENSE file for full terms of limited license.
5 | ]]
6 | dqn = {}
7 |
8 | require 'torch'
9 | require 'nn'
10 | require 'nngraph'
11 | require 'nnutils'
12 | require 'image'
13 | require 'Scale'
14 | require 'NeuralQLearner'
15 | require 'TransitionTable'
16 | require 'Rectifier'
17 |
18 |
19 | function torchSetup(_opt)
20 | _opt = _opt or {}
21 | local opt = table.copy(_opt)
22 | assert(opt)
23 |
24 | -- preprocess options:
25 | --- convert options strings to tables
26 | if opt.pool_frms then
27 | opt.pool_frms = str_to_table(opt.pool_frms)
28 | end
29 | if opt.env_params then
30 | opt.env_params = str_to_table(opt.env_params)
31 | end
32 | if opt.agent_params then
33 | opt.agent_params = str_to_table(opt.agent_params)
34 | opt.agent_params.gpu = opt.gpu
35 | opt.agent_params.best = opt.best
36 | opt.agent_params.verbose = opt.verbose
37 | if opt.network ~= '' then
38 | opt.agent_params.network = opt.network
39 | end
40 | end
41 |
42 | --- general setup
43 | opt.tensorType = opt.tensorType or 'torch.FloatTensor'
44 | torch.setdefaulttensortype(opt.tensorType)
45 | if not opt.threads then
46 | opt.threads = 4
47 | end
48 | torch.setnumthreads(opt.threads)
49 | if not opt.verbose then
50 | opt.verbose = 10
51 | end
52 | if opt.verbose >= 1 then
53 | print('Torch Threads:', torch.getnumthreads())
54 | end
55 |
56 | --- set gpu device
57 | if opt.gpu and opt.gpu >= 0 then
58 | require 'cutorch'
59 | require 'cunn'
60 | if opt.gpu == 0 then
61 | local gpu_id = tonumber(os.getenv('GPU_ID'))
62 | if gpu_id then opt.gpu = gpu_id+1 end
63 | end
64 | if opt.gpu > 0 then cutorch.setDevice(opt.gpu) end
65 | opt.gpu = cutorch.getDevice()
66 | print('Using GPU device id:', opt.gpu-1)
67 | else
68 | opt.gpu = -1
69 | if opt.verbose >= 1 then
70 | print('Using CPU code only. GPU device id:', opt.gpu)
71 | end
72 | end
73 |
74 | --- set up random number generators
75 | -- removing lua RNG; seeding torch RNG with opt.seed and setting cutorch
76 | -- RNG seed to the first uniform random int32 from the previous RNG;
77 | -- this is preferred because using the same seed for both generators
78 | -- may introduce correlations; we assume that both torch RNGs ensure
79 | -- adequate dispersion for different seeds.
80 | math.random = nil
81 | opt.seed = opt.seed or 1
82 | torch.manualSeed(opt.seed)
83 | if opt.verbose >= 1 then
84 | print('Torch Seed:', torch.initialSeed())
85 | end
86 | local firstRandInt = torch.random()
87 | if opt.gpu >= 0 then
88 | cutorch.manualSeed(firstRandInt)
89 | if opt.verbose >= 1 then
90 | print('CUTorch Seed:', cutorch.initialSeed())
91 | end
92 | end
93 |
94 | return opt
95 | end
96 |
97 |
98 | function setup(_opt)
99 | assert(_opt)
100 |
101 | --preprocess options:
102 | --- convert options strings to tables
103 | _opt.pool_frms = str_to_table(_opt.pool_frms)
104 | _opt.env_params = str_to_table(_opt.env_params)
105 | _opt.agent_params = str_to_table(_opt.agent_params)
106 | _opt.skill_agent_params = str_to_table(_opt.skill_agent_params)
107 | if _opt.agent_params.transition_params then
108 | _opt.skill_agent_params.transition_params =
109 | str_to_table(_opt.agent_params.transition_params)
110 |
111 | _opt.agent_params.transition_params =
112 | str_to_table(_opt.agent_params.transition_params)
113 | end
114 |
115 | --- first things first
116 | local opt = torchSetup(_opt)
117 |
118 | local distilled_hdrln = _opt.distilled_hdrln
119 | if distilled_hdrln == "true" then
120 | distilled_hdrln = true
121 | else
122 | distilled_hdrln = false
123 | end
124 | local supervised_skills = _opt.supervised_skills
125 | if supervised_skills == "true" then
126 | supervised_skills = true
127 | else
128 | supervised_skills = false
129 | end
130 |
131 | local num_skills = tonumber(_opt.num_skills)
132 | if args.supervised_file then
133 | local myFile = hdf5.open(args.supervised_file, 'r')
134 | num_skills = myFile:read('numberSkills'):all()
135 | myFile:close()
136 | end
137 | local gameEnv = nil
138 |
139 | local MCgameActions_primitive = {1,3,4,0,5} -- this is our game actions table
140 | local MCgameActions = MCgameActions_primitive:copy()
141 | local options = {} -- these actions are correlated to an OPTION, i.e an action the HDRLN selects that is mapped to a skill
142 | local optionsActions = {} -- for each skill we map availiable actions. Currently each skill maps all aviliable actions (doesn't have to be this way)
143 |
144 | local max_action_val = MCgameActions_primitive[1]
145 | for i = 2, #MCgameActions_primitive
146 | do
147 | max_action_val = max(max_action_val, MCgameActions_primitive[i])
148 | end
149 |
150 | for i = 1, num_skills
151 | do
152 | MCgameActions[#MCgameActions + 1] = i + max_action_val
153 | options[#options + 1] = i + max_action_val -- we want all actions mapped to skills to be larger than the maximal primitive action value
154 | optionsActions[#optionsActions + 1] = MCgameActions_primitive -- map all primitive actions for each skill
155 | end
156 |
157 |
158 | -- agent options
159 | _opt.agent_params.actions = MCgameActions
160 | _opt.agent_params.options = options
161 | _opt.agent_params.optionsActions = optionsActions
162 | _opt.agent_params.gpu = _opt.gpu
163 | _opt.agent_params.best = _opt.best
164 | _opt.agent_params.distilled_network = distilled_hdrln
165 | _opt.agent_params.distill = false
166 | if _opt.agent_params.network then
167 | print(_opt.agent_params.network)
168 | _opt.agent_params.network = "convnet_atari_main"
169 | end
170 | --_opt.agent_params.network = "convnet_atari3"
171 | if _opt.network ~= '' then
172 | _opt.agent_params.network = _opt.network
173 | end
174 |
175 | _opt.agent_params.verbose = _opt.verbose
176 | if not _opt.agent_params.state_dim then
177 | _opt.agent_params.state_dim = gameEnv:nObsFeature()
178 | end
179 | if distilled_hdrln then -- distilled means single main network with multiple skills integrated into it
180 | _opt.skill_agent_params.actions = MCgameActions_primitive
181 | _opt.skill_agent_params.gpu = _opt.gpu
182 | _opt.skill_agent_params.best = _opt.best
183 | _opt.skill_agent_params.distilled_network = true
184 | _opt.skill_agent_params.distill = false
185 | _opt.agent_params.supervised_skills = supervised_skills
186 | _opt.agent_params.supervised_file = args.supervised_file
187 | if _opt.distilled_network ~= '' then
188 | _opt.skill_agent_params.network = _opt.distilled_network
189 | end
190 | _opt.skill_agent_params.verbose = _opt.verbose
191 | if not _opt.skill_agent_params.state_dim then
192 | _opt.skill_agent_params.state_dim = gameEnv:nObsFeature()
193 | end
194 | print("SKILL NETWORK")
195 | local distilled_agent = dqn[_opt.agent](_opt.skill_agent_params)
196 | print("END SKILL NETWORK")
197 |
198 | _opt.agent_params.skill_agent = distilled_agent
199 | else
200 | for i = 1, num_skills
201 | do
202 | _opt.skill_agent_params.actions = MCgameActions_primitive
203 | _opt.skill_agent_params.gpu = _opt.gpu
204 | _opt.skill_agent_params.best = _opt.best
205 | _opt.skill_agent_params.distilled_network = false
206 | _opt.skill_agent_params.distill = false
207 | if getlocal('_opt.skill_' .. i) ~= '' then
208 | _opt.skill_agent_params.network = getlocal('_opt.skill_' .. i)
209 | end
210 | _opt.skill_agent_params.verbose = _opt.verbose
211 | if not _opt.skill_agent_params.state_dim then
212 | _opt.skill_agent_params.state_dim = gameEnv:nObsFeature()
213 | end
214 | print("SKILL NETWORK " .. i)
215 | local skill_agent = dqn[_opt.agent](_opt.skill_agent_params)
216 | _opt.agent_params.skill_agent[#(_opt.agent_params.skill_agent) + 1] = skill_agent
217 | print("END SKILL NETWORK " .. i)
218 | end
219 | end
220 |
221 | _opt.agent_params.primitive_actions = MCgameActions_primitive
222 | print("MAIN AGENT")
223 | local agent = dqn[_opt.agent](_opt.agent_params)
224 | print("MAIN AGENT")
225 | if opt.verbose >= 1 then
226 | print('Set up Torch using these options:')
227 | for k, v in pairs(opt) do
228 | print(k, v)
229 | end
230 | end
231 |
232 | return gameEnv, MCgameActions_primitive, agent, opt
233 | end
234 |
235 |
236 |
237 | --- other functions
238 |
239 | function str_to_table(str)
240 | if type(str) == 'table' then
241 | return str
242 | end
243 | if not str or type(str) ~= 'string' then
244 | if type(str) == 'table' then
245 | return str
246 | end
247 | return {}
248 | end
249 | local ttr
250 | if str ~= '' then
251 | local ttx=tt
252 | loadstring('tt = {' .. str .. '}')()
253 | ttr = tt
254 | tt = ttx
255 | else
256 | ttr = {}
257 | end
258 | return ttr
259 | end
260 |
261 | function table.copy(t)
262 | if t == nil then return nil end
263 | local nt = {}
264 | for k, v in pairs(t) do
265 | if type(v) == 'table' then
266 | nt[k] = table.copy(v)
267 | else
268 | nt[k] = v
269 | end
270 | end
271 | setmetatable(nt, table.copy(getmetatable(t)))
272 | return nt
273 | end
274 |
--------------------------------------------------------------------------------
/DQN-HDRLN/net_downsample_2x_full_y.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2014 Google Inc.
3 |
4 | See LICENSE file for full terms of limited license.
5 | ]]
6 |
7 | require "image"
8 | require "Scale"
9 |
10 | local function create_network(args)
11 | -- Y (luminance)
12 | return nn.Scale(84, 84, true)
13 | end
14 |
15 | return create_network
16 |
--------------------------------------------------------------------------------
/DQN-HDRLN/nnutils.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2014 Google Inc.
3 |
4 | See LICENSE file for full terms of limited license.
5 | ]]
6 |
7 | require "torch"
8 |
9 | function recursive_map(module, field, func)
10 | local str = ""
11 | if module[field] or module.modules then
12 | str = str .. torch.typename(module) .. ": "
13 | end
14 | if module[field] then
15 | str = str .. func(module[field])
16 | end
17 | if module.modules then
18 | str = str .. "["
19 | for i, submodule in ipairs(module.modules) do
20 | local submodule_str = recursive_map(submodule, field, func)
21 | str = str .. submodule_str
22 | if i < #module.modules and string.len(submodule_str) > 0 then
23 | str = str .. " "
24 | end
25 | end
26 | str = str .. "]"
27 | end
28 |
29 | return str
30 | end
31 |
32 | function abs_mean(w)
33 | return torch.mean(torch.abs(w:clone():float()))
34 | end
35 |
36 | function abs_max(w)
37 | return torch.abs(w:clone():float()):max()
38 | end
39 |
40 | -- Build a string of average absolute weight values for the modules in the
41 | -- given network.
42 | function get_weight_norms(module)
43 | return "Weight norms:\n" .. recursive_map(module, "weight", abs_mean) ..
44 | "\nWeight max:\n" .. recursive_map(module, "weight", abs_max)
45 | end
46 |
47 | -- Build a string of average absolute weight gradient values for the modules
48 | -- in the given network.
49 | function get_grad_norms(module)
50 | return "Weight grad norms:\n" ..
51 | recursive_map(module, "gradWeight", abs_mean) ..
52 | "\nWeight grad max:\n" .. recursive_map(module, "gradWeight", abs_max)
53 | end
54 |
--------------------------------------------------------------------------------
/DQN-HDRLN/run_gpu_dist:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ -z "$1" ]
4 | then echo "Please provide the name of the game, e.g. ./run_gpu breakout "; exit 0
5 | fi
6 | ENV=$1
7 | FRAMEWORK="alewrap"
8 |
9 | game_path=$PWD"/roms/"
10 | env_params="useRGB=true"
11 | agent="NeuralQLearner"
12 | n_replay=16
13 | netfile="\"convnet_atari3\""
14 | update_freq=4
15 | actrep=1
16 | discount=0.99
17 | seed=1
18 | learn_start=20000 #2000 #20000 #5000 #50000
19 | pool_frms_type="\"max\""
20 | pool_frms_size=2
21 | initial_priority="false"
22 | replay_memory=100000 #1000000
23 | eps_end=0.1
24 | eps_endt=100000 #400000 #75000 #replay_memory
25 | lr=0.0025 #0.00025
26 | agent_type="DQN3_0_1"
27 | preproc_net="\"net_downsample_2x_full_y\""
28 | agent_name=$agent_type"_"$1"_FULL_Y"
29 | state_dim=7056 #dimensions of image we send widthXheight
30 | ncols=1
31 | ddqn=1
32 | reward_shaping=0
33 | option_length=5
34 | agent_params="lr="$lr",ep=1,ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=32,rescale_r=1,ncols="$ncols",bufferSize=512,valid_size=500,target_q=10000,clip_delta=1,ddqn="$ddqn",option_length="$option_length"" #,min_reward=-2000,max_reward=2000" #minibatch_size=32,min_reward=-1,max_reward=1,hist_len=4
35 | steps=50000000
36 | eval_freq=5000 #5000 #20000 #250000
37 | eval_steps=1000 #1000 #125000
38 | prog_freq=10000 # how often it outputs to stdout
39 | save_freq=50000
40 | gpu=0
41 | random_starts=30
42 | pool_frms="type="$pool_frms_type",size="$pool_frms_size
43 | num_threads=4
44 | NETWORK=$agent_name
45 | skill_agent_params=$agent_params",network=\"convnet_atari3\""
46 | agent_params=$agent_params",network="$netfile
47 | distilled_hdrln="true" # this tells us if HDRLN setup is using a distilled agent or using multiple DQN array
48 | supervised_skills="true" # this goes together with distilled agent
49 | supervised_file="/home/deep5/DQN_Shahar_Chen_oldpc/dqn_distill/skillWeights.h5"
50 | num_skills=4 # how many skills, relevant both for distilled and non distilled setup
51 | args="-framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -skill_agent_params $skill_agent_params -agent_params $agent_params -steps $steps -eval_freq $eval_freq -eval_steps $eval_steps -prog_freq $prog_freq -save_freq $save_freq -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads -reward_shaping $reward_shaping -distilled_hdrln $distilled_hdrln -supervised_skills $supervised_skills -supervised_file $supervised_file"
52 | args=$args" -socket $2"
53 |
54 | args=$args" -num_skills $num_skills -skill_1 pickup.t7 -skill_2 navigate.t7 -skill_3 break.t7 -skill_4 break.t7 -distilled_network DQN3_0_1_hdrln_dist_fix_size_FULL_Y.t7"; #distilled_network_0.1temp.t7";
55 |
56 | echo $args
57 |
58 | cd dqn_distill
59 | ../torch/bin/qlua train_agent.lua $args
60 |
--------------------------------------------------------------------------------
/DQN-HDRLN/test_gpu_dist:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ -z "$1" ]
4 | then echo "Please provide the name of the game, e.g. ./watch_pretrained breakout"; exit 0
5 | fi
6 |
7 | if [ -z "$2" ]
8 | then echo "Please provide the pretrained network file, e.g. ./watch_pretrained breakout DQN3_0_1_breakout_FULL_Y.t7"; exit 0
9 | fi
10 |
11 | ENV=$1
12 | NETWORK=$2
13 | FRAMEWORK="alewrap"
14 |
15 | game_path=$PWD"/roms/"
16 | env_params="useRGB=true"
17 | agent="NeuralQLearner"
18 | n_replay=16
19 | netfile="\"convnet_atari3\""
20 | update_freq=4
21 | actrep=1
22 | discount=0.99
23 | seed=1
24 | learn_start=50000
25 | pool_frms_type="\"max\""
26 | pool_frms_size=2
27 | initial_priority="false"
28 | replay_memory=500010
29 | eps_end=0
30 | eps_endt=0
31 | lr=0.00025
32 | agent_type="DQN3_0_1"
33 | preproc_net="\"net_downsample_2x_full_y\""
34 | agent_name=$agent_type"_"$1"_FULL_Y"
35 | state_dim=7056
36 | ncols=1
37 | reward_shaping=0
38 | ddqn=1
39 | option_length=5
40 | agent_params="lr="$lr",ep=1,ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=32,rescale_r=1,ncols="$ncols",bufferSize=512,valid_size=500,target_q=10000,clip_delta=1,min_reward=-300,max_reward=300,ddqn="$ddqn",option_length="$option_length""
41 | gif_file="../gifs/$ENV."
42 | gpu=0
43 | random_starts=30
44 | pool_frms="type="$pool_frms_type",size="$pool_frms_size
45 | num_threads=4
46 | skill_agent_params=$agent_params",network=\"convnet_atari3\""
47 |
48 | args="-framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -skill_agent_params $skill_agent_params -agent_params $agent_params -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads -network $NETWORK -gif_file $gif_file -reward_shaping $reward_shaping"
49 | args=$args" -socket $3"
50 | args=$args" -distilled_network distilled_network_0.1temp.t7";
51 |
52 |
53 | echo $args
54 |
55 | cd dqn_distill
56 | ../torch/bin/qlua test_dist_agent.lua $args
57 |
--------------------------------------------------------------------------------
/Distillation process/README.md:
--------------------------------------------------------------------------------
1 | # Distilled Multi Skill Agent Process
2 | __This is in development stage__
3 |
4 | For further explanation see the [website](http://chentessler.wixsite.com/hdrlnminecraft).
5 |
6 | This code requires a lot of refactoring and cleanup.
7 | You let your "teachers" play for N time steps and record the data (Input: state, Output: Q values).
8 | Then with this code, load the data and train the new distilled "student".
9 |
10 | Currently trains using the MSE loss function, need to swap it to KL-Divergence.
11 |
--------------------------------------------------------------------------------
/Distillation process/convnet.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2014 Google Inc.
3 |
4 | See LICENSE file for full terms of limited license.
5 | ]]
6 |
7 | require "initenv"
8 |
9 | function create_network(args)
10 |
11 | local net = nn.Sequential()
12 | net:add(nn.Reshape(unpack(args.input_dims)))
13 |
14 | --- first convolutional layer
15 | local convLayer = nn.SpatialConvolution
16 |
17 | net:add(convLayer(args.hist_len*args.ncols, args.n_units[1],
18 | args.filter_size[1], args.filter_size[1],
19 | args.filter_stride[1], args.filter_stride[1],1))
20 | net:add(args.nl())
21 |
22 | -- Add convolutional layers
23 | for i=1,(#args.n_units-1) do
24 | -- second convolutional layer
25 | net:add(convLayer(args.n_units[i], args.n_units[i+1],
26 | args.filter_size[i+1], args.filter_size[i+1],
27 | args.filter_stride[i+1], args.filter_stride[i+1]))
28 | net:add(args.nl())
29 | end
30 |
31 | local nel
32 | if args.gpu >= 0 then
33 | nel = net:cuda():forward(torch.zeros(1,unpack(args.input_dims))
34 | :cuda()):nElement()
35 | else
36 | nel = net:forward(torch.zeros(1,unpack(args.input_dims))):nElement()
37 | end
38 |
39 | -- reshape all feature planes into a vector per example
40 | net:add(nn.Reshape(nel))
41 |
42 | -- fully connected layer
43 | net:add(nn.Linear(nel, args.n_hid[1]))
44 | net:add(args.nl())
45 | local last_layer_size = args.n_hid[1]
46 |
47 | for i=1,(#args.n_hid-1) do
48 | -- add Linear layer
49 | last_layer_size = args.n_hid[i+1]
50 | net:add(nn.Linear(args.n_hid[i], last_layer_size))
51 | net:add(args.nl())
52 | end
53 |
54 | if args.distilled_network == true and false then
55 | -- add the last fully connected layer (to actions)
56 | concat = nn.Concat(2)
57 |
58 |
59 | local MCgameActions = {1,3,4,5,0,6,7,8,9} --,5,6,7,8} -- this is our game actions table
60 | local MCgameActions_primitive = {1,3,4,0,5} --,5} -- this is our game actions table
61 | local optionsActions = {6,7,8,9} -- these actions are correlated to an OPTION, 20 = solve room (make this struct with max iterations per option and socket port and ip)
62 |
63 | local navigateActions = {1,3,4}
64 | local pickupActions = {1,3,4}
65 | local breakActions = {1,3,4,5}
66 | local placeActions = {1,3,4,0}
67 |
68 |
69 | --args.skills = {navigateActions, pickupActions, breakActions, placeActions}
70 |
71 | for i=1,#args.skills do
72 | print('Added new skill layer '..i..' with '..(args.skills[i])..' actions')
73 | skill = nn.Sequential()
74 | skill:add(nn.Linear(last_layer_size, (args.skills[i])))
75 | concat:add(skill)
76 | end
77 |
78 | net:add(concat)
79 | else
80 | -- add the last fully connected layer (to actions)
81 | net:add(nn.Linear(last_layer_size, args.n_actions))
82 | end
83 |
84 | if args.gpu >=0 then
85 | net:cuda()
86 | end
87 | if args.verbose >= 2 then
88 | print(net)
89 | print('Convolutional layers flattened output size:', nel)
90 | end
91 | return net
92 | end
93 |
--------------------------------------------------------------------------------
/Distillation process/convnet_atari3.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2014 Google Inc.
3 |
4 | See LICENSE file for full terms of limited license.
5 | ]]
6 |
7 | require 'convnet'
8 |
9 | return function(args)
10 | args.n_units = {32, 64, 64}
11 | args.filter_size = {8, 4, 3}
12 | args.filter_stride = {4, 2, 1}
13 | args.n_hid = {512}
14 | args.nl = nn.Rectifier
15 | args.distilled_network = true
16 | args.skills = {3,3,4,4}
17 | return create_network(args)
18 | end
19 |
20 |
--------------------------------------------------------------------------------
/Distillation process/convnet_atari_main.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2014 Google Inc.
3 |
4 | See LICENSE file for full terms of limited license.
5 | ]]
6 |
7 | require 'convnet'
8 |
9 | return function(args)
10 | args.n_units = {32, 64, 64}
11 | args.filter_size = {8, 4, 3}
12 | args.filter_stride = {4, 2, 1}
13 | args.n_hid = {512}
14 | args.nl = nn.Rectifier
15 | return create_network(args)
16 | end
17 |
--------------------------------------------------------------------------------
/Distillation process/distill_gpu:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ -z "$1" ]
4 | then echo "Please provide the name of the game, e.g. ./run_gpu breakout "; exit 0
5 | fi
6 | ENV=$1
7 | FRAMEWORK="alewrap"
8 |
9 | game_path=$PWD"/roms/"
10 | env_params="useRGB=true"
11 | agent="NeuralQLearner"
12 | n_replay=16
13 | netfile="\"convnet_atari3\""
14 | update_freq=4
15 | actrep=1
16 | discount=0.99
17 | seed=1
18 | learn_start=2000 #20000 #5000 #50000
19 | pool_frms_type="\"max\""
20 | pool_frms_size=2
21 | initial_priority="false"
22 | replay_memory=100000 #1000000
23 | eps_end=0.1
24 | eps_endt=100000 #400000 #75000 #replay_memory
25 | lr=0.00025
26 | agent_type="DQN3_0_1"
27 | preproc_net="\"net_downsample_2x_full_y\""
28 | agent_name=$agent_type"_"$1"_FULL_Y"
29 | state_dim=7056 #dimensions of image we send widthXheight
30 | ncols=1
31 | ddqn=1
32 | reward_shaping=1
33 | option_length=5
34 | agent_params="lr="$lr",ep=1,ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=1,rescale_r=1,ncols="$ncols",bufferSize=512,valid_size=500,target_q=10000,clip_delta=1,ddqn="$ddqn",option_length="$option_length"" #,min_reward=-2000,max_reward=2000" #minibatch_size=32,min_reward=-1,max_reward=1,hist_len=4
35 | steps=50000000
36 | eval_freq=5000 #5000 #20000 #250000
37 | eval_steps=1000 #1000 #125000
38 | prog_freq=10000 # how often it outputs to stdout
39 | save_freq=50000
40 | gpu=2
41 | random_starts=30
42 | pool_frms="type="$pool_frms_type",size="$pool_frms_size
43 | num_threads=4
44 | NETWORK=$agent_name
45 |
46 | args="-framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -skill_agent_params $agent_params -agent_params $agent_params -steps $steps -eval_freq $eval_freq -eval_steps $eval_steps -prog_freq $prog_freq -save_freq $save_freq -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads -reward_shaping $reward_shaping"
47 | args=$args" -socket $2"
48 | if [ -n "$3" ]
49 | then args=$args" -pickup_network pickup.t7 -navigate_network navigate.t7 -break_network break.t7 -place_network break.t7";
50 | fi
51 | echo $args
52 |
53 | cd dqn_distill
54 | ../torch/bin/qlua distill_agent.lua $args
55 |
--------------------------------------------------------------------------------
/Distillation process/test_distill:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ -z "$1" ]
4 | then echo "Please provide the name of the game, e.g. ./run_gpu breakout "; exit 0
5 | fi
6 | ENV=$1
7 | FRAMEWORK="alewrap"
8 |
9 | game_path=$PWD"/roms/"
10 | env_params="useRGB=true"
11 | agent="NeuralQLearner"
12 | n_replay=16
13 | netfile="\"convnet_atari3\""
14 | update_freq=4
15 | actrep=1
16 | discount=0.99
17 | seed=1
18 | learn_start=2000 #20000 #5000 #50000
19 | pool_frms_type="\"max\""
20 | pool_frms_size=2
21 | initial_priority="false"
22 | replay_memory=100000 #1000000
23 | eps_end=0.1
24 | eps_endt=100000 #400000 #75000 #replay_memory
25 | lr=0.0025 #0.00025
26 | agent_type="DQN3_0_1"
27 | preproc_net="\"net_downsample_2x_full_y\""
28 | agent_name=$agent_type"_"$1"_FULL_Y"
29 | state_dim=7056 #dimensions of image we send widthXheight
30 | ncols=1
31 | ddqn=1
32 | reward_shaping=1
33 | option_length=5
34 | agent_params="lr="$lr",ep=1,ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=32,rescale_r=1,ncols="$ncols",bufferSize=512,valid_size=500,target_q=10000,clip_delta=1,ddqn="$ddqn",option_length="$option_length"" #,min_reward=-2000,max_reward=2000" #minibatch_size=32,min_reward=-1,max_reward=1,hist_len=4
35 | steps=50000000
36 | eval_freq=5000 #5000 #20000 #250000
37 | eval_steps=1000 #1000 #125000
38 | prog_freq=10000 # how often it outputs to stdout
39 | save_freq=50000
40 | gpu=2
41 | random_starts=30
42 | pool_frms="type="$pool_frms_type",size="$pool_frms_size
43 | num_threads=4
44 | NETWORK=$2
45 |
46 | args="-framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -skill_agent_params $agent_params -agent_params $agent_params -steps $steps -eval_freq $eval_freq -eval_steps $eval_steps -prog_freq $prog_freq -save_freq $save_freq -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads -reward_shaping $reward_shaping -network $NETWORK"
47 | args=$args" -socket $3"
48 |
49 | args=$args" -pickup_network pickup.t7 -navigate_network navigate.t7 -break_network break.t7 -place_network break.t7 -distilled_network distilled_network_0.1temp_batchof1.t7"
50 |
51 | echo $args
52 |
53 | cd dqn_distill
54 | ../torch/bin/qlua test_agent.lua $args
55 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Chen Tessler
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # H-DRLN
2 | __This is in development stage__
3 |
4 | The H-DRLN (Hierarchical Deep RL Network) is a novel architecture to incorporate options (skills) within the DQN.
5 | For further explanation see the [website](http://chentessler.wixsite.com/hdrlnminecraft).
6 |
7 | ### Graphic view of created skill clusters
8 | We use the work done by [Zahavy et al](https://arxiv.org/abs/1602.02658) and adapt the [interface they created](https://github.com/bentzinir/graying_the_box) in order to graphically view the clusters on the tSNE plot.
9 |
10 | ### References
11 | - [A Deep Hierarchical Approach to Lifelong Learning in Minecraft - Tessler et al](https://arxiv.org/abs/1604.07255)
12 | - [Human-level control through deep reinforcement learning. Nature 518(7540):529–533 Mnih et al](https://arxiv.org/abs/1312.5602)
13 | - [Graying the black box: Understanding dqns. Proceedings of the 33th international conference on machine learning (ICML) - Zahavy et al](https://arxiv.org/abs/1602.02658)
14 |
--------------------------------------------------------------------------------
/Utilities/README.md:
--------------------------------------------------------------------------------
1 | # Utilities
2 | __This is in development stage__
3 |
4 | - activations_hdf5_to_tsne.lua - use this in order to extract 2d tSNE map of the activations. (also left some code to convert data structure from t7 to h5 if needed).
5 | - clean_data.py - runs over data and keeps only trajectories that result in a successful result. We don't want to learn a skill of the agent stuck in the corner.
6 | - cluster.py - will run multiple Gaussian Mixture Models in order to find the optimal # of clusters to fit the data.
7 | - learn_weights.py - tensorflow, will learn the weights + save them. Weights are an extension of the "expert network".
8 |
--------------------------------------------------------------------------------
/Utilities/activations_hdf5_to_tsne.lua:
--------------------------------------------------------------------------------
1 | require 'hdf5'
2 | require 'unsup'
3 | m = require 'manifold'
4 | bh_tsne = require('tsne')
5 |
6 | local cmd = torch.CmdLine()
7 | cmd:text()
8 | cmd:option('-N', '', 'number of states')
9 | cmd:text()
10 | cmd:option('-iter', '', 'number of iterations')
11 | cmd:option('-convert', false, 'convert t7 to hdf5')
12 | local args = cmd:parse(arg)
13 | local N = tonumber(args.N)
14 | local maxiter = tonumber(args.iter)
15 |
16 | local pca_dims = 50
17 |
18 | local convert_t7_to_hdf5 = cmd:parse(args.convert)
19 |
20 | if convert_t7_to_hdf5 then
21 | local myFile = torch.DiskFile('./dqn_distill/hdrln_activations.t7', 'r')
22 | local data = myFile:readObject()--myFile:read('data'):all()
23 | else
24 | local myFile = hdf5.open('./dqn_distill/activationsClean.h5', 'r')
25 | local data = myFile:read('data'):all()
26 | end
27 | myFile:close()
28 |
29 | print('Data loaded ' .. data:size(1) .. ' ' .. data:size(2))
30 |
31 | if convert_t7_to_hdf5 then
32 | data = data:narrow(1, 1, N):double()
33 | local myFile2= hdf5.open('./dqn_distill/activations.h5', 'w')
34 | myFile2:write('data', data)
35 | myFile2:close()
36 | print('saved activations')
37 |
38 | local myFile = torch.DiskFile('./dqn_distill/hdrln_actions.t7', 'r') --hdf5.open('./dqn/tmp/global_activations.h5', 'r')
39 | local data = myFile:readObject()--myFile:read('data'):all()
40 | myFile:close()
41 | data = data:narrow(1, 1, N)
42 | local myFile2= hdf5.open('./dqn_distill/actions.h5', 'w')
43 | myFile2:write('data', data)
44 | myFile2:close()
45 | print('saved actions')
46 |
47 | local myFile = torch.DiskFile('./dqn_distill/hdrln_fullstates.t7', 'r') --hdf5.open('./dqn/tmp/global_activations.h5', 'r')
48 | local data = myFile:readObject()--myFile:read('data'):all()
49 | myFile:close()
50 | data = data:narrow(1, 1, N)
51 | local myFile2= hdf5.open('./dqn_distill/screens.h5', 'w')
52 | myFile2:write('data', data)
53 | myFile2:close()
54 | print('saved screens')
55 |
56 | local myFile = torch.DiskFile('./dqn_distill/hdrln_qvals.t7', 'r') --hdf5.open('./dqn/tmp/global_activations.h5', 'r')
57 | local data = myFile:readObject()--myFile:read('data'):all()
58 | myFile:close()
59 | data = data:narrow(1, 1, N)
60 | local myFile2= hdf5.open('./dqn_distill/qvals.h5', 'w')
61 | myFile2:write('data', data)
62 | myFile2:close()
63 | print('saved qvals')
64 |
65 | local myFile = torch.DiskFile('./dqn_distill/hdrln_rewards.t7', 'r') --hdf5.open('./dqn/tmp/global_activations.h5', 'r')
66 | local data = myFile:readObject()--myFile:read('data'):all()
67 | myFile:close()
68 | data = data:narrow(1, 1, N)
69 | local myFile2= hdf5.open('./dqn_distill/reward.h5', 'w')
70 | myFile2:write('data', data)
71 | myFile2:close()
72 | print('saved reward')
73 |
74 | local myFile = torch.DiskFile('./dqn_distill/hdrln_statespace.t7', 'r') --hdf5.open('./dqn/tmp/global_activations.h5', 'r')
75 | local data = myFile:readObject()--myFile:read('data'):all()
76 | myFile:close()
77 | data = data:narrow(1, 1, N)
78 | local myFile2= hdf5.open('./dqn_distill/states.h5', 'w')
79 | myFile2:write('data', data)
80 | myFile2:close()
81 | print('saved states')
82 |
83 | local myFile = torch.DiskFile('./dqn_distill/hdrln_terminal.t7', 'r') --hdf5.open('./dqn/tmp/global_activations.h5', 'r')
84 | local data = myFile:readObject()--myFile:read('data'):all()
85 | myFile:close()
86 | data = data:narrow(1, 1, N)
87 | local myFile2= hdf5.open('./dqn_distill/termination.h5', 'w')
88 | myFile2:write('data', data)
89 | myFile2:close()
90 | print('saved termination')
91 | else
92 | data = data:double()
93 |
94 | local p = 512
95 | -- perform pca on the transpose
96 | local mean = torch.mean(data,1)
97 | local xm = data - torch.ger(torch.ones(data:size(1)),mean:squeeze())
98 | local c = torch.mm(xm:t(),xm)
99 | c:div(data:size(1)-1)
100 | local ce,cv = torch.symeig(c,'V')
101 |
102 | cv = cv:narrow(2,p-pca_dims+1,pca_dims)
103 | data = torch.mm(data, cv)
104 |
105 | --
106 | --opts = {ndims = 2, perplexity = 50,pca = 50, use_bh = true, theta = 0.5}
107 | for perp = 750,850,100
108 | do
109 | opts = {ndims = 2, perplexity = perp, use_bh = true, pca=pca_dims, theta = 0.5, n_iter=maxiter, max_iter = maxiter, method='barnes_hut'}
110 | print('run t-SNE')
111 | --mapped_activations = m.embedding.tsne(data:double(), opts)
112 | mapped_activations = bh_tsne(data:double(), opts)
113 | print('save t-SNE')
114 | local myFile2= hdf5.open('./dqn_distill/lowd_activations_'..perp..'.h5', 'w')
115 | myFile2:write('data', mapped_activations)
116 | myFile2:close()
117 | print('saved t-SNE')
118 | end
119 | end
120 |
--------------------------------------------------------------------------------
/Utilities/clean_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import h5py
3 |
4 | debug = False
5 |
6 | termination = h5py.File('./termination.h5', 'r').get('data')
7 | reward = h5py.File('./reward.h5', 'r').get('data')
8 | activations = h5py.File('./activations.h5', 'r').get('data')
9 | actions = h5py.File('./actions.h5', 'r').get('data')
10 | qvals = h5py.File('./qvals.h5', 'r').get('data')
11 | states = h5py.File('./states.h5', 'r').get('data')
12 |
13 | if debug:
14 | print('Shape of termination: \n', (np.array(termination)).shape)
15 | print('Shape of activations: \n', (np.array(activations)).shape)
16 | print('Shape of actions: \n', (np.array(actions)).shape)
17 | print('Shape of reward: \n', (np.array(reward)).shape)
18 | print('Shape of qvals: \n', (np.array(qvals)).shape)
19 | print('Shape of states: \n', (np.array(states)).shape)
20 |
21 | startTrajectory = [] # first index in a trajectory that leads to success
22 | endTrajectory = [] # last index in a trajectory that leads to success
23 |
24 | initialIndex = 0
25 | for i in range(len(termination) - 1):
26 | if debug:
27 | print(str(i) + ' , ' + str(termination[i]) + ' , ' + str(reward[i]))
28 | if termination[i + 1] == 1 and reward[i] == 0:
29 | startTrajectory.append(initialIndex)
30 | endTrajectory.append(i)
31 | if debug:
32 | print('Success: ' + str(initialIndex) + ' , ' + str(i))
33 | initialIndex = i + 2
34 | elif termination[i + 1] == 1:
35 | initialIndex = i + 2
36 |
37 | rewardClean = []
38 | activationsClean = []
39 | qvalsClean = []
40 | statesClean = []
41 | actionsClean = []
42 |
43 | totalStates = 0
44 |
45 | for i in range(len(startTrajectory)):
46 | totalStates += endTrajectory[i] - startTrajectory[i] + 1
47 | for j in range(startTrajectory[i], endTrajectory[i]):
48 | rewardClean.append(reward[j])
49 | activationsClean.append(activations[j, :])
50 | qvalsClean.append(qvals[j])
51 | actionsClean.append(actions[j])
52 | statesClean.append(states[j, :])
53 |
54 | rCleanFile = h5py.File('rewardClean.h5', 'w')
55 | rCleanFile.create_dataset('data', data=rewardClean)
56 |
57 | aCleanFile = h5py.File('actionsClean.h5', 'w')
58 | aCleanFile.create_dataset('data', data=actionsClean)
59 |
60 | actCleanFile = h5py.File('activationsClean.h5', 'w')
61 | actCleanFile.create_dataset('data', data=activationsClean)
62 |
63 | qCleanFile = h5py.File('qvalsClean.h5', 'w')
64 | qCleanFile.create_dataset('data', data=qvalsClean)
65 |
66 | sCleanFile = h5py.File('statesClean.h5', 'w')
67 | sCleanFile.create_dataset('data', data=statesClean)
68 |
69 | print('Done! Total states kept: ' + str(totalStates))
70 |
--------------------------------------------------------------------------------
/Utilities/learn_weights.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import argparse
6 | import sys
7 | import h5py
8 | import numpy as np
9 | import math
10 |
11 | import tensorflow as tf
12 |
13 | FLAGS = None
14 |
15 | def main(_):
16 | hiddenWidth1 = 100
17 | hiddenWidth2 = 64
18 | outputWidth = 5
19 | weightInit = -1
20 | batchSize = 4
21 | gamma = 0.7
22 |
23 | dataOut = h5py.File('skillWeightsQ.h5', 'w') #_2Layer.h5', 'w')
24 |
25 | # Import data
26 | print('Loading data...')
27 | data = h5py.File(FLAGS.file, 'r')
28 | numSkills = data.get('numberSkills')
29 | print('Number of skills is ' + str(numSkills[()]))
30 |
31 | dataOut.create_dataset('hiddenWidth', data=hiddenWidth1)
32 | dataOut.create_dataset('numberSkills', data=numSkills)
33 |
34 | for skill in range(numSkills[()]):
35 | activations = np.array(data.get('activations_' + str(skill)))
36 | actions = (np.array(data.get('actions_' + str(skill))) - 1)
37 | termination = np.array(data.get('termination_' + str(skill)))
38 |
39 | print('Creating model...')
40 | # Create the model
41 | step = tf.Variable(0, trainable=False) # cant attach non trainable variable to gpu
42 | with tf.device('/gpu:1'):
43 | x = tf.placeholder(tf.float32, [None, 512, ])
44 |
45 |
46 | # Hidden Layer1
47 | W_hidden1 = tf.Variable(tf.truncated_normal([512, hiddenWidth1], stddev=0.1))
48 | b_hidden1 = tf.Variable(tf.constant(0.1, shape=[hiddenWidth1]))
49 | y_hidden1 = tf.add(tf.matmul(x, W_hidden1), b_hidden1)
50 | act_hidden1 = tf.nn.relu(y_hidden1)
51 |
52 | # Hidden Layer2
53 | W_hidden2 = tf.Variable(tf.random_uniform([hiddenWidth1, hiddenWidth2], weightInit, 1))
54 | b_hidden2 = tf.Variable(tf.random_uniform([hiddenWidth2], weightInit, 1))
55 | y_hidden2 = tf.add(tf.matmul(act_hidden1, W_hidden2), b_hidden2)
56 | act_hidden2 = tf.nn.relu(y_hidden2)
57 |
58 | # Output Layer
59 | W_output = tf.Variable(tf.truncated_normal([hiddenWidth1, outputWidth], stddev=0.1))
60 | #W_output = tf.Variable(tf.truncated_normal([hiddenWidth2, outputWidth], stddev=0.1))
61 | b_output = tf.Variable(tf.constant(0.1, shape=[outputWidth]))
62 | y = tf.add(tf.matmul(act_hidden1, W_output), b_output)
63 | #y = tf.add(tf.matmul(act_hidden2, W_output), b_output)
64 | predict = tf.argmax(y, 1)
65 | '''
66 |
67 | # Linear only
68 | W = tf.Variable(tf.random_uniform([512, outputWidth], weightInit, 0.01))
69 | b = tf.Variable(tf.random_uniform([outputWidth], weightInit, 0.01))
70 | y = tf.add(tf.matmul(x, W), b)
71 | predict = tf.argmax(y, 1)
72 | '''
73 | nextQ = tf.placeholder(shape=[None, outputWidth, ], dtype=tf.float32)
74 | loss = tf.reduce_sum(tf.square(nextQ - y))
75 | #loss = tf.nn.softmax_cross_entropy_with_logits(y, nextQ)
76 |
77 | rate = tf.train.exponential_decay(0.0005, step, 250, 0.9999)
78 | trainer = tf.train.AdamOptimizer(rate) #learning_rate=0.000001) # GradientDescentOptimizer(learning_rate=0.0001)
79 | updateModel = trainer.minimize(loss, global_step=step)
80 |
81 | # train_step = tf.train.AdamOptimizer().minimize(cross_entropy)
82 | # train_step = tf.train.RMSPropOptimizer(0.1).minimize(cross_entropy)
83 | # train_step = tf.train.GradientDescentOptimizer(0.1).minimize(cross_entropy)
84 |
85 | sess = tf.InteractiveSession(config=tf.ConfigProto(log_device_placement=True))
86 | #tf.global_variables_initializer().run()
87 | tf.initialize_all_variables().run()
88 | #sess.run(tf.initialize_all_variables())
89 | # Train
90 | maxQ = 1
91 | iteration = 0
92 | print('Training...')
93 | for _ in range(20000000):
94 | if (_ % 1000000 == 0 and _ > 0): # and False):
95 | testPredictions = sess.run(predict, feed_dict={x: activations[int(math.ceil(activations.shape[0] * 0.8)) + 1:activations.shape[0],:]})
96 | trainPredictions = sess.run(predict, feed_dict={x: activations[0:int(math.ceil(activations.shape[0] * 0.8)),:]})
97 | print('Done ' + str(_) + ' iterations. testing error is: ' + str(100 * np.sum(np.sign(np.absolute(testPredictions - actions[int(math.ceil(activations.shape[0] * 0.8)) + 1:activations.shape[0]]))) * 1.0 / (activations.shape[0] - int(math.ceil(activations.shape[0] * 0.8)) + 1)) + '%, training error is: ' + str(100 * np.sum(np.sign(np.absolute(trainPredictions - actions[0:int(math.ceil(activations.shape[0] * 0.8))]))) * 1.0 / (int(math.ceil(activations.shape[0] * 0.8)))) + '%')
98 | print('Loss: ' + str(loss_val) + ', Skill#: ' + str(skill))
99 |
100 | index = np.random.randint(int(math.ceil(activations.shape[0] * 0.8)), size=batchSize)
101 | '''
102 | iteration = iteration + 1
103 | iteration = iteration % int(math.ceil(activations.shape[0] * 0.8))
104 | index = np.array([iteration])
105 | '''
106 |
107 | allQ = sess.run(y,feed_dict={x: activations[index, :]})
108 |
109 | Q1 = sess.run(y,feed_dict={x: activations[index + 1, :]})
110 | targetQ = np.ones(allQ.shape) * -1
111 | #targetQ = allQ
112 | for i in range(index.shape[0]):
113 | if termination[index[i]] == 1:
114 | Q = 0
115 | else:
116 | Q = np.max(Q1[i, :]) * gamma
117 |
118 | # maxQ = max(maxQ, abs(Q))
119 | targetQ[i, :] = targetQ[i, :] + Q - gamma * gamma
120 | targetQ[i, int(actions[index[i]])] = targetQ[i, int(actions[index[i]])] + gamma * gamma
121 | targetQ = targetQ * 1.0 / maxQ
122 |
123 | '''
124 | targetQ = np.zeros(allQ.shape)
125 | for i in range(index.shape[0]):
126 | targetQ[i, int(actions[index[i]])] = 1
127 | '''
128 |
129 | _, loss_val = sess.run([updateModel, loss], feed_dict={x: activations[index, :], nextQ: targetQ})
130 |
131 | # Test trained model
132 | print('Testing model on ' + str(len(actions[int(math.ceil(activations.shape[0] * 0.8)) + 1:activations.shape[0]])) + ' samples...')
133 |
134 | prediction = tf.argmax(y,1)
135 | predictions = prediction.eval(feed_dict={x: activations[int(math.ceil(activations.shape[0] * 0.8)) + 1:activations.shape[0],:]}, session=sess)
136 | # print(predictions)
137 | # print(actions[int(math.ceil(np.array(activations).shape[0] * 0.8)) + 1:np.array(activations).shape[0]])
138 | print('Testing error:')
139 | print(100 * np.sum(np.sign(np.absolute(predictions - actions[int(math.ceil(activations.shape[0] * 0.8)) + 1:activations.shape[0]]))) * 1.0 / (activations.shape[0] - int(math.ceil(activations.shape[0] * 0.8)) + 1))
140 | print('Training error:')
141 | predictions = prediction.eval(feed_dict={x: activations[0:int(math.ceil(activations.shape[0] * 0.8)),:]}, session=sess)
142 | print(100 * np.sum(np.sign(np.absolute(predictions - actions[0:int(math.ceil(activations.shape[0] * 0.8))]))) * 1.0 / (int(math.ceil(activations.shape[0] * 0.8))))
143 |
144 | dataOut.create_dataset('W_hidden_' + str(skill), data=sess.run(W_hidden1))
145 | dataOut.create_dataset('b_hidden_' + str(skill), data=sess.run(b_hidden1))
146 | dataOut.create_dataset('W_output_' + str(skill), data=sess.run(W_output))
147 | dataOut.create_dataset('b_output_' + str(skill), data=sess.run(b_output))
148 |
149 | if __name__ == '__main__':
150 | parser = argparse.ArgumentParser()
151 | parser.add_argument('-file', type=str, required=True,
152 | help='Name of Skill extraction file.')
153 | FLAGS, unparsed = parser.parse_known_args()
154 | tf.app.run(main=main)
155 |
--------------------------------------------------------------------------------
/graying_the_box/.gitignore:
--------------------------------------------------------------------------------
1 | *.bin
2 | *.h5
3 | *.pyc
4 | *.so
5 | *.o
6 | data/
--------------------------------------------------------------------------------
/graying_the_box/.idea/.name:
--------------------------------------------------------------------------------
1 | graying_the_box
--------------------------------------------------------------------------------
/graying_the_box/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/graying_the_box/.idea/graying_the_box.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/graying_the_box/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/graying_the_box/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/graying_the_box/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/graying_the_box/LUA/dqn/Rectifier.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2014 Google Inc.
3 |
4 | See LICENSE file for full terms of limited license.
5 | ]]
6 |
7 | --[[ Rectified Linear Unit.
8 |
9 | The output is max(0, input).
10 | --]]
11 |
12 | local Rectifier, parent = torch.class('nn.Rectifier', 'nn.Module')
13 |
14 | -- This module accepts minibatches
15 | function Rectifier:updateOutput(input)
16 | return self.output:resizeAs(input):copy(input):abs():add(input):div(2)
17 | end
18 |
19 | function Rectifier:updateGradInput(input, gradOutput)
20 | self.gradInput:resizeAs(self.output)
21 | return self.gradInput:sign(self.output):cmul(gradOutput)
22 | end
--------------------------------------------------------------------------------
/graying_the_box/LUA/dqn/Scale.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2014 Google Inc.
3 |
4 | See LICENSE file for full terms of limited license.
5 | ]]
6 |
7 | require "nn"
8 | require "image"
9 |
10 | local scale = torch.class('nn.Scale', 'nn.Module')
11 |
12 |
13 | function scale:__init(height, width)
14 | self.height = height
15 | self.width = width
16 | end
17 |
18 | function scale:forward(x)
19 | local x = x
20 | if x:dim() > 3 then
21 | x = x[1]
22 | end
23 |
24 | x = image.rgb2y(x)
25 | x = image.scale(x, self.width, self.height, 'bilinear')
26 | return x
27 | end
28 |
29 | function scale:updateOutput(input)
30 | return self:forward(input)
31 | end
32 |
33 | function scale:float()
34 | end
35 |
--------------------------------------------------------------------------------
/graying_the_box/LUA/dqn/convnet.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2014 Google Inc.
3 |
4 | See LICENSE file for full terms of limited license.
5 | ]]
6 |
7 | require "initenv"
8 |
9 | function create_network(args)
10 |
11 | local net = nn.Sequential()
12 | net:add(nn.Reshape(unpack(args.input_dims)))
13 |
14 | --- first convolutional layer
15 | local convLayer = nn.SpatialConvolution
16 |
17 | net:add(convLayer(args.hist_len*args.ncols, args.n_units[1],
18 | args.filter_size[1], args.filter_size[1],
19 | args.filter_stride[1], args.filter_stride[1],1))
20 | net:add(args.nl())
21 |
22 | -- Add convolutional layers
23 | for i=1,(#args.n_units-1) do
24 | -- second convolutional layer
25 | net:add(convLayer(args.n_units[i], args.n_units[i+1],
26 | args.filter_size[i+1], args.filter_size[i+1],
27 | args.filter_stride[i+1], args.filter_stride[i+1]))
28 | net:add(args.nl())
29 | end
30 |
31 | local nel
32 | if args.gpu >= 0 then
33 | nel = net:cuda():forward(torch.zeros(1,unpack(args.input_dims))
34 | :cuda()):nElement()
35 | else
36 | nel = net:forward(torch.zeros(1,unpack(args.input_dims))):nElement()
37 | end
38 |
39 | -- reshape all feature planes into a vector per example
40 | net:add(nn.Reshape(nel))
41 |
42 | -- fully connected layer
43 | net:add(nn.Linear(nel, args.n_hid[1]))
44 | net:add(args.nl())
45 | local last_layer_size = args.n_hid[1]
46 |
47 | for i=1,(#args.n_hid-1) do
48 | -- add Linear layer
49 | last_layer_size = args.n_hid[i+1]
50 | net:add(nn.Linear(args.n_hid[i], last_layer_size))
51 | net:add(args.nl())
52 | end
53 |
54 | -- add the last fully connected layer (to actions)
55 | net:add(nn.Linear(last_layer_size, args.n_actions))
56 |
57 | if args.gpu >=0 then
58 | net:cuda()
59 | end
60 | if args.verbose >= 2 then
61 | print(net)
62 | print('Convolutional layers flattened output size:', nel)
63 | end
64 | return net
65 | end
66 |
--------------------------------------------------------------------------------
/graying_the_box/LUA/dqn/convnet_atari3.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2014 Google Inc.
3 |
4 | See LICENSE file for full terms of limited license.
5 | ]]
6 |
7 | require 'convnet'
8 |
9 | return function(args)
10 | args.n_units = {32, 64, 64}
11 | args.filter_size = {8, 4, 3}
12 | args.filter_stride = {4, 2, 1}
13 | args.n_hid = {512}
14 | args.nl = nn.Rectifier
15 |
16 | return create_network(args)
17 | end
18 |
19 |
--------------------------------------------------------------------------------
/graying_the_box/LUA/dqn/initenv.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2014 Google Inc.
3 |
4 | See LICENSE file for full terms of limited license.
5 | ]]
6 | dqn = {}
7 |
8 | require 'torch'
9 | require 'nn'
10 | require 'nngraph'
11 | require 'nnutils'
12 | require 'image'
13 | require 'Scale'
14 | require 'NeuralQLearner'
15 | require 'TransitionTable'
16 | require 'Rectifier'
17 |
18 |
19 | function torchSetup(_opt)
20 | _opt = _opt or {}
21 | local opt = table.copy(_opt)
22 | assert(opt)
23 |
24 | -- preprocess options:
25 | --- convert options strings to tables
26 | if opt.pool_frms then
27 | opt.pool_frms = str_to_table(opt.pool_frms)
28 | end
29 | if opt.env_params then
30 | opt.env_params = str_to_table(opt.env_params)
31 | end
32 | if opt.agent_params then
33 | opt.agent_params = str_to_table(opt.agent_params)
34 | opt.agent_params.gpu = opt.gpu
35 | opt.agent_params.best = opt.best
36 | opt.agent_params.verbose = opt.verbose
37 | if opt.network ~= '' then
38 | opt.agent_params.network = opt.network
39 | end
40 | end
41 |
42 | --- general setup
43 | opt.tensorType = opt.tensorType or 'torch.FloatTensor'
44 | torch.setdefaulttensortype(opt.tensorType)
45 | if not opt.threads then
46 | opt.threads = 4
47 | end
48 | torch.setnumthreads(opt.threads)
49 | if not opt.verbose then
50 | opt.verbose = 10
51 | end
52 | if opt.verbose >= 1 then
53 | print('Torch Threads:', torch.getnumthreads())
54 | end
55 |
56 | --- set gpu device
57 | if opt.gpu and opt.gpu >= 0 then
58 | require 'cutorch'
59 | require 'cunn'
60 | if opt.gpu == 0 then
61 | local gpu_id = tonumber(os.getenv('GPU_ID'))
62 | if gpu_id then opt.gpu = gpu_id+1 end
63 | end
64 | if opt.gpu > 0 then cutorch.setDevice(opt.gpu) end
65 | opt.gpu = cutorch.getDevice()
66 | print('Using GPU device id:', opt.gpu-1)
67 | else
68 | opt.gpu = -1
69 | if opt.verbose >= 1 then
70 | print('Using CPU code only. GPU device id:', opt.gpu)
71 | end
72 | end
73 |
74 | --- set up random number generators
75 | -- removing lua RNG; seeding torch RNG with opt.seed and setting cutorch
76 | -- RNG seed to the first uniform random int32 from the previous RNG;
77 | -- this is preferred because using the same seed for both generators
78 | -- may introduce correlations; we assume that both torch RNGs ensure
79 | -- adequate dispersion for different seeds.
80 | math.random = nil
81 | opt.seed = opt.seed or 1
82 | torch.manualSeed(opt.seed)
83 | if opt.verbose >= 1 then
84 | print('Torch Seed:', torch.initialSeed())
85 | end
86 | local firstRandInt = torch.random()
87 | if opt.gpu >= 0 then
88 | cutorch.manualSeed(firstRandInt)
89 | if opt.verbose >= 1 then
90 | print('CUTorch Seed:', cutorch.initialSeed())
91 | end
92 | end
93 |
94 | return opt
95 | end
96 |
97 |
98 | function setup(_opt)
99 | assert(_opt)
100 |
101 | --preprocess options:
102 | --- convert options strings to tables
103 | _opt.pool_frms = str_to_table(_opt.pool_frms)
104 | _opt.env_params = str_to_table(_opt.env_params)
105 | _opt.agent_params = str_to_table(_opt.agent_params)
106 | if _opt.agent_params.transition_params then
107 | _opt.agent_params.transition_params =
108 | str_to_table(_opt.agent_params.transition_params)
109 | end
110 |
111 | --- first things first
112 | local opt = torchSetup(_opt)
113 |
114 | -- load training framework and environment
115 | local framework = require(opt.framework)
116 | assert(framework)
117 |
118 | local gameEnv = framework.GameEnvironment(opt)
119 | local gameActions = gameEnv:getActions()
120 |
121 | -- agent options
122 | _opt.agent_params.actions = gameActions
123 | _opt.agent_params.gpu = _opt.gpu
124 | _opt.agent_params.best = _opt.best
125 | if _opt.network ~= '' then
126 | _opt.agent_params.network = _opt.network
127 | end
128 | _opt.agent_params.verbose = _opt.verbose
129 | if not _opt.agent_params.state_dim then
130 | _opt.agent_params.state_dim = gameEnv:nObsFeature()
131 | end
132 |
133 | local agent = dqn[_opt.agent](_opt.agent_params)
134 |
135 | if opt.verbose >= 1 then
136 | print('Set up Torch using these options:')
137 | for k, v in pairs(opt) do
138 | print(k, v)
139 | end
140 | end
141 |
142 | return gameEnv, gameActions, agent, opt
143 | end
144 |
145 |
146 |
147 | --- other functions
148 |
149 | function str_to_table(str)
150 | if type(str) == 'table' then
151 | return str
152 | end
153 | if not str or type(str) ~= 'string' then
154 | if type(str) == 'table' then
155 | return str
156 | end
157 | return {}
158 | end
159 | local ttr
160 | if str ~= '' then
161 | local ttx=tt
162 | loadstring('tt = {' .. str .. '}')()
163 | ttr = tt
164 | tt = ttx
165 | else
166 | ttr = {}
167 | end
168 | return ttr
169 | end
170 |
171 | function table.copy(t)
172 | if t == nil then return nil end
173 | local nt = {}
174 | for k, v in pairs(t) do
175 | if type(v) == 'table' then
176 | nt[k] = table.copy(v)
177 | else
178 | nt[k] = v
179 | end
180 | end
181 | setmetatable(nt, table.copy(getmetatable(t)))
182 | return nt
183 | end
184 |
--------------------------------------------------------------------------------
/graying_the_box/LUA/dqn/net_downsample_2x_full_y.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2014 Google Inc.
3 |
4 | See LICENSE file for full terms of limited license.
5 | ]]
6 |
7 | require "image"
8 | require "Scale"
9 |
10 | local function create_network(args)
11 | -- Y (luminance)
12 | return nn.Scale(84, 84, true)
13 | end
14 |
15 | return create_network
16 |
--------------------------------------------------------------------------------
/graying_the_box/LUA/dqn/nnutils.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2014 Google Inc.
3 |
4 | See LICENSE file for full terms of limited license.
5 | ]]
6 |
7 | require "torch"
8 |
9 | function recursive_map(module, field, func)
10 | local str = ""
11 | if module[field] or module.modules then
12 | str = str .. torch.typename(module) .. ": "
13 | end
14 | if module[field] then
15 | str = str .. func(module[field])
16 | end
17 | if module.modules then
18 | str = str .. "["
19 | for i, submodule in ipairs(module.modules) do
20 | local submodule_str = recursive_map(submodule, field, func)
21 | str = str .. submodule_str
22 | if i < #module.modules and string.len(submodule_str) > 0 then
23 | str = str .. " "
24 | end
25 | end
26 | str = str .. "]"
27 | end
28 |
29 | return str
30 | end
31 |
32 | function abs_mean(w)
33 | return torch.mean(torch.abs(w:clone():float()))
34 | end
35 |
36 | function abs_max(w)
37 | return torch.abs(w:clone():float()):max()
38 | end
39 |
40 | -- Build a string of average absolute weight values for the modules in the
41 | -- given network.
42 | function get_weight_norms(module)
43 | return "Weight norms:\n" .. recursive_map(module, "weight", abs_mean) ..
44 | "\nWeight max:\n" .. recursive_map(module, "weight", abs_max)
45 | end
46 |
47 | -- Build a string of average absolute weight gradient values for the modules
48 | -- in the given network.
49 | function get_grad_norms(module)
50 | return "Weight grad norms:\n" ..
51 | recursive_map(module, "gradWeight", abs_mean) ..
52 | "\nWeight grad max:\n" .. recursive_map(module, "gradWeight", abs_max)
53 | end
54 |
--------------------------------------------------------------------------------
/graying_the_box/LUA/dqn/test_agent.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2014 Google Inc.
3 | See LICENSE file for full terms of limited license.
4 | ]]
5 |
6 | gd = require "gd"
7 |
8 | if not dqn then
9 | require "initenv"
10 | end
11 |
12 | local cmd = torch.CmdLine()
13 | cmd:text()
14 | cmd:text('Train Agent in Environment:')
15 | cmd:text()
16 | cmd:text('Options:')
17 |
18 | cmd:option('-framework', '', 'name of training framework')
19 | cmd:option('-env', '', 'name of environment to use')
20 | cmd:option('-game_path', '', 'path to environment file (ROM)')
21 | cmd:option('-env_params', '', 'string of environment parameters')
22 | cmd:option('-pool_frms', '',
23 | 'string of frame pooling parameters (e.g.: size=2,type="max")')
24 | cmd:option('-actrep', 1, 'how many times to repeat action')
25 | cmd:option('-random_starts', 0, 'play action 0 between 1 and random_starts ' ..
26 | 'number of times at the start of each training episode')
27 |
28 | cmd:option('-name', '', 'filename used for saving network and training history')
29 | cmd:option('-network', '', 'reload pretrained network')
30 | cmd:option('-agent', '', 'name of agent file to use')
31 | cmd:option('-agent_params', '', 'string of agent parameters')
32 | cmd:option('-seed', 1, 'fixed input seed for repeatable experiments')
33 |
34 | cmd:option('-verbose', 2,
35 | 'the higher the level, the more information is printed to screen')
36 | cmd:option('-threads', 1, 'number of BLAS threads')
37 | cmd:option('-gpu', -1, 'gpu flag')
38 | cmd:option('-gif_file', '', 'GIF path to write session screens')
39 | cmd:option('-csv_file', '', 'CSV path to write session data')
40 |
41 | cmd:text()
42 |
43 | local opt = cmd:parse(arg)
44 |
45 | --- General setup.
46 | local game_env, game_actions, agent, opt = setup(opt)
47 |
48 | -- override print to always flush the output
49 | local old_print = print
50 | local print = function(...)
51 | old_print(...)
52 | io.flush()
53 | end
54 |
55 | -- file names from command line
56 | local gif_filename = opt.gif_file
57 |
58 | -- start a new game
59 | local screen, reward, terminal = game_env:newGame()
60 |
61 | -- compress screen to JPEG with 100% quality
62 | local jpg = image.compressJPG(screen:squeeze(), 100)
63 | -- create gd image from JPEG string
64 | local im = gd.createFromJpegStr(jpg:storage():string())
65 | -- convert truecolor to palette
66 | im:trueColorToPalette(false, 256)
67 |
68 | -- write GIF header, use global palette and infinite looping
69 | im:gifAnimBegin(gif_filename, true, 0)
70 | -- write first frame
71 | im:gifAnimAdd(gif_filename, false, 0, 0, 7, gd.DISPOSAL_NONE)
72 |
73 | -- remember the image and show it first
74 | local previm = im
75 | local win = image.display({image=screen})
76 |
77 | print("Started playing...")
78 |
79 | -- play one episode (game)
80 | while not terminal do
81 | -- if action was chosen randomly, Q-value is 0
82 | agent.bestq = 0
83 |
84 | -- choose the best action
85 | local action_index = agent:perceive(reward, screen, terminal, true, 0.05)
86 |
87 | -- play game in test mode (episodes don't end when losing a life)
88 | screen, reward, terminal = game_env:step(game_actions[action_index], false)
89 |
90 | -- display screen
91 | image.display({image=screen, win=win})
92 |
93 | -- create gd image from tensor
94 | jpg = image.compressJPG(screen:squeeze(), 100)
95 | im = gd.createFromJpegStr(jpg:storage():string())
96 |
97 | -- use palette from previous (first) image
98 | im:trueColorToPalette(false, 256)
99 | im:paletteCopy(previm)
100 |
101 | -- write new GIF frame, no local palette, starting from left-top, 7ms delay
102 | im:gifAnimAdd(gif_filename, false, 0, 0, 7, gd.DISPOSAL_NONE)
103 | -- remember previous screen for optimal compression
104 | previm = im
105 |
106 | end
107 |
108 | -- end GIF animation and close CSV file
109 | gd.gifAnimEnd(gif_filename)
110 |
111 | print("Finished playing, close window to exit!")
--------------------------------------------------------------------------------
/graying_the_box/LUA/dqn/train_agent.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2014 Google Inc.
3 |
4 | See LICENSE file for full terms of limited license.
5 | ]]
6 |
7 | if not dqn then
8 | require "initenv"
9 | end
10 |
11 | local cmd = torch.CmdLine()
12 | cmd:text()
13 | cmd:text('Train Agent in Environment:')
14 | cmd:text()
15 | cmd:text('Options:')
16 |
17 | cmd:option('-framework', '', 'name of training framework')
18 | cmd:option('-env', '', 'name of environment to use')
19 | cmd:option('-game_path', '', 'path to environment file (ROM)')
20 | cmd:option('-env_params', '', 'string of environment parameters')
21 | cmd:option('-pool_frms', '',
22 | 'string of frame pooling parameters (e.g.: size=2,type="max")')
23 | cmd:option('-actrep', 1, 'how many times to repeat action')
24 | cmd:option('-random_starts', 0, 'play action 0 between 1 and random_starts ' ..
25 | 'number of times at the start of each training episode')
26 |
27 | cmd:option('-name', '', 'filename used for saving network and training history')
28 | cmd:option('-network', '', 'reload pretrained network')
29 | cmd:option('-agent', '', 'name of agent file to use')
30 | cmd:option('-agent_params', '', 'string of agent parameters')
31 | cmd:option('-seed', 1, 'fixed input seed for repeatable experiments')
32 | cmd:option('-saveNetworkParams', true,
33 | 'saves the agent network in a separate file') -- tom
34 | cmd:option('-prog_freq', 50000, 'frequency of progress output')
35 | cmd:option('-save_freq', 1000000, 'the model is saved every save_freq steps')
36 | cmd:option('-eval_freq', 10^4, 'frequency of greedy evaluation')
37 | cmd:option('-save_versions', 5, '') --tom
38 |
39 | cmd:option('-steps', 10^5, 'number of training steps to perform')
40 | cmd:option('-eval_steps', 10^5, 'number of evaluation steps')
41 |
42 | cmd:option('-verbose', 2,
43 | 'the higher the level, the more information is printed to screen')
44 | cmd:option('-threads', 1, 'number of BLAS threads')
45 | cmd:option('-gpu', -1, 'gpu flag')
46 |
47 | cmd:text()
48 |
49 | local opt = cmd:parse(arg)
50 |
51 | --- General setup.
52 | local game_env, game_actions, agent, opt = setup(opt)
53 |
54 | -- override print to always flush the output
55 | local old_print = print
56 | local print = function(...)
57 | old_print(...)
58 | io.flush()
59 | end
60 |
61 | local learn_start = agent.learn_start
62 | local start_time = sys.clock()
63 | local reward_counts = {}
64 | local episode_counts = {}
65 | local time_history = {}
66 | local v_history = {}
67 | local qmax_history = {}
68 | local td_history = {}
69 | local reward_history = {}
70 | local step = 0
71 | time_history[1] = 0
72 |
73 | local total_reward
74 | local nrewards
75 | local nepisodes
76 | local episode_reward
77 |
78 | local screen, reward, terminal = game_env:getState()
79 |
80 | print("Iteration ..", step)
81 | local win = nil
82 | win1 = nil
83 | win2 = nil
84 | win3 = nil
85 |
86 | -- Tom
87 | local testing_eps = 0.05
88 | local in_episode_time = 0
89 | local episode_index = 1
90 | local episode_reward = 0
91 | local life_count = 5
92 | -- Tom End
93 | while step < opt.steps do
94 | step = step + 1
95 |
96 | -- Tom
97 | --local action_index = agent:perceive(reward, screen, terminal) more arguments
98 | in_episode_time = in_episode_time + 1
99 | local action_index = agent:perceive(reward, screen, terminal)
100 |
101 | -- Tom End
102 |
103 | -- game over? get next game!
104 | if not terminal then
105 | screen, reward, terminal = game_env:step(game_actions[action_index], true)
106 | episode_reward = episode_reward + reward
107 | else
108 | -- Tom
109 | --[[agent:UpdateEpisodeData(episode_index,in_episode_time)
110 | in_episode_time = 0
111 | life_count = life_count - 1
112 | if life_count == 0 then
113 | --print("Episode Reward: ", episode_reward)
114 | agent:UpdateEpisodeReward(episode_index,episode_reward)
115 | life_count = 5
116 | episode_reward = 0
117 | end
118 | episode_index = episode_index +1
119 | -- Tom End
120 | --]]
121 | if opt.random_starts > 0 then
122 | screen, reward, terminal = game_env:nextRandomGame()
123 | else
124 | screen, reward, terminal = game_env:newGame()
125 | end
126 | end
127 |
128 | -- display screen
129 | win = image.display({image=screen, win=win})
130 |
131 | if step % opt.prog_freq == 0 then
132 | assert(step==agent.numSteps, 'trainer step: ' .. step ..
133 | ' & agent.numSteps: ' .. agent.numSteps)
134 | print("Steps: ", step)
135 | agent:report()
136 | print(agent.transitions.ER_Cluster_counter)
137 | collectgarbage()
138 | end
139 |
140 | if step%1000 == 0 then collectgarbage() end
141 |
142 | if step % opt.eval_freq == 0 and step > learn_start then
143 |
144 | screen, reward, terminal = game_env:newGame()
145 |
146 | total_reward = 0
147 | nrewards = 0
148 | nepisodes = 0
149 | episode_reward = 0
150 |
151 | local eval_time = sys.clock()
152 | for estep=1,opt.eval_steps do
153 | local action_index = agent:perceive(reward, screen, terminal, true, testing_eps)
154 |
155 | -- Play game in test mode (episodes don't end when losing a life)
156 | screen, reward, terminal = game_env:step(game_actions[action_index])
157 |
158 | -- display screen
159 | win = image.display({image=screen, win=win})
160 |
161 | if estep%1000 == 0 then collectgarbage() end
162 |
163 | -- record every reward
164 | episode_reward = episode_reward + reward
165 | if reward ~= 0 then
166 | nrewards = nrewards + 1
167 | end
168 |
169 | if terminal then
170 | total_reward = total_reward + episode_reward
171 | episode_reward = 0
172 | nepisodes = nepisodes + 1
173 | screen, reward, terminal = game_env:nextRandomGame()
174 | end
175 | end
176 |
177 | eval_time = sys.clock() - eval_time
178 | start_time = start_time + eval_time
179 | agent:compute_validation_statistics()
180 | local ind = #reward_history+1
181 | total_reward = total_reward/math.max(1, nepisodes)
182 |
183 | if #reward_history == 0 or total_reward > torch.Tensor(reward_history):max() then
184 | agent.best_network = agent.network:clone()
185 | end
186 |
187 | if agent.v_avg then
188 | v_history[ind] = agent.v_avg
189 | td_history[ind] = agent.tderr_avg
190 | qmax_history[ind] = agent.q_max
191 | end
192 | print("V", v_history[ind], "TD error", td_history[ind], "Qmax", qmax_history[ind])
193 |
194 | reward_history[ind] = total_reward
195 | reward_counts[ind] = nrewards
196 | episode_counts[ind] = nepisodes
197 |
198 | time_history[ind+1] = sys.clock() - start_time
199 |
200 | local time_dif = time_history[ind+1] - time_history[ind]
201 |
202 | local training_rate = opt.actrep*opt.eval_freq/time_dif
203 |
204 | print(string.format(
205 | '\nSteps: %d (frames: %d), reward: %.2f, epsilon: %.2f, lr: %G, ' ..
206 | 'training time: %ds, training rate: %dfps, testing time: %ds, ' ..
207 | 'testing rate: %dfps, num. ep.: %d, num. rewards: %d',
208 | step, step*opt.actrep, total_reward, agent.ep, agent.lr, time_dif,
209 | training_rate, eval_time, opt.actrep*opt.eval_steps/eval_time,
210 | nepisodes, nrewards))
211 | end
212 |
213 | if step % opt.save_freq == 0 or step == opt.steps then
214 | local s, a, r, s2, term = agent.valid_s, agent.valid_a, agent.valid_r,
215 | agent.valid_s2, agent.valid_term
216 | agent.valid_s, agent.valid_a, agent.valid_r, agent.valid_s2,
217 | agent.valid_term = nil, nil, nil, nil, nil, nil, nil
218 | local w, dw, g, g2, delta, delta2, deltas, tmp = agent.w, agent.dw,
219 | agent.g, agent.g2, agent.delta, agent.delta2, agent.deltas, agent.tmp
220 | agent.w, agent.dw, agent.g, agent.g2, agent.delta, agent.delta2,
221 | agent.deltas, agent.tmp = nil, nil, nil, nil, nil, nil, nil, nil
222 |
223 | local filename = opt.name
224 | if opt.save_versions > 0 then
225 | filename = filename .. "_" .. math.floor(step / opt.save_versions)
226 | end
227 | filename = filename
228 | torch.save(filename .. ".t7", {agent = agent,
229 | model = agent.network,
230 | best_model = agent.best_network,
231 | reward_history = reward_history,
232 | reward_counts = reward_counts,
233 | episode_counts = episode_counts,
234 | time_history = time_history,
235 | v_history = v_history,
236 | td_history = td_history,
237 | qmax_history = qmax_history,
238 | arguments=opt})
239 | if opt.saveNetworkParams then
240 | local nets = {network=w:clone():float()}
241 | torch.save(filename..'.params.t7', nets, 'ascii')
242 | end
243 | agent.valid_s, agent.valid_a, agent.valid_r, agent.valid_s2,
244 | agent.valid_term = s, a, r, s2, term
245 | agent.w, agent.dw, agent.g, agent.g2, agent.delta, agent.delta2,
246 | agent.deltas, agent.tmp = w, dw, g, g2, delta, delta2, deltas, tmp
247 | print('Saved:', filename .. '.t7')
248 | io.flush()
249 | collectgarbage()
250 | end
251 | end
252 |
--------------------------------------------------------------------------------
/graying_the_box/LUA/dqn/train_agent_tmp.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2014 Google Inc.
3 | See LICENSE file for full terms of limited license.
4 | ]]
5 |
6 | if not dqn then
7 | require "initenv"
8 | end
9 |
10 | local cmd = torch.CmdLine()
11 | cmd:text()
12 | cmd:text('Train Agent in Environment:')
13 | cmd:text()
14 | cmd:text('Options:')
15 |
16 | cmd:option('-framework', '', 'name of training framework')
17 | cmd:option('-env', '', 'name of environment to use')
18 | cmd:option('-game_path', '', 'path to environment file (ROM)')
19 | cmd:option('-env_params', '', 'string of environment parameters')
20 | cmd:option('-pool_frms', '',
21 | 'string of frame pooling parameters (e.g.: size=2,type="max")')
22 | cmd:option('-actrep', 1, 'how many times to repeat action')
23 | cmd:option('-random_starts', 0, 'play action 0 between 1 and random_starts ' ..
24 | 'number of times at the start of each training episode')
25 |
26 | cmd:option('-name', '', 'filename used for saving network and training history')
27 | cmd:option('-network', '', 'reload pretrained network')
28 | cmd:option('-agent', '', 'name of agent file to use')
29 | cmd:option('-agent_params', '', 'string of agent parameters')
30 | cmd:option('-seed', 1, 'fixed input seed for repeatable experiments')
31 | cmd:option('-saveNetworkParams', false,
32 | 'saves the agent network in a separate file')
33 | cmd:option('-prog_freq', 5*10^3, 'frequency of progress output')
34 | cmd:option('-save_freq', 5*10^4, 'the model is saved every save_freq steps')
35 | cmd:option('-eval_freq', 10^4, 'frequency of greedy evaluation')
36 | cmd:option('-save_versions', 0, '')
37 |
38 | cmd:option('-steps', 10^5, 'number of training steps to perform')
39 | cmd:option('-eval_steps', 10^5, 'number of evaluation steps')
40 |
41 | cmd:option('-verbose', 2,
42 | 'the higher the level, the more information is printed to screen')
43 | cmd:option('-threads', 1, 'number of BLAS threads')
44 | cmd:option('-gpu', -1, 'gpu flag')
45 |
46 | cmd:text()
47 |
48 | local opt = cmd:parse(arg)
49 |
50 | --- General setup.
51 | local game_env, game_actions, agent, opt = setup(opt)
52 |
53 | -- override print to always flush the output
54 | local old_print = print
55 | local print = function(...)
56 | old_print(...)
57 | io.flush()
58 | end
59 |
60 | local learn_start = agent.learn_start
61 | local start_time = sys.clock()
62 | local reward_counts = {}
63 | local episode_counts = {}
64 | local time_history = {}
65 | local v_history = {}
66 | local qmax_history = {}
67 | local td_history = {}
68 | local reward_history = {}
69 | local step = 0
70 | time_history[1] = 0
71 |
72 | local total_reward
73 | local nrewards
74 | local nepisodes
75 | local episode_reward
76 |
77 | local screen, reward, terminal = game_env:getState()
78 |
79 | print("Iteration ..", step)
80 | local win = nil
81 | while step < opt.steps do
82 | step = step + 1
83 | local action_index = agent:perceive(reward, screen, terminal)
84 |
85 | -- game over? get next game!
86 | if not terminal then
87 | screen, reward, terminal = game_env:step(game_actions[action_index], true)
88 | else
89 | if opt.random_starts > 0 then
90 | screen, reward, terminal = game_env:nextRandomGame()
91 | else
92 | screen, reward, terminal = game_env:newGame()
93 | end
94 | end
95 |
96 | -- display screen
97 | win = image.display({image=screen, win=win})
98 |
99 | if step % opt.prog_freq == 0 then
100 | assert(step==agent.numSteps, 'trainer step: ' .. step ..
101 | ' & agent.numSteps: ' .. agent.numSteps)
102 | print("Steps: ", step)
103 | agent:report()
104 | collectgarbage()
105 | end
106 |
107 | if step%1000 == 0 then collectgarbage() end
108 |
109 | if step % opt.eval_freq == 0 and step > learn_start then
110 |
111 | screen, reward, terminal = game_env:newGame()
112 |
113 | total_reward = 0
114 | nrewards = 0
115 | nepisodes = 0
116 | episode_reward = 0
117 |
118 | local eval_time = sys.clock()
119 | for estep=1,opt.eval_steps do
120 | local action_index = agent:perceive(reward, screen, terminal, true, 0.05)
121 |
122 | -- Play game in test mode (episodes don't end when losing a life)
123 | screen, reward, terminal = game_env:step(game_actions[action_index])
124 |
125 | -- display screen
126 | win = image.display({image=screen, win=win})
127 |
128 | if estep%1000 == 0 then collectgarbage() end
129 |
130 | -- record every reward
131 | episode_reward = episode_reward + reward
132 | if reward ~= 0 then
133 | nrewards = nrewards + 1
134 | end
135 |
136 | if terminal then
137 | total_reward = total_reward + episode_reward
138 | episode_reward = 0
139 | nepisodes = nepisodes + 1
140 | screen, reward, terminal = game_env:nextRandomGame()
141 | end
142 | end
143 |
144 | eval_time = sys.clock() - eval_time
145 | start_time = start_time + eval_time
146 | agent:compute_validation_statistics()
147 | local ind = #reward_history+1
148 | total_reward = total_reward/math.max(1, nepisodes)
149 |
150 | if #reward_history == 0 or total_reward > torch.Tensor(reward_history):max() then
151 | agent.best_network = agent.network:clone()
152 | end
153 |
154 | if agent.v_avg then
155 | v_history[ind] = agent.v_avg
156 | td_history[ind] = agent.tderr_avg
157 | qmax_history[ind] = agent.q_max
158 | end
159 | print("V", v_history[ind], "TD error", td_history[ind], "Qmax", qmax_history[ind])
160 |
161 | reward_history[ind] = total_reward
162 | reward_counts[ind] = nrewards
163 | episode_counts[ind] = nepisodes
164 |
165 | time_history[ind+1] = sys.clock() - start_time
166 |
167 | local time_dif = time_history[ind+1] - time_history[ind]
168 |
169 | local training_rate = opt.actrep*opt.eval_freq/time_dif
170 |
171 | print(string.format(
172 | '\nSteps: %d (frames: %d), reward: %.2f, epsilon: %.2f, lr: %G, ' ..
173 | 'training time: %ds, training rate: %dfps, testing time: %ds, ' ..
174 | 'testing rate: %dfps, num. ep.: %d, num. rewards: %d',
175 | step, step*opt.actrep, total_reward, agent.ep, agent.lr, time_dif,
176 | training_rate, eval_time, opt.actrep*opt.eval_steps/eval_time,
177 | nepisodes, nrewards))
178 | end
179 |
180 | if step % opt.save_freq == 0 or step == opt.steps then
181 | local s, a, r, s2, term = agent.valid_s, agent.valid_a, agent.valid_r,
182 | agent.valid_s2, agent.valid_term
183 | agent.valid_s, agent.valid_a, agent.valid_r, agent.valid_s2,
184 | agent.valid_term = nil, nil, nil, nil, nil, nil, nil
185 | local w, dw, g, g2, delta, delta2, deltas, tmp = agent.w, agent.dw,
186 | agent.g, agent.g2, agent.delta, agent.delta2, agent.deltas, agent.tmp
187 | agent.w, agent.dw, agent.g, agent.g2, agent.delta, agent.delta2,
188 | agent.deltas, agent.tmp = nil, nil, nil, nil, nil, nil, nil, nil
189 |
190 | local filename = opt.name
191 | if opt.save_versions > 0 then
192 | filename = filename .. "_" .. math.floor(step / opt.save_versions)
193 | end
194 | filename = filename
195 | torch.save(filename .. ".t7", {agent = agent,
196 | model = agent.network,
197 | best_model = agent.best_network,
198 | reward_history = reward_history,
199 | reward_counts = reward_counts,
200 | episode_counts = episode_counts,
201 | time_history = time_history,
202 | v_history = v_history,
203 | td_history = td_history,
204 | qmax_history = qmax_history,
205 | arguments=opt})
206 | if opt.saveNetworkParams then
207 | local nets = {network=w:clone():float()}
208 | torch.save(filename..'.params.t7', nets, 'ascii')
209 | end
210 | agent.valid_s, agent.valid_a, agent.valid_r, agent.valid_s2,
211 | agent.valid_term = s, a, r, s2, term
212 | agent.w, agent.dw, agent.g, agent.g2, agent.delta, agent.delta2,
213 | agent.deltas, agent.tmp = w, dw, g, g2, delta, delta2, deltas, tmp
214 | print('Saved:', filename .. '.t7')
215 | io.flush()
216 | collectgarbage()
217 | end
218 | end
219 |
--------------------------------------------------------------------------------
/graying_the_box/LUA/logs/Results/Plot.lua:
--------------------------------------------------------------------------------
1 | require 'gnuplot'
2 |
3 | n = 0
4 | local file1 = io.open("./Results/reward.txt")
5 | if file1 then
6 | for line in file1:lines() do
7 | n = n+1
8 | end
9 | end
10 | file1:close()
11 |
12 | m = 0
13 | local file1 = io.open("./Results/conv1.txt")
14 | if file1 then
15 | for line in file1:lines() do
16 | m = m+1
17 | end
18 | end
19 | file1:close()
20 | m = m-1
21 | local file1 = io.open("./Results/reward.txt")
22 | local file2 = io.open("./Results/TD.txt")
23 | local file3 = io.open("./Results/vavg.txt")
24 | local file4 = io.open("./Results/conv1.txt")
25 | local file5 = io.open("./Results/conv2.txt")
26 | local file6 = io.open("./Results/conv3.txt")
27 | local file7 = io.open("./Results/lin1.txt")
28 | local file8 = io.open("./Results/lin2.txt")
29 |
30 | i = 1
31 | conv1_norm = torch.Tensor(m/4-1)
32 | conv1_norm_max = torch.Tensor(m/4-1)
33 | conv1_grad = torch.Tensor(m/4-1)
34 | conv1_grad_max = torch.Tensor(m/4-1)
35 |
36 | if file4 then
37 | for line in file4:lines() do
38 | if i == 1 then
39 | print(i)
40 | elseif i == m-4 then
41 | break
42 | elseif i % 4 == 2 then
43 | conv1_norm[math.floor(i/4)+1] = tonumber(line)
44 | elseif i % 4 == 3 then
45 | conv1_norm_max[math.floor(i/4)+1] = tonumber(line)
46 | elseif i % 4 == 0 then
47 | conv1_grad[math.floor(i/4)+1] = tonumber(line)
48 | elseif i % 4 == 1 then
49 | conv1_grad_max[math.floor(i/4)+1] = tonumber(line)
50 | end
51 | i = i+1
52 | end
53 | end
54 |
55 | gnuplot.pngfigure('./Results/tmp/conv1.png')
56 | gnuplot.title('conv 1 over training')
57 | gnuplot.plot({'Norm',conv1_norm},{'Grad',conv1_grad})
58 | gnuplot.xlabel('Training epochs')
59 | gnuplot.movelegend('left','top')
60 | gnuplot.plotflush()
61 |
62 |
63 | gnuplot.pngfigure('./Results/tmp/conv1_max.png')
64 | gnuplot.title('conv 1 max over training')
65 | gnuplot.plot({'Max norm',conv1_norm_max},{'Max grad',conv1_grad_max})
66 | gnuplot.xlabel('Training epochs')
67 | gnuplot.movelegend('left','top')
68 | gnuplot.plotflush()
69 |
70 |
71 | i = 1
72 | conv2_norm = torch.Tensor(m/4-1)
73 | conv2_norm_max = torch.Tensor(m/4-1)
74 | conv2_grad = torch.Tensor(m/4-1)
75 | conv2_grad_max = torch.Tensor(m/4-1)
76 |
77 | if file5 then
78 | for line in file5:lines() do
79 | if i == 1 then
80 | print(i)
81 | elseif i == m-4 then
82 | break
83 | elseif i % 4 == 2 then
84 | conv2_norm[math.floor(i/4)+1] = tonumber(line)
85 | elseif i % 4 == 3 then
86 | conv2_norm_max[math.floor(i/4)+1] = tonumber(line)
87 | elseif i % 4 == 0 then
88 | conv2_grad[math.floor(i/4)+1] = tonumber(line)
89 | elseif i % 4 == 1 then
90 | conv2_grad_max[math.floor(i/4)+1] = tonumber(line)
91 | end
92 | i = i+1
93 | end
94 | end
95 |
96 | gnuplot.pngfigure('./Results/tmp/conv2.png')
97 | gnuplot.title('conv 2 over training')
98 | gnuplot.plot({'Norm',conv2_norm},{'Grad',conv2_grad})
99 | gnuplot.xlabel('Training epochs')
100 | gnuplot.movelegend('left','top')
101 | gnuplot.plotflush()
102 |
103 |
104 | gnuplot.pngfigure('./Results/tmp/conv2_max.png')
105 | gnuplot.title('conv 2 max over training')
106 | gnuplot.plot({'Max norm',conv2_norm_max},{'Max grad',conv2_grad_max})
107 | gnuplot.xlabel('Training epochs')
108 | gnuplot.movelegend('left','top')
109 | gnuplot.plotflush()
110 |
111 | i = 1
112 | conv3_norm = torch.Tensor(m/4-1)
113 | conv3_norm_max = torch.Tensor(m/4-1)
114 | conv3_grad = torch.Tensor(m/4-1)
115 | conv3_grad_max = torch.Tensor(m/4-1)
116 |
117 | if file6 then
118 | for line in file6:lines() do
119 | if i == 1 then
120 | print(i)
121 | elseif i == m-4 then
122 | break
123 | elseif i % 4 == 2 then
124 | conv3_norm[math.floor(i/4)+1] = tonumber(line)
125 | elseif i % 4 == 3 then
126 | conv3_norm_max[math.floor(i/4)+1] = tonumber(line)
127 | elseif i % 4 == 0 then
128 | conv3_grad[math.floor(i/4)+1] = tonumber(line)
129 | elseif i % 4 == 1 then
130 | conv3_grad_max[math.floor(i/4)+1] = tonumber(line)
131 | end
132 | i = i+1
133 | end
134 | end
135 |
136 | gnuplot.pngfigure('./Results/tmp/conv3.png')
137 | gnuplot.title('conv 3 over training')
138 | gnuplot.plot({'Norm',conv3_norm},{'Grad',conv3_grad})
139 | gnuplot.xlabel('Training epochs')
140 | gnuplot.movelegend('left','top')
141 | gnuplot.plotflush()
142 |
143 |
144 | gnuplot.pngfigure('./Results/tmp/conv3_max.png')
145 | gnuplot.title('conv 3 max over training')
146 | gnuplot.plot({'Max norm',conv3_norm_max},{'Max grad',conv3_grad_max})
147 | gnuplot.xlabel('Training epochs')
148 | gnuplot.movelegend('left','top')
149 | gnuplot.plotflush()
150 |
151 | i = 1
152 | lin1_norm = torch.Tensor(m/4-1)
153 | lin1_norm_max = torch.Tensor(m/4-1)
154 | lin1_grad = torch.Tensor(m/4-1)
155 | lin1_grad_max = torch.Tensor(m/4-1)
156 |
157 | if file7 then
158 | for line in file7:lines() do
159 | if i == 1 then
160 | print(i)
161 | elseif i == m-4 then
162 | break
163 | elseif i % 4 == 2 then
164 | lin1_norm[math.floor(i/4)+1] = tonumber(line)
165 | elseif i % 4 == 3 then
166 | lin1_norm_max[math.floor(i/4)+1] = tonumber(line)
167 | elseif i % 4 == 0 then
168 | lin1_grad[math.floor(i/4)+1] = tonumber(line)
169 | elseif i % 4 == 1 then
170 | lin1_grad_max[math.floor(i/4)+1] = tonumber(line)
171 | end
172 | i = i+1
173 | end
174 | end
175 |
176 | gnuplot.pngfigure('./Results/tmp/lin1.png')
177 | gnuplot.title('lin1 over training')
178 | gnuplot.plot({'Norm',lin1_norm},{'Grad',lin1_grad})
179 | gnuplot.xlabel('Training epochs')
180 | gnuplot.movelegend('left','top')
181 | gnuplot.plotflush()
182 |
183 |
184 | gnuplot.pngfigure('./Results/tmp/lin1_max.png')
185 | gnuplot.title('lin1 max over training')
186 | gnuplot.plot({'Max norm',lin1_norm_max},{'Max grad',lin1_grad_max})
187 | gnuplot.xlabel('Training epochs')
188 | gnuplot.movelegend('left','top')
189 | gnuplot.plotflush()
190 |
191 | i = 1
192 | lin2_norm = torch.Tensor(m/4-1)
193 | lin2_norm_max = torch.Tensor(m/4-1)
194 | lin2_grad = torch.Tensor(m/4-1)
195 | lin2_grad_max = torch.Tensor(m/4-1)
196 |
197 | if file8 then
198 | for line in file8:lines() do
199 | if i == 1 then
200 | print(i)
201 | elseif i == m-4 then
202 | break
203 | elseif i % 4 == 2 then
204 | lin2_norm[math.floor(i/4)+1] = tonumber(line)
205 | elseif i % 4 == 3 then
206 | lin2_norm_max[math.floor(i/4)+1] = tonumber(line)
207 | elseif i % 4 == 0 then
208 | lin2_grad[math.floor(i/4)+1] = tonumber(line)
209 | elseif i % 4 == 1 then
210 | lin2_grad_max[math.floor(i/4)+1] = tonumber(line)
211 | end
212 | i = i+1
213 | end
214 | end
215 |
216 | gnuplot.pngfigure('./Results/tmp/lin2.png')
217 | gnuplot.title('lin2 over training')
218 | gnuplot.plot({'Norm',lin2_norm},{'Grad',lin2_grad})
219 | gnuplot.xlabel('Training epochs')
220 | gnuplot.movelegend('left','top')
221 | gnuplot.plotflush()
222 |
223 |
224 | gnuplot.pngfigure('./Results/tmp/lin2_max.png')
225 | gnuplot.title('lin2 max over training')
226 | gnuplot.plot({'Max norm',lin2_norm_max},{'Max grad',lin2_grad_max})
227 | gnuplot.xlabel('Training epochs')
228 | gnuplot.movelegend('left','top')
229 | gnuplot.plotflush()
230 |
231 |
232 | i = 1
233 | x = torch.Tensor(n)
234 | if file1 then
235 | for line in file1:lines() do
236 | x[i] = tonumber(line)
237 | i = i+1
238 | end
239 | end
240 |
241 | i = 1
242 | y = torch.Tensor(n)
243 | if file2 then
244 | for line in file2:lines() do
245 | y[i] = tonumber(line)
246 | i = i+1
247 | end
248 | end
249 |
250 | i = 1
251 | z = torch.Tensor(n)
252 | if file3 then
253 | for line in file3:lines() do
254 | z[i] = tonumber(line)
255 | z[i] = z[i]/y[i]
256 | i = i+1
257 | end
258 | end
259 |
260 | gnuplot.pngfigure('./Results/tmp/reward.png')
261 | gnuplot.title('reward over testing')
262 | gnuplot.plot(x)
263 | gnuplot.plotflush()
264 |
265 | gnuplot.pngfigure('./Results/tmp/vavg.png')
266 | gnuplot.title('vavg over testing')
267 | gnuplot.plot(y)
268 | gnuplot.plotflush()
269 |
270 | gnuplot.pngfigure('./Results/tmp/TD_Error.png')
271 | gnuplot.title('Normalized TD error over testing')
272 | gnuplot.plot(z)
273 | gnuplot.plotflush()
274 |
275 |
276 |
277 |
--------------------------------------------------------------------------------
/graying_the_box/LUA/logs/Results/Plot2.lua:
--------------------------------------------------------------------------------
1 | require 'gnuplot'
2 |
3 | n = 0
4 | local file1 = io.open("./Results/reward2.txt")
5 | if file1 then
6 | for line in file1:lines() do
7 | n = n+1
8 | end
9 | end
10 | file1:close()
11 |
12 | local file1 = io.open("./Results/reward1.txt")
13 | local file2 = io.open("./Results/reward2.txt")
14 |
15 | i = 1
16 |
17 |
18 | x = torch.Tensor(n)
19 | if file1 then
20 | for line in file1:lines() do
21 | x[i] = tonumber(line)
22 | i = i+1
23 | end
24 | end
25 |
26 |
27 | x_ax = torch.range(1,i-1)
28 |
29 |
30 | i = 1
31 | y = torch.Tensor(n)
32 | if file2 then
33 | for line in file2:lines() do
34 | y[i] = tonumber(line)
35 | i = i+1
36 | end
37 | end
38 | y_ax = torch.range(1,i-1)
39 |
40 |
41 | if x_ax:numel()>y_ax:numel() then
42 | n = y_ax:numel()
43 | else
44 | n = x_ax:numel()
45 | end
46 |
47 | data1 = torch.Tensor(n)
48 | data2 = torch.Tensor(n)
49 | for i = 1,n do
50 | data1[i] = x[i]
51 | data2[i] = y[i]
52 | end
53 |
54 |
55 | gnuplot.pngfigure('reward.png')
56 | gnuplot.title('Average Reward over training')
57 | gnuplot.plot({'Policy Noise',data1},{'Q noise',data2})
58 | gnuplot.xlabel('Training epochs')
59 | gnuplot.ylabel('Average testing reward')
60 | gnuplot.movelegend('left','top')
61 | gnuplot.plotflush()
62 |
63 |
64 |
--------------------------------------------------------------------------------
/graying_the_box/LUA/logs/Results/Plot3.lua:
--------------------------------------------------------------------------------
1 | require 'gnuplot'
2 |
3 | n = 0
4 | local file1 = io.open("./Results/reward1.txt")
5 | if file1 then
6 | for line in file1:lines() do
7 | n = n+1
8 | end
9 | end
10 | file1:close()
11 |
12 | local file1 = io.open("./Results/reward1.txt")
13 | local file2 = io.open("./Results/reward2.txt")
14 | local file3 = io.open("./Results/reward3.txt")
15 | i = 1
16 |
17 |
18 | x = torch.Tensor(n)
19 | if file1 then
20 | for line in file1:lines() do
21 | x[i] = tonumber(line)
22 | i = i+1
23 | end
24 | end
25 |
26 |
27 | x_ax = torch.range(1,i-1)
28 |
29 |
30 | i = 1
31 | y = torch.Tensor(n)
32 | if file2 then
33 | for line in file2:lines() do
34 | y[i] = tonumber(line)
35 | i = i+1
36 | end
37 | end
38 | y_ax = torch.range(1,i-1)
39 |
40 | i = 1
41 | z = torch.Tensor(n)
42 | if file3 then
43 | for line in file3:lines() do
44 | z[i] = tonumber(line)
45 | i = i+1
46 | end
47 | end
48 | z_ax = torch.range(1,i-1)
49 |
50 |
51 | if x_ax:numel()>y_ax:numel() then
52 | n = y_ax:numel()
53 | else
54 | n = x_ax:numel()
55 | end
56 | if z_ax:numel() ./Results/vavg.txt
4 | cat $1 | grep "reward:" | cut -d "," -f 2 | cut -d " " -f 3 > ./Results/reward.txt
5 | cat $1 | grep "TD" | cut -d"T" -f 2|cut -f2> ./Results/TD.txt
6 | cat $1 | grep "nn.Sequential" | cut -d" " -f 3 > ./Results/conv1.txt
7 | cat $1 | grep "nn.Sequential" | cut -d" " -f 5 > ./Results/conv2.txt
8 | cat $1 | grep "nn.Sequential" | cut -d" " -f 7 > ./Results/conv3.txt
9 | cat $1 | grep "nn.Sequential" | cut -d" " -f 9 > ./Results/lin1.txt
10 | cat $1 | grep "nn.Sequential" | cut -d" " -f 11 | cut -d"]" -f1 > ./Results/lin2.txt
11 |
12 |
13 |
14 | th ./Results/Plot.lua
15 |
16 |
17 |
--------------------------------------------------------------------------------
/graying_the_box/LUA/logs/Results/Process2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cat $1 | grep "TD" | cut -d"V" -f 2|cut -f2 > ./Results/vavg1.txt
4 | cat $1 | grep "reward:" | cut -d "," -f 2 | cut -d " " -f 3 > ./Results/reward1.txt
5 | cat $1 | grep "TD" | cut -d"T" -f 2|cut -f2> ./Results/TD1.txt
6 |
7 | cat $2 | grep "TD" | cut -d"V" -f 2|cut -f2 > ./Results/vavg2.txt
8 | cat $2 | grep "reward:" | cut -d "," -f 2 | cut -d " " -f 3 > ./Results/reward2.txt
9 | cat $2 | grep "TD" | cut -d"T" -f 2|cut -f2> ./Results/TD2.txt
10 |
11 | th ./Results/Plot2.lua
12 |
13 |
14 |
--------------------------------------------------------------------------------
/graying_the_box/LUA/logs/Results/Process3.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cat $1 | grep "TD" | cut -d"V" -f 2|cut -f2 > ./Results/vavg1.txt
4 | cat $1 | grep "reward:" | cut -d "," -f 2 | cut -d " " -f 3 > ./Results/reward1.txt
5 | cat $1 | grep "TD" | cut -d"T" -f 2|cut -f2> ./Results/TD1.txt
6 |
7 | cat $2 | grep "TD" | cut -d"V" -f 2|cut -f2 > ./Results/vavg2.txt
8 | cat $2 | grep "reward:" | cut -d "," -f 2 | cut -d " " -f 3 > ./Results/reward2.txt
9 | cat $2 | grep "TD" | cut -d"T" -f 2|cut -f2> ./Results/TD2.txt
10 |
11 | cat $3 | grep "TD" | cut -d"V" -f 2|cut -f2 > ./Results/vavg3.txt
12 | cat $3 | grep "reward:" | cut -d "," -f 2 | cut -d " " -f 3 > ./Results/reward3.txt
13 | cat $3 | grep "TD" | cut -d"T" -f 2|cut -f2> ./Results/TD3.txt
14 |
15 | th ./Results/Plot3.lua
16 |
17 |
18 |
--------------------------------------------------------------------------------
/graying_the_box/LUA/logs/Results/tmp/TD_Error.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/TD_Error.png
--------------------------------------------------------------------------------
/graying_the_box/LUA/logs/Results/tmp/conv1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/conv1.png
--------------------------------------------------------------------------------
/graying_the_box/LUA/logs/Results/tmp/conv1_max.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/conv1_max.png
--------------------------------------------------------------------------------
/graying_the_box/LUA/logs/Results/tmp/conv2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/conv2.png
--------------------------------------------------------------------------------
/graying_the_box/LUA/logs/Results/tmp/conv2_max.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/conv2_max.png
--------------------------------------------------------------------------------
/graying_the_box/LUA/logs/Results/tmp/conv3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/conv3.png
--------------------------------------------------------------------------------
/graying_the_box/LUA/logs/Results/tmp/conv3_max.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/conv3_max.png
--------------------------------------------------------------------------------
/graying_the_box/LUA/logs/Results/tmp/lin1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/lin1.png
--------------------------------------------------------------------------------
/graying_the_box/LUA/logs/Results/tmp/lin1_max.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/lin1_max.png
--------------------------------------------------------------------------------
/graying_the_box/LUA/logs/Results/tmp/lin2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/lin2.png
--------------------------------------------------------------------------------
/graying_the_box/LUA/logs/Results/tmp/lin2_max.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/lin2_max.png
--------------------------------------------------------------------------------
/graying_the_box/LUA/logs/Results/tmp/reward.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/reward.png
--------------------------------------------------------------------------------
/graying_the_box/LUA/logs/Results/tmp/vavg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/vavg.png
--------------------------------------------------------------------------------
/graying_the_box/LUA/roms/README:
--------------------------------------------------------------------------------
1 | Rom files should be put in this directory
2 |
--------------------------------------------------------------------------------
/graying_the_box/LUA/run_gpu:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ -z "$1" ]
4 | then echo "Please provide the name of the game, e.g. ./run_gpu breakout "; exit 0
5 | fi
6 | ENV=$1
7 | FRAMEWORK="alewrap"
8 |
9 | game_path=$PWD"/roms/"
10 | env_params="useRGB=true"
11 | agent="NeuralQLearner"
12 | n_replay=1
13 | netfile="\"convnet_atari3\""
14 | update_freq=4
15 | actrep=4
16 | discount=0.99
17 | seed=1
18 |
19 | learn_start=1000000
20 | replay_memory=1000000
21 | eps_start=0.1 #1
22 | lr=0.00025 #0.00025
23 |
24 |
25 | pool_frms_type="\"max\""
26 | pool_frms_size=2
27 | initial_priority="false"
28 |
29 | eps_end=0.1
30 | eps_endt=replay_memory
31 |
32 | agent_type="DQN3_0_1"
33 | preproc_net="\"net_downsample_2x_full_y\""
34 | agent_name=$agent_type"_"$1"_SeaQuest"
35 | state_dim=7056
36 | ncols=1
37 | agent_params="lr="$lr",ep="$eps_start",ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=32,rescale_r=1,ncols="$ncols",bufferSize=512,valid_size=500,target_q=10000,clip_delta=1,min_reward=-1,max_reward=1"
38 | steps=50000000
39 | eval_freq=250000
40 | eval_steps=125000
41 | prog_freq=10000
42 | save_freq=1000000
43 | gpu=0
44 | random_starts=30
45 | pool_frms="type="$pool_frms_type",size="$pool_frms_size
46 | num_threads=4
47 |
48 | args="-network $2 -framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -agent_params $agent_params -steps $steps -eval_freq $eval_freq -eval_steps $eval_steps -prog_freq $prog_freq -save_freq $save_freq -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads"
49 | echo $args
50 |
51 | cd dqn
52 | qlua ./train_agent.lua $args
53 |
--------------------------------------------------------------------------------
/graying_the_box/bhtsne/Makefile.win:
--------------------------------------------------------------------------------
1 | CXX = cl.exe
2 | CFLAGS = /nologo /O2 /EHsc /D "_CRT_SECURE_NO_DEPRECATE" /D "USEOMP" /openmp
3 |
4 | TARGET = windows
5 |
6 | all: $(TARGET) $(TARGET)\bh_tsne.exe
7 |
8 | $(TARGET)\bh_tsne.exe: tsne.obj sptree.obj
9 | $(CXX) $(CFLAGS) tsne.obj sptree.obj -Fe$(TARGET)\bh_tsne.exe
10 |
11 | sptree.obj: sptree.cpp sptree.h
12 | $(CXX) $(CFLAGS) -c sptree.cpp
13 |
14 | tsne.obj: tsne.cpp tsne.h sptree.h vptree.h
15 | $(CXX) $(CFLAGS) -c tsne.cpp
16 |
17 | .PHONY: $(TARGET)
18 | $(TARGET):
19 | -mkdir $(TARGET)
20 |
21 | clean:
22 | -erase /Q *.obj *.exe $(TARGET)\.
23 | -rd $(TARGET)
24 |
--------------------------------------------------------------------------------
/graying_the_box/bhtsne/parse_lua_tensor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def parse_lua_tensor(file, dim):
4 |
5 | k = 0
6 | read_data = []
7 | with open(file, 'r') as f:
8 | for line in f:
9 | if 'Columns' in line or line == '\n':
10 | continue
11 | else:
12 | read_data.append(line)
13 | k += 1
14 |
15 | read_data = read_data[2:-1]
16 | samples = []
17 | for line in read_data:
18 | line = line[0:-1] # remove end of line mark
19 | line_ = line.split(' ')
20 | line_ = filter(None,line_)
21 |
22 | samples.append(line_)
23 |
24 | z = np.array(samples)
25 | num_cols = dim # DEBUG HERE!!!
26 | num_lines = z.shape[0] * z.shape[1] / num_cols
27 | my_mat = np.zeros(shape=(num_lines,num_cols), dtype='float32')
28 |
29 | num_blocks = z.shape[0] / num_lines
30 |
31 | for i in range(num_blocks):
32 | my_mat[:,i*z.shape[1]:(i+1)*z.shape[1]] = z[num_lines*i : (i+1)*num_lines,:]
33 |
34 | return my_mat
--------------------------------------------------------------------------------
/graying_the_box/bhtsne/sptree.h:
--------------------------------------------------------------------------------
1 | /*
2 | *
3 | * Copyright (c) 2014, Laurens van der Maaten (Delft University of Technology)
4 | * All rights reserved.
5 | *
6 | * Redistribution and use in source and binary forms, with or without
7 | * modification, are permitted provided that the following conditions are met:
8 | * 1. Redistributions of source code must retain the above copyright
9 | * notice, this list of conditions and the following disclaimer.
10 | * 2. Redistributions in binary form must reproduce the above copyright
11 | * notice, this list of conditions and the following disclaimer in the
12 | * documentation and/or other materials provided with the distribution.
13 | * 3. All advertising materials mentioning features or use of this software
14 | * must display the following acknowledgement:
15 | * This product includes software developed by the Delft University of Technology.
16 | * 4. Neither the name of the Delft University of Technology nor the names of
17 | * its contributors may be used to endorse or promote products derived from
18 | * this software without specific prior written permission.
19 | *
20 | * THIS SOFTWARE IS PROVIDED BY LAURENS VAN DER MAATEN ''AS IS'' AND ANY EXPRESS
21 | * OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
22 | * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
23 | * EVENT SHALL LAURENS VAN DER MAATEN BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 | * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25 | * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
26 | * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
28 | * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
29 | * OF SUCH DAMAGE.
30 | *
31 | */
32 |
33 |
34 | #ifndef SPTREE_H
35 | #define SPTREE_H
36 |
37 | using namespace std;
38 |
39 |
40 | class Cell {
41 |
42 | unsigned int dimension;
43 | double* corner;
44 | double* width;
45 |
46 |
47 | public:
48 | Cell(unsigned int inp_dimension);
49 | Cell(unsigned int inp_dimension, double* inp_corner, double* inp_width);
50 | ~Cell();
51 |
52 | double getCorner(unsigned int d);
53 | double getWidth(unsigned int d);
54 | void setCorner(unsigned int d, double val);
55 | void setWidth(unsigned int d, double val);
56 | bool containsPoint(double point[]);
57 | };
58 |
59 |
60 | class SPTree
61 | {
62 |
63 | // Fixed constants
64 | static const unsigned int QT_NODE_CAPACITY = 1;
65 |
66 | // A buffer we use when doing force computations
67 | double* buff;
68 |
69 | // Properties of this node in the tree
70 | SPTree* parent;
71 | unsigned int dimension;
72 | bool is_leaf;
73 | unsigned int size;
74 | unsigned int cum_size;
75 |
76 | // Axis-aligned bounding box stored as a center with half-dimensions to represent the boundaries of this quad tree
77 | Cell* boundary;
78 |
79 | // Indices in this space-partitioning tree node, corresponding center-of-mass, and list of all children
80 | double* data;
81 | double* center_of_mass;
82 | unsigned int index[QT_NODE_CAPACITY];
83 |
84 | // Children
85 | SPTree** children;
86 | unsigned int no_children;
87 |
88 | public:
89 | SPTree(unsigned int D, double* inp_data, unsigned int N);
90 | SPTree(unsigned int D, double* inp_data, double* inp_corner, double* inp_width);
91 | SPTree(unsigned int D, double* inp_data, unsigned int N, double* inp_corner, double* inp_width);
92 | SPTree(SPTree* inp_parent, unsigned int D, double* inp_data, unsigned int N, double* inp_corner, double* inp_width);
93 | SPTree(SPTree* inp_parent, unsigned int D, double* inp_data, double* inp_corner, double* inp_width);
94 | ~SPTree();
95 | void setData(double* inp_data);
96 | SPTree* getParent();
97 | void construct(Cell boundary);
98 | bool insert(unsigned int new_index);
99 | void subdivide();
100 | bool isCorrect();
101 | void rebuildTree();
102 | void getAllIndices(unsigned int* indices);
103 | unsigned int getDepth();
104 | void computeNonEdgeForces(unsigned int point_index, double theta, double neg_f[], double* sum_Q);
105 | void computeEdgeForces(unsigned int* row_P, unsigned int* col_P, double* val_P, int N, double* pos_f);
106 | void print();
107 |
108 | private:
109 | void init(SPTree* inp_parent, unsigned int D, double* inp_data, double* inp_corner, double* inp_width);
110 | void fill(unsigned int N);
111 | unsigned int getAllIndices(unsigned int* indices, unsigned int loc);
112 | bool isChild(unsigned int test_index, unsigned int start, unsigned int end);
113 | };
114 |
115 | #endif
116 |
--------------------------------------------------------------------------------
/graying_the_box/bhtsne/tsne.h:
--------------------------------------------------------------------------------
1 | /*
2 | *
3 | * Copyright (c) 2014, Laurens van der Maaten (Delft University of Technology)
4 | * All rights reserved.
5 | *
6 | * Redistribution and use in source and binary forms, with or without
7 | * modification, are permitted provided that the following conditions are met:
8 | * 1. Redistributions of source code must retain the above copyright
9 | * notice, this list of conditions and the following disclaimer.
10 | * 2. Redistributions in binary form must reproduce the above copyright
11 | * notice, this list of conditions and the following disclaimer in the
12 | * documentation and/or other materials provided with the distribution.
13 | * 3. All advertising materials mentioning features or use of this software
14 | * must display the following acknowledgement:
15 | * This product includes software developed by the Delft University of Technology.
16 | * 4. Neither the name of the Delft University of Technology nor the names of
17 | * its contributors may be used to endorse or promote products derived from
18 | * this software without specific prior written permission.
19 | *
20 | * THIS SOFTWARE IS PROVIDED BY LAURENS VAN DER MAATEN ''AS IS'' AND ANY EXPRESS
21 | * OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
22 | * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
23 | * EVENT SHALL LAURENS VAN DER MAATEN BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 | * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25 | * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
26 | * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
28 | * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
29 | * OF SUCH DAMAGE.
30 | *
31 | */
32 |
33 |
34 | #ifndef TSNE_H
35 | #define TSNE_H
36 |
37 |
38 | static inline double sign(double x) { return (x == .0 ? .0 : (x < .0 ? -1.0 : 1.0)); }
39 |
40 |
41 | class TSNE
42 | {
43 | public:
44 | void run(double* X, int N, int D, double* Y, int no_dims, double perplexity, double theta);
45 | bool load_data(double** data, int* n, int* d, int* no_dims, double* theta, double* perplexity, int* rand_seed);
46 | void save_data(double* data, int* landmarks, double* costs, int n, int d);
47 | void symmetrizeMatrix(unsigned int** row_P, unsigned int** col_P, double** val_P, int N); // should be static!
48 |
49 |
50 | private:
51 | void computeGradient(double* P, unsigned int* inp_row_P, unsigned int* inp_col_P, double* inp_val_P, double* Y, int N, int D, double* dC, double theta);
52 | void computeExactGradient(double* P, double* Y, int N, int D, double* dC);
53 | double evaluateError(double* P, double* Y, int N, int D);
54 | double evaluateError(unsigned int* row_P, unsigned int* col_P, double* val_P, double* Y, int N, int D, double theta);
55 | void zeroMean(double* X, int N, int D);
56 | void computeGaussianPerplexity(double* X, int N, int D, double* P, double perplexity);
57 | void computeGaussianPerplexity(double* X, int N, int D, unsigned int** _row_P, unsigned int** _col_P, double** _val_P, double perplexity, int K);
58 | void computeSquaredEuclideanDistance(double* X, int N, int D, double* DD);
59 | double randn();
60 | };
61 |
62 | #endif
63 |
64 |
--------------------------------------------------------------------------------
/graying_the_box/bhtsne/vptree.h:
--------------------------------------------------------------------------------
1 | /*
2 | *
3 | * Copyright (c) 2014, Laurens van der Maaten (Delft University of Technology)
4 | * All rights reserved.
5 | *
6 | * Redistribution and use in source and binary forms, with or without
7 | * modification, are permitted provided that the following conditions are met:
8 | * 1. Redistributions of source code must retain the above copyright
9 | * notice, this list of conditions and the following disclaimer.
10 | * 2. Redistributions in binary form must reproduce the above copyright
11 | * notice, this list of conditions and the following disclaimer in the
12 | * documentation and/or other materials provided with the distribution.
13 | * 3. All advertising materials mentioning features or use of this software
14 | * must display the following acknowledgement:
15 | * This product includes software developed by the Delft University of Technology.
16 | * 4. Neither the name of the Delft University of Technology nor the names of
17 | * its contributors may be used to endorse or promote products derived from
18 | * this software without specific prior written permission.
19 | *
20 | * THIS SOFTWARE IS PROVIDED BY LAURENS VAN DER MAATEN ''AS IS'' AND ANY EXPRESS
21 | * OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
22 | * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
23 | * EVENT SHALL LAURENS VAN DER MAATEN BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 | * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25 | * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
26 | * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
28 | * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
29 | * OF SUCH DAMAGE.
30 | *
31 | */
32 |
33 |
34 | /* This code was adopted with minor modifications from Steve Hanov's great tutorial at http://stevehanov.ca/blog/index.php?id=130 */
35 |
36 | #include
37 | #include
38 | #include
39 | #include
40 | #include
41 | #include
42 | #include
43 |
44 |
45 | #ifndef VPTREE_H
46 | #define VPTREE_H
47 |
48 | class DataPoint
49 | {
50 | int _ind;
51 |
52 | public:
53 | double* _x;
54 | int _D;
55 | DataPoint() {
56 | _D = 1;
57 | _ind = -1;
58 | _x = NULL;
59 | }
60 | DataPoint(int D, int ind, double* x) {
61 | _D = D;
62 | _ind = ind;
63 | _x = (double*) malloc(_D * sizeof(double));
64 | for(int d = 0; d < _D; d++) _x[d] = x[d];
65 | }
66 | DataPoint(const DataPoint& other) { // this makes a deep copy -- should not free anything
67 | if(this != &other) {
68 | _D = other.dimensionality();
69 | _ind = other.index();
70 | _x = (double*) malloc(_D * sizeof(double));
71 | for(int d = 0; d < _D; d++) _x[d] = other.x(d);
72 | }
73 | }
74 | ~DataPoint() { if(_x != NULL) free(_x); }
75 | DataPoint& operator= (const DataPoint& other) { // asignment should free old object
76 | if(this != &other) {
77 | if(_x != NULL) free(_x);
78 | _D = other.dimensionality();
79 | _ind = other.index();
80 | _x = (double*) malloc(_D * sizeof(double));
81 | for(int d = 0; d < _D; d++) _x[d] = other.x(d);
82 | }
83 | return *this;
84 | }
85 | int index() const { return _ind; }
86 | int dimensionality() const { return _D; }
87 | double x(int d) const { return _x[d]; }
88 | };
89 |
90 | double euclidean_distance(const DataPoint &t1, const DataPoint &t2) {
91 | double dd = .0;
92 | double* x1 = t1._x;
93 | double* x2 = t2._x;
94 | double diff;
95 | for(int d = 0; d < t1._D; d++) {
96 | diff = (x1[d] - x2[d]);
97 | dd += diff * diff;
98 | }
99 | return sqrt(dd);
100 | }
101 |
102 |
103 | template
104 | class VpTree
105 | {
106 | public:
107 |
108 | // Default constructor
109 | VpTree() : _root(0) {}
110 |
111 | // Destructor
112 | ~VpTree() {
113 | delete _root;
114 | }
115 |
116 | // Function to create a new VpTree from data
117 | void create(const std::vector& items) {
118 | delete _root;
119 | _items = items;
120 | _root = buildFromPoints(0, items.size());
121 | }
122 |
123 | // Function that uses the tree to find the k nearest neighbors of target
124 | void search(const T& target, int k, std::vector* results, std::vector* distances)
125 | {
126 |
127 | // Use a priority queue to store intermediate results on
128 | std::priority_queue heap;
129 |
130 | // Variable that tracks the distance to the farthest point in our results
131 | _tau = DBL_MAX;
132 |
133 | // Perform the search
134 | search(_root, target, k, heap);
135 |
136 | // Gather final results
137 | results->clear(); distances->clear();
138 | while(!heap.empty()) {
139 | results->push_back(_items[heap.top().index]);
140 | distances->push_back(heap.top().dist);
141 | heap.pop();
142 | }
143 |
144 | // Results are in reverse order
145 | std::reverse(results->begin(), results->end());
146 | std::reverse(distances->begin(), distances->end());
147 | }
148 |
149 | private:
150 | std::vector _items;
151 | double _tau;
152 |
153 | // Single node of a VP tree (has a point and radius; left children are closer to point than the radius)
154 | struct Node
155 | {
156 | int index; // index of point in node
157 | double threshold; // radius(?)
158 | Node* left; // points closer by than threshold
159 | Node* right; // points farther away than threshold
160 |
161 | Node() :
162 | index(0), threshold(0.), left(0), right(0) {}
163 |
164 | ~Node() { // destructor
165 | delete left;
166 | delete right;
167 | }
168 | }* _root;
169 |
170 |
171 | // An item on the intermediate result queue
172 | struct HeapItem {
173 | HeapItem( int index, double dist) :
174 | index(index), dist(dist) {}
175 | int index;
176 | double dist;
177 | bool operator<(const HeapItem& o) const {
178 | return dist < o.dist;
179 | }
180 | };
181 |
182 | // Distance comparator for use in std::nth_element
183 | struct DistanceComparator
184 | {
185 | const T& item;
186 | DistanceComparator(const T& item) : item(item) {}
187 | bool operator()(const T& a, const T& b) {
188 | return distance(item, a) < distance(item, b);
189 | }
190 | };
191 |
192 | // Function that (recursively) fills the tree
193 | Node* buildFromPoints( int lower, int upper )
194 | {
195 | if (upper == lower) { // indicates that we're done here!
196 | return NULL;
197 | }
198 |
199 | // Lower index is center of current node
200 | Node* node = new Node();
201 | node->index = lower;
202 |
203 | if (upper - lower > 1) { // if we did not arrive at leaf yet
204 |
205 | // Choose an arbitrary point and move it to the start
206 | int i = (int) ((double)rand() / RAND_MAX * (upper - lower - 1)) + lower;
207 | std::swap(_items[lower], _items[i]);
208 |
209 | // Partition around the median distance
210 | int median = (upper + lower) / 2;
211 | std::nth_element(_items.begin() + lower + 1,
212 | _items.begin() + median,
213 | _items.begin() + upper,
214 | DistanceComparator(_items[lower]));
215 |
216 | // Threshold of the new node will be the distance to the median
217 | node->threshold = distance(_items[lower], _items[median]);
218 |
219 | // Recursively build tree
220 | node->index = lower;
221 | node->left = buildFromPoints(lower + 1, median);
222 | node->right = buildFromPoints(median, upper);
223 | }
224 |
225 | // Return result
226 | return node;
227 | }
228 |
229 | // Helper function that searches the tree
230 | void search(Node* node, const T& target, int k, std::priority_queue& heap)
231 | {
232 | if(node == NULL) return; // indicates that we're done here
233 |
234 | // Compute distance between target and current node
235 | double dist = distance(_items[node->index], target);
236 |
237 | // If current node within radius tau
238 | if(dist < _tau) {
239 | if(heap.size() == k) heap.pop(); // remove furthest node from result list (if we already have k results)
240 | heap.push(HeapItem(node->index, dist)); // add current node to result list
241 | if(heap.size() == k) _tau = heap.top().dist; // update value of tau (farthest point in result list)
242 | }
243 |
244 | // Return if we arrived at a leaf
245 | if(node->left == NULL && node->right == NULL) {
246 | return;
247 | }
248 |
249 | // If the target lies within the radius of ball
250 | if(dist < node->threshold) {
251 | if(dist - _tau <= node->threshold) { // if there can still be neighbors inside the ball, recursively search left child first
252 | search(node->left, target, k, heap);
253 | }
254 |
255 | if(dist + _tau >= node->threshold) { // if there can still be neighbors outside the ball, recursively search right child
256 | search(node->right, target, k, heap);
257 | }
258 |
259 | // If the target lies outsize the radius of the ball
260 | } else {
261 | if(dist + _tau >= node->threshold) { // if there can still be neighbors outside the ball, recursively search right child first
262 | search(node->right, target, k, heap);
263 | }
264 |
265 | if (dist - _tau <= node->threshold) { // if there can still be neighbors inside the ball, recursively search left child
266 | search(node->left, target, k, heap);
267 | }
268 | }
269 | }
270 | };
271 |
272 | #endif
273 |
--------------------------------------------------------------------------------
/graying_the_box/bhtsne/write_movie.py:
--------------------------------------------------------------------------------
1 | # This example uses a MovieWriter directly to grab individual frames and
2 | # write them to a file. This avoids any event loop integration, but has
3 | # the advantage of working with even the Agg backend. This is not recommended
4 | # for use in an interactive setting.
5 | # -*- noplot -*-
6 |
7 | import numpy as np
8 | import matplotlib
9 | matplotlib.use("Agg")
10 | import matplotlib.pyplot as plt
11 | import matplotlib.animation as manimation
12 |
13 | FFMpegWriter = manimation.writers['ffmpeg']
14 | metadata = dict(title='Movie Test', artist='Matplotlib',
15 | comment='Movie support!')
16 | writer = FFMpegWriter(fps=15, metadata=metadata)
17 |
18 | fig = plt.figure()
19 | l, = plt.plot([], [], 'k-o')
20 |
21 | plt.xlim(-5, 5)
22 | plt.ylim(-5, 5)
23 |
24 | x0,y0 = 0, 0
25 |
26 | with writer.saving(fig, "writer_test.mp4", 100):
27 | for i in range(100):
28 | x0 += 0.1 * np.random.randn()
29 | y0 += 0.1 * np.random.randn()
30 | l.set_data(x0, y0)
31 | writer.grab_frame()
--------------------------------------------------------------------------------
/graying_the_box/clustering.py:
--------------------------------------------------------------------------------
1 | from sklearn.cluster import SpectralClustering
2 | from sklearn.cluster import KMeans as Kmeans_st
3 | # from sklearn.cluster import KMeans_st as Kmeans_st
4 |
5 | from emhc import EMHC
6 | from smdp import SMDP
7 | import numpy as np
8 | import common
9 | from digraph import draw_transition_table
10 |
11 | def perpare_features(self, n_features=3):
12 |
13 | data = np.zeros(shape=(self.global_feats['tsne'].shape[0],n_features))
14 | data[:,0:2] = self.global_feats['tsne']
15 | data[:,2] = self.global_feats['value']
16 | # data[:,3] = self.global_feats['time']
17 | # data[:,4] = self.global_feats['termination']
18 | # data[:,5] = self.global_feats['tsne3d_norm']
19 | # data[:,6] = self.hand_craft_feats['missing_bricks']
20 | # data[:,6] = self.hand_craft_feats['hole']
21 | # data[:,7] = self.hand_craft_feats['racket']
22 | # data[:,8] = self.hand_craft_feats['ball_dir']
23 | # data[:,9] = self.hand_craft_feats['traj']
24 | # data[:,9:11] = self.hand_craft_feats['ball_pos']
25 | data[np.isnan(data)] = 0
26 | # 1.2 data standartization
27 | # scaler = preprocessing.StandardScaler(with_centering=False).fit(data)
28 | # data = scaler.fit_transform(data)
29 |
30 | # data_mean = data.mean(axis=0)
31 | # data -= data_mean
32 | return data
33 |
34 | def clustering_(self, plt, n_points=None, force=0):
35 |
36 | if n_points==None:
37 | n_points = self.global_feats['termination'].shape[0]
38 |
39 | if self.clustering_labels is not None:
40 | self.tsne_scat.set_array(self.clustering_labels.astype(np.float32)/self.clustering_labels.max())
41 | draw_transition_table(transition_table=self.smdp.P, cluster_centers=self.cluster_centers,
42 | meanscreen=self.meanscreen, tsne=self.global_feats['tsne'], color=self.color, black_edges=self.smdp.edges)
43 | plt.show()
44 | if force==0:
45 | return
46 |
47 | n_clusters = self.cluster_params['n_clusters']
48 | W = self.cluster_params['window_size']
49 | n_iters = self.cluster_params['n_iters']
50 | entropy_iters = self.cluster_params['entropy_iters']
51 |
52 | # slice data by given indices
53 | term = self.global_feats['termination'][:n_points]
54 | reward = self.global_feats['reward'][:n_points]
55 | value = self.global_feats['value'][:n_points]
56 | tsne = self.global_feats['tsne'][:n_points]
57 | traj_ids = self.hand_craft_feats['traj'][:n_points]
58 |
59 | # 1. create data for clustering
60 | data = perpare_features(self)
61 | data = data[:n_points]
62 | data_scale = data.max(axis=0)
63 | data /= data_scale
64 |
65 | # 2. Build cluster model
66 | # 2.1 spatio-temporal K-means
67 | if self.cluster_params['method'] == 0:
68 | windows_vec = np.arange(start=W,stop=W+1,step=1)
69 | clusters_vec = np.arange(start=n_clusters,stop=n_clusters+1,step=1)
70 | models_vec = []
71 | scores = np.zeros(shape=(len(clusters_vec),1))
72 | for i,n_w in enumerate(windows_vec):
73 | for j,n_c in enumerate(clusters_vec):
74 | cluster_model = Kmeans_st(n_clusters=n_clusters,window_size=n_w,n_jobs=8,n_init=n_iters,entropy_iters=entropy_iters)
75 | cluster_model.fit(data, rewards=reward, termination=term, values=value)
76 | labels = cluster_model.labels_
77 | models_vec.append(cluster_model.smdp)
78 | scores[j] = cluster_model.smdp.score
79 | print 'window size: %d , Value mse: %f' % (n_w, cluster_model.smdp.score)
80 | best = np.argmin(scores)
81 | self.cluster_params['n_clusters'] +=best
82 | self.smdp = models_vec[best]
83 |
84 | # 2.1 Spectral clustering
85 | elif self.cluster_params['method'] == 1:
86 | import scipy.spatial.distance
87 | import scipy.sparse
88 | dists = scipy.spatial.distance.pdist(tsne, 'euclidean')
89 | similarity = np.exp(-dists/10)
90 | similarity[similarity<1e-2] = 0
91 | print 'Created similarity matrix'
92 | affine_mat = scipy.spatial.distance.squareform(similarity)
93 | cluster_model = SpectralClustering(n_clusters=n_clusters,affinity='precomputed')
94 | labels = cluster_model.fit_predict(affine_mat)
95 |
96 | # 2.2 EMHC
97 | elif self.cluster_params['method'] == 2:
98 | # cluster with k means down to n_clusters + D
99 | n_clusters_ = n_clusters + 5
100 | kmeans_st_model = Kmeans_st(n_clusters=n_clusters_,window_size=W,n_jobs=8,n_init=n_iters,entropy_iters=entropy_iters, random_state=123)
101 | kmeans_st_model.fit(data, rewards=reward, termination=term, values=value)
102 | cluster_model = EMHC(X=data, labels=kmeans_st_model.labels_, termination=term, min_clusters=n_clusters, max_entropy=np.inf)
103 | cluster_model.fit()
104 | labels = cluster_model.labels_
105 | self.smdp = SMDP(labels=labels, termination=term, rewards=reward, values=value, n_clusters=n_clusters)
106 |
107 | self.smdp.complete_smdp()
108 | self.clustering_labels = self.smdp.labels
109 | common.create_trajectory_data(self, reward, traj_ids)
110 | self.state_pi_correlation = common.reward_policy_correlation(self.traj_list, self.smdp.greedy_policy, self.smdp)
111 |
112 | top_greedy_vec = []
113 | bottom_greedy_vec = []
114 | max_diff = 0
115 | best_d = 1
116 | for i,d in enumerate(xrange(1,30)):
117 | tb_trajs_discr = common.extermum_trajs_discrepency(self.traj_list, self.clustering_labels, term, reward, value, self.smdp.n_clusters, self.smdp.greedy_policy, d=d)
118 | top_greedy_vec.append([i,tb_trajs_discr['top_greedy_sum']])
119 | bottom_greedy_vec.append([i,tb_trajs_discr['bottom_greedy_sum']])
120 | diff_i = tb_trajs_discr['top_greedy_sum'] - tb_trajs_discr['bottom_greedy_sum']
121 | if diff_i > max_diff:
122 | max_diff = diff_i
123 | best_d = d
124 |
125 | self.tb_trajs_discr = common.extermum_trajs_discrepency(self.traj_list, self.clustering_labels, term, reward, value, self.smdp.n_clusters, self.smdp.greedy_policy, d=best_d)
126 | self.top_greedy_vec = top_greedy_vec
127 | self.bottom_greedy_vec = bottom_greedy_vec
128 |
129 | common.draw_skills(self,self.smdp.n_clusters,plt)
130 |
131 |
132 | # 4. collect statistics
133 | cluster_centers = cluster_model.cluster_centers_
134 | cluster_centers *= data_scale
135 |
136 | screen_size = self.screens.shape
137 | meanscreen = np.zeros(shape=(n_clusters,screen_size[1],screen_size[2],screen_size[3]))
138 | cluster_time = np.zeros(shape=(n_clusters,1))
139 | width = int(np.floor(np.sqrt(n_clusters)))
140 | length = int(n_clusters/width)
141 | # f, ax = plt.subplots(length,width)
142 |
143 | for cluster_ind in range(n_clusters):
144 | indices = (labels==cluster_ind)
145 | cluster_data = data[indices]
146 | cluster_time[cluster_ind] = np.mean(self.global_feats['time'][indices])
147 | meanscreen[cluster_ind,:,:,:] = common.calc_cluster_im(self,indices)
148 |
149 | # 5. draw cluster indices
150 | plt.figure(self.fig.number)
151 | data *= data_scale
152 | for i in range(n_clusters):
153 | self.ax_tsne.annotate(i, xy=cluster_centers[i,0:2], size=20, color='r')
154 | draw_transition_table(transition_table=self.smdp.P, cluster_centers=cluster_centers,
155 | meanscreen=meanscreen, tsne=data[:,0:2], color=self.color, black_edges=self.smdp.edges)
156 |
157 | self.cluster_centers = cluster_centers
158 | self.meanscreen =meanscreen
159 | self.cluster_time =cluster_time
160 | common.visualize(self)
161 |
162 | def update_slider(self, name, slider):
163 | def f():
164 | setattr(self, name, slider.val)
165 | return f
166 |
--------------------------------------------------------------------------------
/graying_the_box/digraph.py:
--------------------------------------------------------------------------------
1 | import networkx as nx
2 | import matplotlib.pyplot as plt
3 | import pylab
4 | import pickle
5 | import numpy as np
6 | import common
7 |
8 | def draw_transition_table(transition_table, cluster_centers, meanscreen, tsne ,color, black_edges=None, red_edges=None, title=None):
9 | G = nx.DiGraph()
10 | edge_colors = []
11 |
12 | if red_edges is not None:
13 | for e in red_edges:
14 | G.add_edges_from([e], weight=np.round(transition_table[e[0],e[1]]*100)/100)
15 | edge_colors.append('red')
16 |
17 | if black_edges is not None:
18 | if red_edges is not None:
19 | black_edges = list(set(black_edges)-set(red_edges))
20 |
21 | for e in black_edges:
22 | G.add_edges_from([e], weight=np.round(transition_table[e[0],e[1]]*100)/100)
23 | edge_colors.append('black')
24 |
25 |
26 | edge_labels=dict([((u,v,),d['weight']) for u,v,d in G.edges(data=True)])
27 |
28 | node_labels = {node:node for node in G.nodes()};
29 | counter=0
30 | for key in node_labels.keys():
31 | node_labels[key] = counter
32 | counter+=1
33 |
34 | if title is None:
35 | fig = plt.figure('SMDP')
36 | fig.clear()
37 | else:
38 | fig = plt.figure(title)
39 |
40 | plt.scatter(tsne[:,0],tsne[:,1],s= np.ones(tsne.shape[0])*2,facecolor=color, edgecolor='none')
41 | pos = cluster_centers[:,0:2]
42 | nx.draw_networkx_edge_labels(G,pos,edge_labels=edge_labels,label_pos=0.65,font_size=9)
43 | nx.draw_networkx_labels(G, pos, labels=node_labels,font_color='w',font_size=8)
44 | nx.draw(G,pos,cmap=plt.cm.brg,edge_color=edge_colors)
45 |
46 |
47 | ######Present images on nodes
48 | ax = plt.subplot(111)
49 | plt.axis('off')
50 | trans = ax.transData.transform
51 | trans2 = fig.transFigure.inverted().transform
52 | cut = 1.01
53 | xmax = cut * max(tsne[:,0])
54 | ymax = cut * max(tsne[:,1])
55 | xmin = cut * min(tsne[:,0])
56 | ymin = cut * min(tsne[:,1])
57 | plt.xlim(xmin, xmax)
58 | plt.ylim(ymin, ymax)
59 |
60 | h = 70.0
61 | w = 70.0
62 | counter= 0
63 | for node in G:
64 | xx, yy = trans(pos[node])
65 | # axes coordinates
66 | xa, ya = trans2((xx, yy))
67 |
68 | # this is the image size
69 | piesize_1 = (300.0 / (h*80))
70 | piesize_2 = (300.0 / (w*80))
71 | p2_2 = piesize_2 / 2
72 | p2_1 = piesize_1 / 2
73 | a = plt.axes([xa - p2_2, ya - p2_1, piesize_2, piesize_1])
74 | G.node[node]['image'] = meanscreen[counter]
75 | #display it
76 | a.imshow(G.node[node]['image'])
77 | a.set_title(node_labels[counter])
78 | #turn off the axis from minor plot
79 | a.axis('off')
80 | counter+=1
81 | plt.draw()
82 |
83 | def draw_transition_table_no_image(transition_table,cluster_centers):
84 | G = nx.DiGraph()
85 | G2 = nx.DiGraph()
86 |
87 | # print transition_table.sum(axis=1)
88 |
89 | transition_table = (transition_table.transpose()/transition_table.sum(axis=1)).transpose()
90 | transition_table[np.isnan(transition_table)]=0
91 | # print(transition_table)
92 | # transition_table = (transition_table.transpose()/transition_table.sum(axis=1)).transpose()
93 | # print transition_table
94 | # print transition_table.sum(axis=0)
95 | # assert(np.all(transition_table.sum(axis=0)!=0))
96 | transition_table[transition_table<0.1]=0
97 |
98 | pos = cluster_centers[:,0:2]
99 | m,n = transition_table.shape
100 |
101 | for i in range(m):
102 | for j in range(n):
103 | if transition_table[i,j]!=0:
104 | G.add_edges_from([(i, j)], weight=np.round(transition_table[i,j]*100)/100)
105 | G2.add_edges_from([(i, j)], weight=np.round(transition_table[i,j]*100)/100)
106 | values = cluster_centers[:,2]
107 |
108 | red_edges = []
109 | edges_sizes =[]
110 | for i in range(n):
111 | trans = transition_table[i,:]
112 | indices = (trans!=0)
113 | index = np.argmax(cluster_centers[indices,2])
114 | counter = 0
115 | for j in range(len(indices)):
116 | if indices[j]:
117 | if counter == index:
118 | ind = j
119 | break
120 | else:
121 | counter+=1
122 | edges_sizes.append(ind)
123 | red_edges.append((i,ind))
124 | # print(red_edges)
125 | # sizes = 3000*cluster_centers[:,3]
126 | sizes = np.ones_like(values)*500
127 | edge_labels=dict([((u,v,),d['weight']) for u,v,d in G.edges(data=True)])
128 | edge_colors = ['black' for edge in G.edges()]
129 | # edge_colors = ['black' if not edge in red_edges else 'red' for edge in G.edges()]
130 |
131 |
132 | node_labels = {node:node for node in G.nodes()};
133 | counter=0
134 | for key in node_labels.keys():
135 | # node_labels[key] = np.round(100*cluster_centers[counter,3])/100
136 | node_labels[key] = counter
137 | counter+=1
138 |
139 | fig = plt.figure()
140 | nx.draw_networkx_edge_labels(G,pos,edge_labels=edge_labels,label_pos=0.65,font_size=9)
141 | nx.draw_networkx_labels(G, pos, labels=node_labels,font_color='w',font_size=8)
142 | nx.draw(G,pos, node_color = values,cmap=plt.cm.brg, node_size=np.round(sizes),edge_color=edge_colors,edge_cmap=plt.cm.Reds)
143 |
144 |
145 | ######Present images on nodes
146 | # plt.show()
147 |
148 | #
149 | def test():
150 | gamename = 'breakout' #breakout pacman
151 | transition_table = pickle.load(file('/home/tom/git/graying_the_box/data/'+gamename+'/120k' + '/knn/' + 'transition_table.bin'))
152 | cluster_centers = pickle.load(file('/home/tom/git/graying_the_box/data/'+gamename+'/120k' + '/knn/' + 'cluster_centers.bin'))
153 | cluster_std = pickle.load(file('/home/tom/git/graying_the_box/data/'+gamename+'/120k' + '/knn/' + 'cluster_std.bin'))
154 | cluster_med = pickle.load(file('/home/tom/git/graying_the_box/data/'+gamename+'/120k' + '/knn/' + 'cluster_med.bin'))
155 | cluster_min = pickle.load(file('/home/tom/git/graying_the_box/data/'+gamename+'/120k' + '/knn/' + 'cluster_min.bin'))
156 | cluster_max = pickle.load(file('/home/tom/git/graying_the_box/data/'+gamename+'/120k' + '/knn/' + 'cluster_max.bin'))
157 | meanscreen = pickle.load(file('/home/tom/git/graying_the_box/data/'+gamename+'/120k' + '/knn/' + 'meanscreen.bin'))
158 | cluster_time = pickle.load(file('/home/tom/git/graying_the_box/data/'+gamename+'/120k' + '/knn/' + 'cluster_time.bin'))
159 | tsne = common.load_hdf5('lowd_activations', 'data/' + 'breakout' + '/'+'120k/')
160 | q_hdf5 = common.load_hdf5('qvals', 'data/' + 'breakout' + '/'+'120k/')
161 |
162 | num_frames = 120000
163 | V = np.zeros(shape=(num_frames))
164 |
165 | for i in range(0,num_frames):
166 | V[i] = max(q_hdf5[i])
167 | V = V/V.max()
168 | draw_transition_table(transition_table,cluster_centers,meanscreen,cluster_time,tsne,V)
169 | plt.show()
170 | # test()
171 |
172 |
173 |
174 | # stdscreen = pickle.load(file('/home/tom/git/graying_the_box/data/'+gamename+'/120k' + '/knn/' + 'stdscreen.bin'))
175 | # #
176 | # a = 1
177 | # b = 0
178 | # c = 0
179 | # screen = a*meanscreen + c*stdscreen
180 |
181 | # facecolor = self.color,
182 | # edgecolor='none',picker=5)
183 | # draw_transition_table_no_image(transition_table,cluster_centers)
184 |
185 | # transition_table = pickle.load(file('/home/tom/git/graying_the_box/data/seaquest/120k' + '/knn/' + 'transition_table.bin'))
186 | # transition_table[transition_table<0.1]=0
187 | # cluster_centers = pickle.load(file('/home/tom/git/graying_the_box/data/seaquest/120k' + '/knn/' + 'cluster_centers.bin'))
188 |
189 | # pos2 = np.zeros(shape=(cluster_centers.shape[0],2))
190 | # pos2[:,0] = cluster_time[:,0]
191 | # pos2[:,1] = cluster_centers[:,1]
192 | # plt.figure()
193 | # nx.draw_networkx_edge_labels(G2,pos2,edge_labels=edge_labels,label_pos=0.8,font_size=8)
194 | # nx.draw_networkx_labels(G2, pos2, labels=node_labels,font_color='w',font_size=8)
195 | # nx.draw(G2,pos2, node_color = values,cmap=plt.cm.brg, node_size=np.round(sizes),edge_color=edge_colors,edge_cmap=plt.cm.Reds)
196 |
--------------------------------------------------------------------------------
/graying_the_box/hand_crafted_features/add_breakout_buttons.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def add_game_buttons(self, top_y=0.69, x_left = 0.68):
4 |
5 | self.add_slider_button([x_left, top_y, 0.08, 0.01], 'ball_x', 0, 160)
6 | self.add_slider_button([x_left, top_y-0.04, 0.08, 0.01], 'ball_y', 0, 210)
7 | self.add_slider_button([x_left, top_y-0.08, 0.08, 0.01], 'racket_x', 0, 160)
8 | self.add_slider_button([x_left, top_y-0.12, 0.08, 0.01], 'missing_bricks', 0, 130)
9 | self.add_check_button([x_left, top_y-0.20, 0.06, 0.04], 'hole', ('no-hole','hole'), (True,True))
10 | self.add_check_button([x_left, top_y-0.30, 0.07, 0.08], 'ball_dir', ('down-right','up-right','down-left','up-left'), (True,True,True,True))
11 | self.ball_dir_mat = np.zeros(shape=(self.num_points,4), dtype='bool')
12 | for i,dir in enumerate(self.hand_craft_feats['ball_dir']):
13 | self.ball_dir_mat[i,int(dir)] = 1
14 |
15 | ##############################################
16 | # marking points along trajectories from a to b
17 |
18 | # self.cond_vector_mark_trajs = np.zeros(shape=(self.num_points,), dtype='int8')
19 | #
20 | # self.fig_mark_trajs = plt.figure('mark trajectories')
21 | # self.fig_mark_trajs.canvas.mpl_connect('pick_event', self.on_scatter_pick_mark_trajs)
22 | # self.ax_mark_trajs = self.fig_mark_trajs.add_subplot(111)
23 | # self.scat_mark_trajs = self.ax_mark_trajs.scatter(self.data_t[0],
24 | # self.data_t[1],
25 | # s = 5 * self.cond_vector_mark_trajs + np.ones(self.num_points)*self.pnt_size,
26 | # facecolor = self.V,
27 | # edgecolor='none',
28 | # picker=5)
29 |
30 | ###############################################
31 |
32 | def update_cond_vector(self):
33 |
34 | ball_x = np.asarray([row[0] for row in self.hand_craft_feats['ball_pos']])
35 | ball_y = np.asarray([row[1] for row in self.hand_craft_feats['ball_pos']])
36 | ball_dir = np.asarray(self.hand_craft_feats['ball_dir'])
37 | racket_x = np.asarray(self.hand_craft_feats['racket'])
38 | missing_bricks = np.asarray(self.hand_craft_feats['missing_bricks'])
39 | has_hole = np.asarray(self.hand_craft_feats['hole'])
40 |
41 | self.cond_vector = (ball_x >= self.ball_x_min) * (ball_x <= self.ball_x_max) * \
42 | (ball_y >= self.ball_y_min) * (ball_y <= self.ball_y_max) * \
43 | (racket_x >= self.racket_x_min) * (racket_x <= self.racket_x_max) * \
44 | (missing_bricks >= self.missing_bricks_min) * (missing_bricks <= self.missing_bricks_max)
45 |
46 | # ball dir
47 | dirs_mask = np.zeros_like(self.cond_vector)
48 | for i,val in enumerate(self.ball_dir_check_button.get_status()):
49 | if val:
50 | dirs_mask += self.ball_dir_mat[:,i]
51 | self.cond_vector = self.cond_vector * dirs_mask
52 |
53 | # has a hole
54 | if self.hole_check_button.get_status()[0] == 0: # filter out states with no-hole
55 | self.cond_vector = self.cond_vector * (has_hole)
56 | elif self.hole_check_button.get_status()[1] == 0: # filter out states with a hole
57 | self.cond_vector = self.cond_vector * (1-has_hole)
58 | self.cond_vector = self.cond_vector.astype(int)
59 |
60 | def on_scatter_pick_mark_trajs(self,event):
61 | if hasattr(event,'ind'):
62 | ind = event.ind[0]
63 | traj_ids = self.state_labels[:,6]
64 | times = self.state_labels[:,7]
65 | has_hole = self.state_labels[:,5]
66 |
67 | traj_id = traj_ids[ind]
68 | time = times[ind]
69 |
70 | # if current point is already marked then un-mark the entire trajectory
71 | if self.cond_vector_mark_trajs[ind] == 1:
72 | self.cond_vector_mark_trajs[np.nonzero(traj_ids==traj_id)] = 0
73 | else:
74 | # mark the entire point on the current trajectory from time until has_hole = 1
75 | cond = 1 * (traj_ids == traj_id) * (times >=time) * (has_hole == 0)
76 | self.cond_vector_mark_trajs[np.nonzero(cond)] = 1
77 | self.cond_vector_mark_trajs = self.cond_vector_mark_trajs.astype(int)
78 |
79 | sizes = 5 * self.cond_vector_mark_trajs + np.ones(self.num_points)*self.pnt_size
80 | self.scat_mark_trajs.set_array(self.cond_vector_mark_trajs)
81 | self.scat_mark_trajs.set_sizes(sizes)
82 | plt.pause(0.01)
--------------------------------------------------------------------------------
/graying_the_box/hand_crafted_features/add_global_features.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | def add_buttons(self, global_feats):
4 |
5 | #############################
6 | # 3.1 global coloring buttons
7 | #############################
8 | self.COLORS = {}
9 | self.add_color_button([0.60, 0.95, 0.09, 0.02], 'value', global_feats['value'])
10 | self.add_color_button([0.70, 0.95, 0.09, 0.02], 'actions', global_feats['actions'])
11 | self.add_color_button([0.60, 0.92, 0.09, 0.02], 'rooms', global_feats['rooms'])
12 | self.add_color_button([0.60, 0.89, 0.09, 0.02], 'gauss_clust', global_feats['gauss_clust'])
13 | self.add_color_button([0.70, 0.89, 0.09, 0.02], 'TD', global_feats['TD'])
14 | self.add_color_button([0.60, 0.86, 0.09, 0.02], 'action repetition', global_feats['act_rep'])
15 | self.add_color_button([0.70, 0.86, 0.09, 0.02], 'reward', global_feats['reward'])
16 |
17 | for i in range((global_feats['single_clusters']).shape[0]):
18 | self.add_color_button([0.80 + 0.10 * math.floor(i / 4), 0.95 - (i % 4) * 0.03, 0.09, 0.02], 'cluster_' + str(i), (global_feats['single_clusters'])[i])
19 |
20 |
21 | self.SLIDER_FUNCS = []
22 | self.CHECK_BUTTONS = []
23 |
--------------------------------------------------------------------------------
/graying_the_box/hand_crafted_features/add_global_features.py~:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | def add_buttons(self, global_feats):
4 |
5 | #############################
6 | # 3.1 global coloring buttons
7 | #############################
8 | self.COLORS = {}
9 | self.add_color_button([0.60, 0.95, 0.09, 0.02], 'value', global_feats['value'])
10 | self.add_color_button([0.70, 0.95, 0.09, 0.02], 'actions', global_feats['actions'])
11 | self.add_color_button([0.60, 0.92, 0.09, 0.02], 'rooms', global_feats['rooms'])
12 | self.add_color_button([0.60, 0.89, 0.09, 0.02], 'gauss_clust', global_feats['gauss_clust'])
13 | self.add_color_button([0.70, 0.89, 0.09, 0.02], 'TD', global_feats['TD'])
14 | self.add_color_button([0.60, 0.86, 0.09, 0.02], 'action repetition', global_feats['act_rep'])
15 | self.add_color_button([0.70, 0.86, 0.09, 0.02], 'reward', global_feats['reward'])
16 |
17 | for i in range((global_feats['single_clusters']).shape[0]):
18 | self.add_color_button([0.80 + 0.10 * math.floor(i / 4), 0.95 - (i % 24 * 0.03, 0.09, 0.02], 'cluster_' + str(i), (global_feats['single_clusters'])[i])
19 |
20 |
21 | self.SLIDER_FUNCS = []
22 | self.CHECK_BUTTONS = []
23 |
--------------------------------------------------------------------------------
/graying_the_box/hand_crafted_features/add_pacman_buttons.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import math
4 |
5 | def add_game_buttons(self, top_y=0.69, x_left = 0.68):
6 | self.add_slider_button([x_left, top_y, 0.08, 0.01], 'player_x', 0, 160)
7 | self.add_slider_button([x_left, top_y-0.04, 0.08, 0.01], 'player_y', 0, 210)
8 | self.add_slider_button([x_left, top_y-0.08, 0.08, 0.01], 'bricks', 0, 121)
9 | self.add_slider_button([x_left, top_y-0.12, 0.08, 0.01], 'enemy_distance', 0, 160+210)
10 | self.add_slider_button([x_left, top_y-0.16, 0.08, 0.01], 'lives', 0, 3)
11 | self.add_check_button([x_left, top_y-0.22, 0.06, 0.03], 'ghost', ('no-ghost','ghost'), (True,True))
12 | self.add_check_button([x_left, top_y-0.26, 0.06, 0.03], 'box', ('no-box','box'), (True,True))
13 | self.add_check_button([x_left, top_y-0.35, 0.07, 0.08], 'player_dir', ('stand','right','left','bottom','top'), (True, True, True, True, True))
14 | self.player_dir_mat = np.zeros(shape=(self.num_points,5), dtype='bool')
15 | for i,dir in enumerate(self.hand_craft_feats['player_dir']):
16 | self.player_dir_mat[i,int(dir)] = 1
17 |
18 | def update_cond_vector(self):
19 | player_x = np.asarray([row[0] for row in self.hand_craft_feats['player_pos']])
20 | player_y = np.asarray([row[1] for row in self.hand_craft_feats['player_pos']])
21 | player_dir = np.asarray(self.hand_craft_feats['player_dir'])
22 | bricks = np.asarray(self.hand_craft_feats['bricks'])
23 | nb_lives = np.asarray(self.hand_craft_feats['lives'])
24 | ghost_mode = np.asarray(self.hand_craft_feats['ghost'])
25 | enemies_dist = np.asarray(self.hand_craft_feats['enemy_distance'])
26 | bonus_box = np.asarray(self.hand_craft_feats['box'])
27 |
28 | self.cond_vector = (player_x >= self.player_x_min) * (player_x <= self.player_x_max) * \
29 | (player_y >= self.player_y_min) * (player_y <= self.player_y_max) * \
30 | (bricks >= self.bricks_min) * (bricks <= self.bricks_max) * \
31 | (enemies_dist >= self.enemy_distance_min) * (enemies_dist <= self.enemy_distance_max) * \
32 | (nb_lives >= self.lives_min) * (nb_lives <= self.lives_max)
33 |
34 | # player dir
35 | dirs_mask = np.zeros_like(self.cond_vector)
36 | for i,val in enumerate(self.player_dir_check_button.get_status()):
37 | if val:
38 | dirs_mask += self.player_dir_mat[:,i]
39 |
40 | self.cond_vector = self.cond_vector * dirs_mask
41 |
42 | # ghost mode
43 | if self.ghost_check_button.get_status()[0] == 0: # filter out states with "no-ghost"
44 | self.cond_vector = self.cond_vector * ghost_mode
45 | elif self.ghost_check_button.get_status()[1] == 0: # filter out states with "ghost" mode
46 | self.cond_vector = self.cond_vector * (1-ghost_mode)
47 |
48 | # bonus box
49 | if self.box_check_button.get_status()[0] == 0: # filter out states with "no-box"
50 | self.cond_vector = self.cond_vector * bonus_box
51 | elif self.box_check_button.get_status()[1] == 0: # filter out states with "box"
52 | self.cond_vector = self.cond_vector * (1-bonus_box)
53 |
54 | self.cond_vector = self.cond_vector.astype(int)
55 |
56 | def value_colored_frame(self):
57 | # value colored frame
58 | fig5 = plt.figure('value colored pacman frame')
59 | ax_5 = fig5.add_subplot(111)
60 |
61 | v_mat = np.zeros((210,160))
62 | count_mat = np.zeros((210,160))
63 | player_x = self.state_labels[:,0]
64 | player_y = self.state_labels[:,1]
65 |
66 | for x,y,v in zip(player_x, player_y, self.V):
67 | if math.isnan(x) or math.isnan(y):
68 | continue
69 | v_mat[int(y),int(x)] += v
70 | count_mat[int(y),int(x)] += 1
71 |
72 | v_mat = v_mat / count_mat
73 |
74 | ax_5.imshow(v_mat, interpolation='spline36')
--------------------------------------------------------------------------------
/graying_the_box/hand_crafted_features/add_seaquest_buttons.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def add_game_buttons(self, top_y=0.69, x_left = 0.68):
4 | self.add_slider_button([x_left, top_y , 0.08, 0.01], 'shooter_x', 0, 160)
5 | self.add_slider_button([x_left, top_y-0.04 , 0.08, 0.01], 'shooter_y', 0, 210)
6 | self.add_slider_button([x_left, top_y-0.08 , 0.08, 0.01], 'oxygen', 0, 1)
7 | self.add_slider_button([x_left, top_y-0.12 , 0.08, 0.01], 'divers', 0, 6)
8 | self.add_slider_button([x_left, top_y-0.16 , 0.08, 0.01], 'taken_divers', 0, 3)
9 | self.add_slider_button([x_left, top_y-0.20 , 0.08, 0.01], 'enemies', 0, 8)
10 | self.add_slider_button([x_left, top_y-0.24 , 0.08, 0.01], 'lives', 0, 3)
11 | self.add_check_button([x_left, top_y-0.32 , 0.06, 0.04], 'shooter_dir', ('dont-care','down','up'), (True, True, True))
12 |
13 | self.shooter_dir_mat = np.zeros(shape=(self.num_points,3), dtype='bool')
14 |
15 | for i,dir in enumerate(self.hand_craft_feats['shooter_dir']):
16 | self.shooter_dir_mat[i,int(dir)] = 1
17 |
18 | def update_cond_vector(self):
19 | shooter_x = np.asarray([row[0] for row in self.hand_craft_feats['shooter_pos']])
20 | shooter_y = np.asarray([row[1] for row in self.hand_craft_feats['shooter_pos']])
21 | shooter_dir = self.hand_craft_feats['shooter_dir']
22 | oxygen = np.asarray(self.hand_craft_feats['oxygen'])
23 | nb_divers = np.asarray(self.hand_craft_feats['divers'])
24 | nb_taken_divers = np.asarray(self.hand_craft_feats['taken_divers'])
25 | nb_enemies = np.asarray(self.hand_craft_feats['enemies'])
26 | nb_lives = np.asarray(self.hand_craft_feats['lives'])
27 |
28 | self.cond_vector = (shooter_x >= self.shooter_x_min) * (shooter_x <= self.shooter_x_max) * \
29 | (shooter_y >= self.shooter_y_min) * (shooter_y <= self.shooter_y_max) * \
30 | (oxygen >= self.oxygen_min) * (oxygen <= self.oxygen_max) * \
31 | (nb_divers >= self.divers_min) * (nb_divers <= self.divers_max) * \
32 | (nb_taken_divers >= self.taken_divers_min) * (nb_taken_divers <= self.taken_divers_max) *\
33 | (nb_enemies >= self.enemies_min) * (nb_enemies <= self.enemies_max) * \
34 | (nb_lives >= self.lives_min) * (nb_lives <= self.lives_max)
35 |
36 | # shtr dir
37 | dirs_mask = np.zeros_like(self.cond_vector)
38 | for i,val in enumerate(self.shooter_dir_check_button.get_status()):
39 | if val:
40 | dirs_mask += self.shooter_dir_mat[:,i]
41 |
42 | self.cond_vector = self.cond_vector * dirs_mask
43 |
44 | self.cond_vector = self.cond_vector.astype(int)
45 |
--------------------------------------------------------------------------------
/graying_the_box/hand_crafted_features/label_states_breakout.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 | def label_states(states, screens, termination_mat, debug_mode, num_lives):
5 |
6 | im_size = np.sqrt(states.shape[1])
7 | states = np.reshape(states, (states.shape[0], im_size, im_size)).astype('int16')
8 |
9 | screens = np.reshape(np.transpose(screens), (3,210,160,-1))
10 |
11 | screens = np.transpose(screens,(3,1,2,0))
12 |
13 | # masks
14 | ball_mask = np.ones_like(screens[0])
15 | ball_mask[189:] = 0
16 | ball_mask[57:63] = 0
17 |
18 | ball_x_ = 80
19 | ball_y_ = 105
20 |
21 | td_mask = np.ones_like(screens[0])
22 | td_mask[189:] = 0
23 | td_mask[:25] = 0
24 |
25 | features = {
26 | 'ball_pos': [[-1,-1],[-1,-1]],
27 | 'ball_dir': [-1,-1],
28 | 'racket': [-1,-1],
29 | 'missing_bricks': [0,0],
30 | 'hole': [0,0],
31 | 'traj': [0,0],
32 | 'time': [0,0]
33 | }
34 |
35 | if debug_mode:
36 | fig1 = plt.figure('screens')
37 | ax1 = fig1.add_subplot(111)
38 | screen_plt = ax1.imshow(screens[0], interpolation='none')
39 | plt.ion()
40 | plt.show()
41 |
42 | traj_id = 0
43 | time = 0
44 | strike_counter = 0
45 | s_ = screens[1]
46 | for i,s in enumerate(screens[2:]):
47 |
48 | #0. TD
49 | tdiff = (s - s_) * td_mask
50 | s_ = s
51 |
52 | row_ind, col_ind = np.nonzero(tdiff[:,:,0])
53 | ball_y = np.mean(row_ind)
54 | ball_x = np.mean(col_ind)
55 |
56 | # #1. ball location
57 | red_ch = s[:,:,0]
58 | # is_red = 255 * (red_ch == 200)
59 | # ball_filtered = np.zeros_like(s)
60 | # ball_filtered[:,:,0] = is_red
61 | # ball_filtered = ball_mask * ball_filtered
62 | #
63 | # row_ind, col_ind = np.nonzero(ball_filtered[:,:,0])
64 | #
65 | # ball_y = np.mean(row_ind)
66 | # ball_x = np.mean(col_ind)
67 |
68 | #2. ball direction
69 | ball_dir = 0 * (ball_x >= ball_x_ and ball_y >= ball_y_) +\
70 | 1 * (ball_x >= ball_x_ and ball_y < ball_y_) +\
71 | 2 * (ball_x < ball_x_ and ball_y >= ball_y_) +\
72 | 3 * (ball_x < ball_x_ and ball_y < ball_y_)
73 |
74 | ball_x_ = ball_x
75 | ball_y_ = ball_y
76 |
77 | #3. racket position
78 | is_red = 255 * (red_ch[190,8:-8] == 200)
79 | racket_x = np.mean(np.nonzero(is_red)) + 8
80 |
81 | #4. number of bricks
82 | z = red_ch[57:92,8:-8].flatten()
83 | is_brick = np.sum(1*(z>0) + 0*(z==0))
84 |
85 | missing_bricks = (len(z) - is_brick)/40.
86 |
87 | #5. holes
88 | brick_strip = red_ch[57:92,8:-8]
89 | brick_row_sum = brick_strip.sum(axis=0)
90 | has_hole = np.any((brick_row_sum==0))
91 |
92 | #6. traj_id
93 | if termination_mat[i] > 0:
94 | strike_counter+=1
95 | if strike_counter%num_lives==0:
96 | traj_id += 1
97 | time = 0
98 | time += 1
99 |
100 | if debug_mode:
101 | screen_plt.set_data(s)
102 | buf_line = ('Exqample %d: ball pos (x,y): (%0.2f, %0.2f), ball direct: %d, racket pos: (%0.2f), number of missing bricks: %d, has a hole: %d, traj id: %d, time: %d, st_cnt: %d') % \
103 | (i, ball_x, ball_y, ball_dir, racket_x, missing_bricks, has_hole, traj_id, time, strike_counter)
104 | print buf_line
105 | plt.pause(0.001)
106 |
107 | # labels[i] = (ball_x, ball_y, ball_dir, racket_x, missing_bricks, has_hole, traj_id, time)
108 |
109 | features['ball_pos'].append([ball_x, ball_y])
110 | features['ball_dir'].append(ball_dir)
111 | features['racket'].append(racket_x)
112 | features['missing_bricks'].append(missing_bricks)
113 | features['hole'].append(has_hole)
114 | features['traj'].append(traj_id)
115 | features['time'].append(time)
116 | features['n_trajs'] = traj_id
117 |
118 | return features
--------------------------------------------------------------------------------
/graying_the_box/hand_crafted_features/label_states_packman.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import scipy.signal
4 | import scipy.ndimage
5 |
6 | def label_states(states, screens, termination_mat, debug_mode, num_lives):
7 |
8 | screens = np.reshape(np.transpose(screens), (3,210,160,-1))
9 |
10 | screens = np.transpose(screens,(3,1,2,0))
11 |
12 | features = {
13 | 'player_pos': [[-1,-1],[-1,-1]],
14 | 'player_dir': [-1,-1],
15 | 'bricks': [-1,-1],
16 | 'lives': [3,3],
17 | 'sweets': [[1,1,1,1],[1,1,1,1]],
18 | 'ghost': [-1,-1],
19 | 'enemy_distance': [-1,-1],
20 | 'box': [0,0]
21 | }
22 |
23 | # player mask
24 | player_mask = np.ones((210,160))
25 | player_mask[91:96,78:82] = 0
26 |
27 | # number of bricks
28 | brown_frame = 1 * (screens[0,:,:,0]==162)
29 |
30 | small_brick_mask = np.asarray([[45,45, 45, 45, 45, 45],
31 | [45,162,162,162,162,45],
32 | [45,162,162,162,162,45],
33 | [45,45, 45, 45, 45, 45]])
34 |
35 | wide_brick_mask = np.asarray([[45,45, 45, 45, 45, 45, 45, 45, 45, 45],
36 | [45,162,162,162,162,162,162,162,162,45],
37 | [45,162,162,162,162,162,162,162,162,45],
38 | [45,45, 45, 45, 45, 45, 45, 45, 45, 45]])
39 |
40 | # sweets
41 | sweets_y = [22,142,22,142]
42 | sweets_x = [6,6,151,151]
43 |
44 | hist_depth = 16
45 | sweets_mat = np.zeros((hist_depth,4))
46 | sweets_vec_ = np.zeros((4,))
47 | bonus_times = [[] for i in range(4)]
48 |
49 | # left bottom brick mask = [139:149,5:11,0]
50 | # top left brick mask = [19:29,5:11,0]
51 | # top right brick mask = [19:29,149:155,0]
52 | # bottom right brick mask = [139:149,149:155,0]
53 | brick_color_1 = 45 * np.ones((10,6))
54 | brick_color_1[1:-1,1:-1] = 180
55 | brick_color_2 = 45 * np.ones((10,6))
56 | brick_color_2[1:-1,1:-1] = 149
57 | brick_color_3 = 45 * np.ones((10,6))
58 | brick_color_3[1:-1,1:-1] = 212
59 | brick_color_4 = 45 * np.ones((10,6))
60 | brick_color_4[1:-1,1:-1] = 232
61 | brick_color_5 = 45 * np.ones((10,6))
62 | brick_color_5[1:-1,1:-1] = 204
63 | #enemies mask
64 | enemies_mask = np.ones((210,160))
65 | enemies_mask[20:28,6:10] = 0
66 | enemies_mask[140:148,6:10] = 0
67 | enemies_mask[20:28,150:154] = 0
68 | enemies_mask[140:148,150:154] = 0
69 |
70 | # bonus box
71 | bonus_color_mask = 162 * np.ones((7,6))
72 | bonus_color_mask[1:-1,1:-1] = 210
73 |
74 | if debug_mode:
75 | fig1 = plt.figure('screens')
76 | ax1 = fig1.add_subplot(111)
77 | screen_plt = ax1.imshow(screens[0], interpolation='none')
78 | fig2 = plt.figure('enemies')
79 | ax2 = fig2.add_subplot(111)
80 | brown_plt = ax2.imshow(brown_frame)
81 | # fig3 = plt.figure('brick heat map')
82 | # ax3 = fig3.add_subplot(111)
83 | # bricks_plt = ax3.imshow(brown_frame, interpolation='none')
84 | plt.ion()
85 | plt.show()
86 |
87 | player_x_ = 79
88 | player_y_ = 133
89 |
90 | restarted_flag = 0
91 | ttime = 0
92 |
93 | for i,s in enumerate(screens[2:]):
94 | # for i,s in enumerate(screens[750:]):
95 |
96 | # 1. player location
97 | player_frame = s[:,:,0]==210 * player_mask
98 | row_ind, col_ind = np.nonzero(player_frame)
99 | player_y = np.mean(row_ind)
100 | player_x = np.mean(col_ind)
101 |
102 | # 2. player direction
103 | dx = player_x - player_x_
104 | dy = player_y - player_y_
105 |
106 | player_dir = 0
107 | if abs(dx) >= abs(dy):
108 | if dx > 0:
109 | player_dir = 1
110 | elif dx<0:
111 | player_dir = 2
112 | else:
113 | if dy > 0:
114 | player_dir = 3
115 | elif dy < 0:
116 | player_dir = 4
117 |
118 | player_x_ = player_x
119 | player_y_ = player_y
120 |
121 | # 3. number of bricks to eat
122 | # approximated (fast)
123 | brown_frame = 1 * (s[:,:,0]==162)
124 | brown_sum = np.sum(brown_frame)
125 | nb_bricks_apprx = (brown_sum - 6042)/8.
126 | nb_bricks_apprx = np.maximum(nb_bricks_apprx,0)
127 |
128 | # exact (slow)
129 | # nb_bricks = 0
130 | # for r in range(160):
131 | # for c in range (150):
132 | # small_slice = s[r:r+4,c:c+6,0]
133 | # wide_slice = s[r:r+4,c:c+10,0]
134 | # if (small_slice == small_brick_mask).all():
135 | # nb_bricks += 1
136 | # elif (wide_slice == wide_brick_mask).all():
137 | # nb_bricks += 1
138 |
139 | # 4. number of lives
140 | lives_strip = s[184:189,:,0]
141 | _, nb_lives = scipy.ndimage.label(lives_strip)
142 |
143 | # 5. sweets
144 | # look for bottom left brick
145 | sweets_mat[i%hist_depth,0] = 1 * np.all(s[139:149,5:11,0] == brick_color_1) + \
146 | 1 * np.all(s[139:149,5:11,0] == brick_color_2) + \
147 | 1 * np.all(s[139:149,5:11,0] == brick_color_3) + \
148 | 1 * np.all(s[139:149,5:11,0] == brick_color_4) + \
149 | 1 * np.all(s[139:149,5:11,0] == brick_color_5)
150 |
151 |
152 | # look for top right brick
153 | sweets_mat[i%hist_depth,1] = 1 * np.all(s[19:29,149:155,0] == brick_color_1) + \
154 | 1 * np.all(s[19:29,149:155,0] == brick_color_2) + \
155 | 1 * np.all(s[19:29,149:155,0] == brick_color_3) + \
156 | 1 * np.all(s[19:29,149:155,0] == brick_color_4) + \
157 | 1 * np.all(s[19:29,149:155,0] == brick_color_5)
158 |
159 | # look for bottom right brick
160 | sweets_mat[i%hist_depth,2] = 1 * np.all(s[139:149,149:155,0] == brick_color_1) + \
161 | 1 * np.all(s[139:149,149:155,0] == brick_color_2) + \
162 | 1 * np.all(s[139:149,149:155,0] == brick_color_3) + \
163 | 1 * np.all(s[139:149,149:155,0] == brick_color_4) + \
164 | 1 * np.all(s[139:149,149:155,0] == brick_color_5)
165 |
166 | # look for top left brick
167 | sweets_mat[i%hist_depth,3] = 1 * np.all(s[19:29,5:11,0] == brick_color_1) + \
168 | 1 * np.all(s[19:29,5:11,0] == brick_color_2) + \
169 | 1 * np.all(s[19:29,5:11,0] == brick_color_3) + \
170 | 1 * np.all(s[19:29,5:11,0] == brick_color_4) + \
171 | 1 * np.all(s[19:29,5:11,0] == brick_color_5)
172 |
173 | sweets_vec = 1 * np.any(sweets_mat, axis=0)
174 |
175 | if np.all(sweets_vec_ == 1) and restarted_flag == 0:
176 | ttime = 0
177 | restarted_flag = 1
178 | if not np.all(sweets_vec_ == 1):
179 | restarted_flag = 0
180 | ttime += 1
181 |
182 | # 6. ghost mode
183 | ghost_mode = np.any(1 * (s[:,:,0] == 149) + 1 * (s[:,:,0] == 212))
184 |
185 | # 7. enemies map
186 | enemies_map = 1 * (s[:,:,0] == 180) + \
187 | 1 * (s[:,:,0] == 149) + \
188 | 1 * (s[:,:,0] == 212) + \
189 | 1 * (s[:,:,0] == 128) + \
190 | 1 * (s[:,:,0] == 232) + \
191 | 1 * (s[:,:,0] == 204)
192 | enemies_map = enemies_map * enemies_mask
193 |
194 | enemies_dilate_map = scipy.ndimage.binary_dilation(enemies_map, iterations=1)
195 | labeled_array, nb_enemies = scipy.ndimage.label(enemies_dilate_map)
196 |
197 | min_dist = 999
198 | for j in range(nb_enemies):
199 | enemy_j_rows, enemy_j_cols = np.nonzero(labeled_array==j+1)
200 | e_j_y = np.mean(enemy_j_rows)
201 | e_j_x = np.mean(enemy_j_cols)
202 | dist_j = abs(player_x - e_j_x) + abs(player_y - e_j_y)
203 | if dist_j < min_dist:
204 | min_dist = dist_j
205 |
206 | # 8. bonus box
207 | # has_box = np.any(s[91:96,78:82,0]==210)
208 | has_box = 1 * np.all(s[90:97,77:83,0] == bonus_color_mask)
209 |
210 | #6. bonus bricks option
211 | # bottom left brick
212 | if sweets_vec[0] == 0 and sweets_vec_[0] == 1:
213 | bonus_times[0].append(ttime)
214 | # top right brick
215 | elif sweets_vec[1] == 0 and sweets_vec_[1] == 1:
216 | bonus_times[1].append(ttime)
217 | # bottom right brick
218 | elif sweets_vec[2] == 0 and sweets_vec_[2] == 1:
219 | bonus_times[2].append(ttime)
220 | # top left brick
221 | elif sweets_vec[3] == 0 and sweets_vec_[3] == 1:
222 | bonus_times[3].append(ttime)
223 |
224 | sweets_vec_ = sweets_vec
225 |
226 | if debug_mode:
227 | screen_plt.set_data(s)
228 | brown_plt.set_data(brown_frame)
229 | # bricks_plt.set_data(brown_frame)
230 | buf_line = ('Exqample %d: Player (x,y): (%0.2f,%0.2f), Player dir: %d, number of bricks: %d, number of lives: %d, sweets_vec: %s, ghost mode: %d, minimal enemy distance: %0.2f, bonus box: %d, time: %d') %\
231 | (i, player_x, player_y, player_dir, nb_bricks_apprx, nb_lives, sweets_vec, ghost_mode, min_dist, has_box, ttime)
232 | print buf_line
233 | plt.pause(0.01)
234 |
235 | features['player_pos'].append([player_x, player_y])
236 | features['player_dir'].append(player_dir)
237 | features['bricks'].append(nb_bricks_apprx)
238 | features['sweets'].append(sweets_vec)
239 | features['ghost'].append(ghost_mode)
240 | features['lives'].append(nb_lives)
241 | features['enemy_distance'].append(min_dist)
242 | features['box'].append(has_box)
243 |
244 | # histogram of sweets cllection times
245 | if 0:
246 | brick_bl_times = np.asarray(bonus_times[0])
247 | brick_tr_times = np.asarray(bonus_times[1])
248 | brick_br_times = np.asarray(bonus_times[2])
249 | brick_tl_times = np.asarray(bonus_times[3])
250 |
251 | h_bl = plt.hist(brick_bl_times, bins = 70, range=(0,800), normed=1, facecolor='black', alpha=0.75, label='bottom left')
252 | h_tr = plt.hist(brick_tr_times, bins = 70, range=(0,800), normed=1, facecolor='red', alpha=0.75, label='top right')
253 | h_br = plt.hist(brick_br_times, bins = 70, range=(0,800), normed=1, facecolor='green', alpha=0.75, label='bottom right')
254 | h_tl = plt.hist(brick_tl_times, bins = 70, range=(0,800), normed=1, facecolor='blue', alpha=0.75, label='top left')
255 |
256 | plt.legend(prop={'size':20})
257 | plt.tick_params(axis='both',which='both',left='off',right='off', labelsize=20)
258 | plt.locator_params(axis='x',nbins=4)
259 | plt.locator_params(axis='y',nbins=8)
260 |
261 | plt.show()
262 |
263 | return features
--------------------------------------------------------------------------------
/graying_the_box/hand_crafted_features/label_states_seaquest.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import scipy.signal
4 | import scipy.ndimage
5 |
6 | def label_states(states, screens, termination_mat, debug_mode, num_lives):
7 |
8 | screens = np.reshape(np.transpose(screens), (3,210,160,-1))
9 |
10 | screens = np.transpose(screens,(3,1,2,0))
11 |
12 | features = {
13 | 'shooter_pos': [[-1,-1],[-1,-1]],
14 | 'shooter_dir': [-1,-1],
15 | 'racket': [-1,-1],
16 | 'oxygen': [0,0],
17 | 'divers': [0,0],
18 | 'taken_divers': [0,0],
19 | 'enemies': [0,0],
20 | 'lives': [3,3],
21 | }
22 |
23 | # shooter convolution
24 | yellow_frame = screens[0,:,:,0]==187
25 | shtr_mask = np.asarray([[1,1,1,1],[1,1,1,1],[1,1,1,1],[1,1,1,1],[1,1,1,1]])
26 | shooter_conv_map = scipy.signal.convolve2d(yellow_frame, shtr_mask, mode='same')
27 | shtr_y_, shtr_x_ = np.unravel_index(np.argmax(shooter_conv_map),(210,160))
28 |
29 | #divers
30 | divers_frame = np.ones_like(screens[0,:,:,0])
31 | divers_frame[:160,:] = 1 * (screens[0,:160,:,0]==66)
32 |
33 | if debug_mode:
34 | fig1 = plt.figure('screens')
35 | ax1 = fig1.add_subplot(111)
36 | screen_plt = ax1.imshow(screens[0], interpolation='none')
37 | fig2 = plt.figure('divers')
38 | ax2 = fig2.add_subplot(111)
39 | divers_plt = ax2.imshow(divers_frame)
40 | # fig3 = plt.figure('enemies conv')
41 | # ax3 = fig3.add_subplot(111)
42 | # diver_conv_plt = ax3.imshow(diver_conv_map, interpolation='none')
43 | plt.ion()
44 | plt.show()
45 |
46 | shtr_dir_ = 0
47 | shtr_dir__ = 0
48 |
49 | for i,s in enumerate(screens[2:]):
50 | # for i,s in enumerate(screens[13210:]):
51 |
52 | # 1. shooter location
53 | yellow_frame = s[:,:,0]==187
54 | shooter_conv_map = scipy.signal.convolve2d(yellow_frame, shtr_mask,mode='same')
55 | shtr_y, shtr_x = np.unravel_index(np.argmax(shooter_conv_map),(210,160))
56 |
57 | # if shtr_x==0 :
58 | # shtr_x = -1
59 | # if shtr_y==0:
60 | # shtr_y = -1
61 |
62 | # 2. shooter direction
63 | shtr_dir = 1 * (shtr_y >= shtr_y_) +\
64 | 2 * (shtr_y < shtr_y_)
65 |
66 | coherent = (shtr_dir == shtr_dir_) * (shtr_dir == shtr_dir__)
67 | shtr_dir__ = shtr_dir_
68 | shtr_dir_ = shtr_dir
69 |
70 | shtr_dir = shtr_dir * coherent
71 |
72 | shtr_x_ = shtr_x
73 | shtr_y_ = shtr_y
74 |
75 | # 3. oxygen
76 | oxgn_line = s[172,49:111,0]
77 | oxgn_lvl = np.sum(1*(oxgn_line==214))/float((111-49))
78 |
79 | # 4. avaiable divers
80 | divers_frame = 0 * divers_frame
81 | divers_frame[:160,:] = 1 * (s[:160,:,0] == 66)
82 | erote_map = scipy.ndimage.binary_erosion(divers_frame, structure=np.ones((2,2)), iterations=1)
83 | diver_conv_map = scipy.ndimage.binary_dilation(erote_map, iterations=3)
84 | _, nb_divers = scipy.ndimage.label(diver_conv_map)
85 |
86 | # 5. taken divers
87 | # divers_frame = 0 * divers_frame
88 | # divers_frame[178:188,:] = 1 * (s[178:188,:,0] == 24)
89 | # diver_conv_map = scipy.ndimage.binary_dilation(divers_frame, iterations=1)
90 | # _, nb_taken_divers = scipy.ndimage.label(diver_conv_map)
91 | nb_taken_divers = np.sum(1 * s[178,:,0] == 24)
92 |
93 | # 6. enemies
94 | enemies_frame = 0 * divers_frame
95 | enemies_frame[:160,:] = 1 * (s[:160,:,0] == 92) + \
96 | 2 * (s[:160,:,0] == 160) + \
97 | 3 * (s[:160,:,0] == 170) + \
98 | 4 * (s[:160,:,0] == 198)
99 |
100 | enemies_conv_map = scipy.ndimage.binary_dilation(enemies_frame, iterations=2)
101 | _, nb_enemies = scipy.ndimage.label(enemies_conv_map)
102 |
103 | # lives
104 | lives_slice = 1 * (s[18:30,:,0] == 210)
105 | _, nb_lives = scipy.ndimage.label(lives_slice)
106 |
107 |
108 | if debug_mode:
109 | screen_plt.set_data(s)
110 | # divers_plt.set_data(diver_conv_map)
111 | # diver_conv_plt.set_data(diver_conv_map)
112 | buf_line = ('Exqample %d: shooter (x,y): (%0.2f, %0.2f), shooter dir: %d, oxygen level: %0.2f, divers: %d, taken divers: %d, enemies: %d, lives: %d') %\
113 | (i, shtr_x, shtr_y, shtr_dir, oxgn_lvl, nb_divers, nb_taken_divers, nb_enemies, nb_lives)
114 | print buf_line
115 | plt.pause(0.01)
116 |
117 | features['shooter_pos'].append([shtr_x, shtr_y])
118 | features['shooter_dir'].append(shtr_dir)
119 | features['oxygen'].append(oxgn_lvl)
120 | features['divers'].append(nb_divers)
121 | features['taken_divers'].append(nb_taken_divers)
122 | features['enemies'].append(nb_enemies)
123 | features['lives'].append(nb_lives)
124 |
125 | return features
--------------------------------------------------------------------------------
/graying_the_box/main.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.append('hand_crafted_features')
4 |
5 | from prepare_data import prepare_data
6 | from vis_tool import VIS_TOOL
7 |
8 | # Parameters
9 | run_dir = '/home/deep5/DQN_Shahar_Chen_oldpc/dqn_distill/'
10 | num_frames = 100000
11 | game_id = 0 # 0-breakout, 1-seaquest, 2-pacman
12 | load_data = 0
13 | debug_mode = 0
14 | cluster_method = 0 # 0-kmeans, 1-spectral_clustering, 2-EMHC (entropy minimization hierarchical clustering)
15 | n_clusters = 4
16 | window_size = 2
17 | n_iters = 8
18 | entropy_iters = 0
19 |
20 | cluster_params = {
21 | 'method': cluster_method,
22 | 'n_clusters': n_clusters,
23 | 'window_size': window_size,
24 | 'n_iters': n_iters,
25 | 'entropy_iters': entropy_iters
26 | }
27 | global_feats, hand_crafted_feats = prepare_data(game_id, run_dir, num_frames, load_data, debug_mode)
28 |
29 | vis_tool = VIS_TOOL(global_feats=global_feats, hand_craft_feats=hand_crafted_feats, game_id=game_id, cluster_params=cluster_params)
30 |
31 | vis_tool.show()
32 |
--------------------------------------------------------------------------------
/graying_the_box/others/SaliencyScore.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('/home/tom/OpenBox/bhtsne/')
3 |
4 | import numpy as np
5 | import h5py
6 | import matplotlib.image as mpimg
7 |
8 | import numpy as np
9 | import matplotlib.pyplot as plt
10 |
11 | numframes = 30000
12 | Seaquestind = 352
13 | Breakoutind = 56
14 | Pacmanind = 1316
15 |
16 | im_size = 84
17 |
18 |
19 | print "loading states... "
20 |
21 | Breakout_state_file = h5py.File('/home/tom/OpenBox/tsne_res/6k/states.h5', 'r')
22 | Breakout_state_mat = Breakout_state_file['data']
23 | Breakout_states = Breakout_state_mat[:numframes]
24 | Breakout_states = np.reshape(Breakout_states, (Breakout_states.shape[0], im_size,im_size))
25 |
26 | Seaquest_state_file = h5py.File('/home/tom/OpenBox/tsne_res/seaquest/13k/states.h5', 'r')
27 | Seaquest_state_mat = Seaquest_state_file['data']
28 | Seaquest_states = Seaquest_state_mat[:numframes]
29 | Seaquest_states = np.reshape(Seaquest_states, (Seaquest_states.shape[0], im_size,im_size))
30 |
31 | Pacman_state_file = h5py.File('/home/tom/OpenBox/tsne_res/pacman/7k/states.h5', 'r')
32 | Pacman_state_mat = Pacman_state_file['data']
33 | Pacman_states = Pacman_state_mat[:numframes]
34 | Pacman_states = np.reshape(Pacman_states, (Pacman_states.shape[0], im_size,im_size))
35 | print "loading grads... "
36 | thresh = 0.1
37 |
38 | Breakout_grad_file = h5py.File('/home/tom/OpenBox/tsne_res/6k/grads.h5', 'r')
39 | Breakout_grad_mat = Breakout_grad_file['data']
40 | Breakout_grads = Breakout_grad_mat[:numframes]
41 | Breakout_grads[np.abs(Breakout_grads)0]
92 | Seaquest_term = Seaquest_tsne[:,Seaquest_term>0]
93 | Pacman_term = Pacman_tsne[:,Pacman_term>0]
94 |
95 | axs.flat[0].annotate(
96 | 'Termination',
97 | xy = (Breakout_term[0,0], Breakout_term[1,0]), xytext = (-20, 20),size=8,
98 | textcoords = 'offset points', ha = 'right', va = 'bottom',
99 | bbox = dict(boxstyle = 'round,pad=0.5', fc = 'yellow', alpha = 0.5),
100 | arrowprops = dict(arrowstyle = '->', connectionstyle = 'arc3,rad=0'))
101 |
102 | axs.flat[1].annotate(
103 | 'Termination',
104 | xy = (Seaquest_term[0,0], Seaquest_term[1,0]), xytext = (-20, 20),size=8,
105 | textcoords = 'offset points', ha = 'right', va = 'bottom',
106 | bbox = dict(boxstyle = 'round,pad=0.5', fc = 'yellow', alpha = 0.5),
107 | arrowprops = dict(arrowstyle = '->', connectionstyle = 'arc3,rad=0'))
108 |
109 | axs.flat[2].annotate(
110 | 'Termination',
111 | xy = (Pacman_term[0,0], Pacman_term[1,0]), xytext = (-20, 20),size=8,
112 | textcoords = 'offset points', ha = 'right', va = 'bottom',
113 | bbox = dict(boxstyle = 'round,pad=0.5', fc = 'yellow', alpha = 0.5),
114 | arrowprops = dict(arrowstyle = '->', connectionstyle = 'arc3,rad=0'))
115 |
116 | plt.show()
117 |
--------------------------------------------------------------------------------
/graying_the_box/prepare_data.py:
--------------------------------------------------------------------------------
1 | import cPickle as pickle
2 | from prepare_global_features import prepare_global_features
3 |
4 | def prepare_data(game_id, run_dir, num_frames, load_data, debug_mode):
5 | # 1. switch games TEST
6 | if game_id == 0: #'breakout'
7 | num_actions = 10
8 | num_lives = 1
9 | data_dir = run_dir
10 | from label_states_breakout import label_states
11 | grad_thresh = 0.1
12 |
13 | elif game_id == 1: #'seaquest'
14 | num_actions = 18
15 | num_lives = 3
16 | data_dir = 'data/' + 'seaquest' + '/' + run_dir + '/'
17 | from label_states_seaquest import label_states
18 | grad_thresh = 0.05
19 |
20 | elif game_id == 2: #'pacman'
21 | num_actions = 5
22 | num_lives = 3
23 | data_dir = 'data/' + 'pacman' + '/' + run_dir + '/'
24 | from label_states_packman import label_states
25 | grad_thresh = 0.05
26 |
27 | # 2. load data
28 | if load_data:
29 | # 2.1 global features
30 | global_feats = pickle.load(file(data_dir + 'global_features.bin','rb'))
31 |
32 | # 2.2 hand craft features
33 | hand_craft_feats = pickle.load(file(data_dir + 'hand_craft_features.bin','rb'))
34 |
35 | # 3. prepare data
36 | else:
37 | # 3.1 global features
38 | global_feats = prepare_global_features(data_dir, num_frames, num_actions, num_lives, grad_thresh)
39 | # pickle.dump(global_feats,file(data_dir + 'global_features.bin','wb'))
40 |
41 | # 3.2 hand craft features
42 | hand_craft_feats = label_states(global_feats['states'], global_feats['screens'], global_feats['termination'], debug_mode=debug_mode, num_lives=num_lives)
43 | # pickle.dump(hand_craft_feats,file(data_dir + 'hand_craft_features.bin','wb'))
44 |
45 | return global_feats, hand_craft_feats
46 |
--------------------------------------------------------------------------------
/graying_the_box/prepare_global_features.py:
--------------------------------------------------------------------------------
1 | import common
2 | import numpy as np
3 | import h5py
4 |
5 | def prepare_global_features(data_dir, num_frames, num_actions, num_lives, grad_thresh):
6 |
7 | print "Preparing global features... "
8 |
9 | # load hdf5 files
10 | print "Loading from hdf5 files... "
11 | states_hd5f = common.load_hdf5('statesClean', data_dir, num_frames)
12 | screens_hdf5 = np.ones((3, states_hd5f.shape[0], states_hd5f.shape[1]))
13 | screens_hdf5[0, :, :] = states_hd5f[:, :]
14 | screens_hdf5[1, :, :] = states_hd5f[:, :]
15 | screens_hdf5[2, :, :] = states_hd5f[:, :]
16 | #termination_hdf5 = common.load_hdf5('terminationClean', data_dir, num_frames)
17 | lowd_activation_hdf5 = common.load_hdf5('lowd_activations_800', data_dir, num_frames)
18 | lowd_activation_3d_hdf5 = np.zeros((lowd_activation_hdf5.shape[0], 3)) # common.load_hdf5('lowd_activations3d', data_dir, num_frames)
19 | q_hdf5 = common.load_hdf5('qvalsClean', data_dir, num_frames)
20 | a_hdf5 = np.array(common.load_hdf5('actionsClean', data_dir, num_frames))
21 | reward_hdf5 = common.load_hdf5('rewardClean', data_dir, num_frames)
22 | with h5py.File('/home/deep5/DQN_Shahar_Chen_oldpc/dqn_distill/tsneMatch.h5', 'r') as hf:
23 | gaus_clust = (np.array(hf.get('data')))
24 | with h5py.File('/home/deep5/DQN_Shahar_Chen_oldpc/dqn_distill/tsneRooms.h5', 'r') as hf:
25 | rooms = (np.array(hf.get('data')))
26 | with h5py.File('/home/deep5/DQN_Shahar_Chen_oldpc/dqn_distill/tsneData.h5', 'r') as hf:
27 | single_clusters = (np.array(hf.get('data')))
28 |
29 | # grads_hdf5 = common.load_hdf5('grads', data_dir, num_frames)
30 | grads_hdf5 = states_hd5f # take state image instead of saliency image until we have saliency for all data
31 |
32 | termination_hdf5 = np.sign(np.array(reward_hdf5)) + 1
33 |
34 |
35 | # 9. get Q and action, translate to numpy
36 | V = np.zeros(shape=num_frames)
37 | Q = np.zeros(shape=(num_frames,num_actions))
38 | tsne = np.zeros(shape=(num_frames,2))
39 | a = np.zeros(shape=num_frames)
40 | term = np.zeros(shape=num_frames)
41 | reward = np.zeros(shape=num_frames)
42 | TD = np.zeros(shape=num_frames)
43 | tsne3d = np.zeros(shape=(num_frames,3))
44 | tsne3d_next = np.zeros(shape=(num_frames,3))
45 | time = np.zeros(shape=num_frames)
46 | act_rep = np.zeros(shape=num_frames)
47 | trajectory_index = np.zeros(shape=num_frames)
48 | tsne3d_norm = np.zeros(shape=num_frames)
49 |
50 | counter = 0
51 | term_counter = 0
52 | trajectory_counter = 1
53 |
54 | for i in range(1,num_frames-1):
55 | V[i] = q_hdf5[i]
56 | a[i] = float(a_hdf5[i])-1
57 | reward[i] = reward_hdf5[i]
58 | term[i] = termination_hdf5[i]
59 | Q[i] = q_hdf5[i]
60 | tsne[i] = lowd_activation_hdf5[i]
61 | tsne3d[i] = lowd_activation_3d_hdf5[i]
62 | tsne3d_next[i] = lowd_activation_3d_hdf5[i+1]-lowd_activation_3d_hdf5[i]
63 | tsne3d_norm[i] = np.linalg.norm(tsne3d_next[i])
64 | TD[i] = abs(Q[i-1,int(a[i-1])]-0.99*(Q[i,int(a[i])]+reward[i-1]))
65 |
66 | # calculate time and trajecttory index
67 | time[i] = counter
68 | trajectory_index[i] = trajectory_counter
69 |
70 | if (term[i] > 0):#the fifth terminal
71 | term_counter += 1
72 | if term_counter % num_lives == 0:
73 | counter = 0
74 | trajectory_counter+=1
75 | else:
76 | counter += 1
77 | else:
78 | counter += 1
79 |
80 | Advantage = V-Q.T
81 | risk = np.sum(Advantage,axis=0)
82 | term_binary = term
83 | term_binary[np.nonzero(term_binary!=0)]=1
84 |
85 | global_feats = {
86 | 'tsne': tsne,
87 | 'states':states_hd5f,
88 | 'screens':screens_hdf5,
89 | 'value':V,
90 | 'actions':a,
91 | 'termination':term_binary,
92 | 'risk':risk,
93 | 'tsne3d':tsne3d,
94 | 'tsne3d_norm':tsne3d_norm,
95 | 'Advantage':Advantage,
96 | 'time':time,
97 | 'TD':TD,
98 | 'reward':reward,
99 | 'act_rep':act_rep,
100 | 'tsne3d_next':tsne3d_next,
101 | 'grads':grads_hdf5,
102 | 'trajectory_index':trajectory_index,
103 | 'data_dir':data_dir,
104 | 'gauss_clust':gaus_clust,
105 | 'rooms':rooms,
106 | 'single_clusters':single_clusters
107 | }
108 |
109 | return global_feats
110 |
--------------------------------------------------------------------------------
/graying_the_box/smdp.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.linalg
3 |
4 |
5 | def divide_tt(X, tt_ratio):
6 | N = X.shape[0]
7 | X_train = X[:int(tt_ratio*N)]
8 | X_test = X[int(tt_ratio*N):]
9 | return X_train, X_test
10 |
11 | class SMDP(object):
12 | def __init__(self, labels, termination, rewards, values, n_clusters, tb=0, gamma=0.99, trunc_th = 0.1, k=5):
13 |
14 | self.k = k
15 | self.gamma = gamma
16 | self.trunc_th = trunc_th
17 | self.rewards = rewards
18 | if tb == 0:
19 | self.labels,self.n_clusters = self.remove_empty_clusters(labels,n_clusters)
20 | else:
21 | self.labels = labels
22 | self.n_clusters = n_clusters
23 | self.termination = termination
24 | self.TT = self.calculate_transition_matrix()
25 | self.P = self.calculate_prob_transition_matrix(self.TT)
26 | if tb == 0:
27 | self.check_empty_P()
28 | self.r, self.skill_time = self.smdp_reward()
29 | self.v_smdp = self.calc_v_smdp(self.P)
30 | self.v_dqn = self.calc_v_dqn(values)
31 | self.score = self.value_score()
32 | self.clusters_count = self.count_clusters()
33 | self.entropy = self.calc_entropy(self.P)
34 | self.edges = self.get_smdp_edges()
35 |
36 | ####### Methods ######
37 |
38 | def check_empty_P(self):
39 | cluster_ind = 0
40 | for p in (self.P):
41 | if p.sum()==0:
42 | indices = np.nonzero(self.labels==cluster_ind)[0]
43 | for i in indices:
44 | self.labels[i]=self.labels[i-1]
45 | cluster_ind+=1
46 | self.labels, self.n_clusters = self.remove_empty_clusters(self.labels,self.n_clusters)
47 | self.TT = self.calculate_transition_matrix()
48 | self.P = self.calculate_prob_transition_matrix(self.TT)
49 |
50 | def count_clusters(self):
51 | cluster_count = np.zeros(self.n_clusters)
52 | for l in self.labels:
53 | cluster_count[l]+=1
54 | return cluster_count
55 |
56 | def remove_empty_clusters(self,labels,n_c):
57 | # remove empty clusters
58 | labels_i = np.copy(labels)
59 | cluster_flags = np.zeros(n_c,dtype=np.bool)
60 | cluster_count = np.zeros(n_c)
61 | for l in labels_i:
62 | cluster_flags[l] = True
63 | cluster_count[l]+=1
64 | shift_vec = np.cumsum(~cluster_flags)
65 | for ind,l in enumerate(labels_i):
66 | labels_i[ind] -= shift_vec[l]
67 | new_n_clusters = np.max(labels_i)+1
68 |
69 | return labels_i,new_n_clusters
70 |
71 | def calc_entropy(self,P):
72 | e = scipy.stats.entropy(P.T)
73 | e_finite_ind = np.isfinite(e)
74 | entropy = np.average(a=e[e_finite_ind],weights=self.clusters_count)
75 | return entropy
76 |
77 | def truncate(self, M, th):
78 | M[np.nonzero(Mself.k:
115 | mean_rewards[l_p,0] += total_r #/ (t+1)
116 | mean_rewards[l_p,1] += 1
117 | mean_times[l_p,0] += t
118 | mean_times[l_p,1] += 1
119 | l_p = l
120 | total_r = 0
121 | t = 0
122 |
123 | for mr,mt in zip(mean_rewards, mean_times):
124 | mr[0] = mr[0] / mr[1]
125 | mt[0] = mt[0] / mt[1]
126 |
127 | return mean_rewards[:,0], mean_times[:,0]
128 |
129 | def calc_v_smdp(self,P):
130 | GAMMA = np.diag(self.gamma**self.skill_time)
131 | v = np.dot(np.linalg.pinv(np.eye(self.n_clusters)-np.dot(GAMMA,P)),self.r)
132 | return v
133 |
134 | def calc_v_policy(self,P):
135 | skill_time = np.zeros(shape=(self.n_clusters,1))
136 | r = np.zeros(shape=(self.n_clusters,1))
137 | for cluster_ind in xrange(self.n_clusters):
138 | cluster_policy_skills = np.nonzero(P[cluster_ind,:])
139 | cluster_skills = np.nonzero(self.P[cluster_ind,:])
140 | n_skills = len(cluster_policy_skills[0])
141 |
142 | for policy_skill_ind in xrange(n_skills):
143 | next_ind = cluster_policy_skills[0][policy_skill_ind]
144 | skill_ind = np.nonzero(cluster_skills[0]==next_ind)[0][0]
145 | r[cluster_ind] += P[cluster_ind,next_ind]*self.R_skills[cluster_ind][skill_ind,0]
146 | skill_time[cluster_ind] += P[cluster_ind,next_ind]*self.k_skills[cluster_ind][skill_ind,0]
147 | GAMMA = np.diag((self.gamma**skill_time)[:,0])
148 | v = np.dot(np.linalg.pinv(np.eye(self.n_clusters)-np.dot(GAMMA,P)),r)
149 | return v
150 |
151 | def calc_v_dqn(self, values):
152 |
153 | value_vec = np.zeros(shape=(self.n_clusters,2))
154 | for i,(l,v) in enumerate(zip(self.labels, values)):
155 | value_vec[l, 0] += v
156 | value_vec[l, 1] += 1
157 |
158 | # 1.5 normalize rewards
159 | for val in value_vec:
160 | val[0] = val[0]/val[1]
161 |
162 | return value_vec[:,0]
163 |
164 | def value_score(self):
165 | # v_dqn = (self.v_dqn-self.v_dqn.mean())/self.v_dqn.std()
166 | # v_smdp = (self.v_smdp-self.v_smdp.mean())/self.v_smdp.std()
167 | v_dqn = self.v_dqn
168 | v_smdp = self.v_smdp
169 | return np.linalg.norm(v_dqn-v_smdp)/np.linalg.norm(v_dqn)
170 |
171 | def policy_improvement(self):
172 | policy = []
173 | for cluster_ind in xrange(self.n_clusters):
174 | n_skills = len(self.skills[cluster_ind][0])
175 | val = np.zeros(shape=(n_skills,1))
176 | for skill_ind in xrange(n_skills):
177 | r = self.R_skills[cluster_ind][skill_ind,0]
178 | k = self.k_skills[cluster_ind][skill_ind,0]
179 | next_ind = self.skills[cluster_ind][0][skill_ind]
180 | val[skill_ind] = r+(self.gamma**k)*self.v_smdp[next_ind]
181 | policy.append((cluster_ind,self.skills[cluster_ind][0][np.argmax(val)]))
182 | return policy
183 |
184 | def get_smdp_edges(self):
185 | edges = []
186 | for i in xrange(self.n_clusters):
187 | for j in xrange(self.n_clusters):
188 | if self.P[i,j]>0:
189 | edges.append((i,j))
190 | return edges
191 |
192 | def evaluate_greedy_policy(self,l):
193 | PP = np.copy(self.P)
194 | for trans in self.greedy_policy:
195 | p = PP[trans[0],:]
196 | p_greedy = np.zeros_like(p)
197 | p_greedy[trans[1]]=1
198 | p = l*p+(1-l)*p_greedy
199 | PP[trans[0],:]=p
200 | self.v_greedy = self.calc_v_policy(PP)
201 |
202 | def create_skills_model(self,P):
203 | skills = []
204 | for cluster_ind in xrange(self.n_clusters):
205 | skills.append(np.nonzero(P[cluster_ind,:]))
206 | l_p = int(self.labels[0])
207 | total_r = 0
208 | t = 0
209 | R_skills = []
210 | k_skills = []
211 | for cluster_ind in xrange(len(skills)):
212 | R_skills.append(np.zeros(shape=(len(skills[cluster_ind][0]),2)))
213 | k_skills.append(np.zeros(shape=(len(skills[cluster_ind][0]),2)))
214 |
215 | rewards_clip = np.clip(self.rewards,-1,1)
216 | for i, (l, r) in enumerate(zip(self.labels[1:], rewards_clip[1:])):
217 | total_r += self.gamma**t * r
218 | if l == l_p:
219 | t += 1
220 | else:
221 | if self.P[l_p,l]>0:
222 | skill_index = np.flatnonzero(np.asarray(skills[l_p][0])==l)[0]
223 | R_skills[l_p][skill_index,0] += total_r #/ (t+1)
224 | R_skills[l_p][skill_index,1] += 1
225 | k_skills[l_p][skill_index,0] += t
226 | k_skills[l_p][skill_index,1] += 1
227 | l_p = int(l)
228 | total_r = 0
229 | t = 0
230 |
231 | for skill_r,skill_k in zip(R_skills, k_skills):
232 | skill_r[:,0] /= skill_r[:,1]
233 | skill_k[:,0] /= skill_k[:,1]
234 |
235 | return skills,R_skills,k_skills
236 |
237 | def calc_skill_indices(self):
238 | l_p = int(self.labels[0])
239 | current_skill = []
240 | skill_indices = [[] for i in range(self.n_clusters)]
241 | skill_list = [[] for i in range(self.n_clusters)]
242 | skill_time = [[] for i in range(self.n_clusters)]
243 |
244 | for i, l in enumerate(zip(self.labels[1:])):
245 | current_skill.append(i)
246 | if l[0] != l_p:
247 | if self.P[l_p,l[0]]>0:
248 | skill_index = np.nonzero(skill_list[l_p]==l[0])[0] #find skill index in list
249 | curr_length = len(current_skill)
250 | if curr_length > self.k:
251 | length = []
252 | length.append(curr_length)
253 |
254 | if len(skill_index) == 0: #if not found - append
255 | skill_list[l_p].append(l[0])
256 | skill_indices[l_p].append(current_skill)
257 | skill_time[l_p].append(length)
258 | else:
259 | skill_indices[l_p][skill_index].extend(current_skill)
260 | skill_time[l_p][skill_index].extend(length)
261 |
262 | l_p = l[0]
263 | current_skill = []
264 | return skill_indices,skill_list,skill_time
265 |
266 | def complete_smdp(self):
267 | # not needed in spatio-temporal
268 | self.skills, self.R_skills, self.k_skills = self.create_skills_model(self.P)
269 | self.greedy_policy = self.policy_improvement()
270 | self.v_greedy = self.evaluate_greedy_policy(0)
271 | self.skill_indices, self.skill_list,self.skill_time = self.calc_skill_indices()
272 |
--------------------------------------------------------------------------------
/graying_the_box/vis_tool.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from matplotlib.widgets import Button, Slider, CheckButtons
4 | from add_global_features import add_buttons as add_global_buttons
5 | from control_buttons import add_buttons as add_control_buttons
6 | import common
7 |
8 | class VIS_TOOL(object):
9 |
10 | def __init__(self, global_feats, hand_craft_feats, game_id, cluster_params):
11 |
12 | # 0. connect arguments
13 | self.global_feats = global_feats
14 | self.game_id = game_id
15 | self.hand_craft_feats = hand_craft_feats
16 | self.num_points = global_feats['tsne'].shape[0]
17 | self.data_t = global_feats['tsne'].T
18 | screens = np.reshape(np.transpose(global_feats['screens']), (3,210,160,-1))
19 | self.screens = np.transpose(screens,(3,1,2,0))
20 | self.im_size = np.sqrt(global_feats['states'].shape[1])
21 | self.states = np.reshape(global_feats['states'], (global_feats['states'].shape[0], self.im_size,self.im_size))
22 | self.tsne3d_next = global_feats['tsne3d_next'].T
23 | self.tsnedata = global_feats['tsne3d'].T
24 | self.traj_index = global_feats['trajectory_index']
25 | self.tsne3d_norm = global_feats['tsne3d_norm']
26 | tmp = global_feats['value'] - np.amin(np.array(global_feats['value']))
27 | self.color = tmp / np.amax(tmp)
28 | self.cluster_params = cluster_params
29 |
30 | # 1. Constants
31 | self.pnt_size = 2
32 | self.ind = 0
33 | self.prev_ind = 0
34 |
35 | # 2. Plots
36 | self.fig = plt.figure('tSNE')
37 |
38 | # 2.1 t-SNE
39 | self.ax_tsne = plt.subplot2grid((3,5),(0,0), rowspan=3, colspan=3)
40 |
41 | self.tsne_scat = self.ax_tsne.scatter(self.data_t[0], self.data_t[1], s= np.ones(self.num_points)*self.pnt_size,c = self.color,edgecolor='none',picker=5)
42 |
43 | self.ax_tsne.set_xticklabels([])
44 | self.ax_tsne.set_yticklabels([])
45 |
46 | # 2.1.5 colorbar
47 | cb_axes = self.fig.add_axes([0.253,0.13,0.2,0.01])
48 | cbar = self.fig.colorbar(self.tsne_scat, cax=cb_axes, orientation='horizontal', ticks=[0,1])
49 | cbar.ax.set_xticklabels(['Low','High'])
50 |
51 | # 2.2 game screen (state)
52 | self.ax_screen = plt.subplot2grid((3,5),(2,3), rowspan=1, colspan=1)
53 |
54 | self.screenplot = self.ax_screen.imshow(self.screens[self.ind], interpolation='none')
55 |
56 | self.ax_screen.set_xticklabels([])
57 | self.ax_screen.set_yticklabels([])
58 |
59 | # 2.3 gradient image (saliency map)
60 | self.ax_state = plt.subplot2grid((3,5),(2,4), rowspan=1, colspan=1)
61 |
62 | self.stateplot = self.ax_state.imshow(self.states[self.ind], interpolation='none', cmap='gray',picker=5)
63 |
64 | self.ax_state.set_xticklabels([])
65 | self.ax_state.set_yticklabels([])
66 |
67 | # 3. Global Features
68 | add_global_buttons(self, global_feats)
69 |
70 | # 4. Control buttons
71 | add_control_buttons(self)
72 |
73 | # 4.1 add game buttons
74 | if game_id == 0: # breakout
75 | from add_breakout_buttons import add_game_buttons, update_cond_vector
76 | if game_id == 1: # seaquest
77 | from add_seaquest_buttons import add_game_buttons, update_cond_vector
78 | if game_id == 2: # pacman
79 | from add_pacman_buttons import add_game_buttons, update_cond_vector
80 |
81 | add_game_buttons(self)
82 |
83 | self.update_cond_vector = update_cond_vector
84 |
85 | def add_color_button(self, pos, name, color):
86 | def set_color(event):
87 | self.color = self.COLORS[id(event.inaxes)]
88 | self.tsne_scat.set_array(self.color)
89 |
90 | ax = plt.axes(pos)
91 | setattr(self, name+'_button', Button(ax, name))
92 | getattr(self, name+'_button').on_clicked(set_color)
93 | color = np.array(color) - np.amin(np.array(color))
94 | self.COLORS[id(ax)] = color/np.amax(color)
95 |
96 | def update_sliders(self, val):
97 | for f in self.SLIDER_FUNCS:
98 | f()
99 | # self.update_cond_vector_breakout()
100 |
101 | def add_slider_button(self, pos, name, v_min, v_max):
102 |
103 | def update_slider(self, name, slider):
104 | def f():
105 | setattr(self, name, slider.val)
106 | return f
107 |
108 | ax_min = plt.axes(pos)
109 | ax_max = plt.axes([pos[0], pos[1]-0.02, pos[2], pos[3]])
110 |
111 | slider_min = Slider(ax_min, name+'_min', valmin=v_min, valmax=v_max, valinit=v_min)
112 | slider_max = Slider(ax_max, name+'_max', valmin=v_min, valmax=v_max, valinit=v_max)
113 |
114 | self.SLIDER_FUNCS.append(update_slider(self, name+'_min', slider_min))
115 | self.SLIDER_FUNCS.append(update_slider(self, name+'_max', slider_max))
116 |
117 | slider_min.on_changed(self.update_sliders)
118 | slider_max.on_changed(self.update_sliders)
119 |
120 | def add_check_button(self, pos, name, options, init_vals):
121 | def set_options(label):
122 | pass
123 |
124 | ax = plt.axes(pos)
125 | setattr(self, name+'_check_button', CheckButtons(ax, options, init_vals))
126 | getattr(self, name+'_check_button').on_clicked(set_options)
127 |
128 | def on_scatter_pick(self,event):
129 | self.ind = event.ind[0]
130 | self.update_plot()
131 | self.prev_ind = self.ind
132 |
133 | def update_plot(self):
134 | self.screenplot.set_array(self.screens[self.ind])
135 | self.stateplot.set_array(self.states[self.ind])
136 | sizes = self.tsne_scat.get_sizes()
137 | sizes[self.ind] = 250
138 | sizes[self.prev_ind] = self.pnt_size
139 | self.tsne_scat.set_sizes(sizes)
140 | self.fig.canvas.draw()
141 | print 'chosen point: %d' % self.ind
142 |
143 | def show(self):
144 | plt.show(block=True)
145 |
--------------------------------------------------------------------------------
|