├── nn ├── __init__.py ├── utils │ ├── math.py │ ├── misc.py │ └── viz.py ├── datasets │ ├── iterators.py │ └── generators.py └── network │ ├── blocks.py │ ├── cells.py │ ├── base.py │ ├── stn.py │ └── physics_models.py ├── requirements-gpu.txt ├── requirements.txt ├── LICENSE ├── runners ├── run_base.py └── run_physics.py ├── .gitignore └── README.md /nn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nn/utils/math.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def sigmoid(x): 4 | return 1 / (1 + np.exp(-x)) -------------------------------------------------------------------------------- /requirements-gpu.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu==1.12.0 2 | Pillow==4.2.1 3 | requests==2.18.4 4 | scipy==0.19.1 5 | six==1.10.0 6 | matplotlib==2.1.0 7 | pytest==3.3.2 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.12.0 2 | Pillow==4.2.1 3 | requests==2.18.4 4 | scipy==0.19.1 5 | six==1.10.0 6 | matplotlib==2.1.0 7 | tinydb==3.7.0 8 | pytest==3.3.2 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Miguel Jaques 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /nn/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import inspect 4 | import numpy as np 5 | import zipfile 6 | 7 | def log_metrics(logger, prefix, metrics): 8 | metrics_string = " ".join([k+"=%s"%metrics[k] for k in sorted(metrics.keys())]) 9 | string = prefix + " " + metrics_string 10 | logger.info(string) 11 | 12 | def classes_in_module(module): 13 | classes = {} 14 | for name, obj in inspect.getmembers(module): 15 | if inspect.isclass(obj): 16 | if obj.__module__ == module.__name__: 17 | classes[name] = obj 18 | return classes 19 | 20 | def rgb2gray(rgb): 21 | return np.dot(rgb[...,:3], [0.299, 0.587, 0.114]) 22 | 23 | def zipdir(path, save_dir): 24 | zipf = zipfile.ZipFile(os.path.join(save_dir, 'code.zip'), 'w', zipfile.ZIP_DEFLATED) 25 | 26 | # ziph is zipfile handle 27 | for root, dirs, files in os.walk(path): 28 | for file in files: 29 | if file.split(".")[-1] == "py": 30 | zipf.write(os.path.join(root, file), 31 | os.path.relpath(os.path.join(root, file), os.path.join(path, '..'))) 32 | 33 | zipf.close() -------------------------------------------------------------------------------- /runners/run_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import tensorflow as tf 4 | 5 | tf.app.flags.DEFINE_integer("epochs", 10, "Epochs to train.") 6 | tf.app.flags.DEFINE_integer("batch_size", 100, "Training batch size") 7 | tf.app.flags.DEFINE_string("save_dir", "", "Directory to save checkpoint and logs.") 8 | tf.app.flags.DEFINE_bool("use_ckpt", False, "Whether to start from scratch of start from checkpoint.") 9 | tf.app.flags.DEFINE_string("ckpt_dir", "", "Checkpoint dir to use.") 10 | tf.app.flags.DEFINE_float("base_lr", 1e-3, "Base learning rate.") 11 | tf.app.flags.DEFINE_bool("anneal_lr", True, "Whether to anneal lr after 0.75 of total epochs.") 12 | tf.app.flags.DEFINE_string("optimizer", "rmsprop", "Optimizer to use.") 13 | tf.app.flags.DEFINE_integer("save_every_n_epochs", 5, "Epochs between checkpoint saves.") 14 | tf.app.flags.DEFINE_integer("eval_every_n_epochs", 1, "Epochs between validation run.") 15 | tf.app.flags.DEFINE_integer("print_interval", 10, "Print train metrics every n mini-batches.") 16 | tf.app.flags.DEFINE_bool("debug", False, "If true, eval is not ran before training.") 17 | tf.app.flags.DEFINE_bool("test_mode", False, "If true, only run test set.") 18 | 19 | logger = logging.getLogger("tf") 20 | logger.setLevel(logging.DEBUG) 21 | # create console handler 22 | ch = logging.StreamHandler() 23 | ch.setLevel(logging.DEBUG) 24 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(message)s') 25 | ch.setFormatter(formatter) 26 | logger.addHandler(ch) 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | -------------------------------------------------------------------------------- /nn/utils/viz.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | def gallery(array, ncols=3): 5 | nindex, height, width, intensity = array.shape 6 | 7 | bordered = 0.5*np.ones([nindex, height+2, width+2, intensity]) 8 | for i in range(nindex): 9 | bordered[i,1:-1,1:-1,:] = array[i] 10 | 11 | array = bordered 12 | nindex, height, width, intensity = array.shape 13 | 14 | nrows = nindex//ncols 15 | assert nindex == nrows*ncols 16 | # want result.shape = (height*nrows, width*ncols, intensity) 17 | result = (array.reshape(nrows, ncols, height, width, intensity) 18 | .swapaxes(1,2) 19 | .reshape(height*nrows, width*ncols, intensity)) 20 | return result 21 | 22 | def gif(filename, array, fps=10, scale=1.0): 23 | from moviepy.editor import ImageSequenceClip 24 | """Creates a gif given a stack of images using moviepy 25 | Notes 26 | ----- 27 | works with current Github version of moviepy (not the pip version) 28 | https://github.com/Zulko/moviepy/commit/d4c9c37bc88261d8ed8b5d9b7c317d13b2cdf62e 29 | Usage 30 | ----- 31 | >>> X = randn(100, 64, 64) 32 | >>> gif('test.gif', X) 33 | Parameters 34 | ---------- 35 | filename : string 36 | The filename of the gif to write to 37 | array : array_like 38 | A numpy array that contains a sequence of images 39 | fps : int 40 | frames per second (default: 10) 41 | scale : float 42 | how much to rescale each image by (default: 1.0) 43 | """ 44 | 45 | # ensure that the file has the .gif extension 46 | fname, _ = os.path.splitext(filename) 47 | filename = fname + '.gif' 48 | 49 | # copy into the color dimension if the images are black and white 50 | if array.ndim == 3: 51 | array = array[..., np.newaxis] * np.ones(3) 52 | 53 | # make the moviepy clip 54 | clip = ImageSequenceClip(list(array), fps=fps).resize(scale) 55 | clip.write_gif(filename, fps=fps) 56 | return clip -------------------------------------------------------------------------------- /nn/datasets/iterators.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | class DataIterator: 7 | 8 | def __init__(self, X, Y=None): 9 | self.X = X 10 | self.Y = Y 11 | 12 | self.num_examples = self.X.shape[0] 13 | self.epochs_completed = 0 14 | self.indices = np.arange(self.num_examples) 15 | self.reset_iteration() 16 | 17 | def reset_iteration(self): 18 | np.random.shuffle(self.indices) 19 | self.start_idx = 0 20 | 21 | def get_epoch(self): 22 | return self.epochs_completed 23 | 24 | def reset_epoch(self): 25 | self.reset_iteration() 26 | self.epochs_completed = 0 27 | 28 | def next_batch(self, batch_size, data_type="train", shuffle=True):# 29 | assert data_type in ["train", "val", "test"], \ 30 | "data_type must be 'train', 'val', or 'test'." 31 | 32 | idx = self.indices[self.start_idx:self.start_idx + batch_size] 33 | 34 | batch_x = self.X[idx] 35 | batch_y = self.Y[idx] if self.Y is not None else self.Y 36 | self.start_idx += batch_size 37 | 38 | if self.start_idx + batch_size > self.num_examples: 39 | self.reset_iteration() 40 | self.epochs_completed += 1 41 | 42 | return (batch_x, batch_y) 43 | 44 | def sample_random_batch(self, batch_size): 45 | start_idx = np.random.randint(0, self.num_examples - batch_size) 46 | batch_x = self.X[self.start_idx:self.start_idx + batch_size] 47 | batch_y = self.Y[self.start_idx:self.start_idx + batch_size] if self.Y is not None else self.Y 48 | 49 | return (batch_x, batch_y) 50 | 51 | 52 | def get_iterators(file, conv=False, datapoints=0): 53 | data = np.load(file) 54 | if conv: 55 | img_shape = data["train_x"][0,0].shape 56 | else: 57 | img_shape = data["train_x"][0,0].flatten().shape 58 | train_it = DataIterator(X=data["train_x"].reshape(data["train_x"].shape[:2]+img_shape)/255) 59 | valid_it = DataIterator(X=data["valid_x"].reshape(data["valid_x"].shape[:2]+img_shape)/255) 60 | test_it = DataIterator(X=data["test_x"].reshape(data["test_x"].shape[:2]+img_shape)/255) 61 | return train_it, valid_it, test_it 62 | -------------------------------------------------------------------------------- /runners/run_physics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import inspect 4 | import tensorflow as tf 5 | from nn.network import physics_models 6 | from nn.utils.misc import classes_in_module 7 | from nn.datasets.iterators import get_iterators 8 | import runners.run_base 9 | 10 | tf.app.flags.DEFINE_string("task", "", "Type of task.") 11 | tf.app.flags.DEFINE_string("model", "PhysicsNet", "Model to use.") 12 | tf.app.flags.DEFINE_integer("recurrent_units", 100, "Number of units for each lstm, if using black-box dynamics.") 13 | tf.app.flags.DEFINE_integer("lstm_layers", 1, "Number of lstm cells to use, if using black-box dynamics") 14 | tf.app.flags.DEFINE_string("cell_type", "", "Type of pendulum to use.") 15 | tf.app.flags.DEFINE_string("encoder_type", "conv_encoder", "Type of encoder to use.") 16 | tf.app.flags.DEFINE_string("decoder_type", "conv_st_decoder", "Type of decoder to use.") 17 | 18 | tf.app.flags.DEFINE_float("autoencoder_loss", 0.0, "Autoencoder loss weighing.") 19 | tf.app.flags.DEFINE_bool("alt_vel", False, "Whether to use linear velocity computation.") 20 | tf.app.flags.DEFINE_bool("color", False, "Whether images are rbg or grayscale.") 21 | tf.app.flags.DEFINE_integer("datapoints", 0, "How many datapoints from the dataset to use. \ 22 | Useful for measuring data efficiency. default=0 uses all data.") 23 | 24 | FLAGS = tf.app.flags.FLAGS 25 | 26 | model_classes = classes_in_module(physics_models) 27 | Model = model_classes[FLAGS.model] 28 | 29 | data_file, test_data_file, cell_type, seq_len, test_seq_len, input_steps, pred_steps, input_size = { 30 | "bouncing_balls": ( 31 | "bouncing/color_bounce_vx8_vy8_sl12_r2.npz", 32 | "bouncing/color_bounce_vx8_vy8_sl30_r2.npz", 33 | "bouncing_ode_cell", 34 | 12, 30, 4, 6, 32*32), 35 | "spring_color": ( 36 | "spring_color/color_spring_vx8_vy8_sl12_r2_k4_e6.npz", 37 | "spring_color/color_spring_vx8_vy8_sl30_r2_k4_e6.npz", 38 | "spring_ode_cell", 39 | 12, 30, 4, 6, 32*32), 40 | "spring_color_half": ( 41 | "spring_color_half/color_spring_vx4_vy4_sl12_r2_k4_e6_halfpane.npz", 42 | "spring_color_half/color_spring_vx4_vy4_sl30_r2_k4_e6_halfpane.npz", 43 | "spring_ode_cell", 44 | 12, 30, 4, 6, 32*32), 45 | "3bp_color": ( 46 | "3bp_color/color_3bp_vx2_vy2_sl20_r2_g60_m1_dt05.npz", 47 | "3bp_color/color_3bp_vx2_vy2_sl40_r2_g60_m1_dt05.npz", 48 | "gravity_ode_cell", 49 | 20, 40, 4, 12, 36*36), 50 | "mnist_spring_color": ( 51 | "mnist_spring_color/color_mnist_spring_vx8_vy8_sl12_r2_k2_e12.npz", 52 | "mnist_spring_color/color_mnist_spring_vx8_vy8_sl30_r2_k2_e12.npz", 53 | "spring_ode_cell", 54 | 12, 30, 3, 7, 64*64) 55 | }[FLAGS.task] 56 | 57 | if __name__ == "__main__": 58 | if not FLAGS.test_mode: 59 | network = Model(FLAGS.task, FLAGS.recurrent_units, FLAGS.lstm_layers, cell_type, 60 | seq_len, input_steps, pred_steps, 61 | FLAGS.autoencoder_loss, FLAGS.alt_vel, FLAGS.color, 62 | input_size, FLAGS.encoder_type, FLAGS.decoder_type) 63 | 64 | network.build_graph() 65 | network.build_optimizer(FLAGS.base_lr, FLAGS.optimizer, FLAGS.anneal_lr) 66 | network.initialize_graph(FLAGS.save_dir, FLAGS.use_ckpt, FLAGS.ckpt_dir) 67 | 68 | data_iterators = get_iterators( 69 | os.path.join( 70 | os.path.dirname(os.path.realpath(__file__)), 71 | "../data/datasets/%s"%data_file), conv=True, datapoints=FLAGS.datapoints) 72 | network.get_data(data_iterators) 73 | network.train(FLAGS.epochs, FLAGS.batch_size, FLAGS.save_every_n_epochs, FLAGS.eval_every_n_epochs, 74 | FLAGS.print_interval, FLAGS.debug) 75 | 76 | tf.reset_default_graph() 77 | 78 | network = Model(FLAGS.task, FLAGS.recurrent_units, FLAGS.lstm_layers, cell_type, 79 | test_seq_len, input_steps, pred_steps, 80 | FLAGS.autoencoder_loss, FLAGS.alt_vel, FLAGS.color, 81 | input_size, FLAGS.encoder_type, FLAGS.decoder_type) 82 | 83 | network.build_graph() 84 | network.build_optimizer(FLAGS.base_lr, FLAGS.optimizer, FLAGS.anneal_lr) 85 | network.initialize_graph(FLAGS.save_dir, True, FLAGS.ckpt_dir) 86 | 87 | data_iterators = get_iterators( 88 | os.path.join( 89 | os.path.dirname(os.path.realpath(__file__)), 90 | "../data/datasets/%s"%test_data_file), conv=True, datapoints=FLAGS.datapoints) 91 | network.get_data(data_iterators) 92 | network.train(0, FLAGS.batch_size, FLAGS.save_every_n_epochs, FLAGS.eval_every_n_epochs, 93 | FLAGS.print_interval, FLAGS.debug) 94 | -------------------------------------------------------------------------------- /nn/network/blocks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | """ Useful subnetwork components """ 5 | 6 | 7 | def unet(inp, base_channels, out_channels, upsamp=True): 8 | h = inp 9 | h = tf.layers.conv2d(h, base_channels, 3, activation=tf.nn.relu, padding="SAME") 10 | h1 = tf.layers.conv2d(h, base_channels, 3, activation=tf.nn.relu, padding="SAME") 11 | h = tf.layers.max_pooling2d(h1, 2, 2) 12 | h = tf.layers.conv2d(h, base_channels*2, 3, activation=tf.nn.relu, padding="SAME") 13 | h2 = tf.layers.conv2d(h, base_channels*2, 3, activation=tf.nn.relu, padding="SAME") 14 | h = tf.layers.max_pooling2d(h2, 2, 2) 15 | h = tf.layers.conv2d(h, base_channels*4, 3, activation=tf.nn.relu, padding="SAME") 16 | h3 = tf.layers.conv2d(h, base_channels*4, 3, activation=tf.nn.relu, padding="SAME") 17 | h = tf.layers.max_pooling2d(h3, 2, 2) 18 | h = tf.layers.conv2d(h, base_channels*8, 3, activation=tf.nn.relu, padding="SAME") 19 | h4 = tf.layers.conv2d(h, base_channels*8, 3, activation=tf.nn.relu, padding="SAME") 20 | if upsamp: 21 | h = tf.image.resize_bilinear(h4, h3.get_shape()[1:3]) 22 | h = tf.layers.conv2d(h, base_channels*2, 3, activation=None, padding="SAME") 23 | else: 24 | h = tf.layers.conv2d_transpose(h, base_channels*4, 3, 2, activation=None, padding="SAME") 25 | h = tf.concat([h, h3], axis=-1) 26 | h = tf.layers.conv2d(h, base_channels*4, 3, activation=tf.nn.relu, padding="SAME") 27 | h = tf.layers.conv2d(h, base_channels*4, 3, activation=tf.nn.relu, padding="SAME") 28 | if upsamp: 29 | h = tf.image.resize_bilinear(h, h2.get_shape()[1:3]) 30 | h = tf.layers.conv2d(h, base_channels*2, 3, activation=None, padding="SAME") 31 | else: 32 | h = tf.layers.conv2d_transpose(h, base_channels*2, 3, 2, activation=None, padding="SAME") 33 | h = tf.concat([h, h2], axis=-1) 34 | h = tf.layers.conv2d(h, base_channels*2, 3, activation=tf.nn.relu, padding="SAME") 35 | h = tf.layers.conv2d(h, base_channels*2, 3, activation=tf.nn.relu, padding="SAME") 36 | if upsamp: 37 | h = tf.image.resize_bilinear(h, h1.get_shape()[1:3]) 38 | h = tf.layers.conv2d(h, base_channels*2, 3, activation=None, padding="SAME") 39 | else: 40 | h = tf.layers.conv2d_transpose(h, base_channels, 3, 2, activation=None, padding="SAME") 41 | h = tf.concat([h, h1], axis=-1) 42 | h = tf.layers.conv2d(h, base_channels, 3, activation=tf.nn.relu, padding="SAME") 43 | h = tf.layers.conv2d(h, base_channels, 3, activation=tf.nn.relu, padding="SAME") 44 | 45 | h = tf.layers.conv2d(h, out_channels, 1, activation=None, padding="SAME") 46 | return h 47 | 48 | 49 | def shallow_unet(inp, base_channels, out_channels, upsamp=True): 50 | h = inp 51 | h = tf.layers.conv2d(h, base_channels, 3, activation=tf.nn.relu, padding="SAME") 52 | h1 = tf.layers.conv2d(h, base_channels, 3, activation=tf.nn.relu, padding="SAME") 53 | h = tf.layers.max_pooling2d(h1, 2, 2) 54 | h = tf.layers.conv2d(h, base_channels*2, 3, activation=tf.nn.relu, padding="SAME") 55 | h2 = tf.layers.conv2d(h, base_channels*2, 3, activation=tf.nn.relu, padding="SAME") 56 | h = tf.layers.max_pooling2d(h2, 2, 2) 57 | h = tf.layers.conv2d(h, base_channels*4, 3, activation=tf.nn.relu, padding="SAME") 58 | h = tf.layers.conv2d(h, base_channels*4, 3, activation=tf.nn.relu, padding="SAME") 59 | #h = tf.concat([h, h3], axis=-1) 60 | #h = tf.layers.conv2d(h, base_channels*4, 3, activation=tf.nn.relu, padding="SAME") 61 | #h = tf.layers.conv2d(h, base_channels*4, 3, activation=tf.nn.relu, padding="SAME") 62 | if upsamp: 63 | h = tf.image.resize_bilinear(h, h2.get_shape()[1:3]) 64 | h = tf.layers.conv2d(h, base_channels*2, 3, activation=None, padding="SAME") 65 | else: 66 | h = tf.layers.conv2d_transpose(h, base_channels*2, 3, 2, activation=None, padding="SAME") 67 | h = tf.concat([h, h2], axis=-1) 68 | h = tf.layers.conv2d(h, base_channels*2, 3, activation=tf.nn.relu, padding="SAME") 69 | h = tf.layers.conv2d(h, base_channels*2, 3, activation=tf.nn.relu, padding="SAME") 70 | if upsamp: 71 | h = tf.image.resize_bilinear(h, h1.get_shape()[1:3]) 72 | h = tf.layers.conv2d(h, base_channels*2, 3, activation=None, padding="SAME") 73 | else: 74 | h = tf.layers.conv2d_transpose(h, base_channels, 3, 2, activation=None, padding="SAME") 75 | h = tf.concat([h, h1], axis=-1) 76 | h = tf.layers.conv2d(h, base_channels, 3, activation=tf.nn.relu, padding="SAME") 77 | h = tf.layers.conv2d(h, base_channels, 3, activation=tf.nn.relu, padding="SAME") 78 | 79 | h = tf.layers.conv2d(h, out_channels, 1, activation=None, padding="SAME") 80 | return h 81 | 82 | 83 | def variable_from_network(shape): 84 | # Produces a variable from a vector of 1's. 85 | # Improves learning speed of contents and masks. 86 | var = tf.ones([1,10]) 87 | var = tf.layers.dense(var, 200, activation=tf.tanh) 88 | var = tf.layers.dense(var, np.prod(shape), activation=None) 89 | var = tf.reshape(var, shape) 90 | return var 91 | -------------------------------------------------------------------------------- /nn/network/cells.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | class ode_cell(tf.nn.rnn_cell.BasicRNNCell): 6 | 7 | @property 8 | def state_size(self): 9 | return self._num_units, self._num_units 10 | 11 | def zero_state(self, batch_size, dtype): 12 | x_0 = tf.zeros([batch_size, self._num_units], dtype=dtype) 13 | v_0 = tf.zeros([batch_size, self._num_units], dtype=dtype) 14 | return x_0, v_0 15 | 16 | 17 | class bouncing_ode_cell(ode_cell): 18 | """ Assumes there are 2 objects """ 19 | 20 | def build(self, inputs_shape): 21 | if inputs_shape[-1] is None: 22 | raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" 23 | % str(inputs_shape)) 24 | 25 | input_depth = inputs_shape[-1] 26 | h_depth = self._num_units 27 | assert h_depth == input_depth 28 | 29 | self.dt = self.add_variable("dt_x", shape=[], initializer=tf.constant_initializer(0.3), trainable=False) 30 | self.built = True 31 | 32 | def call(self, poss, vels): 33 | poss = tf.split(poss, 2, 1) 34 | vels = tf.split(vels, 2, 1) 35 | for i in range(5): 36 | poss[0] = poss[0] + self.dt/5*vels[0] 37 | poss[1] = poss[1] + self.dt/5*vels[1] 38 | 39 | for j in range(2): 40 | # Compute wall collisions. Image boundaries are hard-coded. 41 | vels[j] = tf.where(tf.greater(poss[j]+2, 32), -vels[j], vels[j]) 42 | vels[j] = tf.where(tf.greater(0.0, poss[j]-2), -vels[j], vels[j]) 43 | poss[j] = tf.where(tf.greater(poss[j]+2, 32), 32-(poss[j]+2-32)-2, poss[j]) 44 | poss[j] = tf.where(tf.greater(0.0, poss[j]-2), -(poss[j]-2)+2, poss[j]) 45 | 46 | poss = tf.concat(poss, axis=1) 47 | vels = tf.concat(vels, axis=1) 48 | return poss, vels 49 | 50 | 51 | class spring_ode_cell(ode_cell): 52 | """ Assumes there are 2 objects """ 53 | 54 | def build(self, inputs_shape): 55 | if inputs_shape[-1] is None: 56 | raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" 57 | % str(inputs_shape)) 58 | 59 | input_depth = inputs_shape[-1] 60 | h_depth = self._num_units 61 | assert h_depth == input_depth 62 | 63 | self.dt = self.add_variable("dt_x", shape=[], initializer=tf.constant_initializer(0.3), trainable=False) 64 | self.k = self.add_variable("log_k", shape=[], initializer=tf.constant_initializer(np.log(1.0)), trainable=True) 65 | self.equil = self.add_variable("log_l", shape=[], initializer=tf.constant_initializer(np.log(1.0)), trainable=True) 66 | self.built = True 67 | 68 | def call(self, poss, vels): 69 | poss = tf.split(poss, 2, 1) 70 | vels = tf.split(vels, 2, 1) 71 | for i in range(5): 72 | norm = tf.sqrt(tf.abs(tf.reduce_sum(tf.square(poss[0]-poss[1]), axis=-1, keepdims=True))) 73 | direction = (poss[0]-poss[1])/(norm+1e-4) 74 | F = tf.exp(self.k)*(norm-2*tf.exp(self.equil))*direction 75 | vels[0] = vels[0] - self.dt/5*F 76 | vels[1] = vels[1] + self.dt/5*F 77 | 78 | poss[0] = poss[0] + self.dt/5*vels[0] 79 | poss[1] = poss[1] + self.dt/5*vels[1] 80 | 81 | poss = tf.concat(poss, axis=1) 82 | vels = tf.concat(vels, axis=1) 83 | return poss, vels 84 | 85 | 86 | class gravity_ode_cell(ode_cell): 87 | """ Assumes there are 3 objects """ 88 | 89 | def build(self, inputs_shape): 90 | if inputs_shape[-1] is None: 91 | raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" 92 | % str(inputs_shape)) 93 | 94 | input_depth = inputs_shape[-1] 95 | h_depth = self._num_units 96 | assert h_depth == input_depth 97 | 98 | self.dt = self.add_variable("dt_x", shape=[], initializer=tf.constant_initializer(0.5), trainable=False) 99 | self.g = self.add_variable("log_g", shape=[], initializer=tf.constant_initializer(np.log(1.0)), trainable=True) 100 | self.m = self.add_variable("log_m", shape=[], initializer=tf.constant_initializer(np.log(1.0)), trainable=False) 101 | self.A = tf.exp(self.g)*tf.exp(2*self.m) 102 | self.built = True 103 | 104 | def call(self, poss, vels): 105 | for i in range(5): 106 | vecs = [poss[:,0:2]-poss[:,2:4], poss[:,2:4]-poss[:,4:6], poss[:,4:6]-poss[:,0:2]] 107 | norms = [tf.sqrt(tf.clip_by_value(tf.reduce_sum(tf.square(vec), axis=-1, keepdims=True), 1e-1, 1e5)) for vec in vecs] 108 | F = [vec/tf.pow(tf.clip_by_value(norm, 1, 170), 3) for vec, norm in zip(vecs, norms)] 109 | F = [F[0]-F[2], F[1]-F[0], F[2]-F[1]] 110 | F = [-self.A*f for f in F] 111 | F = tf.concat(F, axis=1) 112 | vels = vels + self.dt/5*F 113 | poss = poss + self.dt/5*vels 114 | 115 | poss = tf.concat(poss, axis=1) 116 | vels = tf.concat(vels, axis=1) 117 | return poss, vels 118 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Physics-as-Inverse-Graphics 2 | 3 | This repo contains the code for the paper Physics-as-Inverse-Graphics: Unsupervised Physical Parameter Estimation from Video (https://arxiv.org/abs/1905.11169). 4 | 5 | ## Running experiments 6 | 7 | To train run: 8 | 9 | ``` 10 | PYTHONPATH=. python runners/run_physics.py --task=spring_color --model=PhysicsNet --epochs=500 11 | --batch_size=100 --save_dir= --autoencoder_loss=3.0 --base_lr=3e-4 --anneal_lr=true 12 | --color=true --eval_every_n_epochs=10 --print_interval=100 --debug=false --use_ckpt=false 13 | ``` 14 | 15 | This will automatically run on the test set (evaluation with extrapolation range) in the end of training. 16 | To run only evaluation on a previously trained model use the extra flags `--test_mode` and `--use_ckpt`: 17 | 18 | ``` 19 | PYTHONPATH=. python runners/run_physics.py --task=spring_color --model=PhysicsNet --epochs=500 20 | --batch_size=100 --save_dir= --autoencoder_loss=3.0 --base_lr=3e-4 21 | --color=true --eval_every_n_epochs=10 --print_interval=100 --debug=false 22 | --use_ckpt=true --test_mode=true 23 | ``` 24 | 25 | This will use the checkpoint found in ``. To evaluate a checkpoint from a different folder use `--ckpt_dir`: 26 | 27 | ``` 28 | PYTHONPATH=. python runners/run_physics.py --task=spring_color --model=PhysicsNet --epochs=500 29 | --batch_size=100 --save_dir= --autoencoder_loss=3.0 --base_lr=3e-4 30 | --color=true --eval_every_n_epochs=10 --print_interval=100 --debug=false 31 | --use_ckpt=true --test_mode=true --ckpt_dir= 32 | ``` 33 | 34 | To keep training a model from a checkpoint, simply use the same as above, but with `--test_mode=false`. Note that in this case `base_lr` will be used as the starting learning rate - there is no global learning rate variable saved in the checkpoint - so if you restart training after annealing was applied, be sure to change the `base_lr` accordingly. 35 | 36 | Notes on flags, hyperparameters, and general training behavior: 37 | * Using `--anneal_lr=true` will reduce the base learning rate by a factor of 5 after 70% of the epochs are completed. To change this find the corresponding code in `nn/network/base.py`, in the class method `BaseNet.train()`. 38 | * When using `autoencoder_loss`, the encoder and decoder parts of the model will train fairly early in training. The rest of training is mostly improving the physical parameters, but this can take a long time. I recommend training between 500 and 1000 epochs (higher for `3bp_color` dataset, lower for `spring` datasets). 39 | 40 | 41 | ## Tasks 42 | 43 | There are currently 5 tasks implemented in this repo: 44 | 45 | * `bouncing_balls`: (here there are no learnable physical parameters) 46 | * `spring_color`: Two colored balls connected by a spring. 47 | * `spring_color_half`: Same as above, but in the input and prediction range the balls never leave half of the image. They only move to the other half of the image in the extrapolation range of the test set. 48 | * `mnist_spring_color`: Two colored MNIST digits connected by a spring, in a CIFAR background. 49 | * `3bp_color`: Three colored balls connected by gravitational force (`3bp` stands for 3-body-problem). 50 | 51 | The input, prediction and extrapolation steps are preset for each task, and correspond to the values described in the paper (see 1st paragraph of Section 4.1). 52 | 53 | ## Data 54 | 55 | The datasets for the tasks above can be downloaded from [this Google Drive](https://drive.google.com/open?id=16uvdhZiv2CkoDDDNGRG4l_T7LEZXzfyA). These datasets should be placed in a folder called `/data/datasets` in order to be automatically fetched by the code. 56 | 57 | ## Hyperparameters 58 | 59 | For the tasks above, the recommended `base_lr` and `autoencoder_loss` paramters are: 60 | * `bouncing_balls`: `--base_lr=3e-4 --autoencoder_loss=2.0` 61 | * `spring_color`: `--base_lr=6e-4 --autoencoder_loss=3.0` 62 | * `spring_color_half`: `--base_lr=6e-4 --autoencoder_loss=3.0` 63 | * `mnist_spring_color`: `--base_lr=6e-4 --autoencoder_loss=3.0` 64 | * `3bp_color`: `--base_lr=1e-3 --autoencoder_loss=5.0` 65 | 66 | ## Interpreting results in the `log.txt` file 67 | 68 | When tracking training progress from the `log.txt` file, a value of `eval_recons_loss` below 1.5 indicates that the encoder and decoder have correctly discovered the objects in the scene, and a value of `eval_pred_loss` below 3.0 and 30.0 (for balls and mnist datasets, respectively) indicates that the velocity estimator and the physical parameters have been learned correctly. Due to the dependency on initialization, it is possible that even using the hyperparameters above the model gets stuck in a local minimum and never gets below the aforementioned values, by failing to discover all the objects or learning the correct physical parameters/velocity estimator (this is common in unsupervised object discovery methods). I am working on improving convergence stability. 69 | 70 | ## Reading other results 71 | 72 | The `example%d.jpg` files show random rollouts from the validation/test set. The top row corresponds to the model prediction, middle row to the ground-truth, and bottom row to the reconstructed frames (as used by the autoencoder loss - this can be used to evaluate whether the objects have been discovered even though the dynamics might not have been learned yet). 73 | 74 | The `templates.jpg` file shows the learned contents (top) and masks (bottom). 75 | 76 | ## Note on reproducibility 77 | 78 | This model shows seed dependency when it comes to discovering the objects. For datasets with two objects it works most of the time, whereas for the `3bp_color` dataset it is harder to find a seed that works. It is possible this might be solved by tweaking hyperparameters and network structure, but we have not explored that extensively. 79 | -------------------------------------------------------------------------------- /nn/network/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import logging 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from nn.utils.misc import log_metrics, zipdir 9 | 10 | logger = logging.getLogger("tf") 11 | root_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..") 12 | 13 | OPTIMIZERS = { 14 | "adam": tf.train.AdamOptimizer, 15 | "rmsprop": tf.train.RMSPropOptimizer, 16 | "momentum": lambda x: tf.train.MomentumOptimizer(x, 0.9), 17 | "sgd": tf.train.GradientDescentOptimizer 18 | } 19 | 20 | 21 | class BaseNet: 22 | 23 | def __init__(self): 24 | self.train_metrics = {} 25 | self.eval_metrics = {} 26 | 27 | # Extra functions to be ran at train/valid/test time 28 | # that can be defined by the children 29 | # Should have the format: 30 | # self.extra_valid_fns = [ 31 | # (valid_fn1, args, kwargs), 32 | # ... 33 | # ] 34 | self.extra_train_fns = [] 35 | self.extra_valid_fns = [] 36 | self.extra_test_fns = [] 37 | 38 | self.sess = tf.Session() 39 | 40 | def run_extra_fns(self, type): 41 | if type == "train": 42 | extra_fns = self.extra_train_fns 43 | elif type == "valid": 44 | extra_fns = self.extra_valid_fns 45 | else: 46 | extra_fns = self.extra_test_fns 47 | 48 | for fn, args, kwargs in extra_fns: 49 | fn(*args, **kwargs) 50 | 51 | def feedforward(self): 52 | raise NotImplementedError 53 | 54 | def compute_loss(self): 55 | raise NotImplementedError 56 | 57 | def build_graph(self): 58 | raise NotImplementedError 59 | 60 | def get_data(self, data_iterators): 61 | self.train_iterator, self.valid_iterator, self.test_iterator = data_iterators 62 | 63 | def get_iterator(self, type): 64 | if type == "train": 65 | eval_iterator = self.train_iterator 66 | elif type == "valid": 67 | eval_iterator = self.valid_iterator 68 | elif type == "test": 69 | eval_iterator = self.test_iterator 70 | return eval_iterator 71 | 72 | def initialize_graph(self, 73 | save_dir, 74 | use_ckpt, 75 | ckpt_dir=""): 76 | 77 | self.save_dir = save_dir 78 | self.saver = tf.train.Saver() 79 | if os.path.exists(save_dir): 80 | if use_ckpt: 81 | restore = True 82 | if ckpt_dir: 83 | restore_dir = ckpt_dir 84 | else: 85 | restore_dir = save_dir 86 | else: 87 | logger.info("Folder exists, deleting...") 88 | shutil.rmtree(save_dir) 89 | os.makedirs(save_dir) 90 | restore = False 91 | else: 92 | os.makedirs(save_dir) 93 | if use_ckpt: 94 | restore = True 95 | restore_dir = ckpt_dir 96 | else: 97 | restore = False 98 | 99 | if restore: 100 | self.saver.restore(self.sess, os.path.join(restore_dir, "model.ckpt")) 101 | self.sess.run(self.lr.assign(self.base_lr)) 102 | else: 103 | self.sess.run(tf.global_variables_initializer()) 104 | 105 | def build_optimizer(self, base_lr, optimizer="adam", anneal_lr=True): 106 | self.base_lr = base_lr 107 | self.anneal_lr = anneal_lr 108 | self.lr = tf.Variable(base_lr, trainable=False, name="base_lr") 109 | self.optimizer = OPTIMIZERS[optimizer](self.lr) 110 | self.train_op = self.optimizer.minimize(self.loss) 111 | 112 | def get_batch(self, batch_size, iterator): 113 | batch_x, batch_y = iterator.next_batch(batch_size) 114 | if batch_y is None: 115 | feed_dict = {self.input:batch_x} 116 | else: 117 | feed_dict = {self.input:batch_x, self.target:batch_y} 118 | return feed_dict, (batch_x, batch_y) 119 | 120 | def add_train_logger(self): 121 | log_path = os.path.join(self.save_dir, "log.txt") 122 | fh = logging.FileHandler(log_path) 123 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(message)s') 124 | fh.setFormatter(formatter) 125 | logger.addHandler(fh) 126 | 127 | def train(self, 128 | epochs, 129 | batch_size, 130 | save_every_n_epochs, 131 | eval_every_n_epochs, 132 | print_interval, 133 | debug=False): 134 | 135 | self.add_train_logger() 136 | zipdir(root_path, self.save_dir) 137 | logger.info("\n".join(sys.argv)) 138 | 139 | step = 0 140 | 141 | # Run validation once before starting training 142 | if not debug and epochs > 0: 143 | valid_metrics_results = self.eval(batch_size, type='valid') 144 | log_metrics(logger, "valid - epoch=%s"%0, valid_metrics_results) 145 | 146 | for ep in range(1, epochs+1): 147 | if self.anneal_lr: 148 | if ep == int(0.75*epochs): 149 | self.sess.run(tf.assign(self.lr, self.lr/5)) 150 | while self.train_iterator.epochs_completed < ep: 151 | feed_dict, _ = self.get_batch(batch_size, self.train_iterator) 152 | results, _ = self.sess.run( 153 | [self.train_metrics, self.train_op], feed_dict=feed_dict) 154 | 155 | self.run_extra_fns("train") 156 | 157 | if step % print_interval == 0: 158 | log_metrics(logger, "train - iter=%s"%step, results) 159 | step += 1 160 | 161 | if ep % eval_every_n_epochs == 0: 162 | valid_metrics_results = self.eval(batch_size, type='valid') 163 | log_metrics(logger, "valid - epoch=%s"%ep, valid_metrics_results) 164 | 165 | if ep % save_every_n_epochs == 0: 166 | self.saver.save(self.sess, os.path.join(self.save_dir, "model.ckpt")) 167 | 168 | test_metrics_results = self.eval(batch_size, type='test') 169 | log_metrics(logger, "test - epoch=%s"%epochs, test_metrics_results) 170 | 171 | def eval(self, 172 | batch_size, 173 | type='valid'): 174 | 175 | eval_metrics_results = {k:[] for k in self.eval_metrics.keys()} 176 | eval_outputs = {"input":[], "output":[]} 177 | 178 | eval_iterator = self.get_iterator(type) 179 | eval_iterator.reset_epoch() 180 | 181 | while eval_iterator.get_epoch() < 1: 182 | if eval_iterator.X.shape[0] < 100: 183 | batch_size = eval_iterator.X.shape[0] 184 | feed_dict, _ = self.get_batch(batch_size, eval_iterator) 185 | fetches = {k:v for k, v in self.eval_metrics.items()} 186 | fetches["output"] = self.output 187 | fetches["input"] = self.input 188 | results = self.sess.run(fetches, feed_dict=feed_dict) 189 | 190 | for k in self.eval_metrics.keys(): 191 | eval_metrics_results[k].append(results[k]) 192 | eval_outputs["input"].append(results["input"]) 193 | eval_outputs["output"].append(results["output"]) 194 | 195 | eval_metrics_results = {k:np.mean(v, axis=0) for k,v in eval_metrics_results.items()} 196 | np.savez_compressed(os.path.join(self.save_dir, "outputs.npz"), 197 | input=np.concatenate(eval_outputs["input"], axis=0), 198 | output=np.concatenate(eval_outputs["output"], axis=0)) 199 | 200 | self.run_extra_fns(type) 201 | 202 | return eval_metrics_results 203 | -------------------------------------------------------------------------------- /nn/network/stn.py: -------------------------------------------------------------------------------- 1 | from six.moves import xrange 2 | import tensorflow as tf 3 | 4 | def stn(U, theta, out_size, name='SpatialTransformer', **kwargs): 5 | """Spatial Transformer Layer 6 | Implements a spatial transformer layer as described in [1]_. 7 | Based on [2]_ and edited by David Dao for Tensorflow. 8 | Parameters 9 | ---------- 10 | U : float 11 | The output of a convolutional net should have the 12 | shape [num_batch, height, width, num_channels]. 13 | theta: float 14 | The output of the 15 | localisation network should be [num_batch, 6]. 16 | out_size: tuple of two ints 17 | The size of the output of the network (height, width) 18 | References 19 | ---------- 20 | .. [1] Spatial Transformer Networks 21 | Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu 22 | Submitted on 5 Jun 2015 23 | .. [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py 24 | Notes 25 | ----- 26 | To initialize the network to the identity transform init 27 | ``theta`` to : 28 | identity = np.array([[1., 0., 0.], 29 | [0., 1., 0.]]) 30 | identity = identity.flatten() 31 | theta = tf.Variable(initial_value=identity) 32 | """ 33 | 34 | def _repeat(x, n_repeats): 35 | with tf.variable_scope('_repeat'): 36 | rep = tf.transpose( 37 | tf.expand_dims(tf.ones([n_repeats, ]), 1), [1, 0]) 38 | rep = tf.cast(rep, 'int32') 39 | x = tf.matmul(tf.reshape(x, (-1, 1)), rep) 40 | return tf.reshape(x, [-1]) 41 | 42 | def _interpolate(im, x, y, out_size): 43 | with tf.variable_scope('_interpolate'): 44 | # constants 45 | num_batch = tf.shape(im)[0] 46 | height = tf.shape(im)[1] 47 | width = tf.shape(im)[2] 48 | channels = tf.shape(im)[3] 49 | 50 | x = tf.cast(x, 'float32') 51 | y = tf.cast(y, 'float32') 52 | height_f = tf.cast(height, 'float32') 53 | width_f = tf.cast(width, 'float32') 54 | out_height = out_size[0] 55 | out_width = out_size[1] 56 | zero = tf.zeros([], dtype='int32') 57 | max_y = tf.cast(tf.shape(im)[1] - 1, 'int32') 58 | max_x = tf.cast(tf.shape(im)[2] - 1, 'int32') 59 | 60 | #x = tf.Print(x,[x],message="x: ", summarize=1000) 61 | 62 | # scale indices from [-1, 1] to [0, width/height] 63 | x = (x + 1.0)*(width_f-1.01) / 2.0 64 | #x = tf.Print(x,[x],message="x_floor: ", summarize=1000) 65 | 66 | y = (y + 1.0)*(height_f-1.01) / 2.0 67 | 68 | # do sampling 69 | x0 = tf.cast(tf.floor(x), 'int32') 70 | #x0 = tf.Print(x0,[x0],message="x0: ", summarize=1000) 71 | 72 | x1 = x0 + 1 73 | 74 | y0 = tf.cast(tf.floor(y), 'int32') 75 | #y0 = tf.Print(y0,[y0],message="y0: ", summarize=1000) 76 | 77 | y1 = y0 + 1 78 | 79 | x0 = tf.clip_by_value(x0, zero, max_x) 80 | #x0 = tf.Print(x0,[x0],message="x0_clip: ", summarize=1000) 81 | 82 | x1 = tf.clip_by_value(x1, zero, max_x) 83 | #x1 = tf.Print(x1,[x1],message="x1_clip: ", summarize=1000) 84 | 85 | y0 = tf.clip_by_value(y0, zero, max_y) 86 | y1 = tf.clip_by_value(y1, zero, max_y) 87 | dim2 = width 88 | dim1 = width*height 89 | base = _repeat(tf.range(num_batch)*dim1, out_height*out_width) 90 | base_y0 = base + y0*dim2 91 | base_y1 = base + y1*dim2 92 | idx_a = base_y0 + x0 93 | idx_b = base_y1 + x0 94 | idx_c = base_y0 + x1 95 | idx_d = base_y1 + x1 96 | 97 | # use indices to lookup pixels in the flat image and restore 98 | # channels dim 99 | im_flat = tf.reshape(im, [-1, channels]) 100 | im_flat = tf.cast(im_flat, 'float32') 101 | Ia = tf.gather(im_flat, idx_a) 102 | Ib = tf.gather(im_flat, idx_b) 103 | Ic = tf.gather(im_flat, idx_c) 104 | Id = tf.gather(im_flat, idx_d) 105 | 106 | # and finally calculate interpolated values 107 | x0_f = tf.cast(x0, 'float32') 108 | x1_f = tf.cast(x1, 'float32') 109 | y0_f = tf.cast(y0, 'float32') 110 | y1_f = tf.cast(y1, 'float32') 111 | wa = tf.expand_dims(((x1_f-x) * (y1_f-y)), 1) 112 | wb = tf.expand_dims(((x1_f-x) * (y-y0_f)), 1) 113 | wc = tf.expand_dims(((x-x0_f) * (y1_f-y)), 1) 114 | wd = tf.expand_dims(((x-x0_f) * (y-y0_f)), 1) 115 | 116 | """ 117 | wa = tf.Print(wa,[wa],message="wa: ", summarize=1000) 118 | wb = tf.Print(wb,[wb],message="wb: ", summarize=1000) 119 | wc = tf.Print(wc,[wc],message="wc: ", summarize=1000) 120 | wd = tf.Print(wd,[wd],message="wd: ", summarize=1000) 121 | """ 122 | 123 | a = wa*Ia 124 | b = wb*Ib 125 | c = wc*Ic 126 | d = wd*Id 127 | 128 | """ 129 | a = tf.Print(a,[a],message="wa*Ia: ", summarize=1000) 130 | b = tf.Print(b,[b],message="wb*Ib: ", summarize=1000) 131 | c = tf.Print(c,[c],message="wc*Ic: ", summarize=1000) 132 | d = tf.Print(d,[d],message="wd*Id: ", summarize=1000) 133 | """ 134 | 135 | output = tf.add_n([a,b,c,d]) 136 | 137 | #output = tf.Print(output,[output],message="output1: ", summarize=1000) 138 | 139 | return output 140 | 141 | def _meshgrid(height, width): 142 | with tf.variable_scope('_meshgrid'): 143 | # This should be equivalent to: 144 | # x_t, y_t = np.meshgrid(np.linspace(-1, 1, width), 145 | # np.linspace(-1, 1, height)) 146 | # ones = np.ones(np.prod(x_t.shape)) 147 | # grid = np.vstack([x_t.flatten(), y_t.flatten(), ones]) 148 | x_t = tf.matmul(tf.ones(shape=[height, 1]), 149 | tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width), 1), [1, 0])) 150 | y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height), 1), 151 | tf.ones(shape=[1, width])) 152 | 153 | x_t_flat = tf.reshape(x_t, (1, -1)) 154 | y_t_flat = tf.reshape(y_t, (1, -1)) 155 | 156 | ones = tf.ones_like(x_t_flat) 157 | grid = tf.concat([x_t_flat, y_t_flat, ones], 0) 158 | return grid 159 | 160 | def _transform(theta, input_dim, out_size): 161 | with tf.variable_scope('_transform'): 162 | num_batch = tf.shape(input_dim)[0] 163 | height = tf.shape(input_dim)[1] 164 | width = tf.shape(input_dim)[2] 165 | num_channels = tf.shape(input_dim)[3] 166 | theta = tf.reshape(theta, (-1, 2, 3)) 167 | theta = tf.cast(theta, 'float32') 168 | 169 | # grid of (x_t, y_t, 1), eq (1) in ref [1] 170 | height_f = tf.cast(height, 'float32') 171 | width_f = tf.cast(width, 'float32') 172 | out_height = out_size[0] 173 | out_width = out_size[1] 174 | grid = _meshgrid(out_height, out_width) 175 | grid = tf.expand_dims(grid, 0) 176 | grid = tf.reshape(grid, [-1]) 177 | grid = tf.tile(grid, [num_batch]) 178 | grid = tf.reshape(grid, [num_batch, 3, -1]) 179 | 180 | # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s) 181 | T_g = tf.matmul(theta, grid) 182 | 183 | x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1]) 184 | y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1]) 185 | x_s_flat = tf.reshape(x_s, [-1]) 186 | y_s_flat = tf.reshape(y_s, [-1]) 187 | 188 | input_transformed = _interpolate( 189 | input_dim, x_s_flat, y_s_flat, 190 | out_size) 191 | 192 | #input_transformed = tf.Print(input_transformed,[input_transformed],message="input_transformed: ", summarize=1000) 193 | 194 | output = tf.reshape( 195 | input_transformed, [num_batch, out_height, out_width, num_channels]) 196 | 197 | #output = tf.Print(output,[output],message="output: ", summarize=1000) 198 | 199 | return output 200 | 201 | with tf.variable_scope(name): 202 | output = _transform(theta, U, out_size) 203 | return output 204 | 205 | 206 | def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'): 207 | """Batch Spatial Transformer Layer 208 | Parameters 209 | ---------- 210 | U : float 211 | tensor of inputs [num_batch,height,width,num_channels] 212 | thetas : float 213 | a set of transformations for each input [num_batch,num_transforms,6] 214 | out_size : int 215 | the size of the output [out_height,out_width] 216 | Returns: float 217 | Tensor of size [num_batch*num_transforms,out_height,out_width,num_channels] 218 | """ 219 | with tf.variable_scope(name): 220 | num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2]) 221 | indices = [[i]*num_transforms for i in xrange(num_batch)] 222 | input_repeated = tf.gather(U, tf.reshape(indices, [-1])) 223 | return transformer(input_repeated, thetas, out_size) -------------------------------------------------------------------------------- /nn/network/physics_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import logging 4 | import numpy as np 5 | import tensorflow as tf 6 | from pprint import pprint 7 | import inspect 8 | 9 | from nn.network.base import BaseNet, OPTIMIZERS 10 | from nn.network.cells import bouncing_ode_cell, spring_ode_cell, gravity_ode_cell 11 | from nn.network.stn import stn 12 | from nn.network.blocks import unet, shallow_unet, variable_from_network 13 | from nn.utils.misc import log_metrics 14 | from nn.utils.viz import gallery, gif 15 | from nn.utils.math import sigmoid 16 | import matplotlib.pyplot as plt 17 | import matplotlib.cm as cm 18 | plt.switch_backend('agg') 19 | 20 | logger = logging.getLogger("tf") 21 | 22 | CELLS = { 23 | "bouncing_ode_cell": bouncing_ode_cell, 24 | "spring_ode_cell": spring_ode_cell, 25 | "gravity_ode_cell": gravity_ode_cell, 26 | "lstm": tf.nn.rnn_cell.LSTMCell 27 | } 28 | 29 | # total number of latent units for each datasets 30 | # coord_units = num_objects*num_dimensions*2 31 | COORD_UNITS = { 32 | "bouncing_balls": 8, 33 | "spring_color": 8, 34 | "spring_color_half": 8, 35 | "3bp_color": 12, 36 | "mnist_spring_color": 8 37 | } 38 | 39 | class PhysicsNet(BaseNet): 40 | def __init__(self, 41 | task="", 42 | recurrent_units=128, 43 | lstm_layers=1, 44 | cell_type="", 45 | seq_len=20, 46 | input_steps=3, 47 | pred_steps=5, 48 | autoencoder_loss=0.0, 49 | alt_vel=False, 50 | color=False, 51 | input_size=36*36, 52 | encoder_type="conv_encoder", 53 | decoder_type="conv_st_decoder"): 54 | 55 | super(PhysicsNet, self).__init__() 56 | 57 | assert task in COORD_UNITS 58 | self.task = task 59 | 60 | # Only used when using black-box dynamics (baselines) 61 | self.recurrent_units = recurrent_units 62 | self.lstm_layers = lstm_layers 63 | 64 | self.cell_type = cell_type 65 | self.cell = CELLS[self.cell_type] 66 | self.color = color 67 | self.conv_ch = 3 if color else 1 68 | self.input_size = input_size 69 | 70 | self.conv_input_shape = [int(np.sqrt(input_size))]*2+[self.conv_ch] 71 | self.input_shape = [int(np.sqrt(input_size))]*2+[self.conv_ch] # same as conv_input_shape, just here for backward compatibility 72 | 73 | self.encoder = {name: method for name, method in \ 74 | inspect.getmembers(self, predicate=inspect.ismethod) if "encoder" in name 75 | }[encoder_type] 76 | self.decoder = {name: method for name, method in \ 77 | inspect.getmembers(self, predicate=inspect.ismethod) if "decoder" in name 78 | }[decoder_type] 79 | 80 | self.output_shape = self.input_shape 81 | 82 | assert seq_len > input_steps + pred_steps 83 | assert input_steps >= 1 84 | assert pred_steps >= 1 85 | self.seq_len = seq_len 86 | self.input_steps = input_steps 87 | self.pred_steps = pred_steps 88 | self.extrap_steps = self.seq_len-self.input_steps-self.pred_steps 89 | 90 | self.alt_vel = alt_vel 91 | self.autoencoder_loss = autoencoder_loss 92 | 93 | self.coord_units = COORD_UNITS[self.task] 94 | self.n_objs = self.coord_units//4 95 | 96 | self.extra_valid_fns.append((self.visualize_sequence,[],{})) 97 | self.extra_test_fns.append((self.visualize_sequence,[],{})) 98 | 99 | def get_batch(self, batch_size, iterator): 100 | batch_x, _ = iterator.next_batch(batch_size) 101 | batch_len = batch_x.shape[1] 102 | feed_dict = {self.input: batch_x} 103 | return feed_dict, (batch_x, None) 104 | 105 | def compute_loss(self): 106 | 107 | # Compute reconstruction loss 108 | recons_target = self.input[:,:self.input_steps+self.pred_steps] 109 | recons_loss = tf.square(recons_target-self.recons_out) 110 | #recons_ce_loss = -(recons_target*tf.log(self.recons_out+1e-7) + (1.0-recons_target)*tf.log(1.0-self.recons_out+1e-7)) 111 | recons_loss = tf.reduce_sum(recons_loss, axis=[2,3,4]) 112 | 113 | self.recons_loss = tf.reduce_mean(recons_loss) 114 | 115 | target = self.input[:,self.input_steps:] 116 | #ce_loss = -(target*tf.log(self.output+1e-7) + (1.0-target)*tf.log(1.0-self.output+1e-7)) 117 | loss = tf.square(target-self.output) 118 | loss = tf.reduce_sum(loss, axis=[2,3,4]) 119 | 120 | # Compute prediction losses. pred_loss is used for training, extrap_loss is used for evaluation 121 | self.pred_loss = tf.reduce_mean(loss[:,:self.pred_steps]) 122 | self.extrap_loss = tf.reduce_mean(loss[:,self.pred_steps:]) 123 | 124 | train_loss = self.pred_loss 125 | if self.autoencoder_loss > 0.0: 126 | train_loss += self.autoencoder_loss*self.recons_loss 127 | 128 | eval_losses = [self.pred_loss, self.extrap_loss, self.recons_loss] 129 | return train_loss, eval_losses 130 | 131 | def build_graph(self): 132 | self.input = tf.placeholder(tf.float32, shape=[None, self.seq_len]+self.input_shape) 133 | self.output = self.conv_feedforward() 134 | 135 | self.train_loss, self.eval_losses = self.compute_loss() 136 | self.train_metrics["train_loss"] = self.train_loss 137 | self.eval_metrics["eval_pred_loss"] = self.eval_losses[0] 138 | self.eval_metrics["eval_extrap_loss"] = self.eval_losses[1] 139 | self.eval_metrics["eval_recons_loss"] = self.eval_losses[2] 140 | self.loss = self.train_loss 141 | 142 | def build_optimizer(self, base_lr, optimizer="rmsprop", anneal_lr=True): 143 | # Uncomment lines below to have different learning rates for physics and vision components 144 | 145 | self.base_lr = base_lr 146 | self.anneal_lr = anneal_lr 147 | self.lr = tf.Variable(base_lr, trainable=False, name="base_lr") 148 | self.optimizer = OPTIMIZERS[optimizer](self.lr) 149 | #self.dyn_optimizer = OPTIMIZERS[optimizer](1e-3) 150 | 151 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 152 | with tf.control_dependencies(update_ops): 153 | gvs = self.optimizer.compute_gradients(self.loss, var_list=tf.trainable_variables()) 154 | gvs = [(tf.clip_by_value(grad, -1.0, 1.0), var) for grad, var in gvs if grad is not None] 155 | self.train_op = self.optimizer.apply_gradients(gvs) 156 | 157 | # self.train_op = self.optimizer.apply_gradients([gv for gv in gvs if "cell" not in gv[1].name]) 158 | # if len([gv for gv in gvs if "cell" in gv[1].name]) > 0: 159 | # self.dyn_train_op = self.dyn_optimizer.apply_gradients([gv for gv in gvs if "cell" in gv[1].name]) 160 | # self.train_op = tf.group(self.train_op, self.dyn_train_op) 161 | 162 | def conv_encoder(self, inp, scope=None, reuse=tf.AUTO_REUSE): 163 | with tf.variable_scope(scope or tf.get_variable_scope(), reuse=reuse): 164 | with tf.variable_scope("encoder"): 165 | rang = tf.range(self.conv_input_shape[0], dtype=tf.float32) 166 | grid_x, grid_y = tf.meshgrid(rang, rang) 167 | grid = tf.concat([grid_x[:,:,None], grid_y[:,:,None]], axis=2) 168 | grid = tf.tile(grid[None,:,:,:], [tf.shape(inp)[0], 1, 1, 1]) 169 | 170 | if self.input_shape[0] < 40: 171 | h = inp 172 | h = shallow_unet(h, 8, self.n_objs, upsamp=True) 173 | 174 | h = tf.concat([h, tf.ones_like(h[:,:,:,:1])], axis=-1) 175 | h = tf.nn.softmax(h, axis=-1) 176 | self.enc_masks = h 177 | self.masked_objs = [self.enc_masks[:,:,:,i:i+1]*inp for i in range(self.n_objs)] 178 | 179 | h = tf.concat(self.masked_objs, axis=0) 180 | h = tf.reshape(h, [tf.shape(h)[0], self.input_shape[0]*self.input_shape[0]*self.conv_ch]) 181 | h = tf.layers.dense(h, 200, activation=tf.nn.relu) 182 | h = tf.layers.dense(h, 200, activation=tf.nn.relu) 183 | h = tf.layers.dense(h, 2, activation=None) 184 | h = tf.concat(tf.split(h, self.n_objs, 0), axis=1) 185 | h = tf.tanh(h)*(self.conv_input_shape[0]/2)+(self.conv_input_shape[0]/2) 186 | else: 187 | h = inp 188 | h = unet(h, 16, self.n_objs, upsamp=True) 189 | 190 | h = tf.concat([h, tf.ones_like(h[:,:,:,:1])], axis=-1) 191 | h = tf.nn.softmax(h, axis=-1) 192 | self.enc_masks = h 193 | self.masked_objs = [self.enc_masks[:,:,:,i:i+1]*inp for i in range(self.n_objs)] 194 | h = tf.concat(self.masked_objs, axis=0) 195 | h = tf.layers.average_pooling2d(h, 2, 2) 196 | #h = tf.reduce_mean(h, axis=-1) 197 | 198 | h = tf.layers.flatten(h) 199 | h = tf.layers.dense(h, 200, activation=tf.nn.relu) 200 | h = tf.layers.dense(h, 200, activation=tf.nn.relu) 201 | h = tf.layers.dense(h, 2, activation=None) 202 | h = tf.concat(tf.split(h, self.n_objs, 0), axis=1) 203 | h = tf.tanh(h)*(self.conv_input_shape[0]/2)+(self.conv_input_shape[0]/2) 204 | return h 205 | 206 | def vel_encoder(self, inp, scope=None, reuse=tf.AUTO_REUSE): 207 | with tf.variable_scope(scope or tf.get_variable_scope(), reuse=reuse): 208 | with tf.variable_scope("init_vel"): 209 | if self.alt_vel: 210 | # Computes velocity as a linear combination of the differences 211 | # between previous time-steps 212 | h = tf.split(inp, self.input_steps, 1) 213 | h = [h[i+1]-h[i] for i in range(self.input_steps-1)] 214 | h = tf.concat(h, axis=1) 215 | h = tf.split(h, self.n_objs, 2) 216 | h = tf.concat(h, axis=0) 217 | h = tf.reshape(h, [tf.shape(h)[0], (self.input_steps-1)*2]) 218 | h = tf.layers.dense(h, 2, activation=None) 219 | h = tf.split(h, self.n_objs, 0) 220 | h = tf.concat(h, axis=1) 221 | else: 222 | # Computes velocity using an MLP with positions as input 223 | h = tf.split(inp, self.n_objs, 2) 224 | h = tf.concat(h, axis=0) 225 | h = tf.reshape(h, [tf.shape(h)[0], self.input_steps*self.coord_units//self.n_objs//2]) 226 | h = tf.layers.dense(h, 100, activation=tf.tanh) 227 | h = tf.layers.dense(h, 100, activation=tf.tanh) 228 | h = tf.layers.dense(h, self.coord_units//self.n_objs//2, activation=None) 229 | h = tf.split(h, self.n_objs, 0) 230 | h = tf.concat(h, axis=1) 231 | return h 232 | 233 | def conv_st_decoder(self, inp, scope=None, reuse=tf.AUTO_REUSE): 234 | with tf.variable_scope(scope or tf.get_variable_scope(), reuse=reuse): 235 | with tf.variable_scope("decoder"): 236 | 237 | batch_size = tf.shape(inp)[0] 238 | tmpl_size = self.conv_input_shape[0]//2 239 | 240 | # This parameter can be played with. 241 | # Setting it to log(2.0) makes the attention window half the size, which might make 242 | # it easier for the model to discover objects in some cases. 243 | # I haven't found this to make a consistent difference though. 244 | logsigma = tf.get_variable("logsigma", shape=[], initializer=tf.constant_initializer(np.log(1.0)), trainable=True) 245 | sigma = tf.exp(logsigma) 246 | 247 | template = variable_from_network([self.n_objs, tmpl_size, tmpl_size, 1]) 248 | self.template = template 249 | template = tf.tile(template, [1,1,1,3])+5 250 | 251 | contents = variable_from_network([self.n_objs, tmpl_size, tmpl_size, self.conv_ch]) 252 | self.contents = contents 253 | contents = tf.nn.sigmoid(contents) 254 | joint = tf.concat([template, contents], axis=-1) 255 | 256 | c2t = tf.convert_to_tensor 257 | out_temp_cont = [] 258 | for loc, join in zip(tf.split(inp, self.n_objs, -1), tf.split(joint, self.n_objs, 0)): 259 | theta0 = tf.tile(c2t([sigma]), [tf.shape(inp)[0]]) 260 | theta1 = tf.tile(c2t([0.0]), [tf.shape(inp)[0]]) 261 | theta2 = (self.conv_input_shape[0]/2-loc[:,0])/tmpl_size*sigma 262 | theta3 = tf.tile(c2t([0.0]), [tf.shape(inp)[0]]) 263 | theta4 = tf.tile(c2t([sigma]), [tf.shape(inp)[0]]) 264 | theta5 = (self.conv_input_shape[0]/2-loc[:,1])/tmpl_size*sigma 265 | theta = tf.stack([theta0, theta1, theta2, theta3, theta4, theta5], axis=1) 266 | 267 | out_join = stn(tf.tile(join, [tf.shape(inp)[0], 1, 1, 1]), theta, self.conv_input_shape[:2]) 268 | out_temp_cont.append(tf.split(out_join, 2, -1)) 269 | 270 | background_content = variable_from_network([1]+self.input_shape) 271 | self.background_content = tf.nn.sigmoid(background_content) 272 | background_content = tf.tile(self.background_content, [batch_size, 1, 1, 1]) 273 | contents = [p[1] for p in out_temp_cont] 274 | contents.append(background_content) 275 | self.transf_contents = contents 276 | 277 | background_mask = tf.ones_like(out_temp_cont[0][0]) 278 | masks = tf.stack([p[0]-5 for p in out_temp_cont]+[background_mask], axis=-1) 279 | masks = tf.nn.softmax(masks, axis=-1) 280 | masks = tf.unstack(masks, axis=-1) 281 | self.transf_masks = masks 282 | 283 | out = tf.add_n([m*c for m, c in zip(masks, contents)]) 284 | 285 | return out 286 | 287 | def conv_feedforward(self): 288 | with tf.variable_scope("net") as tvs: 289 | lstms = [tf.nn.rnn_cell.LSTMCell(self.recurrent_units) for i in range(self.lstm_layers)] 290 | states = [lstm.zero_state(tf.shape(self.input)[0], dtype=tf.float32) for lstm in lstms] 291 | rollout_cell = self.cell(self.coord_units//2) 292 | 293 | # Encode all the input and train frames 294 | h = tf.reshape(self.input[:,:self.input_steps+self.pred_steps], [-1]+self.input_shape) 295 | enc_pos = self.encoder(h, scope=tvs) 296 | 297 | # decode the input and pred frames 298 | recons_out = self.decoder(enc_pos, scope=tvs) 299 | 300 | self.recons_out = tf.reshape(recons_out, 301 | [tf.shape(self.input)[0], self.input_steps+self.pred_steps]+self.input_shape) 302 | self.enc_pos = tf.reshape(enc_pos, 303 | [tf.shape(self.input)[0], self.input_steps+self.pred_steps, self.coord_units//2]) 304 | 305 | if self.input_steps > 1: 306 | vel = self.vel_encoder(self.enc_pos[:,:self.input_steps], scope=tvs) 307 | else: 308 | vel = tf.zeros([tf.shape(self.input)[0], self.coord_units//2]) 309 | 310 | pos = self.enc_pos[:,self.input_steps-1] 311 | output_seq = [] 312 | pos_vel_seq = [] 313 | pos_vel_seq.append(tf.concat([pos, vel], axis=1)) 314 | 315 | # rollout ODE and decoder 316 | for t in range(self.pred_steps+self.extrap_steps): 317 | # rollout 318 | pos, vel = rollout_cell(pos, vel) 319 | 320 | # decode 321 | out = self.decoder(pos, scope=tvs) 322 | 323 | pos_vel_seq.append(tf.concat([pos, vel], axis=1)) 324 | output_seq.append(out) 325 | 326 | current_scope = tf.get_default_graph().get_name_scope() 327 | self.network_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 328 | scope=current_scope) 329 | logger.info(self.network_vars) 330 | 331 | output_seq = tf.stack(output_seq) 332 | pos_vel_seq = tf.stack(pos_vel_seq) 333 | output_seq = tf.transpose(output_seq, (1,0,2,3,4)) 334 | self.pos_vel_seq = tf.transpose(pos_vel_seq, (1,0,2)) 335 | return output_seq 336 | 337 | def visualize_sequence(self): 338 | batch_size = 5 339 | 340 | feed_dict, (batch_x, _) = self.get_batch(batch_size, self.test_iterator) 341 | fetches = [self.output, self.recons_out] 342 | if hasattr(self, 'pos_vel_seq'): 343 | fetches.append(self.pos_vel_seq) 344 | 345 | res = self.sess.run(fetches, feed_dict=feed_dict) 346 | output_seq = res[0] 347 | recons_seq = res[1] 348 | if hasattr(self, 'pos_vel_seq'): 349 | pos_vel_seq = res[2] 350 | output_seq = np.concatenate([batch_x[:,:self.input_steps], output_seq], axis=1) 351 | recons_seq = np.concatenate([recons_seq, np.zeros((batch_size, self.extrap_steps)+recons_seq.shape[2:])], axis=1) 352 | 353 | # Plot a grid with prediction sequences 354 | for i in range(batch_x.shape[0]): 355 | #if hasattr(self, 'pos_vel_seq'): 356 | # if i == 0 or i == 1: 357 | # logger.info(pos_vel_seq[i]) 358 | 359 | to_concat = [output_seq[i],batch_x[i],recons_seq[i]] 360 | total_seq = np.concatenate(to_concat, axis=0) 361 | 362 | total_seq = total_seq.reshape([total_seq.shape[0], 363 | self.input_shape[0], 364 | self.input_shape[1], self.conv_ch]) 365 | 366 | result = gallery(total_seq, ncols=batch_x.shape[1]) 367 | 368 | norm = plt.Normalize(0.0, 1.0) 369 | 370 | figsize = (result.shape[1]//self.input_shape[1], result.shape[0]//self.input_shape[0]) 371 | fig, ax = plt.subplots(figsize=figsize) 372 | ax.imshow(np.squeeze(result), interpolation='nearest', cmap=cm.Greys_r, norm=norm) 373 | ax.get_xaxis().set_visible(False) 374 | ax.get_yaxis().set_visible(False) 375 | fig.tight_layout() 376 | fig.savefig(os.path.join(self.save_dir, "example%d.jpg"%i)) 377 | 378 | # Make a gif from the sequences 379 | bordered_output_seq = 0.5*np.ones([batch_size, self.seq_len, 380 | self.conv_input_shape[0]+2, self.conv_input_shape[1]+2, 3]) 381 | bordered_batch_x = 0.5*np.ones([batch_size, self.seq_len, 382 | self.conv_input_shape[0]+2, self.conv_input_shape[1]+2, 3]) 383 | output_seq = output_seq.reshape([batch_size, self.seq_len]+self.input_shape) 384 | batch_x = batch_x.reshape([batch_size, self.seq_len]+self.input_shape) 385 | bordered_output_seq[:,:,1:-1,1:-1] = output_seq 386 | bordered_batch_x[:,:,1:-1,1:-1] = batch_x 387 | output_seq = bordered_output_seq 388 | batch_x = bordered_batch_x 389 | output_seq = np.concatenate(np.split(output_seq, batch_size, 0), axis=-2).squeeze() 390 | batch_x = np.concatenate(np.split(batch_x, batch_size, 0), axis=-2).squeeze() 391 | frames = np.concatenate([output_seq, batch_x], axis=1) 392 | 393 | gif(os.path.join(self.save_dir, "animation%d.gif"%i), 394 | frames*255, fps=7, scale=3) 395 | 396 | # Save extra tensors for visualization 397 | fetches = {"contents": self.contents, 398 | "templates": self.template, 399 | "background_content": self.background_content, 400 | "transf_contents": self.transf_contents, 401 | "transf_masks": self.transf_masks, 402 | "enc_masks": self.enc_masks, 403 | "masked_objs": self.masked_objs} 404 | results = self.sess.run(fetches, feed_dict=feed_dict) 405 | np.savez_compressed(os.path.join(self.save_dir, "extra_outputs.npz"), **results) 406 | contents = results["contents"] 407 | templates = results["templates"] 408 | contents = 1/(1+np.exp(-contents)) 409 | templates = 1/(1+np.exp(-(templates-5))) 410 | if self.conv_ch == 1: 411 | contents = np.tile(contents, [1,1,1,3]) 412 | templates = np.tile(templates, [1,1,1,3]) 413 | total_seq = np.concatenate([contents, templates], axis=0) 414 | result = gallery(total_seq, ncols=self.n_objs) 415 | fig, ax = plt.subplots(figsize=figsize) 416 | ax.imshow(np.squeeze(result), interpolation='nearest', cmap=cm.Greys_r, norm=norm) 417 | ax.get_xaxis().set_visible(False) 418 | ax.get_yaxis().set_visible(False) 419 | fig.tight_layout() 420 | fig.savefig(os.path.join(self.save_dir, "templates.jpg")) 421 | 422 | logger.info([(v.name, self.sess.run(v)) for v in tf.trainable_variables() if "ode_cell" in v.name or "sigma" in v.name]) 423 | 424 | -------------------------------------------------------------------------------- /nn/datasets/generators.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import matplotlib.cm as cm 5 | from itertools import combinations 6 | 7 | from nn.utils.viz import gallery 8 | from nn.utils.misc import rgb2gray 9 | 10 | def generate_bouncing_ball_dataset(dest, 11 | train_set_size, 12 | valid_set_size, 13 | test_set_size, 14 | seq_len, 15 | box_size): 16 | np.random.seed(0) 17 | 18 | def verify_collision(x, v): 19 | if x[0] + v[0] > box_size or x[0] + v[0] < 0.0: 20 | v[0] = -v[0] 21 | if x[1] + v[1] > box_size or x[1] + v[1] < 0.0: 22 | v[1] = -v[1] 23 | return v 24 | 25 | def generate_trajectory(steps): 26 | traj = [] 27 | x = np.random.rand(2)*box_size 28 | speed = np.random.rand()+1 29 | angle = np.random.rand()*2*np.pi 30 | v = np.array([speed*np.cos(angle), speed*np.sin(angle)]) 31 | for _ in range(steps): 32 | traj.append(x) 33 | v = verify_collision(x, v) 34 | x = x + v 35 | return traj 36 | 37 | trajectories = [] 38 | for i in range(train_set_size+valid_set_size+test_set_size): 39 | trajectories.append(generate_trajectory(seq_len)) 40 | trajectories = np.array(trajectories) 41 | 42 | np.savez_compressed(dest, 43 | train_x=trajectories[:train_set_size], 44 | valid_x=trajectories[train_set_size:train_set_size+valid_set_size], 45 | test_x=trajectories[train_set_size+valid_set_size:]) 46 | print("Saved to file %s" % dest) 47 | 48 | 49 | def compute_wall_collision(pos, vel, radius, img_size): 50 | if pos[1]-radius <= 0: 51 | vel[1] = -vel[1] 52 | pos[1] = -(pos[1]-radius)+radius 53 | if pos[1]+radius >= img_size[1]: 54 | vel[1] = -vel[1] 55 | pos[1] = img_size[1]-(pos[1]+radius-img_size[1])-radius 56 | if pos[0]-radius <= 0: 57 | vel[0] = -vel[0] 58 | pos[0] = -(pos[0]-radius)+radius 59 | if pos[0]+radius >= img_size[0]: 60 | vel[0] = -vel[0] 61 | pos[0] = img_size[0]-(pos[0]+radius-img_size[0])-radius 62 | return pos, vel 63 | 64 | 65 | def verify_wall_collision(pos, vel, radius, img_size): 66 | if pos[1]-radius <= 0: 67 | return True 68 | if pos[1]+radius >= img_size[1]: 69 | return True 70 | if pos[0]-radius <= 0: 71 | return True 72 | if pos[0]+radius >= img_size[0]: 73 | return True 74 | return False 75 | 76 | 77 | def verify_object_collision(poss, radius): 78 | for pos1, pos2 in combinations(poss, 2): 79 | if np.linalg.norm(pos1-pos2) <= radius: 80 | return True 81 | return False 82 | 83 | 84 | def generate_falling_ball_dataset(dest, 85 | train_set_size, 86 | valid_set_size, 87 | test_set_size, 88 | seq_len, 89 | img_size=None, 90 | radius=3, 91 | dt=0.15, 92 | g=9.8, 93 | ode_steps=10): 94 | 95 | from skimage.draw import circle 96 | from nn.utils.viz import gallery 97 | import matplotlib.cm as cm 98 | if img_size is None: 99 | img_size = [32,32] 100 | 101 | def generate_sequence(): 102 | seq = [] 103 | # sample initial position, with v=0 104 | pos = np.random.rand(2) 105 | pos[0] = radius+(img_size[0]-2*radius)*pos[0] 106 | pos[1] = radius + (img_size[1]-2*radius)/2*pos[1] 107 | vel = np.array([0.0,0.0]) 108 | 109 | for i in range(seq_len): 110 | assert pos[1]+radius < img_size[1] 111 | 112 | frame = np.zeros(img_size+[1], dtype=np.int8) 113 | rr, cc = circle(int(pos[1]), int(pos[0]), radius) 114 | frame[rr, cc, 0] = 255 115 | 116 | seq.append(frame) 117 | 118 | # rollout physics 119 | for _ in range(ode_steps): 120 | vel[1] = vel[1] + dt/ode_steps*g 121 | pos[1] = pos[1] + dt/ode_steps*vel[1] 122 | 123 | return seq 124 | 125 | sequences = [] 126 | for i in range(train_set_size+valid_set_size+test_set_size): 127 | if i % 100 == 0: 128 | print(i) 129 | sequences.append(generate_sequence()) 130 | sequences = np.array(sequences, dtype=np.uint8) 131 | 132 | np.savez_compressed(dest, 133 | train_x=sequences[:train_set_size], 134 | valid_x=sequences[train_set_size:train_set_size+valid_set_size], 135 | test_x=sequences[train_set_size+valid_set_size:]) 136 | print("Saved to file %s" % dest) 137 | 138 | # Save 10 samples 139 | result = gallery(np.concatenate(sequences[:10]/255), ncols=sequences.shape[1]) 140 | 141 | norm = plt.Normalize(0.0, 1.0) 142 | fig, ax = plt.subplots(figsize=(10, 10)) 143 | ax.imshow(np.squeeze(result), interpolation='nearest', cmap=cm.Greys_r, norm=norm) 144 | ax.get_xaxis().set_visible(False) 145 | ax.get_yaxis().set_visible(False) 146 | fig.tight_layout() 147 | fig.savefig(dest.split(".")[0]+"_samples.jpg") 148 | 149 | 150 | def generate_falling_bouncing_ball_dataset(dest, 151 | train_set_size, 152 | valid_set_size, 153 | test_set_size, 154 | seq_len, 155 | img_size=None, 156 | radius=3, 157 | dt=0.30, 158 | g=9.8, 159 | vx0_max=0.0, 160 | vy0_max=0.0, 161 | cifar_background=False, 162 | ode_steps=10): 163 | 164 | if cifar_background: 165 | import tensorflow as tf 166 | (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() 167 | 168 | from skimage.draw import circle 169 | from skimage.transform import resize 170 | 171 | if img_size is None: 172 | img_size = [32,32] 173 | scale = 10 174 | scaled_img_size = [img_size[0]*scale, img_size[1]*scale] 175 | 176 | def generate_sequence(): 177 | seq = [] 178 | # sample initial position, with v=0 179 | pos = np.random.rand(2) 180 | pos[0] = radius + (img_size[0]-2*radius)*pos[0] 181 | if g == 0.0: 182 | pos[1] = radius + (img_size[1]-2*radius)*pos[1] 183 | else: 184 | pos[1] = radius + (img_size[1]-2*radius)/2*pos[1] 185 | angle = np.random.rand()*2*np.pi 186 | vel = np.array([np.cos(angle)*vx0_max, 187 | np.sin(angle)*vy0_max]) 188 | 189 | if cifar_background: 190 | cifar_img = x_train[np.random.randint(50000)] 191 | 192 | for i in range(seq_len): 193 | if cifar_background: 194 | frame = cifar_img 195 | frame = rgb2gray(frame)/255 196 | frame = resize(frame, scaled_img_size) 197 | frame = np.clip(frame-0.2, 0.0, 1.0) # darken image a bit 198 | else: 199 | frame = np.zeros(scaled_img_size, dtype=np.float32) 200 | 201 | rr, cc = circle(int(pos[1]*scale), int(pos[0]*scale), radius*scale, scaled_img_size) 202 | frame[rr, cc] = 1.0 203 | frame = resize(frame, img_size, anti_aliasing=True) 204 | frame = (frame[:,:,None]*255).astype(np.uint8) 205 | 206 | seq.append(frame) 207 | 208 | # rollout physics 209 | for _ in range(ode_steps): 210 | vel[1] = vel[1] + dt/ode_steps*g 211 | pos[1] = pos[1] + dt/ode_steps*vel[1] 212 | 213 | pos[0] = pos[0] + dt/ode_steps*vel[0] 214 | 215 | # verify wall collisions 216 | pos, vel = compute_wall_collision(pos, vel, radius, img_size) 217 | return seq 218 | 219 | sequences = [] 220 | for i in range(train_set_size+valid_set_size+test_set_size): 221 | if i % 100 == 0: 222 | print(i) 223 | sequences.append(generate_sequence()) 224 | sequences = np.array(sequences, dtype=np.uint8) 225 | 226 | np.savez_compressed(dest, 227 | train_x=sequences[:train_set_size], 228 | valid_x=sequences[train_set_size:train_set_size+valid_set_size], 229 | test_x=sequences[train_set_size+valid_set_size:]) 230 | print("Saved to file %s" % dest) 231 | 232 | # Save 10 samples 233 | result = gallery(np.concatenate(sequences[:10]/255), ncols=sequences.shape[1]) 234 | 235 | norm = plt.Normalize(0.0, 1.0) 236 | fig, ax = plt.subplots(figsize=(sequences.shape[1], 10)) 237 | ax.imshow(np.squeeze(result), interpolation='nearest', cmap=cm.Greys_r, norm=norm) 238 | ax.get_xaxis().set_visible(False) 239 | ax.get_yaxis().set_visible(False) 240 | fig.tight_layout() 241 | fig.savefig(dest.split(".")[0]+"_samples.jpg") 242 | 243 | 244 | def generate_spring_balls_dataset(dest, 245 | train_set_size, 246 | valid_set_size, 247 | test_set_size, 248 | seq_len, 249 | img_size=None, 250 | radius=3, 251 | dt=0.3, 252 | k=3, 253 | equil=5, 254 | vx0_max=0.0, 255 | vy0_max=0.0, 256 | color=False, 257 | cifar_background=False, 258 | ode_steps=10): 259 | 260 | if cifar_background: 261 | import tensorflow as tf 262 | (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() 263 | 264 | from skimage.draw import circle 265 | from skimage.transform import resize 266 | 267 | if img_size is None: 268 | img_size = [32,32] 269 | scale = 10 270 | scaled_img_size = [img_size[0]*scale, img_size[1]*scale] 271 | 272 | def generate_sequence(): 273 | # sample initial position of the center of mass, then sample 274 | # position of each object relative to that. 275 | 276 | collision = True 277 | while collision == True: 278 | seq = [] 279 | 280 | cm_pos = np.random.rand(2) 281 | cm_pos[0] = radius+equil + (img_size[0]-2*(radius+equil))*cm_pos[0] 282 | cm_pos[1] = radius+equil + (img_size[1]-2*(radius+equil))*cm_pos[1] 283 | 284 | angle = np.random.rand()*2*np.pi 285 | # calculate position of both objects 286 | r = np.random.rand()+0.5 287 | poss = [[np.cos(angle)*equil*r+cm_pos[0], np.sin(angle)*equil*r+cm_pos[1]], 288 | [np.cos(angle+np.pi)*equil*r+cm_pos[0], np.sin(angle+np.pi)*equil*r+cm_pos[1]]] 289 | poss = np.array(poss) 290 | angles = np.random.rand(2)*2*np.pi 291 | vels = [[np.cos(angles[0])*vx0_max, np.sin(angles[0])*vy0_max], 292 | [np.cos(angles[1])*vx0_max, np.sin(angles[1])*vy0_max]] 293 | vels = np.array(vels) 294 | 295 | if cifar_background: 296 | cifar_img = x_train[np.random.randint(50000)] 297 | 298 | for i in range(seq_len): 299 | if cifar_background: 300 | frame = cifar_img 301 | frame = rgb2gray(frame)/255 302 | frame = resize(frame, scaled_img_size) 303 | frame = np.clip(frame-0.2, 0.0, 1.0) # darken image a bit 304 | else: 305 | if color: 306 | frame = np.zeros(scaled_img_size+[3], dtype=np.float32) 307 | else: 308 | frame = np.zeros(scaled_img_size+[1], dtype=np.float32) 309 | 310 | 311 | for j, pos in enumerate(poss): 312 | rr, cc = circle(int(pos[1]*scale), int(pos[0]*scale), radius*scale, scaled_img_size) 313 | if color: 314 | frame[rr, cc, 2-j] = 1.0 315 | else: 316 | frame[rr, cc, 0] = 1.0 317 | 318 | frame = resize(frame, img_size, anti_aliasing=True) 319 | frame = (frame*255).astype(np.uint8) 320 | 321 | seq.append(frame) 322 | 323 | # rollout physics 324 | for _ in range(ode_steps): 325 | norm = np.linalg.norm(poss[0]-poss[1]) 326 | direction = (poss[0]-poss[1])/norm 327 | F = k*(norm-2*equil)*direction 328 | vels[0] = vels[0] - dt/ode_steps*F 329 | vels[1] = vels[1] + dt/ode_steps*F 330 | poss = poss + dt/ode_steps*vels 331 | 332 | collision = verify_wall_collision(poss[0], vels[0], radius, img_size) or \ 333 | verify_wall_collision(poss[1], vels[1], radius, img_size) 334 | if collision: 335 | break 336 | #poss[0], vels[0] = compute_wall_collision(poss[0], vels[0], radius, img_size) 337 | #poss[1], vels[1] = compute_wall_collision(poss[1], vels[1], radius, img_size) 338 | if collision: 339 | break 340 | 341 | return seq 342 | 343 | sequences = [] 344 | for i in range(train_set_size+valid_set_size+test_set_size): 345 | if i % 100 == 0: 346 | print(i) 347 | sequences.append(generate_sequence()) 348 | sequences = np.array(sequences, dtype=np.uint8) 349 | 350 | np.savez_compressed(dest, 351 | train_x=sequences[:train_set_size], 352 | valid_x=sequences[train_set_size:train_set_size+valid_set_size], 353 | test_x=sequences[train_set_size+valid_set_size:]) 354 | print("Saved to file %s" % dest) 355 | 356 | # Save 10 samples 357 | result = gallery(np.concatenate(sequences[:10]/255), ncols=sequences.shape[1]) 358 | 359 | norm = plt.Normalize(0.0, 1.0) 360 | fig, ax = plt.subplots(figsize=(sequences.shape[1], 10)) 361 | ax.imshow(np.squeeze(result), interpolation='nearest', cmap=cm.Greys_r, norm=norm) 362 | ax.get_xaxis().set_visible(False) 363 | ax.get_yaxis().set_visible(False) 364 | fig.tight_layout() 365 | fig.savefig(dest.split(".")[0]+"_samples.jpg") 366 | 367 | 368 | def generate_spring_mnist_dataset(dest, 369 | train_set_size, 370 | valid_set_size, 371 | test_set_size, 372 | seq_len, 373 | img_size=None, 374 | radius=3, 375 | dt=0.3, 376 | k=3, 377 | equil=5, 378 | vx0_max=0.0, 379 | vy0_max=0.0, 380 | color=False, 381 | cifar_background=False, 382 | ode_steps=10): 383 | 384 | # A single CIFAR image is used for background 385 | # Only 2 mnist digits are used 386 | import tensorflow as tf 387 | from skimage.draw import circle 388 | from skimage.transform import resize 389 | 390 | scale = 5 391 | if img_size is None: 392 | img_size = [32,32] 393 | scaled_img_size = [img_size[0]*scale, img_size[1]*scale] 394 | 395 | if cifar_background: 396 | (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() 397 | cifar_img = x_train[1] 398 | 399 | (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() 400 | digits = x_train[0:2, 3:-3, 3:-3]/255 401 | digits = [resize(d, [22*scale, 22*scale]) for d in digits] 402 | radius = 11 403 | 404 | def generate_sequence(): 405 | # sample initial position of the center of mass, then sample 406 | # position of each object relative to that. 407 | 408 | collision = True 409 | while collision == True: 410 | seq = [] 411 | 412 | cm_pos = np.random.rand(2) 413 | cm_pos[0] = radius+equil + (img_size[0]-2*(radius+equil))*cm_pos[0] 414 | cm_pos[1] = radius+equil + (img_size[1]-2*(radius+equil))*cm_pos[1] 415 | 416 | angle = np.random.rand()*2*np.pi 417 | # calculate position of both objects 418 | r = np.random.rand()+0.5 419 | poss = [[np.cos(angle)*equil*r+cm_pos[0], np.sin(angle)*equil*r+cm_pos[1]], 420 | [np.cos(angle+np.pi)*equil*r+cm_pos[0], np.sin(angle+np.pi)*equil*r+cm_pos[1]]] 421 | poss = np.array(poss) 422 | angles = np.random.rand(2)*2*np.pi 423 | vels = [[np.cos(angles[0])*vx0_max, np.sin(angles[0])*vy0_max], 424 | [np.cos(angles[1])*vx0_max, np.sin(angles[1])*vy0_max]] 425 | vels = np.array(vels) 426 | 427 | for i in range(seq_len): 428 | if cifar_background: 429 | frame = cifar_img 430 | if not color: 431 | frame = rgb2gray(frame) 432 | frame = frame[:,:,None] 433 | frame = frame/255 434 | frame = resize(frame, scaled_img_size) 435 | frame = np.clip(frame-0.2, 0.0, 1.0) # darken image a bit 436 | else: 437 | if color: 438 | frame = np.zeros(scaled_img_size+[3], dtype=np.float32) 439 | else: 440 | frame = np.zeros(scaled_img_size+[1], dtype=np.float32) 441 | 442 | 443 | for j, pos in enumerate(poss): 444 | rr, cc = circle(int(pos[1]*scale), int(pos[0]*scale), radius*scale, scaled_img_size) 445 | frame_coords = np.array([[max(0, (pos[1]-radius)*scale), min(scaled_img_size[1], (pos[1]+radius)*scale)], 446 | [max(0, (pos[0]-radius)*scale), min(scaled_img_size[0], (pos[0]+radius)*scale)]]) 447 | digit_coords = np.array([[max(0, (radius-pos[1])*scale), min(2*radius*scale, scaled_img_size[1]-(pos[1]-radius)*scale)], 448 | [max(0, (radius-pos[0])*scale), min(2*radius*scale, scaled_img_size[0]-(pos[0]-radius)*scale)]]) 449 | frame_coords = np.round(frame_coords).astype(np.int32) 450 | digit_coords = np.round(digit_coords).astype(np.int32) 451 | 452 | digit_slice = digits[j][digit_coords[0,0]:digit_coords[0,1], 453 | digit_coords[1,0]:digit_coords[1,1]] 454 | if color: 455 | for l in range(3): 456 | frame_slice = frame[frame_coords[0,0]:frame_coords[0,1], 457 | frame_coords[1,0]:frame_coords[1,1], l] 458 | c = 1.0 if l == j else 0.0 459 | frame[frame_coords[0,0]:frame_coords[0,1], 460 | frame_coords[1,0]:frame_coords[1,1], l] = digit_slice*c + (1-digit_slice)*frame_slice 461 | 462 | else: 463 | frame_slice = frame[frame_coords[0,0]:frame_coords[0,1], 464 | frame_coords[1,0]:frame_coords[1,1], 0] 465 | frame[frame_coords[0,0]:frame_coords[0,1], 466 | frame_coords[1,0]:frame_coords[1,1], 0] = digit_slice + (1-digit_slice)*frame_slice 467 | 468 | frame = resize(frame, img_size, anti_aliasing=True) 469 | frame = (frame*255).astype(np.uint8) 470 | 471 | seq.append(frame) 472 | 473 | # rollout physics 474 | for _ in range(ode_steps): 475 | norm = np.linalg.norm(poss[0]-poss[1]) 476 | direction = (poss[0]-poss[1])/norm 477 | F = k*(norm-2*equil)*direction 478 | vels[0] = vels[0] - dt/ode_steps*F 479 | vels[1] = vels[1] + dt/ode_steps*F 480 | poss = poss + dt/ode_steps*vels 481 | 482 | collision = verify_wall_collision(poss[0], vels[0], 2, img_size) or \ 483 | verify_wall_collision(poss[1], vels[1], 2, img_size) 484 | if collision: 485 | break 486 | #poss[0], vels[0] = compute_wall_collision(poss[0], vels[0], radius, img_size) 487 | #poss[1], vels[1] = compute_wall_collision(poss[1], vels[1], radius, img_size) 488 | if collision: 489 | break 490 | 491 | return seq 492 | 493 | sequences = [] 494 | for i in range(train_set_size+valid_set_size+test_set_size): 495 | if i % 100 == 0: 496 | print(i) 497 | sequences.append(generate_sequence()) 498 | sequences = np.array(sequences, dtype=np.uint8) 499 | 500 | np.savez_compressed(dest, 501 | train_x=sequences[:train_set_size], 502 | valid_x=sequences[train_set_size:train_set_size+valid_set_size], 503 | test_x=sequences[train_set_size+valid_set_size:]) 504 | print("Saved to file %s" % dest) 505 | 506 | # Save 10 samples 507 | result = gallery(np.concatenate(sequences[:10]/255), ncols=sequences.shape[1]) 508 | 509 | norm = plt.Normalize(0.0, 1.0) 510 | fig, ax = plt.subplots(figsize=(sequences.shape[1], 10)) 511 | ax.imshow(np.squeeze(result), interpolation='nearest', cmap=cm.Greys_r, norm=norm) 512 | ax.get_xaxis().set_visible(False) 513 | ax.get_yaxis().set_visible(False) 514 | fig.tight_layout() 515 | fig.savefig(dest.split(".")[0]+"_samples.jpg") 516 | 517 | 518 | def generate_3_body_problem_dataset(dest, 519 | train_set_size, 520 | valid_set_size, 521 | test_set_size, 522 | seq_len, 523 | img_size=None, 524 | radius=3, 525 | dt=0.3, 526 | g=9.8, 527 | m=1.0, 528 | vx0_max=0.0, 529 | vy0_max=0.0, 530 | color=False, 531 | cifar_background=False, 532 | ode_steps=10): 533 | 534 | if cifar_background: 535 | import tensorflow as tf 536 | (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() 537 | 538 | from skimage.draw import circle 539 | from skimage.transform import resize 540 | 541 | if img_size is None: 542 | img_size = [32,32] 543 | scale = 10 544 | scaled_img_size = [img_size[0]*scale, img_size[1]*scale] 545 | 546 | def generate_sequence(): 547 | # sample initial position of the center of mass, then sample 548 | # position of each object relative to that. 549 | 550 | collision = True 551 | while collision == True: 552 | seq = [] 553 | 554 | cm_pos = np.random.rand(2) 555 | cm_pos = np.array(img_size)/2 556 | angle1 = np.random.rand()*2*np.pi 557 | angle2 = angle1 + 2*np.pi/3+(np.random.rand()-0.5)/2 558 | angle3 = angle1 + 4*np.pi/3+(np.random.rand()-0.5)/2 559 | 560 | angles = [angle1, angle2, angle3] 561 | # calculate position of both objects 562 | r = (np.random.rand()/2+0.75)*img_size[0]/4 563 | poss = [[np.cos(angle)*r+cm_pos[0], np.sin(angle)*r+cm_pos[1]] for angle in angles] 564 | poss = np.array(poss) 565 | 566 | #angles = np.random.rand(3)*2*np.pi 567 | #vels = [[np.cos(angle)*vx0_max, np.sin(angle)*vy0_max] for angle in angles] 568 | #vels = np.array(vels) 569 | r = np.random.randint(0,2)*2-1 570 | angles = [angle+r*np.pi/2 for angle in angles] 571 | noise = np.random.rand(2)-0.5 572 | vels = [[np.cos(angle)*vx0_max+noise[0], np.sin(angle)*vy0_max+noise[1]] for angle in angles] 573 | vels = np.array(vels) 574 | 575 | if cifar_background: 576 | cifar_img = x_train[np.random.randint(50000)] 577 | 578 | for i in range(seq_len): 579 | if cifar_background: 580 | frame = cifar_img 581 | frame = rgb2gray(frame)/255 582 | frame = resize(frame, scaled_img_size) 583 | frame = np.clip(frame-0.2, 0.0, 1.0) # darken image a bit 584 | else: 585 | if color: 586 | frame = np.zeros(scaled_img_size+[3], dtype=np.float32) 587 | else: 588 | frame = np.zeros(scaled_img_size+[1], dtype=np.float32) 589 | 590 | for j, pos in enumerate(poss): 591 | rr, cc = circle(int(pos[1]*scale), int(pos[0]*scale), radius*scale, scaled_img_size) 592 | if color: 593 | frame[rr, cc, 2-j] = 1.0 594 | else: 595 | frame[rr, cc, 0] = 1.0 596 | 597 | frame = resize(frame, img_size, anti_aliasing=True) 598 | frame = (frame*255).astype(np.uint8) 599 | 600 | seq.append(frame) 601 | 602 | # rollout physics 603 | for _ in range(ode_steps): 604 | norm01 = np.linalg.norm(poss[0]-poss[1]) 605 | norm12 = np.linalg.norm(poss[1]-poss[2]) 606 | norm20 = np.linalg.norm(poss[2]-poss[0]) 607 | vec01 = (poss[0]-poss[1]) 608 | vec12 = (poss[1]-poss[2]) 609 | vec20 = (poss[2]-poss[0]) 610 | 611 | # Compute force vectors 612 | F = [vec01/norm01**3-vec20/norm20**3, 613 | vec12/norm12**3-vec01/norm01**3, 614 | vec20/norm20**3-vec12/norm12**3] 615 | F = np.array(F) 616 | F = -g*m*m*F 617 | 618 | vels = vels + dt/ode_steps*F 619 | poss = poss + dt/ode_steps*vels 620 | 621 | collision = any([verify_wall_collision(pos, vel, radius, img_size) for pos, vel in zip(poss, vels)]) or \ 622 | verify_object_collision(poss, radius+1) 623 | if collision: 624 | break 625 | 626 | if collision: 627 | break 628 | 629 | return seq 630 | 631 | sequences = [] 632 | for i in range(train_set_size+valid_set_size+test_set_size): 633 | if i % 100 == 0: 634 | print(i) 635 | sequences.append(generate_sequence()) 636 | sequences = np.array(sequences, dtype=np.uint8) 637 | 638 | np.savez_compressed(dest, 639 | train_x=sequences[:train_set_size], 640 | valid_x=sequences[train_set_size:train_set_size+valid_set_size], 641 | test_x=sequences[train_set_size+valid_set_size:]) 642 | print("Saved to file %s" % dest) 643 | 644 | # Save 10 samples 645 | result = gallery(np.concatenate(sequences[:10]/255), ncols=sequences.shape[1]) 646 | 647 | norm = plt.Normalize(0.0, 1.0) 648 | fig, ax = plt.subplots(figsize=(sequences.shape[1], 10)) 649 | ax.imshow(np.squeeze(result), interpolation='nearest', cmap=cm.Greys_r, norm=norm) 650 | ax.get_xaxis().set_visible(False) 651 | ax.get_yaxis().set_visible(False) 652 | fig.tight_layout() 653 | fig.savefig(dest.split(".")[0]+"_samples.jpg") 654 | --------------------------------------------------------------------------------