├── example.gif ├── LICENSE ├── src ├── utils.py ├── visual_encoders.py ├── legacy │ ├── data_prep.py │ ├── text_processing.py │ ├── main.py │ ├── trainer.py │ └── resnet.py ├── convgru.py ├── blocks.py ├── data_prep2.py ├── text_processing2.py ├── models.py ├── process3-5.ipynb └── pipeline.ipynb ├── references.md └── README.md /example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukoshkin/text2video/HEAD/example.gif -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Vladislav Lukoshkin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import autograd 3 | 4 | def to_video(tensor): 5 | """ 6 | Takes generator output and converts it to 7 | the format which can later be written by 8 | SummaryWriter's instance 9 | 10 | input: tensor of shape (N, C, D, H, W) 11 | obtained passed through nn.Tanh 12 | """ 13 | generated = (tensor + 1) / 2 * 255 14 | generated = generated.cpu() \ 15 | .numpy() \ 16 | .transpose(0, 2, 1, 3, 4) 17 | 18 | return generated.astype('uint8') 19 | 20 | 21 | def selectFramesRandomly(N, k): 22 | """ 23 | N - total number of frames 24 | k - number to choose 25 | """ 26 | frame_ids = torch.multinomial(torch.ones(N), k) 27 | frame_ids, _ = torch.sort(frame_ids) 28 | return frame_ids 29 | 30 | 31 | def calc_grad_penalty(real_samples, fake_samples, net_D, condition): 32 | """ 33 | Evaluates the D's gradient penalty and allows other gradients 34 | to backpropogate through the penalty term 35 | Args: 36 | real_samples - a tensor (presumably, without `grad` attribute) 37 | fake_samples - tensor of the same shape as `real_samples` 38 | net_D - conditional discriminator 39 | """ 40 | alpha = real_samples.new( 41 | real_samples.size(0), 42 | *([1]*(real_samples.dim()-1)) 43 | ).uniform_().expand(*real_samples.shape) 44 | 45 | inputs = alpha * real_samples + (1-alpha) * fake_samples.detach() 46 | inputs.requires_grad_(True) 47 | outputs = net_D(inputs, condition) 48 | jacobian = autograd.grad( 49 | outputs=outputs, inputs=inputs, 50 | grad_outputs=torch.ones_like(outputs), 51 | create_graph=True)[0] 52 | 53 | # flatten each sample grad. and apply 2nd norm to it 54 | jacobian = jacobian.view(jacobian.size(0), -1) 55 | return (jacobian.norm(dim=1)**2).mean() 56 | 57 | 58 | def vanilla_DLoss(M, GM, E, NE, net_D): 59 | pos, neg, gen = map(net_D, (M, M, GM), (E, NE, E)) 60 | L = pos.log().mean() + (-neg).log1p().mean() + (-gen).log1p().mean() 61 | return -.33 * L 62 | 63 | def vanilla_GLoss1(GM, E, net_D): 64 | return (-net_D(GM, E)).log1p().mean() 65 | 66 | def vanilla_GLoss2(GM, E, net_D): 67 | return -net_D(GM, E).log().mean() 68 | 69 | 70 | eps = torch.tensor(1e-12) 71 | 72 | def batchGAN_DLoss(u, multibatch, net_D): 73 | E, M = multibatch 74 | v = net_D(M, E).view(len(u), -1).mean(1) 75 | L = u * ((u+eps)/(v+eps)).log() + (1-u) * ((1-u+eps)/(1-v+eps)).log() 76 | return L.mean() 77 | 78 | def batchMGAN_DLoss(u, multibatch, net_D): 79 | E, M = multibatch 80 | v = net_D(M, E) 81 | u = u.repeat_interleave(len(v)//len(u)) 82 | L = u * ((u+eps)/(v+eps)).log() + (1-u) * ((1-u+eps)/(1-v+eps)).log() 83 | return L.mean() 84 | 85 | def batchGAN_GLoss(multibatch, net_D): 86 | E, M = multibatch 87 | return vanilla_GLoss2(M, E, net_D) 88 | -------------------------------------------------------------------------------- /src/visual_encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.utils import spectral_norm 4 | 5 | from blocks import DBlock 6 | from convgru import ConvGRU 7 | from functools import partial 8 | 9 | 10 | def SN(sn): 11 | return spectral_norm if sn else lambda x: x 12 | 13 | 14 | class VideoEncoder(nn.Module): 15 | def __init__(self, in_colors=3, base_width=32, bn=True, sn=False): 16 | super().__init__() 17 | block2d = partial(DBlock, '2d', bn=bn, sn=sn) 18 | block3d = partial(DBlock, '3d', bn=bn, sn=sn) 19 | self.downsampler1 = nn.Sequential( 20 | SN(sn)(nn.Conv3d(in_colors, base_width, 1)), 21 | block3d(base_width, base_width*2, 2), 22 | block3d(base_width*2, base_width*4, (1,2,2))) 23 | self.cgru = ConvGRU( 24 | base_width*4, base_width*4, 3, spectral_norm=sn) 25 | self.downsampler2 = nn.Sequential( 26 | block2d(base_width*4, base_width*8, 2), 27 | block2d(base_width*8, base_width*16, 2), 28 | block2d(base_width*16, base_width*32, 2)) 29 | 30 | def forward(self, video): 31 | H = self.downsampler1(video) 32 | _, last = self.cgru(H) 33 | H = self.downsampler2(last) 34 | 35 | return H.view(H.size(0), -1) 36 | 37 | 38 | class ProjectionVideoDiscriminator(VideoEncoder): 39 | def __init__(self, cond_size, in_colors=3, base_width=32, logits=True): 40 | super().__init__(in_colors, base_width, bn=False, sn=True) 41 | self.proj = nn.Sequential( 42 | SN(True)(nn.Linear(cond_size, base_width*32)), 43 | nn.LeakyReLU(.2, inplace=True)) 44 | self.pool = SN(True)(nn.Linear(base_width*32, 1)) 45 | if logits: self.activation = nn.Sequential() 46 | else: self.activation = torch.sigmoid 47 | 48 | def forward(self, video, embedding): 49 | E = self.proj(embedding) 50 | H = super().forward(video) 51 | out = self.pool(H).squeeze(1) 52 | out += torch.einsum('ij,ij->i', E, H) 53 | 54 | return self.activation(out) 55 | 56 | 57 | class ImageEncoder(nn.Module): 58 | def __init__(self, in_colors=3, base_width=32, bn=True, sn=False): 59 | super().__init__() 60 | block2d = partial(DBlock, '2d', bn=bn, sn=sn) 61 | self.downsampler = nn.Sequential( 62 | SN(sn)(nn.Conv2d(in_colors, base_width, 1)), 63 | block2d(base_width, base_width*2, 2), 64 | block2d(base_width*2, base_width*4, 2), 65 | block2d(base_width*4, base_width*8, 2), 66 | block2d(base_width*8, base_width*16, 2), 67 | block2d(base_width*16, base_width*32, 2)) 68 | 69 | def forward(self, images): 70 | """ 71 | images 72 | """ 73 | k = images.size(1) 74 | images = torch.flatten(images, 0, 1) 75 | H = self.downsampler(images) 76 | 77 | # images.shape (N, k, C, H, W) 78 | # images.shape (N*k, C, H, W) 79 | # H.shape (N*k, base_width*32, 1, 1) 80 | # output.shape (N, k, base_width*32) 81 | 82 | return H.view(H.size(0)//k, k, -1) 83 | 84 | 85 | class ProjectionImageDiscriminator(ImageEncoder): 86 | def __init__(self, cond_size, in_colors=3, base_width=32, logits=True): 87 | super().__init__(in_colors, base_width, bn=False, sn=True) 88 | self.proj = nn.Sequential( 89 | SN(True)(nn.Linear(cond_size, base_width*32)), 90 | nn.LeakyReLU(.2, inplace=True)) 91 | self.pool = SN(True)(nn.Linear(base_width*32, 1)) 92 | if logits: self.activation = nn.Sequential() 93 | else: self.activation = torch.sigmoid 94 | 95 | def forward(self, video, embedding): 96 | E = self.proj(embedding) 97 | H = super().forward(video) 98 | out = self.pool(H).sum([1, 2]) 99 | out += torch.einsum('ij,ikj->i', E, H) 100 | 101 | return self.activation(out) 102 | -------------------------------------------------------------------------------- /src/legacy/data_prep.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import pickle 3 | import numpy as np 4 | 5 | from tqdm import tqdm 6 | from pathlib import Path 7 | from smart_open import open 8 | from torch.utils.data import Dataset 9 | from text_processing import doTextPart, sen2vec 10 | 11 | 12 | class LabeledVideoDataset(Dataset): 13 | def __init__( 14 | self, path, cache, 15 | video_shape=(32, 64, 64, 3), step=2, 16 | min_word_freq = 2, check_spell=False, 17 | transform=None, ext='webm'): 18 | self.transform = transform if transform else lambda x: x 19 | 20 | path = Path(path) 21 | file_name = path.stem 22 | 23 | cache = Path(cache) 24 | if (cache / f"{file_name}.db").exists(): 25 | with open(cache / f"{file_name}.db", 'rb') as fp: 26 | self.data = pickle.load(fp) 27 | self.i2i = pickle.load(fp) 28 | else: 29 | self.data = [] 30 | cache.mkdir(parents=True, exist_ok=True) 31 | 32 | max_len, t2i, df = doTextPart ( 33 | path, cache, 34 | min_word_freq, 35 | check_spell 36 | ) 37 | index = 0 38 | self.i2i = {} 39 | 40 | mult = [] 41 | corrupted = 0 42 | D, H, W, C = video_shape 43 | folder = path.parents[1] 44 | 45 | pbar = tqdm(df.iterrows(), "Preparing dataset", len(df)) 46 | for _, sample in pbar: 47 | video = folder / 'video' / f"{sample['id']}.{ext}" 48 | ViCap = cv2.VideoCapture(str(video)) 49 | _D = ViCap.get(cv2.CAP_PROP_FRAME_COUNT) 50 | 51 | if int(_D) < D : continue 52 | mult.append(int(_D) // D) 53 | 54 | CNT = 0 55 | frames = [] 56 | success = True 57 | while success and (CNT < D * mult[-1]): 58 | success, image = ViCap.read() 59 | if success: 60 | image = cv2.resize( 61 | image, (H, W), 62 | interpolation=cv2.INTER_AREA) 63 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 64 | frames += [image] 65 | CNT += 1 66 | 67 | ViCap.release() 68 | cv2.destroyAllWindows() 69 | 70 | if CNT == D * mult[-1]: 71 | frames = np.array( 72 | frames, 'float32').transpose(3,0,1,2) / 255 73 | sen_len = len(sample['label']) 74 | numerated = sen2vec(sample['label'], t2i, max_len) 75 | self.data.append( 76 | (sen_len, numerated, frames[:, ::step * mult[-1]])) 77 | self.i2i[sample['id']] = index 78 | index += 1 79 | else: 80 | corrupted += 1 81 | mult.pop() 82 | 83 | self.data = np.array( 84 | self.data, 85 | [('', 'int64'), 86 | ('', 'int64', max_len), 87 | ('', 'float32', (C, D//step, H, W))]) 88 | 89 | print('No of corrupted videos', corrupted) 90 | print(f'Caching database to {file_name}.db') 91 | with open(cache / f'{file_name}.db', 'wb') as fp: 92 | pickle.dump(self.data, fp) 93 | pickle.dump(self.i2i, fp) 94 | print('Done!') 95 | 96 | def __getitem__(self, index): 97 | sen_len, label, video = self.data[index] 98 | 99 | return {'slens': sen_len, 100 | 'label': label, 101 | 'video': self.transform(video)} 102 | 103 | def __len__(self): 104 | return len(self.data) 105 | 106 | def getById(self, video_ids): 107 | ids = map(self.i2i.get, video_ids) 108 | selected = np.take(self.data, list(ids)) 109 | 110 | return np.rec.array(selected) 111 | -------------------------------------------------------------------------------- /src/legacy/text_processing.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import copy 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from tqdm import tqdm 8 | from pathlib import Path 9 | from smart_open import open 10 | from collections import Counter 11 | from nltk import wordpunct_tokenize 12 | from spellchecker import SpellChecker 13 | 14 | 15 | def selectTemplates(path, templates, new_name): 16 | """ 17 | Remove from the database (json file specified by 'path' argument) 18 | label categories which are not in 'templates' list. The new file 19 | is created under the same folder as the old one. 20 | """ 21 | path = Path(path) 22 | df = pd.read_json(path) 23 | 24 | mask = df.template.isin(templates) 25 | new_df = df[mask] 26 | 27 | new_df.index = np.arange(mask.sum()) 28 | new_path = path.parent / new_name 29 | new_df.to_pickle(new_path) 30 | 31 | return new_path 32 | 33 | 34 | def doTextPart(path, cache, min_freq=2, check_spell=False): 35 | path = Path(path) 36 | file_name = path.stem 37 | ext = path.suffix[1:] 38 | 39 | cache = Path(cache) 40 | if ((cache / 'vocab.pkl').exists() and 41 | (cache / f'{file_name}.pkl').exists()): 42 | df = pd.read_pickle(cache / f'{file_name}.pkl') 43 | with open(cache / 'vocab.pkl', 'rb') as fp: 44 | t2i = pickle.load(fp) 45 | max_len = int(fp.readline()) 46 | 47 | return max_len, t2i, df 48 | 49 | cache.mkdir(parents=True, exist_ok=True) 50 | 51 | if ext == 'json': 52 | df = pd.read_json(path) 53 | elif ext == 'pkl': 54 | df = pd.read_pickle(path) 55 | else: 56 | raise TypeError("'json' and 'pkl' are only supported") 57 | 58 | sentences = list(df.label.apply(wordpunct_tokenize).values) 59 | token_counts = Counter() 60 | token_counts.update(np.concatenate(sentences)) 61 | 62 | vague_words = [ 63 | x for x in token_counts.keys() 64 | if token_counts[x] <= min_freq 65 | ] 66 | 67 | if check_spell: 68 | spell = SpellChecker() 69 | checked_words = [] 70 | for sen in tqdm(sentences, 'Spell-check'): 71 | for i, w in enumerate(sen): 72 | if w in vague_words: 73 | sen[i] = spell.correction(w) 74 | checked_words.append(sen[i]) 75 | 76 | token_counts.update(checked_words) 77 | 78 | vague_words = [ 79 | x for x in token_counts.keys() 80 | if token_counts[x] <= min_freq 81 | ] 82 | 83 | mended = copy.deepcopy(sentences) 84 | pbar = tqdm(sentences, "Removing 'bad samples'") 85 | k = 0 86 | for i, sen in enumerate(pbar): 87 | flag = True 88 | for w in sen: 89 | if w in vague_words: 90 | df.drop(i, inplace=True) 91 | del mended[k] 92 | flag = False 93 | break 94 | if flag: 95 | k += 1 96 | 97 | df.label = mended 98 | df.to_pickle(cache / f'{file_name}.pkl') 99 | 100 | for w in vague_words: 101 | token_counts.pop(w) 102 | 103 | tokens = ['PAD'] + list(token_counts.keys()) 104 | t2i = {t: i for i, t in enumerate(tokens)} 105 | max_len = max(map(len, mended)) 106 | with open(cache / 'vocab.pkl', 'wb') as fp: 107 | pickle.dump(t2i, fp) 108 | fp.write(b'%d' % max_len) 109 | 110 | return (max_len, t2i, df) 111 | 112 | 113 | def getGloveEmbeddings(folder, cache, t2i, emb_size=50): 114 | cache = Path(cache) 115 | file_path = cache / 'emb_matrix.npy' 116 | if file_path.exists(): 117 | return np.load(file_path) 118 | 119 | folder = Path(folder) 120 | with open(folder / f'glove.6B.{emb_size}d.txt') as fp: 121 | raw_data = fp.readlines() 122 | 123 | glove = {} 124 | pbar = tqdm(raw_data, 'Reading glove embeddings') 125 | for line in pbar: 126 | t, *v = line.split() 127 | glove[t] = np.array(v, 'float32') 128 | 129 | emb_matrix = np.empty((len(t2i), emb_size), 'float32') 130 | for t, i in t2i.items(): 131 | try: 132 | emb_matrix[i] = glove[t] 133 | except KeyError: 134 | emb_matrix[i] = .6 * np.random.randn(emb_size) 135 | emb_matrix[0] = 0 136 | 137 | np.save(cache / 'emb_matrix', emb_matrix) 138 | 139 | return emb_matrix 140 | 141 | 142 | def sen2vec(sen, t2i, max_len): 143 | """ 144 | Converts a sentence to a sequence of positive 145 | integers of length 'max_len' (according to 't2i' 146 | dictionary), padded with zeros where necessary 147 | 148 | Output type: int64 - necessary for nn.Embedding 149 | """ 150 | numerated = np.zeros(max_len, 'int') 151 | filling = [t2i[w] for w in sen] 152 | numerated[:len(filling)] = filling 153 | 154 | return numerated 155 | -------------------------------------------------------------------------------- /references.md: -------------------------------------------------------------------------------- 1 | # References to Some of the Ideas Used in This Work 2 | 3 | **0-GP**: [Improving Generalization and Stability of Generative Adversarial Networks](https://arxiv.org/pdf/1902.03984.pdf) 4 | - Gradient exploding in the discriminator\* can lead to mode collapse in the generator (math. justification in the article) 5 | - The number of modes in the distribution grows linearly with the size of the discriminator -> higher capacity discriminators are needed for 6 | better approximation of the target distribution. 7 | - Generalization is guaranted if the discriminator set is small enough. 8 | - To smooth out the loss surface one can build a discriminator that makes the judgement on a mixed batch of fake and real samples, determining 9 | the proportion between them (Lucas et al., 2018) 10 | - VEEGAN (Srivastava et al., 2017) uses the inverse mapping of the generator to map the data to the prior distribution. The mismatch between 11 | the inverse mapping and the prior is used to detect mode collapse. It is not able to help, if the generator can remember the entire dataset 12 | - Generalization capability of the discriminator can be estimated by measuring the difference between its performance on the training dataset 13 | and a held-out dataset 14 | - When generator starts to produce samples of the same quality as the real ones, we come to the situation where the discriminator has to deal 15 | with mislabeled data: generated samples, regardless of how good they are, are still labeled as bad ones, so the discriminator trained on such 16 | dataset will overfit and not be able to teach the generator 17 | - Heuristically, overfitting can be alleviated by limiting the number of discriminator updates per generator update. Goodfellow et al. (2014) 18 | recommended to update the discriminator once every generator update 19 | - It is observed that the norm of the gradient w.r.t. the discriminator’s parameters decreases as fakes samples approach real samples. If the 20 | discriminator’s learning rate is fixed, then the number of gradient descent steps that the discriminator has to take to reach eps-optimal 21 | state should increase. *Alternating gradient descent with the same learning rate for discriminator and generator, and fixed number of 22 | discriminator updates per generator update (Fixed-Alt-GD) cannot maintain the (empirical) optimality of the discriminator*. In GANs trained 23 | with Two Timescale Update Rule (TTUR) (Heusel et al., 2017), the ratio between the learning rate of the discriminator and that of the 24 | generator goes to infinity as the iteration number goes to infinity. Therefore, the discriminator can learn much faster than the generator 25 | and might be able to maintain its optimality throughout the learning process. 26 | ____________________________________ 27 | **\*** in case of emperically optimal D 28 | NOTE: All references to the authors in the text block above 29 | are direct copies of references that can be found in [the article](https://arxiv.org/pdf/1902.03984.pdf) 30 | 31 | 32 | **CBN**: 33 | 1. [Modulating early visual processing by language](https://arxiv.org/pdf/1707.00683.pdf) 34 | 2. [A Learned Representation For Artistic Style](https://arxiv.org/pdf/1610.07629.pdf) 35 | **BN**: [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf) 36 | **ResBlocks**: [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385.pdf) 37 | **ProjDisc**: [cGANs with Projection Discriminator](https://arxiv.org/pdf/1802.05637.pdf) 38 | **ConvGRU**: [Convolutional Gated Recurrent Networks for Video Segmentation](https://arxiv.org/pdf/1611.05435.pdf) 39 | **Basic Ideas for Text Encoders**: [Realistic Image Generation using Region-phrase Attention](https://arxiv.org/pdf/1902.05395.pdf) 40 | **D/G Blocks' Structure**: [Large Scale GAN Training for High Fidelity Natural Image Synthesis](https://arxiv.org/pdf/1809.11096.pdf) 41 | - Employing Spectral Normalization in G improves stability, allowing for fewer D steps per iteration. 42 | - Greater batch size can help dealing with mode collapse and impove the network performance, though it might lead to training collapse (NaNs) 43 | **Joint Structured Embeddings**: 44 | 1. [Learning Deep Representations of Fine-Grained Visual Descriptions](https://arxiv.org/pdf/1605.05395.pdf) 45 | 2. [also](https://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Akata_Evaluation_of_Output_2015_CVPR_paper.pdf) 46 | **Concatenate by Stacking**: [StackGAN++: Realistic Image Synthesis with Stacked Generative Adversarial Networks](https://arxiv.org/pdf/1710.10916.pdf) 47 | **Self-Attention**: [A Structured Self-attentive Sentence Embedding](https://arxiv.org/pdf/1703.03130.pdf) 48 | -------------------------------------------------------------------------------- /src/legacy/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | main.py [options] [] 4 | 5 | path path to a json or pkl file with data 6 | cache folder where you store the processed data 7 | and the generator weights 8 | 9 | Options: 10 | --from-scratch remove cached files before training begins 11 | --device= stands for gpu node number 12 | --num_workers= number of cpu processes to load batches on gpu [default: 8] 13 | --batch-size= batch size (currently, all types of batches are of the same size) [default: 3] 14 | 15 | --noise when specified, Gaussian noise is added to both discriminators 16 | --sigma= if there is a noise, controls its standard deviation [default: .1] 17 | 18 | --video-length= original length of videos in the video batch [default: 32] 19 | --training-time= number of training epochs [default: 100000] 20 | 21 | --encoder= type of encoder: simple, mere, joint [default: mere] 22 | --pp= print period [default: 0] 23 | --lp= log period [default: 20] 24 | """ 25 | import docopt 26 | import pickle 27 | from pathlib import Path 28 | 29 | import torch 30 | from torch.utils.data import DataLoader 31 | from torch import optim 32 | 33 | import models 34 | from trainer import Trainer 35 | from data_prep import LabeledVideoDataset 36 | from text_processing import getGloveEmbeddings, selectTemplates 37 | 38 | 39 | if __name__ == "__main__": 40 | args = docopt.docopt(__doc__) 41 | print(args) 42 | 43 | cache = Path(args['']) 44 | cache = cache if cache else Path('../logdir') 45 | if args['--from-scratch']: 46 | if cache.exists(): 47 | for f in cache.glob('*'): 48 | f.unlink() 49 | 50 | device = args['--device'] 51 | device = torch.device(f'cuda:{device}' if device else 'cpu') 52 | 53 | templates = ['Pushing [something] from left to right'] 54 | new_path = selectTemplates( 55 | args[''], templates, 56 | f'{len(templates)}-template.pkl') 57 | 58 | vlen = int(args['--video-length']) 59 | video_dataset = LabeledVideoDataset( 60 | new_path, cache, (vlen, 64, 64, 3), check_spell=True) 61 | 62 | video_loader = DataLoader( 63 | video_dataset, int(args['--batch-size']), 64 | shuffle=True, num_workers=int(args['--num_workers']), 65 | pin_memory=True, drop_last=True) 66 | 67 | val_samples = [168029, 157604, 71563, 82109] 68 | val_samples = video_dataset.getById(val_samples) 69 | lens = torch.tensor(val_samples.f0, device=device) 70 | texts = torch.tensor(val_samples.f1, device=device) 71 | movies = torch.tensor(val_samples.f2, device=device) 72 | 73 | with open(cache / 'vocab.pkl', 'rb') as fp: 74 | t2i = pickle.load(fp) 75 | max_sen_len = int(fp.readline()) 76 | 77 | device = torch.device(device) 78 | emb_weights = getGloveEmbeddings('../embeddings', cache, t2i) 79 | emb_weights = torch.tensor(emb_weights, device=device) 80 | 81 | if args['--encoder'] == 'simple': 82 | emb_size = 50 83 | text_encoder = models.SimpleTextEncoder(emb_weights) 84 | elif args['--encoder'] == 'mere': 85 | emb_size = 64 86 | text_encoder = models.TextEncoder(emb_weights, proj=True) 87 | elif args['--encoder'] == 'joint': 88 | pass 89 | else: 90 | raise TypeError('Invalid encoder type') 91 | 92 | dim_Z = 50 93 | generator = models.VideoGenerator(dim_Z, emb_size) 94 | 95 | image_discriminator = models.ImageDiscriminator( 96 | cond_size=emb_size, noise=args['--noise'], 97 | sigma=float(args['--sigma'])) 98 | video_discriminator = models.VideoDiscriminator( 99 | cond_size=emb_size, noise=args['--noise'], 100 | sigma=float(args['--sigma'])) 101 | 102 | generator.to(device) 103 | text_encoder.to(device) 104 | image_discriminator.to(device) 105 | video_discriminator.to(device) 106 | 107 | dis_dict = {'image': image_discriminator, 108 | 'video': video_discriminator} 109 | 110 | opt_list = [ 111 | optim.Adam( 112 | generator.parameters(), lr=5e-5, 113 | betas=(.3, .999), weight_decay=1e-5), 114 | optim.Adam( 115 | dis_dict['image'].parameters(), lr=2e-4, 116 | betas=(.3, .999), weight_decay=1e-5), 117 | optim.Adam( 118 | dis_dict['video'].parameters(), lr=2e-4, 119 | betas=(.3, .999), weight_decay=1e-5), 120 | ] 121 | 122 | train_enc = (args['--encoder'] == 'mere') 123 | if train_enc: 124 | opt_list += [optim.Adam( 125 | text_encoder.parameters(), lr=2e-4, 126 | betas=(.3, .999), weight_decay=1e-5)] 127 | 128 | trainer = Trainer ( 129 | text_encoder, dis_dict, generator, 130 | opt_list, video_loader, cache, 131 | train_enc, int(args['--training-time']), 132 | ) 133 | trainer.train(lens, texts, movies, int(args['--pp']), int(args['--lp'])) 134 | -------------------------------------------------------------------------------- /src/legacy/trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | 4 | from torch import autograd 5 | from torch.utils.tensorboard import SummaryWriter 6 | from torch.nn.utils.rnn import pack_padded_sequence 7 | 8 | 9 | def to_video(tensor): 10 | """ 11 | Takes generator output and converts it to 12 | the format which can later be written by 13 | SummaryWriter's instance 14 | 15 | input: tensor of shape (N, C, D, H, W) 16 | obtained passed through nn.Tanh 17 | """ 18 | generated = (tensor + 1) / 2 * 255 19 | generated = generated.cpu() \ 20 | .numpy() \ 21 | .transpose(0, 2, 1, 3, 4) 22 | 23 | return generated.astype('uint8') 24 | 25 | 26 | class Trainer: 27 | """ 28 | Args: 29 | dis_dict dict object containing image and video 30 | discriminators under the keys 'image' 31 | and 'video' respectively 32 | opt_list list with the optimizers of all model 33 | components 34 | log_folder folder where to save the generator's 35 | weights after training is over 36 | train_enc whether to train encoder or not 37 | """ 38 | def __init__( 39 | self, encoder, dis_dict, generator, opt_list, 40 | video_loader, log_folder, train_enc, num_epochs=100000): 41 | 42 | self.encoder = encoder 43 | self.dis_dict = dis_dict 44 | self.generator = generator 45 | self.opt_list = opt_list 46 | self.train_enc = train_enc 47 | 48 | self.vloader = video_loader 49 | self.num_epochs = num_epochs 50 | self.log_folder = log_folder 51 | 52 | self.pairs = {} 53 | self.logs = {'image dis': 0, 54 | 'video dis': 0, 55 | 'generator': 0} 56 | 57 | def formTrainingPairs(self, videos, condition): 58 | sp = len(condition) // 2 59 | self.pairs['pos'] = (videos[:sp], condition[:sp]) 60 | roll_condition = torch.roll(condition[sp:], 1, 0) 61 | self.pairs['neg'] = (videos[sp:], roll_condition) 62 | 63 | art_condition = condition.detach()[:sp] 64 | fake_videos = self.generator(art_condition) 65 | fake_videos.register_hook(lambda grad: -grad) 66 | self.pairs['gen'] = (fake_videos, art_condition) 67 | 68 | def calculateBaseLossTerms(self, kind): 69 | pos_scores = self.dis_dict[kind](*self.pairs['pos']) 70 | neg_scores = self.dis_dict[kind](*self.pairs['neg']) 71 | gen_scores = self.dis_dict[kind](*self.pairs['gen']) 72 | 73 | L1 = torch.log(pos_scores).mean() 74 | L2 = torch.log1p(-neg_scores).mean() 75 | L3 = torch.log1p(-gen_scores).mean() 76 | 77 | self.logs[f'{kind} dis'] += (L1 + L2 + L3).item() 78 | self.logs['generator'] += L3.item() 79 | 80 | return -(L1 + L2 + L3) 81 | 82 | def passBatchThroughNetwork(self, labels, videos, senlen): 83 | with torch.set_grad_enabled(self.train_enc): 84 | condition = self.encoder(labels, senlen) 85 | self.formTrainingPairs(videos, condition) 86 | 87 | loss = videos.new(1).fill_(0) 88 | for opt in self.opt_list: 89 | opt.zero_grad() 90 | for kind in ['image', 'video']: 91 | loss += self.calculateBaseLossTerms(kind) 92 | loss.backward() 93 | for opt in self.opt_list: 94 | opt.step() 95 | 96 | def train(self, lens, texts, movies, pp=0, lp=20): 97 | """ 98 | lp: log period 99 | pp: print period (0 - no print to stdout) 100 | """ 101 | writer = SummaryWriter() 102 | writer.add_video("Real Clips", to_video(movies)) 103 | device = next(self.generator.parameters()).device 104 | 105 | time_per_epoch =- time.time() 106 | for epoch in range(self.num_epochs): 107 | for No, batch in enumerate(self.vloader): 108 | labels = batch['label'].to(device, non_blocking=True) 109 | videos = batch['video'].to(device, non_blocking=True) 110 | senlen = batch['slens'].to(device, non_blocking=True) 111 | self.passBatchThroughNetwork(labels, videos, senlen) 112 | 113 | if pp and epoch % pp == 0: 114 | time_per_epoch += time.time() 115 | print(f'Epoch {epoch}/{self.num_epochs}') 116 | for k, v in self.logs.items(): 117 | print("\t%s:\t%5.4f" % (k, v/(No+1))) 118 | self.logs[k] = v / (No+1) 119 | print('Completed in %.f s' % time_per_epoch) 120 | time_per_epoch =- time.time() 121 | 122 | if epoch % lp == 0: 123 | self.generator.eval() 124 | with torch.no_grad(): 125 | condition = self.encoder(texts, lens) 126 | movies = self.generator(condition) 127 | writer.add_scalars('Loss', self.logs, epoch) 128 | writer.add_video('Fakes', to_video(movies), epoch) 129 | self.generator.train() 130 | self.logs = dict.fromkeys(self.logs, 0) 131 | 132 | torch.save( 133 | self.generator.state_dict(), 134 | self.log_folder / ('gen_%05d.pytorch' % epoch)) 135 | print('Training has been completed successfully!') 136 | -------------------------------------------------------------------------------- /src/legacy/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Noise(nn.Module): 6 | def __init__(self, noise, sigma=.1): 7 | super().__init__() 8 | self.noise = noise 9 | self.sigma = sigma 10 | 11 | def forward(self, x): 12 | if self.noise: 13 | return x + self.sigma * torch.randn_like(x) 14 | return x 15 | 16 | 17 | def block1x1(Conv, BatchNorm, inC, outC, noise, sigma): 18 | """ 19 | Returns a layer which performs a 1x1 convolution 20 | (1x1x1 in case of 3d) with subsequent normalization 21 | and rectification 22 | """ 23 | block_list = [Noise(noise, sigma)] 24 | block_list += [Conv(inC, outC, 1, bias=False)] 25 | if BatchNorm is not None: 26 | block_list += [BatchNorm(outC)] 27 | 28 | return nn.Sequential(*block_list) 29 | 30 | 31 | def block3x3(Conv, BatchNorm, inC, outC, stride, noise, sigma): 32 | """ 33 | Returns a layer which performs a 3x3 convolution 34 | (3x3x3 in case of 3d) with subsequent normalization 35 | and rectification 36 | """ 37 | block_list = [Noise(noise, sigma)] 38 | block_list += [Conv(inC, outC, 3, stride, 1, bias=False)] 39 | if BatchNorm is not None: 40 | block_list += [BatchNorm(outC)] 41 | 42 | return nn.Sequential(*block_list) 43 | 44 | 45 | class BasicBlock(nn.Module): 46 | """ 47 | Args: 48 | type 2d or 3d 49 | stride convolution stride (int or tuple of ints) 50 | noise boolen flag: use Noise layer or do not 51 | sigma standard deviation of the gaussian noise 52 | used in Noise layer 53 | bn whether to add BatchNorm layer 54 | """ 55 | def __init__( 56 | self, type, in_channels, out_channels, 57 | stride=1, bn=True, noise=False, sigma=.2): 58 | super().__init__() 59 | 60 | if type == '3d': 61 | Conv = nn.Conv3d 62 | AvgPool = nn.AvgPool3d 63 | BatchNorm = nn.BatchNorm3d if bn else None 64 | elif type == '2d': 65 | Conv = nn.Conv2d 66 | AvgPool = nn.AvgPool2d 67 | BatchNorm = nn.BatchNorm2d if bn else None 68 | else: 69 | raise TypeError ( 70 | "__init__(): argument 'type' " 71 | "must be '2d' or '3d'" 72 | ) 73 | proj_list = [] 74 | self.proj = None 75 | assert (torch.tensor(stride) <= 2).all() 76 | if in_channels != out_channels: 77 | proj_list += [Conv(in_channels, out_channels, 1)] 78 | if (torch.tensor(stride) > 1).any(): 79 | proj_list += [AvgPool(stride)] 80 | self.proj = nn.Sequential(*proj_list) 81 | 82 | self.leaky = nn.LeakyReLU(.2, inplace=True) 83 | self.main = nn.Sequential ( 84 | block3x3( 85 | Conv, BatchNorm, in_channels, 86 | out_channels, stride, noise, sigma), 87 | self.leaky, 88 | block3x3( 89 | Conv, BatchNorm, out_channels, 90 | out_channels, 1, noise, sigma), 91 | ) 92 | 93 | def forward(self, x): 94 | y = self.main(x) 95 | if self.proj is not None: 96 | x = self.proj(x) 97 | 98 | return self.leaky(y + x) 99 | 100 | 101 | class Bottleneck(nn.Module): 102 | """ 103 | Args: 104 | type 2d or 3d 105 | width width of bottleneck 106 | stride convolution stride (int or tuple of ints) 107 | noise boolen flag: use Noise layer or do not 108 | sigma standard deviation of the gaussian noise 109 | used in Noise layer 110 | bn whether to add BatchNorm layer 111 | """ 112 | def __init__( 113 | self, type, in_channels, out_channels, stride=1, 114 | bn=True, width=None, noise=False, sigma=.2): 115 | super().__init__() 116 | 117 | if type == '3d': 118 | Conv = nn.Conv3d 119 | AvgPool = nn.AvgPool3d 120 | BatchNorm = nn.BatchNorm3d if bn else None 121 | elif type == '2d': 122 | Conv = nn.Conv2d 123 | AvgPool = nn.AvgPool2d 124 | BatchNorm = nn.BatchNorm2d if bn else None 125 | else: 126 | raise TypeError ( 127 | "__init__(): argument 'type' " 128 | "must be '2d' or '3d'" 129 | ) 130 | proj_list = [] 131 | self.proj = None 132 | assert (torch.tensor(stride) <= 2).all() 133 | if in_channels != out_channels: 134 | proj_list += [Conv(in_channels, out_channels, 1)] 135 | if (torch.tensor(stride) > 1).any(): 136 | proj_list += [AvgPool(stride)] 137 | self.proj = nn.Sequential(*proj_list) 138 | 139 | if not width: 140 | width = (in_channels + out_channels) // 4 141 | 142 | self.leaky = nn.LeakyReLU(.2, inplace=True) 143 | self.main = nn.Sequential ( 144 | block1x1(Conv, BatchNorm, in_channels, width, noise, sigma), 145 | self.leaky, 146 | block3x3(Conv, BatchNorm, width, width, stride, noise, sigma), 147 | self.leaky, 148 | block1x1(Conv, BatchNorm, width, out_channels, noise, sigma) 149 | ) 150 | 151 | def forward(self, x): 152 | y = self.main(x) 153 | if self.proj is not None: 154 | x = self.proj(x) 155 | 156 | return self.leaky(y + x) 157 | -------------------------------------------------------------------------------- /src/convgru.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils import spectral_norm as SN 4 | 5 | class GeneralConvGRUCell(nn.Module): 6 | def __init__(self, conv_layers): 7 | super().__init__() 8 | self.ConvZ, self.ConvR, self.ConvH = conv_layers 9 | 10 | def forward(self, x, h_prev): 11 | XH = torch.cat([x, h_prev], dim=1) 12 | z = torch.sigmoid(self.ConvZ(XH)) 13 | r = torch.sigmoid(self.ConvR(XH)) 14 | XRH = torch.cat([x, r*h_prev], dim=1) 15 | h_new = torch.tanh(self.ConvH(XRH)) 16 | 17 | return (1 - z) * h_prev + z * h_new 18 | 19 | 20 | class ConvGRUCell(GeneralConvGRUCell): 21 | def __init__( 22 | self, in_channels, hidden_planes, 23 | kernel_size, spectral_norm): 24 | padding = kernel_size // 2 25 | ConvZ = nn.Conv2d( 26 | in_channels+hidden_planes, 27 | hidden_planes, kernel_size, 28 | padding=padding) 29 | ConvR = nn.Conv2d( 30 | in_channels+hidden_planes, 31 | hidden_planes, kernel_size, 32 | padding=padding) 33 | ConvH = nn.Conv2d( 34 | in_channels+hidden_planes, 35 | hidden_planes, kernel_size, 36 | padding=padding) 37 | 38 | if spectral_norm: 39 | ConvZ = SN(ConvZ) 40 | ConvR = SN(ConvR) 41 | ConvH = SN(ConvH) 42 | else: 43 | nn.init.orthogonal_(ConvZ.weight) 44 | nn.init.orthogonal_(ConvR.weight) 45 | nn.init.orthogonal_(ConvH.weight) 46 | nn.init.constant_(ConvZ.bias, 0.) 47 | nn.init.constant_(ConvR.bias, 0.) 48 | nn.init.constant_(ConvH.bias, 0.) 49 | super().__init__((ConvZ, ConvR, ConvH)) 50 | 51 | 52 | class ConvGRU(nn.Module): 53 | """ 54 | Args: 55 | seq_first if seq_first is True, the layer 56 | takes an input of the shape (N, T, C, H, W) 57 | returns a tensor of the shape (N*T, K, H, W) 58 | and the last hidden state of the size (N, K, H, W). 59 | Otherwise, it expects the shape (N, C, T, H, W) 60 | and outputs the shapes (N, K, T, H, W) and (N, K, H, W) 61 | K - number of hidden planes 62 | C - number of the input channels 63 | T - number of frames in a video 64 | """ 65 | def __init__( 66 | self, in_channels, hidden_planes, 67 | kernel_size, seq_first=False, 68 | spectral_norm=False): 69 | super().__init__() 70 | self.C = hidden_planes 71 | self.seq_first = seq_first 72 | self.CGRUcell = ConvGRUCell( 73 | in_channels, hidden_planes, 74 | kernel_size, spectral_norm) 75 | 76 | def _defaultCall(self, X, h_init): 77 | H = [h_init] 78 | for i in range(X.size(2)): 79 | H.append(self.CGRUcell(X[:, :, i], H[-1])) 80 | return torch.stack(H[1:], dim=2), H[-1] 81 | 82 | def _seqFirstCall(self, X, h_init): 83 | H = [h_init] 84 | for i in range(X.size(1)): 85 | H.append(self.CGRUcell(X[:, i], H[-1])) 86 | return torch.flatten(torch.stack(H[1:], dim=1), 0, 1), H[-1] 87 | 88 | def forward(self, X, h_init=None): 89 | if h_init is None: 90 | h_init = X.new(X.size(0), self.C, *X.shape[3:]).fill_(0) 91 | if self.seq_first: 92 | return self._seqFirstCall(X, h_init) 93 | else: 94 | return self._defaultCall(X, h_init) 95 | 96 | 97 | class AdvancedConvGRU(nn.Module): 98 | """ 99 | Args: 100 | deepConv2d block of 2d convolutions made with `partial` 101 | for specifying only 'in' and 'out' channels 102 | seq_first if seq_first is True, the layer 103 | takes an input of the shape (N, T, C, H, W) 104 | returns a tensor of the shape (N*T, K, H, W) 105 | and the last hidden state of the size (N, K, H, W). 106 | Otherwise, it expects the shape (N, C, T, H, W) 107 | and outputs the shapes (N, K, T, H, W) and (N, K, H, W) 108 | K - number of hidden planes 109 | C - number of the input channels 110 | T - number of frames in a video 111 | """ 112 | def __init__( 113 | self, deepConv2d, in_channels, hidden_planes, seq_first=False): 114 | super().__init__() 115 | self._C = hidden_planes 116 | self.seq_first = seq_first 117 | in_channels += hidden_planes 118 | conv_layers = [deepConv2d(in_channels, hidden_planes)] 119 | conv_layers += [deepConv2d(in_channels, hidden_planes)] 120 | conv_layers += [deepConv2d(in_channels, hidden_planes)] 121 | self.CGRUcell = GeneralConvGRUCell(conv_layers) 122 | 123 | def _defaultCall(self, X, h_init): 124 | H = [h_init] 125 | for i in range(X.size(2)): 126 | H.append(self.CGRUcell(X[:, :, i], H[-1])) 127 | return torch.stack(H[1:], dim=2), H[-1] 128 | 129 | def _seqFirstCall(self, X, h_init): 130 | H = [h_init] 131 | for i in range(X.size(1)): 132 | H.append(self.CGRUcell(X[:, i], H[-1])) 133 | return torch.flatten(torch.stack(H[1:], dim=1), 0, 1), H[-1] 134 | 135 | def forward(self, X, h_init=None): 136 | if h_init is None: 137 | h_init = X.new(X.size(0), self._C, *X.shape[3:]).fill_(0) 138 | if self.seq_first: 139 | return self._seqFirstCall(X, h_init) 140 | else: 141 | return self._defaultCall(X, h_init) 142 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Video Generation Based on Short Text Description (2019) 2 | 3 | Over a year latter (in 2020), I have decided to add README to the repository, since 4 | some people find it useful even without a description. I hope this 5 | step will make the results of my work more usable for those who are interested 6 | in the problem and stumble upon the repository when browsing the topic on GitHub. 7 | 8 | ## Example of Generated Video 9 | 10 | Unfortunately, I have not saved videos generated by the network, 11 | since all the results remained on the working laptop which I hand over 12 | at the end of the internship. The only thing left is the recording 13 | that I did on my cellphone (***Sorry if this makes your eyes bleed***). 14 | 15 | What is on the gif? There are 5 blocks of images stacked horizontally. 16 | Each block contains 4 objects, selected from '20bn-something-something-v2' 17 | dataset and belonging to the same category _"Pushing [something] from left to right"_ 18 | **1** (~1000 samples). They are _book_ (top left window), _box_ (top right window), 19 | _mug_ (bottom left window), and _marker_ (bottom right window), pushed along the surface by hand. 20 | The number of occurrences in the data subset for the corresponding objects is 57, 43, 9, 55. 21 | 22 | The generated videos are diverse (thanks to [zero-gradient penalty](https://arxiv.org/abs/1902.03984)) 23 | and about the same quality as the videos from the training data. However, **there are no tests 24 | conducted on the validation data.** Regarding the gif below, since all the objects 25 | belong to the same category, only single-word conditioning (i.e., on the object) is used. 26 | Still, there are tools in the repository for encoding the whole sentence. 27 | 28 |

