├── README.md
├── data
├── data.lua
├── donkey_video2.lua
└── donkey_video3.lua
├── extra
└── stabilize_videos_many.py
├── generate.lua
├── main.lua
├── main_conditional.lua
└── main_ucf.lua
/README.md:
--------------------------------------------------------------------------------
1 | Generating Videos with Scene Dynamics
2 | =====================================
3 |
4 | This repository contains an implementation of [Generating Videos with Scene Dynamics](http://carlvondrick.com/tinyvideo/) by Carl Vondrick, Hamed Pirsiavash, Antonio Torralba, to appear at NIPS 2016. The model learns to generate tiny videos using adversarial networks.
5 |
6 | Example Generations
7 | -------------------
8 | Below are some selected videos that are generated by our model. These videos are not real; they are hallucinated by a generative video model. While they are not photo-realistic, the motions are fairly reasonable for the scene category they are trained on.
9 |
10 |
11 | Beach
12 |
13 |
14 |
15 |
16 | 
17 |
18 |
19 |
20 |
21 | 
22 |
23 |
24 |
25 |
26 |
27 | |
28 | Golf
29 |
30 |
31 |
32 |
33 | 
34 |
35 |
36 |
37 |
38 | 
39 |
40 |
41 |
42 |
43 |
44 | |
45 | Train Station
46 |
47 |
48 |
49 |
50 | 
51 |
52 |
53 |
54 |
55 | 
56 |
57 |
58 |
59 |
60 |
61 | |
62 | Baby
63 |
64 |
65 |
66 |
67 | 
68 |
69 |
70 |
71 |
72 | 
73 |
74 |
75 |
76 |
77 |
78 | |
79 |
80 |
81 |
82 | Training
83 | --------
84 |
85 | The code requires a Torch7 installation.
86 |
87 | To train a generator for video, see main.lua. This file will construct the networks, start many threads to load data, and train the networks.
88 |
89 | For the conditional version, see main_conditional.lua. This is similar to main.lua, except the input to the model is a static image.
90 |
91 | To generate videos, see generate.lua. This file will also output intermediate layers,
92 | such as the mask and background image, which you can inspect manually.
93 |
94 | Data
95 | ----
96 | The data loading is designed assuming videos have been stabilized and flattened
97 | into JPEG images. We do this for efficiency. Stabilization is computationally slow and
98 | must be done offline, and reading one file per video is more efficient on NFS.
99 |
100 | For our stabilization code, see the 'extra' directory.
101 | Essentially, this will convert each video into an image of vertically
102 | concatenated frames. After doing this, you create a text file listing
103 | all the frames, which you pass into the data loader.
104 |
105 | Models
106 | ------
107 | You can download our pre-trained models [here](https://drive.google.com/file/d/0B-xMJ5CYz_F9QS1BTE5yWl9aUWs/view?usp=sharing) (1 GB ZIP file).
108 |
109 | Notes
110 | -----
111 | The code is based on [DCGAN](https://github.com/soumith/dcgan.torch) and our [starter code](https://github.com/cvondrick/torch-starter) in [Torch7](https://github.com/torch/torch7).
112 |
113 | If you find this useful for your research, please consider citing our NIPS
114 | paper.
115 |
116 | License
117 | -------
118 | MIT
119 |
--------------------------------------------------------------------------------
/data/data.lua:
--------------------------------------------------------------------------------
1 | local Threads = require 'threads'
2 | Threads.serialization('threads.sharedserialize')
3 |
4 | local data = {}
5 |
6 | local result = {}
7 | local unpack = unpack and unpack or table.unpack
8 |
9 | function data.new(n, dataset_name, opt_)
10 | opt_ = opt_ or {}
11 | local self = {}
12 | for k,v in pairs(data) do
13 | self[k] = v
14 | end
15 |
16 | self.randomize = opt_.randomize
17 |
18 | local donkey_file
19 | if dataset_name == 'simple' then
20 | donkey_file = 'donkey_simple.lua'
21 | elseif dataset_name == 'video2' then
22 | donkey_file = 'donkey_video2.lua'
23 | elseif dataset_name == 'video3' then
24 | donkey_file = 'donkey_video3.lua'
25 | else
26 | error('Unknown dataset: ' .. dataset_name)
27 | end
28 |
29 | if n > 0 then
30 | local options = opt_
31 | self.threads = Threads(n,
32 | function() require 'torch' end,
33 | function(idx)
34 | opt = options
35 | tid = idx
36 | local seed = (opt.manualSeed and opt.manualSeed or 0) + idx
37 | torch.manualSeed(seed)
38 | torch.setnumthreads(1)
39 | print(string.format('Starting donkey with id: %d seed: %d', tid, seed))
40 | assert(options, 'options not found')
41 | assert(opt, 'opt not given')
42 | paths.dofile(donkey_file)
43 | end
44 | )
45 | else
46 | if donkey_file then paths.dofile(donkey_file) end
47 | self.threads = {}
48 | function self.threads:addjob(f1, f2) f2(f1()) end
49 | function self.threads:dojob() end
50 | function self.threads:synchronize() end
51 | end
52 |
53 | local nSamples = 0
54 | self.threads:addjob(function() return trainLoader:size() end,
55 | function(c) nSamples = c end)
56 | self.threads:synchronize()
57 | self._size = nSamples
58 |
59 | self.jobCount = 0
60 | for i = 1, n do
61 | self:queueJob()
62 | end
63 |
64 | return self
65 | end
66 |
67 | function data:queueJob()
68 | self.jobCount = self.jobCount + 1
69 |
70 | if self.randomize > 0 then
71 | self.threads:addjob(function()
72 | return trainLoader:sample(opt.batchSize)
73 | end,
74 | self._pushResult)
75 | else
76 | local indexStart = (self.jobCount-1) * opt.batchSize + 1
77 | local indexEnd = (indexStart + opt.batchSize - 1)
78 | if indexEnd <= self:size() then
79 | self.threads:addjob(function()
80 | return trainLoader:get(indexStart, indexEnd)
81 | end,
82 | self._pushResult)
83 | end
84 | end
85 | end
86 |
87 | function data._pushResult(...)
88 | local res = {...}
89 | if res == nil then
90 | self.threads:synchronize()
91 | end
92 | result[1] = res
93 | end
94 |
95 | function data:getBatch()
96 | -- queue another job
97 | local res
98 | repeat
99 | self:queueJob()
100 | self.threads:dojob()
101 | res = result[1]
102 | result[1] = nil
103 | until torch.type(res) == 'table'
104 | return unpack(res)
105 | end
106 |
107 | function data:size()
108 | return self._size
109 | end
110 |
111 | return data
112 |
--------------------------------------------------------------------------------
/data/donkey_video2.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2015-present, Facebook, Inc.
3 | All rights reserved.
4 |
5 | This source code is licensed under the BSD-style license found in the
6 | LICENSE file in the root directory of this source tree. An additional grant
7 | of patent rights can be found in the PATENTS file in the same directory.
8 | ]]--
9 |
10 | -- Heavily moidifed by Carl to make it simpler
11 |
12 | require 'torch'
13 | require 'image'
14 | tds = require 'tds'
15 | torch.setdefaulttensortype('torch.FloatTensor')
16 | local class = require('pl.class')
17 |
18 | local dataset = torch.class('dataLoader')
19 |
20 | -- this function reads in the data files
21 | function dataset:__init(args)
22 | for k,v in pairs(args) do self[k] = v end
23 |
24 | assert(self.frameSize > 0)
25 |
26 | if self.filenamePad == nil then
27 | self.filenamePad = 8
28 | end
29 |
30 | -- read text file consisting of frame directories and counts of frames
31 | self.data = tds.Vec()
32 | print('reading ' .. args.data_list)
33 | for line in io.lines(args.data_list) do
34 | local split = {}
35 | for k in string.gmatch(line, "%S+") do table.insert(split, k) end
36 | self.data:insert(split[1])
37 | end
38 |
39 | print('found ' .. #self.data .. ' videos')
40 |
41 | end
42 |
43 | function dataset:size()
44 | return #self.data
45 | end
46 |
47 | -- converts a table of samples (and corresponding labels) to a clean tensor
48 | function dataset:tableToOutput(dataTable, extraTable)
49 | local data, scalarLabels, labels
50 | local quantity = #dataTable
51 | assert(dataTable[1]:dim() == 4)
52 | data = torch.Tensor(quantity, 3, self.frameSize, self.fineSize, self.fineSize)
53 | for i=1,#dataTable do
54 | data[i]:copy(dataTable[i])
55 | end
56 | return data, extraTable
57 | end
58 |
59 | -- sampler, samples with replacement from the training set.
60 | function dataset:sample(quantity)
61 | assert(quantity)
62 | local dataTable = {}
63 | local extraTable = {}
64 | for i=1,quantity do
65 | local idx = torch.random(1, #self.data)
66 | local data_path = self.data_root .. '/' .. self.data[idx]
67 |
68 | local out = self:trainHook(data_path)
69 | table.insert(dataTable, out)
70 | table.insert(extraTable, self.data[idx])
71 | end
72 | return self:tableToOutput(dataTable,extraTable)
73 | end
74 |
75 | -- gets data in a certain range
76 | function dataset:get(start_idx,stop_idx)
77 | local dataTable = {}
78 | local extraTable = {}
79 | for idx=start_idx,stop_idx do
80 | local data_path = self.data_root .. '/' .. self.data[idx]
81 |
82 | local out = self:trainHook(data_path)
83 | table.insert(dataTable, out)
84 | table.insert(extraTable, self.data[idx])
85 | end
86 | return self:tableToOutput(dataTable,extraTable)
87 |
88 | end
89 |
90 | -- function to load the image, jitter it appropriately (random crops etc.)
91 | function dataset:trainHook(path)
92 | collectgarbage()
93 |
94 | local oW = self.fineSize
95 | local oH = self.fineSize
96 | local h1
97 | local w1
98 |
99 | local out = torch.zeros(3, self.frameSize, oW, oH)
100 |
101 | local ok,input = pcall(image.load, path, 3, 'float')
102 | if not ok then
103 | print('warning: failed loading: ' .. path)
104 | return out
105 | end
106 |
107 | local count = input:size(2) / opt.loadSize
108 | local t1 = 1
109 |
110 | for fr=1,self.frameSize do
111 | local off
112 | if fr <= count then
113 | off = (fr+t1-2) * opt.loadSize+1
114 | else
115 | off = (count+t1-2)*opt.loadSize+1 -- repeat the last frame
116 | end
117 | local crop = input[{ {}, {off, off+opt.loadSize-1}, {} }]
118 | out[{ {}, fr, {}, {} }]:copy(image.scale(crop, opt.fineSize, opt.fineSize))
119 | end
120 |
121 | out:mul(2):add(-1) -- make it [0, 1] -> [-1, 1]
122 |
123 | -- subtract mean
124 | for c=1,3 do
125 | out[{ c, {}, {} }]:add(-self.mean[c])
126 | end
127 |
128 | return out
129 | end
130 |
131 | -- data.lua expects a variable called trainLoader
132 | trainLoader = dataLoader(opt)
133 |
--------------------------------------------------------------------------------
/data/donkey_video3.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2015-present, Facebook, Inc.
3 | All rights reserved.
4 |
5 | This source code is licensed under the BSD-style license found in the
6 | LICENSE file in the root directory of this source tree. An additional grant
7 | of patent rights can be found in the PATENTS file in the same directory.
8 | ]]--
9 |
10 | -- Heavily moidifed by Carl to make it simpler
11 |
12 | require 'torch'
13 | require 'image'
14 | tds = require 'tds'
15 | torch.setdefaulttensortype('torch.FloatTensor')
16 | local class = require('pl.class')
17 |
18 | local dataset = torch.class('dataLoader')
19 |
20 | -- this function reads in the data files
21 | function dataset:__init(args)
22 | for k,v in pairs(args) do self[k] = v end
23 |
24 | assert(self.frameSize > 0)
25 |
26 | if self.filenamePad == nil then
27 | self.filenamePad = 8
28 | end
29 |
30 | -- read text file consisting of frame directories and counts of frames
31 | self.data = tds.Vec()
32 | self.category = tds.Vec()
33 | print('reading ' .. args.data_list)
34 | for line in io.lines(args.data_list) do
35 | local split = {}
36 | for k in string.gmatch(line, "%S+") do table.insert(split, k) end
37 | self.data:insert(split[1])
38 | self.category:insert(split[2])
39 | end
40 |
41 | print('found ' .. #self.data .. ' videos')
42 |
43 | end
44 |
45 | function dataset:size()
46 | return #self.data
47 | end
48 |
49 | -- converts a table of samples (and corresponding labels) to a clean tensor
50 | function dataset:tableToOutput(dataTable, scalarTable, extraTable)
51 | local data, scalarLabels, labels
52 | local quantity = #dataTable
53 | assert(dataTable[1]:dim() == 4)
54 | data = torch.Tensor(quantity, 3, self.frameSize, self.fineSize, self.fineSize)
55 | label = torch.Tensor(quantity)
56 | for i=1,#dataTable do
57 | data[i]:copy(dataTable[i])
58 | label[i] = scalarTable[i]
59 | end
60 | return data, label, extraTable
61 | end
62 |
63 | -- sampler, samples with replacement from the training set.
64 | function dataset:sample(quantity)
65 | assert(quantity)
66 | local dataTable = {}
67 | local scalarTable = {}
68 | local extraTable = {}
69 | for i=1,quantity do
70 | local idx = torch.random(1, #self.data)
71 | local data_path = self.data_root .. '/' .. self.data[idx]
72 |
73 | local out = self:trainHook(data_path)
74 | table.insert(dataTable, out)
75 | table.insert(extraTable, self.data[idx])
76 | table.insert(scalarTable, self.category[idx])
77 | end
78 | return self:tableToOutput(dataTable, scalarTable, extraTable)
79 | end
80 |
81 | -- gets data in a certain range
82 | function dataset:get(start_idx,stop_idx)
83 | assert(false)
84 | end
85 |
86 | -- function to load the image, jitter it appropriately (random crops etc.)
87 | function dataset:trainHook(path)
88 | collectgarbage()
89 |
90 | local oW = self.fineSize
91 | local oH = self.fineSize
92 | local h1
93 | local w1
94 |
95 | local out = torch.zeros(3, self.frameSize, oW, oH)
96 |
97 | local ok,input = pcall(image.load, path, 3, 'float')
98 | if not ok then
99 | print('warning: failed loading: ' .. path)
100 | return out
101 | end
102 |
103 | local count = input:size(2) / opt.loadSize
104 | local t1 = 1
105 |
106 | for fr=1,self.frameSize do
107 | local off
108 | if fr <= count then
109 | off = (fr+t1-2) * opt.loadSize+1
110 | else
111 | off = (count+t1-2)*opt.loadSize+1 -- repeat the last frame
112 | end
113 |
114 | local crop
115 | if off+opt.loadSize-1 <= input:size(2) and off > 0 then
116 | crop = input[{ {}, {off, off+opt.loadSize-1}, {} }]
117 | else
118 | print('*** WARNING ***')
119 | print(' bad size')
120 | print(' path: ' .. path)
121 | crop = torch.zeros(3, opt.fineSize, opt.fineSize)
122 | end
123 | out[{ {}, fr, {}, {} }]:copy(image.scale(crop, opt.fineSize, opt.fineSize))
124 | end
125 |
126 | out:mul(2):add(-1) -- make it [0, 1] -> [-1, 1]
127 |
128 | -- subtract mean
129 | for c=1,3 do
130 | out[{ c, {}, {} }]:add(-self.mean[c])
131 | end
132 |
133 | return out
134 | end
135 |
136 | -- data.lua expects a variable called trainLoader
137 | trainLoader = dataLoader(opt)
138 |
--------------------------------------------------------------------------------
/extra/stabilize_videos_many.py:
--------------------------------------------------------------------------------
1 | # hackery for opencv to use sift
2 | import sys
3 | sys.path.insert(0, "/data/vision/torralba/commonsense/future/opencv-2.4.11/install/lib/python2.7/dist-packages")
4 |
5 | import numpy as np
6 | import cv2
7 | import json
8 | import os
9 | import argparse
10 | import subprocess
11 | import random
12 | from scipy.ndimage.filters import gaussian_filter
13 |
14 | MIN_MATCH_COUNT = 10
15 | VIDEO_SIZE = 128
16 | CROP_SIZE = 128
17 | MAX_FRAMES = 33
18 | MIN_FRAMES = 16
19 | FRAMES_DELAY = 2
20 |
21 | def get_video_info(video):
22 | stats = subprocess.check_output("ffprobe -select_streams v -v error -show_entries stream=width,height,duration -of default=noprint_wrappers=1 {}".format(video), shell = True)
23 | info = dict(x.split("=") for x in stats.strip().split("\n"))
24 | print info
25 | return {"width": int(info['width']),
26 | "height": int(info['height']),
27 | "duration": float(info['duration'])}
28 |
29 | class FrameReader(object):
30 | def __init__(self, video):
31 | self.info = get_video_info(video)
32 |
33 | command = [ "ffmpeg",
34 | '-i', video,
35 | '-f', 'image2pipe',
36 | '-pix_fmt', 'rgb24',
37 | '-vcodec', 'rawvideo',
38 | '-']
39 | self.pipe = subprocess.Popen(command, stdout = subprocess.PIPE, bufsize=10**8)
40 |
41 | def __iter__(self):
42 | return self
43 |
44 | def next(self):
45 | raw_image = self.pipe.stdout.read(self.info['width']*self.info['height']*3)
46 | # transform the byte read into a numpy array
47 | image = np.fromstring(raw_image, dtype='uint8')
48 | try:
49 | image = image.reshape((self.info['height'],self.info['width'],3))
50 | except:
51 | raise StopIteration()
52 | # throw away the data in the pipe's buffer.
53 | self.pipe.stdout.flush()
54 |
55 | image = image[:, :, (2,1,0)]
56 |
57 | return image
58 |
59 | def close(self):
60 | self.pipe.stdout.close()
61 | self.pipe.kill()
62 |
63 | def process_im(im):
64 | h = im.shape[0]
65 | w = im.shape[1]
66 |
67 | if w > h:
68 | scale = float(VIDEO_SIZE) / h
69 | else:
70 | scale = float(VIDEO_SIZE) / w
71 |
72 | new_h = int(h * scale)
73 | new_w = int(w * scale)
74 |
75 | im = cv2.resize(im, (new_w, new_h))
76 |
77 | h = im.shape[0]
78 | w = im.shape[1]
79 |
80 | h_start = h / 2 - CROP_SIZE / 2
81 | h_stop = h_start + CROP_SIZE
82 |
83 | w_start = w / 2 - CROP_SIZE / 2
84 | w_stop = w_start + CROP_SIZE
85 |
86 | im = im[h_start:h_stop, w_start:w_stop, :]
87 |
88 | return im
89 |
90 | def compute(video, frame_dir):
91 | try:
92 | frames = FrameReader(video)
93 | except subprocess.CalledProcessError:
94 | print "failed due to CalledProcessError"
95 | return False
96 |
97 | for _ in range(FRAMES_DELAY):
98 | try:
99 | frames.next()
100 | except StopIteration:
101 | return False
102 |
103 | # Initiate SIFT detector
104 | sift = cv2.SIFT()
105 | #sift = cv2.ORB()
106 | #sift = cv2.BRISK()
107 |
108 | FLANN_INDEX_KDTREE = 0
109 | index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 5)
110 | search_params = dict(checks = 50)
111 | flann = cv2.FlannBasedMatcher(index_params, search_params)
112 |
113 | movie_clip = 0
114 | movie_clip_files = []
115 | for _ in range(100):
116 | try:
117 | img2 = frames.next()
118 | except StopIteration:
119 | print "end of stream"
120 | break
121 |
122 | bg_img = process_im(img2.copy())
123 | kp2, des2 = sift.detectAndCompute(img2,None)
124 |
125 | Ms = []
126 |
127 | movie = [bg_img.copy()]
128 |
129 | failed = False
130 |
131 | bigM = np.eye(3)
132 |
133 | for fr, img1 in enumerate(frames):
134 | #img1 = cv2.imread(im1,0)
135 | kp1, des1 = sift.detectAndCompute(img1,None)
136 | if des1 is None or des2 is None:
137 | print "Empty matches"
138 | M = np.eye(3)
139 | failed = True
140 | elif len(kp1) < 2 or len(kp2) < 2:
141 | print "Not enough key points"
142 | M = np.eye(3)
143 | failed = True
144 | else:
145 | matches = flann.knnMatch(des1.astype("float32"),des2.astype("float32"),k=2)
146 | # store all the good matches as per Lowe's ratio test.
147 | good = []
148 | for m,n in matches:
149 | if m.distance < 0.7*n.distance:
150 | good.append(m)
151 |
152 | if len(good)>=MIN_MATCH_COUNT:
153 | src_pts = np.float32([ kp1[m.queryIdx].pt for m in good ]).reshape(-1,1,2)
154 | dst_pts = np.float32([ kp2[m.trainIdx].pt for m in good ]).reshape(-1,1,2)
155 |
156 | M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC,5.0)
157 | else:
158 | print "Not enough matches are found - %d/%d" % (len(good),MIN_MATCH_COUNT)
159 |
160 | M = np.eye(3)
161 | failed = True
162 |
163 | Ms.append(M)
164 | bigM = np.dot(bigM, M)
165 |
166 | mask = (np.ones((img2.shape[0], img2.shape[1], 3)) * 255).astype('uint8')
167 | mask = cv2.warpPerspective(mask, bigM, (img2.shape[1], img2.shape[0]))
168 | mask = cv2.erode(mask / 255, np.ones((5,5),np.uint8), iterations=1) * 255
169 | mask = process_im(mask).astype("float") / 255.
170 |
171 | if (mask > 0).any():
172 | save_im = cv2.warpPerspective(img1, bigM, (img2.shape[1], img2.shape[0]))
173 |
174 | save_im = bg_img * (1-mask) + process_im(save_im) * mask
175 | movie.append(save_im.copy())
176 | #cv2.imwrite(frame_dir + ("/%08d.jpg"%(frame_counter)), save_im)
177 |
178 | bg_img = save_im.copy()
179 |
180 | else: # homography has gone out of frame, so just abort, comment these lines to keep trying
181 | break
182 |
183 | if len(movie) > MAX_FRAMES:
184 | break
185 |
186 | img2 = img1
187 | kp2 = kp1
188 | des2 = des1
189 |
190 | if failed:
191 | break
192 |
193 | if len(movie) < MIN_FRAMES:
194 | print "this movie clip is too short, causing fail"
195 | failed = True
196 |
197 | if failed:
198 | print "aborting movie clip due to failure"
199 | else:
200 | # write a column stacked image so it can be loaded at once, which
201 | # will hopefully reduce IO significantly
202 | stacked = np.vstack(movie)
203 | movie_clip_filename = frame_dir + "/%04d.jpg" % movie_clip
204 | movie_clip_files.append(movie_clip_filename)
205 | print "writing {}".format(movie_clip_filename)
206 | cv2.imwrite(movie_clip_filename, stacked)
207 | movie_clip += 1
208 |
209 | frames.close()
210 |
211 | open(frame_dir + "/list.txt", "w").write("\n".join(movie_clip_files))
212 |
213 | def get_stable_path(video):
214 | #return "frames-stable/{}".format(video)
215 | return "frames-stable-many/{}".format(video)
216 |
217 | work = [x.strip() for x in open("scene_extract/job_list.txt")]
218 | random.shuffle(work)
219 |
220 | for video in work:
221 | stable_path = get_stable_path(video)
222 | lock_file = stable_path + ".lock"
223 |
224 | if os.path.exists(stable_path) or os.path.exists(lock_file):
225 | print "already done: {}".format(stable_path)
226 | continue
227 |
228 | try:
229 | os.makedirs(os.path.dirname(stable_path))
230 | except OSError:
231 | pass
232 | try:
233 | os.makedirs(stable_path)
234 | except OSError:
235 | pass
236 | try:
237 | os.mkdir(lock_file)
238 | except OSError:
239 | pass
240 |
241 | print video
242 |
243 | #result = compute("videos/" + video, stable_path)
244 | result = compute(video, stable_path)
245 |
246 | try:
247 | os.rmdir(lock_file)
248 | except:
249 | pass
250 |
--------------------------------------------------------------------------------
/generate.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'nn'
3 | require 'image'
4 | require 'cunn'
5 | require 'cudnn'
6 |
7 | opt = {
8 | model = 'models/golf/iter65000_net.t7',
9 | batchSize = 128,
10 | gpu = 1,
11 | cudnn = 1,
12 | }
13 |
14 | -- one-line argument parser. parses enviroment variables to override the defaults
15 | for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end
16 | print(opt)
17 |
18 | torch.manualSeed(0)
19 | torch.setnumthreads(1)
20 | torch.setdefaulttensortype('torch.FloatTensor')
21 |
22 | -- if using GPU, select indicated one
23 | cutorch.setDevice(opt.gpu)
24 |
25 | net = torch.load(opt.model)
26 | net:evaluate()
27 | net:cuda()
28 | net = cudnn.convert(net, cudnn)
29 |
30 | print('Generator:')
31 | print(net)
32 |
33 | local noise = torch.Tensor(opt.batchSize, 100):normal():cuda()
34 |
35 | local gen = net:forward(noise)
36 | local video = net.modules[2].output[1]:float()
37 | local mask = net.modules[2].output[2]:float()
38 | local static = net.modules[2].output[3]:float()
39 | local mask = mask:repeatTensor(1,3,1,1,1)
40 |
41 | function WriteGif(filename, movie)
42 | for fr=1,movie:size(3) do
43 | image.save(filename .. '.' .. string.format('%08d', fr) .. '.png', image.toDisplayTensor(movie:select(3,fr)))
44 | end
45 | cmd = "ffmpeg -f image2 -i " .. filename .. ".%08d.png -y " .. filename
46 | print('==> ' .. cmd)
47 | sys.execute(cmd)
48 | for fr=1,movie:size(3) do
49 | os.remove(filename .. '.' .. string.format('%08d', fr) .. '.png')
50 | end
51 | end
52 |
53 | paths.mkdir('vis/')
54 | WriteGif('vis/gen.gif', gen)
55 | WriteGif('vis/video.gif', video)
56 | WriteGif('vis/videomask.gif', torch.cmul(video, mask))
57 | WriteGif('vis/mask.gif', mask)
58 | image.save('vis/static.jpg', image.toDisplayTensor(static))
59 |
--------------------------------------------------------------------------------
/main.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'nn'
3 | require 'optim'
4 |
5 | -- to specify these at runtime, you can do, e.g.:
6 | -- $ lr=0.001 th main.lua
7 | opt = {
8 | dataset = 'video2', -- indicates what dataset load to use (in data.lua)
9 | nThreads = 32, -- how many threads to pre-fetch data
10 | batchSize = 64, -- self-explanatory
11 | loadSize = 128, -- when loading images, resize first to this size
12 | fineSize = 64, -- crop this size from the loaded image
13 | frameSize = 32,
14 | lr = 0.0002, -- learning rate
15 | lr_decay = 1000, -- how often to decay learning rate (in epoch's)
16 | lambda = 0.1,
17 | beta1 = 0.5, -- momentum term for adam
18 | meanIter = 0, -- how many iterations to retrieve for mean estimation
19 | saveIter = 1000, -- write check point on this interval
20 | niter = 100, -- number of iterations through dataset
21 | ntrain = math.huge, -- how big one epoch should be
22 | gpu = 1, -- which GPU to use; consider using CUDA_VISIBLE_DEVICES instead
23 | cudnn = 1, -- whether to use cudnn or not
24 | finetune = '', -- if set, will load this network instead of starting from scratch
25 | name = 'beach100', -- the name of the experiment
26 | randomize = 1, -- whether to shuffle the data file or not
27 | cropping = 'random', -- options for data augmentation
28 | display_port = 8001, -- port to push graphs
29 | display_id = 1, -- window ID when pushing graphs
30 | mean = {0,0,0},
31 | data_root = '/data/vision/torralba/crossmodal/flickr_videos/',
32 | data_list = '/data/vision/torralba/crossmodal/flickr_videos/scene_extract/lists-full/_b_beach.txt.train',
33 | }
34 |
35 | -- one-line argument parser. parses enviroment variables to override the defaults
36 | for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end
37 | print(opt)
38 |
39 | torch.manualSeed(0)
40 | torch.setnumthreads(1)
41 | torch.setdefaulttensortype('torch.FloatTensor')
42 |
43 | -- if using GPU, select indicated one
44 | if opt.gpu > 0 then
45 | require 'cunn'
46 | cutorch.setDevice(opt.gpu)
47 | end
48 |
49 | -- create data loader
50 | local DataLoader = paths.dofile('data/data.lua')
51 | local data = DataLoader.new(opt.nThreads, opt.dataset, opt)
52 | print("Dataset: " .. opt.dataset, " Size: ", data:size())
53 |
54 | -- define the model
55 | local net
56 | local netD
57 | local mask_net
58 | local motion_net
59 | local static_net
60 | local penalty_net
61 | if opt.finetune == '' then -- build network from scratch
62 | net = nn.Sequential()
63 |
64 | static_net = nn.Sequential()
65 | static_net:add(nn.View(-1, 100, 1, 1))
66 | static_net:add(nn.SpatialFullConvolution(100, 512, 4,4))
67 | static_net:add(nn.SpatialBatchNormalization(512)):add(nn.ReLU(true))
68 | static_net:add(nn.SpatialFullConvolution(512, 256, 4,4, 2,2, 1,1))
69 | static_net:add(nn.SpatialBatchNormalization(256)):add(nn.ReLU(true))
70 | static_net:add(nn.SpatialFullConvolution(256, 128, 4,4, 2,2, 1,1))
71 | static_net:add(nn.SpatialBatchNormalization(128)):add(nn.ReLU(true))
72 | static_net:add(nn.SpatialFullConvolution(128, 64, 4,4, 2,2, 1,1))
73 | static_net:add(nn.SpatialBatchNormalization(64)):add(nn.ReLU(true))
74 | static_net:add(nn.SpatialFullConvolution(64, 3, 4,4, 2,2, 1,1))
75 | static_net:add(nn.Tanh())
76 |
77 | local net_video = nn.Sequential()
78 | net_video:add(nn.View(-1, 100, 1, 1, 1))
79 | net_video:add(nn.VolumetricFullConvolution(100, 512, 2,4,4))
80 | net_video:add(nn.VolumetricBatchNormalization(512)):add(nn.ReLU(true))
81 | net_video:add(nn.VolumetricFullConvolution(512, 256, 4,4,4, 2,2,2, 1,1,1))
82 | net_video:add(nn.VolumetricBatchNormalization(256)):add(nn.ReLU(true))
83 | net_video:add(nn.VolumetricFullConvolution(256, 128, 4,4,4, 2,2,2, 1,1,1))
84 | net_video:add(nn.VolumetricBatchNormalization(128)):add(nn.ReLU(true))
85 | net_video:add(nn.VolumetricFullConvolution(128, 64, 4,4,4, 2,2,2, 1,1,1))
86 | net_video:add(nn.VolumetricBatchNormalization(64)):add(nn.ReLU(true))
87 |
88 | local mask_out = nn.VolumetricFullConvolution(64,1, 4,4,4, 2,2,2, 1,1,1)
89 | penalty_net = nn.L1Penalty(opt.lambda, true)
90 | mask_net = nn.Sequential():add(mask_out):add(nn.Sigmoid()):add(penalty_net)
91 | gen_net = nn.Sequential():add(nn.VolumetricFullConvolution(64,3, 4,4,4, 2,2,2, 1,1,1)):add(nn.Tanh())
92 | net_video:add(nn.ConcatTable():add(gen_net):add(mask_net))
93 |
94 | -- [1] is generated video, [2] is mask, and [3] is static
95 | net:add(nn.ConcatTable():add(net_video):add(static_net)):add(nn.FlattenTable())
96 |
97 | -- video .* mask (with repmat on mask)
98 | motion_net = nn.Sequential():add(nn.ConcatTable():add(nn.SelectTable(1))
99 | :add(nn.Sequential():add(nn.SelectTable(2))
100 | :add(nn.Squeeze())
101 | :add(nn.Replicate(3, 2)))) -- for color chan
102 | :add(nn.CMulTable())
103 |
104 | -- static .* (1-mask) (then repmatted)
105 | local sta_part = nn.Sequential():add(nn.ConcatTable():add(nn.Sequential():add(nn.SelectTable(3))
106 | :add(nn.Replicate(opt.frameSize, 3))) -- for time
107 | :add(nn.Sequential():add(nn.SelectTable(2))
108 | :add(nn.Squeeze())
109 | :add(nn.MulConstant(-1))
110 | :add(nn.AddConstant(1))
111 | :add(nn.Replicate(3, 2)))) -- for color chan
112 | :add(nn.CMulTable())
113 |
114 | net:add(nn.ConcatTable():add(motion_net):add(sta_part)):add(nn.CAddTable())
115 |
116 | netD = nn.Sequential()
117 |
118 | netD:add(nn.VolumetricConvolution(3,64, 4,4,4, 2,2,2, 1,1,1))
119 | netD:add(nn.LeakyReLU(0.2, true))
120 | netD:add(nn.VolumetricConvolution(64,128, 4,4,4, 2,2,2, 1,1,1))
121 | netD:add(nn.VolumetricBatchNormalization(128,1e-3)):add(nn.LeakyReLU(0.2, true))
122 | netD:add(nn.VolumetricConvolution(128,256, 4,4,4, 2,2,2, 1,1,1))
123 | netD:add(nn.VolumetricBatchNormalization(256,1e-3)):add(nn.LeakyReLU(0.2, true))
124 | netD:add(nn.VolumetricConvolution(256,512, 4,4,4, 2,2,2, 1,1,1))
125 | netD:add(nn.VolumetricBatchNormalization(512,1e-3)):add(nn.LeakyReLU(0.2, true))
126 | netD:add(nn.VolumetricConvolution(512,2, 2,4,4, 1,1,1, 0,0,0))
127 | netD:add(nn.View(2):setNumInputDims(4))
128 |
129 | -- initialize the model
130 | local function weights_init(m)
131 | local name = torch.type(m)
132 | if name:find('Convolution') then
133 | m.weight:normal(0.0, 0.01)
134 | m.bias:fill(0)
135 | elseif name:find('BatchNormalization') then
136 | if m.weight then m.weight:normal(1.0, 0.02) end
137 | if m.bias then m.bias:fill(0) end
138 | end
139 | end
140 | net:apply(weights_init) -- loop over all layers, applying weights_init
141 | netD:apply(weights_init)
142 |
143 | mask_out.weight:normal(0, 0.01)
144 | mask_out.bias:fill(0)
145 |
146 | else -- load in existing network
147 | print('loading ' .. opt.finetune)
148 | net = torch.load(opt.finetune)
149 | end
150 |
151 | print('Generator:')
152 | print(net)
153 | print('Discriminator:')
154 | print(netD)
155 |
156 | -- define the loss
157 | local criterion = nn.CrossEntropyCriterion()
158 | local real_label = 1
159 | local fake_label = 2
160 |
161 | -- create the data placeholders
162 | local noise = torch.Tensor(opt.batchSize, 100)
163 | local target = torch.Tensor(opt.batchSize, 3, opt.frameSize, opt.fineSize, opt.fineSize)
164 | local label = torch.Tensor(opt.batchSize)
165 | local err, errD
166 |
167 | -- timers to roughly profile performance
168 | local tm = torch.Timer()
169 | local data_tm = torch.Timer()
170 |
171 | -- ship everything to GPU if needed
172 | if opt.gpu > 0 then
173 | noise = noise:cuda()
174 | target = target:cuda()
175 | label = label:cuda()
176 | net:cuda()
177 | netD:cuda()
178 | criterion:cuda()
179 | end
180 |
181 | -- conver to cudnn if needed
182 | -- if this errors on you, you can disable, but will be slightly slower
183 | if opt.gpu > 0 and opt.cudnn > 0 then
184 | require 'cudnn'
185 | net = cudnn.convert(net, cudnn)
186 | netD = cudnn.convert(netD, cudnn)
187 | end
188 |
189 | -- get a vector of parameters
190 | local parameters, gradParameters = net:getParameters()
191 | local parametersD, gradParametersD = netD:getParameters()
192 |
193 | -- show graphics
194 | disp = require 'display'
195 | disp.url = 'http://localhost:' .. opt.display_port .. '/events'
196 |
197 | -- optimization closure
198 | -- the optimizer will call this function to get the gradients
199 | local data_im,data_label
200 | local fDx = function(x)
201 | gradParametersD:zero()
202 |
203 | -- fetch data
204 | data_tm:reset(); data_tm:resume()
205 | data_im = data:getBatch()
206 | data_tm:stop()
207 |
208 | -- ship to GPU
209 | noise:normal()
210 | target:copy(data_im)
211 | label:fill(real_label)
212 |
213 | -- forward/backwards real examples
214 | local output = netD:forward(target)
215 | errD = criterion:forward(output, label)
216 | local df_do = criterion:backward(output, label)
217 | netD:backward(target, df_do)
218 |
219 | -- generate fake examples
220 | local fake = net:forward(noise)
221 | target:copy(fake)
222 | label:fill(fake_label)
223 |
224 | -- forward/backwards fake examples
225 | local output = netD:forward(target)
226 | errD = errD + criterion:forward(output, label)
227 | local df_do = criterion:backward(output, label)
228 | netD:backward(target, df_do)
229 |
230 | errD = errD / 2
231 |
232 | return errD, gradParametersD
233 | end
234 |
235 | local fx = function(x)
236 | gradParameters:zero()
237 |
238 | label:fill(real_label)
239 | local output = netD.output
240 | err = criterion:forward(output, label)
241 | local df_do = criterion:backward(output, label)
242 | local df_dg = netD:updateGradInput(target, df_do)
243 |
244 | net:backward(noise, df_dg)
245 |
246 | return err, gradParameters
247 | end
248 |
249 | local counter = 0
250 | local history = {}
251 |
252 | -- parameters for the optimization
253 | -- very important: you must only create this table once!
254 | -- the optimizer will add fields to this table (such as momentum)
255 | local optimState = {
256 | learningRate = opt.lr,
257 | beta1 = opt.beta1,
258 | }
259 | local optimStateD = {
260 | learningRate = opt.lr,
261 | beta1 = opt.beta1,
262 | }
263 |
264 | -- train main loop
265 | for epoch = 1,opt.niter do -- for each epoch
266 | for i = 1, math.min(data:size(), opt.ntrain), opt.batchSize do -- for each mini-batch
267 | collectgarbage() -- necessary sometimes
268 |
269 | tm:reset()
270 |
271 | -- do one iteration
272 | optim.adam(fDx, parametersD, optimStateD)
273 | optim.adam(fx, parameters, optimState)
274 |
275 | if counter % 10 == 0 then
276 | table.insert(history, {counter, err, errD})
277 | disp.plot(history, {win=opt.display_id+1, title=opt.name, labels = {"iteration", "err", "errD"}})
278 | end
279 |
280 | if counter % 100 == 0 then
281 | local vis = net.output:float()
282 | local vis_tab = {}
283 | for i=1,opt.frameSize do table.insert(vis_tab, vis[{ {}, {}, i, {}, {} }]) end
284 | disp.image(torch.cat(vis_tab, 3), {win=opt.display_id, title=(opt.name .. ' gen')})
285 |
286 | local vis = motion_net.output:float()
287 | local vis_tab = {}
288 | for i=1,opt.frameSize do table.insert(vis_tab, vis[{ {}, {}, i, {}, {} }]) end
289 | disp.image(torch.cat(vis_tab, 3), {win=opt.display_id+3, title=(opt.name .. ' motion')})
290 |
291 | local vis = static_net.output:float()
292 | disp.image(vis, {win=opt.display_id+4, title=(opt.name .. ' static')})
293 |
294 | local vis = mask_net.output:float():squeeze()
295 | local vis_lo = vis:min()
296 | local vis_hi = vis:max()
297 | local vis_tab = {}
298 | for i=1,opt.frameSize do table.insert(vis_tab, vis[{ {}, i, {}, {} }]) end
299 | disp.image(torch.cat(vis_tab, 2), {win=opt.display_id+2, title=(opt.name .. ' mask ' .. string.format('%.2f %.2f', vis_lo, vis_hi))})
300 | end
301 | counter = counter + 1
302 |
303 | print(('%s: Epoch: [%d][%8d / %8d]\t Time: %.3f DataTime: %.3f '
304 | .. ' Err: %.4f ErrD: %.4f L2: %.4f'):format(
305 | opt.name, epoch, ((i-1) / opt.batchSize),
306 | math.floor(math.min(data:size(), opt.ntrain) / opt.batchSize),
307 | tm:time().real, data_tm:time().real,
308 | err and err or -1, errD and errD or -1, penalty_net.loss))
309 |
310 | -- save checkpoint
311 | -- :clearState() compacts the model so it takes less space on disk
312 | if counter % opt.saveIter == 0 then
313 | print('Saving ' .. opt.name .. '/iter' .. counter .. '_net.t7')
314 | paths.mkdir('checkpoints')
315 | paths.mkdir('checkpoints/' .. opt.name)
316 | torch.save('checkpoints/' .. opt.name .. '/iter' .. counter .. '_net.t7', net:clearState())
317 | torch.save('checkpoints/' .. opt.name .. '/iter' .. counter .. '_netD.t7', netD:clearState())
318 | torch.save('checkpoints/' .. opt.name .. '/iter' .. counter .. '_history.t7', history)
319 | end
320 | end
321 |
322 | -- decay the learning rate, if requested
323 | if opt.lr_decay > 0 and epoch % opt.lr_decay == 0 then
324 | opt.lr = opt.lr / 10
325 | print('Decreasing learning rate to ' .. opt.lr)
326 |
327 | -- create new optimState to reset momentum
328 | optimState = {
329 | learningRate = opt.lr,
330 | beta1 = opt.beta1,
331 | }
332 | end
333 | end
334 |
--------------------------------------------------------------------------------
/main_conditional.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'nn'
3 | require 'optim'
4 |
5 | -- to specify these at runtime, you can do, e.g.:
6 | -- $ lr=0.001 th main.lua
7 | opt = {
8 | dataset = 'video2', -- indicates what dataset load to use (in data.lua)
9 | nThreads = 32, -- how many threads to pre-fetch data
10 | batchSize = 32, -- self-explanatory
11 | loadSize = 128, -- when loading images, resize first to this size
12 | fineSize = 64, -- crop this size from the loaded image
13 | frameSize = 32,
14 | lr = 0.0002, -- learning rate
15 | lr_decay = 1000, -- how often to decay learning rate (in epoch's)
16 | lambda = 10,
17 | beta1 = 0.5, -- momentum term for adam
18 | meanIter = 0, -- how many iterations to retrieve for mean estimation
19 | saveIter = 1000, -- write check point on this interval
20 | niter = 100, -- number of iterations through dataset
21 | ntrain = math.huge, -- how big one epoch should be
22 | gpu = 1, -- which GPU to use; consider using CUDA_VISIBLE_DEVICES instead
23 | cudnn = 1, -- whether to use cudnn or not
24 | finetune = '', -- if set, will load this network instead of starting from scratch
25 | name = 'condbeach7', -- the name of the experiment
26 | randomize = 1, -- whether to shuffle the data file or not
27 | cropping = 'random', -- options for data augmentation
28 | display_port = 8000, -- port to push graphs
29 | display_id = 1, -- window ID when pushing graphs
30 | mean = {0,0,0},
31 | data_root = '/data/vision/torralba/crossmodal/flickr_videos/',
32 | data_list = '/data/vision/torralba/crossmodal/flickr_videos/scene_extract/lists-full/_b_beach.txt.train',
33 | }
34 |
35 | -- one-line argument parser. parses enviroment variables to override the defaults
36 | for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end
37 | print(opt)
38 |
39 | torch.manualSeed(0)
40 | torch.setnumthreads(1)
41 | torch.setdefaulttensortype('torch.FloatTensor')
42 |
43 | -- if using GPU, select indicated one
44 | if opt.gpu > 0 then
45 | require 'cunn'
46 | cutorch.setDevice(opt.gpu)
47 | end
48 |
49 | -- create data loader
50 | local DataLoader = paths.dofile('data/data.lua')
51 | local data = DataLoader.new(opt.nThreads, opt.dataset, opt)
52 | print("Dataset: " .. opt.dataset, " Size: ", data:size())
53 |
54 | -- define the model
55 | local net
56 | local netD
57 | local mask_net
58 | local motion_net
59 | local static_net
60 | if opt.finetune == '' then -- build network from scratch
61 | net = nn.Sequential()
62 |
63 | local encode_net = nn.Sequential()
64 | encode_net:add(nn.SpatialConvolution(3,128, 4,4, 2,2, 1,1))
65 | encode_net:add(nn.ReLU(true))
66 | encode_net:add(nn.SpatialConvolution(128,256, 4,4, 2,2, 1,1))
67 | encode_net:add(nn.SpatialBatchNormalization(256,1e-3)):add(nn.ReLU(true))
68 | encode_net:add(nn.SpatialConvolution(256,512, 4,4, 2,2, 1,1))
69 | encode_net:add(nn.SpatialBatchNormalization(512,1e-3)):add(nn.ReLU(true))
70 | encode_net:add(nn.SpatialConvolution(512,1024, 4,4, 2,2, 1,1))
71 | encode_net:add(nn.SpatialBatchNormalization(1024,1e-3)):add(nn.ReLU(true))
72 | net:add(encode_net)
73 |
74 | static_net = nn.Sequential()
75 | static_net:add(nn.SpatialFullConvolution(1024, 512, 4,4, 2,2, 1,1))
76 | static_net:add(nn.SpatialBatchNormalization(512)):add(nn.ReLU(true))
77 | static_net:add(nn.SpatialFullConvolution(512, 256, 4,4, 2,2, 1,1))
78 | static_net:add(nn.SpatialBatchNormalization(256)):add(nn.ReLU(true))
79 | static_net:add(nn.SpatialFullConvolution(256, 128, 4,4, 2,2, 1,1))
80 | static_net:add(nn.SpatialBatchNormalization(128)):add(nn.ReLU(true))
81 | static_net:add(nn.SpatialFullConvolution(128, 3, 4,4, 2,2, 1,1))
82 | static_net:add(nn.Tanh())
83 |
84 | local net_video = nn.Sequential()
85 | net_video:add(nn.View(-1, 1024, 1, 4, 4))
86 | net_video:add(nn.VolumetricFullConvolution(1024, 1024, 2,1,1))
87 | net_video:add(nn.VolumetricBatchNormalization(1024)):add(nn.ReLU(true))
88 | net_video:add(nn.VolumetricFullConvolution(1024, 512, 4,4,4, 2,2,2, 1,1,1))
89 | net_video:add(nn.VolumetricBatchNormalization(512)):add(nn.ReLU(true))
90 | net_video:add(nn.VolumetricFullConvolution(512, 256, 4,4,4, 2,2,2, 1,1,1))
91 | net_video:add(nn.VolumetricBatchNormalization(256)):add(nn.ReLU(true))
92 | net_video:add(nn.VolumetricFullConvolution(256, 128, 4,4,4, 2,2,2, 1,1,1))
93 | net_video:add(nn.VolumetricBatchNormalization(128)):add(nn.ReLU(true))
94 |
95 | local mask_out = nn.VolumetricFullConvolution(128,1, 4,4,4, 2,2,2, 1,1,1)
96 | mask_net = nn.Sequential():add(mask_out):add(nn.Sigmoid())
97 | gen_net = nn.Sequential():add(nn.VolumetricFullConvolution(128,3, 4,4,4, 2,2,2, 1,1,1)):add(nn.Tanh())
98 | net_video:add(nn.ConcatTable():add(gen_net):add(mask_net))
99 |
100 | -- [1] is generated video, [2] is mask, and [3] is static
101 | net:add(nn.ConcatTable():add(net_video):add(static_net)):add(nn.FlattenTable())
102 |
103 | -- video .* mask (with repmat on mask)
104 | motion_net = nn.Sequential():add(nn.ConcatTable():add(nn.SelectTable(1))
105 | :add(nn.Sequential():add(nn.SelectTable(2))
106 | :add(nn.Squeeze())
107 | :add(nn.Replicate(3, 2)))) -- for color chan
108 | :add(nn.CMulTable())
109 |
110 | -- static .* (1-mask) (then repmatted)
111 | local sta_part = nn.Sequential():add(nn.ConcatTable():add(nn.Sequential():add(nn.SelectTable(3))
112 | :add(nn.Replicate(opt.frameSize, 3))) -- for time
113 | :add(nn.Sequential():add(nn.SelectTable(2))
114 | :add(nn.Squeeze())
115 | :add(nn.MulConstant(-1))
116 | :add(nn.AddConstant(1))
117 | :add(nn.Replicate(3, 2)))) -- for color chan
118 | :add(nn.CMulTable())
119 |
120 | net:add(nn.ConcatTable():add(motion_net):add(sta_part)):add(nn.CAddTable())
121 |
122 | netD = nn.Sequential()
123 |
124 | netD:add(nn.VolumetricConvolution(3,128, 4,4,4, 2,2,2, 1,1,1))
125 | netD:add(nn.LeakyReLU(0.2, true))
126 | netD:add(nn.VolumetricConvolution(128,256, 4,4,4, 2,2,2, 1,1,1))
127 | netD:add(nn.VolumetricBatchNormalization(256,1e-3)):add(nn.LeakyReLU(0.2, true))
128 | netD:add(nn.VolumetricConvolution(256,512, 4,4,4, 2,2,2, 1,1,1))
129 | netD:add(nn.VolumetricBatchNormalization(512,1e-3)):add(nn.LeakyReLU(0.2, true))
130 | netD:add(nn.VolumetricConvolution(512,1024, 4,4,4, 2,2,2, 1,1,1))
131 | netD:add(nn.VolumetricBatchNormalization(1024,1e-3)):add(nn.LeakyReLU(0.2, true))
132 | netD:add(nn.VolumetricConvolution(1024,2, 2,4,4, 1,1,1, 0,0,0))
133 | netD:add(nn.View(2):setNumInputDims(4))
134 |
135 | -- initialize the model
136 | local function weights_init(m)
137 | local name = torch.type(m)
138 | if name:find('Convolution') then
139 | m.weight:normal(0.0, 0.01)
140 | m.bias:fill(0)
141 | elseif name:find('BatchNormalization') then
142 | if m.weight then m.weight:normal(1.0, 0.02) end
143 | if m.bias then m.bias:fill(0) end
144 | end
145 | end
146 | net:apply(weights_init) -- loop over all layers, applying weights_init
147 | netD:apply(weights_init)
148 |
149 | mask_out.weight:normal(0, 0.01)
150 | mask_out.bias:fill(0)
151 |
152 | else -- load in existing network
153 | print('loading ' .. opt.finetune)
154 | net = torch.load(opt.finetune)
155 | end
156 |
157 | print('Generator:')
158 | print(net)
159 | print('Discriminator:')
160 | print(netD)
161 |
162 | -- define the loss
163 | local criterion = nn.CrossEntropyCriterion()
164 | local criterionReg = nn.AbsCriterion()
165 | local real_label = 1
166 | local fake_label = 2
167 |
168 | -- create the data placeholders
169 | local input = torch.Tensor(opt.batchSize, 3, opt.fineSize, opt.fineSize)
170 | local target = torch.Tensor(opt.batchSize, 3, opt.frameSize, opt.fineSize, opt.fineSize)
171 | local video = torch.Tensor(opt.batchSize, 3, opt.frameSize, opt.fineSize, opt.fineSize)
172 | local label = torch.Tensor(opt.batchSize)
173 | local err, errD, errReg
174 |
175 | -- timers to roughly profile performance
176 | local tm = torch.Timer()
177 | local data_tm = torch.Timer()
178 |
179 | -- ship everything to GPU if needed
180 | if opt.gpu > 0 then
181 | input = input:cuda()
182 | target = target:cuda()
183 | video = video:cuda()
184 | label = label:cuda()
185 | net:cuda()
186 | netD:cuda()
187 | criterion:cuda()
188 | criterionReg:cuda()
189 | end
190 |
191 | -- conver to cudnn if needed
192 | -- if this errors on you, you can disable, but will be slightly slower
193 | if opt.gpu > 0 and opt.cudnn > 0 then
194 | require 'cudnn'
195 | net = cudnn.convert(net, cudnn)
196 | netD = cudnn.convert(netD, cudnn)
197 | end
198 |
199 | -- get a vector of parameters
200 | local parameters, gradParameters = net:getParameters()
201 | local parametersD, gradParametersD = netD:getParameters()
202 |
203 | -- show graphics
204 | disp = require 'display'
205 | disp.url = 'http://localhost:' .. opt.display_port .. '/events'
206 |
207 | -- optimization closure
208 | -- the optimizer will call this function to get the gradients
209 | local data_im,data_label
210 | local fDx = function(x)
211 | gradParametersD:zero()
212 |
213 | -- fetch data
214 | data_tm:reset(); data_tm:resume()
215 | data_im = data:getBatch()
216 | data_tm:stop()
217 |
218 | -- ship to GPU
219 | input:copy(data_im:select(3,1))
220 | target:copy(data_im)
221 | video:copy(data_im)
222 | label:fill(real_label)
223 |
224 | -- forward/backwards real examples
225 | local output = netD:forward(video)
226 | errD = criterion:forward(output, label)
227 | local df_do = criterion:backward(output, label)
228 | netD:backward(video, df_do)
229 |
230 | -- generate fake examples
231 | local fake = net:forward(input)
232 | video:copy(fake)
233 | label:fill(fake_label)
234 |
235 | -- forward/backwards fake examples
236 | local output = netD:forward(video)
237 | errD = errD + criterion:forward(output, label)
238 | local df_do = criterion:backward(output, label)
239 | netD:backward(video, df_do)
240 |
241 | errD = errD / 2
242 |
243 | return errD, gradParametersD
244 | end
245 |
246 | local fx = function(x)
247 | gradParameters:zero()
248 |
249 | label:fill(real_label)
250 | local output = netD.output
251 | err = criterion:forward(output, label)
252 | local df_do = criterion:backward(output, label)
253 | local df_dg = netD:updateGradInput(video, df_do)
254 |
255 | errReg = criterionReg:forward(video:select(3,1), target:select(3,1)) * opt.lambda
256 | local df_reg = criterionReg:backward(video:select(3,1), target:select(3,1)) * opt.lambda
257 |
258 | df_dg[{ {}, {}, 1, {}, {} }]:add(df_reg)
259 |
260 | net:backward(input, df_dg)
261 |
262 | return err + errReg, gradParameters
263 | end
264 |
265 | local counter = 0
266 | local history = {}
267 |
268 | -- parameters for the optimization
269 | -- very important: you must only create this table once!
270 | -- the optimizer will add fields to this table (such as momentum)
271 | local optimState = {
272 | learningRate = opt.lr,
273 | beta1 = opt.beta1,
274 | }
275 | local optimStateD = {
276 | learningRate = opt.lr,
277 | beta1 = opt.beta1,
278 | }
279 |
280 | -- train main loop
281 | for epoch = 1,opt.niter do -- for each epoch
282 | for i = 1, math.min(data:size(), opt.ntrain), opt.batchSize do -- for each mini-batch
283 | collectgarbage() -- necessary sometimes
284 |
285 | tm:reset()
286 |
287 | -- do one iteration
288 | optim.adam(fDx, parametersD, optimStateD)
289 | optim.adam(fx, parameters, optimState)
290 |
291 | if counter % 10 == 0 then
292 | table.insert(history, {counter, err, errD, errReg})
293 | disp.plot(history, {win=opt.display_id+1, title=opt.name, labels = {"iteration", "err", "errD", "errR"}})
294 | end
295 |
296 | if counter % 100 == 0 then
297 | local vis = net.output:float()
298 | local vis_tab = {}
299 | for i=1,opt.frameSize do table.insert(vis_tab, vis[{ {}, {}, i, {}, {} }]) end
300 | disp.image(torch.cat(vis_tab, 3), {win=opt.display_id, title=(opt.name .. ' gen')})
301 |
302 | local vis = motion_net.output:float()
303 | local vis_tab = {}
304 | for i=1,opt.frameSize do table.insert(vis_tab, vis[{ {}, {}, i, {}, {} }]) end
305 | disp.image(torch.cat(vis_tab, 3), {win=opt.display_id+3, title=(opt.name .. ' motion')})
306 |
307 | local vis = static_net.output:float()
308 | disp.image(vis, {win=opt.display_id+4, title=(opt.name .. ' static')})
309 |
310 | local vis = mask_net.output:float():squeeze()
311 | local vis_lo = vis:min()
312 | local vis_hi = vis:max()
313 | local vis_tab = {}
314 | for i=1,opt.frameSize do table.insert(vis_tab, vis[{ {}, i, {}, {} }]) end
315 | disp.image(torch.cat(vis_tab, 2), {win=opt.display_id+2, title=(opt.name .. ' mask ' .. string.format('%.2f %.2f', vis_lo, vis_hi))})
316 | end
317 | counter = counter + 1
318 |
319 | print(('%s: Epoch: [%d][%8d / %8d]\t Time: %.3f DataTime: %.3f '
320 | .. ' Err: %.4f ErrD: %.4f ErrR: %.4f'):format(
321 | opt.name, epoch, ((i-1) / opt.batchSize),
322 | math.floor(math.min(data:size(), opt.ntrain) / opt.batchSize),
323 | tm:time().real, data_tm:time().real,
324 | err and err or -1, errD and errD or -1, errReg and errReg or -1))
325 |
326 | -- save checkpoint
327 | -- :clearState() compacts the model so it takes less space on disk
328 | if counter % opt.saveIter == 0 then
329 | print('Saving ' .. opt.name .. '/iter' .. counter .. '_net.t7')
330 | paths.mkdir('checkpoints')
331 | paths.mkdir('checkpoints/' .. opt.name)
332 | torch.save('checkpoints/' .. opt.name .. '/iter' .. counter .. '_net.t7', net:clearState())
333 | torch.save('checkpoints/' .. opt.name .. '/iter' .. counter .. '_netD.t7', netD:clearState())
334 | torch.save('checkpoints/' .. opt.name .. '/iter' .. counter .. '_history.t7', history)
335 | end
336 | end
337 |
338 | -- decay the learning rate, if requested
339 | if opt.lr_decay > 0 and epoch % opt.lr_decay == 0 then
340 | opt.lr = opt.lr / 10
341 | print('Decreasing learning rate to ' .. opt.lr)
342 |
343 | -- create new optimState to reset momentum
344 | optimState = {
345 | learningRate = opt.lr,
346 | beta1 = opt.beta1,
347 | }
348 | end
349 | end
350 |
--------------------------------------------------------------------------------
/main_ucf.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'nn'
3 | require 'optim'
4 |
5 | -- to specify these at runtime, you can do, e.g.:
6 | -- $ lr=0.001 th main.lua
7 | opt = {
8 | dataset = 'video3', -- indicates what dataset load to use (in data.lua)
9 | nThreads = 16, -- how many threads to pre-fetch data
10 | batchSize = 256, -- self-explanatory
11 | loadSize = 256, -- when loading images, resize first to this size
12 | fineSize = 64, -- crop this size from the loaded image
13 | frameSize = 32,
14 | lr = 0.0002, -- learning rate
15 | lr_decay = 1000, -- how often to decay learning rate (in epoch's)
16 | lambda = 0.1,
17 | beta1 = 0.5, -- momentum term for adam
18 | meanIter = 0, -- how many iterations to retrieve for mean estimation
19 | saveIter = 100, -- write check point on this interval
20 | niter = 100, -- number of iterations through dataset
21 | max_iter = 1000,
22 | ntrain = math.huge, -- how big one epoch should be
23 | gpu = 1, -- which GPU to use; consider using CUDA_VISIBLE_DEVICES instead
24 | cudnn = 1, -- whether to use cudnn or not
25 | name = 'ucf101', -- the name of the experiment
26 | randomize = 1, -- whether to shuffle the data file or not
27 | cropping = 'random', -- options for data augmentation
28 | display_port = 8001, -- port to push graphs
29 | display_id = 1, -- window ID when pushing graphs
30 | mean = {0,0,0},
31 | data_root = '/data/vision/torralba/hallucination/UCF101/frames-stable-nofail/videos',
32 | data_list = '/data/vision/torralba/hallucination/UCF101/gan/train.txt'
33 | }
34 |
35 | -- one-line argument parser. parses enviroment variables to override the defaults
36 | for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end
37 | print(opt)
38 |
39 | torch.manualSeed(0)
40 | torch.setnumthreads(1)
41 | torch.setdefaulttensortype('torch.FloatTensor')
42 |
43 | -- if using GPU, select indicated one
44 | if opt.gpu > 0 then
45 | require 'cunn'
46 | require 'cudnn'
47 | cutorch.setDevice(opt.gpu)
48 | end
49 |
50 | -- create data loader
51 | local DataLoader = paths.dofile('data/data.lua')
52 | local data = DataLoader.new(opt.nThreads, opt.dataset, opt)
53 | print("Dataset: " .. opt.dataset, " Size: ", data:size())
54 |
55 | -- define the model
56 | local net = torch.load("checkpoints/all100/iter95000_netD.t7")
57 | net:remove(#net.modules)
58 | net:remove(#net.modules)
59 | --for i=1,#net.modules do net.modules[i].accGradParameters = function() end end -- freeze all but last
60 | net:add(nn.VolumetricDropout(0.5))
61 | net:add(nn.VolumetricConvolution(512,101, 2,4,4, 1,1,1, 0,0,0))
62 | net:add(nn.View(-1, 101))
63 |
64 | -- net:reset() -- random weights?
65 |
66 | print('Net:')
67 | print(net)
68 |
69 | -- define the loss
70 | local criterion = nn.CrossEntropyCriterion()
71 |
72 | -- create the data placeholders
73 | local input = torch.Tensor(opt.batchSize, 3, opt.frameSize, opt.fineSize, opt.fineSize)
74 | local label = torch.Tensor(opt.batchSize)
75 | local err
76 | local accuracy
77 |
78 | -- timers to roughly profile performance
79 | local tm = torch.Timer()
80 | local data_tm = torch.Timer()
81 |
82 | -- ship everything to GPU if needed
83 | if opt.gpu > 0 then
84 | input = input:cuda()
85 | label = label:cuda()
86 | net:cuda()
87 | criterion:cuda()
88 | end
89 |
90 | -- conver to cudnn if needed
91 | -- if this errors on you, you can disable, but will be slightly slower
92 | if opt.gpu > 0 and opt.cudnn > 0 then
93 | net = cudnn.convert(net, cudnn)
94 | end
95 |
96 | -- get a vector of parameters
97 | local parameters, gradParameters = net:getParameters()
98 |
99 | -- show graphics
100 | disp = require 'display'
101 | disp.url = 'http://localhost:' .. opt.display_port .. '/events'
102 |
103 | -- optimization closure
104 | -- the optimizer will call this function to get the gradients
105 | local data_im,data_label
106 | local fx = function(x)
107 | gradParameters:zero()
108 |
109 | -- fetch data
110 | data_tm:reset(); data_tm:resume()
111 | data_im,data_label = data:getBatch()
112 | data_tm:stop()
113 |
114 | input:copy(data_im)
115 | label:copy(data_label)
116 |
117 | local output = net:forward(input)
118 | err = criterion:forward(output, label)
119 | local df_do = criterion:backward(output, label)
120 | net:backward(input, df_do)
121 |
122 | _,pred_cat = torch.max(output:float(), 2)
123 | accuracy = pred_cat:float():eq(data_label):float():mean()
124 |
125 | return err, gradParameters
126 | end
127 |
128 | local counter = 0
129 | local history = {}
130 |
131 | -- parameters for the optimization
132 | -- very important: you must only create this table once!
133 | -- the optimizer will add fields to this table (such as momentum)
134 | local optimState = {
135 | learningRate = opt.lr,
136 | beta1 = opt.beta1,
137 | }
138 |
139 | -- train main loop
140 | for epoch = 1,opt.niter do -- for each epoch
141 | for i = 1, math.min(data:size(), opt.ntrain), opt.batchSize do -- for each mini-batch
142 | collectgarbage() -- necessary sometimes
143 |
144 | tm:reset()
145 |
146 | -- do one iteration
147 | optim.adam(fx, parameters, optimState)
148 |
149 | if counter % 10 == 0 then
150 | table.insert(history, {counter, err, errD})
151 | disp.plot(history, {win=opt.display_id+1, title=opt.name, labels = {"iteration", "err"}})
152 | end
153 |
154 | counter = counter + 1
155 |
156 | print(('%s: Epoch: [%d][%8d / %8d]\t Time: %.3f DataTime: %.3f '
157 | .. ' Err: %.4f Accuracy: %.4f'):format(
158 | opt.name, epoch, ((i-1) / opt.batchSize),
159 | math.floor(math.min(data:size(), opt.ntrain) / opt.batchSize),
160 | tm:time().real, data_tm:time().real,
161 | err and err or -1, accuracy))
162 |
163 | -- save checkpoint
164 | -- :clearState() compacts the model so it takes less space on disk
165 | if counter % opt.saveIter == 0 then
166 | print('Saving ' .. opt.name .. '/iter' .. counter .. '_net.t7')
167 | paths.mkdir('checkpoints')
168 | paths.mkdir('checkpoints/' .. opt.name)
169 | torch.save('checkpoints/' .. opt.name .. '/iter' .. counter .. '_net.t7', net:clearState())
170 | torch.save('checkpoints/' .. opt.name .. '/iter' .. counter .. '_history.t7', history)
171 | end
172 | end
173 |
174 | -- decay the learning rate, if requested
175 | if opt.lr_decay > 0 and epoch % opt.lr_decay == 0 then
176 | opt.lr = opt.lr / 10
177 | print('Decreasing learning rate to ' .. opt.lr)
178 |
179 | -- create new optimState to reset momentum
180 | optimState = {
181 | learningRate = opt.lr,
182 | beta1 = opt.beta1,
183 | }
184 | end
185 |
186 | if counter > opt.max_iter then
187 | break
188 | end
189 | end
190 |
--------------------------------------------------------------------------------