├── pytorch ├── __init__.py ├── QNetwork.py ├── ReplayBuffer.py ├── Mario.py ├── Federator.py ├── Env.py ├── DDQN.py ├── DQN.py └── Agent.py ├── results ├── Mario │ ├── 1.gif │ ├── 3.gif │ ├── rewards.png │ └── rewards │ │ ├── rewards0.pkl │ │ ├── rewards1.pkl │ │ ├── rewards2.pkl │ │ └── rewards3.pkl ├── CartPole │ ├── 1 │ │ ├── CartPole.png │ │ ├── fed_rewards.npy │ │ └── single_rewards.npy │ └── 2 │ │ ├── Figure_2.png │ │ ├── fed_rewards.npy │ │ └── single_rewards.npy └── LunarLander │ ├── fed_rewards.npy │ ├── lunarlander.png │ └── single_rewards.npy ├── single-agent-cart.py ├── single-agent-lun.py ├── README.md ├── main-lun.py ├── main-cart.py ├── .gitignore └── Mario.ipynb /pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /results/Mario/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TroddenSpade/Federated-DRL/HEAD/results/Mario/1.gif -------------------------------------------------------------------------------- /results/Mario/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TroddenSpade/Federated-DRL/HEAD/results/Mario/3.gif -------------------------------------------------------------------------------- /results/Mario/rewards.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TroddenSpade/Federated-DRL/HEAD/results/Mario/rewards.png -------------------------------------------------------------------------------- /results/CartPole/1/CartPole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TroddenSpade/Federated-DRL/HEAD/results/CartPole/1/CartPole.png -------------------------------------------------------------------------------- /results/CartPole/2/Figure_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TroddenSpade/Federated-DRL/HEAD/results/CartPole/2/Figure_2.png -------------------------------------------------------------------------------- /results/CartPole/1/fed_rewards.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TroddenSpade/Federated-DRL/HEAD/results/CartPole/1/fed_rewards.npy -------------------------------------------------------------------------------- /results/CartPole/2/fed_rewards.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TroddenSpade/Federated-DRL/HEAD/results/CartPole/2/fed_rewards.npy -------------------------------------------------------------------------------- /results/LunarLander/fed_rewards.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TroddenSpade/Federated-DRL/HEAD/results/LunarLander/fed_rewards.npy -------------------------------------------------------------------------------- /results/LunarLander/lunarlander.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TroddenSpade/Federated-DRL/HEAD/results/LunarLander/lunarlander.png -------------------------------------------------------------------------------- /results/Mario/rewards/rewards0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TroddenSpade/Federated-DRL/HEAD/results/Mario/rewards/rewards0.pkl -------------------------------------------------------------------------------- /results/Mario/rewards/rewards1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TroddenSpade/Federated-DRL/HEAD/results/Mario/rewards/rewards1.pkl -------------------------------------------------------------------------------- /results/Mario/rewards/rewards2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TroddenSpade/Federated-DRL/HEAD/results/Mario/rewards/rewards2.pkl -------------------------------------------------------------------------------- /results/Mario/rewards/rewards3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TroddenSpade/Federated-DRL/HEAD/results/Mario/rewards/rewards3.pkl -------------------------------------------------------------------------------- /results/CartPole/1/single_rewards.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TroddenSpade/Federated-DRL/HEAD/results/CartPole/1/single_rewards.npy -------------------------------------------------------------------------------- /results/CartPole/2/single_rewards.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TroddenSpade/Federated-DRL/HEAD/results/CartPole/2/single_rewards.npy -------------------------------------------------------------------------------- /results/LunarLander/single_rewards.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TroddenSpade/Federated-DRL/HEAD/results/LunarLander/single_rewards.npy -------------------------------------------------------------------------------- /single-agent-cart.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import torch 3 | import matplotlib.pyplot as plt 4 | 5 | from pytorch.DQN import Agent 6 | from pytorch.QNetwork import FCQ 7 | from pytorch.ReplayBuffer import ReplayBuffer 8 | 9 | if __name__ == "__main__": 10 | args = { 11 | "env_fn": lambda : gym.make("CartPole-v1"), 12 | "Qnet": FCQ, 13 | "buffer": ReplayBuffer, 14 | 15 | "net_args": { 16 | "hidden_layers":(64,64), 17 | "activation_fn":torch.nn.functional.relu, 18 | "optimizer":torch.optim.Adam, 19 | "learning_rate":0.0005, 20 | }, 21 | 22 | "max_epsilon": 1.0, 23 | "min_epsilon": 0.1, 24 | "decay_steps": 5000, 25 | "gamma": 0.99, 26 | "target_update_rate": 15, 27 | "min_buffer": 64 28 | } 29 | 30 | agent = Agent(**args) 31 | agent.train(300) 32 | print(agent.episode_count) 33 | 34 | plt.plot(agent.rewards) 35 | # plt.plot(evals) 36 | plt.show() -------------------------------------------------------------------------------- /single-agent-lun.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | from pytorch.DQN import Agent 7 | from pytorch.QNetwork import FCQ 8 | from pytorch.ReplayBuffer import ReplayBuffer 9 | 10 | if __name__ == "__main__": 11 | args = { 12 | "env_fn": lambda : gym.make("LunarLander-v2"), 13 | "Qnet": FCQ, 14 | "buffer": ReplayBuffer, 15 | 16 | "net_args": { 17 | "hidden_layers":(512, 256, 128), 18 | "activation_fn":torch.nn.functional.relu, 19 | "optimizer":torch.optim.Adam, 20 | "learning_rate":0.0005, 21 | }, 22 | 23 | "max_epsilon": 1.0, 24 | "min_epsilon": 0.1, 25 | "decay_steps": 5000, 26 | "gamma": 0.99, 27 | "target_update_rate": 15, 28 | "min_buffer": 64 29 | } 30 | 31 | rewards = np.zeros(200) 32 | for i in range(10): 33 | agent = Agent(**args) 34 | agent.train(200) 35 | print(agent.step_count) 36 | rewards += agent.rewards 37 | 38 | plt.plot(rewards/10) 39 | # plt.plot(evals) 40 | plt.show() -------------------------------------------------------------------------------- /pytorch/QNetwork.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class FCQ(torch.nn.Module): 4 | def __init__(self, states_input_size, actions_size, 5 | hidden_layers, 6 | activation_fn=torch.nn.functional.relu, 7 | optimizer=torch.optim.Adam, learning_rate=0.0005) -> None: 8 | super().__init__() 9 | self.activation_fn = activation_fn 10 | 11 | self.hidden_layers = torch.nn.ModuleList() 12 | prev_size = states_input_size 13 | for layer_size in hidden_layers: 14 | self.hidden_layers.append(torch.nn.Linear(prev_size, layer_size)) 15 | prev_size = layer_size 16 | 17 | self.output_layer = torch.nn.Linear(prev_size, actions_size) 18 | 19 | self.optimizer = optimizer(self.parameters(), lr=learning_rate) 20 | 21 | 22 | def format_(self, states): 23 | if not isinstance(states, torch.Tensor): 24 | states = torch.tensor(states, dtype=torch.float32) 25 | return states 26 | 27 | 28 | def forward(self, states): 29 | x = self.format_(states) 30 | for hidden_layer in self.hidden_layers: 31 | x = self.activation_fn(hidden_layer(x)) 32 | return self.output_layer(x) 33 | 34 | 35 | def optimize(self, loss): 36 | self.optimizer.zero_grad() 37 | loss.backward() 38 | self.optimizer.step() 39 | 40 | 41 | @staticmethod 42 | def reset_weights(m): 43 | for layer in m.children(): 44 | if hasattr(layer, 'reset_parameters'): 45 | layer.reset_parameters() 46 | 47 | def reset(self): 48 | self.apply(FCQ.reset_weights) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Parallel Deep Reinforcement Learning with Federated Learning Framework 2 | The purpose of this project is to assess the effect of parallel training of multiple Deep Reinforcement Learning agents using the Federated Averaging (FedAVG) algorithm -- after training the agents for specific timesteps, all of the Deep Q Network models are aggregated by taking the average of their parameters and subsequently the averaged model will be set for all of the agents for more training rounds. 3 | 4 | ### Environments 5 | * CartPole 6 | * Lunar Lander 7 | * Super Mario Bros 8 | 9 | ### Deep Reinforcement Learning Methods 10 | * Deep Q Network 11 | * Double Deep Q Network 12 | 13 | ## Experiments 14 | ### 3 DQN Agents on Cartpole Environment 15 | ![CP1](https://github.com/TroddenSpade/Federated-DQN/blob/main/results/CartPole/1/CartPole.png?raw=true) 16 | ![CP2](https://github.com/TroddenSpade/Federated-DQN/blob/main/results/CartPole/2/Figure_2.png?raw=true) 17 | 18 | ### 3 DQN Agents on Lunar Lander Environment 19 | ![LL](https://github.com/TroddenSpade/Federated-DQN/blob/main/results/LunarLander/lunarlander.png?raw=true) 20 | 21 | ### 4 DDQN Agents on Super Mario Bros 1-1 to 1-4 22 | ![SMB](https://github.com/TroddenSpade/Federated-DQN/blob/main/results/Mario/rewards.png?raw=true) 23 | 24 | | Env 1-1 | Env 1-2 | 25 | | :---: | :---: | 26 | |![1-2](https://github.com/TroddenSpade/Federated-DQN/blob/main/results/Mario/0.gif?raw=true) | ![1-4](https://github.com/TroddenSpade/Federated-DQN/blob/main/results/Mario/1.gif?raw=true) | 27 | | Env 1-3 | Env 1-4 | 28 | |![1-2](https://github.com/TroddenSpade/Federated-DQN/blob/main/results/Mario/2.gif?raw=true) | ![1-4](https://github.com/TroddenSpade/Federated-DQN/blob/main/results/Mario/3.gif?raw=true) | 29 | -------------------------------------------------------------------------------- /main-lun.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import torch 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from tqdm import tqdm 6 | 7 | from pytorch.Federator import Federator 8 | from pytorch.DQN import Agent 9 | from pytorch.QNetwork import FCQ 10 | from pytorch.ReplayBuffer import ReplayBuffer 11 | 12 | if __name__ == "__main__": 13 | 14 | args = { 15 | "env_fn": lambda : gym.make("LunarLander-v2"), 16 | "Qnet": FCQ, 17 | "buffer": ReplayBuffer, 18 | 19 | "net_args" : { 20 | "hidden_layers":(512, 256, 128), 21 | "activation_fn":torch.nn.functional.relu, 22 | "optimizer":torch.optim.Adam, 23 | "learning_rate":0.0005, 24 | }, 25 | 26 | "max_epsilon": 1.0, 27 | "min_epsilon": 0.1, 28 | "decay_steps": 2000, 29 | "gamma": 0.99, 30 | "target_update_rate": 100, 31 | "min_buffer": 64 32 | } 33 | 34 | n_runs = 350 35 | n_agents = 3 36 | n_iterations = 5 37 | update_rate = 300 38 | 39 | fed_rewards = np.zeros(n_runs) 40 | for i in range(n_iterations): 41 | fed = Federator(n_agents=n_agents, update_rate=update_rate, args=args) 42 | fed_rewards += fed.train(n_runs) 43 | fed_rewards /= n_iterations 44 | fed.print_episode_lengths() 45 | with open('fed_rewards.npy', 'wb') as f: 46 | np.save(f, fed_rewards) 47 | 48 | single_rewards = np.zeros(n_runs) 49 | for i in range(n_iterations): 50 | ag = Agent(**args) 51 | for r in tqdm(range(n_runs)): 52 | ag.step(update_rate) 53 | single_rewards[r] += ag.evaluate() 54 | single_rewards /= n_iterations 55 | with open('single_rewards.npy', 'wb') as f: 56 | np.save(f, single_rewards) 57 | 58 | plt.plot(fed_rewards, color="b", label="federated") 59 | plt.plot(single_rewards, color="r", label="single") 60 | plt.legend() 61 | plt.show() 62 | -------------------------------------------------------------------------------- /main-cart.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import torch 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from tqdm import tqdm 6 | 7 | from pytorch.Federator import Federator 8 | from pytorch.DQN import Agent 9 | from pytorch.QNetwork import FCQ 10 | from pytorch.ReplayBuffer import ReplayBuffer 11 | 12 | if __name__ == "__main__": 13 | 14 | args = { 15 | "env_fn": lambda : gym.make("CartPole-v1"), 16 | "Qnet": FCQ, 17 | "buffer": ReplayBuffer, 18 | 19 | "net_args" : { 20 | "hidden_layers":(512, 128), 21 | "activation_fn":torch.nn.functional.relu, 22 | "optimizer":torch.optim.Adam, # torch.optim.RMSprop 23 | "learning_rate":0.0005, 24 | }, 25 | 26 | "max_epsilon": 1.0, 27 | "min_epsilon": 0.1, 28 | "decay_steps": 2000, 29 | "gamma": 0.99, 30 | "target_update_rate": 15, 31 | "min_buffer": 64 32 | } 33 | 34 | n_runs = 2000 35 | n_agents = 3 36 | n_iterations = 10 37 | update_rate = 30 38 | 39 | fed_rewards = np.zeros(n_runs) 40 | for i in range(n_iterations): 41 | fed = Federator(n_agents=n_agents, update_rate=update_rate, args=args) 42 | fed_rewards += fed.train(n_runs) 43 | fed_rewards /= n_iterations 44 | fed.print_episode_lengths() 45 | with open('fed_rewards.npy', 'wb') as f: 46 | np.save(f, fed_rewards) 47 | 48 | 49 | single_rewards = np.zeros(n_runs) 50 | for i in range(n_iterations): 51 | ag = Agent(**args) 52 | for r in tqdm(range(n_runs)): 53 | ag.step(update_rate) 54 | single_rewards[r] += ag.evaluate() 55 | single_rewards /= n_iterations 56 | with open('single_rewards.npy', 'wb') as f: 57 | np.save(f, single_rewards) 58 | 59 | plt.plot(fed_rewards, color="b", label="federated") 60 | plt.plot(single_rewards, color="r", label="single") 61 | plt.legend() 62 | plt.show() 63 | -------------------------------------------------------------------------------- /pytorch/ReplayBuffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class ReplayBuffer: 4 | def __init__(self, state_shape, action_space, 5 | batch_size=64, max_size=50000): 6 | self.next = 0 7 | self.size = 0 8 | self.max_size = max_size 9 | self.batch_size = batch_size 10 | 11 | self.states = np.empty(shape=(max_size, *state_shape)) 12 | self.actions = np.empty(shape=(max_size, 1), dtype=np.int64) 13 | self.rewards = np.empty(shape=(max_size)) 14 | self.states_p = np.empty(shape=(max_size, *state_shape)) 15 | self.is_terminals = np.empty(shape=(max_size), dtype=np.float) 16 | 17 | def __len__(self): return self.size 18 | 19 | def store(self, state, action, reward, state_p, is_terminal): 20 | self.states[self.next] = state 21 | self.actions[self.next] = action 22 | self.rewards[self.next] = reward 23 | self.states_p[self.next] = state_p 24 | self.is_terminals[self.next] = is_terminal 25 | 26 | self.next += 1 27 | self.size = min(self.size + 1, self.max_size) 28 | self.next = self.next % self.max_size 29 | 30 | def sample(self, batch_size=None): 31 | batch_size = self.batch_size \ 32 | if batch_size is None else batch_size 33 | indices = np.random.choice(self.size, size=batch_size, 34 | replace=False) 35 | return self.states[indices], \ 36 | self.actions[indices], \ 37 | self.rewards[indices], \ 38 | self.states_p[indices], \ 39 | self.is_terminals[indices] 40 | 41 | def clear(self): 42 | self.next = 0 43 | self.size = 0 44 | self.states = np.empty_like(self.states) 45 | self.actions = np.empty_like(self.actions) 46 | self.rewards = np.empty_like(self.rewards) 47 | self.states_p = np.empty_like(self.states_p) 48 | self.is_terminals = np.empty_like(self.is_terminals) 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /pytorch/Mario.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | from Agent import Agent 6 | from QNetwork import QNetwork 7 | 8 | 9 | class Mario(Agent): 10 | def __init__(self, env_names, env_fn, Qnet=QNetwork, load=False, path=None) -> None: 11 | self.path = path + "global/" 12 | self.envs = [] 13 | for name in env_names: 14 | self.envs.append(env_fn(name)) 15 | self.n_actions = self.envs[0].action_space.n 16 | self.state_shape = self.envs[0].observation_space.shape 17 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 18 | 19 | self.online_net = Qnet(self.state_shape, self.n_actions).to(self.device) 20 | self.target_net = Qnet(self.state_shape, self.n_actions).to(self.device) 21 | 22 | if load: 23 | self.load() 24 | else: 25 | self.update_target_network() 26 | 27 | 28 | def load(self): 29 | self.online_net.load_state_dict(torch.load(self.path + "online_net.pt", 30 | map_location=torch.device(self.device))) 31 | self.target_net.load_state_dict(torch.load(self.path + "target_net.pt", 32 | map_location=torch.device(self.device))) 33 | 34 | 35 | def save(self): 36 | os.makedirs(os.path.dirname(self.path), exist_ok=True) 37 | torch.save(self.online_net.state_dict(), self.path + "online_net.pt") 38 | torch.save(self.target_net.state_dict(), self.path + "target_net.pt") 39 | 40 | 41 | def get_score(self): 42 | # return np.mean(self.rewards[-5:]) 43 | return 1 44 | 45 | 46 | def test(self): 47 | rewards = np.zeros(len(self.envs)) 48 | for i in range(len(self.envs)): 49 | r = self.evaluate(i) 50 | rewards[i] = r 51 | return rewards 52 | 53 | 54 | def evaluate(self, i, render=False): 55 | rewards = 0 56 | state = self.envs[i].reset() 57 | while True: 58 | action = self.greedyPolicy(state) 59 | state_p, reward, done, _ = self.envs[i].step(action) 60 | if render: 61 | self.envs[i].render() 62 | rewards += reward 63 | if done: 64 | break 65 | state = state_p 66 | return rewards 67 | 68 | 69 | def greedyPolicy(self, state): 70 | with torch.no_grad(): 71 | state = state.__array__() 72 | state = torch.tensor(state).unsqueeze(0).to(self.device) 73 | action = self.target_net(state).argmax().item() 74 | return action -------------------------------------------------------------------------------- /pytorch/Federator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | 4 | from pytorch.DQN import Agent 5 | from pytorch.QNetwork import FCQ 6 | from pytorch.ReplayBuffer import ReplayBuffer 7 | 8 | class Federator: 9 | def __init__(self, n_agents, update_rate, args) -> None: 10 | 11 | self.env = args["env_fn"]() 12 | self.n_actions = self.env.action_space.n 13 | self.state_shape = self.env.observation_space.shape 14 | 15 | self.global_agent = Agent(**args) 16 | 17 | self.update_rate = update_rate 18 | self.n_agents = n_agents 19 | self.agents = [] 20 | for _ in range(n_agents): 21 | agent = Agent(**args) 22 | self.agents.append(agent) 23 | 24 | self.set_local_networks() 25 | 26 | 27 | def print_episode_lengths(self): 28 | for a in self.agents: 29 | print(a.episode_count) 30 | 31 | def train(self, n_runs): 32 | rewards = np.zeros(n_runs) 33 | for r in tqdm(range(n_runs)): 34 | scores = [] 35 | for agent in self.agents: 36 | agent.step(self.update_rate) 37 | scores.append(agent.get_score()) 38 | self.aggregate_networks(scores) 39 | self.set_local_networks() 40 | rewards[r] = self.global_agent.evaluate() 41 | return rewards 42 | 43 | 44 | def aggregate_networks(self, scores): 45 | sd_online = self.global_agent.online_net.state_dict() 46 | sd_target = self.global_agent.target_net.state_dict() 47 | 48 | online_dicts = [] 49 | target_dicts = [] 50 | for agent in self.agents: 51 | online_dicts.append(agent.online_net.state_dict()) 52 | target_dicts.append(agent.target_net.state_dict()) 53 | 54 | for key in sd_online: 55 | sd_online[key] -= sd_online[key] 56 | for i, dict in enumerate(online_dicts): 57 | sd_online[key] += scores[i] * dict[key] 58 | sd_online[key] /= sum(scores) 59 | 60 | for key in sd_target: 61 | sd_target[key] -= sd_target[key] 62 | for i, dict in enumerate(target_dicts): 63 | sd_target[key] += scores[i] * dict[key] 64 | sd_target[key] /= sum(scores) 65 | 66 | self.global_agent.online_net.load_state_dict(sd_online) 67 | self.global_agent.target_net.load_state_dict(sd_target) 68 | 69 | 70 | def set_local_networks(self): 71 | for agent in self.agents: 72 | agent.online_net.load_state_dict( 73 | self.global_agent.online_net.state_dict()) 74 | agent.target_net.load_state_dict( 75 | self.global_agent.target_net.state_dict()) -------------------------------------------------------------------------------- /pytorch/Env.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as T 3 | import numpy as np 4 | 5 | import gym 6 | from gym.wrappers import FrameStack 7 | 8 | from nes_py.wrappers import JoypadSpace 9 | import gym_super_mario_bros 10 | from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT 11 | 12 | class SkipFrame(gym.Wrapper): 13 | def __init__(self, env, skip): 14 | """Return only every `skip`-th frame""" 15 | super().__init__(env) 16 | self._skip = skip 17 | 18 | def step(self, action): 19 | """Repeat action, and sum reward""" 20 | total_reward = 0.0 21 | done = False 22 | for i in range(self._skip): 23 | # Accumulate reward and repeat the same action 24 | obs, reward, done, info = self.env.step(action) 25 | total_reward += reward 26 | if done: 27 | break 28 | return obs, total_reward, done, info 29 | 30 | class ResizeObservation(gym.ObservationWrapper): 31 | def __init__(self, env, shape): 32 | super().__init__(env) 33 | if isinstance(shape, int): 34 | self.shape = (shape, shape) 35 | else: 36 | self.shape = tuple(shape) 37 | obs_shape = self.shape + self.observation_space.shape[2:] 38 | self.observation_space = gym.spaces.Box(low=0, high=255, 39 | shape=obs_shape, dtype=np.uint8) 40 | 41 | def observation(self, observation): 42 | transforms = T.Compose( 43 | [T.Resize(self.shape), T.ToTensor(), T.Normalize((0,), (255,))] 44 | ) 45 | observation = transforms(observation).squeeze(0) 46 | return observation 47 | 48 | class GrayScaleObservation(gym.ObservationWrapper): 49 | def __init__(self, env): 50 | super().__init__(env) 51 | obs_shape = self.observation_space.shape[:2] 52 | self.observation_space = gym.spaces.Box(low=0, high=255, 53 | shape=obs_shape, dtype=np.uint8) 54 | 55 | def permute_orientation(self, observation): 56 | # permute [H, W, C] array to [C, H, W] tensor 57 | observation = np.transpose(observation, (2, 0, 1)) 58 | observation = torch.tensor(observation.copy(), dtype=torch.float) 59 | return observation 60 | 61 | def observation(self, observation): 62 | observation = self.permute_orientation(observation) 63 | transforms = T.Compose( 64 | [T.ToPILImage(), T.Grayscale()] 65 | ) 66 | observation = transforms(observation) 67 | return observation 68 | 69 | def create_mario_env(env_name): 70 | env = gym_super_mario_bros.make(env_name) 71 | env = SkipFrame(env, skip=4) 72 | env = GrayScaleObservation(env) 73 | env = ResizeObservation(env, shape=84) 74 | env = FrameStack(env, num_stack=4) 75 | return JoypadSpace(env, SIMPLE_MOVEMENT) -------------------------------------------------------------------------------- /pytorch/DDQN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | 5 | class Agent(): 6 | def __init__(self, env_fn, Qnet, buffer, net_args, 7 | max_epsilon, min_epsilon, 8 | gamma, decay_steps, 9 | target_update_rate, min_buffer) -> None: 10 | 11 | self.env = env_fn() 12 | self.env_test = env_fn() 13 | self.n_actions = self.env.action_space.n 14 | self.state_shape = self.env.observation_space.shape 15 | 16 | self.online_net = Qnet(self.state_shape[0], self.n_actions, **net_args) 17 | self.target_net = Qnet(self.state_shape[0], self.n_actions, **net_args) 18 | self.update_target_network() 19 | 20 | self.buffer = buffer(self.state_shape, self.n_actions) 21 | self.min_buffer = min_buffer 22 | 23 | self.epsilon = max_epsilon 24 | self.min_epsilon = min_epsilon 25 | self.epsilon_decay = (max_epsilon - min_epsilon)/decay_steps 26 | self.gamma = gamma 27 | self.target_update_rate = target_update_rate 28 | 29 | self.step_count = 0 30 | self.episode_reward = 0 31 | self.episode_count = 1 32 | self.state = self.env.reset() 33 | self.rewards = [] 34 | 35 | 36 | def step(self, steps): 37 | for _ in range(steps): 38 | self.step_count += 1 39 | 40 | action = self.epsilonGreedyPolicy(self.state) 41 | state_p, reward, done, info = self.env.step(action) 42 | self.episode_reward += reward 43 | 44 | is_truncated = 'TimeLimit.truncated' in info and info['TimeLimit.truncated'] 45 | is_failure = done and not is_truncated 46 | self.buffer.store(self.state, action, reward, state_p, float(is_failure)) 47 | 48 | if len(self.buffer) >= self.min_buffer: 49 | self.update() 50 | if self.step_count % self.target_update_rate == 0: 51 | self.update_target_network() 52 | 53 | self.state = state_p 54 | if done: 55 | self.episode_count += 1 56 | self.state = self.env.reset() 57 | self.rewards.append(self.episode_reward) 58 | self.episode_reward = 0 59 | 60 | 61 | def train(self, n_episodes): 62 | for _ in range(n_episodes): 63 | self.episode_reward = 0 64 | self.state = self.env.reset() 65 | 66 | while True: 67 | self.step_count += 1 68 | action = self.epsilonGreedyPolicy(self.state) 69 | state_p, reward, done, info = self.env.step(action) 70 | self.episode_reward += reward 71 | 72 | is_truncated = 'TimeLimit.truncated' in info and info['TimeLimit.truncated'] 73 | is_failure = done and not is_truncated 74 | self.buffer.store(self.state, action, reward, state_p, float(is_failure)) 75 | 76 | if len(self.buffer) >= self.min_buffer: 77 | self.update() 78 | if self.step_count % self.target_update_rate == 0: 79 | self.update_target_network() 80 | 81 | self.state = state_p 82 | if done: 83 | print(self.episode_reward) 84 | self.episode_count += 1 85 | self.rewards.append(self.episode_reward) 86 | break 87 | 88 | 89 | def get_score(self): 90 | return np.mean(self.rewards[-5:]) 91 | 92 | 93 | def evaluate(self): 94 | rewards = 0 95 | state = self.env_test.reset() 96 | while True: 97 | action = self.greedyPolicy(state) 98 | state_p, reward, done, _ = self.env_test.step(action) 99 | rewards += reward 100 | if done: 101 | break 102 | state = state_p 103 | return rewards 104 | 105 | 106 | def update(self): 107 | states, actions, rewards, states_p, is_terminals = self.buffer.sample() 108 | actions = torch.tensor(actions) 109 | is_terminals = torch.tensor(is_terminals) 110 | rewards = torch.tensor(rewards) 111 | q_states = self.online_net(states).gather(1, actions).squeeze() 112 | with torch.no_grad(): 113 | q_states_p = self.target_net(states_p) 114 | q_target = rewards + self.gamma * (1-is_terminals) * q_states_p.max(1)[0] 115 | 116 | td_error = q_states - q_target 117 | loss = td_error.pow(2).mean() 118 | 119 | self.online_net.optimize(loss) 120 | 121 | 122 | def update_epsilon(self): 123 | self.epsilon = max(self.min_epsilon, 124 | self.epsilon - self.epsilon_decay) 125 | 126 | 127 | def update_target_network(self): 128 | self.target_net.load_state_dict(self.online_net.state_dict()) 129 | 130 | 131 | def epsilonGreedyPolicy(self, state): 132 | if np.random.rand() < self.epsilon: 133 | action = np.random.randint(self.n_actions) 134 | else: 135 | with torch.no_grad(): 136 | action = self.online_net(state).argmax().item() 137 | self.update_epsilon() 138 | return action 139 | 140 | 141 | def greedyPolicy(self, state): 142 | with torch.no_grad(): 143 | action = self.target_net(state).argmax() 144 | return action.item() -------------------------------------------------------------------------------- /pytorch/DQN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | 5 | class Agent(): 6 | def __init__(self, env_fn, Qnet, buffer, net_args, 7 | max_epsilon, min_epsilon, 8 | gamma, decay_steps, 9 | target_update_rate, min_buffer) -> None: 10 | 11 | self.env = env_fn() 12 | self.env_test = env_fn() 13 | self.n_actions = self.env.action_space.n 14 | self.state_shape = self.env.observation_space.shape 15 | 16 | self.online_net = Qnet(self.state_shape[0], self.n_actions, **net_args) 17 | self.target_net = Qnet(self.state_shape[0], self.n_actions, **net_args) 18 | self.update_target_network() 19 | 20 | self.buffer = buffer(self.state_shape, self.n_actions) 21 | self.min_buffer = min_buffer 22 | 23 | self.epsilon = max_epsilon 24 | self.min_epsilon = min_epsilon 25 | self.epsilon_decay = (max_epsilon - min_epsilon)/decay_steps 26 | self.gamma = gamma 27 | self.target_update_rate = target_update_rate 28 | 29 | self.step_count = 0 30 | self.episode_reward = 0 31 | self.episode_count = 1 32 | self.state = self.env.reset() 33 | self.rewards = [] 34 | 35 | 36 | def step(self, steps): 37 | for _ in range(steps): 38 | self.step_count += 1 39 | 40 | action = self.epsilonGreedyPolicy(self.state) 41 | state_p, reward, done, info = self.env.step(action) 42 | self.episode_reward += reward 43 | 44 | is_truncated = 'TimeLimit.truncated' in info and info['TimeLimit.truncated'] 45 | is_failure = done and not is_truncated 46 | self.buffer.store(self.state, action, reward, state_p, float(is_failure)) 47 | 48 | if len(self.buffer) >= self.min_buffer: 49 | self.update() 50 | if self.step_count % self.target_update_rate == 0: 51 | self.update_target_network() 52 | 53 | self.state = state_p 54 | if done: 55 | self.episode_count += 1 56 | self.state = self.env.reset() 57 | self.rewards.append(self.episode_reward) 58 | self.episode_reward = 0 59 | 60 | 61 | def train(self, n_episodes): 62 | for _ in range(n_episodes): 63 | self.episode_reward = 0 64 | self.state = self.env.reset() 65 | 66 | while True: 67 | self.step_count += 1 68 | action = self.epsilonGreedyPolicy(self.state) 69 | state_p, reward, done, info = self.env.step(action) 70 | self.episode_reward += reward 71 | 72 | is_truncated = 'TimeLimit.truncated' in info and info['TimeLimit.truncated'] 73 | is_failure = done and not is_truncated 74 | self.buffer.store(self.state, action, reward, state_p, float(is_failure)) 75 | 76 | if len(self.buffer) >= self.min_buffer: 77 | self.update() 78 | if self.step_count % self.target_update_rate == 0: 79 | self.update_target_network() 80 | 81 | self.state = state_p 82 | if done: 83 | print(self.episode_reward) 84 | self.episode_count += 1 85 | self.rewards.append(self.episode_reward) 86 | break 87 | 88 | 89 | def get_score(self): 90 | return np.mean(self.rewards[-5:]) 91 | 92 | 93 | def evaluate(self): 94 | rewards = 0 95 | state = self.env_test.reset() 96 | while True: 97 | action = self.greedyPolicy(state) 98 | state_p, reward, done, _ = self.env_test.step(action) 99 | rewards += reward 100 | if done: 101 | break 102 | state = state_p 103 | return rewards 104 | 105 | 106 | def update(self): 107 | states, actions, rewards, states_p, is_terminals = self.buffer.sample() 108 | actions = torch.tensor(actions) 109 | is_terminals = torch.tensor(is_terminals) 110 | rewards = torch.tensor(rewards) 111 | q_states = self.online_net(states).gather(1, actions).squeeze() 112 | with torch.no_grad(): 113 | q_states_p = self.target_net(states_p) 114 | q_target = rewards + self.gamma * (1-is_terminals) * q_states_p.max(1)[0] 115 | 116 | td_error = q_states - q_target 117 | loss = td_error.pow(2).mean() 118 | 119 | self.online_net.optimize(loss) 120 | 121 | 122 | def update_epsilon(self): 123 | self.epsilon = max(self.min_epsilon, 124 | self.epsilon - self.epsilon_decay) 125 | 126 | 127 | def update_target_network(self): 128 | self.target_net.load_state_dict(self.online_net.state_dict()) 129 | 130 | 131 | def epsilonGreedyPolicy(self, state): 132 | if np.random.rand() < self.epsilon: 133 | action = np.random.randint(self.n_actions) 134 | else: 135 | with torch.no_grad(): 136 | action = self.online_net(state).argmax().item() 137 | self.update_epsilon() 138 | return action 139 | 140 | 141 | def greedyPolicy(self, state): 142 | with torch.no_grad(): 143 | action = self.target_net(state).argmax() 144 | return action.item() -------------------------------------------------------------------------------- /pytorch/Agent.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | from QNetwork import QNetwork 8 | from ReplayBuffer import ReplayBuffer 9 | 10 | class Agent(): 11 | def __init__(self, id, env_name, env_fn, Qnet=QNetwork, buffer=ReplayBuffer, 12 | max_epsilon=1, min_epsilon=0.05, epsilon_decay=0.99, gamma=0.9, 13 | target_update_rate=2000, min_buffer=100, 14 | load=False, path=None) -> None: 15 | self.id = id 16 | self.path = path + str(id) + "/" 17 | 18 | self.env = env_fn(env_name) 19 | self.env_fn = env_fn 20 | self.n_actions = self.env.action_space.n 21 | self.state_shape = self.env.observation_space.shape 22 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 23 | 24 | self.min_buffer = min_buffer 25 | self.min_epsilon = min_epsilon 26 | self.epsilon_decay = epsilon_decay 27 | self.gamma = gamma 28 | self.target_update_rate = target_update_rate 29 | self.buffer = buffer(self.state_shape, self.n_actions, 30 | load=load, path=self.path) 31 | 32 | self.online_net = Qnet(self.state_shape, self.n_actions).to(self.device) 33 | self.target_net = Qnet(self.state_shape, self.n_actions).to(self.device) 34 | 35 | if load: 36 | self.load() 37 | else: 38 | self.update_target_network() 39 | self.epsilon = max_epsilon 40 | self.step_count = 0 41 | self.episode_count = 0 42 | self.rewards = [] 43 | 44 | 45 | def load(self): 46 | with open(self.path + "step_count.pkl", 'rb') as f: 47 | self.step_count = pickle.load(f) 48 | with open(self.path + "episode_count.pkl", 'rb') as f: 49 | self.episode_count = pickle.load(f) 50 | with open(self.path + "rewards.pkl", 'rb') as f: 51 | self.rewards = pickle.load(f) 52 | with open(self.path + "epsilon.pkl", 'rb') as f: 53 | self.epsilon = pickle.load(f) 54 | self.online_net.load_state_dict(torch.load(self.path + "online_net.pt", 55 | map_location=torch.device(self.device))) 56 | self.target_net.load_state_dict(torch.load(self.path + "target_net.pt", 57 | map_location=torch.device(self.device))) 58 | 59 | def save(self): 60 | os.makedirs(os.path.dirname(self.path), exist_ok=True) 61 | self.buffer.save() 62 | with open(self.path + "step_count.pkl", "wb") as f: 63 | pickle.dump(self.step_count, f) 64 | with open(self.path + "episode_count.pkl", "wb") as f: 65 | pickle.dump(self.episode_count, f) 66 | with open(self.path + "rewards.pkl", "wb") as f: 67 | pickle.dump(self.rewards, f) 68 | with open(self.path + "epsilon.pkl", "wb") as f: 69 | pickle.dump(self.epsilon, f) 70 | torch.save(self.online_net.state_dict(), self.path + "online_net.pt") 71 | torch.save(self.target_net.state_dict(), self.path + "target_net.pt") 72 | 73 | 74 | 75 | def train(self, n_episodes): 76 | for i in tqdm(range(n_episodes)): 77 | episode_reward = 0 78 | state = self.env.reset() 79 | 80 | while True: 81 | self.step_count += 1 82 | action = self.epsilonGreedyPolicy(state) 83 | state_p, reward, done, info = self.env.step(action) 84 | episode_reward += reward 85 | 86 | is_truncated = 'TimeLimit.truncated' in info and info['TimeLimit.truncated'] 87 | is_failure = done and not is_truncated 88 | self.buffer.store(state, action, reward, state_p, float(is_failure)) 89 | 90 | if len(self.buffer) >= self.min_buffer: 91 | self.update() 92 | if self.step_count % self.target_update_rate == 0: 93 | self.update_target_network() 94 | 95 | state = state_p 96 | if done: 97 | self.episode_count += 1 98 | self.rewards.append(episode_reward) 99 | break 100 | 101 | print("Agent-{} Episode {} Step {} score = {}, average score = {}"\ 102 | .format(self.id, self.episode_count, self.step_count, self.rewards[-1], np.mean(self.rewards))) 103 | 104 | 105 | def get_score(self): 106 | # return np.mean(self.rewards[-5:]) 107 | return 1 108 | 109 | 110 | def update(self): 111 | states, actions, rewards, states_p, is_terminals = self.buffer.sample() 112 | states = states.to(self.device) 113 | actions = actions.to(self.device) 114 | rewards = rewards.to(self.device) 115 | states_p = states_p.to(self.device) 116 | is_terminals = is_terminals.to(self.device) 117 | 118 | td_estimate = self.online_net(states).gather(1, actions) 119 | 120 | actions_p = self.online_net(states).argmax(axis=1, keepdim=True) 121 | with torch.no_grad(): 122 | q_states_p = self.target_net(states_p) 123 | q_state_p_action_p = q_states_p.gather(1, actions_p) 124 | td_target = rewards + (1-is_terminals) * self.gamma * q_state_p_action_p 125 | 126 | self.online_net.update_netowrk(td_estimate, td_target) 127 | self.update_epsilon() 128 | 129 | 130 | def update_epsilon(self): 131 | self.epsilon *= self.epsilon_decay 132 | self.epsilon = max(self.epsilon, self.min_epsilon) 133 | 134 | 135 | def update_target_network(self): 136 | self.target_net.load_state_dict(self.online_net.state_dict()) 137 | 138 | 139 | def epsilonGreedyPolicy(self, state): 140 | if np.random.rand() < self.epsilon: 141 | action = np.random.randint(self.n_actions) 142 | else: 143 | state = state.__array__() 144 | state = torch.tensor(state).unsqueeze(0).to(self.device) 145 | with torch.no_grad(): 146 | action = self.online_net(state).argmax().item() 147 | return action 148 | -------------------------------------------------------------------------------- /Mario.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Mario.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | }, 17 | "accelerator": "GPU" 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "source": [ 23 | "NES Environment" 24 | ], 25 | "metadata": { 26 | "id": "U5Zpi_aKcVYR" 27 | } 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 1, 32 | "metadata": { 33 | "id": "tQ0ChH2FI1mu", 34 | "colab": { 35 | "base_uri": "https://localhost:8080/" 36 | }, 37 | "outputId": "769ecf0b-1f0c-4698-f411-aa1d9020e100" 38 | }, 39 | "outputs": [ 40 | { 41 | "output_type": "stream", 42 | "name": "stdout", 43 | "text": [ 44 | "Requirement already satisfied: gym-super-mario-bros==7.3.0 in /usr/local/lib/python3.7/dist-packages (7.3.0)\n", 45 | "Requirement already satisfied: nes-py>=8.0.0 in /usr/local/lib/python3.7/dist-packages (from gym-super-mario-bros==7.3.0) (8.1.8)\n", 46 | "Requirement already satisfied: numpy>=1.18.5 in /usr/local/lib/python3.7/dist-packages (from nes-py>=8.0.0->gym-super-mario-bros==7.3.0) (1.21.5)\n", 47 | "Requirement already satisfied: gym>=0.17.2 in /usr/local/lib/python3.7/dist-packages (from nes-py>=8.0.0->gym-super-mario-bros==7.3.0) (0.17.3)\n", 48 | "Requirement already satisfied: pyglet<=1.5.11,>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from nes-py>=8.0.0->gym-super-mario-bros==7.3.0) (1.5.0)\n", 49 | "Requirement already satisfied: tqdm>=4.48.2 in /usr/local/lib/python3.7/dist-packages (from nes-py>=8.0.0->gym-super-mario-bros==7.3.0) (4.62.3)\n", 50 | "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from gym>=0.17.2->nes-py>=8.0.0->gym-super-mario-bros==7.3.0) (1.4.1)\n", 51 | "Requirement already satisfied: cloudpickle<1.7.0,>=1.2.0 in /usr/local/lib/python3.7/dist-packages (from gym>=0.17.2->nes-py>=8.0.0->gym-super-mario-bros==7.3.0) (1.3.0)\n", 52 | "Requirement already satisfied: future in /usr/local/lib/python3.7/dist-packages (from pyglet<=1.5.11,>=1.4.0->nes-py>=8.0.0->gym-super-mario-bros==7.3.0) (0.16.0)\n", 53 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (4.62.3)\n" 54 | ] 55 | } 56 | ], 57 | "source": [ 58 | "!pip install gym-super-mario-bros==7.3.0\n", 59 | "!pip install tqdm" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "source": [ 65 | "import pickle\n", 66 | "import os\n", 67 | "from tqdm import tqdm\n", 68 | "\n", 69 | "from nes_py.wrappers import JoypadSpace\n", 70 | "import gym_super_mario_bros\n", 71 | "from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT\n", 72 | "\n", 73 | "import gym\n", 74 | "from gym.wrappers import FrameStack\n", 75 | "\n", 76 | "import numpy as np\n", 77 | "import torch\n", 78 | "from torchvision import transforms as T\n", 79 | "import matplotlib.pyplot as plt" 80 | ], 81 | "metadata": { 82 | "id": "QU1-XZL_cewu" 83 | }, 84 | "execution_count": 2, 85 | "outputs": [] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "source": [ 90 | "%matplotlib inline" 91 | ], 92 | "metadata": { 93 | "id": "4EBzBrX_GNWD" 94 | }, 95 | "execution_count": 3, 96 | "outputs": [] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "source": [ 101 | "env_test = gym_super_mario_bros.make('SuperMarioBros-1-1-v0')\n", 102 | "print(env_test.observation_space.shape)\n", 103 | "print(env_test.action_space.n)" 104 | ], 105 | "metadata": { 106 | "colab": { 107 | "base_uri": "https://localhost:8080/" 108 | }, 109 | "id": "J9onIgm6dImw", 110 | "outputId": "8c6a5874-17d6-4ebd-b512-d7c2f1c2b0ab" 111 | }, 112 | "execution_count": 4, 113 | "outputs": [ 114 | { 115 | "output_type": "stream", 116 | "name": "stdout", 117 | "text": [ 118 | "(240, 256, 3)\n", 119 | "256\n" 120 | ] 121 | } 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "source": [ 127 | "class SkipFrame(gym.Wrapper):\n", 128 | " def __init__(self, env, skip):\n", 129 | " \"\"\"Return only every `skip`-th frame\"\"\"\n", 130 | " super().__init__(env)\n", 131 | " self._skip = skip\n", 132 | "\n", 133 | " def step(self, action):\n", 134 | " \"\"\"Repeat action, and sum reward\"\"\"\n", 135 | " total_reward = 0.0\n", 136 | " done = False\n", 137 | " for i in range(self._skip):\n", 138 | " # Accumulate reward and repeat the same action\n", 139 | " obs, reward, done, info = self.env.step(action)\n", 140 | " total_reward += reward\n", 141 | " if done:\n", 142 | " break\n", 143 | " return obs, total_reward, done, info" 144 | ], 145 | "metadata": { 146 | "id": "9Q_m_d2DCRyG" 147 | }, 148 | "execution_count": 5, 149 | "outputs": [] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "source": [ 154 | "class ResizeObservation(gym.ObservationWrapper):\n", 155 | " def __init__(self, env, shape):\n", 156 | " super().__init__(env)\n", 157 | " if isinstance(shape, int):\n", 158 | " self.shape = (shape, shape)\n", 159 | " else:\n", 160 | " self.shape = tuple(shape)\n", 161 | " obs_shape = self.shape + self.observation_space.shape[2:]\n", 162 | " self.observation_space = gym.spaces.Box(low=0, high=255, \n", 163 | " shape=obs_shape, dtype=np.uint8)\n", 164 | "\n", 165 | " def observation(self, observation):\n", 166 | " transforms = T.Compose(\n", 167 | " [T.Resize(self.shape), T.Normalize(0, 255)]\n", 168 | " )\n", 169 | " observation = transforms(observation).squeeze(0)\n", 170 | " return observation" 171 | ], 172 | "metadata": { 173 | "id": "eu_BDqyBCfvE" 174 | }, 175 | "execution_count": 6, 176 | "outputs": [] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "source": [ 181 | "class GrayScaleObservation(gym.ObservationWrapper):\n", 182 | " def __init__(self, env):\n", 183 | " super().__init__(env)\n", 184 | " obs_shape = self.observation_space.shape[:2]\n", 185 | " self.observation_space = gym.spaces.Box(low=0, high=255, \n", 186 | " shape=obs_shape, dtype=np.uint8)\n", 187 | "\n", 188 | " def permute_orientation(self, observation):\n", 189 | " # permute [H, W, C] array to [C, H, W] tensor\n", 190 | " observation = np.transpose(observation, (2, 0, 1))\n", 191 | " observation = torch.tensor(observation.copy(), dtype=torch.float)\n", 192 | " return observation\n", 193 | "\n", 194 | " def observation(self, observation):\n", 195 | " observation = self.permute_orientation(observation)\n", 196 | " transform = T.Grayscale()\n", 197 | " observation = transform(observation)\n", 198 | " return observation" 199 | ], 200 | "metadata": { 201 | "id": "pkXKwY5Adbuz" 202 | }, 203 | "execution_count": 7, 204 | "outputs": [] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "source": [ 209 | "def create_mario_env(env_name):\n", 210 | " env = gym_super_mario_bros.make(env_name)\n", 211 | " env = SkipFrame(env, skip=4)\n", 212 | " env = GrayScaleObservation(env)\n", 213 | " env = ResizeObservation(env, shape=84)\n", 214 | " env = FrameStack(env, num_stack=4)\n", 215 | " return JoypadSpace(env, SIMPLE_MOVEMENT)" 216 | ], 217 | "metadata": { 218 | "id": "BMdORGcbYcmt" 219 | }, 220 | "execution_count": 8, 221 | "outputs": [] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "source": [ 226 | "import numpy as np\n", 227 | "\n", 228 | "class ReplayBuffer:\n", 229 | " def __init__(self, state_shape, action_space, batch_size=32, max_size=10000,\n", 230 | " load=False, path=None):\n", 231 | " self.path = path + 'buffer/'\n", 232 | " self.max_size = max_size\n", 233 | " self.batch_size = batch_size\n", 234 | "\n", 235 | " if load:\n", 236 | " self.load()\n", 237 | " else:\n", 238 | " self.next = 0\n", 239 | " self.size = 0\n", 240 | "\n", 241 | " self.states = torch.empty((max_size, *state_shape))\n", 242 | " self.actions = torch.empty((max_size, 1), dtype=torch.int64)\n", 243 | " self.rewards = torch.empty((max_size, 1))\n", 244 | " self.states_p = torch.empty((max_size, *state_shape))\n", 245 | " self.is_terminals = torch.empty((max_size, 1), dtype=torch.float)\n", 246 | "\n", 247 | "\n", 248 | " def __len__(self): return self.size\n", 249 | " \n", 250 | "\n", 251 | " def store(self, state, action, reward, state_p, is_terminal):\n", 252 | " state = state.__array__()\n", 253 | " state_p = state_p.__array__()\n", 254 | "\n", 255 | " self.states[self.next] = torch.tensor(state)\n", 256 | " self.actions[self.next] = action\n", 257 | " self.rewards[self.next] = reward\n", 258 | " self.states_p[self.next] = torch.tensor(state_p)\n", 259 | " self.is_terminals[self.next] = is_terminal\n", 260 | "\n", 261 | " self.size = min(self.size + 1, self.max_size)\n", 262 | " self.next = (self.next + 1) % self.max_size\n", 263 | "\n", 264 | "\n", 265 | " def sample(self):\n", 266 | " indices = np.random.choice(self.size, size=self.batch_size, \n", 267 | " replace=False)\n", 268 | " return self.states[indices], \\\n", 269 | " self.actions[indices], \\\n", 270 | " self.rewards[indices], \\\n", 271 | " self.states_p[indices], \\\n", 272 | " self.is_terminals[indices]\n", 273 | "\n", 274 | "\n", 275 | " def clear(self):\n", 276 | " self.next = 0\n", 277 | " self.size = 0\n", 278 | " self.states = torch.empty_like(self.states)\n", 279 | " self.actions = torch.empty_like(self.actions)\n", 280 | " self.rewards = torch.empty_like(self.rewards)\n", 281 | " self.states_p = torch.empty_like(self.states_p)\n", 282 | " self.is_terminals = torch.empty_like(self.is_terminals)\n", 283 | "\n", 284 | "\n", 285 | " def load(self):\n", 286 | " with open(self.path + \"next.pkl\", 'rb') as f:\n", 287 | " self.next = pickle.load(f)\n", 288 | " with open(self.path + \"size.pkl\", 'rb') as f:\n", 289 | " self.size = pickle.load(f)\n", 290 | " self.states = torch.load(self.path + \"states.pt\")\n", 291 | " self.actions = torch.load(self.path + \"actions.pt\")\n", 292 | " self.rewards = torch.load(self.path + \"rewards.pt\")\n", 293 | " self.states_p = torch.load(self.path + \"states_p.pt\")\n", 294 | " self.is_terminals = torch.load(self.path + \"is_terminals.pt\")\n", 295 | "\n", 296 | "\n", 297 | " def save(self):\n", 298 | " os.makedirs(os.path.dirname(self.path), exist_ok=True)\n", 299 | " with open(self.path + \"next.pkl\", \"wb\") as f:\n", 300 | " pickle.dump(self.next, f)\n", 301 | " with open(self.path + \"size.pkl\", \"wb\") as f:\n", 302 | " pickle.dump(self.size, f)\n", 303 | " torch.save(self.states, self.path + \"states.pt\")\n", 304 | " torch.save(self.actions, self.path + \"actions.pt\")\n", 305 | " torch.save(self.rewards, self.path + \"rewards.pt\")\n", 306 | " torch.save(self.states_p, self.path + \"states_p.pt\")\n", 307 | " torch.save(self.is_terminals, self.path + \"is_terminals.pt\")" 308 | ], 309 | "metadata": { 310 | "id": "9S04urIOAL4e" 311 | }, 312 | "execution_count": null, 313 | "outputs": [] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "source": [ 318 | "class QNetwork(torch.nn.Module):\n", 319 | " def __init__(self, input_shape, actions_size, \n", 320 | " optimizer=torch.optim.Adam, learning_rate=0.00025):\n", 321 | " super().__init__()\n", 322 | " self.personalized = torch.nn.Sequential(\n", 323 | " torch.nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),\n", 324 | " torch.nn.ReLU(),\n", 325 | " torch.nn.Conv2d(32, 64, kernel_size=4, stride=2),\n", 326 | " torch.nn.ReLU(),\n", 327 | " )\n", 328 | " self.shared = torch.nn.Sequential(\n", 329 | " torch.nn.Conv2d(64, 64, kernel_size=3, stride=1),\n", 330 | " torch.nn.ReLU(),\n", 331 | " torch.nn.Flatten(),\n", 332 | " torch.nn.Linear(3136, 512),\n", 333 | " torch.nn.ReLU(),\n", 334 | " torch.nn.Linear(512, actions_size)\n", 335 | " )\n", 336 | " self.optimizer = optimizer(self.parameters(), lr=learning_rate)\n", 337 | " self.loss_fn = torch.nn.SmoothL1Loss()\n", 338 | "\n", 339 | "\n", 340 | " def format_(self, states):\n", 341 | " if not isinstance(states, torch.Tensor):\n", 342 | " states = torch.tensor(states, dtype=torch.float32)\n", 343 | " return states\n", 344 | "\n", 345 | "\n", 346 | " def forward(self, x):\n", 347 | " states = self.format_(x)\n", 348 | " out = self.personalized(states)\n", 349 | " out = self.shared(out)\n", 350 | " return out\n", 351 | "\n", 352 | "\n", 353 | " def update_netowrk(self, td_estimate, td_target):\n", 354 | " loss = self.loss_fn(td_estimate, td_target)\n", 355 | " self.optimizer.zero_grad()\n", 356 | " loss.backward()\n", 357 | " self.optimizer.step()" 358 | ], 359 | "metadata": { 360 | "id": "zifT_IS_0cj_" 361 | }, 362 | "execution_count": null, 363 | "outputs": [] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "source": [ 368 | "class Agent():\n", 369 | " def __init__(self, id, env_name, env_fn, Qnet=QNetwork, buffer=ReplayBuffer,\n", 370 | " max_epsilon=1, min_epsilon=0.05, epsilon_decay=0.99, gamma=0.9,\n", 371 | " target_update_rate=2000, min_buffer=100, \n", 372 | " load=False, path=None) -> None:\n", 373 | " self.id = id\n", 374 | " self.path = path + str(id) + \"/\"\n", 375 | "\n", 376 | " self.env = env_fn(env_name)\n", 377 | " self.env_fn = env_fn\n", 378 | " self.n_actions = self.env.action_space.n\n", 379 | " self.state_shape = self.env.observation_space.shape\n", 380 | " self.device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 381 | "\n", 382 | " self.min_buffer = min_buffer\n", 383 | " self.min_epsilon = min_epsilon\n", 384 | " self.epsilon_decay = epsilon_decay\n", 385 | " self.gamma = gamma\n", 386 | " self.target_update_rate = target_update_rate\n", 387 | " self.buffer = buffer(self.state_shape, self.n_actions,\n", 388 | " load=load, path=self.path)\n", 389 | "\n", 390 | " self.online_net = Qnet(self.state_shape, self.n_actions).to(self.device)\n", 391 | " self.target_net = Qnet(self.state_shape, self.n_actions).to(self.device)\n", 392 | "\n", 393 | " if load:\n", 394 | " self.load()\n", 395 | " else:\n", 396 | " self.update_target_network()\n", 397 | " self.epsilon = max_epsilon\n", 398 | " self.step_count = 0\n", 399 | " self.episode_count = 0\n", 400 | " self.rewards = []\n", 401 | "\n", 402 | " \n", 403 | " def load(self):\n", 404 | " with open(self.path + \"step_count.pkl\", 'rb') as f:\n", 405 | " self.step_count = pickle.load(f)\n", 406 | " with open(self.path + \"episode_count.pkl\", 'rb') as f:\n", 407 | " self.episode_count = pickle.load(f)\n", 408 | " with open(self.path + \"rewards.pkl\", 'rb') as f:\n", 409 | " self.rewards = pickle.load(f)\n", 410 | " with open(self.path + \"epsilon.pkl\", 'rb') as f:\n", 411 | " self.epsilon = pickle.load(f)\n", 412 | " self.online_net.load_state_dict(torch.load(self.path + \"online_net.pt\", \n", 413 | " map_location=torch.device(self.device)))\n", 414 | " self.target_net.load_state_dict(torch.load(self.path + \"target_net.pt\", \n", 415 | " map_location=torch.device(self.device)))\n", 416 | "\n", 417 | " def save(self):\n", 418 | " os.makedirs(os.path.dirname(self.path), exist_ok=True)\n", 419 | " self.buffer.save()\n", 420 | " with open(self.path + \"step_count.pkl\", \"wb\") as f:\n", 421 | " pickle.dump(self.step_count, f)\n", 422 | " with open(self.path + \"episode_count.pkl\", \"wb\") as f:\n", 423 | " pickle.dump(self.episode_count, f)\n", 424 | " with open(self.path + \"rewards.pkl\", \"wb\") as f:\n", 425 | " pickle.dump(self.rewards, f)\n", 426 | " with open(self.path + \"epsilon.pkl\", \"wb\") as f:\n", 427 | " pickle.dump(self.epsilon, f)\n", 428 | " torch.save(self.online_net.state_dict(), self.path + \"online_net.pt\")\n", 429 | " torch.save(self.target_net.state_dict(), self.path + \"target_net.pt\")\n", 430 | "\n", 431 | "\n", 432 | "\n", 433 | " def train(self, n_episodes):\n", 434 | " for i in tqdm(range(n_episodes)):\n", 435 | " episode_reward = 0\n", 436 | " state = self.env.reset()\n", 437 | "\n", 438 | " while True:\n", 439 | " self.step_count += 1\n", 440 | " action = self.epsilonGreedyPolicy(state)\n", 441 | " state_p, reward, done, info = self.env.step(action)\n", 442 | " episode_reward += reward\n", 443 | "\n", 444 | " is_truncated = 'TimeLimit.truncated' in info and info['TimeLimit.truncated']\n", 445 | " is_failure = done and not is_truncated\n", 446 | " self.buffer.store(state, action, reward, state_p, float(is_failure))\n", 447 | "\n", 448 | " if len(self.buffer) >= self.min_buffer:\n", 449 | " self.update()\n", 450 | " if self.step_count % self.target_update_rate == 0:\n", 451 | " self.update_target_network()\n", 452 | "\n", 453 | " state = state_p\n", 454 | " if done:\n", 455 | " self.episode_count += 1\n", 456 | " self.rewards.append(episode_reward)\n", 457 | " break\n", 458 | "\n", 459 | " print(\"Agent-{} Episode {} Step {} score = {}, average score = {}\"\\\n", 460 | " .format(self.id, self.episode_count, self.step_count, self.rewards[-1], np.mean(self.rewards)))\n", 461 | "\n", 462 | "\n", 463 | " def get_score(self):\n", 464 | " # return np.mean(self.rewards[-5:])\n", 465 | " return 1\n", 466 | "\n", 467 | "\n", 468 | " def update(self):\n", 469 | " states, actions, rewards, states_p, is_terminals = self.buffer.sample()\n", 470 | " states = states.to(self.device)\n", 471 | " actions = actions.to(self.device)\n", 472 | " rewards = rewards.to(self.device)\n", 473 | " states_p = states_p.to(self.device)\n", 474 | " is_terminals = is_terminals.to(self.device)\n", 475 | "\n", 476 | " td_estimate = self.online_net(states).gather(1, actions)\n", 477 | "\n", 478 | " actions_p = self.online_net(states).argmax(axis=1, keepdim=True)\n", 479 | " with torch.no_grad():\n", 480 | " q_states_p = self.target_net(states_p)\n", 481 | " q_state_p_action_p = q_states_p.gather(1, actions_p)\n", 482 | " td_target = rewards + (1-is_terminals) * self.gamma * q_state_p_action_p\n", 483 | "\n", 484 | " self.online_net.update_netowrk(td_estimate, td_target)\n", 485 | " self.update_epsilon()\n", 486 | "\n", 487 | "\n", 488 | " def update_epsilon(self):\n", 489 | " self.epsilon *= self.epsilon_decay\n", 490 | " self.epsilon = max(self.epsilon, self.min_epsilon)\n", 491 | "\n", 492 | "\n", 493 | " def update_target_network(self):\n", 494 | " self.target_net.load_state_dict(self.online_net.state_dict())\n", 495 | "\n", 496 | "\n", 497 | " def epsilonGreedyPolicy(self, state):\n", 498 | " if np.random.rand() < self.epsilon:\n", 499 | " action = np.random.randint(self.n_actions)\n", 500 | " else:\n", 501 | " state = state.__array__()\n", 502 | " state = torch.tensor(state).unsqueeze(0).to(self.device)\n", 503 | " with torch.no_grad():\n", 504 | " action = self.online_net(state).argmax().item()\n", 505 | " return action" 506 | ], 507 | "metadata": { 508 | "id": "_EBKN7-pBmDP" 509 | }, 510 | "execution_count": null, 511 | "outputs": [] 512 | }, 513 | { 514 | "cell_type": "code", 515 | "source": [ 516 | "class Mario(Agent):\n", 517 | " def __init__(self, env_names, env_fn, Qnet=QNetwork, load=False, path=None) -> None:\n", 518 | " self.path = path + \"global/\"\n", 519 | " self.envs = []\n", 520 | " for name in env_names:\n", 521 | " self.envs.append(env_fn(name))\n", 522 | " self.n_actions = self.envs[0].action_space.n\n", 523 | " self.state_shape = self.envs[0].observation_space.shape\n", 524 | " self.device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 525 | "\n", 526 | " self.online_net = Qnet(self.state_shape, self.n_actions).to(self.device)\n", 527 | " self.target_net = Qnet(self.state_shape, self.n_actions).to(self.device)\n", 528 | "\n", 529 | " if load:\n", 530 | " self.load()\n", 531 | " else:\n", 532 | " self.update_target_network()\n", 533 | "\n", 534 | "\n", 535 | " def load(self):\n", 536 | " self.online_net.load_state_dict(torch.load(self.path + \"online_net.pt\", \n", 537 | " map_location=torch.device(self.device)))\n", 538 | " self.target_net.load_state_dict(torch.load(self.path + \"target_net.pt\", \n", 539 | " map_location=torch.device(self.device)))\n", 540 | "\n", 541 | "\n", 542 | " def save(self):\n", 543 | " os.makedirs(os.path.dirname(self.path), exist_ok=True)\n", 544 | " torch.save(self.online_net.state_dict(), self.path + \"online_net.pt\")\n", 545 | " torch.save(self.target_net.state_dict(), self.path + \"target_net.pt\")\n", 546 | "\n", 547 | "\n", 548 | " def get_score(self):\n", 549 | " # return np.mean(self.rewards[-5:])\n", 550 | " return 1\n", 551 | "\n", 552 | "\n", 553 | " def test(self):\n", 554 | " rewards = np.zeros(len(self.envs))\n", 555 | " for i in range(len(self.envs)):\n", 556 | " r = self.evaluate(i)\n", 557 | " rewards[i] = r\n", 558 | " return rewards\n", 559 | "\n", 560 | "\n", 561 | " def evaluate(self, i):\n", 562 | " rewards = 0\n", 563 | " state = self.envs[i].reset()\n", 564 | " while True:\n", 565 | " action = self.greedyPolicy(state)\n", 566 | " state_p, reward, done, _ = self.envs[i].step(action)\n", 567 | " rewards += reward\n", 568 | " if done:\n", 569 | " break\n", 570 | " state = state_p\n", 571 | " return rewards\n", 572 | "\n", 573 | "\n", 574 | " def greedyPolicy(self, state):\n", 575 | " with torch.no_grad():\n", 576 | " state = state.__array__()\n", 577 | " state = torch.tensor(state).unsqueeze(0).to(self.device)\n", 578 | " action = self.target_net(state).argmax().item()\n", 579 | " return action" 580 | ], 581 | "metadata": { 582 | "id": "weiDR_iqmYkB" 583 | }, 584 | "execution_count": null, 585 | "outputs": [] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "source": [ 590 | "class Federator:\n", 591 | " def __init__(self, env_fn, update_rate, path=\"./Mario/\", load=False) -> None:\n", 592 | " self.path = path\n", 593 | " self.envs = [\n", 594 | " 'SuperMarioBros-1-1-v0',\n", 595 | " 'SuperMarioBros-1-2-v0',\n", 596 | " 'SuperMarioBros-1-3-v0',\n", 597 | " 'SuperMarioBros-1-4-v0'\n", 598 | " ]\n", 599 | " self.global_agent = Mario(self.envs, env_fn, load=load, path=self.path)\n", 600 | "\n", 601 | " self.update_rate = update_rate\n", 602 | " self.n_agents = 4\n", 603 | " self.agents = []\n", 604 | " for i in range(self.n_agents):\n", 605 | " agent = Agent(i, self.envs[i], env_fn, load=load, path=self.path)\n", 606 | " self.agents.append(agent)\n", 607 | "\n", 608 | " if load:\n", 609 | " self.load()\n", 610 | " else:\n", 611 | " self.set_local_networks()\n", 612 | " self.rewards = []\n", 613 | "\n", 614 | "\n", 615 | " def load(self):\n", 616 | " with open(self.path + \"rewards.pkl\", 'rb') as f:\n", 617 | " self.rewards = pickle.load(f)\n", 618 | "\n", 619 | "\n", 620 | " def save(self):\n", 621 | " os.makedirs(os.path.dirname(self.path), exist_ok=True)\n", 622 | " with open(self.path + \"rewards.pkl\", \"wb\") as f:\n", 623 | " pickle.dump(self.rewards, f)\n", 624 | " self.global_agent.save()\n", 625 | " for agent in self.agents:\n", 626 | " agent.save()\n", 627 | " print(\"All Saved to \" + self.path)\n", 628 | "\n", 629 | " def train(self, n_runs):\n", 630 | " rewards = np.zeros((n_runs, len(self.envs)))\n", 631 | " for i in range(n_runs):\n", 632 | " print(\"Iteration: {}\".format(i+1))\n", 633 | " scores = []\n", 634 | " for agent in self.agents:\n", 635 | " agent.train(self.update_rate)\n", 636 | " scores.append(agent.get_score())\n", 637 | " self.aggregate_networks(scores)\n", 638 | " self.set_local_networks()\n", 639 | " rewards[i] = self.global_agent.test()\n", 640 | " print(rewards[i])\n", 641 | " self.save()\n", 642 | "\n", 643 | "\n", 644 | " def aggregate_networks(self, scores):\n", 645 | " sd_online = self.global_agent.online_net.state_dict()\n", 646 | " sd_target = self.global_agent.target_net.state_dict()\n", 647 | "\n", 648 | " online_dicts = []\n", 649 | " target_dicts = []\n", 650 | " for agent in self.agents:\n", 651 | " online_dicts.append(agent.online_net.state_dict())\n", 652 | " target_dicts.append(agent.target_net.state_dict())\n", 653 | "\n", 654 | " for key in sd_online:\n", 655 | " sd_online[key] = torch.zeros_like(sd_online[key])\n", 656 | " for i, dict in enumerate(online_dicts):\n", 657 | " sd_online[key] += scores[i] * dict[key]\n", 658 | " sd_online[key] /= sum(scores)\n", 659 | "\n", 660 | " for key in sd_target:\n", 661 | " sd_target[key] = torch.zeros_like(sd_target[key])\n", 662 | " for i, dict in enumerate(target_dicts):\n", 663 | " sd_target[key] += scores[i] * dict[key]\n", 664 | " sd_target[key] /= sum(scores)\n", 665 | "\n", 666 | " self.global_agent.online_net.load_state_dict(sd_online)\n", 667 | " self.global_agent.target_net.load_state_dict(sd_target)\n", 668 | "\n", 669 | "\n", 670 | " def set_local_networks(self):\n", 671 | " for agent in self.agents:\n", 672 | " agent.online_net.load_state_dict(\n", 673 | " self.global_agent.online_net.state_dict())\n", 674 | " agent.target_net.load_state_dict(\n", 675 | " self.global_agent.target_net.state_dict())" 676 | ], 677 | "metadata": { 678 | "id": "RvupvVEmfga_" 679 | }, 680 | "execution_count": null, 681 | "outputs": [] 682 | }, 683 | { 684 | "cell_type": "code", 685 | "source": [ 686 | "agent = Federator(create_mario_env, 200, load=True)\n", 687 | "agent.train(5)" 688 | ], 689 | "metadata": { 690 | "id": "13Si-HfXFpen" 691 | }, 692 | "execution_count": null, 693 | "outputs": [] 694 | }, 695 | { 696 | "cell_type": "code", 697 | "source": [ 698 | "! cp -r ./Mario/ /content/drive/Shareddrives/Sam/" 699 | ], 700 | "metadata": { 701 | "id": "xfbDQKYyH9qP" 702 | }, 703 | "execution_count": null, 704 | "outputs": [] 705 | } 706 | ] 707 | } --------------------------------------------------------------------------------