├── README.md ├── data_generator.py ├── convex_hull.py ├── data_preprocess.py ├── .gitignore ├── u_net.py ├── run_dqn_img.py ├── logz.py ├── env.py ├── local_env.py ├── dqn_utils.py └── dqn.py /README.md: -------------------------------------------------------------------------------- 1 | # Deep RL Segmentation 2 | Project for Berkeley Deep RL course: using deep reinforcement learning for segmentation of medical images 3 | 4 | # Data pre-processing 5 | For the data pre-processing script to work: 6 | - Clone cocoapi inside the deeprl_segmentation folder, and follow the instructions to install it (usually just need to run Make 7 | inside the PythonAPI folder) 8 | - Download your coco dataset (for example, val2017) inside the deeprl_segmentation folder 9 | - Download the corresponding annotations, and place them inside a folder called annotations inside the deeprl_segmentation folder 10 | -------------------------------------------------------------------------------- /data_generator.py: -------------------------------------------------------------------------------- 1 | from data_preprocess import PROC_DATA_DIR, DATA_TYPE 2 | images_dir = '%s/%s/images/'%(PROC_DATA_DIR, DATA_TYPE) 3 | masks_dirs = '%s/%s/masks/'%(PROC_DATA_DIR, DATA_TYPE) 4 | 5 | import os 6 | import random 7 | import numpy as np 8 | import skimage.io as io 9 | import cv2 10 | 11 | def getRandomFile(path): 12 | """ 13 | Returns a random filename, chosen among the files of the given path. 14 | """ 15 | files = os.listdir(path) 16 | index = random.randrange(0, len(files)) 17 | return files[index] 18 | 19 | def generator_fn(num_processes=4, batch_size=128): 20 | while True: 21 | files = [getRandomFile(images_dir) for i in range(batch_size)] 22 | mask_files = [masks_dirs + f + ".npy" for f in files] 23 | img_mask_pairs = [(io.imread(fname=images_dir + img_file), np.load(mask_file)) for img_file, mask_file in zip(files, mask_files)] 24 | for img, mask in img_mask_pairs: 25 | if img.shape == (256, 256, 3): 26 | yield img, mask 27 | #yield cv2.resize(img, dsize=(32, 32), interpolation=cv2.INTER_CUBIC), cv2.resize(mask, dsize=(32, 32), interpolation=cv2.INTER_CUBIC) 28 | -------------------------------------------------------------------------------- /convex_hull.py: -------------------------------------------------------------------------------- 1 | from scipy.spatial import ConvexHull 2 | import numpy as np 3 | 4 | class ConvexHullPolicy(): 5 | def __init__(self, img_size): 6 | self.done = True 7 | self.pen_up = False 8 | self.img_size = img_size 9 | self.mask = None 10 | 11 | def get_action(self, state, true_segmentation): 12 | if self.mask is None or not np.array_equal(true_segmentation, self.mask): 13 | self.done = True 14 | self.pen_up = False 15 | 16 | if self.pen_up: 17 | self.pen_up = False 18 | self.done = True 19 | return 1 # Finish drawing 20 | 21 | if self.done: 22 | self.mask = true_segmentation 23 | assert(self.mask.shape == state.shape[:2]) 24 | points = np.argwhere(true_segmentation) 25 | if len(points) == 0: 26 | return 1 27 | try: 28 | hull = ConvexHull(points) 29 | except: 30 | return 1 31 | self.vertices = [points[vertex] for vertex in hull.vertices] 32 | self.i = 0 33 | self.done = False 34 | 35 | if self.i >= len(self.vertices): 36 | self.pen_up = True 37 | return 0 # Pen Up 38 | 39 | x, y = self.vertices[self.i] 40 | action = int(2 + x * self.img_size + y) 41 | 42 | self.i += 1 43 | return action -------------------------------------------------------------------------------- /data_preprocess.py: -------------------------------------------------------------------------------- 1 | from pycocotools.coco import COCO 2 | import numpy as np 3 | import skimage.io as io 4 | import os 5 | import time 6 | 7 | DATA_TYPE = "train2017" 8 | DATA_DIR="" 9 | annFile='annotations/instances_{}.json'.format(DATA_TYPE) 10 | PROC_DATA_DIR = "proc" 11 | SIZE = 256 12 | 13 | 14 | def merge_masks(masks): 15 | # Merges the binary 0-1 masks using bitwise OR 16 | return np.expand_dims(np.bitwise_or.reduce(masks, axis=2),axis=2) 17 | 18 | def main(): 19 | coco = COCO(annFile) 20 | catIds = coco.getCatIds(catNms=['person']) 21 | imgIds = coco.getImgIds(catIds=catIds ) 22 | print("Person images : ", len(imgIds)) 23 | images = coco.loadImgs(imgIds) 24 | count = 0 25 | iter_count = 0 26 | start_time = time.time() 27 | os.makedirs('%s/%s/images/'%(PROC_DATA_DIR,DATA_TYPE), exist_ok=True) 28 | os.makedirs('%s/%s/masks/'%(PROC_DATA_DIR,DATA_TYPE), exist_ok=True) 29 | for img in images: 30 | iter_count +=1 31 | if iter_count % 100 == 0: 32 | print("Time since started : " , time.time() - start_time) 33 | if img['height'] < SIZE or img['width'] < SIZE: 34 | continue 35 | count +=1 36 | I = io.imread('%s%s/%s'%(DATA_DIR,DATA_TYPE,img['file_name'])) 37 | cropped = I[:SIZE, :SIZE] 38 | annotation_ids = coco.getAnnIds(imgIds=img['id']) 39 | annotations = coco.loadAnns(ids=annotation_ids) 40 | masks = [coco.annToMask(ann) for ann in annotations] 41 | cropped_masks = merge_masks(np.stack([mask[:SIZE, :SIZE] for mask in masks], axis=-1)) 42 | io.imsave('%s/%s/images/%s'%(PROC_DATA_DIR,DATA_TYPE,img['file_name']), cropped) 43 | np.save('%s/%s/masks/%s'%(PROC_DATA_DIR,DATA_TYPE,img['file_name']), cropped_masks) 44 | print("Processed " + str(count) + " images") 45 | 46 | 47 | if __name__ == '__main__': 48 | main() -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # Data and COCO related stuff 107 | /annotations 108 | /val2017 109 | /cocoapi 110 | /proc/ 111 | 112 | # Visual Studio Code stuff 113 | .vscode/ -------------------------------------------------------------------------------- /u_net.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | #Performs a 2D Convolution operation with padding 5 | def conv(input_layer, filter_size, kernel_size, name, strides, padding = "SAME", activation = tf.nn.relu): 6 | output = tf.layers.conv2d(input_layer, filters = filter_size, kernel_size = kernel_size, name = name, strides = strides, padding = padding, activation = activation) 7 | return output 8 | 9 | def deconv(input_layer, filter_size, output_size, out_channel, in_channel, name, strides = [1, 1, 1, 1], padding = "SAME"): 10 | batch_size = tf.shape(input_layer)[0] 11 | output = tf.nn.conv2d_transpose(input_layer, tf.get_variable(name = name, shape = [filter_size, filter_size, out_channel, in_channel]), tf.stack([batch_size, output_size, output_size, out_channel]), strides = strides, padding = padding) 12 | return output 13 | 14 | #Defines the unet structure, given a batch of input images of size 256 x 256 15 | #Input: img_input is a tf Tensor of shape [batch_size, 256, 256, 1]. img_input should have type 'float32'. 16 | #Optional Args: scope name, reuse (added for compatibility purposes) 17 | #Output: Returns a tensor of shape [batch_size, 256, 256, 1] representing the q-values. 18 | def build_unet(img_input, scope = "default", reuse = False): 19 | with tf.variable_scope(scope, reuse=reuse): 20 | res = img_input #256 x 256 21 | #res = conv(res, 32 ,3, 'F0', strides = (2,2)) #128 x 128 22 | #res = conv(res, 64, 3, 'F1', strides = (2,2)) #64 x 64 23 | #res = conv(res, 128, 3, 'F2', strides = (2,2)) #32 x 32 24 | res = conv(res, 256, 3, 'F3', strides = (2,2)) #16 x 16 25 | 26 | #2 FC layers to get the convolved tensor down to 3 values for pen-state 27 | pen_states = tf.contrib.layers.flatten(res) 28 | pen_states = tf.contrib.layers.fully_connected(pen_states, 300) 29 | pen_states = tf.contrib.layers.fully_connected(pen_states, 2) 30 | 31 | #Up-convolve 32 | res = deconv(res, 1, 16, 256, 256, 'B0') #16 x 16 33 | res = tf.nn.relu(res) 34 | 35 | res = deconv(res, 2, 32, 128, 256, 'B1', [1,2,2,1]) #32 x 32 36 | res = tf.nn.relu(res) 37 | 38 | #res = deconv(res, 2, 64, 64, 128, 'B2', [1,2,2,1]) #64 x 64 39 | #res = tf.nn.relu(res) 40 | 41 | #res = deconv(res, 2, 128, 32, 64, 'B3', [1,2,2,1]) #128 x 128 42 | #res = tf.nn.relu(res) 43 | 44 | #res = deconv(res, 2, 256, 16, 32, 'B4', [1,2,2,1]) #256 x 256 45 | #res = tf.nn.relu(res) 46 | 47 | res = deconv(res, 1, 32, 1, 128, 'B5') 48 | res = tf.contrib.layers.flatten(res) 49 | return tf.concat([pen_states, res], axis=1) 50 | 51 | #Example usage 52 | def main(): 53 | img = tf.convert_to_tensor(np.random.uniform(0, 1, size = (10, 256, 256, 6)).astype('float32')) 54 | ans = build_unet(img) 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /run_dqn_img.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import random 5 | import numpy as np 6 | import tensorflow as tf 7 | import tensorflow.contrib.layers as layers 8 | #from env import Environment 9 | from local_env import Environment 10 | import dqn 11 | from dqn_utils import * 12 | from u_net import build_unet 13 | from data_generator import generator_fn, DATA_TYPE 14 | 15 | TRAIN_DATA_DIR = "train" 16 | 17 | def img_segment_learn(env, 18 | session, 19 | num_timesteps, 20 | progress_dir): 21 | # This is just a rough estimate 22 | # Store logged images in progress_dir 23 | num_iterations = float(num_timesteps) / 4.0 24 | 25 | lr_multiplier = 1.0 26 | lr_schedule = PiecewiseSchedule([ 27 | (0, 1e-4 * lr_multiplier), 28 | (num_iterations / 10, 1e-4 * lr_multiplier), 29 | (num_iterations / 2, 5e-5 * lr_multiplier), 30 | ], 31 | outside_value=5e-5 * lr_multiplier) 32 | optimizer = dqn.OptimizerSpec( 33 | constructor=tf.train.AdamOptimizer, 34 | kwargs=dict(epsilon=1e-4), 35 | lr_schedule=lr_schedule 36 | ) 37 | 38 | exploration_schedule = PiecewiseSchedule( 39 | [ 40 | (0, 1.0), 41 | (1e6, 0.1), 42 | (num_iterations / 2, 0.01), 43 | ], outside_value=0.01 44 | ) 45 | 46 | return dqn.learn( 47 | env=env, 48 | q_func=build_unet, 49 | optimizer_spec=optimizer, 50 | session=session, 51 | exploration=exploration_schedule, 52 | replay_buffer_size=50000, 53 | batch_size=32, 54 | gamma=0.99, 55 | learning_starts=50000, 56 | learning_freq=4, 57 | target_update_freq=10000, 58 | grad_norm_clipping=10, 59 | double_q=False, 60 | progress_dir=progress_dir 61 | ) 62 | 63 | def get_available_gpus(): 64 | from tensorflow.python.client import device_lib 65 | local_device_protos = device_lib.list_local_devices() 66 | return [x.physical_device_desc for x in local_device_protos if x.device_type == 'GPU'] 67 | 68 | def get_session(): 69 | tf.reset_default_graph() 70 | tf_config = tf.ConfigProto( 71 | inter_op_parallelism_threads=1, 72 | intra_op_parallelism_threads=1) 73 | session = tf.Session(config=tf_config) 74 | print("AVAILABLE GPUS: ", get_available_gpus()) 75 | return session 76 | 77 | 78 | def main(): 79 | # Run training 80 | env = Environment(generator_fn(), img_shape=(256,256,3)) 81 | #test_env = Environment(test_generator_fn) 82 | session = get_session() 83 | training_result_dir = '%s/%s/results'%(TRAIN_DATA_DIR,DATA_TYPE) 84 | training_progress_dir = '%s/%s/progress'%(TRAIN_DATA_DIR,DATA_TYPE) 85 | #os.makedirs(training_progress_dir) 86 | #os.makedirs(training_result_dir) 87 | alg = img_segment_learn(env, session,num_timesteps=2e8, progress_dir=training_progress_dir) 88 | training_results, training_rewards = alg.test(env, num_test_samples = 1000) 89 | i = 0 90 | reward_sum = 0 91 | for result, reward in zip(training_results, training_rewards): 92 | result_file_name = "result_" + str(i) + ".npy" 93 | np.save('%s/%s'%(training_result_dir, result_file_name), result) 94 | reward_sum += reward 95 | i += 1 96 | print("Average reward ", str(reward_sum)) 97 | 98 | 99 | 100 | if __name__ == "__main__": 101 | main() 102 | -------------------------------------------------------------------------------- /logz.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | """ 4 | 5 | Some simple logging functionality, inspired by rllab's logging. 6 | Assumes that each diagnostic gets logged each iteration 7 | 8 | Call logz.configure_output_dir() to start logging to a 9 | tab-separated-values file (some_folder_name/log.txt) 10 | 11 | To load the learning curves, you can do, for example 12 | 13 | A = np.genfromtxt('/tmp/expt_1468984536/log.txt',delimiter='\t',dtype=None, names=True) 14 | A['EpRewMean'] 15 | 16 | """ 17 | 18 | import os.path as osp, shutil, time, atexit, os, subprocess 19 | import pickle 20 | import tensorflow as tf 21 | 22 | color2num = dict( 23 | gray=30, 24 | red=31, 25 | green=32, 26 | yellow=33, 27 | blue=34, 28 | magenta=35, 29 | cyan=36, 30 | white=37, 31 | crimson=38 32 | ) 33 | 34 | def colorize(string, color, bold=False, highlight=False): 35 | attr = [] 36 | num = color2num[color] 37 | if highlight: num += 10 38 | attr.append(str(num)) 39 | if bold: attr.append('1') 40 | return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string) 41 | 42 | class G: 43 | output_dir = None 44 | output_file = None 45 | first_row = True 46 | log_headers = [] 47 | log_current_row = {} 48 | 49 | def configure_output_dir(d=None): 50 | """ 51 | Set output directory to d, or to /tmp/somerandomnumber if d is None 52 | """ 53 | G.output_dir = d or "/tmp/experiments/%i"%int(time.time()) 54 | assert not osp.exists(G.output_dir), "Log dir %s already exists! Delete it first or use a different dir"%G.output_dir 55 | os.makedirs(G.output_dir) 56 | G.output_file = open(osp.join(G.output_dir, "log.txt"), 'w') 57 | atexit.register(G.output_file.close) 58 | print(colorize("Logging data to %s"%G.output_file.name, 'green', bold=True)) 59 | 60 | def log_tabular(key, val): 61 | """ 62 | Log a value of some diagnostic 63 | Call this once for each diagnostic quantity, each iteration 64 | """ 65 | if G.first_row: 66 | G.log_headers.append(key) 67 | else: 68 | assert key in G.log_headers, "Trying to introduce a new key %s that you didn't include in the first iteration"%key 69 | assert key not in G.log_current_row, "You already set %s this iteration. Maybe you forgot to call dump_tabular()"%key 70 | G.log_current_row[key] = val 71 | 72 | def save_params(params): 73 | with open(osp.join(G.output_dir, "params.json"), 'w') as out: 74 | out.write(json.dumps(params, separators=(',\n','\t:\t'), sort_keys=True)) 75 | 76 | def pickle_tf_vars(): 77 | """ 78 | Saves tensorflow variables 79 | Requires them to be initialized first, also a default session must exist 80 | """ 81 | _dict = {v.name : v.eval() for v in tf.global_variables()} 82 | with open(osp.join(G.output_dir, "vars.pkl"), 'wb') as f: 83 | pickle.dump(_dict, f) 84 | 85 | 86 | def dump_tabular(): 87 | """ 88 | Write all of the diagnostics from the current iteration 89 | """ 90 | vals = [] 91 | key_lens = [len(key) for key in G.log_headers] 92 | max_key_len = max(15,max(key_lens)) 93 | keystr = '%'+'%d'%max_key_len 94 | fmt = "| " + keystr + "s | %15s |" 95 | n_slashes = 22 + max_key_len 96 | print("-"*n_slashes) 97 | for key in G.log_headers: 98 | val = G.log_current_row.get(key, "") 99 | if hasattr(val, "__float__"): valstr = "%8.3g"%val 100 | else: valstr = val 101 | print(fmt%(key, valstr)) 102 | vals.append(val) 103 | print("-"*n_slashes) 104 | if G.output_file is not None: 105 | if G.first_row: 106 | G.output_file.write("\t".join(G.log_headers)) 107 | G.output_file.write("\n") 108 | G.output_file.write("\t".join(map(str,vals))) 109 | G.output_file.write("\n") 110 | G.output_file.flush() 111 | G.log_current_row.clear() 112 | G.first_row=False 113 | -------------------------------------------------------------------------------- /env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.ndimage.filters import gaussian_filter 3 | from scipy.ndimage.morphology import binary_fill_holes 4 | from skimage import feature 5 | 6 | PEN_DOWN = 2 7 | PEN_UP = 0 8 | FINISH = 1 9 | 10 | class Environment(): 11 | # Pulls in new images with generator_fn 12 | # generator_fn should return a preprocessed image and a segmentation mask 13 | def __init__(self, generator, gaussian_std=2.0, img_shape=(256,256), alpha=0.05, max_line_len=50): 14 | self.generator = generator 15 | self.gaussian_std = gaussian_std 16 | self.img_shape = img_shape 17 | self.alpha = alpha 18 | self.max_line_len = max_line_len 19 | 20 | self.curr_image = None 21 | self.curr_mask = None 22 | self.curr_blurred_mask = None 23 | self.state_map = None 24 | self.last_action = None 25 | self.first_vertex = None 26 | 27 | self.reset() 28 | 29 | # Returns (new_state, reward, done) 30 | # Action should be int: 0 = pen up, 1 = finish, other = index into array 31 | def step(self, action): 32 | action_class = PEN_DOWN if action > 1 else action 33 | coord_x, coord_y = (-1,-1) 34 | if action_class == PEN_DOWN: 35 | coord_x = (action - 2) // self.img_shape[0] 36 | coord_y = (action - 2) % self.img_shape[0] 37 | 38 | if self.last_action == PEN_UP: 39 | if action_class == PEN_UP: 40 | return self._get_state(), -1.0, False 41 | elif action_class == PEN_DOWN: 42 | self.first_vertex = (coord_x, coord_y) 43 | self.state_map[2,:,:] = 0 44 | self.state_map[2, coord_x, coord_y] = 1 45 | self.state_map[1, coord_x, coord_y] = 1 46 | rew = self.curr_blurred_mask[coord_x, coord_y] / self.alpha 47 | self.last_action = PEN_DOWN 48 | return self._get_state(), rew, False 49 | else: 50 | self.last_action = FINISH 51 | return self._get_state(), -1.0, True 52 | elif self.last_action == PEN_DOWN: 53 | if action_class == PEN_UP: 54 | rew = self._finish_polygon() 55 | self.first_vertex = None 56 | self.last_action = PEN_UP 57 | return self._get_state(), rew, False 58 | 59 | elif action_class == PEN_DOWN: 60 | prev_vertex_x, prev_vertex_y = np.where(self.state_map[2] == 1) 61 | prev_vertex_x = prev_vertex_x[0] 62 | prev_vertex_y = prev_vertex_y[0] 63 | 64 | # Penalize illegal placements 65 | #if np.hypot(coord_x - prev_vertex_x, coord_y - prev_vertex_y) > self.max_line_len: 66 | # return self._get_state(), -1, False 67 | 68 | line_x, line_y = self._get_line_coordinates(prev_vertex_x, prev_vertex_y, coord_x, coord_y) 69 | rew = self._contour_reward(line_x, line_y) 70 | for x, y in zip(line_x, line_y): 71 | self.state_map[1, x, y] = 1 72 | 73 | self.state_map[2, prev_vertex_x, prev_vertex_y] = 0 74 | self.state_map[2, coord_x, coord_y] = 1 75 | 76 | self.last_action = PEN_DOWN 77 | return self._get_state(), rew, False 78 | 79 | else: 80 | rew = self._finish_polygon() 81 | self.last_action = FINISH 82 | return self._get_state(), rew, True 83 | 84 | else: 85 | self.reset() 86 | return self.step(action) 87 | # raise Exception('Environment is done, should have been reset') 88 | 89 | # Returns initial state 90 | def reset(self): 91 | self.curr_image, self.curr_mask = next(self.generator) 92 | if len(self.curr_mask.shape) == 3: 93 | self.curr_mask = self.curr_mask[:,:,0] 94 | assert(self.curr_image.shape == self.img_shape) 95 | assert(self.curr_mask.shape == self.img_shape[:2]) 96 | 97 | mask_outline = feature.canny(self.curr_mask.astype(np.float32), sigma=2).astype(np.float32) 98 | self.curr_blurred_mask = gaussian_filter(mask_outline, self.gaussian_std) - 0.1 99 | self.curr_mask = self.curr_mask.astype(np.bool_) 100 | self.state_map = np.zeros((3, self.img_shape[0], self.img_shape[1]), dtype=np.int16) 101 | 102 | self.last_action = PEN_UP 103 | self.first_vertex = None 104 | 105 | first_state = self._get_state() 106 | return first_state 107 | 108 | def _get_state(self): 109 | return np.concatenate((self.curr_image, np.transpose(self.state_map)), axis=-1) 110 | 111 | def _contour_reward(self, line_x, line_y): 112 | rew = 0.0 113 | for x, y in zip(line_x, line_y): 114 | rew += self.curr_blurred_mask[x, y] 115 | self.curr_blurred_mask[x, y] = 0.0 # Can't get contour reward twice 116 | return rew / self.alpha 117 | 118 | def _region_reward(self): 119 | assert(self.curr_mask.dtype == np.bool_) 120 | mask = self.state_map[1].astype(np.bool_) 121 | intersection = (mask * self.curr_mask).sum() 122 | union = (mask + self.curr_mask).sum() 123 | iou = float(intersection) / float(union) 124 | return iou / self.alpha 125 | 126 | def _get_line_coordinates(self, x0, y0, x1, y1): 127 | length = int(np.hypot(x1 - x0, y1 - y0)) 128 | x, y = np.linspace(x0, x1, length), np.linspace(y0, y1, length) 129 | return x.astype(np.int), y.astype(np.int) 130 | 131 | # Returns contour reward + region reward for a finished polygon 132 | def _finish_polygon(self): 133 | last_x, last_y = np.where(self.state_map[2] == 1) 134 | last_x = last_x[0] 135 | last_y = last_y[0] 136 | last_line_x, last_line_y = self._get_line_coordinates(last_x, last_y, self.first_vertex[0], self.first_vertex[1]) 137 | rew = self._contour_reward(last_line_x, last_line_y) 138 | for x, y in zip(last_line_x, last_line_y): 139 | self.state_map[1, x, y] = 1 140 | # Fill in polygon 141 | self.state_map[1] = binary_fill_holes(self.state_map[1]) 142 | rew += self._region_reward() 143 | 144 | # Add polygon to overall segmentation mask 145 | polys = self.state_map[0] 146 | polys += self.state_map[1] 147 | polys[polys > 1] = 1 148 | 149 | self.state_map[1,:,:] = 0 150 | self.state_map[2,:,:] = 0 151 | 152 | return rew 153 | -------------------------------------------------------------------------------- /local_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.ndimage.filters import gaussian_filter 3 | from scipy.ndimage.morphology import binary_fill_holes 4 | from skimage import feature 5 | import random 6 | 7 | PEN_DOWN = 2 8 | PEN_UP = 0 9 | FINISH = 1 10 | 11 | class Environment(): 12 | # Pulls in new images with generator_fn 13 | # generator_fn should return a preprocessed image and a segmentation mask 14 | def __init__(self, generator, gaussian_std=2.0, img_shape=(256,256), window_size=32, alpha=0.05, max_line_len=50): 15 | self.generator = generator 16 | self.gaussian_std = gaussian_std 17 | self.img_shape = img_shape 18 | self.window_size = window_size 19 | self.alpha = alpha 20 | self.max_line_len = max_line_len 21 | 22 | self.curr_image = None 23 | self.curr_mask = None 24 | self.curr_blurred_mask = None 25 | self.state_map = None 26 | self.last_action = None 27 | self.first_vertex = None 28 | 29 | self.reset() 30 | 31 | # Returns (new_state, reward, done) 32 | # Action should be int: 0 = pen up, 1 = finish, other = index into array 33 | def step(self, action): 34 | action_class = PEN_DOWN if action > 1 else action 35 | coord_x, coord_y = (-1,-1) 36 | min_x, min_y, _, _ = self._get_window_bounds() 37 | if action_class == PEN_DOWN: 38 | coord_x = min_x + ((action - 2) // self.window_size) 39 | coord_y = min_y + ((action - 2) % self.window_size) 40 | 41 | if self.last_action == PEN_UP: 42 | if action_class == PEN_UP: 43 | return self._get_state(), -1.0, True 44 | elif action_class == PEN_DOWN: 45 | self.first_vertex = (coord_x, coord_y) 46 | self.state_map[2,:,:] = 0 47 | self.state_map[2, coord_x, coord_y] = 1 48 | self.state_map[1, coord_x, coord_y] = 1 49 | rew = self.curr_blurred_mask[coord_x, coord_y] / self.alpha 50 | self.last_action = PEN_DOWN 51 | return self._get_state(), rew, False 52 | else: 53 | self.last_action = FINISH 54 | return self._get_state(), -1.0, True 55 | elif self.last_action == PEN_DOWN: 56 | if action_class == PEN_UP: 57 | rew = self._finish_polygon() 58 | self.first_vertex = None 59 | self.last_action = PEN_UP 60 | return self._get_state(), rew, True 61 | 62 | elif action_class == PEN_DOWN: 63 | prev_vertex_x, prev_vertex_y = self._get_last_state() 64 | 65 | # Penalize illegal placements 66 | #if np.hypot(coord_x - prev_vertex_x, coord_y - prev_vertex_y) > self.max_line_len: 67 | # return self._get_state(), -1, False 68 | 69 | line_x, line_y = self._get_line_coordinates(prev_vertex_x, prev_vertex_y, coord_x, coord_y) 70 | rew = self._contour_reward(line_x, line_y) 71 | for x, y in zip(line_x, line_y): 72 | self.state_map[1, x, y] = 1 73 | 74 | self.state_map[2, prev_vertex_x, prev_vertex_y] = 0 75 | self.state_map[2, coord_x, coord_y] = 1 76 | 77 | self.last_action = PEN_DOWN 78 | return self._get_state(), rew, False 79 | 80 | else: 81 | rew = self._finish_polygon() 82 | self.last_action = FINISH 83 | return self._get_state(), rew, True 84 | 85 | else: 86 | self.reset() 87 | return self.step(action) 88 | # raise Exception('Environment is done, should have been reset') 89 | 90 | # Returns initial state 91 | def reset(self): 92 | self.curr_image, self.curr_mask = next(self.generator) 93 | if len(self.curr_mask.shape) == 3: 94 | self.curr_mask = self.curr_mask[:,:,0] 95 | assert(self.curr_image.shape == self.img_shape) 96 | assert(self.curr_mask.shape == self.img_shape[:2]) 97 | 98 | mask_outline = feature.canny(self.curr_mask.astype(np.float32), sigma=2).astype(np.float32) 99 | xs, ys = np.where(mask_outline == 1) 100 | if len(xs) == 0: 101 | return self.reset() 102 | i = random.choice(range(len(xs))) 103 | 104 | 105 | self.curr_blurred_mask = gaussian_filter(mask_outline, self.gaussian_std) 106 | self.curr_mask = self.curr_mask.astype(np.bool_) 107 | self.state_map = np.zeros((3, self.img_shape[0], self.img_shape[1]), dtype=np.int16) 108 | self.state_map[2, xs[i], ys[i]] = 1 109 | 110 | self.last_action = PEN_UP 111 | self.first_vertex = None 112 | 113 | first_state = self._get_state() 114 | return first_state 115 | 116 | def _get_last_state(self): 117 | prev_vertex_x, prev_vertex_y = np.where(self.state_map[2] == 1) 118 | x = prev_vertex_x[0] 119 | y = prev_vertex_y[0] 120 | return x, y 121 | 122 | def _get_window_bounds(self): 123 | x, y = self._get_last_state() 124 | if x < self.window_size / 2: 125 | min_x = max(0, x - self.window_size / 2) 126 | max_x = min_x + self.window_size 127 | else: 128 | max_x = min(x + self.window_size / 2, self.img_shape[0]) 129 | min_x = max_x - self.window_size 130 | if y < self.window_size / 2: 131 | min_y = max(0, y - self.window_size / 2) 132 | max_y = min_y + self.window_size 133 | else: 134 | max_y = min(y + self.window_size / 2, self.img_shape[1]) 135 | min_y = max_y - self.window_size 136 | return int(min_x), int(min_y), int(max_x), int(max_y) 137 | 138 | def _get_state(self): 139 | min_x, min_y, max_x, max_y = self._get_window_bounds() 140 | return np.concatenate((self.curr_image, np.transpose(self.state_map)), axis=-1)[min_x:max_x, min_y:max_y, :] 141 | 142 | def get_full_state(self): 143 | return np.concatenate((self.curr_image, np.transpose(self.state_map)), axis=-1) 144 | 145 | def _contour_reward(self, line_x, line_y): 146 | rew = 0.0 147 | for x, y in zip(line_x, line_y): 148 | rew += self.curr_blurred_mask[x, y] 149 | self.curr_blurred_mask[x, y] = 0.0 # Can't get contour reward twice 150 | return rew / self.alpha 151 | 152 | def _region_reward(self): 153 | assert(self.curr_mask.dtype == np.bool_) 154 | mask = self.state_map[1].astype(np.bool_) 155 | intersection = (mask * self.curr_mask).sum() 156 | union = (mask + self.curr_mask).sum() 157 | iou = float(intersection) / float(union) 158 | return iou / self.alpha 159 | 160 | def _get_line_coordinates(self, x0, y0, x1, y1): 161 | length = int(np.hypot(x1 - x0, y1 - y0)) 162 | x, y = np.linspace(x0, x1, length), np.linspace(y0, y1, length) 163 | return x.astype(np.int), y.astype(np.int) 164 | 165 | # Returns contour reward + region reward for a finished polygon 166 | def _finish_polygon(self): 167 | last_x, last_y = self._get_last_state() 168 | last_line_x, last_line_y = self._get_line_coordinates(last_x, last_y, self.first_vertex[0], self.first_vertex[1]) 169 | rew = self._contour_reward(last_line_x, last_line_y) 170 | for x, y in zip(last_line_x, last_line_y): 171 | self.state_map[1, x, y] = 1 172 | # Fill in polygon 173 | self.state_map[1] = binary_fill_holes(self.state_map[1]) 174 | rew += self._region_reward() 175 | 176 | # Add polygon to overall segmentation mask 177 | polys = self.state_map[0] 178 | polys += self.state_map[1] 179 | polys[polys > 1] = 1 180 | 181 | #self.state_map[1,:,:] = 0 182 | #self.state_map[2,:,:] = 0 183 | 184 | return rew 185 | -------------------------------------------------------------------------------- /dqn_utils.py: -------------------------------------------------------------------------------- 1 | """This file includes a collection of utility functions that are useful for 2 | implementing DQN.""" 3 | import gym 4 | import tensorflow as tf 5 | import numpy as np 6 | import random 7 | 8 | def huber_loss(x, delta=1.0): 9 | # https://en.wikipedia.org/wiki/Huber_loss 10 | return tf.where( 11 | tf.abs(x) < delta, 12 | tf.square(x) * 0.5, 13 | delta * (tf.abs(x) - 0.5 * delta) 14 | ) 15 | 16 | def sample_n_unique(sampling_f, n): 17 | """Helper function. Given a function `sampling_f` that returns 18 | comparable objects, sample n such unique objects. 19 | """ 20 | res = [] 21 | while len(res) < n: 22 | candidate = sampling_f() 23 | if candidate not in res: 24 | res.append(candidate) 25 | return res 26 | 27 | class Schedule(object): 28 | def value(self, t): 29 | """Value of the schedule at time t""" 30 | raise NotImplementedError() 31 | 32 | class ConstantSchedule(object): 33 | def __init__(self, value): 34 | """Value remains constant over time. 35 | Parameters 36 | ---------- 37 | value: float 38 | Constant value of the schedule 39 | """ 40 | self._v = value 41 | 42 | def value(self, t): 43 | """See Schedule.value""" 44 | return self._v 45 | 46 | def linear_interpolation(l, r, alpha): 47 | return l + alpha * (r - l) 48 | 49 | class PiecewiseSchedule(object): 50 | def __init__(self, endpoints, interpolation=linear_interpolation, outside_value=None): 51 | """Piecewise schedule. 52 | endpoints: [(int, int)] 53 | list of pairs `(time, value)` meanining that schedule should output 54 | `value` when `t==time`. All the values for time must be sorted in 55 | an increasing order. When t is between two times, e.g. `(time_a, value_a)` 56 | and `(time_b, value_b)`, such that `time_a <= t < time_b` then value outputs 57 | `interpolation(value_a, value_b, alpha)` where alpha is a fraction of 58 | time passed between `time_a` and `time_b` for time `t`. 59 | interpolation: lambda float, float, float: float 60 | a function that takes value to the left and to the right of t according 61 | to the `endpoints`. Alpha is the fraction of distance from left endpoint to 62 | right endpoint that t has covered. See linear_interpolation for example. 63 | outside_value: float 64 | if the value is requested outside of all the intervals sepecified in 65 | `endpoints` this value is returned. If None then AssertionError is 66 | raised when outside value is requested. 67 | """ 68 | idxes = [e[0] for e in endpoints] 69 | assert idxes == sorted(idxes) 70 | self._interpolation = interpolation 71 | self._outside_value = outside_value 72 | self._endpoints = endpoints 73 | 74 | def value(self, t): 75 | """See Schedule.value""" 76 | for (l_t, l), (r_t, r) in zip(self._endpoints[:-1], self._endpoints[1:]): 77 | if l_t <= t and t < r_t: 78 | alpha = float(t - l_t) / (r_t - l_t) 79 | return self._interpolation(l, r, alpha) 80 | 81 | # t does not belong to any of the pieces, so doom. 82 | assert self._outside_value is not None 83 | return self._outside_value 84 | 85 | class LinearSchedule(object): 86 | def __init__(self, schedule_timesteps, final_p, initial_p=1.0): 87 | """Linear interpolation between initial_p and final_p over 88 | schedule_timesteps. After this many timesteps pass final_p is 89 | returned. 90 | Parameters 91 | ---------- 92 | schedule_timesteps: int 93 | Number of timesteps for which to linearly anneal initial_p 94 | to final_p 95 | initial_p: float 96 | initial output value 97 | final_p: float 98 | final output value 99 | """ 100 | self.schedule_timesteps = schedule_timesteps 101 | self.final_p = final_p 102 | self.initial_p = initial_p 103 | 104 | def value(self, t): 105 | """See Schedule.value""" 106 | fraction = min(float(t) / self.schedule_timesteps, 1.0) 107 | return self.initial_p + fraction * (self.final_p - self.initial_p) 108 | 109 | def compute_exponential_averages(variables, decay): 110 | """Given a list of tensorflow scalar variables 111 | create ops corresponding to their exponential 112 | averages 113 | Parameters 114 | ---------- 115 | variables: [tf.Tensor] 116 | List of scalar tensors. 117 | Returns 118 | ------- 119 | averages: [tf.Tensor] 120 | List of scalar tensors corresponding to averages 121 | of al the `variables` (in order) 122 | apply_op: tf.runnable 123 | Op to be run to update the averages with current value 124 | of variables. 125 | """ 126 | averager = tf.train.ExponentialMovingAverage(decay=decay) 127 | apply_op = averager.apply(variables) 128 | return [averager.average(v) for v in variables], apply_op 129 | 130 | def minimize_and_clip(optimizer, objective, var_list, clip_val=10): 131 | """Minimized `objective` using `optimizer` w.r.t. variables in 132 | `var_list` while ensure the norm of the gradients for each 133 | variable is clipped to `clip_val` 134 | """ 135 | gradients = optimizer.compute_gradients(objective, var_list=var_list) 136 | for i, (grad, var) in enumerate(gradients): 137 | if grad is not None: 138 | gradients[i] = (tf.clip_by_norm(grad, clip_val), var) 139 | return optimizer.apply_gradients(gradients) 140 | 141 | def initialize_interdependent_variables(session, vars_list, feed_dict): 142 | """Initialize a list of variables one at a time, which is useful if 143 | initialization of some variables depends on initialization of the others. 144 | """ 145 | vars_left = vars_list 146 | while len(vars_left) > 0: 147 | new_vars_left = [] 148 | for v in vars_left: 149 | try: 150 | # If using an older version of TensorFlow, uncomment the line 151 | # below and comment out the line after it. 152 | #session.run(tf.initialize_variables([v]), feed_dict) 153 | session.run(tf.variables_initializer([v]), feed_dict) 154 | except tf.errors.FailedPreconditionError: 155 | new_vars_left.append(v) 156 | if len(new_vars_left) >= len(vars_left): 157 | # This can happend if the variables all depend on each other, or more likely if there's 158 | # another variable outside of the list, that still needs to be initialized. This could be 159 | # detected here, but life's finite. 160 | raise Exception("Cycle in variable dependencies, or extenrnal precondition unsatisfied.") 161 | else: 162 | vars_left = new_vars_left 163 | 164 | def get_wrapper_by_name(env, classname): 165 | currentenv = env 166 | while True: 167 | if classname in currentenv.__class__.__name__: 168 | return currentenv 169 | elif isinstance(env, gym.Wrapper): 170 | currentenv = currentenv.env 171 | else: 172 | raise ValueError("Couldn't find wrapper named %s"%classname) 173 | 174 | class ReplayBuffer(object): 175 | def __init__(self, size): 176 | """This is a memory efficient implementation of the replay buffer. 177 | adapted for the purposes of running image segmentation via RL 178 | 179 | The sepecific memory optimizations use here are: 180 | - only store each frame once rather than k time 181 | even if every observation normally consists of k last frames 182 | - store frames as np.uint8 (actually it is most time-performance 183 | to cast them back to float32 on GPU to minimize memory transfer 184 | time) 185 | - store frame_t and frame_(t+1) in the same buffer. 186 | 187 | For the typical use case in Atari Deep RL buffer with 1M frames the total 188 | memory footprint of this buffer is 10^6 * 84 * 84 bytes ~= 7 gigabytes 189 | 190 | In our case, given that the dimension of our state space is (256,256,6) 191 | and the dimension of the action space is (3, 50, 50), 192 | the memory footprint for each (s, a, r, d) tuple (since we store s, s' in the same buffer) 193 | is 3.2 MB 194 | 195 | Warning! Assumes that returning frame of zeros at the beginning 196 | of the episode, when there is less frames than `frame_history_len`, 197 | is acceptable. 198 | 199 | Parameters 200 | ---------- 201 | size: int 202 | Max number of transitions to store in the buffer. When the buffer 203 | overflows the old memories are dropped. 204 | """ 205 | 206 | self.size = size 207 | 208 | self.next_idx = 0 209 | self.num_in_buffer = 0 210 | 211 | self.obs = None 212 | self.action = None 213 | self.reward = None 214 | self.done = None 215 | 216 | def can_sample(self, batch_size): 217 | """Returns true if `batch_size` different transitions can be sampled from the buffer.""" 218 | return batch_size + 1 <= self.num_in_buffer 219 | 220 | def get_sample(self, idxes): 221 | obs_batch = np.array([self.obs[idx] for idx in idxes]) 222 | act_batch = self.action[idxes] 223 | rew_batch = self.reward[idxes] 224 | next_obs_batch = np.array([self.obs[idx + 1] for idx in idxes]) 225 | done_mask = np.array([1.0 if self.done[idx] else 0.0 for idx in idxes], dtype=np.float32) 226 | 227 | return obs_batch, act_batch, rew_batch, next_obs_batch, done_mask 228 | 229 | 230 | def sample(self, batch_size): 231 | """Sample `batch_size` different transitions. 232 | 233 | i-th sample transition is the following: 234 | 235 | when observing `obs_batch[i]`, action `act_batch[i]` was taken, 236 | after which reward `rew_batch[i]` was received and subsequent 237 | observation next_obs_batch[i] was observed, unless the epsiode 238 | was done which is represented by `done_mask[i]` which is equal 239 | to 1 if episode has ended as a result of that action. 240 | 241 | Parameters 242 | ---------- 243 | batch_size: int 244 | How many transitions to sample. 245 | 246 | Returns 247 | ------- 248 | obs_batch: np.array 249 | Array of shape 250 | (batch_size, img_h, img_w, img_c + 3) 251 | and dtype np.uint8 252 | act_batch: np.array 253 | Array of shape (batch_size) and dtype np.int32 254 | rew_batch: np.array 255 | Array of shape (batch_size,) and dtype np.float32 256 | next_obs_batch: np.array 257 | Array of shape 258 | (batch_size, img_h, img_w, img_c + 3) 259 | and dtype np.uint8 260 | done_mask: np.array 261 | Array of shape (batch_size,) and dtype np.float32 262 | """ 263 | assert self.can_sample(batch_size) 264 | idxes = sample_n_unique(lambda: random.randint(0, self.num_in_buffer - 2), batch_size) 265 | return self.get_sample(idxes) 266 | 267 | def store_observation(self, new_obs): 268 | """Store a single observation in the buffer at the next available index, overwriting 269 | old frames if necessary. 270 | 271 | Parameters 272 | ---------- 273 | new_obs: np.array 274 | Array of shape (img_h, img_w, img_c+3) 275 | The observation (image, state maps) 276 | 277 | Returns 278 | ------- 279 | idx: int 280 | Index at which the frame is stored. To be used for `store_effect` later. 281 | """ 282 | if self.obs is None: 283 | print('Obs expected size, ', [self.size] + list(new_obs.shape)) 284 | self.obs = np.empty([self.size] + list(new_obs.shape), dtype=np.uint8) 285 | self.action = np.empty([self.size], dtype=np.uint8) 286 | self.reward = np.empty([self.size], dtype=np.float32) 287 | self.done = np.empty([self.size], dtype=np.bool) 288 | self.obs[self.next_idx] = new_obs 289 | ret = self.next_idx 290 | self.next_idx = (self.next_idx + 1) % self.size 291 | self.num_in_buffer = min(self.size, self.num_in_buffer + 1) 292 | return ret 293 | 294 | def store_effect(self, idx, action, reward, done): 295 | """Store effects of action taken after obeserving frame stored 296 | at index idx. The reason `store_frame` and `store_effect` is broken 297 | up into two functions is so that once can call `encode_recent_observation` 298 | in between. 299 | 300 | Paramters 301 | --------- 302 | idx: int 303 | Index in buffer of recently observed frame (returned by `store_frame`). 304 | action: ((class), x, y) 305 | Tuple representing pen-down/pen-up/draw-finish and the corresponding next pen location 306 | reward: float 307 | Reward that was received when the actions was performed. 308 | done: bool 309 | True if episode was finished after performing that action. 310 | """ 311 | self.action[idx] = action 312 | self.reward[idx] = reward 313 | self.done[idx] = done 314 | 315 | -------------------------------------------------------------------------------- /dqn.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import time 3 | import pickle 4 | import sys 5 | import itertools 6 | import numpy as np 7 | import random 8 | import tensorflow as tf 9 | import tensorflow.contrib.layers as layers 10 | from collections import namedtuple 11 | from dqn_utils import * 12 | import json 13 | import os 14 | import time 15 | import logz 16 | 17 | from convex_hull import ConvexHullPolicy 18 | 19 | OptimizerSpec = namedtuple("OptimizerSpec", ["constructor", "kwargs", "lr_schedule"]) 20 | img_size = 256 21 | window_size = 32 22 | 23 | class QLearner(object): 24 | 25 | def __init__( 26 | self, 27 | env, 28 | q_func, 29 | optimizer_spec, 30 | session, 31 | exploration=LinearSchedule(1000000, 0.1), 32 | total_time_steps=2000000, 33 | replay_buffer_size=1000, 34 | batch_size=32, 35 | gamma=0.99, 36 | learning_starts=500, 37 | learning_freq=4, 38 | target_update_freq=10000, 39 | grad_norm_clipping=10, 40 | pixel_limit=50, 41 | rew_file=None, 42 | double_q=False, 43 | progress_dir=None): 44 | """Run Deep Q-learning algorithm. 45 | 46 | You can specify your own convnet using q_func. 47 | 48 | All schedules are w.r.t. total number of steps taken in the environment. 49 | 50 | Parameters 51 | ---------- 52 | env: Image_Env 53 | Environment to train on. 54 | q_func: function 55 | Model to use for computing the q functions. It should accept the 56 | following named arguments: 57 | img_in: tf.Tensor 58 | tensorflow tensor representing the input image 59 | scope: str 60 | scope in which all the model related variables 61 | should be created 62 | reuse: bool 63 | whether previously created variables should be reused. 64 | Returns two tensors: 65 | q_class: (3) 66 | q_map: (256, 256) 67 | optimizer_spec: OptimizerSpec 68 | Specifying the constructor and kwargs, as well as learning rate schedule 69 | for the optimizer 70 | session: tf.Session 71 | tensorflow session to use. 72 | exploration: rl_algs.deepq.utils.schedules.Schedule 73 | schedule for probability of chosing random action. 74 | total_time_steps: int 75 | Total time steps we plan to run the RL algorithm for; replaces the stopping criterion 76 | replay_buffer_size: int 77 | How many memories to store in the replay buffer. 78 | batch_size: int 79 | How many transitions to sample each time experience is replayed. 80 | gamma: float 81 | Discount Factor 82 | learning_starts: int 83 | After how many environment steps to start replaying experiences 84 | learning_freq: int 85 | How many steps of environment to take between every experience replay 86 | target_update_freq: int 87 | How many experience replay rounds (not steps!) to perform between 88 | each update to the target Q network 89 | grad_norm_clipping: float or None 90 | If not None gradients' norms are clipped to this value. 91 | pixel_limit: int 92 | Number of pixels we limit the drawing step to 93 | double_q: bool 94 | If True, then use double Q-learning to compute target values. Otherwise, use vanilla DQN. 95 | https://papers.nips.cc/paper/3964-double-q-learning.pdf 96 | progress_dir: str 97 | Place to store logged image+masks for reference (helps if you have to terminate early) 98 | """ 99 | self.target_update_freq = target_update_freq 100 | self.optimizer_spec = optimizer_spec 101 | self.batch_size = batch_size 102 | self.learning_freq = learning_freq 103 | self.learning_starts = learning_starts 104 | self.total_time_steps=total_time_steps 105 | self.env = env 106 | self.session = session 107 | self.exploration = exploration 108 | self.rew_file = str(uuid.uuid4()) + '.pkl' if rew_file is None else rew_file 109 | self.pixel_limit = pixel_limit 110 | self.progress_dir = progress_dir 111 | 112 | self.hull_policy = ConvexHullPolicy(img_size) 113 | 114 | ############### 115 | # BUILD MODEL # 116 | ############### 117 | 118 | 119 | input_shape = (window_size, window_size, 6) 120 | action_dim = 2 + window_size * window_size 121 | 122 | # set up placeholders 123 | # placeholder for current observation (or state) 124 | self.obs_t_ph = tf.placeholder(tf.uint8, [None] + list(input_shape)) 125 | # placeholder for current action 126 | self.act_t_ph = tf.placeholder(tf.int32, [None]) 127 | # placeholder for current reward 128 | self.rew_t_ph = tf.placeholder(tf.float32, [None]) 129 | # placeholder for next observation (or state) 130 | self.obs_tp1_ph = tf.placeholder(tf.uint8, [None] + list(input_shape)) 131 | # placeholder for end of episode mask 132 | # this value is 1 if the next state corresponds to the end of an episode, 133 | # in which case there is no Q-value at the next state; at the end of an 134 | # episode, only the current state reward contributes to the target, not the 135 | # next state Q-value (i.e. target is just rew_t_ph, not rew_t_ph + gamma * q_tp1) 136 | self.done_mask_ph = tf.placeholder(tf.float32, [None]) 137 | 138 | # casting to float on GPU ensures lower data transfer times. 139 | self.obs_t_float = tf.cast(self.obs_t_ph, tf.float32) / 255.0 140 | self.obs_tp1_float = tf.cast(self.obs_tp1_ph, tf.float32) / 255.0 141 | 142 | # Here, you should fill in your own code to compute the Bellman error. This requires 143 | # evaluating the current and next Q-values and constructing the corresponding error. 144 | # TensorFlow will differentiate this error for you, you just need to pass it to the 145 | # optimizer. See assignment text for details. 146 | # Your code should produce one scalar-valued tensor: total_error 147 | # This will be passed to the optimizer in the provided code below. 148 | # Your code should also produce two collections of variables: 149 | # q_func_vars 150 | # target_q_func_vars 151 | # These should hold all of the variables of the Q-function network and target network, 152 | # respectively. A convenient way to get these is to make use of TF's "scope" feature. 153 | # For example, you can create your Q-function network with the scope "q_func" like this: 154 | # = q_func(obs_t_float, num_actions, scope="q_func", reuse=False) 155 | # And then you can obtain the variables like this: 156 | # q_func_vars = tf.get_colletction(tf.GraphKeys.GLOBAL_VARIABLES, scope='q_func') 157 | # Older versions of TensorFlow may require using "VARIABLES" instead of "GLOBAL_VARIABLES" 158 | # Tip: use huber_loss (from dqn_utils) instead of squared error when defining self.total_error 159 | ###### 160 | 161 | # YOUR CODE HERE 162 | curr_q_eval = q_func(self.obs_t_float, scope="q_func", reuse=False) 163 | self.q_action = q_func(self.obs_t_float, 'q_func', reuse=True) 164 | target_q_action = q_func(self.obs_tp1_float, 'target_func', reuse=False) 165 | if double_q: 166 | target_actions = tf.argmax(curr_q_eval, output_type=tf.int32) 167 | action_idx = tf.stack([tf.range(0, tf.shape(curr_q_eval)[0]), target_actions], axis=1) 168 | gamma_max_future_q_targets = tf.scalar_mul(gamma, tf.gather_nd(target_q_action, action_idx)) 169 | else: 170 | gamma_max_future_q_targets = tf.scalar_mul(gamma, tf.reduce_max(target_q_action)) 171 | 172 | q_targets = tf.stop_gradient(tf.add(self.rew_t_ph, gamma_max_future_q_targets - tf.multiply(self.done_mask_ph, gamma_max_future_q_targets))) 173 | idx = tf.range(0, tf.shape(self.act_t_ph)[0]) 174 | cat_idx = tf.stack([idx, self.act_t_ph], axis=1) 175 | current_q_values = tf.gather_nd(curr_q_eval, cat_idx) 176 | self.total_error = tf.reduce_sum(huber_loss(current_q_values - q_targets)) 177 | q_func_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='q_func') 178 | target_q_func_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='target_func') 179 | ###### 180 | 181 | # construct optimization op (with gradient clipping) 182 | self.learning_rate = tf.placeholder(tf.float32, (), name="learning_rate") 183 | optimizer = self.optimizer_spec.constructor(learning_rate=self.learning_rate, **self.optimizer_spec.kwargs) 184 | self.train_fn = minimize_and_clip(optimizer, self.total_error, 185 | var_list=q_func_vars, clip_val=grad_norm_clipping) 186 | 187 | # update_target_fn will be called periodically to copy Q network to target Q network 188 | update_target_fn = [] 189 | for var, var_target in zip(sorted(q_func_vars, key=lambda v: v.name), 190 | sorted(target_q_func_vars, key=lambda v: v.name)): 191 | update_target_fn.append(var_target.assign(var)) 192 | self.update_target_fn = tf.group(*update_target_fn) 193 | 194 | # construct the replay buffer 195 | self.replay_buffer = ReplayBuffer(replay_buffer_size) 196 | self.replay_buffer_idx = None 197 | 198 | ############### 199 | # RUN ENV # 200 | ############### 201 | self.model_initialized = False 202 | self.num_param_updates = 0 203 | self.mean_episode_reward = -float('nan') 204 | self.best_mean_episode_reward = -float('inf') 205 | self.last_obs = self.env.reset() 206 | self.log_every_n_steps = 1000 207 | 208 | self.start_time = None 209 | self.t = 0 210 | 211 | ############### 212 | # SETUP LOGGING # 213 | ############### 214 | if not(os.path.exists('data')): 215 | os.makedirs('data') 216 | logdir = 'img-' 217 | logdir = logdir + time.strftime("%d-%m-%Y_%H-%M-%S") 218 | logdir = os.path.join('data', logdir) 219 | logz.configure_output_dir(logdir) 220 | 221 | 222 | def stopping_criterion_met(self): 223 | return self.t >= self.total_time_steps 224 | 225 | def choose_random_action(self, last_obs, epsilon): 226 | # Randomly selects a legal action 227 | epsilon_flip = np.random.uniform() 228 | if epsilon_flip < 1/3: 229 | # Pen down 230 | pen_loc_map = last_obs[:, :, 5] 231 | x_dim, y_dim = pen_loc_map.shape 232 | pen_x, pen_y = next((idx for idx, val in np.ndenumerate(pen_loc_map) if val==1), (-1, -1)) 233 | if pen_x == -1: 234 | # 1 not in last state map => pen up was the last action 235 | x_rnd, y_rnd = np.random.randint(0, x_dim-1), np.random.randint(0, y_dim-1) 236 | else: 237 | x_min, y_min = max(0, pen_x - self.pixel_limit//2), max(0, pen_y - self.pixel_limit//2) 238 | x_max, y_max = min(pen_x + self.pixel_limit//2, x_dim - 1), min(pen_y + self.pixel_limit//2, y_dim - 1) 239 | distance_based_exploration_th = 0.01 240 | if epsilon > distance_based_exploration_th: 241 | # Explore geometrically, giving points further from pen current x and y locations preference, 242 | # weighted by the current exploration factor 243 | x_range, y_range = np.arange(x_min, x_max), np.arange(y_min, y_max) 244 | smoothing_factor = 1 245 | x_weights, y_weights = np.abs(pen_x - x_range)*epsilon + smoothing_factor, np.abs(pen_y - y_range)*epsilon + smoothing_factor 246 | x_p, y_p = x_weights/sum(x_weights), y_weights/sum(y_weights) 247 | x_rnd, y_rnd = np.random.choice(x_range, p=x_p), np.random.choice(y_range, p=y_p) 248 | else: 249 | x_rnd, y_rnd = np.random.randint(x_min, x_max), np.random.randint(y_min, y_max) 250 | return 2 + x_rnd * window_size + y_rnd 251 | elif epsilon_flip < 2/3: 252 | # Pen up 253 | return 0 254 | else: 255 | # Draw finish 256 | return 1 257 | 258 | def step_env(self): 259 | ### 2. Step the env and store the transition 260 | # At this point, "self.last_obs" contains the latest observation that was 261 | # recorded from the simulator. Here, your code needs to store this 262 | # observation and its outcome (reward, next observation, etc.) into 263 | # the replay buffer while stepping the simulator forward one step. 264 | # At the end of this block of code, the simulator should have been 265 | # advanced one step, and the replay buffer should contain one more 266 | # transition. 267 | # Specifically, self.last_obs must point to the new latest observation. 268 | # Useful functions you'll need to call: 269 | # obs, reward, done, info = env.step(action) 270 | # this steps the environment forward one step 271 | # obs = env.reset() 272 | # this resets the environment if you reached an episode boundary. 273 | # Don't forget to call env.reset() to get a new observation if done 274 | # is true!! 275 | # Don't forget to include epsilon greedy exploration! 276 | # And remember that the first time you enter this loop, the model 277 | # may not yet have been initialized (but of course, the first step 278 | # might as well be random, since you haven't trained your net...) 279 | 280 | ##### 281 | 282 | # YOUR CODE HERE 283 | buf_idx = self.replay_buffer.store_observation(self.last_obs) 284 | epsilon = self.exploration.value(self.t) 285 | if not self.model_initialized: 286 | # Completely random 287 | action = self.choose_random_action(self.last_obs, epsilon) 288 | #action = self.hull_policy.get_action(self.last_obs, self.env.curr_mask) 289 | else: 290 | epsilon_flip = np.random.binomial(1, epsilon) 291 | if epsilon_flip == 1: 292 | action = self.choose_random_action(self.last_obs, epsilon) 293 | else: 294 | q_values = self.session.run(tf.squeeze(self.q_action), {self.obs_t_ph: np.expand_dims(self.last_obs, axis=0)}) 295 | action = np.argmax(q_values) 296 | obs, reward, done = self.env.step(action) 297 | self.replay_buffer.store_effect(buf_idx, action, reward, done) 298 | if done: 299 | self.last_obs = self.env.reset() 300 | else: 301 | self.last_obs = obs 302 | 303 | def update_model(self): 304 | ### 3. Perform experience replay and train the network. 305 | # note that this is only done if the replay buffer contains enough samples 306 | # for us to learn something useful -- until then, the model will not be 307 | # initialized and random actions should be taken 308 | if (self.t > self.learning_starts and \ 309 | self.t % self.learning_freq == 0 and \ 310 | self.replay_buffer.can_sample(self.batch_size)): 311 | 312 | # Here, you should perform training. Training consists of four steps: 313 | # 3.a: use the replay buffer to sample a batch of transitions (see the 314 | # replay buffer code for function definition, each batch that you sample 315 | # should consist of current observations, current actions, rewards, 316 | # next observations, and done indicator). 317 | # 3.b: initialize the model if it has not been initialized yet; to do 318 | # that, call 319 | # initialize_interdependent_variables(self.session, tf.global_variables(), { 320 | # self.obs_t_ph: obs_t_batch, 321 | # self.obs_tp1_ph: obs_tp1_batch, 322 | # }) 323 | # where obs_t_batch and obs_tp1_batch are the batches of observations at 324 | # the current and next time step. The boolean variable model_initialized 325 | # indicates whether or not the model has been initialized. 326 | # Remember that you have to update the target network too (see 3.d)! 327 | # 3.c: train the model. To do this, you'll need to use the self.train_fn and 328 | # self.total_error ops that were created earlier: self.total_error is what you 329 | # created to compute the total Bellman error in a batch, and self.train_fn 330 | # will actually perform a gradient step and update the network parameters 331 | # to reduce total_error. When calling self.session.run on these you'll need to 332 | # populate the following placeholders: 333 | # self.obs_t_ph 334 | # self.act_t_ph 335 | # self.rew_t_ph 336 | # self.obs_tp1_ph 337 | # self.done_mask_ph 338 | # (this is needed for computing self.total_error) 339 | # self.learning_rate -- you can get this from self.optimizer_spec.lr_schedule.value(t) 340 | # (this is needed by the optimizer to choose the learning rate) 341 | # 3.d: periodically update the target network by calling 342 | # self.session.run(self.update_target_fn) 343 | # you should update every target_update_freq steps, and you may find the 344 | # variable self.num_param_updates useful for this (it was initialized to 0) 345 | ##### 346 | 347 | # YOUR CODE HERE 348 | obs_t_batch, act_t_batch, rew_t_batch, obs_tp1_batch, done_mask = self.replay_buffer.sample(self.batch_size) 349 | if not self.model_initialized: 350 | initialize_interdependent_variables(self.session, tf.global_variables(), 351 | { 352 | self.obs_t_ph: obs_t_batch, 353 | self.obs_tp1_ph: obs_tp1_batch, 354 | }) 355 | self.model_initialized = True 356 | self.session.run([self.total_error, self.train_fn], { 357 | self.obs_t_ph: obs_t_batch, 358 | self.act_t_ph: act_t_batch, 359 | self.rew_t_ph: rew_t_batch, 360 | self.obs_tp1_ph: obs_tp1_batch, 361 | self.done_mask_ph: done_mask, 362 | self.learning_rate: self.optimizer_spec.lr_schedule.value(self.t) 363 | }) 364 | self.num_param_updates += 1 365 | if self.num_param_updates % self.target_update_freq == 0: 366 | self.session.run(self.update_target_fn) 367 | self.t += 1 368 | 369 | def log_progress(self): 370 | if self.t % self.log_every_n_steps == 0 and self.model_initialized: 371 | log_episodes = 5 372 | episodes = [self.predict(self.env) for i in range(log_episodes)] 373 | episode_results = [prediction[0] for prediction in episodes] 374 | episode_rewards = [prediction[1] for prediction in episodes] 375 | episode_lengths = [prediction[2] for prediction in episodes] 376 | self.mean_episode_reward = sum(episode_rewards)/len(episode_rewards) 377 | self.best_mean_episode_reward = max(self.best_mean_episode_reward, self.mean_episode_reward) 378 | print("Timestep %d" % (self.t,)) 379 | print("mean reward (5 episodes) %f" % self.mean_episode_reward) 380 | print("best mean reward %f" % self.best_mean_episode_reward) 381 | print("exploration %f" % self.exploration.value(self.t)) 382 | print("learning_rate %f" % self.optimizer_spec.lr_schedule.value(self.t)) 383 | print("Episode lengths ", episode_lengths) 384 | if self.start_time is not None: 385 | print("running time %f" % ((time.time() - self.start_time) / 60.)) 386 | self.start_time = time.time() 387 | logz.log_tabular("Timestep", self.t) 388 | logz.log_tabular("Mean Reward (5 episodes)", self.mean_episode_reward) 389 | logz.log_tabular("Best Mean Reward", self.best_mean_episode_reward) 390 | logz.dump_tabular() 391 | sys.stdout.flush() 392 | with open(self.rew_file, 'wb') as f: 393 | pickle.dump(episode_rewards, f, pickle.HIGHEST_PROTOCOL) 394 | if self.progress_dir is None: 395 | return 396 | for count, result in enumerate(episode_results): 397 | result_file_name = "result_" + str(count) + "_t_" + str(self.t) + ".npy" 398 | np.save('%s/%s'%(self.progress_dir, result_file_name), result) 399 | 400 | 401 | def predict(self, test_env): 402 | # Runs the prediction algorithm on one image, and returns [img_c_1, img_c_2, img_c_3, img_mask], reward 403 | # since we no longer need the last two, and to keep the reward for logging purposes 404 | done = False 405 | self.last_obs = test_env.reset() 406 | count = 0 407 | reward_sum = 0 408 | while(not done and count < 100): 409 | q_values = self.session.run([tf.squeeze(self.q_action)], {self.obs_t_ph: np.expand_dims(self.last_obs, axis=0)}) 410 | action = np.argmax(q_values) 411 | obs, reward, done = test_env.step(action) 412 | self.last_obs = obs 413 | reward_sum += reward 414 | count += 1 415 | if not done: 416 | # Run a pen finish 417 | action = 1 418 | obs, reward, done = test_env.step(action) 419 | self.last_obs = obs 420 | reward_sum += reward 421 | #return self.last_obs[:,:,:4], reward_sum 422 | return test_env.get_full_state(), reward_sum, count 423 | 424 | def test(self, test_env, num_test_samples): 425 | results, rewards = [], [] 426 | for sample in range(num_test_samples): 427 | curr_result, curr_reward, _ = self.predict(test_env) 428 | results.append(curr_result) 429 | rewards.append(curr_reward) 430 | return results, rewards 431 | 432 | 433 | def learn(*args, **kwargs): 434 | alg = QLearner(*args, **kwargs) 435 | while not alg.stopping_criterion_met(): 436 | alg.step_env() 437 | # at this point, the environment should have been advanced one step (and 438 | # reset if done was true), and self.last_obs should point to the new latest 439 | # observation 440 | alg.update_model() 441 | alg.log_progress() 442 | return alg 443 | 444 | 445 | 446 | 447 | --------------------------------------------------------------------------------