├── assets ├── model.png ├── tensorboard.png ├── early_random.png ├── early_converged.png └── brushes │ ├── charcoal.myb │ └── dry_brush.myb ├── models ├── __init__.py ├── discriminator.py └── policy.py ├── requirements.txt ├── utils ├── __init__.py ├── args.py ├── image.py ├── misc.py ├── tf.py ├── io.py ├── train.py └── logging.py ├── envs ├── __init__.py ├── utils.py ├── base.py ├── mypaint_utils.py ├── simple.py └── mnist.py ├── install.sh ├── replay.py ├── .gitignore ├── LICENSE ├── main.py ├── README.md ├── run.py ├── config.py ├── trainer.py ├── rl_utils.py └── agent.py /assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/SPIRAL-tensorflow/HEAD/assets/model.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .policy import Policy 2 | from .discriminator import Discriminator 3 | -------------------------------------------------------------------------------- /assets/tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/SPIRAL-tensorflow/HEAD/assets/tensorboard.png -------------------------------------------------------------------------------- /assets/early_random.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/SPIRAL-tensorflow/HEAD/assets/early_random.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ipdb 2 | tqdm 3 | numpy 4 | scipy 5 | pillow 6 | pathlib 7 | cloudpickle 8 | future-fstrings 9 | -------------------------------------------------------------------------------- /assets/early_converged.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/SPIRAL-tensorflow/HEAD/assets/early_converged.png -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import io 2 | from . import tf 3 | from . import args 4 | from . import misc 5 | from . import train 6 | from . import logging 7 | -------------------------------------------------------------------------------- /utils/args.py: -------------------------------------------------------------------------------- 1 | import parser 2 | 3 | def str2bool(v): 4 | return v.lower() in ('true') 5 | 6 | def str_list(value): 7 | if not value: 8 | return value 9 | else: 10 | return [num for num in value.split(',')] 11 | 12 | def int_list(value): 13 | return [int(num) for num in value.split(',')] 14 | 15 | def add_argument_group(parser, name): 16 | arg = parser.add_argument_group(name) 17 | return arg 18 | 19 | -------------------------------------------------------------------------------- /envs/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: future_fstrings -*- 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from .simple import Simple 7 | from .mnist import MNIST, SimpleMNIST 8 | 9 | 10 | def create_env(args): 11 | env = args.env.lower() 12 | if env == 'simple': 13 | env = Simple(args) 14 | elif env == 'simple_mnist': 15 | env = SimpleMNIST(args) 16 | elif env == 'mnist': 17 | env = MNIST(args) 18 | else: 19 | raise Exception("Unkown environment: {}".format(args.env)) 20 | return env 21 | -------------------------------------------------------------------------------- /envs/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def rgb2gray(rgb): 4 | rgb = np.dot(rgb[...,:3], [0.299, 0.587, 0.114]) 5 | return np.expand_dims(rgb, -1) 6 | 7 | def l2(mat1, mat2): 8 | return np.sqrt(np.sum((mat1 - mat2)**2)) 9 | 10 | def uniform_locations(screen_size, location_size, object_radius, 11 | normalize=False): 12 | x = np.linspace(object_radius, screen_size-object_radius, location_size) 13 | grid = np.meshgrid(x, x) 14 | out = np.array(zip(*np.vstack(map(np.ravel, grid)))) 15 | if normalize: 16 | div = location_size**2 / 2 17 | out = (out - div) / div 18 | return out 19 | 20 | -------------------------------------------------------------------------------- /assets/brushes/charcoal.myb: -------------------------------------------------------------------------------- 1 | # mypaint brush file 2 | color 0 0 0 3 | opaque 0.4 | pressure 1.0 0.4 0.0 0.0 0.0 0.0 0.0 0.0 4 | opaque_multiply 0.0 | pressure 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 5 | radius_logarithmic 0.7 6 | hardness 0.2 7 | dabs_per_basic_radius 0.0 8 | dabs_per_actual_radius 5.0 9 | dabs_per_second 0.0 10 | radius_by_random 0.0 11 | speed1_slowness 0.04 12 | speed2_slowness 0.8 13 | offset_by_random 1.6 | pressure 1.0 -1.4 0.0 0.0 0.0 0.0 0.0 0.0 14 | offset_by_speed 0.0 15 | offset_by_speed_slowness 1.0 16 | slow_tracking 2.0 17 | slow_tracking_per_dab 0.0 18 | color_value 0.0 19 | color_saturation 0.0 20 | color_hue 0.0 21 | adapt_color_from_image 0.0 22 | change_radius 0.0 23 | stroke_treshold 0.0 24 | stroke_duration_logarithmic 4.0 25 | stroke_holdtime 0.0 26 | opaque_linearize 0.0 27 | -------------------------------------------------------------------------------- /utils/image.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | tfgan = tf.contrib.gan 3 | 4 | 5 | def get_image_grid(images, batch_size, num_classes, num_images_per_class): 6 | # Code from https://github.com/tensorflow/models/blob/master/research/gan/cifar/util.py 7 | images.shape[0:1].assert_is_compatible_with([batch_size]) 8 | if batch_size < num_classes * num_images_per_class: 9 | raise ValueError('Not enough images in batch to show the desired number of ' 10 | 'images.') 11 | if batch_size % num_classes != 0: 12 | raise ValueError('`batch_size` must be divisible by `num_classes`.') 13 | 14 | # Only get a certain number of images per class. 15 | num_batches = batch_size // num_classes 16 | indices = [i * num_batches + j for i in xrange(num_classes) 17 | for j in xrange(num_images_per_class)] 18 | sampled_images = tf.gather(images, indices) 19 | return tfgan.eval.image_reshaper( 20 | sampled_images, num_cols=num_images_per_class) 21 | 22 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | sudo apt-get install -y build-essential 3 | sudo apt-get install -y libjson-c-dev libgirepository1.0-dev libglib2.0-dev 4 | sudo apt-get install -y python2.7 autotools-dev intltool gettext libtool 5 | 6 | sudo apt-get install -y git swig python-setuptools gettext g++ 7 | sudo apt-get install -y python-dev python-numpy 8 | sudo apt-get install -y libgtk-3-dev python-gi-dev 9 | sudo apt-get install -y libpng-dev liblcms2-dev libjson-c-dev 10 | sudo apt-get install -y gir1.2-gtk-3.0 python-gi-cairo 11 | 12 | mkdir libs 13 | cd libs 14 | 15 | if [ ! -d libmypaint ]; then 16 | wget https://github.com/mypaint/libmypaint/releases/download/v1.3.0/libmypaint-1.3.0.tar.xz 17 | tar -xvf libmypaint-1.3.0.tar.xz 18 | mv libmypaint-1.3.0 libmypaint 19 | cd libmypaint 20 | ./configure 21 | sudo make install 22 | cd .. 23 | fi 24 | 25 | if [ ! -d mypaint ]; then 26 | wget https://github.com/mypaint/mypaint/releases/download/v1.2.1/mypaint-1.2.1.tar.xz 27 | tar -xvf mypaint-1.2.1.tar.xz 28 | mv mypaint-1.2.1 mypaint 29 | cd mypaint 30 | scons 31 | sudo scons install 32 | cd .. 33 | fi 34 | 35 | sudo ldconfig 36 | -------------------------------------------------------------------------------- /assets/brushes/dry_brush.myb: -------------------------------------------------------------------------------- 1 | # mypaint brush file 2 | # you can edit this file and then select the brush in mypaint (again) to reload 3 | version 2 4 | opaque 0.8 | pressure (0.000000 0.000000), (1.000000 0.200000) 5 | opaque_multiply 0.0 | pressure (0.000000 0.000000), (1.000000 1.000000) 6 | opaque_linearize 0.0 7 | radius_logarithmic 0.6 | speed2 (0.000000 0.042857), (4.000000 -0.300000) 8 | hardness 0.2 9 | dabs_per_basic_radius 6.0 10 | dabs_per_actual_radius 6.0 11 | dabs_per_second 0.0 12 | radius_by_random 0.1 13 | speed1_slowness 0.04 14 | speed2_slowness 0.8 15 | speed1_gamma 4.0 16 | speed2_gamma 4.0 17 | offset_by_random 0.0 | pressure (0.000000 0.000000), (1.000000 1.400000) 18 | offset_by_speed 0.0 19 | offset_by_speed_slowness 1.0 20 | slow_tracking 2.0 21 | slow_tracking_per_dab 0.0 22 | tracking_noise 0.0 23 | color_h 0.0 24 | color_s 0.0 25 | color_v 0.0 26 | change_color_h 0.0 27 | change_color_l 0.0 28 | change_color_hsl_s 0.0 29 | change_color_v 0.0 30 | change_color_hsv_s 0.0 31 | smudge 0.0 32 | smudge_length 0.5 33 | stroke_treshold 0.0 34 | stroke_duration_logarithmic 4.0 35 | stroke_holdtime 0.0 36 | custom_input 0.0 37 | custom_input_slowness 0.0 38 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: future_fstrings -*- 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import sys 7 | import uuid 8 | from datetime import datetime 9 | from collections import defaultdict 10 | from tensorflow.python.client import device_lib 11 | 12 | 13 | def count_gpu(): 14 | devices = device_lib.list_local_devices() 15 | return sum(1 for d in devices if d.device_type == 'GPU') 16 | 17 | def get_hash(length): 18 | assert length <= 32, "length of hash should be less than 32" 19 | hash_text = uuid.uuid4().hex 20 | return hash_text[:length] 21 | 22 | def get_time(): 23 | return datetime.now().strftime("%m%d_%H%M%S") 24 | 25 | def progress(count, total, status=''): 26 | bar_len = 60 27 | filled_len = int(round(bar_len * count / float(total))) 28 | 29 | percents = round(100.0 * count / float(total), 1) 30 | bar = '=' * filled_len + '-' * (bar_len - filled_len) 31 | 32 | sys.stdout.write(f'[{status}] {bar}| [{count}/{total}] {percents}%\r') 33 | sys.stdout.flush() 34 | 35 | 36 | class keydefaultdict(defaultdict): 37 | def __missing__(self, key): 38 | if self.default_factory is None: 39 | raise KeyError(key) 40 | else: 41 | ret = self[key] = self.default_factory(key) 42 | return ret 43 | -------------------------------------------------------------------------------- /replay.py: -------------------------------------------------------------------------------- 1 | # -*- coding: future_fstrings -*- 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | 8 | 9 | class ReplayBuffer(object): 10 | # Code based on https://github.com/carpedm20/simulated-unsupervised-tensorflow/blob/master/replay.py 11 | def __init__(self, args, observation_shape): 12 | self.buffer_size = args.buffer_batch_num * args.disc_batch_size 13 | self.batch_size = args.disc_batch_size 14 | 15 | self.rng = np.random.RandomState(args.seed) 16 | 17 | self.idx = 0 18 | replay_shape = [self.buffer_size] + observation_shape 19 | self.data = np.zeros(replay_shape, dtype=np.uint8) 20 | 21 | self.most_recent = None 22 | 23 | def push(self, batches): 24 | batch_size = len(batches) 25 | if self.idx + batch_size > self.buffer_size: 26 | self.data[:-batch_size] = self.data[batch_size:] 27 | self.data[-batch_size:] = batches 28 | else: 29 | self.data[self.idx:self.idx+batch_size] = batches 30 | self.idx += int(batch_size) 31 | 32 | def sample(self, n): 33 | while self.idx < n: 34 | pass 35 | random_idx = self.rng.choice( 36 | self.idx, self.batch_size) 37 | return self.data[random_idx].astype(np.float32) 38 | -------------------------------------------------------------------------------- /envs/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Environment(object): 5 | 6 | def __init__(self, args): 7 | self.args = args 8 | 9 | if not args.jump and 'jump' in self.action_sizes: 10 | del self.action_sizes['jump'] 11 | 12 | if not args.curve and 'control' in self.action_sizes: 13 | del self.action_sizes['control'] 14 | 15 | # terminal 16 | self.episode_length = args.episode_length 17 | 18 | # screen 19 | self.screen_size = args.screen_size 20 | self.height, self.width = self.screen_size, self.screen_size 21 | self.observation_shape = [ 22 | self.height, self.width, args.color_channel] 23 | 24 | # location 25 | self.location_size = args.location_size 26 | self.location_shape = [self.location_size, self.location_size] 27 | 28 | for name, value in self.action_sizes.items(): 29 | if value is None: 30 | self.action_sizes[name] = self.location_shape 31 | 32 | self.acs = list(self.action_sizes.keys()) 33 | self.ac_idx = { 34 | ac:idx for idx, ac in enumerate(self.acs) 35 | } 36 | 37 | self.conditional = args.conditional 38 | 39 | def random_action(self): 40 | action = [] 41 | for ac in self.acs: 42 | size = self.action_sizes[ac] 43 | sample = np.random.randint(np.prod(size)) 44 | action.append(sample) 45 | return action 46 | 47 | @property 48 | def initial_action(self): 49 | return [-1] * len(self.acs) 50 | 51 | def norm(self, img): 52 | return (np.array(img) - 127.5) / 127.5 53 | 54 | def denorm(self, img): 55 | return img * 127.5 + 127.5 56 | -------------------------------------------------------------------------------- /utils/tf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: future_fstrings -*- 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import tensorflow as tf 7 | from tensorflow.python.client import device_lib 8 | 9 | 10 | # Disables write_meta_graph argument, which freezes entire process and is mostly useless. 11 | class FastSaver(tf.train.Saver): 12 | def save(self, sess, save_path, global_step=None, latest_filename=None, 13 | meta_graph_suffix="meta", write_meta_graph=True): 14 | super(FastSaver, self).save(sess, save_path, global_step, latest_filename, 15 | meta_graph_suffix, False) 16 | 17 | 18 | def get_all_variables(scope_name): 19 | return tf.get_collection( 20 | tf.GraphKeys.GLOBAL_VARIABLES, 21 | scope=scope_name) 22 | 23 | def get_sync_op(from_list, to_list): 24 | assert len(from_list) == len(to_list), \ 25 | "length of to variables should be same ({len(from_list)} != {len(to_list)})" 26 | syncs = [] 27 | for from_v, to_v in zip(from_list, to_list): 28 | assert from_v.get_shape() == to_v.get_shape(), \ 29 | f"{from_v.get_shape()} != {to_v.get_shape()}" \ 30 | f" ({from_v.name}, {to_v.name})" 31 | sync = to_v.assign(from_v) 32 | syncs.append(sync) 33 | return tf.group(*syncs) 34 | 35 | def cluster_spec(num_workers, num_ps, port=12222): 36 | cluster = {} 37 | 38 | all_ps = [] 39 | host = '127.0.0.1' 40 | for _ in range(num_ps): 41 | all_ps.append('{}:{}'.format(host, port)) 42 | port += 1 43 | cluster['ps'] = all_ps 44 | 45 | all_workers = [] 46 | for _ in range(num_workers): 47 | all_workers.append('{}:{}'.format(host, port)) 48 | port += 1 49 | cluster['worker'] = all_workers 50 | return cluster 51 | 52 | def int_shape(tensor): 53 | shape = tensor.get_shape().as_list() 54 | return [num if num is not None else -1 for num in shape] 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Library 2 | libs 3 | 4 | # Data 5 | *.png 6 | *.gif 7 | *.tar.gz 8 | data/cifar-10-batches-py 9 | 10 | # ipython checkpoints 11 | .ipynb_checkpoints 12 | 13 | # Log 14 | logs 15 | 16 | # ETC 17 | .DS_Store 18 | .vscode 19 | 20 | # Created by https://www.gitignore.io/api/python,vim 21 | 22 | ### Python ### 23 | # Byte-compiled / optimized / DLL files 24 | __pycache__/ 25 | *.py[cod] 26 | *$py.class 27 | 28 | # C extensions 29 | *.so 30 | 31 | # Distribution / packaging 32 | .Python 33 | env/ 34 | build/ 35 | develop-eggs/ 36 | dist/ 37 | downloads/ 38 | eggs/ 39 | .eggs/ 40 | lib/ 41 | lib64/ 42 | parts/ 43 | sdist/ 44 | var/ 45 | wheels/ 46 | *.egg-info/ 47 | .installed.cfg 48 | *.egg 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .coverage 64 | .coverage.* 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | *,cover 69 | .hypothesis/ 70 | 71 | # Translations 72 | *.mo 73 | *.pot 74 | 75 | # Django stuff: 76 | *.log 77 | local_settings.py 78 | 79 | # Flask stuff: 80 | instance/ 81 | .webassets-cache 82 | 83 | # Scrapy stuff: 84 | .scrapy 85 | 86 | # Sphinx documentation 87 | docs/_build/ 88 | 89 | # PyBuilder 90 | target/ 91 | 92 | # Jupyter Notebook 93 | .ipynb_checkpoints 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # celery beat schedule file 99 | celerybeat-schedule 100 | 101 | # dotenv 102 | .env 103 | 104 | # virtualenv 105 | .venv/ 106 | venv/ 107 | ENV/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | 116 | ### Vim ### 117 | # swap 118 | [._]*.s[a-v][a-z] 119 | [._]*.sw[a-p] 120 | [._]s[a-v][a-z] 121 | [._]sw[a-p] 122 | # session 123 | Session.vim 124 | # temporary 125 | .netrwhist 126 | *~ 127 | # auto-generated tag files 128 | tags 129 | 130 | # End of https://www.gitignore.io/api/python,vim 131 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Taehoon Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | 24 | MIT License 25 | 26 | Copyright (c) 2016 openai 27 | 28 | Permission is hereby granted, free of charge, to any person obtaining a copy 29 | of this software and associated documentation files (the "Software"), to deal 30 | in the Software without restriction, including without limitation the rights 31 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 32 | copies of the Software, and to permit persons to whom the Software is 33 | furnished to do so, subject to the following conditions: 34 | 35 | The above copyright notice and this permission notice shall be included in all 36 | copies or substantial portions of the Software. 37 | 38 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 39 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 40 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 41 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 42 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 43 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 44 | SOFTWARE. 45 | -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | # -*- coding: future_fstrings -*- 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import sys 8 | import csv 9 | import json 10 | import math 11 | import timeit 12 | import numpy as np 13 | import cloudpickle 14 | from six.moves import shlex_quote 15 | 16 | from . import logging 17 | 18 | logger = logging.get_logger() 19 | 20 | try: 21 | import scipy.misc 22 | imread = scipy.misc.imread 23 | imresize = scipy.misc.imresize 24 | imsave = imwrite = scipy.misc.imsave 25 | except: 26 | import cv2 27 | imread = cv2.imread 28 | imresize = cv2.imresize 29 | imsave = imwrite = cv2.imwrite 30 | 31 | 32 | def get_cmd(as_list=False): 33 | args = [shlex_quote(arg) for arg in sys.argv][1:] 34 | if as_list: 35 | return args 36 | return ' '.join(args) 37 | 38 | class Timer: 39 | """Example: 40 | with ut.io.Timer("tokenize"): 41 | for text, score in ut.io.read_csv(path): 42 | chars = korean.tokenize(text) + [''] 43 | for char in chars: 44 | self.dictionary.add_char(char) 45 | """ 46 | def __init__(self, desc=""): 47 | self.desc = desc 48 | 49 | def __enter__(self): 50 | self.start = timeit.default_timer() 51 | return self 52 | 53 | def __exit__(self, *args): 54 | self.end = timeit.default_timer() 55 | self.interval = self.end - self.start 56 | logger.debug(f"[tt] {self.desc}: {self.interval:.3f}s") 57 | 58 | def makedirs(path): 59 | path = str(path) 60 | if not os.path.exists(path): 61 | logger.info(f"Make directories: {path}") 62 | os.makedirs(path) 63 | else: 64 | logger.warning(f"Skip making directories: {path}") 65 | 66 | def remove_file(path): 67 | if os.path.exists(path): 68 | os.remove(path) 69 | logger.info(f"Removed: {path}") 70 | 71 | ##################### 72 | # Pickle 73 | ##################### 74 | 75 | def _dump_pickle(path, data): 76 | path = str(path) 77 | with open(path, 'wb') as f, Timer(f"Dumped pickle: {path}"): 78 | cloudpickle.dump(data, f) 79 | 80 | def dump_pickle(path, data, num_split=1): 81 | if num_split == 1: 82 | _dump_pickle(path, data) 83 | else: 84 | pivot = 0 85 | window_size = math.ceil(len(data) / num_split) 86 | for idx in trange(num_split, desc=f"Dump {num_split} pickles"): 87 | new_path = add_postfix(path, idx) 88 | _dump_pickle(new_path, data[pivot:pivot+window_size]) 89 | pivot += window_size 90 | 91 | def _load_pickle(path): 92 | path = str(path) 93 | with open(path, 'rb') as f, Timer(f"Loaded pickle: {path}"): 94 | data = cloudpickle.load(f) 95 | return data 96 | 97 | def load_pickle(path, num_split=1): 98 | if num_split == 1: 99 | data = _load_pickle(path) 100 | else: 101 | data = [] 102 | for idx in trange(num_split, desc=f"Load {num_split} pickles"): 103 | new_path = add_postfix(path, idx) 104 | tmp_data = _load_pickle(new_path) 105 | data.extend(tmp_data) 106 | return data 107 | 108 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import time 7 | import sys, signal 8 | import tensorflow as tf 9 | 10 | import trainer 11 | import utils as ut 12 | from envs import create_env 13 | 14 | logger = ut.logging.get_logger() 15 | 16 | 17 | def main(_): 18 | from config import get_args 19 | args = get_args() 20 | 21 | ut.train.set_global_seed(args.seed + args.task) 22 | 23 | spec = ut.tf.cluster_spec( 24 | args.num_workers, 1, args.start_port) 25 | cluster = tf.train.ClusterSpec(spec) 26 | cluster_def = cluster.as_cluster_def() 27 | 28 | def shutdown(signal, frame): 29 | logger.warn('Received signal %s: exiting', signal) 30 | sys.exit(128+signal) 31 | 32 | signal.signal(signal.SIGHUP, shutdown) 33 | signal.signal(signal.SIGINT, shutdown) 34 | signal.signal(signal.SIGTERM, shutdown) 35 | 36 | ############################# 37 | # Prepare common envs 38 | ############################# 39 | 40 | env = create_env(args) 41 | 42 | queue_shapes = [ 43 | ['actions', [len(env.action_sizes)]], 44 | ['states', env.observation_shape], 45 | ['rewards', []], 46 | ['values', [1]], 47 | ['features', [2, args.lstm_size]], 48 | ] 49 | if args.conditional: 50 | queue_shapes.append(['conditions', env.observation_shape]) 51 | else: 52 | queue_shapes.append(['z', [args.z_dim]]) 53 | 54 | for idx, (name, shape) in enumerate(queue_shapes): 55 | length = env.episode_length 56 | if name == 'states': 57 | length += 1 58 | queue_shapes[idx][1] = [length] + shape 59 | 60 | queue_shapes.extend([ 61 | ('r', []), 62 | ]) 63 | 64 | trajectory_queue_size = \ 65 | args.policy_batch_size * max(5, args.num_workers) 66 | replay_queue_size = \ 67 | args.disc_batch_size * max(5, args.num_workers) 68 | 69 | ############################# 70 | # Run 71 | ############################# 72 | 73 | if args.task == 0: 74 | ut.train.save_args(args) 75 | 76 | if args.job_name == "worker": 77 | gpu_options = tf.GPUOptions(allow_growth=True) 78 | 79 | tf_config = tf.ConfigProto( 80 | allow_soft_placement=True, 81 | intra_op_parallelism_threads=1, 82 | inter_op_parallelism_threads=2, 83 | gpu_options=gpu_options) 84 | 85 | server = tf.train.Server( 86 | cluster_def, 87 | job_name="worker", 88 | task_index=args.task, 89 | config=tf_config) 90 | trainer.train(args, server, cluster, env, queue_shapes, 91 | trajectory_queue_size, replay_queue_size) 92 | else: 93 | del env 94 | 95 | server = tf.train.Server( 96 | cluster_def, job_name="ps", task_index=args.task, 97 | config=tf.ConfigProto(device_filters=["/job:ps"])) 98 | 99 | with tf.device("/job:ps/task:{}".format(args.task)): 100 | queue_size = args.policy_batch_size * args.num_workers 101 | 102 | queue = tf.FIFOQueue( 103 | trajectory_queue_size, 104 | [tf.float32] * len(queue_shapes), 105 | shapes=[shape for _, shape in queue_shapes], 106 | shared_name='queue') 107 | 108 | replay = tf.FIFOQueue( 109 | replay_queue_size, 110 | tf.float32, 111 | shapes=dict(queue_shapes)['states'][1:], 112 | shared_name='replay') 113 | 114 | while True: 115 | time.sleep(1000) 116 | 117 | if __name__ == "__main__": 118 | tf.app.run() 119 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SPIRAL in TensorFlow (in progress) 2 | 3 | TensorFlow implementation of [Synthesizing Programs for Images using Reinforced Adversarial Learning](https://deepmind.com/blog/learning-to-generate-images/) (**SPIRAL**). 4 | 5 | ![model](assets/model.png) 6 | 7 | **SPIRAL** is an adversarially trained agent that generates a program which is executed by a graphics engine to interpret and sample images. This agent is trained to fool a discriminator with a distributed reinforcement learning without any supervision. 8 | 9 | In short, Distributed RL + GAN + Program synthesis. 10 | 11 | 12 | ## Prerequisites 13 | 14 | - *Python 2.7* 15 | - [MyPaint 1.2.x](https://github.com/mypaint/mypaint/tree/v1.2.x) 16 | - [TensorFlow 1.6.0](http://pytorch.org/) 17 | 18 | 19 | ## Usage 20 | 21 | Install prerequisites with: 22 | 23 | ./install.sh 24 | pip install -r requirements.txt 25 | 26 | To debug a **SPIARL** model: 27 | 28 | python run.py --num_workers 8 --env simple --episode_length=1 \ 29 | --location_size=8 --conditional=True \ 30 | --loss=l2 --policy_batch_size=1 31 | 32 | To train a **SPIARL** model: 33 | 34 | python run.py --num_workers 16 --env simple_mnist --episode_length=3 \ 35 | --color_channel=1 --location_size=32 --loss=gan --num_gpu=1 \ 36 | --disc_dim=8 --conditional=False \ 37 | --mnist_nums=1,7 --jump=False --curve=False 38 | 39 | python run.py --num_workers 24 --env simple_mnist --episode_length=6 \ 40 | --color_channel=1 --location_size=32 --loss=gan --num_gpu=2 \ 41 | --disc_dim=64 --conditional=False \ 42 | --mnist_nums=0,1,2,3,4,5,6,7,8,9 --jump=True 43 | 44 | python run.py --num_workers 12 --env simple_mnist --episode_length=2 \ 45 | --color_channel=1 --location_size=32 --conditional=True \ 46 | --mnist_nums=1 --loss=gan 47 | 48 | python run.py --num_workers 24 --env simple_mnist --episode_length=3 \ 49 | --color_channel=1 --location_size=32 --conditional=True \ 50 | --mnist_nums=1,2,7 --loss=l2 51 | 52 | python run.py --num_workers 24 --env simple_mnist --episode_length=3 \ 53 | --color_channel=1 --location_size=32 --conditional=True \ 54 | --mnist_nums=1,2,7 --loss=gan --num_gpu=2 55 | 56 | python run.py --num_workers 24 --env simple_mnist --episode_length=5 \ 57 | --color_channel=1 --location_size=32 --conditional=True \ 58 | --mnist_nums=0,1,2,7 --loss=gan --num_gpu=2 59 | 60 | 61 | ## Results 62 | 63 | (in progress) 64 | 65 | Random generated samples at early stage: 66 | 67 | ![model](assets/early_converged.png) 68 | 69 | Incorrectly converged samples at early stage: 70 | 71 | ![model](assets/early_random.png) 72 | 73 | Tensorboard: 74 | 75 | ![model](assets/tensorboard.png) 76 | 77 | 78 | ## To-do 79 | 80 | - [x] IMPALA A2C 81 | - [ ] IMPALA V-trace 82 | - [x] Simple environment (debugging) 83 | - [x] Find a correct libmypaint setting 84 | - [x] MNIST environment 85 | - [x] ReplayThread (`--loss=gan`) 86 | - [x] `--num_gpu=2` test 87 | - [x] `--conditional=True` (need more details) 88 | - [ ] Replay memory needs more detailed information 89 | - [ ] Population Based Training (to be honest, I don't have any plan for this) 90 | 91 | 92 | ## References 93 | 94 | *This code is heavily based on [openai/universe-starter-agent](https://github.com/openai/universe-starter-agent).* 95 | 96 | - [Population Based Training of Neural Networks](https://arxiv.org/abs/1711.09846) 97 | - [Asynchronous Methods for Deep Reinforcement Learning](https://arxiv.org/abs/1602.01783) 98 | - [IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures](https://arxiv.org/abs/1802.01561) 99 | 100 | 101 | ## Author 102 | 103 | Taehoon Kim / [@carpedm20](http://carpedm20.github.io/) 104 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # -*- coding: future_fstrings -*- 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import sys 8 | import tensorflow as tf 9 | from six.moves import shlex_quote 10 | 11 | import utils as ut 12 | 13 | 14 | def new_cmd(session, name, cmd, load_path, shell): 15 | if isinstance(cmd, (list, tuple)): 16 | cmd = " ".join(shlex_quote(str(v)) for v in cmd) 17 | return name, "tmux send-keys -t {}:{} {} Enter".format(session, name, shlex_quote(cmd)) 18 | 19 | 20 | def create_commands(args, shell='bash'): 21 | ut.train.prepare_dirs(args) 22 | 23 | actual_args = ut.io.get_cmd(as_list=True) 24 | actual_cmd = ' '.join(actual_args) 25 | 26 | # for launching the TF workers and for launching tensorboard 27 | base_cmd = [ 28 | 'CUDA_VISIBLE_DEVICES=', 29 | sys.executable, 'main.py', 30 | '--load_path', args.load_path, 31 | '--start_port', args.start_port, 32 | '--num_gpu', ut.misc.count_gpu(), 33 | ] + actual_args 34 | 35 | cmds_map = [ 36 | ("dummy", "tmux send-keys -t {}:0 Enter".format(args.tag)), 37 | new_cmd(args.tag, "ps", base_cmd + ["--job_name", "ps"], args.load_path, shell), 38 | ] 39 | 40 | if args.loss == 'l2': 41 | gpu_task_num = 1 42 | elif args.loss == 'gan': 43 | gpu_task_num = 2 44 | 45 | for idx in range(args.num_workers): 46 | if idx < gpu_task_num and args.num_gpu > 0: # gpu workers 47 | cmd = [base_cmd[0] + str(min(args.num_gpu, max(0, args.num_gpu - idx - 1)))] + base_cmd[1:] 48 | else: 49 | cmd = base_cmd[:] 50 | 51 | cmd += ["--job_name", "worker", "--task", str(idx)] 52 | cmds_map += [new_cmd(args.tag, "w-%d" % idx, cmd, args.load_path, shell)] 53 | 54 | tmp_tb_dir = "/".join(sys.executable.split('/')[:-1]) 55 | tmp_tb_path = os.path.join(tmp_tb_dir, "tensorboard") 56 | 57 | if os.path.exists(tmp_tb_path): 58 | tb = tmp_tb_dir + "/tensorboard" 59 | else: 60 | tb = "tensorboard" 61 | tb_args = [tb, "--logdir", args.log_dir, "--port", "12345"] 62 | 63 | cmds_map += [new_cmd(args.tag, "tb", tb_args, args.load_path, shell)] 64 | cmds_map += [new_cmd(args.tag, "htop", ["htop"], args.load_path, shell)] 65 | 66 | windows = [v[0] for v in cmds_map] 67 | 68 | notes = [] 69 | cmds = [] 70 | 71 | notes += ["Use `tmux attach -t {}` to watch process output".format(args.tag)] 72 | notes += ["Use `tmux kill-session -t {}` to kill the job".format(args.tag)] 73 | 74 | notes += ["Point your browser to http://localhost:12345 to see Tensorboard"] 75 | 76 | cmds += [ 77 | # kill any process using tensorboard's port 78 | f"kill $( lsof -i:{args.tb_port} -t ) > /dev/null 2>&1", 79 | # kill any processes using ps / worker ports 80 | f"kill $( lsof -i:{args.start_port}-{args.num_workers + args.start_port} -t ) > /dev/null 2>&1", 81 | f"tmux kill-session -t {args.tag}", 82 | f"tmux new-session -s {args.tag} -n {windows[0]} -d {shell}", 83 | ] 84 | for w in windows[1:]: 85 | cmds += ["tmux new-window -t {} -n {} {}".format(args.tag, w, shell)] 86 | cmds += ["sleep 1"] 87 | 88 | for window, cmd in cmds_map: 89 | cmds += [cmd] 90 | 91 | return cmds, notes 92 | 93 | 94 | def run(args): 95 | cmds, notes = create_commands(args) 96 | if args.dry_run: 97 | print("Dry-run mode due to -n flag, otherwise the following commands would be executed:") 98 | else: 99 | print("Executing the following commands:") 100 | 101 | print("\n".join(cmds)) 102 | print("") 103 | 104 | if not args.dry_run: 105 | os.environ["TMUX"] = "" 106 | os.system("\n".join(cmds)) 107 | 108 | print('\n'.join(notes)) 109 | 110 | 111 | if __name__ == "__main__": 112 | from config import get_args 113 | args = get_args() 114 | run(args) 115 | -------------------------------------------------------------------------------- /utils/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: future_fstrings -*- 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import json 8 | import random 9 | import numpy as np 10 | from pathlib import Path 11 | 12 | from . import io 13 | from . import misc 14 | from . import logging 15 | 16 | PARAM_FNAME = "params.json" 17 | 18 | 19 | logger = logging.get_logger() 20 | 21 | def set_global_seed(seed, tensorflow=False, tf=False, pytorch=False): 22 | if tf or tensorflow: 23 | try: 24 | import tensorflow as tf 25 | except ImportError: 26 | pass 27 | else: 28 | tf.set_random_seed(seed) 29 | 30 | if pytorch: 31 | try: 32 | import torch as th 33 | except ImportError: 34 | pass 35 | else: 36 | th.manual_seed(seed) 37 | if th.cuda.is_available(): 38 | th.cuda.manual_seed(seed) 39 | 40 | random.seed(seed) 41 | np.random.seed(seed) 42 | 43 | def prepare_dirs(args): 44 | if args.load_path: 45 | if str(args.load_path).startswith(str(args.log_dir)): 46 | load_path = Path(args.load_path) 47 | else: 48 | load_path = Path(f"{str(args.log_dir)}/{args.load_path}") 49 | args.model_name = load_path.name 50 | else: 51 | model_desc = io.get_cmd().replace(" ", "|") 52 | hash_text = misc.get_hash(6) 53 | args.hash_text = hash_text 54 | 55 | args.model_name = f"{args.env}_{misc.get_time()}{model_desc}_{hash_text}" 56 | load_path = args.log_dir / args.model_name 57 | 58 | args.load_path = load_path 59 | io.makedirs(args.load_path) 60 | 61 | # create directories 62 | for key, path in vars(args).items(): 63 | if key.endswith('_dir') and not os.path.exists(str(path)): 64 | io.makedirs(path) 65 | 66 | def save_args(args): 67 | load_path = Path(args.load_path) 68 | param_path = load_path / PARAM_FNAME 69 | 70 | info = { k:str(v) if isinstance(v, Path) else v for k, v in args.__dict__.items() } 71 | with open(str(param_path), 'w') as f: 72 | json.dump(info, f, indent=4, sort_keys=True) 73 | 74 | cmd_path = load_path / "cmd.sh" 75 | with open(str(cmd_path), 'w') as f: 76 | f.write(io.get_cmd()) 77 | 78 | logger.info(f"Saved {PARAM_FNAME}: {param_path}") 79 | 80 | def update_args(args, key, new_value): 81 | load_path = Path(args.load_path) 82 | param_path = load_path / PARAM_FNAME 83 | 84 | if param_path.exists(): 85 | with open(param_path) as f: 86 | saved_args = json.load(f) 87 | 88 | original_value = saved_args.get(key, None) 89 | saved_args[key] = new_value 90 | 91 | with open(param_path, 'w') as f: 92 | json.dump(saved_args, f, indent=4, sort_keys=True) 93 | 94 | logger.info(f"Update {param_path}: {key} ({original_value} -> {new_value})") 95 | else: 96 | raise FileNotFoundError(f"{param_path} not exists") 97 | 98 | # XXX: actually `skip_list` is quite important during test time 99 | def load_args(args, skip_list=['load_path', 'test_epoch', 'test_dataset', 'train']): 100 | args_keys = vars(args).keys() 101 | args_path = os.path.join(args.load_path, PARAM_FNAME) 102 | 103 | with open(args_path) as f: 104 | saved_args = json.load(f) 105 | 106 | for saved_key, saved_value in saved_args.items(): 107 | if saved_key in skip_list: 108 | continue 109 | 110 | if hasattr(args, saved_key): 111 | args_value = getattr(args, saved_key) 112 | if args_value != saved_value: 113 | if isinstance(args_value, Path): 114 | # we don't have to print this 115 | saved_value = Path(saved_value) 116 | else: 117 | logger.info(f"[UPDATE] args `{saved_key}`: {args_value} -> {saved_value}") 118 | setattr(args, saved_key, saved_value) 119 | else: # if key is programmatically generated, skip and let the code make the key by itself 120 | pass 121 | 122 | -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: future_fstrings -*- 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import sys 7 | import logging 8 | 9 | BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8) 10 | 11 | RESET_SEQ = "\033[0m" 12 | COLOR_SEQ = "\033[1;%dm" 13 | BOLD_SEQ = "\033[1m" 14 | 15 | 16 | class JacLogFormatter(logging.Formatter): 17 | # Code from https://github.com/vacancy/Jacinle/blob/master/jacinle/logging/logger.py 18 | log_fout = None 19 | date_full = '[%(asctime)s %(lineno)d@%(filename)s:%(name)s] ' 20 | date = '%(asctime)s ' 21 | msg = '%(message)s' 22 | max_lines = 256 23 | 24 | def _color_dbg(self, msg): 25 | return '\x1b[36m{}\x1b[0m'.format(msg) 26 | 27 | def _color_warn(self, msg): 28 | return '\x1b[1;31m{}\x1b[0m'.format(msg) 29 | 30 | def _color_err(self, msg): 31 | return '\x1b[1;4;31m{}\x1b[0m'.format(msg) 32 | 33 | def _color_omitted(self, msg): 34 | return '\x1b[35m{}\x1b[0m'.format(msg) 35 | 36 | def _color_normal(self, msg): 37 | return msg 38 | 39 | def _color_date(self, msg): 40 | return '\x1b[32m{}\x1b[0m'.format(msg) 41 | 42 | def format(self, record): 43 | if record.levelno == logging.DEBUG: 44 | mcl, mtxt = self._color_dbg, '' 45 | elif record.levelno == logging.WARNING: 46 | mcl, mtxt = self._color_warn, '' 47 | elif record.levelno == logging.ERROR: 48 | mcl, mtxt = self._color_err, '' 49 | else: 50 | mcl, mtxt = self._color_normal, '' 51 | 52 | if mtxt: 53 | mtxt += ' ' 54 | 55 | if self.log_fout: 56 | self.__set_fmt(self.date_full + mtxt + self.msg) 57 | formatted = super(JacLogFormatter, self).format(record) 58 | nr_line = formatted.count('\n') + 1 59 | if nr_line >= self.max_lines: 60 | head, body = formatted.split('\n', 1) 61 | formatted = '\n'.join([ 62 | head, 63 | 'BEGIN_LONG_LOG_{}_LINES{{'.format(nr_line - 1), 64 | body, 65 | '}}END_LONG_LOG_{}_LINES'.format(nr_line - 1) 66 | ]) 67 | self.log_fout.write(formatted) 68 | self.log_fout.write('\n') 69 | self.log_fout.flush() 70 | 71 | self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg)) 72 | formatted = super(JacLogFormatter, self).format(record) 73 | nr_line = formatted.count('\n') + 1 74 | if nr_line >= self.max_lines: 75 | lines = formatted.split('\n') 76 | remain = self.max_lines//2 77 | removed = len(lines) - remain * 2 78 | if removed > 0: 79 | mid_msg = self._color_omitted( 80 | '[{} log lines omitted (would be written to output file if set_output_file() has been called;\n' 81 | ' the threshold can be set at TALogFormatter.max_lines)].'.format(removed)) 82 | formatted = '\n'.join( 83 | lines[:remain] + [mid_msg] + lines[-remain:]) 84 | 85 | return formatted 86 | 87 | if sys.version_info.major < 3: 88 | def __set_fmt(self, fmt): 89 | self._fmt = fmt 90 | else: 91 | def __set_fmt(self, fmt): 92 | self._style._fmt = fmt 93 | 94 | log_formatter = JacLogFormatter( 95 | "%(asctime)s:%(levelname)s::%(message)s", "%m-%d %H:%M:%S") 96 | 97 | 98 | def get_logger(name=__file__, level=logging.INFO): 99 | logger = logging.getLogger(name) 100 | 101 | if getattr(logger, '_init_done__', None): 102 | logger.setLevel(level) 103 | return logger 104 | 105 | logger._init_done__ = True 106 | logger.propagate = False 107 | logger.setLevel(level) 108 | 109 | del logger.handlers[:] 110 | 111 | handler = logging.StreamHandler() 112 | handler.setFormatter(log_formatter) 113 | 114 | logger.addHandler(handler) 115 | 116 | return logger 117 | 118 | logger = get_logger() 119 | 120 | def add_file_handler(args): 121 | handler = logging.FileHandler(f"{args.model_dir}/log.log", 'a') 122 | handler.setFormatter(log_formatter) 123 | logger.addHandler(handler) 124 | -------------------------------------------------------------------------------- /envs/mypaint_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | ## Curve Math 5 | def point_on_curve_1(t, cx, cy, sx, sy, x1, y1, x2, y2): 6 | ratio = t/100.0 7 | x3, y3 = multiply_add(sx, sy, x1, y1, ratio) 8 | x4, y4 = multiply_add(cx, cy, x2, y2, ratio) 9 | x5, y5 = difference(x3, y3, x4, y4) 10 | x, y = multiply_add(x3, y3, x5, y5, ratio) 11 | return x, y 12 | 13 | 14 | def point_on_curve_2(t, cx, cy, sx, sy, kx, ky, x1, y1, x2, y2, x3, y3): 15 | ratio = t/100.0 16 | x4, y4 = multiply_add(sx, sy, x1, y1, ratio) 17 | x5, y5 = multiply_add(cx, cy, x2, y2, ratio) 18 | x6, y6 = multiply_add(kx, ky, x3, y3, ratio) 19 | x1, y1 = difference(x4, y4, x5, y5) 20 | x2, y2 = difference(x5, y5, x6, y6) 21 | x4, y4 = multiply_add(x4, y4, x1, y1, ratio) 22 | x5, y5 = multiply_add(x5, y5, x2, y2, ratio) 23 | x1, y1 = difference(x4, y4, x5, y5) 24 | x, y = multiply_add(x4, y4, x1, y1, ratio) 25 | return x, y 26 | 27 | 28 | ## Ellipse Math 29 | def starting_point_for_ellipse(x, y, rotate): 30 | # Rotate starting point 31 | r = math.radians(rotate) 32 | sin = math.sin(r) 33 | cos = math.cos(r) 34 | x, y = rotate_ellipse(x, y, cos, sin) 35 | return x, y, sin, cos 36 | 37 | 38 | def point_in_ellipse(x, y, r_sin, r_cos, degree): 39 | # Find point in ellipse 40 | r2 = math.radians(degree) 41 | cos = math.cos(r2) 42 | sin = math.sin(r2) 43 | x = x * cos 44 | y = y * sin 45 | # Rotate Ellipse 46 | x, y = rotate_ellipse(y, x, r_sin, r_cos) 47 | return x, y 48 | 49 | 50 | def rotate_ellipse(x, y, sin, cos): 51 | x1, y1 = multiply(x, y, sin) 52 | x2, y2 = multiply(x, y, cos) 53 | x = x2 - y1 54 | y = y2 + x1 55 | return x, y 56 | 57 | 58 | ## Vector Math 59 | def get_angle(x1, y1, x2, y2): 60 | dot = dot_product(x1, y1, x2, y2) 61 | if abs(dot) < 1.0: 62 | angle = math.acos(dot) * 180/math.pi 63 | else: 64 | angle = 0.0 65 | return angle 66 | 67 | 68 | def constrain_to_angle(x, y, sx, sy): 69 | length, nx, ny = length_and_normal(sx, sy, x, y) 70 | # dot = nx*1 + ny*0 therefore nx 71 | angle = math.acos(nx) * 180/math.pi 72 | angle = constraint_angle(angle) 73 | ax, ay = angle_normal(ny, angle) 74 | x = sx + ax*length 75 | y = sy + ay*length 76 | return x, y 77 | 78 | 79 | def constraint_angle(angle): 80 | n = angle//15 81 | n1 = n*15 82 | rem = angle - n1 83 | if rem < 7.5: 84 | angle = n*15.0 85 | else: 86 | angle = (n+1)*15.0 87 | return angle 88 | 89 | 90 | def angle_normal(ny, angle): 91 | if ny < 0.0: 92 | angle = 360.0 - angle 93 | radians = math.radians(angle) 94 | x = math.cos(radians) 95 | y = math.sin(radians) 96 | return x, y 97 | 98 | 99 | def length_and_normal(x1, y1, x2, y2): 100 | x, y = difference(x1, y1, x2, y2) 101 | length = vector_length(x, y) 102 | if length == 0.0: 103 | x, y = 0.0, 0.0 104 | else: 105 | x, y = x/length, y/length 106 | return length, x, y 107 | 108 | 109 | def normal(x1, y1, x2, y2): 110 | junk, x, y = length_and_normal(x1, y1, x2, y2) 111 | return x, y 112 | 113 | 114 | def vector_length(x, y): 115 | length = math.sqrt(x*x + y*y) 116 | return length 117 | 118 | 119 | def distance(x1, y1, x2, y2): 120 | x, y = difference(x1, y1, x2, y2) 121 | length = vector_length(x, y) 122 | return length 123 | 124 | 125 | def dot_product(x1, y1, x2, y2): 126 | return x1*x2 + y1*y2 127 | 128 | 129 | def multiply_add(x1, y1, x2, y2, d): 130 | x3, y3 = multiply(x2, y2, d) 131 | x, y = add(x1, y1, x3, y3) 132 | return x, y 133 | 134 | 135 | def multiply(x, y, d): 136 | # Multiply vector 137 | x = x*d 138 | y = y*d 139 | return x, y 140 | 141 | 142 | def add(x1, y1, x2, y2): 143 | # Add vectors 144 | x = x1+x2 145 | y = y1+y2 146 | return x, y 147 | 148 | 149 | def difference(x1, y1, x2, y2): 150 | # Difference in x and y between two points 151 | x = x2-x1 152 | y = y2-y1 153 | return x, y 154 | 155 | 156 | def midpoint(x1, y1, x2, y2): 157 | # Midpoint between to points 158 | x = (x1+x2)/2.0 159 | y = (y1+y2)/2.0 160 | return x, y 161 | 162 | 163 | def perpendicular(x1, y1): 164 | # Swap x and y, then flip one sign to give vector at 90 degree 165 | x = -y1 166 | y = x1 167 | return x, y 168 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: future_fstrings -*- 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import argparse 7 | from pathlib import Path 8 | 9 | import utils as ut 10 | 11 | logger = ut.logging.get_logger() 12 | 13 | 14 | parser = argparse.ArgumentParser() 15 | 16 | # model 17 | model_arg = ut.args.add_argument_group(parser, 'model') 18 | model_arg.add_argument('--lstm_size', default=256, type=int) 19 | model_arg.add_argument('--scale', default=1.0, type=float) 20 | model_arg.add_argument('--z_dim', default=10, type=int) 21 | model_arg.add_argument('--dynamic_channel', default=False, type=ut.args.str2bool) 22 | model_arg.add_argument('--disc_dim', default=64, type=int) 23 | model_arg.add_argument('--disc_batch_norm', default=True, type=ut.args.str2bool) 24 | model_arg.add_argument('--loss', default='gan', type=str, 25 | choices=['l2', 'gan']) 26 | 27 | 28 | # environment 29 | env_arg = ut.args.add_argument_group(parser, 'environment') 30 | env_arg.add_argument('--env', default="simple_mnist") 31 | env_arg.add_argument('--jump', default=True, type=ut.args.str2bool) 32 | env_arg.add_argument('--curve', default=True, type=ut.args.str2bool) 33 | env_arg.add_argument('--episode_length', default=5, type=int) 34 | env_arg.add_argument('--screen_size', default=64, type=int) 35 | env_arg.add_argument('--location_size', default=32, type=int) 36 | env_arg.add_argument('--color_channel', default=3, type=int, choices=[3, 1]) 37 | env_arg.add_argument('--mnist_nums', default='0,1,2,3,4,5,6,7,8,9', type=ut.args.int_list) 38 | env_arg.add_argument('--brush_path', default='assets/brushes/dry_brush.myb', type=str) 39 | env_arg.add_argument('--conditional', default=True, type=ut.args.str2bool) 40 | 41 | 42 | # train 43 | train_arg = ut.args.add_argument_group(parser, 'train') 44 | train_arg.add_argument('--policy_lr', default=1e-5, type=float) 45 | train_arg.add_argument('--disc_lr', default=1e-4, type=float) 46 | train_arg.add_argument('--clip_disc_weights', default=False, type=ut.args.str2bool) 47 | train_arg.add_argument('--entropy_coeff', default=0.01, type=float) 48 | train_arg.add_argument('--grad_clip', default=40, type=int) 49 | train_arg.add_argument('--policy_batch_size', default=64, type=int) 50 | train_arg.add_argument('--disc_batch_size', default=64, type=int) 51 | train_arg.add_argument('--replay_size', default=10, type=int) 52 | train_arg.add_argument('--buffer_batch_num', default=20, type=int) 53 | train_arg.add_argument('--wgan_lambda', default=20, type=float) 54 | train_arg.add_argument('--train', default=True, type=ut.args.str2bool) 55 | 56 | 57 | # distributed 58 | dist_arg = ut.args.add_argument_group(parser, 'distributed') 59 | dist_arg.add_argument('--task', default=0, type=int) 60 | dist_arg.add_argument('--job_name', default="worker") 61 | dist_arg.add_argument('--num_workers', default=4, type=int) 62 | dist_arg.add_argument('--start_port', default=13333, type=int) 63 | dist_arg.add_argument('--tag', default='spiral', type=str) 64 | 65 | 66 | # Misc 67 | misc_arg = ut.args.add_argument_group(parser, 'misc') 68 | misc_arg.add_argument('--debug', type=ut.args.str2bool, default=False) 69 | misc_arg.add_argument('--num_gpu', type=int, default=1, 70 | choices=[0, 1, 2]) 71 | misc_arg.add_argument('--policy_log_step', type=int, default=20) 72 | misc_arg.add_argument('--disc_log_step', type=int, default=50) 73 | misc_arg.add_argument('--data_dir', type=Path, default='.data') 74 | misc_arg.add_argument('--log_dir', type=Path, default='logs') 75 | misc_arg.add_argument('--load_path', type=Path, default=None) 76 | misc_arg.add_argument('--log_level', type=str, default='INFO', 77 | choices=['INFO', 'DEBUG', 'WARN']) 78 | misc_arg.add_argument('--seed', type=int, default=123) 79 | misc_arg.add_argument('--dry_run', action='store_true') 80 | misc_arg.add_argument('--tb_port', type=int, default=12345) 81 | 82 | 83 | def get_args(group_name=None, parse_unknown=False): 84 | if parse_unknown: 85 | args, unknown = parser.parse_known_args() 86 | else: 87 | args = parser.parse_args() 88 | 89 | ############################## 90 | # Preprocess or filter args 91 | ############################## 92 | if args.loss == 'gan': 93 | args.conditional = False 94 | assert args.num_workers > 2, "num_workers should be larger than 2 (policy, discriminator, worker)" 95 | elif args.loss == 'l2': 96 | args.conditional = True 97 | assert args.num_workers > 1, "num_workers should be larger than 2 (policy, worker)" 98 | 99 | if parse_unknown: 100 | return args, unknown 101 | return args 102 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import time 7 | import sys, signal 8 | import tensorflow as tf 9 | 10 | import utils as ut 11 | from agent import Agent 12 | from envs import create_env 13 | 14 | logger = ut.logging.get_logger() 15 | 16 | 17 | def train(args, server, cluster, env, queue_shapes, 18 | trajectory_queue_size, replay_queue_size): 19 | 20 | agent = Agent(args, server, cluster, env, queue_shapes, 21 | trajectory_queue_size, replay_queue_size) 22 | 23 | # Variable names that start with "local" are not saved in checkpoints. 24 | variables_to_save = [ 25 | v for v in tf.global_variables() if not v.name.startswith("local")] 26 | 27 | init_op = tf.variables_initializer(variables_to_save) 28 | init_all_op = tf.global_variables_initializer() 29 | saver = ut.tf.FastSaver(variables_to_save) 30 | 31 | var_list = tf.get_collection( 32 | tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) 33 | 34 | logger.info('Trainable vars:') 35 | for v in var_list: 36 | logger.info(' %s %s', v.name, v.get_shape()) 37 | 38 | def init_fn(ses): 39 | logger.info("Initializing all parameters.") 40 | ses.run(init_all_op) 41 | 42 | devices = ["/job:ps"] 43 | if args.task == 0: 44 | devices += ["/job:worker/task:{}/gpu:0".format(args.task), 45 | "/job:worker/task:{}/cpu:0".format(args.task)] 46 | elif args.task == 1: 47 | devices += ["/job:worker/task:{}/gpu:{}".format(args.task, 1 if args.num_gpu > 1 else 0), 48 | "/job:worker/task:{}/cpu:0".format(args.task)] 49 | else: 50 | devices += ["/job:worker/task:{}/cpu:0".format(args.task)] 51 | 52 | config = tf.ConfigProto(device_filters=devices, allow_soft_placement=True) 53 | logger.info("Events directory: %s_%s", args.load_path, args.task) 54 | 55 | summary_writer = tf.summary.FileWriter( 56 | "{}_{}".format(args.load_path, args.task)) 57 | agent.summary_writer = summary_writer 58 | 59 | uninitialized_variables = tf.report_uninitialized_variables(variables_to_save) 60 | 61 | if args.task == 1 and args.loss == 'gan': 62 | local_init_op = tf.variables_initializer(agent.local_disc.var_list) 63 | else: 64 | local_init_op = None 65 | 66 | sv = tf.train.Supervisor( 67 | is_chief=args.task == 0, 68 | logdir=str(args.load_path), 69 | saver=saver, 70 | summary_op=None, 71 | init_op=init_op, 72 | init_fn=init_fn, 73 | local_init_op=local_init_op, 74 | summary_writer=summary_writer, 75 | # very useful when sv.managed_session hang 76 | #ready_op=tf.constant([], dtype=tf.string), 77 | ready_op=uninitialized_variables, 78 | global_step=agent.policy_step, 79 | save_model_secs=30, 80 | save_summaries_secs=30) 81 | 82 | num_policy_steps = 100000000 83 | 84 | logger.info( 85 | "Starting session. If this hangs, we're mostly likely waiting" 86 | " to connect to the parameter server. One common cause is that" 87 | " the parameter server DNS name isn't resolving yet, or is misspecified.") 88 | 89 | with sv.managed_session(server.target, config=config) as sess, \ 90 | sess.as_default(): 91 | 92 | ############################### 93 | # Run thread 94 | ############################### 95 | 96 | if args.task == 1 and args.loss == 'gan': 97 | # master_disc ->local_disc 98 | sess.run(agent.disc_initializer) 99 | agent.start_replay_thread(sess, summary_writer) 100 | elif args.task >= 1: 101 | sess.run(agent.policy_sync) 102 | agent.start_worker_thread(sess, summary_writer) 103 | 104 | policy_step = sess.run(agent.policy_step) 105 | logger.info("Starting training at step=%d", policy_step) 106 | 107 | while not sv.should_stop() and ( \ 108 | not num_policy_steps or policy_step < num_policy_steps): 109 | 110 | if args.task == 0: 111 | agent.train_policy(sess) 112 | elif args.task == 1 and args.loss == 'gan': 113 | # local_disc -> master_disc 114 | sess.run(agent.disc_sync) 115 | agent.train_gan(sess) 116 | else: 117 | sess.run(agent.policy_sync) 118 | policy_step = sess.run(agent.policy_step) 119 | 120 | # Ask for all the services to stop. 121 | sv.stop() 122 | logger.info('reached %s steps. worker stopped.', policy_step) 123 | 124 | -------------------------------------------------------------------------------- /envs/simple.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | from PIL import Image, ImageDraw 7 | 8 | from . import utils 9 | from .base import Environment 10 | 11 | 12 | class Simple(Environment): 13 | 14 | action_sizes = { 15 | 'color': [1], 16 | 'shape': [2], 17 | #'location': None, 18 | } 19 | 20 | def __init__(self, args): 21 | super(Simple, self).__init__(args) 22 | 23 | assert self.conditional, \ 24 | "Don't train a Simple env with --conditional=False" 25 | 26 | # object 27 | self.object_radius = 3 28 | self.object_size = self.object_radius * 2 29 | self.background_color = (255, 255, 255) 30 | 31 | self.colors = [ 32 | (0, 0, 0), # black 33 | (173, 181, 189), # gray 5 34 | (255, 224, 102), # yellow 3 35 | #(102, 217, 232), # cyan 3 36 | ] 37 | if 'color' in self.action_sizes: 38 | self.colors = self.colors[:self.action_sizes['color'][0]] 39 | 40 | self.shapes = ['circle', 'rectangle'] 41 | if 'shape' in self.action_sizes: 42 | self.shapes = self.shapes[:self.action_sizes['shape'][0]] 43 | 44 | self.locations = utils.uniform_locations( 45 | self.screen_size, self.location_size, 46 | self.object_radius) 47 | 48 | self.image = None 49 | self.drawer = None 50 | self.random_target = None 51 | 52 | def reset(self): 53 | self.random_target = self.get_random_target() 54 | self.image = Image.new( 55 | 'RGB', (self.width, self.height), self.background_color) 56 | self.drawer = ImageDraw.Draw(self.image) 57 | self._step = 0 58 | 59 | # TODO(taehoon): z 60 | self.z = None 61 | return self.state, self.random_target, self.z 62 | 63 | def draw(self, ac, drawer=None): 64 | if drawer is None: 65 | drawer = self.drawer 66 | 67 | r = self.object_radius 68 | x, y = self.locations[0] 69 | color, shape = self.colors[0], self.shapes[0] 70 | 71 | for name in self.action_sizes: 72 | named_ac = ac[self.ac_idx[name]] 73 | value = getattr(self, name + "s")[named_ac] 74 | 75 | if name == 'location': 76 | x, y = value 77 | elif name == 'color': 78 | color = value 79 | elif name == 'shape': 80 | shape = value 81 | 82 | if shape == 'circle': 83 | drawer.ellipse((x-r, y-r, x+r, y+r), fill=color) 84 | elif shape == 'rectangle': 85 | drawer.rectangle((x-r, y-r, x+r, y+r), fill=color) 86 | else: 87 | raise Exception("Unkown shape: {}".format(shape)) 88 | 89 | def get_random_target(self): 90 | image = Image.new( 91 | 'RGB', (self.width, self.height), self.background_color) 92 | drawer = ImageDraw.Draw(image) 93 | 94 | locations = [] 95 | for _ in range(self.episode_length): 96 | ac = self.random_action(locations=locations) 97 | self.draw(ac, drawer) 98 | if 'location' in self.ac_idx: 99 | locations.append(ac[self.ac_idx['location']]) 100 | else: 101 | locations.append(self.locations[0]) 102 | 103 | return np.array(self.norm(image)) 104 | 105 | def random_action(self, locations=[]): 106 | action = [] 107 | for name in self.acs: 108 | size = self.action_sizes[name] 109 | while True: 110 | sample = np.random.randint(np.prod(size)) 111 | if name == 'locations': 112 | if sample in locations: 113 | continue 114 | else: 115 | break 116 | action.append(sample) 117 | return action 118 | 119 | def step(self, acs): 120 | self.draw(acs, self.drawer) 121 | self._step += 1 122 | terminal = (self._step == self.episode_length) 123 | if terminal: 124 | if self.conditional: 125 | reward = - utils.l2(self.state, self.random_target) \ 126 | / np.prod(self.observation_shape) * 100 127 | else: 128 | reward = None 129 | else: 130 | reward = 0 131 | 132 | # XXX: DEBUG 133 | if reward == 0: reward = 1 134 | 135 | # state, reward, terminal, info 136 | return self.state, reward, terminal, {} 137 | 138 | def save_image(self, path): 139 | self.image.save(path) 140 | 141 | @property 142 | def state(self): 143 | return np.array(self.norm(self.image)) 144 | 145 | 146 | if __name__ == '__main__': 147 | from config import get_args 148 | args = get_args() 149 | 150 | env = Simple(args) 151 | 152 | for ep_idx in range(10): 153 | step = 0 154 | env.reset() 155 | 156 | while True: 157 | action = env.random_action() 158 | print("[Step {}] ac: {}".format(step, action)) 159 | state, reward, terminal, info = env.step(action) 160 | step += 1 161 | 162 | if terminal: 163 | print("Ep #{} finished.".format(ep_idx)) 164 | env.save_image("simple{}.png".format(ep_idx)) 165 | break 166 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | # Code based on https://github.com/tensorflow/models/blob/master/research/slim/nets/dcgan.py 2 | from math import log 3 | import tensorflow as tf 4 | 5 | import utils as ut 6 | 7 | tl = tf.layers 8 | logger = ut.logging.get_logger() 9 | 10 | 11 | class Discriminator(object): 12 | def __init__(self, args, step, image_shape, norm_fn, scope_name): 13 | self.args = args 14 | self.step = step 15 | self.scope_name = scope_name 16 | 17 | self.fake = fake = tf.placeholder( 18 | tf.float32, [None] + list(image_shape), name='c') 19 | self.real = real = tf.placeholder( 20 | tf.float32, [None] + list(image_shape), name='c') 21 | 22 | # NHWC -> NCHW 23 | if args.num_gpu > 0: 24 | fake = tf.transpose(fake, [0, 3, 1, 2]) 25 | real = tf.transpose(real, [0, 3, 1, 2]) 26 | self.data_format = "channels_first" 27 | channel_idx = 1 28 | else: 29 | self.data_format = "channels_last" 30 | channel_idx = -1 31 | 32 | if norm_fn is not None: 33 | fake = norm_fn(fake) 34 | real = norm_fn(real) 35 | 36 | if self.args.conditional: 37 | fake = tf.concat([fake, real], axis=channel_idx) 38 | real = tf.concat([real, real], axis=channel_idx) 39 | 40 | self.fake_in = fake 41 | self.real_in = real 42 | 43 | 44 | self.real_probs, self.real_logits = self.build_model(self.real_in) 45 | self.var_list = tf.trainable_variables(self.scope_name) 46 | 47 | self.fake_probs, self.fake_logits = self.build_model(self.fake_in, reuse=True) 48 | 49 | self.build_optim() 50 | 51 | def build_model(self, 52 | inputs, 53 | is_training=True, 54 | reuse=False): 55 | 56 | inp_shape = inputs.get_shape().as_list()[2] 57 | 58 | with tf.variable_scope(self.scope_name, values=[inputs], reuse=reuse) as scope: 59 | x = inputs 60 | 61 | layer_num = int(log(inp_shape, 2)) 62 | for idx in range(layer_num): 63 | cur_depth = self.args.disc_dim * 2**idx 64 | 65 | x = tl.conv2d( 66 | x, cur_depth, 5, 67 | strides=(2, 2), 68 | padding='same', 69 | activation=None, 70 | data_format=self.data_format, 71 | kernel_initializer=tf.keras.initializers.he_normal(), 72 | name="conv{}".format(idx)) 73 | 74 | logger.info("conv: {} ({})".format(x.name, x.get_shape())) 75 | 76 | if idx > 0 and self.args.disc_batch_norm: 77 | x = tl.batch_normalization( 78 | x, axis=1 if self.data_format == "channels_first" else -1, 79 | fused=True, training=True) 80 | 81 | x = tf.nn.leaky_relu(x) 82 | 83 | x = tl.flatten(x) 84 | logits = tl.dense( 85 | x, 1, 86 | activation=None, 87 | kernel_initializer=tf.keras.initializers.glorot_normal(), 88 | name="dense") 89 | 90 | logger.info("logits: {} ({})".format(logits.name, logits.get_shape())) 91 | 92 | logits = tf.reshape(logits, [-1]) 93 | probs = tf.nn.sigmoid(logits) 94 | 95 | return probs, logits 96 | 97 | def build_optim(self): 98 | self.g_loss = -tf.reduce_mean(self.fake_logits) 99 | self.critic_loss = \ 100 | tf.reduce_mean(self.fake_logits) - tf.reduce_mean(self.real_logits) 101 | 102 | alpha = tf.random_uniform( 103 | [self.args.disc_batch_size, 1], 104 | minval=0.0, 105 | maxval=1.0) 106 | 107 | fake_data = tl.flatten(self.fake_in) 108 | real_data = tl.flatten(self.real_in) 109 | 110 | differences = fake_data - real_data 111 | interpolates = real_data + (alpha*differences) 112 | 113 | diff_in = tf.reshape(interpolates, ut.tf.int_shape(self.fake_in)) 114 | diff_probs, diff_logits = self.build_model(diff_in, reuse=True) 115 | 116 | gradients = tf.gradients(diff_probs, [interpolates])[0] 117 | slopes = tf.sqrt(1e-8+tf.reduce_sum( 118 | tf.square(gradients), reduction_indices=[1])) 119 | 120 | self.gradient_penalty = tf.reduce_mean((slopes-1.)**2) 121 | self.d_loss = self.critic_loss + \ 122 | self.args.wgan_lambda * self.gradient_penalty 123 | 124 | self.opt = tf.train.AdamOptimizer( 125 | self.args.disc_lr, beta1=0.5, beta2=0.9, 126 | name="disc_optim") 127 | 128 | if self.args.disc_batch_norm: 129 | update_ops = tf.get_collection( 130 | tf.GraphKeys.UPDATE_OPS, scope=self.scope_name) 131 | with tf.control_dependencies(update_ops): 132 | self.train_op = self.opt.minimize( 133 | self.d_loss, self.step, var_list=self.var_list) 134 | else: 135 | self.train_op = self.opt.minimize( 136 | self.d_loss, self.step, var_list=self.var_list) 137 | 138 | self.var_list += self.opt.variables() 139 | 140 | scope_vars = tf.get_collection( 141 | tf.GraphKeys.GLOBAL_VARIABLES, scope=self.scope_name) 142 | self.var_list += [ 143 | v for v in scope_vars if "batch_normalization" in v.name] 144 | 145 | ################## 146 | # Summaries 147 | ################## 148 | 149 | if self.scope_name == 'local': 150 | self.summary_op = tf.summary.merge([ 151 | tf.summary.scalar("gan/critic_loss", self.critic_loss), 152 | tf.summary.scalar("gan/disc_loss", self.d_loss), 153 | tf.summary.scalar("gan/penalty", self.gradient_penalty), 154 | tf.summary.scalar("gan/gen_loss", self.g_loss), 155 | ]) 156 | 157 | def predict(self, images): 158 | sess = tf.get_default_session() 159 | feed_dict = { 160 | self.real: images, 161 | } 162 | probs = sess.run(self.real_probs, feed_dict) 163 | return probs 164 | 165 | 166 | if __name__ == '__main__': 167 | from config import get_args 168 | args = get_args() 169 | 170 | noise = tf.random_normal([args.disc_batch_size, 128]) 171 | 172 | x = tl.dense( 173 | noise, 4*4*4*2048, 174 | activation=tf.nn.relu, 175 | kernel_initializer=tf.keras.initializers.glorot_normal(), 176 | name="dense") 177 | 178 | x 179 | -------------------------------------------------------------------------------- /rl_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import threading 6 | import numpy as np 7 | import scipy.signal 8 | from collections import namedtuple 9 | 10 | import utils as ut 11 | 12 | 13 | logger = ut.logging.get_logger() 14 | 15 | Batch = namedtuple("Batch", ["si", "a", "adv", "r", "features", "c", "z"]) 16 | 17 | 18 | def discount(x, gamma): 19 | return scipy.signal.lfilter( 20 | [1], [1, -gamma], x[:,::-1], axis=1)[:,::-1] 21 | 22 | def flatten_first_two(x): 23 | return np.reshape(x, [-1] + list(x.shape)[2:]) 24 | 25 | def multiple_process_rollout(rollout, gamma, lambda_=1.0): 26 | """ 27 | given a rollout, compute its returns and the advantage 28 | """ 29 | batch_si = np.asarray(rollout['states']) 30 | batch_a = np.asarray(rollout['actions']) 31 | rewards = np.asarray(rollout['rewards']) 32 | vpred_t = np.hstack( 33 | [rollout['values'][:,:,0], np.expand_dims(rollout['r'], -1)]) 34 | 35 | rewards_plus_v = np.hstack( 36 | [rollout['rewards'], np.expand_dims(rollout['r'], -1)]) 37 | batch_r = discount(rewards_plus_v, gamma)[:,:-1] 38 | delta_t = rewards + gamma * vpred_t[:,1:] - vpred_t[:,:-1] 39 | 40 | batch_adv = discount(delta_t, gamma * lambda_) 41 | 42 | features = rollout['features'][:,0] 43 | 44 | if 'conditions' in rollout: 45 | batch_c = np.asarray(rollout['conditions']) 46 | batch_z = None 47 | else: 48 | batch_c = None 49 | batch_z = np.asarray(rollout['z']) 50 | 51 | #batch_a = flatten_first_two(batch_a) 52 | #batch_r = flatten_first_two(batch_r) 53 | #batch_si = flatten_first_two(batch_si) 54 | #batch_adv = flatten_first_two(batch_adv) 55 | #features = features[:,:,0,:] 56 | 57 | return Batch(batch_si, batch_a, batch_adv, batch_r, features, batch_c, batch_z) 58 | 59 | 60 | class PartialRollout(object): 61 | """ 62 | a piece of a complete rollout. We run our agent, and process its experience 63 | once it has processed enough steps. 64 | """ 65 | def __init__(self): 66 | self.states = [] 67 | self.actions = [] 68 | self.rewards = [] 69 | self.values = [] 70 | self.r = 0.0 71 | self.features = [] 72 | self.conditions = None 73 | self.z = None 74 | 75 | def add(self, state, action, reward, value, features, conditions=None, z=None): 76 | self.states += [state] 77 | self.actions += [action] 78 | self.rewards += [reward] 79 | self.values += [value] 80 | self.features += [features] 81 | 82 | if conditions is not None: 83 | if self.conditions is None: 84 | self.conditions = [] 85 | self.conditions += [conditions] 86 | 87 | if z is not None: 88 | if self.z is None: 89 | self.z = [] 90 | self.z += [z] 91 | 92 | 93 | class WorkerThread(threading.Thread): 94 | def __init__(self, env, policy, 95 | traj_enqueues, traj_placeholders, traj_size, 96 | replay_enqueue, replay_placeholder, replay_size): 97 | threading.Thread.__init__(self) 98 | 99 | self.env = env 100 | self.sess = None 101 | self.daemon = True 102 | self.policy = policy 103 | self.last_features = None 104 | self.summary_writer = None 105 | self.num_local_steps = env.episode_length 106 | 107 | self.traj_enqueues = traj_enqueues 108 | self.traj_placeholders = traj_placeholders 109 | self.traj_size = traj_size 110 | 111 | self.replay_enqueue = replay_enqueue 112 | self.replay_placeholder = replay_placeholder 113 | self.replay_size = replay_size 114 | 115 | def start_thread(self, sess, summary_writer): 116 | self.sess = sess 117 | self.summary_writer = summary_writer 118 | self.start() 119 | 120 | def run(self): 121 | with self.sess.as_default(): 122 | self._run() 123 | 124 | def _run(self): 125 | rollout_provider = env_runner( 126 | self.env, self.policy, 127 | self.num_local_steps, self.summary_writer) 128 | while True: 129 | out = next(rollout_provider) 130 | 131 | feed_dict = { 132 | self.traj_placeholders['actions']: out.actions, 133 | self.traj_placeholders['states']: out.states, 134 | self.traj_placeholders['rewards']: out.rewards, 135 | self.traj_placeholders['values']: out.values, 136 | self.traj_placeholders['features']: out.features, 137 | self.traj_placeholders['r']: out.r, 138 | } 139 | if self.env.conditional: 140 | feed_dict.update({ 141 | self.traj_placeholders['conditions']: out.conditions, 142 | }) 143 | else: 144 | feed_dict.update({ 145 | self.traj_placeholders['z']: out.z, 146 | }) 147 | 148 | for k, v in feed_dict.items(): 149 | if isinstance(v, list): 150 | feed_dict[k] = np.array(v) 151 | 152 | fetches = [ 153 | self.traj_enqueues, 154 | ] 155 | if self.replay_enqueue is not None: 156 | fetches.append(self.replay_enqueue) 157 | feed_dict.update({ 158 | self.replay_placeholder: out.states[-1], 159 | }) 160 | 161 | out = self.sess.run(fetches, feed_dict) 162 | 163 | 164 | class ReplayThread(threading.Thread): 165 | def __init__(self, replay, replay_dequeue): 166 | threading.Thread.__init__(self) 167 | 168 | self.replay = replay 169 | self.replay_dequeue = replay_dequeue 170 | 171 | def start_thread(self, sess): 172 | self.sess = sess 173 | self.start() 174 | 175 | def run(self): 176 | with self.sess.as_default(): 177 | self._run() 178 | 179 | def _run(self): 180 | while True: 181 | generated = self.sess.run(self.replay_dequeue) 182 | self.replay.push(generated) 183 | 184 | 185 | def env_runner(env, policy, num_local_steps, summary_writer): 186 | last_state, condition, z = env.reset() 187 | last_features = policy.get_initial_features(1, flat=True) 188 | 189 | length = 0 190 | rewards = 0 191 | 192 | while True: 193 | rollout = PartialRollout() 194 | 195 | last_action = env.initial_action 196 | 197 | for _ in range(num_local_steps): 198 | c, h = last_features 199 | 200 | fetched = policy.act( 201 | last_state, last_action, c, h, condition, z) 202 | action, value_, features = fetched[0], fetched[1], fetched[2:4] 203 | 204 | action = [np.argmax(action[name]) for name in env.acs] 205 | state, reward, terminal, info = env.step(action) 206 | 207 | # collect the experience 208 | rollout.add(last_state, action, reward, 209 | value_, last_features, condition, z) 210 | length += 1 211 | 212 | # TODO: discriminator communication to get reward 213 | rewards += reward 214 | 215 | last_state = state 216 | last_action = action 217 | last_features = features 218 | 219 | if info: 220 | summary = tf.Summary() 221 | for k, v in info.items(): 222 | summary.value.add(tag=k, simple_value=float(v)) 223 | summary_writer.add_summary(summary, policy.global_step.eval()) 224 | summary_writer.flush() 225 | 226 | last_state, condition, z = env.reset() 227 | logger.debug( 228 | "Episode finished. Sum of rewards: {:.5f}." \ 229 | "Length: {}.".format(rewards, length)) 230 | 231 | length = 0 232 | rewards = 0 233 | 234 | rollout.states += [state] 235 | 236 | # once we have enough experience, yield it, 237 | # and have the ThreadRunner place it on a queue 238 | yield rollout 239 | -------------------------------------------------------------------------------- /envs/mnist.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import sys 6 | import numpy as np 7 | from tqdm import tqdm 8 | import tensorflow as tf 9 | from PIL import Image, ImageDraw 10 | from collections import defaultdict 11 | 12 | # MyPaint 13 | sys.path.append('libs/mypaint') 14 | from lib import surface, tiledsurface, brush 15 | 16 | import utils as ut 17 | from . import utils 18 | from .mypaint_utils import * 19 | from .base import Environment 20 | 21 | 22 | class MNIST(Environment): 23 | head = 0.25 24 | tail = 0.75 25 | 26 | action_sizes = { 27 | 'pressure': [2], 28 | 'jump': [2], 29 | 'size': [2], 30 | 'control': None, 31 | 'end': None, 32 | } 33 | 34 | size = 0.2 35 | pressure = 0.3 36 | 37 | def __init__(self, args): 38 | super(MNIST, self).__init__(args) 39 | self.mnist_nums = args.mnist_nums 40 | self.colorize = not args.train 41 | 42 | self.prepare_mnist() 43 | 44 | # jump 45 | self.jumps = [0, 1] 46 | 47 | # size 48 | self.sizes = np.arange(0.2, 2.0, 0.5) 49 | self.sizes = self.sizes * 1 50 | if 'size' in self.action_sizes: 51 | self.sizes = \ 52 | self.sizes[:self.action_sizes['size'][0]] 53 | 54 | # pressure 55 | self.pressures = np.arange(0.8, 0, -0.3) 56 | if 'pressure' in self.action_sizes: 57 | self.pressures = \ 58 | self.pressures[:self.action_sizes['pressure'][0]] 59 | 60 | self.colors = [ 61 | (0., 0., 0.), # black 62 | (102., 217., 232.), # cyan 3 63 | (173., 181., 189.), # gray 5 64 | (255., 224., 102.), # yellow 3 65 | (229., 153., 247.), # grape 3 66 | (99., 230., 190.), # teal 3 67 | (255., 192., 120.), # orange 3 68 | (255., 168., 168.), # red 3 69 | ] 70 | self.colors = np.array(self.colors) / 255. 71 | 72 | if 'color' in self.action_sizes: 73 | self.colors = self.colors[:self.action_sizes['color'][0]] 74 | 75 | self.controls = utils.uniform_locations( 76 | self.screen_size, self.location_size, 0) 77 | 78 | self.ends = utils.uniform_locations( 79 | self.screen_size, self.location_size, 0) 80 | 81 | def reset(self): 82 | self.entry_pressure = np.min(self.pressures) 83 | 84 | if self.conditional: 85 | self.random_target = self.get_random_target(num=1, squeeze=True) 86 | else: 87 | self.random_target = None 88 | 89 | self.s = tiledsurface.Surface() 90 | self.s.flood_fill(0, 0, (255, 255, 255), (0, 0, 64, 64), 0, self.s) 91 | self.s.begin_atomic() 92 | 93 | with open(self.args.brush_path) as fp: 94 | self.bi = brush.BrushInfo(fp.read()) 95 | self.b = brush.Brush(self.bi) 96 | 97 | self._step = 0 98 | self.s_x, self.s_y = None, None 99 | 100 | if self.args.conditional: 101 | self.z = None 102 | else: 103 | self.z = np.random.uniform(-0.5, 0.5, size=self.args.z_dim) 104 | 105 | return self.state, self.random_target, self.z 106 | 107 | def draw(self, ac, s=None, dtime=1): 108 | if s is None: 109 | s = self.s 110 | 111 | jump = 0 112 | x, y = self.ends[0] 113 | c_x, c_y = self.controls[0] 114 | color = self.colors[0] 115 | pressure, size = self.pressure, self.size 116 | 117 | for name in self.action_sizes: 118 | named_ac = ac[self.ac_idx[name]] 119 | value = getattr(self, name + "s")[named_ac] 120 | 121 | if name == 'end': 122 | x, y = value 123 | if name == 'control': 124 | c_x, c_y = value 125 | elif name == 'pressure': 126 | pressure = value 127 | elif name == 'size': 128 | size = value 129 | elif name == 'jump': 130 | jump = value 131 | 132 | if self.colorize: 133 | self.b.brushinfo.set_color_rgb(self.colors[self._step]) 134 | if 'size' in self.action_sizes: 135 | self.b.brushinfo.set_base_value('radius_logarithmic', size) 136 | 137 | if (self.s_x is None and self.s_y is None): 138 | # when self._step == 0 139 | pressure = 0 140 | self.s_x, self.s_y = 0, 0 141 | self._stroke_to(self.s_x, self.s_y, pressure) 142 | elif 'jump' in self.action_sizes and jump: 143 | pressure = 0 144 | self._stroke_to(self.s_x, self.s_y, pressure) 145 | else: 146 | self._stroke_to(self.s_x, self.s_y, pressure) 147 | 148 | self._draw(x, y, c_x, c_y, pressure, size, color, dtime) 149 | 150 | def _draw(self, x, y, c_x, c_y, 151 | pressure, size, color, dtime): 152 | end_pressure = pressure 153 | 154 | # if straight line or jump 155 | if 'control' not in self.action_sizes or pressure == 0: 156 | self.b.stroke_to( 157 | self.s.backend, x, y, pressure, 0, 0, dtime) 158 | else: 159 | end_pressure = self.curve( 160 | c_x, c_y, self.s_x, self.s_y, x, y, pressure) 161 | 162 | self.entry_pressure = end_pressure 163 | 164 | self.s_x, self.s_y = x, y 165 | 166 | self.s.end_atomic() 167 | self.s.begin_atomic() 168 | 169 | # sx, sy = starting point 170 | # ex, ey = end point 171 | # kx, ky = curve point from last line 172 | # lx, ly = last point from InteractionMode update 173 | def curve(self, cx, cy, sx, sy, ex, ey, pressure): 174 | #entry_p, midpoint_p, junk, prange2, head, tail 175 | entry_p, midpoint_p, prange1, prange2, h, t = \ 176 | self._line_settings(pressure) 177 | 178 | points_in_curve = 100 179 | mx, my = midpoint(sx, sy, ex, ey) 180 | length, nx, ny = length_and_normal(mx, my, cx, cy) 181 | cx, cy = multiply_add(mx, my, nx, ny, length*2) 182 | x1, y1 = difference(sx, sy, cx, cy) 183 | x2, y2 = difference(cx, cy, ex, ey) 184 | head = points_in_curve * h 185 | head_range = int(head)+1 186 | tail = points_in_curve * t 187 | tail_range = int(tail)+1 188 | tail_length = points_in_curve - tail 189 | 190 | # Beginning 191 | px, py = point_on_curve_1(1, cx, cy, sx, sy, x1, y1, x2, y2) 192 | length, nx, ny = length_and_normal(sx, sy, px, py) 193 | bx, by = multiply_add(sx, sy, nx, ny, 0.25) 194 | self._stroke_to(bx, by, entry_p) 195 | pressure = abs(1/head * prange1 + entry_p) 196 | self._stroke_to(px, py, pressure) 197 | 198 | for i in xrange(2, head_range): 199 | px, py = point_on_curve_1(i, cx, cy, sx, sy, x1, y1, x2, y2) 200 | pressure = abs(i/head * prange1 + entry_p) 201 | self._stroke_to(px, py, pressure) 202 | 203 | # Middle 204 | for i in xrange(head_range, tail_range): 205 | px, py = point_on_curve_1(i, cx, cy, sx, sy, x1, y1, x2, y2) 206 | self._stroke_to(px, py, midpoint_p) 207 | 208 | # End 209 | for i in xrange(tail_range, points_in_curve+1): 210 | px, py = point_on_curve_1(i, cx, cy, sx, sy, x1, y1, x2, y2) 211 | pressure = abs((i-tail)/tail_length * prange2 + midpoint_p) 212 | self._stroke_to(px, py, pressure) 213 | 214 | return pressure 215 | 216 | def _stroke_to(self, x, y, pressure, duration=0.1): 217 | self.b.stroke_to( 218 | self.s.backend, 219 | x, y, 220 | pressure, 221 | 0.0, 0.0, 222 | duration) 223 | 224 | def get_random_target(self, num=1, squeeze=False): 225 | random_idxes = np.random.choice(self.real_data.shape[0], num, replace=False) 226 | random_image = self.real_data[random_idxes] 227 | if squeeze: 228 | random_image = np.squeeze(random_image, 0) 229 | return random_image 230 | 231 | def step(self, acs): 232 | self.draw(acs, self.s) 233 | self._step += 1 234 | terminal = (self._step == self.episode_length) 235 | if terminal: 236 | if self.conditional: 237 | reward = 1 238 | reward += - utils.l2(self.state, self.random_target) \ 239 | / np.prod(self.observation_shape) 240 | else: 241 | reward = 0 242 | else: 243 | reward = 0 244 | # state, reward, terminal, info 245 | return self.state, reward, terminal, {} 246 | 247 | def save_image(self, path="test.png"): 248 | Image.fromarray(self.image.astype(np.uint8).squeeze()).save(path) 249 | #self.s.save_as_png(path, alpha=False) 250 | 251 | @property 252 | def image(self): 253 | rect = [0, 0, self.height, self.width] 254 | scanline_strips = \ 255 | surface.scanline_strips_iter(self.s, rect) 256 | return next(scanline_strips) 257 | 258 | @property 259 | def state(self): 260 | return utils.rgb2gray(self.image) 261 | 262 | def get_action_desc(self, ac): 263 | desc = [] 264 | for name in self.action_sizes: 265 | named_ac = ac[self.ac_idx[name]] 266 | actual_ac = getattr(self, name+"s")[named_ac] 267 | desc.append("{}: {} ({})".format(name, actual_ac, named_ac)) 268 | return "\n".join(desc) 269 | 270 | def _line_settings(self, pressure): 271 | p1 = self.entry_pressure 272 | p2 = (self.entry_pressure + pressure) / 2 273 | p3 = pressure 274 | if self.head == 0.0001: 275 | p1 = p2 276 | prange1 = p2 - p1 277 | prange2 = p3 - p2 278 | return p1, p2, prange1, prange2, self.head, self.tail 279 | 280 | def prepare_mnist(self): 281 | ut.io.makedirs(self.args.data_dir) 282 | 283 | # ground truth MNIST data 284 | mnist_dir = self.args.data_dir / 'mnist' 285 | mnist = tf.contrib.learn.datasets.DATASETS['mnist'](str(mnist_dir)) 286 | 287 | pkl_path = mnist_dir / 'mnist_dict.pkl' 288 | 289 | if pkl_path.exists(): 290 | mnist_dict = ut.io.load_pickle(pkl_path) 291 | else: 292 | mnist_dict = defaultdict(lambda: defaultdict(list)) 293 | for name in ['train', 'test', 'valid']: 294 | for num in self.args.mnist_nums: 295 | filtered_data = \ 296 | mnist.train.images[mnist.train.labels == num] 297 | filtered_data = \ 298 | np.reshape(filtered_data, [-1, 28, 28]) 299 | 300 | iterator = tqdm(filtered_data, 301 | desc="[{}] Processing {}".format(name, num)) 302 | for idx, image in enumerate(iterator): 303 | # XXX: don't know which way would be the best 304 | resized_image = ut.io.imresize( 305 | image, [self.height, self.width], 306 | interp='cubic') 307 | mnist_dict[name][num].append( 308 | np.expand_dims(resized_image, -1)) 309 | ut.io.dump_pickle(pkl_path, mnist_dict) 310 | 311 | mnist_dict = mnist_dict['train' if self.args.train else 'test'] 312 | 313 | data = [] 314 | for num in self.args.mnist_nums: 315 | data.append(mnist_dict[int(num)]) 316 | 317 | self.real_data = 255 - np.concatenate([d for d in data]) 318 | 319 | 320 | class SimpleMNIST(MNIST): 321 | 322 | action_sizes = { 323 | #'pressure': [2], 324 | 'jump': [2], 325 | #'color': [4], 326 | #'size': [2], 327 | 'control': None, 328 | 'end': None, 329 | } 330 | 331 | def __init__(self, args): 332 | super(SimpleMNIST, self).__init__(args) 333 | 334 | 335 | if __name__ == '__main__': 336 | import utils as ut 337 | from config import get_args 338 | 339 | args = get_args() 340 | ut.train.set_global_seed(args.seed) 341 | 342 | env = args.env.lower() 343 | 344 | if env == 'mnist': 345 | env = MNIST(args) 346 | elif env == 'simple_mnist': 347 | env = SimpleMNIST(args) 348 | else: 349 | raise Exception("Unkown environment: {}".format(args.env)) 350 | 351 | for ep_idx in range(10): 352 | step = 0 353 | env.reset() 354 | 355 | while True: 356 | action = env.random_action() 357 | print("[Step {}] ac: {}".format( 358 | step, env.get_action_desc(action))) 359 | state, reward, terminal, info = env.step(action) 360 | env.save_image("mnist{}_{}.png".format(ep_idx, step)) 361 | 362 | if terminal: 363 | print("Ep #{} finished ==> Reward: {}".format(ep_idx, reward)) 364 | break 365 | 366 | step += 1 367 | -------------------------------------------------------------------------------- /models/policy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | import utils as ut 5 | 6 | tl = tf.layers 7 | tc = tf.nn.rnn_cell 8 | 9 | 10 | class Policy(object): 11 | 12 | def __init__(self, args, env, scope_name, 13 | image_shape, action_sizes, data_format): 14 | self.args = args 15 | scale = args.scale 16 | self.lstm_size = args.lstm_size 17 | 18 | if data_format == 'channels_first' and args.dynamic_channel: 19 | self.image_shape = list(image_shape[-1:] + image_shape[:-1]) 20 | else: 21 | self.image_shape = list(image_shape) 22 | 23 | self.action_sizes = action_sizes 24 | self.data_format = data_format 25 | 26 | action_num = len(action_sizes) 27 | 28 | with tf.variable_scope(scope_name) as scope: 29 | # [B, max_time, H, W, C] 30 | self.x = x = tf.placeholder( 31 | tf.float32, [None, None] + self.image_shape, name='x') 32 | 33 | # last is only used for summary 34 | x = x[:,:env.episode_length] 35 | 36 | # Flatten multiple episodes 37 | # XXX: important difference from openai/universe-starter-agent 38 | x_shape = tf.shape(x) 39 | batch_size, max_time = x_shape[0], x_shape[1] 40 | 41 | # [B, max_time, action_num] 42 | self.ac = ac = tf.placeholder( 43 | tf.float32, [None, None, action_num], name='ac') 44 | 45 | if args.conditional: 46 | # [B, max_time, H, W, C] 47 | self.c = c = tf.placeholder( 48 | tf.float32, [None, None] + self.image_shape, name='c') 49 | # TODO: need to get confirmed from the authors 50 | x = tf.concat([x, c], axis=-1) 51 | x_shape = list(self.image_shape) 52 | x_shape[-1] = int(x.get_shape()[-1]) 53 | 54 | self.z = None 55 | else: 56 | self.c = None 57 | x_shape = self.image_shape 58 | 59 | # [B, max_time, z_dim] 60 | self.z = z = tf.placeholder( 61 | tf.float32, [None, None, self.args.z_dim], name='z') 62 | 63 | z_enc = mlp( 64 | z, 65 | self.lstm_size, 66 | name="z_enc") 67 | 68 | x = tf.reshape(x, [-1] + x_shape) 69 | ac = tf.reshape(ac, [-1, action_num]) 70 | 71 | if data_format == 'channels_first' and args.dynamic_channel: 72 | x = tf.transpose(x, [0, 3, 1, 2]) 73 | 74 | ################################ 75 | # Beginning of policy network 76 | ################################ 77 | 78 | a_enc = mlp( 79 | tf.expand_dims(ac, -1), 80 | int(16), 81 | name="a_enc") 82 | a_concat = tf.reshape( 83 | a_enc, [-1, int(16) * action_num]) 84 | a_fc = tl.dense( 85 | a_concat, int(32), 86 | activation=tf.nn.relu, 87 | name="a_concat_fc") 88 | 89 | # [B, 1, 1, 32] 90 | a_expand = tf.expand_dims(tf.expand_dims(a_fc, 1), 1) 91 | if data_format == 'channels_first' and args.dynamic_channel: 92 | a_expand = tf.transpose(a_expand, [0, 3, 1, 2]) 93 | 94 | x_enc = tl.conv2d( 95 | x, int(32), 5, 96 | padding='same', 97 | activation=tf.nn.relu, 98 | data_format=self.data_format, 99 | name="x_c_enc" if args.conditional else "x_enc") 100 | 101 | add = x_enc + a_expand 102 | 103 | for idx in range(int(3)): 104 | add = tl.conv2d( 105 | add, int(32), 4, strides=(2, 2), 106 | padding='valid', 107 | activation=tf.nn.relu, 108 | data_format=self.data_format, 109 | name="add_enc_{}".format(idx)) 110 | 111 | for idx in range(int(8*scale)): 112 | add = res_block( 113 | add, 32, 3, self.data_format, 114 | name="encoder_res_{}".format(idx)) 115 | 116 | flat = tl.flatten(add) 117 | 118 | out = tl.dense( 119 | flat, self.lstm_size, 120 | activation=tf.nn.relu, 121 | name="flat_fc") 122 | 123 | # [batch_size, max_time, ...] 124 | flat_out = tl.flatten(out) 125 | lstm_in_shape = [batch_size, max_time, flat_out.get_shape()[-1]] 126 | lstm_in = tf.reshape(flat_out, lstm_in_shape, name="lstm_in") 127 | 128 | if not self.args.conditional: 129 | lstm_in += z_enc 130 | 131 | self.lstm = tc.BasicLSTMCell(self.lstm_size, state_is_tuple=True) 132 | 133 | def make_init(batch_size): 134 | c_init = np.zeros((batch_size, self.lstm.state_size.c), np.float32) 135 | h_init = np.zeros((batch_size, self.lstm.state_size.h), np.float32) 136 | return [c_init, h_init] 137 | 138 | self.state_init = ut.misc.keydefaultdict(make_init) 139 | 140 | c_in = tf.placeholder( 141 | tf.float32, 142 | [None, self.lstm.state_size.c], 143 | name="lstm_c_in") 144 | h_in = tf.placeholder( 145 | tf.float32, 146 | [None, self.lstm.state_size.h], 147 | name="lstm_h_in") 148 | self.state_in = [c_in, h_in] 149 | state_in = tc.LSTMStateTuple(c_in, h_in) 150 | 151 | lstm_out, lstm_state = tf.nn.dynamic_rnn( 152 | self.lstm, 153 | # [batch_size, max_time, ...] 154 | lstm_in, 155 | # [batch_size, cell.state_size] 156 | initial_state=state_in, 157 | time_major=False) 158 | 159 | # [bach_size, max_time, action_size] 160 | self.one_hot_samples, self.samples, self.logits = self.decoder( 161 | tf.nn.relu(lstm_out), self.action_sizes, 162 | self.data_format, self.lstm_size, scale) 163 | 164 | lstm_c, lstm_h = lstm_state 165 | self.state_out = [lstm_c, lstm_h] 166 | 167 | self.vf = tl.dense( 168 | lstm_out, 1, 169 | activation=None, 170 | name="value")[:,:,0] 171 | #kernel_initializer=normalized_columns_initializer(1.0))[:,:,0] 172 | 173 | self.var_list = tf.trainable_variables(scope=scope_name) 174 | 175 | def get_initial_features(self, batch_size, flat=False): 176 | assert batch_size == 1 and flat, \ 177 | "Use flat=True only when batch_size == 1" 178 | out = self.state_init[batch_size] 179 | if flat: 180 | out = [out[0][0], out[1][0]] 181 | return out 182 | 183 | def get_feed_dict(self, ob, ac, c, h, condition, z): 184 | feed_dict = { 185 | self.x: [[ob]], # fake batch, time axis 186 | self.ac: [[ac]], 187 | self.state_in[0]: [c], # fake batch axis 188 | self.state_in[1]: [h], 189 | } 190 | if condition is not None: 191 | feed_dict.update({ self.c: [[condition]] }) 192 | if z is not None: 193 | feed_dict.update({ self.z: [[z]] }) 194 | return feed_dict 195 | 196 | def act(self, ob, ac, c, h, condition=None, z=None): 197 | sess = tf.get_default_session() 198 | 199 | feed_dict = self.get_feed_dict(ob, ac, c, h, condition, z) 200 | 201 | fetches = [self.one_hot_samples, self.vf] + self.state_out 202 | if not self.args.conditional: 203 | fetches += [self.z] 204 | 205 | out = sess.run(fetches, feed_dict) 206 | 207 | # TODO: need to extract one 208 | for idx, item in enumerate(out): 209 | if isinstance(item, dict): 210 | for name in item: 211 | item[name] = item[name][0] 212 | else: 213 | item = item[0] 214 | out[idx] = item 215 | return out 216 | 217 | def decoder(self, z, action_sizes, data_format, lstm_size, scale=1): 218 | # [batch, max_time, lstm_size] 219 | z_shape = tf.shape(z) 220 | batch_size, max_time = z_shape[0], z_shape[1] 221 | 222 | one_hot_samples, samples, logits = {}, {}, {} 223 | 224 | for action_idx, (name, action_size) in enumerate(action_sizes.items()): 225 | # [batch*max_time, lstm_size] 226 | z_flat = tf.reshape(z, [-1, self.lstm_size]) 227 | 228 | with tf.variable_scope("decoder_{}".format(name)): 229 | if len(action_size) == 1: 230 | N = action_size[0] 231 | logit = tl.dense( 232 | z_flat, N, 233 | activation=None, 234 | name="action{}".format(name), 235 | kernel_initializer= \ 236 | normalized_columns_initializer(0.01)) 237 | else: 238 | # format: NHWC 239 | reshape = tf.reshape(z_flat, [-1, 4, 4, int(lstm_size / 16)]) 240 | 241 | # format: NHWC 242 | res = deconv = tl.conv2d_transpose( 243 | reshape, int(32), 4, 244 | strides=(2, 2), 245 | padding='same', 246 | activation=tf.nn.relu, 247 | data_format='channels_last') 248 | 249 | if data_format == 'channels_first' \ 250 | and args.dynamic_channel: 251 | # format: NHWC -> NCHW 252 | res = tf.transpose(res, [0, 3, 1, 2]) 253 | 254 | # format: each 255 | for res_idx in range(int(8*scale)): 256 | res = res_block( 257 | res, int(32), 3, data_format, 258 | name="decoder_res_{}".format(res_idx)) 259 | 260 | # format: NHWC 261 | deconv = res 262 | transposed = False 263 | for deconv_idx in range(int(2)): 264 | deconv_width = int(deconv.get_shape()[2]) 265 | if deconv_width == action_size[0]: 266 | break 267 | 268 | # format: NCHW -> NHWC 269 | if deconv_idx == 0 and data_format == 'channels_first' \ 270 | and args.dynamic_channel: 271 | transposed = True 272 | deconv = tf.transpose(deconv, [0, 2, 3, 1]) 273 | 274 | # format: NHWC 275 | deconv = tl.conv2d_transpose( 276 | deconv, int(32), 4, 277 | strides=(2, 2), 278 | padding='same', 279 | activation=tf.nn.relu, 280 | data_format='channels_last', 281 | name="deconv_{}".format(deconv_idx)) 282 | 283 | # format: each 284 | if data_format == 'channels_first' and transposed \ 285 | and args.dynamic_channel: 286 | # format: NHWC -> NCHW 287 | deconv = tf.transpose(deconv, [0, 3, 1, 2]) 288 | 289 | # format: each 290 | conv = tl.conv2d( 291 | deconv, 1, 3, 292 | padding='same', 293 | activation=None, 294 | data_format=data_format, 295 | name="conv_1x1") 296 | 297 | logit = tl.flatten(conv) 298 | 299 | logits[name] = tf.reshape( 300 | logit, [batch_size, max_time, -1]) 301 | 302 | action_one_hot, action = \ 303 | categorical_sample(logit, np.prod(action_size)) 304 | 305 | # [batch, max_time, action_size[name]] 306 | one_hot_samples[name] = tf.reshape( 307 | action_one_hot, [batch_size, max_time, -1], 308 | name="one_hot_samples_{}".format(name)) 309 | # [batch, max_time, 1] 310 | samples[name] = tf.reshape( 311 | action, [batch_size, max_time], 312 | name="samples_{}".format(name)) 313 | 314 | if action_idx < len(action_sizes) - 1: 315 | # this will be feeded to make gradient flows 316 | out = mlp( 317 | tf.expand_dims(samples[name], -1), int(16), 318 | name='sample_mlp') 319 | # [batch, max_time, lstm_size] 320 | z = tl.dense( 321 | tf.concat([z, out], -1), int(lstm_size), 322 | activation=tf.nn.relu, 323 | name="concat_z_fc") 324 | 325 | return one_hot_samples, samples, logits 326 | 327 | #def value(self, ob, c, h): 328 | # sess = tf.get_default_session() 329 | # feed_dict = { 330 | # self.x: [[ob]], 331 | # self.state_in[0]: c, 332 | # self.state_in[1]: h, 333 | # } 334 | # return sess.run(self.vf, feed_dict)[0][0] 335 | 336 | # TODO: not sure what this architecture is (1) 337 | def mlp(x, dim, hid_dim=64, num_layers=3, name=None): 338 | for idx in range(num_layers-1): 339 | x = tl.dense( 340 | x, hid_dim, 341 | activation=tf.nn.relu, 342 | name="{}_{}".format(name, idx)) 343 | x = tl.dense( 344 | x, dim, 345 | activation=tf.nn.relu, 346 | name="{}_{}".format(name, idx+1)) 347 | return x 348 | 349 | def res_block(x, channel, size, data_format, name): 350 | with tf.variable_scope(name): 351 | enc_x = tl.conv2d( 352 | x, channel, size, 353 | padding='same', 354 | activation=tf.nn.relu, 355 | data_format=data_format) 356 | 357 | res = tl.conv2d( 358 | enc_x, channel, size, 359 | padding='same', 360 | activation=None, 361 | data_format=data_format) + x 362 | return res 363 | 364 | def categorical_sample(logits, d): 365 | out = tf.multinomial(logits - \ 366 | tf.reduce_max(logits, [1], keepdims=True), 1) 367 | value = tf.squeeze(out, [1]) 368 | return tf.one_hot(value, d), tf.cast(value, tf.float32) 369 | 370 | def normalized_columns_initializer(std=1.0): 371 | def _initializer(shape, dtype=None, partition_info=None): 372 | out = np.random.randn(*shape).astype(np.float32) 373 | out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) 374 | return tf.constant(out) 375 | return _initializer 376 | -------------------------------------------------------------------------------- /agent.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | import models 9 | import replay 10 | import rl_utils 11 | import utils as ut 12 | 13 | logger = ut.logging.get_logger() 14 | image_reshaper = tf.contrib.gan.eval.eval_utils.image_reshaper 15 | 16 | 17 | class Agent(object): 18 | 19 | def __init__(self, args, server, cluster, env, queue_shapes, 20 | trajectory_queue_size, replay_queue_size): 21 | self.env = env 22 | self.args = args 23 | self.task = args.task 24 | self.queue_shapes = queue_shapes 25 | self.trajectory_queue_size = trajectory_queue_size 26 | self.replay_queue_size = replay_queue_size 27 | 28 | self.action_sizes = env.action_sizes 29 | self.input_shape = list(self.env.observation_shape) 30 | 31 | # used for summary 32 | self._disc_step = 0 33 | self._policy_step = 0 34 | 35 | ################################## 36 | # Queue pipelines (ps/task=0~) 37 | ################################## 38 | with tf.device('/job:ps/task:0'): 39 | # TODO: we may need more than 1 queue 40 | #for i in range(cluster.num_tasks('ps')): 41 | if args.task != 1 or args.loss == 'l2': 42 | self.trajectory_queue = tf.FIFOQueue( 43 | self.trajectory_queue_size, 44 | [tf.float32] * len(self.queue_shapes), 45 | shapes=[shape for _, shape in self.queue_shapes], 46 | names=[name for name, _ in self.queue_shapes], 47 | shared_name='queue') 48 | self.trajectory_queue_size_op = self.trajectory_queue.size() 49 | 50 | if args.loss == 'gan': 51 | self.replay_queue = tf.FIFOQueue( 52 | self.replay_queue_size, 53 | tf.float32, 54 | shapes=dict(self.queue_shapes)['states'][1:], 55 | shared_name='replay') 56 | self.replay_queue_size_op = self.replay_queue.size() 57 | else: 58 | self.replay_queue = None 59 | self.replay_queue_size_op = None 60 | 61 | ########################### 62 | # Master policy (task!=1) 63 | ########################### 64 | 65 | device = 'gpu' if self.task == 0 else 'cpu' 66 | master_gpu = "/job:worker/task:{}/{}:0".format(self.args.task, device) 67 | master_gpu_replica = tf.train. \ 68 | replica_device_setter(1, worker_device=master_gpu) 69 | 70 | with tf.device(master_gpu_replica): 71 | with tf.variable_scope("global"): 72 | self.policy_step = tf.get_variable( 73 | "policy_step", [], tf.int32, 74 | initializer=tf.constant_initializer(0, dtype=tf.int32), 75 | trainable=False) 76 | 77 | self.disc_step = tf.get_variable( 78 | "disc_step", [], tf.int32, 79 | initializer=tf.constant_initializer(0, dtype=tf.int32), 80 | trainable=False) 81 | 82 | #master_cpu = "/job:worker/task:{}/cpu:0".format(self.args.task, device) 83 | #master_cpu_replica = tf.train. \ 84 | # replica_device_setter(1, worker_device=master_cpu) 85 | 86 | #with tf.device(master_cpu_replica): 87 | # master should initialize discriminator 88 | if args.task < 2 and args.loss == 'gan': 89 | self.global_disc = models.Discriminator( 90 | self.args, self.disc_step, self.input_shape, 91 | self.env.norm, "global") 92 | 93 | if args.task != 1 or args.loss == 'l2': 94 | logger.debug(master_gpu) 95 | 96 | with tf.device(master_gpu_replica): 97 | self.prepare_master_network() 98 | 99 | ########################### 100 | # Master policy network 101 | ########################### 102 | if self.args.task == 0: 103 | policy_batch_size = self.args.policy_batch_size 104 | # XXX: may need this if you are lack of GPU memory 105 | #policy_batch_size = int(self.args.policy_batch_size \ 106 | # / self.env.episode_length) 107 | 108 | worker_device = "/job:worker/task:{}/cpu:0".format(self.task) 109 | logger.debug(worker_device) 110 | 111 | with tf.device(worker_device): 112 | with tf.variable_scope("global"): 113 | self.trajectory_dequeue = self.trajectory_queue. \ 114 | dequeue_many(policy_batch_size) 115 | 116 | ########################### 117 | # Discriminator (task=1) 118 | ########################### 119 | elif self.args.task == 1 and self.args.loss == 'gan': 120 | device = 'gpu' if args.num_gpu > 0 else 'cpu' 121 | worker_device = "/job:worker/task:{}/{}:0".format(self.task, device) 122 | logger.debug(worker_device) 123 | 124 | with tf.device(worker_device): 125 | self.prepare_gan() 126 | 127 | worker_device = "/job:worker/task:{}/cpu:0".format(self.task) 128 | logger.debug(worker_device) 129 | 130 | with tf.device(worker_device): 131 | with tf.variable_scope("global"): 132 | self.replay_dequeue = self.replay_queue. \ 133 | dequeue_many(self.args.disc_batch_size) 134 | 135 | ##################################################### 136 | # Local policy network (task >= 2 (gan) or 1 (l2)) 137 | ##################################################### 138 | elif self.args.task >= 1: 139 | worker_device = "/job:worker/task:{}/cpu:0".format(self.task) 140 | logger.debug(worker_device) 141 | 142 | with tf.device(worker_device): 143 | self.prepare_local_network() 144 | 145 | def prepare_master_network(self): 146 | self.global_network = pi = models.Policy( 147 | self.args, self.env, "global", 148 | self.input_shape, self.action_sizes, 149 | data_format='channels_first' \ 150 | if self.args.dynamic_channel \ 151 | else 'channels_last') 152 | 153 | self.acs, acs = {}, {} 154 | for idx, (name, action_size) in enumerate( 155 | self.action_sizes.items()): 156 | # [B, action_size] 157 | self.acs[name] = tf.placeholder( 158 | tf.int32, [None, None], name="{}_in".format(name)) 159 | acs[name] = tf.one_hot(self.acs[name], np.prod(action_size)) 160 | 161 | self.adv = adv = tf.placeholder( 162 | tf.float32, [None, self.env.episode_length], name="adv") 163 | self.r = r = tf.placeholder( 164 | tf.float32, [None, self.env.episode_length], name="r") 165 | 166 | bsz = tf.to_float(tf.shape(pi.x)[0]) 167 | 168 | ######################## 169 | # Building optimizer 170 | ######################## 171 | 172 | self.loss = 0 173 | self.pi_loss, self.vf_loss, self.entropy = 0, 0, 0 174 | 175 | for name in self.action_sizes: 176 | ac = acs[name] 177 | logit = pi.logits[name] 178 | 179 | log_prob_tf = tf.nn.log_softmax(logit) 180 | prob_tf = tf.nn.softmax(logit) 181 | 182 | pi_loss = - tf.reduce_sum( 183 | tf.reduce_sum(log_prob_tf * ac, [-1]) * adv) 184 | 185 | # loss of value function 186 | vf_loss = 0.5 * tf.reduce_sum(tf.square(pi.vf - r)) 187 | entropy = - tf.reduce_sum(prob_tf * log_prob_tf) 188 | 189 | self.loss += pi_loss + 0.5 * vf_loss - \ 190 | entropy * self.args.entropy_coeff 191 | 192 | self.pi_loss += pi_loss 193 | self.vf_loss += vf_loss 194 | self.entropy += entropy 195 | 196 | grads = tf.gradients(self.loss, pi.var_list) 197 | 198 | ################## 199 | # Summaries 200 | ################## 201 | 202 | # summarize only the last state 203 | last_state = self.env.denorm(pi.x[:,-1]) 204 | last_state.set_shape( 205 | [self.args.policy_batch_size] + ut.tf.int_shape(last_state)[1:]) 206 | 207 | summaries = [ 208 | tf.summary.image("last_state", image_reshaper(last_state)), 209 | tf.summary.scalar("env/r", tf.reduce_mean(self.r[:,-1])), 210 | tf.summary.scalar("model/policy_loss", self.pi_loss / bsz), 211 | tf.summary.scalar("model/value_loss", self.vf_loss / bsz), 212 | tf.summary.scalar("model/entropy", self.entropy / bsz), 213 | tf.summary.scalar("model/grad_global_norm", tf.global_norm(grads)), 214 | tf.summary.scalar("model/var_global_norm", tf.global_norm(pi.var_list)), 215 | ] 216 | 217 | if pi.c is not None: 218 | target = self.env.denorm(pi.c[:,-1]) 219 | target.set_shape( 220 | [self.args.policy_batch_size] + ut.tf.int_shape(target)[1:]) 221 | 222 | summaries.append( 223 | tf.summary.image("target", image_reshaper(target))) 224 | 225 | self.l2_loss = tf.sqrt(1e-8 + 226 | tf.reduce_sum(((pi.x[:,-1] - pi.c[:,-1])/255.)**2, [-3,-2,-1])) 227 | summaries.append( 228 | tf.summary.scalar("model/l2_loss", tf.reduce_mean(self.l2_loss))) 229 | 230 | self.summary_op = tf.summary.merge(summaries) 231 | grads, _ = tf.clip_by_global_norm(grads, self.args.grad_clip) 232 | 233 | grads_and_vars = list(zip(grads, self.global_network.var_list)) 234 | 235 | # each worker has a different set of adam optimizer parameters 236 | opt = tf.train.AdamOptimizer( 237 | self.args.policy_lr, name="policy_optim") 238 | 239 | self.train_op = opt.apply_gradients(grads_and_vars, self.policy_step) 240 | self.summary_writer = None 241 | 242 | def prepare_local_network(self): 243 | self.local_network = models.Policy( 244 | self.args, self.env, "local", 245 | self.input_shape, self.action_sizes, 246 | data_format='channels_last') 247 | 248 | ########################## 249 | # Trajectory queue 250 | ########################## 251 | self.trajectory_placeholders = { 252 | name:tf.placeholder( 253 | tf.float32, dict(self.queue_shapes)[name], 254 | name="{}_in".format(name)) \ 255 | for name, shape in self.queue_shapes 256 | } 257 | self.trajectory_enqueues = self.trajectory_queue.enqueue( 258 | { name:self.trajectory_placeholders[name] \ 259 | for name, _ in self.queue_shapes }) 260 | 261 | ########################## 262 | # Replay queue 263 | ########################## 264 | if self.args.loss == 'gan': 265 | self.replay_placeholder = tf.placeholder( 266 | tf.float32, self.input_shape, 267 | name="replay_in") 268 | self.replay_enqueue = self.replay_queue.enqueue( 269 | self.replay_placeholder) 270 | else: 271 | self.replay_placeholder = None 272 | self.replay_enqueue = None 273 | 274 | ############################### 275 | # Thread dealing with queues 276 | ############################### 277 | self.worker_thread = rl_utils.WorkerThread( 278 | self.env, 279 | self.local_network, 280 | self.trajectory_enqueues, 281 | self.trajectory_placeholders, 282 | self.trajectory_queue_size_op, 283 | self.replay_enqueue, 284 | self.replay_placeholder, 285 | self.replay_queue_size_op) 286 | 287 | # copy weights from the parameter server to the local model 288 | self.policy_sync = ut.tf.get_sync_op( 289 | from_list=self.global_network.var_list, 290 | to_list=self.local_network.var_list) 291 | 292 | def prepare_gan(self): 293 | self.replay = replay.ReplayBuffer(self.args, self.input_shape) 294 | self.replay_dequeue = \ 295 | self.replay_queue.dequeue_many(self.args.disc_batch_size) 296 | 297 | self.replay_thread = rl_utils.ReplayThread( 298 | self.replay, self.replay_dequeue) 299 | 300 | self.local_disc = models.Discriminator( 301 | self.args, self.disc_step, self.input_shape, 302 | self.env.norm, "local") 303 | 304 | self.disc_sync = ut.tf.get_sync_op( 305 | from_list=self.local_disc.var_list, 306 | to_list=self.global_disc.var_list) 307 | 308 | self.disc_initializer = ut.tf.get_sync_op( 309 | from_list=self.global_disc.var_list, 310 | to_list=self.local_disc.var_list) 311 | 312 | def start_worker_thread(self, sess, summary_writer): 313 | self.worker_thread.start_thread(sess, summary_writer) 314 | self.summary_writer = summary_writer 315 | 316 | def start_replay_thread(self, sess, summary_writer): 317 | self.replay_thread.start_thread(sess) 318 | self.summary_writer = summary_writer 319 | 320 | def pull_batch_from_queue(self): 321 | rollout = self.worker_thread.queue.get(timeout=600.0) 322 | while not rollout.terminal: 323 | try: 324 | rollout.extend(self.worker_thread.queue.get_nowait()) 325 | except queue.Empty: 326 | break 327 | return rollout 328 | 329 | ########################### 330 | # Master policy (task=0) 331 | ########################### 332 | 333 | def train_policy(self, sess): 334 | rollout = sess.run(self.trajectory_dequeue) 335 | 336 | if self.args.loss == 'gan': 337 | probs = self.global_disc.predict(rollout['states'][:,-1]) 338 | rollout['rewards'][:,-1] = probs 339 | 340 | batch = rl_utils.multiple_process_rollout( 341 | rollout, gamma=0.99, lambda_=1.0) 342 | 343 | ################# 344 | # Feed ops 345 | ################# 346 | 347 | feed_dict = { 348 | # [B, ep_len] 349 | self.r: batch.r, 350 | self.adv: batch.adv, 351 | self.global_network.x: batch.si, 352 | # [B, ep_len, action_size] 353 | self.global_network.ac: batch.a, 354 | self.global_network.state_in[0]: batch.features[:,0], 355 | self.global_network.state_in[1]: batch.features[:,1], 356 | } 357 | for name in self.action_sizes: 358 | name_a = batch.a[:,:,self.env.ac_idx[name]] 359 | feed_dict.update({ 360 | self.acs[name]: name_a, 361 | }) 362 | if name in self.global_network.samples: 363 | feed_dict.update({ 364 | self.global_network.samples[name]: name_a, 365 | }) 366 | 367 | if self.args.conditional: 368 | feed_dict.update({ 369 | self.global_network.c: batch.c, 370 | }) 371 | else: 372 | feed_dict.update({ 373 | self.global_network.z: batch.z, 374 | }) 375 | 376 | ################# 377 | # Fetch ops 378 | ################# 379 | 380 | fetches = { 381 | 'train': self.train_op, 382 | 'step': self.policy_step, 383 | } 384 | if self._policy_step % self.args.policy_log_step == 0: 385 | fetches.update({ 386 | 'summary': self.summary_op, 387 | 'policy_size': self.trajectory_queue_size_op, 388 | }) 389 | 390 | out = sess.run(fetches, feed_dict=feed_dict) 391 | 392 | if self._policy_step % self.args.policy_log_step == 0: 393 | self.summary_writer.add_summary( 394 | tf.Summary.FromString(out['summary']), out['step']) 395 | self.summary_writer.flush() 396 | 397 | debug_text = "# traj: {}".format(out['policy_size']) 398 | if self.task == 0: 399 | logger.info(debug_text) 400 | else: 401 | logger.debug(debug_text) 402 | 403 | self._policy_step = out['step'] 404 | 405 | ########################### 406 | # Discriminator (task=1) 407 | ########################### 408 | 409 | def train_gan(self, sess): 410 | fakes = self.replay.sample( 411 | self.args.disc_batch_size) 412 | 413 | feed_dict = { 414 | self.local_disc.fake: fakes, 415 | self.local_disc.real: self.env.get_random_target(self.args.disc_batch_size), 416 | } 417 | 418 | fetches = { 419 | 'train': self.local_disc.train_op, 420 | 'step': self.local_disc.step, 421 | } 422 | if self._disc_step % self.args.disc_log_step == 0: 423 | fetches.update({ 424 | 'summary': self.local_disc.summary_op, 425 | 'replay_size': self.replay_queue_size_op, 426 | }) 427 | 428 | out = sess.run(fetches, feed_dict=feed_dict) 429 | 430 | if self._disc_step % self.args.disc_log_step == 0: 431 | self.summary_writer.add_summary( 432 | tf.Summary.FromString(out['summary']), out['step']) 433 | self.summary_writer.flush() 434 | 435 | logger.info("# replay: {}".format(out['replay_size'])) 436 | 437 | self._disc_step = out['step'] 438 | 439 | 440 | def weights_before_after(before, after, var_to_test): 441 | print(" [*] Weight change check") 442 | 443 | for idx, (bef, aft, var) in \ 444 | enumerate(zip(before, after, var_to_test)): 445 | assert bef.shape == aft.shape, \ 446 | "Shape [{}] is not same: {}, {}".format( 447 | var.name, bef.shape, aft.shape) 448 | 449 | bef_sum, aft_sum = bef.sum(), aft.sum() 450 | same_or_not = "SAME" if bef_sum == aft_sum else " " 451 | 452 | print(" [{}] {}: {} ({}, {})". \ 453 | format(idx, var.name, same_or_not, bef_sum, aft_sum)) 454 | 455 | --------------------------------------------------------------------------------