├── .gitignore ├── .luacheckrc ├── tests ├── changelimiter.lua ├── encoder-decoder.lua └── load_checkpoint_interactive.lua ├── ChangeLimiter.lua ├── data_utils.lua ├── Noise.lua ├── Print.lua ├── DownsampledDecoder.lua ├── Decoder.lua ├── AtariDecoder.lua ├── ActionDecoder.lua ├── Scale.lua ├── create_video.py ├── resize.py ├── ScheduledWeightSharpener.lua ├── analyze_loss.py ├── DownsampledAutoencoder.lua ├── KITTIEncoder.lua ├── DownsampledEncoder.lua ├── AtariEncoder.lua ├── ActionEncoder.lua ├── UnsupervisedEncoder.lua ├── BallsEncoder.lua ├── action_data_converter.lua ├── utils.py ├── action_data_converter.py ├── action_data_converter2.lua ├── action_data_converter_all.lua ├── balls_generator.py ├── runner.py ├── render_action_examples.lua ├── render_examples.lua ├── atari_runner.py ├── render_balls_examples.lua ├── kitti_data_converter.lua ├── action_runner.py ├── MotionBCECriterion.lua ├── balls_runner.py ├── render_atari_examples.lua ├── render_downsampled_examples.lua ├── downsampled_runner.py ├── data_loaders.lua ├── vis.lua ├── val_vis.py ├── utils.lua ├── render_generalization_face.lua ├── bouncing_balls.py ├── render_generalization_atari.lua ├── render_generalization_downsampled.lua ├── render_generalization_balls.lua ├── render_generalization_action.lua ├── balls_main.lua ├── action_main.lua └── downsampled_main.lua /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | */.DS_Store 3 | *.ipynb 4 | .ipynb_checkpoints 5 | sample.lua 6 | data/ 7 | slurm_logs/ 8 | slurm_scripts/ 9 | networks/ 10 | old_networks/ 11 | logs 12 | reports 13 | -------------------------------------------------------------------------------- /.luacheckrc: -------------------------------------------------------------------------------- 1 | globals = { 2 | "torch", 3 | "nn", 4 | "optim", 5 | "paths", 6 | "opt", 7 | "image", 8 | } 9 | 10 | ignore = { 11 | "opt", 12 | } 13 | 14 | allow_defined = true 15 | allow_defined_top = true 16 | -------------------------------------------------------------------------------- /tests/changelimiter.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | require 'ChangeLimiter' 4 | 5 | torch.manualSeed(1) 6 | 7 | -- parameters 8 | local precision = 1e-5 9 | local jac = nn.Jacobian 10 | 11 | -- define inputs and module 12 | local input = torch.rand(3, 200) 13 | 14 | local network = nn.Sequential() 15 | network:add(nn.SplitTable(1)) 16 | network:add(nn.ChangeLimiter()) 17 | 18 | -- test backprop, with Jacobian 19 | local err = jac.testJacobian(network, input) 20 | print('==> error: ' .. err) 21 | if err module OK') 23 | else 24 | print('==> error too large, incorrect implementation') 25 | end 26 | -------------------------------------------------------------------------------- /ChangeLimiter.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | ChangeLimiter = torch.class('nn.ChangeLimiter', 'nn.Module') 4 | 5 | function ChangeLimiter:updateOutput(input) 6 | -- print(input) 7 | local distribution, input1, input2 = table.unpack(input) 8 | self.output = torch.cmul(input1, ((distribution * -1) + 1)) + torch.cmul(input2, distribution) -- why don't you do an AddTable? 9 | -- print(self.output) 10 | return self.output 11 | end 12 | 13 | function ChangeLimiter:updateGradInput(input, gradOutput) 14 | local distribution, input1, input2 = table.unpack(input) 15 | self.gradInput = { 16 | torch.cmul(gradOutput, (input2 - input1)), 17 | torch.cmul(gradOutput, ((distribution * -1) + 1)), 18 | torch.cmul(gradOutput, distribution), 19 | } 20 | 21 | return self.gradInput 22 | end 23 | -------------------------------------------------------------------------------- /data_utils.lua: -------------------------------------------------------------------------------- 1 | require 'hdf5' 2 | require 'paths' 3 | require 'math' 4 | require 'xlua' 5 | 6 | -- train-val-test: 70-15-15 split 7 | function split_batches(examples) 8 | local num_test = math.floor(#examples * 0.15) 9 | local num_val = num_test 10 | local num_train = #examples - 2*num_test 11 | 12 | local test = {} 13 | local val = {} 14 | local train = {} 15 | 16 | -- shuffle examples 17 | local ridxs = torch.randperm(#examples) 18 | for i = 1, ridxs:size(1) do 19 | xlua.progress(i, ridxs:size(1)) 20 | local batch = examples[ridxs[i]] 21 | if i <= num_train then 22 | table.insert(train, batch) 23 | elseif i <= num_train + num_val then 24 | table.insert(val, batch) 25 | else 26 | table.insert(test, batch) 27 | end 28 | end 29 | return {train, val, test} 30 | end 31 | -------------------------------------------------------------------------------- /Noise.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | Noise, parent = torch.class('nn.Noise', 'nn.Module') 4 | 5 | function Noise:__init(variance) 6 | parent.__init(self) 7 | self.variance = variance 8 | self.active = true 9 | end 10 | 11 | function Noise:updateOutput(input) 12 | if self.active then 13 | local noise = input:clone() 14 | if self.variance == 0 then 15 | noise:fill(0) 16 | else 17 | noise:normal(0, self.variance) 18 | end 19 | 20 | self.output = input + noise 21 | else 22 | self.output = input 23 | end 24 | return self.output 25 | end 26 | 27 | function Noise:updateGradInput(_, gradOutput) 28 | self.gradInput = gradOutput:clone() 29 | return self.gradInput 30 | end 31 | 32 | function Noise:training() 33 | self.active = true 34 | end 35 | 36 | function Noise:evaluate() 37 | self.active = false 38 | end 39 | -------------------------------------------------------------------------------- /Print.lua: -------------------------------------------------------------------------------- 1 | -- require 'nn' 2 | local Print = torch.class('nn.Print', 'nn.Module') 3 | 4 | function Print:__init(name, just_dimensions) 5 | self.name = name 6 | self.just_dimensions = just_dimensions or false 7 | end 8 | 9 | function Print:updateOutput(input) 10 | if self.just_dimensions then 11 | print(self.name.." input dimensions: ") 12 | if type(input) == 'table' then 13 | print("table:", input) 14 | else 15 | print(input:size()) 16 | end 17 | else 18 | print(self.name.." input: ") 19 | print(input) 20 | end 21 | self.output = input 22 | return input 23 | end 24 | 25 | function Print:updateGradInput(_, gradOutput) 26 | if self.just_dimensions then 27 | print(self.name.." gradOutput dimensions:") 28 | if type(gradOutput) == 'table' then 29 | print("table:", #gradOutput) 30 | else 31 | print(gradOutput:size()) 32 | end 33 | else 34 | print(self.name.." gradOutput:") 35 | print(gradOutput) 36 | end 37 | self.gradInput = gradOutput 38 | return self.gradInput 39 | end 40 | -------------------------------------------------------------------------------- /DownsampledDecoder.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | -- transforms () 4 | local DownsampledDecoder = function(dim_hidden, color_channels, feature_maps) 5 | local filter_size = 6 6 | local encoded_size = 10 7 | local decoder = nn.Sequential() 8 | decoder:add(nn.Linear(dim_hidden, (feature_maps/4)*encoded_size*encoded_size )) 9 | decoder:add(nn.Threshold(0,1e-6)) 10 | 11 | decoder:add(nn.Reshape((feature_maps/4),encoded_size,encoded_size)) 12 | 13 | decoder:add(nn.SpatialUpSamplingNearest(2)) 14 | decoder:add(nn.SpatialConvolution(feature_maps/4,feature_maps/2, filter_size, filter_size)) 15 | decoder:add(nn.Threshold(0,1e-6)) 16 | 17 | decoder:add(nn.SpatialUpSamplingNearest(2)) 18 | decoder:add(nn.SpatialConvolution(feature_maps/2,feature_maps,filter_size,filter_size)) 19 | decoder:add(nn.Threshold(0,1e-6)) 20 | 21 | decoder:add(nn.SpatialUpSamplingNearest(2)) 22 | decoder:add(nn.SpatialConvolution(feature_maps,feature_maps,filter_size,filter_size)) 23 | decoder:add(nn.Threshold(0,1e-6)) 24 | 25 | decoder:add(nn.SpatialUpSamplingNearest(2)) 26 | decoder:add(nn.SpatialConvolution(feature_maps,color_channels,filter_size+1,filter_size+1)) 27 | decoder:add(nn.Sigmoid()) 28 | return decoder 29 | end 30 | 31 | return DownsampledDecoder 32 | -------------------------------------------------------------------------------- /tests/encoder-decoder.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cutorch' 3 | require 'cunn' 4 | 5 | local UnsupervisedEncoder = require 'UnsupervisedEncoder' 6 | local Decoder = require 'Decoder' 7 | 8 | torch.manualSeed(1) 9 | -- torch.setdefaulttensortype('torch.CudaTensor') 10 | 11 | -- parameters 12 | local precision = 1e-5 13 | local jac = nn.Jacobian 14 | 15 | local dim_hidden = 200 16 | local color_channels = 1 17 | local feature_maps = 96 18 | local filter_size = 5 19 | 20 | local image_size = 150 21 | 22 | iteration = 1 23 | 24 | -- define inputs and module 25 | local input = torch.rand(2, 1, image_size, image_size):cuda() 26 | 27 | local network = nn.Sequential() 28 | network:add(nn.SplitTable(1)) 29 | network:add(nn.ParallelTable():add(nn.Reshape(1, image_size, image_size)):add(nn.Reshape(1, image_size, image_size))) 30 | network:add(UnsupervisedEncoder(dim_hidden, color_channels, feature_maps, filter_size, 1)) 31 | network:add(Decoder(dim_hidden, color_channels, feature_maps, filter_size)) 32 | 33 | network:cuda() 34 | 35 | network:evaluate() 36 | 37 | -- test backprop, with Jacobian 38 | local err = jac.testJacobian(network, input) 39 | print('==> error: ' .. err) 40 | if err module OK') 42 | else 43 | print('==> error too large, incorrect implementation') 44 | end 45 | -------------------------------------------------------------------------------- /Decoder.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | local Decoder = function(dim_hidden, color_channels, feature_maps, batch_norm) 4 | local decoder = nn.Sequential() 5 | decoder:add(nn.Linear(dim_hidden, (feature_maps/4)*15*15 )) 6 | decoder:add(nn.Threshold(0,1e-6)) 7 | 8 | decoder:add(nn.Reshape((feature_maps/4),15,15)) 9 | 10 | decoder:add(nn.SpatialUpSamplingNearest(2)) 11 | decoder:add(nn.SpatialConvolution(feature_maps/4,feature_maps/2, 7, 7)) 12 | if batch_norm then 13 | decoder:add(nn.SpatialBatchNormalization(feature_maps/2)) 14 | end 15 | decoder:add(nn.Threshold(0,1e-6)) 16 | 17 | decoder:add(nn.SpatialUpSamplingNearest(2)) 18 | decoder:add(nn.SpatialConvolution(feature_maps/2,feature_maps,7,7)) 19 | if batch_norm then 20 | decoder:add(nn.SpatialBatchNormalization(feature_maps)) 21 | end 22 | decoder:add(nn.Threshold(0,1e-6)) 23 | 24 | decoder:add(nn.SpatialUpSamplingNearest(2)) 25 | decoder:add(nn.SpatialConvolution(feature_maps,feature_maps,7,7)) 26 | if batch_norm then 27 | decoder:add(nn.SpatialBatchNormalization(feature_maps)) 28 | end 29 | decoder:add(nn.Threshold(0,1e-6)) 30 | 31 | decoder:add(nn.SpatialUpSamplingNearest(2)) 32 | decoder:add(nn.SpatialConvolution(feature_maps,color_channels,7,7)) 33 | decoder:add(nn.Sigmoid()) 34 | return decoder 35 | end 36 | 37 | return Decoder 38 | -------------------------------------------------------------------------------- /AtariDecoder.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | -- transforms () 4 | local AtariDecoder = function(dim_hidden, color_channels, feature_maps, batch_norm) 5 | local decoder = nn.Sequential() 6 | decoder:add(nn.Linear(dim_hidden, (feature_maps/4)*19*16 )) 7 | decoder:add(nn.Threshold(0,1e-6)) 8 | 9 | decoder:add(nn.Reshape((feature_maps/4),19,16)) 10 | 11 | decoder:add(nn.SpatialUpSamplingNearest(2)) 12 | decoder:add(nn.SpatialConvolution(feature_maps/4,feature_maps/2, 7, 7)) 13 | if batch_norm then 14 | decoder:add(nn.SpatialBatchNormalization(feature_maps/2)) 15 | end 16 | decoder:add(nn.Threshold(0,1e-6)) 17 | 18 | decoder:add(nn.SpatialUpSamplingNearest(2)) 19 | decoder:add(nn.SpatialConvolution(feature_maps/2,feature_maps,8,8)) 20 | if batch_norm then 21 | decoder:add(nn.SpatialBatchNormalization(feature_maps)) 22 | end 23 | decoder:add(nn.Threshold(0,1e-6)) 24 | 25 | decoder:add(nn.SpatialUpSamplingNearest(2)) 26 | decoder:add(nn.SpatialConvolution(feature_maps,feature_maps,8,7)) 27 | if batch_norm then 28 | decoder:add(nn.SpatialBatchNormalization(feature_maps)) 29 | end 30 | decoder:add(nn.Threshold(0,1e-6)) 31 | 32 | decoder:add(nn.SpatialUpSamplingNearest(2)) 33 | decoder:add(nn.SpatialConvolution(feature_maps,color_channels,7,7)) 34 | decoder:add(nn.Sigmoid()) 35 | return decoder 36 | end 37 | 38 | return AtariDecoder 39 | -------------------------------------------------------------------------------- /ActionDecoder.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | -- (14,16) --> (120, 160) 4 | local ActionDecoder = function(dim_hidden, color_channels, feature_maps, batch_norm) 5 | local decoder = nn.Sequential() 6 | decoder:add(nn.Linear(dim_hidden, (feature_maps/4)*14*16 )) 7 | decoder:add(nn.Threshold(0,1e-6)) 8 | 9 | decoder:add(nn.Reshape((feature_maps/4),14,16)) 10 | 11 | decoder:add(nn.SpatialUpSamplingNearest(2)) 12 | decoder:add(nn.SpatialConvolution(feature_maps/4,feature_maps/2, 7, 8)) 13 | if batch_norm then 14 | decoder:add(nn.SpatialBatchNormalization(feature_maps/2)) 15 | end 16 | decoder:add(nn.Threshold(0,1e-6)) 17 | 18 | decoder:add(nn.SpatialUpSamplingNearest(2)) 19 | decoder:add(nn.SpatialConvolution(feature_maps/2,feature_maps,8,8)) 20 | if batch_norm then 21 | decoder:add(nn.SpatialBatchNormalization(feature_maps)) 22 | end 23 | decoder:add(nn.Threshold(0,1e-6)) 24 | 25 | decoder:add(nn.SpatialUpSamplingNearest(2)) 26 | decoder:add(nn.SpatialConvolution(feature_maps,feature_maps,8,8)) 27 | if batch_norm then 28 | decoder:add(nn.SpatialBatchNormalization(feature_maps)) 29 | end 30 | decoder:add(nn.Threshold(0,1e-6)) 31 | 32 | decoder:add(nn.SpatialUpSamplingNearest(2)) 33 | decoder:add(nn.SpatialConvolution(feature_maps,color_channels,7,7)) 34 | decoder:add(nn.Sigmoid()) 35 | return decoder 36 | end 37 | 38 | return ActionDecoder 39 | -------------------------------------------------------------------------------- /Scale.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2014 Google Inc. 3 | 4 | See LICENSE file for full terms of limited license. 5 | ]] 6 | 7 | require "nn" 8 | local image = require "image" 9 | 10 | local scale = torch.class('nn.Scale', 'nn.Module') 11 | 12 | 13 | function scale:__init(height, width) 14 | self.height = height 15 | self.width = width 16 | end 17 | 18 | function scale:scale_one(input) 19 | -- output:zero():add(0.299, input[1]):add(0.587, input[2]):add(0.114, input[3]) 20 | local output = image.rgb2y(input:float()) -- turn it into grayscale (luminance) 21 | output = (image.scale(output, self.width, self.height, 'bilinear')) 22 | return output 23 | end 24 | 25 | function scale:forward(x) 26 | local is_cuda = (x:type() == "torch.CudaTensor") 27 | -- self.output = x 28 | if x:dim() > 3 then 29 | self.output = torch.Tensor(x:size(1), 1, self.width, self.height):float() 30 | for i = 1, x:size(1) do 31 | -- puts the scaled version directly in output 32 | self.output[i] = self:scale_one(x[i]) 33 | end 34 | else 35 | -- self.output = torch.Tensor(1, self.width, self.height) 36 | self.output = self:scale_one(x) 37 | end 38 | if is_cuda then 39 | self.output = self.output:cuda() 40 | end 41 | 42 | return self.output 43 | end 44 | 45 | function scale:updateOutput(input) 46 | return self:forward(input) 47 | end 48 | 49 | function scale:float() 50 | end 51 | -------------------------------------------------------------------------------- /create_video.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | from images2gif import writeGif 5 | from PIL import Image 6 | import os 7 | import pprint 8 | 9 | # Usage: Download images2gif 10 | # Change 11 | # for im in images: 12 | # palettes.append( getheader(im)[1] ) 13 | # To 14 | # for im in images: 15 | # palettes.append(im.palette.getdata()[1]) 16 | 17 | def create_gif(images_root): 18 | """ 19 | writeGif(filename, images, duration=0.1, loops=0, dither=1) 20 | Write an animated gif from the specified images. 21 | images should be a list of numpy arrays of PIL images. 22 | Numpy images of type float should have pixels between 0 and 1. 23 | Numpy images of other types are expected to have values between 0 and 255. 24 | """ 25 | def img_id(filename): 26 | begin = len('changing_') 27 | end = filename.find('_amount') 28 | return int(filename[begin:end]) 29 | 30 | file_names = sorted([fn for fn in os.listdir(images_root) if fn.endswith('.png')], key=lambda x: img_id(x)) 31 | images = [Image.open(os.path.join(images_root,fn)) for fn in file_names] 32 | filename = os.path.join(images_root, "gif.GIF") 33 | # print filename 34 | writeGif(filename, images, duration=0.05) 35 | 36 | if __name__ == '__main__': 37 | root = '/Users/MichaelChang/Dropbox (MIT Solar Car Team)/MacHD/Documents/Important/MIT/Research/SuperUROP/Code/unsupervised-dcign/renderings/mutation' 38 | images_root = 'ballsgss3_Feb_23_08_10' 39 | for exp in [f for f in os.listdir(os.path.join(root,images_root)) if '.txt' not in f and not f.startswith('.')]: 40 | for demo in [f for f in os.listdir(os.path.join(*[root,images_root,exp])) if not f.startswith('.')]: 41 | print demo 42 | create_gif(os.path.join(*[root, images_root, exp, demo])) 43 | -------------------------------------------------------------------------------- /resize.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from resizeimage import resizeimage 4 | from progressbar import ProgressBar 5 | 6 | def resize(infolder, outfolder, imagefile, size): 7 | """ 8 | infolder: folder that input image lives 9 | outfolder: folder that output image will live 10 | imagefile: image filename, like img.png 11 | size: list: [150,150] 12 | """ 13 | fd_img = open(os.path.join(infolder,imagefile),'r') 14 | img = Image.open(fd_img) 15 | img = resizeimage.resize_crop(img, size) 16 | img.save(os.path.join(outfolder,imagefile), img.format) 17 | fd_img.close() 18 | 19 | def resize_in_folder(folder,size): 20 | for img in os.listdir(folder): 21 | parent = os.path.dirname(folder) 22 | outfolder = os.path.join(parent,os.path.basename(folder) + '_resize') 23 | if not os.path.exists(outfolder): os.mkdir(outfolder) 24 | resize(folder, outfolder, img, size) 25 | 26 | def resize_kitti(): 27 | settings = ['road', 'campus', 'residential', 'city'] 28 | for setting in settings: 29 | root = '/om/data/public/mbchang/udcign-data/kitti/raw/videos/' + setting + '/' 30 | ch = 'image_02' # color 31 | size = [150,150] 32 | pbar = ProgressBar().start() 33 | i = 0 34 | for folder in os.listdir(root): 35 | img_folder = os.path.join(root, folder + '/' + ch + '/data') 36 | resize_in_folder(img_folder, size) 37 | 38 | pbar.update(i + 1) 39 | i += 1 40 | pbar.finish() 41 | 42 | def resize_toyota(): 43 | root = '/om/data/public/mbchang/udcign-data/toyota/pics' 44 | size = [150,150] 45 | pbar = ProgressBar().start() 46 | i = 0 47 | for folder in os.listdir(root): 48 | resize_in_folder(os.path.join(root,folder), size) 49 | pbar.update(i + 1) 50 | i += 1 51 | pbar.finish() 52 | 53 | if __name__ == '__main__': 54 | resize_toyota() 55 | -------------------------------------------------------------------------------- /ScheduledWeightSharpener.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Input: A table {x, y} of a Tensor x and a scalar y. 3 | Output: x^y (elementwise) 4 | 5 | taken from https://github.com/kaishengtai/torch-ntm/blob/master/layers/PowTable.lua 6 | --]] 7 | 8 | local ScheduledWeightSharpener, parent = torch.class('nn.ScheduledWeightSharpener', 'nn.Module') 9 | 10 | function ScheduledWeightSharpener:__init(sharpening_rate) 11 | parent.__init(self) 12 | self.slope = sharpening_rate 13 | end 14 | 15 | function ScheduledWeightSharpener:getP() 16 | local iteration = opt.current_scheduler_iteration 17 | return math.min(1 + (iteration / 10000) * self.slope, 100) 18 | end 19 | 20 | function ScheduledWeightSharpener:updateOutput(input) 21 | local v = input:clone() 22 | v:clamp(0,1000000) 23 | 24 | -- smoothly increase the sharpening from 1 to 100 25 | -- iteration is defined globally in the training loop 26 | local p = self:getP() 27 | -- print('v:', v) 28 | -- print('p:', p) 29 | self.output = torch.pow(v, p) 30 | if self.output[1][1] ~= self.output[1][1] then 31 | print('Made a nan set of weights.') 32 | print('v:', v) 33 | print('p:', p) 34 | os.exit(1) 35 | end 36 | -- print(self.output) 37 | return self.output 38 | end 39 | 40 | function ScheduledWeightSharpener:updateGradInput(input, gradOutput) 41 | local v = input:clone() 42 | v:clamp(0,1000000) 43 | local p = self:getP() 44 | 45 | self.gradInput = torch.cmul(gradOutput, torch.pow(v, p - 1)) * p 46 | 47 | if self.gradInput[1][1] ~= self.gradInput[1][1] then 48 | print('Made a nan set of gradients.') 49 | print('v:', v) 50 | print('p:', p) 51 | print('gradInput:', self.gradInput) 52 | print('gradOutput:', gradOutput) 53 | os.exit(1) 54 | end 55 | -- local pgrad = 0 56 | -- for i = 1, v:size(1) do 57 | -- if v[i] > 0 then 58 | -- pgrad = pgrad + math.log(v[i]) * self.output[1][i] * gradOutput[1][i] 59 | -- end 60 | -- end 61 | -- pgrad = pgrad + 0.001 62 | -- print('pgrad: ', pgrad, 'modified pgrad: ', ) 63 | -- self.gradInput[2][1] = pgrad 64 | return self.gradInput 65 | end 66 | -------------------------------------------------------------------------------- /tests/load_checkpoint_interactive.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cutorch' 3 | require 'cunn' 4 | require 'image' 5 | 6 | vis = require 'vis' 7 | data_loaders = require 'data_loaders' 8 | Encoder = require 'AtariEncoder' 9 | Decoder = require 'AtariDecoder' 10 | 11 | base_directory = "/om/user/wwhitney/unsupervised-dcign/networks" 12 | 13 | function getLastSnapshot(network_name) 14 | local res_file = io.popen("ls -t "..paths.concat(base_directory, network_name).." | grep -i epoch | head -n 1") 15 | local result = res_file:read():match( "^%s*(.-)%s*$" ) 16 | res_file:close() 17 | return result 18 | end 19 | 20 | 21 | network_name = "down_motion_scale_3_noise_0.1_heads_3_sharpening_rate_10_gpu_learning_rate_0.0002_model_disentangled_dataset_name_space_invaders_frame_interval_1" 22 | epoch = getLastSnapshot(network_name) 23 | checkpoint = torch.load('networks/'..network_name..'/'..epoch) 24 | opt = checkpoint.opt 25 | 26 | model = checkpoint.model 27 | encoder = model.modules[1] 28 | decoder = model.modules[2] 29 | 30 | model:evaluate() 31 | 32 | print(model) 33 | 34 | -- weight_predictor = encoder:findModules('nn.Normalize')[1] 35 | -- previous_embedding = encoder:findModules('nn.Linear')[1] 36 | -- current_embedding = encoder:findModules('nn.Linear')[2] 37 | 38 | -- for i, mod in ipairs(encoder:listModules()) do print(i, mod) end 39 | 40 | -- batch = data_loaders.load_atari_batch(339, 'test') 41 | -- output = model:forward(batch):clone() 42 | 43 | 44 | -- test_input_frame_index = 14 45 | 46 | -- weights = weight_predictor.output[test_input_frame_index] 47 | -- mx, idx = weights:max(1) 48 | -- mx = mx[1] 49 | -- idx = idx[1] 50 | 51 | -- print("mx: ", mx, "idx: ", idx) 52 | 53 | -- for i = 1, weights:size(1) do 54 | -- print(i, vis.simplestr(weights[i])) 55 | -- end 56 | 57 | -- base_embedding = previous_embedding.output[test_input_frame_index]:clone() 58 | 59 | -- function render_changing_index(changing_index) 60 | -- output_dir = 'reports/renderings/mutate_'..network_name 61 | -- os.execute('mkdir -p '..output_dir) 62 | -- i = 0 63 | -- for change = -4, 1.5, 0.05 do 64 | -- changed_embedding = base_embedding:clone() 65 | -- changed_embedding[changing_index] = changed_embedding[changing_index] + change 66 | -- image.save(output_dir..'/changing_'..i..'_amount_'..change..'.png', decoder:forward(changed_embedding:reshape(1, 200))[1]:float():clone()) 67 | -- i = i + 1 68 | -- end 69 | -- end 70 | 71 | -- render_changing_index(idx) 72 | -------------------------------------------------------------------------------- /analyze_loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pprint 3 | import heapq as H 4 | 5 | lastvals = [] 6 | bestvals = [] 7 | lastvalsdict = {} 8 | bestvalsdict = {} 9 | k =10 10 | suppress_errors = True 11 | tag = 'ballsreg' 12 | 13 | for exp in os.listdir('.'): 14 | if '.py' not in exp and tag in exp: 15 | lines = open(os.path.join(exp,'val_loss.txt'), 'r').readlines() 16 | best = (float('inf'),0) 17 | for i in xrange(len(lines)): 18 | val_loss = float(lines[i]) 19 | if val_loss < best[0]: 20 | best = (val_loss, i) 21 | H.heappush(lastvals, (val_loss,exp)) 22 | H.heappush(bestvals, (best[0],(exp,best[1]))) 23 | lastvalsdict[exp] = val_loss 24 | bestvalsdict[exp] = best 25 | 26 | def orderbest(): 27 | global bestvals,lastvalsdict 28 | print '\nbest', k, 'vals' 29 | nbestvals = H.nsmallest(k,bestvals) 30 | for pair in nbestvals: 31 | bestval, info = pair 32 | exp, epcnum = info 33 | print exp,'\tbestval',bestval,'at valtest',epcnum,'\tlastval',lastvalsdict[exp] 34 | 35 | def orderlast(): 36 | global lastvals,bestvalsdict 37 | print '\nlast', k, 'vals' 38 | nlastvals = H.nsmallest(k,lastvals) 39 | for pair in nlastvals: 40 | lastval, exp = pair 41 | bestval, epcnum = bestvalsdict[exp] 42 | print exp,'\tbestval',bestval,'at valtest',epcnum,'\tlastval',lastval 43 | 44 | def orderall(): 45 | """ Find intersection of bestvals and lastvals """ 46 | global bestvals, lastvals,bestvalsdict,lastvalsdict 47 | print '\nbest', k, 'vals overall' 48 | nbestvals = H.nsmallest(k,bestvals) 49 | nlastvals = H.nsmallest(k,lastvals) 50 | bestexps = [] 51 | 52 | exps = {} 53 | for i in xrange(len(nlastvals)): 54 | lastval, explast = nlastvals[i] 55 | exps[explast] = i 56 | for j in xrange(len(nbestvals)): 57 | bestval, info = nbestvals[j] 58 | expbest, epcnum = info 59 | if expbest == explast: 60 | exps[expbest] += j 61 | for pair in exps.items(): 62 | H.heappush(bestexps,(pair[1],pair[0])) 63 | nbestexps = H.nsmallest(k,bestexps) 64 | for pair in nbestexps: 65 | rank, exp = pair 66 | lastval = lastvalsdict[exp] 67 | bestval, epcnum = bestvalsdict[exp] 68 | print exp,'\tbestval',bestval,'at valtest',epcnum,'\tlastval',lastval 69 | 70 | 71 | orderbest() 72 | orderlast() 73 | orderall() 74 | 75 | -------------------------------------------------------------------------------- /DownsampledAutoencoder.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | -- transforms () 4 | local DownsampledAutoencoder = function(dim_hidden, color_channels, feature_maps) 5 | local encoder_filter_size = 5 6 | local stride = 1 7 | local padding = 2 8 | local encoded_size = 10 9 | 10 | local model = nn.Sequential() 11 | local encoder = nn.Sequential() 12 | 13 | encoder:add(nn.SpatialConvolution(color_channels, feature_maps, encoder_filter_size, encoder_filter_size, stride, stride, padding, padding)) 14 | encoder:add(nn.SpatialMaxPooling(2,2,2,2)) 15 | encoder:add(nn.Threshold(0,1e-6)) 16 | 17 | encoder:add(nn.SpatialConvolution(feature_maps, feature_maps/2, encoder_filter_size, encoder_filter_size, stride, stride, padding, padding)) 18 | encoder:add(nn.SpatialMaxPooling(2,2,2,2)) 19 | encoder:add(nn.Threshold(0,1e-6)) 20 | 21 | encoder:add(nn.SpatialConvolution(feature_maps/2, feature_maps/4, encoder_filter_size, encoder_filter_size, stride, stride, padding, padding)) 22 | encoder:add(nn.SpatialMaxPooling(2,2,2,2)) 23 | encoder:add(nn.Threshold(0,1e-6)) 24 | 25 | encoder:add(nn.Reshape((feature_maps/4) * encoded_size * encoded_size)) 26 | encoder:add(nn.Linear((feature_maps/4) * encoded_size * encoded_size, dim_hidden)) 27 | 28 | local decoder_filter_size = 6 29 | local decoder = nn.Sequential() 30 | decoder:add(nn.Linear(dim_hidden, (feature_maps/4)*encoded_size*encoded_size )) 31 | decoder:add(nn.Threshold(0,1e-6)) 32 | 33 | decoder:add(nn.Reshape((feature_maps/4),encoded_size,encoded_size)) 34 | 35 | decoder:add(nn.SpatialUpSamplingNearest(2)) 36 | decoder:add(nn.SpatialConvolution(feature_maps/4,feature_maps/2, decoder_filter_size, decoder_filter_size)) 37 | decoder:add(nn.Threshold(0,1e-6)) 38 | 39 | decoder:add(nn.SpatialUpSamplingNearest(2)) 40 | decoder:add(nn.SpatialConvolution(feature_maps/2,feature_maps,decoder_filter_size,decoder_filter_size)) 41 | decoder:add(nn.Threshold(0,1e-6)) 42 | 43 | decoder:add(nn.SpatialUpSamplingNearest(2)) 44 | decoder:add(nn.SpatialConvolution(feature_maps,feature_maps,decoder_filter_size,decoder_filter_size)) 45 | decoder:add(nn.Threshold(0,1e-6)) 46 | 47 | decoder:add(nn.SpatialUpSamplingNearest(2)) 48 | decoder:add(nn.SpatialConvolution(feature_maps,color_channels,decoder_filter_size+1,decoder_filter_size+1)) 49 | decoder:add(nn.Sigmoid()) 50 | 51 | model:add(encoder) 52 | model:add(decoder) 53 | 54 | return model 55 | end 56 | 57 | return DownsampledAutoencoder 58 | -------------------------------------------------------------------------------- /KITTIEncoder.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'nngraph' 3 | 4 | require 'Print' 5 | require 'ChangeLimiter' 6 | require 'Noise' 7 | require 'ScheduledWeightSharpener' 8 | 9 | local KITTIEncoder = function(dim_hidden, color_channels, feature_maps, noise, sharpening_rate, scheduler_iteration, batch_norm, num_heads) 10 | 11 | local filter_size = 5 12 | local inputs = { 13 | nn.Identity()():annotate{name="input1"}, 14 | nn.Identity()():annotate{name="input2"}, 15 | } 16 | 17 | -- make two copies of an encoder 18 | 19 | local enc1 = nn.Sequential() 20 | enc1:add(nn.SpatialConvolution(color_channels, feature_maps, filter_size, filter_size)) 21 | enc1:add(nn.SpatialMaxPooling(2,2,2,2)) 22 | if batch_norm then 23 | enc1:add(nn.SpatialBatchNormalization(feature_maps)) 24 | end 25 | enc1:add(nn.Threshold(0,1e-6)) 26 | 27 | enc1:add(nn.SpatialConvolution(feature_maps, feature_maps/2, filter_size, filter_size)) 28 | enc1:add(nn.SpatialMaxPooling(2,2,2,2)) 29 | if batch_norm then 30 | enc1:add(nn.SpatialBatchNormalization(feature_maps/2)) 31 | end 32 | enc1:add(nn.Threshold(0,1e-6)) 33 | 34 | enc1:add(nn.SpatialConvolution(feature_maps/2, feature_maps/4, filter_size, filter_size)) 35 | enc1:add(nn.SpatialMaxPooling(2,2,2,2)) 36 | if batch_norm then 37 | enc1:add(nn.SpatialBatchNormalization(feature_maps/4)) 38 | end 39 | enc1:add(nn.Threshold(0,1e-6)) 40 | 41 | enc1:add(nn.Reshape((feature_maps/4) * 15*15)) 42 | enc1:add(nn.Linear((feature_maps/4) * 15*15, dim_hidden)) 43 | 44 | local enc2 = enc1:clone('weight', 'bias', 'gradWeight', 'gradBias') 45 | enc1 = enc1(inputs[1]) 46 | enc2 = enc2(inputs[2]) 47 | 48 | 49 | -- make the heads to analyze the encodings 50 | local heads = {} 51 | heads[1] = nn.Sequential() 52 | heads[1]:add(nn.JoinTable(2)) 53 | heads[1]:add(nn.Linear(dim_hidden * 2, dim_hidden)) 54 | heads[1]:add(nn.Sigmoid()) 55 | heads[1]:add(nn.Noise(noise)) 56 | heads[1]:add(nn.ScheduledWeightSharpener(sharpening_rate, scheduler_iteration)) 57 | heads[1]:add(nn.AddConstant(1e-20)) 58 | heads[1]:add(nn.Normalize(1, 1e-100)) 59 | 60 | for i = 2, num_heads do -- won't execute if num_heads < 2 61 | heads[i] = heads[1]:clone() 62 | end 63 | 64 | for i = 1, num_heads do 65 | heads[i] = heads[i]{enc1, enc2} 66 | end 67 | 68 | -- combine the distributions from all heads 69 | local dist_adder = nn.CAddTable()(heads) 70 | local dist_clamp = nn.Clamp(0, 1)(dist_adder) 71 | 72 | -- and use it to filter the encodings 73 | local change_limiter = nn.ChangeLimiter()({dist_clamp, enc1, enc2}):annotate{name="change_limiter"} 74 | 75 | 76 | local output = {change_limiter} 77 | return nn.gModule(inputs, output) 78 | end 79 | 80 | return KITTIEncoder 81 | -------------------------------------------------------------------------------- /DownsampledEncoder.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'nngraph' 3 | 4 | require 'Print' 5 | require 'ChangeLimiter' 6 | require 'Noise' 7 | require 'ScheduledWeightSharpener' 8 | 9 | local DownsampledEncoder = function(dim_hidden, color_channels, feature_maps, noise, sharpening_rate, scheduler_iteration, num_heads) 10 | 11 | local filter_size = 5 12 | local stride = 1 13 | local padding = 2 14 | local encoded_size = 10 15 | local inputs = { 16 | nn.Identity()():annotate{name="input1"}, 17 | nn.Identity()():annotate{name="input2"}, 18 | } 19 | 20 | -- make two copies of an encoder 21 | 22 | local enc1 = nn.Sequential() 23 | enc1:add(nn.SpatialConvolution(color_channels, feature_maps, filter_size, filter_size, stride, stride, padding, padding)) 24 | enc1:add(nn.SpatialMaxPooling(2,2,2,2)) 25 | enc1:add(nn.Threshold(0,1e-6)) 26 | 27 | enc1:add(nn.SpatialConvolution(feature_maps, feature_maps/2, filter_size, filter_size, stride, stride, padding, padding)) 28 | enc1:add(nn.SpatialMaxPooling(2,2,2,2)) 29 | enc1:add(nn.Threshold(0,1e-6)) 30 | 31 | enc1:add(nn.SpatialConvolution(feature_maps/2, feature_maps/4, filter_size, filter_size, stride, stride, padding, padding)) 32 | enc1:add(nn.SpatialMaxPooling(2,2,2,2)) 33 | enc1:add(nn.Threshold(0,1e-6)) 34 | 35 | enc1:add(nn.Reshape((feature_maps/4) * encoded_size * encoded_size)) 36 | enc1:add(nn.Linear((feature_maps/4) * encoded_size * encoded_size, dim_hidden)) 37 | 38 | local enc2 = enc1:clone('weight', 'bias', 'gradWeight', 'gradBias') 39 | enc1 = enc1(inputs[1]) 40 | enc2 = enc2(inputs[2]) 41 | 42 | 43 | -- make the heads to analyze the encodings 44 | local heads = {} 45 | heads[1] = nn.Sequential() 46 | heads[1]:add(nn.JoinTable(2)) 47 | heads[1]:add(nn.Linear(dim_hidden * 2, dim_hidden)) 48 | heads[1]:add(nn.Sigmoid()) 49 | heads[1]:add(nn.Noise(noise)) 50 | heads[1]:add(nn.ScheduledWeightSharpener(sharpening_rate, scheduler_iteration)) 51 | heads[1]:add(nn.AddConstant(1e-20)) 52 | heads[1]:add(nn.Normalize(1, 1e-100)) 53 | 54 | for i = 2, num_heads do 55 | heads[i] = heads[1]:clone() 56 | end 57 | 58 | for i = 1, num_heads do 59 | heads[i] = heads[i]{enc1, enc2} 60 | end 61 | 62 | local dist 63 | if num_heads > 1 then 64 | -- combine the distributions from all heads 65 | local dist_adder = nn.CAddTable()(heads) 66 | local dist_clamp = nn.Clamp(0, 1)(dist_adder) 67 | dist = dist_clamp 68 | else 69 | dist = heads[1] 70 | end 71 | 72 | -- and use it to filter the encodings 73 | local change_limiter = nn.ChangeLimiter()({dist, enc1, enc2}):annotate{name="change_limiter"} 74 | 75 | local output = {change_limiter} 76 | return nn.gModule(inputs, output) 77 | end 78 | 79 | return DownsampledEncoder 80 | -------------------------------------------------------------------------------- /AtariEncoder.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'nngraph' 3 | 4 | require 'Print' 5 | require 'ChangeLimiter' 6 | require 'Noise' 7 | require 'ScheduledWeightSharpener' 8 | 9 | local AtariEncoder = function(dim_hidden, color_channels, feature_maps, noise, sharpening_rate, scheduler_iteration, batch_norm, num_heads) 10 | 11 | local filter_size = 5 12 | local inputs = { 13 | nn.Identity()():annotate{name="input1"}, 14 | nn.Identity()():annotate{name="input2"}, 15 | } 16 | 17 | -- make two copies of an encoder 18 | 19 | local enc1 = nn.Sequential() 20 | enc1:add(nn.SpatialConvolution(color_channels, feature_maps, filter_size, filter_size)) 21 | enc1:add(nn.SpatialMaxPooling(2,2,2,2)) 22 | if batch_norm then 23 | enc1:add(nn.SpatialBatchNormalization(feature_maps)) 24 | end 25 | enc1:add(nn.Threshold(0,1e-6)) 26 | 27 | enc1:add(nn.SpatialConvolution(feature_maps, feature_maps/2, filter_size, filter_size)) 28 | enc1:add(nn.SpatialMaxPooling(2,2,2,2)) 29 | if batch_norm then 30 | enc1:add(nn.SpatialBatchNormalization(feature_maps/2)) 31 | end 32 | enc1:add(nn.Threshold(0,1e-6)) 33 | 34 | enc1:add(nn.SpatialConvolution(feature_maps/2, feature_maps/4, filter_size, filter_size)) 35 | enc1:add(nn.SpatialMaxPooling(2,2,2,2)) 36 | if batch_norm then 37 | enc1:add(nn.SpatialBatchNormalization(feature_maps/4)) 38 | end 39 | enc1:add(nn.Threshold(0,1e-6)) 40 | 41 | enc1:add(nn.Reshape((feature_maps/4) * 22*16)) 42 | enc1:add(nn.Linear((feature_maps/4) * 22*16, dim_hidden)) 43 | 44 | local enc2 = enc1:clone('weight', 'bias', 'gradWeight', 'gradBias') 45 | enc1 = enc1(inputs[1]) 46 | enc2 = enc2(inputs[2]) 47 | 48 | 49 | -- make the heads to analyze the encodings 50 | local heads = {} 51 | heads[1] = nn.Sequential() 52 | heads[1]:add(nn.JoinTable(2)) 53 | heads[1]:add(nn.Linear(dim_hidden * 2, dim_hidden)) 54 | heads[1]:add(nn.Sigmoid()) 55 | heads[1]:add(nn.Noise(noise)) 56 | heads[1]:add(nn.ScheduledWeightSharpener(sharpening_rate, scheduler_iteration)) 57 | heads[1]:add(nn.AddConstant(1e-20)) 58 | heads[1]:add(nn.Normalize(1, 1e-100)) 59 | 60 | for i = 2, num_heads do 61 | heads[i] = heads[1]:clone() 62 | end 63 | 64 | for i = 1, num_heads do 65 | heads[i] = heads[i]{enc1, enc2} 66 | end 67 | 68 | local dist 69 | if num_heads > 1 then 70 | -- combine the distributions from all heads 71 | local dist_adder = nn.CAddTable()(heads) 72 | local dist_clamp = nn.Clamp(0, 1)(dist_adder) 73 | dist = dist_clamp 74 | else 75 | dist = heads[1] 76 | end 77 | 78 | -- and use it to filter the encodings 79 | local change_limiter = nn.ChangeLimiter()({dist, enc1, enc2}):annotate{name="change_limiter"} 80 | 81 | local output = {change_limiter} 82 | return nn.gModule(inputs, output) 83 | end 84 | 85 | return AtariEncoder 86 | -------------------------------------------------------------------------------- /ActionEncoder.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'nngraph' 3 | 4 | require 'Print' 5 | require 'ChangeLimiter' 6 | require 'Noise' 7 | require 'ScheduledWeightSharpener' 8 | 9 | local ActionEncoder = function(dim_hidden, color_channels, feature_maps, noise, sharpening_rate, scheduler_iteration, batch_norm, num_heads) 10 | 11 | local filter_size = 5 12 | local inputs = { 13 | nn.Identity()():annotate{name="input1"}, 14 | nn.Identity()():annotate{name="input2"}, 15 | } 16 | 17 | -- make two copies of an encoder 18 | 19 | local enc1 = nn.Sequential() 20 | enc1:add(nn.SpatialConvolution(color_channels, feature_maps, filter_size, filter_size)) 21 | enc1:add(nn.SpatialMaxPooling(2,2,2,2)) 22 | if batch_norm then 23 | enc1:add(nn.SpatialBatchNormalization(feature_maps)) 24 | end 25 | enc1:add(nn.Threshold(0,1e-6)) 26 | 27 | enc1:add(nn.SpatialConvolution(feature_maps, feature_maps/2, filter_size, filter_size)) 28 | enc1:add(nn.SpatialMaxPooling(2,2,2,2)) 29 | if batch_norm then 30 | enc1:add(nn.SpatialBatchNormalization(feature_maps/2)) 31 | end 32 | enc1:add(nn.Threshold(0,1e-6)) 33 | 34 | enc1:add(nn.SpatialConvolution(feature_maps/2, feature_maps/4, filter_size, filter_size)) 35 | enc1:add(nn.SpatialMaxPooling(2,2,2,2)) 36 | if batch_norm then 37 | enc1:add(nn.SpatialBatchNormalization(feature_maps/4)) 38 | end 39 | enc1:add(nn.Threshold(0,1e-6)) 40 | 41 | enc1:add(nn.Reshape((feature_maps/4) * 16*11)) 42 | enc1:add(nn.Linear((feature_maps/4) * 16*11, dim_hidden)) 43 | 44 | local enc2 = enc1:clone('weight', 'bias', 'gradWeight', 'gradBias') 45 | enc1 = enc1(inputs[1]) 46 | enc2 = enc2(inputs[2]) 47 | 48 | 49 | -- make the heads to analyze the encodings 50 | local heads = {} 51 | heads[1] = nn.Sequential() 52 | heads[1]:add(nn.JoinTable(2)) 53 | heads[1]:add(nn.Linear(dim_hidden * 2, dim_hidden)) 54 | heads[1]:add(nn.Sigmoid()) 55 | heads[1]:add(nn.Noise(noise)) 56 | heads[1]:add(nn.ScheduledWeightSharpener(sharpening_rate, scheduler_iteration)) 57 | heads[1]:add(nn.AddConstant(1e-20)) 58 | heads[1]:add(nn.Normalize(1, 1e-100)) 59 | 60 | for i = 2, num_heads do 61 | heads[i] = heads[1]:clone() 62 | end 63 | 64 | for i = 1, num_heads do 65 | heads[i] = heads[i]{enc1, enc2} 66 | end 67 | 68 | local dist 69 | if num_heads > 1 then 70 | -- combine the distributions from all heads 71 | local dist_adder = nn.CAddTable()(heads) 72 | local dist_clamp = nn.Clamp(0, 1)(dist_adder) 73 | dist = dist_clamp 74 | else 75 | dist = heads[1] 76 | end 77 | 78 | -- and use it to filter the encodings 79 | local change_limiter = nn.ChangeLimiter()({dist, enc1, enc2}):annotate{name="change_limiter"} 80 | 81 | 82 | local output = {change_limiter} 83 | return nn.gModule(inputs, output) 84 | end 85 | 86 | return ActionEncoder 87 | -------------------------------------------------------------------------------- /UnsupervisedEncoder.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'nngraph' 3 | 4 | require 'Print' 5 | require 'ChangeLimiter' 6 | require 'Noise' 7 | require 'ScheduledWeightSharpener' 8 | 9 | local UnsupervisedEncoder = function(dim_hidden, color_channels, feature_maps, noise, sharpening_rate, scheduler_iteration, batch_norm, num_heads) 10 | 11 | local filter_size = 5 12 | local inputs = { 13 | nn.Identity()():annotate{name="input1"}, 14 | nn.Identity()():annotate{name="input2"}, 15 | } 16 | 17 | -- make two copies of an encoder 18 | 19 | local enc1 = nn.Sequential() 20 | enc1:add(nn.SpatialConvolution(color_channels, feature_maps, filter_size, filter_size)) 21 | enc1:add(nn.SpatialMaxPooling(2,2,2,2)) 22 | if batch_norm then 23 | enc1:add(nn.SpatialBatchNormalization(feature_maps)) 24 | end 25 | enc1:add(nn.Threshold(0,1e-6)) 26 | 27 | enc1:add(nn.SpatialConvolution(feature_maps, feature_maps/2, filter_size, filter_size)) 28 | enc1:add(nn.SpatialMaxPooling(2,2,2,2)) 29 | if batch_norm then 30 | enc1:add(nn.SpatialBatchNormalization(feature_maps/2)) 31 | end 32 | enc1:add(nn.Threshold(0,1e-6)) 33 | 34 | enc1:add(nn.SpatialConvolution(feature_maps/2, feature_maps/4, filter_size, filter_size)) 35 | enc1:add(nn.SpatialMaxPooling(2,2,2,2)) 36 | if batch_norm then 37 | enc1:add(nn.SpatialBatchNormalization(feature_maps/4)) 38 | end 39 | enc1:add(nn.Threshold(0,1e-6)) 40 | 41 | enc1:add(nn.Reshape((feature_maps/4) * 15*15)) 42 | enc1:add(nn.Linear((feature_maps/4) * 15*15, dim_hidden)) 43 | 44 | local enc2 = enc1:clone('weight', 'bias', 'gradWeight', 'gradBias') 45 | enc1 = enc1(inputs[1]) 46 | enc2 = enc2(inputs[2]) 47 | 48 | 49 | -- make the heads to analyze the encodings 50 | local heads = {} 51 | heads[1] = nn.Sequential() 52 | heads[1]:add(nn.JoinTable(2)) 53 | heads[1]:add(nn.Linear(dim_hidden * 2, dim_hidden)) 54 | heads[1]:add(nn.Sigmoid()) 55 | heads[1]:add(nn.Noise(noise)) 56 | heads[1]:add(nn.ScheduledWeightSharpener(sharpening_rate, scheduler_iteration)) 57 | heads[1]:add(nn.AddConstant(1e-20)) 58 | heads[1]:add(nn.Normalize(1, 1e-100)) 59 | 60 | for i = 2, num_heads do 61 | heads[i] = heads[1]:clone() 62 | end 63 | 64 | for i = 1, num_heads do 65 | heads[i] = heads[i]{enc1, enc2} 66 | end 67 | 68 | local dist 69 | if num_heads > 1 then 70 | -- combine the distributions from all heads 71 | local dist_adder = nn.CAddTable()(heads) 72 | local dist_clamp = nn.Clamp(0, 1)(dist_adder) 73 | dist = dist_clamp 74 | else 75 | dist = heads[1] 76 | end 77 | 78 | -- and use it to filter the encodings 79 | local change_limiter = nn.ChangeLimiter()({dist, enc1, enc2}):annotate{name="change_limiter"} 80 | 81 | local output = {change_limiter} 82 | return nn.gModule(inputs, output) 83 | end 84 | 85 | return UnsupervisedEncoder 86 | -------------------------------------------------------------------------------- /BallsEncoder.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'nngraph' 3 | 4 | require 'Print' 5 | require 'ChangeLimiter' 6 | require 'Noise' 7 | require 'ScheduledWeightSharpener' 8 | 9 | local BallsEncoder = function(dim_hidden, color_channels, feature_maps, noise, sharpening_rate, scheduler_iteration, batch_norm, num_heads) 10 | 11 | local filter_size = 5 12 | local inputs = { 13 | nn.Identity()():annotate{name="input1"}, 14 | nn.Identity()():annotate{name="input2"}, 15 | } 16 | 17 | -- make two copies of an encoder 18 | 19 | local enc1 = nn.Sequential() 20 | enc1:add(nn.SpatialConvolution(color_channels, feature_maps, filter_size, filter_size)) 21 | enc1:add(nn.SpatialMaxPooling(2,2,2,2)) 22 | if batch_norm then 23 | enc1:add(nn.SpatialBatchNormalization(feature_maps)) 24 | end 25 | enc1:add(nn.Threshold(0,1e-6)) 26 | 27 | enc1:add(nn.SpatialConvolution(feature_maps, feature_maps/2, filter_size, filter_size)) 28 | enc1:add(nn.SpatialMaxPooling(2,2,2,2)) 29 | if batch_norm then 30 | enc1:add(nn.SpatialBatchNormalization(feature_maps/2)) 31 | end 32 | enc1:add(nn.Threshold(0,1e-6)) 33 | 34 | enc1:add(nn.SpatialConvolution(feature_maps/2, feature_maps/4, filter_size, filter_size)) 35 | enc1:add(nn.SpatialMaxPooling(2,2,2,2)) 36 | if batch_norm then 37 | enc1:add(nn.SpatialBatchNormalization(feature_maps/4)) 38 | end 39 | enc1:add(nn.Threshold(0,1e-6)) 40 | 41 | enc1:add(nn.Reshape((feature_maps/4) * 15*15)) 42 | enc1:add(nn.Linear((feature_maps/4) * 15*15, dim_hidden)) 43 | 44 | local enc2 = enc1:clone('weight', 'bias', 'gradWeight', 'gradBias') 45 | enc1 = enc1(inputs[1]) 46 | enc2 = enc2(inputs[2]) 47 | 48 | 49 | -- make the heads to analyze the encodings 50 | local heads = {} 51 | heads[1] = nn.Sequential() 52 | heads[1]:add(nn.JoinTable(2)) 53 | heads[1]:add(nn.Linear(dim_hidden * 2, dim_hidden)) 54 | heads[1]:add(nn.Sigmoid()) 55 | heads[1]:add(nn.Noise(noise)) 56 | heads[1]:add(nn.ScheduledWeightSharpener(sharpening_rate, scheduler_iteration)) 57 | heads[1]:add(nn.AddConstant(1e-20)) 58 | heads[1]:add(nn.Normalize(1, 1e-100)) 59 | 60 | for i = 2, num_heads do 61 | heads[i] = heads[1]:clone() 62 | end 63 | 64 | for i = 1, num_heads do 65 | heads[i] = heads[i]{enc1, enc2} 66 | end 67 | 68 | local dist 69 | if num_heads > 1 then 70 | -- combine the distributions from all heads 71 | local dist_adder = nn.CAddTable()(heads) 72 | local dist_clamp = nn.Clamp(0, 1)(dist_adder) -- TODO is clamp the right way to go about it? 73 | dist = dist_clamp 74 | else 75 | dist = heads[1] 76 | end 77 | 78 | -- and use it to filter the encodings 79 | local change_limiter = nn.ChangeLimiter()({dist, enc1, enc2}):annotate{name="change_limiter"} 80 | 81 | local output = {change_limiter} 82 | return nn.gModule(inputs, output) 83 | end 84 | 85 | return BallsEncoder 86 | -------------------------------------------------------------------------------- /action_data_converter.lua: -------------------------------------------------------------------------------- 1 | require 'hdf5' 2 | require 'paths' 3 | require 'math' 4 | require 'xlua' 5 | 6 | --[[ 7 | Usage: Need a .h5 file below, generated from 8 | action_data_converter.py, which has been saved in 9 | 10 | 11 | --]] 12 | 13 | 14 | local dataset_name = 'actions_2_frame_subsample_5.h5' 15 | local dataset_folder = '/om/data/public/mbchang/udcign-data/action/raw/hdf5' 16 | local to_save_folder = '/om/data/public/mbchang/udcign-data/action/all' 17 | 18 | 19 | -- local dataset_name = 'test_actions_2_frame_subsample_5.h5' 20 | -- local dataset_folder = '/Users/MichaelChang/Documents/Researchlink/SuperUROP/Code/unsupervised-dcign/data/actions/raw/videos' 21 | -- local to_save_folder = '/Users/MichaelChang/Documents/Researchlink/SuperUROP/Code/unsupervised-dcign/data/actions/float' 22 | 23 | local bsize = 30 24 | 25 | function load_data(dataset_name, dataset_folder) 26 | local dataset_file = hdf5.open(dataset_folder .. '/' .. dataset_name, 'r') 27 | local examples = {} 28 | for action,data in pairs(dataset_file:all()) do 29 | for j = 1,data:size(1)-bsize,bsize do -- the frames in data re 30 | local batch = data[{{j,j+bsize-1}}] 31 | table.insert(examples, batch) -- each batch is a contiguous sequence (bsize, height, width) 32 | end 33 | end 34 | return examples 35 | end 36 | 37 | function split_batches(examples, bsize) 38 | local num_test = math.floor(#examples * 0.15) 39 | local num_val = num_test 40 | local num_train = #examples - 2*num_test 41 | 42 | local test = {} 43 | local val = {} 44 | local train = {} 45 | 46 | -- shuffle examples 47 | local ridxs = torch.randperm(#examples) 48 | for i = 1, ridxs:size(1) do 49 | xlua.progress(i, ridxs:size(1)) 50 | local batch = examples[ridxs[i]] 51 | if i <= num_train then 52 | table.insert(train, batch) 53 | elseif i <= num_train + num_val then 54 | table.insert(val, batch) 55 | else 56 | table.insert(test, batch) 57 | end 58 | end 59 | return {train, val, test} 60 | end 61 | 62 | 63 | function save_batches(datasets, savefolder) 64 | local train, val, test = unpack(datasets) 65 | local data_table = {train=train, val=val, test=test} 66 | for dname,data in pairs(data_table) do 67 | local subfolder = paths.concat(savefolder,dname) 68 | if not paths.dirp(subfolder) then paths.mkdir(subfolder) end 69 | local i = 1 70 | for _,b in pairs(data) do 71 | xlua.progress(i, #data) 72 | b = b:float() 73 | local batchname = paths.concat(subfolder, 'batch'..i) 74 | torch.save(batchname, b) 75 | i = i + 1 76 | end 77 | end 78 | end 79 | 80 | -- main 81 | print('loading data') 82 | local ex = load_data(dataset_name, dataset_folder) 83 | print('splitting batches') 84 | local train, val, test = unpack(split_batches(ex, bsize)) 85 | print('saving batches') 86 | save_batches({train, val, test}, to_save_folder) 87 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import os 4 | 5 | def num_to_one_hot(array, discrete_values): 6 | """ 7 | The values in the array come from the list discrete_values. 8 | For example, if discrete_values is [0.33, 1.0, 3.0] then all the values in this 9 | array are in [0.33, 1.0, 3.0]. 10 | 11 | This method adds another axis (last axis) and makes these values into a one-hot 12 | encoding of those discrete values. For example, if the array was shape (4,10) 13 | and len(discrete_values) was 3, then this method will produce an array 14 | with shape (4,10,3) 15 | """ 16 | n_values = len(discrete_values) 17 | broadcast = tuple([1 for i in xrange(array.ndim)] + [n_values]) 18 | array = np.tile(np.expand_dims(array,array.ndim+1), broadcast) 19 | for i in xrange(n_values): array[...,i] = array[...,i] == discrete_values[i] 20 | return array 21 | 22 | def one_hot_to_num(one_hot_vector, discrete_values): 23 | """ 24 | one_hot_vector: (n,) one hot vector 25 | discrete_values is a list of values that the onehot represents 26 | 27 | assumes that the one_hot_vector only as one 1 28 | 29 | return the VALUE in discrete_values that the one_hot_vector refers to 30 | """ 31 | # print one_hot_vector 32 | # TODO: this should return the actual value, not the index! 33 | assert sum(one_hot_vector) == 1 # it had better have one 1 34 | return discrete_values[int(np.nonzero(one_hot_vector)[0])] 35 | 36 | def stack(list_of_nparrays): 37 | """ 38 | input 39 | :nparray: list of numpy arrays 40 | output 41 | :stack each numpy array along a new dimension: axis=0 42 | """ 43 | st = lambda a: np.vstack(([np.expand_dims(x,axis=0) for x in a])) 44 | stacked = np.vstack(([np.expand_dims(x,axis=0) for x in list_of_nparrays])) 45 | assert stacked.shape[0] == len(list_of_nparrays) 46 | assert stacked.shape[1:] == list_of_nparrays[0].shape 47 | return stacked 48 | 49 | def save_dict_to_hdf5(dataset, dataset_name, dataset_folder): 50 | print '\nSaving', dataset_name 51 | h = h5py.File(os.path.join(dataset_folder, dataset_name + '.h5'), 'w') 52 | print dataset.keys() 53 | for k, v in dataset.items(): 54 | print 'Saving', k 55 | h.create_dataset(k, data=v, dtype='float64') 56 | h.close() 57 | print 'Reading saved file' 58 | g = h5py.File(os.path.join(dataset_folder, dataset_name + '.h5'), 'r') 59 | for k in g.keys(): 60 | print k 61 | print g[k][:].shape 62 | g.close() 63 | 64 | def load_hdf5(filename, datapath): 65 | """ 66 | Loads the data stored in the datapath stored in filename as a numpy array 67 | """ 68 | data = load_dict_from_hdf5(filename) 69 | return data[datapath] 70 | 71 | def load_dict_from_hdf5(filepath): 72 | data = {} 73 | g = h5py.File(filepath, 'r') 74 | for k in g.keys(): 75 | data[k] = g[k][:] 76 | return data 77 | 78 | def subtensor_equal(subtensor, tensor, dim): 79 | """ 80 | Return if subtensor, when broaadcasted along dim 81 | """ 82 | num_copies = tensor.shape[dim] 83 | 84 | subtensor_stack = np.concatenate([subtensor for s in num_copies], dim=dim) 85 | 86 | return subtensor_stack == tensor 87 | -------------------------------------------------------------------------------- /action_data_converter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import copy 4 | import utils as u 5 | import numpy as np 6 | import itertools 7 | from progressbar import ProgressBar 8 | 9 | # Usage: The videos are saved in '/om/data/public/mbchang/udcign-data/action/raw/videos' 10 | # I haven't figured out how to use cv2 on openmind yet, so copy the videos 11 | # to your local computer as the root folder below. There should be a folder for 12 | # each of the actions below under the videos folder 13 | 14 | # pc 15 | root = '/Users/MichaelChang/Documents/Researchlink/SuperUROP/Code/data/udcign/action/videos' 16 | out = '/Users/MichaelChang/Documents/Researchlink/SuperUROP/Code/data/udcign/action/hdf5' 17 | 18 | # openmind 19 | # root = '/om/data/public/mbchang/udcign-data/action/raw/videos' 20 | # out = '/om/data/public/mbchang/udcign-data/action/raw/hdf5' 21 | 22 | actions = ['boxing', 'handclapping', 'handwaving', 'jogging', 'running', 'walking'] 23 | scenario = 'd4' # d4 means outdoors 24 | subsample = 1 25 | gray = True 26 | 27 | action_data = {} 28 | 29 | for action in actions: 30 | print action 31 | action_vid_folder = os.path.join(root,action) 32 | action_vids = {} 33 | pbar = ProgressBar() 34 | for i in pbar(range(len(os.listdir(action_vid_folder)))): 35 | vid = os.listdir(action_vid_folder)[i] 36 | if scenario not in vid: continue 37 | # print vid 38 | # continue 39 | cap = cv2.VideoCapture(os.path.join(action_vid_folder, vid)) 40 | num_frames = cap.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT) 41 | video = [] 42 | # pbar = ProgressBar() 43 | 44 | # the frames are guaranteed to be consecutive 45 | for i in range(int(num_frames)): 46 | ret, frame = cap.read() 47 | # print frame 48 | 49 | if gray: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) # shape: (height, width) it is gray anyway 50 | if i % subsample == 0: 51 | frame = frame/float(255) # normalize 52 | frame = frame.astype('float32') 53 | # cv2.imshow('frame',frame) 54 | if gray: frame = np.tile(frame,(1,1,1)) # give it the channel dim 55 | video.append(frame) 56 | 57 | # import time 58 | # time.sleep(0.5) 59 | if cv2.waitKey(1) & 0xFF == ord('q'): 60 | break 61 | cap.release() 62 | # cv2.destroyAllWindows() 63 | 64 | video = u.stack(video) 65 | action_vids[vid] = video # video, subsampled, evenly spaced, consecutive video 66 | pbar.finish() 67 | u.save_dict_to_hdf5(dataset=action_vids, dataset_name=action+'_subsamp='+str(subsample)+'_scenario='+scenario, dataset_folder=out) 68 | 69 | # action_vids = np.vstack(action_vids) # consecutive video ACTUALLY THIS MIGHT NOT BE TRUE. WE NEED THE VIDEOS TO BE SEPARATE! 70 | 71 | # randomly permute -- don't do this! 72 | # tm1s = np.random.permutation(range(0,len(action_vids)-1,2)) 73 | # ts = np.array([i+1 for i in tm1s]) 74 | # shuffle_idxs = list(it.next() for it in itertools.cycle([iter(tm1s), iter(ts)])) # groups of 2 75 | # action_vids = action_vids[np.array(shuffle_idxs),:,:] 76 | 77 | # action_data[action] = action_vids 78 | 79 | # save 80 | # u.save_dict_to_hdf5(action_data, 'actions_2_frame_subsample_' + str(subsample), root) 81 | -------------------------------------------------------------------------------- /action_data_converter2.lua: -------------------------------------------------------------------------------- 1 | require 'hdf5' 2 | require 'paths' 3 | require 'math' 4 | require 'xlua' 5 | 6 | --[[ 7 | Usage: Need a .h5 file below, generated from 8 | action_data_converter.py, which has been saved in 9 | 10 | The difference between this and action_data_converter.lua is that this 11 | works with h5 files in which each h5 file is an action, rather than all 12 | the actions inside one h5 file 13 | --]] 14 | 15 | 16 | 17 | 18 | function load_data(dataset_name, dataset_folder, bsize) 19 | local dataset_file = hdf5.open(dataset_folder .. '/' .. dataset_name, 'r') 20 | local examples = {} 21 | for action,data in pairs(dataset_file:all()) do 22 | for j = 1,data:size(1)-bsize,bsize do -- the frames in data re 23 | local batch = data[{{j,j+bsize-1}}] 24 | table.insert(examples, batch) -- each batch is a contiguous sequence (bsize, height, width) 25 | end 26 | end 27 | return examples 28 | end 29 | 30 | function split_batches(examples) 31 | local num_test = math.floor(#examples * 0.15) 32 | local num_val = num_test 33 | local num_train = #examples - 2*num_test 34 | 35 | local test = {} 36 | local val = {} 37 | local train = {} 38 | 39 | -- shuffle examples 40 | local ridxs = torch.randperm(#examples) 41 | for i = 1, ridxs:size(1) do 42 | xlua.progress(i, ridxs:size(1)) 43 | local batch = examples[ridxs[i]] 44 | if i <= num_train then 45 | table.insert(train, batch) 46 | elseif i <= num_train + num_val then 47 | table.insert(val, batch) 48 | else 49 | table.insert(test, batch) 50 | end 51 | end 52 | return {train, val, test} 53 | end 54 | 55 | 56 | function save_batches(datasets, savefolder) 57 | local train, val, test = unpack(datasets) 58 | local data_table = {train=train, val=val, test=test} 59 | for dname,data in pairs(data_table) do 60 | local subfolder = paths.concat(savefolder,dname) 61 | if not paths.dirp(subfolder) then paths.mkdir(subfolder) end 62 | local i = 1 63 | for _,b in pairs(data) do 64 | xlua.progress(i, #data) 65 | b = b:float() 66 | local batchname = paths.concat(subfolder, 'batch'..i) 67 | -- torch.save(batchname, b) 68 | i = i + 1 69 | end 70 | end 71 | end 72 | 73 | function main() 74 | local actions = {'running', 'jogging', 'walking', 'handclapping', 'handwaving', 'boxing'} 75 | local dataset_folder = '/om/data/public/mbchang/udcign-data/action/raw/hdf5' 76 | local bsize = 30 77 | 78 | for _,action in pairs(actions) do 79 | local dataset_name = action .. '_subsamp=1_scenario=d4.h5' 80 | local to_save_folder = '/om/data/public/mbchang/udcign-data/action/'..action..'_d4' 81 | if not paths.dirp(to_save_folder) then paths.mkdir(to_save_folder) end 82 | print('dataset:'..dataset_name) 83 | print('to save:'..to_save_folder) 84 | 85 | -- main 86 | print('loading data') 87 | local ex = load_data(dataset_name, dataset_folder, bsize) 88 | print('splitting batches') 89 | local train, val, test = unpack(split_batches(ex)) 90 | print('saving batches') 91 | save_batches({train, val, test}, to_save_folder) 92 | end 93 | end 94 | 95 | 96 | main() 97 | -------------------------------------------------------------------------------- /action_data_converter_all.lua: -------------------------------------------------------------------------------- 1 | require 'hdf5' 2 | require 'paths' 3 | require 'math' 4 | require 'xlua' 5 | 6 | --[[ 7 | Usage: Need a .h5 file below, generated from 8 | action_data_converter.py, which has been saved in 9 | 10 | The difference between this and action_data_converter.lua is that this 11 | works with h5 files in which each h5 file is an action, rather than all 12 | the actions inside one h5 file 13 | --]] 14 | 15 | 16 | function load_data(dataset_name, dataset_folder, bsize) 17 | local dataset_file = hdf5.open(dataset_folder .. '/' .. dataset_name, 'r') 18 | local examples = {} 19 | for action,data in pairs(dataset_file:all()) do 20 | for j = 1,data:size(1)-bsize,bsize do -- the frames in data re 21 | local batch = data[{{j,j+bsize-1}}] 22 | table.insert(examples, batch) -- each batch is a contiguous sequence (bsize, height, width) 23 | end 24 | end 25 | return examples 26 | end 27 | 28 | function split_batches(examples) 29 | local num_test = math.floor(#examples * 0.15) 30 | local num_val = num_test 31 | local num_train = #examples - 2*num_test 32 | 33 | local test = {} 34 | local val = {} 35 | local train = {} 36 | 37 | -- shuffle examples 38 | local ridxs = torch.randperm(#examples) 39 | for i = 1, ridxs:size(1) do 40 | xlua.progress(i, ridxs:size(1)) 41 | local batch = examples[ridxs[i]] 42 | if i <= num_train then 43 | table.insert(train, batch) 44 | elseif i <= num_train + num_val then 45 | table.insert(val, batch) 46 | else 47 | table.insert(test, batch) 48 | end 49 | end 50 | return {train, val, test} 51 | end 52 | 53 | function save_batches(datasets, savefolder, idxs) 54 | local train, val, test = unpack(datasets) 55 | local data_table = {train=train, val=val, test=test} 56 | for dname,data in pairs(data_table) do 57 | local subfolder = paths.concat(savefolder,dname) 58 | if not paths.dirp(subfolder) then paths.mkdir(subfolder) end 59 | for _,b in pairs(data) do 60 | -- xlua.progress(idxs[dname], #data) 61 | local batchname = paths.concat(subfolder, 'batch'..idxs[dname]) 62 | print(dname..': '..batchname) 63 | torch.save(batchname, b:float()) 64 | idxs[dname] = idxs[dname] + 1 65 | end 66 | end 67 | return idxs 68 | end 69 | 70 | function main_all() 71 | local scenario = 'd4' 72 | local actions = {'running', 'jogging', 'walking', 'handclapping', 'handwaving', 'boxing'} 73 | local dataset_folder = '/om/data/public/mbchang/udcign-data/action/raw/hdf5' 74 | local to_save_folder = '/om/data/public/mbchang/udcign-data/action/allactions'..scenario 75 | if not paths.dirp(to_save_folder) then paths.mkdir(to_save_folder) end 76 | local bsize = 30 77 | local idxs = {train=1,val=1,test=1} -- train, val, test 78 | print(idxs) 79 | 80 | for _,action in pairs(actions) do 81 | local dataset_name = action .. '_subsamp=1_scenario=d4.h5' 82 | print('dataset:'..dataset_name) 83 | print('to save:'..to_save_folder) 84 | 85 | -- main 86 | print('loading data') 87 | local ex = load_data(dataset_name, dataset_folder, bsize) 88 | print('splitting batches') 89 | local train, val, test = unpack(split_batches(ex)) 90 | print('saving batches') 91 | idxs = save_batches({train, val, test}, to_save_folder, idxs) 92 | print(idxs) 93 | end 94 | end 95 | 96 | 97 | main_all() 98 | -------------------------------------------------------------------------------- /balls_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | dry_run = '--dry-run' in sys.argv 5 | local = '--local' in sys.argv 6 | detach = '--detach' in sys.argv 7 | 8 | dry_run = True 9 | local = False 10 | detach = True 11 | 12 | if not os.path.exists("slurm_logs"): 13 | os.makedirs("slurm_logs") 14 | 15 | if not os.path.exists("slurm_scripts"): 16 | os.makedirs("slurm_scripts") 17 | 18 | 19 | # networks_prefix = "networks" 20 | 21 | base_networks = { 22 | } 23 | 24 | 25 | # Don't give it a save name - that gets generated for you 26 | # jobs = [ 27 | # { 28 | # "import": "onestep", 29 | # }, 30 | # 31 | # 32 | # ] 33 | 34 | jobs = [{'mode':m, 'subsample':s, 'num_balls':n} 35 | for m in ['train', 'val', 'test'] 36 | for s in [3] 37 | for n in [1,2,3,4,5,6]] 38 | 39 | 40 | if dry_run: 41 | print "NOT starting jobs:" 42 | else: 43 | print "Starting jobs:" 44 | 45 | for job in jobs: 46 | jobname = "ballsgen" 47 | flagstring = "" 48 | for flag in job: 49 | if isinstance(job[flag], bool): 50 | if job[flag]: 51 | jobname = jobname + "_" + flag 52 | flagstring = flagstring + " --" + flag 53 | else: 54 | print "WARNING: Excluding 'False' flag " + flag 55 | elif flag == 'import': 56 | imported_network_name = job[flag] 57 | if imported_network_name in base_networks.keys(): 58 | network_location = base_networks[imported_network_name] 59 | jobname = jobname + "_" + flag + "_" + str(imported_network_name) 60 | flagstring = flagstring + " --" + flag + " " + str(network_location) 61 | else: 62 | jobname = jobname + "_" + flag + "_" + str(job[flag]) 63 | flagstring = flagstring + " --" + flag + " " + networks_prefix + "/" + str(job[flag]) 64 | else: 65 | jobname = jobname + "_" + flag + "_" + str(job[flag]) 66 | flagstring = flagstring + " --" + flag + " " + str(job[flag]) 67 | flagstring = flagstring #+ " --name " + jobname 68 | 69 | jobcommand = "python bouncing_balls.py" + flagstring 70 | 71 | print(jobcommand) 72 | if local and not dry_run: 73 | if detach: 74 | os.system(jobcommand + ' 2> slurm_logs/' + jobname + '.err 1> slurm_logs/' + jobname + '.out &') 75 | else: 76 | os.system(jobcommand) 77 | 78 | else: 79 | with open('slurm_scripts/' + jobname + '.slurm', 'w') as slurmfile: 80 | slurmfile.write("#!/bin/bash\n") 81 | slurmfile.write("#SBATCH --job-name"+"=" + jobname + "\n") 82 | slurmfile.write("#SBATCH --output=slurm_logs/" + jobname + ".out\n") 83 | # slurmfile.write("#SBATCH --error=slurm_logs/" + jobname + ".err\n") 84 | slurmfile.write("#SBATCH -N 1\n") 85 | slurmfile.write("#SBATCH -c 2\n") 86 | # slurmfile.write("#SBATCH -p gpu\n") 87 | # slurmfile.write("#SBATCH --gres=gpu:1\n") 88 | slurmfile.write("#SBATCH --mem=3000\n") 89 | slurmfile.write("#SBATCH --time=6-23:00:00\n") 90 | slurmfile.write("#SBATCH -x node027\n") 91 | slurmfile.write(jobcommand) 92 | 93 | if not dry_run: 94 | # if 'gpu' in job and job['gpu']: 95 | # os.system("sbatch -N 1 -c 2 --gres=gpu:1 -p gpu --mem=8000 --time=6-23:00:00 slurm_scripts/" + jobname + ".slurm &") 96 | # else: 97 | # os.system("sbatch -N 1 -c 2 --mem=8000 --time=6-23:00:00 slurm_scripts/" + jobname + ".slurm &") 98 | os.system("sbatch slurm_scripts/" + jobname + ".slurm &") 99 | -------------------------------------------------------------------------------- /runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | dry_run = '--dry-run' in sys.argv 5 | local = '--local' in sys.argv 6 | detach = '--detach' in sys.argv 7 | 8 | if not os.path.exists("slurm_logs"): 9 | os.makedirs("slurm_logs") 10 | 11 | if not os.path.exists("slurm_scripts"): 12 | os.makedirs("slurm_scripts") 13 | 14 | 15 | networks_prefix = "networks" 16 | 17 | base_networks = { 18 | } 19 | 20 | 21 | # Don't give it a save name - that gets generated for you 22 | # jobs = [ 23 | # { 24 | # "import": "onestep", 25 | # }, 26 | # 27 | # 28 | # ] 29 | 30 | jobs = [] 31 | 32 | noise_options = [0.1] 33 | sharpening_rate_options = [8] 34 | learning_rate_options = [3e-4] 35 | heads_options = [1] 36 | # L2_options = [1e-2, 1e-3, 1e-4] 37 | 38 | for noise in noise_options: 39 | for sharpening_rate in sharpening_rate_options: 40 | for learning_rate in learning_rate_options: 41 | for heads in heads_options: 42 | job = { 43 | "noise": noise, 44 | "sharpening_rate": sharpening_rate, 45 | "learning_rate": learning_rate, 46 | "heads": heads, 47 | 48 | "dual_objectives": True, 49 | 50 | "gpu": True, 51 | } 52 | jobs.append(job) 53 | 54 | 55 | if dry_run: 56 | print "NOT starting jobs:" 57 | else: 58 | print "Starting jobs:" 59 | 60 | for job in jobs: 61 | jobname = "unsup" 62 | flagstring = "" 63 | for flag in job: 64 | if isinstance(job[flag], bool): 65 | if job[flag]: 66 | jobname = jobname + "_" + flag 67 | flagstring = flagstring + " --" + flag 68 | else: 69 | print "WARNING: Excluding 'False' flag " + flag 70 | elif flag == 'import': 71 | imported_network_name = job[flag] 72 | if imported_network_name in base_networks.keys(): 73 | network_location = base_networks[imported_network_name] 74 | jobname = jobname + "_" + flag + "_" + str(imported_network_name) 75 | flagstring = flagstring + " --" + flag + " " + str(network_location) 76 | else: 77 | jobname = jobname + "_" + flag + "_" + str(job[flag]) 78 | flagstring = flagstring + " --" + flag + " " + networks_prefix + "/" + str(job[flag]) 79 | else: 80 | jobname = jobname + "_" + flag + "_" + str(job[flag]) 81 | flagstring = flagstring + " --" + flag + " " + str(job[flag]) 82 | flagstring = flagstring + " --name " + jobname 83 | 84 | jobcommand = "th main.lua" + flagstring 85 | 86 | print(jobcommand) 87 | if local and not dry_run: 88 | if detach: 89 | os.system(jobcommand + ' 2> slurm_logs/' + jobname + '.err 1> slurm_logs/' + jobname + '.out &') 90 | else: 91 | os.system(jobcommand) 92 | 93 | else: 94 | with open('slurm_scripts/' + jobname + '.slurm', 'w') as slurmfile: 95 | slurmfile.write("#!/bin/bash\n") 96 | slurmfile.write("#SBATCH --job-name"+"=" + jobname + "\n") 97 | slurmfile.write("#SBATCH --output=slurm_logs/" + jobname + ".out\n") 98 | slurmfile.write("#SBATCH --error=slurm_logs/" + jobname + ".err\n") 99 | slurmfile.write(jobcommand) 100 | 101 | if not dry_run: 102 | if 'gpu' in job and job['gpu']: 103 | os.system("sbatch -N 1 -c 2 --gres=gpu:1 -p gpu --mem=8000 --time=6-23:00:00 slurm_scripts/" + jobname + ".slurm &") 104 | else: 105 | os.system("sbatch -N 1 -c 2 --mem=8000 --time=6-23:00:00 slurm_scripts/" + jobname + ".slurm &") 106 | -------------------------------------------------------------------------------- /render_action_examples.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cutorch' 3 | require 'cunn' 4 | require 'paths' 5 | require 'lfs' 6 | 7 | vis = require 'vis' 8 | require 'ActionEncoder' 9 | require 'ActionDecoder' 10 | local data_loaders = require 'data_loaders' 11 | 12 | name = arg[1] 13 | -- dataset_name = arg[2] or name 14 | networks = {} 15 | while true do 16 | local line = io.read() 17 | if line == nil then 18 | break 19 | elseif not (string.match(line, 'NAME')) then 20 | -- strip whitespace 21 | line = string.gsub(line, "%s+", "") 22 | 23 | table.insert(networks, line) 24 | end 25 | end 26 | 27 | -- opt = { 28 | -- datasetdir = '/om/user/wwhitney/deep-game-engine', 29 | -- dataset_name = dataset_name, 30 | -- gpu = true, 31 | -- } 32 | 33 | base_directory = "/home/mbchang/code/unsupervised-dcign/logslink" -- TODO 34 | 35 | local jobname = name ..'_'.. os.date("%b_%d_%H_%M") 36 | local output_path = 'renderings/'..jobname 37 | os.execute('mkdir -p '..output_path) 38 | 39 | 40 | function getLastSnapshot(network_name) 41 | local res_file = io.popen("ls -t "..paths.concat(base_directory, network_name).." | grep -i epoch | head -n 1") 42 | local result = res_file:read():match( "^%s*(.-)%s*$" ) 43 | res_file:close() 44 | return result 45 | end 46 | 47 | for _, network in ipairs(networks) do 48 | collectgarbage() 49 | 50 | print('') 51 | print(network) 52 | local checkpoint = torch.load(paths.concat(base_directory, network, getLastSnapshot(network))) 53 | opt = checkpoint.opt 54 | local model = checkpoint.model 55 | local scheduler_iteration = torch.Tensor{checkpoint.step} 56 | model:evaluate() 57 | 58 | local encoder = model.modules[1] 59 | local sharpener = encoder:findModules('nn.ScheduledWeightSharpener')[1] 60 | sharpener.iteration_container = scheduler_iteration 61 | print("Current sharpening: ", sharpener:getP()) 62 | 63 | local weight_predictor = encoder:findModules('nn.Normalize')[1] 64 | local previous_embedding = encoder:findModules('nn.Linear')[1] 65 | local current_embedding = encoder:findModules('nn.Linear')[2] 66 | local decoder = model.modules[2] 67 | 68 | local images = {} 69 | for i = 100, 100 do -- for now only render one batch 70 | -- fetch a batch 71 | local input = data_loaders.load_action_batch(i, 'test') 72 | local output = model:forward(input):clone() 73 | local embedding_from_previous = previous_embedding.output:clone() 74 | local embedding_from_current = current_embedding.output:clone() 75 | 76 | local reconstruction_from_previous = decoder:forward(embedding_from_previous):clone() 77 | local reconstruction_from_current = decoder:forward(embedding_from_current):clone() 78 | 79 | local weight_norms = torch.zeros(output:size(1)) 80 | for input_index = 1, output:size(1) do 81 | local weights = weight_predictor.output[input_index] 82 | local max_weight, varying_index = weights:max(1) 83 | print("Varying index: " .. vis.simplestr(varying_index), "Weight: " .. vis.simplestr(max_weight)) 84 | 85 | -- local embedding_change = embedding_from_current[input_index] - embedding_from_previous[input_index] 86 | -- local normalized_embedding_change = embedding_change / embedding_change:norm(1) 87 | -- print("Independence of embedding change: ", normalized_embedding_change:norm()) 88 | -- print("Distance between timesteps: ", embedding_change:norm()) 89 | 90 | weight_norms[input_index] = weights:norm() 91 | 92 | local image_row = {} 93 | table.insert(image_row, input[1][input_index]:float()) 94 | table.insert(image_row, input[2][input_index]:float()) 95 | table.insert(image_row, reconstruction_from_previous[input_index]:float()) 96 | table.insert(image_row, reconstruction_from_current[input_index]:float()) 97 | table.insert(image_row, output[input_index]:float()) 98 | table.insert(images, image_row) 99 | end 100 | print("Mean independence of weights: ", weight_norms:mean()) 101 | 102 | collectgarbage() 103 | end 104 | vis.save_image_grid(paths.concat(output_path, network ..'.png'), images) 105 | end 106 | 107 | 108 | print("done") 109 | -------------------------------------------------------------------------------- /render_examples.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cutorch' 3 | require 'cunn' 4 | require 'paths' 5 | require 'lfs' 6 | 7 | vis = require 'vis' 8 | require 'UnsupervisedEncoder' 9 | require 'Decoder' 10 | local data_loaders = require 'data_loaders' 11 | 12 | name = arg[1] 13 | networks = {} 14 | while true do 15 | local line = io.read() 16 | if line == nil then break end 17 | 18 | -- strip whitespace 19 | line = string.gsub(line, "%s+", "") 20 | 21 | table.insert(networks, line) 22 | end 23 | 24 | opt = { 25 | datasetdir = '/om/user/wwhitney/facegen/CNN_DATASET', 26 | gpu = true, 27 | } 28 | 29 | base_directory = "/om/user/wwhitney/unsupervised-dcign/networks" 30 | 31 | local jobname = name ..'_'.. os.date("%b_%d_%H_%M") 32 | local output_path = 'reports/renderings/'..jobname 33 | os.execute('mkdir -p '..output_path) 34 | 35 | local dataset_types = {"AZ_VARIED", "EL_VARIED", "LIGHT_AZ_VARIED"} 36 | 37 | 38 | function getLastSnapshot(network_name) 39 | local res_file = io.popen("ls -t "..paths.concat(base_directory, network_name).." | grep -i epoch | head -n 1") 40 | local result = res_file:read():match( "^%s*(.-)%s*$" ) 41 | res_file:close() 42 | return result 43 | end 44 | 45 | for _, network in ipairs(networks) do 46 | collectgarbage() 47 | print('') 48 | print(network) 49 | local checkpoint = torch.load(paths.concat(base_directory, network, getLastSnapshot(network))) 50 | local model = checkpoint.model 51 | local scheduler_iteration = torch.Tensor{checkpoint.step} 52 | model:evaluate() 53 | 54 | local encoder = model.modules[1] 55 | local sharpener = encoder:findModules('nn.ScheduledWeightSharpener')[1] 56 | sharpener.iteration_container = scheduler_iteration 57 | print("Current sharpening: ", sharpener:getP()) 58 | 59 | local weight_predictor = encoder:findModules('nn.Normalize')[1] 60 | local previous_embedding = encoder:findModules('nn.Linear')[1] 61 | local current_embedding = encoder:findModules('nn.Linear')[2] 62 | local decoder = model.modules[2] 63 | 64 | for _, variation in ipairs(dataset_types) do 65 | local images = {} 66 | for i = 1, 1 do -- for now only render one batch 67 | -- fetch a batch 68 | local input = data_loaders.load_mv_batch(i, variation, 'FT_test') 69 | local output = model:forward(input):clone() 70 | local embedding_from_previous = previous_embedding.output:clone() 71 | local embedding_from_current = current_embedding.output:clone() 72 | 73 | local reconstruction_from_previous = decoder:forward(embedding_from_previous):clone() 74 | local reconstruction_from_current = decoder:forward(embedding_from_current):clone() 75 | 76 | local weight_norms = torch.zeros(output:size(1)) 77 | for input_index = 1, output:size(1) do 78 | local weights = weight_predictor.output[input_index] 79 | local max_weight, varying_index = weights:max(1) 80 | print("Varying index: " .. vis.simplestr(varying_index), "Weight: " .. vis.simplestr(max_weight)) 81 | 82 | -- local embedding_change = embedding_from_current[input_index] - embedding_from_previous[input_index] 83 | -- local normalized_embedding_change = embedding_change / embedding_change:norm(1) 84 | -- print("Independence of embedding change: ", normalized_embedding_change:norm()) 85 | -- print("Distance between timesteps: ", embedding_change:norm()) 86 | 87 | weight_norms[input_index] = weights:norm() 88 | 89 | local image_row = {} 90 | table.insert(image_row, input[1][input_index]:float()) 91 | table.insert(image_row, input[2][input_index]:float()) 92 | table.insert(image_row, reconstruction_from_previous[input_index]:float()) 93 | table.insert(image_row, reconstruction_from_current[input_index]:float()) 94 | table.insert(image_row, output[input_index]:float()) 95 | table.insert(images, image_row) 96 | end 97 | print("Mean independence of weights: ", weight_norms:mean()) 98 | 99 | end 100 | vis.save_image_grid(paths.concat(output_path, network .."-"..variation..'.png'), images) 101 | end 102 | collectgarbage() 103 | end 104 | 105 | 106 | print("done") 107 | -------------------------------------------------------------------------------- /atari_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | dry_run = '--dry-run' in sys.argv 5 | local = '--local' in sys.argv 6 | detach = '--detach' in sys.argv 7 | 8 | if not os.path.exists("slurm_logs"): 9 | os.makedirs("slurm_logs") 10 | 11 | if not os.path.exists("slurm_scripts"): 12 | os.makedirs("slurm_scripts") 13 | 14 | 15 | networks_prefix = "networks" 16 | 17 | base_networks = { 18 | } 19 | 20 | 21 | # Don't give it a save name - that gets generated for you 22 | # jobs = [ 23 | # { 24 | # "import": "onestep", 25 | # }, 26 | # 27 | # 28 | # ] 29 | 30 | jobs = [] 31 | 32 | noise_options = [0.1] 33 | sharpening_rate_options = [10] 34 | learning_rate_options = [2e-4] 35 | heads_options = [1, 3] 36 | motion_scale_options = [3] 37 | frame_interval_options = [3] 38 | dataset_name_options = ["breakout", "space_invaders"] 39 | # L2_options = [1e-2, 1e-3, 1e-4] 40 | 41 | for noise in noise_options: 42 | for sharpening_rate in sharpening_rate_options: 43 | for learning_rate in learning_rate_options: 44 | for heads in heads_options: 45 | for motion_scale in motion_scale_options: 46 | for frame_interval in frame_interval_options: 47 | for dataset_name in dataset_name_options: 48 | job = { 49 | "noise": noise, 50 | "sharpening_rate": sharpening_rate, 51 | "learning_rate": learning_rate, 52 | "heads": heads, 53 | "motion_scale": motion_scale, 54 | "frame_interval": frame_interval, 55 | "dataset_name": dataset_name, 56 | 57 | "dual_objectives": True, 58 | "gpu": True, 59 | } 60 | jobs.append(job) 61 | 62 | 63 | if dry_run: 64 | print "NOT starting jobs:" 65 | else: 66 | print "Starting jobs:" 67 | 68 | for job in jobs: 69 | jobname = "atari" 70 | flagstring = "" 71 | for flag in job: 72 | if isinstance(job[flag], bool): 73 | if job[flag]: 74 | jobname = jobname + "_" + flag 75 | flagstring = flagstring + " --" + flag 76 | else: 77 | print "WARNING: Excluding 'False' flag " + flag 78 | elif flag == 'import': 79 | imported_network_name = job[flag] 80 | if imported_network_name in base_networks.keys(): 81 | network_location = base_networks[imported_network_name] 82 | jobname = jobname + "_" + flag + "_" + str(imported_network_name) 83 | flagstring = flagstring + " --" + flag + " " + str(network_location) 84 | else: 85 | jobname = jobname + "_" + flag + "_" + str(job[flag]) 86 | flagstring = flagstring + " --" + flag + " " + networks_prefix + "/" + str(job[flag]) 87 | else: 88 | jobname = jobname + "_" + flag + "_" + str(job[flag]) 89 | flagstring = flagstring + " --" + flag + " " + str(job[flag]) 90 | flagstring = flagstring + " --name " + jobname 91 | 92 | jobcommand = "th atari_main.lua" + flagstring 93 | 94 | print(jobcommand) 95 | if local and not dry_run: 96 | if detach: 97 | os.system(jobcommand + ' 2> slurm_logs/' + jobname + '.err 1> slurm_logs/' + jobname + '.out &') 98 | else: 99 | os.system(jobcommand) 100 | 101 | else: 102 | with open('slurm_scripts/' + jobname + '.slurm', 'w') as slurmfile: 103 | slurmfile.write("#!/bin/bash\n") 104 | slurmfile.write("#SBATCH --job-name"+"=" + jobname + "\n") 105 | slurmfile.write("#SBATCH --output=slurm_logs/" + jobname + ".out\n") 106 | slurmfile.write("#SBATCH --error=slurm_logs/" + jobname + ".err\n") 107 | slurmfile.write(jobcommand) 108 | 109 | if not dry_run: 110 | if 'gpu' in job and job['gpu']: 111 | os.system("sbatch -N 1 -c 2 --gres=gpu:1 -p gpu --mem=8000 --time=6-23:00:00 slurm_scripts/" + jobname + ".slurm &") 112 | else: 113 | os.system("sbatch -N 1 -c 2 --mem=8000 --time=6-23:00:00 slurm_scripts/" + jobname + ".slurm &") 114 | -------------------------------------------------------------------------------- /render_balls_examples.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cutorch' 3 | require 'cunn' 4 | require 'paths' 5 | require 'lfs' 6 | 7 | vis = require 'vis' 8 | require 'BallsEncoder' 9 | require 'Decoder' 10 | local data_loaders = require 'data_loaders' 11 | 12 | name = arg[1] 13 | -- dataset_name = arg[2] or name 14 | networks = {} 15 | while true do 16 | local line = io.read() 17 | if line == nil then 18 | break 19 | elseif not (string.match(line, 'NAME')) then 20 | -- strip whitespace 21 | line = string.gsub(line, "%s+", "") 22 | 23 | table.insert(networks, line) 24 | end 25 | end 26 | 27 | -- opt = { 28 | -- datasetdir = '/om/user/wwhitney/deep-game-engine', 29 | -- dataset_name = dataset_name, 30 | -- gpu = true, 31 | -- } 32 | 33 | base_directory = '/home/mbchang/code/unsupervised-dcign/logslink' 34 | 35 | local jobname = name ..'_'.. os.date("%b_%d_%H_%M") 36 | local output_path = 'renderings/'..jobname 37 | os.execute('mkdir -p '..output_path) 38 | 39 | 40 | function getLastSnapshot(network_name) 41 | local res_file = io.popen("ls -t "..paths.concat(base_directory, network_name).." | grep -i epoch | head -n 1") 42 | local status, result = pcall(function() return res_file:read():match( "^%s*(.-)%s*$" ) end) 43 | res_file:close() 44 | if not status then 45 | return false 46 | else 47 | return result 48 | end 49 | end 50 | 51 | for _, network in ipairs(networks) do 52 | collectgarbage() 53 | 54 | print('') 55 | print(network) 56 | local checkpoint = torch.load(paths.concat(base_directory, network, getLastSnapshot(network))) 57 | opt = checkpoint.opt 58 | local model = checkpoint.model 59 | local scheduler_iteration = torch.Tensor{checkpoint.step} 60 | model:evaluate() 61 | 62 | local encoder = model.modules[1] 63 | local sharpener = encoder:findModules('nn.ScheduledWeightSharpener')[1] 64 | sharpener.iteration_container = scheduler_iteration 65 | print("Current sharpening: ", sharpener:getP()) 66 | 67 | local weight_predictor = encoder:findModules('nn.Normalize')[1] 68 | local previous_embedding = encoder:findModules('nn.Linear')[1] 69 | local current_embedding = encoder:findModules('nn.Linear')[2] 70 | local decoder = model.modules[2] 71 | 72 | local images = {} 73 | for i = 10, 10 do -- for now only render one batch 74 | -- fetch a batch 75 | local input = data_loaders.load_balls_batch(i, 'test') 76 | local output = model:forward(input):clone() 77 | local embedding_from_previous = previous_embedding.output:clone() 78 | local embedding_from_current = current_embedding.output:clone() 79 | 80 | local reconstruction_from_previous = decoder:forward(embedding_from_previous):clone() 81 | local reconstruction_from_current = decoder:forward(embedding_from_current):clone() 82 | 83 | local weight_norms = torch.zeros(output:size(1)) 84 | for input_index = 1, output:size(1) do 85 | local weights = weight_predictor.output[input_index] 86 | local max_weight, varying_index = weights:max(1) 87 | print("Varying index: " .. vis.simplestr(varying_index), "Weight: " .. vis.simplestr(max_weight)) 88 | 89 | -- local embedding_change = embedding_from_current[input_index] - embedding_from_previous[input_index] 90 | -- local normalized_embedding_change = embedding_change / embedding_change:norm(1) 91 | -- print("Independence of embedding change: ", normalized_embedding_change:norm()) 92 | -- print("Distance between timesteps: ", embedding_change:norm()) 93 | 94 | weight_norms[input_index] = weights:norm() 95 | 96 | local image_row = {} 97 | table.insert(image_row, input[1][input_index]:float()) 98 | table.insert(image_row, input[2][input_index]:float()) 99 | table.insert(image_row, reconstruction_from_previous[input_index]:float()) 100 | table.insert(image_row, reconstruction_from_current[input_index]:float()) 101 | table.insert(image_row, output[input_index]:float()) 102 | table.insert(images, image_row) 103 | end 104 | print("Mean independence of weights: ", weight_norms:mean()) 105 | 106 | collectgarbage() 107 | end 108 | vis.save_image_grid(paths.concat(output_path, network ..'.png'), images) 109 | end 110 | 111 | 112 | print("done") 113 | -------------------------------------------------------------------------------- /kitti_data_converter.lua: -------------------------------------------------------------------------------- 1 | require 'paths' 2 | require 'image' 3 | require 'data_utils' 4 | require 'utils' 5 | 6 | local raw_root = '/om/data/public/mbchang/udcign-data/kitti/raw/videos' 7 | local out_root = '/om/data/public/mbchang/udcign-data/kitti/data/all' 8 | local bsize = 30 9 | local ch = 'image_02' 10 | local data_root = '/om/data/public/mbchang/udcign-data/kitti/data_bsize'..bsize 11 | local dim = 150 12 | local id = 0 13 | local subsample = 1 14 | 15 | -- create all the batches for one setting 16 | -- a setting can be 'road', 'campus', etc 17 | function create_batches(setting_folder, bsize, idxs) 18 | print(idxs) 19 | local setting_ex = {} 20 | for group in paths.iterdirs(setting_folder) do 21 | local img_folder = paths.concat(setting_folder,group,ch,'data_resize') 22 | local ex = get_examples(img_folder) 23 | local groups2 = duplicate(ex) -- now in groups of 2 24 | assert(#groups2 == #ex - 1) 25 | setting_ex = extend(setting_ex,groups2) 26 | print(#groups2,#setting_ex) 27 | end 28 | print(#setting_ex) 29 | -- setting_ex is now a huge table of 2-tables for this particular setting 30 | local setting_batches = group2batches(setting_ex,bsize) 31 | -- tables of examples of (1,150,150) 32 | print('split batches: '..#setting_batches) 33 | local train, val, test = unpack(split_batches(setting_batches, bsize)) 34 | 35 | -- now save 36 | print('save batches: train '..#train..' val '..#val..' test '..#test) 37 | local new_idxs = save_batches({train, val, test}, out_root, idxs) 38 | 39 | return new_idxs 40 | end 41 | 42 | function group2batches(examples, bsize) 43 | local batches = {} 44 | for j = 1, #examples-bsize, bsize do 45 | local batch = subrange(examples, j, j+bsize-1) 46 | table.insert(batches, batch) -- each batch is table of 2-tables 47 | end 48 | return batches 49 | end 50 | 51 | 52 | function get_examples(folder) 53 | -- first get the imgs inside this folder 54 | print(folder) 55 | local imgs = {} 56 | for img in paths.iterfiles(folder) do 57 | imgpath = paths.concat(folder, img) 58 | imgs[#imgs+1] = imgpath 59 | end 60 | table.sort(imgs) -- consecutive 61 | 62 | -- then subsample these images 63 | local subsampled_imgs = {} 64 | for k,img_file in pairs(imgs) do 65 | if k % subsample == 0 then 66 | local img = image.load(img_file) 67 | subsampled_imgs[#subsampled_imgs+1] = img:float() 68 | end 69 | end 70 | return subsampled_imgs 71 | end 72 | 73 | -- turns n examples into n-1 example pairs 74 | -- return a table of size n-1 of tables of size 2 75 | function duplicate(examples) 76 | local tm1s = subrange(examples,1,#examples-1) 77 | local ts = subrange(examples,2,#examples) 78 | local g2 = {} 79 | for i = 1, #tm1s do 80 | g2[#g2+1] = {tm1s[i],ts[i]} -- t is second element 81 | end 82 | return g2 83 | end 84 | 85 | 86 | function save_batches(datasets, savefolder, idxs) 87 | local train, val, test = unpack(datasets) 88 | local data_table = {train=train, val=val, test=test} 89 | for dname,data in pairs(data_table) do 90 | local subfolder = paths.concat(savefolder,dname) 91 | if not paths.dirp(subfolder) then paths.mkdir(subfolder) end 92 | for _,b in pairs(data) do 93 | -- xlua.progress(idxs[dname], #data) 94 | local batchname = paths.concat(subfolder, 'batch'..idxs[dname]) 95 | print(dname..': '..batchname) 96 | -- torch.save(batchname, b) 97 | idxs[dname] = idxs[dname] + 1 98 | end 99 | end 100 | return idxs 101 | end 102 | 103 | 104 | -- main 105 | function main() 106 | local idxs = {train=1,val=1,test=1} -- train, val, test 107 | for setting in paths.iterdirs(raw_root) do 108 | print('setting '..setting) 109 | idxs = create_batches(paths.concat(raw_root,setting), bsize, idxs) 110 | 111 | -- 112 | -- for group in paths.iterdirs(paths.concat(raw_root,setting)) do 113 | -- print(idxs) 114 | -- local img_folder = paths.concat(raw_root,setting,group,ch,'data_resize') 115 | -- idxs = create_batches(img_folder, bsize, idxs) -- create train, val, test for this setting 116 | -- end 117 | 118 | end 119 | print(idxs) 120 | end 121 | 122 | main() 123 | 124 | 125 | -- to use: first call resize.py, then call this file 126 | -------------------------------------------------------------------------------- /action_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | dry_run = '--dry-run' in sys.argv 5 | local = '--local' in sys.argv 6 | detach = '--detach' in sys.argv 7 | 8 | dry_run = False 9 | local = False 10 | detach = True 11 | 12 | if not os.path.exists("slurm_logs"): 13 | os.makedirs("slurm_logs") 14 | 15 | if not os.path.exists("slurm_scripts"): 16 | os.makedirs("slurm_scripts") 17 | 18 | 19 | networks_prefix = "networks" 20 | 21 | base_networks = { 22 | } 23 | 24 | 25 | # Don't give it a save name - that gets generated for you 26 | # jobs = [ 27 | # { 28 | # "import": "onestep", 29 | # }, 30 | # 31 | # 32 | # ] 33 | 34 | jobs = [] 35 | 36 | # noise_options = [0.1] 37 | # sharpening_rate_options = [10] 38 | learning_rate_options = [20e-5] 39 | # motion_scale_options = [3] 40 | heads_options = [1,2,3] 41 | dataset_name_options = ['allactionsd4'] 42 | # L2_options = [1e-2, 1e-3, 1e-4] 43 | 44 | # for noise in noise_options: 45 | # for sharpening_rate in sharpening_rate_options: 46 | for learning_rate in learning_rate_options: 47 | for heads in heads_options: 48 | for dataset_name in dataset_name_options: 49 | # for motion_scale in motion_scale_options: 50 | job = { 51 | # "noise": noise, 52 | # "sharpening_rate": sharpening_rate, 53 | "learning_rate": learning_rate, 54 | "heads": heads, 55 | "dataset_name": dataset_name, 56 | # "motion_scale": motion_scale 57 | # "gpu": True, 58 | } 59 | jobs.append(job) 60 | 61 | 62 | if dry_run: 63 | print "NOT starting jobs:" 64 | else: 65 | print "Starting jobs:" 66 | 67 | for job in jobs: 68 | jobname = "actiond4" 69 | flagstring = "" 70 | for flag in job: 71 | if isinstance(job[flag], bool): 72 | if job[flag]: 73 | jobname = jobname + "_" + flag 74 | flagstring = flagstring + " --" + flag 75 | else: 76 | print "WARNING: Excluding 'False' flag " + flag 77 | elif flag == 'import': 78 | imported_network_name = job[flag] 79 | if imported_network_name in base_networks.keys(): 80 | network_location = base_networks[imported_network_name] 81 | jobname = jobname + "_" + flag + "_" + str(imported_network_name) 82 | flagstring = flagstring + " --" + flag + " " + str(network_location) 83 | else: 84 | jobname = jobname + "_" + flag + "_" + str(job[flag]) 85 | flagstring = flagstring + " --" + flag + " " + networks_prefix + "/" + str(job[flag]) 86 | else: 87 | jobname = jobname + "_" + flag + "_" + str(job[flag]) 88 | flagstring = flagstring + " --" + flag + " " + str(job[flag]) 89 | flagstring = flagstring + " --name " + jobname 90 | 91 | jobcommand = "th action_main.lua" + flagstring 92 | 93 | print(jobcommand) 94 | if local and not dry_run: 95 | if detach: 96 | os.system(jobcommand + ' 2> slurm_logs/' + jobname + '.err 1> slurm_logs/' + jobname + '.out &') 97 | else: 98 | os.system(jobcommand) 99 | 100 | else: 101 | with open('slurm_scripts/' + jobname + '.slurm', 'w') as slurmfile: 102 | slurmfile.write("#!/bin/bash\n") 103 | slurmfile.write("#SBATCH --job-name"+"=" + jobname + "\n") 104 | slurmfile.write("#SBATCH --output=slurm_logs/" + jobname + ".out\n") 105 | # slurmfile.write("#SBATCH --error=slurm_logs/" + jobname + ".err\n") 106 | slurmfile.write("#SBATCH -N 1\n") 107 | slurmfile.write("#SBATCH -c 2\n") 108 | slurmfile.write("#SBATCH -p gpu\n") 109 | slurmfile.write("#SBATCH --gres=gpu:1\n") 110 | slurmfile.write("#SBATCH --mem=5000\n") 111 | slurmfile.write("#SBATCH --time=6-23:00:00\n") 112 | slurmfile.write(jobcommand) 113 | 114 | if not dry_run: 115 | # if 'gpu' in job and job['gpu']: 116 | # os.system("sbatch -N 1 -c 2 --gres=gpu:1 -p gpu --mem=8000 --time=6-23:00:00 slurm_scripts/" + jobname + ".slurm &") 117 | # else: 118 | # os.system("sbatch -N 1 -c 2 --mem=8000 --time=6-23:00:00 slurm_scripts/" + jobname + ".slurm &") 119 | os.system("sbatch slurm_scripts/" + jobname + ".slurm &") 120 | -------------------------------------------------------------------------------- /MotionBCECriterion.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | This class increases the error function's sensitivity to elements of the 3 | target which change from frame to frame within a batch. 4 | 5 | It assumes that each batch takes the form of successive 6 | frames of video. 7 | 8 | After calculating the pointwise cross-entropy, it applies a multiplicative 9 | mask, causing points which have changed from frame to frame to have a much 10 | greater impact on the summed BCE. 11 | 12 | We treat the first and last frame in each batch as only having motion 13 | relative to the second and next-to-last frames, respectively. 14 | All other frames have increased sensitivity in regions which differ either 15 | from the previous frame or from the following frame. 16 | --]] 17 | 18 | local MotionBCECriterion, parent = torch.class('nn.MotionBCECriterion', 'nn.Criterion') 19 | 20 | local eps = 1e-12 21 | 22 | function MotionBCECriterion:__init(motionScale) 23 | parent.__init(self) 24 | self.sizeAverage = true 25 | self.motionScale = motionScale 26 | self.mask = torch.Tensor() 27 | end 28 | 29 | function MotionBCECriterion:updateOutput(input, target) 30 | -- print("input") 31 | -- print(input:size()) 32 | -- print("target") 33 | -- print(target:size()) 34 | -- log(input) * target + log(1 - input) * (1 - target) 35 | 36 | self.term1 = self.term1 or input.new() 37 | self.term2 = self.term2 or input.new() 38 | self.term3 = self.term3 or input.new() 39 | 40 | self.term1:resizeAs(input) 41 | self.term2:resizeAs(input) 42 | self.term3:resizeAs(input) 43 | 44 | self.term1:fill(1):add(-1,target) 45 | self.term2:fill(1):add(-1,input):add(eps):log():cmul(self.term1) 46 | 47 | self.term3:copy(input):add(eps):log():cmul(target) 48 | self.term3:add(self.term2) 49 | 50 | if self.sizeAverage then 51 | self.term3:div(target:nElement()) 52 | end 53 | 54 | -- the error is Sum[(error at each point) * (importance of that point)] 55 | self:updateScalingMask(target) 56 | self.term3:cmul(self.mask) 57 | self.output = - self.term3:sum() 58 | 59 | return self.output 60 | end 61 | 62 | function MotionBCECriterion:updateGradInput(input, target) 63 | -- target / input - (1 - target) / (1 - input) 64 | 65 | self.term1 = self.term1 or input.new() 66 | self.term2 = self.term2 or input.new() 67 | self.term3 = self.term3 or input.new() 68 | 69 | self.term1:resizeAs(input) 70 | self.term2:resizeAs(input) 71 | self.term3:resizeAs(input) 72 | 73 | self.term1:fill(1):add(-1,target) 74 | self.term2:fill(1):add(-1,input) 75 | 76 | self.term2:add(eps) 77 | self.term1:cdiv(self.term2) 78 | 79 | self.term3:copy(input):add(eps) 80 | 81 | self.gradInput:resizeAs(input) 82 | self.gradInput:copy(target):cdiv(self.term3) 83 | 84 | self.gradInput:add(-1,self.term1) 85 | 86 | if self.sizeAverage then 87 | self.gradInput:div(target:nElement()) 88 | end 89 | 90 | self.gradInput:mul(-1) 91 | 92 | -- as stated above, 93 | -- the error is Sum[(error at each point) * (importance of that point)] 94 | -- so the gradient is grad(error at each point) * (importance of that point) 95 | self:updateScalingMask(target) 96 | self.gradInput:cmul(self.mask) 97 | 98 | return self.gradInput 99 | end 100 | 101 | function MotionBCECriterion:updateScalingMask(target) 102 | self.mask:resizeAs(target):fill(0) 103 | local nBatches = target:size(1) 104 | 105 | -- find all the places in each frame that changed since the frame before 106 | -- all of the "forward in time" changes 107 | self.mask[{{2, nBatches}}] = target[{{2, nBatches}}] - target[{{1, nBatches - 1}}] 108 | self.mask:abs() 109 | 110 | -- also find all the "backward in time" changes 111 | -- these are the same places in the frames; 112 | -- we want to highlight the importance of regions that **will** change too 113 | self.mask[{{1, nBatches - 1}}] = self.mask[{{1, nBatches - 1}}] 114 | + self.mask[{{2, nBatches}}]--:clone() 115 | 116 | self.mask:abs():sign() 117 | -- normalize the tmask to be 1 where things changed, 0 otherwise 118 | -- self.mask:apply(function(el) 119 | -- if el > 0 then 120 | -- return 1 121 | -- else 122 | -- return 0 123 | -- end 124 | -- end) 125 | 126 | -- scale by the importance we assign to motion 127 | self.mask = self.mask * self.motionScale 128 | 129 | -- add 1 so we can just do self.mask * BCE 130 | self.mask = self.mask + 1 131 | return self.mask 132 | end 133 | -------------------------------------------------------------------------------- /balls_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | dry_run = '--dry-run' in sys.argv 5 | local = '--local' in sys.argv 6 | detach = '--detach' in sys.argv 7 | 8 | dry_run = True 9 | local = False 10 | detach = True 11 | 12 | if not os.path.exists("slurm_logs"): 13 | os.makedirs("slurm_logs") 14 | 15 | if not os.path.exists("slurm_scripts"): 16 | os.makedirs("slurm_scripts") 17 | 18 | 19 | networks_prefix = "networks" 20 | 21 | base_networks = { 22 | } 23 | 24 | 25 | # Don't give it a save name - that gets generated for you 26 | # jobs = [ 27 | # { 28 | # "import": "onestep", 29 | # }, 30 | # 31 | # 32 | # ] 33 | 34 | jobs = [] 35 | 36 | # noise_options = [0.1] 37 | # sharpening_rate_options = [10] 38 | learning_rate_options = [40e-5, 50e-5] 39 | heads_options = [1,2,3] 40 | numballs_options = [2,3] 41 | subsamp_options = [3] 42 | # L2_options = [1e-2, 1e-3, 1e-4] 43 | 44 | # for noise in noise_options: 45 | # for sharpening_rate in sharpening_rate_options: 46 | for learning_rate in learning_rate_options: 47 | for numballs in numballs_options: 48 | for heads in heads_options: 49 | for subsamp in subsamp_options: 50 | job = { 51 | # "noise": noise, 52 | # "sharpening_rate": sharpening_rate, 53 | "learning_rate": learning_rate, 54 | "heads": heads, 55 | "numballs": numballs, 56 | "subsample":subsamp 57 | # "dataset_name": dataset_name 58 | # "gpu": True, 59 | } 60 | jobs.append(job) 61 | 62 | 63 | if dry_run: 64 | print "NOT starting jobs:" 65 | else: 66 | print "Starting jobs:" 67 | 68 | for job in jobs: 69 | jobname = "balls" 70 | flagstring = "" 71 | for flag in job: 72 | if isinstance(job[flag], bool): 73 | if job[flag]: 74 | jobname = jobname + "_" + flag 75 | flagstring = flagstring + " --" + flag 76 | else: 77 | print "WARNING: Excluding 'False' flag " + flag 78 | elif flag == 'import': 79 | imported_network_name = job[flag] 80 | if imported_network_name in base_networks.keys(): 81 | network_location = base_networks[imported_network_name] 82 | jobname = jobname + "_" + flag + "_" + str(imported_network_name) 83 | flagstring = flagstring + " --" + flag + " " + str(network_location) 84 | else: 85 | jobname = jobname + "_" + flag + "_" + str(job[flag]) 86 | flagstring = flagstring + " --" + flag + " " + networks_prefix + "/" + str(job[flag]) 87 | else: 88 | jobname = jobname + "_" + flag + "_" + str(job[flag]) 89 | flagstring = flagstring + " --" + flag + " " + str(job[flag]) 90 | flagstring = flagstring + " --name " + jobname 91 | 92 | jobcommand = "th balls_main.lua" + flagstring 93 | 94 | print(jobcommand) 95 | if local and not dry_run: 96 | if detach: 97 | os.system(jobcommand + ' 2> slurm_logs/' + jobname + '.err 1> slurm_logs/' + jobname + '.out &') 98 | else: 99 | os.system(jobcommand) 100 | 101 | else: 102 | with open('slurm_scripts/' + jobname + '.slurm', 'w') as slurmfile: 103 | slurmfile.write("#!/bin/bash\n") 104 | slurmfile.write("#SBATCH --job-name"+"=" + jobname + "\n") 105 | slurmfile.write("#SBATCH --output=slurm_logs/" + jobname + ".out\n") 106 | # slurmfile.write("#SBATCH --error=slurm_logs/" + jobname + ".err\n") 107 | slurmfile.write("#SBATCH -N 1\n") 108 | slurmfile.write("#SBATCH -c 2\n") 109 | slurmfile.write("#SBATCH -p gpu\n") 110 | slurmfile.write("#SBATCH --gres=gpu:1\n") 111 | slurmfile.write("#SBATCH --mem=30000\n") 112 | slurmfile.write("#SBATCH --time=6-23:00:00\n") 113 | slurmfile.write("#SBATCH -x node027\n") 114 | slurmfile.write(jobcommand) 115 | 116 | if not dry_run: 117 | # if 'gpu' in job and job['gpu']: 118 | # os.system("sbatch -N 1 -c 2 --gres=gpu:1 -p gpu --mem=8000 --time=6-23:00:00 slurm_scripts/" + jobname + ".slurm &") 119 | # else: 120 | # os.system("sbatch -N 1 -c 2 --mem=8000 --time=6-23:00:00 slurm_scripts/" + jobname + ".slurm &") 121 | os.system("sbatch slurm_scripts/" + jobname + ".slurm &") 122 | -------------------------------------------------------------------------------- /render_atari_examples.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cutorch' 3 | require 'cunn' 4 | require 'paths' 5 | require 'lfs' 6 | 7 | vis = require 'vis' 8 | require 'AtariEncoder' 9 | require 'AtariDecoder' 10 | local data_loaders = require 'data_loaders' 11 | 12 | name = arg[1] 13 | -- dataset_name = arg[2] or name 14 | networks = {} 15 | while true do 16 | local line = io.read() 17 | if line == nil then break end 18 | 19 | -- strip whitespace 20 | line = string.gsub(line, "%s+", "") 21 | 22 | table.insert(networks, line) 23 | end 24 | 25 | -- opt = { 26 | -- datasetdir = '/om/user/wwhitney/deep-game-engine', 27 | -- dataset_name = dataset_name, 28 | -- gpu = true, 29 | -- } 30 | 31 | base_directory = "/om/user/wwhitney/unsupervised-dcign/networks" 32 | 33 | local jobname = name ..'_'.. os.date("%b_%d_%H_%M") 34 | local output_path = 'reports/renderings/'..jobname 35 | os.execute('mkdir -p '..output_path) 36 | 37 | 38 | function getLastSnapshot(network_name) 39 | local res_file = io.popen("ls -t "..paths.concat(base_directory, network_name).." | grep -i epoch | head -n 1") 40 | local status, result = pcall(function() return res_file:read():match( "^%s*(.-)%s*$" ) end) 41 | -- print(status, result) 42 | res_file:close() 43 | if not status then 44 | return false 45 | else 46 | return result 47 | end 48 | end 49 | 50 | for _, network in ipairs(networks) do 51 | collectgarbage() 52 | 53 | print('') 54 | print(network) 55 | local snapshot_name = getLastSnapshot(network) 56 | if snapshot_name then 57 | local checkpoint = torch.load(paths.concat(base_directory, network, snapshot_name)) 58 | opt = checkpoint.opt 59 | local model = checkpoint.model 60 | local scheduler_iteration = torch.Tensor{checkpoint.step} 61 | model:evaluate() 62 | 63 | local encoder = model.modules[1] 64 | local sharpener = encoder:findModules('nn.ScheduledWeightSharpener')[1] 65 | sharpener.iteration_container = scheduler_iteration 66 | print("Current sharpening: ", sharpener:getP()) 67 | 68 | local weight_predictor = encoder:findModules('nn.Normalize')[1] 69 | local previous_embedding = encoder:findModules('nn.Linear')[1] 70 | local current_embedding = encoder:findModules('nn.Linear')[2] 71 | local decoder = model.modules[2] 72 | 73 | for i = 339, 343 do 74 | local images = {} 75 | 76 | -- fetch a batch 77 | local input = data_loaders.load_atari_batch(i, 'test') 78 | local output = model:forward(input):clone() 79 | local embedding_from_previous = previous_embedding.output:clone() 80 | local embedding_from_current = current_embedding.output:clone() 81 | 82 | local reconstruction_from_previous = decoder:forward(embedding_from_previous):clone() 83 | local reconstruction_from_current = decoder:forward(embedding_from_current):clone() 84 | 85 | local weight_norms = torch.zeros(output:size(1)) 86 | local weight_norms = torch.zeros(output:size(1)) 87 | for input_index = 1, output:size(1) do 88 | weights = weight_predictor.output[input_index]:clone() 89 | weight_norms[input_index] = weights:norm() 90 | end 91 | print("Mean independence of weights: ", weight_norms:mean()) 92 | 93 | for input_index = 1, math.min(30, output:size(1)), 3 do 94 | local weights = weight_predictor.output[input_index]:clone() 95 | local max_weight, varying_index = weights:max(1) 96 | -- print("Varying index: " .. vis.simplestr(varying_index), "Weight: " .. vis.simplestr(max_weight)) 97 | 98 | -- local embedding_change = embedding_from_current[input_index] - embedding_from_previous[input_index] 99 | -- local normalized_embedding_change = embedding_change / embedding_change:norm(1) 100 | -- print("Independence of embedding change: ", normalized_embedding_change:norm()) 101 | -- print("Distance between timesteps: ", embedding_change:norm()) 102 | 103 | 104 | local image_row = {} 105 | table.insert(image_row, input[1][input_index]:float()) 106 | table.insert(image_row, input[2][input_index]:float()) 107 | table.insert(image_row, reconstruction_from_previous[input_index]:float()) 108 | table.insert(image_row, reconstruction_from_current[input_index]:float()) 109 | table.insert(image_row, output[input_index]:float()) 110 | table.insert(images, image_row) 111 | end 112 | vis.save_image_grid(paths.concat(output_path, network .. '_batch_'..i..'.png'), images) 113 | 114 | collectgarbage() 115 | end 116 | end 117 | end 118 | 119 | 120 | print("done") 121 | -------------------------------------------------------------------------------- /render_downsampled_examples.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cutorch' 3 | require 'cunn' 4 | require 'paths' 5 | require 'lfs' 6 | 7 | vis = require 'vis' 8 | require 'DownsampledEncoder' 9 | require 'DownsampledDecoder' 10 | local data_loaders = require 'data_loaders' 11 | 12 | require 'Scale' 13 | 14 | name = arg[1] 15 | -- dataset_name = arg[2] or name 16 | networks = {} 17 | while true do 18 | local line = io.read() 19 | if line == nil then break end 20 | 21 | -- strip whitespace 22 | line = string.gsub(line, "%s+", "") 23 | 24 | table.insert(networks, line) 25 | end 26 | 27 | -- opt = { 28 | -- datasetdir = '/om/user/wwhitney/deep-game-engine', 29 | -- dataset_name = dataset_name, 30 | -- gpu = true, 31 | -- } 32 | 33 | base_directory = "/om/user/wwhitney/unsupervised-dcign/networks" 34 | 35 | local jobname = name ..'_'.. os.date("%b_%d_%H_%M") 36 | local output_path = 'reports/renderings/'..jobname 37 | os.execute('mkdir -p '..output_path) 38 | 39 | scale = nn.Scale(84, 84, true) 40 | 41 | function getLastSnapshot(network_name) 42 | local res_file = io.popen("ls -t "..paths.concat(base_directory, network_name).." | grep -i epoch | head -n 1") 43 | local status, result = pcall(function() return res_file:read():match( "^%s*(.-)%s*$" ) end) 44 | -- print(status, result) 45 | res_file:close() 46 | if not status then 47 | return false 48 | else 49 | return result 50 | end 51 | end 52 | 53 | for _, network in ipairs(networks) do 54 | collectgarbage() 55 | 56 | print('') 57 | print(network) 58 | local snapshot_name = getLastSnapshot(network) 59 | if snapshot_name then 60 | local checkpoint = torch.load(paths.concat(base_directory, network, snapshot_name)) 61 | opt = checkpoint.opt 62 | local model = checkpoint.model 63 | local scheduler_iteration = torch.Tensor{checkpoint.step} 64 | model:evaluate() 65 | 66 | local encoder = model.modules[1] 67 | local sharpener = encoder:findModules('nn.ScheduledWeightSharpener')[1] 68 | sharpener.iteration_container = scheduler_iteration 69 | print("Current sharpening: ", sharpener:getP()) 70 | 71 | local weight_predictor = encoder:findModules('nn.Normalize')[1] 72 | local previous_embedding = encoder:findModules('nn.Linear')[1] 73 | local current_embedding = encoder:findModules('nn.Linear')[2] 74 | local decoder = model.modules[2] 75 | 76 | for i = 339, 343 do 77 | local images = {} 78 | 79 | -- fetch a batch 80 | local input = data_loaders.load_atari_batch(i, 'test') 81 | input = { 82 | scale:forward(input[1]), 83 | scale:forward(input[2]), 84 | } 85 | local output = model:forward(input):clone() 86 | local embedding_from_previous = previous_embedding.output:clone() 87 | local embedding_from_current = current_embedding.output:clone() 88 | 89 | local reconstruction_from_previous = decoder:forward(embedding_from_previous):clone() 90 | local reconstruction_from_current = decoder:forward(embedding_from_current):clone() 91 | 92 | local weight_norms = torch.zeros(output:size(1)) 93 | local weight_norms = torch.zeros(output:size(1)) 94 | for input_index = 1, output:size(1) do 95 | weights = weight_predictor.output[input_index]:clone() 96 | weight_norms[input_index] = weights:norm() 97 | end 98 | print("Mean independence of weights: ", weight_norms:mean()) 99 | 100 | for input_index = 1, math.min(30, output:size(1)), 3 do 101 | local weights = weight_predictor.output[input_index]:clone() 102 | local max_weight, varying_index = weights:max(1) 103 | -- print("Varying index: " .. vis.simplestr(varying_index), "Weight: " .. vis.simplestr(max_weight)) 104 | 105 | -- local embedding_change = embedding_from_current[input_index] - embedding_from_previous[input_index] 106 | -- local normalized_embedding_change = embedding_change / embedding_change:norm(1) 107 | -- print("Independence of embedding change: ", normalized_embedding_change:norm()) 108 | -- print("Distance between timesteps: ", embedding_change:norm()) 109 | 110 | 111 | local image_row = {} 112 | table.insert(image_row, input[1][input_index]:float()) 113 | table.insert(image_row, input[2][input_index]:float()) 114 | table.insert(image_row, reconstruction_from_previous[input_index]:float()) 115 | table.insert(image_row, reconstruction_from_current[input_index]:float()) 116 | table.insert(image_row, output[input_index]:float()) 117 | table.insert(images, image_row) 118 | end 119 | vis.save_image_grid(paths.concat(output_path, network .. '_batch_'..i..'.png'), images) 120 | 121 | collectgarbage() 122 | end 123 | end 124 | end 125 | 126 | 127 | print("done") 128 | -------------------------------------------------------------------------------- /downsampled_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | dry_run = '--dry-run' in sys.argv 5 | local = '--local' in sys.argv 6 | detach = '--detach' in sys.argv 7 | 8 | if not os.path.exists("slurm_logs"): 9 | os.makedirs("slurm_logs") 10 | 11 | if not os.path.exists("slurm_scripts"): 12 | os.makedirs("slurm_scripts") 13 | 14 | 15 | networks_prefix = "networks" 16 | 17 | base_networks = { 18 | } 19 | 20 | 21 | # Don't give it a save name - that gets generated for you 22 | jobs = [ 23 | # { 24 | # "noise": 0.1, 25 | # "sharpening_rate": 10, 26 | # "learning_rate": 2e-4, 27 | # "heads": 3, 28 | # "motion_scale": 3, 29 | # "frame_interval": 1, 30 | # "dataset_name": "space_invaders", 31 | # "model": "autoencoder", 32 | 33 | # "gpu": True, 34 | # }, 35 | { 36 | "noise": 0.1, 37 | "sharpening_rate": 10, 38 | "learning_rate": 2e-4, 39 | "heads": 3, 40 | "motion_scale": 3, 41 | "frame_interval": 1, 42 | "dataset_name": "breakout", 43 | "model": "disentangled", 44 | 45 | "gpu": True, 46 | }, 47 | { 48 | "noise": 0.1, 49 | "sharpening_rate": 10, 50 | "learning_rate": 2e-4, 51 | "heads": 3, 52 | "motion_scale": 3, 53 | "frame_interval": 1, 54 | "dataset_name": "breakout", 55 | "model": "vanilla", 56 | 57 | "gpu": True, 58 | }, 59 | 60 | 61 | ] 62 | 63 | # jobs = [] 64 | 65 | # noise_options = [0.1] 66 | # sharpening_rate_options = [10] 67 | # learning_rate_options = [2e-4] 68 | # heads_options = [3] 69 | # motion_scale_options = [3] 70 | # frame_interval_options = [1] 71 | # dataset_name_options = ["space_invaders"] 72 | # model_options = ["disentangled", "autoencoder"] 73 | # # L2_options = [1e-2, 1e-3, 1e-4] 74 | 75 | # for noise in noise_options: 76 | # for sharpening_rate in sharpening_rate_options: 77 | # for learning_rate in learning_rate_options: 78 | # for heads in heads_options: 79 | # for motion_scale in motion_scale_options: 80 | # for frame_interval in frame_interval_options: 81 | # for dataset_name in dataset_name_options: 82 | # for model in model_options: 83 | # job = { 84 | # "noise": noise, 85 | # "sharpening_rate": sharpening_rate, 86 | # "learning_rate": learning_rate, 87 | # "heads": heads, 88 | # "motion_scale": motion_scale, 89 | # "frame_interval": frame_interval, 90 | # "dataset_name": dataset_name, 91 | # "model": model, 92 | 93 | # "gpu": True, 94 | # } 95 | # jobs.append(job) 96 | 97 | 98 | if dry_run: 99 | print "NOT starting jobs:" 100 | else: 101 | print "Starting jobs:" 102 | 103 | for job in jobs: 104 | jobname = "down" 105 | flagstring = "" 106 | for flag in job: 107 | if isinstance(job[flag], bool): 108 | if job[flag]: 109 | jobname = jobname + "_" + flag 110 | flagstring = flagstring + " --" + flag 111 | else: 112 | print "WARNING: Excluding 'False' flag " + flag 113 | elif flag == 'import': 114 | imported_network_name = job[flag] 115 | if imported_network_name in base_networks.keys(): 116 | network_location = base_networks[imported_network_name] 117 | jobname = jobname + "_" + flag + "_" + str(imported_network_name) 118 | flagstring = flagstring + " --" + flag + " " + str(network_location) 119 | else: 120 | jobname = jobname + "_" + flag + "_" + str(job[flag]) 121 | flagstring = flagstring + " --" + flag + " " + networks_prefix + "/" + str(job[flag]) 122 | else: 123 | jobname = jobname + "_" + flag + "_" + str(job[flag]) 124 | flagstring = flagstring + " --" + flag + " " + str(job[flag]) 125 | flagstring = flagstring + " --name " + jobname 126 | 127 | jobcommand = "th downsampled_main.lua" + flagstring 128 | 129 | print(jobcommand) 130 | if local and not dry_run: 131 | if detach: 132 | os.system(jobcommand + ' 2> slurm_logs/' + jobname + '.err 1> slurm_logs/' + jobname + '.out &') 133 | else: 134 | os.system(jobcommand) 135 | 136 | else: 137 | with open('slurm_scripts/' + jobname + '.slurm', 'w') as slurmfile: 138 | slurmfile.write("#!/bin/bash\n") 139 | slurmfile.write("#SBATCH --job-name"+"=" + jobname + "\n") 140 | slurmfile.write("#SBATCH --output=slurm_logs/" + jobname + ".out\n") 141 | slurmfile.write("#SBATCH --error=slurm_logs/" + jobname + ".err\n") 142 | # slurmfile.write("luarocks install cutorch\n") 143 | # slurmfile.write("luarocks install cunn\n") 144 | slurmfile.write(jobcommand) 145 | 146 | if not dry_run: 147 | if 'gpu' in job and job['gpu']: 148 | os.system("sbatch -N 1 -c 2 --gres=gpu:titan-x:1 --mem=8000 --time=6-23:00:00 slurm_scripts/" + jobname + ".slurm &") 149 | else: 150 | os.system("sbatch -N 1 -c 2 --mem=8000 --time=6-23:00:00 slurm_scripts/" + jobname + ".slurm &") 151 | -------------------------------------------------------------------------------- /data_loaders.lua: -------------------------------------------------------------------------------- 1 | require 'image' 2 | 3 | local data_loaders = {} 4 | 5 | opt = {} 6 | 7 | function data_loaders.load_mv_batch(id, dataset_name, mode) 8 | local data = torch.load(opt.datasetdir .. '/th_' .. dataset_name .. '/' .. mode .. '/batch' .. id) 9 | 10 | local input1s = torch.zeros(19, 1, 150, 150) 11 | local input2s = torch.zeros(19, 1, 150, 150) 12 | 13 | if opt.gpu then 14 | data = data:cuda() 15 | input1s = input1s:cuda() 16 | input2s = input2s:cuda() 17 | end 18 | 19 | for i = 1, 19 do 20 | input1s[i] = data[i] 21 | input2s[i] = data[i + 1] 22 | end 23 | return {input1s, input2s} 24 | end 25 | 26 | function data_loaders.load_random_mv_batch(mode) 27 | local variation_type = math.random(3) 28 | local variation_name = "" 29 | if variation_type == 1 then 30 | variation_name = "AZ_VARIED" 31 | elseif variation_type == 2 then 32 | variation_name = "EL_VARIED" 33 | elseif variation_type == 3 then 34 | variation_name = "LIGHT_AZ_VARIED" 35 | end 36 | 37 | local id, mode_name 38 | if mode == 'train' then 39 | mode_name = 'FT_training' 40 | id = math.random(opt.num_train_batches_per_type) 41 | elseif mode == 'test' then 42 | mode_name = 'FT_test' 43 | id = math.random(opt.num_test_batches_per_type) 44 | end 45 | return data_loaders.load_mv_batch(id, variation_name, mode_name), variation_type 46 | end 47 | 48 | 49 | 50 | function data_loaders.load_atari_batch(id, mode) 51 | local data = torch.load(opt.datasetdir .. '/dataset_DQN_' .. opt.dataset_name .. '_trained/' .. mode .. '/images_batch_' .. id) 52 | 53 | local frame_interval = opt.frame_interval or 1 54 | 55 | local num_inputs = data:size(1) - frame_interval 56 | 57 | local input1s = torch.zeros(num_inputs, 3, 210, 160) 58 | local input2s = torch.zeros(num_inputs, 3, 210, 160) 59 | 60 | if opt.gpu then 61 | data = data:cuda() 62 | input1s = input1s:cuda() 63 | input2s = input2s:cuda() 64 | end 65 | 66 | for i = 1, num_inputs do 67 | input1s[i] = data[i] 68 | input2s[i] = data[i + frame_interval] 69 | end 70 | return {input1s, input2s} 71 | end 72 | 73 | function data_loaders.load_random_atari_batch(mode) 74 | local id 75 | if mode == 'train' then 76 | id = math.random(opt.num_train_batches) 77 | elseif mode == 'test' then 78 | id = math.random(opt.num_test_batches) 79 | end 80 | return data_loaders.load_atari_batch(id, mode) 81 | end 82 | 83 | 84 | function data_loaders.load_action_batch(id, mode) 85 | local data = torch.load(opt.datasetdir .. '/' .. opt.dataset_name ..'/'.. mode .. '/batch' .. id) 86 | 87 | local input1s = torch.zeros(29, 1, 120, 160) 88 | local input2s = torch.zeros(29, 1, 120, 160) 89 | 90 | if opt.gpu then 91 | data = data:cuda() 92 | input1s = input1s:cuda() 93 | input2s = input2s:cuda() 94 | end 95 | 96 | for i = 1, 29 do 97 | input1s[i] = data[i] 98 | input2s[i] = data[i + 1] 99 | end 100 | return {input1s, input2s} 101 | end 102 | 103 | 104 | function data_loaders.load_random_action_batch(mode) 105 | local id 106 | if mode == 'train' then 107 | id = math.random(opt.num_train_batches) 108 | elseif mode == 'test' then 109 | id = math.random(opt.num_test_batches) 110 | end 111 | return data_loaders.load_action_batch(id, mode) 112 | end 113 | 114 | 115 | 116 | function data_loaders.load_balls_batch(id, mode) 117 | local adjusted_id = id-1 -- adjust from python indexing 118 | local batch_folder 119 | if opt.subsample then 120 | batch_folder = opt.datasetdir .. '/' .. mode .. '_nb='..opt.numballs..'_bsize=30_imsize=150_subsamp='..opt.subsample..'/batch' ..adjusted_id 121 | else 122 | batch_folder = opt.datasetdir .. '/' .. mode .. '_nb='..opt.numballs..'_bsize=30_imsize=150/batch' ..adjusted_id 123 | end 124 | 125 | -- now open and sort the images. images go from 0 to 29 126 | local data = torch.zeros(30,1,150,150) 127 | for i = 0,29 do -- because of python indexing 128 | local img = image.load(batch_folder ..'/' .. i ..'.png') 129 | -- img = img/255 -- 130 | data[i+1] = img 131 | end 132 | 133 | local input1s = torch.zeros(29, 1, 150, 150) 134 | local input2s = torch.zeros(29, 1, 150, 150) 135 | 136 | if opt.gpu then 137 | data = data:cuda() 138 | input1s = input1s:cuda() 139 | input2s = input2s:cuda() 140 | end 141 | 142 | for i = 1, 29 do 143 | input1s[i] = data[i] 144 | input2s[i] = data[i + 1] 145 | end 146 | return {input1s, input2s} 147 | end 148 | 149 | 150 | function data_loaders.load_random_balls_batch(mode) 151 | local id 152 | if mode == 'train' then 153 | id = math.random(opt.num_train_batches) 154 | elseif mode == 'test' then 155 | id = math.random(opt.num_test_batches) 156 | end 157 | return data_loaders.load_balls_batch(id, mode) -- becuase I saved it as python indexing 158 | end 159 | 160 | 161 | -- function data_loaders.load_kitti_batch(id, mode) 162 | -- local data = torch.load(opt.datasetdir .. '/' .. mode .. '/batch' .. id) 163 | -- -- data = data:reshape(data:size(1),1,data:size(2),data:size(3)) -- one channel 164 | -- print(data:size()) 165 | -- 166 | -- local input1s = torch.zeros(29, 1, 150, 150) 167 | -- local input2s = torch.zeros(29, 1, 150, 150) 168 | -- 169 | -- if opt.gpu then 170 | -- data = data:cuda() 171 | -- input1s = input1s:cuda() 172 | -- input2s = input2s:cuda() 173 | -- end 174 | -- 175 | -- for i = 1, 29 do 176 | -- input1s[i] = data[i] 177 | -- input2s[i] = data[i + 1] 178 | -- end 179 | -- return {input1s, input2s} 180 | -- end 181 | -- 182 | -- function data_loaders.load_random_kitti_batch(mode) 183 | -- local id 184 | -- if mode == 'train' then 185 | -- id = math.random(opt.num_train_batches) 186 | -- elseif mode == 'test' then 187 | -- id = math.random(opt.num_train_batches) 188 | -- end 189 | -- return data_loaders.load_kitti_batch(id, mode) 190 | -- end 191 | 192 | return data_loaders 193 | -------------------------------------------------------------------------------- /vis.lua: -------------------------------------------------------------------------------- 1 | require 'image' 2 | 3 | local vis = {} 4 | 5 | function vis.save_image_grid(filepath, images) 6 | if images ~= nil and images[1] ~= nil then 7 | colors = images[1][1]:size(1) 8 | image_width = images[1][1]:size(3) 9 | image_height = images[1][1]:size(2) 10 | -- print(image_width, image_height) 11 | -- print(images[1][1]:size()) 12 | padding = 5 13 | images_across = #images[1] 14 | images_down = #images 15 | -- print(images_down, images_across) 16 | 17 | image_output = torch.zeros( 18 | colors, 19 | image_height * images_down + (images_down - 1) * padding, 20 | image_width * images_across + (images_across - 1) * padding) 21 | -- print(image_output:size()) 22 | for i, image_row in ipairs(images) do 23 | for j, image in ipairs(image_row) do 24 | -- print(image:sum()) 25 | y_index = j - 1 26 | y_location = y_index * image_width + y_index * padding 27 | x_index = i - 1 28 | x_location = (x_index) * image_height + x_index * padding 29 | 30 | -- print({{x_location + 1, x_location + image_height}, 31 | -- {y_location + 1, y_location + image_width}}) 32 | 33 | image_output[{{}, 34 | {x_location + 1, x_location + image_height}, 35 | {y_location + 1, y_location + image_width}}] = image 36 | end 37 | end 38 | -- image_output = image_output:reshape(colors, image_output:size()[1], image_output:size()[2]) 39 | image.save(filepath, image_output) 40 | else 41 | error("Invalid images:", images) 42 | end 43 | end 44 | 45 | vis.colors = { 46 | HEADER = '\27[95m', 47 | OKBLUE = '\27[94m', 48 | OKGREEN = '\27[92m', 49 | WARNING = '\27[93m', 50 | FAIL = '\27[91m', 51 | ENDC = '\27[0m', 52 | BOLD = '\27[1m', 53 | UNDERLINE = '\27[4m', 54 | RESET = '\27[0m' 55 | } 56 | 57 | vis.decimalPlaces = 4 58 | 59 | function vis.lines(str) 60 | local t = {} 61 | local function helper(line) table.insert(t, line) return "" end 62 | helper((str:gsub("(.-)\r?\n", helper))) 63 | return t 64 | end 65 | 66 | function vis.flatten(tensor) 67 | return tensor:reshape(1, tensor:nElement()) 68 | end 69 | 70 | function vis.round(tensor, places) 71 | places = places or vis.decimalPlaces 72 | local tensorClone = tensor:clone() 73 | local offset = 0 74 | if tensor:sum() ~= 0 then 75 | offset = - math.floor(math.log10(torch.abs(tensorClone):mean())) + (places - 1) 76 | end 77 | 78 | tensorClone = tensorClone * (10 ^ offset) 79 | tensorClone:round() 80 | tensorClone = tensorClone / (10 ^ offset) 81 | 82 | if tostring(tensorClone[1]) == tostring(0/0) then 83 | print(tensor) 84 | print(math.floor(math.log10(torch.abs(tensorClone):mean()))) 85 | print(offset) 86 | error("got nan") 87 | end 88 | 89 | return tensorClone 90 | end 91 | 92 | function vis.simplestr(input) 93 | if type(input) == 'number' then 94 | local str = string.format("%." .. vis.decimalPlaces .. "f", input) 95 | return str 96 | end 97 | -- local rounded = vis.round(input) 98 | -- -- local rounded = input:clone() 99 | -- 100 | -- local strTable = vis.lines(tostring(vis.flatten(rounded))) 101 | -- table.remove(strTable, #strTable) 102 | -- table.remove(strTable, #strTable) 103 | -- 104 | -- local str = "" 105 | -- for i, line in ipairs(strTable) do 106 | -- str = str..line 107 | -- end 108 | -- return str 109 | -- print("input", input) 110 | -- print("tensor1", input[1]) 111 | -- print(input) 112 | local str = string.format("%." .. vis.decimalPlaces .. "f", input[1]) 113 | for i = 2, input:size(1) do 114 | str = str .. string.format(" %." .. vis.decimalPlaces .. "f", input[i]) 115 | end 116 | return str 117 | end 118 | 119 | function vis.prettySingleError(number) 120 | local str = tostring(number) 121 | if math.abs(number) < 1e-10 then 122 | return '0.0000' 123 | else 124 | return vis.colors.FAIL..str..vis.colors.RESET 125 | end 126 | end 127 | 128 | function vis.prettyError(err) 129 | if type(err) == 'number' then 130 | return vis.prettySingleError(err) 131 | elseif type(err) == 'table' then 132 | local str = '' 133 | for _, val in ipairs(err) do 134 | str = str .. ' ' .. vis.prettySingleError(val) 135 | end 136 | return str 137 | elseif err.size then -- assume tensor 138 | local rounded = vis.round(err) 139 | if rounded:nDimension() ~= 1 then 140 | error("Only able to pretty-print 1D tensors.") 141 | else 142 | local str = '' 143 | for i = 1, rounded:size(1) do 144 | str = str .. ' ' .. vis.prettySingleError(rounded[i]) 145 | end 146 | return str 147 | end 148 | else 149 | error("Not sure what to do with this object.") 150 | end 151 | end 152 | 153 | function vis.diff(a, b) 154 | local str 155 | if type(a) == 'number' and type(b) == 'number' then 156 | str = vis.prettySingleError(a - b) 157 | elseif type(a) == 'table' and type(b) == 'table' then 158 | str = '' 159 | for i, _ in ipairs(a) do 160 | str = str .. ' ' .. vis.prettySingleError(a[i] - b[i]) 161 | end 162 | elseif a.size and b.size then -- assume tensor 163 | local rounded = vis.round(a - b) 164 | if rounded:nDimension() ~= 1 then 165 | error("Only able to pretty-print 1D tensors.") 166 | else 167 | str = '' 168 | for i = 1, rounded:size(1) do 169 | str = str .. ' ' .. vis.prettySingleError(rounded[i]) 170 | end 171 | end 172 | else 173 | error("Not sure what to do with this object.") 174 | end 175 | print(str) 176 | end 177 | 178 | function vis.hist(a) 179 | tensor = a:clone() 180 | tensor = tensor / tensor:clone():abs():max() 181 | -- print(tensor:min()) 182 | tensor = tensor + (-tensor:min()) 183 | tensor:mul(10) 184 | local str = vis.simplestr(tensor) 185 | -- print(str) 186 | os.execute('spark ' .. str) 187 | end 188 | 189 | return vis 190 | -------------------------------------------------------------------------------- /val_vis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | # from matplotlib import pyplot as plt 4 | import seaborn 5 | 6 | import sys 7 | import os 8 | import copy 9 | import pprint 10 | from matplotlib import gridspec 11 | 12 | import argparse 13 | 14 | parser = argparse.ArgumentParser(description='Plot dem results.') 15 | parser.add_argument('--name', default='default') 16 | # parser.add_argument('--keep_losers', default=False) 17 | parser.add_argument('--hide_losers', action='store_true', default=False) 18 | parser.add_argument('--keep_young', action='store_true', default=False) 19 | parser.add_argument('--loser_threshold', default=1) 20 | args = parser.parse_args() 21 | 22 | output_dir = "reports/" + args.name 23 | 24 | pp = pprint.PrettyPrinter(indent=4) 25 | 26 | def mean(l): 27 | return sum(l) / float(len(l)) 28 | 29 | networks = {} 30 | for name in sys.stdin: 31 | network_name = name.strip() 32 | # print(network_name) 33 | opt_path = "networks/" + network_name + "/opt.txt" 34 | loss_path = "networks/" + network_name + "/val_loss.txt" 35 | # print(os.path.isfile(opt_path)) 36 | # print(os.path.isfile(loss_path)) 37 | try: 38 | if os.path.isfile(opt_path) and os.path.isfile(loss_path): 39 | network_data = {} 40 | with open(opt_path) as opt_file: 41 | options = {} 42 | for line in opt_file: 43 | k, v = line.split(": ") 44 | options[k] = v.strip() 45 | network_data['options'] = options 46 | network_data['options']['name'] = network_name 47 | 48 | with open(loss_path) as loss_file: 49 | losses = [] 50 | for line in loss_file: 51 | losses.append(float(line)) 52 | network_data['losses'] = losses 53 | 54 | networks[network_name] = network_data 55 | 56 | except IOError as e: 57 | pass 58 | 59 | 60 | if not args.keep_young: 61 | network_ages = [] 62 | for network_name in networks: 63 | network = networks[network_name] 64 | network_ages.append(len(network['losses'])) 65 | 66 | mean_network_age = mean(network_ages) 67 | 68 | new_networks = {} 69 | for network_name in networks: 70 | network = networks[network_name] 71 | if len(network['losses']) < (3 * mean_network_age / 4.): 72 | print("Network is too young. Excluding: " + network_name) 73 | else: 74 | new_networks[network_name] = network 75 | 76 | networks = new_networks 77 | 78 | if args.hide_losers: 79 | new_networks = {} 80 | for network_name in networks: 81 | network = networks[network_name] 82 | if network['losses'][-1] > args.loser_threshold: 83 | print("Network's loss is too high: " + str(network['losses'][-1]) + ". Excluding: " + network_name) 84 | else: 85 | new_networks[network_name] = network 86 | 87 | networks = new_networks 88 | 89 | same_options = copy.deepcopy(networks[networks.keys()[0]]['options']) 90 | diff_options = [] 91 | for network_name in networks: 92 | network = networks[network_name] 93 | options = network['options'] 94 | for option in options: 95 | if option not in diff_options: 96 | if option not in same_options: 97 | diff_options.append(option) 98 | else: 99 | if options[option] != same_options[option]: 100 | diff_options.append(option) 101 | same_options.pop(option, None) 102 | 103 | print(diff_options) 104 | # print(same_options) 105 | 106 | # don't separate them by name 107 | # diff_options.remove("name") 108 | 109 | per_option_loss_lists = {} 110 | 111 | for option in diff_options: 112 | option_loss_lists = {} 113 | for network_name in networks: 114 | network = networks[network_name] 115 | 116 | option_value = 'none' 117 | if option in network['options'] and network['options'][option] != '': 118 | option_value = network['options'][option] 119 | 120 | if option_value not in option_loss_lists: 121 | option_loss_lists[option_value] = [] 122 | 123 | option_loss_lists[option_value].append(network['losses']) 124 | 125 | per_option_loss_lists[option] = option_loss_lists 126 | 127 | 128 | # per_option_mean_losses = {} 129 | # for option in per_option_loss_lists: 130 | # per_value_mean_losses = {} 131 | # for option_value in per_option_loss_lists[option]: 132 | # loss_lists = per_option_loss_lists[option][option_value] 133 | # 134 | # last_losses = [losses[-1] for losses in loss_lists] 135 | # mean_loss = mean(last_losses) 136 | # per_value_mean_losses[option_value] = mean_loss 137 | # 138 | # per_option_mean_losses[option] = per_value_mean_losses 139 | 140 | per_option_last_losses = {} 141 | for option in per_option_loss_lists: 142 | per_value_last_losses = {} 143 | for option_value in per_option_loss_lists[option]: 144 | loss_lists = per_option_loss_lists[option][option_value] 145 | per_value_last_losses[option_value] = [losses[-1] for losses in loss_lists] 146 | 147 | per_option_last_losses[option] = per_value_last_losses 148 | 149 | 150 | if not os.path.exists(output_dir): 151 | os.makedirs(output_dir) 152 | 153 | for option in per_option_last_losses: 154 | 155 | if option == 'name' or option == 'import': 156 | fig = seaborn.plt.figure(figsize=(15,15)) 157 | else: 158 | fig = seaborn.plt.figure(figsize=(15,10)) 159 | fig.add_subplot() 160 | 161 | df = pd.DataFrame(columns=["option", option, "loss"]) 162 | i = 0 163 | for option_value in per_option_last_losses[option]: 164 | for value in per_option_last_losses[option][option_value]: 165 | df.loc[i] = [option, option_value, value] 166 | i += 1 167 | 168 | # seaborn.set(font_scale=0.5) 169 | print(df) 170 | if option == 'name': 171 | g = seaborn.barplot(data=df, x=option, y="loss") 172 | else: 173 | g = seaborn.boxplot(data=df, x=option, y="loss") 174 | seaborn.stripplot(data=df, x=option, y="loss", ax=g, color="black") 175 | 176 | if option == 'name' or option == 'import': 177 | for item in g.get_xticklabels(): 178 | item.set_fontsize(5) 179 | 180 | 181 | seaborn.plt.xticks(rotation=90) 182 | g.set(title=option) 183 | g.set_yscale('log') 184 | 185 | seaborn.plt.tight_layout() 186 | seaborn.plt.savefig(output_dir + "/" + option + ".pdf", dpi=300) 187 | seaborn.plt.close() 188 | -------------------------------------------------------------------------------- /utils.lua: -------------------------------------------------------------------------------- 1 | local T = require 'pl.tablex' 2 | 3 | -- From https://gist.github.com/cwarden/1207556 4 | function catch(what) 5 | return what[1] 6 | end 7 | 8 | -- From https://gist.github.com/cwarden/1207556 9 | function try(what) 10 | status, result = pcall(what[1]) 11 | if not status then 12 | what[2](result) 13 | end 14 | return result 15 | end 16 | 17 | function subrange(t, first, last) 18 | local sub = {} 19 | for i=first,last do 20 | sub[#sub + 1] = t[i] 21 | end 22 | return sub 23 | end 24 | 25 | -- merge t2 into t1 26 | function merge_tables(t1, t2) 27 | -- Merges t2 and t1, overwriting t1 keys by t2 keys when applicable 28 | merged_table = T.deepcopy(t1) 29 | for k,v in pairs(t2) do 30 | -- if merged_table[k] then 31 | -- error('t1 and t2 both contain the key: ' .. k) 32 | -- end 33 | merged_table[k] = v 34 | end 35 | return merged_table 36 | end 37 | 38 | -- merge t2 into t1 39 | -- TODO do set functions 40 | function merge_tables_by_value(t1, t2) 41 | -- Merges t2 and t1, overwriting t1 keys by t2 keys when applicable 42 | for k,v in pairs(t1) do assert(type(k) == 'number') end 43 | merged_table = T.deepcopy(t1) 44 | for _,v in pairs(t2) do 45 | if not isin(v, merged_table) then 46 | merged_table[#merged_table+1] = v -- just append 47 | end 48 | end 49 | return merged_table 50 | end 51 | 52 | function intersect(t1, t2) 53 | local intersect_table = {} 54 | for k,v1 in pairs(t1) do 55 | if isin(v1, t2) then 56 | intersect_table[#intersect_table+1] = v1 57 | end 58 | end 59 | return intersect_table 60 | end 61 | 62 | function is_subset(small_table, big_table) 63 | for _, el in pairs(small_table) do 64 | if not isin(el, big_table) then 65 | return false 66 | end 67 | end 68 | return true 69 | end 70 | 71 | function isin(element, table) 72 | for _,v in pairs(table) do 73 | if v == element then 74 | return true 75 | end 76 | end 77 | return false 78 | end 79 | 80 | function is_empty(table) 81 | if next(table) == nil then return true end 82 | end 83 | 84 | -- BUG! If the arg is nil, then it won't get passed into args_table! 85 | function all_args_exist(args_table, num_args) 86 | if not(#args_table == num_args) then return false end 87 | local exist = true 88 | for _,a in pairs(args_table) do 89 | if a == nil then 90 | exist = false 91 | end 92 | end 93 | -- assert(false) 94 | return exist 95 | end 96 | 97 | function is_substring(substring, string) 98 | return not (string:find(substring) == nil) 99 | end 100 | 101 | function notnil(x) 102 | return not(x == nil) 103 | end 104 | 105 | -- from http://lua-users.org/wiki/FunctionalLibrary 106 | -- map(function, table) 107 | -- e.g: map(double, {1,2,3}) -> {2,4,6} 108 | function map(func, tbl) 109 | local newtbl = {} 110 | for i,v in pairs(tbl) do 111 | newtbl[i] = func(v) 112 | end 113 | return newtbl 114 | end 115 | 116 | -- from http://lua-users.org/wiki/FunctionalLibrary 117 | -- filter(function, table) 118 | -- e.g: filter(is_even, {1,2,3,4}) -> {2,4} 119 | function filter(func, tbl) 120 | local newtbl= {} 121 | for i,v in pairs(tbl) do 122 | if func(v) then 123 | newtbl[i]=v 124 | end 125 | end 126 | return newtbl 127 | end 128 | 129 | -- from http://lua-users.org/wiki/FunctionalLibrary 130 | -- head(table) 131 | -- e.g: head({1,2,3}) -> 1 132 | function head(tbl) 133 | return tbl[1] 134 | end 135 | 136 | -- from http://lua-users.org/wiki/FunctionalLibrary 137 | -- tail(table) 138 | -- e.g: tail({1,2,3}) -> {2,3} 139 | -- 140 | -- XXX This is a BAD and ugly implementation. 141 | -- should return the address to next porinter, like in C (arr+1) 142 | function tail(tbl) 143 | if table.getn(tbl) < 1 then 144 | return nil 145 | else 146 | local newtbl = {} 147 | local tblsize = table.getn(tbl) 148 | local i = 2 149 | while (i <= tblsize) do 150 | table.insert(newtbl, i-1, tbl[i]) 151 | i = i + 1 152 | end 153 | return newtbl 154 | end 155 | end 156 | 157 | -- from http://lua-users.org/wiki/FunctionalLibrary 158 | -- foldr(function, default_value, table) 159 | -- e.g: foldr(operator.mul, 1, {1,2,3,4,5}) -> 120 160 | function foldr(func, val, tbl) 161 | for i,v in pairs(tbl) do 162 | val = func(val, v) 163 | end 164 | return val 165 | end 166 | 167 | -- from http://lua-users.org/wiki/FunctionalLibrary 168 | -- reduce(function, table) 169 | -- e.g: reduce(operator.add, {1,2,3,4}) -> 10 170 | function reduce(func, tbl) 171 | return foldr(func, head(tbl), tail(tbl)) 172 | end 173 | 174 | -- range(start) returns an iterator from 1 to a (step = 1) 175 | -- range(start, stop) returns an iterator from a to b (step = 1) 176 | -- range(start, stop, step) returns an iterator from a to b, counting by step. 177 | -- from http://lua-users.org/wiki/RangeIterator 178 | function range (i, to, inc) 179 | if i == nil then return end -- range(--[[ no args ]]) -> return "nothing" to fail the loop in the caller 180 | 181 | if not to then 182 | to = i 183 | i = to == 0 and 0 or (to > 0 and 1 or -1) 184 | end 185 | 186 | -- we don't have to do the to == 0 check 187 | -- 0 -> 0 with any inc would never iterate 188 | inc = inc or (i < to and 1 or -1) 189 | 190 | -- step back (once) before we start 191 | i = i - inc 192 | 193 | return function () if i == to then return nil end i = i + inc return i, i end 194 | end 195 | 196 | -- the elements of t2 should go after their corresponding t1 elements 197 | function interleave(t1, t2) 198 | assert(#t1 == #t2) 199 | local interleaved = {} 200 | for i = 1, #t1 do 201 | interleaved[#interleaved+1] = t1[i] 202 | interleaved[#interleaved+1] = t2[i] -- t2's elements comes after t1 203 | end 204 | assert(#interleaved %2 == 0) 205 | return T.deepcopy(interleaved) 206 | end 207 | 208 | -- extends t1 by the elements of t2 209 | function extend(t1, t2) 210 | local total = T.deepcopy(t1) 211 | for i = 1, #t2 do 212 | total[#total+1] = t2[i] 213 | end 214 | assert(#total == #t1 + #t2) 215 | return T.deepcopy(total) 216 | end 217 | 218 | -- print(merge_tables_by_value({['a']=1}, {['b'] = 2, ['c'] = 5})) 219 | 220 | -- print(intersect({'a','b','c'}, {'d','b','c'})) 221 | 222 | -- a = {10,20,30,40,50,60} 223 | -- print(subrange(a,1,#a-1)) 224 | -- print(a) 225 | -------------------------------------------------------------------------------- /render_generalization_face.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cutorch' 3 | require 'cunn' 4 | require 'paths' 5 | require 'lfs' 6 | 7 | vis = require 'vis' 8 | require 'AtariEncoder' 9 | require 'AtariDecoder' 10 | local data_loaders = require 'data_loaders' 11 | 12 | name = arg[1] 13 | -- dataset_name = arg[2] or name 14 | networks = {} 15 | while true do 16 | local line = io.read() 17 | if line == nil then break end 18 | 19 | -- strip whitespace 20 | line = string.gsub(line, "%s+", "") 21 | 22 | table.insert(networks, line) 23 | end 24 | 25 | -- opt = { 26 | -- datasetdir = '/om/user/wwhitney/deep-game-engine', 27 | -- dataset_name = dataset_name, 28 | -- gpu = true, 29 | -- } 30 | 31 | base_directory = "/om/user/wwhitney/unsupervised-dcign/networks" 32 | 33 | local jobname = name ..'_'.. os.date("%b_%d_%H_%M") 34 | local output_path = 'reports/renderings/mutation/'..jobname 35 | os.execute('mkdir -p '..output_path) 36 | 37 | 38 | function getLastSnapshot(network_name) 39 | local res_file = io.popen("ls -t "..paths.concat(base_directory, network_name).." | grep -i epoch | head -n 1") 40 | local status, result = pcall(function() return res_file:read():match( "^%s*(.-)%s*$" ) end) 41 | -- print(status, result) 42 | res_file:close() 43 | if not status then 44 | return false 45 | else 46 | return result 47 | end 48 | end 49 | 50 | for _, network in ipairs(networks) do 51 | collectgarbage() 52 | 53 | print('') 54 | print(network) 55 | local snapshot_name = getLastSnapshot(network) 56 | if snapshot_name then 57 | local checkpoint = torch.load(paths.concat(base_directory, network, snapshot_name)) 58 | opt = checkpoint.opt 59 | local model = checkpoint.model 60 | local scheduler_iteration = torch.Tensor{checkpoint.step} 61 | model:evaluate() 62 | 63 | local encoder = model.modules[1] 64 | local sharpener = encoder:findModules('nn.ScheduledWeightSharpener')[1] 65 | sharpener.iteration_container = scheduler_iteration 66 | print("Current sharpening: ", sharpener:getP()) 67 | 68 | local weight_predictor = encoder:findModules('nn.Normalize')[1] 69 | local previous_embedding = encoder:findModules('nn.Linear')[1] 70 | -- local current_embedding = encoder:findModules('nn.Linear')[2] 71 | local decoder = model.modules[2] 72 | 73 | for _, mode in ipairs{"AZ_VARIED", "EL_VARIED", "LIGHT_AZ_VARIED"} do 74 | for _, batch_index in ipairs{347, 400} do 75 | print("Batch index: ", batch_index) 76 | -- local images = {} 77 | 78 | -- fetch a batch 79 | local input = data_loaders.load_mv_batch(batch_index, mode, 'FT_test') 80 | local output = model:forward(input):clone() 81 | local embedding_from_previous = previous_embedding.output:clone() 82 | -- local embedding_from_current = current_embedding.output:clone() 83 | 84 | -- local reconstruction_from_previous = decoder:forward(embedding_from_previous):clone() 85 | -- local reconstruction_from_current = decoder:forward(embedding_from_current):clone() 86 | 87 | local weight_norms = torch.zeros(output:size(1)) 88 | for input_index = 1, output:size(1) do 89 | local weights = weight_predictor.output[input_index]:clone() 90 | weight_norms[input_index] = weights:norm() 91 | end 92 | print("Mean independence of weights: ", weight_norms:mean()) 93 | 94 | 95 | 96 | 97 | -- local max_indices = {} 98 | -- for input_index = 1, output:size(1) do 99 | -- local weights = weight_predictor.output[input_index]:clone() 100 | -- local _, idx = weights:max(1) 101 | -- max_indices[idx[1]] = true 102 | -- end 103 | 104 | for input_index = 1, 2 do 105 | collectgarbage() 106 | print("Input index: ", input_index) 107 | local base_embedding = embedding_from_previous[input_index]:clone():float() 108 | 109 | local weights = weight_predictor.output[input_index]:clone():float() 110 | local max_indices = {} 111 | for nth_max = 1, 3 do 112 | local _, idx = weights:max(1) 113 | idx = idx[1] 114 | max_indices[idx] = true 115 | weights[idx] = 0 116 | end 117 | 118 | 119 | for max_index, _ in pairs(max_indices) do 120 | 121 | collectgarbage() 122 | -- local weights = weight_predictor.output[input_index]:clone() 123 | -- local max_weight, varying_index = weights:max(1) 124 | 125 | local num_frames = 40 126 | local min_change = -1.5 127 | local max_change = 1.5 128 | 129 | local mutated_input = torch.Tensor(num_frames, base_embedding:size(1)) 130 | 131 | for i = 1, num_frames do 132 | local change = min_change + (i-1) * (max_change-min_change)/num_frames 133 | mutated_input[i] = base_embedding:clone() 134 | mutated_input[i][max_index] = mutated_input[i][max_index] + change 135 | end 136 | 137 | local mutated_renders = decoder:forward(mutated_input:cuda()):clone() 138 | 139 | local output_directory = paths.concat( 140 | output_path, 141 | network, 142 | mode..'_batch_'..batch_index..'_input_'..input_index..'_along_'..max_index) 143 | os.execute('mkdir -p '..output_directory) 144 | 145 | for i = 1, num_frames do 146 | local change = min_change + (i-1) * (max_change-min_change)/num_frames 147 | local output_filename = paths.concat( 148 | output_directory, 149 | 'changing_'..i..'_amount_'..vis.simplestr(change)..'.png') 150 | image.save(output_filename, mutated_renders[i]) 151 | end 152 | 153 | collectgarbage() 154 | end 155 | end 156 | end 157 | -- print("Mean independence of weights: ", weight_norms:mean()) 158 | -- vis.save_image_grid(paths.concat(output_path, network .. '_batch_'..batch_index..'.png'), images) 159 | 160 | 161 | end 162 | end 163 | end 164 | 165 | 166 | print("done") 167 | -------------------------------------------------------------------------------- /bouncing_balls.py: -------------------------------------------------------------------------------- 1 | from pylab import * 2 | import pdb 3 | import os 4 | import cv2 5 | from progressbar import ProgressBar 6 | from argparse import ArgumentParser 7 | 8 | # Adapted from Ruben Villegas, U Mich 9 | 10 | class BouncingBallDataHandler(object): 11 | def __init__(self, num_balls, seq_length, batch_size, image_size): 12 | self.SIZE = 10 13 | self.T = seq_length 14 | self.n = num_balls 15 | self.res = image_size 16 | self.batch_size = batch_size 17 | 18 | def norm(self, x): return sqrt((x**2).sum()) 19 | def sigmoid(self, x): return 1./(1.+exp(-x)) 20 | 21 | def new_speeds(self, m1, m2, v1, v2): 22 | new_v2 = (2*m1*v1 + v2*(m2-m1))/(m1+m2) 23 | new_v1 = new_v2 + (v2 - v1) 24 | return new_v1, new_v2 25 | 26 | # size of bounding box: self.SIZE X self.SIZE. 27 | 28 | def bounce_n(self, T=128, n=2, r=None, m=None): 29 | if r==None: r=array([1.2]*n) 30 | if m==None: m=array([1]*n) 31 | # r is to be rather small. 32 | X=zeros((T, n, 2), dtype='float') 33 | v = randn(n,2) 34 | v = v / self.norm(v)*.5 35 | good_config=False 36 | while not good_config: 37 | x = 2+rand(n,2)*8 38 | good_config=True 39 | for i in range(n): 40 | for z in range(2): 41 | if x[i][z]-r[i]<0: good_config=False 42 | if x[i][z]+r[i]>self.SIZE: good_config=False 43 | 44 | # that's the main part. 45 | for i in range(n): 46 | for j in range(i): 47 | if self.norm(x[i]-x[j])self.SIZE: v[i][z]=-abs(v[i][z]) # want negative 67 | 68 | 69 | for i in range(n): 70 | for j in range(i): 71 | if self.norm(x[i]-x[j])1]=1 103 | return A 104 | 105 | def bounce_mat(self, res, n=2, T=128, r =None): 106 | if r==None: r=array([1.2]*n) 107 | x = self.bounce_n(T,n,r); 108 | A = self.matricize(x,res,r) 109 | return A 110 | 111 | def bounce_vec(self, res, n=2, T=128, r =None, m =None): 112 | if r==None: r=array([1.2]*n) 113 | x = self.bounce_n(T,n,r,m); 114 | V = self.matricize(x,res,r) 115 | return V.reshape(T, res**2) 116 | 117 | def show_single_V(self, V): 118 | res = int(sqrt(shape(V)[0])) 119 | show(V.reshape(res, res)) 120 | 121 | def show_V(self, V): 122 | T = len(V) 123 | res = int(sqrt(shape(V)[1])) 124 | for t in range(T): 125 | print t 126 | show(V[t].reshape(res, res)) 127 | 128 | def unsigmoid(self, x): return log (x) - log (1-x) 129 | 130 | def show_A(self, A): 131 | T = len(A) 132 | for t in range(T): 133 | show(A[t]) 134 | 135 | def GetBatch(self): 136 | seq_batch = np.zeros( ( self.T, 137 | self.batch_size, 138 | self.res,self.res ) ) 139 | 140 | for i in xrange(self.batch_size): 141 | seq = self.bounce_mat(self.res, self.n, self.T) 142 | seq_batch[:,i,:] = seq.reshape( seq.shape[0], 143 | seq.shape[1],seq.shape[2] ) 144 | return seq_batch.astype('float32') 145 | 146 | def make_dataset(root, mode, num_batches, num_balls, batch_size, image_size, subsample): 147 | """ Each sequence is a batch """ 148 | handler = BouncingBallDataHandler(num_balls=num_balls, seq_length=batch_size*subsample, batch_size=1, image_size=150) 149 | data_root = os.path.join(root, mode) + '_nb=' + str(num_balls) + '_bsize=' + str(batch_size) + '_imsize=' + str(image_size) + '_subsamp=' + str(subsample) 150 | os.mkdir(data_root) 151 | pbar = ProgressBar() 152 | for i in pbar(range(num_batches)): 153 | batch_folder = os.path.join(data_root,'batch'+str(i)) 154 | os.mkdir(batch_folder) 155 | x = handler.GetBatch() # this is normalized between 0 and 1 156 | for j in range(0,batch_size*subsample,subsample): 157 | # pass 158 | cv2.imwrite(os.path.join(batch_folder,str(j/subsample)+'.png'),x[j,0,:,:]*255) 159 | pbar.finish() 160 | 161 | if __name__ == "__main__": 162 | 163 | # handler = BouncingBallDataHandler(num_balls=3, seq_length=30, batch_size=1, image_size=150) 164 | # x = handler.GetBatch() 165 | 166 | parser = ArgumentParser() 167 | parser.add_argument('-m','--mode', type=str, default='train', 168 | help='train | test | val') 169 | parser.add_argument('-b','--batch_size', type=int, default=30, 170 | help='batch size') 171 | parser.add_argument('-n','--num_balls', type=int, default=3, 172 | help='Number of balls') 173 | parser.add_argument('-s','--subsample', type=int, default=3, 174 | help='How many frames to subample') 175 | parser.add_argument('-i','--image_size', type=int, default=150, 176 | help='Dimension of square image') 177 | # parser.add_argument('-e','--name', type=str, default='defaultballsname', 178 | # help='name of job') 179 | args = parser.parse_args() 180 | 181 | root = '/om/data/public/mbchang/udcign-data/balls' 182 | # root = '/Users/MichaelChang/Documents/Researchlink/SuperUROP/Code/data/udcign/balls' 183 | datasets = {'train':9000,'test':1000,'val':1000} 184 | 185 | mode = args.mode #'val' 186 | num_balls = args.num_balls#6 187 | batch_size = args.batch_size#30 188 | image_size = args.image_size#150 189 | subsample = args.subsample#5 190 | print mode, num_balls 191 | print(args) 192 | # TODO add subsample! 193 | make_dataset(root, mode, datasets[mode], num_balls, batch_size, image_size, subsample) 194 | -------------------------------------------------------------------------------- /render_generalization_atari.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cutorch' 3 | require 'cunn' 4 | require 'paths' 5 | require 'lfs' 6 | 7 | vis = require 'vis' 8 | require 'AtariEncoder' 9 | require 'AtariDecoder' 10 | local data_loaders = require 'data_loaders' 11 | 12 | name = arg[1] 13 | -- dataset_name = arg[2] or name 14 | networks = {} 15 | while true do 16 | local line = io.read() 17 | if line == nil then break end 18 | 19 | -- strip whitespace 20 | line = string.gsub(line, "%s+", "") 21 | 22 | table.insert(networks, line) 23 | end 24 | 25 | -- opt = { 26 | -- datasetdir = '/om/user/wwhitney/deep-game-engine', 27 | -- dataset_name = dataset_name, 28 | -- gpu = true, 29 | -- } 30 | 31 | base_directory = "/om/user/wwhitney/unsupervised-dcign/networks" 32 | 33 | local jobname = name ..'_'.. os.date("%b_%d_%H_%M") 34 | local output_path = 'reports/renderings/mutation/'..jobname 35 | os.execute('mkdir -p '..output_path) 36 | 37 | 38 | function getLastSnapshot(network_name) 39 | local res_file = io.popen("ls -t "..paths.concat(base_directory, network_name).." | grep -i epoch | head -n 1") 40 | local status, result = pcall(function() return res_file:read():match( "^%s*(.-)%s*$" ) end) 41 | -- print(status, result) 42 | res_file:close() 43 | if not status then 44 | return false 45 | else 46 | return result 47 | end 48 | end 49 | 50 | for _, network in ipairs(networks) do 51 | collectgarbage() 52 | 53 | print('') 54 | print(network) 55 | local snapshot_name = getLastSnapshot(network) 56 | if snapshot_name then 57 | local checkpoint = torch.load(paths.concat(base_directory, network, snapshot_name)) 58 | opt = checkpoint.opt 59 | local model = checkpoint.model 60 | local scheduler_iteration = torch.Tensor{checkpoint.step} 61 | model:evaluate() 62 | 63 | local encoder = model.modules[1] 64 | local sharpener = encoder:findModules('nn.ScheduledWeightSharpener')[1] 65 | sharpener.iteration_container = scheduler_iteration 66 | print("Current sharpening: ", sharpener:getP()) 67 | 68 | local weight_predictor = encoder:findModules('nn.Normalize')[1] 69 | local previous_embedding = encoder:findModules('nn.Linear')[1] 70 | -- local current_embedding = encoder:findModules('nn.Linear')[2] 71 | local decoder = model.modules[2] 72 | 73 | for _, batch_index in ipairs{347, 400, 420} do 74 | print("Batch index: ", batch_index) 75 | -- local images = {} 76 | 77 | -- fetch a batch 78 | local input = data_loaders.load_atari_batch(batch_index, 'test') 79 | local output = model:forward(input):clone() 80 | local embedding_from_previous = previous_embedding.output:clone() 81 | -- local embedding_from_current = current_embedding.output:clone() 82 | 83 | -- local reconstruction_from_previous = decoder:forward(embedding_from_previous):clone() 84 | -- local reconstruction_from_current = decoder:forward(embedding_from_current):clone() 85 | 86 | local weight_norms = torch.zeros(output:size(1)) 87 | for input_index = 1, output:size(1) do 88 | local weights = weight_predictor.output[input_index]:clone() 89 | weight_norms[input_index] = weights:norm() 90 | end 91 | print("Mean independence of weights: ", weight_norms:mean()) 92 | 93 | 94 | 95 | 96 | -- local max_indices = {} 97 | -- for input_index = 1, output:size(1) do 98 | -- local weights = weight_predictor.output[input_index]:clone() 99 | -- local _, idx = weights:max(1) 100 | -- max_indices[idx[1]] = true 101 | -- end 102 | 103 | for input_index = 1, 2 do 104 | collectgarbage() 105 | print("Input index: ", input_index) 106 | local base_embedding = embedding_from_previous[input_index]:clone():float() 107 | 108 | local weights = weight_predictor.output[input_index]:clone():float() 109 | local max_indices = {} 110 | for nth_max = 1, 3 do 111 | local _, idx = weights:max(1) 112 | idx = idx[1] 113 | max_indices[idx] = true 114 | weights[idx] = 0 115 | end 116 | 117 | 118 | for max_index, _ in pairs(max_indices) do 119 | collectgarbage() 120 | -- local weights = weight_predictor.output[input_index]:clone() 121 | -- local max_weight, varying_index = weights:max(1) 122 | 123 | local num_frames = 40 124 | local min_change = -1.5 125 | local max_change = 1.5 126 | 127 | local mutated_input = torch.Tensor(num_frames, base_embedding:size(1)) 128 | 129 | for i = 1, num_frames do 130 | local change = min_change + (i-1) * (max_change-min_change)/num_frames 131 | mutated_input[i] = base_embedding:clone() 132 | mutated_input[i][max_index] = mutated_input[i][max_index] + change 133 | end 134 | 135 | local mutated_renders = decoder:forward(mutated_input:cuda()):clone() 136 | 137 | local output_directory = paths.concat( 138 | output_path, 139 | network, 140 | 'batch_'..batch_index..'_input_'..input_index..'_along_'..max_index) 141 | os.execute('mkdir -p '..output_directory) 142 | 143 | for i = 1, num_frames do 144 | local change = min_change + (i-1) * (max_change-min_change)/num_frames 145 | local output_filename = paths.concat( 146 | output_directory, 147 | 'changing_'..i..'_amount_'..vis.simplestr(change)..'.png') 148 | image.save(output_filename, mutated_renders[i]) 149 | end 150 | 151 | -- for change = -3, 3, 0.1 do 152 | -- local output_directory = paths.concat( 153 | -- output_path, 154 | -- network, 155 | -- 'batch_'..batch_index..'_input_'..input_index..'_along_'..max_index) 156 | -- local output_filename = paths.concat( 157 | -- output_directory, 158 | -- 'changing_'..i..'_amount_'..vis.simplestr(change)..'.png') 159 | -- os.execute('mkdir -p '..output_directory) 160 | 161 | -- local changed_embedding = base_embedding:clone() 162 | -- changed_embedding[max_index] = changed_embedding[max_index] + change 163 | 164 | -- local rendering = decoder:forward(changed_embedding:reshape(1, 200))[1]:clone() 165 | -- image.save(output_filename, rendering:float()) 166 | -- i = i + 1 167 | -- end 168 | 169 | -- weight_norms[input_index] = weights:norm() 170 | 171 | -- local image_row = {} 172 | -- table.insert(image_row, input[1][input_index]:float()) 173 | -- table.insert(image_row, input[2][input_index]:float()) 174 | -- table.insert(image_row, reconstruction_from_previous[input_index]:float()) 175 | -- table.insert(image_row, reconstruction_from_current[input_index]:float()) 176 | -- table.insert(image_row, output[input_index]:float()) 177 | -- table.insert(images, image_row) 178 | collectgarbage() 179 | end 180 | end 181 | -- print("Mean independence of weights: ", weight_norms:mean()) 182 | -- vis.save_image_grid(paths.concat(output_path, network .. '_batch_'..batch_index..'.png'), images) 183 | 184 | 185 | end 186 | end 187 | end 188 | 189 | 190 | print("done") 191 | -------------------------------------------------------------------------------- /render_generalization_downsampled.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cutorch' 3 | require 'cunn' 4 | require 'paths' 5 | require 'lfs' 6 | 7 | vis = require 'vis' 8 | require 'DownsampledEncoder' 9 | require 'DownsampledDecoder' 10 | local data_loaders = require 'data_loaders' 11 | 12 | require 'Scale' 13 | 14 | name = arg[1] 15 | -- dataset_name = arg[2] or name 16 | networks = {} 17 | while true do 18 | local line = io.read() 19 | if line == nil then break end 20 | 21 | -- strip whitespace 22 | line = string.gsub(line, "%s+", "") 23 | 24 | table.insert(networks, line) 25 | end 26 | 27 | -- opt = { 28 | -- datasetdir = '/om/user/wwhitney/deep-game-engine', 29 | -- dataset_name = dataset_name, 30 | -- gpu = true, 31 | -- } 32 | 33 | base_directory = "/om/user/wwhitney/unsupervised-dcign/networks" 34 | 35 | local jobname = name ..'_'.. os.date("%b_%d_%H_%M") 36 | local output_path = 'reports/renderings/mutation/'..jobname 37 | os.execute('mkdir -p '..output_path) 38 | 39 | scale = nn.Scale(84, 84, true) 40 | 41 | 42 | function getLastSnapshot(network_name) 43 | local res_file = io.popen("ls -t "..paths.concat(base_directory, network_name).." | grep -i epoch | head -n 1") 44 | local status, result = pcall(function() return res_file:read():match( "^%s*(.-)%s*$" ) end) 45 | -- print(status, result) 46 | res_file:close() 47 | if not status then 48 | return false 49 | else 50 | return result 51 | end 52 | end 53 | 54 | for _, network in ipairs(networks) do 55 | collectgarbage() 56 | 57 | print('') 58 | print(network) 59 | local snapshot_name = getLastSnapshot(network) 60 | if snapshot_name then 61 | local checkpoint = torch.load(paths.concat(base_directory, network, snapshot_name)) 62 | opt = checkpoint.opt 63 | local model = checkpoint.model 64 | local scheduler_iteration = torch.Tensor{checkpoint.step} 65 | model:evaluate() 66 | 67 | local encoder = model.modules[1] 68 | local sharpener = encoder:findModules('nn.ScheduledWeightSharpener')[1] 69 | -- sharpener.iteration_container = scheduler_iteration 70 | print("Current sharpening: ", sharpener:getP()) 71 | 72 | local weight_predictor = encoder:findModules('nn.Normalize')[1] 73 | local previous_embedding = encoder:findModules('nn.Linear')[1] 74 | -- local current_embedding = encoder:findModules('nn.Linear')[2] 75 | local decoder = model.modules[2] 76 | 77 | for _, batch_index in ipairs{347, 400, 420} do 78 | print("Batch index: ", batch_index) 79 | -- local images = {} 80 | 81 | -- fetch a batch 82 | local input = data_loaders.load_atari_batch(batch_index, 'test') 83 | input = { 84 | scale:forward(input[1]), 85 | scale:forward(input[2]), 86 | } 87 | 88 | local output = model:forward(input):clone() 89 | local embedding_from_previous = previous_embedding.output:clone() 90 | -- local embedding_from_current = current_embedding.output:clone() 91 | 92 | -- local reconstruction_from_previous = decoder:forward(embedding_from_previous):clone() 93 | -- local reconstruction_from_current = decoder:forward(embedding_from_current):clone() 94 | 95 | local weight_norms = torch.zeros(output:size(1)) 96 | for input_index = 1, output:size(1) do 97 | local weights = weight_predictor.output[input_index]:clone() 98 | weight_norms[input_index] = weights:norm() 99 | end 100 | print("Mean independence of weights: ", weight_norms:mean()) 101 | 102 | 103 | -- local max_indices = {} 104 | -- for input_index = 1, output:size(1) do 105 | -- local weights = weight_predictor.output[input_index]:clone() 106 | -- local _, idx = weights:max(1) 107 | -- max_indices[idx[1]] = true 108 | -- end 109 | 110 | for input_index = 1, 2 do 111 | collectgarbage() 112 | print("Input index: ", input_index) 113 | local base_embedding = embedding_from_previous[input_index]:clone():float() 114 | 115 | local weights = weight_predictor.output[input_index]:clone():float() 116 | local max_indices = {} 117 | for nth_max = 1, 3 do 118 | local _, idx = weights:max(1) 119 | idx = idx[1] 120 | max_indices[idx] = true 121 | weights[idx] = 0 122 | end 123 | 124 | 125 | for max_index, _ in pairs(max_indices) do 126 | collectgarbage() 127 | -- local weights = weight_predictor.output[input_index]:clone() 128 | -- local max_weight, varying_index = weights:max(1) 129 | 130 | local num_frames = 40 131 | local min_change = -1.5 132 | local max_change = 1.5 133 | 134 | local mutated_input = torch.Tensor(num_frames, base_embedding:size(1)) 135 | 136 | for i = 1, num_frames do 137 | local change = min_change + (i-1) * (max_change-min_change)/num_frames 138 | mutated_input[i] = base_embedding:clone() 139 | mutated_input[i][max_index] = mutated_input[i][max_index] + change 140 | end 141 | 142 | local mutated_renders = decoder:forward(mutated_input:cuda()):clone() 143 | 144 | local output_directory = paths.concat( 145 | output_path, 146 | network, 147 | 'batch_'..batch_index..'_input_'..input_index..'_along_'..max_index) 148 | os.execute('mkdir -p '..output_directory) 149 | 150 | for i = 1, num_frames do 151 | local change = min_change + (i-1) * (max_change-min_change)/num_frames 152 | local output_filename = paths.concat( 153 | output_directory, 154 | 'changing_'..i..'_amount_'..vis.simplestr(change)..'.png') 155 | image.save(output_filename, mutated_renders[i]) 156 | end 157 | 158 | -- for change = -3, 3, 0.1 do 159 | -- local output_directory = paths.concat( 160 | -- output_path, 161 | -- network, 162 | -- 'batch_'..batch_index..'_input_'..input_index..'_along_'..max_index) 163 | -- local output_filename = paths.concat( 164 | -- output_directory, 165 | -- 'changing_'..i..'_amount_'..vis.simplestr(change)..'.png') 166 | -- os.execute('mkdir -p '..output_directory) 167 | 168 | -- local changed_embedding = base_embedding:clone() 169 | -- changed_embedding[max_index] = changed_embedding[max_index] + change 170 | 171 | -- local rendering = decoder:forward(changed_embedding:reshape(1, 200))[1]:clone() 172 | -- image.save(output_filename, rendering:float()) 173 | -- i = i + 1 174 | -- end 175 | 176 | -- weight_norms[input_index] = weights:norm() 177 | 178 | -- local image_row = {} 179 | -- table.insert(image_row, input[1][input_index]:float()) 180 | -- table.insert(image_row, input[2][input_index]:float()) 181 | -- table.insert(image_row, reconstruction_from_previous[input_index]:float()) 182 | -- table.insert(image_row, reconstruction_from_current[input_index]:float()) 183 | -- table.insert(image_row, output[input_index]:float()) 184 | -- table.insert(images, image_row) 185 | collectgarbage() 186 | end 187 | end 188 | -- print("Mean independence of weights: ", weight_norms:mean()) 189 | -- vis.save_image_grid(paths.concat(output_path, network .. '_batch_'..batch_index..'.png'), images) 190 | 191 | 192 | end 193 | end 194 | end 195 | 196 | 197 | print("done") 198 | -------------------------------------------------------------------------------- /render_generalization_balls.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cutorch' 3 | require 'cunn' 4 | require 'paths' 5 | require 'lfs' 6 | 7 | vis = require 'vis' 8 | require 'BallsEncoder' 9 | require 'Decoder' 10 | local data_loaders = require 'data_loaders' 11 | 12 | 13 | name = arg[1] 14 | -- dataset_name = arg[2] or name 15 | networks = {} 16 | while true do 17 | local line = io.read() 18 | if line == nil then break end 19 | 20 | 21 | -- strip whitespace 22 | line = string.gsub(line, "%s+", "") 23 | 24 | 25 | table.insert(networks, line) 26 | end 27 | 28 | 29 | -- opt = { 30 | -- datasetdir = '/om/user/wwhitney/deep-game-engine', 31 | -- dataset_name = dataset_name, 32 | -- gpu = true, 33 | -- } 34 | 35 | 36 | base_directory = "/home/mbchang/code/unsupervised-dcign/logslink" 37 | 38 | 39 | local jobname = name ..'_'.. os.date("%b_%d_%H_%M") 40 | local output_path = 'renderings/mutation/'..jobname 41 | os.execute('mkdir -p '..output_path) 42 | 43 | 44 | 45 | 46 | function getLastSnapshot(network_name) 47 | local res_file = io.popen("ls -t "..paths.concat(base_directory, network_name).." | grep -i epoch | head -n 1") 48 | local status, result = pcall(function() return res_file:read():match( "^%s*(.-)%s*$" ) end) 49 | -- print(status, result) 50 | res_file:close() 51 | if not status then 52 | return false 53 | else 54 | return result 55 | end 56 | end 57 | 58 | 59 | for _, network in ipairs(networks) do 60 | collectgarbage() 61 | 62 | 63 | print('') 64 | print(network) 65 | local snapshot_name = getLastSnapshot(network) 66 | if snapshot_name then 67 | local checkpoint = torch.load(paths.concat(base_directory, network, snapshot_name)) 68 | opt = checkpoint.opt 69 | local model = checkpoint.model 70 | local scheduler_iteration = torch.Tensor{checkpoint.step} 71 | model:evaluate() 72 | 73 | 74 | local encoder = model.modules[1] 75 | local sharpener = encoder:findModules('nn.ScheduledWeightSharpener')[1] 76 | sharpener.iteration_container = scheduler_iteration 77 | print("Current sharpening: ", sharpener:getP()) 78 | 79 | 80 | local weight_predictor = encoder:findModules('nn.Normalize')[1] 81 | local previous_embedding = encoder:findModules('nn.Linear')[1] 82 | -- local current_embedding = encoder:findModules('nn.Linear')[2] 83 | local decoder = model.modules[2] 84 | 85 | 86 | for _, batch_index in ipairs{25,50,75,100} do 87 | print("Batch index: ", batch_index) 88 | -- local images = {} 89 | 90 | 91 | -- fetch a batch 92 | local input = data_loaders.load_balls_batch(batch_index, 'train') 93 | local output = model:forward(input):clone() 94 | local embedding_from_previous = previous_embedding.output:clone() 95 | -- local embedding_from_current = current_embedding.output:clone() 96 | 97 | 98 | -- local reconstruction_from_previous = decoder:forward(embedding_from_previous):clone() 99 | -- local reconstruction_from_current = decoder:forward(embedding_from_current):clone() 100 | 101 | 102 | local weight_norms = torch.zeros(output:size(1)) 103 | for input_index = 1, output:size(1) do 104 | local weights = weight_predictor.output[input_index]:clone() 105 | weight_norms[input_index] = weights:norm() 106 | end 107 | print("Mean independence of weights: ", weight_norms:mean()) 108 | 109 | 110 | -- local max_indices = {} 111 | -- for input_index = 1, output:size(1) do 112 | -- local weights = weight_predictor.output[input_index]:clone() 113 | -- local _, idx = weights:max(1) 114 | -- max_indices[idx[1]] = true 115 | -- end 116 | 117 | 118 | for _, input_index in pairs{1,6,11,16,21,26} do -- example in the batch 119 | collectgarbage() 120 | print("Input index: ", input_index) 121 | local base_embedding = embedding_from_previous[input_index]:clone():float() 122 | 123 | 124 | local weights = weight_predictor.output[input_index]:clone():float() 125 | local max_indices = {} 126 | for nth_max = 1, 3 do -- take the largest 3 components 127 | local _, idx = weights:max(1) 128 | idx = idx[1] 129 | print(nth_max..'th biggest:'..idx) 130 | max_indices[idx] = true 131 | weights[idx] = 0 132 | end 133 | 134 | 135 | 136 | 137 | for max_index, _ in pairs(max_indices) do 138 | collectgarbage() 139 | -- local weights = weight_predictor.output[input_index]:clone() 140 | -- local max_weight, varying_index = weights:max(1) 141 | 142 | 143 | local num_frames = 80 -- how many frames to predict? 144 | local min_change = -4.5 -- low 145 | local max_change = 4.5 -- high 146 | 147 | 148 | local mutated_input = torch.Tensor(num_frames, base_embedding:size(1)) 149 | 150 | 151 | for i = 1, num_frames do 152 | local change = min_change + (i-1) * (max_change-min_change)/num_frames 153 | mutated_input[i] = base_embedding:clone() 154 | mutated_input[i][max_index] = mutated_input[i][max_index] + change 155 | end 156 | 157 | 158 | local mutated_renders = decoder:forward(mutated_input:cuda()):clone() 159 | 160 | 161 | local output_directory = paths.concat( 162 | output_path, 163 | network, 164 | 'batch_'..batch_index..'_input_'..input_index..'_along_'..max_index) 165 | os.execute('mkdir -p '..output_directory) 166 | 167 | 168 | for i = 1, num_frames do 169 | local change = min_change + (i-1) * (max_change-min_change)/num_frames 170 | local output_filename = paths.concat( 171 | output_directory, 172 | 'changing_'..i..'_amount_'..vis.simplestr(change)..'.png') 173 | image.save(output_filename, mutated_renders[i]) 174 | end 175 | 176 | 177 | -- for change = -3, 3, 0.1 do 178 | -- local output_directory = paths.concat( 179 | -- output_path, 180 | -- network, 181 | -- 'batch_'..batch_index..'_input_'..input_index..'_along_'..max_index) 182 | -- local output_filename = paths.concat( 183 | -- output_directory, 184 | -- 'changing_'..i..'_amount_'..vis.simplestr(change)..'.png') 185 | -- os.execute('mkdir -p '..output_directory) 186 | 187 | 188 | -- local changed_embedding = base_embedding:clone() 189 | -- changed_embedding[max_index] = changed_embedding[max_index] + change 190 | 191 | 192 | -- local rendering = decoder:forward(changed_embedding:reshape(1, 200))[1]:clone() 193 | -- image.save(output_filename, rendering:float()) 194 | -- i = i + 1 195 | -- end 196 | 197 | 198 | -- weight_norms[input_index] = weights:norm() 199 | 200 | 201 | -- local image_row = {} 202 | -- table.insert(image_row, input[1][input_index]:float()) 203 | -- table.insert(image_row, input[2][input_index]:float()) 204 | -- table.insert(image_row, reconstruction_from_previous[input_index]:float()) 205 | -- table.insert(image_row, reconstruction_from_current[input_index]:float()) 206 | -- table.insert(image_row, output[input_index]:float()) 207 | -- table.insert(images, image_row) 208 | collectgarbage() 209 | end 210 | end 211 | -- print("Mean independence of weights: ", weight_norms:mean()) 212 | -- vis.save_image_grid(paths.concat(output_path, network .. '_batch_'..batch_index..'.png'), images) 213 | 214 | 215 | 216 | 217 | end 218 | end 219 | end 220 | 221 | 222 | 223 | 224 | print("done") 225 | -------------------------------------------------------------------------------- /render_generalization_action.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'cutorch' 3 | require 'cunn' 4 | require 'paths' 5 | require 'lfs' 6 | 7 | vis = require 'vis' 8 | require 'ActionEncoder' 9 | require 'ActionDecoder' 10 | local data_loaders = require 'data_loaders' 11 | 12 | 13 | name = arg[1] 14 | -- dataset_name = arg[2] or name 15 | networks = {} 16 | while true do 17 | local line = io.read() 18 | if line == nil then break end 19 | 20 | 21 | -- strip whitespace 22 | line = string.gsub(line, "%s+", "") 23 | 24 | 25 | table.insert(networks, line) 26 | end 27 | 28 | 29 | -- opt = { 30 | -- datasetdir = '/om/user/wwhitney/deep-game-engine', 31 | -- dataset_name = dataset_name, 32 | -- gpu = true, 33 | -- } 34 | 35 | 36 | base_directory = "/home/mbchang/code/unsupervised-dcign/logslink" 37 | 38 | 39 | local jobname = name ..'_'.. os.date("%b_%d_%H_%M") 40 | local output_path = 'renderings/mutation/'..jobname 41 | os.execute('mkdir -p '..output_path) 42 | 43 | 44 | 45 | 46 | function getLastSnapshot(network_name) 47 | local res_file = io.popen("ls -t "..paths.concat(base_directory, network_name).." | grep -i epoch | head -n 1") 48 | local status, result = pcall(function() return res_file:read():match( "^%s*(.-)%s*$" ) end) 49 | -- print(status, result) 50 | res_file:close() 51 | if not status then 52 | return false 53 | else 54 | return result 55 | end 56 | end 57 | 58 | 59 | for _, network in ipairs(networks) do 60 | collectgarbage() 61 | 62 | 63 | print('') 64 | print(network) 65 | local snapshot_name = getLastSnapshot(network) 66 | if snapshot_name then 67 | local checkpoint = torch.load(paths.concat(base_directory, network, snapshot_name)) 68 | opt = checkpoint.opt 69 | local model = checkpoint.model 70 | local scheduler_iteration = torch.Tensor{checkpoint.step} 71 | model:evaluate() 72 | 73 | 74 | local encoder = model.modules[1] 75 | local sharpener = encoder:findModules('nn.ScheduledWeightSharpener')[1] 76 | sharpener.iteration_container = scheduler_iteration 77 | print("Current sharpening: ", sharpener:getP()) 78 | 79 | 80 | local weight_predictor = encoder:findModules('nn.Normalize')[1] 81 | local previous_embedding = encoder:findModules('nn.Linear')[1] 82 | -- local current_embedding = encoder:findModules('nn.Linear')[2] 83 | local decoder = model.modules[2] 84 | 85 | 86 | for _, batch_index in ipairs{10,20,30,40,50,60,70,80,90,100} do 87 | print("Batch index: ", batch_index) 88 | -- local images = {} 89 | 90 | 91 | -- fetch a batch 92 | local input = data_loaders.load_action_batch(batch_index, 'test') 93 | local output = model:forward(input):clone() 94 | local embedding_from_previous = previous_embedding.output:clone() 95 | -- local embedding_from_current = current_embedding.output:clone() 96 | 97 | 98 | -- local reconstruction_from_previous = decoder:forward(embedding_from_previous):clone() 99 | -- local reconstruction_from_current = decoder:forward(embedding_from_current):clone() 100 | 101 | 102 | local weight_norms = torch.zeros(output:size(1)) 103 | for input_index = 1, output:size(1) do 104 | local weights = weight_predictor.output[input_index]:clone() 105 | weight_norms[input_index] = weights:norm() 106 | end 107 | print("Mean independence of weights: ", weight_norms:mean()) 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | -- local max_indices = {} 117 | -- for input_index = 1, output:size(1) do 118 | -- local weights = weight_predictor.output[input_index]:clone() 119 | -- local _, idx = weights:max(1) 120 | -- max_indices[idx[1]] = true 121 | -- end 122 | 123 | 124 | for _, input_index in pairs{1, 15} do -- example in the batch 125 | collectgarbage() 126 | print("Input index: ", input_index) 127 | local base_embedding = embedding_from_previous[input_index]:clone():float() 128 | 129 | 130 | local weights = weight_predictor.output[input_index]:clone():float() 131 | local max_indices = {} 132 | for nth_max = 1, 3 do -- take the largest 3 components 133 | local _, idx = weights:max(1) 134 | idx = idx[1] 135 | print(nth_max..'th biggest:'..idx) 136 | max_indices[idx] = true 137 | weights[idx] = 0 138 | end 139 | 140 | 141 | 142 | 143 | for max_index, _ in pairs(max_indices) do 144 | collectgarbage() 145 | -- local weights = weight_predictor.output[input_index]:clone() 146 | -- local max_weight, varying_index = weights:max(1) 147 | 148 | 149 | local num_frames = 50 -- how many frames to predict? 150 | local min_change = -1.0 -- low 151 | local max_change = 1.0 -- high 152 | 153 | 154 | local mutated_input = torch.Tensor(num_frames, base_embedding:size(1)) 155 | 156 | 157 | for i = 1, num_frames do 158 | local change = min_change + (i-1) * (max_change-min_change)/num_frames 159 | mutated_input[i] = base_embedding:clone() 160 | mutated_input[i][max_index] = mutated_input[i][max_index] + change 161 | end 162 | 163 | 164 | local mutated_renders = decoder:forward(mutated_input:cuda()):clone() 165 | 166 | 167 | local output_directory = paths.concat( 168 | output_path, 169 | network, 170 | 'batch_'..batch_index..'_input_'..input_index..'_along_'..max_index) 171 | os.execute('mkdir -p '..output_directory) 172 | 173 | 174 | for i = 1, num_frames do 175 | local change = min_change + (i-1) * (max_change-min_change)/num_frames 176 | local output_filename = paths.concat( 177 | output_directory, 178 | 'changing_'..i..'_amount_'..vis.simplestr(change)..'.png') 179 | image.save(output_filename, mutated_renders[i]) 180 | end 181 | 182 | 183 | -- for change = -3, 3, 0.1 do 184 | -- local output_directory = paths.concat( 185 | -- output_path, 186 | -- network, 187 | -- 'batch_'..batch_index..'_input_'..input_index..'_along_'..max_index) 188 | -- local output_filename = paths.concat( 189 | -- output_directory, 190 | -- 'changing_'..i..'_amount_'..vis.simplestr(change)..'.png') 191 | -- os.execute('mkdir -p '..output_directory) 192 | 193 | 194 | -- local changed_embedding = base_embedding:clone() 195 | -- changed_embedding[max_index] = changed_embedding[max_index] + change 196 | 197 | 198 | -- local rendering = decoder:forward(changed_embedding:reshape(1, 200))[1]:clone() 199 | -- image.save(output_filename, rendering:float()) 200 | -- i = i + 1 201 | -- end 202 | 203 | 204 | -- weight_norms[input_index] = weights:norm() 205 | 206 | 207 | -- local image_row = {} 208 | -- table.insert(image_row, input[1][input_index]:float()) 209 | -- table.insert(image_row, input[2][input_index]:float()) 210 | -- table.insert(image_row, reconstruction_from_previous[input_index]:float()) 211 | -- table.insert(image_row, reconstruction_from_current[input_index]:float()) 212 | -- table.insert(image_row, output[input_index]:float()) 213 | -- table.insert(images, image_row) 214 | collectgarbage() 215 | end 216 | end 217 | -- print("Mean independence of weights: ", weight_norms:mean()) 218 | -- vis.save_image_grid(paths.concat(output_path, network .. '_batch_'..batch_index..'.png'), images) 219 | 220 | 221 | 222 | 223 | end 224 | end 225 | end 226 | 227 | 228 | 229 | 230 | print("done") 231 | -------------------------------------------------------------------------------- /balls_main.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'optim' 3 | 4 | require 'MotionBCECriterion' 5 | 6 | local Encoder = require 'BallsEncoder' 7 | local Decoder = require 'Decoder' 8 | 9 | local data_loaders = require 'data_loaders' 10 | 11 | local cmd = torch.CmdLine() 12 | 13 | cmd:option('--name', 'net', 'filename to autosave the checkpont to. Will be inside checkpoint_dir/') 14 | cmd:option('--checkpoint_dir', 'logslink', 'output directory where checkpoints get written') 15 | cmd:option('-import', '', 'initialize network parameters from checkpoint at this path') 16 | 17 | -- data 18 | cmd:option('--datasetdir', '/om/data/public/mbchang/udcign-data/balls', 'dataset source directory') -- change 19 | cmd:option('--numballs', 1, 'dataset source directory') 20 | cmd:option('--subsample', 3, 'subsample') -- hard code this into data_loader 21 | cmd:option('--frame_interval', 1, 'the number of timesteps between input[1] and input[2]') 22 | 23 | -- optimization 24 | cmd:option('--learning_rate', 30e-5, 'learning rate') 25 | cmd:option('--learning_rate_decay', 0.97, 'learning rate decay') 26 | cmd:option('--learning_rate_decay_after', 18000, 'in number of examples, when to start decaying the learning rate') 27 | cmd:option('--learning_rate_decay_interval', 4000, 'in number of examples, how often to decay the learning rate') 28 | cmd:option('--decay_rate', 0.95, 'decay rate for rmsprop') 29 | cmd:option('--grad_clip', 3, 'clip gradients at this value') 30 | 31 | cmd:option('--L2', 0, 'amount of L2 regularization') 32 | cmd:option('--criterion', 'BCE', 'criterion to use') 33 | cmd:option('--batch_norm', false, 'use model with batch normalization') 34 | 35 | cmd:option('--heads', 1, 'how many filtering heads to use') 36 | cmd:option('--motion_scale', 3, 'how much to accentuate loss on changing pixels') 37 | 38 | cmd:option('--dim_hidden', 200, 'dimension of the representation layer') 39 | cmd:option('--feature_maps', 72, 'number of feature maps') 40 | cmd:option('--color_channels', 1, '1 for grayscale, 3 for color') 41 | cmd:option('--sharpening_rate', 10, 'number of feature maps') 42 | cmd:option('--noise', 0.1, 'variance of added Gaussian noise') 43 | 44 | 45 | cmd:option('--max_epochs', 50, 'number of full passes through the training data') 46 | 47 | -- bookkeeping 48 | cmd:option('--seed', 123, 'torch manual random number generator seed') 49 | cmd:option('--print_every', 10, 'how many steps/minibatches between printing out the loss') 50 | cmd:option('--eval_val_every', 4500, 'every how many iterations should we evaluate on validation data?') -- CHANGE 51 | 52 | -- data 53 | cmd:option('--num_train_batches', 9000, 'number of batches to train with per epoch') -- CHANGE 54 | cmd:option('--num_test_batches', 1000, 'number of batches to test with') -- CHANGE 55 | 56 | -- GPU/CPU 57 | cmd:option('--gpu', true, 'which gpu to use. -1 = use CPU') 58 | cmd:text() 59 | 60 | 61 | -- parse input params 62 | opt = cmd:parse(arg) 63 | torch.manualSeed(opt.seed) 64 | 65 | print(opt) 66 | print(opt.gpu) 67 | 68 | if opt.gpu then 69 | require 'cutorch' 70 | require 'cunn' 71 | end 72 | 73 | if opt.name == 'net' then 74 | local name = 'unsup_' 75 | for _, v in ipairs(arg) do 76 | name = name .. tostring(v) .. '_' 77 | end 78 | opt.name = name .. os.date("%b_%d_%H_%M_%S") 79 | end 80 | 81 | local savedir = string.format('%s/%s', opt.checkpoint_dir, opt.name) 82 | print("Saving output to "..savedir) 83 | os.execute('mkdir -p '..savedir) 84 | os.execute(string.format('rm %s/*', savedir)) 85 | 86 | -- log out the options used for creating this network to a file in the save directory. 87 | -- super useful when you're moving folders around so you don't lose track of things. 88 | local f = io.open(savedir .. '/opt.txt', 'w') 89 | for key, val in pairs(opt) do 90 | f:write(tostring(key) .. ": " .. tostring(val) .. "\n") 91 | end 92 | f:flush() 93 | f:close() 94 | 95 | local logfile = io.open(savedir .. '/output.log', 'w') 96 | true_print = print 97 | print = function(...) 98 | for _, v in ipairs{...} do 99 | true_print(v) 100 | logfile:write(tostring(v)) 101 | end 102 | logfile:write("\n") 103 | logfile:flush() 104 | end 105 | 106 | 107 | local scheduler_iteration = torch.zeros(1) 108 | 109 | local model = nn.Sequential() 110 | model:add(Encoder(opt.dim_hidden, opt.color_channels, opt.feature_maps, opt.noise, opt.sharpening_rate, scheduler_iteration, opt.batch_norm, opt.heads)) 111 | model:add(Decoder(opt.dim_hidden, opt.color_channels, opt.feature_maps, opt.batch_norm)) 112 | 113 | -- graph.dot(model.modules[1].fg, 'encoder', 'reports/encoder') 114 | 115 | if opt.criterion == 'MSE' then 116 | criterion = nn.MSECriterion() 117 | elseif opt.criterion == 'BCE' then 118 | criterion = nn.BCECriterion() 119 | -- criterion = nn.MotionBCECriterion(opt.motion_scale) 120 | else 121 | error("Invalid criterion specified!") 122 | end 123 | 124 | if opt.gpu then 125 | model:cuda() 126 | criterion:cuda() 127 | end 128 | params, grad_params = model:getParameters() 129 | 130 | 131 | function validate() 132 | local loss = 0 133 | model:evaluate() 134 | 135 | for i = 1, opt.num_test_batches do -- iterate over batches in the split 136 | -- fetch a batch 137 | local input = data_loaders.load_balls_batch(i, 'test') 138 | 139 | output = model:forward(input) 140 | 141 | local step_loss = criterion:forward(output, input[2]) 142 | loss = loss + step_loss 143 | end 144 | 145 | loss = loss / opt.num_test_batches 146 | return loss 147 | end 148 | 149 | -- do fwd/bwd and return loss, grad_params 150 | function feval(x) 151 | if x ~= params then 152 | error("Params not equal to given feval argument.") 153 | params:copy(x) 154 | end 155 | grad_params:zero() 156 | 157 | ------------------ get minibatch ------------------- 158 | local input = data_loaders.load_random_balls_batch('train') 159 | 160 | ------------------- forward pass ------------------- 161 | model:training() -- make sure we are in correct mode 162 | 163 | 164 | output = model:forward(input) 165 | 166 | loss = criterion:forward(output, input[2]) 167 | grad_output = criterion:backward(output, input[2]):clone() 168 | 169 | ------------------ backward pass ------------------- 170 | model:backward(input, grad_output) 171 | 172 | 173 | ------------------ regularize ------------------- 174 | if opt.L2 > 0 then 175 | -- Loss: 176 | loss = loss + opt.coefL2 * params:norm(2)^2/2 177 | -- Gradients: 178 | grad_params:add( params:clone():mul(opt.L2) ) 179 | end 180 | 181 | grad_params:clamp(-opt.grad_clip, opt.grad_clip) 182 | 183 | collectgarbage() 184 | return loss, grad_params 185 | end 186 | 187 | 188 | train_losses = {} 189 | val_losses = {} 190 | local optim_state = {learningRate = opt.learning_rate, alpha = opt.decay_rate} 191 | local iterations = opt.max_epochs * opt.num_train_batches 192 | -- local iterations_per_epoch = opt.num_train_batches 193 | local loss0 = nil 194 | 195 | for step = 1, iterations do 196 | scheduler_iteration[1] = step 197 | epoch = step / opt.num_train_batches 198 | 199 | local timer = torch.Timer() 200 | 201 | local _, loss = optim.rmsprop(feval, params, optim_state) 202 | -- print(params:norm()) -- params are definitely getting updated 203 | 204 | local time = timer:time().real 205 | 206 | local train_loss = loss[1] -- the loss is inside a list, pop it 207 | train_losses[step] = train_loss 208 | 209 | -- exponential learning rate decay 210 | if step % opt.learning_rate_decay_interval == 0 and opt.learning_rate_decay < 1 then 211 | if step >= opt.learning_rate_decay_after then 212 | local decay_factor = opt.learning_rate_decay 213 | optim_state.learningRate = optim_state.learningRate * decay_factor -- decay it 214 | print('decayed function learning rate by a factor ' .. decay_factor .. ' to ' .. optim_state.learningRate) 215 | end 216 | end 217 | 218 | if step % opt.print_every == 0 then 219 | print(string.format("%d/%d (epoch %.3f), train_loss = %6.8f, grad/param norm = %6.4e, time/batch = %.2fs", step, iterations, epoch, train_loss, grad_params:norm() / params:norm(), time)) 220 | end 221 | 222 | -- every now and then or on last iteration 223 | if step % opt.eval_val_every == 0 or step == iterations then 224 | -- evaluate loss on validation data 225 | local val_loss = validate() -- 2 = validation 226 | val_losses[step] = val_loss 227 | print(string.format('[epoch %.3f] Validation loss: %6.8f', epoch, val_loss)) 228 | 229 | local model_file = string.format('%s/epoch%.2f_%.4f.t7', savedir, epoch, val_loss) 230 | print('saving checkpoint to ' .. model_file) 231 | local checkpoint = {} 232 | checkpoint.model = model 233 | checkpoint.opt = opt 234 | checkpoint.train_losses = train_losses 235 | checkpoint.val_loss = val_loss 236 | checkpoint.val_losses = val_losses 237 | checkpoint.step = step 238 | checkpoint.epoch = epoch 239 | torch.save(model_file, checkpoint) 240 | 241 | local val_loss_log = io.open(savedir ..'/val_loss.txt', 'a') 242 | val_loss_log:write(val_loss .. "\n") 243 | val_loss_log:flush() 244 | val_loss_log:close() 245 | end 246 | 247 | if step % 10 == 0 then collectgarbage() end 248 | 249 | -- handle early stopping if things are going really bad 250 | if loss[1] ~= loss[1] then 251 | print('loss is NaN. This usually indicates a bug. Please check the issues page for existing issues, or create a new issue, if none exist. Ideally, please state: your operating system, 32-bit/64-bit, your blas version, cpu/cuda/cl?') 252 | break -- halt 253 | end 254 | if loss0 == nil then 255 | loss0 = loss[1] 256 | end 257 | -- if loss[1] > loss0 * 8 then 258 | -- print('loss is exploding, aborting.') 259 | -- print("loss0:", loss0, "loss[1]:", loss[1]) 260 | -- break -- halt 261 | -- end 262 | end 263 | --]] 264 | -------------------------------------------------------------------------------- /action_main.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'optim' 3 | 4 | require 'MotionBCECriterion' 5 | 6 | local Encoder = require 'ActionEncoder' 7 | local Decoder = require 'ActionDecoder' 8 | 9 | local data_loaders = require 'data_loaders' 10 | 11 | local cmd = torch.CmdLine() 12 | 13 | cmd:option('--name', 'net', 'filename to autosave the checkpont to. Will be inside checkpoint_dir/') 14 | cmd:option('--checkpoint_dir', 'logslink', 'output directory where checkpoints get written') 15 | cmd:option('-import', '', 'initialize network parameters from checkpoint at this path') 16 | 17 | -- data 18 | cmd:option('--datasetdir', '/om/data/public/mbchang/udcign-data/action', 'dataset source directory') -- change 19 | cmd:option('--dataset_name', 'allactionsd4', 'dataset source directory') 20 | cmd:option('--frame_interval', 1, 'the number of timesteps between input[1] and input[2]') 21 | 22 | -- optimization 23 | cmd:option('--learning_rate', 1e-4, 'learning rate') 24 | cmd:option('--learning_rate_decay', 0.97, 'learning rate decay') 25 | cmd:option('--learning_rate_decay_after', 7000, 'in number of examples, when to start decaying the learning rate') 26 | cmd:option('--learning_rate_decay_interval', 1000, 'in number of examples, how often to decay the learning rate') 27 | cmd:option('--decay_rate', 0.95, 'decay rate for rmsprop') 28 | cmd:option('--grad_clip', 3, 'clip gradients at this value') 29 | 30 | cmd:option('--L2', 0, 'amount of L2 regularization') 31 | cmd:option('--criterion', 'BCE', 'criterion to use') 32 | cmd:option('--batch_norm', false, 'use model with batch normalization') 33 | 34 | cmd:option('--heads', 1, 'how many filtering heads to use') 35 | cmd:option('--motion_scale', 3, 'how much to accentuate loss on changing pixels') 36 | 37 | cmd:option('--dim_hidden', 200, 'dimension of the representation layer') 38 | cmd:option('--feature_maps', 72, 'number of feature maps') 39 | cmd:option('--color_channels', 1, '1 for grayscale, 3 for color') 40 | cmd:option('--sharpening_rate', 10, 'number of feature maps') 41 | cmd:option('--noise', 0.1, 'variance of added Gaussian noise') 42 | 43 | 44 | cmd:option('--max_epochs', 100, 'number of full passes through the training data') 45 | 46 | -- bookkeeping 47 | cmd:option('--seed', 123, 'torch manual random number generator seed') 48 | cmd:option('--print_every', 10, 'how many steps/minibatches between printing out the loss') 49 | cmd:option('--eval_val_every', 1395, 'every how many iterations should we evaluate on validation data?') -- CHANGE 50 | 51 | -- data 52 | cmd:option('--num_train_batches', 6527, 'number of batches to train with per epoch') -- CHANGE 53 | cmd:option('--num_test_batches', 1395, 'number of batches to test with') -- CHANGE 54 | 55 | -- GPU/CPU 56 | cmd:option('--gpu', true, 'which gpu to use. -1 = use CPU') 57 | cmd:text() 58 | 59 | 60 | -- parse input params 61 | opt = cmd:parse(arg) 62 | torch.manualSeed(opt.seed) 63 | 64 | local dsizes = {walking={num_train_batches=1347,num_test_batches=288}, 65 | running={num_train_batches=860,num_test_batches=183}, 66 | jogging={num_train_batches=985,num_test_batches=210}, 67 | handclapping={num_train_batches=957,num_test_batches=205}, 68 | handwaving={num_train_batches=1216,num_test_batches=260}, 69 | boxing={num_train_batches=1015,num_test_batches=217}, 70 | allactions={num_train_batches=6527,num_test_batches=1395}, 71 | allactionsd4={num_train_batches=1468,num_test_batches=311}} 72 | opt.num_train_batches = dsizes[opt.dataset_name].num_train_batches 73 | opt.num_test_batches = dsizes[opt.dataset_name].num_test_batches 74 | opt.eval_val_every = opt.num_train_batches 75 | 76 | print(opt) 77 | print(opt.gpu) 78 | 79 | if opt.gpu then 80 | require 'cutorch' 81 | require 'cunn' 82 | end 83 | 84 | if opt.name == 'net' then 85 | local name = 'unsup_' 86 | for _, v in ipairs(arg) do 87 | name = name .. tostring(v) .. '_' 88 | end 89 | opt.name = name .. os.date("%b_%d_%H_%M_%S") 90 | end 91 | 92 | local savedir = string.format('%s/%s', opt.checkpoint_dir, opt.name) 93 | print("Saving output to "..savedir) 94 | os.execute('mkdir -p '..savedir) 95 | os.execute(string.format('rm %s/*', savedir)) 96 | 97 | -- log out the options used for creating this network to a file in the save directory. 98 | -- super useful when you're moving folders around so you don't lose track of things. 99 | local f = io.open(savedir .. '/opt.txt', 'w') 100 | for key, val in pairs(opt) do 101 | f:write(tostring(key) .. ": " .. tostring(val) .. "\n") 102 | end 103 | f:flush() 104 | f:close() 105 | 106 | local logfile = io.open(savedir .. '/output.log', 'w') 107 | true_print = print 108 | print = function(...) 109 | for _, v in ipairs{...} do 110 | true_print(v) 111 | logfile:write(tostring(v)) 112 | end 113 | logfile:write("\n") 114 | logfile:flush() 115 | end 116 | 117 | 118 | local scheduler_iteration = torch.zeros(1) 119 | 120 | local model = nn.Sequential() 121 | model:add(Encoder(opt.dim_hidden, opt.color_channels, opt.feature_maps, opt.noise, opt.sharpening_rate, scheduler_iteration, opt.batch_norm, opt.heads)) 122 | model:add(Decoder(opt.dim_hidden, opt.color_channels, opt.feature_maps, opt.batch_norm)) 123 | 124 | -- graph.dot(model.modules[1].fg, 'encoder', 'reports/encoder') 125 | 126 | if opt.criterion == 'MSE' then 127 | criterion = nn.MSECriterion() 128 | elseif opt.criterion == 'BCE' then 129 | --criterion = nn.BCECriterion() 130 | criterion = nn.MotionBCECriterion(opt.motion_scale) 131 | else 132 | error("Invalid criterion specified!") 133 | end 134 | 135 | if opt.gpu then 136 | model:cuda() 137 | criterion:cuda() 138 | end 139 | params, grad_params = model:getParameters() 140 | 141 | 142 | function validate() 143 | local loss = 0 144 | model:evaluate() 145 | 146 | for i = 1, opt.num_test_batches do -- iterate over batches in the split 147 | -- fetch a batch 148 | local input = data_loaders.load_action_batch(i, 'test') 149 | 150 | output = model:forward(input) 151 | 152 | local step_loss = criterion:forward(output, input[2]) 153 | loss = loss + step_loss 154 | end 155 | 156 | loss = loss / opt.num_test_batches 157 | return loss 158 | end 159 | 160 | -- do fwd/bwd and return loss, grad_params 161 | function feval(x) 162 | if x ~= params then 163 | error("Params not equal to given feval argument.") 164 | params:copy(x) 165 | end 166 | grad_params:zero() 167 | 168 | ------------------ get minibatch ------------------- 169 | local input = data_loaders.load_random_action_batch('train') 170 | 171 | ------------------- forward pass ------------------- 172 | model:training() -- make sure we are in correct mode 173 | 174 | 175 | output = model:forward(input) 176 | 177 | loss = criterion:forward(output, input[2]) 178 | grad_output = criterion:backward(output, input[2]):clone() 179 | 180 | ------------------ backward pass ------------------- 181 | model:backward(input, grad_output) 182 | 183 | 184 | ------------------ regularize ------------------- 185 | if opt.L2 > 0 then 186 | -- Loss: 187 | loss = loss + opt.coefL2 * params:norm(2)^2/2 188 | -- Gradients: 189 | grad_params:add( params:clone():mul(opt.L2) ) 190 | end 191 | 192 | grad_params:clamp(-opt.grad_clip, opt.grad_clip) 193 | 194 | collectgarbage() 195 | return loss, grad_params 196 | end 197 | 198 | 199 | train_losses = {} 200 | val_losses = {} 201 | local optim_state = {learningRate = opt.learning_rate, alpha = opt.decay_rate} 202 | local iterations = opt.max_epochs * opt.num_train_batches 203 | -- local iterations_per_epoch = opt.num_train_batches 204 | local loss0 = nil 205 | 206 | for step = 1, iterations do 207 | scheduler_iteration[1] = step 208 | epoch = step / opt.num_train_batches 209 | 210 | local timer = torch.Timer() 211 | 212 | local _, loss = optim.rmsprop(feval, params, optim_state) 213 | -- print(params:norm()) -- params are definitely getting updated 214 | 215 | local time = timer:time().real 216 | 217 | local train_loss = loss[1] -- the loss is inside a list, pop it 218 | train_losses[step] = train_loss 219 | 220 | -- exponential learning rate decay 221 | if step % opt.learning_rate_decay_interval == 0 and opt.learning_rate_decay < 1 then 222 | if step >= opt.learning_rate_decay_after then 223 | local decay_factor = opt.learning_rate_decay 224 | optim_state.learningRate = optim_state.learningRate * decay_factor -- decay it 225 | print('decayed function learning rate by a factor ' .. decay_factor .. ' to ' .. optim_state.learningRate) 226 | end 227 | end 228 | 229 | if step % opt.print_every == 0 then 230 | print(string.format("%d/%d (epoch %.3f), train_loss = %6.8f, grad/param norm = %6.4e, time/batch = %.2fs", step, iterations, epoch, train_loss, grad_params:norm() / params:norm(), time)) 231 | end 232 | 233 | -- every now and then or on last iteration 234 | if step % opt.eval_val_every == 0 or step == iterations then 235 | -- evaluate loss on validation data 236 | local val_loss = validate() -- 2 = validation 237 | val_losses[step] = val_loss 238 | print(string.format('[epoch %.3f] Validation loss: %6.8f', epoch, val_loss)) 239 | 240 | local model_file = string.format('%s/epoch%.2f_%.4f.t7', savedir, epoch, val_loss) 241 | print('saving checkpoint to ' .. model_file) 242 | local checkpoint = {} 243 | checkpoint.model = model 244 | checkpoint.opt = opt 245 | checkpoint.train_losses = train_losses 246 | checkpoint.val_loss = val_loss 247 | checkpoint.val_losses = val_losses 248 | checkpoint.step = step 249 | checkpoint.epoch = epoch 250 | torch.save(model_file, checkpoint) 251 | 252 | local val_loss_log = io.open(savedir ..'/val_loss.txt', 'a') 253 | val_loss_log:write(val_loss .. "\n") 254 | val_loss_log:flush() 255 | val_loss_log:close() 256 | end 257 | 258 | if step % 10 == 0 then collectgarbage() end 259 | 260 | -- handle early stopping if things are going really bad 261 | if loss[1] ~= loss[1] then 262 | print('loss is NaN. This usually indicates a bug. Please check the issues page for existing issues, or create a new issue, if none exist. Ideally, please state: your operating system, 32-bit/64-bit, your blas version, cpu/cuda/cl?') 263 | break -- halt 264 | end 265 | if loss0 == nil then 266 | loss0 = loss[1] 267 | end 268 | -- if loss[1] > loss0 * 8 then 269 | -- print('loss is exploding, aborting.') 270 | -- print("loss0:", loss0, "loss[1]:", loss[1]) 271 | -- break -- halt 272 | -- end 273 | end 274 | --]] 275 | -------------------------------------------------------------------------------- /downsampled_main.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'optim' 3 | 4 | require 'MotionBCECriterion' 5 | require 'Scale' 6 | 7 | Encoder = require 'DownsampledEncoder' 8 | Decoder = require 'DownsampledDecoder' 9 | 10 | Autoencoder = require 'DownsampledAutoencoder' 11 | 12 | data_loaders = require 'data_loaders' 13 | 14 | local cmd = torch.CmdLine() 15 | 16 | cmd:option('--name', 'net', 'filename to autosave the checkpont to. Will be inside checkpoint_dir/') 17 | cmd:option('--checkpoint_dir', 'networks', 'output directory where checkpoints get written') 18 | cmd:option('--import', '', 'initialize network parameters from checkpoint at this path') 19 | 20 | 21 | cmd:option('--model', 'disentangled', 'which model to use. disentangled or autoencoder') 22 | 23 | -- data 24 | cmd:option('--datasetdir', '/om/user/wwhitney/deep-game-engine', 'dataset source directory') 25 | cmd:option('--dataset_name', 'space_invaders', 'dataset source directory') 26 | cmd:option('--frame_interval', 1, 'the number of timesteps between input[1] and input[2]') 27 | 28 | -- optimization 29 | cmd:option('--learning_rate', 1e-4, 'learning rate') 30 | cmd:option('--learning_rate_decay', 0.97, 'learning rate decay') 31 | cmd:option('--learning_rate_decay_after', 18000, 'in number of examples, when to start decaying the learning rate') 32 | cmd:option('--learning_rate_decay_interval', 4000, 'in number of examples, how often to decay the learning rate') 33 | cmd:option('--decay_rate', 0.95, 'decay rate for rmsprop') 34 | cmd:option('--grad_clip', 3, 'clip gradients at this value') 35 | 36 | cmd:option('--criterion', 'BCE', 'criterion to use') 37 | 38 | cmd:option('--heads', 1, 'how many filtering heads to use') 39 | cmd:option('--motion_scale', 1, 'how much to accentuate loss on changing pixels') 40 | 41 | cmd:option('--dim_hidden', 200, 'dimension of the representation layer') 42 | cmd:option('--feature_maps', 72, 'number of feature maps') 43 | cmd:option('--color_channels', 1, '1 for grayscale, 3 for color') 44 | cmd:option('--sharpening_rate', 10, 'number of feature maps') 45 | cmd:option('--noise', 0.1, 'variance of added Gaussian noise') 46 | 47 | 48 | cmd:option('--max_epochs', 50, 'number of full passes through the training data') 49 | 50 | -- bookkeeping 51 | cmd:option('--seed', 123, 'torch manual random number generator seed') 52 | cmd:option('--print_every', 1, 'how many steps/minibatches between printing out the loss') 53 | cmd:option('--eval_val_every', 9000, 'every how many iterations should we evaluate on validation data?') 54 | 55 | -- data 56 | cmd:option('--num_train_batches', 8000, 'number of batches to train with per epoch') 57 | cmd:option('--num_test_batches', 900, 'number of batches to test with') 58 | 59 | -- GPU/CPU 60 | cmd:option('--gpu', false, 'which gpu to use. -1 = use CPU') 61 | cmd:text() 62 | 63 | 64 | -- parse input params 65 | opt = cmd:parse(arg) 66 | torch.manualSeed(opt.seed) 67 | 68 | if opt.gpu then 69 | require 'cutorch' 70 | require 'cunn' 71 | end 72 | 73 | if opt.name == 'net' then 74 | local name = 'downsample_' 75 | for _, v in ipairs(arg) do 76 | name = name .. tostring(v) .. '_' 77 | end 78 | opt.name = name .. os.date("%b_%d_%H_%M_%S") 79 | end 80 | 81 | local savedir = string.format('%s/%s', opt.checkpoint_dir, opt.name) 82 | print("Saving output to "..savedir) 83 | os.execute('mkdir -p '..savedir) 84 | os.execute(string.format('rm %s/*', savedir)) 85 | 86 | -- log out the options used for creating this network to a file in the save directory. 87 | -- super useful when you're moving folders around so you don't lose track of things. 88 | local f = io.open(savedir .. '/opt.txt', 'w') 89 | for key, val in pairs(opt) do 90 | f:write(tostring(key) .. ": " .. tostring(val) .. "\n") 91 | end 92 | f:flush() 93 | f:close() 94 | 95 | local logfile = io.open(savedir .. '/output.log', 'w') 96 | true_print = print 97 | print = function(...) 98 | for _, v in ipairs{...} do 99 | true_print(v) 100 | logfile:write(tostring(v)) 101 | end 102 | logfile:write("\n") 103 | logfile:flush() 104 | end 105 | 106 | -- this is dumb, but it's the easiest and most portable way 107 | -- to make this a global variable 108 | opt.current_scheduler_iteration = 0 109 | 110 | -- local model 111 | if opt.model == 'disentangled' then 112 | model = nn.Sequential() 113 | local encoder = Encoder(opt.dim_hidden, opt.color_channels, opt.feature_maps, opt.noise, opt.sharpening_rate, scheduler_iteration, opt.heads) 114 | local decoder = Decoder(opt.dim_hidden, opt.color_channels, opt.feature_maps) 115 | model:add(encoder) 116 | model:add(decoder) 117 | elseif opt.model == 'autoencoder' then 118 | model = Autoencoder(opt.dim_hidden, opt.color_channels, opt.feature_maps) 119 | else 120 | error("Invalid model type: " ..opt.model) 121 | end 122 | 123 | print(model) 124 | 125 | scale = nn.Scale(84, 84, true) 126 | 127 | -- local encoder1 = encoder:findModules('nn.Sequential')[1] 128 | -- local encoder2 = encoder:findModules('nn.Sequential')[2] 129 | 130 | -- graph.dot(model.modules[1].fg, 'encoder', 'reports/encoder') 131 | 132 | if opt.criterion == 'MSE' then 133 | criterion = nn.MSECriterion() 134 | elseif opt.criterion == 'BCE' then 135 | criterion = nn.MotionBCECriterion(opt.motion_scale) 136 | else 137 | error("Invalid criterion specified!") 138 | end 139 | 140 | if opt.gpu then 141 | model:cuda() 142 | criterion:cuda() 143 | end 144 | params, grad_params = model:getParameters() 145 | 146 | 147 | local sharpener = model.modules[1]:findModules('nn.ScheduledWeightSharpener')[1] 148 | function validate() 149 | local loss = 0 150 | model:evaluate() 151 | 152 | for i = 1, opt.num_test_batches do -- iterate over batches in the split 153 | -- fetch a batch 154 | local input = data_loaders.load_atari_batch(i, 'test') 155 | input = { 156 | scale:forward(input[1]), 157 | scale:forward(input[2]), 158 | } 159 | local target = input[2] 160 | 161 | if opt.model == 'autoencoder' then 162 | input = input[1] 163 | end 164 | output = model:forward(input) 165 | 166 | local step_loss = criterion:forward(output, target) 167 | loss = loss + step_loss 168 | end 169 | 170 | loss = loss / opt.num_test_batches 171 | return loss 172 | end 173 | 174 | -- do fwd/bwd and return loss, grad_params 175 | function feval(x) 176 | if x ~= params then 177 | error("Params not equal to given feval argument.") 178 | params:copy(x) 179 | end 180 | grad_params:zero() 181 | 182 | ------------------ get minibatch ------------------- 183 | local input = data_loaders.load_random_atari_batch('train') 184 | input = { 185 | scale:forward(input[1]), 186 | scale:forward(input[2]), 187 | } 188 | local target = input[2] 189 | 190 | ------------------- forward pass ------------------- 191 | model:training() -- make sure we are in correct mode 192 | 193 | 194 | local loss 195 | if opt.model == 'autoencoder' then 196 | input = input[1] 197 | end 198 | -- print(input:size()) 199 | local output = model:forward(input) 200 | -- print(output:size()) 201 | loss = criterion:forward(output, target) 202 | local grad_output = criterion:backward(output, target):clone() 203 | 204 | model:backward(input, grad_output) 205 | 206 | grad_params:clamp(-opt.grad_clip, opt.grad_clip) 207 | 208 | collectgarbage() 209 | return loss, grad_params 210 | end 211 | 212 | 213 | train_losses = {} 214 | val_losses = {} 215 | local optim_state = {learningRate = opt.learning_rate, alpha = opt.decay_rate} 216 | local iterations = opt.max_epochs * opt.num_train_batches 217 | -- local iterations_per_epoch = opt.num_train_batches 218 | local loss0 = nil 219 | 220 | for step = 1, iterations do 221 | opt.current_scheduler_iteration = step 222 | epoch = step / opt.num_train_batches 223 | 224 | local timer = torch.Timer() 225 | 226 | local _, loss = optim.rmsprop(feval, params, optim_state) 227 | 228 | local time = timer:time().real 229 | 230 | local train_loss = loss[1] -- the loss is inside a list, pop it 231 | train_losses[step] = train_loss 232 | 233 | -- exponential learning rate decay 234 | if step % opt.learning_rate_decay_interval == 0 and opt.learning_rate_decay < 1 then 235 | if step >= opt.learning_rate_decay_after then 236 | local decay_factor = opt.learning_rate_decay 237 | optim_state.learningRate = optim_state.learningRate * decay_factor -- decay it 238 | print('decayed function learning rate by a factor ' .. decay_factor .. ' to ' .. optim_state.learningRate) 239 | end 240 | end 241 | 242 | if step % opt.print_every == 0 then 243 | print(string.format("%d/%d (epoch %.3f), train_loss = %6.8f, grad/param norm = %6.4e, time/batch = %.2fs", step, iterations, epoch, train_loss, grad_params:norm() / params:norm(), time)) 244 | end 245 | 246 | -- every now and then or on last iteration 247 | if step % opt.eval_val_every == 0 or step == iterations then 248 | print(string.format("Weight sharpener exponent at epoch %.3f: %.12f", epoch, sharpener:getP())) 249 | 250 | -- evaluate loss on validation data 251 | local val_loss = validate() -- 2 = validation 252 | val_losses[step] = val_loss 253 | print(string.format('[epoch %.3f] Validation loss: %6.8f', epoch, val_loss)) 254 | 255 | local model_file = string.format('%s/epoch%.2f_%.4f.t7', savedir, epoch, val_loss) 256 | print('saving checkpoint to ' .. model_file) 257 | local checkpoint = {} 258 | checkpoint.model = model 259 | checkpoint.opt = opt 260 | checkpoint.train_losses = train_losses 261 | checkpoint.val_loss = val_loss 262 | checkpoint.val_losses = val_losses 263 | checkpoint.step = step 264 | checkpoint.epoch = epoch 265 | torch.save(model_file, checkpoint) 266 | 267 | local val_loss_log = io.open(savedir ..'/val_loss.txt', 'a') 268 | val_loss_log:write(val_loss .. "\n") 269 | val_loss_log:flush() 270 | val_loss_log:close() 271 | end 272 | 273 | if step % 10 == 0 then collectgarbage() end 274 | 275 | -- handle early stopping if things are going really bad 276 | if loss[1] ~= loss[1] then 277 | print('loss is NaN. This usually indicates a bug. Please check the issues page for existing issues, or create a new issue, if none exist. Ideally, please state: your operating system, 32-bit/64-bit, your blas version, cpu/cuda/cl?') 278 | break -- halt 279 | end 280 | if loss0 == nil then 281 | loss0 = loss[1] 282 | end 283 | -- if loss[1] > loss0 * 8 then 284 | -- print('loss is exploding, aborting.') 285 | -- print("loss0:", loss0, "loss[1]:", loss[1]) 286 | -- break -- halt 287 | -- end 288 | end 289 | --]] 290 | --------------------------------------------------------------------------------