├── seac ├── configs │ ├── rware1.yaml │ ├── rware2.yaml │ ├── rware4.yaml │ ├── rware3.yaml │ ├── foraging1.yaml │ ├── foraging2.yaml │ ├── foraging3.yaml │ ├── foraging5.yaml │ └── foraging4.yaml ├── pretrained │ └── rware-small-4ag │ │ ├── agent0 │ │ └── models.pt │ │ ├── agent1 │ │ └── models.pt │ │ ├── agent2 │ │ └── models.pt │ │ └── agent3 │ │ └── models.pt ├── utils.py ├── distributions.py ├── evaluate.py ├── envs.py ├── run.py ├── wrappers.py ├── a2c.py ├── model.py ├── train.py └── storage.py ├── .flake8 ├── seql ├── model.py ├── utilities │ ├── model_saver.py │ └── logger.py ├── marl_algorithm.py ├── lbf_train.py ├── rware_train.py ├── marl_utils.py ├── agent.py ├── wrappers.py ├── baseline_buffer.py ├── iql.py └── train.py ├── requirements.txt ├── README.md └── .gitignore /seac/configs/rware1.yaml: -------------------------------------------------------------------------------- 1 | env_name: rware-tiny-4ag-v1 2 | num_env_steps: 40000000.0 3 | time_limit: 500 4 | -------------------------------------------------------------------------------- /seac/configs/rware2.yaml: -------------------------------------------------------------------------------- 1 | env_name: rware-tiny-2ag-v1 2 | num_env_steps: 80000000.0 3 | time_limit: 500 4 | -------------------------------------------------------------------------------- /seac/configs/rware4.yaml: -------------------------------------------------------------------------------- 1 | env_name: rware-small-4ag-v1 2 | num_env_steps: 150000000.0 3 | time_limit: 500 4 | -------------------------------------------------------------------------------- /seac/configs/rware3.yaml: -------------------------------------------------------------------------------- 1 | env_name: rware-tiny-2ag-hard-v1 2 | num_env_steps: 120000000.0 3 | time_limit: 500 4 | -------------------------------------------------------------------------------- /seac/configs/foraging1.yaml: -------------------------------------------------------------------------------- 1 | env_name: Foraging-10x10-3p-3f-v0 2 | num_env_steps: 50000000.0 3 | time_limit: 25 4 | -------------------------------------------------------------------------------- /seac/configs/foraging2.yaml: -------------------------------------------------------------------------------- 1 | env_name: Foraging-15x15-3p-4f-v0 2 | num_env_steps: 50000000.0 3 | time_limit: 25 4 | -------------------------------------------------------------------------------- /seac/configs/foraging3.yaml: -------------------------------------------------------------------------------- 1 | env_name: Foraging-12x12-2p-1f-v0 2 | num_env_steps: 50000000.0 3 | time_limit: 25 4 | -------------------------------------------------------------------------------- /seac/configs/foraging5.yaml: -------------------------------------------------------------------------------- 1 | env_name: Foraging-8x8-2p-2f-coop-v0 2 | num_env_steps: 50000000.0 3 | time_limit: 25 4 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E226,E302,E41, F841 3 | max-line-length = 89 4 | exclude = tests/* 5 | max-complexity = 10 -------------------------------------------------------------------------------- /seac/configs/foraging4.yaml: -------------------------------------------------------------------------------- 1 | env_name: Foraging-12x12-2p-1f-coop-v0 2 | num_env_steps: 50000000.0 3 | time_limit: 25 4 | -------------------------------------------------------------------------------- /seac/pretrained/rware-small-4ag/agent0/models.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/semitable/seac/HEAD/seac/pretrained/rware-small-4ag/agent0/models.pt -------------------------------------------------------------------------------- /seac/pretrained/rware-small-4ag/agent1/models.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/semitable/seac/HEAD/seac/pretrained/rware-small-4ag/agent1/models.pt -------------------------------------------------------------------------------- /seac/pretrained/rware-small-4ag/agent2/models.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/semitable/seac/HEAD/seac/pretrained/rware-small-4ag/agent2/models.pt -------------------------------------------------------------------------------- /seac/pretrained/rware-small-4ag/agent3/models.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/semitable/seac/HEAD/seac/pretrained/rware-small-4ag/agent3/models.pt -------------------------------------------------------------------------------- /seac/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | def init(module, weight_init, bias_init, gain=1): 8 | weight_init(module.weight.data, gain=gain) 9 | bias_init(module.bias.data) 10 | return module 11 | 12 | 13 | def cleanup_log_dir(log_dir): 14 | try: 15 | os.makedirs(log_dir) 16 | except OSError: 17 | files = glob.glob(os.path.join(log_dir, "*.monitor.csv")) 18 | for f in files: 19 | os.remove(f) 20 | -------------------------------------------------------------------------------- /seac/distributions.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from utils import init 8 | 9 | """ 10 | Modify standard PyTorch distributions so they are compatible with this code. 11 | """ 12 | # Categorical 13 | class FixedCategorical(torch.distributions.Categorical): 14 | def sample(self): 15 | return super().sample().unsqueeze(-1) 16 | 17 | def log_probs(self, actions): 18 | return ( 19 | super() 20 | .log_prob(actions.squeeze(-1)) 21 | .view(actions.size(0), -1) 22 | .sum(-1) 23 | .unsqueeze(-1) 24 | ) 25 | 26 | def mode(self): 27 | return self.probs.argmax(dim=-1, keepdim=True) 28 | 29 | 30 | class Categorical(nn.Module): 31 | def __init__(self, num_inputs, num_outputs): 32 | super(Categorical, self).__init__() 33 | 34 | init_ = lambda m: init( 35 | m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), gain=0.01 36 | ) 37 | 38 | self.linear = init_(nn.Linear(num_inputs, num_outputs)) 39 | 40 | def forward(self, x): 41 | x = self.linear(x) 42 | return FixedCategorical(logits=x) 43 | -------------------------------------------------------------------------------- /seql/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | device = "cpu" #torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | 7 | 8 | class QNetwork(nn.Module): 9 | """Deep Q-Network""" 10 | 11 | def __init__( 12 | self, state_size, action_size, hidden_dim, nonlin=F.relu 13 | ): 14 | """ 15 | Initialize parameters and build model. 16 | :param state_size: Dimension of each state 17 | :param action_size: Dimension of each action 18 | :param hidden_dim: dimension of hidden layers 19 | :param nonlin: nonlinearity to use 20 | """ 21 | super(QNetwork, self).__init__() 22 | self.fc1 = nn.Linear(state_size, hidden_dim) 23 | self.fc2 = nn.Linear(hidden_dim, hidden_dim) 24 | self.fc3 = nn.Linear(hidden_dim, action_size) 25 | self.nonlin = nonlin 26 | 27 | def forward(self, state): 28 | """ 29 | Compute forward pass over QNetwork 30 | :param state: state representation for input state 31 | :return: forward pass result 32 | """ 33 | x = self.nonlin(self.fc1(state)) 34 | x = self.nonlin(self.fc2(x)) 35 | return self.fc3(x) 36 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.12.0 2 | attrs==20.3.0 3 | cachetools==4.2.1 4 | certifi==2020.12.5 5 | chardet==4.0.0 6 | cloudpickle==1.6.0 7 | colorama==0.4.4 8 | cycler==0.10.0 9 | decorator==4.4.2 10 | docopt==0.6.2 11 | future==0.18.2 12 | gitdb==4.0.7 13 | GitPython==3.1.14 14 | google-auth==1.28.1 15 | google-auth-oauthlib==0.4.4 16 | grpcio==1.37.0 17 | gym==0.18.0 18 | idna==2.10 19 | iniconfig==1.1.1 20 | jsonpickle==1.5.2 21 | kiwisolver==1.3.1 22 | lbforaging==1.0.15 23 | Markdown==3.3.4 24 | matplotlib==3.4.1 25 | munch==2.5.0 26 | networkx==2.5.1 27 | numpy==1.20.2 28 | oauthlib==3.1.0 29 | packaging==20.9 30 | pandas==1.2.4 31 | Pillow==7.2.0 32 | pluggy==0.13.1 33 | protobuf==3.15.8 34 | py==1.10.0 35 | py-cpuinfo==8.0.0 36 | pyasn1==0.4.8 37 | pyasn1-modules==0.2.8 38 | pyglet==1.5.0 39 | pyparsing==2.4.7 40 | pytest==6.2.3 41 | python-dateutil==2.8.1 42 | pytz==2021.1 43 | PyYAML==5.4.1 44 | requests==2.25.1 45 | requests-oauthlib==1.3.0 46 | rsa==4.7.2 47 | sacred==0.8.2 48 | scipy==1.6.2 49 | six==1.15.0 50 | smmap==4.0.0 51 | stable-baselines3==1.0 52 | tensorboard==2.4.1 53 | tensorboard-plugin-wit==1.8.0 54 | toml==0.10.2 55 | torch==1.8.1 56 | tqdm==4.60.0 57 | typing-extensions==3.7.4.3 58 | urllib3==1.26.4 59 | Werkzeug==1.0.1 60 | wrapt==1.12.1 61 | -------------------------------------------------------------------------------- /seac/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import rware 3 | import lbforaging 4 | import gym 5 | 6 | from a2c import A2C 7 | from wrappers import RecordEpisodeStatistics, TimeLimit, Monitor 8 | 9 | path = "pretrained/rware-small-4ag" 10 | env_name = "rware-small-4ag-v1" 11 | time_limit = 500 # 25 for LBF 12 | 13 | EPISODES = 5 14 | 15 | env = gym.make(env_name) 16 | agents = [ 17 | A2C(i, osp, asp, 0.1, 0.1, False, 1, 1, "cpu") 18 | for i, (osp, asp) in enumerate(zip(env.observation_space, env.action_space)) 19 | ] 20 | for agent in agents: 21 | agent.restore(path + f"/agent{agent.agent_id}") 22 | 23 | for ep in range(EPISODES): 24 | env = gym.make(env_name) 25 | env = Monitor(env, f"seac_rware-small-4ag_eval/video_ep{ep+1}", mode="evaluation") 26 | env = TimeLimit(env, time_limit) 27 | env = RecordEpisodeStatistics(env) 28 | 29 | obs = env.reset() 30 | done = [False] * len(agents) 31 | 32 | while not all(done): 33 | obs = [torch.from_numpy(o) for o in obs] 34 | _, actions, _ , _ = zip(*[agent.model.act(obs[agent.agent_id], None, None) for agent in agents]) 35 | actions = [a.item() for a in actions] 36 | env.render() 37 | obs, _, done, info = env.step(actions) 38 | obs = env.reset() 39 | print("--- Episode Finished ---") 40 | print(f"Episode rewards: {sum(info['episode_reward'])}") 41 | print(info) 42 | print(" --- ") 43 | -------------------------------------------------------------------------------- /seql/utilities/model_saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | 6 | class ModelSaver: 7 | """ 8 | Class to save model parameters 9 | """ 10 | 11 | def __init__(self, models_dir="models", run_name="default"): 12 | self.models_dir = models_dir 13 | self.run_name = run_name 14 | 15 | def clear_models(self): 16 | """ 17 | Remove model files in model dir 18 | """ 19 | if not os.path.isdir(self.models_dir): 20 | return 21 | model_dir = os.path.join(self.models_dir, self.run_name) 22 | if not os.path.isdir(model_dir): 23 | return 24 | for f in os.listdir(model_dir): 25 | f_path = os.path.join(model_dir, f) 26 | if not os.path.isfile(f_path): 27 | continue 28 | os.remove(f_path) 29 | 30 | def save_models(self, alg, extension): 31 | """ 32 | generate and save networks 33 | :param model_dir_path: path of model directory 34 | :param run_name: name of run 35 | :param alg_name: name of used algorithm 36 | :param alg: training object of trained algorithm 37 | :param extension: name extension 38 | """ 39 | if not os.path.isdir(self.models_dir): 40 | os.mkdir(self.models_dir) 41 | model_dir = os.path.join(self.models_dir, self.run_name) 42 | if not os.path.isdir(model_dir): 43 | os.mkdir(model_dir) 44 | 45 | for i, agent in enumerate(alg.agents): 46 | name = "iql_agent%d_params_" % i 47 | name += extension 48 | torch.save(agent.model.state_dict(), os.path.join(model_dir, name)) 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Shared Experience Actor Critic 2 | 3 | This repository is the official implementation of [Shared Experience Actor Critic](https://arxiv.org/abs/2006.07169). 4 | 5 | ## Requirements 6 | 7 | For the experiments in LBF and RWARE, please install from: 8 | - [Level Based Foraging Official Repo](https://github.com/uoe-agents/lb-foraging) 9 | - [Multi-Robot Warehouse Official Repo](https://github.com/uoe-agents/robotic-warehouse) 10 | 11 | Also requires, PyTorch 1.6+ 12 | 13 | ## Training - SEAC 14 | To train the agents in the paper, navigate to the seac directory: 15 | ``` 16 | cd seac 17 | ``` 18 | 19 | And run: 20 | 21 | ```train 22 | python train.py with 23 | ``` 24 | 25 | Valid environment configs are: 26 | - `env_name=Foraging-15x15-3p-4f-v0 time_limit=25` 27 | - ... 28 | - `env_name=Foraging-12x12-2p-1f-v0 time_limit=25` or any other foraging environment size/configuration. 29 | - `env_name=rware-tiny-2ag-v1 time_limit=500` 30 | - `env_name=rware-tiny-4ag-v1 time_limit=500` 31 | - ... 32 | - `env_name=rware-tiny-2ag-hard-v1 time_limit=500` or any other rware environment size/configuration. 33 | ## Training - SEQL 34 | 35 | To train the agents in the paper, navigate to the seac directory: 36 | ``` 37 | cd seql 38 | ``` 39 | 40 | And run the training script. Possible options are: 41 | - `python lbf_train.py --env Foraging-12x12-2p-1f-v0` 42 | - ... 43 | - `python lbf_train.py --env Foraging-15x15-3p-4f-v0` or any other foraging environment size/configuration. 44 | - `python rware_train.py --env "rware-tiny-2ag-v1"` 45 | - ... 46 | - `python rware_train.py --env "rware-tiny-4ag-v1"`or any other rware environment size/configuration. 47 | 48 | ## Evaluation/Visualization - SEAC 49 | 50 | To load and render the pretrained models in SEAC, run in the seac directory 51 | 52 | ```eval 53 | python evaluate.py 54 | ``` 55 | 56 | ## Citation 57 | ``` 58 | @inproceedings{christianos2020shared, 59 | title={Shared Experience Actor-Critic for Multi-Agent Reinforcement Learning}, 60 | author={Christianos, Filippos and Sch{\"a}fer, Lukas and Albrecht, Stefano V}, 61 | booktitle = {Advances in Neural Information Processing Systems}, 62 | year={2020} 63 | } 64 | ``` 65 | -------------------------------------------------------------------------------- /seac/envs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import gym 4 | import numpy as np 5 | import torch 6 | from gym.spaces.box import Box 7 | from gym.wrappers import Monitor 8 | 9 | from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnvWrapper 10 | from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder 11 | from stable_baselines3.common.vec_env.vec_normalize import VecNormalize as VecNormalize_ 12 | from wrappers import TimeLimit, Monitor 13 | 14 | 15 | class MADummyVecEnv(DummyVecEnv): 16 | def __init__(self, env_fns): 17 | super().__init__(env_fns) 18 | agents = len(self.observation_space) 19 | # change this because we want >1 reward 20 | self.buf_rews = np.zeros((self.num_envs, agents), dtype=np.float32) 21 | 22 | def make_env(env_id, seed, rank, time_limit, wrappers, monitor_dir): 23 | def _thunk(): 24 | 25 | env = gym.make(env_id) 26 | env.seed(seed + rank) 27 | 28 | if time_limit: 29 | env = TimeLimit(env, time_limit) 30 | for wrapper in wrappers: 31 | env = wrapper(env) 32 | 33 | if monitor_dir: 34 | env = Monitor(env, monitor_dir, lambda ep: int(ep==0), force=True, uid=str(rank)) 35 | 36 | return env 37 | 38 | return _thunk 39 | 40 | 41 | def make_vec_envs( 42 | env_name, seed, dummy_vecenv, parallel, time_limit, wrappers, device, monitor_dir=None 43 | ): 44 | envs = [ 45 | make_env(env_name, seed, i, time_limit, wrappers,monitor_dir) for i in range(parallel) 46 | ] 47 | 48 | if dummy_vecenv or len(envs) == 1 or monitor_dir: 49 | envs = MADummyVecEnv(envs) 50 | else: 51 | envs = SubprocVecEnv(envs, start_method="fork") 52 | 53 | envs = VecPyTorch(envs, device) 54 | return envs 55 | 56 | 57 | class VecPyTorch(VecEnvWrapper): 58 | def __init__(self, venv, device): 59 | """Return only every `skip`-th frame""" 60 | super(VecPyTorch, self).__init__(venv) 61 | self.device = device 62 | # TODO: Fix data types 63 | 64 | def reset(self): 65 | obs = self.venv.reset() 66 | return [torch.from_numpy(o).to(self.device) for o in obs] 67 | return obs 68 | 69 | def step_async(self, actions): 70 | actions = [a.squeeze().cpu().numpy() for a in actions] 71 | actions = list(zip(*actions)) 72 | return self.venv.step_async(actions) 73 | 74 | def step_wait(self): 75 | obs, rew, done, info = self.venv.step_wait() 76 | return ( 77 | [torch.from_numpy(o).float().to(self.device) for o in obs], 78 | torch.from_numpy(rew).float().to(self.device), 79 | torch.from_numpy(done).float().to(self.device), 80 | info, 81 | ) 82 | 83 | -------------------------------------------------------------------------------- /seql/marl_algorithm.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | class MarlAlgorithm: 6 | """ 7 | abstract class for MARL algorithm 8 | """ 9 | 10 | def __init__(self, n_agents, observation_sizes, action_sizes, params): 11 | """ 12 | Initialise parameters for MARL training 13 | :param n_agents: number of agents 14 | :param observation_sizes: dimension of observation for each agent 15 | :param action_sizes: dimension of action for each agent 16 | :param params: parsed arglist parameter list 17 | """ 18 | self.n_agents = n_agents 19 | self.observation_sizes = observation_sizes 20 | self.action_sizes = action_sizes 21 | self.params = params 22 | self.batch_size = params.batch_size 23 | self.gamma = params.gamma 24 | self.tau = params.tau 25 | self.learning_rate = params.lr 26 | self.epsilon = params.epsilon 27 | self.decay_factor = params.decay_factor 28 | self.seed = params.seed 29 | 30 | if self.seed is not None: 31 | random.seed(self.seed) 32 | np.random.seed(self.seed) 33 | torch.manual_seed(self.seed) 34 | torch.cuda.manual_seed(self.seed) 35 | if torch.cuda.is_available(): 36 | torch.backends.cudnn.deterministic = True 37 | torch.backends.cudnn.benchmark = False 38 | 39 | self.t_steps = 0 40 | 41 | def reset(self, episode): 42 | """ 43 | Reset algorithm for new episode 44 | :param episode: new episode number 45 | """ 46 | raise NotImplementedError 47 | 48 | def step(self, observations, explore=False, available_actions=None): 49 | """ 50 | Take a step forward in environment with all agents 51 | :param observations: list of observations for each agent 52 | :param explore: flag whether or not to add exploration noise 53 | :param available_actions: binary vector (n_agents, n_actions) where each list contains 54 | binary values indicating whether action is applicable 55 | :return: list of actions for each agent 56 | """ 57 | raise NotImplementedError 58 | 59 | def update(self, memory, use_cuda=False): 60 | """ 61 | Train agent models based on memory samples 62 | :param memory: replay buffer memory to sample experience from 63 | :param use_cuda: flag if cuda/ gpus should be used 64 | :return: tuple of loss lists 65 | """ 66 | raise NotImplementedError 67 | 68 | def load_model_networks(self, directory, extension="_final"): 69 | """ 70 | Load model networks of all agents 71 | :param directory: path to directory where to load models from 72 | """ 73 | raise NotImplementedError 74 | -------------------------------------------------------------------------------- /seac/run.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import subprocess 3 | from pathlib import Path 4 | from itertools import product 5 | from collections import defaultdict 6 | import re 7 | 8 | import click 9 | 10 | _CPU_COUNT = multiprocessing.cpu_count() - 1 11 | 12 | 13 | def _find_named_configs(): 14 | configs = defaultdict(list) 15 | for c in Path("configs/").glob("**/*.yaml"): 16 | parent = str(c.relative_to("configs/").parent) 17 | name = c.stem 18 | if parent == ".": 19 | parent = None 20 | configs[parent].append(name) 21 | return configs 22 | 23 | 24 | _NAMED_CONFIGS = _find_named_configs() 25 | 26 | 27 | def _get_ingredient_from_mask(mask): 28 | if "/" in mask: 29 | return mask.split("/") 30 | return None, mask 31 | 32 | 33 | def _validate_config_mask(ctx, param, values): 34 | for v in values: 35 | ingredient, _ = _get_ingredient_from_mask(v) 36 | if ingredient not in _NAMED_CONFIGS: 37 | raise click.BadParameter( 38 | f"Invalid ingredient '{ingredient}'. Valid ingredients are: {list(_NAMED_CONFIGS.keys())}" 39 | ) 40 | return values 41 | 42 | 43 | def _filter_configs(configs, mask): 44 | ingredient, mask = _get_ingredient_from_mask(mask) 45 | regex = re.compile(mask) 46 | configs[ingredient] = list(filter(regex.search, configs[ingredient])) 47 | return configs 48 | 49 | 50 | def work(cmd): 51 | cmd = cmd.split(" ") 52 | return subprocess.call(cmd, shell=False) 53 | 54 | 55 | @click.command() 56 | @click.option("--seeds", default=3, show_default=True, help="How many seeds to run") 57 | @click.option( 58 | "--cpus", 59 | default=_CPU_COUNT, 60 | show_default=True, 61 | help="How many processes to run in parallel", 62 | ) 63 | @click.option( 64 | "--config-mask", 65 | "-c", 66 | multiple=True, 67 | callback=_validate_config_mask, 68 | help="Regex mask to filter configs/. Ingredient separator with forward slash \ 69 | '/'. E.g. 'algorithm/rware*'. By default all configs found are used.", 70 | ) 71 | def main(seeds, cpus, config_mask): 72 | pool = multiprocessing.Pool(processes=cpus) 73 | 74 | configs = _NAMED_CONFIGS 75 | for mask in config_mask: 76 | configs = _filter_configs(configs, mask) 77 | configs = [[f"{k}.{i}" if k else str(i) for i in v] for k, v in configs.items()] 78 | configs += [[f"seed={seed}" for seed in range(seeds)]] 79 | 80 | click.echo("Running following combinations: ") 81 | click.echo(click.style(" X ", fg="red", bold=True).join([str(s) for s in configs])) 82 | 83 | configs = list(product(*configs)) 84 | if len(configs) == 0: 85 | click.echo("No valid combinations. Aborted!") 86 | exit(1) 87 | 88 | click.confirm( 89 | f"There are {click.style(str(len(configs)), fg='red')} combinations of configurations. Up to {cpus} will run in parallel. Continue?", 90 | abort=True, 91 | ) 92 | 93 | configs = [ 94 | "python train.py -u with dummy_vecenv=True " + " ".join(c) for c in configs 95 | ] 96 | 97 | print(pool.map(work, configs)) 98 | 99 | 100 | if __name__ == "__main__": 101 | main() 102 | -------------------------------------------------------------------------------- /seql/lbf_train.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | 5 | import gym 6 | 7 | import lbforaging 8 | 9 | from wrappers import RecordEpisodeStatistics 10 | 11 | from train import Train 12 | 13 | 14 | class LBFTrain(Train): 15 | """ 16 | Training environment for the level-based foraging environment (LBF) 17 | """ 18 | 19 | def __init__(self): 20 | """ 21 | Create LBF Train instance 22 | """ 23 | super(LBFTrain, self).__init__() 24 | 25 | def parse_args(self): 26 | """ 27 | parse own arguments including default args and rware specific args 28 | """ 29 | self.parse_default_args() 30 | self.parser.add_argument( 31 | "--env", type=str, default=None, help="name of the lbf environment" 32 | ) 33 | 34 | def create_environment(self): 35 | """ 36 | Create environment instance 37 | :return: environment (gym interface), env_name, task_name, n_agents, observation_sizes, 38 | action_sizes 39 | """ 40 | # load scenario from script 41 | env = gym.make(self.arglist.env) 42 | env = RecordEpisodeStatistics(env, deque_size=10) 43 | 44 | task_name = self.arglist.env 45 | 46 | n_agents = env.n_agents 47 | 48 | print("Observation spaces: ", [env.observation_space[i] for i in range(n_agents)]) 49 | print("Action spaces: ", [env.action_space[i] for i in range(n_agents)]) 50 | observation_sizes = self.extract_sizes(env.observation_space) 51 | action_sizes = self.extract_sizes(env.action_space) 52 | 53 | return ( 54 | env, 55 | "lbf", 56 | task_name, 57 | n_agents, 58 | env.observation_space, 59 | env.action_space, 60 | observation_sizes, 61 | action_sizes, 62 | ) 63 | 64 | def reset_environment(self): 65 | """ 66 | Reset environment for new episode 67 | :return: observation (as torch tensor) 68 | """ 69 | obs = self.env.reset() 70 | obs = [np.expand_dims(o, axis=0) for o in obs] 71 | return obs 72 | 73 | def select_actions(self, obs, explore=True): 74 | """ 75 | Select actions for agents 76 | :param obs: joint observations for agents 77 | :return: actions, onehot_actions 78 | """ 79 | # get actions as torch Variables 80 | torch_agent_actions = self.alg.step(obs, explore) 81 | # convert actions to numpy arrays 82 | onehot_actions = [ac.data.numpy() for ac in torch_agent_actions] 83 | # convert onehot to ints 84 | actions = np.argmax(onehot_actions, axis=-1) 85 | 86 | return actions, onehot_actions 87 | 88 | def environment_step(self, actions): 89 | """ 90 | Take step in the environment 91 | :param actions: actions to apply for each agent 92 | :return: reward, done, next_obs (as Pytorch tensors), info 93 | """ 94 | # environment step 95 | next_obs, reward, done, info = self.env.step(actions) 96 | next_obs = [np.expand_dims(o, axis=0) for o in next_obs] 97 | return reward, done, next_obs, info 98 | 99 | def environment_render(self): 100 | """ 101 | Render visualisation of environment 102 | """ 103 | self.env.render() 104 | time.sleep(0.1) 105 | 106 | 107 | if __name__ == "__main__": 108 | train = LBFTrain() 109 | train.train() 110 | -------------------------------------------------------------------------------- /seql/rware_train.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | 5 | import gym 6 | 7 | import robotic_warehouse 8 | 9 | from wrappers import RecordEpisodeStatistics 10 | 11 | from train import Train 12 | 13 | 14 | class RWARETrain(Train): 15 | """ 16 | Training environment for the Robotic warehouse environment (RWARE) 17 | """ 18 | 19 | def __init__(self): 20 | """ 21 | Create RWARE Train instance 22 | """ 23 | super(RWARETrain, self).__init__() 24 | 25 | def parse_args(self): 26 | """ 27 | parse own arguments including default args and rware specific args 28 | """ 29 | self.parse_default_args() 30 | self.parser.add_argument( 31 | "--env", type=str, default=None, help="name of the rware environment" 32 | ) 33 | 34 | def create_environment(self): 35 | """ 36 | Create environment instance 37 | :return: environment (gym interface), env_name, task_name, n_agents, observation_sizes, 38 | action_sizes 39 | """ 40 | # load scenario from script 41 | env = gym.make(self.arglist.env) 42 | env = RecordEpisodeStatistics(env, deque_size=10) 43 | 44 | task_name = self.arglist.env 45 | 46 | n_agents = env.n_agents 47 | 48 | print("Observation spaces: ", [env.observation_space[i] for i in range(n_agents)]) 49 | print("Action spaces: ", [env.action_space[i] for i in range(n_agents)]) 50 | observation_sizes = self.extract_sizes(env.observation_space) 51 | action_sizes = self.extract_sizes(env.action_space) 52 | 53 | return ( 54 | env, 55 | "rware", 56 | task_name, 57 | n_agents, 58 | env.observation_space, 59 | env.action_space, 60 | observation_sizes, 61 | action_sizes, 62 | ) 63 | 64 | def reset_environment(self): 65 | """ 66 | Reset environment for new episode 67 | :return: observation (as torch tensor) 68 | """ 69 | obs = self.env.reset() 70 | obs = [np.expand_dims(o, axis=0) for o in obs] 71 | return obs 72 | 73 | def select_actions(self, obs, explore=True): 74 | """ 75 | Select actions for agents 76 | :param obs: joint observations for agents 77 | :return: actions, onehot_actions 78 | """ 79 | # get actions as torch Variables 80 | torch_agent_actions = self.alg.step(obs, explore) 81 | # convert actions to numpy arrays 82 | onehot_actions = [ac.data.numpy() for ac in torch_agent_actions] 83 | # convert onehot to ints 84 | actions = np.argmax(onehot_actions, axis=-1) 85 | 86 | return actions, onehot_actions 87 | 88 | def environment_step(self, actions): 89 | """ 90 | Take step in the environment 91 | :param actions: actions to apply for each agent 92 | :return: reward, done, next_obs (as Pytorch tensors), info 93 | """ 94 | # environment step 95 | next_obs, reward, done, info = self.env.step(actions) 96 | next_obs = [np.expand_dims(o, axis=0) for o in next_obs] 97 | return reward, done, next_obs, info 98 | 99 | def environment_render(self): 100 | """ 101 | Render visualisation of environment 102 | """ 103 | self.env.render() 104 | time.sleep(0.1) 105 | 106 | 107 | if __name__ == "__main__": 108 | train = RWARETrain() 109 | train.train() 110 | -------------------------------------------------------------------------------- /seql/marl_utils.py: -------------------------------------------------------------------------------- 1 | # https://github.com/shariqiqbal2810/maddpg-pytorch/blob/master/utils/misc.py 2 | 3 | import os 4 | import torch 5 | import torch.nn.functional as F 6 | import torch.distributed as dist 7 | from torch.autograd import Variable 8 | import numpy as np 9 | 10 | # https://github.com/ikostrikov/pytorch-ddpg-naf/blob/master/ddpg.py#L11 11 | def soft_update(target, source, tau): 12 | """ 13 | Perform DDPG soft update (move target params toward source based on weight 14 | factor tau) 15 | Inputs: 16 | target (torch.nn.Module): Net to copy parameters to 17 | source (torch.nn.Module): Net whose parameters to copy 18 | tau (float, 0 < x < 1): Weight factor for update 19 | """ 20 | for target_param, param in zip(target.parameters(), source.parameters()): 21 | target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau) 22 | 23 | 24 | # https://github.com/ikostrikov/pytorch-ddpg-naf/blob/master/ddpg.py#L15 25 | def hard_update(target, source): 26 | """ 27 | Copy network parameters from source to target 28 | Inputs: 29 | target (torch.nn.Module): Net to copy parameters to 30 | source (torch.nn.Module): Net whose parameters to copy 31 | """ 32 | for target_param, param in zip(target.parameters(), source.parameters()): 33 | target_param.data.copy_(param.data) 34 | 35 | 36 | def onehot_from_logits(logits, eps=0.0): 37 | """ 38 | Given batch of logits, return one-hot sample using epsilon greedy strategy 39 | (based on given epsilon) 40 | """ 41 | # get best (according to current policy) actions in one-hot form 42 | argmax_acs = (logits == logits.max(1, keepdim=True)[0]).float() 43 | if eps == 0.0: 44 | return argmax_acs 45 | # get random actions in one-hot form 46 | rand_acs = Variable( 47 | torch.eye(logits.shape[1])[ 48 | [np.random.choice(range(logits.shape[1]), size=logits.shape[0])] 49 | ], 50 | requires_grad=False, 51 | ) 52 | # chooses between best and random actions using epsilon greedy 53 | return torch.stack( 54 | [ 55 | argmax_acs[i] if r > eps else rand_acs[i] 56 | for i, r in enumerate(torch.rand(logits.shape[0])) 57 | ] 58 | ) 59 | 60 | 61 | # modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb 62 | def sample_gumbel(shape, eps=1e-20, use_cuda=False): 63 | """Sample from Gumbel(0, 1)""" 64 | if use_cuda: 65 | tens_type = torch.cuda.FloatTensor 66 | else: 67 | tens_type = torch.FloatTensor 68 | U = Variable(tens_type(*shape).uniform_(), requires_grad=False) 69 | return -torch.log(-torch.log(U + eps) + eps) 70 | 71 | 72 | # modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb 73 | def gumbel_softmax_sample(logits, temperature): 74 | """ Draw a sample from the Gumbel-Softmax distribution""" 75 | y = logits + sample_gumbel(logits.shape, use_cuda=logits.is_cuda) 76 | return F.softmax(y / temperature, dim=1) 77 | 78 | 79 | # modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb 80 | def gumbel_softmax(logits, temperature=1.0, hard=False): 81 | """Sample from the Gumbel-Softmax distribution and optionally discretize. 82 | Args: 83 | logits: [batch_size, n_class] unnormalized log-probs 84 | temperature: non-negative scalar 85 | hard: if True, take argmax, but differentiate w.r.t. soft sample y 86 | Returns: 87 | [batch_size, n_class] sample from the Gumbel-Softmax distribution. 88 | If hard=True, then the returned sample will be one-hot, otherwise it will 89 | be a probabilitiy distribution that sums to 1 across classes 90 | """ 91 | y = gumbel_softmax_sample(logits, temperature) 92 | if hard: 93 | y_hard = onehot_from_logits(y) 94 | y = (y_hard - y).detach() + y 95 | return y 96 | -------------------------------------------------------------------------------- /seql/agent.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import random 3 | 4 | import numpy as np 5 | 6 | import torch 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | 10 | from model import QNetwork 11 | from marl_utils import hard_update, soft_update, onehot_from_logits 12 | 13 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 14 | FloatTensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor 15 | 16 | 17 | class Agent: 18 | """ 19 | Class for individual IQL agent 20 | """ 21 | 22 | def __init__(self, observation_size, action_size, params): 23 | """ 24 | Initialise parameters for agent 25 | :param observation_size: dimensions of observations 26 | :param action_size: dimensions of actions 27 | :param params: parsed arglist parameter list 28 | """ 29 | self.observation_size = observation_size 30 | self.action_size = action_size 31 | self.params = params 32 | 33 | self.epsilon = params.epsilon 34 | self.epsilon_anneal_slow = params.epsilon_anneal_slow 35 | if self.epsilon_anneal_slow: 36 | self.goal_epsilon = params.goal_epsilon 37 | self.epsilon_decay = params.epsilon_decay 38 | self.decay_factor = params.decay_factor 39 | self.current_decay = params.decay_factor 40 | else: 41 | self.decay_factor = params.decay_factor 42 | 43 | # create Q-Learning networks 44 | self.model = QNetwork(observation_size, action_size, params.hidden_dim) 45 | 46 | if params.seed is not None: 47 | random.seed(params.seed) 48 | np.random.seed(params.seed) 49 | torch.manual_seed(params.seed) 50 | torch.cuda.manual_seed(params.seed) 51 | if torch.cuda.is_available(): 52 | torch.backends.cudnn.deterministic = True 53 | torch.backends.cudnn.benchmark = False 54 | 55 | self.target_model = QNetwork( 56 | observation_size, action_size, params.hidden_dim 57 | ) 58 | 59 | if params.seed is not None: 60 | random.seed(params.seed) 61 | np.random.seed(params.seed) 62 | torch.manual_seed(params.seed) 63 | torch.cuda.manual_seed(params.seed) 64 | if torch.cuda.is_available(): 65 | torch.backends.cudnn.deterministic = True 66 | torch.backends.cudnn.benchmark = False 67 | 68 | hard_update(self.target_model, self.model) 69 | 70 | # create optimizer 71 | self.optimizer = optim.Adam(self.model.parameters(), lr=params.lr) 72 | 73 | if params.seed is not None: 74 | random.seed(params.seed) 75 | np.random.seed(params.seed) 76 | torch.manual_seed(params.seed) 77 | torch.cuda.manual_seed(params.seed) 78 | if torch.cuda.is_available(): 79 | torch.backends.cudnn.deterministic = True 80 | torch.backends.cudnn.benchmark = False 81 | 82 | self.t_step = 0 83 | 84 | def step(self, obs, explore=False, available_actions=None): 85 | """ 86 | Take a step forward in environment for a minibatch of observations 87 | :param obs (PyTorch Variable): Observations for this agent 88 | :param explore (boolean): Whether or not to add exploration noise 89 | :param available_actions: binary vector (n_agents, n_actions) where each list contains 90 | binary values indicating whether action is applicable 91 | :return: action (PyTorch Variable) Actions for this agent 92 | """ 93 | qvals = self.model(obs) 94 | self.t_step += 1 95 | 96 | if available_actions is not None: 97 | assert self.discrete_actions 98 | available_mask = torch.ByteTensor(list(map(lambda a: a == 1, available_actions))) 99 | negative_tensor = torch.ones(qvals.shape) * -1e9 100 | negative_tensor[:, available_mask] = qvals[:, available_mask] 101 | qvals = negative_tensor 102 | if explore: 103 | action = onehot_from_logits(qvals, self.epsilon) 104 | else: 105 | # use small epsilon in evaluation even 106 | action = onehot_from_logits(qvals, 0.01) 107 | 108 | if self.epsilon_anneal_slow: 109 | self.current_decay *= self.decay_factor 110 | self.epsilon = max(0.1 + (self.epsilon_decay - self.current_decay)/ self.epsilon_decay, self.goal_epsilon) 111 | else: 112 | self.epsilon *= self.decay_factor 113 | 114 | return action 115 | -------------------------------------------------------------------------------- /seql/wrappers.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | from collections import deque 4 | from time import perf_counter 5 | 6 | import gym 7 | import numpy as np 8 | from gym import ObservationWrapper, spaces 9 | from gym.wrappers import TimeLimit as GymTimeLimit 10 | from gym.wrappers import Monitor as GymMonitor 11 | 12 | 13 | class RecordEpisodeStatistics(gym.Wrapper): 14 | """ Multi-agent version of RecordEpisodeStatistics gym wrapper""" 15 | 16 | def __init__(self, env, deque_size=100): 17 | super().__init__(env) 18 | self.t0 = perf_counter() 19 | self.episode_reward = np.zeros(self.n_agents) 20 | self.episode_length = 0 21 | self.reward_queue = deque(maxlen=deque_size) 22 | self.length_queue = deque(maxlen=deque_size) 23 | 24 | def reset(self, **kwargs): 25 | observation = super().reset(**kwargs) 26 | self.episode_reward = 0 27 | self.episode_length = 0 28 | self.t0 = perf_counter() 29 | 30 | return observation 31 | 32 | def step(self, action): 33 | observation, reward, done, info = super().step(action) 34 | self.episode_reward += np.array(reward, dtype=np.float64) 35 | self.episode_length += 1 36 | if done: 37 | info["episode_reward"] = self.episode_reward 38 | for i, agent_reward in enumerate(self.episode_reward): 39 | info[f"agent{i}/episode_reward"] = agent_reward 40 | info["episode_length"] = self.episode_length 41 | info["episode_time"] = perf_counter() - self.t0 42 | 43 | self.reward_queue.append(self.episode_reward) 44 | self.length_queue.append(self.episode_length) 45 | return observation, reward, done, info 46 | 47 | 48 | class FlattenObservation(ObservationWrapper): 49 | r"""Observation wrapper that flattens the observation of individual agents.""" 50 | 51 | def __init__(self, env): 52 | super(FlattenObservation, self).__init__(env) 53 | 54 | ma_spaces = [] 55 | 56 | for sa_obs in env.observation_space: 57 | flatdim = spaces.flatdim(sa_obs) 58 | ma_spaces += [ 59 | spaces.Box( 60 | low=-float("inf"), 61 | high=float("inf"), 62 | shape=(flatdim,), 63 | dtype=np.float32, 64 | ) 65 | ] 66 | 67 | self.observation_space = spaces.Tuple(tuple(ma_spaces)) 68 | 69 | def observation(self, observation): 70 | return tuple([ 71 | spaces.flatten(obs_space, obs) 72 | for obs_space, obs in zip(self.env.observation_space, observation) 73 | ]) 74 | 75 | 76 | class SquashDones(gym.Wrapper): 77 | r"""Wrapper that squashes multiple dones to a single one using all(dones)""" 78 | 79 | def step(self, action): 80 | observation, reward, done, info = self.env.step(action) 81 | return observation, reward, all(done), info 82 | 83 | 84 | class GlobalizeReward(gym.RewardWrapper): 85 | def reward(self, reward): 86 | return self.n_agents * [sum(reward)] 87 | 88 | 89 | class TimeLimit(GymTimeLimit): 90 | def __init__(self, env, max_episode_steps=None): 91 | super().__init__(env) 92 | if max_episode_steps is None and self.env.spec is not None: 93 | max_episode_steps = env.spec.max_episode_steps 94 | # if self.env.spec is not None: 95 | # self.env.spec.max_episode_steps = max_episode_steps 96 | self._max_episode_steps = max_episode_steps 97 | self._elapsed_steps = None 98 | 99 | def step(self, action): 100 | assert self._elapsed_steps is not None, "Cannot call env.step() before calling reset()" 101 | observation, reward, done, info = self.env.step(action) 102 | self._elapsed_steps += 1 103 | if self._elapsed_steps >= self._max_episode_steps: 104 | info['TimeLimit.truncated'] = not all(done) 105 | done = len(observation) * [True] 106 | return observation, reward, done, info 107 | 108 | class ClearInfo(gym.Wrapper): 109 | def step(self, action): 110 | observation, reward, done, info = self.env.step(action) 111 | return observation, reward, done, {} 112 | 113 | 114 | class Monitor(GymMonitor): 115 | def _after_step(self, observation, reward, done, info): 116 | if not self.enabled: return done 117 | 118 | if done and self.env_semantics_autoreset: 119 | # For envs with BlockingReset wrapping VNCEnv, this observation will be the first one of the new episode 120 | self.reset_video_recorder() 121 | self.episode_id += 1 122 | self._flush() 123 | 124 | # Record stats 125 | self.stats_recorder.after_step(observation, sum(reward), done, info) 126 | # Record video 127 | self.video_recorder.capture_frame() 128 | 129 | return done -------------------------------------------------------------------------------- /seac/wrappers.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | from collections import deque 4 | from time import perf_counter 5 | 6 | import gym 7 | import numpy as np 8 | from gym import ObservationWrapper, spaces 9 | from gym.wrappers import TimeLimit as GymTimeLimit 10 | from gym.wrappers import Monitor as GymMonitor 11 | 12 | 13 | class RecordEpisodeStatistics(gym.Wrapper): 14 | """ Multi-agent version of RecordEpisodeStatistics gym wrapper""" 15 | 16 | def __init__(self, env, deque_size=100): 17 | super().__init__(env) 18 | self.t0 = perf_counter() 19 | self.episode_reward = np.zeros(self.n_agents) 20 | self.episode_length = 0 21 | self.reward_queue = deque(maxlen=deque_size) 22 | self.length_queue = deque(maxlen=deque_size) 23 | 24 | def reset(self, **kwargs): 25 | observation = super().reset(**kwargs) 26 | self.episode_reward = 0 27 | self.episode_length = 0 28 | self.t0 = perf_counter() 29 | 30 | return observation 31 | 32 | def step(self, action): 33 | observation, reward, done, info = super().step(action) 34 | self.episode_reward += np.array(reward, dtype=np.float64) 35 | self.episode_length += 1 36 | if all(done): 37 | info["episode_reward"] = self.episode_reward 38 | for i, agent_reward in enumerate(self.episode_reward): 39 | info[f"agent{i}/episode_reward"] = agent_reward 40 | info["episode_length"] = self.episode_length 41 | info["episode_time"] = perf_counter() - self.t0 42 | 43 | self.reward_queue.append(self.episode_reward) 44 | self.length_queue.append(self.episode_length) 45 | return observation, reward, done, info 46 | 47 | 48 | class FlattenObservation(ObservationWrapper): 49 | r"""Observation wrapper that flattens the observation of individual agents.""" 50 | 51 | def __init__(self, env): 52 | super(FlattenObservation, self).__init__(env) 53 | 54 | ma_spaces = [] 55 | 56 | for sa_obs in env.observation_space: 57 | flatdim = spaces.flatdim(sa_obs) 58 | ma_spaces += [ 59 | spaces.Box( 60 | low=-float("inf"), 61 | high=float("inf"), 62 | shape=(flatdim,), 63 | dtype=np.float32, 64 | ) 65 | ] 66 | 67 | self.observation_space = spaces.Tuple(tuple(ma_spaces)) 68 | 69 | def observation(self, observation): 70 | return tuple([ 71 | spaces.flatten(obs_space, obs) 72 | for obs_space, obs in zip(self.env.observation_space, observation) 73 | ]) 74 | 75 | 76 | class SquashDones(gym.Wrapper): 77 | r"""Wrapper that squashes multiple dones to a single one using all(dones)""" 78 | 79 | def step(self, action): 80 | observation, reward, done, info = self.env.step(action) 81 | return observation, reward, all(done), info 82 | 83 | 84 | class GlobalizeReward(gym.RewardWrapper): 85 | def reward(self, reward): 86 | return self.n_agents * [sum(reward)] 87 | 88 | 89 | class TimeLimit(GymTimeLimit): 90 | def __init__(self, env, max_episode_steps=None): 91 | super().__init__(env) 92 | if max_episode_steps is None and self.env.spec is not None: 93 | max_episode_steps = env.spec.max_episode_steps 94 | # if self.env.spec is not None: 95 | # self.env.spec.max_episode_steps = max_episode_steps 96 | self._max_episode_steps = max_episode_steps 97 | self._elapsed_steps = None 98 | 99 | def step(self, action): 100 | assert self._elapsed_steps is not None, "Cannot call env.step() before calling reset()" 101 | observation, reward, done, info = self.env.step(action) 102 | self._elapsed_steps += 1 103 | if self._elapsed_steps >= self._max_episode_steps: 104 | info['TimeLimit.truncated'] = not all(done) 105 | done = len(observation) * [True] 106 | return observation, reward, done, info 107 | 108 | class ClearInfo(gym.Wrapper): 109 | def step(self, action): 110 | observation, reward, done, info = self.env.step(action) 111 | return observation, reward, done, {} 112 | 113 | 114 | class Monitor(GymMonitor): 115 | def _after_step(self, observation, reward, done, info): 116 | if not self.enabled: return done 117 | 118 | if all(done) and self.env_semantics_autoreset: 119 | # For envs with BlockingReset wrapping VNCEnv, this observation will be the first one of the new episode 120 | self.reset_video_recorder() 121 | self.episode_id += 1 122 | self._flush() 123 | 124 | # Record stats 125 | self.stats_recorder.after_step(observation, sum(reward), all(done), info) 126 | # Record video 127 | self.video_recorder.capture_frame() 128 | 129 | return done 130 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | seac/results 2 | seql/models 3 | seql/logs 4 | # Created by https://www.gitignore.io/api/linux,python,windows,pycharm+all,visualstudiocode 5 | # Edit at https://www.gitignore.io/?templates=linux,python,windows,pycharm+all,visualstudiocode 6 | .vscode 7 | ### Linux ### 8 | *~ 9 | 10 | # temporary files which can be created if a process still has a handle open of a deleted file 11 | .fuse_hidden* 12 | 13 | # KDE directory preferences 14 | .directory 15 | 16 | # Linux trash folder which might appear on any partition or disk 17 | .Trash-* 18 | 19 | # .nfs files are created when an open file is removed but is still being accessed 20 | .nfs* 21 | 22 | ### PyCharm+all ### 23 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 24 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 25 | 26 | # User-specific stuff 27 | .idea/**/workspace.xml 28 | .idea/**/tasks.xml 29 | .idea/**/usage.statistics.xml 30 | .idea/**/dictionaries 31 | .idea/**/shelf 32 | 33 | # Generated files 34 | .idea/**/contentModel.xml 35 | 36 | # Sensitive or high-churn files 37 | .idea/**/dataSources/ 38 | .idea/**/dataSources.ids 39 | .idea/**/dataSources.local.xml 40 | .idea/**/sqlDataSources.xml 41 | .idea/**/dynamic.xml 42 | .idea/**/uiDesigner.xml 43 | .idea/**/dbnavigator.xml 44 | 45 | # Gradle 46 | .idea/**/gradle.xml 47 | .idea/**/libraries 48 | 49 | # Gradle and Maven with auto-import 50 | # When using Gradle or Maven with auto-import, you should exclude module files, 51 | # since they will be recreated, and may cause churn. Uncomment if using 52 | # auto-import. 53 | # .idea/modules.xml 54 | # .idea/*.iml 55 | # .idea/modules 56 | # *.iml 57 | # *.ipr 58 | 59 | # CMake 60 | cmake-build-*/ 61 | 62 | # Mongo Explorer plugin 63 | .idea/**/mongoSettings.xml 64 | 65 | # File-based project format 66 | *.iws 67 | 68 | # IntelliJ 69 | out/ 70 | 71 | # mpeltonen/sbt-idea plugin 72 | .idea_modules/ 73 | 74 | # JIRA plugin 75 | atlassian-ide-plugin.xml 76 | 77 | # Cursive Clojure plugin 78 | .idea/replstate.xml 79 | 80 | # Crashlytics plugin (for Android Studio and IntelliJ) 81 | com_crashlytics_export_strings.xml 82 | crashlytics.properties 83 | crashlytics-build.properties 84 | fabric.properties 85 | 86 | # Editor-based Rest Client 87 | .idea/httpRequests 88 | 89 | # Android studio 3.1+ serialized cache file 90 | .idea/caches/build_file_checksums.ser 91 | 92 | ### PyCharm+all Patch ### 93 | # Ignores the whole .idea folder and all .iml files 94 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 95 | 96 | .idea/ 97 | 98 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 99 | 100 | *.iml 101 | modules.xml 102 | .idea/misc.xml 103 | *.ipr 104 | 105 | # Sonarlint plugin 106 | .idea/sonarlint 107 | 108 | ### Python ### 109 | # Byte-compiled / optimized / DLL files 110 | __pycache__/ 111 | *.py[cod] 112 | *$py.class 113 | 114 | # C extensions 115 | *.so 116 | 117 | # Distribution / packaging 118 | .Python 119 | build/ 120 | develop-eggs/ 121 | dist/ 122 | downloads/ 123 | eggs/ 124 | .eggs/ 125 | lib/ 126 | lib64/ 127 | parts/ 128 | sdist/ 129 | var/ 130 | wheels/ 131 | pip-wheel-metadata/ 132 | share/python-wheels/ 133 | *.egg-info/ 134 | .installed.cfg 135 | *.egg 136 | MANIFEST 137 | 138 | # PyInstaller 139 | # Usually these files are written by a python script from a template 140 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 141 | *.manifest 142 | *.spec 143 | 144 | # Installer logs 145 | pip-log.txt 146 | pip-delete-this-directory.txt 147 | 148 | # Unit test / coverage reports 149 | htmlcov/ 150 | .tox/ 151 | .nox/ 152 | .coverage 153 | .coverage.* 154 | .cache 155 | nosetests.xml 156 | coverage.xml 157 | *.cover 158 | .hypothesis/ 159 | .pytest_cache/ 160 | 161 | # Translations 162 | *.mo 163 | *.pot 164 | 165 | # Scrapy stuff: 166 | .scrapy 167 | 168 | # Sphinx documentation 169 | docs/_build/ 170 | 171 | # PyBuilder 172 | target/ 173 | 174 | # pyenv 175 | .python-version 176 | 177 | # pipenv 178 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 179 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 180 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 181 | # install all needed dependencies. 182 | #Pipfile.lock 183 | 184 | # celery beat schedule file 185 | celerybeat-schedule 186 | 187 | # SageMath parsed files 188 | *.sage.py 189 | 190 | # Spyder project settings 191 | .spyderproject 192 | .spyproject 193 | 194 | # Rope project settings 195 | .ropeproject 196 | 197 | # Mr Developer 198 | .mr.developer.cfg 199 | .project 200 | .pydevproject 201 | 202 | # mkdocs documentation 203 | /site 204 | 205 | # mypy 206 | .mypy_cache/ 207 | .dmypy.json 208 | dmypy.json 209 | 210 | # Pyre type checker 211 | .pyre/ 212 | 213 | ### VisualStudioCode ### 214 | .vscode/* 215 | !.vscode/settings.json 216 | !.vscode/tasks.json 217 | !.vscode/launch.json 218 | !.vscode/extensions.json 219 | 220 | ### VisualStudioCode Patch ### 221 | # Ignore all local history of files 222 | .history 223 | 224 | ### Windows ### 225 | # Windows thumbnail cache files 226 | Thumbs.db 227 | Thumbs.db:encryptable 228 | ehthumbs.db 229 | ehthumbs_vista.db 230 | 231 | # Dump file 232 | *.stackdump 233 | 234 | # Folder config file 235 | [Dd]esktop.ini 236 | 237 | # Recycle Bin used on file shares 238 | $RECYCLE.BIN/ 239 | 240 | # Windows Installer files 241 | *.cab 242 | *.msi 243 | *.msix 244 | *.msm 245 | *.msp 246 | 247 | # Windows shortcuts 248 | *.lnk 249 | 250 | # End of https://www.gitignore.io/api/linux,python,windows,pycharm+all,visualstudiocode -------------------------------------------------------------------------------- /seql/baseline_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | 5 | class ReplayBuffer(object): 6 | def __init__(self, size): 7 | """Create Replay buffer. 8 | Parameters 9 | ---------- 10 | size: int 11 | Max number of transitions to store in the buffer. When the buffer 12 | overflows the old memories are dropped. 13 | """ 14 | self._storage = [] 15 | self._maxsize = size 16 | self._next_idx = 0 17 | 18 | def __len__(self): 19 | return len(self._storage) 20 | 21 | def add(self, obs_t, action, reward, obs_tp1, done): 22 | data = (obs_t, action, reward, obs_tp1, done) 23 | 24 | if self._next_idx >= len(self._storage): 25 | self._storage.append(data) 26 | else: 27 | self._storage[self._next_idx] = data 28 | self._next_idx = (self._next_idx + 1) % self._maxsize 29 | 30 | def _encode_sample(self, idxes): 31 | obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], [] 32 | for i in idxes: 33 | data = self._storage[i] 34 | obs_t, action, reward, obs_tp1, done = data 35 | obses_t.append(np.array(obs_t, copy=False)) 36 | actions.append(np.array(action, copy=False)) 37 | rewards.append(reward) 38 | obses_tp1.append(np.array(obs_tp1, copy=False)) 39 | dones.append(done) 40 | return np.array(obses_t), np.array(actions), np.array(rewards), np.array(obses_tp1), np.array(dones) 41 | 42 | def sample(self, batch_size): 43 | """Sample a batch of experiences. 44 | Parameters 45 | ---------- 46 | batch_size: int 47 | How many transitions to sample. 48 | Returns 49 | ------- 50 | obs_batch: np.array 51 | batch of observations 52 | act_batch: np.array 53 | batch of actions executed given obs_batch 54 | rew_batch: np.array 55 | rewards received as results of executing act_batch 56 | next_obs_batch: np.array 57 | next set of observations seen after executing act_batch 58 | done_mask: np.array 59 | done_mask[i] = 1 if executing act_batch[i] resulted in 60 | the end of an episode and 0 otherwise. 61 | """ 62 | idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] 63 | return self._encode_sample(idxes) 64 | 65 | 66 | class MARLReplayBuffer(object): 67 | def __init__(self, size, num_agents): 68 | """Create Replay buffer. 69 | Parameters 70 | ---------- 71 | size: int 72 | Max number of transitions to store in the buffer. When the buffer 73 | overflows the old memories are dropped. 74 | num_agents: int 75 | Number of agents 76 | """ 77 | self.size = size 78 | self.num_agents = num_agents 79 | self.buffers = [ReplayBuffer(size) for _ in range(num_agents)] 80 | 81 | def __len__(self): 82 | return len(self.buffers[0]) 83 | 84 | def add(self, observations, actions, rewards, next_observations, dones): 85 | for i, (o, a, r, no, d) in enumerate(zip(observations, actions, rewards, next_observations, dones)): 86 | self.buffers[i].add(o, a, r, no, d) 87 | 88 | def sample(self, batch_size, agent_i): 89 | """Sample a batch of experiences. 90 | Parameters 91 | ---------- 92 | batch_size: int 93 | How many transitions to sample. 94 | agent_i: int 95 | Index of agent to sample for 96 | Returns 97 | ------- 98 | obs_batch: np.array 99 | batch of observations 100 | act_batch: np.array 101 | batch of actions executed given obs_batch 102 | rew_batch: np.array 103 | rewards received as results of executing act_batch 104 | next_obs_batch: np.array 105 | next set of observations seen after executing act_batch 106 | done_mask: np.array 107 | done_mask[i] = 1 if executing act_batch[i] resulted in 108 | the end of an episode and 0 otherwise. 109 | """ 110 | cast = lambda x: torch.from_numpy(x).float() 111 | obs, act, rew, next_obs, done = self.buffers[agent_i].sample(batch_size) 112 | obs = cast(obs).squeeze() 113 | act = cast(act) 114 | rew = cast(rew) 115 | next_obs = cast(next_obs).squeeze() 116 | done = cast(done) 117 | return obs, act, rew, next_obs, done 118 | 119 | def sample_shared(self, batch_size): 120 | """Sample a batch of experiences. 121 | Parameters 122 | ---------- 123 | batch_size: int 124 | How many transitions to sample. 125 | Returns 126 | ------- 127 | obs_batch: np.array 128 | batch of observations 129 | act_batch: np.array 130 | batch of actions executed given obs_batch 131 | rew_batch: np.array 132 | rewards received as results of executing act_batch 133 | next_obs_batch: np.array 134 | next set of observations seen after executing act_batch 135 | done_mask: np.array 136 | done_mask[i] = 1 if executing act_batch[i] resulted in 137 | the end of an episode and 0 otherwise. 138 | """ 139 | batch_size_each = batch_size // self.num_agents 140 | obs = [] 141 | act = [] 142 | rew = [] 143 | next_obs = [] 144 | done = [] 145 | for agent_i in range(self.num_agents): 146 | o, a, r, no, d = self.buffers[agent_i].sample(batch_size_each) 147 | obs.append(o) 148 | act.append(a) 149 | rew.append(r) 150 | next_obs.append(no) 151 | done.append(d) 152 | cast = lambda x: torch.from_numpy(x).float() 153 | obs = cast(np.vstack(obs)).squeeze() 154 | act = cast(np.vstack(act)) 155 | rew = cast(np.vstack(rew)) 156 | next_obs = cast(np.vstack(next_obs)).squeeze() 157 | done = cast(np.vstack(done)) 158 | return obs, act, rew, next_obs, done 159 | -------------------------------------------------------------------------------- /seac/a2c.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | 8 | import numpy as np 9 | 10 | import gym 11 | from model import Policy, FCNetwork 12 | from gym.spaces.utils import flatdim 13 | from storage import RolloutStorage 14 | from sacred import Ingredient 15 | 16 | algorithm = Ingredient("algorithm") 17 | 18 | 19 | @algorithm.config 20 | def config(): 21 | lr = 3e-4 22 | adam_eps = 0.001 23 | gamma = 0.99 24 | use_gae = False 25 | gae_lambda = 0.95 26 | entropy_coef = 0.01 27 | value_loss_coef = 0.5 28 | max_grad_norm = 0.5 29 | 30 | use_proper_time_limits = True 31 | recurrent_policy = False 32 | use_linear_lr_decay = False 33 | 34 | seac_coef = 1.0 35 | 36 | num_processes = 4 37 | num_steps = 5 38 | 39 | device = "cpu" 40 | 41 | 42 | class A2C: 43 | @algorithm.capture() 44 | def __init__( 45 | self, 46 | agent_id, 47 | obs_space, 48 | action_space, 49 | lr, 50 | adam_eps, 51 | recurrent_policy, 52 | num_steps, 53 | num_processes, 54 | device, 55 | ): 56 | self.agent_id = agent_id 57 | self.obs_size = flatdim(obs_space) 58 | self.action_size = flatdim(action_space) 59 | self.obs_space = obs_space 60 | self.action_space = action_space 61 | 62 | self.model = Policy( 63 | obs_space, action_space, base_kwargs={"recurrent": recurrent_policy}, 64 | ) 65 | 66 | self.storage = RolloutStorage( 67 | obs_space, 68 | action_space, 69 | self.model.recurrent_hidden_state_size, 70 | num_steps, 71 | num_processes, 72 | ) 73 | 74 | self.model.to(device) 75 | self.optimizer = optim.Adam(self.model.parameters(), lr, eps=adam_eps) 76 | 77 | # self.intr_stats = RunningStats() 78 | self.saveables = { 79 | "model": self.model, 80 | "optimizer": self.optimizer, 81 | } 82 | 83 | def save(self, path): 84 | torch.save(self.saveables, os.path.join(path, "models.pt")) 85 | 86 | def restore(self, path): 87 | checkpoint = torch.load(os.path.join(path, "models.pt")) 88 | for k, v in self.saveables.items(): 89 | v.load_state_dict(checkpoint[k].state_dict()) 90 | 91 | @algorithm.capture 92 | def compute_returns(self, use_gae, gamma, gae_lambda, use_proper_time_limits): 93 | with torch.no_grad(): 94 | next_value = self.model.get_value( 95 | self.storage.obs[-1], 96 | self.storage.recurrent_hidden_states[-1], 97 | self.storage.masks[-1], 98 | ).detach() 99 | 100 | self.storage.compute_returns( 101 | next_value, use_gae, gamma, gae_lambda, use_proper_time_limits, 102 | ) 103 | 104 | @algorithm.capture 105 | def update( 106 | self, 107 | storages, 108 | value_loss_coef, 109 | entropy_coef, 110 | seac_coef, 111 | max_grad_norm, 112 | device, 113 | ): 114 | 115 | obs_shape = self.storage.obs.size()[2:] 116 | action_shape = self.storage.actions.size()[-1] 117 | num_steps, num_processes, _ = self.storage.rewards.size() 118 | 119 | values, action_log_probs, dist_entropy, _ = self.model.evaluate_actions( 120 | self.storage.obs[:-1].view(-1, *obs_shape), 121 | self.storage.recurrent_hidden_states[0].view( 122 | -1, self.model.recurrent_hidden_state_size 123 | ), 124 | self.storage.masks[:-1].view(-1, 1), 125 | self.storage.actions.view(-1, action_shape), 126 | ) 127 | 128 | values = values.view(num_steps, num_processes, 1) 129 | action_log_probs = action_log_probs.view(num_steps, num_processes, 1) 130 | 131 | advantages = self.storage.returns[:-1] - values 132 | 133 | policy_loss = -(advantages.detach() * action_log_probs).mean() 134 | value_loss = advantages.pow(2).mean() 135 | 136 | 137 | # calculate prediction loss for the OTHER actor 138 | other_agent_ids = [x for x in range(len(storages)) if x != self.agent_id] 139 | seac_policy_loss = 0 140 | seac_value_loss = 0 141 | for oid in other_agent_ids: 142 | 143 | other_values, logp, _, _ = self.model.evaluate_actions( 144 | storages[oid].obs[:-1].view(-1, *obs_shape), 145 | storages[oid] 146 | .recurrent_hidden_states[0] 147 | .view(-1, self.model.recurrent_hidden_state_size), 148 | storages[oid].masks[:-1].view(-1, 1), 149 | storages[oid].actions.view(-1, action_shape), 150 | ) 151 | other_values = other_values.view(num_steps, num_processes, 1) 152 | logp = logp.view(num_steps, num_processes, 1) 153 | other_advantage = ( 154 | storages[oid].returns[:-1] - other_values 155 | ) # or storages[oid].rewards 156 | 157 | importance_sampling = ( 158 | logp.exp() / (storages[oid].action_log_probs.exp() + 1e-7) 159 | ).detach() 160 | # importance_sampling = 1.0 161 | seac_value_loss += ( 162 | importance_sampling * other_advantage.pow(2) 163 | ).mean() 164 | seac_policy_loss += ( 165 | -importance_sampling * logp * other_advantage.detach() 166 | ).mean() 167 | 168 | self.optimizer.zero_grad() 169 | ( 170 | policy_loss 171 | + value_loss_coef * value_loss 172 | - entropy_coef * dist_entropy 173 | + seac_coef * seac_policy_loss 174 | + seac_coef * value_loss_coef * seac_value_loss 175 | ).backward() 176 | 177 | nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm) 178 | 179 | self.optimizer.step() 180 | 181 | return { 182 | "policy_loss": policy_loss.item(), 183 | "value_loss": value_loss_coef * value_loss.item(), 184 | "dist_entropy": entropy_coef * dist_entropy.item(), 185 | "importance_sampling": importance_sampling.mean().item(), 186 | "seac_policy_loss": seac_coef * seac_policy_loss.item(), 187 | "seac_value_loss": seac_coef 188 | * value_loss_coef 189 | * seac_value_loss.item(), 190 | } 191 | -------------------------------------------------------------------------------- /seac/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from distributions import Categorical 7 | from utils import init 8 | 9 | 10 | class Flatten(nn.Module): 11 | def forward(self, x): 12 | return x.view(x.size(0), -1) 13 | 14 | 15 | class FCNetwork(nn.Module): 16 | def __init__(self, dims, out_layer=None): 17 | """ 18 | Creates a network using ReLUs between layers and no activation at the end 19 | :param dims: tuple in the form of (100, 100, ..., 5). for dim sizes 20 | """ 21 | super().__init__() 22 | input_size = dims[0] 23 | h_sizes = dims[1:] 24 | 25 | mods = [nn.Linear(input_size, h_sizes[0])] 26 | for i in range(len(h_sizes) - 1): 27 | mods.append(nn.ReLU()) 28 | mods.append(nn.Linear(h_sizes[i], h_sizes[i + 1])) 29 | 30 | if out_layer: 31 | mods.append(out_layer) 32 | 33 | self.layers = nn.Sequential(*mods) 34 | 35 | def forward(self, x): 36 | # Feedforward 37 | return self.layers(x) 38 | 39 | def hard_update(self, source): 40 | for target_param, source_param in zip(self.parameters(), source.parameters()): 41 | target_param.data.copy_(source_param.data) 42 | 43 | def soft_update(self, source, t): 44 | for target_param, source_param in zip(self.parameters(), source.parameters()): 45 | target_param.data.copy_((1 - t) * target_param.data + t * source_param.data) 46 | 47 | 48 | class Policy(nn.Module): 49 | def __init__(self, obs_space, action_space, base=None, base_kwargs=None): 50 | super(Policy, self).__init__() 51 | 52 | obs_shape = obs_space.shape 53 | 54 | if base_kwargs is None: 55 | base_kwargs = {} 56 | 57 | self.base = MLPBase(obs_shape[0], **base_kwargs) 58 | 59 | num_outputs = action_space.n 60 | self.dist = Categorical(self.base.output_size, num_outputs) 61 | 62 | @property 63 | def is_recurrent(self): 64 | return self.base.is_recurrent 65 | 66 | @property 67 | def recurrent_hidden_state_size(self): 68 | """Size of rnn_hx.""" 69 | return self.base.recurrent_hidden_state_size 70 | 71 | def forward(self, inputs, rnn_hxs, masks): 72 | raise NotImplementedError 73 | 74 | def act(self, inputs, rnn_hxs, masks, deterministic=False): 75 | value, actor_features, rnn_hxs = self.base(inputs, rnn_hxs, masks) 76 | dist = self.dist(actor_features) 77 | 78 | if deterministic: 79 | action = dist.mode() 80 | else: 81 | action = dist.sample() 82 | 83 | action_log_probs = dist.log_probs(action) 84 | dist_entropy = dist.entropy().mean() 85 | 86 | return value, action, action_log_probs, rnn_hxs 87 | 88 | def get_value(self, inputs, rnn_hxs, masks): 89 | value, _, _ = self.base(inputs, rnn_hxs, masks) 90 | return value 91 | 92 | def evaluate_actions(self, inputs, rnn_hxs, masks, action): 93 | value, actor_features, rnn_hxs = self.base(inputs, rnn_hxs, masks) 94 | dist = self.dist(actor_features) 95 | 96 | action_log_probs = dist.log_probs(action) 97 | dist_entropy = dist.entropy().mean() 98 | 99 | return value, action_log_probs, dist_entropy, rnn_hxs 100 | 101 | 102 | class NNBase(nn.Module): 103 | def __init__(self, recurrent, recurrent_input_size, hidden_size): 104 | super(NNBase, self).__init__() 105 | 106 | self._hidden_size = hidden_size 107 | self._recurrent = recurrent 108 | 109 | if recurrent: 110 | self.gru = nn.GRU(recurrent_input_size, hidden_size) 111 | for name, param in self.gru.named_parameters(): 112 | if "bias" in name: 113 | nn.init.constant_(param, 0) 114 | elif "weight" in name: 115 | nn.init.orthogonal_(param) 116 | 117 | @property 118 | def is_recurrent(self): 119 | return self._recurrent 120 | 121 | @property 122 | def recurrent_hidden_state_size(self): 123 | if self._recurrent: 124 | return self._hidden_size 125 | return 1 126 | 127 | @property 128 | def output_size(self): 129 | return self._hidden_size 130 | 131 | def _forward_gru(self, x, hxs, masks): 132 | if x.size(0) == hxs.size(0): 133 | x, hxs = self.gru(x.unsqueeze(0), (hxs * masks).unsqueeze(0)) 134 | x = x.squeeze(0) 135 | hxs = hxs.squeeze(0) 136 | else: 137 | # x is a (T, N, -1) tensor that has been flatten to (T * N, -1) 138 | N = hxs.size(0) 139 | T = int(x.size(0) / N) 140 | 141 | # unflatten 142 | x = x.view(T, N, x.size(1)) 143 | 144 | # Same deal with masks 145 | masks = masks.view(T, N) 146 | 147 | # Let's figure out which steps in the sequence have a zero for any agent 148 | # We will always assume t=0 has a zero in it as that makes the logic cleaner 149 | has_zeros = (masks[1:] == 0.0).any(dim=-1).nonzero().squeeze().cpu() 150 | 151 | # +1 to correct the masks[1:] 152 | if has_zeros.dim() == 0: 153 | # Deal with scalar 154 | has_zeros = [has_zeros.item() + 1] 155 | else: 156 | has_zeros = (has_zeros + 1).numpy().tolist() 157 | 158 | # add t=0 and t=T to the list 159 | has_zeros = [0] + has_zeros + [T] 160 | 161 | hxs = hxs.unsqueeze(0) 162 | outputs = [] 163 | for i in range(len(has_zeros) - 1): 164 | # We can now process steps that don't have any zeros in masks together! 165 | # This is much faster 166 | start_idx = has_zeros[i] 167 | end_idx = has_zeros[i + 1] 168 | 169 | rnn_scores, hxs = self.gru( 170 | x[start_idx:end_idx], hxs * masks[start_idx].view(1, -1, 1) 171 | ) 172 | 173 | outputs.append(rnn_scores) 174 | 175 | # assert len(outputs) == T 176 | # x is a (T, N, -1) tensor 177 | x = torch.cat(outputs, dim=0) 178 | # flatten 179 | x = x.view(T * N, -1) 180 | hxs = hxs.squeeze(0) 181 | 182 | return x, hxs 183 | 184 | 185 | class MLPBase(NNBase): 186 | def __init__(self, num_inputs, recurrent=False, hidden_size=64): 187 | super(MLPBase, self).__init__(recurrent, num_inputs, hidden_size) 188 | 189 | if recurrent: 190 | num_inputs = hidden_size 191 | 192 | init_ = lambda m: init( 193 | m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2) 194 | ) 195 | 196 | self.actor = nn.Sequential( 197 | init_(nn.Linear(num_inputs, hidden_size)), 198 | nn.ReLU(), 199 | init_(nn.Linear(hidden_size, hidden_size)), 200 | nn.ReLU(), 201 | ) 202 | 203 | self.critic = nn.Sequential( 204 | init_(nn.Linear(num_inputs, hidden_size)), 205 | nn.ReLU(), 206 | init_(nn.Linear(hidden_size, hidden_size)), 207 | nn.ReLU(), 208 | ) 209 | 210 | self.critic_linear = init_(nn.Linear(hidden_size, 1)) 211 | 212 | self.train() 213 | 214 | def forward(self, inputs, rnn_hxs, masks): 215 | x = inputs 216 | 217 | if self.is_recurrent: 218 | x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks) 219 | 220 | hidden_critic = self.critic(x) 221 | hidden_actor = self.actor(x) 222 | 223 | return self.critic_linear(hidden_critic), hidden_actor, rnn_hxs 224 | -------------------------------------------------------------------------------- /seac/train.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | import shutil 5 | import time 6 | from collections import deque 7 | from os import path 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | import torch 12 | from sacred import Experiment 13 | from sacred.observers import ( # noqa 14 | FileStorageObserver, 15 | MongoObserver, 16 | QueuedMongoObserver, 17 | QueueObserver, 18 | ) 19 | from torch.utils.tensorboard import SummaryWriter 20 | 21 | import utils 22 | from a2c import A2C, algorithm 23 | from envs import make_vec_envs 24 | from wrappers import RecordEpisodeStatistics, SquashDones 25 | from model import Policy 26 | 27 | import robotic_warehouse # noqa 28 | import lbforaging # noqa 29 | 30 | ex = Experiment(ingredients=[algorithm]) 31 | ex.captured_out_filter = lambda captured_output: "Output capturing turned off." 32 | ex.observers.append(FileStorageObserver("./results/sacred")) 33 | 34 | logging.basicConfig( 35 | level=logging.INFO, 36 | format="(%(process)d) [%(levelname).1s] - (%(asctime)s) - %(name)s >> %(message)s", 37 | datefmt="%m/%d %H:%M:%S", 38 | ) 39 | 40 | 41 | @ex.config 42 | def config(): 43 | env_name = None 44 | time_limit = None 45 | wrappers = ( 46 | RecordEpisodeStatistics, 47 | SquashDones, 48 | ) 49 | dummy_vecenv = False 50 | 51 | num_env_steps = 100e6 52 | 53 | eval_dir = "./results/video/{id}" 54 | loss_dir = "./results/loss/{id}" 55 | save_dir = "./results/trained_models/{id}" 56 | 57 | log_interval = 2000 58 | save_interval = int(1e6) 59 | eval_interval = int(1e6) 60 | episodes_per_eval = 8 61 | 62 | 63 | for conf in glob.glob("configs/*.yaml"): 64 | name = f"{Path(conf).stem}" 65 | ex.add_named_config(name, conf) 66 | 67 | def _squash_info(info): 68 | info = [i for i in info if i] 69 | new_info = {} 70 | keys = set([k for i in info for k in i.keys()]) 71 | keys.discard("TimeLimit.truncated") 72 | for key in keys: 73 | mean = np.mean([np.array(d[key]).sum() for d in info if key in d]) 74 | new_info[key] = mean 75 | return new_info 76 | 77 | 78 | @ex.capture 79 | def evaluate( 80 | agents, 81 | monitor_dir, 82 | episodes_per_eval, 83 | env_name, 84 | seed, 85 | wrappers, 86 | dummy_vecenv, 87 | time_limit, 88 | algorithm, 89 | _log, 90 | ): 91 | device = algorithm["device"] 92 | 93 | eval_envs = make_vec_envs( 94 | env_name, 95 | seed, 96 | dummy_vecenv, 97 | episodes_per_eval, 98 | time_limit, 99 | wrappers, 100 | device, 101 | monitor_dir=monitor_dir, 102 | ) 103 | 104 | n_obs = eval_envs.reset() 105 | n_recurrent_hidden_states = [ 106 | torch.zeros( 107 | episodes_per_eval, agent.model.recurrent_hidden_state_size, device=device 108 | ) 109 | for agent in agents 110 | ] 111 | masks = torch.zeros(episodes_per_eval, 1, device=device) 112 | 113 | all_infos = [] 114 | 115 | while len(all_infos) < episodes_per_eval: 116 | with torch.no_grad(): 117 | _, n_action, _, n_recurrent_hidden_states = zip( 118 | *[ 119 | agent.model.act( 120 | n_obs[agent.agent_id], recurrent_hidden_states, masks 121 | ) 122 | for agent, recurrent_hidden_states in zip( 123 | agents, n_recurrent_hidden_states 124 | ) 125 | ] 126 | ) 127 | 128 | # Obser reward and next obs 129 | n_obs, _, done, infos = eval_envs.step(n_action) 130 | 131 | n_masks = torch.tensor( 132 | [[0.0] if done_ else [1.0] for done_ in done], 133 | dtype=torch.float32, 134 | device=device, 135 | ) 136 | all_infos.extend([i for i in infos if i]) 137 | 138 | eval_envs.close() 139 | info = _squash_info(all_infos) 140 | _log.info( 141 | f"Evaluation using {len(all_infos)} episodes: mean reward {info['episode_reward']:.5f}\n" 142 | ) 143 | 144 | 145 | @ex.automain 146 | def main( 147 | _run, 148 | _log, 149 | num_env_steps, 150 | env_name, 151 | seed, 152 | algorithm, 153 | dummy_vecenv, 154 | time_limit, 155 | wrappers, 156 | save_dir, 157 | eval_dir, 158 | loss_dir, 159 | log_interval, 160 | save_interval, 161 | eval_interval, 162 | ): 163 | 164 | if loss_dir: 165 | loss_dir = path.expanduser(loss_dir.format(id=str(_run._id))) 166 | utils.cleanup_log_dir(loss_dir) 167 | writer = SummaryWriter(loss_dir) 168 | else: 169 | writer = None 170 | 171 | eval_dir = path.expanduser(eval_dir.format(id=str(_run._id))) 172 | save_dir = path.expanduser(save_dir.format(id=str(_run._id))) 173 | 174 | utils.cleanup_log_dir(eval_dir) 175 | utils.cleanup_log_dir(save_dir) 176 | 177 | torch.set_num_threads(1) 178 | envs = make_vec_envs( 179 | env_name, 180 | seed, 181 | dummy_vecenv, 182 | algorithm["num_processes"], 183 | time_limit, 184 | wrappers, 185 | algorithm["device"], 186 | ) 187 | 188 | agents = [ 189 | A2C(i, osp, asp) 190 | for i, (osp, asp) in enumerate(zip(envs.observation_space, envs.action_space)) 191 | ] 192 | obs = envs.reset() 193 | 194 | for i in range(len(obs)): 195 | agents[i].storage.obs[0].copy_(obs[i]) 196 | agents[i].storage.to(algorithm["device"]) 197 | 198 | start = time.time() 199 | num_updates = ( 200 | int(num_env_steps) // algorithm["num_steps"] // algorithm["num_processes"] 201 | ) 202 | 203 | all_infos = deque(maxlen=10) 204 | 205 | for j in range(1, num_updates + 1): 206 | 207 | for step in range(algorithm["num_steps"]): 208 | # Sample actions 209 | with torch.no_grad(): 210 | n_value, n_action, n_action_log_prob, n_recurrent_hidden_states = zip( 211 | *[ 212 | agent.model.act( 213 | agent.storage.obs[step], 214 | agent.storage.recurrent_hidden_states[step], 215 | agent.storage.masks[step], 216 | ) 217 | for agent in agents 218 | ] 219 | ) 220 | # Obser reward and next obs 221 | obs, reward, done, infos = envs.step(n_action) 222 | # envs.envs[0].render() 223 | 224 | # If done then clean the history of observations. 225 | masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) 226 | 227 | bad_masks = torch.FloatTensor( 228 | [ 229 | [0.0] if info.get("TimeLimit.truncated", False) else [1.0] 230 | for info in infos 231 | ] 232 | ) 233 | for i in range(len(agents)): 234 | agents[i].storage.insert( 235 | obs[i], 236 | n_recurrent_hidden_states[i], 237 | n_action[i], 238 | n_action_log_prob[i], 239 | n_value[i], 240 | reward[:, i].unsqueeze(1), 241 | masks, 242 | bad_masks, 243 | ) 244 | 245 | for info in infos: 246 | if info: 247 | all_infos.append(info) 248 | 249 | # value_loss, action_loss, dist_entropy = agent.update(rollouts) 250 | for agent in agents: 251 | agent.compute_returns() 252 | 253 | for agent in agents: 254 | loss = agent.update([a.storage for a in agents]) 255 | for k, v in loss.items(): 256 | if writer: 257 | writer.add_scalar(f"agent{agent.agent_id}/{k}", v, j) 258 | 259 | for agent in agents: 260 | agent.storage.after_update() 261 | 262 | if j % log_interval == 0 and len(all_infos) > 1: 263 | squashed = _squash_info(all_infos) 264 | 265 | total_num_steps = ( 266 | (j + 1) * algorithm["num_processes"] * algorithm["num_steps"] 267 | ) 268 | end = time.time() 269 | _log.info( 270 | f"Updates {j}, num timesteps {total_num_steps}, FPS {int(total_num_steps / (end - start))}" 271 | ) 272 | _log.info( 273 | f"Last {len(all_infos)} training episodes mean reward {squashed['episode_reward'].sum():.3f}" 274 | ) 275 | 276 | for k, v in squashed.items(): 277 | _run.log_scalar(k, v, j) 278 | all_infos.clear() 279 | 280 | if save_interval is not None and ( 281 | j > 0 and j % save_interval == 0 or j == num_updates 282 | ): 283 | cur_save_dir = path.join(save_dir, f"u{j}") 284 | for agent in agents: 285 | save_at = path.join(cur_save_dir, f"agent{agent.agent_id}") 286 | os.makedirs(save_at, exist_ok=True) 287 | agent.save(save_at) 288 | archive_name = shutil.make_archive(cur_save_dir, "xztar", save_dir, f"u{j}") 289 | shutil.rmtree(cur_save_dir) 290 | _run.add_artifact(archive_name) 291 | 292 | if eval_interval is not None and ( 293 | j > 0 and j % eval_interval == 0 or j == num_updates 294 | ): 295 | evaluate( 296 | agents, os.path.join(eval_dir, f"u{j}"), 297 | ) 298 | videos = glob.glob(os.path.join(eval_dir, f"u{j}") + "/*.mp4") 299 | for i, v in enumerate(videos): 300 | _run.add_artifact(v, f"u{j}.{i}.mp4") 301 | envs.close() 302 | -------------------------------------------------------------------------------- /seql/iql.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import numpy as np 5 | 6 | import torch 7 | 8 | from marl_algorithm import MarlAlgorithm 9 | from marl_utils import soft_update 10 | from agent import Agent 11 | 12 | device = torch.device("cuda:0" if torch.cuda.is_available() else device) 13 | 14 | MSELoss = torch.nn.MSELoss() 15 | 16 | 17 | class IQL(MarlAlgorithm): 18 | """ 19 | (Deep) Independent Q-Learning (IQL) class 20 | 21 | Original IQL paper: 22 | Tan, M. (1993). 23 | Multi-agent reinforcement learning: Independent vs. cooperative agents. 24 | In Proceedings of the tenth international conference on machine learning (pp. 330-337). 25 | 26 | Link: http://web.mit.edu/16.412j/www/html/Advanced%20lectures/2004/Multi-AgentReinforcementLearningIndependentVersusCooperativeAgents.pdf 27 | 28 | Deep Q-Learning (DQN) paper: 29 | Mnih, V., Kavukcuoglu, K., Silver, D., Rusu, A. A., Veness, J., Bellemare, M. G., ... & Petersen, S. (2015). 30 | Human-level control through deep reinforcement learning. 31 | Nature, 518(7540), 529. 32 | 33 | Link: https://www.nature.com/articles/nature14236?wm=book_wap_0005 34 | """ 35 | 36 | def __init__(self, n_agents, observation_sizes, action_sizes, params): 37 | """ 38 | Initialise parameters for IQL training 39 | :param n_agents: number of agents 40 | :param observation_sizes: dimension of observation for each agent 41 | :param action_sizes: dimension of action for each agent 42 | :param params: parsed arglist parameter list 43 | """ 44 | super(IQL, self).__init__( 45 | n_agents, observation_sizes, action_sizes, params 46 | ) 47 | 48 | self.shared_experience = params.shared_experience 49 | self.shared_lambda = params.shared_lambda 50 | self.targets_type = params.targets 51 | self.model_dev = device # device for model 52 | self.trgt_model_dev = device # device for target model 53 | 54 | self.agents = [ 55 | Agent(observation_sizes[i], action_sizes[i], params) 56 | for i in range(n_agents) 57 | ] 58 | 59 | def reset(self, episode): 60 | """ 61 | Reset algorithm for new episode 62 | :param episode: new episode number 63 | """ 64 | self.prep_rollouts(device=device) 65 | 66 | def prep_rollouts(self, device=device): 67 | """ 68 | Prepare networks for rollout steps and use given device 69 | :param device: device to cast networks to 70 | """ 71 | for a in self.agents: 72 | a.model.eval() 73 | if device == "gpu": 74 | fn = lambda x: x.cuda() 75 | else: 76 | fn = lambda x: x.cpu() 77 | if not self.model_dev == device: 78 | for a in self.agents: 79 | a.model = fn(a.model) 80 | self.model_dev = device 81 | 82 | def step(self, observations, explore=False, available_actions=None): 83 | """ 84 | Take a step forward in environment with all agents 85 | :param observations: list of observations for each agent 86 | :param explore: flag whether or not to add exploration noise 87 | :param available_actions: binary vector (n_agents, n_actions) where each list contains 88 | binary values indicating whether action is applicable 89 | :return: list of actions for each agent 90 | """ 91 | if available_actions is None: 92 | return [a.step(obs, explore)[0] for a, obs in zip(self.agents, observations)] 93 | else: 94 | return [ 95 | a.step(obs, explore, available_actions[i])[0] 96 | for i, (a, obs) in enumerate(zip(self.agents, observations)) 97 | ] 98 | self.t_steps += 1 99 | 100 | def update_all_targets(self): 101 | """ 102 | Update all target networks (called after normal updates have been 103 | performed for each agent) 104 | """ 105 | for a in self.agents: 106 | soft_update(a.target_model, a.model, self.params.tau) 107 | 108 | def prep_training(self, device="gpu"): 109 | """ 110 | Prepare networks for training and use given device 111 | :param device: device to cast networks to 112 | """ 113 | for a in self.agents: 114 | a.model.train() 115 | a.target_model.train() 116 | if device == "gpu": 117 | fn = lambda x: x.cuda() 118 | else: 119 | fn = lambda x: x.cpu() 120 | if not self.model_dev == device: 121 | for a in self.agents: 122 | a.model = fn(a.model) 123 | self.model_dev = device 124 | if not self.trgt_model_dev == device: 125 | for a in self.agents: 126 | a.target_model = fn(a.target_model) 127 | self.trgt_model_dev = device 128 | 129 | def update_agent(self, sample, agent_i, use_cuda): 130 | """ 131 | Update parameters of agent model based on sample from replay buffer 132 | :param sample: tuple of (observations, actions, rewards, next 133 | observations, and episode end masks) sampled randomly from 134 | the replay buffer 135 | :param agent_i: index of agent to update 136 | :param use_cuda: flag if cuda/ gpus should be used 137 | :return: q loss 138 | """ 139 | # timer = time.process_time() 140 | obs, acs, rews, next_obs, dones = sample 141 | curr_agent = self.agents[agent_i] 142 | 143 | curr_agent.optimizer.zero_grad() 144 | 145 | if self.targets_type == "simple": 146 | q_next_states = curr_agent.target_model(next_obs) 147 | target_next_states = q_next_states.max(-1)[0] 148 | elif self.targets_type == "double": 149 | q_tp1_values = curr_agent.model(next_obs).detach() 150 | _, a_prime = q_tp1_values.max(1) 151 | q_next_states = curr_agent.target_model(next_obs) 152 | target_next_states = q_next_states.gather(1, a_prime.unsqueeze(1)) 153 | elif self.targets_type == "our-double": 154 | # this does not use target network but instead uses the network of another agent 155 | other_agent = self.agents[int(not agent_i)] # or sample any other agent except agent_i (if agents>2) 156 | q_tp1_values = curr_agent.model(next_obs).detach() 157 | _, a_prime = q_tp1_values.max(1) 158 | q_next_states = other_agent.model(next_obs).detach() 159 | target_next_states = q_next_states.gather(1, a_prime.unsqueeze(1)) 160 | elif self.targets_type == "our-clipped": 161 | # uses TD3's clipped q networks by taking the min of all agents models 162 | target_next_states = torch.cat([a.model(next_obs).detach().max(dim=1)[0].unsqueeze(1) for a in self.agents], dim=1).min(dim=1)[0] 163 | 164 | 165 | # compute Q-targets for current states 166 | target_states = ( 167 | rews.view(-1, 1) 168 | + self.gamma * target_next_states.view(-1, 1)# * (1 - dones.view(-1, 1)) 169 | ) 170 | 171 | # target_timer = time.process_time() - timer 172 | # print(f"\t\tTarget computation time: {target_timer}") 173 | # timer = time.process_time() 174 | 175 | 176 | # local Q-values 177 | all_q_states = curr_agent.model(obs) 178 | q_states = torch.sum(all_q_states * acs, dim=1).view(-1, 1) 179 | 180 | # q_timer = time.process_time() - timer 181 | # print(f"\t\tQ-values computation time: {q_timer}") 182 | # timer = time.process_time() 183 | 184 | if self.shared_experience: 185 | batch_size_agent = self.batch_size // self.n_agents 186 | agent_mask = np.arange(batch_size_agent * agent_i, batch_size_agent * (agent_i + 1)) 187 | other_agents_mask = np.concatenate([np.arange(0, batch_size_agent * agent_i), np.arange(batch_size_agent * (agent_i + 1), self.batch_size)]) 188 | qloss = MSELoss(q_states[agent_mask], target_states[agent_mask].detach()) 189 | qloss += self.shared_lambda * MSELoss(q_states[other_agents_mask], target_states[other_agents_mask].detach()) 190 | else: 191 | qloss = MSELoss(q_states, target_states.detach()) 192 | 193 | # loss_timer = time.process_time() - timer 194 | # print(f"\t\tLoss computation time: {loss_timer}") 195 | qloss.backward() 196 | torch.nn.utils.clip_grad_norm_(curr_agent.model.parameters(), 0.5) 197 | curr_agent.optimizer.step() 198 | 199 | return qloss 200 | 201 | def update(self, memory, use_cuda=False): 202 | """ 203 | Train agent models based on memory samples 204 | :param memory: replay buffer memory to sample experience from 205 | :param use_cuda: flag if cuda/ gpus should be used 206 | :return: qnetwork losses 207 | """ 208 | q_losses = [] 209 | if use_cuda: 210 | self.prep_training(device="gpu") 211 | else: 212 | self.prep_training(device=device) 213 | if self.shared_experience: 214 | samples = memory.sample_shared(self.params.batch_size) 215 | for a_i in range(self.n_agents): 216 | # print(f"\tUpdate agent {a_i}:") 217 | # timer = time.process_time() 218 | if not self.shared_experience: 219 | samples = memory.sample(self.params.batch_size, a_i) 220 | # sample_time = time.process_time() - timer 221 | # print(f"\t\tSample time from memory: {sample_time}") 222 | q_loss = self.update_agent(samples, a_i, use_cuda=False) 223 | q_losses.append(q_loss) 224 | self.update_all_targets() 225 | self.prep_rollouts(device=device) 226 | 227 | return q_losses 228 | 229 | def load_model_networks(self, directory, extension="_final"): 230 | """ 231 | Load model networks of all agents 232 | :param directory: path to directory where to load models from 233 | """ 234 | for i, agent in enumerate(self.agents): 235 | name = "iql_agent%d_params" % i 236 | name += extension 237 | agent.model.load_state_dict( 238 | torch.load(os.path.join(directory, name), map_location=device) 239 | ) 240 | -------------------------------------------------------------------------------- /seac/storage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 3 | 4 | 5 | def _flatten_helper(T, N, _tensor): 6 | return _tensor.view(T * N, *_tensor.size()[2:]) 7 | 8 | class RolloutStorage(object): 9 | def __init__( 10 | self, 11 | obs_space, 12 | action_space, 13 | recurrent_hidden_state_size, 14 | num_steps, 15 | num_processes, 16 | ): 17 | obs_shape = obs_space.shape 18 | self.obs = torch.zeros(num_steps + 1, num_processes, *obs_shape) 19 | self.recurrent_hidden_states = torch.zeros( 20 | num_steps + 1, num_processes, recurrent_hidden_state_size 21 | ) 22 | self.rewards = torch.zeros(num_steps, num_processes, 1) 23 | self.value_preds = torch.zeros(num_steps + 1, num_processes, 1) 24 | self.returns = torch.zeros(num_steps + 1, num_processes, 1) 25 | self.action_log_probs = torch.zeros(num_steps, num_processes, 1) 26 | if action_space.__class__.__name__ == "Discrete": 27 | action_shape = 1 28 | else: 29 | action_shape = action_space.shape[0] 30 | self.actions = torch.zeros(num_steps, num_processes, action_shape) 31 | if action_space.__class__.__name__ == "Discrete": 32 | self.actions = self.actions.long() 33 | self.masks = torch.ones(num_steps + 1, num_processes, 1) 34 | 35 | # Masks that indicate whether it's a true terminal state 36 | # or time limit end state 37 | self.bad_masks = torch.ones(num_steps + 1, num_processes, 1) 38 | 39 | self.num_steps = num_steps 40 | self.step = 0 41 | 42 | def to(self, device): 43 | self.obs = self.obs.to(device) 44 | self.recurrent_hidden_states = self.recurrent_hidden_states.to(device) 45 | self.rewards = self.rewards.to(device) 46 | self.value_preds = self.value_preds.to(device) 47 | self.returns = self.returns.to(device) 48 | self.action_log_probs = self.action_log_probs.to(device) 49 | self.actions = self.actions.to(device) 50 | self.masks = self.masks.to(device) 51 | self.bad_masks = self.bad_masks.to(device) 52 | 53 | def insert( 54 | self, 55 | obs, 56 | recurrent_hidden_states, 57 | actions, 58 | action_log_probs, 59 | value_preds, 60 | rewards, 61 | masks, 62 | bad_masks, 63 | ): 64 | self.obs[self.step + 1].copy_(obs) 65 | self.recurrent_hidden_states[self.step + 1].copy_(recurrent_hidden_states) 66 | self.actions[self.step].copy_(actions) 67 | self.action_log_probs[self.step].copy_(action_log_probs) 68 | self.value_preds[self.step].copy_(value_preds) 69 | self.rewards[self.step].copy_(rewards) 70 | self.masks[self.step + 1].copy_(masks) 71 | self.bad_masks[self.step + 1].copy_(bad_masks) 72 | 73 | self.step = (self.step + 1) % self.num_steps 74 | 75 | def after_update(self): 76 | self.obs[0].copy_(self.obs[-1]) 77 | self.recurrent_hidden_states[0].copy_(self.recurrent_hidden_states[-1]) 78 | self.masks[0].copy_(self.masks[-1]) 79 | self.bad_masks[0].copy_(self.bad_masks[-1]) 80 | 81 | def compute_returns( 82 | self, next_value, use_gae, gamma, gae_lambda, use_proper_time_limits=True 83 | ): 84 | if use_proper_time_limits: 85 | if use_gae: 86 | self.value_preds[-1] = next_value 87 | gae = 0 88 | for step in reversed(range(self.rewards.size(0))): 89 | delta = ( 90 | self.rewards[step] 91 | + gamma * self.value_preds[step + 1] * self.masks[step + 1] 92 | - self.value_preds[step] 93 | ) 94 | gae = delta + gamma * gae_lambda * self.masks[step + 1] * gae 95 | gae = gae * self.bad_masks[step + 1] 96 | self.returns[step] = gae + self.value_preds[step] 97 | else: 98 | self.returns[-1] = next_value 99 | for step in reversed(range(self.rewards.size(0))): 100 | self.returns[step] = ( 101 | ( 102 | self.returns[step + 1] * gamma * self.masks[step + 1] 103 | + self.rewards[step] 104 | ) 105 | * self.bad_masks[step + 1] 106 | + (1 - self.bad_masks[step + 1]) * self.value_preds[step] 107 | ) 108 | else: 109 | if use_gae: 110 | self.value_preds[-1] = next_value 111 | gae = 0 112 | for step in reversed(range(self.rewards.size(0))): 113 | delta = ( 114 | self.rewards[step] 115 | + gamma * self.value_preds[step + 1] * self.masks[step + 1] 116 | - self.value_preds[step] 117 | ) 118 | gae = delta + gamma * gae_lambda * self.masks[step + 1] * gae 119 | self.returns[step] = gae + self.value_preds[step] 120 | else: 121 | self.returns[-1] = next_value 122 | for step in reversed(range(self.rewards.size(0))): 123 | self.returns[step] = ( 124 | self.returns[step + 1] * gamma * self.masks[step + 1] 125 | + self.rewards[step] 126 | ) 127 | 128 | def feed_forward_generator( 129 | self, advantages, num_mini_batch=None, mini_batch_size=None 130 | ): 131 | num_steps, num_processes = self.rewards.size()[0:2] 132 | batch_size = num_processes * num_steps 133 | 134 | if mini_batch_size is None: 135 | assert batch_size >= num_mini_batch, ( 136 | "PPO requires the number of processes ({}) " 137 | "* number of steps ({}) = {} " 138 | "to be greater than or equal to the number of PPO mini batches ({})." 139 | "".format( 140 | num_processes, num_steps, num_processes * num_steps, num_mini_batch 141 | ) 142 | ) 143 | mini_batch_size = batch_size // num_mini_batch 144 | sampler = BatchSampler( 145 | SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=True 146 | ) 147 | for indices in sampler: 148 | obs_batch = self.obs[:-1].view(-1, *self.obs.size()[2:])[indices] 149 | recurrent_hidden_states_batch = self.recurrent_hidden_states[:-1].view( 150 | -1, self.recurrent_hidden_states.size(-1) 151 | )[indices] 152 | actions_batch = self.actions.view(-1, self.actions.size(-1))[indices] 153 | value_preds_batch = self.value_preds[:-1].view(-1, 1)[indices] 154 | return_batch = self.returns[:-1].view(-1, 1)[indices] 155 | masks_batch = self.masks[:-1].view(-1, 1)[indices] 156 | old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[indices] 157 | if advantages is None: 158 | adv_targ = None 159 | else: 160 | adv_targ = advantages.view(-1, 1)[indices] 161 | 162 | yield obs_batch, recurrent_hidden_states_batch, actions_batch, value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_targ 163 | 164 | def recurrent_generator(self, advantages, num_mini_batch): 165 | num_processes = self.rewards.size(1) 166 | assert num_processes >= num_mini_batch, ( 167 | "PPO requires the number of processes ({}) " 168 | "to be greater than or equal to the number of " 169 | "PPO mini batches ({}).".format(num_processes, num_mini_batch) 170 | ) 171 | num_envs_per_batch = num_processes // num_mini_batch 172 | perm = torch.randperm(num_processes) 173 | for start_ind in range(0, num_processes, num_envs_per_batch): 174 | obs_batch = [] 175 | recurrent_hidden_states_batch = [] 176 | actions_batch = [] 177 | value_preds_batch = [] 178 | return_batch = [] 179 | masks_batch = [] 180 | old_action_log_probs_batch = [] 181 | adv_targ = [] 182 | 183 | for offset in range(num_envs_per_batch): 184 | ind = perm[start_ind + offset] 185 | obs_batch.append(self.obs[:-1, ind]) 186 | recurrent_hidden_states_batch.append( 187 | self.recurrent_hidden_states[0:1, ind] 188 | ) 189 | actions_batch.append(self.actions[:, ind]) 190 | value_preds_batch.append(self.value_preds[:-1, ind]) 191 | return_batch.append(self.returns[:-1, ind]) 192 | masks_batch.append(self.masks[:-1, ind]) 193 | old_action_log_probs_batch.append(self.action_log_probs[:, ind]) 194 | adv_targ.append(advantages[:, ind]) 195 | 196 | T, N = self.num_steps, num_envs_per_batch 197 | # These are all tensors of size (T, N, -1) 198 | obs_batch = torch.stack(obs_batch, 1) 199 | actions_batch = torch.stack(actions_batch, 1) 200 | value_preds_batch = torch.stack(value_preds_batch, 1) 201 | return_batch = torch.stack(return_batch, 1) 202 | masks_batch = torch.stack(masks_batch, 1) 203 | old_action_log_probs_batch = torch.stack(old_action_log_probs_batch, 1) 204 | adv_targ = torch.stack(adv_targ, 1) 205 | 206 | # States is just a (N, -1) tensor 207 | recurrent_hidden_states_batch = torch.stack( 208 | recurrent_hidden_states_batch, 1 209 | ).view(N, -1) 210 | 211 | # Flatten the (T, N, ...) tensors to (T * N, ...) 212 | obs_batch = _flatten_helper(T, N, obs_batch) 213 | actions_batch = _flatten_helper(T, N, actions_batch) 214 | value_preds_batch = _flatten_helper(T, N, value_preds_batch) 215 | return_batch = _flatten_helper(T, N, return_batch) 216 | masks_batch = _flatten_helper(T, N, masks_batch) 217 | old_action_log_probs_batch = _flatten_helper( 218 | T, N, old_action_log_probs_batch 219 | ) 220 | adv_targ = _flatten_helper(T, N, adv_targ) 221 | 222 | yield obs_batch, recurrent_hidden_states_batch, actions_batch, value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_targ 223 | -------------------------------------------------------------------------------- /seql/utilities/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from collections import namedtuple 4 | 5 | import numpy as np 6 | 7 | 8 | class Logger: 9 | """ 10 | Class to log training information 11 | """ 12 | 13 | def __init__( 14 | self, 15 | n_agents, 16 | task_name="mape", 17 | run_name="default", 18 | log_path="logs", 19 | ): 20 | """ 21 | Create Logger instance 22 | :param n_agents: number of agents 23 | :param task_name: name of task 24 | :param run_name: name of run iteration 25 | :param log_path: path where logs should be saved 26 | """ 27 | self.n_agents = n_agents 28 | self.task_name = task_name 29 | self.run_name = run_name 30 | self.log_path = log_path 31 | 32 | # episode info 33 | self.episode = namedtuple("Ep", "number returns variances epsilon") 34 | self.current_episode = 0 35 | self.episodes = [] 36 | 37 | # loss info 38 | self.loss = namedtuple("Loss", "name episode mean variance") 39 | 40 | # training returns 41 | self.training_returns = [] 42 | self.training_agent_returns = [] 43 | 44 | # parameters in arrays (for efficiency) 45 | self.returns_means = [] 46 | self.returns_vars = [] 47 | self.epsilons = [] 48 | self.alg_losses_list = [[] for i in range(n_agents)] 49 | # store current episode 50 | self.current_alg_losses_list = [[] for i in range(n_agents)] 51 | 52 | # alg losses 53 | self.alg_losses = [] 54 | for _ in range(n_agents): 55 | losses = {} 56 | losses["qnetwork"] = [] 57 | self.alg_losses.append(losses) 58 | 59 | def log_episode(self, ep, returns_means, returns_vars, epsilon): 60 | """ 61 | Save episode information 62 | :param ep: episode number 63 | :param returns_means: average returns during episode (for each agent) 64 | :param returns_vars: variance of returns during episode (for each agent) 65 | :param epsilon: value for exploration 66 | """ 67 | ep = self.episode(ep, returns_means, returns_vars, epsilon) 68 | self.episodes.append(ep) 69 | self.returns_means.append(returns_means) 70 | self.returns_vars.append(returns_vars) 71 | self.epsilons.append(epsilon) 72 | 73 | self.current_episode = ep 74 | 75 | n_losses = 0 76 | for l in self.current_alg_losses_list: 77 | n_losses += l.__len__() 78 | if n_losses == 0: 79 | return 80 | 81 | for i in range(self.n_agents): 82 | q_losses = np.array(self.current_alg_losses_list[i]) 83 | q_loss_mean = q_losses.mean() 84 | q_loss = self.loss("qnetwork", ep.number, q_loss_mean, q_losses.var()) 85 | self.alg_losses[i]["qnetwork"].append(q_loss) 86 | self.alg_losses_list[i].append(q_loss_mean) 87 | 88 | # empty current episode lists 89 | self.current_alg_losses_list = [[] for i in range(self.n_agents)] 90 | 91 | def log_training_returns(self, timestep, ret, rets): 92 | """ 93 | Save mean return over last x episodes 94 | :param timestep (int): timestep of returns 95 | :param ret (float): mean cumulative return over last 10 episodes 96 | :param rets (List[float]): mean returns over last 10 episodes for each agent 97 | """ 98 | self.training_returns.append((timestep, ret)) 99 | self.training_agent_returns.append(rets) 100 | 101 | def log_losses(self, ep, losses): 102 | """ 103 | Save loss information 104 | :param ep: episode number 105 | :param losses: losses of algorithm 106 | """ 107 | qnet_loss = losses 108 | if len(qnet_loss) > 0: 109 | for i in range(self.n_agents): 110 | q_loss = qnet_loss[i].item() 111 | self.current_alg_losses_list[i].append(q_loss) 112 | 113 | def dump_episodes(self, num=None): 114 | """ 115 | Output episode info 116 | :param num: number of last episodes to output info for (or all if None) 117 | """ 118 | if num is None: 119 | start_idx = 0 120 | else: 121 | start_idx = -num 122 | print("\n\nEpisode\t\t\treturns\t\t\tvariances\t\t\t\texploration") 123 | for ep in self.episodes[start_idx:]: 124 | line = str(ep.number) + "\t\t\t" 125 | for ret in ep.returns: 126 | line += "%.3f " % ret 127 | line = line[:-1] + "\t\t" 128 | for var in ep.variances: 129 | line += "%.3f " % var 130 | line = line[:-1] + "\t\t\t" 131 | line += "%.3f" % ep.epsilon 132 | print(line) 133 | print() 134 | 135 | def __format_time(self, time): 136 | """ 137 | format time from seconds to string 138 | :param time: time in seconds (float) 139 | :return: time_string 140 | """ 141 | hours = time // 3600 142 | time -= hours * 3600 143 | minutes = time // 60 144 | time -= minutes * 60 145 | time_string = "%d:%d:%.2f" % (hours, minutes, time) 146 | return time_string 147 | 148 | def dump_train_progress(self, ep, num_episodes, duration): 149 | """ 150 | Output training progress info 151 | :param ep: current episode number 152 | :param num_episodes: number of episodes to complete 153 | :param duration: training duration so far (in seconds) 154 | """ 155 | print( 156 | "Training progress:\tepisodes: %d/%d\t\t\t\tduration: %s" 157 | % (ep + 1, num_episodes, self.__format_time(duration)) 158 | ) 159 | progress_percent = (ep + 1) / num_episodes 160 | remaining_duration = duration * (1 - progress_percent) / progress_percent 161 | 162 | arrow_len = 50 163 | arrow_progress = int(progress_percent * arrow_len) 164 | arrow_string = "|" + arrow_progress * "=" + ">" + (arrow_len - arrow_progress) * " " + "|" 165 | print( 166 | "%.2f%%\t%s\tremaining duration: %s\n" 167 | % (progress_percent * 100, arrow_string, self.__format_time(remaining_duration)) 168 | ) 169 | 170 | def dump_losses(self, num=None): 171 | """ 172 | Output loss info 173 | :param num: number of last loss entries to output (or all if None) 174 | """ 175 | num_entries = len(self.alg_losses[0]["qnetwork"]) 176 | start_idx = 0 177 | if num is not None: 178 | start_idx = num_entries - num 179 | 180 | if num_entries == 0: 181 | print("No loss values stored yet!") 182 | return 183 | 184 | # build header 185 | header = "Episode index\t\tagent_id:\t\t" 186 | header += "q_loss " 187 | print(header) 188 | 189 | for i in range(start_idx, num_entries): 190 | for j in range(self.n_agents): 191 | alg_loss = self.alg_losses[j] 192 | line = "" 193 | q_loss = alg_loss["qnetwork"][i] 194 | line += str(q_loss.episode) + "\t\t\t" + str(j + 1) + ":\t\t\t" 195 | line += "%.5f\t\t" % q_loss.mean 196 | print(line) 197 | 198 | def clear_logs(self): 199 | """ 200 | Remove log files in log dir 201 | """ 202 | if not os.path.isdir(self.log_path): 203 | return 204 | log_dir = os.path.join(self.log_path, self.run_name) 205 | if not os.path.isdir(log_dir): 206 | return 207 | for f in os.listdir(log_dir): 208 | f_path = os.path.join(log_dir, f) 209 | if not os.path.isfile(f_path): 210 | continue 211 | os.remove(f_path) 212 | 213 | def save_episodes(self, num=None, extension="final"): 214 | """ 215 | Save episode information in CSV file 216 | :param num: number of last episodes to save (or all if None) 217 | :param extension: extension name of csv file 218 | """ 219 | if not os.path.isdir(self.log_path): 220 | os.mkdir(self.log_path) 221 | log_dir = os.path.join(self.log_path, self.run_name) 222 | if not os.path.isdir(log_dir): 223 | os.mkdir(log_dir) 224 | 225 | csv_name = "iql_" + self.task_name + "_epinfo_" + extension + ".csv" 226 | csv_path = os.path.join(log_dir, csv_name) 227 | 228 | with open(csv_path, "w") as csv_file: 229 | # write header line 230 | h = "number,returns,variances,epsilon\n" 231 | csv_file.write(h) 232 | 233 | if num is None: 234 | start_idx = 0 235 | else: 236 | start_idx = -num 237 | for ep in self.episodes[start_idx:]: 238 | line = "" 239 | line += str(ep.number) + "," 240 | if len(ep.returns) > 1: 241 | line += "[" 242 | for r in ep.returns: 243 | line += "%.5f " % r 244 | line = line[:-1] + "]," 245 | else: 246 | line += str(ep.returns) + "," 247 | if len(ep.variances) > 1: 248 | line += "[" 249 | for v in ep.variances: 250 | line += "%.5f " % v 251 | line = line[:-1] + "]," 252 | else: 253 | line += str(ep.variances) + "," 254 | line += str(ep.epsilon) + "\n" 255 | csv_file.write(line) 256 | 257 | def save_training_returns(self, extension="final"): 258 | """ 259 | Save training returns so far in file 260 | :param extension: extension name of csv file 261 | """ 262 | if not os.path.isdir(self.log_path): 263 | os.mkdir(self.log_path) 264 | log_dir = os.path.join(self.log_path, self.run_name) 265 | if not os.path.isdir(log_dir): 266 | os.mkdir(log_dir) 267 | 268 | file_name = "iql_" + self.task_name + "_training_returns" + extension + ".csv" 269 | csv_path = os.path.join(log_dir, file_name) 270 | 271 | with open(csv_path, "w") as csv_file: 272 | # write header line 273 | h = "timestep,return," 274 | for i in range(self.n_agents): 275 | h += f"ag{i + 1}_return," 276 | h = h[:-1] + "\n" 277 | csv_file.write(h) 278 | 279 | for i in range(len(self.training_returns)): 280 | timestep, ret = self.training_returns[i] 281 | rets = self.training_agent_returns[i] 282 | line = f"{timestep},{ret}," 283 | for ret in rets: 284 | line += "%.5f," % ret 285 | line = line[:-1] + "\n" 286 | csv_file.write(line) 287 | 288 | 289 | def save_losses(self, num=None, extension="final"): 290 | """ 291 | Save loss information in CSV file 292 | :param num: number of last episodes to save (or all if None) 293 | :param extension: extension name of csv file 294 | """ 295 | if not os.path.isdir(self.log_path): 296 | os.mkdir(self.log_path) 297 | log_dir = os.path.join(self.log_path, self.run_name) 298 | if not os.path.isdir(log_dir): 299 | os.mkdir(log_dir) 300 | 301 | csv_name = "iql_" + self.task_name + "_lossinfo_" + extension + ".csv" 302 | csv_path = os.path.join(log_dir, csv_name) 303 | 304 | with open(csv_path, "w") as csv_file: 305 | # write header line 306 | h = "iteration,episode," 307 | for i in range(self.n_agents): 308 | h += f"ag{i + 1}_iql_loss," 309 | h = h[:-1] + "\n" 310 | csv_file.write(h) 311 | 312 | num_entries = len(self.alg_losses[0]["qnetwork"]) 313 | start_idx = 0 314 | if num is not None: 315 | start_idx = num_entries - num 316 | 317 | for i in range(start_idx, num_entries): 318 | line = str(i) + "," 319 | for j in range(self.n_agents): 320 | alg_loss = self.alg_losses[j] 321 | q_loss = alg_loss["qnetwork"][i] 322 | if j == 0: 323 | line += str(q_loss.episode) + "," 324 | line += "%.5f," % q_loss.mean 325 | line = line[:-1] + "\n" 326 | csv_file.write(line) 327 | 328 | def save_duration_cuda(self, duration, cuda): 329 | """ 330 | Store mini log file with duration and if cuda was used 331 | :param duration: duration of run in seconds 332 | :param cuda: flag if cuda was used 333 | """ 334 | if not os.path.isdir(self.log_path): 335 | os.mkdir(self.log_path) 336 | log_dir = os.path.join(self.log_path, self.run_name) 337 | if not os.path.isdir(log_dir): 338 | os.mkdir(log_dir) 339 | 340 | log_name = "iql_" + self.task_name + ".log" 341 | log_path = os.path.join(log_dir, log_name) 342 | 343 | with open(log_path, "w") as log_file: 344 | log_file.write("duration: %.2fs\n" % duration) 345 | log_file.write("cuda: %s\n" % str(cuda)) 346 | 347 | def save_parameters( 348 | self, env, task, n_agents, observation_sizes, action_sizes, arglist 349 | ): 350 | """ 351 | Store mini csv file with used parameters 352 | :param env: environment name 353 | :param task: task name 354 | :param n_agents: number of agents 355 | :param observation_sizes: dimension of observation for each agent 356 | :param action_sizes: dimension of action for each agent 357 | :param arglist: parsed arglist of parameters 358 | """ 359 | if not os.path.isdir(self.log_path): 360 | os.mkdir(self.log_path) 361 | log_dir = os.path.join(self.log_path, self.run_name) 362 | if not os.path.isdir(log_dir): 363 | os.mkdir(log_dir) 364 | 365 | log_name = "iql_" + self.task_name + "_parameters.csv" 366 | log_path = os.path.join(log_dir, log_name) 367 | 368 | with open(log_path, "w") as log_file: 369 | log_file.write("param,value\n") 370 | log_file.write("env,%s\n" % env) 371 | log_file.write("task,%s\n" % task) 372 | log_file.write("n_agents,%d\n" % n_agents) 373 | log_file.write("observation_sizes,%s\n" % observation_sizes) 374 | log_file.write("action_sizes,%s\n" % action_sizes) 375 | for arg in vars(arglist): 376 | log_file.write(arg + ",%s\n" % str(getattr(arglist, arg))) 377 | -------------------------------------------------------------------------------- /seql/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | from torch.autograd import Variable 8 | 9 | from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete 10 | 11 | from iql import IQL 12 | from baseline_buffer import MARLReplayBuffer 13 | 14 | from utilities.model_saver import ModelSaver 15 | from utilities.logger import Logger 16 | 17 | USE_CUDA = False #torch.cuda.is_available() 18 | TARGET_TYPES = ["simple", "double", "our-double", "our-clipped"] 19 | 20 | 21 | class Train: 22 | def __init__(self): 23 | self.parser = argparse.ArgumentParser( 24 | "Reinforcement Learning experiments for multiagent environments" 25 | ) 26 | self.parse_args() 27 | self.arglist = self.parser.parse_args() 28 | 29 | def parse_default_args(self): 30 | """ 31 | Parse default arguments for MARL training script 32 | """ 33 | # algorithm 34 | self.parser.add_argument("--hidden_dim", default=128, type=int) 35 | self.parser.add_argument("--shared_experience", action="store_true", default=False) 36 | self.parser.add_argument("--shared_lambda", default=1.0, type=float) 37 | self.parser.add_argument( 38 | "--targets", type=str, default="simple", help="target computation used for DQN" 39 | ) 40 | 41 | # training length 42 | self.parser.add_argument( 43 | "--num_episodes", type=int, default=120000, help="number of episodes" 44 | ) 45 | self.parser.add_argument( 46 | "--max_episode_len", type=int, default=25, help="maximum episode length" 47 | ) 48 | 49 | # core training parameters 50 | self.parser.add_argument( 51 | "--n_training_threads", default=1, type=int, help="number of training threads" 52 | ) 53 | self.parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") 54 | self.parser.add_argument( 55 | "--tau", type=float, default=0.05, help="tau as stepsize for target network updates" 56 | ) 57 | self.parser.add_argument( 58 | "--lr", type=float, default=0.0001, help="learning rate for Adam optimizer" #use 5e-5 for RWARE 59 | ) 60 | self.parser.add_argument( 61 | "--seed", type=int, default=None, help="random seed used throughout training" 62 | ) 63 | self.parser.add_argument( 64 | "--steps_per_update", type=int, default=1, help="number of steps before updates" 65 | ) 66 | 67 | self.parser.add_argument( 68 | "--buffer_capacity", type=int, default=int(1e6), help="Replay buffer capacity" 69 | ) 70 | self.parser.add_argument( 71 | "--batch_size", 72 | type=int, 73 | default=128, 74 | help="number of episodes to optimize at the same time", 75 | ) 76 | self.parser.add_argument( 77 | "--epsilon", type=float, default=1.0, help="epsilon value" 78 | ) 79 | self.parser.add_argument( 80 | "--goal_epsilon", type=float, default=0.01, help="epsilon target value" 81 | ) 82 | self.parser.add_argument( 83 | "--epsilon_decay", type=float, default=10, help="epsilon decay value" 84 | ) 85 | self.parser.add_argument( 86 | "--epsilon_anneal_slow", action="store_true", default=False, help="anneal epsilon slowly" 87 | ) 88 | 89 | # visualisation 90 | self.parser.add_argument("--render", action="store_true", default=False) 91 | self.parser.add_argument( 92 | "--eval_frequency", default=50, type=int, help="frequency of evaluation episodes" 93 | ) 94 | self.parser.add_argument( 95 | "--eval_episodes", default=5, type=int, help="number of evaluation episodes" 96 | ) 97 | self.parser.add_argument( 98 | "--run", type=str, default="default", help="run name for stored paths" 99 | ) 100 | self.parser.add_argument("--save_interval", default=100, type=int) 101 | self.parser.add_argument("--training_returns_freq", default=100, type=int) 102 | 103 | def parse_args(self): 104 | """ 105 | parse own arguments 106 | """ 107 | self.parse_default_args() 108 | 109 | def extract_sizes(self, spaces): 110 | """ 111 | Extract space dimensions 112 | :param spaces: list of Gym spaces 113 | :return: list of ints with sizes for each agent 114 | """ 115 | sizes = [] 116 | for space in spaces: 117 | if isinstance(space, Box): 118 | size = sum(space.shape) 119 | elif isinstance(space, Dict): 120 | size = sum(self.extract_sizes(space.values())) 121 | elif isinstance(space, Discrete) or isinstance(space, MultiBinary): 122 | size = space.n 123 | elif isinstance(space, MultiDiscrete): 124 | size = sum(space.nvec) 125 | else: 126 | raise ValueError("Unknown class of space: ", type(space)) 127 | sizes.append(size) 128 | return sizes 129 | 130 | def create_environment(self): 131 | """ 132 | Create environment instance 133 | :return: environment (gym interface), env_name, task_name, n_agents, observation_sizes, 134 | action_sizes, discrete_actions 135 | """ 136 | raise NotImplementedError() 137 | 138 | def reset_environment(self): 139 | """ 140 | Reset environment for new episode 141 | :return: observation (as torch tensor) 142 | """ 143 | raise NotImplementedError 144 | 145 | def select_actions(self, obs, explore=True): 146 | """ 147 | Select actions for agents 148 | :param obs: joint observation 149 | :param explore: flag if exploration should be used 150 | :return: action_tensor, action_list 151 | """ 152 | raise NotImplementedError() 153 | 154 | def environment_step(self, actions): 155 | """ 156 | Take step in the environment 157 | :param actions: actions to apply for each agent 158 | :return: reward, done, next_obs (as Pytorch tensors) 159 | """ 160 | raise NotImplementedError() 161 | 162 | def environment_render(self): 163 | """ 164 | Render visualisation of environment 165 | """ 166 | raise NotImplementedError() 167 | 168 | def fill_buffer(self, timesteps): 169 | """ 170 | Randomly sample actions and store experience in buffer 171 | :param timesteps: number of timesteps 172 | """ 173 | t = 0 174 | while t < timesteps: 175 | done = False 176 | obs = self.reset_environment() 177 | while not done and t < timesteps: 178 | actions = [space.sample() for space in self.action_spaces] 179 | rewards, dones, next_obs, _ = self.environment_step(actions) 180 | onehot_actions = np.zeros((len(actions), self.action_sizes[0])) 181 | onehot_actions[np.arange(len(actions)), actions] = 1 182 | self.memory.add(obs, onehot_actions, rewards, next_obs, dones) 183 | obs = next_obs 184 | t += 1 185 | done = all(dones) 186 | 187 | def eval(self, ep, n_agents): 188 | """ 189 | Execute evaluation episode without exploration 190 | :param ep: episode number 191 | :param n_agents: number of agents in task 192 | :return: returns, episode_length, done 193 | """ 194 | obs = self.reset_environment() 195 | self.alg.reset(ep) 196 | 197 | episode_returns = np.array([0.0] * n_agents) 198 | episode_length = 0 199 | done = False 200 | 201 | while not done and episode_length < self.arglist.max_episode_len: 202 | torch_obs = [ 203 | Variable(torch.Tensor(obs[i]), requires_grad=False) for i in range(n_agents) 204 | ] 205 | 206 | actions, _ = self.select_actions(torch_obs, False) 207 | rewards, dones, next_obs, _ = self.environment_step(actions) 208 | 209 | episode_returns += rewards 210 | 211 | obs = next_obs 212 | episode_length += 1 213 | done = all(dones) 214 | 215 | return episode_returns, episode_length, done 216 | 217 | def set_seeds(self, seed): 218 | """ 219 | Set random seeds before model creation 220 | :param seed (int): seed to use 221 | """ 222 | if seed is not None: 223 | random.seed(seed) 224 | np.random.seed(seed) 225 | torch.manual_seed(seed) 226 | torch.cuda.manual_seed(seed) 227 | if torch.cuda.is_available(): 228 | torch.backends.cudnn.deterministic = True 229 | torch.backends.cudnn.benchmark = False 230 | 231 | 232 | def train(self): 233 | """ 234 | Abstract training flow 235 | """ 236 | # set random seeds before model creation 237 | self.set_seeds(self.arglist.seed) 238 | 239 | # use number of threads if no GPUs are available 240 | if not USE_CUDA: 241 | torch.set_num_threads(self.arglist.n_training_threads) 242 | 243 | env, env_name, task_name, n_agents, observation_spaces, action_spaces, observation_sizes, action_sizes = ( 244 | self.create_environment() 245 | ) 246 | self.env = env 247 | self.n_agents = n_agents 248 | self.observation_spaces = observation_spaces 249 | self.action_spaces = action_spaces 250 | self.observation_sizes = observation_sizes 251 | self.action_sizes = action_sizes 252 | 253 | if self.arglist.max_episode_len == 25: 254 | steps = self.arglist.num_episodes * 20 #self.arglist.max_episode_len 255 | else: 256 | steps = self.arglist.num_episodes * self.arglist.max_episode_len 257 | # steps-th root of goal epsilon 258 | if self.arglist.epsilon_anneal_slow: 259 | decay_factor = self.arglist.epsilon_decay ** (1 / float(steps)) 260 | self.arglist.decay_factor = decay_factor 261 | print( 262 | f"Epsilon is decaying with (({self.arglist.epsilon_decay} - {decay_factor}**t) / {self.arglist.epsilon_decay}) to {self.arglist.goal_epsilon} over {steps} steps." 263 | ) 264 | else: 265 | decay_epsilon = self.arglist.goal_epsilon ** (1 / float(steps)) 266 | self.arglist.decay_factor = decay_epsilon 267 | print( 268 | "Epsilon is decaying with factor %.7f to %.3f over %d steps." 269 | % (decay_epsilon, self.arglist.goal_epsilon, steps) 270 | ) 271 | 272 | print("Observation sizes: ", observation_sizes) 273 | print("Action sizes: ", action_sizes) 274 | 275 | target_type = self.arglist.targets 276 | if not target_type in TARGET_TYPES: 277 | print(f"Invalid target type {target_type}!") 278 | return 279 | else: 280 | if target_type == "simple": 281 | print("Simple target computation used") 282 | elif target_type == "double": 283 | print("Double target computation used") 284 | elif target_type == "our-double": 285 | print("Agent-double target computation used") 286 | elif target_type == "our-clipped": 287 | print("Agent-clipped target computation used") 288 | 289 | # create algorithm trainer 290 | self.alg = IQL( 291 | n_agents, observation_sizes, action_sizes, self.arglist 292 | ) 293 | 294 | obs_size = observation_sizes[0] 295 | for o_size in observation_sizes[1:]: 296 | assert obs_size == o_size 297 | act_size = action_sizes[0] 298 | for a_size in action_sizes[1:]: 299 | assert act_size == a_size 300 | 301 | self.memory = MARLReplayBuffer( 302 | self.arglist.buffer_capacity, 303 | n_agents, 304 | ) 305 | 306 | # set random seeds past model creation 307 | self.set_seeds(self.arglist.seed) 308 | 309 | self.model_saver = ModelSaver("models", self.arglist.run) 310 | self.logger = Logger( 311 | n_agents, 312 | task_name, 313 | self.arglist.run, 314 | ) 315 | 316 | self.fill_buffer(5000) 317 | 318 | print("Starting iterations...") 319 | start_time = time.process_time() 320 | # timer = time.process_time() 321 | # env_time = 0 322 | # step_time = 0 323 | # update_time = 0 324 | # after_ep_time = 0 325 | 326 | t = 0 327 | training_returns_saved = 0 328 | 329 | episode_returns = [] 330 | episode_agent_returns = [] 331 | for ep in range(self.arglist.num_episodes): 332 | obs = self.reset_environment() 333 | self.alg.reset(ep) 334 | 335 | # episode_returns = np.array([0.0] * n_agents) 336 | episode_length = 0 337 | done = False 338 | 339 | while not done and episode_length < self.arglist.max_episode_len: 340 | torch_obs = [ 341 | Variable(torch.Tensor(obs[i]), requires_grad=False) for i in range(n_agents) 342 | ] 343 | 344 | # env_time += time.process_time() - timer 345 | # timer = time.process_time() 346 | actions, onehot_actions = self.select_actions(torch_obs) 347 | # step_time += time.process_time() - timer 348 | # timer = time.process_time() 349 | rewards, dones, next_obs, info = self.environment_step(actions) 350 | 351 | # episode_returns += rewards 352 | 353 | self.memory.add(obs, onehot_actions, rewards, next_obs, dones) 354 | 355 | t += 1 356 | 357 | # env_time += time.process_time() - timer 358 | # timer = time.process_time() 359 | if ( 360 | len(self.memory) >= self.arglist.batch_size 361 | and (t % self.arglist.steps_per_update) == 0 362 | ): 363 | losses = self.alg.update(self.memory, USE_CUDA) 364 | self.logger.log_losses(ep, losses) 365 | #self.logger.dump_losses(1) 366 | 367 | # update_time += time.process_time() - timer 368 | # timer = time.process_time() 369 | # for displaying learned policies 370 | if self.arglist.render: 371 | self.environment_render() 372 | 373 | obs = next_obs 374 | episode_length += 1 375 | done = all(dones) 376 | 377 | if done or episode_length == self.arglist.max_episode_len: 378 | episode_returns.append(info["episode_reward"]) 379 | agent_returns = [] 380 | for i in range(n_agents): 381 | agent_returns.append(info[f"agent{i}/episode_reward"]) 382 | episode_agent_returns.append(agent_returns) 383 | 384 | 385 | # env_time += time.process_time() - timer 386 | # timer = time.process_time() 387 | if (training_returns_saved + 1) * t >= self.arglist.training_returns_freq: 388 | training_returns_saved += 1 389 | returns = np.array(episode_returns[-10:]) 390 | mean_return = returns.mean() 391 | agent_returns = np.array(episode_agent_returns[-10:]) 392 | mean_agent_return = agent_returns.mean(axis=0) 393 | 394 | self.logger.log_training_returns(t, mean_return, mean_agent_return) 395 | 396 | if ep % self.arglist.eval_frequency == 0: 397 | eval_returns = np.zeros((self.arglist.eval_episodes, n_agents)) 398 | for i in range(self.arglist.eval_episodes): 399 | ep_returns, _, _ = self.eval(ep, n_agents) 400 | eval_returns[i, :] = ep_returns 401 | self.logger.log_episode( 402 | ep, eval_returns.mean(0), eval_returns.var(0), self.alg.agents[0].epsilon 403 | ) 404 | self.logger.dump_episodes(1) 405 | if ep % 100 == 0 and ep > 0: 406 | duration = time.process_time() - start_time 407 | self.logger.dump_train_progress(ep, self.arglist.num_episodes, duration) 408 | 409 | if ep % self.arglist.save_interval == 0 and ep > 0: 410 | # save models 411 | print("Remove previous models") 412 | self.model_saver.clear_models() 413 | print("Saving intermediate models") 414 | self.model_saver.save_models(self.alg, str(ep)) 415 | # save logs 416 | print("Remove previous logs") 417 | self.logger.clear_logs() 418 | print("Saving intermediate logs") 419 | self.logger.save_training_returns(extension=str(ep)) 420 | self.logger.save_episodes(extension=str(ep)) 421 | self.logger.save_losses(extension=str(ep)) 422 | # save parameter log 423 | self.logger.save_parameters( 424 | env_name, 425 | task_name, 426 | n_agents, 427 | observation_sizes, 428 | action_sizes, 429 | self.arglist, 430 | ) 431 | 432 | # after_ep_time += time.process_time() - timer 433 | # timer = time.process_time() 434 | # print(f"Episode {ep} times:") 435 | # print(f"\tEnv time: {env_time}s") 436 | # print(f"\tStep time: {step_time}s") 437 | # print(f"\tUpdate time: {update_time}s") 438 | # print(f"\tAfter Ep time: {after_ep_time}s") 439 | # env_time = 0 440 | # step_time = 0 441 | # update_time = 0 442 | # after_ep_time = 0 443 | 444 | duration = time.process_time() - start_time 445 | print("Overall duration: %.2fs" % duration) 446 | 447 | # save models 448 | print("Remove previous models") 449 | self.model_saver.clear_models() 450 | print("Saving final models") 451 | self.model_saver.save_models(self.alg, "final") 452 | 453 | # save logs 454 | print("Remove previous logs") 455 | self.logger.clear_logs() 456 | print("Saving final logs") 457 | self.logger.save_episodes(extension="final") 458 | self.logger.save_losses(extension="final") 459 | self.logger.save_duration_cuda(duration, torch.cuda.is_available()) 460 | 461 | # save parameter log 462 | self.logger.save_parameters( 463 | env_name, 464 | task_name, 465 | n_agents, 466 | observation_sizes, 467 | action_sizes, 468 | self.arglist, 469 | ) 470 | 471 | env.close() 472 | 473 | if __name__ == "__main__": 474 | train = Train() 475 | train.train() 476 | --------------------------------------------------------------------------------