29 | 30 |

31 | 32 | --- 33 | **1** Yep, exactly "from left to right" and not the other way 34 | around as you can read it on the gif (it is a typo). However, it is good 35 | for validation purposes to make new labels with the reversed direction of 36 | movement or new (but "similar", e.g., in space of embeddings) objects 37 | from unchanged category. 38 | 39 | ## Navigating Through SRC Files 40 | 41 |
42 | 43 | 44 | 45 | 46 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 65 | 66 | 67 | 68 | 69 | 70 | 71 |
data_prep2.py 47 | Video and text processing (based on text_processing.py)
blocks.py Building blocks used in models.py
visual_encoders.py Advanced building blocks for image and video discriminators
process3-5.ipynb Pipeline for the training process on multiple gpus
(3-5 is a hardcoded range of gpus involved)
pipeline.ipynb Previously served for the same purpose as process3-5.ipynb
but hadrcoded range was 0-2. 64 | Now it is unfinished implementation of mixed batches
legacy (=obsolete) Early attempts and ideas
72 | 73 |
74 | 75 | There is also a collection of [references](https://github.com/lukoshkin/text2video/blob/develop/references.md) 76 | to articles relevant (at the time of 2019) to the text2video generation problem. 77 | 78 | ## Update of 2021 79 | 80 | This year, I have decided to make the results reproducible. It has turned out 81 | that if one has not dealt with this repo before, it is a tough call for them 82 | to get it up and running again. Quite surprising, huh? Especially, considering 83 | I uploaded everything hastily and as it was. Now, at least you can follow 84 | the instructions below. This is how you reach the state of what I left in 2019. 85 | To move further, a great deal of effort is required. Good luck! 86 | 87 | ## Setting Everything Up 88 | 89 | 1. Clone the repository 90 | ``` 91 | git clone https://github.com/lukoshkin/text2video.git 92 | ``` 93 | 94 | 2. Retrieve the docker image 95 | ``` 96 | cd docker 97 | docker build -t lukoshkin/text2video:base . 98 | ``` 99 | or 100 | ``` 101 | docker pull lukoshkin/text2video:base 102 | ``` 103 | 104 | If using singularity, one can obtain the image by typing 105 | ``` 106 | singularity build t2v-base.simg docker://lukoshkin/text2video:base 107 | ``` 108 | 109 | 3. Get an access to GPU. For cluster-folks, it may look like: 110 | ``` 111 | salloc -p gpu_a100 -N 1 -n 4 --gpus=1 --mem-per-gpu=20G --time=12:00:00 112 | ## prints the name of allocated node, e.g., gn26 113 | ssh gn26 114 | ``` 115 | 116 | 4. Cd to the directory where everything is located and create a container 117 | (9999 is a port exposed for Jupyter outside the container. That is, if running 118 | directly on your computer, `localhost:9999` is your access point in a browser. 119 | You may need one more port for TensorBoard as well; note, 8888 is the default 120 | one Jupyter tries first, if the port is busy, you should specify it manually 121 | in the `jupyter`-command with `--port` option ) 122 | ``` 123 | nvidia-docker run --name t2v \ 124 | -p 9999:8888 -v "$PWD":/home/depp/project 125 | -d lukoshkin/text2video:base \ 126 | 'jupyter-notebook --ip=0.0.0.0 --no-browser' 127 | ``` 128 | 129 | Singularity users are like: 130 | ``` 131 | singularity exec \ 132 | --no-home -B "$PWD:$HOME" --nv t2v-base.simg \ 133 | jupyter notebook --ip 0.0.0.0 --no-browser 134 | ``` 135 | 136 | If accustomed to work in JupyterLab, please, use it readily 137 | by rewriting the commands to the proper form first. 138 | 139 | For running everything on a HPC cluster, one should forward the ports. 140 | You type one of the following. Which one? - depends on whether you 141 | ssh to calculation nodes (gn26) on your server and whether you 142 | set up a `nickname` for the latter in `.ssh/config`. 143 | ``` 144 | ssh -NL 9999:gn26:8888 nickname 145 | ssh -NL 9999:localhost:8888 user@server 146 | ``` 147 | -------------------------------------------------------------------------------- /src/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.utils import spectral_norm 4 | 5 | class DBlock(nn.Module): 6 | """ 7 | Discriminator's block 8 | """ 9 | def __init__( 10 | self, type, 11 | in_channels, out_channels, stride, 12 | bn=False, sn=True, mbs=None): 13 | super().__init__() 14 | if type == '2d': 15 | Conv = nn.Conv2d 16 | AvgPool = nn.AvgPool2d 17 | BatchNorm = nn.BatchNorm2d 18 | elif type == '3d': 19 | Conv = nn.Conv3d 20 | AvgPool = nn.AvgPool3d 21 | BatchNorm = nn.BatchNorm3d 22 | else: 23 | raise TypeError( 24 | "__init__(): argument 'type' " 25 | "must be '2d' or '3d'") 26 | 27 | SN = spectral_norm if sn else lambda x: x 28 | leaky = nn.LeakyReLU(.2, inplace=True) 29 | if mbs is not None: 30 | conv1 = PermInvariantLayer( 31 | Conv, in_channels, out_channels, 32 | mbs, sn, kernel_size=3, padding=1) 33 | conv2 = PermInvariantLayer( 34 | Conv, out_channels, out_channels, 35 | mbs, sn, kernel_size=3, padding=1) 36 | else: 37 | conv1 = SN(Conv(in_channels, out_channels, 3, 1, 1)) 38 | conv2 = SN(Conv(out_channels, out_channels, 3, 1, 1)) 39 | 40 | main_list = [leaky, conv1, leaky, conv2] 41 | if bn: 42 | main_list.insert(0, BatchNorm(in_channels)) 43 | main_list.insert(3, BatchNorm(out_channels)) 44 | 45 | proj_list = [] 46 | self.proj = None 47 | assert (torch.tensor(stride) <= 2).all() 48 | if in_channels != out_channels: 49 | proj_list += [SN(Conv(in_channels, out_channels, 1))] 50 | if (torch.tensor(stride) > 1).any(): 51 | main_list += [AvgPool(stride)] 52 | proj_list += [AvgPool(stride)] 53 | self.proj = nn.Sequential(*proj_list) 54 | self.main = nn.Sequential(*main_list) 55 | 56 | def forward(self, x): 57 | return self.main(x) + self.proj(x) 58 | 59 | 60 | #global variable for the rest of the classes 61 | SN = spectral_norm 62 | 63 | 64 | class GBlock(nn.Module): 65 | """ 66 | Generator's block 67 | """ 68 | def __init__(self, type, in_channels, out_channels, stride): 69 | super().__init__() 70 | if type == '2d': 71 | Conv = nn.Conv2d 72 | AvgPool = nn.AvgPool2d 73 | BatchNorm = nn.BatchNorm2d 74 | elif type == '3d': 75 | Conv = nn.Conv3d 76 | AvgPool = nn.AvgPool3d 77 | BatchNorm = nn.BatchNorm3d 78 | else: 79 | raise TypeError( 80 | "__init__(): argument 'type' " 81 | "must be '2d' or '3d'") 82 | main_list = [ 83 | BatchNorm(in_channels), 84 | nn.ReLU(inplace=True), 85 | SN(Conv(in_channels, out_channels, 3, 1, 1)), 86 | BatchNorm(out_channels), 87 | nn.ReLU(inplace=True), 88 | SN(Conv(out_channels, out_channels, 3, 1, 1)), 89 | ] 90 | proj_list = [] 91 | self.proj = None 92 | assert (torch.tensor(stride) <= 2).all() 93 | if in_channels != out_channels: 94 | proj_list += [SN(Conv(in_channels, out_channels, 1))] 95 | if (torch.tensor(stride) > 1).any(): 96 | main_list.insert(2, nn.Upsample(scale_factor=stride)) 97 | proj_list.insert(0, nn.Upsample(scale_factor=stride)) 98 | self.proj = nn.Sequential(*proj_list) 99 | self.main = nn.Sequential(*main_list) 100 | 101 | def forward(self, x): 102 | return self.main(x) + self.proj(x) 103 | 104 | 105 | class CGBlock(nn.Module): 106 | """ 107 | Conditional generator's block 108 | """ 109 | def __init__(self, type, cond_size, in_channels, out_channels, stride): 110 | super().__init__() 111 | if type == '2d': 112 | Conv = nn.Conv2d 113 | AvgPool = nn.AvgPool2d 114 | elif type == '3d': 115 | Conv = nn.Conv3d 116 | AvgPool = nn.AvgPool3d 117 | else: 118 | raise TypeError ( 119 | "__init__(): argument 'type' " 120 | "must be '2d' or '3d'" 121 | ) 122 | self.bn1 = CBN(type, cond_size, in_channels) 123 | self.bn2 = CBN(type, cond_size, out_channels) 124 | self.conv1 = SN(Conv(in_channels, out_channels, 3, 1, 1)) 125 | self.conv2 = SN(Conv(out_channels, out_channels, 3, 1, 1)) 126 | self.upsample = nn.Upsample(scale_factor=stride) 127 | self.relu = nn.ReLU(inplace=True) 128 | 129 | proj_list = [] 130 | self.proj = None 131 | assert (torch.tensor(stride) <= 2).all() 132 | if in_channels != out_channels: 133 | proj_list += [SN(Conv(in_channels, out_channels, 1))] 134 | if (torch.tensor(stride) > 1).any(): 135 | proj_list.insert(0, nn.Upsample(scale_factor=stride)) 136 | self.proj = nn.Sequential(*proj_list) 137 | 138 | def forward(self, x, y): 139 | h = self.relu(self.bn1(x, y)) 140 | h = self.upsample(h) 141 | h = self.conv1(h) 142 | h = self.relu(self.bn2(h, y)) 143 | h = self.conv2(h) 144 | 145 | return h + self.proj(x) 146 | 147 | 148 | class CBN(nn.Module): 149 | """ 150 | Conditional BatchNorm 151 | """ 152 | def __init__( 153 | self, type, cond_size, out_channels, eps=1e-5, momentum=0.1): 154 | super().__init__() 155 | self.gain = SN(nn.Linear(cond_size, out_channels, bias=False)) 156 | self.bias = SN(nn.Linear(cond_size, out_channels, bias=False)) 157 | if type == '2d': 158 | self.type = 2 159 | self.bn = nn.BatchNorm2d( 160 | out_channels, eps, momentum, affine=False) 161 | else: 162 | self.type = 3 163 | self.bn = nn.BatchNorm3d( 164 | out_channels, eps, momentum, affine=False) 165 | 166 | def forward(self, x, y): 167 | gain = self.gain(y).view(y.size(0), -1, *(self.type*[1])) 168 | bias = self.bias(y).view(y.size(0), -1, *(self.type*[1])) 169 | return gain * self.bn(x) + bias 170 | 171 | 172 | class PermInvariantLayer(nn.Module): 173 | """ 174 | The implementation of https://arxiv.org/pdf/1806.07185.pdf 175 | Args: 176 | ----- 177 | layer `nn.Linear` or `nn.ConvNd` 178 | kwargs all other keyword arguments relevant to the specified 179 | layer (except `bias` - this will raise a `TypeError`) 180 | minibatch_size size of mini-batches that result from splitting along 181 | the batch dimension. The actual batch size must be 182 | divisible by this number. 183 | """ 184 | def __init__( 185 | self, layer, in_channels, 186 | out_channels, minibatch_size, sn=True, **kwargs): 187 | super().__init__() 188 | if 'bias' in kwargs: 189 | raise TypeError("`bias` argument is redundant here") 190 | self.mbs = minibatch_size 191 | self.batch_mixing = layer( 192 | in_channels, out_channels, bias=False, **kwargs) 193 | self.non_mixing = layer(in_channels, out_channels, **kwargs) 194 | with torch.no_grad(): 195 | self.batch_mixing.weight.mul_(1./(self.mbs+1)) 196 | self.non_mixing.weight.mul_(self.mbs/(self.mbs+1)) 197 | 198 | if sn: 199 | self.batch_mixing = SN(self.batch_mixing) 200 | self.non_mixing = SN(self.non_mixing) 201 | 202 | def forward(self, x): 203 | out = self.non_mixing(x) 204 | out = out.view(-1, self.mbs, *out.shape[1:]) 205 | x = x.view(-1, self.mbs, *x.shape[1:]).mean(1) 206 | out += self.batch_mixing(x).unsqueeze(1) 207 | 208 | # x.shape: (N, the rest) 209 | # out.shape: (N/self.mbs, self.mbs, the rest) 210 | # x.shape: (N/self.mbs, the rest) 211 | # the last operation broadcasting: 212 | # (N/self.mbs, self.mbs, the rest) 213 | # + (N/self.mbs, 1, the rest) 214 | 215 | return torch.flatten(out,0,1) 216 | -------------------------------------------------------------------------------- /src/data_prep2.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import pickle 3 | import numpy as np 4 | 5 | from tqdm import tqdm 6 | from pathlib import Path 7 | from smart_open import open 8 | from torch.utils.data import Dataset 9 | from text_processing2 import TextProcessor 10 | 11 | 12 | class LabeledVideoDataset(TextProcessor, Dataset): 13 | """ 14 | Args: 15 | ----- 16 | path2lables - first required positional argument. 17 | path2videos - second required positional argument. 18 | cache - folder to preserve intermediate and final results. 19 | video_shape=(D, H, W, C), where D is the number of frames. 20 | step - if a video has multiple of D frames, then this var 21 | is applied to reduce their number. 22 | mode - the way of screening samples based a sentence object. 23 | If 'toy', only sentences with single word objects are selected. 24 | If 'simple' - with single-word and hyphenated compound word objects. 25 | If 'casual' - any, including those presented by multiple word objects. 26 | check_spell - whether or not to correct words spelling. 27 | min_word_freq - maximum word frequency of rare words. Sentences 28 | containing them will be discarded from the dataset. 29 | glove_folder - path to GloVe embeddings (default: ../embeddings). 30 | emb_size - dimensionality of GloVE embeddings. 31 | glove_filtration - remove samples (sentences) containing 32 | 'out-of-GloVe-dictionary' words. 33 | transform - function which should be applied on a video. 34 | """ 35 | def __init__( 36 | self, path2labels, path2videos, cache, 37 | video_shape=(32, 32, 32, 3), step=2, 38 | mode='toy', check_spell=True, min_word_freq=2, 39 | glove_folder='../embeddings', emb_size=50, 40 | glove_filtration=True, transform=None): 41 | 42 | super().__init__( 43 | path2labels, cache, mode, 44 | check_spell, min_word_freq, 45 | glove_folder, emb_size, glove_filtration) 46 | 47 | self.transform = transform if transform else lambda x: x 48 | self._save_db_as = Path(f'{self._path.stem}.db') 49 | self._video_shape = video_shape 50 | self._step = step 51 | 52 | if (cache/self._save_db_as).exists(): 53 | with open(cache/self._save_db_as, 'rb') as fp: 54 | self.data = pickle.load(fp) 55 | self._i2i = pickle.load(fp) 56 | else: 57 | self._i2i = {} 58 | self.data = {'major': [], 'minor': []} 59 | 60 | self._prepareDatabase(Path(path2videos)) 61 | print('Caching database to', self._save_db_as) 62 | with open(cache/self._save_db_as, 'wb') as fp: 63 | pickle.dump(self.data, fp, pickle.HIGHEST_PROTOCOL) 64 | pickle.dump(self._i2i, fp) 65 | print('Done!') 66 | 67 | def __len__(self): 68 | return len(self.data['minor']) 69 | 70 | def __getitem__(self, index): 71 | """ 72 | 'minor' is a set of extracted and encoded 73 | object, action performed on the object, 74 | and the number of tokens in the sentence 75 | presenting this action. 76 | 77 | 'major' is a set of video, its description, 78 | and the number of words in the latter. 79 | """ 80 | lbl, lbl_len, video = self.data['major'][index] 81 | obj_vec, act_vec, act_len = self.data['minor'][index] 82 | 83 | return {'major': 84 | {'label': lbl, 85 | 'lbllen': lbl_len, 86 | 'video': self.transform(video)}, 87 | 'minor': 88 | {'object': obj_vec, 89 | 'action': act_vec, 90 | 'actlen': act_len}} 91 | 92 | def getById(self, video_ids): 93 | """ 94 | Extracts the samples' data by their video ids 95 | """ 96 | ids = list(map(self._i2i.get, video_ids)) 97 | selected_major = np.take(self.data['major'], ids) 98 | selected_minor = np.take(self.data['minor'], ids) 99 | return np.rec.array(selected_major), np.rec.array(selected_minor) 100 | 101 | def sen2vec(self, sen, mode): 102 | """ 103 | Converts a sentence to a sequence of positive integers 104 | according to 't2i' dictionary. Depending on the mode, 105 | the result may be padded to the required length. 106 | 107 | Output type: int64 - necessary for nn.Embedding 108 | """ 109 | if mode == 'toy': 110 | return self.t2i[sen[0]] 111 | 112 | filling = [self.t2i[w] for w in sen] 113 | if mode == 'simple': 114 | return np.array(filling) 115 | 116 | max_len = self._max_len if mode == 'action' else self._act_max_len 117 | numerated = np.zeros(max_len, 'int') 118 | numerated[:len(filling)] = filling 119 | return numerated 120 | 121 | def _prepareDatabase(self, path2videos): 122 | """ 123 | Fetches videos and corresponding text representations to `self.data` 124 | Prepares `self._i2i` for the later use by the `self.getById` method 125 | """ 126 | D, H, W, C = self._video_shape 127 | new_index, corrupted, mult = 0, 0, [] 128 | pbar = tqdm( 129 | self.df.iterrows(), "Preparing dataset", 130 | len(self.df), bar_format=self._tqdmBF) 131 | 132 | for old_index, sample in pbar: 133 | video = path2videos/f"{sample.id}.webm" 134 | ViCap = cv2.VideoCapture(str(video)) 135 | _D = ViCap.get(cv2.CAP_PROP_FRAME_COUNT) 136 | if int(_D) < D: 137 | self.df.drop(old_index, inplace=True) 138 | continue 139 | 140 | mult.append(int(_D) // D) 141 | frames, CNT = self._extractFrames(ViCap, D*mult[-1]) 142 | if CNT == D * mult[-1]: 143 | frames = np.array(frames, 'f4').transpose(3,0,1,2) 144 | frames = frames[:, ::self._step*mult[-1]] / 255 145 | self._processSample(frames, sample) 146 | self._i2i[sample.id] = new_index 147 | new_index += 1 148 | else: 149 | corrupted += 1 150 | mult.pop() 151 | self.df.drop(old_index, inplace=True) 152 | self.df.index = np.arange(len(self.df)) 153 | print('No of corrupted videos:', corrupted) 154 | self.data['major'] = np.array( 155 | self.data['major'], 156 | [('', 'i8', self._max_len), ('', 'i8'), 157 | ('', 'f4', (C, D//self._step, H, W))]) 158 | self.data['minor'] = np.array( 159 | self.data['minor'], 160 | [('', 'O'), ('', 'i8', self._act_max_len), ('', 'i8')]) 161 | 162 | def _extractFrames(self, ViCap, length): 163 | """ 164 | Retrieves cropped to the required shape frames from the stream 165 | `ViCap`. The maximum number of extracted frames is limited 166 | to the `length` value. In addition to the list of frames, 167 | it returns the length of the resulting video - `CNT` 168 | """ 169 | CNT = 0 170 | frames = [] 171 | success = True 172 | while success and (CNT < length): 173 | success, image = ViCap.read() 174 | if success: 175 | image = cv2.resize( 176 | image, self._video_shape[1:-1], 177 | interpolation=cv2.INTER_AREA) 178 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 179 | frames += [image] 180 | CNT += 1 181 | 182 | ViCap.release() 183 | cv2.destroyAllWindows() 184 | return frames, CNT 185 | 186 | def _processSample(self, frames, sample): 187 | """ 188 | Obtains objects, action and the entire sentence (all vectorized) 189 | from the corresponding to frames row in `self.df`. Adds the earlier 190 | extracted attributes to `self.data` 191 | """ 192 | lbl, obj, act = sample[1:4] 193 | obj_len, act_len = map(len, [obj, act]) 194 | lbl_len = act_len + obj_len - 1 195 | 196 | obj_vec = self.sen2vec(obj, self._mode) 197 | act_vec = self.sen2vec(act, 'action') 198 | lbl_vec = self.sen2vec(lbl, 'label') 199 | 200 | self.data['minor'].append((obj_vec, act_vec, act_len)) 201 | self.data['major'].append((lbl_vec, lbl_len, frames)) 202 | -------------------------------------------------------------------------------- /src/text_processing2.py: -------------------------------------------------------------------------------- 1 | import re 2 | import pickle 3 | 4 | import copy 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from tqdm import tqdm 9 | from pathlib import Path 10 | from smart_open import open 11 | from collections import Counter 12 | from nltk import wordpunct_tokenize 13 | from spellchecker import SpellChecker 14 | 15 | 16 | def selectTemplates(path, templates, new_name): 17 | """ 18 | Remove from the database (json file specified by `path` argument) 19 | label categories which are not in `templates` list. The new file 20 | is created under the same folder as the old one. 21 | """ 22 | path = Path(path) 23 | df = pd.read_json(path) 24 | 25 | mask = df.template.isin(templates) 26 | new_df = df[mask] 27 | 28 | new_df.index = np.arange(mask.sum()) 29 | new_path = path.parent / new_name 30 | new_df.to_pickle(new_path) 31 | 32 | return new_path 33 | 34 | 35 | class TextProcessor: 36 | def __init__( 37 | self, path, cache, mode='toy', 38 | check_spell=True, min_freq=2, 39 | glove_folder='../embeddings', 40 | emb_size=50, glove_filtration=True): 41 | self._mode = mode 42 | self._path = Path(path) 43 | self._cache = Path(cache) 44 | 45 | self._emb_size = emb_size 46 | self._min_freq = min_freq 47 | self._check_spell = check_spell 48 | self._save_df_as = f'{self._path.stem}.pkl' 49 | self._save_t2i_as = f'{mode}_vocab.pkl' 50 | self._tqdmBF = '{l_bar}{bar:10}{r_bar}{bar:-10b}' 51 | self._GF = glove_filtration 52 | 53 | if mode not in ['toy', 'simple', 'casual']: 54 | raise TypeError("mode must be 'toy', 'simple' or 'casual'") 55 | if mode == 'casual': 56 | raise NotImplemented('This feature is not added') 57 | 58 | self._prepareGloveDict(glove_folder, emb_size) 59 | self._doTextPart() 60 | 61 | def getGloveEmbeddings(self, monitor_unk=False): 62 | """ 63 | Prepares word embedding matrix of width `emb_size` 64 | with `self.t2i` vocab. from the corresponding file in the `folder`. 65 | Caches the result at `self.cache` 66 | """ 67 | if (self._cache/'emb_matrix.npy').exists(): 68 | return np.load(self._cache/'emb_matrix.npy') 69 | 70 | not_found = 0 71 | emb_matrix = np.empty((len(self.t2i), self._emb_size), 'float32') 72 | for t, i in self.t2i.items(): 73 | try: 74 | emb_matrix[i] = self._glove[t] 75 | except KeyError: 76 | if monitor_unk: print('UNK:', t) 77 | emb_matrix[i] = .6 * np.random.randn(self._emb_size) 78 | not_found += 1 79 | emb_matrix[0] = 0 80 | 81 | if not_found: 82 | print('No of missed tokens in glove dict:', not_found) 83 | np.save(self._cache/'emb_matrix', emb_matrix) 84 | return emb_matrix 85 | 86 | def _prepareGloveDict(self, folder, emb_size): 87 | """ 88 | Constructs GloVe dictionary. Cashes the result 89 | """ 90 | save_as = self._cache/f'glove.{emb_size}d.pkl' 91 | if save_as.exists(): 92 | with open(save_as, 'rb') as fp: 93 | self._glove = pickle.load(fp) 94 | return 95 | 96 | folder = Path(folder) 97 | with open(folder/f'glove.6B.{emb_size}d.txt') as fp: 98 | raw_data = fp.readlines() 99 | 100 | pbar = tqdm( 101 | raw_data, 'Reading glove embeddings', 102 | bar_format=self._tqdmBF) 103 | 104 | self._glove = {} 105 | for line in pbar: 106 | t, *v = line.split() 107 | self._glove[t] = np.array(v, 'float32') 108 | 109 | with open(save_as, 'wb') as fp: 110 | pickle.dump(self._glove, fp) 111 | 112 | def _doTextPart(self): 113 | """ 114 | Does all the work related to text part: 115 | filtration by word frequency, spell correction, 116 | data frame columns update. Also, caches the results 117 | """ 118 | if ((self._cache/self._save_t2i_as).exists() and 119 | (self._cache/self._save_df_as).exists()): 120 | self.df = pd.read_pickle(self._cache/self._save_df_as) 121 | with open(self._cache/self._save_t2i_as, 'rb') as fp: 122 | self.t2i = pickle.load(fp) 123 | self._max_len, self._act_max_len = pickle.load(fp) 124 | return 125 | 126 | ext = self._path.suffix 127 | if ext == '.json': 128 | self.df = pd.read_json(self._path) 129 | elif ext == '.pkl': 130 | self.df = pd.read_pickle(self._path) 131 | else: 132 | raise TypeError("'json' and 'pkl' are only supported") 133 | 134 | tokens = ['PAD'] 135 | self._extractObjects(self._mode, tokens) 136 | self._extractActions(tokens) 137 | self.df.label = self.df.apply( 138 | lambda x: re.sub('something', 139 | ' '.join(x.placeholders), 140 | ' '.join(x.template)), axis=1) 141 | self.df.label = self.df.label.map(str.split) 142 | self._act_max_len = max(map(len, self.df.template)) 143 | if self._mode == 'toy': self._max_len = self._act_max_len 144 | else: self._max_len = max(map(len, self.df.label)) 145 | 146 | self.t2i = {t: i for i, t in enumerate(tokens)} 147 | self._cache.mkdir(parents=True, exist_ok=True) 148 | self.df.to_pickle(self._cache/self._save_df_as) 149 | with open(self._cache/self._save_t2i_as, 'wb') as fp: 150 | pickle.dump(self.t2i, fp) 151 | pickle.dump((self._max_len, self._act_max_len), fp) 152 | 153 | def _extractActions(self, tokens): 154 | """ 155 | Updates the `self.df.template` series and collects tokens 156 | encountered in the column into `tokens` 157 | """ 158 | self.df.template = self.df.template.map( 159 | #lambda x: re.sub('\[|\]', '', x)) 160 | lambda x: re.sub('\[.*?\]', 'something', x)) 161 | # there are strings where there are more than just 162 | # word 'something'. In this case, the words characterize 163 | # the object better, however, it may not help to 164 | # train the network 165 | self.df.template = self.df.template.map(str.lower) 166 | sentences = self.df.template.unique() 167 | if self._GF: 168 | bad_sens, bad_ids = [], [] 169 | for i, sen in enumerate(sentences): 170 | for w in sen.split(): 171 | if w not in self._glove: 172 | bad_sens.append(sen) 173 | bad_ids.append(i) 174 | break 175 | mask = self.df.template.isin(bad_sens) 176 | self.df = self.df[~mask] 177 | sentences = np.delete(sentences, bad_ids) 178 | self.df.index = np.arange(len(self.df)) 179 | self.df.template = self.df.template.map(wordpunct_tokenize) 180 | token_counts = Counter( 181 | np.concatenate(list(map(str.split, sentences)))) 182 | tokens += list(token_counts.keys()) 183 | 184 | def _extractObjects(self, mode, tokens): 185 | """ 186 | Updates the `self.df.placeholders` series according to the processing 187 | `mode` policy. Collects tokens encountered in the column into `tokens` 188 | """ 189 | sentences = self.df.placeholders 190 | if mode in ['simple', 'toy']: 191 | mask = (sentences.map(len) == 1) 192 | sentences = sentences[mask] 193 | self.df = self.df[mask] 194 | if mode == 'toy': 195 | mask = (sentences.map( 196 | lambda x: len(wordpunct_tokenize(x[0])) == 1)) 197 | sentences = sentences[mask] 198 | self.df = self.df[mask] 199 | sentences = sentences.map(lambda x: x[0]) 200 | mask = sentences.map(lambda x: 'something' not in x) 201 | sentences = sentences[mask] 202 | self.df = self.df[mask] 203 | self.df.index = np.arange(mask.sum()) 204 | 205 | sentences = list(sentences.map(wordpunct_tokenize).values) 206 | # << converting to list to be able to delete elements 207 | token_counts = self._findTypos(sentences) 208 | self._updateObjects(sentences) 209 | 210 | for w in self._vague_words: 211 | token_counts.pop(w) 212 | for w in token_counts: 213 | if w in self._glove: 214 | tokens.append(w) 215 | 216 | def _findTypos(self, sentences): 217 | """ 218 | Colects suspicious (rare) words in `self._vague_words`. 219 | `self._min_freq` defines the rarity extent. 220 | If `self._check_spell` == True, then the words found 221 | will be corrected with pyspellchecker 222 | 223 | Returns counts of the words in the sentences 224 | """ 225 | def freqFilter(C): 226 | return [] if self._min_freq < 1 else [ 227 | x for x in C.keys() if C[x] <= self._min_freq] 228 | 229 | token_counts = Counter(np.concatenate(sentences)) 230 | self._vague_words = freqFilter(token_counts) 231 | 232 | if not (self._check_spell and self._vague_words): 233 | return token_counts 234 | 235 | spell = SpellChecker() 236 | checked_words = [] 237 | pbar = tqdm( 238 | sentences, 239 | 'Spell-check', 240 | bar_format=self._tqdmBF) 241 | 242 | for sen in pbar: 243 | for i, w in enumerate(sen): 244 | if w in self._vague_words: 245 | sen[i] = spell.correction(w) 246 | checked_words.append(sen[i]) 247 | token_counts.update(checked_words) 248 | self._vague_words = freqFilter(token_counts) 249 | return token_counts 250 | 251 | def _updateObjects(self, sentences): 252 | """ 253 | Subtitutes 'mispelled' words (those that are in `self._vague_words`) 254 | in the `self.df.placeholders` with the `sentences` given 255 | """ 256 | def to_remove(word): 257 | vague = word in self._vague_words 258 | if self._GF: 259 | not_found = word not in self._glove 260 | return vague or not_found 261 | return vague 262 | 263 | mended = copy.deepcopy(sentences) 264 | if self._vague_words or self._GF: 265 | pbar = tqdm( 266 | sentences, 267 | "Removing 'bad samples'", 268 | bar_format=self._tqdmBF) 269 | 270 | k = 0 271 | for i, sen in enumerate(pbar): 272 | flag = True 273 | for w in sen: 274 | if to_remove(w): 275 | self.df.drop(i, inplace=True) 276 | del mended[k] 277 | flag = False 278 | break 279 | if flag: 280 | k += 1 281 | 282 | self.df.placeholders = mended 283 | self.df.index = np.arange(len(self.df)) 284 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.data 4 | 5 | from torch.nn.utils.rnn import pack_padded_sequence, PackedSequence 6 | from torch.nn.utils import spectral_norm as SN 7 | from functools import partial 8 | from blocks import DBlock, GBlock, CGBlock, PermInvariantLayer 9 | from convgru import ConvGRU, AdvancedConvGRU 10 | 11 | 12 | class SimpleTextEncoder: 13 | """ 14 | Embedding of a sentence is obtained as the average 15 | of pretrained GloVe embeddings of the words 16 | that make up the sentence 17 | """ 18 | def __init__(self, emb_weights): 19 | self.embed = nn.Embedding.from_pretrained( 20 | emb_weights, padding_idx=0) 21 | 22 | def __call__(self, text_ids, lengths): 23 | if lengths == 1: 24 | return self.embed(text_ids) 25 | return self.embed(text_ids).sum(1) / lengths[:, None].float() 26 | 27 | 28 | class TextEncoder(nn.Module): 29 | """ 30 | Args: 31 | emb_weights matrix of size (n_tokens, emb_dim) 32 | proj project to lower dimension space 33 | train_embs whether to train embeddings or not 34 | (by default, the value is False, 35 | i.e. emb_weights are frozen) 36 | 37 | hyppar[0] number of gru hidden units 38 | hyppar[1] projection dimensionality 39 | """ 40 | def __init__( 41 | self, emb_weights, proj=False, 42 | train_embs=False, hyppar=(64,64)): 43 | super().__init__() 44 | self.embed = nn.Embedding.from_pretrained ( 45 | emb_weights, 46 | freeze=(not train_embs), 47 | padding_idx=0 48 | ) 49 | self.gru = nn.GRU ( 50 | emb_weights.size(1), hyppar[0], 51 | batch_first=True, bidirectional=True 52 | ) 53 | if proj: 54 | self.proj = nn.Sequential ( 55 | nn.Linear(hyppar[0]*2, hyppar[1]), 56 | nn.LeakyReLU(.2, inplace=True) 57 | ) 58 | else: 59 | self.proj = lambda x: x 60 | 61 | def forward(self, text_ids, lengths): 62 | lengths, sortbylen = lengths.sort(0, descending=True) 63 | H = self.embed(text_ids[sortbylen]) 64 | H = pack_padded_sequence(H, lengths, batch_first=True) 65 | 66 | _, last = self.gru(H) 67 | out = self.proj(torch.cat(tuple(last), 1)) 68 | _, unsort = sortbylen.sort(0) 69 | 70 | return out[unsort] 71 | 72 | 73 | class StackImageDiscriminator(nn.Module): 74 | def __init__( 75 | self, in_channels=3, cond_size=64, 76 | base_width=32, noise=False, sigma=.2): 77 | super().__init__() 78 | self.D1 = nn.Sequential ( 79 | SN(nn.Conv2d(in_channels, base_width, 1)), 80 | DBlock('2d', base_width, base_width, 2), 81 | DBlock('2d', base_width, base_width*2, 2), 82 | DBlock('2d', base_width*2, base_width*4, 2), 83 | ) 84 | # << output size: (-1, base_width*4, 4, 4) 85 | 86 | cat_dim = base_width*4 + cond_size 87 | self.D2 = nn.Sequential ( 88 | DBlock('2d', cat_dim, cat_dim*2, 1), 89 | SN(nn.Conv2d(cat_dim*2, 1, 4)), 90 | nn.Sigmoid() 91 | ) 92 | 93 | def forward(self, x, c): 94 | """ 95 | x: images; tensor of shape (N, k, C, H, W) 96 | c: condition (do not confuse with C - number of filters) 97 | """ 98 | k = x.size(1) 99 | x = torch.flatten(x, 0, 1) 100 | x = self.D1(x) 101 | 102 | c = c[None, ..., None, None].expand(k, *c.shape, 4, 4) 103 | x = torch.cat((x, torch.flatten(c,0,1)), 1) 104 | return self.D2(x).view(-1, k) 105 | 106 | 107 | class StackVideoDiscriminator(nn.Module): 108 | def __init__( 109 | self, in_channels=3, cond_size=64, 110 | base_width=32, noise=False, sigma=.2): 111 | super().__init__() 112 | self.D1 = nn.Sequential ( 113 | SN(nn.Conv3d(in_channels, base_width, 1)), 114 | DBlock('3d', base_width, base_width, 2), 115 | DBlock('3d', base_width, base_width*2, 2), 116 | DBlock('3d', base_width*2, base_width*4, 2), 117 | DBlock('3d', base_width*4, base_width*8, (2,1,1)), 118 | ) 119 | # << ouput size: (-1, base_width*4, 1, 4, 4) 120 | 121 | cat_dim = base_width*8 + cond_size 122 | self.D2 = nn.Sequential ( 123 | DBlock('3d', cat_dim, cat_dim*2, 1), 124 | SN(nn.Conv3d(cat_dim*2, 1, (1, 4, 4))), 125 | nn.Sigmoid() 126 | ) 127 | 128 | def forward(self, x, c): 129 | """ 130 | x: video 131 | c: condition 132 | """ 133 | x = self.D1(x) 134 | c = c.view(*c.shape,1,1,1).expand(*c.shape,1,4,4) 135 | x = torch.cat((x, c), 1) 136 | 137 | return self.D2(x).view(-1) 138 | 139 | class BatchVideoDiscriminator(nn.Module): 140 | def __init__(self, mbs, cond_size, in_colors=3, base_width=32): 141 | super().__init__() 142 | block2d = partial(DBlock, '2d', mbs=mbs) 143 | block3d = partial(DBlock, '3d', mbs=mbs) 144 | cgru_heart = partial( 145 | PermInvariantLayer, nn.Conv2d, 146 | minibatch_size=mbs, sn=True, kernel_size=3, padding=1) 147 | self.cgru = AdvancedConvGRU(cgru_heart, base_width*4, base_width*4) 148 | self.k1_conv = PermInvariantLayer( 149 | nn.Conv3d, in_colors, base_width, 150 | mbs, sn=True, kernel_size=1) 151 | proj = PermInvariantLayer( 152 | nn.Linear, cond_size, base_width*32, mbs, sn=True) 153 | self.proj = nn.Sequential(proj, nn.LeakyReLU(.2, True)) 154 | self.pool = PermInvariantLayer( 155 | nn.Linear, base_width*32, 1, mbs, sn=True) 156 | self.downsampler1 = nn.Sequential( 157 | block3d(base_width, base_width*2, 2), 158 | block3d(base_width*2, base_width*4, (1,2,2))) 159 | self.downsampler2 = nn.Sequential( 160 | block2d(base_width*4, base_width*8, 2), 161 | block2d(base_width*8, base_width*16, 2), 162 | block2d(base_width*16, base_width*32, 2)) 163 | 164 | def forward(self, video, embedding): 165 | H = self.k1_conv(video) 166 | sys.stdout.flush() 167 | H = self.downsampler1(H) 168 | _, last = self.cgru(H) 169 | H = self.downsampler2(last) 170 | H = H.view(H.size(0), -1) 171 | E = self.proj(embedding) 172 | out = self.pool(H).squeeze() 173 | out += torch.einsum('ij,ij->i', E, H) 174 | 175 | return torch.sigmoid(out) 176 | 177 | 178 | class SimpleVideoGenerator(nn.Module): 179 | """ 180 | Args: 181 | dim_Z noise dimensionality 182 | cond_size condition size 183 | """ 184 | def __init__( 185 | self, dim_Z, cond_size=64, 186 | n_colors=3, base_width=128, video_length=16): 187 | super().__init__() 188 | self.dim_Z = dim_Z 189 | self.n_colors = n_colors 190 | self.vlen = video_length 191 | self.code_size = dim_Z + cond_size 192 | 193 | self.gru = nn.GRU( 194 | self.code_size, self.code_size, batch_first=True) 195 | 196 | GB = partial(GBlock, '3d', stride=(1,2,2)) 197 | 198 | self.main = nn.Sequential( 199 | GB(self.code_size, base_width*8), 200 | GB(base_width*8, base_width*4), 201 | GB(base_width*4, base_width*2), 202 | GB(base_width*2, base_width), 203 | GB(base_width, self.n_colors), 204 | nn.Tanh()) 205 | 206 | def forward(self, c, vlen=None): 207 | """ 208 | c: condition (batch_size, cond_size) 209 | vlen: video length 210 | """ 211 | vlen = vlen if vlen else self.vlen 212 | 213 | code = c.new(len(c), vlen, self.code_size).normal_() 214 | code[..., self.dim_Z:] = c[:, None, :] 215 | 216 | H,_ = self.gru(code) 217 | H = H.permute(0, 2, 1)[..., None, None] 218 | 219 | return self.main(H) 220 | 221 | 222 | class TestVideoGenerator(nn.Module): 223 | """ 224 | Args: 225 | dim_Z noise dimensionality 226 | cond_size condition size 227 | """ 228 | def __init__( 229 | self, dim_Z, cond_size=64, 230 | n_colors=3, base_width=128, video_length=16): 231 | super().__init__() 232 | self.dim_Z = dim_Z 233 | self.n_colors = n_colors 234 | self.vlen = video_length 235 | self.code_size = dim_Z + cond_size 236 | 237 | self.gru = nn.GRU( 238 | self.code_size, self.code_size, batch_first=True) 239 | 240 | #GB = partial(GBlock, '2d', stride=1) 241 | CGB = partial(CGBlock, '3d', self.code_size, stride=(1,2,2)) 242 | 243 | self.gblock1 = CGB(self.code_size, base_width*8) 244 | self.gblock2 = CGB(base_width*8, base_width*4) 245 | self.gblock3 = CGB(base_width*4, base_width*2) 246 | #self.cgru = AdvancedConvGRU(GB, base_width*2, base_width*2) 247 | self.cgru = ConvGRU( 248 | base_width*2, base_width*2, 3, spectral_norm=True) 249 | self.gblock4 = CGB(base_width*2, base_width) 250 | self.gblock5 = CGB(base_width, self.n_colors) 251 | 252 | def forward(self, c, vlen=None): 253 | """ 254 | c: condition (batch_size, cond_size) 255 | vlen: video length 256 | """ 257 | vlen = vlen if vlen else self.vlen 258 | 259 | code = c.new(len(c), vlen, self.code_size).normal_() 260 | code[..., self.dim_Z:] = c[:, None, :] 261 | 262 | vcon = c.new(len(c), self.code_size).normal_() 263 | vcon[:, self.dim_Z:] = c 264 | 265 | H,_ = self.gru(code) 266 | H = H.permute(0, 2, 1)[..., None, None] 267 | 268 | H = self.gblock1(H, vcon) 269 | H = self.gblock2(H, vcon) 270 | H = self.gblock3(H, vcon) 271 | H,_ = self.cgru(H) 272 | H = self.gblock4(H, vcon) 273 | H = self.gblock5(H, vcon) 274 | 275 | return torch.tanh(H) 276 | 277 | 278 | class MultiConditionalVideoGenerator(nn.Module): 279 | """ 280 | Args: 281 | dim_Z noise dimensionality 282 | cond_sizes condition sizes 283 | """ 284 | def __init__( 285 | self, dim_Z, cond_sizes=(64,64,64), 286 | n_colors=3, base_width=128, video_length=16): 287 | super().__init__() 288 | self.dim_Z = dim_Z 289 | self.n_colors = n_colors 290 | self.vlen = video_length 291 | self.code_size = dim_Z + cond_sizes[0] 292 | self.imcond_size = dim_Z + cond_sizes[1] 293 | self.vicond_size = dim_Z + cond_sizes[2] 294 | 295 | self.gru = nn.GRU( 296 | self.code_size, self.code_size, batch_first=True) 297 | 298 | iCGB = partial(CGBlock, '2d', self.imcond_size, stride=2) 299 | viCGB = partial(CGBlock, '3d', self.imcond_size, stride=2) 300 | vCGB = partial(CGBlock, '3d', self.vicond_size, stride=(1,2,2)) 301 | 302 | self.gblock1 = viCGB(self.code_size, base_width*8) 303 | self.gblock2 = vCGB(base_width*8, base_width*4) 304 | self.gblock3 = vCGB(base_width*4, base_width*2) 305 | self.cgru = ConvGRU( 306 | base_width*2, base_width*2, 3, spectral_norm=True) 307 | self.gblock4 = iCGB(base_width*2, base_width) 308 | self.gblock5 = iCGB(base_width, self.n_colors) 309 | 310 | def forward(self, lc, ic, vc, vlen=None): 311 | """ 312 | vlen: video length 313 | lc: label condition (batch_size, cond_sizes[0]) 314 | ic: image condition (batch_size, cond_sizes[1]) 315 | vc: video condition (batch_size, cond_sizes[2]) 316 | """ 317 | N = len(lc) 318 | vlen = vlen if vlen else self.vlen 319 | 320 | # >> basic conditioning 321 | code = lc.new(N, vlen//2, self.code_size).normal_() 322 | code[..., self.dim_Z:] = lc[:, None, :] 323 | 324 | # >> frame-level conditioning 325 | imcond = ic.new(N*vlen, self.imcond_size).normal_() 326 | imcond[:, self.dim_Z:] = ic.repeat_interleave(vlen, 0) 327 | hicond = ic.new(N, self.imcond_size).normal_() 328 | hicond[:, self.dim_Z:] = ic 329 | 330 | # >> action-capturing conditioning 331 | vicond = vc.new(N, self.vicond_size).normal_() 332 | vicond[:, self.dim_Z:] = vc 333 | 334 | H,_ = self.gru(code) 335 | H = H.permute(0, 2, 1)[..., None, None] 336 | # code.shape: (N, code_size) 337 | # H.shape: (N, code_size, vlen//2, 1, 1) 338 | 339 | H = self.gblock1(H, hicond) 340 | H = self.gblock2(H, vicond) 341 | H = self.gblock3(H, vicond) 342 | H,_ = self.cgru(H) 343 | # H.shape: (N, base_width*2, vlen, 8, 8) 344 | 345 | H = torch.flatten(H.permute(0,2,1,3,4), 0, 1) 346 | H = self.gblock4(H, imcond) 347 | H = self.gblock5(H, imcond) 348 | H = H.view(N, vlen, H.size(1), H.size(2), -1) 349 | # H.shape: (N*vlen, base_width*2, 8, 8) 350 | # H.shape: (N*vlen, n_colors, 32, 32) 351 | # return shape: (N, n_colors, vlen, 32, 32) 352 | 353 | return torch.tanh(H.permute(0,2,1,3,4)) 354 | -------------------------------------------------------------------------------- /src/process3-5.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import shutil\n", 10 | "import numpy as np\n", 11 | "from pathlib import Path\n", 12 | "\n", 13 | "import cv2\n", 14 | "import torch\n", 15 | "import torch.nn as nn\n", 16 | "import torch.optim as optim\n", 17 | "\n", 18 | "from torch.utils.data import Dataset, DataLoader\n", 19 | "from torch.utils.tensorboard import SummaryWriter\n", 20 | "\n", 21 | "from data_prep2 import LabeledVideoDataset\n", 22 | "from text_processing2 import selectTemplates\n", 23 | "\n", 24 | "from models import *\n", 25 | "from utils import to_video, selectFramesRandomly, calc_grad_penalty\n", 26 | "from visual_encoders import *" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "path2labels = '../../20bn-sth-v2/labels/train.json'\n", 36 | "path2videos = '../../20bn-sth-v2/videos'\n", 37 | "cache = '../.cache'\n", 38 | "!mkdir -p {cache}\n", 39 | "\n", 40 | "glove_folder = '../embeddings'\n", 41 | "!wget -Nq -P .. https://nlp.stanford.edu/data/glove.6B.zip\n", 42 | "!unzip -nq -d {glove_folder} ../glove.6B.zip\n", 43 | "\n", 44 | "# !rm -f {cache}/*\n", 45 | "templates = ['Pushing [something] from left to right']\n", 46 | "path2labels = selectTemplates(path2labels, templates, '1-template.pkl')\n", 47 | "\n", 48 | "lvds = LabeledVideoDataset(\n", 49 | " path2labels, path2videos, cache,\n", 50 | " video_shape=(32, 32, 32, 3), glove_folder=glove_folder)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "# device = torch.device(\"cuda:3\")\n", 60 | "device = torch.device(\"cuda\")\n", 61 | "batch_size = 128\n", 62 | "\n", 63 | "major_data = lvds.data['major']\n", 64 | "minor_data = lvds.data['minor']\n", 65 | "major_data = np.rec.array(major_data)\n", 66 | "minor_data = np.rec.array(minor_data)\n", 67 | "\n", 68 | "# If working with a small dataset, transfer it entirely on the device\n", 69 | "\n", 70 | "#label = torch.tensor(major_data.f0, device=device)\n", 71 | "#slens = torch.tensor(major_data.f1, device=device)\n", 72 | "videos = torch.tensor(major_data.f2, device=device)\n", 73 | "\n", 74 | "obj_vec = torch.tensor(minor_data.f0.astype('int'), device=device)\n", 75 | "# act_vec = torch.tensor(minor_data.f1.astype('int'), device=device)\n", 76 | "# act_len = torch.tensor(minor_data.f2.astype('int'), device=device)\n", 77 | "\n", 78 | "emb_weights = lvds.getGloveEmbeddings(True)\n", 79 | "emb_weights = torch.tensor(emb_weights, device=device)\n", 80 | "encoder = SimpleTextEncoder(emb_weights)\n", 81 | "embeddings = encoder(obj_vec, 1)\n", 82 | "\n", 83 | "class BN20sthsth(Dataset):\n", 84 | " def __init__(self, embeddings, videos):\n", 85 | " self.embs = embeddings\n", 86 | " self.gifs = videos\n", 87 | " \n", 88 | " def __getitem__(self, idx):\n", 89 | " return self.embs[idx], self.gifs[idx]\n", 90 | " \n", 91 | " def __len__(self):\n", 92 | " return len(self.embs)\n", 93 | "\n", 94 | "ds = BN20sthsth(embeddings, videos)\n", 95 | "dl = DataLoader(ds, batch_size, shuffle=True, drop_last=True)\n", 96 | " \n", 97 | "# validation: book, green cup, red candle, an orange bowl (not ordered?)\n", 98 | "# val_samples = [168029, 157604, 71563, 82109] \n", 99 | "# train: book, box, mug, marker (ordered)\n", 100 | "val_samples = [118889, 65005, 162293, 73929] \n", 101 | "\n", 102 | "major_val, minor_val = lvds.getById(val_samples)\n", 103 | "val_obj= torch.tensor(minor_val.f0.astype('int'), device=device)\n", 104 | "test = encoder(val_obj, 1)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "gen = TestVideoGenerator(dim_Z=50, cond_size=50)\n", 114 | "i_dis = ProjectionImageDiscriminator(50, logits=False)\n", 115 | "v_dis = ProjectionVideoDiscriminator(50, logits=False)\n", 116 | "\n", 117 | "gen = gen.to(device)\n", 118 | "i_dis = i_dis.to(device)\n", 119 | "v_dis = v_dis.to(device)\n", 120 | "\n", 121 | "# models should be on the first device in the device_ids list \n", 122 | "# before making a call to nn.DataParallel class\n", 123 | "# gen = nn.DataParallel(gen, device_ids=[3, 4, 5])\n", 124 | "# i_dis = nn.DataParallel(i_dis, device_ids=[3, 4, 5])\n", 125 | "# v_dis = nn.DataParallel(v_dis, device_ids=[3, 4, 5])\n", 126 | "lr_g = 0.00005\n", 127 | "lr_d = 0.0002\n", 128 | "\n", 129 | "g_opt = optim.Adam(gen.parameters(), lr=lr_g, betas=(0.3, 0.999))\n", 130 | "i_opt = optim.Adam(i_dis.parameters(), lr=lr_d, betas=(0.3, 0.999))\n", 131 | "v_opt = optim.Adam(v_dis.parameters(), lr=lr_d, betas=(0.3, 0.999))\n", 132 | "\n", 133 | "logdir = Path('../runs/process_3')\n", 134 | "experiment_name = 'testing_proj_discs'" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "# Start from scratch\n", 144 | "if logdir.exists():\n", 145 | " shutil.rmtree(logdir)\n", 146 | "\n", 147 | "CNT = 0\n", 148 | "start = 1" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "k_d, k_g = 2, 1\n", 158 | "hyppar = 100\n", 159 | "num_epochs = 10000\n", 160 | "submit_period = 1\n", 161 | "save_period = 200\n", 162 | "n_cp = 10\n", 163 | "\n", 164 | "loss = {}\n", 165 | "epoch_loss = {'D': 0, 'G': 0}\n", 166 | "writer = SummaryWriter(logdir)\n", 167 | "\n", 168 | "\n", 169 | "for epoch in range(start, num_epochs+1):\n", 170 | " print(f'EPOCH: {epoch} / {num_epochs}')\n", 171 | " for No, (e, v) in enumerate(dl):\n", 172 | " ne = torch.roll(e, 1, 0)\n", 173 | " f_ids = selectFramesRandomly(16, 8)\n", 174 | " i = v[:, :, f_ids, ...].permute(0, 2, 1, 3, 4)\n", 175 | " for _ in range(k_d):\n", 176 | " with torch.no_grad():\n", 177 | " gv = gen(e)\n", 178 | " gi = gv[:, :, f_ids, ...].permute(0, 2, 1, 3, 4)\n", 179 | " # i = image, v = video, e = embedding, g = generated,\n", 180 | " # s = scores, n = negative, p = positive\n", 181 | " \n", 182 | " pis, nis, gis = i_dis(i, e), i_dis(i, ne), i_dis(gi, e)\n", 183 | " pvs, nvs, gvs = v_dis(v, e), v_dis(v, ne), v_dis(gv, e)\n", 184 | " l1 = pis.log().mean() + pvs.log().mean()\n", 185 | " l2 = (-nis).log1p().mean() + (-nvs).log1p().mean()\n", 186 | " l3 = (-gis).log1p().mean() + (-gvs).log1p().mean()\n", 187 | " l4 = calc_grad_penalty(v, gv, v_dis, e)\n", 188 | " l5 = calc_grad_penalty(i, gi, i_dis, e)\n", 189 | " L = -.33*(l1+l2+l3) + hyppar*(l4+l5)\n", 190 | " \n", 191 | " i_opt.zero_grad()\n", 192 | " v_opt.zero_grad()\n", 193 | " L.backward()\n", 194 | " i_opt.step()\n", 195 | " v_opt.step()\n", 196 | " \n", 197 | " epoch_loss['D'] += L.item()\n", 198 | " \n", 199 | " for _ in range(k_g):\n", 200 | " gv = gen(e)\n", 201 | " gi = gv[:, :, f_ids, ...].permute(0, 2, 1, 3, 4)\n", 202 | " L = -.5 * (i_dis(gi, e).log().mean() + v_dis(gv, e).log().mean())\n", 203 | " \n", 204 | " g_opt.zero_grad()\n", 205 | " L.backward()\n", 206 | " g_opt.step()\n", 207 | " \n", 208 | " epoch_loss['G'] += L.item()\n", 209 | " \n", 210 | " for k, v in epoch_loss.items():\n", 211 | " loss[k] = v / (No + 1)\n", 212 | " epoch_loss = dict.fromkeys(epoch_loss, 0)\n", 213 | " \n", 214 | " if epoch % submit_period == 0:\n", 215 | " gen.eval()\n", 216 | " with torch.no_grad():\n", 217 | " gv = gen(test)\n", 218 | " writer.add_scalars('Loss_ap', loss, epoch)\n", 219 | " writer.add_video('Fakes_ap', to_video(gv), epoch)\n", 220 | " gen.train()\n", 221 | " \n", 222 | " if epoch % save_period == 0:\n", 223 | " print('Saving the progress')\n", 224 | " # Save models themselves as well (in case you're gonna change them latter)\n", 225 | " checkpoint = {\n", 226 | " 'epoch': epoch,\n", 227 | " 'next_checkpoint_No': CNT+1,\n", 228 | " 'gen_state': gen.state_dict(), \n", 229 | " 'i_dis_state': i_dis.state_dict(),\n", 230 | " 'v_dis_state': v_dis.state_dict(),\n", 231 | " 'gen_model': gen,\n", 232 | " 'i_dis_model': i_dis,\n", 233 | " 'v_dis_model': v_dis,\n", 234 | " 'g_opt_dict': g_opt.state_dict(),\n", 235 | " 'i_opt_dict': i_opt.state_dict(),\n", 236 | " 'v_opt_dict': v_opt.state_dict()}\n", 237 | " save_as = f'../checkpoints/{experiment_name}-cp{CNT%n_cp}.tar'\n", 238 | " torch.save(checkpoint, save_as)\n", 239 | " CNT += 1\n", 240 | " print('Done!')" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "# MANUAL SAVING AFTER INTERRUPTING THE KERNEL\n", 250 | "checkpoint = {\n", 251 | " 'epoch': epoch, \n", 252 | " 'next_checkpoint_No': CNT+1,\n", 253 | " 'gen_state': gen.state_dict(), \n", 254 | " 'i_dis_state': i_dis.state_dict(),\n", 255 | " 'v_dis_state': v_dis.state_dict(),\n", 256 | " 'gen_model': gen,\n", 257 | " 'i_dis_model': i_dis,\n", 258 | " 'v_dis_model': v_dis,\n", 259 | " 'g_opt_dict': g_opt.state_dict(),\n", 260 | " 'i_opt_dict': i_opt.state_dict(),\n", 261 | " 'v_opt_dict': v_opt.state_dict()}\n", 262 | "save_as = f'../checkpoints/{experiment_name}-cp{CNT%n_cp}.tar'\n", 263 | "torch.save(checkpoint, save_as)" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "# RESUMING THE TRAINING\n", 273 | "# >> initialize models and optimizers >> \n", 274 | "#gen = \n", 275 | "#i_dis =\n", 276 | "#v_dis =\n", 277 | "# << initialize models <<\n", 278 | "\n", 279 | "# if `save_as` contains a string defined during training, then use it\n", 280 | "# otherwise\n", 281 | "load_this = !ls -lt ../checkpoints/{experiment_name}* | head -1 | rev | cut -d ' ' -f1 | rev\n", 282 | "print(load_this[0])\n", 283 | "checkpoint = torch.load(load_this[0])\n", 284 | "gen.load_state_dict(checkpoint['gen_state'])\n", 285 | "i_dis.load_state_dict(checkpoint['i_dis_state'])\n", 286 | "v_dis.load_state_dict(checkpoint['v_dis_state'])\n", 287 | "g_opt.load_state_dict(checkpoint['g_opt_dict'])\n", 288 | "i_opt.load_state_dict(checkpoint['i_opt_dict'])\n", 289 | "v_opt.load_state_dict(checkpoint['v_opt_dict'])\n", 290 | "start = checkpoint['epoch']\n", 291 | "CNT = checkpoint['next_checkpoint_No']\n", 292 | "gen.train();\n", 293 | "\n", 294 | "# now you can run the block with training" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": null, 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [ 303 | "# LOADDING FOR INFERENCE\n", 304 | "# checkpoint = torch.load(f'../checkpoints/{experiment_name}-cp1.tar')\n", 305 | "load_this = !ls -lt ../checkpoints/{experiment_name}* | head -1 | rev | cut -d ' ' -f1 | rev\n", 306 | "checkpoint = torch.load(load_this[0])\n", 307 | "\n", 308 | "# IF THE ORIGINAL ARCHITECTURE IS PRESERVED \n", 309 | "gen.load_state_dict(checkpoint['gen_state'])\n", 310 | "# gen.eval()\n", 311 | "# do something\n", 312 | "\n", 313 | "# -- OTHERWISE --\n", 314 | "#gen = checkpoint['gen_model']\n", 315 | "#gen.eval()\n", 316 | "# do something" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "metadata": {}, 323 | "outputs": [], 324 | "source": [ 325 | "# LOADING IF SOME OF MODEL COMPONENTS HAVE BEEN CHANGED\n", 326 | "load_this = !ls -lt ../checkpoints/{experiment_name}* | head -1 | rev | cut -d ' ' -f1 | rev\n", 327 | "checkpoint = torch.load(load_this[0])\n", 328 | "\n", 329 | "gen = checkpoint['gen_model']" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": null, 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [ 338 | "test_obj = 'penny'\n", 339 | "\n", 340 | "books = torch.tensor(lvds._glove['books']).cuda(device)\n", 341 | "box = torch.tensor(lvds._glove['box']).cuda(device)\n", 342 | "carton = torch.tensor(lvds._glove['carton']).cuda(device)\n", 343 | "paper = torch.tensor(lvds._glove['paper']).cuda(device)\n", 344 | "notebook = torch.tensor(lvds._glove['notebook']).cuda(device)\n", 345 | "album = torch.tensor(lvds._glove['album']).cuda(device)\n", 346 | "book = torch.tensor(lvds._glove['book']).cuda(device)\n", 347 | "cent = torch.tensor(lvds._glove['cent']).cuda(device)\n", 348 | "penny = torch.tensor(lvds._glove['penny']).cuda(device)\n", 349 | "coin = torch.tensor(lvds._glove['coin']).cuda(device)\n", 350 | "test_embs = torch.stack((penny, coin, album, books), dim=0)" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": null, 356 | "metadata": {}, 357 | "outputs": [], 358 | "source": [ 359 | "cos_sim = nn.CosineSimilarity(dim=0)\n", 360 | "cos_sim(books, book)" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": null, 366 | "metadata": {}, 367 | "outputs": [], 368 | "source": [ 369 | "# INFERENCE\n", 370 | "gen.eval()\n", 371 | "with torch.no_grad():\n", 372 | " infer = gen(test_embs)\n", 373 | "\n", 374 | "generated_video = ((infer+1) / 2 * 255).permute(0, 2, 3, 4, 1).cpu().numpy().astype('uint8')" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": null, 380 | "metadata": {}, 381 | "outputs": [], 382 | "source": [ 383 | "# WRITE VIDEOS TO FILES\n", 384 | "fourcc = cv2.VideoWriter_fourcc(*'XVID')\n", 385 | "im_size = (128, 128)\n", 386 | "subjects = ['book', 'box', 'mug', 'marker']\n", 387 | "folder = Path('generated_video')\n", 388 | "if not folder.exists():\n", 389 | " folder.mkdir()\n", 390 | "for i, subj in enumerate(subjects):\n", 391 | " out = cv2.VideoWriter(f'{folder}/{subj}.avi', fourcc, 4., im_size)\n", 392 | " for frame in generated_video[i]:\n", 393 | " frame = cv2.resize(frame, im_size)\n", 394 | " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n", 395 | " out.write(frame)\n", 396 | " out.release()" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": null, 402 | "metadata": {}, 403 | "outputs": [], 404 | "source": [ 405 | "# WRITE VIDEOS TO FILES\n", 406 | "fourcc = cv2.VideoWriter_fourcc(*'XVID')\n", 407 | "im_size = (128, 128)\n", 408 | "subjects = ['penny', 'coin', 'album', 'books']\n", 409 | "folder = Path('generated_video')\n", 410 | "if not folder.exists():\n", 411 | " folder.mkdir()\n", 412 | "for i, subj in enumerate(subjects):\n", 413 | " out = cv2.VideoWriter(f'{folder}/{subj}.avi', fourcc, 4., im_size)\n", 414 | " for frame in generated_video[i]:\n", 415 | " frame = cv2.resize(frame, im_size)\n", 416 | " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n", 417 | " out.write(frame)\n", 418 | " out.release()" 419 | ] 420 | } 421 | ], 422 | "metadata": { 423 | "kernelspec": { 424 | "display_name": "Python 3 (ipykernel)", 425 | "language": "python", 426 | "name": "python3" 427 | }, 428 | "language_info": { 429 | "codemirror_mode": { 430 | "name": "ipython", 431 | "version": 3 432 | }, 433 | "file_extension": ".py", 434 | "mimetype": "text/x-python", 435 | "name": "python", 436 | "nbconvert_exporter": "python", 437 | "pygments_lexer": "ipython3", 438 | "version": "3.8.10" 439 | }, 440 | "toc": { 441 | "base_numbering": 1, 442 | "nav_menu": {}, 443 | "number_sections": true, 444 | "sideBar": true, 445 | "skip_h1_title": false, 446 | "title_cell": "Table of Contents", 447 | "title_sidebar": "Contents", 448 | "toc_cell": false, 449 | "toc_position": {}, 450 | "toc_section_display": true, 451 | "toc_window_display": false 452 | } 453 | }, 454 | "nbformat": 4, 455 | "nbformat_minor": 2 456 | } 457 | -------------------------------------------------------------------------------- /src/pipeline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "import pickle\n", 11 | "from pathlib import Path\n", 12 | "import shutil\n", 13 | "import numpy as np\n", 14 | "\n", 15 | "import torch\n", 16 | "import torch.nn as nn\n", 17 | "import torch.optim as optim\n", 18 | "\n", 19 | "\n", 20 | "from torch.utils.data import Dataset, DataLoader\n", 21 | "from torch.utils.tensorboard import SummaryWriter\n", 22 | "\n", 23 | "from data_prep2 import LabeledVideoDataset\n", 24 | "from text_processing2 import selectTemplates\n", 25 | "\n", 26 | "from models import *\n", 27 | "from utils import *\n", 28 | "from visual_encoders import *" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "path2labels = '../../20bn-sth-v2/labels/train.json'\n", 38 | "path2videos = '../../20bn-sth-v2/videos'\n", 39 | "cache = '../.cache'\n", 40 | "!mkdir -p {cache}\n", 41 | "\n", 42 | "glove_folder = '../embeddings'\n", 43 | "!wget -Nq -P .. https://nlp.stanford.edu/data/glove.6B.zip\n", 44 | "!unzip -nq -d {glove_folder} ../glove.6B.zip\n", 45 | "\n", 46 | "# !rm -f {cache}/*\n", 47 | "templates = ['Pushing [something] from left to right']\n", 48 | "path2labels = selectTemplates(path2labels, templates, '1-template.pkl')\n", 49 | "\n", 50 | "lvds = LabeledVideoDataset(\n", 51 | " path2labels, path2videos, cache,\n", 52 | " video_shape=(32, 32, 32, 3), glove_folder=glove_folder)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "device = torch.device(\"cuda:0\")\n", 62 | "batch_size = 128\n", 63 | "\n", 64 | "major_data = lvds.data['major']\n", 65 | "minor_data = lvds.data['minor']\n", 66 | "major_data = np.rec.array(major_data)\n", 67 | "minor_data = np.rec.array(minor_data)\n", 68 | "\n", 69 | "# If working with a small dataset, transfer it entirely on the device\n", 70 | "\n", 71 | "#label = torch.tensor(major_data.f0, device=device)\n", 72 | "#slens = torch.tensor(major_data.f1, device=device)\n", 73 | "video = torch.tensor(major_data.f2, device=device)\n", 74 | "obj_vec = torch.tensor(minor_data.f0.astype('int'), device=device)\n", 75 | "#act_vec = torch.tensor(minor_data.f1.astype('int'), device=device)\n", 76 | "#act_len = torch.tensor(minor_data.f2.astype('int'), device=device)\n", 77 | "\n", 78 | "emb_weights = lvds.getGloveEmbeddings(True)\n", 79 | "emb_weights = torch.tensor(emb_weights, device=device)\n", 80 | "encoder = SimpleTextEncoder(emb_weights)\n", 81 | "embeddings = encoder(obj_vec, 1)\n", 82 | "\n", 83 | "class BN20sthsth(Dataset):\n", 84 | " def __init__(self, embeddings, video):\n", 85 | " self.embs = embeddings\n", 86 | " self.gifs = video\n", 87 | " \n", 88 | " def __getitem__(self, idx):\n", 89 | " return self.embs[idx], self.gifs[idx]\n", 90 | " \n", 91 | " def __len__(self):\n", 92 | " return len(self.embs)\n", 93 | " \n", 94 | "ds = BN20sthsth(embeddings, video)\n", 95 | "dl = DataLoader(ds, batch_size, shuffle=True, drop_last=True)\n", 96 | " \n", 97 | "# validation: book, green cup, red candle, an orange bowl (not ordered?)\n", 98 | "#val_samples = [168029, 157604, 71563, 82109] \n", 99 | "# train: book, box, mug, marker (ordered)\n", 100 | "val_samples = [118889, 65005, 162293, 73929] \n", 101 | "\n", 102 | "major_val, minor_val = lvds.getById(val_samples)\n", 103 | "val_obj= torch.tensor(minor_val.f0.astype('int'), device=device)\n", 104 | "test = encoder(val_obj, 1)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "mbs = 16\n", 114 | "dim_Z = 50\n", 115 | "cond_size = 50\n", 116 | "\n", 117 | "gen = TestVideoGenerator(dim_Z=dim_Z, cond_size=cond_size)\n", 118 | "i_dis = ProjectionImageDiscriminator(cond_size=cond_size, logits=False)\n", 119 | "av_dis = ProjectionVideoDiscriminator(cond_size=cond_size, logits=False)\n", 120 | "lv_dis = BatchVideoDiscriminator(mbs=mbs, cond_size=cond_size)\n", 121 | "\n", 122 | "gen = gen.to(device)\n", 123 | "i_dis = i_dis.to(device)\n", 124 | "av_dis = av_dis.to(device)\n", 125 | "lv_dis = lv_dis.to(device)\n", 126 | "\n", 127 | "# models should be on the first device in the device_ids list \n", 128 | "# before making a call to nn.DataParallel class\n", 129 | "gen = nn.DataParallel(gen, device_ids=[0, 1, 2, 3, 4, 5])\n", 130 | "i_dis = nn.DataParallel(i_dis, device_ids=[0, 1, 2, 3, 4, 5])\n", 131 | "av_dis = nn.DataParallel(av_dis, device_ids=[0, 1, 2, 3, 4, 5])\n", 132 | "lv_dis = nn.DataParallel(lv_dis, device_ids=[0, 1, 2, 3, 4, 5])\n", 133 | "lr_g = 0.00005\n", 134 | "lr_d = 0.0002\n", 135 | "\n", 136 | "g_opt = optim.Adam(gen.parameters(), lr=lr_g, betas=(0.3, 0.999))\n", 137 | "i_opt = optim.Adam(i_dis.parameters(), lr=lr_d, betas=(0.3, 0.999))\n", 138 | "av_opt = optim.Adam(av_dis.parameters(), lr=lr_d, betas=(0.3, 0.999))\n", 139 | "lv_opt = optim.Adam(lv_dis.parameters(), lr=lr_d, betas=(0.3, 0.999))\n", 140 | "\n", 141 | "experiment_name = 'one_cat-1000samples'\n", 142 | "logdir = Path('runs/process_0')" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "video_shape = (3, 16, 32, 32)\n", 152 | "batch_size = 256\n", 153 | "minibatch_size = 16\n", 154 | "gamma = .3\n", 155 | "\n", 156 | "def truncated_uniform(gamma=.3):\n", 157 | " r = gamma * torch.rand(1)\n", 158 | " if torch.randint(2, (1,)): return r\n", 159 | " else: return 1-r" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "# Problem in this cell: some of the variables that require grad are reused\n", 169 | "hyppar = 100\n", 170 | "real_buffer = [torch.tensor([], device=device)] * 2\n", 171 | "fake_buffer = [torch.tensor([], device=device)] * 3\n", 172 | "\n", 173 | "loss = {}\n", 174 | "epoch_loss = {'D': 0, 'G': 0}\n", 175 | "\n", 176 | "for No, (E, V) in enumerate(dl):\n", 177 | " real_buffer[0] = torch.cat([real_buffer[0], E])\n", 178 | " real_buffer[1] = torch.cat([real_buffer[1], V])\n", 179 | " \n", 180 | " shuffle = torch.randperm(batch_size)\n", 181 | " FE = E[shuffle]\n", 182 | " fake_buffer[0] = torch.cat([fake_buffer[0], FE])\n", 183 | " \n", 184 | " with torch.no_grad():\n", 185 | " FV = gen(FE)\n", 186 | " fake_buffer[1] = torch.cat([fake_buffer[1], FV])\n", 187 | " \n", 188 | " FV = gen(FE)\n", 189 | " fake_buffer[2] = torch.cat([fake_buffer[2], FV])\n", 190 | " \n", 191 | " multibatch = (E.new(batch_size, cond_size), E.new(batch_size, *video_shape))\n", 192 | " while (len(real_buffer[0]) >= batch_size and\n", 193 | " len(fake_buffer[0]) >= batch_size):\n", 194 | " n = batch_size // minibatch_size\n", 195 | " u = E.new(n)\n", 196 | " B = []\n", 197 | " for i in range(n):\n", 198 | " p = truncated_uniform(gamma)\n", 199 | " beta = torch.bernoulli(torch.empty(minibatch_size), p.item()).bool()\n", 200 | " u[i] = beta.sum().float() / minibatch_size\n", 201 | " B.append(beta)\n", 202 | " \n", 203 | " B = torch.cat(B)\n", 204 | " x = B.sum()\n", 205 | " y = batch_size - x\n", 206 | " \n", 207 | " E = real_buffer[0][:x].clone()\n", 208 | " V = real_buffer[1][:x].clone()\n", 209 | " multibatch[0][B] = E\n", 210 | " multibatch[1][B] = V\n", 211 | " multibatch[0][~B] = fake_buffer[0][:y].clone()\n", 212 | " multibatch[1][~B] = fake_buffer[1][:y].clone()\n", 213 | " \n", 214 | " real_buffer[0] = real_buffer[0][x:]\n", 215 | " real_buffer[1] = real_buffer[1][x:]\n", 216 | " fake_buffer[0] = fake_buffer[0][y:]\n", 217 | " fake_buffer[1] = fake_buffer[1][y:]\n", 218 | " \n", 219 | " with torch.no_grad():\n", 220 | " GV = gen(E)\n", 221 | " \n", 222 | " torch.cuda.empty_cache()\n", 223 | " NE = torch.roll(E, 1, 0)\n", 224 | " ids = selectFramesRandomly(16, 8)\n", 225 | " I = V[:, :, ids, ...].permute(0, 2, 1, 3, 4)\n", 226 | " GI = GV[:, :, ids, ...].permute(0, 2, 1, 3, 4)\n", 227 | " \n", 228 | " L = batchGAN_DLoss(u, multibatch, lv_dis)\n", 229 | " L += .5 * vanilla_DLoss(V, GV, E, NE, av_dis)\n", 230 | " L += .5 * vanilla_DLoss(I, GI, E, NE, i_dis)\n", 231 | " L += hyppar * sum(\n", 232 | " map(calc_grad_penalty, (V, I), (GV, GI), (av_dis, i_dis), (E, E)))\n", 233 | " \n", 234 | " sys.stdout.flush()\n", 235 | " i_opt.zero_grad()\n", 236 | " av_opt.zero_grad()\n", 237 | " lv_opt.zero_grad()\n", 238 | " L.backward(retain_graph=True)\n", 239 | " i_opt.step()\n", 240 | " av_opt.step()\n", 241 | " lv_opt.step()\n", 242 | " \n", 243 | " epoch_loss['D'] += L.item()\n", 244 | " \n", 245 | " GV = gen(E)\n", 246 | " GI = GV[:, :, ids, ...].permute(0, 2, 1, 3, 4)\n", 247 | " multibatch[1][~B] = fake_buffer[2][:y]\n", 248 | " fake_buffer[2] = fake_buffer[2][y:]\n", 249 | " L = batchGAN_GLoss(multibatch, lv_dis)\n", 250 | " L += .5 * vanilla_GLoss2(GV, E, av_dis)\n", 251 | " L += .5 * vanilla_GLoss2(GI, E, i_dis)\n", 252 | " \n", 253 | " epoch_loss['G'] += L.item()\n", 254 | " for k, v in epoch_loss.items():\n", 255 | " loss[k] = v / (No + 1)\n", 256 | " epoch_loss = dict.fromkeys(epoch_loss, 0)\n", 257 | " torch.cuda.empty_cache()" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "# START TRAINING FROM SCRATCH\n", 267 | "if logdir.exists():\n", 268 | " shutil.rmtree(logdir)\n", 269 | "\n", 270 | "CNT = 0\n", 271 | "start = 1" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": null, 277 | "metadata": {}, 278 | "outputs": [], 279 | "source": [ 280 | "k_d, k_g = 2, 1\n", 281 | "hyppar = 1000\n", 282 | "num_epochs = 10000\n", 283 | "submit_period = 1\n", 284 | "save_period = 200\n", 285 | "n_cp = 10\n", 286 | "\n", 287 | "loss = {}\n", 288 | "epoch_loss = {'D': 0, 'G': 0}\n", 289 | "writer = SummaryWriter(logdir)\n", 290 | "\n", 291 | "\n", 292 | "for epoch in range(start, num_epochs+1):\n", 293 | " print(f'EPOCH: {epoch} / {num_epochs}')\n", 294 | " for No, (e, v) in enumerate(dl):\n", 295 | " ne = torch.roll(e, 1, 0)\n", 296 | " f_ids = selectFramesRandomly(16, 8)\n", 297 | " i = v[:, :, f_ids, ...].permute(0, 2, 1, 3, 4)\n", 298 | " for _ in range(k_d):\n", 299 | " with torch.no_grad():\n", 300 | " gv = gen(e)\n", 301 | " gi = gv[:, :, f_ids, ...].permute(0, 2, 1, 3, 4)\n", 302 | " # i = image, v = video, e = embedding, g = generated,\n", 303 | " # s = scores, n = negative, p = positive\n", 304 | " \n", 305 | " pis, nis, gis = i_dis(i, e), i_dis(i, ne), i_dis(gi, e)\n", 306 | " pvs, nvs, gvs = v_dis(v, e), v_dis(v, ne), v_dis(gv, e)\n", 307 | " l1 = pis.log().mean() + pvs.log().mean()\n", 308 | " l2 = (-nis).log1p().mean() + (-nvs).log1p().mean()\n", 309 | " l3 = (-gis).log1p().mean() + (-gvs).log1p().mean()\n", 310 | " l4 = calc_grad_penalty(v, gv, v_dis, e)\n", 311 | " l5 = calc_grad_penalty(i, gi, i_dis, e)\n", 312 | " L = -.33*(l1+l2+l3) + hyppar*(l4+l5)\n", 313 | " \n", 314 | " i_opt.zero_grad()\n", 315 | " v_opt.zero_grad()\n", 316 | " L.backward()\n", 317 | " i_opt.step()\n", 318 | " v_opt.step()\n", 319 | " \n", 320 | " epoch_loss['D'] += L.item()\n", 321 | " \n", 322 | " for _ in range(k_g):\n", 323 | " gv = gen(e)\n", 324 | " gi = gv[:, :, f_ids, ...].permute(0, 2, 1, 3, 4)\n", 325 | " L = -.5 * (i_dis(gi, e).log().mean() + v_dis(gv, e).log().mean())\n", 326 | " \n", 327 | " g_opt.zero_grad()\n", 328 | " L.backward()\n", 329 | " g_opt.step()\n", 330 | " \n", 331 | " epoch_loss['G'] += L.item()\n", 332 | " \n", 333 | " for k, v in epoch_loss.items():\n", 334 | " loss[k] = v / (No + 1)\n", 335 | " epoch_loss = dict.fromkeys(epoch_loss, 0)\n", 336 | " \n", 337 | " if epoch % submit_period == 0:\n", 338 | " gen.eval()\n", 339 | " with torch.no_grad():\n", 340 | " gv = gen(test)\n", 341 | " writer.add_scalars('Loss_ap', loss, epoch)\n", 342 | " writer.add_video('Fakes_ap', to_video(gv), epoch)\n", 343 | " gen.train()\n", 344 | " \n", 345 | " if epoch % save_period == 0:\n", 346 | " print('Saving the progress')\n", 347 | " # Save models themselves as well (in case you're gonna change them latter)\n", 348 | " checkpoint = {\n", 349 | " 'epoch': epoch,\n", 350 | " 'next_checkpoint_No': CNT+1,\n", 351 | " 'gen_state': gen.state_dict(), \n", 352 | " 'i_dis_state': i_dis.state_dict(),\n", 353 | " 'v_dis_state': v_dis.state_dict(),\n", 354 | " 'gen_model': gen,\n", 355 | " 'i_dis_model': i_dis,\n", 356 | " 'v_dis_model': v_dis,\n", 357 | " 'g_opt_dict': g_opt.state_dict(),\n", 358 | " 'i_opt_dict': i_opt.state_dict(),\n", 359 | " 'v_opt_dict': v_opt.state_dict()}\n", 360 | " save_as = f'../checkpoints/{experiment_name}-cp{CNT%n_cp}.tar'\n", 361 | " torch.save(checkpoint, save_as)\n", 362 | " CNT += 1\n", 363 | " print('Done!')" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": null, 369 | "metadata": {}, 370 | "outputs": [], 371 | "source": [ 372 | "# MANUAL SAVING AFTER INTERRUPTING THE KERNEL\n", 373 | "checkpoint = {\n", 374 | " 'epoch': epoch, \n", 375 | " 'next_checkpoint_No': CNT+1,\n", 376 | " 'gen_state': gen.state_dict(), \n", 377 | " 'i_dis_state': i_dis.state_dict(),\n", 378 | " 'v_dis_state': v_dis.state_dict(),\n", 379 | " 'gen_model': gen,\n", 380 | " 'i_dis_model': i_dis,\n", 381 | " 'v_dis_model': v_dis,\n", 382 | " 'g_opt_dict': g_opt.state_dict(),\n", 383 | " 'i_opt_dict': i_opt.state_dict(),\n", 384 | " 'v_opt_dict': v_opt.state_dict()}\n", 385 | "save_as = f'../checkpoints/{experiment_name}-cp{CNT%n_cp}.tar'\n", 386 | "torch.save(checkpoint, save_as)" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": null, 392 | "metadata": {}, 393 | "outputs": [], 394 | "source": [ 395 | "# RESUMING THE TRAINING\n", 396 | "# >> initialize models and optimizers >> \n", 397 | "#gen = \n", 398 | "#i_dis =\n", 399 | "#v_dis =\n", 400 | "# << initialize models <<\n", 401 | "\n", 402 | "# if `save_as` contains a string defined during training, then use it\n", 403 | "# otherwise\n", 404 | "load_this = !ls -lt ../checkpoints/{experiment_name}* | head -1 | rev | cut -d ' ' -f1 | rev\n", 405 | "checkpoint = torch.load(load_this[0])\n", 406 | "gen.load_state_dict(checkpoint['gen_state'])\n", 407 | "i_dis.load_state_dict(checkpoint['i_dis_state'])\n", 408 | "v_dis.load_state_dict(checkpoint['v_dis_state'])\n", 409 | "g_opt.load_state_dict(checkpoint['g_opt_dict'])\n", 410 | "i_opt.load_state_dict(checkpoint['i_opt_dict'])\n", 411 | "v_opt.load_state_dict(checkpoint['v_opt_dict'])\n", 412 | "start = checkpoint['epoch']\n", 413 | "CNT = checkpoint['next_checkpoint_No']\n", 414 | "gen.train();\n", 415 | "\n", 416 | "# now you can run the block with training" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": null, 422 | "metadata": {}, 423 | "outputs": [], 424 | "source": [ 425 | "# LOADDING FOR INFERENCE\n", 426 | "checkpoint = torch.load(f'../checkpoints/{experiment_name}-cp1.tar')\n", 427 | "\n", 428 | "# IF THE ORIGINAL ARCHITECTURE IS PRESERVED \n", 429 | "gen.load_state_dict(checkpoint['gen_state'])\n", 430 | "# gen.eval()\n", 431 | "# do something\n", 432 | "\n", 433 | "# -- OTHERWISE --\n", 434 | "#gen = checkpoint['gen_model']\n", 435 | "#gen.eval()\n", 436 | "# do something" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": null, 442 | "metadata": {}, 443 | "outputs": [], 444 | "source": [ 445 | "# INFERENCE\n", 446 | "gen.eval()\n", 447 | "with torch.no_grad():\n", 448 | " infer = gen(test)\n", 449 | "\n", 450 | "generated_video = ((infer+1) / 2 * 255).permute(0, 2, 3, 4, 1).cpu().numpy().astype('uint8')" 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": null, 456 | "metadata": {}, 457 | "outputs": [], 458 | "source": [ 459 | "# WRITE VIDEOS TO FILES\n", 460 | "fourcc = cv2.VideoWriter_fourcc(*'XVID')\n", 461 | "im_size = (128, 128)\n", 462 | "subjects = ['book', 'box', 'mug', 'marker']\n", 463 | "folder = Path('generated_video')\n", 464 | "if not folder.exists():\n", 465 | " folder.mkdir()\n", 466 | "for i, subj in enumerate(subjects):\n", 467 | " out = cv2.VideoWriter(f'{folder}/{subj}.avi', fourcc, 4., im_size)\n", 468 | " for frame in generated_video[i]:\n", 469 | " frame = cv2.resize(frame, im_size)\n", 470 | " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n", 471 | " out.write(frame)\n", 472 | " out.release()" 473 | ] 474 | } 475 | ], 476 | "metadata": { 477 | "kernelspec": { 478 | "display_name": "Python 3 (ipykernel)", 479 | "language": "python", 480 | "name": "python3" 481 | }, 482 | "language_info": { 483 | "codemirror_mode": { 484 | "name": "ipython", 485 | "version": 3 486 | }, 487 | "file_extension": ".py", 488 | "mimetype": "text/x-python", 489 | "name": "python", 490 | "nbconvert_exporter": "python", 491 | "pygments_lexer": "ipython3", 492 | "version": "3.8.10" 493 | }, 494 | "toc": { 495 | "base_numbering": 1, 496 | "nav_menu": {}, 497 | "number_sections": true, 498 | "sideBar": true, 499 | "skip_h1_title": false, 500 | "title_cell": "Table of Contents", 501 | "title_sidebar": "Contents", 502 | "toc_cell": false, 503 | "toc_position": {}, 504 | "toc_section_display": true, 505 | "toc_window_display": false 506 | } 507 | }, 508 | "nbformat": 4, 509 | "nbformat_minor": 2 510 | } 511 | --------------------------------------------------------------------------------