├── coop_marl ├── models │ ├── __init__.py │ └── modules.py ├── envs │ ├── overcooked │ │ ├── gym_cooking │ │ │ ├── misc │ │ │ │ ├── __init__.py │ │ │ │ └── game │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── graphics │ │ │ │ │ ├── Plate.png │ │ │ │ │ ├── plate.png │ │ │ │ │ ├── blender.png │ │ │ │ │ ├── FreshOnion.png │ │ │ │ │ ├── agent-blue.png │ │ │ │ │ ├── arrow_down.png │ │ │ │ │ ├── arrow_left.png │ │ │ │ │ ├── arrow_up.png │ │ │ │ │ ├── blender2.png │ │ │ │ │ ├── blender3.png │ │ │ │ │ ├── cutboard.png │ │ │ │ │ ├── delivery.png │ │ │ │ │ ├── ChoppedOnion.png │ │ │ │ │ ├── FreshCarrot.png │ │ │ │ │ ├── FreshLettuce.png │ │ │ │ │ ├── FreshTomato.png │ │ │ │ │ ├── MashedCarrot.png │ │ │ │ │ ├── agent-green.png │ │ │ │ │ ├── agent-yellow.png │ │ │ │ │ ├── arrow_right.png │ │ │ │ │ ├── ChoppedCarrot.png │ │ │ │ │ ├── ChoppedLettuce.png │ │ │ │ │ ├── ChoppedTomato.png │ │ │ │ │ ├── agent-magenta.png │ │ │ │ │ ├── InProgressCarrot.png │ │ │ │ │ ├── ChoppedLettuce-ChoppedOnion.png │ │ │ │ │ ├── ChoppedOnion-ChoppedTomato.png │ │ │ │ │ ├── ChoppedLettuce-ChoppedTomato.png │ │ │ │ │ └── ChoppedLettuce-ChoppedOnion-ChoppedTomato.png │ │ │ │ │ ├── screenshots │ │ │ │ │ ├── open_room_blender_agents1_03-01-22_01-42-09.png │ │ │ │ │ └── open_room_blender_agents1_03-01-22_01-42-33.png │ │ │ │ │ └── utils.py │ │ │ ├── utils │ │ │ │ ├── __init__.py │ │ │ │ └── new_style_level │ │ │ │ │ ├── open_room_tomato_salad.json │ │ │ │ │ ├── open_room_salad_easy.json │ │ │ │ │ ├── full_divider_salad_easy.json │ │ │ │ │ ├── open_room_tomato_salad_r.json │ │ │ │ │ ├── open_room_blender.json │ │ │ │ │ ├── full_divider_salad_more_ingred.json │ │ │ │ │ ├── full_divider_salad_static.json │ │ │ │ │ ├── full_divider_salad_2.json │ │ │ │ │ ├── full_divider_salad_3.json │ │ │ │ │ ├── full_divider_salad_4.json │ │ │ │ │ ├── open_room_salad.json │ │ │ │ │ └── full_divider_salad.json │ │ │ ├── cooking_book │ │ │ │ ├── __init__.py │ │ │ │ ├── recipe.py │ │ │ │ └── recipe_drawer.py │ │ │ ├── cooking_world │ │ │ │ ├── __init__.py │ │ │ │ ├── constants.py │ │ │ │ ├── abstract_classes.py │ │ │ │ └── world_objects.py │ │ │ ├── environment │ │ │ │ ├── game │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── game.py │ │ │ │ ├── graphics │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── graphic_store.py │ │ │ │ ├── __init__.py │ │ │ │ └── environment.py │ │ │ ├── __init__.py │ │ │ ├── test.py │ │ │ └── demo_multiplayer_gameplay.py │ │ ├── setup.py │ │ └── overcooked_maker.py │ ├── __init__.py │ ├── gym_maker.py │ ├── mpe │ │ └── _mpe_utils │ │ │ └── simple_env.py │ └── one_step_matrix.py ├── worker │ └── __init__.py ├── evaluation │ └── __init__.py ├── controllers │ ├── __init__.py │ └── controllers.py ├── runners │ ├── __init__.py │ └── runners.py ├── utils │ ├── __init__.py │ ├── rebar │ │ ├── README.md │ │ ├── setup.py │ │ └── rebar │ │ │ ├── __init__.py │ │ │ ├── interrupting.py │ │ │ ├── storing.py │ │ │ ├── contextlib.py │ │ │ ├── widgets.py │ │ │ ├── stats │ │ │ ├── gpu.py │ │ │ ├── __init__.py │ │ │ ├── categories.py │ │ │ ├── writing.py │ │ │ └── reading.py │ │ │ ├── recurrence.py │ │ │ ├── paths.py │ │ │ ├── parallel.py │ │ │ ├── numpy.py │ │ │ ├── queuing.py │ │ │ └── logging.py │ ├── nn.py │ ├── metrics.py │ ├── logger.py │ ├── parser.py │ └── utils.py ├── agents │ ├── __init__.py │ └── agent.py └── trainers │ ├── __init__.py │ └── trainer.py ├── .gitignore ├── config ├── envs │ ├── rendezvous.yaml │ ├── one_step_matrix.yaml │ └── overcooked.yaml └── algs │ ├── maven │ ├── overcooked.yaml │ ├── rendezvous.yaml │ ├── one_step_matrix.yaml │ └── default.yaml │ ├── multi_sp │ ├── rendezvous.yaml │ ├── overcooked.yaml │ ├── one_step_matrix.yaml │ └── default.yaml │ ├── sp_mi │ ├── rendezvous.yaml │ ├── overcooked.yaml │ ├── one_step_matrix.yaml │ └── default.yaml │ ├── trajedi │ ├── rendezvous.yaml │ ├── overcooked.yaml │ ├── one_step_matrix.yaml │ └── default.yaml │ ├── incompat │ ├── rendezvous.yaml │ ├── overcooked.yaml │ ├── one_step_matrix.yaml │ └── default.yaml │ ├── meta │ ├── overcooked.yaml │ └── default.yaml │ └── default.yaml ├── setup.py ├── scripts ├── overcooked │ ├── multi_sp.sh │ ├── lipo.sh │ ├── trajedi.sh │ ├── multi_maven.sh │ └── multi_sp_mi.sh ├── pmr-c │ ├── multi_sp.sh │ ├── multi_maven.sh │ ├── sp_mi.sh │ ├── maven.sh │ ├── multi_sp_mi.sh │ ├── lipo.sh │ └── trajedi.sh ├── pmr-l │ ├── multi_sp.sh │ ├── sp_mi.sh │ ├── maven.sh │ ├── lipo.sh │ ├── multi_sp_mi.sh │ ├── multi_maven.sh │ └── trajedi.sh ├── cmg-h │ ├── multi_sp.sh │ ├── lipo.sh │ ├── sp_mi.sh │ ├── trajedi.sh │ ├── multi_sp_mi.sh │ ├── multi_maven.sh │ └── maven.sh └── cmg-s │ ├── multi_sp.sh │ ├── sp_mi.sh │ ├── multi_sp_mi.sh │ ├── trajedi.sh │ ├── lipo.sh │ ├── multi_maven.sh │ └── maven.sh ├── requirements.txt ├── install.sh ├── main.py ├── README.md └── gif_view.py /coop_marl/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/cooking_book/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/cooking_world/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/environment/game/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /coop_marl/worker/__init__.py: -------------------------------------------------------------------------------- 1 | from .worker import RolloutWorker 2 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/environment/graphics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /coop_marl/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from coop_marl.evaluation.eval import * 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.sublime-project 3 | 4 | *.sublime-workspace 5 | 6 | *__pycache__/ 7 | *.pyc 8 | *.egg-info/ 9 | 10 | *results* 11 | tmp/ 12 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/Plate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/Plate.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/plate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/plate.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/blender.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/blender.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/FreshOnion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/FreshOnion.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/agent-blue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/agent-blue.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/arrow_down.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/arrow_down.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/arrow_left.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/arrow_left.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/arrow_up.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/arrow_up.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/blender2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/blender2.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/blender3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/blender3.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/cutboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/cutboard.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/delivery.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/delivery.png -------------------------------------------------------------------------------- /config/envs/rendezvous.yaml: -------------------------------------------------------------------------------- 1 | # Format goes like this 2 | # trainer_name: 3 | # k: v 4 | 5 | name: rendezvous 6 | horizon: 50 7 | n_landmarks: 4 8 | partner_obs: True 9 | mode: easy 10 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/ChoppedOnion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/ChoppedOnion.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/FreshCarrot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/FreshCarrot.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/FreshLettuce.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/FreshLettuce.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/FreshTomato.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/FreshTomato.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/MashedCarrot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/MashedCarrot.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/agent-green.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/agent-green.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/agent-yellow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/agent-yellow.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/arrow_right.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/arrow_right.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/ChoppedCarrot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/ChoppedCarrot.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/ChoppedLettuce.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/ChoppedLettuce.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/ChoppedTomato.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/ChoppedTomato.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/agent-magenta.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/agent-magenta.png -------------------------------------------------------------------------------- /coop_marl/controllers/__init__.py: -------------------------------------------------------------------------------- 1 | from coop_marl.controllers.controllers import RandomController, PSController, MappingController 2 | 3 | __all__ = ['RandomController', 'PSController', 'MappingController'] -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/InProgressCarrot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/InProgressCarrot.png -------------------------------------------------------------------------------- /coop_marl/runners/__init__.py: -------------------------------------------------------------------------------- 1 | from coop_marl.runners.runners import EpisodesRunner, StepsRunner 2 | 3 | __all__ = ['EpisodesRunner', 'StepsRunner'] 4 | 5 | registered_runners = {a:eval(a) for a in __all__} 6 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/ChoppedLettuce-ChoppedOnion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/ChoppedLettuce-ChoppedOnion.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/ChoppedOnion-ChoppedTomato.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/ChoppedOnion-ChoppedTomato.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/environment/__init__.py: -------------------------------------------------------------------------------- 1 | from gym_cooking.environment.environment import GymCookingEnvironment 2 | from gym_cooking.environment.cooking_zoo import CookingEnvironment as CookingZooEnvironment 3 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/ChoppedLettuce-ChoppedTomato.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/ChoppedLettuce-ChoppedTomato.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/ChoppedLettuce-ChoppedOnion-ChoppedTomato.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/graphics/ChoppedLettuce-ChoppedOnion-ChoppedTomato.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, find_packages 4 | 5 | setup(name='coop_marl', 6 | version='0.0.1', 7 | description='', 8 | packages=find_packages(), 9 | install_requires=[] 10 | ) -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/screenshots/open_room_blender_agents1_03-01-22_01-42-09.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/screenshots/open_room_blender_agents1_03-01-22_01-42-09.png -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/screenshots/open_room_blender_agents1_03-01-22_01-42-33.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/51616/marl-lipo/HEAD/coop_marl/envs/overcooked/gym_cooking/misc/game/screenshots/open_room_blender_agents1_03-01-22_01-42-33.png -------------------------------------------------------------------------------- /config/algs/maven/overcooked.yaml: -------------------------------------------------------------------------------- 1 | def_config: !include config/algs/maven/default.yaml 2 | 3 | lr: 0.0005 4 | z_dim: 4 5 | n_iter: 30000 6 | n_sp_episodes: 8 7 | eval_interval: 10000 8 | discrim_coef: 0.1 9 | gamma: 0.99 10 | start_e: 1 11 | end_e: 0.05 12 | explore_decay_ts: 1000000 13 | -------------------------------------------------------------------------------- /config/algs/maven/rendezvous.yaml: -------------------------------------------------------------------------------- 1 | def_config: !include config/algs/maven/default.yaml 2 | 3 | lr: 0.0003 4 | z_dim: 4 5 | n_iter: 1000 6 | n_sp_episodes: 15 7 | eval_interval: 1000 8 | discrim_coef: 0.1 9 | gamma: 0.99 10 | start_e: 1 11 | end_e: 0.05 12 | explore_decay_ts: 200000 13 | -------------------------------------------------------------------------------- /config/algs/multi_sp/rendezvous.yaml: -------------------------------------------------------------------------------- 1 | # Format goes like this 2 | # def_config: 3 | # k: v 4 | # config: 5 | # k: v 6 | def_config: !include config/algs/multi_sp/default.yaml 7 | 8 | pop_size: 4 9 | n_iter: 300 10 | eval_interval: 100 11 | epcohs: 10 12 | num_mb: 2 13 | lr: 0.0003 14 | -------------------------------------------------------------------------------- /coop_marl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from rebar import dotdict, arrdict 2 | 3 | Dotdict = dotdict.dotdict 4 | Arrdict = arrdict.arrdict 5 | 6 | from .utils import * 7 | from .nn import * 8 | from .rl import * 9 | from .logger import * 10 | from .parser import * 11 | from .metrics import * 12 | 13 | -------------------------------------------------------------------------------- /coop_marl/utils/rebar/README.md: -------------------------------------------------------------------------------- 1 | # rebar 2 | Reinforcement learning utils from https://github.com/andyljones/megastep. 3 | 4 | rebar is Andy Jones’s personal reinforcement learning toolbox. 5 | I'm just merely using his toolkit in my project. All credit goes to [Andy Jones](https://andyljones.com/) 6 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register(id="cookingEnv-v1", 4 | entry_point="gym_cooking.environment:GymCookingEnvironment") 5 | register(id="cookingZooEnv-v0", 6 | entry_point="gym_cooking.environment:CookingZooEnvironment") 7 | -------------------------------------------------------------------------------- /config/algs/sp_mi/rendezvous.yaml: -------------------------------------------------------------------------------- 1 | # Format goes like this 2 | # def_config: 3 | # k: v 4 | # config: 5 | # k: v 6 | 7 | def_config: !include config/algs/sp_mi/default.yaml 8 | 9 | n_iter: 300 10 | eval_interval: 200 11 | z_dim: 4 12 | discrim_coef: 5.0 13 | epcohs: 10 14 | num_mb: 2 15 | lr: 0.0003 16 | -------------------------------------------------------------------------------- /config/envs/one_step_matrix.yaml: -------------------------------------------------------------------------------- 1 | # Format goes like this 2 | # trainer_name: 3 | # k: v 4 | 5 | name: one_step_matrix 6 | n_conventions: 32 7 | k: [32,31,30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1] # 8 8 | payoffs: np.ones(n_conventions) # np.linspace(1.0,0.5,n_conventions) 9 | -------------------------------------------------------------------------------- /scripts/overcooked/multi_sp.sh: -------------------------------------------------------------------------------- 1 | for seed in 111 222 333 444 555 2 | do 3 | xvfb-run -a python main.py --config_file config/algs/multi_sp/overcooked.yaml \ 4 | --env_config_file config/envs/overcooked.yaml \ 5 | --config '{"pop_size": 8, "render": 0, "save_folder":"training_partners_8"}' --seed $seed 6 | done 7 | -------------------------------------------------------------------------------- /config/algs/trajedi/rendezvous.yaml: -------------------------------------------------------------------------------- 1 | # Format goes like this 2 | # def_config: 3 | # k: v 4 | # config: 5 | # k: v 6 | def_config: !include config/algs/trajedi/default.yaml 7 | 8 | n_iter: 300 9 | eval_interval: 200 10 | pop_size: 4 11 | diverse_coef: 10.0 12 | kernel_gamma: 0.0 13 | lr: 0.0003 14 | num_mb: 2 15 | epochs: 10 16 | -------------------------------------------------------------------------------- /config/algs/maven/one_step_matrix.yaml: -------------------------------------------------------------------------------- 1 | def_config: !include config/algs/maven/default.yaml 2 | 3 | render: False 4 | get_q_values: True 5 | lr: 0.0003 6 | z_dim: 32 7 | n_iter: 300 8 | n_sp_episodes: 100 9 | eval_interval: 500 10 | buffer_size: 10000 11 | discrim_coef: 10 12 | start_e: 1 13 | end_e: 0.05 14 | explore_decay_ts: 20000 15 | -------------------------------------------------------------------------------- /config/algs/multi_sp/overcooked.yaml: -------------------------------------------------------------------------------- 1 | # Format goes like this 2 | # def_config: 3 | # k: v 4 | # config: 5 | # k: v 6 | def_config: !include config/algs/multi_sp/default.yaml 7 | 8 | runner: StepsRunner 9 | pop_size: 4 10 | n_iter: 2000 11 | n_sp_ts: 10000 12 | eval_interval: 500 13 | epochs: 15 14 | num_mb: 5 15 | lr: 0.0005 16 | ent_coef: 0.03 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mlagents_envs==0.27.0 2 | PettingZoo==1.9.0 3 | pygame==2.0.1 4 | pymunk==6.0.0 5 | ray==1.9.0 6 | pygifsicle==1.0.5 7 | pyyaml-include==1.2.post2 8 | pyglet==1.5.15 9 | opencv-python==3.4.14.53 10 | gym==0.18.3 11 | tensorboard==2.8.0 12 | xvfbwrapper==0.2.9 13 | jupyterlab 14 | wandb 15 | matplotlib 16 | plotly 17 | tqdm 18 | -------------------------------------------------------------------------------- /coop_marl/utils/rebar/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, find_packages 4 | 5 | setup(name='rebar', 6 | version='0.0.1', 7 | description='rebar is Andy Jones’s personal reinforcement learning toolbox.', 8 | packages=find_packages(), 9 | python_requires='>=3.6', 10 | install_requires=[] 11 | ) 12 | -------------------------------------------------------------------------------- /scripts/overcooked/lipo.sh: -------------------------------------------------------------------------------- 1 | for seed in 111 222 333 444 555 2 | do 3 | xvfb-run -a python main.py --config_file config/algs/incompat/overcooked.yaml \ 4 | --env_config_file config/envs/overcooked.yaml \ 5 | --config '{"discrim_coef": 0.5, "pop_size": 8, "render": 0, "save_folder": "training_partners_8", "xp_coef": 0.3, "z_dim": 8}' \ 6 | --seed $seed 7 | done -------------------------------------------------------------------------------- /scripts/overcooked/trajedi.sh: -------------------------------------------------------------------------------- 1 | for seed in 111 222 333 444 555 2 | do 3 | xvfb-run -a python main.py --config_file config/algs/trajedi/overcooked.yaml \ 4 | --env_config_file config/envs/overcooked.yaml \ 5 | --config '{"diverse_coef": 5, "kernel_gamma": 0.5, "pop_size": 8, "render": 0, "save_folder": "training_partners_8"}' --env_config '{}' --seed $seed 6 | done 7 | -------------------------------------------------------------------------------- /config/algs/sp_mi/overcooked.yaml: -------------------------------------------------------------------------------- 1 | # Format goes like this 2 | # def_config: 3 | # k: v 4 | # config: 5 | # k: v 6 | 7 | def_config: !include config/algs/sp_mi/default.yaml 8 | 9 | runner: StepsRunner 10 | n_iter: 2000 11 | n_sp_ts: 10000 12 | eval_interval: 500 13 | z_dim: 4 14 | discrim_coef: 10 15 | epochs: 15 16 | num_mb: 5 17 | lr: 0.0005 18 | ent_coef: 0.03 19 | -------------------------------------------------------------------------------- /config/algs/incompat/rendezvous.yaml: -------------------------------------------------------------------------------- 1 | # Format goes like this 2 | # def_config: 3 | # k: v 4 | # config: 5 | # k: v 6 | 7 | def_config: !include config/algs/incompat/default.yaml 8 | 9 | trainer: incompat 10 | n_iter: 300 11 | eval_interval: 300 12 | pop_size: 8 13 | z_dim: 4 14 | xp_coef: 1.0 15 | 16 | epcohs: 10 17 | num_mb: 2 18 | lr: 0.0003 19 | discrim_coef: 0.5 20 | -------------------------------------------------------------------------------- /config/algs/trajedi/overcooked.yaml: -------------------------------------------------------------------------------- 1 | # Format goes like this 2 | # def_config: 3 | # k: v 4 | # config: 5 | # k: v 6 | def_config: !include config/algs/trajedi/default.yaml 7 | 8 | runner: StepsRunner 9 | n_iter: 2000 10 | n_sp_ts: 10000 11 | eval_interval: 500 12 | pop_size: 4 13 | diverse_coef: 10.0 14 | kernel_gamma: 0.0 15 | lr: 0.0005 16 | num_mb: 5 17 | epochs: 15 18 | ent_coef: 0.03 19 | -------------------------------------------------------------------------------- /scripts/pmr-c/multi_sp.sh: -------------------------------------------------------------------------------- 1 | for pop_size in 1 2 4 8 2 | do 3 | for seed in 111 222 333 4 | do 5 | xvfb-run -a python main.py --config_file config/algs/multi_sp/rendezvous.yaml \ 6 | --env_config_file config/envs/rendezvous.yaml \ 7 | --config '{"pop_size": '"${pop_size}"', "save_folder": "results_sweep_rendezvous"}' \ 8 | --env_config '{"mode": "easy"}' --seed $seed 9 | done 10 | done 11 | -------------------------------------------------------------------------------- /scripts/pmr-l/multi_sp.sh: -------------------------------------------------------------------------------- 1 | for pop_size in 1 2 4 8 2 | do 3 | for seed in 111 222 333 4 | do 5 | xvfb-run -a python main.py --config_file config/algs/multi_sp/rendezvous.yaml \ 6 | --env_config_file config/envs/rendezvous.yaml \ 7 | --config '{"pop_size": '"${pop_size}"', "save_folder": "results_sweep_rendezvous"}' \ 8 | --env_config '{"mode": "hard"}' --seed $seed 9 | done 10 | done 11 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='cooking-gym', 4 | version='0.0.1', 5 | description='Cooking gym with graphics and ideas based on: "Too Many Cooks: Overcooked environment"', 6 | author='David Rother, Rose E. Wang', 7 | email='david@edv-drucksysteme.de', 8 | packages=find_packages() + [""], 9 | install_requires=[] 10 | ) 11 | -------------------------------------------------------------------------------- /scripts/overcooked/multi_maven.sh: -------------------------------------------------------------------------------- 1 | for seed in 111 222 333 444 555 2 | do 3 | xvfb-run -a python main.py --config_file config/algs/maven/overcooked.yaml \ 4 | --env_config_file config/envs/overcooked.yaml \ 5 | --config '{"algo_name": "multi_maven", "discrim_coef": 5, "n_iter": 30000, "n_sp_episodes": 4, "pop_size": 8, "render": 0, "save_folder": "training_partners_8", "trainer": "incompat", "z_dim": 8}' --seed $seed 6 | done 7 | -------------------------------------------------------------------------------- /scripts/overcooked/multi_sp_mi.sh: -------------------------------------------------------------------------------- 1 | for seed in 111 222 333 444 555 2 | do 3 | xvfb-run -a python main.py --config_file config/algs/sp_mi/overcooked.yaml \ 4 | --env_config_file config/envs/overcooked.yaml \ 5 | --config '{"algo_name": "multi_sp_mi", "discrim_coef": 5, "n_sp_ts": 20000, "pop_size": 8, "render": 0, "save_folder": "training_partners_8", "trainer": "incompat", "z_dim": 8}' \ 6 | --env_config '{}' --seed $seed 7 | done 8 | -------------------------------------------------------------------------------- /scripts/cmg-h/multi_sp.sh: -------------------------------------------------------------------------------- 1 | for pop_size in 8 16 32 64 2 | do 3 | for seed in 111 222 333 4 | do 5 | xvfb-run -a python main.py --config_file config/algs/multi_sp/one_step_matrix.yaml \ 6 | --env_config_file config/envs/one_step_matrix.yaml \ 7 | --config '{"pop_size": '"${pop_size}"', "save_folder": "results_sweep_one_step_matrix_uneven_m32"}' \ 8 | --env_config '{}' --seed $seed 9 | done 10 | done 11 | -------------------------------------------------------------------------------- /scripts/cmg-s/multi_sp.sh: -------------------------------------------------------------------------------- 1 | for pop_size in 8 16 32 64 2 | do 3 | for seed in 111 222 333 4 | do 5 | xvfb-run -a python main.py --config_file config/algs/multi_sp/one_step_matrix.yaml \ 6 | --env_config_file config/envs/one_step_matrix.yaml \ 7 | --config '{"pop_size": '"${pop_size}"', "save_folder": "results_sweep_one_step_matrix_k_8"}' \ 8 | --env_config '{"k": 8}' --seed $seed 9 | done 10 | done 11 | -------------------------------------------------------------------------------- /config/algs/sp_mi/one_step_matrix.yaml: -------------------------------------------------------------------------------- 1 | # Format goes like this 2 | # def_config: 3 | # k: v 4 | # config: 5 | # k: v 6 | 7 | def_config: !include config/algs/sp_mi/default.yaml 8 | 9 | render: False 10 | get_act_dist: True 11 | n_iter: 300 12 | eval_interval: 300 13 | n_sp_episodes: 100 14 | n_xp_episodes: 100 15 | pop_size: 1 16 | z_dim: 32 17 | lr: 0.0003 18 | ent_coef: 0.0 19 | discrim_coef: 1.0 20 | xp_coef: 0.0 21 | num_mb: 2 22 | epochs: 10 23 | -------------------------------------------------------------------------------- /config/algs/incompat/overcooked.yaml: -------------------------------------------------------------------------------- 1 | # Format goes like this 2 | # def_config: 3 | # k: v 4 | # config: 5 | # k: v 6 | 7 | def_config: !include config/algs/incompat/default.yaml 8 | 9 | runner: StepsRunner 10 | trainer: incompat 11 | n_iter: 2000 12 | n_sp_ts: 10000 13 | n_xp_ts: 10000 14 | eval_interval: 500 15 | ent_coef: 0.03 16 | pop_size: 4 17 | z_dim: 8 18 | xp_coef: 0.2 19 | 20 | epochs: 15 21 | num_mb: 5 22 | lr: 0.0005 23 | discrim_coef: 0.1 24 | -------------------------------------------------------------------------------- /config/algs/multi_sp/one_step_matrix.yaml: -------------------------------------------------------------------------------- 1 | # Format goes like this 2 | # def_config: 3 | # k: v 4 | # config: 5 | # k: v 6 | 7 | def_config: !include config/algs/multi_sp/default.yaml 8 | 9 | render: False 10 | get_act_dist: True 11 | n_iter: 300 12 | eval_interval: 300 13 | n_sp_episodes: 100 14 | n_xp_episodes: 100 15 | pop_size: 32 16 | z_dim: 8 17 | lr: 0.0003 18 | ent_coef: 0.0 19 | discrim_coef: 0.0 20 | xp_coef: 0.0 21 | 22 | num_mb: 2 23 | epochs: 10 24 | -------------------------------------------------------------------------------- /scripts/cmg-h/lipo.sh: -------------------------------------------------------------------------------- 1 | for pop_size in 8 16 32 64 2 | do 3 | for seed in 111 222 333 4 | do 5 | xvfb-run -a python main.py --config_file config/algs/incompat/one_step_matrix.yaml \ 6 | --env_config_file config/envs/one_step_matrix.yaml \ 7 | --config '{"num_xp_pair_sample": 64, "pop_size": '"${pop_size}"', "save_folder": "results_sweep_one_step_matrix_uneven_m32", "xp_coef": 1}' --env_config '{}' --seed $seed 8 | done 9 | done 10 | -------------------------------------------------------------------------------- /config/algs/trajedi/one_step_matrix.yaml: -------------------------------------------------------------------------------- 1 | # Format goes like this 2 | # def_config: 3 | # k: v 4 | # config: 5 | # k: v 6 | def_config: !include config/algs/trajedi/default.yaml 7 | 8 | render: False 9 | get_act_dist: True 10 | n_iter: 300 11 | eval_interval: 300 12 | n_sp_episodes: 100 13 | n_xp_episodes: 100 14 | pop_size: 32 15 | z_dim: 8 16 | ent_coef: 0.0 17 | discrim_coef: 0.0 18 | diverse_coef: 10.0 19 | kernel_gamma: 0.0 20 | lr: 0.0003 21 | num_mb: 2 22 | epochs: 10 23 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/cooking_world/constants.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class ChopFoodStates(Enum): 5 | FRESH = "Fresh" 6 | CHOPPED = "Chopped" 7 | 8 | 9 | class BlenderFoodStates(Enum): 10 | FRESH = "Fresh" 11 | IN_PROGRESS = "InProgress" 12 | MASHED = "Mashed" 13 | 14 | 15 | ONION_INIT_STATE = ChopFoodStates.FRESH 16 | TOMATO_INIT_STATE = ChopFoodStates.FRESH 17 | LETTUCE_INIT_STATE = ChopFoodStates.FRESH 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /config/algs/incompat/one_step_matrix.yaml: -------------------------------------------------------------------------------- 1 | # Format goes like this 2 | # def_config: 3 | # k: v 4 | # config: 5 | # k: v 6 | 7 | def_config: !include config/algs/incompat/default.yaml 8 | 9 | render: False 10 | eval_all_pairs: False 11 | get_act_dist: True 12 | n_iter: 300 13 | eval_interval: 300 14 | n_sp_episodes: 100 15 | n_xp_episodes: 100 16 | pop_size: 32 17 | z_dim: 8 18 | lr: 0.0003 19 | ent_coef: 0.0 20 | discrim_coef: 0.0 21 | xp_coef: 1.0 22 | num_xp_pair_sample: 32 23 | num_mb: 2 24 | epochs: 10 25 | -------------------------------------------------------------------------------- /scripts/cmg-h/sp_mi.sh: -------------------------------------------------------------------------------- 1 | for pop_size in 8 16 32 64 2 | do 3 | for seed in 111 222 333 4 | do 5 | xvfb-run -a python main.py --config_file config/algs/sp_mi/one_step_matrix.yaml \ 6 | --env_config_file config/envs/one_step_matrix.yaml \ 7 | --config '{"discrim_coef": 50, "n_sp_episodes": 6400, "n_workers": 16, "save_folder": "results_sweep_one_step_matrix_uneven_m32", "trainer": "simple", "z_dim": '"${pop_size}"'}' \ 8 | --env_config '{}' --seed $seed 9 | done 10 | done 11 | -------------------------------------------------------------------------------- /scripts/cmg-s/sp_mi.sh: -------------------------------------------------------------------------------- 1 | for pop_size in 8 16 32 64 2 | do 3 | for seed in 111 222 333 4 | do 5 | xvfb-run -a python main.py --config_file config/algs/sp_mi/one_step_matrix.yaml \ 6 | --env_config_file config/envs/one_step_matrix.yaml \ 7 | --config '{"discrim_coef": 50, "n_sp_episodes": 6400, "n_workers": 16, "save_folder": "results_sweep_one_step_matrix_k_8", "trainer": "simple", "z_dim": '"${pop_size}"'}' \ 8 | --env_config '{"k": 8}' --seed $seed 9 | done 10 | done 11 | -------------------------------------------------------------------------------- /config/envs/overcooked.yaml: -------------------------------------------------------------------------------- 1 | # Format goes like this 2 | # trainer_name: 3 | # k: v 4 | 5 | name: overcooked 6 | mode: full_divider_salad_4 7 | horizon: 200 8 | recipes: ["LettuceSalad", 9 | "TomatoSalad", 10 | "TomatoLettuceSalad", 11 | "TomatoCarrotSalad", 12 | "ChoppedCarrot", 13 | "ChoppedOnion" 14 | ] 15 | obs_spaces: dense 16 | num_agents: 2 17 | interact_reward: 0.5 18 | progress_reward: 1.0 19 | complete_reward: 10.0 20 | step_cost: 0.05 21 | -------------------------------------------------------------------------------- /coop_marl/utils/rebar/rebar/__init__.py: -------------------------------------------------------------------------------- 1 | """**rebar** helps with reinforcement. That's why it's called rebar! It's a toolkit that has evolved 2 | as I've worked on RL projects. 3 | 4 | Unlike the :mod:`megastep` module which is stable, documented and feature-complete, rebar is an unstable, 5 | undocumented work-in-progress. It's in the megastep repo because megastep itself uses two of rebar's most useful components: 6 | :class:`~rebar.dotdict.dotdict` and :class:`~rebar.arrdict.arrdict`, while the demo uses a whole lot more. 7 | """ -------------------------------------------------------------------------------- /coop_marl/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from coop_marl.envs.gym_maker import GymMaker 2 | from coop_marl.envs.mpe.rendezvous import Rendezvous 3 | from coop_marl.envs.overcooked.overcooked_maker import OvercookedMaker 4 | from coop_marl.envs.one_step_matrix import OneStepMatrixGame 5 | 6 | registered_envs = {} 7 | registered_envs['gym_maker'] = GymMaker.make_env 8 | registered_envs['rendezvous'] = Rendezvous.make_env 9 | registered_envs['overcooked'] = OvercookedMaker.make_env 10 | registered_envs['one_step_matrix'] = OneStepMatrixGame.make_env 11 | -------------------------------------------------------------------------------- /scripts/cmg-s/multi_sp_mi.sh: -------------------------------------------------------------------------------- 1 | for pop_size in 1 2 4 8 2 | do 3 | for seed in 111 222 333 4 | do 5 | xvfb-run -a python main.py --config_file config/algs/sp_mi/one_step_matrix.yaml \ 6 | --env_config_file config/envs/one_step_matrix.yaml \ 7 | --config '{"algo_name": "multi_sp_mi", "discrim_coef": 50, "n_sp_episodes": 800, "n_workers": 16, "pop_size": '"${pop_size}"', "save_folder": "results_sweep_one_step_matrix_k_8", "trainer": "incompat", "vary_z_eval": 1, "z_dim": 8}' \ 8 | --env_config '{"k": 8}' --seed $seed 9 | done 10 | done 11 | -------------------------------------------------------------------------------- /scripts/pmr-c/multi_maven.sh: -------------------------------------------------------------------------------- 1 | for pop_size in 1 2 4 2 | do 3 | for seed in 111 222 333 4 | do 5 | xvfb-run -a python main.py --config_file config/algs/maven/rendezvous.yaml \ 6 | --env_config_file config/envs/rendezvous.yaml \ 7 | --config '{"algo_name": "multi_maven", "discrim_coef": 10, "n_iter": 4000, "n_sp_episodes": 30, "n_workers": 16, "pop_size": '"${pop_size}"', "save_folder": "results_sweep_rendezvous", "trainer": "incompat", "vary_z_eval": 1, "z_dim": 8}' \ 8 | --env_config '{"mode": "easy"}' --seed $seed 9 | done 10 | done 11 | -------------------------------------------------------------------------------- /coop_marl/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from coop_marl.agents.agent import Agent 2 | 3 | from coop_marl.agents.qmix import QMIXAgent 4 | from coop_marl.agents.incompat_mappo_z import IncompatMAPPOZ 5 | from coop_marl.agents.mappo_trajedi import MAPPOTrajeDiAgent 6 | from coop_marl.agents.mappo_rl2 import MAPPORL2Agent 7 | 8 | __all__ = ['Agent', 9 | 'QMIXAgent', 10 | 'IncompatMAPPOZ', 11 | 'MAPPOTrajeDiAgent', 12 | 'MAPPORL2Agent'] 13 | 14 | registered_agents = {a:eval(a) for a in __all__} # dict([(a,eval(a)) for a in __all__]) 15 | -------------------------------------------------------------------------------- /coop_marl/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from coop_marl.trainers.trainer import Trainer, trainer_setup, population_based_setup, population_evaluation, collect_data 2 | from coop_marl.trainers.simple import SimplePSTrainer 3 | from coop_marl.trainers.incompat import IncompatTrainer 4 | from coop_marl.trainers.trajedi import TrajeDiTrainer 5 | from coop_marl.trainers.meta import MetaTrainer 6 | 7 | registered_trainers = {} 8 | 9 | registered_trainers['simple'] = SimplePSTrainer 10 | registered_trainers['incompat'] = IncompatTrainer 11 | registered_trainers['trajedi'] = TrajeDiTrainer 12 | registered_trainers['meta'] = MetaTrainer 13 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | sudo apt-get install gifsicle xvfb -y 3 | conda install -c conda-forge cudatoolkit=11.1 cudnn=8.4.1 4 | pip install --upgrade pip 5 | pip install setuptools==66 6 | pip install wheel==0.38.4 7 | pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html 8 | pip install -r requirements.txt 9 | pip install protobuf==3.20 10 | 11 | # install rebar 12 | cd coop_marl/utils/ 13 | git clone https://github.com/51616/rebar.git 14 | cd rebar 15 | python setup.py develop 16 | cd ../../.. 17 | 18 | cd coop_marl/envs/overcooked/ 19 | pip install -e . 20 | cd ../../.. 21 | 22 | # install coop_marl 23 | python setup.py develop 24 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/test.py: -------------------------------------------------------------------------------- 1 | from gym_cooking.environment import cooking_zoo 2 | 3 | n_agents = 1 4 | num_humans = 1 5 | max_steps = 100 6 | render = False 7 | 8 | level = 'open_room_salad_easy' 9 | seed = 1 10 | record = False 11 | max_num_timesteps = 1000 12 | recipes = ["LettuceSalad", 'LettuceSalad'] 13 | 14 | env = parallel_env = cooking_zoo.parallel_env(level=level, num_agents=n_agents, record=record, 15 | max_steps=max_num_timesteps, recipes=recipes, obs_spaces=["dense"], 16 | interact_reward=0.5, progress_reward=1.0, complete_reward=10.0, 17 | step_cost=0.05) 18 | obs = env.reset() 19 | print(obs) 20 | -------------------------------------------------------------------------------- /config/algs/meta/overcooked.yaml: -------------------------------------------------------------------------------- 1 | 2 | def_config: !include config/algs/meta/default.yaml 3 | 4 | runner: StepsRunner 5 | n_iter: 500 6 | n_ts: 320000 7 | eval_interval: 100 8 | n_eval_ep: 2 9 | render: False 10 | training_device: 'cuda' 11 | 12 | n_workers: 16 # 8 13 | critic_use_local_obs: True 14 | anneal_lr: True 15 | num_anneal_iter: 500 16 | min_anneal_lr: 0.0003 17 | lr: 0.0005 18 | gamma: 0.99 19 | gae_lambda: 0.95 20 | ent_coef: 0.03 21 | clip_param: 0.3 22 | vf_clip_param: 10 23 | vf_coef: 1.0 24 | max_len: 50 25 | num_seq_mb: 1600 # num_seq_mb * max_len timesteps per minibatch -> n_ts/(max_len*num_seq_mb) minibatches 26 | num_mb: 0 27 | mb_size: 0 28 | epochs: 15 29 | env_wrappers: [ZWrapper, AgentIDWrapper, StateWrapper] 30 | z_dim: 8 31 | 32 | partner_dir: [] 33 | partner_iterations: null 34 | -------------------------------------------------------------------------------- /scripts/cmg-h/trajedi.sh: -------------------------------------------------------------------------------- 1 | for pop_size in 8 16 64 2 | do 3 | for seed in 111 222 333 4 | do 5 | xvfb-run -a python main.py --config_file config/algs/trajedi/one_step_matrix.yaml \ 6 | --env_config_file config/envs/one_step_matrix.yaml \ 7 | --config '{"diverse_coef": 0.1, "pop_size": '"${pop_size}"', "save_folder": "results_sweep_one_step_matrix_uneven_m32"}' \ 8 | --env_config '{}' --seed $seed 9 | done 10 | done 11 | 12 | for seed in 111 222 333 13 | do 14 | xvfb-run -a python main.py --config_file config/algs/trajedi/one_step_matrix.yaml \ 15 | --env_config_file config/envs/one_step_matrix.yaml \ 16 | --config '{"diverse_coef": 0.2, "pop_size": 32, "save_folder": "results_sweep_one_step_matrix_uneven_m32"}' \ 17 | --env_config '{}' --seed $seed 18 | done 19 | -------------------------------------------------------------------------------- /scripts/cmg-s/trajedi.sh: -------------------------------------------------------------------------------- 1 | for seed in 111 222 333 2 | do 3 | xvfb-run -a python main.py --config_file config/algs/trajedi/one_step_matrix.yaml \ 4 | --env_config_file config/envs/one_step_matrix.yaml \ 5 | --config '{"diverse_coef": 0.05, "pop_size": 8, "save_folder": "results_sweep_one_step_matrix_k_8"}' \ 6 | --env_config '{"k": 8}' --seed $seed 7 | done 8 | 9 | for pop_size in 16 32 64 10 | do 11 | for seed in 111 222 333 12 | do 13 | xvfb-run -a python main.py --config_file config/algs/trajedi/one_step_matrix.yaml \ 14 | --env_config_file config/envs/one_step_matrix.yaml \ 15 | --config '{"diverse_coef": 0.01, "pop_size": '"${pop_size}"', "save_folder": "results_sweep_one_step_matrix_k_8"}' \ 16 | --env_config '{"k": 8}' --seed $seed 17 | done 18 | done 19 | -------------------------------------------------------------------------------- /coop_marl/agents/agent.py: -------------------------------------------------------------------------------- 1 | from coop_marl.utils import Arrdict 2 | 3 | class Agent: 4 | # Every agent should have these functions 5 | def __init__(self, config): 6 | self.validate_config(config) 7 | 8 | def act(self, inp): 9 | raise NotImplementedError 10 | 11 | def preprocess(self, traj): 12 | raise NotImplementedError 13 | 14 | def train(self, batch): 15 | raise NotImplementedError 16 | 17 | # def get_dummy_decision(self): 18 | # raise NotImplementedError 19 | 20 | def get_prev_decision_view(self): 21 | # raise NotImplementedError 22 | return Arrdict() 23 | 24 | def reset(self): 25 | raise NotImplementedError 26 | 27 | def validate_config(self, config): 28 | raise NotImplementedError -------------------------------------------------------------------------------- /scripts/cmg-s/lipo.sh: -------------------------------------------------------------------------------- 1 | for seed in 111 222 333 2 | do 3 | xvfb-run -a python main.py --config_file config/algs/incompat/one_step_matrix.yaml \ 4 | --env_config_file config/envs/one_step_matrix.yaml \ 5 | --config '{"num_xp_pair_sample": 64, "pop_size": 8, "save_folder": "results_sweep_one_step_matrix_k_8", "xp_coef": 0.5}' \ 6 | --env_config '{"k": 8}' --seed $seed 7 | done 8 | 9 | for pop_size in 16 32 64 10 | do 11 | for seed in 111 222 333 12 | do 13 | xvfb-run -a python main.py --config_file config/algs/incompat/one_step_matrix.yaml \ 14 | --env_config_file config/envs/one_step_matrix.yaml \ 15 | --config '{"num_xp_pair_sample": 64, "pop_size":'"${pop_size}"', "save_folder": "results_sweep_one_step_matrix_k_8", "xp_coef": 1}' \ 16 | --env_config '{"k": 8}' --seed $seed 17 | done 18 | done 19 | -------------------------------------------------------------------------------- /config/algs/default.yaml: -------------------------------------------------------------------------------- 1 | # Format goes like this 2 | # trainer_name: 3 | # k: v 4 | trainer: simple 5 | runner: EpisodesRunner 6 | render: True 7 | render_only_sp: False 8 | render_mode: 'rgb_array' 9 | vary_z_eval: False 10 | n_workers: 16 11 | flatten_traj: True 12 | num_cpus: 1 13 | use_gpu: False 14 | debug: False 15 | device: 'cpu' 16 | n_grad_cum: 0 17 | training_device: 'cpu' 18 | save_folder: 'results' 19 | checkpoint: Null 20 | run_name: '' 21 | save_interval: 0 22 | 23 | anneal_lr: False 24 | l2_reg_coef: 0.0 25 | min_action: -1 26 | max_action: 1 27 | norm_obs: True 28 | norm_ret: True 29 | norm_state: True 30 | anneal_lr: False 31 | use_value_norm: False 32 | pol_init_var: 1.0 33 | hidden_size: 64 34 | num_hidden: 2 35 | clip_v_loss: True 36 | save_dir: Null 37 | shared_z: True 38 | 39 | use_bandit: False 40 | uniform_selector_keep_last: False 41 | -------------------------------------------------------------------------------- /config/algs/trajedi/default.yaml: -------------------------------------------------------------------------------- 1 | # Format goes like this 2 | # def_config: 3 | # k: v 4 | # config: 5 | # k: v 6 | def_config: !include config/algs/default.yaml 7 | 8 | render: True 9 | render_mode: 'rgb_array' 10 | num_cpus: 1 11 | use_gpu: False 12 | debug: False 13 | 14 | algo_name: trajedi 15 | trainer: trajedi 16 | runner: EpisodesRunner 17 | agent_name: MAPPOTrajeDiAgent 18 | 19 | use_br: False 20 | n_iter: 500 21 | pop_size: 2 22 | diverse_coef: 1.0 23 | kernel_gamma: 0.0 24 | flatten_traj: True 25 | 26 | eval_interval: 50 27 | n_sp_episodes: 50 28 | n_xp_episodes: 50 29 | n_eval_ep: 10 30 | 31 | z_dim: 4 32 | z_discrete: True 33 | gamma: 0.99 34 | lr: 0.0001 35 | vf_coef: 0.5 36 | ent_coef: 0.03 37 | epochs: 5 38 | num_mb: 3 39 | gae_lambda: 0.95 40 | clip_param: 0.3 41 | vf_clip_param: 10.0 42 | 43 | env_wrappers: [ZWrapper, AgentIDWrapper, StateWrapper] -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/misc/game/utils.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | 3 | 4 | class Color: 5 | BLACK = (0, 0, 0) 6 | FLOOR = (245, 230, 210) # light gray 7 | COUNTER = (220, 170, 110) # tan/gray 8 | COUNTER_BORDER = (114, 93, 51) # darker tan 9 | DELIVERY = (96, 96, 96) # grey 10 | 11 | 12 | KeyToTuple = { 13 | pygame.K_UP: (0, -1), # 273 14 | pygame.K_DOWN: (0, 1), # 274 15 | pygame.K_RIGHT: (1, 0), # 275 16 | pygame.K_LEFT: (-1, 0), # 276 17 | } 18 | 19 | KeyToTuple_human1 = { 20 | pygame.K_UP: 4, # 273 21 | pygame.K_DOWN: 3, # 274 22 | pygame.K_RIGHT: 2, # 275 23 | pygame.K_LEFT: 1, # 276 24 | pygame.K_SPACE: 0, 25 | pygame.K_f: 5 26 | } 27 | 28 | KeyToTuple_human2 = { 29 | pygame.K_w: (0, -1), 30 | pygame.K_s: (0, 1), 31 | pygame.K_d: (1, 0), 32 | pygame.K_a: (-1, 0), 33 | } 34 | -------------------------------------------------------------------------------- /coop_marl/utils/rebar/rebar/interrupting.py: -------------------------------------------------------------------------------- 1 | import signal 2 | from .contextlib import maybeasynccontextmanager 3 | import logging 4 | 5 | log = logging.getLogger(__name__) 6 | 7 | class Interrupter: 8 | 9 | def __init__(self): 10 | self._is_set = False 11 | 12 | def check(self): 13 | if self._is_set: 14 | self.reset() 15 | raise KeyboardInterrupt() 16 | 17 | def handle(self, signum, frame): 18 | log.info('Setting interrupt flag') 19 | self._is_set = True 20 | 21 | def reset(self): 22 | self._is_set = False 23 | 24 | _INTERRUPTER = Interrupter() 25 | 26 | @maybeasynccontextmanager 27 | def interrupter(): 28 | old = signal.signal(signal.SIGINT, _INTERRUPTER.handle) 29 | try: 30 | yield _INTERRUPTER 31 | finally: 32 | _INTERRUPTER.reset() 33 | signal.signal(signal.SIGINT, old) -------------------------------------------------------------------------------- /scripts/pmr-c/sp_mi.sh: -------------------------------------------------------------------------------- 1 | for pop_size in 1 4 8 2 | do 3 | for seed in 111 222 333 4 | do 5 | xvfb-run -a python main.py --config_file config/algs/sp_mi/rendezvous.yaml \ 6 | --env_config_file config/envs/rendezvous.yaml \ 7 | --config '{"discrim_coef": 1, "n_sp_episodes": 400, "n_workers": 16, "save_folder": "results_sweep_rendezvous", "trainer": "simple", "z_dim": '"${pop_size}"'}' \ 8 | --env_config '{"mode": "easy"}' --seed $seed 9 | done 10 | done 11 | 12 | 13 | for seed in 111 222 333 14 | do 15 | xvfb-run -a python main.py --config_file config/algs/sp_mi/rendezvous.yaml \ 16 | --env_config_file config/envs/rendezvous.yaml \ 17 | --config '{"discrim_coef": 10, "n_sp_episodes": 400, "n_workers": 16, "save_folder": "results_sweep_rendezvous", "trainer": "simple", "z_dim": 2}' \ 18 | --env_config '{"mode": "easy"}' --seed $seed 19 | done 20 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/utils/new_style_level/open_room_tomato_salad.json: -------------------------------------------------------------------------------- 1 | { 2 | "LEVEL_LAYOUT": "-------\n- -\n- -\n- -\n- -\n- -\n-------", 3 | "STATIC_OBJECTS": [{"CutBoard": {"COUNT": 1, "X_POSITION": [1], "Y_POSITION": [6]}}, 4 | {"CutBoard": {"COUNT": 1, "X_POSITION": [0], "Y_POSITION": [5]}}, 5 | {"DeliverSquare": {"COUNT": 1, "X_POSITION": [3], "Y_POSITION": [0]}}], 6 | "DYNAMIC_OBJECTS": [{"Plate": {"COUNT": 1, "X_POSITION": [4], "Y_POSITION": [0]}}, 7 | {"Plate": {"COUNT": 1, "X_POSITION": [0], "Y_POSITION": [3]}}, 8 | {"Tomato": {"COUNT": 1, "X_POSITION": [2], "Y_POSITION": [6]}}, 9 | {"Tomato": {"COUNT": 1, "X_POSITION": [0], "Y_POSITION": [4]}} 10 | ], 11 | "AGENTS": [{"MAX_COUNT": 4, "X_POSITION": [2, 3, 4], "Y_POSITION": [2, 3, 4]}] 12 | } -------------------------------------------------------------------------------- /config/algs/maven/default.yaml: -------------------------------------------------------------------------------- 1 | def_config: !include config/algs/default.yaml 2 | 3 | algo_name: maven 4 | trainer: simple 5 | agent_name: QMIXAgent 6 | runner: EpisodesRunner 7 | flatten_traj: False 8 | vary_z_eval: True 9 | n_iter: 500 10 | n_sp_episodes: 10 11 | n_eval_ep: 10 # for each z value 12 | eval_interval: 50 13 | hidden_dim: 64 14 | mixing_embed_dim: 32 15 | hypernet_embed: 128 16 | buffer_size: 1000 17 | batch_size: 128 18 | 19 | maven: True 20 | discrim_coef: 0.1 21 | z_dim: 4 22 | z_discrete: True 23 | discrim_hidden_dim: 64 24 | z_policy: False 25 | 26 | lr: 0.001 27 | gamma: 0.99 28 | start_e: 1 29 | end_e: 0.05 30 | explore_decay_ts: 100000 31 | target_update_freq: 25 32 | env_wrappers: [ZWrapper, AgentIDWrapper, StateWrapper] 33 | 34 | # for incompat trainer 35 | pop_size: 1 36 | num_xp_pair_sample: 0 37 | use_bandit: False 38 | pg_xp_max_only: False 39 | value_xp_max_only: False 40 | eval_all_pairs: False 41 | -------------------------------------------------------------------------------- /scripts/pmr-c/maven.sh: -------------------------------------------------------------------------------- 1 | for pop_size in 1 2 2 | do 3 | for seed in 111 222 333 4 | do 5 | xvfb-run -a python main.py --config_file config/algs/maven/rendezvous.yaml \ 6 | --env_config_file config/envs/rendezvous.yaml \ 7 | --config '{"discrim_coef": 10, "n_iter": 4000, "n_sp_episodes": 30, "n_workers": 16, "save_folder": "results_sweep_rendezvous", "trainer": "simple", "z_dim": '"${pop_size}"'}' --env_config '{"mode": "easy"}' --seed $seed 8 | done 9 | done 10 | 11 | for pop_size in 4 8 12 | do 13 | for seed in 111 222 333 14 | do 15 | xvfb-run -a python main.py --config_file config/algs/maven/rendezvous.yaml \ 16 | --env_config_file config/envs/rendezvous.yaml \ 17 | --config '{"discrim_coef": 1, "n_iter": 4000, "n_sp_episodes": 30, "n_workers": 16, "save_folder": "results_sweep_rendezvous", "trainer": "simple", "z_dim": '"${pop_size}"'}' --env_config '{"mode": "easy"}' --seed $seed 18 | done 19 | done 20 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/utils/new_style_level/open_room_salad_easy.json: -------------------------------------------------------------------------------- 1 | { 2 | "LEVEL_LAYOUT": "-------\n- -\n- -\n- -\n- -\n- -\n-------", 3 | "STATIC_OBJECTS": [{"CutBoard": {"COUNT": 2, "X_POSITION": [0,6], "Y_POSITION": [3]}}, 4 | {"DeliverSquare": {"COUNT": 2, "X_POSITION": [0,6], "Y_POSITION": [4]}} 5 | ], 6 | "DYNAMIC_OBJECTS": [ 7 | {"Plate": {"COUNT": 2, "X_POSITION": [0,6], "Y_POSITION": [1]}}, 8 | {"Lettuce": {"COUNT": 2, "X_POSITION": [0,6], "Y_POSITION": [5]}}, 9 | {"Tomato": {"COUNT": 2, "X_POSITION": [1,2], "Y_POSITION": [0]}}, 10 | {"Onion": {"COUNT": 2, "X_POSITION": [4,5], "Y_POSITION": [0]}}, 11 | {"Carrot": {"COUNT": 2, "X_POSITION": [2,4], "Y_POSITION": [6]}} 12 | ], 13 | "AGENTS": [{"MAX_COUNT": 4, "X_POSITION": [1 ,2, 3, 4, 5], "Y_POSITION": [1 ,2, 3, 4, 5]}] 14 | } -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/utils/new_style_level/full_divider_salad_easy.json: -------------------------------------------------------------------------------- 1 | { 2 | "LEVEL_LAYOUT": "-------\n- - -\n- - -\n- - -\n- - -\n- - -\n-------", 3 | "STATIC_OBJECTS": [{"CutBoard": {"COUNT": 2, "X_POSITION": [0,6], "Y_POSITION": [3]}}, 4 | {"DeliverSquare": {"COUNT": 2, "X_POSITION": [0,6], "Y_POSITION": [4]}} 5 | ], 6 | "DYNAMIC_OBJECTS": [ 7 | {"Plate": {"COUNT": 2, "X_POSITION": [0,6], "Y_POSITION": [1]}}, 8 | {"Lettuce": {"COUNT": 2, "X_POSITION": [0,6], "Y_POSITION": [5]}}, 9 | {"Tomato": {"COUNT": 2, "X_POSITION": [1,2], "Y_POSITION": [0]}}, 10 | {"Onion": {"COUNT": 2, "X_POSITION": [4,5], "Y_POSITION": [0]}}, 11 | {"Carrot": {"COUNT": 2, "X_POSITION": [2,4], "Y_POSITION": [6]}} 12 | ], 13 | "AGENTS": [{"MAX_COUNT": 4, "X_POSITION": [1 ,2, 3, 4, 5], "Y_POSITION": [1 ,2, 3, 4, 5]}] 14 | } -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/utils/new_style_level/open_room_tomato_salad_r.json: -------------------------------------------------------------------------------- 1 | { 2 | "LEVEL_LAYOUT": "-------\n- -\n- -\n- -\n- -\n- -\n-------", 3 | "STATIC_OBJECTS": [{"CutBoard": {"COUNT": 1, "X_POSITION": [1 ,2, 3, 4, 5], "Y_POSITION": [0, 6]}}, 4 | {"CutBoard": {"COUNT": 1, "X_POSITION": [0, 6], "Y_POSITION": [1 ,2, 3, 4, 5]}}, 5 | {"DeliverSquare": {"COUNT": 1, "X_POSITION": [1 ,2, 3, 4, 5], "Y_POSITION": [0, 6]}}], 6 | "DYNAMIC_OBJECTS": [{"Plate": {"COUNT": 1, "X_POSITION": [1 ,2, 3, 4, 5], "Y_POSITION": [0, 6]}}, 7 | {"Plate": {"COUNT": 1, "X_POSITION": [0, 6], "Y_POSITION": [1 ,2, 3, 4, 5]}}, 8 | {"Tomato": {"COUNT": 1, "X_POSITION": [1 ,2, 3, 4, 5], "Y_POSITION": [0, 6]}}, 9 | {"Tomato": {"COUNT": 1, "X_POSITION": [0, 6], "Y_POSITION": [1 ,2, 3, 4, 5]}} 10 | ], 11 | "AGENTS": [{"MAX_COUNT": 4, "X_POSITION": [2, 3, 4], "Y_POSITION": [2, 3, 4]}] 12 | } -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/utils/new_style_level/open_room_blender.json: -------------------------------------------------------------------------------- 1 | { 2 | "LEVEL_LAYOUT": "-------\n- -\n- -\n- -\n- -\n- -\n-------", 3 | "STATIC_OBJECTS": [{"CutBoard": {"COUNT": 1, "X_POSITION": [1 ,2, 3, 4, 5], "Y_POSITION": [0, 6]}}, 4 | {"Blender": {"COUNT": 1, "X_POSITION": [0, 6], "Y_POSITION": [1 ,2, 3, 4, 5]}}, 5 | {"DeliverSquare": {"COUNT": 1, "X_POSITION": [1 ,2, 3, 4, 5], "Y_POSITION": [0, 6]}}], 6 | "DYNAMIC_OBJECTS": [{"Plate": {"COUNT": 1, "X_POSITION": [1 ,2, 3, 4, 5], "Y_POSITION": [0, 6]}}, 7 | {"Plate": {"COUNT": 1, "X_POSITION": [0, 6], "Y_POSITION": [1 ,2, 3, 4, 5]}}, 8 | {"Lettuce": {"COUNT": 1, "X_POSITION": [0, 6], "Y_POSITION": [1 ,2, 3, 4, 5]}}, 9 | {"Carrot": {"COUNT": 1, "X_POSITION": [1 ,2, 3, 4, 5], "Y_POSITION": [0, 6]}} 10 | ], 11 | "AGENTS": [{"MAX_COUNT": 4, "X_POSITION": [1 ,2, 3, 4, 5], "Y_POSITION": [1 ,2, 3, 4, 5]}] 12 | } -------------------------------------------------------------------------------- /scripts/pmr-c/multi_sp_mi.sh: -------------------------------------------------------------------------------- 1 | for seed in 111 222 333 2 | do 3 | xvfb-run -a python main.py --config_file config/algs/sp_mi/rendezvous.yaml \ 4 | --env_config_file config/envs/rendezvous.yaml \ 5 | --config '{"algo_name": "multi_sp_mi", "discrim_coef": 1, "n_sp_episodes": 400, "n_workers": 16, "pop_size": 1, "save_folder": "results_sweep_rendezvous", "trainer": "incompat", "vary_z_eval": 1, "z_dim": 8}' \ 6 | --env_config '{"mode": "easy"}' --seed $seed 7 | done 8 | 9 | for pop_size in 2 4 10 | do 11 | for seed in 111 222 333 12 | do 13 | xvfb-run -a python main.py --config_file config/algs/sp_mi/rendezvous.yaml \ 14 | --env_config_file config/envs/rendezvous.yaml \ 15 | --config '{"algo_name": "multi_sp_mi", "discrim_coef": 1, "n_sp_episodes": 400, "n_workers": 16, "pop_size": '"${pop_size}"', "save_folder": "results_sweep_rendezvous", "trainer": "incompat", "vary_z_eval": 1, "z_dim": 8}' \ 16 | --env_config '{"mode": "easy"}' --seed $seed 17 | done 18 | done 19 | -------------------------------------------------------------------------------- /scripts/cmg-s/multi_maven.sh: -------------------------------------------------------------------------------- 1 | for seed in 111 222 333 2 | do 3 | xvfb-run -a python main.py --config_file config/algs/maven/one_step_matrix.yaml \ 4 | --env_config_file config/envs/one_step_matrix.yaml \ 5 | --config '{"algo_name": "multi_maven", "discrim_coef": 5, "n_sp_episodes": 800, "n_workers": 16, "pop_size": 1, "save_folder": "results_sweep_one_step_matrix_k_8", "trainer": "incompat", "vary_z_eval": 1, "z_dim": 8}' \ 6 | --env_config '{"k": 8}' --seed $seed 7 | done 8 | 9 | for pop_size in 2 4 8 10 | do 11 | for seed in 111 222 333 12 | do 13 | xvfb-run -a python main.py --config_file config/algs/maven/one_step_matrix.yaml \ 14 | --env_config_file config/envs/one_step_matrix.yaml \ 15 | --config '{"algo_name": "multi_maven", "discrim_coef": 1, "n_sp_episodes": 800, "n_workers": 16, "pop_size": '"${pop_size}"', "save_folder": "results_sweep_one_step_matrix_k_8", "trainer": "incompat", "vary_z_eval": 1, "z_dim": 8}' \ 16 | --env_config '{"k": 8}' --seed $seed 17 | done 18 | done -------------------------------------------------------------------------------- /scripts/cmg-h/multi_sp_mi.sh: -------------------------------------------------------------------------------- 1 | for seed in 111 222 333 2 | do 3 | xvfb-run -a python main.py --config_file config/algs/sp_mi/one_step_matrix.yaml \ 4 | --env_config_file config/envs/one_step_matrix.yaml \ 5 | --config '{"algo_name": "multi_sp_mi", "discrim_coef": 10, "n_sp_episodes": 800, "n_workers": 16, "pop_size": 1, "save_folder": "results_sweep_one_step_matrix_uneven_m32", "trainer": "incompat", "vary_z_eval": 1, "z_dim": 8}' \ 6 | --env_config {} --seed $seed 7 | done 8 | 9 | for pop_size in 2 4 8 10 | do 11 | for seed in 111 222 333 12 | do 13 | xvfb-run -a python main.py --config_file config/algs/sp_mi/one_step_matrix.yaml \ 14 | --env_config_file config/envs/one_step_matrix.yaml \ 15 | --config '{"algo_name": "multi_sp_mi", "discrim_coef": 50, "n_sp_episodes": 800, "n_workers": 16, "pop_size": '"${pop_size}"', "save_folder": "results_sweep_one_step_matrix_uneven_m32", "trainer": "incompat", "vary_z_eval": 1, "z_dim": 8}' \ 16 | --env_config {} --seed $seed 17 | done 18 | done 19 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/environment/graphics/graphic_store.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym_cooking.cooking_world.world_objects import * 3 | from collections import namedtuple 4 | 5 | 6 | GraphicScaling = namedtuple("GraphicScaling", ["holding_scale", "container_scale"]) 7 | 8 | 9 | class GraphicStore: 10 | 11 | OBJECT_PROPERTIES = {Blender: GraphicScaling(None, 0.5)} 12 | 13 | def __init__(self, world_height, world_width): 14 | self.scale = 80 # num pixels per tile 15 | self.holding_scale = 0.5 16 | self.container_scale = 0.7 17 | self.width = self.scale * world_width 18 | self.height = self.scale * world_height 19 | self.tile_size = (self.scale, self.scale) 20 | self.holding_size = tuple((self.holding_scale * np.asarray(self.tile_size)).astype(int)) 21 | self.container_size = tuple((self.container_scale * np.asarray(self.tile_size)).astype(int)) 22 | self.holding_container_size = tuple((self.container_scale * np.asarray(self.holding_size)).astype(int)) 23 | 24 | 25 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/utils/new_style_level/full_divider_salad_more_ingred.json: -------------------------------------------------------------------------------- 1 | { 2 | "LEVEL_LAYOUT": "-------\n- - -\n- - -\n- - -\n- - -\n- - -\n-------", 3 | "STATIC_OBJECTS": [{"CutBoard": {"COUNT": 1, "X_POSITION": [0], "Y_POSITION": [3]}}, 4 | {"DeliverSquare": {"COUNT": 1, "X_POSITION": [6], "Y_POSITION": [4]}} 5 | ], 6 | "DYNAMIC_OBJECTS": [ 7 | {"Plate": {"COUNT": 2, "X_POSITION": [0,6], "Y_POSITION": [1]}}, 8 | {"Lettuce": {"COUNT": 1, "X_POSITION": [6], "Y_POSITION": [5]}}, 9 | {"Tomato": {"COUNT": 2, "X_POSITION": [6], "Y_POSITION": [2,3]}}, 10 | {"Onion": {"COUNT": 2, "X_POSITION": [4,5], "Y_POSITION": [0]}}, 11 | {"Carrot": {"COUNT": 2, "X_POSITION": [4,5], "Y_POSITION": [6]}} 12 | ], 13 | "AGENTS": [{"MAX_COUNT": 1, "X_POSITION": [1 ,2], "Y_POSITION": [1 ,2, 3, 4, 5]}, 14 | {"MAX_COUNT": 1, "X_POSITION": [4 ,5], "Y_POSITION": [1 ,2, 3, 4, 5]} 15 | ] 16 | } -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/utils/new_style_level/full_divider_salad_static.json: -------------------------------------------------------------------------------- 1 | { 2 | "LEVEL_LAYOUT": "-------\n- - -\n- - -\n- - -\n- - -\n- - -\n-------", 3 | "STATIC_OBJECTS": [{"CutBoard": {"COUNT": 1, "X_POSITION": [0], "Y_POSITION": [3]}}, 4 | {"DeliverSquare": {"COUNT": 1, "X_POSITION": [6], "Y_POSITION": [3]}} 5 | ], 6 | "DYNAMIC_OBJECTS": [ 7 | {"Plate": {"COUNT": 1, "X_POSITION": [3], "Y_POSITION": [3]}}, 8 | {"Lettuce": {"COUNT": 2, "X_POSITION": [0,6], "Y_POSITION": [5]}}, 9 | {"Tomato": {"COUNT": 2, "X_POSITION": [1,2], "Y_POSITION": [0]}}, 10 | {"Onion": {"COUNT": 2, "X_POSITION": [4,5], "Y_POSITION": [0]}}, 11 | {"Carrot": {"COUNT": 2, "X_POSITION": [2,4], "Y_POSITION": [6]}} 12 | ], 13 | "AGENTS": [{"MAX_COUNT": 1, "X_POSITION": [1 ,2], "Y_POSITION": [1 ,2, 3, 4, 5]}, 14 | {"MAX_COUNT": 1, "X_POSITION": [4 ,5], "Y_POSITION": [1 ,2, 3, 4, 5]} 15 | ] 16 | } 17 | -------------------------------------------------------------------------------- /coop_marl/utils/rebar/rebar/storing.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pickle 3 | from . import paths 4 | import time 5 | 6 | def store_latest(run_name, objs, throttle=0): 7 | path = paths.path(run_name, 'storing').with_suffix('.pkl') 8 | if path.exists(): 9 | if (time.time() - path.lstat().st_mtime) < throttle: 10 | return False 11 | 12 | state_dicts = {k: v.state_dict() for k, v in objs.items()} 13 | bs = pickle.dumps(state_dicts) 14 | path.with_suffix('.tmp').write_bytes(bs) 15 | path.with_suffix('.tmp').rename(path) 16 | 17 | return True 18 | 19 | def runs(): 20 | return paths.runs() 21 | 22 | def stored(run_name=-1): 23 | ps = paths.subdirectory(run_name, 'storing').glob('*.pkl') 24 | infos = [] 25 | for p in ps: 26 | infos.append({ 27 | **paths.parse(p), 28 | 'path': p}) 29 | 30 | return pd.DataFrame(infos) 31 | 32 | def load(run_name=-1, procname='MainProcess'): 33 | path = stored(run_name).loc[lambda df: df.procname == procname].iloc[-1].path 34 | return pickle.loads(path.read_bytes()) -------------------------------------------------------------------------------- /config/algs/meta/default.yaml: -------------------------------------------------------------------------------- 1 | # Format goes like this 2 | # def_config: 3 | # k: v 4 | # config: 5 | # k: v 6 | def_config: !include config/algs/default.yaml 7 | 8 | algo_name: meta_rl # used for creating a save directory 9 | trainer: meta 10 | runner: EpisodesRunner 11 | agent_name: MAPPORL2Agent 12 | render: True 13 | render_mode: 'rgb_array' 14 | eval_interval: 50 15 | n_eval_ep: 10 16 | num_cpus: 1 17 | use_gpu: False 18 | debug: False 19 | 20 | n_workers: 4 21 | n_iter: 400 22 | eval_interval: 50 23 | n_episodes: 50 24 | n_ts: 5000 25 | n_eval_ep: 10 26 | z_dim: 4 27 | z_discrete: True 28 | flatten_traj: True 29 | critic_use_local_obs: False 30 | 31 | hidden_size: 256 32 | lr: 0.001 33 | gamma: 0.99 34 | gae_lambda: 0.95 35 | ent_coef: 0.01 36 | clip_param: 0.3 37 | vf_clip_param: 10 38 | vf_coef: 1.0 39 | max_len: 50 40 | num_seq_mb: 100 # 100*50 timesteps per minibatch 41 | num_mb: 0 42 | mb_size: 0 43 | epochs: 10 44 | env_wrappers: [ZWrapper, AgentIDWrapper, StateWrapper] 45 | shared_z: False 46 | 47 | partner_dir: [] 48 | partner_iterations: [] 49 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/utils/new_style_level/full_divider_salad_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "LEVEL_LAYOUT": "-------\n- - -\n- - -\n- - -\n- - -\n- - -\n-------", 3 | "STATIC_OBJECTS": [{"CutBoard": {"COUNT": 1, "X_POSITION": [0], "Y_POSITION": [3]}}, 4 | {"DeliverSquare": {"COUNT": 1, "X_POSITION": [6], "Y_POSITION": [3]}} 5 | ], 6 | "DYNAMIC_OBJECTS": [ 7 | {"Plate": {"COUNT": 1, "X_POSITION": [3], "Y_POSITION": [1,2,3,4,5]}}, 8 | {"Lettuce": {"COUNT": 1, "X_POSITION": [0], "Y_POSITION": [1,2,3,4,5]}}, 9 | {"Tomato": {"COUNT": 1, "X_POSITION": [6], "Y_POSITION": [1,2,3,4,5]}}, 10 | {"Onion": {"COUNT": 1, "X_POSITION": [6], "Y_POSITION": [1,2,3,4,5]}}, 11 | {"Carrot": {"COUNT": 1, "X_POSITION": [0], "Y_POSITION": [1,2,3,4,5]}} 12 | ], 13 | "AGENTS": [{"MAX_COUNT": 1, "X_POSITION": [1 ,2], "Y_POSITION": [1 ,2, 3, 4, 5]}, 14 | {"MAX_COUNT": 1, "X_POSITION": [4 ,5], "Y_POSITION": [1 ,2, 3, 4, 5]} 15 | ] 16 | } -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/utils/new_style_level/full_divider_salad_3.json: -------------------------------------------------------------------------------- 1 | { 2 | "LEVEL_LAYOUT": "-------\n- - -\n- - -\n- - -\n- - -\n- - -\n-------", 3 | "STATIC_OBJECTS": [{"CutBoard": {"COUNT": 1, "X_POSITION": [0], "Y_POSITION": [3]}}, 4 | {"DeliverSquare": {"COUNT": 1, "X_POSITION": [6], "Y_POSITION": [3]}} 5 | ], 6 | "DYNAMIC_OBJECTS": [ 7 | {"Plate": {"COUNT": 1, "X_POSITION": [3], "Y_POSITION": [1,2,3,4,5]}}, 8 | {"Lettuce": {"COUNT": 2, "X_POSITION": [0], "Y_POSITION": [1,2,3,4,5]}}, 9 | {"Tomato": {"COUNT": 2, "X_POSITION": [6], "Y_POSITION": [1,2,3,4,5]}}, 10 | {"Onion": {"COUNT": 2, "X_POSITION": [6], "Y_POSITION": [1,2,3,4,5]}}, 11 | {"Carrot": {"COUNT": 2, "X_POSITION": [0], "Y_POSITION": [1,2,3,4,5]}} 12 | ], 13 | "AGENTS": [{"MAX_COUNT": 1, "X_POSITION": [1 ,2], "Y_POSITION": [1 ,2, 3, 4, 5]}, 14 | {"MAX_COUNT": 1, "X_POSITION": [4 ,5], "Y_POSITION": [1 ,2, 3, 4, 5]} 15 | ] 16 | } -------------------------------------------------------------------------------- /scripts/cmg-h/multi_maven.sh: -------------------------------------------------------------------------------- 1 | for pop_size in 1 4 2 | do 3 | for seed in 111 222 333 4 | do 5 | xvfb-run -a python main.py --config_file config/algs/maven/one_step_matrix.yaml \ 6 | --env_config_file config/envs/one_step_matrix.yaml \ 7 | --config '{"algo_name": "multi_maven", "discrim_coef": 1, "n_sp_episodes": 800, "n_workers": 16, "pop_size": '"${pop_size}"', "save_folder": "results_sweep_one_step_matrix_uneven_m32", "trainer": "incompat", "vary_z_eval": 1, "z_dim": 8}' \ 8 | --env_config '{}' --seed $seed 9 | done 10 | done 11 | 12 | for pop_size in 2 8 13 | do 14 | for seed in 111 222 333 15 | do 16 | xvfb-run -a python main.py --config_file config/algs/maven/one_step_matrix.yaml \ 17 | --env_config_file config/envs/one_step_matrix.yaml \ 18 | --config '{"algo_name": "multi_maven", "discrim_coef": 5, "n_sp_episodes": 800, "n_workers": 16, "pop_size": '"${pop_size}"', "save_folder": "results_sweep_one_step_matrix_uneven_m32", "trainer": "incompat", "vary_z_eval": 1, "z_dim": 8}' \ 19 | --env_config '{}' --seed $seed 20 | done 21 | done -------------------------------------------------------------------------------- /scripts/pmr-c/lipo.sh: -------------------------------------------------------------------------------- 1 | for pop_size in 1 4 2 | do 3 | for seed in 111 222 333 4 | do 5 | xvfb-run -a python main.py --config_file config/algs/incompat/rendezvous.yaml \ 6 | --env_config_file config/envs/rendezvous.yaml \ 7 | --config '{"pop_size": '"${pop_size}"', "save_folder": "results_sweep_rendezvous", "xp_coef": 0.5}' \ 8 | --env_config '{"mode": "easy"}' --seed $seed 9 | done 10 | done 11 | 12 | for seed in 111 222 333 13 | do 14 | xvfb-run -a python main.py --config_file config/algs/incompat/rendezvous.yaml \ 15 | --env_config_file config/envs/rendezvous.yaml \ 16 | --config '{"pop_size": 2, "save_folder": "results_sweep_rendezvous", "xp_coef": 0.1}' \ 17 | --env_config '{"mode": "easy"}' --seed $seed 18 | done 19 | 20 | for seed in 111 222 333 21 | do 22 | xvfb-run -a python main.py --config_file config/algs/incompat/rendezvous.yaml \ 23 | --env_config_file config/envs/rendezvous.yaml \ 24 | --config '{"pop_size": 8, "save_folder": "results_sweep_rendezvous", "xp_coef": 0.25}' \ 25 | --env_config '{"mode": "easy"}' --seed $seed 26 | done 27 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/utils/new_style_level/full_divider_salad_4.json: -------------------------------------------------------------------------------- 1 | { 2 | "LEVEL_LAYOUT": "-------\n- - -\n- - -\n- - -\n- - -\n- - -\n-------", 3 | "STATIC_OBJECTS": [{"CutBoard": {"COUNT": 1, "X_POSITION": [0], "Y_POSITION": [3]}}, 4 | {"DeliverSquare": {"COUNT": 1, "X_POSITION": [6], "Y_POSITION": [3]}} 5 | ], 6 | "DYNAMIC_OBJECTS": [ 7 | {"Plate": {"COUNT": 1, "X_POSITION": [3], "Y_POSITION": [1,2,3,4,5]}}, 8 | {"Lettuce": {"COUNT": 1, "X_POSITION": [0,3], "Y_POSITION": [1,2,3,4,5]}}, 9 | {"Tomato": {"COUNT": 1, "X_POSITION": [3,6], "Y_POSITION": [1,2,3,4,5]}}, 10 | {"Onion": {"COUNT": 1, "X_POSITION": [0,3], "Y_POSITION": [1,2,3,4,5]}}, 11 | {"Carrot": {"COUNT": 1, "X_POSITION": [3,6], "Y_POSITION": [1,2,3,4,5]}} 12 | ], 13 | "AGENTS": [{"MAX_COUNT": 1, "X_POSITION": [1 ,2], "Y_POSITION": [1 ,2, 3, 4, 5]}, 14 | {"MAX_COUNT": 1, "X_POSITION": [4 ,5], "Y_POSITION": [1 ,2, 3, 4, 5]} 15 | ] 16 | } -------------------------------------------------------------------------------- /coop_marl/utils/nn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | FLOAT_MIN = -3.4e38 5 | FLOAT_MAX = 3.4e38 6 | 7 | # https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_continuous_action.py 8 | def ortho_layer_init(layer, std=np.sqrt(2), bias_const=0.0): 9 | torch.nn.init.orthogonal_(layer.weight, std) 10 | torch.nn.init.constant_(layer.bias, bias_const) 11 | return layer 12 | 13 | def k_uniform_init(layer, a=0, mode='fan_in', nonlinearity='leaky_relu'): 14 | torch.nn.init.kaiming_uniform_(layer.weight, a, mode, nonlinearity) 15 | torch.nn.init.constant_(layer.bias, 0) 16 | return layer 17 | 18 | def dict_to_tensor(obs_dict, device, axis=0, dtype=torch.float): 19 | # takes a dict of obs (e.g. player->obs) and returns tensor of obs as [N_player, obs_dim] 20 | return torch.stack([torch.as_tensor(o, dtype=dtype, device=device) for o in obs_dict.values()], axis=axis) 21 | 22 | def dict_to_np(obs_dict,*, axis=0, dtype=np.float32): 23 | # takes a dict of obs (e.g. player->obs) and returns tensor of obs as [N_player, obs_dim] 24 | return np.stack([np.array(o, dtype=dtype) for o in obs_dict.values()], axis=axis) -------------------------------------------------------------------------------- /config/algs/sp_mi/default.yaml: -------------------------------------------------------------------------------- 1 | # Format goes like this 2 | # def_config: 3 | # k: v 4 | # config: 5 | # k: v 6 | def_config: !include config/algs/default.yaml 7 | 8 | trainer: incompat 9 | runner: EpisodesRunner 10 | algo_name: sp_mi # used for creating a save directory 11 | render: True 12 | render_mode: 'rgb_array' 13 | eval_interval: 50 14 | n_eval_ep: 10 15 | num_cpus: 1 16 | use_gpu: False 17 | debug: False 18 | agent_name: IncompatMAPPOZ 19 | flatten_traj: False 20 | vary_z_eval: True 21 | eval_all_pairs: False 22 | num_xp_pair_sample: 0 23 | 24 | parent_only: False 25 | n_iter: 400 26 | eval_interval: 50 27 | n_sp_episodes: 50 28 | n_xp_episodes: 50 29 | n_eval_ep: 10 30 | z_dim: 4 31 | z_discrete: True 32 | pop_size: 1 33 | flatten_traj: False 34 | pg_xp_max_only: True 35 | value_xp_max_only: False 36 | anneal_xp: False 37 | 38 | lr: 0.0001 39 | use_hypernet: False 40 | # hyper_l2_reg_coef: 0.0001 41 | xp_coef: 0.0 42 | discrim_coef: 0.1 43 | gamma: 0.99 44 | gae_lambda: 0.95 45 | ent_coef: 0.03 46 | clip_param: 0.3 47 | vf_clip_param: 10 48 | vf_coef: 0.5 49 | num_mb: 5 50 | epochs: 3 51 | env_wrappers: [ZWrapper, AgentIDWrapper, StateWrapper] 52 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/utils/new_style_level/open_room_salad.json: -------------------------------------------------------------------------------- 1 | { 2 | "LEVEL_LAYOUT": "-------\n- -\n- -\n- -\n- -\n- -\n-------", 3 | "STATIC_OBJECTS": [{"CutBoard": {"COUNT": 1, "X_POSITION": [1 ,2, 3, 4, 5], "Y_POSITION": [0, 6]}}, 4 | {"CutBoard": {"COUNT": 1, "X_POSITION": [0, 6], "Y_POSITION": [1 ,2, 3, 4, 5]}}, 5 | {"DeliverSquare": {"COUNT": 1, "X_POSITION": [1 ,2, 3, 4, 5], "Y_POSITION": [0, 6]}}], 6 | "DYNAMIC_OBJECTS": [{"Plate": {"COUNT": 1, "X_POSITION": [1 ,2, 3, 4, 5], "Y_POSITION": [0, 6]}}, 7 | {"Plate": {"COUNT": 1, "X_POSITION": [0, 6], "Y_POSITION": [1 ,2, 3, 4, 5]}}, 8 | {"Lettuce": {"COUNT": 1, "X_POSITION": [0, 6], "Y_POSITION": [1 ,2, 3, 4, 5]}}, 9 | {"Tomato": {"COUNT": 1, "X_POSITION": [1 ,2, 3, 4, 5], "Y_POSITION": [0, 6]}}, 10 | {"Onion": {"COUNT": 1, "X_POSITION": [1 ,2, 3, 4, 5], "Y_POSITION": [0, 6]}}, 11 | {"Tomato": {"COUNT": 1, "X_POSITION": [0, 6], "Y_POSITION": [1 ,2, 3, 4, 5]}} 12 | ], 13 | "AGENTS": [{"MAX_COUNT": 4, "X_POSITION": [1 ,2, 3, 4, 5], "Y_POSITION": [1 ,2, 3, 4, 5]}] 14 | } -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/environment/environment.py: -------------------------------------------------------------------------------- 1 | from gym_cooking.environment import cooking_zoo 2 | from gym.utils import seeding 3 | 4 | import gym 5 | 6 | 7 | class GymCookingEnvironment(gym.Env): 8 | """Environment object for Overcooked.""" 9 | 10 | metadata = {'render.modes': ['human'], 'name': "cooking_zoo"} 11 | 12 | def __init__(self, level, record, max_steps, recipe, obs_spaces=["numeric"]): 13 | super().__init__() 14 | self.num_agents = 1 15 | self.zoo_env = cooking_zoo.parallel_env(level=level, num_agents=self.num_agents, record=record, 16 | max_steps=max_steps, recipes=[recipe], obs_spaces=obs_spaces) 17 | self.observation_space = self.zoo_env.observation_spaces["player_0"] 18 | self.action_space = self.zoo_env.action_spaces["player_0"] 19 | 20 | def step(self, action): 21 | converted_action = {"player_0": action} 22 | obs, reward, done, info = self.zoo_env.step(converted_action) 23 | return obs["player_0"], reward["player_0"], done["player_0"], info["player_0"] 24 | 25 | def reset(self): 26 | return self.zoo_env.reset()["player_0"] 27 | 28 | def render(self, mode='human'): 29 | pass 30 | 31 | -------------------------------------------------------------------------------- /scripts/pmr-l/sp_mi.sh: -------------------------------------------------------------------------------- 1 | for pop_size in 1 2 2 | do 3 | for seed in 111 222 333 4 | do 5 | xvfb-run -a python main.py --config_file config/algs/sp_mi/rendezvous.yaml \ 6 | --env_config_file config/envs/rendezvous.yaml \ 7 | --config '{"discrim_coef": 10, "n_sp_episodes": 400, "n_workers": 16, "save_folder": "results_sweep_rendezvous", "trainer": "simple", "z_dim": '"${pop_size}"'}' \ 8 | --env_config '{"mode": "hard"}' --seed $seed 9 | done 10 | done 11 | 12 | for seed in 111 222 333 13 | do 14 | xvfb-run -a python main.py --config_file config/algs/sp_mi/rendezvous.yaml \ 15 | --env_config_file config/envs/rendezvous.yaml \ 16 | --config '{"discrim_coef": 5, "n_sp_episodes": 400, "n_workers": 16, "save_folder": "results_sweep_rendezvous", "trainer": "simple", "z_dim": 4}' \ 17 | --env_config '{"mode": "hard"}' --seed $seed 18 | done 19 | 20 | for seed in 111 222 333 21 | do 22 | xvfb-run -a python main.py --config_file config/algs/sp_mi/rendezvous.yaml \ 23 | --env_config_file config/envs/rendezvous.yaml \ 24 | --config '{"discrim_coef": 1, "n_sp_episodes": 400, "n_workers": 16, "save_folder": "results_sweep_rendezvous", "trainer": "simple", "z_dim": 8}' \ 25 | --env_config '{"mode": "hard"}' --seed $seed 26 | done 27 | -------------------------------------------------------------------------------- /scripts/cmg-h/maven.sh: -------------------------------------------------------------------------------- 1 | for seed in 111 222 333 2 | do 3 | xvfb-run -a python main.py --config_file config/algs/maven/one_step_matrix.yaml \ 4 | --env_config_file config/envs/one_step_matrix.yaml \ 5 | --config '{"discrim_coef": 5, "n_sp_episodes": 6400, "n_workers": 16, "save_folder": "results_sweep_one_step_matrix_uneven_m32", "trainer": "simple", "z_dim": 8}' --env_config '{}' --seed $seed 6 | done 7 | 8 | for seed in 111 222 333 9 | do 10 | xvfb-run -a python main.py --config_file config/algs/maven/one_step_matrix.yaml \ 11 | --env_config_file config/envs/one_step_matrix.yaml \ 12 | --config '{"discrim_coef": 10, "n_sp_episodes": 6400, "n_workers": 16, "save_folder": "results_sweep_one_step_matrix_uneven_m32", "trainer": "simple", "z_dim": 16}' --env_config '{}' --seed $seed 13 | done 14 | 15 | for pop_size in 32 64 16 | do 17 | for seed in 111 222 333 18 | do 19 | xvfb-run -a python main.py --config_file config/algs/maven/one_step_matrix.yaml \ 20 | --env_config_file config/envs/one_step_matrix.yaml \ 21 | --config '{"discrim_coef": 50, "n_sp_episodes": 6400, "n_workers": 16, "save_folder": "results_sweep_one_step_matrix_uneven_m32", "trainer": "simple", "z_dim": '"${pop_size}"'}' --env_config '{}' --seed $seed 22 | done 23 | done 24 | -------------------------------------------------------------------------------- /scripts/cmg-s/maven.sh: -------------------------------------------------------------------------------- 1 | for seed in 111 222 333 2 | do 3 | xvfb-run -a python main.py --config_file config/algs/maven/one_step_matrix.yaml \ 4 | --env_config_file config/envs/one_step_matrix.yaml \ 5 | --config '{"discrim_coef": 1, "n_sp_episodes": 6400, "n_workers": 16, "save_folder": "results_sweep_one_step_matrix_k_8", "trainer": "simple", "z_dim": 8}' \ 6 | --env_config '{"k": 8}' --seed $seed 7 | done 8 | 9 | for pop_size in 16 32 10 | do 11 | for seed in 111 222 333 12 | do 13 | xvfb-run -a python main.py --config_file config/algs/maven/one_step_matrix.yaml \ 14 | --env_config_file config/envs/one_step_matrix.yaml \ 15 | --config '{"discrim_coef": 5, "n_sp_episodes": 6400, "n_workers": 16, "save_folder": "results_sweep_one_step_matrix_k_8", "trainer": "simple", "z_dim": '"${pop_size}"'}' \ 16 | --env_config '{"k": 8}' --seed $seed 17 | done 18 | done 19 | 20 | for seed in 111 222 333 21 | do 22 | xvfb-run -a python main.py --config_file config/algs/maven/one_step_matrix.yaml \ 23 | --env_config_file config/envs/one_step_matrix.yaml \ 24 | --config '{"discrim_coef": 10, "n_sp_episodes": 6400, "n_workers": 16, "save_folder": "results_sweep_one_step_matrix_k_8", "trainer": "simple", "z_dim": 64}' \ 25 | --env_config '{"k": 8}' --seed $seed 26 | done 27 | -------------------------------------------------------------------------------- /scripts/pmr-l/maven.sh: -------------------------------------------------------------------------------- 1 | for pop_size in 1 8 2 | do 3 | for seed in 111 222 333 4 | do 5 | xvfb-run -a python main.py --config_file config/algs/maven/rendezvous.yaml \ 6 | --env_config_file config/envs/rendezvous.yaml \ 7 | --config '{"discrim_coef": 1, "n_iter": 4000, "n_sp_episodes": 30, "n_workers": 16, "save_folder": "results_sweep_rendezvous", "trainer": "simple", "z_dim": '"${pop_size}"'}' \ 8 | --env_config '{"mode": "hard"}' --seed $seed 9 | done 10 | done 11 | 12 | for seed in 111 222 333 13 | do 14 | xvfb-run -a python main.py --config_file config/algs/maven/rendezvous.yaml \ 15 | --env_config_file config/envs/rendezvous.yaml \ 16 | --config '{"discrim_coef": 0.005, "n_iter": 4000, "n_sp_episodes": 30, "n_workers": 16, "save_folder": "results_sweep_rendezvous", "trainer": "simple", "z_dim": 2}' \ 17 | --env_config '{"mode": "hard"}' --seed $seed 18 | done 19 | 20 | for seed in 111 222 333 21 | do 22 | xvfb-run -a python main.py --config_file config/algs/maven/rendezvous.yaml \ 23 | --env_config_file config/envs/rendezvous.yaml \ 24 | --config '{"discrim_coef": 5, "n_iter": 4000, "n_sp_episodes": 30, "n_workers": 16, "save_folder": "results_sweep_rendezvous", "trainer": "simple", "z_dim": 4}' \ 25 | --env_config '{"mode": "hard"}' --seed $seed 26 | done 27 | -------------------------------------------------------------------------------- /scripts/pmr-l/lipo.sh: -------------------------------------------------------------------------------- 1 | for seed in 111 222 333 2 | do 3 | xvfb-run -a python main.py --config_file config/algs/incompat/rendezvous.yaml \ 4 | --env_config_file config/envs/rendezvous.yaml \ 5 | --config '{"pop_size": 1, "save_folder": "results_sweep_rendezvous", "xp_coef": 0.5}' --env_config '{"mode": "hard"}' --seed $seed 6 | done 7 | 8 | for seed in 111 222 333 9 | do 10 | xvfb-run -a python main.py --config_file config/algs/incompat/rendezvous.yaml \ 11 | --env_config_file config/envs/rendezvous.yaml \ 12 | --config '{"pop_size": 2, "save_folder": "results_sweep_rendezvous", "xp_coef": 0.25}' --env_config '{"mode": "hard"}' --seed $seed 13 | done 14 | 15 | for seed in 111 222 333 16 | do 17 | xvfb-run -a python main.py --config_file config/algs/incompat/rendezvous.yaml \ 18 | --env_config_file config/envs/rendezvous.yaml \ 19 | --config '{"pop_size": 4, "save_folder": "results_sweep_rendezvous", "xp_coef": 0.25}' --env_config '{"mode": "hard"}' --seed $seed 20 | done 21 | 22 | for seed in 111 222 333 23 | do 24 | xvfb-run -a python main.py --config_file config/algs/incompat/rendezvous.yaml \ 25 | --env_config_file config/envs/rendezvous.yaml \ 26 | --config '{"pop_size": 8, "save_folder": "results_sweep_rendezvous", "xp_coef": 0.25}' --env_config '{"mode": "hard"}' --seed $seed 27 | done 28 | -------------------------------------------------------------------------------- /scripts/pmr-l/multi_sp_mi.sh: -------------------------------------------------------------------------------- 1 | for seed in 111 222 333 2 | do 3 | xvfb-run -a python main.py --config_file config/algs/sp_mi/rendezvous.yaml \ 4 | --env_config_file config/envs/rendezvous.yaml \ 5 | --config '{"algo_name": "multi_sp_mi", "discrim_coef": 1, "n_sp_episodes": 400, "n_workers": 16, "pop_size": 1, "save_folder": "results_sweep_rendezvous", "trainer": "incompat", "vary_z_eval": 1, "z_dim": 8}' --env_config '{"mode": "hard"}' --seed $seed 6 | done 7 | 8 | for seed in 111 222 333 9 | do 10 | xvfb-run -a python main.py --config_file config/algs/sp_mi/rendezvous.yaml \ 11 | --env_config_file config/envs/rendezvous.yaml \ 12 | --config '{"algo_name": "multi_sp_mi", "discrim_coef": 5, "n_sp_episodes": 400, "n_workers": 16, "pop_size": 2, "save_folder": "results_sweep_rendezvous", "trainer": "incompat", "vary_z_eval": 1, "z_dim": 8}' --env_config '{"mode": "hard"}' --seed $seed 13 | done 14 | 15 | for seed in 111 222 333 16 | do 17 | xvfb-run -a python main.py --config_file config/algs/sp_mi/rendezvous.yaml \ 18 | --env_config_file config/envs/rendezvous.yaml \ 19 | --config '{"algo_name": "multi_sp_mi", "discrim_coef": 1, "n_sp_episodes": 400, "n_workers": 16, "pop_size": 4, "save_folder": "results_sweep_rendezvous", "trainer": "incompat", "vary_z_eval": 1, "z_dim": 8}' --env_config '{"mode": "hard"}' --seed $seed 20 | done 21 | -------------------------------------------------------------------------------- /config/algs/multi_sp/default.yaml: -------------------------------------------------------------------------------- 1 | # Format goes like this 2 | # def_config: 3 | # k: v 4 | # config: 5 | # k: v 6 | def_config: !include config/algs/default.yaml 7 | 8 | render: True 9 | render_mode: 'rgb_array' 10 | num_cpus: 1 11 | use_gpu: False 12 | debug: False 13 | 14 | algo_name: multi_sp 15 | trainer: incompat 16 | runner: EpisodesRunner 17 | agent_name: IncompatMAPPOZ 18 | 19 | use_gpu: False 20 | debug: False 21 | flatten_traj: False 22 | parent_only: False 23 | eval_all_pairs: False 24 | z_dim: 4 25 | z_discrete: True 26 | pop_size: 1 27 | 28 | pg_xp_max_only: True 29 | value_xp_max_only: False 30 | shared_z: False 31 | use_bandit: False 32 | anneal_xp: False 33 | 34 | discrim_coef: 0.0 35 | xp_coef: 0.0 36 | use_hypernet: False 37 | num_xp_pair_sample: 0 38 | 39 | # these two reduce the trajedi trainer to just be a multi-run trainer 40 | use_br: False 41 | diverse_coef: 0.0 42 | 43 | n_iter: 100 44 | pop_size: 2 45 | kernel_gamma: 0.0 46 | flatten_traj: True 47 | 48 | eval_interval: 50 49 | n_sp_episodes: 50 50 | n_xp_episodes: 50 51 | n_sp_ts: 5000 52 | n_xp_ts: 5000 53 | n_eval_ep: 10 54 | hidden_size: 64 55 | num_xp_pair_sample: 0 56 | 57 | gamma: 0.99 58 | lr: 0.0001 59 | vf_coef: 0.5 60 | ent_coef: 0.03 61 | epochs: 5 62 | num_mb: 3 63 | mb_size: 0 64 | gae_lambda: 0.95 65 | clip_param: 0.3 66 | vf_clip_param: 10.0 67 | env_wrappers: [ZWrapper, AgentIDWrapper, StateWrapper] 68 | -------------------------------------------------------------------------------- /scripts/pmr-l/multi_maven.sh: -------------------------------------------------------------------------------- 1 | for seed in 111 222 333 2 | do 3 | xvfb-run -a python main.py --config_file config/algs/maven/rendezvous.yaml \ 4 | --env_config_file config/envs/rendezvous.yaml \ 5 | --config '{"algo_name": "multi_maven", "discrim_coef": 10, "n_iter": 4000, "n_sp_episodes": 30, "n_workers": 16, "pop_size": 1, "save_folder": "results_sweep_rendezvous", "trainer": "incompat", "vary_z_eval": 1, "z_dim": 8}' \ 6 | --env_config '{"mode": "hard"}' --seed $seed 7 | done 8 | 9 | for seed in 111 222 333 10 | do 11 | xvfb-run -a python main.py --config_file config/algs/maven/rendezvous.yaml \ 12 | --env_config_file config/envs/rendezvous.yaml \ 13 | --config '{"algo_name": "multi_maven", "discrim_coef": 10, "n_iter": 4000, "n_sp_episodes": 30, "n_workers": 16, "pop_size": 2, "save_folder": "results_sweep_rendezvous", "trainer": "incompat", "vary_z_eval": 1, "z_dim": 8}' \ 14 | --env_config '{"mode": "hard"}' --seed $seed 15 | done 16 | 17 | for seed in 111 222 333 18 | do 19 | xvfb-run -a python main.py --config_file config/algs/maven/rendezvous.yaml \ 20 | --env_config_file config/envs/rendezvous.yaml \ 21 | --config '{"algo_name": "multi_maven", "discrim_coef": 5, "n_iter": 4000, "n_sp_episodes": 30, "n_workers": 16, "pop_size": 4, "save_folder": "results_sweep_rendezvous", "trainer": "incompat", "vary_z_eval": 1, "z_dim": 8}' \ 22 | --env_config '{"mode": "hard"}' --seed $seed 23 | done 24 | -------------------------------------------------------------------------------- /config/algs/incompat/default.yaml: -------------------------------------------------------------------------------- 1 | # Format goes like this 2 | # def_config: 3 | # k: v 4 | # config: 5 | # k: v 6 | def_config: !include config/algs/default.yaml 7 | 8 | algo_name: incompat # used for creating a save directory 9 | trainer: incompat 10 | runner: EpisodesRunner 11 | agent_name: IncompatMAPPOZ 12 | render: True 13 | render_only_sp: True 14 | render_mode: 'rgb_array' 15 | eval_interval: 50 16 | n_eval_ep: 10 17 | num_cpus: 1 18 | use_gpu: False 19 | debug: False 20 | flatten_traj: False 21 | training_device: cuda 22 | 23 | parent_only: False 24 | eval_all_pairs: True 25 | n_iter: 400 26 | num_xp_pair_sample: 1000 # sample all pairs by default 27 | eval_interval: 50 28 | n_sp_episodes: 50 29 | n_xp_episodes: 50 30 | n_sp_ts: 5000 31 | n_xp_ts: 5000 32 | n_eval_ep: 10 33 | z_dim: 4 34 | z_discrete: True 35 | pop_size: 2 36 | flatten_traj: False 37 | 38 | lr: 0.0001 39 | use_hypernet: False 40 | pg_xp_max_only: True 41 | value_xp_max_only: False 42 | xp_coef: 1.0 43 | discrim_coef: 0.1 44 | gamma: 0.99 45 | gae_lambda: 0.95 46 | ent_coef: 0.03 47 | clip_param: 0.3 48 | vf_clip_param: 10 49 | vf_coef: 1.0 # 0.5 50 | num_mb: 5 51 | mb_size: 0 52 | epochs: 3 53 | env_wrappers: [ZWrapper, AgentIDWrapper, StateWrapper] 54 | shared_z: False 55 | use_bandit: False 56 | bandit_eps: 0.1 57 | bandit_window_size: 3 58 | uniform_selector_keep_last: False 59 | 60 | anneal_xp: False 61 | xp_coef_stop: 0.1 62 | xp_coef_start: 0.0 63 | n_anneal_iter: 1000 64 | -------------------------------------------------------------------------------- /scripts/pmr-c/trajedi.sh: -------------------------------------------------------------------------------- 1 | for seed in 111 222 333 2 | do 3 | xvfb-run -a python main.py --config_file config/algs/trajedi/rendezvous.yaml \ 4 | --env_config_file config/envs/rendezvous.yaml \ 5 | --config '{"diverse_coef": 10, "kernel_gamma": 0.1, "pop_size": 1, "save_folder": "results_sweep_rendezvous"}' \ 6 | --env_config '{"mode": "easy"}' --seed $seed 7 | done 8 | 9 | for seed in 111 222 333 10 | do 11 | xvfb-run -a python main.py --config_file config/algs/trajedi/rendezvous.yaml \ 12 | --env_config_file config/envs/rendezvous.yaml \ 13 | --config '{"diverse_coef": 50, "kernel_gamma": 0.1, "pop_size": 2, "save_folder": "results_sweep_rendezvous"}' \ 14 | --env_config '{"mode": "easy"}' --seed $seed 15 | done 16 | 17 | for seed in 111 222 333 18 | do 19 | xvfb-run -a python main.py --config_file config/algs/trajedi/rendezvous.yaml \ 20 | --env_config_file config/envs/rendezvous.yaml \ 21 | --config '{"diverse_coef": 1, "kernel_gamma": 0, "pop_size": 4, "save_folder": "results_sweep_rendezvous"}' \ 22 | --env_config '{"mode": "easy"}' --seed $seed 23 | done 24 | 25 | for seed in 111 222 333 26 | do 27 | xvfb-run -a python main.py --config_file config/algs/trajedi/rendezvous.yaml \ 28 | --env_config_file config/envs/rendezvous.yaml \ 29 | --config '{"diverse_coef": 10, "kernel_gamma": 0.5, "pop_size": 8, "save_folder": "results_sweep_rendezvous"}' \ 30 | --env_config '{"mode": "easy"}' --seed $seed 31 | done 32 | -------------------------------------------------------------------------------- /scripts/pmr-l/trajedi.sh: -------------------------------------------------------------------------------- 1 | for seed in 111 222 333 2 | do 3 | xvfb-run -a python main.py --config_file config/algs/trajedi/rendezvous.yaml \ 4 | --env_config_file config/envs/rendezvous.yaml \ 5 | --config '{"diverse_coef": 50, "kernel_gamma": 0.1, "pop_size": 1, "save_folder": "results_sweep_rendezvous"}' \ 6 | --env_config '{"mode": "hard"}' --seed $seed 7 | done 8 | 9 | for seed in 111 222 333 10 | do 11 | xvfb-run -a python main.py --config_file config/algs/trajedi/rendezvous.yaml \ 12 | --env_config_file config/envs/rendezvous.yaml \ 13 | --config '{"diverse_coef": 50, "kernel_gamma": 0, "pop_size": 2, "save_folder": "results_sweep_rendezvous"}' \ 14 | --env_config '{"mode": "hard"}' --seed $seed 15 | done 16 | 17 | for seed in 111 222 333 18 | do 19 | xvfb-run -a python main.py --config_file config/algs/trajedi/rendezvous.yaml \ 20 | --env_config_file config/envs/rendezvous.yaml \ 21 | --config '{"diverse_coef": 5, "kernel_gamma": 0.1, "pop_size": 4, "save_folder": "results_sweep_rendezvous"}' \ 22 | --env_config '{"mode": "hard"}' --seed $seed 23 | done 24 | 25 | for seed in 111 222 333 26 | do 27 | xvfb-run -a python main.py --config_file config/algs/trajedi/rendezvous.yaml \ 28 | --env_config_file config/envs/rendezvous.yaml \ 29 | --config '{"diverse_coef": 5, "kernel_gamma": 0.1, "pop_size": 8, "save_folder": "results_sweep_rendezvous"}' \ 30 | --env_config '{"mode": "hard"}' --seed $seed 31 | done 32 | -------------------------------------------------------------------------------- /coop_marl/utils/rebar/rebar/contextlib.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from functools import wraps 3 | 4 | class MaybeAsyncGeneratorContextManager: 5 | 6 | def __init__(self, func, args, kwargs): 7 | self._func = func 8 | self._args = args 9 | self._kwargs = kwargs 10 | self._sync = None 11 | self._async = None 12 | 13 | def __enter__(self): 14 | if self._sync is None: 15 | syncfunc = contextmanager(self._func) 16 | self._sync = syncfunc(*self._args, **self._kwargs) 17 | return self._sync.__enter__() 18 | 19 | def __exit__(self, t, v, tb): 20 | return self._sync.__exit__(t, v, tb) 21 | 22 | def __aenter__(self): 23 | if self._async is None: 24 | # Hide this 3.8 import; most users will never hit it 25 | from contextlib import asynccontextmanager 26 | 27 | @asynccontextmanager 28 | async def asyncfunc(*args, **kwargs): 29 | with contextmanager(self._func)(*args, **kwargs): 30 | yield 31 | self._async = asyncfunc(*self._args, **self._kwargs) 32 | return self._async.__aenter__() 33 | 34 | def __aexit__(self, t, v, tb): 35 | return self._async.__aexit__(t, v, tb) 36 | 37 | def maybeasynccontextmanager(func): 38 | @wraps(func) 39 | def helper(*args, **kwds): 40 | return MaybeAsyncGeneratorContextManager(func, args, kwds) 41 | return helper -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/utils/new_style_level/full_divider_salad.json: -------------------------------------------------------------------------------- 1 | { 2 | "LEVEL_LAYOUT": "-------\n- - -\n- - -\n- - -\n- - -\n- - -\n-------", 3 | "STATIC_OBJECTS": [{"CutBoard": {"COUNT": 1, "X_POSITION": [0], "Y_POSITION": [3]}}, 4 | {"DeliverSquare": {"COUNT": 1, "X_POSITION": [6], "Y_POSITION": [3]}} 5 | ], 6 | "DYNAMIC_OBJECTS": [ 7 | {"Plate": {"COUNT": 1, "X_POSITION": [3], "Y_POSITION": [1,2,3,4,5]}}, 8 | {"Lettuce": {"COUNT": 1, "X_POSITION": [0], "Y_POSITION": [1,2,3,4,5]}}, 9 | {"Lettuce": {"COUNT": 1, "X_POSITION": [6], "Y_POSITION": [1,2,3,4,5]}}, 10 | {"Tomato": {"COUNT": 1, "X_POSITION": [0], "Y_POSITION": [1,2,3,4,5]}}, 11 | {"Tomato": {"COUNT": 1, "X_POSITION": [1,2], "Y_POSITION": [0]}}, 12 | {"Onion": {"COUNT": 1, "X_POSITION": [4,5], "Y_POSITION": [0]}}, 13 | {"Onion": {"COUNT": 1, "X_POSITION": [6], "Y_POSITION": [1,2,3,4,5]}}, 14 | {"Carrot": {"COUNT": 1, "X_POSITION": [0], "Y_POSITION": [1,2,3,4,5]}}, 15 | {"Carrot": {"COUNT": 1, "X_POSITION": [6], "Y_POSITION": [1,2,3,4,5]}} 16 | ], 17 | "AGENTS": [{"MAX_COUNT": 1, "X_POSITION": [1 ,2], "Y_POSITION": [1 ,2, 3, 4, 5]}, 18 | {"MAX_COUNT": 1, "X_POSITION": [4 ,5], "Y_POSITION": [1 ,2, 3, 4, 5]} 19 | ] 20 | } -------------------------------------------------------------------------------- /coop_marl/utils/rebar/rebar/widgets.py: -------------------------------------------------------------------------------- 1 | import ipywidgets as widgets 2 | from IPython.display import display, clear_output 3 | import threading 4 | 5 | WRITE_LOCK = threading.RLock() 6 | 7 | class Output: 8 | 9 | def __init__(self, compositor, output, lines): 10 | self._compositor = compositor 11 | self._output = output 12 | self.lines = lines 13 | 14 | def refresh(self, content): 15 | # This is not thread-safe, but the recommended way to do 16 | # thread-safeness - to use append_stdout - causes flickering 17 | with WRITE_LOCK, self._output: 18 | clear_output(wait=True) 19 | print(content) 20 | 21 | def close(self): 22 | self._compositor.remove(self._output) 23 | 24 | class Compositor: 25 | 26 | def __init__(self, lines=80): 27 | self.lines = lines 28 | self._box = widgets.HBox( 29 | layout=widgets.Layout(align_items='stretch')) 30 | display(self._box) 31 | 32 | def output(self): 33 | output = widgets.Output( 34 | layout=widgets.Layout(width='100%')) 35 | self._box.children = (*self._box.children, output) 36 | 37 | return Output(self, output, self.lines) 38 | 39 | def remove(self, child): 40 | child.close() 41 | self._box.children = tuple(c for c in self._box.children if c != child) 42 | 43 | def clear(self): 44 | for child in self._box.children: 45 | self.remove(child) 46 | 47 | 48 | def test(): 49 | compositor = Compositor() 50 | first = compositor.output() 51 | second = compositor.output() 52 | 53 | first.refresh('left') 54 | second.refresh('right') -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/demo_multiplayer_gameplay.py: -------------------------------------------------------------------------------- 1 | from gym_cooking.environment.game.game import Game 2 | 3 | from gym_cooking.environment import cooking_zoo 4 | 5 | n_agents = 2 6 | num_humans = 1 7 | max_steps = 100 8 | render = False 9 | 10 | level = 'full_divider_salad_4' # 'open_room_salad_easy' 11 | seed = 3 12 | record = False 13 | max_num_timesteps = 1000 14 | recipes = [ 15 | "LettuceSalad", 16 | "TomatoSalad", 17 | "ChoppedCarrot", 18 | "ChoppedOnion", 19 | "TomatoLettuceSalad", 20 | "TomatoCarrotSalad" 21 | ] 22 | 23 | parallel_env = cooking_zoo.parallel_env(level=level, num_agents=n_agents, record=record, 24 | max_steps=max_num_timesteps, recipes=recipes, obs_spaces=["dense"], 25 | interact_reward=0.5, progress_reward=1.0, complete_reward=10.0, 26 | step_cost=0.05) 27 | 28 | action_spaces = parallel_env.action_spaces 29 | 30 | 31 | class CookingAgent: 32 | 33 | def __init__(self, action_space): 34 | self.action_space = action_space 35 | 36 | def get_action(self, observation) -> int: 37 | return self.action_space.sample() 38 | 39 | player_2_action_space = action_spaces["player_1"] 40 | cooking_agent = CookingAgent(player_2_action_space) 41 | game = Game(parallel_env, num_humans, [cooking_agent], max_steps, render=False) 42 | store = game.on_execute() 43 | 44 | # game = Game(parallel_env, num_humans, [], max_steps, render=False) 45 | # store = game.on_execute() 46 | 47 | # game = Game(parallel_env, 0, [cooking_agent,cooking_agent], max_steps) 48 | # store = game.on_execute_ai_only_with_delay() 49 | 50 | print("done") 51 | -------------------------------------------------------------------------------- /coop_marl/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from coop_marl.utils import Dotdict, chop_into_episodes # , get_logger 4 | 5 | def get_avg_metrics(metrics): 6 | # assume a list of dotdicts of structure {player_name: dotdict} 7 | # each corresponds to one agent 8 | temp = [defaultdict(list) for _ in range(len(metrics))] # list of dicts of lists 9 | for i, m in enumerate(metrics): 10 | for p in m: 11 | for k,v in m[p].items(): 12 | if isinstance(v, list): 13 | temp[i][k].extend(v) 14 | else: 15 | temp[i][k].append(v) 16 | 17 | out = defaultdict(list) 18 | for t in temp: 19 | for k in t: 20 | out[k].append(sum(t[k])/len(t[k])) # mean 21 | 22 | return Dotdict(out) 23 | 24 | def get_info(episodes): 25 | out = Dotdict() 26 | ret = defaultdict(int) 27 | n_ep = defaultdict(int) 28 | n_ts = defaultdict(int) 29 | players = list(episodes[0].inp.data.keys()) 30 | 31 | for ep in episodes: 32 | for p in players: 33 | dones = getattr(ep.outcome.done, p) 34 | if dones[-1]: 35 | rews = ep.outcome.reward[p] 36 | if 'reward_unnorm' in ep.outcome[p]: 37 | rews = ep.outcome[p].reward_unnorm 38 | ret[p] += sum(rews) 39 | n_ts[p] += rews.shape[0] 40 | n_ep[p] += 1 41 | # overcooked log complete dishes 42 | 43 | for p in players: 44 | out[p] = Dotdict() 45 | out[p]['avg_ret'] = ret[p]/n_ep[p] 46 | out[p]['avg_rew_per_ts'] = ret[p]/n_ts[p] 47 | out[p]['avg_ep_len'] = n_ts[p]/n_ep[p] 48 | return out 49 | 50 | def get_traj_info(traj): 51 | episodes = chop_into_episodes(traj) 52 | infos = get_info(episodes) 53 | return infos -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | def wait(): 2 | from coop_marl.utils import input_with_timeout 3 | try: 4 | t = 30 5 | input_with_timeout(f'Press Enter (or wait {t} seconds) to continue...', timeout=t) 6 | except Exception: 7 | print('Input timed out, executing the next command.') 8 | 9 | def main(): 10 | 11 | from coop_marl.utils import pblock, parse_args, create_parser 12 | args, conf, env_conf, trainer = parse_args(create_parser()) 13 | import sys 14 | from coop_marl.utils import get_logger, set_random_seed 15 | logger = get_logger() 16 | logger.info(pblock(' '.join(sys.argv), 'Argv...')) 17 | logger.info(pblock(args, 'CLI arguments...')) 18 | logger.info(pblock(conf, 'Training config...')) 19 | logger.info(pblock(env_conf, 'Environment config...')) 20 | # wait() 21 | set_random_seed(args.seed) 22 | # wandb.init(project=env_name, name=run_name, dir=conf['save_dir'], mode='offline', resume=True) 23 | # import wandb 24 | # wandb.init(project=..., 25 | # name= 26 | # pytorch=True) 27 | 28 | from tqdm import tqdm 29 | from coop_marl.trainers import registered_trainers 30 | trainer = registered_trainers[trainer](conf, env_conf) 31 | start_iter = trainer.iter 32 | save_interval = conf.save_interval if conf.save_interval else conf.eval_interval 33 | for i in tqdm(range(start_iter,conf.n_iter)): 34 | _ = trainer.train() # collect data and update the agents 35 | if ((i+1) % conf.eval_interval==0) or ((i+1)==conf.n_iter) or (i==0): 36 | _ = trainer.evaluate() 37 | if ((i+1) % save_interval==0) or ((i+1)==conf.n_iter) or (i==0): 38 | trainer.save() 39 | try: 40 | import ray 41 | ray.shutdown() 42 | logger.info(f'Ray is shutdown...') 43 | except Exception as e: 44 | logger.error(e) 45 | # wandb.finish() 46 | logger.close() 47 | if conf.render: 48 | import subprocess 49 | subprocess.run([f'python gif_view.py --path {conf["save_dir"]}'], shell=True) 50 | 51 | if __name__=='__main__': 52 | main() 53 | -------------------------------------------------------------------------------- /coop_marl/utils/rebar/rebar/stats/gpu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | from io import BytesIO 4 | from subprocess import check_output 5 | from . import writing 6 | import time 7 | 8 | 9 | def memory(device=0): 10 | total_mem = torch.cuda.get_device_properties(f'cuda:{device}').total_memory 11 | writing.max(f'gpu-memory/cache/{device}', torch.cuda.max_memory_cached(device)/total_mem) 12 | torch.cuda.reset_max_memory_cached() 13 | writing.max(f'gpu-memory/alloc/{device}', torch.cuda.max_memory_allocated(device)/total_mem) 14 | torch.cuda.reset_max_memory_allocated() 15 | torch.cuda.reset_max_memory_cached() 16 | 17 | def dataframe(): 18 | """Use `nvidia-smi --help-query-gpu` to get a list of query params""" 19 | params = { 20 | 'device': 'index', 21 | 'compute': 'utilization.gpu', 'access': 'utilization.memory', 22 | 'memused': 'memory.used', 'memtotal': 'memory.total', 23 | 'fan': 'fan.speed', 'power': 'power.draw', 'temp': 'temperature.gpu'} 24 | command = f"""nvidia-smi --format=csv,nounits,noheader --query-gpu={','.join(params.values())}""" 25 | df = pd.read_csv(BytesIO(check_output(command, shell=True)), header=None) 26 | df.columns = list(params.keys()) 27 | df = df.set_index('device') 28 | df = df.apply(pd.to_numeric, errors='coerce') 29 | return df 30 | 31 | _last = -1 32 | def vitals(device=None, throttle=0): 33 | # This is a fairly expensive op, so let's avoid doing it too often 34 | global _last 35 | if time.time() - _last < throttle: 36 | return 37 | _last = time.time() 38 | 39 | df = dataframe() 40 | if device is None: 41 | pass 42 | elif isinstance(device, int): 43 | df = df.loc[[device]] 44 | else: 45 | df = df.loc[device] 46 | 47 | fields = ['compute', 'access', 'fan', 'power', 'temp'] 48 | for (device, field), value in df[fields].stack().iteritems(): 49 | writing.mean(f'gpu/{field}/{device}', value) 50 | 51 | for device in df.index: 52 | writing.mean(f'gpu/memory/{device}', 100*df.loc[device, 'memused']/df.loc[device, 'memtotal']) -------------------------------------------------------------------------------- /coop_marl/utils/rebar/rebar/recurrence.py: -------------------------------------------------------------------------------- 1 | from . import arrdict 2 | from torch import nn 3 | from contextlib import contextmanager 4 | 5 | class State: 6 | 7 | def __init__(self): 8 | super().__init__() 9 | 10 | self._value = None 11 | 12 | def get(self, factory=None): 13 | if self._value is None and factory is not None: 14 | self._value = factory() 15 | return self._value 16 | 17 | def set(self, value): 18 | self._value = value 19 | 20 | def clear(self): 21 | self._value = None 22 | 23 | def __repr__(self): 24 | return f'State({self._value})' 25 | 26 | def __str__(self): 27 | return repr(self) 28 | 29 | def states(net): 30 | substates = {k: states(v) for k, v in net.named_children()} 31 | ownstates = {k: getattr(net, k) for k in dir(net) if isinstance(getattr(net, k), State)} 32 | return arrdict.arrdict({k: v for k, v in {**ownstates, **substates}.items() if v}) 33 | 34 | def _nonnull(x): 35 | y = type(x)() 36 | for k, v in x.items(): 37 | if isinstance(v, dict): 38 | subtree = _nonnull(v) 39 | if subtree: 40 | y[k] = subtree 41 | elif v is not None: 42 | y[k] = v 43 | return y 44 | 45 | def get(net): 46 | return _nonnull(states(net).map(lambda s: s.get())) 47 | 48 | def set(net, state): 49 | state.starmap(lambda r, n: n.set(r), states(net)) 50 | 51 | def clear(net): 52 | states(net).map(lambda s: s.clear()) 53 | 54 | @contextmanager 55 | def temp_clear(net): 56 | original = get(net) 57 | clear(net) 58 | try: 59 | yield 60 | finally: 61 | set(net, original) 62 | 63 | @contextmanager 64 | def temp_set(net, state): 65 | original = get(net) 66 | set(net, state) 67 | try: 68 | yield 69 | finally: 70 | set(net, original) 71 | 72 | @contextmanager 73 | def temp_clear_set(net, state): 74 | with temp_clear(net), temp_set(net, state): 75 | yield net 76 | 77 | class Sequential(nn.Sequential): 78 | 79 | def forward(self, input, **kwargs): 80 | for module in self: 81 | input = module(input, **kwargs) 82 | return input -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generating Diverse Cooperative Agents by Learning Incompatible Policies 2 | 3 | [Paper](https://openreview.net/forum?id=UkU05GOH7_6) | [Project page](https://sites.google.com/view/iclr-lipo-2023) 4 | 5 | ### Installation 6 | 7 | If you don't have conda, install conda first. 8 | ``` 9 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 10 | bash Miniconda3-latest-Linux-x86_64.sh 11 | ``` 12 | 13 | Create a new environment and install cuda and torch. 14 | ``` 15 | conda create -n coop_marl python=3.8 16 | conda activate coop_marl 17 | ``` 18 | 19 | The installer will use CUDA 11.1, so make sure that your current Nvidia driver supports that. 20 | You can install Nvidia driver using `sudo apt install nvidia-driver-515`. You can change the number to install a different version. 21 | 22 | If torch can't see your GPU, add this to your `~/.bashrc`. 23 | ``` 24 | export LD_LIBRARY_PATH=$CONDA_PREFIX/lib/:$LD_LIBRARY_PATH 25 | ``` 26 | 27 | Installing dependencies. 28 | ``` 29 | ./install.sh 30 | ``` 31 | 32 | ### Training 33 | 34 | Examples of training command are in `scripts/`. The commands are based on the best searched hyperparameters of the corresponding algorithm and environment. If you want to run the scripts, make sure that you are currently at the root of the project. 35 | 36 | For generalist agents, you can use the following commands: 37 | ``` 38 | xvfb-run -a python main.py --config_file config/algs/meta/overcooked.yaml \ 39 | --env_config_file config/envs/overcooked.yaml --config {"partner_dir": ["..."], "render": 0} 40 | ``` 41 | where `partner_dir` is the path to the training partners e.g., `training_partners_8/overcooked_full_divider_salad_4/trajedi/20220919-233301`. 42 | Generalist agents can only be trained after you obtained the training partners. 43 | 44 | ### BibTeX 45 | ``` 46 | @inproceedings{charakorn2023generating, 47 | title={Generating Diverse Cooperative Agents by Learning Incompatible Policies}, 48 | author={Rujikorn Charakorn and Poramate Manoonpong and Nat Dilokthanakul}, 49 | booktitle={The Eleventh International Conference on Learning Representations }, 50 | year={2023}, 51 | url={https://openreview.net/forum?id=UkU05GOH7_6} 52 | } 53 | ``` 54 | -------------------------------------------------------------------------------- /coop_marl/envs/gym_maker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | from copy import copy 4 | 5 | from coop_marl.utils import Arrdict, Dotdict 6 | from coop_marl.envs.wrappers import SARDConsistencyChecker 7 | 8 | class GymMaker: 9 | def __init__(self, env_name): 10 | self._env = gym.make(env_name) 11 | self.players = ['player_0'] 12 | self.action_spaces = Dotdict({self.players[0]:copy(self._env.action_space)}) 13 | self.observation_spaces = Dotdict({self.players[0]:Dotdict(obs=copy(self._env.observation_space))}) 14 | self.total_steps = 0 15 | 16 | def get_action_space(self): 17 | return self._env.action_space 18 | 19 | def get_observation_space(self): 20 | return Dotdict(obs=self._env.observation_space) 21 | 22 | def reset(self): 23 | obs = self._env.reset() 24 | data = Arrdict() 25 | data[self.players[0]] = Arrdict(obs=obs.astype(np.float32), reward=np.float32(0), done=False) 26 | return data 27 | 28 | def step(self,decision): 29 | self.total_steps += 1 30 | action = decision[self.players[0]]['action'] 31 | obs, reward, done, info = self._env.step(action) 32 | data = Arrdict() 33 | data[self.players[0]] = Arrdict(obs=obs.astype(np.float32), reward=reward.astype(np.float32), done=done) 34 | return data, Dotdict(info) 35 | 36 | def render(self, mode): 37 | return self._env.render(mode) 38 | 39 | @staticmethod 40 | def make_env(*args,**kwargs): 41 | env = GymMaker(*args,**kwargs) 42 | env = SARDConsistencyChecker(env) 43 | return env 44 | 45 | if __name__ == '__main__': 46 | from coop_marl.controllers import RandomController 47 | from coop_marl.runners import StepsRunner 48 | import argparse 49 | 50 | parser = argparse.ArgumentParser(description='DQN agent') 51 | # Common arguments 52 | parser.add_argument('--env_name', type=str, default='CartPole-v1', 53 | help='name of the env') 54 | args = parser.parse_args() 55 | 56 | env = GymMaker(args.env_name) 57 | action_spaces = env.action_spaces 58 | controller = RandomController(action_spaces) 59 | runner = StepsRunner(env, controller) 60 | for i in range(20): 61 | traj, *_ = runner.rollout(1) 62 | print(traj.data.player_1) 63 | -------------------------------------------------------------------------------- /coop_marl/utils/rebar/rebar/stats/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | from contextlib import contextmanager 4 | from functools import partial 5 | from torch import nn 6 | import logging 7 | import pandas as pd 8 | from .. import paths 9 | 10 | log = logging.getLogger(__name__) 11 | 12 | # For re-export 13 | from .writing import * 14 | from .writing import to_dir, mean 15 | from .reading import from_dir, Reader 16 | from . import gpu 17 | 18 | @contextmanager 19 | def via_dir(run_name, *args, **kwargs): 20 | with to_dir(run_name), from_dir(run_name, *args, **kwargs): 21 | yield 22 | 23 | def normhook(name, t): 24 | 25 | def hook(grad): 26 | mean(name, grad.pow(2).sum().pow(.5)) 27 | 28 | t.register_hook(hook) 29 | 30 | def total_gradient_norm(params): 31 | if isinstance(params, nn.Module): 32 | return total_gradient_norm(params.parameters()) 33 | norms = [p.grad.data.float().pow(2).sum() for p in params if p.grad is not None] 34 | return torch.sum(torch.tensor(norms)).pow(.5) 35 | 36 | def total_norm(params): 37 | if isinstance(params, nn.Module): 38 | return total_norm(params.parameters()) 39 | return sum([p.data.float().pow(2).sum() for p in params if p is not None]).pow(.5) 40 | 41 | def rel_gradient_norm(name, agent): 42 | mean(name, total_gradient_norm(agent), total_norm(agent)) 43 | 44 | def funcduty(name): 45 | def factory(f): 46 | def g(self, *args, **kwargs): 47 | start = time.time() 48 | result = f(self, *args, **kwargs) 49 | record('duty', f'duty/{name}', time.time() - start) 50 | return result 51 | return g 52 | return factory 53 | 54 | def compare(run_names=[-1], prefix='', rule='60s'): 55 | return pd.concat({paths.resolve(run): Reader(run, prefix).resample(rule) for run in run_names}, 1) 56 | 57 | ## TESTS 58 | 59 | def test_from_dir(): 60 | from .. import paths, widgets, logging 61 | paths.clear('test-run', 'stats') 62 | paths.clear('test-run', 'logs') 63 | 64 | compositor = widgets.Compositor() 65 | with logging.from_dir('test-run', compositor), \ 66 | to_dir('test-run'), \ 67 | from_dir('test-run', compositor): 68 | for i in range(10): 69 | mean('count', i) 70 | mean('twocount', 2*i) 71 | time.sleep(.25) 72 | -------------------------------------------------------------------------------- /gif_view.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | 3 | def get_it(path): 4 | return int(path.split('/')[-3][3:]) 5 | 6 | def get_pair(path): 7 | home,away = path.split('.')[0].split('/')[-1].split('-') 8 | return (int(home),int(away)) 9 | 10 | def sort_by_it(gifs): 11 | dec = [(-get_it(g),get_pair(g),g) for g in gifs] 12 | dec.sort() 13 | out = [g for it,pair,g in dec] 14 | return out 15 | 16 | def generate(path=None, folder_name=None, env_name=None, algo_name=None, run_name=None): 17 | if path is None: 18 | # get the latest run 19 | # the path is env_name/algo_name/path 20 | path = sorted(glob(f'{folder_name}/{env_name}/{algo_name}/{run_name}'), key=os.path.getmtime)[-1] 21 | 22 | with open(f'{path}/gif_view.html', 'w') as f: 23 | cur_it = -1 24 | cur_player = -1 25 | f.write(f'

