├── skill ├── __init__.py ├── goto_goal.py ├── explore.py └── base_skill.py ├── algos ├── __init__.py ├── base.py ├── ppo.py ├── model.py └── buffer.py ├── utils ├── __init__.py ├── global_param.py ├── log.py ├── logo.py ├── format.py ├── env.py └── chatglm-api.py ├── skill_multi_step ├── __init__.py ├── toggle.py ├── pickup.py ├── goto_goal.py ├── explore.py └── base_skill.py ├── env ├── __init__.py ├── doorkey.py ├── twodoor.py ├── lavadoorkey.py ├── coloreddoorkey.py └── historicalobs.py ├── prompt └── task_info.json ├── .gitignore ├── main.py ├── Readme.md ├── teacher_policy.py ├── planner.py ├── mediator.py └── Game.py /skill/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_skill import * 2 | from .explore import * 3 | from .goto_goal import * -------------------------------------------------------------------------------- /algos/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .ppo import * 3 | from .model import * 4 | from .buffer import * -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .env import * 2 | from .format import * 3 | from .log import * 4 | from .logo import * 5 | from .global_param import * -------------------------------------------------------------------------------- /skill_multi_step/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_skill import * 2 | from .explore import * 3 | from .goto_goal import * 4 | from .pickup import pickup 5 | from .drop import drop 6 | from .toggle import toggle -------------------------------------------------------------------------------- /utils/global_param.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : global_param.py 5 | @Time : 2023/05/23 10:55:41 6 | @Author : Hu Bin 7 | @Version : 1.0 8 | @Desc : global parameter dictory for skill 9 | ''' 10 | 11 | def init(): 12 | global _global_dict 13 | _global_dict = {} 14 | 15 | def set_value(key, value): 16 | _global_dict[key] = value 17 | 18 | def get_value(key, defValue=None): 19 | try: 20 | return _global_dict[key] 21 | except KeyError: 22 | return defValue -------------------------------------------------------------------------------- /skill_multi_step/toggle.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import gymnasium as gym 5 | import minigrid 6 | from .goto_goal import GoTo_Goal 7 | 8 | class Toggle(BaseSkill): 9 | def __init__(self, init_obs, target_obj): 10 | self.unpack_obs(self.obs) 11 | self.target_obj = target_obj 12 | 13 | def __call__(self): 14 | # if self.obs[self.target_pos[0], self.target_pos[1], 2] == 1: # open 15 | # return None, True, False 16 | 17 | # get the coordinate of target_obj 18 | target_pos = tuple(np.argwhere(self.map==self.target_obj)[0]) 19 | print(target_pos) 20 | action = GoTo_Goal(obs, target_pos)() 21 | return action, False, False -------------------------------------------------------------------------------- /skill_multi_step/pickup.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import gymnasium as gym 5 | import minigrid 6 | from .goto_goal import GoTo_Goal 7 | 8 | class Pickup(BaseSkill): 9 | def __init__(self, init_obs, target_obj): 10 | self.unpack_obs(self.obs) 11 | self.target_obj = target_obj 12 | # get the coordinate of target_obj 13 | 14 | def __call__(self): 15 | if self.carrying == self.target_obj: 16 | return None, True, False 17 | else: 18 | # get the coordinate of target_obj 19 | target_pos = tuple(np.argwhere(self.map==self.target_obj)[0]) 20 | print(target_pos) 21 | action = GoTo_Goal(obs, target_pos)() 22 | return action, False, False 23 | 24 | # fwd_pos = tuple(self.agent_pos + DIR_TO_VEC[self.agent_dir]) 25 | # if fwd_pos == target_pos: 26 | # return 3, False, False -------------------------------------------------------------------------------- /algos/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : base.py 5 | @Time : 2023/05/18 09:42:14 6 | @Author : Hu Bin 7 | @Version : 1.0 8 | @Desc : None 9 | ''' 10 | 11 | from abc import ABC, abstractmethod 12 | import os 13 | import torch 14 | from .model import MLPBase, LSTMBase 15 | 16 | class Base(ABC): 17 | """The base class for RL algorithms.""" 18 | 19 | def __init__(self, model, obs_space, action_space, device, save_path, recurrent): 20 | self.device = device 21 | self.save_path = save_path 22 | self.recurrent = recurrent 23 | 24 | if model: 25 | self.model = model.to(self.device) 26 | elif self.recurrent: 27 | print("use LSTM......") 28 | self.model = LSTMBase(obs_space, action_space).to(self.device) 29 | else: 30 | print("use MLP......") 31 | self.model = MLPBase(obs_space, action_space).to(self.device) 32 | 33 | def save(self, name="acmodel"): 34 | try: 35 | os.makedirs(self.save_path) 36 | except OSError: 37 | pass 38 | filetype = ".pt" # pytorch model 39 | torch.save(self.model, os.path.join(self.save_path, name + filetype)) 40 | 41 | @abstractmethod 42 | def update_policy(self): 43 | pass -------------------------------------------------------------------------------- /env/__init__.py: -------------------------------------------------------------------------------- 1 | from .historicalobs import * 2 | from .doorkey import * 3 | from .lavadoorkey import * 4 | from .twodoor import * 5 | from .coloreddoorkey import * 6 | 7 | 8 | gym.envs.register( 9 | id='MiniGrid-SimpleDoorKey-Min5-Max10-View3', 10 | entry_point='env.doorkey:DoorKeyEnv', 11 | kwargs={'minRoomSize' : 5, \ 12 | 'maxRoomSize' : 10, \ 13 | 'agent_view_size' : 3, \ 14 | 'max_steps': 150}, 15 | ) 16 | 17 | gym.envs.register( 18 | id='MiniGrid-LavaDoorKey-Min5-Max10-View3', 19 | entry_point='env.lavadoorkey:LavaDoorKeyEnv', 20 | kwargs={'minRoomSize' : 5, \ 21 | 'maxRoomSize' : 10, \ 22 | 'agent_view_size' : 3, \ 23 | 'max_steps': 150}, 24 | ) 25 | 26 | gym.envs.register( 27 | id='MiniGrid-ColoredDoorKey-Min5-Max10-View3', 28 | entry_point='env.coloreddoorkey:ColoredDoorKeyEnv', 29 | kwargs={'minRoomSize' : 5, \ 30 | 'maxRoomSize' : 10, \ 31 | 'minNumKeys' : 2, \ 32 | 'maxNumKeys' : 2, \ 33 | 'agent_view_size' : 3, \ 34 | 'max_steps' : 150}, 35 | ) 36 | 37 | gym.envs.register( 38 | id='MiniGrid-TwoDoor-Min20-Max20', 39 | entry_point='env.twodoor:TwoDoorEnv', 40 | kwargs={'minRoomSize' : 20, \ 41 | 'maxRoomSize' : 20, \ 42 | 'agent_view_size' : 3, \ 43 | 'max_steps' : 150}, 44 | ) -------------------------------------------------------------------------------- /utils/log.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import os, pickle 3 | import json 4 | 5 | class color: 6 | BOLD = '\033[1m\033[48m' 7 | END = '\033[0m' 8 | ORANGE = '\033[38;5;202m' 9 | BLACK = '\033[38;5;240m' 10 | 11 | 12 | def create_logger(args, training=True): 13 | from torch.utils.tensorboard import SummaryWriter 14 | """Use hyperparms to set a directory to output diagnostic files.""" 15 | 16 | arg_dict = args.__dict__ 17 | 18 | assert "logdir" in arg_dict, \ 19 | "You must provide a 'logdir' key in your command line arguments." 20 | 21 | arg_dict = OrderedDict(sorted(arg_dict.items(), key=lambda t: t[0])) 22 | # logdir = str(arg_dict.pop('logdir')) 23 | if training: 24 | output_dir = os.path.join(args.logdir, args.policy, args.task, args.savedir) 25 | else: 26 | output_dir = os.path.join(args.logdir, args.policy, args.task, args.loaddir, args.savedir) 27 | 28 | os.makedirs(output_dir, exist_ok=True) 29 | 30 | # Create a file with all the hyperparam settings in plaintext 31 | info_path = os.path.join(output_dir, "config.json") 32 | 33 | with open(info_path, 'wt') as f: 34 | json.dump(arg_dict, f, indent=4) 35 | 36 | logger = SummaryWriter(output_dir, flush_secs=0.1) 37 | print("Logging to " + color.BOLD + color.ORANGE + str(output_dir) + color.END) 38 | 39 | logger.name = args.savedir 40 | logger.dir = output_dir 41 | return logger 42 | 43 | def parse_previous(args): 44 | if args.previous is not None: 45 | run_args = pickle.load(open(args.previous + "config.json", "rb")) 46 | args.recurrent = run_args.recurrent 47 | args.env_name = run_args.env_name 48 | args.command_profile = run_args.command_profile 49 | args.input_profile = run_args.input_profile 50 | args.learn_gains = run_args.learn_gains 51 | args.traj = run_args.traj 52 | args.no_delta = run_args.no_delta 53 | args.ik_baseline = run_args.ik_baseline 54 | if args.exchange_reward is not None: 55 | args.reward = args.exchange_reward 56 | args.run_name = run_args.run_name + "_NEW-" + args.reward 57 | else: 58 | args.reward = run_args.reward 59 | args.run_name = run_args.run_name + "--cont" 60 | return args 61 | -------------------------------------------------------------------------------- /utils/logo.py: -------------------------------------------------------------------------------- 1 | class color: 2 | BOLD = '\033[1m\033[48m' 3 | END = '\033[0m' 4 | ORANGE = '\033[38;5;202m' 5 | BLACK = '\033[38;5;240m' 6 | 7 | def print_logo(subtitle=""): 8 | print(color.BOLD, end="") 9 | print(color.ORANGE, end="") 10 | print(" #@@# ") 11 | print(" #@ @# ") 12 | print("#@@# #@@# #@@# #@@# #@ @# #@@@@@@@@@@@@# #@@# ") 13 | print("#@@# #@@# #@@@# #@@@# #@ @# #@@@@@@@@@@@@@@@# #@@# ") 14 | print("#@@# #@@# #@@#@# #@#@@# #@@@@@@@@@# #@@# @@# #@@# ") 15 | print("#@@# #@@# #@@ #@# #@# @@# @# #@@# @@# #@@# ") 16 | print("#@@# #@@# #@@ #@# #@# @@# @# #@@# @@# #@@# ") 17 | print("#@@# #@@# #@@ #@# #@# @@# @# #@@# @@# #@@# ") 18 | print("#@@# #@@# #@@ #@# #@# @@# #@@@@@@@@@@@@@@@@# #@@# ") 19 | print("#@@# #@@# #@@ #@##@# @@# #@@@@@@@@@@@@@@@# #@@# ") 20 | print("#@@# #@@# #@@ #@@# @@# #@@# @@# #@@# ") 21 | print("#@@# #@@# #@@ @@# #@@# @@# #@@# ") 22 | print("#@@# #@@# #@@ @@# #@@# @@# #@@# ") 23 | print("#@@# #@@# #@@ @@# #@@# @@# #@@# ") 24 | print("#@@# #@@# #@@ @@# #@@# @@# #@@# ") 25 | print("#@@@@@@@@@@@@@@@@@# #@@@@@@@@@@@@@@@@@# #@@ @@# #@@# @@# #@@@@@@@@@@@@@@@@@#") 26 | print("#@@@@@@@@@@@@@@@@@# #@@@@@@@@@@@@@@@@@# #@@ @@# #@@# @@# #@@@@@@@@@@@@@@@@@#") 27 | 28 | 29 | print(color.END) 30 | print(subtitle + "\n\n") 31 | -------------------------------------------------------------------------------- /utils/format.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy 4 | import re 5 | import torch 6 | import torch_ac 7 | import gymnasium as gym 8 | 9 | import utils 10 | 11 | 12 | def get_obss_preprocessor(obs_space): 13 | # Check if obs_space is an image space 14 | if isinstance(obs_space, gym.spaces.Box): 15 | obs_space = {"image": obs_space.shape} 16 | 17 | def preprocess_obss(obss, device=None): 18 | return torch_ac.DictList({ 19 | "image": preprocess_images(obss, device=device) 20 | }) 21 | 22 | # Check if it is a MiniGrid observation space 23 | elif isinstance(obs_space, gym.spaces.Dict) and "image" in obs_space.spaces.keys(): 24 | obs_space = {"image": obs_space.spaces["image"].shape, "text": 100} 25 | 26 | vocab = Vocabulary(obs_space["text"]) 27 | 28 | def preprocess_obss(obss, device=None): 29 | return torch_ac.DictList({ 30 | "image": preprocess_images([obs["image"] for obs in obss], device=device), 31 | "text": preprocess_texts([obs["mission"] for obs in obss], vocab, device=device) 32 | }) 33 | 34 | preprocess_obss.vocab = vocab 35 | 36 | else: 37 | raise ValueError("Unknown observation space: " + str(obs_space)) 38 | 39 | return obs_space, preprocess_obss 40 | 41 | 42 | def preprocess_images(images, device=None): 43 | # Bug of Pytorch: very slow if not first converted to numpy array 44 | images = numpy.array(images) 45 | return torch.tensor(images, device=device, dtype=torch.float) 46 | 47 | 48 | def preprocess_texts(texts, vocab, device=None): 49 | var_indexed_texts = [] 50 | max_text_len = 0 51 | 52 | for text in texts: 53 | tokens = re.findall("([a-z]+)", text.lower()) 54 | var_indexed_text = numpy.array([vocab[token] for token in tokens]) 55 | var_indexed_texts.append(var_indexed_text) 56 | max_text_len = max(len(var_indexed_text), max_text_len) 57 | 58 | indexed_texts = numpy.zeros((len(texts), max_text_len)) 59 | 60 | for i, indexed_text in enumerate(var_indexed_texts): 61 | indexed_texts[i, :len(indexed_text)] = indexed_text 62 | 63 | return torch.tensor(indexed_texts, device=device, dtype=torch.long) 64 | 65 | 66 | class Vocabulary: 67 | """A mapping from tokens to ids with a capacity of `max_size` words. 68 | It can be saved in a `vocab.json` file.""" 69 | 70 | def __init__(self, max_size): 71 | self.max_size = max_size 72 | self.vocab = {} 73 | 74 | def load_vocab(self, vocab): 75 | self.vocab = vocab 76 | 77 | def __getitem__(self, token): 78 | if not token in self.vocab.keys(): 79 | if len(self.vocab) >= self.max_size: 80 | raise ValueError("Maximum vocabulary capacity reached") 81 | self.vocab[token] = len(self.vocab) + 1 82 | return self.vocab[token] 83 | -------------------------------------------------------------------------------- /utils/env.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | import numpy as np 3 | from collections import deque 4 | from gymnasium import spaces 5 | 6 | def make_env_fn(env_key, render_mode=None, frame_stack=1): 7 | def _f(): 8 | env = gym.make(env_key, render_mode=render_mode) 9 | if frame_stack > 1: 10 | env = FrameStack(env, frame_stack) 11 | return env 12 | return _f 13 | 14 | 15 | # Gives a vectorized interface to a single environment 16 | class WrapEnv: 17 | def __init__(self, env_fn): 18 | self.env = env_fn() 19 | 20 | def __getattr__(self, attr): 21 | return getattr(self.env, attr) 22 | 23 | def step(self, action): 24 | # assert action.ndim == 1 25 | env_return = self.env.step(action) 26 | if len(env_return) == 4: 27 | state, reward, terminated, info = env_return 28 | else: 29 | state, reward, terminated, truncated, info = env_return 30 | if isinstance(state, dict): 31 | state = state['image'] 32 | return np.array([state]), np.array([reward]), np.array([terminated]), np.array([info]) 33 | 34 | def render(self): 35 | self.env.render() 36 | 37 | def reset(self, seed=None): 38 | state, *_ = self.env.reset(seed=seed) 39 | if isinstance(state, tuple): 40 | ## gym state is tuple type 41 | return np.array([state[0]]) 42 | elif isinstance(state, dict): 43 | ## minigrid state is dict type 44 | return np.array([state['image']]) 45 | else: 46 | return np.array([state]) 47 | 48 | 49 | class FrameStack(gym.Wrapper): 50 | def __init__(self, env, k): 51 | """Stack k last frames. 52 | Returns lazy array, which is much more memory efficient. 53 | See Also 54 | -------- 55 | baselines.common.atari_wrappers.LazyFrames 56 | """ 57 | gym.Wrapper.__init__(self, env) 58 | self.k = k 59 | self.frames = deque([], maxlen=k) 60 | shp = env.observation_space['image'].shape 61 | self.observation_space = spaces.Box( 62 | low=0, 63 | high=255, 64 | shape=(shp[:-1] + (shp[-1] * k,)), 65 | dtype=env.observation_space['image'].dtype) 66 | 67 | def reset(self, seed=None): 68 | ob = self.env.reset(seed=seed)[0] 69 | ob = ob['image'] 70 | for _ in range(self.k): 71 | self.frames.append(ob) 72 | return self._get_ob(), {} 73 | 74 | def step(self, action): 75 | ob, reward, terminated, truncated, info = self.env.step(action) 76 | # ob, reward, done, info = self.env.step(action) 77 | ob = ob['image'] 78 | self.frames.append(ob) 79 | return self._get_ob(), reward, terminated, truncated, info 80 | 81 | def _get_ob(self): 82 | assert len(self.frames) == self.k 83 | return np.concatenate(self.frames, axis=-1) -------------------------------------------------------------------------------- /utils/chatglm-api.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Dict, List, Any,Optional 2 | 3 | import argparse 4 | import json 5 | import logging 6 | 7 | import fastapi 8 | import httpx 9 | import uvicorn 10 | 11 | 12 | from pydantic import BaseModel 13 | 14 | class Conversation: 15 | system: str 16 | roles: List[str] 17 | messages: List[List[str]] 18 | 19 | def __init__(self,system, roles, messages): 20 | self.system = system 21 | self.roles = roles 22 | self.messages = messages 23 | 24 | conv = Conversation( 25 | system="A chat between a curious user and an artificial intelligence assistant. ", 26 | roles=(["USER", "Assistant"]), 27 | prompt=(), 28 | ) 29 | 30 | 31 | app = fastapi.FastAPI() 32 | 33 | headers = {"User-Agent": "FastChat API Server"} 34 | 35 | 36 | 37 | class Request(BaseModel): 38 | model: str 39 | prompt: List[Dict[str, str]] 40 | top_p: Optional[float] =0.7 41 | temperature: Optional[float] = 0.7 42 | 43 | 44 | @app.post("/v1/chat/completions") 45 | async def chat_completion(request: Request): 46 | """Creates a completion for the chat message""" 47 | conv.prompt =[] 48 | payload = generate_payload(request.prompt) 49 | content = await invoke_example(request.model, payload) 50 | 51 | generate_payload(content) 52 | 53 | 54 | print('a',content) 55 | return content[0]['content'] 56 | 57 | 58 | 59 | import zhipuai 60 | 61 | # your api key 62 | zhipuai.api_key = "you_api_key" 63 | 64 | async def invoke_example(model,prompt): 65 | 66 | response = zhipuai.model_api.invoke( 67 | model= model, 68 | prompt=prompt, 69 | top_p=0.7, 70 | temperature=0.7, 71 | ) 72 | 73 | return response['data']['choices'] 74 | 75 | 76 | 77 | def generate_payload(messages: List[Dict[str, str]]): 78 | 79 | conv.prompt = list(conv.prompt) 80 | for message in prompt: 81 | 82 | msg_role = message["role"] 83 | 84 | if msg_role == "user": 85 | conv.messages.append({'role': conv.roles[0], 'content': message["content"]}) 86 | elif msg_role == "assistant": 87 | conv.messages.append({'role': conv.roles[1], 'content': message["content"]}) 88 | else: 89 | raise ValueError(f"Unknown role: {msg_role}") 90 | 91 | return conv.messages 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser(description="ChatGLM-compatible Restful API server.") 95 | parser.add_argument("--host", type=str, default="10.109.116.3", help="host name") 96 | parser.add_argument("--port", type=int, default=6000, help="port number") 97 | 98 | args = parser.parse_args() 99 | uvicorn.run("chatglm-api:app", host=args.host, port=args.port, reload=False) 100 | 101 | 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /prompt/task_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "twodoor":{ 3 | "episode": 150, 4 | "level": "easy", 5 | "description": "In a locked 2D grid room, there is an agent whose task is to open the door. The door can only be opened while agent holds the key. The agent can perform the following actions: explore, go to , pick up , drop , or open . Please clarify which object the agent has seen and is holding given the information provided. Then, please inference what the agent should do in current state. Finally, please select the correct agent action.", 6 | "example": "\n Example: \n Agent see , holds . \n 1. What agent sees and holds: Agent sees key, holds nothing, has not seen door yet. \n 2. What should the agent do: the agent should go to the key and then, pick it up. \n 3. Action: {go to , pick up } \n.", 7 | "configurations": "MiniGrid-TwoDoor-Min20-Max20" 8 | }, 9 | 10 | "simpledoorkey":{ 11 | "episode": 150, 12 | "level": "easy", 13 | "description": "In a locked 2D grid room, there is an agent whose task is to open the door. The door can only be opened while agent holds the key. The agent can perform the following actions: explore, go to , pick up , drop , or open . Please clarify which object the agent has seen and is holding given the information provided. Then, please inference what the agent should do in current state. Finally, please select the correct agent action.", 14 | "example": "\n Example: \n Agent see , holds . \n 1. What agent sees and holds: Agent sees key, holds nothing, has not seen door yet. \n 2. What should the agent do: the agent should go to the key and then, pick it up. \n 3. Action: {go to , pick up } \n.", 15 | "configurations": "MiniGrid-SimpleDoorKey-Min5-Max10-View3" 16 | }, 17 | 18 | "lavadoorkey":{ 19 | "episode": 150, 20 | "level": "hard", 21 | "description": "In a locked 2D grid room, there is an agent whose task is to open the door. The door can only be opened while agent holds the key. The agent can perform the following actions: explore, go to , pick up , drop , or open . Please clarify which object the agent has seen and is holding given the information provided. Then, please inference what the agent should do in current state. Finally, please select the correct agent action.", 22 | "example": "\n Example: \n Agent see , holds . \n 1. What agent sees and holds: Agent sees key, holds nothing, has not seen door yet. \n 2. What should the agent do: the agent should go to the key and then, pick it up. \n 3. Action: {go to , pick up } \n.", 23 | "configurations": "MiniGrid-LavaDoorKey-Min5-Max10-View3" 24 | }, 25 | 26 | "coloreddoorkey":{ 27 | "episode": 150, 28 | "level": "medium", 29 | "description": "In a locked 2D grid room, agent can only open door while holding a key that matches color of door. Agent can perform following actions: explore, go to , pick up , drop , or open . Agent can only hold one object. Please clarify which object agent sees and holds given information provided. Then, please inference what agent can do in current state. Finally, please select the correct actions.", 30 | "example": "\n Example: \n Agent sees , holds . \n 1. What agent sees and holds: Agent sees key, holds nothing, has not seen door yet. \n 2. What should agent do: Agent should first go to key, and then pick up key. \n 3. Action: {go to , pick up } \n.", 31 | "configurations": "MiniGrid-ColoredDoorKey-Min5-Max10-View3" 32 | } 33 | 34 | } -------------------------------------------------------------------------------- /env/doorkey.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from minigrid.core.constants import COLOR_NAMES 3 | from minigrid.core.grid import Grid 4 | from minigrid.core.mission import MissionSpace 5 | from minigrid.core.world_object import Door, Key, Wall 6 | from .historicalobs import HistoricalObsEnv 7 | 8 | 9 | class DoorKeyEnv(HistoricalObsEnv): 10 | def __init__( 11 | self, 12 | minRoomSize: int = 5, 13 | maxRoomSize: int = 10, 14 | agent_view_size: int = 7, 15 | max_steps: int | None = None, 16 | **kwargs, 17 | ): 18 | 19 | self.minRoomSize = minRoomSize 20 | self.maxRoomSize = maxRoomSize 21 | 22 | mission_space = MissionSpace(mission_func=self._gen_mission) 23 | 24 | if max_steps is None: 25 | max_steps = maxRoomSize ** 2 26 | 27 | super().__init__( 28 | mission_space=mission_space, 29 | width=maxRoomSize, 30 | height=maxRoomSize, 31 | agent_view_size=agent_view_size, 32 | max_steps=max_steps, 33 | **kwargs, 34 | ) 35 | 36 | @staticmethod 37 | def _gen_mission(): 38 | return "open the door" 39 | 40 | def _gen_grid(self, width, height): 41 | 42 | # Create the grid 43 | self.grid = Grid(width, height) 44 | 45 | # Choose the room size randomly 46 | sizeX = self._rand_int(self.minRoomSize, self.maxRoomSize + 1) 47 | sizeY = self._rand_int(self.minRoomSize, self.maxRoomSize + 1) 48 | topX, topY = 0, 0 49 | 50 | # Draw the top and bottom walls 51 | wall = Wall() 52 | for i in range(0, sizeX): 53 | self.grid.set(topX + i, topY, wall) 54 | self.grid.set(topX + i, topY + sizeY - 1, wall) 55 | 56 | # Draw the left and right walls 57 | for j in range(0, sizeY): 58 | self.grid.set(topX, topY + j, wall) 59 | self.grid.set(topX + sizeX - 1, topY + j, wall) 60 | 61 | # Pick which wall to place the out door on 62 | wallSet = {0, 1, 2, 3} 63 | exitDoorWall = self._rand_elem(sorted(wallSet)) 64 | 65 | # Pick the exit door position 66 | # Exit on right wall 67 | if exitDoorWall == 0: 68 | exitDoorPos = (topX + sizeX - 1, topY + self._rand_int(1, sizeY - 1)) 69 | # Exit on south wall 70 | elif exitDoorWall == 1: 71 | exitDoorPos = (topX + self._rand_int(1, sizeX - 1), topY + sizeY - 1) 72 | # Exit on left wall 73 | elif exitDoorWall == 2: 74 | exitDoorPos = (topX, topY + self._rand_int(1, sizeY - 1)) 75 | # Exit on north wall 76 | elif exitDoorWall == 3: 77 | exitDoorPos = (topX + self._rand_int(1, sizeX - 1), topY) 78 | else: 79 | assert False 80 | 81 | # Place the door 82 | doorColor = self._rand_elem(sorted(set(COLOR_NAMES))) 83 | exitDoor = Door(doorColor, is_locked=True) 84 | self.door = exitDoor 85 | self.grid.set(exitDoorPos[0], exitDoorPos[1], exitDoor) 86 | 87 | # Randomize the starting agent position and direction 88 | self.place_agent((topX, topY), (sizeX, sizeY), rand_dir=True) 89 | 90 | # Randomize the key position 91 | key = Key(doorColor) 92 | self.place_obj(key, (topX, topY), (sizeX, sizeY)) 93 | 94 | self.mission = "open the door" 95 | 96 | def step(self, action): 97 | obs, reward, terminated, truncated, info = super().step(action) 98 | 99 | if action == self.actions.toggle: 100 | if self.door.is_open: 101 | reward = self._reward() 102 | terminated = True 103 | 104 | return obs, reward, terminated, truncated, info 105 | 106 | -------------------------------------------------------------------------------- /env/twodoor.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from minigrid.core.constants import COLOR_NAMES 3 | from minigrid.core.grid import Grid 4 | from minigrid.core.mission import MissionSpace 5 | from minigrid.core.world_object import Door, Key, Wall 6 | from .historicalobs import HistoricalObsEnv 7 | 8 | 9 | class TwoDoorEnv(HistoricalObsEnv): 10 | def __init__( 11 | self, 12 | minRoomSize: int = 20, 13 | maxRoomSize: int = 20, 14 | agent_view_size: int = 7, 15 | max_steps: int | None = None, 16 | **kwargs, 17 | ): 18 | 19 | self.minRoomSize = minRoomSize 20 | self.maxRoomSize = maxRoomSize 21 | 22 | mission_space = MissionSpace(mission_func=self._gen_mission) 23 | 24 | if max_steps is None: 25 | max_steps = maxRoomSize ** 2 26 | 27 | super().__init__( 28 | mission_space=mission_space, 29 | width=maxRoomSize, 30 | height=maxRoomSize, 31 | agent_view_size=agent_view_size, 32 | max_steps=max_steps, 33 | **kwargs, 34 | ) 35 | 36 | @staticmethod 37 | def _gen_mission(): 38 | return "open the door" 39 | 40 | def _gen_grid(self, width, height): 41 | 42 | # Create the grid 43 | self.grid = Grid(width, height) 44 | 45 | # Choose the room size randomly 46 | sizeX = self._rand_int(self.minRoomSize, self.maxRoomSize + 1) 47 | sizeY = self._rand_int(self.minRoomSize, self.maxRoomSize + 1) 48 | topX, topY = 0, 0 49 | 50 | # Draw the top and bottom walls 51 | wall = Wall() 52 | for i in range(0, sizeX): 53 | self.grid.set(topX + i, topY, wall) 54 | self.grid.set(topX + i, topY + sizeY - 1, wall) 55 | 56 | # Draw the left and right walls 57 | for j in range(0, sizeY): 58 | self.grid.set(topX, topY + j, wall) 59 | self.grid.set(topX + sizeX - 1, topY + j, wall) 60 | 61 | # Pick which wall to place the out door on 62 | wallSet = {0, 1} 63 | exitDoorWall = self._rand_elem(sorted(wallSet)) 64 | 65 | # Pick the exit door position 66 | # Exit on right and left wall 67 | if exitDoorWall == 0: 68 | exitDoorPos = [(topX + sizeX - 1, topY + self._rand_int(1, sizeY - 1)), 69 | (topX, topY + self._rand_int(1, sizeY - 1))] 70 | # Exit on south and north wall 71 | elif exitDoorWall == 1: 72 | exitDoorPos = [(topX + self._rand_int(1, sizeX - 1), topY + sizeY - 1), 73 | (topX + self._rand_int(1, sizeX - 1), topY)] 74 | else: 75 | assert False 76 | 77 | # Place the door 78 | doorColor = self._rand_elem(sorted(set(COLOR_NAMES))) 79 | self.doors = [] 80 | for i in range(2): 81 | exitDoor = Door(doorColor, is_locked=True) 82 | self.doors.append(exitDoor) 83 | self.grid.set(exitDoorPos[i][0], exitDoorPos[i][1], exitDoor) 84 | 85 | # Randomize the key position 86 | key = Key(doorColor) 87 | self.place_obj(key, (sizeX//2-1, sizeY//2-1), (2, 2)) 88 | 89 | # Randomize the starting agent position and direction 90 | self.place_agent((topX, topY), (sizeX, sizeY), rand_dir=True) 91 | 92 | self.mission = "open the door" 93 | 94 | def step(self, action): 95 | obs, reward, terminated, truncated, info = super().step(action) 96 | 97 | if action == self.actions.toggle: 98 | if self.doors[0].is_open or self.doors[1].is_open: 99 | reward = self._reward() 100 | terminated = True 101 | 102 | return obs, reward, terminated, truncated, info 103 | 104 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | venv/ 126 | ENV/ 127 | env.bak/ 128 | venv.bak/ 129 | 130 | # Spyder project settings 131 | .spyderproject 132 | .spyproject 133 | 134 | # Rope project settings 135 | .ropeproject 136 | 137 | # mkdocs documentation 138 | /site 139 | 140 | # mypy 141 | .mypy_cache/ 142 | .dmypy.json 143 | dmypy.json 144 | 145 | # Pyre type checker 146 | .pyre/ 147 | 148 | # pytype static type analyzer 149 | .pytype/ 150 | 151 | # Cython debug symbols 152 | cython_debug/ 153 | 154 | # PyCharm 155 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 156 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 157 | # and can be added to the global gitignore or merged into this file. For a more nuclear 158 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 159 | #.idea/ 160 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os,json, sys 3 | import numpy as np 4 | # single gpu 5 | 6 | os.system('nvidia-smi -q -d Memory | grep -A5 GPU | grep Free > tmp.txt') 7 | memory_gpu = [int(x.split()[2]) for x in open('tmp.txt', 'r').readlines()] 8 | os.environ["CUDA_VISIBLE_DEVICES"] = str(np.argmax(memory_gpu)) 9 | os.system('rm tmp.txt') 10 | 11 | import torch 12 | import utils 13 | from Game import Game 14 | 15 | def train(args): 16 | for i in args.seed_list: 17 | args.savedir = args.savedir + "-" + str(i) 18 | args.seed = i 19 | game = Game(args) 20 | game.train() 21 | 22 | def evaluate(args): 23 | assert args.loaddir 24 | print("env name: %s for %s" %(args.task, args.loaddir)) 25 | args.seed = args.seed_list[0] 26 | game = Game(args, training=False) 27 | eval_returns = [] 28 | eval_lens = [] 29 | eval_success = [] 30 | 31 | if len(args.env_seed_list) == 0: 32 | env_seed_list = [None] * args.num_eval 33 | elif len(args.env_seed_list) == 1: 34 | env_seed_list = [args.env_seed_list[0] + i for i in range(args.num_eval)] 35 | else: 36 | env_seed_list = args.env_seed_list 37 | 38 | for i in env_seed_list: 39 | eval_outputs = game.evaluate(seed = i, teacher_policy = args.eval_teacher) 40 | eval_returns.append(eval_outputs[0]) 41 | eval_lens.append(eval_outputs[1]) 42 | eval_success.append(eval_outputs[2]) 43 | 44 | print("Mean return:", np.mean(eval_returns)) 45 | print("Mean length:", np.mean(eval_lens)) 46 | print("Success rate:", np.mean(eval_success)) 47 | 48 | 49 | if __name__ == "__main__": 50 | utils.print_logo(subtitle="Maintained by Research Center for Applied Mathematics and Machine Intelligence, Zhejiang Lab") 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("--task", type=str, default="SimpleDoorKey", help="SimpleDoorKey, KeyInBox, RandomBoxKey, ColoredDoorKey, DynamicDoorKey") 53 | 54 | # parser.add_argument("--env_seed", type=int, default=0) 55 | parser.add_argument("--env_seed_list", type=int, nargs="*", default=[0], help="Seeds for evaluation environments") 56 | parser.add_argument("--seed_list", type=int, nargs="*", default=[0], help="Seeds for Numpy, Torch and LLM") 57 | parser.add_argument("--frame_stack", type=int, default=1) 58 | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") 59 | parser.add_argument("--policy", type=str, default='ppo') 60 | parser.add_argument("--n_itr", type=int, default=20000, help="Number of iterations of the learning algorithm") 61 | parser.add_argument("--traj_per_itr", type=int, default=10) 62 | parser.add_argument("--batch_size", type=int, default=128) 63 | parser.add_argument("--lam", type=float, default=0.95, help="Generalized advantage estimate discount") 64 | parser.add_argument("--gamma", type=float, default=0.99, help="MDP discount") 65 | parser.add_argument("--recurrent", default=False, action='store_true') 66 | 67 | parser.add_argument("--logdir", type=str, default="log") # Where to log diagnostics to 68 | parser.add_argument("--loaddir", type=str, default=None) 69 | parser.add_argument("--loadmodel", type=str, default="acmodel") 70 | parser.add_argument("--savedir", type=str, required=True, help="path to folder containing policy and run details") 71 | 72 | parser.add_argument("--offline_planner", default=False, action='store_true') 73 | parser.add_argument("--soft_planner", default=False, action='store_true') 74 | parser.add_argument("--eval_teacher", default=False, action='store_true') 75 | parser.add_argument("--num_eval", type=int, default=10) 76 | parser.add_argument("--eval_interval", type=int, default=10) 77 | parser.add_argument("--save_interval", type=int, default=100) 78 | 79 | if sys.argv[1] == 'eval': 80 | sys.argv.remove(sys.argv[1]) 81 | args = parser.parse_args() 82 | evaluate(args) 83 | elif sys.argv[1] == 'train': 84 | sys.argv.remove(sys.argv[1]) 85 | args = parser.parse_args() 86 | train(args) 87 | else: 88 | print("Invalid option '{}'".format(sys.argv[1])) -------------------------------------------------------------------------------- /env/lavadoorkey.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from minigrid.core.constants import COLOR_NAMES 3 | from minigrid.core.grid import Grid 4 | from minigrid.core.mission import MissionSpace 5 | from minigrid.core.world_object import Door, Key, Wall, Lava 6 | from .historicalobs import HistoricalObsEnv 7 | 8 | 9 | class LavaDoorKeyEnv(HistoricalObsEnv): 10 | def __init__( 11 | self, 12 | minRoomSize: int = 5, 13 | maxRoomSize: int = 10, 14 | agent_view_size: int = 7, 15 | max_steps: int | None = None, 16 | **kwargs, 17 | ): 18 | 19 | self.minRoomSize = minRoomSize 20 | self.maxRoomSize = maxRoomSize 21 | 22 | mission_space = MissionSpace(mission_func=self._gen_mission) 23 | 24 | if max_steps is None: 25 | max_steps = maxRoomSize ** 2 26 | 27 | super().__init__( 28 | mission_space=mission_space, 29 | width=maxRoomSize, 30 | height=maxRoomSize, 31 | agent_view_size=agent_view_size, 32 | max_steps=max_steps, 33 | **kwargs, 34 | ) 35 | 36 | @staticmethod 37 | def _gen_mission(): 38 | return "open the door" 39 | 40 | def _gen_grid(self, width, height): 41 | 42 | # Create the grid 43 | self.grid = Grid(width, height) 44 | 45 | # Choose the room size randomly 46 | sizeX = self._rand_int(self.minRoomSize, self.maxRoomSize + 1) 47 | sizeY = self._rand_int(self.minRoomSize, self.maxRoomSize + 1) 48 | topX, topY = 0, 0 49 | 50 | # Draw the top and bottom walls 51 | wall = Wall() 52 | for i in range(0, sizeX): 53 | self.grid.set(topX + i, topY, wall) 54 | self.grid.set(topX + i, topY + sizeY - 1, wall) 55 | 56 | # Draw the left and right walls 57 | for j in range(0, sizeY): 58 | self.grid.set(topX, topY + j, wall) 59 | self.grid.set(topX + sizeX - 1, topY + j, wall) 60 | 61 | # Pick which wall to place the out door on 62 | wallSet = {0, 1, 2, 3} 63 | exitDoorWall = self._rand_elem(sorted(wallSet)) 64 | 65 | # Pick the exit door position 66 | # Exit on right wall 67 | if exitDoorWall == 0: 68 | rand_int = self._rand_int(1, sizeY - 1) 69 | exitDoorPos = (topX + sizeX - 1, topY + rand_int) 70 | rejectPos = (topX + sizeX - 2, topY + rand_int) 71 | # Exit on south wall 72 | elif exitDoorWall == 1: 73 | rand_int = self._rand_int(1, sizeX - 1) 74 | exitDoorPos = (topX + rand_int, topY + sizeY - 1) 75 | rejectPos = (topX + rand_int, topY + sizeY - 2) 76 | # Exit on left wall 77 | elif exitDoorWall == 2: 78 | rand_int = self._rand_int(1, sizeY - 1) 79 | exitDoorPos = (topX, topY + rand_int) 80 | rejectPos = (topX + 1, topY + rand_int) 81 | # Exit on north wall 82 | elif exitDoorWall == 3: 83 | rand_int = self._rand_int(1, sizeX - 1) 84 | exitDoorPos = (topX + rand_int, topY) 85 | rejectPos = (topX + rand_int, topY + 1) 86 | else: 87 | assert False 88 | 89 | # Place the door 90 | doorColor = self._rand_elem(sorted(set(COLOR_NAMES))) 91 | exitDoor = Door(doorColor, is_locked=True) 92 | self.door = exitDoor 93 | self.grid.set(exitDoorPos[0], exitDoorPos[1], exitDoor) 94 | 95 | # Randomize the starting agent position and direction 96 | self.place_agent((topX, topY), (sizeX, sizeY), rand_dir=True) 97 | 98 | # Randomize the key position 99 | key = Key(doorColor) 100 | self.place_obj(key, (topX, topY), (sizeX, sizeY)) 101 | 102 | # Randomize the lava position 103 | reject_fn = lambda env, pos: pos == rejectPos 104 | lava = Lava() 105 | self.place_obj(lava, (topX, topY), (sizeX, sizeY), reject_fn = reject_fn) 106 | 107 | self.mission = "open the door" 108 | 109 | def step(self, action): 110 | obs, reward, terminated, truncated, info = super().step(action) 111 | 112 | if action == self.actions.toggle: 113 | if self.door.is_open: 114 | reward = self._reward() 115 | terminated = True 116 | 117 | return obs, reward, terminated, truncated, info 118 | 119 | -------------------------------------------------------------------------------- /env/coloreddoorkey.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from minigrid.core.grid import Grid 3 | from minigrid.core.mission import MissionSpace 4 | from minigrid.core.world_object import Door, Key, Wall 5 | from .historicalobs import HistoricalObsEnv 6 | 7 | 8 | class ColoredDoorKeyEnv(HistoricalObsEnv): 9 | def __init__( 10 | self, 11 | minRoomSize: int = 5, 12 | maxRoomSize: int = 10, 13 | minNumKeys: int = 2, 14 | maxNumKeys: int = 3, 15 | agent_view_size: int = 7, 16 | max_steps: int | None = None, 17 | **kwargs, 18 | ): 19 | 20 | assert minNumKeys > 0 21 | self.minRoomSize = minRoomSize 22 | self.maxRoomSize = maxRoomSize 23 | self.minNumKeys = minNumKeys 24 | self.maxNumKeys = maxNumKeys 25 | 26 | mission_space = MissionSpace(mission_func=self._gen_mission) 27 | 28 | if max_steps is None: 29 | max_steps = maxRoomSize ** 2 30 | 31 | super().__init__( 32 | mission_space=mission_space, 33 | width=maxRoomSize, 34 | height=maxRoomSize, 35 | agent_view_size=agent_view_size, 36 | max_steps=max_steps, 37 | **kwargs, 38 | ) 39 | 40 | @staticmethod 41 | def _gen_mission(): 42 | return "open the door" 43 | 44 | def _gen_grid(self, width, height): 45 | 46 | # Create the grid 47 | self.grid = Grid(width, height) 48 | 49 | # Choose the room size randomly 50 | sizeX = self._rand_int(self.minRoomSize, self.maxRoomSize + 1) 51 | sizeY = self._rand_int(self.minRoomSize, self.maxRoomSize + 1) 52 | topX, topY = 0, 0 53 | 54 | # Draw the top and bottom walls 55 | wall = Wall() 56 | for i in range(0, sizeX): 57 | self.grid.set(topX + i, topY, wall) 58 | self.grid.set(topX + i, topY + sizeY - 1, wall) 59 | 60 | # Draw the left and right walls 61 | for j in range(0, sizeY): 62 | self.grid.set(topX, topY + j, wall) 63 | self.grid.set(topX + sizeX - 1, topY + j, wall) 64 | 65 | # Pick which wall to place the out door on 66 | wallSet = {0, 1, 2, 3} 67 | exitDoorWall = self._rand_elem(sorted(wallSet)) 68 | 69 | # Pick the exit door position 70 | # Exit on right wall 71 | if exitDoorWall == 0: 72 | rand_int = self._rand_int(1, sizeY - 1) 73 | exitDoorPos = (topX + sizeX - 1, topY + rand_int) 74 | rejectPos = (topX + sizeX - 2, topY + rand_int) 75 | # Exit on south wall 76 | elif exitDoorWall == 1: 77 | rand_int = self._rand_int(1, sizeX - 1) 78 | exitDoorPos = (topX + rand_int, topY + sizeY - 1) 79 | rejectPos = (topX + rand_int, topY + sizeY - 2) 80 | # Exit on left wall 81 | elif exitDoorWall == 2: 82 | rand_int = self._rand_int(1, sizeY - 1) 83 | exitDoorPos = (topX, topY + rand_int) 84 | rejectPos = (topX + 1, topY + rand_int) 85 | # Exit on north wall 86 | elif exitDoorWall == 3: 87 | rand_int = self._rand_int(1, sizeX - 1) 88 | exitDoorPos = (topX + rand_int, topY) 89 | rejectPos = (topX + rand_int, topY + 1) 90 | else: 91 | assert False 92 | 93 | # Place the door 94 | doorColor = 'blue' 95 | exitDoor = Door(doorColor, is_locked=True) 96 | self.door = exitDoor 97 | self.grid.set(exitDoorPos[0], exitDoorPos[1], exitDoor) 98 | 99 | # Randomize the starting agent position and direction 100 | self.place_agent((topX, topY), (sizeX, sizeY), rand_dir=True) 101 | 102 | # Randomize the key position 103 | reject_fn = lambda env, pos: pos == rejectPos 104 | numKeys = self._rand_int(self.minNumKeys, self.maxNumKeys + 1) 105 | key = Key(doorColor) 106 | self.place_obj(key, (topX, topY), (sizeX, sizeY), reject_fn = reject_fn) 107 | keyColor = 'red' 108 | for i in range(numKeys - 1): 109 | key = Key(keyColor) 110 | self.place_obj(key, (topX, topY), (sizeX, sizeY), reject_fn = reject_fn) 111 | 112 | self.mission = "open the door" 113 | 114 | def step(self, action): 115 | obs, reward, terminated, truncated, info = super().step(action) 116 | 117 | if action == self.actions.toggle: 118 | if self.door.is_open: 119 | reward = self._reward() 120 | terminated = True 121 | 122 | return obs, reward, terminated, truncated, info -------------------------------------------------------------------------------- /skill_multi_step/goto_goal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import gymnasium as gym 5 | import minigrid 6 | from .base_skill import BaseSkill 7 | 8 | ''' 9 | OBJECT_TO_IDX = { 10 | "unseen": 0, 11 | "empty": 1, 12 | "wall": 2, 13 | "floor": 3, 14 | "door": 4, 15 | "key": 5, 16 | "ball": 6, 17 | "box": 7, 18 | "goal": 8, 19 | "lava": 9, 20 | "agent": 10, 21 | } 22 | 23 | STATE_TO_IDX = { 24 | "open": 0, 25 | "closed": 1, 26 | "locked": 2, 27 | } 28 | ''' 29 | 30 | DIRECTION = { 31 | 0: [1, 0], 32 | 1: [0, 1], 33 | 2: [-1, 0], 34 | 3: [0, -1], 35 | } 36 | 37 | 38 | def check_go_through(pos, maps): 39 | x, y = pos 40 | width, height, _ = maps.shape 41 | if x<0 or x>=width or y<0 or y>=height: 42 | return False 43 | # return (maps[x, y, 0] in [1, 8] or (maps[x, y, 0] == 4 and maps[x, y, 2]==0) ) 44 | return (maps[x, y, 0] in [1, 8, 9] or (maps[x, y, 0] == 4 and maps[x, y, 2]==0) ) 45 | 46 | def get_neighbors(pos_and_dir, maps): 47 | x, y, direction = pos_and_dir 48 | next_dir_left = direction - 1 if direction > 0 else 3 49 | next_dir_right = direction + 1 if direction < 3 else 0 50 | neighbor_list = [(x,y,next_dir_left), (x,y,next_dir_right)] 51 | forward_x, forward_y = DIRECTION[direction] 52 | new_x,new_y = (x+forward_x, y+forward_y) 53 | 54 | if check_go_through((new_x,new_y), maps): 55 | neighbor_list.append((new_x, new_y, direction)) 56 | 57 | 58 | assert not len(neighbor_list)==0 59 | 60 | return neighbor_list 61 | 62 | 63 | class GoTo_Goal(BaseSkill): 64 | def __init__(self, init_obs, target_pos): 65 | self.obs = init_obs[:, :, -4:] 66 | self.unpack_obs(self.obs) 67 | self.path = self.plan(target_pos) 68 | 69 | def plan(self, target_pos): 70 | start_node = (self.agent_pos[0], self.agent_pos[1], self.agent_dir) 71 | x, y = target_pos 72 | target_pos_and_dir = [(x-1, y, 0), (x, y-1, 1), (x+1, y, 2), (x, y+1, 3)] 73 | 74 | open_list = set([start_node]) 75 | closed_list = set([]) 76 | 77 | g = {} 78 | g[start_node] = 0 79 | 80 | parents = {} 81 | parents[start_node] = start_node 82 | 83 | while len(open_list) > 0: 84 | n = None 85 | 86 | for v in open_list: 87 | if n is None or g[v] < g[n]: 88 | n = v 89 | 90 | if n == None: 91 | assert False, "no action found" 92 | return None 93 | 94 | ### reconstruct and return the path when the node is the goal position 95 | if n in target_pos_and_dir: 96 | reconst_path = [] 97 | while parents[n] != n: 98 | reconst_path.append(n) 99 | n = parents[n] 100 | 101 | reconst_path.append(start_node) 102 | reconst_path.reverse() 103 | return reconst_path 104 | 105 | for m in get_neighbors(n, self.obs): 106 | if m not in open_list and m not in closed_list: 107 | open_list.add(m) 108 | parents[m] = n 109 | g[m] = g[n] + 1 110 | 111 | else: 112 | if g[m] > g[n]+1: 113 | g[m] = g[n]+1 114 | parents[m] = n 115 | 116 | if m in closed_list: 117 | closed_list.remove(m) 118 | open_list.add(m) 119 | 120 | open_list.remove(n) 121 | closed_list.add(n) 122 | 123 | # print(start_node, target_pos) 124 | # print(self.map) 125 | # print("no action found") 126 | return [[(None, None, 6)]] 127 | 128 | def __call__(self, can_truncate): 129 | if len(self.path) == 1: 130 | return None, True, False 131 | # return 6, True, False 132 | else: 133 | cur_dir = self.path[0][2] 134 | next_dir = self.path[1][2] 135 | angle = (cur_dir - next_dir) % 4 136 | if angle == 1: 137 | action = 0 138 | elif angle == 3: 139 | action = 1 140 | elif angle == 0: 141 | action = 2 142 | else: 143 | assert False, "'wrong path: cannot trun twice in a step!'" 144 | 145 | return action, False, False -------------------------------------------------------------------------------- /skill/goto_goal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import gymnasium as gym 5 | import minigrid 6 | from .base_skill import BaseSkill 7 | 8 | ''' 9 | OBJECT_TO_IDX = { 10 | "unseen": 0, 11 | "empty": 1, 12 | "wall": 2, 13 | "floor": 3, 14 | "door": 4, 15 | "key": 5, 16 | "ball": 6, 17 | "box": 7, 18 | "goal": 8, 19 | "lava": 9, 20 | "agent": 10, 21 | } 22 | 23 | STATE_TO_IDX = { 24 | "open": 0, 25 | "closed": 1, 26 | "locked": 2, 27 | } 28 | ''' 29 | 30 | DIRECTION = { 31 | 0: [1, 0], 32 | 1: [0, 1], 33 | 2: [-1, 0], 34 | 3: [0, -1], 35 | } 36 | 37 | 38 | def check_go_through(pos, maps): 39 | x, y = pos 40 | width, height, _ = maps.shape 41 | if x<0 or x>=width or y<0 or y>=height: 42 | return False 43 | # return (maps[x, y, 0] in [1, 8] or (maps[x, y, 0] == 4 and maps[x, y, 2]==0) ) 44 | return (maps[x, y, 0] in [1, 8, 9] or (maps[x, y, 0] == 4 and maps[x, y, 2]==0) ) 45 | 46 | def get_neighbors(pos_and_dir, maps): 47 | x, y, direction = pos_and_dir 48 | next_dir_left = direction - 1 if direction > 0 else 3 49 | next_dir_right = direction + 1 if direction < 3 else 0 50 | neighbor_list = [(x,y,next_dir_left), (x,y,next_dir_right)] 51 | forward_x, forward_y = DIRECTION[direction] 52 | new_x,new_y = (x+forward_x, y+forward_y) 53 | 54 | if check_go_through((new_x,new_y), maps): 55 | neighbor_list.append((new_x, new_y, direction)) 56 | 57 | 58 | assert not len(neighbor_list)==0 59 | 60 | return neighbor_list 61 | 62 | 63 | class GoTo_Goal(BaseSkill): 64 | def __init__(self, target_pos): 65 | self.target_pos = target_pos 66 | self.message = "none" 67 | 68 | def plan(self): 69 | start_node = (self.agent_pos[0], self.agent_pos[1], self.agent_dir) 70 | x, y = self.target_pos 71 | target_pos_and_dir = [(x-1, y, 0), (x, y-1, 1), (x+1, y, 2), (x, y+1, 3)] 72 | 73 | open_list = set([start_node]) 74 | closed_list = set([]) 75 | 76 | g = {} 77 | g[start_node] = 0 78 | 79 | parents = {} 80 | parents[start_node] = start_node 81 | 82 | while len(open_list) > 0: 83 | n = None 84 | 85 | for v in open_list: 86 | if n is None or g[v] < g[n]: 87 | n = v 88 | 89 | if n == None: 90 | self.message = "no action found" 91 | return None 92 | 93 | ### reconstruct and return the path when the node is the goal position 94 | if n in target_pos_and_dir: 95 | reconst_path = [] 96 | while parents[n] != n: 97 | reconst_path.append(n) 98 | n = parents[n] 99 | 100 | reconst_path.append(start_node) 101 | reconst_path.reverse() 102 | return reconst_path 103 | 104 | for m in get_neighbors(n, self.obs): 105 | if m not in open_list and m not in closed_list: 106 | open_list.add(m) 107 | parents[m] = n 108 | g[m] = g[n] + 1 109 | 110 | else: 111 | if g[m] > g[n]+1: 112 | g[m] = g[n]+1 113 | parents[m] = n 114 | 115 | if m in closed_list: 116 | closed_list.remove(m) 117 | open_list.add(m) 118 | 119 | open_list.remove(n) 120 | closed_list.add(n) 121 | 122 | # print(start_node, self.target_pos) 123 | # print(self.map) 124 | # print("no action found") 125 | # return [[(None, None, 6)]] 126 | self.message = "no action found" 127 | return None 128 | 129 | def __call__(self, obs): 130 | self.unpack_obs(obs) 131 | path = self.plan() 132 | 133 | if not path: 134 | return None, False 135 | elif len(path) == 1: 136 | return None, True 137 | else: 138 | cur_dir = path[0][2] 139 | next_dir = path[1][2] 140 | angle = (cur_dir - next_dir) % 4 141 | if angle == 1: 142 | action = 0 143 | elif angle == 3: 144 | action = 1 145 | elif angle == 0: 146 | action = 2 147 | else: 148 | assert False, "'wrong path: cannot trun twice in a step!'" 149 | 150 | return action, False -------------------------------------------------------------------------------- /algos/ppo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : ppo.py 5 | @Time : 2023/07/14 11:23:57 6 | @Author : Zhou Zihao 7 | @Version : 1.0 8 | @Desc : None 9 | ''' 10 | 11 | from .base import Base 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | class PPO(Base): 17 | 18 | def __init__(self, 19 | model, 20 | obs_space, 21 | action_space, 22 | device, 23 | save_path, 24 | recurrent=False, 25 | lr=0.001, 26 | max_grad_norm=0.5, 27 | adam_eps=1e-8, 28 | clip_eps=0.2, 29 | entropy_coef=0.001, 30 | kickstarting_coef_initial=10., 31 | kickstarting_coef_decent=0.005, 32 | kickstarting_coef_minimum=0.1, 33 | iter_with_ks=3000, 34 | value_loss_coef=.5, 35 | batch_size=128, 36 | num_worker=4, 37 | epoch=3, 38 | ): 39 | # model 40 | super().__init__(model, obs_space, action_space, device, save_path, recurrent) 41 | 42 | # optimizer 43 | self.lr = lr 44 | self.grad_clip = max_grad_norm 45 | self.eps = adam_eps 46 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr = self.lr, eps=self.eps) 47 | 48 | # loss 49 | self.clip = clip_eps 50 | self.entropy_coef = entropy_coef 51 | self.iter_with_ks = iter_with_ks 52 | self.ks_coef = kickstarting_coef_initial 53 | self.ks_coef_minimum = kickstarting_coef_minimum 54 | self.ks_coef_descent = kickstarting_coef_decent 55 | self.value_loss_coef = value_loss_coef 56 | 57 | # other settings 58 | self.batch_size = batch_size 59 | self.epochs = epoch 60 | self.num_worker = num_worker 61 | self.iter = 0 62 | 63 | def __call__(self, obs, mask, states): 64 | return self.model(obs, mask, states) 65 | 66 | def update_kickstarting_coef(self): 67 | self.iter += 1 68 | if self.ks_coef <= self.ks_coef_minimum: 69 | self.ks_coef = self.ks_coef_minimum 70 | else: 71 | self.ks_coef -= self.ks_coef_descent 72 | # if self.iter == self.iter_with_ks: 73 | # self.optimizer = torch.optim.Adam(self.model.parameters(), lr = self.lr, eps=self.eps) 74 | # self.model.actor[-1].reset_parameters() 75 | 76 | def update_policy(self, buffer): 77 | losses = [] 78 | 79 | for _ in range(self.epochs): 80 | for batch in buffer.sample(self.batch_size, self.recurrent): 81 | obs_batch, action_batch, return_batch, advantage_batch, values_batch, mask, log_prob_batch, teacher_prob_batch = batch 82 | 83 | # get policy 84 | states = self.model.init_states(self.device, obs_batch.size()[1]) if self.recurrent else None 85 | pdf, value, _ = self.model(obs_batch, mask, states) 86 | 87 | # update 88 | entropy_loss = (pdf.entropy() * mask).mean() 89 | kickstarting_loss = -(pdf.logits * teacher_prob_batch).sum(dim=-1).mean() 90 | 91 | ratio = torch.exp(pdf.log_prob(action_batch) - log_prob_batch) 92 | surr1 = ratio * advantage_batch * mask 93 | surr2 = torch.clamp(ratio, 1.0 - self.clip, 1.0 + self.clip) * advantage_batch * mask 94 | policy_loss = -torch.min(surr1, surr2).mean() 95 | 96 | value_clipped = values_batch + torch.clamp(value - values_batch, -self.clip, self.clip) 97 | surr1 = ((value - return_batch)*mask).pow(2) 98 | surr2 = ((value_clipped - return_batch)*mask).pow(2) 99 | value_loss = torch.max(surr1, surr2).mean() 100 | 101 | loss = policy_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_loss 102 | if self.iter < self.iter_with_ks: 103 | loss += self.ks_coef * kickstarting_loss 104 | 105 | # Update actor-critic 106 | self.optimizer.zero_grad() 107 | loss.backward() 108 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) 109 | self.optimizer.step() 110 | 111 | # Update logger 112 | losses.append([loss.item(), entropy_loss.item(), kickstarting_loss.item(), policy_loss.item(), value_loss.item()]) 113 | 114 | if self.iter < self.iter_with_ks: 115 | # Update kickstarting coefficient 116 | self.update_kickstarting_coef() 117 | 118 | mean_losses = np.mean(losses, axis=0) 119 | return mean_losses 120 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # [Large Language Model as a Policy Teacher for Training Reinforcement Learning Agents](https://arxiv.org/abs/2311.13373) 2 | 3 | ## Abstract 4 | Recent studies have uncovered the potential of Large Language Models (LLMs) in addressing complex sequential decision-making tasks through the provision of high-level instructions. However, LLM-based agents lack specialization in tackling specific target problems, particularly in real-time dynamic environments. Additionally, deploying an LLM-based agent in practical scenarios can be both costly and time-consuming. On the other hand, reinforcement learning (RL) approaches train agents that specialize in the target task but often suffer from low sampling efficiency and high exploration costs. In this paper, we introduce a novel framework that addresses these challenges by training a smaller, specialized student RL agent using instructions from an LLM-based teacher agent. By incorporating the guidance from the teacher agent, the student agent can distill the prior knowledge of the LLM into its own model. Consequently, the student agent can be trained with significantly less data. Moreover, through further training with environment feedback, the student agent surpasses the capabilities of its teacher for completing the target task. We conducted experiments on challenging MiniGrid and Habitat environments, specifically designed for embodied AI research, to evaluate the effectiveness of our framework. The results clearly demonstrate that our approach achieves superior performance compared to strong baseline methods. Our code is available at https://github.com/ZJLAB-AMMI/LLM4Teach. 5 | 6 | ## Purpose 7 | This repo is intended to serve as a foundation with which you can reproduce the results of the experiments detailed in our paper, [Large Language Model as a Policy Teacher for Training Reinforcement Learning Agents](https://arxiv.org/abs/2311.13373). 8 | 9 | 10 | ## Running experiments 11 | ### Setup the LLMs 12 | 13 | 1. For ChatGLM models, please use your own api_key and run the following code to launch the API 14 | ```bash 15 | python3 -m utils.chatglm_api --host --port 16 | ``` 17 | 18 | 2. For Vicuna models, please follow the instruction from [FastChat](https://github.com/lm-sys/FastChat) to install Vicuna model on local sever. Here are the commands to launch the API in terminal: 19 | 20 | ```bash 21 | python3 -m fastchat.serve.controller --host localhost --port ### Launch the controller 22 | python3 -m fastchat.serve.model_worker --model-name '' --model-path --controller http://localhost: --port --worker_address http://localhost: ### Launch the model worker 23 | python3 -m fastchat.serve.api --host --port ### Launch the API 24 | ``` 25 | 26 | 27 | ### Train and evaluate the models 28 | Any algorithm can be run from the main.py entry point. 29 | 30 | To train on a SimpleDoorKey environment, 31 | 32 | ```bash 33 | python main.py train --task SimpleDoorKey --savedir train 34 | ``` 35 | 36 | 41 | 42 | To evaluate the trained model, 43 | 44 | ```bash 45 | python main.py eval --task SimpleDoorKey --loaddir train --savedir eval 46 | ``` 47 | 48 | To evaluate the LLM-based teacher baseline, 49 | ```bash 50 | python main.py eval --task SimpleDoorKey --loaddir train --savedir eval --eval_teacher 51 | ``` 52 | 53 | ## Logging details 54 | Tensorboard logging is enabled by default for all algorithms. The logger expects that you supply an argument named ```logdir```, containing the root directory you want to store your logfiles 55 | 56 | The resulting directory tree would look something like this: 57 | ``` 58 | log/ # directory with all of the saved models and tensorboard 59 | └── ppo # algorithm name 60 | └── simpledoorkey # environment name 61 | └── save_name # unique save name 62 | ├── acmodel.pt # actor and critic network for algo 63 | ├── events.out.tfevents # tensorboard binary file 64 | └── config.json # readable hyperparameters for this run 65 | ``` 66 | 67 | Using tensorboard makes it easy to compare experiments and resume training later on. 68 | 69 | To see live training progress 70 | 71 | Run ```$ tensorboard --logdir=log``` then navigate to ```http://localhost:6006/``` in your browser 72 | 73 | ## Citation 74 | If you find [our work](https://arxiv.org/abs/2311.13373) useful, please kindly cite: 75 | ```bibtex 76 | @inproceedings{zhou2024large, 77 | title={Large Language Model as a Policy Teacher for Training Reinforcement Learning Agents}, 78 | author={Zhou, Zihao and Hu, Bin and Zhao, Chenyang and Zhang, Pu and Liu, Bin}, 79 | booktitle={The 33rd International Joint Conference on Artificial Intelligence (IJCAI 2024)}, 80 | year={2024} 81 | } 82 | ``` 83 | 84 | ## Acknowledgements 85 | This work is supported by Exploratory Research Project (No.2022RC0AN02) of Zhejiang Lab. 86 | -------------------------------------------------------------------------------- /teacher_policy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import gymnasium as gym 5 | import minigrid 6 | from planner import Planner 7 | from skill import GoTo_Goal, Explore, Pickup, Drop, Toggle, Wait 8 | from mediator import IDX_TO_SKILL, IDX_TO_OBJECT 9 | 10 | # single step (can handle soft planner) 11 | class TeacherPolicy(): 12 | def __init__(self, task, offline, soft, prefix, action_space, agent_view_size): 13 | self.planner = Planner(task, offline, soft, prefix) 14 | self.agent_view_size = agent_view_size 15 | self.action_space = action_space 16 | 17 | def get_skill_name(self, skill): 18 | try: 19 | return IDX_TO_SKILL[skill["action"]] + " " + IDX_TO_OBJECT[skill["object"]] 20 | except AttributeError: 21 | return "None" 22 | 23 | def reset(self): 24 | self.skill = None 25 | self.skill_list = [] 26 | self.skill_teminated = False 27 | self.planner.reset() 28 | 29 | def skill2teacher(self, skill): 30 | skill_action = skill['action'] 31 | if skill_action == 0: 32 | teacher = Explore(self.agent_view_size) 33 | elif skill_action == 1: 34 | teacher = GoTo_Goal(skill['coordinate']) 35 | elif skill_action == 2: 36 | teacher = Pickup(skill['object']) 37 | elif skill_action == 3: 38 | teacher = Drop(skill['object']) 39 | elif skill_action == 4: 40 | teacher = Toggle(skill['object']) 41 | elif skill_action == 6: 42 | teacher = Wait() 43 | else: 44 | assert False, "invalid skill" 45 | return teacher 46 | 47 | def get_action(self, skill_list, obs): 48 | teminated = True 49 | action = None 50 | while not action and teminated and len(skill_list) > 0: 51 | skill = skill_list.pop(0) 52 | teacher = self.skill2teacher(skill) 53 | action, teminated = teacher(obs) 54 | 55 | if action == None: 56 | 57 | action = 6 58 | 59 | action = np.array([i == action for i in range(self.action_space)], dtype=np.float32) 60 | 61 | return action 62 | 63 | def __call__(self, obs): 64 | skill_list, probs = self.planner(obs) 65 | action = np.zeros(self.action_space) 66 | for skills, prob in zip(skill_list, probs): 67 | action += self.get_action(skills, obs) * prob 68 | return action 69 | 70 | 71 | # class TeacherPolicy(): 72 | # def __init__(self, task, ideal, seed, agent_view_size): 73 | # self.planner = Planner(task, ideal, seed) 74 | # self.agent_view_size = agent_view_size 75 | 76 | # @property 77 | # def current_skill(self): 78 | # try: 79 | # return IDX_TO_SKILL[self.skill["action"]] + " " + IDX_TO_OBJECT[self.skill["object"]] 80 | # except AttributeError: 81 | # return "None" 82 | 83 | # def reset(self): 84 | # self.skill = None 85 | # self.skill_list = [] 86 | # self.skill_teminated = False 87 | # self.skill_truncated = False 88 | # self.planner.reset() 89 | 90 | # def initial_planning(self, decription, task_example): 91 | # self.planner.initial_planning(decription, task_example) 92 | 93 | # def skill2teacher(self, obs): 94 | # skill_action = self.skill['action'] 95 | # if skill_action == 0: 96 | # teacher = Explore(obs, self.agent_view_size) 97 | # elif skill_action == 1: 98 | # teacher = GoTo_Goal(obs, self.skill['coordinate']) 99 | # elif skill_action == 2: 100 | # teacher = Pickup(obs, self.skill['object']) 101 | # elif skill_action == 3: 102 | # teacher = Drop(obs, self.skill['object']) 103 | # elif skill_action == 4: 104 | # teacher = Toggle(obs, self.skill['object']) 105 | # else: 106 | # assert False, "invalid skill" 107 | # return teacher 108 | 109 | # def switch_skill(self, obs): 110 | # if self.skill_truncated or len(self.skill_list) == 0: 111 | # self.skill_list = self.planner(obs) # ask LLM 112 | # self.can_truncated = False 113 | # self.skill = self.skill_list.pop(0) 114 | 115 | # def __call__(self, obs, max_tries=5, always_ask=True): 116 | # if always_ask: 117 | # self.skill_truncated = True 118 | # action = None 119 | # tries = 0 120 | # self.can_truncate = True 121 | # while not action and tries <= max_tries: 122 | # if self.skill_teminated or self.skill_truncated: 123 | # self.switch_skill(obs) 124 | # teacher = self.skill2teacher(obs) 125 | # action, self.skill_teminated, self.skill_truncated = teacher(self.can_truncated) 126 | # tries += 1 127 | 128 | # if action == None: 129 | # print(obs[:, :, 0]) 130 | # print(obs[:, :, 2]) 131 | # print(obs[:, :, 3]) 132 | # print(teacher.message) 133 | # assert False, "teacher {} cannot give an action".format(self.current_skill) 134 | 135 | # return action 136 | 137 | 138 | -------------------------------------------------------------------------------- /skill/explore.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import gymnasium as gym 5 | import minigrid 6 | from minigrid.core.constants import DIR_TO_VEC 7 | from .base_skill import BaseSkill 8 | 9 | 10 | class Explore(BaseSkill): 11 | def __init__(self, agent_view_size): 12 | assert agent_view_size >= 3 13 | self.agent_view_size = agent_view_size 14 | self.message = "none" 15 | 16 | def get_room_boundary(self): 17 | width = self.map.shape[0] 18 | height = self.map.shape[1] 19 | self.botX, self.botY = width, height 20 | for i in range(1, width-1): 21 | for j in range(1, height-1): 22 | if self.map[i, j] not in (2, 4): 23 | pass 24 | elif self.botX == width and self.map[i, j + 1] in (2, 4): 25 | self.botX = i + 1 26 | elif self.botY == height and self.map[i + 1, j] in (2, 4): 27 | self.botY = j + 1 28 | else: 29 | pass 30 | if self.botX != width and self.botY != height: 31 | break 32 | 33 | def get_view(self, agent_dir, agent_pos=None): 34 | agent_pos = agent_pos if agent_pos else self.agent_pos 35 | 36 | # Facing right 37 | if agent_dir == 0: 38 | topX = agent_pos[0] 39 | topY = agent_pos[1] - self.agent_view_size // 2 40 | # Facing down 41 | elif agent_dir == 1: 42 | topX = agent_pos[0] - self.agent_view_size // 2 43 | topY = agent_pos[1] 44 | # Facing left 45 | elif agent_dir == 2: 46 | topX = agent_pos[0] - self.agent_view_size + 1 47 | topY = agent_pos[1] - self.agent_view_size // 2 48 | # Facing up 49 | elif agent_dir == 3: 50 | topX = agent_pos[0] - self.agent_view_size // 2 51 | topY = agent_pos[1] - self.agent_view_size + 1 52 | else: 53 | assert False, "invalid agent direction" 54 | 55 | # clip by room boundary 56 | topX = max(0, topX) 57 | topY = max(0, topY) 58 | botX = min(topX + self.agent_view_size, self.botX) 59 | botY = min(topY + self.agent_view_size, self.botY) 60 | # print("[{}:{}, {}:{}]".format(topX, botX, topY, botY)) 61 | 62 | return self.map[topX:botX, topY:botY] 63 | 64 | def get_grid_slice(self, agent_dir, agent_pos=None): 65 | agent_pos = agent_pos if agent_pos else self.agent_pos 66 | topX = 0 67 | topY = 0 68 | botX = self.botX 69 | botY = self.botY 70 | 71 | # Facing right 72 | if agent_dir == 0: 73 | topX = agent_pos[0] + self.agent_view_size // 2 + 1 74 | # Facing down 75 | elif agent_dir == 1: 76 | topY = agent_pos[1] + self.agent_view_size // 2 + 1 77 | # Facing left 78 | elif agent_dir == 2: 79 | botX = agent_pos[0] - self.agent_view_size // 2 80 | # Facing up 81 | elif agent_dir == 3: 82 | botY = agent_pos[1] - self.agent_view_size // 2 83 | else: 84 | assert False, "invalid agent direction" 85 | # print("[{}:{}, {}:{}]".format(topX, botX, topY, botY)) 86 | 87 | return self.map[topX:botX, topY:botY] 88 | 89 | def object_forward(self, agent_dir, agent_pos=None): 90 | x, y = self.agent_pos + DIR_TO_VEC[agent_dir] 91 | fwd_obj = self.map[x, y] 92 | if fwd_obj in (2, 4): # wall, door 93 | return 1 94 | # elif fwd_obj in (5, 6, 7, 9): # key, ball, box or lava 95 | elif fwd_obj in (5, 6, 7): # key, ball, box or lava 96 | return 2 97 | else: 98 | return 0 99 | 100 | def count_unseen_grid(self, agent_dir, agent_pos=None): 101 | grid = self.get_grid_slice(agent_dir, agent_pos) 102 | if grid.size == 0: 103 | # print("Wall ahead in dir {}".format(agent_dir)) 104 | return 0 105 | else: 106 | return np.count_nonzero(grid == 0) 107 | 108 | def __call__(self, obs): 109 | self.unpack_obs(obs) 110 | self.get_room_boundary() 111 | 112 | terminated = False 113 | # avoid object 114 | if self.object_forward(self.agent_dir) == 2: 115 | if self.object_forward((self.agent_dir - 1) % 4) in (1,2): # object/wall on the left 116 | action = 1 117 | else: 118 | action = 0 119 | else: 120 | # unseen grid in forward direction? 121 | if self.count_unseen_grid(self.agent_dir) > 0: 122 | action = 2 123 | # unseen grid in leftward direction? 124 | elif self.count_unseen_grid((self.agent_dir - 1) % 4) > 0: 125 | action = 0 126 | # unseen grid in rightward direction? 127 | elif self.count_unseen_grid((self.agent_dir + 1) % 4) > 0: 128 | action = 1 129 | # unseen grid in backward direction? 130 | elif self.count_unseen_grid((self.agent_dir + 2) % 4, tuple(self.agent_pos + DIR_TO_VEC[self.agent_dir])) > 0: 131 | action = 0 # or 1 132 | # no unseen grid 133 | else: 134 | action = None 135 | terminated = True 136 | self.message = "no unseen grid" 137 | 138 | return action, terminated 139 | 140 | -------------------------------------------------------------------------------- /skill_multi_step/explore.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import gymnasium as gym 5 | import minigrid 6 | from minigrid.core.constants import DIR_TO_VEC 7 | from .base_skill import BaseSkill 8 | 9 | 10 | class Explore(BaseSkill): 11 | def __init__(self, obs, agent_view_size): 12 | assert agent_view_size >= 3 13 | self.agent_view_size = agent_view_size 14 | self.unpack_obs(obs) 15 | self.get_room_boundary() 16 | self.action = None 17 | self.message = "none" 18 | 19 | def get_room_boundary(self): 20 | width = self.map.shape[0] 21 | height = self.map.shape[1] 22 | self.botX, self.botY = width, height 23 | for i in range(1, width-1): 24 | for j in range(1, height-1): 25 | if self.map[i, j] not in (2, 4): 26 | pass 27 | elif self.botX == width and self.map[i, j + 1] in (2, 4): 28 | self.botX = i + 1 29 | elif self.botY == height and self.map[i + 1, j] in (2, 4): 30 | self.botY = j + 1 31 | else: 32 | pass 33 | if self.botX != width and self.botY != height: 34 | break 35 | 36 | def get_view(self, agent_dir, agent_pos=None): 37 | agent_pos = agent_pos if agent_pos else self.agent_pos 38 | 39 | # Facing right 40 | if agent_dir == 0: 41 | topX = agent_pos[0] 42 | topY = agent_pos[1] - self.agent_view_size // 2 43 | # Facing down 44 | elif agent_dir == 1: 45 | topX = agent_pos[0] - self.agent_view_size // 2 46 | topY = agent_pos[1] 47 | # Facing left 48 | elif agent_dir == 2: 49 | topX = agent_pos[0] - self.agent_view_size + 1 50 | topY = agent_pos[1] - self.agent_view_size // 2 51 | # Facing up 52 | elif agent_dir == 3: 53 | topX = agent_pos[0] - self.agent_view_size // 2 54 | topY = agent_pos[1] - self.agent_view_size + 1 55 | else: 56 | assert False, "invalid agent direction" 57 | 58 | # clip by room boundary 59 | topX = max(0, topX) 60 | topY = max(0, topY) 61 | botX = min(topX + self.agent_view_size, self.botX) 62 | botY = min(topY + self.agent_view_size, self.botY) 63 | # print("[{}:{}, {}:{}]".format(topX, botX, topY, botY)) 64 | 65 | return self.map[topX:botX, topY:botY] 66 | 67 | def get_grid_slice(self, agent_dir, agent_pos=None): 68 | agent_pos = agent_pos if agent_pos else self.agent_pos 69 | topX = 0 70 | topY = 0 71 | botX = self.botX 72 | botY = self.botY 73 | 74 | # Facing right 75 | if agent_dir == 0: 76 | topX = agent_pos[0] + self.agent_view_size // 2 + 1 77 | # Facing down 78 | elif agent_dir == 1: 79 | topY = agent_pos[1] + self.agent_view_size // 2 + 1 80 | # Facing left 81 | elif agent_dir == 2: 82 | botX = agent_pos[0] - self.agent_view_size // 2 83 | # Facing up 84 | elif agent_dir == 3: 85 | botY = agent_pos[1] - self.agent_view_size // 2 86 | else: 87 | assert False, "invalid agent direction" 88 | # print("[{}:{}, {}:{}]".format(topX, botX, topY, botY)) 89 | 90 | return self.map[topX:botX, topY:botY] 91 | 92 | def object_in_sight(self, agent_dir, agent_pos=None): 93 | grid = self.get_view(agent_dir, agent_pos) 94 | for i in np.nditer(grid): 95 | # if i in (5, 6, 7, 9): # key, ball, box or lava 96 | if i in (5, 6, 7): # key, ball, box 97 | return True 98 | return False 99 | 100 | def object_forward(self, agent_dir, agent_pos=None): 101 | x, y = self.agent_pos + DIR_TO_VEC[agent_dir] 102 | fwd_obj = self.map[x, y] 103 | if fwd_obj in (2, 4): # wall, door 104 | return 1 105 | # elif fwd_obj in (5, 6, 7, 9): # key, ball, box or lava 106 | elif fwd_obj in (5, 6, 7): # key, ball, box or lava 107 | return 2 108 | else: 109 | return 0 110 | 111 | def count_unseen_grid(self, agent_dir, agent_pos=None): 112 | grid = self.get_grid_slice(agent_dir, agent_pos) 113 | if grid.size == 0: 114 | # print("Wall ahead in dir {}".format(agent_dir)) 115 | return 0 116 | else: 117 | return np.count_nonzero(grid == 0) 118 | 119 | def __call__(self, can_truncate): 120 | # object in view? 121 | if self.object_in_sight(self.agent_dir): 122 | if can_truncate: 123 | self.message = "object in sight" 124 | return None, True, False # ask LLM 125 | elif self.object_forward(self.agent_dir) == 2: # avoid object 126 | if self.object_forward((self.agent_dir - 1) % 4) in (1,2): 127 | return 1, False, False 128 | else: 129 | return 0, False, False 130 | 131 | terminated = False 132 | truncated = False 133 | # unseen grid in forward direction? 134 | if self.count_unseen_grid(self.agent_dir) > 0: 135 | action = 2 136 | # unseen grid in leftward direction? 137 | elif self.count_unseen_grid((self.agent_dir - 1) % 4) > 0: 138 | action = 0 139 | # unseen grid in rightward direction? 140 | elif self.count_unseen_grid((self.agent_dir + 1) % 4) > 0: 141 | action = 1 142 | # unseen grid in backward direction? 143 | elif self.count_unseen_grid((self.agent_dir + 2) % 4, tuple(self.agent_pos + DIR_TO_VEC[self.agent_dir])) > 0: 144 | action = 0 # or 1 145 | # no unseen grid 146 | else: 147 | action = None 148 | terminated = True 149 | self.message = "no unseen grid" 150 | 151 | return action, terminated, truncated 152 | 153 | -------------------------------------------------------------------------------- /algos/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : student_net.py 5 | @Time : 2023/07/14 16:34:11 6 | @Author : Zhou Zihao 7 | @Version : 1.0 8 | @Desc : None 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.distributions.categorical import Categorical 14 | import torch.nn.functional as F 15 | import numpy as np 16 | 17 | 18 | class NNBase(nn.Module): 19 | def __init__(self, obs_space, action_space): 20 | super().__init__() 21 | width, height, channel = obs_space["image"] 22 | 23 | # Define image embedding 24 | self.image_conv = nn.Sequential( 25 | nn.Conv2d(channel, 16, (3, 3), padding=1), 26 | nn.ReLU(), 27 | # nn.MaxPool2d((2, 2)), 28 | nn.Conv2d(16, 32, (3, 3), padding=1), 29 | nn.ReLU(), 30 | nn.Conv2d(32, 64, (3, 3), padding=1), 31 | nn.ReLU() 32 | ) 33 | 34 | dummy_x = torch.zeros((1, channel, height, width)) 35 | embedding_size = np.prod(self.image_conv(dummy_x).shape) 36 | return embedding_size, action_space 37 | 38 | def forward(self, obs, masks=None, states=None): 39 | raise NotImplementedError 40 | 41 | 42 | class MLPBase(NNBase): 43 | def __init__(self, obs_space, action_space): 44 | embedding_size, action_space = super().__init__(obs_space, action_space) 45 | 46 | # self.fc = nn.Sequential( 47 | # nn.Linear(embedding_size, 64), 48 | # nn.Tanh() 49 | # ) 50 | 51 | # # Define actor's model 52 | # self.actor = nn.Linear(64, action_space) 53 | # # Define critic's model 54 | # self.critic = nn.Linear(64, 1) 55 | 56 | # # Define actor's model 57 | # self.actor = nn.Sequential( 58 | # nn.Linear(64, 16), 59 | # nn.Tanh(), 60 | # nn.Linear(16, action_space) 61 | # ) 62 | # # Define critic's model 63 | # self.critic = nn.Sequential( 64 | # nn.Linear(64, 16), 65 | # nn.Tanh(), 66 | # nn.Linear(16, 1) 67 | # ) 68 | 69 | # Define actor's model 70 | self.actor = nn.Sequential( 71 | nn.Linear(embedding_size, 64), 72 | nn.ReLU(), 73 | nn.Linear(64, action_space) 74 | ) 75 | # Define critic's model 76 | self.critic = nn.Sequential( 77 | nn.Linear(embedding_size, 64), 78 | nn.ReLU(), 79 | nn.Linear(64, 1) 80 | ) 81 | 82 | def init_states(self, device=None, num_trajs=1): 83 | return None 84 | 85 | def forward(self, obs, masks=None, states=None): 86 | input_dim = len(obs.size()) 87 | assert input_dim == 4, "observation dimension expected to be 4, but got {}.".format(input_dim) 88 | 89 | # feature extractor 90 | x = obs.transpose(1, 3) # [num_trans, channels, height, width] 91 | x = self.image_conv(x) 92 | x = x.reshape(x.shape[0], -1) # [num_trans, -1] 93 | embedding = x 94 | # embedding = self.fc(x) 95 | 96 | # actor-critic 97 | value = self.critic(embedding).squeeze(1) 98 | action_logits = self.actor(embedding) 99 | dist = Categorical(logits=action_logits) 100 | 101 | return dist, value, embedding 102 | 103 | 104 | class LSTMBase(NNBase): 105 | def __init__(self, obs_space, action_space): 106 | embedding_size, action_space = super().__init__(obs_space, action_space) 107 | 108 | self.fc = nn.Sequential( 109 | nn.Linear(embedding_size, 256), 110 | nn.ReLU() 111 | ) 112 | self.core = nn.LSTM(256, 256, 2) 113 | 114 | # Define actor's model 115 | self.actor = nn.Sequential( 116 | nn.Linear(256, 64), 117 | nn.ReLU(), 118 | nn.Linear(64, action_space) 119 | ) 120 | # Define critic's model 121 | self.critic = nn.Sequential( 122 | nn.Linear(256, 64), 123 | nn.ReLU(), 124 | nn.Linear(64, 1) 125 | ) 126 | 127 | def init_states(self, device, num_trajs=1): 128 | return (torch.zeros(self.core.num_layers, num_trajs, self.core.hidden_size).to(device), 129 | torch.zeros(self.core.num_layers, num_trajs, self.core.hidden_size).to(device)) 130 | 131 | def forward(self, obs, masks, states): 132 | input_dim = len(obs.size()) 133 | if input_dim == 4: 134 | unroll_length = obs.shape[0] 135 | num_trajs = 1 136 | elif len(obs.size()) == 5: 137 | unroll_length, num_trajs, *_ = obs.shape 138 | obs = torch.flatten(obs, 0, 1) # [unroll_length * num_trajs, width, height, channels] 139 | else: 140 | assert False, "observation dimension expected to be 4 or 5, but got {}.".format(input_dim) 141 | 142 | # feature extractor 143 | x = obs.transpose(1, 3) # [unroll_length * num_trajs, channels, height, width] 144 | x = self.image_conv(x) 145 | x = x.reshape(unroll_length * num_trajs, -1) # [unroll_length * num_trajs, -1] 146 | x = self.fc(x) 147 | 148 | # LSTM 149 | core_input = x.view(unroll_length, num_trajs, -1) # [unroll_length, num_trajs, -1] 150 | masks = masks.view(unroll_length, 1, num_trajs, 1) # [unroll_length, 1, num_trajs, 1] 151 | core_output_list = [] 152 | for inp, mask in zip(core_input.unbind(), masks.unbind()): 153 | states = tuple(mask * s for s in states) 154 | output, states = self.core(inp.unsqueeze(0), states) 155 | core_output_list.append(output) 156 | core_output = torch.cat(core_output_list) # [unroll_length, num_trajs, -1] 157 | core_output = core_output.view(unroll_length * num_trajs, -1) # [unroll_length * num_trajs, -1] 158 | 159 | # actor-critic 160 | action_logits = self.actor(core_output) 161 | dist = Categorical(logits=action_logits) 162 | value = self.critic(core_output).squeeze(1) 163 | 164 | return dist, value, states 165 | -------------------------------------------------------------------------------- /skill_multi_step/base_skill.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import gymnasium as gym 5 | import minigrid 6 | from minigrid.core.constants import DIR_TO_VEC 7 | 8 | ''' 9 | OBJECT_TO_IDX = { 10 | "unseen": 0, 11 | "empty": 1, 12 | "wall": 2, 13 | "floor": 3, 14 | "door": 4, 15 | "key": 5, 16 | "ball": 6, 17 | "box": 7, 18 | "goal": 8, 19 | "lava": 9, 20 | "agent": 10, 21 | } 22 | 23 | STATE_TO_IDX = { 24 | "open": 0, 25 | "closed": 1, 26 | "locked": 2, 27 | } 28 | ''' 29 | 30 | DIRECTION = { 31 | 0: [1, 0], 32 | 1: [0, 1], 33 | 2: [-1, 0], 34 | 3: [0, -1], 35 | } 36 | 37 | class BaseSkill(): 38 | def __init__(self): 39 | pass 40 | 41 | def unpack_obs(self, obs): 42 | self.obs = obs 43 | agent_map = obs[:, :, 3] 44 | self.agent_pos = np.argwhere(agent_map != 4)[0] 45 | self.agent_dir = obs[self.agent_pos[0], self.agent_pos[1], 3] 46 | self.map = obs[:, :, 0].copy() 47 | self.carrying = self.map[self.agent_pos[0], self.agent_pos[1]] 48 | self.map[self.agent_pos[0], self.agent_pos[1]] = 10 49 | 50 | 51 | # class Pickup(BaseSkill): 52 | # def __init__(self, init_obs): 53 | # init_obs = init_obs[:,:,-4:] 54 | # self.path_prefix = [] 55 | # self.path_suffix = [] 56 | # self.plan(init_obs) 57 | 58 | # def plan(self, init_obs, max_tries=30): 59 | # self.unpack_obs(init_obs) 60 | 61 | # if self.map[self.agent_pos[0], self.agent_pos[1]] == 1: #not carrying 62 | # self.path = [3] 63 | # else: 64 | # angle_list = [0, 2, 1, 3] 65 | # angle_list.remove(0) 66 | # goto_angle = None 67 | # finish = False 68 | # tries = 0 69 | # while not finish: 70 | # search_angle = angle_list.pop(0) 71 | # _drop, _goto = self.can_drop(search_angle) 72 | # tries += 1 73 | # if _drop: 74 | # self.update_path(search_angle) 75 | # self.path = self.path2action(self.path_prefix) + [4] + self.path2action(self.path_suffix) + [3] 76 | # finish = True 77 | # else: 78 | # # since there is only 1 door, there is at most 1 angle can go to but cannot drop 79 | # if _goto: 80 | # goto_angle = search_angle 81 | 82 | # if len(angle_list) == 0: 83 | # if goto_angle or tries < max_tries: 84 | # self.update_path(goto_angle, forward=True) 85 | # self.agent_dir = (self.agent_dir + goto_angle) % 4 86 | # self.agent_pos = self.agent_pos + DIR_TO_VEC[self.agent_dir] 87 | # angle_list = [0, 2, 1, 3] 88 | # angle_list.remove(2) # not search backward 89 | # goto_angle = None 90 | # else: 91 | # finish = True 92 | # self.path = [] 93 | # print("path not found!") 94 | 95 | # def can_drop(self, angle, distance=1): 96 | # target_dir = (self.agent_dir + angle) % 4 97 | # target_pos = self.agent_pos + DIR_TO_VEC[target_dir] * distance 98 | # target_obj = self.map[target_pos[0], target_pos[1]] 99 | # if target_obj != 1: # not empty 100 | # _drop, _goto = False, False 101 | # else: 102 | # _drop, _goto = True, True 103 | # for i in range(4): 104 | # nearby_pos = target_pos + DIR_TO_VEC[i] 105 | # if self.map[nearby_pos[0], nearby_pos[1]] == 4: # near a door 106 | # _drop = False 107 | # return _drop, _goto 108 | 109 | # def update_path(self, angle, forward=False): 110 | # if forward: 111 | # if angle == 2: 112 | # self.path_prefix += [2, 'f'] 113 | # self.path_suffix = [2, 'f'] + self.path_suffix 114 | # elif angle == 1: 115 | # self.path_prefix += [1, 'f'] 116 | # self.path_suffix = [2, 'f', 1] + self.path_suffix 117 | # elif angle == 3: 118 | # self.path_prefix += [3, 'f'] 119 | # self.path_suffix = [2, 'f', 3] + self.path_suffix 120 | # else: 121 | # self.path_prefix += ['f'] 122 | # self.path_suffix = [2, 'f', 2] + self.path_suffix 123 | # else: 124 | # if angle == 2: 125 | # self.path_prefix += [2] 126 | # self.path_suffix = [2] + self.path_suffix 127 | # elif angle == 1: 128 | # self.path_prefix += [1] 129 | # self.path_suffix = [3] + self.path_suffix 130 | # elif angle == 3: 131 | # self.path_prefix += [3] 132 | # self.path_suffix = [1] + self.path_suffix 133 | # else: 134 | # pass 135 | 136 | # def path2action(self, path): 137 | # angle = 0 138 | # action_list = [] 139 | # path.append('f') 140 | # for i in path: 141 | # if i == 'f': 142 | # angle = angle % 4 143 | # if angle == 1: 144 | # action_list.append(1) 145 | # elif angle == 3: 146 | # action_list.append(0) 147 | # elif angle == 2: 148 | # action_list.extend([0, 0]) 149 | # else: 150 | # pass 151 | # angle = 0 152 | # action_list.append(2) 153 | # else: 154 | # angle += i 155 | # return action_list[:-1] 156 | 157 | # def step(self, obs): 158 | # action = self.path.pop(0) 159 | # terminated = self.done_check() 160 | # return action, terminated, False 161 | 162 | # def done_check(self): 163 | # return len(self.path) == 0 -------------------------------------------------------------------------------- /skill/base_skill.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import gymnasium as gym 5 | import minigrid 6 | from minigrid.core.constants import DIR_TO_VEC 7 | 8 | ''' 9 | OBJECT_TO_IDX = { 10 | "unseen": 0, 11 | "empty": 1, 12 | "wall": 2, 13 | "floor": 3, 14 | "door": 4, 15 | "key": 5, 16 | "ball": 6, 17 | "box": 7, 18 | "goal": 8, 19 | "lava": 9, 20 | "agent": 10, 21 | } 22 | 23 | STATE_TO_IDX = { 24 | "open": 0, 25 | "closed": 1, 26 | "locked": 2, 27 | } 28 | ''' 29 | 30 | DIRECTION = { 31 | 0: [1, 0], 32 | 1: [0, 1], 33 | 2: [-1, 0], 34 | 3: [0, -1], 35 | } 36 | 37 | class BaseSkill(): 38 | def __init__(self): 39 | pass 40 | 41 | def unpack_obs(self, obs): 42 | self.obs = obs 43 | agent_map = obs[:, :, 3] 44 | self.agent_pos = np.argwhere(agent_map != 4)[0] 45 | self.agent_dir = obs[self.agent_pos[0], self.agent_pos[1], 3] 46 | self.map = obs[:, :, 0].copy() 47 | self.carrying = self.map[self.agent_pos[0], self.agent_pos[1]] 48 | self.map[self.agent_pos[0], self.agent_pos[1]] = 10 49 | 50 | 51 | class Pickup(BaseSkill): 52 | def __init__(self, target_obj=None): 53 | pass 54 | 55 | def __call__(self, obs): 56 | return 3, True 57 | 58 | class Drop(BaseSkill): 59 | def __init__(self, target_obj=None): 60 | pass 61 | 62 | def __call__(self, obs): 63 | return 4, True 64 | 65 | class Toggle(BaseSkill): 66 | def __init__(self, target_obj=None): 67 | pass 68 | 69 | def __call__(self, obs): 70 | return 5, True 71 | 72 | class Wait(BaseSkill): 73 | def __init__(self, target_obj=None): 74 | pass 75 | 76 | def __call__(self, obs): 77 | return 6, True 78 | 79 | 80 | # class Pickup(BaseSkill): 81 | # def __init__(self, init_obs): 82 | # init_obs = init_obs[:,:,-4:] 83 | # self.path_prefix = [] 84 | # self.path_suffix = [] 85 | # self.plan(init_obs) 86 | 87 | # def plan(self, init_obs, max_tries=30): 88 | # self.unpack_obs(init_obs) 89 | 90 | # if self.map[self.agent_pos[0], self.agent_pos[1]] == 1: #not carrying 91 | # self.path = [3] 92 | # else: 93 | # angle_list = [0, 2, 1, 3] 94 | # angle_list.remove(0) 95 | # goto_angle = None 96 | # finish = False 97 | # tries = 0 98 | # while not finish: 99 | # search_angle = angle_list.pop(0) 100 | # _drop, _goto = self.can_drop(search_angle) 101 | # tries += 1 102 | # if _drop: 103 | # self.update_path(search_angle) 104 | # self.path = self.path2action(self.path_prefix) + [4] + self.path2action(self.path_suffix) + [3] 105 | # finish = True 106 | # else: 107 | # # since there is only 1 door, there is at most 1 angle can go to but cannot drop 108 | # if _goto: 109 | # goto_angle = search_angle 110 | 111 | # if len(angle_list) == 0: 112 | # if goto_angle or tries < max_tries: 113 | # self.update_path(goto_angle, forward=True) 114 | # self.agent_dir = (self.agent_dir + goto_angle) % 4 115 | # self.agent_pos = self.agent_pos + DIR_TO_VEC[self.agent_dir] 116 | # angle_list = [0, 2, 1, 3] 117 | # angle_list.remove(2) # not search backward 118 | # goto_angle = None 119 | # else: 120 | # finish = True 121 | # self.path = [] 122 | # print("path not found!") 123 | 124 | # def can_drop(self, angle, distance=1): 125 | # target_dir = (self.agent_dir + angle) % 4 126 | # target_pos = self.agent_pos + DIR_TO_VEC[target_dir] * distance 127 | # target_obj = self.map[target_pos[0], target_pos[1]] 128 | # if target_obj != 1: # not empty 129 | # _drop, _goto = False, False 130 | # else: 131 | # _drop, _goto = True, True 132 | # for i in range(4): 133 | # nearby_pos = target_pos + DIR_TO_VEC[i] 134 | # if self.map[nearby_pos[0], nearby_pos[1]] == 4: # near a door 135 | # _drop = False 136 | # return _drop, _goto 137 | 138 | # def update_path(self, angle, forward=False): 139 | # if forward: 140 | # if angle == 2: 141 | # self.path_prefix += [2, 'f'] 142 | # self.path_suffix = [2, 'f'] + self.path_suffix 143 | # elif angle == 1: 144 | # self.path_prefix += [1, 'f'] 145 | # self.path_suffix = [2, 'f', 1] + self.path_suffix 146 | # elif angle == 3: 147 | # self.path_prefix += [3, 'f'] 148 | # self.path_suffix = [2, 'f', 3] + self.path_suffix 149 | # else: 150 | # self.path_prefix += ['f'] 151 | # self.path_suffix = [2, 'f', 2] + self.path_suffix 152 | # else: 153 | # if angle == 2: 154 | # self.path_prefix += [2] 155 | # self.path_suffix = [2] + self.path_suffix 156 | # elif angle == 1: 157 | # self.path_prefix += [1] 158 | # self.path_suffix = [3] + self.path_suffix 159 | # elif angle == 3: 160 | # self.path_prefix += [3] 161 | # self.path_suffix = [1] + self.path_suffix 162 | # else: 163 | # pass 164 | 165 | # def path2action(self, path): 166 | # angle = 0 167 | # action_list = [] 168 | # path.append('f') 169 | # for i in path: 170 | # if i == 'f': 171 | # angle = angle % 4 172 | # if angle == 1: 173 | # action_list.append(1) 174 | # elif angle == 3: 175 | # action_list.append(0) 176 | # elif angle == 2: 177 | # action_list.extend([0, 0]) 178 | # else: 179 | # pass 180 | # angle = 0 181 | # action_list.append(2) 182 | # else: 183 | # angle += i 184 | # return action_list[:-1] 185 | 186 | # def step(self, obs): 187 | # action = self.path.pop(0) 188 | # terminated = self.done_check() 189 | # return action, terminated, False 190 | 191 | # def done_check(self): 192 | # return len(self.path) == 0 -------------------------------------------------------------------------------- /algos/buffer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : buffer.py 5 | @Time : 2023/04/19 09:40:36 6 | @Author : Hu Bin 7 | @Version : 1.0 8 | @Desc : None 9 | ''' 10 | import numpy as np 11 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 12 | from torch.nn.utils.rnn import pad_sequence 13 | import torch 14 | 15 | 16 | # def Merge_Buffers(buffers, device='cpu'): 17 | # merged = Buffer(device=device) 18 | # for buf in buffers: 19 | # offset = len(merged) 20 | 21 | # merged.obs += buf.obs 22 | # merged.actions += buf.actions 23 | # merged.rewards += buf.rewards 24 | # merged.values += buf.values 25 | # merged.returns += buf.returns 26 | # merged.log_probs += buf.log_probs 27 | # merged.teacher_probs += buf.teacher_probs 28 | 29 | # merged.ep_returns += buf.ep_returns 30 | # merged.ep_lens += buf.ep_lens 31 | 32 | 33 | # merged.traj_idx += [offset + i for i in buf.traj_idx[1:]] 34 | # merged.ptr += buf.ptr 35 | 36 | # return merged 37 | 38 | 39 | class Buffer: 40 | """ 41 | A buffer for storing trajectory data and calculating returns for the policy 42 | and critic updates. 43 | """ 44 | def __init__(self, gamma=0.99, lam=0.95, device='cpu'): 45 | self.gamma = gamma 46 | self.lam = lam # unused 47 | self.device = device 48 | 49 | def __len__(self): 50 | return self.ptr 51 | 52 | def clear(self): 53 | self.obs = [] 54 | self.actions = [] 55 | self.rewards = [] 56 | self.values = [] 57 | self.log_probs = [] 58 | self.teacher_probs = [] 59 | 60 | self.ptr = 0 61 | self.traj_idx = [0] 62 | self.returns = [] 63 | self.ep_returns = [] # for logging 64 | self.ep_lens = [] 65 | 66 | def store(self, state, action, reward, value, log_probs, teacher_probs): 67 | """ 68 | Append one timestep of agent-environment interaction to the buffer. 69 | """ 70 | # TODO: make sure these dimensions really make sense 71 | # print(obs.shape, action.shape, reward.shape, value.shape, log_probs.shape) 72 | self.obs += [state.squeeze(0)] 73 | self.actions += [action.squeeze()] 74 | self.rewards += [reward.squeeze()] 75 | self.values += [value.squeeze()] 76 | self.log_probs += [log_probs.squeeze()] 77 | self.teacher_probs += [teacher_probs] 78 | self.ptr += 1 79 | 80 | def finish_path(self, last_val=None): 81 | self.traj_idx += [self.ptr] 82 | rewards = self.rewards[self.traj_idx[-2]:self.traj_idx[-1]] 83 | 84 | returns = [] 85 | R = last_val 86 | for reward in reversed(rewards): 87 | R = self.gamma * R + reward 88 | returns.insert(0, R) 89 | 90 | self.returns += returns 91 | self.ep_returns += [np.sum(rewards)] 92 | self.ep_lens += [len(rewards)] 93 | 94 | def get(self): 95 | return( 96 | np.array(self.obs), 97 | np.array(self.actions), 98 | np.array(self.returns), 99 | np.array(self.values), 100 | np.array(self.log_probs), 101 | np.array(self.teacher_probs) 102 | ) 103 | 104 | def sample(self, batch_size=64, recurrent=False): 105 | if recurrent: 106 | random_indices = np.random.permutation(len(self.ep_lens)) 107 | last_index = random_indices[-1] 108 | sampler = [] 109 | indices = [] 110 | num_sample = 0 111 | for i in random_indices: 112 | indices.append(i) 113 | num_sample += self.ep_lens[i] 114 | if num_sample > batch_size or i == last_index: 115 | sampler.append(indices) 116 | indices = [] 117 | num_sample = 0 118 | # random_indices = SubsetRandomSampler(range(len(self.traj_idx)-1)) 119 | # sampler = BatchSampler(random_indices, batch_size, drop_last=False) 120 | else: 121 | random_indices = SubsetRandomSampler(range(self.ptr)) 122 | sampler = BatchSampler(random_indices, batch_size, drop_last=True) 123 | 124 | observations, actions, returns, values, log_probs, teacher_probs = map(torch.Tensor, self.get()) 125 | 126 | advantages = returns - values 127 | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5) 128 | 129 | for indices in sampler: 130 | if recurrent: 131 | obs_batch = [observations[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices] 132 | action_batch = [actions[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices] 133 | return_batch = [returns[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices] 134 | advantage_batch = [advantages[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices] 135 | values_batch = [values[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices] 136 | mask = [torch.ones_like(r) for r in return_batch] 137 | log_prob_batch = [log_probs[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices] 138 | teacher_prob_batch = [teacher_probs[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices] 139 | 140 | obs_batch = pad_sequence(obs_batch, batch_first=False) # [unroll_length, num_trajs, ...] 141 | action_batch = pad_sequence(action_batch, batch_first=False).flatten(0,1) 142 | return_batch = pad_sequence(return_batch, batch_first=False).flatten(0,1) 143 | advantage_batch = pad_sequence(advantage_batch, batch_first=False).flatten(0,1) 144 | values_batch = pad_sequence(values_batch, batch_first=False).flatten(0,1) 145 | mask = pad_sequence(mask, batch_first=False).flatten(0,1) 146 | log_prob_batch = pad_sequence(log_prob_batch, batch_first=False).flatten(0,1) 147 | teacher_prob_batch = pad_sequence(teacher_prob_batch, batch_first=False).flatten(0,1) 148 | else: 149 | obs_batch = observations[indices] 150 | action_batch = actions[indices] 151 | return_batch = returns[indices] 152 | advantage_batch = advantages[indices] 153 | values_batch = values[indices] 154 | mask = torch.FloatTensor([1]) 155 | log_prob_batch = log_probs[indices] 156 | teacher_prob_batch = teacher_probs[indices] 157 | 158 | 159 | yield obs_batch.to(self.device), action_batch.to(self.device), return_batch.to(self.device), advantage_batch.to(self.device), values_batch.to(self.device), mask.to(self.device), log_prob_batch.to(self.device), teacher_prob_batch.to(self.device) -------------------------------------------------------------------------------- /env/historicalobs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gymnasium import spaces 3 | from minigrid.core.grid import Grid 4 | from minigrid.core.world_object import Wall, WorldObj 5 | from minigrid.minigrid_env import MiniGridEnv 6 | from minigrid.core.mission import MissionSpace 7 | from minigrid.core.constants import TILE_PIXELS 8 | from minigrid.utils.rendering import fill_coords, point_in_rect, point_in_circle 9 | 10 | class Unseen(WorldObj): 11 | def __init__(self): 12 | super().__init__("unseen", "grey") 13 | 14 | def render(self, img): 15 | fill_coords(img, point_in_rect(0, 1, 0, 1), np.array([200, 200, 200])) 16 | # fill_coords(img, point_in_rect(0, 1, 0.45, 0.55), np.array([200, 200, 200])) 17 | 18 | class Footprint(WorldObj): 19 | def __init__(self): 20 | super().__init__("empty", "red") 21 | 22 | def render(self, img): 23 | fill_coords(img, point_in_circle(0.5, 0.5, 0.15), np.array([225, 0, 0])) 24 | 25 | class FlexibleGrid(Grid): 26 | def __init__(self, width: int, height: int): 27 | assert width >= 1 28 | assert height >= 1 29 | 30 | self.width: int = width 31 | self.height: int = height 32 | 33 | self.grid: list[WorldObj | None] = [None] * (width * height) 34 | self.mask: np.ndarray = np.zeros(shape=(self.width, self.height), dtype=bool) 35 | 36 | def looking_vertical(self, i, direction, out_of_bound): 37 | for j in range(0, self.height - 1): # north to south 38 | if not self.mask[i, j]: 39 | continue 40 | 41 | cell = self.get(i, j) 42 | if cell and not cell.see_behind(): 43 | continue 44 | 45 | # print(i, j) 46 | self.mask[i, j + 1] = True 47 | if not out_of_bound: 48 | self.mask[i + direction, j] = True 49 | self.mask[i + direction, j + 1] = True 50 | # print(self.mask.T) 51 | 52 | for j in reversed(range(1, self.height)): # south to north 53 | if not self.mask[i, j]: 54 | continue 55 | 56 | cell = self.get(i, j) 57 | if cell and not cell.see_behind(): 58 | continue 59 | 60 | # print(i, j) 61 | self.mask[i, j - 1] = True 62 | if not out_of_bound: 63 | self.mask[i + direction, j] = True 64 | self.mask[i + direction, j - 1] = True 65 | # print(self.mask.T) 66 | 67 | def looking_horizontal(self, j, direction, out_of_bound): 68 | for i in range(0, self.width - 1): # west to east 69 | if not self.mask[i, j]: 70 | continue 71 | 72 | cell = self.get(i, j) 73 | if cell and not cell.see_behind(): 74 | continue 75 | 76 | # print(i, j) 77 | self.mask[i + 1, j] = True 78 | if not out_of_bound: 79 | self.mask[i, j + direction] = True 80 | self.mask[i + 1, j + direction] = True 81 | # print(self.mask.T) 82 | 83 | for i in reversed(range(1, self.width)): # east to west 84 | if not self.mask[i, j]: 85 | continue 86 | 87 | cell = self.get(i, j) 88 | if cell and not cell.see_behind(): 89 | continue 90 | 91 | # print(i, j) 92 | self.mask[i - 1, j] = True 93 | if not out_of_bound: 94 | self.mask[i, j + direction] = True 95 | self.mask[i - 1, j + direction] = True 96 | # print(self.mask.T) 97 | 98 | def process_vis(self, agent_pos, agent_dir): 99 | self.mask[agent_pos[0], agent_pos[1]] = True 100 | 101 | if agent_dir == 0: # east 102 | for i in range(0, self.width): 103 | out_of_bound = i == self.width - 1 104 | self.looking_vertical(i, 1, out_of_bound) 105 | 106 | elif agent_dir == 2: # west 107 | for i in reversed(range(0, self.width)): 108 | out_of_bound = i == 0 109 | self.looking_vertical(i, -1, out_of_bound) 110 | 111 | elif agent_dir == 1: # south 112 | for j in range(0, self.height): 113 | out_of_bound = j == self.height - 1 114 | self.looking_horizontal(j, 1, out_of_bound) 115 | 116 | elif agent_dir == 3: # north 117 | for j in reversed(range(0, self.height)): 118 | out_of_bound = j == 0 119 | self.looking_horizontal(j, -1, out_of_bound) 120 | return self.mask 121 | 122 | 123 | class HistoricalObsEnv(MiniGridEnv): 124 | def __init__( 125 | self, 126 | mission_space: MissionSpace, 127 | width: int | None = None, 128 | height: int | None = None, 129 | max_steps: int = 100, 130 | **kwargs, 131 | ): 132 | super().__init__( 133 | mission_space=mission_space, 134 | width=width, 135 | height=height, 136 | max_steps=max_steps, 137 | **kwargs, 138 | ) 139 | 140 | image_observation_space = spaces.Box( 141 | low=0, 142 | high=255, 143 | shape=(self.width, self.height, 4), 144 | dtype="uint8", 145 | ) 146 | mission_space = self.observation_space["mission"] 147 | self.observation_space = spaces.Dict( 148 | { 149 | "image": image_observation_space, 150 | "mission": mission_space, 151 | } 152 | ) 153 | 154 | self.mask = np.zeros(shape=(self.width, self.height), dtype=bool) 155 | 156 | def reset(self, seed=None): 157 | super().reset(seed=seed) 158 | self.mask = np.zeros(shape=(self.width, self.height), dtype=bool) 159 | topX, topY, botX, botY = self.get_view_exts(agent_view_size=None, clip=True) 160 | vis_grid = self.slice_grid(topX, topY, botX - topX, botY - topY) 161 | if not self.see_through_walls: 162 | vis_mask = vis_grid.process_vis((self.agent_pos[0] - topX, self.agent_pos[1] - topY), self.agent_dir) 163 | else: 164 | vis_mask = np.ones(shape=(botX - topX, botY - topY), dtype=bool) 165 | 166 | self.mask[topX:botX, topY:botY] += vis_mask 167 | obs = self.gen_obs() 168 | 169 | return obs, {} 170 | 171 | def slice_grid(self, topX, topY, width, height) -> FlexibleGrid: 172 | """ 173 | Get a subset of the grid 174 | """ 175 | 176 | vis_grid = FlexibleGrid(width, height) 177 | 178 | for j in range(0, height): 179 | for i in range(0, width): 180 | x = topX + i 181 | y = topY + j 182 | 183 | if 0 <= x < self.grid.width and 0 <= y < self.grid.height: 184 | v = self.grid.get(x, y) 185 | else: 186 | v = Wall() 187 | 188 | vis_grid.set(i, j, v) 189 | 190 | return vis_grid 191 | 192 | def get_view_exts(self, agent_view_size=None, clip=False): 193 | """ 194 | Get the extents of the square set of tiles visible to the agent 195 | if agent_view_size is None, use self.agent_view_size 196 | """ 197 | 198 | topX, topY, botX, botY = super().get_view_exts(agent_view_size) 199 | if clip: 200 | topX = max(0, topX) 201 | topY = max(0, topY) 202 | botX = min(botX, self.width) 203 | botY = min(botY, self.height) 204 | return topX, topY, botX, botY 205 | 206 | def gen_hist_obs_grid(self, agent_view_size=None): 207 | topX, topY, botX, botY = self.get_view_exts(agent_view_size, clip=True) 208 | grid = self.grid.copy() 209 | vis_grid = self.slice_grid(topX, topY, botX - topX, botY - topY) 210 | if not self.see_through_walls: 211 | vis_mask = vis_grid.process_vis((self.agent_pos[0] - topX, self.agent_pos[1] - topY), self.agent_dir) 212 | else: 213 | vis_mask = np.ones(shape=(botX - topX, botY - topY), dtype=bool) 214 | 215 | self.mask[topX:botX, topY:botY] += vis_mask 216 | 217 | # Make it so the agent sees what it's carrying 218 | if self.carrying: 219 | grid.set(*self.agent_pos, self.carrying) 220 | else: 221 | grid.set(*self.agent_pos, None) 222 | 223 | return grid 224 | 225 | def gen_obs(self): 226 | grid = self.gen_hist_obs_grid() 227 | 228 | image = grid.encode(self.mask) 229 | 230 | agent_pos_dir = np.zeros((self.width, self.height), dtype="uint8") + 4 231 | agent_pos_dir[self.agent_pos] = self.agent_dir 232 | 233 | obs = {"image": np.concatenate((image, agent_pos_dir[:,:,None]), axis=2), "mission": self.mission} 234 | return obs 235 | 236 | def get_full_render(self, highlight, tile_size): 237 | grid = self.gen_hist_obs_grid() 238 | img = grid.render( 239 | tile_size, 240 | agent_pos=self.agent_pos, 241 | agent_dir=self.agent_dir, 242 | highlight_mask=self.mask, 243 | ) 244 | return img 245 | 246 | def get_mask_render(self, path_mask=None, tile_size=TILE_PIXELS): 247 | grid = self.gen_hist_obs_grid() 248 | unseen_mask = np.ones(shape=(self.width, self.height), dtype=bool) ^ self.mask 249 | 250 | if path_mask is None: 251 | path_mask = np.zeros(shape=(self.width, self.height), dtype=bool) 252 | 253 | # Compute the total grid size 254 | width_px = self.width * tile_size 255 | height_px = self.height * tile_size 256 | 257 | img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8) 258 | 259 | # Render the grid 260 | for j in range(0, self.height): 261 | for i in range(0, self.width): 262 | if unseen_mask[i, j]: 263 | cell = Unseen() 264 | elif path_mask[i, j]: 265 | cell = Footprint() 266 | else: 267 | cell = grid.get(i, j) 268 | 269 | agent_here = np.array_equal(self.agent_pos, (i, j)) 270 | tile_img = Grid.render_tile( 271 | cell, 272 | agent_dir=self.agent_dir if agent_here else None, 273 | highlight=False, 274 | tile_size=tile_size, 275 | ) 276 | 277 | ymin = j * tile_size 278 | ymax = (j + 1) * tile_size 279 | xmin = i * tile_size 280 | xmax = (i + 1) * tile_size 281 | img[ymin:ymax, xmin:xmax, :] = tile_img 282 | 283 | return img -------------------------------------------------------------------------------- /planner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : planner.py 5 | @Time : 2023/05/16 09:12:11 6 | @Author : Hu Bin 7 | @Version : 1.0 8 | @Desc : None 9 | ''' 10 | 11 | 12 | import os, requests 13 | from typing import Any 14 | from mediator import * 15 | from utils import global_param 16 | 17 | 18 | from abc import ABC, abstractmethod 19 | 20 | class Base_Planner(ABC): 21 | """The base class for Planner.""" 22 | 23 | def __init__(self, offline=True, soft=False, prefix=''): 24 | super().__init__() 25 | self.offline = offline 26 | self.soft = soft 27 | self.prompt_prefix = prefix 28 | self.plans_dict = {} 29 | self.mediator = None 30 | 31 | self.dialogue_system = '' 32 | self.dialogue_user = '' 33 | self.dialogue_logger = '' 34 | self.show_dialogue = False 35 | 36 | if not offline: 37 | self.llm_model = "vicuna-33b" 38 | self.llm_url = 'http://localhost:3300/v1/chat/completions' 39 | # self.llm_model = "chatglm_Turbo" 40 | # self.llm_url = 'http://10.109.116.3:6000/chat' 41 | self.plans_dict = {} 42 | if self.llm_model == "vicuna-33b": 43 | self.init_llm() 44 | 45 | def reset(self, show=False): 46 | self.dialogue_user = '' 47 | self.dialogue_logger = '' 48 | self.show_dialogue = show 49 | ## reset dialogue 50 | if self.show_dialogue: 51 | print(self.dialogue_system) 52 | self.mediator.reset() 53 | # if not self.offline: 54 | # self.online_planning("reset") 55 | 56 | def init_llm(self): 57 | self.dialogue_system += self.prompt_prefix 58 | 59 | ## set system part 60 | server_error_cnt = 0 61 | while server_error_cnt < 10: 62 | try: 63 | headers = {'Content-Type': 'application/json'} 64 | 65 | data = {'model': self.llm_model, "messages":[{"role": "system", "content": self.prompt_prefix}]} 66 | response = requests.post(self.llm_url, headers=headers, json=data) 67 | 68 | if response.status_code == 200: 69 | result = response.json() 70 | break 71 | else: 72 | assert False, f"fail to initialize: status code {response.status_code}" 73 | 74 | except Exception as e: 75 | server_error_cnt += 1 76 | print(f"fail to initialize: {e}") 77 | 78 | def query_codex(self, prompt_text): 79 | server_error_cnt = 0 80 | while server_error_cnt < 10: 81 | try: 82 | #response = openai.Completion.create(prompt_text) 83 | headers = {'Content-Type': 'application/json'} 84 | 85 | # print(f"user prompt:{prompt_text}") 86 | if self.llm_model == "chatglm_Turbo": 87 | data = {'model': self.llm_model, "prompt":[{"role": "user", "content": self.prompt_prefix + prompt_text}]} 88 | elif self.llm_model == "vicuna-33b": 89 | data = {'model': self.llm_model, "messages":[{"role": "user", "content": prompt_text}]} 90 | response = requests.post(self.llm_url, headers=headers, json=data) 91 | 92 | if response.status_code == 200: 93 | result = response.json() 94 | break 95 | else: 96 | assert False, f"fail to query: status code {response.status_code}" 97 | 98 | except Exception as e: 99 | server_error_cnt += 1 100 | print(f"fail to query: {e}") 101 | 102 | try: 103 | plan = re.search("Action[s]*\:\s*\{([\w\s\<\>\,]*)\}", result, re.I | re.M).group(1) 104 | return plan 105 | except: 106 | print(f"LLM response invalid format: '{result}'.") 107 | return self.query_codex(prompt_text) 108 | 109 | def plan(self, text, n_ask=10): 110 | if text in self.plans_dict.keys(): 111 | plans, probs = self.plans_dict[text] 112 | else: 113 | print(f"new obs: {text}") 114 | plans = {} 115 | for _ in range(n_ask): 116 | plan = self.query_codex(text) 117 | if plan in plans.keys(): 118 | plans[plan] += 1/n_ask 119 | else: 120 | plans[plan] = 1/n_ask 121 | 122 | plans, probs = list(plans.keys()), list(plans.values()) 123 | self.plans_dict[text] = (plans, probs) 124 | 125 | for k, v in self.plans_dict.items(): 126 | print(f"{k}:{v}") 127 | 128 | return plans, probs 129 | 130 | def __call__(self, obs): 131 | # self.mediator.reset() 132 | text = self.mediator.RL2LLM(obs) 133 | plans, probs = self.plan(text) 134 | self.dialogue_user = text + "\n" + str(plans) + "\n" + str(probs) 135 | if self.show_dialogue: 136 | print(self.dialogue_user) 137 | skill_list, probs = self.mediator.LLM2RL(plans, probs) 138 | 139 | return skill_list, probs 140 | 141 | 142 | 143 | class SimpleDoorKey_Planner(Base_Planner): 144 | def __init__(self, offline, soft, prefix): 145 | super().__init__(offline, soft, prefix) 146 | self.mediator = SimpleDoorKey_Mediator(soft) 147 | if offline: 148 | self.plans_dict = { 149 | "Agent sees , holds ." : [["explore"], [1.0]], 150 | "Agent sees , holds ." : [["explore"], [1.0]], 151 | "Agent sees , holds ." : [["go to , pick up ", "pick up "], [0.98, 0.02]], 152 | "Agent sees , holds ." : [["explore", "go to , open ", "explore, go to , open ", "explore, go to ", "explore, open ", "go to , pick up , use "], [0.68, 0.22, 0.04, 0.02, 0.02, 0.02]], 153 | "Agent sees , holds ." : [["go to , open with ", "go to , open ", "go to , pick up , go to , open ", "explore, go to "], [0.62, 0.3, 0.06, 0.02]], 154 | "Agent sees , , holds ." : [["go to , pick up , go to , open ", "go to , pick up , open ", "pick up , go to , open ", "go to , go to , use ", "go to , pick up , explore"], [0.84, 0.08, 0.04, 0.02, 0.02]] 155 | } 156 | 157 | 158 | class ColoredDoorKey_Planner(Base_Planner): 159 | def __init__(self, offline, soft, prefix): 160 | super().__init__(offline, soft, prefix) 161 | self.mediator = ColoredDoorKey_Mediator(soft) 162 | if offline: 163 | self.plans_dict = { 164 | "Agent sees , holds ." : [["explore"],[1]], 165 | "Agent sees , holds ." : [["explore","go to east"], [0.94,0.06]], 166 | "Agent sees , holds ." : [["go to , pick up ","pick up "],[0.87,0.13]], 167 | "Agent sees , holds ." : [["explore"],[1.0]], 168 | "Agent sees , holds .": [["go to , open ","open "],[0.72,0.28]], 169 | "Agent sees , holds .": [["explore", "go to "],[0.98,0.02]], 170 | "Agent sees , holds .": [["drop , go to , pick up ","drop , pick up "],[0.87,0.13]], 171 | "Agent sees , , holds .": [["go to , pick up ","pick up "],[0.81,0.19]], 172 | "Agent sees , , holds .": [["go to , pick up ","pick up "],[0.73,0.27]], 173 | "Agent sees , , holds .": [["go to , pick up ","pick up "],[0.84,0.16]], 174 | "Agent sees , , holds .": [["drop , go to , pick up ","drop , pick up "],[0.79,0.21]], 175 | "Agent sees , , holds .": [["drop , go to , pick up ", "go to , open "],[0.71,0.29]], 176 | "Agent sees , , , holds .": [["go to , pick up ","pick up ","go to , pick up "],[0.72,0.24,0.04]], 177 | "Agent sees , , , holds .": [["go to , pick up "," pick up "],[0.94,0.06]], 178 | } 179 | 180 | def plan(self, text): 181 | pattern= r'\b(blue|green|grey|purple|red|yellow)\b' 182 | color_words = re.findall(pattern, text) 183 | 184 | words = list(set(color_words)) 185 | words.sort(key=color_words.index) 186 | color_words = words 187 | color_index =['color1','color2'] 188 | if color_words != []: 189 | for i in range(len(color_words)): 190 | text = text.replace(color_words[i], color_index[i]) 191 | 192 | plans, probs = super().plan(text) 193 | 194 | plans = str(plans) 195 | for i in range(len(color_words)): 196 | plans = plans.replace(color_index[i], color_words[i]) 197 | plans = eval(plans) 198 | 199 | return plans, probs 200 | 201 | 202 | class TwoDoor_Planner(Base_Planner): 203 | def __init__(self, offline, soft, prefix): 204 | super().__init__(offline, soft, prefix) 205 | self.mediator = TwoDoor_Mediator(soft) 206 | if offline: 207 | self.plans_dict = { 208 | "Agent sees , holds ." : [["explore"], [1.0]], 209 | "Agent sees , holds ." : [["explore"], [1.0]], 210 | "Agent sees , holds ." : [["go to , pick up "], [1.0]], 211 | "Agent sees , holds ." : [["explore"], [1.0]], 212 | "Agent sees , holds ." : [["go to , open "], [1.0]], 213 | "Agent sees , , holds ." : [["go to , pick up "], [1.0]], 214 | "Agent sees , , holds ." : [["explore"], [1.0]], 215 | "Agent sees , , , holds .": [["go to , pick up "], [1.0]], 216 | "Agent sees , , holds .": [["go to , open ", "go to , open "], [0.5, 0.5]], 217 | } 218 | 219 | 220 | def Planner(task, offline=True, soft=False, prefix=''): 221 | if task.lower() == "simpledoorkey": 222 | planner = SimpleDoorKey_Planner(offline, soft, prefix) 223 | elif task.lower() == "lavadoorkey": 224 | planner = SimpleDoorKey_Planner(offline, soft, prefix) 225 | elif task.lower() == "coloreddoorkey": 226 | planner = ColoredDoorKey_Planner(offline, soft, prefix) 227 | elif task.lower() == "twodoor": 228 | planner = TwoDoor_Planner(offline, soft, prefix) 229 | return planner 230 | 231 | -------------------------------------------------------------------------------- /mediator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : mediator.py 5 | @Time : 2023/05/16 10:22:36 6 | @Author : Hu Bin 7 | @Version : 1.0 8 | @Desc : None 9 | ''' 10 | 11 | import numpy as np 12 | import re 13 | import copy 14 | from abc import ABC, abstractmethod 15 | @staticmethod 16 | def get_minigrid_words(): 17 | colors = ["red", "green", "blue", "yellow", "purple", "grey"] 18 | objects = [ 19 | "unseen", 20 | "empty", 21 | "wall", 22 | "floor", 23 | "box", 24 | "key", 25 | "ball", 26 | "door", 27 | "goal", 28 | "agent", 29 | "lava", 30 | ] 31 | 32 | verbs = [ 33 | "pick", 34 | "avoid", 35 | "get", 36 | "find", 37 | "put", 38 | "use", 39 | "open", 40 | "go", 41 | "fetch", 42 | "reach", 43 | "unlock", 44 | "traverse", 45 | ] 46 | 47 | extra_words = [ 48 | "up", 49 | "the", 50 | "a", 51 | "at", 52 | ",", 53 | "square", 54 | "and", 55 | "then", 56 | "to", 57 | "of", 58 | "rooms", 59 | "near", 60 | "opening", 61 | "must", 62 | "you", 63 | "matching", 64 | "end", 65 | "hallway", 66 | "object", 67 | "from", 68 | "room", 69 | ] 70 | 71 | all_words = colors + objects + verbs + extra_words 72 | assert len(all_words) == len(set(all_words)) 73 | return {word: i for i, word in enumerate(all_words)} 74 | 75 | # Map of agent direction, 0: East; 1: South; 2: West; 3: North 76 | DIRECTION = { 77 | 0: [1, 0], 78 | 1: [0, 1], 79 | 2: [-1, 0], 80 | 3: [0, -1], 81 | } 82 | 83 | # Map of object type to integers 84 | OBJECT_TO_IDX = { 85 | "unseen": 0, 86 | "empty": 1, 87 | "wall": 2, 88 | "floor": 3, 89 | "door": 4, 90 | "key": 5, 91 | "ball": 6, 92 | "box": 7, 93 | "goal": 8, 94 | "lava": 9, 95 | "agent": 10, 96 | } 97 | IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys())) 98 | 99 | # Used to map colors to integers 100 | COLOR_TO_IDX = {"red": 0, "green": 1, "blue": 2, "purple": 3, "yellow": 4, "grey": 5} 101 | IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys())) 102 | 103 | # Map of state names to integers 104 | STATE_TO_IDX = { 105 | "open": 0, 106 | "closed": 1, 107 | "locked": 2, 108 | } 109 | IDX_TO_STATE = dict(zip(STATE_TO_IDX.values(), STATE_TO_IDX.keys())) 110 | 111 | # Map of skill names to integers 112 | SKILL_TO_IDX = {"explore": 0, "go to object": 1, "pickup": 2, "drop": 3, "toggle": 4} 113 | IDX_TO_SKILL = dict(zip(SKILL_TO_IDX.values(), SKILL_TO_IDX.keys())) 114 | 115 | 116 | 117 | 118 | class Base_Mediator(ABC): 119 | """The base class for Base_Mediator.""" 120 | 121 | def __init__(self, soft): 122 | super().__init__() 123 | self.soft = soft 124 | self.obj_coordinate = {} 125 | 126 | # obs to natural language 127 | def RL2LLM(self, obs, color_info=True): 128 | context = '' 129 | if len(obs.shape) == 4: 130 | obs = obs[0,:,:,-4:] 131 | obs_object = copy.deepcopy(obs[:,:,0]) 132 | agent_map = obs[:, :, 3] 133 | agent_pos = np.argwhere(agent_map != 4)[0] 134 | agent_dir = agent_map[agent_pos[0],agent_pos[1]] 135 | 136 | key_list = np.argwhere(obs_object==5) 137 | door_list = np.argwhere(obs_object==4) 138 | 139 | carrying = "nothing" 140 | if len(key_list): 141 | for key in key_list: 142 | i, j = key 143 | if color_info: 144 | color = obs[i,j,1] 145 | obj = f"{IDX_TO_COLOR[color]} key" 146 | else: 147 | obj = "key" 148 | 149 | if (agent_pos == key).all(): 150 | carrying = obj 151 | else: 152 | context += f"<{obj}>, " 153 | self.obj_coordinate[obj] = (i,j) 154 | 155 | if len(door_list): 156 | for door in door_list: 157 | i, j = door 158 | if color_info: 159 | color = obs[i,j,1] 160 | obj = f"{IDX_TO_COLOR[color]} door" 161 | else: 162 | obj = "door" 163 | 164 | context += f"<{obj}>, " 165 | self.obj_coordinate[obj] = (i,j) 166 | 167 | if context == '': 168 | context += ", " 169 | context += f"holds <{carrying}>." 170 | 171 | context = f"Agent sees {context}" 172 | return context 173 | 174 | def LLM2RL(self, plans, probs): 175 | if self.soft: 176 | skill_list = [self.parser(plan) for plan in plans] 177 | else: 178 | plan = np.random.choice(plans, p=probs) 179 | skill_list = [self.parser(plan)] 180 | probs = [1.] 181 | 182 | return skill_list, probs 183 | 184 | def reset(self): 185 | self.obj_coordinate = {} 186 | 187 | class SimpleDoorKey_Mediator(Base_Mediator): 188 | def __init__(self, soft): 189 | super().__init__(soft) 190 | 191 | def RL2LLM(self, obs): 192 | return super().RL2LLM(obs, color_info=False) 193 | 194 | def parser(self, plan): 195 | skill_list = [] 196 | skills = plan.split(',') 197 | for text in skills: 198 | # action: 199 | if "explore" in text: 200 | act = SKILL_TO_IDX["explore"] 201 | elif "go to" in text: 202 | act = SKILL_TO_IDX["go to object"] 203 | elif "pick up" in text: 204 | act = SKILL_TO_IDX["pickup"] 205 | elif "drop" in text: 206 | act = SKILL_TO_IDX["drop"] 207 | elif "open" in text: 208 | act = SKILL_TO_IDX["toggle"] 209 | else: 210 | # print("Unknown Planning :", text) 211 | act = 6 # do nothing 212 | # object: 213 | try: 214 | if "door" in text: 215 | obj = OBJECT_TO_IDX["door"] 216 | coordinate = self.obj_coordinate["door"] 217 | elif "key" in text: 218 | obj = OBJECT_TO_IDX["key"] 219 | coordinate = self.obj_coordinate["key"] 220 | elif "explore" in text: 221 | obj = OBJECT_TO_IDX["empty"] 222 | coordinate = None 223 | else: 224 | assert False 225 | except: 226 | # print("Unknown Planning :", text) 227 | act = 6 # do nothing 228 | obj = OBJECT_TO_IDX["empty"] 229 | coordinate = None 230 | 231 | skill = {"action": act, 232 | "object": obj, 233 | "coordinate": coordinate,} 234 | skill_list.append(skill) 235 | 236 | return skill_list 237 | 238 | 239 | class ColoredDoorKey_Mediator(Base_Mediator): 240 | def __init__(self, soft): 241 | super().__init__(soft) 242 | 243 | def RL2LLM(self, obs): 244 | return super().RL2LLM(obs) 245 | 246 | def parser(self, plan): 247 | skill_list = [] 248 | skills = plan.split(',') 249 | for text in skills: 250 | # action: 251 | if "explore" in text: 252 | act = SKILL_TO_IDX["explore"] 253 | elif "go to" in text: 254 | act = SKILL_TO_IDX["go to object"] 255 | elif "pick up" in text: 256 | act = SKILL_TO_IDX["pickup"] 257 | elif "drop" in text: 258 | act = SKILL_TO_IDX["drop"] 259 | elif "open" in text: 260 | act = SKILL_TO_IDX["toggle"] 261 | else: 262 | print("Unknown Planning :", text) 263 | act = 6 # do nothing 264 | # object: 265 | try: 266 | if "door" in text: 267 | obj = OBJECT_TO_IDX["door"] 268 | words = text.split(' ') 269 | filter_words = [] 270 | for w in words: 271 | w1="".join(c for c in w if c.isalpha()) 272 | filter_words.append(w1) 273 | object_word = filter_words[-2] + " " + filter_words[-1] 274 | coordinate = self.obj_coordinate[object_word] 275 | elif "key" in text: 276 | obj = OBJECT_TO_IDX["key"] 277 | words = text.split(' ') 278 | filter_words = [] 279 | for w in words: 280 | w1="".join(c for c in w if c.isalpha()) 281 | filter_words.append(w1) 282 | object_word = filter_words[-2] + " " + filter_words[-1] 283 | coordinate = self.obj_coordinate[object_word] 284 | elif "explore" in text: 285 | obj = OBJECT_TO_IDX["empty"] 286 | coordinate = None 287 | else: 288 | assert False 289 | except: 290 | print("Unknown Planning :", text) 291 | act = 6 # do nothing 292 | obj = OBJECT_TO_IDX["empty"] 293 | coordinate = None 294 | 295 | skill = {"action": act, 296 | "object": obj, 297 | "coordinate": coordinate,} 298 | skill_list.append(skill) 299 | 300 | return skill_list 301 | 302 | class TwoDoor_Mediator(Base_Mediator): 303 | def __init__(self, soft): 304 | super().__init__(soft) 305 | 306 | def RL2LLM(self, obs): 307 | context = '' 308 | if len(obs.shape) == 4: 309 | obs = obs[0,:,:,-4:] 310 | obs_object = copy.deepcopy(obs[:,:,0]) 311 | agent_map = obs[:, :, 3] 312 | agent_pos = np.argwhere(agent_map != 4)[0] 313 | agent_dir = agent_map[agent_pos[0],agent_pos[1]] 314 | 315 | key_list = np.argwhere(obs_object==5) 316 | door_list = np.argwhere(obs_object==4) 317 | 318 | carrying = "nothing" 319 | if len(key_list): 320 | for key in key_list: 321 | i, j = key 322 | obj = "key" 323 | 324 | if (agent_pos == key).all(): 325 | carrying = obj 326 | else: 327 | context += f"<{obj}>, " 328 | self.obj_coordinate[obj] = (i,j) 329 | 330 | if len(door_list): 331 | n = 1 332 | for door in door_list: 333 | i, j = door 334 | obj = f"door{n}" 335 | n += 1 336 | 337 | context += f"<{obj}>, " 338 | self.obj_coordinate[obj] = (i,j) 339 | 340 | if context == '': 341 | context += ", " 342 | context += f"holds <{carrying}>." 343 | 344 | context = f"Agent sees {context}" 345 | return context 346 | 347 | def parser(self, plan): 348 | skill_list = [] 349 | skills = plan.split(',') 350 | for text in skills: 351 | # action: 352 | if "explore" in text: 353 | act = SKILL_TO_IDX["explore"] 354 | elif "go to" in text: 355 | act = SKILL_TO_IDX["go to object"] 356 | elif "pick up" in text: 357 | act = SKILL_TO_IDX["pickup"] 358 | elif "drop" in text: 359 | act = SKILL_TO_IDX["drop"] 360 | elif "open" in text: 361 | act = SKILL_TO_IDX["toggle"] 362 | else: 363 | # print("Unknown Planning :", text) 364 | act = 6 # do nothing 365 | # object: 366 | try: 367 | if "door1" in text: 368 | obj = OBJECT_TO_IDX["door"] 369 | coordinate = self.obj_coordinate["door1"] 370 | elif "door2" in text: 371 | obj = OBJECT_TO_IDX["door"] 372 | coordinate = self.obj_coordinate["door2"] 373 | elif "key" in text: 374 | obj = OBJECT_TO_IDX["key"] 375 | coordinate = self.obj_coordinate["key"] 376 | elif "explore" in text: 377 | obj = OBJECT_TO_IDX["empty"] 378 | coordinate = None 379 | else: 380 | assert False 381 | except: 382 | # print("Unknown Planning :", text) 383 | act = 6 # do nothing 384 | obj = OBJECT_TO_IDX["empty"] 385 | coordinate = None 386 | 387 | skill = {"action": act, 388 | "object": obj, 389 | "coordinate": coordinate,} 390 | skill_list.append(skill) 391 | 392 | return skill_list 393 | 394 | 395 | if __name__ == "__main__": 396 | word = get_minigrid_words() -------------------------------------------------------------------------------- /Game.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : Game.py 5 | @Time : 2023/07/14 11:06:59 6 | @Author : Zhou Zihao 7 | @Version : 1.0 8 | @Desc : None 9 | ''' 10 | 11 | import os, json, sys 12 | import gymnasium as gym 13 | import numpy as np 14 | import torch 15 | import cv2 16 | import time 17 | 18 | import env 19 | import algos 20 | import skill 21 | import utils 22 | import cv2 23 | from teacher_policy import TeacherPolicy 24 | 25 | prefix = os.getcwd() 26 | task_info_json = os.path.join(prefix, "prompt/task_info.json") 27 | 28 | class Game: 29 | def __init__(self, args, training=True): 30 | # init seed 31 | self.seed = args.seed 32 | self.setup_seed(args.seed) 33 | 34 | # init env 35 | self.load_task_info(args.task, args.frame_stack, args.offline_planner, args.soft_planner) 36 | 37 | # init logger 38 | self.logger = utils.create_logger(args, training) 39 | 40 | # init policy 41 | if args.loaddir: 42 | model_dir = os.path.join(args.logdir, args.policy, args.task, args.loaddir, args.loadmodel) 43 | policy = torch.load(model_dir) 44 | else: 45 | policy = None 46 | self.device = args.device 47 | self.batch_size = args.batch_size 48 | self.recurrent = args.recurrent 49 | # self.student_policy = policy 50 | self.student_policy = algos.PPO(policy, 51 | self.obs_space, 52 | self.action_space, 53 | self.device, 54 | self.logger.dir, 55 | batch_size=self.batch_size, 56 | recurrent=self.recurrent) 57 | 58 | # init buffer 59 | self.gamma = args.gamma 60 | self.lam = args.lam 61 | self.buffer = algos.Buffer(self.gamma, self.lam, self.device) 62 | 63 | # other settings 64 | self.n_itr = args.n_itr 65 | self.traj_per_itr = args.traj_per_itr 66 | self.num_eval = args.num_eval 67 | self.eval_interval = args.eval_interval 68 | self.save_interval = args.save_interval 69 | self.total_steps = 0 70 | 71 | 72 | def setup_seed(self, seed): 73 | # setup seed for Numpy, Torch and LLM, not for env 74 | torch.manual_seed(seed) 75 | torch.cuda.manual_seed_all(seed) 76 | np.random.seed(seed) 77 | torch.backends.cudnn.deterministic = True 78 | 79 | 80 | def load_task_info(self, task, frame_stack, offline, soft): 81 | print(f"[INFO]: resetting the task: {task}") 82 | with open(task_info_json, 'r') as f: 83 | task_info = json.load(f) 84 | task = task.lower() 85 | 86 | env_fn = utils.make_env_fn(task_info[task]['configurations'], 87 | render_mode="rgb_array", 88 | frame_stack = frame_stack) 89 | self.env = utils.WrapEnv(env_fn) 90 | self.obs_space = utils.get_obss_preprocessor(self.env.observation_space)[0] 91 | self.action_space = self.env.action_space.n 92 | self.max_ep_len = self.env.max_steps 93 | 94 | prefix = task_info[task]['description'] + task_info[task]['example'] 95 | self.teacher_policy = TeacherPolicy(task, offline, soft, prefix, self.action_space, self.env.agent_view_size) 96 | 97 | 98 | def train(self): 99 | start_time = time.time() 100 | for itr in range(self.n_itr): 101 | print("********** Iteration {} ************".format(itr)) 102 | print("time elapsed: {:.2f} s".format(time.time() - start_time)) 103 | 104 | ## collecting ## 105 | sample_start = time.time() 106 | self.buffer.clear() 107 | n_traj = self.traj_per_itr 108 | for _ in range(n_traj): 109 | self.collect() 110 | while len(self.buffer) < self.batch_size * 2: 111 | self.collect() 112 | n_traj += 1 113 | total_steps = len(self.buffer) 114 | samp_time = time.time() - sample_start 115 | print("{:.2f} s to collect {:6n} timesteps | {:3.2f}sample/s.".format(samp_time, total_steps, (total_steps)/samp_time)) 116 | self.total_steps += total_steps 117 | 118 | ## training ## 119 | optimizer_start = time.time() 120 | mean_losses = self.student_policy.update_policy(self.buffer) 121 | opt_time = time.time() - optimizer_start 122 | try: 123 | print("{:.2f} s to optimizer| loss {:6.3f}, entropy {:6.3f}, kickstarting {:6.3f}.".format(opt_time, mean_losses[0], mean_losses[1], mean_losses[2])) 124 | except: 125 | print(mean_losses) 126 | 127 | ## evaluate ## 128 | if itr % self.eval_interval == 0 and itr > 0: 129 | evaluate_start = time.time() 130 | eval_returns = [] 131 | eval_lens = [] 132 | eval_success = [] 133 | for i in range(self.num_eval): 134 | eval_outputs = self.evaluate(itr, record_frames=False) 135 | eval_returns.append(eval_outputs[0]) 136 | eval_lens.append(eval_outputs[1]) 137 | eval_success.append(eval_outputs[2]) 138 | eval_time = time.time() - evaluate_start 139 | print("{:.2f} s to evaluate.".format(eval_time)) 140 | 141 | if itr % self.save_interval == 0 and itr > 0: 142 | self.student_policy.save(str(itr)) 143 | 144 | ## log ## 145 | if self.logger is not None: 146 | avg_len = np.mean(self.buffer.ep_lens) 147 | avg_reward = np.mean(self.buffer.ep_returns) 148 | std_reward = np.std(self.buffer.ep_returns) 149 | success_rate = sum(i > 0 for i in self.buffer.ep_returns) / n_traj 150 | sys.stdout.write("-" * 49 + "\n") 151 | sys.stdout.write("| %25s | %15s |" % ('Timesteps', self.total_steps) + "\n") 152 | sys.stdout.write("| %25s | %15s |" % ('Return (train)', round(avg_reward,2)) + "\n") 153 | sys.stdout.write("| %25s | %15s |" % ('Episode Length (train)', round(avg_len,2)) + "\n") 154 | sys.stdout.write("| %25s | %15s |" % ('Success Rate (train)', round(success_rate,2)) + "\n") 155 | if itr % self.eval_interval == 0 and itr > 0: 156 | avg_eval_reward = np.mean(eval_returns) 157 | avg_eval_len = np.mean(eval_lens) 158 | eval_success_rate = np.sum(eval_success) / self.num_eval 159 | sys.stdout.write("| %25s | %15s |" % ('Return (eval)', round(avg_eval_reward,2)) + "\n") 160 | sys.stdout.write("| %25s | %15s |" % ('Episode Length (eval) ', round(avg_eval_len,2)) + "\n") 161 | sys.stdout.write("| %25s | %15s |" % ('Success Rate (eval) ', round(eval_success_rate,2)) + "\n") 162 | self.logger.add_scalar("Test/Return", avg_eval_reward, itr) 163 | self.logger.add_scalar("Test/Eplen", avg_eval_len, itr) 164 | self.logger.add_scalar("Test/Success Rate", eval_success_rate, itr) 165 | sys.stdout.write("-" * 49 + "\n") 166 | sys.stdout.flush() 167 | 168 | self.logger.add_scalar("Train/Return Mean", avg_reward, itr) 169 | self.logger.add_scalar("Train/Return Std", std_reward, itr) 170 | self.logger.add_scalar("Train/Eplen", avg_len, itr) 171 | self.logger.add_scalar("Train/Success Rate", success_rate, itr) 172 | self.logger.add_scalar("Train/Loss", mean_losses[0], itr) 173 | self.logger.add_scalar("Train/Mean Entropy", mean_losses[1], itr) 174 | self.logger.add_scalar("Train/Kickstarting Loss", mean_losses[2], itr) 175 | self.logger.add_scalar("Train/Policy Loss", mean_losses[3], itr) 176 | self.logger.add_scalar("Train/Value Loss", mean_losses[4], itr) 177 | self.logger.add_scalar("Train/Kickstarting Coef", self.student_policy.ks_coef, itr) 178 | 179 | self.student_policy.save() 180 | 181 | 182 | def collect(self): 183 | ''' 184 | collect episodic data. 185 | ''' 186 | with torch.no_grad(): 187 | obs = self.env.reset() 188 | done = False 189 | ep_len = 0 190 | 191 | # reset student policy 192 | mask = torch.FloatTensor([1]).to(self.device) # not done until episode ends 193 | states = self.student_policy.model.init_states(self.device) if self.recurrent else None 194 | 195 | # reset teacher policy 196 | self.teacher_policy.reset() 197 | 198 | while not done and ep_len < self.max_ep_len: 199 | # get action from student policy 200 | dist, value, states = self.student_policy(torch.Tensor(obs).to(self.device), 201 | mask, states) 202 | action = dist.sample() 203 | log_probs = dist.log_prob(action) 204 | action = action.to("cpu").numpy() 205 | 206 | # get action from teacher policy 207 | teacher_probs = self.teacher_policy(obs[0]) 208 | 209 | # interact with env 210 | next_obs, reward, done, info = self.env.step(action) 211 | 212 | # store in buffer 213 | self.buffer.store(obs, 214 | action, 215 | reward, 216 | value.to("cpu").numpy(), 217 | log_probs.to("cpu").numpy(), 218 | teacher_probs) 219 | obs = next_obs 220 | ep_len += 1 221 | if done: 222 | value = 0. 223 | else: 224 | value = self.student_policy(torch.Tensor(obs).to(self.device), 225 | mask, states)[1].to("cpu").item() 226 | self.buffer.finish_path(last_val=value) 227 | 228 | 229 | def evaluate(self, itr=None, seed=None, record_frames=True, deterministic=False, teacher_policy=False): 230 | with torch.no_grad(): 231 | # init env 232 | seed = seed if seed else np.random.randint(1000000) 233 | obs = self.env.reset(seed) 234 | done = False 235 | ep_len = 0 236 | ep_return = 0. 237 | 238 | if teacher_policy: 239 | # init teacher policy 240 | self.teacher_policy.reset() 241 | else: 242 | # init student policy 243 | mask = torch.Tensor([1.]).to(self.device) # not done until episode ends 244 | states = self.student_policy.model.init_states(self.device) if self.recurrent else None 245 | 246 | # init vedio directory 247 | if record_frames: 248 | img_array = [] 249 | img = self.env.get_mask_render() 250 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 251 | img_array.append(img) 252 | 253 | dir_name = 'teacher video' if teacher_policy else 'video' 254 | dir_path = os.path.join(self.logger.dir, dir_name) 255 | try: 256 | os.makedirs(dir_path) 257 | except OSError: 258 | pass 259 | 260 | while not done and ep_len < self.max_ep_len: 261 | if teacher_policy: 262 | # get action from teacher policy 263 | probs = self.teacher_policy(obs[0]) 264 | if deterministic: 265 | action = np.argmax(probs) 266 | else: 267 | action = np.random.choice(self.action_space, p=probs) 268 | else: 269 | # get action from student policy 270 | dist, value, states = self.student_policy(torch.Tensor(obs).to(self.device), mask, states) 271 | if deterministic: 272 | action = torch.argmax(dist.probs).unsqueeze(0).to("cpu").numpy() 273 | else: 274 | action = dist.sample().to("cpu").numpy() 275 | 276 | # interact with env 277 | obs, reward, done, info = self.env.step(action) 278 | ep_return += float(reward) 279 | ep_len += 1 280 | 281 | if record_frames: 282 | img = self.env.get_mask_render() 283 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 284 | img_array.append(img) 285 | 286 | ep_success = 1 if ep_return > 0 else 0 287 | 288 | # save vedio 289 | if record_frames: 290 | height, width, layers = img.shape 291 | size = (width,height) 292 | video_name = "%s-%s.avi"%(itr, seed) if itr else "%s.avi"%seed 293 | video_path = os.path.join(dir_path, video_name) 294 | out = cv2.VideoWriter(video_path, 295 | fourcc=cv2.VideoWriter_fourcc(*'DIVX'), 296 | fps=3, 297 | frameSize=size) 298 | 299 | for img in img_array: 300 | out.write(img) 301 | out.release() 302 | 303 | return ep_return, ep_len, ep_success 304 | 305 | 306 | if __name__ == '__main__': 307 | pass 308 | 309 | 310 | --------------------------------------------------------------------------------