├── data ├── raw │ ├── push │ │ ├── push_train │ │ │ └── .gitkeep │ │ ├── push_testnovel │ │ │ └── .gitkeep │ │ ├── push_testseen │ │ │ └── .gitkeep │ │ └── .DS_Store │ ├── .DS_Store │ ├── download_data.sh │ └── push_datafiles.txt └── processed │ └── push │ ├── push_testnovel │ └── .gitkeep │ ├── push_testseen │ └── .gitkeep │ ├── push_train │ └── .gitkeep │ └── .DS_Store ├── .gitignore ├── train.py ├── predict.py ├── README.md ├── options.py ├── data.py ├── tfrecord_to_dataset.py ├── model.py └── networks.py /data/raw/push/push_train/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/raw/push/push_testnovel/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/raw/push/push_testseen/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/processed/push/push_testnovel/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/processed/push/push_testseen/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/processed/push/push_train/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.npy 2 | **/.DS_Store 3 | *.tfrecord-* 4 | .idea 5 | __pycache__/ 6 | *.py[cod] 7 | -------------------------------------------------------------------------------- /data/raw/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaohui9607/physical_interaction_video_prediction_pytorch/HEAD/data/raw/.DS_Store -------------------------------------------------------------------------------- /data/raw/push/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaohui9607/physical_interaction_video_prediction_pytorch/HEAD/data/raw/push/.DS_Store -------------------------------------------------------------------------------- /data/processed/push/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaohui9607/physical_interaction_video_prediction_pytorch/HEAD/data/processed/push/.DS_Store -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from options import Options 2 | from model import Model 3 | 4 | def train(): 5 | opt = Options().parse() 6 | model = Model(opt) 7 | 8 | model.train() 9 | 10 | 11 | if __name__ == '__main__': 12 | train() -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | 5 | from options import Options 6 | from model import Model 7 | from torchvision.transforms import functional as F 8 | 9 | opt = Options().parse() 10 | 11 | 12 | def save_to_local(tensor_list, folder): 13 | for idx_, tensor in enumerate(tensor_list): 14 | img = F.to_pil_image(tensor.squeeze()) 15 | img.save(os.path.join(folder, "predict_%s.jpg" % idx_)) 16 | 17 | 18 | def predict(net, data, save_path=None): 19 | images, actions, states = data 20 | images = [F.to_tensor(F.resize(F.to_pil_image(im), (opt.height, opt.width))).unsqueeze(0).to(opt.device) 21 | for im in torch.from_numpy(images).unbind(0)] 22 | actions = [ac.unsqueeze(0).to(opt.device) for ac in torch.from_numpy(actions).unbind(0)] 23 | states = [st.unsqueeze(0).to(opt.device) for st in torch.from_numpy(states).unbind(0)] 24 | 25 | with torch.no_grad(): 26 | gen_images, gen_states = net(images, actions, states[0]) 27 | save_images = images[:opt.context_frames] + gen_images[opt.context_frames-1:] 28 | if save_path: 29 | save_to_local(save_images, save_path) 30 | 31 | 32 | if __name__ == '__main__': 33 | images, actions, states = np.load("data/processed/push/push_testseen/image/batch_1_0.npy"), \ 34 | np.load("data/processed/push/push_testseen/action/batch_1_0.npy"), \ 35 | np.load("data/processed/push/push_testseen/state/batch_1_0.npy") 36 | 37 | m = Model(opt) 38 | # m.load_weight() 39 | net = m.net 40 | 41 | predict(net, (images, actions, states), save_path="predict/") 42 | 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Unsupervised Learning for Physical Interaction through Video Prediction 2 | ============================== 3 | 4 | Based on the paper from C. Finn, I. Goodfellow and S. Levine: [*"Unsupervised Learning for Physical Interaction through Video Prediction"*](https://papers.nips.cc/paper/6161-unsupervised-learning-for-physical-interaction-through-video-prediction.pdf), Implemented in Pytorch. 5 | 6 | Prepare the data need for training 7 | ------------ 8 | ```bash 9 | $ sh download_data.sh push_datafiles.txt # Will download all the data from Google's ftp to data/raw 10 | $ python ./tfrecord_to_dataset.py 11 | ``` 12 | 13 | Training 14 | ------------ 15 | ```bash 16 | $ python ./train.py \ 17 | --data_dir data/processed/push \ # path to the training set. 18 | --model CDNA \ # the model type to use - DNA, CDNA, or STP 19 | --output_dir ./weights \ # where to save model checkpoints 20 | --pretrained_model model \ # path to model to initialize from, random if emtpy 21 | --sequence_length 10 \ # the number of total frames in a sequence 22 | --context_frames 2 \ # the number of ground truth frames to pass in at start 23 | --num_masks 10 \ # the number of transformations and corresponding masks 24 | --schedsamp_k 900.0 \ # the constant used for scheduled sampling or -1 25 | --train_val_split 0.95 \ # the percentage of training data for validation 26 | --batch_size 32 \ # the training batch size 27 | --learning_rate 0.001 \ # the initial learning rate for the Adam optimizer 28 | --epochs 10 \ # total training epoch 29 | --print_interval 10 \ # iterations to output loss 30 | --device cuda \ # the device used for training 31 | --use_state \ # whether or not to condition on actions and the initial state 32 | ``` 33 | -------------------------------------------------------------------------------- /data/raw/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2016 The TensorFlow Authors All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | 18 | # Example: 19 | # 20 | # download_dataset.sh datafiles.txt ./tmp 21 | # 22 | # will download all of the files listed in the file, datafiles.txt, into 23 | # a directory, "./tmp". 24 | # 25 | # Each line of the datafiles.txt file should contain the path from the 26 | # bucket root to a file. 27 | 28 | ARGC="$#" 29 | LISTING_FILE=push_datafiles.txt 30 | if [ "${ARGC}" -ge 1 ]; then 31 | LISTING_FILE=$1 32 | fi 33 | OUTPUT_DIR="./" 34 | if [ "${ARGC}" -ge 2 ]; then 35 | OUTPUT_DIR=$2 36 | fi 37 | 38 | echo "OUTPUT_DIR=$OUTPUT_DIR" 39 | 40 | mkdir "${OUTPUT_DIR}" 41 | 42 | function download_file { 43 | FILE=$1 44 | BUCKET="https://storage.googleapis.com/brain-robotics-data" 45 | URL="${BUCKET}/${FILE}" 46 | OUTPUT_FILE="${OUTPUT_DIR}/${FILE}" 47 | DIRECTORY=`dirname ${OUTPUT_FILE}` 48 | echo DIRECTORY=$DIRECTORY 49 | mkdir -p "${DIRECTORY}" 50 | curl --output ${OUTPUT_FILE} ${URL} 51 | } 52 | 53 | while read filename; do 54 | download_file $filename 55 | done <${LISTING_FILE} -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | 5 | # pylint: disable=C0103,C0301,R0903,W0622 6 | 7 | 8 | class Options(): 9 | """Options class 10 | 11 | Returns: 12 | [argparse]: argparse containing train and test options 13 | """ 14 | 15 | def __init__(self): 16 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | 18 | self.parser.add_argument('--data_dir', default='data/processed/push/', help='directory containing data.') 19 | self.parser.add_argument('--channels', type=int, default=3, help='# channel of input') 20 | self.parser.add_argument('--height', type=int, default=64, help='height of image') 21 | self.parser.add_argument('--width', type=int, default=64, help='width of image') 22 | self.parser.add_argument('--output_dir', default='weight', help='directory for model weight.') 23 | self.parser.add_argument('--pretrained_model', default='', help='filepath of a pretrained model to initialize from.') 24 | self.parser.add_argument('--sequence_length', type=int, default=10, help='sequence length, including context frames.') 25 | self.parser.add_argument('--context_frames', type=int, default=2, help= '# of frames before predictions.') 26 | self.parser.add_argument('--use_state', default=True, action='store_true', help='Whether or not to give the state+action to the model') 27 | self.parser.add_argument('--model', default='CDNA', help='model architecture to use - CDNA, DNA, or STP') 28 | self.parser.add_argument('--num_masks', type=int, default=10, help='number of masks, usually 1 for DNA, 10 for CDNA, STN.') 29 | self.parser.add_argument('--device', default='cpu', help='cuda:[d] | cpu') 30 | 31 | # training details 32 | self.parser.add_argument('--print_interval', type=int, default=10, help='# iterations to output loss') 33 | self.parser.add_argument('--schedsamp_k', type=float, default=900.0, help='The k hyperparameter for scheduled sampling, -1 for no scheduled sampling.') 34 | self.parser.add_argument('--batch_size', type=int, default=32, help='batch size for training') 35 | self.parser.add_argument('--learning_rate', type=float, default=0.001, help='the base learning rate of the generator') 36 | self.parser.add_argument('--epochs', type=int, default=10, help='# total training epoch') 37 | self.opt = None 38 | 39 | def parse(self): 40 | """ Parse Arguments. 41 | """ 42 | self.opt = self.parser.parse_args() 43 | if not os.path.exists(self.opt.output_dir): 44 | os.makedirs(self.opt.output_dir) 45 | return self.opt 46 | 47 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision import transforms 4 | from torch.utils.data import DataLoader, Dataset 5 | import numpy as np 6 | 7 | IMG_EXTENSIONS = ('.npy',) 8 | 9 | 10 | def make_dataset(path): 11 | image_folders = os.path.join(path, 'image') 12 | action_folders = os.path.join(path, 'action') 13 | state_folders = os.path.join(path, 'state') 14 | 15 | if os.path.exists(image_folders) + os.path.exists(action_folders) + os.path.exists(state_folders) != 3: 16 | raise FileExistsError('some subfolders from data set do not exists!') 17 | 18 | samples = [] 19 | for sample in os.listdir(image_folders): 20 | image, action, state = os.path.join(image_folders, sample), os.path.join(action_folders, sample), os.path.join(state_folders, sample) 21 | samples.append((image, action, state)) 22 | return samples 23 | 24 | 25 | def npy_loader(path): 26 | samples = torch.from_numpy(np.load(path)) 27 | return samples 28 | 29 | 30 | class PushDataset(Dataset): 31 | def __init__(self, root, image_transform=None, action_transform=None, state_transform=None, loader=npy_loader, device='cpu'): 32 | if not os.path.exists(root): 33 | raise FileExistsError('{0} does not exists!'.format(root)) 34 | # self.subfolders = [f[0] for f in os.walk(root)][1:] 35 | self.image_transform = image_transform 36 | self.action_transform = action_transform 37 | self.state_transform = state_transform 38 | self.samples = make_dataset(root) 39 | if len(self.samples) == 0: 40 | raise (RuntimeError("Found 0 images in subfolders of: " + root + "\n" 41 | "Supported image extensions are: " + ",".join( 42 | IMG_EXTENSIONS))) 43 | self.loader = loader 44 | self.device = device 45 | def __getitem__(self, index): 46 | image, action, state = self.samples[index] 47 | image, action, state = self.loader(image), self.loader(action), self.loader(state) 48 | 49 | if self.image_transform is not None: 50 | image = torch.cat([self.image_transform(single_image).unsqueeze(0) for single_image in image.unbind(0)], dim=0) 51 | if self.action_transform is not None: 52 | action = torch.cat([self.action_transform(single_action).unsqueeze(0) for single_action in action.unbind(0)], dim=0) 53 | if self.state_transform is not None: 54 | state = torch.cat([self.state_transform(single_state).unsqueeze(0) for single_state in state.unbind(0)], dim=0) 55 | 56 | return image.to(self.device), action.to(self.device), state.to(self.device) 57 | 58 | def __len__(self): 59 | return len(self.samples) 60 | 61 | 62 | def build_dataloader(opt): 63 | image_transform = transforms.Compose([ 64 | transforms.ToPILImage(), 65 | transforms.Resize((opt.height, opt.width)), 66 | transforms.ToTensor() 67 | ]) 68 | 69 | train_ds = PushDataset( 70 | root=os.path.join(opt.data_dir, 'push_train'), 71 | image_transform=image_transform, 72 | loader=npy_loader, 73 | device=opt.device 74 | ) 75 | 76 | testseen_ds = PushDataset( 77 | root=os.path.join(opt.data_dir, 'push_testseen'), 78 | image_transform=image_transform, 79 | loader=npy_loader, 80 | device=opt.device 81 | ) 82 | 83 | train_dl = DataLoader(dataset=train_ds, batch_size=opt.batch_size, shuffle=True, drop_last=False) 84 | testseen_dl = DataLoader(dataset=testseen_ds, batch_size=opt.batch_size, shuffle=False, drop_last=False) 85 | return train_dl, testseen_dl 86 | 87 | -------------------------------------------------------------------------------- /tfrecord_to_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | """Code for turning the tfrecord file into other format readable for pytorch.""" 3 | 4 | import os 5 | import numpy as np 6 | import tensorflow as tf 7 | from tensorflow.python.platform import gfile 8 | 9 | 10 | # tf record data location: 11 | DATA_DIR = 'data/raw/push/push_testnovel' 12 | 13 | OUT_DIR = 'data/processed/push/push_testnovel' 14 | 15 | SEQUENCE_LENGTH = 10 16 | 17 | ORIGINAL_WIDTH = 640 18 | ORIGINAL_HEIGHT = 512 19 | COLOR_CHAN = 3 20 | 21 | # Dimension of the state and action. 22 | STATE_DIM = 5 23 | ACTION_DIM = 5 24 | 25 | IMG_WIDTH = 64 26 | IMG_HEIGHT = 64 27 | 28 | 29 | def convert(): 30 | config = tf.ConfigProto( 31 | device_count={'GPU': 0} 32 | ) 33 | with tf.Session(config=config) as sess: 34 | files = gfile.Glob(os.path.join(DATA_DIR, '*')) 35 | queue = tf.train.string_input_producer(files, shuffle=False) 36 | reader = tf.TFRecordReader() 37 | _, serialized_example = reader.read(queue) 38 | image_seq, state_seq, action_seq = [], [], [] 39 | 40 | for i in range(SEQUENCE_LENGTH): 41 | image_name = 'move/' + str(i) + '/image/encoded' 42 | action_name = 'move/' + str(i) + '/commanded_pose/vec_pitch_yaw' 43 | state_name = 'move/' + str(i) + '/endeffector/vec_pitch_yaw' 44 | 45 | features = { 46 | image_name: tf.FixedLenFeature([1], tf.string), 47 | action_name: tf.FixedLenFeature([STATE_DIM], tf.float32), 48 | state_name: tf.FixedLenFeature([ACTION_DIM], tf.float32) 49 | } 50 | 51 | features = tf.parse_single_example(serialized_example, features=features) 52 | image_buffer = tf.reshape(features[image_name], shape=[]) 53 | image = tf.image.decode_jpeg(image_buffer, channels=COLOR_CHAN) 54 | image.set_shape([ORIGINAL_HEIGHT, ORIGINAL_WIDTH, COLOR_CHAN]) 55 | 56 | crop_size = min(ORIGINAL_WIDTH, ORIGINAL_HEIGHT) 57 | image = tf.image.resize_image_with_crop_or_pad(image, crop_size, crop_size) 58 | image = tf.reshape(image, [1, crop_size, crop_size, COLOR_CHAN]) 59 | image = tf.image.resize_bicubic(image, [IMG_HEIGHT, IMG_WIDTH]) 60 | image_seq.append(image) 61 | 62 | state = tf.reshape(features[state_name], shape=[1, STATE_DIM]) 63 | state_seq.append(state) 64 | action = tf.reshape(features[action_name], shape=[1, ACTION_DIM]) 65 | action_seq.append(action) 66 | 67 | image_seq = tf.concat(axis=0, values=image_seq) 68 | state_seq = tf.concat(axis=0, values=state_seq) 69 | action_seq = tf.concat(axis=0, values=action_seq) 70 | 71 | [image_batch, action_batch, state_batch] = tf.train.batch( 72 | [image_seq, action_seq, state_seq], 73 | 1, 74 | num_threads=1, 75 | capacity=100 * 64, 76 | allow_smaller_final_batch=True) 77 | 78 | init_op = tf.initialize_all_variables() 79 | sess.run(init_op) 80 | coord = tf.train.Coordinator() 81 | threads = tf.train.start_queue_runners(coord=coord) 82 | 83 | if not os.path.exists(OUT_DIR): 84 | os.makedirs(OUT_DIR) 85 | if not os.path.exists(os.path.join(OUT_DIR, 'image')): 86 | os.makedirs(os.path.join(OUT_DIR, 'image')) 87 | if not os.path.exists(os.path.join(OUT_DIR, 'state')): 88 | os.makedirs(os.path.join(OUT_DIR, 'state')) 89 | if not os.path.exists(os.path.join(OUT_DIR, 'action')): 90 | os.makedirs(os.path.join(OUT_DIR, 'action')) 91 | 92 | for j in range(len(files)): 93 | data_length = sum(1 for _ in tf.python_io.tf_record_iterator(files[j])) 94 | for i in range(data_length): 95 | imgs, acts, stas = sess.run([image_batch, action_batch, state_batch]) 96 | imgs = imgs.squeeze().transpose([0, 3, 1, 2]) 97 | acts = acts.squeeze() 98 | stas = stas.squeeze() 99 | np.save(os.path.join(OUT_DIR, 'image', 'batch_{0}_{1}'.format(j, i)), imgs) 100 | np.save(os.path.join(OUT_DIR, 'action', 'batch_{0}_{1}'.format(j, i)), acts) 101 | np.save(os.path.join(OUT_DIR, 'state', 'batch_{0}_{1}'.format(j, i)), stas) 102 | 103 | coord.request_stop() 104 | coord.join(threads) 105 | 106 | if __name__ == '__main__': 107 | convert() 108 | 109 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | from networks import network 5 | from data import build_dataloader 6 | from torch.nn import functional as F 7 | 8 | 9 | def peak_signal_to_noise_ratio(true, pred): 10 | return 10.0 * torch.log(torch.tensor(1.0) / F.mse_loss(true, pred)) / torch.log(torch.tensor(10.0)) 11 | 12 | 13 | class Model(): 14 | def __init__(self, opt): 15 | self.opt = opt 16 | self.device = self.opt.device 17 | 18 | train_dataloader, valid_dataloader = build_dataloader(opt) 19 | self.dataloader = {'train': train_dataloader, 'valid': valid_dataloader} 20 | 21 | self.net = network(self.opt.channels, self.opt.height, self.opt.width, -1, self.opt.schedsamp_k, 22 | self.opt.use_state, self.opt.num_masks, self.opt.model=='STP', self.opt.model=='CDNA', self.opt.model=='DNA', self.opt.context_frames) 23 | self.net.to(self.device) 24 | self.mse_loss = nn.MSELoss() 25 | self.w_state = 1e-4 26 | if self.opt.pretrained_model: 27 | self.load_weight() 28 | self.optimizer = torch.optim.Adam(self.net.parameters(), self.opt.learning_rate) 29 | 30 | def train_epoch(self, epoch): 31 | print("--------------------start training epoch %2d--------------------" % epoch) 32 | for iter_, (images, actions, states) in enumerate(self.dataloader['train']): 33 | self.net.zero_grad() 34 | images = images.permute([1, 0, 2, 3, 4]).unbind(0) 35 | actions = actions.permute([1, 0, 2]).unbind(0) 36 | states = states.permute([1, 0, 2]).unbind(0) 37 | gen_images, gen_states = self.net(images, actions, states[0]) 38 | 39 | loss, psnr = 0.0, 0.0 40 | for i, (image, gen_image) in enumerate(zip(images[self.opt.context_frames:], gen_images[self.opt.context_frames-1:])): 41 | recon_loss = self.mse_loss(image, gen_image) 42 | psnr_i = peak_signal_to_noise_ratio(image, gen_image) 43 | loss += recon_loss 44 | psnr += psnr_i 45 | 46 | for i, (state, gen_state) in enumerate(zip(states[self.opt.context_frames:], gen_states[self.opt.context_frames-1:])): 47 | state_loss = self.mse_loss(state, gen_state) * self.w_state 48 | loss += state_loss 49 | loss /= torch.tensor(self.opt.sequence_length - self.opt.context_frames) 50 | loss.backward() 51 | self.optimizer.step() 52 | 53 | if iter_ % self.opt.print_interval == 0: 54 | print("training epoch: %3d, iterations: %3d/%3d loss: %6f" % 55 | (epoch, iter_, len(self.dataloader['train'].dataset)//self.opt.batch_size, loss)) 56 | 57 | self.net.iter_num += 1 58 | 59 | def train(self): 60 | for epoch_i in range(0, self.opt.epochs): 61 | self.train_epoch(epoch_i) 62 | self.evaluate(epoch_i) 63 | self.save_weight(epoch_i) 64 | 65 | def evaluate(self, epoch): 66 | with torch.no_grad(): 67 | recon_loss, state_loss = 0.0, 0.0 68 | for iter_, (images, actions, states) in enumerate(self.dataloader['valid']): 69 | images = images.permute([1, 0, 2, 3, 4]).unbind(0) 70 | actions = actions.permute([1, 0, 2]).unbind(0) 71 | states = states.permute([1, 0, 2]).unbind(0) 72 | gen_images, gen_states = self.net(images, actions, states[0]) 73 | for i, (image, gen_image) in enumerate( 74 | zip(images[self.opt.context_frames:], gen_images[self.opt.context_frames - 1:])): 75 | recon_loss += self.mse_loss(image, gen_image) 76 | 77 | for i, (state, gen_state) in enumerate( 78 | zip(states[self.opt.context_frames:], gen_states[self.opt.context_frames - 1:])): 79 | state_loss += self.mse_loss(state, gen_state) * self.w_state 80 | recon_loss /= (torch.tensor(self.opt.sequence_length - self.opt.context_frames) * len(self.dataloader['valid'].dataset)/self.opt.batch_size) 81 | state_loss /= (torch.tensor(self.opt.sequence_length - self.opt.context_frames) * len(self.dataloader['valid'].dataset)/self.opt.batch_size) 82 | 83 | print("evaluation epoch: %3d, recon_loss: %6f, state_loss: %6f" % (epoch, recon_loss, state_loss)) 84 | 85 | def save_weight(self, epoch): 86 | torch.save(self.net.state_dict(), os.path.join(self.opt.output_dir, "net_epoch_%d.pth" % epoch)) 87 | 88 | def load_weight(self): 89 | self.net.load_state_dict(torch.load(self.opt.pretrained_model)) -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | RELU_SHIFT = 1e-12 7 | DNA_KERN_SIZE = 5 8 | STATE_DIM = 5 9 | ACTION_DIM = 5 10 | 11 | 12 | class ConvLSTM(nn.Module): 13 | def __init__(self, in_channels, out_channels, kernel_size=5, forget_bias=1.0, padding=0): 14 | super(ConvLSTM, self).__init__() 15 | self.out_channels = out_channels 16 | self.conv = nn.Conv2d(in_channels=out_channels + in_channels, out_channels=4 * out_channels, kernel_size=kernel_size, stride=1, padding=padding) 17 | self.forget_bias = forget_bias 18 | 19 | def forward(self, inputs, states): 20 | if states is None: 21 | states = (torch.zeros([inputs.shape[0], self.out_channels, inputs.shape[2], inputs.shape[3]], device=inputs.device), 22 | torch.zeros([inputs.shape[0], self.out_channels, inputs.shape[2], inputs.shape[3]], device=inputs.device)) 23 | if not isinstance(states, tuple): 24 | raise TypeError("states type is not right") 25 | 26 | c, h = states 27 | if not (len(c.shape) == 4 and len(h.shape) == 4 and len(inputs.shape) == 4): 28 | raise TypeError("") 29 | 30 | inputs_h = torch.cat((inputs, h), dim=1) 31 | i_j_f_o = self.conv(inputs_h) 32 | i, j, f, o = torch.split(i_j_f_o, self.out_channels, dim=1) 33 | 34 | new_c = c * torch.sigmoid(f + self.forget_bias) + torch.sigmoid(i) * torch.tanh(j) 35 | new_h = torch.tanh(new_c) * torch.sigmoid(o) 36 | 37 | return new_h, (new_c, new_h) 38 | 39 | 40 | class network(nn.Module): 41 | def __init__(self, channels=3, 42 | height=64, 43 | width=64, 44 | iter_num=-1.0, 45 | k=-1, 46 | use_state=True, 47 | num_masks=10, 48 | stp=False, 49 | cdna=True, 50 | dna=False, 51 | context_frames=2): 52 | super(network, self).__init__() 53 | if stp + cdna + dna != 1: 54 | raise ValueError('More than one, or no network option specified.') 55 | lstm_size = [32, 32, 64, 64, 128, 64, 32] 56 | self.dna = dna 57 | self.stp = stp 58 | self.cdna = cdna 59 | self.channels = channels 60 | self.use_state = use_state 61 | self.num_masks = num_masks 62 | self.height = height 63 | self.width = width 64 | self.context_frames = context_frames 65 | self.k = k 66 | self.iter_num = iter_num 67 | 68 | self.STATE_DIM = STATE_DIM 69 | self.ACTION_DIM = ACTION_DIM 70 | if not self.use_state: 71 | self.STATE_DIM = 0 72 | self.ACTION_DIM = 0 73 | # N * 3 * H * W -> N * 32 * H/2 * W/2 74 | self.enc0 = nn.Conv2d(in_channels=channels, out_channels=lstm_size[0], kernel_size=5, stride=2, padding=2) 75 | self.enc0_norm = nn.LayerNorm([lstm_size[0], self.height//2, self.width//2]) 76 | # N * 32 * H/2 * W/2 -> N * 32 * H/2 * W/2 77 | self.lstm1 = ConvLSTM(in_channels=32, out_channels=lstm_size[0], kernel_size=5, padding=2) 78 | self.lstm1_norm = nn.LayerNorm([lstm_size[0], self.height//2, self.width//2]) 79 | # N * 32 * H/2 * W/2 -> N * 32 * H/2 * W/2 80 | self.lstm2 = ConvLSTM(in_channels=lstm_size[0], out_channels=lstm_size[1], kernel_size=5, padding=2) 81 | self.lstm2_norm = nn.LayerNorm([lstm_size[1], self.height//2, self.width//2]) 82 | 83 | # N * 32 * H/4 * W/4 -> N * 32 * H/4 * W/4 84 | self.enc1 = nn.Conv2d(in_channels=lstm_size[1], out_channels=lstm_size[1], kernel_size=3, stride=2, padding=1) 85 | # N * 32 * H/4 * W/4 -> N * 64 * H/4 * W/4 86 | self.lstm3 = ConvLSTM(in_channels=lstm_size[1], out_channels=lstm_size[2], kernel_size=5, padding=2) 87 | self.lstm3_norm = nn.LayerNorm([lstm_size[2], self.height//4, self.width//4]) 88 | # N * 64 * H/4 * W/4 -> N * 64 * H/4 * W/4 89 | self.lstm4 = ConvLSTM(in_channels=lstm_size[2], out_channels=lstm_size[3], kernel_size=5, padding=2) 90 | self.lstm4_norm = nn.LayerNorm([lstm_size[3], self.height//4, self.width//4]) 91 | # pass in state and action 92 | 93 | # N * 64 * H/4 * W/4 -> N * 64 * H/8 * W/8 94 | self.enc2 = nn.Conv2d(in_channels=lstm_size[3], out_channels=lstm_size[3], kernel_size=3, stride=2, padding=1) 95 | # N * (10+64) * H/8 * W/8 -> N * 64 * H/8 * W/8 96 | self.enc3 = nn.Conv2d(in_channels=lstm_size[3]+self.STATE_DIM+self.ACTION_DIM, out_channels=lstm_size[3], kernel_size=1, stride=1) 97 | # N * 64 * H/8 * W/8 -> N * 128 * H/8 * W/8 98 | self.lstm5 = ConvLSTM(in_channels=lstm_size[3], out_channels=lstm_size[4], kernel_size=5, padding=2) 99 | self.lstm5_norm = nn.LayerNorm([lstm_size[4], self.height//8, self.width//8]) 100 | # N * 128 * H/8 * W/8 -> N * 128 * H/4 * W/4 101 | self.enc4 = nn.ConvTranspose2d(in_channels=lstm_size[4], out_channels=lstm_size[4], kernel_size=3, stride=2, output_padding=1, padding=1) 102 | # N * 128 * H/4 * W/4 -> N * 64 * H/4 * W/4 103 | self.lstm6 = ConvLSTM(in_channels=lstm_size[4], out_channels=lstm_size[5], kernel_size=5, padding=2) 104 | self.lstm6_norm = nn.LayerNorm([lstm_size[5], self.height//4, self.width//4]) 105 | 106 | # N * 64 * H/4 * W/4 -> N * 64 * H/2 * W/2 107 | self.enc5 = nn.ConvTranspose2d(in_channels=lstm_size[5]+lstm_size[1], out_channels=lstm_size[5]+lstm_size[1], kernel_size=3, stride=2, output_padding=1, padding=1) 108 | # N * 64 * H/2 * W/2 -> N * 32 * H/2 * W/2 109 | self.lstm7 = ConvLSTM(in_channels=lstm_size[5]+lstm_size[1], out_channels=lstm_size[6], kernel_size=5, padding=2) 110 | self.lstm7_norm = nn.LayerNorm([lstm_size[6], self.height//2, self.width//2]) 111 | # N * 32 * H/2 * W/2 -> N * 32 * H * W 112 | self.enc6 = nn.ConvTranspose2d(in_channels=lstm_size[6]+lstm_size[0], out_channels=lstm_size[6], kernel_size=3, stride=2, output_padding=1, padding=1) 113 | self.enc6_norm = nn.LayerNorm([lstm_size[6], self.height, self.width]) 114 | 115 | if self.dna: 116 | # N * 32 * H * W -> N * (DNA_KERN_SIZE*DNA_KERN_SIZE) * H * W 117 | self.enc7 = nn.ConvTranspose2d(in_channels=lstm_size[6], out_channels=DNA_KERN_SIZE**2, kernel_size=1, stride=1) 118 | else: 119 | # N * 32 * H * W -> N * 3 * H * W 120 | self.enc7 = nn.ConvTranspose2d(in_channels=lstm_size[6], out_channels=channels, kernel_size=1, stride=1) 121 | if self.cdna: 122 | # a reshape from lstm5: N * 128 * H/8 * W/8 -> N * (128 * H/8 * W/8) 123 | # N * (128 * H/8 * W/8) -> N * (10 * 5 * 5) 124 | in_dim = int(lstm_size[4] * self.height * self.width / 64) 125 | self.fc = nn.Linear(in_dim, DNA_KERN_SIZE * DNA_KERN_SIZE * self.num_masks) 126 | else: 127 | in_dim = int(lstm_size[4] * self.height * self.width / 64) 128 | self.fc = nn.Linear(in_dim, 100) 129 | self.fc_stp = nn.Linear(100, (self.num_masks-1) * 6) 130 | # N * 32 * H * W -> N * 11 * H * W 131 | self.maskout = nn.ConvTranspose2d(lstm_size[6], self.num_masks+1, kernel_size=1, stride=1) 132 | self.stateout = nn.Linear(STATE_DIM+ACTION_DIM, STATE_DIM) 133 | 134 | def forward(self, images, actions, init_state): 135 | ''' 136 | 137 | :param inputs: T * N * C * H * W 138 | :param state: T * N * C 139 | :param action: T * N * C 140 | :return: 141 | ''' 142 | 143 | lstm_state1, lstm_state2, lstm_state3, lstm_state4 = None, None, None, None 144 | lstm_state5, lstm_state6, lstm_state7 = None, None, None 145 | gen_images, gen_states = [], [] 146 | current_state = init_state 147 | if self.k == -1: 148 | feedself = True 149 | else: 150 | num_ground_truth = round(images[0].shape[1] * (self.k / (math.exp(self.iter_num/self.k) + self.k))) 151 | feedself = False 152 | 153 | for image, action in zip(images[:-1], actions[:-1]): 154 | 155 | done_warm_start = len(gen_images) >= self.context_frames 156 | 157 | if feedself and done_warm_start: 158 | # Feed in generated image. 159 | image = gen_images[-1] 160 | elif done_warm_start: 161 | # Scheduled sampling 162 | image = self.scheduled_sample(image, gen_images[-1], num_ground_truth) 163 | else: 164 | # Always feed in ground_truth 165 | image = image 166 | 167 | enc0 = self.enc0_norm(torch.relu(self.enc0(image))) 168 | 169 | lstm1, lstm_state1 = self.lstm1(enc0, lstm_state1) 170 | lstm1 = self.lstm1_norm(lstm1) 171 | 172 | lstm2, lstm_state2 = self.lstm2(lstm1, lstm_state2) 173 | lstm2 = self.lstm2_norm(lstm2) 174 | 175 | enc1 = torch.relu(self.enc1(lstm2)) 176 | 177 | lstm3, lstm_state3 = self.lstm3(enc1, lstm_state3) 178 | lstm3 = self.lstm3_norm(lstm3) 179 | 180 | lstm4, lstm_state4 = self.lstm4(lstm3, lstm_state4) 181 | lstm4 = self.lstm4_norm(lstm4) 182 | 183 | enc2 = torch.relu(self.enc2(lstm4)) 184 | 185 | # pass in state and action 186 | state_action = torch.cat([action, current_state], dim=1) 187 | smear = torch.reshape(state_action, list(state_action.shape)+[1, 1]) 188 | smear = smear.repeat(1, 1, enc2.shape[2], enc2.shape[3]) 189 | if self.use_state: 190 | enc2 = torch.cat([enc2, smear], dim=1) 191 | enc3 = torch.relu(self.enc3(enc2)) 192 | 193 | lstm5, lstm_state5 = self.lstm5(enc3, lstm_state5) 194 | lstm5 = self.lstm5_norm(lstm5) 195 | enc4 = torch.relu(self.enc4(lstm5)) 196 | 197 | lstm6, lstm_state6 = self.lstm6(enc4, lstm_state6) 198 | lstm6 = self.lstm6_norm(lstm6) 199 | # skip connection 200 | lstm6 = torch.cat([lstm6, enc1], dim=1) 201 | 202 | enc5 = torch.relu(self.enc5(lstm6)) 203 | 204 | lstm7, lstm_state7 = self.lstm7(enc5, lstm_state7) 205 | lstm7 = self.lstm7_norm(lstm7) 206 | # skip connection 207 | lstm7 = torch.cat([lstm7, enc0], dim=1) 208 | 209 | enc6 = self.enc6_norm(torch.relu(self.enc6(lstm7))) 210 | 211 | enc7 = torch.relu(self.enc7(enc6)) 212 | 213 | if self.dna: 214 | if self.num_masks != 1: 215 | raise ValueError('Only one mask is supported for DNA model.') 216 | transformed = [self.dna_transformation(image, enc7)] 217 | else: 218 | transformed = [torch.sigmoid(enc7)] 219 | _input = lstm5.view(lstm5.shape[0], -1) 220 | if self.cdna: 221 | transformed += self.cdna_transformation(image, _input) 222 | else: 223 | transformed += self.stp_transformation(image, _input) 224 | 225 | masks = torch.relu(self.maskout(enc6)) 226 | masks = torch.softmax(masks, dim=1) 227 | mask_list = torch.split(masks, split_size_or_sections=1, dim=1) 228 | 229 | output = mask_list[0] * image 230 | for layer, mask in zip(transformed, mask_list[1:]): 231 | output += layer * mask 232 | 233 | gen_images.append(output) 234 | 235 | current_state = self.stateout(state_action) 236 | gen_states.append(current_state) 237 | 238 | return gen_images, gen_states 239 | 240 | def stp_transformation(self, image, stp_input): 241 | identity_params = torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], dtype=torch.float32).unsqueeze(1).repeat(1, self.num_masks-1) 242 | 243 | stp_input = self.fc(stp_input) 244 | stp_input = self.fc_stp(stp_input) 245 | stp_input = stp_input.view(-1, 6, self.num_masks-1) + identity_params 246 | params = torch.unbind(stp_input, dim=-1) 247 | 248 | transformed = [F.grid_sample(image, F.affine_grid(param.view(-1, 3, 2), image.size())) for param in params] 249 | return transformed 250 | 251 | def cdna_transformation(self, image, cdna_input): 252 | batch_size, height, width = image.shape[0], image.shape[2], image.shape[3] 253 | 254 | cdna_kerns = self.fc(cdna_input) 255 | cdna_kerns = cdna_kerns.view(batch_size, self.num_masks, 1, DNA_KERN_SIZE, DNA_KERN_SIZE) 256 | cdna_kerns = torch.relu(cdna_kerns - RELU_SHIFT) + RELU_SHIFT 257 | norm_factor = torch.sum(cdna_kerns, dim=[2,3,4], keepdim=True) 258 | cdna_kerns /= norm_factor 259 | 260 | cdna_kerns = cdna_kerns.view(batch_size*self.num_masks, 1, DNA_KERN_SIZE,DNA_KERN_SIZE) 261 | image = image.permute([1, 0, 2, 3]) 262 | 263 | transformed = torch.conv2d(image, cdna_kerns, stride=1, padding=[2, 2], groups=batch_size) 264 | 265 | transformed = transformed.view(self.channels, batch_size, self.num_masks, height, width) 266 | transformed = transformed.permute([1, 0, 3, 4, 2]) 267 | transformed = torch.unbind(transformed, dim=-1) 268 | 269 | return transformed 270 | 271 | def dna_transformation(self, image, dna_input): 272 | image_pad = F.pad(image, [2, 2, 2, 2, 0, 0, 0, 0], "constant", 0) 273 | height, width = image.shape[2], image.shape[3] 274 | 275 | inputs = [] 276 | 277 | for xkern in range(DNA_KERN_SIZE): 278 | for ykern in range(DNA_KERN_SIZE): 279 | inputs.append(image_pad[:, :, xkern:xkern+height, ykern:ykern+width].clone().unsqueeze(dim=1)) 280 | inputs = torch.cat(inputs, dim=4) 281 | 282 | kernel = torch.relu(dna_input-RELU_SHIFT)+RELU_SHIFT 283 | kernel = kernel / torch.sum(kernel, dim=1, keepdim=True).unsqueeze(2) 284 | 285 | return torch.sum(kernel*inputs, dim=1, keepdim=False) 286 | 287 | def scheduled_sample(self, ground_truth_x, generated_x, num_ground_truth): 288 | generated_examps = torch.cat([ground_truth_x[:num_ground_truth, ...], generated_x[num_ground_truth:, :]], dim=0) 289 | return generated_examps 290 | 291 | 292 | 293 | -------------------------------------------------------------------------------- /data/raw/push_datafiles.txt: -------------------------------------------------------------------------------- 1 | push/push_testnovel/push_testnovel.tfrecord-00000-of-00005 2 | push/push_testnovel/push_testnovel.tfrecord-00001-of-00005 3 | push/push_testnovel/push_testnovel.tfrecord-00002-of-00005 4 | push/push_testnovel/push_testnovel.tfrecord-00003-of-00005 5 | push/push_testnovel/push_testnovel.tfrecord-00004-of-00005 6 | push/push_testseen/push_testseen.tfrecord-00000-of-00005 7 | push/push_testseen/push_testseen.tfrecord-00001-of-00005 8 | push/push_testseen/push_testseen.tfrecord-00002-of-00005 9 | push/push_testseen/push_testseen.tfrecord-00003-of-00005 10 | push/push_testseen/push_testseen.tfrecord-00004-of-00005 11 | push/push_train/push_train.tfrecord-00000-of-00264 12 | push/push_train/push_train.tfrecord-00001-of-00264 13 | push/push_train/push_train.tfrecord-00002-of-00264 14 | push/push_train/push_train.tfrecord-00003-of-00264 15 | push/push_train/push_train.tfrecord-00004-of-00264 16 | push/push_train/push_train.tfrecord-00005-of-00264 17 | push/push_train/push_train.tfrecord-00006-of-00264 18 | push/push_train/push_train.tfrecord-00007-of-00264 19 | push/push_train/push_train.tfrecord-00008-of-00264 20 | push/push_train/push_train.tfrecord-00009-of-00264 21 | push/push_train/push_train.tfrecord-00010-of-00264 22 | push/push_train/push_train.tfrecord-00011-of-00264 23 | push/push_train/push_train.tfrecord-00012-of-00264 24 | push/push_train/push_train.tfrecord-00013-of-00264 25 | push/push_train/push_train.tfrecord-00014-of-00264 26 | push/push_train/push_train.tfrecord-00015-of-00264 27 | push/push_train/push_train.tfrecord-00016-of-00264 28 | push/push_train/push_train.tfrecord-00017-of-00264 29 | push/push_train/push_train.tfrecord-00018-of-00264 30 | push/push_train/push_train.tfrecord-00019-of-00264 31 | push/push_train/push_train.tfrecord-00020-of-00264 32 | push/push_train/push_train.tfrecord-00021-of-00264 33 | push/push_train/push_train.tfrecord-00022-of-00264 34 | push/push_train/push_train.tfrecord-00023-of-00264 35 | push/push_train/push_train.tfrecord-00024-of-00264 36 | push/push_train/push_train.tfrecord-00025-of-00264 37 | push/push_train/push_train.tfrecord-00026-of-00264 38 | push/push_train/push_train.tfrecord-00027-of-00264 39 | push/push_train/push_train.tfrecord-00028-of-00264 40 | push/push_train/push_train.tfrecord-00029-of-00264 41 | push/push_train/push_train.tfrecord-00030-of-00264 42 | push/push_train/push_train.tfrecord-00031-of-00264 43 | push/push_train/push_train.tfrecord-00032-of-00264 44 | push/push_train/push_train.tfrecord-00033-of-00264 45 | push/push_train/push_train.tfrecord-00034-of-00264 46 | push/push_train/push_train.tfrecord-00035-of-00264 47 | push/push_train/push_train.tfrecord-00036-of-00264 48 | push/push_train/push_train.tfrecord-00037-of-00264 49 | push/push_train/push_train.tfrecord-00038-of-00264 50 | push/push_train/push_train.tfrecord-00039-of-00264 51 | push/push_train/push_train.tfrecord-00040-of-00264 52 | push/push_train/push_train.tfrecord-00041-of-00264 53 | push/push_train/push_train.tfrecord-00042-of-00264 54 | push/push_train/push_train.tfrecord-00043-of-00264 55 | push/push_train/push_train.tfrecord-00044-of-00264 56 | push/push_train/push_train.tfrecord-00045-of-00264 57 | push/push_train/push_train.tfrecord-00046-of-00264 58 | push/push_train/push_train.tfrecord-00047-of-00264 59 | push/push_train/push_train.tfrecord-00048-of-00264 60 | push/push_train/push_train.tfrecord-00049-of-00264 61 | push/push_train/push_train.tfrecord-00050-of-00264 62 | push/push_train/push_train.tfrecord-00051-of-00264 63 | push/push_train/push_train.tfrecord-00052-of-00264 64 | push/push_train/push_train.tfrecord-00053-of-00264 65 | push/push_train/push_train.tfrecord-00054-of-00264 66 | push/push_train/push_train.tfrecord-00055-of-00264 67 | push/push_train/push_train.tfrecord-00056-of-00264 68 | push/push_train/push_train.tfrecord-00057-of-00264 69 | push/push_train/push_train.tfrecord-00058-of-00264 70 | push/push_train/push_train.tfrecord-00059-of-00264 71 | push/push_train/push_train.tfrecord-00060-of-00264 72 | push/push_train/push_train.tfrecord-00061-of-00264 73 | push/push_train/push_train.tfrecord-00062-of-00264 74 | push/push_train/push_train.tfrecord-00063-of-00264 75 | push/push_train/push_train.tfrecord-00064-of-00264 76 | push/push_train/push_train.tfrecord-00065-of-00264 77 | push/push_train/push_train.tfrecord-00066-of-00264 78 | push/push_train/push_train.tfrecord-00067-of-00264 79 | push/push_train/push_train.tfrecord-00068-of-00264 80 | push/push_train/push_train.tfrecord-00069-of-00264 81 | push/push_train/push_train.tfrecord-00070-of-00264 82 | push/push_train/push_train.tfrecord-00071-of-00264 83 | push/push_train/push_train.tfrecord-00072-of-00264 84 | push/push_train/push_train.tfrecord-00073-of-00264 85 | push/push_train/push_train.tfrecord-00074-of-00264 86 | push/push_train/push_train.tfrecord-00075-of-00264 87 | push/push_train/push_train.tfrecord-00076-of-00264 88 | push/push_train/push_train.tfrecord-00077-of-00264 89 | push/push_train/push_train.tfrecord-00078-of-00264 90 | push/push_train/push_train.tfrecord-00079-of-00264 91 | push/push_train/push_train.tfrecord-00080-of-00264 92 | push/push_train/push_train.tfrecord-00081-of-00264 93 | push/push_train/push_train.tfrecord-00082-of-00264 94 | push/push_train/push_train.tfrecord-00083-of-00264 95 | push/push_train/push_train.tfrecord-00084-of-00264 96 | push/push_train/push_train.tfrecord-00085-of-00264 97 | push/push_train/push_train.tfrecord-00086-of-00264 98 | push/push_train/push_train.tfrecord-00087-of-00264 99 | push/push_train/push_train.tfrecord-00088-of-00264 100 | push/push_train/push_train.tfrecord-00089-of-00264 101 | push/push_train/push_train.tfrecord-00090-of-00264 102 | push/push_train/push_train.tfrecord-00091-of-00264 103 | push/push_train/push_train.tfrecord-00092-of-00264 104 | push/push_train/push_train.tfrecord-00093-of-00264 105 | push/push_train/push_train.tfrecord-00094-of-00264 106 | push/push_train/push_train.tfrecord-00095-of-00264 107 | push/push_train/push_train.tfrecord-00096-of-00264 108 | push/push_train/push_train.tfrecord-00097-of-00264 109 | push/push_train/push_train.tfrecord-00098-of-00264 110 | push/push_train/push_train.tfrecord-00099-of-00264 111 | push/push_train/push_train.tfrecord-00100-of-00264 112 | push/push_train/push_train.tfrecord-00101-of-00264 113 | push/push_train/push_train.tfrecord-00102-of-00264 114 | push/push_train/push_train.tfrecord-00103-of-00264 115 | push/push_train/push_train.tfrecord-00104-of-00264 116 | push/push_train/push_train.tfrecord-00105-of-00264 117 | push/push_train/push_train.tfrecord-00106-of-00264 118 | push/push_train/push_train.tfrecord-00107-of-00264 119 | push/push_train/push_train.tfrecord-00108-of-00264 120 | push/push_train/push_train.tfrecord-00109-of-00264 121 | push/push_train/push_train.tfrecord-00110-of-00264 122 | push/push_train/push_train.tfrecord-00111-of-00264 123 | push/push_train/push_train.tfrecord-00112-of-00264 124 | push/push_train/push_train.tfrecord-00113-of-00264 125 | push/push_train/push_train.tfrecord-00114-of-00264 126 | push/push_train/push_train.tfrecord-00115-of-00264 127 | push/push_train/push_train.tfrecord-00116-of-00264 128 | push/push_train/push_train.tfrecord-00117-of-00264 129 | push/push_train/push_train.tfrecord-00118-of-00264 130 | push/push_train/push_train.tfrecord-00119-of-00264 131 | push/push_train/push_train.tfrecord-00120-of-00264 132 | push/push_train/push_train.tfrecord-00121-of-00264 133 | push/push_train/push_train.tfrecord-00122-of-00264 134 | push/push_train/push_train.tfrecord-00123-of-00264 135 | push/push_train/push_train.tfrecord-00124-of-00264 136 | push/push_train/push_train.tfrecord-00125-of-00264 137 | push/push_train/push_train.tfrecord-00126-of-00264 138 | push/push_train/push_train.tfrecord-00127-of-00264 139 | push/push_train/push_train.tfrecord-00128-of-00264 140 | push/push_train/push_train.tfrecord-00129-of-00264 141 | push/push_train/push_train.tfrecord-00130-of-00264 142 | push/push_train/push_train.tfrecord-00131-of-00264 143 | push/push_train/push_train.tfrecord-00132-of-00264 144 | push/push_train/push_train.tfrecord-00133-of-00264 145 | push/push_train/push_train.tfrecord-00134-of-00264 146 | push/push_train/push_train.tfrecord-00135-of-00264 147 | push/push_train/push_train.tfrecord-00136-of-00264 148 | push/push_train/push_train.tfrecord-00137-of-00264 149 | push/push_train/push_train.tfrecord-00138-of-00264 150 | push/push_train/push_train.tfrecord-00139-of-00264 151 | push/push_train/push_train.tfrecord-00140-of-00264 152 | push/push_train/push_train.tfrecord-00141-of-00264 153 | push/push_train/push_train.tfrecord-00142-of-00264 154 | push/push_train/push_train.tfrecord-00143-of-00264 155 | push/push_train/push_train.tfrecord-00144-of-00264 156 | push/push_train/push_train.tfrecord-00145-of-00264 157 | push/push_train/push_train.tfrecord-00146-of-00264 158 | push/push_train/push_train.tfrecord-00147-of-00264 159 | push/push_train/push_train.tfrecord-00148-of-00264 160 | push/push_train/push_train.tfrecord-00149-of-00264 161 | push/push_train/push_train.tfrecord-00150-of-00264 162 | push/push_train/push_train.tfrecord-00151-of-00264 163 | push/push_train/push_train.tfrecord-00152-of-00264 164 | push/push_train/push_train.tfrecord-00153-of-00264 165 | push/push_train/push_train.tfrecord-00154-of-00264 166 | push/push_train/push_train.tfrecord-00155-of-00264 167 | push/push_train/push_train.tfrecord-00156-of-00264 168 | push/push_train/push_train.tfrecord-00157-of-00264 169 | push/push_train/push_train.tfrecord-00158-of-00264 170 | push/push_train/push_train.tfrecord-00159-of-00264 171 | push/push_train/push_train.tfrecord-00160-of-00264 172 | push/push_train/push_train.tfrecord-00161-of-00264 173 | push/push_train/push_train.tfrecord-00162-of-00264 174 | push/push_train/push_train.tfrecord-00163-of-00264 175 | push/push_train/push_train.tfrecord-00164-of-00264 176 | push/push_train/push_train.tfrecord-00165-of-00264 177 | push/push_train/push_train.tfrecord-00166-of-00264 178 | push/push_train/push_train.tfrecord-00167-of-00264 179 | push/push_train/push_train.tfrecord-00168-of-00264 180 | push/push_train/push_train.tfrecord-00169-of-00264 181 | push/push_train/push_train.tfrecord-00170-of-00264 182 | push/push_train/push_train.tfrecord-00171-of-00264 183 | push/push_train/push_train.tfrecord-00172-of-00264 184 | push/push_train/push_train.tfrecord-00173-of-00264 185 | push/push_train/push_train.tfrecord-00174-of-00264 186 | push/push_train/push_train.tfrecord-00175-of-00264 187 | push/push_train/push_train.tfrecord-00176-of-00264 188 | push/push_train/push_train.tfrecord-00177-of-00264 189 | push/push_train/push_train.tfrecord-00178-of-00264 190 | push/push_train/push_train.tfrecord-00179-of-00264 191 | push/push_train/push_train.tfrecord-00180-of-00264 192 | push/push_train/push_train.tfrecord-00181-of-00264 193 | push/push_train/push_train.tfrecord-00182-of-00264 194 | push/push_train/push_train.tfrecord-00183-of-00264 195 | push/push_train/push_train.tfrecord-00184-of-00264 196 | push/push_train/push_train.tfrecord-00185-of-00264 197 | push/push_train/push_train.tfrecord-00186-of-00264 198 | push/push_train/push_train.tfrecord-00187-of-00264 199 | push/push_train/push_train.tfrecord-00188-of-00264 200 | push/push_train/push_train.tfrecord-00189-of-00264 201 | push/push_train/push_train.tfrecord-00190-of-00264 202 | push/push_train/push_train.tfrecord-00191-of-00264 203 | push/push_train/push_train.tfrecord-00192-of-00264 204 | push/push_train/push_train.tfrecord-00193-of-00264 205 | push/push_train/push_train.tfrecord-00194-of-00264 206 | push/push_train/push_train.tfrecord-00195-of-00264 207 | push/push_train/push_train.tfrecord-00196-of-00264 208 | push/push_train/push_train.tfrecord-00197-of-00264 209 | push/push_train/push_train.tfrecord-00198-of-00264 210 | push/push_train/push_train.tfrecord-00199-of-00264 211 | push/push_train/push_train.tfrecord-00200-of-00264 212 | push/push_train/push_train.tfrecord-00201-of-00264 213 | push/push_train/push_train.tfrecord-00202-of-00264 214 | push/push_train/push_train.tfrecord-00203-of-00264 215 | push/push_train/push_train.tfrecord-00204-of-00264 216 | push/push_train/push_train.tfrecord-00205-of-00264 217 | push/push_train/push_train.tfrecord-00206-of-00264 218 | push/push_train/push_train.tfrecord-00207-of-00264 219 | push/push_train/push_train.tfrecord-00208-of-00264 220 | push/push_train/push_train.tfrecord-00209-of-00264 221 | push/push_train/push_train.tfrecord-00210-of-00264 222 | push/push_train/push_train.tfrecord-00211-of-00264 223 | push/push_train/push_train.tfrecord-00212-of-00264 224 | push/push_train/push_train.tfrecord-00213-of-00264 225 | push/push_train/push_train.tfrecord-00214-of-00264 226 | push/push_train/push_train.tfrecord-00215-of-00264 227 | push/push_train/push_train.tfrecord-00216-of-00264 228 | push/push_train/push_train.tfrecord-00217-of-00264 229 | push/push_train/push_train.tfrecord-00218-of-00264 230 | push/push_train/push_train.tfrecord-00219-of-00264 231 | push/push_train/push_train.tfrecord-00220-of-00264 232 | push/push_train/push_train.tfrecord-00221-of-00264 233 | push/push_train/push_train.tfrecord-00222-of-00264 234 | push/push_train/push_train.tfrecord-00223-of-00264 235 | push/push_train/push_train.tfrecord-00224-of-00264 236 | push/push_train/push_train.tfrecord-00225-of-00264 237 | push/push_train/push_train.tfrecord-00226-of-00264 238 | push/push_train/push_train.tfrecord-00227-of-00264 239 | push/push_train/push_train.tfrecord-00228-of-00264 240 | push/push_train/push_train.tfrecord-00229-of-00264 241 | push/push_train/push_train.tfrecord-00230-of-00264 242 | push/push_train/push_train.tfrecord-00231-of-00264 243 | push/push_train/push_train.tfrecord-00232-of-00264 244 | push/push_train/push_train.tfrecord-00233-of-00264 245 | push/push_train/push_train.tfrecord-00234-of-00264 246 | push/push_train/push_train.tfrecord-00235-of-00264 247 | push/push_train/push_train.tfrecord-00236-of-00264 248 | push/push_train/push_train.tfrecord-00237-of-00264 249 | push/push_train/push_train.tfrecord-00238-of-00264 250 | push/push_train/push_train.tfrecord-00239-of-00264 251 | push/push_train/push_train.tfrecord-00240-of-00264 252 | push/push_train/push_train.tfrecord-00241-of-00264 253 | push/push_train/push_train.tfrecord-00242-of-00264 254 | push/push_train/push_train.tfrecord-00243-of-00264 255 | push/push_train/push_train.tfrecord-00244-of-00264 256 | push/push_train/push_train.tfrecord-00245-of-00264 257 | push/push_train/push_train.tfrecord-00246-of-00264 258 | push/push_train/push_train.tfrecord-00247-of-00264 259 | push/push_train/push_train.tfrecord-00248-of-00264 260 | push/push_train/push_train.tfrecord-00249-of-00264 261 | push/push_train/push_train.tfrecord-00250-of-00264 262 | push/push_train/push_train.tfrecord-00251-of-00264 263 | push/push_train/push_train.tfrecord-00252-of-00264 264 | push/push_train/push_train.tfrecord-00253-of-00264 265 | push/push_train/push_train.tfrecord-00254-of-00264 266 | push/push_train/push_train.tfrecord-00255-of-00264 267 | push/push_train/push_train.tfrecord-00256-of-00264 268 | push/push_train/push_train.tfrecord-00257-of-00264 269 | push/push_train/push_train.tfrecord-00258-of-00264 270 | push/push_train/push_train.tfrecord-00259-of-00264 271 | push/push_train/push_train.tfrecord-00260-of-00264 272 | push/push_train/push_train.tfrecord-00261-of-00264 273 | push/push_train/push_train.tfrecord-00262-of-00264 274 | push/push_train/push_train.tfrecord-00263-of-00264 --------------------------------------------------------------------------------