Run name: {path}


') 26 | # gifs = glob(f'{path}/renders/*.gif') 27 | gifs = glob(f'{path}/*/renders/*.gif') 28 | # for gif in sorted(gifs, key=lambda x:-int(x.split('/')[-3][3:])): 29 | for gif in sort_by_it(gifs): 30 | it = int(gif.split('/')[-3][3:]) 31 | if it != cur_it: 32 | cur_it = it 33 | f.write(f'

Iteration: {it}


') 34 | player = int(gif.split('/')[-1].split('-')[0]) 35 | if cur_player != player: 36 | cur_player = player 37 | f.write('
') 38 | img_path = "/".join(gif.split("/")[-3:]) 39 | f.write(f'') 40 | return os.path.abspath(path) 41 | 42 | if __name__=='__main__': 43 | import argparse 44 | import webbrowser 45 | import os 46 | 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument('--path', type=str, default=None) 49 | parser.add_argument('--folder_name', type=str, default='results') 50 | parser.add_argument('--env_name', type=str, default='*') 51 | parser.add_argument('--algo_name', type=str, default='*') 52 | parser.add_argument('--run_name', type=str, default='*') 53 | args = parser.parse_args() 54 | 55 | # assert args.path ^ (args.env_name | args.algo_name | args.run_name) 56 | # path = args.path 57 | # if path is None: 58 | # path = f'{args.env_name}/{args.algo_name}/{args.run_name}' 59 | path = generate(args.path, args.folder_name, args.env_name, args.algo_name, args.run_name) 60 | print(path) 61 | # cwd = os.path.dirname(os.path.realpath(__file__)) 62 | webbrowser.open_new_tab(f'{path}/gif_view.html') -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/cooking_book/recipe.py: -------------------------------------------------------------------------------- 1 | from gym_cooking.cooking_world.world_objects import * 2 | from gym_cooking.cooking_world.cooking_world import CookingWorld 3 | 4 | import numpy as np 5 | 6 | 7 | class NodeTypes(Enum): 8 | CHECKPOINT = "Checkpoint" 9 | ACTION = "Action" 10 | 11 | 12 | class RecipeNode: 13 | 14 | def __init__(self, root_type, id_num, name, parent=None, conditions=None, contains=None, 15 | node_type=NodeTypes.CHECKPOINT): 16 | self.parent = parent 17 | self.achieved = False 18 | self.id_num = id_num 19 | self.root_type = root_type 20 | self.conditions = conditions or [] 21 | self.contains = contains or [] 22 | self.world_objects = [] 23 | self.name = name 24 | self.node_type = node_type 25 | 26 | def is_leaf(self): 27 | return not bool(self.contains) 28 | 29 | 30 | class Recipe: 31 | 32 | def __init__(self, root_node: RecipeNode, name=None): 33 | self.root_node = root_node 34 | self.node_list = [root_node] + self.expand_child_nodes(root_node) 35 | self.name = name 36 | 37 | def goals_completed(self, num_goals): 38 | goals = np.zeros(num_goals, dtype=np.int32) 39 | for node in self.node_list: 40 | goals[node.id_num] = int(not node.achieved) 41 | return goals 42 | 43 | def completed(self): 44 | return self.root_node.achieved 45 | 46 | def update_recipe_state(self, world: CookingWorld): 47 | for node in reversed(self.node_list): 48 | node.achieved = False 49 | node.world_objects = [] 50 | if not all((contains.achieved for contains in node.contains)): 51 | continue 52 | for obj in world.world_objects[node.name]: 53 | # check for all conditions 54 | if self.check_conditions(node, obj): 55 | node.world_objects.append(obj) 56 | node.achieved = True 57 | 58 | def expand_child_nodes(self, node: RecipeNode): 59 | child_nodes = [] 60 | for child in node.contains: 61 | child_nodes.extend(self.expand_child_nodes(child)) 62 | return node.contains + child_nodes 63 | 64 | @staticmethod 65 | def check_conditions(node: RecipeNode, world_object): 66 | for condition in node.conditions: 67 | if getattr(world_object, condition[0]) != condition[1]: 68 | return False 69 | else: 70 | all_contained = [] 71 | for contains in node.contains: 72 | all_contained.append(any([obj.location == world_object.location for obj in contains.world_objects])) 73 | return all(all_contained) 74 | -------------------------------------------------------------------------------- /coop_marl/utils/rebar/rebar/paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import shutil 4 | import multiprocessing as mp 5 | import pandas as pd 6 | import re 7 | from . import dotdict 8 | 9 | ROOT = 'output/traces' 10 | 11 | def resolve(run_name): 12 | if isinstance(run_name, str): 13 | return run_name 14 | if isinstance(run_name, int): 15 | times = {p: p.stat().st_ctime for p in Path(ROOT).iterdir()} 16 | paths = sorted(times, key=times.__getitem__) 17 | return paths[run_name].parts[-1] 18 | raise ValueError(f'Can\'t find a run corresponding to {run_name}') 19 | 20 | def run_dir(run_name): 21 | run_name = resolve(run_name) 22 | return Path(ROOT) / run_name 23 | 24 | def subdirectory(run_name, group, channel=''): 25 | if channel: 26 | return run_dir(run_name) / group / channel 27 | else: 28 | return run_dir(run_name) / group 29 | 30 | def clear(run_name, group=None): 31 | if group is None: 32 | shutil.rmtree(run_dir(run_name), ignore_errors=True) 33 | else: 34 | shutil.rmtree(subdirectory(run_name, group), ignore_errors=True) 35 | 36 | def path(run_name, group, channel=''): 37 | # Python's idea of a process name is different from the system's idea. Dunno where 38 | # the difference comes from. 39 | run_name = resolve(run_name) 40 | 41 | proc = mp.current_process() 42 | 43 | for x in [run_name, group]: 44 | for c in ['_', os.sep]: 45 | assert c not in x, f'Can\'t have "{c}" in the file path' 46 | 47 | path = subdirectory(run_name, group, channel) / f'{proc.name}-{proc.pid}' 48 | 49 | path.parent.mkdir(exist_ok=True, parents=True) 50 | return path 51 | 52 | def glob(run_name, group, channel='', pattern='*'): 53 | paths = subdirectory(run_name, group, channel).glob(pattern) 54 | return sorted(paths, key=lambda p: p.stat().st_mtime) 55 | 56 | def parse(path): 57 | parts = path.relative_to(ROOT).with_suffix('').parts 58 | procname, pid = re.match(r'^(.*)-(.*)$', parts[-1]).groups() 59 | return dotdict.dotdict( 60 | run_name=parts[0], 61 | group=parts[1], 62 | channel='/'.join(parts[2:-1]), 63 | filename=parts[-1], 64 | procname=procname, 65 | pid=pid) 66 | 67 | def runs(): 68 | paths = [] 69 | for p in Path(ROOT).iterdir(): 70 | paths.append({ 71 | 'path': p, 72 | 'created': pd.Timestamp(p.stat().st_ctime, unit='s'), 73 | 'run_name': p.parts[-1]}) 74 | return pd.DataFrame(paths).sort_values('created').reset_index(drop=True) 75 | 76 | def size(run_name, group): 77 | run_name = resolve(run_name) 78 | b = sum(item.stat().st_size for item in subdirectory(run_name, group).glob('**/*.*')) 79 | return b/1e6 80 | -------------------------------------------------------------------------------- /coop_marl/utils/rebar/rebar/stats/categories.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import torch 4 | import matplotlib.pyplot as plt 5 | import pandas as pd 6 | import logging 7 | 8 | log = logging.getLogger(__name__) 9 | 10 | CATEGORIES = {} 11 | def category(M): 12 | CATEGORIES[M.__name__.lower()] = M 13 | return M 14 | 15 | @category 16 | def last(x): 17 | def resample(**kwargs): 18 | return x.resample(**kwargs).last() 19 | return resample 20 | 21 | @category 22 | def max(x): 23 | def resample(**kwargs): 24 | return x.resample(**kwargs).max() 25 | return resample 26 | 27 | @category 28 | def mean(total, count=1): 29 | def resample(**kwargs): 30 | return total.resample(**kwargs).mean()/count.resample(**kwargs).mean() 31 | return resample 32 | 33 | @category 34 | def std(x): 35 | def resample(**kwargs): 36 | return x.resample(**kwargs).std() 37 | return resample 38 | 39 | @category 40 | def cumsum(total=1): 41 | def resample(**kwargs): 42 | return total.resample(**kwargs).sum().cumsum() 43 | return resample 44 | 45 | @category 46 | def timeaverage(x): 47 | def resample(**kwargs): 48 | # TODO: To do this properly, I need to get individual per-device streams 49 | y = x.sort_index() 50 | dt = y.index.to_series().diff().dt.total_seconds() 51 | return (y*dt).resample(**kwargs).mean()/dt.resample(**kwargs).mean() 52 | return resample 53 | 54 | @category 55 | def duty(duration): 56 | def resample(**kwargs): 57 | sums = duration.resample(**kwargs).sum() 58 | periods = sums.index.to_series().diff().dt.total_seconds() 59 | return sums/periods 60 | return resample 61 | 62 | @category 63 | def maxrate(duration, count=1): 64 | def resample(**kwargs): 65 | return count.resample(**kwargs).mean()/duration.resample(**kwargs).mean() 66 | return resample 67 | 68 | @category 69 | def rate(count=1): 70 | def resample(**kwargs): 71 | counts = count.resample(**kwargs).sum() 72 | dt = pd.to_timedelta(counts.index.freq).total_seconds() 73 | dt = min(dt, (count.index[-1] - count.index[0]).total_seconds()) 74 | return counts/dt 75 | return resample 76 | 77 | @category 78 | def period(count=1): 79 | def resample(**kwargs): 80 | counts = count.resample(**kwargs).sum() 81 | dt = pd.to_timedelta(counts.index.freq).total_seconds() 82 | dt = min(dt, (count.index[-1] - count.index[0]).total_seconds()) 83 | return dt/counts 84 | return resample 85 | 86 | @category 87 | def dist(samples, size=10000): 88 | return samples 89 | 90 | @category 91 | def noisescale(S, G2, B): 92 | def resample(**kwargs): 93 | return S.resample(**kwargs).mean()/G2.resample(**kwargs).mean() 94 | return resample -------------------------------------------------------------------------------- /coop_marl/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | from torch.utils.tensorboard import SummaryWriter 5 | 6 | logger = None 7 | 8 | def get_logger(log_dir='tmp/', debug=False): 9 | global logger 10 | if logger is None: 11 | logger = Logger(log_dir, debug) 12 | return logger 13 | 14 | def create_logger(log_dir='', debug=False): 15 | '''Create a global logger that logs INFO level messages to stdout and DEBUG ones to debug.log''' 16 | logger = logging.getLogger() 17 | log_formatter = logging.Formatter(fmt='%(asctime)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 18 | stream_level = logging.DEBUG if debug else logging.INFO 19 | stream_handler = logging.StreamHandler() 20 | stream_handler.setFormatter(logging.Formatter(fmt=None)) 21 | stream_handler.setLevel(stream_level) 22 | logger.addHandler(stream_handler) 23 | 24 | loc = 'debug.log' 25 | if log_dir: 26 | os.makedirs(log_dir, exist_ok=True) 27 | loc = f'{log_dir}/debug.log' 28 | debug_handler = logging.FileHandler(loc, delay=True) 29 | debug_handler.setLevel(logging.DEBUG) 30 | debug_handler.setFormatter(log_formatter) 31 | logger.addHandler(debug_handler) 32 | logger.setLevel(logging.DEBUG) 33 | return logger 34 | 35 | def pblock(msg, msg_header=''): 36 | out = [] 37 | out.append('='*60) 38 | out.append(msg_header) 39 | out.append(str(msg)) 40 | out.append('='*60) 41 | return '\n'.join(out) 42 | 43 | class Logger: 44 | """Combine logger and SummaryWriter into a single object""" 45 | def __init__(self, log_dir, debug): 46 | self.writer = SummaryWriter(log_dir) 47 | self.logger = create_logger(log_dir, debug) 48 | 49 | def __getattr__(self, attr): 50 | if attr in self.__dict__: 51 | return self.__dict__[attr] 52 | elif getattr(self.logger, attr, None): 53 | return getattr(self.logger, attr) 54 | else: 55 | return getattr(self.writer, attr) 56 | 57 | def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None): 58 | for k,v in tag_scalar_dict.items(): 59 | self.writer.add_scalar(f'{main_tag}/{k}', v, global_step, walltime) 60 | 61 | def close(self): 62 | logging.shutdown() 63 | self.writer.close() 64 | 65 | # taken from https://stackoverflow.com/questions/287871/how-to-print-colored-text-to-the-terminal 66 | # https://svn.blender.org/svnroot/bf-blender/trunk/blender/build_files/scons/tools/bcolors.py 67 | class bcolors: 68 | HEADER = '\033[95m' 69 | OKBLUE = '\033[94m' 70 | OKGREEN = '\033[92m' 71 | WARNING = '\033[93m' 72 | FAIL = '\033[91m' 73 | ENDC = '\033[0m' 74 | 75 | def disable(self): 76 | self.HEADER = '' 77 | self.OKBLUE = '' 78 | self.OKGREEN = '' 79 | self.WARNING = '' 80 | self.FAIL = '' 81 | self.ENDC = '' -------------------------------------------------------------------------------- /coop_marl/envs/mpe/_mpe_utils/simple_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pettingzoo.mpe._mpe_utils.simple_env import SimpleEnv as BaseSimpleEnv 3 | from pettingzoo.mpe._mpe_utils import rendering 4 | 5 | class SimpleEnv(BaseSimpleEnv): 6 | def render(self, mode='human'): 7 | if self.viewer is None: 8 | self.viewer = rendering.Viewer(200, 200) 9 | # fix the camera position with initial positions 10 | all_poses = [entity.state.p_pos for entity in self.world.entities] 11 | cam_range = np.max(np.abs(np.array(all_poses))) + 1 12 | self.viewer.set_max_size(cam_range) 13 | 14 | # create rendering geometry 15 | if self.render_geoms is None: 16 | self.render_geoms = [] 17 | self.render_geoms_xform = [] 18 | for entity in self.world.entities: 19 | geom = rendering.make_circle(entity.size) 20 | xform = rendering.Transform() 21 | if 'agent' in entity.name: 22 | geom.set_color(*entity.color[:3], alpha=0.5) 23 | else: 24 | geom.set_color(*entity.color[:3]) 25 | geom.add_attr(xform) 26 | self.render_geoms.append(geom) 27 | self.render_geoms_xform.append(xform) 28 | 29 | # add geoms to viewer 30 | self.viewer.geoms = [] 31 | for geom in self.render_geoms: 32 | self.viewer.add_geom(geom) 33 | 34 | self.viewer.text_lines = [] 35 | idx = 0 36 | for agent in self.world.agents: 37 | if not agent.silent: 38 | tline = rendering.TextLine(self.viewer.window, idx) 39 | self.viewer.text_lines.append(tline) 40 | idx += 1 41 | 42 | alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 43 | for idx, other in enumerate(self.world.agents): 44 | if other.silent: 45 | continue 46 | if np.all(other.state.c == 0): 47 | word = '_' 48 | elif self.continuous_actions: 49 | word = '[' + ",".join([f"{comm:.2f}" for comm in other.state.c]) + "]" 50 | else: 51 | word = alphabet[np.argmax(other.state.c)] 52 | 53 | message = (other.name + ' sends ' + word + ' ') 54 | 55 | self.viewer.text_lines[idx].set_text(message) 56 | 57 | # update bounds to center around agent 58 | # all_poses = [entity.state.p_pos for entity in self.world.entities] 59 | # cam_range = np.max(np.abs(np.array(all_poses))) + 1 60 | # self.viewer.set_max_size(cam_range) 61 | # update geometry positions 62 | for e, entity in enumerate(self.world.entities): 63 | self.render_geoms_xform[e].set_translation(*entity.state.p_pos) 64 | # render to display or array 65 | return self.viewer.render(return_rgb_array=mode == 'rgb_array') -------------------------------------------------------------------------------- /coop_marl/envs/one_step_matrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.spaces import Box, Discrete 3 | 4 | from coop_marl.utils import Arrdict, Dotdict 5 | from coop_marl.envs.wrappers import SARDConsistencyChecker 6 | 7 | class OneStepMatrixGame: 8 | def __init__(self,n_conventions,payoffs,k,*args,**kwargs): 9 | self.players = ['player_0','player_1'] 10 | if isinstance(k, list): 11 | assert len(k)==n_conventions 12 | if isinstance(k, int): 13 | k = [k] * n_conventions 14 | 15 | self.n_actions = sum(k) 16 | print(f'N actions: {self.n_actions}') 17 | self.payoff_matrix = np.zeros([self.n_actions, self.n_actions], dtype=np.float32) 18 | if isinstance(payoffs, str): 19 | payoffs = eval(payoffs) 20 | for m in range(n_conventions): 21 | start = sum(k[:m]) 22 | end = sum(k[:m+1]) 23 | self.payoff_matrix[start:end, start:end] = payoffs[m] 24 | 25 | self.action_spaces = Dotdict({player: self.get_action_space() for player in self.players}) 26 | self.observation_spaces = Dotdict({player: self.get_observation_space() for player in self.players}) 27 | 28 | def get_action_space(self): 29 | return Discrete(self.n_actions) 30 | 31 | def get_observation_space(self): 32 | return Dotdict(obs=Box(low=0, high=1, shape=[1], dtype=np.float32)) 33 | 34 | def set_players(self, players): 35 | self.players = players 36 | 37 | def reset(self): 38 | data = Arrdict() 39 | for p in self.players: 40 | data[p] = Arrdict(obs=np.array([1.0], dtype=np.float32), 41 | reward=np.float32(0), 42 | done=False) 43 | return data 44 | 45 | def step(self, decision): 46 | done = True 47 | joint_actions = list(decision.action.values()) 48 | a1, a2 = joint_actions 49 | reward = self.payoff_matrix[a1,a2] 50 | 51 | data = Arrdict() 52 | for p in self.players: 53 | data[p] = Arrdict(obs=np.array([0.0], dtype=np.float32), 54 | reward=np.float32(reward), 55 | done=done) 56 | return data, {} 57 | 58 | def render(self, *args, **kwargs): 59 | return 60 | 61 | @staticmethod 62 | def make_env(*args,**kwargs): 63 | env = OneStepMatrixGame(*args,**kwargs) 64 | env = SARDConsistencyChecker(env) 65 | return env 66 | 67 | if __name__ == '__main__': 68 | env = OneStepMatrixGame.make_env(n_conventions=3,k=[3,2,1],payoffs='np.linspace(1,0.5,n_conventions)') 69 | data = env.reset() 70 | action_spaces = env.action_spaces 71 | 72 | while True: 73 | decision = Arrdict() 74 | for p in action_spaces: 75 | a = action_spaces[p].sample() 76 | decision[p] = Arrdict(action=a) 77 | for p in decision: 78 | print(f'{p} obs: {data.obs[p]}') 79 | print(f'{p} action: {decision[p].action}') 80 | data, *_= env.step(decision) 81 | for p in data: 82 | print(f'{p} reward: {data.reward[p]}') 83 | print() 84 | if data.player_0.done: 85 | for p in data: 86 | print(f'{p} obs: {data.obs[p]}') 87 | break 88 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/cooking_world/abstract_classes.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod, ABC 2 | from gym_cooking.cooking_world.constants import * 3 | 4 | 5 | class Object(ABC): 6 | 7 | def __init__(self, location, movable, walkable): 8 | self.location = location 9 | self.movable = movable # you can pick this one up 10 | self.walkable = walkable # you can walk on it 11 | 12 | def name(self) -> str: 13 | return type(self).__name__ 14 | 15 | def move_to(self, new_location): 16 | self.location = new_location 17 | 18 | @abstractmethod 19 | def file_name(self) -> str: 20 | pass 21 | 22 | 23 | class ActionObject(ABC): 24 | 25 | @abstractmethod 26 | def action(self, objects): 27 | pass 28 | 29 | 30 | class ProgressingObject(ABC): 31 | 32 | @abstractmethod 33 | def progress(self, dynamic_objects): 34 | pass 35 | 36 | 37 | class StaticObject(Object): 38 | 39 | def __init__(self, location, walkable): 40 | super().__init__(location, False, walkable) 41 | 42 | def move_to(self, new_location): 43 | raise Exception(f"Can't move static object {self.name()}") 44 | 45 | @abstractmethod 46 | def accepts(self, dynamic_objects) -> bool: 47 | pass 48 | 49 | 50 | class DynamicObject(Object, ABC): 51 | 52 | def __init__(self, location): 53 | super().__init__(location, True, False) 54 | 55 | 56 | class Container(DynamicObject, ABC): 57 | 58 | def __init__(self, location, content=None): 59 | super().__init__(location) 60 | self.content = content or [] 61 | 62 | def move_to(self, new_location): 63 | for content in self.content: 64 | content.move_to(new_location) 65 | self.location = new_location 66 | 67 | def add_content(self, content): 68 | self.content.append(content) 69 | 70 | 71 | class Food: 72 | 73 | @abstractmethod 74 | def done(self): 75 | pass 76 | 77 | 78 | class ChopFood(DynamicObject, Food, ABC): 79 | 80 | def __init__(self, location): 81 | super().__init__(location) 82 | self.chop_state = ChopFoodStates.FRESH 83 | 84 | def chop(self): 85 | if self.done(): 86 | return False 87 | self.chop_state = ChopFoodStates.CHOPPED 88 | return True 89 | 90 | 91 | class BlenderFood(DynamicObject, Food, ABC): 92 | 93 | def __init__(self, location): 94 | super().__init__(location) 95 | self.current_progress = 10 96 | self.max_progress = 0 97 | self.min_progress = 10 98 | self.blend_state = BlenderFoodStates.FRESH 99 | 100 | def blend(self): 101 | if self.done(): 102 | return False 103 | if self.blend_state == BlenderFoodStates.FRESH or self.blend_state == BlenderFoodStates.IN_PROGRESS: 104 | self.current_progress -= 1 105 | self.blend_state = BlenderFoodStates.IN_PROGRESS if self.current_progress > self.max_progress \ 106 | else BlenderFoodStates.MASHED 107 | return True 108 | 109 | 110 | ABSTRACT_GAME_CLASSES = (ActionObject, ProgressingObject, Container, Food, ChopFood, DynamicObject, StaticObject, 111 | BlenderFood) 112 | 113 | STATEFUL_GAME_CLASSES = (ChopFood, BlenderFood) 114 | -------------------------------------------------------------------------------- /coop_marl/utils/rebar/rebar/parallel.py: -------------------------------------------------------------------------------- 1 | from tqdm.auto import tqdm 2 | from contextlib import contextmanager 3 | import multiprocessing 4 | import types 5 | from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, Future, _base, as_completed 6 | import logging 7 | 8 | log = logging.getLogger(__name__) 9 | 10 | class SerialExecutor(_base.Executor): 11 | """An executor that runs things on the main process/thread - meaning stack traces are interpretable 12 | and the debugger works! 13 | """ 14 | 15 | def __init__(self, *args, **kwargs): 16 | pass 17 | 18 | def submit(self, f, *args, **kwargs): 19 | future = Future() 20 | future.set_result(f(*args, **kwargs)) 21 | return future 22 | 23 | @contextmanager 24 | def VariableExecutor(N=None, processes=True, **kwargs): 25 | """An executor that can be easily switched between serial, thread and parallel execution. 26 | If N=0, a serial executor will be used. 27 | """ 28 | 29 | N = multiprocessing.cpu_count() if N is None else N 30 | 31 | if N == 0: 32 | executor = SerialExecutor 33 | elif processes: 34 | executor = ProcessPoolExecutor 35 | else: 36 | executor = ThreadPoolExecutor 37 | 38 | log.debug('Launching a {} with {} processes'.format(executor.__name__, N)) 39 | with executor(N, **kwargs) as pool: 40 | yield pool 41 | 42 | @contextmanager 43 | def parallel(f, progress=True, **kwargs): 44 | """Sugar for using the VariableExecutor. Call as 45 | 46 | with parallel(f) as g: 47 | ys = g.wait({x: g(x) for x in xs}) 48 | and f'll be called in parallel on each x, and the results collected in a dictionary. 49 | A fantastic additonal feature is that if you pass `parallel(f, N=0)` , everything will be run on 50 | the host process, so you can `import pdb; pdb.pm()` any errors. 51 | """ 52 | 53 | with VariableExecutor(**kwargs) as pool: 54 | 55 | def reraise(f, futures={}): 56 | e = f.exception() 57 | if e: 58 | log.warn('Exception raised on "{}"'.format(futures[f]), exc_info=e) 59 | raise e 60 | return f.result() 61 | 62 | submitted = set() 63 | 64 | def submit(*args, **kwargs): 65 | fut = pool.submit(f, *args, **kwargs) 66 | submitted.add(fut) 67 | fut.add_done_callback(submitted.discard) # Try to avoid memory leak 68 | return fut 69 | 70 | def wait(c): 71 | # Recurse on list-likes 72 | if type(c) in (list, tuple, types.GeneratorType): 73 | ctor = list if isinstance(c, types.GeneratorType) else type(c) 74 | results = wait(dict(enumerate(c))) 75 | return ctor(results[k] for k in sorted(results)) 76 | 77 | # Now can be sure we've got a dict-like 78 | futures = {fut: k for k, fut in c.items()} 79 | 80 | results = {} 81 | for fut in tqdm(as_completed(futures), total=len(c), disable=not progress): 82 | results[futures[fut]] = reraise(fut, futures) 83 | 84 | return results 85 | 86 | def cancel(): 87 | while True: 88 | remaining = list(submitted) 89 | for fut in remaining: 90 | fut.cancel() 91 | submitted.discard(fut) 92 | if not remaining: 93 | break 94 | 95 | try: 96 | submit.wait = wait 97 | yield submit 98 | finally: 99 | cancel() -------------------------------------------------------------------------------- /coop_marl/utils/parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import argparse 4 | import random 5 | from datetime import datetime 6 | 7 | import torch 8 | from yamlinclude import YamlIncludeConstructor 9 | 10 | from coop_marl.utils import Dotdict, update_existing_keys, get_logger, pblock 11 | 12 | YamlIncludeConstructor.add_to_loader_class(loader_class=yaml.FullLoader) 13 | 14 | DEF_CONFIG = 'def_config' 15 | 16 | def save_yaml(conf, path): 17 | os.makedirs('/'.join(path.split('/')[:-1]), exist_ok=True) 18 | with open(f'{path}.yaml', 'w') as f: 19 | yaml.dump(conf, f, default_flow_style=False) 20 | 21 | def create_parser(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--config_file', type=argparse.FileType(mode='r'), required=True) 24 | parser.add_argument('--env_config_file', type=argparse.FileType(mode='r'), required=True) 25 | parser.add_argument('--config', default={}, type=yaml.load) 26 | parser.add_argument('--env_config', default={}, type=yaml.load) 27 | parser.add_argument('--seed', default=-1, type=int) 28 | parser.add_argument('--run_name', default='', type=str) 29 | return parser 30 | 31 | def get_def_conf(data, init_call=False): 32 | if DEF_CONFIG not in data: 33 | if init_call: 34 | return {} 35 | return data 36 | cur_level = {k:v for k,v in data.items() if k != DEF_CONFIG} 37 | next_level = get_def_conf(data[DEF_CONFIG]) 38 | next_level.update(cur_level) 39 | return next_level 40 | 41 | def parse_nested_yaml(yaml): 42 | def_conf = get_def_conf(yaml, True) 43 | conf = {k:v for k,v in yaml.items() if k != DEF_CONFIG} 44 | if def_conf is not None: 45 | def_conf.update(conf) 46 | conf = def_conf 47 | return conf 48 | 49 | def parse_args(parser): 50 | args = parser.parse_args() 51 | data = yaml.load(args.config_file, Loader=yaml.FullLoader) 52 | conf = parse_nested_yaml(data) 53 | 54 | env_conf = yaml.load(args.env_config_file, Loader=yaml.FullLoader) 55 | 56 | # replace the config params with hparams from console args 57 | unused_param = [None, None] 58 | conf_names = ['config', 'env config'] 59 | for i, (cli_conf, yaml_conf, text) in enumerate(zip([args.config, args.env_config], [conf, env_conf], conf_names)): 60 | yaml_conf, unused_param[i] = update_existing_keys(yaml_conf, cli_conf) 61 | 62 | run_name = datetime.now().strftime("%Y%m%d-%H%M%S") 63 | if len(args.run_name)>0: 64 | run_name = args.run_name 65 | conf['run_name'] = run_name 66 | 67 | if not getattr(conf, 'save_dir', ''): 68 | env_folder = f'{env_conf["name"]}_{env_conf["mode"]}' if 'mode' in env_conf else f'{env_conf["name"]}' 69 | save_folder = conf['save_folder'] 70 | conf['save_dir'] = f'{save_folder}/{env_folder}/{conf["algo_name"]}/{run_name}' 71 | logger = get_logger(log_dir=conf['save_dir'], debug=conf['debug']) 72 | [logger.info(pblock(unused_param[i], f'Unused {conf_names[i]} parameters')) for i in range(2)] 73 | if args.seed==-1: 74 | args.seed = random.randint(1,int(2**31-1)) 75 | 76 | if conf['use_gpu']: 77 | conf['device'] = 'cuda' 78 | 79 | if conf['training_device'] == 'cuda': 80 | if not torch.cuda.is_available(): 81 | logger.info('CUDA is not available, using CPU for training instead.') 82 | conf['training_device'] = 'cpu' 83 | 84 | delattr(args, 'env_config_file') 85 | delattr(args, 'config_file') 86 | [save_yaml(c, f'{conf["save_dir"]}/{name}') for c, name in zip([conf, env_conf], ['conf','env_conf'])] 87 | return [Dotdict(x) for x in [vars(args), conf, env_conf]] + [conf['trainer']] 88 | 89 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/overcooked_maker.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | from gym_cooking.environment import cooking_zoo 6 | from gym_cooking.environment.game.graphic_pipeline import GraphicPipeline 7 | from coop_marl.utils import Arrdict, Dotdict, arrdict 8 | from coop_marl.envs.wrappers import SARDConsistencyChecker 9 | 10 | class OvercookedMaker: 11 | def __init__(self, *, mode, horizon, recipes, obs_spaces, num_agents=2, 12 | interact_reward=0.5, progress_reward=1.0, complete_reward=10.0, 13 | step_cost=0.1, **kwargs): 14 | if not isinstance(obs_spaces, list): 15 | obs_spaces = [obs_spaces] 16 | self._env = cooking_zoo.parallel_env(level=mode, num_agents=num_agents, record=False, 17 | max_steps=horizon, recipes=recipes, obs_spaces=obs_spaces, 18 | interact_reward=interact_reward, progress_reward=progress_reward, 19 | complete_reward=complete_reward, step_cost=step_cost) 20 | 21 | self.players = self._env.possible_agents 22 | self.action_spaces = Dotdict(self._env.action_spaces) 23 | self.observation_spaces = Dotdict((k,Dotdict(obs=v)) for k,v in self._env.observation_spaces.items()) 24 | self.graphic_pipeline = GraphicPipeline(self._env, display=False) # do not create a display window 25 | self.graphic_pipeline.on_init() 26 | 27 | def get_action_space(self): 28 | return gym.spaces.Discrete(6) 29 | 30 | def get_observation_space(self): 31 | # agent observation size 32 | if isinstance(self._env.unwrapped.obs_size, int): 33 | return Dotdict(obs=gym.spaces.Box(-1,1,shape=self._env.unwrapped.obs_size)) 34 | else: 35 | return Dotdict(obs=gym.spaces.Box(0,10,shape=self._env.unwrapped.obs_size)) 36 | 37 | def reset(self): 38 | obs = self._env.reset() 39 | data = Arrdict() 40 | for p,k in zip(self.players,obs): 41 | data[p] = Arrdict(obs=obs[k], 42 | reward=np.float32(0), 43 | done=False 44 | ) 45 | return data 46 | 47 | def step(self,decision): 48 | actions = {} 49 | for a,p in zip(self._env.agents,decision.action): 50 | actions[a] = decision.action[p] 51 | 52 | obs, reward, done, info = self._env.step(actions) 53 | data= Arrdict() 54 | for k in obs.keys(): 55 | data[k] = Arrdict(obs=obs[k], 56 | reward=np.float32(reward[k]), 57 | done=done[k] 58 | ) 59 | 60 | return data, Dotdict(info) 61 | 62 | def render(self, mode): 63 | return self.graphic_pipeline.on_render(mode) 64 | 65 | @staticmethod 66 | def make_env(*args,**kwargs): 67 | env = OvercookedMaker(*args,**kwargs) 68 | env = SARDConsistencyChecker(env) 69 | return env 70 | 71 | if __name__ == '__main__': 72 | from coop_marl.controllers import RandomController 73 | from coop_marl.runners import StepsRunner 74 | import argparse 75 | 76 | parser = argparse.ArgumentParser() 77 | parser.add_argument('--level', type=str, default='simple') 78 | args = parser.parse_args() 79 | 80 | level = 'full_divider_salad_4' 81 | horizon = 200 82 | recipes = [ 83 | "LettuceSalad", 84 | "TomatoSalad", 85 | "ChoppedCarrot", 86 | "ChoppedOnion", 87 | "TomatoLettuceSalad", 88 | "TomatoCarrotSalad" 89 | ] 90 | 91 | env = OvercookedMaker.make_env(obs_spaces='dense', mode=level, horizon=horizon, recipes=recipes) 92 | action_spaces = env.action_spaces 93 | controller = RandomController(action_spaces) 94 | runner = StepsRunner(env, controller) 95 | buffer = [] 96 | for i in range(2): 97 | traj, infos, frames = runner.rollout(100, render=True) 98 | print(traj.data) 99 | buffer.append(traj) 100 | batch = arrdict.cat(buffer) 101 | plt.imshow(frames[0]) 102 | plt.show() 103 | -------------------------------------------------------------------------------- /coop_marl/utils/rebar/rebar/stats/writing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import inspect 4 | from ..contextlib import maybeasynccontextmanager 5 | from .. import numpy 6 | from . import categories 7 | from functools import partial 8 | 9 | __all__ = ['to_dir', 'defer', 'record'] 10 | 11 | WRITER = None 12 | 13 | @maybeasynccontextmanager 14 | def to_dir(run_name): 15 | try: 16 | global WRITER 17 | old = WRITER 18 | WRITER = numpy.Writer(run_name, 'stats') 19 | yield 20 | finally: 21 | WRITER = old 22 | def clean(x): 23 | if isinstance(x, torch.Tensor): 24 | x = x.detach().cpu().numpy() 25 | if isinstance(x, np.ndarray) and x.ndim == 0: 26 | x = x.item() 27 | if isinstance(x, dict): 28 | return {k: clean(v) for k, v in x.items()} 29 | return x 30 | 31 | def eager_record(category, field, *args, **kwargs): 32 | if WRITER is None: 33 | return 34 | if not isinstance(field, str): 35 | raise ValueError(f'Field should be a string, is actually {field}') 36 | 37 | args = tuple(clean(a) for a in args) 38 | kwargs = {k: clean(v) for k, v in kwargs.items()} 39 | 40 | func = categories.CATEGORIES[category] 41 | call = inspect.getcallargs(func, *args, **kwargs) 42 | call = {'_time': np.datetime64('now'), **call} 43 | 44 | WRITER.write(f'{category}/{field}', call) 45 | 46 | _record = eager_record 47 | QUEUE = None 48 | 49 | def record(*args, **kwargs): 50 | return _record(*args, **kwargs) 51 | def deferred_record(category, field, *args, **kwargs): 52 | if not isinstance(field, str): 53 | raise ValueError(f'Field should be a string, is actually {field}') 54 | QUEUE.append((category, field, args, kwargs)) 55 | 56 | def _mono_getter(collection, x): 57 | dtype = x.dtype 58 | if dtype not in collection: 59 | collection[dtype] = [] 60 | start = sum(c.nelement() for c in collection[dtype]) 61 | end = start + x.nelement() 62 | collection[dtype].append(x.flatten()) 63 | 64 | def f(collection): 65 | return collection[dtype][start:end].reshape(x.shape) 66 | return f 67 | 68 | def _dummy_getter(x): 69 | def f(collection): 70 | return x 71 | return f 72 | 73 | def _multi_getter(collection, *args, **kwargs): 74 | arggetters = [] 75 | for a in args: 76 | if isinstance(a, torch.Tensor) and a.device.type != 'cpu': 77 | arggetters.append(_mono_getter(collection, a)) 78 | else: 79 | arggetters.append(_dummy_getter(a)) 80 | 81 | kwarggetters = {} 82 | for k, v in kwargs.items(): 83 | if isinstance(v, torch.Tensor) and v.device.type != 'cpu': 84 | kwarggetters[k] = _mono_getter(collection, v) 85 | else: 86 | kwarggetters[k] = _dummy_getter(v) 87 | 88 | def f(collection): 89 | args = tuple(g(collection) for g in arggetters) 90 | kwargs = {k: g(collection) for k, g in kwarggetters.items()} 91 | return args, kwargs 92 | return f 93 | 94 | def _gather(queue): 95 | collection = {} 96 | getters = [] 97 | for category, field, args, kwargs in queue: 98 | getters.append((category, field, _multi_getter(collection, *args, **kwargs))) 99 | collection = {k: torch.cat(v).detach().cpu() for k, v in collection.items()} 100 | return collection, getters 101 | 102 | @maybeasynccontextmanager 103 | def defer(): 104 | global _record 105 | global QUEUE 106 | _record = deferred_record 107 | QUEUE = [] 108 | try: 109 | yield 110 | finally: 111 | collection, getters = _gather(QUEUE) 112 | 113 | for (category, field, getter) in getters: 114 | args, kwargs = getter(collection) 115 | args = tuple(clean(a) for a in args) 116 | kwargs = {k: clean(v) for k, v in kwargs.items()} 117 | func = categories.CATEGORIES[category] 118 | call = inspect.getcallargs(func, *args, **kwargs) 119 | call = {'_time': np.datetime64('now'), **call} 120 | 121 | if WRITER is not None: 122 | WRITER.write(f'{category}/{field}', call) 123 | 124 | QUEUE = None 125 | _record = eager_record 126 | 127 | for c in categories.CATEGORIES: 128 | locals()[c] = partial(record, c) 129 | __all__.append(c) -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/cooking_book/recipe_drawer.py: -------------------------------------------------------------------------------- 1 | from gym_cooking.cooking_world.world_objects import * 2 | from gym_cooking.cooking_book.recipe import Recipe, RecipeNode 3 | from copy import deepcopy 4 | 5 | 6 | def id_num_generator(): 7 | num = 0 8 | while True: 9 | yield num 10 | num += 1 11 | 12 | 13 | id_generator = id_num_generator() 14 | 15 | # Basic food Items 16 | # root_type, id_num, parent=None, conditions=None, contains=None 17 | ChoppedLettuce = RecipeNode(root_type=Lettuce, id_num=next(id_generator), name="Lettuce", 18 | conditions=[("chop_state", ChopFoodStates.CHOPPED)]) 19 | ChoppedOnion = RecipeNode(root_type=Onion, id_num=next(id_generator), name="Onion", 20 | conditions=[("chop_state", ChopFoodStates.CHOPPED)]) 21 | ChoppedTomato = RecipeNode(root_type=Tomato, id_num=next(id_generator), name="Tomato", 22 | conditions=[("chop_state", ChopFoodStates.CHOPPED)]) 23 | ChoppedCarrot = RecipeNode(root_type=Carrot, id_num=next(id_generator), name="Carrot", 24 | conditions=[("chop_state", ChopFoodStates.CHOPPED)]) 25 | MashedCarrot = RecipeNode(root_type=Carrot, id_num=next(id_generator), name="Carrot", 26 | conditions=[("blend_state", BlenderFoodStates.MASHED)]) 27 | 28 | # Salad Plates 29 | LettuceSaladPlate = RecipeNode(root_type=Plate, id_num=next(id_generator), name="Plate", conditions=None, 30 | contains=[ChoppedLettuce]) 31 | TomatoSaladPlate = RecipeNode(root_type=Plate, id_num=next(id_generator), name="Plate", conditions=None, 32 | contains=[ChoppedTomato]) 33 | TomatoLettucePlate = RecipeNode(root_type=Plate, id_num=next(id_generator), name="Plate", conditions=None, 34 | contains=[ChoppedTomato, ChoppedLettuce]) 35 | TomatoCarrotPlate = RecipeNode(root_type=Plate, id_num=next(id_generator), name="Plate", conditions=None, 36 | contains=[ChoppedTomato, ChoppedCarrot]) 37 | TomatoLettuceOnionPlate = RecipeNode(root_type=Plate, id_num=next(id_generator), name="Plate", conditions=None, 38 | contains=[ChoppedTomato, ChoppedLettuce, ChoppedOnion]) 39 | ChoppedOnionPlate = RecipeNode(root_type=Plate, id_num=next(id_generator), name="Plate", conditions=None, 40 | contains=[ChoppedOnion]) 41 | ChoppedCarrotPlate = RecipeNode(root_type=Plate, id_num=next(id_generator), name="Plate", conditions=None, 42 | contains=[ChoppedCarrot]) 43 | MasedCarrotPlate = RecipeNode(root_type=Plate, id_num=next(id_generator), name="Plate", conditions=None, 44 | contains=[MashedCarrot]) 45 | 46 | # Delivered Salads 47 | LettuceSalad = RecipeNode(root_type=DeliverSquare, id_num=next(id_generator), name="DeliverSquare", conditions=None, 48 | contains=[LettuceSaladPlate]) 49 | TomatoSalad = RecipeNode(root_type=DeliverSquare, id_num=next(id_generator), name="DeliverSquare", conditions=None, 50 | contains=[TomatoSaladPlate]) 51 | TomatoLettuceSalad = RecipeNode(root_type=DeliverSquare, id_num=next(id_generator), name="DeliverSquare", 52 | conditions=None, contains=[TomatoLettucePlate]) 53 | TomatoCarrotSalad = RecipeNode(root_type=DeliverSquare, id_num=next(id_generator), name="DeliverSquare", 54 | conditions=None, contains=[TomatoCarrotPlate]) 55 | ChoppedOnion = RecipeNode(root_type=DeliverSquare, id_num=next(id_generator), name="DeliverSquare", 56 | conditions=None, contains=[ChoppedOnionPlate]) 57 | ChoppedCarrot = RecipeNode(root_type=DeliverSquare, id_num=next(id_generator), name="DeliverSquare", 58 | conditions=None, contains=[ChoppedCarrotPlate]) 59 | MashedCarrot = RecipeNode(root_type=DeliverSquare, id_num=next(id_generator), name="DeliverSquare", 60 | conditions=None, contains=[MasedCarrotPlate]) 61 | 62 | # this one increments one further and is thus the amount of ids we have given since 63 | # we started counting at zero. 64 | NUM_GOALS = next(id_generator) 65 | 66 | RECIPES = { 67 | "LettuceSalad":lambda: deepcopy(Recipe(LettuceSalad, name='LettuceSalad')), 68 | "TomatoSalad": lambda: deepcopy(Recipe(TomatoSalad, name='TomatoSalad')), 69 | "TomatoLettuceSalad": lambda: deepcopy(Recipe(TomatoLettuceSalad, name='TomatoLettuceSalad')), 70 | "TomatoCarrotSalad": lambda: deepcopy(Recipe(TomatoCarrotSalad, name='TomatoCarrotSalad')), 71 | # "TomatoLettuceOnionSalad": lambda: deepcopy(Recipe(TomatoLettuceOnionSalad, name='TomatoLettuceOnionSalad')), 72 | "ChoppedCarrot": lambda: deepcopy(Recipe(ChoppedCarrot, name='ChoppedCarrot')), 73 | "ChoppedOnion": lambda: deepcopy(Recipe(ChoppedOnion, name='ChoppedOnion')), 74 | "MashedCarrot": lambda: deepcopy(Recipe(MashedCarrot, name='MashedCarrot'))} 75 | -------------------------------------------------------------------------------- /coop_marl/utils/rebar/rebar/numpy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.lib import format as npformat 3 | from . import paths 4 | from io import BytesIO 5 | from datetime import datetime 6 | from collections import defaultdict 7 | import time 8 | 9 | def infer_dtype(exemplar): 10 | return np.dtype([(k, v.dtype if isinstance(v, np.generic) else type(v)) for k, v in exemplar.items()]) 11 | 12 | def make_header(dtype): 13 | """ 14 | Ref: https://numpy.org/devdocs/reference/generated/numpy.lib.format.html 15 | We're doing version 3. Only difference is the zero shape, since we're 16 | going to deduce the array size from the filesize. 17 | """ 18 | assert not dtype.hasobject, 'Arrays with objects in get pickled, so can\'t be appended to' 19 | 20 | bs = BytesIO() 21 | npformat._write_array_header(bs, { 22 | 'descr': dtype.descr, 23 | 'fortran_order': False, 24 | 'shape': (0,)}, 25 | version=(3, 0)) 26 | return bs.getvalue() 27 | 28 | class FileWriter: 29 | 30 | def __init__(self, path, period=1): 31 | self._path = path 32 | self._file = None 33 | self._period = 5 34 | self._next = time.time() 35 | 36 | def _init(self, exemplar): 37 | self._file = self._path.open('wb', buffering=4096) 38 | self._dtype = infer_dtype(exemplar) 39 | self._file.write(make_header(self._dtype)) 40 | self._file.flush() 41 | 42 | def write(self, d): 43 | if self._file is None: 44 | self._init(d) 45 | assert set(d) == set(self._dtype.names) 46 | row = np.array([tuple(v for v in d.values())], self._dtype) 47 | self._file.write(row.tobytes()) 48 | self._file.flush() 49 | 50 | def close(self): 51 | self._file.close() 52 | self._file = None 53 | 54 | class Writer: 55 | 56 | def __init__(self, run_name, group): 57 | self._run_name = run_name 58 | self._group = group 59 | self._writers = {} 60 | 61 | def write(self, channel, d): 62 | if channel not in self._writers: 63 | path = paths.path(self._run_name, self._group, channel).with_suffix('.npr') 64 | self._writers[channel] = FileWriter(path) 65 | self._writers[channel].write(d) 66 | 67 | def write_many(self, ds): 68 | for channel, d in ds.items(): 69 | if channel not in self._writers: 70 | path = paths.path(self._run_name, self._group, channel).with_suffix('.npr') 71 | self._writers[channel] = FileWriter(path) 72 | self._writers[channel].write(d) 73 | 74 | def close(self): 75 | for _, w in self._writers.items(): 76 | w.close() 77 | self._writers = {} 78 | 79 | class FileReader: 80 | 81 | def __init__(self, path): 82 | self._path = path 83 | self._file = None 84 | 85 | def _init(self): 86 | #TODO: Can speed this up with PAG's regex header parser 87 | self._file = self._path.open('rb') 88 | version = npformat.read_magic(self._file) 89 | _, _, dtype = npformat._read_array_header(self._file, version) 90 | self._dtype = dtype 91 | 92 | def read(self): 93 | if self._file is None: 94 | self._init() 95 | return np.fromfile(self._file, dtype=self._dtype) 96 | 97 | def close(self): 98 | self._file.close() 99 | self._file = None 100 | 101 | class Reader: 102 | 103 | def __init__(self, run_name, group): 104 | self._run_name = paths.resolve(run_name) 105 | self._group = group 106 | self._readers = {} 107 | 108 | def read(self): 109 | for path in paths.subdirectory(self._run_name, self._group).glob('**/*.npr'): 110 | parts = paths.parse(path) 111 | if (parts.channel, parts.filename) not in self._readers: 112 | self._readers[parts.channel, parts.filename] = FileReader(path) 113 | 114 | results = defaultdict(lambda: []) 115 | for (channel, _), reader in self._readers.items(): 116 | arr = reader.read() 117 | if len(arr) > 0: 118 | results[channel].append(arr) 119 | 120 | return results 121 | 122 | 123 | def test_file_write_read(): 124 | d = {'total': 65536, 'count': 14, '_time': np.datetime64('now')} 125 | 126 | paths.clear('test', 'stats') 127 | path = paths.path('test', 'stats', 'mean/traj-length').with_suffix('.npr') 128 | 129 | writer = FileWriter(path) 130 | writer.write(d) 131 | 132 | reader = FileReader(path) 133 | r = reader.read() 134 | 135 | assert len(r) == 1 136 | 137 | def test_write_read(): 138 | paths.clear('test', 'stats') 139 | 140 | writer = Writer('test', 'stats') 141 | writer.write('mean/traj-length', {'total': 65536, 'count': 14, '_time': np.datetime64('now')}) 142 | writer.write('max/reward', {'total': 50000.5, 'count': 50, '_time': np.datetime64('now')}) 143 | 144 | reader = Reader('test', 'stats') 145 | r = reader.read() 146 | 147 | assert len(r) == 2 -------------------------------------------------------------------------------- /coop_marl/models/modules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | from coop_marl.utils.nn import ortho_layer_init 7 | 8 | class FCLayers(nn.Module): 9 | def __init__(self, input_dim, hidden_size, num_hidden, output_size, 10 | activation_fn=nn.ELU, base_std=np.sqrt(2), head_std=0.01, 11 | last_linear=True, layer_init_fn=ortho_layer_init): 12 | assert num_hidden > 0 13 | super().__init__() 14 | layer_construct = nn.Linear 15 | 16 | pre_hidden = [layer_init_fn(layer_construct(input_dim, hidden_size), base_std), 17 | activation_fn() 18 | ] 19 | base = pre_hidden 20 | if num_hidden>1: 21 | hiddens = [] 22 | for i in range(num_hidden-1): 23 | hiddens.append(layer_init_fn(layer_construct(hidden_size, hidden_size), base_std)) 24 | hiddens.append(activation_fn()) 25 | base.extend(hiddens) 26 | out = [layer_init_fn(layer_construct(hidden_size,output_size), head_std)] 27 | if not last_linear: 28 | out.append(activation_fn()) 29 | self.layers = nn.Sequential(*base, *out) 30 | 31 | def forward(self, x): 32 | return self.layers(x) 33 | 34 | # taken from: https://github.com/keynans/HypeRL/blob/main/PEARL/torch/sac/hyper_network.py 35 | class HyperHead(nn.Module): 36 | def __init__(self, base_inp_dim, meta_hidden_dim, output_size, stddev=0.05): 37 | super().__init__() 38 | self.output_size = output_size 39 | self.base_inp_dim = base_inp_dim 40 | self.w = nn.Linear(meta_hidden_dim, base_inp_dim * output_size) 41 | self.b = nn.Linear(meta_hidden_dim, output_size) 42 | self.init_layers(stddev) 43 | 44 | def forward(self, x): 45 | w = self.w(x).view(-1, self.output_size, self.base_inp_dim) 46 | b = self.b(x).view(-1, self.output_size, 1) 47 | return w, b 48 | 49 | def init_layers(self,stddev): 50 | ortho_layer_init(self.w, stddev) 51 | ortho_layer_init(self.b, stddev) 52 | 53 | class ResBlock(nn.Module): 54 | 55 | def __init__(self, in_size, out_size): 56 | super(ResBlock, self).__init__() 57 | self.layers = nn.Sequential( 58 | nn.ELU(), 59 | ortho_layer_init(nn.Linear(in_size, out_size)), 60 | nn.ELU(), 61 | ortho_layer_init(nn.Linear(out_size, out_size)), 62 | ) 63 | 64 | def forward(self, x): 65 | h = self.layers(x) 66 | return x + h 67 | 68 | class MetaResNet(nn.Module): 69 | def __init__(self, meta_dim, hidden_size): 70 | super(MetaResNet, self).__init__() 71 | 72 | self.hidden_size = hidden_size 73 | self.layers = nn.Sequential( 74 | nn.Linear(meta_dim, hidden_size), 75 | ResBlock(hidden_size, hidden_size), 76 | nn.ELU() 77 | ) 78 | 79 | self.init_layers() 80 | 81 | def forward(self, meta_v): 82 | return self.layers(meta_v) 83 | 84 | def init_layers(self): 85 | for module in self.layers.modules(): 86 | if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Linear)): 87 | ortho_layer_init(module) 88 | 89 | class HyperNet(nn.Module): 90 | def __init__(self, base_inp_dim, meta_inp_dim, hidden_size, num_heads=1, head_sizes=[256], head_std=[0.05]): 91 | super().__init__() 92 | assert isinstance(head_sizes, list) 93 | assert isinstance(head_std, list) 94 | self.encoder = MetaResNet(meta_inp_dim, hidden_size) 95 | self.heads = nn.ModuleList() 96 | self.heads.append(HyperHead(base_inp_dim, hidden_size, head_sizes[0], stddev=head_std[0])) 97 | for i in range(1,num_heads): 98 | self.heads.append(HyperHead(head_sizes[i-1], hidden_size, head_sizes[i], stddev=head_std[i])) 99 | 100 | def produce_wb(self, z): 101 | if len(z.shape)==3: 102 | z = z.view(-1,z.shape[2]) 103 | meta_h = self.encoder(z) 104 | wb = [] 105 | for head in self.heads: 106 | w,b = head(meta_h) 107 | wb.append((w,b)) 108 | return wb 109 | 110 | def forward(self, base_inp, meta_inp): 111 | wb = self.produce_wb(meta_inp) 112 | # base_inp for ff net -> [bs, feature] -> [bs, feature ,1] 113 | # base_inp for rnn -> [bs, ts, feature] -> [bs * ts, feature, 1] 114 | if len(base_inp.shape)==2: 115 | # ff net 116 | h = base_inp.unsqueeze(-1) 117 | elif len(base_inp.shape)==3: 118 | # rnn 119 | h = base_inp.reshape(base_inp.shape[0]*base_inp.shape[1],base_inp.shape[2],1) 120 | 121 | for w,b in wb[:-1]: 122 | h = F.elu(torch.bmm(w, h) + b) 123 | w,b = wb[-1] 124 | out = torch.bmm(w,h) + b 125 | if len(base_inp.shape)==3: 126 | out = out.view(*base_inp.shape[:2],-1) 127 | return out, wb 128 | return out.squeeze(-1), wb 129 | -------------------------------------------------------------------------------- /coop_marl/runners/runners.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import random 3 | 4 | from coop_marl.utils import Arrdict, arrdict 5 | 6 | ''' 7 | All runner assume simultaneous action except RoundRobinTurnBasedEpisodesRunner 8 | ''' 9 | 10 | class Runner(): 11 | def __init__(self,env,controller,*,possible_teams=None): 12 | self.env = env 13 | self.controller = controller 14 | self.possible_teams = possible_teams 15 | self._setup() 16 | 17 | def _setup(self): 18 | self.total_timestep = 0 19 | 20 | def rollout(self, *args, **kwargs): 21 | # each runner should have its own rollout logic 22 | raise NotImplementedError 23 | 24 | class EpisodesRunner(Runner): 25 | def _setup(self): 26 | super()._setup() 27 | # each player might have different keys in the decision 28 | # make sure that the controller takes care of this 29 | self.dummy_decision = self.controller.get_prev_decision_view() 30 | 31 | def rollout(self, num_episodes, render=False, render_mode='rgb_array', resample_players=False): 32 | frames = [] 33 | buffer = [] 34 | infos = [] 35 | for _ in range(num_episodes): 36 | if resample_players: 37 | team = random.choice(self.possible_teams) 38 | self.env.set_players(team) 39 | self.controller.reset() 40 | outcome = self.env.reset() 41 | decision = Arrdict({p:self.dummy_decision[p] for p in outcome}) 42 | while True: 43 | if render: 44 | frames.append(self.env.render(mode=render_mode)) 45 | inp = Arrdict(data=outcome, prev_decision=decision) 46 | decision = self.controller.select_actions(inp) 47 | transition = Arrdict(inp=inp, decision=decision) 48 | # env step 49 | outcome, info = self.env.step(decision) 50 | # use s_t, r_t, a_t, d_t convention 51 | transition['outcome'] = outcome 52 | # add transition to buffer 53 | buffer.append(transition) 54 | infos.append(info) 55 | self.total_timestep += 1 56 | # last time step data will not be collected 57 | # check terminal condition 58 | done_agents = set(k for k,v in outcome.done.items() if v) 59 | if done_agents==set(outcome.keys()): 60 | break 61 | traj = arrdict.stack(buffer) 62 | return traj, infos, frames 63 | 64 | class StepsRunner(Runner): 65 | def _setup(self): 66 | super()._setup() 67 | self.last_outcome = self.env.reset() 68 | # each player might have different keys in the decision 69 | # make sure that the controller takes care of this 70 | self.dummy_decision = self.controller.get_prev_decision_view() 71 | self.last_decision = Arrdict({p:self.dummy_decision[p] for p in self.last_outcome}) # None 72 | 73 | # has to keep env state between calls 74 | def rollout(self, num_timesteps, render=False, render_mode='rgb_array', resample_players=False): 75 | # keep rolling-out until num_timesteps frams are collected (through multiple resets if needed) 76 | frames = [] 77 | buffer = [] 78 | infos = [] 79 | 80 | outcome = self.last_outcome 81 | decision = deepcopy(self.last_decision) 82 | # remove keys that're not gonna be used in prev_decision 83 | for p,d in self.last_decision.items(): 84 | for k in d: 85 | if k not in self.dummy_decision[p]: 86 | del decision[p][k] 87 | 88 | for _ in range(num_timesteps): 89 | if render: 90 | frames.append(self.env.render(mode=render_mode)) 91 | 92 | inp = Arrdict(data=outcome, prev_decision=decision) 93 | decision = self.controller.select_actions(inp) 94 | transition = Arrdict(inp=inp, decision=decision) 95 | # env step 96 | outcome, info = self.env.step(decision) 97 | # use s_t, r_t, a_t, d_t convention 98 | transition['outcome'] = outcome 99 | # add transition to buffer 100 | buffer.append(transition) 101 | infos.append(info) 102 | self.total_timestep += 1 103 | # last time step data will not be collected 104 | # check terminal condition 105 | done_agents = set(k for k,v in outcome.done.items() if v) 106 | if done_agents==set(outcome.keys()): 107 | if resample_players: 108 | team = random.choice(self.possible_teams) 109 | self.env.set_players(team) 110 | self.controller.reset() 111 | outcome = self.env.reset() 112 | decision = Arrdict({p:self.dummy_decision[p] for p in outcome}) # None 113 | 114 | traj = arrdict.stack(buffer) 115 | self.last_outcome = outcome 116 | self.last_decision = decision 117 | 118 | return traj, infos, frames 119 | 120 | 121 | -------------------------------------------------------------------------------- /coop_marl/utils/utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from functools import wraps 3 | import inspect 4 | import sys 5 | import select 6 | import os 7 | import random 8 | import yaml 9 | from copy import deepcopy 10 | 11 | 12 | import torch 13 | import numpy as np 14 | from PIL import Image 15 | from pygifsicle import optimize 16 | from yamlinclude import YamlIncludeConstructor 17 | 18 | from coop_marl.utils import Arrdict, arrdict 19 | 20 | '''taken from: https://stackoverflow.com/questions/1389180/automatically-initialize-instance-variables''' 21 | def auto_assign(func): 22 | """ 23 | Automatically assigns the parameters. 24 | 25 | >>> class process: 26 | ... @initializer 27 | ... def __init__(self, cmd, reachable=False, user='root'): 28 | ... pass 29 | >>> p = process('halt', True) 30 | >>> p.cmd, p.reachable, p.user 31 | ('halt', True, 'root') 32 | """ 33 | names, varargs, keywords, defaults = inspect.getargspec(func) 34 | 35 | @wraps(func) 36 | def wrapper(self, *args, **kargs): 37 | for name, arg in list(zip(names[1:], args)) + list(kargs.items()): 38 | setattr(self, name, arg) 39 | 40 | if defaults is not None: 41 | for i in range(len(defaults)): 42 | index = -(i + 1) 43 | if not hasattr(self, names[index]): 44 | setattr(self, names[index], defaults[index]) 45 | 46 | func(self, *args, **kargs) 47 | 48 | return wrapper 49 | 50 | def reverse_dict(d): 51 | ''' 52 | Reverse a dict. Returns a reversed dict with its values as lists 53 | ''' 54 | out = defaultdict(list) 55 | for k,v in d.items(): 56 | out[v].append(k) 57 | return out 58 | 59 | def update_existing_keys(target_dict, source_dict): 60 | unused_param = {} 61 | for k,v in source_dict.items(): 62 | if k in target_dict: 63 | if isinstance(target_dict[k],dict) and isinstance(v,dict): 64 | target_dict[k].update(v) 65 | else: 66 | target_dict[k] = v 67 | else: 68 | unused_param[k] = v 69 | return target_dict, unused_param 70 | 71 | def merge_dict(d): 72 | assert len(d)>0 73 | if len(d)==1: 74 | return d[0] 75 | 76 | if len(d[0])==0: 77 | return merge_dict(d[1:]) 78 | 79 | out = deepcopy(d[0]) 80 | d1 = d[1] 81 | for k,v in d1.items(): 82 | if k in out: 83 | if isinstance(v, dict): 84 | out[k] = merge_dict([out[k], v]) 85 | elif isinstance(out[k], list): 86 | if isinstance(v, list): 87 | out[k].extend(v) 88 | else: 89 | out[k].append(v) 90 | elif isinstance(v, list): 91 | out[k] = [out[k]] + v 92 | else: 93 | out[k] = [out[k], v] 94 | else: 95 | out[k] = v 96 | return merge_dict([out] + d[2:]) 97 | 98 | 99 | 100 | class TimeoutExpired(Exception): 101 | pass 102 | 103 | def input_with_timeout(prompt, timeout=10): 104 | sys.stdout.write(prompt) 105 | sys.stdout.flush() 106 | ready, _, _ = select.select([sys.stdin], [],[], timeout) 107 | if ready: 108 | return sys.stdin.readline().rstrip('\n') # expect stdin to be line-buffered 109 | raise TimeoutExpired 110 | 111 | def wait(): 112 | try: 113 | t = 30 114 | input_with_timeout(f'Press Enter (or wait {t} seconds) to continue...', timeout=t) 115 | except TimeoutExpired: 116 | print('Input timed out, executing the next command.') 117 | 118 | def save_gif(imgs, path, fps=30, size=None): 119 | if (imgs is None) or len(imgs)==0: 120 | return 121 | folder = path.split('/')[:-1] 122 | if len(folder) > 0: 123 | os.makedirs('/'.join(path.split('/')[:-1]), exist_ok=True) 124 | if imgs[0].shape[-1]==1: 125 | imgs = np.array(imgs) 126 | imgs = np.tile(imgs,(1,1,1,3)) 127 | imgs = [Image.fromarray(img) for img in imgs] 128 | if size is not None: 129 | imgs = [i.resize(size) for i in imgs] 130 | imgs[0].save(path, save_all=True, append_images=imgs[1:], duration=1000/fps, loop=0) 131 | optimize(path) 132 | 133 | def set_random_seed(seed): 134 | torch.manual_seed(seed) 135 | random.seed(seed) 136 | np.random.seed(seed) 137 | 138 | def create_ph_list(*args, **kwargs): 139 | return 140 | 141 | def load_yaml(dir): 142 | YamlIncludeConstructor.add_to_loader_class(loader_class=yaml.FullLoader) 143 | with open(dir) as f: 144 | out = yaml.load(f, Loader=yaml.FullLoader) 145 | return out 146 | 147 | def flatten_traj(traj): 148 | # remove player keys 149 | out = Arrdict() 150 | for p in traj.inp.data: 151 | print(p) 152 | batch = getattr(traj, p) # remove player keys from traj 153 | out = arrdict.merge_and_cat([out, batch]) 154 | return out 155 | 156 | def safe_log(val, replace_val=-50): 157 | log_val = torch.log(val + 1e-8) 158 | replace_bool = torch.isnan(log_val) + torch.isinf(log_val) 159 | return torch.where(replace_bool, replace_val * torch.ones_like(log_val), log_val) 160 | 161 | if __name__ == '__main__': 162 | a = {'a':1,'b':2,'c':1} 163 | d = reverse_dict(a) 164 | print(d) 165 | -------------------------------------------------------------------------------- /coop_marl/utils/rebar/rebar/stats/reading.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from .. import numpy, paths, widgets, logging 4 | import re 5 | import numpy as np 6 | from .. import arrdict 7 | from . import categories 8 | import pandas as pd 9 | import threading 10 | from contextlib import contextmanager 11 | import _thread 12 | 13 | log = logging.getLogger(__name__) 14 | 15 | def format(v): 16 | if isinstance(v, int): 17 | return f'{v}' 18 | if isinstance(v, float): 19 | return f'{v:.6g}' 20 | if isinstance(v, list): 21 | return ', '.join(format(vv) for vv in v) 22 | if isinstance(v, dict): 23 | return '{' + ', '.join(f'{k}: {format(vv)}' for k, vv in v.items()) + '}' 24 | return str(v) 25 | 26 | def adaptive_rule(df): 27 | timespan = (df.index[-1] - df.index[0]).total_seconds() 28 | if timespan < 600: 29 | return '15s' 30 | elif timespan < 7200: 31 | return '1min' 32 | else: 33 | return '10min' 34 | 35 | class Reader: 36 | 37 | def __init__(self, run_name, prefix=''): 38 | self._reader = numpy.Reader(run_name, 'stats') 39 | self._prefix = prefix 40 | self._arrs = {} 41 | 42 | def arrays(self): 43 | #TODO: If this gets slow, do amortized allocation of arrays x2 as big as needed 44 | for channel, new in self._reader.read().items(): 45 | category, field = re.match(r'^(.*?)/(.*)$', channel).groups() 46 | if field.startswith(self._prefix): 47 | current = [self._arrs[category, field]] if (category, field) in self._arrs else [] 48 | self._arrs[category, field] = np.concatenate(current + new) 49 | return arrdict.arrdict(self._arrs) 50 | 51 | def pandas(self): 52 | arrs = self.arrays() 53 | 54 | dfs = {} 55 | for (category, field), arr in arrs.items(): 56 | df = pd.DataFrame.from_records(arr, index='_time') 57 | df.index.name = 'time' 58 | dfs[category, field] = df 59 | return arrdict.arrdict(dfs) 60 | 61 | def resample(self, rule='60s', **kwargs): 62 | kwargs = {'rule': rule, **kwargs} 63 | 64 | results = {} 65 | for (category, field), df in self.pandas().items(): 66 | func = getattr(categories, category) 67 | results[field] = func(**{k: df[k] for k in df})(**kwargs) 68 | 69 | if results: 70 | df = pd.concat(results, 1) 71 | df.index = df.index - df.index[0] 72 | return df 73 | else: 74 | return pd.DataFrame(index=pd.TimedeltaIndex([], name='time')) 75 | 76 | def arrays(prefix='', run_name=-1): 77 | return Reader(run_name, prefix).arrays() 78 | 79 | def pandas(name, run_name=-1): 80 | dfs = Reader(run_name, name).pandas() 81 | for (_, field), df in dfs.items(): 82 | return df 83 | raise KeyError(f'Couldn\'t find a statistic matching {name}') 84 | 85 | def resample(prefix='', run_name=-1, rule='60s'): 86 | return Reader(run_name, prefix).resample(rule) 87 | 88 | def tdformat(td): 89 | """How is this not in Python, numpy or pandas?""" 90 | x = td.total_seconds() 91 | x, _ = divmod(x, 1) 92 | x, s = divmod(x, 60) 93 | if x < 1: 94 | return f'{s:.0f}s' 95 | h, m = divmod(x, 60) 96 | if h < 1: 97 | return f'{m:.0f}m{s:02.0f}s' 98 | else: 99 | return f'{h:.0f}h{m:02.0f}m{s:02.0f}s' 100 | 101 | def __from_dir(canceller, run_name, out, rule, throttle=1): 102 | reader = Reader(run_name) 103 | start = pd.Timestamp.now() 104 | 105 | nxt = time.time() 106 | while True: 107 | if time.time() > nxt: 108 | nxt = nxt + throttle 109 | 110 | # Base slightly into the future, else by the time the resample actually happens you're 111 | # left with an almost-empty last interval. 112 | base = int(time.time() % 60) + 5 113 | values = reader.resample(rule=rule, base=base) 114 | 115 | if len(values) > 0: 116 | values = values.ffill(limit=1).iloc[-1].to_dict() 117 | key_length = max([len(str(k)) for k in values], default=0)+1 118 | content = '\n'.join(f'{{:{key_length}s}} {{}}'.format(k, format(values[k])) for k in sorted(values)) 119 | else: 120 | content = 'No stats yet' 121 | 122 | size = paths.size(run_name, 'stats') 123 | age = pd.Timestamp.now() - start 124 | out.refresh(f'{run_name}: {tdformat(age)} old, {rule} rule, {size:.0f}MB on disk\n\n{content}') 125 | 126 | if canceller.is_set(): 127 | break 128 | 129 | time.sleep(.1) 130 | 131 | def _from_dir(*args, **kwargs): 132 | try: 133 | __from_dir(*args, **kwargs) 134 | except KeyboardInterrupt: 135 | log.info('Interrupting main') 136 | _thread.interrupt_main() 137 | 138 | @contextmanager 139 | def from_dir(run_name, compositor=None, rule='60s'): 140 | if logging.in_ipython(): 141 | try: 142 | canceller = threading.Event() 143 | out = (compositor or widgets.Compositor()).output() 144 | thread = threading.Thread(target=_from_dir, args=(canceller, run_name, out, rule)) 145 | thread.start() 146 | yield 147 | finally: 148 | canceller.set() 149 | thread.join(1) 150 | if thread.is_alive(): 151 | log.error('Stat display thread won\'t die') 152 | else: 153 | log.info('Stat display thread cancelled') 154 | 155 | # Want to leave the outputs open so you can see the final stats 156 | # out.close() 157 | else: 158 | log.info('No stats emitted in console mode') 159 | yield -------------------------------------------------------------------------------- /coop_marl/trainers/trainer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from coop_marl.agents import registered_agents 4 | from coop_marl.envs import registered_envs 5 | from coop_marl.envs.wrappers import wrap 6 | from coop_marl.controllers import MappingController 7 | from coop_marl.runners import registered_runners, EpisodesRunner, StepsRunner 8 | from coop_marl.utils import get_logger, Dotdict, get_traj_info, auto_assign 9 | 10 | logger = get_logger() 11 | 12 | def population_based_setup(trainer, conf): 13 | agent_cls = registered_agents[conf.agent_name] 14 | trainer.agent_pop = [agent_cls(conf) for _ in range(conf.pop_size)] 15 | trainer.agents_dict = {p:agent_cls(conf) for p in trainer.env.players} 16 | trainer.controller = MappingController(action_spaces=trainer.env.action_spaces, 17 | agents_dict=trainer.agents_dict, 18 | policy_mapping_fn=lambda name: name, 19 | ) 20 | 21 | 22 | def trainer_setup(trainer, conf, env_conf, eval_env_conf=None): 23 | trainer.iter = 0 24 | reg_env_name = env_conf.name 25 | del env_conf['name'] 26 | env = registered_envs[reg_env_name](**env_conf) 27 | if eval_env_conf is not None: 28 | eval_env = registered_envs[reg_env_name](**eval_env_conf) 29 | else: 30 | eval_env = registered_envs[reg_env_name](**env_conf) 31 | 32 | if 'env_wrappers' in conf: 33 | env = wrap(env=env, wrappers=conf.env_wrappers, **conf) 34 | eval_env = wrap(env=eval_env, wrappers=conf.env_wrappers, **conf) 35 | 36 | conf['obs_space'] = env.get_observation_space()['obs'] 37 | conf['act_space'] = env.get_action_space() 38 | conf['n_agents'] = len(env.players) 39 | try: 40 | conf['state_shape'] = env.get_state_shape() 41 | except AttributeError: 42 | conf['state_shape'] = None 43 | 44 | trainer.env = env 45 | trainer.eval_env = eval_env 46 | trainer.agent = registered_agents[conf.agent_name](conf) 47 | 48 | 49 | def population_evaluation(trainer, render=False, render_mode='rgb_array'): 50 | pop_size = trainer.conf.pop_size 51 | trajs = [[None]* pop_size for _ in range(pop_size)] 52 | infos = [[None]* pop_size for _ in range(pop_size)] 53 | frames = [[None]* pop_size for _ in range(pop_size)] 54 | metrics = [[None]* pop_size for _ in range(pop_size)] 55 | 56 | for i in range(pop_size): 57 | param_i = trainer.agent_pop[i].get_param() 58 | for a in trainer.controller.agents_dict.values(): 59 | a.set_param(param_i) 60 | trajs[i][i], infos[i][i], frames[i][i], metrics[i][i] = evaluate(trainer, trainer.conf.n_eval_ep, render=render, render_mode=render_mode) 61 | 62 | for j in range(pop_size): 63 | if i!=j: 64 | param_j = trainer.agent_pop[j].get_param() 65 | for k in range(trainer.conf.n_agents): 66 | for a in trainer.controller.agents_dict.values(): 67 | a.set_param(param_j) 68 | player_k = list(trainer.controller.agents_dict.values())[k] 69 | player_k.set_param(param_i) 70 | trajs[i][j], infos[i][j], frames[i][j], metrics[i][j] = evaluate(trainer, trainer.conf.n_eval_ep, render=render, render_mode=render_mode) 71 | 72 | return trajs, infos, frames, metrics 73 | 74 | def evaluate(trainer, n=None, render=False, render_mode='rgb_array'): 75 | eval_traj, infos, frames, eval_metrics = trainer._collect_data(n, eval=True, render=render, render_mode=render_mode) 76 | return eval_traj, infos, frames, eval_metrics 77 | 78 | def collect_data(runner, n=None, render=False, render_mode='rgb_array'): 79 | traj, infos, frames = runner.rollout(n, render=render, render_mode=render_mode) 80 | metrics = Dotdict({}) 81 | traj_info = get_traj_info(traj) 82 | metrics.update(traj_info) 83 | return traj, infos, frames, metrics 84 | 85 | 86 | class Trainer(ABC): 87 | @auto_assign 88 | def __init__(self, conf, env_conf, eval_env_conf=None): 89 | self._setup(conf, env_conf, eval_env_conf) 90 | 91 | def _setup(self, conf, env_conf, eval_env_conf): 92 | self.iter = 0 93 | reg_env_name = env_conf.name 94 | del env_conf['name'] 95 | env = registered_envs[reg_env_name](**env_conf) 96 | if eval_env_conf is not None: 97 | eval_env = registered_envs[reg_env_name](**eval_env_conf) 98 | else: 99 | eval_env = registered_envs[reg_env_name](**env_conf) 100 | 101 | if 'env_wrappers' in conf: 102 | env = wrap(env=env, wrappers=conf.env_wrappers, **conf) 103 | eval_env = wrap(env=eval_env, wrappers=conf.env_wrappers, **conf) 104 | 105 | conf['obs_space'] = env.get_observation_space()['obs'] 106 | conf['act_space'] = env.get_action_space() 107 | conf['n_agents'] = len(env.players) 108 | try: 109 | conf['state_shape'] = env.get_state_shape() 110 | except AttributeError: 111 | conf['state_shape'] = None 112 | 113 | self.env = env 114 | self.eval_env = eval_env 115 | 116 | def _runner_setup(self): 117 | runner_cls = registered_runners[self.conf.runner] 118 | self.runner = runner_cls(env=self.env, controller=self.controller) 119 | self.eval_runner = EpisodesRunner(env=self.env, controller=self.controller) 120 | 121 | def _collect_data(self, n=None, eval=False, render=False, render_mode='rgb_array'): 122 | runner = self.runner 123 | if eval: 124 | runner = self.eval_runner 125 | if n is None: 126 | if isinstance(runner, EpisodesRunner): 127 | n = self.conf.n_ep 128 | elif isinstance(runner, StepsRunner): 129 | n = self.conf.n_ts 130 | return collect_data(runner, n, render, render_mode) 131 | 132 | @abstractmethod 133 | def train(self): 134 | raise NotImplementedError() 135 | 136 | @abstractmethod 137 | def evaluate(self, n=None, render=False, render_mode='rgb_array'): 138 | eval_traj, infos, frames, eval_metrics = self._collect_data(n, eval=True, render=render, render_mode=render_mode) 139 | return eval_traj, infos, frames, eval_metrics 140 | -------------------------------------------------------------------------------- /coop_marl/controllers/controllers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.spaces import Discrete 3 | 4 | from coop_marl.utils import Arrdict, arrdict, reverse_dict 5 | from coop_marl.agents import Agent 6 | 7 | class Controller: 8 | def select_actions(self, inp): 9 | raise NotImplementedError 10 | def train(self, *args, **kwargs): 11 | raise NotImplementedError 12 | def reset(self): 13 | raise NotImplementedError 14 | def get_prev_decision_view(self): 15 | raise NotImplementedError 16 | 17 | # Agent-less controller, e.g., random agent could implement without calling each agent explicitly 18 | class RandomController(Controller): 19 | # this factors out the agents 20 | def __init__(self, action_spaces): 21 | self.action_spaces = action_spaces 22 | 23 | def select_actions(self, inp): 24 | # data is arrdict 25 | action_dict = Arrdict() 26 | for k in inp.data.keys(): 27 | if 'action_mask' in inp.data[k].keys() and isinstance(self.action_spaces[k],Discrete): 28 | # handle action mask in discrete action space 29 | a = np.random.choice(np.where(inp.data[k].action_mask==1)[0]) 30 | action_dict[k] = Arrdict(action=a) 31 | else: 32 | action_dict[k] = Arrdict(action=self.action_spaces[k].sample()) 33 | return action_dict 34 | 35 | def train(self, *args, **kwargs): 36 | pass 37 | 38 | def reset(self): 39 | pass 40 | 41 | def get_prev_decision_view(self): 42 | return {p:Arrdict() for p,space in self.action_spaces.items()} 43 | 44 | 45 | # Agent-based controller needs to call each agent independently to construct action_dict 46 | class PSController(Controller): 47 | ''' 48 | Parameter sharing controller: one agent controlls all the players in the game 49 | ''' 50 | def __init__(self, action_spaces, agent): 51 | assert isinstance(agent, Agent), f'agent must be an instance of Agent. Got {type(agent)}' 52 | self.action_spaces = action_spaces 53 | self.agent = agent 54 | 55 | def select_actions(self, inp): 56 | # act returns list of arrdict (each arrdict represents one player's action) 57 | actions = self.agent.act(inp) 58 | action_dict = Arrdict((k,v) for k,v in zip(inp.data.keys(), actions)) 59 | return action_dict 60 | 61 | def train(self, traj, *args, **kwargs): 62 | if kwargs.get('flatten',True): 63 | # concat and preprocess all agents traj to a single batch 64 | batch = [] 65 | for player in traj.inp.data: 66 | player_traj = getattr(traj,player) 67 | self.agent.preprocess(player_traj) 68 | batch.append(player_traj) 69 | batch = arrdict.cat(batch) 70 | self.agent.train(batch, *args, **kwargs) 71 | else: 72 | self.agent.preprocess(traj) 73 | self.agent.train(traj, *args, **kwargs) 74 | 75 | def reset(self): 76 | # called at the begining of an episode 77 | self.agent.reset() 78 | 79 | def get_prev_decision_view(self): 80 | dummy_decision = self.agent.get_prev_decision_view() 81 | return Arrdict({p:dummy_decision for p in self.action_spaces}) 82 | 83 | 84 | class MappingController(Controller): 85 | ''' 86 | Maps player to agent using policy_mapping_fn 87 | ''' 88 | # currently does not support stochastic mapping (training logic assumes mapping is static between episodes) 89 | def __init__(self, action_spaces, agents_dict, policy_mapping_fn, possible_teams=None): 90 | self.action_spaces = action_spaces 91 | self.agents_dict = agents_dict 92 | self.policy_mapping_fn = policy_mapping_fn 93 | self.possible_teams = possible_teams 94 | self._player_to_agent = {} 95 | 96 | # ray/rllib/evaluation/episode.py 97 | def agent_for(self, player): 98 | if player not in self._player_to_agent: 99 | self._player_to_agent[player] = self.policy_mapping_fn(player) 100 | return self._player_to_agent[player] 101 | 102 | def select_actions(self, inp): 103 | player_to_agent = {p:self.agent_for(p) for p in inp.data.keys()} 104 | agent_to_player = reverse_dict(player_to_agent) 105 | 106 | # create inp_dict for each agent to use as inputs 107 | agent_in = Arrdict() 108 | for agent, player_list in agent_to_player.items(): 109 | for player in player_list: 110 | if agent not in agent_in: 111 | agent_in[agent] = Arrdict() 112 | agent_in[agent][player] = getattr(inp,player) 113 | else: 114 | assert player not in agent_in[agent], f'each player has its own dict key\ 115 | and no one player should appear twice' 116 | agent_in[agent][player] = getattr(inp,player) 117 | 118 | # act 119 | action_dict = Arrdict() 120 | for agent, player_list in agent_to_player.items(): 121 | actions = self.agents_dict[agent].act(agent_in[agent]) 122 | for a,player in zip(actions, player_list): 123 | action_dict[player] = a 124 | return action_dict 125 | 126 | def train(self, traj, *args, **kwargs): 127 | # preprocess traj 128 | player_to_agent = {p:self.agent_for(p) for p in traj.inp.data.keys()} 129 | agent_batch = Arrdict() 130 | for player,agent in player_to_agent.items(): 131 | player_traj = getattr(traj,player) 132 | # per agent perprocess 133 | self.agents_dict[agent].preprocess(player_traj) 134 | if agent not in agent_batch: 135 | agent_batch[agent] = player_traj 136 | else: 137 | agent_batch[agent] = arrdict.cat([agent_batch[agent], (player_traj)]) 138 | 139 | for agent in agent_batch.keys(): 140 | self.agents_dict[agent].train(agent_batch[agent], *args, **kwargs) 141 | 142 | def reset(self): 143 | self._player_to_agent = {} 144 | for a in self.agents_dict.values(): 145 | a.reset() 146 | 147 | def get_prev_decision_view(self): 148 | decision = Arrdict() 149 | if self.possible_teams is None: 150 | for p in self.action_spaces: 151 | decision[p] = self.agents_dict[self.agent_for(p)].get_prev_decision_view() 152 | else: 153 | for team in self.possible_teams: 154 | for p in team: 155 | decision[p] = self.agents_dict[self.agent_for(p)].get_prev_decision_view() 156 | return decision 157 | -------------------------------------------------------------------------------- /coop_marl/utils/rebar/rebar/queuing.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from torch import multiprocessing as mp 4 | import queue 5 | from contextlib import contextmanager, asynccontextmanager 6 | import traceback 7 | import asyncio 8 | from functools import wraps 9 | from torch import nn 10 | import logging 11 | from . import dotdict 12 | 13 | log = logging.getLogger(__name__) 14 | 15 | class SerialQueue: 16 | 17 | def __init__(self): 18 | self._queue = [] 19 | self._put_end = False 20 | self._got_end = False 21 | 22 | def get(self): 23 | if len(self._queue) > 0: 24 | item = self._queue[0] 25 | self._queue = self._queue[1:] 26 | 27 | if isinstance(item, str) and (item == '__END__'): 28 | log.info('Got END') 29 | self._got_end = True 30 | return None 31 | else: 32 | return item 33 | else: 34 | return None 35 | 36 | def put(self, item): 37 | # Not safe to test `in` directly since `item` is likely to be a tensor 38 | if isinstance(item, (str, type(None))) and (item in ('__END__', None)): 39 | raise ValueError(f'Tried to put sentinel value "{item}"') 40 | if len(self._queue) < 1: 41 | self._queue.append(item) 42 | return True 43 | else: 44 | return False 45 | 46 | def put_end(self): 47 | if self._put_end: 48 | return True 49 | else: 50 | if len(self._queue) < 1: 51 | self._queue.append('__END__') 52 | log.info('Put END') 53 | self._put_end = True 54 | return True 55 | else: 56 | return False 57 | 58 | def get_end(self): 59 | self.get() 60 | return self._got_end 61 | 62 | def join(self, timeout=None): 63 | if len(self._queue) == 0: 64 | return True 65 | else: 66 | return False 67 | 68 | class MultiprocessQueue: 69 | 70 | def __init__(self): 71 | self.queue = mp.JoinableQueue(1) 72 | self._put_end = False 73 | self._got_end = False 74 | 75 | def get(self): 76 | try: 77 | item = self.queue.get_nowait() 78 | if isinstance(item, str) and (item == '__END__'): 79 | log.info('Got END') 80 | self._got_end = True 81 | self.queue.task_done() 82 | return None 83 | else: 84 | self.queue.task_done() 85 | return item 86 | except queue.Empty: 87 | return None 88 | 89 | def put(self, item): 90 | # Not safe to test `in` directly since `item` is likely to be a tensor 91 | if isinstance(item, (str, type(None))) and (item in ('__END__', None)): 92 | raise ValueError(f'Tried to put sentinel value "{item}"') 93 | try: 94 | self.queue.put_nowait(item) 95 | return True 96 | except queue.Full: 97 | return False 98 | 99 | def put_end(self): 100 | try: 101 | if not self._put_end: 102 | self.queue.put_nowait('__END__') 103 | log.info('Put END') 104 | self._put_end = True 105 | return True 106 | except queue.Full: 107 | return False 108 | 109 | def get_end(self): 110 | self.get() 111 | return self._got_end 112 | 113 | def join(self, timeout=None): 114 | try: 115 | with self.queue._cond: 116 | if not self.queue._unfinished_tasks._semlock._is_zero(): 117 | self.queue._cond.wait(timeout=timeout) 118 | return True 119 | except RuntimeError: 120 | return False 121 | 122 | async def close(intakes, outputs, timeout=5): 123 | """Strategy: 124 | * Wait until you can send an END through each output queue 125 | * Drain the intake queues until you get an END from each one 126 | * Wait for each output queue to drain 127 | """ 128 | log.info(f'Closing; draining intakes and waiting to send ENDs. {timeout}s timeout.') 129 | cutoff = time.time() + timeout 130 | while True: 131 | # Avoid a deadlock where everyone's queues are full so ENDs can't be sent 132 | for intake in intakes: 133 | intake.get() 134 | 135 | if all(o.put_end() for o in outputs): 136 | break 137 | if time.time() > cutoff: 138 | log.warn('Timed out while waiting to send ENDs') 139 | return 140 | 141 | # We're not actually running in a proper scheduler here, so can't sleep via it 142 | await asyncio.sleep(0) 143 | time.sleep(.1) 144 | 145 | log.info(f'Sent ENDs to outputs; waiting to get ENDs from intakes') 146 | while True: 147 | if all(i.get_end() for i in intakes): 148 | break 149 | if time.time() > cutoff: 150 | log.warn('Timed out while waiting to get ENDs') 151 | return 152 | 153 | # We're not actually running in a proper scheduler here, so can't sleep via it 154 | await asyncio.sleep(0) 155 | time.sleep(.1) 156 | 157 | log.info(f'Intakes emptied; waiting for outputs to drain') 158 | while True: 159 | if all(o.join(.1) for o in outputs): 160 | break 161 | if time.time() > cutoff: 162 | log.warn('Timed out while waiting to drain outputs') 163 | return 164 | 165 | # We're not actually running in a proper scheduler here, so can't sleep via it 166 | await asyncio.sleep(0) 167 | time.sleep(.1) 168 | 169 | log.info('Outputs drained.') 170 | 171 | def create(x, serial=False): 172 | if isinstance(x, dict): 173 | return dotdict.dotdict({n: create(v, serial) for n, v in x.items()}) 174 | elif isinstance(x, (list, tuple)): 175 | return dotdict.dotdict({n: create(n, serial) for n in x}) 176 | elif isinstance(x, str): 177 | return SerialQueue() if serial else MultiprocessQueue() 178 | raise ValueError(f'Can\'t handle {type(x)}') 179 | 180 | @asynccontextmanager 181 | async def cleanup(intakes, outputs): 182 | intakes = [intakes] if isinstance(intakes, (SerialQueue, MultiprocessQueue)) else intakes 183 | outputs = [outputs] if isinstance(outputs, (SerialQueue, MultiprocessQueue)) else outputs 184 | try: 185 | yield 186 | except: 187 | log.info(f'Got an exception, cleaning up queues:\n{traceback.format_exc()}') 188 | raise 189 | finally: 190 | await close(intakes, outputs) 191 | 192 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/cooking_world/world_objects.py: -------------------------------------------------------------------------------- 1 | from gym_cooking.cooking_world.abstract_classes import * 2 | from gym_cooking.cooking_world.constants import * 3 | from typing import List 4 | 5 | 6 | class Floor(StaticObject): 7 | 8 | def __init__(self, location): 9 | super().__init__(location, True) 10 | 11 | def accepts(self, dynamic_objects) -> bool: 12 | return False 13 | 14 | def file_name(self) -> str: 15 | return "floor" 16 | 17 | 18 | class Counter(StaticObject): 19 | 20 | def __init__(self, location): 21 | super().__init__(location, False) 22 | 23 | def accepts(self, dynamic_objects) -> bool: 24 | return True 25 | 26 | def file_name(self) -> str: 27 | return "counter" 28 | 29 | 30 | class DeliverSquare(StaticObject): 31 | 32 | def __init__(self, location): 33 | super().__init__(location, False) 34 | 35 | def accepts(self, dynamic_objects) -> bool: 36 | return True 37 | 38 | def file_name(self) -> str: 39 | return "delivery" 40 | 41 | 42 | class CutBoard(StaticObject, ActionObject): 43 | 44 | def __init__(self, location): 45 | super().__init__(location, False) 46 | 47 | def action(self, dynamic_objects: List): 48 | if len(dynamic_objects) == 1: 49 | try: 50 | return dynamic_objects[0].chop() 51 | except AttributeError: 52 | return False 53 | return False 54 | 55 | def accepts(self, dynamic_objects) -> bool: 56 | return len(dynamic_objects) == 1 and isinstance(dynamic_objects[0], ChopFood) 57 | 58 | def file_name(self) -> str: 59 | return "cutboard" 60 | 61 | 62 | class Blender(StaticObject, ProgressingObject): 63 | 64 | def __init__(self, location): 65 | super().__init__(location, False) 66 | self.content = None 67 | 68 | def progress(self, dynamic_objects): 69 | assert len(dynamic_objects) < 2, "Too many Dynamic Objects placed into the Blender" 70 | if not dynamic_objects: 71 | self.content = None 72 | return 73 | elif not self.content: 74 | self.content = dynamic_objects 75 | elif self.content: 76 | if self.content[0] == dynamic_objects[0]: 77 | self.content[0].blend() 78 | else: 79 | self.content = dynamic_objects 80 | 81 | def accepts(self, dynamic_objects) -> bool: 82 | return len(dynamic_objects) == 1 and isinstance(dynamic_objects[0], BlenderFood) 83 | 84 | def file_name(self) -> str: 85 | return "blender3" 86 | 87 | 88 | class Plate(Container): 89 | 90 | def __init__(self, location): 91 | super().__init__(location) 92 | 93 | def add_content(self, content): 94 | if not isinstance(content, Food): 95 | raise TypeError(f"Only Food can be added to a plate! Tried to add {content.name()}") 96 | if not content.done(): 97 | raise Exception(f"Can't add food in unprepared state.") 98 | self.content.append(content) 99 | 100 | def file_name(self) -> str: 101 | return "Plate" 102 | 103 | 104 | class Onion(ChopFood): 105 | 106 | def __init__(self, location): 107 | super().__init__(location) 108 | 109 | def done(self): 110 | if self.chop_state == ChopFoodStates.CHOPPED: 111 | return True 112 | else: 113 | return False 114 | 115 | def file_name(self) -> str: 116 | if self.done(): 117 | return "ChoppedOnion" 118 | else: 119 | return "FreshOnion" 120 | 121 | 122 | class Tomato(ChopFood): 123 | 124 | def __init__(self, location): 125 | super().__init__(location) 126 | 127 | def done(self): 128 | if self.chop_state == ChopFoodStates.CHOPPED: 129 | return True 130 | else: 131 | return False 132 | 133 | def file_name(self) -> str: 134 | if self.done(): 135 | return "ChoppedTomato" 136 | else: 137 | return "FreshTomato" 138 | 139 | 140 | class Lettuce(ChopFood): 141 | 142 | def __init__(self, location): 143 | super().__init__(location) 144 | 145 | def done(self): 146 | if self.chop_state == ChopFoodStates.CHOPPED: 147 | return True 148 | else: 149 | return False 150 | 151 | def file_name(self) -> str: 152 | if self.done(): 153 | return "ChoppedLettuce" 154 | else: 155 | return "FreshLettuce" 156 | 157 | 158 | class Carrot(BlenderFood, ChopFood): 159 | 160 | def __init__(self, location): 161 | super().__init__(location) 162 | 163 | def done(self): 164 | if self.chop_state == ChopFoodStates.CHOPPED or self.blend_state == BlenderFoodStates.MASHED: 165 | return True 166 | else: 167 | return False 168 | 169 | def file_name(self) -> str: 170 | if self.done(): 171 | return "ChoppedCarrot" 172 | else: 173 | return "FreshCarrot" 174 | 175 | 176 | class Agent(Object): 177 | 178 | def __init__(self, location, color, name): 179 | super().__init__(location, False, False) 180 | self.holding = None 181 | self.color = color 182 | self.name = name 183 | self.orientation = 1 184 | 185 | def grab(self, obj: DynamicObject): 186 | self.holding = obj 187 | obj.move_to(self.location) 188 | 189 | def put_down(self, location): 190 | self.holding.move_to(location) 191 | self.holding = None 192 | 193 | def move_to(self, new_location): 194 | self.location = new_location 195 | if self.holding: 196 | self.holding.move_to(new_location) 197 | 198 | def change_orientation(self, new_orientation): 199 | assert 0 < new_orientation < 5 200 | self.orientation = new_orientation 201 | 202 | def file_name(self) -> str: 203 | pass 204 | 205 | 206 | StringToClass = { 207 | "Floor": Floor, 208 | "Counter": Counter, 209 | "CutBoard": CutBoard, 210 | "DeliverSquare": DeliverSquare, 211 | "Tomato": Tomato, 212 | "Lettuce": Lettuce, 213 | "Onion": Onion, 214 | "Plate": Plate, 215 | "Agent": Agent, 216 | "Blender": Blender, 217 | "Carrot": Carrot 218 | } 219 | 220 | ClassToString = { 221 | Floor: "Floor", 222 | Counter: "Counter", 223 | CutBoard: "CutBoard", 224 | DeliverSquare: "DeliverSquare", 225 | Tomato: 'Tomato', 226 | Lettuce: "Lettuce", 227 | Onion: "Onion", 228 | Plate: "Plate", 229 | Agent: "Agent", 230 | Blender: "Blender", 231 | Carrot: "Carrot" 232 | } 233 | 234 | GAME_CLASSES = [Floor, Counter, CutBoard, DeliverSquare, Tomato, Lettuce, Onion, Plate, Agent, Blender, Carrot] 235 | GAME_CLASSES_STATE_LENGTH = [(Floor, 1), (Counter, 1), (CutBoard, 1), (DeliverSquare, 1), (Tomato, 2), 236 | (Lettuce, 2), (Onion, 2), (Plate, 1), (Agent, 5), (Blender, 1), (Carrot, 3)] 237 | GAME_CLASSES_HOLDABLE_IDX = {cls:i for i, cls in enumerate(["Tomato", "Lettuce", "Onion", "Plate", "Carrot"])} 238 | FOOD_CLASSES = ["Tomato", "Lettuce", "Onion", "Carrot"] 239 | FOOD_CLASSES_IDX = {cls:i for i, cls in enumerate(FOOD_CLASSES)} 240 | OBJ_IDX = {ClassToString[cls]:i for i, cls in enumerate(GAME_CLASSES[1:])} -------------------------------------------------------------------------------- /coop_marl/utils/rebar/rebar/logging.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import time 4 | from pathlib import Path 5 | from collections import defaultdict, deque 6 | import logging.handlers 7 | import ipywidgets as widgets 8 | from contextlib import contextmanager 9 | import psutil 10 | from . import widgets, paths 11 | from .contextlib import maybeasynccontextmanager 12 | import sys 13 | import traceback 14 | import _thread 15 | import threading 16 | 17 | # for re-export 18 | from logging import getLogger 19 | 20 | log = getLogger(__name__) 21 | 22 | #TODO: This shouldn't be at the top level 23 | logging.basicConfig( 24 | stream=sys.stdout, 25 | level=logging.INFO, 26 | format='%(asctime)s %(levelname)s %(name)s: %(message)s', 27 | datefmt=r'%Y-%m-%d %H:%M:%S') 28 | logging.getLogger('parso').setLevel('WARN') # Jupyter's autocomplete spams the output if this isn't set 29 | log.info('Set log params') 30 | 31 | def in_ipython(): 32 | try: 33 | __IPYTHON__ 34 | return True 35 | except NameError: 36 | return False 37 | 38 | class StdoutRenderer: 39 | 40 | def __init__(self): 41 | super().__init__() 42 | 43 | def emit(self, path, line): 44 | source = '{procname}/#{pid}'.format(**paths.parse(path)) 45 | print(f'{source}: {line}') 46 | 47 | def close(self): 48 | pass 49 | 50 | class IPythonRenderer: 51 | 52 | def __init__(self, compositor=None): 53 | super().__init__() 54 | self._out = (compositor or widgets.Compositor()).output() 55 | self._next = time.time() 56 | self._lasts = {} 57 | self._buffers = defaultdict(lambda: deque(['']*self._out.lines, maxlen=self._out.lines)) 58 | 59 | def _format_block(self, name): 60 | n_lines = max(self._out.lines//(len(self._buffers) + 2), 1) 61 | lines = '\n'.join(list(self._buffers[name])[-n_lines:]) 62 | return f'{name}:\n{lines}' 63 | 64 | def _display(self, force=False): 65 | content = '\n\n'.join([self._format_block(n) for n in self._buffers]) 66 | self._out.refresh(content) 67 | 68 | for name, last in list(self._lasts.items()): 69 | if time.time() - last > 120: 70 | del self._buffers[name] 71 | del self._lasts[name] 72 | 73 | def emit(self, path, line): 74 | source = '{procname}/#{pid}'.format(**paths.parse(path)) 75 | self._buffers[source].append(line) 76 | self._lasts[source] = time.time() 77 | self._display() 78 | 79 | def close(self): 80 | self._display(force=True) 81 | # Want to leave the outputs open so you can see the final messages 82 | # self._out.close() 83 | super().close() 84 | 85 | @contextmanager 86 | def handlers(*new_handlers): 87 | logger = logging.getLogger() 88 | old_handlers = [*logger.handlers] 89 | try: 90 | logger.handlers = new_handlers 91 | yield 92 | finally: 93 | for h in new_handlers: 94 | try: 95 | h.acquire() 96 | h.flush() 97 | h.close() 98 | except (OSError, ValueError): 99 | pass 100 | finally: 101 | h.release() 102 | 103 | logger.handlers = old_handlers 104 | 105 | @maybeasynccontextmanager 106 | def to_dir(run_name): 107 | path = paths.path(run_name, 'logs').with_suffix('.txt') 108 | handler = logging.FileHandler(path) 109 | handler.setLevel(logging.INFO) 110 | handler.setFormatter(logging.Formatter( 111 | fmt='%(asctime)s %(levelname)s %(name)s: %(message)s', 112 | datefmt=r'%H:%M:%S')) 113 | 114 | with handlers(handler): 115 | try: 116 | yield 117 | except: 118 | log.info(f'Trace:\n{traceback.format_exc()}') 119 | raise 120 | 121 | class Reader: 122 | 123 | def __init__(self, run_name): 124 | self._dir = paths.subdirectory(run_name, 'logs') 125 | self._files = {} 126 | 127 | def read(self): 128 | for path in self._dir.glob('*.txt'): 129 | if path not in self._files: 130 | self._files[path] = path.open('r') 131 | 132 | for path, f in self._files.items(): 133 | for line in f.readlines(): 134 | yield path, line.rstrip('\n') 135 | 136 | def __from_dir(canceller, renderer, reader): 137 | while True: 138 | for path, line in reader.read(): 139 | renderer.emit(path, line) 140 | 141 | if canceller.is_set(): 142 | break 143 | 144 | time.sleep(.01) 145 | 146 | def _from_dir(canceller, renderer, reader): 147 | try: 148 | __from_dir(canceller, renderer, reader) 149 | except KeyboardInterrupt: 150 | log.info('Interrupting main') 151 | _thread.interrupt_main() 152 | __from_dir(canceller, renderer, reader) 153 | 154 | @contextmanager 155 | def from_dir(run_name, compositor=None): 156 | if in_ipython(): 157 | renderer = IPythonRenderer(compositor) 158 | else: 159 | renderer = StdoutRenderer() 160 | 161 | with to_dir(run_name): 162 | try: 163 | reader = Reader(run_name) 164 | canceller = threading.Event() 165 | thread = threading.Thread(target=_from_dir, args=(canceller, renderer, reader)) 166 | thread.start() 167 | yield 168 | finally: 169 | log.info('Cancelling log forwarding thread') 170 | time.sleep(.25) 171 | canceller.set() 172 | thread.join(1) 173 | if thread.is_alive(): 174 | log.error('Logging thread won\'t die') 175 | else: 176 | log.info('Log forwarding thread cancelled') 177 | 178 | @contextmanager 179 | def via_dir(run_name, compositor=None): 180 | with to_dir(run_name), from_dir(run_name, compositor): 181 | yield 182 | 183 | ### TESTS 184 | 185 | def test_in_process(): 186 | paths.clear('test', 'logs') 187 | 188 | with from_dir('test'): 189 | for _ in range(10): 190 | log.info('hello') 191 | time.sleep(.1) 192 | 193 | def _test_multiprocess(run_name): 194 | with to_file(run_name): 195 | for i in range(10): 196 | log.info(str(i)) 197 | time.sleep(.5) 198 | 199 | def test_multiprocess(): 200 | paths.clear('test', 'logs') 201 | 202 | import multiprocessing as mp 203 | with from_dir('test'): 204 | ps = [] 205 | for _ in range(3): 206 | p = mp.Process(target=_test_multiprocess, args=('test',)) 207 | p.start() 208 | ps.append(p) 209 | 210 | while any(p.is_alive() for p in ps): 211 | time.sleep(.5) 212 | 213 | def _test_error(run_name): 214 | with to_file(run_name): 215 | log.info('Alive') 216 | time.sleep(2) 217 | raise ValueError('Last gasp') 218 | 219 | def test_error(): 220 | paths.clear('test', 'logs') 221 | 222 | import multiprocessing as mp 223 | with from_dir('test'): 224 | ps = [] 225 | for _ in range(1): 226 | p = mp.Process(target=_test_error, args=('test',)) 227 | p.start() 228 | ps.append(p) 229 | 230 | while any(p.is_alive() for p in ps): 231 | time.sleep(.5) 232 | -------------------------------------------------------------------------------- /coop_marl/envs/overcooked/gym_cooking/environment/game/game.py: -------------------------------------------------------------------------------- 1 | import os 2 | from gym_cooking.environment.game import graphic_pipeline 3 | from gym_cooking.misc.game.utils import * 4 | 5 | import pygame 6 | import matplotlib.pyplot as plt 7 | 8 | import os.path 9 | from collections import defaultdict 10 | from datetime import datetime 11 | from time import sleep 12 | 13 | 14 | os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = 'hide' 15 | 16 | 17 | class Game: 18 | 19 | def __init__(self, env, num_humans, ai_policies, max_steps=100, render=False): 20 | self._running = True 21 | self.env = env 22 | self.play = bool(num_humans) 23 | self.render = render or self.play 24 | # Visual parameters 25 | self.graphics_pipeline = graphic_pipeline.GraphicPipeline(env, self.render) 26 | self.save_dir = 'misc/game/screenshots' 27 | self.store = defaultdict(list) 28 | self.num_humans = num_humans 29 | self.ai_policies = ai_policies 30 | self.max_steps = max_steps 31 | self.current_step = 0 32 | self.last_obs = env.reset() 33 | self.step_done = False 34 | self.yielding_action_dict = {} 35 | assert len(ai_policies) == len(env.unwrapped.world.agents) - num_humans 36 | if not os.path.exists(self.save_dir): 37 | os.makedirs(self.save_dir) 38 | 39 | def on_init(self): 40 | pygame.init() 41 | self.graphics_pipeline.on_init() 42 | return True 43 | 44 | def on_event(self, event): 45 | self.step_done = False 46 | if event.type == pygame.QUIT: 47 | self._running = False 48 | self.store["observation"].append(self.last_obs) 49 | elif event.type == pygame.KEYDOWN: 50 | # exit the game 51 | if event.key == pygame.K_ESCAPE: 52 | self._running = False 53 | self.store["observation"].append(self.last_obs) 54 | # Save current image 55 | if event.key == pygame.K_RETURN: 56 | image_name = '{}_{}.png'.format(self.env.unwrapped.filename, datetime.now().strftime('%m-%d-%y_%H-%M-%S')) 57 | pygame.image.save(self.graphics_pipeline.screen, '{}/{}'.format(self.save_dir, image_name)) 58 | print('Saved image {} to {}'.format(image_name, self.save_dir)) 59 | return 60 | 61 | # Control current human agent 62 | if event.key in KeyToTuple_human1 and self.num_humans > 0: 63 | store_action_dict = {} 64 | action = KeyToTuple_human1[event.key] 65 | for i in range(self.num_humans): 66 | self.env.unwrapped.world.agents[i].action = action 67 | store_action_dict[self.env.unwrapped.world.agents[i]] = action 68 | 69 | self.store["observation"].append(self.last_obs) 70 | self.store["agent_states"].append([agent.location for agent in self.env.unwrapped.world.agents]) 71 | for idx, agent in enumerate(self.env.unwrapped.world.agents): 72 | if idx >= self.num_humans: 73 | ai_policy = self.ai_policies[idx - self.num_humans] 74 | env_agent = self.env.unwrapped.world_agent_to_env_agent_mapping[agent] 75 | last_obs_raw = self.last_obs[env_agent] 76 | ai_action = ai_policy.get_action(last_obs_raw) 77 | store_action_dict[agent] = ai_action 78 | self.env.unwrapped.world.agents[idx].action = ai_action 79 | 80 | self.yielding_action_dict = {agent: self.env.unwrapped.world_agent_mapping[agent].action 81 | for agent in self.env.agents} 82 | observations, rewards, dones, infos = self.env.step(self.yielding_action_dict) 83 | # print(observations['player_0']) 84 | print(rewards) 85 | print(infos) 86 | 87 | self.store["actions"].append(store_action_dict) 88 | self.store["info"].append(infos) 89 | self.store["rewards"].append(rewards) 90 | self.store["done"].append(dones) 91 | 92 | self.last_obs = observations 93 | self.step_done = True 94 | 95 | if all(dones.values()): 96 | self._running = False 97 | self.store["observation"].append(self.last_obs) 98 | 99 | def ai_only_event(self): 100 | self.step_done = False 101 | 102 | store_action_dict = {} 103 | 104 | self.store["observation"].append(self.last_obs) 105 | self.store["agent_states"].append([agent.location for agent in self.env.unwrapped.world.agents]) 106 | for idx, agent in enumerate(self.env.unwrapped.world.agents): 107 | if idx >= self.num_humans: 108 | ai_policy = self.ai_policies[idx - self.num_humans] # .agent 109 | env_agent = self.env.unwrapped.world_agent_to_env_agent_mapping[agent] 110 | last_obs_raw = self.last_obs[env_agent] 111 | ai_action = ai_policy.get_action(last_obs_raw) 112 | store_action_dict[agent] = ai_action 113 | self.env.unwrapped.world.agents[idx].action = ai_action 114 | 115 | self.yielding_action_dict = {agent: self.env.unwrapped.world_agent_mapping[agent].action 116 | for agent in self.env.agents} 117 | observations, rewards, dones, infos = self.env.step(self.yielding_action_dict) 118 | self.store["actions"].append(store_action_dict) 119 | self.store["info"].append(infos) 120 | self.store["rewards"].append(rewards) 121 | self.store["done"].append(dones) 122 | self.last_obs = observations 123 | self.step_done = True 124 | 125 | if all(dones.values()): 126 | self._running = False 127 | self.store["observation"].append(self.last_obs) 128 | 129 | def on_execute(self): 130 | self._running = self.on_init() 131 | 132 | while self._running: 133 | for event in pygame.event.get(): 134 | self.on_event(event) 135 | self.on_render() 136 | self.on_cleanup() 137 | 138 | return self.store 139 | 140 | def on_execute_yielding(self): 141 | self._running = self.on_init() 142 | 143 | while self._running: 144 | for event in pygame.event.get(): 145 | self.on_event(event) 146 | self.on_render() 147 | if self.step_done: 148 | self.step_done = False 149 | yield self.store["observation"][-1], self.store["done"][-1], self.store["info"][-1], \ 150 | self.store["rewards"][-1], self.yielding_action_dict 151 | self.on_cleanup() 152 | 153 | def on_execute_ai_only_with_delay(self): 154 | self._running = self.on_init() 155 | 156 | while self._running: 157 | sleep(0.2) 158 | self.ai_only_event() 159 | self.on_render() 160 | self.on_cleanup() 161 | 162 | return self.store 163 | 164 | def on_render(self, mode=''): 165 | return self.graphics_pipeline.on_render(mode) 166 | 167 | @staticmethod 168 | def on_cleanup(): 169 | # pygame.display.quit() 170 | pygame.quit() 171 | 172 | def get_image_obs(self): 173 | return self.graphics_pipeline.get_image_obs() 174 | 175 | def save_image_obs(self, t): 176 | self.graphics_pipeline.save_image_obs(t) 177 | --------------------------------------------------------------------------------