├── DQN-HDRLN ├── NeuralQLearner.lua ├── Rectifier.lua ├── Scale.lua ├── TransitionTable.lua ├── convnet.lua ├── convnet_atari3.lua ├── convnet_atari_main.lua ├── initenv.lua ├── net_downsample_2x_full_y.lua ├── nnutils.lua ├── run_gpu_dist ├── test_agent.lua ├── test_dist_agent.lua ├── test_gpu_dist └── train_agent.lua ├── Distillation process ├── NeuralQLearner.lua ├── README.md ├── TransitionTable.lua ├── convnet.lua ├── convnet_atari3.lua ├── convnet_atari_main.lua ├── distill_agent.lua ├── distill_gpu ├── initenv.lua └── test_distill ├── LICENSE ├── README.md ├── Utilities ├── README.md ├── activations_hdf5_to_tsne.lua ├── clean_data.py ├── cluster.py └── learn_weights.py └── graying_the_box ├── .gitignore ├── .idea ├── .name ├── encodings.xml ├── graying_the_box.iml ├── misc.xml ├── modules.xml ├── vcs.xml └── workspace.xml ├── LUA ├── dqn │ ├── NeuralQLearner.lua │ ├── Rectifier.lua │ ├── Scale.lua │ ├── TransitionTable.lua │ ├── convnet.lua │ ├── convnet_atari3.lua │ ├── initenv.lua │ ├── net_downsample_2x_full_y.lua │ ├── nnutils.lua │ ├── test_agent.lua │ ├── train_agent.lua │ └── train_agent_tmp.lua ├── logs │ └── Results │ │ ├── Plot.lua │ │ ├── Plot2.lua │ │ ├── Plot3.lua │ │ ├── Process.sh │ │ ├── Process2.sh │ │ ├── Process3.sh │ │ └── tmp │ │ ├── TD_Error.png │ │ ├── conv1.png │ │ ├── conv1_max.png │ │ ├── conv2.png │ │ ├── conv2_max.png │ │ ├── conv3.png │ │ ├── conv3_max.png │ │ ├── lin1.png │ │ ├── lin1_max.png │ │ ├── lin2.png │ │ ├── lin2_max.png │ │ ├── reward.png │ │ └── vavg.png ├── roms │ └── README └── run_gpu ├── bhtsne ├── Makefile.win ├── parse_lua_tensor.py ├── sptree.cpp ├── sptree.h ├── tsne.cpp ├── tsne.h ├── vptree.h └── write_movie.py ├── clustering.py ├── common.py ├── digraph.py ├── emhc.py ├── hand_crafted_features ├── add_breakout_buttons.py ├── add_global_features.py ├── add_global_features.py~ ├── add_pacman_buttons.py ├── add_seaquest_buttons.py ├── control_buttons.py ├── label_states_breakout.py ├── label_states_packman.py └── label_states_seaquest.py ├── main.py ├── others ├── SaliencyScore.py ├── Score_trans.py ├── Seaquest_SaliencyDiver.py └── TermFig.py ├── prepare_data.py ├── prepare_global_features.py ├── smdp.py └── vis_tool.py /DQN-HDRLN/Rectifier.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | --[[ Rectified Linear Unit. 8 | 9 | The output is max(0, input). 10 | --]] 11 | 12 | local Rectifier, parent = torch.class('nn.Rectifier', 'nn.Module') 13 | 14 | -- This module accepts minibatches 15 | function Rectifier:updateOutput(input) 16 | return self.output:resizeAs(input):copy(input):abs():add(input):div(2) 17 | end 18 | 19 | function Rectifier:updateGradInput(input, gradOutput) 20 | self.gradInput:resizeAs(self.output) 21 | return self.gradInput:sign(self.output):cmul(gradOutput) 22 | end -------------------------------------------------------------------------------- /DQN-HDRLN/Scale.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "nn" 8 | require "image" 9 | 10 | local scale = torch.class('nn.Scale', 'nn.Module') 11 | 12 | 13 | function scale:__init(height, width) 14 | self.height = height 15 | self.width = width 16 | end 17 | 18 | function scale:forward(x) 19 | local x = x 20 | if x:dim() > 3 then 21 | x = x[1] 22 | end 23 | 24 | x = image.rgb2y(x) 25 | x = image.scale(x, self.width, self.height, 'bilinear') 26 | return x 27 | end 28 | 29 | function scale:updateOutput(input) 30 | return self:forward(input) 31 | end 32 | 33 | function scale:float() 34 | end 35 | -------------------------------------------------------------------------------- /DQN-HDRLN/convnet.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "initenv" 8 | 9 | function create_network(args) 10 | 11 | local net = nn.Sequential() 12 | net:add(nn.Reshape(unpack(args.input_dims))) 13 | 14 | --- first convolutional layer 15 | local convLayer = nn.SpatialConvolution 16 | 17 | net:add(convLayer(args.hist_len*args.ncols, args.n_units[1], 18 | args.filter_size[1], args.filter_size[1], 19 | args.filter_stride[1], args.filter_stride[1],1)) 20 | net:add(args.nl()) 21 | 22 | -- Add convolutional layers 23 | for i=1,(#args.n_units-1) do 24 | -- second convolutional layer 25 | net:add(convLayer(args.n_units[i], args.n_units[i+1], 26 | args.filter_size[i+1], args.filter_size[i+1], 27 | args.filter_stride[i+1], args.filter_stride[i+1])) 28 | net:add(args.nl()) 29 | end 30 | 31 | local nel 32 | if args.gpu >= 0 then 33 | nel = net:cuda():forward(torch.zeros(1,unpack(args.input_dims)) 34 | :cuda()):nElement() 35 | else 36 | nel = net:forward(torch.zeros(1,unpack(args.input_dims))):nElement() 37 | end 38 | 39 | -- reshape all feature planes into a vector per example 40 | net:add(nn.Reshape(nel)) 41 | 42 | -- fully connected layer 43 | net:add(nn.Linear(nel, args.n_hid[1])) 44 | net:add(args.nl()) 45 | local last_layer_size = args.n_hid[1] 46 | 47 | for i=1,(#args.n_hid-1) do 48 | -- add Linear layer 49 | last_layer_size = args.n_hid[i+1] 50 | net:add(nn.Linear(args.n_hid[i], last_layer_size)) 51 | net:add(args.nl()) 52 | end 53 | 54 | if args.distilled_network == true and false then 55 | -- add the last fully connected layer (to actions) 56 | concat = nn.Concat(2) 57 | 58 | 59 | local MCgameActions = {1,3,4,5,0,6,7,8,9} --,5,6,7,8} -- this is our game actions table 60 | local MCgameActions_primitive = {1,3,4,0,5} --,5} -- this is our game actions table 61 | local optionsActions = {6,7,8,9} -- these actions are correlated to an OPTION, 20 = solve room (make this struct with max iterations per option and socket port and ip) 62 | 63 | local navigateActions = {1,3,4} 64 | local pickupActions = {1,3,4} 65 | local breakActions = {1,3,4,5} 66 | local placeActions = {1,3,4,0} 67 | 68 | 69 | --args.skills = {navigateActions, pickupActions, breakActions, placeActions} 70 | 71 | for i=1,#args.skills do 72 | print('Added new skill layer '..i..' with '..(args.skills[i])..' actions') 73 | skill = nn.Sequential() 74 | skill:add(nn.Linear(last_layer_size, (args.skills[i]))) 75 | concat:add(skill) 76 | end 77 | 78 | net:add(concat) 79 | else 80 | -- add the last fully connected layer (to actions) 81 | net:add(nn.Linear(last_layer_size, args.n_actions)) 82 | end 83 | 84 | if args.gpu >=0 then 85 | net:cuda() 86 | end 87 | if args.verbose >= 2 then 88 | print(net) 89 | print('Convolutional layers flattened output size:', nel) 90 | end 91 | return net 92 | end 93 | -------------------------------------------------------------------------------- /DQN-HDRLN/convnet_atari3.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require 'convnet' 8 | 9 | return function(args) 10 | args.n_units = {32, 64, 64} 11 | args.filter_size = {8, 4, 3} 12 | args.filter_stride = {4, 2, 1} 13 | args.n_hid = {512} 14 | args.nl = nn.Rectifier 15 | args.distilled_network = true 16 | args.skills = {3,3,4,4} 17 | return create_network(args) 18 | end 19 | 20 | -------------------------------------------------------------------------------- /DQN-HDRLN/convnet_atari_main.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require 'convnet' 8 | 9 | return function(args) 10 | args.n_units = {32, 64, 64} 11 | args.filter_size = {8, 4, 3} 12 | args.filter_stride = {4, 2, 1} 13 | args.n_hid = {512} 14 | args.nl = nn.Rectifier 15 | return create_network(args) 16 | end 17 | -------------------------------------------------------------------------------- /DQN-HDRLN/initenv.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | dqn = {} 7 | 8 | require 'torch' 9 | require 'nn' 10 | require 'nngraph' 11 | require 'nnutils' 12 | require 'image' 13 | require 'Scale' 14 | require 'NeuralQLearner' 15 | require 'TransitionTable' 16 | require 'Rectifier' 17 | 18 | 19 | function torchSetup(_opt) 20 | _opt = _opt or {} 21 | local opt = table.copy(_opt) 22 | assert(opt) 23 | 24 | -- preprocess options: 25 | --- convert options strings to tables 26 | if opt.pool_frms then 27 | opt.pool_frms = str_to_table(opt.pool_frms) 28 | end 29 | if opt.env_params then 30 | opt.env_params = str_to_table(opt.env_params) 31 | end 32 | if opt.agent_params then 33 | opt.agent_params = str_to_table(opt.agent_params) 34 | opt.agent_params.gpu = opt.gpu 35 | opt.agent_params.best = opt.best 36 | opt.agent_params.verbose = opt.verbose 37 | if opt.network ~= '' then 38 | opt.agent_params.network = opt.network 39 | end 40 | end 41 | 42 | --- general setup 43 | opt.tensorType = opt.tensorType or 'torch.FloatTensor' 44 | torch.setdefaulttensortype(opt.tensorType) 45 | if not opt.threads then 46 | opt.threads = 4 47 | end 48 | torch.setnumthreads(opt.threads) 49 | if not opt.verbose then 50 | opt.verbose = 10 51 | end 52 | if opt.verbose >= 1 then 53 | print('Torch Threads:', torch.getnumthreads()) 54 | end 55 | 56 | --- set gpu device 57 | if opt.gpu and opt.gpu >= 0 then 58 | require 'cutorch' 59 | require 'cunn' 60 | if opt.gpu == 0 then 61 | local gpu_id = tonumber(os.getenv('GPU_ID')) 62 | if gpu_id then opt.gpu = gpu_id+1 end 63 | end 64 | if opt.gpu > 0 then cutorch.setDevice(opt.gpu) end 65 | opt.gpu = cutorch.getDevice() 66 | print('Using GPU device id:', opt.gpu-1) 67 | else 68 | opt.gpu = -1 69 | if opt.verbose >= 1 then 70 | print('Using CPU code only. GPU device id:', opt.gpu) 71 | end 72 | end 73 | 74 | --- set up random number generators 75 | -- removing lua RNG; seeding torch RNG with opt.seed and setting cutorch 76 | -- RNG seed to the first uniform random int32 from the previous RNG; 77 | -- this is preferred because using the same seed for both generators 78 | -- may introduce correlations; we assume that both torch RNGs ensure 79 | -- adequate dispersion for different seeds. 80 | math.random = nil 81 | opt.seed = opt.seed or 1 82 | torch.manualSeed(opt.seed) 83 | if opt.verbose >= 1 then 84 | print('Torch Seed:', torch.initialSeed()) 85 | end 86 | local firstRandInt = torch.random() 87 | if opt.gpu >= 0 then 88 | cutorch.manualSeed(firstRandInt) 89 | if opt.verbose >= 1 then 90 | print('CUTorch Seed:', cutorch.initialSeed()) 91 | end 92 | end 93 | 94 | return opt 95 | end 96 | 97 | 98 | function setup(_opt) 99 | assert(_opt) 100 | 101 | --preprocess options: 102 | --- convert options strings to tables 103 | _opt.pool_frms = str_to_table(_opt.pool_frms) 104 | _opt.env_params = str_to_table(_opt.env_params) 105 | _opt.agent_params = str_to_table(_opt.agent_params) 106 | _opt.skill_agent_params = str_to_table(_opt.skill_agent_params) 107 | if _opt.agent_params.transition_params then 108 | _opt.skill_agent_params.transition_params = 109 | str_to_table(_opt.agent_params.transition_params) 110 | 111 | _opt.agent_params.transition_params = 112 | str_to_table(_opt.agent_params.transition_params) 113 | end 114 | 115 | --- first things first 116 | local opt = torchSetup(_opt) 117 | 118 | local distilled_hdrln = _opt.distilled_hdrln 119 | if distilled_hdrln == "true" then 120 | distilled_hdrln = true 121 | else 122 | distilled_hdrln = false 123 | end 124 | local supervised_skills = _opt.supervised_skills 125 | if supervised_skills == "true" then 126 | supervised_skills = true 127 | else 128 | supervised_skills = false 129 | end 130 | 131 | local num_skills = tonumber(_opt.num_skills) 132 | if args.supervised_file then 133 | local myFile = hdf5.open(args.supervised_file, 'r') 134 | num_skills = myFile:read('numberSkills'):all() 135 | myFile:close() 136 | end 137 | local gameEnv = nil 138 | 139 | local MCgameActions_primitive = {1,3,4,0,5} -- this is our game actions table 140 | local MCgameActions = MCgameActions_primitive:copy() 141 | local options = {} -- these actions are correlated to an OPTION, i.e an action the HDRLN selects that is mapped to a skill 142 | local optionsActions = {} -- for each skill we map availiable actions. Currently each skill maps all aviliable actions (doesn't have to be this way) 143 | 144 | local max_action_val = MCgameActions_primitive[1] 145 | for i = 2, #MCgameActions_primitive 146 | do 147 | max_action_val = max(max_action_val, MCgameActions_primitive[i]) 148 | end 149 | 150 | for i = 1, num_skills 151 | do 152 | MCgameActions[#MCgameActions + 1] = i + max_action_val 153 | options[#options + 1] = i + max_action_val -- we want all actions mapped to skills to be larger than the maximal primitive action value 154 | optionsActions[#optionsActions + 1] = MCgameActions_primitive -- map all primitive actions for each skill 155 | end 156 | 157 | 158 | -- agent options 159 | _opt.agent_params.actions = MCgameActions 160 | _opt.agent_params.options = options 161 | _opt.agent_params.optionsActions = optionsActions 162 | _opt.agent_params.gpu = _opt.gpu 163 | _opt.agent_params.best = _opt.best 164 | _opt.agent_params.distilled_network = distilled_hdrln 165 | _opt.agent_params.distill = false 166 | if _opt.agent_params.network then 167 | print(_opt.agent_params.network) 168 | _opt.agent_params.network = "convnet_atari_main" 169 | end 170 | --_opt.agent_params.network = "convnet_atari3" 171 | if _opt.network ~= '' then 172 | _opt.agent_params.network = _opt.network 173 | end 174 | 175 | _opt.agent_params.verbose = _opt.verbose 176 | if not _opt.agent_params.state_dim then 177 | _opt.agent_params.state_dim = gameEnv:nObsFeature() 178 | end 179 | if distilled_hdrln then -- distilled means single main network with multiple skills integrated into it 180 | _opt.skill_agent_params.actions = MCgameActions_primitive 181 | _opt.skill_agent_params.gpu = _opt.gpu 182 | _opt.skill_agent_params.best = _opt.best 183 | _opt.skill_agent_params.distilled_network = true 184 | _opt.skill_agent_params.distill = false 185 | _opt.agent_params.supervised_skills = supervised_skills 186 | _opt.agent_params.supervised_file = args.supervised_file 187 | if _opt.distilled_network ~= '' then 188 | _opt.skill_agent_params.network = _opt.distilled_network 189 | end 190 | _opt.skill_agent_params.verbose = _opt.verbose 191 | if not _opt.skill_agent_params.state_dim then 192 | _opt.skill_agent_params.state_dim = gameEnv:nObsFeature() 193 | end 194 | print("SKILL NETWORK") 195 | local distilled_agent = dqn[_opt.agent](_opt.skill_agent_params) 196 | print("END SKILL NETWORK") 197 | 198 | _opt.agent_params.skill_agent = distilled_agent 199 | else 200 | for i = 1, num_skills 201 | do 202 | _opt.skill_agent_params.actions = MCgameActions_primitive 203 | _opt.skill_agent_params.gpu = _opt.gpu 204 | _opt.skill_agent_params.best = _opt.best 205 | _opt.skill_agent_params.distilled_network = false 206 | _opt.skill_agent_params.distill = false 207 | if getlocal('_opt.skill_' .. i) ~= '' then 208 | _opt.skill_agent_params.network = getlocal('_opt.skill_' .. i) 209 | end 210 | _opt.skill_agent_params.verbose = _opt.verbose 211 | if not _opt.skill_agent_params.state_dim then 212 | _opt.skill_agent_params.state_dim = gameEnv:nObsFeature() 213 | end 214 | print("SKILL NETWORK " .. i) 215 | local skill_agent = dqn[_opt.agent](_opt.skill_agent_params) 216 | _opt.agent_params.skill_agent[#(_opt.agent_params.skill_agent) + 1] = skill_agent 217 | print("END SKILL NETWORK " .. i) 218 | end 219 | end 220 | 221 | _opt.agent_params.primitive_actions = MCgameActions_primitive 222 | print("MAIN AGENT") 223 | local agent = dqn[_opt.agent](_opt.agent_params) 224 | print("MAIN AGENT") 225 | if opt.verbose >= 1 then 226 | print('Set up Torch using these options:') 227 | for k, v in pairs(opt) do 228 | print(k, v) 229 | end 230 | end 231 | 232 | return gameEnv, MCgameActions_primitive, agent, opt 233 | end 234 | 235 | 236 | 237 | --- other functions 238 | 239 | function str_to_table(str) 240 | if type(str) == 'table' then 241 | return str 242 | end 243 | if not str or type(str) ~= 'string' then 244 | if type(str) == 'table' then 245 | return str 246 | end 247 | return {} 248 | end 249 | local ttr 250 | if str ~= '' then 251 | local ttx=tt 252 | loadstring('tt = {' .. str .. '}')() 253 | ttr = tt 254 | tt = ttx 255 | else 256 | ttr = {} 257 | end 258 | return ttr 259 | end 260 | 261 | function table.copy(t) 262 | if t == nil then return nil end 263 | local nt = {} 264 | for k, v in pairs(t) do 265 | if type(v) == 'table' then 266 | nt[k] = table.copy(v) 267 | else 268 | nt[k] = v 269 | end 270 | end 271 | setmetatable(nt, table.copy(getmetatable(t))) 272 | return nt 273 | end 274 | -------------------------------------------------------------------------------- /DQN-HDRLN/net_downsample_2x_full_y.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "image" 8 | require "Scale" 9 | 10 | local function create_network(args) 11 | -- Y (luminance) 12 | return nn.Scale(84, 84, true) 13 | end 14 | 15 | return create_network 16 | -------------------------------------------------------------------------------- /DQN-HDRLN/nnutils.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "torch" 8 | 9 | function recursive_map(module, field, func) 10 | local str = "" 11 | if module[field] or module.modules then 12 | str = str .. torch.typename(module) .. ": " 13 | end 14 | if module[field] then 15 | str = str .. func(module[field]) 16 | end 17 | if module.modules then 18 | str = str .. "[" 19 | for i, submodule in ipairs(module.modules) do 20 | local submodule_str = recursive_map(submodule, field, func) 21 | str = str .. submodule_str 22 | if i < #module.modules and string.len(submodule_str) > 0 then 23 | str = str .. " " 24 | end 25 | end 26 | str = str .. "]" 27 | end 28 | 29 | return str 30 | end 31 | 32 | function abs_mean(w) 33 | return torch.mean(torch.abs(w:clone():float())) 34 | end 35 | 36 | function abs_max(w) 37 | return torch.abs(w:clone():float()):max() 38 | end 39 | 40 | -- Build a string of average absolute weight values for the modules in the 41 | -- given network. 42 | function get_weight_norms(module) 43 | return "Weight norms:\n" .. recursive_map(module, "weight", abs_mean) .. 44 | "\nWeight max:\n" .. recursive_map(module, "weight", abs_max) 45 | end 46 | 47 | -- Build a string of average absolute weight gradient values for the modules 48 | -- in the given network. 49 | function get_grad_norms(module) 50 | return "Weight grad norms:\n" .. 51 | recursive_map(module, "gradWeight", abs_mean) .. 52 | "\nWeight grad max:\n" .. recursive_map(module, "gradWeight", abs_max) 53 | end 54 | -------------------------------------------------------------------------------- /DQN-HDRLN/run_gpu_dist: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ] 4 | then echo "Please provide the name of the game, e.g. ./run_gpu breakout "; exit 0 5 | fi 6 | ENV=$1 7 | FRAMEWORK="alewrap" 8 | 9 | game_path=$PWD"/roms/" 10 | env_params="useRGB=true" 11 | agent="NeuralQLearner" 12 | n_replay=16 13 | netfile="\"convnet_atari3\"" 14 | update_freq=4 15 | actrep=1 16 | discount=0.99 17 | seed=1 18 | learn_start=20000 #2000 #20000 #5000 #50000 19 | pool_frms_type="\"max\"" 20 | pool_frms_size=2 21 | initial_priority="false" 22 | replay_memory=100000 #1000000 23 | eps_end=0.1 24 | eps_endt=100000 #400000 #75000 #replay_memory 25 | lr=0.0025 #0.00025 26 | agent_type="DQN3_0_1" 27 | preproc_net="\"net_downsample_2x_full_y\"" 28 | agent_name=$agent_type"_"$1"_FULL_Y" 29 | state_dim=7056 #dimensions of image we send widthXheight 30 | ncols=1 31 | ddqn=1 32 | reward_shaping=0 33 | option_length=5 34 | agent_params="lr="$lr",ep=1,ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=32,rescale_r=1,ncols="$ncols",bufferSize=512,valid_size=500,target_q=10000,clip_delta=1,ddqn="$ddqn",option_length="$option_length"" #,min_reward=-2000,max_reward=2000" #minibatch_size=32,min_reward=-1,max_reward=1,hist_len=4 35 | steps=50000000 36 | eval_freq=5000 #5000 #20000 #250000 37 | eval_steps=1000 #1000 #125000 38 | prog_freq=10000 # how often it outputs to stdout 39 | save_freq=50000 40 | gpu=0 41 | random_starts=30 42 | pool_frms="type="$pool_frms_type",size="$pool_frms_size 43 | num_threads=4 44 | NETWORK=$agent_name 45 | skill_agent_params=$agent_params",network=\"convnet_atari3\"" 46 | agent_params=$agent_params",network="$netfile 47 | distilled_hdrln="true" # this tells us if HDRLN setup is using a distilled agent or using multiple DQN array 48 | supervised_skills="true" # this goes together with distilled agent 49 | supervised_file="/home/deep5/DQN_Shahar_Chen_oldpc/dqn_distill/skillWeights.h5" 50 | num_skills=4 # how many skills, relevant both for distilled and non distilled setup 51 | args="-framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -skill_agent_params $skill_agent_params -agent_params $agent_params -steps $steps -eval_freq $eval_freq -eval_steps $eval_steps -prog_freq $prog_freq -save_freq $save_freq -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads -reward_shaping $reward_shaping -distilled_hdrln $distilled_hdrln -supervised_skills $supervised_skills -supervised_file $supervised_file" 52 | args=$args" -socket $2" 53 | 54 | args=$args" -num_skills $num_skills -skill_1 pickup.t7 -skill_2 navigate.t7 -skill_3 break.t7 -skill_4 break.t7 -distilled_network DQN3_0_1_hdrln_dist_fix_size_FULL_Y.t7"; #distilled_network_0.1temp.t7"; 55 | 56 | echo $args 57 | 58 | cd dqn_distill 59 | ../torch/bin/qlua train_agent.lua $args 60 | -------------------------------------------------------------------------------- /DQN-HDRLN/test_gpu_dist: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ] 4 | then echo "Please provide the name of the game, e.g. ./watch_pretrained breakout"; exit 0 5 | fi 6 | 7 | if [ -z "$2" ] 8 | then echo "Please provide the pretrained network file, e.g. ./watch_pretrained breakout DQN3_0_1_breakout_FULL_Y.t7"; exit 0 9 | fi 10 | 11 | ENV=$1 12 | NETWORK=$2 13 | FRAMEWORK="alewrap" 14 | 15 | game_path=$PWD"/roms/" 16 | env_params="useRGB=true" 17 | agent="NeuralQLearner" 18 | n_replay=16 19 | netfile="\"convnet_atari3\"" 20 | update_freq=4 21 | actrep=1 22 | discount=0.99 23 | seed=1 24 | learn_start=50000 25 | pool_frms_type="\"max\"" 26 | pool_frms_size=2 27 | initial_priority="false" 28 | replay_memory=500010 29 | eps_end=0 30 | eps_endt=0 31 | lr=0.00025 32 | agent_type="DQN3_0_1" 33 | preproc_net="\"net_downsample_2x_full_y\"" 34 | agent_name=$agent_type"_"$1"_FULL_Y" 35 | state_dim=7056 36 | ncols=1 37 | reward_shaping=0 38 | ddqn=1 39 | option_length=5 40 | agent_params="lr="$lr",ep=1,ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=32,rescale_r=1,ncols="$ncols",bufferSize=512,valid_size=500,target_q=10000,clip_delta=1,min_reward=-300,max_reward=300,ddqn="$ddqn",option_length="$option_length"" 41 | gif_file="../gifs/$ENV." 42 | gpu=0 43 | random_starts=30 44 | pool_frms="type="$pool_frms_type",size="$pool_frms_size 45 | num_threads=4 46 | skill_agent_params=$agent_params",network=\"convnet_atari3\"" 47 | 48 | args="-framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -skill_agent_params $skill_agent_params -agent_params $agent_params -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads -network $NETWORK -gif_file $gif_file -reward_shaping $reward_shaping" 49 | args=$args" -socket $3" 50 | args=$args" -distilled_network distilled_network_0.1temp.t7"; 51 | 52 | 53 | echo $args 54 | 55 | cd dqn_distill 56 | ../torch/bin/qlua test_dist_agent.lua $args 57 | -------------------------------------------------------------------------------- /Distillation process/README.md: -------------------------------------------------------------------------------- 1 | # Distilled Multi Skill Agent Process 2 | __This is in development stage__ 3 | 4 | For further explanation see the [website](http://chentessler.wixsite.com/hdrlnminecraft). 5 | 6 | This code requires a lot of refactoring and cleanup. 7 | You let your "teachers" play for N time steps and record the data (Input: state, Output: Q values). 8 | Then with this code, load the data and train the new distilled "student". 9 | 10 | Currently trains using the MSE loss function, need to swap it to KL-Divergence. 11 | -------------------------------------------------------------------------------- /Distillation process/convnet.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "initenv" 8 | 9 | function create_network(args) 10 | 11 | local net = nn.Sequential() 12 | net:add(nn.Reshape(unpack(args.input_dims))) 13 | 14 | --- first convolutional layer 15 | local convLayer = nn.SpatialConvolution 16 | 17 | net:add(convLayer(args.hist_len*args.ncols, args.n_units[1], 18 | args.filter_size[1], args.filter_size[1], 19 | args.filter_stride[1], args.filter_stride[1],1)) 20 | net:add(args.nl()) 21 | 22 | -- Add convolutional layers 23 | for i=1,(#args.n_units-1) do 24 | -- second convolutional layer 25 | net:add(convLayer(args.n_units[i], args.n_units[i+1], 26 | args.filter_size[i+1], args.filter_size[i+1], 27 | args.filter_stride[i+1], args.filter_stride[i+1])) 28 | net:add(args.nl()) 29 | end 30 | 31 | local nel 32 | if args.gpu >= 0 then 33 | nel = net:cuda():forward(torch.zeros(1,unpack(args.input_dims)) 34 | :cuda()):nElement() 35 | else 36 | nel = net:forward(torch.zeros(1,unpack(args.input_dims))):nElement() 37 | end 38 | 39 | -- reshape all feature planes into a vector per example 40 | net:add(nn.Reshape(nel)) 41 | 42 | -- fully connected layer 43 | net:add(nn.Linear(nel, args.n_hid[1])) 44 | net:add(args.nl()) 45 | local last_layer_size = args.n_hid[1] 46 | 47 | for i=1,(#args.n_hid-1) do 48 | -- add Linear layer 49 | last_layer_size = args.n_hid[i+1] 50 | net:add(nn.Linear(args.n_hid[i], last_layer_size)) 51 | net:add(args.nl()) 52 | end 53 | 54 | if args.distilled_network == true and false then 55 | -- add the last fully connected layer (to actions) 56 | concat = nn.Concat(2) 57 | 58 | 59 | local MCgameActions = {1,3,4,5,0,6,7,8,9} --,5,6,7,8} -- this is our game actions table 60 | local MCgameActions_primitive = {1,3,4,0,5} --,5} -- this is our game actions table 61 | local optionsActions = {6,7,8,9} -- these actions are correlated to an OPTION, 20 = solve room (make this struct with max iterations per option and socket port and ip) 62 | 63 | local navigateActions = {1,3,4} 64 | local pickupActions = {1,3,4} 65 | local breakActions = {1,3,4,5} 66 | local placeActions = {1,3,4,0} 67 | 68 | 69 | --args.skills = {navigateActions, pickupActions, breakActions, placeActions} 70 | 71 | for i=1,#args.skills do 72 | print('Added new skill layer '..i..' with '..(args.skills[i])..' actions') 73 | skill = nn.Sequential() 74 | skill:add(nn.Linear(last_layer_size, (args.skills[i]))) 75 | concat:add(skill) 76 | end 77 | 78 | net:add(concat) 79 | else 80 | -- add the last fully connected layer (to actions) 81 | net:add(nn.Linear(last_layer_size, args.n_actions)) 82 | end 83 | 84 | if args.gpu >=0 then 85 | net:cuda() 86 | end 87 | if args.verbose >= 2 then 88 | print(net) 89 | print('Convolutional layers flattened output size:', nel) 90 | end 91 | return net 92 | end 93 | -------------------------------------------------------------------------------- /Distillation process/convnet_atari3.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require 'convnet' 8 | 9 | return function(args) 10 | args.n_units = {32, 64, 64} 11 | args.filter_size = {8, 4, 3} 12 | args.filter_stride = {4, 2, 1} 13 | args.n_hid = {512} 14 | args.nl = nn.Rectifier 15 | args.distilled_network = true 16 | args.skills = {3,3,4,4} 17 | return create_network(args) 18 | end 19 | 20 | -------------------------------------------------------------------------------- /Distillation process/convnet_atari_main.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require 'convnet' 8 | 9 | return function(args) 10 | args.n_units = {32, 64, 64} 11 | args.filter_size = {8, 4, 3} 12 | args.filter_stride = {4, 2, 1} 13 | args.n_hid = {512} 14 | args.nl = nn.Rectifier 15 | return create_network(args) 16 | end 17 | -------------------------------------------------------------------------------- /Distillation process/distill_gpu: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ] 4 | then echo "Please provide the name of the game, e.g. ./run_gpu breakout "; exit 0 5 | fi 6 | ENV=$1 7 | FRAMEWORK="alewrap" 8 | 9 | game_path=$PWD"/roms/" 10 | env_params="useRGB=true" 11 | agent="NeuralQLearner" 12 | n_replay=16 13 | netfile="\"convnet_atari3\"" 14 | update_freq=4 15 | actrep=1 16 | discount=0.99 17 | seed=1 18 | learn_start=2000 #20000 #5000 #50000 19 | pool_frms_type="\"max\"" 20 | pool_frms_size=2 21 | initial_priority="false" 22 | replay_memory=100000 #1000000 23 | eps_end=0.1 24 | eps_endt=100000 #400000 #75000 #replay_memory 25 | lr=0.00025 26 | agent_type="DQN3_0_1" 27 | preproc_net="\"net_downsample_2x_full_y\"" 28 | agent_name=$agent_type"_"$1"_FULL_Y" 29 | state_dim=7056 #dimensions of image we send widthXheight 30 | ncols=1 31 | ddqn=1 32 | reward_shaping=1 33 | option_length=5 34 | agent_params="lr="$lr",ep=1,ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=1,rescale_r=1,ncols="$ncols",bufferSize=512,valid_size=500,target_q=10000,clip_delta=1,ddqn="$ddqn",option_length="$option_length"" #,min_reward=-2000,max_reward=2000" #minibatch_size=32,min_reward=-1,max_reward=1,hist_len=4 35 | steps=50000000 36 | eval_freq=5000 #5000 #20000 #250000 37 | eval_steps=1000 #1000 #125000 38 | prog_freq=10000 # how often it outputs to stdout 39 | save_freq=50000 40 | gpu=2 41 | random_starts=30 42 | pool_frms="type="$pool_frms_type",size="$pool_frms_size 43 | num_threads=4 44 | NETWORK=$agent_name 45 | 46 | args="-framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -skill_agent_params $agent_params -agent_params $agent_params -steps $steps -eval_freq $eval_freq -eval_steps $eval_steps -prog_freq $prog_freq -save_freq $save_freq -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads -reward_shaping $reward_shaping" 47 | args=$args" -socket $2" 48 | if [ -n "$3" ] 49 | then args=$args" -pickup_network pickup.t7 -navigate_network navigate.t7 -break_network break.t7 -place_network break.t7"; 50 | fi 51 | echo $args 52 | 53 | cd dqn_distill 54 | ../torch/bin/qlua distill_agent.lua $args 55 | -------------------------------------------------------------------------------- /Distillation process/test_distill: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ] 4 | then echo "Please provide the name of the game, e.g. ./run_gpu breakout "; exit 0 5 | fi 6 | ENV=$1 7 | FRAMEWORK="alewrap" 8 | 9 | game_path=$PWD"/roms/" 10 | env_params="useRGB=true" 11 | agent="NeuralQLearner" 12 | n_replay=16 13 | netfile="\"convnet_atari3\"" 14 | update_freq=4 15 | actrep=1 16 | discount=0.99 17 | seed=1 18 | learn_start=2000 #20000 #5000 #50000 19 | pool_frms_type="\"max\"" 20 | pool_frms_size=2 21 | initial_priority="false" 22 | replay_memory=100000 #1000000 23 | eps_end=0.1 24 | eps_endt=100000 #400000 #75000 #replay_memory 25 | lr=0.0025 #0.00025 26 | agent_type="DQN3_0_1" 27 | preproc_net="\"net_downsample_2x_full_y\"" 28 | agent_name=$agent_type"_"$1"_FULL_Y" 29 | state_dim=7056 #dimensions of image we send widthXheight 30 | ncols=1 31 | ddqn=1 32 | reward_shaping=1 33 | option_length=5 34 | agent_params="lr="$lr",ep=1,ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=32,rescale_r=1,ncols="$ncols",bufferSize=512,valid_size=500,target_q=10000,clip_delta=1,ddqn="$ddqn",option_length="$option_length"" #,min_reward=-2000,max_reward=2000" #minibatch_size=32,min_reward=-1,max_reward=1,hist_len=4 35 | steps=50000000 36 | eval_freq=5000 #5000 #20000 #250000 37 | eval_steps=1000 #1000 #125000 38 | prog_freq=10000 # how often it outputs to stdout 39 | save_freq=50000 40 | gpu=2 41 | random_starts=30 42 | pool_frms="type="$pool_frms_type",size="$pool_frms_size 43 | num_threads=4 44 | NETWORK=$2 45 | 46 | args="-framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -skill_agent_params $agent_params -agent_params $agent_params -steps $steps -eval_freq $eval_freq -eval_steps $eval_steps -prog_freq $prog_freq -save_freq $save_freq -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads -reward_shaping $reward_shaping -network $NETWORK" 47 | args=$args" -socket $3" 48 | 49 | args=$args" -pickup_network pickup.t7 -navigate_network navigate.t7 -break_network break.t7 -place_network break.t7 -distilled_network distilled_network_0.1temp_batchof1.t7" 50 | 51 | echo $args 52 | 53 | cd dqn_distill 54 | ../torch/bin/qlua test_agent.lua $args 55 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Chen Tessler 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # H-DRLN 2 | __This is in development stage__ 3 | 4 | The H-DRLN (Hierarchical Deep RL Network) is a novel architecture to incorporate options (skills) within the DQN. 5 | For further explanation see the [website](http://chentessler.wixsite.com/hdrlnminecraft). 6 | 7 | ### Graphic view of created skill clusters 8 | We use the work done by [Zahavy et al](https://arxiv.org/abs/1602.02658) and adapt the [interface they created](https://github.com/bentzinir/graying_the_box) in order to graphically view the clusters on the tSNE plot. 9 | 10 | ### References 11 | - [A Deep Hierarchical Approach to Lifelong Learning in Minecraft - Tessler et al](https://arxiv.org/abs/1604.07255) 12 | - [Human-level control through deep reinforcement learning. Nature 518(7540):529–533 Mnih et al](https://arxiv.org/abs/1312.5602) 13 | - [Graying the black box: Understanding dqns. Proceedings of the 33th international conference on machine learning (ICML) - Zahavy et al](https://arxiv.org/abs/1602.02658) 14 | -------------------------------------------------------------------------------- /Utilities/README.md: -------------------------------------------------------------------------------- 1 | # Utilities 2 | __This is in development stage__ 3 | 4 | - activations_hdf5_to_tsne.lua - use this in order to extract 2d tSNE map of the activations. (also left some code to convert data structure from t7 to h5 if needed). 5 | - clean_data.py - runs over data and keeps only trajectories that result in a successful result. We don't want to learn a skill of the agent stuck in the corner. 6 | - cluster.py - will run multiple Gaussian Mixture Models in order to find the optimal # of clusters to fit the data. 7 | - learn_weights.py - tensorflow, will learn the weights + save them. Weights are an extension of the "expert network". 8 | -------------------------------------------------------------------------------- /Utilities/activations_hdf5_to_tsne.lua: -------------------------------------------------------------------------------- 1 | require 'hdf5' 2 | require 'unsup' 3 | m = require 'manifold' 4 | bh_tsne = require('tsne') 5 | 6 | local cmd = torch.CmdLine() 7 | cmd:text() 8 | cmd:option('-N', '', 'number of states') 9 | cmd:text() 10 | cmd:option('-iter', '', 'number of iterations') 11 | cmd:option('-convert', false, 'convert t7 to hdf5') 12 | local args = cmd:parse(arg) 13 | local N = tonumber(args.N) 14 | local maxiter = tonumber(args.iter) 15 | 16 | local pca_dims = 50 17 | 18 | local convert_t7_to_hdf5 = cmd:parse(args.convert) 19 | 20 | if convert_t7_to_hdf5 then 21 | local myFile = torch.DiskFile('./dqn_distill/hdrln_activations.t7', 'r') 22 | local data = myFile:readObject()--myFile:read('data'):all() 23 | else 24 | local myFile = hdf5.open('./dqn_distill/activationsClean.h5', 'r') 25 | local data = myFile:read('data'):all() 26 | end 27 | myFile:close() 28 | 29 | print('Data loaded ' .. data:size(1) .. ' ' .. data:size(2)) 30 | 31 | if convert_t7_to_hdf5 then 32 | data = data:narrow(1, 1, N):double() 33 | local myFile2= hdf5.open('./dqn_distill/activations.h5', 'w') 34 | myFile2:write('data', data) 35 | myFile2:close() 36 | print('saved activations') 37 | 38 | local myFile = torch.DiskFile('./dqn_distill/hdrln_actions.t7', 'r') --hdf5.open('./dqn/tmp/global_activations.h5', 'r') 39 | local data = myFile:readObject()--myFile:read('data'):all() 40 | myFile:close() 41 | data = data:narrow(1, 1, N) 42 | local myFile2= hdf5.open('./dqn_distill/actions.h5', 'w') 43 | myFile2:write('data', data) 44 | myFile2:close() 45 | print('saved actions') 46 | 47 | local myFile = torch.DiskFile('./dqn_distill/hdrln_fullstates.t7', 'r') --hdf5.open('./dqn/tmp/global_activations.h5', 'r') 48 | local data = myFile:readObject()--myFile:read('data'):all() 49 | myFile:close() 50 | data = data:narrow(1, 1, N) 51 | local myFile2= hdf5.open('./dqn_distill/screens.h5', 'w') 52 | myFile2:write('data', data) 53 | myFile2:close() 54 | print('saved screens') 55 | 56 | local myFile = torch.DiskFile('./dqn_distill/hdrln_qvals.t7', 'r') --hdf5.open('./dqn/tmp/global_activations.h5', 'r') 57 | local data = myFile:readObject()--myFile:read('data'):all() 58 | myFile:close() 59 | data = data:narrow(1, 1, N) 60 | local myFile2= hdf5.open('./dqn_distill/qvals.h5', 'w') 61 | myFile2:write('data', data) 62 | myFile2:close() 63 | print('saved qvals') 64 | 65 | local myFile = torch.DiskFile('./dqn_distill/hdrln_rewards.t7', 'r') --hdf5.open('./dqn/tmp/global_activations.h5', 'r') 66 | local data = myFile:readObject()--myFile:read('data'):all() 67 | myFile:close() 68 | data = data:narrow(1, 1, N) 69 | local myFile2= hdf5.open('./dqn_distill/reward.h5', 'w') 70 | myFile2:write('data', data) 71 | myFile2:close() 72 | print('saved reward') 73 | 74 | local myFile = torch.DiskFile('./dqn_distill/hdrln_statespace.t7', 'r') --hdf5.open('./dqn/tmp/global_activations.h5', 'r') 75 | local data = myFile:readObject()--myFile:read('data'):all() 76 | myFile:close() 77 | data = data:narrow(1, 1, N) 78 | local myFile2= hdf5.open('./dqn_distill/states.h5', 'w') 79 | myFile2:write('data', data) 80 | myFile2:close() 81 | print('saved states') 82 | 83 | local myFile = torch.DiskFile('./dqn_distill/hdrln_terminal.t7', 'r') --hdf5.open('./dqn/tmp/global_activations.h5', 'r') 84 | local data = myFile:readObject()--myFile:read('data'):all() 85 | myFile:close() 86 | data = data:narrow(1, 1, N) 87 | local myFile2= hdf5.open('./dqn_distill/termination.h5', 'w') 88 | myFile2:write('data', data) 89 | myFile2:close() 90 | print('saved termination') 91 | else 92 | data = data:double() 93 | 94 | local p = 512 95 | -- perform pca on the transpose 96 | local mean = torch.mean(data,1) 97 | local xm = data - torch.ger(torch.ones(data:size(1)),mean:squeeze()) 98 | local c = torch.mm(xm:t(),xm) 99 | c:div(data:size(1)-1) 100 | local ce,cv = torch.symeig(c,'V') 101 | 102 | cv = cv:narrow(2,p-pca_dims+1,pca_dims) 103 | data = torch.mm(data, cv) 104 | 105 | -- 106 | --opts = {ndims = 2, perplexity = 50,pca = 50, use_bh = true, theta = 0.5} 107 | for perp = 750,850,100 108 | do 109 | opts = {ndims = 2, perplexity = perp, use_bh = true, pca=pca_dims, theta = 0.5, n_iter=maxiter, max_iter = maxiter, method='barnes_hut'} 110 | print('run t-SNE') 111 | --mapped_activations = m.embedding.tsne(data:double(), opts) 112 | mapped_activations = bh_tsne(data:double(), opts) 113 | print('save t-SNE') 114 | local myFile2= hdf5.open('./dqn_distill/lowd_activations_'..perp..'.h5', 'w') 115 | myFile2:write('data', mapped_activations) 116 | myFile2:close() 117 | print('saved t-SNE') 118 | end 119 | end 120 | -------------------------------------------------------------------------------- /Utilities/clean_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | 4 | debug = False 5 | 6 | termination = h5py.File('./termination.h5', 'r').get('data') 7 | reward = h5py.File('./reward.h5', 'r').get('data') 8 | activations = h5py.File('./activations.h5', 'r').get('data') 9 | actions = h5py.File('./actions.h5', 'r').get('data') 10 | qvals = h5py.File('./qvals.h5', 'r').get('data') 11 | states = h5py.File('./states.h5', 'r').get('data') 12 | 13 | if debug: 14 | print('Shape of termination: \n', (np.array(termination)).shape) 15 | print('Shape of activations: \n', (np.array(activations)).shape) 16 | print('Shape of actions: \n', (np.array(actions)).shape) 17 | print('Shape of reward: \n', (np.array(reward)).shape) 18 | print('Shape of qvals: \n', (np.array(qvals)).shape) 19 | print('Shape of states: \n', (np.array(states)).shape) 20 | 21 | startTrajectory = [] # first index in a trajectory that leads to success 22 | endTrajectory = [] # last index in a trajectory that leads to success 23 | 24 | initialIndex = 0 25 | for i in range(len(termination) - 1): 26 | if debug: 27 | print(str(i) + ' , ' + str(termination[i]) + ' , ' + str(reward[i])) 28 | if termination[i + 1] == 1 and reward[i] == 0: 29 | startTrajectory.append(initialIndex) 30 | endTrajectory.append(i) 31 | if debug: 32 | print('Success: ' + str(initialIndex) + ' , ' + str(i)) 33 | initialIndex = i + 2 34 | elif termination[i + 1] == 1: 35 | initialIndex = i + 2 36 | 37 | rewardClean = [] 38 | activationsClean = [] 39 | qvalsClean = [] 40 | statesClean = [] 41 | actionsClean = [] 42 | 43 | totalStates = 0 44 | 45 | for i in range(len(startTrajectory)): 46 | totalStates += endTrajectory[i] - startTrajectory[i] + 1 47 | for j in range(startTrajectory[i], endTrajectory[i]): 48 | rewardClean.append(reward[j]) 49 | activationsClean.append(activations[j, :]) 50 | qvalsClean.append(qvals[j]) 51 | actionsClean.append(actions[j]) 52 | statesClean.append(states[j, :]) 53 | 54 | rCleanFile = h5py.File('rewardClean.h5', 'w') 55 | rCleanFile.create_dataset('data', data=rewardClean) 56 | 57 | aCleanFile = h5py.File('actionsClean.h5', 'w') 58 | aCleanFile.create_dataset('data', data=actionsClean) 59 | 60 | actCleanFile = h5py.File('activationsClean.h5', 'w') 61 | actCleanFile.create_dataset('data', data=activationsClean) 62 | 63 | qCleanFile = h5py.File('qvalsClean.h5', 'w') 64 | qCleanFile.create_dataset('data', data=qvalsClean) 65 | 66 | sCleanFile = h5py.File('statesClean.h5', 'w') 67 | sCleanFile.create_dataset('data', data=statesClean) 68 | 69 | print('Done! Total states kept: ' + str(totalStates)) 70 | -------------------------------------------------------------------------------- /Utilities/learn_weights.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import sys 7 | import h5py 8 | import numpy as np 9 | import math 10 | 11 | import tensorflow as tf 12 | 13 | FLAGS = None 14 | 15 | def main(_): 16 | hiddenWidth1 = 100 17 | hiddenWidth2 = 64 18 | outputWidth = 5 19 | weightInit = -1 20 | batchSize = 4 21 | gamma = 0.7 22 | 23 | dataOut = h5py.File('skillWeightsQ.h5', 'w') #_2Layer.h5', 'w') 24 | 25 | # Import data 26 | print('Loading data...') 27 | data = h5py.File(FLAGS.file, 'r') 28 | numSkills = data.get('numberSkills') 29 | print('Number of skills is ' + str(numSkills[()])) 30 | 31 | dataOut.create_dataset('hiddenWidth', data=hiddenWidth1) 32 | dataOut.create_dataset('numberSkills', data=numSkills) 33 | 34 | for skill in range(numSkills[()]): 35 | activations = np.array(data.get('activations_' + str(skill))) 36 | actions = (np.array(data.get('actions_' + str(skill))) - 1) 37 | termination = np.array(data.get('termination_' + str(skill))) 38 | 39 | print('Creating model...') 40 | # Create the model 41 | step = tf.Variable(0, trainable=False) # cant attach non trainable variable to gpu 42 | with tf.device('/gpu:1'): 43 | x = tf.placeholder(tf.float32, [None, 512, ]) 44 | 45 | 46 | # Hidden Layer1 47 | W_hidden1 = tf.Variable(tf.truncated_normal([512, hiddenWidth1], stddev=0.1)) 48 | b_hidden1 = tf.Variable(tf.constant(0.1, shape=[hiddenWidth1])) 49 | y_hidden1 = tf.add(tf.matmul(x, W_hidden1), b_hidden1) 50 | act_hidden1 = tf.nn.relu(y_hidden1) 51 | 52 | # Hidden Layer2 53 | W_hidden2 = tf.Variable(tf.random_uniform([hiddenWidth1, hiddenWidth2], weightInit, 1)) 54 | b_hidden2 = tf.Variable(tf.random_uniform([hiddenWidth2], weightInit, 1)) 55 | y_hidden2 = tf.add(tf.matmul(act_hidden1, W_hidden2), b_hidden2) 56 | act_hidden2 = tf.nn.relu(y_hidden2) 57 | 58 | # Output Layer 59 | W_output = tf.Variable(tf.truncated_normal([hiddenWidth1, outputWidth], stddev=0.1)) 60 | #W_output = tf.Variable(tf.truncated_normal([hiddenWidth2, outputWidth], stddev=0.1)) 61 | b_output = tf.Variable(tf.constant(0.1, shape=[outputWidth])) 62 | y = tf.add(tf.matmul(act_hidden1, W_output), b_output) 63 | #y = tf.add(tf.matmul(act_hidden2, W_output), b_output) 64 | predict = tf.argmax(y, 1) 65 | ''' 66 | 67 | # Linear only 68 | W = tf.Variable(tf.random_uniform([512, outputWidth], weightInit, 0.01)) 69 | b = tf.Variable(tf.random_uniform([outputWidth], weightInit, 0.01)) 70 | y = tf.add(tf.matmul(x, W), b) 71 | predict = tf.argmax(y, 1) 72 | ''' 73 | nextQ = tf.placeholder(shape=[None, outputWidth, ], dtype=tf.float32) 74 | loss = tf.reduce_sum(tf.square(nextQ - y)) 75 | #loss = tf.nn.softmax_cross_entropy_with_logits(y, nextQ) 76 | 77 | rate = tf.train.exponential_decay(0.0005, step, 250, 0.9999) 78 | trainer = tf.train.AdamOptimizer(rate) #learning_rate=0.000001) # GradientDescentOptimizer(learning_rate=0.0001) 79 | updateModel = trainer.minimize(loss, global_step=step) 80 | 81 | # train_step = tf.train.AdamOptimizer().minimize(cross_entropy) 82 | # train_step = tf.train.RMSPropOptimizer(0.1).minimize(cross_entropy) 83 | # train_step = tf.train.GradientDescentOptimizer(0.1).minimize(cross_entropy) 84 | 85 | sess = tf.InteractiveSession(config=tf.ConfigProto(log_device_placement=True)) 86 | #tf.global_variables_initializer().run() 87 | tf.initialize_all_variables().run() 88 | #sess.run(tf.initialize_all_variables()) 89 | # Train 90 | maxQ = 1 91 | iteration = 0 92 | print('Training...') 93 | for _ in range(20000000): 94 | if (_ % 1000000 == 0 and _ > 0): # and False): 95 | testPredictions = sess.run(predict, feed_dict={x: activations[int(math.ceil(activations.shape[0] * 0.8)) + 1:activations.shape[0],:]}) 96 | trainPredictions = sess.run(predict, feed_dict={x: activations[0:int(math.ceil(activations.shape[0] * 0.8)),:]}) 97 | print('Done ' + str(_) + ' iterations. testing error is: ' + str(100 * np.sum(np.sign(np.absolute(testPredictions - actions[int(math.ceil(activations.shape[0] * 0.8)) + 1:activations.shape[0]]))) * 1.0 / (activations.shape[0] - int(math.ceil(activations.shape[0] * 0.8)) + 1)) + '%, training error is: ' + str(100 * np.sum(np.sign(np.absolute(trainPredictions - actions[0:int(math.ceil(activations.shape[0] * 0.8))]))) * 1.0 / (int(math.ceil(activations.shape[0] * 0.8)))) + '%') 98 | print('Loss: ' + str(loss_val) + ', Skill#: ' + str(skill)) 99 | 100 | index = np.random.randint(int(math.ceil(activations.shape[0] * 0.8)), size=batchSize) 101 | ''' 102 | iteration = iteration + 1 103 | iteration = iteration % int(math.ceil(activations.shape[0] * 0.8)) 104 | index = np.array([iteration]) 105 | ''' 106 | 107 | allQ = sess.run(y,feed_dict={x: activations[index, :]}) 108 | 109 | Q1 = sess.run(y,feed_dict={x: activations[index + 1, :]}) 110 | targetQ = np.ones(allQ.shape) * -1 111 | #targetQ = allQ 112 | for i in range(index.shape[0]): 113 | if termination[index[i]] == 1: 114 | Q = 0 115 | else: 116 | Q = np.max(Q1[i, :]) * gamma 117 | 118 | # maxQ = max(maxQ, abs(Q)) 119 | targetQ[i, :] = targetQ[i, :] + Q - gamma * gamma 120 | targetQ[i, int(actions[index[i]])] = targetQ[i, int(actions[index[i]])] + gamma * gamma 121 | targetQ = targetQ * 1.0 / maxQ 122 | 123 | ''' 124 | targetQ = np.zeros(allQ.shape) 125 | for i in range(index.shape[0]): 126 | targetQ[i, int(actions[index[i]])] = 1 127 | ''' 128 | 129 | _, loss_val = sess.run([updateModel, loss], feed_dict={x: activations[index, :], nextQ: targetQ}) 130 | 131 | # Test trained model 132 | print('Testing model on ' + str(len(actions[int(math.ceil(activations.shape[0] * 0.8)) + 1:activations.shape[0]])) + ' samples...') 133 | 134 | prediction = tf.argmax(y,1) 135 | predictions = prediction.eval(feed_dict={x: activations[int(math.ceil(activations.shape[0] * 0.8)) + 1:activations.shape[0],:]}, session=sess) 136 | # print(predictions) 137 | # print(actions[int(math.ceil(np.array(activations).shape[0] * 0.8)) + 1:np.array(activations).shape[0]]) 138 | print('Testing error:') 139 | print(100 * np.sum(np.sign(np.absolute(predictions - actions[int(math.ceil(activations.shape[0] * 0.8)) + 1:activations.shape[0]]))) * 1.0 / (activations.shape[0] - int(math.ceil(activations.shape[0] * 0.8)) + 1)) 140 | print('Training error:') 141 | predictions = prediction.eval(feed_dict={x: activations[0:int(math.ceil(activations.shape[0] * 0.8)),:]}, session=sess) 142 | print(100 * np.sum(np.sign(np.absolute(predictions - actions[0:int(math.ceil(activations.shape[0] * 0.8))]))) * 1.0 / (int(math.ceil(activations.shape[0] * 0.8)))) 143 | 144 | dataOut.create_dataset('W_hidden_' + str(skill), data=sess.run(W_hidden1)) 145 | dataOut.create_dataset('b_hidden_' + str(skill), data=sess.run(b_hidden1)) 146 | dataOut.create_dataset('W_output_' + str(skill), data=sess.run(W_output)) 147 | dataOut.create_dataset('b_output_' + str(skill), data=sess.run(b_output)) 148 | 149 | if __name__ == '__main__': 150 | parser = argparse.ArgumentParser() 151 | parser.add_argument('-file', type=str, required=True, 152 | help='Name of Skill extraction file.') 153 | FLAGS, unparsed = parser.parse_known_args() 154 | tf.app.run(main=main) 155 | -------------------------------------------------------------------------------- /graying_the_box/.gitignore: -------------------------------------------------------------------------------- 1 | *.bin 2 | *.h5 3 | *.pyc 4 | *.so 5 | *.o 6 | data/ -------------------------------------------------------------------------------- /graying_the_box/.idea/.name: -------------------------------------------------------------------------------- 1 | graying_the_box -------------------------------------------------------------------------------- /graying_the_box/.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /graying_the_box/.idea/graying_the_box.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | -------------------------------------------------------------------------------- /graying_the_box/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /graying_the_box/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /graying_the_box/.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /graying_the_box/LUA/dqn/Rectifier.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | --[[ Rectified Linear Unit. 8 | 9 | The output is max(0, input). 10 | --]] 11 | 12 | local Rectifier, parent = torch.class('nn.Rectifier', 'nn.Module') 13 | 14 | -- This module accepts minibatches 15 | function Rectifier:updateOutput(input) 16 | return self.output:resizeAs(input):copy(input):abs():add(input):div(2) 17 | end 18 | 19 | function Rectifier:updateGradInput(input, gradOutput) 20 | self.gradInput:resizeAs(self.output) 21 | return self.gradInput:sign(self.output):cmul(gradOutput) 22 | end -------------------------------------------------------------------------------- /graying_the_box/LUA/dqn/Scale.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "nn" 8 | require "image" 9 | 10 | local scale = torch.class('nn.Scale', 'nn.Module') 11 | 12 | 13 | function scale:__init(height, width) 14 | self.height = height 15 | self.width = width 16 | end 17 | 18 | function scale:forward(x) 19 | local x = x 20 | if x:dim() > 3 then 21 | x = x[1] 22 | end 23 | 24 | x = image.rgb2y(x) 25 | x = image.scale(x, self.width, self.height, 'bilinear') 26 | return x 27 | end 28 | 29 | function scale:updateOutput(input) 30 | return self:forward(input) 31 | end 32 | 33 | function scale:float() 34 | end 35 | -------------------------------------------------------------------------------- /graying_the_box/LUA/dqn/convnet.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "initenv" 8 | 9 | function create_network(args) 10 | 11 | local net = nn.Sequential() 12 | net:add(nn.Reshape(unpack(args.input_dims))) 13 | 14 | --- first convolutional layer 15 | local convLayer = nn.SpatialConvolution 16 | 17 | net:add(convLayer(args.hist_len*args.ncols, args.n_units[1], 18 | args.filter_size[1], args.filter_size[1], 19 | args.filter_stride[1], args.filter_stride[1],1)) 20 | net:add(args.nl()) 21 | 22 | -- Add convolutional layers 23 | for i=1,(#args.n_units-1) do 24 | -- second convolutional layer 25 | net:add(convLayer(args.n_units[i], args.n_units[i+1], 26 | args.filter_size[i+1], args.filter_size[i+1], 27 | args.filter_stride[i+1], args.filter_stride[i+1])) 28 | net:add(args.nl()) 29 | end 30 | 31 | local nel 32 | if args.gpu >= 0 then 33 | nel = net:cuda():forward(torch.zeros(1,unpack(args.input_dims)) 34 | :cuda()):nElement() 35 | else 36 | nel = net:forward(torch.zeros(1,unpack(args.input_dims))):nElement() 37 | end 38 | 39 | -- reshape all feature planes into a vector per example 40 | net:add(nn.Reshape(nel)) 41 | 42 | -- fully connected layer 43 | net:add(nn.Linear(nel, args.n_hid[1])) 44 | net:add(args.nl()) 45 | local last_layer_size = args.n_hid[1] 46 | 47 | for i=1,(#args.n_hid-1) do 48 | -- add Linear layer 49 | last_layer_size = args.n_hid[i+1] 50 | net:add(nn.Linear(args.n_hid[i], last_layer_size)) 51 | net:add(args.nl()) 52 | end 53 | 54 | -- add the last fully connected layer (to actions) 55 | net:add(nn.Linear(last_layer_size, args.n_actions)) 56 | 57 | if args.gpu >=0 then 58 | net:cuda() 59 | end 60 | if args.verbose >= 2 then 61 | print(net) 62 | print('Convolutional layers flattened output size:', nel) 63 | end 64 | return net 65 | end 66 | -------------------------------------------------------------------------------- /graying_the_box/LUA/dqn/convnet_atari3.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require 'convnet' 8 | 9 | return function(args) 10 | args.n_units = {32, 64, 64} 11 | args.filter_size = {8, 4, 3} 12 | args.filter_stride = {4, 2, 1} 13 | args.n_hid = {512} 14 | args.nl = nn.Rectifier 15 | 16 | return create_network(args) 17 | end 18 | 19 | -------------------------------------------------------------------------------- /graying_the_box/LUA/dqn/initenv.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | dqn = {} 7 | 8 | require 'torch' 9 | require 'nn' 10 | require 'nngraph' 11 | require 'nnutils' 12 | require 'image' 13 | require 'Scale' 14 | require 'NeuralQLearner' 15 | require 'TransitionTable' 16 | require 'Rectifier' 17 | 18 | 19 | function torchSetup(_opt) 20 | _opt = _opt or {} 21 | local opt = table.copy(_opt) 22 | assert(opt) 23 | 24 | -- preprocess options: 25 | --- convert options strings to tables 26 | if opt.pool_frms then 27 | opt.pool_frms = str_to_table(opt.pool_frms) 28 | end 29 | if opt.env_params then 30 | opt.env_params = str_to_table(opt.env_params) 31 | end 32 | if opt.agent_params then 33 | opt.agent_params = str_to_table(opt.agent_params) 34 | opt.agent_params.gpu = opt.gpu 35 | opt.agent_params.best = opt.best 36 | opt.agent_params.verbose = opt.verbose 37 | if opt.network ~= '' then 38 | opt.agent_params.network = opt.network 39 | end 40 | end 41 | 42 | --- general setup 43 | opt.tensorType = opt.tensorType or 'torch.FloatTensor' 44 | torch.setdefaulttensortype(opt.tensorType) 45 | if not opt.threads then 46 | opt.threads = 4 47 | end 48 | torch.setnumthreads(opt.threads) 49 | if not opt.verbose then 50 | opt.verbose = 10 51 | end 52 | if opt.verbose >= 1 then 53 | print('Torch Threads:', torch.getnumthreads()) 54 | end 55 | 56 | --- set gpu device 57 | if opt.gpu and opt.gpu >= 0 then 58 | require 'cutorch' 59 | require 'cunn' 60 | if opt.gpu == 0 then 61 | local gpu_id = tonumber(os.getenv('GPU_ID')) 62 | if gpu_id then opt.gpu = gpu_id+1 end 63 | end 64 | if opt.gpu > 0 then cutorch.setDevice(opt.gpu) end 65 | opt.gpu = cutorch.getDevice() 66 | print('Using GPU device id:', opt.gpu-1) 67 | else 68 | opt.gpu = -1 69 | if opt.verbose >= 1 then 70 | print('Using CPU code only. GPU device id:', opt.gpu) 71 | end 72 | end 73 | 74 | --- set up random number generators 75 | -- removing lua RNG; seeding torch RNG with opt.seed and setting cutorch 76 | -- RNG seed to the first uniform random int32 from the previous RNG; 77 | -- this is preferred because using the same seed for both generators 78 | -- may introduce correlations; we assume that both torch RNGs ensure 79 | -- adequate dispersion for different seeds. 80 | math.random = nil 81 | opt.seed = opt.seed or 1 82 | torch.manualSeed(opt.seed) 83 | if opt.verbose >= 1 then 84 | print('Torch Seed:', torch.initialSeed()) 85 | end 86 | local firstRandInt = torch.random() 87 | if opt.gpu >= 0 then 88 | cutorch.manualSeed(firstRandInt) 89 | if opt.verbose >= 1 then 90 | print('CUTorch Seed:', cutorch.initialSeed()) 91 | end 92 | end 93 | 94 | return opt 95 | end 96 | 97 | 98 | function setup(_opt) 99 | assert(_opt) 100 | 101 | --preprocess options: 102 | --- convert options strings to tables 103 | _opt.pool_frms = str_to_table(_opt.pool_frms) 104 | _opt.env_params = str_to_table(_opt.env_params) 105 | _opt.agent_params = str_to_table(_opt.agent_params) 106 | if _opt.agent_params.transition_params then 107 | _opt.agent_params.transition_params = 108 | str_to_table(_opt.agent_params.transition_params) 109 | end 110 | 111 | --- first things first 112 | local opt = torchSetup(_opt) 113 | 114 | -- load training framework and environment 115 | local framework = require(opt.framework) 116 | assert(framework) 117 | 118 | local gameEnv = framework.GameEnvironment(opt) 119 | local gameActions = gameEnv:getActions() 120 | 121 | -- agent options 122 | _opt.agent_params.actions = gameActions 123 | _opt.agent_params.gpu = _opt.gpu 124 | _opt.agent_params.best = _opt.best 125 | if _opt.network ~= '' then 126 | _opt.agent_params.network = _opt.network 127 | end 128 | _opt.agent_params.verbose = _opt.verbose 129 | if not _opt.agent_params.state_dim then 130 | _opt.agent_params.state_dim = gameEnv:nObsFeature() 131 | end 132 | 133 | local agent = dqn[_opt.agent](_opt.agent_params) 134 | 135 | if opt.verbose >= 1 then 136 | print('Set up Torch using these options:') 137 | for k, v in pairs(opt) do 138 | print(k, v) 139 | end 140 | end 141 | 142 | return gameEnv, gameActions, agent, opt 143 | end 144 | 145 | 146 | 147 | --- other functions 148 | 149 | function str_to_table(str) 150 | if type(str) == 'table' then 151 | return str 152 | end 153 | if not str or type(str) ~= 'string' then 154 | if type(str) == 'table' then 155 | return str 156 | end 157 | return {} 158 | end 159 | local ttr 160 | if str ~= '' then 161 | local ttx=tt 162 | loadstring('tt = {' .. str .. '}')() 163 | ttr = tt 164 | tt = ttx 165 | else 166 | ttr = {} 167 | end 168 | return ttr 169 | end 170 | 171 | function table.copy(t) 172 | if t == nil then return nil end 173 | local nt = {} 174 | for k, v in pairs(t) do 175 | if type(v) == 'table' then 176 | nt[k] = table.copy(v) 177 | else 178 | nt[k] = v 179 | end 180 | end 181 | setmetatable(nt, table.copy(getmetatable(t))) 182 | return nt 183 | end 184 | -------------------------------------------------------------------------------- /graying_the_box/LUA/dqn/net_downsample_2x_full_y.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "image" 8 | require "Scale" 9 | 10 | local function create_network(args) 11 | -- Y (luminance) 12 | return nn.Scale(84, 84, true) 13 | end 14 | 15 | return create_network 16 | -------------------------------------------------------------------------------- /graying_the_box/LUA/dqn/nnutils.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "torch" 8 | 9 | function recursive_map(module, field, func) 10 | local str = "" 11 | if module[field] or module.modules then 12 | str = str .. torch.typename(module) .. ": " 13 | end 14 | if module[field] then 15 | str = str .. func(module[field]) 16 | end 17 | if module.modules then 18 | str = str .. "[" 19 | for i, submodule in ipairs(module.modules) do 20 | local submodule_str = recursive_map(submodule, field, func) 21 | str = str .. submodule_str 22 | if i < #module.modules and string.len(submodule_str) > 0 then 23 | str = str .. " " 24 | end 25 | end 26 | str = str .. "]" 27 | end 28 | 29 | return str 30 | end 31 | 32 | function abs_mean(w) 33 | return torch.mean(torch.abs(w:clone():float())) 34 | end 35 | 36 | function abs_max(w) 37 | return torch.abs(w:clone():float()):max() 38 | end 39 | 40 | -- Build a string of average absolute weight values for the modules in the 41 | -- given network. 42 | function get_weight_norms(module) 43 | return "Weight norms:\n" .. recursive_map(module, "weight", abs_mean) .. 44 | "\nWeight max:\n" .. recursive_map(module, "weight", abs_max) 45 | end 46 | 47 | -- Build a string of average absolute weight gradient values for the modules 48 | -- in the given network. 49 | function get_grad_norms(module) 50 | return "Weight grad norms:\n" .. 51 | recursive_map(module, "gradWeight", abs_mean) .. 52 | "\nWeight grad max:\n" .. recursive_map(module, "gradWeight", abs_max) 53 | end 54 | -------------------------------------------------------------------------------- /graying_the_box/LUA/dqn/test_agent.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | See LICENSE file for full terms of limited license. 4 | ]] 5 | 6 | gd = require "gd" 7 | 8 | if not dqn then 9 | require "initenv" 10 | end 11 | 12 | local cmd = torch.CmdLine() 13 | cmd:text() 14 | cmd:text('Train Agent in Environment:') 15 | cmd:text() 16 | cmd:text('Options:') 17 | 18 | cmd:option('-framework', '', 'name of training framework') 19 | cmd:option('-env', '', 'name of environment to use') 20 | cmd:option('-game_path', '', 'path to environment file (ROM)') 21 | cmd:option('-env_params', '', 'string of environment parameters') 22 | cmd:option('-pool_frms', '', 23 | 'string of frame pooling parameters (e.g.: size=2,type="max")') 24 | cmd:option('-actrep', 1, 'how many times to repeat action') 25 | cmd:option('-random_starts', 0, 'play action 0 between 1 and random_starts ' .. 26 | 'number of times at the start of each training episode') 27 | 28 | cmd:option('-name', '', 'filename used for saving network and training history') 29 | cmd:option('-network', '', 'reload pretrained network') 30 | cmd:option('-agent', '', 'name of agent file to use') 31 | cmd:option('-agent_params', '', 'string of agent parameters') 32 | cmd:option('-seed', 1, 'fixed input seed for repeatable experiments') 33 | 34 | cmd:option('-verbose', 2, 35 | 'the higher the level, the more information is printed to screen') 36 | cmd:option('-threads', 1, 'number of BLAS threads') 37 | cmd:option('-gpu', -1, 'gpu flag') 38 | cmd:option('-gif_file', '', 'GIF path to write session screens') 39 | cmd:option('-csv_file', '', 'CSV path to write session data') 40 | 41 | cmd:text() 42 | 43 | local opt = cmd:parse(arg) 44 | 45 | --- General setup. 46 | local game_env, game_actions, agent, opt = setup(opt) 47 | 48 | -- override print to always flush the output 49 | local old_print = print 50 | local print = function(...) 51 | old_print(...) 52 | io.flush() 53 | end 54 | 55 | -- file names from command line 56 | local gif_filename = opt.gif_file 57 | 58 | -- start a new game 59 | local screen, reward, terminal = game_env:newGame() 60 | 61 | -- compress screen to JPEG with 100% quality 62 | local jpg = image.compressJPG(screen:squeeze(), 100) 63 | -- create gd image from JPEG string 64 | local im = gd.createFromJpegStr(jpg:storage():string()) 65 | -- convert truecolor to palette 66 | im:trueColorToPalette(false, 256) 67 | 68 | -- write GIF header, use global palette and infinite looping 69 | im:gifAnimBegin(gif_filename, true, 0) 70 | -- write first frame 71 | im:gifAnimAdd(gif_filename, false, 0, 0, 7, gd.DISPOSAL_NONE) 72 | 73 | -- remember the image and show it first 74 | local previm = im 75 | local win = image.display({image=screen}) 76 | 77 | print("Started playing...") 78 | 79 | -- play one episode (game) 80 | while not terminal do 81 | -- if action was chosen randomly, Q-value is 0 82 | agent.bestq = 0 83 | 84 | -- choose the best action 85 | local action_index = agent:perceive(reward, screen, terminal, true, 0.05) 86 | 87 | -- play game in test mode (episodes don't end when losing a life) 88 | screen, reward, terminal = game_env:step(game_actions[action_index], false) 89 | 90 | -- display screen 91 | image.display({image=screen, win=win}) 92 | 93 | -- create gd image from tensor 94 | jpg = image.compressJPG(screen:squeeze(), 100) 95 | im = gd.createFromJpegStr(jpg:storage():string()) 96 | 97 | -- use palette from previous (first) image 98 | im:trueColorToPalette(false, 256) 99 | im:paletteCopy(previm) 100 | 101 | -- write new GIF frame, no local palette, starting from left-top, 7ms delay 102 | im:gifAnimAdd(gif_filename, false, 0, 0, 7, gd.DISPOSAL_NONE) 103 | -- remember previous screen for optimal compression 104 | previm = im 105 | 106 | end 107 | 108 | -- end GIF animation and close CSV file 109 | gd.gifAnimEnd(gif_filename) 110 | 111 | print("Finished playing, close window to exit!") -------------------------------------------------------------------------------- /graying_the_box/LUA/dqn/train_agent.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | if not dqn then 8 | require "initenv" 9 | end 10 | 11 | local cmd = torch.CmdLine() 12 | cmd:text() 13 | cmd:text('Train Agent in Environment:') 14 | cmd:text() 15 | cmd:text('Options:') 16 | 17 | cmd:option('-framework', '', 'name of training framework') 18 | cmd:option('-env', '', 'name of environment to use') 19 | cmd:option('-game_path', '', 'path to environment file (ROM)') 20 | cmd:option('-env_params', '', 'string of environment parameters') 21 | cmd:option('-pool_frms', '', 22 | 'string of frame pooling parameters (e.g.: size=2,type="max")') 23 | cmd:option('-actrep', 1, 'how many times to repeat action') 24 | cmd:option('-random_starts', 0, 'play action 0 between 1 and random_starts ' .. 25 | 'number of times at the start of each training episode') 26 | 27 | cmd:option('-name', '', 'filename used for saving network and training history') 28 | cmd:option('-network', '', 'reload pretrained network') 29 | cmd:option('-agent', '', 'name of agent file to use') 30 | cmd:option('-agent_params', '', 'string of agent parameters') 31 | cmd:option('-seed', 1, 'fixed input seed for repeatable experiments') 32 | cmd:option('-saveNetworkParams', true, 33 | 'saves the agent network in a separate file') -- tom 34 | cmd:option('-prog_freq', 50000, 'frequency of progress output') 35 | cmd:option('-save_freq', 1000000, 'the model is saved every save_freq steps') 36 | cmd:option('-eval_freq', 10^4, 'frequency of greedy evaluation') 37 | cmd:option('-save_versions', 5, '') --tom 38 | 39 | cmd:option('-steps', 10^5, 'number of training steps to perform') 40 | cmd:option('-eval_steps', 10^5, 'number of evaluation steps') 41 | 42 | cmd:option('-verbose', 2, 43 | 'the higher the level, the more information is printed to screen') 44 | cmd:option('-threads', 1, 'number of BLAS threads') 45 | cmd:option('-gpu', -1, 'gpu flag') 46 | 47 | cmd:text() 48 | 49 | local opt = cmd:parse(arg) 50 | 51 | --- General setup. 52 | local game_env, game_actions, agent, opt = setup(opt) 53 | 54 | -- override print to always flush the output 55 | local old_print = print 56 | local print = function(...) 57 | old_print(...) 58 | io.flush() 59 | end 60 | 61 | local learn_start = agent.learn_start 62 | local start_time = sys.clock() 63 | local reward_counts = {} 64 | local episode_counts = {} 65 | local time_history = {} 66 | local v_history = {} 67 | local qmax_history = {} 68 | local td_history = {} 69 | local reward_history = {} 70 | local step = 0 71 | time_history[1] = 0 72 | 73 | local total_reward 74 | local nrewards 75 | local nepisodes 76 | local episode_reward 77 | 78 | local screen, reward, terminal = game_env:getState() 79 | 80 | print("Iteration ..", step) 81 | local win = nil 82 | win1 = nil 83 | win2 = nil 84 | win3 = nil 85 | 86 | -- Tom 87 | local testing_eps = 0.05 88 | local in_episode_time = 0 89 | local episode_index = 1 90 | local episode_reward = 0 91 | local life_count = 5 92 | -- Tom End 93 | while step < opt.steps do 94 | step = step + 1 95 | 96 | -- Tom 97 | --local action_index = agent:perceive(reward, screen, terminal) more arguments 98 | in_episode_time = in_episode_time + 1 99 | local action_index = agent:perceive(reward, screen, terminal) 100 | 101 | -- Tom End 102 | 103 | -- game over? get next game! 104 | if not terminal then 105 | screen, reward, terminal = game_env:step(game_actions[action_index], true) 106 | episode_reward = episode_reward + reward 107 | else 108 | -- Tom 109 | --[[agent:UpdateEpisodeData(episode_index,in_episode_time) 110 | in_episode_time = 0 111 | life_count = life_count - 1 112 | if life_count == 0 then 113 | --print("Episode Reward: ", episode_reward) 114 | agent:UpdateEpisodeReward(episode_index,episode_reward) 115 | life_count = 5 116 | episode_reward = 0 117 | end 118 | episode_index = episode_index +1 119 | -- Tom End 120 | --]] 121 | if opt.random_starts > 0 then 122 | screen, reward, terminal = game_env:nextRandomGame() 123 | else 124 | screen, reward, terminal = game_env:newGame() 125 | end 126 | end 127 | 128 | -- display screen 129 | win = image.display({image=screen, win=win}) 130 | 131 | if step % opt.prog_freq == 0 then 132 | assert(step==agent.numSteps, 'trainer step: ' .. step .. 133 | ' & agent.numSteps: ' .. agent.numSteps) 134 | print("Steps: ", step) 135 | agent:report() 136 | print(agent.transitions.ER_Cluster_counter) 137 | collectgarbage() 138 | end 139 | 140 | if step%1000 == 0 then collectgarbage() end 141 | 142 | if step % opt.eval_freq == 0 and step > learn_start then 143 | 144 | screen, reward, terminal = game_env:newGame() 145 | 146 | total_reward = 0 147 | nrewards = 0 148 | nepisodes = 0 149 | episode_reward = 0 150 | 151 | local eval_time = sys.clock() 152 | for estep=1,opt.eval_steps do 153 | local action_index = agent:perceive(reward, screen, terminal, true, testing_eps) 154 | 155 | -- Play game in test mode (episodes don't end when losing a life) 156 | screen, reward, terminal = game_env:step(game_actions[action_index]) 157 | 158 | -- display screen 159 | win = image.display({image=screen, win=win}) 160 | 161 | if estep%1000 == 0 then collectgarbage() end 162 | 163 | -- record every reward 164 | episode_reward = episode_reward + reward 165 | if reward ~= 0 then 166 | nrewards = nrewards + 1 167 | end 168 | 169 | if terminal then 170 | total_reward = total_reward + episode_reward 171 | episode_reward = 0 172 | nepisodes = nepisodes + 1 173 | screen, reward, terminal = game_env:nextRandomGame() 174 | end 175 | end 176 | 177 | eval_time = sys.clock() - eval_time 178 | start_time = start_time + eval_time 179 | agent:compute_validation_statistics() 180 | local ind = #reward_history+1 181 | total_reward = total_reward/math.max(1, nepisodes) 182 | 183 | if #reward_history == 0 or total_reward > torch.Tensor(reward_history):max() then 184 | agent.best_network = agent.network:clone() 185 | end 186 | 187 | if agent.v_avg then 188 | v_history[ind] = agent.v_avg 189 | td_history[ind] = agent.tderr_avg 190 | qmax_history[ind] = agent.q_max 191 | end 192 | print("V", v_history[ind], "TD error", td_history[ind], "Qmax", qmax_history[ind]) 193 | 194 | reward_history[ind] = total_reward 195 | reward_counts[ind] = nrewards 196 | episode_counts[ind] = nepisodes 197 | 198 | time_history[ind+1] = sys.clock() - start_time 199 | 200 | local time_dif = time_history[ind+1] - time_history[ind] 201 | 202 | local training_rate = opt.actrep*opt.eval_freq/time_dif 203 | 204 | print(string.format( 205 | '\nSteps: %d (frames: %d), reward: %.2f, epsilon: %.2f, lr: %G, ' .. 206 | 'training time: %ds, training rate: %dfps, testing time: %ds, ' .. 207 | 'testing rate: %dfps, num. ep.: %d, num. rewards: %d', 208 | step, step*opt.actrep, total_reward, agent.ep, agent.lr, time_dif, 209 | training_rate, eval_time, opt.actrep*opt.eval_steps/eval_time, 210 | nepisodes, nrewards)) 211 | end 212 | 213 | if step % opt.save_freq == 0 or step == opt.steps then 214 | local s, a, r, s2, term = agent.valid_s, agent.valid_a, agent.valid_r, 215 | agent.valid_s2, agent.valid_term 216 | agent.valid_s, agent.valid_a, agent.valid_r, agent.valid_s2, 217 | agent.valid_term = nil, nil, nil, nil, nil, nil, nil 218 | local w, dw, g, g2, delta, delta2, deltas, tmp = agent.w, agent.dw, 219 | agent.g, agent.g2, agent.delta, agent.delta2, agent.deltas, agent.tmp 220 | agent.w, agent.dw, agent.g, agent.g2, agent.delta, agent.delta2, 221 | agent.deltas, agent.tmp = nil, nil, nil, nil, nil, nil, nil, nil 222 | 223 | local filename = opt.name 224 | if opt.save_versions > 0 then 225 | filename = filename .. "_" .. math.floor(step / opt.save_versions) 226 | end 227 | filename = filename 228 | torch.save(filename .. ".t7", {agent = agent, 229 | model = agent.network, 230 | best_model = agent.best_network, 231 | reward_history = reward_history, 232 | reward_counts = reward_counts, 233 | episode_counts = episode_counts, 234 | time_history = time_history, 235 | v_history = v_history, 236 | td_history = td_history, 237 | qmax_history = qmax_history, 238 | arguments=opt}) 239 | if opt.saveNetworkParams then 240 | local nets = {network=w:clone():float()} 241 | torch.save(filename..'.params.t7', nets, 'ascii') 242 | end 243 | agent.valid_s, agent.valid_a, agent.valid_r, agent.valid_s2, 244 | agent.valid_term = s, a, r, s2, term 245 | agent.w, agent.dw, agent.g, agent.g2, agent.delta, agent.delta2, 246 | agent.deltas, agent.tmp = w, dw, g, g2, delta, delta2, deltas, tmp 247 | print('Saved:', filename .. '.t7') 248 | io.flush() 249 | collectgarbage() 250 | end 251 | end 252 | -------------------------------------------------------------------------------- /graying_the_box/LUA/dqn/train_agent_tmp.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | See LICENSE file for full terms of limited license. 4 | ]] 5 | 6 | if not dqn then 7 | require "initenv" 8 | end 9 | 10 | local cmd = torch.CmdLine() 11 | cmd:text() 12 | cmd:text('Train Agent in Environment:') 13 | cmd:text() 14 | cmd:text('Options:') 15 | 16 | cmd:option('-framework', '', 'name of training framework') 17 | cmd:option('-env', '', 'name of environment to use') 18 | cmd:option('-game_path', '', 'path to environment file (ROM)') 19 | cmd:option('-env_params', '', 'string of environment parameters') 20 | cmd:option('-pool_frms', '', 21 | 'string of frame pooling parameters (e.g.: size=2,type="max")') 22 | cmd:option('-actrep', 1, 'how many times to repeat action') 23 | cmd:option('-random_starts', 0, 'play action 0 between 1 and random_starts ' .. 24 | 'number of times at the start of each training episode') 25 | 26 | cmd:option('-name', '', 'filename used for saving network and training history') 27 | cmd:option('-network', '', 'reload pretrained network') 28 | cmd:option('-agent', '', 'name of agent file to use') 29 | cmd:option('-agent_params', '', 'string of agent parameters') 30 | cmd:option('-seed', 1, 'fixed input seed for repeatable experiments') 31 | cmd:option('-saveNetworkParams', false, 32 | 'saves the agent network in a separate file') 33 | cmd:option('-prog_freq', 5*10^3, 'frequency of progress output') 34 | cmd:option('-save_freq', 5*10^4, 'the model is saved every save_freq steps') 35 | cmd:option('-eval_freq', 10^4, 'frequency of greedy evaluation') 36 | cmd:option('-save_versions', 0, '') 37 | 38 | cmd:option('-steps', 10^5, 'number of training steps to perform') 39 | cmd:option('-eval_steps', 10^5, 'number of evaluation steps') 40 | 41 | cmd:option('-verbose', 2, 42 | 'the higher the level, the more information is printed to screen') 43 | cmd:option('-threads', 1, 'number of BLAS threads') 44 | cmd:option('-gpu', -1, 'gpu flag') 45 | 46 | cmd:text() 47 | 48 | local opt = cmd:parse(arg) 49 | 50 | --- General setup. 51 | local game_env, game_actions, agent, opt = setup(opt) 52 | 53 | -- override print to always flush the output 54 | local old_print = print 55 | local print = function(...) 56 | old_print(...) 57 | io.flush() 58 | end 59 | 60 | local learn_start = agent.learn_start 61 | local start_time = sys.clock() 62 | local reward_counts = {} 63 | local episode_counts = {} 64 | local time_history = {} 65 | local v_history = {} 66 | local qmax_history = {} 67 | local td_history = {} 68 | local reward_history = {} 69 | local step = 0 70 | time_history[1] = 0 71 | 72 | local total_reward 73 | local nrewards 74 | local nepisodes 75 | local episode_reward 76 | 77 | local screen, reward, terminal = game_env:getState() 78 | 79 | print("Iteration ..", step) 80 | local win = nil 81 | while step < opt.steps do 82 | step = step + 1 83 | local action_index = agent:perceive(reward, screen, terminal) 84 | 85 | -- game over? get next game! 86 | if not terminal then 87 | screen, reward, terminal = game_env:step(game_actions[action_index], true) 88 | else 89 | if opt.random_starts > 0 then 90 | screen, reward, terminal = game_env:nextRandomGame() 91 | else 92 | screen, reward, terminal = game_env:newGame() 93 | end 94 | end 95 | 96 | -- display screen 97 | win = image.display({image=screen, win=win}) 98 | 99 | if step % opt.prog_freq == 0 then 100 | assert(step==agent.numSteps, 'trainer step: ' .. step .. 101 | ' & agent.numSteps: ' .. agent.numSteps) 102 | print("Steps: ", step) 103 | agent:report() 104 | collectgarbage() 105 | end 106 | 107 | if step%1000 == 0 then collectgarbage() end 108 | 109 | if step % opt.eval_freq == 0 and step > learn_start then 110 | 111 | screen, reward, terminal = game_env:newGame() 112 | 113 | total_reward = 0 114 | nrewards = 0 115 | nepisodes = 0 116 | episode_reward = 0 117 | 118 | local eval_time = sys.clock() 119 | for estep=1,opt.eval_steps do 120 | local action_index = agent:perceive(reward, screen, terminal, true, 0.05) 121 | 122 | -- Play game in test mode (episodes don't end when losing a life) 123 | screen, reward, terminal = game_env:step(game_actions[action_index]) 124 | 125 | -- display screen 126 | win = image.display({image=screen, win=win}) 127 | 128 | if estep%1000 == 0 then collectgarbage() end 129 | 130 | -- record every reward 131 | episode_reward = episode_reward + reward 132 | if reward ~= 0 then 133 | nrewards = nrewards + 1 134 | end 135 | 136 | if terminal then 137 | total_reward = total_reward + episode_reward 138 | episode_reward = 0 139 | nepisodes = nepisodes + 1 140 | screen, reward, terminal = game_env:nextRandomGame() 141 | end 142 | end 143 | 144 | eval_time = sys.clock() - eval_time 145 | start_time = start_time + eval_time 146 | agent:compute_validation_statistics() 147 | local ind = #reward_history+1 148 | total_reward = total_reward/math.max(1, nepisodes) 149 | 150 | if #reward_history == 0 or total_reward > torch.Tensor(reward_history):max() then 151 | agent.best_network = agent.network:clone() 152 | end 153 | 154 | if agent.v_avg then 155 | v_history[ind] = agent.v_avg 156 | td_history[ind] = agent.tderr_avg 157 | qmax_history[ind] = agent.q_max 158 | end 159 | print("V", v_history[ind], "TD error", td_history[ind], "Qmax", qmax_history[ind]) 160 | 161 | reward_history[ind] = total_reward 162 | reward_counts[ind] = nrewards 163 | episode_counts[ind] = nepisodes 164 | 165 | time_history[ind+1] = sys.clock() - start_time 166 | 167 | local time_dif = time_history[ind+1] - time_history[ind] 168 | 169 | local training_rate = opt.actrep*opt.eval_freq/time_dif 170 | 171 | print(string.format( 172 | '\nSteps: %d (frames: %d), reward: %.2f, epsilon: %.2f, lr: %G, ' .. 173 | 'training time: %ds, training rate: %dfps, testing time: %ds, ' .. 174 | 'testing rate: %dfps, num. ep.: %d, num. rewards: %d', 175 | step, step*opt.actrep, total_reward, agent.ep, agent.lr, time_dif, 176 | training_rate, eval_time, opt.actrep*opt.eval_steps/eval_time, 177 | nepisodes, nrewards)) 178 | end 179 | 180 | if step % opt.save_freq == 0 or step == opt.steps then 181 | local s, a, r, s2, term = agent.valid_s, agent.valid_a, agent.valid_r, 182 | agent.valid_s2, agent.valid_term 183 | agent.valid_s, agent.valid_a, agent.valid_r, agent.valid_s2, 184 | agent.valid_term = nil, nil, nil, nil, nil, nil, nil 185 | local w, dw, g, g2, delta, delta2, deltas, tmp = agent.w, agent.dw, 186 | agent.g, agent.g2, agent.delta, agent.delta2, agent.deltas, agent.tmp 187 | agent.w, agent.dw, agent.g, agent.g2, agent.delta, agent.delta2, 188 | agent.deltas, agent.tmp = nil, nil, nil, nil, nil, nil, nil, nil 189 | 190 | local filename = opt.name 191 | if opt.save_versions > 0 then 192 | filename = filename .. "_" .. math.floor(step / opt.save_versions) 193 | end 194 | filename = filename 195 | torch.save(filename .. ".t7", {agent = agent, 196 | model = agent.network, 197 | best_model = agent.best_network, 198 | reward_history = reward_history, 199 | reward_counts = reward_counts, 200 | episode_counts = episode_counts, 201 | time_history = time_history, 202 | v_history = v_history, 203 | td_history = td_history, 204 | qmax_history = qmax_history, 205 | arguments=opt}) 206 | if opt.saveNetworkParams then 207 | local nets = {network=w:clone():float()} 208 | torch.save(filename..'.params.t7', nets, 'ascii') 209 | end 210 | agent.valid_s, agent.valid_a, agent.valid_r, agent.valid_s2, 211 | agent.valid_term = s, a, r, s2, term 212 | agent.w, agent.dw, agent.g, agent.g2, agent.delta, agent.delta2, 213 | agent.deltas, agent.tmp = w, dw, g, g2, delta, delta2, deltas, tmp 214 | print('Saved:', filename .. '.t7') 215 | io.flush() 216 | collectgarbage() 217 | end 218 | end 219 | -------------------------------------------------------------------------------- /graying_the_box/LUA/logs/Results/Plot.lua: -------------------------------------------------------------------------------- 1 | require 'gnuplot' 2 | 3 | n = 0 4 | local file1 = io.open("./Results/reward.txt") 5 | if file1 then 6 | for line in file1:lines() do 7 | n = n+1 8 | end 9 | end 10 | file1:close() 11 | 12 | m = 0 13 | local file1 = io.open("./Results/conv1.txt") 14 | if file1 then 15 | for line in file1:lines() do 16 | m = m+1 17 | end 18 | end 19 | file1:close() 20 | m = m-1 21 | local file1 = io.open("./Results/reward.txt") 22 | local file2 = io.open("./Results/TD.txt") 23 | local file3 = io.open("./Results/vavg.txt") 24 | local file4 = io.open("./Results/conv1.txt") 25 | local file5 = io.open("./Results/conv2.txt") 26 | local file6 = io.open("./Results/conv3.txt") 27 | local file7 = io.open("./Results/lin1.txt") 28 | local file8 = io.open("./Results/lin2.txt") 29 | 30 | i = 1 31 | conv1_norm = torch.Tensor(m/4-1) 32 | conv1_norm_max = torch.Tensor(m/4-1) 33 | conv1_grad = torch.Tensor(m/4-1) 34 | conv1_grad_max = torch.Tensor(m/4-1) 35 | 36 | if file4 then 37 | for line in file4:lines() do 38 | if i == 1 then 39 | print(i) 40 | elseif i == m-4 then 41 | break 42 | elseif i % 4 == 2 then 43 | conv1_norm[math.floor(i/4)+1] = tonumber(line) 44 | elseif i % 4 == 3 then 45 | conv1_norm_max[math.floor(i/4)+1] = tonumber(line) 46 | elseif i % 4 == 0 then 47 | conv1_grad[math.floor(i/4)+1] = tonumber(line) 48 | elseif i % 4 == 1 then 49 | conv1_grad_max[math.floor(i/4)+1] = tonumber(line) 50 | end 51 | i = i+1 52 | end 53 | end 54 | 55 | gnuplot.pngfigure('./Results/tmp/conv1.png') 56 | gnuplot.title('conv 1 over training') 57 | gnuplot.plot({'Norm',conv1_norm},{'Grad',conv1_grad}) 58 | gnuplot.xlabel('Training epochs') 59 | gnuplot.movelegend('left','top') 60 | gnuplot.plotflush() 61 | 62 | 63 | gnuplot.pngfigure('./Results/tmp/conv1_max.png') 64 | gnuplot.title('conv 1 max over training') 65 | gnuplot.plot({'Max norm',conv1_norm_max},{'Max grad',conv1_grad_max}) 66 | gnuplot.xlabel('Training epochs') 67 | gnuplot.movelegend('left','top') 68 | gnuplot.plotflush() 69 | 70 | 71 | i = 1 72 | conv2_norm = torch.Tensor(m/4-1) 73 | conv2_norm_max = torch.Tensor(m/4-1) 74 | conv2_grad = torch.Tensor(m/4-1) 75 | conv2_grad_max = torch.Tensor(m/4-1) 76 | 77 | if file5 then 78 | for line in file5:lines() do 79 | if i == 1 then 80 | print(i) 81 | elseif i == m-4 then 82 | break 83 | elseif i % 4 == 2 then 84 | conv2_norm[math.floor(i/4)+1] = tonumber(line) 85 | elseif i % 4 == 3 then 86 | conv2_norm_max[math.floor(i/4)+1] = tonumber(line) 87 | elseif i % 4 == 0 then 88 | conv2_grad[math.floor(i/4)+1] = tonumber(line) 89 | elseif i % 4 == 1 then 90 | conv2_grad_max[math.floor(i/4)+1] = tonumber(line) 91 | end 92 | i = i+1 93 | end 94 | end 95 | 96 | gnuplot.pngfigure('./Results/tmp/conv2.png') 97 | gnuplot.title('conv 2 over training') 98 | gnuplot.plot({'Norm',conv2_norm},{'Grad',conv2_grad}) 99 | gnuplot.xlabel('Training epochs') 100 | gnuplot.movelegend('left','top') 101 | gnuplot.plotflush() 102 | 103 | 104 | gnuplot.pngfigure('./Results/tmp/conv2_max.png') 105 | gnuplot.title('conv 2 max over training') 106 | gnuplot.plot({'Max norm',conv2_norm_max},{'Max grad',conv2_grad_max}) 107 | gnuplot.xlabel('Training epochs') 108 | gnuplot.movelegend('left','top') 109 | gnuplot.plotflush() 110 | 111 | i = 1 112 | conv3_norm = torch.Tensor(m/4-1) 113 | conv3_norm_max = torch.Tensor(m/4-1) 114 | conv3_grad = torch.Tensor(m/4-1) 115 | conv3_grad_max = torch.Tensor(m/4-1) 116 | 117 | if file6 then 118 | for line in file6:lines() do 119 | if i == 1 then 120 | print(i) 121 | elseif i == m-4 then 122 | break 123 | elseif i % 4 == 2 then 124 | conv3_norm[math.floor(i/4)+1] = tonumber(line) 125 | elseif i % 4 == 3 then 126 | conv3_norm_max[math.floor(i/4)+1] = tonumber(line) 127 | elseif i % 4 == 0 then 128 | conv3_grad[math.floor(i/4)+1] = tonumber(line) 129 | elseif i % 4 == 1 then 130 | conv3_grad_max[math.floor(i/4)+1] = tonumber(line) 131 | end 132 | i = i+1 133 | end 134 | end 135 | 136 | gnuplot.pngfigure('./Results/tmp/conv3.png') 137 | gnuplot.title('conv 3 over training') 138 | gnuplot.plot({'Norm',conv3_norm},{'Grad',conv3_grad}) 139 | gnuplot.xlabel('Training epochs') 140 | gnuplot.movelegend('left','top') 141 | gnuplot.plotflush() 142 | 143 | 144 | gnuplot.pngfigure('./Results/tmp/conv3_max.png') 145 | gnuplot.title('conv 3 max over training') 146 | gnuplot.plot({'Max norm',conv3_norm_max},{'Max grad',conv3_grad_max}) 147 | gnuplot.xlabel('Training epochs') 148 | gnuplot.movelegend('left','top') 149 | gnuplot.plotflush() 150 | 151 | i = 1 152 | lin1_norm = torch.Tensor(m/4-1) 153 | lin1_norm_max = torch.Tensor(m/4-1) 154 | lin1_grad = torch.Tensor(m/4-1) 155 | lin1_grad_max = torch.Tensor(m/4-1) 156 | 157 | if file7 then 158 | for line in file7:lines() do 159 | if i == 1 then 160 | print(i) 161 | elseif i == m-4 then 162 | break 163 | elseif i % 4 == 2 then 164 | lin1_norm[math.floor(i/4)+1] = tonumber(line) 165 | elseif i % 4 == 3 then 166 | lin1_norm_max[math.floor(i/4)+1] = tonumber(line) 167 | elseif i % 4 == 0 then 168 | lin1_grad[math.floor(i/4)+1] = tonumber(line) 169 | elseif i % 4 == 1 then 170 | lin1_grad_max[math.floor(i/4)+1] = tonumber(line) 171 | end 172 | i = i+1 173 | end 174 | end 175 | 176 | gnuplot.pngfigure('./Results/tmp/lin1.png') 177 | gnuplot.title('lin1 over training') 178 | gnuplot.plot({'Norm',lin1_norm},{'Grad',lin1_grad}) 179 | gnuplot.xlabel('Training epochs') 180 | gnuplot.movelegend('left','top') 181 | gnuplot.plotflush() 182 | 183 | 184 | gnuplot.pngfigure('./Results/tmp/lin1_max.png') 185 | gnuplot.title('lin1 max over training') 186 | gnuplot.plot({'Max norm',lin1_norm_max},{'Max grad',lin1_grad_max}) 187 | gnuplot.xlabel('Training epochs') 188 | gnuplot.movelegend('left','top') 189 | gnuplot.plotflush() 190 | 191 | i = 1 192 | lin2_norm = torch.Tensor(m/4-1) 193 | lin2_norm_max = torch.Tensor(m/4-1) 194 | lin2_grad = torch.Tensor(m/4-1) 195 | lin2_grad_max = torch.Tensor(m/4-1) 196 | 197 | if file8 then 198 | for line in file8:lines() do 199 | if i == 1 then 200 | print(i) 201 | elseif i == m-4 then 202 | break 203 | elseif i % 4 == 2 then 204 | lin2_norm[math.floor(i/4)+1] = tonumber(line) 205 | elseif i % 4 == 3 then 206 | lin2_norm_max[math.floor(i/4)+1] = tonumber(line) 207 | elseif i % 4 == 0 then 208 | lin2_grad[math.floor(i/4)+1] = tonumber(line) 209 | elseif i % 4 == 1 then 210 | lin2_grad_max[math.floor(i/4)+1] = tonumber(line) 211 | end 212 | i = i+1 213 | end 214 | end 215 | 216 | gnuplot.pngfigure('./Results/tmp/lin2.png') 217 | gnuplot.title('lin2 over training') 218 | gnuplot.plot({'Norm',lin2_norm},{'Grad',lin2_grad}) 219 | gnuplot.xlabel('Training epochs') 220 | gnuplot.movelegend('left','top') 221 | gnuplot.plotflush() 222 | 223 | 224 | gnuplot.pngfigure('./Results/tmp/lin2_max.png') 225 | gnuplot.title('lin2 max over training') 226 | gnuplot.plot({'Max norm',lin2_norm_max},{'Max grad',lin2_grad_max}) 227 | gnuplot.xlabel('Training epochs') 228 | gnuplot.movelegend('left','top') 229 | gnuplot.plotflush() 230 | 231 | 232 | i = 1 233 | x = torch.Tensor(n) 234 | if file1 then 235 | for line in file1:lines() do 236 | x[i] = tonumber(line) 237 | i = i+1 238 | end 239 | end 240 | 241 | i = 1 242 | y = torch.Tensor(n) 243 | if file2 then 244 | for line in file2:lines() do 245 | y[i] = tonumber(line) 246 | i = i+1 247 | end 248 | end 249 | 250 | i = 1 251 | z = torch.Tensor(n) 252 | if file3 then 253 | for line in file3:lines() do 254 | z[i] = tonumber(line) 255 | z[i] = z[i]/y[i] 256 | i = i+1 257 | end 258 | end 259 | 260 | gnuplot.pngfigure('./Results/tmp/reward.png') 261 | gnuplot.title('reward over testing') 262 | gnuplot.plot(x) 263 | gnuplot.plotflush() 264 | 265 | gnuplot.pngfigure('./Results/tmp/vavg.png') 266 | gnuplot.title('vavg over testing') 267 | gnuplot.plot(y) 268 | gnuplot.plotflush() 269 | 270 | gnuplot.pngfigure('./Results/tmp/TD_Error.png') 271 | gnuplot.title('Normalized TD error over testing') 272 | gnuplot.plot(z) 273 | gnuplot.plotflush() 274 | 275 | 276 | 277 | -------------------------------------------------------------------------------- /graying_the_box/LUA/logs/Results/Plot2.lua: -------------------------------------------------------------------------------- 1 | require 'gnuplot' 2 | 3 | n = 0 4 | local file1 = io.open("./Results/reward2.txt") 5 | if file1 then 6 | for line in file1:lines() do 7 | n = n+1 8 | end 9 | end 10 | file1:close() 11 | 12 | local file1 = io.open("./Results/reward1.txt") 13 | local file2 = io.open("./Results/reward2.txt") 14 | 15 | i = 1 16 | 17 | 18 | x = torch.Tensor(n) 19 | if file1 then 20 | for line in file1:lines() do 21 | x[i] = tonumber(line) 22 | i = i+1 23 | end 24 | end 25 | 26 | 27 | x_ax = torch.range(1,i-1) 28 | 29 | 30 | i = 1 31 | y = torch.Tensor(n) 32 | if file2 then 33 | for line in file2:lines() do 34 | y[i] = tonumber(line) 35 | i = i+1 36 | end 37 | end 38 | y_ax = torch.range(1,i-1) 39 | 40 | 41 | if x_ax:numel()>y_ax:numel() then 42 | n = y_ax:numel() 43 | else 44 | n = x_ax:numel() 45 | end 46 | 47 | data1 = torch.Tensor(n) 48 | data2 = torch.Tensor(n) 49 | for i = 1,n do 50 | data1[i] = x[i] 51 | data2[i] = y[i] 52 | end 53 | 54 | 55 | gnuplot.pngfigure('reward.png') 56 | gnuplot.title('Average Reward over training') 57 | gnuplot.plot({'Policy Noise',data1},{'Q noise',data2}) 58 | gnuplot.xlabel('Training epochs') 59 | gnuplot.ylabel('Average testing reward') 60 | gnuplot.movelegend('left','top') 61 | gnuplot.plotflush() 62 | 63 | 64 | -------------------------------------------------------------------------------- /graying_the_box/LUA/logs/Results/Plot3.lua: -------------------------------------------------------------------------------- 1 | require 'gnuplot' 2 | 3 | n = 0 4 | local file1 = io.open("./Results/reward1.txt") 5 | if file1 then 6 | for line in file1:lines() do 7 | n = n+1 8 | end 9 | end 10 | file1:close() 11 | 12 | local file1 = io.open("./Results/reward1.txt") 13 | local file2 = io.open("./Results/reward2.txt") 14 | local file3 = io.open("./Results/reward3.txt") 15 | i = 1 16 | 17 | 18 | x = torch.Tensor(n) 19 | if file1 then 20 | for line in file1:lines() do 21 | x[i] = tonumber(line) 22 | i = i+1 23 | end 24 | end 25 | 26 | 27 | x_ax = torch.range(1,i-1) 28 | 29 | 30 | i = 1 31 | y = torch.Tensor(n) 32 | if file2 then 33 | for line in file2:lines() do 34 | y[i] = tonumber(line) 35 | i = i+1 36 | end 37 | end 38 | y_ax = torch.range(1,i-1) 39 | 40 | i = 1 41 | z = torch.Tensor(n) 42 | if file3 then 43 | for line in file3:lines() do 44 | z[i] = tonumber(line) 45 | i = i+1 46 | end 47 | end 48 | z_ax = torch.range(1,i-1) 49 | 50 | 51 | if x_ax:numel()>y_ax:numel() then 52 | n = y_ax:numel() 53 | else 54 | n = x_ax:numel() 55 | end 56 | if z_ax:numel() ./Results/vavg.txt 4 | cat $1 | grep "reward:" | cut -d "," -f 2 | cut -d " " -f 3 > ./Results/reward.txt 5 | cat $1 | grep "TD" | cut -d"T" -f 2|cut -f2> ./Results/TD.txt 6 | cat $1 | grep "nn.Sequential" | cut -d" " -f 3 > ./Results/conv1.txt 7 | cat $1 | grep "nn.Sequential" | cut -d" " -f 5 > ./Results/conv2.txt 8 | cat $1 | grep "nn.Sequential" | cut -d" " -f 7 > ./Results/conv3.txt 9 | cat $1 | grep "nn.Sequential" | cut -d" " -f 9 > ./Results/lin1.txt 10 | cat $1 | grep "nn.Sequential" | cut -d" " -f 11 | cut -d"]" -f1 > ./Results/lin2.txt 11 | 12 | 13 | 14 | th ./Results/Plot.lua 15 | 16 | 17 | -------------------------------------------------------------------------------- /graying_the_box/LUA/logs/Results/Process2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cat $1 | grep "TD" | cut -d"V" -f 2|cut -f2 > ./Results/vavg1.txt 4 | cat $1 | grep "reward:" | cut -d "," -f 2 | cut -d " " -f 3 > ./Results/reward1.txt 5 | cat $1 | grep "TD" | cut -d"T" -f 2|cut -f2> ./Results/TD1.txt 6 | 7 | cat $2 | grep "TD" | cut -d"V" -f 2|cut -f2 > ./Results/vavg2.txt 8 | cat $2 | grep "reward:" | cut -d "," -f 2 | cut -d " " -f 3 > ./Results/reward2.txt 9 | cat $2 | grep "TD" | cut -d"T" -f 2|cut -f2> ./Results/TD2.txt 10 | 11 | th ./Results/Plot2.lua 12 | 13 | 14 | -------------------------------------------------------------------------------- /graying_the_box/LUA/logs/Results/Process3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cat $1 | grep "TD" | cut -d"V" -f 2|cut -f2 > ./Results/vavg1.txt 4 | cat $1 | grep "reward:" | cut -d "," -f 2 | cut -d " " -f 3 > ./Results/reward1.txt 5 | cat $1 | grep "TD" | cut -d"T" -f 2|cut -f2> ./Results/TD1.txt 6 | 7 | cat $2 | grep "TD" | cut -d"V" -f 2|cut -f2 > ./Results/vavg2.txt 8 | cat $2 | grep "reward:" | cut -d "," -f 2 | cut -d " " -f 3 > ./Results/reward2.txt 9 | cat $2 | grep "TD" | cut -d"T" -f 2|cut -f2> ./Results/TD2.txt 10 | 11 | cat $3 | grep "TD" | cut -d"V" -f 2|cut -f2 > ./Results/vavg3.txt 12 | cat $3 | grep "reward:" | cut -d "," -f 2 | cut -d " " -f 3 > ./Results/reward3.txt 13 | cat $3 | grep "TD" | cut -d"T" -f 2|cut -f2> ./Results/TD3.txt 14 | 15 | th ./Results/Plot3.lua 16 | 17 | 18 | -------------------------------------------------------------------------------- /graying_the_box/LUA/logs/Results/tmp/TD_Error.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/TD_Error.png -------------------------------------------------------------------------------- /graying_the_box/LUA/logs/Results/tmp/conv1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/conv1.png -------------------------------------------------------------------------------- /graying_the_box/LUA/logs/Results/tmp/conv1_max.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/conv1_max.png -------------------------------------------------------------------------------- /graying_the_box/LUA/logs/Results/tmp/conv2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/conv2.png -------------------------------------------------------------------------------- /graying_the_box/LUA/logs/Results/tmp/conv2_max.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/conv2_max.png -------------------------------------------------------------------------------- /graying_the_box/LUA/logs/Results/tmp/conv3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/conv3.png -------------------------------------------------------------------------------- /graying_the_box/LUA/logs/Results/tmp/conv3_max.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/conv3_max.png -------------------------------------------------------------------------------- /graying_the_box/LUA/logs/Results/tmp/lin1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/lin1.png -------------------------------------------------------------------------------- /graying_the_box/LUA/logs/Results/tmp/lin1_max.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/lin1_max.png -------------------------------------------------------------------------------- /graying_the_box/LUA/logs/Results/tmp/lin2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/lin2.png -------------------------------------------------------------------------------- /graying_the_box/LUA/logs/Results/tmp/lin2_max.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/lin2_max.png -------------------------------------------------------------------------------- /graying_the_box/LUA/logs/Results/tmp/reward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/reward.png -------------------------------------------------------------------------------- /graying_the_box/LUA/logs/Results/tmp/vavg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tesslerc/H-DRLN/87c643e193002fce3e1865a2e962351eff6cbdea/graying_the_box/LUA/logs/Results/tmp/vavg.png -------------------------------------------------------------------------------- /graying_the_box/LUA/roms/README: -------------------------------------------------------------------------------- 1 | Rom files should be put in this directory 2 | -------------------------------------------------------------------------------- /graying_the_box/LUA/run_gpu: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ] 4 | then echo "Please provide the name of the game, e.g. ./run_gpu breakout "; exit 0 5 | fi 6 | ENV=$1 7 | FRAMEWORK="alewrap" 8 | 9 | game_path=$PWD"/roms/" 10 | env_params="useRGB=true" 11 | agent="NeuralQLearner" 12 | n_replay=1 13 | netfile="\"convnet_atari3\"" 14 | update_freq=4 15 | actrep=4 16 | discount=0.99 17 | seed=1 18 | 19 | learn_start=1000000 20 | replay_memory=1000000 21 | eps_start=0.1 #1 22 | lr=0.00025 #0.00025 23 | 24 | 25 | pool_frms_type="\"max\"" 26 | pool_frms_size=2 27 | initial_priority="false" 28 | 29 | eps_end=0.1 30 | eps_endt=replay_memory 31 | 32 | agent_type="DQN3_0_1" 33 | preproc_net="\"net_downsample_2x_full_y\"" 34 | agent_name=$agent_type"_"$1"_SeaQuest" 35 | state_dim=7056 36 | ncols=1 37 | agent_params="lr="$lr",ep="$eps_start",ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=32,rescale_r=1,ncols="$ncols",bufferSize=512,valid_size=500,target_q=10000,clip_delta=1,min_reward=-1,max_reward=1" 38 | steps=50000000 39 | eval_freq=250000 40 | eval_steps=125000 41 | prog_freq=10000 42 | save_freq=1000000 43 | gpu=0 44 | random_starts=30 45 | pool_frms="type="$pool_frms_type",size="$pool_frms_size 46 | num_threads=4 47 | 48 | args="-network $2 -framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -agent_params $agent_params -steps $steps -eval_freq $eval_freq -eval_steps $eval_steps -prog_freq $prog_freq -save_freq $save_freq -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads" 49 | echo $args 50 | 51 | cd dqn 52 | qlua ./train_agent.lua $args 53 | -------------------------------------------------------------------------------- /graying_the_box/bhtsne/Makefile.win: -------------------------------------------------------------------------------- 1 | CXX = cl.exe 2 | CFLAGS = /nologo /O2 /EHsc /D "_CRT_SECURE_NO_DEPRECATE" /D "USEOMP" /openmp 3 | 4 | TARGET = windows 5 | 6 | all: $(TARGET) $(TARGET)\bh_tsne.exe 7 | 8 | $(TARGET)\bh_tsne.exe: tsne.obj sptree.obj 9 | $(CXX) $(CFLAGS) tsne.obj sptree.obj -Fe$(TARGET)\bh_tsne.exe 10 | 11 | sptree.obj: sptree.cpp sptree.h 12 | $(CXX) $(CFLAGS) -c sptree.cpp 13 | 14 | tsne.obj: tsne.cpp tsne.h sptree.h vptree.h 15 | $(CXX) $(CFLAGS) -c tsne.cpp 16 | 17 | .PHONY: $(TARGET) 18 | $(TARGET): 19 | -mkdir $(TARGET) 20 | 21 | clean: 22 | -erase /Q *.obj *.exe $(TARGET)\. 23 | -rd $(TARGET) 24 | -------------------------------------------------------------------------------- /graying_the_box/bhtsne/parse_lua_tensor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def parse_lua_tensor(file, dim): 4 | 5 | k = 0 6 | read_data = [] 7 | with open(file, 'r') as f: 8 | for line in f: 9 | if 'Columns' in line or line == '\n': 10 | continue 11 | else: 12 | read_data.append(line) 13 | k += 1 14 | 15 | read_data = read_data[2:-1] 16 | samples = [] 17 | for line in read_data: 18 | line = line[0:-1] # remove end of line mark 19 | line_ = line.split(' ') 20 | line_ = filter(None,line_) 21 | 22 | samples.append(line_) 23 | 24 | z = np.array(samples) 25 | num_cols = dim # DEBUG HERE!!! 26 | num_lines = z.shape[0] * z.shape[1] / num_cols 27 | my_mat = np.zeros(shape=(num_lines,num_cols), dtype='float32') 28 | 29 | num_blocks = z.shape[0] / num_lines 30 | 31 | for i in range(num_blocks): 32 | my_mat[:,i*z.shape[1]:(i+1)*z.shape[1]] = z[num_lines*i : (i+1)*num_lines,:] 33 | 34 | return my_mat -------------------------------------------------------------------------------- /graying_the_box/bhtsne/sptree.h: -------------------------------------------------------------------------------- 1 | /* 2 | * 3 | * Copyright (c) 2014, Laurens van der Maaten (Delft University of Technology) 4 | * All rights reserved. 5 | * 6 | * Redistribution and use in source and binary forms, with or without 7 | * modification, are permitted provided that the following conditions are met: 8 | * 1. Redistributions of source code must retain the above copyright 9 | * notice, this list of conditions and the following disclaimer. 10 | * 2. Redistributions in binary form must reproduce the above copyright 11 | * notice, this list of conditions and the following disclaimer in the 12 | * documentation and/or other materials provided with the distribution. 13 | * 3. All advertising materials mentioning features or use of this software 14 | * must display the following acknowledgement: 15 | * This product includes software developed by the Delft University of Technology. 16 | * 4. Neither the name of the Delft University of Technology nor the names of 17 | * its contributors may be used to endorse or promote products derived from 18 | * this software without specific prior written permission. 19 | * 20 | * THIS SOFTWARE IS PROVIDED BY LAURENS VAN DER MAATEN ''AS IS'' AND ANY EXPRESS 21 | * OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 22 | * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO 23 | * EVENT SHALL LAURENS VAN DER MAATEN BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 25 | * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR 26 | * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING 28 | * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY 29 | * OF SUCH DAMAGE. 30 | * 31 | */ 32 | 33 | 34 | #ifndef SPTREE_H 35 | #define SPTREE_H 36 | 37 | using namespace std; 38 | 39 | 40 | class Cell { 41 | 42 | unsigned int dimension; 43 | double* corner; 44 | double* width; 45 | 46 | 47 | public: 48 | Cell(unsigned int inp_dimension); 49 | Cell(unsigned int inp_dimension, double* inp_corner, double* inp_width); 50 | ~Cell(); 51 | 52 | double getCorner(unsigned int d); 53 | double getWidth(unsigned int d); 54 | void setCorner(unsigned int d, double val); 55 | void setWidth(unsigned int d, double val); 56 | bool containsPoint(double point[]); 57 | }; 58 | 59 | 60 | class SPTree 61 | { 62 | 63 | // Fixed constants 64 | static const unsigned int QT_NODE_CAPACITY = 1; 65 | 66 | // A buffer we use when doing force computations 67 | double* buff; 68 | 69 | // Properties of this node in the tree 70 | SPTree* parent; 71 | unsigned int dimension; 72 | bool is_leaf; 73 | unsigned int size; 74 | unsigned int cum_size; 75 | 76 | // Axis-aligned bounding box stored as a center with half-dimensions to represent the boundaries of this quad tree 77 | Cell* boundary; 78 | 79 | // Indices in this space-partitioning tree node, corresponding center-of-mass, and list of all children 80 | double* data; 81 | double* center_of_mass; 82 | unsigned int index[QT_NODE_CAPACITY]; 83 | 84 | // Children 85 | SPTree** children; 86 | unsigned int no_children; 87 | 88 | public: 89 | SPTree(unsigned int D, double* inp_data, unsigned int N); 90 | SPTree(unsigned int D, double* inp_data, double* inp_corner, double* inp_width); 91 | SPTree(unsigned int D, double* inp_data, unsigned int N, double* inp_corner, double* inp_width); 92 | SPTree(SPTree* inp_parent, unsigned int D, double* inp_data, unsigned int N, double* inp_corner, double* inp_width); 93 | SPTree(SPTree* inp_parent, unsigned int D, double* inp_data, double* inp_corner, double* inp_width); 94 | ~SPTree(); 95 | void setData(double* inp_data); 96 | SPTree* getParent(); 97 | void construct(Cell boundary); 98 | bool insert(unsigned int new_index); 99 | void subdivide(); 100 | bool isCorrect(); 101 | void rebuildTree(); 102 | void getAllIndices(unsigned int* indices); 103 | unsigned int getDepth(); 104 | void computeNonEdgeForces(unsigned int point_index, double theta, double neg_f[], double* sum_Q); 105 | void computeEdgeForces(unsigned int* row_P, unsigned int* col_P, double* val_P, int N, double* pos_f); 106 | void print(); 107 | 108 | private: 109 | void init(SPTree* inp_parent, unsigned int D, double* inp_data, double* inp_corner, double* inp_width); 110 | void fill(unsigned int N); 111 | unsigned int getAllIndices(unsigned int* indices, unsigned int loc); 112 | bool isChild(unsigned int test_index, unsigned int start, unsigned int end); 113 | }; 114 | 115 | #endif 116 | -------------------------------------------------------------------------------- /graying_the_box/bhtsne/tsne.h: -------------------------------------------------------------------------------- 1 | /* 2 | * 3 | * Copyright (c) 2014, Laurens van der Maaten (Delft University of Technology) 4 | * All rights reserved. 5 | * 6 | * Redistribution and use in source and binary forms, with or without 7 | * modification, are permitted provided that the following conditions are met: 8 | * 1. Redistributions of source code must retain the above copyright 9 | * notice, this list of conditions and the following disclaimer. 10 | * 2. Redistributions in binary form must reproduce the above copyright 11 | * notice, this list of conditions and the following disclaimer in the 12 | * documentation and/or other materials provided with the distribution. 13 | * 3. All advertising materials mentioning features or use of this software 14 | * must display the following acknowledgement: 15 | * This product includes software developed by the Delft University of Technology. 16 | * 4. Neither the name of the Delft University of Technology nor the names of 17 | * its contributors may be used to endorse or promote products derived from 18 | * this software without specific prior written permission. 19 | * 20 | * THIS SOFTWARE IS PROVIDED BY LAURENS VAN DER MAATEN ''AS IS'' AND ANY EXPRESS 21 | * OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 22 | * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO 23 | * EVENT SHALL LAURENS VAN DER MAATEN BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 25 | * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR 26 | * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING 28 | * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY 29 | * OF SUCH DAMAGE. 30 | * 31 | */ 32 | 33 | 34 | #ifndef TSNE_H 35 | #define TSNE_H 36 | 37 | 38 | static inline double sign(double x) { return (x == .0 ? .0 : (x < .0 ? -1.0 : 1.0)); } 39 | 40 | 41 | class TSNE 42 | { 43 | public: 44 | void run(double* X, int N, int D, double* Y, int no_dims, double perplexity, double theta); 45 | bool load_data(double** data, int* n, int* d, int* no_dims, double* theta, double* perplexity, int* rand_seed); 46 | void save_data(double* data, int* landmarks, double* costs, int n, int d); 47 | void symmetrizeMatrix(unsigned int** row_P, unsigned int** col_P, double** val_P, int N); // should be static! 48 | 49 | 50 | private: 51 | void computeGradient(double* P, unsigned int* inp_row_P, unsigned int* inp_col_P, double* inp_val_P, double* Y, int N, int D, double* dC, double theta); 52 | void computeExactGradient(double* P, double* Y, int N, int D, double* dC); 53 | double evaluateError(double* P, double* Y, int N, int D); 54 | double evaluateError(unsigned int* row_P, unsigned int* col_P, double* val_P, double* Y, int N, int D, double theta); 55 | void zeroMean(double* X, int N, int D); 56 | void computeGaussianPerplexity(double* X, int N, int D, double* P, double perplexity); 57 | void computeGaussianPerplexity(double* X, int N, int D, unsigned int** _row_P, unsigned int** _col_P, double** _val_P, double perplexity, int K); 58 | void computeSquaredEuclideanDistance(double* X, int N, int D, double* DD); 59 | double randn(); 60 | }; 61 | 62 | #endif 63 | 64 | -------------------------------------------------------------------------------- /graying_the_box/bhtsne/vptree.h: -------------------------------------------------------------------------------- 1 | /* 2 | * 3 | * Copyright (c) 2014, Laurens van der Maaten (Delft University of Technology) 4 | * All rights reserved. 5 | * 6 | * Redistribution and use in source and binary forms, with or without 7 | * modification, are permitted provided that the following conditions are met: 8 | * 1. Redistributions of source code must retain the above copyright 9 | * notice, this list of conditions and the following disclaimer. 10 | * 2. Redistributions in binary form must reproduce the above copyright 11 | * notice, this list of conditions and the following disclaimer in the 12 | * documentation and/or other materials provided with the distribution. 13 | * 3. All advertising materials mentioning features or use of this software 14 | * must display the following acknowledgement: 15 | * This product includes software developed by the Delft University of Technology. 16 | * 4. Neither the name of the Delft University of Technology nor the names of 17 | * its contributors may be used to endorse or promote products derived from 18 | * this software without specific prior written permission. 19 | * 20 | * THIS SOFTWARE IS PROVIDED BY LAURENS VAN DER MAATEN ''AS IS'' AND ANY EXPRESS 21 | * OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 22 | * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO 23 | * EVENT SHALL LAURENS VAN DER MAATEN BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 25 | * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR 26 | * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 | * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING 28 | * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY 29 | * OF SUCH DAMAGE. 30 | * 31 | */ 32 | 33 | 34 | /* This code was adopted with minor modifications from Steve Hanov's great tutorial at http://stevehanov.ca/blog/index.php?id=130 */ 35 | 36 | #include 37 | #include 38 | #include 39 | #include 40 | #include 41 | #include 42 | #include 43 | 44 | 45 | #ifndef VPTREE_H 46 | #define VPTREE_H 47 | 48 | class DataPoint 49 | { 50 | int _ind; 51 | 52 | public: 53 | double* _x; 54 | int _D; 55 | DataPoint() { 56 | _D = 1; 57 | _ind = -1; 58 | _x = NULL; 59 | } 60 | DataPoint(int D, int ind, double* x) { 61 | _D = D; 62 | _ind = ind; 63 | _x = (double*) malloc(_D * sizeof(double)); 64 | for(int d = 0; d < _D; d++) _x[d] = x[d]; 65 | } 66 | DataPoint(const DataPoint& other) { // this makes a deep copy -- should not free anything 67 | if(this != &other) { 68 | _D = other.dimensionality(); 69 | _ind = other.index(); 70 | _x = (double*) malloc(_D * sizeof(double)); 71 | for(int d = 0; d < _D; d++) _x[d] = other.x(d); 72 | } 73 | } 74 | ~DataPoint() { if(_x != NULL) free(_x); } 75 | DataPoint& operator= (const DataPoint& other) { // asignment should free old object 76 | if(this != &other) { 77 | if(_x != NULL) free(_x); 78 | _D = other.dimensionality(); 79 | _ind = other.index(); 80 | _x = (double*) malloc(_D * sizeof(double)); 81 | for(int d = 0; d < _D; d++) _x[d] = other.x(d); 82 | } 83 | return *this; 84 | } 85 | int index() const { return _ind; } 86 | int dimensionality() const { return _D; } 87 | double x(int d) const { return _x[d]; } 88 | }; 89 | 90 | double euclidean_distance(const DataPoint &t1, const DataPoint &t2) { 91 | double dd = .0; 92 | double* x1 = t1._x; 93 | double* x2 = t2._x; 94 | double diff; 95 | for(int d = 0; d < t1._D; d++) { 96 | diff = (x1[d] - x2[d]); 97 | dd += diff * diff; 98 | } 99 | return sqrt(dd); 100 | } 101 | 102 | 103 | template 104 | class VpTree 105 | { 106 | public: 107 | 108 | // Default constructor 109 | VpTree() : _root(0) {} 110 | 111 | // Destructor 112 | ~VpTree() { 113 | delete _root; 114 | } 115 | 116 | // Function to create a new VpTree from data 117 | void create(const std::vector& items) { 118 | delete _root; 119 | _items = items; 120 | _root = buildFromPoints(0, items.size()); 121 | } 122 | 123 | // Function that uses the tree to find the k nearest neighbors of target 124 | void search(const T& target, int k, std::vector* results, std::vector* distances) 125 | { 126 | 127 | // Use a priority queue to store intermediate results on 128 | std::priority_queue heap; 129 | 130 | // Variable that tracks the distance to the farthest point in our results 131 | _tau = DBL_MAX; 132 | 133 | // Perform the search 134 | search(_root, target, k, heap); 135 | 136 | // Gather final results 137 | results->clear(); distances->clear(); 138 | while(!heap.empty()) { 139 | results->push_back(_items[heap.top().index]); 140 | distances->push_back(heap.top().dist); 141 | heap.pop(); 142 | } 143 | 144 | // Results are in reverse order 145 | std::reverse(results->begin(), results->end()); 146 | std::reverse(distances->begin(), distances->end()); 147 | } 148 | 149 | private: 150 | std::vector _items; 151 | double _tau; 152 | 153 | // Single node of a VP tree (has a point and radius; left children are closer to point than the radius) 154 | struct Node 155 | { 156 | int index; // index of point in node 157 | double threshold; // radius(?) 158 | Node* left; // points closer by than threshold 159 | Node* right; // points farther away than threshold 160 | 161 | Node() : 162 | index(0), threshold(0.), left(0), right(0) {} 163 | 164 | ~Node() { // destructor 165 | delete left; 166 | delete right; 167 | } 168 | }* _root; 169 | 170 | 171 | // An item on the intermediate result queue 172 | struct HeapItem { 173 | HeapItem( int index, double dist) : 174 | index(index), dist(dist) {} 175 | int index; 176 | double dist; 177 | bool operator<(const HeapItem& o) const { 178 | return dist < o.dist; 179 | } 180 | }; 181 | 182 | // Distance comparator for use in std::nth_element 183 | struct DistanceComparator 184 | { 185 | const T& item; 186 | DistanceComparator(const T& item) : item(item) {} 187 | bool operator()(const T& a, const T& b) { 188 | return distance(item, a) < distance(item, b); 189 | } 190 | }; 191 | 192 | // Function that (recursively) fills the tree 193 | Node* buildFromPoints( int lower, int upper ) 194 | { 195 | if (upper == lower) { // indicates that we're done here! 196 | return NULL; 197 | } 198 | 199 | // Lower index is center of current node 200 | Node* node = new Node(); 201 | node->index = lower; 202 | 203 | if (upper - lower > 1) { // if we did not arrive at leaf yet 204 | 205 | // Choose an arbitrary point and move it to the start 206 | int i = (int) ((double)rand() / RAND_MAX * (upper - lower - 1)) + lower; 207 | std::swap(_items[lower], _items[i]); 208 | 209 | // Partition around the median distance 210 | int median = (upper + lower) / 2; 211 | std::nth_element(_items.begin() + lower + 1, 212 | _items.begin() + median, 213 | _items.begin() + upper, 214 | DistanceComparator(_items[lower])); 215 | 216 | // Threshold of the new node will be the distance to the median 217 | node->threshold = distance(_items[lower], _items[median]); 218 | 219 | // Recursively build tree 220 | node->index = lower; 221 | node->left = buildFromPoints(lower + 1, median); 222 | node->right = buildFromPoints(median, upper); 223 | } 224 | 225 | // Return result 226 | return node; 227 | } 228 | 229 | // Helper function that searches the tree 230 | void search(Node* node, const T& target, int k, std::priority_queue& heap) 231 | { 232 | if(node == NULL) return; // indicates that we're done here 233 | 234 | // Compute distance between target and current node 235 | double dist = distance(_items[node->index], target); 236 | 237 | // If current node within radius tau 238 | if(dist < _tau) { 239 | if(heap.size() == k) heap.pop(); // remove furthest node from result list (if we already have k results) 240 | heap.push(HeapItem(node->index, dist)); // add current node to result list 241 | if(heap.size() == k) _tau = heap.top().dist; // update value of tau (farthest point in result list) 242 | } 243 | 244 | // Return if we arrived at a leaf 245 | if(node->left == NULL && node->right == NULL) { 246 | return; 247 | } 248 | 249 | // If the target lies within the radius of ball 250 | if(dist < node->threshold) { 251 | if(dist - _tau <= node->threshold) { // if there can still be neighbors inside the ball, recursively search left child first 252 | search(node->left, target, k, heap); 253 | } 254 | 255 | if(dist + _tau >= node->threshold) { // if there can still be neighbors outside the ball, recursively search right child 256 | search(node->right, target, k, heap); 257 | } 258 | 259 | // If the target lies outsize the radius of the ball 260 | } else { 261 | if(dist + _tau >= node->threshold) { // if there can still be neighbors outside the ball, recursively search right child first 262 | search(node->right, target, k, heap); 263 | } 264 | 265 | if (dist - _tau <= node->threshold) { // if there can still be neighbors inside the ball, recursively search left child 266 | search(node->left, target, k, heap); 267 | } 268 | } 269 | } 270 | }; 271 | 272 | #endif 273 | -------------------------------------------------------------------------------- /graying_the_box/bhtsne/write_movie.py: -------------------------------------------------------------------------------- 1 | # This example uses a MovieWriter directly to grab individual frames and 2 | # write them to a file. This avoids any event loop integration, but has 3 | # the advantage of working with even the Agg backend. This is not recommended 4 | # for use in an interactive setting. 5 | # -*- noplot -*- 6 | 7 | import numpy as np 8 | import matplotlib 9 | matplotlib.use("Agg") 10 | import matplotlib.pyplot as plt 11 | import matplotlib.animation as manimation 12 | 13 | FFMpegWriter = manimation.writers['ffmpeg'] 14 | metadata = dict(title='Movie Test', artist='Matplotlib', 15 | comment='Movie support!') 16 | writer = FFMpegWriter(fps=15, metadata=metadata) 17 | 18 | fig = plt.figure() 19 | l, = plt.plot([], [], 'k-o') 20 | 21 | plt.xlim(-5, 5) 22 | plt.ylim(-5, 5) 23 | 24 | x0,y0 = 0, 0 25 | 26 | with writer.saving(fig, "writer_test.mp4", 100): 27 | for i in range(100): 28 | x0 += 0.1 * np.random.randn() 29 | y0 += 0.1 * np.random.randn() 30 | l.set_data(x0, y0) 31 | writer.grab_frame() -------------------------------------------------------------------------------- /graying_the_box/clustering.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster import SpectralClustering 2 | from sklearn.cluster import KMeans as Kmeans_st 3 | # from sklearn.cluster import KMeans_st as Kmeans_st 4 | 5 | from emhc import EMHC 6 | from smdp import SMDP 7 | import numpy as np 8 | import common 9 | from digraph import draw_transition_table 10 | 11 | def perpare_features(self, n_features=3): 12 | 13 | data = np.zeros(shape=(self.global_feats['tsne'].shape[0],n_features)) 14 | data[:,0:2] = self.global_feats['tsne'] 15 | data[:,2] = self.global_feats['value'] 16 | # data[:,3] = self.global_feats['time'] 17 | # data[:,4] = self.global_feats['termination'] 18 | # data[:,5] = self.global_feats['tsne3d_norm'] 19 | # data[:,6] = self.hand_craft_feats['missing_bricks'] 20 | # data[:,6] = self.hand_craft_feats['hole'] 21 | # data[:,7] = self.hand_craft_feats['racket'] 22 | # data[:,8] = self.hand_craft_feats['ball_dir'] 23 | # data[:,9] = self.hand_craft_feats['traj'] 24 | # data[:,9:11] = self.hand_craft_feats['ball_pos'] 25 | data[np.isnan(data)] = 0 26 | # 1.2 data standartization 27 | # scaler = preprocessing.StandardScaler(with_centering=False).fit(data) 28 | # data = scaler.fit_transform(data) 29 | 30 | # data_mean = data.mean(axis=0) 31 | # data -= data_mean 32 | return data 33 | 34 | def clustering_(self, plt, n_points=None, force=0): 35 | 36 | if n_points==None: 37 | n_points = self.global_feats['termination'].shape[0] 38 | 39 | if self.clustering_labels is not None: 40 | self.tsne_scat.set_array(self.clustering_labels.astype(np.float32)/self.clustering_labels.max()) 41 | draw_transition_table(transition_table=self.smdp.P, cluster_centers=self.cluster_centers, 42 | meanscreen=self.meanscreen, tsne=self.global_feats['tsne'], color=self.color, black_edges=self.smdp.edges) 43 | plt.show() 44 | if force==0: 45 | return 46 | 47 | n_clusters = self.cluster_params['n_clusters'] 48 | W = self.cluster_params['window_size'] 49 | n_iters = self.cluster_params['n_iters'] 50 | entropy_iters = self.cluster_params['entropy_iters'] 51 | 52 | # slice data by given indices 53 | term = self.global_feats['termination'][:n_points] 54 | reward = self.global_feats['reward'][:n_points] 55 | value = self.global_feats['value'][:n_points] 56 | tsne = self.global_feats['tsne'][:n_points] 57 | traj_ids = self.hand_craft_feats['traj'][:n_points] 58 | 59 | # 1. create data for clustering 60 | data = perpare_features(self) 61 | data = data[:n_points] 62 | data_scale = data.max(axis=0) 63 | data /= data_scale 64 | 65 | # 2. Build cluster model 66 | # 2.1 spatio-temporal K-means 67 | if self.cluster_params['method'] == 0: 68 | windows_vec = np.arange(start=W,stop=W+1,step=1) 69 | clusters_vec = np.arange(start=n_clusters,stop=n_clusters+1,step=1) 70 | models_vec = [] 71 | scores = np.zeros(shape=(len(clusters_vec),1)) 72 | for i,n_w in enumerate(windows_vec): 73 | for j,n_c in enumerate(clusters_vec): 74 | cluster_model = Kmeans_st(n_clusters=n_clusters,window_size=n_w,n_jobs=8,n_init=n_iters,entropy_iters=entropy_iters) 75 | cluster_model.fit(data, rewards=reward, termination=term, values=value) 76 | labels = cluster_model.labels_ 77 | models_vec.append(cluster_model.smdp) 78 | scores[j] = cluster_model.smdp.score 79 | print 'window size: %d , Value mse: %f' % (n_w, cluster_model.smdp.score) 80 | best = np.argmin(scores) 81 | self.cluster_params['n_clusters'] +=best 82 | self.smdp = models_vec[best] 83 | 84 | # 2.1 Spectral clustering 85 | elif self.cluster_params['method'] == 1: 86 | import scipy.spatial.distance 87 | import scipy.sparse 88 | dists = scipy.spatial.distance.pdist(tsne, 'euclidean') 89 | similarity = np.exp(-dists/10) 90 | similarity[similarity<1e-2] = 0 91 | print 'Created similarity matrix' 92 | affine_mat = scipy.spatial.distance.squareform(similarity) 93 | cluster_model = SpectralClustering(n_clusters=n_clusters,affinity='precomputed') 94 | labels = cluster_model.fit_predict(affine_mat) 95 | 96 | # 2.2 EMHC 97 | elif self.cluster_params['method'] == 2: 98 | # cluster with k means down to n_clusters + D 99 | n_clusters_ = n_clusters + 5 100 | kmeans_st_model = Kmeans_st(n_clusters=n_clusters_,window_size=W,n_jobs=8,n_init=n_iters,entropy_iters=entropy_iters, random_state=123) 101 | kmeans_st_model.fit(data, rewards=reward, termination=term, values=value) 102 | cluster_model = EMHC(X=data, labels=kmeans_st_model.labels_, termination=term, min_clusters=n_clusters, max_entropy=np.inf) 103 | cluster_model.fit() 104 | labels = cluster_model.labels_ 105 | self.smdp = SMDP(labels=labels, termination=term, rewards=reward, values=value, n_clusters=n_clusters) 106 | 107 | self.smdp.complete_smdp() 108 | self.clustering_labels = self.smdp.labels 109 | common.create_trajectory_data(self, reward, traj_ids) 110 | self.state_pi_correlation = common.reward_policy_correlation(self.traj_list, self.smdp.greedy_policy, self.smdp) 111 | 112 | top_greedy_vec = [] 113 | bottom_greedy_vec = [] 114 | max_diff = 0 115 | best_d = 1 116 | for i,d in enumerate(xrange(1,30)): 117 | tb_trajs_discr = common.extermum_trajs_discrepency(self.traj_list, self.clustering_labels, term, reward, value, self.smdp.n_clusters, self.smdp.greedy_policy, d=d) 118 | top_greedy_vec.append([i,tb_trajs_discr['top_greedy_sum']]) 119 | bottom_greedy_vec.append([i,tb_trajs_discr['bottom_greedy_sum']]) 120 | diff_i = tb_trajs_discr['top_greedy_sum'] - tb_trajs_discr['bottom_greedy_sum'] 121 | if diff_i > max_diff: 122 | max_diff = diff_i 123 | best_d = d 124 | 125 | self.tb_trajs_discr = common.extermum_trajs_discrepency(self.traj_list, self.clustering_labels, term, reward, value, self.smdp.n_clusters, self.smdp.greedy_policy, d=best_d) 126 | self.top_greedy_vec = top_greedy_vec 127 | self.bottom_greedy_vec = bottom_greedy_vec 128 | 129 | common.draw_skills(self,self.smdp.n_clusters,plt) 130 | 131 | 132 | # 4. collect statistics 133 | cluster_centers = cluster_model.cluster_centers_ 134 | cluster_centers *= data_scale 135 | 136 | screen_size = self.screens.shape 137 | meanscreen = np.zeros(shape=(n_clusters,screen_size[1],screen_size[2],screen_size[3])) 138 | cluster_time = np.zeros(shape=(n_clusters,1)) 139 | width = int(np.floor(np.sqrt(n_clusters))) 140 | length = int(n_clusters/width) 141 | # f, ax = plt.subplots(length,width) 142 | 143 | for cluster_ind in range(n_clusters): 144 | indices = (labels==cluster_ind) 145 | cluster_data = data[indices] 146 | cluster_time[cluster_ind] = np.mean(self.global_feats['time'][indices]) 147 | meanscreen[cluster_ind,:,:,:] = common.calc_cluster_im(self,indices) 148 | 149 | # 5. draw cluster indices 150 | plt.figure(self.fig.number) 151 | data *= data_scale 152 | for i in range(n_clusters): 153 | self.ax_tsne.annotate(i, xy=cluster_centers[i,0:2], size=20, color='r') 154 | draw_transition_table(transition_table=self.smdp.P, cluster_centers=cluster_centers, 155 | meanscreen=meanscreen, tsne=data[:,0:2], color=self.color, black_edges=self.smdp.edges) 156 | 157 | self.cluster_centers = cluster_centers 158 | self.meanscreen =meanscreen 159 | self.cluster_time =cluster_time 160 | common.visualize(self) 161 | 162 | def update_slider(self, name, slider): 163 | def f(): 164 | setattr(self, name, slider.val) 165 | return f 166 | -------------------------------------------------------------------------------- /graying_the_box/digraph.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import matplotlib.pyplot as plt 3 | import pylab 4 | import pickle 5 | import numpy as np 6 | import common 7 | 8 | def draw_transition_table(transition_table, cluster_centers, meanscreen, tsne ,color, black_edges=None, red_edges=None, title=None): 9 | G = nx.DiGraph() 10 | edge_colors = [] 11 | 12 | if red_edges is not None: 13 | for e in red_edges: 14 | G.add_edges_from([e], weight=np.round(transition_table[e[0],e[1]]*100)/100) 15 | edge_colors.append('red') 16 | 17 | if black_edges is not None: 18 | if red_edges is not None: 19 | black_edges = list(set(black_edges)-set(red_edges)) 20 | 21 | for e in black_edges: 22 | G.add_edges_from([e], weight=np.round(transition_table[e[0],e[1]]*100)/100) 23 | edge_colors.append('black') 24 | 25 | 26 | edge_labels=dict([((u,v,),d['weight']) for u,v,d in G.edges(data=True)]) 27 | 28 | node_labels = {node:node for node in G.nodes()}; 29 | counter=0 30 | for key in node_labels.keys(): 31 | node_labels[key] = counter 32 | counter+=1 33 | 34 | if title is None: 35 | fig = plt.figure('SMDP') 36 | fig.clear() 37 | else: 38 | fig = plt.figure(title) 39 | 40 | plt.scatter(tsne[:,0],tsne[:,1],s= np.ones(tsne.shape[0])*2,facecolor=color, edgecolor='none') 41 | pos = cluster_centers[:,0:2] 42 | nx.draw_networkx_edge_labels(G,pos,edge_labels=edge_labels,label_pos=0.65,font_size=9) 43 | nx.draw_networkx_labels(G, pos, labels=node_labels,font_color='w',font_size=8) 44 | nx.draw(G,pos,cmap=plt.cm.brg,edge_color=edge_colors) 45 | 46 | 47 | ######Present images on nodes 48 | ax = plt.subplot(111) 49 | plt.axis('off') 50 | trans = ax.transData.transform 51 | trans2 = fig.transFigure.inverted().transform 52 | cut = 1.01 53 | xmax = cut * max(tsne[:,0]) 54 | ymax = cut * max(tsne[:,1]) 55 | xmin = cut * min(tsne[:,0]) 56 | ymin = cut * min(tsne[:,1]) 57 | plt.xlim(xmin, xmax) 58 | plt.ylim(ymin, ymax) 59 | 60 | h = 70.0 61 | w = 70.0 62 | counter= 0 63 | for node in G: 64 | xx, yy = trans(pos[node]) 65 | # axes coordinates 66 | xa, ya = trans2((xx, yy)) 67 | 68 | # this is the image size 69 | piesize_1 = (300.0 / (h*80)) 70 | piesize_2 = (300.0 / (w*80)) 71 | p2_2 = piesize_2 / 2 72 | p2_1 = piesize_1 / 2 73 | a = plt.axes([xa - p2_2, ya - p2_1, piesize_2, piesize_1]) 74 | G.node[node]['image'] = meanscreen[counter] 75 | #display it 76 | a.imshow(G.node[node]['image']) 77 | a.set_title(node_labels[counter]) 78 | #turn off the axis from minor plot 79 | a.axis('off') 80 | counter+=1 81 | plt.draw() 82 | 83 | def draw_transition_table_no_image(transition_table,cluster_centers): 84 | G = nx.DiGraph() 85 | G2 = nx.DiGraph() 86 | 87 | # print transition_table.sum(axis=1) 88 | 89 | transition_table = (transition_table.transpose()/transition_table.sum(axis=1)).transpose() 90 | transition_table[np.isnan(transition_table)]=0 91 | # print(transition_table) 92 | # transition_table = (transition_table.transpose()/transition_table.sum(axis=1)).transpose() 93 | # print transition_table 94 | # print transition_table.sum(axis=0) 95 | # assert(np.all(transition_table.sum(axis=0)!=0)) 96 | transition_table[transition_table<0.1]=0 97 | 98 | pos = cluster_centers[:,0:2] 99 | m,n = transition_table.shape 100 | 101 | for i in range(m): 102 | for j in range(n): 103 | if transition_table[i,j]!=0: 104 | G.add_edges_from([(i, j)], weight=np.round(transition_table[i,j]*100)/100) 105 | G2.add_edges_from([(i, j)], weight=np.round(transition_table[i,j]*100)/100) 106 | values = cluster_centers[:,2] 107 | 108 | red_edges = [] 109 | edges_sizes =[] 110 | for i in range(n): 111 | trans = transition_table[i,:] 112 | indices = (trans!=0) 113 | index = np.argmax(cluster_centers[indices,2]) 114 | counter = 0 115 | for j in range(len(indices)): 116 | if indices[j]: 117 | if counter == index: 118 | ind = j 119 | break 120 | else: 121 | counter+=1 122 | edges_sizes.append(ind) 123 | red_edges.append((i,ind)) 124 | # print(red_edges) 125 | # sizes = 3000*cluster_centers[:,3] 126 | sizes = np.ones_like(values)*500 127 | edge_labels=dict([((u,v,),d['weight']) for u,v,d in G.edges(data=True)]) 128 | edge_colors = ['black' for edge in G.edges()] 129 | # edge_colors = ['black' if not edge in red_edges else 'red' for edge in G.edges()] 130 | 131 | 132 | node_labels = {node:node for node in G.nodes()}; 133 | counter=0 134 | for key in node_labels.keys(): 135 | # node_labels[key] = np.round(100*cluster_centers[counter,3])/100 136 | node_labels[key] = counter 137 | counter+=1 138 | 139 | fig = plt.figure() 140 | nx.draw_networkx_edge_labels(G,pos,edge_labels=edge_labels,label_pos=0.65,font_size=9) 141 | nx.draw_networkx_labels(G, pos, labels=node_labels,font_color='w',font_size=8) 142 | nx.draw(G,pos, node_color = values,cmap=plt.cm.brg, node_size=np.round(sizes),edge_color=edge_colors,edge_cmap=plt.cm.Reds) 143 | 144 | 145 | ######Present images on nodes 146 | # plt.show() 147 | 148 | # 149 | def test(): 150 | gamename = 'breakout' #breakout pacman 151 | transition_table = pickle.load(file('/home/tom/git/graying_the_box/data/'+gamename+'/120k' + '/knn/' + 'transition_table.bin')) 152 | cluster_centers = pickle.load(file('/home/tom/git/graying_the_box/data/'+gamename+'/120k' + '/knn/' + 'cluster_centers.bin')) 153 | cluster_std = pickle.load(file('/home/tom/git/graying_the_box/data/'+gamename+'/120k' + '/knn/' + 'cluster_std.bin')) 154 | cluster_med = pickle.load(file('/home/tom/git/graying_the_box/data/'+gamename+'/120k' + '/knn/' + 'cluster_med.bin')) 155 | cluster_min = pickle.load(file('/home/tom/git/graying_the_box/data/'+gamename+'/120k' + '/knn/' + 'cluster_min.bin')) 156 | cluster_max = pickle.load(file('/home/tom/git/graying_the_box/data/'+gamename+'/120k' + '/knn/' + 'cluster_max.bin')) 157 | meanscreen = pickle.load(file('/home/tom/git/graying_the_box/data/'+gamename+'/120k' + '/knn/' + 'meanscreen.bin')) 158 | cluster_time = pickle.load(file('/home/tom/git/graying_the_box/data/'+gamename+'/120k' + '/knn/' + 'cluster_time.bin')) 159 | tsne = common.load_hdf5('lowd_activations', 'data/' + 'breakout' + '/'+'120k/') 160 | q_hdf5 = common.load_hdf5('qvals', 'data/' + 'breakout' + '/'+'120k/') 161 | 162 | num_frames = 120000 163 | V = np.zeros(shape=(num_frames)) 164 | 165 | for i in range(0,num_frames): 166 | V[i] = max(q_hdf5[i]) 167 | V = V/V.max() 168 | draw_transition_table(transition_table,cluster_centers,meanscreen,cluster_time,tsne,V) 169 | plt.show() 170 | # test() 171 | 172 | 173 | 174 | # stdscreen = pickle.load(file('/home/tom/git/graying_the_box/data/'+gamename+'/120k' + '/knn/' + 'stdscreen.bin')) 175 | # # 176 | # a = 1 177 | # b = 0 178 | # c = 0 179 | # screen = a*meanscreen + c*stdscreen 180 | 181 | # facecolor = self.color, 182 | # edgecolor='none',picker=5) 183 | # draw_transition_table_no_image(transition_table,cluster_centers) 184 | 185 | # transition_table = pickle.load(file('/home/tom/git/graying_the_box/data/seaquest/120k' + '/knn/' + 'transition_table.bin')) 186 | # transition_table[transition_table<0.1]=0 187 | # cluster_centers = pickle.load(file('/home/tom/git/graying_the_box/data/seaquest/120k' + '/knn/' + 'cluster_centers.bin')) 188 | 189 | # pos2 = np.zeros(shape=(cluster_centers.shape[0],2)) 190 | # pos2[:,0] = cluster_time[:,0] 191 | # pos2[:,1] = cluster_centers[:,1] 192 | # plt.figure() 193 | # nx.draw_networkx_edge_labels(G2,pos2,edge_labels=edge_labels,label_pos=0.8,font_size=8) 194 | # nx.draw_networkx_labels(G2, pos2, labels=node_labels,font_color='w',font_size=8) 195 | # nx.draw(G2,pos2, node_color = values,cmap=plt.cm.brg, node_size=np.round(sizes),edge_color=edge_colors,edge_cmap=plt.cm.Reds) 196 | -------------------------------------------------------------------------------- /graying_the_box/hand_crafted_features/add_breakout_buttons.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def add_game_buttons(self, top_y=0.69, x_left = 0.68): 4 | 5 | self.add_slider_button([x_left, top_y, 0.08, 0.01], 'ball_x', 0, 160) 6 | self.add_slider_button([x_left, top_y-0.04, 0.08, 0.01], 'ball_y', 0, 210) 7 | self.add_slider_button([x_left, top_y-0.08, 0.08, 0.01], 'racket_x', 0, 160) 8 | self.add_slider_button([x_left, top_y-0.12, 0.08, 0.01], 'missing_bricks', 0, 130) 9 | self.add_check_button([x_left, top_y-0.20, 0.06, 0.04], 'hole', ('no-hole','hole'), (True,True)) 10 | self.add_check_button([x_left, top_y-0.30, 0.07, 0.08], 'ball_dir', ('down-right','up-right','down-left','up-left'), (True,True,True,True)) 11 | self.ball_dir_mat = np.zeros(shape=(self.num_points,4), dtype='bool') 12 | for i,dir in enumerate(self.hand_craft_feats['ball_dir']): 13 | self.ball_dir_mat[i,int(dir)] = 1 14 | 15 | ############################################## 16 | # marking points along trajectories from a to b 17 | 18 | # self.cond_vector_mark_trajs = np.zeros(shape=(self.num_points,), dtype='int8') 19 | # 20 | # self.fig_mark_trajs = plt.figure('mark trajectories') 21 | # self.fig_mark_trajs.canvas.mpl_connect('pick_event', self.on_scatter_pick_mark_trajs) 22 | # self.ax_mark_trajs = self.fig_mark_trajs.add_subplot(111) 23 | # self.scat_mark_trajs = self.ax_mark_trajs.scatter(self.data_t[0], 24 | # self.data_t[1], 25 | # s = 5 * self.cond_vector_mark_trajs + np.ones(self.num_points)*self.pnt_size, 26 | # facecolor = self.V, 27 | # edgecolor='none', 28 | # picker=5) 29 | 30 | ############################################### 31 | 32 | def update_cond_vector(self): 33 | 34 | ball_x = np.asarray([row[0] for row in self.hand_craft_feats['ball_pos']]) 35 | ball_y = np.asarray([row[1] for row in self.hand_craft_feats['ball_pos']]) 36 | ball_dir = np.asarray(self.hand_craft_feats['ball_dir']) 37 | racket_x = np.asarray(self.hand_craft_feats['racket']) 38 | missing_bricks = np.asarray(self.hand_craft_feats['missing_bricks']) 39 | has_hole = np.asarray(self.hand_craft_feats['hole']) 40 | 41 | self.cond_vector = (ball_x >= self.ball_x_min) * (ball_x <= self.ball_x_max) * \ 42 | (ball_y >= self.ball_y_min) * (ball_y <= self.ball_y_max) * \ 43 | (racket_x >= self.racket_x_min) * (racket_x <= self.racket_x_max) * \ 44 | (missing_bricks >= self.missing_bricks_min) * (missing_bricks <= self.missing_bricks_max) 45 | 46 | # ball dir 47 | dirs_mask = np.zeros_like(self.cond_vector) 48 | for i,val in enumerate(self.ball_dir_check_button.get_status()): 49 | if val: 50 | dirs_mask += self.ball_dir_mat[:,i] 51 | self.cond_vector = self.cond_vector * dirs_mask 52 | 53 | # has a hole 54 | if self.hole_check_button.get_status()[0] == 0: # filter out states with no-hole 55 | self.cond_vector = self.cond_vector * (has_hole) 56 | elif self.hole_check_button.get_status()[1] == 0: # filter out states with a hole 57 | self.cond_vector = self.cond_vector * (1-has_hole) 58 | self.cond_vector = self.cond_vector.astype(int) 59 | 60 | def on_scatter_pick_mark_trajs(self,event): 61 | if hasattr(event,'ind'): 62 | ind = event.ind[0] 63 | traj_ids = self.state_labels[:,6] 64 | times = self.state_labels[:,7] 65 | has_hole = self.state_labels[:,5] 66 | 67 | traj_id = traj_ids[ind] 68 | time = times[ind] 69 | 70 | # if current point is already marked then un-mark the entire trajectory 71 | if self.cond_vector_mark_trajs[ind] == 1: 72 | self.cond_vector_mark_trajs[np.nonzero(traj_ids==traj_id)] = 0 73 | else: 74 | # mark the entire point on the current trajectory from time until has_hole = 1 75 | cond = 1 * (traj_ids == traj_id) * (times >=time) * (has_hole == 0) 76 | self.cond_vector_mark_trajs[np.nonzero(cond)] = 1 77 | self.cond_vector_mark_trajs = self.cond_vector_mark_trajs.astype(int) 78 | 79 | sizes = 5 * self.cond_vector_mark_trajs + np.ones(self.num_points)*self.pnt_size 80 | self.scat_mark_trajs.set_array(self.cond_vector_mark_trajs) 81 | self.scat_mark_trajs.set_sizes(sizes) 82 | plt.pause(0.01) -------------------------------------------------------------------------------- /graying_the_box/hand_crafted_features/add_global_features.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | def add_buttons(self, global_feats): 4 | 5 | ############################# 6 | # 3.1 global coloring buttons 7 | ############################# 8 | self.COLORS = {} 9 | self.add_color_button([0.60, 0.95, 0.09, 0.02], 'value', global_feats['value']) 10 | self.add_color_button([0.70, 0.95, 0.09, 0.02], 'actions', global_feats['actions']) 11 | self.add_color_button([0.60, 0.92, 0.09, 0.02], 'rooms', global_feats['rooms']) 12 | self.add_color_button([0.60, 0.89, 0.09, 0.02], 'gauss_clust', global_feats['gauss_clust']) 13 | self.add_color_button([0.70, 0.89, 0.09, 0.02], 'TD', global_feats['TD']) 14 | self.add_color_button([0.60, 0.86, 0.09, 0.02], 'action repetition', global_feats['act_rep']) 15 | self.add_color_button([0.70, 0.86, 0.09, 0.02], 'reward', global_feats['reward']) 16 | 17 | for i in range((global_feats['single_clusters']).shape[0]): 18 | self.add_color_button([0.80 + 0.10 * math.floor(i / 4), 0.95 - (i % 4) * 0.03, 0.09, 0.02], 'cluster_' + str(i), (global_feats['single_clusters'])[i]) 19 | 20 | 21 | self.SLIDER_FUNCS = [] 22 | self.CHECK_BUTTONS = [] 23 | -------------------------------------------------------------------------------- /graying_the_box/hand_crafted_features/add_global_features.py~: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | def add_buttons(self, global_feats): 4 | 5 | ############################# 6 | # 3.1 global coloring buttons 7 | ############################# 8 | self.COLORS = {} 9 | self.add_color_button([0.60, 0.95, 0.09, 0.02], 'value', global_feats['value']) 10 | self.add_color_button([0.70, 0.95, 0.09, 0.02], 'actions', global_feats['actions']) 11 | self.add_color_button([0.60, 0.92, 0.09, 0.02], 'rooms', global_feats['rooms']) 12 | self.add_color_button([0.60, 0.89, 0.09, 0.02], 'gauss_clust', global_feats['gauss_clust']) 13 | self.add_color_button([0.70, 0.89, 0.09, 0.02], 'TD', global_feats['TD']) 14 | self.add_color_button([0.60, 0.86, 0.09, 0.02], 'action repetition', global_feats['act_rep']) 15 | self.add_color_button([0.70, 0.86, 0.09, 0.02], 'reward', global_feats['reward']) 16 | 17 | for i in range((global_feats['single_clusters']).shape[0]): 18 | self.add_color_button([0.80 + 0.10 * math.floor(i / 4), 0.95 - (i % 24 * 0.03, 0.09, 0.02], 'cluster_' + str(i), (global_feats['single_clusters'])[i]) 19 | 20 | 21 | self.SLIDER_FUNCS = [] 22 | self.CHECK_BUTTONS = [] 23 | -------------------------------------------------------------------------------- /graying_the_box/hand_crafted_features/add_pacman_buttons.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import math 4 | 5 | def add_game_buttons(self, top_y=0.69, x_left = 0.68): 6 | self.add_slider_button([x_left, top_y, 0.08, 0.01], 'player_x', 0, 160) 7 | self.add_slider_button([x_left, top_y-0.04, 0.08, 0.01], 'player_y', 0, 210) 8 | self.add_slider_button([x_left, top_y-0.08, 0.08, 0.01], 'bricks', 0, 121) 9 | self.add_slider_button([x_left, top_y-0.12, 0.08, 0.01], 'enemy_distance', 0, 160+210) 10 | self.add_slider_button([x_left, top_y-0.16, 0.08, 0.01], 'lives', 0, 3) 11 | self.add_check_button([x_left, top_y-0.22, 0.06, 0.03], 'ghost', ('no-ghost','ghost'), (True,True)) 12 | self.add_check_button([x_left, top_y-0.26, 0.06, 0.03], 'box', ('no-box','box'), (True,True)) 13 | self.add_check_button([x_left, top_y-0.35, 0.07, 0.08], 'player_dir', ('stand','right','left','bottom','top'), (True, True, True, True, True)) 14 | self.player_dir_mat = np.zeros(shape=(self.num_points,5), dtype='bool') 15 | for i,dir in enumerate(self.hand_craft_feats['player_dir']): 16 | self.player_dir_mat[i,int(dir)] = 1 17 | 18 | def update_cond_vector(self): 19 | player_x = np.asarray([row[0] for row in self.hand_craft_feats['player_pos']]) 20 | player_y = np.asarray([row[1] for row in self.hand_craft_feats['player_pos']]) 21 | player_dir = np.asarray(self.hand_craft_feats['player_dir']) 22 | bricks = np.asarray(self.hand_craft_feats['bricks']) 23 | nb_lives = np.asarray(self.hand_craft_feats['lives']) 24 | ghost_mode = np.asarray(self.hand_craft_feats['ghost']) 25 | enemies_dist = np.asarray(self.hand_craft_feats['enemy_distance']) 26 | bonus_box = np.asarray(self.hand_craft_feats['box']) 27 | 28 | self.cond_vector = (player_x >= self.player_x_min) * (player_x <= self.player_x_max) * \ 29 | (player_y >= self.player_y_min) * (player_y <= self.player_y_max) * \ 30 | (bricks >= self.bricks_min) * (bricks <= self.bricks_max) * \ 31 | (enemies_dist >= self.enemy_distance_min) * (enemies_dist <= self.enemy_distance_max) * \ 32 | (nb_lives >= self.lives_min) * (nb_lives <= self.lives_max) 33 | 34 | # player dir 35 | dirs_mask = np.zeros_like(self.cond_vector) 36 | for i,val in enumerate(self.player_dir_check_button.get_status()): 37 | if val: 38 | dirs_mask += self.player_dir_mat[:,i] 39 | 40 | self.cond_vector = self.cond_vector * dirs_mask 41 | 42 | # ghost mode 43 | if self.ghost_check_button.get_status()[0] == 0: # filter out states with "no-ghost" 44 | self.cond_vector = self.cond_vector * ghost_mode 45 | elif self.ghost_check_button.get_status()[1] == 0: # filter out states with "ghost" mode 46 | self.cond_vector = self.cond_vector * (1-ghost_mode) 47 | 48 | # bonus box 49 | if self.box_check_button.get_status()[0] == 0: # filter out states with "no-box" 50 | self.cond_vector = self.cond_vector * bonus_box 51 | elif self.box_check_button.get_status()[1] == 0: # filter out states with "box" 52 | self.cond_vector = self.cond_vector * (1-bonus_box) 53 | 54 | self.cond_vector = self.cond_vector.astype(int) 55 | 56 | def value_colored_frame(self): 57 | # value colored frame 58 | fig5 = plt.figure('value colored pacman frame') 59 | ax_5 = fig5.add_subplot(111) 60 | 61 | v_mat = np.zeros((210,160)) 62 | count_mat = np.zeros((210,160)) 63 | player_x = self.state_labels[:,0] 64 | player_y = self.state_labels[:,1] 65 | 66 | for x,y,v in zip(player_x, player_y, self.V): 67 | if math.isnan(x) or math.isnan(y): 68 | continue 69 | v_mat[int(y),int(x)] += v 70 | count_mat[int(y),int(x)] += 1 71 | 72 | v_mat = v_mat / count_mat 73 | 74 | ax_5.imshow(v_mat, interpolation='spline36') -------------------------------------------------------------------------------- /graying_the_box/hand_crafted_features/add_seaquest_buttons.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def add_game_buttons(self, top_y=0.69, x_left = 0.68): 4 | self.add_slider_button([x_left, top_y , 0.08, 0.01], 'shooter_x', 0, 160) 5 | self.add_slider_button([x_left, top_y-0.04 , 0.08, 0.01], 'shooter_y', 0, 210) 6 | self.add_slider_button([x_left, top_y-0.08 , 0.08, 0.01], 'oxygen', 0, 1) 7 | self.add_slider_button([x_left, top_y-0.12 , 0.08, 0.01], 'divers', 0, 6) 8 | self.add_slider_button([x_left, top_y-0.16 , 0.08, 0.01], 'taken_divers', 0, 3) 9 | self.add_slider_button([x_left, top_y-0.20 , 0.08, 0.01], 'enemies', 0, 8) 10 | self.add_slider_button([x_left, top_y-0.24 , 0.08, 0.01], 'lives', 0, 3) 11 | self.add_check_button([x_left, top_y-0.32 , 0.06, 0.04], 'shooter_dir', ('dont-care','down','up'), (True, True, True)) 12 | 13 | self.shooter_dir_mat = np.zeros(shape=(self.num_points,3), dtype='bool') 14 | 15 | for i,dir in enumerate(self.hand_craft_feats['shooter_dir']): 16 | self.shooter_dir_mat[i,int(dir)] = 1 17 | 18 | def update_cond_vector(self): 19 | shooter_x = np.asarray([row[0] for row in self.hand_craft_feats['shooter_pos']]) 20 | shooter_y = np.asarray([row[1] for row in self.hand_craft_feats['shooter_pos']]) 21 | shooter_dir = self.hand_craft_feats['shooter_dir'] 22 | oxygen = np.asarray(self.hand_craft_feats['oxygen']) 23 | nb_divers = np.asarray(self.hand_craft_feats['divers']) 24 | nb_taken_divers = np.asarray(self.hand_craft_feats['taken_divers']) 25 | nb_enemies = np.asarray(self.hand_craft_feats['enemies']) 26 | nb_lives = np.asarray(self.hand_craft_feats['lives']) 27 | 28 | self.cond_vector = (shooter_x >= self.shooter_x_min) * (shooter_x <= self.shooter_x_max) * \ 29 | (shooter_y >= self.shooter_y_min) * (shooter_y <= self.shooter_y_max) * \ 30 | (oxygen >= self.oxygen_min) * (oxygen <= self.oxygen_max) * \ 31 | (nb_divers >= self.divers_min) * (nb_divers <= self.divers_max) * \ 32 | (nb_taken_divers >= self.taken_divers_min) * (nb_taken_divers <= self.taken_divers_max) *\ 33 | (nb_enemies >= self.enemies_min) * (nb_enemies <= self.enemies_max) * \ 34 | (nb_lives >= self.lives_min) * (nb_lives <= self.lives_max) 35 | 36 | # shtr dir 37 | dirs_mask = np.zeros_like(self.cond_vector) 38 | for i,val in enumerate(self.shooter_dir_check_button.get_status()): 39 | if val: 40 | dirs_mask += self.shooter_dir_mat[:,i] 41 | 42 | self.cond_vector = self.cond_vector * dirs_mask 43 | 44 | self.cond_vector = self.cond_vector.astype(int) 45 | -------------------------------------------------------------------------------- /graying_the_box/hand_crafted_features/label_states_breakout.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | def label_states(states, screens, termination_mat, debug_mode, num_lives): 5 | 6 | im_size = np.sqrt(states.shape[1]) 7 | states = np.reshape(states, (states.shape[0], im_size, im_size)).astype('int16') 8 | 9 | screens = np.reshape(np.transpose(screens), (3,210,160,-1)) 10 | 11 | screens = np.transpose(screens,(3,1,2,0)) 12 | 13 | # masks 14 | ball_mask = np.ones_like(screens[0]) 15 | ball_mask[189:] = 0 16 | ball_mask[57:63] = 0 17 | 18 | ball_x_ = 80 19 | ball_y_ = 105 20 | 21 | td_mask = np.ones_like(screens[0]) 22 | td_mask[189:] = 0 23 | td_mask[:25] = 0 24 | 25 | features = { 26 | 'ball_pos': [[-1,-1],[-1,-1]], 27 | 'ball_dir': [-1,-1], 28 | 'racket': [-1,-1], 29 | 'missing_bricks': [0,0], 30 | 'hole': [0,0], 31 | 'traj': [0,0], 32 | 'time': [0,0] 33 | } 34 | 35 | if debug_mode: 36 | fig1 = plt.figure('screens') 37 | ax1 = fig1.add_subplot(111) 38 | screen_plt = ax1.imshow(screens[0], interpolation='none') 39 | plt.ion() 40 | plt.show() 41 | 42 | traj_id = 0 43 | time = 0 44 | strike_counter = 0 45 | s_ = screens[1] 46 | for i,s in enumerate(screens[2:]): 47 | 48 | #0. TD 49 | tdiff = (s - s_) * td_mask 50 | s_ = s 51 | 52 | row_ind, col_ind = np.nonzero(tdiff[:,:,0]) 53 | ball_y = np.mean(row_ind) 54 | ball_x = np.mean(col_ind) 55 | 56 | # #1. ball location 57 | red_ch = s[:,:,0] 58 | # is_red = 255 * (red_ch == 200) 59 | # ball_filtered = np.zeros_like(s) 60 | # ball_filtered[:,:,0] = is_red 61 | # ball_filtered = ball_mask * ball_filtered 62 | # 63 | # row_ind, col_ind = np.nonzero(ball_filtered[:,:,0]) 64 | # 65 | # ball_y = np.mean(row_ind) 66 | # ball_x = np.mean(col_ind) 67 | 68 | #2. ball direction 69 | ball_dir = 0 * (ball_x >= ball_x_ and ball_y >= ball_y_) +\ 70 | 1 * (ball_x >= ball_x_ and ball_y < ball_y_) +\ 71 | 2 * (ball_x < ball_x_ and ball_y >= ball_y_) +\ 72 | 3 * (ball_x < ball_x_ and ball_y < ball_y_) 73 | 74 | ball_x_ = ball_x 75 | ball_y_ = ball_y 76 | 77 | #3. racket position 78 | is_red = 255 * (red_ch[190,8:-8] == 200) 79 | racket_x = np.mean(np.nonzero(is_red)) + 8 80 | 81 | #4. number of bricks 82 | z = red_ch[57:92,8:-8].flatten() 83 | is_brick = np.sum(1*(z>0) + 0*(z==0)) 84 | 85 | missing_bricks = (len(z) - is_brick)/40. 86 | 87 | #5. holes 88 | brick_strip = red_ch[57:92,8:-8] 89 | brick_row_sum = brick_strip.sum(axis=0) 90 | has_hole = np.any((brick_row_sum==0)) 91 | 92 | #6. traj_id 93 | if termination_mat[i] > 0: 94 | strike_counter+=1 95 | if strike_counter%num_lives==0: 96 | traj_id += 1 97 | time = 0 98 | time += 1 99 | 100 | if debug_mode: 101 | screen_plt.set_data(s) 102 | buf_line = ('Exqample %d: ball pos (x,y): (%0.2f, %0.2f), ball direct: %d, racket pos: (%0.2f), number of missing bricks: %d, has a hole: %d, traj id: %d, time: %d, st_cnt: %d') % \ 103 | (i, ball_x, ball_y, ball_dir, racket_x, missing_bricks, has_hole, traj_id, time, strike_counter) 104 | print buf_line 105 | plt.pause(0.001) 106 | 107 | # labels[i] = (ball_x, ball_y, ball_dir, racket_x, missing_bricks, has_hole, traj_id, time) 108 | 109 | features['ball_pos'].append([ball_x, ball_y]) 110 | features['ball_dir'].append(ball_dir) 111 | features['racket'].append(racket_x) 112 | features['missing_bricks'].append(missing_bricks) 113 | features['hole'].append(has_hole) 114 | features['traj'].append(traj_id) 115 | features['time'].append(time) 116 | features['n_trajs'] = traj_id 117 | 118 | return features -------------------------------------------------------------------------------- /graying_the_box/hand_crafted_features/label_states_packman.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import scipy.signal 4 | import scipy.ndimage 5 | 6 | def label_states(states, screens, termination_mat, debug_mode, num_lives): 7 | 8 | screens = np.reshape(np.transpose(screens), (3,210,160,-1)) 9 | 10 | screens = np.transpose(screens,(3,1,2,0)) 11 | 12 | features = { 13 | 'player_pos': [[-1,-1],[-1,-1]], 14 | 'player_dir': [-1,-1], 15 | 'bricks': [-1,-1], 16 | 'lives': [3,3], 17 | 'sweets': [[1,1,1,1],[1,1,1,1]], 18 | 'ghost': [-1,-1], 19 | 'enemy_distance': [-1,-1], 20 | 'box': [0,0] 21 | } 22 | 23 | # player mask 24 | player_mask = np.ones((210,160)) 25 | player_mask[91:96,78:82] = 0 26 | 27 | # number of bricks 28 | brown_frame = 1 * (screens[0,:,:,0]==162) 29 | 30 | small_brick_mask = np.asarray([[45,45, 45, 45, 45, 45], 31 | [45,162,162,162,162,45], 32 | [45,162,162,162,162,45], 33 | [45,45, 45, 45, 45, 45]]) 34 | 35 | wide_brick_mask = np.asarray([[45,45, 45, 45, 45, 45, 45, 45, 45, 45], 36 | [45,162,162,162,162,162,162,162,162,45], 37 | [45,162,162,162,162,162,162,162,162,45], 38 | [45,45, 45, 45, 45, 45, 45, 45, 45, 45]]) 39 | 40 | # sweets 41 | sweets_y = [22,142,22,142] 42 | sweets_x = [6,6,151,151] 43 | 44 | hist_depth = 16 45 | sweets_mat = np.zeros((hist_depth,4)) 46 | sweets_vec_ = np.zeros((4,)) 47 | bonus_times = [[] for i in range(4)] 48 | 49 | # left bottom brick mask = [139:149,5:11,0] 50 | # top left brick mask = [19:29,5:11,0] 51 | # top right brick mask = [19:29,149:155,0] 52 | # bottom right brick mask = [139:149,149:155,0] 53 | brick_color_1 = 45 * np.ones((10,6)) 54 | brick_color_1[1:-1,1:-1] = 180 55 | brick_color_2 = 45 * np.ones((10,6)) 56 | brick_color_2[1:-1,1:-1] = 149 57 | brick_color_3 = 45 * np.ones((10,6)) 58 | brick_color_3[1:-1,1:-1] = 212 59 | brick_color_4 = 45 * np.ones((10,6)) 60 | brick_color_4[1:-1,1:-1] = 232 61 | brick_color_5 = 45 * np.ones((10,6)) 62 | brick_color_5[1:-1,1:-1] = 204 63 | #enemies mask 64 | enemies_mask = np.ones((210,160)) 65 | enemies_mask[20:28,6:10] = 0 66 | enemies_mask[140:148,6:10] = 0 67 | enemies_mask[20:28,150:154] = 0 68 | enemies_mask[140:148,150:154] = 0 69 | 70 | # bonus box 71 | bonus_color_mask = 162 * np.ones((7,6)) 72 | bonus_color_mask[1:-1,1:-1] = 210 73 | 74 | if debug_mode: 75 | fig1 = plt.figure('screens') 76 | ax1 = fig1.add_subplot(111) 77 | screen_plt = ax1.imshow(screens[0], interpolation='none') 78 | fig2 = plt.figure('enemies') 79 | ax2 = fig2.add_subplot(111) 80 | brown_plt = ax2.imshow(brown_frame) 81 | # fig3 = plt.figure('brick heat map') 82 | # ax3 = fig3.add_subplot(111) 83 | # bricks_plt = ax3.imshow(brown_frame, interpolation='none') 84 | plt.ion() 85 | plt.show() 86 | 87 | player_x_ = 79 88 | player_y_ = 133 89 | 90 | restarted_flag = 0 91 | ttime = 0 92 | 93 | for i,s in enumerate(screens[2:]): 94 | # for i,s in enumerate(screens[750:]): 95 | 96 | # 1. player location 97 | player_frame = s[:,:,0]==210 * player_mask 98 | row_ind, col_ind = np.nonzero(player_frame) 99 | player_y = np.mean(row_ind) 100 | player_x = np.mean(col_ind) 101 | 102 | # 2. player direction 103 | dx = player_x - player_x_ 104 | dy = player_y - player_y_ 105 | 106 | player_dir = 0 107 | if abs(dx) >= abs(dy): 108 | if dx > 0: 109 | player_dir = 1 110 | elif dx<0: 111 | player_dir = 2 112 | else: 113 | if dy > 0: 114 | player_dir = 3 115 | elif dy < 0: 116 | player_dir = 4 117 | 118 | player_x_ = player_x 119 | player_y_ = player_y 120 | 121 | # 3. number of bricks to eat 122 | # approximated (fast) 123 | brown_frame = 1 * (s[:,:,0]==162) 124 | brown_sum = np.sum(brown_frame) 125 | nb_bricks_apprx = (brown_sum - 6042)/8. 126 | nb_bricks_apprx = np.maximum(nb_bricks_apprx,0) 127 | 128 | # exact (slow) 129 | # nb_bricks = 0 130 | # for r in range(160): 131 | # for c in range (150): 132 | # small_slice = s[r:r+4,c:c+6,0] 133 | # wide_slice = s[r:r+4,c:c+10,0] 134 | # if (small_slice == small_brick_mask).all(): 135 | # nb_bricks += 1 136 | # elif (wide_slice == wide_brick_mask).all(): 137 | # nb_bricks += 1 138 | 139 | # 4. number of lives 140 | lives_strip = s[184:189,:,0] 141 | _, nb_lives = scipy.ndimage.label(lives_strip) 142 | 143 | # 5. sweets 144 | # look for bottom left brick 145 | sweets_mat[i%hist_depth,0] = 1 * np.all(s[139:149,5:11,0] == brick_color_1) + \ 146 | 1 * np.all(s[139:149,5:11,0] == brick_color_2) + \ 147 | 1 * np.all(s[139:149,5:11,0] == brick_color_3) + \ 148 | 1 * np.all(s[139:149,5:11,0] == brick_color_4) + \ 149 | 1 * np.all(s[139:149,5:11,0] == brick_color_5) 150 | 151 | 152 | # look for top right brick 153 | sweets_mat[i%hist_depth,1] = 1 * np.all(s[19:29,149:155,0] == brick_color_1) + \ 154 | 1 * np.all(s[19:29,149:155,0] == brick_color_2) + \ 155 | 1 * np.all(s[19:29,149:155,0] == brick_color_3) + \ 156 | 1 * np.all(s[19:29,149:155,0] == brick_color_4) + \ 157 | 1 * np.all(s[19:29,149:155,0] == brick_color_5) 158 | 159 | # look for bottom right brick 160 | sweets_mat[i%hist_depth,2] = 1 * np.all(s[139:149,149:155,0] == brick_color_1) + \ 161 | 1 * np.all(s[139:149,149:155,0] == brick_color_2) + \ 162 | 1 * np.all(s[139:149,149:155,0] == brick_color_3) + \ 163 | 1 * np.all(s[139:149,149:155,0] == brick_color_4) + \ 164 | 1 * np.all(s[139:149,149:155,0] == brick_color_5) 165 | 166 | # look for top left brick 167 | sweets_mat[i%hist_depth,3] = 1 * np.all(s[19:29,5:11,0] == brick_color_1) + \ 168 | 1 * np.all(s[19:29,5:11,0] == brick_color_2) + \ 169 | 1 * np.all(s[19:29,5:11,0] == brick_color_3) + \ 170 | 1 * np.all(s[19:29,5:11,0] == brick_color_4) + \ 171 | 1 * np.all(s[19:29,5:11,0] == brick_color_5) 172 | 173 | sweets_vec = 1 * np.any(sweets_mat, axis=0) 174 | 175 | if np.all(sweets_vec_ == 1) and restarted_flag == 0: 176 | ttime = 0 177 | restarted_flag = 1 178 | if not np.all(sweets_vec_ == 1): 179 | restarted_flag = 0 180 | ttime += 1 181 | 182 | # 6. ghost mode 183 | ghost_mode = np.any(1 * (s[:,:,0] == 149) + 1 * (s[:,:,0] == 212)) 184 | 185 | # 7. enemies map 186 | enemies_map = 1 * (s[:,:,0] == 180) + \ 187 | 1 * (s[:,:,0] == 149) + \ 188 | 1 * (s[:,:,0] == 212) + \ 189 | 1 * (s[:,:,0] == 128) + \ 190 | 1 * (s[:,:,0] == 232) + \ 191 | 1 * (s[:,:,0] == 204) 192 | enemies_map = enemies_map * enemies_mask 193 | 194 | enemies_dilate_map = scipy.ndimage.binary_dilation(enemies_map, iterations=1) 195 | labeled_array, nb_enemies = scipy.ndimage.label(enemies_dilate_map) 196 | 197 | min_dist = 999 198 | for j in range(nb_enemies): 199 | enemy_j_rows, enemy_j_cols = np.nonzero(labeled_array==j+1) 200 | e_j_y = np.mean(enemy_j_rows) 201 | e_j_x = np.mean(enemy_j_cols) 202 | dist_j = abs(player_x - e_j_x) + abs(player_y - e_j_y) 203 | if dist_j < min_dist: 204 | min_dist = dist_j 205 | 206 | # 8. bonus box 207 | # has_box = np.any(s[91:96,78:82,0]==210) 208 | has_box = 1 * np.all(s[90:97,77:83,0] == bonus_color_mask) 209 | 210 | #6. bonus bricks option 211 | # bottom left brick 212 | if sweets_vec[0] == 0 and sweets_vec_[0] == 1: 213 | bonus_times[0].append(ttime) 214 | # top right brick 215 | elif sweets_vec[1] == 0 and sweets_vec_[1] == 1: 216 | bonus_times[1].append(ttime) 217 | # bottom right brick 218 | elif sweets_vec[2] == 0 and sweets_vec_[2] == 1: 219 | bonus_times[2].append(ttime) 220 | # top left brick 221 | elif sweets_vec[3] == 0 and sweets_vec_[3] == 1: 222 | bonus_times[3].append(ttime) 223 | 224 | sweets_vec_ = sweets_vec 225 | 226 | if debug_mode: 227 | screen_plt.set_data(s) 228 | brown_plt.set_data(brown_frame) 229 | # bricks_plt.set_data(brown_frame) 230 | buf_line = ('Exqample %d: Player (x,y): (%0.2f,%0.2f), Player dir: %d, number of bricks: %d, number of lives: %d, sweets_vec: %s, ghost mode: %d, minimal enemy distance: %0.2f, bonus box: %d, time: %d') %\ 231 | (i, player_x, player_y, player_dir, nb_bricks_apprx, nb_lives, sweets_vec, ghost_mode, min_dist, has_box, ttime) 232 | print buf_line 233 | plt.pause(0.01) 234 | 235 | features['player_pos'].append([player_x, player_y]) 236 | features['player_dir'].append(player_dir) 237 | features['bricks'].append(nb_bricks_apprx) 238 | features['sweets'].append(sweets_vec) 239 | features['ghost'].append(ghost_mode) 240 | features['lives'].append(nb_lives) 241 | features['enemy_distance'].append(min_dist) 242 | features['box'].append(has_box) 243 | 244 | # histogram of sweets cllection times 245 | if 0: 246 | brick_bl_times = np.asarray(bonus_times[0]) 247 | brick_tr_times = np.asarray(bonus_times[1]) 248 | brick_br_times = np.asarray(bonus_times[2]) 249 | brick_tl_times = np.asarray(bonus_times[3]) 250 | 251 | h_bl = plt.hist(brick_bl_times, bins = 70, range=(0,800), normed=1, facecolor='black', alpha=0.75, label='bottom left') 252 | h_tr = plt.hist(brick_tr_times, bins = 70, range=(0,800), normed=1, facecolor='red', alpha=0.75, label='top right') 253 | h_br = plt.hist(brick_br_times, bins = 70, range=(0,800), normed=1, facecolor='green', alpha=0.75, label='bottom right') 254 | h_tl = plt.hist(brick_tl_times, bins = 70, range=(0,800), normed=1, facecolor='blue', alpha=0.75, label='top left') 255 | 256 | plt.legend(prop={'size':20}) 257 | plt.tick_params(axis='both',which='both',left='off',right='off', labelsize=20) 258 | plt.locator_params(axis='x',nbins=4) 259 | plt.locator_params(axis='y',nbins=8) 260 | 261 | plt.show() 262 | 263 | return features -------------------------------------------------------------------------------- /graying_the_box/hand_crafted_features/label_states_seaquest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import scipy.signal 4 | import scipy.ndimage 5 | 6 | def label_states(states, screens, termination_mat, debug_mode, num_lives): 7 | 8 | screens = np.reshape(np.transpose(screens), (3,210,160,-1)) 9 | 10 | screens = np.transpose(screens,(3,1,2,0)) 11 | 12 | features = { 13 | 'shooter_pos': [[-1,-1],[-1,-1]], 14 | 'shooter_dir': [-1,-1], 15 | 'racket': [-1,-1], 16 | 'oxygen': [0,0], 17 | 'divers': [0,0], 18 | 'taken_divers': [0,0], 19 | 'enemies': [0,0], 20 | 'lives': [3,3], 21 | } 22 | 23 | # shooter convolution 24 | yellow_frame = screens[0,:,:,0]==187 25 | shtr_mask = np.asarray([[1,1,1,1],[1,1,1,1],[1,1,1,1],[1,1,1,1],[1,1,1,1]]) 26 | shooter_conv_map = scipy.signal.convolve2d(yellow_frame, shtr_mask, mode='same') 27 | shtr_y_, shtr_x_ = np.unravel_index(np.argmax(shooter_conv_map),(210,160)) 28 | 29 | #divers 30 | divers_frame = np.ones_like(screens[0,:,:,0]) 31 | divers_frame[:160,:] = 1 * (screens[0,:160,:,0]==66) 32 | 33 | if debug_mode: 34 | fig1 = plt.figure('screens') 35 | ax1 = fig1.add_subplot(111) 36 | screen_plt = ax1.imshow(screens[0], interpolation='none') 37 | fig2 = plt.figure('divers') 38 | ax2 = fig2.add_subplot(111) 39 | divers_plt = ax2.imshow(divers_frame) 40 | # fig3 = plt.figure('enemies conv') 41 | # ax3 = fig3.add_subplot(111) 42 | # diver_conv_plt = ax3.imshow(diver_conv_map, interpolation='none') 43 | plt.ion() 44 | plt.show() 45 | 46 | shtr_dir_ = 0 47 | shtr_dir__ = 0 48 | 49 | for i,s in enumerate(screens[2:]): 50 | # for i,s in enumerate(screens[13210:]): 51 | 52 | # 1. shooter location 53 | yellow_frame = s[:,:,0]==187 54 | shooter_conv_map = scipy.signal.convolve2d(yellow_frame, shtr_mask,mode='same') 55 | shtr_y, shtr_x = np.unravel_index(np.argmax(shooter_conv_map),(210,160)) 56 | 57 | # if shtr_x==0 : 58 | # shtr_x = -1 59 | # if shtr_y==0: 60 | # shtr_y = -1 61 | 62 | # 2. shooter direction 63 | shtr_dir = 1 * (shtr_y >= shtr_y_) +\ 64 | 2 * (shtr_y < shtr_y_) 65 | 66 | coherent = (shtr_dir == shtr_dir_) * (shtr_dir == shtr_dir__) 67 | shtr_dir__ = shtr_dir_ 68 | shtr_dir_ = shtr_dir 69 | 70 | shtr_dir = shtr_dir * coherent 71 | 72 | shtr_x_ = shtr_x 73 | shtr_y_ = shtr_y 74 | 75 | # 3. oxygen 76 | oxgn_line = s[172,49:111,0] 77 | oxgn_lvl = np.sum(1*(oxgn_line==214))/float((111-49)) 78 | 79 | # 4. avaiable divers 80 | divers_frame = 0 * divers_frame 81 | divers_frame[:160,:] = 1 * (s[:160,:,0] == 66) 82 | erote_map = scipy.ndimage.binary_erosion(divers_frame, structure=np.ones((2,2)), iterations=1) 83 | diver_conv_map = scipy.ndimage.binary_dilation(erote_map, iterations=3) 84 | _, nb_divers = scipy.ndimage.label(diver_conv_map) 85 | 86 | # 5. taken divers 87 | # divers_frame = 0 * divers_frame 88 | # divers_frame[178:188,:] = 1 * (s[178:188,:,0] == 24) 89 | # diver_conv_map = scipy.ndimage.binary_dilation(divers_frame, iterations=1) 90 | # _, nb_taken_divers = scipy.ndimage.label(diver_conv_map) 91 | nb_taken_divers = np.sum(1 * s[178,:,0] == 24) 92 | 93 | # 6. enemies 94 | enemies_frame = 0 * divers_frame 95 | enemies_frame[:160,:] = 1 * (s[:160,:,0] == 92) + \ 96 | 2 * (s[:160,:,0] == 160) + \ 97 | 3 * (s[:160,:,0] == 170) + \ 98 | 4 * (s[:160,:,0] == 198) 99 | 100 | enemies_conv_map = scipy.ndimage.binary_dilation(enemies_frame, iterations=2) 101 | _, nb_enemies = scipy.ndimage.label(enemies_conv_map) 102 | 103 | # lives 104 | lives_slice = 1 * (s[18:30,:,0] == 210) 105 | _, nb_lives = scipy.ndimage.label(lives_slice) 106 | 107 | 108 | if debug_mode: 109 | screen_plt.set_data(s) 110 | # divers_plt.set_data(diver_conv_map) 111 | # diver_conv_plt.set_data(diver_conv_map) 112 | buf_line = ('Exqample %d: shooter (x,y): (%0.2f, %0.2f), shooter dir: %d, oxygen level: %0.2f, divers: %d, taken divers: %d, enemies: %d, lives: %d') %\ 113 | (i, shtr_x, shtr_y, shtr_dir, oxgn_lvl, nb_divers, nb_taken_divers, nb_enemies, nb_lives) 114 | print buf_line 115 | plt.pause(0.01) 116 | 117 | features['shooter_pos'].append([shtr_x, shtr_y]) 118 | features['shooter_dir'].append(shtr_dir) 119 | features['oxygen'].append(oxgn_lvl) 120 | features['divers'].append(nb_divers) 121 | features['taken_divers'].append(nb_taken_divers) 122 | features['enemies'].append(nb_enemies) 123 | features['lives'].append(nb_lives) 124 | 125 | return features -------------------------------------------------------------------------------- /graying_the_box/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('hand_crafted_features') 4 | 5 | from prepare_data import prepare_data 6 | from vis_tool import VIS_TOOL 7 | 8 | # Parameters 9 | run_dir = '/home/deep5/DQN_Shahar_Chen_oldpc/dqn_distill/' 10 | num_frames = 100000 11 | game_id = 0 # 0-breakout, 1-seaquest, 2-pacman 12 | load_data = 0 13 | debug_mode = 0 14 | cluster_method = 0 # 0-kmeans, 1-spectral_clustering, 2-EMHC (entropy minimization hierarchical clustering) 15 | n_clusters = 4 16 | window_size = 2 17 | n_iters = 8 18 | entropy_iters = 0 19 | 20 | cluster_params = { 21 | 'method': cluster_method, 22 | 'n_clusters': n_clusters, 23 | 'window_size': window_size, 24 | 'n_iters': n_iters, 25 | 'entropy_iters': entropy_iters 26 | } 27 | global_feats, hand_crafted_feats = prepare_data(game_id, run_dir, num_frames, load_data, debug_mode) 28 | 29 | vis_tool = VIS_TOOL(global_feats=global_feats, hand_craft_feats=hand_crafted_feats, game_id=game_id, cluster_params=cluster_params) 30 | 31 | vis_tool.show() 32 | -------------------------------------------------------------------------------- /graying_the_box/others/SaliencyScore.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/home/tom/OpenBox/bhtsne/') 3 | 4 | import numpy as np 5 | import h5py 6 | import matplotlib.image as mpimg 7 | 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | 11 | numframes = 30000 12 | Seaquestind = 352 13 | Breakoutind = 56 14 | Pacmanind = 1316 15 | 16 | im_size = 84 17 | 18 | 19 | print "loading states... " 20 | 21 | Breakout_state_file = h5py.File('/home/tom/OpenBox/tsne_res/6k/states.h5', 'r') 22 | Breakout_state_mat = Breakout_state_file['data'] 23 | Breakout_states = Breakout_state_mat[:numframes] 24 | Breakout_states = np.reshape(Breakout_states, (Breakout_states.shape[0], im_size,im_size)) 25 | 26 | Seaquest_state_file = h5py.File('/home/tom/OpenBox/tsne_res/seaquest/13k/states.h5', 'r') 27 | Seaquest_state_mat = Seaquest_state_file['data'] 28 | Seaquest_states = Seaquest_state_mat[:numframes] 29 | Seaquest_states = np.reshape(Seaquest_states, (Seaquest_states.shape[0], im_size,im_size)) 30 | 31 | Pacman_state_file = h5py.File('/home/tom/OpenBox/tsne_res/pacman/7k/states.h5', 'r') 32 | Pacman_state_mat = Pacman_state_file['data'] 33 | Pacman_states = Pacman_state_mat[:numframes] 34 | Pacman_states = np.reshape(Pacman_states, (Pacman_states.shape[0], im_size,im_size)) 35 | print "loading grads... " 36 | thresh = 0.1 37 | 38 | Breakout_grad_file = h5py.File('/home/tom/OpenBox/tsne_res/6k/grads.h5', 'r') 39 | Breakout_grad_mat = Breakout_grad_file['data'] 40 | Breakout_grads = Breakout_grad_mat[:numframes] 41 | Breakout_grads[np.abs(Breakout_grads)0] 92 | Seaquest_term = Seaquest_tsne[:,Seaquest_term>0] 93 | Pacman_term = Pacman_tsne[:,Pacman_term>0] 94 | 95 | axs.flat[0].annotate( 96 | 'Termination', 97 | xy = (Breakout_term[0,0], Breakout_term[1,0]), xytext = (-20, 20),size=8, 98 | textcoords = 'offset points', ha = 'right', va = 'bottom', 99 | bbox = dict(boxstyle = 'round,pad=0.5', fc = 'yellow', alpha = 0.5), 100 | arrowprops = dict(arrowstyle = '->', connectionstyle = 'arc3,rad=0')) 101 | 102 | axs.flat[1].annotate( 103 | 'Termination', 104 | xy = (Seaquest_term[0,0], Seaquest_term[1,0]), xytext = (-20, 20),size=8, 105 | textcoords = 'offset points', ha = 'right', va = 'bottom', 106 | bbox = dict(boxstyle = 'round,pad=0.5', fc = 'yellow', alpha = 0.5), 107 | arrowprops = dict(arrowstyle = '->', connectionstyle = 'arc3,rad=0')) 108 | 109 | axs.flat[2].annotate( 110 | 'Termination', 111 | xy = (Pacman_term[0,0], Pacman_term[1,0]), xytext = (-20, 20),size=8, 112 | textcoords = 'offset points', ha = 'right', va = 'bottom', 113 | bbox = dict(boxstyle = 'round,pad=0.5', fc = 'yellow', alpha = 0.5), 114 | arrowprops = dict(arrowstyle = '->', connectionstyle = 'arc3,rad=0')) 115 | 116 | plt.show() 117 | -------------------------------------------------------------------------------- /graying_the_box/prepare_data.py: -------------------------------------------------------------------------------- 1 | import cPickle as pickle 2 | from prepare_global_features import prepare_global_features 3 | 4 | def prepare_data(game_id, run_dir, num_frames, load_data, debug_mode): 5 | # 1. switch games TEST 6 | if game_id == 0: #'breakout' 7 | num_actions = 10 8 | num_lives = 1 9 | data_dir = run_dir 10 | from label_states_breakout import label_states 11 | grad_thresh = 0.1 12 | 13 | elif game_id == 1: #'seaquest' 14 | num_actions = 18 15 | num_lives = 3 16 | data_dir = 'data/' + 'seaquest' + '/' + run_dir + '/' 17 | from label_states_seaquest import label_states 18 | grad_thresh = 0.05 19 | 20 | elif game_id == 2: #'pacman' 21 | num_actions = 5 22 | num_lives = 3 23 | data_dir = 'data/' + 'pacman' + '/' + run_dir + '/' 24 | from label_states_packman import label_states 25 | grad_thresh = 0.05 26 | 27 | # 2. load data 28 | if load_data: 29 | # 2.1 global features 30 | global_feats = pickle.load(file(data_dir + 'global_features.bin','rb')) 31 | 32 | # 2.2 hand craft features 33 | hand_craft_feats = pickle.load(file(data_dir + 'hand_craft_features.bin','rb')) 34 | 35 | # 3. prepare data 36 | else: 37 | # 3.1 global features 38 | global_feats = prepare_global_features(data_dir, num_frames, num_actions, num_lives, grad_thresh) 39 | # pickle.dump(global_feats,file(data_dir + 'global_features.bin','wb')) 40 | 41 | # 3.2 hand craft features 42 | hand_craft_feats = label_states(global_feats['states'], global_feats['screens'], global_feats['termination'], debug_mode=debug_mode, num_lives=num_lives) 43 | # pickle.dump(hand_craft_feats,file(data_dir + 'hand_craft_features.bin','wb')) 44 | 45 | return global_feats, hand_craft_feats 46 | -------------------------------------------------------------------------------- /graying_the_box/prepare_global_features.py: -------------------------------------------------------------------------------- 1 | import common 2 | import numpy as np 3 | import h5py 4 | 5 | def prepare_global_features(data_dir, num_frames, num_actions, num_lives, grad_thresh): 6 | 7 | print "Preparing global features... " 8 | 9 | # load hdf5 files 10 | print "Loading from hdf5 files... " 11 | states_hd5f = common.load_hdf5('statesClean', data_dir, num_frames) 12 | screens_hdf5 = np.ones((3, states_hd5f.shape[0], states_hd5f.shape[1])) 13 | screens_hdf5[0, :, :] = states_hd5f[:, :] 14 | screens_hdf5[1, :, :] = states_hd5f[:, :] 15 | screens_hdf5[2, :, :] = states_hd5f[:, :] 16 | #termination_hdf5 = common.load_hdf5('terminationClean', data_dir, num_frames) 17 | lowd_activation_hdf5 = common.load_hdf5('lowd_activations_800', data_dir, num_frames) 18 | lowd_activation_3d_hdf5 = np.zeros((lowd_activation_hdf5.shape[0], 3)) # common.load_hdf5('lowd_activations3d', data_dir, num_frames) 19 | q_hdf5 = common.load_hdf5('qvalsClean', data_dir, num_frames) 20 | a_hdf5 = np.array(common.load_hdf5('actionsClean', data_dir, num_frames)) 21 | reward_hdf5 = common.load_hdf5('rewardClean', data_dir, num_frames) 22 | with h5py.File('/home/deep5/DQN_Shahar_Chen_oldpc/dqn_distill/tsneMatch.h5', 'r') as hf: 23 | gaus_clust = (np.array(hf.get('data'))) 24 | with h5py.File('/home/deep5/DQN_Shahar_Chen_oldpc/dqn_distill/tsneRooms.h5', 'r') as hf: 25 | rooms = (np.array(hf.get('data'))) 26 | with h5py.File('/home/deep5/DQN_Shahar_Chen_oldpc/dqn_distill/tsneData.h5', 'r') as hf: 27 | single_clusters = (np.array(hf.get('data'))) 28 | 29 | # grads_hdf5 = common.load_hdf5('grads', data_dir, num_frames) 30 | grads_hdf5 = states_hd5f # take state image instead of saliency image until we have saliency for all data 31 | 32 | termination_hdf5 = np.sign(np.array(reward_hdf5)) + 1 33 | 34 | 35 | # 9. get Q and action, translate to numpy 36 | V = np.zeros(shape=num_frames) 37 | Q = np.zeros(shape=(num_frames,num_actions)) 38 | tsne = np.zeros(shape=(num_frames,2)) 39 | a = np.zeros(shape=num_frames) 40 | term = np.zeros(shape=num_frames) 41 | reward = np.zeros(shape=num_frames) 42 | TD = np.zeros(shape=num_frames) 43 | tsne3d = np.zeros(shape=(num_frames,3)) 44 | tsne3d_next = np.zeros(shape=(num_frames,3)) 45 | time = np.zeros(shape=num_frames) 46 | act_rep = np.zeros(shape=num_frames) 47 | trajectory_index = np.zeros(shape=num_frames) 48 | tsne3d_norm = np.zeros(shape=num_frames) 49 | 50 | counter = 0 51 | term_counter = 0 52 | trajectory_counter = 1 53 | 54 | for i in range(1,num_frames-1): 55 | V[i] = q_hdf5[i] 56 | a[i] = float(a_hdf5[i])-1 57 | reward[i] = reward_hdf5[i] 58 | term[i] = termination_hdf5[i] 59 | Q[i] = q_hdf5[i] 60 | tsne[i] = lowd_activation_hdf5[i] 61 | tsne3d[i] = lowd_activation_3d_hdf5[i] 62 | tsne3d_next[i] = lowd_activation_3d_hdf5[i+1]-lowd_activation_3d_hdf5[i] 63 | tsne3d_norm[i] = np.linalg.norm(tsne3d_next[i]) 64 | TD[i] = abs(Q[i-1,int(a[i-1])]-0.99*(Q[i,int(a[i])]+reward[i-1])) 65 | 66 | # calculate time and trajecttory index 67 | time[i] = counter 68 | trajectory_index[i] = trajectory_counter 69 | 70 | if (term[i] > 0):#the fifth terminal 71 | term_counter += 1 72 | if term_counter % num_lives == 0: 73 | counter = 0 74 | trajectory_counter+=1 75 | else: 76 | counter += 1 77 | else: 78 | counter += 1 79 | 80 | Advantage = V-Q.T 81 | risk = np.sum(Advantage,axis=0) 82 | term_binary = term 83 | term_binary[np.nonzero(term_binary!=0)]=1 84 | 85 | global_feats = { 86 | 'tsne': tsne, 87 | 'states':states_hd5f, 88 | 'screens':screens_hdf5, 89 | 'value':V, 90 | 'actions':a, 91 | 'termination':term_binary, 92 | 'risk':risk, 93 | 'tsne3d':tsne3d, 94 | 'tsne3d_norm':tsne3d_norm, 95 | 'Advantage':Advantage, 96 | 'time':time, 97 | 'TD':TD, 98 | 'reward':reward, 99 | 'act_rep':act_rep, 100 | 'tsne3d_next':tsne3d_next, 101 | 'grads':grads_hdf5, 102 | 'trajectory_index':trajectory_index, 103 | 'data_dir':data_dir, 104 | 'gauss_clust':gaus_clust, 105 | 'rooms':rooms, 106 | 'single_clusters':single_clusters 107 | } 108 | 109 | return global_feats 110 | -------------------------------------------------------------------------------- /graying_the_box/smdp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.linalg 3 | 4 | 5 | def divide_tt(X, tt_ratio): 6 | N = X.shape[0] 7 | X_train = X[:int(tt_ratio*N)] 8 | X_test = X[int(tt_ratio*N):] 9 | return X_train, X_test 10 | 11 | class SMDP(object): 12 | def __init__(self, labels, termination, rewards, values, n_clusters, tb=0, gamma=0.99, trunc_th = 0.1, k=5): 13 | 14 | self.k = k 15 | self.gamma = gamma 16 | self.trunc_th = trunc_th 17 | self.rewards = rewards 18 | if tb == 0: 19 | self.labels,self.n_clusters = self.remove_empty_clusters(labels,n_clusters) 20 | else: 21 | self.labels = labels 22 | self.n_clusters = n_clusters 23 | self.termination = termination 24 | self.TT = self.calculate_transition_matrix() 25 | self.P = self.calculate_prob_transition_matrix(self.TT) 26 | if tb == 0: 27 | self.check_empty_P() 28 | self.r, self.skill_time = self.smdp_reward() 29 | self.v_smdp = self.calc_v_smdp(self.P) 30 | self.v_dqn = self.calc_v_dqn(values) 31 | self.score = self.value_score() 32 | self.clusters_count = self.count_clusters() 33 | self.entropy = self.calc_entropy(self.P) 34 | self.edges = self.get_smdp_edges() 35 | 36 | ####### Methods ###### 37 | 38 | def check_empty_P(self): 39 | cluster_ind = 0 40 | for p in (self.P): 41 | if p.sum()==0: 42 | indices = np.nonzero(self.labels==cluster_ind)[0] 43 | for i in indices: 44 | self.labels[i]=self.labels[i-1] 45 | cluster_ind+=1 46 | self.labels, self.n_clusters = self.remove_empty_clusters(self.labels,self.n_clusters) 47 | self.TT = self.calculate_transition_matrix() 48 | self.P = self.calculate_prob_transition_matrix(self.TT) 49 | 50 | def count_clusters(self): 51 | cluster_count = np.zeros(self.n_clusters) 52 | for l in self.labels: 53 | cluster_count[l]+=1 54 | return cluster_count 55 | 56 | def remove_empty_clusters(self,labels,n_c): 57 | # remove empty clusters 58 | labels_i = np.copy(labels) 59 | cluster_flags = np.zeros(n_c,dtype=np.bool) 60 | cluster_count = np.zeros(n_c) 61 | for l in labels_i: 62 | cluster_flags[l] = True 63 | cluster_count[l]+=1 64 | shift_vec = np.cumsum(~cluster_flags) 65 | for ind,l in enumerate(labels_i): 66 | labels_i[ind] -= shift_vec[l] 67 | new_n_clusters = np.max(labels_i)+1 68 | 69 | return labels_i,new_n_clusters 70 | 71 | def calc_entropy(self,P): 72 | e = scipy.stats.entropy(P.T) 73 | e_finite_ind = np.isfinite(e) 74 | entropy = np.average(a=e[e_finite_ind],weights=self.clusters_count) 75 | return entropy 76 | 77 | def truncate(self, M, th): 78 | M[np.nonzero(Mself.k: 115 | mean_rewards[l_p,0] += total_r #/ (t+1) 116 | mean_rewards[l_p,1] += 1 117 | mean_times[l_p,0] += t 118 | mean_times[l_p,1] += 1 119 | l_p = l 120 | total_r = 0 121 | t = 0 122 | 123 | for mr,mt in zip(mean_rewards, mean_times): 124 | mr[0] = mr[0] / mr[1] 125 | mt[0] = mt[0] / mt[1] 126 | 127 | return mean_rewards[:,0], mean_times[:,0] 128 | 129 | def calc_v_smdp(self,P): 130 | GAMMA = np.diag(self.gamma**self.skill_time) 131 | v = np.dot(np.linalg.pinv(np.eye(self.n_clusters)-np.dot(GAMMA,P)),self.r) 132 | return v 133 | 134 | def calc_v_policy(self,P): 135 | skill_time = np.zeros(shape=(self.n_clusters,1)) 136 | r = np.zeros(shape=(self.n_clusters,1)) 137 | for cluster_ind in xrange(self.n_clusters): 138 | cluster_policy_skills = np.nonzero(P[cluster_ind,:]) 139 | cluster_skills = np.nonzero(self.P[cluster_ind,:]) 140 | n_skills = len(cluster_policy_skills[0]) 141 | 142 | for policy_skill_ind in xrange(n_skills): 143 | next_ind = cluster_policy_skills[0][policy_skill_ind] 144 | skill_ind = np.nonzero(cluster_skills[0]==next_ind)[0][0] 145 | r[cluster_ind] += P[cluster_ind,next_ind]*self.R_skills[cluster_ind][skill_ind,0] 146 | skill_time[cluster_ind] += P[cluster_ind,next_ind]*self.k_skills[cluster_ind][skill_ind,0] 147 | GAMMA = np.diag((self.gamma**skill_time)[:,0]) 148 | v = np.dot(np.linalg.pinv(np.eye(self.n_clusters)-np.dot(GAMMA,P)),r) 149 | return v 150 | 151 | def calc_v_dqn(self, values): 152 | 153 | value_vec = np.zeros(shape=(self.n_clusters,2)) 154 | for i,(l,v) in enumerate(zip(self.labels, values)): 155 | value_vec[l, 0] += v 156 | value_vec[l, 1] += 1 157 | 158 | # 1.5 normalize rewards 159 | for val in value_vec: 160 | val[0] = val[0]/val[1] 161 | 162 | return value_vec[:,0] 163 | 164 | def value_score(self): 165 | # v_dqn = (self.v_dqn-self.v_dqn.mean())/self.v_dqn.std() 166 | # v_smdp = (self.v_smdp-self.v_smdp.mean())/self.v_smdp.std() 167 | v_dqn = self.v_dqn 168 | v_smdp = self.v_smdp 169 | return np.linalg.norm(v_dqn-v_smdp)/np.linalg.norm(v_dqn) 170 | 171 | def policy_improvement(self): 172 | policy = [] 173 | for cluster_ind in xrange(self.n_clusters): 174 | n_skills = len(self.skills[cluster_ind][0]) 175 | val = np.zeros(shape=(n_skills,1)) 176 | for skill_ind in xrange(n_skills): 177 | r = self.R_skills[cluster_ind][skill_ind,0] 178 | k = self.k_skills[cluster_ind][skill_ind,0] 179 | next_ind = self.skills[cluster_ind][0][skill_ind] 180 | val[skill_ind] = r+(self.gamma**k)*self.v_smdp[next_ind] 181 | policy.append((cluster_ind,self.skills[cluster_ind][0][np.argmax(val)])) 182 | return policy 183 | 184 | def get_smdp_edges(self): 185 | edges = [] 186 | for i in xrange(self.n_clusters): 187 | for j in xrange(self.n_clusters): 188 | if self.P[i,j]>0: 189 | edges.append((i,j)) 190 | return edges 191 | 192 | def evaluate_greedy_policy(self,l): 193 | PP = np.copy(self.P) 194 | for trans in self.greedy_policy: 195 | p = PP[trans[0],:] 196 | p_greedy = np.zeros_like(p) 197 | p_greedy[trans[1]]=1 198 | p = l*p+(1-l)*p_greedy 199 | PP[trans[0],:]=p 200 | self.v_greedy = self.calc_v_policy(PP) 201 | 202 | def create_skills_model(self,P): 203 | skills = [] 204 | for cluster_ind in xrange(self.n_clusters): 205 | skills.append(np.nonzero(P[cluster_ind,:])) 206 | l_p = int(self.labels[0]) 207 | total_r = 0 208 | t = 0 209 | R_skills = [] 210 | k_skills = [] 211 | for cluster_ind in xrange(len(skills)): 212 | R_skills.append(np.zeros(shape=(len(skills[cluster_ind][0]),2))) 213 | k_skills.append(np.zeros(shape=(len(skills[cluster_ind][0]),2))) 214 | 215 | rewards_clip = np.clip(self.rewards,-1,1) 216 | for i, (l, r) in enumerate(zip(self.labels[1:], rewards_clip[1:])): 217 | total_r += self.gamma**t * r 218 | if l == l_p: 219 | t += 1 220 | else: 221 | if self.P[l_p,l]>0: 222 | skill_index = np.flatnonzero(np.asarray(skills[l_p][0])==l)[0] 223 | R_skills[l_p][skill_index,0] += total_r #/ (t+1) 224 | R_skills[l_p][skill_index,1] += 1 225 | k_skills[l_p][skill_index,0] += t 226 | k_skills[l_p][skill_index,1] += 1 227 | l_p = int(l) 228 | total_r = 0 229 | t = 0 230 | 231 | for skill_r,skill_k in zip(R_skills, k_skills): 232 | skill_r[:,0] /= skill_r[:,1] 233 | skill_k[:,0] /= skill_k[:,1] 234 | 235 | return skills,R_skills,k_skills 236 | 237 | def calc_skill_indices(self): 238 | l_p = int(self.labels[0]) 239 | current_skill = [] 240 | skill_indices = [[] for i in range(self.n_clusters)] 241 | skill_list = [[] for i in range(self.n_clusters)] 242 | skill_time = [[] for i in range(self.n_clusters)] 243 | 244 | for i, l in enumerate(zip(self.labels[1:])): 245 | current_skill.append(i) 246 | if l[0] != l_p: 247 | if self.P[l_p,l[0]]>0: 248 | skill_index = np.nonzero(skill_list[l_p]==l[0])[0] #find skill index in list 249 | curr_length = len(current_skill) 250 | if curr_length > self.k: 251 | length = [] 252 | length.append(curr_length) 253 | 254 | if len(skill_index) == 0: #if not found - append 255 | skill_list[l_p].append(l[0]) 256 | skill_indices[l_p].append(current_skill) 257 | skill_time[l_p].append(length) 258 | else: 259 | skill_indices[l_p][skill_index].extend(current_skill) 260 | skill_time[l_p][skill_index].extend(length) 261 | 262 | l_p = l[0] 263 | current_skill = [] 264 | return skill_indices,skill_list,skill_time 265 | 266 | def complete_smdp(self): 267 | # not needed in spatio-temporal 268 | self.skills, self.R_skills, self.k_skills = self.create_skills_model(self.P) 269 | self.greedy_policy = self.policy_improvement() 270 | self.v_greedy = self.evaluate_greedy_policy(0) 271 | self.skill_indices, self.skill_list,self.skill_time = self.calc_skill_indices() 272 | -------------------------------------------------------------------------------- /graying_the_box/vis_tool.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from matplotlib.widgets import Button, Slider, CheckButtons 4 | from add_global_features import add_buttons as add_global_buttons 5 | from control_buttons import add_buttons as add_control_buttons 6 | import common 7 | 8 | class VIS_TOOL(object): 9 | 10 | def __init__(self, global_feats, hand_craft_feats, game_id, cluster_params): 11 | 12 | # 0. connect arguments 13 | self.global_feats = global_feats 14 | self.game_id = game_id 15 | self.hand_craft_feats = hand_craft_feats 16 | self.num_points = global_feats['tsne'].shape[0] 17 | self.data_t = global_feats['tsne'].T 18 | screens = np.reshape(np.transpose(global_feats['screens']), (3,210,160,-1)) 19 | self.screens = np.transpose(screens,(3,1,2,0)) 20 | self.im_size = np.sqrt(global_feats['states'].shape[1]) 21 | self.states = np.reshape(global_feats['states'], (global_feats['states'].shape[0], self.im_size,self.im_size)) 22 | self.tsne3d_next = global_feats['tsne3d_next'].T 23 | self.tsnedata = global_feats['tsne3d'].T 24 | self.traj_index = global_feats['trajectory_index'] 25 | self.tsne3d_norm = global_feats['tsne3d_norm'] 26 | tmp = global_feats['value'] - np.amin(np.array(global_feats['value'])) 27 | self.color = tmp / np.amax(tmp) 28 | self.cluster_params = cluster_params 29 | 30 | # 1. Constants 31 | self.pnt_size = 2 32 | self.ind = 0 33 | self.prev_ind = 0 34 | 35 | # 2. Plots 36 | self.fig = plt.figure('tSNE') 37 | 38 | # 2.1 t-SNE 39 | self.ax_tsne = plt.subplot2grid((3,5),(0,0), rowspan=3, colspan=3) 40 | 41 | self.tsne_scat = self.ax_tsne.scatter(self.data_t[0], self.data_t[1], s= np.ones(self.num_points)*self.pnt_size,c = self.color,edgecolor='none',picker=5) 42 | 43 | self.ax_tsne.set_xticklabels([]) 44 | self.ax_tsne.set_yticklabels([]) 45 | 46 | # 2.1.5 colorbar 47 | cb_axes = self.fig.add_axes([0.253,0.13,0.2,0.01]) 48 | cbar = self.fig.colorbar(self.tsne_scat, cax=cb_axes, orientation='horizontal', ticks=[0,1]) 49 | cbar.ax.set_xticklabels(['Low','High']) 50 | 51 | # 2.2 game screen (state) 52 | self.ax_screen = plt.subplot2grid((3,5),(2,3), rowspan=1, colspan=1) 53 | 54 | self.screenplot = self.ax_screen.imshow(self.screens[self.ind], interpolation='none') 55 | 56 | self.ax_screen.set_xticklabels([]) 57 | self.ax_screen.set_yticklabels([]) 58 | 59 | # 2.3 gradient image (saliency map) 60 | self.ax_state = plt.subplot2grid((3,5),(2,4), rowspan=1, colspan=1) 61 | 62 | self.stateplot = self.ax_state.imshow(self.states[self.ind], interpolation='none', cmap='gray',picker=5) 63 | 64 | self.ax_state.set_xticklabels([]) 65 | self.ax_state.set_yticklabels([]) 66 | 67 | # 3. Global Features 68 | add_global_buttons(self, global_feats) 69 | 70 | # 4. Control buttons 71 | add_control_buttons(self) 72 | 73 | # 4.1 add game buttons 74 | if game_id == 0: # breakout 75 | from add_breakout_buttons import add_game_buttons, update_cond_vector 76 | if game_id == 1: # seaquest 77 | from add_seaquest_buttons import add_game_buttons, update_cond_vector 78 | if game_id == 2: # pacman 79 | from add_pacman_buttons import add_game_buttons, update_cond_vector 80 | 81 | add_game_buttons(self) 82 | 83 | self.update_cond_vector = update_cond_vector 84 | 85 | def add_color_button(self, pos, name, color): 86 | def set_color(event): 87 | self.color = self.COLORS[id(event.inaxes)] 88 | self.tsne_scat.set_array(self.color) 89 | 90 | ax = plt.axes(pos) 91 | setattr(self, name+'_button', Button(ax, name)) 92 | getattr(self, name+'_button').on_clicked(set_color) 93 | color = np.array(color) - np.amin(np.array(color)) 94 | self.COLORS[id(ax)] = color/np.amax(color) 95 | 96 | def update_sliders(self, val): 97 | for f in self.SLIDER_FUNCS: 98 | f() 99 | # self.update_cond_vector_breakout() 100 | 101 | def add_slider_button(self, pos, name, v_min, v_max): 102 | 103 | def update_slider(self, name, slider): 104 | def f(): 105 | setattr(self, name, slider.val) 106 | return f 107 | 108 | ax_min = plt.axes(pos) 109 | ax_max = plt.axes([pos[0], pos[1]-0.02, pos[2], pos[3]]) 110 | 111 | slider_min = Slider(ax_min, name+'_min', valmin=v_min, valmax=v_max, valinit=v_min) 112 | slider_max = Slider(ax_max, name+'_max', valmin=v_min, valmax=v_max, valinit=v_max) 113 | 114 | self.SLIDER_FUNCS.append(update_slider(self, name+'_min', slider_min)) 115 | self.SLIDER_FUNCS.append(update_slider(self, name+'_max', slider_max)) 116 | 117 | slider_min.on_changed(self.update_sliders) 118 | slider_max.on_changed(self.update_sliders) 119 | 120 | def add_check_button(self, pos, name, options, init_vals): 121 | def set_options(label): 122 | pass 123 | 124 | ax = plt.axes(pos) 125 | setattr(self, name+'_check_button', CheckButtons(ax, options, init_vals)) 126 | getattr(self, name+'_check_button').on_clicked(set_options) 127 | 128 | def on_scatter_pick(self,event): 129 | self.ind = event.ind[0] 130 | self.update_plot() 131 | self.prev_ind = self.ind 132 | 133 | def update_plot(self): 134 | self.screenplot.set_array(self.screens[self.ind]) 135 | self.stateplot.set_array(self.states[self.ind]) 136 | sizes = self.tsne_scat.get_sizes() 137 | sizes[self.ind] = 250 138 | sizes[self.prev_ind] = self.pnt_size 139 | self.tsne_scat.set_sizes(sizes) 140 | self.fig.canvas.draw() 141 | print 'chosen point: %d' % self.ind 142 | 143 | def show(self): 144 | plt.show(block=True) 145 | --------------------------------------------------------------------------------