├── .idea
├── .gitignore
├── vcs.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── other.xml
├── modules.xml
├── RLPycharmProj.iml
└── misc.xml
├── figures_for_README
├── Pong_16.gif
├── Pong_21.gif
├── Breakout_425.gif
├── Breakout_756.gif
├── Iterations:1200000-Time:04-08-2020-22-06-38.jpg
└── Iterations:5000000-Time:03-30-2020-02-57-22.jpg
├── __pycache__
├── DQNs.cpython-35.pyc
├── DQNs.cpython-36.pyc
├── utils.cpython-35.pyc
├── utils.cpython-36.pyc
├── sumtree.cpython-36.pyc
├── EnvManagers.cpython-35.pyc
└── EnvManagers.cpython-36.pyc
├── Agent.py
├── plot_result.py
├── DDQN_params.json
├── sumtree.py
├── TestonHeldout.py
├── evaluation.py
├── CartPole.py
├── README.md
├── DQNs.py
├── EnvManagers.py
├── continue_training.py
├── Q_Learning_IceLake.ipynb
├── Breakout.py
└── utils.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /workspace.xml
--------------------------------------------------------------------------------
/figures_for_README/Pong_16.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/figures_for_README/Pong_16.gif
--------------------------------------------------------------------------------
/figures_for_README/Pong_21.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/figures_for_README/Pong_21.gif
--------------------------------------------------------------------------------
/__pycache__/DQNs.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/__pycache__/DQNs.cpython-35.pyc
--------------------------------------------------------------------------------
/__pycache__/DQNs.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/__pycache__/DQNs.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/__pycache__/utils.cpython-35.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/sumtree.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/__pycache__/sumtree.cpython-36.pyc
--------------------------------------------------------------------------------
/figures_for_README/Breakout_425.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/figures_for_README/Breakout_425.gif
--------------------------------------------------------------------------------
/figures_for_README/Breakout_756.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/figures_for_README/Breakout_756.gif
--------------------------------------------------------------------------------
/__pycache__/EnvManagers.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/__pycache__/EnvManagers.cpython-35.pyc
--------------------------------------------------------------------------------
/__pycache__/EnvManagers.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/__pycache__/EnvManagers.cpython-36.pyc
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/figures_for_README/Iterations:1200000-Time:04-08-2020-22-06-38.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/figures_for_README/Iterations:1200000-Time:04-08-2020-22-06-38.jpg
--------------------------------------------------------------------------------
/figures_for_README/Iterations:5000000-Time:03-30-2020-02-57-22.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jasonbian97/Deep-Q-Learning-Atari-Pytorch/HEAD/figures_for_README/Iterations:5000000-Time:03-30-2020-02-57-22.jpg
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/RLPycharmProj.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/Agent.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | class Agent():
4 | def __init__(self, strategy, num_actions, device):
5 | self.current_step = 0
6 | self.strategy = strategy
7 | self.num_actions = num_actions
8 | self.device = device
9 |
10 | def select_action(self, state, policy_net):
11 | rate = self.strategy.get_exploration_rate(self.current_step)
12 | self.current_step += 1
13 | # print("eps = ",rate)
14 | if rate > random.random():
15 | action = random.randrange(self.num_actions)
16 | return torch.tensor([action]).to(self.device) # explore
17 | else:
18 | with torch.no_grad():
19 | return policy_net(state).argmax(dim=1).to(self.device) # exploit
20 |
--------------------------------------------------------------------------------
/plot_result.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import torch.optim as optim
3 | import time
4 | # customized import
5 | from DQNs import *
6 | from utils import *
7 | from EnvManagers import BreakoutEnvManager
8 | from Agent import *
9 |
10 | print("="*100)
11 | print("loading mid point file...")
12 | result_pkl_path = "./Results/ModelName:2015_CNN_DQN-GameName:Breakout-Time:03-28-2020-18-20-28.pkl"
13 | with open(result_pkl_path, 'rb') as fresult:
14 | tracker_dict = pickle.load(fresult)
15 | plt.plot(tracker_dict["rewards_hist"])
16 | plt.show()
17 |
18 | print("="*100)
19 | print("loading mid point file...")
20 | result_pkl_path = "./Results/ModelName:2015_CNN_DQN-GameName:Breakout-Time:03-28-2020-18-20-28-Eval.pkl"
21 | with open(result_pkl_path, 'rb') as fresult:
22 | tracker_dict = pickle.load(fresult)
23 | plt.plot(tracker_dict["eval_reward_list"])
24 | plt.show()
25 |
26 |
--------------------------------------------------------------------------------
/DDQN_params.json:
--------------------------------------------------------------------------------
1 | {
2 | "config": {
3 | "MODEL_NAME": "DQN_CNN_2015",
4 | "GAME_NAME": "Breakout",
5 | "GAME_ENV": "BreakoutDeterministic-v4",
6 | "IS_USE_ADDITIONAL_ENDING_CRITERION": false,
7 | "CHECK_POINT_PATH": "./checkpoints/",
8 | "FIGURES_PATH": "./figures/",
9 | "MIDDLE_POINT_PATH": "./MiddlePoints/",
10 | "TENSORBOARD_PATH": "./Tensorboard/",
11 | "RESULT_PATH": "./Results/",
12 | "HELDOUT_SET_DIR": "./Heldout/",
13 | "HELDOUT_SET_MAX_PER_BATCH": 5000,
14 | "HELDOUT_SAVE_RATE": 0.005,
15 | "MAX_ITERATION": 100000,
16 | "MAX_FRAMES": 22000000,
17 | "UPDATE_PER_CHECKPOINT": 20000,
18 | "DATE_FORMAT": "%m-%d-%Y-%H-%M-%S",
19 | "IS_GENERATE_HELDOUT": false,
20 | "IS_SAVE_MIDDLE_POINT": false,
21 | "IS_VISUALIZE_STATE": false,
22 | "IS_RENDER_GAME_PROCESS": true,
23 | "IS_BREAK_BY_MAX_FRAMES": false,
24 | "IS_USED_TENSORBOARD": false,
25 | "IS_BREAK_BY_MAX_ITERATION": true,
26 | "IS_USED_PER": false
27 | },
28 | "hyperparams": {
29 | "batch_size": 32,
30 | "gamma": 0.99,
31 | "eps_start": 1.0,
32 | "eps_end": 0.1,
33 | "eps_final": 0.01,
34 | "eps_startpoint": 50000,
35 | "eps_kneepoint": 1000000,
36 | "eps_final_knee": 22000000,
37 | "alpha": 0.6,
38 | "beta_start": 0.4,
39 | "beta_startpoint": 50000,
40 | "beta_kneepoint": 1000000,
41 | "error_epsilon": 1e-5,
42 | "lr_PER": 0.000005,
43 | "target_update": 2500,
44 | "memory_size": 1000000,
45 | "lr": 0.00005,
46 | "num_episodes": 100000,
47 | "replay_start_size": 50000,
48 | "action_repeat": 4
49 | },
50 | "eval": {
51 | "EVAL_EPISODE": 10,
52 | "GIF_SAVE_PATH": "./GIF_Reuslts/",
53 | "ACTIONS_PER_EVAL": 100000,
54 | "EVAL_MODEL_LIST_TXT_PATH": "./eval_model_list_txt/",
55 | "IS_RENDER_GAME_PROCESS": true,
56 | "IS_VISUALIZE_STATE": false,
57 | "IS_USE_ADDITIONAL_ENDING_CRITERION": true
58 | }
59 | }
--------------------------------------------------------------------------------
/sumtree.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | # store the priorities of experience
3 | # the priorities of memory samples are stored in leaf node, and the value of parent node is the sum of its children
4 | class Sumtree():
5 | def __init__(self, capacity):
6 | self.capacity = capacity
7 | self.tree = np.zeros(2*capacity - 1) # store the priorities of memory
8 | self.tree[capacity - 1] = 1
9 | self.stored = [False] * (2*capacity - 1) # indicate whether this node is used to store
10 | # self.cur_point = 0
11 | self.length = 0 # maximum length is capacity
12 | self.push_count = 0
13 |
14 | def update_node(self, index, change):
15 | # update sum tree from leaf node if the priority of leaf node changed
16 | parent = (index-1)//2
17 | self.tree[parent] += change
18 | self.stored[parent] = True
19 | if parent > 0:
20 | self.update_node(parent, change)
21 |
22 | def update(self, index_memory, p):
23 | # update sum tree from new priority
24 | index = index_memory + self.capacity - 1
25 | change = p - self.tree[index]
26 | self.tree[index] = p
27 | self.stored[index] = True
28 | self.update_node(index, change)
29 |
30 | def get_p_total(self):
31 | # return total priorities
32 | return self.tree[0]
33 |
34 | def get_p_min(self):
35 | return min(self.tree[self.capacity-1:self.length+self.capacity-1])
36 |
37 | def get_by_priority(self, index, s):
38 | # get index of node by priority s
39 | left_child = index*2 + 1
40 | right_child = index*2 + 2
41 | if left_child >= self.tree.shape[0]:
42 | return index
43 | if self.stored[left_child] == False:
44 | return self.get_by_priority(right_child, s-self.tree[left_child])
45 | if self.stored[right_child] == False:
46 | return self.get_by_priority(left_child, s)
47 | if s <= self.tree[left_child]:
48 | return self.get_by_priority(left_child, s)
49 | else:
50 | return self.get_by_priority(right_child, s-self.tree[left_child])
51 |
52 | def sample(self, s):
53 | # sample node by priority s, return the index and priority of experience
54 | self.stored[self.length + self.capacity - 2] = False # cannot sample the latest state
55 | index = self.get_by_priority(0, s)
56 | return index - self.capacity + 1, self.tree[index]
57 |
58 | def push(self):
59 | # push experience, the initial priority is the maximum priority in sum tree
60 | index_memory = self.push_count % self.capacity
61 | if self.length < self.capacity:
62 | self.length += 1
63 | self.update(index_memory, np.max(self.tree[self.capacity-1 : self.capacity+self.length-1]))
64 | self.push_count += 1
65 |
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/TestonHeldout.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import torch.optim as optim
3 | import time
4 | # customized import
5 | from DQNs import *
6 | from utils import *
7 | from EnvManagers import AtariEnvManager
8 | from Agent import *
9 |
10 |
11 | param_json_fname = "DDQN_params.json"
12 | # checkpoints to evaluate
13 | model_list_fname = "./eval_model_list_txt/ModelName_2015_CNN_DQN-GameName_Breakout-Time_03-30-2020-02-57-36.txt"
14 |
15 | config_dict, hyperparams_dict, eval_dict = read_json(param_json_fname)
16 |
17 | # if there is not the txt file to store name of checkpoints, they can be directly loaded from folder by following code
18 | """ model_list = os.listdir(config_dict["CHECK_POINT_PATH"]+config_dict["GAME_NAME"])
19 | model_list = [config_dict["CHECK_POINT_PATH"]+config_dict["GAME_NAME"]+'/'+x for x in model_list] """
20 |
21 | # load model list
22 | with open(model_list_fname) as f:
23 | model_list = f.readlines()
24 | model_list = [x.strip() for x in model_list] # remove whitespace characters like `\n` at the end of each line
25 |
26 | subfolder = model_list_fname.split("/")[-1][:-4]
27 | # get the update iterations of each checkpoint
28 | iterations = [int(x.split("/")[-1].split("_")[1].split("-")[0])/config_dict["UPDATE_PER_CHECKPOINT"] for x in model_list]
29 | iterations.sort()
30 | model_list.sort(key = lambda x: int(x.split("/")[-1].split("_")[1].split("-")[0]))
31 |
32 | # set environment
33 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34 | em = AtariEnvManager(device, config_dict["GAME_ENV"], config_dict["IS_USE_ADDITIONAL_ENDING_CRITERION"])
35 |
36 | # set policy net
37 | if config_dict["MODEL_NAME"] == "DQN_CNN_2015":
38 | policy_net = DQN_CNN_2015(num_classes=em.num_actions_available(),init_weights=True).to(device)
39 | target_net = DQN_CNN_2015(num_classes=em.num_actions_available(),init_weights=True).to(device)
40 | elif config_dict["MODEL_NAME"] == "Dueling_DQN_2016_Modified":
41 | policy_net = Dueling_DQN_2016_Modified(num_classes=em.num_actions_available(), init_weights=True).to(device)
42 | target_net = Dueling_DQN_2016_Modified(num_classes=em.num_actions_available(), init_weights=True).to(device)
43 | else:
44 | print("No such model! Please check your configuration in .json file")
45 |
46 | # Auxilary variables
47 | tracker_dict = {}
48 | tracker_dict["UPDATE_PER_CHECKPOINT"] = config_dict["UPDATE_PER_CHECKPOINT"]
49 | tracker_dict["Qvalue_average_list"] = []
50 | for model_fpath in model_list:
51 | print("testing: ",model_fpath)
52 | # load model from file
53 | policy_net.load_state_dict(torch.load(model_fpath))
54 | policy_net.eval()
55 | Qvalue_model = []
56 | # load heldout set
57 | hfiles = os.listdir(config_dict["HELDOUT_SET_DIR"])
58 | for hfile in hfiles:
59 | with open(config_dict["HELDOUT_SET_DIR"]+'/'+hfile,'rb') as f:
60 | heldout_set = pickle.load(f)
61 | for state in heldout_set:
62 | # compute Qvalue by loaded model
63 | with torch.no_grad():
64 | Qvalue = policy_net(state.float().cuda()/255).detach().max(dim=1)[0].cpu().item()
65 | Qvalue_model.append(Qvalue)
66 | tracker_dict["Qvalue_average_list"].append(sum(Qvalue_model)/len(Qvalue_model))
67 |
68 |
69 | if not os.path.exists(config_dict["RESULT_PATH"]):
70 | os.makedirs(config_dict["RESULT_PATH"])
71 |
72 | # save the figure
73 | plt.figure()
74 | plt.plot(iterations, tracker_dict["Qvalue_average_list"])
75 | plt.title("Average Q on held out set")
76 | plt.xlabel("Training iterations")
77 | plt.ylabel("Average Q value")
78 | plt.savefig(config_dict["RESULT_PATH"] + subfolder + "Average_Q_value.jpg")
79 |
80 | tracker_fname = subfolder + "-Eval.pkl"
81 | with open(config_dict["RESULT_PATH"] + tracker_fname, 'wb') as f:
82 | pickle.dump(tracker_dict, f)
83 |
84 |
85 |
86 |
87 |
88 |
89 |
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import torch.optim as optim
3 | import time
4 | # customized import
5 | from DQNs import *
6 | from utils import *
7 | from EnvManagers import AtariEnvManager
8 | from Agent import *
9 |
10 |
11 | param_json_fname = "DDQN_params.json"
12 | model_list_fname = "./eval_model_list_txt/ModelName:Dueling_DQN_2016_Modified-GameName:Pong-Time:04-07-2020-18-30-30.txt" #TODO
13 |
14 | config_dict, hyperparams_dict, eval_dict = read_json(param_json_fname)
15 |
16 | with open(model_list_fname) as f:
17 | model_list = f.readlines()
18 | model_list = [x.strip() for x in model_list] # remove whitespace characters like `\n` at the end of each line
19 | subfolder = model_list_fname.split("/")[-1][:-4]
20 | # setup params
21 |
22 | # Auxilary variables
23 | tracker_dict = {}
24 | tracker_dict["UPDATE_PER_CHECKPOINT"] = config_dict["UPDATE_PER_CHECKPOINT"]
25 | tracker_dict["eval_reward_list"] = []
26 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27 | # evaluate each checkpoint model
28 | for model_fpath in model_list:
29 | print("testing: ",model_fpath)
30 |
31 | # load model from file
32 | em = AtariEnvManager(device,game_env=config_dict["GAME_ENV"],
33 | is_use_additional_ending_criterion= eval_dict["IS_USE_ADDITIONAL_ENDING_CRITERION"])
34 | if config_dict["MODEL_NAME"] == "DQN_CNN_2015":
35 | policy_net = DQN_CNN_2015(num_classes=em.num_actions_available(),init_weights=False).to(device)
36 | elif config_dict["MODEL_NAME"] == "Dueling_DQN_2016_Modified":
37 | policy_net = Dueling_DQN_2016_Modified(num_classes=em.num_actions_available(), init_weights=False).to(device)
38 | else:
39 | print("No such model! Please check your configuration in .json file")
40 | policy_net.load_state_dict(torch.load(model_fpath))
41 | policy_net.eval() # this network will only be used for inference.
42 | # setup greedy strategy and Agent class
43 | strategy = FullGreedyStrategy(0.01)
44 | agent = Agent(strategy, em.num_actions_available(), device)
45 |
46 | best_frames_for_gif = None
47 | best_reward = -99999
48 |
49 | reward_list_episodes = []
50 | for episode in range(eval_dict["EVAL_EPISODE"]):
51 | em.reset()
52 | state = em.get_state() # initialize sate
53 | reward_per_episode = 0
54 | frames_for_gif = []
55 |
56 | while(1):
57 | if eval_dict["IS_RENDER_GAME_PROCESS"]: em.env.render() #render in 'human mode. BZX: will this slow down the speed?
58 | if eval_dict["IS_VISUALIZE_STATE"]: visualize_state(state)
59 | frame = em.render('rgb_array')
60 | frames_for_gif.append(frame)
61 | # Given s, select a by strategy
62 | if em.done or em.is_additional_ending or em.is_initial_action():
63 | action = torch.tensor([1])
64 | else:
65 | action = agent.select_action(state, policy_net)
66 | # take action
67 | reward = em.take_action(action)
68 | # collect unclipped reward from env along the action
69 | reward_per_episode += reward
70 | # after took a, get s'
71 | next_state = em.get_state()
72 | # update current state
73 | state = next_state
74 |
75 | if em.done:
76 | reward_list_episodes.append(reward_per_episode.cpu().item())
77 | break
78 | #write gif
79 | if reward_per_episode > best_reward:
80 | best_reward = reward_per_episode
81 | #update best frames list
82 | best_frames_for_gif = frames_for_gif.copy()
83 |
84 | tracker_dict["eval_reward_list"].append(np.median(reward_list_episodes))
85 | # print( tracker_dict["eval_reward_list"])
86 | # save results
87 | if best_reward > 0.8 * np.median(tracker_dict["eval_reward_list"]):
88 | model_name = model_fpath.split("/")[-1][:-4]
89 | generate_gif(eval_dict["GIF_SAVE_PATH"] + subfolder + "/", model_name, best_frames_for_gif, best_reward.cpu().item())
90 |
91 | if not os.path.exists(config_dict["RESULT_PATH"]):
92 | os.makedirs(config_dict["RESULT_PATH"])
93 |
94 | tracker_fname = subfolder + "-Eval.pkl"
95 | with open(config_dict["RESULT_PATH"] + tracker_fname, 'wb') as f:
96 | pickle.dump(tracker_dict, f)
97 |
98 |
99 |
100 |
101 |
102 |
103 |
--------------------------------------------------------------------------------
/CartPole.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import math
3 | import random
4 | import numpy as np
5 | import matplotlib
6 | import matplotlib.pyplot as plt
7 | import os
8 | import sys
9 | import torch.nn.functional as F
10 | import datetime
11 | from itertools import count
12 | import torch.nn.functional as F
13 | from PIL import Image
14 | import torch
15 | import torch.optim as optim
16 | import torchvision.transforms as T
17 | # customized import
18 | from DQNs import DQN
19 | from utils import *
20 | from EnvManagers import CartPoleEnvManager
21 | import pickle
22 |
23 |
24 |
25 | class Agent():
26 | def __init__(self, strategy, num_actions, device):
27 | self.current_step = 0
28 | self.strategy = strategy
29 | self.num_actions = num_actions
30 | self.device = device
31 |
32 | def select_action(self, state, policy_net):
33 | rate = self.strategy.get_exploration_rate(self.current_step)
34 | self.current_step += 1
35 |
36 | if rate > random.random():
37 | action = random.randrange(self.num_actions)
38 | return torch.tensor([action]).to(self.device) # explore
39 | else:
40 | with torch.no_grad():
41 | return policy_net(state).argmax(dim=1).to(self.device) # exploit
42 |
43 | # Configuration:
44 | CHECK_POINT_PATH = "./checkpoints/"
45 | FIGURES_PATH = "./figures/"
46 | GAME_NAME = "CartPole"
47 | DATE_FORMAT = "%m-%d-%Y-%H-%M-%S"
48 | EPISODES_PER_CHECKPOINT = 1000
49 |
50 | # Hyperparameters
51 | batch_size = 32
52 | gamma = 0.99
53 | eps_start = 1
54 | eps_end = 0.1
55 | # eps_decay = 0.001
56 | eps_kneepoint = 500000
57 | target_update = 10
58 | memory_size = 100000
59 | lr = 0.0005
60 | num_episodes = 100000
61 |
62 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63 | em = CartPoleEnvManager(device)
64 | strategy = EpsilonGreedyStrategyLinear(eps_start, eps_end, eps_kneepoint)
65 | agent = Agent(strategy, em.num_actions_available(), device)
66 | memory = ReplayMemory(memory_size)
67 |
68 | policy_net = DQN(em.get_screen_height(), em.get_screen_width()).to(device)
69 | target_net = DQN(em.get_screen_height(), em.get_screen_width()).to(device)
70 | target_net.load_state_dict(policy_net.state_dict())
71 | target_net.eval() # this network will only be used for inference.
72 | optimizer = optim.Adam(params=policy_net.parameters(), lr=lr)
73 | criterion = torch.nn.SmoothL1Loss()
74 |
75 | heldoutset_counter = 0
76 | HELD_OUT_SET = []
77 | episode_durations = []
78 | running_reward = 0
79 | plt.figure()
80 | # for episode in range(num_episodes):
81 | # em.reset()
82 | # state = em.get_state()
83 | #
84 | # for timestep in count():
85 | # action = agent.select_action(state, policy_net)
86 | # reward = em.take_action(action)
87 | # next_state = em.get_state()
88 | # if random.random() < 0.005:
89 | # HELD_OUT_SET.append(next_state.cpu().numpy())
90 | # if len(HELD_OUT_SET) == 2000:
91 | # heldoutset_file = open('heldoutset-CartPole-{}'.format(heldoutset_counter), 'wb')
92 | # pickle.dump(HELD_OUT_SET, heldoutset_file)
93 | # heldoutset_file.close()
94 | # HELD_OUT_SET = []
95 | # heldoutset_counter += 1
96 | #
97 | # memory.push(Experience(state, action, next_state, reward))
98 | # state = next_state
99 | #
100 | # if memory.can_provide_sample(batch_size):
101 | # experiences = memory.sample(batch_size)
102 | # states, actions, rewards, next_states = extract_tensors(experiences)
103 | #
104 | # current_q_values = QValues.get_current(policy_net, states, actions)
105 | # next_q_values = QValues.get_next(target_net, next_states)
106 | # target_q_values = (next_q_values * gamma) + rewards
107 | #
108 | # # loss = F.mse_loss(current_q_values, target_q_values.unsqueeze(1))
109 | # loss = criterion(current_q_values, target_q_values.unsqueeze(1))
110 | # optimizer.zero_grad()
111 | # loss.backward()
112 | # optimizer.step()
113 | #
114 | # if em.done:
115 | # episode_durations.append(timestep)
116 | # running_reward = plot(episode_durations, 100)
117 | # break
118 | #
119 | # # BZX: checkpoint model
120 | # if episode % EPISODES_PER_CHECKPOINT == 0:
121 | # path = CHECK_POINT_PATH + GAME_NAME + "/"
122 | # if not os.path.exists(path):
123 | # os.makedirs(path)
124 | # torch.save(policy_net.state_dict(),
125 | # path + "Episodes:{}-Reward:{:.2f}-Time:".format(episode, running_reward) + \
126 | # datetime.datetime.now().strftime(DATE_FORMAT) + ".pth")
127 | # plt.savefig(FIGURES_PATH + "Episodes:{}-Time:".format(episode) + datetime.datetime.now().strftime(
128 | # DATE_FORMAT) + ".jpg")
129 | #
130 | # if episode % target_update == 0:
131 | # target_net.load_state_dict(policy_net.state_dict())
132 | #
133 | # em.close()
134 | em.get_state()
135 | for i in range(5):
136 | em.take_action(torch.tensor([1]))
137 | screen = em.get_state()
138 |
139 | plt.figure()
140 | plt.imshow(screen.squeeze(0).permute(1, 2, 0).cpu(), interpolation='none')
141 | plt.title('Non starting state example')
142 | plt.show()
143 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | # RL-Atari-gym
7 | Reinforcement Learning on Atari Games and Control
8 |
9 | Entrance of program:
10 | - Breakout.py
11 |
12 |
13 | # How to run
14 |
15 | (1). Check DDQN_params.json, make sure that every parameter is set right.
16 |
17 | ```markdown
18 | GAME_NAME # Set the game's name . This will help you create a new dir to save your result.
19 | MODEL_NAME # Set the algorithms and model you are using. This is only used for rename your result file, so you still need
20 | to change the model isntanace manually.
21 | MAX_ITERATION # In original paper, this is set to 25,000,000. But here we set it to 5,000,000 for Breakout.(2,500,000 for Pong will suffice.)
22 | num_episodes # Max number of episodes. We set it to a huge number in default so normally this stop condition
23 | usually won't be satisfied.
24 | # the program will stop when one of the above condition is met.
25 | ```
26 | (2). Select the **model** and **game environment instance** manually. Currently, we are mainly focusing on `DQN_CNN_2015` and `Dueling_DQN_2016_Modified`.
27 |
28 | (3). Run and prey :)
29 |
30 | NOTE: When the program is running, wait for a couple of minutes and take a look at the estimated time printed in the
31 | console. Stop early and decrease the `MAX_ITERATION` if you cannot wait for such a long time. (Recommendation: typically,
32 | 24h could be a reasonable running time for your first training process. Since you can continue training your model, take
33 | a rest for both you and computer and check the saved figures to see if your model has a promising future. Hope so ~ )
34 |
35 | # How to continue training the model
36 |
37 | The breakout.py will automatically save the mid point state and variables for you if the program exit w/o exception.
38 |
39 | 1. set the middle_point_json file path.
40 |
41 | 2. check DDQN_params.json, make sure that every parameter is set right. Typically, you need to set a new `MAX_ITERATION`
42 | or `num_episodes` .
43 |
44 | 3. Run and prey :)
45 |
46 | # How to evaluate the Model
47 |
48 | `evaluation.py` helps you evaluate the model. First, please modified `param_json_fname` and `model_list_fname` to your
49 | directory. Second, change the game environment instance and the model instance. Then run.
50 |
51 | # Results Structure
52 |
53 | The program will automatically create the the directory like this:
54 |
55 | ```markdown
56 | ├── GIF_Reuslts
57 | │ └── ModelName:2015_CNN_DQN-GameName:Breakout-Time:03-28-2020-18-20-28
58 | │ ├── Iterations:100000-Reward:0.69-Time:03-28-2020-18-20-27-EvalReward:0.0.gif
59 | │ ├── Iterations:200000-Reward:0.69-Time:03-28-2020-18-20-27-EvalReward:1.0.gif
60 | ├── Results
61 | │ ├── ModelName:2015_CNN_DQN-GameName:Breakout-Time:03-28-2020-18-20-28-Eval.pkl
62 | │ └── ModelName:2015_CNN_DQN-GameName:Breakout-Time:03-28-2020-18-20-28.pkl
63 | ├── DDQN_params.json
64 |
65 | ```
66 |
67 | Please zip these three files/folders and upload it to our shared google drive. Rename it, e.g. `ModelName:2015_CNN_DQN-GameName:Breakout-Time:03-28-2020-18-20-28`.
68 |
69 | PS:
70 |
71 | `GIF_Reuslts` record the game process
72 |
73 | `Results` contains the history of training and eval process, which can be used to visualize later.
74 |
75 | `DDQN_params.json` contains your algorithm settings, which should match your `Results` and `GIF_Reuslts`.
76 |
77 | # TODO list:
78 | -[ ] Write env class for Pong, Cartpole and other games. (Attention: cropped bbox might need to be changed for different
79 | game.)
80 |
81 | -[ ] Write validation script on heldout sets. Load models and heldout sets, track average max Q value on heldout sets.
82 | (NOTE: load and test models in time sequence indicated by the name of model file.)
83 |
84 | -[ ] Design experiment table. Test two more Atari games. Give average performance(reward) and write .gif file. Store other figures & model for
85 | writing final report.
86 |
87 | -[ ] Implement policy gradient for Atari games. [TBD]
88 |
89 | -[ ] Possible bugs: initialization & final state representation?
90 |
91 | -[x] Implement Priority Queue and compare the performance.
92 |
93 | -[x] Evaluation script. Load model and take greedy strategy to interact with environment.
94 | Test a few epochs and give average performance. Write the best one to .gif file for presentation.
95 |
96 | -[x] Implement continuous training script.
97 |
98 | -[x] Fix eps manager(change final state from 0.1 to 0.01); add evaluation step in the training loop;
99 | write test result into gif; update Target_Net according to the # of actions instead of # of updates.
100 | Rewrite image preprocessing class to tackle with more general game input.(crop at (34,0,160,160)
101 |
102 | # File Instruction(TO BE COMPLETE SOON):
103 |
104 | 'EnvManagers.py' includes the different environment classes for different games. They wrapped the gym.env and its interface.
105 |
106 | 'DQNs.py' includes different Deep Learning architectures for feature extraction and regression.
107 |
108 | 'utils.py' includes tool functions and classes. To be specific, it includes:
109 | - Experience (namedtuple)
110 | - ReplayMemory (class)
111 | - EpsilonGreedyStrategy (class)
112 | - plot (func)
113 | - extract_tensors (func)
114 | - QValues (class)
115 |
--------------------------------------------------------------------------------
/DQNs.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 |
5 | class DQN_NN_Naive(nn.Module):
6 | def __init__(self, img_height, img_width):
7 | super().__init__()
8 |
9 | self.fc1 = nn.Linear(in_features=img_height * img_width * 3, out_features=24)
10 | self.fc2 = nn.Linear(in_features=24, out_features=32)
11 | self.out = nn.Linear(in_features=32, out_features=2)
12 |
13 | def forward(self, t):
14 | t = t.flatten(start_dim=1)
15 | t = F.relu(self.fc1(t))
16 | t = F.relu(self.fc2(t))
17 | t = self.out(t)
18 | return t
19 |
20 |
21 | class DQN_CNN_2013(nn.Module):
22 | def __init__(self, num_classes=4, init_weights=True):
23 | super().__init__()
24 |
25 | self.cnn = nn.Sequential(nn.Conv2d(4, 16, kernel_size=8, stride=4),
26 | nn.ReLU(True),
27 | nn.Conv2d(16, 32, kernel_size=4, stride=2),
28 | nn.ReLU(True)
29 | )
30 | self.classifier = nn.Sequential(nn.Linear(9*9*32, 256),
31 | nn.ReLU(True),
32 | nn.Linear(256, num_classes)
33 | )
34 | # nn.Dropout(0.3), # BZX: optional [TRY]
35 | if init_weights:
36 | self._initialize_weights()
37 |
38 | def forward(self, x):
39 | x = self.cnn(x)
40 | x = torch.flatten(x, start_dim=1)
41 | x = self.classifier(x)
42 | return x
43 |
44 | def _initialize_weights(self):
45 | for m in self.modules():
46 | if isinstance(m, nn.Conv2d):
47 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
48 | if m.bias is not None:
49 | nn.init.constant_(m.bias, 0.0)
50 | elif isinstance(m, nn.BatchNorm2d):
51 | nn.init.constant_(m.weight, 1.0)
52 | nn.init.constant_(m.bias, 0.0)
53 | elif isinstance(m, nn.Linear):
54 | nn.init.normal_(m.weight, 0.0, 0.01)
55 | nn.init.constant_(m.bias, 0.0)
56 |
57 | class DQN_CNN_2015(nn.Module):
58 | def __init__(self, num_classes=4, init_weights=True):
59 | super().__init__()
60 |
61 | self.cnn = nn.Sequential(nn.Conv2d(4, 32, kernel_size=8, stride=4),
62 | nn.ReLU(True),
63 | nn.Conv2d(32, 64, kernel_size=4, stride=2),
64 | nn.ReLU(True),
65 | nn.Conv2d(64, 64, kernel_size=3, stride=1),
66 | nn.ReLU(True)
67 | )
68 | self.classifier = nn.Sequential(nn.Linear(7*7*64, 512),
69 | nn.ReLU(True),
70 | nn.Linear(512, num_classes)
71 | )
72 | if init_weights:
73 | self._initialize_weights()
74 |
75 | def forward(self, x):
76 | x = self.cnn(x)
77 | x = torch.flatten(x, start_dim=1)
78 | x = self.classifier(x)
79 | return x
80 |
81 | def _initialize_weights(self):
82 | for m in self.modules():
83 | if isinstance(m, nn.Conv2d):
84 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
85 | if m.bias is not None:
86 | nn.init.constant_(m.bias, 0.0)
87 | elif isinstance(m, nn.BatchNorm2d):
88 | nn.init.constant_(m.weight, 1.0)
89 | nn.init.constant_(m.bias, 0.0)
90 | elif isinstance(m, nn.Linear):
91 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
92 | nn.init.constant_(m.bias, 0.0)
93 |
94 | #TODO
95 | class Dueling_DQN_2016_Modified(nn.Module):
96 | def __init__(self, num_classes=4, init_weights=True):
97 | super().__init__()
98 |
99 | self.cnn = nn.Sequential(nn.Conv2d(4, 32, kernel_size=8, stride=4,bias=False),
100 | nn.ReLU(True),
101 | nn.Conv2d(32, 64, kernel_size=4, stride=2,bias=False),
102 | nn.ReLU(True),
103 | nn.Conv2d(64, 64, kernel_size=3, stride=1,bias=False),
104 | nn.ReLU(True),
105 | nn.Conv2d(64,1024,kernel_size=7,stride=1,bias=False),
106 | nn.ReLU(True)
107 | )
108 | self.streamA = nn.Linear(512, num_classes)
109 | self.streamV = nn.Linear(512, 1)
110 |
111 | if init_weights:
112 | self._initialize_weights()
113 |
114 | def forward(self, x):
115 | x = self.cnn(x)
116 | sA,sV = torch.split(x,512,dim = 1)
117 | sA = torch.flatten(sA,start_dim=1)
118 | sV = torch.flatten(sV, start_dim=1)
119 | sA = self.streamA(sA) #(B,4)
120 | sV = self.streamV(sV) #(B,1)
121 | # combine this 2 values together
122 | Q_value = sV + (sA - torch.mean(sA,dim=1,keepdim=True))
123 | return Q_value #(B,4)
124 |
125 | def _initialize_weights(self):
126 | for m in self.modules():
127 | if isinstance(m, nn.Conv2d):
128 | nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
129 | if m.bias is not None:
130 | nn.init.constant_(m.bias, 0.0)
131 | elif isinstance(m, nn.BatchNorm2d):
132 | nn.init.constant_(m.weight, 1.0)
133 | nn.init.constant_(m.bias, 0.0)
134 | elif isinstance(m, nn.Linear):
135 | nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
136 | nn.init.constant_(m.bias, 0.0)
137 |
--------------------------------------------------------------------------------
/EnvManagers.py:
--------------------------------------------------------------------------------
1 | """
2 | 1. wrap the env object of Gym
3 | 2. like the dataloader, we do image preprocessing in this class.
4 | """
5 | import gym
6 | import numpy as np
7 | import torch
8 | import torchvision.transforms as T
9 |
10 | class AtariEnvManager():
11 | def __init__(self, device, game_env, is_use_additional_ending_criterion):
12 | "avaliable game env: PongDeterministic-v4, BreakoutDeterministic-v4"
13 | self.device = device
14 | self.game_env = game_env
15 | # PongDeterministic-v4
16 | self.env = gym.make(game_env).unwrapped
17 | # self.env = gym.make('BreakoutDeterministic-v4').unwrapped #BZX: v4 automatically return skipped screens
18 | self.env.reset()
19 | self.current_screen = None
20 | self.done = False
21 | # BZX: running_K: stacked K images together to present a state
22 | # BZx: running_queue: maintain the latest running_K images
23 | self.running_K = 4
24 | self.running_queue = []
25 | self.is_additional_ending = False # This may change to True along the game. 2 possible reason: loss of lives; negative reward.
26 | self.current_lives = None
27 | self.is_use_additional_ending_criterion = is_use_additional_ending_criterion
28 |
29 | def reset(self):
30 | self.env.reset()
31 | self.current_screen = None
32 | self.running_queue = [] #BZX: clear the state
33 | self.is_additional_ending = False
34 | # self.current_lives = 5
35 |
36 | def close(self):
37 | self.env.close()
38 |
39 | def render(self, mode='human'):
40 | return self.env.render(mode)
41 |
42 | def num_actions_available(self):
43 | return self.env.action_space.n
44 |
45 | def print_action_meanings(self):
46 | print(self.env.get_action_meanings())
47 |
48 | def take_action(self, action):
49 | _, reward, self.done, lives = self.env.step(action.item())
50 | # for Pong: lives is always 0.0, so self.is_additional_ending will always be false
51 | if self.is_use_additional_ending_criterion:
52 | self.is_additonal_ending_criterion_met(lives,reward)
53 | return torch.tensor([reward], device=self.device)
54 |
55 | def just_starting(self):
56 | return self.current_screen is None
57 | def is_initial_action(self):
58 | return sum(self.running_queue).sum() == 0
59 |
60 | def init_running_queue(self):
61 | """
62 | initialize running queue with K black images
63 | :return:
64 | """
65 | self.current_screen = self.get_processed_screen()
66 | black_screen = torch.zeros_like(self.current_screen)
67 | for _ in range(self.running_K):
68 | self.running_queue.append(black_screen)
69 |
70 | def get_state(self):
71 | if self.just_starting():
72 | self.init_running_queue()
73 | elif self.done or self.is_additional_ending:
74 | self.current_screen = self.get_processed_screen()
75 | black_screen = torch.zeros_like(self.current_screen)
76 | # BZX: update running_queue
77 | self.running_queue.pop(0)
78 | self.running_queue.append(black_screen)
79 | else: #BZX: normal case
80 | s2 = self.get_processed_screen()
81 | self.current_screen = s2
82 | # BZx: update running_queue with s2
83 | self.running_queue.pop(0)
84 | self.running_queue.append(s2)
85 |
86 | return torch.stack(self.running_queue,dim=1).squeeze(2) #BZX: check if shape is (1KHW)
87 |
88 | def is_additonal_ending_criterion_met(self,lives,reward):
89 | "for different atari game, design different ending state criterion"
90 | if self.game_env == "BreakoutDeterministic-v4":
91 | if self.is_initial_action():
92 | self.current_lives = lives['ale.lives']
93 | elif lives['ale.lives'] < self.current_lives:
94 | self.is_additional_ending = True
95 | else:
96 | self.is_additional_ending = False
97 | self.current_lives = lives['ale.lives']
98 | if self.game_env == "PongDeterministic-v4":
99 | if reward < 0: #miss one ball will lead to a ending sate.
100 | self.is_additional_ending = True
101 | else:
102 | self.is_additional_ending = False
103 | return False
104 |
105 | def get_screen_height(self):
106 | screen = self.get_processed_screen()
107 | return screen.shape[2]
108 |
109 | def get_screen_width(self):
110 | screen = self.get_processed_screen()
111 | return screen.shape[3]
112 |
113 | def get_processed_screen(self):
114 | screen = self.render('rgb_array').transpose((2, 0, 1)) # PyTorch expects CHW
115 | # screen = self.crop_screen(screen)
116 | return self.transform_screen_data(screen) #shape is [1,1,110,84]
117 |
118 | def crop_screen(self, screen):
119 | if self.game_env == "BreakoutDeterministic-v4" or self.game_env == "PongDeterministic-v4":
120 | bbox = [34,0,160,160] #(x,y,delta_x,delta_y)
121 | screen = screen[:, bbox[0]:bbox[2]+bbox[0], bbox[1]:bbox[3]+bbox[1]] #BZX:(CHW)
122 | if self.game_env == "GopherDeterministic-v4":
123 | bbox = [110,0,120,160]
124 | screen = screen[:, bbox[0]:bbox[2]+bbox[0], bbox[1]:bbox[3]+bbox[1]]
125 | return screen
126 |
127 | def transform_screen_data(self, screen):
128 | # Convert to float, rescale, convert to tensor
129 | screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
130 | screen = torch.from_numpy(screen)
131 | screen = self.crop_screen(screen)
132 | # Use torchvision package to compose image transforms
133 | resize = T.Compose([
134 | T.ToPILImage()
135 | , T.Grayscale()
136 | , T.Resize((84, 84)) #BZX: the original paper's settings:(110,84), however, for simplicty we...
137 | , T.ToTensor()
138 | ])
139 | # add a batch dimension (BCHW)
140 | screen = resize(screen)
141 |
142 | return screen.unsqueeze(0).to(self.device) # BZX: Pay attention to the shape here. should be [1,1,84,84]
143 |
144 |
145 | class CartPoleEnvManager():
146 |
147 | def __init__(self, device):
148 | self.device = device
149 | self.env = gym.make('CartPole-v0').unwrapped
150 | self.env.reset()
151 | self.current_screen = None
152 | self.done = False
153 |
154 | def reset(self):
155 | self.env.reset()
156 | self.current_screen = None
157 |
158 | def close(self):
159 | self.env.close()
160 |
161 | def render(self, mode='human'):
162 | return self.env.render(mode)
163 |
164 | def num_actions_available(self):
165 | return self.env.action_space.n
166 |
167 | def take_action(self, action):
168 | _, reward, self.done, _ = self.env.step(action.item())
169 | return torch.tensor([reward], device=self.device)
170 |
171 | def just_starting(self):
172 | return self.current_screen is None
173 |
174 | def get_state(self):
175 | if self.just_starting() or self.done:
176 | self.current_screen = self.get_processed_screen()
177 | black_screen = torch.zeros_like(self.current_screen)
178 | return black_screen
179 | else:
180 | s1 = self.current_screen
181 | s2 = self.get_processed_screen()
182 | self.current_screen = s2
183 | return s2 - s1
184 |
185 | def get_screen_height(self):
186 | screen = self.get_processed_screen()
187 | return screen.shape[2]
188 |
189 | def get_screen_width(self):
190 | screen = self.get_processed_screen()
191 | return screen.shape[3]
192 |
193 | def get_processed_screen(self):
194 | screen = self.render('rgb_array').transpose((2, 0, 1)) # PyTorch expects CHW
195 | screen = self.crop_screen(screen)
196 | return self.transform_screen_data(screen)
197 |
198 | def crop_screen(self, screen):
199 | screen_height = screen.shape[1]
200 |
201 | # Strip off top and bottom
202 | top = int(screen_height * 0.4)
203 | bottom = int(screen_height * 0.8)
204 | screen = screen[:, top:bottom, :]
205 | return screen
206 |
207 | def transform_screen_data(self, screen):
208 | # Convert to float, rescale, convert to tensor
209 | screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
210 | screen = torch.from_numpy(screen)
211 |
212 | # Use torchvision package to compose image transforms
213 | resize = T.Compose([
214 | T.ToPILImage()
215 | , T.Resize((40, 90))
216 | , T.ToTensor()
217 | ])
218 |
219 | return resize(screen).unsqueeze(0).to(self.device) # add a batch dimension (BCHW)
--------------------------------------------------------------------------------
/continue_training.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import torch.optim as optim
3 | import time
4 | # customized import
5 | from DQNs import *
6 | from utils import *
7 | from EnvManagers import BreakoutEnvManager
8 | from Agent import *
9 |
10 | # load params
11 | param_json_fname = "DDQN_params.json" #TODO: please make sure the params are set right
12 | config_dict, hyperparams_dict = read_json(param_json_fname)
13 |
14 | # load middle point file path
15 | print("="*100)
16 | print("loading mid point file...")
17 | Middle_Point_json = "tmp_middle_point_file_path.json" #TODO
18 | md_path_dict = load_Middle_Point(Middle_Point_json)
19 |
20 | heldout_saver = HeldoutSaver(config_dict["HELDOUT_SET_DIR"],
21 | config_dict["HELDOUT_SET_MAX_PER_BATCH"],
22 | config_dict["HELDOUT_SAVE_RATE"])
23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24 | em = BreakoutEnvManager(device)
25 |
26 | # load other states
27 | with open(md_path_dict["mdStateFileName"], 'rb') as middle_point_state_file:
28 | midddle_point = pickle.load(middle_point_state_file)
29 | agent = midddle_point["agent"]
30 | tracker_dict = midddle_point["tracker_dict"]
31 | heldout_saver.set_batch_counter(midddle_point["heldout_batch_counter"])
32 | strategy = midddle_point["strategy"]
33 |
34 | # load memory
35 | with open(md_path_dict["mdMemFileName"], 'rb') as middle_point_mem_file:
36 | memory = pickle.load(middle_point_mem_file)
37 |
38 | # load 2 networks
39 | policy_net = DQN_CNN_2015(num_classes=em.num_actions_available(),init_weights=False).to(device)
40 | target_net = DQN_CNN_2015(num_classes=em.num_actions_available(),init_weights=False).to(device)
41 | policy_net.load_state_dict(torch.load(md_path_dict["md_Policy_Net_fName"]))
42 | target_net.load_state_dict(torch.load(md_path_dict["md_Target_Net_fName"]))
43 | target_net.eval() # this network will only be used for inference.
44 |
45 | print("successfully load all middle point files")
46 | print("="*100)
47 |
48 | # initialize optimizer and criterion
49 | optimizer = optim.Adam(params=policy_net.parameters(), lr=hyperparams_dict["lr"])
50 | criterion = torch.nn.SmoothL1Loss()
51 |
52 | plt.figure()
53 |
54 | t1,t2 = time.time(),time.time() # for estimating the time
55 | num_target_update = 0 # auxillary variable for estimating the time
56 |
57 | for episode in range(hyperparams_dict["num_episodes"]):
58 | em.reset()
59 | state = em.get_state() # initialize sate
60 | tol_reward = 0
61 | while(1):
62 | # Visualization of game process and state
63 | if config_dict["IS_RENDER_GAME_PROCESS"]: em.env.render() # BZX: will this slow down the speed?
64 | if config_dict["IS_VISUALIZE_STATE"]: visualize_state(state)
65 | if config_dict["IS_GENERATE_HELDOUT"]: heldout_saver.append(state) # generate heldout set for offline eval
66 |
67 | # Given s, select a by either policy_net or random
68 | action = agent.select_action(state, policy_net)
69 | # collect reward from env along the action
70 | reward = em.take_action(action)
71 | tol_reward += reward
72 | # after took a, get s'
73 | next_state = em.get_state()
74 | # push (s,a,s',r) into memory
75 | memory.push(Experience(state[0,-1,:,:].clone(), action, "", reward))
76 | # update current state
77 | state = next_state
78 |
79 | # After memory have been filled with enough samples, we update policy_net every 4 agent steps.
80 | if (agent.current_step % hyperparams_dict["action_repeat"] == 0) and \
81 | memory.can_provide_sample(hyperparams_dict["batch_size"], hyperparams_dict["replay_start_size"]):
82 |
83 | experiences = memory.sample(hyperparams_dict["batch_size"])
84 | states, actions, rewards, next_states = extract_tensors(experiences)
85 | current_q_values = QValues.get_current(policy_net, states, actions) # checked
86 | # next_q_values = QValues.DQN_get_next(target_net, next_states) # for DQN
87 | next_q_values = QValues.DDQN_get_next(policy_net,target_net, next_states)
88 | target_q_values = (next_q_values * hyperparams_dict["gamma"]) + rewards
89 | # calculate loss and update policy_net
90 | optimizer.zero_grad()
91 | loss = criterion(current_q_values, target_q_values.unsqueeze(1))
92 | loss.backward()
93 | optimizer.step()
94 |
95 | tracker_dict["loss_hist"].append(loss.item())
96 | tracker_dict["minibatch_updates_counter"] += 1
97 |
98 | # update target_net
99 | if tracker_dict["minibatch_updates_counter"] % hyperparams_dict["target_update"] == 0:
100 | target_net.load_state_dict(policy_net.state_dict())
101 |
102 | # estimate time
103 | num_target_update += 1
104 | if num_target_update % 2 == 0: t1 = time.time()
105 | if num_target_update % 2 == 1: t2 = time.time()
106 | print("=" * 50)
107 | remaining_update_times = (config_dict["MAX_ITERATION"] - tracker_dict["minibatch_updates_counter"])// \
108 | hyperparams_dict["target_update"]
109 | time_sec = abs(t1-t2) * remaining_update_times
110 | print("estimated remaining time = {}h-{}min".format(time_sec//3600,(time_sec%3600)//60))
111 | print("len of replay memory:", len(memory.memory))
112 | print("minibatch_updates_counter = ", tracker_dict["minibatch_updates_counter"])
113 | print("current_step of agent = ", agent.current_step)
114 | print("exploration rate = ", strategy.get_exploration_rate(agent.current_step))
115 | print("=" * 50)
116 |
117 | # save checkpoint model
118 | if tracker_dict["minibatch_updates_counter"] % config_dict["UPDATE_PER_CHECKPOINT"] == 0:
119 | save_model(policy_net, tracker_dict, config_dict)
120 |
121 | plt.savefig(config_dict["FIGURES_PATH"] + "Iterations:{}-Time:".format(tracker_dict["minibatch_updates_counter"]) + datetime.datetime.now().strftime(
122 | config_dict["DATE_FORMAT"]) + ".jpg")
123 |
124 | if em.done:
125 | tracker_dict["rewards_hist"].append(tol_reward)
126 | tracker_dict["running_reward"] = plot(tracker_dict["rewards_hist"], 100)
127 | break
128 | if config_dict["IS_BREAK_BY_MAX_ITERATION"] and \
129 | tracker_dict["minibatch_updates_counter"] > config_dict["MAX_ITERATION"]:
130 | break
131 |
132 | em.close()
133 | # save loss figure
134 | plt.figure()
135 | plt.plot(tracker_dict["loss_hist"])
136 | plt.title("loss")
137 | plt.xlabel("iterations")
138 | plt.savefig(config_dict["FIGURES_PATH"] + "Loss-Iterations:{}-Time:".format(tracker_dict["minibatch_updates_counter"]) + datetime.datetime.now().strftime(
139 | config_dict["DATE_FORMAT"]) + ".jpg")
140 |
141 | if config_dict["IS_SAVE_MIDDLE_POINT"]:
142 | # save core instances
143 | if not os.path.exists(config_dict["MIDDLE_POINT_PATH"]):
144 | os.makedirs(config_dict["MIDDLE_POINT_PATH"])
145 |
146 | mdMemFileName = config_dict["MIDDLE_POINT_PATH"] + "MiddlePoint_Memory_" + datetime.datetime.now().strftime(
147 | config_dict["DATE_FORMAT"]) + ".pkl"
148 | middle_mem_file = open(mdMemFileName, 'wb')
149 | pickle.dump(memory, middle_mem_file)
150 | middle_mem_file.close()
151 | del memory # make more memory space
152 |
153 | midddle_point = {}
154 | midddle_point["agent"] = agent
155 | midddle_point["tracker_dict"] = tracker_dict
156 | midddle_point["heldout_batch_counter"] = heldout_saver.batch_counter
157 | midddle_point["strategy"] = strategy
158 | mdStateFileName = config_dict["MIDDLE_POINT_PATH"] + "MiddlePoint_State_" + datetime.datetime.now().strftime(config_dict["DATE_FORMAT"]) + ".pkl"
159 |
160 | middle_point_file = open(mdStateFileName, 'wb')
161 | pickle.dump(midddle_point, middle_point_file)
162 | middle_point_file.close()
163 |
164 | # save policy_net and target_net
165 | md_Policy_Net_fName = config_dict["MIDDLE_POINT_PATH"] + "MiddlePoint_Policy_Net_" + datetime.datetime.now().strftime(config_dict["DATE_FORMAT"]) + ".pth"
166 | torch.save(policy_net.state_dict(),md_Policy_Net_fName)
167 | md_Target_Net_fName = config_dict["MIDDLE_POINT_PATH"] + "MiddlePoint_Target_Net_" + datetime.datetime.now().strftime(config_dict["DATE_FORMAT"]) + ".pth"
168 | torch.save(policy_net.state_dict(), md_Target_Net_fName)
169 |
170 | # save middle point files' path for continuous training
171 | md_path_dict = {}
172 | md_path_dict["mdMemFileName"] = mdMemFileName
173 | md_path_dict["mdStateFileName"] = mdStateFileName
174 | md_path_dict["md_Policy_Net_fName"] = md_Policy_Net_fName
175 | md_path_dict["md_Target_Net_fName"] = md_Target_Net_fName
176 | with open('tmp_middle_point_file_path.json', 'w') as fp:
177 | json.dump(md_path_dict, fp)
--------------------------------------------------------------------------------
/Q_Learning_IceLake.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "metadata": {
7 | "collapsed": true,
8 | "pycharm": {
9 | "is_executing": false
10 | }
11 | },
12 | "outputs": [],
13 | "source": [
14 | "import numpy as np\n",
15 | "import gym\n",
16 | "import random\n",
17 | "import time\n",
18 | "from IPython.display import clear_output\n"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 10,
24 | "outputs": [
25 | {
26 | "name": "stdout",
27 | "text": [
28 | "[[0. 0. 0. 0.]\n",
29 | " [0. 0. 0. 0.]\n",
30 | " [0. 0. 0. 0.]\n",
31 | " [0. 0. 0. 0.]\n",
32 | " [0. 0. 0. 0.]\n",
33 | " [0. 0. 0. 0.]\n",
34 | " [0. 0. 0. 0.]\n",
35 | " [0. 0. 0. 0.]\n",
36 | " [0. 0. 0. 0.]\n",
37 | " [0. 0. 0. 0.]\n",
38 | " [0. 0. 0. 0.]\n",
39 | " [0. 0. 0. 0.]\n",
40 | " [0. 0. 0. 0.]\n",
41 | " [0. 0. 0. 0.]\n",
42 | " [0. 0. 0. 0.]\n",
43 | " [0. 0. 0. 0.]]\n"
44 | ],
45 | "output_type": "stream"
46 | }
47 | ],
48 | "source": [
49 | "env = gym.make(\"FrozenLake-v0\")\n",
50 | "action_space_size = env.action_space.n\n",
51 | "state_space_size = env.observation_space.n\n",
52 | "\n",
53 | "q_table = np.zeros((state_space_size, action_space_size))\n",
54 | "print(q_table)"
55 | ],
56 | "metadata": {
57 | "collapsed": false,
58 | "pycharm": {
59 | "name": "#%%\n",
60 | "is_executing": false
61 | }
62 | }
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": 11,
67 | "outputs": [],
68 | "source": [
69 | "num_episodes = 20000\n",
70 | "max_steps_per_episode = 100\n",
71 | "\n",
72 | "learning_rate = 0.1\n",
73 | "discount_rate = 0.99\n",
74 | "\n",
75 | "exploration_rate = 1\n",
76 | "max_exploration_rate = 1\n",
77 | "min_exploration_rate = 0.01\n",
78 | "exploration_decay_rate = 0.01"
79 | ],
80 | "metadata": {
81 | "collapsed": false,
82 | "pycharm": {
83 | "name": "#%% fff\n",
84 | "is_executing": false
85 | }
86 | }
87 | },
88 | {
89 | "cell_type": "markdown",
90 | "source": [
91 | "### Q-learning algorithm\n",
92 | "```\n",
93 | "for episode in range(num_episodes):\n",
94 | " # initialize new episode params\n",
95 | " \n",
96 | " for step in range(max_steps_per_episode): \n",
97 | " # Exploration-exploitation trade-off\n",
98 | " # Take new action\n",
99 | " # Update Q-table\n",
100 | " # Set new state\n",
101 | " # Add new reward \n",
102 | "\n",
103 | " # Exploration rate decay \n",
104 | " # Add current episode reward to total rewards list\n",
105 | "```"
106 | ],
107 | "metadata": {
108 | "collapsed": false
109 | }
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": 12,
114 | "outputs": [],
115 | "source": [
116 | "rewards_all_episodes = []\n",
117 | "for episode in range(num_episodes):\n",
118 | " state = env.reset()\n",
119 | " done = False\n",
120 | " rewards_current_episode = 0\n",
121 | "\n",
122 | " for step in range(max_steps_per_episode): \n",
123 | " # Exploration-exploitation trade-off\n",
124 | " exploration_rate_threshold = random.uniform(0, 1)\n",
125 | " if exploration_rate_threshold > exploration_rate:\n",
126 | " action = np.argmax(q_table[state,:]) \n",
127 | " else:\n",
128 | " action = env.action_space.sample()\n",
129 | " #take new action\n",
130 | " new_state, reward, done, info = env.step(action)\n",
131 | " # Update Q-table for Q(s,a)\n",
132 | " q_table[state, action] = q_table[state, action] * (1 - learning_rate) + \\\n",
133 | " learning_rate * (reward + discount_rate * np.max(q_table[new_state, :]))\n",
134 | " \n",
135 | " state = new_state\n",
136 | " rewards_current_episode += reward \n",
137 | " \n",
138 | " if done == True: \n",
139 | " break\n",
140 | " \n",
141 | " # Exploration rate decay\n",
142 | " exploration_rate = min_exploration_rate + \\\n",
143 | " (max_exploration_rate - min_exploration_rate) * np.exp(-exploration_decay_rate*episode)\n",
144 | " rewards_all_episodes.append(rewards_current_episode)\n",
145 | " "
146 | ],
147 | "metadata": {
148 | "collapsed": false,
149 | "pycharm": {
150 | "name": "#%%\n",
151 | "is_executing": false
152 | }
153 | }
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": 13,
158 | "outputs": [
159 | {
160 | "name": "stdout",
161 | "text": [
162 | "********Average reward per thousand episodes********\n",
163 | "\n",
164 | "1000 : 0.04100000000000003\n",
165 | "2000 : 0.05500000000000004\n",
166 | "3000 : 0.08400000000000006\n",
167 | "4000 : 0.10800000000000008\n",
168 | "5000 : 0.10900000000000008\n",
169 | "6000 : 0.1320000000000001\n",
170 | "7000 : 0.1460000000000001\n",
171 | "8000 : 0.17500000000000013\n",
172 | "9000 : 0.24000000000000019\n",
173 | "10000 : 0.2620000000000002\n",
174 | "11000 : 0.45800000000000035\n",
175 | "12000 : 0.6610000000000005\n",
176 | "13000 : 0.6670000000000005\n",
177 | "14000 : 0.6930000000000005\n",
178 | "15000 : 0.6980000000000005\n",
179 | "16000 : 0.6840000000000005\n",
180 | "17000 : 0.6650000000000005\n",
181 | "18000 : 0.7030000000000005\n",
182 | "19000 : 0.6820000000000005\n",
183 | "20000 : 0.6890000000000005\n"
184 | ],
185 | "output_type": "stream"
186 | }
187 | ],
188 | "source": [
189 | "# Calculate and print the average reward per thousand episodes\n",
190 | "rewards_per_thosand_episodes = np.split(np.array(rewards_all_episodes),num_episodes/1000)\n",
191 | "count = 1000\n",
192 | "\n",
193 | "print(\"********Average reward per thousand episodes********\\n\")\n",
194 | "for r in rewards_per_thosand_episodes:\n",
195 | " print(count, \": \", str(sum(r/1000)))\n",
196 | " count += 1000\n"
197 | ],
198 | "metadata": {
199 | "collapsed": false,
200 | "pycharm": {
201 | "name": "#%%\n",
202 | "is_executing": false
203 | }
204 | }
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": 15,
209 | "outputs": [
210 | {
211 | "name": "stdout",
212 | "text": [
213 | "\n",
214 | "\n",
215 | "********Q-table********\n",
216 | "\n",
217 | "[[0.56155668 0.50676359 0.51113332 0.51633663]\n",
218 | " [0.31394613 0.31218741 0.24458661 0.49321051]\n",
219 | " [0.41637602 0.37275899 0.41689343 0.45542588]\n",
220 | " [0.20679537 0.29910297 0.36394872 0.42988882]\n",
221 | " [0.58086203 0.41773475 0.44126002 0.42910216]\n",
222 | " [0. 0. 0. 0. ]\n",
223 | " [0.33825872 0.14911566 0.18237918 0.0651859 ]\n",
224 | " [0. 0. 0. 0. ]\n",
225 | " [0.33209066 0.48196698 0.49282896 0.64355292]\n",
226 | " [0.2918537 0.70621805 0.53639371 0.44925847]\n",
227 | " [0.65793949 0.41190176 0.40708876 0.33627327]\n",
228 | " [0. 0. 0. 0. ]\n",
229 | " [0. 0. 0. 0. ]\n",
230 | " [0.4509215 0.62594686 0.76507311 0.51494189]\n",
231 | " [0.74501853 0.84665135 0.77416626 0.77714155]\n",
232 | " [0. 0. 0. 0. ]]\n"
233 | ],
234 | "output_type": "stream"
235 | }
236 | ],
237 | "source": [
238 | "# Print updated Q-table\n",
239 | "print(\"\\n\\n********Q-table********\\n\")\n",
240 | "print(q_table)"
241 | ],
242 | "metadata": {
243 | "collapsed": false,
244 | "pycharm": {
245 | "name": "#%%\n",
246 | "is_executing": false
247 | }
248 | }
249 | },
250 | {
251 | "cell_type": "markdown",
252 | "source": [
253 | "### The code to watch the agent play the game \n",
254 | "```\n",
255 | "# Watch our agent play Frozen Lake by playing the best action \n",
256 | "# from each state according to the Q-table\n",
257 | "\n",
258 | "for episode in range(3):\n",
259 | " # initialize new episode params\n",
260 | "\n",
261 | " for step in range(max_steps_per_episode): \n",
262 | " # Show current state of environment on screen\n",
263 | " # Choose action with highest Q-value for current state \n",
264 | " # Take new action\n",
265 | " \n",
266 | " if done:\n",
267 | " if reward == 1:\n",
268 | " # Agent reached the goal and won episode\n",
269 | " else:\n",
270 | " # Agent stepped in a hole and lost episode \n",
271 | " \n",
272 | " # Set new state\n",
273 | " \n",
274 | "env.close()\n",
275 | "```\n",
276 | "\n"
277 | ],
278 | "metadata": {
279 | "collapsed": false,
280 | "pycharm": {
281 | "name": "#%% md\n"
282 | }
283 | }
284 | },
285 | {
286 | "cell_type": "code",
287 | "execution_count": 16,
288 | "outputs": [
289 | {
290 | "name": "stdout",
291 | "text": [
292 | " (Down)\n",
293 | "SFFF\n",
294 | "FHFH\n",
295 | "FFFH\n",
296 | "HFF\u001b[41mG\u001b[0m\n",
297 | "****You reached the goal!****\n"
298 | ],
299 | "output_type": "stream"
300 | }
301 | ],
302 | "source": [
303 | "for episode in range(3):\n",
304 | " state = env.reset()\n",
305 | " done = False\n",
306 | " print(\"*****EPISODE \", episode+1, \"*****\\n\\n\\n\\n\")\n",
307 | " time.sleep(1)\n",
308 | " \n",
309 | " for step in range(max_steps_per_episode):\n",
310 | " clear_output(wait=True)\n",
311 | " env.render()\n",
312 | " time.sleep(0.3)\n",
313 | " \n",
314 | " action = np.argmax(q_table[state,:]) \n",
315 | " new_state, reward, done, info = env.step(action)\n",
316 | " \n",
317 | " if done:\n",
318 | " clear_output(wait=True)\n",
319 | " env.render()\n",
320 | " if reward == 1:\n",
321 | " print(\"****You reached the goal!****\")\n",
322 | " time.sleep(3)\n",
323 | " else:\n",
324 | " print(\"****You fell through a hole!****\")\n",
325 | " time.sleep(3)\n",
326 | " clear_output(wait=True)\n",
327 | " break\n",
328 | " state = new_state\n",
329 | "\n",
330 | "env.close()"
331 | ],
332 | "metadata": {
333 | "collapsed": false,
334 | "pycharm": {
335 | "name": "#%%\n",
336 | "is_executing": false
337 | }
338 | }
339 | },
340 | {
341 | "cell_type": "code",
342 | "execution_count": 8,
343 | "outputs": [],
344 | "source": [
345 | "\n"
346 | ],
347 | "metadata": {
348 | "collapsed": false,
349 | "pycharm": {
350 | "name": "#%%\n",
351 | "is_executing": false
352 | }
353 | }
354 | }
355 | ],
356 | "metadata": {
357 | "kernelspec": {
358 | "display_name": "Python 3",
359 | "language": "python",
360 | "name": "python3"
361 | },
362 | "language_info": {
363 | "codemirror_mode": {
364 | "name": "ipython",
365 | "version": 2
366 | },
367 | "file_extension": ".py",
368 | "mimetype": "text/x-python",
369 | "name": "python",
370 | "nbconvert_exporter": "python",
371 | "pygments_lexer": "ipython2",
372 | "version": "2.7.6"
373 | },
374 | "pycharm": {
375 | "stem_cell": {
376 | "cell_type": "raw",
377 | "source": [],
378 | "metadata": {
379 | "collapsed": false
380 | }
381 | }
382 | }
383 | },
384 | "nbformat": 4,
385 | "nbformat_minor": 0
386 | }
--------------------------------------------------------------------------------
/Breakout.py:
--------------------------------------------------------------------------------
1 |
2 | import datetime
3 | import torch.optim as optim
4 | import time
5 | # customized import
6 | from DQNs import *
7 | from utils import *
8 | from EnvManagers import AtariEnvManager
9 | from Agent import *
10 |
11 |
12 | param_json_fname = "DDQN_params.json"
13 | config_dict, hyperparams_dict, eval_dict = read_json(param_json_fname)
14 | if config_dict["IS_USED_TENSORBOARD"]:
15 | from torch.utils.tensorboard import SummaryWriter
16 |
17 | ### core classes
18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19 | print("using:",device)
20 | em = AtariEnvManager(device, config_dict["GAME_ENV"], config_dict["IS_USE_ADDITIONAL_ENDING_CRITERION"])
21 | em.print_action_meanings()
22 | strategy = EpsilonGreedyStrategyLinear(hyperparams_dict["eps_start"], hyperparams_dict["eps_end"], hyperparams_dict["eps_final"],
23 | hyperparams_dict["eps_startpoint"], hyperparams_dict["eps_kneepoint"],hyperparams_dict["eps_final_knee"])
24 | agent = Agent(strategy, em.num_actions_available(), device)
25 | if config_dict["IS_USED_PER"]:
26 | memory = ReplayMemory_economy_PER(hyperparams_dict["memory_size"])
27 | else:
28 | memory = ReplayMemory_economy(hyperparams_dict["memory_size"])
29 |
30 | # availible models: DQN_CNN_2013,DQN_CNN_2015, Dueling_DQN_2016_Modified
31 | if config_dict["MODEL_NAME"] == "DQN_CNN_2015":
32 | policy_net = DQN_CNN_2015(num_classes=em.num_actions_available(),init_weights=True).to(device)
33 | target_net = DQN_CNN_2015(num_classes=em.num_actions_available(),init_weights=True).to(device)
34 | elif config_dict["MODEL_NAME"] == "Dueling_DQN_2016_Modified":
35 | policy_net = Dueling_DQN_2016_Modified(num_classes=em.num_actions_available(), init_weights=True).to(device)
36 | target_net = Dueling_DQN_2016_Modified(num_classes=em.num_actions_available(), init_weights=True).to(device)
37 | else:
38 | print("No such model! Please check your configuration in .json file")
39 |
40 | target_net.load_state_dict(policy_net.state_dict())
41 | target_net.eval() # this network will only be used for inference.
42 | optimizer = optim.Adam(params=policy_net.parameters(), lr=hyperparams_dict["lr"])
43 | criterion = torch.nn.SmoothL1Loss()
44 | # can use tensorboard to track the reward
45 | if config_dict["IS_USED_TENSORBOARD"]:
46 | PATH_to_log_dir = config_dict["TENSORBOARD_PATH"] + datetime.datetime.now().strftime(config_dict["DATE_FORMAT"])
47 | writer = SummaryWriter(PATH_to_log_dir)
48 |
49 | # print("num_actions_available: ",em.num_actions_available())
50 | # print("action_meanings:" ,em.env.get_action_meanings())
51 |
52 | # Auxilarty variables
53 | heldout_saver = HeldoutSaver(config_dict["HELDOUT_SET_DIR"],
54 | config_dict["HELDOUT_SET_MAX_PER_BATCH"],
55 | config_dict["HELDOUT_SAVE_RATE"])
56 | tracker_dict = init_tracker_dict()
57 |
58 | plt.figure()
59 | # for estimating the time
60 | t1,t2 = time.time(),time.time()
61 | num_target_update = 0
62 |
63 | for episode in range(hyperparams_dict["num_episodes"]):
64 | em.reset()
65 | state = em.get_state() # initialize sate
66 | tol_reward = 0
67 | while(1):
68 | # Visualization of game process and state
69 | if config_dict["IS_RENDER_GAME_PROCESS"]: em.env.render() # BZX: will this slow down the speed?
70 | if config_dict["IS_VISUALIZE_STATE"]: visualize_state(state)
71 | if config_dict["IS_GENERATE_HELDOUT"]: heldout_saver.append(state) # generate heldout set for offline eval
72 |
73 | # Given s, select a by either policy_net or random
74 | action = agent.select_action(state, policy_net)
75 | # print(action)
76 | # collect unclipped reward from env along the action
77 | reward = em.take_action(action)
78 | tol_reward += reward
79 | tracker_dict["actions_counter"] += 1
80 | # after took a, get s'
81 | next_state = em.get_state()
82 | # push (s,a,s',r) into memory
83 | memory.push(Experience(state[0,-1,:,:].clone(), action, "", torch.sign(reward))) #clip reward!!!
84 | # update current state
85 | state = next_state
86 |
87 | # After memory have been filled with enough samples, we update policy_net every 4 agent steps.
88 | if (agent.current_step % hyperparams_dict["action_repeat"] == 0) and \
89 | memory.can_provide_sample(hyperparams_dict["batch_size"], hyperparams_dict["replay_start_size"]):
90 |
91 | if config_dict["IS_USED_PER"]:
92 | experiences, experiences_index, weights = memory.sample(hyperparams_dict["batch_size"])
93 | else:
94 | experiences = memory.sample(hyperparams_dict["batch_size"])
95 | states, actions, rewards, next_states = extract_tensors(experiences)
96 | current_q_values = QValues.get_current(policy_net, states, actions) # checked
97 | # next_q_values = QValues.DQN_get_next(target_net, next_states) # for DQN
98 | next_q_values = QValues.DDQN_get_next(policy_net,target_net, next_states)
99 | target_q_values = (next_q_values * hyperparams_dict["gamma"]) + rewards
100 | # calculate loss and update policy_net
101 | optimizer.zero_grad()
102 | if config_dict["IS_USED_PER"]:
103 | # compute TD error
104 | TD_errors = torch.abs(current_q_values - target_q_values.unsqueeze(1)).detach().cpu().numpy()
105 | # update priorities
106 | memory.update_priority(experiences_index, TD_errors.squeeze(1))
107 | # compute loss
108 | loss = torch.mean(weights.detach() * (current_q_values - target_q_values.unsqueeze(1))**2)
109 | else:
110 | loss = criterion(current_q_values, target_q_values.unsqueeze(1))
111 | loss.backward()
112 | optimizer.step()
113 |
114 | tracker_dict["loss_hist"].append(loss.item())
115 | tracker_dict["minibatch_updates_counter"] += 1
116 |
117 | # update target_net
118 | if tracker_dict["minibatch_updates_counter"] % hyperparams_dict["target_update"] == 0:
119 | target_net.load_state_dict(policy_net.state_dict())
120 |
121 | # estimate time
122 | num_target_update += 1
123 | if num_target_update % 2 == 0: t1 = time.time()
124 | if num_target_update % 2 == 1: t2 = time.time()
125 | print("=" * 50)
126 | remaining_update_times = (config_dict["MAX_ITERATION"] - tracker_dict["minibatch_updates_counter"])// \
127 | hyperparams_dict["target_update"]
128 | time_sec = abs(t1-t2) * remaining_update_times
129 | print("estimated remaining time = {}h-{}min".format(time_sec//3600,(time_sec%3600)//60))
130 | print("len of replay memory:", len(memory.memory))
131 | print("minibatch_updates_counter = ", tracker_dict["minibatch_updates_counter"])
132 | print("current_step of agent = ", agent.current_step)
133 | print("exploration rate = ", strategy.get_exploration_rate(agent.current_step))
134 | print("=" * 50)
135 |
136 | # save checkpoint model
137 | if tracker_dict["minibatch_updates_counter"] % config_dict["UPDATE_PER_CHECKPOINT"] == 0:
138 | save_model(policy_net, tracker_dict, config_dict)
139 | if not os.path.exists(config_dict["FIGURES_PATH"]):
140 | os.makedirs(config_dict["FIGURES_PATH"])
141 | plt.savefig(config_dict["FIGURES_PATH"] + "Iterations_{}-Time_".format(tracker_dict["minibatch_updates_counter"]) + datetime.datetime.now().strftime(
142 | config_dict["DATE_FORMAT"]) + ".jpg")
143 |
144 | if em.done:
145 | tracker_dict["rewards_hist"].append(tol_reward)
146 | tracker_dict["rewards_hist_update_axis"].append(tracker_dict["minibatch_updates_counter"])
147 | tracker_dict["running_reward"] = plot(tracker_dict["rewards_hist"], 100)
148 | # use tensorboard to track the reward
149 | if config_dict["IS_USED_TENSORBOARD"]:
150 | moving_avg_period = 100
151 | tracker_dict["moving_avg"] = get_moving_average(moving_avg_period, tracker_dict["rewards_hist"])
152 | writer.add_scalars('reward', {'reward': tracker_dict["rewards_hist"][-1],
153 | 'reward_average': tracker_dict["moving_avg"][-1]}, episode)
154 | break
155 |
156 | if config_dict["IS_BREAK_BY_MAX_ITERATION"] and \
157 | tracker_dict["minibatch_updates_counter"] > config_dict["MAX_ITERATION"]:
158 | break
159 |
160 | em.close()
161 | if config_dict["IS_USED_TENSORBOARD"]:
162 | writer.close()
163 | # save loss figure
164 | plt.figure()
165 | plt.plot(tracker_dict["loss_hist"])
166 | plt.title("loss")
167 | plt.xlabel("iterations")
168 | plt.savefig(config_dict["FIGURES_PATH"] + "Loss-Iterations_{}-Time_".format(tracker_dict["minibatch_updates_counter"]) + datetime.datetime.now().strftime(
169 | config_dict["DATE_FORMAT"]) + ".jpg")
170 |
171 | # save tracker_dict["eval_model_list_txt"] to txt file
172 | if not os.path.exists(eval_dict["EVAL_MODEL_LIST_TXT_PATH"]):
173 | os.makedirs(eval_dict["EVAL_MODEL_LIST_TXT_PATH"])
174 | txt_fname = "ModelName_{}-GameName_{}-Time_".format(config_dict["MODEL_NAME"],config_dict["GAME_NAME"]) + datetime.datetime.now().strftime(
175 | config_dict["DATE_FORMAT"]) + ".txt"
176 | with open( eval_dict["EVAL_MODEL_LIST_TXT_PATH"] + txt_fname,'w') as f:
177 | f.write('\n'.join(tracker_dict["eval_model_list_txt"]))
178 |
179 | # pickle tracker_dict for report figures
180 | print("="*100)
181 | print("saving results...")
182 | print("=" * 100)
183 | if not os.path.exists(config_dict["RESULT_PATH"]):
184 | os.makedirs(config_dict["RESULT_PATH"])
185 | tracker_fname = "ModelName_{}-GameName_{}-Time_".format(config_dict["MODEL_NAME"],config_dict["GAME_NAME"]) + datetime.datetime.now().strftime(
186 | config_dict["DATE_FORMAT"]) + ".pkl"
187 | with open(config_dict["RESULT_PATH"] + tracker_fname,'wb') as f:
188 | pickle.dump(tracker_dict, f)
189 |
190 |
191 | if config_dict["IS_SAVE_MIDDLE_POINT"]:
192 | print("="*100)
193 | print("saving middel point...")
194 | print("=" * 100)
195 | # save core instances
196 | if not os.path.exists(config_dict["MIDDLE_POINT_PATH"]):
197 | os.makedirs(config_dict["MIDDLE_POINT_PATH"])
198 |
199 | mdMemFileName = config_dict["MIDDLE_POINT_PATH"] + "MiddlePoint_Memory_" + datetime.datetime.now().strftime(
200 | config_dict["DATE_FORMAT"]) + ".pkl"
201 | middle_mem_file = open(mdMemFileName, 'wb')
202 | pickle.dump(memory, middle_mem_file)
203 | middle_mem_file.close()
204 | del memory
205 | # del memory # make more memory space
206 |
207 | midddle_point = {}
208 | midddle_point["agent"] = agent
209 | midddle_point["tracker_dict"] = tracker_dict
210 | midddle_point["heldout_batch_counter"] = heldout_saver.batch_counter
211 | midddle_point["strategy"] = strategy
212 | mdStateFileName = config_dict["MIDDLE_POINT_PATH"] + "MiddlePoint_State_" + datetime.datetime.now().strftime(config_dict["DATE_FORMAT"]) + ".pkl"
213 |
214 | middle_point_file = open(mdStateFileName, 'wb')
215 | pickle.dump(midddle_point, middle_point_file)
216 | middle_point_file.close()
217 |
218 | # save policy_net and target_net
219 | md_Policy_Net_fName = config_dict["MIDDLE_POINT_PATH"] + "MiddlePoint_Policy_Net_" + datetime.datetime.now().strftime(config_dict["DATE_FORMAT"]) + ".pth"
220 | torch.save(policy_net.state_dict(),md_Policy_Net_fName)
221 | md_Target_Net_fName = config_dict["MIDDLE_POINT_PATH"] + "MiddlePoint_Target_Net_" + datetime.datetime.now().strftime(config_dict["DATE_FORMAT"]) + ".pth"
222 | torch.save(policy_net.state_dict(), md_Target_Net_fName)
223 |
224 | # save middle point files' path for continuous training
225 | md_path_dict = {}
226 | md_path_dict["mdMemFileName"] = mdMemFileName
227 | md_path_dict["mdStateFileName"] = mdStateFileName
228 | md_path_dict["md_Policy_Net_fName"] = md_Policy_Net_fName
229 | md_path_dict["md_Target_Net_fName"] = md_Target_Net_fName
230 | with open('tmp_middle_point_file_path.json', 'w') as fp:
231 | json.dump(md_path_dict, fp)
232 |
233 |
234 |
235 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import torch
4 | import matplotlib
5 | import matplotlib.pyplot as plt
6 | from collections import namedtuple
7 | import numpy as np
8 | import pickle
9 | import json
10 | import os
11 | import datetime
12 | import imageio
13 | from skimage.transform import resize as skimage_resize
14 | from sumtree import *
15 |
16 | is_ipython = 'inline' in matplotlib.get_backend()
17 | if is_ipython:
18 | from IPython import display
19 | else:
20 | matplotlib.use('TkAgg')
21 | # matplotlib.use('Agg')
22 |
23 | Experience = namedtuple(
24 | 'Experience',
25 | ('state', 'action', 'next_state', 'reward')
26 | )
27 |
28 | Eco_Experience = namedtuple(
29 | 'Eco_Experience',
30 | ('state', 'action', 'reward')
31 | )
32 |
33 |
34 |
35 |
36 |
37 | class ReplayMemory():
38 | # initial memory
39 | def __init__(self, capacity):
40 | self.capacity = capacity
41 | self.memory = []
42 | self.push_count = 0
43 | # self.dtype = torch.uint8
44 |
45 | def push(self, experience):
46 | if len(self.memory) < self.capacity:
47 | self.memory.append(experience)
48 | else:
49 | self.memory[self.push_count % self.capacity] = experience
50 | self.push_count += 1
51 |
52 | def sample(self, batch_size):
53 | return random.sample(self.memory, batch_size)
54 |
55 | def can_provide_sample(self, batch_size):
56 | return len(self.memory) >= batch_size
57 |
58 | class ReplayMemory_economy():
59 | # save one state per experience to improve memory size
60 | def __init__(self, capacity):
61 | self.capacity = capacity
62 | self.memory = []
63 | self.push_count = 0
64 | self.dtype = torch.uint8
65 |
66 | def push(self, experience):
67 | state = (experience.state * 255).type(self.dtype).cpu()
68 | # next_state = (experience.next_state * 255).type(self.dtype)
69 | new_experience = Eco_Experience(state,experience.action,experience.reward)
70 |
71 | if len(self.memory) < self.capacity:
72 | self.memory.append(new_experience)
73 | else:
74 | self.memory[self.push_count % self.capacity] = new_experience
75 | # print(id(experience))
76 | # print(id(self.memory[0]))
77 | self.push_count += 1
78 |
79 | def sample(self, batch_size):
80 | # randomly sample experiences
81 | experience_index = np.random.randint(3, len(self.memory)-1, size = batch_size)
82 | # memory_arr = np.array(self.memory)
83 | experiences = []
84 | for index in experience_index:
85 | if self.push_count > self.capacity:
86 | state = torch.stack(([self.memory[index+j].state for j in range(-3,1)])).unsqueeze(0)
87 | next_state = torch.stack(([self.memory[index+1+j].state for j in range(-3,1)])).unsqueeze(0)
88 | else:
89 | state = torch.stack(([self.memory[np.max(index+j, 0)].state for j in range(-3,1)])).unsqueeze(0)
90 | next_state = torch.stack(([self.memory[np.max(index+1+j, 0)].state for j in range(-3,1)])).unsqueeze(0)
91 | experiences.append(Experience(state.float().cuda()/255, self.memory[index].action, next_state.float().cuda()/255, self.memory[index].reward))
92 | # return random.sample(self.memory, batch_size)
93 | return experiences
94 |
95 | def can_provide_sample(self, batch_size, replay_start_size):
96 | return (len(self.memory) >= replay_start_size) and (len(self.memory) >= batch_size + 3)
97 |
98 | class ReplayMemory_economy_PER():
99 | # Memory replay with priorited experience replay
100 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
101 | def __init__(self, capacity, alpha=0.6, beta_start=0.4, beta_startpoint=50000, beta_kneepoint = 1000000, error_epsilon=1e-5):
102 | self.capacity = capacity
103 | self.memory = []
104 | self.priority_tree = Sumtree(self.capacity) # store priorities
105 | self.alpha = alpha
106 | self.beta = beta_start
107 | self.beta_increase = 1/(beta_kneepoint - beta_startpoint)
108 | self.error_epsilon = error_epsilon
109 | self.push_count = 0
110 | self.dtype = torch.uint8
111 |
112 | def push(self, experience):
113 | state = (experience.state * 255).type(self.dtype).cpu()
114 | # next_state = (experience.next_state * 255).type(self.dtype)
115 | new_experience = Eco_Experience(state,experience.action,experience.reward)
116 |
117 | if len(self.memory) < self.capacity:
118 | self.memory.append(new_experience)
119 | else:
120 | self.memory[self.push_count % self.capacity] = new_experience
121 | self.push_count += 1
122 | # push new state to priority tree
123 | self.priority_tree.push()
124 |
125 | def sample(self, batch_size):
126 | # get indices of experience by priorities
127 | experience_index = []
128 | experiences = []
129 | priorities = []
130 | segment = self.priority_tree.get_p_total()/batch_size
131 | self.beta = np.min([1., self.beta + self.beta_increase])
132 | for i in range(batch_size):
133 | low = segment * i
134 | high = segment * (i+1)
135 | s = random.uniform(low, high)
136 | index, p = self.priority_tree.sample(s)
137 | experience_index.append(index)
138 | priorities.append(p)
139 | # get experience from index
140 | if self.push_count > self.capacity:
141 | state = torch.stack(([self.memory[index+j].state for j in range(-3,1)])).unsqueeze(0)
142 | next_state = torch.stack(([self.memory[index+1+j].state for j in range(-3,1)])).unsqueeze(0)
143 | else:
144 | state = torch.stack(([self.memory[np.max(index+j, 0)].state for j in range(-3,1)])).unsqueeze(0)
145 | next_state = torch.stack(([self.memory[np.max(index+1+j, 0)].state for j in range(-3,1)])).unsqueeze(0)
146 | experiences.append(Experience(state.float().cuda()/255, self.memory[index].action, next_state.float().cuda()/255, self.memory[index].reward))
147 | # compute weight
148 | possibilities = priorities / self.priority_tree.get_p_total()
149 | min_possibility = self.priority_tree.get_p_min()
150 | weight = np.power(self.priority_tree.length * possibilities, -self.beta)
151 | max_weight = np.power(self.priority_tree.length * min_possibility, -self.beta)
152 | weight = weight/max_weight
153 | weight = torch.tensor(weight[:,np.newaxis], dtype = torch.float).to(ReplayMemory_economy_PER.device)
154 | return experiences, experience_index, weight
155 |
156 | def update_priority(self, index_list, TD_error_list):
157 | # update priorities from TD error
158 | # priorities_list = np.abs(TD_error_list) + self.error_epsilon
159 | priorities_list = (np.abs(TD_error_list) + self.error_epsilon) ** self.alpha
160 | for index, priority in zip(index_list, priorities_list):
161 | self.priority_tree.update(index, priority)
162 |
163 | def can_provide_sample(self, batch_size, replay_start_size):
164 | return (len(self.memory) >= replay_start_size) and (len(self.memory) >= batch_size + 3)
165 |
166 | class EpsilonGreedyStrategyExp():
167 | # compute epsilon in epsilon-greedy algorithm by exponentially decrement
168 | def __init__(self, start, end, decay):
169 | self.start = start
170 | self.end = end
171 | self.decay = decay
172 |
173 | def get_exploration_rate(self, current_step):
174 | return self.end + (self.start - self.end) * \
175 | math.exp(-1. * current_step * self.decay)
176 |
177 | class EpsilonGreedyStrategyLinear():
178 | def __init__(self, start, end, final_eps = None, startpoint = 50000, kneepoint=1000000, final_knee_point = None):
179 | # compute epsilon in epsilon-greedy algorithm by linearly decrement
180 | self.start = start
181 | self.end = end
182 | self.final_eps = final_eps
183 | self.kneepoint = kneepoint
184 | self.startpoint = startpoint
185 | self.final_knee_point = final_knee_point
186 |
187 | def get_exploration_rate(self, current_step):
188 | if current_step < self.startpoint:
189 | return 1.
190 | mid_seg = self.end + \
191 | np.maximum(0, (1-self.end)-(1-self.end)/self.kneepoint * (current_step-self.startpoint))
192 | if not self.final_eps:
193 | return mid_seg
194 | else:
195 | if self.final_eps and self.final_knee_point and (current_step= period:
284 | moving_avg = values.unfold(dimension=0, size=period, step=1) \
285 | .mean(dim=1).flatten(start_dim=0)
286 | moving_avg = torch.cat((torch.zeros(period-1), moving_avg))
287 | return moving_avg.numpy()
288 | else:
289 | moving_avg = torch.zeros(len(values))
290 | return moving_avg.numpy()
291 |
292 | def plot(values, moving_avg_period):
293 | """
294 | test: plot(np.random.rand(300), 100)
295 | :param values: numpy 1D vector
296 | :param moving_avg_period:
297 | :return: None
298 | """
299 | # plt.figure()
300 | plt.clf()
301 | plt.title('Training...')
302 | plt.xlabel('Episode')
303 | plt.ylabel('Reward')
304 | plt.plot(values)
305 | moving_avg = get_moving_average(moving_avg_period, values)
306 | plt.plot(moving_avg)
307 | print("Episode", len(values), "\n",moving_avg_period, "episode moving avg:", moving_avg[-1])
308 | plt.pause(0.0001)
309 | # if is_ipython: display.clear_output(wait=True)
310 | return moving_avg[-1]
311 |
312 | def extract_tensors(experiences):
313 | # Convert batch of Experiences to Experience of batches
314 | batch = Experience(*zip(*experiences))
315 |
316 | t1 = torch.cat(batch.state)
317 | t2 = torch.cat(batch.action)
318 | t3 = torch.cat(batch.reward)
319 | t4 = torch.cat(batch.next_state)
320 |
321 | return (t1,t2,t3,t4)
322 |
323 | def visualize_state(state):
324 | # settings
325 | nrows, ncols = 1, 4 # array of sub-plots
326 | figsize = [8, 4] # figure size, inches
327 |
328 | # prep (x,y) for extra plotting on selected sub-plots
329 | # xs = np.linspace(0, 2 * np.pi, 60) # from 0 to 2pi
330 | # ys = np.abs(np.sin(xs)) # absolute of sine
331 |
332 | # create figure (fig), and array of axes (ax)
333 | fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
334 |
335 | # plot simple raster image on each sub-plot
336 | for i, axi in enumerate(ax.flat):
337 | # i runs from 0 to (nrows*ncols-1)
338 | # axi is equivalent with ax[rowid][colid]
339 | img = state.squeeze(0)[i,None]
340 | cpu_img = img.squeeze(0).cpu()
341 | axi.imshow(cpu_img*255,cmap='gray', vmin=0, vmax=255)
342 |
343 | # get indices of row/column
344 | rowid = i // ncols
345 | colid = i % ncols
346 | # write row/col indices as axes' title for identification
347 | # axi.set_title("Row:" + str(rowid) + ", Col:" + str(colid))
348 |
349 | # one can access the axes by ax[row_id][col_id]
350 | # do additional plotting on ax[row_id][col_id] of your choice
351 | # ax[0][2].plot(xs, 3 * ys, color='red', linewidth=3)
352 | # ax[4][3].plot(ys ** 2, xs, color='green', linewidth=3)
353 |
354 | plt.tight_layout(True)
355 | plt.show()
356 |
357 | def init_tracker_dict():
358 | " init auxilary variables"
359 | tracker = {}
360 | tracker["minibatch_updates_counter"] = 1
361 | tracker["actions_counter"] = 1
362 | tracker["running_reward"] = 0
363 | tracker["rewards_hist"] = []
364 | tracker["loss_hist"] = []
365 | tracker["eval_model_list_txt"] = []
366 | tracker["rewards_hist_update_axis"] = []
367 | # only used in evaluation script
368 | tracker["eval_reward_list"] = []
369 | tracker["best_frame_for_gif"] = []
370 | tracker["best_reward"] = 0
371 | return tracker
372 |
373 | def save_model(policy_net, tracker_dict, config_dict):
374 | path = config_dict["CHECK_POINT_PATH"] + config_dict["GAME_NAME"] + "/"
375 | if not os.path.exists(path):
376 | os.makedirs(path)
377 | fname = "Iterations_{}-Reward{:.2f}-Time_".format(tracker_dict["minibatch_updates_counter"],
378 | tracker_dict["running_reward"]) + \
379 | datetime.datetime.now().strftime(config_dict["DATE_FORMAT"]) + ".pth"
380 | torch.save(policy_net.state_dict(), path + fname)
381 | tracker_dict["eval_model_list_txt"].append(path + fname)
382 |
383 | def read_json(param_json_fname):
384 | with open(param_json_fname) as fp:
385 | params_dict = json.load(fp)
386 |
387 | config_dict = params_dict["config"]
388 | hyperparams_dict = params_dict["hyperparams"]
389 | eval_dict = params_dict["eval"]
390 | return config_dict, hyperparams_dict, eval_dict
391 |
392 |
393 | def load_Middle_Point(md_json_file_path):
394 | with open(md_json_file_path) as fp:
395 | md_path_dict = json.load(fp)
396 | return md_path_dict
397 |
398 | def generate_gif(gif_save_path,model_name, frames_for_gif, reward):
399 | """
400 | Args:
401 | frames_for_gif: A sequence of (210, 160, 3) frames of an Atari game in RGB
402 | reward: Integer, Total reward of the episode that es ouputted as a gif
403 | """
404 | if not os.path.exists(gif_save_path):
405 | os.makedirs(gif_save_path)
406 | for idx, frame_idx in enumerate(frames_for_gif):
407 | frames_for_gif[idx] = skimage_resize(frame_idx, (420, 320, 3),
408 | preserve_range=True, order=0).astype(np.uint8)
409 | fname = gif_save_path + model_name + "-EvalReward_{}.gif".format(reward)
410 | imageio.mimsave(fname, frames_for_gif, duration=1 / 30)
411 |
--------------------------------------------------------------------------------