├── WORKSPACE ├── .gitignore ├── vaeseq ├── __init__.py ├── examples │ ├── __init__.py │ ├── midi │ │ ├── __init__.py │ │ ├── BUILD │ │ ├── hparams.py │ │ ├── model_test.py │ │ ├── dataset_test.py │ │ ├── dataset.py │ │ ├── midi.py │ │ └── model.py │ ├── play │ │ ├── __init__.py │ │ ├── hparams.py │ │ ├── model_test.py │ │ ├── BUILD │ │ ├── environment_test.py │ │ ├── play.py │ │ ├── agent.py │ │ ├── model.py │ │ ├── codec.py │ │ └── environment.py │ └── text │ │ ├── __init__.py │ │ ├── BUILD │ │ ├── hparams.py │ │ ├── model_test.py │ │ ├── dataset.py │ │ ├── model.py │ │ ├── dataset_test.py │ │ └── text.py ├── vae │ ├── BUILD │ ├── __init__.py │ ├── rnn.py │ ├── independent_sequence.py │ ├── srnn.py │ └── vae_test.py ├── latent.py ├── hparams.py ├── context_test.py ├── BUILD ├── model_test.py ├── util_test.py ├── vae_module.py ├── dist_module.py ├── batch_dist.py ├── train.py ├── batch_dist_test.py ├── util.py ├── model.py ├── codec.py └── context.py ├── CONTRIBUTING.md ├── setup.py ├── README.md └── LICENSE /WORKSPACE: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *.pyc 3 | bazel-* 4 | build/ 5 | dist/ 6 | *.egg* 7 | *#* 8 | -------------------------------------------------------------------------------- /vaeseq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /vaeseq/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /vaeseq/examples/midi/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /vaeseq/examples/play/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /vaeseq/examples/text/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution, 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | -------------------------------------------------------------------------------- /vaeseq/vae/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | 3 | py_library( 4 | name = "independent_sequence", 5 | srcs = ["independent_sequence.py"], 6 | deps = [ 7 | "//vaeseq:latent", 8 | "//vaeseq:util", 9 | "//vaeseq:vae_module", 10 | ], 11 | ) 12 | 13 | py_library( 14 | name = "rnn", 15 | srcs = ["rnn.py"], 16 | deps = [ 17 | "//vaeseq:util", 18 | "//vaeseq:vae_module", 19 | ], 20 | ) 21 | 22 | py_library( 23 | name = "srnn", 24 | srcs = ["srnn.py"], 25 | deps = [ 26 | "//vaeseq:latent", 27 | "//vaeseq:util", 28 | "//vaeseq:vae_module", 29 | ], 30 | ) 31 | 32 | py_library( 33 | name = "vae", 34 | srcs = ["__init__.py"], 35 | deps = [ 36 | ":independent_sequence", 37 | ":rnn", 38 | ":srnn", 39 | ], 40 | ) 41 | 42 | py_test( 43 | name = "vae_test", 44 | srcs = ["vae_test.py"], 45 | deps = [ 46 | ":vae", 47 | "//vaeseq:codec", 48 | "//vaeseq:context", 49 | "//vaeseq:hparams", 50 | "//vaeseq:util", 51 | ], 52 | ) 53 | -------------------------------------------------------------------------------- /vaeseq/vae/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Registry for different VAE implementations.""" 16 | 17 | from . import independent_sequence 18 | from . import rnn 19 | from . import srnn 20 | 21 | VAE_TYPES = {} 22 | VAE_TYPES["ISEQ"] = independent_sequence.IndependentSequence 23 | VAE_TYPES["RNN"] = rnn.RNN 24 | VAE_TYPES["SRNN"] = srnn.SRNN 25 | 26 | def make(hparams, *args, **kwargs): 27 | """Create a VAE instance according to hparams.vae_type.""" 28 | vae_type = VAE_TYPES[hparams.vae_type] 29 | return vae_type(hparams, *args, **kwargs) 30 | -------------------------------------------------------------------------------- /vaeseq/examples/text/BUILD: -------------------------------------------------------------------------------- 1 | test_suite( 2 | name = "tests", 3 | tests = [ 4 | "dataset_test", 5 | "model_test", 6 | ], 7 | ) 8 | 9 | py_library( 10 | name = "dataset", 11 | srcs = ["dataset.py"], 12 | deps = [], 13 | ) 14 | 15 | py_test( 16 | name = "dataset_test", 17 | srcs = ["dataset_test.py"], 18 | deps = [ 19 | ":dataset", 20 | ], 21 | ) 22 | 23 | py_library( 24 | name = "hparams", 25 | srcs = ["hparams.py"], 26 | deps = [ 27 | "//vaeseq:hparams", 28 | ], 29 | ) 30 | 31 | py_library( 32 | name = "model", 33 | srcs = ["model.py"], 34 | deps = [ 35 | ":dataset", 36 | "//vaeseq:codec", 37 | "//vaeseq:context", 38 | "//vaeseq:model", 39 | "//vaeseq:util", 40 | ], 41 | ) 42 | 43 | py_test( 44 | name = "model_test", 45 | srcs = ["model_test.py"], 46 | deps = [ 47 | ":hparams", 48 | ":model", 49 | "//vaeseq:model_test", 50 | ], 51 | ) 52 | 53 | py_binary( 54 | name = "text", 55 | srcs = ["text.py"], 56 | deps = [ 57 | ":hparams", 58 | ":model", 59 | ], 60 | ) 61 | -------------------------------------------------------------------------------- /vaeseq/examples/midi/BUILD: -------------------------------------------------------------------------------- 1 | test_suite( 2 | name = "tests", 3 | tests = [ 4 | "dataset_test", 5 | "model_test", 6 | ], 7 | ) 8 | 9 | py_library( 10 | name = "dataset", 11 | srcs = ["dataset.py"], 12 | deps = [], 13 | ) 14 | 15 | py_test( 16 | name = "dataset_test", 17 | srcs = ["dataset_test.py"], 18 | deps = [ 19 | ":dataset", 20 | ], 21 | ) 22 | 23 | py_library( 24 | name = "hparams", 25 | srcs = ["hparams.py"], 26 | deps = [ 27 | "//vaeseq:hparams", 28 | ], 29 | ) 30 | 31 | py_library( 32 | name = "model", 33 | srcs = ["model.py"], 34 | deps = [ 35 | ":dataset", 36 | "//vaeseq:codec", 37 | "//vaeseq:context", 38 | "//vaeseq:model", 39 | "//vaeseq:util", 40 | ], 41 | ) 42 | 43 | py_test( 44 | name = "model_test", 45 | srcs = ["model_test.py"], 46 | deps = [ 47 | ":dataset", 48 | ":hparams", 49 | ":model", 50 | "//vaeseq:model_test", 51 | "//vaeseq:util", 52 | ], 53 | ) 54 | 55 | py_binary( 56 | name = "midi", 57 | srcs = ["midi.py"], 58 | deps = [ 59 | ":hparams", 60 | ":model", 61 | ], 62 | ) 63 | -------------------------------------------------------------------------------- /vaeseq/examples/text/hparams.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Hyperparameters for this example.""" 16 | 17 | from vaeseq import hparams as hparams_mod 18 | 19 | _DEFAULTS = dict( 20 | latent_size=16, 21 | sequence_size=40, 22 | rnn_hidden_sizes=[512, 512, 512], 23 | obs_encoder_fc_layers=[64, 32], 24 | obs_decoder_fc_hidden_layers=[64], 25 | embed_size=100, 26 | vocab_size=100, 27 | oov_buckets=1, 28 | ) 29 | 30 | 31 | def make_hparams(flag_value=None, **kwargs): 32 | """Initialize HParams with the defaults in this module.""" 33 | init = dict(_DEFAULTS) 34 | init.update(kwargs) 35 | ret = hparams_mod.make_hparams(flag_value=flag_value, **init) 36 | return ret 37 | -------------------------------------------------------------------------------- /vaeseq/examples/midi/hparams.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Hyperparameters for this example.""" 16 | 17 | from vaeseq import hparams as hparams_mod 18 | 19 | _DEFAULTS = dict( 20 | rnn_hidden_sizes=[512, 512, 512], 21 | obs_encoder_fc_layers=[128, 128, 128], 22 | history_encoder_fc_layers=[128, 128, 128], 23 | obs_decoder_fc_hidden_layers=[128, 128], 24 | latent_size=16, 25 | sequence_size=64, 26 | history_size=20, 27 | rate=32, 28 | l2_regularization=0.01,) 29 | 30 | 31 | def make_hparams(flag_value=None, **kwargs): 32 | """Initialize HParams with the defaults in this module.""" 33 | init = dict(_DEFAULTS) 34 | init.update(kwargs) 35 | ret = hparams_mod.make_hparams(flag_value=flag_value, **init) 36 | return ret 37 | -------------------------------------------------------------------------------- /vaeseq/examples/play/hparams.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Hyperparameters for this example.""" 16 | 17 | from vaeseq import hparams as hparams_mod 18 | 19 | _DEFAULTS = dict( 20 | latent_size=16, 21 | sequence_size=40, 22 | obs_encoder_fc_layers=[64, 64], 23 | obs_decoder_fc_hidden_layers=[64], 24 | latent_decoder_fc_layers=[64], 25 | rnn_hidden_sizes=[64], 26 | game="CartPole-v0", 27 | game_output_size=[4], 28 | game_action_space=2, 29 | batch_size=32, 30 | explore_temp=0.5, 31 | l2_regularization=0.1, 32 | ) 33 | 34 | 35 | def make_hparams(flag_value=None, **kwargs): 36 | """Initialize HParams with the defaults in this module.""" 37 | init = dict(_DEFAULTS) 38 | init.update(kwargs) 39 | ret = hparams_mod.make_hparams(flag_value=flag_value, **init) 40 | return ret 41 | -------------------------------------------------------------------------------- /vaeseq/examples/play/model_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for training and generating graphs.""" 16 | 17 | import tensorflow as tf 18 | from vaeseq import model_test 19 | 20 | from vaeseq.examples.play import hparams as hparams_mod 21 | from vaeseq.examples.play import model as model_mod 22 | 23 | 24 | class ModelTest(model_test.ModelTest): 25 | 26 | def _setup_model(self, session_params): 27 | self.train_dataset = True 28 | self.valid_dataset = None 29 | self.hparams = hparams_mod.make_hparams( 30 | rnn_hidden_sizes=[4, 4], 31 | obs_encoder_fc_layers=[32, 16], 32 | obs_decoder_fc_hidden_layers=[32], 33 | latent_decoder_fc_layers=[32], 34 | check_numerics=True) 35 | self.model = model_mod.Model(self.hparams, session_params) 36 | 37 | 38 | if __name__ == "__main__": 39 | tf.test.main() 40 | -------------------------------------------------------------------------------- /vaeseq/examples/play/BUILD: -------------------------------------------------------------------------------- 1 | test_suite( 2 | name = "tests", 3 | tests = [ 4 | "environment_test", 5 | "model_test", 6 | ], 7 | ) 8 | 9 | py_library( 10 | name = "hparams", 11 | srcs = ["hparams.py"], 12 | deps = [ 13 | "//vaeseq:hparams", 14 | ], 15 | ) 16 | 17 | py_library( 18 | name = "agent", 19 | srcs = ["agent.py"], 20 | deps = [ 21 | "//vaeseq:context", 22 | "//vaeseq:util", 23 | ], 24 | ) 25 | 26 | py_library( 27 | name = "environment", 28 | srcs = ["environment.py"], 29 | deps = [ 30 | "//vaeseq:util", 31 | ], 32 | ) 33 | 34 | py_test( 35 | name = "environment_test", 36 | srcs = ["environment_test.py"], 37 | deps = [ 38 | ":environment", 39 | "//vaeseq:util", 40 | ], 41 | ) 42 | 43 | py_library( 44 | name = "codec", 45 | srcs = ["codec.py"], 46 | deps = [ 47 | "//vaeseq:batch_dist", 48 | "//vaeseq:codec", 49 | "//vaeseq:dist_module", 50 | "//vaeseq:util", 51 | ], 52 | ) 53 | 54 | py_library( 55 | name = "model", 56 | srcs = ["model.py"], 57 | deps = [ 58 | ":agent", 59 | ":codec", 60 | ":environment", 61 | "//vaeseq:model", 62 | "//vaeseq:util", 63 | "//vaeseq:train", 64 | ], 65 | ) 66 | 67 | py_test( 68 | name = "model_test", 69 | srcs = ["model_test.py"], 70 | deps = [ 71 | ":hparams", 72 | ":model", 73 | "//vaeseq:model_test", 74 | ], 75 | ) 76 | 77 | py_binary( 78 | name = "play", 79 | srcs = ["play.py"], 80 | deps = [ 81 | ":hparams", 82 | ":model", 83 | ], 84 | ) 85 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from setuptools import setup, find_packages 16 | import unittest 17 | 18 | 19 | def tests(): 20 | """Used by test_suite below.""" 21 | return unittest.TestLoader().discover( 22 | "vaeseq/", "*_test.py", top_level_dir=".") 23 | 24 | 25 | setup( 26 | name="vae-seq", 27 | author="Yury Sulsky", 28 | author_email="yury.sulsky@gmail.com", 29 | version="0.1", 30 | description="Generative Sequence Models", 31 | long_description=open("README.md").read(), 32 | packages=find_packages(), 33 | install_requires=[ 34 | "dm-sonnet>=1.10", 35 | "future>=0.16.0", 36 | "gym>=0.9.3", 37 | "numpy>=1.12.0", 38 | "pretty-midi>=0.2.8", 39 | "scipy>=0.16.0", 40 | "six>=1.0.0", 41 | ], 42 | extras_require={ 43 | "tf": ["tensorflow>=1.4.0"], 44 | "tf_gpu": ["tensorflow-gpu>=1.4.0"], 45 | }, 46 | entry_points={ 47 | "console_scripts": [ 48 | "vaeseq-text = vaeseq.examples.text.text:main", 49 | "vaeseq-midi = vaeseq.examples.midi.midi:main", 50 | "vaeseq-play = vaeseq.examples.play.play:main", 51 | ], 52 | }, 53 | test_suite="setup.tests", 54 | ) 55 | -------------------------------------------------------------------------------- /vaeseq/examples/midi/model_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for training and generating graphs.""" 16 | 17 | import os.path 18 | import tensorflow as tf 19 | from vaeseq import model_test 20 | 21 | from vaeseq.examples.midi import dataset as dataset_mod 22 | from vaeseq.examples.midi import hparams as hparams_mod 23 | from vaeseq.examples.midi import model as model_mod 24 | 25 | 26 | class ModelTest(model_test.ModelTest): 27 | 28 | def _write_midi(self, note): 29 | """Write a temporary MIDI file with a note playing for one second.""" 30 | temp_path = os.path.join(self.get_temp_dir(), 31 | "note_{}.mid".format(note)) 32 | dataset_mod.write_test_note(temp_path, 1.0, note) 33 | return temp_path 34 | 35 | def _setup_model(self, session_params): 36 | self.train_dataset = [self._write_midi(5), self._write_midi(7)] 37 | self.valid_dataset = [self._write_midi(5), self._write_midi(6)] 38 | self.hparams = hparams_mod.make_hparams( 39 | rnn_hidden_sizes=[4, 4], 40 | obs_encoder_fc_layers=[32, 16], 41 | obs_decoder_fc_hidden_layers=[32], 42 | latent_decoder_fc_layers=[32], 43 | check_numerics=True) 44 | self.model = model_mod.Model(self.hparams, session_params) 45 | 46 | 47 | if __name__ == "__main__": 48 | tf.test.main() 49 | -------------------------------------------------------------------------------- /vaeseq/examples/text/model_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for training and generating graphs.""" 16 | 17 | import io 18 | import os.path 19 | import tensorflow as tf 20 | from vaeseq import model_test 21 | 22 | from vaeseq.examples.text import hparams as hparams_mod 23 | from vaeseq.examples.text import model as model_mod 24 | 25 | 26 | class ModelTest(model_test.ModelTest): 27 | 28 | def _write_corpus(self, text): 29 | """Writes the given text to a temporary file and returns the path.""" 30 | temp_path = os.path.join(self.get_temp_dir(), "corpus.txt") 31 | with io.open(temp_path, "w", encoding="utf-8") as temp_file: 32 | temp_file.write(tf.compat.as_text(text)) 33 | return temp_path 34 | 35 | def _setup_model(self, session_params): 36 | self.train_dataset = self._write_corpus("1234567890" * 100) 37 | self.valid_dataset = self._write_corpus("123" * 20) 38 | self.hparams = hparams_mod.make_hparams( 39 | vocab_size=5, 40 | rnn_hidden_sizes=[4, 4], 41 | obs_encoder_fc_layers=[32, 16], 42 | obs_decoder_fc_hidden_layers=[32], 43 | latent_decoder_fc_layers=[32], 44 | check_numerics=True) 45 | vocab_corpus = self.train_dataset 46 | self.model = model_mod.Model(self.hparams, session_params, vocab_corpus) 47 | 48 | 49 | if __name__ == "__main__": 50 | tf.test.main() 51 | -------------------------------------------------------------------------------- /vaeseq/latent.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Module parameterizing latent variables.""" 16 | 17 | import sonnet as snt 18 | import tensorflow as tf 19 | 20 | from . import dist_module 21 | from . import util 22 | 23 | 24 | class LatentDecoder(dist_module.DistModule): 25 | """Inputs -> P(latent | inputs)""" 26 | 27 | def __init__(self, hparams, name=None): 28 | super(LatentDecoder, self).__init__(name=name) 29 | self._hparams = hparams 30 | 31 | @property 32 | def event_dtype(self): 33 | """The data type of the latent variables.""" 34 | return tf.float32 35 | 36 | @property 37 | def event_size(self): 38 | """The size of the latent variables.""" 39 | return tf.TensorShape([self._hparams.latent_size]) 40 | 41 | def dist(self, params, name=None): 42 | loc, scale_diag = params 43 | name = name or self.module_name + "_dist" 44 | return tf.contrib.distributions.MultivariateNormalDiag( 45 | loc, scale_diag, name=name) 46 | 47 | def _build(self, *inputs): 48 | hparams = self._hparams 49 | mlp = util.make_mlp( 50 | hparams, 51 | hparams.latent_decoder_fc_layers + [hparams.latent_size * 2]) 52 | dist_params = mlp(util.concat_features(inputs)) 53 | loc = dist_params[:, :hparams.latent_size] 54 | scale = util.positive_projection(hparams)( 55 | dist_params[:, hparams.latent_size:]) 56 | return (loc, scale) 57 | -------------------------------------------------------------------------------- /vaeseq/hparams.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Hyperparameters used in this library.""" 16 | 17 | import tensorflow as tf 18 | 19 | _DEFAULTS = dict( 20 | # Number of latent units per time step 21 | latent_size=4, 22 | 23 | # Model parameters 24 | obs_encoder_fc_layers=[256, 128], 25 | obs_decoder_fc_hidden_layers=[256], 26 | latent_decoder_fc_layers=[256], 27 | rnn_hidden_sizes=[32], 28 | 29 | # Default activation (relu/elu/etc.) 30 | activation='relu', 31 | 32 | # Postivitity constraint (softplus/exp/etc.) 33 | positive_projection='softplus', 34 | positive_eps=1e-5, 35 | 36 | # VAE params 37 | divergence_strength_start=1e-5, # scale on divergence penalty. 38 | divergence_strength_half=1e5, # in global-steps. 39 | vae_type='SRNN', # see vae.VAE_TYPES. 40 | use_monte_carlo_kl=False, 41 | srnn_use_res_q=True, 42 | 43 | # Training parameters 44 | learning_rate=0.0001, 45 | l1_regularization=0.0, 46 | l2_regularization=0.01, 47 | batch_size=32, 48 | sequence_size=5, 49 | clip_gradient_norm=1., 50 | check_numerics=True, 51 | 52 | # Evaluation parameters 53 | log_prob_samples=10, # number of latent samples to average over. 54 | ) 55 | 56 | def make_hparams(flag_value=None, **kwargs): 57 | """Initialize HParams with the defaults in this module.""" 58 | init = dict(_DEFAULTS) 59 | init.update(kwargs) 60 | ret = tf.contrib.training.HParams(**init) 61 | if flag_value: 62 | ret.parse(flag_value) 63 | return ret 64 | -------------------------------------------------------------------------------- /vaeseq/examples/play/environment_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for environment.py.""" 16 | 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | from vaeseq import util 21 | from vaeseq.examples.play import environment as env_mod 22 | 23 | 24 | class EnvironmentTest(tf.test.TestCase): 25 | 26 | def test_environment(self): 27 | hparams = tf.contrib.training.HParams( 28 | game="CartPole-v0", 29 | game_output_size=[4]) 30 | left_logits = [-100., 100.] 31 | right_logits = [100., -100.] 32 | actions = [[left_logits, right_logits] * 10, 33 | [right_logits, left_logits] * 10] 34 | batch_size = len(actions) 35 | actions = tf.constant(actions, dtype=tf.float32) 36 | env = env_mod.Environment(hparams) 37 | initial_state = env.initial_state(batch_size=batch_size) 38 | output_dtypes = env.output_dtype 39 | observed, _ = util.heterogeneous_dynamic_rnn( 40 | env, actions, 41 | initial_state=initial_state, 42 | output_dtypes=output_dtypes) 43 | with self.test_session() as sess: 44 | observed = sess.run(observed) 45 | outputs = observed["output"] 46 | game_over = observed["game_over"] 47 | scores = observed["score"] 48 | self.assertTrue(np.all(scores[game_over > 0] == 1)) 49 | nonzero_gameover_scores = scores[np.nonzero(scores[game_over > 0])] 50 | nonzero_gameover_outs = outputs[np.nonzero(outputs[game_over > 0])] 51 | self.assertLessEqual(len(nonzero_gameover_scores), batch_size) 52 | self.assertLessEqual(len(nonzero_gameover_outs), batch_size * 4) 53 | 54 | 55 | if __name__ == "__main__": 56 | tf.test.main() 57 | -------------------------------------------------------------------------------- /vaeseq/context_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for context modules.""" 16 | 17 | import tensorflow as tf 18 | 19 | from vaeseq import codec 20 | from vaeseq import context as context_mod 21 | 22 | 23 | class ContextTest(tf.test.TestCase): 24 | 25 | def testConstant(self): 26 | context = context_mod.as_context(tf.constant([[1,2,3]])) 27 | observed = tf.constant([["a", "b", "c", "d", "e"]]) 28 | contexts = context.from_observations(observed) 29 | with self.test_session() as sess: 30 | contexts = sess.run(contexts) 31 | self.assertAllEqual(contexts, [[1, 2, 3, 0, 0]]) 32 | 33 | def testEncodeObserved(self): 34 | encoder = codec.FlattenEncoder(input_size=tf.TensorShape([1])) 35 | context = context_mod.EncodeObserved(encoder) 36 | observed = tf.constant([[[1.], [2.], [3.]]]) 37 | contexts = context.from_observations(observed) 38 | with self.test_session() as sess: 39 | contexts = sess.run(contexts) 40 | self.assertAllClose(contexts, [[[0.], [1.], [2.]]]) 41 | 42 | def testChain(self): 43 | inputs = tf.constant([[[10.], [20.]]]) 44 | observed = tf.constant([[[1.], [2.], [3.]]]) 45 | encoder = codec.FlattenEncoder(input_size=tf.TensorShape([1])) 46 | context = context_mod.Chain([ 47 | context_mod.as_context(inputs), 48 | context_mod.EncodeObserved(encoder, input_encoder=encoder), 49 | ]) 50 | contexts = context.from_observations(observed) 51 | with self.test_session() as sess: 52 | inputs, contexts = sess.run(contexts) 53 | self.assertAllClose(inputs, [[[10.], [20.], [0.]]]) 54 | self.assertAllClose(contexts, [[[0.], [1.], [2.]]]) 55 | 56 | 57 | if __name__ == '__main__': 58 | tf.test.main() 59 | -------------------------------------------------------------------------------- /vaeseq/examples/midi/dataset_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for dataset.py functionality.""" 16 | 17 | import os.path 18 | import numpy as np 19 | import pretty_midi 20 | import tensorflow as tf 21 | 22 | from vaeseq.examples.midi import dataset as dataset_mod 23 | 24 | 25 | class DatasetTest(tf.test.TestCase): 26 | 27 | def _write_midi(self, note): 28 | """Write a temporary MIDI file with a note playing for one second.""" 29 | temp_path = os.path.join(self.get_temp_dir(), 30 | "note_{}.mid".format(note)) 31 | dataset_mod.write_test_note(temp_path, 1.0, note) 32 | return temp_path 33 | 34 | def test_piano_roll_sequences(self): 35 | filenames = [self._write_midi(5), self._write_midi(7)] 36 | batch_size = 2 37 | sequence_size = 3 38 | rate = 2 39 | dataset = dataset_mod.piano_roll_sequences( 40 | filenames, batch_size, sequence_size, rate) 41 | iterator = dataset.make_initializable_iterator() 42 | batch = iterator.get_next() 43 | with self.test_session() as sess: 44 | sess.run(iterator.initializer) 45 | batch = sess.run(batch) 46 | self.assertAllEqual(batch.shape, [batch_size, sequence_size, 128]) 47 | batch_idx, time_idx, note_idx = np.where(batch) 48 | self.assertAllEqual(batch_idx, [0, 0, 1, 1]) 49 | self.assertAllEqual(time_idx, [0, 1, 0, 1]) 50 | self.assertEqual(note_idx[0], note_idx[1]) 51 | self.assertIn(note_idx[0], (5, 7)) 52 | self.assertEqual(note_idx[2], note_idx[3]) 53 | self.assertIn(note_idx[1], (5, 7)) 54 | 55 | def test_piano_roll_to_midi(self): 56 | np.random.seed(0) 57 | piano_roll = np.random.uniform(size=(200, 128)) > 0.5 58 | midi = dataset_mod.piano_roll_to_midi(piano_roll, 2) 59 | self.assertAllEqual(piano_roll.T, midi.get_piano_roll(2) > 0) 60 | 61 | 62 | if __name__ == "__main__": 63 | tf.test.main() 64 | -------------------------------------------------------------------------------- /vaeseq/examples/text/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Dataset for iterating over text.""" 16 | 17 | import collections 18 | import numpy as np 19 | import tensorflow as tf 20 | 21 | 22 | def _split_string(string): 23 | """Splits a byte string into an array of character bytes.""" 24 | text = tf.compat.as_text(string) 25 | ret = np.empty(len(text), dtype=np.object) 26 | for i, char in enumerate(text): 27 | ret[i] = tf.compat.as_bytes(char) 28 | return ret 29 | 30 | 31 | def vocabulary(filename, max_size=None, num_oov_buckets=1): 32 | """Builds vocabulary and ID lookup tables from the given file.""" 33 | 34 | def _unique_chars(filename): 35 | """Returns the used alphabet as an array of strings.""" 36 | counts = collections.Counter() 37 | with tf.gfile.Open(filename) as file_: 38 | for line in file_: 39 | counts.update(_split_string(line)) 40 | alphabet = [k for (k, _) in counts.most_common(max_size)] 41 | alphabet.sort() 42 | return np.asarray(alphabet, dtype=np.object) 43 | 44 | chars, = tf.py_func(_unique_chars, [filename], [tf.string]) 45 | char_to_id = tf.contrib.lookup.index_table_from_tensor( 46 | chars, num_oov_buckets=num_oov_buckets) 47 | id_to_char = tf.contrib.lookup.index_to_string_table_from_tensor(chars, " ") 48 | return char_to_id, id_to_char 49 | 50 | 51 | def characters(filename, batch_size, sequence_size): 52 | """Returns a dataset of characters from the given file.""" 53 | 54 | def _to_chars(line): 55 | """string scalar -> Dataset of characters (string scalars).""" 56 | chars, = tf.py_func(_split_string, [line + "\n"], [tf.string]) 57 | chars.set_shape([None]) 58 | return tf.data.Dataset.from_tensor_slices(chars) 59 | 60 | return (tf.data.TextLineDataset([filename]) 61 | .flat_map(_to_chars) 62 | .repeat() 63 | .batch(tf.to_int64(sequence_size)) 64 | .shuffle(1000) 65 | .batch(tf.to_int64(batch_size))) 66 | -------------------------------------------------------------------------------- /vaeseq/BUILD: -------------------------------------------------------------------------------- 1 | # Everything depends on TensorFlow and Sonnet. 2 | package(default_visibility = ["//visibility:public"]) 3 | 4 | test_suite( 5 | name = "tests", 6 | tests = [ 7 | "batch_dist_test", 8 | "context_test", 9 | "model_test", 10 | "util_test", 11 | "//vaeseq/examples/midi:tests", 12 | "//vaeseq/examples/play:tests", 13 | "//vaeseq/examples/text:tests", 14 | "//vaeseq/vae:vae_test", 15 | ], 16 | ) 17 | 18 | py_library( 19 | name = "dist_module", 20 | srcs = ["dist_module.py"], 21 | deps = [":util"], 22 | ) 23 | 24 | py_library( 25 | name = "vae_module", 26 | srcs = ["vae_module.py"], 27 | deps = [ 28 | ":dist_module", 29 | ":util", 30 | ], 31 | ) 32 | 33 | py_library( 34 | name = "batch_dist", 35 | srcs = ["batch_dist.py"], 36 | ) 37 | 38 | py_test( 39 | name = "batch_dist_test", 40 | srcs = ["batch_dist_test.py"], 41 | deps = [ 42 | ":batch_dist", 43 | ], 44 | ) 45 | 46 | py_library( 47 | name = "latent", 48 | srcs = ["latent.py"], 49 | deps = [ 50 | ":dist_module", 51 | ":util", 52 | ], 53 | ) 54 | 55 | py_library( 56 | name = "hparams", 57 | srcs = ["hparams.py"], 58 | ) 59 | 60 | py_library( 61 | name = "util", 62 | srcs = ["util.py"], 63 | ) 64 | 65 | py_test( 66 | name = "util_test", 67 | srcs = ["util_test.py"], 68 | deps = [ 69 | ":hparams", 70 | ":util", 71 | ], 72 | ) 73 | 74 | py_library( 75 | name = "context", 76 | srcs = ["context.py"], 77 | deps = [ 78 | ":util", 79 | ], 80 | ) 81 | 82 | py_test( 83 | name = "context_test", 84 | srcs = ["context_test.py"], 85 | deps = [ 86 | ":codec", 87 | ":context", 88 | ], 89 | ) 90 | 91 | py_library( 92 | name = "codec", 93 | srcs = ["codec.py"], 94 | deps = [ 95 | ":batch_dist", 96 | ":dist_module", 97 | ":util", 98 | ], 99 | ) 100 | 101 | py_library( 102 | name = "train", 103 | srcs = ["train.py"], 104 | deps = [ 105 | ":util", 106 | ], 107 | ) 108 | 109 | py_library( 110 | name = "model", 111 | srcs = ["model.py"], 112 | deps = [ 113 | ":context", 114 | ":train", 115 | ":util", 116 | ], 117 | ) 118 | 119 | py_test( 120 | name = "model_test", 121 | srcs = ["model_test.py"], 122 | deps = [ 123 | ":codec", 124 | ":context", 125 | ":hparams", 126 | ":model", 127 | ":util", 128 | "//vaeseq/vae", 129 | ], 130 | ) 131 | -------------------------------------------------------------------------------- /vaeseq/vae/rnn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """We can view an RNN as a VAE with no latent variables: 16 | 17 | Notation: 18 | - d_1:T are the (deterministic) RNN outputs. 19 | - x_1:T are the observed states. 20 | - c_1:T are per-timestep inputs. 21 | 22 | Generative model 23 | ===================== 24 | x_1 x_t 25 | ^ ^ 26 | | | 27 | d_1 ------------> d_t 28 | ^ ^ 29 | | | 30 | c_1 c_t 31 | """ 32 | 33 | import tensorflow as tf 34 | from tensorflow.contrib import distributions 35 | 36 | from .. import util 37 | from .. import vae_module 38 | 39 | class RNN(vae_module.VAECore): 40 | """Implementation of an RNN as a sequential VAE where all latent 41 | variables are deterministic.""" 42 | 43 | def __init__(self, hparams, obs_encoder, obs_decoder, name=None): 44 | super(RNN, self).__init__(hparams, obs_encoder, obs_decoder, name) 45 | with self._enter_variable_scope(): 46 | self._d_core = util.make_rnn(hparams, name="d_core") 47 | 48 | @property 49 | def state_size(self): 50 | return self._d_core.state_size 51 | 52 | def _next_state(self, d_state, event=None): 53 | del event # Not used. 54 | return d_state 55 | 56 | def _initial_state(self, batch_size): 57 | return self._d_core.initial_state(batch_size) 58 | 59 | def _build(self, input_, d_state): 60 | d_out, d_state = self._d_core(util.concat_features(input_), d_state) 61 | return self._obs_decoder(d_out), d_state 62 | 63 | def _infer_latents(self, inputs, observed): 64 | """Because the RNN latent state is fully deterministic, there's no 65 | need to do two passes over the training data.""" 66 | del inputs # Not used. 67 | batch_size = util.batch_size_from_nested_tensors(observed) 68 | sequence_size = util.sequence_size_from_nested_tensors(observed) 69 | divs = tf.zeros([batch_size, sequence_size], name="divergences") 70 | return None, divs 71 | -------------------------------------------------------------------------------- /vaeseq/examples/text/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions to build up training and generation graphs.""" 16 | 17 | from __future__ import print_function 18 | 19 | import sonnet as snt 20 | import tensorflow as tf 21 | 22 | from vaeseq import context as context_mod 23 | from vaeseq import codec as codec_mod 24 | from vaeseq import model as model_mod 25 | from vaeseq import util 26 | 27 | from . import dataset as dataset_mod 28 | 29 | 30 | class Model(model_mod.ModelBase): 31 | """Putting everything together.""" 32 | 33 | def __init__(self, hparams, session_params, vocab_corpus): 34 | self._char_to_id, self._id_to_char = dataset_mod.vocabulary( 35 | vocab_corpus, 36 | max_size=hparams.vocab_size, 37 | num_oov_buckets=hparams.oov_buckets) 38 | super(Model, self).__init__(hparams, session_params) 39 | 40 | def _make_encoder(self): 41 | """Constructs an encoding for a single character ID.""" 42 | embed = snt.Embed( 43 | vocab_size=self.hparams.vocab_size + self.hparams.oov_buckets, 44 | embed_dim=self.hparams.embed_size) 45 | mlp = codec_mod.MLPObsEncoder(self.hparams) 46 | return codec_mod.EncoderSequence([embed, mlp], name="obs_encoder") 47 | 48 | def _make_decoder(self): 49 | """Constructs a decoding for a single character ID.""" 50 | return codec_mod.MLPObsDecoder( 51 | self.hparams, 52 | decoder=codec_mod.CategoricalDecoder(), 53 | param_size=self.hparams.vocab_size + self.hparams.oov_buckets, 54 | name="obs_decoder") 55 | 56 | def _make_dataset(self, corpus): 57 | dataset = dataset_mod.characters(corpus, 58 | util.batch_size(self.hparams), 59 | util.sequence_size(self.hparams)) 60 | iterator = dataset.make_initializable_iterator() 61 | tf.add_to_collection(tf.GraphKeys.LOCAL_INIT_OP, iterator.initializer) 62 | observed = self._char_to_id.lookup(iterator.get_next()) 63 | inputs = None 64 | return inputs, observed 65 | 66 | def _make_output_summary(self, tag, observed): 67 | return tf.summary.text(tag, self._render(observed), collections=[]) 68 | 69 | def _render(self, observed): 70 | """Returns a batch of strings corresponding to the ID sequences.""" 71 | # Note, tf.reduce_sum doesn't work on strings. 72 | return tf.py_func(lambda chars: chars.sum(axis=-1), 73 | [self._id_to_char.lookup(tf.to_int64(observed))], 74 | [tf.string])[0] 75 | 76 | def generate(self): 77 | """Return UTF-8 strings rather than bytes.""" 78 | for string in super(Model, self).generate(): 79 | yield tf.compat.as_text(string) 80 | -------------------------------------------------------------------------------- /vaeseq/examples/text/dataset_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # -*- coding: utf-8 -*- 16 | """Tests for dataset.py functionality.""" 17 | 18 | import io 19 | import os.path 20 | import tensorflow as tf 21 | 22 | from vaeseq.examples.text import dataset as dataset_mod 23 | 24 | 25 | class DatasetTest(tf.test.TestCase): 26 | 27 | def _write_corpus(self, text): 28 | """Save text to a temporary file and return the path.""" 29 | temp_path = os.path.join(self.get_temp_dir(), "corpus.txt") 30 | with io.open(temp_path, "w", encoding="utf-8") as temp_file: 31 | temp_file.write(text) 32 | return temp_path 33 | 34 | def test_vocabulary(self): 35 | text = u"hello\nこんにちは" 36 | vocab_size = len(set(text)) 37 | corpus = self._write_corpus(text) 38 | char_to_id, id_to_char = dataset_mod.vocabulary(corpus) 39 | ids = char_to_id.lookup( 40 | tf.constant([tf.compat.as_bytes(c) 41 | for c in ["X", "l", "\n", u"こ"]])) 42 | chars = id_to_char.lookup(tf.constant([0, 100], dtype=tf.int64)) 43 | with self.test_session() as sess: 44 | sess.run(tf.tables_initializer()) 45 | ids, chars = sess.run([ids, chars]) 46 | self.assertEqual(ids[0], vocab_size) 47 | self.assertTrue(0 <= ids[1] < vocab_size) 48 | self.assertTrue(0 <= ids[2] < vocab_size) 49 | self.assertTrue(0 <= ids[3] < vocab_size) 50 | chars = [tf.compat.as_text(c) for c in chars] 51 | self.assertTrue(chars[0] in text) 52 | self.assertEqual(chars[1], " ") 53 | 54 | def test_vocabulary_capped(self): 55 | text = u"hello\nこんにちは" 56 | corpus = self._write_corpus(text) 57 | char_to_id, id_to_char = dataset_mod.vocabulary(corpus, max_size=1, 58 | num_oov_buckets=1) 59 | ids = char_to_id.lookup( 60 | tf.constant([tf.compat.as_bytes(c) 61 | for c in ["X", "l", "\n", u"こ"]])) 62 | chars = id_to_char.lookup(tf.constant([0, 2], dtype=tf.int64)) 63 | with self.test_session() as sess: 64 | sess.run(tf.tables_initializer()) 65 | ids, chars = sess.run([ids, chars]) 66 | self.assertAllEqual(ids, [1, 0, 1, 1]) 67 | self.assertAllEqual(chars, [b"l", b" "]) 68 | 69 | def test_characters(self): 70 | tf.set_random_seed(1) 71 | text = u"hello\nこんにちは" 72 | dataset = dataset_mod.characters(self._write_corpus(text), 2, 6) 73 | iterator = dataset.make_initializable_iterator() 74 | batch = iterator.get_next() 75 | with self.test_session() as sess: 76 | sess.run(iterator.initializer) 77 | self.assertAllEqual( 78 | sess.run(batch), 79 | [[tf.compat.as_bytes(c) for c in u"こんにちは\n"], 80 | [tf.compat.as_bytes(c) for c in u"hello\n"]]) 81 | 82 | 83 | if __name__ == "__main__": 84 | tf.test.main() 85 | -------------------------------------------------------------------------------- /vaeseq/examples/midi/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Dataset for iterating over MIDI files.""" 16 | 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import pretty_midi 21 | import tensorflow as tf 22 | 23 | 24 | def piano_roll_sequences(filenames, batch_size, sequence_size, rate=100): 25 | """Returns a dataset of piano roll sequences from the given files..""" 26 | 27 | def _to_piano_roll(filename, sequence_size): 28 | """Load a file and return consecutive piano roll sequences.""" 29 | try: 30 | midi = pretty_midi.PrettyMIDI(tf.compat.as_text(filename)) 31 | except Exception: 32 | print("Skipping corrupt MIDI file", filename) 33 | return np.zeros([0, sequence_size, 128], dtype=np.bool) 34 | roll = np.asarray(midi.get_piano_roll(rate).transpose(), dtype=np.bool) 35 | assert roll.shape[1] == 128 36 | # Pad the roll to a multiple of sequence_size 37 | length = len(roll) 38 | remainder = length % sequence_size 39 | if remainder: 40 | new_length = length + sequence_size - remainder 41 | roll = np.resize(roll, (new_length, 128)) 42 | roll[length:, :] = False 43 | length = new_length 44 | return np.reshape(roll, (length // sequence_size, sequence_size, 128)) 45 | 46 | def _to_piano_roll_dataset(filename): 47 | """Filename (string scalar) -> Dataset of piano roll sequences.""" 48 | sequences, = tf.py_func(_to_piano_roll, 49 | [filename, sequence_size], 50 | [tf.bool]) 51 | sequences.set_shape([None, None, 128]) 52 | return tf.data.Dataset.from_tensor_slices(sequences) 53 | 54 | batch_size = tf.to_int64(batch_size) 55 | return (tf.data.Dataset.from_tensor_slices(filenames) 56 | .interleave(_to_piano_roll_dataset, 57 | cycle_length=batch_size * 5, 58 | block_length=1) 59 | .repeat() 60 | .shuffle(1000) 61 | .batch(batch_size)) 62 | 63 | 64 | def piano_roll_to_midi(piano_roll, sample_rate): 65 | """Convert the piano roll to a PrettyMIDI object. 66 | See: http://github.com/craffel/examples/reverse_pianoroll.py 67 | """ 68 | midi = pretty_midi.PrettyMIDI() 69 | instrument = pretty_midi.Instrument(0) 70 | midi.instruments.append(instrument) 71 | padded_roll = np.pad(piano_roll, [(1, 1), (0, 0)], mode='constant') 72 | changes = np.diff(padded_roll, axis=0) 73 | notes = np.full(piano_roll.shape[1], -1, dtype=np.int) 74 | for tick, pitch in zip(*np.where(changes)): 75 | prev = notes[pitch] 76 | if prev == -1: 77 | notes[pitch] = tick 78 | continue 79 | notes[pitch] = -1 80 | instrument.notes.append(pretty_midi.Note( 81 | velocity=100, 82 | pitch=pitch, 83 | start=prev / float(sample_rate), 84 | end=tick / float(sample_rate))) 85 | return midi 86 | 87 | 88 | def write_test_note(path, duration, note): 89 | midi = pretty_midi.PrettyMIDI() 90 | instrument = pretty_midi.Instrument(0) 91 | instrument.notes.append(pretty_midi.Note(100, note, 0.0, duration)) 92 | midi.instruments.append(instrument) 93 | midi.write(path) 94 | -------------------------------------------------------------------------------- /vaeseq/examples/play/play.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Model Cart-Pole and train an Agent via policy gradient.""" 16 | 17 | import argparse 18 | import time 19 | import sys 20 | import tensorflow as tf 21 | 22 | from vaeseq.examples.play import environment as env_mod 23 | from vaeseq.examples.play import hparams as hparams_mod 24 | from vaeseq.examples.play import model as model_mod 25 | from vaeseq import util 26 | 27 | 28 | def train(flags): 29 | model = model_mod.Model( 30 | hparams=hparams_mod.make_hparams(flags.hparams), 31 | session_params=flags) 32 | model.train("train", flags.num_steps) 33 | 34 | def run(flags): 35 | hparams = hparams_mod.make_hparams(flags.hparams) 36 | hparams.batch_size = 1 37 | hparams.sequence_size = flags.max_moves 38 | batch_size = util.batch_size(hparams) 39 | model = model_mod.Model(hparams=hparams, session_params=flags) 40 | if flags.agent == "trained": 41 | agent = model.agent 42 | elif flags.agent == "random": 43 | agent = model_mod.agent_mod.RandomAgent(hparams) 44 | else: 45 | raise ValueError("I don't understand --agent " + flags.agent) 46 | outputs = agent.drive_rnn( 47 | model.env, 48 | sequence_size=util.sequence_size(hparams), 49 | initial_state=agent.initial_state(batch_size=batch_size), 50 | cell_initial_state=model.env.initial_state(batch_size=batch_size)) 51 | score = tf.reduce_sum(outputs["score"]) 52 | with model.eval_session() as sess: 53 | model.env.start_render_thread() 54 | for _ in range(flags.num_games): 55 | print("Score: ", sess.run(score)) 56 | sys.stdout.flush() 57 | model.env.stop_render_thread() 58 | 59 | 60 | # Argument parsing code below. 61 | 62 | def common_args(args): 63 | model_mod.Model.SessionParams.add_parser_arguments(args) 64 | args.add_argument( 65 | "--hparams", default="", 66 | help="Model hyperparameter overrides.") 67 | 68 | 69 | def train_args(args): 70 | common_args(args) 71 | args.add_argument( 72 | "--num-steps", type=int, default=int(1e6), 73 | help="Number of training iterations.") 74 | args.set_defaults(entry=train) 75 | 76 | 77 | def run_args(args): 78 | common_args(args) 79 | args.add_argument( 80 | "--max-moves", type=int, default=1000, 81 | help="Maximum number of moves per game.") 82 | args.add_argument( 83 | "--num-games", type=int, default=1, 84 | help="Number of games to play.") 85 | args.add_argument( 86 | "--agent", default="trained", choices=["trained", "random"], 87 | help="Which agent to use.") 88 | args.set_defaults(entry=run) 89 | 90 | 91 | def main(): 92 | args = argparse.ArgumentParser() 93 | subcommands = args.add_subparsers(title="subcommands") 94 | train_args(subcommands.add_parser( 95 | "train", help="Train a model.")) 96 | run_args(subcommands.add_parser( 97 | "run", help="Run a traned model.")) 98 | flags, unparsed_args = args.parse_known_args(sys.argv[1:]) 99 | if not hasattr(flags, "entry"): 100 | args.print_help() 101 | return 1 102 | tf.logging.set_verbosity(tf.logging.INFO) 103 | tf.app.run(main=lambda _unused_argv: flags.entry(flags), 104 | argv=[sys.argv[0]] + unparsed_args) 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /vaeseq/model_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ModelBase.""" 16 | 17 | import itertools 18 | import os.path 19 | import tensorflow as tf 20 | import sonnet as snt 21 | 22 | from vaeseq import codec 23 | from vaeseq import context as context_mod 24 | from vaeseq import hparams as hparams_mod 25 | from vaeseq import model as model_mod 26 | from vaeseq import util 27 | 28 | 29 | class ModelTest(tf.test.TestCase): 30 | 31 | def setUp(self): 32 | super(ModelTest, self).setUp() 33 | log_dir = os.path.join(self.get_temp_dir(), "log_dir") 34 | session_config = tf.ConfigProto() 35 | session_config.device_count["GPU"] = 0 36 | session_params = model_mod.ModelBase.SessionParams( 37 | log_dir=log_dir, session_config=session_config) 38 | self._setup_model(session_params) 39 | 40 | def _setup_model(self, session_params): 41 | self.train_dataset = "train" 42 | self.valid_dataset = "valid" 43 | self.hparams = hparams_mod.make_hparams() 44 | self.model = MockModel(self.hparams, session_params) 45 | 46 | def _train(self): 47 | return self.model.train(self.train_dataset, num_steps=20, 48 | valid_dataset=self.valid_dataset) 49 | 50 | def _evaluate(self): 51 | return self.model.evaluate(self.train_dataset, num_steps=20) 52 | 53 | def test_inputs(self): 54 | train_inputs, observed = self.model.dataset(self.train_dataset) 55 | train_inputs = context_mod.as_tensors(train_inputs, observed) 56 | gen_inputs = self.model.inputs 57 | gen_inputs = context_mod.as_tensors(gen_inputs, observed) 58 | with self.model.eval_session() as sess: 59 | train_inputs, gen_inputs = sess.run((train_inputs, gen_inputs)) 60 | def _inputs_compatible(inp1, inp2): 61 | self.assertEqual(inp1.dtype, inp2.dtype) 62 | self.assertEqual(inp1.shape, inp2.shape) 63 | snt.nest.map(_inputs_compatible, train_inputs, gen_inputs) 64 | 65 | def test_training_and_eval(self): 66 | train_debug1 = self._train() 67 | eval_debug1 = self._evaluate() 68 | train_debug2 = self._train() 69 | eval_debug2 = self._evaluate() 70 | self.assertLess(train_debug2["elbo_loss"], train_debug1["elbo_loss"]) 71 | self.assertGreater(eval_debug2["log_prob"], eval_debug1["log_prob"]) 72 | 73 | def test_genaration(self): 74 | # Just make sure the graph executes without error. 75 | for seq in itertools.islice(self.model.generate(), 10): 76 | tf.logging.debug("Generated: %r", seq) 77 | 78 | 79 | class MockModel(model_mod.ModelBase): 80 | """Modeling zeros for testing.""" 81 | 82 | def _make_encoder(self): 83 | return codec.MLPObsEncoder(self.hparams) 84 | 85 | def _make_decoder(self): 86 | return codec.BatchDecoder( 87 | codec.MLPObsDecoder( 88 | self.hparams, 89 | codec.NormalDecoder( 90 | positive_projection=util.positive_projection(self.hparams)), 91 | param_size=4), 92 | event_size=[2]) 93 | 94 | def _make_dataset(self, dataset): 95 | batch_size = util.batch_size(self.hparams) 96 | sequence_size = util.sequence_size(self.hparams) 97 | observed = tf.zeros([batch_size, sequence_size, 2]) 98 | inputs = None 99 | return inputs, observed 100 | 101 | def _make_output_summary(self, tag, observed): 102 | return tf.summary.histogram(tag, observed, collections=[]) 103 | 104 | 105 | if __name__ == "__main__": 106 | tf.test.main() 107 | -------------------------------------------------------------------------------- /vaeseq/examples/play/agent.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Game-playing agent.""" 16 | 17 | import abc 18 | import sonnet as snt 19 | import tensorflow as tf 20 | 21 | from vaeseq import context as context_mod 22 | from vaeseq import util 23 | 24 | 25 | class AgentBase(context_mod.Context): 26 | """Base class for input agents.""" 27 | 28 | def __init__(self, hparams, name=None): 29 | super(AgentBase, self).__init__(name=name) 30 | self._hparams = hparams 31 | self._num_actions = tf.TensorShape([self._hparams.game_action_space]) 32 | 33 | @property 34 | def output_size(self): 35 | return self._num_actions 36 | 37 | @property 38 | def output_dtype(self): 39 | return tf.float32 40 | 41 | @abc.abstractmethod 42 | def get_variables(self): 43 | """Returns the variables used by this Agent.""" 44 | 45 | 46 | class RandomAgent(AgentBase): 47 | """Produces actions randomly, for exploration.""" 48 | 49 | def __init__(self, hparams, name=None): 50 | super(RandomAgent, self).__init__(hparams, name=name) 51 | self._dist = tf.distributions.Dirichlet(tf.ones(self._num_actions)) 52 | 53 | @property 54 | def state_size(self): 55 | return tf.TensorShape([0]) 56 | 57 | @property 58 | def state_dtype(self): 59 | return tf.float32 60 | 61 | def observe(self, observation, state): 62 | return state 63 | 64 | def get_variables(self): 65 | return None 66 | 67 | def _build(self, input_, state): 68 | del input_ # Not used. 69 | batch_size = tf.shape(state)[0] 70 | return self._dist.sample(batch_size), state 71 | 72 | 73 | class TrainableAgent(AgentBase): 74 | """Produces actions from a policy RNN.""" 75 | 76 | def __init__(self, hparams, obs_encoder, name=None): 77 | super(TrainableAgent, self).__init__(hparams, name=name) 78 | self._agent_variables = None 79 | self._obs_encoder = obs_encoder 80 | with self._enter_variable_scope(): 81 | self._policy_rnn = util.make_rnn(hparams, name="policy_rnn") 82 | self._project_act = util.make_mlp( 83 | hparams, layers=[hparams.game_action_space], name="policy_proj") 84 | 85 | @property 86 | def state_size(self): 87 | return dict(policy=self._policy_rnn.state_size, 88 | action_logits=self._num_actions, 89 | obs_enc=self._obs_encoder.output_size) 90 | 91 | @property 92 | def state_dtype(self): 93 | return snt.nest.map(lambda _: tf.float32, self.state_size) 94 | 95 | def get_variables(self): 96 | if self._agent_variables is None: 97 | raise ValueError("Agent variables haven't been constructed yet.") 98 | return self._agent_variables 99 | 100 | def observe(self, observation, state): 101 | obs_enc = self._obs_encoder(observation) 102 | rnn_state = state["policy"] 103 | hidden, rnn_state = self._policy_rnn(obs_enc, rnn_state) 104 | action_logits = self._project_act(hidden) 105 | if self._agent_variables is None: 106 | self._agent_variables = snt.nest.flatten( 107 | (self._policy_rnn.get_variables(), 108 | self._project_act.get_variables())) 109 | if self._hparams.explore_temp > 0: 110 | dist = tf.contrib.distributions.ExpRelaxedOneHotCategorical( 111 | self._hparams.explore_temp, 112 | logits=action_logits) 113 | action_logits = dist.sample() 114 | return dict(policy=rnn_state, 115 | action_logits=action_logits, 116 | obs_enc=obs_enc) 117 | 118 | def _build(self, input_, state): 119 | if input_ is not None: 120 | raise ValueError("I don't know how to encode any inputs.") 121 | return state["action_logits"], state 122 | -------------------------------------------------------------------------------- /vaeseq/util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for utility functions in util.py.""" 16 | 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | from vaeseq import hparams as hparams_mod 21 | from vaeseq import util 22 | 23 | 24 | def _add_sub_core(): 25 | """Creates an RNN Core: (a, b, c), s -> (a+s, b-s, c+s), -s""" 26 | return util.WrapRNNCore( 27 | lambda inp, state: (inp + state, -state), 28 | state_size=tf.TensorShape([]), 29 | output_size=tf.TensorShape([1]), 30 | name="AddSubCore") 31 | 32 | 33 | def _identity_core(input_shape): 34 | """Creates an RNN Core that just propagates its inputs.""" 35 | return util.WrapRNNCore( 36 | lambda inp, _: (inp, ()), 37 | state_size=(), 38 | output_size=input_shape, 39 | name="IdentityCore") 40 | 41 | 42 | class UtilTest(tf.test.TestCase): 43 | 44 | def test_calc_kl_analytical(self): 45 | hparams = hparams_mod.make_hparams(use_monte_carlo_kl=False) 46 | dist_a = tf.distributions.Bernoulli(probs=0.5) 47 | dist_b = tf.distributions.Bernoulli(probs=0.3) 48 | kl_div = util.calc_kl(hparams, dist_a.sample(), dist_a, dist_b) 49 | with self.test_session(): 50 | self.assertAllClose( 51 | kl_div.eval(), 52 | 0.5 * (np.log(0.5 / 0.3) + np.log(0.5 / 0.7))) 53 | 54 | def test_calc_kl_mc(self): 55 | tf.set_random_seed(0) 56 | hparams = hparams_mod.make_hparams(use_monte_carlo_kl=True) 57 | samples = 1000 58 | dist_a = tf.distributions.Bernoulli(probs=tf.fill([samples], 0.5)) 59 | dist_b = tf.distributions.Bernoulli(probs=tf.fill([samples], 0.3)) 60 | kl_div = tf.reduce_mean( 61 | util.calc_kl(hparams, dist_a.sample(), dist_a, dist_b), 62 | axis=0) 63 | with self.test_session(): 64 | self.assertAllClose( 65 | kl_div.eval(), 66 | 0.5 * (np.log(0.5 / 0.3) + np.log(0.5 / 0.7)), 67 | atol=0.05) 68 | 69 | def test_concat_features(self): 70 | feature1 = tf.constant([[1, 2]]) 71 | feature2 = tf.constant([[3]]) 72 | feature3 = tf.constant([[4, 5]]) 73 | with self.test_session(): 74 | self.assertAllEqual( 75 | util.concat_features((feature1, (feature2, feature3))).eval(), 76 | [[1, 2, 3, 4, 5]]) 77 | 78 | def test_wrap_rnn_core(self): 79 | core = _add_sub_core() 80 | input_ = tf.constant([[[1], [2], [3]]]) 81 | state = tf.constant(5) 82 | output, out_state = tf.nn.dynamic_rnn(core, input_, initial_state=state) 83 | with self.test_session(): 84 | self.assertEqual(out_state.eval(), -5) 85 | self.assertAllEqual(output.eval(), [[[1 + 5], [2 - 5], [3 + 5]]]) 86 | 87 | def test_reverse_dynamic_rnn(self): 88 | core = _add_sub_core() 89 | input_ = tf.constant([[[1], [2]]]) 90 | state = tf.constant(5) 91 | output, _ = util.reverse_dynamic_rnn( 92 | core, input_, initial_state=state) 93 | with self.test_session(): 94 | self.assertAllEqual(output.eval(), [[[1 - 5], [2 + 5]]]) 95 | 96 | def test_heterogeneous_dynamic_rnn(self): 97 | inputs = (tf.constant([[["hi"], ["there"]]]), 98 | tf.constant([[[1, 2], [3, 4]]], dtype=tf.int32)) 99 | core = _identity_core((tf.TensorShape([1]), tf.TensorShape([2]))) 100 | outputs, _ = util.heterogeneous_dynamic_rnn( 101 | core, inputs, initial_state=(), output_dtypes=(tf.string, tf.int32)) 102 | with self.test_session() as sess: 103 | inputs, outputs = sess.run((inputs, outputs)) 104 | self.assertAllEqual(inputs[0], outputs[0]) 105 | self.assertAllEqual(inputs[1], outputs[1]) 106 | 107 | 108 | if __name__ == "__main__": 109 | tf.test.main() 110 | -------------------------------------------------------------------------------- /vaeseq/examples/midi/midi.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Model MIDI sequences.""" 16 | 17 | import argparse 18 | import itertools 19 | import os.path 20 | import sys 21 | 22 | import scipy.io.wavfile 23 | import tensorflow as tf 24 | 25 | from vaeseq.examples.midi import hparams as hparams_mod 26 | from vaeseq.examples.midi import model as model_mod 27 | 28 | 29 | def train(flags): 30 | model = model_mod.Model( 31 | hparams=hparams_mod.make_hparams(flags.hparams), 32 | session_params=flags) 33 | model.train(flags.train_files, flags.num_steps, 34 | valid_dataset=flags.valid_files) 35 | 36 | 37 | def evaluate(flags): 38 | model = model_mod.Model( 39 | hparams=hparams_mod.make_hparams(flags.hparams), 40 | session_params=flags) 41 | model.evaluate(flags.eval_files, flags.num_steps) 42 | 43 | 44 | def generate(flags): 45 | hparams = hparams_mod.make_hparams(flags.hparams) 46 | hparams.sequence_size = int(hparams.rate * flags.length) 47 | model = model_mod.Model(hparams=hparams, session_params=flags) 48 | samples = itertools.islice(model.generate(), flags.num_samples) 49 | for i, wav in enumerate(samples): 50 | basename = "generated_{:02}.wav".format(i + 1) 51 | tf.logging.info("Writing %s.", basename) 52 | out_path = os.path.join(flags.out_dir, basename) 53 | scipy.io.wavfile.write(out_path, model_mod.Model.SYNTHESIZED_RATE, wav) 54 | 55 | 56 | # Argument parsing code below. 57 | 58 | def common_args(args): 59 | model_mod.Model.SessionParams.add_parser_arguments(args) 60 | args.add_argument( 61 | "--hparams", default="", 62 | help="Model hyperparameter overrides.") 63 | 64 | 65 | def train_args(args): 66 | common_args(args) 67 | args.add_argument( 68 | "--train-files", nargs="+", 69 | help="MIDI files to train on.", 70 | required=True) 71 | args.add_argument( 72 | "--valid-files", nargs="+", 73 | help="MIDI files to evaluate while training.") 74 | args.add_argument( 75 | "--num-steps", type=int, default=int(1e6), 76 | help="Number of training iterations.") 77 | args.set_defaults(entry=train) 78 | 79 | 80 | def eval_args(args): 81 | common_args(args) 82 | args.add_argument( 83 | "--eval-files", nargs="+", 84 | help="MIDI files to evaluate.", 85 | required=True) 86 | args.add_argument( 87 | "--num-steps", type=int, default=int(1e3), 88 | help="Number of eval iterations.") 89 | args.set_defaults(entry=evaluate) 90 | 91 | 92 | def generate_args(args): 93 | common_args(args) 94 | args.add_argument( 95 | "--out-dir", 96 | help="Where to store the generated sequences.", 97 | required=True) 98 | args.add_argument( 99 | "--length", type=float, default=5., 100 | help="Length of the generated sequences, in seconds.") 101 | args.add_argument( 102 | "--num-samples", type=int, default=20, 103 | help="Number of sequences to generate.") 104 | args.set_defaults(entry=generate) 105 | 106 | 107 | def main(): 108 | args = argparse.ArgumentParser() 109 | subcommands = args.add_subparsers(title="subcommands") 110 | train_args(subcommands.add_parser( 111 | "train", help="Train a model.")) 112 | eval_args(subcommands.add_parser( 113 | "evaluate", help="Evaluate a trained model.")) 114 | generate_args(subcommands.add_parser( 115 | "generate", help="Generate some music.")) 116 | flags, unparsed_args = args.parse_known_args(sys.argv[1:]) 117 | if not hasattr(flags, "entry"): 118 | args.print_help() 119 | return 1 120 | tf.logging.set_verbosity(tf.logging.INFO) 121 | tf.app.run(main=lambda _unused_argv: flags.entry(flags), 122 | argv=[sys.argv[0]] + unparsed_args) 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /vaeseq/examples/text/text.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Model sequences of text, character-by-character.""" 16 | 17 | from __future__ import print_function 18 | 19 | import argparse 20 | import itertools 21 | import sys 22 | import tensorflow as tf 23 | 24 | from vaeseq.examples.text import hparams as hparams_mod 25 | from vaeseq.examples.text import model as model_mod 26 | 27 | 28 | def train(flags): 29 | if flags.vocab_corpus is None: 30 | print("NOTE: no --vocab-corpus supplied; using", 31 | repr(flags.train_corpus), "for vocabulary.") 32 | model = model_mod.Model( 33 | hparams=hparams_mod.make_hparams(flags.hparams), 34 | session_params=flags, 35 | vocab_corpus=flags.vocab_corpus or flags.train_corpus) 36 | model.train(flags.train_corpus, flags.num_steps, 37 | valid_dataset=flags.valid_corpus) 38 | 39 | 40 | def evaluate(flags): 41 | model = model_mod.Model( 42 | hparams=hparams_mod.make_hparams(flags.hparams), 43 | session_params=flags, 44 | vocab_corpus=flags.vocab_corpus) 45 | model.evaluate(flags.eval_corpus, flags.num_steps) 46 | 47 | 48 | def generate(flags): 49 | hparams = hparams_mod.make_hparams(flags.hparams) 50 | hparams.sequence_size = flags.length 51 | model = model_mod.Model( 52 | hparams=hparams, 53 | session_params=flags, 54 | vocab_corpus=flags.vocab_corpus) 55 | for i, string in enumerate(itertools.islice(model.generate(), 56 | flags.num_samples)): 57 | print("#{:02d}: {}\n".format(i + 1, string)) 58 | 59 | 60 | # Argument parsing code below. 61 | 62 | def common_args(args, require_vocab): 63 | model_mod.Model.SessionParams.add_parser_arguments(args) 64 | args.add_argument( 65 | "--hparams", default="", 66 | help="Model hyperparameter overrides.") 67 | args.add_argument( 68 | "--vocab-corpus", 69 | help="Path to the corpus used for vocabulary generation.", 70 | required=require_vocab) 71 | 72 | 73 | def train_args(args): 74 | common_args(args, require_vocab=False) 75 | args.add_argument( 76 | "--train-corpus", 77 | help="Location of the training text.", 78 | required=True) 79 | args.add_argument( 80 | "--valid-corpus", 81 | help="Location of the validation text.") 82 | args.add_argument( 83 | "--num-steps", type=int, default=int(1e6), 84 | help="Number of training iterations.") 85 | args.set_defaults(entry=train) 86 | 87 | 88 | def eval_args(args): 89 | common_args(args, require_vocab=True) 90 | args.add_argument( 91 | "--eval-corpus", 92 | help="Location of the training text.", 93 | required=True) 94 | args.add_argument( 95 | "--num-steps", type=int, default=int(1e3), 96 | help="Number of eval iterations.") 97 | args.set_defaults(entry=evaluate) 98 | 99 | 100 | def generate_args(args): 101 | common_args(args, require_vocab=True) 102 | args.add_argument( 103 | "--length", type=int, default=1000, 104 | help="Length of the generated strings.") 105 | args.add_argument( 106 | "--num-samples", type=int, default=20, 107 | help="Number of strings to generate.") 108 | args.set_defaults(entry=generate) 109 | 110 | 111 | def main(): 112 | args = argparse.ArgumentParser() 113 | subcommands = args.add_subparsers(title="subcommands") 114 | train_args(subcommands.add_parser( 115 | "train", help="Train a model.")) 116 | eval_args(subcommands.add_parser( 117 | "evaluate", help="Evaluate a trained model.")) 118 | generate_args(subcommands.add_parser( 119 | "generate", help="Generate some text.")) 120 | flags, unparsed_args = args.parse_known_args(sys.argv[1:]) 121 | if not hasattr(flags, "entry"): 122 | args.print_help() 123 | return 1 124 | tf.logging.set_verbosity(tf.logging.INFO) 125 | tf.app.run(main=lambda _unused_argv: flags.entry(flags), 126 | argv=[sys.argv[0]] + unparsed_args) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /vaeseq/vae_module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Base classes for modules that implement sequential VAEs.""" 16 | 17 | import abc 18 | import tensorflow as tf 19 | import sonnet as snt 20 | 21 | from . import context as context_mod 22 | from . import dist_module 23 | from . import util 24 | 25 | 26 | class VAECore(dist_module.DistCore): 27 | """Base class for sequential VAE implementations.""" 28 | 29 | def __init__(self, hparams, obs_encoder, obs_decoder, name=None): 30 | super(VAECore, self).__init__(name=name) 31 | self._hparams = hparams 32 | self._obs_encoder = obs_encoder 33 | self._obs_decoder = obs_decoder 34 | 35 | @abc.abstractmethod 36 | def _infer_latents(self, inputs, observed): 37 | """Returns a sequence of latent states and their divergences.""" 38 | 39 | def infer_latents(self, inputs, observed): 40 | inputs = context_mod.as_tensors(inputs, observed) 41 | return self._infer_latents(inputs, observed) 42 | 43 | @property 44 | def event_size(self): 45 | return self._obs_decoder.event_size 46 | 47 | @property 48 | def event_dtype(self): 49 | return self._obs_decoder.event_dtype 50 | 51 | def dist(self, params, name=None): 52 | return self._obs_decoder.dist(params, name=name) 53 | 54 | def evaluate(self, inputs, observed, 55 | latents=None, initial_state=None, samples=1): 56 | """Evaluates the log-probabilities of each given observation.""" 57 | inputs = context_mod.as_tensors(inputs, observed) 58 | cell, inputs = self.log_probs, (inputs, observed) 59 | if latents is not None: 60 | if initial_state is not None: 61 | raise ValueError("Cannot specify initial state and latents.") 62 | cell, inputs = util.use_recorded_state_rnn(cell), (inputs, latents) 63 | cell, inputs = util.add_support_for_scalar_rnn_inputs(cell, inputs) 64 | def _make_initial_state(): 65 | if initial_state is not None: 66 | return initial_state 67 | batch_size = util.batch_size_from_nested_tensors(observed) 68 | return self.initial_state(batch_size) 69 | return _average_runs(samples, cell, inputs, _make_initial_state) 70 | 71 | def generate(self, 72 | inputs, 73 | batch_size=None, # defaults to hparams.batch_size 74 | sequence_size=None, # defaults to hparams.sequence_size 75 | initial_state=None, 76 | inputs_initial_state=None): 77 | """Generates a sequence of observations.""" 78 | inputs = context_mod.as_context(inputs) 79 | if sequence_size is None: 80 | sequence_size = util.sequence_size(self._hparams) 81 | 82 | # Create initial states. 83 | infer_batch_size = batch_size 84 | if batch_size is None: 85 | infer_batch_size = util.batch_size(self._hparams) 86 | if inputs_initial_state is None: 87 | inputs_initial_state = inputs.initial_state(infer_batch_size) 88 | if batch_size is None: 89 | infer_batch_size = util.batch_size_from_nested_tensors( 90 | inputs_initial_state) 91 | if initial_state is None: 92 | initial_state = self.initial_state(infer_batch_size) 93 | 94 | cell = util.state_recording_rnn(self.samples) 95 | cell_output_observations = lambda out: out[0] 96 | return inputs.drive_rnn( 97 | cell, 98 | sequence_size=sequence_size, 99 | initial_state=inputs_initial_state, 100 | cell_initial_state=initial_state, 101 | cell_output_dtype=(self.event_dtype, self.state_dtype), 102 | cell_output_observations=cell_output_observations) 103 | 104 | 105 | def _average_runs(num_runs, cell, inputs, make_initial_state): 106 | """Run the RNN outputs over num_run runs.""" 107 | def _run(unused_arg): 108 | del unused_arg 109 | return tf.nn.dynamic_rnn( 110 | cell, inputs, 111 | initial_state=make_initial_state(), 112 | dtype=tf.float32)[0] 113 | if num_runs == 1: 114 | return _run(None) 115 | runs = tf.map_fn(_run, tf.zeros([num_runs, 0])) 116 | return tf.reduce_mean(runs, axis=0) 117 | -------------------------------------------------------------------------------- /vaeseq/vae/independent_sequence.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Simple extension of VAE to a sequential setting. 16 | 17 | Notation: 18 | - z_1:T are hidden states, random variables. 19 | - d_1:T, e_1:T, and f_1:T are deterministic RNN outputs. 20 | - x_1:T are the observed states. 21 | - c_1:T are per-timestep inputs. 22 | 23 | Generative model Inference model 24 | ===================== ===================== 25 | x_1 x_t z_1 z_t 26 | ^ ^ ^ ^ 27 | | | | | 28 | d_1 ------------> d_t f_1 <----- f_t 29 | ^ ^ ^ ^ 30 | | | | | 31 | [c_1, z_1] [c_t, z_t] e_1 -----> e_t 32 | ^ ^ 33 | | | 34 | [c_1, x_1] [c_t, x_t] 35 | """ 36 | 37 | import sonnet as snt 38 | import tensorflow as tf 39 | 40 | from .. import latent as latent_mod 41 | from .. import util 42 | from .. import vae_module 43 | 44 | class IndependentSequence(vae_module.VAECore): 45 | """Implementation of a Sequential VAE with independent latent variables.""" 46 | 47 | def __init__(self, hparams, obs_encoder, obs_decoder, name=None): 48 | super(IndependentSequence, self).__init__( 49 | hparams, obs_encoder, obs_decoder, name) 50 | with self._enter_variable_scope(): 51 | self._d_core = util.make_rnn(hparams, name="d_core") 52 | self._e_core = util.make_rnn(hparams, name="e_core") 53 | self._f_core = util.make_rnn(hparams, name="f_core") 54 | self._q_z = latent_mod.LatentDecoder(hparams, name="latent_q") 55 | 56 | @property 57 | def state_size(self): 58 | return (self._d_core.state_size, self._q_z.event_size) 59 | 60 | def _build(self, input_, state): 61 | d_state, latent = state 62 | d_out, d_state = self._d_core( 63 | util.concat_features((input_, latent)), d_state) 64 | return self._obs_decoder(d_out), d_state 65 | 66 | def _next_state(self, d_state, event=None): 67 | del event # Not used. 68 | batch_size = util.batch_size_from_nested_tensors(d_state) 69 | latent_dist = _latent_prior(self._hparams, batch_size) 70 | return (d_state, latent_dist) 71 | 72 | def _initial_state(self, batch_size): 73 | return self._next_state( 74 | self._d_core.initial_state(batch_size), event=None) 75 | 76 | def _infer_latents(self, inputs, observed): 77 | hparams = self._hparams 78 | batch_size = util.batch_size_from_nested_tensors(observed) 79 | enc_observed = snt.BatchApply(self._obs_encoder, n_dims=2)(observed) 80 | e_outs, _ = tf.nn.dynamic_rnn( 81 | self._e_core, 82 | util.concat_features((inputs, enc_observed)), 83 | initial_state=self._e_core.initial_state(batch_size)) 84 | f_outs, _ = util.reverse_dynamic_rnn( 85 | self._f_core, 86 | e_outs, 87 | initial_state=self._f_core.initial_state(batch_size)) 88 | q_zs = self._q_z.dist( 89 | snt.BatchApply(self._q_z, n_dims=2)(f_outs), 90 | name="q_zs") 91 | latents = q_zs.sample() 92 | p_zs = tf.contrib.distributions.MultivariateNormalDiag( 93 | loc=tf.zeros_like(latents), 94 | scale_diag=tf.ones_like(latents), 95 | name="p_zs") 96 | divs = util.calc_kl(hparams, latents, q_zs, p_zs) 97 | (_unused_d_outs, d_states), _ = tf.nn.dynamic_rnn( 98 | util.state_recording_rnn(self._d_core), 99 | util.concat_features((inputs, latents)), 100 | initial_state=self._d_core.initial_state(batch_size)) 101 | return (d_states, latents), divs 102 | 103 | 104 | def _latent_prior(hparams, batch_size): 105 | dims = tf.stack([batch_size, hparams.latent_size]) 106 | loc = tf.zeros(dims) 107 | loc.set_shape([None, hparams.latent_size]) 108 | scale_diag = tf.ones(dims) 109 | scale_diag.set_shape([None, hparams.latent_size]) 110 | return tf.contrib.distributions.MultivariateNormalDiag( 111 | loc=loc, scale_diag=scale_diag, name="latent") 112 | -------------------------------------------------------------------------------- /vaeseq/examples/play/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions to build up training and generation graphs.""" 16 | 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | import sonnet as snt 22 | 23 | from vaeseq import context as context_mod 24 | from vaeseq import model as model_mod 25 | from vaeseq import train as train_mod 26 | from vaeseq import util 27 | from vaeseq import vae as vae_mod 28 | 29 | from . import agent as agent_mod 30 | from . import codec as codec_mod 31 | from . import environment 32 | 33 | 34 | class Model(model_mod.ModelBase): 35 | """Putting everything together.""" 36 | 37 | def __init__(self, hparams, session_params): 38 | self.env = environment.Environment(hparams) 39 | super(Model, self).__init__(hparams, session_params) 40 | 41 | def _make_encoder(self): 42 | return codec_mod.ObsEncoder(self.hparams) 43 | 44 | def _make_decoder(self): 45 | return codec_mod.ObsDecoder(self.hparams) 46 | 47 | def _make_agent(self): 48 | return agent_mod.TrainableAgent(self.hparams, self.encoder) 49 | 50 | def _make_feedback(self): 51 | # Inputs are agent action logits; pass them through as context. 52 | input_encoder = codec_mod.InputEncoder(self.hparams) 53 | return context_mod.EncodeObserved(self.encoder, 54 | input_encoder=input_encoder) 55 | 56 | def _make_dataset(self, dataset): 57 | del dataset # Not used. 58 | cell = self.env 59 | cell = util.input_recording_rnn( 60 | cell, 61 | input_size=self.agent.output_size) 62 | cell_output_dtype = (self.decoder.event_dtype, 63 | self.agent.output_dtype) 64 | cell_output_observations = lambda out: out[0] 65 | sequence_size = util.sequence_size(self.hparams) 66 | def _drive_env(agent, batch_size): 67 | cell_initial_state = self.env.initial_state(batch_size) 68 | observed, inputs = agent.drive_rnn( 69 | cell=cell, 70 | sequence_size=sequence_size, 71 | initial_state=agent.initial_state(batch_size), 72 | cell_initial_state=cell_initial_state, 73 | cell_output_dtype=cell_output_dtype, 74 | cell_output_observations=cell_output_observations) 75 | return inputs, observed 76 | 77 | batch_size = util.batch_size(self.hparams) 78 | train_batch_size = batch_size // 2 79 | random_batch_size = batch_size - train_batch_size 80 | inputs1, observed1 = _drive_env(self.agent, train_batch_size) 81 | inputs2, observed2 = _drive_env(agent_mod.RandomAgent(self.hparams), 82 | random_batch_size) 83 | tf.summary.histogram("actions", tf.argmax(inputs1, axis=-1)) 84 | inputs, observed = snt.nest.map( 85 | lambda t1, t2: tf.concat([t1, t2], axis=0), 86 | (inputs1, observed1), 87 | (inputs2, observed2)) 88 | return inputs, observed 89 | 90 | def _make_output_summary(self, tag, observed): 91 | return tf.summary.scalar( 92 | tag + "/score", 93 | tf.reduce_mean(tf.reduce_sum(observed["score"], axis=1), axis=0), 94 | collections=[]) 95 | 96 | def _make_elbo_trainer(self): 97 | global_step = tf.train.get_or_create_global_step() 98 | loss = train_mod.ELBOLoss(self.hparams, self.vae) 99 | def _variables(): 100 | agent_vars = set(self.agent.get_variables()) 101 | return [var for var in tf.trainable_variables() 102 | if var not in agent_vars] 103 | return train_mod.Trainer(self.hparams, global_step=global_step, 104 | loss=loss, variables=_variables, 105 | name="elbo_trainer") 106 | 107 | def _make_agent_trainer(self): 108 | global_step = None # Do not increment global step twice per turn. 109 | loss = train_mod.RewardLoss( 110 | self.hparams, self.inputs, self.vae, 111 | reward=lambda observed: observed["score"]) 112 | return train_mod.Trainer(self.hparams, global_step=global_step, 113 | loss=loss, variables=self.agent.get_variables, 114 | name="agent_trainer") 115 | 116 | def _make_trainer(self): 117 | return train_mod.Group([self._make_elbo_trainer(), 118 | self._make_agent_trainer()]) 119 | -------------------------------------------------------------------------------- /vaeseq/examples/play/codec.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Coders/Decoders for game observations.""" 16 | 17 | import numpy as np 18 | import sonnet as snt 19 | import tensorflow as tf 20 | 21 | from vaeseq import batch_dist 22 | from vaeseq import codec 23 | from vaeseq import dist_module 24 | from vaeseq import util 25 | 26 | 27 | ObsEncoder = codec.MLPObsEncoder 28 | 29 | 30 | class InputEncoder(codec.FlattenEncoder): 31 | """Passes through the input action logits.""" 32 | 33 | def __init__(self, hparams, name=None): 34 | input_size = tf.TensorShape([hparams.game_action_space]) 35 | super(InputEncoder, self).__init__(input_size=input_size, name=name) 36 | 37 | 38 | class ObsDecoder(dist_module.DistModule): 39 | """Parameterizes a set of distributions for outputs, score and game-over. 40 | 41 | We're modeling three components for each observation: 42 | 43 | * The game outputs modeled by a diagonal multivariate normal. 44 | * The current score (a Normal distribution). 45 | * Whether the game is over next step. For simplicity, modeled as a normal 46 | with -1/+1 labels. 47 | 48 | All output distributions are reparameterizable, so there is a 49 | pathwise derivative w.r.t. their parameters. 50 | """ 51 | 52 | def __init__(self, hparams, name=None): 53 | super(ObsDecoder, self).__init__(name=name) 54 | self._hparams = hparams 55 | 56 | @property 57 | def event_size(self): 58 | return dict(output=tf.TensorShape(self._hparams.game_output_size), 59 | score=tf.TensorShape([]), 60 | game_over=tf.TensorShape([])) 61 | 62 | @property 63 | def event_dtype(self): 64 | return dict(output=tf.float32, 65 | score=tf.float32, 66 | game_over=tf.float32) 67 | 68 | def dist(self, params, name=None): 69 | """The output distribution.""" 70 | name = name or self.module_name + "_dist" 71 | with tf.name_scope(name): 72 | params_output, params_score, params_game_over = params 73 | components = dict( 74 | output=self._dist_output(params_output), 75 | score=self._dist_score(params_score), 76 | game_over=self._dist_game_over(params_game_over)) 77 | return batch_dist.GroupDistribution(components, name=name) 78 | 79 | def _dist_output(self, params): 80 | """Distribution over the game outputs.""" 81 | loc, scale_diag = params 82 | return tf.contrib.distributions.MultivariateNormalDiag( 83 | loc, scale_diag, name="game_output") 84 | 85 | def _dist_score(self, params): 86 | """Distribution for the game score.""" 87 | loc, scale = params 88 | return tf.distributions.Normal(loc, scale, name="score") 89 | 90 | def _dist_game_over(self, params): 91 | """Distribution for the game over flag.""" 92 | loc, scale = params 93 | return tf.distributions.Normal(loc, scale, name="game_over") 94 | 95 | def _build(self, inputs): 96 | hparams = self._hparams 97 | hidden = snt.Sequential([ 98 | util.concat_features, 99 | util.make_mlp( 100 | hparams, 101 | hparams.obs_decoder_fc_hidden_layers, 102 | activate_final=True), 103 | ])(inputs) 104 | return (self._build_game_output(hidden), 105 | self._build_score(hidden), 106 | self._build_game_over(hidden)) 107 | 108 | def _build_game_output(self, hidden): 109 | """Parameters for the game output prediction.""" 110 | game_outputs = np.product(self._hparams.game_output_size) 111 | lin = snt.Linear(2 * game_outputs, name="game_obs") 112 | loc, scale_diag_unproj = tf.split(lin(hidden), 2, axis=-1) 113 | scale_diag = util.positive_projection(self._hparams)(scale_diag_unproj) 114 | return loc, scale_diag 115 | 116 | def _build_score(self, hidden): 117 | """Parameters for the game score prediction.""" 118 | lin = snt.Linear(2, name="score") 119 | loc, scale_unproj = tf.unstack(lin(hidden), axis=-1) 120 | scale = util.positive_projection(self._hparams)(scale_unproj) 121 | return loc, scale 122 | 123 | def _build_game_over(self, hidden): 124 | """Parameters for the game over prediction.""" 125 | lin = snt.Linear(2, name="game_over") 126 | loc, scale_unproj = tf.unstack(lin(hidden), axis=-1) 127 | scale = util.positive_projection(self._hparams)(scale_unproj) 128 | return loc, scale 129 | -------------------------------------------------------------------------------- /vaeseq/examples/play/environment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Environment that runs OpenAI Gym games.""" 16 | 17 | import threading 18 | import time 19 | 20 | import gym 21 | import numpy as np 22 | import sonnet as snt 23 | import tensorflow as tf 24 | 25 | from vaeseq import util 26 | 27 | 28 | class Environment(snt.RNNCore): 29 | """Plays a batch of games.""" 30 | 31 | def __init__(self, hparams, name=None): 32 | super(Environment, self).__init__(name=name) 33 | self._hparams = hparams 34 | self._games = {} 35 | self._games_lock = threading.Lock() 36 | self._next_id = 1 37 | self._id_lock = threading.Lock() 38 | self._step_time = None 39 | self._render_thread = None 40 | 41 | @property 42 | def output_size(self): 43 | return dict(output=tf.TensorShape(self._hparams.game_output_size), 44 | score=tf.TensorShape([]), 45 | game_over=tf.TensorShape([])) 46 | 47 | @property 48 | def output_dtype(self): 49 | return dict(output=tf.float32, 50 | score=tf.float32, 51 | game_over=tf.float32) 52 | 53 | @property 54 | def state_size(self): 55 | """The state is a game ID, or 0 if the game is over.""" 56 | return tf.TensorShape([]) 57 | 58 | @property 59 | def state_dtype(self): 60 | """The state is a game ID, or 0 if the game is over.""" 61 | return tf.int64 62 | 63 | def initial_state(self, batch_size): 64 | def _make_games(batch_size): 65 | """Produces a serialized batch of randomized games.""" 66 | with self._id_lock: 67 | first_id = self._next_id 68 | self._next_id += batch_size 69 | game_ids = range(first_id, self._next_id) 70 | updates = [] 71 | for game_id in game_ids: 72 | game = gym.make(self._hparams.game) 73 | game.reset() 74 | updates.append((game_id, game)) 75 | with self._games_lock: 76 | self._games.update(updates) 77 | return np.asarray(game_ids, dtype=np.int64) 78 | 79 | state, = tf.py_func(_make_games, [batch_size], [tf.int64]) 80 | state.set_shape([None]) 81 | return state 82 | 83 | def _build(self, input_, state): 84 | actions = tf.distributions.Categorical(logits=input_).sample() 85 | 86 | def _step_games(actions, state): 87 | """Take a step in a single game.""" 88 | score = np.zeros(len(state), dtype=np.float32) 89 | output = np.zeros([len(state)] + self._hparams.game_output_size, 90 | dtype=np.float32) 91 | games = [None] * len(state) 92 | with self._games_lock: 93 | for i, game_id in enumerate(state): 94 | if game_id: 95 | games[i] = self._games[game_id] 96 | finished_games = [] 97 | for i, game in enumerate(games): 98 | if game is None: 99 | continue 100 | output[i], score[i], game_over, _ = game.step(actions[i]) 101 | if game_over: 102 | finished_games.append(state[i]) 103 | state[i] = 0 104 | if finished_games: 105 | with self._games_lock: 106 | for game_id in finished_games: 107 | del self._games[game_id] 108 | if self._render_thread is not None: 109 | time.sleep(0.1) 110 | return output, score, state 111 | 112 | output, score, state = tf.py_func( 113 | _step_games, [actions, state], 114 | [tf.float32, tf.float32, tf.int64]) 115 | output = dict(output=output, score=score, 116 | game_over=2. * tf.to_float(tf.equal(state, 0)) - 1.) 117 | # Fix up the inferred shapes. 118 | util.set_tensor_shapes(output, self.output_size, add_batch_dims=1) 119 | util.set_tensor_shapes(state, self.state_size, add_batch_dims=1) 120 | return output, state 121 | 122 | def start_render_thread(self): 123 | if self._render_thread is not None: 124 | return self._render_thread 125 | self._render_thread = threading.Thread(target=self._render_games_loop) 126 | self._render_thread.start() 127 | 128 | def stop_render_thread(self): 129 | if self._render_thread is None: 130 | return 131 | tmp = self._render_thread 132 | self._render_thread = None 133 | tmp.join() 134 | 135 | def _render_games_loop(self): 136 | while (self._render_thread is not None and 137 | threading.current_thread().ident == self._render_thread.ident): 138 | with self._games_lock: 139 | games = list(self._games.values()) 140 | for game in games: 141 | game.render() 142 | time.sleep(0.05) 143 | -------------------------------------------------------------------------------- /vaeseq/dist_module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Base classes for modules that return distributions.""" 16 | 17 | import abc 18 | import functools 19 | import tensorflow as tf 20 | import sonnet as snt 21 | 22 | from . import util 23 | 24 | 25 | class DistModule(snt.AbstractModule): 26 | """A module that returns parameters for a Distribution.""" 27 | 28 | @abc.abstractproperty 29 | def event_dtype(self): 30 | """Returns the output distribution event dtypes.""" 31 | 32 | @abc.abstractproperty 33 | def event_size(self): 34 | """Returns the output distribution event sizes.""" 35 | 36 | @abc.abstractmethod 37 | def dist(self, params, name=None): 38 | """Constructs a Distribution parameterized by the module output. 39 | This method is separate from _build (which returns Tensors) 40 | to allow batch application: module.dist(BatchApply(module)(...)). 41 | """ 42 | 43 | 44 | class DistCore(DistModule): 45 | """Like an RNNCore, but outputs distributions.""" 46 | 47 | @abc.abstractproperty 48 | def state_size(self): 49 | """Returns the non-batched sizes of Tensors returned from next_state.""" 50 | 51 | @property 52 | def state_dtype(self): 53 | """Returns the types of the (possibly sampled) state Tensors.""" 54 | return snt.nest.map(lambda _: tf.float32, self.state_size) 55 | 56 | @abc.abstractmethod 57 | def _initial_state(self, batch_size): 58 | """Creates the initial state Tensors and Distributions.""" 59 | 60 | def initial_state(self, batch_size, sampled=True): 61 | """Creates the initial state. If samped is True, 62 | then all distributions are sampled.""" 63 | with self._enter_variable_scope(): 64 | with tf.name_scope("initial_state"): 65 | state = self._initial_state(batch_size) 66 | if not sampled: 67 | return state 68 | return _sample_distributions(state) 69 | 70 | @abc.abstractmethod 71 | def _next_state(self, state_arg, event=None): 72 | """Produces the next state given a state_arg and event. 73 | NOTE: this function shouldn't allocate variables.""" 74 | 75 | def next_state(self, state_arg, event=None, sampled=True): 76 | """Produces the next state given a state_arg and event, optionally 77 | replacing all distribution components of the state with samples.""" 78 | with self._enter_variable_scope(): 79 | with tf.name_scope("next_state"): 80 | state = self._next_state(state_arg, event=event) 81 | if not sampled: 82 | return state 83 | return _sample_distributions(state) 84 | 85 | def next_sample(self, input_, state, with_log_prob=False): 86 | """Returns the next sample and state from the distribution.""" 87 | dist_arg, state_arg = self(input_, state) 88 | dist = self.dist(dist_arg) 89 | event = dist.sample() 90 | util.set_tensor_shapes(event, dist.event_shape, add_batch_dims=1) 91 | state = self.next_state(state_arg, event) 92 | if with_log_prob: 93 | return (event, dist.log_prob(event)), state 94 | return event, state 95 | 96 | def next_log_prob(self, input_and_observed, state): 97 | """Returns the log-prob(observed) for the next step.""" 98 | input_, observed = input_and_observed 99 | dist_arg, state_arg = self(input_, state) 100 | dist = self.dist(dist_arg) 101 | state = self.next_state(state_arg, observed) 102 | return dist.log_prob(observed), state 103 | 104 | @util.lazy_property 105 | def samples(self): 106 | """Returns an RNNCore that produces a sequence of samples.""" 107 | return util.WrapRNNCore( 108 | functools.partial(self.next_sample, with_log_prob=False), 109 | self.state_size, 110 | output_size=self.event_size, 111 | name=self.module_name + "/samples") 112 | 113 | @util.lazy_property 114 | def samples_with_log_probs(self): 115 | """Returns an RNNCore that produces (sample, log-prob(sample)).""" 116 | return util.WrapRNNCore( 117 | functools.partial(self.next_sample, with_log_prob=True), 118 | self.state_size, 119 | output_size=(self.event_size, tf.TensorShape([])), 120 | name=self.module_name + "/samples_with_log_probs") 121 | 122 | @util.lazy_property 123 | def log_probs(self): 124 | """Returns an RNNCore that evaluates the log-prob of the input.""" 125 | return util.WrapRNNCore( 126 | self.next_log_prob, 127 | self.state_size, 128 | output_size=tf.TensorShape([]), 129 | name=self.module_name + "/log_probs") 130 | 131 | 132 | def _sample_distributions(components): 133 | """Samples all distributions within components.""" 134 | def _sample_component(component): 135 | if isinstance(component, tf.distributions.Distribution): 136 | return component.sample() 137 | return component 138 | return snt.nest.map(_sample_component, components) 139 | -------------------------------------------------------------------------------- /vaeseq/batch_dist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Distributions over independent sets of events.""" 16 | 17 | import sonnet as snt 18 | import tensorflow as tf 19 | 20 | 21 | class BatchDistribution(tf.distributions.Distribution): 22 | """Wrap a distribution to shift batch dimensions into the event shape.""" 23 | 24 | def __init__(self, distribution, ndims=1, name=None): 25 | parameters = locals() 26 | self._dist = distribution 27 | self._ndims = ndims 28 | super(BatchDistribution, self).__init__( 29 | dtype=distribution.dtype, 30 | reparameterization_type=distribution.reparameterization_type, 31 | validate_args=distribution.validate_args, 32 | allow_nan_stats=distribution.allow_nan_stats, 33 | parameters=parameters, 34 | graph_parents=distribution._graph_parents, 35 | name=name or "batch_" + distribution.name 36 | ) 37 | 38 | def _sample_n(self, n, seed=None): 39 | return self._dist._sample_n(n, seed=seed) 40 | 41 | def _batch_shape_tensor(self): 42 | return self._dist.batch_shape_tensor()[:-self._ndims] 43 | 44 | def _batch_shape(self): 45 | return self._dist.batch_shape[:-self._ndims] 46 | 47 | def _event_shape_tensor(self): 48 | batch_dims = self._dist.batch_shape_tensor()[-self._ndims:] 49 | return tf.concat([batch_dims, self._dist.event_shape_tensor()], 0) 50 | 51 | def _event_shape(self): 52 | batch_dims = self._dist.batch_shape[-self._ndims:] 53 | return batch_dims.concatenate(self._dist.event_shape) 54 | 55 | def _log_prob(self, event): 56 | log_probs = self._dist._log_prob(event) 57 | return tf.reduce_sum(log_probs, axis=list(range(-self._ndims, 0))) 58 | 59 | def _prob(self, event): 60 | probs = self._dist._prob(event) 61 | return tf.reduce_prod(probs, axis=list(range(-self._ndims, 0))) 62 | 63 | 64 | class GroupDistribution(tf.distributions.Distribution): 65 | """Group together several independent distributions. 66 | 67 | Note, the batch shapes of the component distributions must match. 68 | """ 69 | 70 | def __init__(self, distributions, name=None): 71 | parameters = locals() 72 | self._dists = distributions 73 | self._flat_dists = snt.nest.flatten(distributions) 74 | dtype = snt.nest.map(lambda dist: dist.dtype, distributions) 75 | r16n_type = tf.distributions.FULLY_REPARAMETERIZED 76 | for dist in self._flat_dists: 77 | r16n_type = dist.reparameterization_type 78 | if r16n_type is not tf.distributions.FULLY_REPARAMETERIZED: 79 | break 80 | validate_args = all([dist.validate_args for dist in self._flat_dists]) 81 | allow_nan_stats = all( 82 | [dist.allow_nan_stats for dist in self._flat_dists]) 83 | graph_parents = snt.nest.flatten( 84 | [dist._graph_parents for dist in self._flat_dists]) 85 | name = name or "_".join([dist.name for dist in self._flat_dists]) 86 | super(GroupDistribution, self).__init__( 87 | dtype=dtype, 88 | reparameterization_type=r16n_type, 89 | validate_args=validate_args, 90 | allow_nan_stats=allow_nan_stats, 91 | parameters=parameters, 92 | graph_parents=graph_parents, 93 | name=name) 94 | 95 | @property 96 | def batch_shape(self): 97 | return snt.nest.map(lambda dist: dist.batch_shape, self._dists) 98 | 99 | def batch_shape_tensor(self, name="batch_shape_tensor"): 100 | with self._name_scope(name): 101 | return snt.nest.map( 102 | lambda dist: dist.batch_shape_tensor(name), self._dists) 103 | 104 | @property 105 | def event_shape(self): 106 | return snt.nest.map(lambda dist: dist.event_shape, self._dists) 107 | 108 | def event_shape_tensor(self, name="event_shape_tensor"): 109 | with self._name_scope(name): 110 | return snt.nest.map( 111 | lambda dist: dist.event_shape_tensor(name), self._dists) 112 | 113 | def _is_scalar_helper(self, *args, **kwargs): 114 | if not self._flat_dists: 115 | return True 116 | if len(self._flat_dists) == 1: 117 | return self._flat_dists[0]._is_scalar_helper(*args, **kwargs) 118 | return False 119 | 120 | def sample(self, *args, **kwargs): 121 | return snt.nest.map( 122 | lambda dist: dist.sample(*args, **kwargs), 123 | self._dists) 124 | 125 | def log_prob(self, value, name="log_prob"): 126 | flat_values = snt.nest.flatten(value) 127 | with self._name_scope(name, values=flat_values): 128 | return tf.reduce_sum( 129 | [dist.log_prob(val) 130 | for dist, val in zip(self._flat_dists, flat_values)], 131 | axis=0) 132 | 133 | def prob(self, value, name="prob"): 134 | flat_values = snt.nest.flatten(value) 135 | with self._name_scope(name, values=flat_values): 136 | return tf.reduce_prod( 137 | [dist.prob(val) 138 | for dist, val in zip(self._flat_dists, flat_values)], 139 | axis=0) 140 | -------------------------------------------------------------------------------- /vaeseq/vae/srnn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2018 Google, Inc., 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r""""SRNN as described in: 17 | 18 | Marco Fraccaro, Søren Kaae Sønderby, Ulrich Paquet, Ole Winther. 19 | Sequential Neural Models with Stochastic Layers. 20 | https://arxiv.org/abs/1605.07571 21 | 22 | Notation: 23 | - z_0:T are hidden states, random variables 24 | - d_1:T and e_1:T are deterministic RNN outputs 25 | - x_1:T are the observed states 26 | - c_1:T are the per-timestep inputs 27 | 28 | Generative model Inference model 29 | ===================== ===================== 30 | z_0 -> z_1 -----> z_t z_0 -> z_1 ---------> z_t 31 | | ^ | ^ ^ ^ 32 | v | v | | | 33 | x_1 <-. | x_t <-. | | | 34 | \| \| e_1 <--------- e_t 35 | * * / ^ / ^ 36 | | | x_1 | x_t | 37 | d_1 -----> d_t d_1 ---------> d_t 38 | ^ ^ ^ ^ 39 | | | | | 40 | c_1 c_t c_1 c_t 41 | """ 42 | 43 | import sonnet as snt 44 | import tensorflow as tf 45 | 46 | from .. import latent as latent_mod 47 | from .. import util 48 | from .. import vae_module 49 | 50 | class SRNN(vae_module.VAECore): 51 | """Implementation of SRNN (see module description).""" 52 | 53 | def __init__(self, hparams, obs_encoder, obs_decoder, name=None): 54 | super(SRNN, self).__init__(hparams, obs_encoder, obs_decoder, name) 55 | with self._enter_variable_scope(): 56 | self._d_core = util.make_rnn(hparams, name="d_core") 57 | self._e_core = util.make_rnn(hparams, name="e_core") 58 | self._latent_p = latent_mod.LatentDecoder(hparams, name="latent_p") 59 | self._latent_q = latent_mod.LatentDecoder(hparams, name="latent_q") 60 | 61 | @property 62 | def state_size(self): 63 | return (self._d_core.state_size, self._latent_p.event_size) 64 | 65 | def _build(self, input_, state): 66 | d_state, latent = state 67 | d_out, d_state = self._d_core(util.concat_features(input_), d_state) 68 | latent_params = self._latent_p(d_out, latent) 69 | return (self._obs_decoder(util.concat_features((d_out, latent))), 70 | (d_state, latent_params)) 71 | 72 | def _next_state(self, state_arg, event=None): 73 | del event # Not used. 74 | d_state, latent_params = state_arg 75 | return d_state, self._latent_p.dist(latent_params, name="latent") 76 | 77 | def _initial_state(self, batch_size): 78 | d_state = self._d_core.initial_state(batch_size) 79 | latent_input_sizes = (self._d_core.output_size, 80 | self._latent_p.event_size) 81 | latent_inputs = snt.nest.map( 82 | lambda size: tf.zeros( 83 | [batch_size] + tf.TensorShape(size).as_list(), 84 | name="latent_input"), 85 | latent_input_sizes) 86 | latent_params = self._latent_p(latent_inputs) 87 | return self._next_state((d_state, latent_params), event=None) 88 | 89 | def _infer_latents(self, inputs, observed): 90 | hparams = self._hparams 91 | batch_size = util.batch_size_from_nested_tensors(observed) 92 | d_initial, z_initial = self.initial_state(batch_size) 93 | (d_outs, d_states), _ = tf.nn.dynamic_rnn( 94 | util.state_recording_rnn(self._d_core), 95 | util.concat_features(inputs), 96 | initial_state=d_initial) 97 | enc_observed = snt.BatchApply(self._obs_encoder, n_dims=2)(observed) 98 | e_outs, _ = util.reverse_dynamic_rnn( 99 | self._e_core, 100 | util.concat_features((enc_observed, inputs)), 101 | initial_state=self._e_core.initial_state(batch_size)) 102 | 103 | def _inf_step(d_e_outputs, prev_latent): 104 | """Iterate over d_1:T and e_1:T to produce z_1:T.""" 105 | d_out, e_out = d_e_outputs 106 | p_z_params = self._latent_p(d_out, prev_latent) 107 | p_z = self._latent_p.dist(p_z_params) 108 | q_loc, q_scale = self._latent_q(e_out, prev_latent) 109 | if hparams.srnn_use_res_q: 110 | q_loc += p_z.loc 111 | q_z = self._latent_q.dist((q_loc, q_scale), name="q_z_dist") 112 | latent = q_z.sample() 113 | divergence = util.calc_kl(hparams, latent, q_z, p_z) 114 | return (latent, divergence), latent 115 | 116 | inf_core = util.WrapRNNCore( 117 | _inf_step, 118 | state_size=tf.TensorShape(hparams.latent_size), # prev_latent 119 | output_size=(tf.TensorShape(hparams.latent_size), # latent 120 | tf.TensorShape([]),), # divergence 121 | name="inf_z_core") 122 | (latents, kls), _ = util.heterogeneous_dynamic_rnn( 123 | inf_core, 124 | (d_outs, e_outs), 125 | initial_state=z_initial, 126 | output_dtypes=(self._latent_q.event_dtype, tf.float32)) 127 | return (d_states, latents), kls 128 | -------------------------------------------------------------------------------- /vaeseq/examples/midi/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """The model for MIDI music. 16 | 17 | At each time step, we predict a pair of: 18 | * 128 independent Beta variables, assigning scores to each note. 19 | * K in [0,10], a Categorical variable counting the number of notes played. 20 | 21 | When generating music, we emit the top K notes per timestep. 22 | """ 23 | 24 | from __future__ import print_function 25 | 26 | import numpy as np 27 | import tensorflow as tf 28 | 29 | from vaeseq import codec as codec_mod 30 | from vaeseq import context as context_mod 31 | from vaeseq import model as model_mod 32 | from vaeseq import util 33 | 34 | from . import dataset as dataset_mod 35 | 36 | 37 | class Model(model_mod.ModelBase): 38 | """Putting everything together.""" 39 | 40 | def _make_encoder(self): 41 | """Constructs an encoder for a single observation.""" 42 | return codec_mod.MLPObsEncoder(self.hparams, name="obs_encoder") 43 | 44 | def _make_decoder(self): 45 | """Constructs a decoder for a single observation.""" 46 | # We need 2 * 128 (note beta) + 11 (count categorical) parameters. 47 | params = util.make_mlp( 48 | self.hparams, 49 | self.hparams.obs_decoder_fc_hidden_layers + [128 * 2 + 11]) 50 | def _split_params(inp): 51 | note_params, count_param = tf.split(inp, [128 * 2, 11], axis=-1) 52 | return (note_params, count_param) # Note: returning a tuple. 53 | single_note_decoder = codec_mod.BetaDecoder( 54 | positive_projection=util.positive_projection(self.hparams)) 55 | notes_decoder = codec_mod.BatchDecoder( 56 | single_note_decoder, event_size=[128], name="notes_decoder") 57 | count_decoder = codec_mod.CategoricalDecoder(name="count_decoder") 58 | full_decoder = codec_mod.GroupDecoder((notes_decoder, count_decoder)) 59 | return codec_mod.DecoderSequence( 60 | [params, _split_params], full_decoder, name="decoder") 61 | 62 | def _make_feedback(self): 63 | """Constructs the feedback Context.""" 64 | history_combiner = codec_mod.EncoderSequence( 65 | [codec_mod.FlattenEncoder(), 66 | util.make_mlp(self.hparams, 67 | self.hparams.history_encoder_fc_layers)], 68 | name="history_combiner" 69 | ) 70 | return context_mod.Accumulate( 71 | obs_encoder=self.encoder, 72 | history_size=self.hparams.history_size, 73 | history_combiner=history_combiner) 74 | 75 | def _make_dataset(self, files): 76 | dataset = dataset_mod.piano_roll_sequences( 77 | files, 78 | util.batch_size(self.hparams), 79 | util.sequence_size(self.hparams), 80 | rate=self.hparams.rate) 81 | iterator = dataset.make_initializable_iterator() 82 | tf.add_to_collection(tf.GraphKeys.LOCAL_INIT_OP, iterator.initializer) 83 | piano_roll = iterator.get_next() 84 | shape = tf.shape(piano_roll) 85 | notes = tf.where(piano_roll, tf.fill(shape, 0.95), tf.fill(shape, 0.05)) 86 | counts = tf.minimum(10, tf.reduce_sum(tf.to_int32(piano_roll), axis=-1)) 87 | observed = (notes, counts) 88 | inputs = None 89 | return inputs, observed 90 | 91 | # Samples per second when generating audio output. 92 | SYNTHESIZED_RATE = 16000 93 | def _render(self, observed): 94 | """Returns a batch of wave forms corresponding to the observations.""" 95 | notes, counts = observed 96 | 97 | def _synthesize(notes, counts): 98 | """Use pretty_midi to synthesize a wave form.""" 99 | piano_roll = np.zeros((len(counts), 128), dtype=np.bool) 100 | top_notes = np.argsort(notes) 101 | for roll_t, top_notes_t, k in zip(piano_roll, top_notes, counts): 102 | if k > 0: 103 | for i in top_notes_t[-k:]: 104 | roll_t[i] = True 105 | rate = self.hparams.rate 106 | midi = dataset_mod.piano_roll_to_midi(piano_roll, rate) 107 | wave = midi.synthesize(self.SYNTHESIZED_RATE) 108 | wave_len = len(wave) 109 | expect_len = (len(piano_roll) * self.SYNTHESIZED_RATE) // rate 110 | if wave_len < expect_len: 111 | wave = np.pad(wave, [0, expect_len - wave_len], mode='constant') 112 | else: 113 | wave = wave[:expect_len] 114 | return np.float32(wave) 115 | 116 | # Apply synthesize_roll on all elements of the batch. 117 | def _map_batch_elem(notes_counts): 118 | notes, counts = notes_counts 119 | return tf.py_func(_synthesize, [notes, counts], [tf.float32])[0] 120 | return tf.map_fn(_map_batch_elem, (notes, counts), dtype=tf.float32) 121 | 122 | def _make_output_summary(self, tag, observed): 123 | notes, counts = observed 124 | return tf.summary.merge( 125 | [tf.summary.audio( 126 | tag + "/audio", 127 | self._render(observed), 128 | self.SYNTHESIZED_RATE, 129 | collections=[]), 130 | tf.summary.scalar( 131 | tag + "/note_avg", 132 | tf.reduce_mean(notes)), 133 | tf.summary.scalar( 134 | tag + "/note_count", 135 | tf.reduce_mean(tf.to_float(counts)))]) 136 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VAE-Seq 2 | 3 | VAE-Seq is a library for modeling sequences of observations. 4 | 5 | ## Background 6 | 7 | One tool that's commonly used to model sequential data is the 8 | Recurrent Neural Network (RNN), or gated variations of it such as the 9 | Long Short-Term Memory cell or the Gated Recurrent Unit cell. 10 | 11 | RNNs in general are essentially trainable transition functions: 12 | `(input, state) -> (output, state')`, and by themselves don't specify 13 | a complete model. We additionally need to specify a family of 14 | distributions that describes our observations; common choices here are 15 | `Categorical` distributions for discrete observations such as text or 16 | `Normal` distributions for real-valued observations. 17 | 18 | The `output` of the RNN specifies the parameters of the observation 19 | distribution (e.g. the logits of a `Categorical` or the mean and 20 | variance of a `Normal`). But the size of the RNN `output` and the 21 | number of parameters that we need don't necessarily match up. To solve 22 | this, we project `output` into the appropriate shape via a Neural 23 | Network we'll call a decoder. 24 | 25 | And what about the `input` of the RNN? It can be empty, but we might 26 | want to include side information from the environment (e.g. actions 27 | when modeling a game or a metronome when modeling 28 | music). Additionally, the observation from the previous step(s) is 29 | almost always an important feature to include. Here, we'll use another 30 | Neural Network we'll call an encoder to summarize the observation 31 | into a more digestible form. 32 | 33 | Together, these components specify a factored (by time step) 34 | probability distribution that we can train in the usual way: by 35 | maximizing the probability of the network weights given the 36 | observations in your training data and your priors over those 37 | weights. Once trained, you can use ancestral sampling to generate new 38 | sequences. 39 | 40 | ## Motivation 41 | 42 | This library allows you to express the type of model described 43 | above. It handles the plumbing for you: you define the encoder, the 44 | decoder, and the observation distribution. The resulting model can 45 | be trained on a `Dataset` of observation sequences, queried for the 46 | probability of a given sequence, or queried to generate new sequences. 47 | 48 | But the model above also has a limitation: the family of observation 49 | distributions we pick is the only source of non-determinism in the 50 | model. If it can't express the true distribution of observations, the 51 | model won't be able to learn or generate the true range of observation 52 | sequences. For example, consider a sequence of black/white images. If 53 | we pick the observation distribution to be a set of independent 54 | `Bernoulli` distributions over pixel values, the first generated image 55 | would always look like a noisy average over images in the training 56 | set. Subsequent images might get more creative since they are 57 | conditioned on a noisy input, but that depends on how images vary 58 | between steps in the training data. 59 | 60 | The issue in the example above is that the observation distribution we 61 | picked wasn't expressive enough: pixels in an image aren't 62 | independent. One way to fix this is to design very expressive 63 | observation distributions that can model images. Another way is to 64 | condition the simple distribution on a latent variable to produce a 65 | hierarchical output distribution. This latter type of model is known 66 | as a Variational Auto encoder (VAE). 67 | 68 | There are different ways to incorporate latent variables in a 69 | sequential model (see the supported architectures below) but the 70 | general approach we take here is to view the RNN `state` as a 71 | collection of stochastic and deterministic variables. 72 | 73 | ## Usage 74 | 75 | To define a model, subclass `ModelBase` to define an encoder, a 76 | decoder, and the output distribution. The decoder and output 77 | distribution are packaged together into a `DistModule` (see: 78 | [vaeseq/codec.py](vaeseq/codec.py)). 79 | 80 | The following model architectures are currently available (see: 81 | [vaeseq/vae](vaeseq/vae)): 82 | 83 | * An RNN with no latent variables other than a deterministic state. 84 | * A VAE where the stochastic latent variables are independent across 85 | time steps. 86 | * An implementation of SRNN (https://arxiv.org/abs/1605.07571) 87 | 88 | There are lots of hyper-parameters packaged into an `HParams` object 89 | (see: [vaeseq/hparams.py](vaeseq/hparams.py)). You can select among 90 | the architectures above by setting the `vae_type` parameter. 91 | 92 | ## Examples 93 | 94 | When you build and install this library via `python setup.py install`, 95 | the following example programs are installed as well. See: 96 | [vaeseq/examples](vaeseq/examples). 97 | 98 | ### Text 99 | 100 | A character-sequence model that can be used to generate nonsense text 101 | or to evaluate the probability that a given piece of text was written 102 | by a given author. 103 | 104 | To train on Andrej Karpathy's "Tiny Shakespeare" dataset: 105 | ```shell 106 | $ wget https://github.com/karpathy/char-rnn/raw/master/data/tinyshakespeare/input.txt 107 | $ vaeseq-text train --log-dir /tmp/text --train-corpus input.txt \ 108 | --num-steps 1000000 109 | ``` 110 | 111 | After training has completed, you can generate text: 112 | ```shell 113 | $ vaeseq-text generate --log-dir /tmp/text --vocab-corpus input.txt \ 114 | --length 1000 115 | --num-samples 20 116 | ``` 117 | 118 | Or you can tell how likely a piece of text is to be Shakespearean: 119 | ```shell 120 | $ vaeseq-text evaluate --log-dir /tmp/text --vocab-corpus input.txt \ 121 | --eval-corpus foo.txt 122 | ``` 123 | 124 | ### MIDI 125 | 126 | Similar to the text example above, but now modeling MIDI music 127 | (specifically, piano rolls). Installed under `vaeseq-midi`. Don't 128 | expect it to sound great. 129 | 130 | ### Play 131 | 132 | An experiment modeling a game environment and using that to train an 133 | agent via policy gradient. This example uses the OpenAI Gym 134 | module. Installed under `vaeseq-play`. 135 | 136 | ## Disclaimer 137 | 138 | This is not an official Google product. 139 | -------------------------------------------------------------------------------- /vaeseq/vae/vae_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Basic tests for all of the VAE implementations.""" 16 | 17 | import tensorflow as tf 18 | import sonnet as snt 19 | 20 | from vaeseq import codec 21 | from vaeseq import context as context_mod 22 | from vaeseq import hparams as hparams_mod 23 | from vaeseq import util 24 | from vaeseq import vae as vae_mod 25 | 26 | 27 | def _inputs_and_vae(hparams): 28 | """Constructs a VAE.""" 29 | obs_encoder = codec.MLPObsEncoder(hparams) 30 | obs_decoder = codec.MLPObsDecoder( 31 | hparams, 32 | codec.BernoulliDecoder(squeeze_input=True), 33 | param_size=1) 34 | inputs = context_mod.EncodeObserved(obs_encoder) 35 | vae = vae_mod.make(hparams, obs_encoder, obs_decoder) 36 | return inputs, vae 37 | 38 | 39 | def _observed(hparams): 40 | """Test observations.""" 41 | return tf.zeros([util.batch_size(hparams), util.sequence_size(hparams)], 42 | dtype=tf.int32, name="test_obs") 43 | 44 | 45 | def _inf_tensors(hparams, inputs, vae): 46 | """Simple inference graph.""" 47 | with tf.name_scope("inf"): 48 | observed = _observed(hparams) 49 | latents, divs = vae.infer_latents(inputs, observed) 50 | log_probs = vae.evaluate(inputs, observed, latents=latents) 51 | elbo = tf.reduce_sum(log_probs - divs) 52 | return [observed, latents, divs, log_probs, elbo] 53 | 54 | 55 | def _gen_tensors(hparams, inputs, vae): 56 | """Samples observations and latent variables from the VAE.""" 57 | del hparams # Unused, just passed for consistency. 58 | with tf.name_scope("gen"): 59 | generated, latents = vae.generate(inputs) 60 | return [generated, latents] 61 | 62 | 63 | def _eval_tensors(hparams, inputs, vae): 64 | """Calculates the log-probabilities of the observations.""" 65 | with tf.name_scope("eval"): 66 | observed = _observed(hparams) 67 | log_probs = vae.evaluate(inputs, observed, samples=100) 68 | return [log_probs] 69 | 70 | 71 | def _test_assertions(inf_tensors, gen_tensors, eval_tensors): 72 | """Returns in-graph assertions for testing.""" 73 | observed, latents, divs, log_probs, elbo = inf_tensors 74 | generated, sampled_latents = gen_tensors 75 | eval_log_probs, = eval_tensors 76 | 77 | # For RNN, we return None from infer_latents as an optimization. 78 | if latents is None: 79 | latents = sampled_latents 80 | 81 | def _same_batch_and_sequence_size_asserts(t1, name1, t2, name2): 82 | return [ 83 | tf.assert_equal( 84 | util.batch_size_from_nested_tensors(t1), 85 | util.batch_size_from_nested_tensors(t2), 86 | message="Batch: " + name1 + " vs " + name2), 87 | tf.assert_equal( 88 | util.sequence_size_from_nested_tensors(t1), 89 | util.sequence_size_from_nested_tensors(t2), 90 | message="Steps: " + name1 + " vs " + name2), 91 | ] 92 | 93 | def _same_shapes(nested1, nested2): 94 | return snt.nest.flatten(snt.nest.map( 95 | lambda t1, t2: tf.assert_equal( 96 | tf.shape(t1), tf.shape(t2), 97 | message="Shapes: " + t1.name + " vs " + t2.name), 98 | nested1, nested2)) 99 | 100 | def _all_same_batch_and_sequence_sizes(nested): 101 | batch_size = util.batch_size_from_nested_tensors(nested) 102 | sequence_size = util.sequence_size_from_nested_tensors(nested) 103 | return [ 104 | tf.assert_equal(tf.shape(tensor)[0], batch_size, 105 | message="Batch: " + tensor.name) 106 | for tensor in snt.nest.flatten(nested) 107 | ] + [ 108 | tf.assert_equal(tf.shape(tensor)[1], sequence_size, 109 | message="Steps: " + tensor.name) 110 | for tensor in snt.nest.flatten(nested) 111 | ] 112 | 113 | assertions = [ 114 | tf.assert_non_negative(divs), 115 | tf.assert_non_positive(log_probs), 116 | ] + _same_shapes( 117 | (log_probs, log_probs, observed, latents), 118 | (divs, eval_log_probs, generated, sampled_latents) 119 | ) + _all_same_batch_and_sequence_sizes( 120 | (observed, latents, divs) 121 | ) + _all_same_batch_and_sequence_sizes( 122 | (generated, sampled_latents) 123 | ) 124 | vars_ = tf.trainable_variables() 125 | grads = tf.gradients(-elbo, vars_) 126 | for (var, grad) in zip(vars_, grads): 127 | assertions.append(tf.check_numerics(grad, "Gradient for " + var.name)) 128 | return assertions 129 | 130 | 131 | def _all_tensors(hparams, inputs, vae): 132 | """All tensors to evaluate in tests.""" 133 | gen_tensors = _gen_tensors(hparams, inputs, vae) 134 | inf_tensors = _inf_tensors(hparams, inputs, vae) 135 | eval_tensors = _eval_tensors(hparams, inputs, vae) 136 | assertions = _test_assertions(inf_tensors, gen_tensors, eval_tensors) 137 | all_tensors = inf_tensors + gen_tensors + eval_tensors + assertions 138 | return [x for x in all_tensors if x is not None] 139 | 140 | 141 | class VAETest(tf.test.TestCase): 142 | 143 | def _test_vae(self, vae_type): 144 | """Make sure that all tensors and assertions evaluate without error.""" 145 | hparams = hparams_mod.make_hparams(vae_type=vae_type) 146 | inputs, vae = _inputs_and_vae(hparams) 147 | tensors = _all_tensors(hparams, inputs, vae) 148 | with self.test_session() as sess: 149 | sess.run(tf.global_variables_initializer()) 150 | sess.run(tensors) 151 | 152 | def test_iseq(self): 153 | self._test_vae("ISEQ") 154 | 155 | def test_rnn(self): 156 | self._test_vae("RNN") 157 | 158 | def test_srnn(self): 159 | self._test_vae("SRNN") 160 | 161 | 162 | if __name__ == "__main__": 163 | tf.test.main() 164 | -------------------------------------------------------------------------------- /vaeseq/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Training subgraph for a VAE.""" 16 | 17 | import functools 18 | import numpy as np 19 | import sonnet as snt 20 | import tensorflow as tf 21 | 22 | from . import util 23 | 24 | 25 | class Trainer(snt.AbstractModule): 26 | """Wraps an optimizer and an objective.""" 27 | 28 | def __init__(self, hparams, global_step, loss, 29 | variables=tf.trainable_variables, 30 | name=None): 31 | super(Trainer, self).__init__(name=name) 32 | self._hparams = hparams 33 | self._global_step = global_step 34 | self._loss = loss 35 | if callable(variables): 36 | self._variables = variables 37 | else: 38 | self._variables = lambda: variables 39 | 40 | @util.lazy_property 41 | def optimizer(self): 42 | return self._make_optimizer() 43 | 44 | def _make_optimizer(self): 45 | return tf.train.AdamOptimizer(self._hparams.learning_rate) 46 | 47 | def _transform_gradients(self, gradients_to_variables): 48 | """Transform gradients before applying the optimizer.""" 49 | if self._hparams.clip_gradient_norm > 0: 50 | gradients_to_variables = tf.contrib.training.clip_gradient_norms( 51 | gradients_to_variables, 52 | self._hparams.clip_gradient_norm) 53 | return gradients_to_variables 54 | 55 | def _build(self, inputs, observed): 56 | loss, debug_tensors = self._loss(inputs, observed) 57 | variables = self._variables() 58 | if not variables: 59 | raise ValueError("No trainable variables found.") 60 | # Summarize the magnitudes of the model variables to see whether we need 61 | # regularization. 62 | for var in variables: 63 | tf.summary.histogram(name=var.op.name + "/values", values=var) 64 | # Unfortunately regularization losses are all stored together 65 | # so we can't segment them by those that come from variables. 66 | reg_loss = tf.losses.get_regularization_loss() 67 | tf.summary.scalar(name="reg_loss", tensor=reg_loss) 68 | train_op = tf.contrib.training.create_train_op( 69 | loss + reg_loss, 70 | self.optimizer, 71 | global_step=self._global_step, 72 | variables_to_train=variables, 73 | transform_grads_fn=self._transform_gradients, 74 | summarize_gradients=True, 75 | check_numerics=self._hparams.check_numerics) 76 | return train_op, debug_tensors 77 | 78 | 79 | class Group(snt.AbstractModule): 80 | """Trainer that joins multiple trainers together.""" 81 | 82 | def __init__(self, trainers, name=None): 83 | super(Group, self).__init__(name=name) 84 | self._trainers = trainers 85 | 86 | def _build(self, inputs, observed): 87 | train_ops = [] 88 | debug_tensors = {} 89 | for trainer in self._trainers: 90 | train_op, debug = trainer(inputs, observed) 91 | train_ops.append(train_op) 92 | debug_tensors.update(debug) 93 | return tf.group(*train_ops), debug_tensors 94 | 95 | 96 | class ELBOLoss(snt.AbstractModule): 97 | """Calculates an objective for maximizing the evidence lower bound.""" 98 | 99 | def __init__(self, hparams, vae, name=None): 100 | super(ELBOLoss, self).__init__(name=name) 101 | self._hparams = hparams 102 | self._vae = vae 103 | 104 | def _build(self, inputs, observed): 105 | debug_tensors = {} 106 | scalar_summary = functools.partial(_scalar_summary, debug_tensors) 107 | 108 | latents, divs = self._vae.infer_latents(inputs, observed) 109 | log_probs = self._vae.evaluate(inputs, observed, latents=latents) 110 | log_prob = tf.reduce_mean(log_probs) 111 | divergence = tf.reduce_mean(divs) 112 | scalar_summary("log_prob", log_prob) 113 | scalar_summary("divergence", divergence) 114 | scalar_summary("ELBO", log_prob - divergence) 115 | 116 | # We soften the divergence penalty at the start of training. 117 | temp_start = -np.log(self._hparams.divergence_strength_start) 118 | temp_decay = ((-np.log(0.5) / temp_start) ** 119 | (1. / self._hparams.divergence_strength_half)) 120 | global_step = tf.to_double(tf.train.get_or_create_global_step()) 121 | divergence_strength = tf.to_float( 122 | tf.exp(-temp_start * tf.pow(temp_decay, global_step))) 123 | scalar_summary("divergence_strength", divergence_strength) 124 | relaxed_elbo = log_prob - divergence * divergence_strength 125 | loss = -relaxed_elbo 126 | scalar_summary(self.module_name, loss) 127 | return loss, debug_tensors 128 | 129 | 130 | class RewardLoss(snt.AbstractModule): 131 | """Sums component losses.""" 132 | 133 | def __init__(self, hparams, inputs, vae, reward, name=None): 134 | super(RewardLoss, self).__init__(name=name) 135 | self._inputs = inputs 136 | self._vae = vae 137 | self._reward = reward 138 | 139 | def _build(self, inputs, observed): 140 | del inputs, observed # We only use the generated reward. 141 | debug_tensors = {} 142 | scalar_summary = functools.partial(_scalar_summary, debug_tensors) 143 | 144 | generated = self._vae.generate(self._inputs)[0] 145 | mean_reward = tf.reduce_mean(self._reward(generated)) 146 | scalar_summary("mean_reward", mean_reward) 147 | #loss = -self._neg_log_reward_dist.log_prob(-tf.log(mean_reward + 1e-5)) 148 | loss = -mean_reward 149 | scalar_summary(self.module_name, loss) 150 | return loss, debug_tensors 151 | 152 | 153 | def _scalar_summary(debug_tensors, name, tensor): 154 | """Add a summary and a debug output tensor.""" 155 | tensor = tf.convert_to_tensor(tensor, name=name) 156 | debug_tensors[name] = tensor 157 | tf.summary.scalar(name, tensor) 158 | -------------------------------------------------------------------------------- /vaeseq/batch_dist_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for BatchDistribution.""" 16 | 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | from vaeseq import batch_dist 21 | 22 | 23 | class BatchDistributionTest(tf.test.TestCase): 24 | 25 | def test_sample(self): 26 | dist = tf.distributions.Bernoulli(logits=tf.zeros((4, 5, 6))) 27 | batch_dist1 = batch_dist.BatchDistribution(dist) 28 | with self.test_session() as sess: 29 | sess.run(tf.assert_equal(dist.sample(seed=123), 30 | batch_dist1.sample(seed=123))) 31 | 32 | def test_log_prob(self): 33 | dist = tf.distributions.Bernoulli(logits=tf.zeros((4, 5, 6))) 34 | batch_dist1 = batch_dist.BatchDistribution(dist) 35 | batch_dist2 = batch_dist.BatchDistribution(dist, ndims=2) 36 | event = tf.zeros((4, 5, 6)) 37 | with self.test_session() as sess: 38 | self.assertAllClose(sess.run(batch_dist1.log_prob(event)), 39 | np.full((4, 5), 6 * np.log(0.5))) 40 | self.assertAllClose(sess.run(batch_dist2.log_prob(event)), 41 | np.full((4,), 30 * np.log(0.5))) 42 | 43 | def test_prob(self): 44 | dist = tf.distributions.Bernoulli(probs=0.5 * tf.ones((4, 5, 6))) 45 | batch_dist1 = batch_dist.BatchDistribution(dist) 46 | batch_dist2 = batch_dist.BatchDistribution(dist, ndims=2) 47 | event = tf.zeros((4, 5, 6)) 48 | with self.test_session() as sess: 49 | self.assertAllClose(sess.run(batch_dist1.prob(event)), 50 | np.full((4, 5), 0.5 ** 6)) 51 | self.assertAllClose(sess.run(batch_dist2.prob(event)), 52 | np.full((4,), 0.5 ** 30)) 53 | 54 | def test_is_scalar(self): 55 | dist = tf.distributions.Bernoulli(probs=0.5 * tf.ones((4, 5))) 56 | batch_dist1 = batch_dist.BatchDistribution(dist) 57 | batch_dist2 = batch_dist.BatchDistribution(dist, ndims=2) 58 | with self.test_session() as sess: 59 | self.assertAllEqual( 60 | sess.run([batch_dist1.is_scalar_event(), 61 | batch_dist2.is_scalar_event()]), 62 | [False, False]) 63 | self.assertAllEqual( 64 | sess.run([batch_dist1.is_scalar_batch(), 65 | batch_dist2.is_scalar_batch()]), 66 | [False, True]) 67 | 68 | 69 | class GroupDistributionTest(tf.test.TestCase): 70 | 71 | def test_sample(self): 72 | components = { 73 | 'dist_a': tf.distributions.Bernoulli(logits=tf.zeros(3)), 74 | 'dist_b': tf.distributions.Normal(tf.zeros(3), tf.ones(3)), 75 | } 76 | dist = batch_dist.GroupDistribution(components) 77 | with self.test_session() as sess: 78 | val = sess.run(dist.sample()) 79 | self.assertAllEqual(val['dist_a'].shape, [3]) 80 | self.assertAllEqual(val['dist_b'].shape, [3]) 81 | 82 | def test_log_prob(self): 83 | components = { 84 | 'dist_a': tf.distributions.Bernoulli(logits=tf.zeros(3)), 85 | 'dist_b': tf.distributions.Normal(tf.zeros(3), tf.ones(3)), 86 | } 87 | dist = batch_dist.GroupDistribution(components) 88 | with self.test_session() as sess: 89 | a_log_prob, b_log_prob, group_log_prob = sess.run([ 90 | components['dist_a'].log_prob(0.), 91 | components['dist_b'].log_prob(0.), 92 | dist.log_prob({'dist_a': 0., 'dist_b': 0.})]) 93 | self.assertAllClose(a_log_prob, np.log([0.5] * 3)) 94 | self.assertAllClose(b_log_prob, np.log([(2 * np.pi) ** -0.5] * 3)) 95 | self.assertAllClose(group_log_prob, 96 | np.log([0.5 * (2 * np.pi) ** -0.5] * 3)) 97 | 98 | def test_prob(self): 99 | components = { 100 | 'dist_a': tf.distributions.Bernoulli(logits=tf.zeros(3)), 101 | 'dist_b': tf.distributions.Normal(tf.zeros(3), tf.ones(3)), 102 | } 103 | dist = batch_dist.GroupDistribution(components) 104 | with self.test_session() as sess: 105 | a_prob, b_prob, group_prob = sess.run([ 106 | components['dist_a'].prob(0.), 107 | components['dist_b'].prob(0.), 108 | dist.prob({'dist_a': 0., 'dist_b': 0.})]) 109 | self.assertAllClose(a_prob, [0.5] * 3) 110 | self.assertAllClose(b_prob, [(2 * np.pi) ** -0.5] * 3) 111 | self.assertAllClose(group_prob, [0.5 * (2 * np.pi) ** -0.5] * 3) 112 | 113 | def test_is_scalar(self): 114 | assertions = [ 115 | dist.is_scalar_event() for dist in [ 116 | batch_dist.GroupDistribution((((), ()))), 117 | batch_dist.GroupDistribution( 118 | tf.distributions.Bernoulli(probs=0.5)), 119 | batch_dist.GroupDistribution( 120 | tf.distributions.Bernoulli(probs=[0.5, 0.3])), 121 | ] 122 | ] + [ 123 | tf.logical_not(dist.is_scalar_event()) for dist in [ 124 | batch_dist.GroupDistribution( 125 | (tf.distributions.Bernoulli(probs=0.5), 126 | tf.distributions.Bernoulli(probs=0.5))), 127 | batch_dist.GroupDistribution( 128 | tf.contrib.distributions.MultivariateNormalDiag( 129 | loc=tf.zeros(2), scale_diag=tf.ones(2))), 130 | ] 131 | ] + [ 132 | dist.is_scalar_batch() for dist in [ 133 | batch_dist.GroupDistribution((((), ()))), 134 | batch_dist.GroupDistribution( 135 | tf.distributions.Bernoulli(probs=0.5)), 136 | batch_dist.GroupDistribution( 137 | tf.contrib.distributions.MultivariateNormalDiag( 138 | loc=tf.zeros(2), scale_diag=tf.ones(2))), 139 | ] 140 | ] + [ 141 | tf.logical_not(dist.is_scalar_batch()) for dist in [ 142 | batch_dist.GroupDistribution( 143 | (tf.distributions.Bernoulli(probs=0.5), 144 | tf.distributions.Bernoulli(probs=0.5))), 145 | batch_dist.GroupDistribution( 146 | tf.distributions.Bernoulli(probs=[0.5, 0.3])), 147 | ] 148 | ] 149 | with self.test_session() as sess: 150 | self.assertTrue(np.all(sess.run(assertions))) 151 | 152 | 153 | if __name__ == '__main__': 154 | tf.test.main() 155 | -------------------------------------------------------------------------------- /vaeseq/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities used elsewhere in this library.""" 16 | 17 | import functools 18 | import sonnet as snt 19 | import tensorflow as tf 20 | from tensorflow.contrib import distributions 21 | 22 | 23 | def calc_kl(hparams, a_sample, dist_a, dist_b): 24 | """Calculates KL(a||b), either analytically or via MC estimate.""" 25 | if hparams.use_monte_carlo_kl: 26 | return dist_a.log_prob(a_sample) - dist_b.log_prob(a_sample) 27 | return distributions.kl_divergence(dist_a, dist_b) 28 | 29 | 30 | def activation(hparams): 31 | """Returns the activation function selected in hparams.""" 32 | return { 33 | "relu": tf.nn.relu, 34 | "elu": tf.nn.elu, 35 | }[hparams.activation] 36 | 37 | 38 | def positive_projection(hparams): 39 | """Returns the positive projection selected in hparams.""" 40 | proj = { 41 | "exp": tf.exp, 42 | "softplus": tf.nn.softplus, 43 | }[hparams.positive_projection] 44 | return lambda tensor: proj(tensor) + hparams.positive_eps 45 | 46 | 47 | def regularizer(hparams): 48 | def _apply_regularizer(tensor): 49 | with tf.control_dependencies(None): 50 | return tf.contrib.layers.l1_l2_regularizer( 51 | scale_l1=hparams.l1_regularization, 52 | scale_l2=hparams.l2_regularization)(tensor) 53 | return _apply_regularizer 54 | 55 | def make_rnn(hparams, name): 56 | """Constructs a DeepRNN using hparams.rnn_hidden_sizes.""" 57 | regularizers = { 58 | snt.LSTM.W_GATES: regularizer(hparams) 59 | } 60 | with tf.variable_scope(name): 61 | layers = [snt.LSTM(size, regularizers=regularizers) 62 | for size in hparams.rnn_hidden_sizes] 63 | return snt.DeepRNN(layers, skip_connections=False, name=name) 64 | 65 | 66 | def make_mlp(hparams, layers, name=None, **kwargs): 67 | """Constructs an MLP with the given layers, using hparams.activation.""" 68 | regularizers = { 69 | "w": regularizer(hparams) 70 | } 71 | return snt.nets.MLP( 72 | layers, 73 | activation=activation(hparams), 74 | regularizers=regularizers, 75 | name=name or "MLP", 76 | **kwargs) 77 | 78 | 79 | def concat_features(tensors): 80 | """Concatenates nested tensors along the last dimension.""" 81 | tensors = snt.nest.flatten(tensors) 82 | if len(tensors) == 1: 83 | return tensors[0] 84 | return tf.concat(tensors, axis=-1) 85 | 86 | 87 | class WrapRNNCore(snt.RNNCore): 88 | """Wrap a transition function into an RNNCore.""" 89 | 90 | def __init__(self, step, state_size, output_size, name=None): 91 | super(WrapRNNCore, self).__init__(name=name) 92 | self._step = step 93 | self._state_size = state_size 94 | self._output_size = output_size 95 | 96 | @property 97 | def output_size(self): 98 | """RNN output sizes.""" 99 | return self._output_size 100 | 101 | @property 102 | def state_size(self): 103 | """RNN state sizes.""" 104 | return self._state_size 105 | 106 | def _build(self, input_, state): 107 | return self._step(input_, state) 108 | 109 | 110 | def add_support_for_scalar_rnn_inputs(cell, inputs): 111 | """Wraps a cell to add support for scalar RNN inputs.""" 112 | flat_inputs = snt.nest.flatten(inputs) 113 | flat_input_is_scalar = [inp.get_shape().ndims == 2 for inp in flat_inputs] 114 | if not any(flat_input_is_scalar): 115 | return cell, inputs 116 | inputs = snt.nest.pack_sequence_as( 117 | inputs, 118 | [tf.expand_dims(inp, axis=-1) if is_scalar else inp 119 | for inp, is_scalar in zip(flat_inputs, flat_input_is_scalar)]) 120 | is_scalar = snt.nest.pack_sequence_as(inputs, flat_input_is_scalar) 121 | 122 | def _squeeze(input_, is_scalar): 123 | return tf.squeeze(input_, axis=-1) if is_scalar else input_ 124 | 125 | ret_cell = WrapRNNCore( 126 | lambda inp, state: cell(snt.nest.map(_squeeze, inp, is_scalar), state), 127 | state_size=cell.state_size, 128 | output_size=cell.output_size) 129 | return ret_cell, inputs 130 | 131 | 132 | def input_recording_rnn(cell, input_size): 133 | """Transforms the cell to emit both the output and input.""" 134 | def _step(input_, state): 135 | output, state = cell(input_, state) 136 | return (output, input_), state 137 | return WrapRNNCore( 138 | _step, 139 | state_size=cell.state_size, 140 | output_size=(cell.output_size, input_size)) 141 | 142 | 143 | def state_recording_rnn(cell): 144 | """Transforms the cell to emit both the output and the state.""" 145 | def _step(input_, state): 146 | output, next_state = cell(input_, state) 147 | return (output, state), next_state 148 | return WrapRNNCore( 149 | _step, 150 | state_size=cell.state_size, 151 | output_size=(cell.output_size, cell.state_size)) 152 | 153 | 154 | def use_recorded_state_rnn(cell): 155 | """Transforms the cell to use the recorded state in the input.""" 156 | def _step(input_state, state): 157 | del state # unused 158 | input_, state = input_state 159 | return cell(input_, state) 160 | return WrapRNNCore(_step, cell.state_size, cell.output_size) 161 | 162 | 163 | def heterogeneous_dynamic_rnn( 164 | cell, inputs, initial_state=None, time_major=False, 165 | output_dtypes=None, **kwargs): 166 | """Wrapper around tf.nn.dynamic_rnn that supports heterogeneous outputs.""" 167 | time_axis = 0 if time_major else 1 168 | batch_axis = 1 if time_major else 0 169 | if initial_state is None: 170 | batch_size = batch_size_from_nested_tensors(inputs) 171 | initial_state = cell.zero_state(batch_size, output_dtypes) 172 | flat_dtypes = snt.nest.flatten(output_dtypes) 173 | flat_output_size = snt.nest.flatten(cell.output_size) 174 | # The first output will be returned the normal way; the rest will 175 | # be returned via state TensorArrays. 176 | input_length = sequence_size_from_nested_tensors(inputs) 177 | aux_output_tas = [ 178 | tf.TensorArray( 179 | dtype, 180 | size=input_length, 181 | element_shape=tf.TensorShape([None]).concatenate(out_size)) 182 | for dtype, out_size in zip(flat_dtypes[1:], flat_output_size[1:]) 183 | ] 184 | aux_state = (0, aux_output_tas, initial_state) 185 | 186 | def _step(input_, aux_state): 187 | """Wrap the cell to return the first output and store the rest.""" 188 | step, aux_output_tas, state = aux_state 189 | outputs, state = cell(input_, state) 190 | flat_outputs = snt.nest.flatten(outputs) 191 | aux_output_tas = [ 192 | ta.write(step, output) 193 | for ta, output in zip(aux_output_tas, flat_outputs[1:]) 194 | ] 195 | return flat_outputs[0], (step + 1, aux_output_tas, state) 196 | 197 | first_output, (_, aux_output_tas, state) = tf.nn.dynamic_rnn( 198 | WrapRNNCore(_step, state_size=None, output_size=flat_output_size[0]), 199 | inputs, 200 | initial_state=aux_state, 201 | dtype=flat_dtypes[0], 202 | time_major=time_major, 203 | **kwargs) 204 | first_output_shape = first_output.get_shape().with_rank_at_least(2) 205 | time_and_batch = tf.TensorShape([first_output_shape[time_axis], 206 | first_output_shape[batch_axis]]) 207 | outputs = [first_output] 208 | for aux_output_ta in aux_output_tas: 209 | output = aux_output_ta.stack() 210 | output.set_shape(time_and_batch.concatenate(output.get_shape()[2:])) 211 | if not time_major: 212 | output = transpose_time_batch(output) 213 | outputs.append(output) 214 | return snt.nest.pack_sequence_as(output_dtypes, outputs), state 215 | 216 | 217 | def transpose_time_batch(tensor): 218 | """Transposes the first two dimensions of a Tensor.""" 219 | perm = list(range(tensor.get_shape().with_rank_at_least(2).ndims)) 220 | perm[0], perm[1] = 1, 0 221 | return tf.transpose(tensor, perm=perm) 222 | 223 | 224 | def reverse_dynamic_rnn(cell, inputs, time_major=False, **kwargs): 225 | """Runs tf.nn.dynamic_rnn backwards.""" 226 | time_axis = 0 if time_major else 1 227 | reverse_seq = lambda x: tf.reverse(x, axis=[time_axis]) 228 | inputs = snt.nest.map(reverse_seq, inputs) 229 | output, state = tf.nn.dynamic_rnn( 230 | cell, inputs, time_major=time_major, **kwargs) 231 | return snt.nest.map(reverse_seq, output), state 232 | 233 | 234 | def dynamic_hparam(key, value): 235 | """Returns a memoized, non-constant Tensor that allows feeding.""" 236 | collection = tf.get_collection_ref("HPARAMS_" + key) 237 | if len(collection) > 1: 238 | raise ValueError("Dynamic hparams ollection should contain one item.") 239 | if not collection: 240 | with tf.name_scope(""): 241 | default_value = tf.convert_to_tensor(value, name=key + "_default") 242 | tensor = tf.placeholder_with_default( 243 | default_value, 244 | default_value.get_shape(), 245 | name=key) 246 | collection.append(tensor) 247 | return collection[0] 248 | 249 | 250 | def batch_size_from_nested_tensors(tensors): 251 | """Returns the batch dimension from the first non-scalar tensor given.""" 252 | for tensor in snt.nest.flatten(tensors): 253 | if tensor.get_shape().ndims > 0: 254 | return tf.shape(tensor)[0] 255 | return None 256 | 257 | 258 | def sequence_size_from_nested_tensors(tensors): 259 | """Returns the time dimension from the first K-tensor given where K > 1.""" 260 | for tensor in snt.nest.flatten(tensors): 261 | if tensor.get_shape().ndims > 1: 262 | return tf.shape(tensor)[1] 263 | return None 264 | 265 | 266 | def batch_size(hparams): 267 | """Returns a non-constant Tensor that evaluates to hparams.batch_size.""" 268 | return dynamic_hparam("batch_size", hparams.batch_size) 269 | 270 | 271 | def sequence_size(hparams): 272 | """Returns a non-constant Tensor that evaluates to hparams.sequence_size.""" 273 | return dynamic_hparam("sequence_size", hparams.sequence_size) 274 | 275 | 276 | def set_tensor_shapes(tensors, shapes, add_batch_dims=0): 277 | """Set static shape information for nested tuples of tensors and shapes.""" 278 | if add_batch_dims: 279 | batch_dims = tf.TensorShape([None] * add_batch_dims) 280 | shapes = snt.nest.map(batch_dims.concatenate, shapes) 281 | snt.nest.map(lambda tensor, shape: tensor.set_shape(shape), 282 | tensors, shapes) 283 | 284 | 285 | def lazy_property(fn): 286 | """A property decorator that caches the returned value.""" 287 | key = fn.__name__ + "_cache_val_" 288 | 289 | @functools.wraps(fn) 290 | def _lazy(self): 291 | """Very simple cache.""" 292 | if not hasattr(self, key): 293 | setattr(self, key, fn(self)) 294 | return getattr(self, key) 295 | 296 | return property(_lazy) 297 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /vaeseq/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Model base class used by the examples.""" 16 | 17 | from __future__ import print_function 18 | from builtins import range 19 | 20 | import abc 21 | import six 22 | import tensorflow as tf 23 | 24 | from google.protobuf import text_format 25 | 26 | from . import context as context_mod 27 | from . import train as train_mod 28 | from . import util 29 | from . import vae as vae_mod 30 | 31 | 32 | @six.add_metaclass(abc.ABCMeta) 33 | class ModelBase(object): 34 | """Common functionality for training/generation/evaluation/etc.""" 35 | 36 | def __init__(self, hparams, session_params): 37 | self._hparams = hparams 38 | self._session_params = session_params 39 | with tf.name_scope("model") as ns: 40 | self._name_scope = ns 41 | 42 | class SessionParams(object): 43 | """Utility class for commonly used session parameters.""" 44 | 45 | def __init__(self, log_dir=None, master="", task=0, 46 | session_config=None): 47 | self.log_dir = log_dir 48 | self.master = master 49 | self.task = task 50 | self.session_config = session_config 51 | 52 | @classmethod 53 | def add_parser_arguments(cls, parser): 54 | """Add ArgParse argument parsers for session flags. 55 | 56 | The result of argument parsing can be passed into the 57 | ModelBase constructor instead of a SessionParams object. 58 | """ 59 | defaults = cls() 60 | parser.add_argument("--log-dir", required=True, 61 | help="Checkpoint directory") 62 | parser.add_argument("--master", default=defaults.master, 63 | help="Session master.") 64 | parser.add_argument("--task", default=defaults.task, 65 | help="Worker task number.") 66 | def _parse_config_proto(msg): 67 | return text_format.Parse(msg, tf.ConfigProto()) 68 | parser.add_argument("--session-config", default=None, 69 | type=_parse_config_proto, 70 | help="Session ConfigProto.") 71 | 72 | @property 73 | def hparams(self): 74 | return self._hparams 75 | 76 | def name_scope(self, name, default_name=None, values=None): 77 | with tf.name_scope(self._name_scope): 78 | # Capture the sub-namescope as an absolute path. 79 | with tf.name_scope(name, default_name, values) as ns: 80 | return tf.name_scope(ns) 81 | 82 | @util.lazy_property 83 | def encoder(self): 84 | with self.name_scope("encoder"): 85 | return self._make_encoder() 86 | 87 | @util.lazy_property 88 | def decoder(self): 89 | with self.name_scope("decoder"): 90 | return self._make_decoder() 91 | 92 | @util.lazy_property 93 | def feedback(self): 94 | with self.name_scope("feedback"): 95 | return self._make_feedback() 96 | 97 | @util.lazy_property 98 | def agent(self): 99 | with self.name_scope("agent"): 100 | return self._make_agent() 101 | 102 | @util.lazy_property 103 | def inputs(self): 104 | with self.name_scope("inputs"): 105 | return self._make_full_input_context(self.agent) 106 | 107 | @util.lazy_property 108 | def trainer(self): 109 | with self.name_scope("trainer"): 110 | return self._make_trainer() 111 | 112 | @util.lazy_property 113 | def vae(self): 114 | with self.name_scope("vae"): 115 | return vae_mod.make(self.hparams, self.encoder, self.decoder) 116 | 117 | def dataset(self, dataset, name=None): 118 | """Returns inputs and observations for the given dataset.""" 119 | with self.name_scope("dataset", name): 120 | inputs, observed = self._make_dataset(dataset) 121 | return self._make_full_input_context(inputs), observed 122 | 123 | def training_session(self, hooks=None): 124 | scaffold = self._make_scaffold() 125 | return tf.train.MonitoredTrainingSession( 126 | master=self._session_params.master, 127 | config=self._session_params.session_config, 128 | is_chief=(self._session_params.task == 0), 129 | scaffold=scaffold, 130 | hooks=hooks, 131 | checkpoint_dir=self._session_params.log_dir) 132 | 133 | def eval_session(self, hooks=None): 134 | scaffold = self._make_scaffold() 135 | if self._session_params.task == 0: 136 | session_creator = tf.train.ChiefSessionCreator( 137 | master=self._session_params.master, 138 | config=self._session_params.session_config, 139 | scaffold=scaffold, 140 | checkpoint_dir=self._session_params.log_dir) 141 | else: 142 | session_creator = tf.train.WorkerSessionCreator( 143 | master=self._session_params.master, 144 | config=self._session_params.session_config, 145 | scaffold=scaffold) 146 | return tf.train.MonitoredSession( 147 | hooks=hooks, 148 | session_creator=session_creator) 149 | 150 | def evaluate(self, dataset, num_steps): 151 | """Calculates the mean log-prob for the given sequences.""" 152 | with tf.name_scope("evaluate"): 153 | inputs, observed = self.dataset(dataset) 154 | log_probs = self.vae.evaluate( 155 | inputs, observed, 156 | samples=self.hparams.log_prob_samples) 157 | mean_log_prob, update = tf.metrics.mean(log_probs) 158 | hooks = [tf.train.LoggingTensorHook({"log_prob": mean_log_prob}, 159 | every_n_secs=10., at_end=True)] 160 | latest = None 161 | with self.eval_session(hooks=hooks) as sess: 162 | for _ in range(num_steps): 163 | if sess.should_stop(): 164 | break 165 | latest = sess.run(update) 166 | if latest is not None: 167 | return {"log_prob": latest} 168 | return None 169 | 170 | def train(self, dataset, num_steps, valid_dataset=None): 171 | """Trains/continues training the model.""" 172 | global_step = tf.train.get_or_create_global_step() 173 | inputs, observed = self.dataset(dataset, name="train_dataset") 174 | train_op, debug = self.trainer(inputs, observed) 175 | debug["global_step"] = global_step 176 | hooks = [tf.train.LoggingTensorHook(debug, every_n_secs=60.)] 177 | if self._session_params.log_dir: 178 | # Add metric summaries to be computed at a slower rate. 179 | slow_summaries = [] 180 | def _add_to_slow_summaries(name, inputs, observed): 181 | """Creates a self-updating metric summary op.""" 182 | with tf.name_scope(name): 183 | log_probs = self.vae.evaluate( 184 | inputs, observed, 185 | samples=self.hparams.log_prob_samples) 186 | mean, update = tf.metrics.mean(log_probs) 187 | with tf.control_dependencies([update]): 188 | slow_summaries.append( 189 | tf.summary.scalar("mean_log_prob", 190 | mean, collections=[])) 191 | _add_to_slow_summaries("train_eval", inputs, observed) 192 | if valid_dataset is not None: 193 | vinputs, vobserved = self.dataset(valid_dataset, 194 | name="valid_dataset") 195 | _add_to_slow_summaries("valid_eval", vinputs, vobserved) 196 | hooks.append(tf.train.SummarySaverHook( 197 | save_steps=100, 198 | output_dir=self._session_params.log_dir, 199 | summary_op=tf.summary.merge(slow_summaries))) 200 | 201 | # Add sample generated sequences. 202 | generated, _unused_latents = self.vae.generate( 203 | inputs=inputs, 204 | batch_size=util.batch_size_from_nested_tensors(observed), 205 | sequence_size=util.sequence_size_from_nested_tensors(observed)) 206 | hooks.append(tf.train.SummarySaverHook( 207 | save_steps=1000, 208 | output_dir=self._session_params.log_dir, 209 | summary_op=tf.summary.merge([ 210 | self._make_output_summary("observed", observed), 211 | self._make_output_summary("generated", generated), 212 | ]))) 213 | 214 | debug_vals = None 215 | with self.training_session(hooks=hooks) as sess: 216 | for local_step in range(num_steps): 217 | if sess.should_stop(): 218 | break 219 | if local_step < num_steps - 1: 220 | sess.run(train_op) 221 | else: 222 | _, debug_vals = sess.run((train_op, debug)) 223 | return debug_vals 224 | 225 | def generate(self): 226 | """Generates sequences from a trained model.""" 227 | generated, _unused_latents = self.vae.generate(self.inputs) 228 | rendered = self._render(generated) 229 | with self.eval_session() as sess: 230 | while True: 231 | batch = sess.run(rendered) 232 | for sequence in batch: 233 | yield sequence 234 | 235 | def _make_full_input_context(self, inputs): 236 | """Chains agent with feedback produce the VAE input.""" 237 | if inputs is None: 238 | return self.feedback 239 | inputs = context_mod.as_context(inputs) 240 | return context_mod.Chain([inputs, self.feedback]) 241 | 242 | def _make_scaffold(self): 243 | local_init_op = tf.group( 244 | tf.local_variables_initializer(), 245 | tf.tables_initializer(), 246 | *tf.get_collection(tf.GraphKeys.LOCAL_INIT_OP)) 247 | return tf.train.Scaffold(local_init_op=local_init_op) 248 | 249 | def _render(self, observed): 250 | """Returns a rendering of the modeled observation for output.""" 251 | return observed 252 | 253 | def _make_feedback(self): 254 | """Constructs the feedback Context.""" 255 | # Default to an encoding of the previous observation.. 256 | return context_mod.EncodeObserved(self.encoder) 257 | 258 | def _make_agent(self): 259 | """Constructs a Context used for generating inputs.""" 260 | return None # No inputs. 261 | 262 | def _make_trainer(self): 263 | global_step = tf.train.get_or_create_global_step() 264 | loss = train_mod.ELBOLoss(self.hparams, self.vae) 265 | return train_mod.Trainer(self.hparams, global_step=global_step, 266 | loss=loss, variables=tf.trainable_variables) 267 | 268 | @abc.abstractmethod 269 | def _make_encoder(self): 270 | """Constructs the observation encoder.""" 271 | 272 | @abc.abstractmethod 273 | def _make_decoder(self): 274 | """Constructs the observation decoder DistModule.""" 275 | 276 | @abc.abstractmethod 277 | def _make_dataset(self, dataset): 278 | """Returns inputs (can be None) and outputs as sequence Tensors.""" 279 | 280 | @abc.abstractmethod 281 | def _make_output_summary(self, tag, observed): 282 | """Returns a tf.summary to display this sequence..""" 283 | -------------------------------------------------------------------------------- /vaeseq/codec.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Modules for encoding and decoding observations.""" 16 | 17 | import sonnet as snt 18 | import tensorflow as tf 19 | 20 | from . import batch_dist 21 | from . import dist_module 22 | from . import util 23 | 24 | 25 | class EncoderSequence(snt.Sequential): 26 | """A wrapper arount snt.Sequential that also implements output_size.""" 27 | 28 | @property 29 | def output_size(self): 30 | return self.layers[-1].output_size 31 | 32 | 33 | class FlattenEncoder(snt.AbstractModule): 34 | """Forwards the flattened input.""" 35 | 36 | def __init__(self, input_size=None, name=None): 37 | super(FlattenEncoder, self).__init__(name=name) 38 | self._input_size = None 39 | if input_size is not None: 40 | self._merge_input_sizes(input_size) 41 | 42 | def _merge_input_sizes(self, input_size): 43 | if self._input_size is None: 44 | self._input_size = snt.nest.map(tf.TensorShape, input_size) 45 | return 46 | self._input_size = snt.nest.map( 47 | lambda cur_size, inp_size: cur_size.merge_with(inp_size), 48 | self._input_size, 49 | input_size) 50 | 51 | @property 52 | def output_size(self): 53 | """Returns the output Tensor shapes.""" 54 | if self._input_size is None: 55 | return tf.TensorShape([None]) 56 | flattened_size = 0 57 | for inp_size in snt.nest.flatten(self._input_size): 58 | num_elements = inp_size.num_elements() 59 | if num_elements is None: 60 | return tf.TensorShape([None]) 61 | flattened_size += num_elements 62 | return tf.TensorShape([flattened_size]) 63 | 64 | 65 | def _build(self, inp): 66 | input_sizes = snt.nest.map(lambda inp_i: inp_i.get_shape()[1:], inp) 67 | self._merge_input_sizes(input_sizes) 68 | flatten = snt.BatchFlatten(preserve_dims=1) 69 | flat_inp = snt.nest.map(lambda inp_i: tf.to_float(flatten(inp_i)), inp) 70 | ret = util.concat_features(flat_inp) 71 | util.set_tensor_shapes(ret, self.output_size, add_batch_dims=1) 72 | return ret 73 | 74 | 75 | def MLPObsEncoder(hparams, name=None): 76 | """Observation -> encoded, flat observation.""" 77 | name = name or "mlp_obs_encoder" 78 | mlp = util.make_mlp(hparams, hparams.obs_encoder_fc_layers, 79 | name=name + "/mlp") 80 | return EncoderSequence([FlattenEncoder(), mlp], name=name) 81 | 82 | 83 | class DecoderSequence(dist_module.DistModule): 84 | """A sequence of zero or more AbstractModules, followed by a DistModule.""" 85 | 86 | def __init__(self, input_encoders, decoder, name=None): 87 | super(DecoderSequence, self).__init__(name=name) 88 | self._input_encoders = input_encoders 89 | self._decoder = decoder 90 | 91 | @property 92 | def event_dtype(self): 93 | return self._decoder.event_dtype 94 | 95 | @property 96 | def event_size(self): 97 | return self._decoder.event_size 98 | 99 | def dist(self, params, name=None): 100 | return self._decoder.dist(params, name=name) 101 | 102 | def _build(self, inputs): 103 | if self._input_encoders: 104 | inputs = snt.Sequential(self._input_encoders)(inputs) 105 | return self._decoder(inputs) 106 | 107 | 108 | def MLPObsDecoder(hparams, decoder, param_size, name=None): 109 | """Inputs -> decoder(obs; mlp(inputs)).""" 110 | name = name or "mlp_" + decoder.module_name 111 | layers = hparams.obs_decoder_fc_hidden_layers + [param_size] 112 | mlp = util.make_mlp(hparams, layers, name=name + "/mlp") 113 | return DecoderSequence([util.concat_features, mlp], decoder, name=name) 114 | 115 | 116 | class BernoulliDecoder(dist_module.DistModule): 117 | """Inputs -> Bernoulli(obs; logits=inputs).""" 118 | 119 | def __init__(self, dtype=tf.int32, squeeze_input=False, name=None): 120 | self._dtype = dtype 121 | self._squeeze_input = squeeze_input 122 | super(BernoulliDecoder, self).__init__(name=name) 123 | 124 | @property 125 | def event_dtype(self): 126 | return self._dtype 127 | 128 | @property 129 | def event_size(self): 130 | return tf.TensorShape([]) 131 | 132 | def _build(self, inputs): 133 | if self._squeeze_input: 134 | inputs = tf.squeeze(inputs, axis=-1) 135 | return inputs 136 | 137 | def dist(self, params, name=None): 138 | return tf.distributions.Bernoulli( 139 | logits=params, 140 | dtype=self._dtype, 141 | name=name or self.module_name + "_dist") 142 | 143 | 144 | class BetaDecoder(dist_module.DistModule): 145 | """Inputs -> Beta(obs; conc1, conc0).""" 146 | 147 | def __init__(self, 148 | positive_projection=None, 149 | squeeze_input=False, 150 | name=None): 151 | self._positive_projection = positive_projection 152 | self._squeeze_input = squeeze_input 153 | super(BetaDecoder, self).__init__(name=name) 154 | 155 | @property 156 | def event_dtype(self): 157 | return tf.float32 158 | 159 | @property 160 | def event_size(self): 161 | return tf.TensorShape([]) 162 | 163 | def _build(self, inputs): 164 | conc1, conc0 = tf.split(inputs, 2, axis=-1) 165 | if self._positive_projection is not None: 166 | conc1 = self._positive_projection(conc1) 167 | conc0 = self._positive_projection(conc0) 168 | if self._squeeze_input: 169 | conc1 = tf.squeeze(conc1, axis=-1) 170 | conc0 = tf.squeeze(conc0, axis=-1) 171 | return (conc1, conc0) 172 | 173 | def dist(self, params, name=None): 174 | conc1, conc0 = params 175 | return tf.distributions.Beta( 176 | conc1, conc0, 177 | name=name or self.module_name + "_dist") 178 | 179 | 180 | class _BinomialDist(tf.contrib.distributions.Binomial): 181 | """Work around missing functionality in Binomial.""" 182 | 183 | def __init__(self, total_count, logits=None, probs=None, name=None): 184 | self._total_count = total_count 185 | super(_BinomialDist, self).__init__( 186 | total_count=tf.to_float(total_count), 187 | logits=logits, probs=probs, 188 | name=name or "Binomial") 189 | 190 | def _log_prob(self, counts): 191 | return super(_BinomialDist, self)._log_prob(tf.to_float(counts)) 192 | 193 | def _sample_n(self, n, seed=None): 194 | all_counts = tf.to_float(tf.range(self._total_count + 1)) 195 | for batch_dim in range(self.batch_shape.ndims): 196 | all_counts = tf.expand_dims(all_counts, axis=-1) 197 | all_cdfs = tf.map_fn(self.cdf, all_counts) 198 | shape = tf.concat([[n], self.batch_shape_tensor()], 0) 199 | uniform = tf.random_uniform(shape, seed=seed) 200 | return tf.foldl( 201 | lambda acc, cdfs: tf.where(uniform > cdfs, acc + 1, acc), 202 | all_cdfs, 203 | initializer=tf.zeros(shape, dtype=tf.int32)) 204 | 205 | 206 | class BinomialDecoder(dist_module.DistModule): 207 | """Inputs -> Binomial(obs; total_count, logits).""" 208 | 209 | def __init__(self, total_count=None, squeeze_input=False, name=None): 210 | self._total_count = total_count 211 | self._squeeze_input = squeeze_input 212 | super(BinomialDecoder, self).__init__(name=name) 213 | 214 | @property 215 | def event_dtype(self): 216 | return tf.int32 217 | 218 | @property 219 | def event_size(self): 220 | return tf.TensorShape([]) 221 | 222 | def _build(self, inputs): 223 | if self._squeeze_input: 224 | inputs = tf.squeeze(inputs, axis=-1) 225 | return inputs 226 | 227 | def dist(self, params, name=None): 228 | return _BinomialDist( 229 | self._total_count, 230 | logits=params, 231 | name=name or self.module_name + "_dist") 232 | 233 | 234 | class CategoricalDecoder(dist_module.DistModule): 235 | """Inputs -> Categorical(obs; logits=inputs).""" 236 | 237 | def __init__(self, dtype=tf.int32, name=None): 238 | self._dtype = dtype 239 | super(CategoricalDecoder, self).__init__(name=name) 240 | 241 | @property 242 | def event_dtype(self): 243 | return self._dtype 244 | 245 | @property 246 | def event_size(self): 247 | return tf.TensorShape([]) 248 | 249 | def _build(self, inputs): 250 | return inputs 251 | 252 | def dist(self, params, name=None): 253 | return tf.distributions.Categorical( 254 | logits=params, 255 | dtype=self._dtype, 256 | name=name or self.module_name + "_dist") 257 | 258 | 259 | class NormalDecoder(dist_module.DistModule): 260 | """Inputs -> Normal(obs; loc=half(inputs), scale=project(half(inputs)))""" 261 | 262 | def __init__(self, positive_projection=None, name=None): 263 | self._positive_projection = positive_projection 264 | super(NormalDecoder, self).__init__(name=name) 265 | 266 | @property 267 | def event_dtype(self): 268 | return tf.float32 269 | 270 | @property 271 | def event_size(self): 272 | return tf.TensorShape([]) 273 | 274 | def _build(self, inputs): 275 | loc, scale = tf.split(inputs, 2, axis=-1) 276 | if self._positive_projection is not None: 277 | scale = self._positive_projection(scale) 278 | return loc, scale 279 | 280 | def dist(self, params, name=None): 281 | loc, scale = params 282 | return tf.distributions.Normal( 283 | loc=loc, 284 | scale=scale, 285 | name=name or self.module_name + "_dist") 286 | 287 | 288 | class BatchDecoder(dist_module.DistModule): 289 | """Wrap a decoder to model batches of events.""" 290 | 291 | def __init__(self, decoder, event_size, name=None): 292 | self._decoder = decoder 293 | self._event_size = tf.TensorShape(event_size) 294 | super(BatchDecoder, self).__init__(name=name) 295 | 296 | @property 297 | def event_dtype(self): 298 | return self._decoder.event_dtype 299 | 300 | @property 301 | def event_size(self): 302 | return self._event_size 303 | 304 | def _build(self, inputs): 305 | return self._decoder(inputs) 306 | 307 | def dist(self, params, name=None): 308 | return batch_dist.BatchDistribution( 309 | self._decoder.dist(params, name=name), 310 | ndims=self._event_size.ndims) 311 | 312 | 313 | class GroupDecoder(dist_module.DistModule): 314 | """Group up decoders to model a set of independent of events.""" 315 | 316 | def __init__(self, decoders, name=None): 317 | self._decoders = decoders 318 | super(GroupDecoder, self).__init__(name=name) 319 | 320 | @property 321 | def event_dtype(self): 322 | return snt.nest.map(lambda dec: dec.event_dtype, self._decoders) 323 | 324 | @property 325 | def event_size(self): 326 | return snt.nest.map(lambda dec: dec.event_size, self._decoders) 327 | 328 | def _build(self, inputs): 329 | return snt.nest.map_up_to( 330 | self._decoders, 331 | lambda dec, input_: dec(input_), 332 | self._decoders, inputs) 333 | 334 | def dist(self, params, name=None): 335 | with self._enter_variable_scope(): 336 | with tf.name_scope(name or "group"): 337 | dists = snt.nest.map_up_to( 338 | self._decoders, 339 | lambda dec, param: dec.dist(param), 340 | self._decoders, params) 341 | return batch_dist.GroupDistribution(dists, name=name) 342 | -------------------------------------------------------------------------------- /vaeseq/context.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google, Inc., 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Context modules summarize inputs and previous observations.""" 16 | 17 | import abc 18 | import sonnet as snt 19 | import tensorflow as tf 20 | 21 | from . import util 22 | 23 | 24 | def as_context(context, name=None): 25 | """Takes Tensors | Context and returns a Context.""" 26 | if context is None: 27 | raise ValueError("Please supply a Context or a set of nested tensors.") 28 | if isinstance(context, Context): 29 | return context 30 | return Constant(context, name=name) 31 | 32 | 33 | def as_tensors(context, observed): 34 | """Takes Tensors | Context and returns Tensors.""" 35 | if context is None: 36 | raise ValueError("Please supply a Context or a set of nested tensors.") 37 | if isinstance(context, Context): 38 | context = context.from_observations(observed) 39 | return context 40 | 41 | 42 | def _from_observations_cache_key(observations, initial_state): 43 | """Cache key used to memoize repeated calls to Context.from_observations.""" 44 | flat_obs = snt.nest.flatten(observations) 45 | obs_names = tuple([obs.name for obs in flat_obs]) 46 | state_names = None 47 | if initial_state is not None: 48 | flat_state = snt.nest.flatten(initial_state) 49 | state_names = tuple([st.name for st in flat_state]) 50 | return (obs_names, state_names) 51 | 52 | 53 | class Context(snt.RNNCore): 54 | """Context interface.""" 55 | 56 | def __init__(self, name=None): 57 | super(Context, self).__init__(name=name) 58 | self._from_observations_cache = {} 59 | 60 | @abc.abstractproperty 61 | def output_size(self): 62 | """The non-batch sizes of the context Tensors.""" 63 | 64 | @property 65 | def output_dtype(self): 66 | """The context Tensor types.""" 67 | return snt.nest.map(lambda _: tf.float32, self.output_size) 68 | 69 | @abc.abstractproperty 70 | def state_size(self): 71 | """The non-batch sizes of this module's state Tensors.""" 72 | 73 | @abc.abstractproperty 74 | def state_dtype(self): 75 | """The types of this module's state Tensors.""" 76 | 77 | def initial_state(self, batch_size): 78 | def _zero_state(size, dtype): 79 | return tf.zeros([batch_size] + tf.TensorShape(size).as_list(), 80 | dtype=dtype) 81 | return snt.nest.map(_zero_state, self.state_size, self.state_dtype) 82 | 83 | @abc.abstractmethod 84 | def _build(self, input_, state): 85 | """Returns a context for the current time step.""" 86 | 87 | @abc.abstractmethod 88 | def observe(self, observation, state): 89 | """Returns the updated state.""" 90 | 91 | def finished(self, state): 92 | """Returns whether each sequence in the batch has completed.""" 93 | return False 94 | 95 | def drive_rnn(self, 96 | cell, 97 | sequence_size, 98 | initial_state, 99 | cell_initial_state, 100 | cell_output_dtype=None, 101 | cell_output_observations=lambda out: out): 102 | """Equivalent to tf.nn.dynamic_rnn, with inputs from this Context.""" 103 | if cell_output_dtype is None: 104 | cell_output_dtype = snt.nest.map( 105 | lambda _: tf.float32, cell.output_size) 106 | def _loop_fn(time, cell_output, cell_state, ctx_state): 107 | if cell_state is None: 108 | cell_state = cell_initial_state 109 | if ctx_state is None: 110 | ctx_state = initial_state 111 | if cell_output is not None: 112 | obs = cell_output_observations(cell_output) 113 | ctx_state = self.observe(obs, ctx_state) 114 | finished = tf.logical_or(time >= sequence_size, 115 | self.finished(ctx_state)) 116 | ctx, ctx_state = self(None, ctx_state) 117 | if cell_output is None: 118 | # tf.nn.raw_rnn uses the first cell_output as a dummy 119 | # to determine the output types and shapes. We need to 120 | # specify this to use heterogeneous output dtypes. 121 | # Note that the tensors here do not include the batch 122 | # dimension. 123 | with tf.name_scope("dummy"): 124 | cell_output = snt.nest.map( 125 | tf.zeros, 126 | cell.output_size, 127 | cell_output_dtype) 128 | return (finished, ctx, cell_state, cell_output, ctx_state) 129 | output_tas = tf.nn.raw_rnn(cell, _loop_fn)[0] 130 | outputs = snt.nest.map( 131 | lambda ta: util.transpose_time_batch(ta.stack()), 132 | output_tas) 133 | util.set_tensor_shapes(outputs, cell.output_size, add_batch_dims=2) 134 | return outputs 135 | 136 | def from_observations(self, observed, initial_state=None): 137 | """Generate contexts for a static sequence of observations.""" 138 | cache_key = _from_observations_cache_key(observed, initial_state) 139 | if cache_key in self._from_observations_cache: 140 | return self._from_observations_cache[cache_key] 141 | with self._enter_variable_scope(): 142 | with tf.name_scope("from_observations"): 143 | batch_size = util.batch_size_from_nested_tensors(observed) 144 | if initial_state is None: 145 | initial_state = self.initial_state(batch_size) 146 | def _step(obs, state): 147 | ctx, state = self(None, state) 148 | state = self.observe(obs, state) 149 | return ctx, state 150 | cell = util.WrapRNNCore( 151 | _step, 152 | state_size=self.state_size, 153 | output_size=self.output_size) 154 | cell, observed = util.add_support_for_scalar_rnn_inputs( 155 | cell, observed) 156 | contexts, _ = util.heterogeneous_dynamic_rnn( 157 | cell, observed, 158 | initial_state=initial_state, 159 | output_dtypes=self.output_dtype) 160 | self._from_observations_cache[cache_key] = contexts 161 | return contexts 162 | 163 | 164 | class Constant(Context): 165 | """Constant context wrapping a nested tuple of tensors.""" 166 | 167 | def __init__(self, tensors, name=None): 168 | super(Constant, self).__init__(name=name) 169 | self._batch_size = util.batch_size_from_nested_tensors(tensors) 170 | self._sequence_size = util.sequence_size_from_nested_tensors(tensors) 171 | self._tensors = tensors 172 | 173 | @property 174 | def output_size(self): 175 | return snt.nest.map(lambda tensor: tensor.get_shape()[2:], 176 | self._tensors) 177 | 178 | @property 179 | def output_dtype(self): 180 | return snt.nest.map(lambda tensor: tensor.dtype, self._tensors) 181 | 182 | @property 183 | def state_size(self): 184 | return tf.TensorShape([]) 185 | 186 | @property 187 | def state_dtype(self): 188 | return tf.int32 189 | 190 | def initial_state(self, batch_size): 191 | del batch_size # Ignore the requested batch size. 192 | return super(Constant, self).initial_state(self._batch_size) 193 | 194 | def observe(self, observation, state): 195 | del observation # Not used. 196 | return state 197 | 198 | def finished(self, state): 199 | return state >= self._sequence_size 200 | 201 | def _build(self, input_, state): 202 | if input_ is not None: 203 | raise ValueError("I don't know how to encode any inputs.") 204 | finished = self.finished(state) 205 | state = tf.minimum(state, self._sequence_size - 1) 206 | indices = tf.concat([tf.expand_dims(tf.range(tf.shape(state)[0]), 1), 207 | tf.expand_dims(state, 1)], axis=1) 208 | outputs = snt.nest.map(lambda tensor: tf.gather_nd(tensor, indices), 209 | self._tensors) 210 | util.set_tensor_shapes(outputs, self.output_size, add_batch_dims=1) 211 | zero_outputs = snt.nest.map(tf.zeros_like, outputs) 212 | outputs = snt.nest.map(lambda zero, out: tf.where(finished, zero, out), 213 | zero_outputs, outputs) 214 | return outputs, state + 1 215 | 216 | 217 | class Chain(Context): 218 | """Compose a list of contexts.""" 219 | 220 | def __init__(self, contexts, name=None): 221 | super(Chain, self).__init__(name=name) 222 | self._contexts = contexts 223 | 224 | @property 225 | def output_size(self): 226 | return self._contexts[-1].output_size 227 | 228 | @property 229 | def output_dtype(self): 230 | return self._contexts[-1].output_dtype 231 | 232 | @property 233 | def state_size(self): 234 | return tuple([ctx.state_size for ctx in self._contexts]) 235 | 236 | @property 237 | def state_dtype(self): 238 | return tuple([ctx.state_dtype for ctx in self._contexts]) 239 | 240 | def initial_state(self, batch_size): 241 | return [ctx.initial_state(batch_size) for ctx in self._contexts] 242 | 243 | def observe(self, observation, state): 244 | ret = [] 245 | for context, ctx_state in zip(self._contexts, state): 246 | ret.append(context.observe(observation, ctx_state)) 247 | return ret 248 | 249 | def finished(self, state): 250 | finished = False 251 | for context, ctx_state in zip(self._contexts, state): 252 | finished = tf.logical_or(finished, context.finished(ctx_state)) 253 | return finished 254 | 255 | def _build(self, input_, state): 256 | ctx_out = input_ 257 | ctx_states = [] 258 | for context, ctx_state in zip(self._contexts, state): 259 | ctx_out, ctx_state = context(ctx_out, ctx_state) 260 | ctx_states.append(ctx_state) 261 | return ctx_out, ctx_states 262 | 263 | 264 | class EncodeObserved(Context): 265 | """Simple context that encodes the input and previous observation.""" 266 | 267 | def __init__(self, obs_encoder, input_encoder=None, name=None): 268 | super(EncodeObserved, self).__init__(name=name) 269 | self._input_encoder = input_encoder 270 | self._obs_encoder = obs_encoder 271 | 272 | @property 273 | def output_size(self): 274 | if self._input_encoder is None: 275 | return self._obs_encoder.output_size 276 | return (self._input_encoder.output_size, 277 | self._obs_encoder.output_size) 278 | 279 | @property 280 | def state_size(self): 281 | return self._obs_encoder.output_size 282 | 283 | @property 284 | def state_dtype(self): 285 | return tf.float32 286 | 287 | def observe(self, observation, state): 288 | del state # Not used. 289 | return self._obs_encoder(observation) 290 | 291 | def _build(self, input_, state): 292 | if input_ is not None and self._input_encoder is None: 293 | raise ValueError("I don't know how to encode any inputs.") 294 | if self._input_encoder is None: 295 | ret = state 296 | else: 297 | ret = (self._input_encoder(input_), state) 298 | return ret, state 299 | 300 | 301 | class Accumulate(Context): 302 | """Accumulates the last N observation encodings.""" 303 | 304 | def __init__(self, obs_encoder, history_size, history_combiner, name=None): 305 | super(Accumulate, self).__init__(name=name) 306 | self._obs_encoder = obs_encoder 307 | self._history_size = history_size 308 | self._history_combiner = history_combiner 309 | 310 | @property 311 | def output_size(self): 312 | return self._history_combiner.output_size 313 | 314 | @property 315 | def state_size(self): 316 | obs_size = self._obs_encoder.output_size 317 | history_size = tf.TensorShape([self._history_size]) 318 | return snt.nest.map(lambda size: history_size.concatenate(size), 319 | obs_size) 320 | 321 | @property 322 | def state_dtype(self): 323 | return snt.nest.map(lambda _: tf.float32, self.state_size) 324 | 325 | def observe(self, observation, state): 326 | enc_obs = tf.expand_dims(self._obs_encoder(observation), axis=1) 327 | return snt.nest.map( 328 | lambda hist, obs: tf.concat([hist[:, 1:, :], obs], axis=1), 329 | state, enc_obs) 330 | 331 | def _build(self, input_, state): 332 | if input_ is not None: 333 | raise ValueError("I don't know how to encode any inputs.") 334 | return self._history_combiner(state), state 335 | --------------------------------------------------------------------------------