├── README.md ├── data ├── data.lua ├── donkey_video2.lua └── donkey_video3.lua ├── extra └── stabilize_videos_many.py ├── generate.lua ├── main.lua ├── main_conditional.lua └── main_ucf.lua /README.md: -------------------------------------------------------------------------------- 1 | Generating Videos with Scene Dynamics 2 | ===================================== 3 | 4 | This repository contains an implementation of [Generating Videos with Scene Dynamics](http://carlvondrick.com/tinyvideo/) by Carl Vondrick, Hamed Pirsiavash, Antonio Torralba, to appear at NIPS 2016. The model learns to generate tiny videos using adversarial networks. 5 | 6 | Example Generations 7 | ------------------- 8 | Below are some selected videos that are generated by our model. These videos are not real; they are hallucinated by a generative video model. While they are not photo-realistic, the motions are fairly reasonable for the scene category they are trained on. 9 | 10 |
11 | Beach
12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 |
28 | Golf
29 | 30 | 31 | 32 | 33 |
34 | 35 | 36 | 37 | 38 |
39 | 40 | 41 | 42 | 43 | 44 |
45 | Train Station
46 | 47 | 48 | 49 | 50 |
51 | 52 | 53 | 54 | 55 |
56 | 57 | 58 | 59 | 60 | 61 |
62 | Baby
63 | 64 | 65 | 66 | 67 |
68 | 69 | 70 | 71 | 72 |
73 | 74 | 75 | 76 | 77 | 78 |
79 | 80 | 81 | 82 | Training 83 | -------- 84 | 85 | The code requires a Torch7 installation. 86 | 87 | To train a generator for video, see main.lua. This file will construct the networks, start many threads to load data, and train the networks. 88 | 89 | For the conditional version, see main_conditional.lua. This is similar to main.lua, except the input to the model is a static image. 90 | 91 | To generate videos, see generate.lua. This file will also output intermediate layers, 92 | such as the mask and background image, which you can inspect manually. 93 | 94 | Data 95 | ---- 96 | The data loading is designed assuming videos have been stabilized and flattened 97 | into JPEG images. We do this for efficiency. Stabilization is computationally slow and 98 | must be done offline, and reading one file per video is more efficient on NFS. 99 | 100 | For our stabilization code, see the 'extra' directory. 101 | Essentially, this will convert each video into an image of vertically 102 | concatenated frames. After doing this, you create a text file listing 103 | all the frames, which you pass into the data loader. 104 | 105 | Models 106 | ------ 107 | You can download our pre-trained models [here](https://drive.google.com/file/d/0B-xMJ5CYz_F9QS1BTE5yWl9aUWs/view?usp=sharing) (1 GB ZIP file). 108 | 109 | Notes 110 | ----- 111 | The code is based on [DCGAN](https://github.com/soumith/dcgan.torch) and our [starter code](https://github.com/cvondrick/torch-starter) in [Torch7](https://github.com/torch/torch7). 112 | 113 | If you find this useful for your research, please consider citing our NIPS 114 | paper. 115 | 116 | License 117 | ------- 118 | MIT 119 | -------------------------------------------------------------------------------- /data/data.lua: -------------------------------------------------------------------------------- 1 | local Threads = require 'threads' 2 | Threads.serialization('threads.sharedserialize') 3 | 4 | local data = {} 5 | 6 | local result = {} 7 | local unpack = unpack and unpack or table.unpack 8 | 9 | function data.new(n, dataset_name, opt_) 10 | opt_ = opt_ or {} 11 | local self = {} 12 | for k,v in pairs(data) do 13 | self[k] = v 14 | end 15 | 16 | self.randomize = opt_.randomize 17 | 18 | local donkey_file 19 | if dataset_name == 'simple' then 20 | donkey_file = 'donkey_simple.lua' 21 | elseif dataset_name == 'video2' then 22 | donkey_file = 'donkey_video2.lua' 23 | elseif dataset_name == 'video3' then 24 | donkey_file = 'donkey_video3.lua' 25 | else 26 | error('Unknown dataset: ' .. dataset_name) 27 | end 28 | 29 | if n > 0 then 30 | local options = opt_ 31 | self.threads = Threads(n, 32 | function() require 'torch' end, 33 | function(idx) 34 | opt = options 35 | tid = idx 36 | local seed = (opt.manualSeed and opt.manualSeed or 0) + idx 37 | torch.manualSeed(seed) 38 | torch.setnumthreads(1) 39 | print(string.format('Starting donkey with id: %d seed: %d', tid, seed)) 40 | assert(options, 'options not found') 41 | assert(opt, 'opt not given') 42 | paths.dofile(donkey_file) 43 | end 44 | ) 45 | else 46 | if donkey_file then paths.dofile(donkey_file) end 47 | self.threads = {} 48 | function self.threads:addjob(f1, f2) f2(f1()) end 49 | function self.threads:dojob() end 50 | function self.threads:synchronize() end 51 | end 52 | 53 | local nSamples = 0 54 | self.threads:addjob(function() return trainLoader:size() end, 55 | function(c) nSamples = c end) 56 | self.threads:synchronize() 57 | self._size = nSamples 58 | 59 | self.jobCount = 0 60 | for i = 1, n do 61 | self:queueJob() 62 | end 63 | 64 | return self 65 | end 66 | 67 | function data:queueJob() 68 | self.jobCount = self.jobCount + 1 69 | 70 | if self.randomize > 0 then 71 | self.threads:addjob(function() 72 | return trainLoader:sample(opt.batchSize) 73 | end, 74 | self._pushResult) 75 | else 76 | local indexStart = (self.jobCount-1) * opt.batchSize + 1 77 | local indexEnd = (indexStart + opt.batchSize - 1) 78 | if indexEnd <= self:size() then 79 | self.threads:addjob(function() 80 | return trainLoader:get(indexStart, indexEnd) 81 | end, 82 | self._pushResult) 83 | end 84 | end 85 | end 86 | 87 | function data._pushResult(...) 88 | local res = {...} 89 | if res == nil then 90 | self.threads:synchronize() 91 | end 92 | result[1] = res 93 | end 94 | 95 | function data:getBatch() 96 | -- queue another job 97 | local res 98 | repeat 99 | self:queueJob() 100 | self.threads:dojob() 101 | res = result[1] 102 | result[1] = nil 103 | until torch.type(res) == 'table' 104 | return unpack(res) 105 | end 106 | 107 | function data:size() 108 | return self._size 109 | end 110 | 111 | return data 112 | -------------------------------------------------------------------------------- /data/donkey_video2.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2015-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | -- Heavily moidifed by Carl to make it simpler 11 | 12 | require 'torch' 13 | require 'image' 14 | tds = require 'tds' 15 | torch.setdefaulttensortype('torch.FloatTensor') 16 | local class = require('pl.class') 17 | 18 | local dataset = torch.class('dataLoader') 19 | 20 | -- this function reads in the data files 21 | function dataset:__init(args) 22 | for k,v in pairs(args) do self[k] = v end 23 | 24 | assert(self.frameSize > 0) 25 | 26 | if self.filenamePad == nil then 27 | self.filenamePad = 8 28 | end 29 | 30 | -- read text file consisting of frame directories and counts of frames 31 | self.data = tds.Vec() 32 | print('reading ' .. args.data_list) 33 | for line in io.lines(args.data_list) do 34 | local split = {} 35 | for k in string.gmatch(line, "%S+") do table.insert(split, k) end 36 | self.data:insert(split[1]) 37 | end 38 | 39 | print('found ' .. #self.data .. ' videos') 40 | 41 | end 42 | 43 | function dataset:size() 44 | return #self.data 45 | end 46 | 47 | -- converts a table of samples (and corresponding labels) to a clean tensor 48 | function dataset:tableToOutput(dataTable, extraTable) 49 | local data, scalarLabels, labels 50 | local quantity = #dataTable 51 | assert(dataTable[1]:dim() == 4) 52 | data = torch.Tensor(quantity, 3, self.frameSize, self.fineSize, self.fineSize) 53 | for i=1,#dataTable do 54 | data[i]:copy(dataTable[i]) 55 | end 56 | return data, extraTable 57 | end 58 | 59 | -- sampler, samples with replacement from the training set. 60 | function dataset:sample(quantity) 61 | assert(quantity) 62 | local dataTable = {} 63 | local extraTable = {} 64 | for i=1,quantity do 65 | local idx = torch.random(1, #self.data) 66 | local data_path = self.data_root .. '/' .. self.data[idx] 67 | 68 | local out = self:trainHook(data_path) 69 | table.insert(dataTable, out) 70 | table.insert(extraTable, self.data[idx]) 71 | end 72 | return self:tableToOutput(dataTable,extraTable) 73 | end 74 | 75 | -- gets data in a certain range 76 | function dataset:get(start_idx,stop_idx) 77 | local dataTable = {} 78 | local extraTable = {} 79 | for idx=start_idx,stop_idx do 80 | local data_path = self.data_root .. '/' .. self.data[idx] 81 | 82 | local out = self:trainHook(data_path) 83 | table.insert(dataTable, out) 84 | table.insert(extraTable, self.data[idx]) 85 | end 86 | return self:tableToOutput(dataTable,extraTable) 87 | 88 | end 89 | 90 | -- function to load the image, jitter it appropriately (random crops etc.) 91 | function dataset:trainHook(path) 92 | collectgarbage() 93 | 94 | local oW = self.fineSize 95 | local oH = self.fineSize 96 | local h1 97 | local w1 98 | 99 | local out = torch.zeros(3, self.frameSize, oW, oH) 100 | 101 | local ok,input = pcall(image.load, path, 3, 'float') 102 | if not ok then 103 | print('warning: failed loading: ' .. path) 104 | return out 105 | end 106 | 107 | local count = input:size(2) / opt.loadSize 108 | local t1 = 1 109 | 110 | for fr=1,self.frameSize do 111 | local off 112 | if fr <= count then 113 | off = (fr+t1-2) * opt.loadSize+1 114 | else 115 | off = (count+t1-2)*opt.loadSize+1 -- repeat the last frame 116 | end 117 | local crop = input[{ {}, {off, off+opt.loadSize-1}, {} }] 118 | out[{ {}, fr, {}, {} }]:copy(image.scale(crop, opt.fineSize, opt.fineSize)) 119 | end 120 | 121 | out:mul(2):add(-1) -- make it [0, 1] -> [-1, 1] 122 | 123 | -- subtract mean 124 | for c=1,3 do 125 | out[{ c, {}, {} }]:add(-self.mean[c]) 126 | end 127 | 128 | return out 129 | end 130 | 131 | -- data.lua expects a variable called trainLoader 132 | trainLoader = dataLoader(opt) 133 | -------------------------------------------------------------------------------- /data/donkey_video3.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2015-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | -- Heavily moidifed by Carl to make it simpler 11 | 12 | require 'torch' 13 | require 'image' 14 | tds = require 'tds' 15 | torch.setdefaulttensortype('torch.FloatTensor') 16 | local class = require('pl.class') 17 | 18 | local dataset = torch.class('dataLoader') 19 | 20 | -- this function reads in the data files 21 | function dataset:__init(args) 22 | for k,v in pairs(args) do self[k] = v end 23 | 24 | assert(self.frameSize > 0) 25 | 26 | if self.filenamePad == nil then 27 | self.filenamePad = 8 28 | end 29 | 30 | -- read text file consisting of frame directories and counts of frames 31 | self.data = tds.Vec() 32 | self.category = tds.Vec() 33 | print('reading ' .. args.data_list) 34 | for line in io.lines(args.data_list) do 35 | local split = {} 36 | for k in string.gmatch(line, "%S+") do table.insert(split, k) end 37 | self.data:insert(split[1]) 38 | self.category:insert(split[2]) 39 | end 40 | 41 | print('found ' .. #self.data .. ' videos') 42 | 43 | end 44 | 45 | function dataset:size() 46 | return #self.data 47 | end 48 | 49 | -- converts a table of samples (and corresponding labels) to a clean tensor 50 | function dataset:tableToOutput(dataTable, scalarTable, extraTable) 51 | local data, scalarLabels, labels 52 | local quantity = #dataTable 53 | assert(dataTable[1]:dim() == 4) 54 | data = torch.Tensor(quantity, 3, self.frameSize, self.fineSize, self.fineSize) 55 | label = torch.Tensor(quantity) 56 | for i=1,#dataTable do 57 | data[i]:copy(dataTable[i]) 58 | label[i] = scalarTable[i] 59 | end 60 | return data, label, extraTable 61 | end 62 | 63 | -- sampler, samples with replacement from the training set. 64 | function dataset:sample(quantity) 65 | assert(quantity) 66 | local dataTable = {} 67 | local scalarTable = {} 68 | local extraTable = {} 69 | for i=1,quantity do 70 | local idx = torch.random(1, #self.data) 71 | local data_path = self.data_root .. '/' .. self.data[idx] 72 | 73 | local out = self:trainHook(data_path) 74 | table.insert(dataTable, out) 75 | table.insert(extraTable, self.data[idx]) 76 | table.insert(scalarTable, self.category[idx]) 77 | end 78 | return self:tableToOutput(dataTable, scalarTable, extraTable) 79 | end 80 | 81 | -- gets data in a certain range 82 | function dataset:get(start_idx,stop_idx) 83 | assert(false) 84 | end 85 | 86 | -- function to load the image, jitter it appropriately (random crops etc.) 87 | function dataset:trainHook(path) 88 | collectgarbage() 89 | 90 | local oW = self.fineSize 91 | local oH = self.fineSize 92 | local h1 93 | local w1 94 | 95 | local out = torch.zeros(3, self.frameSize, oW, oH) 96 | 97 | local ok,input = pcall(image.load, path, 3, 'float') 98 | if not ok then 99 | print('warning: failed loading: ' .. path) 100 | return out 101 | end 102 | 103 | local count = input:size(2) / opt.loadSize 104 | local t1 = 1 105 | 106 | for fr=1,self.frameSize do 107 | local off 108 | if fr <= count then 109 | off = (fr+t1-2) * opt.loadSize+1 110 | else 111 | off = (count+t1-2)*opt.loadSize+1 -- repeat the last frame 112 | end 113 | 114 | local crop 115 | if off+opt.loadSize-1 <= input:size(2) and off > 0 then 116 | crop = input[{ {}, {off, off+opt.loadSize-1}, {} }] 117 | else 118 | print('*** WARNING ***') 119 | print(' bad size') 120 | print(' path: ' .. path) 121 | crop = torch.zeros(3, opt.fineSize, opt.fineSize) 122 | end 123 | out[{ {}, fr, {}, {} }]:copy(image.scale(crop, opt.fineSize, opt.fineSize)) 124 | end 125 | 126 | out:mul(2):add(-1) -- make it [0, 1] -> [-1, 1] 127 | 128 | -- subtract mean 129 | for c=1,3 do 130 | out[{ c, {}, {} }]:add(-self.mean[c]) 131 | end 132 | 133 | return out 134 | end 135 | 136 | -- data.lua expects a variable called trainLoader 137 | trainLoader = dataLoader(opt) 138 | -------------------------------------------------------------------------------- /extra/stabilize_videos_many.py: -------------------------------------------------------------------------------- 1 | # hackery for opencv to use sift 2 | import sys 3 | sys.path.insert(0, "/data/vision/torralba/commonsense/future/opencv-2.4.11/install/lib/python2.7/dist-packages") 4 | 5 | import numpy as np 6 | import cv2 7 | import json 8 | import os 9 | import argparse 10 | import subprocess 11 | import random 12 | from scipy.ndimage.filters import gaussian_filter 13 | 14 | MIN_MATCH_COUNT = 10 15 | VIDEO_SIZE = 128 16 | CROP_SIZE = 128 17 | MAX_FRAMES = 33 18 | MIN_FRAMES = 16 19 | FRAMES_DELAY = 2 20 | 21 | def get_video_info(video): 22 | stats = subprocess.check_output("ffprobe -select_streams v -v error -show_entries stream=width,height,duration -of default=noprint_wrappers=1 {}".format(video), shell = True) 23 | info = dict(x.split("=") for x in stats.strip().split("\n")) 24 | print info 25 | return {"width": int(info['width']), 26 | "height": int(info['height']), 27 | "duration": float(info['duration'])} 28 | 29 | class FrameReader(object): 30 | def __init__(self, video): 31 | self.info = get_video_info(video) 32 | 33 | command = [ "ffmpeg", 34 | '-i', video, 35 | '-f', 'image2pipe', 36 | '-pix_fmt', 'rgb24', 37 | '-vcodec', 'rawvideo', 38 | '-'] 39 | self.pipe = subprocess.Popen(command, stdout = subprocess.PIPE, bufsize=10**8) 40 | 41 | def __iter__(self): 42 | return self 43 | 44 | def next(self): 45 | raw_image = self.pipe.stdout.read(self.info['width']*self.info['height']*3) 46 | # transform the byte read into a numpy array 47 | image = np.fromstring(raw_image, dtype='uint8') 48 | try: 49 | image = image.reshape((self.info['height'],self.info['width'],3)) 50 | except: 51 | raise StopIteration() 52 | # throw away the data in the pipe's buffer. 53 | self.pipe.stdout.flush() 54 | 55 | image = image[:, :, (2,1,0)] 56 | 57 | return image 58 | 59 | def close(self): 60 | self.pipe.stdout.close() 61 | self.pipe.kill() 62 | 63 | def process_im(im): 64 | h = im.shape[0] 65 | w = im.shape[1] 66 | 67 | if w > h: 68 | scale = float(VIDEO_SIZE) / h 69 | else: 70 | scale = float(VIDEO_SIZE) / w 71 | 72 | new_h = int(h * scale) 73 | new_w = int(w * scale) 74 | 75 | im = cv2.resize(im, (new_w, new_h)) 76 | 77 | h = im.shape[0] 78 | w = im.shape[1] 79 | 80 | h_start = h / 2 - CROP_SIZE / 2 81 | h_stop = h_start + CROP_SIZE 82 | 83 | w_start = w / 2 - CROP_SIZE / 2 84 | w_stop = w_start + CROP_SIZE 85 | 86 | im = im[h_start:h_stop, w_start:w_stop, :] 87 | 88 | return im 89 | 90 | def compute(video, frame_dir): 91 | try: 92 | frames = FrameReader(video) 93 | except subprocess.CalledProcessError: 94 | print "failed due to CalledProcessError" 95 | return False 96 | 97 | for _ in range(FRAMES_DELAY): 98 | try: 99 | frames.next() 100 | except StopIteration: 101 | return False 102 | 103 | # Initiate SIFT detector 104 | sift = cv2.SIFT() 105 | #sift = cv2.ORB() 106 | #sift = cv2.BRISK() 107 | 108 | FLANN_INDEX_KDTREE = 0 109 | index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 5) 110 | search_params = dict(checks = 50) 111 | flann = cv2.FlannBasedMatcher(index_params, search_params) 112 | 113 | movie_clip = 0 114 | movie_clip_files = [] 115 | for _ in range(100): 116 | try: 117 | img2 = frames.next() 118 | except StopIteration: 119 | print "end of stream" 120 | break 121 | 122 | bg_img = process_im(img2.copy()) 123 | kp2, des2 = sift.detectAndCompute(img2,None) 124 | 125 | Ms = [] 126 | 127 | movie = [bg_img.copy()] 128 | 129 | failed = False 130 | 131 | bigM = np.eye(3) 132 | 133 | for fr, img1 in enumerate(frames): 134 | #img1 = cv2.imread(im1,0) 135 | kp1, des1 = sift.detectAndCompute(img1,None) 136 | if des1 is None or des2 is None: 137 | print "Empty matches" 138 | M = np.eye(3) 139 | failed = True 140 | elif len(kp1) < 2 or len(kp2) < 2: 141 | print "Not enough key points" 142 | M = np.eye(3) 143 | failed = True 144 | else: 145 | matches = flann.knnMatch(des1.astype("float32"),des2.astype("float32"),k=2) 146 | # store all the good matches as per Lowe's ratio test. 147 | good = [] 148 | for m,n in matches: 149 | if m.distance < 0.7*n.distance: 150 | good.append(m) 151 | 152 | if len(good)>=MIN_MATCH_COUNT: 153 | src_pts = np.float32([ kp1[m.queryIdx].pt for m in good ]).reshape(-1,1,2) 154 | dst_pts = np.float32([ kp2[m.trainIdx].pt for m in good ]).reshape(-1,1,2) 155 | 156 | M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC,5.0) 157 | else: 158 | print "Not enough matches are found - %d/%d" % (len(good),MIN_MATCH_COUNT) 159 | 160 | M = np.eye(3) 161 | failed = True 162 | 163 | Ms.append(M) 164 | bigM = np.dot(bigM, M) 165 | 166 | mask = (np.ones((img2.shape[0], img2.shape[1], 3)) * 255).astype('uint8') 167 | mask = cv2.warpPerspective(mask, bigM, (img2.shape[1], img2.shape[0])) 168 | mask = cv2.erode(mask / 255, np.ones((5,5),np.uint8), iterations=1) * 255 169 | mask = process_im(mask).astype("float") / 255. 170 | 171 | if (mask > 0).any(): 172 | save_im = cv2.warpPerspective(img1, bigM, (img2.shape[1], img2.shape[0])) 173 | 174 | save_im = bg_img * (1-mask) + process_im(save_im) * mask 175 | movie.append(save_im.copy()) 176 | #cv2.imwrite(frame_dir + ("/%08d.jpg"%(frame_counter)), save_im) 177 | 178 | bg_img = save_im.copy() 179 | 180 | else: # homography has gone out of frame, so just abort, comment these lines to keep trying 181 | break 182 | 183 | if len(movie) > MAX_FRAMES: 184 | break 185 | 186 | img2 = img1 187 | kp2 = kp1 188 | des2 = des1 189 | 190 | if failed: 191 | break 192 | 193 | if len(movie) < MIN_FRAMES: 194 | print "this movie clip is too short, causing fail" 195 | failed = True 196 | 197 | if failed: 198 | print "aborting movie clip due to failure" 199 | else: 200 | # write a column stacked image so it can be loaded at once, which 201 | # will hopefully reduce IO significantly 202 | stacked = np.vstack(movie) 203 | movie_clip_filename = frame_dir + "/%04d.jpg" % movie_clip 204 | movie_clip_files.append(movie_clip_filename) 205 | print "writing {}".format(movie_clip_filename) 206 | cv2.imwrite(movie_clip_filename, stacked) 207 | movie_clip += 1 208 | 209 | frames.close() 210 | 211 | open(frame_dir + "/list.txt", "w").write("\n".join(movie_clip_files)) 212 | 213 | def get_stable_path(video): 214 | #return "frames-stable/{}".format(video) 215 | return "frames-stable-many/{}".format(video) 216 | 217 | work = [x.strip() for x in open("scene_extract/job_list.txt")] 218 | random.shuffle(work) 219 | 220 | for video in work: 221 | stable_path = get_stable_path(video) 222 | lock_file = stable_path + ".lock" 223 | 224 | if os.path.exists(stable_path) or os.path.exists(lock_file): 225 | print "already done: {}".format(stable_path) 226 | continue 227 | 228 | try: 229 | os.makedirs(os.path.dirname(stable_path)) 230 | except OSError: 231 | pass 232 | try: 233 | os.makedirs(stable_path) 234 | except OSError: 235 | pass 236 | try: 237 | os.mkdir(lock_file) 238 | except OSError: 239 | pass 240 | 241 | print video 242 | 243 | #result = compute("videos/" + video, stable_path) 244 | result = compute(video, stable_path) 245 | 246 | try: 247 | os.rmdir(lock_file) 248 | except: 249 | pass 250 | -------------------------------------------------------------------------------- /generate.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'image' 4 | require 'cunn' 5 | require 'cudnn' 6 | 7 | opt = { 8 | model = 'models/golf/iter65000_net.t7', 9 | batchSize = 128, 10 | gpu = 1, 11 | cudnn = 1, 12 | } 13 | 14 | -- one-line argument parser. parses enviroment variables to override the defaults 15 | for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end 16 | print(opt) 17 | 18 | torch.manualSeed(0) 19 | torch.setnumthreads(1) 20 | torch.setdefaulttensortype('torch.FloatTensor') 21 | 22 | -- if using GPU, select indicated one 23 | cutorch.setDevice(opt.gpu) 24 | 25 | net = torch.load(opt.model) 26 | net:evaluate() 27 | net:cuda() 28 | net = cudnn.convert(net, cudnn) 29 | 30 | print('Generator:') 31 | print(net) 32 | 33 | local noise = torch.Tensor(opt.batchSize, 100):normal():cuda() 34 | 35 | local gen = net:forward(noise) 36 | local video = net.modules[2].output[1]:float() 37 | local mask = net.modules[2].output[2]:float() 38 | local static = net.modules[2].output[3]:float() 39 | local mask = mask:repeatTensor(1,3,1,1,1) 40 | 41 | function WriteGif(filename, movie) 42 | for fr=1,movie:size(3) do 43 | image.save(filename .. '.' .. string.format('%08d', fr) .. '.png', image.toDisplayTensor(movie:select(3,fr))) 44 | end 45 | cmd = "ffmpeg -f image2 -i " .. filename .. ".%08d.png -y " .. filename 46 | print('==> ' .. cmd) 47 | sys.execute(cmd) 48 | for fr=1,movie:size(3) do 49 | os.remove(filename .. '.' .. string.format('%08d', fr) .. '.png') 50 | end 51 | end 52 | 53 | paths.mkdir('vis/') 54 | WriteGif('vis/gen.gif', gen) 55 | WriteGif('vis/video.gif', video) 56 | WriteGif('vis/videomask.gif', torch.cmul(video, mask)) 57 | WriteGif('vis/mask.gif', mask) 58 | image.save('vis/static.jpg', image.toDisplayTensor(static)) 59 | -------------------------------------------------------------------------------- /main.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'optim' 4 | 5 | -- to specify these at runtime, you can do, e.g.: 6 | -- $ lr=0.001 th main.lua 7 | opt = { 8 | dataset = 'video2', -- indicates what dataset load to use (in data.lua) 9 | nThreads = 32, -- how many threads to pre-fetch data 10 | batchSize = 64, -- self-explanatory 11 | loadSize = 128, -- when loading images, resize first to this size 12 | fineSize = 64, -- crop this size from the loaded image 13 | frameSize = 32, 14 | lr = 0.0002, -- learning rate 15 | lr_decay = 1000, -- how often to decay learning rate (in epoch's) 16 | lambda = 0.1, 17 | beta1 = 0.5, -- momentum term for adam 18 | meanIter = 0, -- how many iterations to retrieve for mean estimation 19 | saveIter = 1000, -- write check point on this interval 20 | niter = 100, -- number of iterations through dataset 21 | ntrain = math.huge, -- how big one epoch should be 22 | gpu = 1, -- which GPU to use; consider using CUDA_VISIBLE_DEVICES instead 23 | cudnn = 1, -- whether to use cudnn or not 24 | finetune = '', -- if set, will load this network instead of starting from scratch 25 | name = 'beach100', -- the name of the experiment 26 | randomize = 1, -- whether to shuffle the data file or not 27 | cropping = 'random', -- options for data augmentation 28 | display_port = 8001, -- port to push graphs 29 | display_id = 1, -- window ID when pushing graphs 30 | mean = {0,0,0}, 31 | data_root = '/data/vision/torralba/crossmodal/flickr_videos/', 32 | data_list = '/data/vision/torralba/crossmodal/flickr_videos/scene_extract/lists-full/_b_beach.txt.train', 33 | } 34 | 35 | -- one-line argument parser. parses enviroment variables to override the defaults 36 | for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end 37 | print(opt) 38 | 39 | torch.manualSeed(0) 40 | torch.setnumthreads(1) 41 | torch.setdefaulttensortype('torch.FloatTensor') 42 | 43 | -- if using GPU, select indicated one 44 | if opt.gpu > 0 then 45 | require 'cunn' 46 | cutorch.setDevice(opt.gpu) 47 | end 48 | 49 | -- create data loader 50 | local DataLoader = paths.dofile('data/data.lua') 51 | local data = DataLoader.new(opt.nThreads, opt.dataset, opt) 52 | print("Dataset: " .. opt.dataset, " Size: ", data:size()) 53 | 54 | -- define the model 55 | local net 56 | local netD 57 | local mask_net 58 | local motion_net 59 | local static_net 60 | local penalty_net 61 | if opt.finetune == '' then -- build network from scratch 62 | net = nn.Sequential() 63 | 64 | static_net = nn.Sequential() 65 | static_net:add(nn.View(-1, 100, 1, 1)) 66 | static_net:add(nn.SpatialFullConvolution(100, 512, 4,4)) 67 | static_net:add(nn.SpatialBatchNormalization(512)):add(nn.ReLU(true)) 68 | static_net:add(nn.SpatialFullConvolution(512, 256, 4,4, 2,2, 1,1)) 69 | static_net:add(nn.SpatialBatchNormalization(256)):add(nn.ReLU(true)) 70 | static_net:add(nn.SpatialFullConvolution(256, 128, 4,4, 2,2, 1,1)) 71 | static_net:add(nn.SpatialBatchNormalization(128)):add(nn.ReLU(true)) 72 | static_net:add(nn.SpatialFullConvolution(128, 64, 4,4, 2,2, 1,1)) 73 | static_net:add(nn.SpatialBatchNormalization(64)):add(nn.ReLU(true)) 74 | static_net:add(nn.SpatialFullConvolution(64, 3, 4,4, 2,2, 1,1)) 75 | static_net:add(nn.Tanh()) 76 | 77 | local net_video = nn.Sequential() 78 | net_video:add(nn.View(-1, 100, 1, 1, 1)) 79 | net_video:add(nn.VolumetricFullConvolution(100, 512, 2,4,4)) 80 | net_video:add(nn.VolumetricBatchNormalization(512)):add(nn.ReLU(true)) 81 | net_video:add(nn.VolumetricFullConvolution(512, 256, 4,4,4, 2,2,2, 1,1,1)) 82 | net_video:add(nn.VolumetricBatchNormalization(256)):add(nn.ReLU(true)) 83 | net_video:add(nn.VolumetricFullConvolution(256, 128, 4,4,4, 2,2,2, 1,1,1)) 84 | net_video:add(nn.VolumetricBatchNormalization(128)):add(nn.ReLU(true)) 85 | net_video:add(nn.VolumetricFullConvolution(128, 64, 4,4,4, 2,2,2, 1,1,1)) 86 | net_video:add(nn.VolumetricBatchNormalization(64)):add(nn.ReLU(true)) 87 | 88 | local mask_out = nn.VolumetricFullConvolution(64,1, 4,4,4, 2,2,2, 1,1,1) 89 | penalty_net = nn.L1Penalty(opt.lambda, true) 90 | mask_net = nn.Sequential():add(mask_out):add(nn.Sigmoid()):add(penalty_net) 91 | gen_net = nn.Sequential():add(nn.VolumetricFullConvolution(64,3, 4,4,4, 2,2,2, 1,1,1)):add(nn.Tanh()) 92 | net_video:add(nn.ConcatTable():add(gen_net):add(mask_net)) 93 | 94 | -- [1] is generated video, [2] is mask, and [3] is static 95 | net:add(nn.ConcatTable():add(net_video):add(static_net)):add(nn.FlattenTable()) 96 | 97 | -- video .* mask (with repmat on mask) 98 | motion_net = nn.Sequential():add(nn.ConcatTable():add(nn.SelectTable(1)) 99 | :add(nn.Sequential():add(nn.SelectTable(2)) 100 | :add(nn.Squeeze()) 101 | :add(nn.Replicate(3, 2)))) -- for color chan 102 | :add(nn.CMulTable()) 103 | 104 | -- static .* (1-mask) (then repmatted) 105 | local sta_part = nn.Sequential():add(nn.ConcatTable():add(nn.Sequential():add(nn.SelectTable(3)) 106 | :add(nn.Replicate(opt.frameSize, 3))) -- for time 107 | :add(nn.Sequential():add(nn.SelectTable(2)) 108 | :add(nn.Squeeze()) 109 | :add(nn.MulConstant(-1)) 110 | :add(nn.AddConstant(1)) 111 | :add(nn.Replicate(3, 2)))) -- for color chan 112 | :add(nn.CMulTable()) 113 | 114 | net:add(nn.ConcatTable():add(motion_net):add(sta_part)):add(nn.CAddTable()) 115 | 116 | netD = nn.Sequential() 117 | 118 | netD:add(nn.VolumetricConvolution(3,64, 4,4,4, 2,2,2, 1,1,1)) 119 | netD:add(nn.LeakyReLU(0.2, true)) 120 | netD:add(nn.VolumetricConvolution(64,128, 4,4,4, 2,2,2, 1,1,1)) 121 | netD:add(nn.VolumetricBatchNormalization(128,1e-3)):add(nn.LeakyReLU(0.2, true)) 122 | netD:add(nn.VolumetricConvolution(128,256, 4,4,4, 2,2,2, 1,1,1)) 123 | netD:add(nn.VolumetricBatchNormalization(256,1e-3)):add(nn.LeakyReLU(0.2, true)) 124 | netD:add(nn.VolumetricConvolution(256,512, 4,4,4, 2,2,2, 1,1,1)) 125 | netD:add(nn.VolumetricBatchNormalization(512,1e-3)):add(nn.LeakyReLU(0.2, true)) 126 | netD:add(nn.VolumetricConvolution(512,2, 2,4,4, 1,1,1, 0,0,0)) 127 | netD:add(nn.View(2):setNumInputDims(4)) 128 | 129 | -- initialize the model 130 | local function weights_init(m) 131 | local name = torch.type(m) 132 | if name:find('Convolution') then 133 | m.weight:normal(0.0, 0.01) 134 | m.bias:fill(0) 135 | elseif name:find('BatchNormalization') then 136 | if m.weight then m.weight:normal(1.0, 0.02) end 137 | if m.bias then m.bias:fill(0) end 138 | end 139 | end 140 | net:apply(weights_init) -- loop over all layers, applying weights_init 141 | netD:apply(weights_init) 142 | 143 | mask_out.weight:normal(0, 0.01) 144 | mask_out.bias:fill(0) 145 | 146 | else -- load in existing network 147 | print('loading ' .. opt.finetune) 148 | net = torch.load(opt.finetune) 149 | end 150 | 151 | print('Generator:') 152 | print(net) 153 | print('Discriminator:') 154 | print(netD) 155 | 156 | -- define the loss 157 | local criterion = nn.CrossEntropyCriterion() 158 | local real_label = 1 159 | local fake_label = 2 160 | 161 | -- create the data placeholders 162 | local noise = torch.Tensor(opt.batchSize, 100) 163 | local target = torch.Tensor(opt.batchSize, 3, opt.frameSize, opt.fineSize, opt.fineSize) 164 | local label = torch.Tensor(opt.batchSize) 165 | local err, errD 166 | 167 | -- timers to roughly profile performance 168 | local tm = torch.Timer() 169 | local data_tm = torch.Timer() 170 | 171 | -- ship everything to GPU if needed 172 | if opt.gpu > 0 then 173 | noise = noise:cuda() 174 | target = target:cuda() 175 | label = label:cuda() 176 | net:cuda() 177 | netD:cuda() 178 | criterion:cuda() 179 | end 180 | 181 | -- conver to cudnn if needed 182 | -- if this errors on you, you can disable, but will be slightly slower 183 | if opt.gpu > 0 and opt.cudnn > 0 then 184 | require 'cudnn' 185 | net = cudnn.convert(net, cudnn) 186 | netD = cudnn.convert(netD, cudnn) 187 | end 188 | 189 | -- get a vector of parameters 190 | local parameters, gradParameters = net:getParameters() 191 | local parametersD, gradParametersD = netD:getParameters() 192 | 193 | -- show graphics 194 | disp = require 'display' 195 | disp.url = 'http://localhost:' .. opt.display_port .. '/events' 196 | 197 | -- optimization closure 198 | -- the optimizer will call this function to get the gradients 199 | local data_im,data_label 200 | local fDx = function(x) 201 | gradParametersD:zero() 202 | 203 | -- fetch data 204 | data_tm:reset(); data_tm:resume() 205 | data_im = data:getBatch() 206 | data_tm:stop() 207 | 208 | -- ship to GPU 209 | noise:normal() 210 | target:copy(data_im) 211 | label:fill(real_label) 212 | 213 | -- forward/backwards real examples 214 | local output = netD:forward(target) 215 | errD = criterion:forward(output, label) 216 | local df_do = criterion:backward(output, label) 217 | netD:backward(target, df_do) 218 | 219 | -- generate fake examples 220 | local fake = net:forward(noise) 221 | target:copy(fake) 222 | label:fill(fake_label) 223 | 224 | -- forward/backwards fake examples 225 | local output = netD:forward(target) 226 | errD = errD + criterion:forward(output, label) 227 | local df_do = criterion:backward(output, label) 228 | netD:backward(target, df_do) 229 | 230 | errD = errD / 2 231 | 232 | return errD, gradParametersD 233 | end 234 | 235 | local fx = function(x) 236 | gradParameters:zero() 237 | 238 | label:fill(real_label) 239 | local output = netD.output 240 | err = criterion:forward(output, label) 241 | local df_do = criterion:backward(output, label) 242 | local df_dg = netD:updateGradInput(target, df_do) 243 | 244 | net:backward(noise, df_dg) 245 | 246 | return err, gradParameters 247 | end 248 | 249 | local counter = 0 250 | local history = {} 251 | 252 | -- parameters for the optimization 253 | -- very important: you must only create this table once! 254 | -- the optimizer will add fields to this table (such as momentum) 255 | local optimState = { 256 | learningRate = opt.lr, 257 | beta1 = opt.beta1, 258 | } 259 | local optimStateD = { 260 | learningRate = opt.lr, 261 | beta1 = opt.beta1, 262 | } 263 | 264 | -- train main loop 265 | for epoch = 1,opt.niter do -- for each epoch 266 | for i = 1, math.min(data:size(), opt.ntrain), opt.batchSize do -- for each mini-batch 267 | collectgarbage() -- necessary sometimes 268 | 269 | tm:reset() 270 | 271 | -- do one iteration 272 | optim.adam(fDx, parametersD, optimStateD) 273 | optim.adam(fx, parameters, optimState) 274 | 275 | if counter % 10 == 0 then 276 | table.insert(history, {counter, err, errD}) 277 | disp.plot(history, {win=opt.display_id+1, title=opt.name, labels = {"iteration", "err", "errD"}}) 278 | end 279 | 280 | if counter % 100 == 0 then 281 | local vis = net.output:float() 282 | local vis_tab = {} 283 | for i=1,opt.frameSize do table.insert(vis_tab, vis[{ {}, {}, i, {}, {} }]) end 284 | disp.image(torch.cat(vis_tab, 3), {win=opt.display_id, title=(opt.name .. ' gen')}) 285 | 286 | local vis = motion_net.output:float() 287 | local vis_tab = {} 288 | for i=1,opt.frameSize do table.insert(vis_tab, vis[{ {}, {}, i, {}, {} }]) end 289 | disp.image(torch.cat(vis_tab, 3), {win=opt.display_id+3, title=(opt.name .. ' motion')}) 290 | 291 | local vis = static_net.output:float() 292 | disp.image(vis, {win=opt.display_id+4, title=(opt.name .. ' static')}) 293 | 294 | local vis = mask_net.output:float():squeeze() 295 | local vis_lo = vis:min() 296 | local vis_hi = vis:max() 297 | local vis_tab = {} 298 | for i=1,opt.frameSize do table.insert(vis_tab, vis[{ {}, i, {}, {} }]) end 299 | disp.image(torch.cat(vis_tab, 2), {win=opt.display_id+2, title=(opt.name .. ' mask ' .. string.format('%.2f %.2f', vis_lo, vis_hi))}) 300 | end 301 | counter = counter + 1 302 | 303 | print(('%s: Epoch: [%d][%8d / %8d]\t Time: %.3f DataTime: %.3f ' 304 | .. ' Err: %.4f ErrD: %.4f L2: %.4f'):format( 305 | opt.name, epoch, ((i-1) / opt.batchSize), 306 | math.floor(math.min(data:size(), opt.ntrain) / opt.batchSize), 307 | tm:time().real, data_tm:time().real, 308 | err and err or -1, errD and errD or -1, penalty_net.loss)) 309 | 310 | -- save checkpoint 311 | -- :clearState() compacts the model so it takes less space on disk 312 | if counter % opt.saveIter == 0 then 313 | print('Saving ' .. opt.name .. '/iter' .. counter .. '_net.t7') 314 | paths.mkdir('checkpoints') 315 | paths.mkdir('checkpoints/' .. opt.name) 316 | torch.save('checkpoints/' .. opt.name .. '/iter' .. counter .. '_net.t7', net:clearState()) 317 | torch.save('checkpoints/' .. opt.name .. '/iter' .. counter .. '_netD.t7', netD:clearState()) 318 | torch.save('checkpoints/' .. opt.name .. '/iter' .. counter .. '_history.t7', history) 319 | end 320 | end 321 | 322 | -- decay the learning rate, if requested 323 | if opt.lr_decay > 0 and epoch % opt.lr_decay == 0 then 324 | opt.lr = opt.lr / 10 325 | print('Decreasing learning rate to ' .. opt.lr) 326 | 327 | -- create new optimState to reset momentum 328 | optimState = { 329 | learningRate = opt.lr, 330 | beta1 = opt.beta1, 331 | } 332 | end 333 | end 334 | -------------------------------------------------------------------------------- /main_conditional.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'optim' 4 | 5 | -- to specify these at runtime, you can do, e.g.: 6 | -- $ lr=0.001 th main.lua 7 | opt = { 8 | dataset = 'video2', -- indicates what dataset load to use (in data.lua) 9 | nThreads = 32, -- how many threads to pre-fetch data 10 | batchSize = 32, -- self-explanatory 11 | loadSize = 128, -- when loading images, resize first to this size 12 | fineSize = 64, -- crop this size from the loaded image 13 | frameSize = 32, 14 | lr = 0.0002, -- learning rate 15 | lr_decay = 1000, -- how often to decay learning rate (in epoch's) 16 | lambda = 10, 17 | beta1 = 0.5, -- momentum term for adam 18 | meanIter = 0, -- how many iterations to retrieve for mean estimation 19 | saveIter = 1000, -- write check point on this interval 20 | niter = 100, -- number of iterations through dataset 21 | ntrain = math.huge, -- how big one epoch should be 22 | gpu = 1, -- which GPU to use; consider using CUDA_VISIBLE_DEVICES instead 23 | cudnn = 1, -- whether to use cudnn or not 24 | finetune = '', -- if set, will load this network instead of starting from scratch 25 | name = 'condbeach7', -- the name of the experiment 26 | randomize = 1, -- whether to shuffle the data file or not 27 | cropping = 'random', -- options for data augmentation 28 | display_port = 8000, -- port to push graphs 29 | display_id = 1, -- window ID when pushing graphs 30 | mean = {0,0,0}, 31 | data_root = '/data/vision/torralba/crossmodal/flickr_videos/', 32 | data_list = '/data/vision/torralba/crossmodal/flickr_videos/scene_extract/lists-full/_b_beach.txt.train', 33 | } 34 | 35 | -- one-line argument parser. parses enviroment variables to override the defaults 36 | for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end 37 | print(opt) 38 | 39 | torch.manualSeed(0) 40 | torch.setnumthreads(1) 41 | torch.setdefaulttensortype('torch.FloatTensor') 42 | 43 | -- if using GPU, select indicated one 44 | if opt.gpu > 0 then 45 | require 'cunn' 46 | cutorch.setDevice(opt.gpu) 47 | end 48 | 49 | -- create data loader 50 | local DataLoader = paths.dofile('data/data.lua') 51 | local data = DataLoader.new(opt.nThreads, opt.dataset, opt) 52 | print("Dataset: " .. opt.dataset, " Size: ", data:size()) 53 | 54 | -- define the model 55 | local net 56 | local netD 57 | local mask_net 58 | local motion_net 59 | local static_net 60 | if opt.finetune == '' then -- build network from scratch 61 | net = nn.Sequential() 62 | 63 | local encode_net = nn.Sequential() 64 | encode_net:add(nn.SpatialConvolution(3,128, 4,4, 2,2, 1,1)) 65 | encode_net:add(nn.ReLU(true)) 66 | encode_net:add(nn.SpatialConvolution(128,256, 4,4, 2,2, 1,1)) 67 | encode_net:add(nn.SpatialBatchNormalization(256,1e-3)):add(nn.ReLU(true)) 68 | encode_net:add(nn.SpatialConvolution(256,512, 4,4, 2,2, 1,1)) 69 | encode_net:add(nn.SpatialBatchNormalization(512,1e-3)):add(nn.ReLU(true)) 70 | encode_net:add(nn.SpatialConvolution(512,1024, 4,4, 2,2, 1,1)) 71 | encode_net:add(nn.SpatialBatchNormalization(1024,1e-3)):add(nn.ReLU(true)) 72 | net:add(encode_net) 73 | 74 | static_net = nn.Sequential() 75 | static_net:add(nn.SpatialFullConvolution(1024, 512, 4,4, 2,2, 1,1)) 76 | static_net:add(nn.SpatialBatchNormalization(512)):add(nn.ReLU(true)) 77 | static_net:add(nn.SpatialFullConvolution(512, 256, 4,4, 2,2, 1,1)) 78 | static_net:add(nn.SpatialBatchNormalization(256)):add(nn.ReLU(true)) 79 | static_net:add(nn.SpatialFullConvolution(256, 128, 4,4, 2,2, 1,1)) 80 | static_net:add(nn.SpatialBatchNormalization(128)):add(nn.ReLU(true)) 81 | static_net:add(nn.SpatialFullConvolution(128, 3, 4,4, 2,2, 1,1)) 82 | static_net:add(nn.Tanh()) 83 | 84 | local net_video = nn.Sequential() 85 | net_video:add(nn.View(-1, 1024, 1, 4, 4)) 86 | net_video:add(nn.VolumetricFullConvolution(1024, 1024, 2,1,1)) 87 | net_video:add(nn.VolumetricBatchNormalization(1024)):add(nn.ReLU(true)) 88 | net_video:add(nn.VolumetricFullConvolution(1024, 512, 4,4,4, 2,2,2, 1,1,1)) 89 | net_video:add(nn.VolumetricBatchNormalization(512)):add(nn.ReLU(true)) 90 | net_video:add(nn.VolumetricFullConvolution(512, 256, 4,4,4, 2,2,2, 1,1,1)) 91 | net_video:add(nn.VolumetricBatchNormalization(256)):add(nn.ReLU(true)) 92 | net_video:add(nn.VolumetricFullConvolution(256, 128, 4,4,4, 2,2,2, 1,1,1)) 93 | net_video:add(nn.VolumetricBatchNormalization(128)):add(nn.ReLU(true)) 94 | 95 | local mask_out = nn.VolumetricFullConvolution(128,1, 4,4,4, 2,2,2, 1,1,1) 96 | mask_net = nn.Sequential():add(mask_out):add(nn.Sigmoid()) 97 | gen_net = nn.Sequential():add(nn.VolumetricFullConvolution(128,3, 4,4,4, 2,2,2, 1,1,1)):add(nn.Tanh()) 98 | net_video:add(nn.ConcatTable():add(gen_net):add(mask_net)) 99 | 100 | -- [1] is generated video, [2] is mask, and [3] is static 101 | net:add(nn.ConcatTable():add(net_video):add(static_net)):add(nn.FlattenTable()) 102 | 103 | -- video .* mask (with repmat on mask) 104 | motion_net = nn.Sequential():add(nn.ConcatTable():add(nn.SelectTable(1)) 105 | :add(nn.Sequential():add(nn.SelectTable(2)) 106 | :add(nn.Squeeze()) 107 | :add(nn.Replicate(3, 2)))) -- for color chan 108 | :add(nn.CMulTable()) 109 | 110 | -- static .* (1-mask) (then repmatted) 111 | local sta_part = nn.Sequential():add(nn.ConcatTable():add(nn.Sequential():add(nn.SelectTable(3)) 112 | :add(nn.Replicate(opt.frameSize, 3))) -- for time 113 | :add(nn.Sequential():add(nn.SelectTable(2)) 114 | :add(nn.Squeeze()) 115 | :add(nn.MulConstant(-1)) 116 | :add(nn.AddConstant(1)) 117 | :add(nn.Replicate(3, 2)))) -- for color chan 118 | :add(nn.CMulTable()) 119 | 120 | net:add(nn.ConcatTable():add(motion_net):add(sta_part)):add(nn.CAddTable()) 121 | 122 | netD = nn.Sequential() 123 | 124 | netD:add(nn.VolumetricConvolution(3,128, 4,4,4, 2,2,2, 1,1,1)) 125 | netD:add(nn.LeakyReLU(0.2, true)) 126 | netD:add(nn.VolumetricConvolution(128,256, 4,4,4, 2,2,2, 1,1,1)) 127 | netD:add(nn.VolumetricBatchNormalization(256,1e-3)):add(nn.LeakyReLU(0.2, true)) 128 | netD:add(nn.VolumetricConvolution(256,512, 4,4,4, 2,2,2, 1,1,1)) 129 | netD:add(nn.VolumetricBatchNormalization(512,1e-3)):add(nn.LeakyReLU(0.2, true)) 130 | netD:add(nn.VolumetricConvolution(512,1024, 4,4,4, 2,2,2, 1,1,1)) 131 | netD:add(nn.VolumetricBatchNormalization(1024,1e-3)):add(nn.LeakyReLU(0.2, true)) 132 | netD:add(nn.VolumetricConvolution(1024,2, 2,4,4, 1,1,1, 0,0,0)) 133 | netD:add(nn.View(2):setNumInputDims(4)) 134 | 135 | -- initialize the model 136 | local function weights_init(m) 137 | local name = torch.type(m) 138 | if name:find('Convolution') then 139 | m.weight:normal(0.0, 0.01) 140 | m.bias:fill(0) 141 | elseif name:find('BatchNormalization') then 142 | if m.weight then m.weight:normal(1.0, 0.02) end 143 | if m.bias then m.bias:fill(0) end 144 | end 145 | end 146 | net:apply(weights_init) -- loop over all layers, applying weights_init 147 | netD:apply(weights_init) 148 | 149 | mask_out.weight:normal(0, 0.01) 150 | mask_out.bias:fill(0) 151 | 152 | else -- load in existing network 153 | print('loading ' .. opt.finetune) 154 | net = torch.load(opt.finetune) 155 | end 156 | 157 | print('Generator:') 158 | print(net) 159 | print('Discriminator:') 160 | print(netD) 161 | 162 | -- define the loss 163 | local criterion = nn.CrossEntropyCriterion() 164 | local criterionReg = nn.AbsCriterion() 165 | local real_label = 1 166 | local fake_label = 2 167 | 168 | -- create the data placeholders 169 | local input = torch.Tensor(opt.batchSize, 3, opt.fineSize, opt.fineSize) 170 | local target = torch.Tensor(opt.batchSize, 3, opt.frameSize, opt.fineSize, opt.fineSize) 171 | local video = torch.Tensor(opt.batchSize, 3, opt.frameSize, opt.fineSize, opt.fineSize) 172 | local label = torch.Tensor(opt.batchSize) 173 | local err, errD, errReg 174 | 175 | -- timers to roughly profile performance 176 | local tm = torch.Timer() 177 | local data_tm = torch.Timer() 178 | 179 | -- ship everything to GPU if needed 180 | if opt.gpu > 0 then 181 | input = input:cuda() 182 | target = target:cuda() 183 | video = video:cuda() 184 | label = label:cuda() 185 | net:cuda() 186 | netD:cuda() 187 | criterion:cuda() 188 | criterionReg:cuda() 189 | end 190 | 191 | -- conver to cudnn if needed 192 | -- if this errors on you, you can disable, but will be slightly slower 193 | if opt.gpu > 0 and opt.cudnn > 0 then 194 | require 'cudnn' 195 | net = cudnn.convert(net, cudnn) 196 | netD = cudnn.convert(netD, cudnn) 197 | end 198 | 199 | -- get a vector of parameters 200 | local parameters, gradParameters = net:getParameters() 201 | local parametersD, gradParametersD = netD:getParameters() 202 | 203 | -- show graphics 204 | disp = require 'display' 205 | disp.url = 'http://localhost:' .. opt.display_port .. '/events' 206 | 207 | -- optimization closure 208 | -- the optimizer will call this function to get the gradients 209 | local data_im,data_label 210 | local fDx = function(x) 211 | gradParametersD:zero() 212 | 213 | -- fetch data 214 | data_tm:reset(); data_tm:resume() 215 | data_im = data:getBatch() 216 | data_tm:stop() 217 | 218 | -- ship to GPU 219 | input:copy(data_im:select(3,1)) 220 | target:copy(data_im) 221 | video:copy(data_im) 222 | label:fill(real_label) 223 | 224 | -- forward/backwards real examples 225 | local output = netD:forward(video) 226 | errD = criterion:forward(output, label) 227 | local df_do = criterion:backward(output, label) 228 | netD:backward(video, df_do) 229 | 230 | -- generate fake examples 231 | local fake = net:forward(input) 232 | video:copy(fake) 233 | label:fill(fake_label) 234 | 235 | -- forward/backwards fake examples 236 | local output = netD:forward(video) 237 | errD = errD + criterion:forward(output, label) 238 | local df_do = criterion:backward(output, label) 239 | netD:backward(video, df_do) 240 | 241 | errD = errD / 2 242 | 243 | return errD, gradParametersD 244 | end 245 | 246 | local fx = function(x) 247 | gradParameters:zero() 248 | 249 | label:fill(real_label) 250 | local output = netD.output 251 | err = criterion:forward(output, label) 252 | local df_do = criterion:backward(output, label) 253 | local df_dg = netD:updateGradInput(video, df_do) 254 | 255 | errReg = criterionReg:forward(video:select(3,1), target:select(3,1)) * opt.lambda 256 | local df_reg = criterionReg:backward(video:select(3,1), target:select(3,1)) * opt.lambda 257 | 258 | df_dg[{ {}, {}, 1, {}, {} }]:add(df_reg) 259 | 260 | net:backward(input, df_dg) 261 | 262 | return err + errReg, gradParameters 263 | end 264 | 265 | local counter = 0 266 | local history = {} 267 | 268 | -- parameters for the optimization 269 | -- very important: you must only create this table once! 270 | -- the optimizer will add fields to this table (such as momentum) 271 | local optimState = { 272 | learningRate = opt.lr, 273 | beta1 = opt.beta1, 274 | } 275 | local optimStateD = { 276 | learningRate = opt.lr, 277 | beta1 = opt.beta1, 278 | } 279 | 280 | -- train main loop 281 | for epoch = 1,opt.niter do -- for each epoch 282 | for i = 1, math.min(data:size(), opt.ntrain), opt.batchSize do -- for each mini-batch 283 | collectgarbage() -- necessary sometimes 284 | 285 | tm:reset() 286 | 287 | -- do one iteration 288 | optim.adam(fDx, parametersD, optimStateD) 289 | optim.adam(fx, parameters, optimState) 290 | 291 | if counter % 10 == 0 then 292 | table.insert(history, {counter, err, errD, errReg}) 293 | disp.plot(history, {win=opt.display_id+1, title=opt.name, labels = {"iteration", "err", "errD", "errR"}}) 294 | end 295 | 296 | if counter % 100 == 0 then 297 | local vis = net.output:float() 298 | local vis_tab = {} 299 | for i=1,opt.frameSize do table.insert(vis_tab, vis[{ {}, {}, i, {}, {} }]) end 300 | disp.image(torch.cat(vis_tab, 3), {win=opt.display_id, title=(opt.name .. ' gen')}) 301 | 302 | local vis = motion_net.output:float() 303 | local vis_tab = {} 304 | for i=1,opt.frameSize do table.insert(vis_tab, vis[{ {}, {}, i, {}, {} }]) end 305 | disp.image(torch.cat(vis_tab, 3), {win=opt.display_id+3, title=(opt.name .. ' motion')}) 306 | 307 | local vis = static_net.output:float() 308 | disp.image(vis, {win=opt.display_id+4, title=(opt.name .. ' static')}) 309 | 310 | local vis = mask_net.output:float():squeeze() 311 | local vis_lo = vis:min() 312 | local vis_hi = vis:max() 313 | local vis_tab = {} 314 | for i=1,opt.frameSize do table.insert(vis_tab, vis[{ {}, i, {}, {} }]) end 315 | disp.image(torch.cat(vis_tab, 2), {win=opt.display_id+2, title=(opt.name .. ' mask ' .. string.format('%.2f %.2f', vis_lo, vis_hi))}) 316 | end 317 | counter = counter + 1 318 | 319 | print(('%s: Epoch: [%d][%8d / %8d]\t Time: %.3f DataTime: %.3f ' 320 | .. ' Err: %.4f ErrD: %.4f ErrR: %.4f'):format( 321 | opt.name, epoch, ((i-1) / opt.batchSize), 322 | math.floor(math.min(data:size(), opt.ntrain) / opt.batchSize), 323 | tm:time().real, data_tm:time().real, 324 | err and err or -1, errD and errD or -1, errReg and errReg or -1)) 325 | 326 | -- save checkpoint 327 | -- :clearState() compacts the model so it takes less space on disk 328 | if counter % opt.saveIter == 0 then 329 | print('Saving ' .. opt.name .. '/iter' .. counter .. '_net.t7') 330 | paths.mkdir('checkpoints') 331 | paths.mkdir('checkpoints/' .. opt.name) 332 | torch.save('checkpoints/' .. opt.name .. '/iter' .. counter .. '_net.t7', net:clearState()) 333 | torch.save('checkpoints/' .. opt.name .. '/iter' .. counter .. '_netD.t7', netD:clearState()) 334 | torch.save('checkpoints/' .. opt.name .. '/iter' .. counter .. '_history.t7', history) 335 | end 336 | end 337 | 338 | -- decay the learning rate, if requested 339 | if opt.lr_decay > 0 and epoch % opt.lr_decay == 0 then 340 | opt.lr = opt.lr / 10 341 | print('Decreasing learning rate to ' .. opt.lr) 342 | 343 | -- create new optimState to reset momentum 344 | optimState = { 345 | learningRate = opt.lr, 346 | beta1 = opt.beta1, 347 | } 348 | end 349 | end 350 | -------------------------------------------------------------------------------- /main_ucf.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'optim' 4 | 5 | -- to specify these at runtime, you can do, e.g.: 6 | -- $ lr=0.001 th main.lua 7 | opt = { 8 | dataset = 'video3', -- indicates what dataset load to use (in data.lua) 9 | nThreads = 16, -- how many threads to pre-fetch data 10 | batchSize = 256, -- self-explanatory 11 | loadSize = 256, -- when loading images, resize first to this size 12 | fineSize = 64, -- crop this size from the loaded image 13 | frameSize = 32, 14 | lr = 0.0002, -- learning rate 15 | lr_decay = 1000, -- how often to decay learning rate (in epoch's) 16 | lambda = 0.1, 17 | beta1 = 0.5, -- momentum term for adam 18 | meanIter = 0, -- how many iterations to retrieve for mean estimation 19 | saveIter = 100, -- write check point on this interval 20 | niter = 100, -- number of iterations through dataset 21 | max_iter = 1000, 22 | ntrain = math.huge, -- how big one epoch should be 23 | gpu = 1, -- which GPU to use; consider using CUDA_VISIBLE_DEVICES instead 24 | cudnn = 1, -- whether to use cudnn or not 25 | name = 'ucf101', -- the name of the experiment 26 | randomize = 1, -- whether to shuffle the data file or not 27 | cropping = 'random', -- options for data augmentation 28 | display_port = 8001, -- port to push graphs 29 | display_id = 1, -- window ID when pushing graphs 30 | mean = {0,0,0}, 31 | data_root = '/data/vision/torralba/hallucination/UCF101/frames-stable-nofail/videos', 32 | data_list = '/data/vision/torralba/hallucination/UCF101/gan/train.txt' 33 | } 34 | 35 | -- one-line argument parser. parses enviroment variables to override the defaults 36 | for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end 37 | print(opt) 38 | 39 | torch.manualSeed(0) 40 | torch.setnumthreads(1) 41 | torch.setdefaulttensortype('torch.FloatTensor') 42 | 43 | -- if using GPU, select indicated one 44 | if opt.gpu > 0 then 45 | require 'cunn' 46 | require 'cudnn' 47 | cutorch.setDevice(opt.gpu) 48 | end 49 | 50 | -- create data loader 51 | local DataLoader = paths.dofile('data/data.lua') 52 | local data = DataLoader.new(opt.nThreads, opt.dataset, opt) 53 | print("Dataset: " .. opt.dataset, " Size: ", data:size()) 54 | 55 | -- define the model 56 | local net = torch.load("checkpoints/all100/iter95000_netD.t7") 57 | net:remove(#net.modules) 58 | net:remove(#net.modules) 59 | --for i=1,#net.modules do net.modules[i].accGradParameters = function() end end -- freeze all but last 60 | net:add(nn.VolumetricDropout(0.5)) 61 | net:add(nn.VolumetricConvolution(512,101, 2,4,4, 1,1,1, 0,0,0)) 62 | net:add(nn.View(-1, 101)) 63 | 64 | -- net:reset() -- random weights? 65 | 66 | print('Net:') 67 | print(net) 68 | 69 | -- define the loss 70 | local criterion = nn.CrossEntropyCriterion() 71 | 72 | -- create the data placeholders 73 | local input = torch.Tensor(opt.batchSize, 3, opt.frameSize, opt.fineSize, opt.fineSize) 74 | local label = torch.Tensor(opt.batchSize) 75 | local err 76 | local accuracy 77 | 78 | -- timers to roughly profile performance 79 | local tm = torch.Timer() 80 | local data_tm = torch.Timer() 81 | 82 | -- ship everything to GPU if needed 83 | if opt.gpu > 0 then 84 | input = input:cuda() 85 | label = label:cuda() 86 | net:cuda() 87 | criterion:cuda() 88 | end 89 | 90 | -- conver to cudnn if needed 91 | -- if this errors on you, you can disable, but will be slightly slower 92 | if opt.gpu > 0 and opt.cudnn > 0 then 93 | net = cudnn.convert(net, cudnn) 94 | end 95 | 96 | -- get a vector of parameters 97 | local parameters, gradParameters = net:getParameters() 98 | 99 | -- show graphics 100 | disp = require 'display' 101 | disp.url = 'http://localhost:' .. opt.display_port .. '/events' 102 | 103 | -- optimization closure 104 | -- the optimizer will call this function to get the gradients 105 | local data_im,data_label 106 | local fx = function(x) 107 | gradParameters:zero() 108 | 109 | -- fetch data 110 | data_tm:reset(); data_tm:resume() 111 | data_im,data_label = data:getBatch() 112 | data_tm:stop() 113 | 114 | input:copy(data_im) 115 | label:copy(data_label) 116 | 117 | local output = net:forward(input) 118 | err = criterion:forward(output, label) 119 | local df_do = criterion:backward(output, label) 120 | net:backward(input, df_do) 121 | 122 | _,pred_cat = torch.max(output:float(), 2) 123 | accuracy = pred_cat:float():eq(data_label):float():mean() 124 | 125 | return err, gradParameters 126 | end 127 | 128 | local counter = 0 129 | local history = {} 130 | 131 | -- parameters for the optimization 132 | -- very important: you must only create this table once! 133 | -- the optimizer will add fields to this table (such as momentum) 134 | local optimState = { 135 | learningRate = opt.lr, 136 | beta1 = opt.beta1, 137 | } 138 | 139 | -- train main loop 140 | for epoch = 1,opt.niter do -- for each epoch 141 | for i = 1, math.min(data:size(), opt.ntrain), opt.batchSize do -- for each mini-batch 142 | collectgarbage() -- necessary sometimes 143 | 144 | tm:reset() 145 | 146 | -- do one iteration 147 | optim.adam(fx, parameters, optimState) 148 | 149 | if counter % 10 == 0 then 150 | table.insert(history, {counter, err, errD}) 151 | disp.plot(history, {win=opt.display_id+1, title=opt.name, labels = {"iteration", "err"}}) 152 | end 153 | 154 | counter = counter + 1 155 | 156 | print(('%s: Epoch: [%d][%8d / %8d]\t Time: %.3f DataTime: %.3f ' 157 | .. ' Err: %.4f Accuracy: %.4f'):format( 158 | opt.name, epoch, ((i-1) / opt.batchSize), 159 | math.floor(math.min(data:size(), opt.ntrain) / opt.batchSize), 160 | tm:time().real, data_tm:time().real, 161 | err and err or -1, accuracy)) 162 | 163 | -- save checkpoint 164 | -- :clearState() compacts the model so it takes less space on disk 165 | if counter % opt.saveIter == 0 then 166 | print('Saving ' .. opt.name .. '/iter' .. counter .. '_net.t7') 167 | paths.mkdir('checkpoints') 168 | paths.mkdir('checkpoints/' .. opt.name) 169 | torch.save('checkpoints/' .. opt.name .. '/iter' .. counter .. '_net.t7', net:clearState()) 170 | torch.save('checkpoints/' .. opt.name .. '/iter' .. counter .. '_history.t7', history) 171 | end 172 | end 173 | 174 | -- decay the learning rate, if requested 175 | if opt.lr_decay > 0 and epoch % opt.lr_decay == 0 then 176 | opt.lr = opt.lr / 10 177 | print('Decreasing learning rate to ' .. opt.lr) 178 | 179 | -- create new optimState to reset momentum 180 | optimState = { 181 | learningRate = opt.lr, 182 | beta1 = opt.beta1, 183 | } 184 | end 185 | 186 | if counter > opt.max_iter then 187 | break 188 | end 189 | end 190 | --------------------------------------------------------------------------------