├── .idea ├── .gitignore ├── vcs.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── other.xml ├── modules.xml ├── RLPycharmProj.iml └── misc.xml ├── figures_for_README ├── Pong_16.gif ├── Pong_21.gif ├── Breakout_425.gif ├── Breakout_756.gif ├── Iterations:1200000-Time:04-08-2020-22-06-38.jpg └── Iterations:5000000-Time:03-30-2020-02-57-22.jpg ├── __pycache__ ├── DQNs.cpython-35.pyc ├── DQNs.cpython-36.pyc ├── utils.cpython-35.pyc ├── utils.cpython-36.pyc ├── sumtree.cpython-36.pyc ├── EnvManagers.cpython-35.pyc └── EnvManagers.cpython-36.pyc ├── Agent.py ├── plot_result.py ├── DDQN_params.json ├── sumtree.py ├── TestonHeldout.py ├── evaluation.py ├── CartPole.py ├── README.md ├── DQNs.py ├── EnvManagers.py ├── continue_training.py ├── Q_Learning_IceLake.ipynb ├── Breakout.py └── utils.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /workspace.xml -------------------------------------------------------------------------------- /figures_for_README/Pong_16.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/figures_for_README/Pong_16.gif -------------------------------------------------------------------------------- /figures_for_README/Pong_21.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/figures_for_README/Pong_21.gif -------------------------------------------------------------------------------- /__pycache__/DQNs.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/__pycache__/DQNs.cpython-35.pyc -------------------------------------------------------------------------------- /__pycache__/DQNs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/__pycache__/DQNs.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/__pycache__/utils.cpython-35.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/sumtree.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/__pycache__/sumtree.cpython-36.pyc -------------------------------------------------------------------------------- /figures_for_README/Breakout_425.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/figures_for_README/Breakout_425.gif -------------------------------------------------------------------------------- /figures_for_README/Breakout_756.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/figures_for_README/Breakout_756.gif -------------------------------------------------------------------------------- /__pycache__/EnvManagers.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/__pycache__/EnvManagers.cpython-35.pyc -------------------------------------------------------------------------------- /__pycache__/EnvManagers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/__pycache__/EnvManagers.cpython-36.pyc -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /figures_for_README/Iterations:1200000-Time:04-08-2020-22-06-38.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/figures_for_README/Iterations:1200000-Time:04-08-2020-22-06-38.jpg -------------------------------------------------------------------------------- /figures_for_README/Iterations:5000000-Time:03-30-2020-02-57-22.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/figures_for_README/Iterations:5000000-Time:03-30-2020-02-57-22.jpg -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/RLPycharmProj.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /Agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | class Agent(): 4 | def __init__(self, strategy, num_actions, device): 5 | self.current_step = 0 6 | self.strategy = strategy 7 | self.num_actions = num_actions 8 | self.device = device 9 | 10 | def select_action(self, state, policy_net): 11 | rate = self.strategy.get_exploration_rate(self.current_step) 12 | self.current_step += 1 13 | # print("eps = ",rate) 14 | if rate > random.random(): 15 | action = random.randrange(self.num_actions) 16 | return torch.tensor([action]).to(self.device) # explore 17 | else: 18 | with torch.no_grad(): 19 | return policy_net(state).argmax(dim=1).to(self.device) # exploit 20 | -------------------------------------------------------------------------------- /plot_result.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import torch.optim as optim 3 | import time 4 | # customized import 5 | from DQNs import * 6 | from utils import * 7 | from EnvManagers import BreakoutEnvManager 8 | from Agent import * 9 | 10 | print("="*100) 11 | print("loading mid point file...") 12 | result_pkl_path = "./Results/ModelName:2015_CNN_DQN-GameName:Breakout-Time:03-28-2020-18-20-28.pkl" 13 | with open(result_pkl_path, 'rb') as fresult: 14 | tracker_dict = pickle.load(fresult) 15 | plt.plot(tracker_dict["rewards_hist"]) 16 | plt.show() 17 | 18 | print("="*100) 19 | print("loading mid point file...") 20 | result_pkl_path = "./Results/ModelName:2015_CNN_DQN-GameName:Breakout-Time:03-28-2020-18-20-28-Eval.pkl" 21 | with open(result_pkl_path, 'rb') as fresult: 22 | tracker_dict = pickle.load(fresult) 23 | plt.plot(tracker_dict["eval_reward_list"]) 24 | plt.show() 25 | 26 | -------------------------------------------------------------------------------- /DDQN_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "config": { 3 | "MODEL_NAME": "DQN_CNN_2015", 4 | "GAME_NAME": "Breakout", 5 | "GAME_ENV": "BreakoutDeterministic-v4", 6 | "IS_USE_ADDITIONAL_ENDING_CRITERION": false, 7 | "CHECK_POINT_PATH": "./checkpoints/", 8 | "FIGURES_PATH": "./figures/", 9 | "MIDDLE_POINT_PATH": "./MiddlePoints/", 10 | "TENSORBOARD_PATH": "./Tensorboard/", 11 | "RESULT_PATH": "./Results/", 12 | "HELDOUT_SET_DIR": "./Heldout/", 13 | "HELDOUT_SET_MAX_PER_BATCH": 5000, 14 | "HELDOUT_SAVE_RATE": 0.005, 15 | "MAX_ITERATION": 100000, 16 | "MAX_FRAMES": 22000000, 17 | "UPDATE_PER_CHECKPOINT": 20000, 18 | "DATE_FORMAT": "%m-%d-%Y-%H-%M-%S", 19 | "IS_GENERATE_HELDOUT": false, 20 | "IS_SAVE_MIDDLE_POINT": false, 21 | "IS_VISUALIZE_STATE": false, 22 | "IS_RENDER_GAME_PROCESS": true, 23 | "IS_BREAK_BY_MAX_FRAMES": false, 24 | "IS_USED_TENSORBOARD": false, 25 | "IS_BREAK_BY_MAX_ITERATION": true, 26 | "IS_USED_PER": false 27 | }, 28 | "hyperparams": { 29 | "batch_size": 32, 30 | "gamma": 0.99, 31 | "eps_start": 1.0, 32 | "eps_end": 0.1, 33 | "eps_final": 0.01, 34 | "eps_startpoint": 50000, 35 | "eps_kneepoint": 1000000, 36 | "eps_final_knee": 22000000, 37 | "alpha": 0.6, 38 | "beta_start": 0.4, 39 | "beta_startpoint": 50000, 40 | "beta_kneepoint": 1000000, 41 | "error_epsilon": 1e-5, 42 | "lr_PER": 0.000005, 43 | "target_update": 2500, 44 | "memory_size": 1000000, 45 | "lr": 0.00005, 46 | "num_episodes": 100000, 47 | "replay_start_size": 50000, 48 | "action_repeat": 4 49 | }, 50 | "eval": { 51 | "EVAL_EPISODE": 10, 52 | "GIF_SAVE_PATH": "./GIF_Reuslts/", 53 | "ACTIONS_PER_EVAL": 100000, 54 | "EVAL_MODEL_LIST_TXT_PATH": "./eval_model_list_txt/", 55 | "IS_RENDER_GAME_PROCESS": true, 56 | "IS_VISUALIZE_STATE": false, 57 | "IS_USE_ADDITIONAL_ENDING_CRITERION": true 58 | } 59 | } -------------------------------------------------------------------------------- /sumtree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # store the priorities of experience 3 | # the priorities of memory samples are stored in leaf node, and the value of parent node is the sum of its children 4 | class Sumtree(): 5 | def __init__(self, capacity): 6 | self.capacity = capacity 7 | self.tree = np.zeros(2*capacity - 1) # store the priorities of memory 8 | self.tree[capacity - 1] = 1 9 | self.stored = [False] * (2*capacity - 1) # indicate whether this node is used to store 10 | # self.cur_point = 0 11 | self.length = 0 # maximum length is capacity 12 | self.push_count = 0 13 | 14 | def update_node(self, index, change): 15 | # update sum tree from leaf node if the priority of leaf node changed 16 | parent = (index-1)//2 17 | self.tree[parent] += change 18 | self.stored[parent] = True 19 | if parent > 0: 20 | self.update_node(parent, change) 21 | 22 | def update(self, index_memory, p): 23 | # update sum tree from new priority 24 | index = index_memory + self.capacity - 1 25 | change = p - self.tree[index] 26 | self.tree[index] = p 27 | self.stored[index] = True 28 | self.update_node(index, change) 29 | 30 | def get_p_total(self): 31 | # return total priorities 32 | return self.tree[0] 33 | 34 | def get_p_min(self): 35 | return min(self.tree[self.capacity-1:self.length+self.capacity-1]) 36 | 37 | def get_by_priority(self, index, s): 38 | # get index of node by priority s 39 | left_child = index*2 + 1 40 | right_child = index*2 + 2 41 | if left_child >= self.tree.shape[0]: 42 | return index 43 | if self.stored[left_child] == False: 44 | return self.get_by_priority(right_child, s-self.tree[left_child]) 45 | if self.stored[right_child] == False: 46 | return self.get_by_priority(left_child, s) 47 | if s <= self.tree[left_child]: 48 | return self.get_by_priority(left_child, s) 49 | else: 50 | return self.get_by_priority(right_child, s-self.tree[left_child]) 51 | 52 | def sample(self, s): 53 | # sample node by priority s, return the index and priority of experience 54 | self.stored[self.length + self.capacity - 2] = False # cannot sample the latest state 55 | index = self.get_by_priority(0, s) 56 | return index - self.capacity + 1, self.tree[index] 57 | 58 | def push(self): 59 | # push experience, the initial priority is the maximum priority in sum tree 60 | index_memory = self.push_count % self.capacity 61 | if self.length < self.capacity: 62 | self.length += 1 63 | self.update(index_memory, np.max(self.tree[self.capacity-1 : self.capacity+self.length-1])) 64 | self.push_count += 1 65 | 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /TestonHeldout.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import torch.optim as optim 3 | import time 4 | # customized import 5 | from DQNs import * 6 | from utils import * 7 | from EnvManagers import AtariEnvManager 8 | from Agent import * 9 | 10 | 11 | param_json_fname = "DDQN_params.json" 12 | # checkpoints to evaluate 13 | model_list_fname = "./eval_model_list_txt/ModelName_2015_CNN_DQN-GameName_Breakout-Time_03-30-2020-02-57-36.txt" 14 | 15 | config_dict, hyperparams_dict, eval_dict = read_json(param_json_fname) 16 | 17 | # if there is not the txt file to store name of checkpoints, they can be directly loaded from folder by following code 18 | """ model_list = os.listdir(config_dict["CHECK_POINT_PATH"]+config_dict["GAME_NAME"]) 19 | model_list = [config_dict["CHECK_POINT_PATH"]+config_dict["GAME_NAME"]+'/'+x for x in model_list] """ 20 | 21 | # load model list 22 | with open(model_list_fname) as f: 23 | model_list = f.readlines() 24 | model_list = [x.strip() for x in model_list] # remove whitespace characters like `\n` at the end of each line 25 | 26 | subfolder = model_list_fname.split("/")[-1][:-4] 27 | # get the update iterations of each checkpoint 28 | iterations = [int(x.split("/")[-1].split("_")[1].split("-")[0])/config_dict["UPDATE_PER_CHECKPOINT"] for x in model_list] 29 | iterations.sort() 30 | model_list.sort(key = lambda x: int(x.split("/")[-1].split("_")[1].split("-")[0])) 31 | 32 | # set environment 33 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 34 | em = AtariEnvManager(device, config_dict["GAME_ENV"], config_dict["IS_USE_ADDITIONAL_ENDING_CRITERION"]) 35 | 36 | # set policy net 37 | if config_dict["MODEL_NAME"] == "DQN_CNN_2015": 38 | policy_net = DQN_CNN_2015(num_classes=em.num_actions_available(),init_weights=True).to(device) 39 | target_net = DQN_CNN_2015(num_classes=em.num_actions_available(),init_weights=True).to(device) 40 | elif config_dict["MODEL_NAME"] == "Dueling_DQN_2016_Modified": 41 | policy_net = Dueling_DQN_2016_Modified(num_classes=em.num_actions_available(), init_weights=True).to(device) 42 | target_net = Dueling_DQN_2016_Modified(num_classes=em.num_actions_available(), init_weights=True).to(device) 43 | else: 44 | print("No such model! Please check your configuration in .json file") 45 | 46 | # Auxilary variables 47 | tracker_dict = {} 48 | tracker_dict["UPDATE_PER_CHECKPOINT"] = config_dict["UPDATE_PER_CHECKPOINT"] 49 | tracker_dict["Qvalue_average_list"] = [] 50 | for model_fpath in model_list: 51 | print("testing: ",model_fpath) 52 | # load model from file 53 | policy_net.load_state_dict(torch.load(model_fpath)) 54 | policy_net.eval() 55 | Qvalue_model = [] 56 | # load heldout set 57 | hfiles = os.listdir(config_dict["HELDOUT_SET_DIR"]) 58 | for hfile in hfiles: 59 | with open(config_dict["HELDOUT_SET_DIR"]+'/'+hfile,'rb') as f: 60 | heldout_set = pickle.load(f) 61 | for state in heldout_set: 62 | # compute Qvalue by loaded model 63 | with torch.no_grad(): 64 | Qvalue = policy_net(state.float().cuda()/255).detach().max(dim=1)[0].cpu().item() 65 | Qvalue_model.append(Qvalue) 66 | tracker_dict["Qvalue_average_list"].append(sum(Qvalue_model)/len(Qvalue_model)) 67 | 68 | 69 | if not os.path.exists(config_dict["RESULT_PATH"]): 70 | os.makedirs(config_dict["RESULT_PATH"]) 71 | 72 | # save the figure 73 | plt.figure() 74 | plt.plot(iterations, tracker_dict["Qvalue_average_list"]) 75 | plt.title("Average Q on held out set") 76 | plt.xlabel("Training iterations") 77 | plt.ylabel("Average Q value") 78 | plt.savefig(config_dict["RESULT_PATH"] + subfolder + "Average_Q_value.jpg") 79 | 80 | tracker_fname = subfolder + "-Eval.pkl" 81 | with open(config_dict["RESULT_PATH"] + tracker_fname, 'wb') as f: 82 | pickle.dump(tracker_dict, f) 83 | 84 | 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import torch.optim as optim 3 | import time 4 | # customized import 5 | from DQNs import * 6 | from utils import * 7 | from EnvManagers import AtariEnvManager 8 | from Agent import * 9 | 10 | 11 | param_json_fname = "DDQN_params.json" 12 | model_list_fname = "./eval_model_list_txt/ModelName:Dueling_DQN_2016_Modified-GameName:Pong-Time:04-07-2020-18-30-30.txt" #TODO 13 | 14 | config_dict, hyperparams_dict, eval_dict = read_json(param_json_fname) 15 | 16 | with open(model_list_fname) as f: 17 | model_list = f.readlines() 18 | model_list = [x.strip() for x in model_list] # remove whitespace characters like `\n` at the end of each line 19 | subfolder = model_list_fname.split("/")[-1][:-4] 20 | # setup params 21 | 22 | # Auxilary variables 23 | tracker_dict = {} 24 | tracker_dict["UPDATE_PER_CHECKPOINT"] = config_dict["UPDATE_PER_CHECKPOINT"] 25 | tracker_dict["eval_reward_list"] = [] 26 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 27 | # evaluate each checkpoint model 28 | for model_fpath in model_list: 29 | print("testing: ",model_fpath) 30 | 31 | # load model from file 32 | em = AtariEnvManager(device,game_env=config_dict["GAME_ENV"], 33 | is_use_additional_ending_criterion= eval_dict["IS_USE_ADDITIONAL_ENDING_CRITERION"]) 34 | if config_dict["MODEL_NAME"] == "DQN_CNN_2015": 35 | policy_net = DQN_CNN_2015(num_classes=em.num_actions_available(),init_weights=False).to(device) 36 | elif config_dict["MODEL_NAME"] == "Dueling_DQN_2016_Modified": 37 | policy_net = Dueling_DQN_2016_Modified(num_classes=em.num_actions_available(), init_weights=False).to(device) 38 | else: 39 | print("No such model! Please check your configuration in .json file") 40 | policy_net.load_state_dict(torch.load(model_fpath)) 41 | policy_net.eval() # this network will only be used for inference. 42 | # setup greedy strategy and Agent class 43 | strategy = FullGreedyStrategy(0.01) 44 | agent = Agent(strategy, em.num_actions_available(), device) 45 | 46 | best_frames_for_gif = None 47 | best_reward = -99999 48 | 49 | reward_list_episodes = [] 50 | for episode in range(eval_dict["EVAL_EPISODE"]): 51 | em.reset() 52 | state = em.get_state() # initialize sate 53 | reward_per_episode = 0 54 | frames_for_gif = [] 55 | 56 | while(1): 57 | if eval_dict["IS_RENDER_GAME_PROCESS"]: em.env.render() #render in 'human mode. BZX: will this slow down the speed? 58 | if eval_dict["IS_VISUALIZE_STATE"]: visualize_state(state) 59 | frame = em.render('rgb_array') 60 | frames_for_gif.append(frame) 61 | # Given s, select a by strategy 62 | if em.done or em.is_additional_ending or em.is_initial_action(): 63 | action = torch.tensor([1]) 64 | else: 65 | action = agent.select_action(state, policy_net) 66 | # take action 67 | reward = em.take_action(action) 68 | # collect unclipped reward from env along the action 69 | reward_per_episode += reward 70 | # after took a, get s' 71 | next_state = em.get_state() 72 | # update current state 73 | state = next_state 74 | 75 | if em.done: 76 | reward_list_episodes.append(reward_per_episode.cpu().item()) 77 | break 78 | #write gif 79 | if reward_per_episode > best_reward: 80 | best_reward = reward_per_episode 81 | #update best frames list 82 | best_frames_for_gif = frames_for_gif.copy() 83 | 84 | tracker_dict["eval_reward_list"].append(np.median(reward_list_episodes)) 85 | # print( tracker_dict["eval_reward_list"]) 86 | # save results 87 | if best_reward > 0.8 * np.median(tracker_dict["eval_reward_list"]): 88 | model_name = model_fpath.split("/")[-1][:-4] 89 | generate_gif(eval_dict["GIF_SAVE_PATH"] + subfolder + "/", model_name, best_frames_for_gif, best_reward.cpu().item()) 90 | 91 | if not os.path.exists(config_dict["RESULT_PATH"]): 92 | os.makedirs(config_dict["RESULT_PATH"]) 93 | 94 | tracker_fname = subfolder + "-Eval.pkl" 95 | with open(config_dict["RESULT_PATH"] + tracker_fname, 'wb') as f: 96 | pickle.dump(tracker_dict, f) 97 | 98 | 99 | 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /CartPole.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import math 3 | import random 4 | import numpy as np 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | import os 8 | import sys 9 | import torch.nn.functional as F 10 | import datetime 11 | from itertools import count 12 | import torch.nn.functional as F 13 | from PIL import Image 14 | import torch 15 | import torch.optim as optim 16 | import torchvision.transforms as T 17 | # customized import 18 | from DQNs import DQN 19 | from utils import * 20 | from EnvManagers import CartPoleEnvManager 21 | import pickle 22 | 23 | 24 | 25 | class Agent(): 26 | def __init__(self, strategy, num_actions, device): 27 | self.current_step = 0 28 | self.strategy = strategy 29 | self.num_actions = num_actions 30 | self.device = device 31 | 32 | def select_action(self, state, policy_net): 33 | rate = self.strategy.get_exploration_rate(self.current_step) 34 | self.current_step += 1 35 | 36 | if rate > random.random(): 37 | action = random.randrange(self.num_actions) 38 | return torch.tensor([action]).to(self.device) # explore 39 | else: 40 | with torch.no_grad(): 41 | return policy_net(state).argmax(dim=1).to(self.device) # exploit 42 | 43 | # Configuration: 44 | CHECK_POINT_PATH = "./checkpoints/" 45 | FIGURES_PATH = "./figures/" 46 | GAME_NAME = "CartPole" 47 | DATE_FORMAT = "%m-%d-%Y-%H-%M-%S" 48 | EPISODES_PER_CHECKPOINT = 1000 49 | 50 | # Hyperparameters 51 | batch_size = 32 52 | gamma = 0.99 53 | eps_start = 1 54 | eps_end = 0.1 55 | # eps_decay = 0.001 56 | eps_kneepoint = 500000 57 | target_update = 10 58 | memory_size = 100000 59 | lr = 0.0005 60 | num_episodes = 100000 61 | 62 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 63 | em = CartPoleEnvManager(device) 64 | strategy = EpsilonGreedyStrategyLinear(eps_start, eps_end, eps_kneepoint) 65 | agent = Agent(strategy, em.num_actions_available(), device) 66 | memory = ReplayMemory(memory_size) 67 | 68 | policy_net = DQN(em.get_screen_height(), em.get_screen_width()).to(device) 69 | target_net = DQN(em.get_screen_height(), em.get_screen_width()).to(device) 70 | target_net.load_state_dict(policy_net.state_dict()) 71 | target_net.eval() # this network will only be used for inference. 72 | optimizer = optim.Adam(params=policy_net.parameters(), lr=lr) 73 | criterion = torch.nn.SmoothL1Loss() 74 | 75 | heldoutset_counter = 0 76 | HELD_OUT_SET = [] 77 | episode_durations = [] 78 | running_reward = 0 79 | plt.figure() 80 | # for episode in range(num_episodes): 81 | # em.reset() 82 | # state = em.get_state() 83 | # 84 | # for timestep in count(): 85 | # action = agent.select_action(state, policy_net) 86 | # reward = em.take_action(action) 87 | # next_state = em.get_state() 88 | # if random.random() < 0.005: 89 | # HELD_OUT_SET.append(next_state.cpu().numpy()) 90 | # if len(HELD_OUT_SET) == 2000: 91 | # heldoutset_file = open('heldoutset-CartPole-{}'.format(heldoutset_counter), 'wb') 92 | # pickle.dump(HELD_OUT_SET, heldoutset_file) 93 | # heldoutset_file.close() 94 | # HELD_OUT_SET = [] 95 | # heldoutset_counter += 1 96 | # 97 | # memory.push(Experience(state, action, next_state, reward)) 98 | # state = next_state 99 | # 100 | # if memory.can_provide_sample(batch_size): 101 | # experiences = memory.sample(batch_size) 102 | # states, actions, rewards, next_states = extract_tensors(experiences) 103 | # 104 | # current_q_values = QValues.get_current(policy_net, states, actions) 105 | # next_q_values = QValues.get_next(target_net, next_states) 106 | # target_q_values = (next_q_values * gamma) + rewards 107 | # 108 | # # loss = F.mse_loss(current_q_values, target_q_values.unsqueeze(1)) 109 | # loss = criterion(current_q_values, target_q_values.unsqueeze(1)) 110 | # optimizer.zero_grad() 111 | # loss.backward() 112 | # optimizer.step() 113 | # 114 | # if em.done: 115 | # episode_durations.append(timestep) 116 | # running_reward = plot(episode_durations, 100) 117 | # break 118 | # 119 | # # BZX: checkpoint model 120 | # if episode % EPISODES_PER_CHECKPOINT == 0: 121 | # path = CHECK_POINT_PATH + GAME_NAME + "/" 122 | # if not os.path.exists(path): 123 | # os.makedirs(path) 124 | # torch.save(policy_net.state_dict(), 125 | # path + "Episodes:{}-Reward:{:.2f}-Time:".format(episode, running_reward) + \ 126 | # datetime.datetime.now().strftime(DATE_FORMAT) + ".pth") 127 | # plt.savefig(FIGURES_PATH + "Episodes:{}-Time:".format(episode) + datetime.datetime.now().strftime( 128 | # DATE_FORMAT) + ".jpg") 129 | # 130 | # if episode % target_update == 0: 131 | # target_net.load_state_dict(policy_net.state_dict()) 132 | # 133 | # em.close() 134 | em.get_state() 135 | for i in range(5): 136 | em.take_action(torch.tensor([1])) 137 | screen = em.get_state() 138 | 139 | plt.figure() 140 | plt.imshow(screen.squeeze(0).permute(1, 2, 0).cpu(), interpolation='none') 141 | plt.title('Non starting state example') 142 | plt.show() 143 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | # RL-Atari-gym 7 | Reinforcement Learning on Atari Games and Control 8 | 9 | Entrance of program: 10 | - Breakout.py 11 | 12 | 13 | # How to run 14 | 15 | (1). Check DDQN_params.json, make sure that every parameter is set right. 16 | 17 | ```markdown 18 | GAME_NAME # Set the game's name . This will help you create a new dir to save your result. 19 | MODEL_NAME # Set the algorithms and model you are using. This is only used for rename your result file, so you still need 20 | to change the model isntanace manually. 21 | MAX_ITERATION # In original paper, this is set to 25,000,000. But here we set it to 5,000,000 for Breakout.(2,500,000 for Pong will suffice.) 22 | num_episodes # Max number of episodes. We set it to a huge number in default so normally this stop condition 23 | usually won't be satisfied. 24 | # the program will stop when one of the above condition is met. 25 | ``` 26 | (2). Select the **model** and **game environment instance** manually. Currently, we are mainly focusing on `DQN_CNN_2015` and `Dueling_DQN_2016_Modified`. 27 | 28 | (3). Run and prey :) 29 | 30 | NOTE: When the program is running, wait for a couple of minutes and take a look at the estimated time printed in the 31 | console. Stop early and decrease the `MAX_ITERATION` if you cannot wait for such a long time. (Recommendation: typically, 32 | 24h could be a reasonable running time for your first training process. Since you can continue training your model, take 33 | a rest for both you and computer and check the saved figures to see if your model has a promising future. Hope so ~ ) 34 | 35 | # How to continue training the model 36 | 37 | The breakout.py will automatically save the mid point state and variables for you if the program exit w/o exception. 38 | 39 | 1. set the middle_point_json file path. 40 | 41 | 2. check DDQN_params.json, make sure that every parameter is set right. Typically, you need to set a new `MAX_ITERATION` 42 | or `num_episodes` . 43 | 44 | 3. Run and prey :) 45 | 46 | # How to evaluate the Model 47 | 48 | `evaluation.py` helps you evaluate the model. First, please modified `param_json_fname` and `model_list_fname` to your 49 | directory. Second, change the game environment instance and the model instance. Then run. 50 | 51 | # Results Structure 52 | 53 | The program will automatically create the the directory like this: 54 | 55 | ```markdown 56 | ├── GIF_Reuslts 57 | │   └── ModelName:2015_CNN_DQN-GameName:Breakout-Time:03-28-2020-18-20-28 58 | │   ├── Iterations:100000-Reward:0.69-Time:03-28-2020-18-20-27-EvalReward:0.0.gif 59 | │   ├── Iterations:200000-Reward:0.69-Time:03-28-2020-18-20-27-EvalReward:1.0.gif 60 | ├── Results 61 | │   ├── ModelName:2015_CNN_DQN-GameName:Breakout-Time:03-28-2020-18-20-28-Eval.pkl 62 | │   └── ModelName:2015_CNN_DQN-GameName:Breakout-Time:03-28-2020-18-20-28.pkl 63 | ├── DDQN_params.json 64 | 65 | ``` 66 | 67 | Please zip these three files/folders and upload it to our shared google drive. Rename it, e.g. `ModelName:2015_CNN_DQN-GameName:Breakout-Time:03-28-2020-18-20-28`. 68 | 69 | PS: 70 | 71 | `GIF_Reuslts` record the game process 72 | 73 | `Results` contains the history of training and eval process, which can be used to visualize later. 74 | 75 | `DDQN_params.json` contains your algorithm settings, which should match your `Results` and `GIF_Reuslts`. 76 | 77 | # TODO list: 78 | -[ ] Write env class for Pong, Cartpole and other games. (Attention: cropped bbox might need to be changed for different 79 | game.) 80 | 81 | -[ ] Write validation script on heldout sets. Load models and heldout sets, track average max Q value on heldout sets. 82 | (NOTE: load and test models in time sequence indicated by the name of model file.) 83 | 84 | -[ ] Design experiment table. Test two more Atari games. Give average performance(reward) and write .gif file. Store other figures & model for 85 | writing final report. 86 | 87 | -[ ] Implement policy gradient for Atari games. [TBD] 88 | 89 | -[ ] Possible bugs: initialization & final state representation? 90 | 91 | -[x] Implement Priority Queue and compare the performance. 92 | 93 | -[x] Evaluation script. Load model and take greedy strategy to interact with environment. 94 | Test a few epochs and give average performance. Write the best one to .gif file for presentation. 95 | 96 | -[x] Implement continuous training script. 97 | 98 | -[x] Fix eps manager(change final state from 0.1 to 0.01); add evaluation step in the training loop; 99 | write test result into gif; update Target_Net according to the # of actions instead of # of updates. 100 | Rewrite image preprocessing class to tackle with more general game input.(crop at (34,0,160,160) 101 | 102 | # File Instruction(TO BE COMPLETE SOON): 103 | 104 | 'EnvManagers.py' includes the different environment classes for different games. They wrapped the gym.env and its interface. 105 | 106 | 'DQNs.py' includes different Deep Learning architectures for feature extraction and regression. 107 | 108 | 'utils.py' includes tool functions and classes. To be specific, it includes: 109 | - Experience (namedtuple) 110 | - ReplayMemory (class) 111 | - EpsilonGreedyStrategy (class) 112 | - plot (func) 113 | - extract_tensors (func) 114 | - QValues (class) 115 | -------------------------------------------------------------------------------- /DQNs.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | class DQN_NN_Naive(nn.Module): 6 | def __init__(self, img_height, img_width): 7 | super().__init__() 8 | 9 | self.fc1 = nn.Linear(in_features=img_height * img_width * 3, out_features=24) 10 | self.fc2 = nn.Linear(in_features=24, out_features=32) 11 | self.out = nn.Linear(in_features=32, out_features=2) 12 | 13 | def forward(self, t): 14 | t = t.flatten(start_dim=1) 15 | t = F.relu(self.fc1(t)) 16 | t = F.relu(self.fc2(t)) 17 | t = self.out(t) 18 | return t 19 | 20 | 21 | class DQN_CNN_2013(nn.Module): 22 | def __init__(self, num_classes=4, init_weights=True): 23 | super().__init__() 24 | 25 | self.cnn = nn.Sequential(nn.Conv2d(4, 16, kernel_size=8, stride=4), 26 | nn.ReLU(True), 27 | nn.Conv2d(16, 32, kernel_size=4, stride=2), 28 | nn.ReLU(True) 29 | ) 30 | self.classifier = nn.Sequential(nn.Linear(9*9*32, 256), 31 | nn.ReLU(True), 32 | nn.Linear(256, num_classes) 33 | ) 34 | # nn.Dropout(0.3), # BZX: optional [TRY] 35 | if init_weights: 36 | self._initialize_weights() 37 | 38 | def forward(self, x): 39 | x = self.cnn(x) 40 | x = torch.flatten(x, start_dim=1) 41 | x = self.classifier(x) 42 | return x 43 | 44 | def _initialize_weights(self): 45 | for m in self.modules(): 46 | if isinstance(m, nn.Conv2d): 47 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 48 | if m.bias is not None: 49 | nn.init.constant_(m.bias, 0.0) 50 | elif isinstance(m, nn.BatchNorm2d): 51 | nn.init.constant_(m.weight, 1.0) 52 | nn.init.constant_(m.bias, 0.0) 53 | elif isinstance(m, nn.Linear): 54 | nn.init.normal_(m.weight, 0.0, 0.01) 55 | nn.init.constant_(m.bias, 0.0) 56 | 57 | class DQN_CNN_2015(nn.Module): 58 | def __init__(self, num_classes=4, init_weights=True): 59 | super().__init__() 60 | 61 | self.cnn = nn.Sequential(nn.Conv2d(4, 32, kernel_size=8, stride=4), 62 | nn.ReLU(True), 63 | nn.Conv2d(32, 64, kernel_size=4, stride=2), 64 | nn.ReLU(True), 65 | nn.Conv2d(64, 64, kernel_size=3, stride=1), 66 | nn.ReLU(True) 67 | ) 68 | self.classifier = nn.Sequential(nn.Linear(7*7*64, 512), 69 | nn.ReLU(True), 70 | nn.Linear(512, num_classes) 71 | ) 72 | if init_weights: 73 | self._initialize_weights() 74 | 75 | def forward(self, x): 76 | x = self.cnn(x) 77 | x = torch.flatten(x, start_dim=1) 78 | x = self.classifier(x) 79 | return x 80 | 81 | def _initialize_weights(self): 82 | for m in self.modules(): 83 | if isinstance(m, nn.Conv2d): 84 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 85 | if m.bias is not None: 86 | nn.init.constant_(m.bias, 0.0) 87 | elif isinstance(m, nn.BatchNorm2d): 88 | nn.init.constant_(m.weight, 1.0) 89 | nn.init.constant_(m.bias, 0.0) 90 | elif isinstance(m, nn.Linear): 91 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 92 | nn.init.constant_(m.bias, 0.0) 93 | 94 | #TODO 95 | class Dueling_DQN_2016_Modified(nn.Module): 96 | def __init__(self, num_classes=4, init_weights=True): 97 | super().__init__() 98 | 99 | self.cnn = nn.Sequential(nn.Conv2d(4, 32, kernel_size=8, stride=4,bias=False), 100 | nn.ReLU(True), 101 | nn.Conv2d(32, 64, kernel_size=4, stride=2,bias=False), 102 | nn.ReLU(True), 103 | nn.Conv2d(64, 64, kernel_size=3, stride=1,bias=False), 104 | nn.ReLU(True), 105 | nn.Conv2d(64,1024,kernel_size=7,stride=1,bias=False), 106 | nn.ReLU(True) 107 | ) 108 | self.streamA = nn.Linear(512, num_classes) 109 | self.streamV = nn.Linear(512, 1) 110 | 111 | if init_weights: 112 | self._initialize_weights() 113 | 114 | def forward(self, x): 115 | x = self.cnn(x) 116 | sA,sV = torch.split(x,512,dim = 1) 117 | sA = torch.flatten(sA,start_dim=1) 118 | sV = torch.flatten(sV, start_dim=1) 119 | sA = self.streamA(sA) #(B,4) 120 | sV = self.streamV(sV) #(B,1) 121 | # combine this 2 values together 122 | Q_value = sV + (sA - torch.mean(sA,dim=1,keepdim=True)) 123 | return Q_value #(B,4) 124 | 125 | def _initialize_weights(self): 126 | for m in self.modules(): 127 | if isinstance(m, nn.Conv2d): 128 | nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu') 129 | if m.bias is not None: 130 | nn.init.constant_(m.bias, 0.0) 131 | elif isinstance(m, nn.BatchNorm2d): 132 | nn.init.constant_(m.weight, 1.0) 133 | nn.init.constant_(m.bias, 0.0) 134 | elif isinstance(m, nn.Linear): 135 | nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu') 136 | nn.init.constant_(m.bias, 0.0) 137 | -------------------------------------------------------------------------------- /EnvManagers.py: -------------------------------------------------------------------------------- 1 | """ 2 | 1. wrap the env object of Gym 3 | 2. like the dataloader, we do image preprocessing in this class. 4 | """ 5 | import gym 6 | import numpy as np 7 | import torch 8 | import torchvision.transforms as T 9 | 10 | class AtariEnvManager(): 11 | def __init__(self, device, game_env, is_use_additional_ending_criterion): 12 | "avaliable game env: PongDeterministic-v4, BreakoutDeterministic-v4" 13 | self.device = device 14 | self.game_env = game_env 15 | # PongDeterministic-v4 16 | self.env = gym.make(game_env).unwrapped 17 | # self.env = gym.make('BreakoutDeterministic-v4').unwrapped #BZX: v4 automatically return skipped screens 18 | self.env.reset() 19 | self.current_screen = None 20 | self.done = False 21 | # BZX: running_K: stacked K images together to present a state 22 | # BZx: running_queue: maintain the latest running_K images 23 | self.running_K = 4 24 | self.running_queue = [] 25 | self.is_additional_ending = False # This may change to True along the game. 2 possible reason: loss of lives; negative reward. 26 | self.current_lives = None 27 | self.is_use_additional_ending_criterion = is_use_additional_ending_criterion 28 | 29 | def reset(self): 30 | self.env.reset() 31 | self.current_screen = None 32 | self.running_queue = [] #BZX: clear the state 33 | self.is_additional_ending = False 34 | # self.current_lives = 5 35 | 36 | def close(self): 37 | self.env.close() 38 | 39 | def render(self, mode='human'): 40 | return self.env.render(mode) 41 | 42 | def num_actions_available(self): 43 | return self.env.action_space.n 44 | 45 | def print_action_meanings(self): 46 | print(self.env.get_action_meanings()) 47 | 48 | def take_action(self, action): 49 | _, reward, self.done, lives = self.env.step(action.item()) 50 | # for Pong: lives is always 0.0, so self.is_additional_ending will always be false 51 | if self.is_use_additional_ending_criterion: 52 | self.is_additonal_ending_criterion_met(lives,reward) 53 | return torch.tensor([reward], device=self.device) 54 | 55 | def just_starting(self): 56 | return self.current_screen is None 57 | def is_initial_action(self): 58 | return sum(self.running_queue).sum() == 0 59 | 60 | def init_running_queue(self): 61 | """ 62 | initialize running queue with K black images 63 | :return: 64 | """ 65 | self.current_screen = self.get_processed_screen() 66 | black_screen = torch.zeros_like(self.current_screen) 67 | for _ in range(self.running_K): 68 | self.running_queue.append(black_screen) 69 | 70 | def get_state(self): 71 | if self.just_starting(): 72 | self.init_running_queue() 73 | elif self.done or self.is_additional_ending: 74 | self.current_screen = self.get_processed_screen() 75 | black_screen = torch.zeros_like(self.current_screen) 76 | # BZX: update running_queue 77 | self.running_queue.pop(0) 78 | self.running_queue.append(black_screen) 79 | else: #BZX: normal case 80 | s2 = self.get_processed_screen() 81 | self.current_screen = s2 82 | # BZx: update running_queue with s2 83 | self.running_queue.pop(0) 84 | self.running_queue.append(s2) 85 | 86 | return torch.stack(self.running_queue,dim=1).squeeze(2) #BZX: check if shape is (1KHW) 87 | 88 | def is_additonal_ending_criterion_met(self,lives,reward): 89 | "for different atari game, design different ending state criterion" 90 | if self.game_env == "BreakoutDeterministic-v4": 91 | if self.is_initial_action(): 92 | self.current_lives = lives['ale.lives'] 93 | elif lives['ale.lives'] < self.current_lives: 94 | self.is_additional_ending = True 95 | else: 96 | self.is_additional_ending = False 97 | self.current_lives = lives['ale.lives'] 98 | if self.game_env == "PongDeterministic-v4": 99 | if reward < 0: #miss one ball will lead to a ending sate. 100 | self.is_additional_ending = True 101 | else: 102 | self.is_additional_ending = False 103 | return False 104 | 105 | def get_screen_height(self): 106 | screen = self.get_processed_screen() 107 | return screen.shape[2] 108 | 109 | def get_screen_width(self): 110 | screen = self.get_processed_screen() 111 | return screen.shape[3] 112 | 113 | def get_processed_screen(self): 114 | screen = self.render('rgb_array').transpose((2, 0, 1)) # PyTorch expects CHW 115 | # screen = self.crop_screen(screen) 116 | return self.transform_screen_data(screen) #shape is [1,1,110,84] 117 | 118 | def crop_screen(self, screen): 119 | if self.game_env == "BreakoutDeterministic-v4" or self.game_env == "PongDeterministic-v4": 120 | bbox = [34,0,160,160] #(x,y,delta_x,delta_y) 121 | screen = screen[:, bbox[0]:bbox[2]+bbox[0], bbox[1]:bbox[3]+bbox[1]] #BZX:(CHW) 122 | if self.game_env == "GopherDeterministic-v4": 123 | bbox = [110,0,120,160] 124 | screen = screen[:, bbox[0]:bbox[2]+bbox[0], bbox[1]:bbox[3]+bbox[1]] 125 | return screen 126 | 127 | def transform_screen_data(self, screen): 128 | # Convert to float, rescale, convert to tensor 129 | screen = np.ascontiguousarray(screen, dtype=np.float32) / 255 130 | screen = torch.from_numpy(screen) 131 | screen = self.crop_screen(screen) 132 | # Use torchvision package to compose image transforms 133 | resize = T.Compose([ 134 | T.ToPILImage() 135 | , T.Grayscale() 136 | , T.Resize((84, 84)) #BZX: the original paper's settings:(110,84), however, for simplicty we... 137 | , T.ToTensor() 138 | ]) 139 | # add a batch dimension (BCHW) 140 | screen = resize(screen) 141 | 142 | return screen.unsqueeze(0).to(self.device) # BZX: Pay attention to the shape here. should be [1,1,84,84] 143 | 144 | 145 | class CartPoleEnvManager(): 146 | 147 | def __init__(self, device): 148 | self.device = device 149 | self.env = gym.make('CartPole-v0').unwrapped 150 | self.env.reset() 151 | self.current_screen = None 152 | self.done = False 153 | 154 | def reset(self): 155 | self.env.reset() 156 | self.current_screen = None 157 | 158 | def close(self): 159 | self.env.close() 160 | 161 | def render(self, mode='human'): 162 | return self.env.render(mode) 163 | 164 | def num_actions_available(self): 165 | return self.env.action_space.n 166 | 167 | def take_action(self, action): 168 | _, reward, self.done, _ = self.env.step(action.item()) 169 | return torch.tensor([reward], device=self.device) 170 | 171 | def just_starting(self): 172 | return self.current_screen is None 173 | 174 | def get_state(self): 175 | if self.just_starting() or self.done: 176 | self.current_screen = self.get_processed_screen() 177 | black_screen = torch.zeros_like(self.current_screen) 178 | return black_screen 179 | else: 180 | s1 = self.current_screen 181 | s2 = self.get_processed_screen() 182 | self.current_screen = s2 183 | return s2 - s1 184 | 185 | def get_screen_height(self): 186 | screen = self.get_processed_screen() 187 | return screen.shape[2] 188 | 189 | def get_screen_width(self): 190 | screen = self.get_processed_screen() 191 | return screen.shape[3] 192 | 193 | def get_processed_screen(self): 194 | screen = self.render('rgb_array').transpose((2, 0, 1)) # PyTorch expects CHW 195 | screen = self.crop_screen(screen) 196 | return self.transform_screen_data(screen) 197 | 198 | def crop_screen(self, screen): 199 | screen_height = screen.shape[1] 200 | 201 | # Strip off top and bottom 202 | top = int(screen_height * 0.4) 203 | bottom = int(screen_height * 0.8) 204 | screen = screen[:, top:bottom, :] 205 | return screen 206 | 207 | def transform_screen_data(self, screen): 208 | # Convert to float, rescale, convert to tensor 209 | screen = np.ascontiguousarray(screen, dtype=np.float32) / 255 210 | screen = torch.from_numpy(screen) 211 | 212 | # Use torchvision package to compose image transforms 213 | resize = T.Compose([ 214 | T.ToPILImage() 215 | , T.Resize((40, 90)) 216 | , T.ToTensor() 217 | ]) 218 | 219 | return resize(screen).unsqueeze(0).to(self.device) # add a batch dimension (BCHW) -------------------------------------------------------------------------------- /continue_training.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import torch.optim as optim 3 | import time 4 | # customized import 5 | from DQNs import * 6 | from utils import * 7 | from EnvManagers import BreakoutEnvManager 8 | from Agent import * 9 | 10 | # load params 11 | param_json_fname = "DDQN_params.json" #TODO: please make sure the params are set right 12 | config_dict, hyperparams_dict = read_json(param_json_fname) 13 | 14 | # load middle point file path 15 | print("="*100) 16 | print("loading mid point file...") 17 | Middle_Point_json = "tmp_middle_point_file_path.json" #TODO 18 | md_path_dict = load_Middle_Point(Middle_Point_json) 19 | 20 | heldout_saver = HeldoutSaver(config_dict["HELDOUT_SET_DIR"], 21 | config_dict["HELDOUT_SET_MAX_PER_BATCH"], 22 | config_dict["HELDOUT_SAVE_RATE"]) 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | em = BreakoutEnvManager(device) 25 | 26 | # load other states 27 | with open(md_path_dict["mdStateFileName"], 'rb') as middle_point_state_file: 28 | midddle_point = pickle.load(middle_point_state_file) 29 | agent = midddle_point["agent"] 30 | tracker_dict = midddle_point["tracker_dict"] 31 | heldout_saver.set_batch_counter(midddle_point["heldout_batch_counter"]) 32 | strategy = midddle_point["strategy"] 33 | 34 | # load memory 35 | with open(md_path_dict["mdMemFileName"], 'rb') as middle_point_mem_file: 36 | memory = pickle.load(middle_point_mem_file) 37 | 38 | # load 2 networks 39 | policy_net = DQN_CNN_2015(num_classes=em.num_actions_available(),init_weights=False).to(device) 40 | target_net = DQN_CNN_2015(num_classes=em.num_actions_available(),init_weights=False).to(device) 41 | policy_net.load_state_dict(torch.load(md_path_dict["md_Policy_Net_fName"])) 42 | target_net.load_state_dict(torch.load(md_path_dict["md_Target_Net_fName"])) 43 | target_net.eval() # this network will only be used for inference. 44 | 45 | print("successfully load all middle point files") 46 | print("="*100) 47 | 48 | # initialize optimizer and criterion 49 | optimizer = optim.Adam(params=policy_net.parameters(), lr=hyperparams_dict["lr"]) 50 | criterion = torch.nn.SmoothL1Loss() 51 | 52 | plt.figure() 53 | 54 | t1,t2 = time.time(),time.time() # for estimating the time 55 | num_target_update = 0 # auxillary variable for estimating the time 56 | 57 | for episode in range(hyperparams_dict["num_episodes"]): 58 | em.reset() 59 | state = em.get_state() # initialize sate 60 | tol_reward = 0 61 | while(1): 62 | # Visualization of game process and state 63 | if config_dict["IS_RENDER_GAME_PROCESS"]: em.env.render() # BZX: will this slow down the speed? 64 | if config_dict["IS_VISUALIZE_STATE"]: visualize_state(state) 65 | if config_dict["IS_GENERATE_HELDOUT"]: heldout_saver.append(state) # generate heldout set for offline eval 66 | 67 | # Given s, select a by either policy_net or random 68 | action = agent.select_action(state, policy_net) 69 | # collect reward from env along the action 70 | reward = em.take_action(action) 71 | tol_reward += reward 72 | # after took a, get s' 73 | next_state = em.get_state() 74 | # push (s,a,s',r) into memory 75 | memory.push(Experience(state[0,-1,:,:].clone(), action, "", reward)) 76 | # update current state 77 | state = next_state 78 | 79 | # After memory have been filled with enough samples, we update policy_net every 4 agent steps. 80 | if (agent.current_step % hyperparams_dict["action_repeat"] == 0) and \ 81 | memory.can_provide_sample(hyperparams_dict["batch_size"], hyperparams_dict["replay_start_size"]): 82 | 83 | experiences = memory.sample(hyperparams_dict["batch_size"]) 84 | states, actions, rewards, next_states = extract_tensors(experiences) 85 | current_q_values = QValues.get_current(policy_net, states, actions) # checked 86 | # next_q_values = QValues.DQN_get_next(target_net, next_states) # for DQN 87 | next_q_values = QValues.DDQN_get_next(policy_net,target_net, next_states) 88 | target_q_values = (next_q_values * hyperparams_dict["gamma"]) + rewards 89 | # calculate loss and update policy_net 90 | optimizer.zero_grad() 91 | loss = criterion(current_q_values, target_q_values.unsqueeze(1)) 92 | loss.backward() 93 | optimizer.step() 94 | 95 | tracker_dict["loss_hist"].append(loss.item()) 96 | tracker_dict["minibatch_updates_counter"] += 1 97 | 98 | # update target_net 99 | if tracker_dict["minibatch_updates_counter"] % hyperparams_dict["target_update"] == 0: 100 | target_net.load_state_dict(policy_net.state_dict()) 101 | 102 | # estimate time 103 | num_target_update += 1 104 | if num_target_update % 2 == 0: t1 = time.time() 105 | if num_target_update % 2 == 1: t2 = time.time() 106 | print("=" * 50) 107 | remaining_update_times = (config_dict["MAX_ITERATION"] - tracker_dict["minibatch_updates_counter"])// \ 108 | hyperparams_dict["target_update"] 109 | time_sec = abs(t1-t2) * remaining_update_times 110 | print("estimated remaining time = {}h-{}min".format(time_sec//3600,(time_sec%3600)//60)) 111 | print("len of replay memory:", len(memory.memory)) 112 | print("minibatch_updates_counter = ", tracker_dict["minibatch_updates_counter"]) 113 | print("current_step of agent = ", agent.current_step) 114 | print("exploration rate = ", strategy.get_exploration_rate(agent.current_step)) 115 | print("=" * 50) 116 | 117 | # save checkpoint model 118 | if tracker_dict["minibatch_updates_counter"] % config_dict["UPDATE_PER_CHECKPOINT"] == 0: 119 | save_model(policy_net, tracker_dict, config_dict) 120 | 121 | plt.savefig(config_dict["FIGURES_PATH"] + "Iterations:{}-Time:".format(tracker_dict["minibatch_updates_counter"]) + datetime.datetime.now().strftime( 122 | config_dict["DATE_FORMAT"]) + ".jpg") 123 | 124 | if em.done: 125 | tracker_dict["rewards_hist"].append(tol_reward) 126 | tracker_dict["running_reward"] = plot(tracker_dict["rewards_hist"], 100) 127 | break 128 | if config_dict["IS_BREAK_BY_MAX_ITERATION"] and \ 129 | tracker_dict["minibatch_updates_counter"] > config_dict["MAX_ITERATION"]: 130 | break 131 | 132 | em.close() 133 | # save loss figure 134 | plt.figure() 135 | plt.plot(tracker_dict["loss_hist"]) 136 | plt.title("loss") 137 | plt.xlabel("iterations") 138 | plt.savefig(config_dict["FIGURES_PATH"] + "Loss-Iterations:{}-Time:".format(tracker_dict["minibatch_updates_counter"]) + datetime.datetime.now().strftime( 139 | config_dict["DATE_FORMAT"]) + ".jpg") 140 | 141 | if config_dict["IS_SAVE_MIDDLE_POINT"]: 142 | # save core instances 143 | if not os.path.exists(config_dict["MIDDLE_POINT_PATH"]): 144 | os.makedirs(config_dict["MIDDLE_POINT_PATH"]) 145 | 146 | mdMemFileName = config_dict["MIDDLE_POINT_PATH"] + "MiddlePoint_Memory_" + datetime.datetime.now().strftime( 147 | config_dict["DATE_FORMAT"]) + ".pkl" 148 | middle_mem_file = open(mdMemFileName, 'wb') 149 | pickle.dump(memory, middle_mem_file) 150 | middle_mem_file.close() 151 | del memory # make more memory space 152 | 153 | midddle_point = {} 154 | midddle_point["agent"] = agent 155 | midddle_point["tracker_dict"] = tracker_dict 156 | midddle_point["heldout_batch_counter"] = heldout_saver.batch_counter 157 | midddle_point["strategy"] = strategy 158 | mdStateFileName = config_dict["MIDDLE_POINT_PATH"] + "MiddlePoint_State_" + datetime.datetime.now().strftime(config_dict["DATE_FORMAT"]) + ".pkl" 159 | 160 | middle_point_file = open(mdStateFileName, 'wb') 161 | pickle.dump(midddle_point, middle_point_file) 162 | middle_point_file.close() 163 | 164 | # save policy_net and target_net 165 | md_Policy_Net_fName = config_dict["MIDDLE_POINT_PATH"] + "MiddlePoint_Policy_Net_" + datetime.datetime.now().strftime(config_dict["DATE_FORMAT"]) + ".pth" 166 | torch.save(policy_net.state_dict(),md_Policy_Net_fName) 167 | md_Target_Net_fName = config_dict["MIDDLE_POINT_PATH"] + "MiddlePoint_Target_Net_" + datetime.datetime.now().strftime(config_dict["DATE_FORMAT"]) + ".pth" 168 | torch.save(policy_net.state_dict(), md_Target_Net_fName) 169 | 170 | # save middle point files' path for continuous training 171 | md_path_dict = {} 172 | md_path_dict["mdMemFileName"] = mdMemFileName 173 | md_path_dict["mdStateFileName"] = mdStateFileName 174 | md_path_dict["md_Policy_Net_fName"] = md_Policy_Net_fName 175 | md_path_dict["md_Target_Net_fName"] = md_Target_Net_fName 176 | with open('tmp_middle_point_file_path.json', 'w') as fp: 177 | json.dump(md_path_dict, fp) -------------------------------------------------------------------------------- /Q_Learning_IceLake.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": { 7 | "collapsed": true, 8 | "pycharm": { 9 | "is_executing": false 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "import numpy as np\n", 15 | "import gym\n", 16 | "import random\n", 17 | "import time\n", 18 | "from IPython.display import clear_output\n" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 10, 24 | "outputs": [ 25 | { 26 | "name": "stdout", 27 | "text": [ 28 | "[[0. 0. 0. 0.]\n", 29 | " [0. 0. 0. 0.]\n", 30 | " [0. 0. 0. 0.]\n", 31 | " [0. 0. 0. 0.]\n", 32 | " [0. 0. 0. 0.]\n", 33 | " [0. 0. 0. 0.]\n", 34 | " [0. 0. 0. 0.]\n", 35 | " [0. 0. 0. 0.]\n", 36 | " [0. 0. 0. 0.]\n", 37 | " [0. 0. 0. 0.]\n", 38 | " [0. 0. 0. 0.]\n", 39 | " [0. 0. 0. 0.]\n", 40 | " [0. 0. 0. 0.]\n", 41 | " [0. 0. 0. 0.]\n", 42 | " [0. 0. 0. 0.]\n", 43 | " [0. 0. 0. 0.]]\n" 44 | ], 45 | "output_type": "stream" 46 | } 47 | ], 48 | "source": [ 49 | "env = gym.make(\"FrozenLake-v0\")\n", 50 | "action_space_size = env.action_space.n\n", 51 | "state_space_size = env.observation_space.n\n", 52 | "\n", 53 | "q_table = np.zeros((state_space_size, action_space_size))\n", 54 | "print(q_table)" 55 | ], 56 | "metadata": { 57 | "collapsed": false, 58 | "pycharm": { 59 | "name": "#%%\n", 60 | "is_executing": false 61 | } 62 | } 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 11, 67 | "outputs": [], 68 | "source": [ 69 | "num_episodes = 20000\n", 70 | "max_steps_per_episode = 100\n", 71 | "\n", 72 | "learning_rate = 0.1\n", 73 | "discount_rate = 0.99\n", 74 | "\n", 75 | "exploration_rate = 1\n", 76 | "max_exploration_rate = 1\n", 77 | "min_exploration_rate = 0.01\n", 78 | "exploration_decay_rate = 0.01" 79 | ], 80 | "metadata": { 81 | "collapsed": false, 82 | "pycharm": { 83 | "name": "#%% fff\n", 84 | "is_executing": false 85 | } 86 | } 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "source": [ 91 | "### Q-learning algorithm\n", 92 | "```\n", 93 | "for episode in range(num_episodes):\n", 94 | " # initialize new episode params\n", 95 | " \n", 96 | " for step in range(max_steps_per_episode): \n", 97 | " # Exploration-exploitation trade-off\n", 98 | " # Take new action\n", 99 | " # Update Q-table\n", 100 | " # Set new state\n", 101 | " # Add new reward \n", 102 | "\n", 103 | " # Exploration rate decay \n", 104 | " # Add current episode reward to total rewards list\n", 105 | "```" 106 | ], 107 | "metadata": { 108 | "collapsed": false 109 | } 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 12, 114 | "outputs": [], 115 | "source": [ 116 | "rewards_all_episodes = []\n", 117 | "for episode in range(num_episodes):\n", 118 | " state = env.reset()\n", 119 | " done = False\n", 120 | " rewards_current_episode = 0\n", 121 | "\n", 122 | " for step in range(max_steps_per_episode): \n", 123 | " # Exploration-exploitation trade-off\n", 124 | " exploration_rate_threshold = random.uniform(0, 1)\n", 125 | " if exploration_rate_threshold > exploration_rate:\n", 126 | " action = np.argmax(q_table[state,:]) \n", 127 | " else:\n", 128 | " action = env.action_space.sample()\n", 129 | " #take new action\n", 130 | " new_state, reward, done, info = env.step(action)\n", 131 | " # Update Q-table for Q(s,a)\n", 132 | " q_table[state, action] = q_table[state, action] * (1 - learning_rate) + \\\n", 133 | " learning_rate * (reward + discount_rate * np.max(q_table[new_state, :]))\n", 134 | " \n", 135 | " state = new_state\n", 136 | " rewards_current_episode += reward \n", 137 | " \n", 138 | " if done == True: \n", 139 | " break\n", 140 | " \n", 141 | " # Exploration rate decay\n", 142 | " exploration_rate = min_exploration_rate + \\\n", 143 | " (max_exploration_rate - min_exploration_rate) * np.exp(-exploration_decay_rate*episode)\n", 144 | " rewards_all_episodes.append(rewards_current_episode)\n", 145 | " " 146 | ], 147 | "metadata": { 148 | "collapsed": false, 149 | "pycharm": { 150 | "name": "#%%\n", 151 | "is_executing": false 152 | } 153 | } 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 13, 158 | "outputs": [ 159 | { 160 | "name": "stdout", 161 | "text": [ 162 | "********Average reward per thousand episodes********\n", 163 | "\n", 164 | "1000 : 0.04100000000000003\n", 165 | "2000 : 0.05500000000000004\n", 166 | "3000 : 0.08400000000000006\n", 167 | "4000 : 0.10800000000000008\n", 168 | "5000 : 0.10900000000000008\n", 169 | "6000 : 0.1320000000000001\n", 170 | "7000 : 0.1460000000000001\n", 171 | "8000 : 0.17500000000000013\n", 172 | "9000 : 0.24000000000000019\n", 173 | "10000 : 0.2620000000000002\n", 174 | "11000 : 0.45800000000000035\n", 175 | "12000 : 0.6610000000000005\n", 176 | "13000 : 0.6670000000000005\n", 177 | "14000 : 0.6930000000000005\n", 178 | "15000 : 0.6980000000000005\n", 179 | "16000 : 0.6840000000000005\n", 180 | "17000 : 0.6650000000000005\n", 181 | "18000 : 0.7030000000000005\n", 182 | "19000 : 0.6820000000000005\n", 183 | "20000 : 0.6890000000000005\n" 184 | ], 185 | "output_type": "stream" 186 | } 187 | ], 188 | "source": [ 189 | "# Calculate and print the average reward per thousand episodes\n", 190 | "rewards_per_thosand_episodes = np.split(np.array(rewards_all_episodes),num_episodes/1000)\n", 191 | "count = 1000\n", 192 | "\n", 193 | "print(\"********Average reward per thousand episodes********\\n\")\n", 194 | "for r in rewards_per_thosand_episodes:\n", 195 | " print(count, \": \", str(sum(r/1000)))\n", 196 | " count += 1000\n" 197 | ], 198 | "metadata": { 199 | "collapsed": false, 200 | "pycharm": { 201 | "name": "#%%\n", 202 | "is_executing": false 203 | } 204 | } 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 15, 209 | "outputs": [ 210 | { 211 | "name": "stdout", 212 | "text": [ 213 | "\n", 214 | "\n", 215 | "********Q-table********\n", 216 | "\n", 217 | "[[0.56155668 0.50676359 0.51113332 0.51633663]\n", 218 | " [0.31394613 0.31218741 0.24458661 0.49321051]\n", 219 | " [0.41637602 0.37275899 0.41689343 0.45542588]\n", 220 | " [0.20679537 0.29910297 0.36394872 0.42988882]\n", 221 | " [0.58086203 0.41773475 0.44126002 0.42910216]\n", 222 | " [0. 0. 0. 0. ]\n", 223 | " [0.33825872 0.14911566 0.18237918 0.0651859 ]\n", 224 | " [0. 0. 0. 0. ]\n", 225 | " [0.33209066 0.48196698 0.49282896 0.64355292]\n", 226 | " [0.2918537 0.70621805 0.53639371 0.44925847]\n", 227 | " [0.65793949 0.41190176 0.40708876 0.33627327]\n", 228 | " [0. 0. 0. 0. ]\n", 229 | " [0. 0. 0. 0. ]\n", 230 | " [0.4509215 0.62594686 0.76507311 0.51494189]\n", 231 | " [0.74501853 0.84665135 0.77416626 0.77714155]\n", 232 | " [0. 0. 0. 0. ]]\n" 233 | ], 234 | "output_type": "stream" 235 | } 236 | ], 237 | "source": [ 238 | "# Print updated Q-table\n", 239 | "print(\"\\n\\n********Q-table********\\n\")\n", 240 | "print(q_table)" 241 | ], 242 | "metadata": { 243 | "collapsed": false, 244 | "pycharm": { 245 | "name": "#%%\n", 246 | "is_executing": false 247 | } 248 | } 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "source": [ 253 | "### The code to watch the agent play the game \n", 254 | "```\n", 255 | "# Watch our agent play Frozen Lake by playing the best action \n", 256 | "# from each state according to the Q-table\n", 257 | "\n", 258 | "for episode in range(3):\n", 259 | " # initialize new episode params\n", 260 | "\n", 261 | " for step in range(max_steps_per_episode): \n", 262 | " # Show current state of environment on screen\n", 263 | " # Choose action with highest Q-value for current state \n", 264 | " # Take new action\n", 265 | " \n", 266 | " if done:\n", 267 | " if reward == 1:\n", 268 | " # Agent reached the goal and won episode\n", 269 | " else:\n", 270 | " # Agent stepped in a hole and lost episode \n", 271 | " \n", 272 | " # Set new state\n", 273 | " \n", 274 | "env.close()\n", 275 | "```\n", 276 | "\n" 277 | ], 278 | "metadata": { 279 | "collapsed": false, 280 | "pycharm": { 281 | "name": "#%% md\n" 282 | } 283 | } 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": 16, 288 | "outputs": [ 289 | { 290 | "name": "stdout", 291 | "text": [ 292 | " (Down)\n", 293 | "SFFF\n", 294 | "FHFH\n", 295 | "FFFH\n", 296 | "HFF\u001b[41mG\u001b[0m\n", 297 | "****You reached the goal!****\n" 298 | ], 299 | "output_type": "stream" 300 | } 301 | ], 302 | "source": [ 303 | "for episode in range(3):\n", 304 | " state = env.reset()\n", 305 | " done = False\n", 306 | " print(\"*****EPISODE \", episode+1, \"*****\\n\\n\\n\\n\")\n", 307 | " time.sleep(1)\n", 308 | " \n", 309 | " for step in range(max_steps_per_episode):\n", 310 | " clear_output(wait=True)\n", 311 | " env.render()\n", 312 | " time.sleep(0.3)\n", 313 | " \n", 314 | " action = np.argmax(q_table[state,:]) \n", 315 | " new_state, reward, done, info = env.step(action)\n", 316 | " \n", 317 | " if done:\n", 318 | " clear_output(wait=True)\n", 319 | " env.render()\n", 320 | " if reward == 1:\n", 321 | " print(\"****You reached the goal!****\")\n", 322 | " time.sleep(3)\n", 323 | " else:\n", 324 | " print(\"****You fell through a hole!****\")\n", 325 | " time.sleep(3)\n", 326 | " clear_output(wait=True)\n", 327 | " break\n", 328 | " state = new_state\n", 329 | "\n", 330 | "env.close()" 331 | ], 332 | "metadata": { 333 | "collapsed": false, 334 | "pycharm": { 335 | "name": "#%%\n", 336 | "is_executing": false 337 | } 338 | } 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": 8, 343 | "outputs": [], 344 | "source": [ 345 | "\n" 346 | ], 347 | "metadata": { 348 | "collapsed": false, 349 | "pycharm": { 350 | "name": "#%%\n", 351 | "is_executing": false 352 | } 353 | } 354 | } 355 | ], 356 | "metadata": { 357 | "kernelspec": { 358 | "display_name": "Python 3", 359 | "language": "python", 360 | "name": "python3" 361 | }, 362 | "language_info": { 363 | "codemirror_mode": { 364 | "name": "ipython", 365 | "version": 2 366 | }, 367 | "file_extension": ".py", 368 | "mimetype": "text/x-python", 369 | "name": "python", 370 | "nbconvert_exporter": "python", 371 | "pygments_lexer": "ipython2", 372 | "version": "2.7.6" 373 | }, 374 | "pycharm": { 375 | "stem_cell": { 376 | "cell_type": "raw", 377 | "source": [], 378 | "metadata": { 379 | "collapsed": false 380 | } 381 | } 382 | } 383 | }, 384 | "nbformat": 4, 385 | "nbformat_minor": 0 386 | } -------------------------------------------------------------------------------- /Breakout.py: -------------------------------------------------------------------------------- 1 | 2 | import datetime 3 | import torch.optim as optim 4 | import time 5 | # customized import 6 | from DQNs import * 7 | from utils import * 8 | from EnvManagers import AtariEnvManager 9 | from Agent import * 10 | 11 | 12 | param_json_fname = "DDQN_params.json" 13 | config_dict, hyperparams_dict, eval_dict = read_json(param_json_fname) 14 | if config_dict["IS_USED_TENSORBOARD"]: 15 | from torch.utils.tensorboard import SummaryWriter 16 | 17 | ### core classes 18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | print("using:",device) 20 | em = AtariEnvManager(device, config_dict["GAME_ENV"], config_dict["IS_USE_ADDITIONAL_ENDING_CRITERION"]) 21 | em.print_action_meanings() 22 | strategy = EpsilonGreedyStrategyLinear(hyperparams_dict["eps_start"], hyperparams_dict["eps_end"], hyperparams_dict["eps_final"], 23 | hyperparams_dict["eps_startpoint"], hyperparams_dict["eps_kneepoint"],hyperparams_dict["eps_final_knee"]) 24 | agent = Agent(strategy, em.num_actions_available(), device) 25 | if config_dict["IS_USED_PER"]: 26 | memory = ReplayMemory_economy_PER(hyperparams_dict["memory_size"]) 27 | else: 28 | memory = ReplayMemory_economy(hyperparams_dict["memory_size"]) 29 | 30 | # availible models: DQN_CNN_2013,DQN_CNN_2015, Dueling_DQN_2016_Modified 31 | if config_dict["MODEL_NAME"] == "DQN_CNN_2015": 32 | policy_net = DQN_CNN_2015(num_classes=em.num_actions_available(),init_weights=True).to(device) 33 | target_net = DQN_CNN_2015(num_classes=em.num_actions_available(),init_weights=True).to(device) 34 | elif config_dict["MODEL_NAME"] == "Dueling_DQN_2016_Modified": 35 | policy_net = Dueling_DQN_2016_Modified(num_classes=em.num_actions_available(), init_weights=True).to(device) 36 | target_net = Dueling_DQN_2016_Modified(num_classes=em.num_actions_available(), init_weights=True).to(device) 37 | else: 38 | print("No such model! Please check your configuration in .json file") 39 | 40 | target_net.load_state_dict(policy_net.state_dict()) 41 | target_net.eval() # this network will only be used for inference. 42 | optimizer = optim.Adam(params=policy_net.parameters(), lr=hyperparams_dict["lr"]) 43 | criterion = torch.nn.SmoothL1Loss() 44 | # can use tensorboard to track the reward 45 | if config_dict["IS_USED_TENSORBOARD"]: 46 | PATH_to_log_dir = config_dict["TENSORBOARD_PATH"] + datetime.datetime.now().strftime(config_dict["DATE_FORMAT"]) 47 | writer = SummaryWriter(PATH_to_log_dir) 48 | 49 | # print("num_actions_available: ",em.num_actions_available()) 50 | # print("action_meanings:" ,em.env.get_action_meanings()) 51 | 52 | # Auxilarty variables 53 | heldout_saver = HeldoutSaver(config_dict["HELDOUT_SET_DIR"], 54 | config_dict["HELDOUT_SET_MAX_PER_BATCH"], 55 | config_dict["HELDOUT_SAVE_RATE"]) 56 | tracker_dict = init_tracker_dict() 57 | 58 | plt.figure() 59 | # for estimating the time 60 | t1,t2 = time.time(),time.time() 61 | num_target_update = 0 62 | 63 | for episode in range(hyperparams_dict["num_episodes"]): 64 | em.reset() 65 | state = em.get_state() # initialize sate 66 | tol_reward = 0 67 | while(1): 68 | # Visualization of game process and state 69 | if config_dict["IS_RENDER_GAME_PROCESS"]: em.env.render() # BZX: will this slow down the speed? 70 | if config_dict["IS_VISUALIZE_STATE"]: visualize_state(state) 71 | if config_dict["IS_GENERATE_HELDOUT"]: heldout_saver.append(state) # generate heldout set for offline eval 72 | 73 | # Given s, select a by either policy_net or random 74 | action = agent.select_action(state, policy_net) 75 | # print(action) 76 | # collect unclipped reward from env along the action 77 | reward = em.take_action(action) 78 | tol_reward += reward 79 | tracker_dict["actions_counter"] += 1 80 | # after took a, get s' 81 | next_state = em.get_state() 82 | # push (s,a,s',r) into memory 83 | memory.push(Experience(state[0,-1,:,:].clone(), action, "", torch.sign(reward))) #clip reward!!! 84 | # update current state 85 | state = next_state 86 | 87 | # After memory have been filled with enough samples, we update policy_net every 4 agent steps. 88 | if (agent.current_step % hyperparams_dict["action_repeat"] == 0) and \ 89 | memory.can_provide_sample(hyperparams_dict["batch_size"], hyperparams_dict["replay_start_size"]): 90 | 91 | if config_dict["IS_USED_PER"]: 92 | experiences, experiences_index, weights = memory.sample(hyperparams_dict["batch_size"]) 93 | else: 94 | experiences = memory.sample(hyperparams_dict["batch_size"]) 95 | states, actions, rewards, next_states = extract_tensors(experiences) 96 | current_q_values = QValues.get_current(policy_net, states, actions) # checked 97 | # next_q_values = QValues.DQN_get_next(target_net, next_states) # for DQN 98 | next_q_values = QValues.DDQN_get_next(policy_net,target_net, next_states) 99 | target_q_values = (next_q_values * hyperparams_dict["gamma"]) + rewards 100 | # calculate loss and update policy_net 101 | optimizer.zero_grad() 102 | if config_dict["IS_USED_PER"]: 103 | # compute TD error 104 | TD_errors = torch.abs(current_q_values - target_q_values.unsqueeze(1)).detach().cpu().numpy() 105 | # update priorities 106 | memory.update_priority(experiences_index, TD_errors.squeeze(1)) 107 | # compute loss 108 | loss = torch.mean(weights.detach() * (current_q_values - target_q_values.unsqueeze(1))**2) 109 | else: 110 | loss = criterion(current_q_values, target_q_values.unsqueeze(1)) 111 | loss.backward() 112 | optimizer.step() 113 | 114 | tracker_dict["loss_hist"].append(loss.item()) 115 | tracker_dict["minibatch_updates_counter"] += 1 116 | 117 | # update target_net 118 | if tracker_dict["minibatch_updates_counter"] % hyperparams_dict["target_update"] == 0: 119 | target_net.load_state_dict(policy_net.state_dict()) 120 | 121 | # estimate time 122 | num_target_update += 1 123 | if num_target_update % 2 == 0: t1 = time.time() 124 | if num_target_update % 2 == 1: t2 = time.time() 125 | print("=" * 50) 126 | remaining_update_times = (config_dict["MAX_ITERATION"] - tracker_dict["minibatch_updates_counter"])// \ 127 | hyperparams_dict["target_update"] 128 | time_sec = abs(t1-t2) * remaining_update_times 129 | print("estimated remaining time = {}h-{}min".format(time_sec//3600,(time_sec%3600)//60)) 130 | print("len of replay memory:", len(memory.memory)) 131 | print("minibatch_updates_counter = ", tracker_dict["minibatch_updates_counter"]) 132 | print("current_step of agent = ", agent.current_step) 133 | print("exploration rate = ", strategy.get_exploration_rate(agent.current_step)) 134 | print("=" * 50) 135 | 136 | # save checkpoint model 137 | if tracker_dict["minibatch_updates_counter"] % config_dict["UPDATE_PER_CHECKPOINT"] == 0: 138 | save_model(policy_net, tracker_dict, config_dict) 139 | if not os.path.exists(config_dict["FIGURES_PATH"]): 140 | os.makedirs(config_dict["FIGURES_PATH"]) 141 | plt.savefig(config_dict["FIGURES_PATH"] + "Iterations_{}-Time_".format(tracker_dict["minibatch_updates_counter"]) + datetime.datetime.now().strftime( 142 | config_dict["DATE_FORMAT"]) + ".jpg") 143 | 144 | if em.done: 145 | tracker_dict["rewards_hist"].append(tol_reward) 146 | tracker_dict["rewards_hist_update_axis"].append(tracker_dict["minibatch_updates_counter"]) 147 | tracker_dict["running_reward"] = plot(tracker_dict["rewards_hist"], 100) 148 | # use tensorboard to track the reward 149 | if config_dict["IS_USED_TENSORBOARD"]: 150 | moving_avg_period = 100 151 | tracker_dict["moving_avg"] = get_moving_average(moving_avg_period, tracker_dict["rewards_hist"]) 152 | writer.add_scalars('reward', {'reward': tracker_dict["rewards_hist"][-1], 153 | 'reward_average': tracker_dict["moving_avg"][-1]}, episode) 154 | break 155 | 156 | if config_dict["IS_BREAK_BY_MAX_ITERATION"] and \ 157 | tracker_dict["minibatch_updates_counter"] > config_dict["MAX_ITERATION"]: 158 | break 159 | 160 | em.close() 161 | if config_dict["IS_USED_TENSORBOARD"]: 162 | writer.close() 163 | # save loss figure 164 | plt.figure() 165 | plt.plot(tracker_dict["loss_hist"]) 166 | plt.title("loss") 167 | plt.xlabel("iterations") 168 | plt.savefig(config_dict["FIGURES_PATH"] + "Loss-Iterations_{}-Time_".format(tracker_dict["minibatch_updates_counter"]) + datetime.datetime.now().strftime( 169 | config_dict["DATE_FORMAT"]) + ".jpg") 170 | 171 | # save tracker_dict["eval_model_list_txt"] to txt file 172 | if not os.path.exists(eval_dict["EVAL_MODEL_LIST_TXT_PATH"]): 173 | os.makedirs(eval_dict["EVAL_MODEL_LIST_TXT_PATH"]) 174 | txt_fname = "ModelName_{}-GameName_{}-Time_".format(config_dict["MODEL_NAME"],config_dict["GAME_NAME"]) + datetime.datetime.now().strftime( 175 | config_dict["DATE_FORMAT"]) + ".txt" 176 | with open( eval_dict["EVAL_MODEL_LIST_TXT_PATH"] + txt_fname,'w') as f: 177 | f.write('\n'.join(tracker_dict["eval_model_list_txt"])) 178 | 179 | # pickle tracker_dict for report figures 180 | print("="*100) 181 | print("saving results...") 182 | print("=" * 100) 183 | if not os.path.exists(config_dict["RESULT_PATH"]): 184 | os.makedirs(config_dict["RESULT_PATH"]) 185 | tracker_fname = "ModelName_{}-GameName_{}-Time_".format(config_dict["MODEL_NAME"],config_dict["GAME_NAME"]) + datetime.datetime.now().strftime( 186 | config_dict["DATE_FORMAT"]) + ".pkl" 187 | with open(config_dict["RESULT_PATH"] + tracker_fname,'wb') as f: 188 | pickle.dump(tracker_dict, f) 189 | 190 | 191 | if config_dict["IS_SAVE_MIDDLE_POINT"]: 192 | print("="*100) 193 | print("saving middel point...") 194 | print("=" * 100) 195 | # save core instances 196 | if not os.path.exists(config_dict["MIDDLE_POINT_PATH"]): 197 | os.makedirs(config_dict["MIDDLE_POINT_PATH"]) 198 | 199 | mdMemFileName = config_dict["MIDDLE_POINT_PATH"] + "MiddlePoint_Memory_" + datetime.datetime.now().strftime( 200 | config_dict["DATE_FORMAT"]) + ".pkl" 201 | middle_mem_file = open(mdMemFileName, 'wb') 202 | pickle.dump(memory, middle_mem_file) 203 | middle_mem_file.close() 204 | del memory 205 | # del memory # make more memory space 206 | 207 | midddle_point = {} 208 | midddle_point["agent"] = agent 209 | midddle_point["tracker_dict"] = tracker_dict 210 | midddle_point["heldout_batch_counter"] = heldout_saver.batch_counter 211 | midddle_point["strategy"] = strategy 212 | mdStateFileName = config_dict["MIDDLE_POINT_PATH"] + "MiddlePoint_State_" + datetime.datetime.now().strftime(config_dict["DATE_FORMAT"]) + ".pkl" 213 | 214 | middle_point_file = open(mdStateFileName, 'wb') 215 | pickle.dump(midddle_point, middle_point_file) 216 | middle_point_file.close() 217 | 218 | # save policy_net and target_net 219 | md_Policy_Net_fName = config_dict["MIDDLE_POINT_PATH"] + "MiddlePoint_Policy_Net_" + datetime.datetime.now().strftime(config_dict["DATE_FORMAT"]) + ".pth" 220 | torch.save(policy_net.state_dict(),md_Policy_Net_fName) 221 | md_Target_Net_fName = config_dict["MIDDLE_POINT_PATH"] + "MiddlePoint_Target_Net_" + datetime.datetime.now().strftime(config_dict["DATE_FORMAT"]) + ".pth" 222 | torch.save(policy_net.state_dict(), md_Target_Net_fName) 223 | 224 | # save middle point files' path for continuous training 225 | md_path_dict = {} 226 | md_path_dict["mdMemFileName"] = mdMemFileName 227 | md_path_dict["mdStateFileName"] = mdStateFileName 228 | md_path_dict["md_Policy_Net_fName"] = md_Policy_Net_fName 229 | md_path_dict["md_Target_Net_fName"] = md_Target_Net_fName 230 | with open('tmp_middle_point_file_path.json', 'w') as fp: 231 | json.dump(md_path_dict, fp) 232 | 233 | 234 | 235 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import torch 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | from collections import namedtuple 7 | import numpy as np 8 | import pickle 9 | import json 10 | import os 11 | import datetime 12 | import imageio 13 | from skimage.transform import resize as skimage_resize 14 | from sumtree import * 15 | 16 | is_ipython = 'inline' in matplotlib.get_backend() 17 | if is_ipython: 18 | from IPython import display 19 | else: 20 | matplotlib.use('TkAgg') 21 | # matplotlib.use('Agg') 22 | 23 | Experience = namedtuple( 24 | 'Experience', 25 | ('state', 'action', 'next_state', 'reward') 26 | ) 27 | 28 | Eco_Experience = namedtuple( 29 | 'Eco_Experience', 30 | ('state', 'action', 'reward') 31 | ) 32 | 33 | 34 | 35 | 36 | 37 | class ReplayMemory(): 38 | # initial memory 39 | def __init__(self, capacity): 40 | self.capacity = capacity 41 | self.memory = [] 42 | self.push_count = 0 43 | # self.dtype = torch.uint8 44 | 45 | def push(self, experience): 46 | if len(self.memory) < self.capacity: 47 | self.memory.append(experience) 48 | else: 49 | self.memory[self.push_count % self.capacity] = experience 50 | self.push_count += 1 51 | 52 | def sample(self, batch_size): 53 | return random.sample(self.memory, batch_size) 54 | 55 | def can_provide_sample(self, batch_size): 56 | return len(self.memory) >= batch_size 57 | 58 | class ReplayMemory_economy(): 59 | # save one state per experience to improve memory size 60 | def __init__(self, capacity): 61 | self.capacity = capacity 62 | self.memory = [] 63 | self.push_count = 0 64 | self.dtype = torch.uint8 65 | 66 | def push(self, experience): 67 | state = (experience.state * 255).type(self.dtype).cpu() 68 | # next_state = (experience.next_state * 255).type(self.dtype) 69 | new_experience = Eco_Experience(state,experience.action,experience.reward) 70 | 71 | if len(self.memory) < self.capacity: 72 | self.memory.append(new_experience) 73 | else: 74 | self.memory[self.push_count % self.capacity] = new_experience 75 | # print(id(experience)) 76 | # print(id(self.memory[0])) 77 | self.push_count += 1 78 | 79 | def sample(self, batch_size): 80 | # randomly sample experiences 81 | experience_index = np.random.randint(3, len(self.memory)-1, size = batch_size) 82 | # memory_arr = np.array(self.memory) 83 | experiences = [] 84 | for index in experience_index: 85 | if self.push_count > self.capacity: 86 | state = torch.stack(([self.memory[index+j].state for j in range(-3,1)])).unsqueeze(0) 87 | next_state = torch.stack(([self.memory[index+1+j].state for j in range(-3,1)])).unsqueeze(0) 88 | else: 89 | state = torch.stack(([self.memory[np.max(index+j, 0)].state for j in range(-3,1)])).unsqueeze(0) 90 | next_state = torch.stack(([self.memory[np.max(index+1+j, 0)].state for j in range(-3,1)])).unsqueeze(0) 91 | experiences.append(Experience(state.float().cuda()/255, self.memory[index].action, next_state.float().cuda()/255, self.memory[index].reward)) 92 | # return random.sample(self.memory, batch_size) 93 | return experiences 94 | 95 | def can_provide_sample(self, batch_size, replay_start_size): 96 | return (len(self.memory) >= replay_start_size) and (len(self.memory) >= batch_size + 3) 97 | 98 | class ReplayMemory_economy_PER(): 99 | # Memory replay with priorited experience replay 100 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 101 | def __init__(self, capacity, alpha=0.6, beta_start=0.4, beta_startpoint=50000, beta_kneepoint = 1000000, error_epsilon=1e-5): 102 | self.capacity = capacity 103 | self.memory = [] 104 | self.priority_tree = Sumtree(self.capacity) # store priorities 105 | self.alpha = alpha 106 | self.beta = beta_start 107 | self.beta_increase = 1/(beta_kneepoint - beta_startpoint) 108 | self.error_epsilon = error_epsilon 109 | self.push_count = 0 110 | self.dtype = torch.uint8 111 | 112 | def push(self, experience): 113 | state = (experience.state * 255).type(self.dtype).cpu() 114 | # next_state = (experience.next_state * 255).type(self.dtype) 115 | new_experience = Eco_Experience(state,experience.action,experience.reward) 116 | 117 | if len(self.memory) < self.capacity: 118 | self.memory.append(new_experience) 119 | else: 120 | self.memory[self.push_count % self.capacity] = new_experience 121 | self.push_count += 1 122 | # push new state to priority tree 123 | self.priority_tree.push() 124 | 125 | def sample(self, batch_size): 126 | # get indices of experience by priorities 127 | experience_index = [] 128 | experiences = [] 129 | priorities = [] 130 | segment = self.priority_tree.get_p_total()/batch_size 131 | self.beta = np.min([1., self.beta + self.beta_increase]) 132 | for i in range(batch_size): 133 | low = segment * i 134 | high = segment * (i+1) 135 | s = random.uniform(low, high) 136 | index, p = self.priority_tree.sample(s) 137 | experience_index.append(index) 138 | priorities.append(p) 139 | # get experience from index 140 | if self.push_count > self.capacity: 141 | state = torch.stack(([self.memory[index+j].state for j in range(-3,1)])).unsqueeze(0) 142 | next_state = torch.stack(([self.memory[index+1+j].state for j in range(-3,1)])).unsqueeze(0) 143 | else: 144 | state = torch.stack(([self.memory[np.max(index+j, 0)].state for j in range(-3,1)])).unsqueeze(0) 145 | next_state = torch.stack(([self.memory[np.max(index+1+j, 0)].state for j in range(-3,1)])).unsqueeze(0) 146 | experiences.append(Experience(state.float().cuda()/255, self.memory[index].action, next_state.float().cuda()/255, self.memory[index].reward)) 147 | # compute weight 148 | possibilities = priorities / self.priority_tree.get_p_total() 149 | min_possibility = self.priority_tree.get_p_min() 150 | weight = np.power(self.priority_tree.length * possibilities, -self.beta) 151 | max_weight = np.power(self.priority_tree.length * min_possibility, -self.beta) 152 | weight = weight/max_weight 153 | weight = torch.tensor(weight[:,np.newaxis], dtype = torch.float).to(ReplayMemory_economy_PER.device) 154 | return experiences, experience_index, weight 155 | 156 | def update_priority(self, index_list, TD_error_list): 157 | # update priorities from TD error 158 | # priorities_list = np.abs(TD_error_list) + self.error_epsilon 159 | priorities_list = (np.abs(TD_error_list) + self.error_epsilon) ** self.alpha 160 | for index, priority in zip(index_list, priorities_list): 161 | self.priority_tree.update(index, priority) 162 | 163 | def can_provide_sample(self, batch_size, replay_start_size): 164 | return (len(self.memory) >= replay_start_size) and (len(self.memory) >= batch_size + 3) 165 | 166 | class EpsilonGreedyStrategyExp(): 167 | # compute epsilon in epsilon-greedy algorithm by exponentially decrement 168 | def __init__(self, start, end, decay): 169 | self.start = start 170 | self.end = end 171 | self.decay = decay 172 | 173 | def get_exploration_rate(self, current_step): 174 | return self.end + (self.start - self.end) * \ 175 | math.exp(-1. * current_step * self.decay) 176 | 177 | class EpsilonGreedyStrategyLinear(): 178 | def __init__(self, start, end, final_eps = None, startpoint = 50000, kneepoint=1000000, final_knee_point = None): 179 | # compute epsilon in epsilon-greedy algorithm by linearly decrement 180 | self.start = start 181 | self.end = end 182 | self.final_eps = final_eps 183 | self.kneepoint = kneepoint 184 | self.startpoint = startpoint 185 | self.final_knee_point = final_knee_point 186 | 187 | def get_exploration_rate(self, current_step): 188 | if current_step < self.startpoint: 189 | return 1. 190 | mid_seg = self.end + \ 191 | np.maximum(0, (1-self.end)-(1-self.end)/self.kneepoint * (current_step-self.startpoint)) 192 | if not self.final_eps: 193 | return mid_seg 194 | else: 195 | if self.final_eps and self.final_knee_point and (current_step= period: 284 | moving_avg = values.unfold(dimension=0, size=period, step=1) \ 285 | .mean(dim=1).flatten(start_dim=0) 286 | moving_avg = torch.cat((torch.zeros(period-1), moving_avg)) 287 | return moving_avg.numpy() 288 | else: 289 | moving_avg = torch.zeros(len(values)) 290 | return moving_avg.numpy() 291 | 292 | def plot(values, moving_avg_period): 293 | """ 294 | test: plot(np.random.rand(300), 100) 295 | :param values: numpy 1D vector 296 | :param moving_avg_period: 297 | :return: None 298 | """ 299 | # plt.figure() 300 | plt.clf() 301 | plt.title('Training...') 302 | plt.xlabel('Episode') 303 | plt.ylabel('Reward') 304 | plt.plot(values) 305 | moving_avg = get_moving_average(moving_avg_period, values) 306 | plt.plot(moving_avg) 307 | print("Episode", len(values), "\n",moving_avg_period, "episode moving avg:", moving_avg[-1]) 308 | plt.pause(0.0001) 309 | # if is_ipython: display.clear_output(wait=True) 310 | return moving_avg[-1] 311 | 312 | def extract_tensors(experiences): 313 | # Convert batch of Experiences to Experience of batches 314 | batch = Experience(*zip(*experiences)) 315 | 316 | t1 = torch.cat(batch.state) 317 | t2 = torch.cat(batch.action) 318 | t3 = torch.cat(batch.reward) 319 | t4 = torch.cat(batch.next_state) 320 | 321 | return (t1,t2,t3,t4) 322 | 323 | def visualize_state(state): 324 | # settings 325 | nrows, ncols = 1, 4 # array of sub-plots 326 | figsize = [8, 4] # figure size, inches 327 | 328 | # prep (x,y) for extra plotting on selected sub-plots 329 | # xs = np.linspace(0, 2 * np.pi, 60) # from 0 to 2pi 330 | # ys = np.abs(np.sin(xs)) # absolute of sine 331 | 332 | # create figure (fig), and array of axes (ax) 333 | fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize) 334 | 335 | # plot simple raster image on each sub-plot 336 | for i, axi in enumerate(ax.flat): 337 | # i runs from 0 to (nrows*ncols-1) 338 | # axi is equivalent with ax[rowid][colid] 339 | img = state.squeeze(0)[i,None] 340 | cpu_img = img.squeeze(0).cpu() 341 | axi.imshow(cpu_img*255,cmap='gray', vmin=0, vmax=255) 342 | 343 | # get indices of row/column 344 | rowid = i // ncols 345 | colid = i % ncols 346 | # write row/col indices as axes' title for identification 347 | # axi.set_title("Row:" + str(rowid) + ", Col:" + str(colid)) 348 | 349 | # one can access the axes by ax[row_id][col_id] 350 | # do additional plotting on ax[row_id][col_id] of your choice 351 | # ax[0][2].plot(xs, 3 * ys, color='red', linewidth=3) 352 | # ax[4][3].plot(ys ** 2, xs, color='green', linewidth=3) 353 | 354 | plt.tight_layout(True) 355 | plt.show() 356 | 357 | def init_tracker_dict(): 358 | " init auxilary variables" 359 | tracker = {} 360 | tracker["minibatch_updates_counter"] = 1 361 | tracker["actions_counter"] = 1 362 | tracker["running_reward"] = 0 363 | tracker["rewards_hist"] = [] 364 | tracker["loss_hist"] = [] 365 | tracker["eval_model_list_txt"] = [] 366 | tracker["rewards_hist_update_axis"] = [] 367 | # only used in evaluation script 368 | tracker["eval_reward_list"] = [] 369 | tracker["best_frame_for_gif"] = [] 370 | tracker["best_reward"] = 0 371 | return tracker 372 | 373 | def save_model(policy_net, tracker_dict, config_dict): 374 | path = config_dict["CHECK_POINT_PATH"] + config_dict["GAME_NAME"] + "/" 375 | if not os.path.exists(path): 376 | os.makedirs(path) 377 | fname = "Iterations_{}-Reward{:.2f}-Time_".format(tracker_dict["minibatch_updates_counter"], 378 | tracker_dict["running_reward"]) + \ 379 | datetime.datetime.now().strftime(config_dict["DATE_FORMAT"]) + ".pth" 380 | torch.save(policy_net.state_dict(), path + fname) 381 | tracker_dict["eval_model_list_txt"].append(path + fname) 382 | 383 | def read_json(param_json_fname): 384 | with open(param_json_fname) as fp: 385 | params_dict = json.load(fp) 386 | 387 | config_dict = params_dict["config"] 388 | hyperparams_dict = params_dict["hyperparams"] 389 | eval_dict = params_dict["eval"] 390 | return config_dict, hyperparams_dict, eval_dict 391 | 392 | 393 | def load_Middle_Point(md_json_file_path): 394 | with open(md_json_file_path) as fp: 395 | md_path_dict = json.load(fp) 396 | return md_path_dict 397 | 398 | def generate_gif(gif_save_path,model_name, frames_for_gif, reward): 399 | """ 400 | Args: 401 | frames_for_gif: A sequence of (210, 160, 3) frames of an Atari game in RGB 402 | reward: Integer, Total reward of the episode that es ouputted as a gif 403 | """ 404 | if not os.path.exists(gif_save_path): 405 | os.makedirs(gif_save_path) 406 | for idx, frame_idx in enumerate(frames_for_gif): 407 | frames_for_gif[idx] = skimage_resize(frame_idx, (420, 320, 3), 408 | preserve_range=True, order=0).astype(np.uint8) 409 | fname = gif_save_path + model_name + "-EvalReward_{}.gif".format(reward) 410 | imageio.mimsave(fname, frames_for_gif, duration=1 / 30) 411 | --------------------------------------------------------------------------------