├── README.md ├── act_norm.py ├── array_util.py ├── coupling.py ├── dataloader.py ├── glow.py ├── inv_conv.py ├── modules.py ├── optim_util.py ├── sample_100.py ├── shell_util.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-VideoFlow 2 | Pytorch implementation of paper https://arxiv.org/abs/1903.01434
3 | Glow code adapted from https://github.com/chaiyujin/glow-pytorch
4 | Work in progress! 5 | -------------------------------------------------------------------------------- /act_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from array_util import mean_dim 5 | 6 | 7 | class ActNorm(nn.Module): 8 | """Activation normalization for 2D inputs. 9 | 10 | The bias and scale get initialized using the mean and variance of the 11 | first mini-batch. After the init, bias and scale are trainable parameters. 12 | 13 | Adapted from: 14 | > https://github.com/openai/glow 15 | """ 16 | def __init__(self, num_features, scale=1., return_ldj=False): 17 | super(ActNorm, self).__init__() 18 | self.register_buffer('is_initialized', torch.zeros(1)) 19 | self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 20 | self.logs = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 21 | 22 | self.num_features = num_features 23 | self.scale = float(scale) 24 | self.eps = 1e-6 25 | self.return_ldj = return_ldj 26 | 27 | def initialize_parameters(self, x): 28 | if not self.training: 29 | return 30 | 31 | with torch.no_grad(): 32 | bias = -mean_dim(x.clone(), dim=[0, 2, 3], keepdims=True) 33 | v = mean_dim((x.clone() + bias) ** 2, dim=[0, 2, 3], keepdims=True) 34 | logs = (self.scale / (v.sqrt() + self.eps)).log() 35 | self.bias.data.copy_(bias.data) 36 | self.logs.data.copy_(logs.data) 37 | self.is_initialized += 1. 38 | 39 | def _center(self, x, reverse=False): 40 | if reverse: 41 | return x - self.bias 42 | else: 43 | return x + self.bias 44 | 45 | def _scale(self, x, sldj, reverse=False): 46 | logs = self.logs 47 | if reverse: 48 | x = x * logs.mul(-1).exp() 49 | else: 50 | x = x * logs.exp() 51 | 52 | if sldj is not None: 53 | ldj = logs.sum() * x.size(2) * x.size(3) 54 | if reverse: 55 | sldj = sldj - ldj 56 | else: 57 | sldj = sldj + ldj 58 | 59 | return x, sldj 60 | 61 | def forward(self, x, ldj=None, reverse=False): 62 | if not self.is_initialized: 63 | self.initialize_parameters(x) 64 | 65 | if reverse: 66 | x, ldj = self._scale(x, ldj, reverse) 67 | x = self._center(x, reverse) 68 | else: 69 | x = self._center(x, reverse) 70 | x, ldj = self._scale(x, ldj, reverse) 71 | 72 | if self.return_ldj: 73 | return x, ldj 74 | 75 | return x 76 | -------------------------------------------------------------------------------- /array_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mean_dim(tensor, dim=None, keepdims=False): 5 | """Take the mean along multiple dimensions. 6 | 7 | Args: 8 | tensor (torch.Tensor): Tensor of values to average. 9 | dim (list): List of dimensions along which to take the mean. 10 | keepdims (bool): Keep dimensions rather than squeezing. 11 | 12 | Returns: 13 | mean (torch.Tensor): New tensor of mean value(s). 14 | """ 15 | if dim is None: 16 | return tensor.mean() 17 | else: 18 | if isinstance(dim, int): 19 | dim = [dim] 20 | dim = sorted(dim) 21 | for d in dim: 22 | tensor = tensor.mean(dim=d, keepdim=True) 23 | if not keepdims: 24 | for i, d in enumerate(dim): 25 | tensor.squeeze_(d-i) 26 | return tensor 27 | -------------------------------------------------------------------------------- /coupling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from act_norm import ActNorm 6 | 7 | 8 | class AddCoupling(nn.Module): 9 | def __init__(self, in_channels, mid_channels, use_act_norm=True): 10 | super(AddCoupling, self).__init__() 11 | self.nn = NN(in_channels, mid_channels, in_channels, use_act_norm) 12 | # self.scale = nn.Parameter(torch.ones(in_channels, 1, 1)) 13 | 14 | def forward(self, x, ldj, reverse=False): 15 | x_change, x_id = x.chunk(2, dim=1) 16 | shift = self.nn(x_id) 17 | if not reverse: 18 | x_change = x_change + shift 19 | 20 | else: 21 | x_change = x_change - shift #z2 22 | 23 | return torch.cat([x_id, x_change], dim=1), ldj+0.0 24 | 25 | 26 | class AffineCoupling(nn.Module): 27 | """Affine coupling layer originally used in Real NVP and described by Glow. 28 | 29 | Note: The official Glow implementation (https://github.com/openai/glow) 30 | uses a different affine coupling formulation than described in the paper. 31 | This implementation follows the paper and Real NVP. 32 | 33 | Args: 34 | in_channels (int): Number of channels in the input. 35 | mid_channels (int): Number of channels in the intermediate activation 36 | in NN. 37 | """ 38 | def __init__(self, in_channels, mid_channels): 39 | super(AffineCoupling, self).__init__() 40 | self.nn = NN(in_channels, mid_channels, 2 * in_channels) 41 | self.scale = nn.Parameter(torch.ones(in_channels, 1, 1)) 42 | 43 | def forward(self, x, ldj, reverse=False): 44 | x_change, x_id = x.chunk(2, dim=1) 45 | 46 | st = self.nn(x_id) 47 | s, t = st[:, 0::2, ...], st[:, 1::2, ...] 48 | s = self.scale * torch.tanh(s) 49 | 50 | # Scale and translate 51 | if reverse: 52 | x_change = x_change * s.mul(-1).exp() - t 53 | ldj = ldj - s.flatten(1).sum(-1) 54 | else: 55 | x_change = (x_change + t) * s.exp() 56 | ldj = ldj + s.flatten(1).sum(-1) 57 | 58 | x = torch.cat((x_change, x_id), dim=1) 59 | 60 | return x, ldj 61 | 62 | 63 | class NN(nn.Module): 64 | """Small convolutional network used to compute scale and translate factors. 65 | 66 | Args: 67 | in_channels (int): Number of channels in the input. 68 | mid_channels (int): Number of channels in the hidden activations. 69 | out_channels (int): Number of channels in the output. 70 | use_act_norm (bool): Use activation norm rather than batch norm. 71 | """ 72 | def __init__(self, in_channels, mid_channels, out_channels, 73 | use_act_norm=False): 74 | super(NN, self).__init__() 75 | norm_fn = ActNorm if use_act_norm else nn.BatchNorm2d 76 | 77 | self.in_norm = norm_fn(in_channels) 78 | self.in_conv = nn.Conv2d(in_channels, mid_channels, 79 | kernel_size=3, padding=1, bias=False) 80 | nn.init.normal_(self.in_conv.weight, 0., 0.05) 81 | 82 | self.mid_norm = norm_fn(mid_channels) 83 | self.mid_conv = nn.Conv2d(mid_channels, mid_channels, 84 | kernel_size=1, padding=0, bias=False) 85 | nn.init.normal_(self.mid_conv.weight, 0., 0.05) 86 | 87 | self.out_norm = norm_fn(mid_channels) 88 | self.out_conv = nn.Conv2d(mid_channels, out_channels, 89 | kernel_size=3, padding=1, bias=True) 90 | nn.init.zeros_(self.out_conv.weight) 91 | nn.init.zeros_(self.out_conv.bias) 92 | 93 | def forward(self, x): 94 | x = self.in_norm(x) 95 | x = F.relu(x) 96 | x = self.in_conv(x) 97 | 98 | x = self.mid_norm(x) 99 | x = F.relu(x) 100 | x = self.mid_conv(x) 101 | 102 | x = self.out_norm(x) 103 | x = F.relu(x) 104 | x = self.out_conv(x) 105 | 106 | return x 107 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | import random, PIL 5 | from PIL import Image, ImageDraw 6 | from torchvision import datasets, transforms 7 | 8 | 9 | class MovingObjects(Dataset): 10 | def __init__(self, mode, transform, my_seed, dummy_len=30000): 11 | print("NOTE: no data normalization and data range is [0,1]") 12 | if mode is "train": 13 | random.seed(my_seed) 14 | torch.manual_seed(my_seed) 15 | np.random.seed(my_seed) 16 | #else: 17 | # random.seed(1234) 18 | # torch.manual_seed(1234) 19 | # np.random.seed(1234) 20 | 21 | # constant speed of 4 pixels 22 | self.seq_len = 3 23 | # 8 possible direction of movement 24 | pix = 8 25 | self.dummy_len = dummy_len 26 | self.deltaxy = [(pix,pix), (pix,0), (pix,-pix), (0,pix), (0,-pix), (-pix,-pix), (-pix,0), (-pix,pix)] 27 | self.shapes = ['circle', 'rectangle', 'polygon'] 28 | self.size_range = [18] #, 26] # range(10, 20) 29 | 30 | self.center_xy = np.array([32, 32]) 31 | self.transform = transform 32 | 33 | def __len__(self): 34 | return self.dummy_len 35 | 36 | def __getitem__(self, idx): 37 | """idx is a dummy value""" 38 | # pick a shape 39 | shape = random.choice(self.shapes) 40 | 41 | # pick a direction 42 | deltax, deltay = random.choice(self.deltaxy) 43 | 44 | # pick size 45 | size = random.choice(self.size_range) 46 | 47 | # pick color 48 | r = random.choice(range(0,256,200)) 49 | g = random.choice(range(0,256,200)) 50 | b = random.choice(range(0,256,200)) 51 | 52 | frames = [] 53 | img1 = PIL.Image.new(mode='RGB', size=(64,64), color='gray') 54 | if shape is 'circle': 55 | for i in range(self.seq_len): 56 | 57 | c_img = img1.copy() 58 | c_draw = ImageDraw.Draw(c_img) 59 | c1 = tuple(self.center_xy - size/2 + np.array([deltax, deltay])*i) 60 | c2 = tuple(self.center_xy + size/2 + np.array([deltax, deltay])*i) 61 | 62 | c_draw.ellipse([c1, c2], fill=(r ,g , b)) 63 | frames.append(self.transform(c_img)) 64 | 65 | elif shape is 'rectangle': 66 | for i in range(self.seq_len): 67 | 68 | c_img = img1.copy() 69 | c_draw = ImageDraw.Draw(c_img) 70 | c1 = tuple(self.center_xy - size/2 + np.array([deltax, deltay])*i) 71 | c2 = tuple(self.center_xy + size/2 + np.array([deltax, deltay])*i) 72 | 73 | c_draw.rectangle([c1, c2], fill=(r ,g , b)) 74 | frames.append(self.transform(c_img)) 75 | 76 | elif shape is 'polygon': 77 | for i in range(self.seq_len): 78 | c_img = img1.copy() 79 | c_draw = ImageDraw.Draw(c_img) 80 | 81 | c1 = tuple(self.center_xy - np.array([0, size/2]) + np.array([deltax, deltay])*i) 82 | c2 = tuple(self.center_xy + np.array([size/2, size/3]) + np.array([deltax, deltay])*i) 83 | c3 = tuple(self.center_xy + np.array([-size/2, size/3]) + np.array([deltax, deltay])*i) 84 | 85 | c_draw.polygon([c1, c2, c3], fill=(r ,g , b)) 86 | 87 | #change to tensor 88 | frames.append(self.transform(c_img)) 89 | 90 | else: 91 | raise NotImplementedError() 92 | 93 | frames_tensor = torch.stack(frames, dim=0) 94 | return frames_tensor 95 | 96 | 97 | class MovingMNIST(object): 98 | 99 | """Data Handler that creates Bouncing MNIST dataset on the fly.""" 100 | 101 | def __init__(self, train, data_root, seq_len=20, num_digits=2, image_size=64, deterministic=True): 102 | path = data_root 103 | self.seq_len = seq_len 104 | self.num_digits = num_digits 105 | self.image_size = image_size 106 | self.step_length = 0.1 107 | self.digit_size = 32 108 | self.deterministic = deterministic 109 | self.seed_is_set = False # multi threaded loading 110 | self.channels = 1 111 | 112 | self.data = datasets.MNIST( 113 | path, 114 | train=train, 115 | download=True, 116 | transform=transforms.Compose( 117 | [transforms.Scale(self.digit_size), 118 | transforms.ToTensor()])) 119 | 120 | self.N = len(self.data) 121 | 122 | def set_seed(self, seed): 123 | if not self.seed_is_set: 124 | self.seed_is_set = True 125 | np.random.seed(seed) 126 | 127 | def __len__(self): 128 | return self.N 129 | 130 | def __getitem__(self, index): 131 | self.set_seed(index) 132 | image_size = self.image_size 133 | digit_size = self.digit_size 134 | x = np.zeros((self.seq_len, 135 | image_size, 136 | image_size, 137 | self.channels), 138 | dtype=np.float32) 139 | for n in range(self.num_digits): 140 | idx = np.random.randint(self.N) 141 | digit, _ = self.data[idx] 142 | 143 | sx = np.random.randint(image_size-digit_size) 144 | sy = np.random.randint(image_size-digit_size) 145 | dx = np.random.randint(-4, 5) 146 | dy = np.random.randint(-4, 5) 147 | for t in range(self.seq_len): 148 | if sy < 0: 149 | sy = 0 150 | if self.deterministic: 151 | dy = -dy 152 | else: 153 | dy = np.random.randint(1, 5) 154 | dx = np.random.randint(-4, 5) 155 | elif sy >= image_size-32: 156 | sy = image_size-32-1 157 | if self.deterministic: 158 | dy = -dy 159 | else: 160 | dy = np.random.randint(-4, 0) 161 | dx = np.random.randint(-4, 5) 162 | 163 | if sx < 0: 164 | sx = 0 165 | if self.deterministic: 166 | dx = -dx 167 | else: 168 | dx = np.random.randint(1, 5) 169 | dy = np.random.randint(-4, 5) 170 | elif sx >= image_size-32: 171 | sx = image_size-32-1 172 | if self.deterministic: 173 | dx = -dx 174 | else: 175 | dx = np.random.randint(-4, 0) 176 | dy = np.random.randint(-4, 5) 177 | 178 | x[t, sy:sy+32, sx:sx+32, 0] += digit.numpy().squeeze() 179 | sy += dy 180 | sx += dx 181 | 182 | x[x>1] = 1. 183 | # t, w, h, c --> t, c, w, h 184 | return x.transpose([0,3,1,2]) 185 | -------------------------------------------------------------------------------- /glow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from act_norm import ActNorm 6 | from coupling import * 7 | from inv_conv import InvConv 8 | 9 | 10 | class _Glow(nn.Module): 11 | """Flow per level of Glow 12 | 13 | Args: 14 | in_channels (int): Number of channels in the input. 15 | mid_channels (int): Number of channels in hidden layers of each step. 16 | num_levels (int): Number of levels to construct. Counter for recursion. 17 | num_steps (int): Number of steps of flow for each level. 18 | """ 19 | def __init__(self, in_channels, mid_channels, num_steps, num_levels=None): 20 | super(_Glow, self).__init__() 21 | self.steps = nn.ModuleList([_FlowStep(in_channels=in_channels, 22 | mid_channels=mid_channels) 23 | for _ in range(num_steps)]) 24 | 25 | def forward(self, x, sldj, reverse=False): 26 | if not reverse: 27 | for n,step in enumerate(self.steps): 28 | x, sldj = step(x, sldj, reverse) 29 | 30 | if reverse: 31 | for step in reversed(self.steps): 32 | x, sldj = step(x, sldj, reverse) 33 | 34 | return x, sldj 35 | 36 | 37 | class _FlowStep(nn.Module): 38 | def __init__(self, in_channels, mid_channels, coupling='affine', use_act_norm_in_coupling=True): 39 | super(_FlowStep, self).__init__() 40 | 41 | # Activation normalization, invertible 1x1 convolution, affine coupling 42 | self.norm = ActNorm(in_channels, return_ldj=True) 43 | self.conv = InvConv(in_channels) 44 | if coupling is "additive": 45 | self.coup = AddCoupling(in_channels // 2, mid_channels, use_act_norm_in_coupling) 46 | else: 47 | self.coup = AffineCoupling(in_channels // 2, mid_channels) 48 | 49 | def forward(self, x, sldj=None, reverse=False): 50 | if reverse: 51 | x, sldj = self.coup(x, sldj, reverse) 52 | x, sldj = self.conv(x, sldj, reverse) 53 | x, sldj = self.norm(x, sldj, reverse) 54 | else: 55 | x, sldj = self.norm(x, sldj, reverse) 56 | x, sldj = self.conv(x, sldj, reverse) 57 | x, sldj = self.coup(x, sldj, reverse) 58 | 59 | return x, sldj 60 | 61 | 62 | class PreProcess(nn.Module): 63 | def __init__(self, bound=0.9): 64 | super(PreProcess, self).__init__() 65 | self.register_buffer('bounds', torch.tensor([bound], dtype=torch.float32)) 66 | 67 | def forward(self, x): 68 | """Dequantize the input image `x` and convert to logits. 69 | 70 | See Also: 71 | - Dequantization: https://arxiv.org/abs/1511.01844, Section 3.1 72 | - Modeling logits: https://arxiv.org/abs/1605.08803, Section 4.1 73 | 74 | Args: 75 | x (torch.Tensor): Input image. 76 | 77 | Returns: 78 | y (torch.Tensor): Dequantized logits of `x`. 79 | """ 80 | y = (x * 255. + torch.rand_like(x)) / 256. 81 | y = (2 * y - 1) * self.bounds 82 | y = (y + 1) / 2 83 | y = y.log() - (1. - y).log() 84 | 85 | # Save log-determinant of Jacobian of initial transform 86 | ldj = F.softplus(y) + F.softplus(-y) \ 87 | - F.softplus((1. - self.bounds).log() - self.bounds.log()) 88 | sldj = ldj.flatten(1).sum(-1) 89 | #sldj = torch.zeros(x.shape[0]).cuda() 90 | 91 | return y, sldj 92 | 93 | 94 | def squeeze(x, reverse=False): 95 | """Trade spatial extent for channels. In forward direction, convert each 96 | 1x4x4 volume of input into a 4x1x1 volume of output. 97 | 98 | Args: 99 | x (torch.Tensor): Input to squeeze or unsqueeze. 100 | reverse (bool): Reverse the operation, i.e., unsqueeze. 101 | 102 | Returns: 103 | x (torch.Tensor): Squeezed or unsqueezed tensor. 104 | """ 105 | b, c, h, w = x.size() 106 | if reverse: 107 | # Unsqueeze 108 | x = x.view(b, c // 4, 2, 2, h, w) 109 | x = x.permute(0, 1, 4, 2, 5, 3).contiguous() 110 | x = x.view(b, c // 4, h * 2, w * 2) 111 | else: 112 | # Squeeze 113 | x = x.view(b, c, h // 2, 2, w // 2, 2) 114 | x = x.permute(0, 1, 3, 5, 2, 4).contiguous() 115 | x = x.view(b, c * 2 * 2, h // 2, w // 2) 116 | 117 | return x 118 | 119 | if __name__ == "__main__": 120 | model = Glow(512, 3, 32) 121 | test_in = torch.rand(1,3,64,64) 122 | out = model(test_in) 123 | print(len(out[2])) 124 | for e in out[2]: 125 | print(e.shape) 126 | -------------------------------------------------------------------------------- /inv_conv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class InvConv(nn.Module): 8 | """Invertible 1x1 Convolution for 2D inputs. Originally described in Glow 9 | (https://arxiv.org/abs/1807.03039). Does not support LU-decomposed version. 10 | 11 | Args: 12 | num_channels (int): Number of channels in the input and output. 13 | """ 14 | def __init__(self, num_channels): 15 | super(InvConv, self).__init__() 16 | self.num_channels = num_channels 17 | 18 | # Initialize with a random orthogonal matrix 19 | w_init = np.random.randn(num_channels, num_channels) 20 | w_init = np.linalg.qr(w_init)[0].astype(np.float32) 21 | self.weight = nn.Parameter(torch.from_numpy(w_init)) 22 | 23 | def forward(self, x, sldj, reverse=False): 24 | ldj = torch.slogdet(self.weight)[1] * x.size(2) * x.size(3) 25 | 26 | if reverse: 27 | weight = torch.inverse(self.weight.double()).float() 28 | sldj = sldj - ldj 29 | else: 30 | weight = self.weight 31 | sldj = sldj + ldj 32 | 33 | weight = weight.view(self.num_channels, self.num_channels, 1, 1) 34 | z = F.conv2d(x, weight) 35 | 36 | return z, sldj 37 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class NNTheta(nn.Module): 8 | def __init__(self, encoder_ch_in, encoder_mode, num_blocks, h_ch_in=None): 9 | super(NNTheta, self).__init__() 10 | self.encoder_mode = encoder_mode 11 | 12 | if h_ch_in is not None: 13 | self.conv1 = nn.Conv2d(in_channels=h_ch_in, out_channels=h_ch_in, kernel_size=1) 14 | initialize(self.conv1, mode='gaussian') 15 | 16 | dilations = [1, 2] 17 | self.latent_encoder = nn.ModuleList() 18 | for i in range(num_blocks): 19 | self.latent_encoder.append(nn.ModuleList( 20 | [self.latent_dist_encoder(encoder_ch_in, dilation=d, mode=encoder_mode) for d in dilations])) 21 | # print("latent encoder:", self.latent_encoder) 22 | 23 | if h_ch_in: 24 | self.conv2 = nn.Conv2d(in_channels=encoder_ch_in, out_channels=encoder_ch_in, kernel_size=1) 25 | initialize(self.conv2, mode='zeros') 26 | else: 27 | self.conv2 = nn.Conv2d(in_channels=encoder_ch_in, out_channels=2 * encoder_ch_in, kernel_size=1) 28 | initialize(self.conv2, mode='zeros') 29 | 30 | def forward(self, z_past, h=None): 31 | if h is not None: 32 | h = self.conv1(h) 33 | encoder_input = torch.cat([z_past, h], dim=1) 34 | else: 35 | encoder_input = z_past.clone() 36 | 37 | for block in self.latent_encoder: 38 | parallel_outs = [pb(encoder_input) for pb in block] 39 | 40 | parallel_outs.append(encoder_input) 41 | encoder_input = sum(parallel_outs) 42 | 43 | last_t = self.conv2(encoder_input) 44 | deltaz_t, logsigma_t = last_t[:, 0::2, ...], last_t[:, 1::2, ...] 45 | 46 | # assert deltaz_t.shape == z_past.shape 47 | logsigma_t = torch.clamp(logsigma_t, min=-15., max=15.) 48 | mu_t = deltaz_t + z_past 49 | return mu_t, logsigma_t 50 | 51 | @staticmethod 52 | def latent_dist_encoder(ch_in, dilation, mode): 53 | 54 | if mode is "conv_net": 55 | layer1 = nn.Conv2d(in_channels=ch_in, out_channels=512, kernel_size=(3, 3), 56 | dilation=(dilation, dilation), padding=(dilation, dilation)) 57 | initialize(layer1, mode='gaussian') 58 | layer2 = GATU2D(channels=512) 59 | layer3 = nn.Conv2d(in_channels=512, out_channels=ch_in, kernel_size=(3, 3), 60 | dilation=(dilation, dilation), padding=(dilation, dilation)) 61 | initialize(layer3, mode='zeros') 62 | 63 | block = nn.Sequential(*[layer1, nn.ReLU(inplace=True), layer2, layer3]) 64 | 65 | return block 66 | 67 | 68 | class GATU2D(nn.Module): 69 | 70 | def __init__(self, channels): 71 | super(GATU2D, self).__init__() 72 | self.conv1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=1) 73 | initialize(self.conv1, mode='gaussian') 74 | self.conv2 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=1) 75 | initialize(self.conv2, mode='gaussian') 76 | 77 | def forward(self, x): 78 | out1 = torch.tanh(self.conv1(x)) 79 | out2 = torch.sigmoid(self.conv2(x)) 80 | return out1 * out2 81 | 82 | 83 | class NLLLossVF(nn.Module): 84 | def __init__(self, k=256): 85 | super(NLLLossVF, self).__init__() 86 | self.k = k 87 | 88 | def forward(self, gaussian_1, gaussian_2, gaussian_3, z, sldj, input_dim): 89 | 90 | prior_ll_l3 = torch.sum(gaussian_3.log_prob(z.l3), [1, 2, 3]) 91 | prior_ll_l2 = torch.sum(gaussian_2.log_prob(z.l2), [1, 2, 3]) 92 | prior_ll_l1 = torch.sum(gaussian_1.log_prob(z.l1), [1, 2, 3]) 93 | 94 | prior_ll = prior_ll_l1 + prior_ll_l2 + prior_ll_l3 - np.log(self.k) * np.prod(input_dim[1:]) 95 | ll = prior_ll + sldj 96 | nll = -ll.mean() 97 | return nll 98 | 99 | 100 | class GlowLoss(nn.Module): 101 | def __init__(self, k=256): 102 | super(GlowLoss, self).__init__() 103 | self.k = k 104 | 105 | def forward(self, z, sldj): 106 | prior_ll = -0.5 * (z ** 2 + np.log(2 * np.pi)) 107 | prior_ll = prior_ll.flatten(1).sum(-1) \ 108 | - np.log(self.k) * np.prod(z.size()[1:]) 109 | ll = prior_ll + sldj 110 | nll = -ll.mean() 111 | 112 | return nll 113 | 114 | 115 | def initialize(layer, mode): 116 | if mode == 'gaussian': 117 | nn.init.normal_(layer.weight, 0., 0.05) 118 | nn.init.normal_(layer.bias, 0., 0.05) 119 | 120 | elif mode == 'zeros': 121 | nn.init.zeros_(layer.weight) 122 | nn.init.zeros_(layer.bias) 123 | 124 | else: 125 | raise NotImplementedError("To be implemented") 126 | -------------------------------------------------------------------------------- /optim_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.utils as utils 5 | 6 | 7 | def bits_per_dim(x, nll): 8 | """Get the bits per dimension implied by using model with `loss` 9 | for compressing `x`, assuming each entry can take on `k` discrete values. 10 | 11 | Args: 12 | x (torch.Tensor): Input to the model. Just used for dimensions. 13 | nll (torch.Tensor): Scalar negative log-likelihood loss tensor. 14 | 15 | Returns: 16 | bpd (torch.Tensor): Bits per dimension implied if compressing `x`. 17 | """ 18 | dim = np.prod(x.size()[1:]) 19 | bpd = nll / (np.log(2) * dim) 20 | 21 | return bpd 22 | 23 | 24 | def clip_grad_norm(optimizer, max_norm, norm_type=2): 25 | """Clip the norm of the gradients for all parameters under `optimizer`. 26 | 27 | Args: 28 | optimizer (torch.optim.Optimizer): 29 | max_norm (float): The maximum allowable norm of gradients. 30 | norm_type (int): The type of norm to use in computing gradient norms. 31 | """ 32 | for group in optimizer.param_groups: 33 | utils.clip_grad_norm_(group['params'], max_norm, norm_type) 34 | 35 | 36 | def plot_grad_flow(named_parameters): 37 | '''Plots the gradients flowing through different layers in the net during training. 38 | Can be used for checking for possible gradient vanishing / exploding problems. 39 | 40 | Usage: Plug this function in Trainer class after loss.backwards() as 41 | "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow''' 42 | ave_grads = [] 43 | max_grads= [] 44 | layers = [] 45 | for n, p in named_parameters: 46 | if(p.requires_grad) and ("bias" not in n): 47 | layers.append(n) 48 | ave_grads.append(p.grad.abs().mean()) 49 | max_grads.append(p.grad.abs().max()) 50 | plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c") 51 | plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b") 52 | plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k" ) 53 | plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical") 54 | plt.xlim(left=0, right=len(ave_grads)) 55 | plt.ylim(bottom = -0.001, top=0.02) # zoom in on the lower gradient regions 56 | plt.xlabel("Layers") 57 | plt.ylabel("average gradient") 58 | plt.title("Gradient flow") 59 | plt.grid(True) 60 | plt.legend([Line2D([0], [0], color="c", lw=4), 61 | Line2D([0], [0], color="b", lw=4), 62 | Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient']) 63 | -------------------------------------------------------------------------------- /sample_100.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import random 4 | import torchvision 5 | import torch.nn as nn 6 | from torchvision import transforms 7 | from torch.utils.data import DataLoader 8 | import torch.nn.utils as utils 9 | from torch.distributions.normal import Normal 10 | 11 | from collections import namedtuple 12 | 13 | from glow import _Glow, PreProcess, squeeze 14 | from modules import * 15 | from shell_util import AverageMeter, save_model 16 | from optim_util import bits_per_dim 17 | from dataloader import MovingObjects 18 | 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | pre_process = PreProcess().to(device) 21 | 22 | Z_splits = namedtuple('Z_splits', 'l3 l2 l1') 23 | Glow = namedtuple('Glow', 'l3 l2 l1') 24 | NN_Theta = namedtuple('NNTheta', 'l3 l2 l1') 25 | 26 | 27 | def flow_forward(x, flow): 28 | if x.min() < 0 or x.max() > 1: 29 | raise ValueError('Expected x in [0, 1], got min/max {}/{}' 30 | .format(x.min(), x.max())) 31 | 32 | # pre-process 33 | x, sldj = pre_process(x) 34 | # L3 35 | x3 = squeeze(x, reverse=False) 36 | x3, sldj = flow.l3(x3, sldj, reverse=False) 37 | x3, x_split3 = x3.chunk(2, dim=1) 38 | # L2 39 | x2 = squeeze(x3, reverse=False) 40 | x2, sldj = flow.l2(x2, sldj, reverse=False) 41 | x2, x_split2 = x2.chunk(2, dim=1) 42 | # L1 43 | x1 = squeeze(x2, reverse=False) 44 | x1, sldj = flow.l1(x1, sldj) 45 | 46 | partition_out = Z_splits(l3=x_split3, l2=x_split2, l1=x1) 47 | partition_h = Z_splits(l3=x3, l2=x2, l1=None) 48 | 49 | return partition_out, partition_h, sldj 50 | 51 | 52 | def sample_100(context, glow, nn_theta, temperature=0.5): 53 | for net in glow: 54 | net.eval() 55 | for net in nn_theta: 56 | net.eval() 57 | 58 | b_s = context.size(0) 59 | # generate two frames 60 | torchvision.utils.save_image(context[:, 0, ...].squeeze(), 'samples_100/context.png') 61 | 62 | for m in range(100): 63 | print(m) 64 | context_frame = context[:, 0, ...] 65 | for n in range(2): 66 | t0_zi, _, _ = flow_forward(context_frame, glow) 67 | 68 | mu_l1, logsigma_l1 = nn_theta.l1(t0_zi.l1) 69 | g1 = Normal(loc=mu_l1, scale=temperature * torch.exp(logsigma_l1)) 70 | z1_sample = g1.sample() 71 | sldj = torch.zeros(b_s, device=device) 72 | 73 | # Inverse L1 74 | h1, sldj = glow.l1(z1_sample, sldj, reverse=True) 75 | h1 = squeeze(h1, reverse=True) 76 | 77 | # Sample z2 78 | mu_l2, logsigma_l2 = nn_theta.l2(t0_zi.l2, h1) 79 | g2 = Normal(loc=mu_l2, scale=temperature * torch.exp(logsigma_l2)) 80 | z2_sample = g2.sample() 81 | h12 = torch.cat([h1, z2_sample], dim=1) 82 | h12, sldj = glow.l2(h12, sldj, reverse=True) 83 | h12 = squeeze(h12, reverse=True) 84 | 85 | # Sample z3 86 | mu_l3, logsigma_l3 = nn_theta.l3(t0_zi.l3, h12) 87 | g3 = Normal(loc=mu_l3, scale=temperature * torch.exp(logsigma_l3)) 88 | z3_sample = g3.sample() 89 | 90 | x_t = torch.cat([h12, z3_sample], dim=1) 91 | x_t, sldj = glow.l3(x_t, sldj, reverse=True) 92 | x_t = squeeze(x_t, reverse=True) 93 | 94 | x_t = torch.sigmoid(x_t) 95 | 96 | if not os.path.exists('samples_100/'): 97 | os.mkdir('samples_100/') 98 | 99 | torchvision.utils.save_image(x_t, 'samples_100/sample{}_{}.png'.format(m, n+1)) 100 | 101 | assert context_frame.shape == x_t.shape 102 | context_frame = x_t.clone() 103 | 104 | 105 | def main(): 106 | import torch.nn as nn 107 | 108 | seed = 123 109 | random.seed(seed) 110 | np.random.seed(seed) 111 | torch.manual_seed(seed) 112 | torch.cuda.manual_seed_all(seed) 113 | 114 | tr = transforms.Compose([transforms.ToTensor()]) 115 | train_data = MovingObjects("train", tr) 116 | train_loader = DataLoader(train_data, 117 | num_workers=1, 118 | batch_size=1, 119 | shuffle=False, 120 | pin_memory=True) 121 | 122 | in_chs = 3 123 | flow_l3 = nn.DataParallel(_Glow(in_channels=4 * in_chs, mid_channels=512, num_steps=24)).to(device) 124 | flow_l2 = nn.DataParallel(_Glow(in_channels=8 * in_chs, mid_channels=512, num_steps=24)).to(device) 125 | flow_l1 = nn.DataParallel(_Glow(in_channels=16 * in_chs, mid_channels=512, num_steps=24)).to(device) 126 | 127 | nntheta3 = nn.DataParallel( 128 | NNTheta(encoder_ch_in=4 * in_chs, encoder_mode='conv_net', h_ch_in=2 * in_chs, 129 | num_blocks=5)).to(device) # z1:2x32x32 130 | nntheta2 = nn.DataParallel( 131 | NNTheta(encoder_ch_in=8 * in_chs, encoder_mode='conv_net', h_ch_in=4 * in_chs, 132 | num_blocks=5)).to(device) # z2:4x16x16 133 | nntheta1 = nn.DataParallel(NNTheta(encoder_ch_in=16 * in_chs, encoder_mode='conv_net', 134 | num_blocks=5)).to(device) 135 | 136 | model_path = '/b_test/azimi/results/VideoFlow/SMovement/exp10/sacred/snapshots/109.pth' 137 | if True: 138 | print('model loading ...') 139 | flow_l3.load_state_dict(torch.load(model_path)['glow_l3']) 140 | flow_l2.load_state_dict(torch.load(model_path)['glow_l2']) 141 | flow_l1.load_state_dict(torch.load(model_path)['glow_l1']) 142 | nntheta3.load_state_dict(torch.load(model_path)['nn_theta_l3']) 143 | nntheta2.load_state_dict(torch.load(model_path)['nn_theta_l2']) 144 | nntheta1.load_state_dict(torch.load(model_path)['nn_theta_l1']) 145 | 146 | glow = Glow(l3=flow_l3, l2=flow_l2, l1=flow_l1) 147 | nn_theta = NN_Theta(l3=nntheta3, l2=nntheta2, l1=nntheta1) 148 | 149 | context = next(iter(train_loader)).cuda() 150 | sample_100(context, glow, nn_theta) 151 | 152 | 153 | if __name__ == "__main__": 154 | main() 155 | -------------------------------------------------------------------------------- /shell_util.py: -------------------------------------------------------------------------------- 1 | import torch, os 2 | 3 | class AverageMeter(object): 4 | """Computes and stores the average and current value. 5 | 6 | Adapted from: https://github.com/pytorch/examples/blob/master/imagenet/train.py 7 | """ 8 | def __init__(self): 9 | self.val = 0. 10 | self.avg = 0. 11 | self.sum = 0. 12 | self.count = 0. 13 | 14 | def reset(self): 15 | self.val = 0. 16 | self.avg = 0. 17 | self.sum = 0. 18 | self.count = 0. 19 | 20 | def update(self, val, n=1): 21 | self.val = val 22 | self.sum += val * n 23 | self.count += n 24 | self.avg = self.sum / self.count 25 | 26 | 27 | def save_model(glow, nn_theta, optimizer, epoch, save_path): 28 | if not os.path.exists(save_path + 'snapshots/'): 29 | os.mkdir(save_path + 'snapshots/') 30 | 31 | torch.save({ 32 | 'glow_l3': glow.l3.state_dict(), 33 | 'glow_l2': glow.l2.state_dict(), 34 | 'glow_l1': glow.l1.state_dict(), 35 | 'nn_theta_l3': nn_theta.l3.state_dict(), 36 | 'nn_theta_l2': nn_theta.l2.state_dict(), 37 | 'nn_theta_l1': nn_theta.l1.state_dict(), 38 | 'optimizer': optimizer.state_dict() 39 | }, save_path+'snapshots/{}.pth'.format(epoch)) 40 | 41 | 42 | 43 | def track_grads(nn_theta, glow, iter): 44 | for name, param in nn_theta.l1.named_parameters(): 45 | if param.requires_grad: 46 | writer.add_scalar('data/nntheta1', torch.max(param.grad.data), iter) 47 | writer.add_scalar('data/nntheta1', torch.min(param.grad.data), iter) 48 | 49 | for name, param in nn_theta.l2.named_parameters(): 50 | if param.requires_grad: 51 | writer.add_scalar('data/nntheta2', torch.max(param.grad.data), iter) 52 | writer.add_scalar('data/nntheta2', torch.min(param.grad.data), iter) 53 | 54 | for name, param in nn_theta.l3.named_parameters(): 55 | if param.requires_grad: 56 | writer.add_scalar('data/nntheta3', torch.max(param.grad.data), iter) 57 | writer.add_scalar('data/nntheta3', torch.min(param.grad.data), iter) 58 | 59 | for name, param in glow.l1.named_parameters(): 60 | if param.requires_grad: 61 | writer.add_scalar('data/glow1', torch.max(param.grad.data), iter) 62 | writer.add_scalar('data/glow1', torch.min(param.grad.data), iter) 63 | 64 | for name, param in glow.l2.named_parameters(): 65 | if param.requires_grad: 66 | writer.add_scalar('data/glow2', torch.max(param.grad.data), iter) 67 | writer.add_scalar('data/glow2', torch.min(param.grad.data), iter) 68 | 69 | for name, param in glow.l3.named_parameters(): 70 | if param.requires_grad: 71 | writer.add_scalar('data/glow3', torch.max(param.grad.data), iter) 72 | writer.add_scalar('data/glow3', torch.min(param.grad.data), iter) 73 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import torchvision 4 | import torch.nn as nn 5 | from torchvision import transforms 6 | from torch.utils.data import DataLoader 7 | import torch.optim.lr_scheduler as sched 8 | import torch.nn.utils as utils 9 | from torch.distributions.normal import Normal 10 | 11 | from collections import namedtuple 12 | from tqdm import tqdm 13 | 14 | from glow import _Glow, PreProcess, squeeze 15 | from modules import * 16 | from shell_util import AverageMeter, save_model 17 | from optim_util import bits_per_dim 18 | from dataloader import MovingObjects 19 | 20 | from tensorboardX import SummaryWriter 21 | from sacred import Experiment 22 | from sacred.observers import FileStorageObserver 23 | 24 | global_step = 0 25 | 26 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 27 | pre_process = PreProcess().to(device) 28 | 29 | Z_splits = namedtuple('Z_splits', 'l3 l2 l1') 30 | Glow = namedtuple('Glow', 'l3 l2 l1') 31 | NN_Theta = namedtuple('NNTheta', 'l3 l2 l1') 32 | 33 | PATH = './sacred/' 34 | writer = SummaryWriter(PATH) 35 | ex = Experiment() 36 | ex.observers.append(FileStorageObserver.create(PATH)) 37 | 38 | 39 | @ex.config 40 | def config(): 41 | tr_conf = { 42 | 'encoder_mode': 'conv_net', 43 | 'enc_depth': 5, 44 | 'n_epoch': 600, 45 | 'b_s': 26, 46 | 'lr': 1e-4, 47 | 'k': 256, 48 | 'input_channels': 3, 49 | 'resume': True, 50 | 'starting_epoch': 56 51 | } 52 | 53 | 54 | def train_smovement(train_loader, glow, nn_theta, loss_fn, optimizer, scheduler, epoch): 55 | print("ID: exp12_1 testing lr 1e-4 and only one step movement, no glow loss with random patch") 56 | global global_step 57 | loss_meter = AverageMeter() 58 | # loss_fn_glow = GlowLoss() 59 | for net in glow: 60 | net.train() 61 | for net in nn_theta: 62 | net.train() 63 | 64 | with tqdm(total=len(train_loader.dataset)) as progress_bar: 65 | for itr, sequence in enumerate(train_loader): 66 | sequence = sequence.to(device) 67 | b_s = sequence.size(0) 68 | 69 | # start_index = torch.LongTensor(1).random_(0, 2) 70 | # random_patch = sequence[:, start_index:start_index + 2, :, :, :] 71 | 72 | random_patch = [] 73 | for n in range(b_s): 74 | start_index = torch.LongTensor(1).random_(0, 2) 75 | random_patch.append(sequence[n, start_index:start_index + 2, :, :, :]) 76 | random_patch = torch.stack(random_patch, dim=0) 77 | 78 | t0_zi, _, sldj_0 = flow_forward(random_patch[:, 0, :, :, :], glow) 79 | # z_glow = recover_z_shape(t0_zi) 80 | # loss_glow = loss_fn_glow(z_glow, sldj_0) 81 | 82 | t1_zi_out, t1_zi_h, sldj_1 = flow_forward(random_patch[:, 1, :, :, :], glow) 83 | h12 = t1_zi_h.l3 84 | 85 | mu_l3, logsigma_l3 = nn_theta.l3(t0_zi.l3, h12) 86 | g3 = Normal(loc=mu_l3, scale=torch.exp(logsigma_l3)) 87 | 88 | h1 = t1_zi_h.l2 89 | mu_l2, logsigma_l2 = nn_theta.l2(t0_zi.l2, h1) 90 | g2 = Normal(loc=mu_l2, scale=torch.exp(logsigma_l2)) 91 | 92 | mu_l1, logsigma_l1 = nn_theta.l1(t0_zi.l1) 93 | g1 = Normal(loc=mu_l1, scale=torch.exp(logsigma_l1)) 94 | 95 | total_loss = loss_fn(g1, g2, g3, z=t1_zi_out, sldj=sldj_1, 96 | input_dim=random_patch[:, 1, :, :, :].size()) 97 | 98 | # total_loss = loss #+ loss_glow 99 | total_loss.backward() 100 | 101 | clip_grad_value(optimizer) 102 | 103 | optimizer.step() 104 | optimizer.zero_grad() 105 | if scheduler is not None: 106 | scheduler.step(global_step) 107 | 108 | loss_meter.update(total_loss.item(), b_s) 109 | progress_bar.set_postfix(nll=loss_meter.avg, 110 | bpd=bits_per_dim(random_patch[:, 1, :, :, :], loss_meter.avg), 111 | lr=optimizer.param_groups[0]['lr']) 112 | progress_bar.update(b_s) 113 | global_step += 1 114 | 115 | print("global step:", global_step) 116 | 117 | torch.cuda.empty_cache() 118 | #save_model(glow, nn_theta, optimizer, scheduler, epoch, PATH) 119 | save_model(glow, nn_theta, optimizer, epoch, PATH) 120 | writer.add_scalar('data/train_loss', loss_meter.avg, epoch) 121 | writer.add_scalar('data/lr', get_lr(optimizer), epoch) 122 | 123 | context = next(iter(train_loader)).cuda() 124 | flow_inverse_smovement(context, glow, nn_theta, epoch) 125 | 126 | 127 | def recover_z_shape(t_z): 128 | z1 = squeeze(t_z.l1, reverse=True) 129 | z2 = torch.cat([z1, t_z.l2], dim=1) 130 | z2 = squeeze(z2, reverse=True) 131 | z3 = torch.cat([z2, t_z.l3], dim=1) 132 | z3 = squeeze(z3, reverse=True) 133 | return z3 134 | 135 | 136 | def clip_grad_value(optimizer, max_val=10.): 137 | for group in optimizer.param_groups: 138 | utils.clip_grad_value_(group['params'], max_val) 139 | 140 | 141 | def flow_forward(x, flow): 142 | if x.min() < 0 or x.max() > 1: 143 | raise ValueError('Expected x in [0, 1], got min/max {}/{}' 144 | .format(x.min(), x.max())) 145 | 146 | # pre-process 147 | x, sldj = pre_process(x) 148 | # L3 149 | x3 = squeeze(x, reverse=False) 150 | x3, sldj = flow.l3(x3, sldj, reverse=False) 151 | x3, x_split3 = x3.chunk(2, dim=1) 152 | # L2 153 | x2 = squeeze(x3, reverse=False) 154 | x2, sldj = flow.l2(x2, sldj, reverse=False) 155 | x2, x_split2 = x2.chunk(2, dim=1) 156 | # L1 157 | x1 = squeeze(x2, reverse=False) 158 | x1, sldj = flow.l1(x1, sldj) 159 | 160 | partition_out = Z_splits(l3=x_split3, l2=x_split2, l1=x1) 161 | partition_h = Z_splits(l3=x3, l2=x2, l1=None) 162 | 163 | return partition_out, partition_h, sldj 164 | 165 | 166 | def flow_inverse_smovement(context, glow, nn_theta, epoch): 167 | for net in glow: 168 | net.eval() 169 | for net in nn_theta: 170 | net.eval() 171 | 172 | # pre-process the context frame 173 | b_s = context.size(0) 174 | context_frame = context[:, 0, ...] 175 | t0_zi, _, _ = flow_forward(context_frame, glow) 176 | 177 | mu_l1, logsigma_l1 = nn_theta.l1(t0_zi.l1) 178 | g1 = Normal(loc=mu_l1, scale=torch.exp(logsigma_l1)) 179 | z1_sample = g1.sample() 180 | print("z1", z1_sample.shape) 181 | sldj = torch.zeros(b_s, device=device) 182 | 183 | # Inverse L1 184 | h1, sldj = glow.l1(z1_sample, sldj, reverse=True) 185 | h1 = squeeze(h1, reverse=True) 186 | 187 | # Sample z2 188 | mu_l2, logsigma_l2 = nn_theta.l2(t0_zi.l2, h1) 189 | g2 = Normal(loc=mu_l2, scale=torch.exp(logsigma_l2)) 190 | z2_sample = g2.sample() 191 | h12 = torch.cat([h1, z2_sample], dim=1) 192 | h12, sldj = glow.l2(h12, sldj, reverse=True) 193 | h12 = squeeze(h12, reverse=True) 194 | 195 | # Sample z3 196 | mu_l3, logsigma_l3 = nn_theta.l3(t0_zi.l3, h12) 197 | g3 = Normal(loc=mu_l3, scale=torch.exp(logsigma_l3)) 198 | z3_sample = g3.sample() 199 | 200 | x_t = torch.cat([h12, z3_sample], dim=1) 201 | x_t, sldj = glow.l3(x_t, sldj, reverse=True) 202 | x_t = squeeze(x_t, reverse=True) 203 | 204 | x_t = torch.sigmoid(x_t) 205 | 206 | torchvision.utils.save_image(x_t, 'samples/sample{}.png'.format(epoch)) 207 | torchvision.utils.save_image(context[:, 0, ...].squeeze(), 'samples/context{}.png'.format(epoch)) 208 | torchvision.utils.save_image(context[:, 1, ...].squeeze(), 'samples/gt{}.png'.format(epoch)) 209 | 210 | 211 | def get_lr(optimizer): 212 | for param_group in optimizer.param_groups: 213 | return param_group['lr'] 214 | 215 | 216 | @ex.automain 217 | def main(tr_conf): 218 | import torch.nn as nn 219 | 220 | seed = 12345 221 | random.seed(seed) 222 | np.random.seed(seed) 223 | torch.manual_seed(seed) 224 | torch.cuda.manual_seed_all(seed) 225 | 226 | tr = transforms.Compose([transforms.ToTensor()]) 227 | train_data = MovingObjects("train", tr, seed) 228 | train_loader = DataLoader(train_data, 229 | num_workers=tr_conf['b_s'], 230 | batch_size=tr_conf['b_s'], 231 | shuffle=True, 232 | pin_memory=True) 233 | 234 | param_list = [] 235 | in_chs = tr_conf['input_channels'] 236 | flow_l3 = nn.DataParallel(_Glow(in_channels=4 * in_chs, mid_channels=512, num_steps=24)).to(device) 237 | flow_l2 = nn.DataParallel(_Glow(in_channels=8 * in_chs, mid_channels=512, num_steps=24)).to(device) 238 | flow_l1 = nn.DataParallel(_Glow(in_channels=16 * in_chs, mid_channels=512, num_steps=24)).to(device) 239 | 240 | nntheta3 = nn.DataParallel( 241 | NNTheta(encoder_ch_in=4 * in_chs, encoder_mode=tr_conf['encoder_mode'], h_ch_in=2 * in_chs, 242 | num_blocks=tr_conf['enc_depth'])).to(device) # z1:2x32x32 243 | nntheta2 = nn.DataParallel( 244 | NNTheta(encoder_ch_in=8 * in_chs, encoder_mode=tr_conf['encoder_mode'], h_ch_in=4 * in_chs, 245 | num_blocks=tr_conf['enc_depth'])).to(device) # z2:4x16x16 246 | nntheta1 = nn.DataParallel(NNTheta(encoder_ch_in=16 * in_chs, encoder_mode=tr_conf['encoder_mode'], 247 | num_blocks=tr_conf['enc_depth'])).to(device) 248 | 249 | model_path = '/b_test/azimi/results/VideoFlow/SMovement/exp12_2/sacred/snapshots/55.pth' 250 | if tr_conf['resume']: 251 | print('model loading ...') 252 | flow_l3.load_state_dict(torch.load(model_path)['glow_l3']) 253 | flow_l2.load_state_dict(torch.load(model_path)['glow_l2']) 254 | flow_l1.load_state_dict(torch.load(model_path)['glow_l1']) 255 | nntheta3.load_state_dict(torch.load(model_path)['nn_theta_l3']) 256 | nntheta2.load_state_dict(torch.load(model_path)['nn_theta_l2']) 257 | nntheta1.load_state_dict(torch.load(model_path)['nn_theta_l1']) 258 | print("****LOAD THE OPTIMIZER") 259 | 260 | glow = Glow(l3=flow_l3, l2=flow_l2, l1=flow_l1) 261 | nn_theta = NN_Theta(l3=nntheta3, l2=nntheta2, l1=nntheta1) 262 | 263 | for f_level in glow: 264 | param_list += list(f_level.parameters()) 265 | 266 | for nn in nn_theta: 267 | param_list += list(nn.parameters()) 268 | 269 | loss_fn = NLLLossVF() 270 | 271 | optimizer = torch.optim.Adam(param_list, lr=tr_conf['lr']) 272 | optimizer.load_state_dict(torch.load(model_path)['optimizer']) 273 | optimizer.zero_grad() 274 | 275 | # scheduler_step = sched.StepLR(optimizer, step_size=1, gamma=0.99) 276 | # linear_decay = sched.LambdaLR(optimizer, lambda s: 1. - s / 150000. ) 277 | # linear_decay.step(global_step) 278 | 279 | # scheduler = sched.LambdaLR(optimizer, lambda s: min(1., s / 10000)) 280 | # optimizer.load_state_dict(torch.load(model_path)['optimizer']) 281 | 282 | for epoch in range(tr_conf['starting_epoch'], tr_conf['n_epoch']): 283 | print("the learning rate for epoch {} is {}".format(epoch, get_lr(optimizer))) 284 | train_smovement(train_loader, glow, nn_theta, loss_fn, optimizer, None, epoch) 285 | #scheduler_step.step() 286 | --------------------------------------------------------------------------------