├── README.md
├── act_norm.py
├── array_util.py
├── coupling.py
├── dataloader.py
├── glow.py
├── inv_conv.py
├── modules.py
├── optim_util.py
├── sample_100.py
├── shell_util.py
└── train.py
/README.md:
--------------------------------------------------------------------------------
1 | # Pytorch-VideoFlow
2 | Pytorch implementation of paper https://arxiv.org/abs/1903.01434
3 | Glow code adapted from https://github.com/chaiyujin/glow-pytorch
4 | Work in progress!
5 |
--------------------------------------------------------------------------------
/act_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from array_util import mean_dim
5 |
6 |
7 | class ActNorm(nn.Module):
8 | """Activation normalization for 2D inputs.
9 |
10 | The bias and scale get initialized using the mean and variance of the
11 | first mini-batch. After the init, bias and scale are trainable parameters.
12 |
13 | Adapted from:
14 | > https://github.com/openai/glow
15 | """
16 | def __init__(self, num_features, scale=1., return_ldj=False):
17 | super(ActNorm, self).__init__()
18 | self.register_buffer('is_initialized', torch.zeros(1))
19 | self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1))
20 | self.logs = nn.Parameter(torch.zeros(1, num_features, 1, 1))
21 |
22 | self.num_features = num_features
23 | self.scale = float(scale)
24 | self.eps = 1e-6
25 | self.return_ldj = return_ldj
26 |
27 | def initialize_parameters(self, x):
28 | if not self.training:
29 | return
30 |
31 | with torch.no_grad():
32 | bias = -mean_dim(x.clone(), dim=[0, 2, 3], keepdims=True)
33 | v = mean_dim((x.clone() + bias) ** 2, dim=[0, 2, 3], keepdims=True)
34 | logs = (self.scale / (v.sqrt() + self.eps)).log()
35 | self.bias.data.copy_(bias.data)
36 | self.logs.data.copy_(logs.data)
37 | self.is_initialized += 1.
38 |
39 | def _center(self, x, reverse=False):
40 | if reverse:
41 | return x - self.bias
42 | else:
43 | return x + self.bias
44 |
45 | def _scale(self, x, sldj, reverse=False):
46 | logs = self.logs
47 | if reverse:
48 | x = x * logs.mul(-1).exp()
49 | else:
50 | x = x * logs.exp()
51 |
52 | if sldj is not None:
53 | ldj = logs.sum() * x.size(2) * x.size(3)
54 | if reverse:
55 | sldj = sldj - ldj
56 | else:
57 | sldj = sldj + ldj
58 |
59 | return x, sldj
60 |
61 | def forward(self, x, ldj=None, reverse=False):
62 | if not self.is_initialized:
63 | self.initialize_parameters(x)
64 |
65 | if reverse:
66 | x, ldj = self._scale(x, ldj, reverse)
67 | x = self._center(x, reverse)
68 | else:
69 | x = self._center(x, reverse)
70 | x, ldj = self._scale(x, ldj, reverse)
71 |
72 | if self.return_ldj:
73 | return x, ldj
74 |
75 | return x
76 |
--------------------------------------------------------------------------------
/array_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def mean_dim(tensor, dim=None, keepdims=False):
5 | """Take the mean along multiple dimensions.
6 |
7 | Args:
8 | tensor (torch.Tensor): Tensor of values to average.
9 | dim (list): List of dimensions along which to take the mean.
10 | keepdims (bool): Keep dimensions rather than squeezing.
11 |
12 | Returns:
13 | mean (torch.Tensor): New tensor of mean value(s).
14 | """
15 | if dim is None:
16 | return tensor.mean()
17 | else:
18 | if isinstance(dim, int):
19 | dim = [dim]
20 | dim = sorted(dim)
21 | for d in dim:
22 | tensor = tensor.mean(dim=d, keepdim=True)
23 | if not keepdims:
24 | for i, d in enumerate(dim):
25 | tensor.squeeze_(d-i)
26 | return tensor
27 |
--------------------------------------------------------------------------------
/coupling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from act_norm import ActNorm
6 |
7 |
8 | class AddCoupling(nn.Module):
9 | def __init__(self, in_channels, mid_channels, use_act_norm=True):
10 | super(AddCoupling, self).__init__()
11 | self.nn = NN(in_channels, mid_channels, in_channels, use_act_norm)
12 | # self.scale = nn.Parameter(torch.ones(in_channels, 1, 1))
13 |
14 | def forward(self, x, ldj, reverse=False):
15 | x_change, x_id = x.chunk(2, dim=1)
16 | shift = self.nn(x_id)
17 | if not reverse:
18 | x_change = x_change + shift
19 |
20 | else:
21 | x_change = x_change - shift #z2
22 |
23 | return torch.cat([x_id, x_change], dim=1), ldj+0.0
24 |
25 |
26 | class AffineCoupling(nn.Module):
27 | """Affine coupling layer originally used in Real NVP and described by Glow.
28 |
29 | Note: The official Glow implementation (https://github.com/openai/glow)
30 | uses a different affine coupling formulation than described in the paper.
31 | This implementation follows the paper and Real NVP.
32 |
33 | Args:
34 | in_channels (int): Number of channels in the input.
35 | mid_channels (int): Number of channels in the intermediate activation
36 | in NN.
37 | """
38 | def __init__(self, in_channels, mid_channels):
39 | super(AffineCoupling, self).__init__()
40 | self.nn = NN(in_channels, mid_channels, 2 * in_channels)
41 | self.scale = nn.Parameter(torch.ones(in_channels, 1, 1))
42 |
43 | def forward(self, x, ldj, reverse=False):
44 | x_change, x_id = x.chunk(2, dim=1)
45 |
46 | st = self.nn(x_id)
47 | s, t = st[:, 0::2, ...], st[:, 1::2, ...]
48 | s = self.scale * torch.tanh(s)
49 |
50 | # Scale and translate
51 | if reverse:
52 | x_change = x_change * s.mul(-1).exp() - t
53 | ldj = ldj - s.flatten(1).sum(-1)
54 | else:
55 | x_change = (x_change + t) * s.exp()
56 | ldj = ldj + s.flatten(1).sum(-1)
57 |
58 | x = torch.cat((x_change, x_id), dim=1)
59 |
60 | return x, ldj
61 |
62 |
63 | class NN(nn.Module):
64 | """Small convolutional network used to compute scale and translate factors.
65 |
66 | Args:
67 | in_channels (int): Number of channels in the input.
68 | mid_channels (int): Number of channels in the hidden activations.
69 | out_channels (int): Number of channels in the output.
70 | use_act_norm (bool): Use activation norm rather than batch norm.
71 | """
72 | def __init__(self, in_channels, mid_channels, out_channels,
73 | use_act_norm=False):
74 | super(NN, self).__init__()
75 | norm_fn = ActNorm if use_act_norm else nn.BatchNorm2d
76 |
77 | self.in_norm = norm_fn(in_channels)
78 | self.in_conv = nn.Conv2d(in_channels, mid_channels,
79 | kernel_size=3, padding=1, bias=False)
80 | nn.init.normal_(self.in_conv.weight, 0., 0.05)
81 |
82 | self.mid_norm = norm_fn(mid_channels)
83 | self.mid_conv = nn.Conv2d(mid_channels, mid_channels,
84 | kernel_size=1, padding=0, bias=False)
85 | nn.init.normal_(self.mid_conv.weight, 0., 0.05)
86 |
87 | self.out_norm = norm_fn(mid_channels)
88 | self.out_conv = nn.Conv2d(mid_channels, out_channels,
89 | kernel_size=3, padding=1, bias=True)
90 | nn.init.zeros_(self.out_conv.weight)
91 | nn.init.zeros_(self.out_conv.bias)
92 |
93 | def forward(self, x):
94 | x = self.in_norm(x)
95 | x = F.relu(x)
96 | x = self.in_conv(x)
97 |
98 | x = self.mid_norm(x)
99 | x = F.relu(x)
100 | x = self.mid_conv(x)
101 |
102 | x = self.out_norm(x)
103 | x = F.relu(x)
104 | x = self.out_conv(x)
105 |
106 | return x
107 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import Dataset
4 | import random, PIL
5 | from PIL import Image, ImageDraw
6 | from torchvision import datasets, transforms
7 |
8 |
9 | class MovingObjects(Dataset):
10 | def __init__(self, mode, transform, my_seed, dummy_len=30000):
11 | print("NOTE: no data normalization and data range is [0,1]")
12 | if mode is "train":
13 | random.seed(my_seed)
14 | torch.manual_seed(my_seed)
15 | np.random.seed(my_seed)
16 | #else:
17 | # random.seed(1234)
18 | # torch.manual_seed(1234)
19 | # np.random.seed(1234)
20 |
21 | # constant speed of 4 pixels
22 | self.seq_len = 3
23 | # 8 possible direction of movement
24 | pix = 8
25 | self.dummy_len = dummy_len
26 | self.deltaxy = [(pix,pix), (pix,0), (pix,-pix), (0,pix), (0,-pix), (-pix,-pix), (-pix,0), (-pix,pix)]
27 | self.shapes = ['circle', 'rectangle', 'polygon']
28 | self.size_range = [18] #, 26] # range(10, 20)
29 |
30 | self.center_xy = np.array([32, 32])
31 | self.transform = transform
32 |
33 | def __len__(self):
34 | return self.dummy_len
35 |
36 | def __getitem__(self, idx):
37 | """idx is a dummy value"""
38 | # pick a shape
39 | shape = random.choice(self.shapes)
40 |
41 | # pick a direction
42 | deltax, deltay = random.choice(self.deltaxy)
43 |
44 | # pick size
45 | size = random.choice(self.size_range)
46 |
47 | # pick color
48 | r = random.choice(range(0,256,200))
49 | g = random.choice(range(0,256,200))
50 | b = random.choice(range(0,256,200))
51 |
52 | frames = []
53 | img1 = PIL.Image.new(mode='RGB', size=(64,64), color='gray')
54 | if shape is 'circle':
55 | for i in range(self.seq_len):
56 |
57 | c_img = img1.copy()
58 | c_draw = ImageDraw.Draw(c_img)
59 | c1 = tuple(self.center_xy - size/2 + np.array([deltax, deltay])*i)
60 | c2 = tuple(self.center_xy + size/2 + np.array([deltax, deltay])*i)
61 |
62 | c_draw.ellipse([c1, c2], fill=(r ,g , b))
63 | frames.append(self.transform(c_img))
64 |
65 | elif shape is 'rectangle':
66 | for i in range(self.seq_len):
67 |
68 | c_img = img1.copy()
69 | c_draw = ImageDraw.Draw(c_img)
70 | c1 = tuple(self.center_xy - size/2 + np.array([deltax, deltay])*i)
71 | c2 = tuple(self.center_xy + size/2 + np.array([deltax, deltay])*i)
72 |
73 | c_draw.rectangle([c1, c2], fill=(r ,g , b))
74 | frames.append(self.transform(c_img))
75 |
76 | elif shape is 'polygon':
77 | for i in range(self.seq_len):
78 | c_img = img1.copy()
79 | c_draw = ImageDraw.Draw(c_img)
80 |
81 | c1 = tuple(self.center_xy - np.array([0, size/2]) + np.array([deltax, deltay])*i)
82 | c2 = tuple(self.center_xy + np.array([size/2, size/3]) + np.array([deltax, deltay])*i)
83 | c3 = tuple(self.center_xy + np.array([-size/2, size/3]) + np.array([deltax, deltay])*i)
84 |
85 | c_draw.polygon([c1, c2, c3], fill=(r ,g , b))
86 |
87 | #change to tensor
88 | frames.append(self.transform(c_img))
89 |
90 | else:
91 | raise NotImplementedError()
92 |
93 | frames_tensor = torch.stack(frames, dim=0)
94 | return frames_tensor
95 |
96 |
97 | class MovingMNIST(object):
98 |
99 | """Data Handler that creates Bouncing MNIST dataset on the fly."""
100 |
101 | def __init__(self, train, data_root, seq_len=20, num_digits=2, image_size=64, deterministic=True):
102 | path = data_root
103 | self.seq_len = seq_len
104 | self.num_digits = num_digits
105 | self.image_size = image_size
106 | self.step_length = 0.1
107 | self.digit_size = 32
108 | self.deterministic = deterministic
109 | self.seed_is_set = False # multi threaded loading
110 | self.channels = 1
111 |
112 | self.data = datasets.MNIST(
113 | path,
114 | train=train,
115 | download=True,
116 | transform=transforms.Compose(
117 | [transforms.Scale(self.digit_size),
118 | transforms.ToTensor()]))
119 |
120 | self.N = len(self.data)
121 |
122 | def set_seed(self, seed):
123 | if not self.seed_is_set:
124 | self.seed_is_set = True
125 | np.random.seed(seed)
126 |
127 | def __len__(self):
128 | return self.N
129 |
130 | def __getitem__(self, index):
131 | self.set_seed(index)
132 | image_size = self.image_size
133 | digit_size = self.digit_size
134 | x = np.zeros((self.seq_len,
135 | image_size,
136 | image_size,
137 | self.channels),
138 | dtype=np.float32)
139 | for n in range(self.num_digits):
140 | idx = np.random.randint(self.N)
141 | digit, _ = self.data[idx]
142 |
143 | sx = np.random.randint(image_size-digit_size)
144 | sy = np.random.randint(image_size-digit_size)
145 | dx = np.random.randint(-4, 5)
146 | dy = np.random.randint(-4, 5)
147 | for t in range(self.seq_len):
148 | if sy < 0:
149 | sy = 0
150 | if self.deterministic:
151 | dy = -dy
152 | else:
153 | dy = np.random.randint(1, 5)
154 | dx = np.random.randint(-4, 5)
155 | elif sy >= image_size-32:
156 | sy = image_size-32-1
157 | if self.deterministic:
158 | dy = -dy
159 | else:
160 | dy = np.random.randint(-4, 0)
161 | dx = np.random.randint(-4, 5)
162 |
163 | if sx < 0:
164 | sx = 0
165 | if self.deterministic:
166 | dx = -dx
167 | else:
168 | dx = np.random.randint(1, 5)
169 | dy = np.random.randint(-4, 5)
170 | elif sx >= image_size-32:
171 | sx = image_size-32-1
172 | if self.deterministic:
173 | dx = -dx
174 | else:
175 | dx = np.random.randint(-4, 0)
176 | dy = np.random.randint(-4, 5)
177 |
178 | x[t, sy:sy+32, sx:sx+32, 0] += digit.numpy().squeeze()
179 | sy += dy
180 | sx += dx
181 |
182 | x[x>1] = 1.
183 | # t, w, h, c --> t, c, w, h
184 | return x.transpose([0,3,1,2])
185 |
--------------------------------------------------------------------------------
/glow.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from act_norm import ActNorm
6 | from coupling import *
7 | from inv_conv import InvConv
8 |
9 |
10 | class _Glow(nn.Module):
11 | """Flow per level of Glow
12 |
13 | Args:
14 | in_channels (int): Number of channels in the input.
15 | mid_channels (int): Number of channels in hidden layers of each step.
16 | num_levels (int): Number of levels to construct. Counter for recursion.
17 | num_steps (int): Number of steps of flow for each level.
18 | """
19 | def __init__(self, in_channels, mid_channels, num_steps, num_levels=None):
20 | super(_Glow, self).__init__()
21 | self.steps = nn.ModuleList([_FlowStep(in_channels=in_channels,
22 | mid_channels=mid_channels)
23 | for _ in range(num_steps)])
24 |
25 | def forward(self, x, sldj, reverse=False):
26 | if not reverse:
27 | for n,step in enumerate(self.steps):
28 | x, sldj = step(x, sldj, reverse)
29 |
30 | if reverse:
31 | for step in reversed(self.steps):
32 | x, sldj = step(x, sldj, reverse)
33 |
34 | return x, sldj
35 |
36 |
37 | class _FlowStep(nn.Module):
38 | def __init__(self, in_channels, mid_channels, coupling='affine', use_act_norm_in_coupling=True):
39 | super(_FlowStep, self).__init__()
40 |
41 | # Activation normalization, invertible 1x1 convolution, affine coupling
42 | self.norm = ActNorm(in_channels, return_ldj=True)
43 | self.conv = InvConv(in_channels)
44 | if coupling is "additive":
45 | self.coup = AddCoupling(in_channels // 2, mid_channels, use_act_norm_in_coupling)
46 | else:
47 | self.coup = AffineCoupling(in_channels // 2, mid_channels)
48 |
49 | def forward(self, x, sldj=None, reverse=False):
50 | if reverse:
51 | x, sldj = self.coup(x, sldj, reverse)
52 | x, sldj = self.conv(x, sldj, reverse)
53 | x, sldj = self.norm(x, sldj, reverse)
54 | else:
55 | x, sldj = self.norm(x, sldj, reverse)
56 | x, sldj = self.conv(x, sldj, reverse)
57 | x, sldj = self.coup(x, sldj, reverse)
58 |
59 | return x, sldj
60 |
61 |
62 | class PreProcess(nn.Module):
63 | def __init__(self, bound=0.9):
64 | super(PreProcess, self).__init__()
65 | self.register_buffer('bounds', torch.tensor([bound], dtype=torch.float32))
66 |
67 | def forward(self, x):
68 | """Dequantize the input image `x` and convert to logits.
69 |
70 | See Also:
71 | - Dequantization: https://arxiv.org/abs/1511.01844, Section 3.1
72 | - Modeling logits: https://arxiv.org/abs/1605.08803, Section 4.1
73 |
74 | Args:
75 | x (torch.Tensor): Input image.
76 |
77 | Returns:
78 | y (torch.Tensor): Dequantized logits of `x`.
79 | """
80 | y = (x * 255. + torch.rand_like(x)) / 256.
81 | y = (2 * y - 1) * self.bounds
82 | y = (y + 1) / 2
83 | y = y.log() - (1. - y).log()
84 |
85 | # Save log-determinant of Jacobian of initial transform
86 | ldj = F.softplus(y) + F.softplus(-y) \
87 | - F.softplus((1. - self.bounds).log() - self.bounds.log())
88 | sldj = ldj.flatten(1).sum(-1)
89 | #sldj = torch.zeros(x.shape[0]).cuda()
90 |
91 | return y, sldj
92 |
93 |
94 | def squeeze(x, reverse=False):
95 | """Trade spatial extent for channels. In forward direction, convert each
96 | 1x4x4 volume of input into a 4x1x1 volume of output.
97 |
98 | Args:
99 | x (torch.Tensor): Input to squeeze or unsqueeze.
100 | reverse (bool): Reverse the operation, i.e., unsqueeze.
101 |
102 | Returns:
103 | x (torch.Tensor): Squeezed or unsqueezed tensor.
104 | """
105 | b, c, h, w = x.size()
106 | if reverse:
107 | # Unsqueeze
108 | x = x.view(b, c // 4, 2, 2, h, w)
109 | x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
110 | x = x.view(b, c // 4, h * 2, w * 2)
111 | else:
112 | # Squeeze
113 | x = x.view(b, c, h // 2, 2, w // 2, 2)
114 | x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
115 | x = x.view(b, c * 2 * 2, h // 2, w // 2)
116 |
117 | return x
118 |
119 | if __name__ == "__main__":
120 | model = Glow(512, 3, 32)
121 | test_in = torch.rand(1,3,64,64)
122 | out = model(test_in)
123 | print(len(out[2]))
124 | for e in out[2]:
125 | print(e.shape)
126 |
--------------------------------------------------------------------------------
/inv_conv.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class InvConv(nn.Module):
8 | """Invertible 1x1 Convolution for 2D inputs. Originally described in Glow
9 | (https://arxiv.org/abs/1807.03039). Does not support LU-decomposed version.
10 |
11 | Args:
12 | num_channels (int): Number of channels in the input and output.
13 | """
14 | def __init__(self, num_channels):
15 | super(InvConv, self).__init__()
16 | self.num_channels = num_channels
17 |
18 | # Initialize with a random orthogonal matrix
19 | w_init = np.random.randn(num_channels, num_channels)
20 | w_init = np.linalg.qr(w_init)[0].astype(np.float32)
21 | self.weight = nn.Parameter(torch.from_numpy(w_init))
22 |
23 | def forward(self, x, sldj, reverse=False):
24 | ldj = torch.slogdet(self.weight)[1] * x.size(2) * x.size(3)
25 |
26 | if reverse:
27 | weight = torch.inverse(self.weight.double()).float()
28 | sldj = sldj - ldj
29 | else:
30 | weight = self.weight
31 | sldj = sldj + ldj
32 |
33 | weight = weight.view(self.num_channels, self.num_channels, 1, 1)
34 | z = F.conv2d(x, weight)
35 |
36 | return z, sldj
37 |
--------------------------------------------------------------------------------
/modules.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class NNTheta(nn.Module):
8 | def __init__(self, encoder_ch_in, encoder_mode, num_blocks, h_ch_in=None):
9 | super(NNTheta, self).__init__()
10 | self.encoder_mode = encoder_mode
11 |
12 | if h_ch_in is not None:
13 | self.conv1 = nn.Conv2d(in_channels=h_ch_in, out_channels=h_ch_in, kernel_size=1)
14 | initialize(self.conv1, mode='gaussian')
15 |
16 | dilations = [1, 2]
17 | self.latent_encoder = nn.ModuleList()
18 | for i in range(num_blocks):
19 | self.latent_encoder.append(nn.ModuleList(
20 | [self.latent_dist_encoder(encoder_ch_in, dilation=d, mode=encoder_mode) for d in dilations]))
21 | # print("latent encoder:", self.latent_encoder)
22 |
23 | if h_ch_in:
24 | self.conv2 = nn.Conv2d(in_channels=encoder_ch_in, out_channels=encoder_ch_in, kernel_size=1)
25 | initialize(self.conv2, mode='zeros')
26 | else:
27 | self.conv2 = nn.Conv2d(in_channels=encoder_ch_in, out_channels=2 * encoder_ch_in, kernel_size=1)
28 | initialize(self.conv2, mode='zeros')
29 |
30 | def forward(self, z_past, h=None):
31 | if h is not None:
32 | h = self.conv1(h)
33 | encoder_input = torch.cat([z_past, h], dim=1)
34 | else:
35 | encoder_input = z_past.clone()
36 |
37 | for block in self.latent_encoder:
38 | parallel_outs = [pb(encoder_input) for pb in block]
39 |
40 | parallel_outs.append(encoder_input)
41 | encoder_input = sum(parallel_outs)
42 |
43 | last_t = self.conv2(encoder_input)
44 | deltaz_t, logsigma_t = last_t[:, 0::2, ...], last_t[:, 1::2, ...]
45 |
46 | # assert deltaz_t.shape == z_past.shape
47 | logsigma_t = torch.clamp(logsigma_t, min=-15., max=15.)
48 | mu_t = deltaz_t + z_past
49 | return mu_t, logsigma_t
50 |
51 | @staticmethod
52 | def latent_dist_encoder(ch_in, dilation, mode):
53 |
54 | if mode is "conv_net":
55 | layer1 = nn.Conv2d(in_channels=ch_in, out_channels=512, kernel_size=(3, 3),
56 | dilation=(dilation, dilation), padding=(dilation, dilation))
57 | initialize(layer1, mode='gaussian')
58 | layer2 = GATU2D(channels=512)
59 | layer3 = nn.Conv2d(in_channels=512, out_channels=ch_in, kernel_size=(3, 3),
60 | dilation=(dilation, dilation), padding=(dilation, dilation))
61 | initialize(layer3, mode='zeros')
62 |
63 | block = nn.Sequential(*[layer1, nn.ReLU(inplace=True), layer2, layer3])
64 |
65 | return block
66 |
67 |
68 | class GATU2D(nn.Module):
69 |
70 | def __init__(self, channels):
71 | super(GATU2D, self).__init__()
72 | self.conv1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=1)
73 | initialize(self.conv1, mode='gaussian')
74 | self.conv2 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=1)
75 | initialize(self.conv2, mode='gaussian')
76 |
77 | def forward(self, x):
78 | out1 = torch.tanh(self.conv1(x))
79 | out2 = torch.sigmoid(self.conv2(x))
80 | return out1 * out2
81 |
82 |
83 | class NLLLossVF(nn.Module):
84 | def __init__(self, k=256):
85 | super(NLLLossVF, self).__init__()
86 | self.k = k
87 |
88 | def forward(self, gaussian_1, gaussian_2, gaussian_3, z, sldj, input_dim):
89 |
90 | prior_ll_l3 = torch.sum(gaussian_3.log_prob(z.l3), [1, 2, 3])
91 | prior_ll_l2 = torch.sum(gaussian_2.log_prob(z.l2), [1, 2, 3])
92 | prior_ll_l1 = torch.sum(gaussian_1.log_prob(z.l1), [1, 2, 3])
93 |
94 | prior_ll = prior_ll_l1 + prior_ll_l2 + prior_ll_l3 - np.log(self.k) * np.prod(input_dim[1:])
95 | ll = prior_ll + sldj
96 | nll = -ll.mean()
97 | return nll
98 |
99 |
100 | class GlowLoss(nn.Module):
101 | def __init__(self, k=256):
102 | super(GlowLoss, self).__init__()
103 | self.k = k
104 |
105 | def forward(self, z, sldj):
106 | prior_ll = -0.5 * (z ** 2 + np.log(2 * np.pi))
107 | prior_ll = prior_ll.flatten(1).sum(-1) \
108 | - np.log(self.k) * np.prod(z.size()[1:])
109 | ll = prior_ll + sldj
110 | nll = -ll.mean()
111 |
112 | return nll
113 |
114 |
115 | def initialize(layer, mode):
116 | if mode == 'gaussian':
117 | nn.init.normal_(layer.weight, 0., 0.05)
118 | nn.init.normal_(layer.bias, 0., 0.05)
119 |
120 | elif mode == 'zeros':
121 | nn.init.zeros_(layer.weight)
122 | nn.init.zeros_(layer.bias)
123 |
124 | else:
125 | raise NotImplementedError("To be implemented")
126 |
--------------------------------------------------------------------------------
/optim_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.utils as utils
5 |
6 |
7 | def bits_per_dim(x, nll):
8 | """Get the bits per dimension implied by using model with `loss`
9 | for compressing `x`, assuming each entry can take on `k` discrete values.
10 |
11 | Args:
12 | x (torch.Tensor): Input to the model. Just used for dimensions.
13 | nll (torch.Tensor): Scalar negative log-likelihood loss tensor.
14 |
15 | Returns:
16 | bpd (torch.Tensor): Bits per dimension implied if compressing `x`.
17 | """
18 | dim = np.prod(x.size()[1:])
19 | bpd = nll / (np.log(2) * dim)
20 |
21 | return bpd
22 |
23 |
24 | def clip_grad_norm(optimizer, max_norm, norm_type=2):
25 | """Clip the norm of the gradients for all parameters under `optimizer`.
26 |
27 | Args:
28 | optimizer (torch.optim.Optimizer):
29 | max_norm (float): The maximum allowable norm of gradients.
30 | norm_type (int): The type of norm to use in computing gradient norms.
31 | """
32 | for group in optimizer.param_groups:
33 | utils.clip_grad_norm_(group['params'], max_norm, norm_type)
34 |
35 |
36 | def plot_grad_flow(named_parameters):
37 | '''Plots the gradients flowing through different layers in the net during training.
38 | Can be used for checking for possible gradient vanishing / exploding problems.
39 |
40 | Usage: Plug this function in Trainer class after loss.backwards() as
41 | "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
42 | ave_grads = []
43 | max_grads= []
44 | layers = []
45 | for n, p in named_parameters:
46 | if(p.requires_grad) and ("bias" not in n):
47 | layers.append(n)
48 | ave_grads.append(p.grad.abs().mean())
49 | max_grads.append(p.grad.abs().max())
50 | plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
51 | plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
52 | plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k" )
53 | plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
54 | plt.xlim(left=0, right=len(ave_grads))
55 | plt.ylim(bottom = -0.001, top=0.02) # zoom in on the lower gradient regions
56 | plt.xlabel("Layers")
57 | plt.ylabel("average gradient")
58 | plt.title("Gradient flow")
59 | plt.grid(True)
60 | plt.legend([Line2D([0], [0], color="c", lw=4),
61 | Line2D([0], [0], color="b", lw=4),
62 | Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
63 |
--------------------------------------------------------------------------------
/sample_100.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import random
4 | import torchvision
5 | import torch.nn as nn
6 | from torchvision import transforms
7 | from torch.utils.data import DataLoader
8 | import torch.nn.utils as utils
9 | from torch.distributions.normal import Normal
10 |
11 | from collections import namedtuple
12 |
13 | from glow import _Glow, PreProcess, squeeze
14 | from modules import *
15 | from shell_util import AverageMeter, save_model
16 | from optim_util import bits_per_dim
17 | from dataloader import MovingObjects
18 |
19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20 | pre_process = PreProcess().to(device)
21 |
22 | Z_splits = namedtuple('Z_splits', 'l3 l2 l1')
23 | Glow = namedtuple('Glow', 'l3 l2 l1')
24 | NN_Theta = namedtuple('NNTheta', 'l3 l2 l1')
25 |
26 |
27 | def flow_forward(x, flow):
28 | if x.min() < 0 or x.max() > 1:
29 | raise ValueError('Expected x in [0, 1], got min/max {}/{}'
30 | .format(x.min(), x.max()))
31 |
32 | # pre-process
33 | x, sldj = pre_process(x)
34 | # L3
35 | x3 = squeeze(x, reverse=False)
36 | x3, sldj = flow.l3(x3, sldj, reverse=False)
37 | x3, x_split3 = x3.chunk(2, dim=1)
38 | # L2
39 | x2 = squeeze(x3, reverse=False)
40 | x2, sldj = flow.l2(x2, sldj, reverse=False)
41 | x2, x_split2 = x2.chunk(2, dim=1)
42 | # L1
43 | x1 = squeeze(x2, reverse=False)
44 | x1, sldj = flow.l1(x1, sldj)
45 |
46 | partition_out = Z_splits(l3=x_split3, l2=x_split2, l1=x1)
47 | partition_h = Z_splits(l3=x3, l2=x2, l1=None)
48 |
49 | return partition_out, partition_h, sldj
50 |
51 |
52 | def sample_100(context, glow, nn_theta, temperature=0.5):
53 | for net in glow:
54 | net.eval()
55 | for net in nn_theta:
56 | net.eval()
57 |
58 | b_s = context.size(0)
59 | # generate two frames
60 | torchvision.utils.save_image(context[:, 0, ...].squeeze(), 'samples_100/context.png')
61 |
62 | for m in range(100):
63 | print(m)
64 | context_frame = context[:, 0, ...]
65 | for n in range(2):
66 | t0_zi, _, _ = flow_forward(context_frame, glow)
67 |
68 | mu_l1, logsigma_l1 = nn_theta.l1(t0_zi.l1)
69 | g1 = Normal(loc=mu_l1, scale=temperature * torch.exp(logsigma_l1))
70 | z1_sample = g1.sample()
71 | sldj = torch.zeros(b_s, device=device)
72 |
73 | # Inverse L1
74 | h1, sldj = glow.l1(z1_sample, sldj, reverse=True)
75 | h1 = squeeze(h1, reverse=True)
76 |
77 | # Sample z2
78 | mu_l2, logsigma_l2 = nn_theta.l2(t0_zi.l2, h1)
79 | g2 = Normal(loc=mu_l2, scale=temperature * torch.exp(logsigma_l2))
80 | z2_sample = g2.sample()
81 | h12 = torch.cat([h1, z2_sample], dim=1)
82 | h12, sldj = glow.l2(h12, sldj, reverse=True)
83 | h12 = squeeze(h12, reverse=True)
84 |
85 | # Sample z3
86 | mu_l3, logsigma_l3 = nn_theta.l3(t0_zi.l3, h12)
87 | g3 = Normal(loc=mu_l3, scale=temperature * torch.exp(logsigma_l3))
88 | z3_sample = g3.sample()
89 |
90 | x_t = torch.cat([h12, z3_sample], dim=1)
91 | x_t, sldj = glow.l3(x_t, sldj, reverse=True)
92 | x_t = squeeze(x_t, reverse=True)
93 |
94 | x_t = torch.sigmoid(x_t)
95 |
96 | if not os.path.exists('samples_100/'):
97 | os.mkdir('samples_100/')
98 |
99 | torchvision.utils.save_image(x_t, 'samples_100/sample{}_{}.png'.format(m, n+1))
100 |
101 | assert context_frame.shape == x_t.shape
102 | context_frame = x_t.clone()
103 |
104 |
105 | def main():
106 | import torch.nn as nn
107 |
108 | seed = 123
109 | random.seed(seed)
110 | np.random.seed(seed)
111 | torch.manual_seed(seed)
112 | torch.cuda.manual_seed_all(seed)
113 |
114 | tr = transforms.Compose([transforms.ToTensor()])
115 | train_data = MovingObjects("train", tr)
116 | train_loader = DataLoader(train_data,
117 | num_workers=1,
118 | batch_size=1,
119 | shuffle=False,
120 | pin_memory=True)
121 |
122 | in_chs = 3
123 | flow_l3 = nn.DataParallel(_Glow(in_channels=4 * in_chs, mid_channels=512, num_steps=24)).to(device)
124 | flow_l2 = nn.DataParallel(_Glow(in_channels=8 * in_chs, mid_channels=512, num_steps=24)).to(device)
125 | flow_l1 = nn.DataParallel(_Glow(in_channels=16 * in_chs, mid_channels=512, num_steps=24)).to(device)
126 |
127 | nntheta3 = nn.DataParallel(
128 | NNTheta(encoder_ch_in=4 * in_chs, encoder_mode='conv_net', h_ch_in=2 * in_chs,
129 | num_blocks=5)).to(device) # z1:2x32x32
130 | nntheta2 = nn.DataParallel(
131 | NNTheta(encoder_ch_in=8 * in_chs, encoder_mode='conv_net', h_ch_in=4 * in_chs,
132 | num_blocks=5)).to(device) # z2:4x16x16
133 | nntheta1 = nn.DataParallel(NNTheta(encoder_ch_in=16 * in_chs, encoder_mode='conv_net',
134 | num_blocks=5)).to(device)
135 |
136 | model_path = '/b_test/azimi/results/VideoFlow/SMovement/exp10/sacred/snapshots/109.pth'
137 | if True:
138 | print('model loading ...')
139 | flow_l3.load_state_dict(torch.load(model_path)['glow_l3'])
140 | flow_l2.load_state_dict(torch.load(model_path)['glow_l2'])
141 | flow_l1.load_state_dict(torch.load(model_path)['glow_l1'])
142 | nntheta3.load_state_dict(torch.load(model_path)['nn_theta_l3'])
143 | nntheta2.load_state_dict(torch.load(model_path)['nn_theta_l2'])
144 | nntheta1.load_state_dict(torch.load(model_path)['nn_theta_l1'])
145 |
146 | glow = Glow(l3=flow_l3, l2=flow_l2, l1=flow_l1)
147 | nn_theta = NN_Theta(l3=nntheta3, l2=nntheta2, l1=nntheta1)
148 |
149 | context = next(iter(train_loader)).cuda()
150 | sample_100(context, glow, nn_theta)
151 |
152 |
153 | if __name__ == "__main__":
154 | main()
155 |
--------------------------------------------------------------------------------
/shell_util.py:
--------------------------------------------------------------------------------
1 | import torch, os
2 |
3 | class AverageMeter(object):
4 | """Computes and stores the average and current value.
5 |
6 | Adapted from: https://github.com/pytorch/examples/blob/master/imagenet/train.py
7 | """
8 | def __init__(self):
9 | self.val = 0.
10 | self.avg = 0.
11 | self.sum = 0.
12 | self.count = 0.
13 |
14 | def reset(self):
15 | self.val = 0.
16 | self.avg = 0.
17 | self.sum = 0.
18 | self.count = 0.
19 |
20 | def update(self, val, n=1):
21 | self.val = val
22 | self.sum += val * n
23 | self.count += n
24 | self.avg = self.sum / self.count
25 |
26 |
27 | def save_model(glow, nn_theta, optimizer, epoch, save_path):
28 | if not os.path.exists(save_path + 'snapshots/'):
29 | os.mkdir(save_path + 'snapshots/')
30 |
31 | torch.save({
32 | 'glow_l3': glow.l3.state_dict(),
33 | 'glow_l2': glow.l2.state_dict(),
34 | 'glow_l1': glow.l1.state_dict(),
35 | 'nn_theta_l3': nn_theta.l3.state_dict(),
36 | 'nn_theta_l2': nn_theta.l2.state_dict(),
37 | 'nn_theta_l1': nn_theta.l1.state_dict(),
38 | 'optimizer': optimizer.state_dict()
39 | }, save_path+'snapshots/{}.pth'.format(epoch))
40 |
41 |
42 |
43 | def track_grads(nn_theta, glow, iter):
44 | for name, param in nn_theta.l1.named_parameters():
45 | if param.requires_grad:
46 | writer.add_scalar('data/nntheta1', torch.max(param.grad.data), iter)
47 | writer.add_scalar('data/nntheta1', torch.min(param.grad.data), iter)
48 |
49 | for name, param in nn_theta.l2.named_parameters():
50 | if param.requires_grad:
51 | writer.add_scalar('data/nntheta2', torch.max(param.grad.data), iter)
52 | writer.add_scalar('data/nntheta2', torch.min(param.grad.data), iter)
53 |
54 | for name, param in nn_theta.l3.named_parameters():
55 | if param.requires_grad:
56 | writer.add_scalar('data/nntheta3', torch.max(param.grad.data), iter)
57 | writer.add_scalar('data/nntheta3', torch.min(param.grad.data), iter)
58 |
59 | for name, param in glow.l1.named_parameters():
60 | if param.requires_grad:
61 | writer.add_scalar('data/glow1', torch.max(param.grad.data), iter)
62 | writer.add_scalar('data/glow1', torch.min(param.grad.data), iter)
63 |
64 | for name, param in glow.l2.named_parameters():
65 | if param.requires_grad:
66 | writer.add_scalar('data/glow2', torch.max(param.grad.data), iter)
67 | writer.add_scalar('data/glow2', torch.min(param.grad.data), iter)
68 |
69 | for name, param in glow.l3.named_parameters():
70 | if param.requires_grad:
71 | writer.add_scalar('data/glow3', torch.max(param.grad.data), iter)
72 | writer.add_scalar('data/glow3', torch.min(param.grad.data), iter)
73 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import torchvision
4 | import torch.nn as nn
5 | from torchvision import transforms
6 | from torch.utils.data import DataLoader
7 | import torch.optim.lr_scheduler as sched
8 | import torch.nn.utils as utils
9 | from torch.distributions.normal import Normal
10 |
11 | from collections import namedtuple
12 | from tqdm import tqdm
13 |
14 | from glow import _Glow, PreProcess, squeeze
15 | from modules import *
16 | from shell_util import AverageMeter, save_model
17 | from optim_util import bits_per_dim
18 | from dataloader import MovingObjects
19 |
20 | from tensorboardX import SummaryWriter
21 | from sacred import Experiment
22 | from sacred.observers import FileStorageObserver
23 |
24 | global_step = 0
25 |
26 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27 | pre_process = PreProcess().to(device)
28 |
29 | Z_splits = namedtuple('Z_splits', 'l3 l2 l1')
30 | Glow = namedtuple('Glow', 'l3 l2 l1')
31 | NN_Theta = namedtuple('NNTheta', 'l3 l2 l1')
32 |
33 | PATH = './sacred/'
34 | writer = SummaryWriter(PATH)
35 | ex = Experiment()
36 | ex.observers.append(FileStorageObserver.create(PATH))
37 |
38 |
39 | @ex.config
40 | def config():
41 | tr_conf = {
42 | 'encoder_mode': 'conv_net',
43 | 'enc_depth': 5,
44 | 'n_epoch': 600,
45 | 'b_s': 26,
46 | 'lr': 1e-4,
47 | 'k': 256,
48 | 'input_channels': 3,
49 | 'resume': True,
50 | 'starting_epoch': 56
51 | }
52 |
53 |
54 | def train_smovement(train_loader, glow, nn_theta, loss_fn, optimizer, scheduler, epoch):
55 | print("ID: exp12_1 testing lr 1e-4 and only one step movement, no glow loss with random patch")
56 | global global_step
57 | loss_meter = AverageMeter()
58 | # loss_fn_glow = GlowLoss()
59 | for net in glow:
60 | net.train()
61 | for net in nn_theta:
62 | net.train()
63 |
64 | with tqdm(total=len(train_loader.dataset)) as progress_bar:
65 | for itr, sequence in enumerate(train_loader):
66 | sequence = sequence.to(device)
67 | b_s = sequence.size(0)
68 |
69 | # start_index = torch.LongTensor(1).random_(0, 2)
70 | # random_patch = sequence[:, start_index:start_index + 2, :, :, :]
71 |
72 | random_patch = []
73 | for n in range(b_s):
74 | start_index = torch.LongTensor(1).random_(0, 2)
75 | random_patch.append(sequence[n, start_index:start_index + 2, :, :, :])
76 | random_patch = torch.stack(random_patch, dim=0)
77 |
78 | t0_zi, _, sldj_0 = flow_forward(random_patch[:, 0, :, :, :], glow)
79 | # z_glow = recover_z_shape(t0_zi)
80 | # loss_glow = loss_fn_glow(z_glow, sldj_0)
81 |
82 | t1_zi_out, t1_zi_h, sldj_1 = flow_forward(random_patch[:, 1, :, :, :], glow)
83 | h12 = t1_zi_h.l3
84 |
85 | mu_l3, logsigma_l3 = nn_theta.l3(t0_zi.l3, h12)
86 | g3 = Normal(loc=mu_l3, scale=torch.exp(logsigma_l3))
87 |
88 | h1 = t1_zi_h.l2
89 | mu_l2, logsigma_l2 = nn_theta.l2(t0_zi.l2, h1)
90 | g2 = Normal(loc=mu_l2, scale=torch.exp(logsigma_l2))
91 |
92 | mu_l1, logsigma_l1 = nn_theta.l1(t0_zi.l1)
93 | g1 = Normal(loc=mu_l1, scale=torch.exp(logsigma_l1))
94 |
95 | total_loss = loss_fn(g1, g2, g3, z=t1_zi_out, sldj=sldj_1,
96 | input_dim=random_patch[:, 1, :, :, :].size())
97 |
98 | # total_loss = loss #+ loss_glow
99 | total_loss.backward()
100 |
101 | clip_grad_value(optimizer)
102 |
103 | optimizer.step()
104 | optimizer.zero_grad()
105 | if scheduler is not None:
106 | scheduler.step(global_step)
107 |
108 | loss_meter.update(total_loss.item(), b_s)
109 | progress_bar.set_postfix(nll=loss_meter.avg,
110 | bpd=bits_per_dim(random_patch[:, 1, :, :, :], loss_meter.avg),
111 | lr=optimizer.param_groups[0]['lr'])
112 | progress_bar.update(b_s)
113 | global_step += 1
114 |
115 | print("global step:", global_step)
116 |
117 | torch.cuda.empty_cache()
118 | #save_model(glow, nn_theta, optimizer, scheduler, epoch, PATH)
119 | save_model(glow, nn_theta, optimizer, epoch, PATH)
120 | writer.add_scalar('data/train_loss', loss_meter.avg, epoch)
121 | writer.add_scalar('data/lr', get_lr(optimizer), epoch)
122 |
123 | context = next(iter(train_loader)).cuda()
124 | flow_inverse_smovement(context, glow, nn_theta, epoch)
125 |
126 |
127 | def recover_z_shape(t_z):
128 | z1 = squeeze(t_z.l1, reverse=True)
129 | z2 = torch.cat([z1, t_z.l2], dim=1)
130 | z2 = squeeze(z2, reverse=True)
131 | z3 = torch.cat([z2, t_z.l3], dim=1)
132 | z3 = squeeze(z3, reverse=True)
133 | return z3
134 |
135 |
136 | def clip_grad_value(optimizer, max_val=10.):
137 | for group in optimizer.param_groups:
138 | utils.clip_grad_value_(group['params'], max_val)
139 |
140 |
141 | def flow_forward(x, flow):
142 | if x.min() < 0 or x.max() > 1:
143 | raise ValueError('Expected x in [0, 1], got min/max {}/{}'
144 | .format(x.min(), x.max()))
145 |
146 | # pre-process
147 | x, sldj = pre_process(x)
148 | # L3
149 | x3 = squeeze(x, reverse=False)
150 | x3, sldj = flow.l3(x3, sldj, reverse=False)
151 | x3, x_split3 = x3.chunk(2, dim=1)
152 | # L2
153 | x2 = squeeze(x3, reverse=False)
154 | x2, sldj = flow.l2(x2, sldj, reverse=False)
155 | x2, x_split2 = x2.chunk(2, dim=1)
156 | # L1
157 | x1 = squeeze(x2, reverse=False)
158 | x1, sldj = flow.l1(x1, sldj)
159 |
160 | partition_out = Z_splits(l3=x_split3, l2=x_split2, l1=x1)
161 | partition_h = Z_splits(l3=x3, l2=x2, l1=None)
162 |
163 | return partition_out, partition_h, sldj
164 |
165 |
166 | def flow_inverse_smovement(context, glow, nn_theta, epoch):
167 | for net in glow:
168 | net.eval()
169 | for net in nn_theta:
170 | net.eval()
171 |
172 | # pre-process the context frame
173 | b_s = context.size(0)
174 | context_frame = context[:, 0, ...]
175 | t0_zi, _, _ = flow_forward(context_frame, glow)
176 |
177 | mu_l1, logsigma_l1 = nn_theta.l1(t0_zi.l1)
178 | g1 = Normal(loc=mu_l1, scale=torch.exp(logsigma_l1))
179 | z1_sample = g1.sample()
180 | print("z1", z1_sample.shape)
181 | sldj = torch.zeros(b_s, device=device)
182 |
183 | # Inverse L1
184 | h1, sldj = glow.l1(z1_sample, sldj, reverse=True)
185 | h1 = squeeze(h1, reverse=True)
186 |
187 | # Sample z2
188 | mu_l2, logsigma_l2 = nn_theta.l2(t0_zi.l2, h1)
189 | g2 = Normal(loc=mu_l2, scale=torch.exp(logsigma_l2))
190 | z2_sample = g2.sample()
191 | h12 = torch.cat([h1, z2_sample], dim=1)
192 | h12, sldj = glow.l2(h12, sldj, reverse=True)
193 | h12 = squeeze(h12, reverse=True)
194 |
195 | # Sample z3
196 | mu_l3, logsigma_l3 = nn_theta.l3(t0_zi.l3, h12)
197 | g3 = Normal(loc=mu_l3, scale=torch.exp(logsigma_l3))
198 | z3_sample = g3.sample()
199 |
200 | x_t = torch.cat([h12, z3_sample], dim=1)
201 | x_t, sldj = glow.l3(x_t, sldj, reverse=True)
202 | x_t = squeeze(x_t, reverse=True)
203 |
204 | x_t = torch.sigmoid(x_t)
205 |
206 | torchvision.utils.save_image(x_t, 'samples/sample{}.png'.format(epoch))
207 | torchvision.utils.save_image(context[:, 0, ...].squeeze(), 'samples/context{}.png'.format(epoch))
208 | torchvision.utils.save_image(context[:, 1, ...].squeeze(), 'samples/gt{}.png'.format(epoch))
209 |
210 |
211 | def get_lr(optimizer):
212 | for param_group in optimizer.param_groups:
213 | return param_group['lr']
214 |
215 |
216 | @ex.automain
217 | def main(tr_conf):
218 | import torch.nn as nn
219 |
220 | seed = 12345
221 | random.seed(seed)
222 | np.random.seed(seed)
223 | torch.manual_seed(seed)
224 | torch.cuda.manual_seed_all(seed)
225 |
226 | tr = transforms.Compose([transforms.ToTensor()])
227 | train_data = MovingObjects("train", tr, seed)
228 | train_loader = DataLoader(train_data,
229 | num_workers=tr_conf['b_s'],
230 | batch_size=tr_conf['b_s'],
231 | shuffle=True,
232 | pin_memory=True)
233 |
234 | param_list = []
235 | in_chs = tr_conf['input_channels']
236 | flow_l3 = nn.DataParallel(_Glow(in_channels=4 * in_chs, mid_channels=512, num_steps=24)).to(device)
237 | flow_l2 = nn.DataParallel(_Glow(in_channels=8 * in_chs, mid_channels=512, num_steps=24)).to(device)
238 | flow_l1 = nn.DataParallel(_Glow(in_channels=16 * in_chs, mid_channels=512, num_steps=24)).to(device)
239 |
240 | nntheta3 = nn.DataParallel(
241 | NNTheta(encoder_ch_in=4 * in_chs, encoder_mode=tr_conf['encoder_mode'], h_ch_in=2 * in_chs,
242 | num_blocks=tr_conf['enc_depth'])).to(device) # z1:2x32x32
243 | nntheta2 = nn.DataParallel(
244 | NNTheta(encoder_ch_in=8 * in_chs, encoder_mode=tr_conf['encoder_mode'], h_ch_in=4 * in_chs,
245 | num_blocks=tr_conf['enc_depth'])).to(device) # z2:4x16x16
246 | nntheta1 = nn.DataParallel(NNTheta(encoder_ch_in=16 * in_chs, encoder_mode=tr_conf['encoder_mode'],
247 | num_blocks=tr_conf['enc_depth'])).to(device)
248 |
249 | model_path = '/b_test/azimi/results/VideoFlow/SMovement/exp12_2/sacred/snapshots/55.pth'
250 | if tr_conf['resume']:
251 | print('model loading ...')
252 | flow_l3.load_state_dict(torch.load(model_path)['glow_l3'])
253 | flow_l2.load_state_dict(torch.load(model_path)['glow_l2'])
254 | flow_l1.load_state_dict(torch.load(model_path)['glow_l1'])
255 | nntheta3.load_state_dict(torch.load(model_path)['nn_theta_l3'])
256 | nntheta2.load_state_dict(torch.load(model_path)['nn_theta_l2'])
257 | nntheta1.load_state_dict(torch.load(model_path)['nn_theta_l1'])
258 | print("****LOAD THE OPTIMIZER")
259 |
260 | glow = Glow(l3=flow_l3, l2=flow_l2, l1=flow_l1)
261 | nn_theta = NN_Theta(l3=nntheta3, l2=nntheta2, l1=nntheta1)
262 |
263 | for f_level in glow:
264 | param_list += list(f_level.parameters())
265 |
266 | for nn in nn_theta:
267 | param_list += list(nn.parameters())
268 |
269 | loss_fn = NLLLossVF()
270 |
271 | optimizer = torch.optim.Adam(param_list, lr=tr_conf['lr'])
272 | optimizer.load_state_dict(torch.load(model_path)['optimizer'])
273 | optimizer.zero_grad()
274 |
275 | # scheduler_step = sched.StepLR(optimizer, step_size=1, gamma=0.99)
276 | # linear_decay = sched.LambdaLR(optimizer, lambda s: 1. - s / 150000. )
277 | # linear_decay.step(global_step)
278 |
279 | # scheduler = sched.LambdaLR(optimizer, lambda s: min(1., s / 10000))
280 | # optimizer.load_state_dict(torch.load(model_path)['optimizer'])
281 |
282 | for epoch in range(tr_conf['starting_epoch'], tr_conf['n_epoch']):
283 | print("the learning rate for epoch {} is {}".format(epoch, get_lr(optimizer)))
284 | train_smovement(train_loader, glow, nn_theta, loss_fn, optimizer, None, epoch)
285 | #scheduler_step.step()
286 |
--------------------------------------------------------------------------------