├── .gitignore ├── README.md ├── build_venv.sh ├── connectivity_utils.py ├── download_dataset.sh ├── gns.sh ├── gns.submit ├── graph_network.py ├── images └── water_ramps_rollout.gif ├── learned_simulator.py ├── model_demo.py ├── noise_utils.py ├── reading_utils.py ├── render_rollout.py ├── requirements.txt ├── run.sh ├── slurm_scripts ├── render.sh ├── rollout.sh └── train.sh ├── start_venv.sh ├── tfrecord.ipynb ├── train.py └── write_tfrecord.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | venv 3 | scratch 4 | 5 | slurm_scripts/logs_sandramps 6 | slurm_scripts/logs_sand-20M 7 | 8 | slurm_scripts/*.e* 9 | slurm_scripts/*.o* 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning to Simulate Complex Physics with Graph Networks (ICML 2020) 2 | 3 | ICML poster: [icml.cc/virtual/2020/poster/6849](https://icml.cc/virtual/2020/poster/6849) 4 | 5 | Video site: [sites.google.com/view/learning-to-simulate](https://sites.google.com/view/learning-to-simulate) 6 | 7 | ArXiv: [arxiv.org/abs/2002.09405](https://arxiv.org/abs/2002.09405) 8 | 9 | If you use the code here please cite this paper: 10 | 11 | @inproceedings{sanchezgonzalez2020learning, 12 | title={Learning to Simulate Complex Physics with Graph Networks}, 13 | author={Alvaro Sanchez-Gonzalez and 14 | Jonathan Godwin and 15 | Tobias Pfaff and 16 | Rex Ying and 17 | Jure Leskovec and 18 | Peter W. Battaglia}, 19 | booktitle={International Conference on Machine Learning}, 20 | year={2020} 21 | } 22 | 23 | ## Install on TACC 24 | ``` 25 | module load cuda/10.0 26 | module load cudnn/7.6.2 27 | ``` 28 | 29 | ## Example usage: train a model and display a trajectory 30 | 31 | ![WaterRamps rollout](images/water_ramps_rollout.gif) 32 | 33 | After downloading the repo, and from the parent directory. Install dependencies: 34 | 35 | pip install -r learning_to_simulate/requirements.txt 36 | mkdir -p /tmp/rollous 37 | 38 | Download dataset (e.g. WaterRamps): 39 | 40 | mkdir -p /tmp/datasets 41 | bash ./learning_to_simulate/download_dataset.sh WaterRamps /tmp/datasets 42 | 43 | Train a model: 44 | 45 | mkdir -p /tmp/models 46 | python -m learning_to_simulate.train \ 47 | --data_path=/tmp/datasets/WaterRamps \ 48 | --model_path=/tmp/models/WaterRamps 49 | 50 | Generate some trajectory rollouts on the test set: 51 | 52 | mkdir -p /tmp/rollouts 53 | python -m learning_to_simulate.train \ 54 | --mode="eval_rollout" \ 55 | --data_path=/tmp/datasets/WaterRamps \ 56 | --model_path=/tmp/models/WaterRamps \ 57 | --output_path=/tmp/rollouts/WaterRamps 58 | 59 | Plot a trajectory: 60 | 61 | python -m learning_to_simulate.render_rollout \ 62 | --rollout_path=/tmp/rollouts/WaterRamps/rollout_test_0.pkl 63 | 64 | 65 | ## Datasets 66 | 67 | Datasets are available to download via: 68 | 69 | * Metadata file with dataset information (sequence length, dimensionality, box bounds, default connectivity radius, statistics for normalization, ...): 70 | 71 | `https://storage.googleapis.com/learning-to-simulate-complex-physics/Datasets/{DATASET_NAME}/metadata.json` 72 | 73 | * TFRecords containing data for all trajectories (particle types, positions, global context, ...): 74 | 75 | `https://storage.googleapis.com/learning-to-simulate-complex-physics/Datasets/{DATASET_NAME}/{DATASET_SPLIT}.tfrecord` 76 | 77 | Where: 78 | 79 | * `{DATASET_SPLIT}` is one of: 80 | * `train` 81 | * `valid` 82 | * `test` 83 | 84 | * `{DATASET_NAME}` one of the datasets following the naming used in the paper: 85 | * `WaterDrop` 86 | * `Water` 87 | * `Sand` 88 | * `Goop` 89 | * `MultiMaterial` 90 | * `RandomFloor` 91 | * `WaterRamps` 92 | * `SandRamps` 93 | * `FluidShake` 94 | * `FluidShakeBox` 95 | * `Continuous` 96 | * `WaterDrop-XL` 97 | * `Water-3D` 98 | * `Sand-3D` 99 | * `Goop-3D` 100 | 101 | The provided script `./download_dataset.sh` may be used to download all files from each dataset into a folder given its name. 102 | 103 | An additional smaller dataset `WaterDropSample`, which includes only the first two trajectories of `WaterDrop` for each split, is provided for debugging purposes. 104 | 105 | 106 | ## Code structure 107 | 108 | * `train.py`: Script for training, evaluating and generating rollout trajectories. 109 | * `learned_simulator.py`: Implementation of the learnable one-step model that returns the next position of the particles given inputs. It includes data preprocessing, Euler integration, and a helper method for building normalized training outputs and targets. 110 | * `graph_network.py`: Implementation of the graph network used at the core of the learnable part of the model. 111 | * `render_rollout.py`: Visualization code for displaying rollouts such as the example animation. 112 | * `{noise/connectivity/reading}_utils.py`: Util modules for adding noise to the inputs, computing graph connectivity and reading datasets form TFRecords. 113 | * `model_demo.py`: example connecting the model to input dummy data. 114 | 115 | Note this is a reference implementation not designed to scale up to TPUs (unlike the one used for the paper). We have tested that the model can be trained with a batch size of 2 on a single NVIDIA V100 to reach similar qualitative performance (except for the XL and 3D datasets due to OOM). 116 | -------------------------------------------------------------------------------- /build_venv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Fail on any error. 4 | set -e 5 | 6 | # Display commands being run. 7 | set -x 8 | 9 | module load cuda/10.0 10 | module load cudnn/7.6.2 11 | 12 | virtualenv --python=python3.6 venv 13 | source venv/bin/activate 14 | 15 | # Install dependencies. 16 | pip install -r requirements.txt 17 | 18 | -------------------------------------------------------------------------------- /connectivity_utils.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # pylint: disable=g-bad-file-header 3 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================ 17 | """Tools to compute the connectivity of the graph.""" 18 | 19 | import functools 20 | 21 | import numpy as np 22 | from sklearn import neighbors 23 | import tensorflow.compat.v1 as tf 24 | 25 | 26 | def _compute_connectivity(positions, radius, add_self_edges): 27 | """Get the indices of connected edges with radius connectivity. 28 | 29 | Args: 30 | positions: Positions of nodes in the graph. Shape: 31 | [num_nodes_in_graph, num_dims]. 32 | radius: Radius of connectivity. 33 | add_self_edges: Whether to include self edges or not. 34 | 35 | Returns: 36 | senders indices [num_edges_in_graph] 37 | receiver indices [num_edges_in_graph] 38 | 39 | """ 40 | tree = neighbors.KDTree(positions) 41 | receivers_list = tree.query_radius(positions, r=radius) 42 | num_nodes = len(positions) 43 | senders = np.repeat(range(num_nodes), [len(a) for a in receivers_list]) 44 | receivers = np.concatenate(receivers_list, axis=0) 45 | 46 | if not add_self_edges: 47 | # Remove self edges. 48 | mask = senders != receivers 49 | senders = senders[mask] 50 | receivers = receivers[mask] 51 | 52 | return senders, receivers 53 | 54 | 55 | def _compute_connectivity_for_batch( 56 | positions, n_node, radius, add_self_edges): 57 | """`compute_connectivity` for a batch of graphs. 58 | 59 | Args: 60 | positions: Positions of nodes in the batch of graphs. Shape: 61 | [num_nodes_in_batch, num_dims]. 62 | n_node: Number of nodes for each graph in the batch. Shape: 63 | [num_graphs in batch]. 64 | radius: Radius of connectivity. 65 | add_self_edges: Whether to include self edges or not. 66 | 67 | Returns: 68 | senders indices [num_edges_in_batch] 69 | receiver indices [num_edges_in_batch] 70 | number of edges per graph [num_graphs_in_batch] 71 | 72 | """ 73 | 74 | # TODO(alvarosg): Consider if we want to support batches here or not. 75 | # Separate the positions corresponding to particles in different graphs. 76 | positions_per_graph_list = np.split(positions, np.cumsum(n_node[:-1]), axis=0) 77 | receivers_list = [] 78 | senders_list = [] 79 | n_edge_list = [] 80 | num_nodes_in_previous_graphs = 0 81 | 82 | # Compute connectivity for each graph in the batch. 83 | for positions_graph_i in positions_per_graph_list: 84 | senders_graph_i, receivers_graph_i = _compute_connectivity( 85 | positions_graph_i, radius, add_self_edges) 86 | 87 | num_edges_graph_i = len(senders_graph_i) 88 | n_edge_list.append(num_edges_graph_i) 89 | 90 | # Because the inputs will be concatenated, we need to add offsets to the 91 | # sender and receiver indices according to the number of nodes in previous 92 | # graphs in the same batch. 93 | receivers_list.append(receivers_graph_i + num_nodes_in_previous_graphs) 94 | senders_list.append(senders_graph_i + num_nodes_in_previous_graphs) 95 | 96 | num_nodes_graph_i = len(positions_graph_i) 97 | num_nodes_in_previous_graphs += num_nodes_graph_i 98 | 99 | # Concatenate all of the results. 100 | senders = np.concatenate(senders_list, axis=0).astype(np.int32) 101 | receivers = np.concatenate(receivers_list, axis=0).astype(np.int32) 102 | n_edge = np.stack(n_edge_list).astype(np.int32) 103 | 104 | return senders, receivers, n_edge 105 | 106 | 107 | def compute_connectivity_for_batch_pyfunc( 108 | positions, n_node, radius, add_self_edges=True): 109 | """`_compute_connectivity_for_batch` wrapped in a pyfunc.""" 110 | partial_fn = functools.partial( 111 | _compute_connectivity_for_batch, add_self_edges=add_self_edges) 112 | senders, receivers, n_edge = tf.py_function( 113 | partial_fn, 114 | [positions, n_node, radius], 115 | [tf.int32, tf.int32, tf.int32]) 116 | senders.set_shape([None]) 117 | receivers.set_shape([None]) 118 | n_edge.set_shape(n_node.get_shape()) 119 | return senders, receivers, n_edge 120 | 121 | -------------------------------------------------------------------------------- /download_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2020 Deepmind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # Usage: 17 | # bash download_dataset.sh ${DATASET_NAME} ${OUTPUT_DIR} 18 | # Example: 19 | # bash download_dataset.sh WaterDrop /tmp/ 20 | 21 | set -e 22 | 23 | DATASET_NAME="${1}" 24 | OUTPUT_DIR="${2}/${DATASET_NAME}" 25 | 26 | BASE_URL="https://storage.googleapis.com/learning-to-simulate-complex-physics/Datasets/${DATASET_NAME}/" 27 | 28 | mkdir -p ${OUTPUT_DIR} 29 | for file in metadata.json train.tfrecord valid.tfrecord test.tfrecord 30 | do 31 | wget -O "${OUTPUT_DIR}/${file}" "${BASE_URL}${file}" 32 | done 33 | -------------------------------------------------------------------------------- /gns.sh: -------------------------------------------------------------------------------- 1 | cd /work/05873/kks32/frontera 2 | TMP_DIR="/work/05873/kks32/frontera/" 3 | source "${TMP_DIR}/learning_to_simulate/bin/activate" 4 | module load cuda/10.0 5 | module load cudnn/7.6.2 6 | DATASET_NAME="WaterDropSample" 7 | DATA_PATH="${TMP_DIR}/datasets/${DATASET_NAME}" 8 | MODEL_PATH="${TMP_DIR}/models/${DATASET_NAME}" 9 | ROLLOUT_PATH="${TMP_DIR}/rollouts/${DATASET_NAME}" 10 | python -m learning_to_simulate.train --data_path=${DATA_PATH} --model_path=${MODEL_PATH} --num_steps=120000 11 | 12 | -------------------------------------------------------------------------------- /gns.submit: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -J gns # job name 3 | #SBATCH -o gns.o%j # output and error file name (%j expands to jobID) 4 | #SBATCH -N 1 # number of nodes requested 5 | #SBATCH -n 1 # total number of mpi tasks requested 6 | #SBATCH -p rtx # queue (partition) -- normal, development, etc. 7 | #SBATCH -A BCS20003 # Job project 8 | #SBATCH -t 01:00:00 # run time (hh:mm:ss) - 1 hour 9 | # Slurm email notifications 10 | #SBATCH --mail-user=userid@utexas.edu 11 | #SBATCH --mail-type=begin # email me when the job starts 12 | #SBATCH --mail-type=end # email me when the job finishes 13 | # run the executable named a.out 14 | ibrun sh ./gns.sh 15 | 16 | -------------------------------------------------------------------------------- /graph_network.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # pylint: disable=g-bad-file-header 3 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================ 17 | 18 | """Graph network implementation accompanying ICML 2020 submission. 19 | 20 | "Learning to Simulate Complex Physics with Graph Networks" 21 | 22 | Alvaro Sanchez-Gonzalez*, Jonathan Godwin*, Tobias Pfaff*, Rex Ying, 23 | Jure Leskovec, Peter W. Battaglia 24 | 25 | https://arxiv.org/abs/2002.09405 26 | 27 | The Sonnet `EncodeProcessDecode` module provided here implements the learnable 28 | parts of the model. 29 | It assumes an encoder preprocessor has already built a graph with 30 | connectivity and features as described in the paper, with features normalized 31 | to zero-mean unit-variance. 32 | 33 | Dependencies include Tensorflow 1.x, Sonnet 1.x and the Graph Nets 1.1 library. 34 | """ 35 | from typing import Callable 36 | 37 | import graph_nets as gn 38 | import sonnet as snt 39 | import tensorflow as tf 40 | 41 | Reducer = Callable[[tf.Tensor, tf.Tensor, tf.Tensor], tf.Tensor] 42 | 43 | 44 | def build_mlp( 45 | hidden_size: int, num_hidden_layers: int, output_size: int) -> snt.Module: 46 | """Builds an MLP.""" 47 | return snt.nets.MLP( 48 | output_sizes=[hidden_size] * num_hidden_layers + [output_size]) 49 | 50 | 51 | class EncodeProcessDecode(snt.AbstractModule): 52 | """Encode-Process-Decode function approximator for learnable simulator.""" 53 | 54 | def __init__( 55 | self, 56 | latent_size: int, 57 | mlp_hidden_size: int, 58 | mlp_num_hidden_layers: int, 59 | num_message_passing_steps: int, 60 | output_size: int, 61 | reducer: Reducer = tf.math.unsorted_segment_sum, 62 | name: str = "EncodeProcessDecode"): 63 | """Inits the model. 64 | 65 | Args: 66 | latent_size: Size of the node and edge latent representations. 67 | mlp_hidden_size: Hidden layer size for all MLPs. 68 | mlp_num_hidden_layers: Number of hidden layers in all MLPs. 69 | num_message_passing_steps: Number of message passing steps. 70 | output_size: Output size of the decode node representations as required 71 | by the downstream update function. 72 | reducer: Reduction to be used when aggregating the edges in the nodes in 73 | the interaction network. This should be a callable whose signature 74 | matches tf.math.unsorted_segment_sum. 75 | name: Name of the model. 76 | """ 77 | 78 | super().__init__(name=name) 79 | 80 | self._latent_size = latent_size 81 | self._mlp_hidden_size = mlp_hidden_size 82 | self._mlp_num_hidden_layers = mlp_num_hidden_layers 83 | self._num_message_passing_steps = num_message_passing_steps 84 | self._output_size = output_size 85 | self._reducer = reducer 86 | 87 | with self._enter_variable_scope(): 88 | self._networks_builder() 89 | 90 | def _build(self, input_graph: gn.graphs.GraphsTuple) -> tf.Tensor: 91 | """Forward pass of the learnable dynamics model.""" 92 | 93 | # Encode the input_graph. 94 | latent_graph_0 = self._encode(input_graph) 95 | 96 | # Do `m` message passing steps in the latent graphs. 97 | latent_graph_m = self._process(latent_graph_0) 98 | 99 | # Decode from the last latent graph. 100 | return self._decode(latent_graph_m) 101 | 102 | def _networks_builder(self): 103 | """Builds the networks.""" 104 | 105 | def build_mlp_with_layer_norm(): 106 | mlp = build_mlp( 107 | hidden_size=self._mlp_hidden_size, 108 | num_hidden_layers=self._mlp_num_hidden_layers, 109 | output_size=self._latent_size) 110 | return snt.Sequential([mlp, snt.LayerNorm()]) 111 | 112 | # The encoder graph network independently encodes edge and node features. 113 | encoder_kwargs = dict( 114 | edge_model_fn=build_mlp_with_layer_norm, 115 | node_model_fn=build_mlp_with_layer_norm) 116 | self._encoder_network = gn.modules.GraphIndependent(**encoder_kwargs) 117 | 118 | # Create `num_message_passing_steps` graph networks with unshared parameters 119 | # that update the node and edge latent features. 120 | # Note that we can use `modules.InteractionNetwork` because 121 | # it also outputs the messages as updated edge latent features. 122 | self._processor_networks = [] 123 | for _ in range(self._num_message_passing_steps): 124 | self._processor_networks.append( 125 | gn.modules.InteractionNetwork( 126 | edge_model_fn=build_mlp_with_layer_norm, 127 | node_model_fn=build_mlp_with_layer_norm, 128 | reducer=self._reducer)) 129 | 130 | # The decoder MLP decodes node latent features into the output size. 131 | self._decoder_network = build_mlp( 132 | hidden_size=self._mlp_hidden_size, 133 | num_hidden_layers=self._mlp_num_hidden_layers, 134 | output_size=self._output_size) 135 | 136 | def _encode( 137 | self, input_graph: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple: 138 | """Encodes the input graph features into a latent graph.""" 139 | 140 | # Copy the globals to all of the nodes, if applicable. 141 | if input_graph.globals is not None: 142 | broadcasted_globals = gn.blocks.broadcast_globals_to_nodes(input_graph) 143 | input_graph = input_graph.replace( 144 | nodes=tf.concat([input_graph.nodes, broadcasted_globals], axis=-1), 145 | globals=None) 146 | 147 | # Encode the node and edge features. 148 | latent_graph_0 = self._encoder_network(input_graph) 149 | return latent_graph_0 150 | 151 | def _process( 152 | self, latent_graph_0: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple: 153 | """Processes the latent graph with several steps of message passing.""" 154 | 155 | # Do `m` message passing steps in the latent graphs. 156 | # (In the shared parameters case, just reuse the same `processor_network`) 157 | latent_graph_prev_k = latent_graph_0 158 | latent_graph_k = latent_graph_0 159 | for processor_network_k in self._processor_networks: 160 | latent_graph_k = self._process_step( 161 | processor_network_k, latent_graph_prev_k) 162 | latent_graph_prev_k = latent_graph_k 163 | 164 | latent_graph_m = latent_graph_k 165 | return latent_graph_m 166 | 167 | def _process_step( 168 | self, processor_network_k: snt.Module, 169 | latent_graph_prev_k: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple: 170 | """Single step of message passing with node/edge residual connections.""" 171 | 172 | # One step of message passing. 173 | latent_graph_k = processor_network_k(latent_graph_prev_k) 174 | 175 | # Add residuals. 176 | latent_graph_k = latent_graph_k.replace( 177 | nodes=latent_graph_k.nodes+latent_graph_prev_k.nodes, 178 | edges=latent_graph_k.edges+latent_graph_prev_k.edges) 179 | return latent_graph_k 180 | 181 | def _decode(self, latent_graph: gn.graphs.GraphsTuple) -> tf.Tensor: 182 | """Decodes from the latent graph.""" 183 | return self._decoder_network(latent_graph.nodes) 184 | -------------------------------------------------------------------------------- /images/water_ramps_rollout.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kks32/learning_to_simulate/73ea967f5f2a569607e16e67ff5c2ec3303b69da/images/water_ramps_rollout.gif -------------------------------------------------------------------------------- /learned_simulator.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # pylint: disable=g-bad-file-header 3 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================ 17 | """Full model implementation accompanying ICML 2020 submission. 18 | 19 | "Learning to Simulate Complex Physics with Graph Networks" 20 | 21 | Alvaro Sanchez-Gonzalez*, Jonathan Godwin*, Tobias Pfaff*, Rex Ying, 22 | Jure Leskovec, Peter W. Battaglia 23 | 24 | https://arxiv.org/abs/2002.09405 25 | 26 | """ 27 | 28 | import graph_nets as gn 29 | import sonnet as snt 30 | import tensorflow.compat.v1 as tf 31 | 32 | from learning_to_simulate import connectivity_utils 33 | from learning_to_simulate import graph_network 34 | 35 | STD_EPSILON = 1e-8 36 | 37 | 38 | class LearnedSimulator(snt.AbstractModule): 39 | """Learned simulator from https://arxiv.org/pdf/2002.09405.pdf.""" 40 | 41 | def __init__( 42 | self, 43 | num_dimensions, 44 | connectivity_radius, 45 | graph_network_kwargs, 46 | boundaries, 47 | normalization_stats, 48 | num_particle_types, 49 | particle_type_embedding_size, 50 | name="LearnedSimulator"): 51 | """Inits the model. 52 | 53 | Args: 54 | num_dimensions: Dimensionality of the problem. 55 | connectivity_radius: Scalar with the radius of connectivity. 56 | graph_network_kwargs: Keyword arguments to pass to the learned part 57 | of the graph network `model.EncodeProcessDecode`. 58 | boundaries: List of 2-tuples, containing the lower and upper boundaries of 59 | the cuboid containing the particles along each dimensions, matching 60 | the dimensionality of the problem. 61 | normalization_stats: Dictionary with statistics with keys "acceleration" 62 | and "velocity", containing a named tuple for each with mean and std 63 | fields, matching the dimensionality of the problem. 64 | num_particle_types: Number of different particle types. 65 | particle_type_embedding_size: Embedding size for the particle type. 66 | name: Name of the Sonnet module. 67 | 68 | """ 69 | super().__init__(name=name) 70 | 71 | self._connectivity_radius = connectivity_radius 72 | self._num_particle_types = num_particle_types 73 | self._boundaries = boundaries 74 | self._normalization_stats = normalization_stats 75 | with self._enter_variable_scope(): 76 | self._graph_network = graph_network.EncodeProcessDecode( 77 | output_size=num_dimensions, **graph_network_kwargs) 78 | 79 | if self._num_particle_types > 1: 80 | self._particle_type_embedding = tf.get_variable( 81 | "particle_embedding", 82 | [self._num_particle_types, particle_type_embedding_size], 83 | trainable=True, use_resource=True) 84 | 85 | def _build(self, position_sequence, n_particles_per_example, 86 | global_context=None, particle_types=None): 87 | """Produces a model step, outputting the next position for each particle. 88 | 89 | Args: 90 | position_sequence: Sequence of positions for each node in the batch, 91 | with shape [num_particles_in_batch, sequence_length, num_dimensions] 92 | n_particles_per_example: Number of particles for each graph in the batch 93 | with shape [batch_size] 94 | global_context: Tensor of shape [batch_size, context_size], with global 95 | context. 96 | particle_types: Integer tensor of shape [num_particles_in_batch] with 97 | the integer types of the particles, from 0 to `num_particle_types - 1`. 98 | If None, we assume all particles are the same type. 99 | 100 | Returns: 101 | Next position with shape [num_particles_in_batch, num_dimensions] for one 102 | step into the future from the input sequence. 103 | """ 104 | input_graphs_tuple = self._encoder_preprocessor( 105 | position_sequence, n_particles_per_example, global_context, 106 | particle_types) 107 | 108 | normalized_acceleration = self._graph_network(input_graphs_tuple) 109 | 110 | next_position = self._decoder_postprocessor( 111 | normalized_acceleration, position_sequence) 112 | 113 | return next_position 114 | 115 | def _encoder_preprocessor( 116 | self, position_sequence, n_node, global_context, particle_types): 117 | # Extract important features from the position_sequence. 118 | most_recent_position = position_sequence[:, -1] 119 | velocity_sequence = time_diff(position_sequence) # Finite-difference. 120 | 121 | # Get connectivity of the graph. 122 | (senders, receivers, n_edge 123 | ) = connectivity_utils.compute_connectivity_for_batch_pyfunc( 124 | most_recent_position, n_node, self._connectivity_radius) 125 | 126 | # Collect node features. 127 | node_features = [] 128 | 129 | # Normalized velocity sequence, merging spatial an time axis. 130 | velocity_stats = self._normalization_stats["velocity"] 131 | normalized_velocity_sequence = ( 132 | velocity_sequence - velocity_stats.mean) / velocity_stats.std 133 | 134 | flat_velocity_sequence = snt.MergeDims(start=1, size=2)( 135 | normalized_velocity_sequence) 136 | node_features.append(flat_velocity_sequence) 137 | 138 | # Normalized clipped distances to lower and upper boundaries. 139 | # boundaries are an array of shape [num_dimensions, 2], where the second 140 | # axis, provides the lower/upper boundaries. 141 | boundaries = tf.constant(self._boundaries, dtype=tf.float32) 142 | distance_to_lower_boundary = ( 143 | most_recent_position - tf.expand_dims(boundaries[:, 0], 0)) 144 | distance_to_upper_boundary = ( 145 | tf.expand_dims(boundaries[:, 1], 0) - most_recent_position) 146 | distance_to_boundaries = tf.concat( 147 | [distance_to_lower_boundary, distance_to_upper_boundary], axis=1) 148 | normalized_clipped_distance_to_boundaries = tf.clip_by_value( 149 | distance_to_boundaries / self._connectivity_radius, -1., 1.) 150 | node_features.append(normalized_clipped_distance_to_boundaries) 151 | 152 | # Particle type. 153 | if self._num_particle_types > 1: 154 | particle_type_embeddings = tf.nn.embedding_lookup( 155 | self._particle_type_embedding, particle_types) 156 | node_features.append(particle_type_embeddings) 157 | 158 | # Collect edge features. 159 | edge_features = [] 160 | 161 | # Relative displacement and distances normalized to radius 162 | normalized_relative_displacements = ( 163 | tf.gather(most_recent_position, senders) - 164 | tf.gather(most_recent_position, receivers)) / self._connectivity_radius 165 | edge_features.append(normalized_relative_displacements) 166 | 167 | normalized_relative_distances = tf.norm( 168 | normalized_relative_displacements, axis=-1, keepdims=True) 169 | edge_features.append(normalized_relative_distances) 170 | 171 | # Normalize the global context. 172 | if global_context is not None: 173 | context_stats = self._normalization_stats["context"] 174 | # Context in some datasets are all zero, so add an epsilon for numerical 175 | # stability. 176 | global_context = (global_context - context_stats.mean) / tf.math.maximum( 177 | context_stats.std, STD_EPSILON) 178 | 179 | return gn.graphs.GraphsTuple( 180 | nodes=tf.concat(node_features, axis=-1), 181 | edges=tf.concat(edge_features, axis=-1), 182 | globals=global_context, # self._graph_net will appending this to nodes. 183 | n_node=n_node, 184 | n_edge=n_edge, 185 | senders=senders, 186 | receivers=receivers, 187 | ) 188 | 189 | def _decoder_postprocessor(self, normalized_acceleration, position_sequence): 190 | 191 | # The model produces the output in normalized space so we apply inverse 192 | # normalization. 193 | acceleration_stats = self._normalization_stats["acceleration"] 194 | acceleration = ( 195 | normalized_acceleration * acceleration_stats.std 196 | ) + acceleration_stats.mean 197 | 198 | # Use an Euler integrator to go from acceleration to position, assuming 199 | # a dt=1 corresponding to the size of the finite difference. 200 | most_recent_position = position_sequence[:, -1] 201 | most_recent_velocity = most_recent_position - position_sequence[:, -2] 202 | 203 | new_velocity = most_recent_velocity + acceleration # * dt = 1 204 | new_position = most_recent_position + new_velocity # * dt = 1 205 | return new_position 206 | 207 | def get_predicted_and_target_normalized_accelerations( 208 | self, next_position, position_sequence_noise, position_sequence, 209 | n_particles_per_example, global_context=None, particle_types=None): # pylint: disable=g-doc-args 210 | """Produces normalized and predicted acceleration targets. 211 | 212 | Args: 213 | next_position: Tensor of shape [num_particles_in_batch, num_dimensions] 214 | with the positions the model should output given the inputs. 215 | position_sequence_noise: Tensor of the same shape as `position_sequence` 216 | with the noise to apply to each particle. 217 | position_sequence, n_node, global_context, particle_types: Inputs to the 218 | model as defined by `_build`. 219 | 220 | Returns: 221 | Tensors of shape [num_particles_in_batch, num_dimensions] with the 222 | predicted and target normalized accelerations. 223 | """ 224 | 225 | # Add noise to the input position sequence. 226 | noisy_position_sequence = position_sequence + position_sequence_noise 227 | 228 | # Perform the forward pass with the noisy position sequence. 229 | input_graphs_tuple = self._encoder_preprocessor( 230 | noisy_position_sequence, n_particles_per_example, global_context, 231 | particle_types) 232 | predicted_normalized_acceleration = self._graph_network(input_graphs_tuple) 233 | 234 | # Calculate the target acceleration, using an `adjusted_next_position `that 235 | # is shifted by the noise in the last input position. 236 | next_position_adjusted = next_position + position_sequence_noise[:, -1] 237 | target_normalized_acceleration = self._inverse_decoder_postprocessor( 238 | next_position_adjusted, noisy_position_sequence) 239 | # As a result the inverted Euler update in the `_inverse_decoder` produces: 240 | # * A target acceleration that does not explicitly correct for the noise in 241 | # the input positions, as the `next_position_adjusted` is different 242 | # from the true `next_position`. 243 | # * A target acceleration that exactly corrects noise in the input velocity 244 | # since the target next velocity calculated by the inverse Euler update 245 | # as `next_position_adjusted - noisy_position_sequence[:,-1]` 246 | # matches the ground truth next velocity (noise cancels out). 247 | 248 | return predicted_normalized_acceleration, target_normalized_acceleration 249 | 250 | def _inverse_decoder_postprocessor(self, next_position, position_sequence): 251 | """Inverse of `_decoder_postprocessor`.""" 252 | 253 | previous_position = position_sequence[:, -1] 254 | previous_velocity = previous_position - position_sequence[:, -2] 255 | next_velocity = next_position - previous_position 256 | acceleration = next_velocity - previous_velocity 257 | 258 | acceleration_stats = self._normalization_stats["acceleration"] 259 | normalized_acceleration = ( 260 | acceleration - acceleration_stats.mean) / acceleration_stats.std 261 | return normalized_acceleration 262 | 263 | 264 | def time_diff(input_sequence): 265 | return input_sequence[:, 1:] - input_sequence[:, :-1] 266 | 267 | -------------------------------------------------------------------------------- /model_demo.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # pylint: disable=g-bad-file-header 3 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================ 17 | 18 | """Example script accompanying ICML 2020 submission. 19 | 20 | "Learning to Simulate Complex Physics with Graph Networks" 21 | 22 | Alvaro Sanchez-Gonzalez*, Jonathan Godwin*, Tobias Pfaff*, Rex Ying, 23 | Jure Leskovec, Peter W. Battaglia 24 | 25 | https://arxiv.org/abs/2002.09405 26 | 27 | Here we provide the utility function `sample_random_position_sequence()` which 28 | returns a sequence of positions for a variable number of particles, similar to 29 | what a real dataset would provide, and connect the model to it, in both, 30 | single step inference and training mode. 31 | 32 | Dependencies include Tensorflow 1.x, Sonnet 1.x and the Graph Nets 1.1 library. 33 | """ 34 | 35 | import collections 36 | 37 | from learning_to_simulate import learned_simulator 38 | from learning_to_simulate import noise_utils 39 | import numpy as np 40 | import tensorflow.compat.v1 as tf 41 | 42 | INPUT_SEQUENCE_LENGTH = 6 43 | SEQUENCE_LENGTH = INPUT_SEQUENCE_LENGTH + 1 # add one target position. 44 | NUM_DIMENSIONS = 3 45 | NUM_PARTICLE_TYPES = 6 46 | BATCH_SIZE = 5 47 | GLOBAL_CONTEXT_SIZE = 6 48 | 49 | Stats = collections.namedtuple("Stats", ["mean", "std"]) 50 | 51 | DUMMY_STATS = Stats( 52 | mean=np.zeros([NUM_DIMENSIONS], dtype=np.float32), 53 | std=np.ones([NUM_DIMENSIONS], dtype=np.float32)) 54 | DUMMY_CONTEXT_STATS = Stats( 55 | mean=np.zeros([GLOBAL_CONTEXT_SIZE], dtype=np.float32), 56 | std=np.ones([GLOBAL_CONTEXT_SIZE], dtype=np.float32)) 57 | DUMMY_BOUNDARIES = [(-1., 1.)] * NUM_DIMENSIONS 58 | 59 | 60 | def sample_random_position_sequence(): 61 | """Returns mock data mimicking the input features collected by the encoder.""" 62 | num_particles = tf.random_uniform( 63 | shape=(), minval=50, maxval=1000, dtype=tf.int32) 64 | position_sequence = tf.random.normal( 65 | shape=[num_particles, SEQUENCE_LENGTH, NUM_DIMENSIONS]) 66 | return position_sequence 67 | 68 | 69 | def main(): 70 | 71 | # Build the model. 72 | learnable_model = learned_simulator.LearnedSimulator( 73 | num_dimensions=NUM_DIMENSIONS, 74 | connectivity_radius=0.05, 75 | graph_network_kwargs=dict( 76 | latent_size=128, 77 | mlp_hidden_size=128, 78 | mlp_num_hidden_layers=2, 79 | num_message_passing_steps=10, 80 | ), 81 | boundaries=DUMMY_BOUNDARIES, 82 | normalization_stats={"acceleration": DUMMY_STATS, 83 | "velocity": DUMMY_STATS, 84 | "context": DUMMY_CONTEXT_STATS,}, 85 | num_particle_types=NUM_PARTICLE_TYPES, 86 | particle_type_embedding_size=16, 87 | ) 88 | 89 | # Sample a batch of particle sequences with shape: 90 | # [TOTAL_NUM_PARTICLES, SEQUENCE_LENGTH, NUM_DIMENSIONS] 91 | sampled_position_sequences = [ 92 | sample_random_position_sequence() for _ in range(BATCH_SIZE)] 93 | position_sequence_batch = tf.concat(sampled_position_sequences, axis=0) 94 | 95 | # Count how many particles are present in each element in the batch. 96 | # [BATCH_SIZE] 97 | n_particles_per_example = tf.stack( 98 | [tf.shape(seq)[0] for seq in sampled_position_sequences], axis=0) 99 | 100 | # Sample particle types. 101 | # [TOTAL_NUM_PARTICLES] 102 | particle_types = tf.random_uniform( 103 | [tf.shape(position_sequence_batch)[0]], 104 | 0, NUM_PARTICLE_TYPES, dtype=tf.int32) 105 | 106 | # Sample global context. 107 | global_context = tf.random_uniform( 108 | [BATCH_SIZE, GLOBAL_CONTEXT_SIZE], -1., 1., dtype=tf.float32) 109 | 110 | # Separate input sequence from target sequence. 111 | # [TOTAL_NUM_PARTICLES, INPUT_SEQUENCE_LENGTH, NUM_DIMENSIONS] 112 | input_position_sequence = position_sequence_batch[:, :-1] 113 | # [TOTAL_NUM_PARTICLES, NUM_DIMENSIONS] 114 | target_next_position = position_sequence_batch[:, -1] 115 | 116 | # Single step of inference with the model to predict next position for each 117 | # particle [TOTAL_NUM_PARTICLES, NUM_DIMENSIONS]. 118 | predicted_next_position = learnable_model( 119 | input_position_sequence, n_particles_per_example, global_context, 120 | particle_types) 121 | print(f"Per-particle output tensor: {predicted_next_position}") 122 | 123 | # Obtaining predicted and target normalized accelerations for training. 124 | position_sequence_noise = ( 125 | noise_utils.get_random_walk_noise_for_position_sequence( 126 | input_position_sequence, noise_std_last_step=6.7e-4)) 127 | 128 | # Both with shape [TOTAL_NUM_PARTICLES, NUM_DIMENSIONS] 129 | predicted_normalized_acceleration, target_normalized_acceleration = ( 130 | learnable_model.get_predicted_and_target_normalized_accelerations( 131 | target_next_position, position_sequence_noise, 132 | input_position_sequence, n_particles_per_example, global_context, 133 | particle_types)) 134 | print(f"Predicted norm. acceleration: {predicted_normalized_acceleration}") 135 | print(f"Target norm. acceleration: {target_normalized_acceleration}") 136 | 137 | with tf.train.SingularMonitoredSession() as sess: 138 | sess.run([predicted_next_position, 139 | predicted_normalized_acceleration, 140 | target_normalized_acceleration]) 141 | 142 | 143 | if __name__ == "__main__": 144 | tf.disable_v2_behavior() 145 | main() 146 | -------------------------------------------------------------------------------- /noise_utils.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # pylint: disable=g-bad-file-header 3 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================ 17 | """Methods to calculate input noise.""" 18 | 19 | import tensorflow.compat.v1 as tf 20 | 21 | from learning_to_simulate import learned_simulator 22 | 23 | 24 | def get_random_walk_noise_for_position_sequence( 25 | position_sequence, noise_std_last_step): 26 | """Returns random-walk noise in the velocity applied to the position.""" 27 | 28 | velocity_sequence = learned_simulator.time_diff(position_sequence) 29 | 30 | # We want the noise scale in the velocity at the last step to be fixed. 31 | # Because we are going to compose noise at each step using a random_walk: 32 | # std_last_step**2 = num_velocities * std_each_step**2 33 | # so to keep `std_last_step` fixed, we apply at each step: 34 | # std_each_step `std_last_step / np.sqrt(num_input_velocities)` 35 | # TODO(alvarosg): Make sure this is consistent with the value and 36 | # description provided in the paper. 37 | num_velocities = velocity_sequence.shape.as_list()[1] 38 | velocity_sequence_noise = tf.random.normal( 39 | tf.shape(velocity_sequence), 40 | stddev=noise_std_last_step / num_velocities ** 0.5, 41 | dtype=position_sequence.dtype) 42 | 43 | # Apply the random walk. 44 | velocity_sequence_noise = tf.cumsum(velocity_sequence_noise, axis=1) 45 | 46 | # Integrate the noise in the velocity to the positions, assuming 47 | # an Euler intergrator and a dt = 1, and adding no noise to the very first 48 | # position (since that will only be used to calculate the first position 49 | # change). 50 | position_sequence_noise = tf.concat([ 51 | tf.zeros_like(velocity_sequence_noise[:, 0:1]), 52 | tf.cumsum(velocity_sequence_noise, axis=1)], axis=1) 53 | 54 | return position_sequence_noise 55 | -------------------------------------------------------------------------------- /reading_utils.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # pylint: disable=g-bad-file-header 3 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================ 17 | """Utilities for reading open sourced Learning Complex Physics data.""" 18 | 19 | import functools 20 | import numpy as np 21 | import tensorflow.compat.v1 as tf 22 | 23 | # Create a description of the features. 24 | _FEATURE_DESCRIPTION = { 25 | 'position': tf.io.VarLenFeature(tf.string), 26 | } 27 | 28 | _FEATURE_DESCRIPTION_WITH_GLOBAL_CONTEXT = _FEATURE_DESCRIPTION.copy() 29 | _FEATURE_DESCRIPTION_WITH_GLOBAL_CONTEXT['step_context'] = tf.io.VarLenFeature( 30 | tf.string) 31 | 32 | _FEATURE_DTYPES = { 33 | 'position': { 34 | 'in': np.float32, 35 | 'out': tf.float32 36 | }, 37 | 'step_context': { 38 | 'in': np.float32, 39 | 'out': tf.float32 40 | } 41 | } 42 | 43 | _CONTEXT_FEATURES = { 44 | 'key': tf.io.FixedLenFeature([], tf.int64, default_value=0), 45 | 'particle_type': tf.io.VarLenFeature(tf.string) 46 | } 47 | 48 | 49 | def convert_to_tensor(x, encoded_dtype): 50 | if len(x) == 1: 51 | out = np.frombuffer(x[0].numpy(), dtype=encoded_dtype) 52 | else: 53 | out = [] 54 | for el in x: 55 | out.append(np.frombuffer(el.numpy(), dtype=encoded_dtype)) 56 | out = tf.convert_to_tensor(np.array(out)) 57 | return out 58 | 59 | 60 | def parse_serialized_simulation_example(example_proto, metadata): 61 | """Parses a serialized simulation tf.SequenceExample. 62 | 63 | Args: 64 | example_proto: A string encoding of the tf.SequenceExample proto. 65 | metadata: A dict of metadata for the dataset. 66 | 67 | Returns: 68 | context: A dict, with features that do not vary over the trajectory. 69 | parsed_features: A dict of tf.Tensors representing the parsed examples 70 | across time, where axis zero is the time axis. 71 | 72 | """ 73 | if 'context_mean' in metadata: 74 | feature_description = _FEATURE_DESCRIPTION_WITH_GLOBAL_CONTEXT 75 | else: 76 | feature_description = _FEATURE_DESCRIPTION 77 | context, parsed_features = tf.io.parse_single_sequence_example( 78 | example_proto, 79 | context_features=_CONTEXT_FEATURES, 80 | sequence_features=feature_description) 81 | for feature_key, item in parsed_features.items(): 82 | convert_fn = functools.partial( 83 | convert_to_tensor, encoded_dtype=_FEATURE_DTYPES[feature_key]['in']) 84 | parsed_features[feature_key] = tf.py_function( 85 | convert_fn, inp=[item.values], Tout=_FEATURE_DTYPES[feature_key]['out']) 86 | 87 | # There is an extra frame at the beginning so we can calculate pos change 88 | # for all frames used in the paper. 89 | position_shape = [metadata['sequence_length'] + 1, -1, metadata['dim']] 90 | 91 | # Reshape positions to correct dim: 92 | parsed_features['position'] = tf.reshape(parsed_features['position'], 93 | position_shape) 94 | # Set correct shapes of the remaining tensors. 95 | sequence_length = metadata['sequence_length'] + 1 96 | if 'context_mean' in metadata: 97 | context_feat_len = len(metadata['context_mean']) 98 | parsed_features['step_context'] = tf.reshape( 99 | parsed_features['step_context'], 100 | [sequence_length, context_feat_len]) 101 | # Decode particle type explicitly 102 | context['particle_type'] = tf.py_function( 103 | functools.partial(convert_fn, encoded_dtype=np.int64), 104 | inp=[context['particle_type'].values], 105 | Tout=[tf.int64]) 106 | context['particle_type'] = tf.reshape(context['particle_type'], [-1]) 107 | print("context: ", context) 108 | print("features: ", parsed_features) 109 | return context, parsed_features 110 | 111 | 112 | def split_trajectory(context, features, window_length=7): 113 | """Splits trajectory into sliding windows.""" 114 | # Our strategy is to make sure all the leading dimensions are the same size, 115 | # then we can use from_tensor_slices. 116 | 117 | trajectory_length = features['position'].get_shape().as_list()[0] 118 | 119 | # We then stack window_length position changes so the final 120 | # trajectory length will be - window_length +1 (the 1 to make sure we get 121 | # the last split). 122 | input_trajectory_length = trajectory_length - window_length + 1 123 | 124 | model_input_features = {} 125 | # Prepare the context features per step. 126 | model_input_features['particle_type'] = tf.tile( 127 | tf.expand_dims(context['particle_type'], axis=0), 128 | [input_trajectory_length, 1]) 129 | 130 | if 'step_context' in features: 131 | global_stack = [] 132 | for idx in range(input_trajectory_length): 133 | global_stack.append( 134 | features['step_context'][idx:idx + window_length]) 135 | model_input_features['step_context'] = tf.stack(global_stack) 136 | 137 | pos_stack = [] 138 | for idx in range(input_trajectory_length): 139 | pos_stack.append(features['position'][idx:idx + window_length]) 140 | # Get the corresponding positions 141 | model_input_features['position'] = tf.stack(pos_stack) 142 | 143 | return tf.data.Dataset.from_tensor_slices(model_input_features) 144 | -------------------------------------------------------------------------------- /render_rollout.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # pylint: disable=g-bad-file-header 3 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================ 17 | """Simple matplotlib rendering of a rollout prediction against ground truth. 18 | 19 | Usage (from parent directory): 20 | 21 | `python -m learning_to_simulate.render_rollout --rollout_path={OUTPUT_PATH}/rollout_test_1.pkl` 22 | 23 | Where {OUTPUT_PATH} is the output path passed to `train.py` in "eval_rollout" 24 | mode. 25 | 26 | It may require installing Tkinter with `sudo apt-get install python3.7-tk`. 27 | 28 | """ # pylint: disable=line-too-long 29 | 30 | import pickle 31 | 32 | from absl import app 33 | from absl import flags 34 | 35 | from matplotlib import animation 36 | import matplotlib.pyplot as plt 37 | import numpy as np 38 | 39 | flags.DEFINE_string("rollout_path", None, help="Path to rollout pickle file") 40 | flags.DEFINE_integer("step_stride", 3, help="Stride of steps to skip.") 41 | flags.DEFINE_boolean("block_on_show", True, help="For test purposes.") 42 | 43 | FLAGS = flags.FLAGS 44 | 45 | TYPE_TO_COLOR = { 46 | 3: "black", # Boundary particles. 47 | 0: "green", # Rigid solids. 48 | 7: "magenta", # Goop. 49 | 6: "gold", # Sand. 50 | 5: "blue", # Water. 51 | } 52 | 53 | 54 | def main(unused_argv): 55 | 56 | if not FLAGS.rollout_path: 57 | raise ValueError("A `rollout_path` must be passed.") 58 | with open(FLAGS.rollout_path, "rb") as file: 59 | rollout_data = pickle.load(file) 60 | 61 | fig, axes = plt.subplots(1, 2, figsize=(10, 5)) 62 | 63 | plot_info = [] 64 | for ax_i, (label, rollout_field) in enumerate( 65 | [("MPM", "ground_truth_rollout"), 66 | ("GNS", "predicted_rollout")]): 67 | # Append the initial positions to get the full trajectory. 68 | trajectory = np.concatenate([ 69 | rollout_data["initial_positions"], 70 | rollout_data[rollout_field]], axis=0) 71 | ax = axes[ax_i] 72 | ax.set_title(label) 73 | bounds = rollout_data["metadata"]["bounds"] 74 | ax.set_xlim(bounds[0][0], bounds[0][1]) 75 | ax.set_ylim(bounds[1][0], bounds[1][1]) 76 | ax.set_xticks([]) 77 | ax.set_yticks([]) 78 | ax.set_aspect(1.) 79 | points = { 80 | particle_type: ax.plot([], [], "o", ms=2, color=color)[0] 81 | for particle_type, color in TYPE_TO_COLOR.items()} 82 | plot_info.append((ax, trajectory, points)) 83 | 84 | num_steps = trajectory.shape[0] 85 | 86 | def update(step_i): 87 | outputs = [] 88 | for _, trajectory, points in plot_info: 89 | for particle_type, line in points.items(): 90 | mask = rollout_data["particle_types"] == particle_type 91 | line.set_data(trajectory[step_i, mask, 0], 92 | trajectory[step_i, mask, 1]) 93 | outputs.append(line) 94 | return outputs 95 | 96 | unused_animation = animation.FuncAnimation( 97 | fig, update, 98 | frames=np.arange(0, num_steps, FLAGS.step_stride), interval=10) 99 | 100 | unused_animation.save('rollout.gif', dpi=80, fps=30, writer='imagemagick') 101 | plt.show(block=FLAGS.block_on_show) 102 | 103 | 104 | if __name__ == "__main__": 105 | app.run(main) 106 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | autopep8 3 | graph-nets>=1.1 4 | tensorflow==1.15 5 | tensorflow-gpu==1.15 6 | numpy 7 | dm-sonnet<2 8 | tensorflow_probability<0.9 9 | sklearn 10 | dm-tree 11 | matplotlib 12 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2020 Deepmind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Fail on any error. 17 | set -e 18 | 19 | # Display commands being run. 20 | set -x 21 | 22 | module load cuda/10.0 23 | module load cudnn/7.6.2 24 | 25 | TMP_DIR="/work/05873/kks32/frontera/" 26 | 27 | virtualenv --python=python3.6 "${TMP_DIR}/learning_to_simulate/env" 28 | source "${TMP_DIR}/learning_to_simulate/env/bin/activate" 29 | 30 | # Install dependencies. 31 | pip install -r requirements.txt 32 | 33 | # Run the simple demo with dummy inputs. 34 | #python -m learning_to_simulate.model_demo 35 | 36 | # Run some training and evaluation in one of the dataset samples. 37 | 38 | # Download a sample of a dataset. 39 | DATASET_NAME="WaterDropSample" 40 | 41 | bash ./download_dataset.sh ${DATASET_NAME} "${TMP_DIR}/datasets" 42 | 43 | # Train for a few steps. 44 | DATA_PATH="${TMP_DIR}/datasets/${DATASET_NAME}" 45 | MODEL_PATH="${TMP_DIR}/models/${DATASET_NAME}" 46 | python -m learning_to_simulate.train --data_path=${DATA_PATH} --model_path=${MODEL_PATH} --num_steps=10 47 | 48 | # Evaluate on validation split. 49 | python -m learning_to_simulate.train --data_path=${DATA_PATH} --model_path=${MODEL_PATH} --mode="eval" --eval_split="valid" 50 | 51 | # Generate test rollouts. 52 | ROLLOUT_PATH="${TMP_DIR}/rollouts/${DATASET_NAME}" 53 | mkdir -p ${ROLLOUT_PATH} 54 | python -m learning_to_simulate.train --data_path=${DATA_PATH} --model_path=${MODEL_PATH} --mode="eval_rollout" --output_path=${ROLLOUT_PATH} 55 | 56 | # Plot the first rollout. 57 | python -m learning_to_simulate.render_rollout --rollout_path="${ROLLOUT_PATH}/rollout_test_0.pkl" --block_on_show=False 58 | 59 | # Clean up. 60 | rm -r ${TMP_DIR} 61 | -------------------------------------------------------------------------------- /slurm_scripts/render.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J pyt_render # Job name 4 | #SBATCH -o pyt_render.o%j # Name of stdout output file 5 | #SBATCH -e pyt_render.e%j # Name of stderr error file 6 | #SBATCH -p rtx # Queue (partition) name 7 | #SBATCH -N 1 # Total # of nodes (must be 1 for serial) 8 | #SBATCH -n 1 # Total # of mpi tasks (should be 1 for serial) 9 | #SBATCH -t 15:00:00 # Run time (hh:mm:ss) 10 | #SBATCH --mail-type=all # Send email at begin and end of job 11 | #SBATCH -A BCS20003 # Project/Allocation name (req'd if you have more than 1) 12 | 13 | # fail on error 14 | set -e 15 | 16 | # start in slurm_scripts 17 | cd .. 18 | source start_venv.sh 19 | 20 | cd .. 21 | 22 | # assume data is already downloaded and hardcode WaterDropSample 23 | python3 -m learning_to_simulate.render_rollout\ 24 | --rollout_path="${WORK}/gns_tensorflow/Sand/rollouts/rollout_test_0.pkl" 25 | 26 | -------------------------------------------------------------------------------- /slurm_scripts/rollout.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J pyt_roll # Job name 4 | #SBATCH -o pyt_roll.o%j # Name of stdout output file 5 | #SBATCH -e pyt_roll.e%j # Name of stderr error file 6 | #SBATCH -p rtx # Queue (partition) name 7 | #SBATCH -N 1 # Total # of nodes (must be 1 for serial) 8 | #SBATCH -n 1 # Total # of mpi tasks (should be 1 for serial) 9 | #SBATCH -t 15:00:00 # Run time (hh:mm:ss) 10 | #SBATCH --mail-type=all # Send email at begin and end of job 11 | #SBATCH -A BCS20003 # Project/Allocation name (req'd if you have more than 1) 12 | 13 | # fail on error 14 | set -e 15 | 16 | # start in slurm_scripts 17 | cd .. 18 | source start_venv.sh 19 | 20 | cd .. 21 | 22 | # assume data is already downloaded and hardcode WaterDropSample 23 | python3 -m learning_to_simulate.train\ 24 | --mode="eval_rollout"\ 25 | --data_path=$WORK/gns_tensorflow/Sand/dataset\ 26 | --model_path=$WORK/gns_tensorflow/Sand/models\ 27 | --output_path=$WORK/gns_tensorflow/Sand/rollouts 28 | -------------------------------------------------------------------------------- /slurm_scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J tf_train # Job name 4 | #SBATCH -o tf_train.o%j # Name of stdout output file 5 | #SBATCH -e tf_train.e%j # Name of stderr error file 6 | #SBATCH -p rtx # Queue (partition) name 7 | #SBATCH -N 1 # Total # of nodes (must be 1 for serial) 8 | #SBATCH -n 1 # Total # of mpi tasks (should be 1 for serial) 9 | #SBATCH -t 48:00:00 # Run time (hh:mm:ss) 10 | #SBATCH --mail-type=all # Send email at begin and end of job 11 | #SBATCH --mail-user=jvantassel@tacc.utexas.edu 12 | #SBATCH -A BCS20003 # Project/Allocation name (req'd if you have more than 1) 13 | 14 | # fail on error 15 | set -e 16 | 17 | # start in slurm_scripts 18 | cd .. 19 | source start_venv.sh 20 | 21 | cd .. 22 | 23 | # assume data is already downloaded and hardcode WaterDropSample 24 | data="Sand" 25 | DATA_PATH="${WORK}/gns_tensorflow/${data}/dataset" 26 | MODEL_PATH="${WORK}/gns_tensorflow/${data}/models" 27 | 28 | python3 -m learning_to_simulate.train \ 29 | --data_path=${DATA_PATH} \ 30 | --model_path=${MODEL_PATH} \ 31 | --num_steps="1000000" 32 | 33 | -------------------------------------------------------------------------------- /start_venv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load cuda/10.0 4 | module load cudnn/7.6.2 5 | 6 | source venv/bin/activate 7 | -------------------------------------------------------------------------------- /tfrecord.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Read and Write TFRecord\n", 8 | "\n", 9 | "> Krishna Kumar and Joseph Vantassel, The University of Texas at Austin" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "# Import modules and this file should be outside learning_to_simulate code folder\n", 19 | "import functools\n", 20 | "import os\n", 21 | "import json\n", 22 | "import pickle\n", 23 | "\n", 24 | "import tensorflow.compat.v1 as tf\n", 25 | "import numpy as np\n", 26 | "\n", 27 | "from learning_to_simulate import reading_utils" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "## Read Metadata" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "# Set datapath and validation set\n", 44 | "data_path = './datasets/WaterRamps'\n", 45 | "filename = 'valid.tfrecord'" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 3, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "# Read metadata\n", 55 | "def _read_metadata(data_path):\n", 56 | " with open(os.path.join(data_path, 'metadata.json'), 'rt') as fp:\n", 57 | " return json.loads(fp.read())\n", 58 | "\n", 59 | "# Fetch metadata\n", 60 | "metadata = _read_metadata(data_path)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 4, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "name": "stdout", 70 | "output_type": "stream", 71 | "text": [ 72 | "{'bounds': [[0.1, 0.9], [0.1, 0.9]], 'sequence_length': 600, 'default_connectivity_radius': 0.015, 'dim': 2, 'dt': 0.0025, 'vel_mean': [-6.141567458658365e-08, -0.0007425391691160353], 'vel_std': [0.0022381126134429557, 0.0022664486850394443], 'acc_mean': [-1.713503820317499e-07, -2.1448168008479274e-07], 'acc_std': [0.00016824548701156486, 0.0001819676291787043]}\n" 73 | ] 74 | } 75 | ], 76 | "source": [ 77 | "print(metadata)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "## Read TFRecord" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "### Read All Entries" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 5, 97 | "metadata": {}, 98 | "outputs": [ 99 | { 100 | "name": "stdout", 101 | "output_type": "stream", 102 | "text": [ 103 | "context: {'particle_type': , 'key': }\n", 104 | "features: {'position': }\n" 105 | ] 106 | } 107 | ], 108 | "source": [ 109 | "ds_org = tf.data.TFRecordDataset([os.path.join(data_path, filename)])\n", 110 | "ds = ds_org.map(functools.partial(reading_utils.parse_serialized_simulation_example, metadata=metadata))" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": {}, 116 | "source": [ 117 | "### Read Single Entry" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 6, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "# raw_dataset = tf.data.TFRecordDataset(\"datasets/WaterRamps/valid.tfrecord\")\n", 127 | "\n", 128 | "# for raw_record in raw_dataset.take(1):\n", 129 | "# example = tf.train.SequenceExample()\n", 130 | "# example.ParseFromString(raw_record.numpy())\n", 131 | "# a_true, b_true = example.ListFields()" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": {}, 137 | "source": [ 138 | "### Convert to list" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 7, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "lds = list(ds)\n", 148 | "\n", 149 | "particle_types = []\n", 150 | "keys = []\n", 151 | "positions = []\n", 152 | "for _ds in ds:\n", 153 | " context, features = _ds\n", 154 | " particle_types.append(context[\"particle_type\"].numpy().astype(np.int64))\n", 155 | " keys.append(context[\"key\"].numpy().astype(np.int64))\n", 156 | " positions.append(features[\"position\"].numpy().astype(np.float32))" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "## Write New TFRecord" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 8, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "# The following functions can be used to convert a value to a type compatible\n", 173 | "# with tf.train.Example.\n", 174 | "\n", 175 | "def _bytes_feature(value):\n", 176 | " \"\"\"Returns a bytes_list from a string / byte.\"\"\"\n", 177 | " if isinstance(value, type(tf.constant(0))):\n", 178 | " value = value.numpy() # BytesList won't unpack a string from an EagerTensor.\n", 179 | " return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))\n", 180 | "\n", 181 | "def _float_feature(value):\n", 182 | " \"\"\"Returns a float_list from a float / double.\"\"\"\n", 183 | " return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))\n", 184 | "\n", 185 | "def _int64_feature(value):\n", 186 | " \"\"\"Returns an int64_list from a bool / enum / int / uint.\"\"\"\n", 187 | " return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))\n" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 9, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "with tf.python_io.TFRecordWriter('test.tfrecord') as writer:\n", 197 | " \n", 198 | " for step, (particle_type, key, position) in enumerate(zip(particle_types, keys, positions)):\n", 199 | " seq = tf.train.SequenceExample(\n", 200 | " context=tf.train.Features(feature={\n", 201 | " \"particle_type\": _bytes_feature(particle_type.tobytes()),\n", 202 | " \"key\": _int64_feature(key)\n", 203 | " }),\n", 204 | " feature_lists=tf.train.FeatureLists(feature_list={\n", 205 | " 'position': tf.train.FeatureList(\n", 206 | " feature=[_bytes_feature(position.flatten().tobytes())],\n", 207 | " ),\n", 208 | " 'step_context': tf.train.FeatureList(\n", 209 | " feature=[_bytes_feature(np.float32(step).tobytes())]\n", 210 | " ),\n", 211 | " })\n", 212 | " )\n", 213 | "\n", 214 | " writer.write(seq.SerializeToString())" 215 | ] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "metadata": {}, 220 | "source": [ 221 | "## Read New TFRecord" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 10, 227 | "metadata": {}, 228 | "outputs": [ 229 | { 230 | "name": "stdout", 231 | "output_type": "stream", 232 | "text": [ 233 | "context: {'particle_type': , 'key': }\n", 234 | "features: {'position': }\n" 235 | ] 236 | } 237 | ], 238 | "source": [ 239 | "dt = tf.data.TFRecordDataset(['test.tfrecord'])\n", 240 | "dt = dt.map(functools.partial(reading_utils.parse_serialized_simulation_example, metadata=metadata))" 241 | ] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "metadata": {}, 246 | "source": [ 247 | "## Compare Original and New TFRecord" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": 12, 253 | "metadata": {}, 254 | "outputs": [ 255 | { 256 | "name": "stdout", 257 | "output_type": "stream", 258 | "text": [ 259 | "TFRecords are similar!\n" 260 | ] 261 | } 262 | ], 263 | "source": [ 264 | "for ((_ds_context, _ds_feature), (_dt_context, _dt_feature)) in zip(ds, dt):\n", 265 | " if not np.allclose(_ds_context[\"key\"].numpy(), _dt_context[\"key\"].numpy()):\n", 266 | " break\n", 267 | "\n", 268 | " if not np.allclose(_ds_context[\"particle_type\"].numpy(), _dt_context[\"particle_type\"].numpy()):\n", 269 | " break\n", 270 | " \n", 271 | " if not np.allclose(_ds_feature[\"position\"].numpy(), _dt_feature[\"position\"].numpy()):\n", 272 | " break\n", 273 | "\n", 274 | "else:\n", 275 | " print(\"TFRecords are similar!\")" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": null, 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [] 284 | } 285 | ], 286 | "metadata": { 287 | "kernelspec": { 288 | "display_name": "Python 3", 289 | "language": "python", 290 | "name": "python3" 291 | }, 292 | "language_info": { 293 | "codemirror_mode": { 294 | "name": "ipython", 295 | "version": 3 296 | }, 297 | "file_extension": ".py", 298 | "mimetype": "text/x-python", 299 | "name": "python", 300 | "nbconvert_exporter": "python", 301 | "pygments_lexer": "ipython3", 302 | "version": "3.7.5" 303 | } 304 | }, 305 | "nbformat": 4, 306 | "nbformat_minor": 4 307 | } 308 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # pylint: disable=g-bad-file-header 3 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================ 17 | # pylint: disable=line-too-long 18 | """Training script for https://arxiv.org/pdf/2002.09405.pdf. 19 | 20 | Example usage (from parent directory): 21 | `python -m learning_to_simulate.train --data_path={DATA_PATH} --model_path={MODEL_PATH}` 22 | 23 | Evaluate model from checkpoint (from parent directory): 24 | `python -m learning_to_simulate.train --data_path={DATA_PATH} --model_path={MODEL_PATH} --mode=eval` 25 | 26 | Produce rollouts (from parent directory): 27 | `python -m learning_to_simulate.train --data_path={DATA_PATH} --model_path={MODEL_PATH} --output_path={OUTPUT_PATH} --mode=eval_rollout` 28 | 29 | 30 | """ 31 | # pylint: enable=line-too-long 32 | import collections 33 | import functools 34 | import json 35 | import os 36 | import pickle 37 | 38 | from absl import app 39 | from absl import flags 40 | from absl import logging 41 | import numpy as np 42 | import tensorflow.compat.v1 as tf 43 | import tree 44 | 45 | 46 | from learning_to_simulate import learned_simulator 47 | from learning_to_simulate import noise_utils 48 | from learning_to_simulate import reading_utils 49 | 50 | 51 | flags.DEFINE_enum( 52 | 'mode', 'train', ['train', 'eval', 'eval_rollout'], 53 | help='Train model, one step evaluation or rollout evaluation.') 54 | flags.DEFINE_enum('eval_split', 'test', ['train', 'valid', 'test'], 55 | help='Split to use when running evaluation.') 56 | flags.DEFINE_string('data_path', None, help='The dataset directory.') 57 | flags.DEFINE_integer('batch_size', 2, help='The batch size.') 58 | flags.DEFINE_integer('num_steps', int(2e7), help='Number of steps of training.') 59 | flags.DEFINE_float('noise_std', 6.7e-4, help='The std deviation of the noise.') 60 | flags.DEFINE_string('model_path', None, 61 | help=('The path for saving checkpoints of the model. ' 62 | 'Defaults to a temporary directory.')) 63 | flags.DEFINE_string('output_path', None, 64 | help='The path for saving outputs (e.g. rollouts).') 65 | 66 | 67 | FLAGS = flags.FLAGS 68 | 69 | Stats = collections.namedtuple('Stats', ['mean', 'std']) 70 | 71 | INPUT_SEQUENCE_LENGTH = 6 # So we can calculate the last 5 velocities. 72 | NUM_PARTICLE_TYPES = 9 73 | KINEMATIC_PARTICLE_ID = 3 74 | 75 | 76 | def get_kinematic_mask(particle_types): 77 | """Returns a boolean mask, set to true for kinematic (obstacle) particles.""" 78 | return tf.equal(particle_types, KINEMATIC_PARTICLE_ID) 79 | 80 | 81 | def prepare_inputs(tensor_dict): 82 | """Prepares a single stack of inputs by calculating inputs and targets. 83 | 84 | Computes n_particles_per_example, which is a tensor that contains information 85 | about how to partition the axis - i.e. which nodes belong to which graph. 86 | 87 | Adds a batch axis to `n_particles_per_example` and `step_context` so they can 88 | later be batched using `batch_concat`. This batch will be the same as if the 89 | elements had been batched via stacking. 90 | 91 | Note that all other tensors have a variable size particle axis, 92 | and in this case they will simply be concatenated along that 93 | axis. 94 | 95 | 96 | 97 | Args: 98 | tensor_dict: A dict of tensors containing positions, and step context ( 99 | if available). 100 | 101 | Returns: 102 | A tuple of input features and target positions. 103 | 104 | """ 105 | # Position is encoded as [sequence_length, num_particles, dim] but the model 106 | # expects [num_particles, sequence_length, dim]. 107 | pos = tensor_dict['position'] 108 | pos = tf.transpose(pos, perm=[1, 0, 2]) 109 | 110 | # The target position is the final step of the stack of positions. 111 | target_position = pos[:, -1] 112 | 113 | # Remove the target from the input. 114 | tensor_dict['position'] = pos[:, :-1] 115 | 116 | # Compute the number of particles per example. 117 | num_particles = tf.shape(pos)[0] 118 | # Add an extra dimension for stacking via concat. 119 | tensor_dict['n_particles_per_example'] = num_particles[tf.newaxis] 120 | 121 | if 'step_context' in tensor_dict: 122 | # Take the input global context. We have a stack of global contexts, 123 | # and we take the penultimate since the final is the target. 124 | tensor_dict['step_context'] = tensor_dict['step_context'][-2] 125 | # Add an extra dimension for stacking via concat. 126 | tensor_dict['step_context'] = tensor_dict['step_context'][tf.newaxis] 127 | return tensor_dict, target_position 128 | 129 | 130 | def prepare_rollout_inputs(context, features): 131 | """Prepares an inputs trajectory for rollout.""" 132 | out_dict = {**context} 133 | # Position is encoded as [sequence_length, num_particles, dim] but the model 134 | # expects [num_particles, sequence_length, dim]. 135 | pos = tf.transpose(features['position'], [1, 0, 2]) 136 | # The target position is the final step of the stack of positions. 137 | target_position = pos[:, -1] 138 | # Remove the target from the input. 139 | out_dict['position'] = pos[:, :-1] 140 | # Compute the number of nodes 141 | out_dict['n_particles_per_example'] = [tf.shape(pos)[0]] 142 | if 'step_context' in features: 143 | out_dict['step_context'] = features['step_context'] 144 | out_dict['is_trajectory'] = tf.constant([True], tf.bool) 145 | return out_dict, target_position 146 | 147 | 148 | def batch_concat(dataset, batch_size): 149 | """We implement batching as concatenating on the leading axis.""" 150 | 151 | # We create a dataset of datasets of length batch_size. 152 | windowed_ds = dataset.window(batch_size) 153 | 154 | # The plan is then to reduce every nested dataset by concatenating. We can 155 | # do this using tf.data.Dataset.reduce. This requires an initial state, and 156 | # then incrementally reduces by running through the dataset 157 | 158 | # Get initial state. In this case this will be empty tensors of the 159 | # correct shape. 160 | initial_state = tree.map_structure( 161 | lambda spec: tf.zeros( # pylint: disable=g-long-lambda 162 | shape=[0] + spec.shape.as_list()[1:], dtype=spec.dtype), 163 | dataset.element_spec) 164 | 165 | # We run through the nest and concatenate each entry with the previous state. 166 | def reduce_window(initial_state, ds): 167 | return ds.reduce(initial_state, lambda x, y: tf.concat([x, y], axis=0)) 168 | 169 | return windowed_ds.map( 170 | lambda *x: tree.map_structure(reduce_window, initial_state, x)) 171 | 172 | 173 | def get_input_fn(data_path, batch_size, mode, split): 174 | """Gets the learning simulation input function for tf.estimator.Estimator. 175 | 176 | Args: 177 | data_path: the path to the dataset directory. 178 | batch_size: the number of graphs in a batch. 179 | mode: either 'one_step_train', 'one_step' or 'rollout' 180 | split: either 'train', 'valid' or 'test. 181 | 182 | Returns: 183 | The input function for the learning simulation model. 184 | """ 185 | def input_fn(): 186 | """Input function for learning simulation.""" 187 | # Loads the metadata of the dataset. 188 | metadata = _read_metadata(data_path) 189 | # Create a tf.data.Dataset from the TFRecord. 190 | ds = tf.data.TFRecordDataset([os.path.join(data_path, f'{split}.tfrecord')]) 191 | ds = ds.map(functools.partial( 192 | reading_utils.parse_serialized_simulation_example, metadata=metadata)) 193 | if mode.startswith('one_step'): 194 | # Splits an entire trajectory into chunks of 7 steps. 195 | # Previous 5 velocities, current velocity and target. 196 | split_with_window = functools.partial( 197 | reading_utils.split_trajectory, 198 | window_length=INPUT_SEQUENCE_LENGTH + 1) 199 | ds = ds.flat_map(split_with_window) 200 | # Splits a chunk into input steps and target steps 201 | ds = ds.map(prepare_inputs) 202 | # If in train mode, repeat dataset forever and shuffle. 203 | if mode == 'one_step_train': 204 | ds = ds.repeat() 205 | ds = ds.shuffle(512) 206 | # Custom batching on the leading axis. 207 | ds = batch_concat(ds, batch_size) 208 | elif mode == 'rollout': 209 | # Rollout evaluation only available for batch size 1 210 | assert batch_size == 1 211 | ds = ds.map(prepare_rollout_inputs) 212 | else: 213 | raise ValueError(f'mode: {mode} not recognized') 214 | return ds 215 | 216 | return input_fn 217 | 218 | 219 | def rollout(simulator, features, num_steps): 220 | """Rolls out a trajectory by applying the model in sequence.""" 221 | initial_positions = features['position'][:, 0:INPUT_SEQUENCE_LENGTH] 222 | ground_truth_positions = features['position'][:, INPUT_SEQUENCE_LENGTH:] 223 | global_context = features.get('step_context') 224 | def step_fn(step, current_positions, predictions): 225 | 226 | if global_context is None: 227 | global_context_step = None 228 | else: 229 | global_context_step = global_context[ 230 | step + INPUT_SEQUENCE_LENGTH - 1][tf.newaxis] 231 | 232 | next_position = simulator( 233 | current_positions, 234 | n_particles_per_example=features['n_particles_per_example'], 235 | particle_types=features['particle_type'], 236 | global_context=global_context_step) 237 | 238 | # Update kinematic particles from prescribed trajectory. 239 | kinematic_mask = get_kinematic_mask(features['particle_type']) 240 | next_position_ground_truth = ground_truth_positions[:, step] 241 | next_position = tf.where(kinematic_mask, next_position_ground_truth, 242 | next_position) 243 | updated_predictions = predictions.write(step, next_position) 244 | 245 | # Shift `current_positions`, removing the oldest position in the sequence 246 | # and appending the next position at the end. 247 | next_positions = tf.concat([current_positions[:, 1:], 248 | next_position[:, tf.newaxis]], axis=1) 249 | 250 | return (step + 1, next_positions, updated_predictions) 251 | 252 | predictions = tf.TensorArray(size=num_steps, dtype=tf.float32) 253 | _, _, predictions = tf.while_loop( 254 | cond=lambda step, state, prediction: tf.less(step, num_steps), 255 | body=step_fn, 256 | loop_vars=(0, initial_positions, predictions), 257 | back_prop=False, 258 | parallel_iterations=1) 259 | 260 | output_dict = { 261 | 'initial_positions': tf.transpose(initial_positions, [1, 0, 2]), 262 | 'predicted_rollout': predictions.stack(), 263 | 'ground_truth_rollout': tf.transpose(ground_truth_positions, [1, 0, 2]), 264 | 'particle_types': features['particle_type'], 265 | } 266 | 267 | if global_context is not None: 268 | output_dict['global_context'] = global_context 269 | return output_dict 270 | 271 | 272 | def _combine_std(std_x, std_y): 273 | return np.sqrt(std_x**2 + std_y**2) 274 | 275 | 276 | def _get_simulator(model_kwargs, metadata, acc_noise_std, vel_noise_std): 277 | """Instantiates the simulator.""" 278 | # Cast statistics to numpy so they are arrays when entering the model. 279 | cast = lambda v: np.array(v, dtype=np.float32) 280 | acceleration_stats = Stats( 281 | cast(metadata['acc_mean']), 282 | _combine_std(cast(metadata['acc_std']), acc_noise_std)) 283 | velocity_stats = Stats( 284 | cast(metadata['vel_mean']), 285 | _combine_std(cast(metadata['vel_std']), vel_noise_std)) 286 | normalization_stats = {'acceleration': acceleration_stats, 287 | 'velocity': velocity_stats} 288 | if 'context_mean' in metadata: 289 | context_stats = Stats( 290 | cast(metadata['context_mean']), cast(metadata['context_std'])) 291 | normalization_stats['context'] = context_stats 292 | 293 | simulator = learned_simulator.LearnedSimulator( 294 | num_dimensions=metadata['dim'], 295 | connectivity_radius=metadata['default_connectivity_radius'], 296 | graph_network_kwargs=model_kwargs, 297 | boundaries=metadata['bounds'], 298 | num_particle_types=NUM_PARTICLE_TYPES, 299 | normalization_stats=normalization_stats, 300 | particle_type_embedding_size=16) 301 | return simulator 302 | 303 | 304 | def get_one_step_estimator_fn(data_path, 305 | noise_std, 306 | latent_size=128, 307 | hidden_size=128, 308 | hidden_layers=2, 309 | message_passing_steps=10): 310 | """Gets one step model for training simulation.""" 311 | metadata = _read_metadata(data_path) 312 | 313 | model_kwargs = dict( 314 | latent_size=latent_size, 315 | mlp_hidden_size=hidden_size, 316 | mlp_num_hidden_layers=hidden_layers, 317 | num_message_passing_steps=message_passing_steps) 318 | 319 | def estimator_fn(features, labels, mode): 320 | target_next_position = labels 321 | simulator = _get_simulator(model_kwargs, metadata, 322 | vel_noise_std=noise_std, 323 | acc_noise_std=noise_std) 324 | # Sample the noise to add to the inputs to the model during training. 325 | sampled_noise = noise_utils.get_random_walk_noise_for_position_sequence( 326 | features['position'], noise_std_last_step=noise_std) 327 | non_kinematic_mask = tf.logical_not( 328 | get_kinematic_mask(features['particle_type'])) 329 | noise_mask = tf.cast( 330 | non_kinematic_mask, sampled_noise.dtype)[:, tf.newaxis, tf.newaxis] 331 | sampled_noise *= noise_mask 332 | 333 | # Get the predictions and target accelerations. 334 | pred_target = simulator.get_predicted_and_target_normalized_accelerations( 335 | next_position=target_next_position, 336 | position_sequence=features['position'], 337 | position_sequence_noise=sampled_noise, 338 | n_particles_per_example=features['n_particles_per_example'], 339 | particle_types=features['particle_type'], 340 | global_context=features.get('step_context')) 341 | pred_acceleration, target_acceleration = pred_target 342 | 343 | # Calculate the loss and mask out loss on kinematic particles/ 344 | loss = (pred_acceleration - target_acceleration)**2 345 | 346 | num_non_kinematic = tf.reduce_sum( 347 | tf.cast(non_kinematic_mask, tf.float32)) 348 | loss = tf.where(non_kinematic_mask, loss, tf.zeros_like(loss)) 349 | loss = tf.reduce_sum(loss) / tf.reduce_sum(num_non_kinematic) 350 | global_step = tf.train.get_global_step() 351 | # Set learning rate to decay from 1e-4 to 1e-6 exponentially. 352 | min_lr = 1e-6 353 | lr = tf.train.exponential_decay(learning_rate=1e-4 - min_lr, 354 | global_step=global_step, 355 | decay_steps=int(5e6), 356 | decay_rate=0.1) + min_lr 357 | opt = tf.train.AdamOptimizer(learning_rate=lr) 358 | train_op = opt.minimize(loss, global_step) 359 | 360 | # Calculate next position and add some additional eval metrics (only eval). 361 | predicted_next_position = simulator( 362 | position_sequence=features['position'], 363 | n_particles_per_example=features['n_particles_per_example'], 364 | particle_types=features['particle_type'], 365 | global_context=features.get('step_context')) 366 | 367 | predictions = {'predicted_next_position': predicted_next_position} 368 | 369 | eval_metrics_ops = { 370 | 'loss_mse': tf.metrics.mean_squared_error( 371 | pred_acceleration, target_acceleration), 372 | 'one_step_position_mse': tf.metrics.mean_squared_error( 373 | predicted_next_position, target_next_position) 374 | } 375 | return tf.estimator.EstimatorSpec( 376 | mode=mode, 377 | train_op=train_op, 378 | loss=loss, 379 | predictions=predictions, 380 | eval_metric_ops=eval_metrics_ops) 381 | 382 | return estimator_fn 383 | 384 | 385 | def get_rollout_estimator_fn(data_path, 386 | noise_std, 387 | latent_size=128, 388 | hidden_size=128, 389 | hidden_layers=2, 390 | message_passing_steps=10): 391 | """Gets the model function for tf.estimator.Estimator.""" 392 | metadata = _read_metadata(data_path) 393 | 394 | model_kwargs = dict( 395 | latent_size=latent_size, 396 | mlp_hidden_size=hidden_size, 397 | mlp_num_hidden_layers=hidden_layers, 398 | num_message_passing_steps=message_passing_steps) 399 | 400 | def estimator_fn(features, labels, mode): 401 | del labels # Labels to conform to estimator spec. 402 | simulator = _get_simulator(model_kwargs, metadata, 403 | acc_noise_std=noise_std, 404 | vel_noise_std=noise_std) 405 | 406 | num_steps = metadata['sequence_length'] - INPUT_SEQUENCE_LENGTH 407 | rollout_op = rollout(simulator, features, num_steps=num_steps) 408 | squared_error = (rollout_op['predicted_rollout'] - 409 | rollout_op['ground_truth_rollout']) ** 2 410 | loss = tf.reduce_mean(squared_error) 411 | eval_ops = {'rollout_error_mse': tf.metrics.mean_squared_error( 412 | rollout_op['predicted_rollout'], rollout_op['ground_truth_rollout'])} 413 | 414 | # Add a leading axis, since Estimator's predict method insists that all 415 | # tensors have a shared leading batch axis fo the same dims. 416 | rollout_op = tree.map_structure(lambda x: x[tf.newaxis], rollout_op) 417 | return tf.estimator.EstimatorSpec( 418 | mode=mode, 419 | train_op=None, 420 | loss=loss, 421 | predictions=rollout_op, 422 | eval_metric_ops=eval_ops) 423 | 424 | return estimator_fn 425 | 426 | 427 | def _read_metadata(data_path): 428 | with open(os.path.join(data_path, 'metadata.json'), 'rt') as fp: 429 | return json.loads(fp.read()) 430 | 431 | 432 | def main(_): 433 | """Train or evaluates the model.""" 434 | 435 | if FLAGS.mode in ['train', 'eval']: 436 | estimator = tf.estimator.Estimator( 437 | get_one_step_estimator_fn(FLAGS.data_path, FLAGS.noise_std), 438 | model_dir=FLAGS.model_path) 439 | if FLAGS.mode == 'train': 440 | # Train all the way through. 441 | estimator.train( 442 | input_fn=get_input_fn(FLAGS.data_path, FLAGS.batch_size, 443 | mode='one_step_train', split='train'), 444 | max_steps=FLAGS.num_steps) 445 | else: 446 | # One-step evaluation from checkpoint. 447 | eval_metrics = estimator.evaluate(input_fn=get_input_fn( 448 | FLAGS.data_path, FLAGS.batch_size, 449 | mode='one_step', split=FLAGS.eval_split)) 450 | logging.info('Evaluation metrics:') 451 | logging.info(eval_metrics) 452 | elif FLAGS.mode == 'eval_rollout': 453 | if not FLAGS.output_path: 454 | raise ValueError('A rollout path must be provided.') 455 | rollout_estimator = tf.estimator.Estimator( 456 | get_rollout_estimator_fn(FLAGS.data_path, FLAGS.noise_std), 457 | model_dir=FLAGS.model_path) 458 | 459 | # Iterate through rollouts saving them one by one. 460 | metadata = _read_metadata(FLAGS.data_path) 461 | rollout_iterator = rollout_estimator.predict( 462 | input_fn=get_input_fn(FLAGS.data_path, batch_size=1, 463 | mode='rollout', split=FLAGS.eval_split)) 464 | 465 | for example_index, example_rollout in enumerate(rollout_iterator): 466 | example_rollout['metadata'] = metadata 467 | filename = f'rollout_{FLAGS.eval_split}_{example_index}.pkl' 468 | filename = os.path.join(FLAGS.output_path, filename) 469 | logging.info('Saving: %s.', filename) 470 | if not os.path.exists(FLAGS.output_path): 471 | os.mkdir(FLAGS.output_path) 472 | with open(filename, 'wb') as file: 473 | pickle.dump(example_rollout, file) 474 | 475 | if __name__ == '__main__': 476 | tf.disable_v2_behavior() 477 | app.run(main) 478 | -------------------------------------------------------------------------------- /write_tfrecord.py: -------------------------------------------------------------------------------- 1 | # Import modules and this file should be outside learning_to_simulate code folder 2 | import functools 3 | import os 4 | import json 5 | import pickle 6 | 7 | import tensorflow.compat.v1 as tf 8 | import numpy as np 9 | 10 | from learning_to_simulate import reading_utils 11 | 12 | # Set datapath and validation set 13 | data_path = './datasets/WaterRamps' 14 | filename = 'valid.tfrecord' 15 | 16 | # Read metadata 17 | def _read_metadata(data_path): 18 | with open(os.path.join(data_path, 'metadata.json'), 'rt') as fp: 19 | return json.loads(fp.read()) 20 | 21 | # Fetch metadata 22 | metadata = _read_metadata(data_path) 23 | 24 | print(metadata) 25 | 26 | # Read TFRecord 27 | ds_org = tf.data.TFRecordDataset([os.path.join(data_path, filename)]) 28 | ds = ds_org.map(functools.partial(reading_utils.parse_serialized_simulation_example, metadata=metadata)) 29 | 30 | # Convert to list 31 | lds = list(ds) 32 | 33 | particle_types = [] 34 | keys = [] 35 | positions = [] 36 | for _ds in ds: 37 | context, features = _ds 38 | particle_types.append(context["particle_type"].numpy().astype(np.int64)) 39 | keys.append(context["key"].numpy().astype(np.int64)) 40 | positions.append(features["position"].numpy().astype(np.float32)) 41 | 42 | # The following functions can be used to convert a value to a type compatible 43 | # with tf.train.Example. 44 | 45 | def _bytes_feature(value): 46 | """Returns a bytes_list from a string / byte.""" 47 | if isinstance(value, type(tf.constant(0))): 48 | value = value.numpy() # BytesList won't unpack a string from an EagerTensor. 49 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 50 | 51 | def _float_feature(value): 52 | """Returns a float_list from a float / double.""" 53 | return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) 54 | 55 | def _int64_feature(value): 56 | """Returns an int64_list from a bool / enum / int / uint.""" 57 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 58 | 59 | 60 | # Write TF Record 61 | with tf.python_io.TFRecordWriter('test.tfrecord') as writer: 62 | 63 | for step, (particle_type, key, position) in enumerate(zip(particle_types, keys, positions)): 64 | seq = tf.train.SequenceExample( 65 | context=tf.train.Features(feature={ 66 | "particle_type": _bytes_feature(particle_type.tobytes()), 67 | "key": _int64_feature(key) 68 | }), 69 | feature_lists=tf.train.FeatureLists(feature_list={ 70 | 'position': tf.train.FeatureList( 71 | feature=[_bytes_feature(position.flatten().tobytes())], 72 | ), 73 | 'step_context': tf.train.FeatureList( 74 | feature=[_bytes_feature(np.float32(step).tobytes())] 75 | ), 76 | }) 77 | ) 78 | 79 | writer.write(seq.SerializeToString()) 80 | 81 | 82 | dt = tf.data.TFRecordDataset(['test.tfrecord']) 83 | dt = dt.map(functools.partial(reading_utils.parse_serialized_simulation_example, metadata=metadata)) 84 | 85 | 86 | for ((_ds_context, _ds_feature), (_dt_context, _dt_feature)) in zip(ds, dt): 87 | if not np.allclose(_ds_context["key"].numpy(), _dt_context["key"].numpy()): 88 | break 89 | 90 | if not np.allclose(_ds_context["particle_type"].numpy(), _dt_context["particle_type"].numpy()): 91 | break 92 | 93 | if not np.allclose(_ds_feature["position"].numpy(), _dt_feature["position"].numpy()): 94 | break 95 | 96 | else: 97 | print("TFRecords are similar!") 98 | --------------------------------------------------------------------------------