├── .gitignore ├── README.md ├── classifier.py ├── dataset ├── create-lpc-dataset.py └── style-transfer-set.py ├── model.py ├── test ├── cosine-similarity │ ├── image1.png │ ├── image2.png │ ├── set1 │ │ ├── image1.png │ │ └── image2.png │ ├── set10 │ │ ├── image1.png │ │ └── image2.png │ ├── set11 │ │ ├── image1.png │ │ └── image2.png │ ├── set12 │ │ ├── image1.png │ │ └── image2.png │ ├── set2 │ │ ├── image1.png │ │ └── image2.png │ ├── set3 │ │ ├── image1.png │ │ └── image2.png │ ├── set4 │ │ ├── image1.png │ │ └── image2.png │ ├── set5 │ │ ├── image1.png │ │ └── image2.png │ ├── set6 │ │ ├── image1.png │ │ └── image2.png │ ├── set7 │ │ ├── image1.png │ │ └── image2.png │ ├── set8 │ │ ├── image1.png │ │ └── image2.png │ └── set9 │ │ ├── image1.png │ │ └── image2.png ├── style-transfer │ ├── set1 │ │ ├── image1.png │ │ ├── image1_body_image2_motion.png │ │ ├── image2.png │ │ └── image2_body_image1_motion.png │ ├── set2 │ │ ├── image1.png │ │ ├── image1_body_image2_motion.png │ │ ├── image2.png │ │ └── image2_body_image1_motion.png │ ├── set3 │ │ ├── image1.png │ │ ├── image1_body_image2_motion.png │ │ ├── image2.png │ │ └── image2_body_image1_motion.png │ ├── set4 │ │ ├── image1.png │ │ ├── image1_body_image2_motion.png │ │ ├── image2.png │ │ └── image2_body_image1_motion.png │ ├── set5 │ │ ├── image1.png │ │ ├── image1_body_image2_motion.png │ │ ├── image2.png │ │ └── image2_body_image1_motion.png │ ├── set6 │ │ ├── image1.png │ │ ├── image1_body_image2_motion.png │ │ ├── image2.png │ │ └── image2_body_image1_motion.png │ └── set7 │ │ ├── image1.png │ │ ├── image1_body_image2_motion.png │ │ ├── image2.png │ │ └── image2_body_image1_motion.png └── test_similarity.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.model 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Disentangled Sequential Autoencoder 2 | Reproduction of the ICML 2018 publication [Disentangled Sequential Autoencoder by Yinghen Li and Stephen Mandt](https://arxiv.org/abs/1803.02991), a Variational Autoencoder Architecture for learning latent representations of high dimensional sequential data by approximately disentangling the time invariant and the time variable features, without any modification to the ELBO objective. 3 | 4 | # Network Architecture 5 | 6 | ## Prior of z: 7 | 8 | The prior of `z` is a Gaussian with mean and variance computed by the LSTM as follows 9 | ``` 10 | h_t, c_t = prior_lstm(z_t-1, (h_t, c_t)) where h_t is the hidden state and c_t is the cell state 11 | ``` 12 | Now the hidden state ```h_t``` is used to compute the mean and variance of ```z_t``` using an affine transform 13 | ``` 14 | z_mean, z_log_variance = affine_mean(h_t), affine_logvar(h_t) 15 | z = reparameterize(z_mean, z_log_variance) 16 | ``` 17 | The hidden state has dimension 512 and z has dimension 32 18 | 19 | ## Convolutional Encoder: 20 | 21 | The convolutional encoder consists of 4 convolutional layers with 256 layers and a kernel size of 5 22 | Each convolution is followed by a batch normalization layer and a LeakyReLU(0.2) nonlinearity. 23 | For the 3,64,64 frames (all image dimensions are in channel, width, height) in the sprites dataset the following dimension changes take place 24 | 25 | ```3,64,64 -> 256,64,64 -> 256,32,32 -> 256,16,16 -> 256,8,8 (where each -> consists of a convolution, batch normalization followed by LeakyReLU(0.2))``` 26 | 27 | The 8,8,256 tensor is unrolled into a vector of size ```8*8*256``` which is then made to undergo the following tansformations 28 | 29 | ```8*8*256 -> 4096 -> 2048 (where each -> consists of an affine transformation, batch normalization followed by LeakyReLU(0.2)) ``` 30 | 31 | ## Approximate Posterior For f: 32 | 33 | The approximate posterior is parameterized by a bidirectional LSTM that takes the entire sequence of transformed ```x_t```s (after being fed into the convolutional encoder) 34 | as input in each timestep. The hidden layer dimension is 512 35 | 36 | Then the features from the unit corresponding to the last timestep of the forward LSTM and the unit corresponding to the first timestep of the 37 | backward LSTM (as shown in the diagram in the paper) are concatenated and fed to two affine layers (without any added nonlinearity) to compute 38 | the mean and variance of the Gaussian posterior for f 39 | 40 | ## Approximate Posterior for z (Factorized q) 41 | 42 | Each ```x_t``` is first fed into an affine layer followed by a LeakyReLU(0.2) nonlinearity to generate an intermediate feature vector of dimension 512, 43 | which is then followed by two affine layers (without any added nonlinearity) to compute the mean and variance of the Gaussian Posterior of each ```z_t``` 44 | 45 | ``` 46 | inter_t = intermediate_affine(x_t) 47 | z_mean_t, z_log_variance_t = affine_mean(inter_t), affine_logvar(inter_t) 48 | z = reparameterize(z_mean_t, z_log_variance_t) 49 | ``` 50 | 51 | ## Approximate Posterior for z (FULL q) 52 | 53 | The vector ```f``` is concatenated to each ```v_t``` where ```v_t``` is the encodings generated for each frame ```x_t``` by the convolutional encoder. This entire sequence is fed into a bi-LSTM 54 | of hidden layer dimension 512. Then the features of the forward and backward LSTMs are fed into an RNN having a hidden layer dimension 512. The output ```h_t``` of each timestep 55 | of this RNN transformed by two affine transformations (without any added nonlinearity) to compute the mean and variance of the Gaussian Posterior of each ```z_t``` 56 | 57 | ``` 58 | g_t = [v_t, f] for each timestep 59 | forward_features, backward_features = lstm(g_t for all timesteps) 60 | h_t = rnn([forward_features, backward_features]) 61 | z_mean_t, z_log_variance_t = affine_mean(h_t), affine_logvar(h_t) 62 | z = reparameterize(z_mean_t, z_log_variance_t) 63 | ``` 64 | 65 | ## Convolutional Decoder For Conditional Distribution 66 | 67 | The architecture is symmetric to that of the convolutional encoder. The vector ```f``` is concatenated to each ```z_t```, which then undergoes two subsequent 68 | affine transforms, causing the following change in dimensions 69 | 70 | ```256 + 32 -> 4096 -> 8*8*256``` (where each -> consists of an affine transformation, batch normalization followed by LeakyReLU(0.2)) 71 | 72 | The ```8*8*256``` tensor is reshaped into a tensor of shape 256,8,8 and then undergoes the following dimension changes 73 | 74 | ```256,8,8 -> 256,16,16 -> 256,32,32 -> 256,64,64 -> 3,64,64``` (where each -> consists of a transposed convolution, batch normalization followed by LeakyReLU(0.2) 75 | with the exception of the last layer that does not have batchnorm and uses tanh nonlinearity) 76 | 77 | # Optimizer 78 | The model is trained with the Adam optimizer with a learning rate of 0.0002, betas of 0.9 and 0.999, with a batch size of 25 for 200 epochs 79 | 80 | # Hyperparameters: 81 | 82 | * Dimension of the content encoding f : 256 83 | * Dimension of the dynamics encoding of a frame z_t : 32 84 | * Number of frames in the video : 8 85 | * Dimension of the hidden states of the RNNs : 512 86 | * Nonlinearity used in convolutional and deconvolutional layers : LeakyReLU(0.2) in intermediate layers, Tanh in last layer of deconvolutional (Chosen arbitrarily, not stated in the paper) 87 | * Number of channels in the convolutional and deconvolutional layers : 256 88 | * Dimension of convolutional encoding generated from the video frames: 2048 (Chosen arbitrarily, not stated in the paper) 89 | -------------------------------------------------------------------------------- /classifier.py: -------------------------------------------------------------------------------- 1 | from tqdm import * 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.data as data 5 | import torch.optim 6 | 7 | 8 | class Sprites(data.Dataset): 9 | def __init__(self, path, size): 10 | self.path = path 11 | self.length = size 12 | 13 | def __len__(self): 14 | return self.length 15 | 16 | def __getitem__(self, idx): 17 | item = torch.load(self.path+'/%d.sprite' % (idx+1)) 18 | return item['body'], item['shirt'], item['pant'], item['hair'], item['action'], item['sprite'] 19 | 20 | 21 | class SpriteClassifier(nn.Module): 22 | def __init__(self, n_bodies=7, n_shirts=4, n_pants=5, n_hairstyles=6, n_actions=3, 23 | num_frames=8, in_size=64, channels=64, code_dim=1024, hidden_dim=512, nonlinearity=None): 24 | super(SpriteClassifier, self).__init__() 25 | nl = nn.LeakyReLU(0.2) if nonlinearity is None else nonlinearity 26 | encoding_conv = [] 27 | encoding_conv.append(nn.Sequential(nn.Conv2d(3, channels, 5, 4, 1, bias=False), nl)) 28 | size = in_size // 4 29 | self.num_frames = num_frames 30 | while size > 4: 31 | encoding_conv.append(nn.Sequential( 32 | nn.Conv2d(channels, channels * 2, 5, 4, 1, bias=False), 33 | nn.BatchNorm2d(channels * 2), nl)) 34 | size = size // 4 35 | channels *= 2 36 | self.encoding_conv = nn.Sequential(*encoding_conv) 37 | self.final_size = size 38 | self.final_channels = channels 39 | self.code_dim = code_dim 40 | self.hidden_dim = hidden_dim 41 | self.encoding_fc = nn.Sequential( 42 | nn.Linear(size * size * channels, code_dim), 43 | nn.BatchNorm1d(code_dim), nl) 44 | # The last hidden state of a convolutional LSTM over the scenes is used for classification 45 | self.classifier_lstm = nn.LSTM(code_dim, hidden_dim, batch_first=True, bidirectional=False) 46 | self.body = nn.Sequential( 47 | nn.Linear(hidden_dim, hidden_dim // 2), 48 | nn.BatchNorm1d(hidden_dim // 2), nl, 49 | nn.Linear(hidden_dim // 2, n_bodies)) 50 | self.shirt = nn.Sequential( 51 | nn.Linear(hidden_dim, hidden_dim // 2), 52 | nn.BatchNorm1d(hidden_dim // 2), nl, 53 | nn.Linear(hidden_dim // 2, n_shirts)) 54 | self.pants = nn.Sequential( 55 | nn.Linear(hidden_dim, hidden_dim // 2), 56 | nn.BatchNorm1d(hidden_dim // 2), nl, 57 | nn.Linear(hidden_dim // 2, n_pants)) 58 | self.hairstyles = nn.Sequential( 59 | nn.Linear(hidden_dim, hidden_dim // 2), 60 | nn.BatchNorm1d(hidden_dim // 2), nl, 61 | nn.Linear(hidden_dim // 2, n_hairstyles)) 62 | self.action = nn.Sequential( 63 | nn.Linear(hidden_dim, hidden_dim // 2), 64 | nn.BatchNorm1d(hidden_dim // 2), nl, 65 | nn.Linear(hidden_dim // 2, n_actions)) 66 | 67 | def forward(self, x): 68 | x = x.view(-1, x.size(2), x.size(3), x.size(4)) 69 | x = self.encoding_conv(x) 70 | x = x.view(-1, self.final_channels * (self.final_size ** 2)) 71 | x = self.encoding_fc(x) 72 | x = x.view(-1, self.num_frames, self.code_dim) 73 | # Classifier output depends on last layer of LSTM: Can also change this to a bi-LSTM if required 74 | _, (hidden, _) = self.classifier_lstm(x) 75 | hidden = hidden.view(-1, self.hidden_dim) 76 | return self.body(hidden), self.shirt(hidden), self.pants(hidden), self.hairstyles(hidden), self.action(hidden) 77 | 78 | 79 | def save_model(model, optim, epoch, path): 80 | torch.save({ 81 | 'epoch': epoch + 1, 82 | 'state_dict': model.state_dict(), 83 | 'optimizer': optim.state_dict()}, path) 84 | 85 | def check_accuracy(model, test, device): 86 | total = 0 87 | correct_body = 0 88 | correct_shirt = 0 89 | correct_pant = 0 90 | correct_hair = 0 91 | correct_action = 0 92 | with torch.no_grad(): 93 | for item in test: 94 | body, shirt, pant, hair, action, image = item 95 | image = image.to(device) 96 | body = body.to(device) 97 | shirt = shirt.to(device) 98 | pant = pant.to(device) 99 | hair = hair.to(device) 100 | action = action.to(device) 101 | pred_body, pred_shirt, pred_pant, pred_hair, pred_action = model(image) 102 | _, pred_body = torch.max(pred_body.data, 1) 103 | _, pred_shirt = torch.max(pred_shirt.data, 1) 104 | _, pred_pant = torch.max(pred_pant.data, 1) 105 | _, pred_hair = torch.max(pred_hair.data, 1) 106 | _, pred_action = torch.max(pred_action.data, 1) 107 | total += body.size(0) 108 | correct_body += (pred_body == body).sum().item() 109 | correct_shirt += (pred_shirt == shirt).sum().item() 110 | correct_pant += (pred_pant == pant).sum().item() 111 | correct_hair += (pred_hair == hair).sum().item() 112 | correct_action += (pred_action == action).sum().item() 113 | print('Accuracy, Body : {} Shirt : {} Pant : {} Hair : {} Action {}'.format(correct_body/total, correct_shirt/total, correct_pant/total, correct_hair/total, correct_action/total)) 114 | 115 | 116 | def train_classifier(model, optim, dataset, device, epochs, path, test, start=0): 117 | model.train() 118 | criterion = nn.CrossEntropyLoss() 119 | for epoch in range(start, epochs): 120 | running_loss = 0.0 121 | for i, item in tqdm(enumerate(dataset, 1)): 122 | body, shirt, pant, hair, action, image = item 123 | image = image.to(device) 124 | body = body.to(device) 125 | shirt = shirt.to(device) 126 | pant = pant.to(device) 127 | hair = hair.to(device) 128 | action = action.to(device) 129 | pred_body, pred_shirt, pred_pant, pred_hair, pred_action = model(image) 130 | loss = criterion(pred_body, body) + criterion(pred_shirt, shirt) + criterion(pred_pant, pant) + criterion(pred_hair, hair) + criterion(pred_action, action) 131 | loss.backward() 132 | optim.step() 133 | running_loss += loss.item() 134 | print('Epoch {} Avg Loss {}'.format(epoch + 1, running_loss / i)) 135 | save_model(model, optim, epoch, path) 136 | check_accuracy(model, test, device) 137 | 138 | device = torch.device('cuda:0') 139 | model = SpriteClassifier() 140 | model.to(device) 141 | optim = torch.optim.Adam(model.parameters(), lr=0.0003) 142 | sprites_train = Sprites('./dataset/lpc-dataset/train', 6759) 143 | sprites_test = Sprites('./dataset/lpc-dataset/test', 801) 144 | loader = data.DataLoader(sprites_train, batch_size=32, shuffle=True, num_workers=4) 145 | loader_test = data.DataLoader(sprites_test, batch_size=64, shuffle=True, num_workers=4) 146 | train_classifier(model, optim, loader, device, 20, './checkpoint_classifier.pth', loader_test) 147 | -------------------------------------------------------------------------------- /dataset/create-lpc-dataset.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import time 3 | from selenium import webdriver 4 | from selenium.webdriver.support.ui import Select 5 | from PIL import Image 6 | import numpy as np 7 | import torch 8 | import torchvision.transforms as transforms 9 | slice_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]) 10 | def prepare_tensor(path): 11 | img = Image.open(path) 12 | img = img.convert("RGB") 13 | img = np.array(img) 14 | actions = { 15 | 'walk' : { 16 | 'range': [(9,10),(10,11),(11,12)], 17 | 'frames': [(0,1),(1,2),(2,3),(3,4),(4,5),(5,6),(6,7),(7,8)] 18 | }, 19 | 'spellcast': { 20 | 'range': [(1,2),(2,3),(3,4)], 21 | 'frames': [(0,1),(1,2),(2,3),(3,4),(4,5),(5,6),(6,7),(6,7)] 22 | }, 23 | 'slash': { 24 | 'range': [(14,15),(15,16),(16,17)], 25 | 'frames': [(0,1),(1,2),(2,3),(3,4),(4,5),(5,6),(5,6),(5,6)] 26 | } 27 | } 28 | slices = [] 29 | for action,params in actions.items(): 30 | for row in params['range']: 31 | sprite = [] 32 | for col in params['frames']: 33 | sprite.append(slice_transform(img[64*row[0]:64*row[1],64*col[0]:64*col[1],:])) 34 | slices.append(torch.stack(sprite)) 35 | return slices 36 | 37 | driver = webdriver.Firefox() 38 | driver.get("http://gaurav.munjal.us/Universal-LPC-Spritesheet-Character-Generator/") 39 | driver.maximize_window() 40 | 41 | bodies = ['light','dark','dark2','darkelf','darkelf2','tanned','tanned2'] 42 | shirts = ['longsleeve_brown','longsleeve_teal','longsleeve_maroon','longsleeve_white'] 43 | hairstyles = ['green','blue','pink','raven','white','dark_blonde'] 44 | pants = ['magenta','red','teal','white','robe_skirt'] 45 | train = 0 46 | test = 0 47 | for id_body, body in enumerate(bodies): 48 | driver.execute_script("return arguments[0].click();",driver.find_element_by_id('body-'+body)) 49 | time.sleep(0.5) 50 | for id_shirt, shirt in enumerate(shirts): 51 | driver.execute_script("return arguments[0].click();",driver.find_element_by_id('clothes-'+shirt)) 52 | time.sleep(0.5) 53 | for id_pant, pant in enumerate(pants): 54 | if pant=='robe_skirt': 55 | driver.execute_script("return arguments[0].click();",driver.find_element_by_id('legs-'+pant)) 56 | else: 57 | driver.execute_script("return arguments[0].click();",driver.find_element_by_id('legs-pants_'+pant)) 58 | time.sleep(0.5) 59 | for id_hair, hair in enumerate(hairstyles): 60 | driver.execute_script("return arguments[0].click();",driver.find_element_by_id('hair-plain_'+hair)) 61 | time.sleep(0.5) 62 | name = body+"_"+shirt+"_"+pant+"_"+hair 63 | print("Creating character: " + "'" + name) 64 | canvas = driver.find_element_by_id('spritesheet') 65 | canvas_base64 = driver.execute_script("return arguments[0].toDataURL('image/png').substring(21);",canvas) 66 | canvas_png = base64.b64decode(canvas_base64) 67 | with open(str(name) + ".png","wb") as f: 68 | f.write(canvas_png) 69 | slices = prepare_tensor(str(name) + ".png") 70 | p = torch.rand(1).item() <= 0.1 #Randomly add 10% of the characters created in the test set 71 | for id_action, sprites in enumerate(slices): 72 | if p is True: 73 | test += 1 74 | path = './lpc-dataset/test/%d.sprite' % test 75 | else: 76 | train += 1 77 | path = './lpc-dataset/train/%d.sprite' % train 78 | 79 | final_sprite = { 80 | 'body': id_body, 81 | 'shirt': id_shirt, 82 | 'pant': id_pant, 83 | 'hair': id_hair, 84 | 'action': id_action // 3, 85 | 'sprite': sprites 86 | } 87 | print(id_body,id_shirt,id_pant,id_hair,id_action // 3) 88 | torch.save(final_sprite, path) 89 | 90 | print("Dataset is Ready.Training Set Size : %d. Test Set Size : %d " % (train,test)) 91 | -------------------------------------------------------------------------------- /dataset/style-transfer-set.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import * 3 | import base64 4 | import time 5 | from selenium import webdriver 6 | from selenium.webdriver.support.ui import Select 7 | from PIL import Image 8 | import numpy as np 9 | import torch 10 | import torchvision 11 | import torchvision.transforms as transforms 12 | slice_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]) 13 | def prepare_tensor(path,save_path): 14 | img = Image.open(path) 15 | img = img.convert("RGB") 16 | img = np.array(img) 17 | actions = { 18 | 'walk' : { 19 | 'range': [(9,10),(10,11),(11,12)], 20 | 'frames': [(0,1),(1,2),(2,3),(3,4),(4,5),(5,6),(6,7),(7,8)] 21 | }, 22 | 'spellcast': { 23 | 'range': [(1,2),(2,3),(3,4)], 24 | 'frames': [(0,1),(1,2),(2,3),(3,4),(4,5),(5,6),(6,7),(6,7)] 25 | }, 26 | 'slash': { 27 | 'range': [(14,15),(15,16),(16,17)], 28 | 'frames': [(0,1),(1,2),(2,3),(3,4),(4,5),(5,6),(5,6),(5,6)] 29 | } 30 | } 31 | for action,params in actions.items(): 32 | i = 0 33 | for row in params['range']: 34 | sprite = [] 35 | for col in params['frames']: 36 | sprite.append(slice_transform(img[64*row[0]:64*row[1],64*col[0]:64*col[1],:])) 37 | os.makedirs(os.path.dirname(save_path+'/{}/{}.sprite'.format(action,i)),exist_ok=True) 38 | sprite_tensor = torch.stack(sprite) 39 | torch.save(sprite_tensor,save_path+'/{}/{}.sprite'.format(action,i)) 40 | torchvision.utils.save_image(sprite_tensor,save_path+'/{}/{}.png'.format(action,i)) 41 | i += 1 42 | 43 | 44 | driver = webdriver.Firefox() 45 | driver.get("http://gaurav.munjal.us/Universal-LPC-Spritesheet-Character-Generator/") 46 | driver.maximize_window() 47 | 48 | bodies = ['light','dark','dark2','darkelf','darkelf2','tanned','tanned2'] 49 | shirts = ['longsleeve_brown','longsleeve_teal','longsleeve_maroon','longsleeve_white'] 50 | hairstyles = ['green','blue','pink','raven','white','dark_blonde'] 51 | pants = ['magenta','red','teal','white','robe_skirt'] 52 | for body in tqdm(bodies): 53 | driver.execute_script("return arguments[0].click();",driver.find_element_by_id('body-'+body)) 54 | time.sleep(0.5) 55 | for shirt in shirts: 56 | driver.execute_script("return arguments[0].click();",driver.find_element_by_id('clothes-'+shirt)) 57 | time.sleep(0.5) 58 | for pant in pants: 59 | if pant=='robe_skirt': 60 | driver.execute_script("return arguments[0].click();",driver.find_element_by_id('legs-'+pant)) 61 | else: 62 | driver.execute_script("return arguments[0].click();",driver.find_element_by_id('legs-pants_'+pant)) 63 | time.sleep(0.5) 64 | for hair in hairstyles: 65 | driver.execute_script("return arguments[0].click();",driver.find_element_by_id('hair-plain_'+hair)) 66 | time.sleep(0.5) 67 | name = body+"_"+shirt+"_"+pant+"_"+hair 68 | canvas = driver.find_element_by_id('spritesheet') 69 | canvas_base64 = driver.execute_script("return arguments[0].toDataURL('image/png').substring(21);",canvas) 70 | canvas_png = base64.b64decode(canvas_base64) 71 | with open(str(name) + ".png","wb") as f: 72 | f.write(canvas_png) 73 | save_path = './style-transfer/{}/{}/{}/{}'.format(body,shirt,pant,hair) 74 | slices = prepare_tensor(str(name) + ".png",save_path) 75 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # A block consisting of convolution, batch normalization (optional) followed by a nonlinearity (defaults to Leaky ReLU) 7 | class ConvUnit(nn.Module): 8 | def __init__(self, in_channels, out_channels, kernel, stride=1, padding=0, batchnorm=True, nonlinearity=nn.LeakyReLU(0.2)): 9 | super(ConvUnit, self).__init__() 10 | if batchnorm is True: 11 | self.model = nn.Sequential( 12 | nn.Conv2d(in_channels, out_channels, kernel, stride, padding), 13 | nn.BatchNorm2d(out_channels), nonlinearity) 14 | else: 15 | self.model = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel, stride, padding), nonlinearity) 16 | 17 | def forward(self, x): 18 | return self.model(x) 19 | 20 | # A block consisting of a transposed convolution, batch normalization (optional) followed by a nonlinearity (defaults to Leaky ReLU) 21 | class ConvUnitTranspose(nn.Module): 22 | def __init__(self, in_channels, out_channels, kernel, stride=1, padding=0, out_padding=0, batchnorm=True, nonlinearity=nn.LeakyReLU(0.2)): 23 | super(ConvUnitTranspose, self).__init__() 24 | if batchnorm is True: 25 | self.model = nn.Sequential( 26 | nn.ConvTranspose2d(in_channels, out_channels, kernel, stride, padding, out_padding), 27 | nn.BatchNorm2d(out_channels), nonlinearity) 28 | else: 29 | self.model = nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, kernel, stride, padding, out_padding), nonlinearity) 30 | 31 | def forward(self, x): 32 | return self.model(x) 33 | 34 | # A block consisting of an affine layer, batch normalization (optional) followed by a nonlinearity (defaults to Leaky ReLU) 35 | class LinearUnit(nn.Module): 36 | def __init__(self, in_features, out_features, batchnorm=True, nonlinearity=nn.LeakyReLU(0.2)): 37 | super(LinearUnit, self).__init__() 38 | if batchnorm is True: 39 | self.model = nn.Sequential( 40 | nn.Linear(in_features, out_features), 41 | nn.BatchNorm1d(out_features), nonlinearity) 42 | else: 43 | self.model = nn.Sequential( 44 | nn.Linear(in_features, out_features), nonlinearity) 45 | 46 | def forward(self, x): 47 | return self.model(x) 48 | 49 | 50 | class DisentangledVAE(nn.Module): 51 | """ 52 | Network Architecture: 53 | PRIOR OF Z: 54 | The prior of z is a Gaussian with mean and variance computed by the LSTM as follows 55 | h_t, c_t = prior_lstm(z_t-1, (h_t, c_t)) where h_t is the hidden state and c_t is the cell state 56 | Now the hidden state h_t is used to compute the mean and variance of z_t using an affine transform 57 | z_mean, z_log_variance = affine_mean(h_t), affine_logvar(h_t) 58 | z = reparameterize(z_mean, z_log_variance) 59 | The hidden state has dimension 512 and z has dimension 32 60 | 61 | CONVOLUTIONAL ENCODER: 62 | The convolutional encoder consists of 4 convolutional layers with 256 layers and a kernel size of 5 63 | Each convolution is followed by a batch normalization layer and a LeakyReLU(0.2) nonlinearity. 64 | For the 3,64,64 frames (all image dimensions are in channel, width, height) in the sprites dataset the following dimension changes take place 65 | 66 | 3,64,64 -> 256,64,64 -> 256,32,32 -> 256,16,16 -> 256,8,8 (where each -> consists of a convolution, batch normalization followed by LeakyReLU(0.2)) 67 | 68 | The 8,8,256 tensor is unrolled into a vector of size 8*8*256 which is then made to undergo the following tansformations 69 | 70 | 8*8*256 -> 4096 -> 2048 (where each -> consists of an affine transformation, batch normalization followed by LeakyReLU(0.2)) 71 | 72 | APPROXIMATE POSTERIOR FOR f: 73 | The approximate posterior is parameterized by a bidirectional LSTM that takes the entire sequence of transformed x_ts (after being fed into the convolutional encoder) 74 | as input in each timestep. The hidden layer dimension is 512 75 | 76 | Then the features from the unit corresponding to the last timestep of the forward LSTM and the unit corresponding to the first timestep of the 77 | backward LSTM (as shown in the diagram in the paper) are concatenated and fed to two affine layers (without any added nonlinearity) to compute 78 | the mean and variance of the Gaussian posterior for f 79 | 80 | APPROXIMATE POSTERIOR FOR z (FACTORIZED q) 81 | Each x_t is first fed into an affine layer followed by a LeakyReLU(0.2) nonlinearity to generate an intermediate feature vector of dimension 512, 82 | which is then followed by two affine layers (without any added nonlinearity) to compute the mean and variance of the Gaussian Posterior of each z_t 83 | 84 | inter_t = intermediate_affine(x_t) 85 | z_mean_t, z_log_variance_t = affine_mean(inter_t), affine_logvar(inter_t) 86 | z = reparameterize(z_mean_t, z_log_variance_t) 87 | 88 | APPROXIMATE POSTERIOR FOR z (FULL q) 89 | The vector f is concatenated to each v_t where v_t is the encodings generated for each frame x_t by the convolutional encoder. This entire sequence is fed into a bi-LSTM 90 | of hidden layer dimension 512. Then the features of the forward and backward LSTMs are fed into an RNN having a hidden layer dimension 512. The output h_t of each timestep 91 | of this RNN transformed by two affine transformations (without any added nonlinearity) to compute the mean and variance of the Gaussian Posterior of each z_t 92 | 93 | g_t = [v_t, f] for each timestep 94 | forward_features, backward_features = lstm(g_t for all timesteps) 95 | h_t = rnn([forward_features, backward_features]) 96 | z_mean_t, z_log_variance_t = affine_mean(h_t), affine_logvar(h_t) 97 | z = reparameterize(z_mean_t, z_log_variance_t) 98 | 99 | CONVOLUTIONAL DECODER FOR CONDITIONAL DISTRIBUTION p(x_t | f, z_t) 100 | The architecture is symmetric to that of the convolutional encoder. The vector f is concatenated to each z_t, which then undergoes two subsequent 101 | affine transforms, causing the following change in dimensions 102 | 103 | 256 + 32 -> 4096 -> 8*8*256 (where each -> consists of an affine transformation, batch normalization followed by LeakyReLU(0.2)) 104 | 105 | The 8*8*256 tensor is reshaped into a tensor of shape 256,8,8 and then undergoes the following dimension changes 106 | 107 | 256,8,8 -> 256,16,16 -> 256,32,32 -> 256,64,64 -> 3,64,64 (where each -> consists of a transposed convolution, batch normalization followed by LeakyReLU(0.2) 108 | with the exception of the last layer that does not have batchnorm and uses tanh nonlinearity) 109 | 110 | Hyperparameters: 111 | f_dim: Dimension of the content encoding f. f has the shape (batch_size, f_dim) 112 | z_dim: Dimension of the dynamics encoding of a frame z_t. z has the shape (batch_size, frames, z_dim) 113 | frames: Number of frames in the video. 114 | hidden_dim: Dimension of the hidden states of the RNNs 115 | nonlinearity: Nonlinearity used in convolutional and deconvolutional layers, defaults to LeakyReLU(0.2) 116 | in_size: Height and width of each frame in the video (assumed square) 117 | step: Number of channels in the convolutional and deconvolutional layers 118 | conv_dim: The convolutional encoder converts each frame into an intermediate encoding vector of size conv_dim, i.e, 119 | The initial video tensor (batch_size, frames, num_channels, in_size, in_size) is converted to (batch_size, frames, conv_dim) 120 | factorised: Toggles between full and factorised posterior for z as discussed in the paper 121 | 122 | Optimization: 123 | The model is trained with the Adam optimizer with a learning rate of 0.0002, betas of 0.9 and 0.999, with a batch size of 25 for 200 epochs 124 | 125 | """ 126 | def __init__(self, f_dim=256, z_dim=32, conv_dim=2048, step=256, in_size=64, hidden_dim=512, 127 | frames=8, nonlinearity=None, factorised=False, device=torch.device('cpu')): 128 | super(DisentangledVAE, self).__init__() 129 | self.device = device 130 | self.f_dim = f_dim 131 | self.z_dim = z_dim 132 | self.frames = frames 133 | self.conv_dim = conv_dim 134 | self.hidden_dim = hidden_dim 135 | self.factorised = factorised 136 | self.step = step 137 | self.in_size = in_size 138 | nl = nn.LeakyReLU(0.2) if nonlinearity is None else nonlinearity 139 | 140 | # Prior of content is a uniform Gaussian and prior of the dynamics is an LSTM 141 | self.z_prior_lstm = nn.LSTMCell(self.z_dim, self.hidden_dim) 142 | self.z_prior_mean = nn.Linear(self.hidden_dim, self.z_dim) 143 | self.z_prior_logvar = nn.Linear(self.hidden_dim, self.z_dim) 144 | # POSTERIOR DISTRIBUTION NETWORKS 145 | # ------------------------------- 146 | self.f_lstm = nn.LSTM(self.conv_dim, self.hidden_dim, 1, 147 | bidirectional=True, batch_first=True) 148 | # TODO: Check if only one affine transform is sufficient. Paper says distribution is parameterised by LSTM 149 | self.f_mean = LinearUnit(self.hidden_dim * 2, self.f_dim, False) 150 | self.f_logvar = LinearUnit(self.hidden_dim * 2, self.f_dim, False) 151 | 152 | if self.factorised is True: 153 | # Paper says : 1 Hidden Layer MLP. Last layers shouldn't have any nonlinearities 154 | self.z_inter = LinearUnit(self.conv_dim, self.hidden_dim, batchnorm=False) 155 | self.z_mean = nn.Linear(self.hidden_dim, self.z_dim) 156 | self.z_logvar = nn.Linear(self.hidden_dim, self.z_dim) 157 | else: 158 | # TODO: Check if one affine transform is sufficient. Paper says distribution is parameterised by RNN over LSTM. Last layer shouldn't have any nonlinearities 159 | self.z_lstm = nn.LSTM(self.conv_dim + self.f_dim, self.hidden_dim, 1, bidirectional=True, batch_first=True) 160 | self.z_rnn = nn.RNN(self.hidden_dim * 2, self.hidden_dim, batch_first=True) 161 | # Each timestep is for each z so no reshaping and feature mixing 162 | self.z_mean = nn.Linear(self.hidden_dim, self.z_dim) 163 | self.z_logvar = nn.Linear(self.hidden_dim, self.z_dim) 164 | 165 | self.conv = nn.Sequential( 166 | ConvUnit(3, step, 5, 1, 2), # 3*64*64 -> 256*64*64 167 | ConvUnit(step, step, 5, 2, 2), # 256,64,64 -> 256,32,32 168 | ConvUnit(step, step, 5, 2, 2), # 256,32,32 -> 256,16,16 169 | ConvUnit(step, step, 5, 2, 2), # 256,16,16 -> 256,8,8 170 | ) 171 | self.final_conv_size = in_size // 8 172 | self.conv_fc = nn.Sequential(LinearUnit(step * (self.final_conv_size ** 2), self.conv_dim * 2), 173 | LinearUnit(self.conv_dim * 2, self.conv_dim)) 174 | 175 | self.deconv_fc = nn.Sequential(LinearUnit(self.f_dim + self.z_dim, self.conv_dim * 2, False), 176 | LinearUnit(self.conv_dim * 2, step * (self.final_conv_size ** 2), False)) 177 | self.deconv = nn.Sequential( 178 | ConvUnitTranspose(step, step, 5, 2, 2, 1), 179 | ConvUnitTranspose(step, step, 5, 2, 2, 1), 180 | ConvUnitTranspose(step, step, 5, 2, 2, 1), 181 | ConvUnitTranspose(step, 3, 5, 1, 2, 0, nonlinearity=nn.Tanh())) 182 | 183 | for m in self.modules(): 184 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 185 | nn.init.constant_(m.weight, 1) 186 | nn.init.constant_(m.bias, 1) 187 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): 188 | nn.init.kaiming_normal_(m.weight) 189 | 190 | # If random sampling is true, reparametrization occurs else z_t is just set to the mean 191 | def sample_z(self, batch_size, random_sampling=True): 192 | z_out = None # This will ultimately store all z_s in the format [batch_size, frames, z_dim] 193 | z_means = None 194 | z_logvars = None 195 | 196 | # All states are initially set to 0, especially z_0 = 0 197 | z_t = torch.zeros(batch_size, self.z_dim, device=self.device) 198 | z_mean_t = torch.zeros(batch_size, self.z_dim, device=self.device) 199 | z_logvar_t = torch.zeros(batch_size, self.z_dim, device=self.device) 200 | h_t = torch.zeros(batch_size, self.hidden_dim, device=self.device) 201 | c_t = torch.zeros(batch_size, self.hidden_dim, device=self.device) 202 | 203 | for _ in range(self.frames): 204 | h_t, c_t = self.z_prior_lstm(z_t, (h_t, c_t)) 205 | z_mean_t = self.z_prior_mean(h_t) 206 | z_logvar_t = self.z_prior_logvar(h_t) 207 | z_t = self.reparameterize(z_mean_t, z_logvar_t, random_sampling) 208 | if z_out is None: 209 | # If z_out is none it means z_t is z_1, hence store it in the format [batch_size, 1, z_dim] 210 | z_out = z_t.unsqueeze(1) 211 | z_means = z_mean_t.unsqueeze(1) 212 | z_logvars = z_logvar_t.unsqueeze(1) 213 | else: 214 | # If z_out is not none, z_t is not the initial z and hence append it to the previous z_ts collected in z_out 215 | z_out = torch.cat((z_out, z_t.unsqueeze(1)), dim=1) 216 | z_means = torch.cat((z_means, z_mean_t.unsqueeze(1)), dim=1) 217 | z_logvars = torch.cat((z_logvars, z_logvar_t.unsqueeze(1)), dim=1) 218 | 219 | return z_means, z_logvars, z_out 220 | 221 | 222 | def encode_frames(self, x): 223 | # The frames are unrolled into the batch dimension for batch processing such that x goes from 224 | # [batch_size, frames, channels, size, size] to [batch_size * frames, channels, size, size] 225 | x = x.view(-1, 3, self.in_size, self.in_size) 226 | x = self.conv(x) 227 | x = x.view(-1, self.step * (self.final_conv_size ** 2)) 228 | x = self.conv_fc(x) 229 | # The frame dimension is reintroduced and x shape becomes [batch_size, frames, conv_dim] 230 | # This technique is repeated at several points in the code 231 | x = x.view(-1, self.frames, self.conv_dim) 232 | return x 233 | 234 | def decode_frames(self, zf): 235 | x = self.deconv_fc(zf) 236 | x = x.view(-1, self.step, self.final_conv_size, self.final_conv_size) 237 | x = self.deconv(x) 238 | return x.view(-1, self.frames, 3, self.in_size, self.in_size) 239 | 240 | def reparameterize(self, mean, logvar, random_sampling=True): 241 | # Reparametrization occurs only if random sampling is set to true, otherwise mean is returned 242 | if random_sampling is True: 243 | eps = torch.randn_like(logvar) 244 | std = torch.exp(0.5*logvar) 245 | z = mean + eps*std 246 | return z 247 | else: 248 | return mean 249 | 250 | def encode_f(self, x): 251 | lstm_out, _ = self.f_lstm(x) 252 | # The features of the last timestep of the forward RNN is stored at the end of lstm_out in the first half, and the features 253 | # of the "first timestep" of the backward RNN is stored at the beginning of lstm_out in the second half 254 | # For a detailed explanation, check: https://gist.github.com/ceshine/bed2dadca48fe4fe4b4600ccce2fd6e1 255 | backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim] 256 | frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim] 257 | lstm_out = torch.cat((frontal, backward), dim=1) 258 | mean = self.f_mean(lstm_out) 259 | logvar = self.f_logvar(lstm_out) 260 | return mean, logvar, self.reparameterize(mean, logvar, self.training) 261 | 262 | def encode_z(self, x, f): 263 | if self.factorised is True: 264 | features = self.z_inter(x) 265 | else: 266 | # The expansion is done to match the dimension of x and f, used for concatenating f to each x_t 267 | f_expand = f.unsqueeze(1).expand(-1, self.frames, self.f_dim) 268 | lstm_out, _ = self.z_lstm(torch.cat((x, f_expand), dim=2)) 269 | features, _ = self.z_rnn(lstm_out) 270 | mean = self.z_mean(features) 271 | logvar = self.z_logvar(features) 272 | return mean, logvar, self.reparameterize(mean, logvar, self.training) 273 | 274 | def forward(self, x): 275 | z_mean_prior, z_logvar_prior, _ = self.sample_z(x.size(0), random_sampling=self.training) 276 | conv_x = self.encode_frames(x) 277 | f_mean, f_logvar, f = self.encode_f(conv_x) 278 | z_mean, z_logvar, z = self.encode_z(conv_x, f) 279 | f_expand = f.unsqueeze(1).expand(-1, self.frames, self.f_dim) 280 | zf = torch.cat((z, f_expand), dim=2) 281 | recon_x = self.decode_frames(zf) 282 | return f_mean, f_logvar, f, z_mean, z_logvar, z, z_mean_prior, z_logvar_prior, recon_x 283 | -------------------------------------------------------------------------------- /test/cosine-similarity/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/image1.png -------------------------------------------------------------------------------- /test/cosine-similarity/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/image2.png -------------------------------------------------------------------------------- /test/cosine-similarity/set1/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set1/image1.png -------------------------------------------------------------------------------- /test/cosine-similarity/set1/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set1/image2.png -------------------------------------------------------------------------------- /test/cosine-similarity/set10/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set10/image1.png -------------------------------------------------------------------------------- /test/cosine-similarity/set10/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set10/image2.png -------------------------------------------------------------------------------- /test/cosine-similarity/set11/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set11/image1.png -------------------------------------------------------------------------------- /test/cosine-similarity/set11/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set11/image2.png -------------------------------------------------------------------------------- /test/cosine-similarity/set12/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set12/image1.png -------------------------------------------------------------------------------- /test/cosine-similarity/set12/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set12/image2.png -------------------------------------------------------------------------------- /test/cosine-similarity/set2/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set2/image1.png -------------------------------------------------------------------------------- /test/cosine-similarity/set2/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set2/image2.png -------------------------------------------------------------------------------- /test/cosine-similarity/set3/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set3/image1.png -------------------------------------------------------------------------------- /test/cosine-similarity/set3/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set3/image2.png -------------------------------------------------------------------------------- /test/cosine-similarity/set4/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set4/image1.png -------------------------------------------------------------------------------- /test/cosine-similarity/set4/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set4/image2.png -------------------------------------------------------------------------------- /test/cosine-similarity/set5/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set5/image1.png -------------------------------------------------------------------------------- /test/cosine-similarity/set5/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set5/image2.png -------------------------------------------------------------------------------- /test/cosine-similarity/set6/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set6/image1.png -------------------------------------------------------------------------------- /test/cosine-similarity/set6/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set6/image2.png -------------------------------------------------------------------------------- /test/cosine-similarity/set7/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set7/image1.png -------------------------------------------------------------------------------- /test/cosine-similarity/set7/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set7/image2.png -------------------------------------------------------------------------------- /test/cosine-similarity/set8/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set8/image1.png -------------------------------------------------------------------------------- /test/cosine-similarity/set8/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set8/image2.png -------------------------------------------------------------------------------- /test/cosine-similarity/set9/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set9/image1.png -------------------------------------------------------------------------------- /test/cosine-similarity/set9/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/cosine-similarity/set9/image2.png -------------------------------------------------------------------------------- /test/style-transfer/set1/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set1/image1.png -------------------------------------------------------------------------------- /test/style-transfer/set1/image1_body_image2_motion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set1/image1_body_image2_motion.png -------------------------------------------------------------------------------- /test/style-transfer/set1/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set1/image2.png -------------------------------------------------------------------------------- /test/style-transfer/set1/image2_body_image1_motion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set1/image2_body_image1_motion.png -------------------------------------------------------------------------------- /test/style-transfer/set2/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set2/image1.png -------------------------------------------------------------------------------- /test/style-transfer/set2/image1_body_image2_motion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set2/image1_body_image2_motion.png -------------------------------------------------------------------------------- /test/style-transfer/set2/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set2/image2.png -------------------------------------------------------------------------------- /test/style-transfer/set2/image2_body_image1_motion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set2/image2_body_image1_motion.png -------------------------------------------------------------------------------- /test/style-transfer/set3/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set3/image1.png -------------------------------------------------------------------------------- /test/style-transfer/set3/image1_body_image2_motion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set3/image1_body_image2_motion.png -------------------------------------------------------------------------------- /test/style-transfer/set3/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set3/image2.png -------------------------------------------------------------------------------- /test/style-transfer/set3/image2_body_image1_motion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set3/image2_body_image1_motion.png -------------------------------------------------------------------------------- /test/style-transfer/set4/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set4/image1.png -------------------------------------------------------------------------------- /test/style-transfer/set4/image1_body_image2_motion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set4/image1_body_image2_motion.png -------------------------------------------------------------------------------- /test/style-transfer/set4/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set4/image2.png -------------------------------------------------------------------------------- /test/style-transfer/set4/image2_body_image1_motion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set4/image2_body_image1_motion.png -------------------------------------------------------------------------------- /test/style-transfer/set5/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set5/image1.png -------------------------------------------------------------------------------- /test/style-transfer/set5/image1_body_image2_motion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set5/image1_body_image2_motion.png -------------------------------------------------------------------------------- /test/style-transfer/set5/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set5/image2.png -------------------------------------------------------------------------------- /test/style-transfer/set5/image2_body_image1_motion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set5/image2_body_image1_motion.png -------------------------------------------------------------------------------- /test/style-transfer/set6/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set6/image1.png -------------------------------------------------------------------------------- /test/style-transfer/set6/image1_body_image2_motion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set6/image1_body_image2_motion.png -------------------------------------------------------------------------------- /test/style-transfer/set6/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set6/image2.png -------------------------------------------------------------------------------- /test/style-transfer/set6/image2_body_image1_motion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set6/image2_body_image1_motion.png -------------------------------------------------------------------------------- /test/style-transfer/set7/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set7/image1.png -------------------------------------------------------------------------------- /test/style-transfer/set7/image1_body_image2_motion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set7/image1_body_image2_motion.png -------------------------------------------------------------------------------- /test/style-transfer/set7/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set7/image2.png -------------------------------------------------------------------------------- /test/style-transfer/set7/image2_body_image1_motion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yatindandi/Disentangled-Sequential-Autoencoder/4650724de2d8804413b566631f32e2a7bccf52c7/test/style-transfer/set7/image2_body_image1_motion.png -------------------------------------------------------------------------------- /test/test_similarity.py: -------------------------------------------------------------------------------- 1 | from disVAE import FullQDisentangledVAE 2 | import torch 3 | vae = FullQDisentangledVAE(frames=8, f_dim=64, z_dim=32, hidden_dim=512, conv_dim=1024) 4 | device = torch.device('cuda:0') 5 | vae.to(device) 6 | checkpoint = torch.load('disentangled-vae.model') 7 | vae.load_state_dict(checkpoint['state_dict']) 8 | vae.eval() 9 | 10 | for imageset in ('set1', 'set2', 'set3', 'set4', 'set5', 'set6', 'set7', 'set8', 'set9', 'set10', 'set11', 'set12'): 11 | print(imageset) 12 | path = './cosine-similarity/'+imageset+'/' 13 | image1 = torch.load(path + 'image1.sprite') 14 | image2 = torch.load(path + 'image2.sprite') 15 | image1 = image1.to(device) 16 | image2 = image2.to(device) 17 | image1 = torch.unsqueeze(image1,0) 18 | image2= torch.unsqueeze(image2,0) 19 | with torch.no_grad(): 20 | conv1 = vae.encode_frames(image1) 21 | conv2 = vae.encode_frames(image2) 22 | 23 | _,_,image1_f = vae.encode_f(conv1) 24 | _,_,image1_z = vae.encode_z(conv1,image1_f) 25 | 26 | image1_f = image1_f.view(64) 27 | image1_z = image1_z.view(256) 28 | 29 | _,_,image2_f = vae.encode_f(conv2) 30 | _,_,image2_z = vae.encode_z(conv2,image2_f) 31 | image2_f = image2_f.view(64) 32 | image2_z = image2_z.view(256) 33 | 34 | similarity_f = image1_f.dot(image2_f) / (image1_f.norm(2) * image2_f.norm(2)) 35 | similarity_z = image1_z.dot(image2_z) / (image1_z.norm(2) * image2_z.norm(2)) 36 | print('{} : Cosine similarity of f : {} Cosine similarity of z : {}'.format(imageset, similarity_f.item(), similarity_z.item())) 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | import torch.utils.data 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | import numpy as np 8 | from model import * 9 | from tqdm import * 10 | from dataset import * 11 | 12 | __all__ = ['loss_fn', 'Trainer'] 13 | 14 | 15 | def loss_fn(original_seq,recon_seq,f_mean,f_logvar,z_post_mean,z_post_logvar, z_prior_mean, z_prior_logvar): 16 | """ 17 | Loss function consists of 3 parts, the reconstruction term that is the MSE loss between the generated and the original images 18 | the KL divergence of f, and the sum over the KL divergence of each z_t, with the sum divided by batch_size 19 | 20 | Loss = {mse + KL of f + sum(KL of z_t)} / batch_size 21 | Prior of f is a spherical zero mean unit variance Gaussian and the prior of each z_t is a Gaussian whose mean and variance 22 | are given by the LSTM 23 | """ 24 | batch_size = original_seq.size(0) 25 | mse = F.mse_loss(recon_seq,original_seq,reduction='sum'); 26 | kld_f = -0.5 * torch.sum(1 + f_logvar - torch.pow(f_mean,2) - torch.exp(f_logvar)) 27 | z_post_var = torch.exp(z_post_logvar) 28 | z_prior_var = torch.exp(z_prior_logvar) 29 | kld_z = 0.5 * torch.sum(z_prior_logvar - z_post_logvar + ((z_post_var + torch.pow(z_post_mean - z_prior_mean, 2)) / z_prior_var) - 1) 30 | return (mse + kld_f + kld_z)/batch_size, kld_f/batch_size, kld_z/batch_size 31 | 32 | 33 | class Trainer(object): 34 | def __init__(self,model,train,test,trainloader,testloader, test_f_expand, 35 | epochs=100,batch_size=64,learning_rate=0.001,nsamples=1,sample_path='./sample', 36 | recon_path='./recon/', transfer_path = './transfer/', 37 | checkpoints='model.pth', style1='image1.sprite', style2='image2.sprite', device=torch.device('cuda:0')): 38 | self.trainloader = trainloader 39 | self.train = train 40 | self.test = test 41 | self.testloader = testloader 42 | self.start_epoch = 0 43 | self.epochs = epochs 44 | self.device = device 45 | self.batch_size = batch_size 46 | self.model = model 47 | self.model.to(device) 48 | self.learning_rate = learning_rate 49 | self.checkpoints = checkpoints 50 | self.optimizer = optim.Adam(self.model.parameters(),self.learning_rate) 51 | self.samples = nsamples 52 | self.sample_path = sample_path 53 | self.recon_path = recon_path 54 | self.transfer_path = transfer_path 55 | self.test_f_expand = test_f_expand 56 | self.epoch_losses = [] 57 | 58 | self.image1 = torch.load(self.transfer_path + 'image1.sprite')['sprite'] 59 | self.image2 = torch.load(self.transfer_path + 'image2.sprite')['sprite'] 60 | self.image1 = self.image1.to(device) 61 | self.image2 = self.image2.to(device) 62 | self.image1 = torch.unsqueeze(self.image1,0) 63 | self.image2= torch.unsqueeze(self.image2,0) 64 | 65 | def save_checkpoint(self,epoch): 66 | torch.save({ 67 | 'epoch' : epoch+1, 68 | 'state_dict' : self.model.state_dict(), 69 | 'optimizer' : self.optimizer.state_dict(), 70 | 'losses' : self.epoch_losses}, 71 | self.checkpoints) 72 | 73 | def load_checkpoint(self): 74 | try: 75 | print("Loading Checkpoint from '{}'".format(self.checkpoints)) 76 | checkpoint = torch.load(self.checkpoints) 77 | self.start_epoch = checkpoint['epoch'] 78 | self.model.load_state_dict(checkpoint['state_dict']) 79 | self.optimizer.load_state_dict(checkpoint['optimizer']) 80 | self.epoch_losses = checkpoint['losses'] 81 | print("Resuming Training From Epoch {}".format(self.start_epoch)) 82 | except: 83 | print("No Checkpoint Exists At '{}'.Start Fresh Training".format(self.checkpoints)) 84 | self.start_epoch = 0 85 | 86 | def sample_frames(self,epoch): 87 | with torch.no_grad(): 88 | _,_,test_z = self.model.sample_z(1, random_sampling=False) 89 | print(test_z.shape) 90 | print(self.test_f_expand.shape) 91 | test_zf = torch.cat((test_z, self.test_f_expand), dim=2) 92 | recon_x = self.model.decode_frames(test_zf) 93 | recon_x = recon_x.view(self.samples*8,3,64,64) 94 | torchvision.utils.save_image(recon_x,'%s/epoch%d.png' % (self.sample_path,epoch)) 95 | 96 | def recon_frame(self,epoch,original): 97 | with torch.no_grad(): 98 | _,_,_,_,_,_,_,_,recon = self.model(original) 99 | image = torch.cat((original,recon),dim=0) 100 | image = image.view(2*8,3,64,64) 101 | os.makedirs(os.path.dirname('%s/epoch%d.png' % (self.recon_path,epoch)),exist_ok=True) 102 | torchvision.utils.save_image(image,'%s/epoch%d.png' % (self.recon_path,epoch)) 103 | 104 | def style_transfer(self,epoch): 105 | with torch.no_grad(): 106 | conv1 = self.model.encode_frames(self.image1) 107 | conv2 = self.model.encode_frames(self.image2) 108 | _,_,image1_f = self.model.encode_f(conv1) 109 | image1_f_expand = image1_f.unsqueeze(1).expand(-1,self.model.frames,self.model.f_dim) 110 | _,_,image1_z = self.model.encode_z(conv1,image1_f) 111 | _,_,image2_f = self.model.encode_f(conv2) 112 | image2_f_expand = image2_f.unsqueeze(1).expand(-1,self.model.frames,self.model.f_dim) 113 | _,_,image2_z = self.model.encode_z(conv2,image2_f) 114 | image1swap_zf = torch.cat((image2_z,image1_f_expand),dim=2) 115 | image1_body_image2_motion = self.model.decode_frames(image1swap_zf) 116 | image1_body_image2_motion = torch.squeeze(image1_body_image2_motion,0) 117 | image2swap_zf = torch.cat((image1_z,image2_f_expand),dim=2) 118 | image2_body_image1_motion = self.model.decode_frames(image2swap_zf) 119 | image2_body_image1_motion = torch.squeeze(image2_body_image1_motion,0) 120 | image1 = torch.squeeze(self.image1, 0) 121 | image2 = torch.squeeze(self.image2, 0) 122 | os.makedirs(os.path.dirname('%s/epoch%d/image1_body_image2_motion.png' % (self.transfer_path,epoch)),exist_ok=True) 123 | torchvision.utils.save_image(image1,'%s/epoch%d/image1.png' % (self.transfer_path,epoch)) 124 | torchvision.utils.save_image(image2,'%s/epoch%d/image2.png' % (self.transfer_path,epoch)) 125 | torchvision.utils.save_image(image1_body_image2_motion,'%s/epoch%d/image1_body_image2_motion.png' % (self.transfer_path,epoch)) 126 | torchvision.utils.save_image(image2_body_image1_motion,'%s/epoch%d/image2_body_image1_motion.png' % (self.transfer_path,epoch)) 127 | 128 | def train_model(self): 129 | self.model.train() 130 | for epoch in range(self.start_epoch,self.epochs): 131 | losses = [] 132 | kld_fs = [] 133 | kld_zs = [] 134 | print("Running Epoch : {}".format(epoch+1)) 135 | for i,dataitem in tqdm(enumerate(self.trainloader,1)): 136 | _,_,_,_,_,_,data = dataitem 137 | data = data.to(self.device) 138 | self.optimizer.zero_grad() 139 | f_mean, f_logvar, f, z_post_mean, z_post_logvar, z, z_prior_mean, z_prior_logvar, recon_x = self.model(data) 140 | loss, kld_f, kld_z = loss_fn(data, recon_x, f_mean, f_logvar, z_post_mean, z_post_logvar, z_prior_mean, z_prior_logvar) 141 | loss.backward() 142 | self.optimizer.step() 143 | losses.append(loss.item()) 144 | kld_fs.append(kld_f.item()) 145 | kld_zs.append(kld_z.item()) 146 | meanloss = np.mean(losses) 147 | meanf = np.mean(kld_fs) 148 | meanz = np.mean(kld_zs) 149 | self.epoch_losses.append(meanloss) 150 | print("Epoch {} : Average Loss: {} KL of f : {} KL of z : {}".format(epoch+1,meanloss, meanf, meanz)) 151 | self.save_checkpoint(epoch) 152 | self.model.eval() 153 | self.sample_frames(epoch+1) 154 | _,_,_,_,_,_,sample = self.test[int(torch.randint(0,len(self.test),(1,)).item())] 155 | sample = torch.unsqueeze(sample,0) 156 | sample = sample.to(self.device) 157 | self.sample_frames(epoch+1) 158 | self.recon_frame(epoch+1,sample) 159 | self.style_transfer(epoch+1) 160 | self.model.train() 161 | print("Training is complete") 162 | 163 | sprite = Sprites('./dataset/lpc-dataset/train', 6767) 164 | sprite_test = Sprites('./dataset/lpc-dataset/test', 791) 165 | batch_size = 25 166 | loader = torch.utils.data.DataLoader(sprite, batch_size=batch_size, shuffle=True, num_workers=4) 167 | device = torch.device('cuda:1') 168 | vae = DisentangledVAE(f_dim=256, z_dim=32, step=256, factorised=True,device=device) 169 | test_f = torch.rand(1,256, device=device) 170 | test_f = test_f.unsqueeze(1).expand(1, 8, 256) 171 | trainer = Trainer(vae, sprite, sprite_test, loader ,None, test_f,batch_size=25, epochs=500, learning_rate=0.0002, device=device) 172 | trainer.load_checkpoint() 173 | trainer.train_model() 174 | --------------------------------------------------------------------------------