├── README.md ├── div ├── utils.py ├── config_template.py ├── render.py └── run.py ├── test.py ├── requirements.txt ├── rl_algos ├── _ALL_AGENTS.py ├── DQN_TP.py ├── AGENT.py ├── DQN.py ├── PPO.py ├── REINFORCE.py └── ACTOR_CRITIC.py ├── utils.py ├── .gitignore ├── main.py ├── METRICS.py ├── MEMORY.py └── snake_env.py /README.md: -------------------------------------------------------------------------------- 1 | # SnakeRL 2 | Repo for Snake RL 3 | -------------------------------------------------------------------------------- /div/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | def s(): 4 | sys.exit() -------------------------------------------------------------------------------- /div/config_template.py: -------------------------------------------------------------------------------- 1 | project="your_wandb_project_name" 2 | entity="your_wandb_account_name" -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #Snake Tutorial Python 2 | 3 | from ia import Policy 4 | 5 | 6 | policy = Policy() 7 | policy.getReward(0) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | torchaudio 4 | torchsummary 5 | numpy 6 | matplotlib 7 | moviepy 8 | opencv-python 9 | pygame 10 | tkinter 11 | 12 | gym 13 | box2d-py 14 | pyglet 15 | imageio-ffmpeg 16 | wandb -------------------------------------------------------------------------------- /rl_algos/_ALL_AGENTS.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3 import DQN 2 | from rl_algos.REINFORCE import REINFORCE, REINFORCE_OFFPOLICY 3 | from rl_algos.DQN import DQN 4 | from rl_algos.ACTOR_CRITIC import ACTOR_CRITIC 5 | from rl_algos.PPO import PPO -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import random 4 | 5 | import pygame 6 | import tkinter as tk 7 | from tkinter import messagebox 8 | 9 | n_actions = 4 10 | 11 | def OHE(vector): #np vector to list containing one 1 (argmax) and other 0s. 12 | indMax = np.argmax(vector) 13 | return [int(k == indMax) for k in range(len(vector))] 14 | 15 | def sumListOfGradients(listGradients): 16 | listGradientsPure = [gradient for gradient in listGradients if gradient != None] 17 | if listGradientsPure == []: return None 18 | return [sum([gradient[k] for gradient in listGradientsPure]) for k in range(len(listGradientsPure[0]))] 19 | 20 | def multiplyGradients(grad, factor): 21 | if grad == None: 22 | return None 23 | return [elem * factor for elem in grad] 24 | 25 | def actionToVector(action): 26 | actionVector = [0 for _ in range(n_actions)] 27 | actionVector[action] = 1 28 | return np.array(actionVector) 29 | 30 | def actionToNumber(vector): 31 | return np.argmax(vector) 32 | 33 | def callModel(model, input): 34 | #For model being V and input being s, return the float V(s) 35 | return model(np.array([input]))[0] -------------------------------------------------------------------------------- /div/render.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | def render_agent(agent, env = gym.make("CartPole-v0"), episodes = 10, show_metrics = False): 4 | '''Show some episodes of agent performing on env. 5 | agent : an agent 6 | env : a gym env 7 | episodes : number of episodes played 8 | show_metrics : boolean whether metrics from agent are printed 9 | ''' 10 | for episode in range(episodes): 11 | done = False 12 | obs = env.reset() 13 | 14 | while not done: 15 | action = agent.act(obs) 16 | next_obs, reward, done, info = env.step(1) 17 | env.render() 18 | metrics1 = agent.remember(obs, action, reward, done, next_obs, info) 19 | metrics2 = agent.learn() 20 | 21 | if show_metrics: 22 | # print("\n\tMETRICS : ") 23 | for metrics in metrics1 + metrics2: 24 | for key, value in metrics.items(): 25 | print(f"{key}: {value}") 26 | 27 | #If episode ended. 28 | if done: 29 | pass 30 | else: 31 | obs = next_obs -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv* 2 | models 3 | rbdl 4 | videos 5 | config.py 6 | CONFIGS.py 7 | test* 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | -------------------------------------------------------------------------------- /rl_algos/DQN_TP.py: -------------------------------------------------------------------------------- 1 | from copy import copy, deepcopy 2 | import numpy as np 3 | import math 4 | import gym 5 | import sys 6 | import random 7 | import matplotlib.pyplot as plt 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import torch.nn.functional as F 13 | import torchvision.transforms as T 14 | 15 | from MEMORY import Memory 16 | from CONFIGS import DQN_CONFIG 17 | from METRICS import * 18 | from rl_algos.AGENT import AGENT 19 | 20 | class DQN_TP(AGENT): 21 | '''DQN to fill for didactic purposes 22 | ''' 23 | 24 | def __init__(self, action_value : nn.Module): 25 | metrics = [MetricS_On_Learn, Metric_Reward, Metric_Total_Reward, Metric_Performances, Metric_Action_Frequencies] 26 | super().__init__(config = DQN_CONFIG, metrics = metrics) 27 | self.memory = ... 28 | 29 | self.action_value = ... 30 | self.action_value_target = ... 31 | self.opt = ... 32 | self.f_eps = lambda s : max(s.exploration_final, s.exploration_initial + (s.exploration_final - s.exploration_initial) * (s.step / s.exploration_timesteps)) 33 | 34 | 35 | def act(self, observation, greedy=False, mask = None): 36 | '''Ask the agent to take a decision given an observation. 37 | observation : an (n_obs,) shaped observation. 38 | greedy : whether the agent always choose the best Q values according to himself. 39 | mask : a binary list containing 1 where corresponding actions are forbidden. 40 | return : an int corresponding to an action 41 | ''' 42 | 43 | #Batching observation 44 | observations = ... 45 | 46 | # Q(s) 47 | Q = ... 48 | 49 | #Greedy policy 50 | epsilon = self.f_eps(self) 51 | if greedy or np.random.rand() > epsilon: 52 | ... 53 | 54 | #Exploration 55 | else : 56 | ... 57 | 58 | #Save metrics 59 | self.add_metric(mode = 'act') 60 | 61 | # Action 62 | return action 63 | 64 | 65 | def learn(self): 66 | '''Do one step of learning. 67 | ''' 68 | values = dict() 69 | self.step += 1 70 | 71 | #Learn only every train_freq steps 72 | #\optional 73 | 74 | #Learn only after learning_starts steps 75 | #\optional 76 | 77 | #Sample trajectories 78 | observations, actions, rewards, dones, next_observations = ... 79 | actions = actions.to(dtype = torch.int64) 80 | 81 | #Scaling the rewards 82 | #\optional 83 | 84 | # Estimated Q values 85 | Q_s_predicted = ... 86 | 87 | #Gradient descent on Q network 88 | criterion = nn.SmoothL1Loss() 89 | for _ in range(self.gradients_steps): 90 | ... 91 | 92 | #Update target network 93 | if self.update_method == "periodic": 94 | ... 95 | elif self.update_method == "soft": 96 | ... 97 | else: 98 | print(f"Error : update_method {self.update_method} not implemented.") 99 | sys.exit() 100 | 101 | #Save metrics* 102 | values["critic_loss"] = loss.detach().numpy() 103 | values["value"] = Q_s.mean().detach().numpy() 104 | self.add_metric(mode = 'learn', **values) 105 | 106 | 107 | def remember(self, observation, action, reward, done, next_observation, info={}, **param): 108 | '''Save elements inside memory. 109 | *arguments : elements to remember, as numerous and in the same order as in self.memory.MEMORY_KEYS 110 | ''' 111 | ... 112 | 113 | #Save metrics 114 | values = {"obs" : observation, "action" : action, "reward" : reward, "done" : done, "next_obs" : next_observation} 115 | self.add_metric(mode = 'remember', **values) 116 | -------------------------------------------------------------------------------- /div/run.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | import torchvision.transforms as T 7 | 8 | import sys 9 | import math 10 | import random 11 | import os 12 | from moviepy.editor import * 13 | import matplotlib.pyplot as plt 14 | import numpy as np 15 | 16 | import gym 17 | from gym.wrappers import Monitor as GymMonitor 18 | from stable_baselines3.common.monitor import Monitor 19 | from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder 20 | import wandb 21 | from wandb.integration.sb3 import WandbCallback 22 | import wandb 23 | 24 | def run(agent, env, episodes, wandb_cb = True, plt_cb = True, video_cb = True, 25 | n_ep_save_video = 100, 26 | n_render = 20 27 | ): 28 | 29 | print("Run starts.") 30 | config = agent.config 31 | ################### FEEDBACK ##################### 32 | if wandb_cb: 33 | try: 34 | from config import project, entity 35 | except ImportError: 36 | raise Exception("You need to specify your WandB ids in config.py\nConfig template is available at div/config_template.py") 37 | run = wandb.init(project=project, 38 | entity=entity, 39 | config=config 40 | ) 41 | if video_cb: 42 | videos_path = "./div/videos/rl-video/" 43 | env = GymMonitor(env, videos_path, video_callable=lambda ep: ep % n_ep_save_video == 0, force = True) 44 | if plt_cb: 45 | logs = dict() 46 | ##################### END FEEDBACK ################### 47 | 48 | 49 | for episode in range(1, episodes+1): 50 | done = False 51 | obs = env.reset() 52 | 53 | while not done: 54 | action = agent.act(obs) 55 | next_obs, reward, done, info = env.step(action) 56 | metrics1 = agent.remember(obs, action, reward, done, next_obs, info) 57 | metrics2 = agent.learn() 58 | 59 | ###### Feedback ###### 60 | print(f"Episode n°{episode} - Total step n°{agent.step} ...", end = '\r') 61 | if episode % n_render == 0: 62 | env.render() 63 | for metric in metrics1 + metrics2: 64 | if wandb_cb: 65 | wandb.log(metric, step = agent.step) 66 | if plt_cb: 67 | for key, value in metric.items(): 68 | try: 69 | logs[key]["steps"].append(agent.step) 70 | logs[key]["values"].append(value) 71 | except KeyError: 72 | logs[key] = {"steps": [agent.step], "values": [value]} 73 | plt.clf() 74 | plt.plot(logs[key]["steps"][-100:], logs[key]["values"][-100:], '-b') 75 | plt.title(key) 76 | plt.savefig(f"figures/{key}") 77 | ###### End Feedback ###### 78 | 79 | #If episode ended. 80 | if done: 81 | break 82 | else: 83 | obs = next_obs 84 | 85 | if wandb_cb: run.finish() 86 | print("End of run.") 87 | 88 | 89 | 90 | def run_for_sb3(create_agent, config, env, episodes, wandb_cb = True, video_cb = True): 91 | print("Run for SB3 agent starts.") 92 | 93 | env1 = deepcopy(env) 94 | def make_env(): 95 | env2 = deepcopy(env1) 96 | env2.reset() 97 | env2 = Monitor(env2) 98 | return env2 99 | env = DummyVecEnv([make_env]) 100 | 101 | #Wandb 102 | if wandb_cb: 103 | try: 104 | from config import project, entity 105 | except ImportError: 106 | raise Exception("You need to specify your WandB ids in config.py\nConfig template is available at div/config_template.py") 107 | run = wandb.init(project=project, 108 | entity=entity, 109 | sync_tensorboard=True, 110 | monitor_gym=True, 111 | config=config 112 | ) 113 | #Save videos of agent 114 | if video_cb: 115 | n_save = 5000 116 | video_path = f"div/videos/rl-videos-sb3/{run.id}" 117 | env = VecVideoRecorder(env, video_folder= video_path, record_video_trigger=lambda step: step % n_save == 0) 118 | 119 | agent = create_agent(env = env) 120 | agent.learn(total_timesteps=config["total_timesteps"], 121 | callback=WandbCallback( 122 | gradient_save_freq=100, 123 | model_save_path=f"div/models/{run.id}", 124 | verbose=2, 125 | ) if wandb_cb else None, 126 | ) 127 | 128 | 129 | if wandb_cb: run.finish() 130 | print("End of run.") 131 | return agent -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #Torch for deep learning 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | import torchvision.transforms as T 7 | from torchsummary import summary 8 | #Python library 9 | import sys 10 | import math 11 | import random 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | #Gym for environments, WandB for feedback 15 | from snake_env import SnakeEnv, rows 16 | import gym 17 | import wandb 18 | #RL agents 19 | from div.utils import * 20 | try: 21 | from config import agent_name, steps, wandb_cb, n_render 22 | except ImportError: 23 | raise Exception("You need to specify your config in config.py\nConfig template is available at div/config_template.py") 24 | from rl_algos._ALL_AGENTS import REINFORCE, REINFORCE_OFFPOLICY, DQN, ACTOR_CRITIC, PPO 25 | from rl_algos.AGENT import RANDOM_AGENT 26 | 27 | 28 | def run(agent, env, steps, wandb_cb = True, 29 | n_render = 20 30 | ): 31 | '''Train an agent on an env. 32 | agent : an AGENT instance (with methods act, learn and remember implemented) 33 | env : a gym env (with methods reset, step, render) 34 | steps : int, number of steps of training 35 | wandb_cb : bool, whether metrics are logged in WandB 36 | n_render : int, one episode on n_render is rendered 37 | ''' 38 | 39 | print("Run starts.") 40 | ################### FEEDBACK ##################### 41 | if n_render == None: n_render = float('inf') 42 | if wandb_cb: 43 | try: 44 | from config import project, entity 45 | except ImportError: 46 | raise Exception("You need to specify your WandB ids in config.py\nConfig template is available at div/config_template.py") 47 | run = wandb.init(project=project, 48 | entity=entity, 49 | config=agent.config, 50 | ) 51 | ##################### END FEEDBACK ################### 52 | episode = 1 53 | step = 0 54 | while step < steps: 55 | done = False 56 | obs = env.reset() 57 | 58 | 59 | while not done and step < steps: 60 | action = agent.act(obs) #Agent acts 61 | next_obs, reward, done, info = env.step(action) #Env reacts 62 | agent.remember(obs, action, reward, done, next_obs, info) #Agent saves previous transition in its memory 63 | agent.learn() #Agent learn (eventually) 64 | 65 | ###### Feedback ###### 66 | print(f"Episode n°{episode} - Total step n°{step} ...", end = '\r') 67 | if episode % n_render == 0: 68 | env.render() 69 | if wandb_cb: 70 | agent.log_metrics() 71 | ###### End Feedback ###### 72 | 73 | #If episode ended, reset env, else change state 74 | if done: 75 | step += 1 76 | episode += 1 77 | break 78 | else: 79 | step += 1 80 | obs = next_obs 81 | 82 | if wandb_cb: run.finish() #End wandb run. 83 | print("End of run.") 84 | 85 | 86 | 87 | 88 | if __name__ == "__main__": 89 | #ENV 90 | env = SnakeEnv() 91 | n_actions = 4 92 | n_flatten = (rows - 2 - 2)**2 93 | 94 | #ACTOR PI 95 | actor = nn.Sequential( 96 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1), 97 | nn.ReLU(), 98 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1), 99 | nn.ReLU(), 100 | nn.Flatten(), 101 | nn.Linear(n_flatten * 64, 32), 102 | nn.ReLU(), 103 | nn.Linear(32, n_actions), 104 | nn.Softmax(), 105 | ) 106 | 107 | #CRITIC Q 108 | action_value = nn.Sequential( 109 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1), 110 | nn.ReLU(), 111 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1), 112 | nn.ReLU(), 113 | nn.Flatten(), 114 | nn.Linear(n_flatten * 64, 32), 115 | nn.ReLU(), 116 | nn.Linear(32, n_actions), 117 | ) 118 | 119 | #STATE VALUE V 120 | state_value = nn.Sequential( 121 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1), 122 | nn.ReLU(), 123 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1), 124 | nn.ReLU(), 125 | nn.Flatten(), 126 | nn.Linear(n_flatten * 64, 32), 127 | nn.ReLU(), 128 | nn.Linear(32, 1), 129 | ) 130 | 131 | #AGENT 132 | agents = {'dqn' : DQN(action_value=action_value), 133 | 'reinforce' : REINFORCE(actor=actor), 134 | 'reinforce_offpolicy' : REINFORCE_OFFPOLICY(actor = actor), 135 | 'ppo' : PPO(actor = actor, state_value = state_value), 136 | 'ac' : ACTOR_CRITIC(actor = actor, state_value = state_value), 137 | 'random_agent' : RANDOM_AGENT(n_actions = 2), 138 | } 139 | agent = agents[agent_name] 140 | 141 | #RUN 142 | run(agent, 143 | env = env, 144 | steps=steps, 145 | wandb_cb = wandb_cb, 146 | n_render = n_render, 147 | ) 148 | -------------------------------------------------------------------------------- /METRICS.py: -------------------------------------------------------------------------------- 1 | from numbers import Number 2 | from time import time 3 | 4 | 5 | class Metric(): 6 | def __init__(self): 7 | pass 8 | 9 | def on_learn(self, **kwargs): 10 | return dict() 11 | 12 | def on_remember(self, **kwargs): 13 | return dict() 14 | 15 | def on_act(self, **kwargs): 16 | return dict() 17 | 18 | 19 | class MetricS_On_Learn(Metric): 20 | '''Log every metrics whose name match classical RL important values such as Q_value, actor_loss ...''' 21 | metric_names = ["value", "Q_value", "V_value", "actor_loss", "critic_loss", ] 22 | def __init__(self, agent): 23 | super().__init__() 24 | self.agent = agent 25 | 26 | def on_learn(self, **kwargs): 27 | return {metric_name : kwargs[metric_name] for metric_name in self.metric_names if metric_name in kwargs} 28 | 29 | 30 | class MetricS_On_Learn_Numerical(Metric): 31 | '''Log every numerical metrics.''' 32 | def __init__(self, agent): 33 | super().__init__() 34 | self.agent = agent 35 | 36 | def on_learn(self, **kwargs): 37 | return {metric_name : kwargs[metric_name] for metric_name, value in kwargs.items() if isinstance(value, Number)} 38 | 39 | 40 | class Metric_Total_Reward(Metric): 41 | '''Log total reward (sum of reward over an episode) at every episode.''' 42 | def __init__(self, agent): 43 | super().__init__() 44 | self.agent = agent 45 | self.total_reward = 0 46 | self.new_episode = False 47 | 48 | def on_remember(self, **kwargs): 49 | try: 50 | if self.new_episode: 51 | self.total_reward = 0 52 | self.new_episode = False 53 | self.total_reward += kwargs["reward"] 54 | 55 | if kwargs["done"]: 56 | self.new_episode = True 57 | return {"total_reward" : self.total_reward} 58 | else: 59 | return dict() 60 | except KeyError: 61 | return dict() 62 | 63 | 64 | class Metric_Reward(Metric): 65 | def __init__(self, agent): 66 | super().__init__() 67 | self.agent = agent 68 | 69 | def on_remember(self, **kwargs): 70 | try: 71 | return {"reward" : kwargs["reward"]} 72 | except: 73 | return dict() 74 | 75 | 76 | class Metric_Epsilon(Metric): 77 | '''Log exploration factor.''' 78 | def __init__(self, agent): 79 | super().__init__() 80 | self.agent = agent 81 | 82 | def on_learn(self, **kwargs): 83 | try: 84 | return {"epsilon" : self.agent.f_eps(self.agent)} 85 | except: 86 | return dict() 87 | 88 | 89 | class Metric_Critic_Value_Unnormalized(Metric): 90 | '''Log value not scaled.''' 91 | def __init__(self, agent): 92 | super().__init__() 93 | self.agent = agent 94 | self.is_normalized = not( hasattr(agent, "reward_scaler") or agent.reward_scaler is None ) 95 | 96 | def on_learn(self, **kwargs): 97 | try: 98 | if self.is_normalized: 99 | return {"value_unnormalized" : self.agent.reward_scaler * kwargs["value"]} 100 | else: 101 | return {"value_unnormalized" : kwargs["value"]} 102 | except KeyError: 103 | return dict() 104 | 105 | 106 | class Metric_Action_Frequencies(Metric): 107 | '''Log action frequency in one episode for each action possible.''' 108 | def __init__(self, agent): 109 | super().__init__() 110 | self.agent = agent 111 | self.frequencies = dict() 112 | self.new_episode = False 113 | 114 | def on_remember(self, **kwargs): 115 | try: 116 | if self.new_episode: 117 | self.frequencies = dict() 118 | self.ep_lenght = 0 119 | self.new_episode = False 120 | action = kwargs["action"] 121 | if action not in self.frequencies: 122 | self.frequencies[action] = 0 123 | self.frequencies[action] += 1 124 | 125 | if kwargs["done"]: 126 | self.new_episode = True 127 | ep_lenght = sum(self.frequencies.values()) 128 | return {f"action_{a}_freq" : n_actions / ep_lenght for a, n_actions in self.frequencies.items()} 129 | else: 130 | return dict() 131 | except KeyError: 132 | return dict() 133 | 134 | 135 | class Metric_Count_Episodes(Metric): 136 | def __init__(self, agent): 137 | super().__init__() 138 | self.agent = agent 139 | self.n_episodes = 0 140 | 141 | def on_remember(self, **kwargs): 142 | try: 143 | if kwargs["done"]: 144 | self.n_episodes += 1 145 | return {"n_episodes" : self.n_episodes} 146 | else: 147 | return dict() 148 | except KeyError: 149 | return dict() 150 | 151 | 152 | class Metric_Time_Count(Metric): 153 | def __init__(self, agent): 154 | super().__init__() 155 | self.agent = agent 156 | self.t0 = time() 157 | 158 | def on_learn(self, **kwargs): 159 | return {"time" : round((time() - self.t0) / 60, 2)} 160 | 161 | 162 | class Metric_Performances(Metric): 163 | def __init__(self, agent): 164 | super().__init__() 165 | self.agent = agent 166 | self.t0 = time() 167 | def on_x(self, step_of_training : str): 168 | dur = time() - self.t0 169 | self.t0 = time() 170 | if self.agent.step < 10: 171 | return dict() 172 | return {step_of_training: dur} 173 | def on_act(self, **kwargs): 174 | return self.on_x("time : ACTING + LOGGING (+ RENDERING)") 175 | def on_remember(self, **kwargs): 176 | return self.on_x("time : ENV REACTING + REMEMBERING") 177 | def on_learn(self, **kwargs): 178 | return self.on_x("time : SAMPLING + LEARNING") 179 | 180 | 181 | -------------------------------------------------------------------------------- /rl_algos/AGENT.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | import wandb 4 | from random import randint 5 | from METRICS import * 6 | from div.utils import pr_and_raise, pr_shape 7 | 8 | class AGENT(ABC): 9 | 10 | def __init__(self, config = dict(), metrics = list()): 11 | self.step = 0 12 | self.episode = 0 13 | self.metrics = [Metric(self) for Metric in metrics] 14 | self.config = config 15 | for name, value in config.items(): 16 | setattr(self, name, value) 17 | self.metrics_saved = list() 18 | 19 | @abstractmethod 20 | def act(self, obs): 21 | pass 22 | 23 | @abstractmethod 24 | def learn(self): 25 | pass 26 | 27 | @abstractmethod 28 | def remember(self, **kwargs): 29 | pass 30 | 31 | def add_metric(self, mode, **values): 32 | if mode == 'act': 33 | for metric in self.metrics: 34 | self.metrics_saved.append(metric.on_act(**values)) 35 | if mode == 'remember': 36 | for metric in self.metrics: 37 | self.metrics_saved.append(metric.on_remember(**values)) 38 | if mode == 'learn': 39 | for metric in self.metrics: 40 | self.metrics_saved.append(metric.on_learn(**values)) 41 | 42 | def log_metrics(self): 43 | for metric in self.metrics_saved: 44 | wandb.log(metric, step = self.step) 45 | self.metrics_saved = list() 46 | 47 | def compute_TD(self, rewards, observations): 48 | '''Compute the n_step TD estimate of a state value, where n_step is an attribute of agent. 49 | rewards : a (T, 1) shaped torch tensor representing rewards 50 | observations : a (T, *dims) shaped torch tensor reprensting observations* 51 | 52 | ''' 53 | n = self.n_step 54 | 55 | #We compute the discounted sum of the n next rewards dynamically. 56 | T = len(rewards) 57 | rewards = rewards[:, 0].numpy() 58 | n_next_rewards = [0 for _ in range(T)] + [0] 59 | t = T - 1 60 | while t >= 0: 61 | if t >= T - n: 62 | n_next_rewards[t] = rewards[t] + self.gamma * n_next_rewards[t+1] 63 | else: 64 | n_next_rewards[t] = rewards[t] + self.gamma * n_next_rewards[t+1] - (self.gamma ** n) * rewards[t+n] 65 | t -= 1 66 | n_next_rewards.pop(-1) 67 | n_next_rewards = torch.Tensor(n_next_rewards).unsqueeze(-1) 68 | 69 | #We compute the state value, and shift them forward in order to add them or not to the estimate. 70 | state_values = self.state_value(observations) 71 | state_values_to_add = torch.concat((state_values, torch.zeros(n, 1)), axis = 0)[n:] 72 | 73 | V_targets = n_next_rewards + state_values_to_add 74 | return V_targets 75 | 76 | 77 | def compute_critic(self, method, rewards, dones = None, observations = None, next_observations = None): 78 | '''Estimate some critic values such as advantage function A_pi_st_at of each transitions (s, a, r)t, noted A_s, on an episode. 79 | method : string, method used for advantage estimation, in (TD, MC, n_step, GAE) or (total_reward) 80 | *args : elements for computing A_s, torch tensor. 81 | return : A_s as a torch tensor of shape (T, 1) 82 | ''' 83 | if method == 'TD': 84 | rewards = rewards[:, 0] #(T,) 85 | values = list() 86 | t = len(rewards) - 1 87 | next_reward = 0 88 | while t >= 0: 89 | next_reward = rewards[t] + self.gamma * next_reward 90 | values.insert(0, next_reward) 91 | t -= 1 92 | res = torch.Tensor(values).unsqueeze(-1) 93 | 94 | elif method == 'A_MC': 95 | rewards = rewards[:, 0] #(T,) 96 | advantages = list() 97 | t = len(rewards) - 1 98 | next_reward = 0 99 | while t >= 0: 100 | next_reward = rewards[t] + self.gamma * next_reward 101 | advantages.insert(0, next_reward) 102 | t -= 1 103 | res = torch.Tensor(advantages).unsqueeze(-1) - self.state_value(observations) 104 | 105 | elif method == 'V_TD': 106 | res = rewards + (1 - dones) * self.gamma * self.state_value(next_observations) 107 | 108 | elif method == 'A_TD': 109 | res = rewards + (1 - dones) * self.gamma * self.state_value(next_observations) - self.state_value(observations) 110 | 111 | 112 | 113 | else: 114 | raise Exception(f"Method '{method}' for computing advantage estimate is not implemented.") 115 | 116 | return res 117 | 118 | #Use the following agent as a model for minimum restrictions on AGENT subclasses : 119 | class RANDOM_AGENT(AGENT): 120 | '''A random agent evolving in a discrete environment. 121 | n_actions : int, n of action space 122 | ''' 123 | def __init__(self, n_actions): 124 | super().__init__(metrics=[MetricS_On_Learn_Numerical, Metric_Performances]) #Choose metrics here 125 | self.n_actions = n_actions #For RandomAgent only 126 | 127 | def act(self, obs): 128 | #Choose action here 129 | ... 130 | action = randint(0, self.n_actions - 1) 131 | #Save metrics 132 | values = {"my_metric_name1" : 22, "my_metric_name2" : 42} 133 | self.add_metric(mode = 'act', **values) 134 | 135 | return action 136 | 137 | def learn(self): 138 | #Learn here 139 | ... 140 | #Save metrics 141 | self.step += 1 142 | values = {"my_metric_name1" : 22, "my_metric_name2" : 42} 143 | self.add_metric(mode = 'learn', **values) 144 | 145 | def remember(self, *args): 146 | #Save kwargs in memory here 147 | ... 148 | #Save metrics 149 | values = {"my_metric_name1" : 22, "my_metric_name2" : 42} 150 | self.add_metric(mode = 'remember', **values) -------------------------------------------------------------------------------- /rl_algos/DQN.py: -------------------------------------------------------------------------------- 1 | from copy import copy, deepcopy 2 | import numpy as np 3 | import math 4 | import gym 5 | import sys 6 | import random 7 | import matplotlib.pyplot as plt 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import torch.nn.functional as F 13 | import torchvision.transforms as T 14 | 15 | from MEMORY import Memory 16 | from CONFIGS import DQN_CONFIG 17 | from METRICS import * 18 | from rl_algos.AGENT import AGENT 19 | 20 | class DQN(AGENT): 21 | 22 | def __init__(self, action_value : nn.Module): 23 | metrics = [MetricS_On_Learn, Metric_Total_Reward] 24 | super().__init__(config = DQN_CONFIG, metrics = metrics) 25 | self.memory = Memory(MEMORY_KEYS = ['observation', 'action','reward', 'done', 'next_observation']) 26 | 27 | self.action_value = action_value 28 | self.action_value_target = deepcopy(action_value) 29 | self.opt = optim.Adam(lr = self.learning_rate, params=action_value.parameters()) 30 | self.f_eps = lambda s : max(s.exploration_final, s.exploration_initial + (s.exploration_final - s.exploration_initial) * (s.step / s.exploration_timesteps)) 31 | 32 | 33 | def act(self, observation, greedy=False, mask = None): 34 | '''Ask the agent to take a decision given an observation. 35 | observation : an (n_obs,) shaped observation. 36 | greedy : whether the agent always choose the best Q values according to himself. 37 | mask : a binary list containing 1 where corresponding actions are forbidden. 38 | return : an int corresponding to an action 39 | ''' 40 | 41 | #Batching observation 42 | observations = torch.Tensor(observation) 43 | observations = observations.unsqueeze(0) # (1, observation_space) 44 | 45 | # Q(s) 46 | Q = self.action_value(observations) # (1, action_space) 47 | 48 | #Greedy policy 49 | epsilon = self.f_eps(self) 50 | if greedy or np.random.rand() > epsilon: 51 | with torch.no_grad(): 52 | if mask is not None: 53 | Q = Q - 10000.0 * torch.Tensor([mask]) # So that forbidden action won't ever be selected by the argmax. 54 | action = torch.argmax(Q, axis = -1).numpy()[0] 55 | 56 | #Exploration 57 | else : 58 | if mask is None: 59 | action = torch.randint(size = (1,), low = 0, high = Q.shape[-1]).numpy()[0] #Choose random action 60 | else: 61 | authorized_actions = [i for i in range(len(mask)) if mask[i] == 0] #Choose random action among authorized ones 62 | action = random.choice(authorized_actions) 63 | 64 | #Save metrics 65 | self.add_metric(mode = 'act') 66 | 67 | # Action 68 | return action 69 | 70 | 71 | def learn(self): 72 | '''Do one step of learning. 73 | ''' 74 | values = dict() 75 | self.step += 1 76 | 77 | #Learn only every train_freq steps 78 | if self.step % self.train_freq != 0: 79 | return 80 | 81 | #Learn only after learning_starts steps 82 | if self.step < self.learning_starts: 83 | return 84 | 85 | #Sample trajectories 86 | observations, actions, rewards, dones, next_observations = self.memory.sample( 87 | sample_size=self.sample_size, 88 | method = "random", 89 | ) 90 | actions = actions.to(dtype = torch.int64) 91 | 92 | # print(observations.shape, actions, rewards, dones, sep = '\n\n') 93 | # raise 94 | 95 | #Scaling the rewards 96 | if self.reward_scaler is not None: 97 | rewards = rewards / self.reward_scaler 98 | 99 | # Estimated Q values 100 | if not self.double_q_learning: 101 | #Simple learning : Q(s,a) = r + gamma * max_a'(Q_target(s_next, a')) * (1-d) | s_next and r being the result of action a taken in observation s 102 | future_Q_s_a = self.action_value_target(next_observations) 103 | future_Q_s, bests_a = torch.max(future_Q_s_a, dim = 1, keepdim=True) 104 | Q_s_predicted = rewards + self.gamma * future_Q_s * (1 - dones) #(n_sampled,) 105 | else: 106 | #Double Q Learning : Q(s,a) = r + gamma * Q_target(s_next, argmax_a'(Q(s_next, a'))) 107 | future_Q_s_a = self.action_value(next_observations) 108 | future_Q_s, bests_a = torch.max(future_Q_s_a, dim = 1, keepdim=True) 109 | future_Q_s_a_target = self.action_value_target(next_observations) 110 | future_Q_s_target = torch.gather(future_Q_s_a_target, dim = 1, index= bests_a) 111 | 112 | Q_s_predicted = rewards + self.gamma * future_Q_s_target * (1 - dones) 113 | 114 | #Gradient descent on Q network 115 | criterion = nn.SmoothL1Loss() 116 | for _ in range(self.gradients_steps): 117 | self.opt.zero_grad() 118 | Q_s_a = self.action_value(observations) 119 | Q_s = Q_s_a.gather(dim = 1, index = actions) 120 | loss = criterion(Q_s, Q_s_predicted) 121 | loss.backward(retain_graph = True) 122 | if self.clipping is not None: 123 | for param in self.action_value.parameters(): 124 | param.grad.data.clamp_(-self.clipping, self.clipping) 125 | self.opt.step() 126 | 127 | #Update target network 128 | if self.update_method == "periodic": 129 | if self.step % self.target_update_interval == 0: 130 | self.action_value_target = deepcopy(self.action_value) 131 | elif self.update_method == "soft": 132 | for phi, phi_target in zip(self.action_value.parameters(), self.action_value_target.parameters()): 133 | phi_target.data = self.tau * phi_target.data + (1-self.tau) * phi.data 134 | else: 135 | print(f"Error : update_method {self.update_method} not implemented.") 136 | sys.exit() 137 | 138 | #Save metrics* 139 | values["critic_loss"] = loss.detach().numpy() 140 | values["value"] = Q_s.mean().detach().numpy() 141 | self.add_metric(mode = 'learn', **values) 142 | 143 | 144 | def remember(self, observation, action, reward, done, next_observation, info={}, **param): 145 | '''Save elements inside memory. 146 | *arguments : elements to remember, as numerous and in the same order as in self.memory.MEMORY_KEYS 147 | ''' 148 | self.memory.remember((observation, action, reward, done, next_observation)) 149 | 150 | #Save metrics 151 | values = {"obs" : observation, "action" : action, "reward" : reward, "done" : done, "next_obs" : next_observation} 152 | self.add_metric(mode = 'remember', **values) 153 | 154 | -------------------------------------------------------------------------------- /MEMORY.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch 3 | import numpy as np 4 | import random as rd 5 | 6 | #Memory using tensor 7 | class Memory(): 8 | '''Memory class for keeping observations, actions, rewards, ... in memory. 9 | MEMORY_KEYS : a list of string, each string being the name of a kind of element to remember. 10 | max_memory_len : maximum memory lenght, no limit by default. 11 | ''' 12 | 13 | def __init__(self, MEMORY_KEYS: list, max_memory_len: int=None): 14 | self.max_memory_len = max_memory_len 15 | self.MEMORY_KEYS = MEMORY_KEYS 16 | self.trajectory = {} 17 | 18 | def remember(self, transition: tuple): 19 | '''Memorizes a transition and add it to the buffer. 20 | transition : a tuple of element corresponding to self.MEMORY_KEYS. 21 | ''' 22 | for val, key in zip(transition, self.MEMORY_KEYS): 23 | if type(val) == bool: 24 | val = int(val) 25 | val = torch.tensor(val) 26 | val = torch.unsqueeze(val, 0) 27 | if len(val.shape) == 1: 28 | val = torch.unsqueeze(val, 0) 29 | try: 30 | self.trajectory[key] = torch.concat((val, self.trajectory[key]), axis = 0) #(memory_lenght, n_?) 31 | except KeyError: 32 | self.trajectory[key] = val 33 | self.memory_len = len(self.trajectory[self.MEMORY_KEYS[0]]) 34 | 35 | 36 | 37 | def sample(self, sample_size=None, pos_start=None, method='last', func = None): 38 | '''Samples several transitions from memory, using different methods. 39 | sample_size : the number of transitions to sample, default all. 40 | pos_start : the position in the memory of the first transition sampled, default 0. 41 | method : the method of sampling in "all", "last", "random", "all_shuffled", "batch_shuffled". 42 | func : a function applied to each elements of the transitions, usually converting a np array into a pytorch/tf/jax tensor. 43 | return : a list containing a list of size sample_size for each kind of element stored. 44 | ''' 45 | if method == 'all': 46 | trajectory = [self.trajectory[key] for key in self.MEMORY_KEYS] 47 | 48 | elif method == 'last': 49 | trajectory = [self.trajectory[key][-sample_size:] 50 | for key in self.MEMORY_KEYS] 51 | 52 | elif method == 'random': 53 | indexes = np.random.permutation(self.memory_len)[:sample_size] 54 | trajectory = [self.trajectory[key][indexes] 55 | for key in self.MEMORY_KEYS] 56 | 57 | elif method == 'all_shuffled': 58 | trajectory = [self.trajectory[key][self.max_memory_len] 59 | for key in self.MEMORY_KEYS] 60 | 61 | elif method == 'batch_shuffled': 62 | trajectory = [self.trajectory[key][pos_start: pos_start + 63 | sample_size] 64 | for key in self.MEMORY_KEYS] 65 | 66 | else: 67 | raise NotImplementedError('Not implemented sample') 68 | 69 | if func is not None: 70 | trajectory = [func(elem) for elem in trajectory] 71 | return trajectory 72 | 73 | def __len__(self): 74 | return self.memory_len 75 | 76 | def __empty__(self): 77 | self.trajectory = {} 78 | 79 | 80 | #Memory using lists 81 | class Memory(): 82 | def __init__(self, MEMORY_KEYS: list, max_memory_len: int=float('inf')): 83 | self.max_memory_len = max_memory_len 84 | self.MEMORY_KEYS = MEMORY_KEYS 85 | self.trajectory = {key : list() for key in MEMORY_KEYS} 86 | 87 | def remember(self, transition: tuple): 88 | '''Memorizes a transition and add it to the buffer. Complexity = O(size_transition) 89 | transition : a tuple of element corresponding to self.MEMORY_KEYS. 90 | ''' 91 | for val, key in zip(transition, self.MEMORY_KEYS): 92 | if type(val) == bool: val = int(val) 93 | self.trajectory[key].append(val) 94 | if len(self) > self.max_memory_len: 95 | for val, key in zip(transition, self.MEMORY_KEYS): 96 | self.trajectory[key].pop() 97 | 98 | def sample(self, sample_size=None, pos_start=None, method='last'): 99 | '''Samples several transitions from memory, using different methods. Complexity = O(sample_size x transition_size) 100 | sample_size : the number of transitions to sample, default all. 101 | pos_start : the position in the memory of the first transition sampled, default 0. 102 | method : the method of sampling in "all", "last", "random", "all_shuffled", "batch_shuffled", "batch". 103 | return : a list containing a list of size sample_size for each kind of element stored. 104 | ''' 105 | if sample_size is None: 106 | sample_size = len(self) 107 | else: 108 | sample_size = min(sample_size, len(self)) 109 | 110 | if method == 'all': 111 | #Each elements in order. 112 | indexes = [n for n in range(len(self))] 113 | 114 | elif method == 'last': 115 | #sample_size last elements in order. 116 | indexes = [n for n in range(len(self) - sample_size, len(self))] 117 | 118 | elif method == 'random': 119 | #sample_size elements sampled. 120 | indexes = [rd.randint(0, len(self) - 1) for _ in range(sample_size)] 121 | 122 | elif method == 'all_shuffled': 123 | #Each elements suffled. 124 | indexes = [n for n in range(len(self))] 125 | rd.shuffle(indexes) 126 | 127 | elif method == "batch": 128 | #Element n° pos_start and sample_size next elements, in order. 129 | indexes = [pos_start + n for n in range(sample_size)] 130 | 131 | elif method == 'batch_shuffled': 132 | #Element n° pos_start and sample_size next elements, shuffled. 133 | indexes = [pos_start + n for n in range(sample_size)] 134 | rd.shuffle(indexes) 135 | 136 | else: 137 | raise NotImplementedError('Not implemented sample') 138 | 139 | trajectory = list() 140 | for elements in self.trajectory.values(): 141 | sampled_elements = torch.tensor(np.array([elements[idx] for idx in indexes])) 142 | if len(sampled_elements.shape) == 1: 143 | sampled_elements = torch.unsqueeze(sampled_elements, -1) 144 | trajectory.append(sampled_elements) 145 | 146 | return trajectory 147 | 148 | def __len__(self): 149 | return len(self.trajectory[self.MEMORY_KEYS[0]]) 150 | 151 | def __empty__(self): 152 | self.trajectory = {key : list() for key in self.MEMORY_KEYS} 153 | 154 | -------------------------------------------------------------------------------- /snake_env.py: -------------------------------------------------------------------------------- 1 | from time import sleep 2 | import gym 3 | import math 4 | import numpy as np 5 | import random 6 | import matplotlib.pyplot as plt 7 | import pygame 8 | from utils import * 9 | 10 | n_actions = 4 11 | rows = 5 12 | width = 500 13 | 14 | snackReward = 30 15 | nothingReward = -1 16 | deathReward = 0 17 | delay = 200 18 | GREEN = (0,255,0) 19 | RED = (255,0,0) 20 | 21 | 22 | 23 | class cube(object): 24 | 25 | def __init__(self,start,dirnx=1,dirny=0,color=(255,0,0)): 26 | self.pos = start 27 | self.w = width 28 | self.rows = rows 29 | self.dirnx = 1 30 | self.dirny = 0 31 | self.color = color 32 | 33 | 34 | def move(self, dirnx, dirny): 35 | self.dirnx = dirnx 36 | self.dirny = dirny 37 | self.pos = (self.pos[0] + self.dirnx, self.pos[1] + self.dirny) 38 | 39 | def draw(self, surface, eyes=False): 40 | dis = self.w // rows 41 | i = self.pos[0] 42 | j = self.pos[1] 43 | 44 | pygame.draw.rect(surface, self.color, (i*dis+1,j*dis+1, dis-2, dis-2)) 45 | if eyes: 46 | centre = dis//2 47 | radius = 3 48 | circleMiddle = (i*dis+centre-radius,j*dis+8) 49 | circleMiddle2 = (i*dis + dis -radius*2, j*dis+8) 50 | pygame.draw.circle(surface, (0,0,0), circleMiddle, radius) 51 | pygame.draw.circle(surface, (0,0,0), circleMiddle2, radius) 52 | 53 | 54 | class snake(object): 55 | body = [] 56 | turns = {} 57 | def __init__(self, color, pos): 58 | self.color = color 59 | self.head = cube(pos) 60 | self.body.append(self.head) 61 | self.dirnx = 0 62 | self.dirny = 1 63 | 64 | def move(self, action): 65 | 66 | for event in pygame.event.get(): 67 | if event.type == pygame.QUIT: 68 | pygame.quit() 69 | 70 | #Choose action 71 | keyNumbers = [pygame.K_LEFT, pygame.K_RIGHT, pygame.K_UP, pygame.K_DOWN] 72 | keys = {keyNumbers[k] : int(k == action) for k in range(n_actions)} 73 | 74 | # Move accordingly to action chosen 75 | if keys[pygame.K_LEFT]: 76 | self.dirnx = -1 77 | self.dirny = 0 78 | self.turns[self.head.pos[:]] = [self.dirnx, self.dirny] 79 | 80 | elif keys[pygame.K_RIGHT]: 81 | self.dirnx = 1 82 | self.dirny = 0 83 | self.turns[self.head.pos[:]] = [self.dirnx, self.dirny] 84 | 85 | elif keys[pygame.K_UP]: 86 | self.dirnx = 0 87 | self.dirny = -1 88 | self.turns[self.head.pos[:]] = [self.dirnx, self.dirny] 89 | 90 | elif keys[pygame.K_DOWN]: 91 | self.dirnx = 0 92 | self.dirny = 1 93 | self.turns[self.head.pos[:]] = [self.dirnx, self.dirny] 94 | 95 | #Move the entire snake body 96 | for i, c in enumerate(self.body): 97 | p = c.pos[:] 98 | if p in self.turns: 99 | turn = self.turns[p] 100 | 101 | c.dirnx, c.dirny = turn 102 | if c.dirnx == -1 and c.pos[0] <= 0: c.pos = (c.rows-1, c.pos[1]) 103 | elif c.dirnx == 1 and c.pos[0] >= c.rows-1: c.pos = (0,c.pos[1]) 104 | elif c.dirny == 1 and c.pos[1] >= c.rows-1: c.pos = (c.pos[0], 0) 105 | elif c.dirny == -1 and c.pos[1] <= 0: c.pos = (c.pos[0],c.rows-1) 106 | else: c.move(c.dirnx,c.dirny) 107 | 108 | if i == len(self.body)-1: 109 | self.turns.pop(p) 110 | else: 111 | if c.dirnx == -1 and c.pos[0] <= 0: c.pos = (c.rows-1, c.pos[1]) 112 | elif c.dirnx == 1 and c.pos[0] >= c.rows-1: c.pos = (0,c.pos[1]) 113 | elif c.dirny == 1 and c.pos[1] >= c.rows-1: c.pos = (c.pos[0], 0) 114 | elif c.dirny == -1 and c.pos[1] <= 0: c.pos = (c.pos[0],c.rows-1) 115 | else: c.move(c.dirnx,c.dirny) 116 | 117 | 118 | 119 | def reset(self, pos): 120 | self.head = cube(pos) 121 | self.body = [] 122 | self.body.append(self.head) 123 | self.turns = {} 124 | self.dirnx = 0 125 | self.dirny = 1 126 | 127 | 128 | def addCube(self): 129 | tail = self.body[-1] 130 | dx, dy = tail.dirnx, tail.dirny 131 | 132 | if dx == 1 and dy == 0: 133 | self.body.append(cube(((tail.pos[0]-1)%rows,(tail.pos[1])%rows))) 134 | elif dx == -1 and dy == 0: 135 | self.body.append(cube(((tail.pos[0]+1)%rows,(tail.pos[1])%rows))) 136 | elif dx == 0 and dy == 1: 137 | self.body.append(cube(((tail.pos[0])%rows,(tail.pos[1]-1)%rows))) 138 | elif dx == 0 and dy == -1: 139 | self.body.append(cube(((tail.pos[0])%rows,(tail.pos[1]+1)%rows))) 140 | 141 | self.body[-1].dirnx = dx 142 | self.body[-1].dirny = dy 143 | 144 | 145 | def draw(self, surface): 146 | for i, c in enumerate(self.body): 147 | if i ==0: 148 | c.draw(surface, True) 149 | else: 150 | c.draw(surface) 151 | 152 | 153 | 154 | 155 | 156 | def drawGrid(w, rows, surface): 157 | sizeBtwn = w // rows 158 | 159 | x = 0 160 | y = 0 161 | for l in range(rows): 162 | x = x + sizeBtwn 163 | y = y + sizeBtwn 164 | 165 | pygame.draw.line(surface, (255,255,255), (x,0),(x,w)) 166 | pygame.draw.line(surface, (255,255,255), (0,y),(w,y)) 167 | 168 | 169 | def randomSnack(rows, item): 170 | 171 | positions = item.body 172 | 173 | while True: 174 | x = random.randrange(rows) 175 | y = random.randrange(rows) 176 | if len(list(filter(lambda z:z.pos == (x,y), positions))) > 0: 177 | continue 178 | else: 179 | break 180 | 181 | return (x,y) 182 | 183 | 184 | 185 | 186 | class SnakeEnv(gym.Env): 187 | metadata = {'render.modes': ['human']} 188 | 189 | def __init__(self) -> None: 190 | super().__init__() 191 | self.win = pygame.display.set_mode((width, width)) 192 | self.snake = snake(RED, (rows//2,rows//2)) 193 | self.snack = cube(randomSnack(rows, self.snake), color=GREEN) 194 | 195 | 196 | def step(self, action): 197 | done = False 198 | s = self.snake 199 | 200 | s.move(action) 201 | #If snack meet his own body, the game end. 202 | for x in range(len(s.body)): 203 | if s.body[x].pos in list(map(lambda z:z.pos,s.body[x+1:])): 204 | reward = deathReward 205 | s.reset((rows//2,rows//2)) 206 | done = True 207 | break 208 | 209 | #If the snake meet a snack, he gains a piece of body, a new snack is generated and AI get rewarded. 210 | if not done and s.body[0].pos == self.snack.pos: 211 | if len(s.body) >= rows**2 - 2: #If map is full of the body, game end. 212 | done = True 213 | s.addCube() 214 | self.snack = cube(randomSnack(rows, s), color=GREEN) 215 | reward = snackReward 216 | else: 217 | reward = nothingReward 218 | 219 | next_obs = self.readState() 220 | info = dict() 221 | return next_obs, reward, done, info 222 | 223 | 224 | def reset(self): 225 | self.snake.reset((rows//2,rows//2)) 226 | obs = self.readState() 227 | return obs 228 | 229 | def render(self): 230 | surface = self.win 231 | surface.fill((0,0,0)) 232 | self.snake.draw(surface) 233 | self.snack.draw(surface) 234 | drawGrid(width,rows, surface) 235 | pygame.display.update() 236 | sleep(0.02) 237 | 238 | def readState(self): 239 | arr = np.array(np.zeros((3, rows, rows)), dtype=np.float32) 240 | s = self.snake 241 | arr[0, s.head.pos[0],s.head.pos[1]] = 1. 242 | for bodyPart in s.body: 243 | arr[1, bodyPart.pos[0], bodyPart.pos[1]] = 1. 244 | arr[2, self.snack.pos[0], self.snack.pos[1]] = 1. 245 | return arr 246 | 247 | 248 | -------------------------------------------------------------------------------- /rl_algos/PPO.py: -------------------------------------------------------------------------------- 1 | from copy import copy, deepcopy 2 | import numpy as np 3 | import math 4 | import gym 5 | import sys 6 | import random 7 | import matplotlib.pyplot as plt 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import torch.nn.functional as F 13 | import torchvision.transforms as T 14 | from torch.distributions.categorical import Categorical 15 | 16 | from div.utils import * 17 | from MEMORY import Memory 18 | from CONFIGS import PPO_CONFIG 19 | from METRICS import * 20 | from rl_algos.AGENT import AGENT 21 | 22 | class PPO(AGENT): 23 | '''PPO updates its networks without changing too much the policy, which increases stability. 24 | NN trained : Actor Critic 25 | Policy used : Off-policy 26 | Online : Yes 27 | Stochastic : Yes 28 | Actions : discrete (continuous not implemented) 29 | States : continuous (discrete not implemented) 30 | ''' 31 | 32 | def __init__(self, actor : nn.Module, state_value : nn.Module): 33 | metrics = [MetricS_On_Learn, Metric_Total_Reward, Metric_Count_Episodes] 34 | super().__init__(config = PPO_CONFIG, metrics = metrics) 35 | self.memory_transition = Memory(MEMORY_KEYS = ['observation', 'action','reward', 'done', 'prob']) 36 | self.memory_episodes = Memory(MEMORY_KEYS = ['episode']) 37 | 38 | self.state_value = state_value 39 | self.state_value_target = deepcopy(state_value) 40 | self.opt_critic = optim.Adam(lr = self.learning_rate_critic, params=self.state_value.parameters()) 41 | 42 | self.policy = actor 43 | 44 | self.last_prob = None 45 | self.episode_ended = False 46 | 47 | 48 | def act(self, observation, mask = None): 49 | '''Ask the agent to take a decision given an observation. 50 | observation : an (n_obs,) shaped numpy observation. 51 | mask : a binary list containing 1 where corresponding actions are forbidden. 52 | return : an int corresponding to an action 53 | ''' 54 | 55 | #Batching observation 56 | observation = torch.Tensor(observation) 57 | observations = observation.unsqueeze(0) # (1, observation_space) 58 | probs = self.policy(observations) # (1, n_actions) 59 | distribs = Categorical(probs = probs) 60 | actions = distribs.sample() 61 | action = actions.numpy()[0] 62 | 63 | #Save metrics 64 | self.add_metric(mode = 'act') 65 | 66 | # Action 67 | self.last_prob = probs[0, action].detach() 68 | return action 69 | 70 | 71 | def learn(self): 72 | '''Do one step of learning. 73 | ''' 74 | values = dict() 75 | self.step += 1 76 | 77 | #Learn every n end of episodes 78 | if not self.episode_ended: 79 | return 80 | self.episode += 1 81 | self.episode_ended = False 82 | if self.episode % self.train_freq_episode != 0: 83 | return 84 | 85 | #Sample trajectories 86 | episodes = self.memory_episodes.sample( 87 | method = "last", 88 | sample_size=self.n_episodes, 89 | as_tensor=False, 90 | ) 91 | episodes = episodes[0] 92 | 93 | #Compute A_s and V_s estimates and concatenate trajectories. 94 | advantages = list() 95 | V_targets = list() 96 | for observations, actions, rewards, dones, probs in episodes: 97 | #Scaling the rewards 98 | if self.reward_scaler is not None: 99 | rewards = rewards / self.reward_scaler 100 | #Compute V and A 101 | advantages.append(self.compute_critic(self.compute_advantage_method, observations = observations, rewards = rewards, dones = dones)) 102 | V_targets = self.compute_TD(rewards, observations) 103 | advantages = torch.concat(advantages, axis = 0).detach() 104 | V_targets = torch.concat(V_targets, axis = 0).detach() 105 | observations, actions, rewards, dones, probs = [torch.concat([episode[elem] for episode in episodes], axis = 0) for elem in range(len(episodes[0]))] 106 | 107 | #Shuffling data 108 | indexes = torch.randperm(len(rewards)) 109 | observations, actions, rewards, dones, probs, advantages, V_targets = \ 110 | [element[indexes] for element in [observations, 111 | actions, 112 | rewards, 113 | dones, 114 | probs, 115 | advantages, 116 | V_targets, 117 | ]] 118 | 119 | #Type bug fixes 120 | actions = actions.to(dtype = torch.int64) 121 | rewards = rewards.float() 122 | 123 | #We perform gradient descent on K epochs on T datas with minibatch of size M <= T. 124 | policy_new = deepcopy(self.policy) 125 | opt_policy = optim.Adam(lr = self.learning_rate_actor, params=policy_new.parameters()) 126 | n_batch = math.ceil(len(observations) / self.batch_size) 127 | 128 | for _ in range(self.epochs): 129 | for i in range(n_batch): 130 | #Batching data 131 | observations_batch = observations[i * self.batch_size : (i+1) * self.batch_size] 132 | actions_batch = actions[i * self.batch_size : (i+1) * self.batch_size] 133 | probs_batch = probs[i * self.batch_size : (i+1) * self.batch_size] 134 | advantages_batch = advantages[i * self.batch_size : (i+1) * self.batch_size] 135 | V_targets_batch = V_targets[i * self.batch_size : (i+1) * self.batch_size] 136 | 137 | #Objective function : J_clip = min(r*A, clip(r,1-e,1+e)A) where r = pi_theta_new/pi_theta_old and A advantage function 138 | pi_theta_new_s_a = policy_new(observations_batch) 139 | pi_theta_new_s = torch.gather(pi_theta_new_s_a, dim = 1, index = actions_batch) 140 | ratio_s = pi_theta_new_s / probs_batch 141 | ratio_s_clipped = torch.clamp(ratio_s, 1 - self.epsilon_clipper, 1 + self.epsilon_clipper) 142 | J_clip = torch.minimum(ratio_s * advantages_batch, ratio_s_clipped * advantages_batch).mean() 143 | 144 | #Error on critic : L = L(V(s), V_target) with V_target = r + gamma * (1-d) * V_target(s_next) 145 | V_s = self.state_value(observations_batch) 146 | critic_loss = F.smooth_l1_loss(V_s, V_targets_batch).mean() 147 | 148 | #Entropy : H = sum_a(- log(p) * p) where p = pi_theta(a|s) 149 | log_pi_theta_s_a = torch.log(pi_theta_new_s_a) 150 | pmlogp_s_a = - log_pi_theta_s_a * pi_theta_new_s_a 151 | H_s = torch.sum(pmlogp_s_a, dim = 1) 152 | H = H_s.mean() 153 | 154 | #Total objective function 155 | J = J_clip - self.c_critic * critic_loss + self.c_entropy * H 156 | loss = - J 157 | 158 | #Gradient descend 159 | opt_policy.zero_grad() 160 | self.opt_critic.zero_grad() 161 | loss.backward(retain_graph = True) 162 | opt_policy.step() 163 | self.opt_critic.step() 164 | 165 | 166 | #Update policy 167 | self.policy = deepcopy(policy_new) 168 | 169 | #Update target network 170 | if self.update_method == "periodic": 171 | if self.step % self.target_update_interval == 0: 172 | self.state_value_target = deepcopy(self.state_value) 173 | elif self.update_method == "soft": 174 | for phi, phi_target in zip(self.state_value.parameters(), self.state_value_target.parameters()): 175 | phi_target.data = self.tau * phi_target.data + (1-self.tau) * phi.data 176 | else: 177 | print(f"Error : update_method {self.update_method} not implemented.") 178 | sys.exit() 179 | 180 | #Save metrics 181 | values["critic_loss"] = critic_loss.detach().numpy() 182 | values["J_clip"] = J_clip.detach().numpy() 183 | values["value"] = V_s.mean().detach().numpy() 184 | values["entropy"] = H.mean().detach().numpy() 185 | self.add_metric(mode = 'learn', **values) 186 | 187 | 188 | def remember(self, observation, action, reward, done, next_observation, info={}, **param): 189 | '''Save elements inside memory. 190 | *arguments : elements to remember, as numerous and in the same order as in self.memory.MEMORY_KEYS 191 | return : metrics, a list of metrics computed during this remembering step. 192 | ''' 193 | prob = self.last_prob.detach() 194 | self.memory_transition.remember((observation, action, reward, done, prob, info)) 195 | if done: 196 | self.episode_ended = True 197 | episode = self.memory_transition.sample(method = 'all', as_tensor=True) 198 | self.memory_transition.__empty__() 199 | self.memory_episodes.remember((episode,)) 200 | 201 | #Save metrics 202 | values = {"obs" : observation, "action" : action, "reward" : reward, "done" : done} 203 | self.add_metric(mode = 'remember', **values) -------------------------------------------------------------------------------- /rl_algos/REINFORCE.py: -------------------------------------------------------------------------------- 1 | from copy import copy, deepcopy 2 | import numpy as np 3 | import math 4 | import gym 5 | import sys 6 | import random 7 | import matplotlib.pyplot as plt 8 | from div.utils import * 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import torch.nn.functional as F 14 | import torchvision.transforms as T 15 | from torch.distributions.categorical import Categorical 16 | 17 | from MEMORY import Memory 18 | from CONFIGS import REINFORCE_CONFIG 19 | from METRICS import Metric_Performances, Metric_Time_Count, Metric_Total_Reward, MetricS_On_Learn 20 | from rl_algos.AGENT import AGENT 21 | 22 | 23 | class REINFORCE(AGENT): 24 | '''REINFORCE agent is an actor RL agent that performs gradient ascends on the estimated objective function to maximize. 25 | NN trained : Actor 26 | Policy used : On-policy 27 | Stochastic : Yes 28 | Actions : discrete (continuous not implemented) 29 | States : continuous (discrete not implemented) 30 | ''' 31 | 32 | def __init__(self, actor : nn.Module): 33 | metrics = [MetricS_On_Learn, Metric_Total_Reward, Metric_Time_Count] 34 | super().__init__(config = REINFORCE_CONFIG, metrics = metrics) 35 | self.memory = Memory(MEMORY_KEYS = ['observation', 'action','reward', 'done']) 36 | 37 | self.policy = actor 38 | self.opt = optim.Adam(lr = 1e-4, params=self.policy.parameters()) 39 | 40 | self.episode = 0 41 | self.must_learn = False 42 | 43 | def act(self, observation, mask = None): 44 | '''Ask the agent to take a decision given an observation. 45 | observation : an (n_obs,) shaped nummpy observation. 46 | mask : a binary list containing 1 where corresponding actions are forbidden. 47 | return : an int corresponding to an action 48 | ''' 49 | 50 | #Batching observation 51 | observation = torch.Tensor(observation) 52 | observations = observation.unsqueeze(0) # (1, observation_space) 53 | probs = self.policy(observations) # (1, n_actions) 54 | distribs = Categorical(probs = probs) 55 | actions = distribs.sample() 56 | action = actions.numpy()[0] 57 | 58 | #Save metrics 59 | self.add_metric(mode = 'act') 60 | 61 | # Action 62 | return action 63 | 64 | 65 | def learn(self): 66 | '''Do one step of learning. 67 | return : metrics, a list of metrics computed during this learning step. 68 | ''' 69 | values = dict() 70 | self.step += 1 71 | 72 | #Learn every batch_size episodes 73 | if not self.must_learn: 74 | return 75 | self.must_learn = False 76 | 77 | #Sample trajectories 78 | batches = self.memory.sample( 79 | method = "episodic_batches", 80 | ) 81 | 82 | #Compute mean value of gradients over a batch 83 | for _ in range(self.gradient_steps): 84 | loss_mean = torch.tensor(0.) 85 | 86 | for observations, actions, rewards, dones in batches: 87 | 88 | #Some actions dtype problem 89 | actions = actions.to(dtype = torch.int64) 90 | 91 | #Scaling the rewards 92 | if self.reward_scaler is not None: 93 | rewards /= self.reward_scaler 94 | 95 | #Compute Gt the discounted sum of future rewards 96 | ep_lenght = rewards.shape[0] #T 97 | rewards = rewards[:,0].numpy().tolist() 98 | G = [rewards[-1]] 99 | for i in range(1, ep_lenght): 100 | t = ep_lenght - i 101 | previous_G_t = rewards[t] + self.gamma * G[0] 102 | G.insert(0, previous_G_t) 103 | G = torch.tensor(G) 104 | 105 | #Compute log probs 106 | probs = self.policy(observations) #(T, n_actions) 107 | probs = torch.gather(probs, dim = 1, index = actions) #(T, 1) 108 | log_probs = torch.log(probs)[:,0] #(T,) 109 | 110 | #Compute loss = -sum_t( G_t * log_proba_t ) and add it to mean loss 111 | loss = torch.multiply(log_probs, G) 112 | loss = - torch.sum(loss) 113 | loss_mean += loss / self.batch_size 114 | 115 | #Backpropagate to improve policy 116 | self.opt.zero_grad() 117 | loss_mean.backward() 118 | self.opt.step() 119 | 120 | self.memory.__empty__() 121 | 122 | #Save metrics 123 | values["actor_loss"] = loss.detach().numpy() 124 | self.add_metric(mode = 'learn', **values) 125 | 126 | 127 | 128 | def remember(self, observation, action, reward, done, next_observation, info={}, **param): 129 | '''Save elements inside memory. 130 | *arguments : elements to remember, as numerous and in the same order as in self.memory.MEMORY_KEYS 131 | return : metrics, a list of metrics computed during this remembering step. 132 | ''' 133 | self.memory.remember((observation, action, reward, done, next_observation, info)) 134 | if done: 135 | self.episode += 1 136 | if self.episode % self.batch_size == 0: 137 | self.must_learn = True 138 | 139 | #Save metrics 140 | values = {"obs" : observation, "action" : action, "reward" : reward, "done" : done} 141 | self.add_metric(mode = 'remember', **values) 142 | 143 | 144 | 145 | 146 | 147 | 148 | class REINFORCE_OFFPOLICY(AGENT): 149 | '''REINFORCE agent is an actor RL agent that performs gradient ascends on the estimated objective function to maximize. 150 | The offpolicy version add importances weights to keep an unbiased gradient. 151 | NN trained : Actor 152 | Policy used : Off-policy 153 | Stochastic : Yes 154 | Actions : discrete (continuous not implemented) 155 | States : continuous 156 | ''' 157 | 158 | def __init__(self, actor : nn.Module): 159 | metrics = [MetricS_On_Learn, Metric_Total_Reward, Metric_Time_Count] 160 | super().__init__(config = REINFORCE_CONFIG, metrics = metrics) 161 | self.memory_transition = Memory(MEMORY_KEYS = ['observation', 'action','reward', 'done', 'prob']) 162 | self.memory_episodes = Memory(MEMORY_KEYS = ['episode']) 163 | 164 | self.policy = actor 165 | self.opt = optim.Adam(lr = 1e-4, params=self.policy.parameters()) 166 | 167 | self.last_prob = None 168 | self.episode_ended = False 169 | 170 | def act(self, observation, mask = None): 171 | '''Ask the agent to take a decision given an observation. 172 | observation : an (n_obs,) shaped nummpy observation. 173 | mask : a binary list containing 1 where corresponding actions are forbidden. 174 | return : an int corresponding to an action 175 | ''' 176 | with torch.no_grad(): 177 | #Batching observation 178 | observation = torch.Tensor(observation) 179 | observations = observation.unsqueeze(0) # (1, observation_space) 180 | probs = self.policy(observations) # (1, n_actions) 181 | distribs = Categorical(probs = probs) 182 | actions = distribs.sample() 183 | action = actions.numpy()[0] 184 | 185 | #Save metrics 186 | self.add_metric(mode = 'act') 187 | 188 | # Action 189 | self.last_prob = probs[0, action] 190 | return action 191 | 192 | 193 | def learn(self): 194 | '''Do one step of learning. 195 | return : metrics, a list of metrics computed during this learning step. 196 | ''' 197 | values = dict() 198 | self.step += 1 199 | 200 | #Learn when a done=True 201 | if not self.episode_ended: 202 | return 203 | self.episode_ended = False 204 | 205 | #Sample trajectories 206 | episodes = self.memory_episodes.sample( 207 | method = "last", 208 | sample_size=128, 209 | as_tensor=False, 210 | ) 211 | episodes = episodes[0] 212 | 213 | #Compute mean value of gradients over a batch 214 | for _ in range(self.gradient_steps): 215 | loss_mean = torch.tensor(0.) 216 | batch_size = len(episodes) 217 | 218 | for observations, actions, rewards, dones, old_probs in episodes: 219 | 220 | #Some actions dtype problem 221 | actions = actions.to(dtype = torch.int64) 222 | 223 | #Scaling the rewards 224 | if self.reward_scaler is not None: 225 | rewards /= self.reward_scaler 226 | 227 | #Compute Gt the discounted sum of future rewards 228 | ep_lenght = rewards.shape[0] #T 229 | rewards = rewards[:,0].numpy().tolist() 230 | G = [rewards[-1]] 231 | for i in range(1, ep_lenght): 232 | t = ep_lenght - i 233 | previous_G_t = rewards[t] + self.gamma * G[0] 234 | G.insert(0, previous_G_t) 235 | G = torch.tensor(G) 236 | 237 | #Compute loss 238 | if self.J_method == "ratio_ln": 239 | probs = self.policy(observations) #(T, n_actions) 240 | probs = torch.gather(probs, dim = 1, index = actions)[:,0] #(T,) 241 | log_probs = torch.log(probs) #(T,) 242 | 243 | old_probs = old_probs[:, 0] 244 | ratios = (probs / old_probs).detach() 245 | # ratios = torch.clamp(ratios, 1 - self.epsilon_clipper, 1 + self.epsilon_clipper) 246 | log_probs = torch.multiply(log_probs, ratios) 247 | 248 | loss = torch.multiply(log_probs, G) 249 | loss = - torch.sum(loss) 250 | loss_mean += loss / batch_size 251 | 252 | elif self.J_method == "ratio": 253 | #Compute log probs 254 | probs = self.policy(observations) #(T, n_actions) 255 | probs = torch.gather(probs, dim = 1, index = actions)[:,0] #(T,) 256 | ratios = probs / old_probs 257 | ratios = torch.clamp(ratios, 1 - self.epsilon_clipper, 1 + self.epsilon_clipper) 258 | 259 | loss = - torch.multiply(ratios, G) 260 | loss = torch.sum(loss) 261 | loss_mean += loss / batch_size 262 | 263 | 264 | 265 | #Backpropagate to improve policy 266 | self.opt.zero_grad() 267 | loss_mean.backward() 268 | self.opt.step() 269 | 270 | 271 | #Save metrics 272 | values["actor_loss"] = loss.detach().numpy() 273 | self.add_metric(mode = 'learn', **values) 274 | 275 | 276 | 277 | def remember(self, observation, action, reward, done, next_observation, info={}, **param): 278 | '''Save elements inside memory. 279 | *arguments : elements to remember, as numerous and in the same order as in self.memory.MEMORY_KEYS 280 | return : metrics, a list of metrics computed during this remembering step. 281 | ''' 282 | prob = self.last_prob 283 | self.memory_transition.remember((observation, action, reward, done, prob, info)) 284 | if done: 285 | self.episode_ended = True 286 | episode = self.memory_transition.sample(method = 'all') 287 | self.memory_transition.__empty__() 288 | self.memory_episodes.remember((episode,)) 289 | 290 | #Save metrics 291 | values = {"obs" : observation, "action" : action, "reward" : reward, "done" : done, "prob" : prob} 292 | self.add_metric(mode = 'remember', **values) 293 | 294 | -------------------------------------------------------------------------------- /rl_algos/ACTOR_CRITIC.py: -------------------------------------------------------------------------------- 1 | from copy import copy, deepcopy 2 | from operator import index 3 | import numpy as np 4 | import math 5 | import gym 6 | import sys 7 | import random 8 | import matplotlib.pyplot as plt 9 | from div.utils import * 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | import torch.nn.functional as F 15 | import torchvision.transforms as T 16 | from torch.distributions.categorical import Categorical 17 | 18 | from MEMORY import Memory 19 | from CONFIGS import ACTOR_CRITIC_CONFIG 20 | from METRICS import * 21 | from rl_algos.AGENT import AGENT 22 | 23 | class ACTOR_CRITIC(AGENT): 24 | 25 | def __init__(self, 26 | actor : nn.Module, 27 | action_value : nn.Module = None, 28 | state_value : nn.Module = None, 29 | advantage_value : nn.Module = None, 30 | ): 31 | '''A general Actor Critic algorithm, using a policy network that learns to act and a critic network that learns the model. 32 | The parameter compute_gain_method defines the method for defining the weight of the policy gradients which can be (or not be) offline, centered, causal and using a critic. 33 | 34 | memory : a Memory object 35 | actor : a nn.Module neural network, pi_theta : s --> [p(a|s) for a] 36 | action_value : a nn.Module neural network, Q_omega : s --> [Q(s,a) for a] 37 | state_value : a nn.Module neural network, V_phi : s --> V(s) 38 | metrics : a list of Metrics objects 39 | **config : a config dictionnary for Actor_Critic 40 | ''' 41 | metrics = [MetricS_On_Learn, Metric_Total_Reward, Metric_Count_Episodes] 42 | super().__init__(config = ACTOR_CRITIC_CONFIG, metrics = metrics) 43 | self.memory = Memory(MEMORY_KEYS = ['observation', 'action','reward', 'done', 'next_observation']) 44 | self.step = 0 45 | 46 | self.setup_critic(action_value, state_value, advantage_value) 47 | self.policy = actor 48 | self.opt_policy = optim.Adam(lr = self.learning_rate_actor, params=self.policy.parameters()) 49 | 50 | 51 | 52 | def setup_critic(self, Q, V, A): 53 | '''Method for preparing the ACTOR_CRITIC object to use and train its critic depending of the method used 54 | for computing gain. 55 | ''' 56 | if self.compute_gain_method in ("total_reward", "total_future_reward", "total_reward_minus_leaky_mean", "total_reward_minus_MC_mean"): 57 | self.use_Q, self.use_V, self.use_A = False, False, False 58 | if self.compute_gain_method == "total_reward_minus_leaky_mean": 59 | self.alpha_0 = self.config["alpha_0"] if "alpha_0" in self.config else 1e-2 60 | self.V_0 = 0 61 | elif self.compute_gain_method in ("state_value", "state_value_centered", "total_future_reward_minus_state_value", "GAE"): 62 | if V is None: 63 | raise Exception(f"Using method {self.compute_gain_method} requires to use state value V.") 64 | self.use_Q, self.use_V, self.use_A = False, True, False 65 | self.state_value = V 66 | self.state_value_target = deepcopy(V) 67 | self.opt_critic = optim.Adam(lr = self.learning_rate_critic, params = self.state_value.parameters()) 68 | elif self.compute_gain_method in ("action_value", "action_value_centered", "total_future_reward_minus_action_value"): 69 | if Q is None: 70 | raise Exception(f"Using method {self.compute_gain_method} requires to use action value Q.") 71 | self.use_Q, self.use_V, self.use_A = True, False, False 72 | self.action_value = Q 73 | self.opt_critic = optim.Adam(lr = self.learning_rate_critic, params=self.action_value.parameters()) 74 | elif self.compute_gain_method == "advantage_value": 75 | if A is None: 76 | raise Exception(f"Using method {self.compute_gain_method} requires to use advantage value A.") 77 | self.use_Q, self.use_V, self.use_A = False, False, True 78 | self.advantage_value = A 79 | self.opt_critic = optim.Adam(lr = self.learning_rate_critic, params=self.action_value.parameters()) 80 | print("NOT IMPLEMENTED") 81 | raise 82 | else: 83 | raise NotImplementedError(f"Method {self.compute_gain_method} is not implemented.") 84 | 85 | 86 | 87 | 88 | 89 | def act(self, observation, mask = None): 90 | '''Ask the agent to take a decision given an observation. 91 | observation : an (n_obs,) shaped observation. 92 | greedy : whether the agent always choose the best Q values according to himself. 93 | mask : a binary list containing 1 where corresponding actions are forbidden. 94 | return : an int corresponding to an action 95 | ''' 96 | 97 | #Batching observation 98 | observations = torch.Tensor(observation) 99 | observations = observations.unsqueeze(0) # (1, observation_space) 100 | probs = self.policy(observations) # (1, n_actions) 101 | distribs = Categorical(probs = probs) 102 | actions = distribs.sample() 103 | action = actions.numpy()[0] 104 | 105 | #Save metrics 106 | self.add_metric(mode = 'act') 107 | 108 | # Action 109 | return action 110 | 111 | 112 | def learn(self): 113 | '''Do one step of learning. 114 | return : metrics, a list of metrics computed during this learning step. 115 | ''' 116 | self.step += 1 117 | values = dict() 118 | 119 | #Sample trajectories 120 | observations, actions, rewards, dones, next_observations = self.memory.sample( 121 | method = "all", 122 | ) 123 | actions = actions.to(dtype = torch.int64) 124 | 125 | #Learn only at end of episode 126 | if not dones[-1]: 127 | return 128 | 129 | #Scaling the rewards 130 | if self.reward_scaler is not None: 131 | rewards /= self.reward_scaler 132 | 133 | #Updating the policy 134 | if self.step % self.batch_size == 0: 135 | #Loss = - sum_t(G_t * ln(pi(a_t|s_t))) 136 | #G_t can be estimated by various methods 137 | with torch.no_grad(): 138 | G = self.compute_gain(observations, actions, rewards, dones, next_observations, method = self.compute_gain_method) 139 | for _ in range(self.gradient_steps_policy): 140 | self.opt_policy.zero_grad() 141 | probs = self.policy(observations) #(T, n_actions) 142 | probs = torch.gather(probs, dim = 1, index = actions) #(T, 1) 143 | log_probs = torch.log(probs)[:,0] #(T,) 144 | loss_pi = torch.multiply(log_probs, G) 145 | loss_pi = - torch.sum(loss_pi) 146 | #Backpropagate to improve policy 147 | loss_pi.backward(retain_graph = True) 148 | self.opt_policy.step() 149 | #Empty memory of previous episode 150 | self.memory.__empty__() 151 | values["actor_loss"] = loss_pi.detach().numpy().mean() 152 | 153 | #Updating the action value 154 | if self.use_Q: 155 | #Bootsrapping : Q(s,a) = r + gamma * max_a'(Q(s_next, a')) * (1-d) 156 | criterion = nn.MSELoss() 157 | with torch.no_grad(): 158 | Q_s_a_next = self.action_value(next_observations) 159 | Q_s_next, bests_a = torch.max(Q_s_a_next, dim = 1, keepdim=True) 160 | Q_s_estimated = rewards + (1-dones) * self.gamma * Q_s_next 161 | for _ in range(self.gradient_steps_critic): 162 | self.opt_critic.zero_grad() 163 | Q_s_a = self.action_value(observations) 164 | Q_s = Q_s_a.gather(dim = 1, index = actions) 165 | loss_Q = criterion(Q_s_estimated, Q_s) 166 | loss_Q.backward(retain_graph = True) 167 | if self.clipping is not None: 168 | for param in self.action_value.parameters(): 169 | param.grad.data.clamp_(-self.clipping, self.clipping) 170 | self.opt_critic.step() 171 | values["critic_loss"] = loss_Q.detach().numpy() 172 | values["value"] = Q_s.detach().numpy().mean() 173 | 174 | #Updating the state value 175 | if self.use_V: 176 | #Bootstrapping : V(s) = r + gamma * V(s_next) * (1-d) or MC estimation 177 | criterion = nn.MSELoss() 178 | with torch.no_grad(): 179 | V_s_estimated = rewards + (1-dones) * self.gamma * self.state_value_target(next_observations) 180 | V_s_estimated = V_s_estimated.to(torch.float32) 181 | for _ in range(self.gradient_steps_critic): 182 | self.opt_critic.zero_grad() 183 | V_s = self.state_value(observations) 184 | loss_V = criterion(V_s, V_s_estimated) 185 | loss_V.backward() 186 | if self.clipping is not None: 187 | for param in self.state_value.parameters(): 188 | param.grad.data.clamp_(-self.clipping, self.clipping) 189 | self.opt_critic.step() 190 | values["critic_loss"] = loss_V.detach().numpy() 191 | values["value"] = V_s.detach().numpy().mean() 192 | 193 | #Updating the advantage value 194 | if self.use_A and self.step == 0: 195 | #to implement if possible 196 | raise 197 | 198 | #Updating V_0 199 | if self.compute_gain_method == "total_reward_minus_leaky_mean": 200 | ep_lenght = rewards.shape[0] 201 | weigths_gamma = torch.Tensor([self.gamma ** t for t in range(ep_lenght)]) 202 | rewards_weighted = torch.multiply(rewards, weigths_gamma) 203 | total_reward = torch.sum(rewards_weighted) 204 | self.V_0 += self.alpha_0 * (total_reward - self.V_0) 205 | 206 | #Metrics 207 | self.add_metric(mode = 'learn', **values) 208 | 209 | 210 | 211 | 212 | 213 | 214 | def remember(self, observation, action, reward, done, next_observation, info={}, **param): 215 | '''Save elements inside memory. 216 | *arguments : elements to remember, as numerous and in the same order as in self.memory.MEMORY_KEYS 217 | return : metrics, a list of metrics computed during this remembering step. 218 | ''' 219 | self.memory.remember((observation, action, reward, done, next_observation)) 220 | 221 | values = {"obs" : observation, "action" : action, "reward" : reward, "done" : done, "next_obs" : next_observation} 222 | self.add_metric(mode = 'remember', **values) 223 | 224 | 225 | 226 | 227 | def compute_gain(self, observations, actions, rewards, dones, next_observations, method): 228 | '''Compute the "gain" or the "advantage function" that will be applied as weights to the gradients of ln(pi). 229 | *args : the elements of a trajectory during one episode 230 | method : the method used for computing the gain 231 | return : a tensor of shape the lenght of the previous episode, containing the gain/advantage at each step t. 232 | ''' 233 | ep_lenght = rewards.shape[0] 234 | if method == "total_reward": 235 | weigths_gamma = torch.Tensor([self.gamma ** t for t in range(ep_lenght)]) 236 | rewards_weighted = torch.multiply(rewards, weigths_gamma) 237 | total_reward = torch.sum(rewards_weighted) 238 | G = total_reward.repeat(repeats = (ep_lenght,)) 239 | elif method == "total_future_reward": 240 | G = self.compute_future_total_rewards(rewards) 241 | elif method == "total_reward_minus_MC_mean": 242 | total_reward_MC_mean = None #to implement 243 | G = self.compute_future_total_rewards(rewards) - total_reward_MC_mean 244 | elif method == "total_reward_minus_leaky_mean": 245 | G = self.compute_future_total_rewards(rewards) - self.V_0 246 | elif method == "total_future_reward_minus_state_value": 247 | G = self.compute_future_total_rewards(rewards) - self.state_value(observations) 248 | elif method == "state_value": 249 | G = self.state_value(observations)[:,0] 250 | elif method == "state_value_centered": 251 | V_s = self.state_value(observations) 252 | G = rewards + (1-dones) * self.gamma * self.state_value(next_observations) - V_s 253 | G = G[:, 0] 254 | elif method == "GAE": 255 | delta = (rewards + self.state_value(next_observations) - self.state_value(observations)).detach().numpy()[:, 0] 256 | A_GAE = [None for _ in range(ep_lenght - 1)] + [delta[-1]] 257 | for u in range(1, ep_lenght): 258 | t = ep_lenght - 1 - u 259 | A_GAE[t] = self.gamma * self.lmbda * A_GAE[t+1] + delta[t] 260 | G = torch.tensor(A_GAE, dtype=torch.float) 261 | 262 | elif method == "action_value": 263 | Q_s_a = self.action_value(observations) 264 | G = torch.gather(Q_s_a, dim = 1, index=actions)[:, 0] 265 | elif method == "action_value_centered": 266 | Q_s_a = self.action_value(observations) 267 | Q_s = torch.gather(Q_s_a, dim = 1, index=actions)[:, 0] 268 | PI_s_a = self.policy(observations) 269 | Q_s_a_weighted = Q_s_a * PI_s_a 270 | Q_s_mean = torch.sum(Q_s_a_weighted, dim = 1) 271 | G = Q_s - Q_s_mean 272 | elif method == "total_future_reward_minus_action_value": 273 | total_future_rewards = self.compute_future_total_rewards(rewards) 274 | Q_s_a = self.action_value(observations) 275 | PI_s_a = self.policy(observations) 276 | Q_s_a_weighted = Q_s_a * PI_s_a 277 | Q_s_mean = torch.sum(Q_s_a_weighted, dim = 1) 278 | G = total_future_rewards - Q_s_mean 279 | 280 | 281 | else: 282 | raise NotImplementedError("Method for computing gain is not recognized.") 283 | 284 | return G.detach() 285 | 286 | 287 | def compute_future_total_rewards(self, rewards): 288 | '''Compute [G_t for t] where G_t is the sum of future reward weighted by discount factor. 289 | rewards : a tensor of shape (duration_last_episode, 1) 290 | return : a tensor of same shape, [G_t for t] where G_t is the sum for t' >= t of r_t * gamma^(t' - t) 291 | ''' 292 | ep_lenght = rewards.shape[0] 293 | weigths_gamma = torch.Tensor([self.gamma ** t for t in range(ep_lenght)]) 294 | rewards_weighted = torch.multiply(rewards, weigths_gamma) 295 | future_total_rewards = list(torch.sum(rewards_weighted[t:]) for t in range(ep_lenght)) 296 | return torch.Tensor(future_total_rewards) 297 | 298 | 299 | --------------------------------------------------------------------------------